uncertain/
expectation.rs

1use crate::Uncertain;
2use num_traits::{identities, Float};
3use rand_pcg::Pcg32;
4use std::error::Error;
5use std::fmt;
6
7const STEP: usize = 10;
8const MAXS: usize = 1000;
9
10/// Information about a failed call to [`expect`](Uncertain::expect).
11///
12/// This struct allows introspection of a failed attempt to calculate
13/// the expected value of an [`Uncertain`](Uncertain).
14#[derive(Debug, Clone)]
15pub struct ConvergenceError<F>
16where
17    F: Float,
18{
19    sample_mean: F,
20    diff_sum: F,
21    steps: F,
22    precision: F,
23}
24
25impl<F: Float> ConvergenceError<F> {
26    /// The expected value estimate obtained.
27    ///
28    /// This value is less precise than desired
29    /// and should be used with caution.
30    pub fn non_converged_value(&self) -> F {
31        self.sample_mean
32    }
33
34    /// The two sigma confidence interval around the
35    /// computed value.
36    ///
37    /// # Details
38    ///
39    /// Mathematically, this is an estimate for the
40    /// standard deviation of the expected value, i.e.
41    /// `2 * sqrt(var(E(x)))`.
42    ///
43    /// This value is calculated under the assumption
44    /// that samples from the original [`Uncertain`](Uncertain)
45    /// are [identically and independently distributed][iid].
46    ///
47    /// [iid]: https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables
48    pub fn two_sigma_error(&self) -> F {
49        let std = mean_standard_deviation(self.diff_sum, self.steps);
50        std + std
51    }
52
53    /// The precision which was originally mandated by the
54    /// call to [`Uncertain::expect`](Uncertain::expect).
55    pub fn desired_precision(&self) -> F {
56        self.precision
57    }
58}
59
60impl<F: Float + fmt::Display> fmt::Display for ConvergenceError<F> {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        write!(
63            f,
64            "Expected value {} +/- {} did not converge to desired precision {}",
65            self.non_converged_value(),
66            self.two_sigma_error(),
67            self.desired_precision()
68        )
69    }
70}
71
72impl<F: Float + fmt::Debug + fmt::Display> Error for ConvergenceError<F> {}
73
74fn mean_standard_deviation<F: Float>(diff_sum: F, steps: F) -> F {
75    diff_sum.sqrt() / steps // = sqrt( sigma^2 / n ) i.e. sqrt(var(E(x)))
76}
77
78/// Compute the sample expectation.
79pub fn compute<U>(src: &U, precision: U::Value) -> Result<U::Value, ConvergenceError<U::Value>>
80where
81    U: Uncertain + ?Sized,
82    U::Value: Float,
83{
84    let mut rng = Pcg32::new(0xcafef00dd15ea5e5, 0xa02bdbf7bb3c0a7);
85
86    let mut sample_mean = identities::zero();
87    let mut diff_sum = identities::zero();
88    let mut steps = identities::zero();
89
90    for batch in 0..MAXS {
91        for batch_step in 0..STEP {
92            let epoch = STEP * batch + batch_step;
93            let sample = src.sample(&mut rng, epoch);
94            let prev_sample_mean = sample_mean;
95
96            // Using Welford's online algorithm:
97            // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
98            steps = steps + identities::one();
99            sample_mean = prev_sample_mean + (sample - prev_sample_mean) / steps;
100            diff_sum = diff_sum + (sample - prev_sample_mean) * (sample - sample_mean);
101        }
102
103        let std = mean_standard_deviation(diff_sum, steps);
104        if std + std <= precision {
105            return Ok(sample_mean);
106        }
107    }
108
109    Err(ConvergenceError {
110        sample_mean,
111        diff_sum,
112        steps,
113        precision,
114    })
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use crate::Distribution;
121    use rand_distr::Normal;
122
123    #[test]
124    fn simple_expectation() {
125        let values = vec![0.0, 1.0, 5.0, 17.0, 23525.108213];
126        for val in values {
127            let x = Distribution::from(Normal::new(val, 1.0).unwrap());
128
129            let mu = compute(&x, 0.1);
130            assert!(mu.is_ok());
131            assert!((mu.unwrap() - val).abs() < 0.1);
132        }
133    }
134
135    #[test]
136    fn failed_expectation() {
137        let x = Distribution::from(Normal::new(0.0, 1000.0).unwrap());
138
139        let mu = compute(&x, 0.1);
140        assert!(mu.is_err());
141        assert!(mu.err().unwrap().two_sigma_error() > 0.1);
142
143        let mu = compute(&x, 100.0);
144        assert!(mu.is_ok());
145        assert!(mu.unwrap().abs() < 100.0);
146    }
147
148    #[test]
149    fn errors_are_correct() {
150        let cases: Vec<f64> = vec![1000.0, 5000.0, 10_000.0, 23452345.0, 23245.0];
151        for var in cases {
152            let x = Distribution::from(Normal::new(0.0, var).unwrap());
153            let err = x.expect(0.1);
154            assert!(err.is_err());
155            let err = err.err().unwrap();
156
157            // one sigma should be var(x) / sqrt(N), N = STEP * MAXS
158            // we do two sigma, so divide by two
159            let have_err = err.two_sigma_error() / 2.0;
160            let want_err = var / ((STEP * MAXS) as f64).sqrt();
161            let tolerance = 0.01; // plus minus 1%
162            assert!(
163                (have_err - want_err).abs() / want_err.abs() < tolerance,
164                "{} is not approximately {}",
165                have_err,
166                want_err
167            );
168
169            // the value reported by the error should still be good to
170            // within the reported two sigma interval
171            let val = err.non_converged_value();
172            assert!(
173                val.abs() < err.two_sigma_error(),
174                "{} is not close enough to 0",
175                val
176            );
177
178            // but the desired precision should not have been reached
179            assert!(err.two_sigma_error() > err.desired_precision());
180        }
181    }
182}