87 lines
2.9 KiB
Python
87 lines
2.9 KiB
Python
"""Custom middleware for request tracking and rate limiting."""
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from typing import Optional
|
|
|
|
from fastapi import Request, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import JSONResponse
|
|
|
|
from app.core.config import get_settings
|
|
|
|
logger = logging.getLogger("middleware")
|
|
|
|
|
|
class RequestIdMiddleware(BaseHTTPMiddleware):
|
|
"""Adds X-Request-ID header to every response."""
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))
|
|
request.state.request_id = request_id
|
|
response = await call_next(request)
|
|
response.headers["X-Request-ID"] = request_id
|
|
return response
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""Redis-based rate limiting middleware.
|
|
|
|
Default: 100 requests/minute per user.
|
|
Auth endpoints: 10 requests/minute per IP.
|
|
Graceful fallback when Redis is unavailable (allows all requests).
|
|
"""
|
|
|
|
DEFAULT_LIMIT = 100 # requests per minute
|
|
AUTH_LIMIT = 10 # requests per minute for auth endpoints
|
|
WINDOW_SECONDS = 60
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
settings = get_settings()
|
|
if not settings.REDIS_ENABLED:
|
|
return await call_next(request)
|
|
|
|
try:
|
|
from app.core.cache import get_redis
|
|
redis = await get_redis()
|
|
except Exception:
|
|
redis = None
|
|
|
|
if not redis:
|
|
return await call_next(request)
|
|
|
|
try:
|
|
is_auth = request.url.path.startswith("/api/v1/auth")
|
|
limit = self.AUTH_LIMIT if is_auth else self.DEFAULT_LIMIT
|
|
|
|
if is_auth:
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
key = f"rl:auth:{client_ip}"
|
|
else:
|
|
# Use user token hash or client IP for rate limiting
|
|
auth_header = request.headers.get("Authorization", "")
|
|
if auth_header:
|
|
key = f"rl:user:{hash(auth_header)}"
|
|
else:
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
key = f"rl:anon:{client_ip}"
|
|
|
|
current = await redis.incr(key)
|
|
if current == 1:
|
|
await redis.expire(key, self.WINDOW_SECONDS)
|
|
|
|
if current > limit:
|
|
ttl = await redis.ttl(key)
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"detail": "Too many requests",
|
|
"retry_after": max(ttl, 1),
|
|
},
|
|
headers={"Retry-After": str(max(ttl, 1))},
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Rate limiting error (allowing request): %s", e)
|
|
|
|
return await call_next(request)
|