1use crate::grid_search::{ParameterSet, ParameterValue};
7use scirs2_core::rand_prelude::IndexedRandom;
8use scirs2_core::random::prelude::*;
9use sklears_core::error::{Result, SklearsError};
10use std::collections::{HashMap, HashSet};
11use std::sync::Arc;
12
13#[derive(Debug, Clone)]
15pub struct CategoricalParameter {
16 pub name: String,
18 pub values: Vec<ParameterValue>,
20 pub ordered: bool,
22 pub default: Option<ParameterValue>,
24 pub description: Option<String>,
26}
27
28impl CategoricalParameter {
29 pub fn new(name: String, values: Vec<ParameterValue>) -> Self {
31 Self {
32 name,
33 values,
34 ordered: false,
35 default: None,
36 description: None,
37 }
38 }
39
40 pub fn ordered(name: String, values: Vec<ParameterValue>) -> Self {
42 Self {
43 name,
44 values,
45 ordered: true,
46 default: None,
47 description: None,
48 }
49 }
50
51 pub fn with_default(mut self, default: ParameterValue) -> Self {
53 self.default = Some(default);
54 self
55 }
56
57 pub fn with_description(mut self, description: String) -> Self {
59 self.description = Some(description);
60 self
61 }
62
63 pub fn sample(&self, rng: &mut impl Rng) -> ParameterValue {
65 self.values
66 .choose(rng)
67 .expect("operation should succeed")
68 .clone()
69 }
70
71 pub fn get_index(&self, value: &ParameterValue) -> Option<usize> {
73 self.values.iter().position(|v| v == value)
74 }
75
76 pub fn get_neighbors(&self, value: &ParameterValue) -> Vec<ParameterValue> {
78 if !self.ordered {
79 return vec![];
80 }
81
82 if let Some(idx) = self.get_index(value) {
83 let mut neighbors = Vec::new();
84 if idx > 0 {
85 neighbors.push(self.values[idx - 1].clone());
86 }
87 if idx + 1 < self.values.len() {
88 neighbors.push(self.values[idx + 1].clone());
89 }
90 neighbors
91 } else {
92 vec![]
93 }
94 }
95}
96
97#[derive(Clone)]
99pub enum ParameterConstraint {
100 Equality {
102 param: String,
103
104 value: ParameterValue,
105
106 condition_param: String,
107
108 condition_value: ParameterValue,
109 },
110 Inequality {
112 param: String,
113 value: ParameterValue,
114 condition_param: String,
115 condition_value: ParameterValue,
116 },
117 Range {
119 param: String,
120 min_value: ParameterValue,
121 max_value: ParameterValue,
122 condition_param: String,
123 condition_value: ParameterValue,
124 },
125 MutualExclusion {
127 param1: String,
128 value1: ParameterValue,
129 param2: String,
130 value2: ParameterValue,
131 },
132 Custom {
134 name: String,
135 constraint_fn: Arc<dyn Fn(&ParameterSet) -> bool + Send + Sync>,
136 },
137}
138
139impl std::fmt::Debug for ParameterConstraint {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 match self {
142 ParameterConstraint::Equality {
143 param,
144 value,
145 condition_param,
146 condition_value,
147 } => {
148 write!(f, "Equality {{ param: {:?}, value: {:?}, condition_param: {:?}, condition_value: {:?} }}",
149 param, value, condition_param, condition_value)
150 }
151 ParameterConstraint::Inequality {
152 param,
153 value,
154 condition_param,
155 condition_value,
156 } => {
157 write!(f, "Inequality {{ param: {:?}, value: {:?}, condition_param: {:?}, condition_value: {:?} }}",
158 param, value, condition_param, condition_value)
159 }
160 ParameterConstraint::Range {
161 param,
162 min_value,
163 max_value,
164 condition_param,
165 condition_value,
166 } => {
167 write!(f, "Range {{ param: {:?}, min_value: {:?}, max_value: {:?}, condition_param: {:?}, condition_value: {:?} }}",
168 param, min_value, max_value, condition_param, condition_value)
169 }
170 ParameterConstraint::MutualExclusion {
171 param1,
172 value1,
173 param2,
174 value2,
175 } => {
176 write!(
177 f,
178 "MutualExclusion {{ param1: {:?}, value1: {:?}, param2: {:?}, value2: {:?} }}",
179 param1, value1, param2, value2
180 )
181 }
182 ParameterConstraint::Custom { name, .. } => {
183 write!(f, "Custom {{ name: {:?}, constraint_fn: <closure> }}", name)
184 }
185 }
186 }
187}
188
189impl ParameterConstraint {
190 pub fn is_satisfied(&self, params: &ParameterSet) -> bool {
192 match self {
193 ParameterConstraint::Equality {
194 param,
195 value,
196 condition_param,
197 condition_value,
198 } => {
199 if let (Some(param_val), Some(condition_val)) =
200 (params.get(param), params.get(condition_param))
201 {
202 if condition_val == condition_value {
203 param_val == value
204 } else {
205 true }
207 } else {
208 false }
210 }
211 ParameterConstraint::Inequality {
212 param,
213 value,
214 condition_param,
215 condition_value,
216 } => {
217 if let (Some(param_val), Some(condition_val)) =
218 (params.get(param), params.get(condition_param))
219 {
220 if condition_val == condition_value {
221 param_val != value
222 } else {
223 true }
225 } else {
226 false }
228 }
229 ParameterConstraint::Range {
230 param,
231 min_value,
232 max_value,
233 condition_param,
234 condition_value,
235 } => {
236 if let (Some(param_val), Some(condition_val)) =
237 (params.get(param), params.get(condition_param))
238 {
239 if condition_val == condition_value {
240 self.is_in_range(param_val, min_value, max_value)
241 } else {
242 true }
244 } else {
245 false }
247 }
248 ParameterConstraint::MutualExclusion {
249 param1,
250 value1,
251 param2,
252 value2,
253 } => {
254 if let (Some(param1_val), Some(param2_val)) =
255 (params.get(param1), params.get(param2))
256 {
257 !(param1_val == value1 && param2_val == value2)
258 } else {
259 true }
261 }
262 ParameterConstraint::Custom { constraint_fn, .. } => constraint_fn(params),
263 }
264 }
265
266 fn is_in_range(
267 &self,
268 value: &ParameterValue,
269 min_value: &ParameterValue,
270 max_value: &ParameterValue,
271 ) -> bool {
272 match (value, min_value, max_value) {
273 (ParameterValue::Int(v), ParameterValue::Int(min), ParameterValue::Int(max)) => {
274 v >= min && v <= max
275 }
276 (ParameterValue::Float(v), ParameterValue::Float(min), ParameterValue::Float(max)) => {
277 v >= min && v <= max
278 }
279 _ => false, }
281 }
282}
283
284#[derive(Debug, Clone)]
286pub struct ConditionalParameter {
287 pub parameter: CategoricalParameter,
289 pub conditions: Vec<(String, ParameterValue)>,
291 pub require_all_conditions: bool,
293}
294
295impl ConditionalParameter {
296 pub fn new(parameter: CategoricalParameter, conditions: Vec<(String, ParameterValue)>) -> Self {
298 Self {
299 parameter,
300 conditions,
301 require_all_conditions: true,
302 }
303 }
304
305 pub fn require_all_conditions(mut self, require_all: bool) -> Self {
307 self.require_all_conditions = require_all;
308 self
309 }
310
311 pub fn is_active(&self, params: &ParameterSet) -> bool {
313 if self.conditions.is_empty() {
314 return true;
315 }
316
317 let satisfied_conditions = self
318 .conditions
319 .iter()
320 .filter(|(param_name, expected_value)| {
321 params
322 .get(param_name)
323 .map(|value| value == expected_value)
324 .unwrap_or(false)
325 })
326 .count();
327
328 if self.require_all_conditions {
329 satisfied_conditions == self.conditions.len()
330 } else {
331 satisfied_conditions > 0
332 }
333 }
334
335 pub fn sample_if_active(
337 &self,
338 params: &ParameterSet,
339 rng: &mut impl Rng,
340 ) -> Option<ParameterValue> {
341 if self.is_active(params) {
342 Some(self.parameter.sample(rng))
343 } else {
344 None
345 }
346 }
347}
348
349#[derive(Debug, Clone)]
351pub struct ParameterSpace {
352 pub categorical_params: HashMap<String, CategoricalParameter>,
354 pub conditional_params: HashMap<String, ConditionalParameter>,
356 pub constraints: Vec<ParameterConstraint>,
358 pub dependencies: HashMap<String, HashSet<String>>,
360}
361
362impl ParameterSpace {
363 pub fn new() -> Self {
365 Self {
366 categorical_params: HashMap::new(),
367 conditional_params: HashMap::new(),
368 constraints: Vec::new(),
369 dependencies: HashMap::new(),
370 }
371 }
372
373 pub fn add_categorical_parameter(&mut self, param: CategoricalParameter) {
375 self.categorical_params.insert(param.name.clone(), param);
376 }
377
378 pub fn add_conditional_parameter(&mut self, param: ConditionalParameter) {
380 for (dep_param, _) in ¶m.conditions {
382 self.dependencies
383 .entry(param.parameter.name.clone())
384 .or_default()
385 .insert(dep_param.clone());
386 }
387 self.conditional_params
388 .insert(param.parameter.name.clone(), param);
389 }
390
391 pub fn add_constraint(&mut self, constraint: ParameterConstraint) {
393 self.constraints.push(constraint);
394 }
395
396 pub fn sample(&self, rng: &mut impl Rng) -> Result<ParameterSet> {
398 let mut params = ParameterSet::new();
399 let mut attempts = 0;
400 const MAX_ATTEMPTS: usize = 1000;
401
402 while attempts < MAX_ATTEMPTS {
403 params.clear();
404
405 for (name, param) in &self.categorical_params {
407 params.insert(name.clone(), param.sample(rng));
408 }
409
410 for (name, conditional_param) in &self.conditional_params {
412 if let Some(value) = conditional_param.sample_if_active(¶ms, rng) {
413 params.insert(name.clone(), value);
414 }
415 }
416
417 if self.is_valid_parameter_set(¶ms) {
419 return Ok(params);
420 }
421
422 attempts += 1;
423 }
424
425 Err(SklearsError::InvalidInput(format!(
426 "Failed to sample valid parameter set after {} attempts",
427 MAX_ATTEMPTS
428 )))
429 }
430
431 pub fn is_valid_parameter_set(&self, params: &ParameterSet) -> bool {
433 self.constraints
434 .iter()
435 .all(|constraint| constraint.is_satisfied(params))
436 }
437
438 pub fn get_parameter_names(&self) -> HashSet<String> {
440 let mut names = HashSet::new();
441 names.extend(self.categorical_params.keys().cloned());
442 names.extend(self.conditional_params.keys().cloned());
443 names
444 }
445
446 pub fn get_dependent_parameters(&self, param_name: &str) -> HashSet<String> {
448 self.dependencies
449 .iter()
450 .filter_map(|(dependent, dependencies)| {
451 if dependencies.contains(param_name) {
452 Some(dependent.clone())
453 } else {
454 None
455 }
456 })
457 .collect()
458 }
459
460 pub fn get_parameter_dependencies(&self, param_name: &str) -> HashSet<String> {
462 self.dependencies
463 .get(param_name)
464 .cloned()
465 .unwrap_or_default()
466 }
467
468 pub fn sample_with_importance(
470 &self,
471 rng: &mut impl Rng,
472 importance_weights: &HashMap<String, f64>,
473 ) -> Result<ParameterSet> {
474 let mut params = ParameterSet::new();
475 let mut attempts = 0;
476 const MAX_ATTEMPTS: usize = 1000;
477
478 while attempts < MAX_ATTEMPTS {
479 params.clear();
480
481 let mut sorted_params: Vec<_> = self.categorical_params.keys().collect();
483 sorted_params.sort_by(|a, b| {
484 let weight_a = importance_weights.get(*a).unwrap_or(&1.0);
485 let weight_b = importance_weights.get(*b).unwrap_or(&1.0);
486 weight_b
487 .partial_cmp(weight_a)
488 .expect("operation should succeed")
489 });
490
491 for name in sorted_params {
493 if let Some(param) = self.categorical_params.get(name) {
494 params.insert(name.clone(), param.sample(rng));
495 }
496 }
497
498 for (name, conditional_param) in &self.conditional_params {
500 if let Some(value) = conditional_param.sample_if_active(¶ms, rng) {
501 params.insert(name.clone(), value);
502 }
503 }
504
505 if self.is_valid_parameter_set(¶ms) {
507 return Ok(params);
508 }
509
510 attempts += 1;
511 }
512
513 Err(SklearsError::InvalidInput(format!(
514 "Failed to sample valid parameter set after {} attempts",
515 MAX_ATTEMPTS
516 )))
517 }
518
519 pub fn add_float_param(&mut self, name: &str, min: f64, max: f64) {
521 let mut values = Vec::new();
523 let n_values = 10; for i in 0..n_values {
525 let ratio = i as f64 / (n_values - 1) as f64;
526 let value = min + ratio * (max - min);
527 values.push(ParameterValue::Float(value));
528 }
529
530 let param = CategoricalParameter::new(name.to_string(), values);
531 self.add_categorical_parameter(param);
532 }
533
534 pub fn add_int_param(&mut self, name: &str, min: i64, max: i64) {
536 let mut values = Vec::new();
537 let range = max - min + 1;
538 let n_values = if range <= 20 {
539 range as usize
541 } else {
542 10
544 };
545
546 for i in 0..n_values {
547 let value = if range <= 20 {
548 min + i as i64
549 } else {
550 let ratio = i as f64 / (n_values - 1) as f64;
551 min + (ratio * (max - min) as f64) as i64
552 };
553 values.push(ParameterValue::Int(value));
554 }
555
556 let param = CategoricalParameter::new(name.to_string(), values);
557 self.add_categorical_parameter(param);
558 }
559
560 pub fn add_categorical_param(&mut self, name: &str, values: Vec<&str>) {
562 let param_values = values
563 .into_iter()
564 .map(|s| ParameterValue::String(s.to_string()))
565 .collect();
566
567 let param = CategoricalParameter::new(name.to_string(), param_values);
568 self.add_categorical_parameter(param);
569 }
570
571 pub fn add_boolean_param(&mut self, name: &str) {
573 let values = vec![ParameterValue::Bool(false), ParameterValue::Bool(true)];
574 let param = CategoricalParameter::new(name.to_string(), values);
575 self.add_categorical_parameter(param);
576 }
577
578 pub fn auto_detect_ranges(parameter_sets: &[ParameterSet]) -> Result<Self> {
580 let mut space = ParameterSpace::new();
581
582 if parameter_sets.is_empty() {
583 return Ok(space);
584 }
585
586 let mut all_param_names = HashSet::new();
588 for param_set in parameter_sets {
589 all_param_names.extend(param_set.keys().cloned());
590 }
591
592 for param_name in all_param_names {
594 let mut values = HashSet::new();
595 for param_set in parameter_sets {
596 if let Some(value) = param_set.get(¶m_name) {
597 values.insert(value.clone());
598 }
599 }
600
601 if !values.is_empty() {
602 let values_vec: Vec<ParameterValue> = values.into_iter().collect();
603 let categorical_param = CategoricalParameter::new(param_name, values_vec);
604 space.add_categorical_parameter(categorical_param);
605 }
606 }
607
608 Ok(space)
609 }
610}
611
612impl Default for ParameterSpace {
613 fn default() -> Self {
614 Self::new()
615 }
616}
617
618#[derive(Debug)]
620pub struct ParameterImportanceAnalyzer {
621 evaluations: Vec<(ParameterSet, f64)>,
623}
624
625impl ParameterImportanceAnalyzer {
626 pub fn new() -> Self {
628 Self {
629 evaluations: Vec::new(),
630 }
631 }
632
633 pub fn add_evaluation(&mut self, params: ParameterSet, score: f64) {
635 self.evaluations.push((params, score));
636 }
637
638 pub fn calculate_importance(&self) -> HashMap<String, f64> {
640 let mut importance = HashMap::new();
641
642 if self.evaluations.len() < 2 {
643 return importance;
644 }
645
646 let mut all_params = HashSet::new();
648 for (params, _) in &self.evaluations {
649 all_params.extend(params.keys().cloned());
650 }
651
652 for param_name in all_params {
654 let variance = self.calculate_parameter_variance(¶m_name);
655 importance.insert(param_name, variance);
656 }
657
658 let max_importance = importance.values().fold(0.0f64, |a, &b| a.max(b));
660 if max_importance > 0.0 {
661 for value in importance.values_mut() {
662 *value /= max_importance;
663 }
664 }
665
666 importance
667 }
668
669 fn calculate_parameter_variance(&self, param_name: &str) -> f64 {
670 let mut groups: HashMap<String, Vec<f64>> = HashMap::new();
672
673 for (params, score) in &self.evaluations {
674 if let Some(param_value) = params.get(param_name) {
675 let key = format!("{:?}", param_value);
676 groups.entry(key).or_default().push(*score);
677 }
678 }
679
680 if groups.len() < 2 {
681 return 0.0;
682 }
683
684 let mut total_variance = 0.0;
686 let total_count = self.evaluations.len();
687
688 for scores in groups.values() {
689 if scores.len() > 1 {
690 let mean = scores.iter().sum::<f64>() / scores.len() as f64;
691 let variance = scores.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
692 / (scores.len() - 1) as f64;
693 total_variance += variance * (scores.len() as f64 / total_count as f64);
694 }
695 }
696
697 total_variance
698 }
699}
700
701impl Default for ParameterImportanceAnalyzer {
702 fn default() -> Self {
703 Self::new()
704 }
705}
706
707#[allow(non_snake_case)]
708#[cfg(test)]
709mod tests {
710 use super::*;
711
712 #[test]
713 fn test_categorical_parameter() {
714 let param = CategoricalParameter::new(
715 "algorithm".to_string(),
716 vec!["svm".into(), "random_forest".into(), "neural_net".into()],
717 );
718
719 assert_eq!(param.values.len(), 3);
720 assert!(!param.ordered);
721 }
722
723 #[test]
724 fn test_ordered_categorical_parameter() {
725 let param = CategoricalParameter::ordered(
726 "complexity".to_string(),
727 vec!["low".into(), "medium".into(), "high".into()],
728 );
729
730 assert!(param.ordered);
731 let neighbors = param.get_neighbors(&"medium".into());
732 assert_eq!(neighbors.len(), 2);
733 }
734
735 #[test]
736 fn test_parameter_constraint() {
737 let constraint = ParameterConstraint::Equality {
738 param: "kernel".to_string(),
739 value: "rbf".into(),
740 condition_param: "algorithm".to_string(),
741 condition_value: "svm".into(),
742 };
743
744 let mut params = ParameterSet::new();
745 params.insert("algorithm".to_string(), "svm".into());
746 params.insert("kernel".to_string(), "rbf".into());
747
748 assert!(constraint.is_satisfied(¶ms));
749
750 params.insert("kernel".to_string(), "linear".into());
751 assert!(!constraint.is_satisfied(¶ms));
752 }
753
754 #[test]
755 fn test_conditional_parameter() {
756 let base_param = CategoricalParameter::new(
757 "kernel".to_string(),
758 vec!["linear".into(), "rbf".into(), "poly".into()],
759 );
760
761 let conditional_param =
762 ConditionalParameter::new(base_param, vec![("algorithm".to_string(), "svm".into())]);
763
764 let mut params = ParameterSet::new();
765 params.insert("algorithm".to_string(), "svm".into());
766 assert!(conditional_param.is_active(¶ms));
767
768 params.insert("algorithm".to_string(), "random_forest".into());
769 assert!(!conditional_param.is_active(¶ms));
770 }
771
772 #[test]
773 fn test_parameter_space_sampling() {
774 let mut space = ParameterSpace::new();
775
776 let algorithm_param = CategoricalParameter::new(
777 "algorithm".to_string(),
778 vec!["svm".into(), "random_forest".into()],
779 );
780 space.add_categorical_parameter(algorithm_param);
781
782 let kernel_param =
783 CategoricalParameter::new("kernel".to_string(), vec!["linear".into(), "rbf".into()]);
784 let conditional_kernel =
785 ConditionalParameter::new(kernel_param, vec![("algorithm".to_string(), "svm".into())]);
786 space.add_conditional_parameter(conditional_kernel);
787
788 let mut rng = scirs2_core::random::thread_rng();
789 let params = space.sample(&mut rng).expect("operation should succeed");
790
791 assert!(params.contains_key("algorithm"));
792
793 if params.get("algorithm").expect("operation should succeed") == &"svm".into() {
794 assert!(params.contains_key("kernel"));
795 }
796 }
797}