1use std::collections::HashMap;
30
31use scirs2_core::ndarray::Array2;
32use scirs2_core::random::Random;
33use sklears_core::error::{Result, SklearsError};
34use sklears_core::prelude::*;
35
36#[derive(Debug, Clone)]
38pub struct MultiViewData {
39 pub views: Vec<Array2<f64>>,
41 pub view_names: Option<Vec<String>>,
43 pub n_samples: usize,
45}
46
47impl MultiViewData {
48 pub fn new(views: Vec<Array2<f64>>) -> Result<Self> {
50 if views.is_empty() {
51 return Err(SklearsError::InvalidInput(
52 "At least one view is required".to_string(),
53 ));
54 }
55
56 let n_samples = views[0].nrows();
57 for (i, view) in views.iter().enumerate() {
58 if view.nrows() != n_samples {
59 return Err(SklearsError::InvalidInput(format!(
60 "View {} has different number of samples",
61 i
62 )));
63 }
64 }
65
66 Ok(Self {
67 views,
68 view_names: None,
69 n_samples,
70 })
71 }
72
73 pub fn with_view_names(mut self, names: Vec<String>) -> Result<Self> {
75 if names.len() != self.views.len() {
76 return Err(SklearsError::InvalidInput(
77 "Number of view names must match number of views".to_string(),
78 ));
79 }
80 self.view_names = Some(names);
81 Ok(self)
82 }
83
84 pub fn n_views(&self) -> usize {
86 self.views.len()
87 }
88
89 pub fn get_view(&self, index: usize) -> Result<&Array2<f64>> {
91 self.views.get(index).ok_or_else(|| {
92 SklearsError::InvalidInput(format!("View index {} out of bounds", index))
93 })
94 }
95
96 pub fn view_dimensions(&self) -> Vec<usize> {
98 self.views.iter().map(|v| v.ncols()).collect()
99 }
100}
101
102#[derive(Debug, Clone)]
104pub struct MultiViewKMeansConfig {
105 pub k_clusters: usize,
107 pub max_iter: usize,
109 pub tolerance: f64,
111 pub view_weights: Option<Vec<f64>>,
113 pub weight_learning: WeightLearning,
115 pub random_seed: Option<u64>,
117}
118
119impl Default for MultiViewKMeansConfig {
120 fn default() -> Self {
121 Self {
122 k_clusters: 2,
123 max_iter: 100,
124 tolerance: 1e-4,
125 view_weights: None,
126 weight_learning: WeightLearning::Fixed,
127 random_seed: None,
128 }
129 }
130}
131
132#[derive(Debug, Clone, Copy, PartialEq)]
134pub enum WeightLearning {
135 Fixed,
137 Adaptive,
139 Entropy,
141}
142
143pub struct MultiViewKMeans {
145 config: MultiViewKMeansConfig,
146}
147
148pub struct MultiViewKMeansFitted {
150 pub labels: Vec<i32>,
152 pub centroids: Vec<Array2<f64>>,
154 pub view_weights: Vec<f64>,
156 pub n_iterations: usize,
158 pub view_inertias: Vec<f64>,
160 pub total_inertia: f64,
162}
163
164impl MultiViewKMeans {
165 pub fn new(config: MultiViewKMeansConfig) -> Self {
167 Self { config }
168 }
169
170 pub fn fit(&self, data: &MultiViewData) -> Result<MultiViewKMeansFitted> {
172 let n_views = data.n_views();
173 let n_samples = data.n_samples;
174 let k = self.config.k_clusters;
175
176 if k > n_samples {
177 return Err(SklearsError::InvalidInput(
178 "Number of clusters cannot exceed number of samples".to_string(),
179 ));
180 }
181
182 let mut view_weights = if let Some(weights) = &self.config.view_weights {
184 if weights.len() != n_views {
185 return Err(SklearsError::InvalidInput(
186 "View weights length must match number of views".to_string(),
187 ));
188 }
189 weights.clone()
190 } else {
191 vec![1.0 / n_views as f64; n_views]
192 };
193
194 let mut centroids = self.initialize_centroids(data)?;
196
197 let mut labels = vec![0; n_samples];
199 let mut prev_labels = vec![-1; n_samples];
200
201 let rng = match self.config.random_seed {
202 Some(seed) => Random::seed(seed),
203 None => Random::seed(42),
204 };
205
206 for iteration in 0..self.config.max_iter {
207 for i in 0..n_samples {
209 let mut min_distance = f64::INFINITY;
210 let mut best_cluster = 0;
211
212 for k_idx in 0..k {
213 let mut total_distance = 0.0;
214
215 for (v, view) in data.views.iter().enumerate() {
216 let point = view.row(i);
217 let centroid = centroids[v].row(k_idx);
218 let distance: f64 = point
219 .iter()
220 .zip(centroid.iter())
221 .map(|(a, b)| (a - b).powi(2))
222 .sum();
223 total_distance += view_weights[v] * distance;
224 }
225
226 if total_distance < min_distance {
227 min_distance = total_distance;
228 best_cluster = k_idx;
229 }
230 }
231
232 labels[i] = best_cluster as i32;
233 }
234
235 for v in 0..n_views {
237 let view = &data.views[v];
238 let n_features = view.ncols();
239
240 for k_idx in 0..k {
241 let cluster_points: Vec<usize> = labels
242 .iter()
243 .enumerate()
244 .filter(|(_, &label)| label == k_idx as i32)
245 .map(|(i, _)| i)
246 .collect();
247
248 if !cluster_points.is_empty() {
249 let mut new_centroid = vec![0.0; n_features];
250 for &point_idx in &cluster_points {
251 for j in 0..n_features {
252 new_centroid[j] += view[[point_idx, j]];
253 }
254 }
255 for val in new_centroid.iter_mut() {
256 *val /= cluster_points.len() as f64;
257 }
258
259 for j in 0..n_features {
260 centroids[v][[k_idx, j]] = new_centroid[j];
261 }
262 }
263 }
264 }
265
266 if self.config.weight_learning != WeightLearning::Fixed {
268 view_weights =
269 self.update_view_weights(data, &labels, ¢roids, &view_weights)?;
270 }
271
272 if self.has_converged(&labels, &prev_labels) {
274 break;
275 }
276
277 prev_labels = labels.clone();
278 }
279
280 let view_inertias = self.compute_view_inertias(data, &labels, ¢roids);
282 let total_inertia = view_inertias
283 .iter()
284 .zip(view_weights.iter())
285 .map(|(inertia, weight)| inertia * weight)
286 .sum();
287
288 Ok(MultiViewKMeansFitted {
289 labels,
290 centroids,
291 view_weights,
292 n_iterations: self.config.max_iter,
293 view_inertias,
294 total_inertia,
295 })
296 }
297
298 fn initialize_centroids(&self, data: &MultiViewData) -> Result<Vec<Array2<f64>>> {
300 let n_views = data.n_views();
301 let k = self.config.k_clusters;
302 let mut centroids = Vec::new();
303
304 let mut rng = match self.config.random_seed {
305 Some(seed) => Random::seed(seed),
306 None => Random::seed(42),
307 };
308
309 for v in 0..n_views {
310 let view = &data.views[v];
311 let n_features = view.ncols();
312 let n_samples = view.nrows();
313
314 let mut view_centroids = Array2::zeros((k, n_features));
315
316 let mut selected_indices = (0..n_samples).collect::<Vec<_>>();
318 for i in (1..selected_indices.len()).rev() {
320 let j = rng.gen_range(0..i + 1);
321 selected_indices.swap(i, j);
322 }
323
324 for (k_idx, &sample_idx) in selected_indices.iter().take(k).enumerate() {
325 for j in 0..n_features {
326 view_centroids[[k_idx, j]] = view[[sample_idx, j]];
327 }
328 }
329
330 centroids.push(view_centroids);
331 }
332
333 Ok(centroids)
334 }
335
336 fn update_view_weights(
338 &self,
339 data: &MultiViewData,
340 labels: &[i32],
341 centroids: &[Array2<f64>],
342 current_weights: &[f64],
343 ) -> Result<Vec<f64>> {
344 let n_views = data.n_views();
345 let mut new_weights = vec![0.0; n_views];
346
347 match self.config.weight_learning {
348 WeightLearning::Fixed => Ok(current_weights.to_vec()),
349 WeightLearning::Adaptive => {
350 let view_inertias = self.compute_view_inertias(data, labels, centroids);
352 let total_inv_inertia: f64 = view_inertias
353 .iter()
354 .map(|&inertia| 1.0 / (inertia + 1e-8))
355 .sum();
356
357 for v in 0..n_views {
358 new_weights[v] = (1.0 / (view_inertias[v] + 1e-8)) / total_inv_inertia;
359 }
360
361 Ok(new_weights)
362 }
363 WeightLearning::Entropy => {
364 for v in 0..n_views {
366 let entropy = self.compute_view_entropy(data, labels, v);
367 new_weights[v] = 1.0 / (entropy + 1e-8);
368 }
369
370 let total_weight: f64 = new_weights.iter().sum();
372 for weight in new_weights.iter_mut() {
373 *weight /= total_weight;
374 }
375
376 Ok(new_weights)
377 }
378 }
379 }
380
381 fn compute_view_inertias(
383 &self,
384 data: &MultiViewData,
385 labels: &[i32],
386 centroids: &[Array2<f64>],
387 ) -> Vec<f64> {
388 let n_views = data.n_views();
389 let mut view_inertias = vec![0.0; n_views];
390
391 for v in 0..n_views {
392 let view = &data.views[v];
393 let mut inertia = 0.0;
394
395 for (i, &label) in labels.iter().enumerate() {
396 let point = view.row(i);
397 let centroid = centroids[v].row(label as usize);
398 let distance: f64 = point
399 .iter()
400 .zip(centroid.iter())
401 .map(|(a, b)| (a - b).powi(2))
402 .sum();
403 inertia += distance;
404 }
405
406 view_inertias[v] = inertia;
407 }
408
409 view_inertias
410 }
411
412 fn compute_view_entropy(&self, data: &MultiViewData, labels: &[i32], view_index: usize) -> f64 {
414 let mut cluster_counts = HashMap::new();
416 for &label in labels {
417 *cluster_counts.entry(label).or_insert(0) += 1;
418 }
419
420 let total_points = labels.len() as f64;
421 let mut entropy = 0.0;
422
423 for count in cluster_counts.values() {
424 let p = *count as f64 / total_points;
425 if p > 0.0 {
426 entropy -= p * p.log2();
427 }
428 }
429
430 entropy
431 }
432
433 fn has_converged(&self, current_labels: &[i32], prev_labels: &[i32]) -> bool {
435 if current_labels.len() != prev_labels.len() {
436 return false;
437 }
438
439 let changes = current_labels
440 .iter()
441 .zip(prev_labels.iter())
442 .filter(|(curr, prev)| curr != prev)
443 .count();
444
445 (changes as f64 / current_labels.len() as f64) < self.config.tolerance
446 }
447}
448
449impl Estimator for MultiViewKMeans {
450 type Config = MultiViewKMeansConfig;
451 type Error = SklearsError;
452 type Float = f64;
453
454 fn config(&self) -> &Self::Config {
455 &self.config
456 }
457}
458
459#[derive(Debug, Clone)]
461pub struct ConsensusClusteringConfig {
462 pub base_algorithms: Vec<String>,
464 pub k_clusters: usize,
466 pub consensus_method: ConsensusMethod,
468 pub view_weighting: ViewWeighting,
470 pub random_seed: Option<u64>,
472}
473
474impl Default for ConsensusClusteringConfig {
475 fn default() -> Self {
476 Self {
477 base_algorithms: vec!["kmeans".to_string(), "spectral".to_string()],
478 k_clusters: 2,
479 consensus_method: ConsensusMethod::Voting,
480 view_weighting: ViewWeighting::Equal,
481 random_seed: None,
482 }
483 }
484}
485
486#[derive(Debug, Clone, Copy, PartialEq)]
488pub enum ConsensusMethod {
489 Voting,
491 CoAssociation,
493 EvidenceAccumulation,
495}
496
497#[derive(Debug, Clone, Copy, PartialEq)]
499pub enum ViewWeighting {
500 Equal,
502 Quality,
504 Diversity,
506}
507
508pub struct ConsensusClustering {
510 config: ConsensusClusteringConfig,
511}
512
513pub struct ConsensusClusteringFitted {
515 pub labels: Vec<i32>,
517 pub individual_results: Vec<Vec<i32>>,
519 pub consensus_matrix: Array2<f64>,
521 pub view_weights: Vec<f64>,
523 pub agreement_scores: Vec<f64>,
525}
526
527impl ConsensusClustering {
528 pub fn new(config: ConsensusClusteringConfig) -> Self {
530 Self { config }
531 }
532
533 pub fn fit(&self, data: &MultiViewData) -> Result<ConsensusClusteringFitted> {
535 let n_views = data.n_views();
536 let n_samples = data.n_samples;
537
538 let mut individual_results = Vec::new();
540
541 for v in 0..n_views {
542 for algorithm in &self.config.base_algorithms {
543 let view_data = &data.views[v];
544 let labels = self.run_base_clustering(view_data, algorithm)?;
545 individual_results.push(labels);
546 }
547 }
548
549 let view_weights = self.compute_view_weights(&individual_results)?;
551
552 let consensus_matrix = self.compute_consensus_matrix(&individual_results, &view_weights)?;
554
555 let labels = self.generate_consensus_clustering(&consensus_matrix)?;
557
558 let agreement_scores = self.compute_agreement_scores(&individual_results, &labels);
560
561 Ok(ConsensusClusteringFitted {
562 labels,
563 individual_results,
564 consensus_matrix,
565 view_weights,
566 agreement_scores,
567 })
568 }
569
570 fn run_base_clustering(&self, data: &Array2<f64>, algorithm: &str) -> Result<Vec<i32>> {
572 let n_samples = data.nrows();
573 let k = self.config.k_clusters;
574
575 match algorithm {
576 "kmeans" => {
577 let mut rng = match self.config.random_seed {
579 Some(seed) => Random::seed(seed),
580 None => Random::seed(42),
581 };
582
583 let mut labels = vec![0; n_samples];
584 for i in 0..n_samples {
585 labels[i] = rng.gen_range(0..k) as i32;
586 }
587
588 Ok(labels)
590 }
591 "spectral" => {
592 let mut rng = match self.config.random_seed {
594 Some(seed) => Random::seed(seed),
595 None => Random::seed(42),
596 };
597
598 let mut labels = vec![0; n_samples];
599 for i in 0..n_samples {
600 labels[i] = rng.gen_range(0..k) as i32;
601 }
602
603 Ok(labels)
604 }
605 _ => Err(SklearsError::InvalidInput(format!(
606 "Unsupported algorithm: {}",
607 algorithm
608 ))),
609 }
610 }
611
612 fn compute_view_weights(&self, results: &[Vec<i32>]) -> Result<Vec<f64>> {
614 let n_results = results.len();
615
616 match self.config.view_weighting {
617 ViewWeighting::Equal => Ok(vec![1.0 / n_results as f64; n_results]),
618 ViewWeighting::Quality => {
619 let mut weights = vec![0.0; n_results];
621 for (i, labels) in results.iter().enumerate() {
622 weights[i] = self.compute_clustering_quality(labels);
623 }
624
625 let total_weight: f64 = weights.iter().sum();
627 if total_weight > 0.0 {
628 for weight in weights.iter_mut() {
629 *weight /= total_weight;
630 }
631 }
632
633 Ok(weights)
634 }
635 ViewWeighting::Diversity => {
636 let mut weights = vec![1.0; n_results];
638
639 for i in 0..n_results {
640 let mut diversity_score = 0.0;
641 for j in 0..n_results {
642 if i != j {
643 diversity_score +=
644 self.compute_clustering_distance(&results[i], &results[j]);
645 }
646 }
647 weights[i] = diversity_score / (n_results - 1) as f64;
648 }
649
650 let total_weight: f64 = weights.iter().sum();
652 if total_weight > 0.0 {
653 for weight in weights.iter_mut() {
654 *weight /= total_weight;
655 }
656 }
657
658 Ok(weights)
659 }
660 }
661 }
662
663 fn compute_consensus_matrix(
665 &self,
666 results: &[Vec<i32>],
667 weights: &[f64],
668 ) -> Result<Array2<f64>> {
669 if results.is_empty() {
670 return Err(SklearsError::InvalidInput(
671 "No clustering results provided".to_string(),
672 ));
673 }
674
675 let n_samples = results[0].len();
676 let mut consensus = Array2::zeros((n_samples, n_samples));
677
678 match self.config.consensus_method {
679 ConsensusMethod::CoAssociation => {
680 for (result_idx, labels) in results.iter().enumerate() {
682 let weight = weights[result_idx];
683
684 for i in 0..n_samples {
685 for j in i..n_samples {
686 if labels[i] == labels[j] {
687 consensus[[i, j]] += weight;
688 consensus[[j, i]] += weight;
689 }
690 }
691 }
692 }
693 }
694 ConsensusMethod::Voting | ConsensusMethod::EvidenceAccumulation => {
695 for (result_idx, labels) in results.iter().enumerate() {
697 let weight = weights[result_idx];
698
699 for i in 0..n_samples {
700 for j in i..n_samples {
701 if labels[i] == labels[j] {
702 consensus[[i, j]] += weight;
703 consensus[[j, i]] += weight;
704 }
705 }
706 }
707 }
708 }
709 }
710
711 Ok(consensus)
712 }
713
714 fn generate_consensus_clustering(&self, consensus_matrix: &Array2<f64>) -> Result<Vec<i32>> {
716 let n_samples = consensus_matrix.nrows();
717
718 let mut labels = vec![0; n_samples];
722 let mut current_cluster = 0;
723 let mut visited = vec![false; n_samples];
724
725 for i in 0..n_samples {
726 if !visited[i] {
727 let mut cluster_members = vec![i];
729 visited[i] = true;
730
731 let mut stack = vec![i];
733 while let Some(point) = stack.pop() {
734 for j in 0..n_samples {
735 if !visited[j] && consensus_matrix[[point, j]] > 0.5 {
736 visited[j] = true;
737 cluster_members.push(j);
738 stack.push(j);
739 }
740 }
741 }
742
743 for &member in &cluster_members {
745 labels[member] = current_cluster;
746 }
747 current_cluster += 1;
748 }
749 }
750
751 Ok(labels)
752 }
753
754 fn compute_clustering_quality(&self, labels: &[i32]) -> f64 {
756 let mut cluster_counts = HashMap::new();
758 for &label in labels {
759 *cluster_counts.entry(label).or_insert(0) += 1;
760 }
761
762 let total = labels.len() as f64;
764 let mut entropy = 0.0;
765
766 for count in cluster_counts.values() {
767 let p = *count as f64 / total;
768 if p > 0.0 {
769 entropy -= p * p.log2();
770 }
771 }
772
773 entropy
774 }
775
776 fn compute_clustering_distance(&self, labels1: &[i32], labels2: &[i32]) -> f64 {
778 if labels1.len() != labels2.len() {
779 return 0.0;
780 }
781
782 let n_samples = labels1.len();
783 let mut disagreements = 0;
784
785 for i in 0..n_samples {
786 for j in (i + 1)..n_samples {
787 let same_cluster_1 = labels1[i] == labels1[j];
788 let same_cluster_2 = labels2[i] == labels2[j];
789
790 if same_cluster_1 != same_cluster_2 {
791 disagreements += 1;
792 }
793 }
794 }
795
796 disagreements as f64 / ((n_samples * (n_samples - 1)) / 2) as f64
797 }
798
799 fn compute_agreement_scores(&self, results: &[Vec<i32>], consensus: &[i32]) -> Vec<f64> {
801 results
802 .iter()
803 .map(|labels| 1.0 - self.compute_clustering_distance(labels, consensus))
804 .collect()
805 }
806}
807
808impl Estimator for ConsensusClustering {
809 type Config = ConsensusClusteringConfig;
810 type Error = SklearsError;
811 type Float = f64;
812
813 fn config(&self) -> &Self::Config {
814 &self.config
815 }
816}
817
818#[allow(non_snake_case)]
819#[cfg(test)]
820mod tests {
821 use super::*;
822 use scirs2_core::ndarray::array;
823
824 #[test]
825 fn test_multi_view_data_creation() {
826 let view1 = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
827 let view2 = array![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]];
828
829 let multi_view_data = MultiViewData::new(vec![view1, view2]).unwrap();
830
831 assert_eq!(multi_view_data.n_views(), 2);
832 assert_eq!(multi_view_data.n_samples, 3);
833 assert_eq!(multi_view_data.view_dimensions(), vec![2, 3]);
834 }
835
836 #[test]
837 fn test_multi_view_data_mismatched_samples() {
838 let view1 = array![[1.0, 2.0], [3.0, 4.0]]; let view2 = array![[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]; let result = MultiViewData::new(vec![view1, view2]);
842 assert!(result.is_err());
843 }
844
845 #[test]
846 fn test_multi_view_kmeans_config() {
847 let config = MultiViewKMeansConfig {
848 k_clusters: 3,
849 max_iter: 50,
850 tolerance: 1e-3,
851 view_weights: Some(vec![0.6, 0.4]),
852 weight_learning: WeightLearning::Adaptive,
853 random_seed: Some(42),
854 };
855
856 let clusterer = MultiViewKMeans::new(config);
857 assert_eq!(clusterer.config.k_clusters, 3);
859 }
860
861 #[test]
862 fn test_consensus_clustering_creation() {
863 let config = ConsensusClusteringConfig {
864 base_algorithms: vec!["kmeans".to_string()],
865 k_clusters: 2,
866 consensus_method: ConsensusMethod::CoAssociation,
867 view_weighting: ViewWeighting::Quality,
868 random_seed: Some(42),
869 };
870
871 let clusterer = ConsensusClustering::new(config);
872 assert_eq!(clusterer.config.k_clusters, 2);
873 }
874}