1use crate::grid_search::{ParameterSet, ParameterValue};
7use scirs2_core::rand_prelude::IndexedRandom;
8use scirs2_core::random::prelude::*;
9use sklears_core::error::{Result, SklearsError};
10use std::collections::{HashMap, HashSet};
11use std::sync::Arc;
12
13#[derive(Debug, Clone)]
15pub struct CategoricalParameter {
16 pub name: String,
18 pub values: Vec<ParameterValue>,
20 pub ordered: bool,
22 pub default: Option<ParameterValue>,
24 pub description: Option<String>,
26}
27
28impl CategoricalParameter {
29 pub fn new(name: String, values: Vec<ParameterValue>) -> Self {
31 Self {
32 name,
33 values,
34 ordered: false,
35 default: None,
36 description: None,
37 }
38 }
39
40 pub fn ordered(name: String, values: Vec<ParameterValue>) -> Self {
42 Self {
43 name,
44 values,
45 ordered: true,
46 default: None,
47 description: None,
48 }
49 }
50
51 pub fn with_default(mut self, default: ParameterValue) -> Self {
53 self.default = Some(default);
54 self
55 }
56
57 pub fn with_description(mut self, description: String) -> Self {
59 self.description = Some(description);
60 self
61 }
62
63 pub fn sample(&self, rng: &mut impl Rng) -> ParameterValue {
65 self.values.choose(rng).unwrap().clone()
66 }
67
68 pub fn get_index(&self, value: &ParameterValue) -> Option<usize> {
70 self.values.iter().position(|v| v == value)
71 }
72
73 pub fn get_neighbors(&self, value: &ParameterValue) -> Vec<ParameterValue> {
75 if !self.ordered {
76 return vec![];
77 }
78
79 if let Some(idx) = self.get_index(value) {
80 let mut neighbors = Vec::new();
81 if idx > 0 {
82 neighbors.push(self.values[idx - 1].clone());
83 }
84 if idx + 1 < self.values.len() {
85 neighbors.push(self.values[idx + 1].clone());
86 }
87 neighbors
88 } else {
89 vec![]
90 }
91 }
92}
93
94#[derive(Clone)]
96pub enum ParameterConstraint {
97 Equality {
99 param: String,
100
101 value: ParameterValue,
102
103 condition_param: String,
104
105 condition_value: ParameterValue,
106 },
107 Inequality {
109 param: String,
110 value: ParameterValue,
111 condition_param: String,
112 condition_value: ParameterValue,
113 },
114 Range {
116 param: String,
117 min_value: ParameterValue,
118 max_value: ParameterValue,
119 condition_param: String,
120 condition_value: ParameterValue,
121 },
122 MutualExclusion {
124 param1: String,
125 value1: ParameterValue,
126 param2: String,
127 value2: ParameterValue,
128 },
129 Custom {
131 name: String,
132 constraint_fn: Arc<dyn Fn(&ParameterSet) -> bool + Send + Sync>,
133 },
134}
135
136impl std::fmt::Debug for ParameterConstraint {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 match self {
139 ParameterConstraint::Equality {
140 param,
141 value,
142 condition_param,
143 condition_value,
144 } => {
145 write!(f, "Equality {{ param: {:?}, value: {:?}, condition_param: {:?}, condition_value: {:?} }}",
146 param, value, condition_param, condition_value)
147 }
148 ParameterConstraint::Inequality {
149 param,
150 value,
151 condition_param,
152 condition_value,
153 } => {
154 write!(f, "Inequality {{ param: {:?}, value: {:?}, condition_param: {:?}, condition_value: {:?} }}",
155 param, value, condition_param, condition_value)
156 }
157 ParameterConstraint::Range {
158 param,
159 min_value,
160 max_value,
161 condition_param,
162 condition_value,
163 } => {
164 write!(f, "Range {{ param: {:?}, min_value: {:?}, max_value: {:?}, condition_param: {:?}, condition_value: {:?} }}",
165 param, min_value, max_value, condition_param, condition_value)
166 }
167 ParameterConstraint::MutualExclusion {
168 param1,
169 value1,
170 param2,
171 value2,
172 } => {
173 write!(
174 f,
175 "MutualExclusion {{ param1: {:?}, value1: {:?}, param2: {:?}, value2: {:?} }}",
176 param1, value1, param2, value2
177 )
178 }
179 ParameterConstraint::Custom { name, .. } => {
180 write!(f, "Custom {{ name: {:?}, constraint_fn: <closure> }}", name)
181 }
182 }
183 }
184}
185
186impl ParameterConstraint {
187 pub fn is_satisfied(&self, params: &ParameterSet) -> bool {
189 match self {
190 ParameterConstraint::Equality {
191 param,
192 value,
193 condition_param,
194 condition_value,
195 } => {
196 if let (Some(param_val), Some(condition_val)) =
197 (params.get(param), params.get(condition_param))
198 {
199 if condition_val == condition_value {
200 param_val == value
201 } else {
202 true }
204 } else {
205 false }
207 }
208 ParameterConstraint::Inequality {
209 param,
210 value,
211 condition_param,
212 condition_value,
213 } => {
214 if let (Some(param_val), Some(condition_val)) =
215 (params.get(param), params.get(condition_param))
216 {
217 if condition_val == condition_value {
218 param_val != value
219 } else {
220 true }
222 } else {
223 false }
225 }
226 ParameterConstraint::Range {
227 param,
228 min_value,
229 max_value,
230 condition_param,
231 condition_value,
232 } => {
233 if let (Some(param_val), Some(condition_val)) =
234 (params.get(param), params.get(condition_param))
235 {
236 if condition_val == condition_value {
237 self.is_in_range(param_val, min_value, max_value)
238 } else {
239 true }
241 } else {
242 false }
244 }
245 ParameterConstraint::MutualExclusion {
246 param1,
247 value1,
248 param2,
249 value2,
250 } => {
251 if let (Some(param1_val), Some(param2_val)) =
252 (params.get(param1), params.get(param2))
253 {
254 !(param1_val == value1 && param2_val == value2)
255 } else {
256 true }
258 }
259 ParameterConstraint::Custom { constraint_fn, .. } => constraint_fn(params),
260 }
261 }
262
263 fn is_in_range(
264 &self,
265 value: &ParameterValue,
266 min_value: &ParameterValue,
267 max_value: &ParameterValue,
268 ) -> bool {
269 match (value, min_value, max_value) {
270 (ParameterValue::Int(v), ParameterValue::Int(min), ParameterValue::Int(max)) => {
271 v >= min && v <= max
272 }
273 (ParameterValue::Float(v), ParameterValue::Float(min), ParameterValue::Float(max)) => {
274 v >= min && v <= max
275 }
276 _ => false, }
278 }
279}
280
281#[derive(Debug, Clone)]
283pub struct ConditionalParameter {
284 pub parameter: CategoricalParameter,
286 pub conditions: Vec<(String, ParameterValue)>,
288 pub require_all_conditions: bool,
290}
291
292impl ConditionalParameter {
293 pub fn new(parameter: CategoricalParameter, conditions: Vec<(String, ParameterValue)>) -> Self {
295 Self {
296 parameter,
297 conditions,
298 require_all_conditions: true,
299 }
300 }
301
302 pub fn require_all_conditions(mut self, require_all: bool) -> Self {
304 self.require_all_conditions = require_all;
305 self
306 }
307
308 pub fn is_active(&self, params: &ParameterSet) -> bool {
310 if self.conditions.is_empty() {
311 return true;
312 }
313
314 let satisfied_conditions = self
315 .conditions
316 .iter()
317 .filter(|(param_name, expected_value)| {
318 params
319 .get(param_name)
320 .map(|value| value == expected_value)
321 .unwrap_or(false)
322 })
323 .count();
324
325 if self.require_all_conditions {
326 satisfied_conditions == self.conditions.len()
327 } else {
328 satisfied_conditions > 0
329 }
330 }
331
332 pub fn sample_if_active(
334 &self,
335 params: &ParameterSet,
336 rng: &mut impl Rng,
337 ) -> Option<ParameterValue> {
338 if self.is_active(params) {
339 Some(self.parameter.sample(rng))
340 } else {
341 None
342 }
343 }
344}
345
346#[derive(Debug, Clone)]
348pub struct ParameterSpace {
349 pub categorical_params: HashMap<String, CategoricalParameter>,
351 pub conditional_params: HashMap<String, ConditionalParameter>,
353 pub constraints: Vec<ParameterConstraint>,
355 pub dependencies: HashMap<String, HashSet<String>>,
357}
358
359impl ParameterSpace {
360 pub fn new() -> Self {
362 Self {
363 categorical_params: HashMap::new(),
364 conditional_params: HashMap::new(),
365 constraints: Vec::new(),
366 dependencies: HashMap::new(),
367 }
368 }
369
370 pub fn add_categorical_parameter(&mut self, param: CategoricalParameter) {
372 self.categorical_params.insert(param.name.clone(), param);
373 }
374
375 pub fn add_conditional_parameter(&mut self, param: ConditionalParameter) {
377 for (dep_param, _) in ¶m.conditions {
379 self.dependencies
380 .entry(param.parameter.name.clone())
381 .or_default()
382 .insert(dep_param.clone());
383 }
384 self.conditional_params
385 .insert(param.parameter.name.clone(), param);
386 }
387
388 pub fn add_constraint(&mut self, constraint: ParameterConstraint) {
390 self.constraints.push(constraint);
391 }
392
393 pub fn sample(&self, rng: &mut impl Rng) -> Result<ParameterSet> {
395 let mut params = ParameterSet::new();
396 let mut attempts = 0;
397 const MAX_ATTEMPTS: usize = 1000;
398
399 while attempts < MAX_ATTEMPTS {
400 params.clear();
401
402 for (name, param) in &self.categorical_params {
404 params.insert(name.clone(), param.sample(rng));
405 }
406
407 for (name, conditional_param) in &self.conditional_params {
409 if let Some(value) = conditional_param.sample_if_active(¶ms, rng) {
410 params.insert(name.clone(), value);
411 }
412 }
413
414 if self.is_valid_parameter_set(¶ms) {
416 return Ok(params);
417 }
418
419 attempts += 1;
420 }
421
422 Err(SklearsError::InvalidInput(format!(
423 "Failed to sample valid parameter set after {} attempts",
424 MAX_ATTEMPTS
425 )))
426 }
427
428 pub fn is_valid_parameter_set(&self, params: &ParameterSet) -> bool {
430 self.constraints
431 .iter()
432 .all(|constraint| constraint.is_satisfied(params))
433 }
434
435 pub fn get_parameter_names(&self) -> HashSet<String> {
437 let mut names = HashSet::new();
438 names.extend(self.categorical_params.keys().cloned());
439 names.extend(self.conditional_params.keys().cloned());
440 names
441 }
442
443 pub fn get_dependent_parameters(&self, param_name: &str) -> HashSet<String> {
445 self.dependencies
446 .iter()
447 .filter_map(|(dependent, dependencies)| {
448 if dependencies.contains(param_name) {
449 Some(dependent.clone())
450 } else {
451 None
452 }
453 })
454 .collect()
455 }
456
457 pub fn get_parameter_dependencies(&self, param_name: &str) -> HashSet<String> {
459 self.dependencies
460 .get(param_name)
461 .cloned()
462 .unwrap_or_default()
463 }
464
465 pub fn sample_with_importance(
467 &self,
468 rng: &mut impl Rng,
469 importance_weights: &HashMap<String, f64>,
470 ) -> Result<ParameterSet> {
471 let mut params = ParameterSet::new();
472 let mut attempts = 0;
473 const MAX_ATTEMPTS: usize = 1000;
474
475 while attempts < MAX_ATTEMPTS {
476 params.clear();
477
478 let mut sorted_params: Vec<_> = self.categorical_params.keys().collect();
480 sorted_params.sort_by(|a, b| {
481 let weight_a = importance_weights.get(*a).unwrap_or(&1.0);
482 let weight_b = importance_weights.get(*b).unwrap_or(&1.0);
483 weight_b.partial_cmp(weight_a).unwrap()
484 });
485
486 for name in sorted_params {
488 if let Some(param) = self.categorical_params.get(name) {
489 params.insert(name.clone(), param.sample(rng));
490 }
491 }
492
493 for (name, conditional_param) in &self.conditional_params {
495 if let Some(value) = conditional_param.sample_if_active(¶ms, rng) {
496 params.insert(name.clone(), value);
497 }
498 }
499
500 if self.is_valid_parameter_set(¶ms) {
502 return Ok(params);
503 }
504
505 attempts += 1;
506 }
507
508 Err(SklearsError::InvalidInput(format!(
509 "Failed to sample valid parameter set after {} attempts",
510 MAX_ATTEMPTS
511 )))
512 }
513
514 pub fn add_float_param(&mut self, name: &str, min: f64, max: f64) {
516 let mut values = Vec::new();
518 let n_values = 10; for i in 0..n_values {
520 let ratio = i as f64 / (n_values - 1) as f64;
521 let value = min + ratio * (max - min);
522 values.push(ParameterValue::Float(value));
523 }
524
525 let param = CategoricalParameter::new(name.to_string(), values);
526 self.add_categorical_parameter(param);
527 }
528
529 pub fn add_int_param(&mut self, name: &str, min: i64, max: i64) {
531 let mut values = Vec::new();
532 let range = max - min + 1;
533 let n_values = if range <= 20 {
534 range as usize
536 } else {
537 10
539 };
540
541 for i in 0..n_values {
542 let value = if range <= 20 {
543 min + i as i64
544 } else {
545 let ratio = i as f64 / (n_values - 1) as f64;
546 min + (ratio * (max - min) as f64) as i64
547 };
548 values.push(ParameterValue::Int(value));
549 }
550
551 let param = CategoricalParameter::new(name.to_string(), values);
552 self.add_categorical_parameter(param);
553 }
554
555 pub fn add_categorical_param(&mut self, name: &str, values: Vec<&str>) {
557 let param_values = values
558 .into_iter()
559 .map(|s| ParameterValue::String(s.to_string()))
560 .collect();
561
562 let param = CategoricalParameter::new(name.to_string(), param_values);
563 self.add_categorical_parameter(param);
564 }
565
566 pub fn add_boolean_param(&mut self, name: &str) {
568 let values = vec![ParameterValue::Bool(false), ParameterValue::Bool(true)];
569 let param = CategoricalParameter::new(name.to_string(), values);
570 self.add_categorical_parameter(param);
571 }
572
573 pub fn auto_detect_ranges(parameter_sets: &[ParameterSet]) -> Result<Self> {
575 let mut space = ParameterSpace::new();
576
577 if parameter_sets.is_empty() {
578 return Ok(space);
579 }
580
581 let mut all_param_names = HashSet::new();
583 for param_set in parameter_sets {
584 all_param_names.extend(param_set.keys().cloned());
585 }
586
587 for param_name in all_param_names {
589 let mut values = HashSet::new();
590 for param_set in parameter_sets {
591 if let Some(value) = param_set.get(¶m_name) {
592 values.insert(value.clone());
593 }
594 }
595
596 if !values.is_empty() {
597 let values_vec: Vec<ParameterValue> = values.into_iter().collect();
598 let categorical_param = CategoricalParameter::new(param_name, values_vec);
599 space.add_categorical_parameter(categorical_param);
600 }
601 }
602
603 Ok(space)
604 }
605}
606
607impl Default for ParameterSpace {
608 fn default() -> Self {
609 Self::new()
610 }
611}
612
613#[derive(Debug)]
615pub struct ParameterImportanceAnalyzer {
616 evaluations: Vec<(ParameterSet, f64)>,
618}
619
620impl ParameterImportanceAnalyzer {
621 pub fn new() -> Self {
623 Self {
624 evaluations: Vec::new(),
625 }
626 }
627
628 pub fn add_evaluation(&mut self, params: ParameterSet, score: f64) {
630 self.evaluations.push((params, score));
631 }
632
633 pub fn calculate_importance(&self) -> HashMap<String, f64> {
635 let mut importance = HashMap::new();
636
637 if self.evaluations.len() < 2 {
638 return importance;
639 }
640
641 let mut all_params = HashSet::new();
643 for (params, _) in &self.evaluations {
644 all_params.extend(params.keys().cloned());
645 }
646
647 for param_name in all_params {
649 let variance = self.calculate_parameter_variance(¶m_name);
650 importance.insert(param_name, variance);
651 }
652
653 let max_importance = importance.values().fold(0.0f64, |a, &b| a.max(b));
655 if max_importance > 0.0 {
656 for value in importance.values_mut() {
657 *value /= max_importance;
658 }
659 }
660
661 importance
662 }
663
664 fn calculate_parameter_variance(&self, param_name: &str) -> f64 {
665 let mut groups: HashMap<String, Vec<f64>> = HashMap::new();
667
668 for (params, score) in &self.evaluations {
669 if let Some(param_value) = params.get(param_name) {
670 let key = format!("{:?}", param_value);
671 groups.entry(key).or_default().push(*score);
672 }
673 }
674
675 if groups.len() < 2 {
676 return 0.0;
677 }
678
679 let mut total_variance = 0.0;
681 let total_count = self.evaluations.len();
682
683 for scores in groups.values() {
684 if scores.len() > 1 {
685 let mean = scores.iter().sum::<f64>() / scores.len() as f64;
686 let variance = scores.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
687 / (scores.len() - 1) as f64;
688 total_variance += variance * (scores.len() as f64 / total_count as f64);
689 }
690 }
691
692 total_variance
693 }
694}
695
696impl Default for ParameterImportanceAnalyzer {
697 fn default() -> Self {
698 Self::new()
699 }
700}
701
702#[allow(non_snake_case)]
703#[cfg(test)]
704mod tests {
705 use super::*;
706
707 #[test]
708 fn test_categorical_parameter() {
709 let param = CategoricalParameter::new(
710 "algorithm".to_string(),
711 vec!["svm".into(), "random_forest".into(), "neural_net".into()],
712 );
713
714 assert_eq!(param.values.len(), 3);
715 assert!(!param.ordered);
716 }
717
718 #[test]
719 fn test_ordered_categorical_parameter() {
720 let param = CategoricalParameter::ordered(
721 "complexity".to_string(),
722 vec!["low".into(), "medium".into(), "high".into()],
723 );
724
725 assert!(param.ordered);
726 let neighbors = param.get_neighbors(&"medium".into());
727 assert_eq!(neighbors.len(), 2);
728 }
729
730 #[test]
731 fn test_parameter_constraint() {
732 let constraint = ParameterConstraint::Equality {
733 param: "kernel".to_string(),
734 value: "rbf".into(),
735 condition_param: "algorithm".to_string(),
736 condition_value: "svm".into(),
737 };
738
739 let mut params = ParameterSet::new();
740 params.insert("algorithm".to_string(), "svm".into());
741 params.insert("kernel".to_string(), "rbf".into());
742
743 assert!(constraint.is_satisfied(¶ms));
744
745 params.insert("kernel".to_string(), "linear".into());
746 assert!(!constraint.is_satisfied(¶ms));
747 }
748
749 #[test]
750 fn test_conditional_parameter() {
751 let base_param = CategoricalParameter::new(
752 "kernel".to_string(),
753 vec!["linear".into(), "rbf".into(), "poly".into()],
754 );
755
756 let conditional_param =
757 ConditionalParameter::new(base_param, vec![("algorithm".to_string(), "svm".into())]);
758
759 let mut params = ParameterSet::new();
760 params.insert("algorithm".to_string(), "svm".into());
761 assert!(conditional_param.is_active(¶ms));
762
763 params.insert("algorithm".to_string(), "random_forest".into());
764 assert!(!conditional_param.is_active(¶ms));
765 }
766
767 #[test]
768 fn test_parameter_space_sampling() {
769 let mut space = ParameterSpace::new();
770
771 let algorithm_param = CategoricalParameter::new(
772 "algorithm".to_string(),
773 vec!["svm".into(), "random_forest".into()],
774 );
775 space.add_categorical_parameter(algorithm_param);
776
777 let kernel_param =
778 CategoricalParameter::new("kernel".to_string(), vec!["linear".into(), "rbf".into()]);
779 let conditional_kernel =
780 ConditionalParameter::new(kernel_param, vec![("algorithm".to_string(), "svm".into())]);
781 space.add_conditional_parameter(conditional_kernel);
782
783 let mut rng = scirs2_core::random::thread_rng();
784 let params = space.sample(&mut rng).unwrap();
785
786 assert!(params.contains_key("algorithm"));
787
788 if params.get("algorithm").unwrap() == &"svm".into() {
789 assert!(params.contains_key("kernel"));
790 }
791 }
792}