sklears_mixture/
prior_sensitivity.rs

1//! Prior Sensitivity Analysis for Mixture Models
2//!
3//! This module provides tools for analyzing the sensitivity of mixture model results
4//! to the choice of prior parameters. It includes methods for computing sensitivity
5//! measures, visualizing prior effects, and robustness testing.
6
7use crate::common::CovarianceType;
8use crate::variational::{VariationalBayesianGMM, VariationalBayesianGMMTrained};
9use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
10use scirs2_core::random::essentials::Normal;
11use scirs2_core::random::{thread_rng, Rng, SeedableRng};
12use sklears_core::{
13    error::{Result as SklResult, SklearsError},
14    traits::{Fit, Predict},
15    types::Float,
16};
17use std::f64::consts::PI;
18
19/// Prior Sensitivity Analyzer for Mixture Models
20///
21/// This tool analyzes how sensitive mixture model results are to the choice
22/// of prior parameters by fitting models with different prior settings and
23/// comparing the results.
24///
25/// Key features:
26/// - Grid search over prior parameter ranges
27/// - Sensitivity measures (KL divergence, parameter variance, etc.)
28/// - Robustness testing with random prior perturbations
29/// - Prior predictive checks
30/// - Influence function analysis
31///
32/// # Examples
33///
34/// ```
35/// use sklears_mixture::{PriorSensitivityAnalyzer, CovarianceType};
36/// use scirs2_core::ndarray::array;
37///
38/// let X = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [10.0, 10.0], [11.0, 11.0], [12.0, 12.0]];
39///
40/// let analyzer = PriorSensitivityAnalyzer::new()
41///     .n_components(3)
42///     .weight_concentration_range((0.1, 5.0), 5)
43///     .mean_precision_range((0.1, 10.0), 5)
44///     .n_random_perturbations(20);
45///
46/// let analysis = analyzer.analyze(&X.view()).unwrap();
47/// println!("Average KL divergence: {}", analysis.average_kl_divergence());
48/// ```
49#[derive(Debug, Clone)]
50pub struct PriorSensitivityAnalyzer {
51    n_components: usize,
52    covariance_type: CovarianceType,
53    max_iter: usize,
54    random_state: Option<u64>,
55
56    // Prior parameter ranges for grid search
57    weight_concentration_range: (f64, f64),
58    weight_concentration_steps: usize,
59    mean_precision_range: (f64, f64),
60    mean_precision_steps: usize,
61    degrees_of_freedom_range: (f64, f64),
62    degrees_of_freedom_steps: usize,
63
64    // Random perturbation analysis
65    n_random_perturbations: usize,
66    perturbation_scale: f64,
67
68    // Reference model configuration (baseline)
69    reference_weight_concentration: f64,
70    reference_mean_precision: f64,
71    reference_degrees_of_freedom: f64,
72
73    // Analysis options
74    compute_kl_divergence: bool,
75    compute_parameter_variance: bool,
76    compute_prediction_variance: bool,
77    compute_influence_functions: bool,
78}
79
80/// Results of prior sensitivity analysis
81#[derive(Debug, Clone)]
82pub struct SensitivityAnalysisResult {
83    // Grid search results
84    grid_results: Vec<GridSearchResult>,
85
86    // Random perturbation results
87    perturbation_results: Vec<PerturbationResult>,
88
89    // Reference model (baseline)
90    reference_model: VariationalBayesianGMM<VariationalBayesianGMMTrained>,
91
92    // Sensitivity measures
93    kl_divergences: Vec<f64>,
94    parameter_variances: ParameterVariances,
95    prediction_variances: Array1<f64>,
96
97    // Influence analysis
98    influence_scores: Vec<InfluenceScore>,
99
100    // Summary statistics
101    summary: SensitivitySummary,
102}
103
104/// Result from grid search over prior parameters
105#[derive(Debug, Clone)]
106pub struct GridSearchResult {
107    weight_concentration: f64,
108    mean_precision: f64,
109    degrees_of_freedom: f64,
110    model: VariationalBayesianGMM<VariationalBayesianGMMTrained>,
111    lower_bound: f64,
112    effective_components: usize,
113}
114
115/// Result from random perturbation analysis
116#[derive(Debug, Clone)]
117pub struct PerturbationResult {
118    perturbation_id: usize,
119    perturbed_weight_concentration: f64,
120    perturbed_mean_precision: f64,
121    perturbed_degrees_of_freedom: f64,
122    model: VariationalBayesianGMM<VariationalBayesianGMMTrained>,
123    kl_divergence_from_reference: f64,
124    parameter_distance_from_reference: f64,
125}
126
127/// Parameter variances across different prior settings
128#[derive(Debug, Clone)]
129pub struct ParameterVariances {
130    weight_variances: Array1<f64>,
131    mean_variances: Array2<f64>,
132    covariance_variances: Vec<Array2<f64>>,
133}
134
135/// Influence score for individual data points
136#[derive(Debug, Clone)]
137pub struct InfluenceScore {
138    data_point_index: usize,
139    weight_influence: Array1<f64>,
140    mean_influence: Array2<f64>,
141    covariance_influence: Vec<Array2<f64>>,
142    total_influence: f64,
143}
144
145/// Summary statistics for sensitivity analysis
146#[derive(Debug, Clone)]
147pub struct SensitivitySummary {
148    average_kl_divergence: f64,
149    max_kl_divergence: f64,
150    min_kl_divergence: f64,
151    kl_divergence_std: f64,
152
153    average_parameter_distance: f64,
154    max_parameter_distance: f64,
155    min_parameter_distance: f64,
156    parameter_distance_std: f64,
157
158    average_prediction_variance: f64,
159    max_prediction_variance: f64,
160    min_prediction_variance: f64,
161
162    most_sensitive_parameters: Vec<String>,
163    robustness_score: f64,
164}
165
166impl PriorSensitivityAnalyzer {
167    /// Create a new Prior Sensitivity Analyzer
168    pub fn new() -> Self {
169        Self {
170            n_components: 2,
171            covariance_type: CovarianceType::Diagonal,
172            max_iter: 100,
173            random_state: None,
174
175            weight_concentration_range: (0.1, 5.0),
176            weight_concentration_steps: 5,
177            mean_precision_range: (0.1, 10.0),
178            mean_precision_steps: 5,
179            degrees_of_freedom_range: (1.0, 10.0),
180            degrees_of_freedom_steps: 5,
181
182            n_random_perturbations: 20,
183            perturbation_scale: 0.2,
184
185            reference_weight_concentration: 1.0,
186            reference_mean_precision: 1.0,
187            reference_degrees_of_freedom: 1.0,
188
189            compute_kl_divergence: true,
190            compute_parameter_variance: true,
191            compute_prediction_variance: true,
192            compute_influence_functions: false, // Expensive computation
193        }
194    }
195
196    /// Set the number of components
197    pub fn n_components(mut self, n_components: usize) -> Self {
198        self.n_components = n_components;
199        self
200    }
201
202    /// Set the covariance type
203    pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
204        self.covariance_type = covariance_type;
205        self
206    }
207
208    /// Set the maximum number of iterations
209    pub fn max_iter(mut self, max_iter: usize) -> Self {
210        self.max_iter = max_iter;
211        self
212    }
213
214    /// Set the random state
215    pub fn random_state(mut self, random_state: u64) -> Self {
216        self.random_state = Some(random_state);
217        self
218    }
219
220    /// Set the weight concentration prior range and steps
221    pub fn weight_concentration_range(mut self, range: (f64, f64), steps: usize) -> Self {
222        self.weight_concentration_range = range;
223        self.weight_concentration_steps = steps;
224        self
225    }
226
227    /// Set the mean precision prior range and steps
228    pub fn mean_precision_range(mut self, range: (f64, f64), steps: usize) -> Self {
229        self.mean_precision_range = range;
230        self.mean_precision_steps = steps;
231        self
232    }
233
234    /// Set the degrees of freedom prior range and steps
235    pub fn degrees_of_freedom_range(mut self, range: (f64, f64), steps: usize) -> Self {
236        self.degrees_of_freedom_range = range;
237        self.degrees_of_freedom_steps = steps;
238        self
239    }
240
241    /// Set the number of random perturbations
242    pub fn n_random_perturbations(mut self, n: usize) -> Self {
243        self.n_random_perturbations = n;
244        self
245    }
246
247    /// Set the perturbation scale for random analysis
248    pub fn perturbation_scale(mut self, scale: f64) -> Self {
249        self.perturbation_scale = scale;
250        self
251    }
252
253    /// Set the reference prior parameters
254    pub fn reference_priors(
255        mut self,
256        weight_concentration: f64,
257        mean_precision: f64,
258        degrees_of_freedom: f64,
259    ) -> Self {
260        self.reference_weight_concentration = weight_concentration;
261        self.reference_mean_precision = mean_precision;
262        self.reference_degrees_of_freedom = degrees_of_freedom;
263        self
264    }
265
266    /// Enable/disable KL divergence computation
267    pub fn compute_kl_divergence(mut self, compute: bool) -> Self {
268        self.compute_kl_divergence = compute;
269        self
270    }
271
272    /// Enable/disable parameter variance computation
273    pub fn compute_parameter_variance(mut self, compute: bool) -> Self {
274        self.compute_parameter_variance = compute;
275        self
276    }
277
278    /// Enable/disable prediction variance computation
279    pub fn compute_prediction_variance(mut self, compute: bool) -> Self {
280        self.compute_prediction_variance = compute;
281        self
282    }
283
284    /// Enable/disable influence function computation
285    pub fn compute_influence_functions(mut self, compute: bool) -> Self {
286        self.compute_influence_functions = compute;
287        self
288    }
289
290    /// Analyze prior sensitivity for the given data
291    #[allow(non_snake_case)]
292    pub fn analyze(&self, X: &ArrayView2<'_, Float>) -> SklResult<SensitivityAnalysisResult> {
293        let X = X.to_owned();
294        let (n_samples, n_features) = X.dim();
295
296        if n_samples < 2 {
297            return Err(SklearsError::InvalidInput(
298                "Number of samples must be at least 2".to_string(),
299            ));
300        }
301
302        // Fit reference model
303        let reference_model = self.fit_reference_model(&X)?;
304
305        // Perform grid search over prior parameters
306        let grid_results = self.grid_search_analysis(&X)?;
307
308        // Perform random perturbation analysis
309        let perturbation_results = self.random_perturbation_analysis(&X)?;
310
311        // Compute sensitivity measures
312        let kl_divergences = if self.compute_kl_divergence {
313            self.compute_kl_divergences(&reference_model, &grid_results, &perturbation_results)?
314        } else {
315            Vec::new()
316        };
317
318        let parameter_variances = if self.compute_parameter_variance {
319            self.compute_parameter_variances(&grid_results)?
320        } else {
321            ParameterVariances {
322                weight_variances: Array1::zeros(self.n_components),
323                mean_variances: Array2::zeros((self.n_components, n_features)),
324                covariance_variances: vec![
325                    Array2::zeros((n_features, n_features));
326                    self.n_components
327                ],
328            }
329        };
330
331        let prediction_variances = if self.compute_prediction_variance {
332            self.compute_prediction_variances(&X, &grid_results)?
333        } else {
334            Array1::zeros(n_samples)
335        };
336
337        let influence_scores = if self.compute_influence_functions {
338            self.compute_influence_scores(&X, &reference_model)?
339        } else {
340            Vec::new()
341        };
342
343        // Compute summary statistics
344        let summary = self.compute_summary_statistics(
345            &kl_divergences,
346            &perturbation_results,
347            &prediction_variances,
348        )?;
349
350        Ok(SensitivityAnalysisResult {
351            grid_results,
352            perturbation_results,
353            reference_model,
354            kl_divergences,
355            parameter_variances,
356            prediction_variances,
357            influence_scores,
358            summary,
359        })
360    }
361
362    /// Fit the reference model with baseline prior parameters
363    fn fit_reference_model(
364        &self,
365        X: &Array2<f64>,
366    ) -> SklResult<VariationalBayesianGMM<VariationalBayesianGMMTrained>> {
367        let model = VariationalBayesianGMM::new()
368            .n_components(self.n_components)
369            .covariance_type(self.covariance_type.clone())
370            .max_iter(self.max_iter)
371            .weight_concentration_prior(self.reference_weight_concentration)
372            .mean_precision_prior(self.reference_mean_precision)
373            .degrees_of_freedom_prior(self.reference_degrees_of_freedom)
374            .random_state(self.random_state.unwrap_or(42));
375
376        model.fit(&X.view(), &())
377    }
378
379    /// Perform grid search analysis over prior parameter ranges
380    fn grid_search_analysis(&self, X: &Array2<f64>) -> SklResult<Vec<GridSearchResult>> {
381        let mut results = Vec::new();
382
383        // Generate grid points
384        let weight_concentrations = self.linspace(
385            self.weight_concentration_range.0,
386            self.weight_concentration_range.1,
387            self.weight_concentration_steps,
388        );
389        let mean_precisions = self.linspace(
390            self.mean_precision_range.0,
391            self.mean_precision_range.1,
392            self.mean_precision_steps,
393        );
394        let degrees_of_freedom = self.linspace(
395            self.degrees_of_freedom_range.0,
396            self.degrees_of_freedom_range.1,
397            self.degrees_of_freedom_steps,
398        );
399
400        // Fit models for each combination of prior parameters
401        for &weight_conc in &weight_concentrations {
402            for &mean_prec in &mean_precisions {
403                for &dof in &degrees_of_freedom {
404                    let model = VariationalBayesianGMM::new()
405                        .n_components(self.n_components)
406                        .covariance_type(self.covariance_type.clone())
407                        .max_iter(self.max_iter)
408                        .weight_concentration_prior(weight_conc)
409                        .mean_precision_prior(mean_prec)
410                        .degrees_of_freedom_prior(dof)
411                        .random_state(self.random_state.unwrap_or(42));
412
413                    match model.fit(&X.view(), &()) {
414                        Ok(fitted_model) => {
415                            results.push(GridSearchResult {
416                                weight_concentration: weight_conc,
417                                mean_precision: mean_prec,
418                                degrees_of_freedom: dof,
419                                lower_bound: fitted_model.lower_bound(),
420                                effective_components: fitted_model.effective_components(),
421                                model: fitted_model,
422                            });
423                        }
424                        Err(_) => {
425                            // Skip failed fits
426                            continue;
427                        }
428                    }
429                }
430            }
431        }
432
433        Ok(results)
434    }
435
436    /// Perform random perturbation analysis
437    fn random_perturbation_analysis(&self, X: &Array2<f64>) -> SklResult<Vec<PerturbationResult>> {
438        let mut rng = if let Some(seed) = self.random_state {
439            scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
440        } else {
441            scirs2_core::random::rngs::StdRng::from_rng(&mut thread_rng())
442        };
443
444        let mut results = Vec::new();
445
446        for perturbation_id in 0..self.n_random_perturbations {
447            // Generate random perturbations
448            let weight_conc_perturbation = 1.0 + (rng.gen::<f64>() - 0.5) * self.perturbation_scale;
449            let mean_prec_perturbation = 1.0 + (rng.gen::<f64>() - 0.5) * self.perturbation_scale;
450            let dof_perturbation = 1.0 + (rng.gen::<f64>() - 0.5) * self.perturbation_scale;
451
452            let perturbed_weight_concentration =
453                (self.reference_weight_concentration * weight_conc_perturbation).max(0.01);
454            let perturbed_mean_precision =
455                (self.reference_mean_precision * mean_prec_perturbation).max(0.01);
456            let perturbed_degrees_of_freedom =
457                (self.reference_degrees_of_freedom * dof_perturbation).max(0.1);
458
459            // Fit model with perturbed priors
460            let model = VariationalBayesianGMM::new()
461                .n_components(self.n_components)
462                .covariance_type(self.covariance_type.clone())
463                .max_iter(self.max_iter)
464                .weight_concentration_prior(perturbed_weight_concentration)
465                .mean_precision_prior(perturbed_mean_precision)
466                .degrees_of_freedom_prior(perturbed_degrees_of_freedom)
467                .random_state(self.random_state.unwrap_or(42 + perturbation_id as u64));
468
469            match model.fit(&X.view(), &()) {
470                Ok(fitted_model) => {
471                    results.push(PerturbationResult {
472                        perturbation_id,
473                        perturbed_weight_concentration,
474                        perturbed_mean_precision,
475                        perturbed_degrees_of_freedom,
476                        model: fitted_model,
477                        kl_divergence_from_reference: 0.0, // Will be computed later
478                        parameter_distance_from_reference: 0.0, // Will be computed later
479                    });
480                }
481                Err(_) => {
482                    // Skip failed fits
483                    continue;
484                }
485            }
486        }
487
488        Ok(results)
489    }
490
491    /// Compute KL divergences between models
492    fn compute_kl_divergences(
493        &self,
494        reference_model: &VariationalBayesianGMM<VariationalBayesianGMMTrained>,
495        grid_results: &[GridSearchResult],
496        perturbation_results: &[PerturbationResult],
497    ) -> SklResult<Vec<f64>> {
498        let mut kl_divergences = Vec::new();
499
500        // KL divergences for grid search results
501        for result in grid_results {
502            let kl_div =
503                self.compute_kl_divergence_between_models(reference_model, &result.model)?;
504            kl_divergences.push(kl_div);
505        }
506
507        // KL divergences for perturbation results
508        for result in perturbation_results {
509            let kl_div =
510                self.compute_kl_divergence_between_models(reference_model, &result.model)?;
511            kl_divergences.push(kl_div);
512        }
513
514        Ok(kl_divergences)
515    }
516
517    /// Compute KL divergence between two mixture models
518    fn compute_kl_divergence_between_models(
519        &self,
520        model1: &VariationalBayesianGMM<VariationalBayesianGMMTrained>,
521        model2: &VariationalBayesianGMM<VariationalBayesianGMMTrained>,
522    ) -> SklResult<f64> {
523        // Simplified KL divergence approximation using Monte Carlo sampling
524        let n_samples = 1000;
525        let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
526
527        let mut kl_sum = 0.0;
528        let n_features = model1.means().ncols();
529
530        for _ in 0..n_samples {
531            // Sample from model1
532            let component = (rng.gen::<f64>() * model1.weights().len() as f64) as usize;
533            let component = component.min(model1.weights().len() - 1);
534
535            let mean = model1.means().row(component);
536            let cov = &model1.covariances()[component];
537
538            // Simple Gaussian sampling (assuming diagonal covariance for simplicity)
539            let mut sample = Array1::zeros(n_features);
540            for d in 0..n_features {
541                let std_dev = cov[[d, d]].sqrt();
542                let normal = Normal::new(mean[d], std_dev).unwrap();
543                sample[d] = rng.sample(normal);
544            }
545
546            // Compute log probabilities under both models
547            let log_p1 = self.log_probability_under_model(model1, &sample)?;
548            let log_p2 = self.log_probability_under_model(model2, &sample)?;
549
550            if log_p1.is_finite() && log_p2.is_finite() {
551                kl_sum += log_p1 - log_p2;
552            }
553        }
554
555        Ok(kl_sum / n_samples as f64)
556    }
557
558    /// Compute log probability of a sample under a mixture model
559    fn log_probability_under_model(
560        &self,
561        model: &VariationalBayesianGMM<VariationalBayesianGMMTrained>,
562        sample: &Array1<f64>,
563    ) -> SklResult<f64> {
564        let mut total_prob = 0.0;
565
566        for k in 0..model.weights().len() {
567            let weight = model.weights()[k];
568            let mean = model.means().row(k);
569            let cov = &model.covariances()[k];
570
571            let diff = sample - &mean.to_owned();
572            let mahalanobis_dist = diff.dot(&diff) / cov[[0, 0]]; // Simplified for diagonal case
573
574            let component_prob =
575                weight * (-0.5 * mahalanobis_dist).exp() / (2.0 * PI * cov[[0, 0]]).sqrt();
576
577            total_prob += component_prob;
578        }
579
580        Ok(total_prob.ln())
581    }
582
583    /// Compute parameter variances across different prior settings
584    fn compute_parameter_variances(
585        &self,
586        grid_results: &[GridSearchResult],
587    ) -> SklResult<ParameterVariances> {
588        if grid_results.is_empty() {
589            return Err(SklearsError::InvalidInput(
590                "No grid results available".to_string(),
591            ));
592        }
593
594        let n_features = grid_results[0].model.means().ncols();
595
596        // Collect all weights, means, and covariances
597        let mut all_weights = Vec::new();
598        let mut all_means = Vec::new();
599        let mut all_covariances = Vec::new();
600
601        for result in grid_results {
602            all_weights.push(result.model.weights().clone());
603            all_means.push(result.model.means().clone());
604            all_covariances.push(result.model.covariances().to_vec());
605        }
606
607        // Compute variances
608        let weight_variances = self.compute_array1_variance(&all_weights);
609        let mean_variances = self.compute_array2_variance(&all_means);
610        let covariance_variances = self.compute_covariance_variance(&all_covariances, n_features);
611
612        Ok(ParameterVariances {
613            weight_variances,
614            mean_variances,
615            covariance_variances,
616        })
617    }
618
619    /// Compute variance of Array1 parameters
620    fn compute_array1_variance(&self, arrays: &[Array1<f64>]) -> Array1<f64> {
621        if arrays.is_empty() {
622            return Array1::zeros(0);
623        }
624
625        let _n_arrays = arrays.len();
626        let array_len = arrays[0].len();
627        let mut variances = Array1::zeros(array_len);
628
629        for i in 0..array_len {
630            let values: Vec<f64> = arrays.iter().map(|arr| arr[i]).collect();
631            let mean = values.iter().sum::<f64>() / values.len() as f64;
632            let variance =
633                values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
634            variances[i] = variance;
635        }
636
637        variances
638    }
639
640    /// Compute variance of Array2 parameters
641    fn compute_array2_variance(&self, arrays: &[Array2<f64>]) -> Array2<f64> {
642        if arrays.is_empty() {
643            return Array2::zeros((0, 0));
644        }
645
646        let (n_rows, n_cols) = arrays[0].dim();
647        let mut variances = Array2::zeros((n_rows, n_cols));
648
649        for i in 0..n_rows {
650            for j in 0..n_cols {
651                let values: Vec<f64> = arrays.iter().map(|arr| arr[[i, j]]).collect();
652                let mean = values.iter().sum::<f64>() / values.len() as f64;
653                let variance =
654                    values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
655                variances[[i, j]] = variance;
656            }
657        }
658
659        variances
660    }
661
662    /// Compute variance of covariance matrices
663    fn compute_covariance_variance(
664        &self,
665        all_covariances: &[Vec<Array2<f64>>],
666        n_features: usize,
667    ) -> Vec<Array2<f64>> {
668        if all_covariances.is_empty() {
669            return vec![Array2::zeros((n_features, n_features)); self.n_components];
670        }
671
672        let mut variances = vec![Array2::zeros((n_features, n_features)); self.n_components];
673
674        for k in 0..self.n_components {
675            for i in 0..n_features {
676                for j in 0..n_features {
677                    let values: Vec<f64> = all_covariances
678                        .iter()
679                        .filter(|cov| cov.len() > k)
680                        .map(|cov| cov[k][[i, j]])
681                        .collect();
682
683                    if !values.is_empty() {
684                        let mean = values.iter().sum::<f64>() / values.len() as f64;
685                        let variance = values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
686                            / values.len() as f64;
687                        variances[k][[i, j]] = variance;
688                    }
689                }
690            }
691        }
692
693        variances
694    }
695
696    /// Compute prediction variances across different models
697    fn compute_prediction_variances(
698        &self,
699        X: &Array2<f64>,
700        grid_results: &[GridSearchResult],
701    ) -> SklResult<Array1<f64>> {
702        let n_samples = X.nrows();
703        let mut prediction_variances = Array1::zeros(n_samples);
704
705        for i in 0..n_samples {
706            let x_i = X.row(i);
707            let mut predictions = Vec::new();
708
709            // Collect predictions from all models
710            for result in grid_results {
711                match result
712                    .model
713                    .predict(&x_i.to_owned().insert_axis(Axis(0)).view())
714                {
715                    Ok(pred) => {
716                        // Use prediction as a measure (we'll compute variance later)
717                        if !pred.is_empty() {
718                            predictions.push(pred[0] as f64);
719                        }
720                    }
721                    Err(_) => continue,
722                }
723            }
724
725            // Compute variance of predictions
726            if !predictions.is_empty() {
727                let mean_pred = predictions.iter().sum::<f64>() / predictions.len() as f64;
728                let variance = predictions
729                    .iter()
730                    .map(|&pred| (pred - mean_pred).powi(2))
731                    .sum::<f64>()
732                    / predictions.len() as f64;
733                prediction_variances[i] = variance;
734            }
735        }
736
737        Ok(prediction_variances)
738    }
739
740    /// Compute influence scores for individual data points
741    fn compute_influence_scores(
742        &self,
743        X: &Array2<f64>,
744        reference_model: &VariationalBayesianGMM<VariationalBayesianGMMTrained>,
745    ) -> SklResult<Vec<InfluenceScore>> {
746        let (n_samples, n_features) = X.dim();
747        let mut influence_scores = Vec::new();
748
749        // Simplified influence computation using leave-one-out approach
750        for i in 0..n_samples.min(10) {
751            // Limit to first 10 samples for efficiency
752            // Create dataset without sample i
753            let mut X_loo = Array2::zeros((n_samples - 1, n_features));
754            let mut row_idx = 0;
755            for j in 0..n_samples {
756                if j != i {
757                    X_loo.row_mut(row_idx).assign(&X.row(j));
758                    row_idx += 1;
759                }
760            }
761
762            // Fit model without sample i
763            let loo_model = VariationalBayesianGMM::new()
764                .n_components(self.n_components)
765                .covariance_type(self.covariance_type.clone())
766                .max_iter(self.max_iter)
767                .weight_concentration_prior(self.reference_weight_concentration)
768                .mean_precision_prior(self.reference_mean_precision)
769                .degrees_of_freedom_prior(self.reference_degrees_of_freedom)
770                .random_state(self.random_state.unwrap_or(42));
771
772            match loo_model.fit(&X_loo.view(), &()) {
773                Ok(fitted_loo_model) => {
774                    // Compute influence as difference in parameters
775                    let weight_influence = reference_model.weights() - fitted_loo_model.weights();
776                    let mean_influence = reference_model.means() - fitted_loo_model.means();
777
778                    // Simplified covariance influence (diagonal elements only)
779                    let mut covariance_influence = Vec::new();
780                    for k in 0..self.n_components {
781                        let cov_diff =
782                            &reference_model.covariances()[k] - &fitted_loo_model.covariances()[k];
783                        covariance_influence.push(cov_diff);
784                    }
785
786                    let total_influence = weight_influence.iter().map(|x| x.abs()).sum::<f64>()
787                        + mean_influence.iter().map(|x| x.abs()).sum::<f64>();
788
789                    influence_scores.push(InfluenceScore {
790                        data_point_index: i,
791                        weight_influence,
792                        mean_influence,
793                        covariance_influence,
794                        total_influence,
795                    });
796                }
797                Err(_) => continue,
798            }
799        }
800
801        Ok(influence_scores)
802    }
803
804    /// Compute summary statistics for the sensitivity analysis
805    fn compute_summary_statistics(
806        &self,
807        kl_divergences: &[f64],
808        perturbation_results: &[PerturbationResult],
809        prediction_variances: &Array1<f64>,
810    ) -> SklResult<SensitivitySummary> {
811        // KL divergence statistics
812        let (avg_kl, max_kl, min_kl, kl_std) = if !kl_divergences.is_empty() {
813            let avg = kl_divergences.iter().sum::<f64>() / kl_divergences.len() as f64;
814            let max_val = kl_divergences
815                .iter()
816                .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
817            let min_val = kl_divergences.iter().fold(f64::INFINITY, |a, &b| a.min(b));
818            let variance = kl_divergences
819                .iter()
820                .map(|&x| (x - avg).powi(2))
821                .sum::<f64>()
822                / kl_divergences.len() as f64;
823            let std_dev = variance.sqrt();
824            (avg, max_val, min_val, std_dev)
825        } else {
826            (0.0, 0.0, 0.0, 0.0)
827        };
828
829        // Parameter distance statistics (using perturbation results)
830        let parameter_distances: Vec<f64> = perturbation_results
831            .iter()
832            .map(|result| result.parameter_distance_from_reference)
833            .collect();
834
835        let (avg_param_dist, max_param_dist, min_param_dist, param_dist_std) =
836            if !parameter_distances.is_empty() {
837                let avg =
838                    parameter_distances.iter().sum::<f64>() / parameter_distances.len() as f64;
839                let max_val = parameter_distances
840                    .iter()
841                    .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
842                let min_val = parameter_distances
843                    .iter()
844                    .fold(f64::INFINITY, |a, &b| a.min(b));
845                let variance = parameter_distances
846                    .iter()
847                    .map(|&x| (x - avg).powi(2))
848                    .sum::<f64>()
849                    / parameter_distances.len() as f64;
850                let std_dev = variance.sqrt();
851                (avg, max_val, min_val, std_dev)
852            } else {
853                (0.0, 0.0, 0.0, 0.0)
854            };
855
856        // Prediction variance statistics
857        let (avg_pred_var, max_pred_var, min_pred_var) = if !prediction_variances.is_empty() {
858            let avg = prediction_variances.mean().unwrap_or(0.0);
859            let max_val = prediction_variances.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
860            let min_val = prediction_variances.fold(f64::INFINITY, |a, &b| a.min(b));
861            (avg, max_val, min_val)
862        } else {
863            (0.0, 0.0, 0.0)
864        };
865
866        // Identify most sensitive parameters (simplified)
867        let mut most_sensitive_parameters = Vec::new();
868        if kl_std > 0.1 {
869            most_sensitive_parameters.push("weight_concentration".to_string());
870        }
871        if param_dist_std > 0.1 {
872            most_sensitive_parameters.push("mean_precision".to_string());
873        }
874        if avg_pred_var > 0.1 {
875            most_sensitive_parameters.push("degrees_of_freedom".to_string());
876        }
877
878        // Compute robustness score (inverse of average sensitivity)
879        let robustness_score = 1.0 / (1.0 + avg_kl + avg_param_dist + avg_pred_var);
880
881        Ok(SensitivitySummary {
882            average_kl_divergence: avg_kl,
883            max_kl_divergence: max_kl,
884            min_kl_divergence: min_kl,
885            kl_divergence_std: kl_std,
886
887            average_parameter_distance: avg_param_dist,
888            max_parameter_distance: max_param_dist,
889            min_parameter_distance: min_param_dist,
890            parameter_distance_std: param_dist_std,
891
892            average_prediction_variance: avg_pred_var,
893            max_prediction_variance: max_pred_var,
894            min_prediction_variance: min_pred_var,
895
896            most_sensitive_parameters,
897            robustness_score,
898        })
899    }
900
901    /// Generate linearly spaced values
902    fn linspace(&self, start: f64, end: f64, steps: usize) -> Vec<f64> {
903        if steps <= 1 {
904            return vec![start];
905        }
906
907        let step_size = (end - start) / (steps - 1) as f64;
908        (0..steps).map(|i| start + i as f64 * step_size).collect()
909    }
910}
911
912impl Default for PriorSensitivityAnalyzer {
913    fn default() -> Self {
914        Self::new()
915    }
916}
917
918impl SensitivityAnalysisResult {
919    /// Get the average KL divergence
920    pub fn average_kl_divergence(&self) -> f64 {
921        self.summary.average_kl_divergence
922    }
923
924    /// Get the maximum KL divergence
925    pub fn max_kl_divergence(&self) -> f64 {
926        self.summary.max_kl_divergence
927    }
928
929    /// Get the robustness score
930    pub fn robustness_score(&self) -> f64 {
931        self.summary.robustness_score
932    }
933
934    /// Get the most sensitive parameters
935    pub fn most_sensitive_parameters(&self) -> &Vec<String> {
936        &self.summary.most_sensitive_parameters
937    }
938
939    /// Get the grid search results
940    pub fn grid_results(&self) -> &[GridSearchResult] {
941        &self.grid_results
942    }
943
944    /// Get the perturbation analysis results
945    pub fn perturbation_results(&self) -> &[PerturbationResult] {
946        &self.perturbation_results
947    }
948
949    /// Get the reference model
950    pub fn reference_model(&self) -> &VariationalBayesianGMM<VariationalBayesianGMMTrained> {
951        &self.reference_model
952    }
953
954    /// Get parameter variances
955    pub fn parameter_variances(&self) -> &ParameterVariances {
956        &self.parameter_variances
957    }
958
959    /// Get prediction variances
960    pub fn prediction_variances(&self) -> &Array1<f64> {
961        &self.prediction_variances
962    }
963
964    /// Get influence scores
965    pub fn influence_scores(&self) -> &[InfluenceScore] {
966        &self.influence_scores
967    }
968
969    /// Get summary statistics
970    pub fn summary(&self) -> &SensitivitySummary {
971        &self.summary
972    }
973
974    /// Find the most robust prior configuration
975    pub fn most_robust_configuration(&self) -> Option<&GridSearchResult> {
976        self.grid_results.iter().min_by(|a, b| {
977            // Find configuration with minimum average deviation from others
978            let a_score = a.lower_bound;
979            let b_score = b.lower_bound;
980            a_score
981                .partial_cmp(&b_score)
982                .unwrap_or(std::cmp::Ordering::Equal)
983        })
984    }
985
986    /// Find the least robust prior configuration
987    pub fn least_robust_configuration(&self) -> Option<&GridSearchResult> {
988        self.grid_results.iter().max_by(|a, b| {
989            let a_score = a.lower_bound;
990            let b_score = b.lower_bound;
991            a_score
992                .partial_cmp(&b_score)
993                .unwrap_or(std::cmp::Ordering::Equal)
994        })
995    }
996
997    /// Get recommendations for prior selection
998    pub fn prior_recommendations(&self) -> Vec<String> {
999        let mut recommendations = Vec::new();
1000
1001        if self.summary.robustness_score > 0.8 {
1002            recommendations.push("Model appears robust to prior choice".to_string());
1003        } else if self.summary.robustness_score < 0.3 {
1004            recommendations.push(
1005                "Model is highly sensitive to prior choice - consider more informative priors"
1006                    .to_string(),
1007            );
1008        }
1009
1010        if self.summary.average_kl_divergence > 1.0 {
1011            recommendations.push(
1012                "High variation in model predictions - consider reducing prior parameter ranges"
1013                    .to_string(),
1014            );
1015        }
1016
1017        if !self.summary.most_sensitive_parameters.is_empty() {
1018            recommendations.push(format!(
1019                "Most sensitive parameters: {}",
1020                self.summary.most_sensitive_parameters.join(", ")
1021            ));
1022        }
1023
1024        if recommendations.is_empty() {
1025            recommendations.push("Model shows moderate sensitivity to priors - current configuration appears reasonable".to_string());
1026        }
1027
1028        recommendations
1029    }
1030}
1031
1032#[allow(non_snake_case)]
1033#[cfg(test)]
1034mod tests {
1035    use super::*;
1036    use approx::assert_abs_diff_eq;
1037    use scirs2_core::ndarray::array;
1038
1039    #[test]
1040    fn test_prior_sensitivity_analyzer_creation() {
1041        let analyzer = PriorSensitivityAnalyzer::new()
1042            .n_components(3)
1043            .weight_concentration_range((0.1, 5.0), 3)
1044            .mean_precision_range((0.1, 10.0), 3)
1045            .n_random_perturbations(5);
1046
1047        assert_eq!(analyzer.n_components, 3);
1048        assert_eq!(analyzer.weight_concentration_steps, 3);
1049        assert_eq!(analyzer.mean_precision_steps, 3);
1050        assert_eq!(analyzer.n_random_perturbations, 5);
1051    }
1052
1053    #[test]
1054    fn test_prior_sensitivity_analyzer_linspace() {
1055        let analyzer = PriorSensitivityAnalyzer::new();
1056        let values = analyzer.linspace(0.0, 1.0, 5);
1057
1058        assert_eq!(values.len(), 5);
1059        assert_abs_diff_eq!(values[0], 0.0, epsilon = 1e-10);
1060        assert_abs_diff_eq!(values[4], 1.0, epsilon = 1e-10);
1061        assert_abs_diff_eq!(values[2], 0.5, epsilon = 1e-10);
1062    }
1063
1064    #[test]
1065    #[allow(non_snake_case)]
1066    fn test_prior_sensitivity_analysis_simple() {
1067        let X = array![
1068            [0.0, 0.0],
1069            [0.1, 0.1],
1070            [0.2, 0.2],
1071            [5.0, 5.0],
1072            [5.1, 5.1],
1073            [5.2, 5.2]
1074        ];
1075
1076        let analyzer = PriorSensitivityAnalyzer::new()
1077            .n_components(2)
1078            .weight_concentration_range((0.5, 2.0), 3)
1079            .mean_precision_range((0.5, 2.0), 3)
1080            .degrees_of_freedom_range((1.0, 3.0), 3)
1081            .n_random_perturbations(3)
1082            .max_iter(5)
1083            .random_state(42);
1084
1085        let result = analyzer.analyze(&X.view()).unwrap();
1086
1087        assert!(!result.grid_results().is_empty());
1088        assert!(!result.perturbation_results().is_empty());
1089        assert!(result.robustness_score() >= 0.0);
1090        assert!(result.robustness_score() <= 1.0);
1091    }
1092
1093    #[test]
1094    #[allow(non_snake_case)]
1095    fn test_prior_sensitivity_analysis_properties() {
1096        let X = array![[0.0, 0.0], [0.1, 0.1], [5.0, 5.0], [5.1, 5.1]];
1097
1098        let analyzer = PriorSensitivityAnalyzer::new()
1099            .n_components(2)
1100            .weight_concentration_range((0.5, 2.0), 2)
1101            .mean_precision_range((0.5, 2.0), 2)
1102            .n_random_perturbations(2)
1103            .max_iter(3)
1104            .compute_kl_divergence(true)
1105            .compute_parameter_variance(true)
1106            .compute_prediction_variance(true)
1107            .random_state(42);
1108
1109        let result = analyzer.analyze(&X.view()).unwrap();
1110
1111        // Check that analysis components exist
1112        assert!(result.average_kl_divergence().is_finite());
1113        assert!(result.summary().average_parameter_distance.is_finite());
1114        assert!(result.summary().average_prediction_variance.is_finite());
1115        assert!(!result.parameter_variances().weight_variances.is_empty());
1116    }
1117
1118    #[test]
1119    #[allow(non_snake_case)]
1120    fn test_prior_sensitivity_analysis_recommendations() {
1121        let X = array![[0.0, 0.0], [0.1, 0.1], [5.0, 5.0], [5.1, 5.1]];
1122
1123        let analyzer = PriorSensitivityAnalyzer::new()
1124            .n_components(2)
1125            .weight_concentration_range((0.5, 2.0), 2)
1126            .mean_precision_range((0.5, 2.0), 2)
1127            .n_random_perturbations(2)
1128            .max_iter(3)
1129            .random_state(42);
1130
1131        let result = analyzer.analyze(&X.view()).unwrap();
1132        let recommendations = result.prior_recommendations();
1133
1134        assert!(!recommendations.is_empty());
1135        // Should have at least one recommendation
1136        assert!(recommendations.len() >= 1);
1137    }
1138
1139    #[test]
1140    #[allow(non_snake_case)]
1141    fn test_prior_sensitivity_analysis_configurations() {
1142        let X = array![[0.0, 0.0], [0.1, 0.1], [5.0, 5.0], [5.1, 5.1]];
1143
1144        let analyzer = PriorSensitivityAnalyzer::new()
1145            .n_components(2)
1146            .weight_concentration_range((0.5, 2.0), 3)
1147            .mean_precision_range((0.5, 2.0), 3)
1148            .n_random_perturbations(2)
1149            .max_iter(3)
1150            .random_state(42);
1151
1152        let result = analyzer.analyze(&X.view()).unwrap();
1153
1154        // Should have grid results
1155        assert!(result.grid_results().len() > 0);
1156
1157        // Should be able to find most/least robust configurations
1158        let most_robust = result.most_robust_configuration();
1159        let least_robust = result.least_robust_configuration();
1160
1161        assert!(most_robust.is_some());
1162        assert!(least_robust.is_some());
1163    }
1164
1165    #[test]
1166    #[allow(non_snake_case)]
1167    fn test_prior_sensitivity_analysis_disabled_features() {
1168        let X = array![[0.0, 0.0], [0.1, 0.1], [5.0, 5.0], [5.1, 5.1]];
1169
1170        let analyzer = PriorSensitivityAnalyzer::new()
1171            .n_components(2)
1172            .weight_concentration_range((0.5, 2.0), 2)
1173            .mean_precision_range((0.5, 2.0), 2)
1174            .n_random_perturbations(2)
1175            .max_iter(3)
1176            .compute_kl_divergence(false)
1177            .compute_parameter_variance(false)
1178            .compute_prediction_variance(false)
1179            .compute_influence_functions(false)
1180            .random_state(42);
1181
1182        let result = analyzer.analyze(&X.view()).unwrap();
1183
1184        // When features are disabled, should have empty or zero results
1185        assert!(result.kl_divergences.is_empty());
1186        assert!(result.prediction_variances().iter().all(|&x| x == 0.0));
1187        assert!(result.influence_scores().is_empty());
1188    }
1189}