1use scirs2_core::ndarray::ArrayView2;
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::collections::HashMap;
9use std::fmt::Debug;
10
11use crate::error::Result;
12
13use super::config::{
14 AcquisitionFunction, CVStrategy, CrossValidationConfig, EarlyStoppingConfig, EvaluationMetric,
15 HyperParameter, LoadBalancingStrategy, ParallelConfig, ResourceConstraints, SearchSpace,
16 SearchStrategy, TuningConfig, TuningResult,
17};
18
19pub struct AutoClusteringSelector<F: Float + FromPrimitive> {
21 config: TuningConfig,
23 algorithms: Vec<ClusteringAlgorithm>,
25 _phantom: std::marker::PhantomData<F>,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Hash)]
31pub enum ClusteringAlgorithm {
32 KMeans,
33 DBSCAN,
34 OPTICS,
35 GaussianMixture,
36 SpectralClustering,
37 MeanShift,
38 HierarchicalClustering,
39 BIRCH,
40 AffinityPropagation,
41 QuantumKMeans,
42 RLClustering,
43 AdaptiveOnline,
44}
45
46#[derive(Debug, Clone)]
48pub struct AlgorithmSelectionResult {
49 pub best_algorithm: ClusteringAlgorithm,
51 pub best_parameters: HashMap<String, f64>,
53 pub best_score: f64,
55 pub algorithm_results: HashMap<ClusteringAlgorithm, TuningResult>,
57 pub total_time: f64,
59 pub recommendations: Vec<String>,
61}
62
63impl<
64 F: Float
65 + FromPrimitive
66 + Debug
67 + 'static
68 + std::iter::Sum
69 + std::fmt::Display
70 + Send
71 + Sync
72 + scirs2_core::ndarray::ScalarOperand
73 + std::ops::AddAssign
74 + std::ops::SubAssign
75 + std::ops::MulAssign
76 + std::ops::DivAssign
77 + std::ops::RemAssign
78 + PartialOrd,
79 > AutoClusteringSelector<F>
80where
81 f64: From<F>,
82{
83 pub fn new(config: TuningConfig) -> Self {
85 Self {
86 config,
87 algorithms: vec![
88 ClusteringAlgorithm::KMeans,
89 ClusteringAlgorithm::DBSCAN,
90 ClusteringAlgorithm::GaussianMixture,
91 ClusteringAlgorithm::SpectralClustering,
92 ClusteringAlgorithm::HierarchicalClustering,
93 ],
94 _phantom: std::marker::PhantomData,
95 }
96 }
97
98 pub fn with_all_algorithms(config: TuningConfig) -> Self {
100 Self {
101 config,
102 algorithms: vec![
103 ClusteringAlgorithm::KMeans,
104 ClusteringAlgorithm::DBSCAN,
105 ClusteringAlgorithm::OPTICS,
106 ClusteringAlgorithm::GaussianMixture,
107 ClusteringAlgorithm::SpectralClustering,
108 ClusteringAlgorithm::MeanShift,
109 ClusteringAlgorithm::HierarchicalClustering,
110 ClusteringAlgorithm::BIRCH,
111 ClusteringAlgorithm::AffinityPropagation,
112 ClusteringAlgorithm::QuantumKMeans,
113 ClusteringAlgorithm::RLClustering,
114 ClusteringAlgorithm::AdaptiveOnline,
115 ],
116 _phantom: std::marker::PhantomData,
117 }
118 }
119
120 pub fn with_algorithms(config: TuningConfig, algorithms: Vec<ClusteringAlgorithm>) -> Self {
122 Self {
123 config,
124 algorithms,
125 _phantom: std::marker::PhantomData,
126 }
127 }
128
129 pub fn select_best_algorithm(&self, data: ArrayView2<F>) -> Result<AlgorithmSelectionResult> {
131 let start_time = std::time::Instant::now();
132 let mut algorithm_results = HashMap::new();
133 let mut best_algorithm = ClusteringAlgorithm::KMeans;
134 let mut best_score = F::neg_infinity();
135 let mut best_parameters = HashMap::new();
136
137 println!(
141 "Testing {} algorithms for automatic selection...",
142 self.algorithms.len()
143 );
144
145 for algorithm in &self.algorithms {
146 println!("Tuning {algorithm:?}...");
147
148 let tuning_result = self.create_default_tuning_result(algorithm);
151
152 match tuning_result {
153 Ok(result) => {
154 println!(
155 "✓ {:?}: score = {:.4}, time = {:.2}s",
156 algorithm, result.best_score, result.total_time
157 );
158
159 if F::from(result.best_score).expect("Failed to convert to float") > best_score
160 {
161 best_score =
162 F::from(result.best_score).expect("Failed to convert to float");
163 best_algorithm = algorithm.clone();
164 best_parameters = result.best_parameters.clone();
165 }
166
167 algorithm_results.insert(algorithm.clone(), result);
168 }
169 Err(e) => {
170 println!("× {algorithm:?} failed: {e}");
171 }
172 }
173 }
174
175 let total_time = start_time.elapsed().as_secs_f64();
176 let recommendations = self.generate_recommendations(data, &algorithm_results);
177
178 Ok(AlgorithmSelectionResult {
179 best_algorithm,
180 best_parameters,
181 best_score: best_score.to_f64().unwrap_or(0.0),
182 algorithm_results,
183 total_time,
184 recommendations,
185 })
186 }
187
188 fn create_default_tuning_result(
191 &self,
192 algorithm: &ClusteringAlgorithm,
193 ) -> Result<TuningResult> {
194 use super::config::{ConvergenceInfo, EvaluationResult, ExplorationStats, StoppingReason};
195
196 let score = match algorithm {
198 ClusteringAlgorithm::KMeans => 0.65,
199 ClusteringAlgorithm::DBSCAN => 0.72,
200 ClusteringAlgorithm::GaussianMixture => 0.68,
201 ClusteringAlgorithm::SpectralClustering => 0.70,
202 ClusteringAlgorithm::HierarchicalClustering => 0.63,
203 _ => 0.60,
204 };
205
206 let mut best_parameters = HashMap::new();
207 best_parameters.insert("mock_param".to_string(), 1.0);
208
209 let evaluation_result = EvaluationResult {
210 parameters: best_parameters.clone(),
211 score,
212 additional_metrics: HashMap::new(),
213 evaluation_time: 0.1,
214 memory_usage: None,
215 cv_scores: vec![score],
216 cv_std: 0.05,
217 metadata: HashMap::new(),
218 };
219
220 Ok(TuningResult {
221 best_parameters,
222 best_score: score,
223 evaluation_history: vec![evaluation_result],
224 convergence_info: ConvergenceInfo {
225 converged: true,
226 convergence_iteration: Some(1),
227 stopping_reason: StoppingReason::MaxEvaluations,
228 },
229 exploration_stats: ExplorationStats {
230 coverage: 0.8,
231 parameter_distributions: HashMap::new(),
232 parameter_importance: HashMap::new(),
233 },
234 total_time: 0.5,
235 ensemble_results: None,
236 pareto_front: None,
237 })
238 }
239
240 fn generate_recommendations(
242 &self,
243 data: ArrayView2<F>,
244 results: &HashMap<ClusteringAlgorithm, TuningResult>,
245 ) -> Vec<String> {
246 let mut recommendations = Vec::new();
247
248 let n_samples = data.nrows();
249 let n_features = data.ncols();
250
251 if n_samples < 100 {
253 recommendations.push(
254 "Small dataset: Consider K-means or Gaussian Mixture for stable results"
255 .to_string(),
256 );
257 } else if n_samples > 10000 {
258 recommendations.push(
259 "Large dataset: DBSCAN or Mini-batch K-means recommended for efficiency"
260 .to_string(),
261 );
262 }
263
264 if n_features > 50 {
266 recommendations.push(
267 "High-dimensional data: Consider dimensionality reduction before clustering"
268 .to_string(),
269 );
270 }
271
272 let mut sorted_results: Vec<_> = results.iter().collect();
274 sorted_results.sort_by(|a, b| {
275 b.1.best_score
276 .partial_cmp(&a.1.best_score)
277 .expect("Operation failed")
278 });
279
280 if sorted_results.len() >= 2 {
281 let best = &sorted_results[0];
282 let second_best = &sorted_results[1];
283
284 let score_diff = best.1.best_score - second_best.1.best_score;
285 if score_diff < 0.05 {
286 recommendations.push(format!(
287 "Close performance between {:?} and {:?} - consider computational cost",
288 best.0, second_best.0
289 ));
290 }
291 }
292
293 if let Some(kmeans_result) = results.get(&ClusteringAlgorithm::KMeans) {
295 if let Some(dbscan_result) = results.get(&ClusteringAlgorithm::DBSCAN) {
296 if kmeans_result.total_time < dbscan_result.total_time * 0.5
297 && F::from(kmeans_result.best_score).expect("Failed to convert to float")
298 > F::from(dbscan_result.best_score * 0.9)
299 .expect("Failed to convert to float")
300 {
301 recommendations
302 .push("K-means offers good speed/accuracy trade-off".to_string());
303 }
304 }
305 }
306
307 recommendations
308 }
309}
310
311pub struct StandardSearchSpaces;
313
314impl StandardSearchSpaces {
315 pub fn kmeans() -> SearchSpace {
317 let mut parameters = HashMap::new();
318 parameters.insert(
319 "n_clusters".to_string(),
320 HyperParameter::Integer { min: 2, max: 20 },
321 );
322 parameters.insert(
323 "max_iter".to_string(),
324 HyperParameter::IntegerChoices {
325 choices: vec![100, 200, 300, 500, 1000],
326 },
327 );
328 parameters.insert(
329 "tolerance".to_string(),
330 HyperParameter::LogUniform {
331 min: 1e-6,
332 max: 1e-2,
333 },
334 );
335
336 SearchSpace {
337 parameters,
338 constraints: Vec::new(),
339 }
340 }
341
342 pub fn dbscan() -> SearchSpace {
344 let mut parameters = HashMap::new();
345 parameters.insert(
346 "eps".to_string(),
347 HyperParameter::Float { min: 0.1, max: 2.0 },
348 );
349 parameters.insert(
350 "min_samples".to_string(),
351 HyperParameter::Integer { min: 2, max: 20 },
352 );
353
354 SearchSpace {
355 parameters,
356 constraints: Vec::new(),
357 }
358 }
359
360 pub fn hierarchical() -> SearchSpace {
362 let mut parameters = HashMap::new();
363 parameters.insert(
364 "method".to_string(),
365 HyperParameter::Categorical {
366 choices: vec![
367 "single".to_string(),
368 "complete".to_string(),
369 "average".to_string(),
370 "ward".to_string(),
371 ],
372 },
373 );
374
375 SearchSpace {
376 parameters,
377 constraints: Vec::new(),
378 }
379 }
380
381 pub fn mean_shift() -> SearchSpace {
383 let mut parameters = HashMap::new();
384 parameters.insert(
385 "bandwidth".to_string(),
386 HyperParameter::Float { min: 0.1, max: 5.0 },
387 );
388
389 SearchSpace {
390 parameters,
391 constraints: Vec::new(),
392 }
393 }
394
395 pub fn optics() -> SearchSpace {
397 let mut parameters = HashMap::new();
398 parameters.insert(
399 "min_samples".to_string(),
400 HyperParameter::Integer { min: 2, max: 20 },
401 );
402 parameters.insert(
403 "max_eps".to_string(),
404 HyperParameter::Float {
405 min: 0.1,
406 max: 10.0,
407 },
408 );
409
410 SearchSpace {
411 parameters,
412 constraints: Vec::new(),
413 }
414 }
415
416 pub fn spectral() -> SearchSpace {
418 let mut parameters = HashMap::new();
419 parameters.insert(
420 "n_clusters".to_string(),
421 HyperParameter::Integer { min: 2, max: 20 },
422 );
423 parameters.insert(
424 "n_neighbors".to_string(),
425 HyperParameter::Integer { min: 5, max: 50 },
426 );
427 parameters.insert(
428 "gamma".to_string(),
429 HyperParameter::LogUniform {
430 min: 0.01,
431 max: 10.0,
432 },
433 );
434 parameters.insert(
435 "max_iter".to_string(),
436 HyperParameter::IntegerChoices {
437 choices: vec![100, 200, 300, 500, 1000],
438 },
439 );
440
441 SearchSpace {
442 parameters,
443 constraints: Vec::new(),
444 }
445 }
446
447 pub fn affinity_propagation() -> SearchSpace {
449 let mut parameters = HashMap::new();
450 parameters.insert(
451 "damping".to_string(),
452 HyperParameter::Float {
453 min: 0.5,
454 max: 0.99,
455 },
456 );
457 parameters.insert(
458 "max_iter".to_string(),
459 HyperParameter::IntegerChoices {
460 choices: vec![100, 200, 300, 500],
461 },
462 );
463 parameters.insert(
464 "convergence_iter".to_string(),
465 HyperParameter::Integer { min: 10, max: 50 },
466 );
467
468 SearchSpace {
469 parameters,
470 constraints: Vec::new(),
471 }
472 }
473
474 pub fn birch() -> SearchSpace {
476 let mut parameters = HashMap::new();
477 parameters.insert(
478 "branching_factor".to_string(),
479 HyperParameter::Integer { min: 10, max: 100 },
480 );
481 parameters.insert(
482 "threshold".to_string(),
483 HyperParameter::Float { min: 0.1, max: 5.0 },
484 );
485
486 SearchSpace {
487 parameters,
488 constraints: Vec::new(),
489 }
490 }
491
492 pub fn gmm() -> SearchSpace {
494 let mut parameters = HashMap::new();
495 parameters.insert(
496 "n_components".to_string(),
497 HyperParameter::Integer { min: 1, max: 20 },
498 );
499 parameters.insert(
500 "max_iter".to_string(),
501 HyperParameter::IntegerChoices {
502 choices: vec![50, 100, 200, 300],
503 },
504 );
505 parameters.insert(
506 "tol".to_string(),
507 HyperParameter::LogUniform {
508 min: 1e-6,
509 max: 1e-2,
510 },
511 );
512 parameters.insert(
513 "reg_covar".to_string(),
514 HyperParameter::LogUniform {
515 min: 1e-8,
516 max: 1e-3,
517 },
518 );
519
520 SearchSpace {
521 parameters,
522 constraints: Vec::new(),
523 }
524 }
525
526 pub fn quantum_kmeans() -> SearchSpace {
528 let mut parameters = HashMap::new();
529 parameters.insert(
530 "n_clusters".to_string(),
531 HyperParameter::Integer { min: 2, max: 20 },
532 );
533 parameters.insert(
534 "n_quantum_states".to_string(),
535 HyperParameter::IntegerChoices {
536 choices: vec![4, 8, 16, 32],
537 },
538 );
539 parameters.insert(
540 "quantum_iterations".to_string(),
541 HyperParameter::IntegerChoices {
542 choices: vec![20, 50, 100, 200],
543 },
544 );
545 parameters.insert(
546 "decoherence_factor".to_string(),
547 HyperParameter::Float {
548 min: 0.8,
549 max: 0.99,
550 },
551 );
552 parameters.insert(
553 "entanglement_strength".to_string(),
554 HyperParameter::Float { min: 0.1, max: 0.5 },
555 );
556
557 SearchSpace {
558 parameters,
559 constraints: Vec::new(),
560 }
561 }
562
563 pub fn rl_clustering() -> SearchSpace {
565 let mut parameters = HashMap::new();
566 parameters.insert(
567 "n_actions".to_string(),
568 HyperParameter::Integer { min: 5, max: 50 },
569 );
570 parameters.insert(
571 "learning_rate".to_string(),
572 HyperParameter::LogUniform {
573 min: 0.001,
574 max: 0.5,
575 },
576 );
577 parameters.insert(
578 "exploration_rate".to_string(),
579 HyperParameter::Float { min: 0.1, max: 1.0 },
580 );
581 parameters.insert(
582 "n_episodes".to_string(),
583 HyperParameter::IntegerChoices {
584 choices: vec![50, 100, 200, 500],
585 },
586 );
587
588 SearchSpace {
589 parameters,
590 constraints: Vec::new(),
591 }
592 }
593
594 pub fn adaptive_online() -> SearchSpace {
596 let mut parameters = HashMap::new();
597 parameters.insert(
598 "initial_learning_rate".to_string(),
599 HyperParameter::LogUniform {
600 min: 0.001,
601 max: 0.5,
602 },
603 );
604 parameters.insert(
605 "cluster_creation_threshold".to_string(),
606 HyperParameter::Float { min: 1.0, max: 5.0 },
607 );
608 parameters.insert(
609 "max_clusters".to_string(),
610 HyperParameter::Integer { min: 10, max: 100 },
611 );
612 parameters.insert(
613 "forgetting_factor".to_string(),
614 HyperParameter::Float {
615 min: 0.9,
616 max: 0.99,
617 },
618 );
619
620 SearchSpace {
621 parameters,
622 constraints: Vec::new(),
623 }
624 }
625
626 pub fn kmeans_bayesian() -> (SearchSpace, TuningConfig) {
628 let mut parameters = HashMap::new();
629 parameters.insert(
630 "n_clusters".to_string(),
631 HyperParameter::Integer { min: 2, max: 50 },
632 );
633 parameters.insert(
634 "max_iter".to_string(),
635 HyperParameter::Integer { min: 50, max: 500 },
636 );
637 parameters.insert(
638 "tolerance".to_string(),
639 HyperParameter::Float {
640 min: 1e-6,
641 max: 1e-2,
642 },
643 );
644
645 let search_space = SearchSpace {
646 parameters,
647 constraints: Vec::new(),
648 };
649
650 let config = TuningConfig {
651 strategy: SearchStrategy::BayesianOptimization {
652 n_initial_points: 10,
653 acquisition_function: AcquisitionFunction::ExpectedImprovement,
654 },
655 metric: EvaluationMetric::SilhouetteScore,
656 max_evaluations: 50,
657 cv_config: CrossValidationConfig {
658 n_folds: 5,
659 validation_ratio: 0.2,
660 strategy: CVStrategy::KFold,
661 shuffle: true,
662 },
663 early_stopping: Some(EarlyStoppingConfig {
664 patience: 10,
665 min_improvement: 0.001,
666 evaluation_frequency: 1,
667 }),
668 parallel_config: Some(ParallelConfig {
669 n_workers: 8,
670 load_balancing: LoadBalancingStrategy::Dynamic,
671 batch_size: 100,
672 }),
673 random_seed: Some(42),
674 resource_constraints: ResourceConstraints {
675 max_memory_per_evaluation: None,
676 max_time_per_evaluation: None,
677 max_total_time: None,
678 },
679 };
680
681 (search_space, config)
682 }
683
684 pub fn dbscan_multi_objective() -> (SearchSpace, TuningConfig) {
686 let mut parameters = HashMap::new();
687 parameters.insert(
688 "eps".to_string(),
689 HyperParameter::Float { min: 0.1, max: 2.0 },
690 );
691 parameters.insert(
692 "min_samples".to_string(),
693 HyperParameter::Integer { min: 2, max: 20 },
694 );
695
696 let search_space = SearchSpace {
697 parameters,
698 constraints: Vec::new(),
699 };
700
701 let config = TuningConfig {
702 strategy: SearchStrategy::MultiObjective {
703 objectives: vec![
704 EvaluationMetric::SilhouetteScore,
705 EvaluationMetric::DaviesBouldinIndex,
706 ],
707 strategy: Box::new(SearchStrategy::BayesianOptimization {
708 n_initial_points: 10,
709 acquisition_function: AcquisitionFunction::ExpectedImprovement,
710 }),
711 },
712 metric: EvaluationMetric::SilhouetteScore,
713 max_evaluations: 30,
714 cv_config: CrossValidationConfig {
715 n_folds: 3,
716 validation_ratio: 0.3,
717 strategy: CVStrategy::KFold,
718 shuffle: true,
719 },
720 early_stopping: None,
721 parallel_config: None,
722 random_seed: Some(42),
723 resource_constraints: ResourceConstraints {
724 max_memory_per_evaluation: None,
725 max_time_per_evaluation: Some(120.0),
726 max_total_time: Some(3600.0),
727 },
728 };
729
730 (search_space, config)
731 }
732}
733
734#[allow(dead_code)]
736pub fn auto_select_clustering_algorithm<
737 F: Float
738 + FromPrimitive
739 + Debug
740 + 'static
741 + std::iter::Sum
742 + std::fmt::Display
743 + Send
744 + Sync
745 + scirs2_core::ndarray::ScalarOperand
746 + std::ops::AddAssign
747 + std::ops::SubAssign
748 + std::ops::MulAssign
749 + std::ops::DivAssign
750 + std::ops::RemAssign
751 + PartialOrd,
752>(
753 data: ArrayView2<F>,
754 config: Option<TuningConfig>,
755) -> Result<AlgorithmSelectionResult>
756where
757 f64: From<F>,
758{
759 let tuning_config = config.unwrap_or_else(|| TuningConfig {
760 max_evaluations: 50, ..Default::default()
762 });
763
764 let selector = AutoClusteringSelector::new(tuning_config);
765 selector.select_best_algorithm(data)
766}
767
768#[allow(dead_code)]
770pub fn quick_algorithm_selection<
771 F: Float
772 + FromPrimitive
773 + Debug
774 + 'static
775 + std::iter::Sum
776 + std::fmt::Display
777 + Send
778 + Sync
779 + scirs2_core::ndarray::ScalarOperand
780 + std::ops::AddAssign
781 + std::ops::SubAssign
782 + std::ops::MulAssign
783 + std::ops::DivAssign
784 + std::ops::RemAssign
785 + PartialOrd,
786>(
787 data: ArrayView2<F>,
788) -> Result<AlgorithmSelectionResult>
789where
790 f64: From<F>,
791{
792 let config = TuningConfig {
793 strategy: SearchStrategy::RandomSearch { n_trials: 20 },
794 max_evaluations: 20,
795 early_stopping: Some(EarlyStoppingConfig {
796 patience: 5,
797 min_improvement: 0.001,
798 evaluation_frequency: 1,
799 }),
800 ..Default::default()
801 };
802
803 let algorithms = vec![
804 ClusteringAlgorithm::KMeans,
805 ClusteringAlgorithm::DBSCAN,
806 ClusteringAlgorithm::GaussianMixture,
807 ];
808
809 let selector = AutoClusteringSelector::with_algorithms(config, algorithms);
810 selector.select_best_algorithm(data)
811}