Files
tp-ems/backend/app/api/v1/websocket.py
Du Wenbo d8e4449f10 Squashed 'core/' content from commit 92ec910
git-subtree-dir: core
git-subtree-split: 92ec910a132e379a3a6e442a75bcb07cac0f0010
2026-04-04 18:16:49 +08:00

228 lines
7.3 KiB
Python

"""
WebSocket endpoint for real-time data push.
Provides instant updates to connected clients (BigScreen, dashboards)
instead of relying solely on polling.
"""
import asyncio
import logging
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from sqlalchemy import select, func, and_
from app.core.security import decode_access_token
from app.core.database import async_session
from app.models.device import Device
from app.models.energy import EnergyData
from app.models.alarm import AlarmEvent
logger = logging.getLogger("app.websocket")
router = APIRouter(tags=["WebSocket"])
class ConnectionManager:
"""Manages active WebSocket connections."""
def __init__(self):
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
logger.info(f"WebSocket connected. Total: {len(self.active_connections)}")
def disconnect(self, websocket: WebSocket):
if websocket in self.active_connections:
self.active_connections.remove(websocket)
logger.info(f"WebSocket disconnected. Total: {len(self.active_connections)}")
async def broadcast(self, message: dict):
"""Send message to all connected clients, removing dead connections."""
disconnected = []
for connection in self.active_connections:
try:
await connection.send_json(message)
except Exception:
disconnected.append(connection)
for conn in disconnected:
self.disconnect(conn)
manager = ConnectionManager()
# Background task reference
_broadcast_task: Optional[asyncio.Task] = None
async def get_realtime_snapshot() -> dict:
"""Fetch latest realtime data from the database.
Mirrors the logic in dashboard.get_realtime_data:
- Query recent power data points (last 5 min)
- Aggregate by device type (PV inverters vs heat pumps)
"""
try:
async with async_session() as db:
now = datetime.now(timezone.utc)
five_min_ago = now - timedelta(minutes=5)
# Get recent power data points
latest_q = await db.execute(
select(EnergyData).where(
and_(
EnergyData.timestamp >= five_min_ago,
EnergyData.data_type == "power",
)
).order_by(EnergyData.timestamp.desc()).limit(50)
)
data_points = latest_q.scalars().all()
# Get PV and heat pump device IDs
pv_q = await db.execute(
select(Device.id).where(
Device.device_type == "pv_inverter",
Device.is_active == True,
)
)
pv_ids = {r[0] for r in pv_q.fetchall()}
hp_q = await db.execute(
select(Device.id).where(
Device.device_type == "heat_pump",
Device.is_active == True,
)
)
hp_ids = {r[0] for r in hp_q.fetchall()}
pv_power = sum(d.value for d in data_points if d.device_id in pv_ids)
heatpump_power = sum(d.value for d in data_points if d.device_id in hp_ids)
total_load = pv_power + heatpump_power
grid_power = max(0, heatpump_power - pv_power)
# Count active alarms
alarm_count_q = await db.execute(
select(func.count(AlarmEvent.id)).where(
AlarmEvent.status == 'active'
)
)
active_alarms = alarm_count_q.scalar() or 0
return {
"pv_power": round(pv_power, 1),
"heatpump_power": round(heatpump_power, 1),
"total_load": round(total_load, 1),
"grid_power": round(grid_power, 1),
"active_alarms": active_alarms,
"timestamp": now.isoformat(),
}
except Exception as e:
logger.error(f"Error fetching realtime snapshot: {e}")
return {
"pv_power": 0,
"heatpump_power": 0,
"total_load": 0,
"grid_power": 0,
"active_alarms": 0,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
async def broadcast_loop():
"""Background task: broadcast realtime data every 15 seconds."""
while True:
try:
await asyncio.sleep(15)
if manager.active_connections:
data = await get_realtime_snapshot()
await manager.broadcast({
"type": "realtime_update",
"data": data,
})
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Broadcast loop error: {e}")
await asyncio.sleep(5)
async def broadcast_alarm_event(alarm_data: dict):
"""Called externally when a new alarm is triggered."""
if manager.active_connections:
await manager.broadcast({
"type": "alarm_event",
"data": alarm_data,
})
def start_broadcast_task():
"""Start the background broadcast loop. Call during app startup."""
global _broadcast_task
if _broadcast_task is None or _broadcast_task.done():
_broadcast_task = asyncio.create_task(broadcast_loop())
logger.info("WebSocket broadcast task started")
def stop_broadcast_task():
"""Stop the background broadcast loop. Call during app shutdown."""
global _broadcast_task
if _broadcast_task and not _broadcast_task.done():
_broadcast_task.cancel()
logger.info("WebSocket broadcast task stopped")
@router.websocket("/ws/realtime")
async def websocket_realtime(
websocket: WebSocket,
token: str = Query(default=""),
):
"""
WebSocket endpoint for real-time energy data.
Connect with: ws://host/api/v1/ws/realtime?token=<jwt_token>
Messages sent to clients:
- type: "realtime_update" - periodic snapshot every 15s
- type: "alarm_event" - when a new alarm triggers
"""
# Authenticate
if not token:
await websocket.close(code=4001, reason="Missing token")
return
payload = decode_access_token(token)
if payload is None:
await websocket.close(code=4001, reason="Invalid token")
return
await manager.connect(websocket)
# Ensure broadcast task is running
start_broadcast_task()
# Send initial data immediately
try:
initial_data = await get_realtime_snapshot()
await websocket.send_json({
"type": "realtime_update",
"data": initial_data,
})
except Exception as e:
logger.error(f"Error sending initial data: {e}")
# Keep connection alive and handle incoming messages
try:
while True:
# Wait for any client message (ping/pong, or just keep alive)
data = await websocket.receive_text()
# Client can send "ping" to keep alive
if data == "ping":
await websocket.send_json({"type": "pong"})
except WebSocketDisconnect:
manager.disconnect(websocket)
except Exception:
manager.disconnect(websocket)