Files
tp-ems/backend/app/api/v1/prediction.py
Du Wenbo d8e4449f10 Squashed 'core/' content from commit 92ec910
git-subtree-dir: core
git-subtree-split: 92ec910a132e379a3a6e442a75bcb07cac0f0010
2026-04-04 18:16:49 +08:00

186 lines
6.4 KiB
Python

"""AI预测引擎 API - 光伏/负荷/热泵预测 & 自发自用优化"""
from datetime import datetime, timezone, timedelta
from typing import Optional
from fastapi import APIRouter, Depends, Query, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_
from app.core.database import get_db
from app.core.deps import get_current_user, require_roles
from app.models.user import User
from app.models.prediction import PredictionTask, PredictionResult, OptimizationSchedule
from app.services.ai_prediction import (
forecast_pv, forecast_load, forecast_heatpump_cop,
optimize_self_consumption, get_prediction_accuracy, run_prediction,
)
router = APIRouter(prefix="/prediction", tags=["AI预测"])
# ── Schemas ────────────────────────────────────────────────────────────
class RunPredictionRequest(BaseModel):
device_id: Optional[int] = None
prediction_type: str # pv, load, heatpump, optimization
horizon_hours: int = 24
parameters: Optional[dict] = None
# ── Endpoints ──────────────────────────────────────────────────────────
@router.get("/forecast")
async def get_forecast(
device_id: Optional[int] = None,
type: str = Query("pv", pattern="^(pv|load|heatpump)$"),
horizon: int = Query(24, ge=1, le=168),
building_type: str = Query("office", pattern="^(office|factory)$"),
db: AsyncSession = Depends(get_db),
user: User = Depends(get_current_user),
):
"""获取预测结果 - PV发电/负荷/热泵COP"""
if type == "pv":
if not device_id:
raise HTTPException(400, "光伏预测需要指定device_id")
return await forecast_pv(db, device_id, horizon)
elif type == "load":
return await forecast_load(db, device_id, building_type, horizon)
elif type == "heatpump":
if not device_id:
raise HTTPException(400, "热泵预测需要指定device_id")
return await forecast_heatpump_cop(db, device_id, horizon)
@router.post("/run")
async def trigger_prediction(
req: RunPredictionRequest,
db: AsyncSession = Depends(get_db),
user: User = Depends(get_current_user),
):
"""触发新的预测任务"""
task = await run_prediction(
db, req.device_id, req.prediction_type,
req.horizon_hours, req.parameters,
)
return {
"task_id": task.id,
"status": task.status,
"prediction_type": task.prediction_type,
"horizon_hours": task.horizon_hours,
"error_message": task.error_message,
}
@router.get("/accuracy")
async def prediction_accuracy(
type: Optional[str] = Query(None, pattern="^(pv|load|heatpump|optimization)$"),
days: int = Query(7, ge=1, le=90),
db: AsyncSession = Depends(get_db),
user: User = Depends(get_current_user),
):
"""获取预测精度指标 (MAE, RMSE, MAPE)"""
return await get_prediction_accuracy(db, type, days)
@router.get("/optimization")
async def get_optimization(
horizon: int = Query(24, ge=1, le=72),
db: AsyncSession = Depends(get_db),
user: User = Depends(get_current_user),
):
"""获取自发自用优化建议"""
return await optimize_self_consumption(db, horizon)
@router.post("/optimization/{schedule_id}/approve")
async def approve_optimization(
schedule_id: int,
db: AsyncSession = Depends(get_db),
user: User = Depends(require_roles("admin", "energy_manager")),
):
"""审批优化调度方案"""
result = await db.execute(
select(OptimizationSchedule).where(OptimizationSchedule.id == schedule_id)
)
schedule = result.scalar_one_or_none()
if not schedule:
raise HTTPException(404, "优化方案不存在")
if schedule.status != "pending":
raise HTTPException(400, f"方案状态为 {schedule.status},无法审批")
schedule.status = "approved"
schedule.approved_by = user.id
schedule.approved_at = datetime.now(timezone.utc)
return {"id": schedule.id, "status": "approved"}
@router.get("/history")
async def prediction_history(
type: Optional[str] = Query(None),
days: int = Query(7, ge=1, le=30),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
db: AsyncSession = Depends(get_db),
user: User = Depends(get_current_user),
):
"""历史预测任务列表"""
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
conditions = [PredictionTask.created_at >= cutoff]
if type:
conditions.append(PredictionTask.prediction_type == type)
query = (
select(PredictionTask)
.where(and_(*conditions))
.order_by(PredictionTask.created_at.desc())
.offset((page - 1) * page_size)
.limit(page_size)
)
result = await db.execute(query)
tasks = result.scalars().all()
return [{
"id": t.id,
"device_id": t.device_id,
"prediction_type": t.prediction_type,
"horizon_hours": t.horizon_hours,
"status": t.status,
"created_at": str(t.created_at) if t.created_at else None,
"completed_at": str(t.completed_at) if t.completed_at else None,
"error_message": t.error_message,
} for t in tasks]
@router.get("/schedules")
async def list_schedules(
status: Optional[str] = Query(None, pattern="^(pending|approved|executed|rejected)$"),
days: int = Query(7, ge=1, le=30),
db: AsyncSession = Depends(get_db),
user: User = Depends(get_current_user),
):
"""获取优化调度方案列表"""
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
conditions = [OptimizationSchedule.created_at >= cutoff]
if status:
conditions.append(OptimizationSchedule.status == status)
result = await db.execute(
select(OptimizationSchedule)
.where(and_(*conditions))
.order_by(OptimizationSchedule.created_at.desc())
)
schedules = result.scalars().all()
return [{
"id": s.id,
"device_id": s.device_id,
"date": str(s.date) if s.date else None,
"expected_savings_kwh": s.expected_savings_kwh,
"expected_savings_yuan": s.expected_savings_yuan,
"status": s.status,
"schedule_data": s.schedule_data,
"created_at": str(s.created_at) if s.created_at else None,
} for s in schedules]