from datetime import datetime, timedelta, timezone from fastapi import APIRouter, Depends, Query, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func, and_ from pydantic import BaseModel from app.core.database import get_db from app.core.deps import get_current_user from app.models.pricing import ElectricityPricing, PricingPeriod from app.models.energy import EnergyDailySummary from app.models.user import User from app.services.cost_calculator import get_cost_summary, get_cost_breakdown router = APIRouter(prefix="/cost", tags=["费用分析"]) # ---- Schemas ---- class PricingPeriodCreate(BaseModel): period_name: str start_time: str end_time: str price_per_unit: float applicable_months: list[int] | None = None class PricingCreate(BaseModel): name: str energy_type: str = "electricity" pricing_type: str # flat, tou, tiered effective_from: str | None = None effective_to: str | None = None periods: list[PricingPeriodCreate] = [] class PricingUpdate(BaseModel): name: str | None = None energy_type: str | None = None pricing_type: str | None = None effective_from: str | None = None effective_to: str | None = None is_active: bool | None = None # ---- Pricing CRUD ---- @router.get("/pricing") async def list_pricing( energy_type: str | None = None, db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """获取电价配置列表""" q = select(ElectricityPricing).order_by(ElectricityPricing.created_at.desc()) if energy_type: q = q.where(ElectricityPricing.energy_type == energy_type) result = await db.execute(q) pricings = result.scalars().all() items = [] for p in pricings: # Load periods pq = await db.execute(select(PricingPeriod).where(PricingPeriod.pricing_id == p.id)) periods = pq.scalars().all() items.append({ "id": p.id, "name": p.name, "energy_type": p.energy_type, "pricing_type": p.pricing_type, "is_active": p.is_active, "effective_from": str(p.effective_from) if p.effective_from else None, "effective_to": str(p.effective_to) if p.effective_to else None, "created_at": str(p.created_at), "periods": [ {"id": pp.id, "period_name": pp.period_name, "start_time": pp.start_time, "end_time": pp.end_time, "price_per_unit": pp.price_per_unit, "applicable_months": pp.applicable_months} for pp in periods ], }) return items @router.post("/pricing") async def create_pricing( data: PricingCreate, db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """创建电价配置""" pricing = ElectricityPricing( name=data.name, energy_type=data.energy_type, pricing_type=data.pricing_type, effective_from=datetime.fromisoformat(data.effective_from) if data.effective_from else None, effective_to=datetime.fromisoformat(data.effective_to) if data.effective_to else None, created_by=user.id, ) db.add(pricing) await db.flush() for period in data.periods: pp = PricingPeriod( pricing_id=pricing.id, period_name=period.period_name, start_time=period.start_time, end_time=period.end_time, price_per_unit=period.price_per_unit, applicable_months=period.applicable_months, ) db.add(pp) return {"id": pricing.id, "message": "电价配置创建成功"} @router.put("/pricing/{pricing_id}") async def update_pricing( pricing_id: int, data: PricingUpdate, db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """更新电价配置""" result = await db.execute(select(ElectricityPricing).where(ElectricityPricing.id == pricing_id)) pricing = result.scalar_one_or_none() if not pricing: raise HTTPException(status_code=404, detail="电价配置不存在") if data.name is not None: pricing.name = data.name if data.energy_type is not None: pricing.energy_type = data.energy_type if data.pricing_type is not None: pricing.pricing_type = data.pricing_type if data.effective_from is not None: pricing.effective_from = datetime.fromisoformat(data.effective_from) if data.effective_to is not None: pricing.effective_to = datetime.fromisoformat(data.effective_to) if data.is_active is not None: pricing.is_active = data.is_active return {"message": "电价配置更新成功"} @router.delete("/pricing/{pricing_id}") async def deactivate_pricing( pricing_id: int, db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """停用电价配置""" result = await db.execute(select(ElectricityPricing).where(ElectricityPricing.id == pricing_id)) pricing = result.scalar_one_or_none() if not pricing: raise HTTPException(status_code=404, detail="电价配置不存在") pricing.is_active = False return {"message": "电价配置已停用"} # ---- Pricing Periods ---- @router.get("/pricing/{pricing_id}/periods") async def list_periods( pricing_id: int, db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """获取电价时段列表""" result = await db.execute(select(PricingPeriod).where(PricingPeriod.pricing_id == pricing_id)) periods = result.scalars().all() return [ {"id": p.id, "period_name": p.period_name, "start_time": p.start_time, "end_time": p.end_time, "price_per_unit": p.price_per_unit, "applicable_months": p.applicable_months} for p in periods ] @router.post("/pricing/{pricing_id}/periods") async def add_period( pricing_id: int, data: PricingPeriodCreate, db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """添加电价时段""" result = await db.execute(select(ElectricityPricing).where(ElectricityPricing.id == pricing_id)) if not result.scalar_one_or_none(): raise HTTPException(status_code=404, detail="电价配置不存在") period = PricingPeriod( pricing_id=pricing_id, period_name=data.period_name, start_time=data.start_time, end_time=data.end_time, price_per_unit=data.price_per_unit, applicable_months=data.applicable_months, ) db.add(period) await db.flush() return {"id": period.id, "message": "时段添加成功"} # ---- Cost Analysis ---- @router.get("/summary") async def cost_summary( start_date: str = Query(..., description="开始日期, e.g. 2026-01-01"), end_date: str = Query(..., description="结束日期, e.g. 2026-03-31"), group_by: str = Query("day", pattern="^(day|month|device)$"), energy_type: str = Query("electricity"), db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """费用汇总""" start_dt = datetime.fromisoformat(start_date) end_dt = datetime.fromisoformat(end_date) return await get_cost_summary(db, start_dt, end_dt, group_by, energy_type) @router.get("/comparison") async def cost_comparison( energy_type: str = "electricity", period: str = Query("month", pattern="^(day|week|month|year)$"), db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """费用同比环比""" now = datetime.now(timezone.utc) if period == "day": current_start = now.replace(hour=0, minute=0, second=0, microsecond=0) prev_start = current_start - timedelta(days=1) yoy_start = current_start.replace(year=current_start.year - 1) elif period == "week": current_start = now - timedelta(days=now.weekday()) current_start = current_start.replace(hour=0, minute=0, second=0, microsecond=0) prev_start = current_start - timedelta(weeks=1) yoy_start = current_start.replace(year=current_start.year - 1) elif period == "month": current_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) prev_start = (current_start - timedelta(days=1)).replace(day=1) yoy_start = current_start.replace(year=current_start.year - 1) else: # year current_start = now.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0) prev_start = current_start.replace(year=current_start.year - 1) yoy_start = prev_start async def sum_cost(start, end): q = select(func.sum(EnergyDailySummary.cost)).where( and_( EnergyDailySummary.date >= start, EnergyDailySummary.date < end, EnergyDailySummary.energy_type == energy_type, ) ) r = await db.execute(q) return r.scalar() or 0 current = await sum_cost(current_start, now) previous = await sum_cost(prev_start, current_start) yoy = await sum_cost(yoy_start, yoy_start.replace(year=yoy_start.year + 1)) return { "current": round(current, 2), "previous": round(previous, 2), "yoy": round(yoy, 2), "mom_change": round((current - previous) / previous * 100, 1) if previous else 0, "yoy_change": round((current - yoy) / yoy * 100, 1) if yoy else 0, } @router.get("/breakdown") async def cost_breakdown_api( start_date: str = Query(..., description="开始日期"), end_date: str = Query(..., description="结束日期"), energy_type: str = Query("electricity"), db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """峰谷平费用分布""" start_dt = datetime.fromisoformat(start_date) end_dt = datetime.fromisoformat(end_date) return await get_cost_breakdown(db, start_dt, end_dt, energy_type)