158 lines
5.2 KiB
Python
158 lines
5.2 KiB
Python
"""
|
|
Celery tasks for asynchronous report generation.
|
|
Also provides a synchronous fallback for demo/dev environments.
|
|
"""
|
|
import logging
|
|
from datetime import date, datetime
|
|
|
|
from sqlalchemy import select, create_engine
|
|
from sqlalchemy.orm import Session as SyncSession, sessionmaker
|
|
|
|
from app.core.config import get_settings
|
|
from app.models.report import ReportTemplate, ReportTask
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
settings = get_settings()
|
|
|
|
# Synchronous DB engine for Celery workers (Celery cannot use async)
|
|
_sync_url = settings.DATABASE_URL_SYNC
|
|
if not _sync_url:
|
|
# Derive sync URL from async URL
|
|
_sync_url = settings.DATABASE_URL.replace("+aiosqlite", "").replace("+asyncpg", "").replace("+aiomysql", "")
|
|
|
|
_sync_engine = create_engine(_sync_url, echo=False)
|
|
SyncSessionLocal = sessionmaker(bind=_sync_engine)
|
|
|
|
|
|
# Report type -> generator method name mapping
|
|
REPORT_TYPE_METHODS = {
|
|
"daily": "generate_energy_daily_report",
|
|
"monthly": "generate_monthly_summary",
|
|
"device_status": "generate_device_status_report",
|
|
"alarm": "generate_alarm_report",
|
|
"carbon": "generate_carbon_report",
|
|
}
|
|
|
|
|
|
def _run_report_sync(task_id: int) -> str:
|
|
"""
|
|
Synchronous report generation logic.
|
|
Used both by Celery tasks and by the synchronous fallback in the API.
|
|
Returns the generated file path.
|
|
"""
|
|
db: SyncSession = SyncSessionLocal()
|
|
try:
|
|
task = db.execute(select(ReportTask).where(ReportTask.id == task_id)).scalar_one_or_none()
|
|
if not task:
|
|
raise ValueError(f"ReportTask {task_id} not found")
|
|
|
|
task.status = "running"
|
|
db.commit()
|
|
|
|
template = db.execute(
|
|
select(ReportTemplate).where(ReportTemplate.id == task.template_id)
|
|
).scalar_one_or_none()
|
|
if not template:
|
|
task.status = "failed"
|
|
db.commit()
|
|
raise ValueError(f"ReportTemplate {task.template_id} not found")
|
|
|
|
# Determine date range from template filters
|
|
filters = template.filters or {}
|
|
today = date.today()
|
|
start_date = _parse_date(filters.get("start_date"), default=today.replace(day=1))
|
|
end_date = _parse_date(filters.get("end_date"), default=today)
|
|
device_ids = filters.get("device_ids")
|
|
export_format = task.export_format or "xlsx"
|
|
report_type = template.report_type
|
|
|
|
method_name = REPORT_TYPE_METHODS.get(report_type)
|
|
if not method_name:
|
|
task.status = "failed"
|
|
db.commit()
|
|
raise ValueError(f"Unknown report type: {report_type}")
|
|
|
|
# Use synchronous wrapper around async generator
|
|
import asyncio
|
|
from app.core.database import async_session
|
|
from app.services.report_generator import ReportGenerator
|
|
|
|
async def _generate():
|
|
async with async_session() as adb:
|
|
gen = ReportGenerator(adb)
|
|
method = getattr(gen, method_name)
|
|
if report_type == "monthly":
|
|
month = filters.get("month", today.month)
|
|
year = filters.get("year", today.year)
|
|
return await method(month=month, year=year, export_format=export_format)
|
|
elif report_type == "device_status":
|
|
return await method(export_format=export_format)
|
|
else:
|
|
return await method(
|
|
start_date=start_date, end_date=end_date,
|
|
export_format=export_format,
|
|
**({"device_ids": device_ids} if device_ids and report_type == "daily" else {}),
|
|
)
|
|
|
|
loop = asyncio.new_event_loop()
|
|
try:
|
|
filepath = loop.run_until_complete(_generate())
|
|
finally:
|
|
loop.close()
|
|
|
|
task.status = "completed"
|
|
task.file_path = filepath
|
|
task.last_run = datetime.now()
|
|
db.commit()
|
|
|
|
logger.info(f"Report task {task_id} completed: {filepath}")
|
|
return filepath
|
|
|
|
except Exception as e:
|
|
logger.error(f"Report task {task_id} failed: {e}")
|
|
try:
|
|
task = db.execute(select(ReportTask).where(ReportTask.id == task_id)).scalar_one_or_none()
|
|
if task:
|
|
task.status = "failed"
|
|
db.commit()
|
|
except Exception:
|
|
pass
|
|
raise
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def _parse_date(val, default: date) -> date:
|
|
if not val:
|
|
return default
|
|
if isinstance(val, date):
|
|
return val
|
|
try:
|
|
return date.fromisoformat(str(val))
|
|
except (ValueError, TypeError):
|
|
return default
|
|
|
|
|
|
# ---------- Celery task ---------- #
|
|
|
|
try:
|
|
from app.tasks.celery_app import celery_app
|
|
|
|
@celery_app.task(name="app.tasks.report_tasks.generate_report_task", bind=True, max_retries=2)
|
|
def generate_report_task(self, task_id: int) -> str:
|
|
try:
|
|
return _run_report_sync(task_id)
|
|
except Exception as exc:
|
|
logger.error(f"Celery report task failed: {exc}")
|
|
raise self.retry(exc=exc, countdown=10)
|
|
|
|
CELERY_AVAILABLE = True
|
|
except Exception:
|
|
CELERY_AVAILABLE = False
|
|
|
|
|
|
def run_report_sync(task_id: int) -> str:
|
|
"""Public synchronous entry point for fallback mode."""
|
|
return _run_report_sync(task_id)
|