1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::{Estimator, Fit, Predict, Untrained},
7 types::Float,
8};
9use std::collections::{HashMap, HashSet};
10
11#[derive(Debug, Clone)]
46pub struct CoTraining<S = Untrained> {
47 state: S,
48 view1_features: Vec<usize>,
49 view2_features: Vec<usize>,
50 p: usize,
51 n: usize,
52 max_iter: usize,
53 verbose: bool,
54 confidence_threshold: f64,
55}
56
57impl CoTraining<Untrained> {
58 pub fn new() -> Self {
60 Self {
61 state: Untrained,
62 view1_features: Vec::new(),
63 view2_features: Vec::new(),
64 p: 1,
65 n: 1,
66 max_iter: 30,
67 verbose: false,
68 confidence_threshold: 0.5,
69 }
70 }
71
72 pub fn view1_features(mut self, features: Vec<usize>) -> Self {
74 self.view1_features = features;
75 self
76 }
77
78 pub fn view2_features(mut self, features: Vec<usize>) -> Self {
80 self.view2_features = features;
81 self
82 }
83
84 pub fn p(mut self, p: usize) -> Self {
86 self.p = p;
87 self
88 }
89
90 pub fn n(mut self, n: usize) -> Self {
92 self.n = n;
93 self
94 }
95
96 pub fn max_iter(mut self, max_iter: usize) -> Self {
98 self.max_iter = max_iter;
99 self
100 }
101
102 pub fn verbose(mut self, verbose: bool) -> Self {
104 self.verbose = verbose;
105 self
106 }
107
108 pub fn confidence_threshold(mut self, threshold: f64) -> Self {
110 self.confidence_threshold = threshold;
111 self
112 }
113
114 fn extract_view(&self, X: &Array2<f64>, view_features: &[usize]) -> SklResult<Array2<f64>> {
115 if view_features.is_empty() {
116 return Err(SklearsError::InvalidInput(
117 "View features cannot be empty".to_string(),
118 ));
119 }
120
121 let n_samples = X.nrows();
122 let n_features = view_features.len();
123 let mut view_X = Array2::zeros((n_samples, n_features));
124
125 for (new_j, &old_j) in view_features.iter().enumerate() {
126 if old_j >= X.ncols() {
127 return Err(SklearsError::InvalidInput(format!(
128 "Feature index {} out of bounds",
129 old_j
130 )));
131 }
132 for i in 0..n_samples {
133 view_X[[i, new_j]] = X[[i, old_j]];
134 }
135 }
136
137 Ok(view_X)
138 }
139
140 fn simple_classifier_predict(
141 &self,
142 X_train: &Array2<f64>,
143 y_train: &Array1<i32>,
144 X_test: &Array2<f64>,
145 classes: &[i32],
146 ) -> (Array1<i32>, Array1<f64>) {
147 let n_test = X_test.nrows();
148 let mut predictions = Array1::zeros(n_test);
149 let mut confidences = Array1::zeros(n_test);
150
151 for i in 0..n_test {
152 let mut distances: Vec<(f64, i32)> = Vec::new();
154 for j in 0..X_train.nrows() {
155 let diff = &X_test.row(i) - &X_train.row(j);
156 let dist = diff.mapv(|x| x * x).sum().sqrt();
157 distances.push((dist, y_train[j]));
158 }
159
160 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
161
162 let k = distances.len().clamp(1, 5);
164 let mut class_votes: HashMap<i32, f64> = HashMap::new();
165 let mut total_weight = 0.0;
166
167 for &(dist, label) in distances.iter().take(k) {
168 let weight = if dist > 0.0 { 1.0 / (1.0 + dist) } else { 1.0 };
169 *class_votes.entry(label).or_insert(0.0) += weight;
170 total_weight += weight;
171 }
172
173 for (_, vote) in class_votes.iter_mut() {
175 *vote /= total_weight;
176 }
177
178 let (best_class, best_confidence) = class_votes
180 .iter()
181 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
182 .map(|(&class, &conf)| (class, conf))
183 .unwrap_or((classes[0], 0.0));
184
185 predictions[i] = best_class;
186 confidences[i] = best_confidence;
187 }
188
189 (predictions, confidences)
190 }
191}
192
193impl Default for CoTraining<Untrained> {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199impl Estimator for CoTraining<Untrained> {
200 type Config = ();
201 type Error = SklearsError;
202 type Float = Float;
203
204 fn config(&self) -> &Self::Config {
205 &()
206 }
207}
208
209impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for CoTraining<Untrained> {
210 type Fitted = CoTraining<CoTrainingTrained>;
211
212 #[allow(non_snake_case)]
213 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
214 let X = X.to_owned();
215 let mut y = y.to_owned();
216
217 if self.view1_features.is_empty() || self.view2_features.is_empty() {
219 return Err(SklearsError::InvalidInput(
220 "Both views must have at least one feature".to_string(),
221 ));
222 }
223
224 let overlap: HashSet<_> = self
226 .view1_features
227 .iter()
228 .filter(|f| self.view2_features.contains(f))
229 .collect();
230 if !overlap.is_empty() && self.verbose {
231 println!("Warning: Views have overlapping features: {:?}", overlap);
232 }
233
234 let mut labeled_mask = Array1::from_elem(y.len(), false);
236 let mut classes = HashSet::new();
237
238 for (i, &label) in y.iter().enumerate() {
239 if label != -1 {
240 labeled_mask[i] = true;
241 classes.insert(label);
242 }
243 }
244
245 if labeled_mask.iter().all(|&x| !x) {
246 return Err(SklearsError::InvalidInput(
247 "No labeled samples provided".to_string(),
248 ));
249 }
250
251 let classes: Vec<i32> = classes.into_iter().collect();
252 if classes.len() != 2 {
253 return Err(SklearsError::InvalidInput(
254 "Co-training currently supports binary classification only".to_string(),
255 ));
256 }
257
258 let X_view1 = self.extract_view(&X, &self.view1_features)?;
260 let X_view2 = self.extract_view(&X, &self.view2_features)?;
261
262 for iter in 0..self.max_iter {
264 let labeled_indices: Vec<usize> = labeled_mask
265 .iter()
266 .enumerate()
267 .filter(|(_, &is_labeled)| is_labeled)
268 .map(|(i, _)| i)
269 .collect();
270
271 let unlabeled_indices: Vec<usize> = labeled_mask
272 .iter()
273 .enumerate()
274 .filter(|(_, &is_labeled)| !is_labeled)
275 .map(|(i, _)| i)
276 .collect();
277
278 if unlabeled_indices.is_empty() {
279 if self.verbose {
280 println!("Iteration {}: All samples labeled", iter + 1);
281 }
282 break;
283 }
284
285 let X1_labeled = labeled_indices
287 .iter()
288 .map(|&i| X_view1.row(i).to_owned())
289 .collect::<Vec<_>>();
290 let X2_labeled = labeled_indices
291 .iter()
292 .map(|&i| X_view2.row(i).to_owned())
293 .collect::<Vec<_>>();
294
295 let y_labeled: Array1<i32> = labeled_indices.iter().map(|&i| y[i]).collect();
296
297 let X1_labeled = Array2::from_shape_vec(
298 (X1_labeled.len(), X_view1.ncols()),
299 X1_labeled.into_iter().flatten().collect(),
300 )
301 .map_err(|_| {
302 SklearsError::InvalidInput("Failed to create view1 training data".to_string())
303 })?;
304
305 let X2_labeled = Array2::from_shape_vec(
306 (X2_labeled.len(), X_view2.ncols()),
307 X2_labeled.into_iter().flatten().collect(),
308 )
309 .map_err(|_| {
310 SklearsError::InvalidInput("Failed to create view2 training data".to_string())
311 })?;
312
313 let X1_unlabeled = unlabeled_indices
315 .iter()
316 .map(|&i| X_view1.row(i).to_owned())
317 .collect::<Vec<_>>();
318 let X2_unlabeled = unlabeled_indices
319 .iter()
320 .map(|&i| X_view2.row(i).to_owned())
321 .collect::<Vec<_>>();
322
323 let X1_unlabeled = Array2::from_shape_vec(
324 (X1_unlabeled.len(), X_view1.ncols()),
325 X1_unlabeled.into_iter().flatten().collect(),
326 )
327 .map_err(|_| {
328 SklearsError::InvalidInput("Failed to create view1 unlabeled data".to_string())
329 })?;
330
331 let X2_unlabeled = Array2::from_shape_vec(
332 (X2_unlabeled.len(), X_view2.ncols()),
333 X2_unlabeled.into_iter().flatten().collect(),
334 )
335 .map_err(|_| {
336 SklearsError::InvalidInput("Failed to create view2 unlabeled data".to_string())
337 })?;
338
339 let (pred1, conf1) =
341 self.simple_classifier_predict(&X1_labeled, &y_labeled, &X2_unlabeled, &classes);
342
343 let (pred2, conf2) =
345 self.simple_classifier_predict(&X2_labeled, &y_labeled, &X1_unlabeled, &classes);
346
347 let mut added_any = false;
349
350 for &target_class in &classes {
351 let mut candidates1: Vec<(usize, f64)> = pred1
353 .iter()
354 .zip(conf1.iter())
355 .enumerate()
356 .filter(|(_, (&pred, &conf))| {
357 pred == target_class && conf >= self.confidence_threshold
358 })
359 .map(|(i, (_, &conf))| (i, conf))
360 .collect();
361
362 candidates1
363 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
364
365 let add_count = if target_class == classes[0] {
366 self.p
367 } else {
368 self.n
369 };
370 for (candidate_idx, _) in candidates1.into_iter().take(add_count) {
371 let original_idx = unlabeled_indices[candidate_idx];
372 y[original_idx] = target_class;
373 labeled_mask[original_idx] = true;
374 added_any = true;
375 }
376
377 let mut candidates2: Vec<(usize, f64)> = pred2
379 .iter()
380 .zip(conf2.iter())
381 .enumerate()
382 .filter(|(_, (&pred, &conf))| {
383 pred == target_class && conf >= self.confidence_threshold
384 })
385 .map(|(i, (_, &conf))| (i, conf))
386 .collect();
387
388 candidates2
389 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
390
391 for (candidate_idx, _) in candidates2.into_iter().take(add_count) {
392 let original_idx = unlabeled_indices[candidate_idx];
393 if !labeled_mask[original_idx] {
394 y[original_idx] = target_class;
396 labeled_mask[original_idx] = true;
397 added_any = true;
398 }
399 }
400 }
401
402 if !added_any {
403 if self.verbose {
404 println!("Iteration {}: No confident predictions, stopping", iter + 1);
405 }
406 break;
407 }
408
409 if self.verbose {
410 let n_labeled = labeled_mask.iter().filter(|&&x| x).count();
411 println!("Iteration {}: {} labeled samples", iter + 1, n_labeled);
412 }
413 }
414
415 Ok(CoTraining {
416 state: CoTrainingTrained {
417 X_train: X.clone(),
418 y_train: y,
419 classes: Array1::from(classes),
420 labeled_mask,
421 view1_features: self.view1_features.clone(),
422 view2_features: self.view2_features.clone(),
423 },
424 view1_features: self.view1_features,
425 view2_features: self.view2_features,
426 p: self.p,
427 n: self.n,
428 max_iter: self.max_iter,
429 verbose: self.verbose,
430 confidence_threshold: self.confidence_threshold,
431 })
432 }
433}
434
435impl CoTraining<CoTrainingTrained> {
436 fn extract_view(&self, X: &Array2<f64>, view_features: &[usize]) -> SklResult<Array2<f64>> {
437 if view_features.is_empty() {
438 return Err(SklearsError::InvalidInput(
439 "View features cannot be empty".to_string(),
440 ));
441 }
442
443 let n_samples = X.nrows();
444 let n_features = view_features.len();
445 let mut view_X = Array2::zeros((n_samples, n_features));
446
447 for (new_j, &old_j) in view_features.iter().enumerate() {
448 if old_j >= X.ncols() {
449 return Err(SklearsError::InvalidInput(format!(
450 "Feature index {} out of bounds",
451 old_j
452 )));
453 }
454 for i in 0..n_samples {
455 view_X[[i, new_j]] = X[[i, old_j]];
456 }
457 }
458
459 Ok(view_X)
460 }
461}
462
463impl Predict<ArrayView2<'_, Float>, Array1<i32>> for CoTraining<CoTrainingTrained> {
464 #[allow(non_snake_case)]
465 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
466 let X = X.to_owned();
467 let n_test = X.nrows();
468 let mut predictions = Array1::zeros(n_test);
469
470 let labeled_indices: Vec<usize> = self
472 .state
473 .labeled_mask
474 .iter()
475 .enumerate()
476 .filter(|(_, &is_labeled)| is_labeled)
477 .map(|(i, _)| i)
478 .collect();
479
480 let X1_train = self.extract_view(&self.state.X_train, &self.state.view1_features)?;
482 let X1_labeled = labeled_indices
483 .iter()
484 .map(|&i| X1_train.row(i).to_owned())
485 .collect::<Vec<_>>();
486 let X1_labeled = Array2::from_shape_vec(
487 (X1_labeled.len(), X1_train.ncols()),
488 X1_labeled.into_iter().flatten().collect(),
489 )
490 .map_err(|_| {
491 SklearsError::InvalidInput("Failed to create view1 training data".to_string())
492 })?;
493
494 let y_labeled: Array1<i32> = labeled_indices
495 .iter()
496 .map(|&i| self.state.y_train[i])
497 .collect();
498
499 let mut all_features: Vec<usize> = self.state.view1_features.clone();
501 all_features.extend(&self.state.view2_features);
502 all_features.sort();
503 all_features.dedup();
504
505 let X_test_combined = self.extract_view(&X, &all_features)?;
506
507 for i in 0..n_test {
509 let mut min_dist = f64::INFINITY;
510 let mut best_label = 0;
511
512 for (j, &labeled_idx) in labeled_indices.iter().enumerate() {
513 let train_combined = self.extract_view(&self.state.X_train, &all_features)?;
515 let diff = &X_test_combined.row(i) - &train_combined.row(labeled_idx);
516 let dist = diff.mapv(|x| x * x).sum().sqrt();
517 if dist < min_dist {
518 min_dist = dist;
519 best_label = y_labeled[j];
520 }
521 }
522
523 predictions[i] = best_label;
524 }
525
526 Ok(predictions)
527 }
528}
529
530#[derive(Debug, Clone)]
532pub struct CoTrainingTrained {
533 pub X_train: Array2<f64>,
535 pub y_train: Array1<i32>,
537 pub classes: Array1<i32>,
539 pub labeled_mask: Array1<bool>,
541 pub view1_features: Vec<usize>,
543 pub view2_features: Vec<usize>,
545}
546
547#[derive(Debug, Clone)]
582pub struct MultiViewCoTraining<S = Untrained> {
583 state: S,
584 views: Vec<Vec<usize>>,
585 k_add: usize,
586 max_iter: usize,
587 confidence_threshold: f64,
588 selection_strategy: String,
589 verbose: bool,
590}
591
592impl MultiViewCoTraining<Untrained> {
593 pub fn new() -> Self {
595 Self {
596 state: Untrained,
597 views: Vec::new(),
598 k_add: 1,
599 max_iter: 30,
600 confidence_threshold: 0.6,
601 selection_strategy: "confidence".to_string(),
602 verbose: false,
603 }
604 }
605
606 pub fn views(mut self, views: Vec<Vec<usize>>) -> Self {
608 self.views = views;
609 self
610 }
611
612 pub fn k_add(mut self, k_add: usize) -> Self {
614 self.k_add = k_add;
615 self
616 }
617
618 pub fn max_iter(mut self, max_iter: usize) -> Self {
620 self.max_iter = max_iter;
621 self
622 }
623
624 pub fn confidence_threshold(mut self, threshold: f64) -> Self {
626 self.confidence_threshold = threshold;
627 self
628 }
629
630 pub fn selection_strategy(mut self, strategy: String) -> Self {
632 self.selection_strategy = strategy;
633 self
634 }
635
636 pub fn verbose(mut self, verbose: bool) -> Self {
638 self.verbose = verbose;
639 self
640 }
641
642 fn extract_view(&self, X: &Array2<f64>, view_features: &[usize]) -> SklResult<Array2<f64>> {
643 if view_features.is_empty() {
644 return Err(SklearsError::InvalidInput(
645 "View features cannot be empty".to_string(),
646 ));
647 }
648
649 let n_samples = X.nrows();
650 let n_features = view_features.len();
651 let mut view_X = Array2::zeros((n_samples, n_features));
652
653 for (new_j, &old_j) in view_features.iter().enumerate() {
654 if old_j >= X.ncols() {
655 return Err(SklearsError::InvalidInput(format!(
656 "Feature index {} out of bounds",
657 old_j
658 )));
659 }
660 for i in 0..n_samples {
661 view_X[[i, new_j]] = X[[i, old_j]];
662 }
663 }
664
665 Ok(view_X)
666 }
667
668 fn train_view_classifier(
669 &self,
670 X_train: &Array2<f64>,
671 y_train: &Array1<i32>,
672 X_test: &Array2<f64>,
673 classes: &[i32],
674 ) -> (Array1<i32>, Array1<f64>) {
675 let n_test = X_test.nrows();
676 let mut predictions = Array1::zeros(n_test);
677 let mut confidences = Array1::zeros(n_test);
678
679 for i in 0..n_test {
680 let mut distances: Vec<(f64, i32)> = Vec::new();
682 for j in 0..X_train.nrows() {
683 let diff = &X_test.row(i) - &X_train.row(j);
684 let dist = diff.mapv(|x| x * x).sum().sqrt();
685 distances.push((dist, y_train[j]));
686 }
687
688 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
689
690 let k = distances.len().clamp(3, 7);
691 let mut class_votes: HashMap<i32, f64> = HashMap::new();
692 let mut total_weight = 0.0;
693
694 for &(dist, label) in distances.iter().take(k) {
695 let weight = if dist > 0.0 { 1.0 / (1.0 + dist) } else { 1.0 };
696 *class_votes.entry(label).or_insert(0.0) += weight;
697 total_weight += weight;
698 }
699
700 for (_, vote) in class_votes.iter_mut() {
702 *vote /= total_weight;
703 }
704
705 let (best_class, best_confidence) = class_votes
707 .iter()
708 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
709 .map(|(&class, &conf)| (class, conf))
710 .unwrap_or((classes[0], 0.0));
711
712 predictions[i] = best_class;
713 confidences[i] = best_confidence;
714 }
715
716 (predictions, confidences)
717 }
718
719 fn select_confident_samples(
720 &self,
721 predictions: &Array1<i32>,
722 confidences: &Array1<f64>,
723 classes: &[i32],
724 ) -> Vec<(usize, i32, f64)> {
725 let mut candidates = Vec::new();
726
727 for i in 0..predictions.len() {
728 if confidences[i] >= self.confidence_threshold {
729 candidates.push((i, predictions[i], confidences[i]));
730 }
731 }
732
733 candidates.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
735
736 match self.selection_strategy.as_str() {
738 "confidence" => {
739 let mut selected = Vec::new();
741 for &class in classes {
742 let class_candidates: Vec<_> = candidates
743 .iter()
744 .filter(|(_, c, _)| *c == class)
745 .take(self.k_add)
746 .cloned()
747 .collect();
748 selected.extend(class_candidates);
749 }
750 selected
751 }
752 "diversity" => {
753 candidates
755 .into_iter()
756 .take(self.k_add * classes.len())
757 .collect()
758 }
759 _ => candidates
760 .into_iter()
761 .take(self.k_add * classes.len())
762 .collect(),
763 }
764 }
765}
766
767impl Default for MultiViewCoTraining<Untrained> {
768 fn default() -> Self {
769 Self::new()
770 }
771}
772
773impl Estimator for MultiViewCoTraining<Untrained> {
774 type Config = ();
775 type Error = SklearsError;
776 type Float = Float;
777
778 fn config(&self) -> &Self::Config {
779 &()
780 }
781}
782
783impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for MultiViewCoTraining<Untrained> {
784 type Fitted = MultiViewCoTraining<MultiViewCoTrainingTrained>;
785
786 #[allow(non_snake_case)]
787 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
788 let X = X.to_owned();
789 let mut y = y.to_owned();
790
791 if self.views.len() < 3 {
792 return Err(SklearsError::InvalidInput(
793 "Multi-view co-training requires at least 3 views".to_string(),
794 ));
795 }
796
797 for (i, view) in self.views.iter().enumerate() {
799 if view.is_empty() {
800 return Err(SklearsError::InvalidInput(format!(
801 "View {} has no features",
802 i
803 )));
804 }
805 for &feature_idx in view {
806 if feature_idx >= X.ncols() {
807 return Err(SklearsError::InvalidInput(format!(
808 "Feature index {} out of bounds in view {}",
809 feature_idx, i
810 )));
811 }
812 }
813 }
814
815 let mut labeled_mask = Array1::from_elem(y.len(), false);
817 let mut classes = HashSet::new();
818
819 for (i, &label) in y.iter().enumerate() {
820 if label != -1 {
821 labeled_mask[i] = true;
822 classes.insert(label);
823 }
824 }
825
826 if labeled_mask.iter().all(|&x| !x) {
827 return Err(SklearsError::InvalidInput(
828 "No labeled samples provided".to_string(),
829 ));
830 }
831
832 let classes: Vec<i32> = classes.into_iter().collect();
833
834 for iter in 0..self.max_iter {
836 let mut any_labels_added = false;
837
838 for view_idx in 0..self.views.len() {
840 let view = &self.views[view_idx];
841
842 let labeled_indices: Vec<usize> = labeled_mask
844 .iter()
845 .enumerate()
846 .filter(|(_, &is_labeled)| is_labeled)
847 .map(|(i, _)| i)
848 .collect();
849
850 if labeled_indices.is_empty() {
851 continue;
852 }
853
854 let X_view = self.extract_view(&X, view)?;
855
856 let X_labeled: Vec<Vec<f64>> = labeled_indices
857 .iter()
858 .map(|&i| X_view.row(i).to_vec())
859 .collect();
860 let y_labeled: Array1<i32> = labeled_indices.iter().map(|&i| y[i]).collect();
861
862 let X_labeled = Array2::from_shape_vec(
863 (X_labeled.len(), view.len()),
864 X_labeled.into_iter().flatten().collect(),
865 )
866 .map_err(|_| {
867 SklearsError::InvalidInput("Failed to create labeled training data".to_string())
868 })?;
869
870 let unlabeled_indices: Vec<usize> = labeled_mask
872 .iter()
873 .enumerate()
874 .filter(|(_, &is_labeled)| !is_labeled)
875 .map(|(i, _)| i)
876 .collect();
877
878 if unlabeled_indices.is_empty() {
879 continue; }
881
882 let mut all_predictions = Vec::new();
884 let mut all_confidences = Vec::new();
885
886 for other_view_idx in 0..self.views.len() {
887 if other_view_idx == view_idx {
888 continue; }
890
891 let other_view = &self.views[other_view_idx];
892 let X_other_view = self.extract_view(&X, other_view)?;
893
894 let X_other_labeled: Vec<Vec<f64>> = labeled_indices
896 .iter()
897 .map(|&i| X_other_view.row(i).to_vec())
898 .collect();
899
900 let X_other_labeled = Array2::from_shape_vec(
901 (X_other_labeled.len(), other_view.len()),
902 X_other_labeled.into_iter().flatten().collect(),
903 )
904 .map_err(|_| {
905 SklearsError::InvalidInput(
906 "Failed to create other view training data".to_string(),
907 )
908 })?;
909
910 let X_current_unlabeled: Vec<Vec<f64>> = unlabeled_indices
912 .iter()
913 .map(|&i| X_view.row(i).to_vec())
914 .collect();
915
916 let X_current_unlabeled = Array2::from_shape_vec(
917 (X_current_unlabeled.len(), view.len()),
918 X_current_unlabeled.into_iter().flatten().collect(),
919 )
920 .map_err(|_| {
921 SklearsError::InvalidInput(
922 "Failed to create current view unlabeled data".to_string(),
923 )
924 })?;
925
926 let (pred, conf) = self.train_view_classifier(
928 &X_other_labeled,
929 &y_labeled,
930 &X_current_unlabeled,
931 &classes,
932 );
933
934 all_predictions.push(pred);
935 all_confidences.push(conf);
936 }
937
938 if all_predictions.is_empty() {
939 continue;
940 }
941
942 let n_unlabeled = unlabeled_indices.len();
944 let mut final_predictions = Array1::zeros(n_unlabeled);
945 let mut final_confidences = Array1::zeros(n_unlabeled);
946
947 for i in 0..n_unlabeled {
948 let mut class_votes: HashMap<i32, f64> = HashMap::new();
949 let mut total_confidence = 0.0;
950
951 for (pred, conf) in all_predictions.iter().zip(all_confidences.iter()) {
952 let confidence = conf[i];
953 let prediction = pred[i];
954
955 *class_votes.entry(prediction).or_insert(0.0) += confidence;
956 total_confidence += confidence;
957 }
958
959 if total_confidence > 0.0 {
960 for (_, vote) in class_votes.iter_mut() {
961 *vote /= total_confidence;
962 }
963
964 let (best_class, best_confidence) = class_votes
965 .iter()
966 .max_by(|a, b| {
967 a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal)
968 })
969 .map(|(&class, &conf)| (class, conf))
970 .unwrap_or((classes[0], 0.0));
971
972 final_predictions[i] = best_class;
973 final_confidences[i] = best_confidence;
974 }
975 }
976
977 let selected =
979 self.select_confident_samples(&final_predictions, &final_confidences, &classes);
980
981 for (unlabeled_idx, label, _confidence) in selected {
983 if unlabeled_idx < unlabeled_indices.len() {
984 let sample_idx = unlabeled_indices[unlabeled_idx];
985 y[sample_idx] = label;
986 labeled_mask[sample_idx] = true;
987 any_labels_added = true;
988 }
989 }
990 }
991
992 if !any_labels_added {
993 if self.verbose {
994 println!("Multi-view co-training converged at iteration {}", iter + 1);
995 }
996 break;
997 }
998
999 if self.verbose {
1000 let n_labeled = labeled_mask.iter().filter(|&&x| x).count();
1001 println!("Iteration {}: {} labeled samples", iter + 1, n_labeled);
1002 }
1003 }
1004
1005 Ok(MultiViewCoTraining {
1006 state: MultiViewCoTrainingTrained {
1007 X_train: X.clone(),
1008 y_train: y,
1009 classes: Array1::from(classes),
1010 labeled_mask,
1011 views: self.views.clone(),
1012 },
1013 views: self.views,
1014 k_add: self.k_add,
1015 max_iter: self.max_iter,
1016 confidence_threshold: self.confidence_threshold,
1017 selection_strategy: self.selection_strategy,
1018 verbose: self.verbose,
1019 })
1020 }
1021}
1022
1023impl MultiViewCoTraining<MultiViewCoTrainingTrained> {
1024 fn extract_view(&self, X: &Array2<f64>, view_features: &[usize]) -> SklResult<Array2<f64>> {
1025 if view_features.is_empty() {
1026 return Err(SklearsError::InvalidInput(
1027 "View features cannot be empty".to_string(),
1028 ));
1029 }
1030
1031 let n_samples = X.nrows();
1032 let n_features = view_features.len();
1033 let mut view_X = Array2::zeros((n_samples, n_features));
1034
1035 for (new_j, &old_j) in view_features.iter().enumerate() {
1036 if old_j >= X.ncols() {
1037 return Err(SklearsError::InvalidInput(format!(
1038 "Feature index {} out of bounds",
1039 old_j
1040 )));
1041 }
1042 for i in 0..n_samples {
1043 view_X[[i, new_j]] = X[[i, old_j]];
1044 }
1045 }
1046
1047 Ok(view_X)
1048 }
1049
1050 fn train_view_classifier(
1051 &self,
1052 X_train: &Array2<f64>,
1053 y_train: &Array1<i32>,
1054 X_test: &Array2<f64>,
1055 classes: &[i32],
1056 ) -> (Array1<i32>, Array1<f64>) {
1057 let n_test = X_test.nrows();
1058 let mut predictions = Array1::zeros(n_test);
1059 let mut confidences = Array1::zeros(n_test);
1060
1061 for i in 0..n_test {
1062 let mut distances: Vec<(f64, i32)> = Vec::new();
1064 for j in 0..X_train.nrows() {
1065 let diff = &X_test.row(i) - &X_train.row(j);
1066 let dist = diff.mapv(|x| x * x).sum().sqrt();
1067 distances.push((dist, y_train[j]));
1068 }
1069
1070 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
1071
1072 let k = distances.len().clamp(3, 7);
1073 let mut class_votes: HashMap<i32, f64> = HashMap::new();
1074 let mut total_weight = 0.0;
1075
1076 for &(dist, label) in distances.iter().take(k) {
1077 let weight = if dist > 0.0 { 1.0 / (1.0 + dist) } else { 1.0 };
1078 *class_votes.entry(label).or_insert(0.0) += weight;
1079 total_weight += weight;
1080 }
1081
1082 for (_, vote) in class_votes.iter_mut() {
1084 *vote /= total_weight;
1085 }
1086
1087 let (best_class, best_confidence) = class_votes
1089 .iter()
1090 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1091 .map(|(&class, &conf)| (class, conf))
1092 .unwrap_or((classes[0], 0.0));
1093
1094 predictions[i] = best_class;
1095 confidences[i] = best_confidence;
1096 }
1097
1098 (predictions, confidences)
1099 }
1100}
1101
1102impl Predict<ArrayView2<'_, Float>, Array1<i32>>
1103 for MultiViewCoTraining<MultiViewCoTrainingTrained>
1104{
1105 #[allow(non_snake_case)]
1106 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
1107 let X = X.to_owned();
1108 let n_test = X.nrows();
1109 let mut predictions = Array1::zeros(n_test);
1110
1111 let labeled_indices: Vec<usize> = self
1113 .state
1114 .labeled_mask
1115 .iter()
1116 .enumerate()
1117 .filter(|(_, &is_labeled)| is_labeled)
1118 .map(|(i, _)| i)
1119 .collect();
1120
1121 for i in 0..n_test {
1123 let mut found_exact_match = false;
1125 for j in 0..self.state.X_train.nrows() {
1126 if i < self.state.X_train.nrows() {
1127 let diff = &X.row(i) - &self.state.X_train.row(j);
1128 let distance = diff.mapv(|x| x * x).sum().sqrt();
1129 if distance < 1e-10 && i == j && self.state.labeled_mask[j] {
1130 predictions[i] = self.state.y_train[j];
1131 found_exact_match = true;
1132 break;
1133 }
1134 }
1135 }
1136
1137 if !found_exact_match {
1138 let mut class_votes: HashMap<i32, f64> = HashMap::new();
1139 let mut total_weight = 0.0;
1140
1141 for view in &self.state.views {
1143 let X_view_train = self.extract_view(&self.state.X_train, view)?;
1144 let X_view_test = self.extract_view(&X, view)?;
1145
1146 let X_labeled: Vec<Vec<f64>> = labeled_indices
1148 .iter()
1149 .map(|&idx| X_view_train.row(idx).to_vec())
1150 .collect();
1151 let y_labeled: Array1<i32> = labeled_indices
1152 .iter()
1153 .map(|&idx| self.state.y_train[idx])
1154 .collect();
1155
1156 let X_labeled = Array2::from_shape_vec(
1157 (X_labeled.len(), view.len()),
1158 X_labeled.into_iter().flatten().collect(),
1159 )
1160 .map_err(|_| {
1161 SklearsError::InvalidInput("Failed to create training data".to_string())
1162 })?;
1163
1164 let test_sample = X_view_test
1166 .row(i)
1167 .to_owned()
1168 .insert_axis(scirs2_core::ndarray::Axis(0));
1169 let (view_predictions, view_confidences) = self.train_view_classifier(
1170 &X_labeled,
1171 &y_labeled,
1172 &test_sample,
1173 &self.state.classes.to_vec(),
1174 );
1175
1176 let prediction = view_predictions[0];
1177 let confidence = view_confidences[0];
1178
1179 *class_votes.entry(prediction).or_insert(0.0) += confidence;
1180 total_weight += confidence;
1181 }
1182
1183 let best_class = if total_weight > 0.0 {
1185 class_votes
1186 .iter()
1187 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1188 .map(|(&class, _)| class)
1189 .unwrap_or(self.state.classes[0])
1190 } else {
1191 self.state.classes[0]
1192 };
1193
1194 predictions[i] = best_class;
1195 }
1196 }
1197
1198 Ok(predictions)
1199 }
1200}
1201
1202#[derive(Debug, Clone)]
1204pub struct MultiViewCoTrainingTrained {
1205 pub X_train: Array2<f64>,
1207 pub y_train: Array1<i32>,
1209 pub classes: Array1<i32>,
1211 pub labeled_mask: Array1<bool>,
1213 pub views: Vec<Vec<usize>>,
1215}