1use scirs2_core::random::rngs::StdRng;
14use scirs2_core::random::{Rng, SeedableRng};
15use sklears_core::types::Float;
16use std::collections::HashMap;
17
18#[derive(Debug, Clone)]
24pub struct SHAPConfig {
25 pub n_samples: usize,
27 pub max_coalition_size: Option<usize>,
29 pub use_kernel_shap: bool,
31 pub background_size: usize,
33 pub random_state: Option<u64>,
34}
35
36impl Default for SHAPConfig {
37 fn default() -> Self {
38 Self {
39 n_samples: 1000,
40 max_coalition_size: None,
41 use_kernel_shap: true,
42 background_size: 100,
43 random_state: None,
44 }
45 }
46}
47
48pub struct SHAPAnalyzer {
50 config: SHAPConfig,
51 rng: StdRng,
52}
53
54impl SHAPAnalyzer {
55 pub fn new(config: SHAPConfig) -> Self {
56 let rng = StdRng::seed_from_u64(config.random_state.unwrap_or(42));
57 Self { config, rng }
58 }
59
60 pub fn compute_shap_values(
62 &mut self,
63 evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
64 parameter_space: &HashMap<String, (Float, Float)>,
65 reference_config: &HashMap<String, Float>,
66 ) -> Result<SHAPResult, Box<dyn std::error::Error>> {
67 let param_names: Vec<_> = parameter_space.keys().cloned().collect();
68 let n_params = param_names.len();
69
70 if self.config.use_kernel_shap {
71 self.compute_kernel_shap(
72 evaluation_fn,
73 parameter_space,
74 reference_config,
75 ¶m_names,
76 )
77 } else {
78 self.compute_exact_shap(
79 evaluation_fn,
80 parameter_space,
81 reference_config,
82 ¶m_names,
83 n_params,
84 )
85 }
86 }
87
88 fn compute_exact_shap(
90 &mut self,
91 evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
92 parameter_space: &HashMap<String, (Float, Float)>,
93 reference_config: &HashMap<String, Float>,
94 param_names: &[String],
95 n_params: usize,
96 ) -> Result<SHAPResult, Box<dyn std::error::Error>> {
97 let mut shap_values = HashMap::new();
98 let baseline_performance = evaluation_fn(reference_config);
99
100 for (i, param_name) in param_names.iter().enumerate() {
102 let mut marginal_contributions = Vec::new();
103
104 let max_coalitions = 2_usize.pow(n_params as u32 - 1);
106 let n_coalitions = if let Some(max_size) = self.config.max_coalition_size {
107 max_coalitions.min(max_size)
108 } else {
109 max_coalitions.min(1000) };
111
112 for _ in 0..n_coalitions {
113 let coalition = self.sample_coalition(n_params, i);
115
116 let perf_with = self.evaluate_coalition(
118 evaluation_fn,
119 parameter_space,
120 reference_config,
121 param_names,
122 &coalition,
123 Some(i),
124 )?;
125
126 let perf_without = self.evaluate_coalition(
127 evaluation_fn,
128 parameter_space,
129 reference_config,
130 param_names,
131 &coalition,
132 None,
133 )?;
134
135 marginal_contributions.push(perf_with - perf_without);
136 }
137
138 let shap_value = if marginal_contributions.is_empty() {
140 0.0
141 } else {
142 marginal_contributions.iter().sum::<Float>() / marginal_contributions.len() as Float
143 };
144
145 shap_values.insert(param_name.clone(), shap_value);
146 }
147
148 let rankings = self.rank_parameters(&shap_values);
149
150 Ok(SHAPResult {
151 shap_values,
152 baseline_performance,
153 parameter_rankings: rankings,
154 interaction_effects: HashMap::new(), })
156 }
157
158 fn compute_kernel_shap(
160 &mut self,
161 evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
162 parameter_space: &HashMap<String, (Float, Float)>,
163 reference_config: &HashMap<String, Float>,
164 param_names: &[String],
165 ) -> Result<SHAPResult, Box<dyn std::error::Error>> {
166 let n_params = param_names.len();
167 let mut shap_values = HashMap::new();
168 let baseline_performance = evaluation_fn(reference_config);
169
170 let mut samples = Vec::new();
172 let mut performances = Vec::new();
173 let mut weights = Vec::new();
174
175 for _ in 0..self.config.n_samples {
176 let coalition_size = self.rng.gen_range(0..=n_params);
178 let coalition = self.sample_coalition_of_size(n_params, coalition_size);
179
180 let mut perturbed = reference_config.clone();
182 for (idx, &include) in coalition.iter().enumerate() {
183 if !include {
184 let param_name = ¶m_names[idx];
186 if let Some(&(min, max)) = parameter_space.get(param_name) {
187 let random_value = self.rng.gen_range(min..max);
188 perturbed.insert(param_name.clone(), random_value);
189 }
190 }
191 }
192
193 let perf = evaluation_fn(&perturbed);
194 let weight = self.shapley_kernel_weight(coalition_size, n_params);
195
196 samples.push(coalition);
197 performances.push(perf);
198 weights.push(weight);
199 }
200
201 let shap_coefficients =
203 self.solve_weighted_least_squares(&samples, &performances, &weights)?;
204
205 for (i, param_name) in param_names.iter().enumerate() {
206 shap_values.insert(
207 param_name.clone(),
208 shap_coefficients.get(i).cloned().unwrap_or(0.0),
209 );
210 }
211
212 let rankings = self.rank_parameters(&shap_values);
213
214 Ok(SHAPResult {
215 shap_values,
216 baseline_performance,
217 parameter_rankings: rankings,
218 interaction_effects: HashMap::new(),
219 })
220 }
221
222 fn sample_coalition(&mut self, n_params: usize, exclude_idx: usize) -> Vec<bool> {
225 (0..n_params)
226 .map(|i| i != exclude_idx && self.rng.gen_bool(0.5))
227 .collect()
228 }
229
230 fn sample_coalition_of_size(&mut self, n_params: usize, size: usize) -> Vec<bool> {
231 let mut coalition = vec![false; n_params];
232 let mut indices: Vec<_> = (0..n_params).collect();
233
234 for i in (1..n_params).rev() {
236 let j = self.rng.gen_range(0..=i);
237 indices.swap(i, j);
238 }
239
240 for &idx in indices.iter().take(size) {
241 coalition[idx] = true;
242 }
243
244 coalition
245 }
246
247 fn evaluate_coalition(
248 &mut self,
249 evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
250 parameter_space: &HashMap<String, (Float, Float)>,
251 reference_config: &HashMap<String, Float>,
252 param_names: &[String],
253 coalition: &[bool],
254 include_idx: Option<usize>,
255 ) -> Result<Float, Box<dyn std::error::Error>> {
256 let mut config = reference_config.clone();
257
258 for (i, param_name) in param_names.iter().enumerate() {
259 let should_include = coalition[i] || include_idx == Some(i);
260 if !should_include {
261 if let Some(&(min, max)) = parameter_space.get(param_name) {
263 let random_value = self.rng.gen_range(min..max);
264 config.insert(param_name.clone(), random_value);
265 }
266 }
267 }
268
269 Ok(evaluation_fn(&config))
270 }
271
272 fn shapley_kernel_weight(&self, coalition_size: usize, n_params: usize) -> Float {
273 if coalition_size == 0 || coalition_size == n_params {
274 1e10 } else {
276 let numerator = (n_params - 1) as Float;
277 let denominator = (coalition_size * (n_params - coalition_size)) as Float;
278 numerator / denominator
279 }
280 }
281
282 fn solve_weighted_least_squares(
283 &self,
284 samples: &[Vec<bool>],
285 performances: &[Float],
286 weights: &[Float],
287 ) -> Result<Vec<Float>, Box<dyn std::error::Error>> {
288 if samples.is_empty() {
289 return Ok(Vec::new());
290 }
291
292 let n_params = samples[0].len();
293
294 let mut coefficients = vec![0.0; n_params];
296
297 for param_idx in 0..n_params {
298 let mut weighted_sum = 0.0;
299 let mut total_weight = 0.0;
300
301 for (i, sample) in samples.iter().enumerate() {
302 if sample[param_idx] {
303 weighted_sum += performances[i] * weights[i];
304 total_weight += weights[i];
305 }
306 }
307
308 if total_weight > 0.0 {
309 coefficients[param_idx] = weighted_sum / total_weight;
310 }
311 }
312
313 Ok(coefficients)
314 }
315
316 fn rank_parameters(&self, shap_values: &HashMap<String, Float>) -> Vec<(String, Float)> {
317 let mut ranked: Vec<_> = shap_values
318 .iter()
319 .map(|(name, &value)| (name.clone(), value.abs()))
320 .collect();
321 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
322 ranked
323 }
324}
325
326#[derive(Debug, Clone)]
328pub struct SHAPResult {
329 pub shap_values: HashMap<String, Float>,
330 pub baseline_performance: Float,
331 pub parameter_rankings: Vec<(String, Float)>,
332 pub interaction_effects: HashMap<(String, String), Float>,
333}
334
335#[derive(Debug, Clone)]
341pub struct FANOVAConfig {
342 pub n_trees: usize,
343 pub max_depth: usize,
344 pub min_samples_split: usize,
345 pub n_samples: usize,
346 pub random_state: Option<u64>,
347}
348
349impl Default for FANOVAConfig {
350 fn default() -> Self {
351 Self {
352 n_trees: 16,
353 max_depth: 6,
354 min_samples_split: 10,
355 n_samples: 1000,
356 random_state: None,
357 }
358 }
359}
360
361pub struct FANOVAAnalyzer {
363 config: FANOVAConfig,
364}
365
366impl FANOVAAnalyzer {
367 pub fn new(config: FANOVAConfig) -> Self {
368 Self { config }
369 }
370
371 pub fn analyze(
373 &self,
374 evaluation_data: &[(HashMap<String, Float>, Float)],
375 param_names: &[String],
376 ) -> Result<FANOVAResult, Box<dyn std::error::Error>> {
377 let performances: Vec<_> = evaluation_data.iter().map(|(_, perf)| *perf).collect();
379 let mean_performance = performances.iter().sum::<Float>() / performances.len() as Float;
380 let total_variance = performances
381 .iter()
382 .map(|&p| (p - mean_performance).powi(2))
383 .sum::<Float>()
384 / performances.len() as Float;
385
386 let mut main_effects = HashMap::new();
388 let mut interaction_effects = HashMap::new();
389
390 for param_name in param_names {
391 let variance_contribution =
392 self.compute_variance_contribution(evaluation_data, param_name, mean_performance)?;
393 let importance = variance_contribution / total_variance;
394 main_effects.insert(param_name.clone(), importance);
395 }
396
397 for i in 0..param_names.len() {
399 for j in (i + 1)..param_names.len() {
400 let interaction_variance = self.compute_interaction_variance(
401 evaluation_data,
402 ¶m_names[i],
403 ¶m_names[j],
404 mean_performance,
405 )?;
406 let importance = interaction_variance / total_variance;
407 interaction_effects
408 .insert((param_names[i].clone(), param_names[j].clone()), importance);
409 }
410 }
411
412 let rankings = self.rank_by_importance(&main_effects);
413
414 Ok(FANOVAResult {
415 main_effects,
416 interaction_effects,
417 total_variance,
418 parameter_rankings: rankings,
419 })
420 }
421
422 fn compute_variance_contribution(
423 &self,
424 data: &[(HashMap<String, Float>, Float)],
425 param_name: &str,
426 mean: Float,
427 ) -> Result<Float, Box<dyn std::error::Error>> {
428 let mut groups: HashMap<String, Vec<Float>> = HashMap::new();
430
431 for (params, perf) in data {
432 if let Some(&value) = params.get(param_name) {
433 let bin = format!("{:.2}", value);
435 groups.entry(bin).or_default().push(*perf);
436 }
437 }
438
439 let mut variance_explained = 0.0;
441 for performances in groups.values() {
442 if performances.is_empty() {
443 continue;
444 }
445 let group_mean = performances.iter().sum::<Float>() / performances.len() as Float;
446 let group_size = performances.len() as Float;
447 variance_explained += group_size * (group_mean - mean).powi(2);
448 }
449
450 Ok(variance_explained / data.len() as Float)
451 }
452
453 fn compute_interaction_variance(
454 &self,
455 data: &[(HashMap<String, Float>, Float)],
456 param1: &str,
457 param2: &str,
458 mean: Float,
459 ) -> Result<Float, Box<dyn std::error::Error>> {
460 let mut groups: HashMap<(String, String), Vec<Float>> = HashMap::new();
462
463 for (params, perf) in data {
464 if let (Some(&v1), Some(&v2)) = (params.get(param1), params.get(param2)) {
465 let bin1 = format!("{:.2}", v1);
466 let bin2 = format!("{:.2}", v2);
467 groups.entry((bin1, bin2)).or_default().push(*perf);
468 }
469 }
470
471 let var1 = self.compute_variance_contribution(data, param1, mean)?;
473 let var2 = self.compute_variance_contribution(data, param2, mean)?;
474
475 let mut joint_variance = 0.0;
476 for performances in groups.values() {
477 if performances.is_empty() {
478 continue;
479 }
480 let group_mean = performances.iter().sum::<Float>() / performances.len() as Float;
481 let group_size = performances.len() as Float;
482 joint_variance += group_size * (group_mean - mean).powi(2);
483 }
484 joint_variance /= data.len() as Float;
485
486 Ok((joint_variance - var1 - var2).max(0.0))
488 }
489
490 fn rank_by_importance(&self, effects: &HashMap<String, Float>) -> Vec<(String, Float)> {
491 let mut ranked: Vec<_> = effects.iter().map(|(k, &v)| (k.clone(), v)).collect();
492 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
493 ranked
494 }
495}
496
497#[derive(Debug, Clone)]
499pub struct FANOVAResult {
500 pub main_effects: HashMap<String, Float>,
501 pub interaction_effects: HashMap<(String, String), Float>,
502 pub total_variance: Float,
503 pub parameter_rankings: Vec<(String, Float)>,
504}
505
506#[derive(Debug, Clone)]
512pub struct SensitivityConfig {
513 pub n_trajectories: usize,
515 pub n_levels: usize,
517 pub perturbation_delta: Float,
519 pub random_state: Option<u64>,
520}
521
522impl Default for SensitivityConfig {
523 fn default() -> Self {
524 Self {
525 n_trajectories: 10,
526 n_levels: 4,
527 perturbation_delta: 0.01,
528 random_state: None,
529 }
530 }
531}
532
533pub struct SensitivityAnalyzer {
535 config: SensitivityConfig,
536 rng: StdRng,
537}
538
539impl SensitivityAnalyzer {
540 pub fn new(config: SensitivityConfig) -> Self {
541 let rng = StdRng::seed_from_u64(config.random_state.unwrap_or(42));
542 Self { config, rng }
543 }
544
545 pub fn morris_analysis(
547 &mut self,
548 evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
549 parameter_space: &HashMap<String, (Float, Float)>,
550 base_config: &HashMap<String, Float>,
551 ) -> Result<SensitivityResult, Box<dyn std::error::Error>> {
552 let param_names: Vec<_> = parameter_space.keys().cloned().collect();
553 let mut elementary_effects: HashMap<String, Vec<Float>> = HashMap::new();
554
555 for _ in 0..self.config.n_trajectories {
556 let mut current = base_config.clone();
558
559 for param_name in ¶m_names {
560 if let Some(&(min, max)) = parameter_space.get(param_name) {
561 let original_value = current.get(param_name).cloned().unwrap_or(min);
563 let delta = self.config.perturbation_delta * (max - min);
564
565 let perturbed_value = (original_value + delta).min(max);
566 let mut perturbed = current.clone();
567 perturbed.insert(param_name.clone(), perturbed_value);
568
569 let f_original = evaluation_fn(¤t);
571 let f_perturbed = evaluation_fn(&perturbed);
572 let effect = (f_perturbed - f_original) / delta;
573
574 elementary_effects
575 .entry(param_name.clone())
576 .or_default()
577 .push(effect);
578
579 current = perturbed;
580 }
581 }
582 }
583
584 let mut sensitivities = HashMap::new();
586 let interactions = HashMap::new();
587
588 for (param_name, effects) in &elementary_effects {
589 let mean = effects.iter().sum::<Float>() / effects.len() as Float;
590 let variance =
591 effects.iter().map(|&e| (e - mean).powi(2)).sum::<Float>() / effects.len() as Float;
592 let std_dev = variance.sqrt();
593
594 sensitivities.insert(
595 param_name.clone(),
596 ParameterSensitivity {
597 mean_effect: mean.abs(),
598 std_effect: std_dev,
599 mu_star: mean.abs(), sigma: std_dev,
601 },
602 );
603 }
604
605 let rankings = self.rank_sensitivities(&sensitivities);
606
607 Ok(SensitivityResult {
608 sensitivities,
609 interactions,
610 rankings,
611 })
612 }
613
614 pub fn oat_analysis(
616 &mut self,
617 evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
618 parameter_space: &HashMap<String, (Float, Float)>,
619 base_config: &HashMap<String, Float>,
620 ) -> Result<SensitivityResult, Box<dyn std::error::Error>> {
621 let param_names: Vec<_> = parameter_space.keys().cloned().collect();
622 let mut sensitivities = HashMap::new();
623 let baseline_perf = evaluation_fn(base_config);
624
625 for param_name in ¶m_names {
626 if let Some(&(min, max)) = parameter_space.get(param_name) {
627 let base_value = base_config
628 .get(param_name)
629 .cloned()
630 .unwrap_or((min + max) / 2.0);
631
632 let n_points = 5;
634 let mut effects = Vec::new();
635
636 for i in 0..n_points {
637 let alpha = i as Float / (n_points - 1) as Float;
638 let value = min + alpha * (max - min);
639
640 if (value - base_value).abs() < 1e-6 {
641 continue;
642 }
643
644 let mut perturbed = base_config.clone();
645 perturbed.insert(param_name.clone(), value);
646
647 let perf = evaluation_fn(&perturbed);
648 let effect = (perf - baseline_perf).abs() / (value - base_value).abs();
649 effects.push(effect);
650 }
651
652 if !effects.is_empty() {
653 let mean_effect = effects.iter().sum::<Float>() / effects.len() as Float;
654 let variance = effects
655 .iter()
656 .map(|&e| (e - mean_effect).powi(2))
657 .sum::<Float>()
658 / effects.len() as Float;
659
660 sensitivities.insert(
661 param_name.clone(),
662 ParameterSensitivity {
663 mean_effect,
664 std_effect: variance.sqrt(),
665 mu_star: mean_effect,
666 sigma: variance.sqrt(),
667 },
668 );
669 }
670 }
671 }
672
673 let rankings = self.rank_sensitivities(&sensitivities);
674
675 Ok(SensitivityResult {
676 sensitivities,
677 interactions: HashMap::new(),
678 rankings,
679 })
680 }
681
682 fn rank_sensitivities(
683 &self,
684 sensitivities: &HashMap<String, ParameterSensitivity>,
685 ) -> Vec<(String, Float)> {
686 let mut ranked: Vec<_> = sensitivities
687 .iter()
688 .map(|(name, sens)| (name.clone(), sens.mu_star))
689 .collect();
690 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
691 ranked
692 }
693}
694
695#[derive(Debug, Clone)]
697pub struct ParameterSensitivity {
698 pub mean_effect: Float,
699 pub std_effect: Float,
700 pub mu_star: Float,
701 pub sigma: Float,
702}
703
704#[derive(Debug, Clone)]
706pub struct SensitivityResult {
707 pub sensitivities: HashMap<String, ParameterSensitivity>,
708 pub interactions: HashMap<(String, String), Float>,
709 pub rankings: Vec<(String, Float)>,
710}
711
712#[derive(Debug, Clone)]
718pub struct AblationConfig {
719 pub n_iterations: usize,
721 pub leave_one_out: bool,
723 pub cumulative: bool,
725 pub random_state: Option<u64>,
726}
727
728impl Default for AblationConfig {
729 fn default() -> Self {
730 Self {
731 n_iterations: 10,
732 leave_one_out: true,
733 cumulative: false,
734 random_state: None,
735 }
736 }
737}
738
739pub struct AblationAnalyzer {
741 config: AblationConfig,
742}
743
744impl AblationAnalyzer {
745 pub fn new(config: AblationConfig) -> Self {
746 Self { config }
747 }
748
749 pub fn analyze(
751 &self,
752 evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
753 parameter_space: &HashMap<String, (Float, Float)>,
754 base_config: &HashMap<String, Float>,
755 ) -> Result<AblationResult, Box<dyn std::error::Error>> {
756 let param_names: Vec<_> = parameter_space.keys().cloned().collect();
757 let baseline_performance = evaluation_fn(base_config);
758 let mut ablation_effects = HashMap::new();
759
760 if self.config.leave_one_out {
761 for param_name in ¶m_names {
763 let mut ablated = base_config.clone();
764
765 if let Some(&(min, max)) = parameter_space.get(param_name) {
767 ablated.insert(param_name.clone(), (min + max) / 2.0);
768 }
769
770 let ablated_perf = evaluation_fn(&ablated);
771 let effect = baseline_performance - ablated_perf;
772
773 ablation_effects.insert(param_name.clone(), effect);
774 }
775 } else {
776 let mut current = base_config.clone();
778 let mut remaining_params = param_names.clone();
779
780 while !remaining_params.is_empty() {
781 let mut best_param = None;
782 let mut best_effect = f64::NEG_INFINITY;
783
784 for param_name in &remaining_params {
785 let mut test_config = current.clone();
786 if let Some(&(min, max)) = parameter_space.get(param_name) {
787 test_config.insert(param_name.clone(), (min + max) / 2.0);
788 }
789
790 let perf = evaluation_fn(&test_config);
791 let effect = baseline_performance - perf;
792
793 if effect > best_effect {
794 best_effect = effect;
795 best_param = Some(param_name.clone());
796 }
797 }
798
799 if let Some(param) = best_param {
800 ablation_effects.insert(param.clone(), best_effect);
801 if let Some(&(min, max)) = parameter_space.get(¶m) {
802 current.insert(param.clone(), (min + max) / 2.0);
803 }
804 remaining_params.retain(|p| p != ¶m);
805 }
806 }
807 }
808
809 let rankings = self.rank_ablation_effects(&ablation_effects);
810
811 Ok(AblationResult {
812 ablation_effects,
813 baseline_performance,
814 parameter_rankings: rankings,
815 })
816 }
817
818 fn rank_ablation_effects(&self, effects: &HashMap<String, Float>) -> Vec<(String, Float)> {
819 let mut ranked: Vec<_> = effects.iter().map(|(k, &v)| (k.clone(), v.abs())).collect();
820 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
821 ranked
822 }
823}
824
825#[derive(Debug, Clone)]
827pub struct AblationResult {
828 pub ablation_effects: HashMap<String, Float>,
829 pub baseline_performance: Float,
830 pub parameter_rankings: Vec<(String, Float)>,
831}
832
833pub struct HyperparameterImportanceAnalyzer {
839 shap_analyzer: SHAPAnalyzer,
840 fanova_analyzer: FANOVAAnalyzer,
841 sensitivity_analyzer: SensitivityAnalyzer,
842 ablation_analyzer: AblationAnalyzer,
843}
844
845impl HyperparameterImportanceAnalyzer {
846 pub fn new(
847 shap_config: SHAPConfig,
848 fanova_config: FANOVAConfig,
849 sensitivity_config: SensitivityConfig,
850 ablation_config: AblationConfig,
851 ) -> Self {
852 Self {
853 shap_analyzer: SHAPAnalyzer::new(shap_config),
854 fanova_analyzer: FANOVAAnalyzer::new(fanova_config),
855 sensitivity_analyzer: SensitivityAnalyzer::new(sensitivity_config),
856 ablation_analyzer: AblationAnalyzer::new(ablation_config),
857 }
858 }
859
860 pub fn analyze_comprehensive(
862 &mut self,
863 evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
864 parameter_space: &HashMap<String, (Float, Float)>,
865 base_config: &HashMap<String, Float>,
866 evaluation_data: &[(HashMap<String, Float>, Float)],
867 ) -> Result<ComprehensiveImportanceResult, Box<dyn std::error::Error>> {
868 let param_names: Vec<_> = parameter_space.keys().cloned().collect();
869
870 let shap_result =
872 self.shap_analyzer
873 .compute_shap_values(evaluation_fn, parameter_space, base_config)?;
874
875 let fanova_result = self
876 .fanova_analyzer
877 .analyze(evaluation_data, ¶m_names)?;
878
879 let sensitivity_result = self.sensitivity_analyzer.morris_analysis(
880 evaluation_fn,
881 parameter_space,
882 base_config,
883 )?;
884
885 let ablation_result =
886 self.ablation_analyzer
887 .analyze(evaluation_fn, parameter_space, base_config)?;
888
889 let aggregated_rankings = self.aggregate_rankings(
891 &shap_result.parameter_rankings,
892 &fanova_result.parameter_rankings,
893 &sensitivity_result.rankings,
894 &ablation_result.parameter_rankings,
895 );
896
897 Ok(ComprehensiveImportanceResult {
898 shap_result,
899 fanova_result,
900 sensitivity_result,
901 ablation_result,
902 aggregated_rankings,
903 })
904 }
905
906 fn aggregate_rankings(
907 &self,
908 shap: &[(String, Float)],
909 fanova: &[(String, Float)],
910 sensitivity: &[(String, Float)],
911 ablation: &[(String, Float)],
912 ) -> Vec<(String, Float)> {
913 let mut scores: HashMap<String, Vec<Float>> = HashMap::new();
914
915 for (param, value) in shap {
917 scores.entry(param.clone()).or_default().push(*value);
918 }
919 for (param, value) in fanova {
920 scores.entry(param.clone()).or_default().push(*value);
921 }
922 for (param, value) in sensitivity {
923 scores.entry(param.clone()).or_default().push(*value);
924 }
925 for (param, value) in ablation {
926 scores.entry(param.clone()).or_default().push(*value);
927 }
928
929 let mut aggregated: Vec<_> = scores
930 .iter()
931 .map(|(param, values)| {
932 let avg = values.iter().sum::<Float>() / values.len() as Float;
933 (param.clone(), avg)
934 })
935 .collect();
936
937 aggregated.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
938 aggregated
939 }
940}
941
942#[derive(Debug, Clone)]
944pub struct ComprehensiveImportanceResult {
945 pub shap_result: SHAPResult,
946 pub fanova_result: FANOVAResult,
947 pub sensitivity_result: SensitivityResult,
948 pub ablation_result: AblationResult,
949 pub aggregated_rankings: Vec<(String, Float)>,
950}
951
952pub fn compute_shap_importance(
958 evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
959 parameter_space: &HashMap<String, (Float, Float)>,
960 reference_config: &HashMap<String, Float>,
961) -> Result<SHAPResult, Box<dyn std::error::Error>> {
962 let config = SHAPConfig::default();
963 let mut analyzer = SHAPAnalyzer::new(config);
964 analyzer.compute_shap_values(evaluation_fn, parameter_space, reference_config)
965}
966
967pub fn analyze_parameter_sensitivity(
969 evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
970 parameter_space: &HashMap<String, (Float, Float)>,
971 base_config: &HashMap<String, Float>,
972) -> Result<SensitivityResult, Box<dyn std::error::Error>> {
973 let config = SensitivityConfig::default();
974 let mut analyzer = SensitivityAnalyzer::new(config);
975 analyzer.morris_analysis(evaluation_fn, parameter_space, base_config)
976}
977
978#[cfg(test)]
983mod tests {
984 use super::*;
985
986 #[test]
987 fn test_shap_config() {
988 let config = SHAPConfig::default();
989 assert_eq!(config.n_samples, 1000);
990 assert!(config.use_kernel_shap);
991 }
992
993 #[test]
994 fn test_fanova_config() {
995 let config = FANOVAConfig::default();
996 assert_eq!(config.n_trees, 16);
997 assert_eq!(config.max_depth, 6);
998 }
999
1000 #[test]
1001 fn test_sensitivity_config() {
1002 let config = SensitivityConfig::default();
1003 assert_eq!(config.n_trajectories, 10);
1004 assert_eq!(config.n_levels, 4);
1005 }
1006
1007 #[test]
1008 fn test_ablation_config() {
1009 let config = AblationConfig::default();
1010 assert_eq!(config.n_iterations, 10);
1011 assert!(config.leave_one_out);
1012 }
1013
1014 #[test]
1015 fn test_shap_analyzer_creation() {
1016 let config = SHAPConfig::default();
1017 let analyzer = SHAPAnalyzer::new(config);
1018 assert_eq!(analyzer.config.n_samples, 1000);
1019 }
1020
1021 #[test]
1022 fn test_sensitivity_analyzer_creation() {
1023 let config = SensitivityConfig::default();
1024 let analyzer = SensitivityAnalyzer::new(config);
1025 assert_eq!(analyzer.config.n_trajectories, 10);
1026 }
1027}