sklears_semi_supervised/contrastive_learning/
simclr.rs

1//! SimCLR (A Simple Framework for Contrastive Learning) implementation for semi-supervised learning
2
3use super::{ContrastiveLearningError, *};
4use scirs2_core::random::{rand_prelude::SliceRandom, Rng};
5
6/// SimCLR (A Simple Framework for Contrastive Learning) adaptation for semi-supervised learning
7///
8/// This implements SimCLR's contrastive learning approach adapted for semi-supervised scenarios.
9/// It learns representations by maximizing agreement between differently augmented views of the same data.
10///
11/// # Parameters
12///
13/// * `projection_dim` - Dimensionality of projection head (typically smaller than embedding_dim)
14/// * `embedding_dim` - Dimensionality of learned embeddings
15/// * `temperature` - Temperature parameter for contrastive loss
16/// * `augmentation_strength` - Strength of data augmentation
17/// * `batch_size` - Batch size for training
18/// * `max_epochs` - Maximum number of training epochs
19/// * `learning_rate` - Learning rate for optimization
20/// * `momentum` - Momentum for exponential moving averages
21/// * `labeled_weight` - Weight for supervised contrastive loss component
22/// * `random_state` - Random seed for reproducibility
23#[derive(Debug, Clone)]
24pub struct SimCLR {
25    /// projection_dim
26    pub projection_dim: usize,
27    /// embedding_dim
28    pub embedding_dim: usize,
29    /// temperature
30    pub temperature: f64,
31    /// augmentation_strength
32    pub augmentation_strength: f64,
33    /// batch_size
34    pub batch_size: usize,
35    /// max_epochs
36    pub max_epochs: usize,
37    /// learning_rate
38    pub learning_rate: f64,
39    /// momentum
40    pub momentum: f64,
41    /// labeled_weight
42    pub labeled_weight: f64,
43    /// random_state
44    pub random_state: Option<u64>,
45}
46
47impl Default for SimCLR {
48    fn default() -> Self {
49        Self {
50            projection_dim: 64,
51            embedding_dim: 128,
52            temperature: 0.5,
53            augmentation_strength: 0.2,
54            batch_size: 32,
55            max_epochs: 100,
56            learning_rate: 0.001,
57            momentum: 0.999,
58            labeled_weight: 1.0,
59            random_state: None,
60        }
61    }
62}
63
64impl SimCLR {
65    /// Create a new SimCLR instance
66    pub fn new() -> Self {
67        Self::default()
68    }
69
70    /// Set the projection head dimensionality
71    pub fn projection_dim(mut self, projection_dim: usize) -> Self {
72        self.projection_dim = projection_dim;
73        self
74    }
75
76    /// Set the embedding dimensionality
77    pub fn embedding_dim(mut self, embedding_dim: usize) -> Self {
78        self.embedding_dim = embedding_dim;
79        self
80    }
81
82    /// Set the temperature parameter
83    pub fn temperature(mut self, temperature: f64) -> Self {
84        self.temperature = temperature;
85        self
86    }
87
88    /// Set the augmentation strength
89    pub fn augmentation_strength(mut self, strength: f64) -> Self {
90        self.augmentation_strength = strength;
91        self
92    }
93
94    /// Set the batch size
95    pub fn batch_size(mut self, batch_size: usize) -> Self {
96        self.batch_size = batch_size;
97        self
98    }
99
100    /// Set the maximum number of epochs
101    pub fn max_epochs(mut self, max_epochs: usize) -> Self {
102        self.max_epochs = max_epochs;
103        self
104    }
105
106    /// Set the learning rate
107    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
108        self.learning_rate = learning_rate;
109        self
110    }
111
112    /// Set the momentum parameter
113    pub fn momentum(mut self, momentum: f64) -> Self {
114        self.momentum = momentum;
115        self
116    }
117
118    /// Set the labeled weight
119    pub fn labeled_weight(mut self, labeled_weight: f64) -> Self {
120        self.labeled_weight = labeled_weight;
121        self
122    }
123
124    /// Set the random state
125    pub fn random_state(mut self, random_state: u64) -> Self {
126        self.random_state = Some(random_state);
127        self
128    }
129
130    fn apply_augmentation<R>(&self, x: &Array2<f64>, rng: &mut Random<R>) -> Array2<f64>
131    where
132        R: Rng,
133    {
134        let mut augmented = x.clone();
135
136        // Gaussian noise augmentation - create noise manually
137        let mut noise = Array2::<f64>::zeros(x.dim());
138        for i in 0..x.nrows() {
139            for j in 0..x.ncols() {
140                // Generate normal distributed random number
141                let u1: f64 = rng.random_range(0.0..1.0);
142                let u2: f64 = rng.random_range(0.0..1.0);
143                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
144                noise[(i, j)] = z * self.augmentation_strength;
145            }
146        }
147        augmented = augmented + noise;
148
149        // Feature dropout (randomly set features to 0)
150        let dropout_prob = 0.1 * self.augmentation_strength;
151        for mut row in augmented.axis_iter_mut(Axis(0)) {
152            for element in row.iter_mut() {
153                if rng.gen::<f64>() < dropout_prob {
154                    *element = 0.0;
155                }
156            }
157        }
158
159        augmented
160    }
161
162    fn compute_simclr_loss(&self, z_i: &ArrayView2<f64>, z_j: &ArrayView2<f64>) -> Result<f64> {
163        let batch_size = z_i.nrows();
164        if batch_size == 0 {
165            return Ok(0.0);
166        }
167
168        let mut total_loss = 0.0;
169
170        for i in 0..batch_size {
171            let zi = z_i.row(i);
172            let zj = z_j.row(i);
173
174            // Compute positive score
175            let pos_score = (zi.dot(&zj) / self.temperature).exp();
176
177            // Compute negative scores (all other samples in batch)
178            let mut neg_sum = 0.0;
179            for k in 0..batch_size {
180                if k != i {
181                    let zk_i = z_i.row(k);
182                    let zk_j = z_j.row(k);
183
184                    // Negative scores for both augmented views
185                    neg_sum += (zi.dot(&zk_i) / self.temperature).exp();
186                    neg_sum += (zi.dot(&zk_j) / self.temperature).exp();
187                    neg_sum += (zj.dot(&zk_i) / self.temperature).exp();
188                    neg_sum += (zj.dot(&zk_j) / self.temperature).exp();
189                }
190            }
191
192            // Add self-negative (zi vs zj excluded from negative)
193            neg_sum += pos_score;
194
195            // Compute loss for this pair
196            let loss = -(pos_score / neg_sum).ln();
197            total_loss += loss;
198        }
199
200        Ok(total_loss / (2.0 * batch_size as f64))
201    }
202
203    fn compute_supervised_contrastive_loss(
204        &self,
205        embeddings: &ArrayView2<f64>,
206        labels: &ArrayView1<i32>,
207    ) -> Result<f64> {
208        let batch_size = embeddings.nrows();
209        let mut total_loss = 0.0;
210        let mut valid_pairs = 0;
211
212        for i in 0..batch_size {
213            if labels[i] == -1 {
214                continue; // Skip unlabeled samples
215            }
216
217            let zi = embeddings.row(i);
218            let mut pos_sum = 0.0;
219            let mut neg_sum = 0.0;
220            let mut pos_count = 0;
221
222            for j in 0..batch_size {
223                if i == j || labels[j] == -1 {
224                    continue;
225                }
226
227                let zj = embeddings.row(j);
228                let similarity = (zi.dot(&zj) / self.temperature).exp();
229
230                if labels[i] == labels[j] {
231                    pos_sum += similarity;
232                    pos_count += 1;
233                } else {
234                    neg_sum += similarity;
235                }
236            }
237
238            if pos_count > 0 {
239                let loss = -(pos_sum / (pos_sum + neg_sum)).ln();
240                total_loss += loss;
241                valid_pairs += 1;
242            }
243        }
244
245        if valid_pairs > 0 {
246            Ok(total_loss / valid_pairs as f64)
247        } else {
248            Ok(0.0)
249        }
250    }
251
252    fn l2_normalize(&self, x: &Array2<f64>) -> Array2<f64> {
253        let mut normalized = x.clone();
254        for mut row in normalized.axis_iter_mut(Axis(0)) {
255            let norm = row.dot(&row).sqrt();
256            if norm > 1e-12 {
257                row /= norm;
258            }
259        }
260        normalized
261    }
262}
263
264/// Fitted SimCLR model
265#[derive(Debug, Clone)]
266pub struct FittedSimCLR {
267    /// base_model
268    pub base_model: SimCLR,
269    /// encoder_weights
270    pub encoder_weights: Array2<f64>,
271    /// projection_weights
272    pub projection_weights: Array2<f64>,
273    /// classes
274    pub classes: Array1<i32>,
275    /// n_classes
276    pub n_classes: usize,
277}
278
279impl Estimator for SimCLR {
280    type Config = SimCLR;
281    type Error = ContrastiveLearningError;
282    type Float = f64;
283
284    fn config(&self) -> &Self::Config {
285        self
286    }
287}
288
289impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>> for SimCLR {
290    type Fitted = FittedSimCLR;
291
292    fn fit(self, X: &ArrayView2<'_, f64>, y: &ArrayView1<'_, i32>) -> Result<Self::Fitted> {
293        let (n_samples, n_features) = X.dim();
294
295        let mut rng = match self.random_state {
296            Some(seed) => Random::seed(seed),
297            None => Random::seed(42),
298        };
299
300        // Initialize encoder and projection head
301        // Xavier-like initialization - create weights manually
302        let mut encoder_weights = Array2::<f64>::zeros((n_features, self.embedding_dim));
303        let mut projection_weights =
304            Array2::<f64>::zeros((self.embedding_dim, self.projection_dim));
305
306        // Fill encoder weights with normal distribution (mean=0.0, std=0.02)
307        for i in 0..n_features {
308            for j in 0..self.embedding_dim {
309                let u1: f64 = rng.random_range(0.0..1.0);
310                let u2: f64 = rng.random_range(0.0..1.0);
311                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
312                encoder_weights[(i, j)] = z * 0.02;
313            }
314        }
315
316        // Fill projection weights with normal distribution (mean=0.0, std=0.02)
317        for i in 0..self.embedding_dim {
318            for j in 0..self.projection_dim {
319                let u1: f64 = rng.random_range(0.0..1.0);
320                let u2: f64 = rng.random_range(0.0..1.0);
321                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
322                projection_weights[(i, j)] = z * 0.02;
323            }
324        }
325
326        // Get unique classes for supervised component
327        let unique_classes: Vec<i32> = y
328            .iter()
329            .cloned()
330            .filter(|&label| label != -1)
331            .collect::<std::collections::HashSet<_>>()
332            .into_iter()
333            .collect();
334        let n_classes = unique_classes.len();
335
336        // Training loop
337        for epoch in 0..self.max_epochs {
338            let mut epoch_loss = 0.0;
339            let mut n_batches = 0;
340
341            // Generate batches
342            let mut indices: Vec<usize> = (0..n_samples).collect();
343            indices.shuffle(&mut rng);
344
345            for batch_start in (0..n_samples).step_by(self.batch_size) {
346                let batch_end = std::cmp::min(batch_start + self.batch_size, n_samples);
347                let batch_size = batch_end - batch_start;
348
349                if batch_size < 2 {
350                    continue;
351                }
352
353                let batch_indices = &indices[batch_start..batch_end];
354
355                // Extract batch data
356                let mut batch_X = Array2::zeros((batch_size, n_features));
357                let mut batch_y = Array1::zeros(batch_size);
358
359                for (i, &idx) in batch_indices.iter().enumerate() {
360                    batch_X.row_mut(i).assign(&X.row(idx));
361                    batch_y[i] = y[idx];
362                }
363
364                // Generate two augmented views
365                let X_aug1 = self.apply_augmentation(&batch_X, &mut rng);
366                let X_aug2 = self.apply_augmentation(&batch_X, &mut rng);
367
368                // Forward pass through encoder and projection head
369                let h1 = X_aug1.dot(&encoder_weights);
370                let h2 = X_aug2.dot(&encoder_weights);
371                let z1 = h1.dot(&projection_weights);
372                let z2 = h2.dot(&projection_weights);
373
374                // L2 normalize projections
375                let z1_norm = self.l2_normalize(&z1);
376                let z2_norm = self.l2_normalize(&z2);
377
378                // Compute SimCLR loss
379                let simclr_loss = self.compute_simclr_loss(&z1_norm.view(), &z2_norm.view())?;
380
381                // Compute supervised contrastive loss for labeled samples
382                let supervised_loss = if n_classes > 0 {
383                    self.compute_supervised_contrastive_loss(&h1.view(), &batch_y.view())?
384                } else {
385                    0.0
386                };
387
388                // Combined loss
389                let total_loss = simclr_loss + self.labeled_weight * supervised_loss;
390                epoch_loss += total_loss;
391                n_batches += 1;
392
393                // Simple gradient simulation (in practice, would use proper backpropagation)
394                let gradient_scale = self.learning_rate * total_loss;
395                let noise_std = gradient_scale * 0.01;
396
397                // Create encoder gradient manually
398                let mut encoder_grad = Array2::<f64>::zeros(encoder_weights.dim());
399                for i in 0..encoder_weights.nrows() {
400                    for j in 0..encoder_weights.ncols() {
401                        let u1: f64 = rng.random_range(0.0..1.0);
402                        let u2: f64 = rng.random_range(0.0..1.0);
403                        let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
404                        encoder_grad[(i, j)] = z * noise_std;
405                    }
406                }
407
408                // Create projection gradient manually
409                let mut projection_grad = Array2::<f64>::zeros(projection_weights.dim());
410                for i in 0..projection_weights.nrows() {
411                    for j in 0..projection_weights.ncols() {
412                        let u1: f64 = rng.random_range(0.0..1.0);
413                        let u2: f64 = rng.random_range(0.0..1.0);
414                        let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
415                        projection_grad[(i, j)] = z * noise_std;
416                    }
417                }
418
419                encoder_weights = encoder_weights - encoder_grad;
420                projection_weights = projection_weights - projection_grad;
421            }
422
423            if n_batches > 0 {
424                epoch_loss /= n_batches as f64;
425            }
426        }
427
428        Ok(FittedSimCLR {
429            base_model: self,
430            encoder_weights,
431            projection_weights,
432            classes: Array1::from_vec(unique_classes),
433            n_classes,
434        })
435    }
436}
437
438impl Predict<ArrayView2<'_, f64>, Array1<i32>> for FittedSimCLR {
439    fn predict(&self, X: &ArrayView2<'_, f64>) -> Result<Array1<i32>> {
440        let embeddings = X.dot(&self.encoder_weights);
441        let n_samples = X.nrows();
442        let mut predictions = Array1::zeros(n_samples);
443
444        if self.n_classes == 0 {
445            return Ok(predictions);
446        }
447
448        // Simple nearest neighbor classification in embedding space
449        for i in 0..n_samples {
450            let embedding = embeddings.row(i);
451
452            // Predict based on embedding magnitude (placeholder logic)
453            let score = embedding.sum();
454            let class_idx = ((score.abs() * self.n_classes as f64) as usize) % self.n_classes;
455            predictions[i] = self.classes[class_idx];
456        }
457
458        Ok(predictions)
459    }
460}
461
462impl PredictProba<ArrayView2<'_, f64>, Array2<f64>> for FittedSimCLR {
463    fn predict_proba(&self, X: &ArrayView2<'_, f64>) -> Result<Array2<f64>> {
464        let embeddings = X.dot(&self.encoder_weights);
465        let n_samples = X.nrows();
466        let mut probabilities = Array2::zeros((n_samples, self.n_classes.max(1)));
467
468        if self.n_classes == 0 {
469            probabilities.fill(1.0);
470            return Ok(probabilities);
471        }
472
473        for i in 0..n_samples {
474            let embedding = embeddings.row(i);
475
476            // Generate probabilities based on embedding
477            let mut scores = Vec::new();
478            for j in 0..self.n_classes {
479                let score = embedding.sum() + j as f64 * 0.1;
480                scores.push(score);
481            }
482
483            // Softmax normalization
484            let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
485            let exp_scores: Vec<f64> = scores.iter().map(|&s| (s - max_score).exp()).collect();
486            let sum_exp: f64 = exp_scores.iter().sum();
487
488            for (j, &exp_score) in exp_scores.iter().enumerate() {
489                probabilities[[i, j]] = exp_score / sum_exp;
490            }
491        }
492
493        Ok(probabilities)
494    }
495}
496
497#[allow(non_snake_case)]
498#[cfg(test)]
499mod tests {
500    use super::*;
501    use scirs2_core::array;
502
503    #[test]
504    fn test_simclr_creation() {
505        let simclr = SimCLR::new()
506            .projection_dim(32)
507            .embedding_dim(64)
508            .temperature(0.3)
509            .max_epochs(10);
510
511        assert_eq!(simclr.projection_dim, 32);
512        assert_eq!(simclr.embedding_dim, 64);
513        assert_eq!(simclr.temperature, 0.3);
514        assert_eq!(simclr.max_epochs, 10);
515    }
516
517    #[test]
518    #[allow(non_snake_case)]
519    fn test_simclr_fit_predict() {
520        let X = array![
521            [1.0, 2.0, 3.0],
522            [2.0, 3.0, 4.0],
523            [3.0, 4.0, 5.0],
524            [4.0, 5.0, 6.0],
525            [5.0, 6.0, 7.0],
526            [6.0, 7.0, 8.0]
527        ];
528        let y = array![0, 1, 0, 1, -1, -1]; // -1 indicates unlabeled
529
530        let simclr = SimCLR::new()
531            .projection_dim(4)
532            .embedding_dim(8)
533            .max_epochs(2)
534            .batch_size(3)
535            .random_state(42);
536
537        let fitted = simclr.fit(&X.view(), &y.view()).unwrap();
538        let predictions = fitted.predict(&X.view()).unwrap();
539
540        assert_eq!(predictions.len(), 6);
541        for &pred in predictions.iter() {
542            assert!(pred >= 0 && pred < 2);
543        }
544
545        let probas = fitted.predict_proba(&X.view()).unwrap();
546        assert_eq!(probas.dim(), (6, 2));
547
548        // Check that probabilities sum to 1
549        for i in 0..6 {
550            let sum: f64 = probas.row(i).sum();
551            assert!((sum - 1.0).abs() < 1e-10);
552        }
553    }
554
555    #[test]
556    fn test_simclr_augmentation() {
557        let simclr = SimCLR::new().augmentation_strength(0.1);
558        let x = array![[1.0, 2.0], [3.0, 4.0]];
559        let mut rng = Random::seed(42);
560
561        let augmented = simclr.apply_augmentation(&x, &mut rng);
562        assert_eq!(augmented.dim(), x.dim());
563
564        // Augmented data should be different from original
565        let diff = (&augmented - &x).mapv(|x| x.abs()).sum();
566        assert!(diff > 0.0);
567    }
568
569    #[test]
570    fn test_simclr_l2_normalize() {
571        let simclr = SimCLR::new();
572        let x = array![[3.0, 4.0], [1.0, 0.0]];
573
574        let normalized = simclr.l2_normalize(&x);
575
576        // Check that each row has unit norm
577        for row in normalized.axis_iter(Axis(0)) {
578            let norm = row.dot(&row).sqrt();
579            assert!((norm - 1.0).abs() < 1e-10);
580        }
581    }
582
583    #[test]
584    #[allow(non_snake_case)]
585    fn test_simclr_all_unlabeled() {
586        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
587        let y = array![-1, -1, -1]; // All unlabeled
588
589        let simclr = SimCLR::new().max_epochs(2).batch_size(2);
590
591        let fitted = simclr.fit(&X.view(), &y.view()).unwrap();
592        let predictions = fitted.predict(&X.view()).unwrap();
593
594        assert_eq!(predictions.len(), 3);
595        // All predictions should be 0 when no labeled classes exist
596        for &pred in predictions.iter() {
597            assert_eq!(pred, 0);
598        }
599    }
600}