1use scirs2_core::ndarray::{Array1, Array2};
17use sklears_core::error::{Result, SklearsError};
19use std::fmt::Debug;
20
21#[derive(Debug, Clone)]
23pub struct InformationCriterionResult {
24 pub criterion_name: String,
26 pub value: f64,
28 pub log_likelihood: f64,
30 pub n_parameters: usize,
32 pub n_data_points: usize,
34 pub effective_parameters: Option<f64>,
36 pub standard_error: Option<f64>,
38 pub weight: Option<f64>,
40}
41
42impl InformationCriterionResult {
43 pub fn new(
45 criterion_name: String,
46 value: f64,
47 log_likelihood: f64,
48 n_parameters: usize,
49 n_data_points: usize,
50 ) -> Self {
51 Self {
52 criterion_name,
53 value,
54 log_likelihood,
55 n_parameters,
56 n_data_points,
57 effective_parameters: None,
58 standard_error: None,
59 weight: None,
60 }
61 }
62
63 pub fn with_effective_parameters(mut self, p_eff: f64) -> Self {
65 self.effective_parameters = Some(p_eff);
66 self
67 }
68
69 pub fn with_standard_error(mut self, se: f64) -> Self {
71 self.standard_error = Some(se);
72 self
73 }
74
75 pub fn with_weight(mut self, weight: f64) -> Self {
77 self.weight = Some(weight);
78 self
79 }
80}
81
82#[derive(Debug, Clone)]
84pub struct ModelComparisonResult {
85 pub model_names: Vec<String>,
87 pub results: Vec<InformationCriterionResult>,
89 pub delta_values: Vec<f64>,
91 pub weights: Vec<f64>,
93 pub best_model_index: usize,
95 pub evidence_ratio: f64,
97}
98
99impl ModelComparisonResult {
100 pub fn best_model(&self) -> &str {
102 &self.model_names[self.best_model_index]
103 }
104
105 pub fn model_ranking(&self) -> Vec<(usize, &str, f64)> {
107 let mut ranking: Vec<(usize, &str, f64)> = self
108 .model_names
109 .iter()
110 .enumerate()
111 .map(|(i, name)| (i, name.as_str(), self.results[i].value))
112 .collect();
113
114 ranking.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap());
115 ranking
116 }
117
118 pub fn model_strength_interpretation(&self, model_idx: usize) -> String {
120 let delta = self.delta_values[model_idx];
121 match delta {
122 d if d <= 2.0 => "Substantial support".to_string(),
123 d if d <= 4.0 => "Considerably less support".to_string(),
124 d if d <= 7.0 => "Little support".to_string(),
125 _ => "No support".to_string(),
126 }
127 }
128}
129
130pub struct InformationCriterionCalculator {
132 pub use_bias_correction: bool,
134 pub calculate_weights: bool,
136}
137
138impl Default for InformationCriterionCalculator {
139 fn default() -> Self {
140 Self {
141 use_bias_correction: true,
142 calculate_weights: true,
143 }
144 }
145}
146
147impl InformationCriterionCalculator {
148 pub fn new() -> Self {
150 Self::default()
151 }
152
153 pub fn aic(
156 &self,
157 log_likelihood: f64,
158 n_parameters: usize,
159 n_data_points: usize,
160 ) -> InformationCriterionResult {
161 let k = n_parameters as f64;
162 let aic_value = 2.0 * k - 2.0 * log_likelihood;
163
164 InformationCriterionResult::new(
165 "AIC".to_string(),
166 aic_value,
167 log_likelihood,
168 n_parameters,
169 n_data_points,
170 )
171 }
172
173 pub fn aicc(
176 &self,
177 log_likelihood: f64,
178 n_parameters: usize,
179 n_data_points: usize,
180 ) -> Result<InformationCriterionResult> {
181 let k = n_parameters as f64;
182 let n = n_data_points as f64;
183
184 if n <= k + 1.0 {
185 return Err(SklearsError::InvalidInput(
186 "AICc requires n > k + 1".to_string(),
187 ));
188 }
189
190 let aic_value = 2.0 * k - 2.0 * log_likelihood;
191 let correction = 2.0 * k * (k + 1.0) / (n - k - 1.0);
192 let aicc_value = aic_value + correction;
193
194 Ok(InformationCriterionResult::new(
195 "AICc".to_string(),
196 aicc_value,
197 log_likelihood,
198 n_parameters,
199 n_data_points,
200 ))
201 }
202
203 pub fn bic(
206 &self,
207 log_likelihood: f64,
208 n_parameters: usize,
209 n_data_points: usize,
210 ) -> InformationCriterionResult {
211 let k = n_parameters as f64;
212 let n = n_data_points as f64;
213 let bic_value = k * n.ln() - 2.0 * log_likelihood;
214
215 InformationCriterionResult::new(
216 "BIC".to_string(),
217 bic_value,
218 log_likelihood,
219 n_parameters,
220 n_data_points,
221 )
222 }
223
224 pub fn dic(
227 &self,
228 log_likelihood_mean: f64,
229 log_likelihood_samples: &[f64],
230 n_data_points: usize,
231 ) -> Result<InformationCriterionResult> {
232 if log_likelihood_samples.is_empty() {
233 return Err(SklearsError::InvalidInput(
234 "DIC requires posterior samples".to_string(),
235 ));
236 }
237
238 let deviance_mean = -2.0 * log_likelihood_mean;
240
241 let mean_deviance =
243 -2.0 * log_likelihood_samples.iter().sum::<f64>() / log_likelihood_samples.len() as f64;
244
245 let p_d = mean_deviance - deviance_mean;
247
248 let dic_value = deviance_mean + p_d;
250
251 Ok(InformationCriterionResult::new(
252 "DIC".to_string(),
253 dic_value,
254 log_likelihood_mean,
255 0, n_data_points,
257 )
258 .with_effective_parameters(p_d))
259 }
260
261 pub fn waic(
264 &self,
265 pointwise_log_likelihoods: &Array2<f64>, ) -> Result<InformationCriterionResult> {
267 let (n_samples, n_data) = pointwise_log_likelihoods.dim();
268
269 if n_samples == 0 || n_data == 0 {
270 return Err(SklearsError::InvalidInput(
271 "WAIC requires non-empty likelihood matrix".to_string(),
272 ));
273 }
274
275 let mut lppd = 0.0;
277 let mut p_waic = 0.0;
278
279 for j in 0..n_data {
280 let column = pointwise_log_likelihoods.column(j);
281
282 let column_data: Vec<f64> = column.iter().copied().collect();
284 let log_mean_likelihood = log_mean_exp(&column_data);
285 lppd += log_mean_likelihood;
286
287 let mean_log_likelihood = column.mean().unwrap();
289 let variance = column
290 .iter()
291 .map(|&x| (x - mean_log_likelihood).powi(2))
292 .sum::<f64>()
293 / (n_samples - 1) as f64;
294 p_waic += variance;
295 }
296
297 let waic_value = -2.0 * (lppd - p_waic);
298
299 let total_log_likelihood = pointwise_log_likelihoods.sum();
301
302 Ok(InformationCriterionResult::new(
303 "WAIC".to_string(),
304 waic_value,
305 total_log_likelihood,
306 0, n_data,
308 )
309 .with_effective_parameters(p_waic))
310 }
311
312 pub fn looic(
314 &self,
315 pointwise_log_likelihoods: &Array2<f64>,
316 pareto_k_diagnostics: Option<&Array1<f64>>,
317 ) -> Result<InformationCriterionResult> {
318 let (n_samples, n_data) = pointwise_log_likelihoods.dim();
319
320 if n_samples == 0 || n_data == 0 {
321 return Err(SklearsError::InvalidInput(
322 "LOOIC requires non-empty likelihood matrix".to_string(),
323 ));
324 }
325
326 let mut elpd_loo = 0.0; let mut p_loo = 0.0; for j in 0..n_data {
330 let column = pointwise_log_likelihoods.column(j);
331 let log_likes = column.as_slice().unwrap();
332
333 if let Some(k_values) = pareto_k_diagnostics {
335 if k_values[j] > 0.7 {
336 eprintln!(
337 "Warning: High Pareto k ({:.3}) for observation {}",
338 k_values[j], j
339 );
340 }
341 }
342
343 let max_log_like = log_likes.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
345 let rel_log_likes: Vec<f64> = log_likes.iter().map(|&x| x - max_log_like).collect();
346
347 let weights: Vec<f64> = rel_log_likes.iter().map(|&x| x.exp()).collect();
349
350 let sum_weights: f64 = weights.iter().sum();
351 if sum_weights == 0.0 {
352 return Err(SklearsError::InvalidInput(
353 "Zero importance weights".to_string(),
354 ));
355 }
356
357 let _normalized_weights: Vec<f64> = weights.iter().map(|&w| w / sum_weights).collect();
358
359 let loo_lpd = (sum_weights / n_samples as f64).ln() + max_log_like;
361 elpd_loo += loo_lpd;
362
363 let mean_log_like = log_likes.iter().sum::<f64>() / n_samples as f64;
365 p_loo += mean_log_like - loo_lpd;
366 }
367
368 let looic_value = -2.0 * elpd_loo;
369
370 Ok(InformationCriterionResult::new(
371 "LOOIC".to_string(),
372 looic_value,
373 0.0, 0, n_data,
376 )
377 .with_effective_parameters(p_loo))
378 }
379
380 pub fn tic(
383 &self,
384 log_likelihood: f64,
385 fisher_information_trace: f64,
386 n_data_points: usize,
387 ) -> Result<InformationCriterionResult> {
388 if fisher_information_trace <= 0.0 {
389 return Err(SklearsError::InvalidInput(
390 "Fisher information trace must be positive".to_string(),
391 ));
392 }
393
394 let tic_value = -2.0 * log_likelihood + 2.0 * fisher_information_trace;
395
396 Ok(InformationCriterionResult::new(
397 "TIC".to_string(),
398 tic_value,
399 log_likelihood,
400 0, n_data_points,
402 )
403 .with_effective_parameters(fisher_information_trace))
404 }
405
406 pub fn compare_models(
408 &self,
409 models: &[(String, f64, usize, usize)], criterion: InformationCriterion,
411 ) -> Result<ModelComparisonResult> {
412 if models.is_empty() {
413 return Err(SklearsError::InvalidInput("No models provided".to_string()));
414 }
415
416 let mut results = Vec::new();
417 let model_names: Vec<String> = models.iter().map(|(name, _, _, _)| name.clone()).collect();
418
419 for (_name, log_likelihood, n_params, n_data) in models {
421 let result = match criterion {
422 InformationCriterion::AIC => self.aic(*log_likelihood, *n_params, *n_data),
423 InformationCriterion::AICc => self.aicc(*log_likelihood, *n_params, *n_data)?,
424 InformationCriterion::BIC => self.bic(*log_likelihood, *n_params, *n_data),
425 };
426 results.push(result);
427 }
428
429 let best_idx = results
431 .iter()
432 .enumerate()
433 .min_by(|(_, a), (_, b)| a.value.partial_cmp(&b.value).unwrap())
434 .map(|(i, _)| i)
435 .unwrap();
436
437 let best_value = results[best_idx].value;
438
439 let delta_values: Vec<f64> = results.iter().map(|r| r.value - best_value).collect();
441
442 let weights = if self.calculate_weights {
444 self.calculate_akaike_weights(&delta_values)
445 } else {
446 vec![0.0; results.len()]
447 };
448
449 let mut sorted_deltas = delta_values.clone();
451 sorted_deltas.sort_by(|a, b| a.partial_cmp(b).unwrap());
452 let evidence_ratio = if sorted_deltas.len() > 1 {
453 (-0.5 * sorted_deltas[1]).exp()
454 } else {
455 1.0
456 };
457
458 Ok(ModelComparisonResult {
459 model_names,
460 results,
461 delta_values,
462 weights,
463 best_model_index: best_idx,
464 evidence_ratio,
465 })
466 }
467
468 fn calculate_akaike_weights(&self, delta_values: &[f64]) -> Vec<f64> {
470 let weights: Vec<f64> = delta_values
472 .iter()
473 .map(|&delta| (-0.5 * delta).exp())
474 .collect();
475
476 let sum_weights: f64 = weights.iter().sum();
477 if sum_weights == 0.0 {
478 return vec![1.0 / weights.len() as f64; weights.len()];
479 }
480
481 weights.iter().map(|&w| w / sum_weights).collect()
482 }
483
484 pub fn model_averaged_prediction(
486 &self,
487 predictions: &[Array1<f64>],
488 weights: &[f64],
489 ) -> Result<Array1<f64>> {
490 if predictions.is_empty() {
491 return Err(SklearsError::InvalidInput(
492 "No predictions provided".to_string(),
493 ));
494 }
495
496 if predictions.len() != weights.len() {
497 return Err(SklearsError::InvalidInput(
498 "Number of predictions must match number of weights".to_string(),
499 ));
500 }
501
502 let n_samples = predictions[0].len();
503 for pred in predictions {
504 if pred.len() != n_samples {
505 return Err(SklearsError::InvalidInput(
506 "All predictions must have the same length".to_string(),
507 ));
508 }
509 }
510
511 let mut averaged = Array1::zeros(n_samples);
512 for (pred, &weight) in predictions.iter().zip(weights.iter()) {
513 averaged = averaged + pred * weight;
514 }
515
516 Ok(averaged)
517 }
518}
519
520#[derive(Debug, Clone, Copy)]
522pub enum InformationCriterion {
523 AIC,
525 AICc,
527 BIC,
529}
530
531fn log_mean_exp(values: &[f64]) -> f64 {
533 let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
534 let sum: f64 = values.iter().map(|&x| (x - max_val).exp()).sum();
535 max_val + (sum / values.len() as f64).ln()
536}
537
538pub struct CrossValidatedIC {
540 criterion: InformationCriterion,
541 n_folds: usize,
542}
543
544impl CrossValidatedIC {
545 pub fn new(criterion: InformationCriterion, n_folds: usize) -> Self {
547 Self { criterion, n_folds }
548 }
549
550 pub fn select_model(
552 &self,
553 cv_results: &[(String, Vec<f64>, Vec<usize>, Vec<usize>)], ) -> Result<ModelComparisonResult> {
555 let calculator = InformationCriterionCalculator::new();
556 let mut aggregated_models = Vec::new();
557
558 for (name, cv_log_likes, cv_n_params, cv_n_data) in cv_results {
559 if cv_log_likes.len() != self.n_folds {
560 return Err(SklearsError::InvalidInput(
561 "CV results must match number of folds".to_string(),
562 ));
563 }
564
565 let total_log_likelihood: f64 = cv_log_likes.iter().sum();
567 let avg_n_params = cv_n_params.iter().sum::<usize>() / cv_n_params.len();
568 let total_n_data = cv_n_data.iter().sum::<usize>();
569
570 aggregated_models.push((
571 name.clone(),
572 total_log_likelihood,
573 avg_n_params,
574 total_n_data,
575 ));
576 }
577
578 calculator.compare_models(&aggregated_models, self.criterion)
579 }
580}
581
582#[allow(non_snake_case)]
583#[cfg(test)]
584mod tests {
585 use super::*;
586 use scirs2_core::ndarray::array;
587
588 #[test]
589 fn test_aic_calculation() {
590 let calculator = InformationCriterionCalculator::new();
591 let result = calculator.aic(-100.0, 5, 200);
592
593 assert_eq!(result.criterion_name, "AIC");
594 assert_eq!(result.value, 210.0); assert_eq!(result.n_parameters, 5);
596 assert_eq!(result.n_data_points, 200);
597 }
598
599 #[test]
600 fn test_aicc_calculation() {
601 let calculator = InformationCriterionCalculator::new();
602 let result = calculator.aicc(-100.0, 5, 20).unwrap();
603
604 assert_eq!(result.criterion_name, "AICc");
605 assert!(result.value > 210.0); }
607
608 #[test]
609 fn test_bic_calculation() {
610 let calculator = InformationCriterionCalculator::new();
611 let result = calculator.bic(-100.0, 5, 200);
612
613 assert_eq!(result.criterion_name, "BIC");
614 assert!(result.value > 210.0); }
616
617 #[test]
618 fn test_model_comparison() {
619 let calculator = InformationCriterionCalculator::new();
620 let models = vec![
621 ("Model1".to_string(), -95.0, 3, 100),
622 ("Model2".to_string(), -100.0, 5, 100),
623 ("Model3".to_string(), -98.0, 4, 100),
624 ];
625
626 let result = calculator
627 .compare_models(&models, InformationCriterion::AIC)
628 .unwrap();
629
630 assert_eq!(result.model_names.len(), 3);
631 assert_eq!(result.best_model(), "Model1"); assert!((result.weights.iter().sum::<f64>() - 1.0).abs() < 1e-6);
633 }
634
635 #[test]
636 fn test_akaike_weights() {
637 let calculator = InformationCriterionCalculator::new();
638 let delta_values = vec![0.0, 2.0, 4.0]; let weights = calculator.calculate_akaike_weights(&delta_values);
640
641 assert!(weights[0] > weights[1]); assert!(weights[1] > weights[2]); assert!((weights.iter().sum::<f64>() - 1.0).abs() < 1e-6); }
645
646 #[test]
647 fn test_waic_calculation() {
648 let calculator = InformationCriterionCalculator::new();
649
650 let pointwise_ll = array![
652 [-1.0, -1.2, -0.9, -1.1, -1.0],
653 [-1.1, -1.0, -1.0, -1.0, -0.9],
654 [-0.9, -1.1, -1.1, -0.9, -1.1]
655 ];
656
657 let result = calculator.waic(&pointwise_ll).unwrap();
658 assert_eq!(result.criterion_name, "WAIC");
659 assert!(result.effective_parameters.is_some());
660 }
661
662 #[test]
663 fn test_model_ranking() {
664 let models = vec![
665 ("ModelA".to_string(), -100.0, 5, 100),
666 ("ModelB".to_string(), -95.0, 3, 100),
667 ("ModelC".to_string(), -98.0, 4, 100),
668 ];
669
670 let calculator = InformationCriterionCalculator::new();
671 let result = calculator
672 .compare_models(&models, InformationCriterion::AIC)
673 .unwrap();
674 let ranking = result.model_ranking();
675
676 assert_eq!(ranking[0].1, "ModelB"); assert_eq!(ranking[2].1, "ModelA"); }
679
680 #[test]
681 fn test_model_averaged_prediction() {
682 let calculator = InformationCriterionCalculator::new();
683
684 let pred1 = array![1.0, 2.0, 3.0];
685 let pred2 = array![1.1, 2.1, 3.1];
686 let pred3 = array![0.9, 1.9, 2.9];
687
688 let predictions = vec![pred1, pred2, pred3];
689 let weights = vec![0.5, 0.3, 0.2];
690
691 let averaged = calculator
692 .model_averaged_prediction(&predictions, &weights)
693 .unwrap();
694 assert_eq!(averaged.len(), 3);
695
696 let expected_0 = 1.0 * 0.5 + 1.1 * 0.3 + 0.9 * 0.2;
698 assert!((averaged[0] - expected_0).abs() < 1e-10);
699 }
700}