Skip to main content

sklears_model_selection/
parameter_space.rs

1//! Enhanced parameter space definitions with categorical parameter handling
2//!
3//! This module provides advanced parameter space definitions including conditional parameters,
4//! constraints, and dependency handling for sophisticated hyperparameter optimization.
5
6use crate::grid_search::{ParameterSet, ParameterValue};
7use scirs2_core::rand_prelude::IndexedRandom;
8use scirs2_core::random::prelude::*;
9use sklears_core::error::{Result, SklearsError};
10use std::collections::{HashMap, HashSet};
11use std::sync::Arc;
12
13/// Categorical parameter definition with enhanced features
14#[derive(Debug, Clone)]
15pub struct CategoricalParameter {
16    /// Name of the parameter
17    pub name: String,
18    /// Possible values
19    pub values: Vec<ParameterValue>,
20    /// Whether the categories have an ordering
21    pub ordered: bool,
22    /// Default value if any
23    pub default: Option<ParameterValue>,
24    /// Description of the parameter
25    pub description: Option<String>,
26}
27
28impl CategoricalParameter {
29    /// Create a new categorical parameter
30    pub fn new(name: String, values: Vec<ParameterValue>) -> Self {
31        Self {
32            name,
33            values,
34            ordered: false,
35            default: None,
36            description: None,
37        }
38    }
39
40    /// Create an ordered categorical parameter
41    pub fn ordered(name: String, values: Vec<ParameterValue>) -> Self {
42        Self {
43            name,
44            values,
45            ordered: true,
46            default: None,
47            description: None,
48        }
49    }
50
51    /// Set default value
52    pub fn with_default(mut self, default: ParameterValue) -> Self {
53        self.default = Some(default);
54        self
55    }
56
57    /// Set description
58    pub fn with_description(mut self, description: String) -> Self {
59        self.description = Some(description);
60        self
61    }
62
63    /// Sample a random value from this parameter
64    pub fn sample(&self, rng: &mut impl Rng) -> ParameterValue {
65        self.values
66            .choose(rng)
67            .expect("operation should succeed")
68            .clone()
69    }
70
71    /// Get the index of a value (useful for ordered categories)
72    pub fn get_index(&self, value: &ParameterValue) -> Option<usize> {
73        self.values.iter().position(|v| v == value)
74    }
75
76    /// Get neighboring values for ordered categories
77    pub fn get_neighbors(&self, value: &ParameterValue) -> Vec<ParameterValue> {
78        if !self.ordered {
79            return vec![];
80        }
81
82        if let Some(idx) = self.get_index(value) {
83            let mut neighbors = Vec::new();
84            if idx > 0 {
85                neighbors.push(self.values[idx - 1].clone());
86            }
87            if idx + 1 < self.values.len() {
88                neighbors.push(self.values[idx + 1].clone());
89            }
90            neighbors
91        } else {
92            vec![]
93        }
94    }
95}
96
97/// Parameter constraint type
98#[derive(Clone)]
99pub enum ParameterConstraint {
100    /// Equality constraint: param1 == value when param2 == condition
101    Equality {
102        param: String,
103
104        value: ParameterValue,
105
106        condition_param: String,
107
108        condition_value: ParameterValue,
109    },
110    /// Inequality constraint: param1 != value when param2 == condition
111    Inequality {
112        param: String,
113        value: ParameterValue,
114        condition_param: String,
115        condition_value: ParameterValue,
116    },
117    /// Range constraint: param1 in range when param2 == condition
118    Range {
119        param: String,
120        min_value: ParameterValue,
121        max_value: ParameterValue,
122        condition_param: String,
123        condition_value: ParameterValue,
124    },
125    /// Mutual exclusion: if param1 == value1, then param2 != value2
126    MutualExclusion {
127        param1: String,
128        value1: ParameterValue,
129        param2: String,
130        value2: ParameterValue,
131    },
132    /// Custom constraint function
133    Custom {
134        name: String,
135        constraint_fn: Arc<dyn Fn(&ParameterSet) -> bool + Send + Sync>,
136    },
137}
138
139impl std::fmt::Debug for ParameterConstraint {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        match self {
142            ParameterConstraint::Equality {
143                param,
144                value,
145                condition_param,
146                condition_value,
147            } => {
148                write!(f, "Equality {{ param: {:?}, value: {:?}, condition_param: {:?}, condition_value: {:?} }}", 
149                       param, value, condition_param, condition_value)
150            }
151            ParameterConstraint::Inequality {
152                param,
153                value,
154                condition_param,
155                condition_value,
156            } => {
157                write!(f, "Inequality {{ param: {:?}, value: {:?}, condition_param: {:?}, condition_value: {:?} }}", 
158                       param, value, condition_param, condition_value)
159            }
160            ParameterConstraint::Range {
161                param,
162                min_value,
163                max_value,
164                condition_param,
165                condition_value,
166            } => {
167                write!(f, "Range {{ param: {:?}, min_value: {:?}, max_value: {:?}, condition_param: {:?}, condition_value: {:?} }}", 
168                       param, min_value, max_value, condition_param, condition_value)
169            }
170            ParameterConstraint::MutualExclusion {
171                param1,
172                value1,
173                param2,
174                value2,
175            } => {
176                write!(
177                    f,
178                    "MutualExclusion {{ param1: {:?}, value1: {:?}, param2: {:?}, value2: {:?} }}",
179                    param1, value1, param2, value2
180                )
181            }
182            ParameterConstraint::Custom { name, .. } => {
183                write!(f, "Custom {{ name: {:?}, constraint_fn: <closure> }}", name)
184            }
185        }
186    }
187}
188
189impl ParameterConstraint {
190    /// Check if a parameter set satisfies this constraint
191    pub fn is_satisfied(&self, params: &ParameterSet) -> bool {
192        match self {
193            ParameterConstraint::Equality {
194                param,
195                value,
196                condition_param,
197                condition_value,
198            } => {
199                if let (Some(param_val), Some(condition_val)) =
200                    (params.get(param), params.get(condition_param))
201                {
202                    if condition_val == condition_value {
203                        param_val == value
204                    } else {
205                        true // Constraint doesn't apply
206                    }
207                } else {
208                    false // Missing parameters
209                }
210            }
211            ParameterConstraint::Inequality {
212                param,
213                value,
214                condition_param,
215                condition_value,
216            } => {
217                if let (Some(param_val), Some(condition_val)) =
218                    (params.get(param), params.get(condition_param))
219                {
220                    if condition_val == condition_value {
221                        param_val != value
222                    } else {
223                        true // Constraint doesn't apply
224                    }
225                } else {
226                    false // Missing parameters
227                }
228            }
229            ParameterConstraint::Range {
230                param,
231                min_value,
232                max_value,
233                condition_param,
234                condition_value,
235            } => {
236                if let (Some(param_val), Some(condition_val)) =
237                    (params.get(param), params.get(condition_param))
238                {
239                    if condition_val == condition_value {
240                        self.is_in_range(param_val, min_value, max_value)
241                    } else {
242                        true // Constraint doesn't apply
243                    }
244                } else {
245                    false // Missing parameters
246                }
247            }
248            ParameterConstraint::MutualExclusion {
249                param1,
250                value1,
251                param2,
252                value2,
253            } => {
254                if let (Some(param1_val), Some(param2_val)) =
255                    (params.get(param1), params.get(param2))
256                {
257                    !(param1_val == value1 && param2_val == value2)
258                } else {
259                    true // Missing parameters - constraint satisfied
260                }
261            }
262            ParameterConstraint::Custom { constraint_fn, .. } => constraint_fn(params),
263        }
264    }
265
266    fn is_in_range(
267        &self,
268        value: &ParameterValue,
269        min_value: &ParameterValue,
270        max_value: &ParameterValue,
271    ) -> bool {
272        match (value, min_value, max_value) {
273            (ParameterValue::Int(v), ParameterValue::Int(min), ParameterValue::Int(max)) => {
274                v >= min && v <= max
275            }
276            (ParameterValue::Float(v), ParameterValue::Float(min), ParameterValue::Float(max)) => {
277                v >= min && v <= max
278            }
279            _ => false, // Type mismatch
280        }
281    }
282}
283
284/// Conditional parameter definition
285#[derive(Debug, Clone)]
286pub struct ConditionalParameter {
287    /// Base parameter
288    pub parameter: CategoricalParameter,
289    /// Conditions under which this parameter is active
290    pub conditions: Vec<(String, ParameterValue)>,
291    /// Whether all conditions must be met (AND) or any (OR)
292    pub require_all_conditions: bool,
293}
294
295impl ConditionalParameter {
296    /// Create a new conditional parameter
297    pub fn new(parameter: CategoricalParameter, conditions: Vec<(String, ParameterValue)>) -> Self {
298        Self {
299            parameter,
300            conditions,
301            require_all_conditions: true,
302        }
303    }
304
305    /// Set whether all conditions must be met
306    pub fn require_all_conditions(mut self, require_all: bool) -> Self {
307        self.require_all_conditions = require_all;
308        self
309    }
310
311    /// Check if this parameter is active given the current parameter set
312    pub fn is_active(&self, params: &ParameterSet) -> bool {
313        if self.conditions.is_empty() {
314            return true;
315        }
316
317        let satisfied_conditions = self
318            .conditions
319            .iter()
320            .filter(|(param_name, expected_value)| {
321                params
322                    .get(param_name)
323                    .map(|value| value == expected_value)
324                    .unwrap_or(false)
325            })
326            .count();
327
328        if self.require_all_conditions {
329            satisfied_conditions == self.conditions.len()
330        } else {
331            satisfied_conditions > 0
332        }
333    }
334
335    /// Sample from this parameter if it's active
336    pub fn sample_if_active(
337        &self,
338        params: &ParameterSet,
339        rng: &mut impl Rng,
340    ) -> Option<ParameterValue> {
341        if self.is_active(params) {
342            Some(self.parameter.sample(rng))
343        } else {
344            None
345        }
346    }
347}
348
349/// Enhanced parameter space with categorical parameter support
350#[derive(Debug, Clone)]
351pub struct ParameterSpace {
352    /// Base categorical parameters
353    pub categorical_params: HashMap<String, CategoricalParameter>,
354    /// Conditional parameters
355    pub conditional_params: HashMap<String, ConditionalParameter>,
356    /// Parameter constraints
357    pub constraints: Vec<ParameterConstraint>,
358    /// Parameter dependencies (which parameters depend on which)
359    pub dependencies: HashMap<String, HashSet<String>>,
360}
361
362impl ParameterSpace {
363    /// Create a new parameter space
364    pub fn new() -> Self {
365        Self {
366            categorical_params: HashMap::new(),
367            conditional_params: HashMap::new(),
368            constraints: Vec::new(),
369            dependencies: HashMap::new(),
370        }
371    }
372
373    /// Add a categorical parameter
374    pub fn add_categorical_parameter(&mut self, param: CategoricalParameter) {
375        self.categorical_params.insert(param.name.clone(), param);
376    }
377
378    /// Add a conditional parameter
379    pub fn add_conditional_parameter(&mut self, param: ConditionalParameter) {
380        // Track dependencies
381        for (dep_param, _) in &param.conditions {
382            self.dependencies
383                .entry(param.parameter.name.clone())
384                .or_default()
385                .insert(dep_param.clone());
386        }
387        self.conditional_params
388            .insert(param.parameter.name.clone(), param);
389    }
390
391    /// Add a constraint
392    pub fn add_constraint(&mut self, constraint: ParameterConstraint) {
393        self.constraints.push(constraint);
394    }
395
396    /// Sample a valid parameter set from this space
397    pub fn sample(&self, rng: &mut impl Rng) -> Result<ParameterSet> {
398        let mut params = ParameterSet::new();
399        let mut attempts = 0;
400        const MAX_ATTEMPTS: usize = 1000;
401
402        while attempts < MAX_ATTEMPTS {
403            params.clear();
404
405            // Sample base categorical parameters first
406            for (name, param) in &self.categorical_params {
407                params.insert(name.clone(), param.sample(rng));
408            }
409
410            // Sample conditional parameters
411            for (name, conditional_param) in &self.conditional_params {
412                if let Some(value) = conditional_param.sample_if_active(&params, rng) {
413                    params.insert(name.clone(), value);
414                }
415            }
416
417            // Check constraints
418            if self.is_valid_parameter_set(&params) {
419                return Ok(params);
420            }
421
422            attempts += 1;
423        }
424
425        Err(SklearsError::InvalidInput(format!(
426            "Failed to sample valid parameter set after {} attempts",
427            MAX_ATTEMPTS
428        )))
429    }
430
431    /// Check if a parameter set is valid according to all constraints
432    pub fn is_valid_parameter_set(&self, params: &ParameterSet) -> bool {
433        self.constraints
434            .iter()
435            .all(|constraint| constraint.is_satisfied(params))
436    }
437
438    /// Get all possible parameter names
439    pub fn get_parameter_names(&self) -> HashSet<String> {
440        let mut names = HashSet::new();
441        names.extend(self.categorical_params.keys().cloned());
442        names.extend(self.conditional_params.keys().cloned());
443        names
444    }
445
446    /// Get parameters that depend on a given parameter
447    pub fn get_dependent_parameters(&self, param_name: &str) -> HashSet<String> {
448        self.dependencies
449            .iter()
450            .filter_map(|(dependent, dependencies)| {
451                if dependencies.contains(param_name) {
452                    Some(dependent.clone())
453                } else {
454                    None
455                }
456            })
457            .collect()
458    }
459
460    /// Get the dependencies of a parameter
461    pub fn get_parameter_dependencies(&self, param_name: &str) -> HashSet<String> {
462        self.dependencies
463            .get(param_name)
464            .cloned()
465            .unwrap_or_default()
466    }
467
468    /// Generate a smart sample that respects parameter importance
469    pub fn sample_with_importance(
470        &self,
471        rng: &mut impl Rng,
472        importance_weights: &HashMap<String, f64>,
473    ) -> Result<ParameterSet> {
474        let mut params = ParameterSet::new();
475        let mut attempts = 0;
476        const MAX_ATTEMPTS: usize = 1000;
477
478        while attempts < MAX_ATTEMPTS {
479            params.clear();
480
481            // Sort parameters by importance (higher importance first)
482            let mut sorted_params: Vec<_> = self.categorical_params.keys().collect();
483            sorted_params.sort_by(|a, b| {
484                let weight_a = importance_weights.get(*a).unwrap_or(&1.0);
485                let weight_b = importance_weights.get(*b).unwrap_or(&1.0);
486                weight_b
487                    .partial_cmp(weight_a)
488                    .expect("operation should succeed")
489            });
490
491            // Sample important parameters first
492            for name in sorted_params {
493                if let Some(param) = self.categorical_params.get(name) {
494                    params.insert(name.clone(), param.sample(rng));
495                }
496            }
497
498            // Sample conditional parameters
499            for (name, conditional_param) in &self.conditional_params {
500                if let Some(value) = conditional_param.sample_if_active(&params, rng) {
501                    params.insert(name.clone(), value);
502                }
503            }
504
505            // Check constraints
506            if self.is_valid_parameter_set(&params) {
507                return Ok(params);
508            }
509
510            attempts += 1;
511        }
512
513        Err(SklearsError::InvalidInput(format!(
514            "Failed to sample valid parameter set after {} attempts",
515            MAX_ATTEMPTS
516        )))
517    }
518
519    /// Convenience method to add a float parameter with min/max range
520    pub fn add_float_param(&mut self, name: &str, min: f64, max: f64) {
521        // Create a reasonable set of values across the range
522        let mut values = Vec::new();
523        let n_values = 10; // Default number of values to sample
524        for i in 0..n_values {
525            let ratio = i as f64 / (n_values - 1) as f64;
526            let value = min + ratio * (max - min);
527            values.push(ParameterValue::Float(value));
528        }
529
530        let param = CategoricalParameter::new(name.to_string(), values);
531        self.add_categorical_parameter(param);
532    }
533
534    /// Convenience method to add an integer parameter with min/max range
535    pub fn add_int_param(&mut self, name: &str, min: i64, max: i64) {
536        let mut values = Vec::new();
537        let range = max - min + 1;
538        let n_values = if range <= 20 {
539            // If small range, include all values
540            range as usize
541        } else {
542            // If large range, sample 10 values
543            10
544        };
545
546        for i in 0..n_values {
547            let value = if range <= 20 {
548                min + i as i64
549            } else {
550                let ratio = i as f64 / (n_values - 1) as f64;
551                min + (ratio * (max - min) as f64) as i64
552            };
553            values.push(ParameterValue::Int(value));
554        }
555
556        let param = CategoricalParameter::new(name.to_string(), values);
557        self.add_categorical_parameter(param);
558    }
559
560    /// Convenience method to add a categorical parameter from string slice
561    pub fn add_categorical_param(&mut self, name: &str, values: Vec<&str>) {
562        let param_values = values
563            .into_iter()
564            .map(|s| ParameterValue::String(s.to_string()))
565            .collect();
566
567        let param = CategoricalParameter::new(name.to_string(), param_values);
568        self.add_categorical_parameter(param);
569    }
570
571    /// Convenience method to add a boolean parameter
572    pub fn add_boolean_param(&mut self, name: &str) {
573        let values = vec![ParameterValue::Bool(false), ParameterValue::Bool(true)];
574        let param = CategoricalParameter::new(name.to_string(), values);
575        self.add_categorical_parameter(param);
576    }
577
578    /// Auto-detect parameter ranges from a dataset of parameter sets
579    pub fn auto_detect_ranges(parameter_sets: &[ParameterSet]) -> Result<Self> {
580        let mut space = ParameterSpace::new();
581
582        if parameter_sets.is_empty() {
583            return Ok(space);
584        }
585
586        // Collect all parameter names
587        let mut all_param_names = HashSet::new();
588        for param_set in parameter_sets {
589            all_param_names.extend(param_set.keys().cloned());
590        }
591
592        // For each parameter, detect its range/categories
593        for param_name in all_param_names {
594            let mut values = HashSet::new();
595            for param_set in parameter_sets {
596                if let Some(value) = param_set.get(&param_name) {
597                    values.insert(value.clone());
598                }
599            }
600
601            if !values.is_empty() {
602                let values_vec: Vec<ParameterValue> = values.into_iter().collect();
603                let categorical_param = CategoricalParameter::new(param_name, values_vec);
604                space.add_categorical_parameter(categorical_param);
605            }
606        }
607
608        Ok(space)
609    }
610}
611
612impl Default for ParameterSpace {
613    fn default() -> Self {
614        Self::new()
615    }
616}
617
618/// Parameter importance analyzer
619#[derive(Debug)]
620pub struct ParameterImportanceAnalyzer {
621    /// Historical parameter evaluations
622    evaluations: Vec<(ParameterSet, f64)>,
623}
624
625impl ParameterImportanceAnalyzer {
626    /// Create a new importance analyzer
627    pub fn new() -> Self {
628        Self {
629            evaluations: Vec::new(),
630        }
631    }
632
633    /// Add an evaluation result
634    pub fn add_evaluation(&mut self, params: ParameterSet, score: f64) {
635        self.evaluations.push((params, score));
636    }
637
638    /// Calculate parameter importance using variance-based analysis
639    pub fn calculate_importance(&self) -> HashMap<String, f64> {
640        let mut importance = HashMap::new();
641
642        if self.evaluations.len() < 2 {
643            return importance;
644        }
645
646        // Get all parameter names
647        let mut all_params = HashSet::new();
648        for (params, _) in &self.evaluations {
649            all_params.extend(params.keys().cloned());
650        }
651
652        // Calculate variance for each parameter
653        for param_name in all_params {
654            let variance = self.calculate_parameter_variance(&param_name);
655            importance.insert(param_name, variance);
656        }
657
658        // Normalize importances
659        let max_importance = importance.values().fold(0.0f64, |a, &b| a.max(b));
660        if max_importance > 0.0 {
661            for value in importance.values_mut() {
662                *value /= max_importance;
663            }
664        }
665
666        importance
667    }
668
669    fn calculate_parameter_variance(&self, param_name: &str) -> f64 {
670        // Group evaluations by parameter value
671        let mut groups: HashMap<String, Vec<f64>> = HashMap::new();
672
673        for (params, score) in &self.evaluations {
674            if let Some(param_value) = params.get(param_name) {
675                let key = format!("{:?}", param_value);
676                groups.entry(key).or_default().push(*score);
677            }
678        }
679
680        if groups.len() < 2 {
681            return 0.0;
682        }
683
684        // Calculate within-group and between-group variance
685        let mut total_variance = 0.0;
686        let total_count = self.evaluations.len();
687
688        for scores in groups.values() {
689            if scores.len() > 1 {
690                let mean = scores.iter().sum::<f64>() / scores.len() as f64;
691                let variance = scores.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
692                    / (scores.len() - 1) as f64;
693                total_variance += variance * (scores.len() as f64 / total_count as f64);
694            }
695        }
696
697        total_variance
698    }
699}
700
701impl Default for ParameterImportanceAnalyzer {
702    fn default() -> Self {
703        Self::new()
704    }
705}
706
707#[allow(non_snake_case)]
708#[cfg(test)]
709mod tests {
710    use super::*;
711
712    #[test]
713    fn test_categorical_parameter() {
714        let param = CategoricalParameter::new(
715            "algorithm".to_string(),
716            vec!["svm".into(), "random_forest".into(), "neural_net".into()],
717        );
718
719        assert_eq!(param.values.len(), 3);
720        assert!(!param.ordered);
721    }
722
723    #[test]
724    fn test_ordered_categorical_parameter() {
725        let param = CategoricalParameter::ordered(
726            "complexity".to_string(),
727            vec!["low".into(), "medium".into(), "high".into()],
728        );
729
730        assert!(param.ordered);
731        let neighbors = param.get_neighbors(&"medium".into());
732        assert_eq!(neighbors.len(), 2);
733    }
734
735    #[test]
736    fn test_parameter_constraint() {
737        let constraint = ParameterConstraint::Equality {
738            param: "kernel".to_string(),
739            value: "rbf".into(),
740            condition_param: "algorithm".to_string(),
741            condition_value: "svm".into(),
742        };
743
744        let mut params = ParameterSet::new();
745        params.insert("algorithm".to_string(), "svm".into());
746        params.insert("kernel".to_string(), "rbf".into());
747
748        assert!(constraint.is_satisfied(&params));
749
750        params.insert("kernel".to_string(), "linear".into());
751        assert!(!constraint.is_satisfied(&params));
752    }
753
754    #[test]
755    fn test_conditional_parameter() {
756        let base_param = CategoricalParameter::new(
757            "kernel".to_string(),
758            vec!["linear".into(), "rbf".into(), "poly".into()],
759        );
760
761        let conditional_param =
762            ConditionalParameter::new(base_param, vec![("algorithm".to_string(), "svm".into())]);
763
764        let mut params = ParameterSet::new();
765        params.insert("algorithm".to_string(), "svm".into());
766        assert!(conditional_param.is_active(&params));
767
768        params.insert("algorithm".to_string(), "random_forest".into());
769        assert!(!conditional_param.is_active(&params));
770    }
771
772    #[test]
773    fn test_parameter_space_sampling() {
774        let mut space = ParameterSpace::new();
775
776        let algorithm_param = CategoricalParameter::new(
777            "algorithm".to_string(),
778            vec!["svm".into(), "random_forest".into()],
779        );
780        space.add_categorical_parameter(algorithm_param);
781
782        let kernel_param =
783            CategoricalParameter::new("kernel".to_string(), vec!["linear".into(), "rbf".into()]);
784        let conditional_kernel =
785            ConditionalParameter::new(kernel_param, vec![("algorithm".to_string(), "svm".into())]);
786        space.add_conditional_parameter(conditional_kernel);
787
788        let mut rng = scirs2_core::random::thread_rng();
789        let params = space.sample(&mut rng).expect("operation should succeed");
790
791        assert!(params.contains_key("algorithm"));
792
793        if params.get("algorithm").expect("operation should succeed") == &"svm".into() {
794            assert!(params.contains_key("kernel"));
795        }
796    }
797}