from datetime import datetime, timezone from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func, extract from pydantic import BaseModel from app.core.database import get_db from app.core.deps import get_current_user, require_roles from app.models.maintenance import BillingRecord from app.models.user import User router = APIRouter(prefix="/billing", tags=["电费结算"]) # ── Pydantic Schemas ──────────────────────────────────────────────── class BillingCreate(BaseModel): station_name: str billing_type: str # "generation", "consumption", "grid_feed" year: int month: int generation_kwh: float | None = None consumption_kwh: float | None = None grid_feed_kwh: float | None = None unit_price: float | None = None total_amount: float | None = None status: str = "draft" invoice_number: str | None = None invoice_date: str | None = None payment_date: str | None = None notes: str | None = None class BillingUpdate(BaseModel): station_name: str | None = None billing_type: str | None = None year: int | None = None month: int | None = None generation_kwh: float | None = None consumption_kwh: float | None = None grid_feed_kwh: float | None = None unit_price: float | None = None total_amount: float | None = None status: str | None = None invoice_number: str | None = None invoice_date: str | None = None payment_date: str | None = None notes: str | None = None # ── Helpers ───────────────────────────────────────────────────────── def _billing_to_dict(b: BillingRecord) -> dict: return { "id": b.id, "station_name": b.station_name, "billing_type": b.billing_type, "year": b.year, "month": b.month, "generation_kwh": b.generation_kwh, "consumption_kwh": b.consumption_kwh, "grid_feed_kwh": b.grid_feed_kwh, "unit_price": b.unit_price, "total_amount": b.total_amount, "status": b.status, "invoice_number": b.invoice_number, "invoice_date": str(b.invoice_date) if b.invoice_date else None, "payment_date": str(b.payment_date) if b.payment_date else None, "notes": b.notes, "created_by": b.created_by, "created_at": str(b.created_at) if b.created_at else None, "updated_at": str(b.updated_at) if b.updated_at else None, } # ── Billing CRUD ─────────────────────────────────────────────────── @router.get("") async def list_billing( station_name: str | None = None, billing_type: str | None = None, status: str | None = None, year: int | None = None, month: int | None = None, page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): query = select(BillingRecord) if station_name: query = query.where(BillingRecord.station_name == station_name) if billing_type: query = query.where(BillingRecord.billing_type == billing_type) if status: query = query.where(BillingRecord.status == status) if year: query = query.where(BillingRecord.year == year) if month: query = query.where(BillingRecord.month == month) count_q = select(func.count()).select_from(query.subquery()) total = (await db.execute(count_q)).scalar() query = query.order_by(BillingRecord.year.desc(), BillingRecord.month.desc(), BillingRecord.id.desc()) query = query.offset((page - 1) * page_size).limit(page_size) result = await db.execute(query) return { "total": total, "items": [_billing_to_dict(b) for b in result.scalars().all()], } @router.get("/stats") async def billing_stats( year: int | None = None, db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): query = select(BillingRecord) if year: query = query.where(BillingRecord.year == year) # Total generation gen_q = select(func.sum(BillingRecord.generation_kwh)) if year: gen_q = gen_q.where(BillingRecord.year == year) total_generation = (await db.execute(gen_q)).scalar() or 0 # Total amount amt_q = select(func.sum(BillingRecord.total_amount)) if year: amt_q = amt_q.where(BillingRecord.year == year) total_amount = (await db.execute(amt_q)).scalar() or 0 # By month month_q = select( BillingRecord.month, func.sum(BillingRecord.generation_kwh).label("generation"), func.sum(BillingRecord.total_amount).label("amount"), ).group_by(BillingRecord.month).order_by(BillingRecord.month) if year: month_q = month_q.where(BillingRecord.year == year) month_result = await db.execute(month_q) by_month = [ {"month": row[0], "generation_kwh": float(row[1] or 0), "total_amount": float(row[2] or 0)} for row in month_result.all() ] # By type type_q = select( BillingRecord.billing_type, func.sum(BillingRecord.total_amount).label("amount"), ).group_by(BillingRecord.billing_type) if year: type_q = type_q.where(BillingRecord.year == year) type_result = await db.execute(type_q) by_type = {row[0]: float(row[1] or 0) for row in type_result.all()} return { "total_generation_kwh": float(total_generation), "total_amount": float(total_amount), "by_month": by_month, "by_type": by_type, } @router.get("/export") async def export_billing( station_name: str | None = None, year: int | None = None, month: int | None = None, db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): query = select(BillingRecord) if station_name: query = query.where(BillingRecord.station_name == station_name) if year: query = query.where(BillingRecord.year == year) if month: query = query.where(BillingRecord.month == month) query = query.order_by(BillingRecord.year.desc(), BillingRecord.month.desc()) result = await db.execute(query) records = [_billing_to_dict(b) for b in result.scalars().all()] return { "columns": [ "station_name", "billing_type", "year", "month", "generation_kwh", "consumption_kwh", "grid_feed_kwh", "unit_price", "total_amount", "status", "invoice_number", "invoice_date", "payment_date", ], "data": records, } @router.get("/{billing_id}") async def get_billing( billing_id: int, db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): result = await db.execute(select(BillingRecord).where(BillingRecord.id == billing_id)) record = result.scalar_one_or_none() if not record: raise HTTPException(status_code=404, detail="结算记录不存在") return _billing_to_dict(record) @router.post("") async def create_billing( data: BillingCreate, db: AsyncSession = Depends(get_db), user: User = Depends(require_roles("admin", "energy_manager")), ): record = BillingRecord( **data.model_dump(exclude={"invoice_date", "payment_date"}), created_by=user.id, ) if data.invoice_date: record.invoice_date = datetime.fromisoformat(data.invoice_date) if data.payment_date: record.payment_date = datetime.fromisoformat(data.payment_date) db.add(record) await db.flush() return _billing_to_dict(record) @router.put("/{billing_id}") async def update_billing( billing_id: int, data: BillingUpdate, db: AsyncSession = Depends(get_db), user: User = Depends(require_roles("admin", "energy_manager")), ): result = await db.execute(select(BillingRecord).where(BillingRecord.id == billing_id)) record = result.scalar_one_or_none() if not record: raise HTTPException(status_code=404, detail="结算记录不存在") for k, v in data.model_dump(exclude_unset=True, exclude={"invoice_date", "payment_date"}).items(): setattr(record, k, v) if data.invoice_date: record.invoice_date = datetime.fromisoformat(data.invoice_date) if data.payment_date: record.payment_date = datetime.fromisoformat(data.payment_date) return _billing_to_dict(record)