1use scirs2_core::ndarray::{Array1, Array2};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct EnsembleConfig {
13 pub n_estimators: usize,
15 pub sampling_strategy: SamplingStrategy,
17 pub consensus_method: ConsensusMethod,
19 pub random_seed: Option<u64>,
21 pub diversity_strategy: Option<DiversityStrategy>,
23 pub quality_threshold: Option<f64>,
25 pub max_clusters: Option<usize>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub enum SamplingStrategy {
32 Bootstrap { sample_ratio: f64 },
34 RandomSubspace { feature_ratio: f64 },
36 BootstrapSubspace {
38 sample_ratio: f64,
39 feature_ratio: f64,
40 },
41 RandomProjection { target_dimensions: usize },
43 NoiseInjection {
45 noise_level: f64,
46 noise_type: NoiseType,
47 },
48 None,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub enum NoiseType {
55 Gaussian,
57 Uniform,
59 Outliers { outlier_ratio: f64 },
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub enum ConsensusMethod {
66 MajorityVoting,
68 WeightedConsensus,
70 GraphBased { similarity_threshold: f64 },
72 Hierarchical { linkage_method: String },
74 CoAssociation { threshold: f64 },
76 EvidenceAccumulation,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum DiversityStrategy {
83 AlgorithmDiversity {
85 algorithms: Vec<ClusteringAlgorithm>,
86 },
87 ParameterDiversity {
89 algorithm: ClusteringAlgorithm,
90 parameter_ranges: HashMap<String, ParameterRange>,
91 },
92 DataDiversity {
94 sampling_strategies: Vec<SamplingStrategy>,
95 },
96 Combined { strategies: Vec<DiversityStrategy> },
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub enum ClusteringAlgorithm {
103 KMeans { k_range: (usize, usize) },
105 DBSCAN {
107 eps_range: (f64, f64),
108 min_samples_range: (usize, usize),
109 },
110 MeanShift { bandwidth_range: (f64, f64) },
112 Hierarchical { methods: Vec<String> },
114 Spectral { k_range: (usize, usize) },
116 AffinityPropagation { damping_range: (f64, f64) },
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub enum ParameterRange {
123 Integer(i64, i64),
125 Float(f64, f64),
127 Categorical(Vec<String>),
129 Boolean,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct ClusteringResult {
136 pub labels: Array1<i32>,
138 pub algorithm: String,
140 pub parameters: HashMap<String, String>,
142 pub quality_score: f64,
144 pub stability_score: Option<f64>,
146 pub n_clusters: usize,
148 pub runtime: f64,
150}
151
152impl ClusteringResult {
153 pub fn new(
155 labels: Array1<i32>,
156 algorithm: String,
157 parameters: HashMap<String, String>,
158 quality_score: f64,
159 runtime: f64,
160 ) -> Self {
161 let n_clusters = labels
162 .iter()
163 .copied()
164 .filter(|&x| x >= 0)
165 .max()
166 .map(|x| x as usize + 1)
167 .unwrap_or(0);
168
169 Self {
170 labels,
171 algorithm,
172 parameters,
173 quality_score,
174 stability_score: None,
175 n_clusters,
176 runtime,
177 }
178 }
179
180 pub fn with_stability_score(mut self, score: f64) -> Self {
182 self.stability_score = Some(score);
183 self
184 }
185
186 pub fn has_noise(&self) -> bool {
188 self.labels.iter().any(|&x| x < 0)
189 }
190
191 pub fn noise_count(&self) -> usize {
193 self.labels.iter().filter(|&&x| x < 0).count()
194 }
195
196 pub fn cluster_sizes(&self) -> Vec<usize> {
198 let mut sizes = vec![0; self.n_clusters];
199 for &label in self.labels.iter() {
200 if label >= 0 {
201 let cluster_id = label as usize;
202 if cluster_id < sizes.len() {
203 sizes[cluster_id] += 1;
204 }
205 }
206 }
207 sizes
208 }
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct EnsembleResult {
214 pub consensus_labels: Array1<i32>,
216 pub individual_results: Vec<ClusteringResult>,
218 pub consensus_stats: ConsensusStatistics,
220 pub diversity_metrics: DiversityMetrics,
222 pub ensemble_quality: f64,
224 pub stability_score: f64,
226}
227
228impl EnsembleResult {
229 pub fn new(
231 consensus_labels: Array1<i32>,
232 individual_results: Vec<ClusteringResult>,
233 consensus_stats: ConsensusStatistics,
234 diversity_metrics: DiversityMetrics,
235 ensemble_quality: f64,
236 stability_score: f64,
237 ) -> Self {
238 Self {
239 consensus_labels,
240 individual_results,
241 consensus_stats,
242 diversity_metrics,
243 ensemble_quality,
244 stability_score,
245 }
246 }
247
248 pub fn n_consensus_clusters(&self) -> usize {
250 self.consensus_labels
251 .iter()
252 .copied()
253 .filter(|&x| x >= 0)
254 .max()
255 .map(|x| x as usize + 1)
256 .unwrap_or(0)
257 }
258
259 pub fn consensus_cluster_sizes(&self) -> Vec<usize> {
261 let n_clusters = self.n_consensus_clusters();
262 let mut sizes = vec![0; n_clusters];
263 for &label in self.consensus_labels.iter() {
264 if label >= 0 {
265 let cluster_id = label as usize;
266 if cluster_id < sizes.len() {
267 sizes[cluster_id] += 1;
268 }
269 }
270 }
271 sizes
272 }
273
274 pub fn average_individual_quality(&self) -> f64 {
276 if self.individual_results.is_empty() {
277 0.0
278 } else {
279 self.individual_results
280 .iter()
281 .map(|r| r.quality_score)
282 .sum::<f64>()
283 / self.individual_results.len() as f64
284 }
285 }
286
287 pub fn best_individual_result(&self) -> Option<&ClusteringResult> {
289 self.individual_results.iter().max_by(|a, b| {
290 a.quality_score
291 .partial_cmp(&b.quality_score)
292 .unwrap_or(std::cmp::Ordering::Equal)
293 })
294 }
295
296 pub fn algorithm_distribution(&self) -> HashMap<String, usize> {
298 let mut distribution = HashMap::new();
299 for result in &self.individual_results {
300 *distribution.entry(result.algorithm.clone()).or_insert(0) += 1;
301 }
302 distribution
303 }
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct ConsensusStatistics {
309 pub agreement_matrix: Array2<f64>,
311 pub consensus_strength: Array1<f64>,
313 pub cluster_stability: Vec<f64>,
315 pub agreement_counts: Array1<usize>,
317}
318
319impl ConsensusStatistics {
320 pub fn new(
322 agreement_matrix: Array2<f64>,
323 consensus_strength: Array1<f64>,
324 cluster_stability: Vec<f64>,
325 agreement_counts: Array1<usize>,
326 ) -> Self {
327 Self {
328 agreement_matrix,
329 consensus_strength,
330 cluster_stability,
331 agreement_counts,
332 }
333 }
334
335 pub fn average_consensus_strength(&self) -> f64 {
337 self.consensus_strength.mean().unwrap_or(0.0)
338 }
339
340 pub fn min_consensus_strength(&self) -> f64 {
342 self.consensus_strength
343 .iter()
344 .cloned()
345 .fold(f64::INFINITY, f64::min)
346 }
347
348 pub fn max_consensus_strength(&self) -> f64 {
350 self.consensus_strength
351 .iter()
352 .cloned()
353 .fold(f64::NEG_INFINITY, f64::max)
354 }
355
356 pub fn average_cluster_stability(&self) -> f64 {
358 if self.cluster_stability.is_empty() {
359 0.0
360 } else {
361 self.cluster_stability.iter().sum::<f64>() / self.cluster_stability.len() as f64
362 }
363 }
364}
365
366#[derive(Debug, Clone, Serialize, Deserialize)]
368pub struct DiversityMetrics {
369 pub average_diversity: f64,
371 pub diversity_matrix: Array2<f64>,
373 pub algorithm_distribution: HashMap<String, usize>,
375 pub parameter_diversity: HashMap<String, f64>,
377}
378
379impl DiversityMetrics {
380 pub fn new(
382 average_diversity: f64,
383 diversity_matrix: Array2<f64>,
384 algorithm_distribution: HashMap<String, usize>,
385 parameter_diversity: HashMap<String, f64>,
386 ) -> Self {
387 Self {
388 average_diversity,
389 diversity_matrix,
390 algorithm_distribution,
391 parameter_diversity,
392 }
393 }
394
395 pub fn max_diversity(&self) -> f64 {
397 self.diversity_matrix
398 .iter()
399 .cloned()
400 .fold(f64::NEG_INFINITY, f64::max)
401 }
402
403 pub fn min_diversity(&self) -> f64 {
405 self.diversity_matrix
406 .iter()
407 .cloned()
408 .fold(f64::INFINITY, f64::min)
409 }
410
411 pub fn diversity_variance(&self) -> f64 {
413 let mean = self.average_diversity;
414 let variance = self
415 .diversity_matrix
416 .iter()
417 .map(|&x| (x - mean).powi(2))
418 .sum::<f64>()
419 / (self.diversity_matrix.len() as f64);
420 variance
421 }
422
423 pub fn has_good_diversity(&self, threshold: f64) -> bool {
425 self.average_diversity >= threshold
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432 use scirs2_core::ndarray::arr1;
433
434 #[test]
435 fn test_ensemble_config_default() {
436 let config = EnsembleConfig::default();
437 assert_eq!(config.n_estimators, 10);
438 assert!(matches!(
439 config.sampling_strategy,
440 SamplingStrategy::Bootstrap { .. }
441 ));
442 assert!(matches!(
443 config.consensus_method,
444 ConsensusMethod::MajorityVoting
445 ));
446 }
447
448 #[test]
449 fn test_clustering_result_creation() {
450 let labels = arr1(&[0, 0, 1, 1, -1]);
451 let mut params = HashMap::new();
452 params.insert("k".to_string(), "2".to_string());
453
454 let result = ClusteringResult::new(labels, "kmeans".to_string(), params, 0.8, 1.5);
455
456 assert_eq!(result.n_clusters, 2);
457 assert!(result.has_noise());
458 assert_eq!(result.noise_count(), 1);
459 assert_eq!(result.cluster_sizes(), vec![2, 2]);
460 }
461
462 #[test]
463 fn test_ensemble_result_metrics() {
464 let consensus_labels = arr1(&[0, 0, 1, 1]);
465 let individual_results = vec![
466 ClusteringResult::new(
467 arr1(&[0, 0, 1, 1]),
468 "kmeans".to_string(),
469 HashMap::new(),
470 0.8,
471 1.0,
472 ),
473 ClusteringResult::new(
474 arr1(&[1, 1, 0, 0]),
475 "dbscan".to_string(),
476 HashMap::new(),
477 0.7,
478 1.5,
479 ),
480 ];
481
482 let consensus_stats = ConsensusStatistics::new(
483 Array2::zeros((2, 2)),
484 arr1(&[0.9, 0.9, 0.8, 0.8]),
485 vec![0.9, 0.8],
486 arr1(&[2, 2, 2, 2]),
487 );
488
489 let diversity_metrics =
490 DiversityMetrics::new(0.5, Array2::zeros((2, 2)), HashMap::new(), HashMap::new());
491
492 let result = EnsembleResult::new(
493 consensus_labels,
494 individual_results,
495 consensus_stats,
496 diversity_metrics,
497 0.85,
498 0.9,
499 );
500
501 assert_eq!(result.n_consensus_clusters(), 2);
502 assert_eq!(result.average_individual_quality(), 0.75);
503 assert!(result.best_individual_result().is_some());
504 }
505
506 #[test]
507 fn test_consensus_statistics() {
508 let stats = ConsensusStatistics::new(
509 Array2::zeros((3, 3)),
510 arr1(&[0.8, 0.9, 0.7]),
511 vec![0.9, 0.8, 0.85],
512 arr1(&[3, 2, 3]),
513 );
514
515 assert!((stats.average_consensus_strength() - 0.8).abs() < 1e-10);
516 assert_eq!(stats.min_consensus_strength(), 0.7);
517 assert_eq!(stats.max_consensus_strength(), 0.9);
518 assert!((stats.average_cluster_stability() - 0.85).abs() < 1e-10);
519 }
520
521 #[test]
522 fn test_diversity_metrics() {
523 let metrics = DiversityMetrics::new(
524 0.6,
525 Array2::from_shape_vec((2, 2), vec![0.0, 0.8, 0.8, 0.0]).unwrap(),
526 HashMap::new(),
527 HashMap::new(),
528 );
529
530 assert_eq!(metrics.max_diversity(), 0.8);
531 assert_eq!(metrics.min_diversity(), 0.0);
532 assert!(metrics.has_good_diversity(0.5));
533 assert!(!metrics.has_good_diversity(0.7));
534 }
535}