scirs2_integrate/autodiff/
sensitivity.rs

1//! Sensitivity analysis tools
2//!
3//! This module provides tools for analyzing how solutions depend on parameters,
4//! including local sensitivity analysis and global sensitivity indices.
5
6use crate::common::IntegrateFloat;
7use crate::error::{IntegrateError, IntegrateResult};
8use crate::ode::{solve_ivp, ODEOptions};
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
10use scirs2_core::random::Rng;
11use std::collections::HashMap;
12
13// Type alias for complex return type
14type SensitivityResult<F> = IntegrateResult<(HashMap<usize, Array1<F>>, HashMap<usize, Array1<F>>)>;
15
16/// Parameter sensitivity information
17#[derive(Clone)]
18pub struct ParameterSensitivity<F: IntegrateFloat> {
19    /// Parameter name
20    pub name: String,
21    /// Parameter index
22    pub index: usize,
23    /// Nominal value
24    pub nominal_value: F,
25    /// Sensitivity matrix (∂y/∂p)
26    pub sensitivity: Array2<F>,
27    /// Time points
28    pub t_eval: Array1<F>,
29}
30
31/// Sensitivity analysis results
32pub struct SensitivityAnalysis<F: IntegrateFloat> {
33    /// Solution at nominal parameters
34    pub nominal_solution: Array2<F>,
35    /// Time points
36    pub t_eval: Array1<F>,
37    /// Parameter sensitivities
38    pub sensitivities: Vec<ParameterSensitivity<F>>,
39    /// First-order sensitivity indices (if computed)
40    pub first_order_indices: Option<HashMap<String, Array1<F>>>,
41    /// Total sensitivity indices (if computed)
42    pub total_indices: Option<HashMap<String, Array1<F>>>,
43}
44
45impl<F: IntegrateFloat> SensitivityAnalysis<F> {
46    /// Get sensitivity for a specific parameter
47    pub fn get_sensitivity(&self, paramname: &str) -> Option<&ParameterSensitivity<F>> {
48        self.sensitivities.iter().find(|s| s.name == paramname)
49    }
50
51    /// Compute relative sensitivities
52    pub fn relative_sensitivities(&self) -> IntegrateResult<HashMap<String, Array2<F>>> {
53        let mut result = HashMap::new();
54
55        for sens in &self.sensitivities {
56            let mut rel_sens = sens.sensitivity.clone();
57
58            // Compute S_ij = (p_j / y_i) * (∂y_i/∂p_j)
59            for i in 0..rel_sens.nrows() {
60                for j in 0..rel_sens.ncols() {
61                    let y_nominal = self.nominal_solution[[i, j]];
62                    if y_nominal.abs() > F::epsilon() {
63                        rel_sens[[i, j]] *= sens.nominal_value / y_nominal;
64                    }
65                }
66            }
67
68            result.insert(sens.name.clone(), rel_sens);
69        }
70
71        Ok(result)
72    }
73
74    /// Compute time-averaged sensitivities
75    pub fn time_averaged_sensitivities(&self) -> HashMap<String, Array1<F>> {
76        let mut result = HashMap::new();
77        let n_time = self.t_eval.len();
78
79        for sens in &self.sensitivities {
80            let n_states = sens.sensitivity.ncols();
81            let mut avg_sens = Array1::zeros(n_states);
82
83            // Compute time average for each state variable
84            for j in 0..n_states {
85                let mut sum = F::zero();
86                for i in 0..n_time {
87                    sum += sens.sensitivity[[i, j]].abs();
88                }
89                avg_sens[j] = sum / F::from(n_time).unwrap();
90            }
91
92            result.insert(sens.name.clone(), avg_sens);
93        }
94
95        result
96    }
97}
98
99/// Compute sensitivities using forward sensitivity analysis
100#[allow(dead_code)]
101pub fn compute_sensitivities<F, SysFunc, ParamFunc>(
102    system: SysFunc,
103    _parameters: ParamFunc,
104    param_names: Vec<String>,
105    nominal_params: ArrayView1<F>,
106    y0: ArrayView1<F>,
107    t_span: (F, F),
108    _t_eval: Option<ArrayView1<F>>,
109    options: Option<ODEOptions<F>>,
110) -> IntegrateResult<SensitivityAnalysis<F>>
111where
112    F: IntegrateFloat + std::default::Default,
113    SysFunc: Fn(F, ArrayView1<F>, ArrayView1<F>) -> Array1<F> + Clone,
114    ParamFunc: Fn(usize) -> Array1<F>,
115{
116    let n_states = y0.len();
117    let n_params = nominal_params.len();
118
119    if param_names.len() != n_params {
120        return Err(IntegrateError::ValueError(
121            "Number of parameter _names must match number of _parameters".to_string(),
122        ));
123    }
124
125    // Solve nominal system
126    let opts = options.clone().unwrap_or_default();
127
128    let nominal_result = solve_ivp(
129        |t, y| system(t, y, nominal_params),
130        [t_span.0, t_span.1],
131        y0.to_owned(),
132        Some(opts),
133    )?;
134
135    let t_points = nominal_result.t.clone();
136
137    // Compute sensitivities for each parameter
138    let mut sensitivities = Vec::new();
139
140    for (param_idx, param_name) in param_names.iter().enumerate() {
141        // Create augmented system for sensitivity equations
142        let augmented_dim = n_states * (1 + 1); // States + sensitivity matrix
143        let mut y0_aug = Array1::zeros(augmented_dim);
144
145        // Initial conditions: y(0) and S(0) = 0
146        y0_aug
147            .slice_mut(scirs2_core::ndarray::s![0..n_states])
148            .assign(&y0);
149
150        let system_clone = system.clone();
151        let params = nominal_params.to_owned();
152
153        // Augmented system: [dy/dt; dS/dt]
154        let augmented_system = move |t: F, y_aug: ArrayView1<F>| -> Array1<F> {
155            let y = y_aug.slice(scirs2_core::ndarray::s![0..n_states]);
156            let s = y_aug
157                .slice(scirs2_core::ndarray::s![n_states..])
158                .to_owned()
159                .into_shape_with_order((n_states,))
160                .unwrap();
161
162            // Compute f(t, y, p)
163            let f = system_clone(t, y, params.view());
164
165            // Compute ∂f/∂y using finite differences
166            let eps = F::from(1e-8).unwrap();
167            let mut df_dy = Array2::zeros((n_states, n_states));
168
169            for j in 0..n_states {
170                let mut y_pert = y.to_owned();
171                y_pert[j] += eps;
172                let f_pert = system_clone(t, y_pert.view(), params.view());
173
174                for i in 0..n_states {
175                    df_dy[[i, j]] = (f_pert[i] - f[i]) / eps;
176                }
177            }
178
179            // Compute ∂f/∂p for the current parameter
180            let mut params_pert = params.to_owned();
181            params_pert[param_idx] += eps;
182            let f_pert = system_clone(t, y, params_pert.view());
183            let df_dp = (f_pert - &f) / eps;
184
185            // dS/dt = ∂f/∂y * S + ∂f/∂p
186            let ds_dt = df_dy.dot(&s) + df_dp;
187
188            // Combine derivatives
189            let mut result = Array1::zeros(augmented_dim);
190            result
191                .slice_mut(scirs2_core::ndarray::s![0..n_states])
192                .assign(&f);
193            result
194                .slice_mut(scirs2_core::ndarray::s![n_states..])
195                .assign(&ds_dt);
196
197            result
198        };
199
200        // Solve augmented system
201        let aug_opts = options.clone().unwrap_or_default();
202
203        let aug_result = solve_ivp(
204            augmented_system,
205            [t_span.0, t_span.1],
206            y0_aug,
207            Some(aug_opts),
208        )?;
209
210        // Extract sensitivity matrix
211        let aug_time = aug_result.t.len();
212        let mut sensitivity = Array2::zeros((aug_time, n_states));
213        for (i, sol) in aug_result.y.iter().enumerate() {
214            let s = sol.slice(scirs2_core::ndarray::s![n_states..]);
215            sensitivity.row_mut(i).assign(&s);
216        }
217
218        sensitivities.push(ParameterSensitivity {
219            name: param_name.clone(),
220            index: param_idx,
221            nominal_value: nominal_params[param_idx],
222            sensitivity,
223            t_eval: Array1::from_vec(aug_result.t.clone()),
224        });
225    }
226
227    // Convert Vec<Array1<F>> to Array2<F>
228    let n_points = nominal_result.t.len();
229    let mut nominal_solution = Array2::zeros((n_points, n_states));
230    for (i, sol) in nominal_result.y.iter().enumerate() {
231        nominal_solution.row_mut(i).assign(sol);
232    }
233
234    Ok(SensitivityAnalysis {
235        nominal_solution,
236        t_eval: Array1::from_vec(t_points),
237        sensitivities,
238        first_order_indices: None,
239        total_indices: None,
240    })
241}
242
243/// Compute local sensitivity indices at a specific time
244#[allow(dead_code)]
245pub fn local_sensitivity_indices<F: IntegrateFloat>(
246    analysis: &SensitivityAnalysis<F>,
247    time_index: usize,
248) -> IntegrateResult<HashMap<String, Array1<F>>> {
249    let n_states = analysis.nominal_solution.ncols();
250    let mut indices = HashMap::new();
251
252    for sens in &analysis.sensitivities {
253        let mut param_indices = Array1::zeros(n_states);
254
255        for j in 0..n_states {
256            let y_nominal = analysis.nominal_solution[[time_index, j]];
257            let s_ij = sens.sensitivity[[time_index, j]];
258
259            if y_nominal.abs() > F::epsilon() {
260                // Normalized sensitivity _index
261                param_indices[j] = (s_ij * sens.nominal_value / y_nominal).abs();
262            }
263        }
264
265        indices.insert(sens.name.clone(), param_indices);
266    }
267
268    Ok(indices)
269}
270
271/// Sobol indices for global sensitivity analysis
272pub struct SobolIndices<F: IntegrateFloat> {
273    /// First-order indices S_i
274    pub first_order: HashMap<String, F>,
275    /// Total indices S_Ti
276    pub total: HashMap<String, F>,
277    /// Second-order indices S_ij (optional)
278    pub second_order: Option<HashMap<(String, String), F>>,
279}
280
281/// Variance-based sensitivity analysis using Sobol method
282pub struct SobolAnalysis<F: IntegrateFloat> {
283    /// Number of samples
284    n_samples: usize,
285    /// Parameter bounds
286    param_bounds: Vec<(F, F)>,
287    /// Random seed for reproducibility
288    seed: Option<u64>,
289}
290
291impl<F: IntegrateFloat> SobolAnalysis<F> {
292    /// Create a new Sobol analysis
293    pub fn new(n_samples: usize, param_bounds: Vec<(F, F)>) -> Self {
294        SobolAnalysis {
295            n_samples,
296            param_bounds,
297            seed: None,
298        }
299    }
300
301    /// Set random seed for reproducibility
302    pub fn with_seed(&mut self, seed: u64) -> &mut Self {
303        self.seed = Some(seed);
304        self
305    }
306
307    /// Compute Sobol indices
308    pub fn compute_indices<Func>(
309        &self,
310        model: Func,
311        param_names: Vec<String>,
312    ) -> IntegrateResult<SobolIndices<F>>
313    where
314        Func: Fn(ArrayView1<F>) -> IntegrateResult<F> + Sync + Send,
315    {
316        let n_params = self.param_bounds.len();
317        if param_names.len() != n_params {
318            return Err(IntegrateError::ValueError(
319                "Number of parameter _names must match bounds".to_string(),
320            ));
321        }
322
323        // Generate quasi-random samples using Sobol sequence
324        let sample_matrix_a = self.generate_sample_matrix();
325        let sample_matrix_b = self.generate_sample_matrix();
326
327        // Evaluate model at base samples
328        let y_a = SobolAnalysis::<F>::evaluate_model(&model, &sample_matrix_a)?;
329        let y_b = SobolAnalysis::<F>::evaluate_model(&model, &sample_matrix_b)?;
330
331        // Compute variance
332        let var_y = SobolAnalysis::<F>::compute_variance(&y_a, &y_b, self.n_samples);
333
334        let mut first_order = HashMap::new();
335        let mut total = HashMap::new();
336
337        // Compute indices for each parameter
338        for (i, name) in param_names.iter().enumerate() {
339            // Create matrix C_i where column i comes from B, rest from A
340            let sample_matrix_ci = self.create_mixed_matrix(&sample_matrix_a, &sample_matrix_b, i);
341            let y_ci = SobolAnalysis::<F>::evaluate_model(&model, &sample_matrix_ci)?;
342
343            // First-order index: S_i = V(E(Y|X_i)) / V(Y)
344            let s_i = SobolAnalysis::<F>::compute_first_order_index(
345                &y_a,
346                &y_b,
347                &y_ci,
348                var_y,
349                self.n_samples,
350            );
351            first_order.insert(name.clone(), s_i);
352
353            // Total index: S_Ti = 1 - V(E(Y|X_~i)) / V(Y)
354            let s_ti = SobolAnalysis::<F>::compute_total_index(&y_a, &y_ci, var_y, self.n_samples);
355            total.insert(name.clone(), s_ti);
356        }
357
358        Ok(SobolIndices {
359            first_order,
360            total,
361            second_order: None,
362        })
363    }
364
365    /// Generate sample matrix using quasi-random sequences
366    fn generate_sample_matrix(&self) -> Vec<Array1<F>> {
367        let n_params = self.param_bounds.len();
368        let mut samples = Vec::with_capacity(self.n_samples);
369
370        // Simple uniform random sampling (should use Sobol sequence for better coverage)
371        for i in 0..self.n_samples {
372            let mut sample = Array1::zeros(n_params);
373            for j in 0..n_params {
374                let (low, high) = self.param_bounds[j];
375                let u = F::from(i).unwrap() / F::from(self.n_samples - 1).unwrap();
376                sample[j] = low + (high - low) * u;
377            }
378            samples.push(sample);
379        }
380
381        samples
382    }
383
384    /// Evaluate model at all sample points
385    fn evaluate_model<Func>(model: &Func, samples: &[Array1<F>]) -> IntegrateResult<Vec<F>>
386    where
387        Func: Fn(ArrayView1<F>) -> IntegrateResult<F> + Sync + Send,
388    {
389        // Evaluate _model at each sample point
390        let mut results = Vec::with_capacity(samples.len());
391        for sample in samples {
392            results.push(model(sample.view())?);
393        }
394        Ok(results)
395    }
396
397    /// Create mixed sample matrix for computing indices
398    fn create_mixed_matrix(
399        &self,
400        matrix_a: &[Array1<F>],
401        matrix_b: &[Array1<F>],
402        param_idx: usize,
403    ) -> Vec<Array1<F>> {
404        let mut mixed = Vec::with_capacity(self.n_samples);
405
406        for i in 0..self.n_samples {
407            let mut sample = matrix_a[i].clone();
408            sample[param_idx] = matrix_b[i][param_idx];
409            mixed.push(sample);
410        }
411
412        mixed
413    }
414
415    /// Compute variance of model outputs
416    fn compute_variance(y_a: &[F], y_b: &[F], n_samples: usize) -> F {
417        let n = F::from(n_samples).unwrap();
418        let mut sum = F::zero();
419        let mut sum_sq = F::zero();
420
421        for i in 0..n_samples {
422            let y = (y_a[i] + y_b[i]) / F::from(2.0).unwrap();
423            sum += y;
424            sum_sq += y * y;
425        }
426
427        let mean = sum / n;
428        sum_sq / n - mean * mean
429    }
430
431    /// Compute first-order Sobol index
432    fn compute_first_order_index(
433        y_a: &[F],
434        y_b: &[F],
435        y_ci: &[F],
436        var_y: F,
437        n_samples: usize,
438    ) -> F {
439        let n = F::from(n_samples).unwrap();
440        let mut sum = F::zero();
441
442        for i in 0..n_samples {
443            sum += y_b[i] * (y_ci[i] - y_a[i]);
444        }
445
446        let v_i = sum / n;
447        (v_i / var_y).max(F::zero()).min(F::one())
448    }
449
450    /// Compute total Sobol index
451    fn compute_total_index(y_a: &[F], y_ci: &[F], var_y: F, n_samples: usize) -> F {
452        let n = F::from(n_samples).unwrap();
453        let mut sum = F::zero();
454
455        for i in 0..n_samples {
456            let diff = y_a[i] - y_ci[i];
457            sum += diff * diff;
458        }
459
460        let e_i = sum / (F::from(2.0).unwrap() * n);
461        (e_i / var_y).max(F::zero()).min(F::one())
462    }
463}
464
465/// Extended Fourier Amplitude Sensitivity Test (eFAST)
466pub struct EFAST<F: IntegrateFloat> {
467    /// Number of samples
468    n_samples: usize,
469    /// Parameter bounds
470    param_bounds: Vec<(F, F)>,
471    /// Interference factor
472    interference_factor: usize,
473}
474
475impl<F: IntegrateFloat> EFAST<F> {
476    /// Create a new eFAST analysis
477    pub fn new(n_samples: usize, param_bounds: Vec<(F, F)>) -> Self {
478        EFAST {
479            n_samples,
480            param_bounds,
481            interference_factor: 4,
482        }
483    }
484
485    /// Set interference factor
486    pub fn with_interference_factor(&mut self, factor: usize) -> &mut Self {
487        self.interference_factor = factor;
488        self
489    }
490
491    /// Compute sensitivity indices using eFAST
492    pub fn compute_indices<Func>(
493        &self,
494        model: Func,
495        param_names: Vec<String>,
496    ) -> IntegrateResult<HashMap<String, F>>
497    where
498        Func: Fn(ArrayView1<F>) -> IntegrateResult<F>,
499    {
500        let n_params = self.param_bounds.len();
501        if param_names.len() != n_params {
502            return Err(IntegrateError::ValueError(
503                "Number of parameter _names must match bounds".to_string(),
504            ));
505        }
506
507        let mut indices = HashMap::new();
508        let omega_max = (self.n_samples - 1) / (2 * self.interference_factor);
509
510        // Compute indices for each parameter
511        for (i, name) in param_names.iter().enumerate() {
512            let omega_i = omega_max;
513            let samples = self.generate_samples(i, omega_i);
514
515            // Evaluate model
516            let mut y_values = Vec::with_capacity(self.n_samples);
517            for sample in &samples {
518                y_values.push(model(sample.view())?);
519            }
520
521            // Compute Fourier coefficients
522            let sensitivity = self.compute_fourier_sensitivity(&y_values, omega_i);
523            indices.insert(name.clone(), sensitivity);
524        }
525
526        Ok(indices)
527    }
528
529    /// Generate parameter samples using search curve
530    fn generate_samples(&self, _param_index: usize, omega: usize) -> Vec<Array1<F>> {
531        let n_params = self.param_bounds.len();
532        let mut samples = Vec::with_capacity(self.n_samples);
533
534        for k in 0..self.n_samples {
535            let s = F::from(k).unwrap() / F::from(self.n_samples).unwrap();
536            let mut sample = Array1::zeros(n_params);
537
538            for j in 0..n_params {
539                let (low, high) = self.param_bounds[j];
540
541                if j == _param_index {
542                    // Use higher frequency for parameter of interest
543                    let angle = F::from(2.0 * std::f64::consts::PI * omega as f64).unwrap() * s;
544                    let x = (F::one() + angle.sin()) / F::from(2.0).unwrap();
545                    sample[j] = low + (high - low) * x;
546                } else {
547                    // Use lower frequencies for other parameters
548                    let omega_j = if j < _param_index { j + 1 } else { j };
549                    let angle = F::from(2.0 * std::f64::consts::PI * omega_j as f64).unwrap() * s;
550                    let x = (F::one() + angle.sin()) / F::from(2.0).unwrap();
551                    sample[j] = low + (high - low) * x;
552                }
553            }
554
555            samples.push(sample);
556        }
557
558        samples
559    }
560
561    /// Compute Fourier-based sensitivity
562    fn compute_fourier_sensitivity(&self, y_values: &[F], omega: usize) -> F {
563        let n = self.n_samples;
564        let mut a_omega = F::zero();
565        let mut b_omega = F::zero();
566
567        for (k, y_value) in y_values.iter().enumerate().take(n) {
568            let angle =
569                F::from(2.0 * std::f64::consts::PI * omega as f64 * k as f64 / n as f64).unwrap();
570            a_omega += *y_value * angle.cos();
571            b_omega += *y_value * angle.sin();
572        }
573
574        a_omega *= F::from(2.0).unwrap() / F::from(n).unwrap();
575        b_omega *= F::from(2.0).unwrap() / F::from(n).unwrap();
576
577        // Return normalized sensitivity
578        (a_omega * a_omega + b_omega * b_omega).sqrt()
579    }
580}
581
582/// Parameter sensitivity ranking
583#[allow(dead_code)]
584pub fn rank_parameters<F: IntegrateFloat>(analysis: &SensitivityAnalysis<F>) -> Vec<(String, F)> {
585    let averaged = analysis.time_averaged_sensitivities();
586    let mut rankings: Vec<(String, F)> = Vec::new();
587
588    for (name, sens) in averaged {
589        // Use norm of sensitivity vector as ranking metric
590        let mut norm = F::zero();
591        for &s in sens.iter() {
592            norm += s * s;
593        }
594        norm = norm.sqrt();
595        rankings.push((name, norm));
596    }
597
598    // Sort by sensitivity (descending)
599    rankings.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
600
601    rankings
602}
603
604/// Compute sensitivity-based parameter subset selection
605#[allow(dead_code)]
606pub fn select_important_parameters<F: IntegrateFloat>(
607    analysis: &SensitivityAnalysis<F>,
608    threshold: F,
609) -> Vec<String> {
610    let rankings = rank_parameters(analysis);
611    let mut important = Vec::new();
612
613    // Compute total sensitivity
614    let total: F = rankings
615        .iter()
616        .map(|(_, s)| *s)
617        .fold(F::zero(), |acc, x| acc + x);
618
619    if total > F::epsilon() {
620        let mut cumulative = F::zero();
621
622        for (name, sens) in rankings {
623            cumulative += sens;
624            important.push(name);
625
626            // Stop when we've captured threshold fraction of total sensitivity
627            if cumulative / total >= threshold {
628                break;
629            }
630        }
631    }
632
633    important
634}
635
636/// Global sensitivity analysis using Sobol indices
637pub struct SobolSensitivity<F: IntegrateFloat> {
638    /// Number of parameters
639    n_params: usize,
640    /// Number of samples
641    n_samples: usize,
642    /// Parameter bounds
643    param_bounds: Vec<(F, F)>,
644}
645
646impl<F: IntegrateFloat + std::default::Default> SobolSensitivity<F> {
647    /// Create a new Sobol sensitivity analyzer
648    pub fn new(param_bounds: Vec<(F, F)>, n_samples: usize) -> Self {
649        SobolSensitivity {
650            n_params: param_bounds.len(),
651            n_samples,
652            param_bounds,
653        }
654    }
655
656    /// Generate Sobol sample matrices
657    pub fn generate_samples(&self) -> (Array2<F>, Array2<F>) {
658        use scirs2_core::random::Rng;
659        let mut rng = scirs2_core::random::rng();
660
661        // Generate base sample matrix A
662        let mut a_matrix = Array2::zeros((self.n_samples, self.n_params));
663        for i in 0..self.n_samples {
664            for j in 0..self.n_params {
665                let (lower, upper) = self.param_bounds[j];
666                let u: f64 = rng.random();
667                a_matrix[[i, j]] = lower + (upper - lower) * F::from(u).unwrap();
668            }
669        }
670
671        // Generate alternative sample matrix B
672        let mut b_matrix = Array2::zeros((self.n_samples, self.n_params));
673        for i in 0..self.n_samples {
674            for j in 0..self.n_params {
675                let (lower, upper) = self.param_bounds[j];
676                let u: f64 = rng.random();
677                b_matrix[[i, j]] = lower + (upper - lower) * F::from(u).unwrap();
678            }
679        }
680
681        (a_matrix, b_matrix)
682    }
683
684    /// Compute first-order and total Sobol indices
685    pub fn compute_indices<Func, SysFunc>(
686        &self,
687        system: SysFunc,
688        y0_func: Func,
689        t_span: (F, F),
690        t_eval: ArrayView1<F>,
691        options: Option<ODEOptions<F>>,
692    ) -> SensitivityResult<F>
693    where
694        Func: Fn(ArrayView1<F>) -> Array1<F>,
695        SysFunc: Fn(F, ArrayView1<F>, ArrayView1<F>) -> Array1<F> + Clone,
696    {
697        let (a_matrix, b_matrix) = self.generate_samples();
698        let n_states = y0_func(a_matrix.row(0)).len();
699        let n_time = t_eval.len();
700
701        // Compute model outputs for base samples
702        let mut y_a = Array2::zeros((self.n_samples, n_states * n_time));
703        let mut y_b = Array2::zeros((self.n_samples, n_states * n_time));
704
705        for i in 0..self.n_samples {
706            let params_a = a_matrix.row(i);
707            let params_b = b_matrix.row(i);
708
709            let y0_a = y0_func(params_a);
710            let y0_b = y0_func(params_b);
711
712            let sol_a = solve_ivp(
713                |t, y| system(t, y, params_a),
714                [t_span.0, t_span.1],
715                y0_a,
716                options.clone(),
717            )?;
718
719            let sol_b = solve_ivp(
720                |t, y| system(t, y, params_b),
721                [t_span.0, t_span.1],
722                y0_b,
723                options.clone(),
724            )?;
725
726            // Flatten solutions
727            for (j, t) in t_eval.iter().enumerate() {
728                let idx_a = sol_a
729                    .t
730                    .iter()
731                    .position(|&t_sol| (t_sol - *t).abs() < F::epsilon())
732                    .unwrap_or(0);
733                let idx_b = sol_b
734                    .t
735                    .iter()
736                    .position(|&t_sol| (t_sol - *t).abs() < F::epsilon())
737                    .unwrap_or(0);
738
739                for k in 0..n_states {
740                    y_a[[i, j * n_states + k]] = sol_a.y[idx_a][k];
741                    y_b[[i, j * n_states + k]] = sol_b.y[idx_b][k];
742                }
743            }
744        }
745
746        // Compute variance of outputs
747        let _mean_y = y_a.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
748        let var_y = y_a.var_axis(scirs2_core::ndarray::Axis(0), F::zero());
749
750        let mut first_order_indices = HashMap::new();
751        let mut total_indices = HashMap::new();
752
753        // Compute indices for each parameter
754        for param_idx in 0..self.n_params {
755            // Create C_i matrix (all columns from B except i-th from A)
756            let mut y_c_i = Array2::zeros((self.n_samples, n_states * n_time));
757
758            for sample in 0..self.n_samples {
759                let mut params_c_i = b_matrix.row(sample).to_owned();
760                params_c_i[param_idx] = a_matrix[[sample, param_idx]];
761
762                let y0_c = y0_func(params_c_i.view());
763                let sol_c = solve_ivp(
764                    |t, y| system(t, y, params_c_i.view()),
765                    [t_span.0, t_span.1],
766                    y0_c,
767                    options.clone(),
768                )?;
769
770                for (j, t) in t_eval.iter().enumerate() {
771                    let idx = sol_c
772                        .t
773                        .iter()
774                        .position(|&t_sol| (t_sol - *t).abs() < F::epsilon())
775                        .unwrap_or(0);
776                    for k in 0..n_states {
777                        y_c_i[[sample, j * n_states + k]] = sol_c.y[idx][k];
778                    }
779                }
780            }
781
782            // First-order index: S_i = V[E(Y|X_i)] / V(Y)
783            let mut s_i = Array1::zeros(n_states * n_time);
784            for j in 0..(n_states * n_time) {
785                let mut sum = F::zero();
786                for sample in 0..self.n_samples {
787                    sum += y_a[[sample, j]] * (y_c_i[[sample, j]] - y_b[[sample, j]]);
788                }
789                let v_i = sum / F::from(self.n_samples).unwrap();
790                s_i[j] = v_i / var_y[j];
791            }
792            first_order_indices.insert(param_idx, s_i);
793
794            // Total index: ST_i = 1 - V[E(Y|X_~i)] / V(Y)
795            let mut st_i = Array1::zeros(n_states * n_time);
796            for j in 0..(n_states * n_time) {
797                let mut sum = F::zero();
798                for sample in 0..self.n_samples {
799                    sum += y_b[[sample, j]] * (y_c_i[[sample, j]] - y_a[[sample, j]]);
800                }
801                let v_not_i = sum / F::from(self.n_samples).unwrap();
802                st_i[j] = F::one() - v_not_i / var_y[j];
803            }
804            total_indices.insert(param_idx, st_i);
805        }
806
807        Ok((first_order_indices, total_indices))
808    }
809}
810
811/// Morris screening method for parameter sensitivity
812pub struct MorrisScreening<F: IntegrateFloat> {
813    /// Number of parameters
814    n_params: usize,
815    /// Number of trajectories
816    n_trajectories: usize,
817    /// Step size
818    delta: F,
819    /// Parameter bounds
820    param_bounds: Vec<(F, F)>,
821    /// Grid levels
822    grid_levels: usize,
823}
824
825impl<F: IntegrateFloat> MorrisScreening<F> {
826    /// Create a new Morris screening analyzer
827    pub fn new(param_bounds: Vec<(F, F)>, n_trajectories: usize, delta: F) -> Self {
828        MorrisScreening {
829            n_params: param_bounds.len(),
830            n_trajectories,
831            delta,
832            param_bounds,
833            grid_levels: 4,
834        }
835    }
836
837    /// Create a new Morris screening analysis (legacy compatibility)
838    pub fn new_simple(n_trajectories: usize, param_bounds: Vec<(F, F)>) -> Self {
839        MorrisScreening {
840            n_params: param_bounds.len(),
841            n_trajectories,
842            delta: F::from(0.1).unwrap(),
843            param_bounds,
844            grid_levels: 4,
845        }
846    }
847
848    /// Set number of grid levels
849    pub fn with_grid_levels(mut self, levels: usize) -> Self {
850        self.grid_levels = levels;
851        self
852    }
853
854    /// Generate Morris trajectories
855    pub fn generate_trajectories(&self) -> Vec<Array2<F>> {
856        use scirs2_core::random::seq::SliceRandom;
857        let mut rng = scirs2_core::random::rng();
858
859        let mut trajectories = Vec::new();
860
861        for _ in 0..self.n_trajectories {
862            let mut trajectory = Array2::zeros((self.n_params + 1, self.n_params));
863
864            // Generate base point
865            for j in 0..self.n_params {
866                let (lower, upper) = self.param_bounds[j];
867                let u: f64 = rng.random();
868                trajectory[[0, j]] = lower + (upper - lower) * F::from(u).unwrap();
869            }
870
871            // Generate trajectory by changing one parameter at a time
872            let mut param_order: Vec<usize> = (0..self.n_params).collect();
873            param_order.shuffle(&mut rng);
874
875            for (i, &param_idx) in param_order.iter().enumerate() {
876                // Copy previous point
877                for j in 0..self.n_params {
878                    trajectory[[i + 1, j]] = trajectory[[i, j]];
879                }
880
881                // Change one parameter
882                let (lower, upper) = self.param_bounds[param_idx];
883                let range = upper - lower;
884                let direction = if rng.random::<bool>() {
885                    F::one()
886                } else {
887                    -F::one()
888                };
889                trajectory[[i + 1, param_idx]] += direction * self.delta * range;
890
891                // Ensure within bounds
892                trajectory[[i + 1, param_idx]] =
893                    trajectory[[i + 1, param_idx]].max(lower).min(upper);
894            }
895
896            trajectories.push(trajectory);
897        }
898
899        trajectories
900    }
901
902    /// Compute elementary effects from pre-generated trajectories
903    pub fn compute_effects<Func>(
904        &self,
905        model: Func,
906        trajectories: &[Array2<F>],
907    ) -> IntegrateResult<(Array1<F>, Array1<F>)>
908    where
909        Func: Fn(ArrayView1<F>) -> IntegrateResult<F>,
910    {
911        let mut elementary_effects = vec![Vec::new(); self.n_params];
912
913        for trajectory in trajectories {
914            for i in 0..self.n_params {
915                let y_before = model(trajectory.row(i))?;
916                let y_after = model(trajectory.row(i + 1))?;
917
918                // Find which parameter changed
919                for j in 0..self.n_params {
920                    if (trajectory[[i + 1, j]] - trajectory[[i, j]]).abs() > F::epsilon() {
921                        let effect =
922                            (y_after - y_before) / (trajectory[[i + 1, j]] - trajectory[[i, j]]);
923                        elementary_effects[j].push(effect);
924                        break;
925                    }
926                }
927            }
928        }
929
930        // Compute mean and standard deviation of elementary effects
931        let mut mu = Array1::zeros(self.n_params);
932        let mut sigma = Array1::zeros(self.n_params);
933
934        for j in 0..self.n_params {
935            let effects = &elementary_effects[j];
936            let n = F::from(effects.len()).unwrap();
937
938            // Mean of absolute effects (mu*)
939            let sum_abs: F = effects
940                .iter()
941                .map(|&e| e.abs())
942                .fold(F::zero(), |acc, x| acc + x);
943            mu[j] = sum_abs / n;
944
945            // Standard deviation
946            let mean: F = effects.iter().fold(F::zero(), |acc, &x| acc + x) / n;
947            let variance: F = effects
948                .iter()
949                .map(|&e| (e - mean) * (e - mean))
950                .fold(F::zero(), |acc, x| acc + x)
951                / n;
952            sigma[j] = variance.sqrt();
953        }
954
955        Ok((mu, sigma))
956    }
957
958    /// Compute elementary effects with parameter names (legacy compatibility)
959    pub fn compute_effects_named<Func>(
960        &self,
961        model: Func,
962        param_names: Vec<String>,
963    ) -> IntegrateResult<HashMap<String, (F, F)>>
964    where
965        Func: Fn(ArrayView1<F>) -> IntegrateResult<F>,
966    {
967        let n_params = self.param_bounds.len();
968        if param_names.len() != n_params {
969            return Err(IntegrateError::ValueError(
970                "Number of parameter _names must match bounds".to_string(),
971            ));
972        }
973
974        let mut effects = HashMap::new();
975        for name in &param_names {
976            effects.insert(name.clone(), (F::zero(), F::zero()));
977        }
978
979        // Generate trajectories and compute elementary effects
980        for _ in 0..self.n_trajectories {
981            let trajectory = self.generate_trajectory_legacy(n_params);
982
983            for i in 0..n_params {
984                let p1 = trajectory[i].view();
985                let p2 = trajectory[i + 1].view();
986
987                let y1 = model(p1)?;
988                let y2 = model(p2)?;
989
990                // Find which parameter changed
991                let mut changed_param = None;
992                for j in 0..n_params {
993                    if (p1[j] - p2[j]).abs() > F::epsilon() {
994                        changed_param = Some(j);
995                        break;
996                    }
997                }
998
999                if let Some(j) = changed_param {
1000                    let delta = p2[j] - p1[j];
1001                    let ee = (y2 - y1) / delta;
1002
1003                    let name = &param_names[j];
1004                    let (sum, sum_sq) = effects.get_mut(name).unwrap();
1005                    *sum += ee;
1006                    *sum_sq += ee * ee;
1007                }
1008            }
1009        }
1010
1011        // Compute mean and standard deviation
1012        let n_traj = F::from(self.n_trajectories).unwrap();
1013        let mut results = HashMap::new();
1014
1015        for (name, (sum, sum_sq)) in effects {
1016            let mu = sum / n_traj;
1017            let sigma = ((sum_sq / n_traj) - mu * mu).sqrt();
1018            results.insert(name, (mu.abs(), sigma));
1019        }
1020
1021        Ok(results)
1022    }
1023
1024    /// Generate a Morris trajectory (legacy compatibility)
1025    fn generate_trajectory_legacy(&self, n_params: usize) -> Vec<Array1<F>> {
1026        // Simplified trajectory generation
1027        let mut trajectory = Vec::new();
1028        let mut current = Array1::zeros(n_params);
1029
1030        // Random starting point
1031        for i in 0..n_params {
1032            let (low, high) = self.param_bounds[i];
1033            current[i] = low + (high - low) * F::from(0.5).unwrap();
1034        }
1035        trajectory.push(current.clone());
1036
1037        // Change one parameter at a time
1038        for i in 0..n_params {
1039            let (low, high) = self.param_bounds[i];
1040            let delta = (high - low) / F::from((self.grid_levels - 1) as f64).unwrap();
1041            current[i] += delta;
1042            trajectory.push(current.clone());
1043        }
1044
1045        trajectory
1046    }
1047}
1048
1049#[cfg(test)]
1050mod tests {
1051    use super::*;
1052
1053    #[test]
1054    fn test_parameter_sensitivity() {
1055        // Simple linear ODE: dy/dt = -a*y
1056        let system =
1057            |_t: f64, y: ArrayView1<f64>, p: ArrayView1<f64>| Array1::from_vec(vec![-p[0] * y[0]]);
1058
1059        let param_names = vec!["a".to_string()];
1060        let nominal_params = Array1::from_vec(vec![1.0]);
1061        let y0 = Array1::from_vec(vec![1.0]);
1062        let t_span = (0.0, 1.0);
1063
1064        let analysis = compute_sensitivities(
1065            system,
1066            |_| Array1::from_vec(vec![1.0]),
1067            param_names,
1068            nominal_params.view(),
1069            y0.view(),
1070            t_span,
1071            None,
1072            None,
1073        );
1074
1075        // Should complete without errors
1076        assert!(analysis.is_ok());
1077    }
1078}