""" WebSocket endpoint for real-time data push. Provides instant updates to connected clients (BigScreen, dashboards) instead of relying solely on polling. """ import asyncio import logging from datetime import datetime, timedelta, timezone from typing import Optional from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query from sqlalchemy import select, func, and_ from app.core.security import decode_access_token from app.core.database import async_session from app.models.device import Device from app.models.energy import EnergyData from app.models.alarm import AlarmEvent logger = logging.getLogger("app.websocket") router = APIRouter(tags=["WebSocket"]) class ConnectionManager: """Manages active WebSocket connections.""" def __init__(self): self.active_connections: list[WebSocket] = [] async def connect(self, websocket: WebSocket): await websocket.accept() self.active_connections.append(websocket) logger.info(f"WebSocket connected. Total: {len(self.active_connections)}") def disconnect(self, websocket: WebSocket): if websocket in self.active_connections: self.active_connections.remove(websocket) logger.info(f"WebSocket disconnected. Total: {len(self.active_connections)}") async def broadcast(self, message: dict): """Send message to all connected clients, removing dead connections.""" disconnected = [] for connection in self.active_connections: try: await connection.send_json(message) except Exception: disconnected.append(connection) for conn in disconnected: self.disconnect(conn) manager = ConnectionManager() # Background task reference _broadcast_task: Optional[asyncio.Task] = None async def get_realtime_snapshot() -> dict: """Fetch latest realtime data from the database. Mirrors the logic in dashboard.get_realtime_data: - Query recent power data points (last 5 min) - Aggregate by device type (PV inverters vs heat pumps) """ try: async with async_session() as db: now = datetime.now(timezone.utc) five_min_ago = now - timedelta(minutes=5) # Get recent power data points latest_q = await db.execute( select(EnergyData).where( and_( EnergyData.timestamp >= five_min_ago, EnergyData.data_type == "power", ) ).order_by(EnergyData.timestamp.desc()).limit(50) ) data_points = latest_q.scalars().all() # Get PV and heat pump device IDs pv_q = await db.execute( select(Device.id).where( Device.device_type == "pv_inverter", Device.is_active == True, ) ) pv_ids = {r[0] for r in pv_q.fetchall()} hp_q = await db.execute( select(Device.id).where( Device.device_type == "heat_pump", Device.is_active == True, ) ) hp_ids = {r[0] for r in hp_q.fetchall()} pv_power = sum(d.value for d in data_points if d.device_id in pv_ids) heatpump_power = sum(d.value for d in data_points if d.device_id in hp_ids) total_load = pv_power + heatpump_power grid_power = max(0, heatpump_power - pv_power) # Count active alarms alarm_count_q = await db.execute( select(func.count(AlarmEvent.id)).where( AlarmEvent.status == 'active' ) ) active_alarms = alarm_count_q.scalar() or 0 return { "pv_power": round(pv_power, 1), "heatpump_power": round(heatpump_power, 1), "total_load": round(total_load, 1), "grid_power": round(grid_power, 1), "active_alarms": active_alarms, "timestamp": now.isoformat(), } except Exception as e: logger.error(f"Error fetching realtime snapshot: {e}") return { "pv_power": 0, "heatpump_power": 0, "total_load": 0, "grid_power": 0, "active_alarms": 0, "timestamp": datetime.now(timezone.utc).isoformat(), } async def broadcast_loop(): """Background task: broadcast realtime data every 15 seconds.""" while True: try: await asyncio.sleep(15) if manager.active_connections: data = await get_realtime_snapshot() await manager.broadcast({ "type": "realtime_update", "data": data, }) except asyncio.CancelledError: break except Exception as e: logger.error(f"Broadcast loop error: {e}") await asyncio.sleep(5) async def broadcast_alarm_event(alarm_data: dict): """Called externally when a new alarm is triggered.""" if manager.active_connections: await manager.broadcast({ "type": "alarm_event", "data": alarm_data, }) def start_broadcast_task(): """Start the background broadcast loop. Call during app startup.""" global _broadcast_task if _broadcast_task is None or _broadcast_task.done(): _broadcast_task = asyncio.create_task(broadcast_loop()) logger.info("WebSocket broadcast task started") def stop_broadcast_task(): """Stop the background broadcast loop. Call during app shutdown.""" global _broadcast_task if _broadcast_task and not _broadcast_task.done(): _broadcast_task.cancel() logger.info("WebSocket broadcast task stopped") @router.websocket("/ws/realtime") async def websocket_realtime( websocket: WebSocket, token: str = Query(default=""), ): """ WebSocket endpoint for real-time energy data. Connect with: ws://host/api/v1/ws/realtime?token= Messages sent to clients: - type: "realtime_update" - periodic snapshot every 15s - type: "alarm_event" - when a new alarm triggers """ # Authenticate if not token: await websocket.close(code=4001, reason="Missing token") return payload = decode_access_token(token) if payload is None: await websocket.close(code=4001, reason="Invalid token") return await manager.connect(websocket) # Ensure broadcast task is running start_broadcast_task() # Send initial data immediately try: initial_data = await get_realtime_snapshot() await websocket.send_json({ "type": "realtime_update", "data": initial_data, }) except Exception as e: logger.error(f"Error sending initial data: {e}") # Keep connection alive and handle incoming messages try: while True: # Wait for any client message (ping/pong, or just keep alive) data = await websocket.receive_text() # Client can send "ping" to keep alive if data == "ping": await websocket.send_json({"type": "pong"}) except WebSocketDisconnect: manager.disconnect(websocket) except Exception: manager.disconnect(websocket)