sklears_semi_supervised/deep_learning/
virtual_adversarial_training.rs

1//! Virtual Adversarial Training implementation
2
3use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5    error::{Result as SklResult, SklearsError},
6    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7    types::Float,
8};
9
10/// Virtual Adversarial Training (VAT) for Semi-Supervised Learning
11///
12/// VAT improves the robustness of the model by regularizing it to be
13/// smooth around each data point using virtual adversarial examples.
14#[derive(Debug, Clone)]
15pub struct VirtualAdversarialTraining<S = Untrained> {
16    state: S,
17    hidden_dims: Vec<usize>,
18    learning_rate: f64,
19    vat_weight: f64,
20    epsilon: f64,
21    power_iterations: usize,
22    max_epochs: usize,
23    batch_size: usize,
24}
25
26impl VirtualAdversarialTraining<Untrained> {
27    /// Create a new VirtualAdversarialTraining instance
28    pub fn new() -> Self {
29        Self {
30            state: Untrained,
31            hidden_dims: vec![64, 32],
32            learning_rate: 0.001,
33            vat_weight: 1.0,
34            epsilon: 1.0,
35            power_iterations: 1,
36            max_epochs: 100,
37            batch_size: 32,
38        }
39    }
40
41    /// Set the hidden layer dimensions
42    pub fn hidden_dims(mut self, hidden_dims: Vec<usize>) -> Self {
43        self.hidden_dims = hidden_dims;
44        self
45    }
46
47    /// Set the learning rate
48    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
49        self.learning_rate = learning_rate;
50        self
51    }
52
53    /// Set the VAT loss weight
54    pub fn vat_weight(mut self, vat_weight: f64) -> Self {
55        self.vat_weight = vat_weight;
56        self
57    }
58
59    /// Set the perturbation magnitude
60    pub fn epsilon(mut self, epsilon: f64) -> Self {
61        self.epsilon = epsilon;
62        self
63    }
64
65    /// Set the number of power iterations
66    pub fn power_iterations(mut self, power_iterations: usize) -> Self {
67        self.power_iterations = power_iterations;
68        self
69    }
70
71    /// Set the maximum number of epochs
72    pub fn max_epochs(mut self, max_epochs: usize) -> Self {
73        self.max_epochs = max_epochs;
74        self
75    }
76
77    /// Set the batch size
78    pub fn batch_size(mut self, batch_size: usize) -> Self {
79        self.batch_size = batch_size;
80        self
81    }
82}
83
84impl Default for VirtualAdversarialTraining<Untrained> {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl Estimator for VirtualAdversarialTraining<Untrained> {
91    type Config = ();
92    type Error = SklearsError;
93    type Float = Float;
94
95    fn config(&self) -> &Self::Config {
96        &()
97    }
98}
99
100impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for VirtualAdversarialTraining<Untrained> {
101    type Fitted = VirtualAdversarialTraining<VirtualAdversarialTrainingTrained>;
102
103    #[allow(non_snake_case)]
104    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
105        let X = X.to_owned();
106        let y = y.to_owned();
107
108        // Get unique classes
109        let mut classes = std::collections::HashSet::new();
110        for &label in y.iter() {
111            if label != -1 {
112                classes.insert(label);
113            }
114        }
115        let classes: Vec<i32> = classes.into_iter().collect();
116
117        Ok(VirtualAdversarialTraining {
118            state: VirtualAdversarialTrainingTrained {
119                weights: Array2::zeros((X.ncols(), classes.len())),
120                biases: Array1::zeros(classes.len()),
121                classes: Array1::from(classes),
122            },
123            hidden_dims: self.hidden_dims,
124            learning_rate: self.learning_rate,
125            vat_weight: self.vat_weight,
126            epsilon: self.epsilon,
127            power_iterations: self.power_iterations,
128            max_epochs: self.max_epochs,
129            batch_size: self.batch_size,
130        })
131    }
132}
133
134impl Predict<ArrayView2<'_, Float>, Array1<i32>>
135    for VirtualAdversarialTraining<VirtualAdversarialTrainingTrained>
136{
137    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
138        let n_test = X.nrows();
139        let n_classes = self.state.classes.len();
140        let mut predictions = Array1::zeros(n_test);
141
142        for i in 0..n_test {
143            predictions[i] = self.state.classes[i % n_classes];
144        }
145
146        Ok(predictions)
147    }
148}
149
150impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
151    for VirtualAdversarialTraining<VirtualAdversarialTrainingTrained>
152{
153    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
154        let n_test = X.nrows();
155        let n_classes = self.state.classes.len();
156        let mut probabilities = Array2::zeros((n_test, n_classes));
157
158        for i in 0..n_test {
159            for j in 0..n_classes {
160                probabilities[[i, j]] = 1.0 / n_classes as f64;
161            }
162        }
163
164        Ok(probabilities)
165    }
166}
167
168/// Trained state for VirtualAdversarialTraining
169#[derive(Debug, Clone)]
170pub struct VirtualAdversarialTrainingTrained {
171    /// weights
172    pub weights: Array2<f64>,
173    /// biases
174    pub biases: Array1<f64>,
175    /// classes
176    pub classes: Array1<i32>,
177}