Skip to main content

scirs2_cluster/ensemble/
mod.rs

1//! Ensemble clustering algorithms
2//!
3//! This module provides comprehensive ensemble clustering capabilities that combine
4//! multiple base clustering algorithms to achieve more robust and stable results.
5//!
6//! # Examples
7//!
8//! ## Basic Ensemble Clustering
9//!
10//! ```rust
11//! use scirs2_cluster::ensemble::{EnsembleClusterer, EnsembleConfig, SamplingStrategy};
12//! use scirs2_core::ndarray::Array2;
13//!
14//! // Create sample data
15//! let data = Array2::from_shape_vec((100, 2), (0..200).map(|x| x as f64).collect()).expect("Operation failed");
16//!
17//! // Configure ensemble
18//! let config = EnsembleConfig {
19//!     n_estimators: 10,
20//!     sampling_strategy: SamplingStrategy::Bootstrap { sample_ratio: 0.8 },
21//!     ..Default::default()
22//! };
23//!
24//! // Create and fit ensemble
25//! let ensemble = EnsembleClusterer::new(config);
26//! let result = ensemble.fit(data.view()).expect("Operation failed");
27//! println!("Ensemble quality: {}", result.ensemble_quality);
28//! ```
29//!
30//! ## Convenience Functions
31//!
32//! ```rust
33//! use scirs2_cluster::ensemble::convenience::ensemble_clustering;
34//! use scirs2_core::ndarray::Array2;
35//!
36//! let data = Array2::from_shape_vec((50, 3), (0..150).map(|x| x as f64).collect()).expect("Operation failed");
37//! let result = ensemble_clustering(data.view()).expect("Operation failed");
38//! ```
39//!
40//! ## Advanced Ensemble Methods
41//!
42//! ```rust,no_run
43//! use scirs2_cluster::ensemble::advanced::{AdvancedEnsembleClusterer, AdvancedEnsembleConfig};
44//! use scirs2_cluster::ensemble::{
45//!     EnsembleConfig,
46//!     MetaLearningConfig, MetaLearningAlgorithm,
47//!     BayesianAveragingConfig, PosteriorUpdateMethod,
48//!     GeneticOptimizationConfig, SelectionMethod, FitnessFunction,
49//!     BoostingConfig, ReweightingStrategy, ErrorFunction,
50//!     StackingConfig, MetaClusteringAlgorithm, ClusteringAlgorithm,
51//! };
52//! use scirs2_core::ndarray::Array2;
53//!
54//! // Advanced ensemble with meta-learning
55//! let data = Array2::from_shape_vec((100, 4), (0..400).map(|x| x as f64).collect()).expect("Operation failed");
56//! let base_config = EnsembleConfig::default();
57//! let advanced_config = AdvancedEnsembleConfig {
58//!     meta_learning: MetaLearningConfig {
59//!         n_meta_features: 4,
60//!         learning_rate: 0.01,
61//!         n_iterations: 10,
62//!         algorithm: MetaLearningAlgorithm::Linear { regularization: 0.1 },
63//!         validation_split: 0.2,
64//!     },
65//!     bayesian_averaging: BayesianAveragingConfig {
66//!         prior_alpha: 1.0,
67//!         prior_beta: 1.0,
68//!         n_samples: 10,
69//!         burn_in: 2,
70//!         update_method: PosteriorUpdateMethod::MetropolisHastings,
71//!         adaptive_sampling: false,
72//!     },
73//!     genetic_optimization: GeneticOptimizationConfig {
74//!         population_size: 5,
75//!         n_generations: 2,
76//!         crossover_prob: 0.8,
77//!         mutation_prob: 0.1,
78//!         selection_method: SelectionMethod::Tournament { tournament_size: 2 },
79//!         elite_percentage: 0.1,
80//!         fitness_function: FitnessFunction::Silhouette,
81//!     },
82//!     boostingconfig: BoostingConfig {
83//!         n_rounds: 3,
84//!         learning_rate: 1.0,
85//!         reweighting_strategy: ReweightingStrategy::Exponential,
86//!         error_function: ErrorFunction::DisagreementRate,
87//!         adaptive_boosting: false,
88//!     },
89//!     stackingconfig: StackingConfig {
90//!         base_algorithms: vec![ClusteringAlgorithm::KMeans { k_range: (2, 4) }],
91//!         meta_algorithm: MetaClusteringAlgorithm::Hierarchical { linkage: "ward".into() },
92//!         cv_folds: 2,
93//!         blending_ratio: 0.5,
94//!         feature_engineering: false,
95//!     },
96//!     uncertainty_quantification: false,
97//! };
98//! let mut advanced_ensemble = AdvancedEnsembleClusterer::new(advanced_config, base_config);
99//! let result = advanced_ensemble.fit_with_meta_learning(data.view()).expect("Operation failed");
100//! ```
101
102pub mod advanced;
103pub mod algorithms;
104pub mod convenience;
105pub mod core;
106pub mod weighted;
107
108// Re-export main types for convenience
109pub use algorithms::EnsembleClusterer;
110pub use core::*;
111
112// Re-export convenience functions at module level for backward compatibility
113pub use convenience::{
114    adaptive_ensemble, bootstrap_ensemble, ensemble_clustering, federated_ensemble,
115    meta_clustering_ensemble, multi_algorithm_ensemble, AdaptationConfig, AdaptationStrategy,
116    AggregationMethod, FederationConfig,
117};
118
119// Re-export advanced types
120pub use advanced::{
121    AdvancedEnsembleClusterer, AdvancedEnsembleConfig, BayesianAveragingConfig, BoostingConfig,
122    ErrorFunction, FitnessFunction, GeneticOptimizationConfig, GeneticOptimizer,
123    MetaClusteringAlgorithm, MetaLearner, MetaLearningAlgorithm, MetaLearningConfig,
124    PosteriorUpdateMethod, ReweightingStrategy, SelectionMethod, StackingConfig,
125};
126
127// Maintain backward compatibility by re-exporting the convenience module
128pub mod convenience_functions {
129    pub use super::convenience::*;
130}
131
132/// Convenience function to create a default ensemble configuration
133pub fn default_ensemble_config() -> EnsembleConfig {
134    EnsembleConfig::default()
135}
136
137/// Convenience function to create a bootstrap ensemble configuration
138pub fn bootstrap_ensemble_config(n_estimators: usize, sample_ratio: f64) -> EnsembleConfig {
139    EnsembleConfig {
140        n_estimators,
141        sampling_strategy: SamplingStrategy::Bootstrap { sample_ratio },
142        ..Default::default()
143    }
144}
145
146/// Convenience function to create an algorithm diversity configuration
147pub fn algorithm_diversity_config(algorithms: Vec<ClusteringAlgorithm>) -> EnsembleConfig {
148    EnsembleConfig {
149        diversity_strategy: Some(DiversityStrategy::AlgorithmDiversity { algorithms }),
150        ..Default::default()
151    }
152}
153
154/// Convenience function to create a weighted consensus configuration
155pub fn weighted_consensus_config() -> EnsembleConfig {
156    EnsembleConfig {
157        consensus_method: ConsensusMethod::WeightedConsensus,
158        ..Default::default()
159    }
160}
161
162/// Convenience function to create a graph-based consensus configuration
163pub fn graph_based_consensus_config(similarity_threshold: f64) -> EnsembleConfig {
164    EnsembleConfig {
165        consensus_method: ConsensusMethod::GraphBased {
166            similarity_threshold,
167        },
168        ..Default::default()
169    }
170}
171
172/// Convenience function for quick ensemble clustering with default parameters
173pub fn quick_ensemble_clustering<F>(
174    data: scirs2_core::ndarray::ArrayView2<F>,
175    n_estimators: Option<usize>,
176) -> crate::error::Result<EnsembleResult>
177where
178    F: scirs2_core::numeric::Float
179        + scirs2_core::numeric::FromPrimitive
180        + std::fmt::Debug
181        + 'static
182        + std::iter::Sum
183        + std::fmt::Display
184        + Send
185        + Sync,
186    f64: From<F>,
187{
188    let config = EnsembleConfig {
189        n_estimators: n_estimators.unwrap_or(10),
190        ..Default::default()
191    };
192    let ensemble = EnsembleClusterer::new(config);
193    ensemble.fit(data)
194}
195
196/// Convenience function for multi-algorithm ensemble with common algorithms
197pub fn quick_multi_algorithm_ensemble<F>(
198    data: scirs2_core::ndarray::ArrayView2<F>,
199) -> crate::error::Result<EnsembleResult>
200where
201    F: scirs2_core::numeric::Float
202        + scirs2_core::numeric::FromPrimitive
203        + std::fmt::Debug
204        + 'static
205        + std::iter::Sum
206        + std::fmt::Display
207        + Send
208        + Sync,
209    f64: From<F>,
210{
211    let algorithms = vec![
212        ClusteringAlgorithm::KMeans { k_range: (2, 8) },
213        ClusteringAlgorithm::DBSCAN {
214            eps_range: (0.1, 1.0),
215            min_samples_range: (3, 10),
216        },
217        ClusteringAlgorithm::AffinityPropagation {
218            damping_range: (0.5, 0.9),
219        },
220    ];
221
222    multi_algorithm_ensemble(data, algorithms)
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use scirs2_core::ndarray::Array2;
229
230    #[test]
231    fn test_default_ensemble_config() {
232        let config = default_ensemble_config();
233        assert_eq!(config.n_estimators, 10);
234        assert!(matches!(
235            config.sampling_strategy,
236            SamplingStrategy::Bootstrap { .. }
237        ));
238        assert!(matches!(
239            config.consensus_method,
240            ConsensusMethod::MajorityVoting
241        ));
242    }
243
244    #[test]
245    fn test_bootstrap_ensemble_config() {
246        let config = bootstrap_ensemble_config(15, 0.7);
247        assert_eq!(config.n_estimators, 15);
248        if let SamplingStrategy::Bootstrap { sample_ratio } = config.sampling_strategy {
249            assert!((sample_ratio - 0.7).abs() < 1e-10);
250        } else {
251            panic!("Expected Bootstrap sampling strategy");
252        }
253    }
254
255    #[test]
256    fn test_algorithm_diversity_config() {
257        let algorithms = vec![
258            ClusteringAlgorithm::KMeans { k_range: (2, 5) },
259            ClusteringAlgorithm::DBSCAN {
260                eps_range: (0.1, 1.0),
261                min_samples_range: (3, 10),
262            },
263        ];
264        let config = algorithm_diversity_config(algorithms.clone());
265
266        if let Some(DiversityStrategy::AlgorithmDiversity { algorithms: algs }) =
267            config.diversity_strategy
268        {
269            assert_eq!(algs.len(), 2);
270        } else {
271            panic!("Expected AlgorithmDiversity strategy");
272        }
273    }
274
275    #[test]
276    fn test_weighted_consensus_config() {
277        let config = weighted_consensus_config();
278        assert!(matches!(
279            config.consensus_method,
280            ConsensusMethod::WeightedConsensus
281        ));
282    }
283
284    #[test]
285    fn test_graph_based_consensus_config() {
286        let config = graph_based_consensus_config(0.7);
287        if let ConsensusMethod::GraphBased {
288            similarity_threshold,
289        } = config.consensus_method
290        {
291            assert!((similarity_threshold - 0.7).abs() < 1e-10);
292        } else {
293            panic!("Expected GraphBased consensus method");
294        }
295    }
296
297    #[test]
298    fn test_quick_ensemble_clustering() {
299        let data = Array2::from_shape_vec((20, 2), (0..40).map(|x| x as f64).collect())
300            .expect("Operation failed");
301        let result = quick_ensemble_clustering(data.view(), Some(5));
302        assert!(result.is_ok());
303
304        let ensemble_result = result.expect("Operation failed");
305        assert_eq!(ensemble_result.consensus_labels.len(), 20);
306        assert_eq!(ensemble_result.individual_results.len(), 5);
307    }
308
309    #[test]
310    fn test_quick_multi_algorithm_ensemble() {
311        let data = Array2::from_shape_vec((30, 3), (0..90).map(|x| x as f64).collect())
312            .expect("Operation failed");
313        let result = quick_multi_algorithm_ensemble(data.view());
314        assert!(result.is_ok());
315
316        let ensemble_result = result.expect("Operation failed");
317        assert_eq!(ensemble_result.consensus_labels.len(), 30);
318    }
319
320    #[test]
321    fn test_ensemble_result_metrics() {
322        let data = Array2::from_shape_vec((15, 2), (0..30).map(|x| x as f64).collect())
323            .expect("Operation failed");
324        let result = quick_ensemble_clustering(data.view(), Some(3));
325        assert!(result.is_ok());
326
327        let ensemble_result = result.expect("Operation failed");
328        assert!(ensemble_result.ensemble_quality >= -1.0);
329        assert!(ensemble_result.ensemble_quality <= 1.0);
330        assert!(ensemble_result.stability_score >= 0.0);
331        assert!(ensemble_result.stability_score <= 1.0);
332    }
333
334    #[test]
335    fn test_consensus_statistics() {
336        let data = Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect())
337            .expect("Operation failed");
338        let result = quick_ensemble_clustering(data.view(), Some(3));
339        assert!(result.is_ok());
340
341        let ensemble_result = result.expect("Operation failed");
342        let consensus_stats = &ensemble_result.consensus_stats;
343
344        assert_eq!(consensus_stats.consensus_strength.len(), 10);
345        assert_eq!(consensus_stats.agreement_counts.len(), 10);
346        assert!(consensus_stats.average_consensus_strength() >= 0.0);
347        assert!(consensus_stats.average_consensus_strength() <= 1.0);
348    }
349
350    #[test]
351    fn test_diversity_metrics() {
352        let data = Array2::from_shape_vec((12, 2), (0..24).map(|x| x as f64).collect())
353            .expect("Operation failed");
354        let result = quick_ensemble_clustering(data.view(), Some(4));
355        assert!(result.is_ok());
356
357        let ensemble_result = result.expect("Operation failed");
358        let diversity_metrics = &ensemble_result.diversity_metrics;
359
360        assert!(diversity_metrics.average_diversity >= 0.0);
361        assert!(diversity_metrics.average_diversity <= 1.0);
362        assert_eq!(diversity_metrics.diversity_matrix.nrows(), 4);
363        assert_eq!(diversity_metrics.diversity_matrix.ncols(), 4);
364    }
365
366    #[test]
367    fn test_ensemble_clusterer_creation() {
368        let config = EnsembleConfig::default();
369        let ensemble: EnsembleClusterer<f64> = EnsembleClusterer::new(config.clone());
370
371        // Test that the ensemble can be created with different configurations
372        let custom_config = EnsembleConfig {
373            n_estimators: 20,
374            sampling_strategy: SamplingStrategy::RandomSubspace { feature_ratio: 0.5 },
375            consensus_method: ConsensusMethod::WeightedConsensus,
376            random_seed: Some(42),
377            diversity_strategy: Some(DiversityStrategy::AlgorithmDiversity {
378                algorithms: vec![ClusteringAlgorithm::KMeans { k_range: (2, 10) }],
379            }),
380            quality_threshold: Some(0.1),
381            max_clusters: Some(15),
382        };
383
384        let custom_ensemble: EnsembleClusterer<f64> = EnsembleClusterer::new(custom_config);
385
386        // Both ensembles should be creatable without errors
387        assert!(true); // If we get here, creation succeeded
388    }
389}