Skip to main content

scirs2_cluster/ensemble/
weighted.rs

1//! Weighted ensemble clustering.
2//!
3//! Provides:
4//!
5//! - [`WeightedVoting`] – evidence-accumulation ensemble with per-member
6//!   quality weights.
7//! - [`SelectiveEnsemble`] – selects a diverse subset of base clusterings
8//!   to form the final ensemble.
9//! - [`ClusterSimilarity`] – NMI, ARI, and Fowlkes-Mallows similarity
10//!   measures used to score ensemble diversity.
11//! - [`BootstrapEnsemble`] – bootstrap-based cluster ensemble.
12//! - [`StackedClustering`] – stacked generalization for clustering (the
13//!   consensus labels of base clusterings are used as meta-features for a
14//!   second-level clustering).
15
16use scirs2_core::ndarray::{Array2, ArrayView2, Axis};
17use scirs2_core::numeric::{Float, FromPrimitive};
18use std::fmt::Debug;
19
20use crate::error::{ClusteringError, Result};
21
22// ---------------------------------------------------------------------------
23// ClusterSimilarity – pairwise similarity metrics
24// ---------------------------------------------------------------------------
25
26/// Pairwise similarity metrics for comparing two cluster label vectors.
27pub struct ClusterSimilarity;
28
29impl ClusterSimilarity {
30    /// Adjusted Rand Index between two label vectors.
31    ///
32    /// Returns a value in [-0.5, 1.0] where 1.0 indicates perfect agreement.
33    pub fn adjusted_rand_index(labels_a: &[usize], labels_b: &[usize]) -> f64 {
34        if labels_a.len() != labels_b.len() || labels_a.is_empty() {
35            return 0.0;
36        }
37        let n = labels_a.len();
38
39        let ka = *labels_a.iter().max().unwrap_or(&0) + 1;
40        let kb = *labels_b.iter().max().unwrap_or(&0) + 1;
41
42        // Build contingency table
43        let mut contingency = vec![vec![0usize; kb]; ka];
44        for i in 0..n {
45            let a = labels_a[i];
46            let b = labels_b[i];
47            if a < ka && b < kb {
48                contingency[a][b] += 1;
49            }
50        }
51
52        // Row / column sums
53        let row_sums: Vec<usize> = contingency.iter().map(|r| r.iter().sum()).collect();
54        let col_sums: Vec<usize> = (0..kb)
55            .map(|j| contingency.iter().map(|r| r[j]).sum())
56            .collect();
57
58        let sum_comb_c: f64 = contingency
59            .iter()
60            .flat_map(|r| r.iter())
61            .map(|&v| comb2(v))
62            .sum();
63        let sum_comb_a: f64 = row_sums.iter().map(|&v| comb2(v)).sum();
64        let sum_comb_b: f64 = col_sums.iter().map(|&v| comb2(v)).sum();
65        let comb_n = comb2(n);
66
67        let expected = sum_comb_a * sum_comb_b / comb_n.max(1.0);
68        let max_val = (sum_comb_a + sum_comb_b) / 2.0;
69        let denom = max_val - expected;
70        if denom.abs() < 1e-15 {
71            if (sum_comb_c - expected).abs() < 1e-15 {
72                1.0
73            } else {
74                0.0
75            }
76        } else {
77            (sum_comb_c - expected) / denom
78        }
79    }
80
81    /// Normalized Mutual Information (arithmetic mean normalisation).
82    pub fn normalized_mutual_info(labels_a: &[usize], labels_b: &[usize]) -> f64 {
83        if labels_a.len() != labels_b.len() || labels_a.is_empty() {
84            return 0.0;
85        }
86        let n = labels_a.len() as f64;
87        let ka = *labels_a.iter().max().unwrap_or(&0) + 1;
88        let kb = *labels_b.iter().max().unwrap_or(&0) + 1;
89
90        let mut contingency = vec![vec![0usize; kb]; ka];
91        for (&a, &b) in labels_a.iter().zip(labels_b.iter()) {
92            if a < ka && b < kb {
93                contingency[a][b] += 1;
94            }
95        }
96
97        let row_sums: Vec<f64> = contingency
98            .iter()
99            .map(|r| r.iter().sum::<usize>() as f64)
100            .collect();
101        let col_sums: Vec<f64> = (0..kb)
102            .map(|j| contingency.iter().map(|r| r[j]).sum::<usize>() as f64)
103            .collect();
104
105        // Mutual information
106        let mut mi = 0.0_f64;
107        for i in 0..ka {
108            for j in 0..kb {
109                let nij = contingency[i][j] as f64;
110                if nij > 0.0 {
111                    mi += nij / n * (nij * n / (row_sums[i] * col_sums[j])).ln();
112                }
113            }
114        }
115
116        // Entropies
117        let h_a: f64 = row_sums
118            .iter()
119            .filter(|&&v| v > 0.0)
120            .map(|&v| {
121                let p = v / n;
122                -p * p.ln()
123            })
124            .sum();
125        let h_b: f64 = col_sums
126            .iter()
127            .filter(|&&v| v > 0.0)
128            .map(|&v| {
129                let p = v / n;
130                -p * p.ln()
131            })
132            .sum();
133
134        let denom = (h_a + h_b) / 2.0;
135        if denom < 1e-15 {
136            1.0
137        } else {
138            mi / denom
139        }
140    }
141
142    /// Fowlkes-Mallows index between two label vectors.
143    pub fn fowlkes_mallows(labels_a: &[usize], labels_b: &[usize]) -> f64 {
144        if labels_a.len() != labels_b.len() || labels_a.is_empty() {
145            return 0.0;
146        }
147        let n = labels_a.len();
148        let mut tp = 0u64;
149        let mut fp = 0u64;
150        let mut fn_ = 0u64;
151
152        for i in 0..n {
153            for j in (i + 1)..n {
154                let same_a = labels_a[i] == labels_a[j];
155                let same_b = labels_b[i] == labels_b[j];
156                match (same_a, same_b) {
157                    (true, true) => tp += 1,
158                    (true, false) => fp += 1,
159                    (false, true) => fn_ += 1,
160                    _ => {}
161                }
162            }
163        }
164        let denom = ((tp + fp) as f64 * (tp + fn_) as f64).sqrt();
165        if denom < 1e-15 {
166            0.0
167        } else {
168            tp as f64 / denom
169        }
170    }
171}
172
173fn comb2(n: usize) -> f64 {
174    if n < 2 {
175        0.0
176    } else {
177        (n * (n - 1)) as f64 / 2.0
178    }
179}
180
181// ---------------------------------------------------------------------------
182// WeightedVoting
183// ---------------------------------------------------------------------------
184
185/// Configuration for weighted voting ensemble clustering.
186#[derive(Debug, Clone)]
187pub struct WeightedVotingConfig {
188    /// Number of base clusterings to combine.
189    pub n_base: usize,
190    /// Quality metric used to assign weights to base clusterings.
191    pub quality_metric: EnsembleQualityMetric,
192    /// Minimum quality threshold for inclusion.
193    pub min_quality: f64,
194    /// Number of clusters in the final ensemble.
195    pub n_clusters: usize,
196    /// Maximum iterations for the consensus step.
197    pub max_iter: usize,
198}
199
200impl Default for WeightedVotingConfig {
201    fn default() -> Self {
202        Self {
203            n_base: 10,
204            quality_metric: EnsembleQualityMetric::NMI,
205            min_quality: 0.0,
206            n_clusters: 3,
207            max_iter: 100,
208        }
209    }
210}
211
212/// Quality metric for scoring base clusterings.
213#[derive(Debug, Clone, Copy, PartialEq)]
214pub enum EnsembleQualityMetric {
215    /// Normalized Mutual Information vs. a reference clustering.
216    NMI,
217    /// Adjusted Rand Index vs. a reference clustering.
218    ARI,
219    /// Fowlkes-Mallows index vs. a reference clustering.
220    FowlkesMallows,
221    /// Uniform weights (no quality weighting).
222    Uniform,
223}
224
225/// Weighted voting ensemble with evidence accumulation.
226///
227/// Combines multiple base clusterings by building a weighted co-association
228/// matrix: `S[i,j] += w_k * (labels_k[i] == labels_k[j])`.  The final
229/// clustering is obtained by running k-means on the dissimilarity
230/// `D = 1 - S`.
231pub struct WeightedVoting {
232    config: WeightedVotingConfig,
233}
234
235impl WeightedVoting {
236    /// Create a new WeightedVoting instance.
237    pub fn new(config: WeightedVotingConfig) -> Self {
238        Self { config }
239    }
240
241    /// Combine base clusterings using weighted evidence accumulation.
242    ///
243    /// `base_labels`: each row is a labelling from one base clusterer (shape:
244    ///   `[n_base, n_samples]`).
245    /// `weights`: per-base-clusterer quality weight (length `n_base`).  If
246    ///   `None`, uniform weights are used.
247    pub fn combine(
248        &self,
249        base_labels: &[Vec<usize>],
250        weights: Option<&[f64]>,
251    ) -> Result<WeightedVotingResult> {
252        if base_labels.is_empty() {
253            return Err(ClusteringError::InvalidInput(
254                "No base clusterings provided".into(),
255            ));
256        }
257        let n = base_labels[0].len();
258        for bl in base_labels.iter() {
259            if bl.len() != n {
260                return Err(ClusteringError::InvalidInput(
261                    "All base clusterings must have the same length".into(),
262                ));
263            }
264        }
265
266        let m = base_labels.len();
267        let default_w = vec![1.0 / m as f64; m];
268        let w: &[f64] = weights.unwrap_or(&default_w);
269
270        // Filter by minimum quality (quality vs. the average / a reference)
271        // Here we use ARI between each pair and assign weight proportional
272        // to average similarity with other clusterings.
273        let mut effective_weights: Vec<f64> = match self.config.quality_metric {
274            EnsembleQualityMetric::Uniform => w.to_vec(),
275            EnsembleQualityMetric::NMI => self.compute_diversity_weights(base_labels, |a, b| {
276                ClusterSimilarity::normalized_mutual_info(a, b)
277            }),
278            EnsembleQualityMetric::ARI => self.compute_diversity_weights(base_labels, |a, b| {
279                ClusterSimilarity::adjusted_rand_index(a, b)
280            }),
281            EnsembleQualityMetric::FowlkesMallows => self
282                .compute_diversity_weights(base_labels, |a, b| {
283                    ClusterSimilarity::fowlkes_mallows(a, b)
284                }),
285        };
286
287        // Apply supplied weights multiplicatively
288        for (i, ew) in effective_weights.iter_mut().enumerate() {
289            *ew *= w[i];
290        }
291
292        // Normalise
293        let w_sum: f64 = effective_weights.iter().sum();
294        if w_sum < 1e-15 {
295            for ew in effective_weights.iter_mut() {
296                *ew = 1.0 / m as f64;
297            }
298        } else {
299            for ew in effective_weights.iter_mut() {
300                *ew /= w_sum;
301            }
302        }
303
304        // Build weighted co-association matrix
305        let mut co_assoc = vec![vec![0.0f64; n]; n];
306        for (k, bl) in base_labels.iter().enumerate() {
307            let wk = effective_weights[k];
308            if wk < 1e-15 {
309                continue;
310            }
311            for i in 0..n {
312                for j in (i + 1)..n {
313                    if bl[i] == bl[j] {
314                        co_assoc[i][j] += wk;
315                        co_assoc[j][i] += wk;
316                    }
317                }
318            }
319        }
320        // Self-similarity = 1
321        for i in 0..n {
322            co_assoc[i][i] = 1.0;
323        }
324
325        // Final consensus clustering: k-means on dissimilarity embedding rows
326        let labels = self.consensus_from_coassoc(&co_assoc, n)?;
327        let used_bases = base_labels.len();
328
329        Ok(WeightedVotingResult {
330            labels,
331            weights: effective_weights,
332            co_association: co_assoc,
333            n_clusters: self.config.n_clusters,
334            n_base_clusterings: used_bases,
335        })
336    }
337
338    /// Compute per-clustering weights as their average similarity to others.
339    fn compute_diversity_weights(
340        &self,
341        base_labels: &[Vec<usize>],
342        sim_fn: impl Fn(&[usize], &[usize]) -> f64,
343    ) -> Vec<f64> {
344        let m = base_labels.len();
345        let mut weights = vec![0.0f64; m];
346        if m == 1 {
347            weights[0] = 1.0;
348            return weights;
349        }
350        for i in 0..m {
351            let sum: f64 = (0..m)
352                .filter(|&j| j != i)
353                .map(|j| sim_fn(&base_labels[i], &base_labels[j]))
354                .sum();
355            weights[i] = sum / (m - 1) as f64;
356        }
357        weights
358    }
359
360    /// Simple k-means on the dissimilarity rows of the co-association matrix.
361    fn consensus_from_coassoc(&self, co_assoc: &[Vec<f64>], n: usize) -> Result<Vec<usize>> {
362        let k = self.config.n_clusters.min(n);
363        if k == 0 || n == 0 {
364            return Ok(vec![0; n]);
365        }
366
367        // Use co-assoc rows as feature vectors
368        let mut centroids: Vec<Vec<f64>> = (0..k).map(|i| co_assoc[i].clone()).collect();
369        let mut labels = vec![0usize; n];
370
371        for _ in 0..self.config.max_iter {
372            // Assign
373            for i in 0..n {
374                let mut best = 0;
375                let mut best_d = f64::MAX;
376                for (j, c) in centroids.iter().enumerate() {
377                    let d: f64 = co_assoc[i]
378                        .iter()
379                        .zip(c.iter())
380                        .map(|(&a, &b)| (a - b) * (a - b))
381                        .sum();
382                    if d < best_d {
383                        best_d = d;
384                        best = j;
385                    }
386                }
387                labels[i] = best;
388            }
389
390            // Update
391            let mut new_cents = vec![vec![0.0f64; n]; k];
392            let mut counts = vec![0usize; k];
393            for i in 0..n {
394                let j = labels[i];
395                counts[j] += 1;
396                for dim in 0..n {
397                    new_cents[j][dim] += co_assoc[i][dim];
398                }
399            }
400            for j in 0..k {
401                if counts[j] > 0 {
402                    let nf = counts[j] as f64;
403                    for dim in 0..n {
404                        new_cents[j][dim] /= nf;
405                    }
406                }
407            }
408            centroids = new_cents;
409        }
410        Ok(labels)
411    }
412}
413
414/// Result from WeightedVoting.
415#[derive(Debug, Clone)]
416pub struct WeightedVotingResult {
417    /// Consensus cluster labels for each data point.
418    pub labels: Vec<usize>,
419    /// Effective weights assigned to each base clustering.
420    pub weights: Vec<f64>,
421    /// Weighted co-association matrix (n × n).
422    pub co_association: Vec<Vec<f64>>,
423    /// Number of consensus clusters.
424    pub n_clusters: usize,
425    /// Number of base clusterings used.
426    pub n_base_clusterings: usize,
427}
428
429impl WeightedVotingResult {
430    /// Average weight of the base clusterings.
431    pub fn mean_weight(&self) -> f64 {
432        if self.weights.is_empty() {
433            return 0.0;
434        }
435        self.weights.iter().sum::<f64>() / self.weights.len() as f64
436    }
437
438    /// Weight variance (spread of quality scores).
439    pub fn weight_variance(&self) -> f64 {
440        if self.weights.len() < 2 {
441            return 0.0;
442        }
443        let mean = self.mean_weight();
444        let var: f64 = self
445            .weights
446            .iter()
447            .map(|&w| (w - mean).powi(2))
448            .sum::<f64>()
449            / (self.weights.len() - 1) as f64;
450        var
451    }
452}
453
454// ---------------------------------------------------------------------------
455// SelectiveEnsemble
456// ---------------------------------------------------------------------------
457
458/// Configuration for selective ensemble.
459#[derive(Debug, Clone)]
460pub struct SelectiveEnsembleConfig {
461    /// Target ensemble size (number of base clusterings to select).
462    pub target_size: usize,
463    /// Diversity metric for selection.
464    pub diversity_metric: DiversityMeasure,
465    /// Minimum pairwise diversity threshold (clusterings below this are
466    /// considered redundant).
467    pub diversity_threshold: f64,
468}
469
470impl Default for SelectiveEnsembleConfig {
471    fn default() -> Self {
472        Self {
473            target_size: 5,
474            diversity_metric: DiversityMeasure::NMI,
475            diversity_threshold: 0.3,
476        }
477    }
478}
479
480/// Diversity measure for SelectiveEnsemble.
481#[derive(Debug, Clone, Copy)]
482pub enum DiversityMeasure {
483    /// Use 1 - NMI as diversity.
484    NMI,
485    /// Use 1 - ARI as diversity.
486    ARI,
487    /// Use 1 - FowlkesMallows as diversity.
488    FowlkesMallows,
489}
490
491/// SelectiveEnsemble: greedily selects a diverse subset of base clusterings.
492///
493/// Starting from the clustering with the highest average diversity, it
494/// iteratively adds the clustering that maximises the minimum pairwise
495/// diversity with those already selected.
496pub struct SelectiveEnsemble {
497    config: SelectiveEnsembleConfig,
498}
499
500impl SelectiveEnsemble {
501    /// Create a new SelectiveEnsemble.
502    pub fn new(config: SelectiveEnsembleConfig) -> Self {
503        Self { config }
504    }
505
506    /// Select a diverse subset of base clusterings.
507    ///
508    /// Returns the indices of the selected clusterings.
509    pub fn select(&self, base_labels: &[Vec<usize>]) -> Result<SelectiveEnsembleResult> {
510        let m = base_labels.len();
511        if m == 0 {
512            return Err(ClusteringError::InvalidInput(
513                "No base clusterings to select from".into(),
514            ));
515        }
516
517        let target = self.config.target_size.min(m);
518        let sim_fn: Box<dyn Fn(&[usize], &[usize]) -> f64> = match self.config.diversity_metric {
519            DiversityMeasure::NMI => {
520                Box::new(|a, b| ClusterSimilarity::normalized_mutual_info(a, b))
521            }
522            DiversityMeasure::ARI => Box::new(|a, b| ClusterSimilarity::adjusted_rand_index(a, b)),
523            DiversityMeasure::FowlkesMallows => {
524                Box::new(|a, b| ClusterSimilarity::fowlkes_mallows(a, b))
525            }
526        };
527
528        // Compute full diversity matrix (diversity = 1 - similarity)
529        let mut diversity = vec![vec![0.0f64; m]; m];
530        for i in 0..m {
531            for j in (i + 1)..m {
532                let d = 1.0 - sim_fn(&base_labels[i], &base_labels[j]).max(0.0);
533                diversity[i][j] = d;
534                diversity[j][i] = d;
535            }
536        }
537
538        // Greedy max-min selection
539        // Start with the clustering that has the highest average diversity
540        let avg_div: Vec<f64> = diversity
541            .iter()
542            .map(|row| row.iter().sum::<f64>() / (m - 1).max(1) as f64)
543            .collect();
544        let start = avg_div
545            .iter()
546            .enumerate()
547            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
548            .map(|(i, _)| i)
549            .unwrap_or(0);
550
551        let mut selected = vec![start];
552        let mut remaining: Vec<usize> = (0..m).filter(|&i| i != start).collect();
553
554        while selected.len() < target && !remaining.is_empty() {
555            // For each remaining, compute its minimum diversity with the selected set
556            let mut best_idx_in_remaining = 0;
557            let mut best_min_div = -1.0_f64;
558            for (ri, &cand) in remaining.iter().enumerate() {
559                let min_div = selected
560                    .iter()
561                    .map(|&s| diversity[cand][s])
562                    .fold(f64::MAX, f64::min);
563                if min_div > best_min_div {
564                    best_min_div = min_div;
565                    best_idx_in_remaining = ri;
566                }
567            }
568            let chosen = remaining.remove(best_idx_in_remaining);
569            selected.push(chosen);
570        }
571
572        // Compute average pairwise diversity of selected set
573        let avg_diversity = if selected.len() < 2 {
574            0.0
575        } else {
576            let pairs = selected.len() * (selected.len() - 1) / 2;
577            let sum: f64 = selected
578                .iter()
579                .enumerate()
580                .flat_map(|(i, &a)| {
581                    let div_ref = &diversity;
582                    selected[(i + 1)..]
583                        .iter()
584                        .map(move |&b| div_ref[a][b])
585                        .collect::<Vec<_>>()
586                })
587                .sum();
588            sum / pairs as f64
589        };
590
591        Ok(SelectiveEnsembleResult {
592            selected_indices: selected,
593            diversity_matrix: diversity,
594            average_diversity: avg_diversity,
595        })
596    }
597}
598
599/// Result from SelectiveEnsemble.
600#[derive(Debug, Clone)]
601pub struct SelectiveEnsembleResult {
602    /// Indices of selected base clusterings.
603    pub selected_indices: Vec<usize>,
604    /// Full pairwise diversity matrix (m × m).
605    pub diversity_matrix: Vec<Vec<f64>>,
606    /// Average pairwise diversity of the selected set.
607    pub average_diversity: f64,
608}
609
610// ---------------------------------------------------------------------------
611// BootstrapEnsemble
612// ---------------------------------------------------------------------------
613
614/// Configuration for bootstrap ensemble clustering.
615#[derive(Debug, Clone)]
616pub struct BootstrapEnsembleConfig {
617    /// Number of bootstrap samples.
618    pub n_bootstrap: usize,
619    /// Fraction of the dataset to sample per bootstrap.
620    pub sample_ratio: f64,
621    /// Number of clusters per base clustering.
622    pub n_clusters: usize,
623    /// Maximum iterations for each base k-means run.
624    pub max_iter: usize,
625    /// Random seed base.
626    pub seed: u64,
627}
628
629impl Default for BootstrapEnsembleConfig {
630    fn default() -> Self {
631        Self {
632            n_bootstrap: 10,
633            sample_ratio: 0.8,
634            n_clusters: 3,
635            max_iter: 50,
636            seed: 42,
637        }
638    }
639}
640
641/// Bootstrap-based cluster ensemble.
642///
643/// Generates multiple bootstrap sub-samples of the data, clusters each
644/// with k-means, and then combines the base clusterings via a weighted
645/// co-association matrix.
646pub struct BootstrapEnsemble {
647    config: BootstrapEnsembleConfig,
648}
649
650impl BootstrapEnsemble {
651    /// Create a new BootstrapEnsemble.
652    pub fn new(config: BootstrapEnsembleConfig) -> Self {
653        Self { config }
654    }
655
656    /// Fit the bootstrap ensemble and return consensus labels.
657    pub fn fit<F>(&self, data: ArrayView2<F>) -> Result<BootstrapEnsembleResult>
658    where
659        F: Float + FromPrimitive + Debug + Clone,
660        f64: From<F>,
661    {
662        let (n, d) = (data.nrows(), data.ncols());
663        if n == 0 {
664            return Err(ClusteringError::InvalidInput("Empty dataset".into()));
665        }
666        let k = self.config.n_clusters.min(n);
667        let sample_n = ((n as f64 * self.config.sample_ratio) as usize).max(k);
668
669        let mut base_labels_all: Vec<Vec<usize>> = Vec::new();
670
671        for b in 0..self.config.n_bootstrap {
672            // Simple deterministic bootstrap: use modular stride
673            let stride = (b * 7 + 3) % n + 1;
674            let indices: Vec<usize> = (0..sample_n).map(|i| (i * stride) % n).collect();
675
676            // Extract sub-sample centroids via k-means
677            let sample_centroids = self.fit_kmeans_on_indices(data, &indices, k)?;
678
679            // Assign all points to nearest centroid
680            let labels: Vec<usize> = (0..n)
681                .map(|i| {
682                    let row: Vec<f64> = data.row(i).iter().map(|&v| f64::from(v)).collect();
683                    nearest_centroid_f64(&sample_centroids, &row)
684                })
685                .collect();
686            base_labels_all.push(labels);
687        }
688
689        // Combine via co-association
690        let voting = WeightedVoting::new(WeightedVotingConfig {
691            n_base: self.config.n_bootstrap,
692            quality_metric: EnsembleQualityMetric::NMI,
693            min_quality: 0.0,
694            n_clusters: k,
695            max_iter: self.config.max_iter,
696        });
697        let voting_result = voting.combine(&base_labels_all, None)?;
698
699        // Compute stability: average NMI across bootstrap pairs
700        let stability = compute_average_nmi(&base_labels_all);
701
702        Ok(BootstrapEnsembleResult {
703            labels: voting_result.labels,
704            base_labels: base_labels_all,
705            stability,
706            n_bootstrap: self.config.n_bootstrap,
707            n_clusters: k,
708        })
709    }
710
711    fn fit_kmeans_on_indices<F>(
712        &self,
713        data: ArrayView2<F>,
714        indices: &[usize],
715        k: usize,
716    ) -> Result<Vec<Vec<f64>>>
717    where
718        F: Float + FromPrimitive + Debug + Clone,
719        f64: From<F>,
720    {
721        let d = data.ncols();
722        let n = indices.len();
723        let k = k.min(n);
724        // Initial centroids: evenly spaced within the sample
725        let step = n / k;
726        let mut cents: Vec<Vec<f64>> = (0..k)
727            .map(|i| {
728                let idx = indices[i * step];
729                data.row(idx).iter().map(|&v| f64::from(v)).collect()
730            })
731            .collect();
732
733        for _ in 0..self.config.max_iter {
734            let mut new_cents = vec![vec![0.0f64; d]; k];
735            let mut counts = vec![0usize; k];
736            for &idx in indices {
737                let row: Vec<f64> = data.row(idx).iter().map(|&v| f64::from(v)).collect();
738                let best = nearest_centroid_f64(&cents, &row);
739                counts[best] += 1;
740                for dim in 0..d {
741                    new_cents[best][dim] += row[dim];
742                }
743            }
744            for j in 0..k {
745                if counts[j] > 0 {
746                    let nf = counts[j] as f64;
747                    for dim in 0..d {
748                        new_cents[j][dim] /= nf;
749                    }
750                } else {
751                    new_cents[j] = cents[j].clone();
752                }
753            }
754            cents = new_cents;
755        }
756        Ok(cents)
757    }
758}
759
760/// Result from BootstrapEnsemble.
761#[derive(Debug, Clone)]
762pub struct BootstrapEnsembleResult {
763    /// Consensus cluster labels.
764    pub labels: Vec<usize>,
765    /// Labels from each bootstrap run.
766    pub base_labels: Vec<Vec<usize>>,
767    /// Average NMI across bootstrap pairs (stability estimate).
768    pub stability: f64,
769    /// Number of bootstrap runs.
770    pub n_bootstrap: usize,
771    /// Number of clusters.
772    pub n_clusters: usize,
773}
774
775// ---------------------------------------------------------------------------
776// StackedClustering
777// ---------------------------------------------------------------------------
778
779/// Configuration for stacked clustering.
780#[derive(Debug, Clone)]
781pub struct StackedClusteringConfig {
782    /// Number of base clusterings.
783    pub n_base: usize,
784    /// Number of clusters for each base clustering.
785    pub n_base_clusters: usize,
786    /// Number of meta-clusters in the second level.
787    pub n_meta_clusters: usize,
788    /// Maximum iterations for each level.
789    pub max_iter: usize,
790    /// Whether to append the original features to the meta-features.
791    pub append_original: bool,
792}
793
794impl Default for StackedClusteringConfig {
795    fn default() -> Self {
796        Self {
797            n_base: 5,
798            n_base_clusters: 5,
799            n_meta_clusters: 3,
800            max_iter: 100,
801            append_original: false,
802        }
803    }
804}
805
806/// Stacked generalization for clustering.
807///
808/// Base clusterers produce soft or hard label vectors that are used as
809/// meta-features for a second-level k-means clustering.
810pub struct StackedClustering {
811    config: StackedClusteringConfig,
812}
813
814impl StackedClustering {
815    /// Create a new StackedClustering instance.
816    pub fn new(config: StackedClusteringConfig) -> Self {
817        Self { config }
818    }
819
820    /// Fit the stacked ensemble.
821    ///
822    /// Uses k-means with varied random offsets as base clusterers, then
823    /// clusters the label matrix with a second-level k-means.
824    pub fn fit<F>(&self, data: ArrayView2<F>) -> Result<StackedClusteringResult>
825    where
826        F: Float + FromPrimitive + Debug + Clone,
827        f64: From<F>,
828    {
829        let (n, d) = (data.nrows(), data.ncols());
830        if n == 0 {
831            return Err(ClusteringError::InvalidInput("Empty dataset".into()));
832        }
833        let kb = self.config.n_base_clusters.min(n);
834        let km = self.config.n_meta_clusters.min(n);
835
836        // Generate base label vectors with different centroid offsets
837        let mut meta_features: Vec<Vec<f64>> = vec![Vec::new(); n];
838
839        for b in 0..self.config.n_base {
840            let offset = b as f64 * 0.01; // slight deterministic perturbation
841            let labels = self.kmeans_with_offset(data, kb, offset)?;
842            for i in 0..n {
843                meta_features[i].push(labels[i] as f64);
844            }
845        }
846
847        // Optionally append normalised original features
848        if self.config.append_original && d > 0 {
849            // Compute per-dimension range for normalisation
850            let mut min_d = vec![f64::MAX; d];
851            let mut max_d = vec![f64::MIN; d];
852            for row in data.rows() {
853                for (j, &v) in row.iter().enumerate() {
854                    let vf = f64::from(v);
855                    if vf < min_d[j] {
856                        min_d[j] = vf;
857                    }
858                    if vf > max_d[j] {
859                        max_d[j] = vf;
860                    }
861                }
862            }
863            for (i, row) in data.rows().into_iter().enumerate() {
864                for (j, &v) in row.iter().enumerate() {
865                    let vf = f64::from(v);
866                    let range = (max_d[j] - min_d[j]).max(1e-15);
867                    meta_features[i].push((vf - min_d[j]) / range);
868                }
869            }
870        }
871
872        // Second-level k-means on meta-features
873        let meta_d = meta_features.first().map(|r| r.len()).unwrap_or(0);
874        let mut meta_cents: Vec<Vec<f64>> = (0..km).map(|i| meta_features[i % n].clone()).collect();
875        let mut final_labels = vec![0usize; n];
876
877        for _ in 0..self.config.max_iter {
878            for i in 0..n {
879                final_labels[i] = nearest_centroid_f64(&meta_cents, &meta_features[i]);
880            }
881            let mut new_cents = vec![vec![0.0; meta_d]; km];
882            let mut counts = vec![0usize; km];
883            for i in 0..n {
884                let j = final_labels[i];
885                counts[j] += 1;
886                for k in 0..meta_d {
887                    new_cents[j][k] += meta_features[i][k];
888                }
889            }
890            for j in 0..km {
891                if counts[j] > 0 {
892                    let nf = counts[j] as f64;
893                    for k in 0..meta_d {
894                        new_cents[j][k] /= nf;
895                    }
896                }
897            }
898            meta_cents = new_cents;
899        }
900
901        Ok(StackedClusteringResult {
902            labels: final_labels,
903            meta_features,
904            n_base: self.config.n_base,
905            n_meta_clusters: km,
906        })
907    }
908
909    fn kmeans_with_offset<F>(
910        &self,
911        data: ArrayView2<F>,
912        k: usize,
913        offset: f64,
914    ) -> Result<Vec<usize>>
915    where
916        F: Float + FromPrimitive + Debug + Clone,
917        f64: From<F>,
918    {
919        let (n, d) = (data.nrows(), data.ncols());
920        let k = k.min(n);
921        let offset_f = F::from_f64(offset).unwrap_or(F::zero());
922
923        let mut cents: Vec<Vec<f64>> = (0..k)
924            .map(|i| data.row(i).iter().map(|&v| f64::from(v) + offset).collect())
925            .collect();
926        let mut labels = vec![0usize; n];
927
928        for _ in 0..self.config.max_iter {
929            for i in 0..n {
930                let row: Vec<f64> = data.row(i).iter().map(|&v| f64::from(v)).collect();
931                labels[i] = nearest_centroid_f64(&cents, &row);
932            }
933            let mut new_cents = vec![vec![0.0; d]; k];
934            let mut counts = vec![0usize; k];
935            for i in 0..n {
936                let j = labels[i];
937                counts[j] += 1;
938                let row: Vec<f64> = data.row(i).iter().map(|&v| f64::from(v)).collect();
939                for dim in 0..d {
940                    new_cents[j][dim] += row[dim];
941                }
942            }
943            for j in 0..k {
944                if counts[j] > 0 {
945                    let nf = counts[j] as f64;
946                    for dim in 0..d {
947                        new_cents[j][dim] /= nf;
948                    }
949                }
950            }
951            cents = new_cents;
952        }
953        Ok(labels)
954    }
955}
956
957/// Result from StackedClustering.
958#[derive(Debug, Clone)]
959pub struct StackedClusteringResult {
960    /// Final consensus cluster labels.
961    pub labels: Vec<usize>,
962    /// Meta-features matrix (base clustering label vectors per point).
963    pub meta_features: Vec<Vec<f64>>,
964    /// Number of base clusterings used.
965    pub n_base: usize,
966    /// Number of meta-level clusters.
967    pub n_meta_clusters: usize,
968}
969
970// ---------------------------------------------------------------------------
971// Internal helpers
972// ---------------------------------------------------------------------------
973
974fn nearest_centroid_f64(centroids: &[Vec<f64>], point: &[f64]) -> usize {
975    let mut best = 0;
976    let mut best_d = f64::MAX;
977    for (j, c) in centroids.iter().enumerate() {
978        let d: f64 = c
979            .iter()
980            .zip(point.iter())
981            .map(|(&a, &b)| (a - b) * (a - b))
982            .sum();
983        if d < best_d {
984            best_d = d;
985            best = j;
986        }
987    }
988    best
989}
990
991fn compute_average_nmi(base_labels: &[Vec<usize>]) -> f64 {
992    let m = base_labels.len();
993    if m < 2 {
994        return 1.0;
995    }
996    let pairs = m * (m - 1) / 2;
997    let sum: f64 = (0..m)
998        .flat_map(|i| (i + 1..m).map(move |j| (i, j)))
999        .map(|(i, j)| ClusterSimilarity::normalized_mutual_info(&base_labels[i], &base_labels[j]))
1000        .sum();
1001    sum / pairs as f64
1002}
1003
1004// ---------------------------------------------------------------------------
1005// Tests
1006// ---------------------------------------------------------------------------
1007
1008#[cfg(test)]
1009mod tests {
1010    use super::*;
1011    use scirs2_core::ndarray::Array2;
1012
1013    fn two_cluster_labels() -> (Vec<usize>, Vec<usize>) {
1014        let a: Vec<usize> = (0..20).map(|i| if i < 10 { 0 } else { 1 }).collect();
1015        let b: Vec<usize> = (0..20).map(|i| if i < 10 { 0 } else { 1 }).collect();
1016        (a, b)
1017    }
1018
1019    #[test]
1020    fn test_ari_perfect() {
1021        let (a, b) = two_cluster_labels();
1022        let ari = ClusterSimilarity::adjusted_rand_index(&a, &b);
1023        assert!((ari - 1.0).abs() < 1e-9, "ARI = {}", ari);
1024    }
1025
1026    #[test]
1027    fn test_nmi_perfect() {
1028        let (a, b) = two_cluster_labels();
1029        let nmi = ClusterSimilarity::normalized_mutual_info(&a, &b);
1030        assert!((nmi - 1.0).abs() < 1e-9, "NMI = {}", nmi);
1031    }
1032
1033    #[test]
1034    fn test_fowlkes_mallows_perfect() {
1035        let (a, b) = two_cluster_labels();
1036        let fm = ClusterSimilarity::fowlkes_mallows(&a, &b);
1037        assert!((fm - 1.0).abs() < 1e-9, "FM = {}", fm);
1038    }
1039
1040    #[test]
1041    fn test_weighted_voting() {
1042        let labels1: Vec<usize> = (0..20).map(|i| if i < 10 { 0 } else { 1 }).collect();
1043        let labels2: Vec<usize> = (0..20).map(|i| if i < 12 { 0 } else { 1 }).collect();
1044        let base = vec![labels1, labels2];
1045
1046        let wv = WeightedVoting::new(WeightedVotingConfig {
1047            n_base: 2,
1048            n_clusters: 2,
1049            ..Default::default()
1050        });
1051        let result = wv.combine(&base, None).expect("combine ok");
1052        assert_eq!(result.labels.len(), 20);
1053        assert_eq!(result.n_clusters, 2);
1054    }
1055
1056    #[test]
1057    fn test_selective_ensemble() {
1058        let labels: Vec<Vec<usize>> = (0..5)
1059            .map(|b| (0..20).map(|i| if i < 10 + b { 0 } else { 1 }).collect())
1060            .collect();
1061        let se = SelectiveEnsemble::new(SelectiveEnsembleConfig {
1062            target_size: 3,
1063            ..Default::default()
1064        });
1065        let result = se.select(&labels).expect("select ok");
1066        assert_eq!(result.selected_indices.len(), 3);
1067    }
1068
1069    #[test]
1070    fn test_bootstrap_ensemble() {
1071        let data: Array2<f64> = {
1072            let mut v = Vec::new();
1073            for i in 0..20 {
1074                let offset = if i < 10 { 0.0 } else { 10.0 };
1075                v.extend_from_slice(&[offset + i as f64 * 0.1, offset + i as f64 * 0.1]);
1076            }
1077            Array2::from_shape_vec((20, 2), v).expect("ok")
1078        };
1079        let be = BootstrapEnsemble::new(BootstrapEnsembleConfig {
1080            n_bootstrap: 3,
1081            n_clusters: 2,
1082            ..Default::default()
1083        });
1084        let result = be.fit(data.view()).expect("fit ok");
1085        assert_eq!(result.labels.len(), 20);
1086        assert_eq!(result.n_bootstrap, 3);
1087    }
1088
1089    #[test]
1090    fn test_stacked_clustering() {
1091        let data: Array2<f64> = {
1092            let mut v = Vec::new();
1093            for i in 0..20 {
1094                let offset = if i < 10 { 0.0 } else { 10.0 };
1095                v.extend_from_slice(&[offset + i as f64 * 0.1, offset + i as f64 * 0.1]);
1096            }
1097            Array2::from_shape_vec((20, 2), v).expect("ok")
1098        };
1099        let sc = StackedClustering::new(StackedClusteringConfig {
1100            n_base: 3,
1101            n_base_clusters: 2,
1102            n_meta_clusters: 2,
1103            ..Default::default()
1104        });
1105        let result = sc.fit(data.view()).expect("fit ok");
1106        assert_eq!(result.labels.len(), 20);
1107    }
1108}