Skip to main content

sklears_model_selection/
information_criteria.rs

1//! Information Criteria for Model Comparison
2//!
3//! This module implements various information criteria for model selection and comparison.
4//! Information criteria balance model fit (likelihood) with model complexity (number of parameters)
5//! to prevent overfitting and enable fair comparison between models.
6//!
7//! Implemented criteria:
8//! - AIC (Akaike Information Criterion)
9//! - AICc (Corrected AIC for finite samples)
10//! - BIC (Bayesian Information Criterion)
11//! - DIC (Deviance Information Criterion)
12//! - WAIC (Watanabe-Akaike Information Criterion)
13//! - LOOIC (Leave-One-Out Information Criterion)
14//! - TIC (Takeuchi Information Criterion)
15
16use scirs2_core::ndarray::{Array1, Array2};
17// use scirs2_core::numeric::Float as FloatTrait;
18use sklears_core::error::{Result, SklearsError};
19use std::fmt::Debug;
20
21/// Result of information criterion calculation
22#[derive(Debug, Clone)]
23pub struct InformationCriterionResult {
24    /// Name of the criterion
25    pub criterion_name: String,
26    /// Value of the information criterion
27    pub value: f64,
28    /// Log-likelihood of the model
29    pub log_likelihood: f64,
30    /// Number of parameters
31    pub n_parameters: usize,
32    /// Number of data points
33    pub n_data_points: usize,
34    /// Effective number of parameters (for DIC, WAIC)
35    pub effective_parameters: Option<f64>,
36    /// Standard error of the criterion (if available)
37    pub standard_error: Option<f64>,
38    /// Model weight relative to other models
39    pub weight: Option<f64>,
40}
41
42impl InformationCriterionResult {
43    /// Create new IC result
44    pub fn new(
45        criterion_name: String,
46        value: f64,
47        log_likelihood: f64,
48        n_parameters: usize,
49        n_data_points: usize,
50    ) -> Self {
51        Self {
52            criterion_name,
53            value,
54            log_likelihood,
55            n_parameters,
56            n_data_points,
57            effective_parameters: None,
58            standard_error: None,
59            weight: None,
60        }
61    }
62
63    /// Set effective number of parameters
64    pub fn with_effective_parameters(mut self, p_eff: f64) -> Self {
65        self.effective_parameters = Some(p_eff);
66        self
67    }
68
69    /// Set standard error
70    pub fn with_standard_error(mut self, se: f64) -> Self {
71        self.standard_error = Some(se);
72        self
73    }
74
75    /// Set model weight
76    pub fn with_weight(mut self, weight: f64) -> Self {
77        self.weight = Some(weight);
78        self
79    }
80}
81
82/// Comparison result for multiple models
83#[derive(Debug, Clone)]
84pub struct ModelComparisonResult {
85    /// Model names
86    pub model_names: Vec<String>,
87    /// Information criterion results for each model
88    pub results: Vec<InformationCriterionResult>,
89    /// Delta values (difference from best model)
90    pub delta_values: Vec<f64>,
91    /// Model weights (Akaike weights)
92    pub weights: Vec<f64>,
93    /// Index of best model
94    pub best_model_index: usize,
95    /// Evidence ratio for best vs second best
96    pub evidence_ratio: f64,
97}
98
99impl ModelComparisonResult {
100    /// Get best model name
101    pub fn best_model(&self) -> &str {
102        &self.model_names[self.best_model_index]
103    }
104
105    /// Get model ranking by IC value
106    pub fn model_ranking(&self) -> Vec<(usize, &str, f64)> {
107        let mut ranking: Vec<(usize, &str, f64)> = self
108            .model_names
109            .iter()
110            .enumerate()
111            .map(|(i, name)| (i, name.as_str(), self.results[i].value))
112            .collect();
113
114        ranking.sort_by(|a, b| a.2.partial_cmp(&b.2).expect("operation should succeed"));
115        ranking
116    }
117
118    /// Interpret model strength using Burnham & Anderson guidelines
119    pub fn model_strength_interpretation(&self, model_idx: usize) -> String {
120        let delta = self.delta_values[model_idx];
121        match delta {
122            d if d <= 2.0 => "Substantial support".to_string(),
123            d if d <= 4.0 => "Considerably less support".to_string(),
124            d if d <= 7.0 => "Little support".to_string(),
125            _ => "No support".to_string(),
126        }
127    }
128}
129
130/// Information criterion calculator
131pub struct InformationCriterionCalculator {
132    /// Whether to use bias correction for finite samples
133    pub use_bias_correction: bool,
134    /// Whether to calculate model weights
135    pub calculate_weights: bool,
136}
137
138impl Default for InformationCriterionCalculator {
139    fn default() -> Self {
140        Self {
141            use_bias_correction: true,
142            calculate_weights: true,
143        }
144    }
145}
146
147impl InformationCriterionCalculator {
148    /// Create new calculator
149    pub fn new() -> Self {
150        Self::default()
151    }
152
153    /// Calculate AIC (Akaike Information Criterion)
154    /// AIC = 2k - 2ln(L)
155    pub fn aic(
156        &self,
157        log_likelihood: f64,
158        n_parameters: usize,
159        n_data_points: usize,
160    ) -> InformationCriterionResult {
161        let k = n_parameters as f64;
162        let aic_value = 2.0 * k - 2.0 * log_likelihood;
163
164        InformationCriterionResult::new(
165            "AIC".to_string(),
166            aic_value,
167            log_likelihood,
168            n_parameters,
169            n_data_points,
170        )
171    }
172
173    /// Calculate AICc (Corrected AIC for finite samples)
174    /// AICc = AIC + 2k(k+1)/(n-k-1)
175    pub fn aicc(
176        &self,
177        log_likelihood: f64,
178        n_parameters: usize,
179        n_data_points: usize,
180    ) -> Result<InformationCriterionResult> {
181        let k = n_parameters as f64;
182        let n = n_data_points as f64;
183
184        if n <= k + 1.0 {
185            return Err(SklearsError::InvalidInput(
186                "AICc requires n > k + 1".to_string(),
187            ));
188        }
189
190        let aic_value = 2.0 * k - 2.0 * log_likelihood;
191        let correction = 2.0 * k * (k + 1.0) / (n - k - 1.0);
192        let aicc_value = aic_value + correction;
193
194        Ok(InformationCriterionResult::new(
195            "AICc".to_string(),
196            aicc_value,
197            log_likelihood,
198            n_parameters,
199            n_data_points,
200        ))
201    }
202
203    /// Calculate BIC (Bayesian Information Criterion)
204    /// BIC = k*ln(n) - 2ln(L)
205    pub fn bic(
206        &self,
207        log_likelihood: f64,
208        n_parameters: usize,
209        n_data_points: usize,
210    ) -> InformationCriterionResult {
211        let k = n_parameters as f64;
212        let n = n_data_points as f64;
213        let bic_value = k * n.ln() - 2.0 * log_likelihood;
214
215        InformationCriterionResult::new(
216            "BIC".to_string(),
217            bic_value,
218            log_likelihood,
219            n_parameters,
220            n_data_points,
221        )
222    }
223
224    /// Calculate DIC (Deviance Information Criterion)
225    /// DIC = D(θ̄) + 2p_D, where p_D is effective number of parameters
226    pub fn dic(
227        &self,
228        log_likelihood_mean: f64,
229        log_likelihood_samples: &[f64],
230        n_data_points: usize,
231    ) -> Result<InformationCriterionResult> {
232        if log_likelihood_samples.is_empty() {
233            return Err(SklearsError::InvalidInput(
234                "DIC requires posterior samples".to_string(),
235            ));
236        }
237
238        // Deviance at posterior mean
239        let deviance_mean = -2.0 * log_likelihood_mean;
240
241        // Mean deviance
242        let mean_deviance =
243            -2.0 * log_likelihood_samples.iter().sum::<f64>() / log_likelihood_samples.len() as f64;
244
245        // Effective number of parameters
246        let p_d = mean_deviance - deviance_mean;
247
248        // DIC value
249        let dic_value = deviance_mean + p_d;
250
251        Ok(InformationCriterionResult::new(
252            "DIC".to_string(),
253            dic_value,
254            log_likelihood_mean,
255            0, // Not applicable for DIC
256            n_data_points,
257        )
258        .with_effective_parameters(p_d))
259    }
260
261    /// Calculate WAIC (Watanabe-Akaike Information Criterion)
262    /// WAIC = -2 * (lppd - p_WAIC)
263    pub fn waic(
264        &self,
265        pointwise_log_likelihoods: &Array2<f64>, // Rows: samples, Columns: data points
266    ) -> Result<InformationCriterionResult> {
267        let (n_samples, n_data) = pointwise_log_likelihoods.dim();
268
269        if n_samples == 0 || n_data == 0 {
270            return Err(SklearsError::InvalidInput(
271                "WAIC requires non-empty likelihood matrix".to_string(),
272            ));
273        }
274
275        // Log pointwise predictive density (lppd)
276        let mut lppd = 0.0;
277        let mut p_waic = 0.0;
278
279        for j in 0..n_data {
280            let column = pointwise_log_likelihoods.column(j);
281
282            // Log of mean of likelihoods for data point j
283            let column_data: Vec<f64> = column.iter().copied().collect();
284            let log_mean_likelihood = log_mean_exp(&column_data);
285            lppd += log_mean_likelihood;
286
287            // Variance of log-likelihoods for data point j
288            let mean_log_likelihood = column.mean().expect("operation should succeed");
289            let variance = column
290                .iter()
291                .map(|&x| (x - mean_log_likelihood).powi(2))
292                .sum::<f64>()
293                / (n_samples - 1) as f64;
294            p_waic += variance;
295        }
296
297        let waic_value = -2.0 * (lppd - p_waic);
298
299        // Calculate total log likelihood (approximate)
300        let total_log_likelihood = pointwise_log_likelihoods.sum();
301
302        Ok(InformationCriterionResult::new(
303            "WAIC".to_string(),
304            waic_value,
305            total_log_likelihood,
306            0, // Not directly applicable
307            n_data,
308        )
309        .with_effective_parameters(p_waic))
310    }
311
312    /// Calculate LOOIC (Leave-One-Out Information Criterion) using Pareto smoothed importance sampling
313    pub fn looic(
314        &self,
315        pointwise_log_likelihoods: &Array2<f64>,
316        pareto_k_diagnostics: Option<&Array1<f64>>,
317    ) -> Result<InformationCriterionResult> {
318        let (n_samples, n_data) = pointwise_log_likelihoods.dim();
319
320        if n_samples == 0 || n_data == 0 {
321            return Err(SklearsError::InvalidInput(
322                "LOOIC requires non-empty likelihood matrix".to_string(),
323            ));
324        }
325
326        let mut elpd_loo = 0.0; // Expected log pointwise predictive density
327        let mut p_loo = 0.0; // Effective number of parameters
328
329        for j in 0..n_data {
330            let column = pointwise_log_likelihoods.column(j);
331            let log_likes = column.as_slice().expect("operation should succeed");
332
333            // Check Pareto k diagnostic if available
334            if let Some(k_values) = pareto_k_diagnostics {
335                if k_values[j] > 0.7 {
336                    eprintln!(
337                        "Warning: High Pareto k ({:.3}) for observation {}",
338                        k_values[j], j
339                    );
340                }
341            }
342
343            // Importance sampling LOO estimate
344            let max_log_like = log_likes.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
345            let rel_log_likes: Vec<f64> = log_likes.iter().map(|&x| x - max_log_like).collect();
346
347            // Calculate importance weights (simplified PSIS)
348            let weights: Vec<f64> = rel_log_likes.iter().map(|&x| x.exp()).collect();
349
350            let sum_weights: f64 = weights.iter().sum();
351            if sum_weights == 0.0 {
352                return Err(SklearsError::InvalidInput(
353                    "Zero importance weights".to_string(),
354                ));
355            }
356
357            let _normalized_weights: Vec<f64> = weights.iter().map(|&w| w / sum_weights).collect();
358
359            // LOO predictive density
360            let loo_lpd = (sum_weights / n_samples as f64).ln() + max_log_like;
361            elpd_loo += loo_lpd;
362
363            // Effective number of parameters contribution
364            let mean_log_like = log_likes.iter().sum::<f64>() / n_samples as f64;
365            p_loo += mean_log_like - loo_lpd;
366        }
367
368        let looic_value = -2.0 * elpd_loo;
369
370        Ok(InformationCriterionResult::new(
371            "LOOIC".to_string(),
372            looic_value,
373            0.0, // Not directly applicable
374            0,   // Not directly applicable
375            n_data,
376        )
377        .with_effective_parameters(p_loo))
378    }
379
380    /// Calculate TIC (Takeuchi Information Criterion)
381    /// TIC = -2ln(L) + 2tr(J^{-1}K), where J is Fisher information and K is outer product of scores
382    pub fn tic(
383        &self,
384        log_likelihood: f64,
385        fisher_information_trace: f64,
386        n_data_points: usize,
387    ) -> Result<InformationCriterionResult> {
388        if fisher_information_trace <= 0.0 {
389            return Err(SklearsError::InvalidInput(
390                "Fisher information trace must be positive".to_string(),
391            ));
392        }
393
394        let tic_value = -2.0 * log_likelihood + 2.0 * fisher_information_trace;
395
396        Ok(InformationCriterionResult::new(
397            "TIC".to_string(),
398            tic_value,
399            log_likelihood,
400            0, // Parameters counted via Fisher information
401            n_data_points,
402        )
403        .with_effective_parameters(fisher_information_trace))
404    }
405
406    /// Compare multiple models using specified criterion
407    pub fn compare_models(
408        &self,
409        models: &[(String, f64, usize, usize)], // (name, log_likelihood, n_params, n_data)
410        criterion: InformationCriterion,
411    ) -> Result<ModelComparisonResult> {
412        if models.is_empty() {
413            return Err(SklearsError::InvalidInput("No models provided".to_string()));
414        }
415
416        let mut results = Vec::new();
417        let model_names: Vec<String> = models.iter().map(|(name, _, _, _)| name.clone()).collect();
418
419        // Calculate IC for each model
420        for (_name, log_likelihood, n_params, n_data) in models {
421            let result = match criterion {
422                InformationCriterion::AIC => self.aic(*log_likelihood, *n_params, *n_data),
423                InformationCriterion::AICc => self.aicc(*log_likelihood, *n_params, *n_data)?,
424                InformationCriterion::BIC => self.bic(*log_likelihood, *n_params, *n_data),
425            };
426            results.push(result);
427        }
428
429        // Find best model (lowest IC value)
430        let best_idx = results
431            .iter()
432            .enumerate()
433            .min_by(|(_, a), (_, b)| {
434                a.value
435                    .partial_cmp(&b.value)
436                    .expect("operation should succeed")
437            })
438            .map(|(i, _)| i)
439            .expect("operation should succeed");
440
441        let best_value = results[best_idx].value;
442
443        // Calculate delta values
444        let delta_values: Vec<f64> = results.iter().map(|r| r.value - best_value).collect();
445
446        // Calculate Akaike weights
447        let weights = if self.calculate_weights {
448            self.calculate_akaike_weights(&delta_values)
449        } else {
450            vec![0.0; results.len()]
451        };
452
453        // Evidence ratio (best vs second best)
454        let mut sorted_deltas = delta_values.clone();
455        sorted_deltas.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
456        let evidence_ratio = if sorted_deltas.len() > 1 {
457            (-0.5 * sorted_deltas[1]).exp()
458        } else {
459            1.0
460        };
461
462        Ok(ModelComparisonResult {
463            model_names,
464            results,
465            delta_values,
466            weights,
467            best_model_index: best_idx,
468            evidence_ratio,
469        })
470    }
471
472    /// Calculate Akaike weights from delta values
473    fn calculate_akaike_weights(&self, delta_values: &[f64]) -> Vec<f64> {
474        // Akaike weights: w_i = exp(-Δ_i/2) / Σ exp(-Δ_j/2)
475        let weights: Vec<f64> = delta_values
476            .iter()
477            .map(|&delta| (-0.5 * delta).exp())
478            .collect();
479
480        let sum_weights: f64 = weights.iter().sum();
481        if sum_weights == 0.0 {
482            return vec![1.0 / weights.len() as f64; weights.len()];
483        }
484
485        weights.iter().map(|&w| w / sum_weights).collect()
486    }
487
488    /// Calculate model-averaged prediction using IC weights
489    pub fn model_averaged_prediction(
490        &self,
491        predictions: &[Array1<f64>],
492        weights: &[f64],
493    ) -> Result<Array1<f64>> {
494        if predictions.is_empty() {
495            return Err(SklearsError::InvalidInput(
496                "No predictions provided".to_string(),
497            ));
498        }
499
500        if predictions.len() != weights.len() {
501            return Err(SklearsError::InvalidInput(
502                "Number of predictions must match number of weights".to_string(),
503            ));
504        }
505
506        let n_samples = predictions[0].len();
507        for pred in predictions {
508            if pred.len() != n_samples {
509                return Err(SklearsError::InvalidInput(
510                    "All predictions must have the same length".to_string(),
511                ));
512            }
513        }
514
515        let mut averaged = Array1::zeros(n_samples);
516        for (pred, &weight) in predictions.iter().zip(weights.iter()) {
517            averaged = averaged + pred * weight;
518        }
519
520        Ok(averaged)
521    }
522}
523
524/// Types of information criteria
525#[derive(Debug, Clone, Copy)]
526pub enum InformationCriterion {
527    /// AIC
528    AIC,
529    /// AICc
530    AICc,
531    /// BIC
532    BIC,
533}
534
535// Utility functions
536fn log_mean_exp(values: &[f64]) -> f64 {
537    let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
538    let sum: f64 = values.iter().map(|&x| (x - max_val).exp()).sum();
539    max_val + (sum / values.len() as f64).ln()
540}
541
542/// Model selection using information criteria with cross-validation
543pub struct CrossValidatedIC {
544    criterion: InformationCriterion,
545    n_folds: usize,
546}
547
548impl CrossValidatedIC {
549    /// Create new cross-validated IC selector
550    pub fn new(criterion: InformationCriterion, n_folds: usize) -> Self {
551        Self { criterion, n_folds }
552    }
553
554    /// Select best model using cross-validated IC
555    pub fn select_model(
556        &self,
557        cv_results: &[(String, Vec<f64>, Vec<usize>, Vec<usize>)], // (name, cv_log_likes, cv_n_params, cv_n_data)
558    ) -> Result<ModelComparisonResult> {
559        let calculator = InformationCriterionCalculator::new();
560        let mut aggregated_models = Vec::new();
561
562        for (name, cv_log_likes, cv_n_params, cv_n_data) in cv_results {
563            if cv_log_likes.len() != self.n_folds {
564                return Err(SklearsError::InvalidInput(
565                    "CV results must match number of folds".to_string(),
566                ));
567            }
568
569            // Aggregate across folds
570            let total_log_likelihood: f64 = cv_log_likes.iter().sum();
571            let avg_n_params = cv_n_params.iter().sum::<usize>() / cv_n_params.len();
572            let total_n_data = cv_n_data.iter().sum::<usize>();
573
574            aggregated_models.push((
575                name.clone(),
576                total_log_likelihood,
577                avg_n_params,
578                total_n_data,
579            ));
580        }
581
582        calculator.compare_models(&aggregated_models, self.criterion)
583    }
584}
585
586#[allow(non_snake_case)]
587#[cfg(test)]
588mod tests {
589    use super::*;
590    use scirs2_core::ndarray::array;
591
592    #[test]
593    fn test_aic_calculation() {
594        let calculator = InformationCriterionCalculator::new();
595        let result = calculator.aic(-100.0, 5, 200);
596
597        assert_eq!(result.criterion_name, "AIC");
598        assert_eq!(result.value, 210.0); // 2*5 - 2*(-100)
599        assert_eq!(result.n_parameters, 5);
600        assert_eq!(result.n_data_points, 200);
601    }
602
603    #[test]
604    fn test_aicc_calculation() {
605        let calculator = InformationCriterionCalculator::new();
606        let result = calculator
607            .aicc(-100.0, 5, 20)
608            .expect("operation should succeed");
609
610        assert_eq!(result.criterion_name, "AICc");
611        assert!(result.value > 210.0); // Should be higher than AIC due to correction
612    }
613
614    #[test]
615    fn test_bic_calculation() {
616        let calculator = InformationCriterionCalculator::new();
617        let result = calculator.bic(-100.0, 5, 200);
618
619        assert_eq!(result.criterion_name, "BIC");
620        assert!(result.value > 210.0); // BIC penalizes complexity more than AIC
621    }
622
623    #[test]
624    fn test_model_comparison() {
625        let calculator = InformationCriterionCalculator::new();
626        let models = vec![
627            ("Model1".to_string(), -95.0, 3, 100),
628            ("Model2".to_string(), -100.0, 5, 100),
629            ("Model3".to_string(), -98.0, 4, 100),
630        ];
631
632        let result = calculator
633            .compare_models(&models, InformationCriterion::AIC)
634            .expect("operation should succeed");
635
636        assert_eq!(result.model_names.len(), 3);
637        assert_eq!(result.best_model(), "Model1"); // Best log-likelihood with fewest parameters
638        assert!((result.weights.iter().sum::<f64>() - 1.0).abs() < 1e-6);
639    }
640
641    #[test]
642    fn test_akaike_weights() {
643        let calculator = InformationCriterionCalculator::new();
644        let delta_values = vec![0.0, 2.0, 4.0]; // Best model, 2 AIC units worse, 4 AIC units worse
645        let weights = calculator.calculate_akaike_weights(&delta_values);
646
647        assert!(weights[0] > weights[1]); // Best model should have highest weight
648        assert!(weights[1] > weights[2]); // Second best should be better than worst
649        assert!((weights.iter().sum::<f64>() - 1.0).abs() < 1e-6); // Weights sum to 1
650    }
651
652    #[test]
653    fn test_waic_calculation() {
654        let calculator = InformationCriterionCalculator::new();
655
656        // Mock pointwise log-likelihoods: 3 samples, 5 data points
657        let pointwise_ll = array![
658            [-1.0, -1.2, -0.9, -1.1, -1.0],
659            [-1.1, -1.0, -1.0, -1.0, -0.9],
660            [-0.9, -1.1, -1.1, -0.9, -1.1]
661        ];
662
663        let result = calculator
664            .waic(&pointwise_ll)
665            .expect("operation should succeed");
666        assert_eq!(result.criterion_name, "WAIC");
667        assert!(result.effective_parameters.is_some());
668    }
669
670    #[test]
671    fn test_model_ranking() {
672        let models = vec![
673            ("ModelA".to_string(), -100.0, 5, 100),
674            ("ModelB".to_string(), -95.0, 3, 100),
675            ("ModelC".to_string(), -98.0, 4, 100),
676        ];
677
678        let calculator = InformationCriterionCalculator::new();
679        let result = calculator
680            .compare_models(&models, InformationCriterion::AIC)
681            .expect("operation should succeed");
682        let ranking = result.model_ranking();
683
684        assert_eq!(ranking[0].1, "ModelB"); // Best model
685        assert_eq!(ranking[2].1, "ModelA"); // Worst model
686    }
687
688    #[test]
689    fn test_model_averaged_prediction() {
690        let calculator = InformationCriterionCalculator::new();
691
692        let pred1 = array![1.0, 2.0, 3.0];
693        let pred2 = array![1.1, 2.1, 3.1];
694        let pred3 = array![0.9, 1.9, 2.9];
695
696        let predictions = vec![pred1, pred2, pred3];
697        let weights = vec![0.5, 0.3, 0.2];
698
699        let averaged = calculator
700            .model_averaged_prediction(&predictions, &weights)
701            .expect("operation should succeed");
702        assert_eq!(averaged.len(), 3);
703
704        // Check weighted average
705        let expected_0 = 1.0 * 0.5 + 1.1 * 0.3 + 0.9 * 0.2;
706        assert!((averaged[0] - expected_0).abs() < 1e-10);
707    }
708}