scirs2_cluster/ensemble/
core.rs

1//! Core types and configurations for ensemble clustering
2//!
3//! This module provides the fundamental data structures and configurations
4//! used throughout the ensemble clustering system.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Configuration for ensemble clustering
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct EnsembleConfig {
13    /// Number of base clustering algorithms to use
14    pub n_estimators: usize,
15    /// Sampling strategy for data subsets
16    pub sampling_strategy: SamplingStrategy,
17    /// Consensus method for combining results
18    pub consensus_method: ConsensusMethod,
19    /// Random seed for reproducible results
20    pub random_seed: Option<u64>,
21    /// Diversity enforcement strategy
22    pub diversity_strategy: Option<DiversityStrategy>,
23    /// Quality threshold for including results
24    pub quality_threshold: Option<f64>,
25    /// Maximum number of clusters to consider
26    pub max_clusters: Option<usize>,
27}
28
29/// Sampling strategies for creating diverse datasets
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub enum SamplingStrategy {
32    /// Bootstrap sampling with replacement
33    Bootstrap { sample_ratio: f64 },
34    /// Random subspace sampling (feature selection)
35    RandomSubspace { feature_ratio: f64 },
36    /// Combined bootstrap and subspace sampling
37    BootstrapSubspace {
38        sample_ratio: f64,
39        feature_ratio: f64,
40    },
41    /// Random projection to lower dimensions
42    RandomProjection { target_dimensions: usize },
43    /// Noise injection for robustness testing
44    NoiseInjection {
45        noise_level: f64,
46        noise_type: NoiseType,
47    },
48    /// No sampling (use full dataset)
49    None,
50}
51
52/// Types of noise for injection
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub enum NoiseType {
55    /// Gaussian noise
56    Gaussian,
57    /// Uniform noise
58    Uniform,
59    /// Outlier injection
60    Outliers { outlier_ratio: f64 },
61}
62
63/// Methods for combining clustering results
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub enum ConsensusMethod {
66    /// Simple majority voting
67    MajorityVoting,
68    /// Weighted consensus based on quality scores
69    WeightedConsensus,
70    /// Graph-based consensus clustering
71    GraphBased { similarity_threshold: f64 },
72    /// Hierarchical consensus
73    Hierarchical { linkage_method: String },
74    /// Co-association matrix approach
75    CoAssociation { threshold: f64 },
76    /// Evidence accumulation clustering
77    EvidenceAccumulation,
78}
79
80/// Strategies for enforcing diversity among base clusterers
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum DiversityStrategy {
83    /// Algorithm diversity (use different algorithms)
84    AlgorithmDiversity {
85        algorithms: Vec<ClusteringAlgorithm>,
86    },
87    /// Parameter diversity (same algorithm, different parameters)
88    ParameterDiversity {
89        algorithm: ClusteringAlgorithm,
90        parameter_ranges: HashMap<String, ParameterRange>,
91    },
92    /// Data diversity (different data subsets)
93    DataDiversity {
94        sampling_strategies: Vec<SamplingStrategy>,
95    },
96    /// Combined diversity strategy
97    Combined { strategies: Vec<DiversityStrategy> },
98}
99
100/// Supported clustering algorithms for ensemble
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub enum ClusteringAlgorithm {
103    /// K-means clustering
104    KMeans { k_range: (usize, usize) },
105    /// DBSCAN clustering
106    DBSCAN {
107        eps_range: (f64, f64),
108        min_samples_range: (usize, usize),
109    },
110    /// Mean shift clustering
111    MeanShift { bandwidth_range: (f64, f64) },
112    /// Hierarchical clustering
113    Hierarchical { methods: Vec<String> },
114    /// Spectral clustering
115    Spectral { k_range: (usize, usize) },
116    /// Affinity propagation
117    AffinityPropagation { damping_range: (f64, f64) },
118}
119
120/// Parameter ranges for diversity
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub enum ParameterRange {
123    /// Integer range
124    Integer(i64, i64),
125    /// Float range
126    Float(f64, f64),
127    /// Categorical choices
128    Categorical(Vec<String>),
129    /// Boolean choice
130    Boolean,
131}
132
133/// Result of a single clustering run
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct ClusteringResult {
136    /// Cluster labels
137    pub labels: Array1<i32>,
138    /// Algorithm used
139    pub algorithm: String,
140    /// Parameters used
141    pub parameters: HashMap<String, String>,
142    /// Quality score
143    pub quality_score: f64,
144    /// Stability score (if available)
145    pub stability_score: Option<f64>,
146    /// Number of clusters found
147    pub n_clusters: usize,
148    /// Runtime in seconds
149    pub runtime: f64,
150}
151
152impl ClusteringResult {
153    /// Create a new clustering result
154    pub fn new(
155        labels: Array1<i32>,
156        algorithm: String,
157        parameters: HashMap<String, String>,
158        quality_score: f64,
159        runtime: f64,
160    ) -> Self {
161        let n_clusters = labels
162            .iter()
163            .copied()
164            .filter(|&x| x >= 0)
165            .max()
166            .map(|x| x as usize + 1)
167            .unwrap_or(0);
168
169        Self {
170            labels,
171            algorithm,
172            parameters,
173            quality_score,
174            stability_score: None,
175            n_clusters,
176            runtime,
177        }
178    }
179
180    /// Set stability score
181    pub fn with_stability_score(mut self, score: f64) -> Self {
182        self.stability_score = Some(score);
183        self
184    }
185
186    /// Check if this result has noise points
187    pub fn has_noise(&self) -> bool {
188        self.labels.iter().any(|&x| x < 0)
189    }
190
191    /// Get number of noise points
192    pub fn noise_count(&self) -> usize {
193        self.labels.iter().filter(|&&x| x < 0).count()
194    }
195
196    /// Get cluster sizes
197    pub fn cluster_sizes(&self) -> Vec<usize> {
198        let mut sizes = vec![0; self.n_clusters];
199        for &label in self.labels.iter() {
200            if label >= 0 {
201                let cluster_id = label as usize;
202                if cluster_id < sizes.len() {
203                    sizes[cluster_id] += 1;
204                }
205            }
206        }
207        sizes
208    }
209}
210
211/// Ensemble clustering result
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct EnsembleResult {
214    /// Final consensus labels
215    pub consensus_labels: Array1<i32>,
216    /// Individual clustering results
217    pub individual_results: Vec<ClusteringResult>,
218    /// Consensus statistics
219    pub consensus_stats: ConsensusStatistics,
220    /// Diversity metrics
221    pub diversity_metrics: DiversityMetrics,
222    /// Overall quality score
223    pub ensemble_quality: f64,
224    /// Stability score
225    pub stability_score: f64,
226}
227
228impl EnsembleResult {
229    /// Create a new ensemble result
230    pub fn new(
231        consensus_labels: Array1<i32>,
232        individual_results: Vec<ClusteringResult>,
233        consensus_stats: ConsensusStatistics,
234        diversity_metrics: DiversityMetrics,
235        ensemble_quality: f64,
236        stability_score: f64,
237    ) -> Self {
238        Self {
239            consensus_labels,
240            individual_results,
241            consensus_stats,
242            diversity_metrics,
243            ensemble_quality,
244            stability_score,
245        }
246    }
247
248    /// Get number of consensus clusters
249    pub fn n_consensus_clusters(&self) -> usize {
250        self.consensus_labels
251            .iter()
252            .copied()
253            .filter(|&x| x >= 0)
254            .max()
255            .map(|x| x as usize + 1)
256            .unwrap_or(0)
257    }
258
259    /// Get consensus cluster sizes
260    pub fn consensus_cluster_sizes(&self) -> Vec<usize> {
261        let n_clusters = self.n_consensus_clusters();
262        let mut sizes = vec![0; n_clusters];
263        for &label in self.consensus_labels.iter() {
264            if label >= 0 {
265                let cluster_id = label as usize;
266                if cluster_id < sizes.len() {
267                    sizes[cluster_id] += 1;
268                }
269            }
270        }
271        sizes
272    }
273
274    /// Get average quality of individual results
275    pub fn average_individual_quality(&self) -> f64 {
276        if self.individual_results.is_empty() {
277            0.0
278        } else {
279            self.individual_results
280                .iter()
281                .map(|r| r.quality_score)
282                .sum::<f64>()
283                / self.individual_results.len() as f64
284        }
285    }
286
287    /// Get best individual result
288    pub fn best_individual_result(&self) -> Option<&ClusteringResult> {
289        self.individual_results.iter().max_by(|a, b| {
290            a.quality_score
291                .partial_cmp(&b.quality_score)
292                .unwrap_or(std::cmp::Ordering::Equal)
293        })
294    }
295
296    /// Get algorithm distribution
297    pub fn algorithm_distribution(&self) -> HashMap<String, usize> {
298        let mut distribution = HashMap::new();
299        for result in &self.individual_results {
300            *distribution.entry(result.algorithm.clone()).or_insert(0) += 1;
301        }
302        distribution
303    }
304}
305
306/// Statistics about the consensus process
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct ConsensusStatistics {
309    /// Agreement matrix between clusterers
310    pub agreement_matrix: Array2<f64>,
311    /// Per-sample consensus strength
312    pub consensus_strength: Array1<f64>,
313    /// Cluster stability scores
314    pub cluster_stability: Vec<f64>,
315    /// Number of clusterers agreeing on each sample
316    pub agreement_counts: Array1<usize>,
317}
318
319impl ConsensusStatistics {
320    /// Create new consensus statistics
321    pub fn new(
322        agreement_matrix: Array2<f64>,
323        consensus_strength: Array1<f64>,
324        cluster_stability: Vec<f64>,
325        agreement_counts: Array1<usize>,
326    ) -> Self {
327        Self {
328            agreement_matrix,
329            consensus_strength,
330            cluster_stability,
331            agreement_counts,
332        }
333    }
334
335    /// Get average consensus strength
336    pub fn average_consensus_strength(&self) -> f64 {
337        self.consensus_strength.mean().unwrap_or(0.0)
338    }
339
340    /// Get minimum consensus strength
341    pub fn min_consensus_strength(&self) -> f64 {
342        self.consensus_strength
343            .iter()
344            .cloned()
345            .fold(f64::INFINITY, f64::min)
346    }
347
348    /// Get maximum consensus strength
349    pub fn max_consensus_strength(&self) -> f64 {
350        self.consensus_strength
351            .iter()
352            .cloned()
353            .fold(f64::NEG_INFINITY, f64::max)
354    }
355
356    /// Get average cluster stability
357    pub fn average_cluster_stability(&self) -> f64 {
358        if self.cluster_stability.is_empty() {
359            0.0
360        } else {
361            self.cluster_stability.iter().sum::<f64>() / self.cluster_stability.len() as f64
362        }
363    }
364}
365
366/// Diversity metrics for the ensemble
367#[derive(Debug, Clone, Serialize, Deserialize)]
368pub struct DiversityMetrics {
369    /// Average pairwise diversity (1 - ARI)
370    pub average_diversity: f64,
371    /// Diversity matrix between all pairs
372    pub diversity_matrix: Array2<f64>,
373    /// Algorithm distribution
374    pub algorithm_distribution: HashMap<String, usize>,
375    /// Parameter diversity statistics
376    pub parameter_diversity: HashMap<String, f64>,
377}
378
379impl DiversityMetrics {
380    /// Create new diversity metrics
381    pub fn new(
382        average_diversity: f64,
383        diversity_matrix: Array2<f64>,
384        algorithm_distribution: HashMap<String, usize>,
385        parameter_diversity: HashMap<String, f64>,
386    ) -> Self {
387        Self {
388            average_diversity,
389            diversity_matrix,
390            algorithm_distribution,
391            parameter_diversity,
392        }
393    }
394
395    /// Get maximum pairwise diversity
396    pub fn max_diversity(&self) -> f64 {
397        self.diversity_matrix
398            .iter()
399            .cloned()
400            .fold(f64::NEG_INFINITY, f64::max)
401    }
402
403    /// Get minimum pairwise diversity
404    pub fn min_diversity(&self) -> f64 {
405        self.diversity_matrix
406            .iter()
407            .cloned()
408            .fold(f64::INFINITY, f64::min)
409    }
410
411    /// Get diversity variance
412    pub fn diversity_variance(&self) -> f64 {
413        let mean = self.average_diversity;
414        let variance = self
415            .diversity_matrix
416            .iter()
417            .map(|&x| (x - mean).powi(2))
418            .sum::<f64>()
419            / (self.diversity_matrix.len() as f64);
420        variance
421    }
422
423    /// Check if ensemble has good diversity
424    pub fn has_good_diversity(&self, threshold: f64) -> bool {
425        self.average_diversity >= threshold
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use scirs2_core::ndarray::arr1;
433
434    #[test]
435    fn test_ensemble_config_default() {
436        let config = EnsembleConfig::default();
437        assert_eq!(config.n_estimators, 10);
438        assert!(matches!(
439            config.sampling_strategy,
440            SamplingStrategy::Bootstrap { .. }
441        ));
442        assert!(matches!(
443            config.consensus_method,
444            ConsensusMethod::MajorityVoting
445        ));
446    }
447
448    #[test]
449    fn test_clustering_result_creation() {
450        let labels = arr1(&[0, 0, 1, 1, -1]);
451        let mut params = HashMap::new();
452        params.insert("k".to_string(), "2".to_string());
453
454        let result = ClusteringResult::new(labels, "kmeans".to_string(), params, 0.8, 1.5);
455
456        assert_eq!(result.n_clusters, 2);
457        assert!(result.has_noise());
458        assert_eq!(result.noise_count(), 1);
459        assert_eq!(result.cluster_sizes(), vec![2, 2]);
460    }
461
462    #[test]
463    fn test_ensemble_result_metrics() {
464        let consensus_labels = arr1(&[0, 0, 1, 1]);
465        let individual_results = vec![
466            ClusteringResult::new(
467                arr1(&[0, 0, 1, 1]),
468                "kmeans".to_string(),
469                HashMap::new(),
470                0.8,
471                1.0,
472            ),
473            ClusteringResult::new(
474                arr1(&[1, 1, 0, 0]),
475                "dbscan".to_string(),
476                HashMap::new(),
477                0.7,
478                1.5,
479            ),
480        ];
481
482        let consensus_stats = ConsensusStatistics::new(
483            Array2::zeros((2, 2)),
484            arr1(&[0.9, 0.9, 0.8, 0.8]),
485            vec![0.9, 0.8],
486            arr1(&[2, 2, 2, 2]),
487        );
488
489        let diversity_metrics =
490            DiversityMetrics::new(0.5, Array2::zeros((2, 2)), HashMap::new(), HashMap::new());
491
492        let result = EnsembleResult::new(
493            consensus_labels,
494            individual_results,
495            consensus_stats,
496            diversity_metrics,
497            0.85,
498            0.9,
499        );
500
501        assert_eq!(result.n_consensus_clusters(), 2);
502        assert_eq!(result.average_individual_quality(), 0.75);
503        assert!(result.best_individual_result().is_some());
504    }
505
506    #[test]
507    fn test_consensus_statistics() {
508        let stats = ConsensusStatistics::new(
509            Array2::zeros((3, 3)),
510            arr1(&[0.8, 0.9, 0.7]),
511            vec![0.9, 0.8, 0.85],
512            arr1(&[3, 2, 3]),
513        );
514
515        assert!((stats.average_consensus_strength() - 0.8).abs() < 1e-10);
516        assert_eq!(stats.min_consensus_strength(), 0.7);
517        assert_eq!(stats.max_consensus_strength(), 0.9);
518        assert!((stats.average_cluster_stability() - 0.85).abs() < 1e-10);
519    }
520
521    #[test]
522    fn test_diversity_metrics() {
523        let metrics = DiversityMetrics::new(
524            0.6,
525            Array2::from_shape_vec((2, 2), vec![0.0, 0.8, 0.8, 0.0]).unwrap(),
526            HashMap::new(),
527            HashMap::new(),
528        );
529
530        assert_eq!(metrics.max_diversity(), 0.8);
531        assert_eq!(metrics.min_diversity(), 0.0);
532        assert!(metrics.has_good_diversity(0.5));
533        assert!(!metrics.has_good_diversity(0.7));
534    }
535}