1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::{Estimator, Fit, Predict, Untrained},
7 types::Float,
8};
9use std::collections::{HashMap, HashSet};
10
11#[derive(Debug, Clone)]
48pub struct DemocraticCoLearning<S = Untrained> {
49 state: S,
50 views: Vec<Vec<usize>>,
51 k_add: usize,
52 max_iter: usize,
53 confidence_threshold: f64,
54 min_agreement: usize,
55 verbose: bool,
56 selection_strategy: String,
57}
58
59impl DemocraticCoLearning<Untrained> {
60 pub fn new() -> Self {
62 Self {
63 state: Untrained,
64 views: Vec::new(),
65 k_add: 5,
66 max_iter: 30,
67 confidence_threshold: 0.6,
68 min_agreement: 2,
69 verbose: false,
70 selection_strategy: "confidence".to_string(),
71 }
72 }
73
74 pub fn views(mut self, views: Vec<Vec<usize>>) -> Self {
76 self.views = views;
77 self
78 }
79
80 pub fn k_add(mut self, k_add: usize) -> Self {
82 self.k_add = k_add;
83 self
84 }
85
86 pub fn max_iter(mut self, max_iter: usize) -> Self {
88 self.max_iter = max_iter;
89 self
90 }
91
92 pub fn confidence_threshold(mut self, threshold: f64) -> Self {
94 self.confidence_threshold = threshold;
95 self
96 }
97
98 pub fn min_agreement(mut self, min_agreement: usize) -> Self {
100 self.min_agreement = min_agreement;
101 self
102 }
103
104 pub fn verbose(mut self, verbose: bool) -> Self {
106 self.verbose = verbose;
107 self
108 }
109
110 pub fn selection_strategy(mut self, strategy: String) -> Self {
112 self.selection_strategy = strategy;
113 self
114 }
115
116 fn extract_view(&self, X: &Array2<f64>, view_features: &[usize]) -> SklResult<Array2<f64>> {
117 if view_features.is_empty() {
118 return Err(SklearsError::InvalidInput(
119 "View features cannot be empty".to_string(),
120 ));
121 }
122
123 let n_samples = X.nrows();
124 let n_features = view_features.len();
125 let mut view_X = Array2::zeros((n_samples, n_features));
126
127 for (new_j, &old_j) in view_features.iter().enumerate() {
128 if old_j >= X.ncols() {
129 return Err(SklearsError::InvalidInput(format!(
130 "Feature index {} out of bounds",
131 old_j
132 )));
133 }
134 for i in 0..n_samples {
135 view_X[[i, new_j]] = X[[i, old_j]];
136 }
137 }
138
139 Ok(view_X)
140 }
141
142 fn train_classifier(
143 &self,
144 X_train: &Array2<f64>,
145 y_train: &Array1<i32>,
146 X_test: &Array2<f64>,
147 classes: &[i32],
148 ) -> (Array1<i32>, Array1<f64>) {
149 let n_test = X_test.nrows();
150 let mut predictions = Array1::zeros(n_test);
151 let mut confidences = Array1::zeros(n_test);
152
153 for i in 0..n_test {
154 let mut distances: Vec<(f64, i32)> = Vec::new();
156 for j in 0..X_train.nrows() {
157 let diff = &X_test.row(i) - &X_train.row(j);
158 let dist = diff.mapv(|x| x * x).sum().sqrt();
159 distances.push((dist, y_train[j]));
160 }
161
162 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
163
164 let k = distances.len().clamp(3, 7).min(X_train.nrows());
166 let mut class_votes: HashMap<i32, f64> = HashMap::new();
167 let mut total_weight = 0.0;
168
169 for &(dist, label) in distances.iter().take(k) {
170 let weight = if dist > 0.0 { 1.0 / (1.0 + dist) } else { 1.0 };
172 *class_votes.entry(label).or_insert(0.0) += weight;
173 total_weight += weight;
174 }
175
176 for (_, vote) in class_votes.iter_mut() {
178 *vote /= total_weight;
179 }
180
181 let (best_class, best_confidence) = class_votes
183 .iter()
184 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
185 .map(|(&class, &conf)| (class, conf))
186 .unwrap_or((classes[0], 0.0));
187
188 predictions[i] = best_class;
189 confidences[i] = best_confidence;
190 }
191
192 (predictions, confidences)
193 }
194
195 fn democratic_vote(
196 &self,
197 predictions: &[Array1<i32>],
198 confidences: &[Array1<f64>],
199 classes: &[i32],
200 ) -> Vec<(usize, i32, f64)> {
201 let n_samples = predictions[0].len();
202 let n_classifiers = predictions.len();
203 let mut candidates = Vec::new();
204
205 for i in 0..n_samples {
206 let mut class_votes: HashMap<i32, usize> = HashMap::new();
208 let mut total_confidence = 0.0;
209 let mut voting_classifiers = 0;
210
211 for (classifier_idx, (pred, conf)) in
212 predictions.iter().zip(confidences.iter()).enumerate()
213 {
214 if conf[i] >= self.confidence_threshold {
215 *class_votes.entry(pred[i]).or_insert(0) += 1;
216 total_confidence += conf[i];
217 voting_classifiers += 1;
218 }
219 }
220
221 if let Some((&winning_class, &vote_count)) =
223 class_votes.iter().max_by_key(|(_, &count)| count)
224 {
225 let required_agreement = self.min_agreement.min(voting_classifiers);
228 if vote_count >= required_agreement && voting_classifiers >= 1 {
229 let avg_confidence = total_confidence / voting_classifiers as f64;
230
231 let consensus = vote_count as f64 / voting_classifiers as f64;
233 let agreement_bonus = (voting_classifiers as f64).ln() + 1.0;
235 let combined_score = avg_confidence * consensus * agreement_bonus;
236
237 candidates.push((i, winning_class, combined_score));
238 }
239 }
240 }
241
242 candidates.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
244 candidates
245 }
246}
247
248impl Default for DemocraticCoLearning<Untrained> {
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254impl Estimator for DemocraticCoLearning<Untrained> {
255 type Config = ();
256 type Error = SklearsError;
257 type Float = Float;
258
259 fn config(&self) -> &Self::Config {
260 &()
261 }
262}
263
264impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for DemocraticCoLearning<Untrained> {
265 type Fitted = DemocraticCoLearning<DemocraticCoLearningTrained>;
266
267 #[allow(non_snake_case)]
268 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
269 let X = X.to_owned();
270 let mut y = y.to_owned();
271
272 if self.views.len() < 2 {
274 return Err(SklearsError::InvalidInput(
275 "Democratic co-learning requires at least 2 views".to_string(),
276 ));
277 }
278
279 if self.min_agreement > self.views.len() {
280 return Err(SklearsError::InvalidInput(
281 "min_agreement cannot be greater than number of views".to_string(),
282 ));
283 }
284
285 for (view_idx, view) in self.views.iter().enumerate() {
286 if view.is_empty() {
287 return Err(SklearsError::InvalidInput(format!(
288 "View {} cannot be empty",
289 view_idx
290 )));
291 }
292 for &feature_idx in view {
293 if feature_idx >= X.ncols() {
294 return Err(SklearsError::InvalidInput(format!(
295 "Feature index {} in view {} is out of bounds",
296 feature_idx, view_idx
297 )));
298 }
299 }
300 }
301
302 let mut labeled_mask = Array1::from_elem(y.len(), false);
304 let mut classes = HashSet::new();
305
306 for (i, &label) in y.iter().enumerate() {
307 if label != -1 {
308 labeled_mask[i] = true;
309 classes.insert(label);
310 }
311 }
312
313 if labeled_mask.iter().all(|&x| !x) {
314 return Err(SklearsError::InvalidInput(
315 "No labeled samples provided".to_string(),
316 ));
317 }
318
319 let classes: Vec<i32> = classes.into_iter().collect();
320
321 for iter in 0..self.max_iter {
323 let labeled_indices: Vec<usize> = labeled_mask
324 .iter()
325 .enumerate()
326 .filter(|(_, &is_labeled)| is_labeled)
327 .map(|(i, _)| i)
328 .collect();
329
330 let unlabeled_indices: Vec<usize> = labeled_mask
331 .iter()
332 .enumerate()
333 .filter(|(_, &is_labeled)| !is_labeled)
334 .map(|(i, _)| i)
335 .collect();
336
337 if unlabeled_indices.is_empty() {
338 if self.verbose {
339 println!("Iteration {}: All samples labeled", iter + 1);
340 }
341 break;
342 }
343
344 let mut all_predictions = Vec::new();
346 let mut all_confidences = Vec::new();
347
348 for view in &self.views {
349 let X_view = self.extract_view(&X, view)?;
351
352 let X_labeled: Vec<Vec<f64>> = labeled_indices
354 .iter()
355 .map(|&i| X_view.row(i).to_vec())
356 .collect();
357 let y_labeled: Array1<i32> = labeled_indices.iter().map(|&i| y[i]).collect();
358
359 let X_labeled = Array2::from_shape_vec(
360 (X_labeled.len(), view.len()),
361 X_labeled.into_iter().flatten().collect(),
362 )
363 .map_err(|_| {
364 SklearsError::InvalidInput("Failed to create labeled training data".to_string())
365 })?;
366
367 let X_unlabeled: Vec<Vec<f64>> = unlabeled_indices
369 .iter()
370 .map(|&i| X_view.row(i).to_vec())
371 .collect();
372
373 let X_unlabeled = Array2::from_shape_vec(
374 (X_unlabeled.len(), view.len()),
375 X_unlabeled.into_iter().flatten().collect(),
376 )
377 .map_err(|_| {
378 SklearsError::InvalidInput("Failed to create unlabeled data".to_string())
379 })?;
380
381 let (predictions, confidences) =
383 self.train_classifier(&X_labeled, &y_labeled, &X_unlabeled, &classes);
384 all_predictions.push(predictions);
385 all_confidences.push(confidences);
386 }
387
388 let candidates = self.democratic_vote(&all_predictions, &all_confidences, &classes);
390
391 if candidates.is_empty() {
392 if self.verbose {
393 println!(
394 "Iteration {}: No agreed-upon confident predictions, stopping",
395 iter + 1
396 );
397 }
398 break;
399 }
400
401 let selected_count = candidates.len().min(self.k_add);
403 let mut added_count = 0;
404
405 for (candidate_idx, predicted_label, _score) in
406 candidates.into_iter().take(selected_count)
407 {
408 let original_idx = unlabeled_indices[candidate_idx];
409 y[original_idx] = predicted_label;
410 labeled_mask[original_idx] = true;
411 added_count += 1;
412 }
413
414 if added_count == 0 {
415 if self.verbose {
416 println!("Iteration {}: No samples added, stopping", iter + 1);
417 }
418 break;
419 }
420
421 if self.verbose {
422 let n_labeled = labeled_mask.iter().filter(|&&x| x).count();
423 println!(
424 "Iteration {}: {} samples added, {} total labeled",
425 iter + 1,
426 added_count,
427 n_labeled
428 );
429 }
430 }
431
432 Ok(DemocraticCoLearning {
433 state: DemocraticCoLearningTrained {
434 X_train: X.clone(),
435 y_train: y,
436 classes: Array1::from(classes),
437 labeled_mask,
438 views: self.views.clone(),
439 },
440 views: self.views,
441 k_add: self.k_add,
442 max_iter: self.max_iter,
443 confidence_threshold: self.confidence_threshold,
444 min_agreement: self.min_agreement,
445 verbose: self.verbose,
446 selection_strategy: self.selection_strategy,
447 })
448 }
449}
450
451impl DemocraticCoLearning<DemocraticCoLearningTrained> {
452 fn extract_view(&self, X: &Array2<f64>, view_features: &[usize]) -> SklResult<Array2<f64>> {
453 if view_features.is_empty() {
454 return Err(SklearsError::InvalidInput(
455 "View features cannot be empty".to_string(),
456 ));
457 }
458
459 let n_samples = X.nrows();
460 let n_features = view_features.len();
461 let mut view_X = Array2::zeros((n_samples, n_features));
462
463 for (new_j, &old_j) in view_features.iter().enumerate() {
464 if old_j >= X.ncols() {
465 return Err(SklearsError::InvalidInput(format!(
466 "Feature index {} out of bounds",
467 old_j
468 )));
469 }
470 for i in 0..n_samples {
471 view_X[[i, new_j]] = X[[i, old_j]];
472 }
473 }
474
475 Ok(view_X)
476 }
477}
478
479impl Predict<ArrayView2<'_, Float>, Array1<i32>>
480 for DemocraticCoLearning<DemocraticCoLearningTrained>
481{
482 #[allow(non_snake_case)]
483 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
484 let X = X.to_owned();
485 let n_test = X.nrows();
486 let mut predictions = Array1::zeros(n_test);
487
488 let labeled_indices: Vec<usize> = self
490 .state
491 .labeled_mask
492 .iter()
493 .enumerate()
494 .filter(|(_, &is_labeled)| is_labeled)
495 .map(|(i, _)| i)
496 .collect();
497
498 for i in 0..n_test {
500 let mut found_exact_match = false;
502 for j in 0..self.state.X_train.nrows() {
503 if i < self.state.X_train.nrows() {
504 let diff = &X.row(i) - &self.state.X_train.row(j);
506 let distance = diff.mapv(|x| x * x).sum().sqrt();
507 if distance < 1e-10 && i == j {
508 if self.state.labeled_mask[j] {
510 predictions[i] = self.state.y_train[j];
511 found_exact_match = true;
512 break;
513 }
514 }
515 }
516 }
517
518 if !found_exact_match {
519 let mut class_votes: HashMap<i32, f64> = HashMap::new();
520 let mut total_weight = 0.0;
521
522 for view in &self.state.views {
524 let X_view_train = self.extract_view(&self.state.X_train, view)?;
526 let X_view_test = self.extract_view(&X, view)?;
527
528 let X_labeled: Vec<Vec<f64>> = labeled_indices
530 .iter()
531 .map(|&idx| X_view_train.row(idx).to_vec())
532 .collect();
533 let y_labeled: Array1<i32> = labeled_indices
534 .iter()
535 .map(|&idx| self.state.y_train[idx])
536 .collect();
537
538 let X_labeled = Array2::from_shape_vec(
539 (X_labeled.len(), view.len()),
540 X_labeled.into_iter().flatten().collect(),
541 )
542 .map_err(|_| {
543 SklearsError::InvalidInput("Failed to create training data".to_string())
544 })?;
545
546 let test_sample = X_view_test
548 .row(i)
549 .to_owned()
550 .insert_axis(scirs2_core::ndarray::Axis(0));
551 let mut distances: Vec<(f64, i32)> = Vec::new();
552
553 for j in 0..X_labeled.nrows() {
554 let diff = &test_sample.row(0) - &X_labeled.row(j);
555 let dist = diff.mapv(|x| x * x).sum().sqrt();
556 distances.push((dist, y_labeled[j]));
557 }
558
559 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
560
561 let k = distances.len().clamp(1, 5);
562 let mut view_votes: HashMap<i32, f64> = HashMap::new();
563 let mut view_weight = 0.0;
564
565 for &(dist, label) in distances.iter().take(k) {
566 let weight = if dist > 0.0 { 1.0 / (1.0 + dist) } else { 1.0 };
567 *view_votes.entry(label).or_insert(0.0) += weight;
568 view_weight += weight;
569 }
570
571 for (class, vote) in view_votes {
573 let normalized_vote = vote / view_weight;
574 *class_votes.entry(class).or_insert(0.0) += normalized_vote;
575 }
576 total_weight += 1.0; }
578
579 let best_class = class_votes
581 .iter()
582 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
583 .map(|(&class, _)| class)
584 .unwrap_or(self.state.classes[0]);
585
586 predictions[i] = best_class;
587 }
588 }
589
590 Ok(predictions)
591 }
592}
593
594#[derive(Debug, Clone)]
596pub struct DemocraticCoLearningTrained {
597 pub X_train: Array2<f64>,
599 pub y_train: Array1<i32>,
601 pub classes: Array1<i32>,
603 pub labeled_mask: Array1<bool>,
605 pub views: Vec<Vec<usize>>,
607}
608
609#[allow(non_snake_case)]
610#[cfg(test)]
611mod tests {
612 use super::*;
613 use scirs2_core::array;
614
615 #[test]
616 #[allow(non_snake_case)]
617 fn test_democratic_co_learning_basic() {
618 let X = array![
619 [1.0, 2.0, 3.0, 4.0],
620 [2.0, 3.0, 4.0, 5.0],
621 [3.0, 4.0, 5.0, 6.0],
622 [4.0, 5.0, 6.0, 7.0],
623 [5.0, 6.0, 7.0, 8.0],
624 [6.0, 7.0, 8.0, 9.0]
625 ];
626 let y = array![0, 1, -1, -1, -1, -1]; let dcl = DemocraticCoLearning::new()
629 .views(vec![vec![0, 1], vec![2, 3]])
630 .k_add(1)
631 .min_agreement(2)
632 .max_iter(5);
633
634 let fitted = dcl.fit(&X.view(), &y.view()).unwrap();
635 let predictions = fitted.predict(&X.view()).unwrap();
636
637 assert_eq!(predictions.len(), X.nrows());
638
639 assert_eq!(predictions[0], 0);
641 assert_eq!(predictions[1], 1);
642 }
643
644 #[test]
645 fn test_democratic_co_learning_parameters() {
646 let dcl = DemocraticCoLearning::new()
647 .views(vec![vec![0], vec![1]])
648 .k_add(2)
649 .max_iter(10)
650 .confidence_threshold(0.8)
651 .min_agreement(1)
652 .verbose(true)
653 .selection_strategy("confidence".to_string());
654
655 assert_eq!(dcl.k_add, 2);
656 assert_eq!(dcl.max_iter, 10);
657 assert_eq!(dcl.confidence_threshold, 0.8);
658 assert_eq!(dcl.min_agreement, 1);
659 assert_eq!(dcl.verbose, true);
660 assert_eq!(dcl.selection_strategy, "confidence");
661 }
662
663 #[test]
664 #[allow(non_snake_case)]
665 fn test_democratic_co_learning_error_cases() {
666 let X = array![[1.0, 2.0], [3.0, 4.0]];
667 let y = array![0, 1];
668
669 let dcl = DemocraticCoLearning::new().views(vec![vec![0]]);
671 let result = dcl.fit(&X.view(), &y.view());
672 assert!(result.is_err());
673
674 let dcl = DemocraticCoLearning::new()
676 .views(vec![vec![0], vec![1]])
677 .min_agreement(3);
678 let result = dcl.fit(&X.view(), &y.view());
679 assert!(result.is_err());
680
681 let dcl = DemocraticCoLearning::new().views(vec![vec![], vec![1]]);
683 let result = dcl.fit(&X.view(), &y.view());
684 assert!(result.is_err());
685
686 let dcl = DemocraticCoLearning::new().views(vec![vec![0], vec![5]]);
688 let result = dcl.fit(&X.view(), &y.view());
689 assert!(result.is_err());
690
691 let y_unlabeled = array![-1, -1];
693 let dcl = DemocraticCoLearning::new().views(vec![vec![0], vec![1]]);
694 let result = dcl.fit(&X.view(), &y_unlabeled.view());
695 assert!(result.is_err());
696 }
697
698 #[test]
699 #[allow(non_snake_case)]
700 fn test_democratic_co_learning_with_three_views() {
701 let X = array![
702 [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
703 [2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
704 [3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
705 [4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
706 [5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
707 ];
708 let y = array![0, 1, -1, -1, -1];
709
710 let dcl = DemocraticCoLearning::new()
711 .views(vec![vec![0, 1], vec![2, 3], vec![4, 5]])
712 .k_add(1)
713 .min_agreement(2)
714 .max_iter(3);
715
716 let fitted = dcl.fit(&X.view(), &y.view()).unwrap();
717 let predictions = fitted.predict(&X.view()).unwrap();
718
719 assert_eq!(predictions.len(), X.nrows());
720
721 assert_eq!(predictions[0], 0);
723 assert_eq!(predictions[1], 1);
724 }
725
726 #[test]
727 #[allow(non_snake_case)]
728 fn test_democratic_co_learning_all_labeled() {
729 let X = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
730 let y = array![0, 1]; let dcl = DemocraticCoLearning::new()
733 .views(vec![vec![0, 1], vec![2, 3]])
734 .k_add(1)
735 .min_agreement(2)
736 .max_iter(5);
737
738 let fitted = dcl.fit(&X.view(), &y.view()).unwrap();
739 let predictions = fitted.predict(&X.view()).unwrap();
740
741 assert_eq!(predictions.len(), X.nrows());
742 assert_eq!(predictions[0], 0);
743 assert_eq!(predictions[1], 1);
744 }
745
746 #[test]
747 #[allow(non_snake_case)]
748 fn test_extract_view() {
749 let X = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
750 let dcl = DemocraticCoLearning::new();
751
752 let view = dcl.extract_view(&X, &[0, 1]).unwrap();
754 assert_eq!(view.shape(), &[2, 2]);
755 assert_eq!(view[[0, 0]], 1.0);
756 assert_eq!(view[[0, 1]], 2.0);
757 assert_eq!(view[[1, 0]], 5.0);
758 assert_eq!(view[[1, 1]], 6.0);
759
760 let view = dcl.extract_view(&X, &[2, 3]).unwrap();
762 assert_eq!(view.shape(), &[2, 2]);
763 assert_eq!(view[[0, 0]], 3.0);
764 assert_eq!(view[[0, 1]], 4.0);
765 assert_eq!(view[[1, 0]], 7.0);
766 assert_eq!(view[[1, 1]], 8.0);
767
768 let result = dcl.extract_view(&X, &[5]);
770 assert!(result.is_err());
771
772 let result = dcl.extract_view(&X, &[]);
774 assert!(result.is_err());
775 }
776
777 #[test]
778 fn test_democratic_vote() {
779 let dcl = DemocraticCoLearning::new()
780 .confidence_threshold(0.5)
781 .min_agreement(2);
782
783 let predictions = vec![array![0, 1, 0], array![0, 1, 1], array![0, 0, 0]];
784 let confidences = vec![
785 array![0.8, 0.9, 0.6],
786 array![0.7, 0.8, 0.7],
787 array![0.6, 0.4, 0.8], ];
789 let classes = vec![0, 1];
790
791 let candidates = dcl.democratic_vote(&predictions, &confidences, &classes);
792
793 assert!(candidates.len() >= 1);
795
796 let (sample_idx, predicted_class, _score) = candidates[0];
798 assert_eq!(sample_idx, 0);
799 assert_eq!(predicted_class, 0);
800 }
801
802 #[test]
803 #[allow(non_snake_case)]
804 fn test_train_classifier() {
805 let dcl = DemocraticCoLearning::new();
806
807 let X_train = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
808 let y_train = array![0, 1, 0];
809 let X_test = array![[2.0, 3.0], [4.0, 5.0]];
810 let classes = vec![0, 1];
811
812 let (predictions, confidences) =
813 dcl.train_classifier(&X_train, &y_train, &X_test, &classes);
814
815 assert_eq!(predictions.len(), 2);
816 assert_eq!(confidences.len(), 2);
817
818 for &pred in predictions.iter() {
820 assert!(classes.contains(&pred));
821 }
822
823 for &conf in confidences.iter() {
825 assert!(conf >= 0.0 && conf <= 1.0);
826 }
827 }
828
829 #[test]
830 #[allow(non_snake_case)]
831 fn test_democratic_co_learning_high_confidence_threshold() {
832 let X = array![
833 [1.0, 2.0, 3.0, 4.0],
834 [2.0, 3.0, 4.0, 5.0],
835 [3.0, 4.0, 5.0, 6.0],
836 [4.0, 5.0, 6.0, 7.0]
837 ];
838 let y = array![0, 1, -1, -1];
839
840 let dcl = DemocraticCoLearning::new()
842 .views(vec![vec![0, 1], vec![2, 3]])
843 .k_add(1)
844 .min_agreement(2)
845 .confidence_threshold(0.99) .max_iter(2);
847
848 let fitted = dcl.fit(&X.view(), &y.view()).unwrap();
849 let predictions = fitted.predict(&X.view()).unwrap();
850
851 assert_eq!(predictions.len(), X.nrows());
852
853 assert_eq!(predictions[0], 0);
855 assert_eq!(predictions[1], 1);
856 }
857}