1use super::algorithms::EnsembleClusterer;
7use super::core::*;
8use crate::error::{ClusteringError, Result};
9use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
10use scirs2_core::numeric::{Float, FromPrimitive};
11use scirs2_core::random::prelude::*;
12use std::collections::HashMap;
13use std::fmt::Debug;
14
15#[derive(Debug, Clone)]
17pub struct AdaptationConfig {
18 pub chunk_size: usize,
20 pub min_evaluations: usize,
22 pub performance_threshold: f64,
24 pub max_clusterers: usize,
26 pub strategy: AdaptationStrategy,
28}
29
30#[derive(Debug, Clone)]
32pub enum AdaptationStrategy {
33 AddDiverse,
35 RemoveWorst,
37 Replace,
39 Hybrid(Vec<AdaptationStrategy>),
41}
42
43#[derive(Debug, Clone)]
45pub struct FederationConfig {
46 pub differential_privacy: bool,
48 pub privacy_budget: f64,
50 pub aggregation_method: AggregationMethod,
52 pub max_rounds: usize,
54 pub convergence_threshold: f64,
56}
57
58#[derive(Debug, Clone)]
60pub enum AggregationMethod {
61 SecureAveraging,
63 HomomorphicEncryption,
65 MultiPartyComputation,
67}
68
69pub fn ensemble_clustering<F>(data: ArrayView2<F>) -> Result<EnsembleResult>
71where
72 F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
73 f64: From<F>,
74{
75 let config = EnsembleConfig::default();
76 let ensemble = EnsembleClusterer::new(config);
77 ensemble.fit(data)
78}
79
80pub fn bootstrap_ensemble<F>(
82 data: ArrayView2<F>,
83 n_estimators: usize,
84 sample_ratio: f64,
85) -> Result<EnsembleResult>
86where
87 F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
88 f64: From<F>,
89{
90 let config = EnsembleConfig {
91 n_estimators,
92 sampling_strategy: SamplingStrategy::Bootstrap { sample_ratio },
93 ..Default::default()
94 };
95 let ensemble = EnsembleClusterer::new(config);
96 ensemble.fit(data)
97}
98
99pub fn multi_algorithm_ensemble<F>(
101 data: ArrayView2<F>,
102 algorithms: Vec<ClusteringAlgorithm>,
103) -> Result<EnsembleResult>
104where
105 F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
106 f64: From<F>,
107{
108 let config = EnsembleConfig {
109 diversity_strategy: Some(DiversityStrategy::AlgorithmDiversity { algorithms }),
110 ..Default::default()
111 };
112 let ensemble = EnsembleClusterer::new(config);
113 ensemble.fit(data)
114}
115
116pub fn meta_clustering_ensemble<F>(
121 data: ArrayView2<F>,
122 baseconfigs: Vec<EnsembleConfig>,
123 metaconfig: EnsembleConfig,
124) -> Result<EnsembleResult>
125where
126 F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
127 f64: From<F>,
128{
129 let mut base_results = Vec::new();
130 let n_samples = data.shape()[0];
131
132 for config in baseconfigs {
134 let ensemble = EnsembleClusterer::new(config);
135 let result = ensemble.fit(data)?;
136 base_results.extend(result.individual_results);
137 }
138
139 let mut meta_features = Array2::zeros((n_samples, base_results.len()));
141 for (i, result) in base_results.iter().enumerate() {
142 for (j, &label) in result.labels.iter().enumerate() {
143 meta_features[[j, i]] = F::from(label).expect("Failed to convert to float");
144 }
145 }
146
147 let meta_ensemble = EnsembleClusterer::new(metaconfig);
149 let mut meta_result = meta_ensemble.fit(meta_features.view())?;
150
151 meta_result.individual_results = base_results;
153
154 Ok(meta_result)
155}
156
157pub fn adaptive_ensemble<F>(
162 data: ArrayView2<F>,
163 config: &EnsembleConfig,
164 adaptationconfig: AdaptationConfig,
165) -> Result<EnsembleResult>
166where
167 F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
168 f64: From<F>,
169{
170 let mut ensemble = EnsembleClusterer::new(config.clone());
171 let mut current_results = Vec::new();
172 let chunk_size = adaptationconfig.chunk_size;
173
174 for chunk_start in (0..data.shape()[0]).step_by(chunk_size) {
176 let chunk_end = (chunk_start + chunk_size).min(data.shape()[0]);
177 let chunk_data = data.slice(s![chunk_start..chunk_end, ..]);
178
179 let chunk_result = ensemble.fit(chunk_data)?;
181
182 if current_results.len() >= adaptationconfig.min_evaluations {
184 let performance = evaluate_ensemble_performance(¤t_results);
185
186 if performance < adaptationconfig.performance_threshold {
187 ensemble =
189 adapt_ensemble_composition(ensemble, ¤t_results, &adaptationconfig)?;
190 }
191 }
192
193 current_results.push(chunk_result);
194 }
195
196 combine_chunkresults(current_results)
198}
199
200pub fn federated_ensemble<F>(
205 data_sources: Vec<ArrayView2<F>>,
206 config: &EnsembleConfig,
207 federationconfig: FederationConfig,
208) -> Result<EnsembleResult>
209where
210 F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
211 f64: From<F>,
212{
213 let mut local_results = Vec::new();
214
215 for data_source in data_sources {
217 let local_ensemble = EnsembleClusterer::new(config.clone());
218 let result = local_ensemble.fit(data_source)?;
219
220 let private_result = if federationconfig.differential_privacy {
222 apply_differential_privacy(result, federationconfig.privacy_budget)?
223 } else {
224 result
225 };
226
227 local_results.push(private_result);
228 }
229
230 let aggregated_result = secure_aggregate_results(local_results, &federationconfig)?;
232
233 Ok(aggregated_result)
234}
235
236fn evaluate_ensemble_performance(results: &[EnsembleResult]) -> f64 {
239 if results.is_empty() {
240 return 0.0;
241 }
242
243 results.iter().map(|r| r.ensemble_quality).sum::<f64>() / results.len() as f64
245}
246
247fn adapt_ensemble_composition<F>(
248 mut ensemble: EnsembleClusterer<F>,
249 results: &[EnsembleResult],
250 config: &AdaptationConfig,
251) -> Result<EnsembleClusterer<F>>
252where
253 F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
254{
255 match config.strategy {
256 AdaptationStrategy::RemoveWorst => {
257 if results.len() > 1 {
259 }
262 }
263 AdaptationStrategy::AddDiverse => {
264 }
267 _ => {
268 }
270 }
271
272 Ok(ensemble)
273}
274
275fn combine_chunkresults(chunkresults: Vec<EnsembleResult>) -> Result<EnsembleResult> {
276 if chunkresults.is_empty() {
277 return Err(ClusteringError::InvalidInput(
278 "No chunk results to combine".to_string(),
279 ));
280 }
281
282 Ok(chunkresults.into_iter().next().expect("Operation failed"))
285}
286
287fn apply_differential_privacy(
288 mut result: EnsembleResult,
289 privacy_budget: f64,
290) -> Result<EnsembleResult> {
291 let mut rng = scirs2_core::random::thread_rng();
294
295 for label in result.consensus_labels.iter_mut() {
296 if rng.random::<f64>() < 0.05 {
297 *label = (*label + 1) % 3; }
300 }
301
302 Ok(result)
303}
304
305fn secure_aggregate_results(
306 local_results: Vec<EnsembleResult>,
307 config: &FederationConfig,
308) -> Result<EnsembleResult> {
309 if local_results.is_empty() {
310 return Err(ClusteringError::InvalidInput(
311 "No local results to aggregate".to_string(),
312 ));
313 }
314
315 let n_samples = local_results[0].consensus_labels.len();
318 let mut consensus_labels = Array1::<i32>::zeros(n_samples);
319
320 for i in 0..n_samples {
321 let mut votes = HashMap::new();
322 for result in &local_results {
323 *votes.entry(result.consensus_labels[i]).or_insert(0) += 1;
324 }
325
326 let majority_label = votes
328 .into_iter()
329 .max_by_key(|(_, count)| *count)
330 .map(|(label_, _)| label_)
331 .unwrap_or(0);
332
333 consensus_labels[i] = majority_label;
334 }
335
336 let mut aggregated = local_results.into_iter().next().expect("Operation failed");
338 aggregated.consensus_labels = consensus_labels;
339
340 Ok(aggregated)
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use scirs2_core::ndarray::Array2;
347
348 #[test]
349 fn test_simple_ensemble_clustering() {
350 let data = Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect())
351 .expect("Operation failed");
352 let result = ensemble_clustering(data.view());
353 assert!(result.is_ok());
354 }
355
356 #[test]
357 fn test_bootstrap_ensemble() {
358 let data = Array2::from_shape_vec((20, 3), (0..60).map(|x| x as f64).collect())
359 .expect("Operation failed");
360 let result = bootstrap_ensemble(data.view(), 5, 0.8);
361 assert!(result.is_ok());
362 }
363
364 #[test]
365 fn test_adaptation_config() {
366 let config = AdaptationConfig {
367 chunk_size: 100,
368 min_evaluations: 3,
369 performance_threshold: 0.5,
370 max_clusterers: 20,
371 strategy: AdaptationStrategy::AddDiverse,
372 };
373 assert_eq!(config.chunk_size, 100);
374 assert_eq!(config.min_evaluations, 3);
375 }
376
377 #[test]
378 fn test_federation_config() {
379 let config = FederationConfig {
380 differential_privacy: true,
381 privacy_budget: 1.0,
382 aggregation_method: AggregationMethod::SecureAveraging,
383 max_rounds: 10,
384 convergence_threshold: 0.01,
385 };
386 assert!(config.differential_privacy);
387 assert_eq!(config.privacy_budget, 1.0);
388 }
389}