sklears_semi_supervised/contrastive_learning/
supervised_contrastive.rs

1//! Supervised Contrastive Learning implementation for semi-supervised scenarios
2
3use super::{ContrastiveLearningError, *};
4use scirs2_core::random::Rng;
5
6/// Supervised Contrastive Learning for semi-supervised scenarios
7///
8/// This method extends contrastive learning to utilize both labeled and unlabeled data
9/// by pulling together samples from the same class while pushing apart samples from different classes.
10#[derive(Debug, Clone)]
11pub struct SupervisedContrastiveLearning {
12    /// embedding_dim
13    pub embedding_dim: usize,
14    /// temperature
15    pub temperature: f64,
16    /// learning_rate
17    pub learning_rate: f64,
18    /// batch_size
19    pub batch_size: usize,
20    /// max_epochs
21    pub max_epochs: usize,
22    /// augmentation_strength
23    pub augmentation_strength: f64,
24    /// labeled_weight
25    pub labeled_weight: f64,
26    /// random_state
27    pub random_state: Option<u64>,
28}
29
30impl Default for SupervisedContrastiveLearning {
31    fn default() -> Self {
32        Self {
33            embedding_dim: 128,
34            temperature: 0.07,
35            learning_rate: 0.001,
36            batch_size: 32,
37            max_epochs: 100,
38            augmentation_strength: 0.5,
39            labeled_weight: 2.0,
40            random_state: None,
41        }
42    }
43}
44
45impl SupervisedContrastiveLearning {
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    pub fn embedding_dim(mut self, embedding_dim: usize) -> Self {
51        self.embedding_dim = embedding_dim;
52        self
53    }
54
55    pub fn temperature(mut self, temperature: f64) -> Result<Self> {
56        if temperature <= 0.0 {
57            return Err(ContrastiveLearningError::InvalidTemperature(temperature).into());
58        }
59        self.temperature = temperature;
60        Ok(self)
61    }
62
63    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
64        self.learning_rate = learning_rate;
65        self
66    }
67
68    pub fn batch_size(mut self, batch_size: usize) -> Result<Self> {
69        if batch_size == 0 {
70            return Err(ContrastiveLearningError::InvalidBatchSize(batch_size).into());
71        }
72        self.batch_size = batch_size;
73        Ok(self)
74    }
75
76    pub fn max_epochs(mut self, max_epochs: usize) -> Self {
77        self.max_epochs = max_epochs;
78        self
79    }
80
81    pub fn augmentation_strength(mut self, augmentation_strength: f64) -> Result<Self> {
82        if !(0.0..=1.0).contains(&augmentation_strength) {
83            return Err(ContrastiveLearningError::InvalidAugmentationStrength(
84                augmentation_strength,
85            )
86            .into());
87        }
88        self.augmentation_strength = augmentation_strength;
89        Ok(self)
90    }
91
92    pub fn labeled_weight(mut self, labeled_weight: f64) -> Self {
93        self.labeled_weight = labeled_weight;
94        self
95    }
96
97    pub fn random_state(mut self, random_state: u64) -> Self {
98        self.random_state = Some(random_state);
99        self
100    }
101
102    fn augment_data<R>(&self, X: &ArrayView2<f64>, rng: &mut Random<R>) -> Result<Array2<f64>>
103    where
104        R: Rng,
105    {
106        let (n_samples, n_features) = X.dim();
107        let mut augmented = X.to_owned();
108
109        // Gaussian noise augmentation - create noise manually
110        let noise_std = self.augmentation_strength * 0.1;
111        let mut noise = Array2::<f64>::zeros((n_samples, n_features));
112        for i in 0..n_samples {
113            for j in 0..n_features {
114                // Generate normal distributed random number
115                let u1: f64 = rng.gen_range(0.0..1.0);
116                let u2: f64 = rng.gen_range(0.0..1.0);
117                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
118                noise[(i, j)] = z * noise_std;
119            }
120        }
121        augmented = augmented + noise;
122
123        Ok(augmented)
124    }
125
126    fn compute_supervised_contrastive_loss(
127        &self,
128        embeddings: &ArrayView2<f64>,
129        labels: &ArrayView1<i32>,
130    ) -> Result<f64> {
131        let n_samples = embeddings.dim().0;
132        let mut total_loss = 0.0;
133        let mut n_labeled = 0;
134
135        for i in 0..n_samples {
136            if labels[i] == -1 {
137                continue; // Skip unlabeled samples
138            }
139
140            let anchor = embeddings.row(i);
141            let anchor_label = labels[i];
142
143            let mut positive_scores = Vec::new();
144            let mut negative_scores = Vec::new();
145
146            for j in 0..n_samples {
147                if i == j {
148                    continue;
149                }
150
151                let sample = embeddings.row(j);
152                let score = anchor.dot(&sample) / self.temperature;
153
154                if labels[j] == anchor_label && labels[j] != -1 {
155                    positive_scores.push(score);
156                } else if labels[j] != -1 {
157                    negative_scores.push(score);
158                }
159            }
160
161            if positive_scores.is_empty() {
162                continue;
163            }
164
165            // Compute supervised contrastive loss
166            let max_score = positive_scores
167                .iter()
168                .chain(negative_scores.iter())
169                .cloned()
170                .fold(f64::NEG_INFINITY, f64::max);
171
172            let mut pos_exp_sum = 0.0;
173            for &score in positive_scores.iter() {
174                pos_exp_sum += (score - max_score).exp();
175            }
176
177            let mut all_exp_sum = pos_exp_sum;
178            for &score in negative_scores.iter() {
179                all_exp_sum += (score - max_score).exp();
180            }
181
182            if all_exp_sum > 0.0 {
183                let loss = -(pos_exp_sum / all_exp_sum).ln();
184                total_loss += loss;
185                n_labeled += 1;
186            }
187        }
188
189        if n_labeled > 0 {
190            Ok(total_loss / n_labeled as f64)
191        } else {
192            Ok(0.0)
193        }
194    }
195}
196
197/// Fitted Supervised Contrastive Learning model
198#[derive(Debug, Clone)]
199pub struct FittedSupervisedContrastiveLearning {
200    /// base_model
201    pub base_model: SupervisedContrastiveLearning,
202    /// encoder_weights
203    pub encoder_weights: Array2<f64>,
204    /// classes
205    pub classes: Array1<i32>,
206    /// n_classes
207    pub n_classes: usize,
208    /// class_centroids
209    pub class_centroids: Array2<f64>,
210}
211
212impl Estimator for SupervisedContrastiveLearning {
213    type Config = SupervisedContrastiveLearning;
214    type Error = ContrastiveLearningError;
215    type Float = f64;
216
217    fn config(&self) -> &Self::Config {
218        self
219    }
220}
221
222impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>> for SupervisedContrastiveLearning {
223    type Fitted = FittedSupervisedContrastiveLearning;
224
225    #[allow(non_snake_case)]
226    fn fit(self, X: &ArrayView2<'_, f64>, y: &ArrayView1<'_, i32>) -> Result<Self::Fitted> {
227        let (n_samples, n_features) = X.dim();
228
229        // Check for sufficient labeled samples
230        let labeled_count = y.iter().filter(|&&label| label != -1).count();
231        if labeled_count < 2 {
232            return Err(ContrastiveLearningError::InsufficientLabeledSamples.into());
233        }
234
235        let mut rng = match self.random_state {
236            Some(seed) => Random::seed(seed),
237            None => Random::seed(42),
238        };
239
240        // Initialize encoder
241        // Initialize encoder weights manually
242        let mut encoder_weights = Array2::<f64>::zeros((n_features, self.embedding_dim));
243        for i in 0..n_features {
244            for j in 0..self.embedding_dim {
245                // Generate normal distributed random number (mean=0.0, std=0.1)
246                let u1: f64 = rng.gen_range(0.0..1.0);
247                let u2: f64 = rng.gen_range(0.0..1.0);
248                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
249                encoder_weights[(i, j)] = z * 0.1;
250            }
251        }
252
253        // Get unique classes
254        let unique_classes: Vec<i32> = y
255            .iter()
256            .cloned()
257            .filter(|&label| label != -1)
258            .collect::<std::collections::HashSet<_>>()
259            .into_iter()
260            .collect();
261        let n_classes = unique_classes.len();
262
263        // Training loop
264        for epoch in 0..self.max_epochs {
265            // Augment data
266            let X_aug = self.augment_data(X, &mut rng)?;
267
268            // Encode data
269            let embeddings = X_aug.dot(&encoder_weights);
270
271            // Compute supervised contrastive loss
272            let loss = self.compute_supervised_contrastive_loss(&embeddings.view(), y)?;
273
274            // Simple gradient update (placeholder)
275            let gradient_scale = self.learning_rate * loss;
276            // Create gradient noise manually
277            let noise_std = gradient_scale * 0.1;
278            let mut encoder_grad = Array2::<f64>::zeros(encoder_weights.dim());
279            for i in 0..encoder_weights.nrows() {
280                for j in 0..encoder_weights.ncols() {
281                    let u1: f64 = rng.gen_range(0.0..1.0);
282                    let u2: f64 = rng.gen_range(0.0..1.0);
283                    let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
284                    encoder_grad[(i, j)] = z * noise_std;
285                }
286            }
287            encoder_weights = encoder_weights - encoder_grad;
288
289            if epoch % 10 == 0 {
290                println!("Epoch {}: Loss = {:.6}", epoch, loss);
291            }
292        }
293
294        // Compute class centroids
295        let final_embeddings = X.dot(&encoder_weights);
296
297        let mut class_centroids = Array2::zeros((n_classes, self.embedding_dim));
298        let mut class_counts = vec![0; n_classes];
299
300        for i in 0..n_samples {
301            if y[i] != -1 {
302                if let Some(class_idx) = unique_classes.iter().position(|&c| c == y[i]) {
303                    for j in 0..self.embedding_dim {
304                        class_centroids[[class_idx, j]] += final_embeddings[[i, j]];
305                    }
306                    class_counts[class_idx] += 1;
307                }
308            }
309        }
310
311        // Normalize centroids
312        for class_idx in 0..n_classes {
313            if class_counts[class_idx] > 0 {
314                for j in 0..self.embedding_dim {
315                    class_centroids[[class_idx, j]] /= class_counts[class_idx] as f64;
316                }
317            }
318        }
319
320        Ok(FittedSupervisedContrastiveLearning {
321            base_model: self.clone(),
322            encoder_weights,
323            classes: Array1::from_vec(unique_classes),
324            n_classes,
325            class_centroids,
326        })
327    }
328}
329
330impl Predict<ArrayView2<'_, f64>, Array1<i32>> for FittedSupervisedContrastiveLearning {
331    fn predict(&self, X: &ArrayView2<'_, f64>) -> Result<Array1<i32>> {
332        let embeddings = X.dot(&self.encoder_weights);
333
334        let n_samples = X.dim().0;
335        let mut predictions = Array1::zeros(n_samples);
336
337        for i in 0..n_samples {
338            let embedding = embeddings.row(i);
339            let mut best_class = self.classes[0];
340            let mut best_distance = f64::INFINITY;
341
342            for (class_idx, &class) in self.classes.iter().enumerate() {
343                let centroid = self.class_centroids.row(class_idx);
344                let distance = embedding
345                    .iter()
346                    .zip(centroid.iter())
347                    .map(|(e, c)| (e - c).powi(2))
348                    .sum::<f64>()
349                    .sqrt();
350
351                if distance < best_distance {
352                    best_distance = distance;
353                    best_class = class;
354                }
355            }
356
357            predictions[i] = best_class;
358        }
359
360        Ok(predictions)
361    }
362}
363
364impl PredictProba<ArrayView2<'_, f64>, Array2<f64>> for FittedSupervisedContrastiveLearning {
365    fn predict_proba(&self, X: &ArrayView2<'_, f64>) -> Result<Array2<f64>> {
366        let embeddings = X.dot(&self.encoder_weights);
367
368        let n_samples = X.dim().0;
369        let mut probabilities = Array2::zeros((n_samples, self.n_classes));
370
371        for i in 0..n_samples {
372            let embedding = embeddings.row(i);
373            let mut distances = Vec::new();
374
375            for class_idx in 0..self.n_classes {
376                let centroid = self.class_centroids.row(class_idx);
377                let distance = embedding
378                    .iter()
379                    .zip(centroid.iter())
380                    .map(|(e, c)| (e - c).powi(2))
381                    .sum::<f64>()
382                    .sqrt();
383                distances.push(-distance); // Negative distance for softmax
384            }
385
386            // Softmax normalization
387            let max_distance = distances.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
388            let exp_distances: Vec<f64> = distances
389                .iter()
390                .map(|&d| (d - max_distance).exp())
391                .collect();
392            let sum_exp: f64 = exp_distances.iter().sum();
393
394            for (j, &exp_dist) in exp_distances.iter().enumerate() {
395                probabilities[[i, j]] = exp_dist / sum_exp;
396            }
397        }
398
399        Ok(probabilities)
400    }
401}
402
403#[allow(non_snake_case)]
404#[cfg(test)]
405mod tests {
406    use super::*;
407    use approx::assert_abs_diff_eq;
408    use scirs2_core::array;
409
410    #[test]
411    fn test_supervised_contrastive_learning_creation() {
412        let scl = SupervisedContrastiveLearning::new()
413            .embedding_dim(32)
414            .temperature(0.05)
415            .unwrap()
416            .augmentation_strength(0.3)
417            .unwrap()
418            .labeled_weight(3.0)
419            .random_state(42);
420
421        assert_eq!(scl.embedding_dim, 32);
422        assert_eq!(scl.temperature, 0.05);
423        assert_eq!(scl.augmentation_strength, 0.3);
424        assert_eq!(scl.labeled_weight, 3.0);
425        assert_eq!(scl.random_state, Some(42));
426    }
427
428    #[test]
429    fn test_supervised_contrastive_learning_invalid_augmentation() {
430        let result = SupervisedContrastiveLearning::new().augmentation_strength(1.5);
431        assert!(result.is_err());
432    }
433
434    #[test]
435    #[allow(non_snake_case)]
436    fn test_supervised_contrastive_learning_fit_predict() {
437        let X = array![
438            [1.0, 2.0, 3.0],
439            [2.0, 3.0, 4.0],
440            [3.0, 4.0, 5.0],
441            [4.0, 5.0, 6.0],
442            [5.0, 6.0, 7.0],
443            [6.0, 7.0, 8.0]
444        ];
445        let y = array![0, 1, 0, 1, -1, -1];
446
447        let scl = SupervisedContrastiveLearning::new()
448            .embedding_dim(4)
449            .max_epochs(2)
450            .batch_size(3)
451            .unwrap()
452            .random_state(42);
453
454        let fitted = scl.fit(&X.view(), &y.view()).unwrap();
455        let predictions = fitted.predict(&X.view()).unwrap();
456
457        assert_eq!(predictions.len(), 6);
458        for &pred in predictions.iter() {
459            assert!(pred == 0 || pred == 1);
460        }
461
462        let probabilities = fitted.predict_proba(&X.view()).unwrap();
463        assert_eq!(probabilities.dim(), (6, 2));
464
465        // Check that probabilities sum to 1
466        for i in 0..6 {
467            let sum: f64 = probabilities.row(i).sum();
468            assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
469        }
470    }
471
472    #[test]
473    #[allow(non_snake_case)]
474    fn test_insufficient_labeled_samples() {
475        let X = array![[1.0, 2.0], [2.0, 3.0]];
476        let y = array![-1, -1]; // All unlabeled
477
478        let scl = SupervisedContrastiveLearning::new();
479        let result = scl.fit(&X.view(), &y.view());
480        assert!(result.is_err());
481    }
482}