scirs2_cluster/
stability.rs

1//! Cluster stability assessment tools
2//!
3//! This module provides various methods for assessing the stability and
4//! quality of clustering results, including bootstrap validation,
5//! consensus clustering, and stability indices.
6
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use scirs2_core::random::seq::SliceRandom;
10use scirs2_core::random::{Rng, SeedableRng};
11use std::collections::HashSet;
12use std::fmt::Debug;
13
14use crate::error::{ClusteringError, Result};
15use crate::metrics::adjusted_rand_index;
16use crate::vq::kmeans2;
17
18/// Configuration for stability assessment
19#[derive(Debug, Clone)]
20pub struct StabilityConfig {
21    /// Number of bootstrap iterations
22    pub n_bootstrap: usize,
23    /// Fraction of data to sample in each bootstrap iteration
24    pub subsample_ratio: f64,
25    /// Random seed for reproducible results
26    pub random_seed: Option<u64>,
27    /// Number of clustering algorithm runs per bootstrap
28    pub n_runs_per_bootstrap: usize,
29    /// Range of cluster numbers to test (for optimal k selection)
30    pub k_range: Option<(usize, usize)>,
31}
32
33impl Default for StabilityConfig {
34    fn default() -> Self {
35        Self {
36            n_bootstrap: 100,
37            subsample_ratio: 0.8,
38            random_seed: None,
39            n_runs_per_bootstrap: 10,
40            k_range: None,
41        }
42    }
43}
44
45/// Results of stability assessment
46#[derive(Debug, Clone)]
47pub struct StabilityResult<F: Float> {
48    /// Stability scores for each tested configuration
49    pub stability_scores: Vec<F>,
50    /// Consensus clustering result
51    pub consensus_labels: Option<Array1<usize>>,
52    /// Optimal number of clusters (if k_range was provided)
53    pub optimal_k: Option<usize>,
54    /// Mean stability score across all bootstrap iterations
55    pub mean_stability: F,
56    /// Standard deviation of stability scores
57    pub std_stability: F,
58    /// Bootstrap stability matrix
59    pub bootstrap_matrix: Array2<F>,
60}
61
62/// Bootstrap validation for clustering stability
63///
64/// This method assesses the stability of clustering by running the algorithm
65/// on multiple bootstrap samples of the data and measuring the consistency
66/// of the results.
67pub struct BootstrapValidator<F: Float> {
68    config: StabilityConfig,
69    phantom: std::marker::PhantomData<F>,
70}
71
72impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
73    BootstrapValidator<F>
74{
75    /// Create a new bootstrap validator
76    pub fn new(config: StabilityConfig) -> Self {
77        Self {
78            config,
79            phantom: std::marker::PhantomData,
80        }
81    }
82
83    /// Assess K-means clustering stability
84    pub fn assess_kmeans_stability(
85        &self,
86        data: ArrayView2<F>,
87        k: usize,
88    ) -> Result<StabilityResult<F>> {
89        let n_samples = data.shape()[0];
90        let n_features = data.shape()[1];
91
92        if n_samples < 2 {
93            return Err(ClusteringError::InvalidInput(
94                "Need at least 2 samples for stability assessment".into(),
95            ));
96        }
97
98        let subsample_size = ((n_samples as f64) * self.config.subsample_ratio) as usize;
99        if subsample_size < k {
100            return Err(ClusteringError::InvalidInput(
101                "Subsample size must be at least k".into(),
102            ));
103        }
104
105        let mut rng = match self.config.random_seed {
106            Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
107            None => {
108                // Use a default seed when no seed is provided
109                scirs2_core::random::rngs::StdRng::seed_from_u64(42)
110            }
111        };
112
113        let mut bootstrap_results = Vec::new();
114
115        // Perform bootstrap iterations
116        for _iteration in 0..self.config.n_bootstrap {
117            // Create bootstrap sample
118            let mut indices: Vec<usize> = (0..n_samples).collect();
119            indices.shuffle(&mut rng);
120            indices.truncate(subsample_size);
121
122            let mut bootstrap_data = Array2::zeros((subsample_size, n_features));
123            for (new_idx, &old_idx) in indices.iter().enumerate() {
124                bootstrap_data.row_mut(new_idx).assign(&data.row(old_idx));
125            }
126
127            // Run clustering multiple times on this bootstrap sample
128            let mut run_labels = Vec::new();
129            for _run in 0..self.config.n_runs_per_bootstrap {
130                let seed = rng.random::<u64>();
131
132                match kmeans2(
133                    bootstrap_data.view(),
134                    k,
135                    Some(100),   // max_iter
136                    None,        // threshold
137                    None,        // init method
138                    None,        // missing method
139                    Some(false), // check_finite
140                    Some(seed),
141                ) {
142                    Ok((_, labels)) => {
143                        let labels_usize: Array1<usize> = labels.mapv(|x| x);
144                        run_labels.push(labels_usize);
145                    }
146                    Err(_) => {
147                        // If clustering fails, create a dummy result
148                        let dummy_labels = Array1::zeros(subsample_size);
149                        run_labels.push(dummy_labels);
150                    }
151                }
152            }
153
154            bootstrap_results.push((indices, run_labels));
155        }
156
157        // Calculate stability metrics
158        let stability_scores = self.calculate_stability_scores(&bootstrap_results)?;
159        let mean_stability = stability_scores
160            .iter()
161            .copied()
162            .fold(F::zero(), |acc, x| acc + x)
163            / F::from(stability_scores.len()).unwrap();
164
165        let variance = stability_scores
166            .iter()
167            .map(|&x| {
168                let diff = x - mean_stability;
169                diff * diff
170            })
171            .fold(F::zero(), |acc, x| acc + x)
172            / F::from(stability_scores.len()).unwrap();
173        let std_stability = variance.sqrt();
174
175        // Create bootstrap stability matrix
176        let bootstrap_matrix = self.create_bootstrap_matrix(&bootstrap_results, n_samples)?;
177
178        Ok(StabilityResult {
179            stability_scores,
180            consensus_labels: None, // Would need consensus clustering implementation
181            optimal_k: None,
182            mean_stability,
183            std_stability,
184            bootstrap_matrix,
185        })
186    }
187
188    /// Calculate stability scores from bootstrap results
189    fn calculate_stability_scores(
190        &self,
191        bootstrap_results: &[(Vec<usize>, Vec<Array1<usize>>)],
192    ) -> Result<Vec<F>> {
193        let mut scores = Vec::new();
194
195        for (_, run_labels) in bootstrap_results {
196            if run_labels.len() < 2 {
197                continue;
198            }
199
200            // Calculate pairwise ARI between runs
201            let mut pairwise_aris = Vec::new();
202            for i in 0..run_labels.len() {
203                for j in (i + 1)..run_labels.len() {
204                    let labels1 = run_labels[i].mapv(|x| x as i32);
205                    let labels2 = run_labels[j].mapv(|x| x as i32);
206
207                    match adjusted_rand_index::<F>(labels1.view(), labels2.view()) {
208                        Ok(ari) => pairwise_aris.push(ari),
209                        Err(_) => pairwise_aris.push(F::zero()),
210                    }
211                }
212            }
213
214            if !pairwise_aris.is_empty() {
215                let mean_ari = pairwise_aris
216                    .iter()
217                    .copied()
218                    .fold(F::zero(), |acc, x| acc + x)
219                    / F::from(pairwise_aris.len()).unwrap();
220                scores.push(mean_ari);
221            }
222        }
223
224        Ok(scores)
225    }
226
227    /// Create bootstrap stability matrix
228    fn create_bootstrap_matrix(
229        &self,
230        bootstrap_results: &[(Vec<usize>, Vec<Array1<usize>>)],
231        n_samples: usize,
232    ) -> Result<Array2<F>> {
233        let mut co_occurrence_matrix: Array2<F> = Array2::zeros((n_samples, n_samples));
234        let mut count_matrix: Array2<F> = Array2::zeros((n_samples, n_samples));
235
236        for (indices, run_labels) in bootstrap_results {
237            if run_labels.is_empty() {
238                continue;
239            }
240
241            // Use the first run's labels for this bootstrap
242            let labels = &run_labels[0];
243
244            // Update co-occurrence matrix
245            for (i, &idx_i) in indices.iter().enumerate() {
246                for (j, &idx_j) in indices.iter().enumerate() {
247                    if i != j {
248                        count_matrix[[idx_i, idx_j]] = count_matrix[[idx_i, idx_j]] + F::one();
249
250                        if labels[i] == labels[j] {
251                            co_occurrence_matrix[[idx_i, idx_j]] =
252                                co_occurrence_matrix[[idx_i, idx_j]] + F::one();
253                        }
254                    }
255                }
256            }
257        }
258
259        // Convert to probabilities
260        let mut stability_matrix = Array2::zeros((n_samples, n_samples));
261        for i in 0..n_samples {
262            for j in 0..n_samples {
263                if count_matrix[[i, j]] > F::zero() {
264                    stability_matrix[[i, j]] = co_occurrence_matrix[[i, j]] / count_matrix[[i, j]];
265                }
266            }
267        }
268
269        Ok(stability_matrix)
270    }
271}
272
273/// Consensus clustering for robust cluster identification
274///
275/// This method combines multiple clustering results to identify
276/// stable cluster structures.
277pub struct ConsensusClusterer<F: Float> {
278    config: StabilityConfig,
279    phantom: std::marker::PhantomData<F>,
280}
281
282impl<F: Float + FromPrimitive + Debug + std::iter::Sum + std::fmt::Display> ConsensusClusterer<F> {
283    /// Create a new consensus clusterer
284    pub fn new(config: StabilityConfig) -> Self {
285        Self {
286            config,
287            phantom: std::marker::PhantomData,
288        }
289    }
290
291    /// Find consensus clusters using multiple algorithm runs
292    pub fn find_consensus_clusters(&self, data: ArrayView2<F>, k: usize) -> Result<Array1<usize>> {
293        let n_samples = data.shape()[0];
294
295        if n_samples < 2 {
296            return Err(ClusteringError::InvalidInput(
297                "Need at least 2 samples for consensus clustering".into(),
298            ));
299        }
300
301        let mut rng = match self.config.random_seed {
302            Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
303            None => {
304                // Use a default seed when no seed is provided
305                scirs2_core::random::rngs::StdRng::seed_from_u64(42)
306            }
307        };
308
309        let mut all_labels = Vec::new();
310
311        // Run clustering multiple times with different initializations
312        for _run in 0..self.config.n_bootstrap {
313            let seed = rng.random::<u64>();
314
315            match kmeans2(
316                data,
317                k,
318                Some(100),   // max_iter
319                None,        // threshold
320                None,        // init method
321                None,        // missing method
322                Some(false), // check_finite
323                Some(seed),
324            ) {
325                Ok((_, labels)) => {
326                    let labels_usize: Array1<usize> = labels.mapv(|x| x);
327                    all_labels.push(labels_usize);
328                }
329                Err(_) => {
330                    // Skip failed runs
331                    continue;
332                }
333            }
334        }
335
336        if all_labels.is_empty() {
337            return Err(ClusteringError::ComputationError(
338                "All clustering runs failed".into(),
339            ));
340        }
341
342        // Build consensus matrix
343        let mut consensus_matrix = Array2::zeros((n_samples, n_samples));
344
345        for labels in &all_labels {
346            for i in 0..n_samples {
347                for j in 0..n_samples {
348                    if labels[i] == labels[j] {
349                        consensus_matrix[[i, j]] = consensus_matrix[[i, j]] + F::one();
350                    }
351                }
352            }
353        }
354
355        // Normalize by number of runs
356        let n_runs = F::from(all_labels.len()).unwrap();
357        consensus_matrix.mapv_inplace(|x| x / n_runs);
358
359        // Extract consensus clusters using threshold
360        let threshold = F::from(0.5).unwrap();
361        self.extract_consensus_clusters(&consensus_matrix, threshold, k)
362    }
363
364    /// Extract clusters from consensus matrix
365    fn extract_consensus_clusters(
366        &self,
367        consensus_matrix: &Array2<F>,
368        threshold: F,
369        k: usize,
370    ) -> Result<Array1<usize>> {
371        let n_samples = consensus_matrix.shape()[0];
372        let mut labels = Array1::from_elem(n_samples, usize::MAX); // Unassigned
373        let mut current_cluster = 0;
374
375        // Use a greedy approach to find dense consensus regions
376        let mut unassigned: HashSet<usize> = (0..n_samples).collect();
377
378        while current_cluster < k && !unassigned.is_empty() {
379            // Find the pair with highest consensus that includes an unassigned point
380            let mut best_consensus = F::zero();
381            let mut best_seed = None;
382
383            for &i in &unassigned {
384                for &j in &unassigned {
385                    if i != j && consensus_matrix[[i, j]] > best_consensus {
386                        best_consensus = consensus_matrix[[i, j]];
387                        best_seed = Some(i);
388                    }
389                }
390            }
391
392            if let Some(seed) = best_seed {
393                // Grow cluster from seed
394                let mut cluster_members = Vec::new();
395                cluster_members.push(seed);
396
397                // Add all points with high consensus to the seed
398                for &candidate in &unassigned {
399                    if candidate != seed && consensus_matrix[[seed, candidate]] >= threshold {
400                        cluster_members.push(candidate);
401                    }
402                }
403
404                // Assign cluster label
405                for &member in &cluster_members {
406                    labels[member] = current_cluster;
407                    unassigned.remove(&member);
408                }
409
410                current_cluster += 1;
411            } else {
412                // No more good consensus pairs, assign remaining points to nearest cluster
413                break;
414            }
415        }
416
417        // Assign remaining unassigned points to the nearest existing cluster
418        for &unassigned_point in &unassigned {
419            let mut best_cluster = 0;
420            let mut best_avg_consensus = F::zero();
421
422            for cluster_id in 0..current_cluster {
423                let mut total_consensus = F::zero();
424                let mut count = 0;
425
426                for i in 0..n_samples {
427                    if labels[i] == cluster_id {
428                        total_consensus = total_consensus + consensus_matrix[[unassigned_point, i]];
429                        count += 1;
430                    }
431                }
432
433                if count > 0 {
434                    let avg_consensus = total_consensus / F::from(count).unwrap();
435                    if avg_consensus > best_avg_consensus {
436                        best_avg_consensus = avg_consensus;
437                        best_cluster = cluster_id;
438                    }
439                }
440            }
441
442            labels[unassigned_point] = best_cluster;
443        }
444
445        Ok(labels)
446    }
447}
448
449/// Optimal cluster number selection using stability criteria
450pub struct OptimalKSelector<F: Float> {
451    config: StabilityConfig,
452    phantom: std::marker::PhantomData<F>,
453}
454
455impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
456    OptimalKSelector<F>
457{
458    /// Create a new optimal k selector
459    pub fn new(config: StabilityConfig) -> Self {
460        Self {
461            config,
462            phantom: std::marker::PhantomData,
463        }
464    }
465
466    /// Find optimal number of clusters using stability gap statistic
467    pub fn find_optimal_k(&self, data: ArrayView2<F>) -> Result<(usize, Vec<F>)> {
468        let (k_min, k_max) = self.config.k_range.unwrap_or((2, 10));
469        let mut stability_scores = Vec::new();
470
471        for k in k_min..=k_max {
472            let validator = BootstrapValidator::new(self.config.clone());
473            match validator.assess_kmeans_stability(data, k) {
474                Ok(result) => stability_scores.push(result.mean_stability),
475                Err(_) => stability_scores.push(F::zero()),
476            }
477        }
478
479        // Find k with maximum stability
480        let mut best_k = k_min;
481        let mut best_score = F::neg_infinity();
482
483        for (i, &score) in stability_scores.iter().enumerate() {
484            if score > best_score {
485                best_score = score;
486                best_k = k_min + i;
487            }
488        }
489
490        Ok((best_k, stability_scores))
491    }
492
493    /// Find optimal k using gap statistic with reference distribution
494    pub fn gap_statistic(&self, data: ArrayView2<F>) -> Result<(usize, Vec<F>)> {
495        let (k_min, k_max) = self.config.k_range.unwrap_or((2, 10));
496        let n_samples = data.shape()[0];
497        let n_features = data.shape()[1];
498
499        let mut gap_scores = Vec::new();
500
501        // Find data bounds for reference distribution
502        let mut min_vals = Array1::from_elem(n_features, F::infinity());
503        let mut max_vals = Array1::from_elem(n_features, F::neg_infinity());
504
505        for i in 0..n_samples {
506            for j in 0..n_features {
507                let val = data[[i, j]];
508                if val < min_vals[j] {
509                    min_vals[j] = val;
510                }
511                if val > max_vals[j] {
512                    max_vals[j] = val;
513                }
514            }
515        }
516
517        for k in k_min..=k_max {
518            // Calculate log(W_k) for original data
519            let original_wk = self.calculate_within_cluster_dispersion(data, k)?;
520            let log_wk = original_wk.ln();
521
522            // Calculate expected log(W_k) from reference distribution
523            let mut reference_log_wks = Vec::new();
524            let mut rng = match self.config.random_seed {
525                Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
526                None => {
527                    // Use a default seed when no seed is provided
528                    scirs2_core::random::rngs::StdRng::seed_from_u64(42)
529                }
530            };
531
532            for _b in 0..self.config.n_bootstrap {
533                // Generate reference data
534                let mut reference_data = Array2::zeros((n_samples, n_features));
535                for i in 0..n_samples {
536                    for j in 0..n_features {
537                        let range = max_vals[j] - min_vals[j];
538                        let random_val =
539                            min_vals[j] + range * F::from(rng.random::<f64>()).unwrap();
540                        reference_data[[i, j]] = random_val;
541                    }
542                }
543
544                let reference_wk =
545                    self.calculate_within_cluster_dispersion(reference_data.view(), k)?;
546                reference_log_wks.push(reference_wk.ln());
547            }
548
549            // Calculate gap statistic
550            let expected_log_wk = reference_log_wks
551                .iter()
552                .copied()
553                .fold(F::zero(), |acc, x| acc + x)
554                / F::from(reference_log_wks.len()).unwrap();
555            let gap = expected_log_wk - log_wk;
556            gap_scores.push(gap);
557        }
558
559        // Find optimal k (first k where gap(k) >= gap(k+1) - s_{k+1})
560        let mut optimal_k = k_min;
561        for i in 0..(gap_scores.len() - 1) {
562            if gap_scores[i] >= gap_scores[i + 1] {
563                optimal_k = k_min + i;
564                break;
565            }
566        }
567
568        Ok((optimal_k, gap_scores))
569    }
570
571    /// Calculate within-cluster dispersion W_k
572    fn calculate_within_cluster_dispersion(&self, data: ArrayView2<F>, k: usize) -> Result<F> {
573        // Run K-means clustering
574        match kmeans2(
575            data,
576            k,
577            Some(100),   // max_iter
578            None,        // threshold
579            None,        // init method
580            None,        // missing method
581            Some(false), // check_finite
582            self.config.random_seed,
583        ) {
584            Ok((centroids, labels)) => {
585                let mut total_dispersion = F::zero();
586
587                for cluster_id in 0..k {
588                    let mut cluster_dispersion = F::zero();
589                    let mut cluster_size = 0;
590
591                    // Calculate sum of squared distances within cluster
592                    for i in 0..data.shape()[0] {
593                        if labels[i] == cluster_id {
594                            let mut sq_dist = F::zero();
595                            for j in 0..data.shape()[1] {
596                                let diff = data[[i, j]] - centroids[[cluster_id, j]];
597                                sq_dist = sq_dist + diff * diff;
598                            }
599                            cluster_dispersion = cluster_dispersion + sq_dist;
600                            cluster_size += 1;
601                        }
602                    }
603
604                    // Normalize by cluster size
605                    if cluster_size > 1 {
606                        total_dispersion =
607                            total_dispersion + cluster_dispersion / F::from(cluster_size).unwrap();
608                    }
609                }
610
611                Ok(total_dispersion)
612            }
613            Err(e) => Err(e),
614        }
615    }
616}
617
618/// Advanced stability assessment methods
619pub mod advanced {
620    use super::*;
621    use crate::ensemble::{EnsembleClusterer, EnsembleConfig};
622    use crate::metrics::{mutual_info_score, silhouette_score};
623
624    /// Cross-validation based stability assessment
625    ///
626    /// This method uses k-fold cross-validation to assess clustering stability
627    /// by training on different subsets and testing on held-out data.
628    pub struct CrossValidationStability<F: Float> {
629        config: StabilityConfig,
630        n_folds: usize,
631        _phantom: std::marker::PhantomData<F>,
632    }
633
634    impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
635        CrossValidationStability<F>
636    {
637        /// Create a new cross-validation stability assessor
638        pub fn new(config: StabilityConfig, n_folds: usize) -> Self {
639            Self {
640                config,
641                n_folds,
642                _phantom: std::marker::PhantomData,
643            }
644        }
645
646        /// Assess clustering stability using cross-validation
647        pub fn assess_stability(
648            &self,
649            data: ArrayView2<F>,
650            k: usize,
651        ) -> Result<StabilityResult<F>> {
652            let n_samples = data.shape()[0];
653            let fold_size = n_samples / self.n_folds;
654            let mut stability_scores = Vec::new();
655            let mut bootstrap_matrix = Array2::zeros((self.n_folds, self.n_folds));
656
657            // Perform k-fold cross-validation
658            for fold in 0..self.n_folds {
659                let start_idx = fold * fold_size;
660                let end_idx = if fold == self.n_folds - 1 {
661                    n_samples
662                } else {
663                    (fold + 1) * fold_size
664                };
665
666                // Create training set (excluding current fold)
667                let mut train_indices = Vec::new();
668                for i in 0..n_samples {
669                    if i < start_idx || i >= end_idx {
670                        train_indices.push(i);
671                    }
672                }
673
674                // Create training data
675                let train_data =
676                    Array2::from_shape_fn((train_indices.len(), data.shape()[1]), |(i, j)| {
677                        data[[train_indices[i], j]]
678                    });
679
680                // Run clustering on training data
681                let (train_centroids, train_labels) = kmeans2(
682                    train_data.view(),
683                    k,
684                    Some(100),                    // max_iter
685                    Some(F::from(1e-6).unwrap()), // threshold
686                    None,                         // init method
687                    None,                         // missing method
688                    None,                         // check_finite
689                    Some(42),                     // seed
690                )?;
691
692                // Assign test data to nearest centroids
693                let test_labels = Array1::from_shape_fn(end_idx - start_idx, |i| {
694                    let test_point = data.row(start_idx + i);
695                    let mut min_dist = F::infinity();
696                    let mut closest_cluster = 0;
697
698                    for (cluster_id, centroid) in train_centroids.outer_iter().enumerate() {
699                        let dist = test_point
700                            .iter()
701                            .zip(centroid.iter())
702                            .map(|(a, b)| (*a - *b) * (*a - *b))
703                            .sum::<F>()
704                            .sqrt();
705
706                        if dist < min_dist {
707                            min_dist = dist;
708                            closest_cluster = cluster_id;
709                        }
710                    }
711                    closest_cluster
712                });
713
714                // Calculate stability score for this fold
715                let stability = self.calculate_fold_stability(&test_labels, k)?;
716                stability_scores.push(stability);
717            }
718
719            // Calculate mean and standard deviation
720            let mean_stability = stability_scores.iter().fold(F::zero(), |acc, x| acc + *x)
721                / F::from(stability_scores.len()).unwrap();
722            let variance = stability_scores
723                .iter()
724                .map(|&s| (s - mean_stability) * (s - mean_stability))
725                .fold(F::zero(), |acc, x| acc + x)
726                / F::from(stability_scores.len()).unwrap();
727            let std_stability = variance.sqrt();
728
729            Ok(StabilityResult {
730                stability_scores,
731                consensus_labels: None,
732                optimal_k: None,
733                mean_stability,
734                std_stability,
735                bootstrap_matrix,
736            })
737        }
738
739        fn calculate_fold_stability(&self, labels: &Array1<usize>, k: usize) -> Result<F> {
740            // Calculate intra-cluster cohesion
741            let mut cluster_cohesion = F::zero();
742            let mut total_pairs = 0;
743
744            for cluster_id in 0..k {
745                let cluster_members: Vec<_> = labels
746                    .iter()
747                    .enumerate()
748                    .filter(|(_, &label)| label == cluster_id)
749                    .map(|(idx_, _)| idx_)
750                    .collect();
751
752                let cluster_size = cluster_members.len();
753                if cluster_size > 1 {
754                    let pairs = cluster_size * (cluster_size - 1) / 2;
755                    cluster_cohesion = cluster_cohesion + F::from(pairs).unwrap();
756                    total_pairs += pairs;
757                }
758            }
759
760            if total_pairs == 0 {
761                Ok(F::zero())
762            } else {
763                Ok(cluster_cohesion / F::from(total_pairs).unwrap())
764            }
765        }
766    }
767
768    /// Perturbation-based stability assessment
769    ///
770    /// This method assesses stability by introducing controlled perturbations
771    /// to the data and measuring how much the clustering results change.
772    pub struct PerturbationStability<F: Float> {
773        config: StabilityConfig,
774        perturbation_types: Vec<PerturbationType>,
775        _phantom: std::marker::PhantomData<F>,
776    }
777
778    /// Types of perturbations for stability testing
779    #[derive(Debug, Clone)]
780    pub enum PerturbationType {
781        /// Add Gaussian noise
782        GaussianNoise { std_dev: f64 },
783        /// Remove random samples
784        SampleRemoval { removal_rate: f64 },
785        /// Add random features
786        FeatureNoise { noise_level: f64 },
787        /// Outlier injection
788        OutlierInjection {
789            outlier_rate: f64,
790            outlier_magnitude: f64,
791        },
792    }
793
794    impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
795        PerturbationStability<F>
796    {
797        /// Create a new perturbation stability assessor
798        pub fn new(config: StabilityConfig, perturbation_types: Vec<PerturbationType>) -> Self {
799            Self {
800                config,
801                perturbation_types,
802                _phantom: std::marker::PhantomData,
803            }
804        }
805
806        /// Assess clustering stability under perturbations
807        pub fn assess_stability(
808            &self,
809            data: ArrayView2<F>,
810            k: usize,
811        ) -> Result<StabilityResult<F>> {
812            let mut all_stability_scores = Vec::new();
813            let mut rng = scirs2_core::random::rng();
814
815            // Get baseline clustering
816            let (baseline_centroids, baseline_labels) = kmeans2(
817                data,
818                k,
819                Some(100),                    // max_iter
820                Some(F::from(1e-6).unwrap()), // threshold
821                None,                         // init method
822                None,                         // missing method
823                None,                         // check_finite
824                Some(42),                     // seed
825            )?;
826
827            // Test each perturbation type
828            for perturbation in &self.perturbation_types {
829                let mut perturbation_scores = Vec::new();
830
831                for _ in 0..self.config.n_bootstrap {
832                    // Apply perturbation
833                    let perturbed_data = self.apply_perturbation(data, perturbation, &mut rng)?;
834
835                    // Run clustering on perturbed data
836                    let (_, perturbed_labels) = kmeans2(
837                        perturbed_data.view(),
838                        k,
839                        Some(100),                    // max_iter
840                        Some(F::from(1e-6).unwrap()), // threshold
841                        None,                         // init method
842                        None,                         // missing method
843                        None,                         // check_finite
844                        None,                         // random seed
845                    )?;
846
847                    // Calculate similarity to baseline
848                    let similarity =
849                        self.calculate_label_similarity(&baseline_labels, &perturbed_labels)?;
850                    perturbation_scores.push(similarity);
851                }
852
853                all_stability_scores.extend(perturbation_scores);
854            }
855
856            // Calculate overall statistics
857            let mean_stability = all_stability_scores
858                .iter()
859                .fold(F::zero(), |acc, x| acc + *x)
860                / F::from(all_stability_scores.len()).unwrap();
861            let variance = all_stability_scores
862                .iter()
863                .map(|&s| (s - mean_stability) * (s - mean_stability))
864                .sum::<F>()
865                / F::from(all_stability_scores.len()).unwrap();
866            let std_stability = variance.sqrt();
867
868            let bootstrap_matrix =
869                Array2::zeros((self.config.n_bootstrap, self.perturbation_types.len()));
870
871            Ok(StabilityResult {
872                stability_scores: all_stability_scores,
873                consensus_labels: None,
874                optimal_k: None,
875                mean_stability,
876                std_stability,
877                bootstrap_matrix,
878            })
879        }
880
881        fn apply_perturbation(
882            &self,
883            data: ArrayView2<F>,
884            perturbation: &PerturbationType,
885            rng: &mut impl Rng,
886        ) -> Result<Array2<F>> {
887            let mut perturbed = data.to_owned();
888
889            match perturbation {
890                PerturbationType::GaussianNoise { std_dev } => {
891                    for elem in perturbed.iter_mut() {
892                        let noise = rng.random::<f64>() * std_dev;
893                        *elem = *elem + F::from(noise).unwrap();
894                    }
895                }
896                PerturbationType::SampleRemoval { removal_rate } => {
897                    let n_samples = data.shape()[0];
898                    let n_remove = (n_samples as f64 * removal_rate) as usize;
899                    let mut indices: Vec<_> = (0..n_samples).collect();
900                    indices.shuffle(rng);
901                    indices.truncate(n_samples - n_remove);
902                    indices.sort();
903
904                    let mut new_data = Array2::zeros((indices.len(), data.shape()[1]));
905                    for (new_i, &old_i) in indices.iter().enumerate() {
906                        new_data.row_mut(new_i).assign(&data.row(old_i));
907                    }
908                    perturbed = new_data;
909                }
910                PerturbationType::FeatureNoise { noise_level } => {
911                    for elem in perturbed.iter_mut() {
912                        let noise = (rng.random::<f64>() - 0.5) * 2.0 * noise_level;
913                        *elem = *elem + F::from(noise).unwrap();
914                    }
915                }
916                PerturbationType::OutlierInjection {
917                    outlier_rate,
918                    outlier_magnitude,
919                } => {
920                    let n_samples = data.shape()[0];
921                    let n_outliers = (n_samples as f64 * outlier_rate) as usize;
922
923                    for _ in 0..n_outliers {
924                        let sample_idx = rng.random_range(0..n_samples);
925                        let feature_idx = rng.random_range(0..data.shape()[1]);
926                        let outlier_value = rng.random::<f64>() * outlier_magnitude;
927                        perturbed[[sample_idx, feature_idx]] = F::from(outlier_value).unwrap();
928                    }
929                }
930            }
931
932            Ok(perturbed)
933        }
934
935        fn calculate_label_similarity(
936            &self,
937            labels1: &Array1<usize>,
938            labels2: &Array1<usize>,
939        ) -> Result<F> {
940            if labels1.len() != labels2.len() {
941                return Ok(F::zero());
942            }
943
944            // Convert to i32 for ARI calculation
945            let labels1_i32: Array1<i32> = labels1.mapv(|x| x as i32);
946            let labels2_i32: Array1<i32> = labels2.mapv(|x| x as i32);
947
948            // Use adjusted rand index as similarity measure
949            let ari: f64 = adjusted_rand_index(labels1_i32.view(), labels2_i32.view())?;
950            Ok(F::from(ari).unwrap())
951        }
952    }
953
954    /// Multi-scale stability assessment
955    ///
956    /// This method assesses stability across different data scales and resolutions
957    /// to understand how clustering behaves at different granularities.
958    pub struct MultiScaleStability<F: Float> {
959        config: StabilityConfig,
960        scale_factors: Vec<f64>,
961        _phantom: std::marker::PhantomData<F>,
962    }
963
964    impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
965        MultiScaleStability<F>
966    {
967        /// Create a new multi-scale stability assessor
968        pub fn new(config: StabilityConfig, scale_factors: Vec<f64>) -> Self {
969            Self {
970                config,
971                scale_factors,
972                _phantom: std::marker::PhantomData,
973            }
974        }
975
976        /// Assess clustering stability across multiple scales
977        pub fn assess_stability(
978            &self,
979            data: ArrayView2<F>,
980            k_range: (usize, usize),
981        ) -> Result<Vec<StabilityResult<F>>> {
982            let mut results = Vec::new();
983
984            for &scale_factor in &self.scale_factors {
985                // Scale the data
986                let scaled_data = data.mapv(|x| x * F::from(scale_factor).unwrap());
987
988                // Assess stability at this scale for different k values
989                for k in k_range.0..=k_range.1 {
990                    let validator = BootstrapValidator::new(self.config.clone());
991                    let stability_result =
992                        validator.assess_kmeans_stability(scaled_data.view(), k)?;
993                    results.push(stability_result);
994                }
995            }
996
997            Ok(results)
998        }
999
1000        /// Find the most stable scale and cluster count combination
1001        pub fn find_optimal_scale_and_k(
1002            &self,
1003            data: ArrayView2<F>,
1004            k_range: (usize, usize),
1005        ) -> Result<(f64, usize, F)> {
1006            let results = self.assess_stability(data, k_range)?;
1007
1008            let mut best_scale = self.scale_factors[0];
1009            let mut best_k = k_range.0;
1010            let mut best_stability = F::neg_infinity();
1011
1012            let mut result_idx = 0;
1013            for &scale_factor in &self.scale_factors {
1014                for k in k_range.0..=k_range.1 {
1015                    if result_idx < results.len() {
1016                        let stability = results[result_idx].mean_stability;
1017                        if stability > best_stability {
1018                            best_stability = stability;
1019                            best_scale = scale_factor;
1020                            best_k = k;
1021                        }
1022                        result_idx += 1;
1023                    }
1024                }
1025            }
1026
1027            Ok((best_scale, best_k, best_stability))
1028        }
1029    }
1030
1031    /// Prediction Strength Method for clustering validation
1032    ///
1033    /// This method assesses clustering stability by measuring how well cluster assignments
1034    /// from one dataset can predict assignments in another dataset. Based on Tibshirani & Walther's
1035    /// prediction strength criterion.
1036    pub struct PredictionStrength<F: Float> {
1037        /// Configuration for prediction strength assessment
1038        pub config: PredictionStrengthConfig,
1039        phantom: std::marker::PhantomData<F>,
1040    }
1041
1042    /// Configuration for prediction strength method
1043    #[derive(Debug, Clone)]
1044    pub struct PredictionStrengthConfig {
1045        /// Number of bootstrap iterations for assessment
1046        pub n_bootstrap: usize,
1047        /// Fraction of data to use for training in each split
1048        pub train_ratio: f64,
1049        /// Minimum prediction strength threshold for validation
1050        pub strength_threshold: f64,
1051        /// Random seed for reproducible results
1052        pub random_seed: Option<u64>,
1053    }
1054
1055    impl Default for PredictionStrengthConfig {
1056        fn default() -> Self {
1057            Self {
1058                n_bootstrap: 50,
1059                train_ratio: 0.5,
1060                strength_threshold: 0.8,
1061                random_seed: None,
1062            }
1063        }
1064    }
1065
1066    impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
1067        PredictionStrength<F>
1068    {
1069        /// Create a new prediction strength validator
1070        pub fn new(config: PredictionStrengthConfig) -> Self {
1071            Self {
1072                config,
1073                phantom: std::marker::PhantomData,
1074            }
1075        }
1076
1077        /// Assess prediction strength for a range of cluster numbers
1078        pub fn assess_k_range(
1079            &self,
1080            data: ArrayView2<F>,
1081            k_range: (usize, usize),
1082        ) -> Result<Vec<F>> {
1083            let mut prediction_strengths = Vec::new();
1084
1085            for k in k_range.0..=k_range.1 {
1086                let strength = self.compute_prediction_strength(data, k)?;
1087                prediction_strengths.push(strength);
1088            }
1089
1090            Ok(prediction_strengths)
1091        }
1092
1093        /// Compute prediction strength for a specific number of clusters
1094        pub fn compute_prediction_strength(&self, data: ArrayView2<F>, k: usize) -> Result<F> {
1095            let mut rng = match self.config.random_seed {
1096                Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
1097                None => scirs2_core::random::rngs::StdRng::seed_from_u64(
1098                    scirs2_core::random::rng().random(),
1099                ),
1100            };
1101
1102            let n_samples = data.nrows();
1103            let train_size = ((n_samples as f64) * self.config.train_ratio) as usize;
1104
1105            let mut prediction_scores = Vec::new();
1106
1107            for _ in 0..self.config.n_bootstrap {
1108                // Split data into training and test sets
1109                let mut indices: Vec<usize> = (0..n_samples).collect();
1110                indices.shuffle(&mut rng);
1111
1112                let train_indices = &indices[..train_size];
1113                let test_indices = &indices[train_size..];
1114
1115                if test_indices.is_empty() {
1116                    continue;
1117                }
1118
1119                // Create training and test data
1120                let train_data = data.select(scirs2_core::ndarray::Axis(0), train_indices);
1121                let test_data = data.select(scirs2_core::ndarray::Axis(0), test_indices);
1122
1123                // Cluster training data
1124                match kmeans2(train_data.view(), k, None, None, None, None, None, None) {
1125                    Ok((_, train_labels)) => {
1126                        // Cluster test data
1127                        match kmeans2(test_data.view(), k, None, None, None, None, None, None) {
1128                            Ok((_, test_labels)) => {
1129                                // Compute prediction strength
1130                                let strength = self.compute_pairwise_prediction_strength(
1131                                    &train_data,
1132                                    &test_data,
1133                                    &train_labels,
1134                                    &test_labels,
1135                                )?;
1136                                prediction_scores.push(strength);
1137                            }
1138                            Err(_) => continue,
1139                        }
1140                    }
1141                    Err(_) => continue,
1142                }
1143            }
1144
1145            if prediction_scores.is_empty() {
1146                return Ok(F::zero());
1147            }
1148
1149            // Return mean prediction strength
1150            let sum: F = prediction_scores.iter().fold(F::zero(), |acc, &x| acc + x);
1151            Ok(sum / F::from(prediction_scores.len()).unwrap())
1152        }
1153
1154        /// Compute pairwise prediction strength between training and test assignments
1155        fn compute_pairwise_prediction_strength(
1156            &self,
1157            train_data: &Array2<F>,
1158            test_data: &Array2<F>,
1159            train_labels: &Array1<usize>,
1160            test_labels: &Array1<usize>,
1161        ) -> Result<F> {
1162            let test_size = test_data.nrows();
1163            let mut correct_predictions = 0;
1164            let mut total_predictions = 0;
1165
1166            // For each pair of test points
1167            for i in 0..test_size {
1168                for j in (i + 1)..test_size {
1169                    // Find closest points in training _data
1170                    let closest_train_i = self.find_closest_point(&test_data.row(i), train_data)?;
1171                    let closest_train_j = self.find_closest_point(&test_data.row(j), train_data)?;
1172
1173                    // Predict whether test points should be in same cluster
1174                    let predicted_same =
1175                        train_labels[closest_train_i] == train_labels[closest_train_j];
1176                    let actual_same = test_labels[i] == test_labels[j];
1177
1178                    if predicted_same == actual_same {
1179                        correct_predictions += 1;
1180                    }
1181                    total_predictions += 1;
1182                }
1183            }
1184
1185            if total_predictions == 0 {
1186                return Ok(F::zero());
1187            }
1188
1189            Ok(F::from(correct_predictions as f64 / total_predictions as f64).unwrap())
1190        }
1191
1192        /// Find closest point in training data to a test point
1193        fn find_closest_point(
1194            &self,
1195            test_point: &scirs2_core::ndarray::ArrayView1<F>,
1196            train_data: &Array2<F>,
1197        ) -> Result<usize> {
1198            let mut min_distance = F::infinity();
1199            let mut closest_idx = 0;
1200
1201            for (idx, train_point) in train_data.rows().into_iter().enumerate() {
1202                let distance = test_point
1203                    .iter()
1204                    .zip(train_point.iter())
1205                    .map(|(a, b)| (*a - *b) * (*a - *b))
1206                    .fold(F::zero(), |acc, x| acc + x)
1207                    .sqrt();
1208
1209                if distance < min_distance {
1210                    min_distance = distance;
1211                    closest_idx = idx;
1212                }
1213            }
1214
1215            Ok(closest_idx)
1216        }
1217
1218        /// Find optimal number of clusters using prediction strength
1219        pub fn find_optimal_k(
1220            &self,
1221            data: ArrayView2<F>,
1222            k_range: (usize, usize),
1223        ) -> Result<usize> {
1224            let strengths = self.assess_k_range(data, k_range)?;
1225
1226            // Find largest k with prediction strength above threshold
1227            for (idx, &strength) in strengths.iter().enumerate().rev() {
1228                if strength >= F::from(self.config.strength_threshold).unwrap() {
1229                    return Ok(k_range.0 + idx);
1230                }
1231            }
1232
1233            // If no k meets threshold, return the one with highest strength
1234            let best_idx = strengths
1235                .iter()
1236                .enumerate()
1237                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1238                .map(|(idx_, _)| idx_)
1239                .unwrap_or(0);
1240
1241            Ok(k_range.0 + best_idx)
1242        }
1243    }
1244
1245    /// Jaccard Stability Index for clustering validation
1246    ///
1247    /// Measures stability using Jaccard similarity between cluster assignments
1248    /// across different bootstrap samples or parameter settings.
1249    pub struct JaccardStability<F: Float> {
1250        /// Number of bootstrap iterations
1251        pub n_bootstrap: usize,
1252        /// Subsample ratio for each bootstrap
1253        pub subsample_ratio: f64,
1254        /// Random seed for reproducible results
1255        pub random_seed: Option<u64>,
1256        _phantom: std::marker::PhantomData<F>,
1257    }
1258
1259    impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
1260        JaccardStability<F>
1261    {
1262        /// Create a new Jaccard stability validator
1263        pub fn new(n_bootstrap: usize, subsample_ratio: f64, random_seed: Option<u64>) -> Self {
1264            Self {
1265                n_bootstrap,
1266                subsample_ratio,
1267                random_seed,
1268                _phantom: std::marker::PhantomData,
1269            }
1270        }
1271
1272        /// Compute Jaccard stability index for given data and cluster number
1273        pub fn compute_stability(&self, data: ArrayView2<F>, k: usize) -> Result<F> {
1274            let mut rng = match self.random_seed {
1275                Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
1276                None => scirs2_core::random::rngs::StdRng::seed_from_u64(
1277                    scirs2_core::random::rng().random(),
1278                ),
1279            };
1280
1281            let n_samples = data.nrows();
1282            let subsample_size = ((n_samples as f64) * self.subsample_ratio) as usize;
1283
1284            let mut jaccard_scores = Vec::new();
1285
1286            // Generate pairs of bootstrap samples and compute Jaccard similarity
1287            for _ in 0..self.n_bootstrap {
1288                // First bootstrap sample
1289                let mut indices1: Vec<usize> = (0..n_samples).collect();
1290                indices1.shuffle(&mut rng);
1291                let sample_indices1 = &indices1[..subsample_size];
1292                let sample_data1 = data.select(scirs2_core::ndarray::Axis(0), sample_indices1);
1293
1294                // Second bootstrap sample
1295                let mut indices2: Vec<usize> = (0..n_samples).collect();
1296                indices2.shuffle(&mut rng);
1297                let sample_indices2 = &indices2[..subsample_size];
1298                let sample_data2 = data.select(scirs2_core::ndarray::Axis(0), sample_indices2);
1299
1300                // Cluster both samples
1301                match (
1302                    kmeans2(sample_data1.view(), k, None, None, None, None, None, None),
1303                    kmeans2(sample_data2.view(), k, None, None, None, None, None, None),
1304                ) {
1305                    (Ok((_, labels1)), Ok((_, labels2))) => {
1306                        // Find overlapping samples
1307                        let overlap_indices: Vec<(usize, usize)> = sample_indices1
1308                            .iter()
1309                            .enumerate()
1310                            .filter_map(|(i1, &idx1)| {
1311                                sample_indices2
1312                                    .iter()
1313                                    .enumerate()
1314                                    .find(|(_, &idx2)| idx1 == idx2)
1315                                    .map(|(i2_, _)| (i1, i2_))
1316                            })
1317                            .collect();
1318
1319                        if overlap_indices.len() >= 2 {
1320                            let jaccard = self.compute_jaccard_similarity(
1321                                &labels1,
1322                                &labels2,
1323                                &overlap_indices,
1324                            )?;
1325                            jaccard_scores.push(jaccard);
1326                        }
1327                    }
1328                    _ => continue,
1329                }
1330            }
1331
1332            if jaccard_scores.is_empty() {
1333                return Ok(F::zero());
1334            }
1335
1336            // Return mean Jaccard similarity
1337            let sum: F = jaccard_scores.iter().fold(F::zero(), |acc, &x| acc + x);
1338            Ok(sum / F::from(jaccard_scores.len()).unwrap())
1339        }
1340
1341        /// Compute Jaccard similarity between two cluster assignments
1342        fn compute_jaccard_similarity(
1343            &self,
1344            labels1: &Array1<usize>,
1345            labels2: &Array1<usize>,
1346            overlap_indices: &[(usize, usize)],
1347        ) -> Result<F> {
1348            let mut same_cluster_both = 0;
1349            let mut same_cluster_either = 0;
1350
1351            let n_overlap = overlap_indices.len();
1352
1353            for i in 0..n_overlap {
1354                for j in (i + 1)..n_overlap {
1355                    let (idx1_i, idx2_i) = overlap_indices[i];
1356                    let (idx1_j, idx2_j) = overlap_indices[j];
1357
1358                    let same_in_clustering1 = labels1[idx1_i] == labels1[idx1_j];
1359                    let same_in_clustering2 = labels2[idx2_i] == labels2[idx2_j];
1360
1361                    if same_in_clustering1 && same_in_clustering2 {
1362                        same_cluster_both += 1;
1363                    }
1364                    if same_in_clustering1 || same_in_clustering2 {
1365                        same_cluster_either += 1;
1366                    }
1367                }
1368            }
1369
1370            if same_cluster_either == 0 {
1371                return Ok(F::one()); // All pairs are different in both clusterings
1372            }
1373
1374            Ok(F::from(same_cluster_both as f64 / same_cluster_either as f64).unwrap())
1375        }
1376
1377        /// Assess stability across a range of cluster numbers
1378        pub fn assess_k_range(
1379            &self,
1380            data: ArrayView2<F>,
1381            k_range: (usize, usize),
1382        ) -> Result<Vec<F>> {
1383            let mut stabilities = Vec::new();
1384
1385            for k in k_range.0..=k_range.1 {
1386                let stability = self.compute_stability(data, k)?;
1387                stabilities.push(stability);
1388            }
1389
1390            Ok(stabilities)
1391        }
1392    }
1393
1394    /// Cluster-Specific Stability Indices
1395    ///
1396    /// Provides stability metrics for individual clusters rather than
1397    /// global stability measures.
1398    pub struct ClusterSpecificStability<F: Float> {
1399        /// Configuration for cluster-specific stability assessment
1400        pub config: StabilityConfig,
1401        phantom: std::marker::PhantomData<F>,
1402    }
1403
1404    /// Results of cluster-specific stability assessment
1405    #[derive(Debug, Clone)]
1406    pub struct ClusterStabilityResult<F: Float> {
1407        /// Stability score for each cluster
1408        pub cluster_stabilities: Vec<F>,
1409        /// Mean stability across all clusters
1410        pub mean_stability: F,
1411        /// Standard deviation of cluster stabilities
1412        pub std_stability: F,
1413        /// Cluster size consistency across bootstrap samples
1414        pub size_consistency: Vec<F>,
1415    }
1416
1417    impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
1418        ClusterSpecificStability<F>
1419    {
1420        /// Create a new cluster-specific stability validator
1421        pub fn new(config: StabilityConfig) -> Self {
1422            Self {
1423                config,
1424                phantom: std::marker::PhantomData,
1425            }
1426        }
1427
1428        /// Assess stability for each cluster individually
1429        pub fn assess_cluster_stability(
1430            &self,
1431            data: ArrayView2<F>,
1432            k: usize,
1433        ) -> Result<ClusterStabilityResult<F>> {
1434            let mut rng = match self.config.random_seed {
1435                Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
1436                None => scirs2_core::random::rngs::StdRng::seed_from_u64(
1437                    scirs2_core::random::rng().random(),
1438                ),
1439            };
1440
1441            let n_samples = data.nrows();
1442            let subsample_size = ((n_samples as f64) * self.config.subsample_ratio) as usize;
1443
1444            let mut cluster_memberships: Vec<Vec<HashSet<usize>>> = vec![Vec::new(); k];
1445            let mut cluster_sizes: Vec<Vec<usize>> = vec![Vec::new(); k];
1446
1447            // Bootstrap sampling and clustering
1448            for _ in 0..self.config.n_bootstrap {
1449                let mut indices: Vec<usize> = (0..n_samples).collect();
1450                indices.shuffle(&mut rng);
1451                let sample_indices = &indices[..subsample_size];
1452                let sample_data = data.select(scirs2_core::ndarray::Axis(0), sample_indices);
1453
1454                match kmeans2(sample_data.view(), k, None, None, None, None, None, None) {
1455                    Ok((_, labels)) => {
1456                        // Track cluster memberships
1457                        for cluster_id in 0..k {
1458                            let mut cluster_members = HashSet::new();
1459                            for (local_idx, &label) in labels.iter().enumerate() {
1460                                if label == cluster_id {
1461                                    cluster_members.insert(sample_indices[local_idx]);
1462                                }
1463                            }
1464                            cluster_memberships[cluster_id].push(cluster_members.clone());
1465                            cluster_sizes[cluster_id].push(cluster_members.len());
1466                        }
1467                    }
1468                    Err(_) => continue,
1469                }
1470            }
1471
1472            // Compute stability for each cluster
1473            let mut cluster_stabilities = Vec::new();
1474            let mut size_consistency = Vec::new();
1475
1476            for cluster_id in 0..k {
1477                let stability = self.compute_cluster_stability(&cluster_memberships[cluster_id])?;
1478                cluster_stabilities.push(stability);
1479
1480                let consistency = self.compute_size_consistency(&cluster_sizes[cluster_id])?;
1481                size_consistency.push(consistency);
1482            }
1483
1484            // Compute statistics
1485            let mean_stability = cluster_stabilities
1486                .iter()
1487                .fold(F::zero(), |acc, &x| acc + x)
1488                / F::from(cluster_stabilities.len()).unwrap();
1489
1490            let variance = cluster_stabilities
1491                .iter()
1492                .map(|&x| (x - mean_stability) * (x - mean_stability))
1493                .fold(F::zero(), |acc, x| acc + x)
1494                / F::from(cluster_stabilities.len()).unwrap();
1495            let std_stability = variance.sqrt();
1496
1497            Ok(ClusterStabilityResult {
1498                cluster_stabilities,
1499                mean_stability,
1500                std_stability,
1501                size_consistency,
1502            })
1503        }
1504
1505        /// Compute stability for a single cluster across bootstrap samples
1506        fn compute_cluster_stability(&self, cluster_samples: &[HashSet<usize>]) -> Result<F> {
1507            if cluster_samples.len() < 2 {
1508                return Ok(F::zero());
1509            }
1510
1511            let mut jaccard_scores = Vec::new();
1512
1513            // Compute pairwise Jaccard similarities
1514            for i in 0..cluster_samples.len() {
1515                for j in (i + 1)..cluster_samples.len() {
1516                    let intersection_size =
1517                        cluster_samples[i].intersection(&cluster_samples[j]).count();
1518                    let union_size = cluster_samples[i].union(&cluster_samples[j]).count();
1519
1520                    if union_size > 0 {
1521                        let jaccard = intersection_size as f64 / union_size as f64;
1522                        jaccard_scores.push(F::from(jaccard).unwrap());
1523                    }
1524                }
1525            }
1526
1527            if jaccard_scores.is_empty() {
1528                return Ok(F::zero());
1529            }
1530
1531            // Return mean Jaccard similarity
1532            let sum: F = jaccard_scores.iter().fold(F::zero(), |acc, &x| acc + x);
1533            Ok(sum / F::from(jaccard_scores.len()).unwrap())
1534        }
1535
1536        /// Compute size consistency for a cluster across bootstrap samples
1537        fn compute_size_consistency(&self, sizes: &[usize]) -> Result<F> {
1538            if sizes.is_empty() {
1539                return Ok(F::zero());
1540            }
1541
1542            let mean_size = sizes.iter().sum::<usize>() as f64 / sizes.len() as f64;
1543            let variance = sizes
1544                .iter()
1545                .map(|&size| (size as f64 - mean_size).powi(2))
1546                .sum::<f64>()
1547                / sizes.len() as f64;
1548
1549            let cv = if mean_size > 0.0 {
1550                variance.sqrt() / mean_size
1551            } else {
1552                0.0
1553            };
1554            Ok(F::one() - F::from(cv).unwrap()) // Consistency = 1 - CV
1555        }
1556    }
1557
1558    /// Parameter Stability Analysis
1559    ///
1560    /// Assesses how sensitive clustering results are to parameter changes
1561    /// across different algorithm settings.
1562    pub struct ParameterStabilityAnalyzer<F: Float> {
1563        /// Base parameters for analysis
1564        pub base_k: usize,
1565        /// Parameter perturbation ranges
1566        pub perturbation_ranges: Vec<f64>,
1567        /// Number of random parameter samples per range
1568        pub n_samples_per_range: usize,
1569        /// Random seed for reproducible results
1570        pub random_seed: Option<u64>,
1571        _phantom: std::marker::PhantomData<F>,
1572    }
1573
1574    /// Results of parameter stability analysis
1575    #[derive(Debug, Clone)]
1576    pub struct ParameterStabilityResult<F: Float> {
1577        /// Stability scores for different perturbation levels
1578        pub stability_by_perturbation: Vec<F>,
1579        /// Parameter sensitivity profile
1580        pub sensitivity_profile: Vec<F>,
1581        /// Robust parameter range recommendation
1582        pub robust_range: (f64, f64),
1583    }
1584
1585    impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
1586        ParameterStabilityAnalyzer<F>
1587    {
1588        /// Create a new parameter stability analyzer
1589        pub fn new(
1590            base_k: usize,
1591            perturbation_ranges: Vec<f64>,
1592            n_samples_per_range: usize,
1593            random_seed: Option<u64>,
1594        ) -> Self {
1595            Self {
1596                base_k,
1597                perturbation_ranges,
1598                n_samples_per_range,
1599                random_seed,
1600                _phantom: std::marker::PhantomData,
1601            }
1602        }
1603
1604        /// Analyze parameter stability across perturbation ranges
1605        pub fn analyze_stability(
1606            &self,
1607            data: ArrayView2<F>,
1608        ) -> Result<ParameterStabilityResult<F>> {
1609            let mut rng = match self.random_seed {
1610                Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
1611                None => scirs2_core::random::rngs::StdRng::seed_from_u64(
1612                    scirs2_core::random::rng().random(),
1613                ),
1614            };
1615
1616            let mut stability_by_perturbation = Vec::new();
1617            let mut sensitivity_profile = Vec::new();
1618
1619            // Get baseline clustering
1620            let baseline_result = kmeans2(data, self.base_k, None, None, None, None, None, None)?;
1621
1622            for &perturbation_level in &self.perturbation_ranges {
1623                let mut stability_scores = Vec::new();
1624
1625                for _ in 0..self.n_samples_per_range {
1626                    // Perturb parameters (here we vary k as an example)
1627                    let k_perturbation = (F::from(rng.random::<f64>()).unwrap()
1628                        - F::from(0.5).unwrap())
1629                        * F::from(2.0).unwrap()
1630                        * F::from(perturbation_level).unwrap();
1631                    let perturbed_k = (self.base_k as f64
1632                        * (1.0 + k_perturbation.to_f64().unwrap()))
1633                    .round()
1634                    .max(1.0) as usize;
1635
1636                    match kmeans2(data, perturbed_k, None, None, None, None, None, None) {
1637                        Ok((_, perturbed_labels)) => {
1638                            // Compute stability using ARI with baseline
1639                            // Convert usize labels to i32 for ARI computation
1640                            let baseline_i32 = baseline_result.1.mapv(|x| x as i32);
1641                            let perturbed_i32 = perturbed_labels.mapv(|x| x as i32);
1642                            match adjusted_rand_index(baseline_i32.view(), perturbed_i32.view()) {
1643                                Ok(stability) => stability_scores.push(stability),
1644                                Err(_) => continue,
1645                            }
1646                        }
1647                        Err(_) => continue,
1648                    }
1649                }
1650
1651                if !stability_scores.is_empty() {
1652                    let mean_stability = stability_scores.iter().fold(F::zero(), |acc, &x| acc + x)
1653                        / F::from(stability_scores.len()).unwrap();
1654                    stability_by_perturbation.push(mean_stability);
1655
1656                    // Compute sensitivity (1 - stability)
1657                    sensitivity_profile.push(F::one() - mean_stability);
1658                }
1659            }
1660
1661            // Find robust parameter range (where sensitivity is low)
1662            let robust_range = self.find_robust_range(&sensitivity_profile);
1663
1664            Ok(ParameterStabilityResult {
1665                stability_by_perturbation,
1666                sensitivity_profile,
1667                robust_range,
1668            })
1669        }
1670
1671        /// Find the range of perturbations with lowest sensitivity
1672        fn find_robust_range(&self, sensitivity_profile: &[F]) -> (f64, f64) {
1673            if sensitivity_profile.is_empty() {
1674                return (0.0, 0.0);
1675            }
1676
1677            // Find minimum sensitivity
1678            let min_sensitivity = sensitivity_profile
1679                .iter()
1680                .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1681                .unwrap();
1682
1683            // Define threshold as min + 10% of range
1684            let max_sensitivity = sensitivity_profile
1685                .iter()
1686                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1687                .unwrap();
1688            let threshold =
1689                *min_sensitivity + (*max_sensitivity - *min_sensitivity) * F::from(0.1).unwrap();
1690
1691            // Find first and last indices below threshold
1692            let mut start_idx = None;
1693            let mut end_idx = None;
1694
1695            for (idx, &sensitivity) in sensitivity_profile.iter().enumerate() {
1696                if sensitivity <= threshold {
1697                    if start_idx.is_none() {
1698                        start_idx = Some(idx);
1699                    }
1700                    end_idx = Some(idx);
1701                }
1702            }
1703
1704            let start_range = start_idx
1705                .map(|idx| self.perturbation_ranges[idx])
1706                .unwrap_or(0.0);
1707            let end_range = end_idx
1708                .map(|idx| self.perturbation_ranges[idx])
1709                .unwrap_or(0.0);
1710
1711            (start_range, end_range)
1712        }
1713    }
1714}
1715
1716#[cfg(test)]
1717mod tests {
1718    use super::*;
1719    use scirs2_core::ndarray::Array2;
1720
1721    #[test]
1722    fn test_stability_config_default() {
1723        let config = StabilityConfig::default();
1724        assert_eq!(config.n_bootstrap, 100);
1725        assert_eq!(config.subsample_ratio, 0.8);
1726        assert_eq!(config.n_runs_per_bootstrap, 10);
1727        assert!(config.random_seed.is_none());
1728    }
1729
1730    #[test]
1731    fn test_bootstrap_validator() {
1732        let data =
1733            Array2::from_shape_vec((20, 2), (0..40).map(|i| i as f64 / 10.0).collect()).unwrap();
1734
1735        let config = StabilityConfig {
1736            n_bootstrap: 5,
1737            subsample_ratio: 0.8,
1738            n_runs_per_bootstrap: 3,
1739            random_seed: Some(42),
1740            k_range: None,
1741        };
1742
1743        let validator = BootstrapValidator::new(config);
1744        let result = validator.assess_kmeans_stability(data.view(), 2);
1745
1746        assert!(result.is_ok());
1747        let stability_result = result.unwrap();
1748        assert!(stability_result.mean_stability >= 0.0);
1749        assert!(stability_result.mean_stability <= 1.0);
1750        assert_eq!(stability_result.bootstrap_matrix.shape(), &[20, 20]);
1751    }
1752
1753    #[test]
1754    fn test_consensus_clusterer() {
1755        let data = Array2::from_shape_vec(
1756            (6, 2),
1757            vec![0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 5.0, 5.0, 5.1, 5.1, 5.2, 5.2],
1758        )
1759        .unwrap();
1760
1761        let config = StabilityConfig {
1762            n_bootstrap: 10,
1763            random_seed: Some(42),
1764            ..Default::default()
1765        };
1766
1767        let consensus = ConsensusClusterer::new(config);
1768        let result = consensus.find_consensus_clusters(data.view(), 2);
1769
1770        assert!(result.is_ok());
1771        let labels = result.unwrap();
1772        assert_eq!(labels.len(), 6);
1773
1774        // Check that we have exactly 2 clusters
1775        let unique_labels: std::collections::HashSet<_> = labels.iter().copied().collect();
1776        assert_eq!(unique_labels.len(), 2);
1777    }
1778
1779    #[test]
1780    fn test_optimal_k_selector() {
1781        let data = Array2::from_shape_vec(
1782            (12, 2),
1783            vec![
1784                0.0, 0.0, 0.1, 0.1, 0.2, 0.2, // Cluster 1
1785                5.0, 5.0, 5.1, 5.1, 5.2, 5.2, // Cluster 2
1786                10.0, 10.0, 10.1, 10.1, 10.2, 10.2, // Cluster 3
1787                15.0, 15.0, 15.1, 15.1, 15.2, 15.2, // Cluster 4
1788            ],
1789        )
1790        .unwrap();
1791
1792        let config = StabilityConfig {
1793            k_range: Some((2, 5)),
1794            n_bootstrap: 5,
1795            random_seed: Some(42),
1796            ..Default::default()
1797        };
1798
1799        let selector = OptimalKSelector::new(config);
1800        let result = selector.find_optimal_k(data.view());
1801
1802        assert!(result.is_ok());
1803        let (optimal_k, scores) = result.unwrap();
1804        assert!((2..=5).contains(&optimal_k));
1805        assert_eq!(scores.len(), 4); // k=2,3,4,5
1806    }
1807
1808    #[test]
1809    fn test_gap_statistic() {
1810        let data = Array2::from_shape_vec(
1811            (8, 2),
1812            vec![
1813                0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 5.0, 5.0, 5.1, 5.1, 5.2, 5.2, 5.3, 5.3,
1814            ],
1815        )
1816        .unwrap();
1817
1818        let config = StabilityConfig {
1819            k_range: Some((2, 4)),
1820            n_bootstrap: 5,
1821            random_seed: Some(42),
1822            ..Default::default()
1823        };
1824
1825        let selector = OptimalKSelector::new(config);
1826        let result = selector.gap_statistic(data.view());
1827
1828        assert!(result.is_ok());
1829        let (optimal_k, gap_scores) = result.unwrap();
1830        assert!((2..=4).contains(&optimal_k));
1831        assert_eq!(gap_scores.len(), 3); // k=2,3,4
1832    }
1833}