1mod partition;
2use partition::partition;
3
4pub trait Data {
5 fn get_value(&self) -> f64;
6 fn get_weight(&self) -> f64;
7}
8
9#[inline]
10fn weight_sum<T: Data>(input: &mut [T]) -> f64 {
11 input
12 .into_iter()
13 .fold(0.0, |accum, item| accum + item.get_weight())
14}
15
16pub fn calculate<T: Data>(
17 data: &mut [T],
18 lower_weight_delta: f64,
19 higher_weight_delta: f64,
20) -> Option<f64> {
21 match data.len() {
22 0 => None,
23 1 => Some(data[0].get_value()),
24 2 => {
25 let lower = lower_weight_delta + data[0].get_weight();
26 let higher = data[1].get_weight() + higher_weight_delta;
27 if lower == higher {
28 Some((data[0].get_value() + data[1].get_value()) / 2.0)
29 } else if lower > higher {
30 Some(data[0].get_value())
31 } else {
32 Some(data[1].get_value())
33 }
34 }
35 _ => {
36 let (pivot_index, new_data, pivot_extra_weight) = partition(data, data.len() / 2);
37
38 let pivot_weight = new_data[pivot_index].get_weight() + pivot_extra_weight;
39 let lower_weight_sum = lower_weight_delta + weight_sum(&mut new_data[..pivot_index]);
40 let higher_weight_sum =
41 higher_weight_delta + weight_sum(&mut new_data[pivot_index + 1..]);
42 let weight_sum = lower_weight_sum + pivot_weight + higher_weight_sum;
43
44 if lower_weight_sum / weight_sum < 0.5 && higher_weight_sum / weight_sum < 0.5 {
45 Some(new_data[pivot_index].get_value())
46 } else if lower_weight_sum / weight_sum >= 0.5 {
47 let next_data = &mut new_data[..pivot_index + 1];
48 calculate(
49 next_data,
50 lower_weight_delta,
51 higher_weight_sum + pivot_extra_weight,
52 )
53 } else {
54 let next_data = &mut new_data[pivot_index..];
55 calculate(
56 next_data,
57 lower_weight_sum + pivot_extra_weight,
58 higher_weight_delta,
59 )
60 }
61 }
62 }
63}
64
65#[inline]
66pub fn weighted_median<T: Data>(data: &mut [T]) -> Option<f64> {
67 calculate(data, 0.0, 0.0)
68}
69
70#[cfg(test)]
71mod tests {
72 use crate::{weighted_median, Data};
73
74 #[derive(Debug, PartialEq)]
75 pub struct TestData {
76 value: f64,
77 weight: f64,
78 }
79
80 pub struct OnlyValue {
81 value: i32
82 }
83
84 impl OnlyValue {
85 pub fn weight(&self, weight: i32) -> TestData {
86 TestData { value: self.value.into(), weight: weight.into() }
87 }
88 }
89
90 impl Data for TestData {
91 fn get_value(&self) -> f64 {
92 self.value
93 }
94
95 fn get_weight(&self) -> f64 {
96 self.weight
97 }
98 }
99
100 impl TestData {
101 pub fn value(value: i32) -> OnlyValue {
102 OnlyValue { value }
103 }
104 }
105
106 #[test]
107 fn empty_slice() {
108 assert_eq!(weighted_median::<TestData>(&mut []), None)
109 }
110
111 #[test]
112 fn one_element() {
113 assert_eq!(
114 weighted_median(&mut [TestData::value(7).weight(9)]),
115 Some(7.0)
116 );
117 }
118
119 #[test]
120 fn two_elements_different_weight() {
121 assert_eq!(
122 weighted_median(&mut [
123 TestData::value(7).weight(1),
124 TestData::value(8).weight(2),
125 ]),
126 Some(8.0)
127 );
128 assert_eq!(
129 weighted_median(&mut [
130 TestData::value(8).weight(2),
131 TestData::value(7).weight(1),
132 ]),
133 Some(8.0)
134 );
135 }
136
137 #[test]
138 fn two_elements_same_weight() {
139 assert_eq!(
140 weighted_median(&mut [
141 TestData::value(7).weight(1),
142 TestData::value(8).weight(1)
143 ]),
144 Some(7.5)
145 )
146 }
147
148 #[test]
149 fn three_elements_is_first_element() {
150 assert_eq!(
151 weighted_median(&mut [
152 TestData::value(2).weight(1),
153 TestData::value(3).weight(1),
154 TestData::value(1).weight(1),
155 ]),
156 Some(2.0)
157 )
158 }
159
160 #[test]
161 fn three_elements_is_middle_element() {
162 assert_eq!(
163 weighted_median(&mut [
164 TestData::value(3).weight(1),
165 TestData::value(2).weight(1),
166 TestData::value(1).weight(1),
167 ]),
168 Some(2.0)
169 )
170 }
171
172 #[test]
173 fn three_elements_is_last_element() {
174 assert_eq!(
175 weighted_median(&mut [
176 TestData::value(3).weight(1),
177 TestData::value(1).weight(1),
178 TestData::value(2).weight(1),
179 ]),
180 Some(2.0)
181 )
182 }
183
184 #[test]
185 fn three_elements_is_smallest_element() {
186 assert_eq!(
187 weighted_median(&mut [
188 TestData::value(3).weight(1),
189 TestData::value(2).weight(1),
190 TestData {
191 value: 1.0,
192 weight: 5.0
193 },
194 ]),
195 Some(1.0)
196 )
197 }
198
199 #[test]
200 fn three_elements_is_biggest_element() {
201 assert_eq!(
202 weighted_median(&mut [
203 TestData::value(3).weight(5),
204 TestData::value(2).weight(1),
205 TestData::value(1).weight(1),
206 ]),
207 Some(3.0)
208 )
209 }
210
211 #[test]
212 fn three_elements_is_even() {
213 assert_eq!(
214 weighted_median(&mut [
215 TestData::value(3).weight(2),
216 TestData::value(2).weight(1),
217 TestData::value(1).weight(1),
218 ]),
219 Some(2.5)
220 );
221 assert_eq!(
222 weighted_median(&mut [
223 TestData::value(1).weight(1),
224 TestData::value(2).weight(1),
225 TestData::value(3).weight(2),
226 ]),
227 Some(2.5)
228 );
229 }
230
231 #[test]
232 fn four_elements_is_even() {
233 assert_eq!(
234 weighted_median(&mut [
235 TestData::value(1).weight(49),
236 TestData::value(2).weight(1),
237 TestData::value(3).weight(25),
238 TestData::value(1000).weight(25),
239 ]),
240 Some(2.5)
241 );
242 }
243
244 #[test]
245 fn five_elements_is_pivot_value() {
246 assert_eq!(
247 weighted_median(&mut [
248 TestData::value(2).weight(5),
249 TestData::value(1).weight(5),
250 TestData::value(3).weight(1),
251 TestData::value(10).weight(8),
252 TestData::value(8).weight(2),
253 ]),
254 Some(3.0)
255 );
256 }
257
258 #[test]
259 fn duplicated_values() {
260 assert_eq!(
261 weighted_median(&mut [
262 TestData::value(1).weight(1),
263 TestData::value(1).weight(1),
264 TestData::value(2).weight(2),
265 ]),
266 Some(1.5)
267 );
268
269 assert_eq!(
270 weighted_median(&mut [
271 TestData::value(1).weight(2),
272 TestData::value(2).weight(1),
273 TestData::value(2).weight(1)
274 ]),
275 Some(1.5)
276 );
277 }
278}