sklears_semi_supervised/contrastive_learning/
contrastive_predictive_coding.rs

1//! Contrastive Predictive Coding (CPC) implementation for semi-supervised learning
2
3use super::{ContrastiveLearningError, *};
4use scirs2_core::random::rand_prelude::SliceRandom;
5
6/// Contrastive Predictive Coding (CPC) for semi-supervised learning
7///
8/// CPC learns representations by predicting future observations from past contexts
9/// in a contrastive manner. It maximizes mutual information between contexts and
10/// positive samples while minimizing it for negative samples.
11#[derive(Debug, Clone)]
12pub struct ContrastivePredictiveCoding {
13    /// embedding_dim
14    pub embedding_dim: usize,
15    /// hidden_dim
16    pub hidden_dim: usize,
17    /// context_length
18    pub context_length: usize,
19    /// prediction_steps
20    pub prediction_steps: usize,
21    /// temperature
22    pub temperature: f64,
23    /// learning_rate
24    pub learning_rate: f64,
25    /// batch_size
26    pub batch_size: usize,
27    /// max_epochs
28    pub max_epochs: usize,
29    /// negative_samples
30    pub negative_samples: usize,
31    /// random_state
32    pub random_state: Option<u64>,
33}
34
35impl Default for ContrastivePredictiveCoding {
36    fn default() -> Self {
37        Self {
38            embedding_dim: 128,
39            hidden_dim: 256,
40            context_length: 8,
41            prediction_steps: 4,
42            temperature: 0.1,
43            learning_rate: 0.001,
44            batch_size: 32,
45            max_epochs: 100,
46            negative_samples: 16,
47            random_state: None,
48        }
49    }
50}
51
52impl ContrastivePredictiveCoding {
53    pub fn new() -> Self {
54        Self::default()
55    }
56
57    pub fn embedding_dim(mut self, embedding_dim: usize) -> Self {
58        self.embedding_dim = embedding_dim;
59        self
60    }
61
62    pub fn hidden_dim(mut self, hidden_dim: usize) -> Self {
63        self.hidden_dim = hidden_dim;
64        self
65    }
66
67    pub fn context_length(mut self, context_length: usize) -> Self {
68        self.context_length = context_length;
69        self
70    }
71
72    pub fn prediction_steps(mut self, prediction_steps: usize) -> Self {
73        self.prediction_steps = prediction_steps;
74        self
75    }
76
77    pub fn temperature(mut self, temperature: f64) -> Result<Self> {
78        if temperature <= 0.0 {
79            return Err(ContrastiveLearningError::InvalidTemperature(temperature).into());
80        }
81        self.temperature = temperature;
82        Ok(self)
83    }
84
85    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
86        self.learning_rate = learning_rate;
87        self
88    }
89
90    pub fn batch_size(mut self, batch_size: usize) -> Result<Self> {
91        if batch_size == 0 {
92            return Err(ContrastiveLearningError::InvalidBatchSize(batch_size).into());
93        }
94        self.batch_size = batch_size;
95        Ok(self)
96    }
97
98    pub fn max_epochs(mut self, max_epochs: usize) -> Self {
99        self.max_epochs = max_epochs;
100        self
101    }
102
103    pub fn negative_samples(mut self, negative_samples: usize) -> Self {
104        self.negative_samples = negative_samples;
105        self
106    }
107
108    pub fn random_state(mut self, random_state: u64) -> Self {
109        self.random_state = Some(random_state);
110        self
111    }
112
113    fn encode(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>> {
114        let (n_samples, n_features) = x.dim();
115        let mut rng = match self.random_state {
116            Some(seed) => Random::seed(seed),
117            None => Random::seed(42),
118        };
119
120        // Simple linear encoder for demonstration - create weights manually
121        let mut encoder_weights = Array2::<f64>::zeros((n_features, self.embedding_dim));
122        for i in 0..n_features {
123            for j in 0..self.embedding_dim {
124                // Generate normal distributed random number using Box-Muller transform
125                let u1: f64 = rng.random_range(0.0..1.0);
126                let u2: f64 = rng.random_range(0.0..1.0);
127                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
128                encoder_weights[(i, j)] = z * 0.1; // mean=0.0, std=0.1
129            }
130        }
131
132        Ok(x.dot(&encoder_weights))
133    }
134
135    fn context_network(&self, embeddings: &ArrayView2<f64>) -> Result<Array2<f64>> {
136        let (n_samples, embedding_dim) = embeddings.dim();
137        if embedding_dim != self.embedding_dim {
138            return Err(ContrastiveLearningError::EmbeddingDimensionMismatch {
139                expected: self.embedding_dim,
140                actual: embedding_dim,
141            }
142            .into());
143        }
144
145        let mut rng = match self.random_state {
146            Some(seed) => Random::seed(seed),
147            None => Random::seed(42),
148        };
149
150        // Simple context network (could be LSTM/GRU in practice)
151        // Create context weights manually
152        let mut context_weights = Array2::<f64>::zeros((self.embedding_dim, self.hidden_dim));
153        for i in 0..self.embedding_dim {
154            for j in 0..self.hidden_dim {
155                // Generate normal distributed random number using Box-Muller transform
156                let u1: f64 = rng.random_range(0.0..1.0);
157                let u2: f64 = rng.random_range(0.0..1.0);
158                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
159                context_weights[(i, j)] = z * 0.1; // mean=0.0, std=0.1
160            }
161        }
162
163        Ok(embeddings.dot(&context_weights))
164    }
165
166    fn compute_contrastive_loss(
167        &self,
168        context: &ArrayView2<f64>,
169        positive: &ArrayView2<f64>,
170        negatives: &ArrayView2<f64>,
171    ) -> Result<f64> {
172        let batch_size = context.dim().0;
173        let mut total_loss = 0.0;
174
175        for i in 0..batch_size {
176            let ctx = context.row(i);
177            let pos = positive.row(i);
178
179            // Compute positive score
180            let pos_score = ctx.dot(&pos) / self.temperature;
181
182            // Compute negative scores
183            let mut neg_scores = Vec::new();
184            for j in 0..self.negative_samples {
185                if j < negatives.dim().0 {
186                    let neg = negatives.row(j);
187                    let neg_score = ctx.dot(&neg) / self.temperature;
188                    neg_scores.push(neg_score);
189                }
190            }
191
192            // Compute softmax loss
193            let max_score =
194                pos_score.max(neg_scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max));
195            let exp_pos = (pos_score - max_score).exp();
196            let exp_neg_sum: f64 = neg_scores.iter().map(|&s| (s - max_score).exp()).sum();
197
198            let loss = -((exp_pos / (exp_pos + exp_neg_sum)).ln());
199            total_loss += loss;
200        }
201
202        Ok(total_loss / batch_size as f64)
203    }
204}
205
206/// Fitted Contrastive Predictive Coding model
207#[derive(Debug, Clone)]
208pub struct FittedContrastivePredictiveCoding {
209    /// base_model
210    pub base_model: ContrastivePredictiveCoding,
211    /// encoder_weights
212    pub encoder_weights: Array2<f64>,
213    /// context_weights
214    pub context_weights: Array2<f64>,
215    /// classes
216    pub classes: Array1<i32>,
217    /// n_classes
218    pub n_classes: usize,
219}
220
221impl Estimator for ContrastivePredictiveCoding {
222    type Config = ContrastivePredictiveCoding;
223    type Error = ContrastiveLearningError;
224    type Float = f64;
225
226    fn config(&self) -> &Self::Config {
227        self
228    }
229}
230
231impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>> for ContrastivePredictiveCoding {
232    type Fitted = FittedContrastivePredictiveCoding;
233
234    fn fit(self, X: &ArrayView2<'_, f64>, y: &ArrayView1<'_, i32>) -> Result<Self::Fitted> {
235        let (n_samples, n_features) = X.dim();
236
237        // Check for sufficient labeled samples
238        let labeled_count = y.iter().filter(|&&label| label != -1).count();
239        if labeled_count < 2 {
240            return Err(ContrastiveLearningError::InsufficientLabeledSamples.into());
241        }
242
243        let mut rng = match self.random_state {
244            Some(seed) => Random::seed(seed),
245            None => Random::seed(42),
246        };
247
248        // Initialize encoder and context networks
249        let mut encoder_weights = Array2::<f64>::zeros((n_features, self.embedding_dim));
250        let mut context_weights = Array2::<f64>::zeros((self.embedding_dim, self.hidden_dim));
251
252        // Fill encoder weights with normal distribution (mean=0.0, std=0.1)
253        for i in 0..n_features {
254            for j in 0..self.embedding_dim {
255                let u1: f64 = rng.random_range(0.0..1.0);
256                let u2: f64 = rng.random_range(0.0..1.0);
257                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
258                encoder_weights[(i, j)] = z * 0.1;
259            }
260        }
261
262        // Fill context weights with normal distribution (mean=0.0, std=0.1)
263        for i in 0..self.embedding_dim {
264            for j in 0..self.hidden_dim {
265                let u1: f64 = rng.random_range(0.0..1.0);
266                let u2: f64 = rng.random_range(0.0..1.0);
267                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
268                context_weights[(i, j)] = z * 0.1;
269            }
270        }
271
272        // Get unique classes
273        let unique_classes: Vec<i32> = y
274            .iter()
275            .cloned()
276            .filter(|&label| label != -1)
277            .collect::<std::collections::HashSet<_>>()
278            .into_iter()
279            .collect();
280        let n_classes = unique_classes.len();
281
282        // Training loop
283        for epoch in 0..self.max_epochs {
284            // Generate batches
285            let batch_indices: Vec<usize> = (0..n_samples).collect();
286            let mut batch_indices = batch_indices;
287            batch_indices.shuffle(&mut rng);
288
289            let mut epoch_loss = 0.0;
290            let mut n_batches = 0;
291
292            for batch_start in (0..n_samples).step_by(self.batch_size) {
293                let batch_end = std::cmp::min(batch_start + self.batch_size, n_samples);
294                let batch_size = batch_end - batch_start;
295
296                if batch_size < 2 {
297                    continue;
298                }
299
300                // Get batch data
301                let batch_X = X.slice(scirs2_core::ndarray::s![batch_start..batch_end, ..]);
302
303                // Encode batch
304                let encoded = batch_X.dot(&encoder_weights);
305
306                // Context network
307                let context = encoded.dot(&context_weights);
308
309                // Generate positive and negative samples
310                let mut positive_samples = Vec::new();
311                let mut negative_samples = Vec::new();
312
313                for i in 0..batch_size {
314                    // Use next sample as positive (temporal structure)
315                    let pos_idx = if i + 1 < batch_size { i + 1 } else { 0 };
316                    positive_samples.push(encoded.row(pos_idx).to_owned());
317
318                    // Random negative samples
319                    let max_negatives = std::cmp::min(self.negative_samples, batch_size - 1);
320                    let mut neg_count = 0;
321                    while neg_count < max_negatives {
322                        let neg_idx = rng.gen_range(0..batch_size);
323                        if neg_idx != i {
324                            negative_samples.push(encoded.row(neg_idx).to_owned());
325                            neg_count += 1;
326                        }
327                    }
328                }
329
330                // Convert to arrays
331                let positive_array = Array2::from_shape_vec(
332                    (batch_size, self.embedding_dim),
333                    positive_samples.into_iter().flatten().collect(),
334                )
335                .map_err(|e| {
336                    ContrastiveLearningError::MatrixOperationFailed(format!(
337                        "Array creation failed: {}",
338                        e
339                    ))
340                })?;
341
342                let actual_negative_count = negative_samples.len();
343                let negative_array = Array2::from_shape_vec(
344                    (actual_negative_count, self.embedding_dim),
345                    negative_samples.into_iter().flatten().collect(),
346                )
347                .map_err(|e| {
348                    ContrastiveLearningError::MatrixOperationFailed(format!(
349                        "Array creation failed: {}",
350                        e
351                    ))
352                })?;
353
354                // Compute loss using encoded representations
355                let loss = self.compute_contrastive_loss(
356                    &encoded.view(),
357                    &positive_array.view(),
358                    &negative_array.view(),
359                )?;
360                epoch_loss += loss;
361                n_batches += 1;
362
363                // Simple gradient update (in practice, would use proper backpropagation)
364                let gradient_scale = self.learning_rate * loss;
365                // Create gradient noise manually
366                let noise_std = gradient_scale * 0.1;
367                let mut encoder_grad = Array2::<f64>::zeros(encoder_weights.dim());
368                let mut context_grad = Array2::<f64>::zeros(context_weights.dim());
369
370                // Fill encoder grad with normal noise
371                for i in 0..encoder_weights.nrows() {
372                    for j in 0..encoder_weights.ncols() {
373                        let u1: f64 = rng.random_range(0.0..1.0);
374                        let u2: f64 = rng.random_range(0.0..1.0);
375                        let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
376                        encoder_grad[(i, j)] = z * noise_std;
377                    }
378                }
379
380                // Fill context grad with normal noise
381                for i in 0..context_weights.nrows() {
382                    for j in 0..context_weights.ncols() {
383                        let u1: f64 = rng.random_range(0.0..1.0);
384                        let u2: f64 = rng.random_range(0.0..1.0);
385                        let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
386                        context_grad[(i, j)] = z * noise_std;
387                    }
388                }
389
390                encoder_weights = encoder_weights - encoder_grad;
391                context_weights = context_weights - context_grad;
392            }
393
394            if n_batches > 0 {
395                epoch_loss /= n_batches as f64;
396            }
397
398            // Early stopping or convergence check could be added here
399            if epoch % 10 == 0 {
400                println!("Epoch {}: Loss = {:.6}", epoch, epoch_loss);
401            }
402        }
403
404        Ok(FittedContrastivePredictiveCoding {
405            base_model: self.clone(),
406            encoder_weights,
407            context_weights,
408            classes: Array1::from_vec(unique_classes),
409            n_classes,
410        })
411    }
412}
413
414impl Predict<ArrayView2<'_, f64>, Array1<i32>> for FittedContrastivePredictiveCoding {
415    fn predict(&self, X: &ArrayView2<'_, f64>) -> Result<Array1<i32>> {
416        let embeddings = X.dot(&self.encoder_weights);
417
418        let context = embeddings.dot(&self.context_weights);
419
420        // Simple nearest class prediction based on context representations
421        let n_samples = X.dim().0;
422        let mut predictions = Array1::zeros(n_samples);
423
424        for i in 0..n_samples {
425            let ctx = context.row(i);
426            let mut best_class = self.classes[0];
427            let mut best_score = f64::NEG_INFINITY;
428
429            for &class in self.classes.iter() {
430                // Simple scoring based on context magnitude (placeholder)
431                let score = ctx.sum() + class as f64 * 0.1;
432                if score > best_score {
433                    best_score = score;
434                    best_class = class;
435                }
436            }
437
438            predictions[i] = best_class;
439        }
440
441        Ok(predictions)
442    }
443}
444
445impl PredictProba<ArrayView2<'_, f64>, Array2<f64>> for FittedContrastivePredictiveCoding {
446    fn predict_proba(&self, X: &ArrayView2<'_, f64>) -> Result<Array2<f64>> {
447        let embeddings = X.dot(&self.encoder_weights);
448
449        let context = embeddings.dot(&self.context_weights);
450
451        let n_samples = X.dim().0;
452        let mut probabilities = Array2::zeros((n_samples, self.n_classes));
453
454        for i in 0..n_samples {
455            let ctx = context.row(i);
456            let mut scores = Vec::new();
457
458            for &class in self.classes.iter() {
459                // Simple scoring based on context (placeholder)
460                let score = ctx.sum() + class as f64 * 0.1;
461                scores.push(score);
462            }
463
464            // Softmax normalization
465            let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
466            let exp_scores: Vec<f64> = scores.iter().map(|&s| (s - max_score).exp()).collect();
467            let sum_exp: f64 = exp_scores.iter().sum();
468
469            for (j, &exp_score) in exp_scores.iter().enumerate() {
470                probabilities[[i, j]] = exp_score / sum_exp;
471            }
472        }
473
474        Ok(probabilities)
475    }
476}