1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use crate::error::Result;
11
12type EncoderWeights<F> = (
14 Array2<F>,
15 Array1<F>,
16 Array2<F>,
17 Array1<F>,
18 Array2<F>,
19 Array1<F>,
20);
21
22#[derive(Debug)]
24pub struct TimeSeriesVAE<F: Float + Debug + scirs2_core::ndarray::ScalarOperand> {
25 encoder_params: Array2<F>,
27 decoder_params: Array2<F>,
29 latent_dim: usize,
31 seq_len: usize,
33 feature_dim: usize,
35 encoder_hidden: usize,
37 decoder_hidden: usize,
38}
39
40impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand>
41 TimeSeriesVAE<F>
42{
43 pub fn new(
45 seq_len: usize,
46 feature_dim: usize,
47 latent_dim: usize,
48 encoder_hidden: usize,
49 decoder_hidden: usize,
50 ) -> Self {
51 let input_size = seq_len * feature_dim;
52
53 let encoder_param_count = input_size * encoder_hidden
55 + encoder_hidden
56 + encoder_hidden * latent_dim * 2
57 + latent_dim * 2;
58 let mut encoder_params = Array2::zeros((1, encoder_param_count));
59
60 let decoder_param_count =
62 latent_dim * decoder_hidden + decoder_hidden + decoder_hidden * input_size + input_size;
63 let mut decoder_params = Array2::zeros((1, decoder_param_count));
64
65 let encoder_scale = F::from(2.0).unwrap() / F::from(input_size + latent_dim).unwrap();
67 let decoder_scale = F::from(2.0).unwrap() / F::from(latent_dim + input_size).unwrap();
68
69 for i in 0..encoder_param_count {
70 let val = ((i * 19) % 1000) as f64 / 1000.0 - 0.5;
71 encoder_params[[0, i]] = F::from(val).unwrap() * encoder_scale.sqrt();
72 }
73
74 for i in 0..decoder_param_count {
75 let val = ((i * 31) % 1000) as f64 / 1000.0 - 0.5;
76 decoder_params[[0, i]] = F::from(val).unwrap() * decoder_scale.sqrt();
77 }
78
79 Self {
80 encoder_params,
81 decoder_params,
82 latent_dim,
83 seq_len,
84 feature_dim,
85 encoder_hidden,
86 decoder_hidden,
87 }
88 }
89
90 pub fn encode(&self, input: &Array2<F>) -> Result<(Array1<F>, Array1<F>)> {
92 let input_flat = self.flatten_input(input);
94
95 let (w1, b1, w_mean, b_mean, w_logvar, b_logvar) = self.extract_encoder_weights();
97
98 let mut hidden = Array1::zeros(self.encoder_hidden);
100 for i in 0..self.encoder_hidden {
101 let mut sum = b1[i];
102 for j in 0..input_flat.len() {
103 sum = sum + w1[[i, j]] * input_flat[j];
104 }
105 hidden[i] = self.relu(sum);
106 }
107
108 let mut latent_mean = Array1::zeros(self.latent_dim);
110 let mut latent_logvar = Array1::zeros(self.latent_dim);
111
112 for i in 0..self.latent_dim {
113 let mut mean_sum = b_mean[i];
114 let mut logvar_sum = b_logvar[i];
115
116 for j in 0..self.encoder_hidden {
117 mean_sum = mean_sum + w_mean[[i, j]] * hidden[j];
118 logvar_sum = logvar_sum + w_logvar[[i, j]] * hidden[j];
119 }
120
121 latent_mean[i] = mean_sum;
122 latent_logvar[i] = logvar_sum;
123 }
124
125 Ok((latent_mean, latent_logvar))
126 }
127
128 pub fn reparameterize(&self, mean: &Array1<F>, logvar: &Array1<F>) -> Array1<F> {
130 let mut sample = Array1::zeros(self.latent_dim);
131
132 for i in 0..self.latent_dim {
133 let eps = F::from(((i * 47) % 1000) as f64 / 1000.0 - 0.5).unwrap();
135 let std = (logvar[i] / F::from(2.0).unwrap()).exp();
136 sample[i] = mean[i] + std * eps;
137 }
138
139 sample
140 }
141
142 pub fn decode(&self, latent: &Array1<F>) -> Result<Array2<F>> {
144 let (w1, b1, w2, b2) = self.extract_decoder_weights();
146
147 let mut hidden = Array1::zeros(self.decoder_hidden);
149 for i in 0..self.decoder_hidden {
150 let mut sum = b1[i];
151 for j in 0..self.latent_dim {
152 sum = sum + w1[[i, j]] * latent[j];
153 }
154 hidden[i] = self.relu(sum);
155 }
156
157 let output_size = self.seq_len * self.feature_dim;
159 let mut output_flat = Array1::zeros(output_size);
160
161 for i in 0..output_size {
162 let mut sum = b2[i];
163 for j in 0..self.decoder_hidden {
164 sum = sum + w2[[i, j]] * hidden[j];
165 }
166 output_flat[i] = sum;
167 }
168
169 self.unflatten_output(&output_flat)
171 }
172
173 pub fn forward(&self, input: &Array2<F>) -> Result<VAEOutput<F>> {
175 let (latent_mean, latent_logvar) = self.encode(input)?;
176 let latent_sample = self.reparameterize(&latent_mean, &latent_logvar);
177 let reconstruction = self.decode(&latent_sample)?;
178
179 let mut kl_div = F::zero();
181 for i in 0..self.latent_dim {
182 let mean_sq = latent_mean[i] * latent_mean[i];
183 let var = latent_logvar[i].exp();
184 kl_div = kl_div + mean_sq + var - latent_logvar[i] - F::one();
185 }
186 kl_div = kl_div / F::from(2.0).unwrap();
187
188 let mut recon_loss = F::zero();
190 let (seq_len, feature_dim) = input.dim();
191
192 for i in 0..seq_len {
193 for j in 0..feature_dim {
194 let diff = reconstruction[[i, j]] - input[[i, j]];
195 recon_loss = recon_loss + diff * diff;
196 }
197 }
198 recon_loss = recon_loss / F::from(seq_len * feature_dim).unwrap();
199
200 Ok(VAEOutput {
201 reconstruction,
202 latent_mean,
203 latent_logvar,
204 latent_sample,
205 reconstruction_loss: recon_loss,
206 kl_divergence: kl_div,
207 })
208 }
209
210 pub fn generate(&self, numsamples: usize) -> Result<Vec<Array2<F>>> {
212 let mut _samples = Vec::new();
213
214 for i in 0..numsamples {
215 let mut latent = Array1::zeros(self.latent_dim);
217 for j in 0..self.latent_dim {
218 let val = ((i * 53 + j * 29) % 1000) as f64 / 1000.0 - 0.5;
219 latent[j] = F::from(val).unwrap();
220 }
221
222 let generated = self.decode(&latent)?;
223 _samples.push(generated);
224 }
225
226 Ok(_samples)
227 }
228
229 pub fn estimate_uncertainty(
231 &self,
232 input: &Array2<F>,
233 num_samples: usize,
234 ) -> Result<(Array2<F>, Array2<F>)> {
235 let (latent_mean, latent_logvar) = self.encode(input)?;
236 let mut reconstructions = Vec::new();
237
238 for _ in 0..num_samples {
240 let latent_sample = self.reparameterize(&latent_mean, &latent_logvar);
241 let reconstruction = self.decode(&latent_sample)?;
242 reconstructions.push(reconstruction);
243 }
244
245 let (seq_len, feature_dim) = input.dim();
247 let mut mean_recon = Array2::zeros((seq_len, feature_dim));
248 let mut std_recon = Array2::zeros((seq_len, feature_dim));
249
250 for recon in &reconstructions {
252 for i in 0..seq_len {
253 for j in 0..feature_dim {
254 mean_recon[[i, j]] = mean_recon[[i, j]] + recon[[i, j]];
255 }
256 }
257 }
258
259 let num_samples_f = F::from(num_samples).unwrap();
260 for i in 0..seq_len {
261 for j in 0..feature_dim {
262 mean_recon[[i, j]] = mean_recon[[i, j]] / num_samples_f;
263 }
264 }
265
266 for recon in &reconstructions {
268 for i in 0..seq_len {
269 for j in 0..feature_dim {
270 let diff = recon[[i, j]] - mean_recon[[i, j]];
271 std_recon[[i, j]] = std_recon[[i, j]] + diff * diff;
272 }
273 }
274 }
275
276 for i in 0..seq_len {
277 for j in 0..feature_dim {
278 let val: F = std_recon[[i, j]] / num_samples_f;
279 std_recon[[i, j]] = val.sqrt();
280 }
281 }
282
283 Ok((mean_recon, std_recon))
284 }
285
286 fn flatten_input(&self, input: &Array2<F>) -> Array1<F> {
288 let (seq_len, feature_dim) = input.dim();
289 let mut flat = Array1::zeros(seq_len * feature_dim);
290
291 for i in 0..seq_len {
292 for j in 0..feature_dim {
293 flat[i * feature_dim + j] = input[[i, j]];
294 }
295 }
296
297 flat
298 }
299
300 fn unflatten_output(&self, output: &Array1<F>) -> Result<Array2<F>> {
301 let mut result = Array2::zeros((self.seq_len, self.feature_dim));
302
303 for i in 0..self.seq_len {
304 for j in 0..self.feature_dim {
305 let idx = i * self.feature_dim + j;
306 if idx < output.len() {
307 result[[i, j]] = output[idx];
308 }
309 }
310 }
311
312 Ok(result)
313 }
314
315 fn extract_encoder_weights(&self) -> EncoderWeights<F> {
316 let param_vec = self.encoder_params.row(0);
317 let input_size = self.seq_len * self.feature_dim;
318 let mut idx = 0;
319
320 let mut w1 = Array2::zeros((self.encoder_hidden, input_size));
322 for i in 0..self.encoder_hidden {
323 for j in 0..input_size {
324 w1[[i, j]] = param_vec[idx];
325 idx += 1;
326 }
327 }
328
329 let mut b1 = Array1::zeros(self.encoder_hidden);
331 for i in 0..self.encoder_hidden {
332 b1[i] = param_vec[idx];
333 idx += 1;
334 }
335
336 let mut w_mean = Array2::zeros((self.latent_dim, self.encoder_hidden));
338 for i in 0..self.latent_dim {
339 for j in 0..self.encoder_hidden {
340 w_mean[[i, j]] = param_vec[idx];
341 idx += 1;
342 }
343 }
344
345 let mut b_mean = Array1::zeros(self.latent_dim);
347 for i in 0..self.latent_dim {
348 b_mean[i] = param_vec[idx];
349 idx += 1;
350 }
351
352 let mut w_logvar = Array2::zeros((self.latent_dim, self.encoder_hidden));
354 for i in 0..self.latent_dim {
355 for j in 0..self.encoder_hidden {
356 w_logvar[[i, j]] = param_vec[idx];
357 idx += 1;
358 }
359 }
360
361 let mut b_logvar = Array1::zeros(self.latent_dim);
363 for i in 0..self.latent_dim {
364 b_logvar[i] = param_vec[idx];
365 idx += 1;
366 }
367
368 (w1, b1, w_mean, b_mean, w_logvar, b_logvar)
369 }
370
371 fn extract_decoder_weights(&self) -> (Array2<F>, Array1<F>, Array2<F>, Array1<F>) {
372 let param_vec = self.decoder_params.row(0);
373 let output_size = self.seq_len * self.feature_dim;
374 let mut idx = 0;
375
376 let mut w1 = Array2::zeros((self.decoder_hidden, self.latent_dim));
378 for i in 0..self.decoder_hidden {
379 for j in 0..self.latent_dim {
380 w1[[i, j]] = param_vec[idx];
381 idx += 1;
382 }
383 }
384
385 let mut b1 = Array1::zeros(self.decoder_hidden);
387 for i in 0..self.decoder_hidden {
388 b1[i] = param_vec[idx];
389 idx += 1;
390 }
391
392 let mut w2 = Array2::zeros((output_size, self.decoder_hidden));
394 for i in 0..output_size {
395 for j in 0..self.decoder_hidden {
396 w2[[i, j]] = param_vec[idx];
397 idx += 1;
398 }
399 }
400
401 let mut b2 = Array1::zeros(output_size);
403 for i in 0..output_size {
404 b2[i] = param_vec[idx];
405 idx += 1;
406 }
407
408 (w1, b1, w2, b2)
409 }
410
411 fn relu(&self, x: F) -> F {
412 x.max(F::zero())
413 }
414}
415
416#[derive(Debug, Clone)]
418pub struct VAEOutput<F: Float + Debug> {
419 pub reconstruction: Array2<F>,
421 pub latent_mean: Array1<F>,
423 pub latent_logvar: Array1<F>,
425 pub latent_sample: Array1<F>,
427 pub reconstruction_loss: F,
429 pub kl_divergence: F,
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use approx::assert_abs_diff_eq;
437
438 #[test]
439 fn test_vae_creation() {
440 let vae = TimeSeriesVAE::<f64>::new(10, 3, 5, 16, 16);
441 assert_eq!(vae.seq_len, 10);
442 assert_eq!(vae.feature_dim, 3);
443 assert_eq!(vae.latent_dim, 5);
444 assert_eq!(vae.encoder_hidden, 16);
445 assert_eq!(vae.decoder_hidden, 16);
446 }
447
448 #[test]
449 fn test_vae_encode_decode() {
450 let vae = TimeSeriesVAE::<f64>::new(5, 2, 3, 8, 8);
451 let input =
452 Array2::from_shape_vec((5, 2), (0..10).map(|i| i as f64 * 0.1).collect()).unwrap();
453
454 let (mean, logvar) = vae.encode(&input).unwrap();
455 assert_eq!(mean.len(), 3);
456 assert_eq!(logvar.len(), 3);
457
458 let sample = vae.reparameterize(&mean, &logvar);
459 assert_eq!(sample.len(), 3);
460
461 let decoded = vae.decode(&sample).unwrap();
462 assert_eq!(decoded.dim(), (5, 2));
463 }
464
465 #[test]
466 fn test_vae_forward() {
467 let vae = TimeSeriesVAE::<f64>::new(4, 2, 3, 8, 8);
468 let input =
469 Array2::from_shape_vec((4, 2), (0..8).map(|i| i as f64 * 0.1).collect()).unwrap();
470
471 let output = vae.forward(&input).unwrap();
472 assert_eq!(output.reconstruction.dim(), (4, 2));
473 assert_eq!(output.latent_mean.len(), 3);
474 assert_eq!(output.latent_logvar.len(), 3);
475 assert_eq!(output.latent_sample.len(), 3);
476 assert!(output.reconstruction_loss >= 0.0);
477 assert!(output.kl_divergence >= 0.0);
478 }
479
480 #[test]
481 fn test_vae_uncertainty_estimation() {
482 let vae = TimeSeriesVAE::<f64>::new(3, 2, 2, 6, 6);
483 let input =
484 Array2::from_shape_vec((3, 2), (0..6).map(|i| i as f64 * 0.2).collect()).unwrap();
485
486 let (mean_recon, std_recon) = vae.estimate_uncertainty(&input, 5).unwrap();
487 assert_eq!(mean_recon.dim(), (3, 2));
488 assert_eq!(std_recon.dim(), (3, 2));
489
490 for &val in std_recon.iter() {
492 assert!(val >= 0.0);
493 }
494 }
495
496 #[test]
497 fn test_vae_generation() {
498 let vae = TimeSeriesVAE::<f64>::new(4, 2, 3, 8, 8);
499 let samples = vae.generate(3).unwrap();
500
501 assert_eq!(samples.len(), 3);
502 for sample in samples {
503 assert_eq!(sample.dim(), (4, 2));
504 }
505 }
506}