"""Redis Streams-based data ingestion buffer for high-throughput device data.""" import asyncio import json import logging from datetime import datetime, timezone from typing import Optional from sqlalchemy import select from app.core.cache import get_redis from app.core.config import get_settings from app.core.database import async_session from app.models.energy import EnergyData logger = logging.getLogger("ingestion.queue") STREAM_KEY = "ems:ingestion:stream" CONSUMER_GROUP = "ems:ingestion:workers" CONSUMER_NAME = "worker-1" class IngestionQueue: """Push device data into a Redis Stream for buffered ingestion.""" async def push( self, device_id: int, data_type: str, value: float, unit: str, timestamp: Optional[str] = None, raw_data: Optional[dict] = None, ) -> Optional[str]: """Add a data point to the ingestion stream. Returns the message ID on success, None on failure. """ redis = await get_redis() if not redis: return None try: fields = { "device_id": str(device_id), "data_type": data_type, "value": str(value), "unit": unit, "timestamp": timestamp or datetime.now(timezone.utc).isoformat(), } if raw_data: fields["raw_data"] = json.dumps(raw_data, ensure_ascii=False, default=str) msg_id = await redis.xadd(STREAM_KEY, fields) return msg_id except Exception as e: logger.error("Failed to push to ingestion stream: %s", e) return None async def consume_batch(self, count: int = 100) -> list[tuple[str, dict]]: """Read up to `count` messages from the stream via consumer group. Returns list of (message_id, fields) tuples. """ redis = await get_redis() if not redis: return [] try: # Ensure consumer group exists try: await redis.xgroup_create(STREAM_KEY, CONSUMER_GROUP, id="0", mkstream=True) except Exception: # Group already exists pass messages = await redis.xreadgroup( CONSUMER_GROUP, CONSUMER_NAME, {STREAM_KEY: ">"}, count=count, block=1000, ) if not messages: return [] # messages format: [(stream_key, [(msg_id, fields), ...])] return messages[0][1] except Exception as e: logger.error("Failed to consume from ingestion stream: %s", e) return [] async def ack(self, message_ids: list[str]) -> int: """Acknowledge processed messages. Returns number of successfully acknowledged messages. """ redis = await get_redis() if not redis or not message_ids: return 0 try: return await redis.xack(STREAM_KEY, CONSUMER_GROUP, *message_ids) except Exception as e: logger.error("Failed to ack messages: %s", e) return 0 class IngestionWorker: """Background worker that drains the ingestion stream and bulk-inserts to DB.""" def __init__(self, batch_size: int = 100, interval: float = 2.0): self.batch_size = batch_size self.interval = interval self._queue = IngestionQueue() self._running = False self._task: Optional[asyncio.Task] = None async def start(self): """Start the background ingestion worker.""" self._running = True self._task = asyncio.create_task(self._run(), name="ingestion-worker") logger.info( "IngestionWorker started (batch_size=%d, interval=%.1fs)", self.batch_size, self.interval, ) async def stop(self): """Stop the ingestion worker gracefully.""" self._running = False if self._task: self._task.cancel() try: await self._task except asyncio.CancelledError: pass logger.info("IngestionWorker stopped.") async def _run(self): """Main loop: consume batches from stream and insert to DB.""" while self._running: try: messages = await self._queue.consume_batch(count=self.batch_size) if messages: await self._process_batch(messages) else: await asyncio.sleep(self.interval) except asyncio.CancelledError: raise except Exception as e: logger.error("IngestionWorker error: %s", e, exc_info=True) await asyncio.sleep(self.interval) async def _process_batch(self, messages: list[tuple[str, dict]]): """Parse messages and bulk-insert EnergyData rows.""" msg_ids = [] rows = [] for msg_id, fields in messages: msg_ids.append(msg_id) try: ts_str = fields.get("timestamp", "") timestamp = datetime.fromisoformat(ts_str) if ts_str else datetime.now(timezone.utc) raw = None if "raw_data" in fields: try: raw = json.loads(fields["raw_data"]) except (json.JSONDecodeError, TypeError): raw = None rows.append( EnergyData( device_id=int(fields["device_id"]), timestamp=timestamp, data_type=fields["data_type"], value=float(fields["value"]), unit=fields.get("unit", ""), raw_data=raw, ) ) except (KeyError, ValueError) as e: logger.warning("Skipping malformed message %s: %s", msg_id, e) if rows: async with async_session() as session: session.add_all(rows) await session.commit() logger.debug("Bulk-inserted %d rows from ingestion stream.", len(rows)) # Acknowledge all messages (including malformed ones to avoid reprocessing) if msg_ids: await self._queue.ack(msg_ids)