""" Tests for random number generation utilities. """ import pytest import math from src.core.random_utils import exponential_random, set_seed, choice_random class TestExponentialRandom: """Tests for exponential random variable generation.""" def test_exponential_mean(self): """Test that exponential random variable has correct mean.""" set_seed(42) rate = 0.1 expected_mean = 1 / rate # 10.0 # Generate many samples samples = [exponential_random(rate) for _ in range(10000)] actual_mean = sum(samples) / len(samples) # Mean should be close to expected (within 5%) assert abs(actual_mean - expected_mean) / expected_mean < 0.05 def test_exponential_positive(self): """Test that all generated values are positive.""" set_seed(42) samples = [exponential_random(0.5) for _ in range(1000)] assert all(x > 0 for x in samples) def test_exponential_invalid_rate(self): """Test that invalid rate raises ValueError.""" with pytest.raises(ValueError): exponential_random(0) with pytest.raises(ValueError): exponential_random(-0.5) def test_exponential_reproducibility(self): """Test that setting seed produces reproducible results.""" set_seed(123) samples1 = [exponential_random(0.2) for _ in range(10)] set_seed(123) samples2 = [exponential_random(0.2) for _ in range(10)] assert samples1 == samples2 class TestChoiceRandom: """Tests for probabilistic choice.""" def test_choice_distribution(self): """Test that choices follow probability distribution.""" set_seed(42) probs = [0.5, 0.3, 0.2] counts = [0, 0, 0] # Generate many samples n_samples = 10000 for _ in range(n_samples): choice = choice_random(probs) counts[choice] += 1 # Check that frequencies are close to probabilities for i, prob in enumerate(probs): expected = prob * n_samples actual = counts[i] # Within 10% error assert abs(actual - expected) / expected < 0.1 def test_choice_invalid_probabilities(self): """Test that invalid probabilities raise ValueError.""" with pytest.raises(ValueError): choice_random([0.5, 0.3, 0.1]) # Sum = 0.9 with pytest.raises(ValueError): choice_random([0.6, 0.6]) # Sum = 1.2 def test_choice_valid_range(self): """Test that choice is always in valid range.""" set_seed(42) probs = [0.25, 0.25, 0.25, 0.25] samples = [choice_random(probs) for _ in range(1000)] assert all(0 <= s < len(probs) for s in samples)