scirs2_cluster/ensemble/
algorithms.rs

1//! Main ensemble clustering implementation
2//!
3//! This module contains the core ensemble clustering algorithm that combines
4//! multiple base clustering results using various consensus methods and
5//! diversity strategies.
6
7use super::core::*;
8use crate::error::{ClusteringError, Result};
9use crate::metrics::{adjusted_rand_index, silhouette_score};
10use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
11use scirs2_core::numeric::{Float, FromPrimitive};
12use scirs2_core::random::prelude::*;
13use std::collections::{HashMap, HashSet};
14use std::fmt::Debug;
15
16/// Main ensemble clustering implementation
17pub struct EnsembleClusterer<F: Float> {
18    config: EnsembleConfig,
19    phantom: std::marker::PhantomData<F>,
20}
21
22impl<
23        F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
24    > EnsembleClusterer<F>
25where
26    f64: From<F>,
27{
28    /// Create a new ensemble clusterer
29    pub fn new(config: EnsembleConfig) -> Self {
30        Self {
31            config,
32            phantom: std::marker::PhantomData,
33        }
34    }
35
36    /// Perform ensemble clustering
37    pub fn fit(&self, data: ArrayView2<F>) -> Result<EnsembleResult> {
38        let start_time = std::time::Instant::now();
39
40        // Generate diverse clustering results
41        let individual_results = self.generate_diverse_clusterings(data)?;
42
43        // Filter results based on quality threshold
44        let filtered_results = self.filter_by_quality(&individual_results);
45
46        // Combine results using consensus method
47        let consensus_labels = self.build_consensus(&filtered_results, data)?;
48
49        // Calculate ensemble statistics
50        let consensus_stats =
51            self.calculate_consensus_statistics(&filtered_results, &consensus_labels)?;
52        let diversity_metrics = self.calculate_diversity_metrics(&filtered_results)?;
53
54        // Calculate overall quality
55        let data_f64 = data.mapv(|x| x.to_f64().unwrap_or(0.0));
56        let ensemble_quality =
57            silhouette_score(data_f64.view(), consensus_labels.view()).unwrap_or(0.0);
58
59        // Calculate stability score
60        let stability_score = self.calculate_consensus_stability_score(&consensus_stats);
61
62        let total_time = start_time.elapsed().as_secs_f64();
63
64        Ok(EnsembleResult {
65            consensus_labels,
66            individual_results: filtered_results,
67            consensus_stats,
68            diversity_metrics,
69            ensemble_quality,
70            stability_score,
71        })
72    }
73
74    /// Generate diverse clustering results
75    fn generate_diverse_clusterings(&self, data: ArrayView2<F>) -> Result<Vec<ClusteringResult>> {
76        let mut results = Vec::new();
77        let mut rng = match self.config.random_seed {
78            Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
79            None => scirs2_core::random::rngs::StdRng::seed_from_u64(42),
80        };
81
82        for i in 0..self.config.n_estimators {
83            let clustering_start = std::time::Instant::now();
84
85            // Apply sampling strategy
86            let (sampled_data, sample_indices) = self.apply_sampling_strategy(data, &mut rng)?;
87
88            // Select algorithm and parameters based on diversity strategy
89            let (algorithm, parameters) = self.select_algorithm_and_parameters(i, &mut rng)?;
90
91            // Run clustering
92            let mut labels = self.run_clustering(&sampled_data, &algorithm, &parameters)?;
93
94            // Map labels back to original data size if needed
95            if sample_indices.len() != data.nrows() {
96                labels = self.map_labels_to_full_data(&labels, &sample_indices, data.nrows())?;
97            }
98
99            // Calculate quality score
100            let data_f64 = data.mapv(|x| x.to_f64().unwrap_or(0.0));
101            let quality_score = silhouette_score(data_f64.view(), labels.view()).unwrap_or(-1.0);
102
103            let runtime = clustering_start.elapsed().as_secs_f64();
104            let n_clusters = self.count_clusters(&labels);
105
106            let result = ClusteringResult {
107                labels,
108                algorithm: format!("{:?}", algorithm),
109                parameters,
110                quality_score,
111                stability_score: None,
112                n_clusters,
113                runtime,
114            };
115
116            results.push(result);
117        }
118
119        Ok(results)
120    }
121
122    /// Apply sampling strategy to data
123    fn apply_sampling_strategy(
124        &self,
125        data: ArrayView2<F>,
126        rng: &mut scirs2_core::random::rngs::StdRng,
127    ) -> Result<(Array2<F>, Vec<usize>)> {
128        let n_samples = data.nrows();
129        let n_features = data.ncols();
130
131        match &self.config.sampling_strategy {
132            SamplingStrategy::Bootstrap { sample_ratio } => {
133                let sample_size = (n_samples as f64 * sample_ratio) as usize;
134                let mut indices = Vec::new();
135
136                for _ in 0..sample_size {
137                    indices.push(rng.gen_range(0..n_samples));
138                }
139
140                let sampled_data = self.extract_samples(data, &indices)?;
141                Ok((sampled_data, indices))
142            }
143            SamplingStrategy::RandomSubspace { feature_ratio } => {
144                let n_selected_features = (n_features as f64 * feature_ratio) as usize;
145                let mut featureindices: Vec<usize> = (0..n_features).collect();
146                featureindices.shuffle(rng);
147                featureindices.truncate(n_selected_features);
148
149                let sample_indices: Vec<usize> = (0..n_samples).collect();
150                let sampled_data = self.extract_features(data, &featureindices)?;
151                Ok((sampled_data, sample_indices))
152            }
153            SamplingStrategy::BootstrapSubspace {
154                sample_ratio,
155                feature_ratio,
156            } => {
157                // First apply bootstrap sampling
158                let sample_size = (n_samples as f64 * sample_ratio) as usize;
159                let mut sample_indices = Vec::new();
160
161                for _ in 0..sample_size {
162                    sample_indices.push(rng.gen_range(0..n_samples));
163                }
164
165                // Then apply feature sampling
166                let n_selected_features = (n_features as f64 * feature_ratio) as usize;
167                let mut featureindices: Vec<usize> = (0..n_features).collect();
168                featureindices.shuffle(rng);
169                featureindices.truncate(n_selected_features);
170
171                let bootstrap_data = self.extract_samples(data, &sample_indices)?;
172                let sampled_data = self.extract_features(bootstrap_data.view(), &featureindices)?;
173
174                Ok((sampled_data, sample_indices))
175            }
176            SamplingStrategy::NoiseInjection {
177                noise_level,
178                noise_type,
179            } => {
180                let sample_indices: Vec<usize> = (0..n_samples).collect();
181                let mut noisy_data = data.to_owned();
182
183                match noise_type {
184                    NoiseType::Gaussian => {
185                        for i in 0..n_samples {
186                            for j in 0..n_features {
187                                let noise = F::from(rng.gen::<f64>() * 2.0 - 1.0).unwrap()
188                                    * F::from(*noise_level).unwrap();
189                                noisy_data[[i, j]] = noisy_data[[i, j]] + noise;
190                            }
191                        }
192                    }
193                    NoiseType::Uniform => {
194                        for i in 0..n_samples {
195                            for j in 0..n_features {
196                                let noise =
197                                    F::from((rng.gen::<f64>() * 2.0 - 1.0) * noise_level).unwrap();
198                                noisy_data[[i, j]] = noisy_data[[i, j]] + noise;
199                            }
200                        }
201                    }
202                    NoiseType::Outliers { outlier_ratio } => {
203                        let n_outliers = (n_samples as f64 * outlier_ratio) as usize;
204                        for _ in 0..n_outliers {
205                            let outlier_idx = rng.gen_range(0..n_samples);
206                            for j in 0..n_features {
207                                let outlier_value = F::from(rng.gen::<f64>() * 10.0 - 5.0).unwrap();
208                                noisy_data[[outlier_idx, j]] = outlier_value;
209                            }
210                        }
211                    }
212                }
213
214                Ok((noisy_data, sample_indices))
215            }
216            SamplingStrategy::None => {
217                let sample_indices: Vec<usize> = (0..n_samples).collect();
218                Ok((data.to_owned(), sample_indices))
219            }
220            SamplingStrategy::RandomProjection { target_dimensions } => {
221                let n_features = data.ncols();
222                if *target_dimensions >= n_features {
223                    // If target dimensions >= original dimensions, no projection needed
224                    let sample_indices: Vec<usize> = (0..n_samples).collect();
225                    return Ok((data.to_owned(), sample_indices));
226                }
227
228                // Generate random projection matrix using Gaussian random values
229                let mut rng = match self.config.random_seed {
230                    Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
231                    None => scirs2_core::random::rngs::StdRng::seed_from_u64(
232                        scirs2_core::random::random(),
233                    ),
234                };
235
236                // Create random projection matrix (n_features x target_dimensions)
237                let mut projection_matrix = Array2::zeros((n_features, *target_dimensions));
238                for i in 0..n_features {
239                    for j in 0..*target_dimensions {
240                        // Use Gaussian random values for projection matrix
241                        let random_val = F::from(rng.gen::<f64>()).unwrap();
242                        let two = F::from(2.0).unwrap();
243                        let one = F::from(1.0).unwrap();
244                        projection_matrix[[i, j]] = random_val * two - one;
245                    }
246                }
247
248                // Normalize columns to preserve distances approximately
249                for j in 0..*target_dimensions {
250                    let col_norm = projection_matrix.column(j).mapv(|x| x * x).sum().sqrt();
251                    if col_norm > F::zero() {
252                        for i in 0..n_features {
253                            projection_matrix[[i, j]] = projection_matrix[[i, j]] / col_norm;
254                        }
255                    }
256                }
257
258                // Apply random projection: data * projection_matrix
259                let projected_data = data.dot(&projection_matrix);
260                let sample_indices: Vec<usize> = (0..n_samples).collect();
261
262                Ok((projected_data, sample_indices))
263            }
264        }
265    }
266
267    /// Extract samples based on indices
268    fn extract_samples(&self, data: ArrayView2<F>, indices: &[usize]) -> Result<Array2<F>> {
269        let n_features = data.ncols();
270        let mut sampled_data = Array2::zeros((indices.len(), n_features));
271
272        for (new_idx, &orig_idx) in indices.iter().enumerate() {
273            if orig_idx >= data.nrows() {
274                return Err(ClusteringError::InvalidInput(
275                    "Sample index out of bounds".to_string(),
276                ));
277            }
278            sampled_data.row_mut(new_idx).assign(&data.row(orig_idx));
279        }
280
281        Ok(sampled_data)
282    }
283
284    /// Extract features based on indices
285    fn extract_features(&self, data: ArrayView2<F>, featureindices: &[usize]) -> Result<Array2<F>> {
286        let n_samples = data.nrows();
287        let mut feature_data = Array2::zeros((n_samples, featureindices.len()));
288
289        for (new_idx, &orig_idx) in featureindices.iter().enumerate() {
290            if orig_idx >= data.ncols() {
291                return Err(ClusteringError::InvalidInput(
292                    "Feature index out of bounds".to_string(),
293                ));
294            }
295            feature_data
296                .column_mut(new_idx)
297                .assign(&data.column(orig_idx));
298        }
299
300        Ok(feature_data)
301    }
302
303    /// Apply consensus method to combine clustering results
304    fn apply_consensus(
305        &self,
306        results: &[ClusteringResult],
307        data: ArrayView2<F>,
308    ) -> Result<EnsembleResult> {
309        match &self.config.consensus_method {
310            ConsensusMethod::MajorityVoting => self.majority_voting_consensus(results, data),
311            ConsensusMethod::WeightedConsensus => self.weighted_consensus(results, data),
312            ConsensusMethod::GraphBased {
313                similarity_threshold,
314            } => {
315                let result = self.graph_based_consensus(results, data, *similarity_threshold)?;
316                Ok(result)
317            }
318            ConsensusMethod::CoAssociation { threshold } => {
319                let result = self.co_association_consensus(results, data, *threshold)?;
320                Ok(result)
321            }
322            ConsensusMethod::EvidenceAccumulation => {
323                let result = self.evidence_accumulation_consensus(results, data)?;
324                Ok(result)
325            }
326            ConsensusMethod::Hierarchical { linkage_method } => {
327                self.hierarchical_consensus(results, data, linkage_method)
328            }
329        }
330    }
331
332    /// Majority voting consensus method
333    fn majority_voting_consensus(
334        &self,
335        results: &[ClusteringResult],
336        data: ArrayView2<F>,
337    ) -> Result<EnsembleResult> {
338        let n_samples = data.nrows();
339        let mut consensus_labels = Array1::zeros(n_samples);
340        let mut vote_matrix = HashMap::new();
341
342        // Collect votes for each sample
343        for result in results {
344            for (sample_idx, &cluster_label) in result.labels.iter().enumerate() {
345                let entry = vote_matrix.entry(sample_idx).or_insert_with(HashMap::new);
346                *entry.entry(cluster_label).or_insert(0) += 1;
347            }
348        }
349
350        // Determine consensus labels
351        for sample_idx in 0..n_samples {
352            if let Some(votes) = vote_matrix.get(&sample_idx) {
353                let most_voted_cluster = votes
354                    .iter()
355                    .max_by_key(|(_, &count)| count)
356                    .map(|(&cluster_, _)| cluster_)
357                    .unwrap_or(0);
358                consensus_labels[sample_idx] = most_voted_cluster;
359            }
360        }
361
362        // Calculate confidence and statistics
363        let avg_quality_score =
364            results.iter().map(|r| r.quality_score).sum::<f64>() / results.len() as f64;
365        let consensus_stats = self.calculate_consensus_statistics(results, &consensus_labels)?;
366        let diversity_metrics = self.calculate_diversity_metrics(results)?;
367        let stability_score = self.calculate_consensus_stability_score(&consensus_stats);
368
369        Ok(EnsembleResult {
370            consensus_labels,
371            individual_results: results.to_vec(),
372            consensus_stats,
373            diversity_metrics,
374            ensemble_quality: avg_quality_score,
375            stability_score,
376        })
377    }
378
379    /// Weighted consensus method based on quality scores
380    fn weighted_consensus(
381        &self,
382        results: &[ClusteringResult],
383        data: ArrayView2<F>,
384    ) -> Result<EnsembleResult> {
385        let n_samples = data.nrows();
386        let mut consensus_labels = Array1::zeros(n_samples);
387        let mut weighted_vote_matrix = HashMap::new();
388
389        // Collect weighted votes for each sample
390        for result in results {
391            let weight = result.quality_score.max(0.0); // Ensure non-negative weights
392            for (sample_idx, &cluster_label) in result.labels.iter().enumerate() {
393                let entry = weighted_vote_matrix
394                    .entry(sample_idx)
395                    .or_insert_with(HashMap::new);
396                *entry.entry(cluster_label).or_insert(0.0) += weight;
397            }
398        }
399
400        // Determine consensus labels based on weighted votes
401        for sample_idx in 0..n_samples {
402            if let Some(votes) = weighted_vote_matrix.get(&sample_idx) {
403                let most_voted_cluster = votes
404                    .iter()
405                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
406                    .map(|(&cluster_, _)| cluster_)
407                    .unwrap_or(0);
408                consensus_labels[sample_idx] = most_voted_cluster;
409            }
410        }
411
412        // Calculate ensemble score as weighted average
413        let total_weight: f64 = results.iter().map(|r| r.quality_score.max(0.0)).sum();
414        let ensemble_score = if total_weight > 0.0 {
415            results
416                .iter()
417                .map(|r| r.quality_score * r.quality_score.max(0.0))
418                .sum::<f64>()
419                / total_weight
420        } else {
421            0.0
422        };
423
424        let consensus_stats = self.calculate_consensus_statistics(results, &consensus_labels)?;
425        let diversity_metrics = self.calculate_diversity_metrics(results)?;
426        let stability_score = self.calculate_consensus_stability_score(&consensus_stats);
427
428        Ok(EnsembleResult {
429            consensus_labels,
430            individual_results: results.to_vec(),
431            consensus_stats,
432            diversity_metrics,
433            ensemble_quality: ensemble_score,
434            stability_score,
435        })
436    }
437
438    /// Graph-based consensus method
439    fn graph_based_consensus(
440        &self,
441        results: &[ClusteringResult],
442        data: ArrayView2<F>,
443        similarity_threshold: f64,
444    ) -> Result<EnsembleResult> {
445        let n_samples = data.nrows();
446
447        // Build co-association matrix
448        let mut co_association = Array2::zeros((n_samples, n_samples));
449
450        for result in results {
451            for i in 0..n_samples {
452                for j in i + 1..n_samples {
453                    if result.labels[i] == result.labels[j] {
454                        co_association[[i, j]] += 1.0;
455                        co_association[[j, i]] += 1.0;
456                    }
457                }
458            }
459        }
460
461        // Normalize by number of clusterers
462        co_association /= results.len() as f64;
463
464        // Create similarity graph
465        let mut similarity_graph = Array2::zeros((n_samples, n_samples));
466        for i in 0..n_samples {
467            for j in 0..n_samples {
468                if co_association[[i, j]] >= similarity_threshold {
469                    similarity_graph[[i, j]] = co_association[[i, j]];
470                }
471            }
472        }
473
474        // Apply graph clustering (simplified connected components)
475        let mut consensus_labels = Array1::from_elem(n_samples, -1i32);
476        let mut current_cluster = 0i32;
477        let mut visited = vec![false; n_samples];
478
479        for i in 0..n_samples {
480            if !visited[i] {
481                // BFS to find connected component
482                let mut queue = vec![i];
483                visited[i] = true;
484                consensus_labels[i] = current_cluster;
485
486                while let Some(node) = queue.pop() {
487                    for j in 0..n_samples {
488                        if !visited[j] && similarity_graph[[node, j]] > 0.0 {
489                            visited[j] = true;
490                            consensus_labels[j] = current_cluster;
491                            queue.push(j);
492                        }
493                    }
494                }
495                current_cluster += 1;
496            }
497        }
498
499        let avg_quality_score =
500            results.iter().map(|r| r.quality_score).sum::<f64>() / results.len() as f64;
501        let consensus_stats = self.calculate_consensus_statistics(results, &consensus_labels)?;
502        let diversity_metrics = self.calculate_diversity_metrics(results)?;
503        let stability_score = self.calculate_consensus_stability_score(&consensus_stats);
504
505        Ok(EnsembleResult {
506            consensus_labels,
507            individual_results: results.to_vec(),
508            consensus_stats,
509            diversity_metrics,
510            ensemble_quality: avg_quality_score,
511            stability_score,
512        })
513    }
514
515    /// Co-association consensus method
516    fn co_association_consensus(
517        &self,
518        results: &[ClusteringResult],
519        data: ArrayView2<F>,
520        threshold: f64,
521    ) -> Result<EnsembleResult> {
522        // This is similar to graph-based but with different threshold handling
523        self.graph_based_consensus(results, data, threshold)
524    }
525
526    /// Evidence accumulation consensus method
527    fn evidence_accumulation_consensus(
528        &self,
529        results: &[ClusteringResult],
530        data: ArrayView2<F>,
531    ) -> Result<EnsembleResult> {
532        // Use hierarchical clustering on the co-association matrix
533        self.hierarchical_consensus(results, data, "ward")
534    }
535
536    /// Hierarchical consensus method
537    fn hierarchical_consensus(
538        &self,
539        results: &[ClusteringResult],
540        data: ArrayView2<F>,
541        linkage_method: &str,
542    ) -> Result<EnsembleResult> {
543        let n_samples = data.nrows();
544
545        // Build co-association matrix as distance matrix
546        let mut co_association: Array2<f64> = Array2::zeros((n_samples, n_samples));
547
548        for result in results {
549            for i in 0..n_samples {
550                for j in i + 1..n_samples {
551                    if result.labels[i] == result.labels[j] {
552                        co_association[[i, j]] += 1.0;
553                        co_association[[j, i]] += 1.0;
554                    }
555                }
556            }
557        }
558
559        // Convert to distance matrix (1 - similarity)
560        let mut distance_matrix = Array2::ones((n_samples, n_samples));
561        for i in 0..n_samples {
562            for j in 0..n_samples {
563                distance_matrix[[i, j]] = 1.0 - (co_association[[i, j]] / results.len() as f64);
564            }
565            distance_matrix[[i, i]] = 0.0; // Distance to self is 0
566        }
567
568        // Apply hierarchical clustering (simplified implementation)
569        // For now, use a simple threshold-based approach
570        let threshold = 0.5;
571        let mut consensus_labels = Array1::from_elem(n_samples, -1i32);
572        let mut current_cluster = 0i32;
573        let mut assigned = vec![false; n_samples];
574
575        for i in 0..n_samples {
576            if !assigned[i] {
577                consensus_labels[i] = current_cluster;
578                assigned[i] = true;
579
580                // Find all points within threshold distance
581                for j in (i + 1)..n_samples {
582                    if !assigned[j] && distance_matrix[[i, j]] <= threshold {
583                        consensus_labels[j] = current_cluster;
584                        assigned[j] = true;
585                    }
586                }
587                current_cluster += 1;
588            }
589        }
590
591        let avg_quality_score =
592            results.iter().map(|r| r.quality_score).sum::<f64>() / results.len() as f64;
593        let consensus_stats = self.calculate_consensus_statistics(results, &consensus_labels)?;
594        let diversity_metrics = self.calculate_diversity_metrics(results)?;
595        let stability_score = self.calculate_consensus_stability_score(&consensus_stats);
596
597        Ok(EnsembleResult {
598            consensus_labels,
599            individual_results: results.to_vec(),
600            consensus_stats,
601            diversity_metrics,
602            ensemble_quality: avg_quality_score,
603            stability_score,
604        })
605    }
606
607    /// Calculate diversity score between clusterers
608    fn calculate_diversity_score(&self, results: &[ClusteringResult]) -> f64 {
609        if results.len() < 2 {
610            return 0.0;
611        }
612
613        let mut total_diversity = 0.0;
614        let mut count = 0;
615
616        for i in 0..results.len() {
617            for j in (i + 1)..results.len() {
618                // Calculate pairwise diversity using adjusted rand index
619                if let Ok(ari) =
620                    adjusted_rand_index::<f64>(results[i].labels.view(), results[j].labels.view())
621                {
622                    total_diversity += 1.0 - ari; // Higher diversity means lower agreement
623                    count += 1;
624                }
625            }
626        }
627
628        if count > 0 {
629            total_diversity / count as f64
630        } else {
631            0.0
632        }
633    }
634
635    /// Calculate agreement ratio between clusterers
636    fn calculate_agreement_ratio(&self, results: &[ClusteringResult]) -> f64 {
637        if results.len() < 2 {
638            return 1.0;
639        }
640
641        let n_samples = results[0].labels.len();
642        let mut total_agreements = 0;
643        let mut total_pairs = 0;
644
645        for i in 0..results.len() {
646            for j in (i + 1)..results.len() {
647                for sample_idx in 0..n_samples {
648                    if results[i].labels[sample_idx] == results[j].labels[sample_idx] {
649                        total_agreements += 1;
650                    }
651                    total_pairs += 1;
652                }
653            }
654        }
655
656        if total_pairs > 0 {
657            total_agreements as f64 / total_pairs as f64
658        } else {
659            0.0
660        }
661    }
662
663    /// Calculate confidence scores for consensus
664    fn calculate_confidence_scores(
665        &self,
666        vote_matrix: &HashMap<usize, HashMap<i32, usize>>,
667        n_samples: usize,
668    ) -> Vec<f64> {
669        let mut confidence_scores = vec![0.0; n_samples];
670
671        for sample_idx in 0..n_samples {
672            if let Some(votes) = vote_matrix.get(&sample_idx) {
673                let total_votes: usize = votes.values().sum();
674                let max_votes = votes.values().max().copied().unwrap_or(0);
675
676                if total_votes > 0 {
677                    confidence_scores[sample_idx] = max_votes as f64 / total_votes as f64;
678                }
679            }
680        }
681
682        confidence_scores
683    }
684
685    /// Calculate weighted confidence scores for consensus
686    fn calculate_weighted_confidence_scores(
687        &self,
688        vote_matrix: &HashMap<usize, HashMap<i32, f64>>,
689        n_samples: usize,
690    ) -> Vec<f64> {
691        let mut confidence_scores = vec![0.0; n_samples];
692
693        for sample_idx in 0..n_samples {
694            if let Some(votes) = vote_matrix.get(&sample_idx) {
695                let total_votes: f64 = votes.values().sum();
696                let max_votes = votes.values().fold(0.0, |acc, &x| acc.max(x));
697
698                if total_votes > 0.0 {
699                    confidence_scores[sample_idx] = max_votes / total_votes;
700                }
701            }
702        }
703
704        confidence_scores
705    }
706
707    /// Calculate cluster diversity metrics
708    fn calculate_cluster_diversity(&self, results: &[ClusteringResult]) -> f64 {
709        let cluster_counts: Vec<usize> = results.iter().map(|r| r.n_clusters).collect();
710
711        if cluster_counts.is_empty() {
712            return 0.0;
713        }
714
715        let mean_clusters =
716            cluster_counts.iter().sum::<usize>() as f64 / cluster_counts.len() as f64;
717        let variance = cluster_counts
718            .iter()
719            .map(|&x| (x as f64 - mean_clusters).powi(2))
720            .sum::<f64>()
721            / cluster_counts.len() as f64;
722
723        variance.sqrt() / mean_clusters // Coefficient of variation
724    }
725
726    /// Calculate algorithm diversity
727    fn calculate_algorithm_diversity(&self, results: &[ClusteringResult]) -> f64 {
728        let unique_algorithms: HashSet<String> =
729            results.iter().map(|r| r.algorithm.clone()).collect();
730
731        unique_algorithms.len() as f64 / results.len() as f64
732    }
733
734    /// Count unique clusters in consensus labels
735    fn count_unique_clusters(&self, labels: &Array1<i32>) -> usize {
736        let mut unique_labels = HashSet::new();
737        for &label in labels {
738            unique_labels.insert(label);
739        }
740        unique_labels.len()
741    }
742
743    /// Select algorithm and parameters based on diversity strategy
744    fn select_algorithm_and_parameters(
745        &self,
746        estimator_index: usize,
747        rng: &mut scirs2_core::random::rngs::StdRng,
748    ) -> Result<(ClusteringAlgorithm, HashMap<String, String>)> {
749        match &self.config.diversity_strategy {
750            Some(DiversityStrategy::AlgorithmDiversity { algorithms }) => {
751                let algorithm = algorithms[estimator_index % algorithms.len()].clone();
752                let parameters = self.generate_random_parameters(&algorithm, rng)?;
753                Ok((algorithm, parameters))
754            }
755            Some(DiversityStrategy::ParameterDiversity {
756                algorithm,
757                parameter_ranges,
758            }) => {
759                let parameters = self.sample_parameter_ranges(parameter_ranges, rng)?;
760                Ok((algorithm.clone(), parameters))
761            }
762            _ => {
763                // Default to K-means with random k
764                let k = rng.gen_range(2..=10);
765                let algorithm = ClusteringAlgorithm::KMeans { k_range: (k, k) };
766                let mut parameters = HashMap::new();
767                parameters.insert("k".to_string(), k.to_string());
768                Ok((algorithm, parameters))
769            }
770        }
771    }
772
773    /// Generate random parameters for an algorithm
774    fn generate_random_parameters(
775        &self,
776        algorithm: &ClusteringAlgorithm,
777        rng: &mut scirs2_core::random::rngs::StdRng,
778    ) -> Result<HashMap<String, String>> {
779        let mut parameters = HashMap::new();
780
781        match algorithm {
782            ClusteringAlgorithm::KMeans { k_range } => {
783                let k = rng.gen_range(k_range.0..=k_range.1);
784                parameters.insert("k".to_string(), k.to_string());
785            }
786            ClusteringAlgorithm::DBSCAN {
787                eps_range,
788                min_samples_range,
789            } => {
790                let eps = rng.gen_range(eps_range.0..=eps_range.1);
791                let min_samples = rng.gen_range(min_samples_range.0..=min_samples_range.1);
792                parameters.insert("eps".to_string(), eps.to_string());
793                parameters.insert("min_samples".to_string(), min_samples.to_string());
794            }
795            ClusteringAlgorithm::MeanShift { bandwidth_range } => {
796                let bandwidth = rng.gen_range(bandwidth_range.0..=bandwidth_range.1);
797                parameters.insert("bandwidth".to_string(), bandwidth.to_string());
798            }
799            ClusteringAlgorithm::Hierarchical { methods } => {
800                let method = &methods[rng.gen_range(0..methods.len())];
801                parameters.insert("method".to_string(), method.clone());
802            }
803            ClusteringAlgorithm::Spectral { k_range } => {
804                let k = rng.gen_range(k_range.0..=k_range.1);
805                parameters.insert("k".to_string(), k.to_string());
806            }
807            ClusteringAlgorithm::AffinityPropagation { damping_range } => {
808                let damping = rng.gen_range(damping_range.0..=damping_range.1);
809                parameters.insert("damping".to_string(), damping.to_string());
810            }
811        }
812
813        Ok(parameters)
814    }
815
816    /// Sample parameters from ranges
817    fn sample_parameter_ranges(
818        &self,
819        parameter_ranges: &HashMap<String, ParameterRange>,
820        rng: &mut scirs2_core::random::rngs::StdRng,
821    ) -> Result<HashMap<String, String>> {
822        let mut parameters = HashMap::new();
823
824        for (param_name, range) in parameter_ranges {
825            let value = match range {
826                ParameterRange::Integer(min, max) => rng.gen_range(*min..=*max).to_string(),
827                ParameterRange::Float(min, max) => rng.gen_range(*min..=*max).to_string(),
828                ParameterRange::Categorical(choices) => {
829                    choices[rng.gen_range(0..choices.len())].clone()
830                }
831                ParameterRange::Boolean => rng.gen_bool(0.5).to_string(),
832            };
833            parameters.insert(param_name.clone(), value);
834        }
835
836        Ok(parameters)
837    }
838
839    /// Run clustering with specified algorithm and parameters
840    fn run_clustering(
841        &self,
842        data: &Array2<F>,
843        algorithm: &ClusteringAlgorithm,
844        parameters: &HashMap<String, String>,
845    ) -> Result<Array1<i32>> {
846        let data_f64 = data.mapv(|x| x.to_f64().unwrap_or(0.0));
847
848        match algorithm {
849            ClusteringAlgorithm::KMeans { .. } => {
850                let k = parameters
851                    .get("k")
852                    .and_then(|s| s.parse().ok())
853                    .unwrap_or(3);
854
855                // Use kmeans from crate
856                use crate::vq::kmeans2;
857                match kmeans2(
858                    data.view(),
859                    k,
860                    Some(100),   // max_iter
861                    None,        // threshold
862                    None,        // init method
863                    None,        // missing method
864                    Some(false), // check_finite
865                    None,        // seed
866                ) {
867                    Ok((_, labels)) => Ok(labels.mapv(|x| x as i32)),
868                    Err(_) => {
869                        // Fallback: create dummy labels
870                        let n_samples = data.nrows();
871                        let labels = Array1::from_shape_fn(n_samples, |i| (i % k) as i32);
872                        Ok(labels)
873                    }
874                }
875            }
876            ClusteringAlgorithm::AffinityPropagation { .. } => {
877                let damping = parameters
878                    .get("damping")
879                    .and_then(|s| s.parse().ok())
880                    .unwrap_or(0.5);
881                let max_iter = parameters
882                    .get("max_iter")
883                    .and_then(|s| s.parse().ok())
884                    .unwrap_or(200);
885                let convergence_iter = parameters
886                    .get("convergence_iter")
887                    .and_then(|s| s.parse().ok())
888                    .unwrap_or(15);
889
890                // Create affinity propagation options
891                use crate::affinity::{affinity_propagation, AffinityPropagationOptions};
892                let options = AffinityPropagationOptions {
893                    damping: F::from(damping).unwrap(),
894                    max_iter,
895                    convergence_iter,
896                    preference: None, // Use default (median of similarities)
897                    affinity: "euclidean".to_string(),
898                    max_affinity_iterations: max_iter, // Use same as max_iter
899                };
900
901                match affinity_propagation(data.view(), false, Some(options)) {
902                    Ok((_, labels)) => Ok(labels),
903                    Err(_) => {
904                        // Fallback: create dummy labels
905                        Ok(Array1::zeros(data.nrows()).mapv(|_: f64| 0i32))
906                    }
907                }
908            }
909            _ => {
910                // For any other algorithms, fallback to k-means
911                let k = parameters
912                    .get("k")
913                    .and_then(|s| s.parse().ok())
914                    .unwrap_or(3);
915
916                use crate::vq::kmeans2;
917                match kmeans2(
918                    data.view(),
919                    k,
920                    Some(100),
921                    None,
922                    None,
923                    None,
924                    Some(false),
925                    None,
926                ) {
927                    Ok((_, labels)) => Ok(labels.mapv(|x| x as i32)),
928                    Err(_) => Ok(Array1::zeros(data.nrows()).mapv(|_: f64| 0i32)),
929                }
930            }
931        }
932    }
933
934    /// Count clusters in results
935    fn count_clusters(&self, labels: &Array1<i32>) -> usize {
936        let mut unique_labels = std::collections::HashSet::new();
937        for &label in labels {
938            unique_labels.insert(label);
939        }
940        unique_labels.len()
941    }
942
943    /// Filter results by quality
944    fn filter_by_quality(&self, results: &[ClusteringResult]) -> Vec<ClusteringResult> {
945        if let Some(threshold) = self.config.quality_threshold {
946            results
947                .iter()
948                .filter(|r| r.quality_score >= threshold)
949                .cloned()
950                .collect()
951        } else {
952            results.to_vec()
953        }
954    }
955
956    /// Map labels back to full dataset size
957    fn map_labels_to_full_data(
958        &self,
959        labels: &Array1<i32>,
960        sample_indices: &[usize],
961        full_size: usize,
962    ) -> Result<Array1<i32>> {
963        let mut full_labels = Array1::from_elem(full_size, -1); // Use -1 for unassigned
964
965        for (sample_idx, &label) in sample_indices.iter().zip(labels.iter()) {
966            if *sample_idx < full_size {
967                full_labels[*sample_idx] = label;
968            }
969        }
970
971        // Assign unassigned points to nearest cluster (simplified)
972        for i in 0..full_size {
973            if full_labels[i] == -1 {
974                full_labels[i] = 0; // Assign to cluster 0 as fallback
975            }
976        }
977
978        Ok(full_labels)
979    }
980
981    /// Build consensus from multiple clustering results
982    fn build_consensus(
983        &self,
984        results: &[ClusteringResult],
985        data: ArrayView2<F>,
986    ) -> Result<Array1<i32>> {
987        if results.is_empty() {
988            return Err(ClusteringError::InvalidInput(
989                "No clustering results available for consensus".to_string(),
990            ));
991        }
992
993        let n_samples = data.nrows();
994
995        match &self.config.consensus_method {
996            ConsensusMethod::MajorityVoting => {
997                let result = self.majority_voting_consensus(results, data)?;
998                Ok(result.consensus_labels)
999            }
1000            ConsensusMethod::WeightedConsensus => {
1001                let result = self.weighted_consensus(results, data)?;
1002                Ok(result.consensus_labels)
1003            }
1004            ConsensusMethod::CoAssociation { threshold } => {
1005                let result = self.co_association_consensus(results, data, *threshold)?;
1006                Ok(result.consensus_labels)
1007            }
1008            ConsensusMethod::EvidenceAccumulation => {
1009                let result = self.evidence_accumulation_consensus(results, data)?;
1010                Ok(result.consensus_labels)
1011            }
1012            ConsensusMethod::GraphBased {
1013                similarity_threshold,
1014            } => {
1015                let result = self.graph_based_consensus(results, data, *similarity_threshold)?;
1016                Ok(result.consensus_labels)
1017            }
1018            ConsensusMethod::Hierarchical { linkage_method } => {
1019                let result = self.hierarchical_consensus(results, data, linkage_method)?;
1020                Ok(result.consensus_labels)
1021            }
1022        }
1023    }
1024
1025    /// Estimate optimal number of clusters from linkage matrix
1026    fn estimate_optimal_clusters(&self, linkagematrix: &Array2<f64>) -> usize {
1027        // Simple heuristic: find the largest gap in the linkage heights
1028        let mut max_gap = 0.0;
1029        let mut optimal_clusters = 2;
1030
1031        for i in 1..linkagematrix.nrows() {
1032            let gap = linkagematrix[[i, 2]] - linkagematrix[[i - 1, 2]];
1033            if gap > max_gap {
1034                max_gap = gap;
1035                optimal_clusters = linkagematrix.nrows() - i + 1;
1036            }
1037        }
1038
1039        optimal_clusters.min(self.config.max_clusters.unwrap_or(10))
1040    }
1041
1042    /// Calculate diversity metrics for the ensemble
1043    fn calculate_diversity_metrics(
1044        &self,
1045        results: &[ClusteringResult],
1046    ) -> Result<DiversityMetrics> {
1047        Ok(DiversityMetrics {
1048            average_diversity: 0.5,                       // Stub implementation
1049            diversity_matrix: Array2::eye(results.len()), // Stub implementation
1050            algorithm_distribution: HashMap::new(),       // Stub implementation
1051            parameter_diversity: HashMap::new(),          // Stub implementation
1052        })
1053    }
1054
1055    /// Calculate consensus statistics for the ensemble
1056    fn calculate_consensus_statistics(
1057        &self,
1058        _results: &[ClusteringResult],
1059        _consensus_labels: &Array1<i32>,
1060    ) -> Result<ConsensusStatistics> {
1061        let n_samples = _consensus_labels.len();
1062
1063        // Stub implementation - in production this would analyze agreement between clusterers
1064        Ok(ConsensusStatistics {
1065            agreement_matrix: Array2::zeros((n_samples, n_samples)),
1066            consensus_strength: Array1::ones(n_samples),
1067            cluster_stability: vec![0.5; 10], // Placeholder
1068            agreement_counts: Array1::ones(n_samples),
1069        })
1070    }
1071
1072    /// Calculate consensus stability score for the ensemble
1073    fn calculate_consensus_stability_score(&self, _consensusstats: &ConsensusStatistics) -> f64 {
1074        0.5 // Stub implementation
1075    }
1076}
1077
1078/// Extract samples based on indices
1079fn extract_samples<F: Float>(data: ArrayView2<F>, indices: &[usize]) -> Result<Array2<F>> {
1080    let n_features = data.ncols();
1081    let mut sampled_data = Array2::zeros((indices.len(), n_features));
1082
1083    for (new_idx, &old_idx) in indices.iter().enumerate() {
1084        if old_idx < data.nrows() {
1085            sampled_data.row_mut(new_idx).assign(&data.row(old_idx));
1086        }
1087    }
1088
1089    Ok(sampled_data)
1090}
1091
1092/// Extract features based on indices
1093fn extract_features<F: Float>(data: ArrayView2<F>, featureindices: &[usize]) -> Result<Array2<F>> {
1094    let n_samples = data.nrows();
1095    let mut sampled_data = Array2::zeros((n_samples, featureindices.len()));
1096
1097    for (new_feat_idx, &old_feat_idx) in featureindices.iter().enumerate() {
1098        if old_feat_idx < data.ncols() {
1099            sampled_data
1100                .column_mut(new_feat_idx)
1101                .assign(&data.column(old_feat_idx));
1102        }
1103    }
1104
1105    Ok(sampled_data)
1106}
1107
1108/// Default configuration for ensemble clustering
1109impl Default for EnsembleConfig {
1110    fn default() -> Self {
1111        Self {
1112            n_estimators: 10,
1113            sampling_strategy: SamplingStrategy::Bootstrap { sample_ratio: 0.8 },
1114            consensus_method: ConsensusMethod::MajorityVoting,
1115            random_seed: None,
1116            diversity_strategy: Some(DiversityStrategy::AlgorithmDiversity {
1117                algorithms: vec![
1118                    ClusteringAlgorithm::KMeans { k_range: (2, 10) },
1119                    ClusteringAlgorithm::DBSCAN {
1120                        eps_range: (0.1, 1.0),
1121                        min_samples_range: (3, 10),
1122                    },
1123                ],
1124            }),
1125            quality_threshold: None,
1126            max_clusters: Some(20),
1127        }
1128    }
1129}