1#![allow(
2 unused_variables,
3 unused_imports,
4 unused_mut,
5 unused_assignments,
6 clippy::too_many_arguments,
7 clippy::needless_range_loop
8)]
9
10use burn::{
11 backend::{Autodiff, NdArray},
12 module::Module,
13 nn::{Dropout, DropoutConfig, Embedding, EmbeddingConfig, Linear, LinearConfig},
14 optim::{AdamConfig, GradientsParams, Optimizer},
15 prelude::*,
16 tensor::{activation::softplus, backend::AutodiffBackend as AutodiffBackendTrait},
17};
18use pyo3::prelude::*;
19use rayon::prelude::*;
20
21use super::utils::{
22 compute_duration_bins, gelu_cpu, layer_norm_cpu, linear_forward, tensor_to_vec_f32,
23};
24
25type Backend = NdArray;
26type AutodiffBackend = Autodiff<Backend>;
27
28fn gelu<B: burn::prelude::Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
29 let sqrt_2 = (2.0_f32).sqrt();
30 let cdf = (x.clone() / sqrt_2).erf().add_scalar(1.0) * 0.5;
31 x * cdf
32}
33
34fn layer_norm<B: burn::prelude::Backend>(
35 x: Tensor<B, 2>,
36 gamma: Tensor<B, 1>,
37 beta: Tensor<B, 1>,
38 eps: f32,
39) -> Tensor<B, 2> {
40 let [batch, hidden] = x.dims();
41 let mean = x.clone().mean_dim(1);
42 let var = x.clone().var(1);
43 let x_norm = (x - mean) / (var + eps).sqrt();
44 let gamma_expanded: Tensor<B, 2> = gamma.reshape([1, hidden]);
45 let beta_expanded: Tensor<B, 2> = beta.reshape([1, hidden]);
46 x_norm * gamma_expanded + beta_expanded
47}
48
49#[derive(Debug, Clone, Copy, PartialEq)]
50#[pyclass]
51pub enum SurvTraceActivation {
52 GELU,
53 ReLU,
54}
55
56#[pymethods]
57impl SurvTraceActivation {
58 #[new]
59 fn new(name: &str) -> PyResult<Self> {
60 match name.to_lowercase().as_str() {
61 "gelu" => Ok(SurvTraceActivation::GELU),
62 "relu" => Ok(SurvTraceActivation::ReLU),
63 _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
64 "Unknown activation. Use 'gelu' or 'relu'",
65 )),
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
71#[pyclass]
72pub struct SurvTraceConfig {
73 #[pyo3(get, set)]
74 pub hidden_size: usize,
75 #[pyo3(get, set)]
76 pub num_hidden_layers: usize,
77 #[pyo3(get, set)]
78 pub num_attention_heads: usize,
79 #[pyo3(get, set)]
80 pub intermediate_size: usize,
81 #[pyo3(get, set)]
82 pub hidden_dropout_prob: f64,
83 #[pyo3(get, set)]
84 pub attention_dropout_prob: f64,
85 #[pyo3(get, set)]
86 pub num_durations: usize,
87 #[pyo3(get, set)]
88 pub num_events: usize,
89 #[pyo3(get, set)]
90 pub vocab_size: usize,
91 #[pyo3(get, set)]
92 pub learning_rate: f64,
93 #[pyo3(get, set)]
94 pub batch_size: usize,
95 #[pyo3(get, set)]
96 pub n_epochs: usize,
97 #[pyo3(get, set)]
98 pub weight_decay: f64,
99 #[pyo3(get, set)]
100 pub seed: Option<u64>,
101 #[pyo3(get, set)]
102 pub early_stopping_patience: Option<usize>,
103 #[pyo3(get, set)]
104 pub validation_fraction: f64,
105 #[pyo3(get, set)]
106 pub layer_norm_eps: f32,
107}
108
109#[pymethods]
110impl SurvTraceConfig {
111 #[new]
112 #[pyo3(signature = (
113 hidden_size=16,
114 num_hidden_layers=3,
115 num_attention_heads=2,
116 intermediate_size=64,
117 hidden_dropout_prob=0.0,
118 attention_dropout_prob=0.1,
119 num_durations=5,
120 num_events=1,
121 vocab_size=8,
122 learning_rate=0.001,
123 batch_size=64,
124 n_epochs=100,
125 weight_decay=0.0001,
126 seed=None,
127 early_stopping_patience=None,
128 validation_fraction=0.1,
129 layer_norm_eps=1e-12
130 ))]
131 pub fn new(
132 hidden_size: usize,
133 num_hidden_layers: usize,
134 num_attention_heads: usize,
135 intermediate_size: usize,
136 hidden_dropout_prob: f64,
137 attention_dropout_prob: f64,
138 num_durations: usize,
139 num_events: usize,
140 vocab_size: usize,
141 learning_rate: f64,
142 batch_size: usize,
143 n_epochs: usize,
144 weight_decay: f64,
145 seed: Option<u64>,
146 early_stopping_patience: Option<usize>,
147 validation_fraction: f64,
148 layer_norm_eps: f32,
149 ) -> PyResult<Self> {
150 if hidden_size == 0 {
151 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
152 "hidden_size must be positive",
153 ));
154 }
155 if num_hidden_layers == 0 {
156 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
157 "num_hidden_layers must be positive",
158 ));
159 }
160 if num_attention_heads == 0 {
161 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
162 "num_attention_heads must be positive",
163 ));
164 }
165 if !hidden_size.is_multiple_of(num_attention_heads) {
166 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
167 "hidden_size must be divisible by num_attention_heads",
168 ));
169 }
170 if num_durations == 0 {
171 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
172 "num_durations must be positive",
173 ));
174 }
175 if batch_size == 0 {
176 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
177 "batch_size must be positive",
178 ));
179 }
180 if n_epochs == 0 {
181 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
182 "n_epochs must be positive",
183 ));
184 }
185 if !(0.0..1.0).contains(&validation_fraction) {
186 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
187 "validation_fraction must be in [0, 1)",
188 ));
189 }
190
191 Ok(SurvTraceConfig {
192 hidden_size,
193 num_hidden_layers,
194 num_attention_heads,
195 intermediate_size,
196 hidden_dropout_prob,
197 attention_dropout_prob,
198 num_durations,
199 num_events,
200 vocab_size,
201 learning_rate,
202 batch_size,
203 n_epochs,
204 weight_decay,
205 seed,
206 early_stopping_patience,
207 validation_fraction,
208 layer_norm_eps,
209 })
210 }
211}
212
213#[derive(Module, Debug)]
214struct MultiHeadAttention<B: burn::prelude::Backend> {
215 query: Linear<B>,
216 key: Linear<B>,
217 value: Linear<B>,
218 output: Linear<B>,
219 dropout: Dropout,
220 num_heads: usize,
221 head_dim: usize,
222}
223
224impl<B: burn::prelude::Backend> MultiHeadAttention<B> {
225 fn new(device: &B::Device, hidden_size: usize, num_heads: usize, dropout_prob: f64) -> Self {
226 let head_dim = hidden_size / num_heads;
227
228 Self {
229 query: LinearConfig::new(hidden_size, hidden_size).init(device),
230 key: LinearConfig::new(hidden_size, hidden_size).init(device),
231 value: LinearConfig::new(hidden_size, hidden_size).init(device),
232 output: LinearConfig::new(hidden_size, hidden_size).init(device),
233 dropout: DropoutConfig::new(dropout_prob).init(),
234 num_heads,
235 head_dim,
236 }
237 }
238
239 fn forward(&self, x: Tensor<B, 2>, training: bool) -> Tensor<B, 2> {
240 let [batch_size, hidden_size] = x.dims();
241 let seq_len = 1;
242
243 let q = self.query.forward(x.clone());
244 let k = self.key.forward(x.clone());
245 let v = self.value.forward(x);
246
247 let q = q
248 .reshape([batch_size, seq_len, self.num_heads, self.head_dim])
249 .swap_dims(1, 2);
250 let k = k
251 .reshape([batch_size, seq_len, self.num_heads, self.head_dim])
252 .swap_dims(1, 2);
253 let v = v
254 .reshape([batch_size, seq_len, self.num_heads, self.head_dim])
255 .swap_dims(1, 2);
256
257 let scale = (self.head_dim as f32).sqrt();
258 let scores = q.matmul(k.swap_dims(2, 3)) / scale;
259 let attn_weights = burn::tensor::activation::softmax(scores, 3);
260
261 let attn_weights = if training {
262 self.dropout.forward(attn_weights)
263 } else {
264 attn_weights
265 };
266
267 let context = attn_weights.matmul(v);
268 let context = context.swap_dims(1, 2).reshape([batch_size, hidden_size]);
269
270 self.output.forward(context)
271 }
272}
273
274#[derive(Module, Debug)]
275struct TransformerLayer<B: burn::prelude::Backend> {
276 attention: MultiHeadAttention<B>,
277 intermediate: Linear<B>,
278 output_dense: Linear<B>,
279 layer_norm1_gamma: burn::module::Param<Tensor<B, 1>>,
280 layer_norm1_beta: burn::module::Param<Tensor<B, 1>>,
281 layer_norm2_gamma: burn::module::Param<Tensor<B, 1>>,
282 layer_norm2_beta: burn::module::Param<Tensor<B, 1>>,
283 dropout: Dropout,
284 layer_norm_eps: f32,
285}
286
287impl<B: burn::prelude::Backend> TransformerLayer<B> {
288 fn new(
289 device: &B::Device,
290 hidden_size: usize,
291 num_heads: usize,
292 intermediate_size: usize,
293 hidden_dropout_prob: f64,
294 attention_dropout_prob: f64,
295 layer_norm_eps: f32,
296 ) -> Self {
297 Self {
298 attention: MultiHeadAttention::new(
299 device,
300 hidden_size,
301 num_heads,
302 attention_dropout_prob,
303 ),
304 intermediate: LinearConfig::new(hidden_size, intermediate_size).init(device),
305 output_dense: LinearConfig::new(intermediate_size, hidden_size).init(device),
306 layer_norm1_gamma: burn::module::Param::from_tensor(Tensor::ones(
307 [hidden_size],
308 device,
309 )),
310 layer_norm1_beta: burn::module::Param::from_tensor(Tensor::zeros(
311 [hidden_size],
312 device,
313 )),
314 layer_norm2_gamma: burn::module::Param::from_tensor(Tensor::ones(
315 [hidden_size],
316 device,
317 )),
318 layer_norm2_beta: burn::module::Param::from_tensor(Tensor::zeros(
319 [hidden_size],
320 device,
321 )),
322 dropout: DropoutConfig::new(hidden_dropout_prob).init(),
323 layer_norm_eps,
324 }
325 }
326
327 fn forward(&self, x: Tensor<B, 2>, training: bool) -> Tensor<B, 2> {
328 let attn_output = self.attention.forward(x.clone(), training);
329 let attn_output = if training {
330 self.dropout.forward(attn_output)
331 } else {
332 attn_output
333 };
334 let x = layer_norm(
335 x + attn_output,
336 self.layer_norm1_gamma.val(),
337 self.layer_norm1_beta.val(),
338 self.layer_norm_eps,
339 );
340
341 let intermediate = self.intermediate.forward(x.clone());
342 let intermediate = gelu(intermediate);
343 let output = self.output_dense.forward(intermediate);
344 let output = if training {
345 self.dropout.forward(output)
346 } else {
347 output
348 };
349
350 layer_norm(
351 x + output,
352 self.layer_norm2_gamma.val(),
353 self.layer_norm2_beta.val(),
354 self.layer_norm_eps,
355 )
356 }
357}
358
359#[derive(Module, Debug)]
360struct SurvTraceNetwork<B: burn::prelude::Backend> {
361 cat_embeddings: Vec<Embedding<B>>,
362 num_projection: Linear<B>,
363 transformer_layers: Vec<TransformerLayer<B>>,
364 output_heads: Vec<Linear<B>>,
365 hidden_size: usize,
366 num_cat_features: usize,
367 num_num_features: usize,
368 num_events: usize,
369 num_durations: usize,
370}
371
372impl<B: burn::prelude::Backend> SurvTraceNetwork<B> {
373 fn new(
374 device: &B::Device,
375 num_cat_features: usize,
376 num_num_features: usize,
377 cat_cardinalities: &[usize],
378 config: &SurvTraceConfig,
379 ) -> Self {
380 let mut cat_embeddings = Vec::new();
381 for &card in cat_cardinalities {
382 cat_embeddings.push(EmbeddingConfig::new(card.max(2), config.hidden_size).init(device));
383 }
384
385 let num_projection = if num_num_features > 0 {
386 LinearConfig::new(num_num_features, config.hidden_size).init(device)
387 } else {
388 LinearConfig::new(1, config.hidden_size).init(device)
389 };
390
391 let mut transformer_layers = Vec::new();
392 for _ in 0..config.num_hidden_layers {
393 transformer_layers.push(TransformerLayer::new(
394 device,
395 config.hidden_size,
396 config.num_attention_heads,
397 config.intermediate_size,
398 config.hidden_dropout_prob,
399 config.attention_dropout_prob,
400 config.layer_norm_eps,
401 ));
402 }
403
404 let mut output_heads = Vec::new();
405 let num_events = config.num_events.max(1);
406 for _ in 0..num_events {
407 output_heads
408 .push(LinearConfig::new(config.hidden_size, config.num_durations).init(device));
409 }
410
411 Self {
412 cat_embeddings,
413 num_projection,
414 transformer_layers,
415 output_heads,
416 hidden_size: config.hidden_size,
417 num_cat_features,
418 num_num_features,
419 num_events,
420 num_durations: config.num_durations,
421 }
422 }
423
424 fn forward(
425 &self,
426 x_cat: Option<Tensor<B, 2, Int>>,
427 x_num: Tensor<B, 2>,
428 training: bool,
429 ) -> Vec<Tensor<B, 2>> {
430 let [batch_size, _] = x_num.dims();
431 let device = x_num.device();
432
433 let mut embeddings: Tensor<B, 2> = Tensor::zeros([batch_size, self.hidden_size], &device);
434
435 if let Some(x_cat) = x_cat {
436 for (i, emb) in self.cat_embeddings.iter().enumerate() {
437 let cat_slice: Tensor<B, 2, Int> = x_cat.clone().slice([0..batch_size, i..i + 1]);
438 let cat_emb_3d: Tensor<B, 3> = emb.forward(cat_slice);
439 let cat_emb: Tensor<B, 2> = cat_emb_3d.squeeze::<2>();
440 embeddings = embeddings + cat_emb;
441 }
442 }
443
444 if self.num_num_features > 0 {
445 let num_emb = self.num_projection.forward(x_num);
446 embeddings = embeddings + num_emb;
447 }
448
449 let mut hidden = embeddings;
450 for layer in &self.transformer_layers {
451 hidden = layer.forward(hidden, training);
452 }
453
454 let mut outputs = Vec::new();
455 for head in &self.output_heads {
456 let logits = head.forward(hidden.clone());
457 outputs.push(logits);
458 }
459
460 outputs
461 }
462
463 fn forward_inference(
464 &self,
465 x_cat: Option<Tensor<B, 2, Int>>,
466 x_num: Tensor<B, 2>,
467 ) -> Vec<Tensor<B, 2>> {
468 self.forward(x_cat, x_num, false)
469 }
470}
471
472fn compute_nll_logistic_hazard_loss(
473 logits: &[f32],
474 durations: &[usize],
475 events: &[i32],
476 num_durations: usize,
477 batch_indices: &[usize],
478) -> f64 {
479 let mut total_loss = 0.0;
480 let mut n_events = 0;
481
482 for (local_idx, &global_idx) in batch_indices.iter().enumerate() {
483 let duration_bin = durations[global_idx].min(num_durations - 1);
484 let event = events[global_idx];
485
486 for t in 0..=duration_bin {
487 let logit = logits[local_idx * num_durations + t];
488 let target = if t == duration_bin && event == 1 {
489 1.0
490 } else {
491 0.0
492 };
493
494 let loss = if target > 0.5 {
495 (1.0 + (-logit).exp()).ln()
496 } else {
497 logit + (1.0 + (-logit).exp()).ln()
498 };
499 total_loss += loss as f64;
500 }
501
502 if event == 1 {
503 n_events += 1;
504 }
505 }
506
507 if n_events > 0 {
508 total_loss / n_events as f64
509 } else {
510 total_loss / batch_indices.len().max(1) as f64
511 }
512}
513
514fn compute_nll_logistic_hazard_gradient(
515 logits: &[f32],
516 durations: &[usize],
517 events: &[i32],
518 num_durations: usize,
519 batch_indices: &[usize],
520) -> Vec<f32> {
521 let batch_size = batch_indices.len();
522 let mut gradients = vec![0.0f32; batch_size * num_durations];
523
524 for (local_idx, &global_idx) in batch_indices.iter().enumerate() {
525 let duration_bin = durations[global_idx].min(num_durations - 1);
526 let event = events[global_idx];
527
528 for t in 0..=duration_bin {
529 let logit = logits[local_idx * num_durations + t];
530 let pred = 1.0 / (1.0 + (-logit).exp());
531 let target = if t == duration_bin && event == 1 {
532 1.0
533 } else {
534 0.0
535 };
536 gradients[local_idx * num_durations + t] = pred - target;
537 }
538 }
539
540 let n_events: i32 = batch_indices.iter().map(|&i| events[i]).sum();
541 let divisor = if n_events > 0 {
542 n_events as f32
543 } else {
544 batch_size.max(1) as f32
545 };
546
547 for g in &mut gradients {
548 *g /= divisor;
549 }
550
551 gradients
552}
553
554#[derive(Clone)]
555struct StoredWeights {
556 cat_embeddings: Vec<Vec<f32>>,
557 cat_embedding_dims: Vec<(usize, usize)>,
558 num_projection_weights: Vec<f32>,
559 num_projection_bias: Vec<f32>,
560 num_projection_dims: (usize, usize),
561 transformer_layers: Vec<TransformerLayerWeights>,
562 output_heads: Vec<(Vec<f32>, Vec<f32>, usize, usize)>,
563 hidden_size: usize,
564 num_cat_features: usize,
565 num_num_features: usize,
566 num_events: usize,
567}
568
569#[derive(Clone)]
570struct TransformerLayerWeights {
571 query_w: Vec<f32>,
572 query_b: Vec<f32>,
573 key_w: Vec<f32>,
574 key_b: Vec<f32>,
575 value_w: Vec<f32>,
576 value_b: Vec<f32>,
577 output_w: Vec<f32>,
578 output_b: Vec<f32>,
579 intermediate_w: Vec<f32>,
580 intermediate_b: Vec<f32>,
581 output_dense_w: Vec<f32>,
582 output_dense_b: Vec<f32>,
583 ln1_gamma: Vec<f32>,
584 ln1_beta: Vec<f32>,
585 ln2_gamma: Vec<f32>,
586 ln2_beta: Vec<f32>,
587 hidden_size: usize,
588 intermediate_size: usize,
589 num_heads: usize,
590}
591
592impl std::fmt::Debug for StoredWeights {
593 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
594 f.debug_struct("StoredWeights")
595 .field("num_transformer_layers", &self.transformer_layers.len())
596 .field("num_events", &self.num_events)
597 .finish()
598 }
599}
600
601fn extract_weights(
602 model: &SurvTraceNetwork<AutodiffBackend>,
603 config: &SurvTraceConfig,
604 cat_cardinalities: &[usize],
605) -> StoredWeights {
606 let mut cat_embeddings = Vec::new();
607 let mut cat_embedding_dims = Vec::new();
608
609 for (i, emb) in model.cat_embeddings.iter().enumerate() {
610 let w: Vec<f32> = emb
611 .weight
612 .val()
613 .inner()
614 .into_data()
615 .to_vec()
616 .unwrap_or_default();
617 cat_embeddings.push(w);
618 cat_embedding_dims.push((
619 cat_cardinalities.get(i).copied().unwrap_or(2),
620 config.hidden_size,
621 ));
622 }
623
624 let num_proj_w: Tensor<AutodiffBackend, 2> = model.num_projection.weight.val();
625 let num_projection_weights: Vec<f32> = tensor_to_vec_f32(num_proj_w.inner());
626 let num_projection_bias: Vec<f32> = model
627 .num_projection
628 .bias
629 .as_ref()
630 .map(|b| b.val().inner().into_data().to_vec().unwrap_or_default())
631 .unwrap_or_default();
632 let num_projection_dims = (model.num_num_features.max(1), config.hidden_size);
633
634 let mut transformer_layers = Vec::new();
635 for layer in &model.transformer_layers {
636 let tlw = TransformerLayerWeights {
637 query_w: tensor_to_vec_f32(layer.attention.query.weight.val().inner()),
638 query_b: layer
639 .attention
640 .query
641 .bias
642 .as_ref()
643 .map(|b| b.val().inner().into_data().to_vec().unwrap_or_default())
644 .unwrap_or_default(),
645 key_w: tensor_to_vec_f32(layer.attention.key.weight.val().inner()),
646 key_b: layer
647 .attention
648 .key
649 .bias
650 .as_ref()
651 .map(|b| b.val().inner().into_data().to_vec().unwrap_or_default())
652 .unwrap_or_default(),
653 value_w: tensor_to_vec_f32(layer.attention.value.weight.val().inner()),
654 value_b: layer
655 .attention
656 .value
657 .bias
658 .as_ref()
659 .map(|b| b.val().inner().into_data().to_vec().unwrap_or_default())
660 .unwrap_or_default(),
661 output_w: tensor_to_vec_f32(layer.attention.output.weight.val().inner()),
662 output_b: layer
663 .attention
664 .output
665 .bias
666 .as_ref()
667 .map(|b| b.val().inner().into_data().to_vec().unwrap_or_default())
668 .unwrap_or_default(),
669 intermediate_w: tensor_to_vec_f32(layer.intermediate.weight.val().inner()),
670 intermediate_b: layer
671 .intermediate
672 .bias
673 .as_ref()
674 .map(|b| b.val().inner().into_data().to_vec().unwrap_or_default())
675 .unwrap_or_default(),
676 output_dense_w: tensor_to_vec_f32(layer.output_dense.weight.val().inner()),
677 output_dense_b: layer
678 .output_dense
679 .bias
680 .as_ref()
681 .map(|b| b.val().inner().into_data().to_vec().unwrap_or_default())
682 .unwrap_or_default(),
683 ln1_gamma: layer
684 .layer_norm1_gamma
685 .val()
686 .inner()
687 .into_data()
688 .to_vec()
689 .unwrap_or_default(),
690 ln1_beta: layer
691 .layer_norm1_beta
692 .val()
693 .inner()
694 .into_data()
695 .to_vec()
696 .unwrap_or_default(),
697 ln2_gamma: layer
698 .layer_norm2_gamma
699 .val()
700 .inner()
701 .into_data()
702 .to_vec()
703 .unwrap_or_default(),
704 ln2_beta: layer
705 .layer_norm2_beta
706 .val()
707 .inner()
708 .into_data()
709 .to_vec()
710 .unwrap_or_default(),
711 hidden_size: config.hidden_size,
712 intermediate_size: config.intermediate_size,
713 num_heads: config.num_attention_heads,
714 };
715 transformer_layers.push(tlw);
716 }
717
718 let mut output_heads = Vec::new();
719 for head in &model.output_heads {
720 let w: Vec<f32> = tensor_to_vec_f32(head.weight.val().inner());
721 let b: Vec<f32> = head
722 .bias
723 .as_ref()
724 .map(|bias| bias.val().inner().into_data().to_vec().unwrap_or_default())
725 .unwrap_or_default();
726 output_heads.push((w, b, config.hidden_size, config.num_durations));
727 }
728
729 StoredWeights {
730 cat_embeddings,
731 cat_embedding_dims,
732 num_projection_weights,
733 num_projection_bias,
734 num_projection_dims,
735 transformer_layers,
736 output_heads,
737 hidden_size: config.hidden_size,
738 num_cat_features: model.num_cat_features,
739 num_num_features: model.num_num_features,
740 num_events: model.num_events,
741 }
742}
743
744fn predict_with_weights(
745 x_cat: Option<&[i64]>,
746 x_num: &[f64],
747 n: usize,
748 weights: &StoredWeights,
749 layer_norm_eps: f32,
750) -> Vec<Vec<f64>> {
751 let hidden_size = weights.hidden_size;
752 let num_num = weights.num_num_features;
753 let num_cat = weights.num_cat_features;
754
755 let mut all_outputs: Vec<Vec<f64>> = vec![Vec::new(); weights.num_events];
756
757 for i in 0..n {
758 let mut hidden = vec![0.0f64; hidden_size];
759
760 if let Some(cats) = x_cat {
761 for (feat_idx, emb_weights) in weights.cat_embeddings.iter().enumerate() {
762 let (vocab_size, emb_dim) = weights.cat_embedding_dims[feat_idx];
763 let cat_val = cats[i * num_cat + feat_idx] as usize;
764 let cat_val = cat_val.min(vocab_size - 1);
765 for j in 0..emb_dim {
766 hidden[j] += emb_weights[cat_val * emb_dim + j] as f64;
767 }
768 }
769 }
770
771 if num_num > 0 {
772 let (in_dim, out_dim) = weights.num_projection_dims;
773 for j in 0..out_dim {
774 let mut sum = if !weights.num_projection_bias.is_empty() {
775 weights.num_projection_bias[j] as f64
776 } else {
777 0.0
778 };
779 for k in 0..in_dim.min(num_num) {
780 sum += x_num[i * num_num + k]
781 * weights.num_projection_weights[j * in_dim + k] as f64;
782 }
783 hidden[j] += sum;
784 }
785 }
786
787 for layer in &weights.transformer_layers {
788 hidden = apply_transformer_layer_cpu(&hidden, layer, layer_norm_eps);
789 }
790
791 for (event_idx, (w, b, in_dim, out_dim)) in weights.output_heads.iter().enumerate() {
792 let mut logits = Vec::with_capacity(*out_dim);
793 for j in 0..*out_dim {
794 let mut sum = if !b.is_empty() { b[j] as f64 } else { 0.0 };
795 for k in 0..*in_dim {
796 sum += hidden[k] * w[j * in_dim + k] as f64;
797 }
798 logits.push(sum);
799 }
800 all_outputs[event_idx].extend(logits);
801 }
802 }
803
804 all_outputs
805}
806
807fn apply_transformer_layer_cpu(
808 hidden: &[f64],
809 layer: &TransformerLayerWeights,
810 eps: f32,
811) -> Vec<f64> {
812 let h = layer.hidden_size;
813
814 let q = linear_forward(hidden, &layer.query_w, &layer.query_b, h, h);
815 let k = linear_forward(hidden, &layer.key_w, &layer.key_b, h, h);
816 let v = linear_forward(hidden, &layer.value_w, &layer.value_b, h, h);
817
818 let head_dim = h / layer.num_heads;
819 let mut attn_output = vec![0.0f64; h];
820
821 for head in 0..layer.num_heads {
822 let start = head * head_dim;
823 let end = start + head_dim;
824
825 let mut score = 0.0;
826 for i in start..end {
827 score += q[i] * k[i];
828 }
829 score /= (head_dim as f64).sqrt();
830 let attn_weight = 1.0;
831
832 for i in start..end {
833 attn_output[i] = attn_weight * v[i];
834 }
835 }
836
837 let attn_proj = linear_forward(&attn_output, &layer.output_w, &layer.output_b, h, h);
838
839 let mut residual1: Vec<f64> = hidden.iter().zip(&attn_proj).map(|(a, b)| a + b).collect();
840 residual1 = layer_norm_cpu(&residual1, &layer.ln1_gamma, &layer.ln1_beta, eps);
841
842 let intermediate = linear_forward(
843 &residual1,
844 &layer.intermediate_w,
845 &layer.intermediate_b,
846 h,
847 layer.intermediate_size,
848 );
849 let intermediate: Vec<f64> = intermediate.iter().map(|&x| gelu_cpu(x)).collect();
850
851 let output = linear_forward(
852 &intermediate,
853 &layer.output_dense_w,
854 &layer.output_dense_b,
855 layer.intermediate_size,
856 h,
857 );
858
859 let mut residual2: Vec<f64> = residual1.iter().zip(&output).map(|(a, b)| a + b).collect();
860 residual2 = layer_norm_cpu(&residual2, &layer.ln2_gamma, &layer.ln2_beta, eps);
861
862 residual2
863}
864
865fn fit_survtrace_inner(
866 x_cat: Option<&[i64]>,
867 x_num: &[f64],
868 n_obs: usize,
869 num_cat_features: usize,
870 num_num_features: usize,
871 cat_cardinalities: &[usize],
872 time: &[f64],
873 event: &[i32],
874 config: &SurvTraceConfig,
875) -> SurvTrace {
876 let device: <Backend as burn::prelude::Backend>::Device = Default::default();
877 let seed = config.seed.unwrap_or(42);
878
879 let (duration_bins, cuts) = compute_duration_bins(time, config.num_durations);
880
881 let mut model: SurvTraceNetwork<AutodiffBackend> = SurvTraceNetwork::new(
882 &device,
883 num_cat_features,
884 num_num_features,
885 cat_cardinalities,
886 config,
887 );
888
889 let mut optimizer = AdamConfig::new()
890 .with_weight_decay(Some(burn::optim::decay::WeightDecayConfig::new(
891 config.weight_decay as f32,
892 )))
893 .init();
894
895 let n_val = (n_obs as f64 * config.validation_fraction).floor() as usize;
896 let n_train = n_obs - n_val;
897
898 let mut rng = fastrand::Rng::with_seed(seed);
899 let mut shuffled_indices: Vec<usize> = (0..n_obs).collect();
900 for i in (1..n_obs).rev() {
901 let j = rng.usize(0..=i);
902 shuffled_indices.swap(i, j);
903 }
904
905 let train_indices: Vec<usize> = shuffled_indices[..n_train].to_vec();
906 let val_indices: Vec<usize> = shuffled_indices[n_train..].to_vec();
907
908 let mut train_loss_history = Vec::new();
909 let mut val_loss_history = Vec::new();
910 let mut best_val_loss = f64::INFINITY;
911 let mut epochs_without_improvement = 0;
912 let mut best_weights: Option<StoredWeights> = None;
913
914 for epoch in 0..config.n_epochs {
915 let mut epoch_indices = train_indices.clone();
916 for i in (1..epoch_indices.len()).rev() {
917 let j = rng.usize(0..=i);
918 epoch_indices.swap(i, j);
919 }
920
921 let mut epoch_loss = 0.0;
922 let mut n_batches = 0;
923
924 for batch_start in (0..n_train).step_by(config.batch_size) {
925 let batch_end = (batch_start + config.batch_size).min(n_train);
926 let batch_indices: Vec<usize> = epoch_indices[batch_start..batch_end].to_vec();
927 let batch_size = batch_indices.len();
928
929 let x_num_batch: Vec<f32> = batch_indices
930 .iter()
931 .flat_map(|&i| {
932 (0..num_num_features).map(move |j| x_num[i * num_num_features + j] as f32)
933 })
934 .collect();
935
936 let x_num_data = burn::tensor::TensorData::new(
937 x_num_batch.clone(),
938 [batch_size, num_num_features.max(1)],
939 );
940 let x_num_tensor: Tensor<AutodiffBackend, 2> = Tensor::from_data(x_num_data, &device);
941
942 let x_cat_tensor: Option<Tensor<AutodiffBackend, 2, Int>> = if num_cat_features > 0 {
943 if let Some(cats) = x_cat {
944 let x_cat_batch: Vec<i64> = batch_indices
945 .iter()
946 .flat_map(|&i| {
947 (0..num_cat_features).map(move |j| cats[i * num_cat_features + j])
948 })
949 .collect();
950 let x_cat_data =
951 burn::tensor::TensorData::new(x_cat_batch, [batch_size, num_cat_features]);
952 Some(Tensor::from_data(x_cat_data, &device))
953 } else {
954 None
955 }
956 } else {
957 None
958 };
959
960 let outputs = model.forward(x_cat_tensor, x_num_tensor, true);
961
962 let mut total_loss = 0.0;
963 let mut all_grads: Vec<Vec<f32>> = Vec::new();
964
965 for (event_idx, logits_tensor) in outputs.iter().enumerate() {
966 let logits_vec: Vec<f32> = tensor_to_vec_f32(logits_tensor.clone().inner());
967
968 let loss = compute_nll_logistic_hazard_loss(
969 &logits_vec,
970 &duration_bins,
971 event,
972 config.num_durations,
973 &batch_indices,
974 );
975 total_loss += loss;
976
977 let grads = compute_nll_logistic_hazard_gradient(
978 &logits_vec,
979 &duration_bins,
980 event,
981 config.num_durations,
982 &batch_indices,
983 );
984 all_grads.push(grads);
985 }
986
987 epoch_loss += total_loss;
988 n_batches += 1;
989
990 if !all_grads.is_empty() {
991 let grad_data = burn::tensor::TensorData::new(
992 all_grads[0].clone(),
993 [batch_size, config.num_durations],
994 );
995 let grad_tensor: Tensor<AutodiffBackend, 2> = Tensor::from_data(grad_data, &device);
996
997 let pseudo_loss = (outputs[0].clone() * grad_tensor).mean();
998 let grads = pseudo_loss.backward();
999 let grads = GradientsParams::from_grads(grads, &model);
1000 model = optimizer.step(config.learning_rate, model, grads);
1001 }
1002 }
1003
1004 let avg_train_loss = if n_batches > 0 {
1005 epoch_loss / n_batches as f64
1006 } else {
1007 0.0
1008 };
1009 train_loss_history.push(avg_train_loss);
1010
1011 if !val_indices.is_empty() {
1012 let x_num_val: Vec<f32> = val_indices
1013 .iter()
1014 .flat_map(|&i| {
1015 (0..num_num_features).map(move |j| x_num[i * num_num_features + j] as f32)
1016 })
1017 .collect();
1018
1019 let x_num_val_data =
1020 burn::tensor::TensorData::new(x_num_val, [n_val, num_num_features.max(1)]);
1021 let x_num_val_tensor: Tensor<AutodiffBackend, 2> =
1022 Tensor::from_data(x_num_val_data, &device);
1023
1024 let x_cat_val_tensor: Option<Tensor<AutodiffBackend, 2, Int>> = if num_cat_features > 0
1025 {
1026 if let Some(cats) = x_cat {
1027 let x_cat_val: Vec<i64> = val_indices
1028 .iter()
1029 .flat_map(|&i| {
1030 (0..num_cat_features).map(move |j| cats[i * num_cat_features + j])
1031 })
1032 .collect();
1033 let x_cat_val_data =
1034 burn::tensor::TensorData::new(x_cat_val, [n_val, num_cat_features]);
1035 Some(Tensor::from_data(x_cat_val_data, &device))
1036 } else {
1037 None
1038 }
1039 } else {
1040 None
1041 };
1042
1043 let val_outputs = model.forward_inference(x_cat_val_tensor, x_num_val_tensor);
1044
1045 let mut val_loss = 0.0;
1046 for logits_tensor in &val_outputs {
1047 let logits_vec: Vec<f32> = tensor_to_vec_f32(logits_tensor.clone().inner());
1048 val_loss += compute_nll_logistic_hazard_loss(
1049 &logits_vec,
1050 &duration_bins,
1051 event,
1052 config.num_durations,
1053 &val_indices,
1054 );
1055 }
1056 val_loss_history.push(val_loss);
1057
1058 if val_loss < best_val_loss {
1059 best_val_loss = val_loss;
1060 epochs_without_improvement = 0;
1061 best_weights = Some(extract_weights(&model, config, cat_cardinalities));
1062 } else {
1063 epochs_without_improvement += 1;
1064 }
1065
1066 if let Some(patience) = config.early_stopping_patience
1067 && epochs_without_improvement >= patience
1068 {
1069 break;
1070 }
1071 }
1072 }
1073
1074 let final_weights =
1075 best_weights.unwrap_or_else(|| extract_weights(&model, config, cat_cardinalities));
1076
1077 SurvTrace {
1078 weights: final_weights,
1079 config: config.clone(),
1080 duration_cuts: cuts,
1081 train_loss: train_loss_history,
1082 val_loss: val_loss_history,
1083 cat_cardinalities: cat_cardinalities.to_vec(),
1084 }
1085}
1086
1087#[derive(Debug, Clone)]
1088#[pyclass]
1089pub struct SurvTrace {
1090 weights: StoredWeights,
1091 config: SurvTraceConfig,
1092 #[pyo3(get)]
1093 pub duration_cuts: Vec<f64>,
1094 #[pyo3(get)]
1095 pub train_loss: Vec<f64>,
1096 #[pyo3(get)]
1097 pub val_loss: Vec<f64>,
1098 #[pyo3(get)]
1099 pub cat_cardinalities: Vec<usize>,
1100}
1101
1102#[pymethods]
1103impl SurvTrace {
1104 #[staticmethod]
1105 #[pyo3(signature = (x_cat, x_num, n_obs, num_cat_features, num_num_features, cat_cardinalities, time, event, config))]
1106 pub fn fit(
1107 py: Python<'_>,
1108 x_cat: Option<Vec<i64>>,
1109 x_num: Vec<f64>,
1110 n_obs: usize,
1111 num_cat_features: usize,
1112 num_num_features: usize,
1113 cat_cardinalities: Vec<usize>,
1114 time: Vec<f64>,
1115 event: Vec<i32>,
1116 config: &SurvTraceConfig,
1117 ) -> PyResult<Self> {
1118 if x_num.len() != n_obs * num_num_features.max(1) && num_num_features > 0 {
1119 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
1120 "x_num length must equal n_obs * num_num_features",
1121 ));
1122 }
1123 if time.len() != n_obs || event.len() != n_obs {
1124 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
1125 "time and event must have length n_obs",
1126 ));
1127 }
1128 if let Some(ref cats) = x_cat
1129 && cats.len() != n_obs * num_cat_features
1130 {
1131 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
1132 "x_cat length must equal n_obs * num_cat_features",
1133 ));
1134 }
1135
1136 let config = config.clone();
1137 let x_cat_clone = x_cat.clone();
1138 Ok(py.detach(move || {
1139 fit_survtrace_inner(
1140 x_cat_clone.as_deref(),
1141 &x_num,
1142 n_obs,
1143 num_cat_features,
1144 num_num_features,
1145 &cat_cardinalities,
1146 &time,
1147 &event,
1148 &config,
1149 )
1150 }))
1151 }
1152
1153 #[pyo3(signature = (x_cat, x_num, n_new, event_idx=0))]
1154 pub fn predict_hazard(
1155 &self,
1156 x_cat: Option<Vec<i64>>,
1157 x_num: Vec<f64>,
1158 n_new: usize,
1159 event_idx: usize,
1160 ) -> PyResult<Vec<Vec<f64>>> {
1161 let outputs = predict_with_weights(
1162 x_cat.as_deref(),
1163 &x_num,
1164 n_new,
1165 &self.weights,
1166 self.config.layer_norm_eps,
1167 );
1168
1169 if event_idx >= outputs.len() {
1170 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
1171 "event_idx out of range",
1172 ));
1173 }
1174
1175 let logits = &outputs[event_idx];
1176 let num_durations = self.config.num_durations;
1177
1178 let hazards: Vec<Vec<f64>> = (0..n_new)
1179 .map(|i| {
1180 (0..num_durations)
1181 .map(|t| {
1182 let logit = logits[i * num_durations + t];
1183 1.0 / (1.0 + (-logit).exp())
1184 })
1185 .collect()
1186 })
1187 .collect();
1188
1189 Ok(hazards)
1190 }
1191
1192 #[pyo3(signature = (x_cat, x_num, n_new, event_idx=0))]
1193 pub fn predict_survival(
1194 &self,
1195 x_cat: Option<Vec<i64>>,
1196 x_num: Vec<f64>,
1197 n_new: usize,
1198 event_idx: usize,
1199 ) -> PyResult<Vec<Vec<f64>>> {
1200 let hazards = self.predict_hazard(x_cat, x_num, n_new, event_idx)?;
1201
1202 let survival: Vec<Vec<f64>> = hazards
1203 .par_iter()
1204 .map(|h| {
1205 let mut surv = Vec::with_capacity(h.len());
1206 let mut cum_surv = 1.0;
1207 for &haz in h {
1208 cum_surv *= 1.0 - haz;
1209 surv.push(cum_surv);
1210 }
1211 surv
1212 })
1213 .collect();
1214
1215 Ok(survival)
1216 }
1217
1218 #[pyo3(signature = (x_cat, x_num, n_new, event_idx=0))]
1219 pub fn predict_risk(
1220 &self,
1221 x_cat: Option<Vec<i64>>,
1222 x_num: Vec<f64>,
1223 n_new: usize,
1224 event_idx: usize,
1225 ) -> PyResult<Vec<f64>> {
1226 let survival = self.predict_survival(x_cat, x_num, n_new, event_idx)?;
1227
1228 let risks: Vec<f64> = survival
1229 .par_iter()
1230 .map(|s| {
1231 let final_surv = s.last().copied().unwrap_or(1.0);
1232 1.0 - final_surv
1233 })
1234 .collect();
1235
1236 Ok(risks)
1237 }
1238
1239 #[pyo3(signature = (x_cat, x_num, n_new))]
1240 pub fn predict_cumulative_incidence(
1241 &self,
1242 x_cat: Option<Vec<i64>>,
1243 x_num: Vec<f64>,
1244 n_new: usize,
1245 ) -> PyResult<Vec<Vec<Vec<f64>>>> {
1246 let num_events = self.weights.num_events;
1247 let num_durations = self.config.num_durations;
1248
1249 let outputs = predict_with_weights(
1250 x_cat.as_deref(),
1251 &x_num,
1252 n_new,
1253 &self.weights,
1254 self.config.layer_norm_eps,
1255 );
1256
1257 let mut all_hazards: Vec<Vec<Vec<f64>>> = Vec::new();
1258 for event_idx in 0..num_events {
1259 let logits = &outputs[event_idx];
1260 let hazards: Vec<Vec<f64>> = (0..n_new)
1261 .map(|i| {
1262 (0..num_durations)
1263 .map(|t| {
1264 let logit = logits[i * num_durations + t];
1265 1.0 / (1.0 + (-logit).exp())
1266 })
1267 .collect()
1268 })
1269 .collect();
1270 all_hazards.push(hazards);
1271 }
1272
1273 let cifs: Vec<Vec<Vec<f64>>> = (0..n_new)
1274 .into_par_iter()
1275 .map(|i| {
1276 let mut overall_surv = vec![1.0; num_durations + 1];
1277 for t in 0..num_durations {
1278 let mut total_haz = 0.0;
1279 for event_idx in 0..num_events {
1280 total_haz += all_hazards[event_idx][i][t];
1281 }
1282 overall_surv[t + 1] = overall_surv[t] * (1.0 - total_haz.min(1.0));
1283 }
1284
1285 let mut event_cifs = Vec::new();
1286 for event_idx in 0..num_events {
1287 let mut cif = Vec::with_capacity(num_durations);
1288 let mut cum_inc = 0.0;
1289 for t in 0..num_durations {
1290 cum_inc += overall_surv[t] * all_hazards[event_idx][i][t];
1291 cif.push(cum_inc);
1292 }
1293 event_cifs.push(cif);
1294 }
1295 event_cifs
1296 })
1297 .collect();
1298
1299 Ok(cifs)
1300 }
1301
1302 #[getter]
1303 pub fn get_num_events(&self) -> usize {
1304 self.weights.num_events
1305 }
1306
1307 #[getter]
1308 pub fn get_num_durations(&self) -> usize {
1309 self.config.num_durations
1310 }
1311
1312 #[getter]
1313 pub fn get_hidden_size(&self) -> usize {
1314 self.config.hidden_size
1315 }
1316
1317 #[getter]
1318 pub fn get_num_layers(&self) -> usize {
1319 self.config.num_hidden_layers
1320 }
1321}
1322
1323#[pyfunction]
1324#[pyo3(signature = (x_cat, x_num, n_obs, num_cat_features, num_num_features, cat_cardinalities, time, event, config=None))]
1325pub fn survtrace(
1326 py: Python<'_>,
1327 x_cat: Option<Vec<i64>>,
1328 x_num: Vec<f64>,
1329 n_obs: usize,
1330 num_cat_features: usize,
1331 num_num_features: usize,
1332 cat_cardinalities: Vec<usize>,
1333 time: Vec<f64>,
1334 event: Vec<i32>,
1335 config: Option<&SurvTraceConfig>,
1336) -> PyResult<SurvTrace> {
1337 let cfg = config.cloned().unwrap_or_else(|| {
1338 SurvTraceConfig::new(
1339 16, 3, 2, 64, 0.0, 0.1, 5, 1, 8, 0.001, 64, 100, 0.0001, None, None, 0.1, 1e-12,
1340 )
1341 .unwrap()
1342 });
1343
1344 SurvTrace::fit(
1345 py,
1346 x_cat,
1347 x_num,
1348 n_obs,
1349 num_cat_features,
1350 num_num_features,
1351 cat_cardinalities,
1352 time,
1353 event,
1354 &cfg,
1355 )
1356}
1357
1358#[cfg(test)]
1359mod tests {
1360 use super::*;
1361
1362 #[test]
1363 fn test_config_default() {
1364 let config = SurvTraceConfig::new(
1365 16,
1366 3,
1367 2,
1368 64,
1369 0.0,
1370 0.1,
1371 5,
1372 1,
1373 8,
1374 0.001,
1375 64,
1376 100,
1377 0.0001,
1378 Some(42),
1379 Some(5),
1380 0.1,
1381 1e-12,
1382 )
1383 .unwrap();
1384 assert_eq!(config.hidden_size, 16);
1385 assert_eq!(config.num_hidden_layers, 3);
1386 assert_eq!(config.num_attention_heads, 2);
1387 }
1388
1389 #[test]
1390 fn test_config_validation() {
1391 assert!(
1392 SurvTraceConfig::new(
1393 0, 3, 2, 64, 0.0, 0.1, 5, 1, 8, 0.001, 64, 100, 0.0001, None, None, 0.1, 1e-12
1394 )
1395 .is_err()
1396 );
1397 assert!(
1398 SurvTraceConfig::new(
1399 15, 3, 2, 64, 0.0, 0.1, 5, 1, 8, 0.001, 64, 100, 0.0001, None, None, 0.1, 1e-12
1400 )
1401 .is_err()
1402 );
1403 assert!(
1404 SurvTraceConfig::new(
1405 16, 0, 2, 64, 0.0, 0.1, 5, 1, 8, 0.001, 64, 100, 0.0001, None, None, 0.1, 1e-12
1406 )
1407 .is_err()
1408 );
1409 }
1410
1411 #[test]
1412 fn test_survtrace_basic() {
1413 let x_num = vec![1.0, 0.5, 0.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.0, 0.0, 1.5, 0.5];
1414 let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1415 let event = vec![1, 1, 0, 1, 0, 1];
1416
1417 let config = SurvTraceConfig {
1418 hidden_size: 8,
1419 num_hidden_layers: 1,
1420 num_attention_heads: 2,
1421 intermediate_size: 16,
1422 hidden_dropout_prob: 0.0,
1423 attention_dropout_prob: 0.0,
1424 num_durations: 3,
1425 num_events: 1,
1426 vocab_size: 4,
1427 learning_rate: 0.01,
1428 batch_size: 6,
1429 n_epochs: 3,
1430 weight_decay: 0.0,
1431 seed: Some(42),
1432 early_stopping_patience: None,
1433 validation_fraction: 0.0,
1434 layer_norm_eps: 1e-12,
1435 };
1436
1437 let model = fit_survtrace_inner(None, &x_num, 6, 0, 2, &[], &time, &event, &config);
1438 assert_eq!(model.get_num_events(), 1);
1439 assert_eq!(model.get_num_durations(), 3);
1440 assert!(!model.train_loss.is_empty());
1441 }
1442
1443 #[test]
1444 fn test_survtrace_with_categorical() {
1445 let x_cat = vec![0i64, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1];
1446 let x_num = vec![1.0, 0.5, 0.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.0, 0.0, 1.5, 0.5];
1447 let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1448 let event = vec![1, 1, 0, 1, 0, 1];
1449 let cat_cardinalities = vec![2, 2];
1450
1451 let config = SurvTraceConfig {
1452 hidden_size: 8,
1453 num_hidden_layers: 1,
1454 num_attention_heads: 2,
1455 intermediate_size: 16,
1456 hidden_dropout_prob: 0.0,
1457 attention_dropout_prob: 0.0,
1458 num_durations: 3,
1459 num_events: 1,
1460 vocab_size: 4,
1461 learning_rate: 0.01,
1462 batch_size: 6,
1463 n_epochs: 3,
1464 weight_decay: 0.0,
1465 seed: Some(42),
1466 early_stopping_patience: None,
1467 validation_fraction: 0.0,
1468 layer_norm_eps: 1e-12,
1469 };
1470
1471 let model = fit_survtrace_inner(
1472 Some(&x_cat),
1473 &x_num,
1474 6,
1475 2,
1476 2,
1477 &cat_cardinalities,
1478 &time,
1479 &event,
1480 &config,
1481 );
1482 assert_eq!(model.get_num_events(), 1);
1483 }
1484
1485 #[test]
1486 fn test_survtrace_competing_risks() {
1487 let x_num = vec![1.0, 0.5, 0.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.0, 0.0, 1.5, 0.5];
1488 let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1489 let event = vec![1, 2, 0, 1, 2, 1];
1490
1491 let config = SurvTraceConfig {
1492 hidden_size: 8,
1493 num_hidden_layers: 1,
1494 num_attention_heads: 2,
1495 intermediate_size: 16,
1496 hidden_dropout_prob: 0.0,
1497 attention_dropout_prob: 0.0,
1498 num_durations: 3,
1499 num_events: 2,
1500 vocab_size: 4,
1501 learning_rate: 0.01,
1502 batch_size: 6,
1503 n_epochs: 3,
1504 weight_decay: 0.0,
1505 seed: Some(42),
1506 early_stopping_patience: None,
1507 validation_fraction: 0.0,
1508 layer_norm_eps: 1e-12,
1509 };
1510
1511 let model = fit_survtrace_inner(None, &x_num, 6, 0, 2, &[], &time, &event, &config);
1512 assert_eq!(model.get_num_events(), 2);
1513 }
1514
1515 #[test]
1516 fn test_duration_bins() {
1517 let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1518 let (bins, cuts) = compute_duration_bins(×, 5);
1519
1520 assert_eq!(bins.len(), 10);
1521 assert_eq!(cuts.len(), 6);
1522
1523 for &bin in &bins {
1524 assert!(bin < 5);
1525 }
1526 }
1527
1528 #[test]
1529 fn test_nll_loss() {
1530 let logits = vec![0.5f32, -0.3, 0.1, 0.8, -0.2, 0.4];
1531 let durations = vec![1, 0, 2];
1532 let events = vec![1, 0, 1];
1533 let indices: Vec<usize> = vec![0, 1, 2];
1534
1535 let loss = compute_nll_logistic_hazard_loss(&logits, &durations, &events, 2, &indices);
1536 assert!(loss.is_finite());
1537 assert!(loss >= 0.0);
1538 }
1539
1540 #[test]
1541 fn test_gelu_cpu() {
1542 let x = 0.5;
1543 let result = gelu_cpu(x);
1544 assert!(result > 0.0);
1545 assert!(result < x);
1546 }
1547
1548 #[test]
1549 fn test_layer_norm_cpu() {
1550 let x = vec![1.0, 2.0, 3.0, 4.0];
1551 let gamma = vec![1.0f32, 1.0, 1.0, 1.0];
1552 let beta = vec![0.0f32, 0.0, 0.0, 0.0];
1553
1554 let result = layer_norm_cpu(&x, &gamma, &beta, 1e-12);
1555
1556 assert_eq!(result.len(), 4);
1557 let mean: f64 = result.iter().sum::<f64>() / 4.0;
1558 assert!((mean).abs() < 1e-6);
1559 }
1560}