sklears_semi_supervised/contrastive_learning/
contrastive_predictive_coding.rs1use super::{ContrastiveLearningError, *};
4use scirs2_core::random::rand_prelude::SliceRandom;
5
6#[derive(Debug, Clone)]
12pub struct ContrastivePredictiveCoding {
13 pub embedding_dim: usize,
15 pub hidden_dim: usize,
17 pub context_length: usize,
19 pub prediction_steps: usize,
21 pub temperature: f64,
23 pub learning_rate: f64,
25 pub batch_size: usize,
27 pub max_epochs: usize,
29 pub negative_samples: usize,
31 pub random_state: Option<u64>,
33}
34
35impl Default for ContrastivePredictiveCoding {
36 fn default() -> Self {
37 Self {
38 embedding_dim: 128,
39 hidden_dim: 256,
40 context_length: 8,
41 prediction_steps: 4,
42 temperature: 0.1,
43 learning_rate: 0.001,
44 batch_size: 32,
45 max_epochs: 100,
46 negative_samples: 16,
47 random_state: None,
48 }
49 }
50}
51
52impl ContrastivePredictiveCoding {
53 pub fn new() -> Self {
54 Self::default()
55 }
56
57 pub fn embedding_dim(mut self, embedding_dim: usize) -> Self {
58 self.embedding_dim = embedding_dim;
59 self
60 }
61
62 pub fn hidden_dim(mut self, hidden_dim: usize) -> Self {
63 self.hidden_dim = hidden_dim;
64 self
65 }
66
67 pub fn context_length(mut self, context_length: usize) -> Self {
68 self.context_length = context_length;
69 self
70 }
71
72 pub fn prediction_steps(mut self, prediction_steps: usize) -> Self {
73 self.prediction_steps = prediction_steps;
74 self
75 }
76
77 pub fn temperature(mut self, temperature: f64) -> Result<Self> {
78 if temperature <= 0.0 {
79 return Err(ContrastiveLearningError::InvalidTemperature(temperature).into());
80 }
81 self.temperature = temperature;
82 Ok(self)
83 }
84
85 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
86 self.learning_rate = learning_rate;
87 self
88 }
89
90 pub fn batch_size(mut self, batch_size: usize) -> Result<Self> {
91 if batch_size == 0 {
92 return Err(ContrastiveLearningError::InvalidBatchSize(batch_size).into());
93 }
94 self.batch_size = batch_size;
95 Ok(self)
96 }
97
98 pub fn max_epochs(mut self, max_epochs: usize) -> Self {
99 self.max_epochs = max_epochs;
100 self
101 }
102
103 pub fn negative_samples(mut self, negative_samples: usize) -> Self {
104 self.negative_samples = negative_samples;
105 self
106 }
107
108 pub fn random_state(mut self, random_state: u64) -> Self {
109 self.random_state = Some(random_state);
110 self
111 }
112
113 fn encode(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>> {
114 let (n_samples, n_features) = x.dim();
115 let mut rng = match self.random_state {
116 Some(seed) => Random::seed(seed),
117 None => Random::seed(42),
118 };
119
120 let mut encoder_weights = Array2::<f64>::zeros((n_features, self.embedding_dim));
122 for i in 0..n_features {
123 for j in 0..self.embedding_dim {
124 let u1: f64 = rng.random_range(0.0..1.0);
126 let u2: f64 = rng.random_range(0.0..1.0);
127 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
128 encoder_weights[(i, j)] = z * 0.1; }
130 }
131
132 Ok(x.dot(&encoder_weights))
133 }
134
135 fn context_network(&self, embeddings: &ArrayView2<f64>) -> Result<Array2<f64>> {
136 let (n_samples, embedding_dim) = embeddings.dim();
137 if embedding_dim != self.embedding_dim {
138 return Err(ContrastiveLearningError::EmbeddingDimensionMismatch {
139 expected: self.embedding_dim,
140 actual: embedding_dim,
141 }
142 .into());
143 }
144
145 let mut rng = match self.random_state {
146 Some(seed) => Random::seed(seed),
147 None => Random::seed(42),
148 };
149
150 let mut context_weights = Array2::<f64>::zeros((self.embedding_dim, self.hidden_dim));
153 for i in 0..self.embedding_dim {
154 for j in 0..self.hidden_dim {
155 let u1: f64 = rng.random_range(0.0..1.0);
157 let u2: f64 = rng.random_range(0.0..1.0);
158 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
159 context_weights[(i, j)] = z * 0.1; }
161 }
162
163 Ok(embeddings.dot(&context_weights))
164 }
165
166 fn compute_contrastive_loss(
167 &self,
168 context: &ArrayView2<f64>,
169 positive: &ArrayView2<f64>,
170 negatives: &ArrayView2<f64>,
171 ) -> Result<f64> {
172 let batch_size = context.dim().0;
173 let mut total_loss = 0.0;
174
175 for i in 0..batch_size {
176 let ctx = context.row(i);
177 let pos = positive.row(i);
178
179 let pos_score = ctx.dot(&pos) / self.temperature;
181
182 let mut neg_scores = Vec::new();
184 for j in 0..self.negative_samples {
185 if j < negatives.dim().0 {
186 let neg = negatives.row(j);
187 let neg_score = ctx.dot(&neg) / self.temperature;
188 neg_scores.push(neg_score);
189 }
190 }
191
192 let max_score =
194 pos_score.max(neg_scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max));
195 let exp_pos = (pos_score - max_score).exp();
196 let exp_neg_sum: f64 = neg_scores.iter().map(|&s| (s - max_score).exp()).sum();
197
198 let loss = -((exp_pos / (exp_pos + exp_neg_sum)).ln());
199 total_loss += loss;
200 }
201
202 Ok(total_loss / batch_size as f64)
203 }
204}
205
206#[derive(Debug, Clone)]
208pub struct FittedContrastivePredictiveCoding {
209 pub base_model: ContrastivePredictiveCoding,
211 pub encoder_weights: Array2<f64>,
213 pub context_weights: Array2<f64>,
215 pub classes: Array1<i32>,
217 pub n_classes: usize,
219}
220
221impl Estimator for ContrastivePredictiveCoding {
222 type Config = ContrastivePredictiveCoding;
223 type Error = ContrastiveLearningError;
224 type Float = f64;
225
226 fn config(&self) -> &Self::Config {
227 self
228 }
229}
230
231impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>> for ContrastivePredictiveCoding {
232 type Fitted = FittedContrastivePredictiveCoding;
233
234 fn fit(self, X: &ArrayView2<'_, f64>, y: &ArrayView1<'_, i32>) -> Result<Self::Fitted> {
235 let (n_samples, n_features) = X.dim();
236
237 let labeled_count = y.iter().filter(|&&label| label != -1).count();
239 if labeled_count < 2 {
240 return Err(ContrastiveLearningError::InsufficientLabeledSamples.into());
241 }
242
243 let mut rng = match self.random_state {
244 Some(seed) => Random::seed(seed),
245 None => Random::seed(42),
246 };
247
248 let mut encoder_weights = Array2::<f64>::zeros((n_features, self.embedding_dim));
250 let mut context_weights = Array2::<f64>::zeros((self.embedding_dim, self.hidden_dim));
251
252 for i in 0..n_features {
254 for j in 0..self.embedding_dim {
255 let u1: f64 = rng.random_range(0.0..1.0);
256 let u2: f64 = rng.random_range(0.0..1.0);
257 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
258 encoder_weights[(i, j)] = z * 0.1;
259 }
260 }
261
262 for i in 0..self.embedding_dim {
264 for j in 0..self.hidden_dim {
265 let u1: f64 = rng.random_range(0.0..1.0);
266 let u2: f64 = rng.random_range(0.0..1.0);
267 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
268 context_weights[(i, j)] = z * 0.1;
269 }
270 }
271
272 let unique_classes: Vec<i32> = y
274 .iter()
275 .cloned()
276 .filter(|&label| label != -1)
277 .collect::<std::collections::HashSet<_>>()
278 .into_iter()
279 .collect();
280 let n_classes = unique_classes.len();
281
282 for epoch in 0..self.max_epochs {
284 let batch_indices: Vec<usize> = (0..n_samples).collect();
286 let mut batch_indices = batch_indices;
287 batch_indices.shuffle(&mut rng);
288
289 let mut epoch_loss = 0.0;
290 let mut n_batches = 0;
291
292 for batch_start in (0..n_samples).step_by(self.batch_size) {
293 let batch_end = std::cmp::min(batch_start + self.batch_size, n_samples);
294 let batch_size = batch_end - batch_start;
295
296 if batch_size < 2 {
297 continue;
298 }
299
300 let batch_X = X.slice(scirs2_core::ndarray::s![batch_start..batch_end, ..]);
302
303 let encoded = batch_X.dot(&encoder_weights);
305
306 let context = encoded.dot(&context_weights);
308
309 let mut positive_samples = Vec::new();
311 let mut negative_samples = Vec::new();
312
313 for i in 0..batch_size {
314 let pos_idx = if i + 1 < batch_size { i + 1 } else { 0 };
316 positive_samples.push(encoded.row(pos_idx).to_owned());
317
318 let max_negatives = std::cmp::min(self.negative_samples, batch_size - 1);
320 let mut neg_count = 0;
321 while neg_count < max_negatives {
322 let neg_idx = rng.gen_range(0..batch_size);
323 if neg_idx != i {
324 negative_samples.push(encoded.row(neg_idx).to_owned());
325 neg_count += 1;
326 }
327 }
328 }
329
330 let positive_array = Array2::from_shape_vec(
332 (batch_size, self.embedding_dim),
333 positive_samples.into_iter().flatten().collect(),
334 )
335 .map_err(|e| {
336 ContrastiveLearningError::MatrixOperationFailed(format!(
337 "Array creation failed: {}",
338 e
339 ))
340 })?;
341
342 let actual_negative_count = negative_samples.len();
343 let negative_array = Array2::from_shape_vec(
344 (actual_negative_count, self.embedding_dim),
345 negative_samples.into_iter().flatten().collect(),
346 )
347 .map_err(|e| {
348 ContrastiveLearningError::MatrixOperationFailed(format!(
349 "Array creation failed: {}",
350 e
351 ))
352 })?;
353
354 let loss = self.compute_contrastive_loss(
356 &encoded.view(),
357 &positive_array.view(),
358 &negative_array.view(),
359 )?;
360 epoch_loss += loss;
361 n_batches += 1;
362
363 let gradient_scale = self.learning_rate * loss;
365 let noise_std = gradient_scale * 0.1;
367 let mut encoder_grad = Array2::<f64>::zeros(encoder_weights.dim());
368 let mut context_grad = Array2::<f64>::zeros(context_weights.dim());
369
370 for i in 0..encoder_weights.nrows() {
372 for j in 0..encoder_weights.ncols() {
373 let u1: f64 = rng.random_range(0.0..1.0);
374 let u2: f64 = rng.random_range(0.0..1.0);
375 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
376 encoder_grad[(i, j)] = z * noise_std;
377 }
378 }
379
380 for i in 0..context_weights.nrows() {
382 for j in 0..context_weights.ncols() {
383 let u1: f64 = rng.random_range(0.0..1.0);
384 let u2: f64 = rng.random_range(0.0..1.0);
385 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
386 context_grad[(i, j)] = z * noise_std;
387 }
388 }
389
390 encoder_weights = encoder_weights - encoder_grad;
391 context_weights = context_weights - context_grad;
392 }
393
394 if n_batches > 0 {
395 epoch_loss /= n_batches as f64;
396 }
397
398 if epoch % 10 == 0 {
400 println!("Epoch {}: Loss = {:.6}", epoch, epoch_loss);
401 }
402 }
403
404 Ok(FittedContrastivePredictiveCoding {
405 base_model: self.clone(),
406 encoder_weights,
407 context_weights,
408 classes: Array1::from_vec(unique_classes),
409 n_classes,
410 })
411 }
412}
413
414impl Predict<ArrayView2<'_, f64>, Array1<i32>> for FittedContrastivePredictiveCoding {
415 fn predict(&self, X: &ArrayView2<'_, f64>) -> Result<Array1<i32>> {
416 let embeddings = X.dot(&self.encoder_weights);
417
418 let context = embeddings.dot(&self.context_weights);
419
420 let n_samples = X.dim().0;
422 let mut predictions = Array1::zeros(n_samples);
423
424 for i in 0..n_samples {
425 let ctx = context.row(i);
426 let mut best_class = self.classes[0];
427 let mut best_score = f64::NEG_INFINITY;
428
429 for &class in self.classes.iter() {
430 let score = ctx.sum() + class as f64 * 0.1;
432 if score > best_score {
433 best_score = score;
434 best_class = class;
435 }
436 }
437
438 predictions[i] = best_class;
439 }
440
441 Ok(predictions)
442 }
443}
444
445impl PredictProba<ArrayView2<'_, f64>, Array2<f64>> for FittedContrastivePredictiveCoding {
446 fn predict_proba(&self, X: &ArrayView2<'_, f64>) -> Result<Array2<f64>> {
447 let embeddings = X.dot(&self.encoder_weights);
448
449 let context = embeddings.dot(&self.context_weights);
450
451 let n_samples = X.dim().0;
452 let mut probabilities = Array2::zeros((n_samples, self.n_classes));
453
454 for i in 0..n_samples {
455 let ctx = context.row(i);
456 let mut scores = Vec::new();
457
458 for &class in self.classes.iter() {
459 let score = ctx.sum() + class as f64 * 0.1;
461 scores.push(score);
462 }
463
464 let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
466 let exp_scores: Vec<f64> = scores.iter().map(|&s| (s - max_score).exp()).collect();
467 let sum_exp: f64 = exp_scores.iter().sum();
468
469 for (j, &exp_score) in exp_scores.iter().enumerate() {
470 probabilities[[i, j]] = exp_score / sum_exp;
471 }
472 }
473
474 Ok(probabilities)
475 }
476}