1use scirs2_core::ndarray::{Array1, Array2};
17use sklears_core::error::{Result, SklearsError};
19use std::fmt::Debug;
20
21#[derive(Debug, Clone)]
23pub struct InformationCriterionResult {
24 pub criterion_name: String,
26 pub value: f64,
28 pub log_likelihood: f64,
30 pub n_parameters: usize,
32 pub n_data_points: usize,
34 pub effective_parameters: Option<f64>,
36 pub standard_error: Option<f64>,
38 pub weight: Option<f64>,
40}
41
42impl InformationCriterionResult {
43 pub fn new(
45 criterion_name: String,
46 value: f64,
47 log_likelihood: f64,
48 n_parameters: usize,
49 n_data_points: usize,
50 ) -> Self {
51 Self {
52 criterion_name,
53 value,
54 log_likelihood,
55 n_parameters,
56 n_data_points,
57 effective_parameters: None,
58 standard_error: None,
59 weight: None,
60 }
61 }
62
63 pub fn with_effective_parameters(mut self, p_eff: f64) -> Self {
65 self.effective_parameters = Some(p_eff);
66 self
67 }
68
69 pub fn with_standard_error(mut self, se: f64) -> Self {
71 self.standard_error = Some(se);
72 self
73 }
74
75 pub fn with_weight(mut self, weight: f64) -> Self {
77 self.weight = Some(weight);
78 self
79 }
80}
81
82#[derive(Debug, Clone)]
84pub struct ModelComparisonResult {
85 pub model_names: Vec<String>,
87 pub results: Vec<InformationCriterionResult>,
89 pub delta_values: Vec<f64>,
91 pub weights: Vec<f64>,
93 pub best_model_index: usize,
95 pub evidence_ratio: f64,
97}
98
99impl ModelComparisonResult {
100 pub fn best_model(&self) -> &str {
102 &self.model_names[self.best_model_index]
103 }
104
105 pub fn model_ranking(&self) -> Vec<(usize, &str, f64)> {
107 let mut ranking: Vec<(usize, &str, f64)> = self
108 .model_names
109 .iter()
110 .enumerate()
111 .map(|(i, name)| (i, name.as_str(), self.results[i].value))
112 .collect();
113
114 ranking.sort_by(|a, b| a.2.partial_cmp(&b.2).expect("operation should succeed"));
115 ranking
116 }
117
118 pub fn model_strength_interpretation(&self, model_idx: usize) -> String {
120 let delta = self.delta_values[model_idx];
121 match delta {
122 d if d <= 2.0 => "Substantial support".to_string(),
123 d if d <= 4.0 => "Considerably less support".to_string(),
124 d if d <= 7.0 => "Little support".to_string(),
125 _ => "No support".to_string(),
126 }
127 }
128}
129
130pub struct InformationCriterionCalculator {
132 pub use_bias_correction: bool,
134 pub calculate_weights: bool,
136}
137
138impl Default for InformationCriterionCalculator {
139 fn default() -> Self {
140 Self {
141 use_bias_correction: true,
142 calculate_weights: true,
143 }
144 }
145}
146
147impl InformationCriterionCalculator {
148 pub fn new() -> Self {
150 Self::default()
151 }
152
153 pub fn aic(
156 &self,
157 log_likelihood: f64,
158 n_parameters: usize,
159 n_data_points: usize,
160 ) -> InformationCriterionResult {
161 let k = n_parameters as f64;
162 let aic_value = 2.0 * k - 2.0 * log_likelihood;
163
164 InformationCriterionResult::new(
165 "AIC".to_string(),
166 aic_value,
167 log_likelihood,
168 n_parameters,
169 n_data_points,
170 )
171 }
172
173 pub fn aicc(
176 &self,
177 log_likelihood: f64,
178 n_parameters: usize,
179 n_data_points: usize,
180 ) -> Result<InformationCriterionResult> {
181 let k = n_parameters as f64;
182 let n = n_data_points as f64;
183
184 if n <= k + 1.0 {
185 return Err(SklearsError::InvalidInput(
186 "AICc requires n > k + 1".to_string(),
187 ));
188 }
189
190 let aic_value = 2.0 * k - 2.0 * log_likelihood;
191 let correction = 2.0 * k * (k + 1.0) / (n - k - 1.0);
192 let aicc_value = aic_value + correction;
193
194 Ok(InformationCriterionResult::new(
195 "AICc".to_string(),
196 aicc_value,
197 log_likelihood,
198 n_parameters,
199 n_data_points,
200 ))
201 }
202
203 pub fn bic(
206 &self,
207 log_likelihood: f64,
208 n_parameters: usize,
209 n_data_points: usize,
210 ) -> InformationCriterionResult {
211 let k = n_parameters as f64;
212 let n = n_data_points as f64;
213 let bic_value = k * n.ln() - 2.0 * log_likelihood;
214
215 InformationCriterionResult::new(
216 "BIC".to_string(),
217 bic_value,
218 log_likelihood,
219 n_parameters,
220 n_data_points,
221 )
222 }
223
224 pub fn dic(
227 &self,
228 log_likelihood_mean: f64,
229 log_likelihood_samples: &[f64],
230 n_data_points: usize,
231 ) -> Result<InformationCriterionResult> {
232 if log_likelihood_samples.is_empty() {
233 return Err(SklearsError::InvalidInput(
234 "DIC requires posterior samples".to_string(),
235 ));
236 }
237
238 let deviance_mean = -2.0 * log_likelihood_mean;
240
241 let mean_deviance =
243 -2.0 * log_likelihood_samples.iter().sum::<f64>() / log_likelihood_samples.len() as f64;
244
245 let p_d = mean_deviance - deviance_mean;
247
248 let dic_value = deviance_mean + p_d;
250
251 Ok(InformationCriterionResult::new(
252 "DIC".to_string(),
253 dic_value,
254 log_likelihood_mean,
255 0, n_data_points,
257 )
258 .with_effective_parameters(p_d))
259 }
260
261 pub fn waic(
264 &self,
265 pointwise_log_likelihoods: &Array2<f64>, ) -> Result<InformationCriterionResult> {
267 let (n_samples, n_data) = pointwise_log_likelihoods.dim();
268
269 if n_samples == 0 || n_data == 0 {
270 return Err(SklearsError::InvalidInput(
271 "WAIC requires non-empty likelihood matrix".to_string(),
272 ));
273 }
274
275 let mut lppd = 0.0;
277 let mut p_waic = 0.0;
278
279 for j in 0..n_data {
280 let column = pointwise_log_likelihoods.column(j);
281
282 let column_data: Vec<f64> = column.iter().copied().collect();
284 let log_mean_likelihood = log_mean_exp(&column_data);
285 lppd += log_mean_likelihood;
286
287 let mean_log_likelihood = column.mean().expect("operation should succeed");
289 let variance = column
290 .iter()
291 .map(|&x| (x - mean_log_likelihood).powi(2))
292 .sum::<f64>()
293 / (n_samples - 1) as f64;
294 p_waic += variance;
295 }
296
297 let waic_value = -2.0 * (lppd - p_waic);
298
299 let total_log_likelihood = pointwise_log_likelihoods.sum();
301
302 Ok(InformationCriterionResult::new(
303 "WAIC".to_string(),
304 waic_value,
305 total_log_likelihood,
306 0, n_data,
308 )
309 .with_effective_parameters(p_waic))
310 }
311
312 pub fn looic(
314 &self,
315 pointwise_log_likelihoods: &Array2<f64>,
316 pareto_k_diagnostics: Option<&Array1<f64>>,
317 ) -> Result<InformationCriterionResult> {
318 let (n_samples, n_data) = pointwise_log_likelihoods.dim();
319
320 if n_samples == 0 || n_data == 0 {
321 return Err(SklearsError::InvalidInput(
322 "LOOIC requires non-empty likelihood matrix".to_string(),
323 ));
324 }
325
326 let mut elpd_loo = 0.0; let mut p_loo = 0.0; for j in 0..n_data {
330 let column = pointwise_log_likelihoods.column(j);
331 let log_likes = column.as_slice().expect("operation should succeed");
332
333 if let Some(k_values) = pareto_k_diagnostics {
335 if k_values[j] > 0.7 {
336 eprintln!(
337 "Warning: High Pareto k ({:.3}) for observation {}",
338 k_values[j], j
339 );
340 }
341 }
342
343 let max_log_like = log_likes.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
345 let rel_log_likes: Vec<f64> = log_likes.iter().map(|&x| x - max_log_like).collect();
346
347 let weights: Vec<f64> = rel_log_likes.iter().map(|&x| x.exp()).collect();
349
350 let sum_weights: f64 = weights.iter().sum();
351 if sum_weights == 0.0 {
352 return Err(SklearsError::InvalidInput(
353 "Zero importance weights".to_string(),
354 ));
355 }
356
357 let _normalized_weights: Vec<f64> = weights.iter().map(|&w| w / sum_weights).collect();
358
359 let loo_lpd = (sum_weights / n_samples as f64).ln() + max_log_like;
361 elpd_loo += loo_lpd;
362
363 let mean_log_like = log_likes.iter().sum::<f64>() / n_samples as f64;
365 p_loo += mean_log_like - loo_lpd;
366 }
367
368 let looic_value = -2.0 * elpd_loo;
369
370 Ok(InformationCriterionResult::new(
371 "LOOIC".to_string(),
372 looic_value,
373 0.0, 0, n_data,
376 )
377 .with_effective_parameters(p_loo))
378 }
379
380 pub fn tic(
383 &self,
384 log_likelihood: f64,
385 fisher_information_trace: f64,
386 n_data_points: usize,
387 ) -> Result<InformationCriterionResult> {
388 if fisher_information_trace <= 0.0 {
389 return Err(SklearsError::InvalidInput(
390 "Fisher information trace must be positive".to_string(),
391 ));
392 }
393
394 let tic_value = -2.0 * log_likelihood + 2.0 * fisher_information_trace;
395
396 Ok(InformationCriterionResult::new(
397 "TIC".to_string(),
398 tic_value,
399 log_likelihood,
400 0, n_data_points,
402 )
403 .with_effective_parameters(fisher_information_trace))
404 }
405
406 pub fn compare_models(
408 &self,
409 models: &[(String, f64, usize, usize)], criterion: InformationCriterion,
411 ) -> Result<ModelComparisonResult> {
412 if models.is_empty() {
413 return Err(SklearsError::InvalidInput("No models provided".to_string()));
414 }
415
416 let mut results = Vec::new();
417 let model_names: Vec<String> = models.iter().map(|(name, _, _, _)| name.clone()).collect();
418
419 for (_name, log_likelihood, n_params, n_data) in models {
421 let result = match criterion {
422 InformationCriterion::AIC => self.aic(*log_likelihood, *n_params, *n_data),
423 InformationCriterion::AICc => self.aicc(*log_likelihood, *n_params, *n_data)?,
424 InformationCriterion::BIC => self.bic(*log_likelihood, *n_params, *n_data),
425 };
426 results.push(result);
427 }
428
429 let best_idx = results
431 .iter()
432 .enumerate()
433 .min_by(|(_, a), (_, b)| {
434 a.value
435 .partial_cmp(&b.value)
436 .expect("operation should succeed")
437 })
438 .map(|(i, _)| i)
439 .expect("operation should succeed");
440
441 let best_value = results[best_idx].value;
442
443 let delta_values: Vec<f64> = results.iter().map(|r| r.value - best_value).collect();
445
446 let weights = if self.calculate_weights {
448 self.calculate_akaike_weights(&delta_values)
449 } else {
450 vec![0.0; results.len()]
451 };
452
453 let mut sorted_deltas = delta_values.clone();
455 sorted_deltas.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
456 let evidence_ratio = if sorted_deltas.len() > 1 {
457 (-0.5 * sorted_deltas[1]).exp()
458 } else {
459 1.0
460 };
461
462 Ok(ModelComparisonResult {
463 model_names,
464 results,
465 delta_values,
466 weights,
467 best_model_index: best_idx,
468 evidence_ratio,
469 })
470 }
471
472 fn calculate_akaike_weights(&self, delta_values: &[f64]) -> Vec<f64> {
474 let weights: Vec<f64> = delta_values
476 .iter()
477 .map(|&delta| (-0.5 * delta).exp())
478 .collect();
479
480 let sum_weights: f64 = weights.iter().sum();
481 if sum_weights == 0.0 {
482 return vec![1.0 / weights.len() as f64; weights.len()];
483 }
484
485 weights.iter().map(|&w| w / sum_weights).collect()
486 }
487
488 pub fn model_averaged_prediction(
490 &self,
491 predictions: &[Array1<f64>],
492 weights: &[f64],
493 ) -> Result<Array1<f64>> {
494 if predictions.is_empty() {
495 return Err(SklearsError::InvalidInput(
496 "No predictions provided".to_string(),
497 ));
498 }
499
500 if predictions.len() != weights.len() {
501 return Err(SklearsError::InvalidInput(
502 "Number of predictions must match number of weights".to_string(),
503 ));
504 }
505
506 let n_samples = predictions[0].len();
507 for pred in predictions {
508 if pred.len() != n_samples {
509 return Err(SklearsError::InvalidInput(
510 "All predictions must have the same length".to_string(),
511 ));
512 }
513 }
514
515 let mut averaged = Array1::zeros(n_samples);
516 for (pred, &weight) in predictions.iter().zip(weights.iter()) {
517 averaged = averaged + pred * weight;
518 }
519
520 Ok(averaged)
521 }
522}
523
524#[derive(Debug, Clone, Copy)]
526pub enum InformationCriterion {
527 AIC,
529 AICc,
531 BIC,
533}
534
535fn log_mean_exp(values: &[f64]) -> f64 {
537 let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
538 let sum: f64 = values.iter().map(|&x| (x - max_val).exp()).sum();
539 max_val + (sum / values.len() as f64).ln()
540}
541
542pub struct CrossValidatedIC {
544 criterion: InformationCriterion,
545 n_folds: usize,
546}
547
548impl CrossValidatedIC {
549 pub fn new(criterion: InformationCriterion, n_folds: usize) -> Self {
551 Self { criterion, n_folds }
552 }
553
554 pub fn select_model(
556 &self,
557 cv_results: &[(String, Vec<f64>, Vec<usize>, Vec<usize>)], ) -> Result<ModelComparisonResult> {
559 let calculator = InformationCriterionCalculator::new();
560 let mut aggregated_models = Vec::new();
561
562 for (name, cv_log_likes, cv_n_params, cv_n_data) in cv_results {
563 if cv_log_likes.len() != self.n_folds {
564 return Err(SklearsError::InvalidInput(
565 "CV results must match number of folds".to_string(),
566 ));
567 }
568
569 let total_log_likelihood: f64 = cv_log_likes.iter().sum();
571 let avg_n_params = cv_n_params.iter().sum::<usize>() / cv_n_params.len();
572 let total_n_data = cv_n_data.iter().sum::<usize>();
573
574 aggregated_models.push((
575 name.clone(),
576 total_log_likelihood,
577 avg_n_params,
578 total_n_data,
579 ));
580 }
581
582 calculator.compare_models(&aggregated_models, self.criterion)
583 }
584}
585
586#[allow(non_snake_case)]
587#[cfg(test)]
588mod tests {
589 use super::*;
590 use scirs2_core::ndarray::array;
591
592 #[test]
593 fn test_aic_calculation() {
594 let calculator = InformationCriterionCalculator::new();
595 let result = calculator.aic(-100.0, 5, 200);
596
597 assert_eq!(result.criterion_name, "AIC");
598 assert_eq!(result.value, 210.0); assert_eq!(result.n_parameters, 5);
600 assert_eq!(result.n_data_points, 200);
601 }
602
603 #[test]
604 fn test_aicc_calculation() {
605 let calculator = InformationCriterionCalculator::new();
606 let result = calculator
607 .aicc(-100.0, 5, 20)
608 .expect("operation should succeed");
609
610 assert_eq!(result.criterion_name, "AICc");
611 assert!(result.value > 210.0); }
613
614 #[test]
615 fn test_bic_calculation() {
616 let calculator = InformationCriterionCalculator::new();
617 let result = calculator.bic(-100.0, 5, 200);
618
619 assert_eq!(result.criterion_name, "BIC");
620 assert!(result.value > 210.0); }
622
623 #[test]
624 fn test_model_comparison() {
625 let calculator = InformationCriterionCalculator::new();
626 let models = vec![
627 ("Model1".to_string(), -95.0, 3, 100),
628 ("Model2".to_string(), -100.0, 5, 100),
629 ("Model3".to_string(), -98.0, 4, 100),
630 ];
631
632 let result = calculator
633 .compare_models(&models, InformationCriterion::AIC)
634 .expect("operation should succeed");
635
636 assert_eq!(result.model_names.len(), 3);
637 assert_eq!(result.best_model(), "Model1"); assert!((result.weights.iter().sum::<f64>() - 1.0).abs() < 1e-6);
639 }
640
641 #[test]
642 fn test_akaike_weights() {
643 let calculator = InformationCriterionCalculator::new();
644 let delta_values = vec![0.0, 2.0, 4.0]; let weights = calculator.calculate_akaike_weights(&delta_values);
646
647 assert!(weights[0] > weights[1]); assert!(weights[1] > weights[2]); assert!((weights.iter().sum::<f64>() - 1.0).abs() < 1e-6); }
651
652 #[test]
653 fn test_waic_calculation() {
654 let calculator = InformationCriterionCalculator::new();
655
656 let pointwise_ll = array![
658 [-1.0, -1.2, -0.9, -1.1, -1.0],
659 [-1.1, -1.0, -1.0, -1.0, -0.9],
660 [-0.9, -1.1, -1.1, -0.9, -1.1]
661 ];
662
663 let result = calculator
664 .waic(&pointwise_ll)
665 .expect("operation should succeed");
666 assert_eq!(result.criterion_name, "WAIC");
667 assert!(result.effective_parameters.is_some());
668 }
669
670 #[test]
671 fn test_model_ranking() {
672 let models = vec![
673 ("ModelA".to_string(), -100.0, 5, 100),
674 ("ModelB".to_string(), -95.0, 3, 100),
675 ("ModelC".to_string(), -98.0, 4, 100),
676 ];
677
678 let calculator = InformationCriterionCalculator::new();
679 let result = calculator
680 .compare_models(&models, InformationCriterion::AIC)
681 .expect("operation should succeed");
682 let ranking = result.model_ranking();
683
684 assert_eq!(ranking[0].1, "ModelB"); assert_eq!(ranking[2].1, "ModelA"); }
687
688 #[test]
689 fn test_model_averaged_prediction() {
690 let calculator = InformationCriterionCalculator::new();
691
692 let pred1 = array![1.0, 2.0, 3.0];
693 let pred2 = array![1.1, 2.1, 3.1];
694 let pred3 = array![0.9, 1.9, 2.9];
695
696 let predictions = vec![pred1, pred2, pred3];
697 let weights = vec![0.5, 0.3, 0.2];
698
699 let averaged = calculator
700 .model_averaged_prediction(&predictions, &weights)
701 .expect("operation should succeed");
702 assert_eq!(averaged.len(), 3);
703
704 let expected_0 = 1.0 * 0.5 + 1.1 * 0.3 + 0.9 * 0.2;
706 assert!((averaged[0] - expected_0).abs() < 1e-10);
707 }
708}