Files
ems-core/backend/app/core/middleware.py

87 lines
2.9 KiB
Python
Raw Normal View History

"""Custom middleware for request tracking and rate limiting."""
import logging
import time
import uuid
from typing import Optional
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from app.core.config import get_settings
logger = logging.getLogger("middleware")
class RequestIdMiddleware(BaseHTTPMiddleware):
"""Adds X-Request-ID header to every response."""
async def dispatch(self, request: Request, call_next):
request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))
request.state.request_id = request_id
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Redis-based rate limiting middleware.
Default: 100 requests/minute per user.
Auth endpoints: 10 requests/minute per IP.
Graceful fallback when Redis is unavailable (allows all requests).
"""
DEFAULT_LIMIT = 100 # requests per minute
AUTH_LIMIT = 10 # requests per minute for auth endpoints
WINDOW_SECONDS = 60
async def dispatch(self, request: Request, call_next):
settings = get_settings()
if not settings.REDIS_ENABLED:
return await call_next(request)
try:
from app.core.cache import get_redis
redis = await get_redis()
except Exception:
redis = None
if not redis:
return await call_next(request)
try:
is_auth = request.url.path.startswith("/api/v1/auth")
limit = self.AUTH_LIMIT if is_auth else self.DEFAULT_LIMIT
if is_auth:
client_ip = request.client.host if request.client else "unknown"
key = f"rl:auth:{client_ip}"
else:
# Use user token hash or client IP for rate limiting
auth_header = request.headers.get("Authorization", "")
if auth_header:
key = f"rl:user:{hash(auth_header)}"
else:
client_ip = request.client.host if request.client else "unknown"
key = f"rl:anon:{client_ip}"
current = await redis.incr(key)
if current == 1:
await redis.expire(key, self.WINDOW_SECONDS)
if current > limit:
ttl = await redis.ttl(key)
return JSONResponse(
status_code=429,
content={
"detail": "Too many requests",
"retry_after": max(ttl, 1),
},
headers={"Retry-After": str(max(ttl, 1))},
)
except Exception as e:
logger.warning("Rate limiting error (allowing request): %s", e)
return await call_next(request)