1use crate::Uncertain;
2use num_traits::{identities, Float};
3use rand_pcg::Pcg32;
4use std::error::Error;
5use std::fmt;
6
7const STEP: usize = 10;
8const MAXS: usize = 1000;
9
10#[derive(Debug, Clone)]
15pub struct ConvergenceError<F>
16where
17 F: Float,
18{
19 sample_mean: F,
20 diff_sum: F,
21 steps: F,
22 precision: F,
23}
24
25impl<F: Float> ConvergenceError<F> {
26 pub fn non_converged_value(&self) -> F {
31 self.sample_mean
32 }
33
34 pub fn two_sigma_error(&self) -> F {
49 let std = mean_standard_deviation(self.diff_sum, self.steps);
50 std + std
51 }
52
53 pub fn desired_precision(&self) -> F {
56 self.precision
57 }
58}
59
60impl<F: Float + fmt::Display> fmt::Display for ConvergenceError<F> {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 write!(
63 f,
64 "Expected value {} +/- {} did not converge to desired precision {}",
65 self.non_converged_value(),
66 self.two_sigma_error(),
67 self.desired_precision()
68 )
69 }
70}
71
72impl<F: Float + fmt::Debug + fmt::Display> Error for ConvergenceError<F> {}
73
74fn mean_standard_deviation<F: Float>(diff_sum: F, steps: F) -> F {
75 diff_sum.sqrt() / steps }
77
78pub fn compute<U>(src: &U, precision: U::Value) -> Result<U::Value, ConvergenceError<U::Value>>
80where
81 U: Uncertain + ?Sized,
82 U::Value: Float,
83{
84 let mut rng = Pcg32::new(0xcafef00dd15ea5e5, 0xa02bdbf7bb3c0a7);
85
86 let mut sample_mean = identities::zero();
87 let mut diff_sum = identities::zero();
88 let mut steps = identities::zero();
89
90 for batch in 0..MAXS {
91 for batch_step in 0..STEP {
92 let epoch = STEP * batch + batch_step;
93 let sample = src.sample(&mut rng, epoch);
94 let prev_sample_mean = sample_mean;
95
96 steps = steps + identities::one();
99 sample_mean = prev_sample_mean + (sample - prev_sample_mean) / steps;
100 diff_sum = diff_sum + (sample - prev_sample_mean) * (sample - sample_mean);
101 }
102
103 let std = mean_standard_deviation(diff_sum, steps);
104 if std + std <= precision {
105 return Ok(sample_mean);
106 }
107 }
108
109 Err(ConvergenceError {
110 sample_mean,
111 diff_sum,
112 steps,
113 precision,
114 })
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use crate::Distribution;
121 use rand_distr::Normal;
122
123 #[test]
124 fn simple_expectation() {
125 let values = vec![0.0, 1.0, 5.0, 17.0, 23525.108213];
126 for val in values {
127 let x = Distribution::from(Normal::new(val, 1.0).unwrap());
128
129 let mu = compute(&x, 0.1);
130 assert!(mu.is_ok());
131 assert!((mu.unwrap() - val).abs() < 0.1);
132 }
133 }
134
135 #[test]
136 fn failed_expectation() {
137 let x = Distribution::from(Normal::new(0.0, 1000.0).unwrap());
138
139 let mu = compute(&x, 0.1);
140 assert!(mu.is_err());
141 assert!(mu.err().unwrap().two_sigma_error() > 0.1);
142
143 let mu = compute(&x, 100.0);
144 assert!(mu.is_ok());
145 assert!(mu.unwrap().abs() < 100.0);
146 }
147
148 #[test]
149 fn errors_are_correct() {
150 let cases: Vec<f64> = vec![1000.0, 5000.0, 10_000.0, 23452345.0, 23245.0];
151 for var in cases {
152 let x = Distribution::from(Normal::new(0.0, var).unwrap());
153 let err = x.expect(0.1);
154 assert!(err.is_err());
155 let err = err.err().unwrap();
156
157 let have_err = err.two_sigma_error() / 2.0;
160 let want_err = var / ((STEP * MAXS) as f64).sqrt();
161 let tolerance = 0.01; assert!(
163 (have_err - want_err).abs() / want_err.abs() < tolerance,
164 "{} is not approximately {}",
165 have_err,
166 want_err
167 );
168
169 let val = err.non_converged_value();
172 assert!(
173 val.abs() < err.two_sigma_error(),
174 "{} is not close enough to 0",
175 val
176 );
177
178 assert!(err.two_sigma_error() > err.desired_precision());
180 }
181 }
182}