228 lines
7.3 KiB
Python
228 lines
7.3 KiB
Python
|
|
"""
|
||
|
|
WebSocket endpoint for real-time data push.
|
||
|
|
|
||
|
|
Provides instant updates to connected clients (BigScreen, dashboards)
|
||
|
|
instead of relying solely on polling.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import logging
|
||
|
|
from datetime import datetime, timedelta, timezone
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||
|
|
from sqlalchemy import select, func, and_
|
||
|
|
|
||
|
|
from app.core.security import decode_access_token
|
||
|
|
from app.core.database import async_session
|
||
|
|
from app.models.device import Device
|
||
|
|
from app.models.energy import EnergyData
|
||
|
|
from app.models.alarm import AlarmEvent
|
||
|
|
|
||
|
|
logger = logging.getLogger("app.websocket")
|
||
|
|
|
||
|
|
router = APIRouter(tags=["WebSocket"])
|
||
|
|
|
||
|
|
|
||
|
|
class ConnectionManager:
|
||
|
|
"""Manages active WebSocket connections."""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.active_connections: list[WebSocket] = []
|
||
|
|
|
||
|
|
async def connect(self, websocket: WebSocket):
|
||
|
|
await websocket.accept()
|
||
|
|
self.active_connections.append(websocket)
|
||
|
|
logger.info(f"WebSocket connected. Total: {len(self.active_connections)}")
|
||
|
|
|
||
|
|
def disconnect(self, websocket: WebSocket):
|
||
|
|
if websocket in self.active_connections:
|
||
|
|
self.active_connections.remove(websocket)
|
||
|
|
logger.info(f"WebSocket disconnected. Total: {len(self.active_connections)}")
|
||
|
|
|
||
|
|
async def broadcast(self, message: dict):
|
||
|
|
"""Send message to all connected clients, removing dead connections."""
|
||
|
|
disconnected = []
|
||
|
|
for connection in self.active_connections:
|
||
|
|
try:
|
||
|
|
await connection.send_json(message)
|
||
|
|
except Exception:
|
||
|
|
disconnected.append(connection)
|
||
|
|
for conn in disconnected:
|
||
|
|
self.disconnect(conn)
|
||
|
|
|
||
|
|
|
||
|
|
manager = ConnectionManager()
|
||
|
|
|
||
|
|
# Background task reference
|
||
|
|
_broadcast_task: Optional[asyncio.Task] = None
|
||
|
|
|
||
|
|
|
||
|
|
async def get_realtime_snapshot() -> dict:
|
||
|
|
"""Fetch latest realtime data from the database.
|
||
|
|
|
||
|
|
Mirrors the logic in dashboard.get_realtime_data:
|
||
|
|
- Query recent power data points (last 5 min)
|
||
|
|
- Aggregate by device type (PV inverters vs heat pumps)
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
async with async_session() as db:
|
||
|
|
now = datetime.now(timezone.utc)
|
||
|
|
five_min_ago = now - timedelta(minutes=5)
|
||
|
|
|
||
|
|
# Get recent power data points
|
||
|
|
latest_q = await db.execute(
|
||
|
|
select(EnergyData).where(
|
||
|
|
and_(
|
||
|
|
EnergyData.timestamp >= five_min_ago,
|
||
|
|
EnergyData.data_type == "power",
|
||
|
|
)
|
||
|
|
).order_by(EnergyData.timestamp.desc()).limit(50)
|
||
|
|
)
|
||
|
|
data_points = latest_q.scalars().all()
|
||
|
|
|
||
|
|
# Get PV and heat pump device IDs
|
||
|
|
pv_q = await db.execute(
|
||
|
|
select(Device.id).where(
|
||
|
|
Device.device_type == "pv_inverter",
|
||
|
|
Device.is_active == True,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
pv_ids = {r[0] for r in pv_q.fetchall()}
|
||
|
|
|
||
|
|
hp_q = await db.execute(
|
||
|
|
select(Device.id).where(
|
||
|
|
Device.device_type == "heat_pump",
|
||
|
|
Device.is_active == True,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
hp_ids = {r[0] for r in hp_q.fetchall()}
|
||
|
|
|
||
|
|
pv_power = sum(d.value for d in data_points if d.device_id in pv_ids)
|
||
|
|
heatpump_power = sum(d.value for d in data_points if d.device_id in hp_ids)
|
||
|
|
total_load = pv_power + heatpump_power
|
||
|
|
grid_power = max(0, heatpump_power - pv_power)
|
||
|
|
|
||
|
|
# Count active alarms
|
||
|
|
alarm_count_q = await db.execute(
|
||
|
|
select(func.count(AlarmEvent.id)).where(
|
||
|
|
AlarmEvent.status == 'active'
|
||
|
|
)
|
||
|
|
)
|
||
|
|
active_alarms = alarm_count_q.scalar() or 0
|
||
|
|
|
||
|
|
return {
|
||
|
|
"pv_power": round(pv_power, 1),
|
||
|
|
"heatpump_power": round(heatpump_power, 1),
|
||
|
|
"total_load": round(total_load, 1),
|
||
|
|
"grid_power": round(grid_power, 1),
|
||
|
|
"active_alarms": active_alarms,
|
||
|
|
"timestamp": now.isoformat(),
|
||
|
|
}
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error fetching realtime snapshot: {e}")
|
||
|
|
return {
|
||
|
|
"pv_power": 0,
|
||
|
|
"heatpump_power": 0,
|
||
|
|
"total_load": 0,
|
||
|
|
"grid_power": 0,
|
||
|
|
"active_alarms": 0,
|
||
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
async def broadcast_loop():
|
||
|
|
"""Background task: broadcast realtime data every 15 seconds."""
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
await asyncio.sleep(15)
|
||
|
|
if manager.active_connections:
|
||
|
|
data = await get_realtime_snapshot()
|
||
|
|
await manager.broadcast({
|
||
|
|
"type": "realtime_update",
|
||
|
|
"data": data,
|
||
|
|
})
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
break
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Broadcast loop error: {e}")
|
||
|
|
await asyncio.sleep(5)
|
||
|
|
|
||
|
|
|
||
|
|
async def broadcast_alarm_event(alarm_data: dict):
|
||
|
|
"""Called externally when a new alarm is triggered."""
|
||
|
|
if manager.active_connections:
|
||
|
|
await manager.broadcast({
|
||
|
|
"type": "alarm_event",
|
||
|
|
"data": alarm_data,
|
||
|
|
})
|
||
|
|
|
||
|
|
|
||
|
|
def start_broadcast_task():
|
||
|
|
"""Start the background broadcast loop. Call during app startup."""
|
||
|
|
global _broadcast_task
|
||
|
|
if _broadcast_task is None or _broadcast_task.done():
|
||
|
|
_broadcast_task = asyncio.create_task(broadcast_loop())
|
||
|
|
logger.info("WebSocket broadcast task started")
|
||
|
|
|
||
|
|
|
||
|
|
def stop_broadcast_task():
|
||
|
|
"""Stop the background broadcast loop. Call during app shutdown."""
|
||
|
|
global _broadcast_task
|
||
|
|
if _broadcast_task and not _broadcast_task.done():
|
||
|
|
_broadcast_task.cancel()
|
||
|
|
logger.info("WebSocket broadcast task stopped")
|
||
|
|
|
||
|
|
|
||
|
|
@router.websocket("/ws/realtime")
|
||
|
|
async def websocket_realtime(
|
||
|
|
websocket: WebSocket,
|
||
|
|
token: str = Query(default=""),
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
WebSocket endpoint for real-time energy data.
|
||
|
|
|
||
|
|
Connect with: ws://host/api/v1/ws/realtime?token=<jwt_token>
|
||
|
|
|
||
|
|
Messages sent to clients:
|
||
|
|
- type: "realtime_update" - periodic snapshot every 15s
|
||
|
|
- type: "alarm_event" - when a new alarm triggers
|
||
|
|
"""
|
||
|
|
# Authenticate
|
||
|
|
if not token:
|
||
|
|
await websocket.close(code=4001, reason="Missing token")
|
||
|
|
return
|
||
|
|
|
||
|
|
payload = decode_access_token(token)
|
||
|
|
if payload is None:
|
||
|
|
await websocket.close(code=4001, reason="Invalid token")
|
||
|
|
return
|
||
|
|
|
||
|
|
await manager.connect(websocket)
|
||
|
|
|
||
|
|
# Ensure broadcast task is running
|
||
|
|
start_broadcast_task()
|
||
|
|
|
||
|
|
# Send initial data immediately
|
||
|
|
try:
|
||
|
|
initial_data = await get_realtime_snapshot()
|
||
|
|
await websocket.send_json({
|
||
|
|
"type": "realtime_update",
|
||
|
|
"data": initial_data,
|
||
|
|
})
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error sending initial data: {e}")
|
||
|
|
|
||
|
|
# Keep connection alive and handle incoming messages
|
||
|
|
try:
|
||
|
|
while True:
|
||
|
|
# Wait for any client message (ping/pong, or just keep alive)
|
||
|
|
data = await websocket.receive_text()
|
||
|
|
# Client can send "ping" to keep alive
|
||
|
|
if data == "ping":
|
||
|
|
await websocket.send_json({"type": "pong"})
|
||
|
|
except WebSocketDisconnect:
|
||
|
|
manager.disconnect(websocket)
|
||
|
|
except Exception:
|
||
|
|
manager.disconnect(websocket)
|