1use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use scirs2_core::random::{rng, Rng, SeedableRng};
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12use crate::advanced::{
13 adaptive_online_clustering, quantum_kmeans, rl_clustering, AdaptiveOnlineConfig, QuantumConfig,
14 RLClusteringConfig,
15};
16use crate::affinity::{affinity_propagation, AffinityPropagationOptions};
17use crate::birch::{birch, BirchOptions};
18use crate::density::{dbscan, optics};
19use crate::error::{ClusteringError, Result};
20use crate::gmm::{gaussian_mixture, CovarianceType, GMMInit, GMMOptions};
21use crate::hierarchy::linkage;
22use crate::meanshift::mean_shift;
23use crate::metrics::{calinski_harabasz_score, davies_bouldin_score, silhouette_score};
24use crate::spectral::{spectral_clustering, AffinityMode, SpectralClusteringOptions};
25use crate::stability::OptimalKSelector;
26use crate::vq::{kmeans, kmeans2};
27
28use super::config::*;
29use super::cross_validation::CrossValidator;
30use super::optimization_strategies::ParameterGenerator;
31use super::utilities::*;
32
33use statrs::statistics::Statistics;
34
35pub struct AutoTuner<F: Float> {
37 config: TuningConfig,
38 phantom: std::marker::PhantomData<F>,
39}
40
41impl<
42 F: Float
43 + FromPrimitive
44 + Debug
45 + 'static
46 + std::iter::Sum
47 + std::fmt::Display
48 + Send
49 + Sync
50 + scirs2_core::ndarray::ScalarOperand
51 + std::ops::AddAssign
52 + std::ops::SubAssign
53 + std::ops::MulAssign
54 + std::ops::DivAssign
55 + std::ops::RemAssign
56 + PartialOrd,
57 > AutoTuner<F>
58where
59 f64: From<F>,
60{
61 pub fn new(config: TuningConfig) -> Self {
63 Self {
64 config,
65 phantom: std::marker::PhantomData,
66 }
67 }
68
69 pub fn tune_kmeans(
71 &self,
72 data: ArrayView2<F>,
73 search_space: SearchSpace,
74 ) -> Result<TuningResult> {
75 let start_time = std::time::Instant::now();
76 let mut evaluation_history = Vec::new();
77 let mut best_score = f64::NEG_INFINITY;
78 let mut best_parameters = HashMap::new();
79
80 let parameter_generator = ParameterGenerator::new(&self.config);
81 let parameter_combinations = parameter_generator.generate_combinations(&search_space)?;
82
83 let mut rng = match self.config.random_seed {
84 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
85 None => scirs2_core::random::rngs::StdRng::seed_from_u64(42),
86 };
87
88 for (eval_idx, params) in parameter_combinations.iter().enumerate() {
89 if eval_idx >= self.config.max_evaluations {
90 break;
91 }
92
93 if let Some(max_time) = self.config.resource_constraints.max_total_time {
94 if start_time.elapsed().as_secs_f64() > max_time {
95 break;
96 }
97 }
98
99 let eval_start = std::time::Instant::now();
100
101 let k = params.get("n_clusters").map(|&x| x as usize).unwrap_or(3);
102 let max_iter = params.get("max_iter").map(|&x| x as usize);
103 let tol = params.get("tolerance").copied();
104 let seed = rng.random_range(0..u64::MAX);
105
106 let cv_validator = CrossValidator::new(&self.config.cv_config);
107 let cv_scores = cv_validator.cross_validate_kmeans(
108 data,
109 k,
110 max_iter,
111 tol,
112 Some(seed),
113 &self.config.metric,
114 )?;
115
116 let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
117 let cv_std = calculate_std_dev(&cv_scores);
118 let eval_time = eval_start.elapsed().as_secs_f64();
119
120 let result = EvaluationResult {
121 parameters: params.clone(),
122 score: mean_score,
123 additional_metrics: HashMap::new(),
124 evaluation_time: eval_time,
125 memory_usage: None,
126 cv_scores,
127 cv_std,
128 metadata: HashMap::new(),
129 };
130
131 let is_better = is_score_better(mean_score, best_score, &self.config.metric);
132 if is_better {
133 best_score = mean_score;
134 best_parameters = params.clone();
135 }
136
137 evaluation_history.push(result);
138
139 if let Some(ref early_stop) = self.config.early_stopping {
140 if should_stop_early(&evaluation_history, early_stop) {
141 break;
142 }
143 }
144 }
145
146 let total_time = start_time.elapsed().as_secs_f64();
147 let convergence_info =
148 create_convergence_info(&evaluation_history, self.config.max_evaluations);
149 let exploration_stats = calculate_exploration_stats(&evaluation_history);
150
151 Ok(TuningResult {
152 best_parameters,
153 best_score,
154 evaluation_history,
155 convergence_info,
156 exploration_stats,
157 total_time,
158 ensemble_results: None,
159 pareto_front: None,
160 })
161 }
162
163 pub fn tune_dbscan(
165 &self,
166 data: ArrayView2<F>,
167 search_space: SearchSpace,
168 ) -> Result<TuningResult> {
169 let start_time = std::time::Instant::now();
170 let mut evaluation_history = Vec::new();
171 let mut best_score = f64::NEG_INFINITY;
172 let mut best_parameters = HashMap::new();
173
174 let parameter_generator = ParameterGenerator::new(&self.config);
175 let parameter_combinations = parameter_generator.generate_combinations(&search_space)?;
176
177 for (eval_idx, params) in parameter_combinations.iter().enumerate() {
178 if eval_idx >= self.config.max_evaluations {
179 break;
180 }
181
182 let eval_start = std::time::Instant::now();
183
184 let eps = params.get("eps").copied().unwrap_or(0.5);
185 let min_samples = params.get("min_samples").map(|&x| x as usize).unwrap_or(5);
186
187 let cv_validator = CrossValidator::new(&self.config.cv_config);
188 let cv_scores =
189 cv_validator.cross_validate_dbscan(data, eps, min_samples, &self.config.metric)?;
190
191 let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
192 let cv_std = calculate_std_dev(&cv_scores);
193 let eval_time = eval_start.elapsed().as_secs_f64();
194
195 let result = EvaluationResult {
196 parameters: params.clone(),
197 score: mean_score,
198 additional_metrics: HashMap::new(),
199 evaluation_time: eval_time,
200 memory_usage: None,
201 cv_scores,
202 cv_std,
203 metadata: HashMap::new(),
204 };
205
206 let is_better = is_score_better(mean_score, best_score, &self.config.metric);
207 if is_better {
208 best_score = mean_score;
209 best_parameters = params.clone();
210 }
211
212 evaluation_history.push(result);
213
214 if let Some(ref early_stop) = self.config.early_stopping {
215 if should_stop_early(&evaluation_history, early_stop) {
216 break;
217 }
218 }
219 }
220
221 let total_time = start_time.elapsed().as_secs_f64();
222 let convergence_info =
223 create_convergence_info(&evaluation_history, self.config.max_evaluations);
224 let exploration_stats = calculate_exploration_stats(&evaluation_history);
225
226 Ok(TuningResult {
227 best_parameters,
228 best_score,
229 evaluation_history,
230 convergence_info,
231 exploration_stats,
232 total_time,
233 ensemble_results: None,
234 pareto_front: None,
235 })
236 }
237
238 pub fn tune_optics(
240 &self,
241 data: ArrayView2<F>,
242 search_space: SearchSpace,
243 ) -> Result<TuningResult> {
244 let start_time = std::time::Instant::now();
245 let mut evaluation_history = Vec::new();
246 let mut best_score = f64::NEG_INFINITY;
247 let mut best_parameters = HashMap::new();
248
249 let parameter_generator = ParameterGenerator::new(&self.config);
250 let parameter_combinations = parameter_generator.generate_combinations(&search_space)?;
251
252 for combination in ¶meter_combinations {
253 let min_samples = combination
254 .get("min_samples")
255 .ok_or_else(|| {
256 ClusteringError::InvalidInput("min_samples parameter not found".to_string())
257 })?
258 .round() as usize;
259 let max_eps = combination.get("max_eps").copied().unwrap_or(5.0);
260
261 let cv_validator = CrossValidator::new(&self.config.cv_config);
262 let scores = cv_validator.cross_validate_optics(
263 data,
264 min_samples,
265 Some(F::from(max_eps).unwrap()),
266 &self.config.metric,
267 )?;
268 let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
269
270 evaluation_history.push(EvaluationResult {
271 parameters: combination.clone(),
272 score: mean_score,
273 additional_metrics: HashMap::new(),
274 evaluation_time: 0.0,
275 memory_usage: None,
276 cv_scores: scores,
277 cv_std: 0.0,
278 metadata: HashMap::new(),
279 });
280
281 if mean_score > best_score {
282 best_score = mean_score;
283 best_parameters = combination.clone();
284 }
285 }
286
287 let total_time = start_time.elapsed().as_secs_f64();
288
289 Ok(TuningResult {
290 best_parameters,
291 best_score,
292 evaluation_history: evaluation_history.clone(),
293 convergence_info: ConvergenceInfo {
294 converged: false,
295 convergence_iteration: None,
296 stopping_reason: StoppingReason::MaxEvaluations,
297 },
298 exploration_stats: calculate_exploration_stats(&evaluation_history),
299 total_time,
300 ensemble_results: None,
301 pareto_front: None,
302 })
303 }
304
305 pub fn tune_spectral(
307 &self,
308 data: ArrayView2<F>,
309 search_space: SearchSpace,
310 ) -> Result<TuningResult> {
311 let start_time = std::time::Instant::now();
312 let mut evaluation_history = Vec::new();
313 let mut best_score = f64::NEG_INFINITY;
314 let mut best_parameters = HashMap::new();
315
316 let parameter_generator = ParameterGenerator::new(&self.config);
317 let parameter_combinations = parameter_generator.generate_combinations(&search_space)?;
318
319 for combination in ¶meter_combinations {
320 let n_clusters = combination
321 .get("n_clusters")
322 .ok_or_else(|| {
323 ClusteringError::InvalidInput("n_clusters parameter not found".to_string())
324 })?
325 .round() as usize;
326 let n_neighbors = combination
327 .get("n_neighbors")
328 .copied()
329 .unwrap_or(10.0)
330 .round() as usize;
331 let gamma = combination.get("gamma").copied().unwrap_or(1.0);
332 let max_iter = combination
333 .get("max_iter")
334 .copied()
335 .unwrap_or(300.0)
336 .round() as usize;
337
338 let cv_validator = CrossValidator::new(&self.config.cv_config);
339 let scores = cv_validator.cross_validate_spectral(
340 data,
341 n_clusters,
342 n_neighbors,
343 F::from(gamma).unwrap(),
344 max_iter,
345 &self.config.metric,
346 )?;
347 let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
348
349 evaluation_history.push(EvaluationResult {
350 parameters: combination.clone(),
351 score: mean_score,
352 additional_metrics: HashMap::new(),
353 evaluation_time: 0.0,
354 memory_usage: None,
355 cv_scores: scores.clone(),
356 cv_std: scores.std_dev(),
357 metadata: HashMap::new(),
358 });
359
360 if mean_score > best_score {
361 best_score = mean_score;
362 best_parameters = combination.clone();
363 }
364 }
365
366 let total_time = start_time.elapsed().as_secs_f64();
367
368 Ok(TuningResult {
369 best_parameters,
370 best_score,
371 evaluation_history: evaluation_history.clone(),
372 convergence_info: ConvergenceInfo {
373 converged: false,
374 convergence_iteration: None,
375 stopping_reason: StoppingReason::MaxEvaluations,
376 },
377 exploration_stats: calculate_exploration_stats(&evaluation_history),
378 total_time,
379 ensemble_results: None,
380 pareto_front: None,
381 })
382 }
383
384 pub fn tune_affinity_propagation(
386 &self,
387 data: ArrayView2<F>,
388 search_space: SearchSpace,
389 ) -> Result<TuningResult> {
390 let start_time = std::time::Instant::now();
391 let mut evaluation_history = Vec::new();
392 let mut best_score = f64::NEG_INFINITY;
393 let mut best_parameters = HashMap::new();
394
395 let parameter_generator = ParameterGenerator::new(&self.config);
396 let parameter_combinations = parameter_generator.generate_combinations(&search_space)?;
397
398 for combination in ¶meter_combinations {
399 let damping = combination.get("damping").copied().unwrap_or(0.5);
400 let max_iter = combination
401 .get("max_iter")
402 .copied()
403 .unwrap_or(200.0)
404 .round() as usize;
405 let convergence_iter = combination
406 .get("convergence_iter")
407 .copied()
408 .unwrap_or(15.0)
409 .round() as usize;
410
411 let cv_validator = CrossValidator::new(&self.config.cv_config);
412 let scores = cv_validator.cross_validate_affinity_propagation(
413 data,
414 F::from(damping).unwrap(),
415 max_iter,
416 convergence_iter,
417 &self.config.metric,
418 )?;
419 let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
420
421 evaluation_history.push(EvaluationResult {
422 parameters: combination.clone(),
423 score: mean_score,
424 additional_metrics: HashMap::new(),
425 evaluation_time: 0.0,
426 memory_usage: None,
427 cv_scores: scores.clone(),
428 cv_std: scores.std_dev(),
429 metadata: HashMap::new(),
430 });
431
432 if mean_score > best_score {
433 best_score = mean_score;
434 best_parameters = combination.clone();
435 }
436 }
437
438 let total_time = start_time.elapsed().as_secs_f64();
439
440 Ok(TuningResult {
441 best_parameters,
442 best_score,
443 evaluation_history: evaluation_history.clone(),
444 convergence_info: ConvergenceInfo {
445 converged: false,
446 convergence_iteration: None,
447 stopping_reason: StoppingReason::MaxEvaluations,
448 },
449 exploration_stats: calculate_exploration_stats(&evaluation_history),
450 total_time,
451 ensemble_results: None,
452 pareto_front: None,
453 })
454 }
455
456 pub fn tune_birch(
458 &self,
459 data: ArrayView2<F>,
460 search_space: SearchSpace,
461 ) -> Result<TuningResult> {
462 let start_time = std::time::Instant::now();
463 let mut evaluation_history = Vec::new();
464 let mut best_score = f64::NEG_INFINITY;
465 let mut best_parameters = HashMap::new();
466
467 let parameter_generator = ParameterGenerator::new(&self.config);
468 let parameter_combinations = parameter_generator.generate_combinations(&search_space)?;
469
470 for combination in ¶meter_combinations {
471 let branching_factor = combination
472 .get("branching_factor")
473 .copied()
474 .unwrap_or(50.0)
475 .round() as usize;
476 let threshold = combination.get("threshold").copied().unwrap_or(0.5);
477
478 let cv_validator = CrossValidator::new(&self.config.cv_config);
479 let scores = cv_validator.cross_validate_birch(
480 data,
481 branching_factor,
482 F::from(threshold).unwrap(),
483 &self.config.metric,
484 )?;
485 let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
486
487 evaluation_history.push(EvaluationResult {
488 parameters: combination.clone(),
489 score: mean_score,
490 additional_metrics: HashMap::new(),
491 evaluation_time: 0.0,
492 memory_usage: None,
493 cv_scores: scores.clone(),
494 cv_std: scores.std_dev(),
495 metadata: HashMap::new(),
496 });
497
498 if mean_score > best_score {
499 best_score = mean_score;
500 best_parameters = combination.clone();
501 }
502 }
503
504 let total_time = start_time.elapsed().as_secs_f64();
505
506 Ok(TuningResult {
507 best_parameters,
508 best_score,
509 evaluation_history: evaluation_history.clone(),
510 convergence_info: ConvergenceInfo {
511 converged: false,
512 convergence_iteration: None,
513 stopping_reason: StoppingReason::MaxEvaluations,
514 },
515 exploration_stats: calculate_exploration_stats(&evaluation_history),
516 total_time,
517 ensemble_results: None,
518 pareto_front: None,
519 })
520 }
521
522 pub fn tune_gmm(&self, data: ArrayView2<F>, search_space: SearchSpace) -> Result<TuningResult> {
524 let start_time = std::time::Instant::now();
525 let mut evaluation_history = Vec::new();
526 let mut best_score = f64::NEG_INFINITY;
527 let mut best_parameters = HashMap::new();
528
529 let parameter_generator = ParameterGenerator::new(&self.config);
530 let parameter_combinations = parameter_generator.generate_combinations(&search_space)?;
531
532 for combination in ¶meter_combinations {
533 let n_components = combination
534 .get("n_components")
535 .ok_or_else(|| {
536 ClusteringError::InvalidInput("n_components parameter not found".to_string())
537 })?
538 .round() as usize;
539 let max_iter = combination
540 .get("max_iter")
541 .copied()
542 .unwrap_or(100.0)
543 .round() as usize;
544 let tol = combination.get("tol").copied().unwrap_or(1e-3);
545 let reg_covar = combination.get("reg_covar").copied().unwrap_or(1e-6);
546
547 let cv_validator = CrossValidator::new(&self.config.cv_config);
548 let scores = cv_validator.cross_validate_gmm(
549 data,
550 n_components,
551 max_iter,
552 F::from(tol).unwrap(),
553 F::from(reg_covar).unwrap(),
554 &self.config.metric,
555 )?;
556 let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
557
558 evaluation_history.push(EvaluationResult {
559 parameters: combination.clone(),
560 score: mean_score,
561 additional_metrics: HashMap::new(),
562 evaluation_time: 0.0,
563 memory_usage: None,
564 cv_scores: scores.clone(),
565 cv_std: scores.std_dev(),
566 metadata: HashMap::new(),
567 });
568
569 if mean_score > best_score {
570 best_score = mean_score;
571 best_parameters = combination.clone();
572 }
573 }
574
575 let total_time = start_time.elapsed().as_secs_f64();
576
577 Ok(TuningResult {
578 best_parameters,
579 best_score,
580 evaluation_history: evaluation_history.clone(),
581 convergence_info: ConvergenceInfo {
582 converged: false,
583 convergence_iteration: None,
584 stopping_reason: StoppingReason::MaxEvaluations,
585 },
586 exploration_stats: calculate_exploration_stats(&evaluation_history),
587 total_time,
588 ensemble_results: None,
589 pareto_front: None,
590 })
591 }
592}