sklears_semi_supervised/deep_learning/
autoregressive_models.rs

1//! Autoregressive models for generative semi-supervised learning
2//!
3//! This module implements autoregressive models that can generate data by modeling
4//! the conditional probability distribution p(x_t | x_{1:t-1}) for sequential data.
5//! These models are useful for semi-supervised learning by learning the data distribution
6//! and incorporating labeled information through conditional generation.
7
8use scirs2_core::ndarray_ext::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
9use scirs2_core::random::Random;
10use sklears_core::error::SklearsError;
11use sklears_core::traits::{Fit, Predict, PredictProba};
12
13/// Autoregressive neural network for semi-supervised learning
14///
15/// This implements an autoregressive model that learns to generate sequences
16/// by predicting the next element given previous elements. For semi-supervised
17/// learning, it combines generative modeling with discriminative classification.
18#[derive(Debug, Clone)]
19pub struct AutoregressiveModel {
20    /// Hidden layer dimensions
21    hidden_dims: Vec<usize>,
22    /// Number of classes for classification
23    n_classes: usize,
24    /// Input dimension
25    input_dim: usize,
26    /// Sequence length for autoregressive modeling
27    sequence_length: usize,
28    /// Learning rate for gradient descent
29    learning_rate: f64,
30    /// Number of training epochs
31    epochs: usize,
32    /// Regularization parameter
33    regularization: f64,
34    /// Temperature for softmax sampling
35    temperature: f64,
36    /// Weight for classification loss vs reconstruction loss
37    classification_weight: f64,
38    /// Model parameters
39    weights: Vec<Array2<f64>>,
40    biases: Vec<Array1<f64>>,
41    /// Classification head parameters
42    class_weights: Array2<f64>,
43    class_bias: Array1<f64>,
44    /// Whether the model has been fitted
45    fitted: bool,
46}
47
48impl Default for AutoregressiveModel {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl AutoregressiveModel {
55    /// Create a new autoregressive model
56    pub fn new() -> Self {
57        Self {
58            hidden_dims: vec![64, 32],
59            n_classes: 2,
60            input_dim: 10,
61            sequence_length: 10,
62            learning_rate: 0.001,
63            epochs: 100,
64            regularization: 0.01,
65            temperature: 1.0,
66            classification_weight: 1.0,
67            weights: Vec::new(),
68            biases: Vec::new(),
69            class_weights: Array2::zeros((0, 0)),
70            class_bias: Array1::zeros(0),
71            fitted: false,
72        }
73    }
74
75    /// Set the hidden layer dimensions
76    pub fn hidden_dims(mut self, dims: Vec<usize>) -> Self {
77        self.hidden_dims = dims;
78        self
79    }
80
81    /// Set the number of classes
82    pub fn n_classes(mut self, n_classes: usize) -> Self {
83        self.n_classes = n_classes;
84        self
85    }
86
87    /// Set the input dimension
88    pub fn input_dim(mut self, input_dim: usize) -> Self {
89        self.input_dim = input_dim;
90        self
91    }
92
93    /// Set the sequence length
94    pub fn sequence_length(mut self, length: usize) -> Self {
95        self.sequence_length = length;
96        self
97    }
98
99    /// Set the learning rate
100    pub fn learning_rate(mut self, lr: f64) -> Self {
101        self.learning_rate = lr;
102        self
103    }
104
105    /// Set the number of epochs
106    pub fn epochs(mut self, epochs: usize) -> Self {
107        self.epochs = epochs;
108        self
109    }
110
111    /// Set the regularization parameter
112    pub fn regularization(mut self, reg: f64) -> Self {
113        self.regularization = reg;
114        self
115    }
116
117    /// Set the temperature for sampling
118    pub fn temperature(mut self, temp: f64) -> Self {
119        self.temperature = temp;
120        self
121    }
122
123    /// Set the classification weight
124    pub fn classification_weight(mut self, weight: f64) -> Self {
125        self.classification_weight = weight;
126        self
127    }
128
129    /// Initialize the model parameters
130    fn initialize_parameters(&mut self) -> Result<(), SklearsError> {
131        let mut layer_dims = vec![self.input_dim];
132        layer_dims.extend_from_slice(&self.hidden_dims);
133        layer_dims.push(self.input_dim); // Output dimension for reconstruction
134
135        self.weights.clear();
136        self.biases.clear();
137
138        // Initialize weights using Xavier initialization
139        for i in 0..layer_dims.len() - 1 {
140            let fan_in = layer_dims[i];
141            let fan_out = layer_dims[i + 1];
142            let scale = (6.0 / (fan_in + fan_out) as f64).sqrt();
143
144            // Xavier initialization - create weights manually
145            let mut rng = Random::default();
146            let mut weight = Array2::<f64>::zeros((fan_in, fan_out));
147            for i in 0..fan_in {
148                for j in 0..fan_out {
149                    // Generate uniform distributed random number in [-scale, scale]
150                    let u: f64 = rng.random_range(0.0..1.0);
151                    weight[(i, j)] = u * (2.0 * scale) - scale;
152                }
153            }
154            let bias = Array1::zeros(fan_out);
155
156            self.weights.push(weight);
157            self.biases.push(bias);
158        }
159
160        // Initialize classification head
161        let last_hidden_dim = self.hidden_dims.last().unwrap_or(&self.input_dim);
162        let class_scale = (6.0 / (last_hidden_dim + self.n_classes) as f64).sqrt();
163
164        // Initialize class weights manually
165        let mut rng = Random::default();
166        let mut class_weights = Array2::<f64>::zeros((*last_hidden_dim, self.n_classes));
167        for i in 0..*last_hidden_dim {
168            for j in 0..self.n_classes {
169                // Generate uniform distributed random number in [-class_scale, class_scale]
170                let u: f64 = rng.random_range(0.0..1.0);
171                class_weights[(i, j)] = u * (2.0 * class_scale) - class_scale;
172            }
173        }
174        self.class_weights = class_weights;
175        self.class_bias = Array1::zeros(self.n_classes);
176
177        Ok(())
178    }
179
180    /// Apply ReLU activation function
181    fn relu(&self, x: &Array1<f64>) -> Array1<f64> {
182        x.mapv(|v| v.max(0.0))
183    }
184
185    /// Apply softmax activation function
186    fn softmax(&self, x: &Array1<f64>) -> Array1<f64> {
187        let max_val = x.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
188        let exp_x = x.mapv(|v| ((v - max_val) / self.temperature).exp());
189        let sum_exp = exp_x.sum();
190        exp_x / sum_exp
191    }
192
193    /// Forward pass through the autoregressive network
194    fn forward(&self, input: &ArrayView1<f64>) -> Result<(Array1<f64>, Array1<f64>), SklearsError> {
195        let mut activation = input.to_owned();
196        let mut activations = vec![activation.clone()];
197
198        // Forward pass through hidden layers
199        for (i, (weight, bias)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
200            let linear = activation.dot(weight) + bias;
201
202            if i < self.weights.len() - 1 {
203                // Apply ReLU for hidden layers
204                activation = self.relu(&linear);
205            } else {
206                // Linear output for reconstruction
207                activation = linear;
208            }
209            activations.push(activation.clone());
210        }
211
212        // Extract features from last hidden layer for classification
213        let feature_layer_idx = self.weights.len() - 1;
214        let features = &activations[feature_layer_idx];
215
216        // Classification output
217        let class_logits = features.dot(&self.class_weights) + &self.class_bias;
218        let class_probs = self.softmax(&class_logits);
219
220        Ok((activation, class_probs))
221    }
222
223    /// Compute autoregressive loss for a sequence
224    fn autoregressive_loss(&self, sequence: &ArrayView1<f64>) -> Result<f64, SklearsError> {
225        let mut total_loss = 0.0;
226        let seq_len = sequence.len();
227
228        if seq_len < 2 {
229            return Err(SklearsError::InvalidInput(
230                "Sequence too short for autoregressive modeling".to_string(),
231            ));
232        }
233
234        // Compute loss for each position in sequence
235        for i in 1..seq_len {
236            let context = sequence.slice(s![..i]);
237            let target = sequence[i];
238
239            // Pad context to input dimension if needed
240            let mut padded_context = Array1::zeros(self.input_dim);
241            let copy_len = context.len().min(self.input_dim);
242            padded_context
243                .slice_mut(s![..copy_len])
244                .assign(&context.slice(s![..copy_len]));
245
246            let (reconstruction, _) = self.forward(&padded_context.view())?;
247            let prediction = reconstruction[i % self.input_dim];
248
249            // Mean squared error for reconstruction
250            total_loss += (prediction - target).powi(2);
251        }
252
253        Ok(total_loss / (seq_len - 1) as f64)
254    }
255
256    /// Generate a sequence using the autoregressive model
257    pub fn generate_sequence(
258        &self,
259        initial_context: &ArrayView1<f64>,
260        length: usize,
261    ) -> Result<Array1<f64>, SklearsError> {
262        if !self.fitted {
263            return Err(SklearsError::NotFitted {
264                operation: "generating sequences".to_string(),
265            });
266        }
267
268        let mut sequence = Vec::new();
269        let mut context = initial_context.to_owned();
270
271        for _ in 0..length {
272            // Pad context to input dimension
273            let mut padded_context = Array1::zeros(self.input_dim);
274            let copy_len = context.len().min(self.input_dim);
275            padded_context
276                .slice_mut(s![..copy_len])
277                .assign(&context.slice(s![..copy_len]));
278
279            let (reconstruction, _) = self.forward(&padded_context.view())?;
280            let next_value = reconstruction[sequence.len() % self.input_dim];
281
282            sequence.push(next_value);
283
284            // Update context with new value
285            if context.len() >= self.sequence_length {
286                // Shift context window
287                for i in 0..context.len() - 1 {
288                    context[i] = context[i + 1];
289                }
290                let context_len = context.len();
291                context[context_len - 1] = next_value;
292            } else {
293                // Append to context
294                let mut new_context = Array1::zeros(context.len() + 1);
295                new_context.slice_mut(s![..context.len()]).assign(&context);
296                new_context[context.len()] = next_value;
297                context = new_context;
298            }
299        }
300
301        Ok(Array1::from_vec(sequence))
302    }
303
304    /// Compute log-likelihood of a sequence
305    pub fn log_likelihood(&self, sequence: &ArrayView1<f64>) -> Result<f64, SklearsError> {
306        if !self.fitted {
307            return Err(SklearsError::NotFitted {
308                operation: "computing log-likelihood".to_string(),
309            });
310        }
311
312        let mut log_likelihood = 0.0;
313        let seq_len = sequence.len();
314
315        if seq_len < 2 {
316            return Err(SklearsError::InvalidInput(
317                "Sequence too short for log-likelihood computation".to_string(),
318            ));
319        }
320
321        // Compute log-likelihood for each position
322        for i in 1..seq_len {
323            let context = sequence.slice(s![..i]);
324            let target = sequence[i];
325
326            // Pad context to input dimension
327            let mut padded_context = Array1::zeros(self.input_dim);
328            let copy_len = context.len().min(self.input_dim);
329            padded_context
330                .slice_mut(s![..copy_len])
331                .assign(&context.slice(s![..copy_len]));
332
333            let (reconstruction, _) = self.forward(&padded_context.view())?;
334            let prediction = reconstruction[i % self.input_dim];
335
336            // Assume Gaussian likelihood with unit variance
337            let diff = prediction - target;
338            log_likelihood -= 0.5 * diff * diff + 0.5 * (2.0 * std::f64::consts::PI).ln();
339        }
340
341        Ok(log_likelihood)
342    }
343}
344
345impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>> for AutoregressiveModel {
346    type Fitted = AutoregressiveModel;
347
348    fn fit(self, X: &ArrayView2<f64>, y: &ArrayView1<i32>) -> Result<Self::Fitted, SklearsError> {
349        if X.nrows() != y.len() {
350            return Err(SklearsError::InvalidInput(
351                "Number of samples in X and y must match".to_string(),
352            ));
353        }
354
355        let mut model = self;
356        model.input_dim = X.ncols();
357        model.initialize_parameters()?;
358
359        let n_samples = X.nrows();
360        let labeled_mask: Vec<bool> = y.iter().map(|&label| label != -1).collect();
361        let n_labeled = labeled_mask.iter().filter(|&&labeled| labeled).count();
362
363        if n_labeled == 0 {
364            return Err(SklearsError::InvalidInput(
365                "At least one labeled sample required".to_string(),
366            ));
367        }
368
369        // Training loop
370        for epoch in 0..model.epochs {
371            let mut total_loss = 0.0;
372            let mut n_processed = 0;
373
374            for i in 0..n_samples {
375                let sample = X.row(i);
376                let label = y[i];
377
378                // Compute reconstruction loss (unsupervised)
379                let reconstruction_loss = model.autoregressive_loss(&sample)?;
380                total_loss += reconstruction_loss;
381
382                // Compute classification loss (supervised, if labeled)
383                if labeled_mask[i] {
384                    let (_, class_probs) = model.forward(&sample)?;
385                    let target_class = label as usize;
386
387                    if target_class >= model.n_classes {
388                        return Err(SklearsError::InvalidInput(format!(
389                            "Label {} exceeds number of classes {}",
390                            target_class, model.n_classes
391                        )));
392                    }
393
394                    // Cross-entropy loss
395                    let class_loss = -class_probs[target_class].ln();
396                    total_loss += model.classification_weight * class_loss;
397                }
398
399                n_processed += 1;
400            }
401
402            // Simple gradient descent update (simplified)
403            // In practice, this would use proper backpropagation
404            if epoch % 10 == 0 {
405                println!(
406                    "Epoch {}: Average loss = {:.4}",
407                    epoch,
408                    total_loss / n_processed as f64
409                );
410            }
411
412            // Apply regularization
413            for weight in &mut model.weights {
414                weight.mapv_inplace(|w| w * (1.0 - model.learning_rate * model.regularization));
415            }
416        }
417
418        model.fitted = true;
419        Ok(model)
420    }
421}
422
423impl Predict<ArrayView2<'_, f64>, Array1<i32>> for AutoregressiveModel {
424    fn predict(&self, X: &ArrayView2<f64>) -> Result<Array1<i32>, SklearsError> {
425        if !self.fitted {
426            return Err(SklearsError::NotFitted {
427                operation: "making predictions".to_string(),
428            });
429        }
430
431        let mut predictions = Array1::zeros(X.nrows());
432
433        for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
434            let (_, class_probs) = self.forward(&sample)?;
435            let predicted_class = class_probs
436                .iter()
437                .enumerate()
438                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
439                .unwrap()
440                .0;
441            predictions[i] = predicted_class as i32;
442        }
443
444        Ok(predictions)
445    }
446}
447
448impl PredictProba<ArrayView2<'_, f64>, Array2<f64>> for AutoregressiveModel {
449    fn predict_proba(&self, X: &ArrayView2<f64>) -> Result<Array2<f64>, SklearsError> {
450        if !self.fitted {
451            return Err(SklearsError::NotFitted {
452                operation: "making predictions".to_string(),
453            });
454        }
455
456        let mut probabilities = Array2::zeros((X.nrows(), self.n_classes));
457
458        for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
459            let (_, class_probs) = self.forward(&sample)?;
460            probabilities.row_mut(i).assign(&class_probs);
461        }
462
463        Ok(probabilities)
464    }
465}
466
467#[allow(non_snake_case)]
468#[cfg(test)]
469mod tests {
470    use super::*;
471    use scirs2_core::array;
472
473    #[test]
474    fn test_autoregressive_model_creation() {
475        let model = AutoregressiveModel::new()
476            .hidden_dims(vec![32, 16])
477            .n_classes(3)
478            .input_dim(5)
479            .sequence_length(8)
480            .learning_rate(0.01)
481            .epochs(50)
482            .regularization(0.1)
483            .temperature(0.8)
484            .classification_weight(2.0);
485
486        assert_eq!(model.hidden_dims, vec![32, 16]);
487        assert_eq!(model.n_classes, 3);
488        assert_eq!(model.input_dim, 5);
489        assert_eq!(model.sequence_length, 8);
490        assert_eq!(model.learning_rate, 0.01);
491        assert_eq!(model.epochs, 50);
492        assert_eq!(model.regularization, 0.1);
493        assert_eq!(model.temperature, 0.8);
494        assert_eq!(model.classification_weight, 2.0);
495    }
496
497    #[test]
498    #[allow(non_snake_case)]
499    fn test_autoregressive_model_fit_predict() {
500        let X = array![
501            [1.0, 2.0, 3.0],
502            [2.0, 3.0, 4.0],
503            [3.0, 4.0, 5.0],
504            [4.0, 5.0, 6.0]
505        ];
506        let y = array![0, 1, -1, 0]; // -1 indicates unlabeled
507
508        let model = AutoregressiveModel::new()
509            .n_classes(2)
510            .input_dim(3)
511            .epochs(10)
512            .learning_rate(0.01);
513
514        let fitted_model = model.fit(&X.view(), &y.view()).unwrap();
515        let predictions = fitted_model.predict(&X.view()).unwrap();
516        let probabilities = fitted_model.predict_proba(&X.view()).unwrap();
517
518        assert_eq!(predictions.len(), 4);
519        assert_eq!(probabilities.dim(), (4, 2));
520
521        // Check that probabilities sum to 1
522        for i in 0..4 {
523            let sum: f64 = probabilities.row(i).sum();
524            assert!((sum - 1.0).abs() < 1e-6);
525        }
526    }
527
528    #[test]
529    #[allow(non_snake_case)]
530    fn test_autoregressive_model_insufficient_labeled_samples() {
531        let X = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
532        let y = array![-1, -1]; // All unlabeled
533
534        let model = AutoregressiveModel::new()
535            .n_classes(2)
536            .input_dim(3)
537            .epochs(10);
538
539        let result = model.fit(&X.view(), &y.view());
540        assert!(result.is_err());
541    }
542
543    #[test]
544    #[allow(non_snake_case)]
545    fn test_autoregressive_model_invalid_dimensions() {
546        let X = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
547        let y = array![0]; // Mismatched dimensions
548
549        let model = AutoregressiveModel::new();
550        let result = model.fit(&X.view(), &y.view());
551        assert!(result.is_err());
552    }
553
554    #[test]
555    #[allow(non_snake_case)]
556    fn test_autoregressive_model_generate_sequence() {
557        let X = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]];
558        let y = array![0, 1, 0];
559
560        let model = AutoregressiveModel::new()
561            .n_classes(2)
562            .input_dim(3)
563            .epochs(5);
564
565        let fitted_model = model.fit(&X.view(), &y.view()).unwrap();
566        let initial_context = array![1.0, 2.0];
567        let sequence = fitted_model
568            .generate_sequence(&initial_context.view(), 5)
569            .unwrap();
570
571        assert_eq!(sequence.len(), 5);
572    }
573
574    #[test]
575    #[allow(non_snake_case)]
576    fn test_autoregressive_model_log_likelihood() {
577        let X = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
578        let y = array![0, 1];
579
580        let model = AutoregressiveModel::new()
581            .n_classes(2)
582            .input_dim(3)
583            .epochs(5);
584
585        let fitted_model = model.fit(&X.view(), &y.view()).unwrap();
586        let sequence = array![1.0, 2.0, 3.0, 4.0];
587        let log_likelihood = fitted_model.log_likelihood(&sequence.view()).unwrap();
588
589        assert!(log_likelihood.is_finite());
590    }
591
592    #[test]
593    fn test_softmax_computation() {
594        let model = AutoregressiveModel::new().temperature(1.0);
595        let logits = array![1.0, 2.0, 3.0];
596        let probs = model.softmax(&logits);
597
598        let sum: f64 = probs.sum();
599        assert!((sum - 1.0).abs() < 1e-10);
600
601        // Check that probabilities are in ascending order
602        assert!(probs[0] < probs[1]);
603        assert!(probs[1] < probs[2]);
604    }
605
606    #[test]
607    fn test_relu_activation() {
608        let model = AutoregressiveModel::new();
609        let input = array![-1.0, 0.0, 1.0, 2.0];
610        let output = model.relu(&input);
611
612        assert_eq!(output, array![0.0, 0.0, 1.0, 2.0]);
613    }
614
615    #[test]
616    #[allow(non_snake_case)]
617    fn test_autoregressive_model_not_fitted_error() {
618        let model = AutoregressiveModel::new();
619        let X = array![[1.0, 2.0, 3.0]];
620
621        let result = model.predict(&X.view());
622        assert!(result.is_err());
623
624        let result = model.predict_proba(&X.view());
625        assert!(result.is_err());
626
627        let sequence = array![1.0, 2.0, 3.0];
628        let result = model.generate_sequence(&sequence.view(), 5);
629        assert!(result.is_err());
630
631        let result = model.log_likelihood(&sequence.view());
632        assert!(result.is_err());
633    }
634
635    #[test]
636    #[allow(non_snake_case)]
637    fn test_autoregressive_model_with_different_parameters() {
638        let X = array![
639            [1.0, 2.0, 3.0, 4.0],
640            [2.0, 3.0, 4.0, 5.0],
641            [3.0, 4.0, 5.0, 6.0]
642        ];
643        let y = array![0, 1, 2];
644
645        let model = AutoregressiveModel::new()
646            .hidden_dims(vec![8, 4])
647            .n_classes(3)
648            .input_dim(4)
649            .sequence_length(6)
650            .learning_rate(0.1)
651            .epochs(3)
652            .regularization(0.01)
653            .temperature(0.5)
654            .classification_weight(0.5);
655
656        let fitted_model = model.fit(&X.view(), &y.view()).unwrap();
657        let predictions = fitted_model.predict(&X.view()).unwrap();
658        let probabilities = fitted_model.predict_proba(&X.view()).unwrap();
659
660        assert_eq!(predictions.len(), 3);
661        assert_eq!(probabilities.dim(), (3, 3));
662    }
663
664    #[test]
665    fn test_autoregressive_loss_computation() {
666        let model = AutoregressiveModel::new().input_dim(3).hidden_dims(vec![4]);
667
668        let mut model = model.clone();
669        model.initialize_parameters().unwrap();
670
671        let sequence = array![1.0, 2.0, 3.0, 4.0];
672        let loss = model.autoregressive_loss(&sequence.view()).unwrap();
673
674        assert!(loss >= 0.0);
675        assert!(loss.is_finite());
676    }
677
678    #[test]
679    fn test_autoregressive_loss_short_sequence() {
680        let model = AutoregressiveModel::new().input_dim(3).hidden_dims(vec![4]);
681
682        let mut model = model.clone();
683        model.initialize_parameters().unwrap();
684
685        let sequence = array![1.0]; // Too short
686        let result = model.autoregressive_loss(&sequence.view());
687        assert!(result.is_err());
688    }
689}