Squashed 'core/' content from commit 92ec910

git-subtree-dir: core
git-subtree-split: 92ec910a132e379a3a6e442a75bcb07cac0f0010
This commit is contained in:
Du Wenbo
2026-04-04 18:17:10 +08:00
commit 026c837b91
227 changed files with 39179 additions and 0 deletions

View File

148
backend/app/core/cache.py Normal file
View File

@@ -0,0 +1,148 @@
"""Redis caching layer with graceful fallback when Redis is unavailable."""
import json
import logging
from functools import wraps
from typing import Any, Optional
import redis.asyncio as aioredis
from app.core.config import get_settings
logger = logging.getLogger("cache")
_redis_pool: Optional[aioredis.Redis] = None
async def get_redis() -> Optional[aioredis.Redis]:
"""Get or create a global Redis connection pool.
Returns None if Redis is disabled or connection fails.
"""
global _redis_pool
settings = get_settings()
if not settings.REDIS_ENABLED:
return None
if _redis_pool is not None:
return _redis_pool
try:
_redis_pool = aioredis.from_url(
settings.REDIS_URL,
decode_responses=True,
max_connections=20,
)
# Verify connectivity
await _redis_pool.ping()
logger.info("Redis connection established: %s", settings.REDIS_URL)
return _redis_pool
except Exception as e:
logger.warning("Redis unavailable, caching disabled: %s", e)
_redis_pool = None
return None
async def close_redis():
"""Close the global Redis connection pool."""
global _redis_pool
if _redis_pool:
await _redis_pool.close()
_redis_pool = None
logger.info("Redis connection closed.")
class RedisCache:
"""Async Redis cache with JSON serialization and graceful fallback."""
def __init__(self, redis_client: Optional[aioredis.Redis] = None):
self._redis = redis_client
async def _get_client(self) -> Optional[aioredis.Redis]:
if self._redis is not None:
return self._redis
return await get_redis()
async def get(self, key: str) -> Optional[Any]:
"""Get a value from cache. Returns None on miss or error."""
client = await self._get_client()
if not client:
return None
try:
raw = await client.get(key)
if raw is None:
return None
return json.loads(raw)
except (json.JSONDecodeError, TypeError):
return raw
except Exception as e:
logger.warning("Cache get error for key=%s: %s", key, e)
return None
async def set(self, key: str, value: Any, ttl: int = 300) -> bool:
"""Set a value in cache with TTL in seconds."""
client = await self._get_client()
if not client:
return False
try:
serialized = json.dumps(value, ensure_ascii=False, default=str)
await client.set(key, serialized, ex=ttl)
return True
except Exception as e:
logger.warning("Cache set error for key=%s: %s", key, e)
return False
async def delete(self, key: str) -> bool:
"""Delete a key from cache."""
client = await self._get_client()
if not client:
return False
try:
await client.delete(key)
return True
except Exception as e:
logger.warning("Cache delete error for key=%s: %s", key, e)
return False
async def exists(self, key: str) -> bool:
"""Check if a key exists in cache."""
client = await self._get_client()
if not client:
return False
try:
return bool(await client.exists(key))
except Exception as e:
logger.warning("Cache exists error for key=%s: %s", key, e)
return False
def cache_response(prefix: str, ttl_seconds: int = 300):
"""Decorator to cache FastAPI endpoint responses in Redis.
Builds cache key from prefix + sorted query params.
Falls through to the endpoint when Redis is unavailable.
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# Build cache key from all keyword arguments
sorted_params = "&".join(
f"{k}={v}" for k, v in sorted(kwargs.items())
if v is not None and k != "db" and k != "user"
)
cache_key = f"{prefix}:{sorted_params}" if sorted_params else prefix
cache = RedisCache()
# Try cache hit
cached = await cache.get(cache_key)
if cached is not None:
return cached
# Call the actual endpoint
result = await func(*args, **kwargs)
# Store result in cache
await cache.set(cache_key, result, ttl=ttl_seconds)
return result
return wrapper
return decorator

View File

@@ -0,0 +1,88 @@
from pydantic_settings import BaseSettings
from functools import lru_cache
import os
import yaml
class Settings(BaseSettings):
APP_NAME: str = "TianpuEMS"
DEBUG: bool = True
API_V1_PREFIX: str = "/api/v1"
# Customer configuration
CUSTOMER: str = "tianpu" # tianpu, zpark, etc.
CUSTOMER_DISPLAY_NAME: str = "" # Loaded from customer config
# Database: set DATABASE_URL in .env to override.
# Default: SQLite for local dev. Docker sets PostgreSQL via env var.
# Examples:
# SQLite: sqlite+aiosqlite:///./tianpu_ems.db
# PostgreSQL: postgresql+asyncpg://tianpu:tianpu2026@localhost:5432/tianpu_ems
DATABASE_URL: str = "sqlite+aiosqlite:///./tianpu_ems.db"
REDIS_URL: str = "redis://localhost:6379/0"
SECRET_KEY: str = "tianpu-ems-secret-key-change-in-production-2026"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 480
CELERY_ENABLED: bool = False # Set True when Celery worker is running
USE_SIMULATOR: bool = True # True=simulator mode, False=real IoT collectors
# Infrastructure flags
TIMESCALE_ENABLED: bool = False
REDIS_ENABLED: bool = True
INGESTION_QUEUE_ENABLED: bool = False
AGGREGATION_ENABLED: bool = True
# SMTP Email settings
SMTP_HOST: str = ""
SMTP_PORT: int = 587
SMTP_USER: str = ""
SMTP_PASSWORD: str = ""
SMTP_FROM: str = "noreply@tianpu-ems.com"
SMTP_ENABLED: bool = False
# Platform URL for links in emails
PLATFORM_URL: str = "http://localhost:3000"
@property
def DATABASE_URL_SYNC(self) -> str:
"""Derive synchronous URL from async DATABASE_URL for Alembic."""
url = self.DATABASE_URL
return url.replace("+aiosqlite", "").replace("+asyncpg", "+psycopg2")
@property
def is_sqlite(self) -> bool:
return "sqlite" in self.DATABASE_URL
@property
def customer_config_path(self) -> str:
"""Search for customer config in multiple locations."""
backend_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Standalone: project_root/customers/{CUSTOMER}/
path1 = os.path.join(backend_dir, "..", "customers", self.CUSTOMER)
if os.path.isdir(path1):
return os.path.abspath(path1)
# Subtree: customer_project_root/customers/{CUSTOMER}/ (core is 2 levels up)
path2 = os.path.join(backend_dir, "..", "..", "customers", self.CUSTOMER)
if os.path.isdir(path2):
return os.path.abspath(path2)
return os.path.abspath(path1) # Default fallback
def load_customer_config(self) -> dict:
"""Load customer-specific config from customers/{CUSTOMER}/config.yaml"""
config_file = os.path.join(self.customer_config_path, "config.yaml")
if os.path.exists(config_file):
with open(config_file, 'r', encoding='utf-8') as f:
return yaml.safe_load(f) or {}
return {}
class Config:
env_file = ".env"
extra = "ignore"
@lru_cache
def get_settings() -> Settings:
return Settings()

View File

@@ -0,0 +1,27 @@
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.orm import DeclarativeBase
from app.core.config import get_settings
settings = get_settings()
engine_kwargs = {"echo": settings.DEBUG}
if not settings.is_sqlite:
engine_kwargs["pool_size"] = 20
engine_kwargs["max_overflow"] = 10
engine = create_async_engine(settings.DATABASE_URL, **engine_kwargs)
async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
class Base(DeclarativeBase):
pass
async def get_db():
async with async_session() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise

34
backend/app/core/deps.py Normal file
View File

@@ -0,0 +1,34 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.database import get_db
from app.core.security import decode_access_token
from app.models.user import User
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db),
) -> User:
payload = decode_access_token(token)
if payload is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭据")
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭据")
result = await db.execute(select(User).where(User.id == int(user_id)))
user = result.scalar_one_or_none()
if user is None or not user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或已禁用")
return user
def require_roles(*roles: str):
async def checker(user: User = Depends(get_current_user)):
if user.role not in roles:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="权限不足")
return user
return checker

View File

@@ -0,0 +1,86 @@
"""Custom middleware for request tracking and rate limiting."""
import logging
import time
import uuid
from typing import Optional
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from app.core.config import get_settings
logger = logging.getLogger("middleware")
class RequestIdMiddleware(BaseHTTPMiddleware):
"""Adds X-Request-ID header to every response."""
async def dispatch(self, request: Request, call_next):
request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))
request.state.request_id = request_id
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Redis-based rate limiting middleware.
Default: 100 requests/minute per user.
Auth endpoints: 10 requests/minute per IP.
Graceful fallback when Redis is unavailable (allows all requests).
"""
DEFAULT_LIMIT = 100 # requests per minute
AUTH_LIMIT = 10 # requests per minute for auth endpoints
WINDOW_SECONDS = 60
async def dispatch(self, request: Request, call_next):
settings = get_settings()
if not settings.REDIS_ENABLED:
return await call_next(request)
try:
from app.core.cache import get_redis
redis = await get_redis()
except Exception:
redis = None
if not redis:
return await call_next(request)
try:
is_auth = request.url.path.startswith("/api/v1/auth")
limit = self.AUTH_LIMIT if is_auth else self.DEFAULT_LIMIT
if is_auth:
client_ip = request.client.host if request.client else "unknown"
key = f"rl:auth:{client_ip}"
else:
# Use user token hash or client IP for rate limiting
auth_header = request.headers.get("Authorization", "")
if auth_header:
key = f"rl:user:{hash(auth_header)}"
else:
client_ip = request.client.host if request.client else "unknown"
key = f"rl:anon:{client_ip}"
current = await redis.incr(key)
if current == 1:
await redis.expire(key, self.WINDOW_SECONDS)
if current > limit:
ttl = await redis.ttl(key)
return JSONResponse(
status_code=429,
content={
"detail": "Too many requests",
"retry_after": max(ttl, 1),
},
headers={"Retry-After": str(max(ttl, 1))},
)
except Exception as e:
logger.warning("Rate limiting error (allowing request): %s", e)
return await call_next(request)

View File

@@ -0,0 +1,29 @@
from datetime import datetime, timedelta, timezone
from jose import jwt, JWTError
from passlib.context import CryptContext
from app.core.config import get_settings
settings = get_settings()
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def hash_password(password: str) -> str:
return pwd_context.hash(password)
def verify_password(plain: str, hashed: str) -> bool:
return pwd_context.verify(plain, hashed)
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
to_encode = data.copy()
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
to_encode.update({"exp": expire})
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
def decode_access_token(token: str) -> dict | None:
try:
return jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
except JWTError:
return None