Skip to main content

survival/ml/
survtrace.rs

1#![allow(
2    unused_variables,
3    unused_imports,
4    unused_mut,
5    unused_assignments,
6    clippy::too_many_arguments,
7    clippy::needless_range_loop
8)]
9
10use burn::{
11    backend::{Autodiff, NdArray},
12    module::Module,
13    nn::{Dropout, DropoutConfig, Embedding, EmbeddingConfig, Linear, LinearConfig},
14    optim::{AdamConfig, GradientsParams, Optimizer},
15    prelude::*,
16    tensor::{activation::softplus, backend::AutodiffBackend as AutodiffBackendTrait},
17};
18use pyo3::prelude::*;
19use rayon::prelude::*;
20
21use super::utils::{
22    compute_duration_bins, gelu_cpu, layer_norm_cpu, linear_forward, tensor_to_vec_f32,
23};
24
25type Backend = NdArray;
26type AutodiffBackend = Autodiff<Backend>;
27
28fn gelu<B: burn::prelude::Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
29    let sqrt_2 = (2.0_f32).sqrt();
30    let cdf = (x.clone() / sqrt_2).erf().add_scalar(1.0) * 0.5;
31    x * cdf
32}
33
34fn layer_norm<B: burn::prelude::Backend>(
35    x: Tensor<B, 2>,
36    gamma: Tensor<B, 1>,
37    beta: Tensor<B, 1>,
38    eps: f32,
39) -> Tensor<B, 2> {
40    let [batch, hidden] = x.dims();
41    let mean = x.clone().mean_dim(1);
42    let var = x.clone().var(1);
43    let x_norm = (x - mean) / (var + eps).sqrt();
44    let gamma_expanded: Tensor<B, 2> = gamma.reshape([1, hidden]);
45    let beta_expanded: Tensor<B, 2> = beta.reshape([1, hidden]);
46    x_norm * gamma_expanded + beta_expanded
47}
48
49#[derive(Debug, Clone, Copy, PartialEq)]
50#[pyclass]
51pub enum SurvTraceActivation {
52    GELU,
53    ReLU,
54}
55
56#[pymethods]
57impl SurvTraceActivation {
58    #[new]
59    fn new(name: &str) -> PyResult<Self> {
60        match name.to_lowercase().as_str() {
61            "gelu" => Ok(SurvTraceActivation::GELU),
62            "relu" => Ok(SurvTraceActivation::ReLU),
63            _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
64                "Unknown activation. Use 'gelu' or 'relu'",
65            )),
66        }
67    }
68}
69
70#[derive(Debug, Clone)]
71#[pyclass]
72pub struct SurvTraceConfig {
73    #[pyo3(get, set)]
74    pub hidden_size: usize,
75    #[pyo3(get, set)]
76    pub num_hidden_layers: usize,
77    #[pyo3(get, set)]
78    pub num_attention_heads: usize,
79    #[pyo3(get, set)]
80    pub intermediate_size: usize,
81    #[pyo3(get, set)]
82    pub hidden_dropout_prob: f64,
83    #[pyo3(get, set)]
84    pub attention_dropout_prob: f64,
85    #[pyo3(get, set)]
86    pub num_durations: usize,
87    #[pyo3(get, set)]
88    pub num_events: usize,
89    #[pyo3(get, set)]
90    pub vocab_size: usize,
91    #[pyo3(get, set)]
92    pub learning_rate: f64,
93    #[pyo3(get, set)]
94    pub batch_size: usize,
95    #[pyo3(get, set)]
96    pub n_epochs: usize,
97    #[pyo3(get, set)]
98    pub weight_decay: f64,
99    #[pyo3(get, set)]
100    pub seed: Option<u64>,
101    #[pyo3(get, set)]
102    pub early_stopping_patience: Option<usize>,
103    #[pyo3(get, set)]
104    pub validation_fraction: f64,
105    #[pyo3(get, set)]
106    pub layer_norm_eps: f32,
107}
108
109#[pymethods]
110impl SurvTraceConfig {
111    #[new]
112    #[pyo3(signature = (
113        hidden_size=16,
114        num_hidden_layers=3,
115        num_attention_heads=2,
116        intermediate_size=64,
117        hidden_dropout_prob=0.0,
118        attention_dropout_prob=0.1,
119        num_durations=5,
120        num_events=1,
121        vocab_size=8,
122        learning_rate=0.001,
123        batch_size=64,
124        n_epochs=100,
125        weight_decay=0.0001,
126        seed=None,
127        early_stopping_patience=None,
128        validation_fraction=0.1,
129        layer_norm_eps=1e-12
130    ))]
131    pub fn new(
132        hidden_size: usize,
133        num_hidden_layers: usize,
134        num_attention_heads: usize,
135        intermediate_size: usize,
136        hidden_dropout_prob: f64,
137        attention_dropout_prob: f64,
138        num_durations: usize,
139        num_events: usize,
140        vocab_size: usize,
141        learning_rate: f64,
142        batch_size: usize,
143        n_epochs: usize,
144        weight_decay: f64,
145        seed: Option<u64>,
146        early_stopping_patience: Option<usize>,
147        validation_fraction: f64,
148        layer_norm_eps: f32,
149    ) -> PyResult<Self> {
150        if hidden_size == 0 {
151            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
152                "hidden_size must be positive",
153            ));
154        }
155        if num_hidden_layers == 0 {
156            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
157                "num_hidden_layers must be positive",
158            ));
159        }
160        if num_attention_heads == 0 {
161            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
162                "num_attention_heads must be positive",
163            ));
164        }
165        if !hidden_size.is_multiple_of(num_attention_heads) {
166            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
167                "hidden_size must be divisible by num_attention_heads",
168            ));
169        }
170        if num_durations == 0 {
171            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
172                "num_durations must be positive",
173            ));
174        }
175        if batch_size == 0 {
176            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
177                "batch_size must be positive",
178            ));
179        }
180        if n_epochs == 0 {
181            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
182                "n_epochs must be positive",
183            ));
184        }
185        if !(0.0..1.0).contains(&validation_fraction) {
186            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
187                "validation_fraction must be in [0, 1)",
188            ));
189        }
190
191        Ok(SurvTraceConfig {
192            hidden_size,
193            num_hidden_layers,
194            num_attention_heads,
195            intermediate_size,
196            hidden_dropout_prob,
197            attention_dropout_prob,
198            num_durations,
199            num_events,
200            vocab_size,
201            learning_rate,
202            batch_size,
203            n_epochs,
204            weight_decay,
205            seed,
206            early_stopping_patience,
207            validation_fraction,
208            layer_norm_eps,
209        })
210    }
211}
212
213#[derive(Module, Debug)]
214struct MultiHeadAttention<B: burn::prelude::Backend> {
215    query: Linear<B>,
216    key: Linear<B>,
217    value: Linear<B>,
218    output: Linear<B>,
219    dropout: Dropout,
220    num_heads: usize,
221    head_dim: usize,
222}
223
224impl<B: burn::prelude::Backend> MultiHeadAttention<B> {
225    fn new(device: &B::Device, hidden_size: usize, num_heads: usize, dropout_prob: f64) -> Self {
226        let head_dim = hidden_size / num_heads;
227
228        Self {
229            query: LinearConfig::new(hidden_size, hidden_size).init(device),
230            key: LinearConfig::new(hidden_size, hidden_size).init(device),
231            value: LinearConfig::new(hidden_size, hidden_size).init(device),
232            output: LinearConfig::new(hidden_size, hidden_size).init(device),
233            dropout: DropoutConfig::new(dropout_prob).init(),
234            num_heads,
235            head_dim,
236        }
237    }
238
239    fn forward(&self, x: Tensor<B, 2>, training: bool) -> Tensor<B, 2> {
240        let [batch_size, hidden_size] = x.dims();
241        let seq_len = 1;
242
243        let q = self.query.forward(x.clone());
244        let k = self.key.forward(x.clone());
245        let v = self.value.forward(x);
246
247        let q = q
248            .reshape([batch_size, seq_len, self.num_heads, self.head_dim])
249            .swap_dims(1, 2);
250        let k = k
251            .reshape([batch_size, seq_len, self.num_heads, self.head_dim])
252            .swap_dims(1, 2);
253        let v = v
254            .reshape([batch_size, seq_len, self.num_heads, self.head_dim])
255            .swap_dims(1, 2);
256
257        let scale = (self.head_dim as f32).sqrt();
258        let scores = q.matmul(k.swap_dims(2, 3)) / scale;
259        let attn_weights = burn::tensor::activation::softmax(scores, 3);
260
261        let attn_weights = if training {
262            self.dropout.forward(attn_weights)
263        } else {
264            attn_weights
265        };
266
267        let context = attn_weights.matmul(v);
268        let context = context.swap_dims(1, 2).reshape([batch_size, hidden_size]);
269
270        self.output.forward(context)
271    }
272}
273
274#[derive(Module, Debug)]
275struct TransformerLayer<B: burn::prelude::Backend> {
276    attention: MultiHeadAttention<B>,
277    intermediate: Linear<B>,
278    output_dense: Linear<B>,
279    layer_norm1_gamma: burn::module::Param<Tensor<B, 1>>,
280    layer_norm1_beta: burn::module::Param<Tensor<B, 1>>,
281    layer_norm2_gamma: burn::module::Param<Tensor<B, 1>>,
282    layer_norm2_beta: burn::module::Param<Tensor<B, 1>>,
283    dropout: Dropout,
284    layer_norm_eps: f32,
285}
286
287impl<B: burn::prelude::Backend> TransformerLayer<B> {
288    fn new(
289        device: &B::Device,
290        hidden_size: usize,
291        num_heads: usize,
292        intermediate_size: usize,
293        hidden_dropout_prob: f64,
294        attention_dropout_prob: f64,
295        layer_norm_eps: f32,
296    ) -> Self {
297        Self {
298            attention: MultiHeadAttention::new(
299                device,
300                hidden_size,
301                num_heads,
302                attention_dropout_prob,
303            ),
304            intermediate: LinearConfig::new(hidden_size, intermediate_size).init(device),
305            output_dense: LinearConfig::new(intermediate_size, hidden_size).init(device),
306            layer_norm1_gamma: burn::module::Param::from_tensor(Tensor::ones(
307                [hidden_size],
308                device,
309            )),
310            layer_norm1_beta: burn::module::Param::from_tensor(Tensor::zeros(
311                [hidden_size],
312                device,
313            )),
314            layer_norm2_gamma: burn::module::Param::from_tensor(Tensor::ones(
315                [hidden_size],
316                device,
317            )),
318            layer_norm2_beta: burn::module::Param::from_tensor(Tensor::zeros(
319                [hidden_size],
320                device,
321            )),
322            dropout: DropoutConfig::new(hidden_dropout_prob).init(),
323            layer_norm_eps,
324        }
325    }
326
327    fn forward(&self, x: Tensor<B, 2>, training: bool) -> Tensor<B, 2> {
328        let attn_output = self.attention.forward(x.clone(), training);
329        let attn_output = if training {
330            self.dropout.forward(attn_output)
331        } else {
332            attn_output
333        };
334        let x = layer_norm(
335            x + attn_output,
336            self.layer_norm1_gamma.val(),
337            self.layer_norm1_beta.val(),
338            self.layer_norm_eps,
339        );
340
341        let intermediate = self.intermediate.forward(x.clone());
342        let intermediate = gelu(intermediate);
343        let output = self.output_dense.forward(intermediate);
344        let output = if training {
345            self.dropout.forward(output)
346        } else {
347            output
348        };
349
350        layer_norm(
351            x + output,
352            self.layer_norm2_gamma.val(),
353            self.layer_norm2_beta.val(),
354            self.layer_norm_eps,
355        )
356    }
357}
358
359#[derive(Module, Debug)]
360struct SurvTraceNetwork<B: burn::prelude::Backend> {
361    cat_embeddings: Vec<Embedding<B>>,
362    num_projection: Linear<B>,
363    transformer_layers: Vec<TransformerLayer<B>>,
364    output_heads: Vec<Linear<B>>,
365    hidden_size: usize,
366    num_cat_features: usize,
367    num_num_features: usize,
368    num_events: usize,
369    num_durations: usize,
370}
371
372impl<B: burn::prelude::Backend> SurvTraceNetwork<B> {
373    fn new(
374        device: &B::Device,
375        num_cat_features: usize,
376        num_num_features: usize,
377        cat_cardinalities: &[usize],
378        config: &SurvTraceConfig,
379    ) -> Self {
380        let mut cat_embeddings = Vec::new();
381        for &card in cat_cardinalities {
382            cat_embeddings.push(EmbeddingConfig::new(card.max(2), config.hidden_size).init(device));
383        }
384
385        let num_projection = if num_num_features > 0 {
386            LinearConfig::new(num_num_features, config.hidden_size).init(device)
387        } else {
388            LinearConfig::new(1, config.hidden_size).init(device)
389        };
390
391        let mut transformer_layers = Vec::new();
392        for _ in 0..config.num_hidden_layers {
393            transformer_layers.push(TransformerLayer::new(
394                device,
395                config.hidden_size,
396                config.num_attention_heads,
397                config.intermediate_size,
398                config.hidden_dropout_prob,
399                config.attention_dropout_prob,
400                config.layer_norm_eps,
401            ));
402        }
403
404        let mut output_heads = Vec::new();
405        let num_events = config.num_events.max(1);
406        for _ in 0..num_events {
407            output_heads
408                .push(LinearConfig::new(config.hidden_size, config.num_durations).init(device));
409        }
410
411        Self {
412            cat_embeddings,
413            num_projection,
414            transformer_layers,
415            output_heads,
416            hidden_size: config.hidden_size,
417            num_cat_features,
418            num_num_features,
419            num_events,
420            num_durations: config.num_durations,
421        }
422    }
423
424    fn forward(
425        &self,
426        x_cat: Option<Tensor<B, 2, Int>>,
427        x_num: Tensor<B, 2>,
428        training: bool,
429    ) -> Vec<Tensor<B, 2>> {
430        let [batch_size, _] = x_num.dims();
431        let device = x_num.device();
432
433        let mut embeddings: Tensor<B, 2> = Tensor::zeros([batch_size, self.hidden_size], &device);
434
435        if let Some(x_cat) = x_cat {
436            for (i, emb) in self.cat_embeddings.iter().enumerate() {
437                let cat_slice: Tensor<B, 2, Int> = x_cat.clone().slice([0..batch_size, i..i + 1]);
438                let cat_emb_3d: Tensor<B, 3> = emb.forward(cat_slice);
439                let cat_emb: Tensor<B, 2> = cat_emb_3d.squeeze::<2>();
440                embeddings = embeddings + cat_emb;
441            }
442        }
443
444        if self.num_num_features > 0 {
445            let num_emb = self.num_projection.forward(x_num);
446            embeddings = embeddings + num_emb;
447        }
448
449        let mut hidden = embeddings;
450        for layer in &self.transformer_layers {
451            hidden = layer.forward(hidden, training);
452        }
453
454        let mut outputs = Vec::new();
455        for head in &self.output_heads {
456            let logits = head.forward(hidden.clone());
457            outputs.push(logits);
458        }
459
460        outputs
461    }
462
463    fn forward_inference(
464        &self,
465        x_cat: Option<Tensor<B, 2, Int>>,
466        x_num: Tensor<B, 2>,
467    ) -> Vec<Tensor<B, 2>> {
468        self.forward(x_cat, x_num, false)
469    }
470}
471
472fn compute_nll_logistic_hazard_loss(
473    logits: &[f32],
474    durations: &[usize],
475    events: &[i32],
476    num_durations: usize,
477    batch_indices: &[usize],
478) -> f64 {
479    let mut total_loss = 0.0;
480    let mut n_events = 0;
481
482    for (local_idx, &global_idx) in batch_indices.iter().enumerate() {
483        let duration_bin = durations[global_idx].min(num_durations - 1);
484        let event = events[global_idx];
485
486        for t in 0..=duration_bin {
487            let logit = logits[local_idx * num_durations + t];
488            let target = if t == duration_bin && event == 1 {
489                1.0
490            } else {
491                0.0
492            };
493
494            let loss = if target > 0.5 {
495                (1.0 + (-logit).exp()).ln()
496            } else {
497                logit + (1.0 + (-logit).exp()).ln()
498            };
499            total_loss += loss as f64;
500        }
501
502        if event == 1 {
503            n_events += 1;
504        }
505    }
506
507    if n_events > 0 {
508        total_loss / n_events as f64
509    } else {
510        total_loss / batch_indices.len().max(1) as f64
511    }
512}
513
514fn compute_nll_logistic_hazard_gradient(
515    logits: &[f32],
516    durations: &[usize],
517    events: &[i32],
518    num_durations: usize,
519    batch_indices: &[usize],
520) -> Vec<f32> {
521    let batch_size = batch_indices.len();
522    let mut gradients = vec![0.0f32; batch_size * num_durations];
523
524    for (local_idx, &global_idx) in batch_indices.iter().enumerate() {
525        let duration_bin = durations[global_idx].min(num_durations - 1);
526        let event = events[global_idx];
527
528        for t in 0..=duration_bin {
529            let logit = logits[local_idx * num_durations + t];
530            let pred = 1.0 / (1.0 + (-logit).exp());
531            let target = if t == duration_bin && event == 1 {
532                1.0
533            } else {
534                0.0
535            };
536            gradients[local_idx * num_durations + t] = pred - target;
537        }
538    }
539
540    let n_events: i32 = batch_indices.iter().map(|&i| events[i]).sum();
541    let divisor = if n_events > 0 {
542        n_events as f32
543    } else {
544        batch_size.max(1) as f32
545    };
546
547    for g in &mut gradients {
548        *g /= divisor;
549    }
550
551    gradients
552}
553
554#[derive(Clone)]
555struct StoredWeights {
556    cat_embeddings: Vec<Vec<f32>>,
557    cat_embedding_dims: Vec<(usize, usize)>,
558    num_projection_weights: Vec<f32>,
559    num_projection_bias: Vec<f32>,
560    num_projection_dims: (usize, usize),
561    transformer_layers: Vec<TransformerLayerWeights>,
562    output_heads: Vec<(Vec<f32>, Vec<f32>, usize, usize)>,
563    hidden_size: usize,
564    num_cat_features: usize,
565    num_num_features: usize,
566    num_events: usize,
567}
568
569#[derive(Clone)]
570struct TransformerLayerWeights {
571    query_w: Vec<f32>,
572    query_b: Vec<f32>,
573    key_w: Vec<f32>,
574    key_b: Vec<f32>,
575    value_w: Vec<f32>,
576    value_b: Vec<f32>,
577    output_w: Vec<f32>,
578    output_b: Vec<f32>,
579    intermediate_w: Vec<f32>,
580    intermediate_b: Vec<f32>,
581    output_dense_w: Vec<f32>,
582    output_dense_b: Vec<f32>,
583    ln1_gamma: Vec<f32>,
584    ln1_beta: Vec<f32>,
585    ln2_gamma: Vec<f32>,
586    ln2_beta: Vec<f32>,
587    hidden_size: usize,
588    intermediate_size: usize,
589    num_heads: usize,
590}
591
592impl std::fmt::Debug for StoredWeights {
593    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
594        f.debug_struct("StoredWeights")
595            .field("num_transformer_layers", &self.transformer_layers.len())
596            .field("num_events", &self.num_events)
597            .finish()
598    }
599}
600
601fn extract_weights(
602    model: &SurvTraceNetwork<AutodiffBackend>,
603    config: &SurvTraceConfig,
604    cat_cardinalities: &[usize],
605) -> StoredWeights {
606    let mut cat_embeddings = Vec::new();
607    let mut cat_embedding_dims = Vec::new();
608
609    for (i, emb) in model.cat_embeddings.iter().enumerate() {
610        let w: Vec<f32> = emb
611            .weight
612            .val()
613            .inner()
614            .into_data()
615            .to_vec()
616            .unwrap_or_default();
617        cat_embeddings.push(w);
618        cat_embedding_dims.push((
619            cat_cardinalities.get(i).copied().unwrap_or(2),
620            config.hidden_size,
621        ));
622    }
623
624    let num_proj_w: Tensor<AutodiffBackend, 2> = model.num_projection.weight.val();
625    let num_projection_weights: Vec<f32> = tensor_to_vec_f32(num_proj_w.inner());
626    let num_projection_bias: Vec<f32> = model
627        .num_projection
628        .bias
629        .as_ref()
630        .map(|b| b.val().inner().into_data().to_vec().unwrap_or_default())
631        .unwrap_or_default();
632    let num_projection_dims = (model.num_num_features.max(1), config.hidden_size);
633
634    let mut transformer_layers = Vec::new();
635    for layer in &model.transformer_layers {
636        let tlw = TransformerLayerWeights {
637            query_w: tensor_to_vec_f32(layer.attention.query.weight.val().inner()),
638            query_b: layer
639                .attention
640                .query
641                .bias
642                .as_ref()
643                .map(|b| b.val().inner().into_data().to_vec().unwrap_or_default())
644                .unwrap_or_default(),
645            key_w: tensor_to_vec_f32(layer.attention.key.weight.val().inner()),
646            key_b: layer
647                .attention
648                .key
649                .bias
650                .as_ref()
651                .map(|b| b.val().inner().into_data().to_vec().unwrap_or_default())
652                .unwrap_or_default(),
653            value_w: tensor_to_vec_f32(layer.attention.value.weight.val().inner()),
654            value_b: layer
655                .attention
656                .value
657                .bias
658                .as_ref()
659                .map(|b| b.val().inner().into_data().to_vec().unwrap_or_default())
660                .unwrap_or_default(),
661            output_w: tensor_to_vec_f32(layer.attention.output.weight.val().inner()),
662            output_b: layer
663                .attention
664                .output
665                .bias
666                .as_ref()
667                .map(|b| b.val().inner().into_data().to_vec().unwrap_or_default())
668                .unwrap_or_default(),
669            intermediate_w: tensor_to_vec_f32(layer.intermediate.weight.val().inner()),
670            intermediate_b: layer
671                .intermediate
672                .bias
673                .as_ref()
674                .map(|b| b.val().inner().into_data().to_vec().unwrap_or_default())
675                .unwrap_or_default(),
676            output_dense_w: tensor_to_vec_f32(layer.output_dense.weight.val().inner()),
677            output_dense_b: layer
678                .output_dense
679                .bias
680                .as_ref()
681                .map(|b| b.val().inner().into_data().to_vec().unwrap_or_default())
682                .unwrap_or_default(),
683            ln1_gamma: layer
684                .layer_norm1_gamma
685                .val()
686                .inner()
687                .into_data()
688                .to_vec()
689                .unwrap_or_default(),
690            ln1_beta: layer
691                .layer_norm1_beta
692                .val()
693                .inner()
694                .into_data()
695                .to_vec()
696                .unwrap_or_default(),
697            ln2_gamma: layer
698                .layer_norm2_gamma
699                .val()
700                .inner()
701                .into_data()
702                .to_vec()
703                .unwrap_or_default(),
704            ln2_beta: layer
705                .layer_norm2_beta
706                .val()
707                .inner()
708                .into_data()
709                .to_vec()
710                .unwrap_or_default(),
711            hidden_size: config.hidden_size,
712            intermediate_size: config.intermediate_size,
713            num_heads: config.num_attention_heads,
714        };
715        transformer_layers.push(tlw);
716    }
717
718    let mut output_heads = Vec::new();
719    for head in &model.output_heads {
720        let w: Vec<f32> = tensor_to_vec_f32(head.weight.val().inner());
721        let b: Vec<f32> = head
722            .bias
723            .as_ref()
724            .map(|bias| bias.val().inner().into_data().to_vec().unwrap_or_default())
725            .unwrap_or_default();
726        output_heads.push((w, b, config.hidden_size, config.num_durations));
727    }
728
729    StoredWeights {
730        cat_embeddings,
731        cat_embedding_dims,
732        num_projection_weights,
733        num_projection_bias,
734        num_projection_dims,
735        transformer_layers,
736        output_heads,
737        hidden_size: config.hidden_size,
738        num_cat_features: model.num_cat_features,
739        num_num_features: model.num_num_features,
740        num_events: model.num_events,
741    }
742}
743
744fn predict_with_weights(
745    x_cat: Option<&[i64]>,
746    x_num: &[f64],
747    n: usize,
748    weights: &StoredWeights,
749    layer_norm_eps: f32,
750) -> Vec<Vec<f64>> {
751    let hidden_size = weights.hidden_size;
752    let num_num = weights.num_num_features;
753    let num_cat = weights.num_cat_features;
754
755    let mut all_outputs: Vec<Vec<f64>> = vec![Vec::new(); weights.num_events];
756
757    for i in 0..n {
758        let mut hidden = vec![0.0f64; hidden_size];
759
760        if let Some(cats) = x_cat {
761            for (feat_idx, emb_weights) in weights.cat_embeddings.iter().enumerate() {
762                let (vocab_size, emb_dim) = weights.cat_embedding_dims[feat_idx];
763                let cat_val = cats[i * num_cat + feat_idx] as usize;
764                let cat_val = cat_val.min(vocab_size - 1);
765                for j in 0..emb_dim {
766                    hidden[j] += emb_weights[cat_val * emb_dim + j] as f64;
767                }
768            }
769        }
770
771        if num_num > 0 {
772            let (in_dim, out_dim) = weights.num_projection_dims;
773            for j in 0..out_dim {
774                let mut sum = if !weights.num_projection_bias.is_empty() {
775                    weights.num_projection_bias[j] as f64
776                } else {
777                    0.0
778                };
779                for k in 0..in_dim.min(num_num) {
780                    sum += x_num[i * num_num + k]
781                        * weights.num_projection_weights[j * in_dim + k] as f64;
782                }
783                hidden[j] += sum;
784            }
785        }
786
787        for layer in &weights.transformer_layers {
788            hidden = apply_transformer_layer_cpu(&hidden, layer, layer_norm_eps);
789        }
790
791        for (event_idx, (w, b, in_dim, out_dim)) in weights.output_heads.iter().enumerate() {
792            let mut logits = Vec::with_capacity(*out_dim);
793            for j in 0..*out_dim {
794                let mut sum = if !b.is_empty() { b[j] as f64 } else { 0.0 };
795                for k in 0..*in_dim {
796                    sum += hidden[k] * w[j * in_dim + k] as f64;
797                }
798                logits.push(sum);
799            }
800            all_outputs[event_idx].extend(logits);
801        }
802    }
803
804    all_outputs
805}
806
807fn apply_transformer_layer_cpu(
808    hidden: &[f64],
809    layer: &TransformerLayerWeights,
810    eps: f32,
811) -> Vec<f64> {
812    let h = layer.hidden_size;
813
814    let q = linear_forward(hidden, &layer.query_w, &layer.query_b, h, h);
815    let k = linear_forward(hidden, &layer.key_w, &layer.key_b, h, h);
816    let v = linear_forward(hidden, &layer.value_w, &layer.value_b, h, h);
817
818    let head_dim = h / layer.num_heads;
819    let mut attn_output = vec![0.0f64; h];
820
821    for head in 0..layer.num_heads {
822        let start = head * head_dim;
823        let end = start + head_dim;
824
825        let mut score = 0.0;
826        for i in start..end {
827            score += q[i] * k[i];
828        }
829        score /= (head_dim as f64).sqrt();
830        let attn_weight = 1.0;
831
832        for i in start..end {
833            attn_output[i] = attn_weight * v[i];
834        }
835    }
836
837    let attn_proj = linear_forward(&attn_output, &layer.output_w, &layer.output_b, h, h);
838
839    let mut residual1: Vec<f64> = hidden.iter().zip(&attn_proj).map(|(a, b)| a + b).collect();
840    residual1 = layer_norm_cpu(&residual1, &layer.ln1_gamma, &layer.ln1_beta, eps);
841
842    let intermediate = linear_forward(
843        &residual1,
844        &layer.intermediate_w,
845        &layer.intermediate_b,
846        h,
847        layer.intermediate_size,
848    );
849    let intermediate: Vec<f64> = intermediate.iter().map(|&x| gelu_cpu(x)).collect();
850
851    let output = linear_forward(
852        &intermediate,
853        &layer.output_dense_w,
854        &layer.output_dense_b,
855        layer.intermediate_size,
856        h,
857    );
858
859    let mut residual2: Vec<f64> = residual1.iter().zip(&output).map(|(a, b)| a + b).collect();
860    residual2 = layer_norm_cpu(&residual2, &layer.ln2_gamma, &layer.ln2_beta, eps);
861
862    residual2
863}
864
865fn fit_survtrace_inner(
866    x_cat: Option<&[i64]>,
867    x_num: &[f64],
868    n_obs: usize,
869    num_cat_features: usize,
870    num_num_features: usize,
871    cat_cardinalities: &[usize],
872    time: &[f64],
873    event: &[i32],
874    config: &SurvTraceConfig,
875) -> SurvTrace {
876    let device: <Backend as burn::prelude::Backend>::Device = Default::default();
877    let seed = config.seed.unwrap_or(42);
878
879    let (duration_bins, cuts) = compute_duration_bins(time, config.num_durations);
880
881    let mut model: SurvTraceNetwork<AutodiffBackend> = SurvTraceNetwork::new(
882        &device,
883        num_cat_features,
884        num_num_features,
885        cat_cardinalities,
886        config,
887    );
888
889    let mut optimizer = AdamConfig::new()
890        .with_weight_decay(Some(burn::optim::decay::WeightDecayConfig::new(
891            config.weight_decay as f32,
892        )))
893        .init();
894
895    let n_val = (n_obs as f64 * config.validation_fraction).floor() as usize;
896    let n_train = n_obs - n_val;
897
898    let mut rng = fastrand::Rng::with_seed(seed);
899    let mut shuffled_indices: Vec<usize> = (0..n_obs).collect();
900    for i in (1..n_obs).rev() {
901        let j = rng.usize(0..=i);
902        shuffled_indices.swap(i, j);
903    }
904
905    let train_indices: Vec<usize> = shuffled_indices[..n_train].to_vec();
906    let val_indices: Vec<usize> = shuffled_indices[n_train..].to_vec();
907
908    let mut train_loss_history = Vec::new();
909    let mut val_loss_history = Vec::new();
910    let mut best_val_loss = f64::INFINITY;
911    let mut epochs_without_improvement = 0;
912    let mut best_weights: Option<StoredWeights> = None;
913
914    for epoch in 0..config.n_epochs {
915        let mut epoch_indices = train_indices.clone();
916        for i in (1..epoch_indices.len()).rev() {
917            let j = rng.usize(0..=i);
918            epoch_indices.swap(i, j);
919        }
920
921        let mut epoch_loss = 0.0;
922        let mut n_batches = 0;
923
924        for batch_start in (0..n_train).step_by(config.batch_size) {
925            let batch_end = (batch_start + config.batch_size).min(n_train);
926            let batch_indices: Vec<usize> = epoch_indices[batch_start..batch_end].to_vec();
927            let batch_size = batch_indices.len();
928
929            let x_num_batch: Vec<f32> = batch_indices
930                .iter()
931                .flat_map(|&i| {
932                    (0..num_num_features).map(move |j| x_num[i * num_num_features + j] as f32)
933                })
934                .collect();
935
936            let x_num_data = burn::tensor::TensorData::new(
937                x_num_batch.clone(),
938                [batch_size, num_num_features.max(1)],
939            );
940            let x_num_tensor: Tensor<AutodiffBackend, 2> = Tensor::from_data(x_num_data, &device);
941
942            let x_cat_tensor: Option<Tensor<AutodiffBackend, 2, Int>> = if num_cat_features > 0 {
943                if let Some(cats) = x_cat {
944                    let x_cat_batch: Vec<i64> = batch_indices
945                        .iter()
946                        .flat_map(|&i| {
947                            (0..num_cat_features).map(move |j| cats[i * num_cat_features + j])
948                        })
949                        .collect();
950                    let x_cat_data =
951                        burn::tensor::TensorData::new(x_cat_batch, [batch_size, num_cat_features]);
952                    Some(Tensor::from_data(x_cat_data, &device))
953                } else {
954                    None
955                }
956            } else {
957                None
958            };
959
960            let outputs = model.forward(x_cat_tensor, x_num_tensor, true);
961
962            let mut total_loss = 0.0;
963            let mut all_grads: Vec<Vec<f32>> = Vec::new();
964
965            for (event_idx, logits_tensor) in outputs.iter().enumerate() {
966                let logits_vec: Vec<f32> = tensor_to_vec_f32(logits_tensor.clone().inner());
967
968                let loss = compute_nll_logistic_hazard_loss(
969                    &logits_vec,
970                    &duration_bins,
971                    event,
972                    config.num_durations,
973                    &batch_indices,
974                );
975                total_loss += loss;
976
977                let grads = compute_nll_logistic_hazard_gradient(
978                    &logits_vec,
979                    &duration_bins,
980                    event,
981                    config.num_durations,
982                    &batch_indices,
983                );
984                all_grads.push(grads);
985            }
986
987            epoch_loss += total_loss;
988            n_batches += 1;
989
990            if !all_grads.is_empty() {
991                let grad_data = burn::tensor::TensorData::new(
992                    all_grads[0].clone(),
993                    [batch_size, config.num_durations],
994                );
995                let grad_tensor: Tensor<AutodiffBackend, 2> = Tensor::from_data(grad_data, &device);
996
997                let pseudo_loss = (outputs[0].clone() * grad_tensor).mean();
998                let grads = pseudo_loss.backward();
999                let grads = GradientsParams::from_grads(grads, &model);
1000                model = optimizer.step(config.learning_rate, model, grads);
1001            }
1002        }
1003
1004        let avg_train_loss = if n_batches > 0 {
1005            epoch_loss / n_batches as f64
1006        } else {
1007            0.0
1008        };
1009        train_loss_history.push(avg_train_loss);
1010
1011        if !val_indices.is_empty() {
1012            let x_num_val: Vec<f32> = val_indices
1013                .iter()
1014                .flat_map(|&i| {
1015                    (0..num_num_features).map(move |j| x_num[i * num_num_features + j] as f32)
1016                })
1017                .collect();
1018
1019            let x_num_val_data =
1020                burn::tensor::TensorData::new(x_num_val, [n_val, num_num_features.max(1)]);
1021            let x_num_val_tensor: Tensor<AutodiffBackend, 2> =
1022                Tensor::from_data(x_num_val_data, &device);
1023
1024            let x_cat_val_tensor: Option<Tensor<AutodiffBackend, 2, Int>> = if num_cat_features > 0
1025            {
1026                if let Some(cats) = x_cat {
1027                    let x_cat_val: Vec<i64> = val_indices
1028                        .iter()
1029                        .flat_map(|&i| {
1030                            (0..num_cat_features).map(move |j| cats[i * num_cat_features + j])
1031                        })
1032                        .collect();
1033                    let x_cat_val_data =
1034                        burn::tensor::TensorData::new(x_cat_val, [n_val, num_cat_features]);
1035                    Some(Tensor::from_data(x_cat_val_data, &device))
1036                } else {
1037                    None
1038                }
1039            } else {
1040                None
1041            };
1042
1043            let val_outputs = model.forward_inference(x_cat_val_tensor, x_num_val_tensor);
1044
1045            let mut val_loss = 0.0;
1046            for logits_tensor in &val_outputs {
1047                let logits_vec: Vec<f32> = tensor_to_vec_f32(logits_tensor.clone().inner());
1048                val_loss += compute_nll_logistic_hazard_loss(
1049                    &logits_vec,
1050                    &duration_bins,
1051                    event,
1052                    config.num_durations,
1053                    &val_indices,
1054                );
1055            }
1056            val_loss_history.push(val_loss);
1057
1058            if val_loss < best_val_loss {
1059                best_val_loss = val_loss;
1060                epochs_without_improvement = 0;
1061                best_weights = Some(extract_weights(&model, config, cat_cardinalities));
1062            } else {
1063                epochs_without_improvement += 1;
1064            }
1065
1066            if let Some(patience) = config.early_stopping_patience
1067                && epochs_without_improvement >= patience
1068            {
1069                break;
1070            }
1071        }
1072    }
1073
1074    let final_weights =
1075        best_weights.unwrap_or_else(|| extract_weights(&model, config, cat_cardinalities));
1076
1077    SurvTrace {
1078        weights: final_weights,
1079        config: config.clone(),
1080        duration_cuts: cuts,
1081        train_loss: train_loss_history,
1082        val_loss: val_loss_history,
1083        cat_cardinalities: cat_cardinalities.to_vec(),
1084    }
1085}
1086
1087#[derive(Debug, Clone)]
1088#[pyclass]
1089pub struct SurvTrace {
1090    weights: StoredWeights,
1091    config: SurvTraceConfig,
1092    #[pyo3(get)]
1093    pub duration_cuts: Vec<f64>,
1094    #[pyo3(get)]
1095    pub train_loss: Vec<f64>,
1096    #[pyo3(get)]
1097    pub val_loss: Vec<f64>,
1098    #[pyo3(get)]
1099    pub cat_cardinalities: Vec<usize>,
1100}
1101
1102#[pymethods]
1103impl SurvTrace {
1104    #[staticmethod]
1105    #[pyo3(signature = (x_cat, x_num, n_obs, num_cat_features, num_num_features, cat_cardinalities, time, event, config))]
1106    pub fn fit(
1107        py: Python<'_>,
1108        x_cat: Option<Vec<i64>>,
1109        x_num: Vec<f64>,
1110        n_obs: usize,
1111        num_cat_features: usize,
1112        num_num_features: usize,
1113        cat_cardinalities: Vec<usize>,
1114        time: Vec<f64>,
1115        event: Vec<i32>,
1116        config: &SurvTraceConfig,
1117    ) -> PyResult<Self> {
1118        if x_num.len() != n_obs * num_num_features.max(1) && num_num_features > 0 {
1119            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
1120                "x_num length must equal n_obs * num_num_features",
1121            ));
1122        }
1123        if time.len() != n_obs || event.len() != n_obs {
1124            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
1125                "time and event must have length n_obs",
1126            ));
1127        }
1128        if let Some(ref cats) = x_cat
1129            && cats.len() != n_obs * num_cat_features
1130        {
1131            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
1132                "x_cat length must equal n_obs * num_cat_features",
1133            ));
1134        }
1135
1136        let config = config.clone();
1137        let x_cat_clone = x_cat.clone();
1138        Ok(py.detach(move || {
1139            fit_survtrace_inner(
1140                x_cat_clone.as_deref(),
1141                &x_num,
1142                n_obs,
1143                num_cat_features,
1144                num_num_features,
1145                &cat_cardinalities,
1146                &time,
1147                &event,
1148                &config,
1149            )
1150        }))
1151    }
1152
1153    #[pyo3(signature = (x_cat, x_num, n_new, event_idx=0))]
1154    pub fn predict_hazard(
1155        &self,
1156        x_cat: Option<Vec<i64>>,
1157        x_num: Vec<f64>,
1158        n_new: usize,
1159        event_idx: usize,
1160    ) -> PyResult<Vec<Vec<f64>>> {
1161        let outputs = predict_with_weights(
1162            x_cat.as_deref(),
1163            &x_num,
1164            n_new,
1165            &self.weights,
1166            self.config.layer_norm_eps,
1167        );
1168
1169        if event_idx >= outputs.len() {
1170            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
1171                "event_idx out of range",
1172            ));
1173        }
1174
1175        let logits = &outputs[event_idx];
1176        let num_durations = self.config.num_durations;
1177
1178        let hazards: Vec<Vec<f64>> = (0..n_new)
1179            .map(|i| {
1180                (0..num_durations)
1181                    .map(|t| {
1182                        let logit = logits[i * num_durations + t];
1183                        1.0 / (1.0 + (-logit).exp())
1184                    })
1185                    .collect()
1186            })
1187            .collect();
1188
1189        Ok(hazards)
1190    }
1191
1192    #[pyo3(signature = (x_cat, x_num, n_new, event_idx=0))]
1193    pub fn predict_survival(
1194        &self,
1195        x_cat: Option<Vec<i64>>,
1196        x_num: Vec<f64>,
1197        n_new: usize,
1198        event_idx: usize,
1199    ) -> PyResult<Vec<Vec<f64>>> {
1200        let hazards = self.predict_hazard(x_cat, x_num, n_new, event_idx)?;
1201
1202        let survival: Vec<Vec<f64>> = hazards
1203            .par_iter()
1204            .map(|h| {
1205                let mut surv = Vec::with_capacity(h.len());
1206                let mut cum_surv = 1.0;
1207                for &haz in h {
1208                    cum_surv *= 1.0 - haz;
1209                    surv.push(cum_surv);
1210                }
1211                surv
1212            })
1213            .collect();
1214
1215        Ok(survival)
1216    }
1217
1218    #[pyo3(signature = (x_cat, x_num, n_new, event_idx=0))]
1219    pub fn predict_risk(
1220        &self,
1221        x_cat: Option<Vec<i64>>,
1222        x_num: Vec<f64>,
1223        n_new: usize,
1224        event_idx: usize,
1225    ) -> PyResult<Vec<f64>> {
1226        let survival = self.predict_survival(x_cat, x_num, n_new, event_idx)?;
1227
1228        let risks: Vec<f64> = survival
1229            .par_iter()
1230            .map(|s| {
1231                let final_surv = s.last().copied().unwrap_or(1.0);
1232                1.0 - final_surv
1233            })
1234            .collect();
1235
1236        Ok(risks)
1237    }
1238
1239    #[pyo3(signature = (x_cat, x_num, n_new))]
1240    pub fn predict_cumulative_incidence(
1241        &self,
1242        x_cat: Option<Vec<i64>>,
1243        x_num: Vec<f64>,
1244        n_new: usize,
1245    ) -> PyResult<Vec<Vec<Vec<f64>>>> {
1246        let num_events = self.weights.num_events;
1247        let num_durations = self.config.num_durations;
1248
1249        let outputs = predict_with_weights(
1250            x_cat.as_deref(),
1251            &x_num,
1252            n_new,
1253            &self.weights,
1254            self.config.layer_norm_eps,
1255        );
1256
1257        let mut all_hazards: Vec<Vec<Vec<f64>>> = Vec::new();
1258        for event_idx in 0..num_events {
1259            let logits = &outputs[event_idx];
1260            let hazards: Vec<Vec<f64>> = (0..n_new)
1261                .map(|i| {
1262                    (0..num_durations)
1263                        .map(|t| {
1264                            let logit = logits[i * num_durations + t];
1265                            1.0 / (1.0 + (-logit).exp())
1266                        })
1267                        .collect()
1268                })
1269                .collect();
1270            all_hazards.push(hazards);
1271        }
1272
1273        let cifs: Vec<Vec<Vec<f64>>> = (0..n_new)
1274            .into_par_iter()
1275            .map(|i| {
1276                let mut overall_surv = vec![1.0; num_durations + 1];
1277                for t in 0..num_durations {
1278                    let mut total_haz = 0.0;
1279                    for event_idx in 0..num_events {
1280                        total_haz += all_hazards[event_idx][i][t];
1281                    }
1282                    overall_surv[t + 1] = overall_surv[t] * (1.0 - total_haz.min(1.0));
1283                }
1284
1285                let mut event_cifs = Vec::new();
1286                for event_idx in 0..num_events {
1287                    let mut cif = Vec::with_capacity(num_durations);
1288                    let mut cum_inc = 0.0;
1289                    for t in 0..num_durations {
1290                        cum_inc += overall_surv[t] * all_hazards[event_idx][i][t];
1291                        cif.push(cum_inc);
1292                    }
1293                    event_cifs.push(cif);
1294                }
1295                event_cifs
1296            })
1297            .collect();
1298
1299        Ok(cifs)
1300    }
1301
1302    #[getter]
1303    pub fn get_num_events(&self) -> usize {
1304        self.weights.num_events
1305    }
1306
1307    #[getter]
1308    pub fn get_num_durations(&self) -> usize {
1309        self.config.num_durations
1310    }
1311
1312    #[getter]
1313    pub fn get_hidden_size(&self) -> usize {
1314        self.config.hidden_size
1315    }
1316
1317    #[getter]
1318    pub fn get_num_layers(&self) -> usize {
1319        self.config.num_hidden_layers
1320    }
1321}
1322
1323#[pyfunction]
1324#[pyo3(signature = (x_cat, x_num, n_obs, num_cat_features, num_num_features, cat_cardinalities, time, event, config=None))]
1325pub fn survtrace(
1326    py: Python<'_>,
1327    x_cat: Option<Vec<i64>>,
1328    x_num: Vec<f64>,
1329    n_obs: usize,
1330    num_cat_features: usize,
1331    num_num_features: usize,
1332    cat_cardinalities: Vec<usize>,
1333    time: Vec<f64>,
1334    event: Vec<i32>,
1335    config: Option<&SurvTraceConfig>,
1336) -> PyResult<SurvTrace> {
1337    let cfg = config.cloned().unwrap_or_else(|| {
1338        SurvTraceConfig::new(
1339            16, 3, 2, 64, 0.0, 0.1, 5, 1, 8, 0.001, 64, 100, 0.0001, None, None, 0.1, 1e-12,
1340        )
1341        .unwrap()
1342    });
1343
1344    SurvTrace::fit(
1345        py,
1346        x_cat,
1347        x_num,
1348        n_obs,
1349        num_cat_features,
1350        num_num_features,
1351        cat_cardinalities,
1352        time,
1353        event,
1354        &cfg,
1355    )
1356}
1357
1358#[cfg(test)]
1359mod tests {
1360    use super::*;
1361
1362    #[test]
1363    fn test_config_default() {
1364        let config = SurvTraceConfig::new(
1365            16,
1366            3,
1367            2,
1368            64,
1369            0.0,
1370            0.1,
1371            5,
1372            1,
1373            8,
1374            0.001,
1375            64,
1376            100,
1377            0.0001,
1378            Some(42),
1379            Some(5),
1380            0.1,
1381            1e-12,
1382        )
1383        .unwrap();
1384        assert_eq!(config.hidden_size, 16);
1385        assert_eq!(config.num_hidden_layers, 3);
1386        assert_eq!(config.num_attention_heads, 2);
1387    }
1388
1389    #[test]
1390    fn test_config_validation() {
1391        assert!(
1392            SurvTraceConfig::new(
1393                0, 3, 2, 64, 0.0, 0.1, 5, 1, 8, 0.001, 64, 100, 0.0001, None, None, 0.1, 1e-12
1394            )
1395            .is_err()
1396        );
1397        assert!(
1398            SurvTraceConfig::new(
1399                15, 3, 2, 64, 0.0, 0.1, 5, 1, 8, 0.001, 64, 100, 0.0001, None, None, 0.1, 1e-12
1400            )
1401            .is_err()
1402        );
1403        assert!(
1404            SurvTraceConfig::new(
1405                16, 0, 2, 64, 0.0, 0.1, 5, 1, 8, 0.001, 64, 100, 0.0001, None, None, 0.1, 1e-12
1406            )
1407            .is_err()
1408        );
1409    }
1410
1411    #[test]
1412    fn test_survtrace_basic() {
1413        let x_num = vec![1.0, 0.5, 0.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.0, 0.0, 1.5, 0.5];
1414        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1415        let event = vec![1, 1, 0, 1, 0, 1];
1416
1417        let config = SurvTraceConfig {
1418            hidden_size: 8,
1419            num_hidden_layers: 1,
1420            num_attention_heads: 2,
1421            intermediate_size: 16,
1422            hidden_dropout_prob: 0.0,
1423            attention_dropout_prob: 0.0,
1424            num_durations: 3,
1425            num_events: 1,
1426            vocab_size: 4,
1427            learning_rate: 0.01,
1428            batch_size: 6,
1429            n_epochs: 3,
1430            weight_decay: 0.0,
1431            seed: Some(42),
1432            early_stopping_patience: None,
1433            validation_fraction: 0.0,
1434            layer_norm_eps: 1e-12,
1435        };
1436
1437        let model = fit_survtrace_inner(None, &x_num, 6, 0, 2, &[], &time, &event, &config);
1438        assert_eq!(model.get_num_events(), 1);
1439        assert_eq!(model.get_num_durations(), 3);
1440        assert!(!model.train_loss.is_empty());
1441    }
1442
1443    #[test]
1444    fn test_survtrace_with_categorical() {
1445        let x_cat = vec![0i64, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1];
1446        let x_num = vec![1.0, 0.5, 0.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.0, 0.0, 1.5, 0.5];
1447        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1448        let event = vec![1, 1, 0, 1, 0, 1];
1449        let cat_cardinalities = vec![2, 2];
1450
1451        let config = SurvTraceConfig {
1452            hidden_size: 8,
1453            num_hidden_layers: 1,
1454            num_attention_heads: 2,
1455            intermediate_size: 16,
1456            hidden_dropout_prob: 0.0,
1457            attention_dropout_prob: 0.0,
1458            num_durations: 3,
1459            num_events: 1,
1460            vocab_size: 4,
1461            learning_rate: 0.01,
1462            batch_size: 6,
1463            n_epochs: 3,
1464            weight_decay: 0.0,
1465            seed: Some(42),
1466            early_stopping_patience: None,
1467            validation_fraction: 0.0,
1468            layer_norm_eps: 1e-12,
1469        };
1470
1471        let model = fit_survtrace_inner(
1472            Some(&x_cat),
1473            &x_num,
1474            6,
1475            2,
1476            2,
1477            &cat_cardinalities,
1478            &time,
1479            &event,
1480            &config,
1481        );
1482        assert_eq!(model.get_num_events(), 1);
1483    }
1484
1485    #[test]
1486    fn test_survtrace_competing_risks() {
1487        let x_num = vec![1.0, 0.5, 0.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.0, 0.0, 1.5, 0.5];
1488        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1489        let event = vec![1, 2, 0, 1, 2, 1];
1490
1491        let config = SurvTraceConfig {
1492            hidden_size: 8,
1493            num_hidden_layers: 1,
1494            num_attention_heads: 2,
1495            intermediate_size: 16,
1496            hidden_dropout_prob: 0.0,
1497            attention_dropout_prob: 0.0,
1498            num_durations: 3,
1499            num_events: 2,
1500            vocab_size: 4,
1501            learning_rate: 0.01,
1502            batch_size: 6,
1503            n_epochs: 3,
1504            weight_decay: 0.0,
1505            seed: Some(42),
1506            early_stopping_patience: None,
1507            validation_fraction: 0.0,
1508            layer_norm_eps: 1e-12,
1509        };
1510
1511        let model = fit_survtrace_inner(None, &x_num, 6, 0, 2, &[], &time, &event, &config);
1512        assert_eq!(model.get_num_events(), 2);
1513    }
1514
1515    #[test]
1516    fn test_duration_bins() {
1517        let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1518        let (bins, cuts) = compute_duration_bins(&times, 5);
1519
1520        assert_eq!(bins.len(), 10);
1521        assert_eq!(cuts.len(), 6);
1522
1523        for &bin in &bins {
1524            assert!(bin < 5);
1525        }
1526    }
1527
1528    #[test]
1529    fn test_nll_loss() {
1530        let logits = vec![0.5f32, -0.3, 0.1, 0.8, -0.2, 0.4];
1531        let durations = vec![1, 0, 2];
1532        let events = vec![1, 0, 1];
1533        let indices: Vec<usize> = vec![0, 1, 2];
1534
1535        let loss = compute_nll_logistic_hazard_loss(&logits, &durations, &events, 2, &indices);
1536        assert!(loss.is_finite());
1537        assert!(loss >= 0.0);
1538    }
1539
1540    #[test]
1541    fn test_gelu_cpu() {
1542        let x = 0.5;
1543        let result = gelu_cpu(x);
1544        assert!(result > 0.0);
1545        assert!(result < x);
1546    }
1547
1548    #[test]
1549    fn test_layer_norm_cpu() {
1550        let x = vec![1.0, 2.0, 3.0, 4.0];
1551        let gamma = vec![1.0f32, 1.0, 1.0, 1.0];
1552        let beta = vec![0.0f32, 0.0, 0.0, 0.0];
1553
1554        let result = layer_norm_cpu(&x, &gamma, &beta, 1e-12);
1555
1556        assert_eq!(result.len(), 4);
1557        let mean: f64 = result.iter().sum::<f64>() / 4.0;
1558        assert!((mean).abs() < 1e-6);
1559    }
1560}