1#![allow(clippy::too_many_arguments)]
2
3use pyo3::prelude::*;
4use rayon::prelude::*;
5
6#[derive(Debug, Clone)]
7#[pyclass]
8pub struct TFTConfig {
9 #[pyo3(get, set)]
10 pub hidden_dim: usize,
11 #[pyo3(get, set)]
12 pub num_heads: usize,
13 #[pyo3(get, set)]
14 pub num_encoder_layers: usize,
15 #[pyo3(get, set)]
16 pub num_decoder_layers: usize,
17 #[pyo3(get, set)]
18 pub dropout_rate: f64,
19 #[pyo3(get, set)]
20 pub num_time_bins: usize,
21 #[pyo3(get, set)]
22 pub quantiles: Vec<f64>,
23 #[pyo3(get, set)]
24 pub learning_rate: f64,
25 #[pyo3(get, set)]
26 pub batch_size: usize,
27 #[pyo3(get, set)]
28 pub n_epochs: usize,
29 #[pyo3(get, set)]
30 pub seed: Option<u64>,
31}
32
33#[pymethods]
34impl TFTConfig {
35 #[new]
36 #[pyo3(signature = (
37 hidden_dim=64,
38 num_heads=4,
39 num_encoder_layers=2,
40 num_decoder_layers=2,
41 dropout_rate=0.1,
42 num_time_bins=20,
43 quantiles=None,
44 learning_rate=0.001,
45 batch_size=64,
46 n_epochs=100,
47 seed=None
48 ))]
49 pub fn new(
50 hidden_dim: usize,
51 num_heads: usize,
52 num_encoder_layers: usize,
53 num_decoder_layers: usize,
54 dropout_rate: f64,
55 num_time_bins: usize,
56 quantiles: Option<Vec<f64>>,
57 learning_rate: f64,
58 batch_size: usize,
59 n_epochs: usize,
60 seed: Option<u64>,
61 ) -> PyResult<Self> {
62 if !hidden_dim.is_multiple_of(num_heads) {
63 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
64 "hidden_dim must be divisible by num_heads",
65 ));
66 }
67 Ok(Self {
68 hidden_dim,
69 num_heads,
70 num_encoder_layers,
71 num_decoder_layers,
72 dropout_rate,
73 num_time_bins,
74 quantiles: quantiles.unwrap_or_else(|| vec![0.1, 0.5, 0.9]),
75 learning_rate,
76 batch_size,
77 n_epochs,
78 seed,
79 })
80 }
81}
82
83#[allow(dead_code)]
84fn glu(x: &[f64], weights: &[f64]) -> Vec<f64> {
85 let half = x.len() / 2;
86 x.iter()
87 .take(half)
88 .zip(x.iter().skip(half))
89 .zip(weights.iter())
90 .map(|((&a, &b), &w)| a * (1.0 / (1.0 + (-b * w).exp())))
91 .collect()
92}
93
94fn grn(
95 input: &[f64],
96 context: Option<&[f64]>,
97 weights1: &[Vec<f64>],
98 weights2: &[Vec<f64>],
99 biases: &[f64],
100) -> Vec<f64> {
101 let _hidden_dim = weights1.len();
102
103 let hidden: Vec<f64> = weights1
104 .iter()
105 .zip(biases.iter())
106 .map(|(w, &b)| {
107 let mut sum: f64 = input.iter().zip(w.iter()).map(|(&x, &wi)| x * wi).sum();
108 if let Some(ctx) = context {
109 sum += ctx
110 .iter()
111 .zip(w.iter())
112 .map(|(&c, &wi)| c * wi)
113 .sum::<f64>();
114 }
115 (sum + b).max(0.0)
116 })
117 .collect();
118
119 let output: Vec<f64> = weights2
120 .iter()
121 .map(|w| hidden.iter().zip(w.iter()).map(|(&h, &wi)| h * wi).sum())
122 .collect();
123
124 output
125}
126
127fn temporal_self_attention(
128 queries: &[Vec<f64>],
129 keys: &[Vec<f64>],
130 values: &[Vec<f64>],
131 num_heads: usize,
132) -> Vec<Vec<f64>> {
133 let seq_len = queries.len();
134 let d_model = queries[0].len();
135 let d_head = d_model / num_heads;
136
137 let mut outputs = vec![vec![0.0; d_model]; seq_len];
138
139 for h in 0..num_heads {
140 let start = h * d_head;
141 let end = start + d_head;
142
143 for t in 0..seq_len {
144 let q: Vec<f64> = queries[t][start..end].to_vec();
145
146 let scores: Vec<f64> = (0..=t)
147 .map(|s| {
148 let k: Vec<f64> = keys[s][start..end].to_vec();
149 q.iter()
150 .zip(k.iter())
151 .map(|(&qi, &ki)| qi * ki)
152 .sum::<f64>()
153 / (d_head as f64).sqrt()
154 })
155 .collect();
156
157 let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
158 let exp_scores: Vec<f64> = scores.iter().map(|&s| (s - max_score).exp()).collect();
159 let sum_exp: f64 = exp_scores.iter().sum();
160 let attention: Vec<f64> = exp_scores.iter().map(|&e| e / sum_exp).collect();
161
162 for (s, &att) in attention.iter().enumerate() {
163 for (j, &v) in values[s][start..end].iter().enumerate() {
164 outputs[t][start + j] += att * v;
165 }
166 }
167 }
168 }
169
170 outputs
171}
172
173#[derive(Debug, Clone)]
174#[pyclass]
175pub struct TemporalFusionTransformer {
176 static_encoder_weights: Vec<Vec<f64>>,
177 static_encoder_biases: Vec<f64>,
178 temporal_encoder_weights: Vec<Vec<f64>>,
179 temporal_encoder_biases: Vec<f64>,
180 grn_weights1: Vec<Vec<f64>>,
181 grn_weights2: Vec<Vec<f64>>,
182 grn_biases: Vec<f64>,
183 #[allow(dead_code)]
184 attention_weights: Vec<Vec<f64>>,
185 output_weights: Vec<Vec<f64>>,
186 output_biases: Vec<f64>,
187 time_bins: Vec<f64>,
188 config: TFTConfig,
189 n_static_features: usize,
190 n_temporal_features: usize,
191}
192
193#[pymethods]
194impl TemporalFusionTransformer {
195 fn predict_survival(
196 &self,
197 static_features: Vec<Vec<f64>>,
198 temporal_features: Vec<Vec<Vec<f64>>>,
199 ) -> PyResult<Vec<Vec<f64>>> {
200 if static_features.is_empty() {
201 return Ok(Vec::new());
202 }
203
204 let n_samples = static_features.len();
205
206 let survival: Vec<Vec<f64>> = (0..n_samples)
207 .into_par_iter()
208 .map(|i| {
209 let static_encoded: Vec<f64> = self
210 .static_encoder_weights
211 .iter()
212 .zip(self.static_encoder_biases.iter())
213 .map(|(w, &b)| {
214 let sum: f64 = static_features[i]
215 .iter()
216 .zip(w.iter())
217 .map(|(&x, &wi)| x * wi)
218 .sum();
219 (sum + b).max(0.0)
220 })
221 .collect();
222
223 let seq_len = temporal_features.get(i).map(|t| t.len()).unwrap_or(1);
224 let mut temporal_encoded: Vec<Vec<f64>> = Vec::with_capacity(seq_len);
225
226 for t in 0..seq_len {
227 let temporal_input = temporal_features
228 .get(i)
229 .and_then(|tf| tf.get(t))
230 .cloned()
231 .unwrap_or_else(|| vec![0.0; self.n_temporal_features]);
232
233 let encoded: Vec<f64> = self
234 .temporal_encoder_weights
235 .iter()
236 .zip(self.temporal_encoder_biases.iter())
237 .map(|(w, &b)| {
238 let sum: f64 = temporal_input
239 .iter()
240 .zip(w.iter())
241 .map(|(&x, &wi)| x * wi)
242 .sum();
243 (sum + b).max(0.0)
244 })
245 .collect();
246 temporal_encoded.push(encoded);
247 }
248
249 let enriched: Vec<Vec<f64>> = temporal_encoded
250 .iter()
251 .map(|te| {
252 grn(
253 te,
254 Some(&static_encoded),
255 &self.grn_weights1,
256 &self.grn_weights2,
257 &self.grn_biases,
258 )
259 })
260 .collect();
261
262 let attended =
263 temporal_self_attention(&enriched, &enriched, &enriched, self.config.num_heads);
264
265 let final_repr = attended.last().unwrap_or(&static_encoded);
266
267 let logits: Vec<f64> = self
268 .output_weights
269 .iter()
270 .zip(self.output_biases.iter())
271 .map(|(w, &b)| {
272 let sum: f64 = final_repr
273 .iter()
274 .zip(w.iter())
275 .map(|(&h, &wi)| h * wi)
276 .sum();
277 sum + b
278 })
279 .collect();
280
281 let max_logit = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
282 let exp_logits: Vec<f64> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
283 let sum_exp: f64 = exp_logits.iter().sum();
284 let probs: Vec<f64> = exp_logits.iter().map(|&e| e / sum_exp).collect();
285
286 let mut surv = vec![0.0; probs.len()];
287 let mut cumsum = 0.0;
288 for j in (0..probs.len()).rev() {
289 cumsum += probs[j];
290 surv[j] = cumsum.min(1.0);
291 }
292 surv
293 })
294 .collect();
295
296 Ok(survival)
297 }
298
299 fn predict_quantiles(
300 &self,
301 static_features: Vec<Vec<f64>>,
302 temporal_features: Vec<Vec<Vec<f64>>>,
303 ) -> PyResult<Vec<Vec<Vec<f64>>>> {
304 let survival = self.predict_survival(static_features, temporal_features)?;
305
306 let quantile_predictions: Vec<Vec<Vec<f64>>> = survival
307 .iter()
308 .map(|s| {
309 self.config
310 .quantiles
311 .iter()
312 .map(|&q| {
313 s.iter()
314 .map(|&si| si * q + (1.0 - si) * (1.0 - q))
315 .collect()
316 })
317 .collect()
318 })
319 .collect();
320
321 Ok(quantile_predictions)
322 }
323
324 fn get_attention_weights(
325 &self,
326 static_features: Vec<f64>,
327 temporal_features: Vec<Vec<f64>>,
328 ) -> PyResult<Vec<Vec<f64>>> {
329 let seq_len = temporal_features.len();
330 let mut attention_weights = vec![vec![0.0; seq_len]; seq_len];
331
332 let _static_encoded: Vec<f64> = self
333 .static_encoder_weights
334 .iter()
335 .zip(self.static_encoder_biases.iter())
336 .map(|(w, &b)| {
337 let sum: f64 = static_features
338 .iter()
339 .zip(w.iter())
340 .map(|(&x, &wi)| x * wi)
341 .sum();
342 (sum + b).max(0.0)
343 })
344 .collect();
345
346 let temporal_encoded: Vec<Vec<f64>> = temporal_features
347 .iter()
348 .map(|tf| {
349 self.temporal_encoder_weights
350 .iter()
351 .zip(self.temporal_encoder_biases.iter())
352 .map(|(w, &b)| {
353 let sum: f64 = tf.iter().zip(w.iter()).map(|(&x, &wi)| x * wi).sum();
354 (sum + b).max(0.0)
355 })
356 .collect()
357 })
358 .collect();
359
360 let d_head = self.config.hidden_dim / self.config.num_heads;
361
362 for t in 0..seq_len {
363 let q: Vec<f64> = temporal_encoded[t][..d_head].to_vec();
364
365 let scores: Vec<f64> = (0..=t)
366 .map(|s| {
367 let k: Vec<f64> = temporal_encoded[s][..d_head].to_vec();
368 q.iter()
369 .zip(k.iter())
370 .map(|(&qi, &ki)| qi * ki)
371 .sum::<f64>()
372 / (d_head as f64).sqrt()
373 })
374 .collect();
375
376 let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
377 let exp_scores: Vec<f64> = scores.iter().map(|&s| (s - max_score).exp()).collect();
378 let sum_exp: f64 = exp_scores.iter().sum();
379
380 for (s, &e) in exp_scores.iter().enumerate() {
381 attention_weights[t][s] = e / sum_exp;
382 }
383 }
384
385 Ok(attention_weights)
386 }
387
388 fn get_time_bins(&self) -> Vec<f64> {
389 self.time_bins.clone()
390 }
391
392 fn __repr__(&self) -> String {
393 format!(
394 "TemporalFusionTransformer(static={}, temporal={}, hidden={})",
395 self.n_static_features, self.n_temporal_features, self.config.hidden_dim
396 )
397 }
398}
399
400#[pyfunction]
401#[pyo3(signature = (
402 static_features,
403 temporal_features,
404 time,
405 event,
406 config=None
407))]
408pub fn fit_temporal_fusion_transformer(
409 static_features: Vec<Vec<f64>>,
410 temporal_features: Vec<Vec<Vec<f64>>>,
411 time: Vec<f64>,
412 event: Vec<i32>,
413 config: Option<TFTConfig>,
414) -> PyResult<TemporalFusionTransformer> {
415 let config = config.unwrap_or_else(|| {
416 TFTConfig::new(64, 4, 2, 2, 0.1, 20, None, 0.001, 64, 100, None).unwrap()
417 });
418
419 let n = static_features.len();
420 if n == 0 || time.len() != n || event.len() != n {
421 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
422 "Input arrays must have the same non-zero length",
423 ));
424 }
425
426 let n_static = static_features[0].len();
427 let n_temporal = temporal_features
428 .first()
429 .and_then(|t| t.first())
430 .map(|f| f.len())
431 .unwrap_or(1);
432
433 let mut rng = fastrand::Rng::new();
434 if let Some(seed) = config.seed {
435 rng.seed(seed);
436 }
437
438 let static_encoder_weights: Vec<Vec<f64>> = (0..config.hidden_dim)
439 .map(|_| (0..n_static).map(|_| rng.f64() * 0.1 - 0.05).collect())
440 .collect();
441 let static_encoder_biases: Vec<f64> = (0..config.hidden_dim)
442 .map(|_| rng.f64() * 0.1 - 0.05)
443 .collect();
444
445 let temporal_encoder_weights: Vec<Vec<f64>> = (0..config.hidden_dim)
446 .map(|_| (0..n_temporal).map(|_| rng.f64() * 0.1 - 0.05).collect())
447 .collect();
448 let temporal_encoder_biases: Vec<f64> = (0..config.hidden_dim)
449 .map(|_| rng.f64() * 0.1 - 0.05)
450 .collect();
451
452 let grn_weights1: Vec<Vec<f64>> = (0..config.hidden_dim)
453 .map(|_| {
454 (0..config.hidden_dim)
455 .map(|_| rng.f64() * 0.1 - 0.05)
456 .collect()
457 })
458 .collect();
459 let grn_weights2: Vec<Vec<f64>> = (0..config.hidden_dim)
460 .map(|_| {
461 (0..config.hidden_dim)
462 .map(|_| rng.f64() * 0.1 - 0.05)
463 .collect()
464 })
465 .collect();
466 let grn_biases: Vec<f64> = (0..config.hidden_dim)
467 .map(|_| rng.f64() * 0.1 - 0.05)
468 .collect();
469
470 let attention_weights: Vec<Vec<f64>> = (0..config.hidden_dim)
471 .map(|_| {
472 (0..config.hidden_dim)
473 .map(|_| rng.f64() * 0.1 - 0.05)
474 .collect()
475 })
476 .collect();
477
478 let output_weights: Vec<Vec<f64>> = (0..config.num_time_bins)
479 .map(|_| {
480 (0..config.hidden_dim)
481 .map(|_| rng.f64() * 0.1 - 0.05)
482 .collect()
483 })
484 .collect();
485 let output_biases: Vec<f64> = (0..config.num_time_bins)
486 .map(|_| rng.f64() * 0.1 - 0.05)
487 .collect();
488
489 let min_time = time.iter().cloned().fold(f64::INFINITY, f64::min);
490 let max_time = time.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
491 let time_bins: Vec<f64> = (0..=config.num_time_bins)
492 .map(|i| min_time + (max_time - min_time) * i as f64 / config.num_time_bins as f64)
493 .collect();
494
495 Ok(TemporalFusionTransformer {
496 static_encoder_weights,
497 static_encoder_biases,
498 temporal_encoder_weights,
499 temporal_encoder_biases,
500 grn_weights1,
501 grn_weights2,
502 grn_biases,
503 attention_weights,
504 output_weights,
505 output_biases,
506 time_bins,
507 config,
508 n_static_features: n_static,
509 n_temporal_features: n_temporal,
510 })
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn test_config_validation() {
519 let result = TFTConfig::new(64, 5, 2, 2, 0.1, 20, None, 0.001, 64, 100, None);
520 assert!(result.is_err());
521 }
522
523 #[test]
524 fn test_grn() {
525 let input = vec![1.0, 2.0];
526 let weights1 = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
527 let weights2 = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
528 let biases = vec![0.0, 0.0];
529 let output = grn(&input, None, &weights1, &weights2, &biases);
530 assert_eq!(output.len(), 2);
531 }
532}