1use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum VotingStrategy {
17 #[default]
19 Hard,
20 Soft,
22}
23
24#[derive(Debug, Clone)]
29pub struct EnsembleVoting {
30 metadata: KernelMetadata,
31}
32
33impl Default for EnsembleVoting {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl EnsembleVoting {
40 #[must_use]
42 pub fn new() -> Self {
43 Self {
44 metadata: KernelMetadata::batch("ml/ensemble-voting", Domain::StatisticalML)
45 .with_description("Weighted majority voting ensemble")
46 .with_throughput(100_000)
47 .with_latency_us(10.0),
48 }
49 }
50
51 pub fn hard_vote(predictions: &[Vec<i32>], weights: Option<&[f64]>) -> Vec<i32> {
57 if predictions.is_empty() || predictions[0].is_empty() {
58 return Vec::new();
59 }
60
61 let n_classifiers = predictions.len();
62 let n_samples = predictions[0].len();
63
64 let default_weights: Vec<f64> = vec![1.0 / n_classifiers as f64; n_classifiers];
66 let weights = weights.unwrap_or(&default_weights);
67
68 (0..n_samples)
69 .map(|i| {
70 let mut class_weights: HashMap<i32, f64> = HashMap::new();
71
72 for (j, pred) in predictions.iter().enumerate() {
73 let class = pred[i];
74 *class_weights.entry(class).or_insert(0.0) += weights[j];
75 }
76
77 *class_weights
78 .iter()
79 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
80 .map(|(class, _)| class)
81 .unwrap_or(&0)
82 })
83 .collect()
84 }
85
86 pub fn soft_vote(probabilities: &[Vec<Vec<f64>>], weights: Option<&[f64]>) -> Vec<usize> {
93 if probabilities.is_empty() || probabilities[0].is_empty() {
94 return Vec::new();
95 }
96
97 let n_classifiers = probabilities.len();
98 let n_samples = probabilities[0].len();
99 let n_classes = probabilities[0][0].len();
100
101 let default_weights: Vec<f64> = vec![1.0 / n_classifiers as f64; n_classifiers];
103 let weights = weights.unwrap_or(&default_weights);
104
105 (0..n_samples)
106 .map(|sample_idx| {
107 let mut avg_probs = vec![0.0f64; n_classes];
109
110 for (classifier_idx, probs) in probabilities.iter().enumerate() {
111 for (class_idx, &prob) in probs[sample_idx].iter().enumerate() {
112 avg_probs[class_idx] += weights[classifier_idx] * prob;
113 }
114 }
115
116 avg_probs
118 .iter()
119 .enumerate()
120 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
121 .map(|(idx, _)| idx)
122 .unwrap_or(0)
123 })
124 .collect()
125 }
126
127 pub fn weighted_average(predictions: &[Vec<f64>], weights: Option<&[f64]>) -> Vec<f64> {
133 if predictions.is_empty() || predictions[0].is_empty() {
134 return Vec::new();
135 }
136
137 let n_regressors = predictions.len();
138 let n_samples = predictions[0].len();
139
140 let default_weights: Vec<f64> = vec![1.0; n_regressors];
142 let weights = weights.unwrap_or(&default_weights);
143 let weight_sum: f64 = weights.iter().sum();
144
145 (0..n_samples)
146 .map(|i| {
147 let weighted_sum: f64 = predictions
148 .iter()
149 .zip(weights.iter())
150 .map(|(preds, &w)| preds[i] * w)
151 .sum();
152 weighted_sum / weight_sum
153 })
154 .collect()
155 }
156
157 pub fn median_prediction(predictions: &[Vec<f64>]) -> Vec<f64> {
162 if predictions.is_empty() || predictions[0].is_empty() {
163 return Vec::new();
164 }
165
166 let n_samples = predictions[0].len();
167
168 (0..n_samples)
169 .map(|i| {
170 let mut values: Vec<f64> = predictions.iter().map(|p| p[i]).collect();
171 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
172
173 let n = values.len();
174 if n % 2 == 0 {
175 (values[n / 2 - 1] + values[n / 2]) / 2.0
176 } else {
177 values[n / 2]
178 }
179 })
180 .collect()
181 }
182}
183
184impl GpuKernel for EnsembleVoting {
185 fn metadata(&self) -> &KernelMetadata {
186 &self.metadata
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_ensemble_voting_metadata() {
196 let kernel = EnsembleVoting::new();
197 assert_eq!(kernel.metadata().id, "ml/ensemble-voting");
198 assert_eq!(kernel.metadata().domain, Domain::StatisticalML);
199 }
200
201 #[test]
202 fn test_hard_vote() {
203 let predictions = vec![
205 vec![0, 1, 0, 1, 0], vec![0, 0, 0, 1, 1], vec![1, 1, 0, 1, 0], ];
209
210 let result = EnsembleVoting::hard_vote(&predictions, None);
211
212 assert_eq!(result[0], 0); assert_eq!(result[1], 1); assert_eq!(result[2], 0); assert_eq!(result[3], 1); }
219
220 #[test]
221 fn test_hard_vote_weighted() {
222 let predictions = vec![vec![0, 0, 0], vec![1, 1, 1]];
223
224 let weights = vec![0.3, 0.7];
226 let result = EnsembleVoting::hard_vote(&predictions, Some(&weights));
227
228 assert_eq!(result, vec![1, 1, 1]);
230 }
231
232 #[test]
233 fn test_soft_vote() {
234 let probabilities = vec![
236 vec![
238 vec![0.9, 0.1], vec![0.4, 0.6], vec![0.5, 0.5], ],
242 vec![
244 vec![0.8, 0.2], vec![0.3, 0.7], vec![0.2, 0.8], ],
248 ];
249
250 let result = EnsembleVoting::soft_vote(&probabilities, None);
251
252 assert_eq!(result[0], 0); assert_eq!(result[1], 1); assert_eq!(result[2], 1); }
256
257 #[test]
258 fn test_weighted_average() {
259 let predictions = vec![
260 vec![1.0, 2.0, 3.0],
261 vec![2.0, 3.0, 4.0],
262 vec![3.0, 4.0, 5.0],
263 ];
264
265 let result = EnsembleVoting::weighted_average(&predictions, None);
266
267 assert!((result[0] - 2.0).abs() < 0.01);
269 assert!((result[1] - 3.0).abs() < 0.01);
270 assert!((result[2] - 4.0).abs() < 0.01);
271 }
272
273 #[test]
274 fn test_median_prediction() {
275 let predictions = vec![
276 vec![1.0, 100.0, 3.0],
277 vec![2.0, 2.0, 4.0],
278 vec![3.0, 3.0, 5.0],
279 ];
280
281 let result = EnsembleVoting::median_prediction(&predictions);
282
283 assert!((result[0] - 2.0).abs() < 0.01);
285 assert!((result[1] - 3.0).abs() < 0.01); assert!((result[2] - 4.0).abs() < 0.01);
287 }
288
289 #[test]
290 fn test_empty_predictions() {
291 let empty: Vec<Vec<i32>> = vec![];
292 assert!(EnsembleVoting::hard_vote(&empty, None).is_empty());
293
294 let empty_probs: Vec<Vec<Vec<f64>>> = vec![];
295 assert!(EnsembleVoting::soft_vote(&empty_probs, None).is_empty());
296
297 let empty_reg: Vec<Vec<f64>> = vec![];
298 assert!(EnsembleVoting::weighted_average(&empty_reg, None).is_empty());
299 }
300}