Files
ems-core/backend/app/core/cache.py

149 lines
4.5 KiB
Python
Raw Normal View History

"""Redis caching layer with graceful fallback when Redis is unavailable."""
import json
import logging
from functools import wraps
from typing import Any, Optional
import redis.asyncio as aioredis
from app.core.config import get_settings
logger = logging.getLogger("cache")
_redis_pool: Optional[aioredis.Redis] = None
async def get_redis() -> Optional[aioredis.Redis]:
"""Get or create a global Redis connection pool.
Returns None if Redis is disabled or connection fails.
"""
global _redis_pool
settings = get_settings()
if not settings.REDIS_ENABLED:
return None
if _redis_pool is not None:
return _redis_pool
try:
_redis_pool = aioredis.from_url(
settings.REDIS_URL,
decode_responses=True,
max_connections=20,
)
# Verify connectivity
await _redis_pool.ping()
logger.info("Redis connection established: %s", settings.REDIS_URL)
return _redis_pool
except Exception as e:
logger.warning("Redis unavailable, caching disabled: %s", e)
_redis_pool = None
return None
async def close_redis():
"""Close the global Redis connection pool."""
global _redis_pool
if _redis_pool:
await _redis_pool.close()
_redis_pool = None
logger.info("Redis connection closed.")
class RedisCache:
"""Async Redis cache with JSON serialization and graceful fallback."""
def __init__(self, redis_client: Optional[aioredis.Redis] = None):
self._redis = redis_client
async def _get_client(self) -> Optional[aioredis.Redis]:
if self._redis is not None:
return self._redis
return await get_redis()
async def get(self, key: str) -> Optional[Any]:
"""Get a value from cache. Returns None on miss or error."""
client = await self._get_client()
if not client:
return None
try:
raw = await client.get(key)
if raw is None:
return None
return json.loads(raw)
except (json.JSONDecodeError, TypeError):
return raw
except Exception as e:
logger.warning("Cache get error for key=%s: %s", key, e)
return None
async def set(self, key: str, value: Any, ttl: int = 300) -> bool:
"""Set a value in cache with TTL in seconds."""
client = await self._get_client()
if not client:
return False
try:
serialized = json.dumps(value, ensure_ascii=False, default=str)
await client.set(key, serialized, ex=ttl)
return True
except Exception as e:
logger.warning("Cache set error for key=%s: %s", key, e)
return False
async def delete(self, key: str) -> bool:
"""Delete a key from cache."""
client = await self._get_client()
if not client:
return False
try:
await client.delete(key)
return True
except Exception as e:
logger.warning("Cache delete error for key=%s: %s", key, e)
return False
async def exists(self, key: str) -> bool:
"""Check if a key exists in cache."""
client = await self._get_client()
if not client:
return False
try:
return bool(await client.exists(key))
except Exception as e:
logger.warning("Cache exists error for key=%s: %s", key, e)
return False
def cache_response(prefix: str, ttl_seconds: int = 300):
"""Decorator to cache FastAPI endpoint responses in Redis.
Builds cache key from prefix + sorted query params.
Falls through to the endpoint when Redis is unavailable.
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# Build cache key from all keyword arguments
sorted_params = "&".join(
f"{k}={v}" for k, v in sorted(kwargs.items())
if v is not None and k != "db" and k != "user"
)
cache_key = f"{prefix}:{sorted_params}" if sorted_params else prefix
cache = RedisCache()
# Try cache hit
cached = await cache.get(cache_key)
if cached is not None:
return cached
# Call the actual endpoint
result = await func(*args, **kwargs)
# Store result in cache
await cache.set(cache_key, result, ttl=ttl_seconds)
return result
return wrapper
return decorator