1use crate::common::CovarianceType;
8use crate::variational::{VariationalBayesianGMM, VariationalBayesianGMMTrained};
9use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
10use scirs2_core::random::essentials::Normal;
11use scirs2_core::random::{thread_rng, Rng, SeedableRng};
12use sklears_core::{
13 error::{Result as SklResult, SklearsError},
14 traits::{Fit, Predict},
15 types::Float,
16};
17use std::f64::consts::PI;
18
19#[derive(Debug, Clone)]
50pub struct PriorSensitivityAnalyzer {
51 n_components: usize,
52 covariance_type: CovarianceType,
53 max_iter: usize,
54 random_state: Option<u64>,
55
56 weight_concentration_range: (f64, f64),
58 weight_concentration_steps: usize,
59 mean_precision_range: (f64, f64),
60 mean_precision_steps: usize,
61 degrees_of_freedom_range: (f64, f64),
62 degrees_of_freedom_steps: usize,
63
64 n_random_perturbations: usize,
66 perturbation_scale: f64,
67
68 reference_weight_concentration: f64,
70 reference_mean_precision: f64,
71 reference_degrees_of_freedom: f64,
72
73 compute_kl_divergence: bool,
75 compute_parameter_variance: bool,
76 compute_prediction_variance: bool,
77 compute_influence_functions: bool,
78}
79
80#[derive(Debug, Clone)]
82pub struct SensitivityAnalysisResult {
83 grid_results: Vec<GridSearchResult>,
85
86 perturbation_results: Vec<PerturbationResult>,
88
89 reference_model: VariationalBayesianGMM<VariationalBayesianGMMTrained>,
91
92 kl_divergences: Vec<f64>,
94 parameter_variances: ParameterVariances,
95 prediction_variances: Array1<f64>,
96
97 influence_scores: Vec<InfluenceScore>,
99
100 summary: SensitivitySummary,
102}
103
104#[derive(Debug, Clone)]
106pub struct GridSearchResult {
107 weight_concentration: f64,
108 mean_precision: f64,
109 degrees_of_freedom: f64,
110 model: VariationalBayesianGMM<VariationalBayesianGMMTrained>,
111 lower_bound: f64,
112 effective_components: usize,
113}
114
115#[derive(Debug, Clone)]
117pub struct PerturbationResult {
118 perturbation_id: usize,
119 perturbed_weight_concentration: f64,
120 perturbed_mean_precision: f64,
121 perturbed_degrees_of_freedom: f64,
122 model: VariationalBayesianGMM<VariationalBayesianGMMTrained>,
123 kl_divergence_from_reference: f64,
124 parameter_distance_from_reference: f64,
125}
126
127#[derive(Debug, Clone)]
129pub struct ParameterVariances {
130 weight_variances: Array1<f64>,
131 mean_variances: Array2<f64>,
132 covariance_variances: Vec<Array2<f64>>,
133}
134
135#[derive(Debug, Clone)]
137pub struct InfluenceScore {
138 data_point_index: usize,
139 weight_influence: Array1<f64>,
140 mean_influence: Array2<f64>,
141 covariance_influence: Vec<Array2<f64>>,
142 total_influence: f64,
143}
144
145#[derive(Debug, Clone)]
147pub struct SensitivitySummary {
148 average_kl_divergence: f64,
149 max_kl_divergence: f64,
150 min_kl_divergence: f64,
151 kl_divergence_std: f64,
152
153 average_parameter_distance: f64,
154 max_parameter_distance: f64,
155 min_parameter_distance: f64,
156 parameter_distance_std: f64,
157
158 average_prediction_variance: f64,
159 max_prediction_variance: f64,
160 min_prediction_variance: f64,
161
162 most_sensitive_parameters: Vec<String>,
163 robustness_score: f64,
164}
165
166impl PriorSensitivityAnalyzer {
167 pub fn new() -> Self {
169 Self {
170 n_components: 2,
171 covariance_type: CovarianceType::Diagonal,
172 max_iter: 100,
173 random_state: None,
174
175 weight_concentration_range: (0.1, 5.0),
176 weight_concentration_steps: 5,
177 mean_precision_range: (0.1, 10.0),
178 mean_precision_steps: 5,
179 degrees_of_freedom_range: (1.0, 10.0),
180 degrees_of_freedom_steps: 5,
181
182 n_random_perturbations: 20,
183 perturbation_scale: 0.2,
184
185 reference_weight_concentration: 1.0,
186 reference_mean_precision: 1.0,
187 reference_degrees_of_freedom: 1.0,
188
189 compute_kl_divergence: true,
190 compute_parameter_variance: true,
191 compute_prediction_variance: true,
192 compute_influence_functions: false, }
194 }
195
196 pub fn n_components(mut self, n_components: usize) -> Self {
198 self.n_components = n_components;
199 self
200 }
201
202 pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
204 self.covariance_type = covariance_type;
205 self
206 }
207
208 pub fn max_iter(mut self, max_iter: usize) -> Self {
210 self.max_iter = max_iter;
211 self
212 }
213
214 pub fn random_state(mut self, random_state: u64) -> Self {
216 self.random_state = Some(random_state);
217 self
218 }
219
220 pub fn weight_concentration_range(mut self, range: (f64, f64), steps: usize) -> Self {
222 self.weight_concentration_range = range;
223 self.weight_concentration_steps = steps;
224 self
225 }
226
227 pub fn mean_precision_range(mut self, range: (f64, f64), steps: usize) -> Self {
229 self.mean_precision_range = range;
230 self.mean_precision_steps = steps;
231 self
232 }
233
234 pub fn degrees_of_freedom_range(mut self, range: (f64, f64), steps: usize) -> Self {
236 self.degrees_of_freedom_range = range;
237 self.degrees_of_freedom_steps = steps;
238 self
239 }
240
241 pub fn n_random_perturbations(mut self, n: usize) -> Self {
243 self.n_random_perturbations = n;
244 self
245 }
246
247 pub fn perturbation_scale(mut self, scale: f64) -> Self {
249 self.perturbation_scale = scale;
250 self
251 }
252
253 pub fn reference_priors(
255 mut self,
256 weight_concentration: f64,
257 mean_precision: f64,
258 degrees_of_freedom: f64,
259 ) -> Self {
260 self.reference_weight_concentration = weight_concentration;
261 self.reference_mean_precision = mean_precision;
262 self.reference_degrees_of_freedom = degrees_of_freedom;
263 self
264 }
265
266 pub fn compute_kl_divergence(mut self, compute: bool) -> Self {
268 self.compute_kl_divergence = compute;
269 self
270 }
271
272 pub fn compute_parameter_variance(mut self, compute: bool) -> Self {
274 self.compute_parameter_variance = compute;
275 self
276 }
277
278 pub fn compute_prediction_variance(mut self, compute: bool) -> Self {
280 self.compute_prediction_variance = compute;
281 self
282 }
283
284 pub fn compute_influence_functions(mut self, compute: bool) -> Self {
286 self.compute_influence_functions = compute;
287 self
288 }
289
290 #[allow(non_snake_case)]
292 pub fn analyze(&self, X: &ArrayView2<'_, Float>) -> SklResult<SensitivityAnalysisResult> {
293 let X = X.to_owned();
294 let (n_samples, n_features) = X.dim();
295
296 if n_samples < 2 {
297 return Err(SklearsError::InvalidInput(
298 "Number of samples must be at least 2".to_string(),
299 ));
300 }
301
302 let reference_model = self.fit_reference_model(&X)?;
304
305 let grid_results = self.grid_search_analysis(&X)?;
307
308 let perturbation_results = self.random_perturbation_analysis(&X)?;
310
311 let kl_divergences = if self.compute_kl_divergence {
313 self.compute_kl_divergences(&reference_model, &grid_results, &perturbation_results)?
314 } else {
315 Vec::new()
316 };
317
318 let parameter_variances = if self.compute_parameter_variance {
319 self.compute_parameter_variances(&grid_results)?
320 } else {
321 ParameterVariances {
322 weight_variances: Array1::zeros(self.n_components),
323 mean_variances: Array2::zeros((self.n_components, n_features)),
324 covariance_variances: vec![
325 Array2::zeros((n_features, n_features));
326 self.n_components
327 ],
328 }
329 };
330
331 let prediction_variances = if self.compute_prediction_variance {
332 self.compute_prediction_variances(&X, &grid_results)?
333 } else {
334 Array1::zeros(n_samples)
335 };
336
337 let influence_scores = if self.compute_influence_functions {
338 self.compute_influence_scores(&X, &reference_model)?
339 } else {
340 Vec::new()
341 };
342
343 let summary = self.compute_summary_statistics(
345 &kl_divergences,
346 &perturbation_results,
347 &prediction_variances,
348 )?;
349
350 Ok(SensitivityAnalysisResult {
351 grid_results,
352 perturbation_results,
353 reference_model,
354 kl_divergences,
355 parameter_variances,
356 prediction_variances,
357 influence_scores,
358 summary,
359 })
360 }
361
362 fn fit_reference_model(
364 &self,
365 X: &Array2<f64>,
366 ) -> SklResult<VariationalBayesianGMM<VariationalBayesianGMMTrained>> {
367 let model = VariationalBayesianGMM::new()
368 .n_components(self.n_components)
369 .covariance_type(self.covariance_type.clone())
370 .max_iter(self.max_iter)
371 .weight_concentration_prior(self.reference_weight_concentration)
372 .mean_precision_prior(self.reference_mean_precision)
373 .degrees_of_freedom_prior(self.reference_degrees_of_freedom)
374 .random_state(self.random_state.unwrap_or(42));
375
376 model.fit(&X.view(), &())
377 }
378
379 fn grid_search_analysis(&self, X: &Array2<f64>) -> SklResult<Vec<GridSearchResult>> {
381 let mut results = Vec::new();
382
383 let weight_concentrations = self.linspace(
385 self.weight_concentration_range.0,
386 self.weight_concentration_range.1,
387 self.weight_concentration_steps,
388 );
389 let mean_precisions = self.linspace(
390 self.mean_precision_range.0,
391 self.mean_precision_range.1,
392 self.mean_precision_steps,
393 );
394 let degrees_of_freedom = self.linspace(
395 self.degrees_of_freedom_range.0,
396 self.degrees_of_freedom_range.1,
397 self.degrees_of_freedom_steps,
398 );
399
400 for &weight_conc in &weight_concentrations {
402 for &mean_prec in &mean_precisions {
403 for &dof in °rees_of_freedom {
404 let model = VariationalBayesianGMM::new()
405 .n_components(self.n_components)
406 .covariance_type(self.covariance_type.clone())
407 .max_iter(self.max_iter)
408 .weight_concentration_prior(weight_conc)
409 .mean_precision_prior(mean_prec)
410 .degrees_of_freedom_prior(dof)
411 .random_state(self.random_state.unwrap_or(42));
412
413 match model.fit(&X.view(), &()) {
414 Ok(fitted_model) => {
415 results.push(GridSearchResult {
416 weight_concentration: weight_conc,
417 mean_precision: mean_prec,
418 degrees_of_freedom: dof,
419 lower_bound: fitted_model.lower_bound(),
420 effective_components: fitted_model.effective_components(),
421 model: fitted_model,
422 });
423 }
424 Err(_) => {
425 continue;
427 }
428 }
429 }
430 }
431 }
432
433 Ok(results)
434 }
435
436 fn random_perturbation_analysis(&self, X: &Array2<f64>) -> SklResult<Vec<PerturbationResult>> {
438 let mut rng = if let Some(seed) = self.random_state {
439 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
440 } else {
441 scirs2_core::random::rngs::StdRng::from_rng(&mut thread_rng())
442 };
443
444 let mut results = Vec::new();
445
446 for perturbation_id in 0..self.n_random_perturbations {
447 let weight_conc_perturbation = 1.0 + (rng.gen::<f64>() - 0.5) * self.perturbation_scale;
449 let mean_prec_perturbation = 1.0 + (rng.gen::<f64>() - 0.5) * self.perturbation_scale;
450 let dof_perturbation = 1.0 + (rng.gen::<f64>() - 0.5) * self.perturbation_scale;
451
452 let perturbed_weight_concentration =
453 (self.reference_weight_concentration * weight_conc_perturbation).max(0.01);
454 let perturbed_mean_precision =
455 (self.reference_mean_precision * mean_prec_perturbation).max(0.01);
456 let perturbed_degrees_of_freedom =
457 (self.reference_degrees_of_freedom * dof_perturbation).max(0.1);
458
459 let model = VariationalBayesianGMM::new()
461 .n_components(self.n_components)
462 .covariance_type(self.covariance_type.clone())
463 .max_iter(self.max_iter)
464 .weight_concentration_prior(perturbed_weight_concentration)
465 .mean_precision_prior(perturbed_mean_precision)
466 .degrees_of_freedom_prior(perturbed_degrees_of_freedom)
467 .random_state(self.random_state.unwrap_or(42 + perturbation_id as u64));
468
469 match model.fit(&X.view(), &()) {
470 Ok(fitted_model) => {
471 results.push(PerturbationResult {
472 perturbation_id,
473 perturbed_weight_concentration,
474 perturbed_mean_precision,
475 perturbed_degrees_of_freedom,
476 model: fitted_model,
477 kl_divergence_from_reference: 0.0, parameter_distance_from_reference: 0.0, });
480 }
481 Err(_) => {
482 continue;
484 }
485 }
486 }
487
488 Ok(results)
489 }
490
491 fn compute_kl_divergences(
493 &self,
494 reference_model: &VariationalBayesianGMM<VariationalBayesianGMMTrained>,
495 grid_results: &[GridSearchResult],
496 perturbation_results: &[PerturbationResult],
497 ) -> SklResult<Vec<f64>> {
498 let mut kl_divergences = Vec::new();
499
500 for result in grid_results {
502 let kl_div =
503 self.compute_kl_divergence_between_models(reference_model, &result.model)?;
504 kl_divergences.push(kl_div);
505 }
506
507 for result in perturbation_results {
509 let kl_div =
510 self.compute_kl_divergence_between_models(reference_model, &result.model)?;
511 kl_divergences.push(kl_div);
512 }
513
514 Ok(kl_divergences)
515 }
516
517 fn compute_kl_divergence_between_models(
519 &self,
520 model1: &VariationalBayesianGMM<VariationalBayesianGMMTrained>,
521 model2: &VariationalBayesianGMM<VariationalBayesianGMMTrained>,
522 ) -> SklResult<f64> {
523 let n_samples = 1000;
525 let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
526
527 let mut kl_sum = 0.0;
528 let n_features = model1.means().ncols();
529
530 for _ in 0..n_samples {
531 let component = (rng.gen::<f64>() * model1.weights().len() as f64) as usize;
533 let component = component.min(model1.weights().len() - 1);
534
535 let mean = model1.means().row(component);
536 let cov = &model1.covariances()[component];
537
538 let mut sample = Array1::zeros(n_features);
540 for d in 0..n_features {
541 let std_dev = cov[[d, d]].sqrt();
542 let normal = Normal::new(mean[d], std_dev).unwrap();
543 sample[d] = rng.sample(normal);
544 }
545
546 let log_p1 = self.log_probability_under_model(model1, &sample)?;
548 let log_p2 = self.log_probability_under_model(model2, &sample)?;
549
550 if log_p1.is_finite() && log_p2.is_finite() {
551 kl_sum += log_p1 - log_p2;
552 }
553 }
554
555 Ok(kl_sum / n_samples as f64)
556 }
557
558 fn log_probability_under_model(
560 &self,
561 model: &VariationalBayesianGMM<VariationalBayesianGMMTrained>,
562 sample: &Array1<f64>,
563 ) -> SklResult<f64> {
564 let mut total_prob = 0.0;
565
566 for k in 0..model.weights().len() {
567 let weight = model.weights()[k];
568 let mean = model.means().row(k);
569 let cov = &model.covariances()[k];
570
571 let diff = sample - &mean.to_owned();
572 let mahalanobis_dist = diff.dot(&diff) / cov[[0, 0]]; let component_prob =
575 weight * (-0.5 * mahalanobis_dist).exp() / (2.0 * PI * cov[[0, 0]]).sqrt();
576
577 total_prob += component_prob;
578 }
579
580 Ok(total_prob.ln())
581 }
582
583 fn compute_parameter_variances(
585 &self,
586 grid_results: &[GridSearchResult],
587 ) -> SklResult<ParameterVariances> {
588 if grid_results.is_empty() {
589 return Err(SklearsError::InvalidInput(
590 "No grid results available".to_string(),
591 ));
592 }
593
594 let n_features = grid_results[0].model.means().ncols();
595
596 let mut all_weights = Vec::new();
598 let mut all_means = Vec::new();
599 let mut all_covariances = Vec::new();
600
601 for result in grid_results {
602 all_weights.push(result.model.weights().clone());
603 all_means.push(result.model.means().clone());
604 all_covariances.push(result.model.covariances().to_vec());
605 }
606
607 let weight_variances = self.compute_array1_variance(&all_weights);
609 let mean_variances = self.compute_array2_variance(&all_means);
610 let covariance_variances = self.compute_covariance_variance(&all_covariances, n_features);
611
612 Ok(ParameterVariances {
613 weight_variances,
614 mean_variances,
615 covariance_variances,
616 })
617 }
618
619 fn compute_array1_variance(&self, arrays: &[Array1<f64>]) -> Array1<f64> {
621 if arrays.is_empty() {
622 return Array1::zeros(0);
623 }
624
625 let _n_arrays = arrays.len();
626 let array_len = arrays[0].len();
627 let mut variances = Array1::zeros(array_len);
628
629 for i in 0..array_len {
630 let values: Vec<f64> = arrays.iter().map(|arr| arr[i]).collect();
631 let mean = values.iter().sum::<f64>() / values.len() as f64;
632 let variance =
633 values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
634 variances[i] = variance;
635 }
636
637 variances
638 }
639
640 fn compute_array2_variance(&self, arrays: &[Array2<f64>]) -> Array2<f64> {
642 if arrays.is_empty() {
643 return Array2::zeros((0, 0));
644 }
645
646 let (n_rows, n_cols) = arrays[0].dim();
647 let mut variances = Array2::zeros((n_rows, n_cols));
648
649 for i in 0..n_rows {
650 for j in 0..n_cols {
651 let values: Vec<f64> = arrays.iter().map(|arr| arr[[i, j]]).collect();
652 let mean = values.iter().sum::<f64>() / values.len() as f64;
653 let variance =
654 values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
655 variances[[i, j]] = variance;
656 }
657 }
658
659 variances
660 }
661
662 fn compute_covariance_variance(
664 &self,
665 all_covariances: &[Vec<Array2<f64>>],
666 n_features: usize,
667 ) -> Vec<Array2<f64>> {
668 if all_covariances.is_empty() {
669 return vec![Array2::zeros((n_features, n_features)); self.n_components];
670 }
671
672 let mut variances = vec![Array2::zeros((n_features, n_features)); self.n_components];
673
674 for k in 0..self.n_components {
675 for i in 0..n_features {
676 for j in 0..n_features {
677 let values: Vec<f64> = all_covariances
678 .iter()
679 .filter(|cov| cov.len() > k)
680 .map(|cov| cov[k][[i, j]])
681 .collect();
682
683 if !values.is_empty() {
684 let mean = values.iter().sum::<f64>() / values.len() as f64;
685 let variance = values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
686 / values.len() as f64;
687 variances[k][[i, j]] = variance;
688 }
689 }
690 }
691 }
692
693 variances
694 }
695
696 fn compute_prediction_variances(
698 &self,
699 X: &Array2<f64>,
700 grid_results: &[GridSearchResult],
701 ) -> SklResult<Array1<f64>> {
702 let n_samples = X.nrows();
703 let mut prediction_variances = Array1::zeros(n_samples);
704
705 for i in 0..n_samples {
706 let x_i = X.row(i);
707 let mut predictions = Vec::new();
708
709 for result in grid_results {
711 match result
712 .model
713 .predict(&x_i.to_owned().insert_axis(Axis(0)).view())
714 {
715 Ok(pred) => {
716 if !pred.is_empty() {
718 predictions.push(pred[0] as f64);
719 }
720 }
721 Err(_) => continue,
722 }
723 }
724
725 if !predictions.is_empty() {
727 let mean_pred = predictions.iter().sum::<f64>() / predictions.len() as f64;
728 let variance = predictions
729 .iter()
730 .map(|&pred| (pred - mean_pred).powi(2))
731 .sum::<f64>()
732 / predictions.len() as f64;
733 prediction_variances[i] = variance;
734 }
735 }
736
737 Ok(prediction_variances)
738 }
739
740 fn compute_influence_scores(
742 &self,
743 X: &Array2<f64>,
744 reference_model: &VariationalBayesianGMM<VariationalBayesianGMMTrained>,
745 ) -> SklResult<Vec<InfluenceScore>> {
746 let (n_samples, n_features) = X.dim();
747 let mut influence_scores = Vec::new();
748
749 for i in 0..n_samples.min(10) {
751 let mut X_loo = Array2::zeros((n_samples - 1, n_features));
754 let mut row_idx = 0;
755 for j in 0..n_samples {
756 if j != i {
757 X_loo.row_mut(row_idx).assign(&X.row(j));
758 row_idx += 1;
759 }
760 }
761
762 let loo_model = VariationalBayesianGMM::new()
764 .n_components(self.n_components)
765 .covariance_type(self.covariance_type.clone())
766 .max_iter(self.max_iter)
767 .weight_concentration_prior(self.reference_weight_concentration)
768 .mean_precision_prior(self.reference_mean_precision)
769 .degrees_of_freedom_prior(self.reference_degrees_of_freedom)
770 .random_state(self.random_state.unwrap_or(42));
771
772 match loo_model.fit(&X_loo.view(), &()) {
773 Ok(fitted_loo_model) => {
774 let weight_influence = reference_model.weights() - fitted_loo_model.weights();
776 let mean_influence = reference_model.means() - fitted_loo_model.means();
777
778 let mut covariance_influence = Vec::new();
780 for k in 0..self.n_components {
781 let cov_diff =
782 &reference_model.covariances()[k] - &fitted_loo_model.covariances()[k];
783 covariance_influence.push(cov_diff);
784 }
785
786 let total_influence = weight_influence.iter().map(|x| x.abs()).sum::<f64>()
787 + mean_influence.iter().map(|x| x.abs()).sum::<f64>();
788
789 influence_scores.push(InfluenceScore {
790 data_point_index: i,
791 weight_influence,
792 mean_influence,
793 covariance_influence,
794 total_influence,
795 });
796 }
797 Err(_) => continue,
798 }
799 }
800
801 Ok(influence_scores)
802 }
803
804 fn compute_summary_statistics(
806 &self,
807 kl_divergences: &[f64],
808 perturbation_results: &[PerturbationResult],
809 prediction_variances: &Array1<f64>,
810 ) -> SklResult<SensitivitySummary> {
811 let (avg_kl, max_kl, min_kl, kl_std) = if !kl_divergences.is_empty() {
813 let avg = kl_divergences.iter().sum::<f64>() / kl_divergences.len() as f64;
814 let max_val = kl_divergences
815 .iter()
816 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
817 let min_val = kl_divergences.iter().fold(f64::INFINITY, |a, &b| a.min(b));
818 let variance = kl_divergences
819 .iter()
820 .map(|&x| (x - avg).powi(2))
821 .sum::<f64>()
822 / kl_divergences.len() as f64;
823 let std_dev = variance.sqrt();
824 (avg, max_val, min_val, std_dev)
825 } else {
826 (0.0, 0.0, 0.0, 0.0)
827 };
828
829 let parameter_distances: Vec<f64> = perturbation_results
831 .iter()
832 .map(|result| result.parameter_distance_from_reference)
833 .collect();
834
835 let (avg_param_dist, max_param_dist, min_param_dist, param_dist_std) =
836 if !parameter_distances.is_empty() {
837 let avg =
838 parameter_distances.iter().sum::<f64>() / parameter_distances.len() as f64;
839 let max_val = parameter_distances
840 .iter()
841 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
842 let min_val = parameter_distances
843 .iter()
844 .fold(f64::INFINITY, |a, &b| a.min(b));
845 let variance = parameter_distances
846 .iter()
847 .map(|&x| (x - avg).powi(2))
848 .sum::<f64>()
849 / parameter_distances.len() as f64;
850 let std_dev = variance.sqrt();
851 (avg, max_val, min_val, std_dev)
852 } else {
853 (0.0, 0.0, 0.0, 0.0)
854 };
855
856 let (avg_pred_var, max_pred_var, min_pred_var) = if !prediction_variances.is_empty() {
858 let avg = prediction_variances.mean().unwrap_or(0.0);
859 let max_val = prediction_variances.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
860 let min_val = prediction_variances.fold(f64::INFINITY, |a, &b| a.min(b));
861 (avg, max_val, min_val)
862 } else {
863 (0.0, 0.0, 0.0)
864 };
865
866 let mut most_sensitive_parameters = Vec::new();
868 if kl_std > 0.1 {
869 most_sensitive_parameters.push("weight_concentration".to_string());
870 }
871 if param_dist_std > 0.1 {
872 most_sensitive_parameters.push("mean_precision".to_string());
873 }
874 if avg_pred_var > 0.1 {
875 most_sensitive_parameters.push("degrees_of_freedom".to_string());
876 }
877
878 let robustness_score = 1.0 / (1.0 + avg_kl + avg_param_dist + avg_pred_var);
880
881 Ok(SensitivitySummary {
882 average_kl_divergence: avg_kl,
883 max_kl_divergence: max_kl,
884 min_kl_divergence: min_kl,
885 kl_divergence_std: kl_std,
886
887 average_parameter_distance: avg_param_dist,
888 max_parameter_distance: max_param_dist,
889 min_parameter_distance: min_param_dist,
890 parameter_distance_std: param_dist_std,
891
892 average_prediction_variance: avg_pred_var,
893 max_prediction_variance: max_pred_var,
894 min_prediction_variance: min_pred_var,
895
896 most_sensitive_parameters,
897 robustness_score,
898 })
899 }
900
901 fn linspace(&self, start: f64, end: f64, steps: usize) -> Vec<f64> {
903 if steps <= 1 {
904 return vec![start];
905 }
906
907 let step_size = (end - start) / (steps - 1) as f64;
908 (0..steps).map(|i| start + i as f64 * step_size).collect()
909 }
910}
911
912impl Default for PriorSensitivityAnalyzer {
913 fn default() -> Self {
914 Self::new()
915 }
916}
917
918impl SensitivityAnalysisResult {
919 pub fn average_kl_divergence(&self) -> f64 {
921 self.summary.average_kl_divergence
922 }
923
924 pub fn max_kl_divergence(&self) -> f64 {
926 self.summary.max_kl_divergence
927 }
928
929 pub fn robustness_score(&self) -> f64 {
931 self.summary.robustness_score
932 }
933
934 pub fn most_sensitive_parameters(&self) -> &Vec<String> {
936 &self.summary.most_sensitive_parameters
937 }
938
939 pub fn grid_results(&self) -> &[GridSearchResult] {
941 &self.grid_results
942 }
943
944 pub fn perturbation_results(&self) -> &[PerturbationResult] {
946 &self.perturbation_results
947 }
948
949 pub fn reference_model(&self) -> &VariationalBayesianGMM<VariationalBayesianGMMTrained> {
951 &self.reference_model
952 }
953
954 pub fn parameter_variances(&self) -> &ParameterVariances {
956 &self.parameter_variances
957 }
958
959 pub fn prediction_variances(&self) -> &Array1<f64> {
961 &self.prediction_variances
962 }
963
964 pub fn influence_scores(&self) -> &[InfluenceScore] {
966 &self.influence_scores
967 }
968
969 pub fn summary(&self) -> &SensitivitySummary {
971 &self.summary
972 }
973
974 pub fn most_robust_configuration(&self) -> Option<&GridSearchResult> {
976 self.grid_results.iter().min_by(|a, b| {
977 let a_score = a.lower_bound;
979 let b_score = b.lower_bound;
980 a_score
981 .partial_cmp(&b_score)
982 .unwrap_or(std::cmp::Ordering::Equal)
983 })
984 }
985
986 pub fn least_robust_configuration(&self) -> Option<&GridSearchResult> {
988 self.grid_results.iter().max_by(|a, b| {
989 let a_score = a.lower_bound;
990 let b_score = b.lower_bound;
991 a_score
992 .partial_cmp(&b_score)
993 .unwrap_or(std::cmp::Ordering::Equal)
994 })
995 }
996
997 pub fn prior_recommendations(&self) -> Vec<String> {
999 let mut recommendations = Vec::new();
1000
1001 if self.summary.robustness_score > 0.8 {
1002 recommendations.push("Model appears robust to prior choice".to_string());
1003 } else if self.summary.robustness_score < 0.3 {
1004 recommendations.push(
1005 "Model is highly sensitive to prior choice - consider more informative priors"
1006 .to_string(),
1007 );
1008 }
1009
1010 if self.summary.average_kl_divergence > 1.0 {
1011 recommendations.push(
1012 "High variation in model predictions - consider reducing prior parameter ranges"
1013 .to_string(),
1014 );
1015 }
1016
1017 if !self.summary.most_sensitive_parameters.is_empty() {
1018 recommendations.push(format!(
1019 "Most sensitive parameters: {}",
1020 self.summary.most_sensitive_parameters.join(", ")
1021 ));
1022 }
1023
1024 if recommendations.is_empty() {
1025 recommendations.push("Model shows moderate sensitivity to priors - current configuration appears reasonable".to_string());
1026 }
1027
1028 recommendations
1029 }
1030}
1031
1032#[allow(non_snake_case)]
1033#[cfg(test)]
1034mod tests {
1035 use super::*;
1036 use approx::assert_abs_diff_eq;
1037 use scirs2_core::ndarray::array;
1038
1039 #[test]
1040 fn test_prior_sensitivity_analyzer_creation() {
1041 let analyzer = PriorSensitivityAnalyzer::new()
1042 .n_components(3)
1043 .weight_concentration_range((0.1, 5.0), 3)
1044 .mean_precision_range((0.1, 10.0), 3)
1045 .n_random_perturbations(5);
1046
1047 assert_eq!(analyzer.n_components, 3);
1048 assert_eq!(analyzer.weight_concentration_steps, 3);
1049 assert_eq!(analyzer.mean_precision_steps, 3);
1050 assert_eq!(analyzer.n_random_perturbations, 5);
1051 }
1052
1053 #[test]
1054 fn test_prior_sensitivity_analyzer_linspace() {
1055 let analyzer = PriorSensitivityAnalyzer::new();
1056 let values = analyzer.linspace(0.0, 1.0, 5);
1057
1058 assert_eq!(values.len(), 5);
1059 assert_abs_diff_eq!(values[0], 0.0, epsilon = 1e-10);
1060 assert_abs_diff_eq!(values[4], 1.0, epsilon = 1e-10);
1061 assert_abs_diff_eq!(values[2], 0.5, epsilon = 1e-10);
1062 }
1063
1064 #[test]
1065 #[allow(non_snake_case)]
1066 fn test_prior_sensitivity_analysis_simple() {
1067 let X = array![
1068 [0.0, 0.0],
1069 [0.1, 0.1],
1070 [0.2, 0.2],
1071 [5.0, 5.0],
1072 [5.1, 5.1],
1073 [5.2, 5.2]
1074 ];
1075
1076 let analyzer = PriorSensitivityAnalyzer::new()
1077 .n_components(2)
1078 .weight_concentration_range((0.5, 2.0), 3)
1079 .mean_precision_range((0.5, 2.0), 3)
1080 .degrees_of_freedom_range((1.0, 3.0), 3)
1081 .n_random_perturbations(3)
1082 .max_iter(5)
1083 .random_state(42);
1084
1085 let result = analyzer.analyze(&X.view()).unwrap();
1086
1087 assert!(!result.grid_results().is_empty());
1088 assert!(!result.perturbation_results().is_empty());
1089 assert!(result.robustness_score() >= 0.0);
1090 assert!(result.robustness_score() <= 1.0);
1091 }
1092
1093 #[test]
1094 #[allow(non_snake_case)]
1095 fn test_prior_sensitivity_analysis_properties() {
1096 let X = array![[0.0, 0.0], [0.1, 0.1], [5.0, 5.0], [5.1, 5.1]];
1097
1098 let analyzer = PriorSensitivityAnalyzer::new()
1099 .n_components(2)
1100 .weight_concentration_range((0.5, 2.0), 2)
1101 .mean_precision_range((0.5, 2.0), 2)
1102 .n_random_perturbations(2)
1103 .max_iter(3)
1104 .compute_kl_divergence(true)
1105 .compute_parameter_variance(true)
1106 .compute_prediction_variance(true)
1107 .random_state(42);
1108
1109 let result = analyzer.analyze(&X.view()).unwrap();
1110
1111 assert!(result.average_kl_divergence().is_finite());
1113 assert!(result.summary().average_parameter_distance.is_finite());
1114 assert!(result.summary().average_prediction_variance.is_finite());
1115 assert!(!result.parameter_variances().weight_variances.is_empty());
1116 }
1117
1118 #[test]
1119 #[allow(non_snake_case)]
1120 fn test_prior_sensitivity_analysis_recommendations() {
1121 let X = array![[0.0, 0.0], [0.1, 0.1], [5.0, 5.0], [5.1, 5.1]];
1122
1123 let analyzer = PriorSensitivityAnalyzer::new()
1124 .n_components(2)
1125 .weight_concentration_range((0.5, 2.0), 2)
1126 .mean_precision_range((0.5, 2.0), 2)
1127 .n_random_perturbations(2)
1128 .max_iter(3)
1129 .random_state(42);
1130
1131 let result = analyzer.analyze(&X.view()).unwrap();
1132 let recommendations = result.prior_recommendations();
1133
1134 assert!(!recommendations.is_empty());
1135 assert!(recommendations.len() >= 1);
1137 }
1138
1139 #[test]
1140 #[allow(non_snake_case)]
1141 fn test_prior_sensitivity_analysis_configurations() {
1142 let X = array![[0.0, 0.0], [0.1, 0.1], [5.0, 5.0], [5.1, 5.1]];
1143
1144 let analyzer = PriorSensitivityAnalyzer::new()
1145 .n_components(2)
1146 .weight_concentration_range((0.5, 2.0), 3)
1147 .mean_precision_range((0.5, 2.0), 3)
1148 .n_random_perturbations(2)
1149 .max_iter(3)
1150 .random_state(42);
1151
1152 let result = analyzer.analyze(&X.view()).unwrap();
1153
1154 assert!(result.grid_results().len() > 0);
1156
1157 let most_robust = result.most_robust_configuration();
1159 let least_robust = result.least_robust_configuration();
1160
1161 assert!(most_robust.is_some());
1162 assert!(least_robust.is_some());
1163 }
1164
1165 #[test]
1166 #[allow(non_snake_case)]
1167 fn test_prior_sensitivity_analysis_disabled_features() {
1168 let X = array![[0.0, 0.0], [0.1, 0.1], [5.0, 5.0], [5.1, 5.1]];
1169
1170 let analyzer = PriorSensitivityAnalyzer::new()
1171 .n_components(2)
1172 .weight_concentration_range((0.5, 2.0), 2)
1173 .mean_precision_range((0.5, 2.0), 2)
1174 .n_random_perturbations(2)
1175 .max_iter(3)
1176 .compute_kl_divergence(false)
1177 .compute_parameter_variance(false)
1178 .compute_prediction_variance(false)
1179 .compute_influence_functions(false)
1180 .random_state(42);
1181
1182 let result = analyzer.analyze(&X.view()).unwrap();
1183
1184 assert!(result.kl_divergences.is_empty());
1186 assert!(result.prediction_variances().iter().all(|&x| x == 0.0));
1187 assert!(result.influence_scores().is_empty());
1188 }
1189}