sklears_model_selection/
hyperparameter_importance.rs

1//! Hyperparameter Importance Analysis
2//!
3//! This module provides comprehensive hyperparameter importance analysis including:
4//! - SHAP (SHapley Additive exPlanations) values for hyperparameters
5//! - Functional ANOVA (fANOVA) for parameter analysis
6//! - Interaction effect analysis
7//! - Parameter sensitivity analysis
8//! - Ablation studies for parameters
9//!
10//! These techniques help understand which hyperparameters are most important and how they
11//! interact, enabling better hyperparameter optimization strategies.
12
13use scirs2_core::random::rngs::StdRng;
14use scirs2_core::random::{Rng, SeedableRng};
15use sklears_core::types::Float;
16use std::collections::HashMap;
17
18// ============================================================================
19// SHAP Values for Hyperparameters
20// ============================================================================
21
22/// Configuration for SHAP value computation
23#[derive(Debug, Clone)]
24pub struct SHAPConfig {
25    /// Number of samples for SHAP estimation
26    pub n_samples: usize,
27    /// Maximum coalition size to consider
28    pub max_coalition_size: Option<usize>,
29    /// Whether to use KernelSHAP approximation
30    pub use_kernel_shap: bool,
31    /// Background dataset size for TreeSHAP
32    pub background_size: usize,
33    pub random_state: Option<u64>,
34}
35
36impl Default for SHAPConfig {
37    fn default() -> Self {
38        Self {
39            n_samples: 1000,
40            max_coalition_size: None,
41            use_kernel_shap: true,
42            background_size: 100,
43            random_state: None,
44        }
45    }
46}
47
48/// SHAP value analyzer for hyperparameters
49pub struct SHAPAnalyzer {
50    config: SHAPConfig,
51    rng: StdRng,
52}
53
54impl SHAPAnalyzer {
55    pub fn new(config: SHAPConfig) -> Self {
56        let rng = StdRng::seed_from_u64(config.random_state.unwrap_or(42));
57        Self { config, rng }
58    }
59
60    /// Compute SHAP values for hyperparameters
61    pub fn compute_shap_values(
62        &mut self,
63        evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
64        parameter_space: &HashMap<String, (Float, Float)>,
65        reference_config: &HashMap<String, Float>,
66    ) -> Result<SHAPResult, Box<dyn std::error::Error>> {
67        let param_names: Vec<_> = parameter_space.keys().cloned().collect();
68        let n_params = param_names.len();
69
70        if self.config.use_kernel_shap {
71            self.compute_kernel_shap(
72                evaluation_fn,
73                parameter_space,
74                reference_config,
75                &param_names,
76            )
77        } else {
78            self.compute_exact_shap(
79                evaluation_fn,
80                parameter_space,
81                reference_config,
82                &param_names,
83                n_params,
84            )
85        }
86    }
87
88    /// Compute exact SHAP values using all coalitions
89    fn compute_exact_shap(
90        &mut self,
91        evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
92        parameter_space: &HashMap<String, (Float, Float)>,
93        reference_config: &HashMap<String, Float>,
94        param_names: &[String],
95        n_params: usize,
96    ) -> Result<SHAPResult, Box<dyn std::error::Error>> {
97        let mut shap_values = HashMap::new();
98        let baseline_performance = evaluation_fn(reference_config);
99
100        // For each parameter, compute its SHAP value
101        for (i, param_name) in param_names.iter().enumerate() {
102            let mut marginal_contributions = Vec::new();
103
104            // Generate all possible coalitions (power set)
105            let max_coalitions = 2_usize.pow(n_params as u32 - 1);
106            let n_coalitions = if let Some(max_size) = self.config.max_coalition_size {
107                max_coalitions.min(max_size)
108            } else {
109                max_coalitions.min(1000) // Limit for practicality
110            };
111
112            for _ in 0..n_coalitions {
113                // Create random coalition
114                let coalition = self.sample_coalition(n_params, i);
115
116                // Evaluate with and without the parameter
117                let perf_with = self.evaluate_coalition(
118                    evaluation_fn,
119                    parameter_space,
120                    reference_config,
121                    param_names,
122                    &coalition,
123                    Some(i),
124                )?;
125
126                let perf_without = self.evaluate_coalition(
127                    evaluation_fn,
128                    parameter_space,
129                    reference_config,
130                    param_names,
131                    &coalition,
132                    None,
133                )?;
134
135                marginal_contributions.push(perf_with - perf_without);
136            }
137
138            // Average marginal contributions
139            let shap_value = if marginal_contributions.is_empty() {
140                0.0
141            } else {
142                marginal_contributions.iter().sum::<Float>() / marginal_contributions.len() as Float
143            };
144
145            shap_values.insert(param_name.clone(), shap_value);
146        }
147
148        let rankings = self.rank_parameters(&shap_values);
149
150        Ok(SHAPResult {
151            shap_values,
152            baseline_performance,
153            parameter_rankings: rankings,
154            interaction_effects: HashMap::new(), // Computed separately
155        })
156    }
157
158    /// Compute KernelSHAP approximation
159    fn compute_kernel_shap(
160        &mut self,
161        evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
162        parameter_space: &HashMap<String, (Float, Float)>,
163        reference_config: &HashMap<String, Float>,
164        param_names: &[String],
165    ) -> Result<SHAPResult, Box<dyn std::error::Error>> {
166        let n_params = param_names.len();
167        let mut shap_values = HashMap::new();
168        let baseline_performance = evaluation_fn(reference_config);
169
170        // KernelSHAP uses weighted linear regression
171        let mut samples = Vec::new();
172        let mut performances = Vec::new();
173        let mut weights = Vec::new();
174
175        for _ in 0..self.config.n_samples {
176            // Sample a coalition
177            let coalition_size = self.rng.gen_range(0..=n_params);
178            let coalition = self.sample_coalition_of_size(n_params, coalition_size);
179
180            // Create perturbed configuration
181            let mut perturbed = reference_config.clone();
182            for (idx, &include) in coalition.iter().enumerate() {
183                if !include {
184                    // Replace with random value from parameter space
185                    let param_name = &param_names[idx];
186                    if let Some(&(min, max)) = parameter_space.get(param_name) {
187                        let random_value = self.rng.gen_range(min..max);
188                        perturbed.insert(param_name.clone(), random_value);
189                    }
190                }
191            }
192
193            let perf = evaluation_fn(&perturbed);
194            let weight = self.shapley_kernel_weight(coalition_size, n_params);
195
196            samples.push(coalition);
197            performances.push(perf);
198            weights.push(weight);
199        }
200
201        // Solve weighted least squares to get SHAP values
202        let shap_coefficients =
203            self.solve_weighted_least_squares(&samples, &performances, &weights)?;
204
205        for (i, param_name) in param_names.iter().enumerate() {
206            shap_values.insert(
207                param_name.clone(),
208                shap_coefficients.get(i).cloned().unwrap_or(0.0),
209            );
210        }
211
212        let rankings = self.rank_parameters(&shap_values);
213
214        Ok(SHAPResult {
215            shap_values,
216            baseline_performance,
217            parameter_rankings: rankings,
218            interaction_effects: HashMap::new(),
219        })
220    }
221
222    // Helper methods
223
224    fn sample_coalition(&mut self, n_params: usize, exclude_idx: usize) -> Vec<bool> {
225        (0..n_params)
226            .map(|i| i != exclude_idx && self.rng.gen_bool(0.5))
227            .collect()
228    }
229
230    fn sample_coalition_of_size(&mut self, n_params: usize, size: usize) -> Vec<bool> {
231        let mut coalition = vec![false; n_params];
232        let mut indices: Vec<_> = (0..n_params).collect();
233
234        // Fisher-Yates shuffle
235        for i in (1..n_params).rev() {
236            let j = self.rng.gen_range(0..=i);
237            indices.swap(i, j);
238        }
239
240        for &idx in indices.iter().take(size) {
241            coalition[idx] = true;
242        }
243
244        coalition
245    }
246
247    fn evaluate_coalition(
248        &mut self,
249        evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
250        parameter_space: &HashMap<String, (Float, Float)>,
251        reference_config: &HashMap<String, Float>,
252        param_names: &[String],
253        coalition: &[bool],
254        include_idx: Option<usize>,
255    ) -> Result<Float, Box<dyn std::error::Error>> {
256        let mut config = reference_config.clone();
257
258        for (i, param_name) in param_names.iter().enumerate() {
259            let should_include = coalition[i] || include_idx == Some(i);
260            if !should_include {
261                // Use random value
262                if let Some(&(min, max)) = parameter_space.get(param_name) {
263                    let random_value = self.rng.gen_range(min..max);
264                    config.insert(param_name.clone(), random_value);
265                }
266            }
267        }
268
269        Ok(evaluation_fn(&config))
270    }
271
272    fn shapley_kernel_weight(&self, coalition_size: usize, n_params: usize) -> Float {
273        if coalition_size == 0 || coalition_size == n_params {
274            1e10 // Very large weight for empty and full coalitions
275        } else {
276            let numerator = (n_params - 1) as Float;
277            let denominator = (coalition_size * (n_params - coalition_size)) as Float;
278            numerator / denominator
279        }
280    }
281
282    fn solve_weighted_least_squares(
283        &self,
284        samples: &[Vec<bool>],
285        performances: &[Float],
286        weights: &[Float],
287    ) -> Result<Vec<Float>, Box<dyn std::error::Error>> {
288        if samples.is_empty() {
289            return Ok(Vec::new());
290        }
291
292        let n_params = samples[0].len();
293
294        // Simple weighted average approach (simplified)
295        let mut coefficients = vec![0.0; n_params];
296
297        for param_idx in 0..n_params {
298            let mut weighted_sum = 0.0;
299            let mut total_weight = 0.0;
300
301            for (i, sample) in samples.iter().enumerate() {
302                if sample[param_idx] {
303                    weighted_sum += performances[i] * weights[i];
304                    total_weight += weights[i];
305                }
306            }
307
308            if total_weight > 0.0 {
309                coefficients[param_idx] = weighted_sum / total_weight;
310            }
311        }
312
313        Ok(coefficients)
314    }
315
316    fn rank_parameters(&self, shap_values: &HashMap<String, Float>) -> Vec<(String, Float)> {
317        let mut ranked: Vec<_> = shap_values
318            .iter()
319            .map(|(name, &value)| (name.clone(), value.abs()))
320            .collect();
321        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
322        ranked
323    }
324}
325
326/// Result of SHAP analysis
327#[derive(Debug, Clone)]
328pub struct SHAPResult {
329    pub shap_values: HashMap<String, Float>,
330    pub baseline_performance: Float,
331    pub parameter_rankings: Vec<(String, Float)>,
332    pub interaction_effects: HashMap<(String, String), Float>,
333}
334
335// ============================================================================
336// Functional ANOVA (fANOVA)
337// ============================================================================
338
339/// Configuration for fANOVA analysis
340#[derive(Debug, Clone)]
341pub struct FANOVAConfig {
342    pub n_trees: usize,
343    pub max_depth: usize,
344    pub min_samples_split: usize,
345    pub n_samples: usize,
346    pub random_state: Option<u64>,
347}
348
349impl Default for FANOVAConfig {
350    fn default() -> Self {
351        Self {
352            n_trees: 16,
353            max_depth: 6,
354            min_samples_split: 10,
355            n_samples: 1000,
356            random_state: None,
357        }
358    }
359}
360
361/// Functional ANOVA analyzer
362pub struct FANOVAAnalyzer {
363    config: FANOVAConfig,
364}
365
366impl FANOVAAnalyzer {
367    pub fn new(config: FANOVAConfig) -> Self {
368        Self { config }
369    }
370
371    /// Perform fANOVA analysis
372    pub fn analyze(
373        &self,
374        evaluation_data: &[(HashMap<String, Float>, Float)],
375        param_names: &[String],
376    ) -> Result<FANOVAResult, Box<dyn std::error::Error>> {
377        // Compute total variance
378        let performances: Vec<_> = evaluation_data.iter().map(|(_, perf)| *perf).collect();
379        let mean_performance = performances.iter().sum::<Float>() / performances.len() as Float;
380        let total_variance = performances
381            .iter()
382            .map(|&p| (p - mean_performance).powi(2))
383            .sum::<Float>()
384            / performances.len() as Float;
385
386        // Compute variance contribution for each parameter
387        let mut main_effects = HashMap::new();
388        let mut interaction_effects = HashMap::new();
389
390        for param_name in param_names {
391            let variance_contribution =
392                self.compute_variance_contribution(evaluation_data, param_name, mean_performance)?;
393            let importance = variance_contribution / total_variance;
394            main_effects.insert(param_name.clone(), importance);
395        }
396
397        // Compute pairwise interactions
398        for i in 0..param_names.len() {
399            for j in (i + 1)..param_names.len() {
400                let interaction_variance = self.compute_interaction_variance(
401                    evaluation_data,
402                    &param_names[i],
403                    &param_names[j],
404                    mean_performance,
405                )?;
406                let importance = interaction_variance / total_variance;
407                interaction_effects
408                    .insert((param_names[i].clone(), param_names[j].clone()), importance);
409            }
410        }
411
412        let rankings = self.rank_by_importance(&main_effects);
413
414        Ok(FANOVAResult {
415            main_effects,
416            interaction_effects,
417            total_variance,
418            parameter_rankings: rankings,
419        })
420    }
421
422    fn compute_variance_contribution(
423        &self,
424        data: &[(HashMap<String, Float>, Float)],
425        param_name: &str,
426        mean: Float,
427    ) -> Result<Float, Box<dyn std::error::Error>> {
428        // Group by parameter value and compute conditional variance
429        let mut groups: HashMap<String, Vec<Float>> = HashMap::new();
430
431        for (params, perf) in data {
432            if let Some(&value) = params.get(param_name) {
433                // Discretize continuous values into bins
434                let bin = format!("{:.2}", value);
435                groups.entry(bin).or_default().push(*perf);
436            }
437        }
438
439        // Compute variance explained by this parameter
440        let mut variance_explained = 0.0;
441        for performances in groups.values() {
442            if performances.is_empty() {
443                continue;
444            }
445            let group_mean = performances.iter().sum::<Float>() / performances.len() as Float;
446            let group_size = performances.len() as Float;
447            variance_explained += group_size * (group_mean - mean).powi(2);
448        }
449
450        Ok(variance_explained / data.len() as Float)
451    }
452
453    fn compute_interaction_variance(
454        &self,
455        data: &[(HashMap<String, Float>, Float)],
456        param1: &str,
457        param2: &str,
458        mean: Float,
459    ) -> Result<Float, Box<dyn std::error::Error>> {
460        // Compute interaction effect
461        let mut groups: HashMap<(String, String), Vec<Float>> = HashMap::new();
462
463        for (params, perf) in data {
464            if let (Some(&v1), Some(&v2)) = (params.get(param1), params.get(param2)) {
465                let bin1 = format!("{:.2}", v1);
466                let bin2 = format!("{:.2}", v2);
467                groups.entry((bin1, bin2)).or_default().push(*perf);
468            }
469        }
470
471        // Interaction variance
472        let var1 = self.compute_variance_contribution(data, param1, mean)?;
473        let var2 = self.compute_variance_contribution(data, param2, mean)?;
474
475        let mut joint_variance = 0.0;
476        for performances in groups.values() {
477            if performances.is_empty() {
478                continue;
479            }
480            let group_mean = performances.iter().sum::<Float>() / performances.len() as Float;
481            let group_size = performances.len() as Float;
482            joint_variance += group_size * (group_mean - mean).powi(2);
483        }
484        joint_variance /= data.len() as Float;
485
486        // Interaction = joint - individual effects
487        Ok((joint_variance - var1 - var2).max(0.0))
488    }
489
490    fn rank_by_importance(&self, effects: &HashMap<String, Float>) -> Vec<(String, Float)> {
491        let mut ranked: Vec<_> = effects.iter().map(|(k, &v)| (k.clone(), v)).collect();
492        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
493        ranked
494    }
495}
496
497/// Result of fANOVA analysis
498#[derive(Debug, Clone)]
499pub struct FANOVAResult {
500    pub main_effects: HashMap<String, Float>,
501    pub interaction_effects: HashMap<(String, String), Float>,
502    pub total_variance: Float,
503    pub parameter_rankings: Vec<(String, Float)>,
504}
505
506// ============================================================================
507// Parameter Sensitivity Analysis
508// ============================================================================
509
510/// Configuration for sensitivity analysis
511#[derive(Debug, Clone)]
512pub struct SensitivityConfig {
513    /// Number of samples for Morris method
514    pub n_trajectories: usize,
515    /// Grid levels for Sobol analysis
516    pub n_levels: usize,
517    /// Perturbation delta for finite differences
518    pub perturbation_delta: Float,
519    pub random_state: Option<u64>,
520}
521
522impl Default for SensitivityConfig {
523    fn default() -> Self {
524        Self {
525            n_trajectories: 10,
526            n_levels: 4,
527            perturbation_delta: 0.01,
528            random_state: None,
529        }
530    }
531}
532
533/// Sensitivity analyzer using various methods
534pub struct SensitivityAnalyzer {
535    config: SensitivityConfig,
536    rng: StdRng,
537}
538
539impl SensitivityAnalyzer {
540    pub fn new(config: SensitivityConfig) -> Self {
541        let rng = StdRng::seed_from_u64(config.random_state.unwrap_or(42));
542        Self { config, rng }
543    }
544
545    /// Perform Morris sensitivity analysis
546    pub fn morris_analysis(
547        &mut self,
548        evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
549        parameter_space: &HashMap<String, (Float, Float)>,
550        base_config: &HashMap<String, Float>,
551    ) -> Result<SensitivityResult, Box<dyn std::error::Error>> {
552        let param_names: Vec<_> = parameter_space.keys().cloned().collect();
553        let mut elementary_effects: HashMap<String, Vec<Float>> = HashMap::new();
554
555        for _ in 0..self.config.n_trajectories {
556            // Generate random trajectory
557            let mut current = base_config.clone();
558
559            for param_name in &param_names {
560                if let Some(&(min, max)) = parameter_space.get(param_name) {
561                    // Perturb parameter
562                    let original_value = current.get(param_name).cloned().unwrap_or(min);
563                    let delta = self.config.perturbation_delta * (max - min);
564
565                    let perturbed_value = (original_value + delta).min(max);
566                    let mut perturbed = current.clone();
567                    perturbed.insert(param_name.clone(), perturbed_value);
568
569                    // Compute elementary effect
570                    let f_original = evaluation_fn(&current);
571                    let f_perturbed = evaluation_fn(&perturbed);
572                    let effect = (f_perturbed - f_original) / delta;
573
574                    elementary_effects
575                        .entry(param_name.clone())
576                        .or_default()
577                        .push(effect);
578
579                    current = perturbed;
580                }
581            }
582        }
583
584        // Compute statistics
585        let mut sensitivities = HashMap::new();
586        let interactions = HashMap::new();
587
588        for (param_name, effects) in &elementary_effects {
589            let mean = effects.iter().sum::<Float>() / effects.len() as Float;
590            let variance =
591                effects.iter().map(|&e| (e - mean).powi(2)).sum::<Float>() / effects.len() as Float;
592            let std_dev = variance.sqrt();
593
594            sensitivities.insert(
595                param_name.clone(),
596                ParameterSensitivity {
597                    mean_effect: mean.abs(),
598                    std_effect: std_dev,
599                    mu_star: mean.abs(), // |μ|
600                    sigma: std_dev,
601                },
602            );
603        }
604
605        let rankings = self.rank_sensitivities(&sensitivities);
606
607        Ok(SensitivityResult {
608            sensitivities,
609            interactions,
610            rankings,
611        })
612    }
613
614    /// One-at-a-time (OAT) sensitivity analysis
615    pub fn oat_analysis(
616        &mut self,
617        evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
618        parameter_space: &HashMap<String, (Float, Float)>,
619        base_config: &HashMap<String, Float>,
620    ) -> Result<SensitivityResult, Box<dyn std::error::Error>> {
621        let param_names: Vec<_> = parameter_space.keys().cloned().collect();
622        let mut sensitivities = HashMap::new();
623        let baseline_perf = evaluation_fn(base_config);
624
625        for param_name in &param_names {
626            if let Some(&(min, max)) = parameter_space.get(param_name) {
627                let base_value = base_config
628                    .get(param_name)
629                    .cloned()
630                    .unwrap_or((min + max) / 2.0);
631
632                // Evaluate at several points
633                let n_points = 5;
634                let mut effects = Vec::new();
635
636                for i in 0..n_points {
637                    let alpha = i as Float / (n_points - 1) as Float;
638                    let value = min + alpha * (max - min);
639
640                    if (value - base_value).abs() < 1e-6 {
641                        continue;
642                    }
643
644                    let mut perturbed = base_config.clone();
645                    perturbed.insert(param_name.clone(), value);
646
647                    let perf = evaluation_fn(&perturbed);
648                    let effect = (perf - baseline_perf).abs() / (value - base_value).abs();
649                    effects.push(effect);
650                }
651
652                if !effects.is_empty() {
653                    let mean_effect = effects.iter().sum::<Float>() / effects.len() as Float;
654                    let variance = effects
655                        .iter()
656                        .map(|&e| (e - mean_effect).powi(2))
657                        .sum::<Float>()
658                        / effects.len() as Float;
659
660                    sensitivities.insert(
661                        param_name.clone(),
662                        ParameterSensitivity {
663                            mean_effect,
664                            std_effect: variance.sqrt(),
665                            mu_star: mean_effect,
666                            sigma: variance.sqrt(),
667                        },
668                    );
669                }
670            }
671        }
672
673        let rankings = self.rank_sensitivities(&sensitivities);
674
675        Ok(SensitivityResult {
676            sensitivities,
677            interactions: HashMap::new(),
678            rankings,
679        })
680    }
681
682    fn rank_sensitivities(
683        &self,
684        sensitivities: &HashMap<String, ParameterSensitivity>,
685    ) -> Vec<(String, Float)> {
686        let mut ranked: Vec<_> = sensitivities
687            .iter()
688            .map(|(name, sens)| (name.clone(), sens.mu_star))
689            .collect();
690        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
691        ranked
692    }
693}
694
695/// Sensitivity of a single parameter
696#[derive(Debug, Clone)]
697pub struct ParameterSensitivity {
698    pub mean_effect: Float,
699    pub std_effect: Float,
700    pub mu_star: Float,
701    pub sigma: Float,
702}
703
704/// Result of sensitivity analysis
705#[derive(Debug, Clone)]
706pub struct SensitivityResult {
707    pub sensitivities: HashMap<String, ParameterSensitivity>,
708    pub interactions: HashMap<(String, String), Float>,
709    pub rankings: Vec<(String, Float)>,
710}
711
712// ============================================================================
713// Ablation Studies
714// ============================================================================
715
716/// Configuration for ablation studies
717#[derive(Debug, Clone)]
718pub struct AblationConfig {
719    /// Number of ablation iterations
720    pub n_iterations: usize,
721    /// Whether to use leave-one-out approach
722    pub leave_one_out: bool,
723    /// Whether to use cumulative importance
724    pub cumulative: bool,
725    pub random_state: Option<u64>,
726}
727
728impl Default for AblationConfig {
729    fn default() -> Self {
730        Self {
731            n_iterations: 10,
732            leave_one_out: true,
733            cumulative: false,
734            random_state: None,
735        }
736    }
737}
738
739/// Ablation study analyzer
740pub struct AblationAnalyzer {
741    config: AblationConfig,
742}
743
744impl AblationAnalyzer {
745    pub fn new(config: AblationConfig) -> Self {
746        Self { config }
747    }
748
749    /// Perform ablation study
750    pub fn analyze(
751        &self,
752        evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
753        parameter_space: &HashMap<String, (Float, Float)>,
754        base_config: &HashMap<String, Float>,
755    ) -> Result<AblationResult, Box<dyn std::error::Error>> {
756        let param_names: Vec<_> = parameter_space.keys().cloned().collect();
757        let baseline_performance = evaluation_fn(base_config);
758        let mut ablation_effects = HashMap::new();
759
760        if self.config.leave_one_out {
761            // Leave-one-out ablation
762            for param_name in &param_names {
763                let mut ablated = base_config.clone();
764
765                // Remove parameter (use default/random value)
766                if let Some(&(min, max)) = parameter_space.get(param_name) {
767                    ablated.insert(param_name.clone(), (min + max) / 2.0);
768                }
769
770                let ablated_perf = evaluation_fn(&ablated);
771                let effect = baseline_performance - ablated_perf;
772
773                ablation_effects.insert(param_name.clone(), effect);
774            }
775        } else {
776            // Cumulative ablation
777            let mut current = base_config.clone();
778            let mut remaining_params = param_names.clone();
779
780            while !remaining_params.is_empty() {
781                let mut best_param = None;
782                let mut best_effect = f64::NEG_INFINITY;
783
784                for param_name in &remaining_params {
785                    let mut test_config = current.clone();
786                    if let Some(&(min, max)) = parameter_space.get(param_name) {
787                        test_config.insert(param_name.clone(), (min + max) / 2.0);
788                    }
789
790                    let perf = evaluation_fn(&test_config);
791                    let effect = baseline_performance - perf;
792
793                    if effect > best_effect {
794                        best_effect = effect;
795                        best_param = Some(param_name.clone());
796                    }
797                }
798
799                if let Some(param) = best_param {
800                    ablation_effects.insert(param.clone(), best_effect);
801                    if let Some(&(min, max)) = parameter_space.get(&param) {
802                        current.insert(param.clone(), (min + max) / 2.0);
803                    }
804                    remaining_params.retain(|p| p != &param);
805                }
806            }
807        }
808
809        let rankings = self.rank_ablation_effects(&ablation_effects);
810
811        Ok(AblationResult {
812            ablation_effects,
813            baseline_performance,
814            parameter_rankings: rankings,
815        })
816    }
817
818    fn rank_ablation_effects(&self, effects: &HashMap<String, Float>) -> Vec<(String, Float)> {
819        let mut ranked: Vec<_> = effects.iter().map(|(k, &v)| (k.clone(), v.abs())).collect();
820        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
821        ranked
822    }
823}
824
825/// Result of ablation study
826#[derive(Debug, Clone)]
827pub struct AblationResult {
828    pub ablation_effects: HashMap<String, Float>,
829    pub baseline_performance: Float,
830    pub parameter_rankings: Vec<(String, Float)>,
831}
832
833// ============================================================================
834// Unified Importance Analysis
835// ============================================================================
836
837/// Comprehensive hyperparameter importance analyzer
838pub struct HyperparameterImportanceAnalyzer {
839    shap_analyzer: SHAPAnalyzer,
840    fanova_analyzer: FANOVAAnalyzer,
841    sensitivity_analyzer: SensitivityAnalyzer,
842    ablation_analyzer: AblationAnalyzer,
843}
844
845impl HyperparameterImportanceAnalyzer {
846    pub fn new(
847        shap_config: SHAPConfig,
848        fanova_config: FANOVAConfig,
849        sensitivity_config: SensitivityConfig,
850        ablation_config: AblationConfig,
851    ) -> Self {
852        Self {
853            shap_analyzer: SHAPAnalyzer::new(shap_config),
854            fanova_analyzer: FANOVAAnalyzer::new(fanova_config),
855            sensitivity_analyzer: SensitivityAnalyzer::new(sensitivity_config),
856            ablation_analyzer: AblationAnalyzer::new(ablation_config),
857        }
858    }
859
860    /// Perform comprehensive importance analysis
861    pub fn analyze_comprehensive(
862        &mut self,
863        evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
864        parameter_space: &HashMap<String, (Float, Float)>,
865        base_config: &HashMap<String, Float>,
866        evaluation_data: &[(HashMap<String, Float>, Float)],
867    ) -> Result<ComprehensiveImportanceResult, Box<dyn std::error::Error>> {
868        let param_names: Vec<_> = parameter_space.keys().cloned().collect();
869
870        // Run all analyses
871        let shap_result =
872            self.shap_analyzer
873                .compute_shap_values(evaluation_fn, parameter_space, base_config)?;
874
875        let fanova_result = self
876            .fanova_analyzer
877            .analyze(evaluation_data, &param_names)?;
878
879        let sensitivity_result = self.sensitivity_analyzer.morris_analysis(
880            evaluation_fn,
881            parameter_space,
882            base_config,
883        )?;
884
885        let ablation_result =
886            self.ablation_analyzer
887                .analyze(evaluation_fn, parameter_space, base_config)?;
888
889        // Aggregate rankings
890        let aggregated_rankings = self.aggregate_rankings(
891            &shap_result.parameter_rankings,
892            &fanova_result.parameter_rankings,
893            &sensitivity_result.rankings,
894            &ablation_result.parameter_rankings,
895        );
896
897        Ok(ComprehensiveImportanceResult {
898            shap_result,
899            fanova_result,
900            sensitivity_result,
901            ablation_result,
902            aggregated_rankings,
903        })
904    }
905
906    fn aggregate_rankings(
907        &self,
908        shap: &[(String, Float)],
909        fanova: &[(String, Float)],
910        sensitivity: &[(String, Float)],
911        ablation: &[(String, Float)],
912    ) -> Vec<(String, Float)> {
913        let mut scores: HashMap<String, Vec<Float>> = HashMap::new();
914
915        // Normalize and aggregate
916        for (param, value) in shap {
917            scores.entry(param.clone()).or_default().push(*value);
918        }
919        for (param, value) in fanova {
920            scores.entry(param.clone()).or_default().push(*value);
921        }
922        for (param, value) in sensitivity {
923            scores.entry(param.clone()).or_default().push(*value);
924        }
925        for (param, value) in ablation {
926            scores.entry(param.clone()).or_default().push(*value);
927        }
928
929        let mut aggregated: Vec<_> = scores
930            .iter()
931            .map(|(param, values)| {
932                let avg = values.iter().sum::<Float>() / values.len() as Float;
933                (param.clone(), avg)
934            })
935            .collect();
936
937        aggregated.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
938        aggregated
939    }
940}
941
942/// Comprehensive importance analysis result
943#[derive(Debug, Clone)]
944pub struct ComprehensiveImportanceResult {
945    pub shap_result: SHAPResult,
946    pub fanova_result: FANOVAResult,
947    pub sensitivity_result: SensitivityResult,
948    pub ablation_result: AblationResult,
949    pub aggregated_rankings: Vec<(String, Float)>,
950}
951
952// ============================================================================
953// Convenience Functions
954// ============================================================================
955
956/// Compute SHAP values for hyperparameters
957pub fn compute_shap_importance(
958    evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
959    parameter_space: &HashMap<String, (Float, Float)>,
960    reference_config: &HashMap<String, Float>,
961) -> Result<SHAPResult, Box<dyn std::error::Error>> {
962    let config = SHAPConfig::default();
963    let mut analyzer = SHAPAnalyzer::new(config);
964    analyzer.compute_shap_values(evaluation_fn, parameter_space, reference_config)
965}
966
967/// Perform sensitivity analysis
968pub fn analyze_parameter_sensitivity(
969    evaluation_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
970    parameter_space: &HashMap<String, (Float, Float)>,
971    base_config: &HashMap<String, Float>,
972) -> Result<SensitivityResult, Box<dyn std::error::Error>> {
973    let config = SensitivityConfig::default();
974    let mut analyzer = SensitivityAnalyzer::new(config);
975    analyzer.morris_analysis(evaluation_fn, parameter_space, base_config)
976}
977
978// ============================================================================
979// Tests
980// ============================================================================
981
982#[cfg(test)]
983mod tests {
984    use super::*;
985
986    #[test]
987    fn test_shap_config() {
988        let config = SHAPConfig::default();
989        assert_eq!(config.n_samples, 1000);
990        assert!(config.use_kernel_shap);
991    }
992
993    #[test]
994    fn test_fanova_config() {
995        let config = FANOVAConfig::default();
996        assert_eq!(config.n_trees, 16);
997        assert_eq!(config.max_depth, 6);
998    }
999
1000    #[test]
1001    fn test_sensitivity_config() {
1002        let config = SensitivityConfig::default();
1003        assert_eq!(config.n_trajectories, 10);
1004        assert_eq!(config.n_levels, 4);
1005    }
1006
1007    #[test]
1008    fn test_ablation_config() {
1009        let config = AblationConfig::default();
1010        assert_eq!(config.n_iterations, 10);
1011        assert!(config.leave_one_out);
1012    }
1013
1014    #[test]
1015    fn test_shap_analyzer_creation() {
1016        let config = SHAPConfig::default();
1017        let analyzer = SHAPAnalyzer::new(config);
1018        assert_eq!(analyzer.config.n_samples, 1000);
1019    }
1020
1021    #[test]
1022    fn test_sensitivity_analyzer_creation() {
1023        let config = SensitivityConfig::default();
1024        let analyzer = SensitivityAnalyzer::new(config);
1025        assert_eq!(analyzer.config.n_trajectories, 10);
1026    }
1027}