sklears_model_selection/
information_criteria.rs

1//! Information Criteria for Model Comparison
2//!
3//! This module implements various information criteria for model selection and comparison.
4//! Information criteria balance model fit (likelihood) with model complexity (number of parameters)
5//! to prevent overfitting and enable fair comparison between models.
6//!
7//! Implemented criteria:
8//! - AIC (Akaike Information Criterion)
9//! - AICc (Corrected AIC for finite samples)
10//! - BIC (Bayesian Information Criterion)
11//! - DIC (Deviance Information Criterion)
12//! - WAIC (Watanabe-Akaike Information Criterion)
13//! - LOOIC (Leave-One-Out Information Criterion)
14//! - TIC (Takeuchi Information Criterion)
15
16use scirs2_core::ndarray::{Array1, Array2};
17// use scirs2_core::numeric::Float as FloatTrait;
18use sklears_core::error::{Result, SklearsError};
19use std::fmt::Debug;
20
21/// Result of information criterion calculation
22#[derive(Debug, Clone)]
23pub struct InformationCriterionResult {
24    /// Name of the criterion
25    pub criterion_name: String,
26    /// Value of the information criterion
27    pub value: f64,
28    /// Log-likelihood of the model
29    pub log_likelihood: f64,
30    /// Number of parameters
31    pub n_parameters: usize,
32    /// Number of data points
33    pub n_data_points: usize,
34    /// Effective number of parameters (for DIC, WAIC)
35    pub effective_parameters: Option<f64>,
36    /// Standard error of the criterion (if available)
37    pub standard_error: Option<f64>,
38    /// Model weight relative to other models
39    pub weight: Option<f64>,
40}
41
42impl InformationCriterionResult {
43    /// Create new IC result
44    pub fn new(
45        criterion_name: String,
46        value: f64,
47        log_likelihood: f64,
48        n_parameters: usize,
49        n_data_points: usize,
50    ) -> Self {
51        Self {
52            criterion_name,
53            value,
54            log_likelihood,
55            n_parameters,
56            n_data_points,
57            effective_parameters: None,
58            standard_error: None,
59            weight: None,
60        }
61    }
62
63    /// Set effective number of parameters
64    pub fn with_effective_parameters(mut self, p_eff: f64) -> Self {
65        self.effective_parameters = Some(p_eff);
66        self
67    }
68
69    /// Set standard error
70    pub fn with_standard_error(mut self, se: f64) -> Self {
71        self.standard_error = Some(se);
72        self
73    }
74
75    /// Set model weight
76    pub fn with_weight(mut self, weight: f64) -> Self {
77        self.weight = Some(weight);
78        self
79    }
80}
81
82/// Comparison result for multiple models
83#[derive(Debug, Clone)]
84pub struct ModelComparisonResult {
85    /// Model names
86    pub model_names: Vec<String>,
87    /// Information criterion results for each model
88    pub results: Vec<InformationCriterionResult>,
89    /// Delta values (difference from best model)
90    pub delta_values: Vec<f64>,
91    /// Model weights (Akaike weights)
92    pub weights: Vec<f64>,
93    /// Index of best model
94    pub best_model_index: usize,
95    /// Evidence ratio for best vs second best
96    pub evidence_ratio: f64,
97}
98
99impl ModelComparisonResult {
100    /// Get best model name
101    pub fn best_model(&self) -> &str {
102        &self.model_names[self.best_model_index]
103    }
104
105    /// Get model ranking by IC value
106    pub fn model_ranking(&self) -> Vec<(usize, &str, f64)> {
107        let mut ranking: Vec<(usize, &str, f64)> = self
108            .model_names
109            .iter()
110            .enumerate()
111            .map(|(i, name)| (i, name.as_str(), self.results[i].value))
112            .collect();
113
114        ranking.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap());
115        ranking
116    }
117
118    /// Interpret model strength using Burnham & Anderson guidelines
119    pub fn model_strength_interpretation(&self, model_idx: usize) -> String {
120        let delta = self.delta_values[model_idx];
121        match delta {
122            d if d <= 2.0 => "Substantial support".to_string(),
123            d if d <= 4.0 => "Considerably less support".to_string(),
124            d if d <= 7.0 => "Little support".to_string(),
125            _ => "No support".to_string(),
126        }
127    }
128}
129
130/// Information criterion calculator
131pub struct InformationCriterionCalculator {
132    /// Whether to use bias correction for finite samples
133    pub use_bias_correction: bool,
134    /// Whether to calculate model weights
135    pub calculate_weights: bool,
136}
137
138impl Default for InformationCriterionCalculator {
139    fn default() -> Self {
140        Self {
141            use_bias_correction: true,
142            calculate_weights: true,
143        }
144    }
145}
146
147impl InformationCriterionCalculator {
148    /// Create new calculator
149    pub fn new() -> Self {
150        Self::default()
151    }
152
153    /// Calculate AIC (Akaike Information Criterion)
154    /// AIC = 2k - 2ln(L)
155    pub fn aic(
156        &self,
157        log_likelihood: f64,
158        n_parameters: usize,
159        n_data_points: usize,
160    ) -> InformationCriterionResult {
161        let k = n_parameters as f64;
162        let aic_value = 2.0 * k - 2.0 * log_likelihood;
163
164        InformationCriterionResult::new(
165            "AIC".to_string(),
166            aic_value,
167            log_likelihood,
168            n_parameters,
169            n_data_points,
170        )
171    }
172
173    /// Calculate AICc (Corrected AIC for finite samples)
174    /// AICc = AIC + 2k(k+1)/(n-k-1)
175    pub fn aicc(
176        &self,
177        log_likelihood: f64,
178        n_parameters: usize,
179        n_data_points: usize,
180    ) -> Result<InformationCriterionResult> {
181        let k = n_parameters as f64;
182        let n = n_data_points as f64;
183
184        if n <= k + 1.0 {
185            return Err(SklearsError::InvalidInput(
186                "AICc requires n > k + 1".to_string(),
187            ));
188        }
189
190        let aic_value = 2.0 * k - 2.0 * log_likelihood;
191        let correction = 2.0 * k * (k + 1.0) / (n - k - 1.0);
192        let aicc_value = aic_value + correction;
193
194        Ok(InformationCriterionResult::new(
195            "AICc".to_string(),
196            aicc_value,
197            log_likelihood,
198            n_parameters,
199            n_data_points,
200        ))
201    }
202
203    /// Calculate BIC (Bayesian Information Criterion)
204    /// BIC = k*ln(n) - 2ln(L)
205    pub fn bic(
206        &self,
207        log_likelihood: f64,
208        n_parameters: usize,
209        n_data_points: usize,
210    ) -> InformationCriterionResult {
211        let k = n_parameters as f64;
212        let n = n_data_points as f64;
213        let bic_value = k * n.ln() - 2.0 * log_likelihood;
214
215        InformationCriterionResult::new(
216            "BIC".to_string(),
217            bic_value,
218            log_likelihood,
219            n_parameters,
220            n_data_points,
221        )
222    }
223
224    /// Calculate DIC (Deviance Information Criterion)
225    /// DIC = D(θ̄) + 2p_D, where p_D is effective number of parameters
226    pub fn dic(
227        &self,
228        log_likelihood_mean: f64,
229        log_likelihood_samples: &[f64],
230        n_data_points: usize,
231    ) -> Result<InformationCriterionResult> {
232        if log_likelihood_samples.is_empty() {
233            return Err(SklearsError::InvalidInput(
234                "DIC requires posterior samples".to_string(),
235            ));
236        }
237
238        // Deviance at posterior mean
239        let deviance_mean = -2.0 * log_likelihood_mean;
240
241        // Mean deviance
242        let mean_deviance =
243            -2.0 * log_likelihood_samples.iter().sum::<f64>() / log_likelihood_samples.len() as f64;
244
245        // Effective number of parameters
246        let p_d = mean_deviance - deviance_mean;
247
248        // DIC value
249        let dic_value = deviance_mean + p_d;
250
251        Ok(InformationCriterionResult::new(
252            "DIC".to_string(),
253            dic_value,
254            log_likelihood_mean,
255            0, // Not applicable for DIC
256            n_data_points,
257        )
258        .with_effective_parameters(p_d))
259    }
260
261    /// Calculate WAIC (Watanabe-Akaike Information Criterion)
262    /// WAIC = -2 * (lppd - p_WAIC)
263    pub fn waic(
264        &self,
265        pointwise_log_likelihoods: &Array2<f64>, // Rows: samples, Columns: data points
266    ) -> Result<InformationCriterionResult> {
267        let (n_samples, n_data) = pointwise_log_likelihoods.dim();
268
269        if n_samples == 0 || n_data == 0 {
270            return Err(SklearsError::InvalidInput(
271                "WAIC requires non-empty likelihood matrix".to_string(),
272            ));
273        }
274
275        // Log pointwise predictive density (lppd)
276        let mut lppd = 0.0;
277        let mut p_waic = 0.0;
278
279        for j in 0..n_data {
280            let column = pointwise_log_likelihoods.column(j);
281
282            // Log of mean of likelihoods for data point j
283            let column_data: Vec<f64> = column.iter().copied().collect();
284            let log_mean_likelihood = log_mean_exp(&column_data);
285            lppd += log_mean_likelihood;
286
287            // Variance of log-likelihoods for data point j
288            let mean_log_likelihood = column.mean().unwrap();
289            let variance = column
290                .iter()
291                .map(|&x| (x - mean_log_likelihood).powi(2))
292                .sum::<f64>()
293                / (n_samples - 1) as f64;
294            p_waic += variance;
295        }
296
297        let waic_value = -2.0 * (lppd - p_waic);
298
299        // Calculate total log likelihood (approximate)
300        let total_log_likelihood = pointwise_log_likelihoods.sum();
301
302        Ok(InformationCriterionResult::new(
303            "WAIC".to_string(),
304            waic_value,
305            total_log_likelihood,
306            0, // Not directly applicable
307            n_data,
308        )
309        .with_effective_parameters(p_waic))
310    }
311
312    /// Calculate LOOIC (Leave-One-Out Information Criterion) using Pareto smoothed importance sampling
313    pub fn looic(
314        &self,
315        pointwise_log_likelihoods: &Array2<f64>,
316        pareto_k_diagnostics: Option<&Array1<f64>>,
317    ) -> Result<InformationCriterionResult> {
318        let (n_samples, n_data) = pointwise_log_likelihoods.dim();
319
320        if n_samples == 0 || n_data == 0 {
321            return Err(SklearsError::InvalidInput(
322                "LOOIC requires non-empty likelihood matrix".to_string(),
323            ));
324        }
325
326        let mut elpd_loo = 0.0; // Expected log pointwise predictive density
327        let mut p_loo = 0.0; // Effective number of parameters
328
329        for j in 0..n_data {
330            let column = pointwise_log_likelihoods.column(j);
331            let log_likes = column.as_slice().unwrap();
332
333            // Check Pareto k diagnostic if available
334            if let Some(k_values) = pareto_k_diagnostics {
335                if k_values[j] > 0.7 {
336                    eprintln!(
337                        "Warning: High Pareto k ({:.3}) for observation {}",
338                        k_values[j], j
339                    );
340                }
341            }
342
343            // Importance sampling LOO estimate
344            let max_log_like = log_likes.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
345            let rel_log_likes: Vec<f64> = log_likes.iter().map(|&x| x - max_log_like).collect();
346
347            // Calculate importance weights (simplified PSIS)
348            let weights: Vec<f64> = rel_log_likes.iter().map(|&x| x.exp()).collect();
349
350            let sum_weights: f64 = weights.iter().sum();
351            if sum_weights == 0.0 {
352                return Err(SklearsError::InvalidInput(
353                    "Zero importance weights".to_string(),
354                ));
355            }
356
357            let _normalized_weights: Vec<f64> = weights.iter().map(|&w| w / sum_weights).collect();
358
359            // LOO predictive density
360            let loo_lpd = (sum_weights / n_samples as f64).ln() + max_log_like;
361            elpd_loo += loo_lpd;
362
363            // Effective number of parameters contribution
364            let mean_log_like = log_likes.iter().sum::<f64>() / n_samples as f64;
365            p_loo += mean_log_like - loo_lpd;
366        }
367
368        let looic_value = -2.0 * elpd_loo;
369
370        Ok(InformationCriterionResult::new(
371            "LOOIC".to_string(),
372            looic_value,
373            0.0, // Not directly applicable
374            0,   // Not directly applicable
375            n_data,
376        )
377        .with_effective_parameters(p_loo))
378    }
379
380    /// Calculate TIC (Takeuchi Information Criterion)
381    /// TIC = -2ln(L) + 2tr(J^{-1}K), where J is Fisher information and K is outer product of scores
382    pub fn tic(
383        &self,
384        log_likelihood: f64,
385        fisher_information_trace: f64,
386        n_data_points: usize,
387    ) -> Result<InformationCriterionResult> {
388        if fisher_information_trace <= 0.0 {
389            return Err(SklearsError::InvalidInput(
390                "Fisher information trace must be positive".to_string(),
391            ));
392        }
393
394        let tic_value = -2.0 * log_likelihood + 2.0 * fisher_information_trace;
395
396        Ok(InformationCriterionResult::new(
397            "TIC".to_string(),
398            tic_value,
399            log_likelihood,
400            0, // Parameters counted via Fisher information
401            n_data_points,
402        )
403        .with_effective_parameters(fisher_information_trace))
404    }
405
406    /// Compare multiple models using specified criterion
407    pub fn compare_models(
408        &self,
409        models: &[(String, f64, usize, usize)], // (name, log_likelihood, n_params, n_data)
410        criterion: InformationCriterion,
411    ) -> Result<ModelComparisonResult> {
412        if models.is_empty() {
413            return Err(SklearsError::InvalidInput("No models provided".to_string()));
414        }
415
416        let mut results = Vec::new();
417        let model_names: Vec<String> = models.iter().map(|(name, _, _, _)| name.clone()).collect();
418
419        // Calculate IC for each model
420        for (_name, log_likelihood, n_params, n_data) in models {
421            let result = match criterion {
422                InformationCriterion::AIC => self.aic(*log_likelihood, *n_params, *n_data),
423                InformationCriterion::AICc => self.aicc(*log_likelihood, *n_params, *n_data)?,
424                InformationCriterion::BIC => self.bic(*log_likelihood, *n_params, *n_data),
425            };
426            results.push(result);
427        }
428
429        // Find best model (lowest IC value)
430        let best_idx = results
431            .iter()
432            .enumerate()
433            .min_by(|(_, a), (_, b)| a.value.partial_cmp(&b.value).unwrap())
434            .map(|(i, _)| i)
435            .unwrap();
436
437        let best_value = results[best_idx].value;
438
439        // Calculate delta values
440        let delta_values: Vec<f64> = results.iter().map(|r| r.value - best_value).collect();
441
442        // Calculate Akaike weights
443        let weights = if self.calculate_weights {
444            self.calculate_akaike_weights(&delta_values)
445        } else {
446            vec![0.0; results.len()]
447        };
448
449        // Evidence ratio (best vs second best)
450        let mut sorted_deltas = delta_values.clone();
451        sorted_deltas.sort_by(|a, b| a.partial_cmp(b).unwrap());
452        let evidence_ratio = if sorted_deltas.len() > 1 {
453            (-0.5 * sorted_deltas[1]).exp()
454        } else {
455            1.0
456        };
457
458        Ok(ModelComparisonResult {
459            model_names,
460            results,
461            delta_values,
462            weights,
463            best_model_index: best_idx,
464            evidence_ratio,
465        })
466    }
467
468    /// Calculate Akaike weights from delta values
469    fn calculate_akaike_weights(&self, delta_values: &[f64]) -> Vec<f64> {
470        // Akaike weights: w_i = exp(-Δ_i/2) / Σ exp(-Δ_j/2)
471        let weights: Vec<f64> = delta_values
472            .iter()
473            .map(|&delta| (-0.5 * delta).exp())
474            .collect();
475
476        let sum_weights: f64 = weights.iter().sum();
477        if sum_weights == 0.0 {
478            return vec![1.0 / weights.len() as f64; weights.len()];
479        }
480
481        weights.iter().map(|&w| w / sum_weights).collect()
482    }
483
484    /// Calculate model-averaged prediction using IC weights
485    pub fn model_averaged_prediction(
486        &self,
487        predictions: &[Array1<f64>],
488        weights: &[f64],
489    ) -> Result<Array1<f64>> {
490        if predictions.is_empty() {
491            return Err(SklearsError::InvalidInput(
492                "No predictions provided".to_string(),
493            ));
494        }
495
496        if predictions.len() != weights.len() {
497            return Err(SklearsError::InvalidInput(
498                "Number of predictions must match number of weights".to_string(),
499            ));
500        }
501
502        let n_samples = predictions[0].len();
503        for pred in predictions {
504            if pred.len() != n_samples {
505                return Err(SklearsError::InvalidInput(
506                    "All predictions must have the same length".to_string(),
507                ));
508            }
509        }
510
511        let mut averaged = Array1::zeros(n_samples);
512        for (pred, &weight) in predictions.iter().zip(weights.iter()) {
513            averaged = averaged + pred * weight;
514        }
515
516        Ok(averaged)
517    }
518}
519
520/// Types of information criteria
521#[derive(Debug, Clone, Copy)]
522pub enum InformationCriterion {
523    /// AIC
524    AIC,
525    /// AICc
526    AICc,
527    /// BIC
528    BIC,
529}
530
531// Utility functions
532fn log_mean_exp(values: &[f64]) -> f64 {
533    let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
534    let sum: f64 = values.iter().map(|&x| (x - max_val).exp()).sum();
535    max_val + (sum / values.len() as f64).ln()
536}
537
538/// Model selection using information criteria with cross-validation
539pub struct CrossValidatedIC {
540    criterion: InformationCriterion,
541    n_folds: usize,
542}
543
544impl CrossValidatedIC {
545    /// Create new cross-validated IC selector
546    pub fn new(criterion: InformationCriterion, n_folds: usize) -> Self {
547        Self { criterion, n_folds }
548    }
549
550    /// Select best model using cross-validated IC
551    pub fn select_model(
552        &self,
553        cv_results: &[(String, Vec<f64>, Vec<usize>, Vec<usize>)], // (name, cv_log_likes, cv_n_params, cv_n_data)
554    ) -> Result<ModelComparisonResult> {
555        let calculator = InformationCriterionCalculator::new();
556        let mut aggregated_models = Vec::new();
557
558        for (name, cv_log_likes, cv_n_params, cv_n_data) in cv_results {
559            if cv_log_likes.len() != self.n_folds {
560                return Err(SklearsError::InvalidInput(
561                    "CV results must match number of folds".to_string(),
562                ));
563            }
564
565            // Aggregate across folds
566            let total_log_likelihood: f64 = cv_log_likes.iter().sum();
567            let avg_n_params = cv_n_params.iter().sum::<usize>() / cv_n_params.len();
568            let total_n_data = cv_n_data.iter().sum::<usize>();
569
570            aggregated_models.push((
571                name.clone(),
572                total_log_likelihood,
573                avg_n_params,
574                total_n_data,
575            ));
576        }
577
578        calculator.compare_models(&aggregated_models, self.criterion)
579    }
580}
581
582#[allow(non_snake_case)]
583#[cfg(test)]
584mod tests {
585    use super::*;
586    use scirs2_core::ndarray::array;
587
588    #[test]
589    fn test_aic_calculation() {
590        let calculator = InformationCriterionCalculator::new();
591        let result = calculator.aic(-100.0, 5, 200);
592
593        assert_eq!(result.criterion_name, "AIC");
594        assert_eq!(result.value, 210.0); // 2*5 - 2*(-100)
595        assert_eq!(result.n_parameters, 5);
596        assert_eq!(result.n_data_points, 200);
597    }
598
599    #[test]
600    fn test_aicc_calculation() {
601        let calculator = InformationCriterionCalculator::new();
602        let result = calculator.aicc(-100.0, 5, 20).unwrap();
603
604        assert_eq!(result.criterion_name, "AICc");
605        assert!(result.value > 210.0); // Should be higher than AIC due to correction
606    }
607
608    #[test]
609    fn test_bic_calculation() {
610        let calculator = InformationCriterionCalculator::new();
611        let result = calculator.bic(-100.0, 5, 200);
612
613        assert_eq!(result.criterion_name, "BIC");
614        assert!(result.value > 210.0); // BIC penalizes complexity more than AIC
615    }
616
617    #[test]
618    fn test_model_comparison() {
619        let calculator = InformationCriterionCalculator::new();
620        let models = vec![
621            ("Model1".to_string(), -95.0, 3, 100),
622            ("Model2".to_string(), -100.0, 5, 100),
623            ("Model3".to_string(), -98.0, 4, 100),
624        ];
625
626        let result = calculator
627            .compare_models(&models, InformationCriterion::AIC)
628            .unwrap();
629
630        assert_eq!(result.model_names.len(), 3);
631        assert_eq!(result.best_model(), "Model1"); // Best log-likelihood with fewest parameters
632        assert!((result.weights.iter().sum::<f64>() - 1.0).abs() < 1e-6);
633    }
634
635    #[test]
636    fn test_akaike_weights() {
637        let calculator = InformationCriterionCalculator::new();
638        let delta_values = vec![0.0, 2.0, 4.0]; // Best model, 2 AIC units worse, 4 AIC units worse
639        let weights = calculator.calculate_akaike_weights(&delta_values);
640
641        assert!(weights[0] > weights[1]); // Best model should have highest weight
642        assert!(weights[1] > weights[2]); // Second best should be better than worst
643        assert!((weights.iter().sum::<f64>() - 1.0).abs() < 1e-6); // Weights sum to 1
644    }
645
646    #[test]
647    fn test_waic_calculation() {
648        let calculator = InformationCriterionCalculator::new();
649
650        // Mock pointwise log-likelihoods: 3 samples, 5 data points
651        let pointwise_ll = array![
652            [-1.0, -1.2, -0.9, -1.1, -1.0],
653            [-1.1, -1.0, -1.0, -1.0, -0.9],
654            [-0.9, -1.1, -1.1, -0.9, -1.1]
655        ];
656
657        let result = calculator.waic(&pointwise_ll).unwrap();
658        assert_eq!(result.criterion_name, "WAIC");
659        assert!(result.effective_parameters.is_some());
660    }
661
662    #[test]
663    fn test_model_ranking() {
664        let models = vec![
665            ("ModelA".to_string(), -100.0, 5, 100),
666            ("ModelB".to_string(), -95.0, 3, 100),
667            ("ModelC".to_string(), -98.0, 4, 100),
668        ];
669
670        let calculator = InformationCriterionCalculator::new();
671        let result = calculator
672            .compare_models(&models, InformationCriterion::AIC)
673            .unwrap();
674        let ranking = result.model_ranking();
675
676        assert_eq!(ranking[0].1, "ModelB"); // Best model
677        assert_eq!(ranking[2].1, "ModelA"); // Worst model
678    }
679
680    #[test]
681    fn test_model_averaged_prediction() {
682        let calculator = InformationCriterionCalculator::new();
683
684        let pred1 = array![1.0, 2.0, 3.0];
685        let pred2 = array![1.1, 2.1, 3.1];
686        let pred3 = array![0.9, 1.9, 2.9];
687
688        let predictions = vec![pred1, pred2, pred3];
689        let weights = vec![0.5, 0.3, 0.2];
690
691        let averaged = calculator
692            .model_averaged_prediction(&predictions, &weights)
693            .unwrap();
694        assert_eq!(averaged.len(), 3);
695
696        // Check weighted average
697        let expected_0 = 1.0 * 0.5 + 1.1 * 0.3 + 0.9 * 0.2;
698        assert!((averaged[0] - expected_0).abs() < 1e-10);
699    }
700}