1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use crate::error::Result;
11
12#[derive(Debug)]
14pub struct TimeSeriesTransformer<F: Float + Debug + scirs2_core::ndarray::ScalarOperand> {
15 num_layers: usize,
17 num_heads: usize,
19 d_model: usize,
21 d_ff: usize,
23 seq_len: usize,
25 pred_len: usize,
27 parameters: Array2<F>,
29 positional_encoding: Array2<F>,
31}
32
33impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand>
34 TimeSeriesTransformer<F>
35{
36 pub fn new(
38 seq_len: usize,
39 pred_len: usize,
40 d_model: usize,
41 num_heads: usize,
42 num_layers: usize,
43 d_ff: usize,
44 ) -> Self {
45 let attention_params_per_layer = 4 * d_model * d_model; let ff_params_per_layer = 2 * d_model * d_ff + d_ff + d_model; let layer_norm_params_per_layer = 2 * d_model * 2; let embedding_params = seq_len * d_model; let output_params = d_model * pred_len; let total_params = num_layers
53 * (attention_params_per_layer + ff_params_per_layer + layer_norm_params_per_layer)
54 + embedding_params
55 + output_params;
56
57 let scale = F::from(2.0).unwrap() / F::from(d_model).unwrap();
59 let std_dev = scale.sqrt();
60
61 let mut parameters = Array2::zeros((1, total_params));
62 for i in 0..total_params {
63 let val = ((i * 13) % 1000) as f64 / 1000.0 - 0.5;
64 parameters[[0, i]] = F::from(val).unwrap() * std_dev;
65 }
66
67 let mut positional_encoding = Array2::zeros((seq_len, d_model));
69 for pos in 0..seq_len {
70 for i in 0..d_model {
71 let angle = F::from(pos).unwrap()
72 / F::from(10000.0)
73 .unwrap()
74 .powf(F::from(2 * (i / 2)).unwrap() / F::from(d_model).unwrap());
75 if i % 2 == 0 {
76 positional_encoding[[pos, i]] = angle.sin();
77 } else {
78 positional_encoding[[pos, i]] = angle.cos();
79 }
80 }
81 }
82
83 Self {
84 num_layers,
85 num_heads,
86 d_model,
87 d_ff,
88 seq_len,
89 pred_len,
90 parameters,
91 positional_encoding,
92 }
93 }
94
95 pub fn forward(&self, input: &Array2<F>) -> Result<Array2<F>> {
97 let batch_size_ = input.nrows();
98
99 let mut x = self.input_embedding(input)?;
101
102 for i in 0..batch_size_ {
104 for j in 0..self.seq_len {
105 for k in 0..self.d_model {
106 x[[i * self.seq_len + j, k]] =
107 x[[i * self.seq_len + j, k]] + self.positional_encoding[[j, k]];
108 }
109 }
110 }
111
112 for layer in 0..self.num_layers {
114 x = self.transformer_layer(&x, layer)?;
115 }
116
117 self.output_projection(&x, batch_size_)
119 }
120
121 fn input_embedding(&self, input: &Array2<F>) -> Result<Array2<F>> {
123 let batch_size_ = input.nrows();
124 let input_dim = input.ncols();
125
126 let mut embedded = Array2::zeros((batch_size_ * self.seq_len, self.d_model));
128
129 let param_start = 0;
131
132 for i in 0..batch_size_ {
133 for j in 0..self.seq_len.min(input_dim) {
134 for k in 0..self.d_model {
135 let weight_idx = (j * self.d_model + k) % (self.seq_len * self.d_model);
136 let weight = if param_start + weight_idx < self.parameters.ncols() {
137 self.parameters[[0, param_start + weight_idx]]
138 } else {
139 F::zero()
140 };
141 embedded[[i * self.seq_len + j, k]] = input[[i, j]] * weight;
142 }
143 }
144 }
145
146 Ok(embedded)
147 }
148
149 fn transformer_layer(&self, input: &Array2<F>, layeridx: usize) -> Result<Array2<F>> {
151 let attention_output = self.multi_head_attention(input, layeridx)?;
153
154 let norm1_output =
156 self.layer_norm(&self.add_residual(input, &attention_output)?, layeridx, 0)?;
157
158 let ff_output = self.feed_forward(&norm1_output, layeridx)?;
160
161 let final_output =
163 self.layer_norm(&self.add_residual(&norm1_output, &ff_output)?, layeridx, 1)?;
164
165 Ok(final_output)
166 }
167
168 fn multi_head_attention(&self, input: &Array2<F>, layeridx: usize) -> Result<Array2<F>> {
170 let seq_len = input.nrows();
171 let head_dim = self.d_model / self.num_heads;
172
173 let mut output = Array2::zeros((seq_len, self.d_model));
175
176 for head in 0..self.num_heads {
177 let q = self.compute_qkv_projection(input, layeridx, head, 0)?; let k = self.compute_qkv_projection(input, layeridx, head, 1)?; let v = self.compute_qkv_projection(input, layeridx, head, 2)?; let attention_scores = self.compute_attention_scores(&q, &k)?;
184
185 let head_output = self.apply_attention(&attention_scores, &v)?;
187
188 for i in 0..seq_len {
190 for j in 0..head_dim {
191 if head * head_dim + j < self.d_model {
192 output[[i, head * head_dim + j]] = head_output[[i, j]];
193 }
194 }
195 }
196 }
197
198 Ok(output)
199 }
200
201 fn compute_qkv_projection(
203 &self,
204 input: &Array2<F>,
205 layer_idx: usize,
206 head: usize,
207 projection_type: usize,
208 ) -> Result<Array2<F>> {
209 let seq_len = input.nrows();
210 let head_dim = self.d_model / self.num_heads;
211 let mut output = Array2::zeros((seq_len, head_dim));
212
213 for i in 0..seq_len {
215 for j in 0..head_dim {
216 let mut sum = F::zero();
217 for k in 0..self.d_model {
218 let weight_idx = (layer_idx * 1000
220 + head * 100
221 + projection_type * 10
222 + j * self.d_model
223 + k)
224 % self.parameters.ncols();
225 let weight = self.parameters[[0, weight_idx]];
226 sum = sum + input[[i, k]] * weight;
227 }
228 output[[i, j]] = sum;
229 }
230 }
231
232 Ok(output)
233 }
234
235 fn compute_attention_scores(&self, q: &Array2<F>, k: &Array2<F>) -> Result<Array2<F>> {
237 let seq_len = q.nrows();
238 let head_dim = q.ncols();
239 let scale = F::one() / F::from(head_dim).unwrap().sqrt();
240
241 let mut scores = Array2::zeros((seq_len, seq_len));
242
243 for i in 0..seq_len {
244 for j in 0..seq_len {
245 let mut dot_product = F::zero();
246 for dim in 0..head_dim {
247 dot_product = dot_product + q[[i, dim]] * k[[j, dim]];
248 }
249 scores[[i, j]] = dot_product * scale;
250 }
251 }
252
253 self.softmax_2d(&scores)
255 }
256
257 fn apply_attention(&self, attention: &Array2<F>, values: &Array2<F>) -> Result<Array2<F>> {
259 let seq_len = attention.nrows();
260 let head_dim = values.ncols();
261 let mut output = Array2::zeros((seq_len, head_dim));
262
263 for i in 0..seq_len {
264 for j in 0..head_dim {
265 let mut sum = F::zero();
266 for k in 0..seq_len {
267 sum = sum + attention[[i, k]] * values[[k, j]];
268 }
269 output[[i, j]] = sum;
270 }
271 }
272
273 Ok(output)
274 }
275
276 fn feed_forward(&self, input: &Array2<F>, layeridx: usize) -> Result<Array2<F>> {
278 let seq_len = input.nrows();
279
280 let mut hidden = Array2::zeros((seq_len, self.d_ff));
282 for i in 0..seq_len {
283 for j in 0..self.d_ff {
284 let mut sum = F::zero();
285 for k in 0..self.d_model {
286 let weight_idx =
287 (layeridx * 2000 + j * self.d_model + k) % self.parameters.ncols();
288 let weight = self.parameters[[0, weight_idx]];
289 sum = sum + input[[i, k]] * weight;
290 }
291 hidden[[i, j]] = self.relu(sum);
292 }
293 }
294
295 let mut output = Array2::zeros((seq_len, self.d_model));
297 for i in 0..seq_len {
298 for j in 0..self.d_model {
299 let mut sum = F::zero();
300 for k in 0..self.d_ff {
301 let weight_idx =
302 (layeridx * 3000 + j * self.d_ff + k) % self.parameters.ncols();
303 let weight = self.parameters[[0, weight_idx]];
304 sum = sum + hidden[[i, k]] * weight;
305 }
306 output[[i, j]] = sum;
307 }
308 }
309
310 Ok(output)
311 }
312
313 fn layer_norm(
315 &self,
316 input: &Array2<F>,
317 layer_idx: usize,
318 norm_idx: usize,
319 ) -> Result<Array2<F>> {
320 let seq_len = input.nrows();
321 let mut output = Array2::zeros(input.dim());
322
323 for i in 0..seq_len {
324 let mut sum = F::zero();
326 for j in 0..self.d_model {
327 sum = sum + input[[i, j]];
328 }
329 let mean = sum / F::from(self.d_model).unwrap();
330
331 let mut var_sum = F::zero();
332 for j in 0..self.d_model {
333 let diff = input[[i, j]] - mean;
334 var_sum = var_sum + diff * diff;
335 }
336 let variance = var_sum / F::from(self.d_model).unwrap();
337 let std_dev = (variance + F::from(1e-5).unwrap()).sqrt();
338
339 for j in 0..self.d_model {
341 let normalized = (input[[i, j]] - mean) / std_dev;
342
343 let gamma_idx = (layer_idx * 100 + norm_idx * 50 + j) % self.parameters.ncols();
345 let beta_idx = (layer_idx * 100 + norm_idx * 50 + j + 25) % self.parameters.ncols();
346
347 let gamma = self.parameters[[0, gamma_idx]];
348 let beta = self.parameters[[0, beta_idx]];
349
350 output[[i, j]] = gamma * normalized + beta;
351 }
352 }
353
354 Ok(output)
355 }
356
357 fn add_residual(&self, input1: &Array2<F>, input2: &Array2<F>) -> Result<Array2<F>> {
359 let mut output = Array2::zeros(input1.dim());
360
361 for i in 0..input1.nrows() {
362 for j in 0..input1.ncols() {
363 output[[i, j]] = input1[[i, j]] + input2[[i, j]];
364 }
365 }
366
367 Ok(output)
368 }
369
370 fn output_projection(&self, input: &Array2<F>, batchsize: usize) -> Result<Array2<F>> {
372 let mut output = Array2::zeros((batchsize, self.pred_len));
373
374 for i in 0..batchsize {
376 let last_token_idx = i * self.seq_len + self.seq_len - 1;
377
378 for j in 0..self.pred_len {
379 let mut sum = F::zero();
380 for k in 0..self.d_model {
381 let weight_idx = (j * self.d_model + k) % self.parameters.ncols();
382 let weight = self.parameters[[0, weight_idx]];
383 sum = sum + input[[last_token_idx, k]] * weight;
384 }
385 output[[i, j]] = sum;
386 }
387 }
388
389 Ok(output)
390 }
391
392 fn softmax_2d(&self, input: &Array2<F>) -> Result<Array2<F>> {
394 let mut output = Array2::zeros(input.dim());
395
396 for i in 0..input.nrows() {
397 let mut max_val = input[[i, 0]];
399 for j in 1..input.ncols() {
400 if input[[i, j]] > max_val {
401 max_val = input[[i, j]];
402 }
403 }
404
405 let mut sum = F::zero();
407 for j in 0..input.ncols() {
408 let exp_val = (input[[i, j]] - max_val).exp();
409 output[[i, j]] = exp_val;
410 sum = sum + exp_val;
411 }
412
413 for j in 0..input.ncols() {
415 output[[i, j]] = output[[i, j]] / sum;
416 }
417 }
418
419 Ok(output)
420 }
421
422 fn relu(&self, x: F) -> F {
424 x.max(F::zero())
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use approx::assert_abs_diff_eq;
432
433 #[test]
434 fn test_transformer_creation() {
435 let transformer = TimeSeriesTransformer::<f64>::new(10, 5, 64, 8, 4, 256);
436 assert_eq!(transformer.seq_len, 10);
437 assert_eq!(transformer.pred_len, 5);
438 assert_eq!(transformer.d_model, 64);
439 assert_eq!(transformer.num_heads, 8);
440 assert_eq!(transformer.num_layers, 4);
441 assert_eq!(transformer.d_ff, 256);
442 }
443
444 #[test]
445 fn test_positional_encoding() {
446 let transformer = TimeSeriesTransformer::<f64>::new(8, 4, 16, 4, 2, 64);
447 let pe = &transformer.positional_encoding;
448
449 assert_eq!(pe.dim(), (8, 16));
450
451 for &val in pe.iter() {
453 assert!(val >= -1.0 && val <= 1.0);
454 }
455 }
456
457 #[test]
458 fn test_transformer_forward() {
459 let transformer = TimeSeriesTransformer::<f64>::new(6, 3, 32, 4, 2, 128);
460 let input =
461 Array2::from_shape_vec((2, 6), (0..12).map(|i| i as f64 * 0.1).collect()).unwrap();
462
463 let output = transformer.forward(&input).unwrap();
464 assert_eq!(output.dim(), (2, 3)); for &val in output.iter() {
468 assert!(val.is_finite());
469 }
470 }
471
472 #[test]
473 fn test_softmax_properties() {
474 let transformer = TimeSeriesTransformer::<f64>::new(4, 2, 8, 2, 1, 32);
475 let input =
476 Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 0.5, 1.5, 2.5, 2.0, 1.0, 0.5])
477 .unwrap();
478
479 let output = transformer.softmax_2d(&input).unwrap();
480
481 for i in 0..output.nrows() {
483 let row_sum: f64 = (0..output.ncols()).map(|j| output[[i, j]]).sum();
484 assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
485 }
486
487 for &val in output.iter() {
489 assert!(val >= 0.0);
490 }
491 }
492
493 #[test]
494 fn test_input_embedding() {
495 let transformer = TimeSeriesTransformer::<f64>::new(5, 3, 16, 4, 2, 64);
496 let input =
497 Array2::from_shape_vec((2, 5), (0..10).map(|i| i as f64 * 0.2).collect()).unwrap();
498
499 let embedded = transformer.input_embedding(&input).unwrap();
500 assert_eq!(embedded.dim(), (10, 16)); for &val in embedded.iter() {
504 assert!(val.is_finite());
505 }
506 }
507
508 #[test]
509 fn test_multi_head_attention() {
510 let transformer = TimeSeriesTransformer::<f64>::new(4, 2, 16, 4, 1, 64);
511 let input = Array2::zeros((4, 16)); let output = transformer.multi_head_attention(&input, 0).unwrap();
514 assert_eq!(output.dim(), (4, 16));
515
516 for &val in output.iter() {
518 assert!(val.is_finite());
519 }
520 }
521}