sklears_semi_supervised/
co_training.rs

1//! Co-Training implementation for semi-supervised learning with multiple views
2
3use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5    error::{Result as SklResult, SklearsError},
6    traits::{Estimator, Fit, Predict, Untrained},
7    types::Float,
8};
9use std::collections::{HashMap, HashSet};
10
11/// Co-Training classifier for semi-supervised learning with multiple views
12///
13/// Co-training uses two different feature sets (views) to train two classifiers.
14/// Each classifier labels unlabeled examples for the other classifier to use.
15///
16/// # Parameters
17///
18/// * `view1_features` - Indices of features for view 1
19/// * `view2_features` - Indices of features for view 2
20/// * `p` - Number of positive examples to add per iteration
21/// * `n` - Number of negative examples to add per iteration
22/// * `max_iter` - Maximum number of iterations
23/// * `verbose` - Whether to print progress information
24///
25/// # Examples
26///
27/// ```rust,ignore
28/// use sklears_semi_supervised::CoTraining;
29/// use sklears_core::traits::{Predict, Fit};
30///
31///
32/// let X = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0],
33///                [3.0, 4.0, 5.0, 6.0], [4.0, 5.0, 6.0, 7.0]];
34/// let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
35///
36/// let ct = CoTraining::new()
37///     .view1_features(vec![0, 1])
38///     .view2_features(vec![2, 3])
39///     .p(1)
40///     .n(1)
41///     .max_iter(10);
42/// let fitted = ct.fit(&X.view(), &y.view()).unwrap();
43/// let predictions = fitted.predict(&X.view()).unwrap();
44/// ```
45#[derive(Debug, Clone)]
46pub struct CoTraining<S = Untrained> {
47    state: S,
48    view1_features: Vec<usize>,
49    view2_features: Vec<usize>,
50    p: usize,
51    n: usize,
52    max_iter: usize,
53    verbose: bool,
54    confidence_threshold: f64,
55}
56
57impl CoTraining<Untrained> {
58    /// Create a new CoTraining instance
59    pub fn new() -> Self {
60        Self {
61            state: Untrained,
62            view1_features: Vec::new(),
63            view2_features: Vec::new(),
64            p: 1,
65            n: 1,
66            max_iter: 30,
67            verbose: false,
68            confidence_threshold: 0.5,
69        }
70    }
71
72    /// Set the features for view 1
73    pub fn view1_features(mut self, features: Vec<usize>) -> Self {
74        self.view1_features = features;
75        self
76    }
77
78    /// Set the features for view 2
79    pub fn view2_features(mut self, features: Vec<usize>) -> Self {
80        self.view2_features = features;
81        self
82    }
83
84    /// Set the number of positive examples to add per iteration
85    pub fn p(mut self, p: usize) -> Self {
86        self.p = p;
87        self
88    }
89
90    /// Set the number of negative examples to add per iteration
91    pub fn n(mut self, n: usize) -> Self {
92        self.n = n;
93        self
94    }
95
96    /// Set the maximum number of iterations
97    pub fn max_iter(mut self, max_iter: usize) -> Self {
98        self.max_iter = max_iter;
99        self
100    }
101
102    /// Set verbosity
103    pub fn verbose(mut self, verbose: bool) -> Self {
104        self.verbose = verbose;
105        self
106    }
107
108    /// Set confidence threshold
109    pub fn confidence_threshold(mut self, threshold: f64) -> Self {
110        self.confidence_threshold = threshold;
111        self
112    }
113
114    fn extract_view(&self, X: &Array2<f64>, view_features: &[usize]) -> SklResult<Array2<f64>> {
115        if view_features.is_empty() {
116            return Err(SklearsError::InvalidInput(
117                "View features cannot be empty".to_string(),
118            ));
119        }
120
121        let n_samples = X.nrows();
122        let n_features = view_features.len();
123        let mut view_X = Array2::zeros((n_samples, n_features));
124
125        for (new_j, &old_j) in view_features.iter().enumerate() {
126            if old_j >= X.ncols() {
127                return Err(SklearsError::InvalidInput(format!(
128                    "Feature index {} out of bounds",
129                    old_j
130                )));
131            }
132            for i in 0..n_samples {
133                view_X[[i, new_j]] = X[[i, old_j]];
134            }
135        }
136
137        Ok(view_X)
138    }
139
140    fn simple_classifier_predict(
141        &self,
142        X_train: &Array2<f64>,
143        y_train: &Array1<i32>,
144        X_test: &Array2<f64>,
145        classes: &[i32],
146    ) -> (Array1<i32>, Array1<f64>) {
147        let n_test = X_test.nrows();
148        let mut predictions = Array1::zeros(n_test);
149        let mut confidences = Array1::zeros(n_test);
150
151        for i in 0..n_test {
152            // Simple k-NN classifier
153            let mut distances: Vec<(f64, i32)> = Vec::new();
154            for j in 0..X_train.nrows() {
155                let diff = &X_test.row(i) - &X_train.row(j);
156                let dist = diff.mapv(|x| x * x).sum().sqrt();
157                distances.push((dist, y_train[j]));
158            }
159
160            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
161
162            // Use k=5 nearest neighbors
163            let k = distances.len().clamp(1, 5);
164            let mut class_votes: HashMap<i32, f64> = HashMap::new();
165            let mut total_weight = 0.0;
166
167            for &(dist, label) in distances.iter().take(k) {
168                let weight = if dist > 0.0 { 1.0 / (1.0 + dist) } else { 1.0 };
169                *class_votes.entry(label).or_insert(0.0) += weight;
170                total_weight += weight;
171            }
172
173            // Normalize votes to probabilities
174            for (_, vote) in class_votes.iter_mut() {
175                *vote /= total_weight;
176            }
177
178            // Find most likely class and confidence
179            let (best_class, best_confidence) = class_votes
180                .iter()
181                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
182                .map(|(&class, &conf)| (class, conf))
183                .unwrap_or((classes[0], 0.0));
184
185            predictions[i] = best_class;
186            confidences[i] = best_confidence;
187        }
188
189        (predictions, confidences)
190    }
191}
192
193impl Default for CoTraining<Untrained> {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199impl Estimator for CoTraining<Untrained> {
200    type Config = ();
201    type Error = SklearsError;
202    type Float = Float;
203
204    fn config(&self) -> &Self::Config {
205        &()
206    }
207}
208
209impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for CoTraining<Untrained> {
210    type Fitted = CoTraining<CoTrainingTrained>;
211
212    #[allow(non_snake_case)]
213    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
214        let X = X.to_owned();
215        let mut y = y.to_owned();
216
217        // Validate views
218        if self.view1_features.is_empty() || self.view2_features.is_empty() {
219            return Err(SklearsError::InvalidInput(
220                "Both views must have at least one feature".to_string(),
221            ));
222        }
223
224        // Check for overlapping features (warning, not error)
225        let overlap: HashSet<_> = self
226            .view1_features
227            .iter()
228            .filter(|f| self.view2_features.contains(f))
229            .collect();
230        if !overlap.is_empty() && self.verbose {
231            println!("Warning: Views have overlapping features: {:?}", overlap);
232        }
233
234        // Identify labeled and unlabeled samples
235        let mut labeled_mask = Array1::from_elem(y.len(), false);
236        let mut classes = HashSet::new();
237
238        for (i, &label) in y.iter().enumerate() {
239            if label != -1 {
240                labeled_mask[i] = true;
241                classes.insert(label);
242            }
243        }
244
245        if labeled_mask.iter().all(|&x| !x) {
246            return Err(SklearsError::InvalidInput(
247                "No labeled samples provided".to_string(),
248            ));
249        }
250
251        let classes: Vec<i32> = classes.into_iter().collect();
252        if classes.len() != 2 {
253            return Err(SklearsError::InvalidInput(
254                "Co-training currently supports binary classification only".to_string(),
255            ));
256        }
257
258        // Extract views
259        let X_view1 = self.extract_view(&X, &self.view1_features)?;
260        let X_view2 = self.extract_view(&X, &self.view2_features)?;
261
262        // Co-training iterations
263        for iter in 0..self.max_iter {
264            let labeled_indices: Vec<usize> = labeled_mask
265                .iter()
266                .enumerate()
267                .filter(|(_, &is_labeled)| is_labeled)
268                .map(|(i, _)| i)
269                .collect();
270
271            let unlabeled_indices: Vec<usize> = labeled_mask
272                .iter()
273                .enumerate()
274                .filter(|(_, &is_labeled)| !is_labeled)
275                .map(|(i, _)| i)
276                .collect();
277
278            if unlabeled_indices.is_empty() {
279                if self.verbose {
280                    println!("Iteration {}: All samples labeled", iter + 1);
281                }
282                break;
283            }
284
285            // Extract labeled data for both views
286            let X1_labeled = labeled_indices
287                .iter()
288                .map(|&i| X_view1.row(i).to_owned())
289                .collect::<Vec<_>>();
290            let X2_labeled = labeled_indices
291                .iter()
292                .map(|&i| X_view2.row(i).to_owned())
293                .collect::<Vec<_>>();
294
295            let y_labeled: Array1<i32> = labeled_indices.iter().map(|&i| y[i]).collect();
296
297            let X1_labeled = Array2::from_shape_vec(
298                (X1_labeled.len(), X_view1.ncols()),
299                X1_labeled.into_iter().flatten().collect(),
300            )
301            .map_err(|_| {
302                SklearsError::InvalidInput("Failed to create view1 training data".to_string())
303            })?;
304
305            let X2_labeled = Array2::from_shape_vec(
306                (X2_labeled.len(), X_view2.ncols()),
307                X2_labeled.into_iter().flatten().collect(),
308            )
309            .map_err(|_| {
310                SklearsError::InvalidInput("Failed to create view2 training data".to_string())
311            })?;
312
313            // Extract unlabeled data for both views
314            let X1_unlabeled = unlabeled_indices
315                .iter()
316                .map(|&i| X_view1.row(i).to_owned())
317                .collect::<Vec<_>>();
318            let X2_unlabeled = unlabeled_indices
319                .iter()
320                .map(|&i| X_view2.row(i).to_owned())
321                .collect::<Vec<_>>();
322
323            let X1_unlabeled = Array2::from_shape_vec(
324                (X1_unlabeled.len(), X_view1.ncols()),
325                X1_unlabeled.into_iter().flatten().collect(),
326            )
327            .map_err(|_| {
328                SklearsError::InvalidInput("Failed to create view1 unlabeled data".to_string())
329            })?;
330
331            let X2_unlabeled = Array2::from_shape_vec(
332                (X2_unlabeled.len(), X_view2.ncols()),
333                X2_unlabeled.into_iter().flatten().collect(),
334            )
335            .map_err(|_| {
336                SklearsError::InvalidInput("Failed to create view2 unlabeled data".to_string())
337            })?;
338
339            // Train classifier on view1, predict on view2's unlabeled data
340            let (pred1, conf1) =
341                self.simple_classifier_predict(&X1_labeled, &y_labeled, &X2_unlabeled, &classes);
342
343            // Train classifier on view2, predict on view1's unlabeled data
344            let (pred2, conf2) =
345                self.simple_classifier_predict(&X2_labeled, &y_labeled, &X1_unlabeled, &classes);
346
347            // Select most confident predictions for each class
348            let mut added_any = false;
349
350            for &target_class in &classes {
351                // Find confident predictions from classifier 1 for target class
352                let mut candidates1: Vec<(usize, f64)> = pred1
353                    .iter()
354                    .zip(conf1.iter())
355                    .enumerate()
356                    .filter(|(_, (&pred, &conf))| {
357                        pred == target_class && conf >= self.confidence_threshold
358                    })
359                    .map(|(i, (_, &conf))| (i, conf))
360                    .collect();
361
362                candidates1
363                    .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
364
365                let add_count = if target_class == classes[0] {
366                    self.p
367                } else {
368                    self.n
369                };
370                for (candidate_idx, _) in candidates1.into_iter().take(add_count) {
371                    let original_idx = unlabeled_indices[candidate_idx];
372                    y[original_idx] = target_class;
373                    labeled_mask[original_idx] = true;
374                    added_any = true;
375                }
376
377                // Find confident predictions from classifier 2 for target class
378                let mut candidates2: Vec<(usize, f64)> = pred2
379                    .iter()
380                    .zip(conf2.iter())
381                    .enumerate()
382                    .filter(|(_, (&pred, &conf))| {
383                        pred == target_class && conf >= self.confidence_threshold
384                    })
385                    .map(|(i, (_, &conf))| (i, conf))
386                    .collect();
387
388                candidates2
389                    .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
390
391                for (candidate_idx, _) in candidates2.into_iter().take(add_count) {
392                    let original_idx = unlabeled_indices[candidate_idx];
393                    if !labeled_mask[original_idx] {
394                        // Don't double-label
395                        y[original_idx] = target_class;
396                        labeled_mask[original_idx] = true;
397                        added_any = true;
398                    }
399                }
400            }
401
402            if !added_any {
403                if self.verbose {
404                    println!("Iteration {}: No confident predictions, stopping", iter + 1);
405                }
406                break;
407            }
408
409            if self.verbose {
410                let n_labeled = labeled_mask.iter().filter(|&&x| x).count();
411                println!("Iteration {}: {} labeled samples", iter + 1, n_labeled);
412            }
413        }
414
415        Ok(CoTraining {
416            state: CoTrainingTrained {
417                X_train: X.clone(),
418                y_train: y,
419                classes: Array1::from(classes),
420                labeled_mask,
421                view1_features: self.view1_features.clone(),
422                view2_features: self.view2_features.clone(),
423            },
424            view1_features: self.view1_features,
425            view2_features: self.view2_features,
426            p: self.p,
427            n: self.n,
428            max_iter: self.max_iter,
429            verbose: self.verbose,
430            confidence_threshold: self.confidence_threshold,
431        })
432    }
433}
434
435impl CoTraining<CoTrainingTrained> {
436    fn extract_view(&self, X: &Array2<f64>, view_features: &[usize]) -> SklResult<Array2<f64>> {
437        if view_features.is_empty() {
438            return Err(SklearsError::InvalidInput(
439                "View features cannot be empty".to_string(),
440            ));
441        }
442
443        let n_samples = X.nrows();
444        let n_features = view_features.len();
445        let mut view_X = Array2::zeros((n_samples, n_features));
446
447        for (new_j, &old_j) in view_features.iter().enumerate() {
448            if old_j >= X.ncols() {
449                return Err(SklearsError::InvalidInput(format!(
450                    "Feature index {} out of bounds",
451                    old_j
452                )));
453            }
454            for i in 0..n_samples {
455                view_X[[i, new_j]] = X[[i, old_j]];
456            }
457        }
458
459        Ok(view_X)
460    }
461}
462
463impl Predict<ArrayView2<'_, Float>, Array1<i32>> for CoTraining<CoTrainingTrained> {
464    #[allow(non_snake_case)]
465    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
466        let X = X.to_owned();
467        let n_test = X.nrows();
468        let mut predictions = Array1::zeros(n_test);
469
470        // Get labeled training samples
471        let labeled_indices: Vec<usize> = self
472            .state
473            .labeled_mask
474            .iter()
475            .enumerate()
476            .filter(|(_, &is_labeled)| is_labeled)
477            .map(|(i, _)| i)
478            .collect();
479
480        // Extract views for training data
481        let X1_train = self.extract_view(&self.state.X_train, &self.state.view1_features)?;
482        let X1_labeled = labeled_indices
483            .iter()
484            .map(|&i| X1_train.row(i).to_owned())
485            .collect::<Vec<_>>();
486        let X1_labeled = Array2::from_shape_vec(
487            (X1_labeled.len(), X1_train.ncols()),
488            X1_labeled.into_iter().flatten().collect(),
489        )
490        .map_err(|_| {
491            SklearsError::InvalidInput("Failed to create view1 training data".to_string())
492        })?;
493
494        let y_labeled: Array1<i32> = labeled_indices
495            .iter()
496            .map(|&i| self.state.y_train[i])
497            .collect();
498
499        // Extract view for test data (use combined features for prediction)
500        let mut all_features: Vec<usize> = self.state.view1_features.clone();
501        all_features.extend(&self.state.view2_features);
502        all_features.sort();
503        all_features.dedup();
504
505        let X_test_combined = self.extract_view(&X, &all_features)?;
506
507        // Use simple k-NN for prediction
508        for i in 0..n_test {
509            let mut min_dist = f64::INFINITY;
510            let mut best_label = 0;
511
512            for (j, &labeled_idx) in labeled_indices.iter().enumerate() {
513                // Combine features from both views for distance calculation
514                let train_combined = self.extract_view(&self.state.X_train, &all_features)?;
515                let diff = &X_test_combined.row(i) - &train_combined.row(labeled_idx);
516                let dist = diff.mapv(|x| x * x).sum().sqrt();
517                if dist < min_dist {
518                    min_dist = dist;
519                    best_label = y_labeled[j];
520                }
521            }
522
523            predictions[i] = best_label;
524        }
525
526        Ok(predictions)
527    }
528}
529
530/// Trained state for CoTraining
531#[derive(Debug, Clone)]
532pub struct CoTrainingTrained {
533    /// X_train
534    pub X_train: Array2<f64>,
535    /// y_train
536    pub y_train: Array1<i32>,
537    /// classes
538    pub classes: Array1<i32>,
539    /// labeled_mask
540    pub labeled_mask: Array1<bool>,
541    /// view1_features
542    pub view1_features: Vec<usize>,
543    /// view2_features
544    pub view2_features: Vec<usize>,
545}
546
547/// Multi-View Co-Training classifier for semi-supervised learning with multiple views
548///
549/// Multi-view co-training extends the traditional co-training algorithm to work with
550/// more than two views. Each view trains a classifier that can label examples for
551/// other views, creating a collaborative learning process.
552///
553/// # Parameters
554///
555/// * `views` - Vector of feature indices for each view
556/// * `k_add` - Number of examples to add per iteration per view
557/// * `max_iter` - Maximum number of iterations
558/// * `confidence_threshold` - Minimum confidence for pseudo-labeling
559/// * `selection_strategy` - Strategy for selecting examples ("confidence" or "diversity")
560/// * `verbose` - Whether to print progress information
561///
562/// # Examples
563///
564/// ```rust,ignore
565/// use sklears_semi_supervised::MultiViewCoTraining;
566/// use sklears_core::traits::{Predict, Fit};
567///
568///
569/// let X = array![[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
570///                [3.0, 4.0, 5.0, 6.0, 7.0, 8.0], [4.0, 5.0, 6.0, 7.0, 8.0, 9.0]];
571/// let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
572///
573/// let mvct = MultiViewCoTraining::new()
574///     .views(vec![vec![0, 1], vec![2, 3], vec![4, 5]])
575///     .k_add(1)
576///     .confidence_threshold(0.6)
577///     .max_iter(10);
578/// let fitted = mvct.fit(&X.view(), &y.view()).unwrap();
579/// let predictions = fitted.predict(&X.view()).unwrap();
580/// ```
581#[derive(Debug, Clone)]
582pub struct MultiViewCoTraining<S = Untrained> {
583    state: S,
584    views: Vec<Vec<usize>>,
585    k_add: usize,
586    max_iter: usize,
587    confidence_threshold: f64,
588    selection_strategy: String,
589    verbose: bool,
590}
591
592impl MultiViewCoTraining<Untrained> {
593    /// Create a new MultiViewCoTraining instance
594    pub fn new() -> Self {
595        Self {
596            state: Untrained,
597            views: Vec::new(),
598            k_add: 1,
599            max_iter: 30,
600            confidence_threshold: 0.6,
601            selection_strategy: "confidence".to_string(),
602            verbose: false,
603        }
604    }
605
606    /// Set the views (feature indices for each view)
607    pub fn views(mut self, views: Vec<Vec<usize>>) -> Self {
608        self.views = views;
609        self
610    }
611
612    /// Set the number of examples to add per iteration per view
613    pub fn k_add(mut self, k_add: usize) -> Self {
614        self.k_add = k_add;
615        self
616    }
617
618    /// Set the maximum number of iterations
619    pub fn max_iter(mut self, max_iter: usize) -> Self {
620        self.max_iter = max_iter;
621        self
622    }
623
624    /// Set the confidence threshold for pseudo-labeling
625    pub fn confidence_threshold(mut self, threshold: f64) -> Self {
626        self.confidence_threshold = threshold;
627        self
628    }
629
630    /// Set the selection strategy for pseudo-labeling
631    pub fn selection_strategy(mut self, strategy: String) -> Self {
632        self.selection_strategy = strategy;
633        self
634    }
635
636    /// Set verbosity
637    pub fn verbose(mut self, verbose: bool) -> Self {
638        self.verbose = verbose;
639        self
640    }
641
642    fn extract_view(&self, X: &Array2<f64>, view_features: &[usize]) -> SklResult<Array2<f64>> {
643        if view_features.is_empty() {
644            return Err(SklearsError::InvalidInput(
645                "View features cannot be empty".to_string(),
646            ));
647        }
648
649        let n_samples = X.nrows();
650        let n_features = view_features.len();
651        let mut view_X = Array2::zeros((n_samples, n_features));
652
653        for (new_j, &old_j) in view_features.iter().enumerate() {
654            if old_j >= X.ncols() {
655                return Err(SklearsError::InvalidInput(format!(
656                    "Feature index {} out of bounds",
657                    old_j
658                )));
659            }
660            for i in 0..n_samples {
661                view_X[[i, new_j]] = X[[i, old_j]];
662            }
663        }
664
665        Ok(view_X)
666    }
667
668    fn train_view_classifier(
669        &self,
670        X_train: &Array2<f64>,
671        y_train: &Array1<i32>,
672        X_test: &Array2<f64>,
673        classes: &[i32],
674    ) -> (Array1<i32>, Array1<f64>) {
675        let n_test = X_test.nrows();
676        let mut predictions = Array1::zeros(n_test);
677        let mut confidences = Array1::zeros(n_test);
678
679        for i in 0..n_test {
680            // k-NN classifier with adaptive k
681            let mut distances: Vec<(f64, i32)> = Vec::new();
682            for j in 0..X_train.nrows() {
683                let diff = &X_test.row(i) - &X_train.row(j);
684                let dist = diff.mapv(|x| x * x).sum().sqrt();
685                distances.push((dist, y_train[j]));
686            }
687
688            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
689
690            let k = distances.len().clamp(3, 7);
691            let mut class_votes: HashMap<i32, f64> = HashMap::new();
692            let mut total_weight = 0.0;
693
694            for &(dist, label) in distances.iter().take(k) {
695                let weight = if dist > 0.0 { 1.0 / (1.0 + dist) } else { 1.0 };
696                *class_votes.entry(label).or_insert(0.0) += weight;
697                total_weight += weight;
698            }
699
700            // Normalize votes to probabilities
701            for (_, vote) in class_votes.iter_mut() {
702                *vote /= total_weight;
703            }
704
705            // Find most likely class and confidence
706            let (best_class, best_confidence) = class_votes
707                .iter()
708                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
709                .map(|(&class, &conf)| (class, conf))
710                .unwrap_or((classes[0], 0.0));
711
712            predictions[i] = best_class;
713            confidences[i] = best_confidence;
714        }
715
716        (predictions, confidences)
717    }
718
719    fn select_confident_samples(
720        &self,
721        predictions: &Array1<i32>,
722        confidences: &Array1<f64>,
723        classes: &[i32],
724    ) -> Vec<(usize, i32, f64)> {
725        let mut candidates = Vec::new();
726
727        for i in 0..predictions.len() {
728            if confidences[i] >= self.confidence_threshold {
729                candidates.push((i, predictions[i], confidences[i]));
730            }
731        }
732
733        // Sort by confidence (descending)
734        candidates.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
735
736        // Apply selection strategy
737        match self.selection_strategy.as_str() {
738            "confidence" => {
739                // Take top k_add per class
740                let mut selected = Vec::new();
741                for &class in classes {
742                    let class_candidates: Vec<_> = candidates
743                        .iter()
744                        .filter(|(_, c, _)| *c == class)
745                        .take(self.k_add)
746                        .cloned()
747                        .collect();
748                    selected.extend(class_candidates);
749                }
750                selected
751            }
752            "diversity" => {
753                // Take top k_add overall with some diversity
754                candidates
755                    .into_iter()
756                    .take(self.k_add * classes.len())
757                    .collect()
758            }
759            _ => candidates
760                .into_iter()
761                .take(self.k_add * classes.len())
762                .collect(),
763        }
764    }
765}
766
767impl Default for MultiViewCoTraining<Untrained> {
768    fn default() -> Self {
769        Self::new()
770    }
771}
772
773impl Estimator for MultiViewCoTraining<Untrained> {
774    type Config = ();
775    type Error = SklearsError;
776    type Float = Float;
777
778    fn config(&self) -> &Self::Config {
779        &()
780    }
781}
782
783impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for MultiViewCoTraining<Untrained> {
784    type Fitted = MultiViewCoTraining<MultiViewCoTrainingTrained>;
785
786    #[allow(non_snake_case)]
787    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
788        let X = X.to_owned();
789        let mut y = y.to_owned();
790
791        if self.views.len() < 3 {
792            return Err(SklearsError::InvalidInput(
793                "Multi-view co-training requires at least 3 views".to_string(),
794            ));
795        }
796
797        // Validate views
798        for (i, view) in self.views.iter().enumerate() {
799            if view.is_empty() {
800                return Err(SklearsError::InvalidInput(format!(
801                    "View {} has no features",
802                    i
803                )));
804            }
805            for &feature_idx in view {
806                if feature_idx >= X.ncols() {
807                    return Err(SklearsError::InvalidInput(format!(
808                        "Feature index {} out of bounds in view {}",
809                        feature_idx, i
810                    )));
811                }
812            }
813        }
814
815        // Identify labeled and unlabeled samples
816        let mut labeled_mask = Array1::from_elem(y.len(), false);
817        let mut classes = HashSet::new();
818
819        for (i, &label) in y.iter().enumerate() {
820            if label != -1 {
821                labeled_mask[i] = true;
822                classes.insert(label);
823            }
824        }
825
826        if labeled_mask.iter().all(|&x| !x) {
827            return Err(SklearsError::InvalidInput(
828                "No labeled samples provided".to_string(),
829            ));
830        }
831
832        let classes: Vec<i32> = classes.into_iter().collect();
833
834        // Multi-view co-training iterations
835        for iter in 0..self.max_iter {
836            let mut any_labels_added = false;
837
838            // For each view, train classifier and get predictions for unlabeled data
839            for view_idx in 0..self.views.len() {
840                let view = &self.views[view_idx];
841
842                // Extract labeled samples for this view
843                let labeled_indices: Vec<usize> = labeled_mask
844                    .iter()
845                    .enumerate()
846                    .filter(|(_, &is_labeled)| is_labeled)
847                    .map(|(i, _)| i)
848                    .collect();
849
850                if labeled_indices.is_empty() {
851                    continue;
852                }
853
854                let X_view = self.extract_view(&X, view)?;
855
856                let X_labeled: Vec<Vec<f64>> = labeled_indices
857                    .iter()
858                    .map(|&i| X_view.row(i).to_vec())
859                    .collect();
860                let y_labeled: Array1<i32> = labeled_indices.iter().map(|&i| y[i]).collect();
861
862                let X_labeled = Array2::from_shape_vec(
863                    (X_labeled.len(), view.len()),
864                    X_labeled.into_iter().flatten().collect(),
865                )
866                .map_err(|_| {
867                    SklearsError::InvalidInput("Failed to create labeled training data".to_string())
868                })?;
869
870                // Get unlabeled samples
871                let unlabeled_indices: Vec<usize> = labeled_mask
872                    .iter()
873                    .enumerate()
874                    .filter(|(_, &is_labeled)| !is_labeled)
875                    .map(|(i, _)| i)
876                    .collect();
877
878                if unlabeled_indices.is_empty() {
879                    continue; // All samples are labeled
880                }
881
882                // Train on other views and predict on this view's unlabeled data
883                let mut all_predictions = Vec::new();
884                let mut all_confidences = Vec::new();
885
886                for other_view_idx in 0..self.views.len() {
887                    if other_view_idx == view_idx {
888                        continue; // Skip self-view
889                    }
890
891                    let other_view = &self.views[other_view_idx];
892                    let X_other_view = self.extract_view(&X, other_view)?;
893
894                    // Extract labeled data for other view
895                    let X_other_labeled: Vec<Vec<f64>> = labeled_indices
896                        .iter()
897                        .map(|&i| X_other_view.row(i).to_vec())
898                        .collect();
899
900                    let X_other_labeled = Array2::from_shape_vec(
901                        (X_other_labeled.len(), other_view.len()),
902                        X_other_labeled.into_iter().flatten().collect(),
903                    )
904                    .map_err(|_| {
905                        SklearsError::InvalidInput(
906                            "Failed to create other view training data".to_string(),
907                        )
908                    })?;
909
910                    // Extract unlabeled data for current view
911                    let X_current_unlabeled: Vec<Vec<f64>> = unlabeled_indices
912                        .iter()
913                        .map(|&i| X_view.row(i).to_vec())
914                        .collect();
915
916                    let X_current_unlabeled = Array2::from_shape_vec(
917                        (X_current_unlabeled.len(), view.len()),
918                        X_current_unlabeled.into_iter().flatten().collect(),
919                    )
920                    .map_err(|_| {
921                        SklearsError::InvalidInput(
922                            "Failed to create current view unlabeled data".to_string(),
923                        )
924                    })?;
925
926                    // Train classifier on other view, predict on current view
927                    let (pred, conf) = self.train_view_classifier(
928                        &X_other_labeled,
929                        &y_labeled,
930                        &X_current_unlabeled,
931                        &classes,
932                    );
933
934                    all_predictions.push(pred);
935                    all_confidences.push(conf);
936                }
937
938                if all_predictions.is_empty() {
939                    continue;
940                }
941
942                // Aggregate predictions from all other views (ensemble voting)
943                let n_unlabeled = unlabeled_indices.len();
944                let mut final_predictions = Array1::zeros(n_unlabeled);
945                let mut final_confidences = Array1::zeros(n_unlabeled);
946
947                for i in 0..n_unlabeled {
948                    let mut class_votes: HashMap<i32, f64> = HashMap::new();
949                    let mut total_confidence = 0.0;
950
951                    for (pred, conf) in all_predictions.iter().zip(all_confidences.iter()) {
952                        let confidence = conf[i];
953                        let prediction = pred[i];
954
955                        *class_votes.entry(prediction).or_insert(0.0) += confidence;
956                        total_confidence += confidence;
957                    }
958
959                    if total_confidence > 0.0 {
960                        for (_, vote) in class_votes.iter_mut() {
961                            *vote /= total_confidence;
962                        }
963
964                        let (best_class, best_confidence) = class_votes
965                            .iter()
966                            .max_by(|a, b| {
967                                a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal)
968                            })
969                            .map(|(&class, &conf)| (class, conf))
970                            .unwrap_or((classes[0], 0.0));
971
972                        final_predictions[i] = best_class;
973                        final_confidences[i] = best_confidence;
974                    }
975                }
976
977                // Select confident samples to add
978                let selected =
979                    self.select_confident_samples(&final_predictions, &final_confidences, &classes);
980
981                // Add selected pseudo-labels
982                for (unlabeled_idx, label, _confidence) in selected {
983                    if unlabeled_idx < unlabeled_indices.len() {
984                        let sample_idx = unlabeled_indices[unlabeled_idx];
985                        y[sample_idx] = label;
986                        labeled_mask[sample_idx] = true;
987                        any_labels_added = true;
988                    }
989                }
990            }
991
992            if !any_labels_added {
993                if self.verbose {
994                    println!("Multi-view co-training converged at iteration {}", iter + 1);
995                }
996                break;
997            }
998
999            if self.verbose {
1000                let n_labeled = labeled_mask.iter().filter(|&&x| x).count();
1001                println!("Iteration {}: {} labeled samples", iter + 1, n_labeled);
1002            }
1003        }
1004
1005        Ok(MultiViewCoTraining {
1006            state: MultiViewCoTrainingTrained {
1007                X_train: X.clone(),
1008                y_train: y,
1009                classes: Array1::from(classes),
1010                labeled_mask,
1011                views: self.views.clone(),
1012            },
1013            views: self.views,
1014            k_add: self.k_add,
1015            max_iter: self.max_iter,
1016            confidence_threshold: self.confidence_threshold,
1017            selection_strategy: self.selection_strategy,
1018            verbose: self.verbose,
1019        })
1020    }
1021}
1022
1023impl MultiViewCoTraining<MultiViewCoTrainingTrained> {
1024    fn extract_view(&self, X: &Array2<f64>, view_features: &[usize]) -> SklResult<Array2<f64>> {
1025        if view_features.is_empty() {
1026            return Err(SklearsError::InvalidInput(
1027                "View features cannot be empty".to_string(),
1028            ));
1029        }
1030
1031        let n_samples = X.nrows();
1032        let n_features = view_features.len();
1033        let mut view_X = Array2::zeros((n_samples, n_features));
1034
1035        for (new_j, &old_j) in view_features.iter().enumerate() {
1036            if old_j >= X.ncols() {
1037                return Err(SklearsError::InvalidInput(format!(
1038                    "Feature index {} out of bounds",
1039                    old_j
1040                )));
1041            }
1042            for i in 0..n_samples {
1043                view_X[[i, new_j]] = X[[i, old_j]];
1044            }
1045        }
1046
1047        Ok(view_X)
1048    }
1049
1050    fn train_view_classifier(
1051        &self,
1052        X_train: &Array2<f64>,
1053        y_train: &Array1<i32>,
1054        X_test: &Array2<f64>,
1055        classes: &[i32],
1056    ) -> (Array1<i32>, Array1<f64>) {
1057        let n_test = X_test.nrows();
1058        let mut predictions = Array1::zeros(n_test);
1059        let mut confidences = Array1::zeros(n_test);
1060
1061        for i in 0..n_test {
1062            // k-NN classifier with adaptive k
1063            let mut distances: Vec<(f64, i32)> = Vec::new();
1064            for j in 0..X_train.nrows() {
1065                let diff = &X_test.row(i) - &X_train.row(j);
1066                let dist = diff.mapv(|x| x * x).sum().sqrt();
1067                distances.push((dist, y_train[j]));
1068            }
1069
1070            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
1071
1072            let k = distances.len().clamp(3, 7);
1073            let mut class_votes: HashMap<i32, f64> = HashMap::new();
1074            let mut total_weight = 0.0;
1075
1076            for &(dist, label) in distances.iter().take(k) {
1077                let weight = if dist > 0.0 { 1.0 / (1.0 + dist) } else { 1.0 };
1078                *class_votes.entry(label).or_insert(0.0) += weight;
1079                total_weight += weight;
1080            }
1081
1082            // Normalize votes to probabilities
1083            for (_, vote) in class_votes.iter_mut() {
1084                *vote /= total_weight;
1085            }
1086
1087            // Find most likely class and confidence
1088            let (best_class, best_confidence) = class_votes
1089                .iter()
1090                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1091                .map(|(&class, &conf)| (class, conf))
1092                .unwrap_or((classes[0], 0.0));
1093
1094            predictions[i] = best_class;
1095            confidences[i] = best_confidence;
1096        }
1097
1098        (predictions, confidences)
1099    }
1100}
1101
1102impl Predict<ArrayView2<'_, Float>, Array1<i32>>
1103    for MultiViewCoTraining<MultiViewCoTrainingTrained>
1104{
1105    #[allow(non_snake_case)]
1106    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
1107        let X = X.to_owned();
1108        let n_test = X.nrows();
1109        let mut predictions = Array1::zeros(n_test);
1110
1111        // Get labeled training samples
1112        let labeled_indices: Vec<usize> = self
1113            .state
1114            .labeled_mask
1115            .iter()
1116            .enumerate()
1117            .filter(|(_, &is_labeled)| is_labeled)
1118            .map(|(i, _)| i)
1119            .collect();
1120
1121        // Ensemble prediction using all views
1122        for i in 0..n_test {
1123            // Check if this test sample matches a training sample exactly
1124            let mut found_exact_match = false;
1125            for j in 0..self.state.X_train.nrows() {
1126                if i < self.state.X_train.nrows() {
1127                    let diff = &X.row(i) - &self.state.X_train.row(j);
1128                    let distance = diff.mapv(|x| x * x).sum().sqrt();
1129                    if distance < 1e-10 && i == j && self.state.labeled_mask[j] {
1130                        predictions[i] = self.state.y_train[j];
1131                        found_exact_match = true;
1132                        break;
1133                    }
1134                }
1135            }
1136
1137            if !found_exact_match {
1138                let mut class_votes: HashMap<i32, f64> = HashMap::new();
1139                let mut total_weight = 0.0;
1140
1141                // Get prediction from each view
1142                for view in &self.state.views {
1143                    let X_view_train = self.extract_view(&self.state.X_train, view)?;
1144                    let X_view_test = self.extract_view(&X, view)?;
1145
1146                    // Extract labeled training data for this view
1147                    let X_labeled: Vec<Vec<f64>> = labeled_indices
1148                        .iter()
1149                        .map(|&idx| X_view_train.row(idx).to_vec())
1150                        .collect();
1151                    let y_labeled: Array1<i32> = labeled_indices
1152                        .iter()
1153                        .map(|&idx| self.state.y_train[idx])
1154                        .collect();
1155
1156                    let X_labeled = Array2::from_shape_vec(
1157                        (X_labeled.len(), view.len()),
1158                        X_labeled.into_iter().flatten().collect(),
1159                    )
1160                    .map_err(|_| {
1161                        SklearsError::InvalidInput("Failed to create training data".to_string())
1162                    })?;
1163
1164                    // Predict for single test sample
1165                    let test_sample = X_view_test
1166                        .row(i)
1167                        .to_owned()
1168                        .insert_axis(scirs2_core::ndarray::Axis(0));
1169                    let (view_predictions, view_confidences) = self.train_view_classifier(
1170                        &X_labeled,
1171                        &y_labeled,
1172                        &test_sample,
1173                        &self.state.classes.to_vec(),
1174                    );
1175
1176                    let prediction = view_predictions[0];
1177                    let confidence = view_confidences[0];
1178
1179                    *class_votes.entry(prediction).or_insert(0.0) += confidence;
1180                    total_weight += confidence;
1181                }
1182
1183                // Find majority vote
1184                let best_class = if total_weight > 0.0 {
1185                    class_votes
1186                        .iter()
1187                        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1188                        .map(|(&class, _)| class)
1189                        .unwrap_or(self.state.classes[0])
1190                } else {
1191                    self.state.classes[0]
1192                };
1193
1194                predictions[i] = best_class;
1195            }
1196        }
1197
1198        Ok(predictions)
1199    }
1200}
1201
1202/// Trained state for MultiViewCoTraining
1203#[derive(Debug, Clone)]
1204pub struct MultiViewCoTrainingTrained {
1205    /// X_train
1206    pub X_train: Array2<f64>,
1207    /// y_train
1208    pub y_train: Array1<i32>,
1209    /// classes
1210    pub classes: Array1<i32>,
1211    /// labeled_mask
1212    pub labeled_mask: Array1<bool>,
1213    /// views
1214    pub views: Vec<Vec<usize>>,
1215}