"""Custom middleware for request tracking and rate limiting.""" import logging import time import uuid from typing import Optional from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse from app.core.config import get_settings logger = logging.getLogger("middleware") class RequestIdMiddleware(BaseHTTPMiddleware): """Adds X-Request-ID header to every response.""" async def dispatch(self, request: Request, call_next): request_id = request.headers.get("X-Request-ID", str(uuid.uuid4())) request.state.request_id = request_id response = await call_next(request) response.headers["X-Request-ID"] = request_id return response class RateLimitMiddleware(BaseHTTPMiddleware): """Redis-based rate limiting middleware. Default: 100 requests/minute per user. Auth endpoints: 10 requests/minute per IP. Graceful fallback when Redis is unavailable (allows all requests). """ DEFAULT_LIMIT = 100 # requests per minute AUTH_LIMIT = 10 # requests per minute for auth endpoints WINDOW_SECONDS = 60 async def dispatch(self, request: Request, call_next): settings = get_settings() if not settings.REDIS_ENABLED: return await call_next(request) try: from app.core.cache import get_redis redis = await get_redis() except Exception: redis = None if not redis: return await call_next(request) try: is_auth = request.url.path.startswith("/api/v1/auth") limit = self.AUTH_LIMIT if is_auth else self.DEFAULT_LIMIT if is_auth: client_ip = request.client.host if request.client else "unknown" key = f"rl:auth:{client_ip}" else: # Use user token hash or client IP for rate limiting auth_header = request.headers.get("Authorization", "") if auth_header: key = f"rl:user:{hash(auth_header)}" else: client_ip = request.client.host if request.client else "unknown" key = f"rl:anon:{client_ip}" current = await redis.incr(key) if current == 1: await redis.expire(key, self.WINDOW_SECONDS) if current > limit: ttl = await redis.ttl(key) return JSONResponse( status_code=429, content={ "detail": "Too many requests", "retry_after": max(ttl, 1), }, headers={"Retry-After": str(max(ttl, 1))}, ) except Exception as e: logger.warning("Rate limiting error (allowing request): %s", e) return await call_next(request)