test_simulation.py 6,17 ko
Newer Older
"""
Tests for the simulation engine.
"""
import pytest
from src.core.simulation import Simulator, SimulationConfig
from src.core.random_utils import set_seed


class TestSimulationConfig:
    """Tests for simulation configuration."""

    def test_valid_config(self):
        """Test that valid configuration passes validation."""
        config = SimulationConfig(
            arrival_rate=0.1,
            coordinator_service_rate=0.15,
            coordinator_exit_probability=0.5,
            server_service_rates=[0.2],
            server_routing_probs=[0.5],
            simulation_time=1000.0
        )
        config.validate()  # Should not raise

    def test_invalid_arrival_rate(self):
        """Test that invalid arrival rate raises error."""
        config = SimulationConfig(
            arrival_rate=0,  # Invalid
            coordinator_service_rate=0.15,
            coordinator_exit_probability=0.5,
            server_service_rates=[0.2],
            server_routing_probs=[0.5]
        )
        with pytest.raises(ValueError):
            config.validate()

    def test_probability_conservation(self):
        """Test that probabilities must sum to 1.0."""
        config = SimulationConfig(
            arrival_rate=0.1,
            coordinator_service_rate=0.15,
            coordinator_exit_probability=0.5,
            server_service_rates=[0.2],
            server_routing_probs=[0.3]  # Sum = 0.8, should be 0.5
        )
        with pytest.raises(ValueError):
            config.validate()


class TestSimulator:
    """Tests for the simulation engine."""

    def test_simple_simulation(self):
        """Test a simple stable system simulation."""
        config = SimulationConfig(
            arrival_rate=0.08,  # λ = 0.08 (mean inter-arrival = 12.5)
            coordinator_service_rate=0.1,  # μc = 0.1 (mean service = 10)
            coordinator_exit_probability=0.5,  # p = 0.5
            server_service_rates=[0.1],  # μ1 = 0.1
            server_routing_probs=[0.5],  # q1 = 0.5
            warmup_time=1000.0,
            simulation_time=5000.0,
            random_seed=42
        )

        simulator = Simulator(config)
        results = simulator.run()

        # Basic sanity checks
        assert results.total_requests_arrived > 0
        assert results.total_requests_completed > 0
        assert results.average_system_time > 0

        # Check that utilizations are reasonable (≤ 1 for stable system, with small tolerance)
        # Note: May slightly exceed 1.0 due to discrete measurement
        assert results.coordinator_stats["utilization"] < 1.1
        assert results.server_stats["server_1"]["utilization"] < 1.1

        print(f"\nSimulation results:")
        print(f"  Requests arrived: {results.total_requests_arrived}")
        print(f"  Requests completed: {results.total_requests_completed}")
        print(f"  Average system time: {results.average_system_time:.2f}")
        print(f"  Coordinator utilization: {results.coordinator_stats['utilization']:.3f}")
        print(f"  Server 1 utilization: {results.server_stats['server_1']['utilization']:.3f}")

    def test_unstable_system(self):
        """Test detection of unstable system (high utilization)."""
        config = SimulationConfig(
            arrival_rate=0.15,  # High arrival rate
            coordinator_service_rate=0.1,  # Lower service rate
            coordinator_exit_probability=0.2,  # Low exit probability
            server_service_rates=[0.1],
            server_routing_probs=[0.8],  # High routing to server
            warmup_time=500.0,
            simulation_time=2000.0,
            random_seed=42
        )

        simulator = Simulator(config)
        results = simulator.run()

        # In unstable system, utilization should approach 1.0
        # (or exceed it in simulation due to finite time)
        print(f"\nUnstable system results:")
        print(f"  Coordinator utilization: {results.coordinator_stats['utilization']:.3f}")
        print(f"  Server 1 utilization: {results.server_stats['server_1']['utilization']:.3f}")

        # At least one queue should have very high utilization
        assert (results.coordinator_stats["utilization"] > 0.9 or
                results.server_stats["server_1"]["utilization"] > 0.9)

    def test_multiple_servers(self):
        """Test simulation with multiple servers."""
        config = SimulationConfig(
            arrival_rate=0.1,
            coordinator_service_rate=0.15,
            coordinator_exit_probability=0.4,
            server_service_rates=[0.2, 0.15, 0.1],  # 3 servers
            server_routing_probs=[0.2, 0.2, 0.2],  # Equal routing
            warmup_time=1000.0,
            simulation_time=5000.0,
            random_seed=42
        )

        simulator = Simulator(config)
        results = simulator.run()

        # Check that all 3 servers have statistics
        assert len(results.server_stats) == 3
        assert "server_1" in results.server_stats
        assert "server_2" in results.server_stats
        assert "server_3" in results.server_stats

        # All servers should have processed some requests
        for server_id, stats in results.server_stats.items():
            assert stats["total_arrivals"] > 0
            assert stats["total_departures"] > 0
            print(f"  {server_id}: {stats['total_departures']} requests, "
                  f"utilization={stats['utilization']:.3f}")

    def test_reproducibility(self):
        """Test that same seed produces same results."""
        config = SimulationConfig(
            arrival_rate=0.1,
            coordinator_service_rate=0.15,
            coordinator_exit_probability=0.5,
            server_service_rates=[0.2],
            server_routing_probs=[0.5],
            warmup_time=500.0,
            simulation_time=2000.0,
            random_seed=123
        )

        sim1 = Simulator(config)
        results1 = sim1.run()

        sim2 = Simulator(config)
        results2 = sim2.run()

        # Same seed should produce identical results
        assert results1.total_requests_arrived == results2.total_requests_arrived
        assert results1.total_requests_completed == results2.total_requests_completed
        assert abs(results1.average_system_time - results2.average_system_time) < 0.01