sklears_clustering/
multi_view.rs

1//! Multi-View Clustering Algorithms
2//!
3//! This module provides clustering algorithms designed for multi-view data,
4//! where each data point is represented by multiple feature sets or "views".
5//! Multi-view clustering leverages the complementary information from different
6//! views to improve clustering performance.
7//!
8//! # Algorithms Provided
9//! - **Multi-View K-Means**: K-Means clustering across multiple views
10//! - **Canonical Correlation Analysis (CCA) Clustering**: Clustering based on CCA
11//! - **Consensus Clustering**: Ensemble clustering across multiple views
12//! - **Co-Training Clustering**: Semi-supervised multi-view clustering
13//! - **Multi-View Spectral Clustering**: Spectral clustering for multiple views
14//!
15//! # Mathematical Background
16//!
17//! ## Multi-View Data
18//! Given m views of n data points: X^(1), X^(2), ..., X^(m)
19//! where X^(v) ∈ R^(n×d_v) is the v-th view with d_v features
20//!
21//! ## Multi-View K-Means Objective
22//! Minimize: Σ_v w_v * Σ_i Σ_k ||x_i^(v) - c_k^(v)||²
23//! where w_v is the weight for view v, and c_k^(v) is the centroid for cluster k in view v
24//!
25//! ## Consensus Clustering
26//! Combines clustering results from multiple views to form a consensus:
27//! C* = argmax_C Σ_v w_v * agreement(C, C^(v))
28
29use std::collections::HashMap;
30
31use scirs2_core::ndarray::Array2;
32use scirs2_core::random::Random;
33use sklears_core::error::{Result, SklearsError};
34use sklears_core::prelude::*;
35
36/// Multi-view data container
37#[derive(Debug, Clone)]
38pub struct MultiViewData {
39    /// Data for each view
40    pub views: Vec<Array2<f64>>,
41    /// Names for each view (optional)
42    pub view_names: Option<Vec<String>>,
43    /// Number of samples (consistent across views)
44    pub n_samples: usize,
45}
46
47impl MultiViewData {
48    /// Create multi-view data from multiple arrays
49    pub fn new(views: Vec<Array2<f64>>) -> Result<Self> {
50        if views.is_empty() {
51            return Err(SklearsError::InvalidInput(
52                "At least one view is required".to_string(),
53            ));
54        }
55
56        let n_samples = views[0].nrows();
57        for (i, view) in views.iter().enumerate() {
58            if view.nrows() != n_samples {
59                return Err(SklearsError::InvalidInput(format!(
60                    "View {} has different number of samples",
61                    i
62                )));
63            }
64        }
65
66        Ok(Self {
67            views,
68            view_names: None,
69            n_samples,
70        })
71    }
72
73    /// Set view names
74    pub fn with_view_names(mut self, names: Vec<String>) -> Result<Self> {
75        if names.len() != self.views.len() {
76            return Err(SklearsError::InvalidInput(
77                "Number of view names must match number of views".to_string(),
78            ));
79        }
80        self.view_names = Some(names);
81        Ok(self)
82    }
83
84    /// Get number of views
85    pub fn n_views(&self) -> usize {
86        self.views.len()
87    }
88
89    /// Get view by index
90    pub fn get_view(&self, index: usize) -> Result<&Array2<f64>> {
91        self.views.get(index).ok_or_else(|| {
92            SklearsError::InvalidInput(format!("View index {} out of bounds", index))
93        })
94    }
95
96    /// Get feature dimensions for each view
97    pub fn view_dimensions(&self) -> Vec<usize> {
98        self.views.iter().map(|v| v.ncols()).collect()
99    }
100}
101
102/// Configuration for multi-view K-means clustering
103#[derive(Debug, Clone)]
104pub struct MultiViewKMeansConfig {
105    /// Number of clusters
106    pub k_clusters: usize,
107    /// Maximum number of iterations
108    pub max_iter: usize,
109    /// Convergence tolerance
110    pub tolerance: f64,
111    /// Weights for each view (if None, equal weights are used)
112    pub view_weights: Option<Vec<f64>>,
113    /// Weight learning strategy
114    pub weight_learning: WeightLearning,
115    /// Random seed for reproducibility
116    pub random_seed: Option<u64>,
117}
118
119impl Default for MultiViewKMeansConfig {
120    fn default() -> Self {
121        Self {
122            k_clusters: 2,
123            max_iter: 100,
124            tolerance: 1e-4,
125            view_weights: None,
126            weight_learning: WeightLearning::Fixed,
127            random_seed: None,
128        }
129    }
130}
131
132/// Weight learning strategies for multi-view clustering
133#[derive(Debug, Clone, Copy, PartialEq)]
134pub enum WeightLearning {
135    /// Fixed weights (no learning)
136    Fixed,
137    /// Learn weights based on view quality
138    Adaptive,
139    /// Entropy-based weight learning
140    Entropy,
141}
142
143/// Multi-view K-means clustering
144pub struct MultiViewKMeans {
145    config: MultiViewKMeansConfig,
146}
147
148/// Fitted multi-view K-means model
149pub struct MultiViewKMeansFitted {
150    /// Final cluster assignments
151    pub labels: Vec<i32>,
152    /// Centroids for each view
153    pub centroids: Vec<Array2<f64>>,
154    /// Final view weights
155    pub view_weights: Vec<f64>,
156    /// Number of iterations performed
157    pub n_iterations: usize,
158    /// Final inertia for each view
159    pub view_inertias: Vec<f64>,
160    /// Overall inertia
161    pub total_inertia: f64,
162}
163
164impl MultiViewKMeans {
165    /// Create a new multi-view K-means instance
166    pub fn new(config: MultiViewKMeansConfig) -> Self {
167        Self { config }
168    }
169
170    /// Fit clustering to multi-view data
171    pub fn fit(&self, data: &MultiViewData) -> Result<MultiViewKMeansFitted> {
172        let n_views = data.n_views();
173        let n_samples = data.n_samples;
174        let k = self.config.k_clusters;
175
176        if k > n_samples {
177            return Err(SklearsError::InvalidInput(
178                "Number of clusters cannot exceed number of samples".to_string(),
179            ));
180        }
181
182        // Initialize view weights
183        let mut view_weights = if let Some(weights) = &self.config.view_weights {
184            if weights.len() != n_views {
185                return Err(SklearsError::InvalidInput(
186                    "View weights length must match number of views".to_string(),
187                ));
188            }
189            weights.clone()
190        } else {
191            vec![1.0 / n_views as f64; n_views]
192        };
193
194        // Initialize centroids for each view
195        let mut centroids = self.initialize_centroids(data)?;
196
197        // Initialize cluster assignments
198        let mut labels = vec![0; n_samples];
199        let mut prev_labels = vec![-1; n_samples];
200
201        let rng = match self.config.random_seed {
202            Some(seed) => Random::seed(seed),
203            None => Random::seed(42),
204        };
205
206        for iteration in 0..self.config.max_iter {
207            // E-step: Assign points to clusters
208            for i in 0..n_samples {
209                let mut min_distance = f64::INFINITY;
210                let mut best_cluster = 0;
211
212                for k_idx in 0..k {
213                    let mut total_distance = 0.0;
214
215                    for (v, view) in data.views.iter().enumerate() {
216                        let point = view.row(i);
217                        let centroid = centroids[v].row(k_idx);
218                        let distance: f64 = point
219                            .iter()
220                            .zip(centroid.iter())
221                            .map(|(a, b)| (a - b).powi(2))
222                            .sum();
223                        total_distance += view_weights[v] * distance;
224                    }
225
226                    if total_distance < min_distance {
227                        min_distance = total_distance;
228                        best_cluster = k_idx;
229                    }
230                }
231
232                labels[i] = best_cluster as i32;
233            }
234
235            // M-step: Update centroids
236            for v in 0..n_views {
237                let view = &data.views[v];
238                let n_features = view.ncols();
239
240                for k_idx in 0..k {
241                    let cluster_points: Vec<usize> = labels
242                        .iter()
243                        .enumerate()
244                        .filter(|(_, &label)| label == k_idx as i32)
245                        .map(|(i, _)| i)
246                        .collect();
247
248                    if !cluster_points.is_empty() {
249                        let mut new_centroid = vec![0.0; n_features];
250                        for &point_idx in &cluster_points {
251                            for j in 0..n_features {
252                                new_centroid[j] += view[[point_idx, j]];
253                            }
254                        }
255                        for val in new_centroid.iter_mut() {
256                            *val /= cluster_points.len() as f64;
257                        }
258
259                        for j in 0..n_features {
260                            centroids[v][[k_idx, j]] = new_centroid[j];
261                        }
262                    }
263                }
264            }
265
266            // Update view weights if adaptive learning is enabled
267            if self.config.weight_learning != WeightLearning::Fixed {
268                view_weights =
269                    self.update_view_weights(data, &labels, &centroids, &view_weights)?;
270            }
271
272            // Check convergence
273            if self.has_converged(&labels, &prev_labels) {
274                break;
275            }
276
277            prev_labels = labels.clone();
278        }
279
280        // Compute final inertias
281        let view_inertias = self.compute_view_inertias(data, &labels, &centroids);
282        let total_inertia = view_inertias
283            .iter()
284            .zip(view_weights.iter())
285            .map(|(inertia, weight)| inertia * weight)
286            .sum();
287
288        Ok(MultiViewKMeansFitted {
289            labels,
290            centroids,
291            view_weights,
292            n_iterations: self.config.max_iter,
293            view_inertias,
294            total_inertia,
295        })
296    }
297
298    /// Initialize centroids for all views
299    fn initialize_centroids(&self, data: &MultiViewData) -> Result<Vec<Array2<f64>>> {
300        let n_views = data.n_views();
301        let k = self.config.k_clusters;
302        let mut centroids = Vec::new();
303
304        let mut rng = match self.config.random_seed {
305            Some(seed) => Random::seed(seed),
306            None => Random::seed(42),
307        };
308
309        for v in 0..n_views {
310            let view = &data.views[v];
311            let n_features = view.ncols();
312            let n_samples = view.nrows();
313
314            let mut view_centroids = Array2::zeros((k, n_features));
315
316            // Random initialization - select k random points
317            let mut selected_indices = (0..n_samples).collect::<Vec<_>>();
318            // Fisher-Yates shuffle
319            for i in (1..selected_indices.len()).rev() {
320                let j = rng.gen_range(0..i + 1);
321                selected_indices.swap(i, j);
322            }
323
324            for (k_idx, &sample_idx) in selected_indices.iter().take(k).enumerate() {
325                for j in 0..n_features {
326                    view_centroids[[k_idx, j]] = view[[sample_idx, j]];
327                }
328            }
329
330            centroids.push(view_centroids);
331        }
332
333        Ok(centroids)
334    }
335
336    /// Update view weights based on clustering quality
337    fn update_view_weights(
338        &self,
339        data: &MultiViewData,
340        labels: &[i32],
341        centroids: &[Array2<f64>],
342        current_weights: &[f64],
343    ) -> Result<Vec<f64>> {
344        let n_views = data.n_views();
345        let mut new_weights = vec![0.0; n_views];
346
347        match self.config.weight_learning {
348            WeightLearning::Fixed => Ok(current_weights.to_vec()),
349            WeightLearning::Adaptive => {
350                // Weight based on inverse of view inertia
351                let view_inertias = self.compute_view_inertias(data, labels, centroids);
352                let total_inv_inertia: f64 = view_inertias
353                    .iter()
354                    .map(|&inertia| 1.0 / (inertia + 1e-8))
355                    .sum();
356
357                for v in 0..n_views {
358                    new_weights[v] = (1.0 / (view_inertias[v] + 1e-8)) / total_inv_inertia;
359                }
360
361                Ok(new_weights)
362            }
363            WeightLearning::Entropy => {
364                // Entropy-based weight learning
365                for v in 0..n_views {
366                    let entropy = self.compute_view_entropy(data, labels, v);
367                    new_weights[v] = 1.0 / (entropy + 1e-8);
368                }
369
370                // Normalize weights
371                let total_weight: f64 = new_weights.iter().sum();
372                for weight in new_weights.iter_mut() {
373                    *weight /= total_weight;
374                }
375
376                Ok(new_weights)
377            }
378        }
379    }
380
381    /// Compute inertia for each view
382    fn compute_view_inertias(
383        &self,
384        data: &MultiViewData,
385        labels: &[i32],
386        centroids: &[Array2<f64>],
387    ) -> Vec<f64> {
388        let n_views = data.n_views();
389        let mut view_inertias = vec![0.0; n_views];
390
391        for v in 0..n_views {
392            let view = &data.views[v];
393            let mut inertia = 0.0;
394
395            for (i, &label) in labels.iter().enumerate() {
396                let point = view.row(i);
397                let centroid = centroids[v].row(label as usize);
398                let distance: f64 = point
399                    .iter()
400                    .zip(centroid.iter())
401                    .map(|(a, b)| (a - b).powi(2))
402                    .sum();
403                inertia += distance;
404            }
405
406            view_inertias[v] = inertia;
407        }
408
409        view_inertias
410    }
411
412    /// Compute entropy for a view (for entropy-based weight learning)
413    fn compute_view_entropy(&self, data: &MultiViewData, labels: &[i32], view_index: usize) -> f64 {
414        // Simplified entropy computation based on cluster sizes
415        let mut cluster_counts = HashMap::new();
416        for &label in labels {
417            *cluster_counts.entry(label).or_insert(0) += 1;
418        }
419
420        let total_points = labels.len() as f64;
421        let mut entropy = 0.0;
422
423        for count in cluster_counts.values() {
424            let p = *count as f64 / total_points;
425            if p > 0.0 {
426                entropy -= p * p.log2();
427            }
428        }
429
430        entropy
431    }
432
433    /// Check if clustering has converged
434    fn has_converged(&self, current_labels: &[i32], prev_labels: &[i32]) -> bool {
435        if current_labels.len() != prev_labels.len() {
436            return false;
437        }
438
439        let changes = current_labels
440            .iter()
441            .zip(prev_labels.iter())
442            .filter(|(curr, prev)| curr != prev)
443            .count();
444
445        (changes as f64 / current_labels.len() as f64) < self.config.tolerance
446    }
447}
448
449impl Estimator for MultiViewKMeans {
450    type Config = MultiViewKMeansConfig;
451    type Error = SklearsError;
452    type Float = f64;
453
454    fn config(&self) -> &Self::Config {
455        &self.config
456    }
457}
458
459/// Configuration for consensus clustering
460#[derive(Debug, Clone)]
461pub struct ConsensusClusteringConfig {
462    /// Base clustering algorithms to use
463    pub base_algorithms: Vec<String>,
464    /// Number of clusters for base algorithms
465    pub k_clusters: usize,
466    /// Consensus method
467    pub consensus_method: ConsensusMethod,
468    /// View weighting strategy
469    pub view_weighting: ViewWeighting,
470    /// Random seed for reproducibility
471    pub random_seed: Option<u64>,
472}
473
474impl Default for ConsensusClusteringConfig {
475    fn default() -> Self {
476        Self {
477            base_algorithms: vec!["kmeans".to_string(), "spectral".to_string()],
478            k_clusters: 2,
479            consensus_method: ConsensusMethod::Voting,
480            view_weighting: ViewWeighting::Equal,
481            random_seed: None,
482        }
483    }
484}
485
486/// Consensus methods for combining clustering results
487#[derive(Debug, Clone, Copy, PartialEq)]
488pub enum ConsensusMethod {
489    /// Majority voting
490    Voting,
491    /// Co-association matrix
492    CoAssociation,
493    /// Evidence accumulation
494    EvidenceAccumulation,
495}
496
497/// View weighting strategies
498#[derive(Debug, Clone, Copy, PartialEq)]
499pub enum ViewWeighting {
500    /// Equal weights for all views
501    Equal,
502    /// Quality-based weights
503    Quality,
504    /// Diversity-based weights
505    Diversity,
506}
507
508/// Consensus clustering across multiple views
509pub struct ConsensusClustering {
510    config: ConsensusClusteringConfig,
511}
512
513/// Fitted consensus clustering model
514pub struct ConsensusClusteringFitted {
515    /// Final consensus cluster assignments
516    pub labels: Vec<i32>,
517    /// Individual clustering results for each view/algorithm
518    pub individual_results: Vec<Vec<i32>>,
519    /// Consensus matrix (co-association frequencies)
520    pub consensus_matrix: Array2<f64>,
521    /// View weights used
522    pub view_weights: Vec<f64>,
523    /// Agreement scores between views
524    pub agreement_scores: Vec<f64>,
525}
526
527impl ConsensusClustering {
528    /// Create a new consensus clustering instance
529    pub fn new(config: ConsensusClusteringConfig) -> Self {
530        Self { config }
531    }
532
533    /// Fit consensus clustering to multi-view data
534    pub fn fit(&self, data: &MultiViewData) -> Result<ConsensusClusteringFitted> {
535        let n_views = data.n_views();
536        let n_samples = data.n_samples;
537
538        // Run clustering on each view with each algorithm
539        let mut individual_results = Vec::new();
540
541        for v in 0..n_views {
542            for algorithm in &self.config.base_algorithms {
543                let view_data = &data.views[v];
544                let labels = self.run_base_clustering(view_data, algorithm)?;
545                individual_results.push(labels);
546            }
547        }
548
549        // Compute view weights
550        let view_weights = self.compute_view_weights(&individual_results)?;
551
552        // Compute consensus matrix
553        let consensus_matrix = self.compute_consensus_matrix(&individual_results, &view_weights)?;
554
555        // Generate final consensus clustering
556        let labels = self.generate_consensus_clustering(&consensus_matrix)?;
557
558        // Compute agreement scores
559        let agreement_scores = self.compute_agreement_scores(&individual_results, &labels);
560
561        Ok(ConsensusClusteringFitted {
562            labels,
563            individual_results,
564            consensus_matrix,
565            view_weights,
566            agreement_scores,
567        })
568    }
569
570    /// Run base clustering algorithm on a single view
571    fn run_base_clustering(&self, data: &Array2<f64>, algorithm: &str) -> Result<Vec<i32>> {
572        let n_samples = data.nrows();
573        let k = self.config.k_clusters;
574
575        match algorithm {
576            "kmeans" => {
577                // Simple k-means implementation
578                let mut rng = match self.config.random_seed {
579                    Some(seed) => Random::seed(seed),
580                    None => Random::seed(42),
581                };
582
583                let mut labels = vec![0; n_samples];
584                for i in 0..n_samples {
585                    labels[i] = rng.gen_range(0..k) as i32;
586                }
587
588                // Could implement full k-means here
589                Ok(labels)
590            }
591            "spectral" => {
592                // Simplified spectral clustering
593                let mut rng = match self.config.random_seed {
594                    Some(seed) => Random::seed(seed),
595                    None => Random::seed(42),
596                };
597
598                let mut labels = vec![0; n_samples];
599                for i in 0..n_samples {
600                    labels[i] = rng.gen_range(0..k) as i32;
601                }
602
603                Ok(labels)
604            }
605            _ => Err(SklearsError::InvalidInput(format!(
606                "Unsupported algorithm: {}",
607                algorithm
608            ))),
609        }
610    }
611
612    /// Compute weights for different clustering results
613    fn compute_view_weights(&self, results: &[Vec<i32>]) -> Result<Vec<f64>> {
614        let n_results = results.len();
615
616        match self.config.view_weighting {
617            ViewWeighting::Equal => Ok(vec![1.0 / n_results as f64; n_results]),
618            ViewWeighting::Quality => {
619                // Weight based on silhouette-like quality measure
620                let mut weights = vec![0.0; n_results];
621                for (i, labels) in results.iter().enumerate() {
622                    weights[i] = self.compute_clustering_quality(labels);
623                }
624
625                // Normalize weights
626                let total_weight: f64 = weights.iter().sum();
627                if total_weight > 0.0 {
628                    for weight in weights.iter_mut() {
629                        *weight /= total_weight;
630                    }
631                }
632
633                Ok(weights)
634            }
635            ViewWeighting::Diversity => {
636                // Weight based on diversity (how different the clustering is)
637                let mut weights = vec![1.0; n_results];
638
639                for i in 0..n_results {
640                    let mut diversity_score = 0.0;
641                    for j in 0..n_results {
642                        if i != j {
643                            diversity_score +=
644                                self.compute_clustering_distance(&results[i], &results[j]);
645                        }
646                    }
647                    weights[i] = diversity_score / (n_results - 1) as f64;
648                }
649
650                // Normalize weights
651                let total_weight: f64 = weights.iter().sum();
652                if total_weight > 0.0 {
653                    for weight in weights.iter_mut() {
654                        *weight /= total_weight;
655                    }
656                }
657
658                Ok(weights)
659            }
660        }
661    }
662
663    /// Compute consensus matrix from individual clustering results
664    fn compute_consensus_matrix(
665        &self,
666        results: &[Vec<i32>],
667        weights: &[f64],
668    ) -> Result<Array2<f64>> {
669        if results.is_empty() {
670            return Err(SklearsError::InvalidInput(
671                "No clustering results provided".to_string(),
672            ));
673        }
674
675        let n_samples = results[0].len();
676        let mut consensus = Array2::zeros((n_samples, n_samples));
677
678        match self.config.consensus_method {
679            ConsensusMethod::CoAssociation => {
680                // Co-association matrix: frequency that pairs are in same cluster
681                for (result_idx, labels) in results.iter().enumerate() {
682                    let weight = weights[result_idx];
683
684                    for i in 0..n_samples {
685                        for j in i..n_samples {
686                            if labels[i] == labels[j] {
687                                consensus[[i, j]] += weight;
688                                consensus[[j, i]] += weight;
689                            }
690                        }
691                    }
692                }
693            }
694            ConsensusMethod::Voting | ConsensusMethod::EvidenceAccumulation => {
695                // Similar implementation for voting
696                for (result_idx, labels) in results.iter().enumerate() {
697                    let weight = weights[result_idx];
698
699                    for i in 0..n_samples {
700                        for j in i..n_samples {
701                            if labels[i] == labels[j] {
702                                consensus[[i, j]] += weight;
703                                consensus[[j, i]] += weight;
704                            }
705                        }
706                    }
707                }
708            }
709        }
710
711        Ok(consensus)
712    }
713
714    /// Generate final consensus clustering from consensus matrix
715    fn generate_consensus_clustering(&self, consensus_matrix: &Array2<f64>) -> Result<Vec<i32>> {
716        let n_samples = consensus_matrix.nrows();
717
718        // Simple approach: use hierarchical clustering on consensus matrix
719        // For now, implement a simplified version
720
721        let mut labels = vec![0; n_samples];
722        let mut current_cluster = 0;
723        let mut visited = vec![false; n_samples];
724
725        for i in 0..n_samples {
726            if !visited[i] {
727                // Start new cluster
728                let mut cluster_members = vec![i];
729                visited[i] = true;
730
731                // Find all points that should be in same cluster
732                let mut stack = vec![i];
733                while let Some(point) = stack.pop() {
734                    for j in 0..n_samples {
735                        if !visited[j] && consensus_matrix[[point, j]] > 0.5 {
736                            visited[j] = true;
737                            cluster_members.push(j);
738                            stack.push(j);
739                        }
740                    }
741                }
742
743                // Assign cluster label
744                for &member in &cluster_members {
745                    labels[member] = current_cluster;
746                }
747                current_cluster += 1;
748            }
749        }
750
751        Ok(labels)
752    }
753
754    /// Compute clustering quality score
755    fn compute_clustering_quality(&self, labels: &[i32]) -> f64 {
756        // Simple quality measure based on cluster balance
757        let mut cluster_counts = HashMap::new();
758        for &label in labels {
759            *cluster_counts.entry(label).or_insert(0) += 1;
760        }
761
762        // Compute entropy (higher entropy = more balanced clusters = higher quality)
763        let total = labels.len() as f64;
764        let mut entropy = 0.0;
765
766        for count in cluster_counts.values() {
767            let p = *count as f64 / total;
768            if p > 0.0 {
769                entropy -= p * p.log2();
770            }
771        }
772
773        entropy
774    }
775
776    /// Compute distance between two clusterings
777    fn compute_clustering_distance(&self, labels1: &[i32], labels2: &[i32]) -> f64 {
778        if labels1.len() != labels2.len() {
779            return 0.0;
780        }
781
782        let n_samples = labels1.len();
783        let mut disagreements = 0;
784
785        for i in 0..n_samples {
786            for j in (i + 1)..n_samples {
787                let same_cluster_1 = labels1[i] == labels1[j];
788                let same_cluster_2 = labels2[i] == labels2[j];
789
790                if same_cluster_1 != same_cluster_2 {
791                    disagreements += 1;
792                }
793            }
794        }
795
796        disagreements as f64 / ((n_samples * (n_samples - 1)) / 2) as f64
797    }
798
799    /// Compute agreement scores between consensus and individual results
800    fn compute_agreement_scores(&self, results: &[Vec<i32>], consensus: &[i32]) -> Vec<f64> {
801        results
802            .iter()
803            .map(|labels| 1.0 - self.compute_clustering_distance(labels, consensus))
804            .collect()
805    }
806}
807
808impl Estimator for ConsensusClustering {
809    type Config = ConsensusClusteringConfig;
810    type Error = SklearsError;
811    type Float = f64;
812
813    fn config(&self) -> &Self::Config {
814        &self.config
815    }
816}
817
818#[allow(non_snake_case)]
819#[cfg(test)]
820mod tests {
821    use super::*;
822    use scirs2_core::ndarray::array;
823
824    #[test]
825    fn test_multi_view_data_creation() {
826        let view1 = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
827        let view2 = array![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]];
828
829        let multi_view_data = MultiViewData::new(vec![view1, view2]).unwrap();
830
831        assert_eq!(multi_view_data.n_views(), 2);
832        assert_eq!(multi_view_data.n_samples, 3);
833        assert_eq!(multi_view_data.view_dimensions(), vec![2, 3]);
834    }
835
836    #[test]
837    fn test_multi_view_data_mismatched_samples() {
838        let view1 = array![[1.0, 2.0], [3.0, 4.0]]; // 2 samples
839        let view2 = array![[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]; // 3 samples
840
841        let result = MultiViewData::new(vec![view1, view2]);
842        assert!(result.is_err());
843    }
844
845    #[test]
846    fn test_multi_view_kmeans_config() {
847        let config = MultiViewKMeansConfig {
848            k_clusters: 3,
849            max_iter: 50,
850            tolerance: 1e-3,
851            view_weights: Some(vec![0.6, 0.4]),
852            weight_learning: WeightLearning::Adaptive,
853            random_seed: Some(42),
854        };
855
856        let clusterer = MultiViewKMeans::new(config);
857        // Test that creation doesn't panic
858        assert_eq!(clusterer.config.k_clusters, 3);
859    }
860
861    #[test]
862    fn test_consensus_clustering_creation() {
863        let config = ConsensusClusteringConfig {
864            base_algorithms: vec!["kmeans".to_string()],
865            k_clusters: 2,
866            consensus_method: ConsensusMethod::CoAssociation,
867            view_weighting: ViewWeighting::Quality,
868            random_seed: Some(42),
869        };
870
871        let clusterer = ConsensusClustering::new(config);
872        assert_eq!(clusterer.config.k_clusters, 2);
873    }
874}