sklears_model_selection/
parameter_space.rs

1//! Enhanced parameter space definitions with categorical parameter handling
2//!
3//! This module provides advanced parameter space definitions including conditional parameters,
4//! constraints, and dependency handling for sophisticated hyperparameter optimization.
5
6use crate::grid_search::{ParameterSet, ParameterValue};
7use scirs2_core::rand_prelude::IndexedRandom;
8use scirs2_core::random::prelude::*;
9use sklears_core::error::{Result, SklearsError};
10use std::collections::{HashMap, HashSet};
11use std::sync::Arc;
12
13/// Categorical parameter definition with enhanced features
14#[derive(Debug, Clone)]
15pub struct CategoricalParameter {
16    /// Name of the parameter
17    pub name: String,
18    /// Possible values
19    pub values: Vec<ParameterValue>,
20    /// Whether the categories have an ordering
21    pub ordered: bool,
22    /// Default value if any
23    pub default: Option<ParameterValue>,
24    /// Description of the parameter
25    pub description: Option<String>,
26}
27
28impl CategoricalParameter {
29    /// Create a new categorical parameter
30    pub fn new(name: String, values: Vec<ParameterValue>) -> Self {
31        Self {
32            name,
33            values,
34            ordered: false,
35            default: None,
36            description: None,
37        }
38    }
39
40    /// Create an ordered categorical parameter
41    pub fn ordered(name: String, values: Vec<ParameterValue>) -> Self {
42        Self {
43            name,
44            values,
45            ordered: true,
46            default: None,
47            description: None,
48        }
49    }
50
51    /// Set default value
52    pub fn with_default(mut self, default: ParameterValue) -> Self {
53        self.default = Some(default);
54        self
55    }
56
57    /// Set description
58    pub fn with_description(mut self, description: String) -> Self {
59        self.description = Some(description);
60        self
61    }
62
63    /// Sample a random value from this parameter
64    pub fn sample(&self, rng: &mut impl Rng) -> ParameterValue {
65        self.values.choose(rng).unwrap().clone()
66    }
67
68    /// Get the index of a value (useful for ordered categories)
69    pub fn get_index(&self, value: &ParameterValue) -> Option<usize> {
70        self.values.iter().position(|v| v == value)
71    }
72
73    /// Get neighboring values for ordered categories
74    pub fn get_neighbors(&self, value: &ParameterValue) -> Vec<ParameterValue> {
75        if !self.ordered {
76            return vec![];
77        }
78
79        if let Some(idx) = self.get_index(value) {
80            let mut neighbors = Vec::new();
81            if idx > 0 {
82                neighbors.push(self.values[idx - 1].clone());
83            }
84            if idx + 1 < self.values.len() {
85                neighbors.push(self.values[idx + 1].clone());
86            }
87            neighbors
88        } else {
89            vec![]
90        }
91    }
92}
93
94/// Parameter constraint type
95#[derive(Clone)]
96pub enum ParameterConstraint {
97    /// Equality constraint: param1 == value when param2 == condition
98    Equality {
99        param: String,
100
101        value: ParameterValue,
102
103        condition_param: String,
104
105        condition_value: ParameterValue,
106    },
107    /// Inequality constraint: param1 != value when param2 == condition
108    Inequality {
109        param: String,
110        value: ParameterValue,
111        condition_param: String,
112        condition_value: ParameterValue,
113    },
114    /// Range constraint: param1 in range when param2 == condition
115    Range {
116        param: String,
117        min_value: ParameterValue,
118        max_value: ParameterValue,
119        condition_param: String,
120        condition_value: ParameterValue,
121    },
122    /// Mutual exclusion: if param1 == value1, then param2 != value2
123    MutualExclusion {
124        param1: String,
125        value1: ParameterValue,
126        param2: String,
127        value2: ParameterValue,
128    },
129    /// Custom constraint function
130    Custom {
131        name: String,
132        constraint_fn: Arc<dyn Fn(&ParameterSet) -> bool + Send + Sync>,
133    },
134}
135
136impl std::fmt::Debug for ParameterConstraint {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        match self {
139            ParameterConstraint::Equality {
140                param,
141                value,
142                condition_param,
143                condition_value,
144            } => {
145                write!(f, "Equality {{ param: {:?}, value: {:?}, condition_param: {:?}, condition_value: {:?} }}", 
146                       param, value, condition_param, condition_value)
147            }
148            ParameterConstraint::Inequality {
149                param,
150                value,
151                condition_param,
152                condition_value,
153            } => {
154                write!(f, "Inequality {{ param: {:?}, value: {:?}, condition_param: {:?}, condition_value: {:?} }}", 
155                       param, value, condition_param, condition_value)
156            }
157            ParameterConstraint::Range {
158                param,
159                min_value,
160                max_value,
161                condition_param,
162                condition_value,
163            } => {
164                write!(f, "Range {{ param: {:?}, min_value: {:?}, max_value: {:?}, condition_param: {:?}, condition_value: {:?} }}", 
165                       param, min_value, max_value, condition_param, condition_value)
166            }
167            ParameterConstraint::MutualExclusion {
168                param1,
169                value1,
170                param2,
171                value2,
172            } => {
173                write!(
174                    f,
175                    "MutualExclusion {{ param1: {:?}, value1: {:?}, param2: {:?}, value2: {:?} }}",
176                    param1, value1, param2, value2
177                )
178            }
179            ParameterConstraint::Custom { name, .. } => {
180                write!(f, "Custom {{ name: {:?}, constraint_fn: <closure> }}", name)
181            }
182        }
183    }
184}
185
186impl ParameterConstraint {
187    /// Check if a parameter set satisfies this constraint
188    pub fn is_satisfied(&self, params: &ParameterSet) -> bool {
189        match self {
190            ParameterConstraint::Equality {
191                param,
192                value,
193                condition_param,
194                condition_value,
195            } => {
196                if let (Some(param_val), Some(condition_val)) =
197                    (params.get(param), params.get(condition_param))
198                {
199                    if condition_val == condition_value {
200                        param_val == value
201                    } else {
202                        true // Constraint doesn't apply
203                    }
204                } else {
205                    false // Missing parameters
206                }
207            }
208            ParameterConstraint::Inequality {
209                param,
210                value,
211                condition_param,
212                condition_value,
213            } => {
214                if let (Some(param_val), Some(condition_val)) =
215                    (params.get(param), params.get(condition_param))
216                {
217                    if condition_val == condition_value {
218                        param_val != value
219                    } else {
220                        true // Constraint doesn't apply
221                    }
222                } else {
223                    false // Missing parameters
224                }
225            }
226            ParameterConstraint::Range {
227                param,
228                min_value,
229                max_value,
230                condition_param,
231                condition_value,
232            } => {
233                if let (Some(param_val), Some(condition_val)) =
234                    (params.get(param), params.get(condition_param))
235                {
236                    if condition_val == condition_value {
237                        self.is_in_range(param_val, min_value, max_value)
238                    } else {
239                        true // Constraint doesn't apply
240                    }
241                } else {
242                    false // Missing parameters
243                }
244            }
245            ParameterConstraint::MutualExclusion {
246                param1,
247                value1,
248                param2,
249                value2,
250            } => {
251                if let (Some(param1_val), Some(param2_val)) =
252                    (params.get(param1), params.get(param2))
253                {
254                    !(param1_val == value1 && param2_val == value2)
255                } else {
256                    true // Missing parameters - constraint satisfied
257                }
258            }
259            ParameterConstraint::Custom { constraint_fn, .. } => constraint_fn(params),
260        }
261    }
262
263    fn is_in_range(
264        &self,
265        value: &ParameterValue,
266        min_value: &ParameterValue,
267        max_value: &ParameterValue,
268    ) -> bool {
269        match (value, min_value, max_value) {
270            (ParameterValue::Int(v), ParameterValue::Int(min), ParameterValue::Int(max)) => {
271                v >= min && v <= max
272            }
273            (ParameterValue::Float(v), ParameterValue::Float(min), ParameterValue::Float(max)) => {
274                v >= min && v <= max
275            }
276            _ => false, // Type mismatch
277        }
278    }
279}
280
281/// Conditional parameter definition
282#[derive(Debug, Clone)]
283pub struct ConditionalParameter {
284    /// Base parameter
285    pub parameter: CategoricalParameter,
286    /// Conditions under which this parameter is active
287    pub conditions: Vec<(String, ParameterValue)>,
288    /// Whether all conditions must be met (AND) or any (OR)
289    pub require_all_conditions: bool,
290}
291
292impl ConditionalParameter {
293    /// Create a new conditional parameter
294    pub fn new(parameter: CategoricalParameter, conditions: Vec<(String, ParameterValue)>) -> Self {
295        Self {
296            parameter,
297            conditions,
298            require_all_conditions: true,
299        }
300    }
301
302    /// Set whether all conditions must be met
303    pub fn require_all_conditions(mut self, require_all: bool) -> Self {
304        self.require_all_conditions = require_all;
305        self
306    }
307
308    /// Check if this parameter is active given the current parameter set
309    pub fn is_active(&self, params: &ParameterSet) -> bool {
310        if self.conditions.is_empty() {
311            return true;
312        }
313
314        let satisfied_conditions = self
315            .conditions
316            .iter()
317            .filter(|(param_name, expected_value)| {
318                params
319                    .get(param_name)
320                    .map(|value| value == expected_value)
321                    .unwrap_or(false)
322            })
323            .count();
324
325        if self.require_all_conditions {
326            satisfied_conditions == self.conditions.len()
327        } else {
328            satisfied_conditions > 0
329        }
330    }
331
332    /// Sample from this parameter if it's active
333    pub fn sample_if_active(
334        &self,
335        params: &ParameterSet,
336        rng: &mut impl Rng,
337    ) -> Option<ParameterValue> {
338        if self.is_active(params) {
339            Some(self.parameter.sample(rng))
340        } else {
341            None
342        }
343    }
344}
345
346/// Enhanced parameter space with categorical parameter support
347#[derive(Debug, Clone)]
348pub struct ParameterSpace {
349    /// Base categorical parameters
350    pub categorical_params: HashMap<String, CategoricalParameter>,
351    /// Conditional parameters
352    pub conditional_params: HashMap<String, ConditionalParameter>,
353    /// Parameter constraints
354    pub constraints: Vec<ParameterConstraint>,
355    /// Parameter dependencies (which parameters depend on which)
356    pub dependencies: HashMap<String, HashSet<String>>,
357}
358
359impl ParameterSpace {
360    /// Create a new parameter space
361    pub fn new() -> Self {
362        Self {
363            categorical_params: HashMap::new(),
364            conditional_params: HashMap::new(),
365            constraints: Vec::new(),
366            dependencies: HashMap::new(),
367        }
368    }
369
370    /// Add a categorical parameter
371    pub fn add_categorical_parameter(&mut self, param: CategoricalParameter) {
372        self.categorical_params.insert(param.name.clone(), param);
373    }
374
375    /// Add a conditional parameter
376    pub fn add_conditional_parameter(&mut self, param: ConditionalParameter) {
377        // Track dependencies
378        for (dep_param, _) in &param.conditions {
379            self.dependencies
380                .entry(param.parameter.name.clone())
381                .or_default()
382                .insert(dep_param.clone());
383        }
384        self.conditional_params
385            .insert(param.parameter.name.clone(), param);
386    }
387
388    /// Add a constraint
389    pub fn add_constraint(&mut self, constraint: ParameterConstraint) {
390        self.constraints.push(constraint);
391    }
392
393    /// Sample a valid parameter set from this space
394    pub fn sample(&self, rng: &mut impl Rng) -> Result<ParameterSet> {
395        let mut params = ParameterSet::new();
396        let mut attempts = 0;
397        const MAX_ATTEMPTS: usize = 1000;
398
399        while attempts < MAX_ATTEMPTS {
400            params.clear();
401
402            // Sample base categorical parameters first
403            for (name, param) in &self.categorical_params {
404                params.insert(name.clone(), param.sample(rng));
405            }
406
407            // Sample conditional parameters
408            for (name, conditional_param) in &self.conditional_params {
409                if let Some(value) = conditional_param.sample_if_active(&params, rng) {
410                    params.insert(name.clone(), value);
411                }
412            }
413
414            // Check constraints
415            if self.is_valid_parameter_set(&params) {
416                return Ok(params);
417            }
418
419            attempts += 1;
420        }
421
422        Err(SklearsError::InvalidInput(format!(
423            "Failed to sample valid parameter set after {} attempts",
424            MAX_ATTEMPTS
425        )))
426    }
427
428    /// Check if a parameter set is valid according to all constraints
429    pub fn is_valid_parameter_set(&self, params: &ParameterSet) -> bool {
430        self.constraints
431            .iter()
432            .all(|constraint| constraint.is_satisfied(params))
433    }
434
435    /// Get all possible parameter names
436    pub fn get_parameter_names(&self) -> HashSet<String> {
437        let mut names = HashSet::new();
438        names.extend(self.categorical_params.keys().cloned());
439        names.extend(self.conditional_params.keys().cloned());
440        names
441    }
442
443    /// Get parameters that depend on a given parameter
444    pub fn get_dependent_parameters(&self, param_name: &str) -> HashSet<String> {
445        self.dependencies
446            .iter()
447            .filter_map(|(dependent, dependencies)| {
448                if dependencies.contains(param_name) {
449                    Some(dependent.clone())
450                } else {
451                    None
452                }
453            })
454            .collect()
455    }
456
457    /// Get the dependencies of a parameter
458    pub fn get_parameter_dependencies(&self, param_name: &str) -> HashSet<String> {
459        self.dependencies
460            .get(param_name)
461            .cloned()
462            .unwrap_or_default()
463    }
464
465    /// Generate a smart sample that respects parameter importance
466    pub fn sample_with_importance(
467        &self,
468        rng: &mut impl Rng,
469        importance_weights: &HashMap<String, f64>,
470    ) -> Result<ParameterSet> {
471        let mut params = ParameterSet::new();
472        let mut attempts = 0;
473        const MAX_ATTEMPTS: usize = 1000;
474
475        while attempts < MAX_ATTEMPTS {
476            params.clear();
477
478            // Sort parameters by importance (higher importance first)
479            let mut sorted_params: Vec<_> = self.categorical_params.keys().collect();
480            sorted_params.sort_by(|a, b| {
481                let weight_a = importance_weights.get(*a).unwrap_or(&1.0);
482                let weight_b = importance_weights.get(*b).unwrap_or(&1.0);
483                weight_b.partial_cmp(weight_a).unwrap()
484            });
485
486            // Sample important parameters first
487            for name in sorted_params {
488                if let Some(param) = self.categorical_params.get(name) {
489                    params.insert(name.clone(), param.sample(rng));
490                }
491            }
492
493            // Sample conditional parameters
494            for (name, conditional_param) in &self.conditional_params {
495                if let Some(value) = conditional_param.sample_if_active(&params, rng) {
496                    params.insert(name.clone(), value);
497                }
498            }
499
500            // Check constraints
501            if self.is_valid_parameter_set(&params) {
502                return Ok(params);
503            }
504
505            attempts += 1;
506        }
507
508        Err(SklearsError::InvalidInput(format!(
509            "Failed to sample valid parameter set after {} attempts",
510            MAX_ATTEMPTS
511        )))
512    }
513
514    /// Convenience method to add a float parameter with min/max range
515    pub fn add_float_param(&mut self, name: &str, min: f64, max: f64) {
516        // Create a reasonable set of values across the range
517        let mut values = Vec::new();
518        let n_values = 10; // Default number of values to sample
519        for i in 0..n_values {
520            let ratio = i as f64 / (n_values - 1) as f64;
521            let value = min + ratio * (max - min);
522            values.push(ParameterValue::Float(value));
523        }
524
525        let param = CategoricalParameter::new(name.to_string(), values);
526        self.add_categorical_parameter(param);
527    }
528
529    /// Convenience method to add an integer parameter with min/max range
530    pub fn add_int_param(&mut self, name: &str, min: i64, max: i64) {
531        let mut values = Vec::new();
532        let range = max - min + 1;
533        let n_values = if range <= 20 {
534            // If small range, include all values
535            range as usize
536        } else {
537            // If large range, sample 10 values
538            10
539        };
540
541        for i in 0..n_values {
542            let value = if range <= 20 {
543                min + i as i64
544            } else {
545                let ratio = i as f64 / (n_values - 1) as f64;
546                min + (ratio * (max - min) as f64) as i64
547            };
548            values.push(ParameterValue::Int(value));
549        }
550
551        let param = CategoricalParameter::new(name.to_string(), values);
552        self.add_categorical_parameter(param);
553    }
554
555    /// Convenience method to add a categorical parameter from string slice
556    pub fn add_categorical_param(&mut self, name: &str, values: Vec<&str>) {
557        let param_values = values
558            .into_iter()
559            .map(|s| ParameterValue::String(s.to_string()))
560            .collect();
561
562        let param = CategoricalParameter::new(name.to_string(), param_values);
563        self.add_categorical_parameter(param);
564    }
565
566    /// Convenience method to add a boolean parameter
567    pub fn add_boolean_param(&mut self, name: &str) {
568        let values = vec![ParameterValue::Bool(false), ParameterValue::Bool(true)];
569        let param = CategoricalParameter::new(name.to_string(), values);
570        self.add_categorical_parameter(param);
571    }
572
573    /// Auto-detect parameter ranges from a dataset of parameter sets
574    pub fn auto_detect_ranges(parameter_sets: &[ParameterSet]) -> Result<Self> {
575        let mut space = ParameterSpace::new();
576
577        if parameter_sets.is_empty() {
578            return Ok(space);
579        }
580
581        // Collect all parameter names
582        let mut all_param_names = HashSet::new();
583        for param_set in parameter_sets {
584            all_param_names.extend(param_set.keys().cloned());
585        }
586
587        // For each parameter, detect its range/categories
588        for param_name in all_param_names {
589            let mut values = HashSet::new();
590            for param_set in parameter_sets {
591                if let Some(value) = param_set.get(&param_name) {
592                    values.insert(value.clone());
593                }
594            }
595
596            if !values.is_empty() {
597                let values_vec: Vec<ParameterValue> = values.into_iter().collect();
598                let categorical_param = CategoricalParameter::new(param_name, values_vec);
599                space.add_categorical_parameter(categorical_param);
600            }
601        }
602
603        Ok(space)
604    }
605}
606
607impl Default for ParameterSpace {
608    fn default() -> Self {
609        Self::new()
610    }
611}
612
613/// Parameter importance analyzer
614#[derive(Debug)]
615pub struct ParameterImportanceAnalyzer {
616    /// Historical parameter evaluations
617    evaluations: Vec<(ParameterSet, f64)>,
618}
619
620impl ParameterImportanceAnalyzer {
621    /// Create a new importance analyzer
622    pub fn new() -> Self {
623        Self {
624            evaluations: Vec::new(),
625        }
626    }
627
628    /// Add an evaluation result
629    pub fn add_evaluation(&mut self, params: ParameterSet, score: f64) {
630        self.evaluations.push((params, score));
631    }
632
633    /// Calculate parameter importance using variance-based analysis
634    pub fn calculate_importance(&self) -> HashMap<String, f64> {
635        let mut importance = HashMap::new();
636
637        if self.evaluations.len() < 2 {
638            return importance;
639        }
640
641        // Get all parameter names
642        let mut all_params = HashSet::new();
643        for (params, _) in &self.evaluations {
644            all_params.extend(params.keys().cloned());
645        }
646
647        // Calculate variance for each parameter
648        for param_name in all_params {
649            let variance = self.calculate_parameter_variance(&param_name);
650            importance.insert(param_name, variance);
651        }
652
653        // Normalize importances
654        let max_importance = importance.values().fold(0.0f64, |a, &b| a.max(b));
655        if max_importance > 0.0 {
656            for value in importance.values_mut() {
657                *value /= max_importance;
658            }
659        }
660
661        importance
662    }
663
664    fn calculate_parameter_variance(&self, param_name: &str) -> f64 {
665        // Group evaluations by parameter value
666        let mut groups: HashMap<String, Vec<f64>> = HashMap::new();
667
668        for (params, score) in &self.evaluations {
669            if let Some(param_value) = params.get(param_name) {
670                let key = format!("{:?}", param_value);
671                groups.entry(key).or_default().push(*score);
672            }
673        }
674
675        if groups.len() < 2 {
676            return 0.0;
677        }
678
679        // Calculate within-group and between-group variance
680        let mut total_variance = 0.0;
681        let total_count = self.evaluations.len();
682
683        for scores in groups.values() {
684            if scores.len() > 1 {
685                let mean = scores.iter().sum::<f64>() / scores.len() as f64;
686                let variance = scores.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
687                    / (scores.len() - 1) as f64;
688                total_variance += variance * (scores.len() as f64 / total_count as f64);
689            }
690        }
691
692        total_variance
693    }
694}
695
696impl Default for ParameterImportanceAnalyzer {
697    fn default() -> Self {
698        Self::new()
699    }
700}
701
702#[allow(non_snake_case)]
703#[cfg(test)]
704mod tests {
705    use super::*;
706
707    #[test]
708    fn test_categorical_parameter() {
709        let param = CategoricalParameter::new(
710            "algorithm".to_string(),
711            vec!["svm".into(), "random_forest".into(), "neural_net".into()],
712        );
713
714        assert_eq!(param.values.len(), 3);
715        assert!(!param.ordered);
716    }
717
718    #[test]
719    fn test_ordered_categorical_parameter() {
720        let param = CategoricalParameter::ordered(
721            "complexity".to_string(),
722            vec!["low".into(), "medium".into(), "high".into()],
723        );
724
725        assert!(param.ordered);
726        let neighbors = param.get_neighbors(&"medium".into());
727        assert_eq!(neighbors.len(), 2);
728    }
729
730    #[test]
731    fn test_parameter_constraint() {
732        let constraint = ParameterConstraint::Equality {
733            param: "kernel".to_string(),
734            value: "rbf".into(),
735            condition_param: "algorithm".to_string(),
736            condition_value: "svm".into(),
737        };
738
739        let mut params = ParameterSet::new();
740        params.insert("algorithm".to_string(), "svm".into());
741        params.insert("kernel".to_string(), "rbf".into());
742
743        assert!(constraint.is_satisfied(&params));
744
745        params.insert("kernel".to_string(), "linear".into());
746        assert!(!constraint.is_satisfied(&params));
747    }
748
749    #[test]
750    fn test_conditional_parameter() {
751        let base_param = CategoricalParameter::new(
752            "kernel".to_string(),
753            vec!["linear".into(), "rbf".into(), "poly".into()],
754        );
755
756        let conditional_param =
757            ConditionalParameter::new(base_param, vec![("algorithm".to_string(), "svm".into())]);
758
759        let mut params = ParameterSet::new();
760        params.insert("algorithm".to_string(), "svm".into());
761        assert!(conditional_param.is_active(&params));
762
763        params.insert("algorithm".to_string(), "random_forest".into());
764        assert!(!conditional_param.is_active(&params));
765    }
766
767    #[test]
768    fn test_parameter_space_sampling() {
769        let mut space = ParameterSpace::new();
770
771        let algorithm_param = CategoricalParameter::new(
772            "algorithm".to_string(),
773            vec!["svm".into(), "random_forest".into()],
774        );
775        space.add_categorical_parameter(algorithm_param);
776
777        let kernel_param =
778            CategoricalParameter::new("kernel".to_string(), vec!["linear".into(), "rbf".into()]);
779        let conditional_kernel =
780            ConditionalParameter::new(kernel_param, vec![("algorithm".to_string(), "svm".into())]);
781        space.add_conditional_parameter(conditional_kernel);
782
783        let mut rng = scirs2_core::random::thread_rng();
784        let params = space.sample(&mut rng).unwrap();
785
786        assert!(params.contains_key("algorithm"));
787
788        if params.get("algorithm").unwrap() == &"svm".into() {
789            assert!(params.contains_key("kernel"));
790        }
791    }
792}