scenarios.py 5,92 ko
Newer Older
"""
API endpoints for predefined scenarios.
"""
from typing import List
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel

from ..core.scenarios import SCENARIOS, list_scenarios, get_scenario
from ..models.config import SimulationConfigModel, ServerConfig
from ..core.simulation import Simulator


router = APIRouter(prefix="/api/scenarios", tags=["scenarios"])


class ScenarioInfo(BaseModel):
    """Information about a scenario."""
    id: str
    name: str
    description: str


class ScenarioConfigResponse(BaseModel):
    """Response with scenario configuration."""
    id: str
    name: str
    description: str
    config: SimulationConfigModel


@router.get("/", response_model=List[ScenarioInfo])
async def get_scenarios():
    """
    List all available predefined scenarios.

    Returns:
        List of scenario information
    """
    scenarios = list_scenarios()
    return [ScenarioInfo(**s) for s in scenarios]


@router.get("/{scenario_id}", response_model=ScenarioConfigResponse)
async def get_scenario_config(scenario_id: str):
    """
    Get configuration for a specific scenario.

    Args:
        scenario_id: Scenario identifier (e.g., "scenario_1")

    Returns:
        Scenario configuration
    """
    if scenario_id not in SCENARIOS:
        raise HTTPException(
            status_code=404,
            detail=f"Scenario not found. Available: {', '.join(SCENARIOS.keys())}"
        )

    name, description, factory = SCENARIOS[scenario_id]
Hamadou Ba's avatar
Hamadou Ba a validé

    if scenario_id == "scenario_5":
        # Scenario 5 returns multiple variations; use the medium_lambda one as default
        variations = factory()
        _, config = next((v for v in variations if v[0] == "medium_lambda"), variations[1])
    else:
        config = factory()

    # Convert to Pydantic model
    servers = [
        ServerConfig(
            id=f"server_{i+1}",
            service_rate=mu,
            routing_probability=q
        )
        for i, (mu, q) in enumerate(zip(config.server_service_rates, config.server_routing_probs))
    ]

    config_model = SimulationConfigModel(
        arrival_rate=config.arrival_rate,
        coordinator_service_rate=config.coordinator_service_rate,
        coordinator_exit_probability=config.coordinator_exit_probability,
        servers=servers,
        warmup_time=config.warmup_time,
        simulation_time=config.simulation_time,
        random_seed=config.random_seed
    )

    return ScenarioConfigResponse(
        id=scenario_id,
        name=name,
        description=description,
        config=config_model
    )


@router.get("/scenario_5/variations")
async def get_scenario_5_variations():
    """
    Get all variations for scenario 5 (parameter sensitivity).

    Returns:
        List of configurations for scenario 5
    """
    from ..core.scenarios import get_scenario_5_variations

    variations = get_scenario_5_variations()
    results = []

    for desc, config in variations:
        servers = [
            ServerConfig(
                id=f"server_{i+1}",
                service_rate=mu,
                routing_probability=q
            )
            for i, (mu, q) in enumerate(zip(config.server_service_rates, config.server_routing_probs))
        ]

        config_model = SimulationConfigModel(
            arrival_rate=config.arrival_rate,
            coordinator_service_rate=config.coordinator_service_rate,
            coordinator_exit_probability=config.coordinator_exit_probability,
            servers=servers,
            warmup_time=config.warmup_time,
            simulation_time=config.simulation_time,
            random_seed=config.random_seed
        )

        results.append({
            "variation_id": desc,
            "config": config_model
        })

    return {"variations": results, "count": len(results)}


@router.post("/{scenario_id}/run")
async def run_scenario(scenario_id: str):
    """
    Run a predefined scenario.

    Args:
        scenario_id: Scenario identifier

    Returns:
        Session information with results
    """
    if scenario_id not in SCENARIOS:
        raise HTTPException(status_code=404, detail="Scenario not found")

    try:
Hamadou Ba's avatar
Hamadou Ba a validé
        if scenario_id == "scenario_5":
            from ..core.scenarios import get_scenario_5_variations
            variations = get_scenario_5_variations()
            _, config = next((v for v in variations if v[0] == "medium_lambda"), variations[1])
        else:
            config = get_scenario(scenario_id)

        # Run simulation
        simulator = Simulator(config)
        results = simulator.run()

        # Generate session ID
        import uuid
        session_id = str(uuid.uuid4())

        # Import session storage from simulation module
        from .simulation import simulation_sessions

        # Convert config to dict for storage
        servers = [
            {
                "id": f"server_{i+1}",
                "service_rate": mu,
                "routing_probability": q
            }
            for i, (mu, q) in enumerate(zip(config.server_service_rates, config.server_routing_probs))
        ]

        config_dict = {
            "arrival_rate": config.arrival_rate,
            "coordinator_service_rate": config.coordinator_service_rate,
            "coordinator_exit_probability": config.coordinator_exit_probability,
            "servers": servers,
            "warmup_time": config.warmup_time,
            "simulation_time": config.simulation_time,
            "random_seed": config.random_seed
        }

        # Store results
        simulation_sessions[session_id] = {
            "scenario_id": scenario_id,
            "config": config_dict,
            "results": results,
            "status": "completed"
        }

        return {
            "session_id": session_id,
            "scenario_id": scenario_id,
            "status": "completed",
            "message": f"Scenario {scenario_id} completed successfully"
        }

    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))