scirs2_metrics/selection/
mod.rs

1//! Automated model selection based on multiple metrics
2//!
3//! This module provides utilities for automatically selecting the best model
4//! from a set of candidates based on multiple evaluation metrics.
5//!
6//! # Features
7//!
8//! - **Multi-metric evaluation**: Combine multiple metrics with custom weights
9//! - **Flexible scoring**: Support different aggregation strategies
10//! - **Pareto optimal selection**: Find models that are not dominated by others
11//! - **Cross-validation integration**: Work with CV results for robust selection
12//! - **Custom criteria**: Define custom selection criteria
13//!
14//! # Examples
15//!
16//! ## Basic Model Selection
17//!
18//! ```
19//! use scirs2_metrics::selection::{ModelSelector, SelectionCriteria};
20//! use std::collections::HashMap;
21//!
22//! // Define models and their metric scores
23//! let mut modelscores = HashMap::new();
24//! modelscores.insert("model_a".to_string(), vec![("accuracy", 0.85), ("precision", 0.82)]);
25//! modelscores.insert("model_b".to_string(), vec![("accuracy", 0.80), ("precision", 0.90)]);
26//! modelscores.insert("model_c".to_string(), vec![("accuracy", 0.88), ("precision", 0.85)]);
27//!
28//! // Create selector with weighted criteria
29//! let mut selector = ModelSelector::new();
30//! selector.add_metric("accuracy", 0.6, true)  // 60% weight, higher is better
31//!         .add_metric("precision", 0.4, true); // 40% weight, higher is better
32//!
33//! // Select best model
34//! let best_model = selector.select_best(&modelscores).unwrap();
35//! println!("Best model: {}", best_model);
36//! ```
37//!
38//! ## Pareto Optimal Selection
39//!
40//! ```
41//! use scirs2_metrics::selection::ModelSelector;
42//! use std::collections::HashMap;
43//!
44//! let mut modelscores = HashMap::new();
45//! modelscores.insert("model_a".to_string(), vec![("accuracy", 0.85), ("speed", 100.0)]);
46//! modelscores.insert("model_b".to_string(), vec![("accuracy", 0.80), ("speed", 200.0)]);
47//! modelscores.insert("model_c".to_string(), vec![("accuracy", 0.90), ("speed", 50.0)]);
48//!
49//! let mut selector = ModelSelector::new();
50//! selector
51//!     .add_metric("accuracy", 1.0, true)   // higher is better
52//!     .add_metric("speed", 1.0, true);     // higher is better (faster inference)
53//!
54//! let pareto_optimal = selector.find_pareto_optimal(&modelscores);
55//! println!("Pareto optimal models: {:?}", pareto_optimal);
56//! ```
57
58use crate::error::{MetricsError, Result};
59use std::collections::HashMap;
60use std::fmt;
61
62/// Represents a metric with its weight and optimization direction
63#[derive(Debug, Clone)]
64pub struct MetricCriterion {
65    /// Name of the metric
66    pub name: String,
67    /// Weight of the metric in the final score (0.0 to 1.0)
68    pub weight: f64,
69    /// Whether higher values are better
70    pub higher_isbetter: bool,
71}
72
73/// Aggregation strategies for combining multiple metrics
74#[derive(Debug, Clone, Copy)]
75pub enum AggregationStrategy {
76    /// Weighted sum of normalized scores
77    WeightedSum,
78    /// Weighted geometric mean
79    WeightedGeometricMean,
80    /// Weighted harmonic mean
81    WeightedHarmonicMean,
82    /// Minimum score across all metrics (conservative)
83    MinScore,
84    /// Maximum score across all metrics (optimistic)
85    MaxScore,
86}
87
88/// Model selection criteria configuration
89#[derive(Debug, Clone)]
90pub struct SelectionCriteria {
91    /// List of metrics to consider
92    pub metrics: Vec<MetricCriterion>,
93    /// Strategy for aggregating metric scores
94    pub aggregation: AggregationStrategy,
95    /// Minimum threshold that must be met for each metric
96    pub thresholds: HashMap<String, f64>,
97}
98
99impl Default for SelectionCriteria {
100    fn default() -> Self {
101        Self {
102            metrics: Vec::new(),
103            aggregation: AggregationStrategy::WeightedSum,
104            thresholds: HashMap::new(),
105        }
106    }
107}
108
109/// Main model selector that evaluates and ranks models
110pub struct ModelSelector {
111    criteria: SelectionCriteria,
112}
113
114impl Default for ModelSelector {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120impl ModelSelector {
121    /// Creates a new model selector
122    pub fn new() -> Self {
123        Self {
124            criteria: SelectionCriteria::default(),
125        }
126    }
127
128    /// Adds a metric to the selection criteria
129    pub fn add_metric(&mut self, name: &str, weight: f64, higher_isbetter: bool) -> &mut Self {
130        self.criteria.metrics.push(MetricCriterion {
131            name: name.to_string(),
132            weight,
133            higher_isbetter,
134        });
135        self
136    }
137
138    /// Sets the aggregation strategy
139    pub fn with_aggregation(&mut self, strategy: AggregationStrategy) -> &mut Self {
140        self.criteria.aggregation = strategy;
141        self
142    }
143
144    /// Adds a threshold for a specific metric
145    pub fn add_threshold(&mut self, metricname: &str, threshold: f64) -> &mut Self {
146        self.criteria
147            .thresholds
148            .insert(metricname.to_string(), threshold);
149        self
150    }
151
152    /// Selects the best model from a set of candidates
153    pub fn select_best(&self, modelscores: &HashMap<String, Vec<(&str, f64)>>) -> Result<String> {
154        if modelscores.is_empty() {
155            return Err(MetricsError::InvalidInput("No models provided".to_string()));
156        }
157
158        let rankings = self.rank_models(modelscores)?;
159
160        rankings
161            .into_iter()
162            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
163            .map(|(model_name, _)| model_name)
164            .ok_or_else(|| MetricsError::ComputationError("No valid models found".to_string()))
165    }
166
167    /// Ranks all models and returns them sorted by score (descending)
168    pub fn rank_models(
169        &self,
170        modelscores: &HashMap<String, Vec<(&str, f64)>>,
171    ) -> Result<Vec<(String, f64)>> {
172        let mut rankings = Vec::new();
173
174        for (model_name, scores) in modelscores {
175            if let Ok(aggregated_score) = self.compute_aggregated_score(scores) {
176                if self.meets_thresholds(scores) {
177                    rankings.push((model_name.clone(), aggregated_score));
178                }
179            }
180        }
181
182        rankings.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
183        Ok(rankings)
184    }
185
186    /// Finds Pareto optimal models (not dominated by any other model)
187    pub fn find_pareto_optimal(
188        &self,
189        modelscores: &HashMap<String, Vec<(&str, f64)>>,
190    ) -> Vec<String> {
191        let mut pareto_optimal = Vec::new();
192
193        for (model_name, scores) in modelscores {
194            let mut is_dominated = false;
195
196            for (other_name, other_scores) in modelscores {
197                if model_name == other_name {
198                    continue;
199                }
200
201                if self.dominates(scores, other_scores) {
202                    is_dominated = true;
203                    break;
204                }
205            }
206
207            if !is_dominated {
208                pareto_optimal.push(model_name.clone());
209            }
210        }
211
212        pareto_optimal
213    }
214
215    /// Computes the aggregated score for a model based on the selection criteria
216    fn compute_aggregated_score(&self, scores: &[(&str, f64)]) -> Result<f64> {
217        let score_map: HashMap<&str, f64> = scores.iter().cloned().collect();
218
219        // Normalize scores for each metric
220        let mut normalized_scores = Vec::new();
221        let mut total_weight = 0.0;
222
223        for criterion in &self.criteria.metrics {
224            if let Some(&score) = score_map.get(criterion.name.as_str()) {
225                let normalized = if criterion.higher_isbetter {
226                    score
227                } else {
228                    -score // Flip for minimization metrics
229                };
230
231                normalized_scores.push((normalized, criterion.weight));
232                total_weight += criterion.weight;
233            }
234        }
235
236        if normalized_scores.is_empty() {
237            return Err(MetricsError::InvalidInput(
238                "No matching metrics found".to_string(),
239            ));
240        }
241
242        // Normalize weights
243        for (_, weight) in &mut normalized_scores {
244            *weight /= total_weight;
245        }
246
247        // Apply aggregation strategy
248        let aggregated = match self.criteria.aggregation {
249            AggregationStrategy::WeightedSum => normalized_scores
250                .iter()
251                .map(|(score, weight)| score * weight)
252                .sum(),
253            AggregationStrategy::WeightedGeometricMean => {
254                let product: f64 = normalized_scores
255                    .iter()
256                    .map(|(score, weight)| score.abs().powf(*weight))
257                    .product();
258                product
259            }
260            AggregationStrategy::WeightedHarmonicMean => {
261                let weighted_reciprocal_sum: f64 = normalized_scores
262                    .iter()
263                    .map(|(score, weight)| weight / score.abs())
264                    .sum();
265                total_weight / weighted_reciprocal_sum
266            }
267            AggregationStrategy::MinScore => normalized_scores
268                .iter()
269                .map(|(_, score)| *score)
270                .fold(f64::INFINITY, f64::min),
271            AggregationStrategy::MaxScore => normalized_scores
272                .iter()
273                .map(|(_, score)| *score)
274                .fold(f64::NEG_INFINITY, f64::max),
275        };
276
277        Ok(aggregated)
278    }
279
280    /// Checks if a model meets all threshold requirements
281    fn meets_thresholds(&self, scores: &[(&str, f64)]) -> bool {
282        let score_map: HashMap<&str, f64> = scores.iter().cloned().collect();
283
284        for (metricname, threshold) in &self.criteria.thresholds {
285            if let Some(&score) = score_map.get(metricname.as_str()) {
286                // Find the metric criterion to check optimization direction
287                if let Some(criterion) =
288                    self.criteria.metrics.iter().find(|c| c.name == *metricname)
289                {
290                    let meets_threshold = if criterion.higher_isbetter {
291                        score >= *threshold
292                    } else {
293                        score <= *threshold
294                    };
295
296                    if !meets_threshold {
297                        return false;
298                    }
299                }
300            } else {
301                // Metric not found, consider as not meeting threshold
302                return false;
303            }
304        }
305
306        true
307    }
308
309    /// Checks if model A dominates model B (Pareto dominance)
310    fn dominates(&self, scoresa: &[(&str, f64)], scores_b: &[(&str, f64)]) -> bool {
311        let map_a: HashMap<&str, f64> = scores_b.iter().cloned().collect();
312        let map_b: HashMap<&str, f64> = scores_b.iter().cloned().collect();
313
314        let mut at_least_one_better = false;
315
316        for criterion in &self.criteria.metrics {
317            let metricname = criterion.name.as_str();
318
319            if let (Some(&score_a), Some(&score_b)) = (map_a.get(metricname), map_b.get(metricname))
320            {
321                let a_better_than_b = if criterion.higher_isbetter {
322                    score_a > score_b
323                } else {
324                    score_a < score_b
325                };
326
327                let a_worse_than_b = if criterion.higher_isbetter {
328                    score_a < score_b
329                } else {
330                    score_a > score_b
331                };
332
333                if a_worse_than_b {
334                    return false; // A is worse in at least one metric
335                }
336
337                if a_better_than_b {
338                    at_least_one_better = true;
339                }
340            }
341        }
342
343        at_least_one_better
344    }
345}
346
347/// Represents the result of model selection with detailed information
348#[derive(Debug, Clone)]
349pub struct SelectionResult {
350    /// Name of the selected model
351    pub selected_model: String,
352    /// Final aggregated score
353    pub score: f64,
354    /// All models ranked by score
355    pub rankings: Vec<(String, f64)>,
356    /// Pareto optimal models
357    pub pareto_optimal: Vec<String>,
358}
359
360impl fmt::Display for SelectionResult {
361    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362        writeln!(f, "Model Selection Results")?;
363        writeln!(f, "======================")?;
364        writeln!(
365            f,
366            "Selected Model: {} (Score: {:.4})",
367            self.selected_model, self.score
368        )?;
369        writeln!(f)?;
370
371        writeln!(f, "Complete Rankings:")?;
372        writeln!(f, "------------------")?;
373        for (i, (model, score)) in self.rankings.iter().enumerate() {
374            writeln!(f, "{}: {} ({:.4})", i + 1, model, score)?;
375        }
376
377        writeln!(f)?;
378        writeln!(f, "Pareto Optimal Models: {:?}", self.pareto_optimal)?;
379
380        Ok(())
381    }
382}
383
384/// Builder for creating complex model selection scenarios
385pub struct ModelSelectionBuilder {
386    selector: ModelSelector,
387}
388
389impl ModelSelectionBuilder {
390    /// Creates a new builder
391    pub fn new() -> Self {
392        Self {
393            selector: ModelSelector::new(),
394        }
395    }
396
397    /// Adds a metric with weight and direction
398    pub fn metric(mut self, name: &str, weight: f64, higher_isbetter: bool) -> Self {
399        self.selector.add_metric(name, weight, higher_isbetter);
400        self
401    }
402
403    /// Sets the aggregation strategy
404    pub fn aggregation(mut self, strategy: AggregationStrategy) -> Self {
405        self.selector.with_aggregation(strategy);
406        self
407    }
408
409    /// Adds a threshold for a metric
410    pub fn threshold(mut self, metricname: &str, threshold: f64) -> Self {
411        self.selector.add_threshold(metricname, threshold);
412        self
413    }
414
415    /// Builds the selector and performs complete model selection
416    pub fn select(
417        self,
418        modelscores: &HashMap<String, Vec<(&str, f64)>>,
419    ) -> Result<SelectionResult> {
420        let selected_model = self.selector.select_best(modelscores)?;
421        let rankings = self.selector.rank_models(modelscores)?;
422        let pareto_optimal = self.selector.find_pareto_optimal(modelscores);
423
424        let score = rankings
425            .iter()
426            .find(|(name, _)| name == &selected_model)
427            .map(|(_, score)| *score)
428            .unwrap_or(0.0);
429
430        Ok(SelectionResult {
431            selected_model,
432            score,
433            rankings,
434            pareto_optimal,
435        })
436    }
437}
438
439impl Default for ModelSelectionBuilder {
440    fn default() -> Self {
441        Self::new()
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    fn create_test_scores() -> HashMap<String, Vec<(&'static str, f64)>> {
450        let mut scores = HashMap::new();
451        scores.insert(
452            "model_a".to_string(),
453            vec![("accuracy", 0.85), ("precision", 0.82), ("speed", 100.0)],
454        );
455        scores.insert(
456            "model_b".to_string(),
457            vec![("accuracy", 0.80), ("precision", 0.90), ("speed", 200.0)],
458        );
459        scores.insert(
460            "model_c".to_string(),
461            vec![("accuracy", 0.88), ("precision", 0.85), ("speed", 150.0)],
462        );
463        scores
464    }
465
466    #[test]
467    fn test_basic_selection() {
468        let scores = create_test_scores();
469
470        let mut selector = ModelSelector::new();
471        selector
472            .add_metric("accuracy", 0.6, true)
473            .add_metric("precision", 0.4, true);
474
475        let best = selector.select_best(&scores).unwrap();
476        assert!(!best.is_empty());
477    }
478
479    #[test]
480    fn test_ranking() {
481        let scores = create_test_scores();
482
483        let mut selector = ModelSelector::new();
484        selector
485            .add_metric("accuracy", 0.5, true)
486            .add_metric("precision", 0.5, true);
487
488        let rankings = selector.rank_models(&scores).unwrap();
489        assert_eq!(rankings.len(), 3);
490
491        // Rankings should be sorted by score (descending)
492        for i in 1..rankings.len() {
493            assert!(rankings[i - 1].1 >= rankings[i].1);
494        }
495    }
496
497    #[test]
498    fn test_pareto_optimal() {
499        let scores = create_test_scores();
500
501        let mut selector = ModelSelector::new();
502        selector
503            .add_metric("accuracy", 1.0, true)
504            .add_metric("speed", 1.0, true);
505
506        let pareto = selector.find_pareto_optimal(&scores);
507        assert!(!pareto.is_empty());
508    }
509
510    #[test]
511    fn test_thresholds() {
512        let scores = create_test_scores();
513
514        let mut selector = ModelSelector::new();
515        selector
516            .add_metric("accuracy", 1.0, true)
517            .add_threshold("accuracy", 0.87); // Only model_c meets this
518
519        let rankings = selector.rank_models(&scores).unwrap();
520        assert_eq!(rankings.len(), 1);
521        assert_eq!(rankings[0].0, "model_c");
522    }
523
524    #[test]
525    fn test_different_aggregation_strategies() {
526        let scores = create_test_scores();
527
528        let strategies = [
529            AggregationStrategy::WeightedSum,
530            AggregationStrategy::WeightedGeometricMean,
531            AggregationStrategy::MinScore,
532            AggregationStrategy::MaxScore,
533        ];
534
535        for strategy in &strategies {
536            let mut selector = ModelSelector::new();
537            selector
538                .add_metric("accuracy", 0.5, true)
539                .add_metric("precision", 0.5, true)
540                .with_aggregation(*strategy);
541
542            let best = selector.select_best(&scores).unwrap();
543            assert!(!best.is_empty());
544        }
545    }
546
547    #[test]
548    fn test_builder_pattern() {
549        let scores = create_test_scores();
550
551        let result = ModelSelectionBuilder::new()
552            .metric("accuracy", 0.6, true)
553            .metric("precision", 0.4, true)
554            .threshold("accuracy", 0.8)
555            .aggregation(AggregationStrategy::WeightedSum)
556            .select(&scores)
557            .unwrap();
558
559        assert!(!result.selected_model.is_empty());
560        assert!(!result.rankings.is_empty());
561        assert!(!result.pareto_optimal.is_empty());
562    }
563
564    #[test]
565    fn test_empty_models() {
566        let scores = HashMap::new();
567        let selector = ModelSelector::new();
568
569        assert!(selector.select_best(&scores).is_err());
570    }
571
572    #[test]
573    fn test_minimization_metrics() {
574        let mut scores = HashMap::new();
575        scores.insert("model_a".to_string(), vec![("error", 0.1), ("time", 5.0)]);
576        scores.insert("model_b".to_string(), vec![("error", 0.2), ("time", 3.0)]);
577
578        let mut selector = ModelSelector::new();
579        selector
580            .add_metric("error", 0.7, false)    // lower is better
581            .add_metric("time", 0.3, false); // lower is better
582
583        let best = selector.select_best(&scores).unwrap();
584        assert!(!best.is_empty());
585    }
586}