Squashed 'core/' content from commit 92ec910
git-subtree-dir: core git-subtree-split: 92ec910a132e379a3a6e442a75bcb07cac0f0010
This commit is contained in:
5
backend/app/collectors/__init__.py
Normal file
5
backend/app/collectors/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""IoT data collection framework with protocol-specific collectors."""
|
||||
from app.collectors.base import BaseCollector
|
||||
from app.collectors.manager import CollectorManager, COLLECTOR_REGISTRY
|
||||
|
||||
__all__ = ["BaseCollector", "CollectorManager", "COLLECTOR_REGISTRY"]
|
||||
160
backend/app/collectors/base.py
Normal file
160
backend/app/collectors/base.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Base collector abstract class for IoT data collection."""
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.database import async_session
|
||||
from app.models.device import Device
|
||||
from app.models.energy import EnergyData
|
||||
|
||||
|
||||
class BaseCollector(ABC):
|
||||
"""Abstract base class for all protocol collectors."""
|
||||
|
||||
MAX_BACKOFF = 300 # 5 minutes max backoff
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_id: int,
|
||||
device_code: str,
|
||||
connection_params: dict,
|
||||
collect_interval: int = 15,
|
||||
):
|
||||
self.device_id = device_id
|
||||
self.device_code = device_code
|
||||
self.connection_params = connection_params or {}
|
||||
self.collect_interval = collect_interval
|
||||
self.status = "disconnected"
|
||||
self.last_error: Optional[str] = None
|
||||
self.last_collect_time: Optional[datetime] = None
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
self._backoff = 1
|
||||
self.logger = logging.getLogger(f"collector.{device_code}")
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self):
|
||||
"""Establish connection to the device."""
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self):
|
||||
"""Clean up connection resources."""
|
||||
|
||||
@abstractmethod
|
||||
async def collect(self) -> dict:
|
||||
"""Collect data points from the device.
|
||||
|
||||
Returns a dict mapping data_type -> (value, unit), e.g.:
|
||||
{"power": (105.3, "kW"), "voltage": (220.1, "V")}
|
||||
"""
|
||||
|
||||
async def start(self):
|
||||
"""Start the collector loop."""
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._run(), name=f"collector-{self.device_code}")
|
||||
self.logger.info("Collector started for %s", self.device_code)
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the collector loop and disconnect."""
|
||||
self._running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
try:
|
||||
await self.disconnect()
|
||||
except Exception as e:
|
||||
self.logger.warning("Error during disconnect: %s", e)
|
||||
self.status = "disconnected"
|
||||
self.logger.info("Collector stopped for %s", self.device_code)
|
||||
|
||||
async def _run(self):
|
||||
"""Main loop: connect, collect at interval, save to DB."""
|
||||
while self._running:
|
||||
# Connect phase
|
||||
if self.status != "connected":
|
||||
try:
|
||||
await self.connect()
|
||||
self.status = "connected"
|
||||
self.last_error = None
|
||||
self._backoff = 1
|
||||
self.logger.info("Connected to %s", self.device_code)
|
||||
except Exception as e:
|
||||
self.status = "error"
|
||||
self.last_error = str(e)
|
||||
self.logger.error("Connection failed for %s: %s", self.device_code, e)
|
||||
await self._wait_backoff()
|
||||
continue
|
||||
|
||||
# Collect phase
|
||||
try:
|
||||
data = await self.collect()
|
||||
if data:
|
||||
await self._save_data(data)
|
||||
self.last_collect_time = datetime.now(timezone.utc)
|
||||
self._backoff = 1
|
||||
except Exception as e:
|
||||
self.status = "error"
|
||||
self.last_error = str(e)
|
||||
self.logger.error("Collect error for %s: %s", self.device_code, e)
|
||||
try:
|
||||
await self.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
self.status = "disconnected"
|
||||
await self._wait_backoff()
|
||||
continue
|
||||
|
||||
await asyncio.sleep(self.collect_interval)
|
||||
|
||||
async def _wait_backoff(self):
|
||||
"""Wait with exponential backoff."""
|
||||
wait_time = min(self._backoff, self.MAX_BACKOFF)
|
||||
self.logger.debug("Backing off %ds for %s", wait_time, self.device_code)
|
||||
await asyncio.sleep(wait_time)
|
||||
self._backoff = min(self._backoff * 2, self.MAX_BACKOFF)
|
||||
|
||||
async def _save_data(self, data: dict):
|
||||
"""Save collected data points to the database."""
|
||||
now = datetime.now(timezone.utc)
|
||||
async with async_session() as session:
|
||||
points = []
|
||||
for data_type, (value, unit) in data.items():
|
||||
points.append(
|
||||
EnergyData(
|
||||
device_id=self.device_id,
|
||||
timestamp=now,
|
||||
data_type=data_type,
|
||||
value=float(value),
|
||||
unit=unit,
|
||||
)
|
||||
)
|
||||
# Update device status
|
||||
result = await session.execute(
|
||||
select(Device).where(Device.id == self.device_id)
|
||||
)
|
||||
device = result.scalar_one_or_none()
|
||||
if device:
|
||||
device.status = "online"
|
||||
device.last_data_time = now
|
||||
|
||||
session.add_all(points)
|
||||
await session.commit()
|
||||
self.logger.debug("Saved %d points for %s", len(points), self.device_code)
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Return collector status info."""
|
||||
return {
|
||||
"device_id": self.device_id,
|
||||
"device_code": self.device_code,
|
||||
"status": self.status,
|
||||
"last_error": self.last_error,
|
||||
"last_collect_time": self.last_collect_time.isoformat() if self.last_collect_time else None,
|
||||
"collect_interval": self.collect_interval,
|
||||
}
|
||||
107
backend/app/collectors/http_collector.py
Normal file
107
backend/app/collectors/http_collector.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""HTTP API protocol collector."""
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.collectors.base import BaseCollector
|
||||
|
||||
|
||||
class HttpCollector(BaseCollector):
|
||||
"""Collect data by polling HTTP API endpoints.
|
||||
|
||||
connection_params example:
|
||||
{
|
||||
"url": "http://api.example.com/device/data",
|
||||
"method": "GET",
|
||||
"headers": {"X-API-Key": "abc123"},
|
||||
"auth": {"type": "basic", "username": "user", "password": "pass"},
|
||||
"data_mapping": {
|
||||
"active_power": {"key": "data.power", "unit": "kW"},
|
||||
"voltage": {"key": "data.voltage", "unit": "V"}
|
||||
},
|
||||
"timeout": 10
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, device_id, device_code, connection_params, collect_interval=15):
|
||||
super().__init__(device_id, device_code, connection_params, collect_interval)
|
||||
self._url = connection_params.get("url", "")
|
||||
self._method = connection_params.get("method", "GET").upper()
|
||||
self._headers = connection_params.get("headers", {})
|
||||
self._auth_config = connection_params.get("auth", {})
|
||||
self._data_mapping = connection_params.get("data_mapping", {})
|
||||
self._timeout = connection_params.get("timeout", 10)
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def connect(self):
|
||||
auth = None
|
||||
auth_type = self._auth_config.get("type", "")
|
||||
if auth_type == "basic":
|
||||
auth = httpx.BasicAuth(
|
||||
self._auth_config.get("username", ""),
|
||||
self._auth_config.get("password", ""),
|
||||
)
|
||||
|
||||
headers = dict(self._headers)
|
||||
if auth_type == "token":
|
||||
token = self._auth_config.get("token", "")
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
self._client = httpx.AsyncClient(
|
||||
headers=headers,
|
||||
auth=auth,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
# Verify connectivity with a test request
|
||||
response = await self._client.request(self._method, self._url)
|
||||
response.raise_for_status()
|
||||
|
||||
async def disconnect(self):
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def collect(self) -> dict:
|
||||
if not self._client:
|
||||
raise ConnectionError("HTTP client not initialized")
|
||||
|
||||
response = await self._client.request(self._method, self._url)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
return self._parse_response(payload)
|
||||
|
||||
def _parse_response(self, payload: dict) -> dict:
|
||||
"""Parse HTTP JSON response into data points.
|
||||
|
||||
Supports dotted key paths like "data.power" to navigate nested JSON.
|
||||
"""
|
||||
data = {}
|
||||
if self._data_mapping:
|
||||
for data_type, mapping in self._data_mapping.items():
|
||||
key_path = mapping.get("key", data_type)
|
||||
unit = mapping.get("unit", "")
|
||||
value = self._resolve_path(payload, key_path)
|
||||
if value is not None:
|
||||
try:
|
||||
data[data_type] = (float(value), unit)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
# Auto-detect numeric fields at top level
|
||||
for key, value in payload.items():
|
||||
if isinstance(value, (int, float)):
|
||||
data[key] = (float(value), "")
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _resolve_path(obj: dict, path: str):
|
||||
"""Resolve a dotted path like 'data.power' in a nested dict."""
|
||||
parts = path.split(".")
|
||||
current = obj
|
||||
for part in parts:
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return None
|
||||
return current
|
||||
154
backend/app/collectors/manager.py
Normal file
154
backend/app/collectors/manager.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Collector Manager - orchestrates all device collectors."""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.database import async_session
|
||||
from app.models.device import Device
|
||||
from app.collectors.base import BaseCollector
|
||||
from app.collectors.modbus_tcp import ModbusTcpCollector
|
||||
from app.collectors.mqtt_collector import MqttCollector
|
||||
from app.collectors.http_collector import HttpCollector
|
||||
|
||||
logger = logging.getLogger("collector.manager")
|
||||
|
||||
# Full registry mapping protocol names to collector classes
|
||||
COLLECTOR_REGISTRY: dict[str, type[BaseCollector]] = {
|
||||
"modbus_tcp": ModbusTcpCollector,
|
||||
"mqtt": MqttCollector,
|
||||
"http_api": HttpCollector,
|
||||
}
|
||||
|
||||
|
||||
def get_enabled_collectors() -> dict[str, type[BaseCollector]]:
|
||||
"""Return collector registry filtered by customer config.
|
||||
|
||||
If the customer config specifies a 'collectors' list, only those
|
||||
protocols are enabled. Otherwise fall back to the full registry.
|
||||
"""
|
||||
settings = get_settings()
|
||||
customer_config = settings.load_customer_config()
|
||||
enabled_list = customer_config.get("collectors")
|
||||
if enabled_list is None:
|
||||
return COLLECTOR_REGISTRY
|
||||
enabled = {}
|
||||
for name in enabled_list:
|
||||
if name in COLLECTOR_REGISTRY:
|
||||
enabled[name] = COLLECTOR_REGISTRY[name]
|
||||
else:
|
||||
logger.warning("Customer config references unknown collector '%s', skipping", name)
|
||||
return enabled
|
||||
|
||||
|
||||
class CollectorManager:
|
||||
"""Manages lifecycle of all device collectors."""
|
||||
|
||||
def __init__(self):
|
||||
self._collectors: dict[int, BaseCollector] = {} # device_id -> collector
|
||||
self._running = False
|
||||
|
||||
async def start(self):
|
||||
"""Load active devices from DB and start their collectors."""
|
||||
self._running = True
|
||||
await self._load_and_start_collectors()
|
||||
logger.info("CollectorManager started with %d collectors", len(self._collectors))
|
||||
|
||||
async def stop(self):
|
||||
"""Stop all collectors."""
|
||||
self._running = False
|
||||
for device_id, collector in self._collectors.items():
|
||||
try:
|
||||
await collector.stop()
|
||||
except Exception as e:
|
||||
logger.error("Error stopping collector for device %d: %s", device_id, e)
|
||||
self._collectors.clear()
|
||||
logger.info("CollectorManager stopped")
|
||||
|
||||
async def _load_and_start_collectors(self):
|
||||
"""Load active devices with supported protocols and start collectors."""
|
||||
enabled = get_enabled_collectors()
|
||||
logger.info("Enabled collectors: %s", list(enabled.keys()))
|
||||
async with async_session() as session:
|
||||
result = await session.execute(
|
||||
select(Device).where(
|
||||
Device.is_active == True,
|
||||
Device.protocol.in_(list(enabled.keys())),
|
||||
)
|
||||
)
|
||||
devices = result.scalars().all()
|
||||
|
||||
for device in devices:
|
||||
await self.start_collector(
|
||||
device.id,
|
||||
device.code,
|
||||
device.protocol,
|
||||
device.connection_params or {},
|
||||
device.collect_interval or 15,
|
||||
)
|
||||
|
||||
async def start_collector(
|
||||
self,
|
||||
device_id: int,
|
||||
device_code: str,
|
||||
protocol: str,
|
||||
connection_params: dict,
|
||||
collect_interval: int,
|
||||
) -> bool:
|
||||
"""Start a single collector for a device."""
|
||||
if device_id in self._collectors:
|
||||
logger.warning("Collector already running for device %d", device_id)
|
||||
return False
|
||||
|
||||
collector_cls = COLLECTOR_REGISTRY.get(protocol)
|
||||
if not collector_cls:
|
||||
logger.warning("No collector for protocol '%s' (device %s)", protocol, device_code)
|
||||
return False
|
||||
|
||||
collector = collector_cls(device_id, device_code, connection_params, collect_interval)
|
||||
self._collectors[device_id] = collector
|
||||
await collector.start()
|
||||
logger.info("Started %s collector for %s", protocol, device_code)
|
||||
return True
|
||||
|
||||
async def stop_collector(self, device_id: int) -> bool:
|
||||
"""Stop collector for a specific device."""
|
||||
collector = self._collectors.pop(device_id, None)
|
||||
if not collector:
|
||||
return False
|
||||
await collector.stop()
|
||||
return True
|
||||
|
||||
async def restart_collector(self, device_id: int) -> bool:
|
||||
"""Restart collector for a device by reloading its config from DB."""
|
||||
await self.stop_collector(device_id)
|
||||
async with async_session() as session:
|
||||
result = await session.execute(
|
||||
select(Device).where(Device.id == device_id)
|
||||
)
|
||||
device = result.scalar_one_or_none()
|
||||
if not device or not device.is_active:
|
||||
return False
|
||||
return await self.start_collector(
|
||||
device.id,
|
||||
device.code,
|
||||
device.protocol,
|
||||
device.connection_params or {},
|
||||
device.collect_interval or 15,
|
||||
)
|
||||
|
||||
def get_collector(self, device_id: int) -> Optional[BaseCollector]:
|
||||
return self._collectors.get(device_id)
|
||||
|
||||
def get_all_status(self) -> list[dict]:
|
||||
"""Return status of all collectors."""
|
||||
return [c.get_status() for c in self._collectors.values()]
|
||||
|
||||
@property
|
||||
def collector_count(self) -> int:
|
||||
return len(self._collectors)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._running
|
||||
87
backend/app/collectors/modbus_tcp.py
Normal file
87
backend/app/collectors/modbus_tcp.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Modbus TCP protocol collector."""
|
||||
import struct
|
||||
from typing import Optional
|
||||
|
||||
from pymodbus.client import AsyncModbusTcpClient
|
||||
|
||||
from app.collectors.base import BaseCollector
|
||||
|
||||
|
||||
class ModbusTcpCollector(BaseCollector):
|
||||
"""Collect data from devices via Modbus TCP.
|
||||
|
||||
connection_params example:
|
||||
{
|
||||
"host": "192.168.1.100",
|
||||
"port": 502,
|
||||
"slave_id": 1,
|
||||
"registers": [
|
||||
{"address": 0, "count": 2, "data_type": "active_power", "scale": 0.1, "unit": "kW"},
|
||||
{"address": 2, "count": 2, "data_type": "voltage", "scale": 0.1, "unit": "V"}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, device_id, device_code, connection_params, collect_interval=15):
|
||||
super().__init__(device_id, device_code, connection_params, collect_interval)
|
||||
self._client: Optional[AsyncModbusTcpClient] = None
|
||||
self._host = connection_params.get("host", "127.0.0.1")
|
||||
self._port = connection_params.get("port", 502)
|
||||
self._slave_id = connection_params.get("slave_id", 1)
|
||||
self._registers = connection_params.get("registers", [])
|
||||
|
||||
async def connect(self):
|
||||
self._client = AsyncModbusTcpClient(
|
||||
self._host,
|
||||
port=self._port,
|
||||
timeout=5,
|
||||
)
|
||||
connected = await self._client.connect()
|
||||
if not connected:
|
||||
raise ConnectionError(f"Cannot connect to Modbus TCP {self._host}:{self._port}")
|
||||
|
||||
async def disconnect(self):
|
||||
if self._client:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
async def collect(self) -> dict:
|
||||
if not self._client or not self._client.connected:
|
||||
raise ConnectionError("Modbus client not connected")
|
||||
|
||||
data = {}
|
||||
for reg in self._registers:
|
||||
address = reg["address"]
|
||||
count = reg.get("count", 1)
|
||||
data_type = reg["data_type"]
|
||||
scale = reg.get("scale", 1.0)
|
||||
unit = reg.get("unit", "")
|
||||
|
||||
result = await self._client.read_holding_registers(
|
||||
address, count=count, slave=self._slave_id
|
||||
)
|
||||
if result.isError():
|
||||
self.logger.warning(
|
||||
"Modbus read error at address %d for %s: %s",
|
||||
address, self.device_code, result,
|
||||
)
|
||||
continue
|
||||
|
||||
raw_value = self._decode_registers(result.registers, count)
|
||||
value = round(raw_value * scale, 4)
|
||||
data[data_type] = (value, unit)
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _decode_registers(registers: list, count: int) -> float:
|
||||
"""Decode register values to a numeric value."""
|
||||
if count == 1:
|
||||
return float(registers[0])
|
||||
elif count == 2:
|
||||
# Two 16-bit registers -> 32-bit float (big-endian)
|
||||
raw = struct.pack(">HH", registers[0], registers[1])
|
||||
return struct.unpack(">f", raw)[0]
|
||||
else:
|
||||
# Fallback: treat as concatenated 16-bit values
|
||||
return float(registers[0])
|
||||
117
backend/app/collectors/mqtt_collector.py
Normal file
117
backend/app/collectors/mqtt_collector.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""MQTT protocol collector."""
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import aiomqtt
|
||||
|
||||
from app.collectors.base import BaseCollector
|
||||
|
||||
|
||||
class MqttCollector(BaseCollector):
|
||||
"""Collect data from devices via MQTT subscription.
|
||||
|
||||
connection_params example:
|
||||
{
|
||||
"broker": "localhost",
|
||||
"port": 1883,
|
||||
"topic": "device/INV-001/data",
|
||||
"username": "",
|
||||
"password": "",
|
||||
"data_mapping": {
|
||||
"active_power": {"key": "power", "unit": "kW"},
|
||||
"voltage": {"key": "voltage", "unit": "V"}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, device_id, device_code, connection_params, collect_interval=15):
|
||||
super().__init__(device_id, device_code, connection_params, collect_interval)
|
||||
self._broker = connection_params.get("broker", "localhost")
|
||||
self._port = connection_params.get("port", 1883)
|
||||
self._topic = connection_params.get("topic", f"device/{device_code}/data")
|
||||
self._username = connection_params.get("username", "") or None
|
||||
self._password = connection_params.get("password", "") or None
|
||||
self._data_mapping = connection_params.get("data_mapping", {})
|
||||
self._client: Optional[aiomqtt.Client] = None
|
||||
self._latest_data: dict = {}
|
||||
|
||||
async def connect(self):
|
||||
# Connection is established in the run loop via context manager
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
self._client = None
|
||||
|
||||
async def collect(self) -> dict:
|
||||
# Return latest received data; cleared after read
|
||||
data = self._latest_data.copy()
|
||||
self._latest_data.clear()
|
||||
return data
|
||||
|
||||
async def _run(self):
|
||||
"""Override run loop to use MQTT's push-based model."""
|
||||
while self._running:
|
||||
try:
|
||||
async with aiomqtt.Client(
|
||||
self._broker,
|
||||
port=self._port,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
) as client:
|
||||
self._client = client
|
||||
self.status = "connected"
|
||||
self.last_error = None
|
||||
self._backoff = 1
|
||||
self.logger.info("MQTT connected to %s:%d", self._broker, self._port)
|
||||
|
||||
await client.subscribe(self._topic)
|
||||
self.logger.info("Subscribed to %s", self._topic)
|
||||
|
||||
async for message in client.messages:
|
||||
if not self._running:
|
||||
break
|
||||
try:
|
||||
payload = json.loads(message.payload.decode())
|
||||
data = self._parse_payload(payload)
|
||||
if data:
|
||||
self._latest_data.update(data)
|
||||
await self._save_data(data)
|
||||
from datetime import datetime, timezone
|
||||
self.last_collect_time = datetime.now(timezone.utc)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
self.logger.warning("Bad MQTT payload on %s: %s", message.topic, e)
|
||||
|
||||
except aiomqtt.MqttError as e:
|
||||
self.status = "error"
|
||||
self.last_error = str(e)
|
||||
self.logger.error("MQTT error for %s: %s", self.device_code, e)
|
||||
await self._wait_backoff()
|
||||
except Exception as e:
|
||||
self.status = "error"
|
||||
self.last_error = str(e)
|
||||
self.logger.error("Unexpected MQTT error for %s: %s", self.device_code, e)
|
||||
await self._wait_backoff()
|
||||
|
||||
self.status = "disconnected"
|
||||
|
||||
def _parse_payload(self, payload: dict) -> dict:
|
||||
"""Parse MQTT JSON payload into data points.
|
||||
|
||||
If data_mapping is configured, use it. Otherwise, treat all
|
||||
numeric top-level keys as data points with empty units.
|
||||
"""
|
||||
data = {}
|
||||
if self._data_mapping:
|
||||
for data_type, mapping in self._data_mapping.items():
|
||||
key = mapping.get("key", data_type)
|
||||
unit = mapping.get("unit", "")
|
||||
if key in payload:
|
||||
try:
|
||||
data[data_type] = (float(payload[key]), unit)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
for key, value in payload.items():
|
||||
if isinstance(value, (int, float)):
|
||||
data[key] = (float(value), "")
|
||||
return data
|
||||
185
backend/app/collectors/queue.py
Normal file
185
backend/app/collectors/queue.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Redis Streams-based data ingestion buffer for high-throughput device data."""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.cache import get_redis
|
||||
from app.core.config import get_settings
|
||||
from app.core.database import async_session
|
||||
from app.models.energy import EnergyData
|
||||
|
||||
logger = logging.getLogger("ingestion.queue")
|
||||
|
||||
STREAM_KEY = "ems:ingestion:stream"
|
||||
CONSUMER_GROUP = "ems:ingestion:workers"
|
||||
CONSUMER_NAME = "worker-1"
|
||||
|
||||
|
||||
class IngestionQueue:
|
||||
"""Push device data into a Redis Stream for buffered ingestion."""
|
||||
|
||||
async def push(
|
||||
self,
|
||||
device_id: int,
|
||||
data_type: str,
|
||||
value: float,
|
||||
unit: str,
|
||||
timestamp: Optional[str] = None,
|
||||
raw_data: Optional[dict] = None,
|
||||
) -> Optional[str]:
|
||||
"""Add a data point to the ingestion stream.
|
||||
|
||||
Returns the message ID on success, None on failure.
|
||||
"""
|
||||
redis = await get_redis()
|
||||
if not redis:
|
||||
return None
|
||||
try:
|
||||
fields = {
|
||||
"device_id": str(device_id),
|
||||
"data_type": data_type,
|
||||
"value": str(value),
|
||||
"unit": unit,
|
||||
"timestamp": timestamp or datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
if raw_data:
|
||||
fields["raw_data"] = json.dumps(raw_data, ensure_ascii=False, default=str)
|
||||
msg_id = await redis.xadd(STREAM_KEY, fields)
|
||||
return msg_id
|
||||
except Exception as e:
|
||||
logger.error("Failed to push to ingestion stream: %s", e)
|
||||
return None
|
||||
|
||||
async def consume_batch(self, count: int = 100) -> list[tuple[str, dict]]:
|
||||
"""Read up to `count` messages from the stream via consumer group.
|
||||
|
||||
Returns list of (message_id, fields) tuples.
|
||||
"""
|
||||
redis = await get_redis()
|
||||
if not redis:
|
||||
return []
|
||||
try:
|
||||
# Ensure consumer group exists
|
||||
try:
|
||||
await redis.xgroup_create(STREAM_KEY, CONSUMER_GROUP, id="0", mkstream=True)
|
||||
except Exception:
|
||||
# Group already exists
|
||||
pass
|
||||
|
||||
messages = await redis.xreadgroup(
|
||||
CONSUMER_GROUP,
|
||||
CONSUMER_NAME,
|
||||
{STREAM_KEY: ">"},
|
||||
count=count,
|
||||
block=1000,
|
||||
)
|
||||
if not messages:
|
||||
return []
|
||||
# messages format: [(stream_key, [(msg_id, fields), ...])]
|
||||
return messages[0][1]
|
||||
except Exception as e:
|
||||
logger.error("Failed to consume from ingestion stream: %s", e)
|
||||
return []
|
||||
|
||||
async def ack(self, message_ids: list[str]) -> int:
|
||||
"""Acknowledge processed messages.
|
||||
|
||||
Returns number of successfully acknowledged messages.
|
||||
"""
|
||||
redis = await get_redis()
|
||||
if not redis or not message_ids:
|
||||
return 0
|
||||
try:
|
||||
return await redis.xack(STREAM_KEY, CONSUMER_GROUP, *message_ids)
|
||||
except Exception as e:
|
||||
logger.error("Failed to ack messages: %s", e)
|
||||
return 0
|
||||
|
||||
|
||||
class IngestionWorker:
|
||||
"""Background worker that drains the ingestion stream and bulk-inserts to DB."""
|
||||
|
||||
def __init__(self, batch_size: int = 100, interval: float = 2.0):
|
||||
self.batch_size = batch_size
|
||||
self.interval = interval
|
||||
self._queue = IngestionQueue()
|
||||
self._running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
|
||||
async def start(self):
|
||||
"""Start the background ingestion worker."""
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._run(), name="ingestion-worker")
|
||||
logger.info(
|
||||
"IngestionWorker started (batch_size=%d, interval=%.1fs)",
|
||||
self.batch_size,
|
||||
self.interval,
|
||||
)
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the ingestion worker gracefully."""
|
||||
self._running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("IngestionWorker stopped.")
|
||||
|
||||
async def _run(self):
|
||||
"""Main loop: consume batches from stream and insert to DB."""
|
||||
while self._running:
|
||||
try:
|
||||
messages = await self._queue.consume_batch(count=self.batch_size)
|
||||
if messages:
|
||||
await self._process_batch(messages)
|
||||
else:
|
||||
await asyncio.sleep(self.interval)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("IngestionWorker error: %s", e, exc_info=True)
|
||||
await asyncio.sleep(self.interval)
|
||||
|
||||
async def _process_batch(self, messages: list[tuple[str, dict]]):
|
||||
"""Parse messages and bulk-insert EnergyData rows."""
|
||||
msg_ids = []
|
||||
rows = []
|
||||
for msg_id, fields in messages:
|
||||
msg_ids.append(msg_id)
|
||||
try:
|
||||
ts_str = fields.get("timestamp", "")
|
||||
timestamp = datetime.fromisoformat(ts_str) if ts_str else datetime.now(timezone.utc)
|
||||
raw = None
|
||||
if "raw_data" in fields:
|
||||
try:
|
||||
raw = json.loads(fields["raw_data"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
raw = None
|
||||
rows.append(
|
||||
EnergyData(
|
||||
device_id=int(fields["device_id"]),
|
||||
timestamp=timestamp,
|
||||
data_type=fields["data_type"],
|
||||
value=float(fields["value"]),
|
||||
unit=fields.get("unit", ""),
|
||||
raw_data=raw,
|
||||
)
|
||||
)
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.warning("Skipping malformed message %s: %s", msg_id, e)
|
||||
|
||||
if rows:
|
||||
async with async_session() as session:
|
||||
session.add_all(rows)
|
||||
await session.commit()
|
||||
logger.debug("Bulk-inserted %d rows from ingestion stream.", len(rows))
|
||||
|
||||
# Acknowledge all messages (including malformed ones to avoid reprocessing)
|
||||
if msg_ids:
|
||||
await self._queue.ack(msg_ids)
|
||||
204
backend/app/collectors/sungrow_collector.py
Normal file
204
backend/app/collectors/sungrow_collector.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""阳光电源 iSolarCloud API 数据采集器"""
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.collectors.base import BaseCollector
|
||||
|
||||
|
||||
class SungrowCollector(BaseCollector):
|
||||
"""Collect data from Sungrow inverters via iSolarCloud OpenAPI.
|
||||
|
||||
connection_params example:
|
||||
{
|
||||
"api_base": "https://gateway.isolarcloud.com",
|
||||
"app_key": "1BF313B6A9F919A6FB6A90BD43D23395",
|
||||
"sys_code": "901",
|
||||
"x_access_key": "qpthtsf287zvtmr6t3q9hsc0k70f3tay",
|
||||
"user_account": "13911211695",
|
||||
"user_password": "123456#ABC",
|
||||
"ps_id": "power_station_id",
|
||||
"device_sn": "optional_device_serial"
|
||||
}
|
||||
"""
|
||||
|
||||
TOKEN_LIFETIME = 23 * 3600 # Refresh before 24h expiry
|
||||
|
||||
def __init__(self, device_id, device_code, connection_params, collect_interval=900):
|
||||
super().__init__(device_id, device_code, connection_params, collect_interval)
|
||||
self._api_base = connection_params.get("api_base", "https://gateway.isolarcloud.com").rstrip("/")
|
||||
self._app_key = connection_params.get("app_key", "")
|
||||
self._sys_code = connection_params.get("sys_code", "901")
|
||||
self._x_access_key = connection_params.get("x_access_key", "")
|
||||
self._user_account = connection_params.get("user_account", "")
|
||||
self._user_password = connection_params.get("user_password", "")
|
||||
self._ps_id = connection_params.get("ps_id", "")
|
||||
self._device_sn = connection_params.get("device_sn", "")
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
self._token: Optional[str] = None
|
||||
self._token_obtained_at: float = 0
|
||||
|
||||
async def connect(self):
|
||||
"""Establish HTTP client and authenticate with iSolarCloud."""
|
||||
self._client = httpx.AsyncClient(timeout=30)
|
||||
await self._login()
|
||||
self.logger.info("Authenticated with iSolarCloud for %s", self.device_code)
|
||||
|
||||
async def disconnect(self):
|
||||
"""Close HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
self._token = None
|
||||
|
||||
async def collect(self) -> dict:
|
||||
"""Collect real-time data from the Sungrow inverter.
|
||||
|
||||
Returns a dict mapping data_type -> (value, unit).
|
||||
"""
|
||||
if not self._client:
|
||||
raise ConnectionError("HTTP client not initialized")
|
||||
|
||||
# Refresh token if close to expiry
|
||||
if self._token_needs_refresh():
|
||||
await self._login()
|
||||
|
||||
data = {}
|
||||
|
||||
# Fetch power station overview for power/energy data
|
||||
if self._ps_id:
|
||||
ps_data = await self._get_station_data()
|
||||
if ps_data:
|
||||
data.update(ps_data)
|
||||
|
||||
# Fetch device list for per-device metrics
|
||||
if self._ps_id:
|
||||
dev_data = await self._get_device_data()
|
||||
if dev_data:
|
||||
data.update(dev_data)
|
||||
|
||||
return data
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal API methods
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _login(self):
|
||||
"""POST /openapi/login to obtain access token."""
|
||||
payload = {
|
||||
"appkey": self._app_key,
|
||||
"sys_code": self._sys_code,
|
||||
"user_account": self._user_account,
|
||||
"user_password": self._user_password,
|
||||
}
|
||||
result = await self._api_call("/openapi/login", payload, auth=False)
|
||||
|
||||
token = result.get("token")
|
||||
if not token:
|
||||
raise ConnectionError(f"Login failed: {result.get('msg', 'no token returned')}")
|
||||
|
||||
self._token = token
|
||||
self._token_obtained_at = time.monotonic()
|
||||
self.logger.info("iSolarCloud login successful for account %s", self._user_account)
|
||||
|
||||
async def _get_station_data(self) -> dict:
|
||||
"""Fetch power station real-time data."""
|
||||
payload = {"ps_id": self._ps_id}
|
||||
result = await self._api_call("/openapi/getPowerStationList", payload)
|
||||
|
||||
data = {}
|
||||
stations = result.get("pageList", [])
|
||||
for station in stations:
|
||||
if str(station.get("ps_id")) == str(self._ps_id):
|
||||
# Map station-level fields
|
||||
if "curr_power" in station:
|
||||
data["power"] = (float(station["curr_power"]), "kW")
|
||||
if "today_energy" in station:
|
||||
data["daily_energy"] = (float(station["today_energy"]), "kWh")
|
||||
if "total_energy" in station:
|
||||
data["total_energy"] = (float(station["total_energy"]), "kWh")
|
||||
break
|
||||
|
||||
return data
|
||||
|
||||
async def _get_device_data(self) -> dict:
|
||||
"""Fetch device-level real-time data for the target inverter."""
|
||||
payload = {"ps_id": self._ps_id}
|
||||
result = await self._api_call("/openapi/getDeviceList", payload)
|
||||
|
||||
data = {}
|
||||
devices = result.get("pageList", [])
|
||||
for device in devices:
|
||||
# Match by serial number if specified, otherwise use first inverter
|
||||
if self._device_sn and device.get("device_sn") != self._device_sn:
|
||||
continue
|
||||
|
||||
device_type = device.get("device_type", 0)
|
||||
# device_type 1 = inverter in Sungrow API
|
||||
if device_type in (1, "1") or not self._device_sn:
|
||||
if "device_power" in device:
|
||||
data["power"] = (float(device["device_power"]), "kW")
|
||||
if "today_energy" in device:
|
||||
data["daily_energy"] = (float(device["today_energy"]), "kWh")
|
||||
if "total_energy" in device:
|
||||
data["total_energy"] = (float(device["total_energy"]), "kWh")
|
||||
if "temperature" in device:
|
||||
data["temperature"] = (float(device["temperature"]), "°C")
|
||||
if "dc_voltage" in device:
|
||||
data["voltage"] = (float(device["dc_voltage"]), "V")
|
||||
if "ac_current" in device:
|
||||
data["current"] = (float(device["ac_current"]), "A")
|
||||
if "frequency" in device:
|
||||
data["frequency"] = (float(device["frequency"]), "Hz")
|
||||
if self._device_sn:
|
||||
break
|
||||
|
||||
return data
|
||||
|
||||
async def _api_call(self, path: str, payload: dict, auth: bool = True) -> dict:
|
||||
"""Make an API call to iSolarCloud.
|
||||
|
||||
Args:
|
||||
path: API endpoint path (e.g. /openapi/login).
|
||||
payload: Request body parameters.
|
||||
auth: Whether to include the auth token.
|
||||
|
||||
Returns:
|
||||
The 'result_data' dict from the response, or raises on error.
|
||||
"""
|
||||
url = f"{self._api_base}{path}"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-access-key": self._x_access_key,
|
||||
"sys_code": self._sys_code,
|
||||
}
|
||||
if auth and self._token:
|
||||
headers["token"] = self._token
|
||||
|
||||
body = {
|
||||
"appkey": self._app_key,
|
||||
"lang": "_zh_CN",
|
||||
**payload,
|
||||
}
|
||||
|
||||
self.logger.debug("API call: %s %s", "POST", url)
|
||||
response = await self._client.post(url, json=body, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
resp_json = response.json()
|
||||
result_code = resp_json.get("result_code", -1)
|
||||
if result_code != 1 and str(result_code) != "1":
|
||||
msg = resp_json.get("result_msg", "Unknown error")
|
||||
self.logger.error("API error on %s: code=%s msg=%s", path, result_code, msg)
|
||||
raise RuntimeError(f"iSolarCloud API error: {msg} (code={result_code})")
|
||||
|
||||
return resp_json.get("result_data", {})
|
||||
|
||||
def _token_needs_refresh(self) -> bool:
|
||||
"""Check if the token is close to expiry."""
|
||||
if not self._token:
|
||||
return True
|
||||
elapsed = time.monotonic() - self._token_obtained_at
|
||||
return elapsed >= self.TOKEN_LIFETIME
|
||||
Reference in New Issue
Block a user