186 lines
6.3 KiB
Python
186 lines
6.3 KiB
Python
|
|
"""Redis Streams-based data ingestion buffer for high-throughput device data."""
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
from sqlalchemy import select
|
||
|
|
|
||
|
|
from app.core.cache import get_redis
|
||
|
|
from app.core.config import get_settings
|
||
|
|
from app.core.database import async_session
|
||
|
|
from app.models.energy import EnergyData
|
||
|
|
|
||
|
|
logger = logging.getLogger("ingestion.queue")
|
||
|
|
|
||
|
|
STREAM_KEY = "ems:ingestion:stream"
|
||
|
|
CONSUMER_GROUP = "ems:ingestion:workers"
|
||
|
|
CONSUMER_NAME = "worker-1"
|
||
|
|
|
||
|
|
|
||
|
|
class IngestionQueue:
|
||
|
|
"""Push device data into a Redis Stream for buffered ingestion."""
|
||
|
|
|
||
|
|
async def push(
|
||
|
|
self,
|
||
|
|
device_id: int,
|
||
|
|
data_type: str,
|
||
|
|
value: float,
|
||
|
|
unit: str,
|
||
|
|
timestamp: Optional[str] = None,
|
||
|
|
raw_data: Optional[dict] = None,
|
||
|
|
) -> Optional[str]:
|
||
|
|
"""Add a data point to the ingestion stream.
|
||
|
|
|
||
|
|
Returns the message ID on success, None on failure.
|
||
|
|
"""
|
||
|
|
redis = await get_redis()
|
||
|
|
if not redis:
|
||
|
|
return None
|
||
|
|
try:
|
||
|
|
fields = {
|
||
|
|
"device_id": str(device_id),
|
||
|
|
"data_type": data_type,
|
||
|
|
"value": str(value),
|
||
|
|
"unit": unit,
|
||
|
|
"timestamp": timestamp or datetime.now(timezone.utc).isoformat(),
|
||
|
|
}
|
||
|
|
if raw_data:
|
||
|
|
fields["raw_data"] = json.dumps(raw_data, ensure_ascii=False, default=str)
|
||
|
|
msg_id = await redis.xadd(STREAM_KEY, fields)
|
||
|
|
return msg_id
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("Failed to push to ingestion stream: %s", e)
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def consume_batch(self, count: int = 100) -> list[tuple[str, dict]]:
|
||
|
|
"""Read up to `count` messages from the stream via consumer group.
|
||
|
|
|
||
|
|
Returns list of (message_id, fields) tuples.
|
||
|
|
"""
|
||
|
|
redis = await get_redis()
|
||
|
|
if not redis:
|
||
|
|
return []
|
||
|
|
try:
|
||
|
|
# Ensure consumer group exists
|
||
|
|
try:
|
||
|
|
await redis.xgroup_create(STREAM_KEY, CONSUMER_GROUP, id="0", mkstream=True)
|
||
|
|
except Exception:
|
||
|
|
# Group already exists
|
||
|
|
pass
|
||
|
|
|
||
|
|
messages = await redis.xreadgroup(
|
||
|
|
CONSUMER_GROUP,
|
||
|
|
CONSUMER_NAME,
|
||
|
|
{STREAM_KEY: ">"},
|
||
|
|
count=count,
|
||
|
|
block=1000,
|
||
|
|
)
|
||
|
|
if not messages:
|
||
|
|
return []
|
||
|
|
# messages format: [(stream_key, [(msg_id, fields), ...])]
|
||
|
|
return messages[0][1]
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("Failed to consume from ingestion stream: %s", e)
|
||
|
|
return []
|
||
|
|
|
||
|
|
async def ack(self, message_ids: list[str]) -> int:
|
||
|
|
"""Acknowledge processed messages.
|
||
|
|
|
||
|
|
Returns number of successfully acknowledged messages.
|
||
|
|
"""
|
||
|
|
redis = await get_redis()
|
||
|
|
if not redis or not message_ids:
|
||
|
|
return 0
|
||
|
|
try:
|
||
|
|
return await redis.xack(STREAM_KEY, CONSUMER_GROUP, *message_ids)
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("Failed to ack messages: %s", e)
|
||
|
|
return 0
|
||
|
|
|
||
|
|
|
||
|
|
class IngestionWorker:
|
||
|
|
"""Background worker that drains the ingestion stream and bulk-inserts to DB."""
|
||
|
|
|
||
|
|
def __init__(self, batch_size: int = 100, interval: float = 2.0):
|
||
|
|
self.batch_size = batch_size
|
||
|
|
self.interval = interval
|
||
|
|
self._queue = IngestionQueue()
|
||
|
|
self._running = False
|
||
|
|
self._task: Optional[asyncio.Task] = None
|
||
|
|
|
||
|
|
async def start(self):
|
||
|
|
"""Start the background ingestion worker."""
|
||
|
|
self._running = True
|
||
|
|
self._task = asyncio.create_task(self._run(), name="ingestion-worker")
|
||
|
|
logger.info(
|
||
|
|
"IngestionWorker started (batch_size=%d, interval=%.1fs)",
|
||
|
|
self.batch_size,
|
||
|
|
self.interval,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def stop(self):
|
||
|
|
"""Stop the ingestion worker gracefully."""
|
||
|
|
self._running = False
|
||
|
|
if self._task:
|
||
|
|
self._task.cancel()
|
||
|
|
try:
|
||
|
|
await self._task
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
pass
|
||
|
|
logger.info("IngestionWorker stopped.")
|
||
|
|
|
||
|
|
async def _run(self):
|
||
|
|
"""Main loop: consume batches from stream and insert to DB."""
|
||
|
|
while self._running:
|
||
|
|
try:
|
||
|
|
messages = await self._queue.consume_batch(count=self.batch_size)
|
||
|
|
if messages:
|
||
|
|
await self._process_batch(messages)
|
||
|
|
else:
|
||
|
|
await asyncio.sleep(self.interval)
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
raise
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("IngestionWorker error: %s", e, exc_info=True)
|
||
|
|
await asyncio.sleep(self.interval)
|
||
|
|
|
||
|
|
async def _process_batch(self, messages: list[tuple[str, dict]]):
|
||
|
|
"""Parse messages and bulk-insert EnergyData rows."""
|
||
|
|
msg_ids = []
|
||
|
|
rows = []
|
||
|
|
for msg_id, fields in messages:
|
||
|
|
msg_ids.append(msg_id)
|
||
|
|
try:
|
||
|
|
ts_str = fields.get("timestamp", "")
|
||
|
|
timestamp = datetime.fromisoformat(ts_str) if ts_str else datetime.now(timezone.utc)
|
||
|
|
raw = None
|
||
|
|
if "raw_data" in fields:
|
||
|
|
try:
|
||
|
|
raw = json.loads(fields["raw_data"])
|
||
|
|
except (json.JSONDecodeError, TypeError):
|
||
|
|
raw = None
|
||
|
|
rows.append(
|
||
|
|
EnergyData(
|
||
|
|
device_id=int(fields["device_id"]),
|
||
|
|
timestamp=timestamp,
|
||
|
|
data_type=fields["data_type"],
|
||
|
|
value=float(fields["value"]),
|
||
|
|
unit=fields.get("unit", ""),
|
||
|
|
raw_data=raw,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
except (KeyError, ValueError) as e:
|
||
|
|
logger.warning("Skipping malformed message %s: %s", msg_id, e)
|
||
|
|
|
||
|
|
if rows:
|
||
|
|
async with async_session() as session:
|
||
|
|
session.add_all(rows)
|
||
|
|
await session.commit()
|
||
|
|
logger.debug("Bulk-inserted %d rows from ingestion stream.", len(rows))
|
||
|
|
|
||
|
|
# Acknowledge all messages (including malformed ones to avoid reprocessing)
|
||
|
|
if msg_ids:
|
||
|
|
await self._queue.ack(msg_ids)
|