1use crate::error::{MetricsError, Result};
59use std::collections::HashMap;
60use std::fmt;
61
62#[derive(Debug, Clone)]
64pub struct MetricCriterion {
65 pub name: String,
67 pub weight: f64,
69 pub higher_isbetter: bool,
71}
72
73#[derive(Debug, Clone, Copy)]
75pub enum AggregationStrategy {
76 WeightedSum,
78 WeightedGeometricMean,
80 WeightedHarmonicMean,
82 MinScore,
84 MaxScore,
86}
87
88#[derive(Debug, Clone)]
90pub struct SelectionCriteria {
91 pub metrics: Vec<MetricCriterion>,
93 pub aggregation: AggregationStrategy,
95 pub thresholds: HashMap<String, f64>,
97}
98
99impl Default for SelectionCriteria {
100 fn default() -> Self {
101 Self {
102 metrics: Vec::new(),
103 aggregation: AggregationStrategy::WeightedSum,
104 thresholds: HashMap::new(),
105 }
106 }
107}
108
109pub struct ModelSelector {
111 criteria: SelectionCriteria,
112}
113
114impl Default for ModelSelector {
115 fn default() -> Self {
116 Self::new()
117 }
118}
119
120impl ModelSelector {
121 pub fn new() -> Self {
123 Self {
124 criteria: SelectionCriteria::default(),
125 }
126 }
127
128 pub fn add_metric(&mut self, name: &str, weight: f64, higher_isbetter: bool) -> &mut Self {
130 self.criteria.metrics.push(MetricCriterion {
131 name: name.to_string(),
132 weight,
133 higher_isbetter,
134 });
135 self
136 }
137
138 pub fn with_aggregation(&mut self, strategy: AggregationStrategy) -> &mut Self {
140 self.criteria.aggregation = strategy;
141 self
142 }
143
144 pub fn add_threshold(&mut self, metricname: &str, threshold: f64) -> &mut Self {
146 self.criteria
147 .thresholds
148 .insert(metricname.to_string(), threshold);
149 self
150 }
151
152 pub fn select_best(&self, modelscores: &HashMap<String, Vec<(&str, f64)>>) -> Result<String> {
154 if modelscores.is_empty() {
155 return Err(MetricsError::InvalidInput("No models provided".to_string()));
156 }
157
158 let rankings = self.rank_models(modelscores)?;
159
160 rankings
161 .into_iter()
162 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
163 .map(|(model_name, _)| model_name)
164 .ok_or_else(|| MetricsError::ComputationError("No valid models found".to_string()))
165 }
166
167 pub fn rank_models(
169 &self,
170 modelscores: &HashMap<String, Vec<(&str, f64)>>,
171 ) -> Result<Vec<(String, f64)>> {
172 let mut rankings = Vec::new();
173
174 for (model_name, scores) in modelscores {
175 if let Ok(aggregated_score) = self.compute_aggregated_score(scores) {
176 if self.meets_thresholds(scores) {
177 rankings.push((model_name.clone(), aggregated_score));
178 }
179 }
180 }
181
182 rankings.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
183 Ok(rankings)
184 }
185
186 pub fn find_pareto_optimal(
188 &self,
189 modelscores: &HashMap<String, Vec<(&str, f64)>>,
190 ) -> Vec<String> {
191 let mut pareto_optimal = Vec::new();
192
193 for (model_name, scores) in modelscores {
194 let mut is_dominated = false;
195
196 for (other_name, other_scores) in modelscores {
197 if model_name == other_name {
198 continue;
199 }
200
201 if self.dominates(scores, other_scores) {
202 is_dominated = true;
203 break;
204 }
205 }
206
207 if !is_dominated {
208 pareto_optimal.push(model_name.clone());
209 }
210 }
211
212 pareto_optimal
213 }
214
215 fn compute_aggregated_score(&self, scores: &[(&str, f64)]) -> Result<f64> {
217 let score_map: HashMap<&str, f64> = scores.iter().cloned().collect();
218
219 let mut normalized_scores = Vec::new();
221 let mut total_weight = 0.0;
222
223 for criterion in &self.criteria.metrics {
224 if let Some(&score) = score_map.get(criterion.name.as_str()) {
225 let normalized = if criterion.higher_isbetter {
226 score
227 } else {
228 -score };
230
231 normalized_scores.push((normalized, criterion.weight));
232 total_weight += criterion.weight;
233 }
234 }
235
236 if normalized_scores.is_empty() {
237 return Err(MetricsError::InvalidInput(
238 "No matching metrics found".to_string(),
239 ));
240 }
241
242 for (_, weight) in &mut normalized_scores {
244 *weight /= total_weight;
245 }
246
247 let aggregated = match self.criteria.aggregation {
249 AggregationStrategy::WeightedSum => normalized_scores
250 .iter()
251 .map(|(score, weight)| score * weight)
252 .sum(),
253 AggregationStrategy::WeightedGeometricMean => {
254 let product: f64 = normalized_scores
255 .iter()
256 .map(|(score, weight)| score.abs().powf(*weight))
257 .product();
258 product
259 }
260 AggregationStrategy::WeightedHarmonicMean => {
261 let weighted_reciprocal_sum: f64 = normalized_scores
262 .iter()
263 .map(|(score, weight)| weight / score.abs())
264 .sum();
265 total_weight / weighted_reciprocal_sum
266 }
267 AggregationStrategy::MinScore => normalized_scores
268 .iter()
269 .map(|(_, score)| *score)
270 .fold(f64::INFINITY, f64::min),
271 AggregationStrategy::MaxScore => normalized_scores
272 .iter()
273 .map(|(_, score)| *score)
274 .fold(f64::NEG_INFINITY, f64::max),
275 };
276
277 Ok(aggregated)
278 }
279
280 fn meets_thresholds(&self, scores: &[(&str, f64)]) -> bool {
282 let score_map: HashMap<&str, f64> = scores.iter().cloned().collect();
283
284 for (metricname, threshold) in &self.criteria.thresholds {
285 if let Some(&score) = score_map.get(metricname.as_str()) {
286 if let Some(criterion) =
288 self.criteria.metrics.iter().find(|c| c.name == *metricname)
289 {
290 let meets_threshold = if criterion.higher_isbetter {
291 score >= *threshold
292 } else {
293 score <= *threshold
294 };
295
296 if !meets_threshold {
297 return false;
298 }
299 }
300 } else {
301 return false;
303 }
304 }
305
306 true
307 }
308
309 fn dominates(&self, scoresa: &[(&str, f64)], scores_b: &[(&str, f64)]) -> bool {
311 let map_a: HashMap<&str, f64> = scores_b.iter().cloned().collect();
312 let map_b: HashMap<&str, f64> = scores_b.iter().cloned().collect();
313
314 let mut at_least_one_better = false;
315
316 for criterion in &self.criteria.metrics {
317 let metricname = criterion.name.as_str();
318
319 if let (Some(&score_a), Some(&score_b)) = (map_a.get(metricname), map_b.get(metricname))
320 {
321 let a_better_than_b = if criterion.higher_isbetter {
322 score_a > score_b
323 } else {
324 score_a < score_b
325 };
326
327 let a_worse_than_b = if criterion.higher_isbetter {
328 score_a < score_b
329 } else {
330 score_a > score_b
331 };
332
333 if a_worse_than_b {
334 return false; }
336
337 if a_better_than_b {
338 at_least_one_better = true;
339 }
340 }
341 }
342
343 at_least_one_better
344 }
345}
346
347#[derive(Debug, Clone)]
349pub struct SelectionResult {
350 pub selected_model: String,
352 pub score: f64,
354 pub rankings: Vec<(String, f64)>,
356 pub pareto_optimal: Vec<String>,
358}
359
360impl fmt::Display for SelectionResult {
361 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362 writeln!(f, "Model Selection Results")?;
363 writeln!(f, "======================")?;
364 writeln!(
365 f,
366 "Selected Model: {} (Score: {:.4})",
367 self.selected_model, self.score
368 )?;
369 writeln!(f)?;
370
371 writeln!(f, "Complete Rankings:")?;
372 writeln!(f, "------------------")?;
373 for (i, (model, score)) in self.rankings.iter().enumerate() {
374 writeln!(f, "{}: {} ({:.4})", i + 1, model, score)?;
375 }
376
377 writeln!(f)?;
378 writeln!(f, "Pareto Optimal Models: {:?}", self.pareto_optimal)?;
379
380 Ok(())
381 }
382}
383
384pub struct ModelSelectionBuilder {
386 selector: ModelSelector,
387}
388
389impl ModelSelectionBuilder {
390 pub fn new() -> Self {
392 Self {
393 selector: ModelSelector::new(),
394 }
395 }
396
397 pub fn metric(mut self, name: &str, weight: f64, higher_isbetter: bool) -> Self {
399 self.selector.add_metric(name, weight, higher_isbetter);
400 self
401 }
402
403 pub fn aggregation(mut self, strategy: AggregationStrategy) -> Self {
405 self.selector.with_aggregation(strategy);
406 self
407 }
408
409 pub fn threshold(mut self, metricname: &str, threshold: f64) -> Self {
411 self.selector.add_threshold(metricname, threshold);
412 self
413 }
414
415 pub fn select(
417 self,
418 modelscores: &HashMap<String, Vec<(&str, f64)>>,
419 ) -> Result<SelectionResult> {
420 let selected_model = self.selector.select_best(modelscores)?;
421 let rankings = self.selector.rank_models(modelscores)?;
422 let pareto_optimal = self.selector.find_pareto_optimal(modelscores);
423
424 let score = rankings
425 .iter()
426 .find(|(name, _)| name == &selected_model)
427 .map(|(_, score)| *score)
428 .unwrap_or(0.0);
429
430 Ok(SelectionResult {
431 selected_model,
432 score,
433 rankings,
434 pareto_optimal,
435 })
436 }
437}
438
439impl Default for ModelSelectionBuilder {
440 fn default() -> Self {
441 Self::new()
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 fn create_test_scores() -> HashMap<String, Vec<(&'static str, f64)>> {
450 let mut scores = HashMap::new();
451 scores.insert(
452 "model_a".to_string(),
453 vec![("accuracy", 0.85), ("precision", 0.82), ("speed", 100.0)],
454 );
455 scores.insert(
456 "model_b".to_string(),
457 vec![("accuracy", 0.80), ("precision", 0.90), ("speed", 200.0)],
458 );
459 scores.insert(
460 "model_c".to_string(),
461 vec![("accuracy", 0.88), ("precision", 0.85), ("speed", 150.0)],
462 );
463 scores
464 }
465
466 #[test]
467 fn test_basic_selection() {
468 let scores = create_test_scores();
469
470 let mut selector = ModelSelector::new();
471 selector
472 .add_metric("accuracy", 0.6, true)
473 .add_metric("precision", 0.4, true);
474
475 let best = selector.select_best(&scores).unwrap();
476 assert!(!best.is_empty());
477 }
478
479 #[test]
480 fn test_ranking() {
481 let scores = create_test_scores();
482
483 let mut selector = ModelSelector::new();
484 selector
485 .add_metric("accuracy", 0.5, true)
486 .add_metric("precision", 0.5, true);
487
488 let rankings = selector.rank_models(&scores).unwrap();
489 assert_eq!(rankings.len(), 3);
490
491 for i in 1..rankings.len() {
493 assert!(rankings[i - 1].1 >= rankings[i].1);
494 }
495 }
496
497 #[test]
498 fn test_pareto_optimal() {
499 let scores = create_test_scores();
500
501 let mut selector = ModelSelector::new();
502 selector
503 .add_metric("accuracy", 1.0, true)
504 .add_metric("speed", 1.0, true);
505
506 let pareto = selector.find_pareto_optimal(&scores);
507 assert!(!pareto.is_empty());
508 }
509
510 #[test]
511 fn test_thresholds() {
512 let scores = create_test_scores();
513
514 let mut selector = ModelSelector::new();
515 selector
516 .add_metric("accuracy", 1.0, true)
517 .add_threshold("accuracy", 0.87); let rankings = selector.rank_models(&scores).unwrap();
520 assert_eq!(rankings.len(), 1);
521 assert_eq!(rankings[0].0, "model_c");
522 }
523
524 #[test]
525 fn test_different_aggregation_strategies() {
526 let scores = create_test_scores();
527
528 let strategies = [
529 AggregationStrategy::WeightedSum,
530 AggregationStrategy::WeightedGeometricMean,
531 AggregationStrategy::MinScore,
532 AggregationStrategy::MaxScore,
533 ];
534
535 for strategy in &strategies {
536 let mut selector = ModelSelector::new();
537 selector
538 .add_metric("accuracy", 0.5, true)
539 .add_metric("precision", 0.5, true)
540 .with_aggregation(*strategy);
541
542 let best = selector.select_best(&scores).unwrap();
543 assert!(!best.is_empty());
544 }
545 }
546
547 #[test]
548 fn test_builder_pattern() {
549 let scores = create_test_scores();
550
551 let result = ModelSelectionBuilder::new()
552 .metric("accuracy", 0.6, true)
553 .metric("precision", 0.4, true)
554 .threshold("accuracy", 0.8)
555 .aggregation(AggregationStrategy::WeightedSum)
556 .select(&scores)
557 .unwrap();
558
559 assert!(!result.selected_model.is_empty());
560 assert!(!result.rankings.is_empty());
561 assert!(!result.pareto_optimal.is_empty());
562 }
563
564 #[test]
565 fn test_empty_models() {
566 let scores = HashMap::new();
567 let selector = ModelSelector::new();
568
569 assert!(selector.select_best(&scores).is_err());
570 }
571
572 #[test]
573 fn test_minimization_metrics() {
574 let mut scores = HashMap::new();
575 scores.insert("model_a".to_string(), vec![("error", 0.1), ("time", 5.0)]);
576 scores.insert("model_b".to_string(), vec![("error", 0.2), ("time", 3.0)]);
577
578 let mut selector = ModelSelector::new();
579 selector
580 .add_metric("error", 0.7, false) .add_metric("time", 0.3, false); let best = selector.select_best(&scores).unwrap();
584 assert!(!best.is_empty());
585 }
586}