Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Tests for random number generation utilities.
"""
import pytest
import math
from src.core.random_utils import exponential_random, set_seed, choice_random
class TestExponentialRandom:
"""Tests for exponential random variable generation."""
def test_exponential_mean(self):
"""Test that exponential random variable has correct mean."""
set_seed(42)
rate = 0.1
expected_mean = 1 / rate # 10.0
# Generate many samples
samples = [exponential_random(rate) for _ in range(10000)]
actual_mean = sum(samples) / len(samples)
# Mean should be close to expected (within 5%)
assert abs(actual_mean - expected_mean) / expected_mean < 0.05
def test_exponential_positive(self):
"""Test that all generated values are positive."""
set_seed(42)
samples = [exponential_random(0.5) for _ in range(1000)]
assert all(x > 0 for x in samples)
def test_exponential_invalid_rate(self):
"""Test that invalid rate raises ValueError."""
with pytest.raises(ValueError):
exponential_random(0)
with pytest.raises(ValueError):
exponential_random(-0.5)
def test_exponential_reproducibility(self):
"""Test that setting seed produces reproducible results."""
set_seed(123)
samples1 = [exponential_random(0.2) for _ in range(10)]
set_seed(123)
samples2 = [exponential_random(0.2) for _ in range(10)]
assert samples1 == samples2
class TestChoiceRandom:
"""Tests for probabilistic choice."""
def test_choice_distribution(self):
"""Test that choices follow probability distribution."""
set_seed(42)
probs = [0.5, 0.3, 0.2]
counts = [0, 0, 0]
# Generate many samples
n_samples = 10000
for _ in range(n_samples):
choice = choice_random(probs)
counts[choice] += 1
# Check that frequencies are close to probabilities
for i, prob in enumerate(probs):
expected = prob * n_samples
actual = counts[i]
# Within 10% error
assert abs(actual - expected) / expected < 0.1
def test_choice_invalid_probabilities(self):
"""Test that invalid probabilities raise ValueError."""
with pytest.raises(ValueError):
choice_random([0.5, 0.3, 0.1]) # Sum = 0.9
with pytest.raises(ValueError):
choice_random([0.6, 0.6]) # Sum = 1.2
def test_choice_valid_range(self):
"""Test that choice is always in valid range."""
set_seed(42)
probs = [0.25, 0.25, 0.25, 0.25]
samples = [choice_random(probs) for _ in range(1000)]
assert all(0 <= s < len(probs) for s in samples)