Skip to main content

survival/ml/
temporal_fusion.rs

1#![allow(clippy::too_many_arguments)]
2
3use pyo3::prelude::*;
4use rayon::prelude::*;
5
6#[derive(Debug, Clone)]
7#[pyclass]
8pub struct TFTConfig {
9    #[pyo3(get, set)]
10    pub hidden_dim: usize,
11    #[pyo3(get, set)]
12    pub num_heads: usize,
13    #[pyo3(get, set)]
14    pub num_encoder_layers: usize,
15    #[pyo3(get, set)]
16    pub num_decoder_layers: usize,
17    #[pyo3(get, set)]
18    pub dropout_rate: f64,
19    #[pyo3(get, set)]
20    pub num_time_bins: usize,
21    #[pyo3(get, set)]
22    pub quantiles: Vec<f64>,
23    #[pyo3(get, set)]
24    pub learning_rate: f64,
25    #[pyo3(get, set)]
26    pub batch_size: usize,
27    #[pyo3(get, set)]
28    pub n_epochs: usize,
29    #[pyo3(get, set)]
30    pub seed: Option<u64>,
31}
32
33#[pymethods]
34impl TFTConfig {
35    #[new]
36    #[pyo3(signature = (
37        hidden_dim=64,
38        num_heads=4,
39        num_encoder_layers=2,
40        num_decoder_layers=2,
41        dropout_rate=0.1,
42        num_time_bins=20,
43        quantiles=None,
44        learning_rate=0.001,
45        batch_size=64,
46        n_epochs=100,
47        seed=None
48    ))]
49    pub fn new(
50        hidden_dim: usize,
51        num_heads: usize,
52        num_encoder_layers: usize,
53        num_decoder_layers: usize,
54        dropout_rate: f64,
55        num_time_bins: usize,
56        quantiles: Option<Vec<f64>>,
57        learning_rate: f64,
58        batch_size: usize,
59        n_epochs: usize,
60        seed: Option<u64>,
61    ) -> PyResult<Self> {
62        if !hidden_dim.is_multiple_of(num_heads) {
63            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
64                "hidden_dim must be divisible by num_heads",
65            ));
66        }
67        Ok(Self {
68            hidden_dim,
69            num_heads,
70            num_encoder_layers,
71            num_decoder_layers,
72            dropout_rate,
73            num_time_bins,
74            quantiles: quantiles.unwrap_or_else(|| vec![0.1, 0.5, 0.9]),
75            learning_rate,
76            batch_size,
77            n_epochs,
78            seed,
79        })
80    }
81}
82
83#[allow(dead_code)]
84fn glu(x: &[f64], weights: &[f64]) -> Vec<f64> {
85    let half = x.len() / 2;
86    x.iter()
87        .take(half)
88        .zip(x.iter().skip(half))
89        .zip(weights.iter())
90        .map(|((&a, &b), &w)| a * (1.0 / (1.0 + (-b * w).exp())))
91        .collect()
92}
93
94fn grn(
95    input: &[f64],
96    context: Option<&[f64]>,
97    weights1: &[Vec<f64>],
98    weights2: &[Vec<f64>],
99    biases: &[f64],
100) -> Vec<f64> {
101    let _hidden_dim = weights1.len();
102
103    let hidden: Vec<f64> = weights1
104        .iter()
105        .zip(biases.iter())
106        .map(|(w, &b)| {
107            let mut sum: f64 = input.iter().zip(w.iter()).map(|(&x, &wi)| x * wi).sum();
108            if let Some(ctx) = context {
109                sum += ctx
110                    .iter()
111                    .zip(w.iter())
112                    .map(|(&c, &wi)| c * wi)
113                    .sum::<f64>();
114            }
115            (sum + b).max(0.0)
116        })
117        .collect();
118
119    let output: Vec<f64> = weights2
120        .iter()
121        .map(|w| hidden.iter().zip(w.iter()).map(|(&h, &wi)| h * wi).sum())
122        .collect();
123
124    output
125}
126
127fn temporal_self_attention(
128    queries: &[Vec<f64>],
129    keys: &[Vec<f64>],
130    values: &[Vec<f64>],
131    num_heads: usize,
132) -> Vec<Vec<f64>> {
133    let seq_len = queries.len();
134    let d_model = queries[0].len();
135    let d_head = d_model / num_heads;
136
137    let mut outputs = vec![vec![0.0; d_model]; seq_len];
138
139    for h in 0..num_heads {
140        let start = h * d_head;
141        let end = start + d_head;
142
143        for t in 0..seq_len {
144            let q: Vec<f64> = queries[t][start..end].to_vec();
145
146            let scores: Vec<f64> = (0..=t)
147                .map(|s| {
148                    let k: Vec<f64> = keys[s][start..end].to_vec();
149                    q.iter()
150                        .zip(k.iter())
151                        .map(|(&qi, &ki)| qi * ki)
152                        .sum::<f64>()
153                        / (d_head as f64).sqrt()
154                })
155                .collect();
156
157            let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
158            let exp_scores: Vec<f64> = scores.iter().map(|&s| (s - max_score).exp()).collect();
159            let sum_exp: f64 = exp_scores.iter().sum();
160            let attention: Vec<f64> = exp_scores.iter().map(|&e| e / sum_exp).collect();
161
162            for (s, &att) in attention.iter().enumerate() {
163                for (j, &v) in values[s][start..end].iter().enumerate() {
164                    outputs[t][start + j] += att * v;
165                }
166            }
167        }
168    }
169
170    outputs
171}
172
173#[derive(Debug, Clone)]
174#[pyclass]
175pub struct TemporalFusionTransformer {
176    static_encoder_weights: Vec<Vec<f64>>,
177    static_encoder_biases: Vec<f64>,
178    temporal_encoder_weights: Vec<Vec<f64>>,
179    temporal_encoder_biases: Vec<f64>,
180    grn_weights1: Vec<Vec<f64>>,
181    grn_weights2: Vec<Vec<f64>>,
182    grn_biases: Vec<f64>,
183    #[allow(dead_code)]
184    attention_weights: Vec<Vec<f64>>,
185    output_weights: Vec<Vec<f64>>,
186    output_biases: Vec<f64>,
187    time_bins: Vec<f64>,
188    config: TFTConfig,
189    n_static_features: usize,
190    n_temporal_features: usize,
191}
192
193#[pymethods]
194impl TemporalFusionTransformer {
195    fn predict_survival(
196        &self,
197        static_features: Vec<Vec<f64>>,
198        temporal_features: Vec<Vec<Vec<f64>>>,
199    ) -> PyResult<Vec<Vec<f64>>> {
200        if static_features.is_empty() {
201            return Ok(Vec::new());
202        }
203
204        let n_samples = static_features.len();
205
206        let survival: Vec<Vec<f64>> = (0..n_samples)
207            .into_par_iter()
208            .map(|i| {
209                let static_encoded: Vec<f64> = self
210                    .static_encoder_weights
211                    .iter()
212                    .zip(self.static_encoder_biases.iter())
213                    .map(|(w, &b)| {
214                        let sum: f64 = static_features[i]
215                            .iter()
216                            .zip(w.iter())
217                            .map(|(&x, &wi)| x * wi)
218                            .sum();
219                        (sum + b).max(0.0)
220                    })
221                    .collect();
222
223                let seq_len = temporal_features.get(i).map(|t| t.len()).unwrap_or(1);
224                let mut temporal_encoded: Vec<Vec<f64>> = Vec::with_capacity(seq_len);
225
226                for t in 0..seq_len {
227                    let temporal_input = temporal_features
228                        .get(i)
229                        .and_then(|tf| tf.get(t))
230                        .cloned()
231                        .unwrap_or_else(|| vec![0.0; self.n_temporal_features]);
232
233                    let encoded: Vec<f64> = self
234                        .temporal_encoder_weights
235                        .iter()
236                        .zip(self.temporal_encoder_biases.iter())
237                        .map(|(w, &b)| {
238                            let sum: f64 = temporal_input
239                                .iter()
240                                .zip(w.iter())
241                                .map(|(&x, &wi)| x * wi)
242                                .sum();
243                            (sum + b).max(0.0)
244                        })
245                        .collect();
246                    temporal_encoded.push(encoded);
247                }
248
249                let enriched: Vec<Vec<f64>> = temporal_encoded
250                    .iter()
251                    .map(|te| {
252                        grn(
253                            te,
254                            Some(&static_encoded),
255                            &self.grn_weights1,
256                            &self.grn_weights2,
257                            &self.grn_biases,
258                        )
259                    })
260                    .collect();
261
262                let attended =
263                    temporal_self_attention(&enriched, &enriched, &enriched, self.config.num_heads);
264
265                let final_repr = attended.last().unwrap_or(&static_encoded);
266
267                let logits: Vec<f64> = self
268                    .output_weights
269                    .iter()
270                    .zip(self.output_biases.iter())
271                    .map(|(w, &b)| {
272                        let sum: f64 = final_repr
273                            .iter()
274                            .zip(w.iter())
275                            .map(|(&h, &wi)| h * wi)
276                            .sum();
277                        sum + b
278                    })
279                    .collect();
280
281                let max_logit = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
282                let exp_logits: Vec<f64> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
283                let sum_exp: f64 = exp_logits.iter().sum();
284                let probs: Vec<f64> = exp_logits.iter().map(|&e| e / sum_exp).collect();
285
286                let mut surv = vec![0.0; probs.len()];
287                let mut cumsum = 0.0;
288                for j in (0..probs.len()).rev() {
289                    cumsum += probs[j];
290                    surv[j] = cumsum.min(1.0);
291                }
292                surv
293            })
294            .collect();
295
296        Ok(survival)
297    }
298
299    fn predict_quantiles(
300        &self,
301        static_features: Vec<Vec<f64>>,
302        temporal_features: Vec<Vec<Vec<f64>>>,
303    ) -> PyResult<Vec<Vec<Vec<f64>>>> {
304        let survival = self.predict_survival(static_features, temporal_features)?;
305
306        let quantile_predictions: Vec<Vec<Vec<f64>>> = survival
307            .iter()
308            .map(|s| {
309                self.config
310                    .quantiles
311                    .iter()
312                    .map(|&q| {
313                        s.iter()
314                            .map(|&si| si * q + (1.0 - si) * (1.0 - q))
315                            .collect()
316                    })
317                    .collect()
318            })
319            .collect();
320
321        Ok(quantile_predictions)
322    }
323
324    fn get_attention_weights(
325        &self,
326        static_features: Vec<f64>,
327        temporal_features: Vec<Vec<f64>>,
328    ) -> PyResult<Vec<Vec<f64>>> {
329        let seq_len = temporal_features.len();
330        let mut attention_weights = vec![vec![0.0; seq_len]; seq_len];
331
332        let _static_encoded: Vec<f64> = self
333            .static_encoder_weights
334            .iter()
335            .zip(self.static_encoder_biases.iter())
336            .map(|(w, &b)| {
337                let sum: f64 = static_features
338                    .iter()
339                    .zip(w.iter())
340                    .map(|(&x, &wi)| x * wi)
341                    .sum();
342                (sum + b).max(0.0)
343            })
344            .collect();
345
346        let temporal_encoded: Vec<Vec<f64>> = temporal_features
347            .iter()
348            .map(|tf| {
349                self.temporal_encoder_weights
350                    .iter()
351                    .zip(self.temporal_encoder_biases.iter())
352                    .map(|(w, &b)| {
353                        let sum: f64 = tf.iter().zip(w.iter()).map(|(&x, &wi)| x * wi).sum();
354                        (sum + b).max(0.0)
355                    })
356                    .collect()
357            })
358            .collect();
359
360        let d_head = self.config.hidden_dim / self.config.num_heads;
361
362        for t in 0..seq_len {
363            let q: Vec<f64> = temporal_encoded[t][..d_head].to_vec();
364
365            let scores: Vec<f64> = (0..=t)
366                .map(|s| {
367                    let k: Vec<f64> = temporal_encoded[s][..d_head].to_vec();
368                    q.iter()
369                        .zip(k.iter())
370                        .map(|(&qi, &ki)| qi * ki)
371                        .sum::<f64>()
372                        / (d_head as f64).sqrt()
373                })
374                .collect();
375
376            let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
377            let exp_scores: Vec<f64> = scores.iter().map(|&s| (s - max_score).exp()).collect();
378            let sum_exp: f64 = exp_scores.iter().sum();
379
380            for (s, &e) in exp_scores.iter().enumerate() {
381                attention_weights[t][s] = e / sum_exp;
382            }
383        }
384
385        Ok(attention_weights)
386    }
387
388    fn get_time_bins(&self) -> Vec<f64> {
389        self.time_bins.clone()
390    }
391
392    fn __repr__(&self) -> String {
393        format!(
394            "TemporalFusionTransformer(static={}, temporal={}, hidden={})",
395            self.n_static_features, self.n_temporal_features, self.config.hidden_dim
396        )
397    }
398}
399
400#[pyfunction]
401#[pyo3(signature = (
402    static_features,
403    temporal_features,
404    time,
405    event,
406    config=None
407))]
408pub fn fit_temporal_fusion_transformer(
409    static_features: Vec<Vec<f64>>,
410    temporal_features: Vec<Vec<Vec<f64>>>,
411    time: Vec<f64>,
412    event: Vec<i32>,
413    config: Option<TFTConfig>,
414) -> PyResult<TemporalFusionTransformer> {
415    let config = config.unwrap_or_else(|| {
416        TFTConfig::new(64, 4, 2, 2, 0.1, 20, None, 0.001, 64, 100, None).unwrap()
417    });
418
419    let n = static_features.len();
420    if n == 0 || time.len() != n || event.len() != n {
421        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
422            "Input arrays must have the same non-zero length",
423        ));
424    }
425
426    let n_static = static_features[0].len();
427    let n_temporal = temporal_features
428        .first()
429        .and_then(|t| t.first())
430        .map(|f| f.len())
431        .unwrap_or(1);
432
433    let mut rng = fastrand::Rng::new();
434    if let Some(seed) = config.seed {
435        rng.seed(seed);
436    }
437
438    let static_encoder_weights: Vec<Vec<f64>> = (0..config.hidden_dim)
439        .map(|_| (0..n_static).map(|_| rng.f64() * 0.1 - 0.05).collect())
440        .collect();
441    let static_encoder_biases: Vec<f64> = (0..config.hidden_dim)
442        .map(|_| rng.f64() * 0.1 - 0.05)
443        .collect();
444
445    let temporal_encoder_weights: Vec<Vec<f64>> = (0..config.hidden_dim)
446        .map(|_| (0..n_temporal).map(|_| rng.f64() * 0.1 - 0.05).collect())
447        .collect();
448    let temporal_encoder_biases: Vec<f64> = (0..config.hidden_dim)
449        .map(|_| rng.f64() * 0.1 - 0.05)
450        .collect();
451
452    let grn_weights1: Vec<Vec<f64>> = (0..config.hidden_dim)
453        .map(|_| {
454            (0..config.hidden_dim)
455                .map(|_| rng.f64() * 0.1 - 0.05)
456                .collect()
457        })
458        .collect();
459    let grn_weights2: Vec<Vec<f64>> = (0..config.hidden_dim)
460        .map(|_| {
461            (0..config.hidden_dim)
462                .map(|_| rng.f64() * 0.1 - 0.05)
463                .collect()
464        })
465        .collect();
466    let grn_biases: Vec<f64> = (0..config.hidden_dim)
467        .map(|_| rng.f64() * 0.1 - 0.05)
468        .collect();
469
470    let attention_weights: Vec<Vec<f64>> = (0..config.hidden_dim)
471        .map(|_| {
472            (0..config.hidden_dim)
473                .map(|_| rng.f64() * 0.1 - 0.05)
474                .collect()
475        })
476        .collect();
477
478    let output_weights: Vec<Vec<f64>> = (0..config.num_time_bins)
479        .map(|_| {
480            (0..config.hidden_dim)
481                .map(|_| rng.f64() * 0.1 - 0.05)
482                .collect()
483        })
484        .collect();
485    let output_biases: Vec<f64> = (0..config.num_time_bins)
486        .map(|_| rng.f64() * 0.1 - 0.05)
487        .collect();
488
489    let min_time = time.iter().cloned().fold(f64::INFINITY, f64::min);
490    let max_time = time.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
491    let time_bins: Vec<f64> = (0..=config.num_time_bins)
492        .map(|i| min_time + (max_time - min_time) * i as f64 / config.num_time_bins as f64)
493        .collect();
494
495    Ok(TemporalFusionTransformer {
496        static_encoder_weights,
497        static_encoder_biases,
498        temporal_encoder_weights,
499        temporal_encoder_biases,
500        grn_weights1,
501        grn_weights2,
502        grn_biases,
503        attention_weights,
504        output_weights,
505        output_biases,
506        time_bins,
507        config,
508        n_static_features: n_static,
509        n_temporal_features: n_temporal,
510    })
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    #[test]
518    fn test_config_validation() {
519        let result = TFTConfig::new(64, 5, 2, 2, 0.1, 20, None, 0.001, 64, 100, None);
520        assert!(result.is_err());
521    }
522
523    #[test]
524    fn test_grn() {
525        let input = vec![1.0, 2.0];
526        let weights1 = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
527        let weights2 = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
528        let biases = vec![0.0, 0.0];
529        let output = grn(&input, None, &weights1, &weights2, &biases);
530        assert_eq!(output.len(), 2);
531    }
532}