""" Celery tasks for asynchronous report generation. Also provides a synchronous fallback for demo/dev environments. """ import logging from datetime import date, datetime from sqlalchemy import select, create_engine from sqlalchemy.orm import Session as SyncSession, sessionmaker from app.core.config import get_settings from app.models.report import ReportTemplate, ReportTask logger = logging.getLogger(__name__) settings = get_settings() # Synchronous DB engine for Celery workers (Celery cannot use async) _sync_url = settings.DATABASE_URL_SYNC if not _sync_url: # Derive sync URL from async URL _sync_url = settings.DATABASE_URL.replace("+aiosqlite", "").replace("+asyncpg", "").replace("+aiomysql", "") _sync_engine = create_engine(_sync_url, echo=False) SyncSessionLocal = sessionmaker(bind=_sync_engine) # Report type -> generator method name mapping REPORT_TYPE_METHODS = { "daily": "generate_energy_daily_report", "monthly": "generate_monthly_summary", "device_status": "generate_device_status_report", "alarm": "generate_alarm_report", "carbon": "generate_carbon_report", } def _run_report_sync(task_id: int) -> str: """ Synchronous report generation logic. Used both by Celery tasks and by the synchronous fallback in the API. Returns the generated file path. """ db: SyncSession = SyncSessionLocal() try: task = db.execute(select(ReportTask).where(ReportTask.id == task_id)).scalar_one_or_none() if not task: raise ValueError(f"ReportTask {task_id} not found") task.status = "running" db.commit() template = db.execute( select(ReportTemplate).where(ReportTemplate.id == task.template_id) ).scalar_one_or_none() if not template: task.status = "failed" db.commit() raise ValueError(f"ReportTemplate {task.template_id} not found") # Determine date range from template filters filters = template.filters or {} today = date.today() start_date = _parse_date(filters.get("start_date"), default=today.replace(day=1)) end_date = _parse_date(filters.get("end_date"), default=today) device_ids = filters.get("device_ids") export_format = task.export_format or "xlsx" report_type = template.report_type method_name = REPORT_TYPE_METHODS.get(report_type) if not method_name: task.status = "failed" db.commit() raise ValueError(f"Unknown report type: {report_type}") # Use synchronous wrapper around async generator import asyncio from app.core.database import async_session from app.services.report_generator import ReportGenerator async def _generate(): async with async_session() as adb: gen = ReportGenerator(adb) method = getattr(gen, method_name) if report_type == "monthly": month = filters.get("month", today.month) year = filters.get("year", today.year) return await method(month=month, year=year, export_format=export_format) elif report_type == "device_status": return await method(export_format=export_format) else: return await method( start_date=start_date, end_date=end_date, export_format=export_format, **({"device_ids": device_ids} if device_ids and report_type == "daily" else {}), ) loop = asyncio.new_event_loop() try: filepath = loop.run_until_complete(_generate()) finally: loop.close() task.status = "completed" task.file_path = filepath task.last_run = datetime.now() db.commit() logger.info(f"Report task {task_id} completed: {filepath}") return filepath except Exception as e: logger.error(f"Report task {task_id} failed: {e}") try: task = db.execute(select(ReportTask).where(ReportTask.id == task_id)).scalar_one_or_none() if task: task.status = "failed" db.commit() except Exception: pass raise finally: db.close() def _parse_date(val, default: date) -> date: if not val: return default if isinstance(val, date): return val try: return date.fromisoformat(str(val)) except (ValueError, TypeError): return default # ---------- Celery task ---------- # try: from app.tasks.celery_app import celery_app @celery_app.task(name="app.tasks.report_tasks.generate_report_task", bind=True, max_retries=2) def generate_report_task(self, task_id: int) -> str: try: return _run_report_sync(task_id) except Exception as exc: logger.error(f"Celery report task failed: {exc}") raise self.retry(exc=exc, countdown=10) CELERY_AVAILABLE = True except Exception: CELERY_AVAILABLE = False def run_report_sync(task_id: int) -> str: """Public synchronous entry point for fallback mode.""" return _run_report_sync(task_id)