1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use scirs2_core::random::seq::SliceRandom;
10use scirs2_core::random::{Rng, SeedableRng};
11use std::collections::HashSet;
12use std::fmt::Debug;
13
14use crate::error::{ClusteringError, Result};
15use crate::metrics::adjusted_rand_index;
16use crate::vq::kmeans2;
17
18#[derive(Debug, Clone)]
20pub struct StabilityConfig {
21 pub n_bootstrap: usize,
23 pub subsample_ratio: f64,
25 pub random_seed: Option<u64>,
27 pub n_runs_per_bootstrap: usize,
29 pub k_range: Option<(usize, usize)>,
31}
32
33impl Default for StabilityConfig {
34 fn default() -> Self {
35 Self {
36 n_bootstrap: 100,
37 subsample_ratio: 0.8,
38 random_seed: None,
39 n_runs_per_bootstrap: 10,
40 k_range: None,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct StabilityResult<F: Float> {
48 pub stability_scores: Vec<F>,
50 pub consensus_labels: Option<Array1<usize>>,
52 pub optimal_k: Option<usize>,
54 pub mean_stability: F,
56 pub std_stability: F,
58 pub bootstrap_matrix: Array2<F>,
60}
61
62pub struct BootstrapValidator<F: Float> {
68 config: StabilityConfig,
69 phantom: std::marker::PhantomData<F>,
70}
71
72impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
73 BootstrapValidator<F>
74{
75 pub fn new(config: StabilityConfig) -> Self {
77 Self {
78 config,
79 phantom: std::marker::PhantomData,
80 }
81 }
82
83 pub fn assess_kmeans_stability(
85 &self,
86 data: ArrayView2<F>,
87 k: usize,
88 ) -> Result<StabilityResult<F>> {
89 let n_samples = data.shape()[0];
90 let n_features = data.shape()[1];
91
92 if n_samples < 2 {
93 return Err(ClusteringError::InvalidInput(
94 "Need at least 2 samples for stability assessment".into(),
95 ));
96 }
97
98 let subsample_size = ((n_samples as f64) * self.config.subsample_ratio) as usize;
99 if subsample_size < k {
100 return Err(ClusteringError::InvalidInput(
101 "Subsample size must be at least k".into(),
102 ));
103 }
104
105 let mut rng = match self.config.random_seed {
106 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
107 None => {
108 scirs2_core::random::rngs::StdRng::seed_from_u64(42)
110 }
111 };
112
113 let mut bootstrap_results = Vec::new();
114
115 for _iteration in 0..self.config.n_bootstrap {
117 let mut indices: Vec<usize> = (0..n_samples).collect();
119 indices.shuffle(&mut rng);
120 indices.truncate(subsample_size);
121
122 let mut bootstrap_data = Array2::zeros((subsample_size, n_features));
123 for (new_idx, &old_idx) in indices.iter().enumerate() {
124 bootstrap_data.row_mut(new_idx).assign(&data.row(old_idx));
125 }
126
127 let mut run_labels = Vec::new();
129 for _run in 0..self.config.n_runs_per_bootstrap {
130 let seed = rng.random::<u64>();
131
132 match kmeans2(
133 bootstrap_data.view(),
134 k,
135 Some(100), None, None, None, Some(false), Some(seed),
141 ) {
142 Ok((_, labels)) => {
143 let labels_usize: Array1<usize> = labels.mapv(|x| x);
144 run_labels.push(labels_usize);
145 }
146 Err(_) => {
147 let dummy_labels = Array1::zeros(subsample_size);
149 run_labels.push(dummy_labels);
150 }
151 }
152 }
153
154 bootstrap_results.push((indices, run_labels));
155 }
156
157 let stability_scores = self.calculate_stability_scores(&bootstrap_results)?;
159 let mean_stability = stability_scores
160 .iter()
161 .copied()
162 .fold(F::zero(), |acc, x| acc + x)
163 / F::from(stability_scores.len()).unwrap();
164
165 let variance = stability_scores
166 .iter()
167 .map(|&x| {
168 let diff = x - mean_stability;
169 diff * diff
170 })
171 .fold(F::zero(), |acc, x| acc + x)
172 / F::from(stability_scores.len()).unwrap();
173 let std_stability = variance.sqrt();
174
175 let bootstrap_matrix = self.create_bootstrap_matrix(&bootstrap_results, n_samples)?;
177
178 Ok(StabilityResult {
179 stability_scores,
180 consensus_labels: None, optimal_k: None,
182 mean_stability,
183 std_stability,
184 bootstrap_matrix,
185 })
186 }
187
188 fn calculate_stability_scores(
190 &self,
191 bootstrap_results: &[(Vec<usize>, Vec<Array1<usize>>)],
192 ) -> Result<Vec<F>> {
193 let mut scores = Vec::new();
194
195 for (_, run_labels) in bootstrap_results {
196 if run_labels.len() < 2 {
197 continue;
198 }
199
200 let mut pairwise_aris = Vec::new();
202 for i in 0..run_labels.len() {
203 for j in (i + 1)..run_labels.len() {
204 let labels1 = run_labels[i].mapv(|x| x as i32);
205 let labels2 = run_labels[j].mapv(|x| x as i32);
206
207 match adjusted_rand_index::<F>(labels1.view(), labels2.view()) {
208 Ok(ari) => pairwise_aris.push(ari),
209 Err(_) => pairwise_aris.push(F::zero()),
210 }
211 }
212 }
213
214 if !pairwise_aris.is_empty() {
215 let mean_ari = pairwise_aris
216 .iter()
217 .copied()
218 .fold(F::zero(), |acc, x| acc + x)
219 / F::from(pairwise_aris.len()).unwrap();
220 scores.push(mean_ari);
221 }
222 }
223
224 Ok(scores)
225 }
226
227 fn create_bootstrap_matrix(
229 &self,
230 bootstrap_results: &[(Vec<usize>, Vec<Array1<usize>>)],
231 n_samples: usize,
232 ) -> Result<Array2<F>> {
233 let mut co_occurrence_matrix: Array2<F> = Array2::zeros((n_samples, n_samples));
234 let mut count_matrix: Array2<F> = Array2::zeros((n_samples, n_samples));
235
236 for (indices, run_labels) in bootstrap_results {
237 if run_labels.is_empty() {
238 continue;
239 }
240
241 let labels = &run_labels[0];
243
244 for (i, &idx_i) in indices.iter().enumerate() {
246 for (j, &idx_j) in indices.iter().enumerate() {
247 if i != j {
248 count_matrix[[idx_i, idx_j]] = count_matrix[[idx_i, idx_j]] + F::one();
249
250 if labels[i] == labels[j] {
251 co_occurrence_matrix[[idx_i, idx_j]] =
252 co_occurrence_matrix[[idx_i, idx_j]] + F::one();
253 }
254 }
255 }
256 }
257 }
258
259 let mut stability_matrix = Array2::zeros((n_samples, n_samples));
261 for i in 0..n_samples {
262 for j in 0..n_samples {
263 if count_matrix[[i, j]] > F::zero() {
264 stability_matrix[[i, j]] = co_occurrence_matrix[[i, j]] / count_matrix[[i, j]];
265 }
266 }
267 }
268
269 Ok(stability_matrix)
270 }
271}
272
273pub struct ConsensusClusterer<F: Float> {
278 config: StabilityConfig,
279 phantom: std::marker::PhantomData<F>,
280}
281
282impl<F: Float + FromPrimitive + Debug + std::iter::Sum + std::fmt::Display> ConsensusClusterer<F> {
283 pub fn new(config: StabilityConfig) -> Self {
285 Self {
286 config,
287 phantom: std::marker::PhantomData,
288 }
289 }
290
291 pub fn find_consensus_clusters(&self, data: ArrayView2<F>, k: usize) -> Result<Array1<usize>> {
293 let n_samples = data.shape()[0];
294
295 if n_samples < 2 {
296 return Err(ClusteringError::InvalidInput(
297 "Need at least 2 samples for consensus clustering".into(),
298 ));
299 }
300
301 let mut rng = match self.config.random_seed {
302 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
303 None => {
304 scirs2_core::random::rngs::StdRng::seed_from_u64(42)
306 }
307 };
308
309 let mut all_labels = Vec::new();
310
311 for _run in 0..self.config.n_bootstrap {
313 let seed = rng.random::<u64>();
314
315 match kmeans2(
316 data,
317 k,
318 Some(100), None, None, None, Some(false), Some(seed),
324 ) {
325 Ok((_, labels)) => {
326 let labels_usize: Array1<usize> = labels.mapv(|x| x);
327 all_labels.push(labels_usize);
328 }
329 Err(_) => {
330 continue;
332 }
333 }
334 }
335
336 if all_labels.is_empty() {
337 return Err(ClusteringError::ComputationError(
338 "All clustering runs failed".into(),
339 ));
340 }
341
342 let mut consensus_matrix = Array2::zeros((n_samples, n_samples));
344
345 for labels in &all_labels {
346 for i in 0..n_samples {
347 for j in 0..n_samples {
348 if labels[i] == labels[j] {
349 consensus_matrix[[i, j]] = consensus_matrix[[i, j]] + F::one();
350 }
351 }
352 }
353 }
354
355 let n_runs = F::from(all_labels.len()).unwrap();
357 consensus_matrix.mapv_inplace(|x| x / n_runs);
358
359 let threshold = F::from(0.5).unwrap();
361 self.extract_consensus_clusters(&consensus_matrix, threshold, k)
362 }
363
364 fn extract_consensus_clusters(
366 &self,
367 consensus_matrix: &Array2<F>,
368 threshold: F,
369 k: usize,
370 ) -> Result<Array1<usize>> {
371 let n_samples = consensus_matrix.shape()[0];
372 let mut labels = Array1::from_elem(n_samples, usize::MAX); let mut current_cluster = 0;
374
375 let mut unassigned: HashSet<usize> = (0..n_samples).collect();
377
378 while current_cluster < k && !unassigned.is_empty() {
379 let mut best_consensus = F::zero();
381 let mut best_seed = None;
382
383 for &i in &unassigned {
384 for &j in &unassigned {
385 if i != j && consensus_matrix[[i, j]] > best_consensus {
386 best_consensus = consensus_matrix[[i, j]];
387 best_seed = Some(i);
388 }
389 }
390 }
391
392 if let Some(seed) = best_seed {
393 let mut cluster_members = Vec::new();
395 cluster_members.push(seed);
396
397 for &candidate in &unassigned {
399 if candidate != seed && consensus_matrix[[seed, candidate]] >= threshold {
400 cluster_members.push(candidate);
401 }
402 }
403
404 for &member in &cluster_members {
406 labels[member] = current_cluster;
407 unassigned.remove(&member);
408 }
409
410 current_cluster += 1;
411 } else {
412 break;
414 }
415 }
416
417 for &unassigned_point in &unassigned {
419 let mut best_cluster = 0;
420 let mut best_avg_consensus = F::zero();
421
422 for cluster_id in 0..current_cluster {
423 let mut total_consensus = F::zero();
424 let mut count = 0;
425
426 for i in 0..n_samples {
427 if labels[i] == cluster_id {
428 total_consensus = total_consensus + consensus_matrix[[unassigned_point, i]];
429 count += 1;
430 }
431 }
432
433 if count > 0 {
434 let avg_consensus = total_consensus / F::from(count).unwrap();
435 if avg_consensus > best_avg_consensus {
436 best_avg_consensus = avg_consensus;
437 best_cluster = cluster_id;
438 }
439 }
440 }
441
442 labels[unassigned_point] = best_cluster;
443 }
444
445 Ok(labels)
446 }
447}
448
449pub struct OptimalKSelector<F: Float> {
451 config: StabilityConfig,
452 phantom: std::marker::PhantomData<F>,
453}
454
455impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
456 OptimalKSelector<F>
457{
458 pub fn new(config: StabilityConfig) -> Self {
460 Self {
461 config,
462 phantom: std::marker::PhantomData,
463 }
464 }
465
466 pub fn find_optimal_k(&self, data: ArrayView2<F>) -> Result<(usize, Vec<F>)> {
468 let (k_min, k_max) = self.config.k_range.unwrap_or((2, 10));
469 let mut stability_scores = Vec::new();
470
471 for k in k_min..=k_max {
472 let validator = BootstrapValidator::new(self.config.clone());
473 match validator.assess_kmeans_stability(data, k) {
474 Ok(result) => stability_scores.push(result.mean_stability),
475 Err(_) => stability_scores.push(F::zero()),
476 }
477 }
478
479 let mut best_k = k_min;
481 let mut best_score = F::neg_infinity();
482
483 for (i, &score) in stability_scores.iter().enumerate() {
484 if score > best_score {
485 best_score = score;
486 best_k = k_min + i;
487 }
488 }
489
490 Ok((best_k, stability_scores))
491 }
492
493 pub fn gap_statistic(&self, data: ArrayView2<F>) -> Result<(usize, Vec<F>)> {
495 let (k_min, k_max) = self.config.k_range.unwrap_or((2, 10));
496 let n_samples = data.shape()[0];
497 let n_features = data.shape()[1];
498
499 let mut gap_scores = Vec::new();
500
501 let mut min_vals = Array1::from_elem(n_features, F::infinity());
503 let mut max_vals = Array1::from_elem(n_features, F::neg_infinity());
504
505 for i in 0..n_samples {
506 for j in 0..n_features {
507 let val = data[[i, j]];
508 if val < min_vals[j] {
509 min_vals[j] = val;
510 }
511 if val > max_vals[j] {
512 max_vals[j] = val;
513 }
514 }
515 }
516
517 for k in k_min..=k_max {
518 let original_wk = self.calculate_within_cluster_dispersion(data, k)?;
520 let log_wk = original_wk.ln();
521
522 let mut reference_log_wks = Vec::new();
524 let mut rng = match self.config.random_seed {
525 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
526 None => {
527 scirs2_core::random::rngs::StdRng::seed_from_u64(42)
529 }
530 };
531
532 for _b in 0..self.config.n_bootstrap {
533 let mut reference_data = Array2::zeros((n_samples, n_features));
535 for i in 0..n_samples {
536 for j in 0..n_features {
537 let range = max_vals[j] - min_vals[j];
538 let random_val =
539 min_vals[j] + range * F::from(rng.random::<f64>()).unwrap();
540 reference_data[[i, j]] = random_val;
541 }
542 }
543
544 let reference_wk =
545 self.calculate_within_cluster_dispersion(reference_data.view(), k)?;
546 reference_log_wks.push(reference_wk.ln());
547 }
548
549 let expected_log_wk = reference_log_wks
551 .iter()
552 .copied()
553 .fold(F::zero(), |acc, x| acc + x)
554 / F::from(reference_log_wks.len()).unwrap();
555 let gap = expected_log_wk - log_wk;
556 gap_scores.push(gap);
557 }
558
559 let mut optimal_k = k_min;
561 for i in 0..(gap_scores.len() - 1) {
562 if gap_scores[i] >= gap_scores[i + 1] {
563 optimal_k = k_min + i;
564 break;
565 }
566 }
567
568 Ok((optimal_k, gap_scores))
569 }
570
571 fn calculate_within_cluster_dispersion(&self, data: ArrayView2<F>, k: usize) -> Result<F> {
573 match kmeans2(
575 data,
576 k,
577 Some(100), None, None, None, Some(false), self.config.random_seed,
583 ) {
584 Ok((centroids, labels)) => {
585 let mut total_dispersion = F::zero();
586
587 for cluster_id in 0..k {
588 let mut cluster_dispersion = F::zero();
589 let mut cluster_size = 0;
590
591 for i in 0..data.shape()[0] {
593 if labels[i] == cluster_id {
594 let mut sq_dist = F::zero();
595 for j in 0..data.shape()[1] {
596 let diff = data[[i, j]] - centroids[[cluster_id, j]];
597 sq_dist = sq_dist + diff * diff;
598 }
599 cluster_dispersion = cluster_dispersion + sq_dist;
600 cluster_size += 1;
601 }
602 }
603
604 if cluster_size > 1 {
606 total_dispersion =
607 total_dispersion + cluster_dispersion / F::from(cluster_size).unwrap();
608 }
609 }
610
611 Ok(total_dispersion)
612 }
613 Err(e) => Err(e),
614 }
615 }
616}
617
618pub mod advanced {
620 use super::*;
621 use crate::ensemble::{EnsembleClusterer, EnsembleConfig};
622 use crate::metrics::{mutual_info_score, silhouette_score};
623
624 pub struct CrossValidationStability<F: Float> {
629 config: StabilityConfig,
630 n_folds: usize,
631 _phantom: std::marker::PhantomData<F>,
632 }
633
634 impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
635 CrossValidationStability<F>
636 {
637 pub fn new(config: StabilityConfig, n_folds: usize) -> Self {
639 Self {
640 config,
641 n_folds,
642 _phantom: std::marker::PhantomData,
643 }
644 }
645
646 pub fn assess_stability(
648 &self,
649 data: ArrayView2<F>,
650 k: usize,
651 ) -> Result<StabilityResult<F>> {
652 let n_samples = data.shape()[0];
653 let fold_size = n_samples / self.n_folds;
654 let mut stability_scores = Vec::new();
655 let mut bootstrap_matrix = Array2::zeros((self.n_folds, self.n_folds));
656
657 for fold in 0..self.n_folds {
659 let start_idx = fold * fold_size;
660 let end_idx = if fold == self.n_folds - 1 {
661 n_samples
662 } else {
663 (fold + 1) * fold_size
664 };
665
666 let mut train_indices = Vec::new();
668 for i in 0..n_samples {
669 if i < start_idx || i >= end_idx {
670 train_indices.push(i);
671 }
672 }
673
674 let train_data =
676 Array2::from_shape_fn((train_indices.len(), data.shape()[1]), |(i, j)| {
677 data[[train_indices[i], j]]
678 });
679
680 let (train_centroids, train_labels) = kmeans2(
682 train_data.view(),
683 k,
684 Some(100), Some(F::from(1e-6).unwrap()), None, None, None, Some(42), )?;
691
692 let test_labels = Array1::from_shape_fn(end_idx - start_idx, |i| {
694 let test_point = data.row(start_idx + i);
695 let mut min_dist = F::infinity();
696 let mut closest_cluster = 0;
697
698 for (cluster_id, centroid) in train_centroids.outer_iter().enumerate() {
699 let dist = test_point
700 .iter()
701 .zip(centroid.iter())
702 .map(|(a, b)| (*a - *b) * (*a - *b))
703 .sum::<F>()
704 .sqrt();
705
706 if dist < min_dist {
707 min_dist = dist;
708 closest_cluster = cluster_id;
709 }
710 }
711 closest_cluster
712 });
713
714 let stability = self.calculate_fold_stability(&test_labels, k)?;
716 stability_scores.push(stability);
717 }
718
719 let mean_stability = stability_scores.iter().fold(F::zero(), |acc, x| acc + *x)
721 / F::from(stability_scores.len()).unwrap();
722 let variance = stability_scores
723 .iter()
724 .map(|&s| (s - mean_stability) * (s - mean_stability))
725 .fold(F::zero(), |acc, x| acc + x)
726 / F::from(stability_scores.len()).unwrap();
727 let std_stability = variance.sqrt();
728
729 Ok(StabilityResult {
730 stability_scores,
731 consensus_labels: None,
732 optimal_k: None,
733 mean_stability,
734 std_stability,
735 bootstrap_matrix,
736 })
737 }
738
739 fn calculate_fold_stability(&self, labels: &Array1<usize>, k: usize) -> Result<F> {
740 let mut cluster_cohesion = F::zero();
742 let mut total_pairs = 0;
743
744 for cluster_id in 0..k {
745 let cluster_members: Vec<_> = labels
746 .iter()
747 .enumerate()
748 .filter(|(_, &label)| label == cluster_id)
749 .map(|(idx_, _)| idx_)
750 .collect();
751
752 let cluster_size = cluster_members.len();
753 if cluster_size > 1 {
754 let pairs = cluster_size * (cluster_size - 1) / 2;
755 cluster_cohesion = cluster_cohesion + F::from(pairs).unwrap();
756 total_pairs += pairs;
757 }
758 }
759
760 if total_pairs == 0 {
761 Ok(F::zero())
762 } else {
763 Ok(cluster_cohesion / F::from(total_pairs).unwrap())
764 }
765 }
766 }
767
768 pub struct PerturbationStability<F: Float> {
773 config: StabilityConfig,
774 perturbation_types: Vec<PerturbationType>,
775 _phantom: std::marker::PhantomData<F>,
776 }
777
778 #[derive(Debug, Clone)]
780 pub enum PerturbationType {
781 GaussianNoise { std_dev: f64 },
783 SampleRemoval { removal_rate: f64 },
785 FeatureNoise { noise_level: f64 },
787 OutlierInjection {
789 outlier_rate: f64,
790 outlier_magnitude: f64,
791 },
792 }
793
794 impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
795 PerturbationStability<F>
796 {
797 pub fn new(config: StabilityConfig, perturbation_types: Vec<PerturbationType>) -> Self {
799 Self {
800 config,
801 perturbation_types,
802 _phantom: std::marker::PhantomData,
803 }
804 }
805
806 pub fn assess_stability(
808 &self,
809 data: ArrayView2<F>,
810 k: usize,
811 ) -> Result<StabilityResult<F>> {
812 let mut all_stability_scores = Vec::new();
813 let mut rng = scirs2_core::random::rng();
814
815 let (baseline_centroids, baseline_labels) = kmeans2(
817 data,
818 k,
819 Some(100), Some(F::from(1e-6).unwrap()), None, None, None, Some(42), )?;
826
827 for perturbation in &self.perturbation_types {
829 let mut perturbation_scores = Vec::new();
830
831 for _ in 0..self.config.n_bootstrap {
832 let perturbed_data = self.apply_perturbation(data, perturbation, &mut rng)?;
834
835 let (_, perturbed_labels) = kmeans2(
837 perturbed_data.view(),
838 k,
839 Some(100), Some(F::from(1e-6).unwrap()), None, None, None, None, )?;
846
847 let similarity =
849 self.calculate_label_similarity(&baseline_labels, &perturbed_labels)?;
850 perturbation_scores.push(similarity);
851 }
852
853 all_stability_scores.extend(perturbation_scores);
854 }
855
856 let mean_stability = all_stability_scores
858 .iter()
859 .fold(F::zero(), |acc, x| acc + *x)
860 / F::from(all_stability_scores.len()).unwrap();
861 let variance = all_stability_scores
862 .iter()
863 .map(|&s| (s - mean_stability) * (s - mean_stability))
864 .sum::<F>()
865 / F::from(all_stability_scores.len()).unwrap();
866 let std_stability = variance.sqrt();
867
868 let bootstrap_matrix =
869 Array2::zeros((self.config.n_bootstrap, self.perturbation_types.len()));
870
871 Ok(StabilityResult {
872 stability_scores: all_stability_scores,
873 consensus_labels: None,
874 optimal_k: None,
875 mean_stability,
876 std_stability,
877 bootstrap_matrix,
878 })
879 }
880
881 fn apply_perturbation(
882 &self,
883 data: ArrayView2<F>,
884 perturbation: &PerturbationType,
885 rng: &mut impl Rng,
886 ) -> Result<Array2<F>> {
887 let mut perturbed = data.to_owned();
888
889 match perturbation {
890 PerturbationType::GaussianNoise { std_dev } => {
891 for elem in perturbed.iter_mut() {
892 let noise = rng.random::<f64>() * std_dev;
893 *elem = *elem + F::from(noise).unwrap();
894 }
895 }
896 PerturbationType::SampleRemoval { removal_rate } => {
897 let n_samples = data.shape()[0];
898 let n_remove = (n_samples as f64 * removal_rate) as usize;
899 let mut indices: Vec<_> = (0..n_samples).collect();
900 indices.shuffle(rng);
901 indices.truncate(n_samples - n_remove);
902 indices.sort();
903
904 let mut new_data = Array2::zeros((indices.len(), data.shape()[1]));
905 for (new_i, &old_i) in indices.iter().enumerate() {
906 new_data.row_mut(new_i).assign(&data.row(old_i));
907 }
908 perturbed = new_data;
909 }
910 PerturbationType::FeatureNoise { noise_level } => {
911 for elem in perturbed.iter_mut() {
912 let noise = (rng.random::<f64>() - 0.5) * 2.0 * noise_level;
913 *elem = *elem + F::from(noise).unwrap();
914 }
915 }
916 PerturbationType::OutlierInjection {
917 outlier_rate,
918 outlier_magnitude,
919 } => {
920 let n_samples = data.shape()[0];
921 let n_outliers = (n_samples as f64 * outlier_rate) as usize;
922
923 for _ in 0..n_outliers {
924 let sample_idx = rng.random_range(0..n_samples);
925 let feature_idx = rng.random_range(0..data.shape()[1]);
926 let outlier_value = rng.random::<f64>() * outlier_magnitude;
927 perturbed[[sample_idx, feature_idx]] = F::from(outlier_value).unwrap();
928 }
929 }
930 }
931
932 Ok(perturbed)
933 }
934
935 fn calculate_label_similarity(
936 &self,
937 labels1: &Array1<usize>,
938 labels2: &Array1<usize>,
939 ) -> Result<F> {
940 if labels1.len() != labels2.len() {
941 return Ok(F::zero());
942 }
943
944 let labels1_i32: Array1<i32> = labels1.mapv(|x| x as i32);
946 let labels2_i32: Array1<i32> = labels2.mapv(|x| x as i32);
947
948 let ari: f64 = adjusted_rand_index(labels1_i32.view(), labels2_i32.view())?;
950 Ok(F::from(ari).unwrap())
951 }
952 }
953
954 pub struct MultiScaleStability<F: Float> {
959 config: StabilityConfig,
960 scale_factors: Vec<f64>,
961 _phantom: std::marker::PhantomData<F>,
962 }
963
964 impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
965 MultiScaleStability<F>
966 {
967 pub fn new(config: StabilityConfig, scale_factors: Vec<f64>) -> Self {
969 Self {
970 config,
971 scale_factors,
972 _phantom: std::marker::PhantomData,
973 }
974 }
975
976 pub fn assess_stability(
978 &self,
979 data: ArrayView2<F>,
980 k_range: (usize, usize),
981 ) -> Result<Vec<StabilityResult<F>>> {
982 let mut results = Vec::new();
983
984 for &scale_factor in &self.scale_factors {
985 let scaled_data = data.mapv(|x| x * F::from(scale_factor).unwrap());
987
988 for k in k_range.0..=k_range.1 {
990 let validator = BootstrapValidator::new(self.config.clone());
991 let stability_result =
992 validator.assess_kmeans_stability(scaled_data.view(), k)?;
993 results.push(stability_result);
994 }
995 }
996
997 Ok(results)
998 }
999
1000 pub fn find_optimal_scale_and_k(
1002 &self,
1003 data: ArrayView2<F>,
1004 k_range: (usize, usize),
1005 ) -> Result<(f64, usize, F)> {
1006 let results = self.assess_stability(data, k_range)?;
1007
1008 let mut best_scale = self.scale_factors[0];
1009 let mut best_k = k_range.0;
1010 let mut best_stability = F::neg_infinity();
1011
1012 let mut result_idx = 0;
1013 for &scale_factor in &self.scale_factors {
1014 for k in k_range.0..=k_range.1 {
1015 if result_idx < results.len() {
1016 let stability = results[result_idx].mean_stability;
1017 if stability > best_stability {
1018 best_stability = stability;
1019 best_scale = scale_factor;
1020 best_k = k;
1021 }
1022 result_idx += 1;
1023 }
1024 }
1025 }
1026
1027 Ok((best_scale, best_k, best_stability))
1028 }
1029 }
1030
1031 pub struct PredictionStrength<F: Float> {
1037 pub config: PredictionStrengthConfig,
1039 phantom: std::marker::PhantomData<F>,
1040 }
1041
1042 #[derive(Debug, Clone)]
1044 pub struct PredictionStrengthConfig {
1045 pub n_bootstrap: usize,
1047 pub train_ratio: f64,
1049 pub strength_threshold: f64,
1051 pub random_seed: Option<u64>,
1053 }
1054
1055 impl Default for PredictionStrengthConfig {
1056 fn default() -> Self {
1057 Self {
1058 n_bootstrap: 50,
1059 train_ratio: 0.5,
1060 strength_threshold: 0.8,
1061 random_seed: None,
1062 }
1063 }
1064 }
1065
1066 impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
1067 PredictionStrength<F>
1068 {
1069 pub fn new(config: PredictionStrengthConfig) -> Self {
1071 Self {
1072 config,
1073 phantom: std::marker::PhantomData,
1074 }
1075 }
1076
1077 pub fn assess_k_range(
1079 &self,
1080 data: ArrayView2<F>,
1081 k_range: (usize, usize),
1082 ) -> Result<Vec<F>> {
1083 let mut prediction_strengths = Vec::new();
1084
1085 for k in k_range.0..=k_range.1 {
1086 let strength = self.compute_prediction_strength(data, k)?;
1087 prediction_strengths.push(strength);
1088 }
1089
1090 Ok(prediction_strengths)
1091 }
1092
1093 pub fn compute_prediction_strength(&self, data: ArrayView2<F>, k: usize) -> Result<F> {
1095 let mut rng = match self.config.random_seed {
1096 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
1097 None => scirs2_core::random::rngs::StdRng::seed_from_u64(
1098 scirs2_core::random::rng().random(),
1099 ),
1100 };
1101
1102 let n_samples = data.nrows();
1103 let train_size = ((n_samples as f64) * self.config.train_ratio) as usize;
1104
1105 let mut prediction_scores = Vec::new();
1106
1107 for _ in 0..self.config.n_bootstrap {
1108 let mut indices: Vec<usize> = (0..n_samples).collect();
1110 indices.shuffle(&mut rng);
1111
1112 let train_indices = &indices[..train_size];
1113 let test_indices = &indices[train_size..];
1114
1115 if test_indices.is_empty() {
1116 continue;
1117 }
1118
1119 let train_data = data.select(scirs2_core::ndarray::Axis(0), train_indices);
1121 let test_data = data.select(scirs2_core::ndarray::Axis(0), test_indices);
1122
1123 match kmeans2(train_data.view(), k, None, None, None, None, None, None) {
1125 Ok((_, train_labels)) => {
1126 match kmeans2(test_data.view(), k, None, None, None, None, None, None) {
1128 Ok((_, test_labels)) => {
1129 let strength = self.compute_pairwise_prediction_strength(
1131 &train_data,
1132 &test_data,
1133 &train_labels,
1134 &test_labels,
1135 )?;
1136 prediction_scores.push(strength);
1137 }
1138 Err(_) => continue,
1139 }
1140 }
1141 Err(_) => continue,
1142 }
1143 }
1144
1145 if prediction_scores.is_empty() {
1146 return Ok(F::zero());
1147 }
1148
1149 let sum: F = prediction_scores.iter().fold(F::zero(), |acc, &x| acc + x);
1151 Ok(sum / F::from(prediction_scores.len()).unwrap())
1152 }
1153
1154 fn compute_pairwise_prediction_strength(
1156 &self,
1157 train_data: &Array2<F>,
1158 test_data: &Array2<F>,
1159 train_labels: &Array1<usize>,
1160 test_labels: &Array1<usize>,
1161 ) -> Result<F> {
1162 let test_size = test_data.nrows();
1163 let mut correct_predictions = 0;
1164 let mut total_predictions = 0;
1165
1166 for i in 0..test_size {
1168 for j in (i + 1)..test_size {
1169 let closest_train_i = self.find_closest_point(&test_data.row(i), train_data)?;
1171 let closest_train_j = self.find_closest_point(&test_data.row(j), train_data)?;
1172
1173 let predicted_same =
1175 train_labels[closest_train_i] == train_labels[closest_train_j];
1176 let actual_same = test_labels[i] == test_labels[j];
1177
1178 if predicted_same == actual_same {
1179 correct_predictions += 1;
1180 }
1181 total_predictions += 1;
1182 }
1183 }
1184
1185 if total_predictions == 0 {
1186 return Ok(F::zero());
1187 }
1188
1189 Ok(F::from(correct_predictions as f64 / total_predictions as f64).unwrap())
1190 }
1191
1192 fn find_closest_point(
1194 &self,
1195 test_point: &scirs2_core::ndarray::ArrayView1<F>,
1196 train_data: &Array2<F>,
1197 ) -> Result<usize> {
1198 let mut min_distance = F::infinity();
1199 let mut closest_idx = 0;
1200
1201 for (idx, train_point) in train_data.rows().into_iter().enumerate() {
1202 let distance = test_point
1203 .iter()
1204 .zip(train_point.iter())
1205 .map(|(a, b)| (*a - *b) * (*a - *b))
1206 .fold(F::zero(), |acc, x| acc + x)
1207 .sqrt();
1208
1209 if distance < min_distance {
1210 min_distance = distance;
1211 closest_idx = idx;
1212 }
1213 }
1214
1215 Ok(closest_idx)
1216 }
1217
1218 pub fn find_optimal_k(
1220 &self,
1221 data: ArrayView2<F>,
1222 k_range: (usize, usize),
1223 ) -> Result<usize> {
1224 let strengths = self.assess_k_range(data, k_range)?;
1225
1226 for (idx, &strength) in strengths.iter().enumerate().rev() {
1228 if strength >= F::from(self.config.strength_threshold).unwrap() {
1229 return Ok(k_range.0 + idx);
1230 }
1231 }
1232
1233 let best_idx = strengths
1235 .iter()
1236 .enumerate()
1237 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1238 .map(|(idx_, _)| idx_)
1239 .unwrap_or(0);
1240
1241 Ok(k_range.0 + best_idx)
1242 }
1243 }
1244
1245 pub struct JaccardStability<F: Float> {
1250 pub n_bootstrap: usize,
1252 pub subsample_ratio: f64,
1254 pub random_seed: Option<u64>,
1256 _phantom: std::marker::PhantomData<F>,
1257 }
1258
1259 impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
1260 JaccardStability<F>
1261 {
1262 pub fn new(n_bootstrap: usize, subsample_ratio: f64, random_seed: Option<u64>) -> Self {
1264 Self {
1265 n_bootstrap,
1266 subsample_ratio,
1267 random_seed,
1268 _phantom: std::marker::PhantomData,
1269 }
1270 }
1271
1272 pub fn compute_stability(&self, data: ArrayView2<F>, k: usize) -> Result<F> {
1274 let mut rng = match self.random_seed {
1275 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
1276 None => scirs2_core::random::rngs::StdRng::seed_from_u64(
1277 scirs2_core::random::rng().random(),
1278 ),
1279 };
1280
1281 let n_samples = data.nrows();
1282 let subsample_size = ((n_samples as f64) * self.subsample_ratio) as usize;
1283
1284 let mut jaccard_scores = Vec::new();
1285
1286 for _ in 0..self.n_bootstrap {
1288 let mut indices1: Vec<usize> = (0..n_samples).collect();
1290 indices1.shuffle(&mut rng);
1291 let sample_indices1 = &indices1[..subsample_size];
1292 let sample_data1 = data.select(scirs2_core::ndarray::Axis(0), sample_indices1);
1293
1294 let mut indices2: Vec<usize> = (0..n_samples).collect();
1296 indices2.shuffle(&mut rng);
1297 let sample_indices2 = &indices2[..subsample_size];
1298 let sample_data2 = data.select(scirs2_core::ndarray::Axis(0), sample_indices2);
1299
1300 match (
1302 kmeans2(sample_data1.view(), k, None, None, None, None, None, None),
1303 kmeans2(sample_data2.view(), k, None, None, None, None, None, None),
1304 ) {
1305 (Ok((_, labels1)), Ok((_, labels2))) => {
1306 let overlap_indices: Vec<(usize, usize)> = sample_indices1
1308 .iter()
1309 .enumerate()
1310 .filter_map(|(i1, &idx1)| {
1311 sample_indices2
1312 .iter()
1313 .enumerate()
1314 .find(|(_, &idx2)| idx1 == idx2)
1315 .map(|(i2_, _)| (i1, i2_))
1316 })
1317 .collect();
1318
1319 if overlap_indices.len() >= 2 {
1320 let jaccard = self.compute_jaccard_similarity(
1321 &labels1,
1322 &labels2,
1323 &overlap_indices,
1324 )?;
1325 jaccard_scores.push(jaccard);
1326 }
1327 }
1328 _ => continue,
1329 }
1330 }
1331
1332 if jaccard_scores.is_empty() {
1333 return Ok(F::zero());
1334 }
1335
1336 let sum: F = jaccard_scores.iter().fold(F::zero(), |acc, &x| acc + x);
1338 Ok(sum / F::from(jaccard_scores.len()).unwrap())
1339 }
1340
1341 fn compute_jaccard_similarity(
1343 &self,
1344 labels1: &Array1<usize>,
1345 labels2: &Array1<usize>,
1346 overlap_indices: &[(usize, usize)],
1347 ) -> Result<F> {
1348 let mut same_cluster_both = 0;
1349 let mut same_cluster_either = 0;
1350
1351 let n_overlap = overlap_indices.len();
1352
1353 for i in 0..n_overlap {
1354 for j in (i + 1)..n_overlap {
1355 let (idx1_i, idx2_i) = overlap_indices[i];
1356 let (idx1_j, idx2_j) = overlap_indices[j];
1357
1358 let same_in_clustering1 = labels1[idx1_i] == labels1[idx1_j];
1359 let same_in_clustering2 = labels2[idx2_i] == labels2[idx2_j];
1360
1361 if same_in_clustering1 && same_in_clustering2 {
1362 same_cluster_both += 1;
1363 }
1364 if same_in_clustering1 || same_in_clustering2 {
1365 same_cluster_either += 1;
1366 }
1367 }
1368 }
1369
1370 if same_cluster_either == 0 {
1371 return Ok(F::one()); }
1373
1374 Ok(F::from(same_cluster_both as f64 / same_cluster_either as f64).unwrap())
1375 }
1376
1377 pub fn assess_k_range(
1379 &self,
1380 data: ArrayView2<F>,
1381 k_range: (usize, usize),
1382 ) -> Result<Vec<F>> {
1383 let mut stabilities = Vec::new();
1384
1385 for k in k_range.0..=k_range.1 {
1386 let stability = self.compute_stability(data, k)?;
1387 stabilities.push(stability);
1388 }
1389
1390 Ok(stabilities)
1391 }
1392 }
1393
1394 pub struct ClusterSpecificStability<F: Float> {
1399 pub config: StabilityConfig,
1401 phantom: std::marker::PhantomData<F>,
1402 }
1403
1404 #[derive(Debug, Clone)]
1406 pub struct ClusterStabilityResult<F: Float> {
1407 pub cluster_stabilities: Vec<F>,
1409 pub mean_stability: F,
1411 pub std_stability: F,
1413 pub size_consistency: Vec<F>,
1415 }
1416
1417 impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
1418 ClusterSpecificStability<F>
1419 {
1420 pub fn new(config: StabilityConfig) -> Self {
1422 Self {
1423 config,
1424 phantom: std::marker::PhantomData,
1425 }
1426 }
1427
1428 pub fn assess_cluster_stability(
1430 &self,
1431 data: ArrayView2<F>,
1432 k: usize,
1433 ) -> Result<ClusterStabilityResult<F>> {
1434 let mut rng = match self.config.random_seed {
1435 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
1436 None => scirs2_core::random::rngs::StdRng::seed_from_u64(
1437 scirs2_core::random::rng().random(),
1438 ),
1439 };
1440
1441 let n_samples = data.nrows();
1442 let subsample_size = ((n_samples as f64) * self.config.subsample_ratio) as usize;
1443
1444 let mut cluster_memberships: Vec<Vec<HashSet<usize>>> = vec![Vec::new(); k];
1445 let mut cluster_sizes: Vec<Vec<usize>> = vec![Vec::new(); k];
1446
1447 for _ in 0..self.config.n_bootstrap {
1449 let mut indices: Vec<usize> = (0..n_samples).collect();
1450 indices.shuffle(&mut rng);
1451 let sample_indices = &indices[..subsample_size];
1452 let sample_data = data.select(scirs2_core::ndarray::Axis(0), sample_indices);
1453
1454 match kmeans2(sample_data.view(), k, None, None, None, None, None, None) {
1455 Ok((_, labels)) => {
1456 for cluster_id in 0..k {
1458 let mut cluster_members = HashSet::new();
1459 for (local_idx, &label) in labels.iter().enumerate() {
1460 if label == cluster_id {
1461 cluster_members.insert(sample_indices[local_idx]);
1462 }
1463 }
1464 cluster_memberships[cluster_id].push(cluster_members.clone());
1465 cluster_sizes[cluster_id].push(cluster_members.len());
1466 }
1467 }
1468 Err(_) => continue,
1469 }
1470 }
1471
1472 let mut cluster_stabilities = Vec::new();
1474 let mut size_consistency = Vec::new();
1475
1476 for cluster_id in 0..k {
1477 let stability = self.compute_cluster_stability(&cluster_memberships[cluster_id])?;
1478 cluster_stabilities.push(stability);
1479
1480 let consistency = self.compute_size_consistency(&cluster_sizes[cluster_id])?;
1481 size_consistency.push(consistency);
1482 }
1483
1484 let mean_stability = cluster_stabilities
1486 .iter()
1487 .fold(F::zero(), |acc, &x| acc + x)
1488 / F::from(cluster_stabilities.len()).unwrap();
1489
1490 let variance = cluster_stabilities
1491 .iter()
1492 .map(|&x| (x - mean_stability) * (x - mean_stability))
1493 .fold(F::zero(), |acc, x| acc + x)
1494 / F::from(cluster_stabilities.len()).unwrap();
1495 let std_stability = variance.sqrt();
1496
1497 Ok(ClusterStabilityResult {
1498 cluster_stabilities,
1499 mean_stability,
1500 std_stability,
1501 size_consistency,
1502 })
1503 }
1504
1505 fn compute_cluster_stability(&self, cluster_samples: &[HashSet<usize>]) -> Result<F> {
1507 if cluster_samples.len() < 2 {
1508 return Ok(F::zero());
1509 }
1510
1511 let mut jaccard_scores = Vec::new();
1512
1513 for i in 0..cluster_samples.len() {
1515 for j in (i + 1)..cluster_samples.len() {
1516 let intersection_size =
1517 cluster_samples[i].intersection(&cluster_samples[j]).count();
1518 let union_size = cluster_samples[i].union(&cluster_samples[j]).count();
1519
1520 if union_size > 0 {
1521 let jaccard = intersection_size as f64 / union_size as f64;
1522 jaccard_scores.push(F::from(jaccard).unwrap());
1523 }
1524 }
1525 }
1526
1527 if jaccard_scores.is_empty() {
1528 return Ok(F::zero());
1529 }
1530
1531 let sum: F = jaccard_scores.iter().fold(F::zero(), |acc, &x| acc + x);
1533 Ok(sum / F::from(jaccard_scores.len()).unwrap())
1534 }
1535
1536 fn compute_size_consistency(&self, sizes: &[usize]) -> Result<F> {
1538 if sizes.is_empty() {
1539 return Ok(F::zero());
1540 }
1541
1542 let mean_size = sizes.iter().sum::<usize>() as f64 / sizes.len() as f64;
1543 let variance = sizes
1544 .iter()
1545 .map(|&size| (size as f64 - mean_size).powi(2))
1546 .sum::<f64>()
1547 / sizes.len() as f64;
1548
1549 let cv = if mean_size > 0.0 {
1550 variance.sqrt() / mean_size
1551 } else {
1552 0.0
1553 };
1554 Ok(F::one() - F::from(cv).unwrap()) }
1556 }
1557
1558 pub struct ParameterStabilityAnalyzer<F: Float> {
1563 pub base_k: usize,
1565 pub perturbation_ranges: Vec<f64>,
1567 pub n_samples_per_range: usize,
1569 pub random_seed: Option<u64>,
1571 _phantom: std::marker::PhantomData<F>,
1572 }
1573
1574 #[derive(Debug, Clone)]
1576 pub struct ParameterStabilityResult<F: Float> {
1577 pub stability_by_perturbation: Vec<F>,
1579 pub sensitivity_profile: Vec<F>,
1581 pub robust_range: (f64, f64),
1583 }
1584
1585 impl<F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display>
1586 ParameterStabilityAnalyzer<F>
1587 {
1588 pub fn new(
1590 base_k: usize,
1591 perturbation_ranges: Vec<f64>,
1592 n_samples_per_range: usize,
1593 random_seed: Option<u64>,
1594 ) -> Self {
1595 Self {
1596 base_k,
1597 perturbation_ranges,
1598 n_samples_per_range,
1599 random_seed,
1600 _phantom: std::marker::PhantomData,
1601 }
1602 }
1603
1604 pub fn analyze_stability(
1606 &self,
1607 data: ArrayView2<F>,
1608 ) -> Result<ParameterStabilityResult<F>> {
1609 let mut rng = match self.random_seed {
1610 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
1611 None => scirs2_core::random::rngs::StdRng::seed_from_u64(
1612 scirs2_core::random::rng().random(),
1613 ),
1614 };
1615
1616 let mut stability_by_perturbation = Vec::new();
1617 let mut sensitivity_profile = Vec::new();
1618
1619 let baseline_result = kmeans2(data, self.base_k, None, None, None, None, None, None)?;
1621
1622 for &perturbation_level in &self.perturbation_ranges {
1623 let mut stability_scores = Vec::new();
1624
1625 for _ in 0..self.n_samples_per_range {
1626 let k_perturbation = (F::from(rng.random::<f64>()).unwrap()
1628 - F::from(0.5).unwrap())
1629 * F::from(2.0).unwrap()
1630 * F::from(perturbation_level).unwrap();
1631 let perturbed_k = (self.base_k as f64
1632 * (1.0 + k_perturbation.to_f64().unwrap()))
1633 .round()
1634 .max(1.0) as usize;
1635
1636 match kmeans2(data, perturbed_k, None, None, None, None, None, None) {
1637 Ok((_, perturbed_labels)) => {
1638 let baseline_i32 = baseline_result.1.mapv(|x| x as i32);
1641 let perturbed_i32 = perturbed_labels.mapv(|x| x as i32);
1642 match adjusted_rand_index(baseline_i32.view(), perturbed_i32.view()) {
1643 Ok(stability) => stability_scores.push(stability),
1644 Err(_) => continue,
1645 }
1646 }
1647 Err(_) => continue,
1648 }
1649 }
1650
1651 if !stability_scores.is_empty() {
1652 let mean_stability = stability_scores.iter().fold(F::zero(), |acc, &x| acc + x)
1653 / F::from(stability_scores.len()).unwrap();
1654 stability_by_perturbation.push(mean_stability);
1655
1656 sensitivity_profile.push(F::one() - mean_stability);
1658 }
1659 }
1660
1661 let robust_range = self.find_robust_range(&sensitivity_profile);
1663
1664 Ok(ParameterStabilityResult {
1665 stability_by_perturbation,
1666 sensitivity_profile,
1667 robust_range,
1668 })
1669 }
1670
1671 fn find_robust_range(&self, sensitivity_profile: &[F]) -> (f64, f64) {
1673 if sensitivity_profile.is_empty() {
1674 return (0.0, 0.0);
1675 }
1676
1677 let min_sensitivity = sensitivity_profile
1679 .iter()
1680 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1681 .unwrap();
1682
1683 let max_sensitivity = sensitivity_profile
1685 .iter()
1686 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1687 .unwrap();
1688 let threshold =
1689 *min_sensitivity + (*max_sensitivity - *min_sensitivity) * F::from(0.1).unwrap();
1690
1691 let mut start_idx = None;
1693 let mut end_idx = None;
1694
1695 for (idx, &sensitivity) in sensitivity_profile.iter().enumerate() {
1696 if sensitivity <= threshold {
1697 if start_idx.is_none() {
1698 start_idx = Some(idx);
1699 }
1700 end_idx = Some(idx);
1701 }
1702 }
1703
1704 let start_range = start_idx
1705 .map(|idx| self.perturbation_ranges[idx])
1706 .unwrap_or(0.0);
1707 let end_range = end_idx
1708 .map(|idx| self.perturbation_ranges[idx])
1709 .unwrap_or(0.0);
1710
1711 (start_range, end_range)
1712 }
1713 }
1714}
1715
1716#[cfg(test)]
1717mod tests {
1718 use super::*;
1719 use scirs2_core::ndarray::Array2;
1720
1721 #[test]
1722 fn test_stability_config_default() {
1723 let config = StabilityConfig::default();
1724 assert_eq!(config.n_bootstrap, 100);
1725 assert_eq!(config.subsample_ratio, 0.8);
1726 assert_eq!(config.n_runs_per_bootstrap, 10);
1727 assert!(config.random_seed.is_none());
1728 }
1729
1730 #[test]
1731 fn test_bootstrap_validator() {
1732 let data =
1733 Array2::from_shape_vec((20, 2), (0..40).map(|i| i as f64 / 10.0).collect()).unwrap();
1734
1735 let config = StabilityConfig {
1736 n_bootstrap: 5,
1737 subsample_ratio: 0.8,
1738 n_runs_per_bootstrap: 3,
1739 random_seed: Some(42),
1740 k_range: None,
1741 };
1742
1743 let validator = BootstrapValidator::new(config);
1744 let result = validator.assess_kmeans_stability(data.view(), 2);
1745
1746 assert!(result.is_ok());
1747 let stability_result = result.unwrap();
1748 assert!(stability_result.mean_stability >= 0.0);
1749 assert!(stability_result.mean_stability <= 1.0);
1750 assert_eq!(stability_result.bootstrap_matrix.shape(), &[20, 20]);
1751 }
1752
1753 #[test]
1754 fn test_consensus_clusterer() {
1755 let data = Array2::from_shape_vec(
1756 (6, 2),
1757 vec![0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 5.0, 5.0, 5.1, 5.1, 5.2, 5.2],
1758 )
1759 .unwrap();
1760
1761 let config = StabilityConfig {
1762 n_bootstrap: 10,
1763 random_seed: Some(42),
1764 ..Default::default()
1765 };
1766
1767 let consensus = ConsensusClusterer::new(config);
1768 let result = consensus.find_consensus_clusters(data.view(), 2);
1769
1770 assert!(result.is_ok());
1771 let labels = result.unwrap();
1772 assert_eq!(labels.len(), 6);
1773
1774 let unique_labels: std::collections::HashSet<_> = labels.iter().copied().collect();
1776 assert_eq!(unique_labels.len(), 2);
1777 }
1778
1779 #[test]
1780 fn test_optimal_k_selector() {
1781 let data = Array2::from_shape_vec(
1782 (12, 2),
1783 vec![
1784 0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 5.0, 5.0, 5.1, 5.1, 5.2, 5.2, 10.0, 10.0, 10.1, 10.1, 10.2, 10.2, 15.0, 15.0, 15.1, 15.1, 15.2, 15.2, ],
1789 )
1790 .unwrap();
1791
1792 let config = StabilityConfig {
1793 k_range: Some((2, 5)),
1794 n_bootstrap: 5,
1795 random_seed: Some(42),
1796 ..Default::default()
1797 };
1798
1799 let selector = OptimalKSelector::new(config);
1800 let result = selector.find_optimal_k(data.view());
1801
1802 assert!(result.is_ok());
1803 let (optimal_k, scores) = result.unwrap();
1804 assert!((2..=5).contains(&optimal_k));
1805 assert_eq!(scores.len(), 4); }
1807
1808 #[test]
1809 fn test_gap_statistic() {
1810 let data = Array2::from_shape_vec(
1811 (8, 2),
1812 vec![
1813 0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 5.0, 5.0, 5.1, 5.1, 5.2, 5.2, 5.3, 5.3,
1814 ],
1815 )
1816 .unwrap();
1817
1818 let config = StabilityConfig {
1819 k_range: Some((2, 4)),
1820 n_bootstrap: 5,
1821 random_seed: Some(42),
1822 ..Default::default()
1823 };
1824
1825 let selector = OptimalKSelector::new(config);
1826 let result = selector.gap_statistic(data.view());
1827
1828 assert!(result.is_ok());
1829 let (optimal_k, gap_scores) = result.unwrap();
1830 assert!((2..=4).contains(&optimal_k));
1831 assert_eq!(gap_scores.len(), 3); }
1833}