weighted_median/
lib.rs

1mod partition;
2use partition::partition;
3
4pub trait Data {
5    fn get_value(&self) -> f64;
6    fn get_weight(&self) -> f64;
7}
8
9#[inline]
10fn weight_sum<T: Data>(input: &mut [T]) -> f64 {
11    input
12        .into_iter()
13        .fold(0.0, |accum, item| accum + item.get_weight())
14}
15
16pub fn calculate<T: Data>(
17    data: &mut [T],
18    lower_weight_delta: f64,
19    higher_weight_delta: f64,
20) -> Option<f64> {
21    match data.len() {
22        0 => None,
23        1 => Some(data[0].get_value()),
24        2 => {
25            let lower = lower_weight_delta + data[0].get_weight();
26            let higher = data[1].get_weight() + higher_weight_delta;
27            if lower == higher {
28                Some((data[0].get_value() + data[1].get_value()) / 2.0)
29            } else if lower > higher {
30                Some(data[0].get_value())
31            } else {
32                Some(data[1].get_value())
33            }
34        }
35        _ => {
36            let (pivot_index, new_data, pivot_extra_weight) = partition(data, data.len() / 2);
37
38            let pivot_weight = new_data[pivot_index].get_weight() + pivot_extra_weight;
39            let lower_weight_sum = lower_weight_delta + weight_sum(&mut new_data[..pivot_index]);
40            let higher_weight_sum =
41                higher_weight_delta + weight_sum(&mut new_data[pivot_index + 1..]);
42            let weight_sum = lower_weight_sum + pivot_weight + higher_weight_sum;
43
44            if lower_weight_sum / weight_sum < 0.5 && higher_weight_sum / weight_sum < 0.5 {
45                Some(new_data[pivot_index].get_value())
46            } else if lower_weight_sum / weight_sum >= 0.5 {
47                let next_data = &mut new_data[..pivot_index + 1];
48                calculate(
49                    next_data,
50                    lower_weight_delta,
51                    higher_weight_sum + pivot_extra_weight,
52                )
53            } else {
54                let next_data = &mut new_data[pivot_index..];
55                calculate(
56                    next_data,
57                    lower_weight_sum + pivot_extra_weight,
58                    higher_weight_delta,
59                )
60            }
61        }
62    }
63}
64
65#[inline]
66pub fn weighted_median<T: Data>(data: &mut [T]) -> Option<f64> {
67    calculate(data, 0.0, 0.0)
68}
69
70#[cfg(test)]
71mod tests {
72    use crate::{weighted_median, Data};
73
74    #[derive(Debug, PartialEq)]
75    pub struct TestData {
76        value: f64,
77        weight: f64,
78    }
79
80    pub struct OnlyValue {
81        value: i32
82    }
83
84    impl OnlyValue {
85        pub fn weight(&self, weight: i32) -> TestData {
86            TestData { value: self.value.into(), weight: weight.into() }
87        }
88    }
89
90    impl Data for TestData {
91        fn get_value(&self) -> f64 {
92            self.value
93        }
94
95        fn get_weight(&self) -> f64 {
96            self.weight
97        }
98    }
99
100    impl TestData {
101        pub fn value(value: i32) -> OnlyValue {
102            OnlyValue { value }
103        }
104    }
105
106    #[test]
107    fn empty_slice() {
108        assert_eq!(weighted_median::<TestData>(&mut []), None)
109    }
110
111    #[test]
112    fn one_element() {
113        assert_eq!(
114            weighted_median(&mut [TestData::value(7).weight(9)]),
115            Some(7.0)
116        );
117    }
118
119    #[test]
120    fn two_elements_different_weight() {
121        assert_eq!(
122            weighted_median(&mut [
123                TestData::value(7).weight(1),
124                TestData::value(8).weight(2),
125            ]),
126            Some(8.0)
127        );
128        assert_eq!(
129            weighted_median(&mut [
130                TestData::value(8).weight(2),
131                TestData::value(7).weight(1),
132            ]),
133            Some(8.0)
134        );
135    }
136
137    #[test]
138    fn two_elements_same_weight() {
139        assert_eq!(
140            weighted_median(&mut [
141                TestData::value(7).weight(1),
142                TestData::value(8).weight(1)
143            ]),
144            Some(7.5)
145        )
146    }
147
148    #[test]
149    fn three_elements_is_first_element() {
150        assert_eq!(
151            weighted_median(&mut [
152                TestData::value(2).weight(1),
153                TestData::value(3).weight(1),
154                TestData::value(1).weight(1),
155            ]),
156            Some(2.0)
157        )
158    }
159
160    #[test]
161    fn three_elements_is_middle_element() {
162        assert_eq!(
163            weighted_median(&mut [
164                TestData::value(3).weight(1),
165                TestData::value(2).weight(1),
166                TestData::value(1).weight(1),
167            ]),
168            Some(2.0)
169        )
170    }
171
172    #[test]
173    fn three_elements_is_last_element() {
174        assert_eq!(
175            weighted_median(&mut [
176                TestData::value(3).weight(1),
177                TestData::value(1).weight(1),
178                TestData::value(2).weight(1),
179            ]),
180            Some(2.0)
181        )
182    }
183
184    #[test]
185    fn three_elements_is_smallest_element() {
186        assert_eq!(
187            weighted_median(&mut [
188                TestData::value(3).weight(1),
189                TestData::value(2).weight(1),
190                TestData {
191                    value: 1.0,
192                    weight: 5.0
193                },
194            ]),
195            Some(1.0)
196        )
197    }
198
199    #[test]
200    fn three_elements_is_biggest_element() {
201        assert_eq!(
202            weighted_median(&mut [
203                TestData::value(3).weight(5),
204                TestData::value(2).weight(1),
205                TestData::value(1).weight(1),
206            ]),
207            Some(3.0)
208        )
209    }
210
211    #[test]
212    fn three_elements_is_even() {
213        assert_eq!(
214            weighted_median(&mut [
215                TestData::value(3).weight(2),
216                TestData::value(2).weight(1),
217                TestData::value(1).weight(1),
218            ]),
219            Some(2.5)
220        );
221        assert_eq!(
222            weighted_median(&mut [
223                TestData::value(1).weight(1),
224                TestData::value(2).weight(1),
225                TestData::value(3).weight(2),
226            ]),
227            Some(2.5)
228        );
229    }
230
231    #[test]
232    fn four_elements_is_even() {
233        assert_eq!(
234            weighted_median(&mut [
235                TestData::value(1).weight(49),
236                TestData::value(2).weight(1),
237                TestData::value(3).weight(25),
238                TestData::value(1000).weight(25),
239            ]),
240            Some(2.5)
241        );
242    }
243
244    #[test]
245    fn five_elements_is_pivot_value() {
246        assert_eq!(
247            weighted_median(&mut [
248                TestData::value(2).weight(5),
249                TestData::value(1).weight(5),
250                TestData::value(3).weight(1),
251                TestData::value(10).weight(8),
252                TestData::value(8).weight(2),
253            ]),
254            Some(3.0)
255        );
256    }
257
258    #[test]
259    fn duplicated_values() {
260        assert_eq!(
261            weighted_median(&mut [
262                TestData::value(1).weight(1),
263                TestData::value(1).weight(1),
264                TestData::value(2).weight(2),
265            ]),
266            Some(1.5)
267        );
268
269        assert_eq!(
270            weighted_median(&mut [
271                TestData::value(1).weight(2),
272                TestData::value(2).weight(1),
273                TestData::value(2).weight(1)
274            ]),
275            Some(1.5)
276        );
277    }
278}