1use num_traits::{cast, Num, NumAssign, NumCast, Zero};
2
3pub struct Welford<T, W = usize> {
9 mean: Option<T>,
10 total: W,
11 msq: T,
12}
13
14impl<T> Welford<T>
15where
16 T: Zero,
17{
18 pub fn new() -> Self {
20 Self {
21 mean: None,
22 total: 0,
23 msq: T::zero(),
24 }
25 }
26}
27
28impl<T> Default for Welford<T>
29where
30 T: Zero,
31{
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl<T> Welford<T>
38where
39 T: Copy + Num + NumAssign + NumCast,
40{
41 pub fn push(&mut self, value: T) {
42 self.push_weighted(value, 1)
43 }
44}
45
46impl<T, W> Welford<T, W>
47where
48 T: Zero,
49 W: Zero,
50{
51 pub fn with_weights() -> Self {
53 Self {
54 mean: None,
55 total: W::zero(),
56 msq: T::zero(),
57 }
58 }
59}
60
61impl<T, W> Welford<T, W>
62where
63 T: Copy + Num + NumAssign + NumCast,
64 W: Copy + Num + NumAssign + NumCast + PartialOrd,
65{
66 pub fn push_weighted(&mut self, value: T, weight: W) {
67 self.total += weight;
68
69 if self.mean.is_none() {
70 self.mean = Some(value);
71 }
72
73 let delta = value - self.mean.unwrap();
75
76 let total = cast(self.total).expect("failed to cast W to T");
77 let weighted_delta = delta * cast(weight).expect("failed to cast W to T");
78
79 *self.mean.as_mut().unwrap() += weighted_delta / total;
80
81 let delta2 = value - self.mean.unwrap();
82 self.msq += weighted_delta * delta2;
83 }
84
85 pub fn mean(&self) -> Option<T> {
87 self.mean
88 }
89
90 pub fn var(&self) -> Option<T> {
92 if self.total > W::one() {
93 let total: T = cast(self.total).expect("failed to cast W to T");
94 Some(self.msq / (total - T::one()))
95 } else {
96 None
97 }
98 }
99
100 fn merge(&mut self, other: Self) {
101 let weight = other.total;
102
103 if weight == W::zero() {
104 return;
105 } else if self.total == W::zero() {
106 *self = other;
107 return;
108 }
109
110 let delta = other.mean.unwrap() - self.mean.unwrap();
113
114 let total = self.total + weight;
115 let weighted_delta = delta * cast(weight).expect("failed to cast W to T");
116
117 let mean_corr = weighted_delta / cast(total).expect("failed to cast W to T");
118 *self.mean.as_mut().unwrap() += mean_corr;
119
120 self.msq +=
121 other.msq + delta * cast(self.total).expect("failed to cast W to T") * mean_corr;
122
123 self.total = total;
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 #[test]
132 fn test_welford() {
133 let mut w = Welford::default();
134 assert_eq!(w.mean(), None);
135 assert_eq!(w.var(), None);
136
137 w.push(1.0);
138 assert_eq!(w.mean(), Some(1.0));
139 assert_eq!(w.var(), None);
140
141 w.push(3.0);
142 assert_eq!(w.mean(), Some(2.0));
143 assert_eq!(w.var(), Some(2.0));
144
145 w.push(5.0);
146 assert_eq!(w.mean(), Some(3.0));
147 assert_eq!(w.var(), Some(4.0));
148 }
149
150 #[test]
151 fn test_weighted_welford() {
152 let mut w = Welford::with_weights();
153 assert_eq!(w.mean(), None);
154 assert_eq!(w.var(), None);
155
156 w.push_weighted(1.0, 3.0);
157 assert_eq!(w.mean(), Some(1.0));
158 assert_eq!(w.var(), Some(0.0));
159
160 w.push_weighted(3.0, 2.0);
161 assert_eq!(w.mean(), Some(1.8));
162 assert_eq!(w.var(), Some(1.2));
163
164 w.push_weighted(5.0, 1.0);
165 assert_eq!(w.mean(), Some(2.3333333333333335));
166 assert_eq!(w.var(), Some(2.6666666666666665));
167 }
168
169 #[test]
170 fn test_merge() {
171 let mut w1 = Welford::new();
172 let mut w2 = Welford::new();
173
174 w1.push(1.0);
175 w1.push(3.0);
176 w1.push(5.0);
177 w1.push(7.0);
178
179 w2.push(2.0);
180 w2.push(4.0);
181 w2.push(6.0);
182 w2.push(8.0);
183
184 w1.merge(w2);
185 assert_eq!(w1.mean(), Some(4.5));
186 assert_eq!(w1.var(), Some(6.0));
187 }
188
189 #[test]
190 fn test_weighted_merge() {
191 let mut w1 = Welford::with_weights();
192 let mut w2 = Welford::with_weights();
193
194 w1.push_weighted(1.0, 4.0);
195 w1.push_weighted(3.0, 3.0);
196 w1.push_weighted(5.0, 2.0);
197 w1.push_weighted(7.0, 1.0);
198
199 w2.push_weighted(2.0, 4.0);
200 w2.push_weighted(4.0, 3.0);
201 w2.push_weighted(6.0, 2.0);
202 w2.push_weighted(8.0, 1.0);
203
204 w1.merge(w2);
205 assert_eq!(w1.mean(), Some(3.5));
206 assert_eq!(w1.var(), Some(4.473684210526316));
207 }
208}