"""AI预测引擎 API - 光伏/负荷/热泵预测 & 自发自用优化""" from datetime import datetime, timezone, timedelta from typing import Optional from fastapi import APIRouter, Depends, Query, HTTPException from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, and_ from app.core.database import get_db from app.core.deps import get_current_user, require_roles from app.models.user import User from app.models.prediction import PredictionTask, PredictionResult, OptimizationSchedule from app.services.ai_prediction import ( forecast_pv, forecast_load, forecast_heatpump_cop, optimize_self_consumption, get_prediction_accuracy, run_prediction, ) router = APIRouter(prefix="/prediction", tags=["AI预测"]) # ── Schemas ──────────────────────────────────────────────────────────── class RunPredictionRequest(BaseModel): device_id: Optional[int] = None prediction_type: str # pv, load, heatpump, optimization horizon_hours: int = 24 parameters: Optional[dict] = None # ── Endpoints ────────────────────────────────────────────────────────── @router.get("/forecast") async def get_forecast( device_id: Optional[int] = None, type: str = Query("pv", pattern="^(pv|load|heatpump)$"), horizon: int = Query(24, ge=1, le=168), building_type: str = Query("office", pattern="^(office|factory)$"), db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """获取预测结果 - PV发电/负荷/热泵COP""" if type == "pv": if not device_id: raise HTTPException(400, "光伏预测需要指定device_id") return await forecast_pv(db, device_id, horizon) elif type == "load": return await forecast_load(db, device_id, building_type, horizon) elif type == "heatpump": if not device_id: raise HTTPException(400, "热泵预测需要指定device_id") return await forecast_heatpump_cop(db, device_id, horizon) @router.post("/run") async def trigger_prediction( req: RunPredictionRequest, db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """触发新的预测任务""" task = await run_prediction( db, req.device_id, req.prediction_type, req.horizon_hours, req.parameters, ) return { "task_id": task.id, "status": task.status, "prediction_type": task.prediction_type, "horizon_hours": task.horizon_hours, "error_message": task.error_message, } @router.get("/accuracy") async def prediction_accuracy( type: Optional[str] = Query(None, pattern="^(pv|load|heatpump|optimization)$"), days: int = Query(7, ge=1, le=90), db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """获取预测精度指标 (MAE, RMSE, MAPE)""" return await get_prediction_accuracy(db, type, days) @router.get("/optimization") async def get_optimization( horizon: int = Query(24, ge=1, le=72), db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """获取自发自用优化建议""" return await optimize_self_consumption(db, horizon) @router.post("/optimization/{schedule_id}/approve") async def approve_optimization( schedule_id: int, db: AsyncSession = Depends(get_db), user: User = Depends(require_roles("admin", "energy_manager")), ): """审批优化调度方案""" result = await db.execute( select(OptimizationSchedule).where(OptimizationSchedule.id == schedule_id) ) schedule = result.scalar_one_or_none() if not schedule: raise HTTPException(404, "优化方案不存在") if schedule.status != "pending": raise HTTPException(400, f"方案状态为 {schedule.status},无法审批") schedule.status = "approved" schedule.approved_by = user.id schedule.approved_at = datetime.now(timezone.utc) return {"id": schedule.id, "status": "approved"} @router.get("/history") async def prediction_history( type: Optional[str] = Query(None), days: int = Query(7, ge=1, le=30), page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """历史预测任务列表""" cutoff = datetime.now(timezone.utc) - timedelta(days=days) conditions = [PredictionTask.created_at >= cutoff] if type: conditions.append(PredictionTask.prediction_type == type) query = ( select(PredictionTask) .where(and_(*conditions)) .order_by(PredictionTask.created_at.desc()) .offset((page - 1) * page_size) .limit(page_size) ) result = await db.execute(query) tasks = result.scalars().all() return [{ "id": t.id, "device_id": t.device_id, "prediction_type": t.prediction_type, "horizon_hours": t.horizon_hours, "status": t.status, "created_at": str(t.created_at) if t.created_at else None, "completed_at": str(t.completed_at) if t.completed_at else None, "error_message": t.error_message, } for t in tasks] @router.get("/schedules") async def list_schedules( status: Optional[str] = Query(None, pattern="^(pending|approved|executed|rejected)$"), days: int = Query(7, ge=1, le=30), db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """获取优化调度方案列表""" cutoff = datetime.now(timezone.utc) - timedelta(days=days) conditions = [OptimizationSchedule.created_at >= cutoff] if status: conditions.append(OptimizationSchedule.status == status) result = await db.execute( select(OptimizationSchedule) .where(and_(*conditions)) .order_by(OptimizationSchedule.created_at.desc()) ) schedules = result.scalars().all() return [{ "id": s.id, "device_id": s.device_id, "date": str(s.date) if s.date else None, "expected_savings_kwh": s.expected_savings_kwh, "expected_savings_yuan": s.expected_savings_yuan, "status": s.status, "schedule_data": s.schedule_data, "created_at": str(s.created_at) if s.created_at else None, } for s in schedules]