1#![forbid(unsafe_code)]
2use use_seed::SimulationSeed;
21
22#[derive(Debug, Clone, Copy, PartialEq)]
23pub struct MonteCarloEstimate {
24 pub mean: f64,
25 pub variance: f64,
26 pub samples: usize,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum MonteCarloError {
31 InvalidSampleCount,
32 NonFiniteObservation,
33}
34
35pub fn monte_carlo_mean<F>(
36 seed: SimulationSeed,
37 samples: usize,
38 observation: F,
39) -> Result<MonteCarloEstimate, MonteCarloError>
40where
41 F: Fn(f64) -> f64,
42{
43 if samples == 0 {
44 return Err(MonteCarloError::InvalidSampleCount);
45 }
46
47 let mut values = Vec::with_capacity(samples);
48 for index in 0..samples {
49 let sample = seed.mix(index + 1).to_unit_f64();
50 let value = observation(sample);
51 if !value.is_finite() {
52 return Err(MonteCarloError::NonFiniteObservation);
53 }
54
55 values.push(value);
56 }
57
58 let mean = values.iter().sum::<f64>() / samples as f64;
59 let variance = values
60 .iter()
61 .map(|value| {
62 let diff = *value - mean;
63 diff * diff
64 })
65 .sum::<f64>()
66 / samples as f64;
67
68 Ok(MonteCarloEstimate {
69 mean,
70 variance,
71 samples,
72 })
73}
74
75#[cfg(test)]
76mod tests {
77 use super::{MonteCarloError, monte_carlo_mean};
78 use use_seed::SimulationSeed;
79
80 #[test]
81 fn estimates_constant_observations() {
82 let estimate = monte_carlo_mean(SimulationSeed::new(3), 5, |_sample| 2.0).unwrap();
83
84 assert_eq!(estimate.mean, 2.0);
85 assert_eq!(estimate.variance, 0.0);
86 assert_eq!(estimate.samples, 5);
87 }
88
89 #[test]
90 fn stays_repeatable_for_the_same_seed() {
91 let first = monte_carlo_mean(SimulationSeed::new(11), 8, |sample| sample * 2.0).unwrap();
92 let second = monte_carlo_mean(SimulationSeed::new(11), 8, |sample| sample * 2.0).unwrap();
93
94 assert_eq!(first, second);
95 }
96
97 #[test]
98 fn rejects_invalid_inputs() {
99 assert_eq!(
100 monte_carlo_mean(SimulationSeed::new(1), 0, |sample| sample),
101 Err(MonteCarloError::InvalidSampleCount)
102 );
103 assert_eq!(
104 monte_carlo_mean(SimulationSeed::new(1), 2, |_sample| f64::NAN),
105 Err(MonteCarloError::NonFiniteObservation)
106 );
107 }
108}