"""Redis caching layer with graceful fallback when Redis is unavailable.""" import json import logging from functools import wraps from typing import Any, Optional import redis.asyncio as aioredis from app.core.config import get_settings logger = logging.getLogger("cache") _redis_pool: Optional[aioredis.Redis] = None async def get_redis() -> Optional[aioredis.Redis]: """Get or create a global Redis connection pool. Returns None if Redis is disabled or connection fails. """ global _redis_pool settings = get_settings() if not settings.REDIS_ENABLED: return None if _redis_pool is not None: return _redis_pool try: _redis_pool = aioredis.from_url( settings.REDIS_URL, decode_responses=True, max_connections=20, ) # Verify connectivity await _redis_pool.ping() logger.info("Redis connection established: %s", settings.REDIS_URL) return _redis_pool except Exception as e: logger.warning("Redis unavailable, caching disabled: %s", e) _redis_pool = None return None async def close_redis(): """Close the global Redis connection pool.""" global _redis_pool if _redis_pool: await _redis_pool.close() _redis_pool = None logger.info("Redis connection closed.") class RedisCache: """Async Redis cache with JSON serialization and graceful fallback.""" def __init__(self, redis_client: Optional[aioredis.Redis] = None): self._redis = redis_client async def _get_client(self) -> Optional[aioredis.Redis]: if self._redis is not None: return self._redis return await get_redis() async def get(self, key: str) -> Optional[Any]: """Get a value from cache. Returns None on miss or error.""" client = await self._get_client() if not client: return None try: raw = await client.get(key) if raw is None: return None return json.loads(raw) except (json.JSONDecodeError, TypeError): return raw except Exception as e: logger.warning("Cache get error for key=%s: %s", key, e) return None async def set(self, key: str, value: Any, ttl: int = 300) -> bool: """Set a value in cache with TTL in seconds.""" client = await self._get_client() if not client: return False try: serialized = json.dumps(value, ensure_ascii=False, default=str) await client.set(key, serialized, ex=ttl) return True except Exception as e: logger.warning("Cache set error for key=%s: %s", key, e) return False async def delete(self, key: str) -> bool: """Delete a key from cache.""" client = await self._get_client() if not client: return False try: await client.delete(key) return True except Exception as e: logger.warning("Cache delete error for key=%s: %s", key, e) return False async def exists(self, key: str) -> bool: """Check if a key exists in cache.""" client = await self._get_client() if not client: return False try: return bool(await client.exists(key)) except Exception as e: logger.warning("Cache exists error for key=%s: %s", key, e) return False def cache_response(prefix: str, ttl_seconds: int = 300): """Decorator to cache FastAPI endpoint responses in Redis. Builds cache key from prefix + sorted query params. Falls through to the endpoint when Redis is unavailable. """ def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): # Build cache key from all keyword arguments sorted_params = "&".join( f"{k}={v}" for k, v in sorted(kwargs.items()) if v is not None and k != "db" and k != "user" ) cache_key = f"{prefix}:{sorted_params}" if sorted_params else prefix cache = RedisCache() # Try cache hit cached = await cache.get(cache_key) if cached is not None: return cached # Call the actual endpoint result = await func(*args, **kwargs) # Store result in cache await cache.set(cache_key, result, ttl=ttl_seconds) return result return wrapper return decorator