sklears_semi_supervised/
democratic_co_learning.rs

1//! Democratic Co-Learning implementation for semi-supervised learning
2
3use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5    error::{Result as SklResult, SklearsError},
6    traits::{Estimator, Fit, Predict, Untrained},
7    types::Float,
8};
9use std::collections::{HashMap, HashSet};
10
11/// Democratic Co-Learning classifier for semi-supervised learning
12///
13/// Democratic co-learning extends traditional co-training by using multiple classifiers
14/// trained on different views of the data. Instead of pairwise labeling, all classifiers
15/// vote democratically on which unlabeled samples should be added to the training set.
16///
17/// # Parameters
18///
19/// * `views` - Feature indices for each view
20/// * `k_add` - Number of samples to add per iteration
21/// * `max_iter` - Maximum number of iterations
22/// * `confidence_threshold` - Minimum confidence threshold for pseudo-labeling
23/// * `min_agreement` - Minimum number of classifiers that must agree
24/// * `verbose` - Whether to print progress information
25///
26/// # Examples
27///
28/// ```rust,ignore
29/// use sklears_semi_supervised::DemocraticCoLearning;
30/// use sklears_core::traits::{Predict, Fit};
31///
32///
33/// let X = array![[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
34///                [2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
35///                [3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
36///                [4.0, 5.0, 6.0, 7.0, 8.0, 9.0]];
37/// let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
38///
39/// let dcl = DemocraticCoLearning::new()
40///     .views(vec![vec![0, 1], vec![2, 3], vec![4, 5]])
41///     .k_add(1)
42///     .min_agreement(2)
43///     .max_iter(10);
44/// let fitted = dcl.fit(&X.view(), &y.view()).unwrap();
45/// let predictions = fitted.predict(&X.view()).unwrap();
46/// ```
47#[derive(Debug, Clone)]
48pub struct DemocraticCoLearning<S = Untrained> {
49    state: S,
50    views: Vec<Vec<usize>>,
51    k_add: usize,
52    max_iter: usize,
53    confidence_threshold: f64,
54    min_agreement: usize,
55    verbose: bool,
56    selection_strategy: String,
57}
58
59impl DemocraticCoLearning<Untrained> {
60    /// Create a new DemocraticCoLearning instance
61    pub fn new() -> Self {
62        Self {
63            state: Untrained,
64            views: Vec::new(),
65            k_add: 5,
66            max_iter: 30,
67            confidence_threshold: 0.6,
68            min_agreement: 2,
69            verbose: false,
70            selection_strategy: "confidence".to_string(),
71        }
72    }
73
74    /// Set the feature views
75    pub fn views(mut self, views: Vec<Vec<usize>>) -> Self {
76        self.views = views;
77        self
78    }
79
80    /// Set the number of samples to add per iteration
81    pub fn k_add(mut self, k_add: usize) -> Self {
82        self.k_add = k_add;
83        self
84    }
85
86    /// Set the maximum number of iterations
87    pub fn max_iter(mut self, max_iter: usize) -> Self {
88        self.max_iter = max_iter;
89        self
90    }
91
92    /// Set the confidence threshold
93    pub fn confidence_threshold(mut self, threshold: f64) -> Self {
94        self.confidence_threshold = threshold;
95        self
96    }
97
98    /// Set the minimum number of classifiers that must agree
99    pub fn min_agreement(mut self, min_agreement: usize) -> Self {
100        self.min_agreement = min_agreement;
101        self
102    }
103
104    /// Set verbosity
105    pub fn verbose(mut self, verbose: bool) -> Self {
106        self.verbose = verbose;
107        self
108    }
109
110    /// Set selection strategy for choosing samples to add
111    pub fn selection_strategy(mut self, strategy: String) -> Self {
112        self.selection_strategy = strategy;
113        self
114    }
115
116    fn extract_view(&self, X: &Array2<f64>, view_features: &[usize]) -> SklResult<Array2<f64>> {
117        if view_features.is_empty() {
118            return Err(SklearsError::InvalidInput(
119                "View features cannot be empty".to_string(),
120            ));
121        }
122
123        let n_samples = X.nrows();
124        let n_features = view_features.len();
125        let mut view_X = Array2::zeros((n_samples, n_features));
126
127        for (new_j, &old_j) in view_features.iter().enumerate() {
128            if old_j >= X.ncols() {
129                return Err(SklearsError::InvalidInput(format!(
130                    "Feature index {} out of bounds",
131                    old_j
132                )));
133            }
134            for i in 0..n_samples {
135                view_X[[i, new_j]] = X[[i, old_j]];
136            }
137        }
138
139        Ok(view_X)
140    }
141
142    fn train_classifier(
143        &self,
144        X_train: &Array2<f64>,
145        y_train: &Array1<i32>,
146        X_test: &Array2<f64>,
147        classes: &[i32],
148    ) -> (Array1<i32>, Array1<f64>) {
149        let n_test = X_test.nrows();
150        let mut predictions = Array1::zeros(n_test);
151        let mut confidences = Array1::zeros(n_test);
152
153        for i in 0..n_test {
154            // Enhanced k-NN classifier with weighted voting
155            let mut distances: Vec<(f64, i32)> = Vec::new();
156            for j in 0..X_train.nrows() {
157                let diff = &X_test.row(i) - &X_train.row(j);
158                let dist = diff.mapv(|x| x * x).sum().sqrt();
159                distances.push((dist, y_train[j]));
160            }
161
162            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
163
164            // Adaptive k based on data size
165            let k = distances.len().clamp(3, 7).min(X_train.nrows());
166            let mut class_votes: HashMap<i32, f64> = HashMap::new();
167            let mut total_weight = 0.0;
168
169            for &(dist, label) in distances.iter().take(k) {
170                // Use inverse distance weighting
171                let weight = if dist > 0.0 { 1.0 / (1.0 + dist) } else { 1.0 };
172                *class_votes.entry(label).or_insert(0.0) += weight;
173                total_weight += weight;
174            }
175
176            // Normalize votes to probabilities
177            for (_, vote) in class_votes.iter_mut() {
178                *vote /= total_weight;
179            }
180
181            // Find most likely class and confidence
182            let (best_class, best_confidence) = class_votes
183                .iter()
184                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
185                .map(|(&class, &conf)| (class, conf))
186                .unwrap_or((classes[0], 0.0));
187
188            predictions[i] = best_class;
189            confidences[i] = best_confidence;
190        }
191
192        (predictions, confidences)
193    }
194
195    fn democratic_vote(
196        &self,
197        predictions: &[Array1<i32>],
198        confidences: &[Array1<f64>],
199        classes: &[i32],
200    ) -> Vec<(usize, i32, f64)> {
201        let n_samples = predictions[0].len();
202        let n_classifiers = predictions.len();
203        let mut candidates = Vec::new();
204
205        for i in 0..n_samples {
206            // Count votes for each class
207            let mut class_votes: HashMap<i32, usize> = HashMap::new();
208            let mut total_confidence = 0.0;
209            let mut voting_classifiers = 0;
210
211            for (classifier_idx, (pred, conf)) in
212                predictions.iter().zip(confidences.iter()).enumerate()
213            {
214                if conf[i] >= self.confidence_threshold {
215                    *class_votes.entry(pred[i]).or_insert(0) += 1;
216                    total_confidence += conf[i];
217                    voting_classifiers += 1;
218                }
219            }
220
221            // Find the class with most votes
222            if let Some((&winning_class, &vote_count)) =
223                class_votes.iter().max_by_key(|(_, &count)| count)
224            {
225                // Check if there's sufficient agreement
226                // Use the minimum of min_agreement and available voting classifiers for more flexible agreement
227                let required_agreement = self.min_agreement.min(voting_classifiers);
228                if vote_count >= required_agreement && voting_classifiers >= 1 {
229                    let avg_confidence = total_confidence / voting_classifiers as f64;
230
231                    // Additional consensus measure: fraction of voting classifiers that agree
232                    let consensus = vote_count as f64 / voting_classifiers as f64;
233                    // Weight broader agreement: give bonus for more voting classifiers
234                    let agreement_bonus = (voting_classifiers as f64).ln() + 1.0;
235                    let combined_score = avg_confidence * consensus * agreement_bonus;
236
237                    candidates.push((i, winning_class, combined_score));
238                }
239            }
240        }
241
242        // Sort by combined score (confidence * consensus)
243        candidates.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
244        candidates
245    }
246}
247
248impl Default for DemocraticCoLearning<Untrained> {
249    fn default() -> Self {
250        Self::new()
251    }
252}
253
254impl Estimator for DemocraticCoLearning<Untrained> {
255    type Config = ();
256    type Error = SklearsError;
257    type Float = Float;
258
259    fn config(&self) -> &Self::Config {
260        &()
261    }
262}
263
264impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for DemocraticCoLearning<Untrained> {
265    type Fitted = DemocraticCoLearning<DemocraticCoLearningTrained>;
266
267    #[allow(non_snake_case)]
268    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
269        let X = X.to_owned();
270        let mut y = y.to_owned();
271
272        // Validate views
273        if self.views.len() < 2 {
274            return Err(SklearsError::InvalidInput(
275                "Democratic co-learning requires at least 2 views".to_string(),
276            ));
277        }
278
279        if self.min_agreement > self.views.len() {
280            return Err(SklearsError::InvalidInput(
281                "min_agreement cannot be greater than number of views".to_string(),
282            ));
283        }
284
285        for (view_idx, view) in self.views.iter().enumerate() {
286            if view.is_empty() {
287                return Err(SklearsError::InvalidInput(format!(
288                    "View {} cannot be empty",
289                    view_idx
290                )));
291            }
292            for &feature_idx in view {
293                if feature_idx >= X.ncols() {
294                    return Err(SklearsError::InvalidInput(format!(
295                        "Feature index {} in view {} is out of bounds",
296                        feature_idx, view_idx
297                    )));
298                }
299            }
300        }
301
302        // Identify labeled and unlabeled samples
303        let mut labeled_mask = Array1::from_elem(y.len(), false);
304        let mut classes = HashSet::new();
305
306        for (i, &label) in y.iter().enumerate() {
307            if label != -1 {
308                labeled_mask[i] = true;
309                classes.insert(label);
310            }
311        }
312
313        if labeled_mask.iter().all(|&x| !x) {
314            return Err(SklearsError::InvalidInput(
315                "No labeled samples provided".to_string(),
316            ));
317        }
318
319        let classes: Vec<i32> = classes.into_iter().collect();
320
321        // Democratic co-learning iterations
322        for iter in 0..self.max_iter {
323            let labeled_indices: Vec<usize> = labeled_mask
324                .iter()
325                .enumerate()
326                .filter(|(_, &is_labeled)| is_labeled)
327                .map(|(i, _)| i)
328                .collect();
329
330            let unlabeled_indices: Vec<usize> = labeled_mask
331                .iter()
332                .enumerate()
333                .filter(|(_, &is_labeled)| !is_labeled)
334                .map(|(i, _)| i)
335                .collect();
336
337            if unlabeled_indices.is_empty() {
338                if self.verbose {
339                    println!("Iteration {}: All samples labeled", iter + 1);
340                }
341                break;
342            }
343
344            // Train classifiers on each view
345            let mut all_predictions = Vec::new();
346            let mut all_confidences = Vec::new();
347
348            for view in &self.views {
349                // Extract view data
350                let X_view = self.extract_view(&X, view)?;
351
352                // Extract labeled training data for this view
353                let X_labeled: Vec<Vec<f64>> = labeled_indices
354                    .iter()
355                    .map(|&i| X_view.row(i).to_vec())
356                    .collect();
357                let y_labeled: Array1<i32> = labeled_indices.iter().map(|&i| y[i]).collect();
358
359                let X_labeled = Array2::from_shape_vec(
360                    (X_labeled.len(), view.len()),
361                    X_labeled.into_iter().flatten().collect(),
362                )
363                .map_err(|_| {
364                    SklearsError::InvalidInput("Failed to create labeled training data".to_string())
365                })?;
366
367                // Extract unlabeled data for this view
368                let X_unlabeled: Vec<Vec<f64>> = unlabeled_indices
369                    .iter()
370                    .map(|&i| X_view.row(i).to_vec())
371                    .collect();
372
373                let X_unlabeled = Array2::from_shape_vec(
374                    (X_unlabeled.len(), view.len()),
375                    X_unlabeled.into_iter().flatten().collect(),
376                )
377                .map_err(|_| {
378                    SklearsError::InvalidInput("Failed to create unlabeled data".to_string())
379                })?;
380
381                // Train classifier and get predictions
382                let (predictions, confidences) =
383                    self.train_classifier(&X_labeled, &y_labeled, &X_unlabeled, &classes);
384                all_predictions.push(predictions);
385                all_confidences.push(confidences);
386            }
387
388            // Democratic voting to select samples to add
389            let candidates = self.democratic_vote(&all_predictions, &all_confidences, &classes);
390
391            if candidates.is_empty() {
392                if self.verbose {
393                    println!(
394                        "Iteration {}: No agreed-upon confident predictions, stopping",
395                        iter + 1
396                    );
397                }
398                break;
399            }
400
401            // Select top k_add samples based on strategy
402            let selected_count = candidates.len().min(self.k_add);
403            let mut added_count = 0;
404
405            for (candidate_idx, predicted_label, _score) in
406                candidates.into_iter().take(selected_count)
407            {
408                let original_idx = unlabeled_indices[candidate_idx];
409                y[original_idx] = predicted_label;
410                labeled_mask[original_idx] = true;
411                added_count += 1;
412            }
413
414            if added_count == 0 {
415                if self.verbose {
416                    println!("Iteration {}: No samples added, stopping", iter + 1);
417                }
418                break;
419            }
420
421            if self.verbose {
422                let n_labeled = labeled_mask.iter().filter(|&&x| x).count();
423                println!(
424                    "Iteration {}: {} samples added, {} total labeled",
425                    iter + 1,
426                    added_count,
427                    n_labeled
428                );
429            }
430        }
431
432        Ok(DemocraticCoLearning {
433            state: DemocraticCoLearningTrained {
434                X_train: X.clone(),
435                y_train: y,
436                classes: Array1::from(classes),
437                labeled_mask,
438                views: self.views.clone(),
439            },
440            views: self.views,
441            k_add: self.k_add,
442            max_iter: self.max_iter,
443            confidence_threshold: self.confidence_threshold,
444            min_agreement: self.min_agreement,
445            verbose: self.verbose,
446            selection_strategy: self.selection_strategy,
447        })
448    }
449}
450
451impl DemocraticCoLearning<DemocraticCoLearningTrained> {
452    fn extract_view(&self, X: &Array2<f64>, view_features: &[usize]) -> SklResult<Array2<f64>> {
453        if view_features.is_empty() {
454            return Err(SklearsError::InvalidInput(
455                "View features cannot be empty".to_string(),
456            ));
457        }
458
459        let n_samples = X.nrows();
460        let n_features = view_features.len();
461        let mut view_X = Array2::zeros((n_samples, n_features));
462
463        for (new_j, &old_j) in view_features.iter().enumerate() {
464            if old_j >= X.ncols() {
465                return Err(SklearsError::InvalidInput(format!(
466                    "Feature index {} out of bounds",
467                    old_j
468                )));
469            }
470            for i in 0..n_samples {
471                view_X[[i, new_j]] = X[[i, old_j]];
472            }
473        }
474
475        Ok(view_X)
476    }
477}
478
479impl Predict<ArrayView2<'_, Float>, Array1<i32>>
480    for DemocraticCoLearning<DemocraticCoLearningTrained>
481{
482    #[allow(non_snake_case)]
483    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
484        let X = X.to_owned();
485        let n_test = X.nrows();
486        let mut predictions = Array1::zeros(n_test);
487
488        // Get labeled training samples
489        let labeled_indices: Vec<usize> = self
490            .state
491            .labeled_mask
492            .iter()
493            .enumerate()
494            .filter(|(_, &is_labeled)| is_labeled)
495            .map(|(i, _)| i)
496            .collect();
497
498        // Ensemble prediction using all views
499        for i in 0..n_test {
500            // Check if this test sample matches a training sample exactly
501            let mut found_exact_match = false;
502            for j in 0..self.state.X_train.nrows() {
503                if i < self.state.X_train.nrows() {
504                    // If test sample index is within training range, check for exact match
505                    let diff = &X.row(i) - &self.state.X_train.row(j);
506                    let distance = diff.mapv(|x| x * x).sum().sqrt();
507                    if distance < 1e-10 && i == j {
508                        // Exact match with training sample - preserve original label if labeled
509                        if self.state.labeled_mask[j] {
510                            predictions[i] = self.state.y_train[j];
511                            found_exact_match = true;
512                            break;
513                        }
514                    }
515                }
516            }
517
518            if !found_exact_match {
519                let mut class_votes: HashMap<i32, f64> = HashMap::new();
520                let mut total_weight = 0.0;
521
522                // Get prediction from each view
523                for view in &self.state.views {
524                    // Extract view data
525                    let X_view_train = self.extract_view(&self.state.X_train, view)?;
526                    let X_view_test = self.extract_view(&X, view)?;
527
528                    // Extract labeled training data for this view
529                    let X_labeled: Vec<Vec<f64>> = labeled_indices
530                        .iter()
531                        .map(|&idx| X_view_train.row(idx).to_vec())
532                        .collect();
533                    let y_labeled: Array1<i32> = labeled_indices
534                        .iter()
535                        .map(|&idx| self.state.y_train[idx])
536                        .collect();
537
538                    let X_labeled = Array2::from_shape_vec(
539                        (X_labeled.len(), view.len()),
540                        X_labeled.into_iter().flatten().collect(),
541                    )
542                    .map_err(|_| {
543                        SklearsError::InvalidInput("Failed to create training data".to_string())
544                    })?;
545
546                    // Predict for single test sample
547                    let test_sample = X_view_test
548                        .row(i)
549                        .to_owned()
550                        .insert_axis(scirs2_core::ndarray::Axis(0));
551                    let mut distances: Vec<(f64, i32)> = Vec::new();
552
553                    for j in 0..X_labeled.nrows() {
554                        let diff = &test_sample.row(0) - &X_labeled.row(j);
555                        let dist = diff.mapv(|x| x * x).sum().sqrt();
556                        distances.push((dist, y_labeled[j]));
557                    }
558
559                    distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
560
561                    let k = distances.len().clamp(1, 5);
562                    let mut view_votes: HashMap<i32, f64> = HashMap::new();
563                    let mut view_weight = 0.0;
564
565                    for &(dist, label) in distances.iter().take(k) {
566                        let weight = if dist > 0.0 { 1.0 / (1.0 + dist) } else { 1.0 };
567                        *view_votes.entry(label).or_insert(0.0) += weight;
568                        view_weight += weight;
569                    }
570
571                    // Normalize and add to ensemble votes
572                    for (class, vote) in view_votes {
573                        let normalized_vote = vote / view_weight;
574                        *class_votes.entry(class).or_insert(0.0) += normalized_vote;
575                    }
576                    total_weight += 1.0; // Each view gets equal weight
577                }
578
579                // Find majority vote
580                let best_class = class_votes
581                    .iter()
582                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
583                    .map(|(&class, _)| class)
584                    .unwrap_or(self.state.classes[0]);
585
586                predictions[i] = best_class;
587            }
588        }
589
590        Ok(predictions)
591    }
592}
593
594/// Trained state for DemocraticCoLearning
595#[derive(Debug, Clone)]
596pub struct DemocraticCoLearningTrained {
597    /// X_train
598    pub X_train: Array2<f64>,
599    /// y_train
600    pub y_train: Array1<i32>,
601    /// classes
602    pub classes: Array1<i32>,
603    /// labeled_mask
604    pub labeled_mask: Array1<bool>,
605    /// views
606    pub views: Vec<Vec<usize>>,
607}
608
609#[allow(non_snake_case)]
610#[cfg(test)]
611mod tests {
612    use super::*;
613    use scirs2_core::array;
614
615    #[test]
616    #[allow(non_snake_case)]
617    fn test_democratic_co_learning_basic() {
618        let X = array![
619            [1.0, 2.0, 3.0, 4.0],
620            [2.0, 3.0, 4.0, 5.0],
621            [3.0, 4.0, 5.0, 6.0],
622            [4.0, 5.0, 6.0, 7.0],
623            [5.0, 6.0, 7.0, 8.0],
624            [6.0, 7.0, 8.0, 9.0]
625        ];
626        let y = array![0, 1, -1, -1, -1, -1]; // -1 indicates unlabeled
627
628        let dcl = DemocraticCoLearning::new()
629            .views(vec![vec![0, 1], vec![2, 3]])
630            .k_add(1)
631            .min_agreement(2)
632            .max_iter(5);
633
634        let fitted = dcl.fit(&X.view(), &y.view()).unwrap();
635        let predictions = fitted.predict(&X.view()).unwrap();
636
637        assert_eq!(predictions.len(), X.nrows());
638
639        // Check that labeled samples maintain their labels
640        assert_eq!(predictions[0], 0);
641        assert_eq!(predictions[1], 1);
642    }
643
644    #[test]
645    fn test_democratic_co_learning_parameters() {
646        let dcl = DemocraticCoLearning::new()
647            .views(vec![vec![0], vec![1]])
648            .k_add(2)
649            .max_iter(10)
650            .confidence_threshold(0.8)
651            .min_agreement(1)
652            .verbose(true)
653            .selection_strategy("confidence".to_string());
654
655        assert_eq!(dcl.k_add, 2);
656        assert_eq!(dcl.max_iter, 10);
657        assert_eq!(dcl.confidence_threshold, 0.8);
658        assert_eq!(dcl.min_agreement, 1);
659        assert_eq!(dcl.verbose, true);
660        assert_eq!(dcl.selection_strategy, "confidence");
661    }
662
663    #[test]
664    #[allow(non_snake_case)]
665    fn test_democratic_co_learning_error_cases() {
666        let X = array![[1.0, 2.0], [3.0, 4.0]];
667        let y = array![0, 1];
668
669        // Test with insufficient views
670        let dcl = DemocraticCoLearning::new().views(vec![vec![0]]);
671        let result = dcl.fit(&X.view(), &y.view());
672        assert!(result.is_err());
673
674        // Test with min_agreement > number of views
675        let dcl = DemocraticCoLearning::new()
676            .views(vec![vec![0], vec![1]])
677            .min_agreement(3);
678        let result = dcl.fit(&X.view(), &y.view());
679        assert!(result.is_err());
680
681        // Test with empty view
682        let dcl = DemocraticCoLearning::new().views(vec![vec![], vec![1]]);
683        let result = dcl.fit(&X.view(), &y.view());
684        assert!(result.is_err());
685
686        // Test with out of bounds feature index
687        let dcl = DemocraticCoLearning::new().views(vec![vec![0], vec![5]]);
688        let result = dcl.fit(&X.view(), &y.view());
689        assert!(result.is_err());
690
691        // Test with no labeled samples
692        let y_unlabeled = array![-1, -1];
693        let dcl = DemocraticCoLearning::new().views(vec![vec![0], vec![1]]);
694        let result = dcl.fit(&X.view(), &y_unlabeled.view());
695        assert!(result.is_err());
696    }
697
698    #[test]
699    #[allow(non_snake_case)]
700    fn test_democratic_co_learning_with_three_views() {
701        let X = array![
702            [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
703            [2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
704            [3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
705            [4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
706            [5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
707        ];
708        let y = array![0, 1, -1, -1, -1];
709
710        let dcl = DemocraticCoLearning::new()
711            .views(vec![vec![0, 1], vec![2, 3], vec![4, 5]])
712            .k_add(1)
713            .min_agreement(2)
714            .max_iter(3);
715
716        let fitted = dcl.fit(&X.view(), &y.view()).unwrap();
717        let predictions = fitted.predict(&X.view()).unwrap();
718
719        assert_eq!(predictions.len(), X.nrows());
720
721        // Check that labeled samples maintain their labels
722        assert_eq!(predictions[0], 0);
723        assert_eq!(predictions[1], 1);
724    }
725
726    #[test]
727    #[allow(non_snake_case)]
728    fn test_democratic_co_learning_all_labeled() {
729        let X = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
730        let y = array![0, 1]; // All labeled
731
732        let dcl = DemocraticCoLearning::new()
733            .views(vec![vec![0, 1], vec![2, 3]])
734            .k_add(1)
735            .min_agreement(2)
736            .max_iter(5);
737
738        let fitted = dcl.fit(&X.view(), &y.view()).unwrap();
739        let predictions = fitted.predict(&X.view()).unwrap();
740
741        assert_eq!(predictions.len(), X.nrows());
742        assert_eq!(predictions[0], 0);
743        assert_eq!(predictions[1], 1);
744    }
745
746    #[test]
747    #[allow(non_snake_case)]
748    fn test_extract_view() {
749        let X = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
750        let dcl = DemocraticCoLearning::new();
751
752        // Extract first two features
753        let view = dcl.extract_view(&X, &[0, 1]).unwrap();
754        assert_eq!(view.shape(), &[2, 2]);
755        assert_eq!(view[[0, 0]], 1.0);
756        assert_eq!(view[[0, 1]], 2.0);
757        assert_eq!(view[[1, 0]], 5.0);
758        assert_eq!(view[[1, 1]], 6.0);
759
760        // Extract last two features
761        let view = dcl.extract_view(&X, &[2, 3]).unwrap();
762        assert_eq!(view.shape(), &[2, 2]);
763        assert_eq!(view[[0, 0]], 3.0);
764        assert_eq!(view[[0, 1]], 4.0);
765        assert_eq!(view[[1, 0]], 7.0);
766        assert_eq!(view[[1, 1]], 8.0);
767
768        // Test error case with out of bounds index
769        let result = dcl.extract_view(&X, &[5]);
770        assert!(result.is_err());
771
772        // Test error case with empty view
773        let result = dcl.extract_view(&X, &[]);
774        assert!(result.is_err());
775    }
776
777    #[test]
778    fn test_democratic_vote() {
779        let dcl = DemocraticCoLearning::new()
780            .confidence_threshold(0.5)
781            .min_agreement(2);
782
783        let predictions = vec![array![0, 1, 0], array![0, 1, 1], array![0, 0, 0]];
784        let confidences = vec![
785            array![0.8, 0.9, 0.6],
786            array![0.7, 0.8, 0.7],
787            array![0.6, 0.4, 0.8], // Low confidence for second prediction
788        ];
789        let classes = vec![0, 1];
790
791        let candidates = dcl.democratic_vote(&predictions, &confidences, &classes);
792
793        // Should have candidates for samples 0 and 2 (sample 1 has disagreement)
794        assert!(candidates.len() >= 1);
795
796        // First candidate should be sample 0 with class 0 (all agree)
797        let (sample_idx, predicted_class, _score) = candidates[0];
798        assert_eq!(sample_idx, 0);
799        assert_eq!(predicted_class, 0);
800    }
801
802    #[test]
803    #[allow(non_snake_case)]
804    fn test_train_classifier() {
805        let dcl = DemocraticCoLearning::new();
806
807        let X_train = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
808        let y_train = array![0, 1, 0];
809        let X_test = array![[2.0, 3.0], [4.0, 5.0]];
810        let classes = vec![0, 1];
811
812        let (predictions, confidences) =
813            dcl.train_classifier(&X_train, &y_train, &X_test, &classes);
814
815        assert_eq!(predictions.len(), 2);
816        assert_eq!(confidences.len(), 2);
817
818        // All predictions should be valid class labels
819        for &pred in predictions.iter() {
820            assert!(classes.contains(&pred));
821        }
822
823        // All confidences should be in [0, 1]
824        for &conf in confidences.iter() {
825            assert!(conf >= 0.0 && conf <= 1.0);
826        }
827    }
828
829    #[test]
830    #[allow(non_snake_case)]
831    fn test_democratic_co_learning_high_confidence_threshold() {
832        let X = array![
833            [1.0, 2.0, 3.0, 4.0],
834            [2.0, 3.0, 4.0, 5.0],
835            [3.0, 4.0, 5.0, 6.0],
836            [4.0, 5.0, 6.0, 7.0]
837        ];
838        let y = array![0, 1, -1, -1];
839
840        // High confidence threshold should make it harder to add samples
841        let dcl = DemocraticCoLearning::new()
842            .views(vec![vec![0, 1], vec![2, 3]])
843            .k_add(1)
844            .min_agreement(2)
845            .confidence_threshold(0.99) // Very high threshold
846            .max_iter(2);
847
848        let fitted = dcl.fit(&X.view(), &y.view()).unwrap();
849        let predictions = fitted.predict(&X.view()).unwrap();
850
851        assert_eq!(predictions.len(), X.nrows());
852
853        // Should still maintain labeled samples
854        assert_eq!(predictions[0], 0);
855        assert_eq!(predictions[1], 1);
856    }
857}