welford/
lib.rs

1use num_traits::{cast, Num, NumAssign, NumCast, Zero};
2
3/// An·online·calculator·for·both·mean·and·variance.
4///
5/// References:
6/// - https://doi.org/10.1080/00401706.1962.10490022
7/// - https://stats.stackexchange.com/a/235151/146964
8pub struct Welford<T, W = usize> {
9    mean: Option<T>,
10    total: W,
11    msq: T,
12}
13
14impl<T> Welford<T>
15where
16    T: Zero,
17{
18    /// Create a new unweighted Welford calculator.
19    pub fn new() -> Self {
20        Self {
21            mean: None,
22            total: 0,
23            msq: T::zero(),
24        }
25    }
26}
27
28impl<T> Default for Welford<T>
29where
30    T: Zero,
31{
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl<T> Welford<T>
38where
39    T: Copy + Num + NumAssign + NumCast,
40{
41    pub fn push(&mut self, value: T) {
42        self.push_weighted(value, 1)
43    }
44}
45
46impl<T, W> Welford<T, W>
47where
48    T: Zero,
49    W: Zero,
50{
51    /// Create a new weighted Welford calculator.
52    pub fn with_weights() -> Self {
53        Self {
54            mean: None,
55            total: W::zero(),
56            msq: T::zero(),
57        }
58    }
59}
60
61impl<T, W> Welford<T, W>
62where
63    T: Copy + Num + NumAssign + NumCast,
64    W: Copy + Num + NumAssign + NumCast + PartialOrd,
65{
66    pub fn push_weighted(&mut self, value: T, weight: W) {
67        self.total += weight;
68
69        if self.mean.is_none() {
70            self.mean = Some(value);
71        }
72
73        // self.mean is Some(T) from here on.
74        let delta = value - self.mean.unwrap();
75
76        let total = cast(self.total).expect("failed to cast W to T");
77        let weighted_delta = delta * cast(weight).expect("failed to cast W to T");
78
79        *self.mean.as_mut().unwrap() += weighted_delta / total;
80
81        let delta2 = value - self.mean.unwrap();
82        self.msq += weighted_delta * delta2;
83    }
84
85    /// Mean.
86    pub fn mean(&self) -> Option<T> {
87        self.mean
88    }
89
90    /// Sample variance.
91    pub fn var(&self) -> Option<T> {
92        if self.total > W::one() {
93            let total: T = cast(self.total).expect("failed to cast W to T");
94            Some(self.msq / (total - T::one()))
95        } else {
96            None
97        }
98    }
99
100    fn merge(&mut self, other: Self) {
101        let weight = other.total;
102
103        if weight == W::zero() {
104            return;
105        } else if self.total == W::zero() {
106            *self = other;
107            return;
108        }
109
110        // self.mean is Some(T) from here on since totals have been updated.
111        // WARN: Probably unstable, see <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm>.
112        let delta = other.mean.unwrap() - self.mean.unwrap();
113
114        let total = self.total + weight;
115        let weighted_delta = delta * cast(weight).expect("failed to cast W to T");
116
117        let mean_corr = weighted_delta / cast(total).expect("failed to cast W to T");
118        *self.mean.as_mut().unwrap() += mean_corr;
119
120        self.msq +=
121            other.msq + delta * cast(self.total).expect("failed to cast W to T") * mean_corr;
122
123        self.total = total;
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    #[test]
132    fn test_welford() {
133        let mut w = Welford::default();
134        assert_eq!(w.mean(), None);
135        assert_eq!(w.var(), None);
136
137        w.push(1.0);
138        assert_eq!(w.mean(), Some(1.0));
139        assert_eq!(w.var(), None);
140
141        w.push(3.0);
142        assert_eq!(w.mean(), Some(2.0));
143        assert_eq!(w.var(), Some(2.0));
144
145        w.push(5.0);
146        assert_eq!(w.mean(), Some(3.0));
147        assert_eq!(w.var(), Some(4.0));
148    }
149
150    #[test]
151    fn test_weighted_welford() {
152        let mut w = Welford::with_weights();
153        assert_eq!(w.mean(), None);
154        assert_eq!(w.var(), None);
155
156        w.push_weighted(1.0, 3.0);
157        assert_eq!(w.mean(), Some(1.0));
158        assert_eq!(w.var(), Some(0.0));
159
160        w.push_weighted(3.0, 2.0);
161        assert_eq!(w.mean(), Some(1.8));
162        assert_eq!(w.var(), Some(1.2));
163
164        w.push_weighted(5.0, 1.0);
165        assert_eq!(w.mean(), Some(2.3333333333333335));
166        assert_eq!(w.var(), Some(2.6666666666666665));
167    }
168
169    #[test]
170    fn test_merge() {
171        let mut w1 = Welford::new();
172        let mut w2 = Welford::new();
173
174        w1.push(1.0);
175        w1.push(3.0);
176        w1.push(5.0);
177        w1.push(7.0);
178
179        w2.push(2.0);
180        w2.push(4.0);
181        w2.push(6.0);
182        w2.push(8.0);
183
184        w1.merge(w2);
185        assert_eq!(w1.mean(), Some(4.5));
186        assert_eq!(w1.var(), Some(6.0));
187    }
188
189    #[test]
190    fn test_weighted_merge() {
191        let mut w1 = Welford::with_weights();
192        let mut w2 = Welford::with_weights();
193
194        w1.push_weighted(1.0, 4.0);
195        w1.push_weighted(3.0, 3.0);
196        w1.push_weighted(5.0, 2.0);
197        w1.push_weighted(7.0, 1.0);
198
199        w2.push_weighted(2.0, 4.0);
200        w2.push_weighted(4.0, 3.0);
201        w2.push_weighted(6.0, 2.0);
202        w2.push_weighted(8.0, 1.0);
203
204        w1.merge(w2);
205        assert_eq!(w1.mean(), Some(3.5));
206        assert_eq!(w1.var(), Some(4.473684210526316));
207    }
208}