Files
zpark-ems/backend/app/tasks/report_tasks.py

158 lines
5.2 KiB
Python
Raw Normal View History

"""
Celery tasks for asynchronous report generation.
Also provides a synchronous fallback for demo/dev environments.
"""
import logging
from datetime import date, datetime
from sqlalchemy import select, create_engine
from sqlalchemy.orm import Session as SyncSession, sessionmaker
from app.core.config import get_settings
from app.models.report import ReportTemplate, ReportTask
logger = logging.getLogger(__name__)
settings = get_settings()
# Synchronous DB engine for Celery workers (Celery cannot use async)
_sync_url = settings.DATABASE_URL_SYNC
if not _sync_url:
# Derive sync URL from async URL
_sync_url = settings.DATABASE_URL.replace("+aiosqlite", "").replace("+asyncpg", "").replace("+aiomysql", "")
_sync_engine = create_engine(_sync_url, echo=False)
SyncSessionLocal = sessionmaker(bind=_sync_engine)
# Report type -> generator method name mapping
REPORT_TYPE_METHODS = {
"daily": "generate_energy_daily_report",
"monthly": "generate_monthly_summary",
"device_status": "generate_device_status_report",
"alarm": "generate_alarm_report",
"carbon": "generate_carbon_report",
}
def _run_report_sync(task_id: int) -> str:
"""
Synchronous report generation logic.
Used both by Celery tasks and by the synchronous fallback in the API.
Returns the generated file path.
"""
db: SyncSession = SyncSessionLocal()
try:
task = db.execute(select(ReportTask).where(ReportTask.id == task_id)).scalar_one_or_none()
if not task:
raise ValueError(f"ReportTask {task_id} not found")
task.status = "running"
db.commit()
template = db.execute(
select(ReportTemplate).where(ReportTemplate.id == task.template_id)
).scalar_one_or_none()
if not template:
task.status = "failed"
db.commit()
raise ValueError(f"ReportTemplate {task.template_id} not found")
# Determine date range from template filters
filters = template.filters or {}
today = date.today()
start_date = _parse_date(filters.get("start_date"), default=today.replace(day=1))
end_date = _parse_date(filters.get("end_date"), default=today)
device_ids = filters.get("device_ids")
export_format = task.export_format or "xlsx"
report_type = template.report_type
method_name = REPORT_TYPE_METHODS.get(report_type)
if not method_name:
task.status = "failed"
db.commit()
raise ValueError(f"Unknown report type: {report_type}")
# Use synchronous wrapper around async generator
import asyncio
from app.core.database import async_session
from app.services.report_generator import ReportGenerator
async def _generate():
async with async_session() as adb:
gen = ReportGenerator(adb)
method = getattr(gen, method_name)
if report_type == "monthly":
month = filters.get("month", today.month)
year = filters.get("year", today.year)
return await method(month=month, year=year, export_format=export_format)
elif report_type == "device_status":
return await method(export_format=export_format)
else:
return await method(
start_date=start_date, end_date=end_date,
export_format=export_format,
**({"device_ids": device_ids} if device_ids and report_type == "daily" else {}),
)
loop = asyncio.new_event_loop()
try:
filepath = loop.run_until_complete(_generate())
finally:
loop.close()
task.status = "completed"
task.file_path = filepath
task.last_run = datetime.now()
db.commit()
logger.info(f"Report task {task_id} completed: {filepath}")
return filepath
except Exception as e:
logger.error(f"Report task {task_id} failed: {e}")
try:
task = db.execute(select(ReportTask).where(ReportTask.id == task_id)).scalar_one_or_none()
if task:
task.status = "failed"
db.commit()
except Exception:
pass
raise
finally:
db.close()
def _parse_date(val, default: date) -> date:
if not val:
return default
if isinstance(val, date):
return val
try:
return date.fromisoformat(str(val))
except (ValueError, TypeError):
return default
# ---------- Celery task ---------- #
try:
from app.tasks.celery_app import celery_app
@celery_app.task(name="app.tasks.report_tasks.generate_report_task", bind=True, max_retries=2)
def generate_report_task(self, task_id: int) -> str:
try:
return _run_report_sync(task_id)
except Exception as exc:
logger.error(f"Celery report task failed: {exc}")
raise self.retry(exc=exc, countdown=10)
CELERY_AVAILABLE = True
except Exception:
CELERY_AVAILABLE = False
def run_report_sync(task_id: int) -> str:
"""Public synchronous entry point for fallback mode."""
return _run_report_sync(task_id)