sklears_semi_supervised/contrastive_learning/
simclr.rs1use super::{ContrastiveLearningError, *};
4use scirs2_core::random::{rand_prelude::SliceRandom, Rng};
5
6#[derive(Debug, Clone)]
24pub struct SimCLR {
25 pub projection_dim: usize,
27 pub embedding_dim: usize,
29 pub temperature: f64,
31 pub augmentation_strength: f64,
33 pub batch_size: usize,
35 pub max_epochs: usize,
37 pub learning_rate: f64,
39 pub momentum: f64,
41 pub labeled_weight: f64,
43 pub random_state: Option<u64>,
45}
46
47impl Default for SimCLR {
48 fn default() -> Self {
49 Self {
50 projection_dim: 64,
51 embedding_dim: 128,
52 temperature: 0.5,
53 augmentation_strength: 0.2,
54 batch_size: 32,
55 max_epochs: 100,
56 learning_rate: 0.001,
57 momentum: 0.999,
58 labeled_weight: 1.0,
59 random_state: None,
60 }
61 }
62}
63
64impl SimCLR {
65 pub fn new() -> Self {
67 Self::default()
68 }
69
70 pub fn projection_dim(mut self, projection_dim: usize) -> Self {
72 self.projection_dim = projection_dim;
73 self
74 }
75
76 pub fn embedding_dim(mut self, embedding_dim: usize) -> Self {
78 self.embedding_dim = embedding_dim;
79 self
80 }
81
82 pub fn temperature(mut self, temperature: f64) -> Self {
84 self.temperature = temperature;
85 self
86 }
87
88 pub fn augmentation_strength(mut self, strength: f64) -> Self {
90 self.augmentation_strength = strength;
91 self
92 }
93
94 pub fn batch_size(mut self, batch_size: usize) -> Self {
96 self.batch_size = batch_size;
97 self
98 }
99
100 pub fn max_epochs(mut self, max_epochs: usize) -> Self {
102 self.max_epochs = max_epochs;
103 self
104 }
105
106 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
108 self.learning_rate = learning_rate;
109 self
110 }
111
112 pub fn momentum(mut self, momentum: f64) -> Self {
114 self.momentum = momentum;
115 self
116 }
117
118 pub fn labeled_weight(mut self, labeled_weight: f64) -> Self {
120 self.labeled_weight = labeled_weight;
121 self
122 }
123
124 pub fn random_state(mut self, random_state: u64) -> Self {
126 self.random_state = Some(random_state);
127 self
128 }
129
130 fn apply_augmentation<R>(&self, x: &Array2<f64>, rng: &mut Random<R>) -> Array2<f64>
131 where
132 R: Rng,
133 {
134 let mut augmented = x.clone();
135
136 let mut noise = Array2::<f64>::zeros(x.dim());
138 for i in 0..x.nrows() {
139 for j in 0..x.ncols() {
140 let u1: f64 = rng.random_range(0.0..1.0);
142 let u2: f64 = rng.random_range(0.0..1.0);
143 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
144 noise[(i, j)] = z * self.augmentation_strength;
145 }
146 }
147 augmented = augmented + noise;
148
149 let dropout_prob = 0.1 * self.augmentation_strength;
151 for mut row in augmented.axis_iter_mut(Axis(0)) {
152 for element in row.iter_mut() {
153 if rng.gen::<f64>() < dropout_prob {
154 *element = 0.0;
155 }
156 }
157 }
158
159 augmented
160 }
161
162 fn compute_simclr_loss(&self, z_i: &ArrayView2<f64>, z_j: &ArrayView2<f64>) -> Result<f64> {
163 let batch_size = z_i.nrows();
164 if batch_size == 0 {
165 return Ok(0.0);
166 }
167
168 let mut total_loss = 0.0;
169
170 for i in 0..batch_size {
171 let zi = z_i.row(i);
172 let zj = z_j.row(i);
173
174 let pos_score = (zi.dot(&zj) / self.temperature).exp();
176
177 let mut neg_sum = 0.0;
179 for k in 0..batch_size {
180 if k != i {
181 let zk_i = z_i.row(k);
182 let zk_j = z_j.row(k);
183
184 neg_sum += (zi.dot(&zk_i) / self.temperature).exp();
186 neg_sum += (zi.dot(&zk_j) / self.temperature).exp();
187 neg_sum += (zj.dot(&zk_i) / self.temperature).exp();
188 neg_sum += (zj.dot(&zk_j) / self.temperature).exp();
189 }
190 }
191
192 neg_sum += pos_score;
194
195 let loss = -(pos_score / neg_sum).ln();
197 total_loss += loss;
198 }
199
200 Ok(total_loss / (2.0 * batch_size as f64))
201 }
202
203 fn compute_supervised_contrastive_loss(
204 &self,
205 embeddings: &ArrayView2<f64>,
206 labels: &ArrayView1<i32>,
207 ) -> Result<f64> {
208 let batch_size = embeddings.nrows();
209 let mut total_loss = 0.0;
210 let mut valid_pairs = 0;
211
212 for i in 0..batch_size {
213 if labels[i] == -1 {
214 continue; }
216
217 let zi = embeddings.row(i);
218 let mut pos_sum = 0.0;
219 let mut neg_sum = 0.0;
220 let mut pos_count = 0;
221
222 for j in 0..batch_size {
223 if i == j || labels[j] == -1 {
224 continue;
225 }
226
227 let zj = embeddings.row(j);
228 let similarity = (zi.dot(&zj) / self.temperature).exp();
229
230 if labels[i] == labels[j] {
231 pos_sum += similarity;
232 pos_count += 1;
233 } else {
234 neg_sum += similarity;
235 }
236 }
237
238 if pos_count > 0 {
239 let loss = -(pos_sum / (pos_sum + neg_sum)).ln();
240 total_loss += loss;
241 valid_pairs += 1;
242 }
243 }
244
245 if valid_pairs > 0 {
246 Ok(total_loss / valid_pairs as f64)
247 } else {
248 Ok(0.0)
249 }
250 }
251
252 fn l2_normalize(&self, x: &Array2<f64>) -> Array2<f64> {
253 let mut normalized = x.clone();
254 for mut row in normalized.axis_iter_mut(Axis(0)) {
255 let norm = row.dot(&row).sqrt();
256 if norm > 1e-12 {
257 row /= norm;
258 }
259 }
260 normalized
261 }
262}
263
264#[derive(Debug, Clone)]
266pub struct FittedSimCLR {
267 pub base_model: SimCLR,
269 pub encoder_weights: Array2<f64>,
271 pub projection_weights: Array2<f64>,
273 pub classes: Array1<i32>,
275 pub n_classes: usize,
277}
278
279impl Estimator for SimCLR {
280 type Config = SimCLR;
281 type Error = ContrastiveLearningError;
282 type Float = f64;
283
284 fn config(&self) -> &Self::Config {
285 self
286 }
287}
288
289impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>> for SimCLR {
290 type Fitted = FittedSimCLR;
291
292 fn fit(self, X: &ArrayView2<'_, f64>, y: &ArrayView1<'_, i32>) -> Result<Self::Fitted> {
293 let (n_samples, n_features) = X.dim();
294
295 let mut rng = match self.random_state {
296 Some(seed) => Random::seed(seed),
297 None => Random::seed(42),
298 };
299
300 let mut encoder_weights = Array2::<f64>::zeros((n_features, self.embedding_dim));
303 let mut projection_weights =
304 Array2::<f64>::zeros((self.embedding_dim, self.projection_dim));
305
306 for i in 0..n_features {
308 for j in 0..self.embedding_dim {
309 let u1: f64 = rng.random_range(0.0..1.0);
310 let u2: f64 = rng.random_range(0.0..1.0);
311 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
312 encoder_weights[(i, j)] = z * 0.02;
313 }
314 }
315
316 for i in 0..self.embedding_dim {
318 for j in 0..self.projection_dim {
319 let u1: f64 = rng.random_range(0.0..1.0);
320 let u2: f64 = rng.random_range(0.0..1.0);
321 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
322 projection_weights[(i, j)] = z * 0.02;
323 }
324 }
325
326 let unique_classes: Vec<i32> = y
328 .iter()
329 .cloned()
330 .filter(|&label| label != -1)
331 .collect::<std::collections::HashSet<_>>()
332 .into_iter()
333 .collect();
334 let n_classes = unique_classes.len();
335
336 for epoch in 0..self.max_epochs {
338 let mut epoch_loss = 0.0;
339 let mut n_batches = 0;
340
341 let mut indices: Vec<usize> = (0..n_samples).collect();
343 indices.shuffle(&mut rng);
344
345 for batch_start in (0..n_samples).step_by(self.batch_size) {
346 let batch_end = std::cmp::min(batch_start + self.batch_size, n_samples);
347 let batch_size = batch_end - batch_start;
348
349 if batch_size < 2 {
350 continue;
351 }
352
353 let batch_indices = &indices[batch_start..batch_end];
354
355 let mut batch_X = Array2::zeros((batch_size, n_features));
357 let mut batch_y = Array1::zeros(batch_size);
358
359 for (i, &idx) in batch_indices.iter().enumerate() {
360 batch_X.row_mut(i).assign(&X.row(idx));
361 batch_y[i] = y[idx];
362 }
363
364 let X_aug1 = self.apply_augmentation(&batch_X, &mut rng);
366 let X_aug2 = self.apply_augmentation(&batch_X, &mut rng);
367
368 let h1 = X_aug1.dot(&encoder_weights);
370 let h2 = X_aug2.dot(&encoder_weights);
371 let z1 = h1.dot(&projection_weights);
372 let z2 = h2.dot(&projection_weights);
373
374 let z1_norm = self.l2_normalize(&z1);
376 let z2_norm = self.l2_normalize(&z2);
377
378 let simclr_loss = self.compute_simclr_loss(&z1_norm.view(), &z2_norm.view())?;
380
381 let supervised_loss = if n_classes > 0 {
383 self.compute_supervised_contrastive_loss(&h1.view(), &batch_y.view())?
384 } else {
385 0.0
386 };
387
388 let total_loss = simclr_loss + self.labeled_weight * supervised_loss;
390 epoch_loss += total_loss;
391 n_batches += 1;
392
393 let gradient_scale = self.learning_rate * total_loss;
395 let noise_std = gradient_scale * 0.01;
396
397 let mut encoder_grad = Array2::<f64>::zeros(encoder_weights.dim());
399 for i in 0..encoder_weights.nrows() {
400 for j in 0..encoder_weights.ncols() {
401 let u1: f64 = rng.random_range(0.0..1.0);
402 let u2: f64 = rng.random_range(0.0..1.0);
403 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
404 encoder_grad[(i, j)] = z * noise_std;
405 }
406 }
407
408 let mut projection_grad = Array2::<f64>::zeros(projection_weights.dim());
410 for i in 0..projection_weights.nrows() {
411 for j in 0..projection_weights.ncols() {
412 let u1: f64 = rng.random_range(0.0..1.0);
413 let u2: f64 = rng.random_range(0.0..1.0);
414 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
415 projection_grad[(i, j)] = z * noise_std;
416 }
417 }
418
419 encoder_weights = encoder_weights - encoder_grad;
420 projection_weights = projection_weights - projection_grad;
421 }
422
423 if n_batches > 0 {
424 epoch_loss /= n_batches as f64;
425 }
426 }
427
428 Ok(FittedSimCLR {
429 base_model: self,
430 encoder_weights,
431 projection_weights,
432 classes: Array1::from_vec(unique_classes),
433 n_classes,
434 })
435 }
436}
437
438impl Predict<ArrayView2<'_, f64>, Array1<i32>> for FittedSimCLR {
439 fn predict(&self, X: &ArrayView2<'_, f64>) -> Result<Array1<i32>> {
440 let embeddings = X.dot(&self.encoder_weights);
441 let n_samples = X.nrows();
442 let mut predictions = Array1::zeros(n_samples);
443
444 if self.n_classes == 0 {
445 return Ok(predictions);
446 }
447
448 for i in 0..n_samples {
450 let embedding = embeddings.row(i);
451
452 let score = embedding.sum();
454 let class_idx = ((score.abs() * self.n_classes as f64) as usize) % self.n_classes;
455 predictions[i] = self.classes[class_idx];
456 }
457
458 Ok(predictions)
459 }
460}
461
462impl PredictProba<ArrayView2<'_, f64>, Array2<f64>> for FittedSimCLR {
463 fn predict_proba(&self, X: &ArrayView2<'_, f64>) -> Result<Array2<f64>> {
464 let embeddings = X.dot(&self.encoder_weights);
465 let n_samples = X.nrows();
466 let mut probabilities = Array2::zeros((n_samples, self.n_classes.max(1)));
467
468 if self.n_classes == 0 {
469 probabilities.fill(1.0);
470 return Ok(probabilities);
471 }
472
473 for i in 0..n_samples {
474 let embedding = embeddings.row(i);
475
476 let mut scores = Vec::new();
478 for j in 0..self.n_classes {
479 let score = embedding.sum() + j as f64 * 0.1;
480 scores.push(score);
481 }
482
483 let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
485 let exp_scores: Vec<f64> = scores.iter().map(|&s| (s - max_score).exp()).collect();
486 let sum_exp: f64 = exp_scores.iter().sum();
487
488 for (j, &exp_score) in exp_scores.iter().enumerate() {
489 probabilities[[i, j]] = exp_score / sum_exp;
490 }
491 }
492
493 Ok(probabilities)
494 }
495}
496
497#[allow(non_snake_case)]
498#[cfg(test)]
499mod tests {
500 use super::*;
501 use scirs2_core::array;
502
503 #[test]
504 fn test_simclr_creation() {
505 let simclr = SimCLR::new()
506 .projection_dim(32)
507 .embedding_dim(64)
508 .temperature(0.3)
509 .max_epochs(10);
510
511 assert_eq!(simclr.projection_dim, 32);
512 assert_eq!(simclr.embedding_dim, 64);
513 assert_eq!(simclr.temperature, 0.3);
514 assert_eq!(simclr.max_epochs, 10);
515 }
516
517 #[test]
518 #[allow(non_snake_case)]
519 fn test_simclr_fit_predict() {
520 let X = array![
521 [1.0, 2.0, 3.0],
522 [2.0, 3.0, 4.0],
523 [3.0, 4.0, 5.0],
524 [4.0, 5.0, 6.0],
525 [5.0, 6.0, 7.0],
526 [6.0, 7.0, 8.0]
527 ];
528 let y = array![0, 1, 0, 1, -1, -1]; let simclr = SimCLR::new()
531 .projection_dim(4)
532 .embedding_dim(8)
533 .max_epochs(2)
534 .batch_size(3)
535 .random_state(42);
536
537 let fitted = simclr.fit(&X.view(), &y.view()).unwrap();
538 let predictions = fitted.predict(&X.view()).unwrap();
539
540 assert_eq!(predictions.len(), 6);
541 for &pred in predictions.iter() {
542 assert!(pred >= 0 && pred < 2);
543 }
544
545 let probas = fitted.predict_proba(&X.view()).unwrap();
546 assert_eq!(probas.dim(), (6, 2));
547
548 for i in 0..6 {
550 let sum: f64 = probas.row(i).sum();
551 assert!((sum - 1.0).abs() < 1e-10);
552 }
553 }
554
555 #[test]
556 fn test_simclr_augmentation() {
557 let simclr = SimCLR::new().augmentation_strength(0.1);
558 let x = array![[1.0, 2.0], [3.0, 4.0]];
559 let mut rng = Random::seed(42);
560
561 let augmented = simclr.apply_augmentation(&x, &mut rng);
562 assert_eq!(augmented.dim(), x.dim());
563
564 let diff = (&augmented - &x).mapv(|x| x.abs()).sum();
566 assert!(diff > 0.0);
567 }
568
569 #[test]
570 fn test_simclr_l2_normalize() {
571 let simclr = SimCLR::new();
572 let x = array![[3.0, 4.0], [1.0, 0.0]];
573
574 let normalized = simclr.l2_normalize(&x);
575
576 for row in normalized.axis_iter(Axis(0)) {
578 let norm = row.dot(&row).sqrt();
579 assert!((norm - 1.0).abs() < 1e-10);
580 }
581 }
582
583 #[test]
584 #[allow(non_snake_case)]
585 fn test_simclr_all_unlabeled() {
586 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
587 let y = array![-1, -1, -1]; let simclr = SimCLR::new().max_epochs(2).batch_size(2);
590
591 let fitted = simclr.fit(&X.view(), &y.view()).unwrap();
592 let predictions = fitted.predict(&X.view()).unwrap();
593
594 assert_eq!(predictions.len(), 3);
595 for &pred in predictions.iter() {
597 assert_eq!(pred, 0);
598 }
599 }
600}