import asyncio from datetime import datetime, timezone import pytest from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession from app.core.database import Base, get_db from app.core.security import hash_password, create_access_token from app.models.user import User, Role from app.models.device import Device, DeviceType, DeviceGroup from app.models.energy import EnergyData, EnergyDailySummary from app.models.alarm import AlarmRule, AlarmEvent from app.models.carbon import CarbonEmission, EmissionFactor from app.models.report import ReportTemplate, ReportTask TEST_DB_URL = "sqlite+aiosqlite://" engine = create_async_engine(TEST_DB_URL, echo=False) TestSession = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) @pytest.fixture(scope="session") def event_loop(): loop = asyncio.new_event_loop() yield loop loop.close() @pytest.fixture(autouse=True) async def setup_db(): async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) @pytest.fixture async def db_session(): async with TestSession() as session: try: yield session await session.commit() except Exception: await session.rollback() raise async def _override_get_db(): async with TestSession() as session: try: yield session await session.commit() except Exception: await session.rollback() raise def _create_test_app(): from fastapi import FastAPI from app.api.router import api_router test_app = FastAPI() test_app.include_router(api_router) test_app.dependency_overrides[get_db] = _override_get_db return test_app @pytest.fixture async def client(): app = _create_test_app() transport = ASGITransport(app=app, raise_app_exceptions=False) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac app.dependency_overrides.clear() @pytest.fixture async def admin_user(db_session: AsyncSession): user = User( username="testadmin", hashed_password=hash_password("admin123"), full_name="Test Admin", email="admin@test.com", phone="13800000001", role="admin", is_active=True, ) db_session.add(user) await db_session.commit() await db_session.refresh(user) return user @pytest.fixture async def normal_user(db_session: AsyncSession): user = User( username="testuser", hashed_password=hash_password("user123"), full_name="Test User", email="user@test.com", phone="13800000002", role="visitor", is_active=True, ) db_session.add(user) await db_session.commit() await db_session.refresh(user) return user @pytest.fixture async def admin_token(admin_user: User) -> str: return create_access_token({"sub": str(admin_user.id), "role": admin_user.role}) @pytest.fixture async def user_token(normal_user: User) -> str: return create_access_token({"sub": str(normal_user.id), "role": normal_user.role}) def auth_header(token: str) -> dict: return {"Authorization": f"Bearer {token}"} @pytest.fixture async def seed_roles(db_session: AsyncSession): roles = [ Role(name="admin", display_name="管理员", description="系统管理员"), Role(name="energy_manager", display_name="能源管理员", description="能源管理"), Role(name="visitor", display_name="访客", description="只读访客"), ] db_session.add_all(roles) await db_session.commit() return roles @pytest.fixture async def seed_device_types(db_session: AsyncSession): types = [ DeviceType(code="pv_inverter", name="光伏逆变器", icon="solar"), DeviceType(code="heat_pump", name="热泵机组", icon="heat"), DeviceType(code="meter", name="电表", icon="meter"), ] db_session.add_all(types) await db_session.commit() return types @pytest.fixture async def seed_device_groups(db_session: AsyncSession): groups = [ DeviceGroup(name="A区", location="大兴园区A区"), DeviceGroup(name="B区", location="大兴园区B区"), ] db_session.add_all(groups) await db_session.commit() return groups @pytest.fixture async def seed_devices(db_session: AsyncSession, seed_device_types): devices = [ Device(name="光伏逆变器1号", code="PV-INV-001", device_type="pv_inverter", status="online", rated_power=100.0, is_active=True), Device(name="热泵机组1号", code="HP-001", device_type="heat_pump", status="online", rated_power=50.0, is_active=True), Device(name="电表1号", code="MTR-001", device_type="meter", status="offline", is_active=True), ] db_session.add_all(devices) await db_session.commit() for d in devices: await db_session.refresh(d) return devices @pytest.fixture async def seed_energy_data(db_session: AsyncSession, seed_devices): now = datetime.now(timezone.utc) data = [] for device in seed_devices: data.append(EnergyData( device_id=device.id, timestamp=now, data_type="power", value=42.5, unit="kW", )) db_session.add_all(data) await db_session.commit() return data @pytest.fixture async def seed_daily_summary(db_session: AsyncSession, seed_devices): now = datetime.now(timezone.utc) summaries = [ EnergyDailySummary( device_id=seed_devices[0].id, date=now, energy_type="electricity", total_consumption=100.0, total_generation=80.0, peak_power=50.0, avg_power=30.0, operating_hours=8.0, cost=50.0, carbon_emission=40.0, ), ] db_session.add_all(summaries) await db_session.commit() return summaries @pytest.fixture async def seed_alarm_rule(db_session: AsyncSession, admin_user): rule = AlarmRule( name="高温报警", data_type="temperature", condition="gt", threshold=80.0, severity="warning", created_by=admin_user.id, is_active=True, ) db_session.add(rule) await db_session.commit() await db_session.refresh(rule) return rule @pytest.fixture async def seed_alarm_event(db_session: AsyncSession, seed_devices, seed_alarm_rule): event = AlarmEvent( rule_id=seed_alarm_rule.id, device_id=seed_devices[0].id, severity="warning", title="温度过高", description="设备温度超过阈值", value=85.0, threshold=80.0, status="active", ) db_session.add(event) await db_session.commit() await db_session.refresh(event) return event @pytest.fixture async def seed_carbon(db_session: AsyncSession): now = datetime.now(timezone.utc) records = [ CarbonEmission(date=now, scope=2, category="electricity", emission=100.0, reduction=20.0), ] db_session.add_all(records) await db_session.commit() return records @pytest.fixture async def seed_emission_factors(db_session: AsyncSession): factors = [ EmissionFactor(name="华北电网", energy_type="electricity", factor=0.8843, unit="kWh", scope=2, region="north_china", source="生态环境部"), ] db_session.add_all(factors) await db_session.commit() return factors @pytest.fixture async def seed_report_template(db_session: AsyncSession, admin_user): template = ReportTemplate( name="日报模板", report_type="daily", description="每日能耗报表", fields=[{"name": "consumption", "label": "能耗"}], created_by=admin_user.id, ) db_session.add(template) await db_session.commit() await db_session.refresh(template) return template @pytest.fixture async def seed_report_task(db_session: AsyncSession, seed_report_template, admin_user): task = ReportTask( template_id=seed_report_template.id, name="测试任务", export_format="xlsx", created_by=admin_user.id, ) db_session.add(task) await db_session.commit() await db_session.refresh(task) return task