scirs2_cluster/tuning/
algorithm_tuners.rs

1//! Algorithm-specific hyperparameter tuners
2//!
3//! This module contains the main AutoTuner implementation and methods
4//! for tuning hyperparameters of specific clustering algorithms.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use scirs2_core::random::{rng, Rng, SeedableRng};
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12use crate::advanced::{
13    adaptive_online_clustering, quantum_kmeans, rl_clustering, AdaptiveOnlineConfig, QuantumConfig,
14    RLClusteringConfig,
15};
16use crate::affinity::{affinity_propagation, AffinityPropagationOptions};
17use crate::birch::{birch, BirchOptions};
18use crate::density::{dbscan, optics};
19use crate::error::{ClusteringError, Result};
20use crate::gmm::{gaussian_mixture, CovarianceType, GMMInit, GMMOptions};
21use crate::hierarchy::linkage;
22use crate::meanshift::mean_shift;
23use crate::metrics::{calinski_harabasz_score, davies_bouldin_score, silhouette_score};
24use crate::spectral::{spectral_clustering, AffinityMode, SpectralClusteringOptions};
25use crate::stability::OptimalKSelector;
26use crate::vq::{kmeans, kmeans2};
27
28use super::config::*;
29use super::cross_validation::CrossValidator;
30use super::optimization_strategies::ParameterGenerator;
31use super::utilities::*;
32
33use statrs::statistics::Statistics;
34
35/// Main hyperparameter tuning engine for clustering algorithms
36pub struct AutoTuner<F: Float> {
37    config: TuningConfig,
38    phantom: std::marker::PhantomData<F>,
39}
40
41impl<
42        F: Float
43            + FromPrimitive
44            + Debug
45            + 'static
46            + std::iter::Sum
47            + std::fmt::Display
48            + Send
49            + Sync
50            + scirs2_core::ndarray::ScalarOperand
51            + std::ops::AddAssign
52            + std::ops::SubAssign
53            + std::ops::MulAssign
54            + std::ops::DivAssign
55            + std::ops::RemAssign
56            + PartialOrd,
57    > AutoTuner<F>
58where
59    f64: From<F>,
60{
61    /// Create a new auto tuner with specified configuration
62    pub fn new(config: TuningConfig) -> Self {
63        Self {
64            config,
65            phantom: std::marker::PhantomData,
66        }
67    }
68
69    /// Tune K-means hyperparameters
70    pub fn tune_kmeans(
71        &self,
72        data: ArrayView2<F>,
73        search_space: SearchSpace,
74    ) -> Result<TuningResult> {
75        let start_time = std::time::Instant::now();
76        let mut evaluation_history = Vec::new();
77        let mut best_score = f64::NEG_INFINITY;
78        let mut best_parameters = HashMap::new();
79
80        let parameter_generator = ParameterGenerator::new(&self.config);
81        let parameter_combinations = parameter_generator.generate_combinations(&search_space)?;
82
83        let mut rng = match self.config.random_seed {
84            Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
85            None => scirs2_core::random::rngs::StdRng::seed_from_u64(42),
86        };
87
88        for (eval_idx, params) in parameter_combinations.iter().enumerate() {
89            if eval_idx >= self.config.max_evaluations {
90                break;
91            }
92
93            if let Some(max_time) = self.config.resource_constraints.max_total_time {
94                if start_time.elapsed().as_secs_f64() > max_time {
95                    break;
96                }
97            }
98
99            let eval_start = std::time::Instant::now();
100
101            let k = params.get("n_clusters").map(|&x| x as usize).unwrap_or(3);
102            let max_iter = params.get("max_iter").map(|&x| x as usize);
103            let tol = params.get("tolerance").copied();
104            let seed = rng.random_range(0..u64::MAX);
105
106            let cv_validator = CrossValidator::new(&self.config.cv_config);
107            let cv_scores = cv_validator.cross_validate_kmeans(
108                data,
109                k,
110                max_iter,
111                tol,
112                Some(seed),
113                &self.config.metric,
114            )?;
115
116            let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
117            let cv_std = calculate_std_dev(&cv_scores);
118            let eval_time = eval_start.elapsed().as_secs_f64();
119
120            let result = EvaluationResult {
121                parameters: params.clone(),
122                score: mean_score,
123                additional_metrics: HashMap::new(),
124                evaluation_time: eval_time,
125                memory_usage: None,
126                cv_scores,
127                cv_std,
128                metadata: HashMap::new(),
129            };
130
131            let is_better = is_score_better(mean_score, best_score, &self.config.metric);
132            if is_better {
133                best_score = mean_score;
134                best_parameters = params.clone();
135            }
136
137            evaluation_history.push(result);
138
139            if let Some(ref early_stop) = self.config.early_stopping {
140                if should_stop_early(&evaluation_history, early_stop) {
141                    break;
142                }
143            }
144        }
145
146        let total_time = start_time.elapsed().as_secs_f64();
147        let convergence_info =
148            create_convergence_info(&evaluation_history, self.config.max_evaluations);
149        let exploration_stats = calculate_exploration_stats(&evaluation_history);
150
151        Ok(TuningResult {
152            best_parameters,
153            best_score,
154            evaluation_history,
155            convergence_info,
156            exploration_stats,
157            total_time,
158            ensemble_results: None,
159            pareto_front: None,
160        })
161    }
162
163    /// Tune DBSCAN hyperparameters
164    pub fn tune_dbscan(
165        &self,
166        data: ArrayView2<F>,
167        search_space: SearchSpace,
168    ) -> Result<TuningResult> {
169        let start_time = std::time::Instant::now();
170        let mut evaluation_history = Vec::new();
171        let mut best_score = f64::NEG_INFINITY;
172        let mut best_parameters = HashMap::new();
173
174        let parameter_generator = ParameterGenerator::new(&self.config);
175        let parameter_combinations = parameter_generator.generate_combinations(&search_space)?;
176
177        for (eval_idx, params) in parameter_combinations.iter().enumerate() {
178            if eval_idx >= self.config.max_evaluations {
179                break;
180            }
181
182            let eval_start = std::time::Instant::now();
183
184            let eps = params.get("eps").copied().unwrap_or(0.5);
185            let min_samples = params.get("min_samples").map(|&x| x as usize).unwrap_or(5);
186
187            let cv_validator = CrossValidator::new(&self.config.cv_config);
188            let cv_scores =
189                cv_validator.cross_validate_dbscan(data, eps, min_samples, &self.config.metric)?;
190
191            let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
192            let cv_std = calculate_std_dev(&cv_scores);
193            let eval_time = eval_start.elapsed().as_secs_f64();
194
195            let result = EvaluationResult {
196                parameters: params.clone(),
197                score: mean_score,
198                additional_metrics: HashMap::new(),
199                evaluation_time: eval_time,
200                memory_usage: None,
201                cv_scores,
202                cv_std,
203                metadata: HashMap::new(),
204            };
205
206            let is_better = is_score_better(mean_score, best_score, &self.config.metric);
207            if is_better {
208                best_score = mean_score;
209                best_parameters = params.clone();
210            }
211
212            evaluation_history.push(result);
213
214            if let Some(ref early_stop) = self.config.early_stopping {
215                if should_stop_early(&evaluation_history, early_stop) {
216                    break;
217                }
218            }
219        }
220
221        let total_time = start_time.elapsed().as_secs_f64();
222        let convergence_info =
223            create_convergence_info(&evaluation_history, self.config.max_evaluations);
224        let exploration_stats = calculate_exploration_stats(&evaluation_history);
225
226        Ok(TuningResult {
227            best_parameters,
228            best_score,
229            evaluation_history,
230            convergence_info,
231            exploration_stats,
232            total_time,
233            ensemble_results: None,
234            pareto_front: None,
235        })
236    }
237
238    /// Tune OPTICS hyperparameters
239    pub fn tune_optics(
240        &self,
241        data: ArrayView2<F>,
242        search_space: SearchSpace,
243    ) -> Result<TuningResult> {
244        let start_time = std::time::Instant::now();
245        let mut evaluation_history = Vec::new();
246        let mut best_score = f64::NEG_INFINITY;
247        let mut best_parameters = HashMap::new();
248
249        let parameter_generator = ParameterGenerator::new(&self.config);
250        let parameter_combinations = parameter_generator.generate_combinations(&search_space)?;
251
252        for combination in &parameter_combinations {
253            let min_samples = combination
254                .get("min_samples")
255                .ok_or_else(|| {
256                    ClusteringError::InvalidInput("min_samples parameter not found".to_string())
257                })?
258                .round() as usize;
259            let max_eps = combination.get("max_eps").copied().unwrap_or(5.0);
260
261            let cv_validator = CrossValidator::new(&self.config.cv_config);
262            let scores = cv_validator.cross_validate_optics(
263                data,
264                min_samples,
265                Some(F::from(max_eps).unwrap()),
266                &self.config.metric,
267            )?;
268            let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
269
270            evaluation_history.push(EvaluationResult {
271                parameters: combination.clone(),
272                score: mean_score,
273                additional_metrics: HashMap::new(),
274                evaluation_time: 0.0,
275                memory_usage: None,
276                cv_scores: scores,
277                cv_std: 0.0,
278                metadata: HashMap::new(),
279            });
280
281            if mean_score > best_score {
282                best_score = mean_score;
283                best_parameters = combination.clone();
284            }
285        }
286
287        let total_time = start_time.elapsed().as_secs_f64();
288
289        Ok(TuningResult {
290            best_parameters,
291            best_score,
292            evaluation_history: evaluation_history.clone(),
293            convergence_info: ConvergenceInfo {
294                converged: false,
295                convergence_iteration: None,
296                stopping_reason: StoppingReason::MaxEvaluations,
297            },
298            exploration_stats: calculate_exploration_stats(&evaluation_history),
299            total_time,
300            ensemble_results: None,
301            pareto_front: None,
302        })
303    }
304
305    /// Tune Spectral clustering hyperparameters
306    pub fn tune_spectral(
307        &self,
308        data: ArrayView2<F>,
309        search_space: SearchSpace,
310    ) -> Result<TuningResult> {
311        let start_time = std::time::Instant::now();
312        let mut evaluation_history = Vec::new();
313        let mut best_score = f64::NEG_INFINITY;
314        let mut best_parameters = HashMap::new();
315
316        let parameter_generator = ParameterGenerator::new(&self.config);
317        let parameter_combinations = parameter_generator.generate_combinations(&search_space)?;
318
319        for combination in &parameter_combinations {
320            let n_clusters = combination
321                .get("n_clusters")
322                .ok_or_else(|| {
323                    ClusteringError::InvalidInput("n_clusters parameter not found".to_string())
324                })?
325                .round() as usize;
326            let n_neighbors = combination
327                .get("n_neighbors")
328                .copied()
329                .unwrap_or(10.0)
330                .round() as usize;
331            let gamma = combination.get("gamma").copied().unwrap_or(1.0);
332            let max_iter = combination
333                .get("max_iter")
334                .copied()
335                .unwrap_or(300.0)
336                .round() as usize;
337
338            let cv_validator = CrossValidator::new(&self.config.cv_config);
339            let scores = cv_validator.cross_validate_spectral(
340                data,
341                n_clusters,
342                n_neighbors,
343                F::from(gamma).unwrap(),
344                max_iter,
345                &self.config.metric,
346            )?;
347            let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
348
349            evaluation_history.push(EvaluationResult {
350                parameters: combination.clone(),
351                score: mean_score,
352                additional_metrics: HashMap::new(),
353                evaluation_time: 0.0,
354                memory_usage: None,
355                cv_scores: scores.clone(),
356                cv_std: scores.std_dev(),
357                metadata: HashMap::new(),
358            });
359
360            if mean_score > best_score {
361                best_score = mean_score;
362                best_parameters = combination.clone();
363            }
364        }
365
366        let total_time = start_time.elapsed().as_secs_f64();
367
368        Ok(TuningResult {
369            best_parameters,
370            best_score,
371            evaluation_history: evaluation_history.clone(),
372            convergence_info: ConvergenceInfo {
373                converged: false,
374                convergence_iteration: None,
375                stopping_reason: StoppingReason::MaxEvaluations,
376            },
377            exploration_stats: calculate_exploration_stats(&evaluation_history),
378            total_time,
379            ensemble_results: None,
380            pareto_front: None,
381        })
382    }
383
384    /// Tune Affinity Propagation hyperparameters
385    pub fn tune_affinity_propagation(
386        &self,
387        data: ArrayView2<F>,
388        search_space: SearchSpace,
389    ) -> Result<TuningResult> {
390        let start_time = std::time::Instant::now();
391        let mut evaluation_history = Vec::new();
392        let mut best_score = f64::NEG_INFINITY;
393        let mut best_parameters = HashMap::new();
394
395        let parameter_generator = ParameterGenerator::new(&self.config);
396        let parameter_combinations = parameter_generator.generate_combinations(&search_space)?;
397
398        for combination in &parameter_combinations {
399            let damping = combination.get("damping").copied().unwrap_or(0.5);
400            let max_iter = combination
401                .get("max_iter")
402                .copied()
403                .unwrap_or(200.0)
404                .round() as usize;
405            let convergence_iter = combination
406                .get("convergence_iter")
407                .copied()
408                .unwrap_or(15.0)
409                .round() as usize;
410
411            let cv_validator = CrossValidator::new(&self.config.cv_config);
412            let scores = cv_validator.cross_validate_affinity_propagation(
413                data,
414                F::from(damping).unwrap(),
415                max_iter,
416                convergence_iter,
417                &self.config.metric,
418            )?;
419            let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
420
421            evaluation_history.push(EvaluationResult {
422                parameters: combination.clone(),
423                score: mean_score,
424                additional_metrics: HashMap::new(),
425                evaluation_time: 0.0,
426                memory_usage: None,
427                cv_scores: scores.clone(),
428                cv_std: scores.std_dev(),
429                metadata: HashMap::new(),
430            });
431
432            if mean_score > best_score {
433                best_score = mean_score;
434                best_parameters = combination.clone();
435            }
436        }
437
438        let total_time = start_time.elapsed().as_secs_f64();
439
440        Ok(TuningResult {
441            best_parameters,
442            best_score,
443            evaluation_history: evaluation_history.clone(),
444            convergence_info: ConvergenceInfo {
445                converged: false,
446                convergence_iteration: None,
447                stopping_reason: StoppingReason::MaxEvaluations,
448            },
449            exploration_stats: calculate_exploration_stats(&evaluation_history),
450            total_time,
451            ensemble_results: None,
452            pareto_front: None,
453        })
454    }
455
456    /// Tune BIRCH hyperparameters
457    pub fn tune_birch(
458        &self,
459        data: ArrayView2<F>,
460        search_space: SearchSpace,
461    ) -> Result<TuningResult> {
462        let start_time = std::time::Instant::now();
463        let mut evaluation_history = Vec::new();
464        let mut best_score = f64::NEG_INFINITY;
465        let mut best_parameters = HashMap::new();
466
467        let parameter_generator = ParameterGenerator::new(&self.config);
468        let parameter_combinations = parameter_generator.generate_combinations(&search_space)?;
469
470        for combination in &parameter_combinations {
471            let branching_factor = combination
472                .get("branching_factor")
473                .copied()
474                .unwrap_or(50.0)
475                .round() as usize;
476            let threshold = combination.get("threshold").copied().unwrap_or(0.5);
477
478            let cv_validator = CrossValidator::new(&self.config.cv_config);
479            let scores = cv_validator.cross_validate_birch(
480                data,
481                branching_factor,
482                F::from(threshold).unwrap(),
483                &self.config.metric,
484            )?;
485            let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
486
487            evaluation_history.push(EvaluationResult {
488                parameters: combination.clone(),
489                score: mean_score,
490                additional_metrics: HashMap::new(),
491                evaluation_time: 0.0,
492                memory_usage: None,
493                cv_scores: scores.clone(),
494                cv_std: scores.std_dev(),
495                metadata: HashMap::new(),
496            });
497
498            if mean_score > best_score {
499                best_score = mean_score;
500                best_parameters = combination.clone();
501            }
502        }
503
504        let total_time = start_time.elapsed().as_secs_f64();
505
506        Ok(TuningResult {
507            best_parameters,
508            best_score,
509            evaluation_history: evaluation_history.clone(),
510            convergence_info: ConvergenceInfo {
511                converged: false,
512                convergence_iteration: None,
513                stopping_reason: StoppingReason::MaxEvaluations,
514            },
515            exploration_stats: calculate_exploration_stats(&evaluation_history),
516            total_time,
517            ensemble_results: None,
518            pareto_front: None,
519        })
520    }
521
522    /// Tune GMM hyperparameters
523    pub fn tune_gmm(&self, data: ArrayView2<F>, search_space: SearchSpace) -> Result<TuningResult> {
524        let start_time = std::time::Instant::now();
525        let mut evaluation_history = Vec::new();
526        let mut best_score = f64::NEG_INFINITY;
527        let mut best_parameters = HashMap::new();
528
529        let parameter_generator = ParameterGenerator::new(&self.config);
530        let parameter_combinations = parameter_generator.generate_combinations(&search_space)?;
531
532        for combination in &parameter_combinations {
533            let n_components = combination
534                .get("n_components")
535                .ok_or_else(|| {
536                    ClusteringError::InvalidInput("n_components parameter not found".to_string())
537                })?
538                .round() as usize;
539            let max_iter = combination
540                .get("max_iter")
541                .copied()
542                .unwrap_or(100.0)
543                .round() as usize;
544            let tol = combination.get("tol").copied().unwrap_or(1e-3);
545            let reg_covar = combination.get("reg_covar").copied().unwrap_or(1e-6);
546
547            let cv_validator = CrossValidator::new(&self.config.cv_config);
548            let scores = cv_validator.cross_validate_gmm(
549                data,
550                n_components,
551                max_iter,
552                F::from(tol).unwrap(),
553                F::from(reg_covar).unwrap(),
554                &self.config.metric,
555            )?;
556            let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
557
558            evaluation_history.push(EvaluationResult {
559                parameters: combination.clone(),
560                score: mean_score,
561                additional_metrics: HashMap::new(),
562                evaluation_time: 0.0,
563                memory_usage: None,
564                cv_scores: scores.clone(),
565                cv_std: scores.std_dev(),
566                metadata: HashMap::new(),
567            });
568
569            if mean_score > best_score {
570                best_score = mean_score;
571                best_parameters = combination.clone();
572            }
573        }
574
575        let total_time = start_time.elapsed().as_secs_f64();
576
577        Ok(TuningResult {
578            best_parameters,
579            best_score,
580            evaluation_history: evaluation_history.clone(),
581            convergence_info: ConvergenceInfo {
582                converged: false,
583                convergence_iteration: None,
584                stopping_reason: StoppingReason::MaxEvaluations,
585            },
586            exploration_stats: calculate_exploration_stats(&evaluation_history),
587            total_time,
588            ensemble_results: None,
589            pareto_front: None,
590        })
591    }
592}