test_random_utils.py 2,72 ko
Newer Older
"""
Tests for random number generation utilities.
"""
import pytest
import math
from src.core.random_utils import exponential_random, set_seed, choice_random


class TestExponentialRandom:
    """Tests for exponential random variable generation."""

    def test_exponential_mean(self):
        """Test that exponential random variable has correct mean."""
        set_seed(42)
        rate = 0.1
        expected_mean = 1 / rate  # 10.0

        # Generate many samples
        samples = [exponential_random(rate) for _ in range(10000)]
        actual_mean = sum(samples) / len(samples)

        # Mean should be close to expected (within 5%)
        assert abs(actual_mean - expected_mean) / expected_mean < 0.05

    def test_exponential_positive(self):
        """Test that all generated values are positive."""
        set_seed(42)
        samples = [exponential_random(0.5) for _ in range(1000)]
        assert all(x > 0 for x in samples)

    def test_exponential_invalid_rate(self):
        """Test that invalid rate raises ValueError."""
        with pytest.raises(ValueError):
            exponential_random(0)
        with pytest.raises(ValueError):
            exponential_random(-0.5)

    def test_exponential_reproducibility(self):
        """Test that setting seed produces reproducible results."""
        set_seed(123)
        samples1 = [exponential_random(0.2) for _ in range(10)]

        set_seed(123)
        samples2 = [exponential_random(0.2) for _ in range(10)]

        assert samples1 == samples2


class TestChoiceRandom:
    """Tests for probabilistic choice."""

    def test_choice_distribution(self):
        """Test that choices follow probability distribution."""
        set_seed(42)
        probs = [0.5, 0.3, 0.2]
        counts = [0, 0, 0]

        # Generate many samples
        n_samples = 10000
        for _ in range(n_samples):
            choice = choice_random(probs)
            counts[choice] += 1

        # Check that frequencies are close to probabilities
        for i, prob in enumerate(probs):
            expected = prob * n_samples
            actual = counts[i]
            # Within 10% error
            assert abs(actual - expected) / expected < 0.1

    def test_choice_invalid_probabilities(self):
        """Test that invalid probabilities raise ValueError."""
        with pytest.raises(ValueError):
            choice_random([0.5, 0.3, 0.1])  # Sum = 0.9
        with pytest.raises(ValueError):
            choice_random([0.6, 0.6])  # Sum = 1.2

    def test_choice_valid_range(self):
        """Test that choice is always in valid range."""
        set_seed(42)
        probs = [0.25, 0.25, 0.25, 0.25]
        samples = [choice_random(probs) for _ in range(1000)]

        assert all(0 <= s < len(probs) for s in samples)