sklears_semi_supervised/contrastive_learning/
supervised_contrastive.rs1use super::{ContrastiveLearningError, *};
4use scirs2_core::random::Rng;
5
6#[derive(Debug, Clone)]
11pub struct SupervisedContrastiveLearning {
12 pub embedding_dim: usize,
14 pub temperature: f64,
16 pub learning_rate: f64,
18 pub batch_size: usize,
20 pub max_epochs: usize,
22 pub augmentation_strength: f64,
24 pub labeled_weight: f64,
26 pub random_state: Option<u64>,
28}
29
30impl Default for SupervisedContrastiveLearning {
31 fn default() -> Self {
32 Self {
33 embedding_dim: 128,
34 temperature: 0.07,
35 learning_rate: 0.001,
36 batch_size: 32,
37 max_epochs: 100,
38 augmentation_strength: 0.5,
39 labeled_weight: 2.0,
40 random_state: None,
41 }
42 }
43}
44
45impl SupervisedContrastiveLearning {
46 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn embedding_dim(mut self, embedding_dim: usize) -> Self {
51 self.embedding_dim = embedding_dim;
52 self
53 }
54
55 pub fn temperature(mut self, temperature: f64) -> Result<Self> {
56 if temperature <= 0.0 {
57 return Err(ContrastiveLearningError::InvalidTemperature(temperature).into());
58 }
59 self.temperature = temperature;
60 Ok(self)
61 }
62
63 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
64 self.learning_rate = learning_rate;
65 self
66 }
67
68 pub fn batch_size(mut self, batch_size: usize) -> Result<Self> {
69 if batch_size == 0 {
70 return Err(ContrastiveLearningError::InvalidBatchSize(batch_size).into());
71 }
72 self.batch_size = batch_size;
73 Ok(self)
74 }
75
76 pub fn max_epochs(mut self, max_epochs: usize) -> Self {
77 self.max_epochs = max_epochs;
78 self
79 }
80
81 pub fn augmentation_strength(mut self, augmentation_strength: f64) -> Result<Self> {
82 if !(0.0..=1.0).contains(&augmentation_strength) {
83 return Err(ContrastiveLearningError::InvalidAugmentationStrength(
84 augmentation_strength,
85 )
86 .into());
87 }
88 self.augmentation_strength = augmentation_strength;
89 Ok(self)
90 }
91
92 pub fn labeled_weight(mut self, labeled_weight: f64) -> Self {
93 self.labeled_weight = labeled_weight;
94 self
95 }
96
97 pub fn random_state(mut self, random_state: u64) -> Self {
98 self.random_state = Some(random_state);
99 self
100 }
101
102 fn augment_data<R>(&self, X: &ArrayView2<f64>, rng: &mut Random<R>) -> Result<Array2<f64>>
103 where
104 R: Rng,
105 {
106 let (n_samples, n_features) = X.dim();
107 let mut augmented = X.to_owned();
108
109 let noise_std = self.augmentation_strength * 0.1;
111 let mut noise = Array2::<f64>::zeros((n_samples, n_features));
112 for i in 0..n_samples {
113 for j in 0..n_features {
114 let u1: f64 = rng.gen_range(0.0..1.0);
116 let u2: f64 = rng.gen_range(0.0..1.0);
117 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
118 noise[(i, j)] = z * noise_std;
119 }
120 }
121 augmented = augmented + noise;
122
123 Ok(augmented)
124 }
125
126 fn compute_supervised_contrastive_loss(
127 &self,
128 embeddings: &ArrayView2<f64>,
129 labels: &ArrayView1<i32>,
130 ) -> Result<f64> {
131 let n_samples = embeddings.dim().0;
132 let mut total_loss = 0.0;
133 let mut n_labeled = 0;
134
135 for i in 0..n_samples {
136 if labels[i] == -1 {
137 continue; }
139
140 let anchor = embeddings.row(i);
141 let anchor_label = labels[i];
142
143 let mut positive_scores = Vec::new();
144 let mut negative_scores = Vec::new();
145
146 for j in 0..n_samples {
147 if i == j {
148 continue;
149 }
150
151 let sample = embeddings.row(j);
152 let score = anchor.dot(&sample) / self.temperature;
153
154 if labels[j] == anchor_label && labels[j] != -1 {
155 positive_scores.push(score);
156 } else if labels[j] != -1 {
157 negative_scores.push(score);
158 }
159 }
160
161 if positive_scores.is_empty() {
162 continue;
163 }
164
165 let max_score = positive_scores
167 .iter()
168 .chain(negative_scores.iter())
169 .cloned()
170 .fold(f64::NEG_INFINITY, f64::max);
171
172 let mut pos_exp_sum = 0.0;
173 for &score in positive_scores.iter() {
174 pos_exp_sum += (score - max_score).exp();
175 }
176
177 let mut all_exp_sum = pos_exp_sum;
178 for &score in negative_scores.iter() {
179 all_exp_sum += (score - max_score).exp();
180 }
181
182 if all_exp_sum > 0.0 {
183 let loss = -(pos_exp_sum / all_exp_sum).ln();
184 total_loss += loss;
185 n_labeled += 1;
186 }
187 }
188
189 if n_labeled > 0 {
190 Ok(total_loss / n_labeled as f64)
191 } else {
192 Ok(0.0)
193 }
194 }
195}
196
197#[derive(Debug, Clone)]
199pub struct FittedSupervisedContrastiveLearning {
200 pub base_model: SupervisedContrastiveLearning,
202 pub encoder_weights: Array2<f64>,
204 pub classes: Array1<i32>,
206 pub n_classes: usize,
208 pub class_centroids: Array2<f64>,
210}
211
212impl Estimator for SupervisedContrastiveLearning {
213 type Config = SupervisedContrastiveLearning;
214 type Error = ContrastiveLearningError;
215 type Float = f64;
216
217 fn config(&self) -> &Self::Config {
218 self
219 }
220}
221
222impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>> for SupervisedContrastiveLearning {
223 type Fitted = FittedSupervisedContrastiveLearning;
224
225 #[allow(non_snake_case)]
226 fn fit(self, X: &ArrayView2<'_, f64>, y: &ArrayView1<'_, i32>) -> Result<Self::Fitted> {
227 let (n_samples, n_features) = X.dim();
228
229 let labeled_count = y.iter().filter(|&&label| label != -1).count();
231 if labeled_count < 2 {
232 return Err(ContrastiveLearningError::InsufficientLabeledSamples.into());
233 }
234
235 let mut rng = match self.random_state {
236 Some(seed) => Random::seed(seed),
237 None => Random::seed(42),
238 };
239
240 let mut encoder_weights = Array2::<f64>::zeros((n_features, self.embedding_dim));
243 for i in 0..n_features {
244 for j in 0..self.embedding_dim {
245 let u1: f64 = rng.gen_range(0.0..1.0);
247 let u2: f64 = rng.gen_range(0.0..1.0);
248 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
249 encoder_weights[(i, j)] = z * 0.1;
250 }
251 }
252
253 let unique_classes: Vec<i32> = y
255 .iter()
256 .cloned()
257 .filter(|&label| label != -1)
258 .collect::<std::collections::HashSet<_>>()
259 .into_iter()
260 .collect();
261 let n_classes = unique_classes.len();
262
263 for epoch in 0..self.max_epochs {
265 let X_aug = self.augment_data(X, &mut rng)?;
267
268 let embeddings = X_aug.dot(&encoder_weights);
270
271 let loss = self.compute_supervised_contrastive_loss(&embeddings.view(), y)?;
273
274 let gradient_scale = self.learning_rate * loss;
276 let noise_std = gradient_scale * 0.1;
278 let mut encoder_grad = Array2::<f64>::zeros(encoder_weights.dim());
279 for i in 0..encoder_weights.nrows() {
280 for j in 0..encoder_weights.ncols() {
281 let u1: f64 = rng.gen_range(0.0..1.0);
282 let u2: f64 = rng.gen_range(0.0..1.0);
283 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
284 encoder_grad[(i, j)] = z * noise_std;
285 }
286 }
287 encoder_weights = encoder_weights - encoder_grad;
288
289 if epoch % 10 == 0 {
290 println!("Epoch {}: Loss = {:.6}", epoch, loss);
291 }
292 }
293
294 let final_embeddings = X.dot(&encoder_weights);
296
297 let mut class_centroids = Array2::zeros((n_classes, self.embedding_dim));
298 let mut class_counts = vec![0; n_classes];
299
300 for i in 0..n_samples {
301 if y[i] != -1 {
302 if let Some(class_idx) = unique_classes.iter().position(|&c| c == y[i]) {
303 for j in 0..self.embedding_dim {
304 class_centroids[[class_idx, j]] += final_embeddings[[i, j]];
305 }
306 class_counts[class_idx] += 1;
307 }
308 }
309 }
310
311 for class_idx in 0..n_classes {
313 if class_counts[class_idx] > 0 {
314 for j in 0..self.embedding_dim {
315 class_centroids[[class_idx, j]] /= class_counts[class_idx] as f64;
316 }
317 }
318 }
319
320 Ok(FittedSupervisedContrastiveLearning {
321 base_model: self.clone(),
322 encoder_weights,
323 classes: Array1::from_vec(unique_classes),
324 n_classes,
325 class_centroids,
326 })
327 }
328}
329
330impl Predict<ArrayView2<'_, f64>, Array1<i32>> for FittedSupervisedContrastiveLearning {
331 fn predict(&self, X: &ArrayView2<'_, f64>) -> Result<Array1<i32>> {
332 let embeddings = X.dot(&self.encoder_weights);
333
334 let n_samples = X.dim().0;
335 let mut predictions = Array1::zeros(n_samples);
336
337 for i in 0..n_samples {
338 let embedding = embeddings.row(i);
339 let mut best_class = self.classes[0];
340 let mut best_distance = f64::INFINITY;
341
342 for (class_idx, &class) in self.classes.iter().enumerate() {
343 let centroid = self.class_centroids.row(class_idx);
344 let distance = embedding
345 .iter()
346 .zip(centroid.iter())
347 .map(|(e, c)| (e - c).powi(2))
348 .sum::<f64>()
349 .sqrt();
350
351 if distance < best_distance {
352 best_distance = distance;
353 best_class = class;
354 }
355 }
356
357 predictions[i] = best_class;
358 }
359
360 Ok(predictions)
361 }
362}
363
364impl PredictProba<ArrayView2<'_, f64>, Array2<f64>> for FittedSupervisedContrastiveLearning {
365 fn predict_proba(&self, X: &ArrayView2<'_, f64>) -> Result<Array2<f64>> {
366 let embeddings = X.dot(&self.encoder_weights);
367
368 let n_samples = X.dim().0;
369 let mut probabilities = Array2::zeros((n_samples, self.n_classes));
370
371 for i in 0..n_samples {
372 let embedding = embeddings.row(i);
373 let mut distances = Vec::new();
374
375 for class_idx in 0..self.n_classes {
376 let centroid = self.class_centroids.row(class_idx);
377 let distance = embedding
378 .iter()
379 .zip(centroid.iter())
380 .map(|(e, c)| (e - c).powi(2))
381 .sum::<f64>()
382 .sqrt();
383 distances.push(-distance); }
385
386 let max_distance = distances.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
388 let exp_distances: Vec<f64> = distances
389 .iter()
390 .map(|&d| (d - max_distance).exp())
391 .collect();
392 let sum_exp: f64 = exp_distances.iter().sum();
393
394 for (j, &exp_dist) in exp_distances.iter().enumerate() {
395 probabilities[[i, j]] = exp_dist / sum_exp;
396 }
397 }
398
399 Ok(probabilities)
400 }
401}
402
403#[allow(non_snake_case)]
404#[cfg(test)]
405mod tests {
406 use super::*;
407 use approx::assert_abs_diff_eq;
408 use scirs2_core::array;
409
410 #[test]
411 fn test_supervised_contrastive_learning_creation() {
412 let scl = SupervisedContrastiveLearning::new()
413 .embedding_dim(32)
414 .temperature(0.05)
415 .unwrap()
416 .augmentation_strength(0.3)
417 .unwrap()
418 .labeled_weight(3.0)
419 .random_state(42);
420
421 assert_eq!(scl.embedding_dim, 32);
422 assert_eq!(scl.temperature, 0.05);
423 assert_eq!(scl.augmentation_strength, 0.3);
424 assert_eq!(scl.labeled_weight, 3.0);
425 assert_eq!(scl.random_state, Some(42));
426 }
427
428 #[test]
429 fn test_supervised_contrastive_learning_invalid_augmentation() {
430 let result = SupervisedContrastiveLearning::new().augmentation_strength(1.5);
431 assert!(result.is_err());
432 }
433
434 #[test]
435 #[allow(non_snake_case)]
436 fn test_supervised_contrastive_learning_fit_predict() {
437 let X = array![
438 [1.0, 2.0, 3.0],
439 [2.0, 3.0, 4.0],
440 [3.0, 4.0, 5.0],
441 [4.0, 5.0, 6.0],
442 [5.0, 6.0, 7.0],
443 [6.0, 7.0, 8.0]
444 ];
445 let y = array![0, 1, 0, 1, -1, -1];
446
447 let scl = SupervisedContrastiveLearning::new()
448 .embedding_dim(4)
449 .max_epochs(2)
450 .batch_size(3)
451 .unwrap()
452 .random_state(42);
453
454 let fitted = scl.fit(&X.view(), &y.view()).unwrap();
455 let predictions = fitted.predict(&X.view()).unwrap();
456
457 assert_eq!(predictions.len(), 6);
458 for &pred in predictions.iter() {
459 assert!(pred == 0 || pred == 1);
460 }
461
462 let probabilities = fitted.predict_proba(&X.view()).unwrap();
463 assert_eq!(probabilities.dim(), (6, 2));
464
465 for i in 0..6 {
467 let sum: f64 = probabilities.row(i).sum();
468 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
469 }
470 }
471
472 #[test]
473 #[allow(non_snake_case)]
474 fn test_insufficient_labeled_samples() {
475 let X = array![[1.0, 2.0], [2.0, 3.0]];
476 let y = array![-1, -1]; let scl = SupervisedContrastiveLearning::new();
479 let result = scl.fit(&X.view(), &y.view());
480 assert!(result.is_err());
481 }
482}