149 lines
4.5 KiB
Python
149 lines
4.5 KiB
Python
"""Redis caching layer with graceful fallback when Redis is unavailable."""
|
|
import json
|
|
import logging
|
|
from functools import wraps
|
|
from typing import Any, Optional
|
|
|
|
import redis.asyncio as aioredis
|
|
|
|
from app.core.config import get_settings
|
|
|
|
logger = logging.getLogger("cache")
|
|
|
|
_redis_pool: Optional[aioredis.Redis] = None
|
|
|
|
|
|
async def get_redis() -> Optional[aioredis.Redis]:
|
|
"""Get or create a global Redis connection pool.
|
|
|
|
Returns None if Redis is disabled or connection fails.
|
|
"""
|
|
global _redis_pool
|
|
settings = get_settings()
|
|
if not settings.REDIS_ENABLED:
|
|
return None
|
|
if _redis_pool is not None:
|
|
return _redis_pool
|
|
try:
|
|
_redis_pool = aioredis.from_url(
|
|
settings.REDIS_URL,
|
|
decode_responses=True,
|
|
max_connections=20,
|
|
)
|
|
# Verify connectivity
|
|
await _redis_pool.ping()
|
|
logger.info("Redis connection established: %s", settings.REDIS_URL)
|
|
return _redis_pool
|
|
except Exception as e:
|
|
logger.warning("Redis unavailable, caching disabled: %s", e)
|
|
_redis_pool = None
|
|
return None
|
|
|
|
|
|
async def close_redis():
|
|
"""Close the global Redis connection pool."""
|
|
global _redis_pool
|
|
if _redis_pool:
|
|
await _redis_pool.close()
|
|
_redis_pool = None
|
|
logger.info("Redis connection closed.")
|
|
|
|
|
|
class RedisCache:
|
|
"""Async Redis cache with JSON serialization and graceful fallback."""
|
|
|
|
def __init__(self, redis_client: Optional[aioredis.Redis] = None):
|
|
self._redis = redis_client
|
|
|
|
async def _get_client(self) -> Optional[aioredis.Redis]:
|
|
if self._redis is not None:
|
|
return self._redis
|
|
return await get_redis()
|
|
|
|
async def get(self, key: str) -> Optional[Any]:
|
|
"""Get a value from cache. Returns None on miss or error."""
|
|
client = await self._get_client()
|
|
if not client:
|
|
return None
|
|
try:
|
|
raw = await client.get(key)
|
|
if raw is None:
|
|
return None
|
|
return json.loads(raw)
|
|
except (json.JSONDecodeError, TypeError):
|
|
return raw
|
|
except Exception as e:
|
|
logger.warning("Cache get error for key=%s: %s", key, e)
|
|
return None
|
|
|
|
async def set(self, key: str, value: Any, ttl: int = 300) -> bool:
|
|
"""Set a value in cache with TTL in seconds."""
|
|
client = await self._get_client()
|
|
if not client:
|
|
return False
|
|
try:
|
|
serialized = json.dumps(value, ensure_ascii=False, default=str)
|
|
await client.set(key, serialized, ex=ttl)
|
|
return True
|
|
except Exception as e:
|
|
logger.warning("Cache set error for key=%s: %s", key, e)
|
|
return False
|
|
|
|
async def delete(self, key: str) -> bool:
|
|
"""Delete a key from cache."""
|
|
client = await self._get_client()
|
|
if not client:
|
|
return False
|
|
try:
|
|
await client.delete(key)
|
|
return True
|
|
except Exception as e:
|
|
logger.warning("Cache delete error for key=%s: %s", key, e)
|
|
return False
|
|
|
|
async def exists(self, key: str) -> bool:
|
|
"""Check if a key exists in cache."""
|
|
client = await self._get_client()
|
|
if not client:
|
|
return False
|
|
try:
|
|
return bool(await client.exists(key))
|
|
except Exception as e:
|
|
logger.warning("Cache exists error for key=%s: %s", key, e)
|
|
return False
|
|
|
|
|
|
def cache_response(prefix: str, ttl_seconds: int = 300):
|
|
"""Decorator to cache FastAPI endpoint responses in Redis.
|
|
|
|
Builds cache key from prefix + sorted query params.
|
|
Falls through to the endpoint when Redis is unavailable.
|
|
"""
|
|
|
|
def decorator(func):
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
# Build cache key from all keyword arguments
|
|
sorted_params = "&".join(
|
|
f"{k}={v}" for k, v in sorted(kwargs.items())
|
|
if v is not None and k != "db" and k != "user"
|
|
)
|
|
cache_key = f"{prefix}:{sorted_params}" if sorted_params else prefix
|
|
|
|
cache = RedisCache()
|
|
# Try cache hit
|
|
cached = await cache.get(cache_key)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
# Call the actual endpoint
|
|
result = await func(*args, **kwargs)
|
|
|
|
# Store result in cache
|
|
await cache.set(cache_key, result, ttl=ttl_seconds)
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
return decorator
|