scirs2_series/advanced_training_modules/
variational.rs

1//! Variational Autoencoder for Time Series
2//!
3//! This module provides implementations for variational autoencoders specifically designed
4//! for time series data, including uncertainty quantification and probabilistic modeling.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use crate::error::Result;
11
12/// Type alias for complex encoder weights return type
13type EncoderWeights<F> = (
14    Array2<F>,
15    Array1<F>,
16    Array2<F>,
17    Array1<F>,
18    Array2<F>,
19    Array1<F>,
20);
21
22/// Variational Autoencoder for Time Series with Uncertainty Quantification
23#[derive(Debug)]
24pub struct TimeSeriesVAE<F: Float + Debug + scirs2_core::ndarray::ScalarOperand> {
25    /// Encoder parameters
26    encoder_params: Array2<F>,
27    /// Decoder parameters
28    decoder_params: Array2<F>,
29    /// Latent dimension
30    latent_dim: usize,
31    /// Input sequence length
32    seq_len: usize,
33    /// Feature dimension
34    feature_dim: usize,
35    /// Hidden dimensions
36    encoder_hidden: usize,
37    decoder_hidden: usize,
38}
39
40impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand>
41    TimeSeriesVAE<F>
42{
43    /// Create new Time Series VAE
44    pub fn new(
45        seq_len: usize,
46        feature_dim: usize,
47        latent_dim: usize,
48        encoder_hidden: usize,
49        decoder_hidden: usize,
50    ) -> Self {
51        let input_size = seq_len * feature_dim;
52
53        // Initialize encoder parameters (input -> _hidden -> latent_mean, latent_logvar)
54        let encoder_param_count = input_size * encoder_hidden
55            + encoder_hidden
56            + encoder_hidden * latent_dim * 2
57            + latent_dim * 2;
58        let mut encoder_params = Array2::zeros((1, encoder_param_count));
59
60        // Initialize decoder parameters (latent -> _hidden -> output)
61        let decoder_param_count =
62            latent_dim * decoder_hidden + decoder_hidden + decoder_hidden * input_size + input_size;
63        let mut decoder_params = Array2::zeros((1, decoder_param_count));
64
65        // Xavier initialization
66        let encoder_scale = F::from(2.0).unwrap() / F::from(input_size + latent_dim).unwrap();
67        let decoder_scale = F::from(2.0).unwrap() / F::from(latent_dim + input_size).unwrap();
68
69        for i in 0..encoder_param_count {
70            let val = ((i * 19) % 1000) as f64 / 1000.0 - 0.5;
71            encoder_params[[0, i]] = F::from(val).unwrap() * encoder_scale.sqrt();
72        }
73
74        for i in 0..decoder_param_count {
75            let val = ((i * 31) % 1000) as f64 / 1000.0 - 0.5;
76            decoder_params[[0, i]] = F::from(val).unwrap() * decoder_scale.sqrt();
77        }
78
79        Self {
80            encoder_params,
81            decoder_params,
82            latent_dim,
83            seq_len,
84            feature_dim,
85            encoder_hidden,
86            decoder_hidden,
87        }
88    }
89
90    /// Encode time series to latent distribution
91    pub fn encode(&self, input: &Array2<F>) -> Result<(Array1<F>, Array1<F>)> {
92        // Flatten input
93        let input_flat = self.flatten_input(input);
94
95        // Extract encoder weights
96        let (w1, b1, w_mean, b_mean, w_logvar, b_logvar) = self.extract_encoder_weights();
97
98        // Forward through encoder
99        let mut hidden = Array1::zeros(self.encoder_hidden);
100        for i in 0..self.encoder_hidden {
101            let mut sum = b1[i];
102            for j in 0..input_flat.len() {
103                sum = sum + w1[[i, j]] * input_flat[j];
104            }
105            hidden[i] = self.relu(sum);
106        }
107
108        // Compute latent mean and log variance
109        let mut latent_mean = Array1::zeros(self.latent_dim);
110        let mut latent_logvar = Array1::zeros(self.latent_dim);
111
112        for i in 0..self.latent_dim {
113            let mut mean_sum = b_mean[i];
114            let mut logvar_sum = b_logvar[i];
115
116            for j in 0..self.encoder_hidden {
117                mean_sum = mean_sum + w_mean[[i, j]] * hidden[j];
118                logvar_sum = logvar_sum + w_logvar[[i, j]] * hidden[j];
119            }
120
121            latent_mean[i] = mean_sum;
122            latent_logvar[i] = logvar_sum;
123        }
124
125        Ok((latent_mean, latent_logvar))
126    }
127
128    /// Sample from latent distribution using reparameterization trick
129    pub fn reparameterize(&self, mean: &Array1<F>, logvar: &Array1<F>) -> Array1<F> {
130        let mut sample = Array1::zeros(self.latent_dim);
131
132        for i in 0..self.latent_dim {
133            // Sample from standard normal (simplified)
134            let eps = F::from(((i * 47) % 1000) as f64 / 1000.0 - 0.5).unwrap();
135            let std = (logvar[i] / F::from(2.0).unwrap()).exp();
136            sample[i] = mean[i] + std * eps;
137        }
138
139        sample
140    }
141
142    /// Decode latent representation to time series
143    pub fn decode(&self, latent: &Array1<F>) -> Result<Array2<F>> {
144        // Extract decoder weights
145        let (w1, b1, w2, b2) = self.extract_decoder_weights();
146
147        // Forward through decoder
148        let mut hidden = Array1::zeros(self.decoder_hidden);
149        for i in 0..self.decoder_hidden {
150            let mut sum = b1[i];
151            for j in 0..self.latent_dim {
152                sum = sum + w1[[i, j]] * latent[j];
153            }
154            hidden[i] = self.relu(sum);
155        }
156
157        // Generate output
158        let output_size = self.seq_len * self.feature_dim;
159        let mut output_flat = Array1::zeros(output_size);
160
161        for i in 0..output_size {
162            let mut sum = b2[i];
163            for j in 0..self.decoder_hidden {
164                sum = sum + w2[[i, j]] * hidden[j];
165            }
166            output_flat[i] = sum;
167        }
168
169        // Reshape to time series format
170        self.unflatten_output(&output_flat)
171    }
172
173    /// Full forward pass with reconstruction and KL divergence
174    pub fn forward(&self, input: &Array2<F>) -> Result<VAEOutput<F>> {
175        let (latent_mean, latent_logvar) = self.encode(input)?;
176        let latent_sample = self.reparameterize(&latent_mean, &latent_logvar);
177        let reconstruction = self.decode(&latent_sample)?;
178
179        // Compute KL divergence
180        let mut kl_div = F::zero();
181        for i in 0..self.latent_dim {
182            let mean_sq = latent_mean[i] * latent_mean[i];
183            let var = latent_logvar[i].exp();
184            kl_div = kl_div + mean_sq + var - latent_logvar[i] - F::one();
185        }
186        kl_div = kl_div / F::from(2.0).unwrap();
187
188        // Compute reconstruction loss
189        let mut recon_loss = F::zero();
190        let (seq_len, feature_dim) = input.dim();
191
192        for i in 0..seq_len {
193            for j in 0..feature_dim {
194                let diff = reconstruction[[i, j]] - input[[i, j]];
195                recon_loss = recon_loss + diff * diff;
196            }
197        }
198        recon_loss = recon_loss / F::from(seq_len * feature_dim).unwrap();
199
200        Ok(VAEOutput {
201            reconstruction,
202            latent_mean,
203            latent_logvar,
204            latent_sample,
205            reconstruction_loss: recon_loss,
206            kl_divergence: kl_div,
207        })
208    }
209
210    /// Generate new time series by sampling from latent space
211    pub fn generate(&self, numsamples: usize) -> Result<Vec<Array2<F>>> {
212        let mut _samples = Vec::new();
213
214        for i in 0..numsamples {
215            // Sample from prior distribution (standard normal)
216            let mut latent = Array1::zeros(self.latent_dim);
217            for j in 0..self.latent_dim {
218                let val = ((i * 53 + j * 29) % 1000) as f64 / 1000.0 - 0.5;
219                latent[j] = F::from(val).unwrap();
220            }
221
222            let generated = self.decode(&latent)?;
223            _samples.push(generated);
224        }
225
226        Ok(_samples)
227    }
228
229    /// Estimate uncertainty by sampling multiple reconstructions
230    pub fn estimate_uncertainty(
231        &self,
232        input: &Array2<F>,
233        num_samples: usize,
234    ) -> Result<(Array2<F>, Array2<F>)> {
235        let (latent_mean, latent_logvar) = self.encode(input)?;
236        let mut reconstructions = Vec::new();
237
238        // Generate multiple _samples
239        for _ in 0..num_samples {
240            let latent_sample = self.reparameterize(&latent_mean, &latent_logvar);
241            let reconstruction = self.decode(&latent_sample)?;
242            reconstructions.push(reconstruction);
243        }
244
245        // Compute mean and standard deviation
246        let (seq_len, feature_dim) = input.dim();
247        let mut mean_recon = Array2::zeros((seq_len, feature_dim));
248        let mut std_recon = Array2::zeros((seq_len, feature_dim));
249
250        // Compute mean
251        for recon in &reconstructions {
252            for i in 0..seq_len {
253                for j in 0..feature_dim {
254                    mean_recon[[i, j]] = mean_recon[[i, j]] + recon[[i, j]];
255                }
256            }
257        }
258
259        let num_samples_f = F::from(num_samples).unwrap();
260        for i in 0..seq_len {
261            for j in 0..feature_dim {
262                mean_recon[[i, j]] = mean_recon[[i, j]] / num_samples_f;
263            }
264        }
265
266        // Compute standard deviation
267        for recon in &reconstructions {
268            for i in 0..seq_len {
269                for j in 0..feature_dim {
270                    let diff = recon[[i, j]] - mean_recon[[i, j]];
271                    std_recon[[i, j]] = std_recon[[i, j]] + diff * diff;
272                }
273            }
274        }
275
276        for i in 0..seq_len {
277            for j in 0..feature_dim {
278                let val: F = std_recon[[i, j]] / num_samples_f;
279                std_recon[[i, j]] = val.sqrt();
280            }
281        }
282
283        Ok((mean_recon, std_recon))
284    }
285
286    // Helper methods
287    fn flatten_input(&self, input: &Array2<F>) -> Array1<F> {
288        let (seq_len, feature_dim) = input.dim();
289        let mut flat = Array1::zeros(seq_len * feature_dim);
290
291        for i in 0..seq_len {
292            for j in 0..feature_dim {
293                flat[i * feature_dim + j] = input[[i, j]];
294            }
295        }
296
297        flat
298    }
299
300    fn unflatten_output(&self, output: &Array1<F>) -> Result<Array2<F>> {
301        let mut result = Array2::zeros((self.seq_len, self.feature_dim));
302
303        for i in 0..self.seq_len {
304            for j in 0..self.feature_dim {
305                let idx = i * self.feature_dim + j;
306                if idx < output.len() {
307                    result[[i, j]] = output[idx];
308                }
309            }
310        }
311
312        Ok(result)
313    }
314
315    fn extract_encoder_weights(&self) -> EncoderWeights<F> {
316        let param_vec = self.encoder_params.row(0);
317        let input_size = self.seq_len * self.feature_dim;
318        let mut idx = 0;
319
320        // W1: input_size x encoder_hidden
321        let mut w1 = Array2::zeros((self.encoder_hidden, input_size));
322        for i in 0..self.encoder_hidden {
323            for j in 0..input_size {
324                w1[[i, j]] = param_vec[idx];
325                idx += 1;
326            }
327        }
328
329        // b1: encoder_hidden
330        let mut b1 = Array1::zeros(self.encoder_hidden);
331        for i in 0..self.encoder_hidden {
332            b1[i] = param_vec[idx];
333            idx += 1;
334        }
335
336        // W_mean: encoder_hidden x latent_dim
337        let mut w_mean = Array2::zeros((self.latent_dim, self.encoder_hidden));
338        for i in 0..self.latent_dim {
339            for j in 0..self.encoder_hidden {
340                w_mean[[i, j]] = param_vec[idx];
341                idx += 1;
342            }
343        }
344
345        // b_mean: latent_dim
346        let mut b_mean = Array1::zeros(self.latent_dim);
347        for i in 0..self.latent_dim {
348            b_mean[i] = param_vec[idx];
349            idx += 1;
350        }
351
352        // W_logvar: encoder_hidden x latent_dim
353        let mut w_logvar = Array2::zeros((self.latent_dim, self.encoder_hidden));
354        for i in 0..self.latent_dim {
355            for j in 0..self.encoder_hidden {
356                w_logvar[[i, j]] = param_vec[idx];
357                idx += 1;
358            }
359        }
360
361        // b_logvar: latent_dim
362        let mut b_logvar = Array1::zeros(self.latent_dim);
363        for i in 0..self.latent_dim {
364            b_logvar[i] = param_vec[idx];
365            idx += 1;
366        }
367
368        (w1, b1, w_mean, b_mean, w_logvar, b_logvar)
369    }
370
371    fn extract_decoder_weights(&self) -> (Array2<F>, Array1<F>, Array2<F>, Array1<F>) {
372        let param_vec = self.decoder_params.row(0);
373        let output_size = self.seq_len * self.feature_dim;
374        let mut idx = 0;
375
376        // W1: latent_dim x decoder_hidden
377        let mut w1 = Array2::zeros((self.decoder_hidden, self.latent_dim));
378        for i in 0..self.decoder_hidden {
379            for j in 0..self.latent_dim {
380                w1[[i, j]] = param_vec[idx];
381                idx += 1;
382            }
383        }
384
385        // b1: decoder_hidden
386        let mut b1 = Array1::zeros(self.decoder_hidden);
387        for i in 0..self.decoder_hidden {
388            b1[i] = param_vec[idx];
389            idx += 1;
390        }
391
392        // W2: decoder_hidden x output_size
393        let mut w2 = Array2::zeros((output_size, self.decoder_hidden));
394        for i in 0..output_size {
395            for j in 0..self.decoder_hidden {
396                w2[[i, j]] = param_vec[idx];
397                idx += 1;
398            }
399        }
400
401        // b2: output_size
402        let mut b2 = Array1::zeros(output_size);
403        for i in 0..output_size {
404            b2[i] = param_vec[idx];
405            idx += 1;
406        }
407
408        (w1, b1, w2, b2)
409    }
410
411    fn relu(&self, x: F) -> F {
412        x.max(F::zero())
413    }
414}
415
416/// VAE output structure
417#[derive(Debug, Clone)]
418pub struct VAEOutput<F: Float + Debug> {
419    /// Reconstructed time series
420    pub reconstruction: Array2<F>,
421    /// Latent mean
422    pub latent_mean: Array1<F>,
423    /// Latent log variance
424    pub latent_logvar: Array1<F>,
425    /// Latent sample
426    pub latent_sample: Array1<F>,
427    /// Reconstruction loss
428    pub reconstruction_loss: F,
429    /// KL divergence
430    pub kl_divergence: F,
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use approx::assert_abs_diff_eq;
437
438    #[test]
439    fn test_vae_creation() {
440        let vae = TimeSeriesVAE::<f64>::new(10, 3, 5, 16, 16);
441        assert_eq!(vae.seq_len, 10);
442        assert_eq!(vae.feature_dim, 3);
443        assert_eq!(vae.latent_dim, 5);
444        assert_eq!(vae.encoder_hidden, 16);
445        assert_eq!(vae.decoder_hidden, 16);
446    }
447
448    #[test]
449    fn test_vae_encode_decode() {
450        let vae = TimeSeriesVAE::<f64>::new(5, 2, 3, 8, 8);
451        let input =
452            Array2::from_shape_vec((5, 2), (0..10).map(|i| i as f64 * 0.1).collect()).unwrap();
453
454        let (mean, logvar) = vae.encode(&input).unwrap();
455        assert_eq!(mean.len(), 3);
456        assert_eq!(logvar.len(), 3);
457
458        let sample = vae.reparameterize(&mean, &logvar);
459        assert_eq!(sample.len(), 3);
460
461        let decoded = vae.decode(&sample).unwrap();
462        assert_eq!(decoded.dim(), (5, 2));
463    }
464
465    #[test]
466    fn test_vae_forward() {
467        let vae = TimeSeriesVAE::<f64>::new(4, 2, 3, 8, 8);
468        let input =
469            Array2::from_shape_vec((4, 2), (0..8).map(|i| i as f64 * 0.1).collect()).unwrap();
470
471        let output = vae.forward(&input).unwrap();
472        assert_eq!(output.reconstruction.dim(), (4, 2));
473        assert_eq!(output.latent_mean.len(), 3);
474        assert_eq!(output.latent_logvar.len(), 3);
475        assert_eq!(output.latent_sample.len(), 3);
476        assert!(output.reconstruction_loss >= 0.0);
477        assert!(output.kl_divergence >= 0.0);
478    }
479
480    #[test]
481    fn test_vae_uncertainty_estimation() {
482        let vae = TimeSeriesVAE::<f64>::new(3, 2, 2, 6, 6);
483        let input =
484            Array2::from_shape_vec((3, 2), (0..6).map(|i| i as f64 * 0.2).collect()).unwrap();
485
486        let (mean_recon, std_recon) = vae.estimate_uncertainty(&input, 5).unwrap();
487        assert_eq!(mean_recon.dim(), (3, 2));
488        assert_eq!(std_recon.dim(), (3, 2));
489
490        // Check that standard deviations are non-negative
491        for &val in std_recon.iter() {
492            assert!(val >= 0.0);
493        }
494    }
495
496    #[test]
497    fn test_vae_generation() {
498        let vae = TimeSeriesVAE::<f64>::new(4, 2, 3, 8, 8);
499        let samples = vae.generate(3).unwrap();
500
501        assert_eq!(samples.len(), 3);
502        for sample in samples {
503            assert_eq!(sample.dim(), (4, 2));
504        }
505    }
506}