1use std::f32::consts::PI;
22
23use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
24use scirs2_core::RngExt;
25
26type TabResult<T> = Result<T, String>;
31
32#[inline]
37fn relu(x: f32) -> f32 {
38 x.max(0.0)
39}
40
41#[inline]
42fn gelu(x: f32) -> f32 {
43 0.5 * x * (1.0 + (x * 0.797_884_6 * (1.0 + 0.044715 * x * x)).tanh())
44}
45
46#[inline]
47fn sigmoid(x: f32) -> f32 {
48 let c = x.clamp(-88.0, 88.0);
49 1.0 / (1.0 + (-c).exp())
50}
51
52fn softmax(v: &[f32]) -> Vec<f32> {
54 if v.is_empty() {
55 return Vec::new();
56 }
57 let max = v.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
58 let exps: Vec<f32> = v.iter().map(|&x| (x - max).exp()).collect();
59 let sum: f32 = exps.iter().sum::<f32>().max(1e-12);
60 exps.iter().map(|&e| e / sum).collect()
61}
62
63fn sparsemax(z: &[f32]) -> Vec<f32> {
65 let n = z.len();
66 if n == 0 {
67 return Vec::new();
68 }
69 let mut sorted = z.to_vec();
70 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
71 let mut cumsum = 0.0_f32;
72 let mut k = n;
73 for (i, &s) in sorted.iter().enumerate() {
74 cumsum += s;
75 if s > (cumsum - 1.0) / (i + 1) as f32 {
76 k = i + 1;
77 }
78 }
79 let tau = (sorted[..k].iter().sum::<f32>() - 1.0) / k as f32;
80 z.iter().map(|&zi| (zi - tau).max(0.0)).collect()
81}
82
83fn linear(w: &[f32], b: &[f32], x: &[f32]) -> TabResult<Vec<f32>> {
86 let in_dim = x.len();
87 let out_dim = b.len();
88 if w.len() != out_dim * in_dim {
89 return Err(format!(
90 "linear: w.len()={} != out×in={}×{}",
91 w.len(),
92 out_dim,
93 in_dim
94 ));
95 }
96 let mut y = vec![0.0_f32; out_dim];
97 for o in 0..out_dim {
98 let row = &w[o * in_dim..(o + 1) * in_dim];
99 y[o] = b[o]
100 + row
101 .iter()
102 .zip(x.iter())
103 .map(|(&wi, &xi)| wi * xi)
104 .sum::<f32>();
105 }
106 Ok(y)
107}
108
109fn layer_norm(x: &[f32], gamma: &[f32], beta: &[f32]) -> TabResult<Vec<f32>> {
111 let n = x.len();
112 if gamma.len() != n || beta.len() != n {
113 return Err(format!(
114 "layer_norm: dim mismatch x={n}, γ={}, β={}",
115 gamma.len(),
116 beta.len()
117 ));
118 }
119 let mean = x.iter().copied().sum::<f32>() / n as f32;
120 let var = x.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n as f32;
121 let std_inv = (var + 1e-5_f32).sqrt().recip();
122 Ok(x.iter()
123 .enumerate()
124 .map(|(i, &v)| (v - mean) * std_inv * gamma[i] + beta[i])
125 .collect())
126}
127
128fn xavier_uniform(size: usize, fan_in: usize, fan_out: usize, rng: &mut StdRng) -> Vec<f32> {
130 let bound = (6.0_f32 / (fan_in + fan_out).max(1) as f32).sqrt();
131 (0..size)
132 .map(|_| {
133 let u: f32 = rng.random();
134 2.0 * bound * u - bound
135 })
136 .collect()
137}
138
139fn kaiming_uniform(size: usize, fan_in: usize, rng: &mut StdRng) -> Vec<f32> {
141 let bound = (2.0_f32 / fan_in.max(1) as f32).sqrt();
142 (0..size)
143 .map(|_| {
144 let u: f32 = rng.random();
145 2.0 * bound * u - bound
146 })
147 .collect()
148}
149
150fn zeros(size: usize) -> Vec<f32> {
152 vec![0.0_f32; size]
153}
154
155fn ones(size: usize) -> Vec<f32> {
157 vec![1.0_f32; size]
158}
159
160fn scaled_dot_product_attn(
163 q: &[f32],
164 k: &[f32],
165 v: &[f32],
166 seq_len: usize,
167 d_model: usize,
168 n_heads: usize,
169 wq: &[f32],
170 wk: &[f32],
171 wv: &[f32],
172 wo: &[f32],
173) -> TabResult<Vec<f32>> {
174 if n_heads == 0 || d_model % n_heads != 0 {
175 return Err(format!(
176 "d_model={d_model} not divisible by n_heads={n_heads}"
177 ));
178 }
179 let dh = d_model / n_heads;
180 let scale = (dh as f32).sqrt().recip();
181
182 let project = |w: &[f32], inp: &[f32]| -> TabResult<Vec<f32>> {
185 let mut out = vec![0.0_f32; seq_len * d_model];
186 for s in 0..seq_len {
187 for o in 0..d_model {
188 let mut acc = 0.0_f32;
189 for i in 0..d_model {
190 acc += w[o * d_model + i] * inp[s * d_model + i];
191 }
192 out[s * d_model + o] = acc;
193 }
194 }
195 Ok(out)
196 };
197
198 let pq = project(wq, q)?;
199 let pk = project(wk, k)?;
200 let pv = project(wv, v)?;
201
202 let mut output = vec![0.0_f32; seq_len * d_model];
203
204 for h in 0..n_heads {
205 let offset = h * dh;
206 let mut scores = vec![0.0_f32; seq_len * seq_len];
208 for i in 0..seq_len {
209 for j in 0..seq_len {
210 let mut dot = 0.0_f32;
211 for d in 0..dh {
212 dot += pq[i * d_model + offset + d] * pk[j * d_model + offset + d];
213 }
214 scores[i * seq_len + j] = dot * scale;
215 }
216 }
217 for i in 0..seq_len {
219 let row = softmax(&scores[i * seq_len..(i + 1) * seq_len]);
220 scores[i * seq_len..(i + 1) * seq_len].copy_from_slice(&row);
221 }
222 for i in 0..seq_len {
224 for d in 0..dh {
225 let mut acc = 0.0_f32;
226 for j in 0..seq_len {
227 acc += scores[i * seq_len + j] * pv[j * d_model + offset + d];
228 }
229 output[i * d_model + offset + d] = acc;
230 }
231 }
232 }
233
234 let mut result = vec![0.0_f32; seq_len * d_model];
236 for s in 0..seq_len {
237 for o in 0..d_model {
238 let mut acc = 0.0_f32;
239 for i in 0..d_model {
240 acc += wo[o * d_model + i] * output[s * d_model + i];
241 }
242 result[s * d_model + o] = acc;
243 }
244 }
245 Ok(result)
246}
247
248#[derive(Debug, Clone)]
254pub struct TabTransformerConfig {
255 pub n_cat_features: usize,
257 pub n_num_features: usize,
259 pub cat_vocab_sizes: Vec<usize>,
261 pub embed_dim: usize,
263 pub n_heads: usize,
265 pub n_layers: usize,
267 pub ffn_dim: usize,
269 pub n_classes: usize,
271}
272
273#[derive(Debug, Clone)]
279pub struct TabTransformer {
280 cfg: TabTransformerConfig,
281 embeddings: Vec<Vec<f32>>,
283 layers: Vec<TabTransformerLayer>,
285 final_ln_g: Vec<f32>,
287 final_ln_b: Vec<f32>,
288 head_w1: Vec<f32>,
290 head_b1: Vec<f32>,
291 head_w2: Vec<f32>,
292 head_b2: Vec<f32>,
293}
294
295#[derive(Debug, Clone)]
296struct TabTransformerLayer {
297 wq: Vec<f32>,
298 wk: Vec<f32>,
299 wv: Vec<f32>,
300 wo: Vec<f32>,
301 ln1_g: Vec<f32>,
302 ln1_b: Vec<f32>,
303 ffn_w1: Vec<f32>,
304 ffn_b1: Vec<f32>,
305 ffn_w2: Vec<f32>,
306 ffn_b2: Vec<f32>,
307 ln2_g: Vec<f32>,
308 ln2_b: Vec<f32>,
309}
310
311impl TabTransformer {
312 pub fn new(cfg: TabTransformerConfig, seed: u64) -> TabResult<Self> {
314 if cfg.cat_vocab_sizes.len() != cfg.n_cat_features {
315 return Err(format!(
316 "TabTransformer: cat_vocab_sizes.len()={} != n_cat_features={}",
317 cfg.cat_vocab_sizes.len(),
318 cfg.n_cat_features
319 ));
320 }
321 let mut rng = StdRng::seed_from_u64(seed);
322 let d = cfg.embed_dim;
323
324 let embeddings = cfg
325 .cat_vocab_sizes
326 .iter()
327 .map(|&v| kaiming_uniform(v * d, d, &mut rng))
328 .collect();
329
330 let layers = (0..cfg.n_layers)
331 .map(|_| TabTransformerLayer {
332 wq: xavier_uniform(d * d, d, d, &mut rng),
333 wk: xavier_uniform(d * d, d, d, &mut rng),
334 wv: xavier_uniform(d * d, d, d, &mut rng),
335 wo: xavier_uniform(d * d, d, d, &mut rng),
336 ln1_g: ones(d),
337 ln1_b: zeros(d),
338 ffn_w1: xavier_uniform(cfg.ffn_dim * d, d, cfg.ffn_dim, &mut rng),
339 ffn_b1: zeros(cfg.ffn_dim),
340 ffn_w2: xavier_uniform(d * cfg.ffn_dim, cfg.ffn_dim, d, &mut rng),
341 ffn_b2: zeros(d),
342 ln2_g: ones(d),
343 ln2_b: zeros(d),
344 })
345 .collect();
346
347 let head_in = cfg.n_cat_features * d + cfg.n_num_features;
348 let head_h = (head_in * 2).max(64);
349 Ok(Self {
350 embeddings,
351 layers,
352 final_ln_g: ones(d),
353 final_ln_b: zeros(d),
354 head_w1: xavier_uniform(head_h * head_in, head_in, head_h, &mut rng),
355 head_b1: zeros(head_h),
356 head_w2: xavier_uniform(cfg.n_classes * head_h, head_h, cfg.n_classes, &mut rng),
357 head_b2: zeros(cfg.n_classes),
358 cfg,
359 })
360 }
361
362 pub fn forward(&self, cat_ids: &[usize], num_features: &[f32]) -> TabResult<Vec<f32>> {
365 let d = self.cfg.embed_dim;
366 if cat_ids.len() != self.cfg.n_cat_features {
367 return Err(format!(
368 "TabTransformer: expected {} cat ids, got {}",
369 self.cfg.n_cat_features,
370 cat_ids.len()
371 ));
372 }
373 if num_features.len() != self.cfg.n_num_features {
374 return Err(format!(
375 "TabTransformer: expected {} num features, got {}",
376 self.cfg.n_num_features,
377 num_features.len()
378 ));
379 }
380
381 let seq_len = self.cfg.n_cat_features;
383 let mut seq = vec![0.0_f32; seq_len * d];
384 for (i, &id) in cat_ids.iter().enumerate() {
385 let v = self.cfg.cat_vocab_sizes[i];
386 let clamped = id.min(v.saturating_sub(1));
387 let emb = &self.embeddings[i][clamped * d..(clamped + 1) * d];
388 seq[i * d..(i + 1) * d].copy_from_slice(emb);
389 }
390
391 for layer in &self.layers {
393 let attn_out = scaled_dot_product_attn(
394 &seq,
395 &seq,
396 &seq,
397 seq_len,
398 d,
399 self.cfg.n_heads,
400 &layer.wq,
401 &layer.wk,
402 &layer.wv,
403 &layer.wo,
404 )?;
405 let mut h = vec![0.0_f32; seq_len * d];
407 for (i, (&a, &s)) in attn_out.iter().zip(seq.iter()).enumerate() {
408 h[i] = a + s;
409 }
410 let mut ln_out = vec![0.0_f32; seq_len * d];
411 for s in 0..seq_len {
412 let normed = layer_norm(&h[s * d..(s + 1) * d], &layer.ln1_g, &layer.ln1_b)?;
413 ln_out[s * d..(s + 1) * d].copy_from_slice(&normed);
414 }
415 let mut ffn_out = vec![0.0_f32; seq_len * d];
417 for s in 0..seq_len {
418 let tok = &ln_out[s * d..(s + 1) * d];
419 let h1 = linear(&layer.ffn_w1, &layer.ffn_b1, tok)?;
420 let h1a: Vec<f32> = h1.iter().map(|&x| gelu(x)).collect();
421 let h2 = linear(&layer.ffn_w2, &layer.ffn_b2, &h1a)?;
422 let res2: Vec<f32> = h2.iter().zip(tok.iter()).map(|(&a, &b)| a + b).collect();
424 let normed2 = layer_norm(&res2, &layer.ln2_g, &layer.ln2_b)?;
425 ffn_out[s * d..(s + 1) * d].copy_from_slice(&normed2);
426 }
427 seq = ffn_out;
428 }
429
430 let mut head_input = vec![0.0_f32; self.cfg.n_cat_features * d + self.cfg.n_num_features];
432 head_input[..seq.len()].copy_from_slice(&seq);
433 head_input[seq.len()..].copy_from_slice(num_features);
434
435 let h1 = linear(&self.head_w1, &self.head_b1, &head_input)?;
436 let h1a: Vec<f32> = h1.iter().map(|&x| relu(x)).collect();
437 linear(&self.head_w2, &self.head_b2, &h1a)
438 }
439}
440
441#[derive(Debug, Clone)]
447pub struct FTTransformerConfig {
448 pub n_cat_features: usize,
450 pub n_num_features: usize,
452 pub cat_vocab_sizes: Vec<usize>,
454 pub embed_dim: usize,
456 pub n_heads: usize,
458 pub n_layers: usize,
460 pub ffn_dim: usize,
462 pub n_classes: usize,
464}
465
466#[derive(Debug, Clone)]
472pub struct FTTransformer {
473 cfg: FTTransformerConfig,
474 cat_embeddings: Vec<Vec<f32>>,
476 num_w: Vec<Vec<f32>>, num_b: Vec<Vec<f32>>, cls_token: Vec<f32>,
481 layers: Vec<TabTransformerLayer>,
483 head_w: Vec<f32>,
485 head_b: Vec<f32>,
486}
487
488impl FTTransformer {
489 pub fn new(cfg: FTTransformerConfig, seed: u64) -> TabResult<Self> {
491 if cfg.cat_vocab_sizes.len() != cfg.n_cat_features {
492 return Err(format!(
493 "FTTransformer: cat_vocab_sizes.len()={} != n_cat_features={}",
494 cfg.cat_vocab_sizes.len(),
495 cfg.n_cat_features
496 ));
497 }
498 let mut rng = StdRng::seed_from_u64(seed);
499 let d = cfg.embed_dim;
500
501 let cat_embeddings = cfg
502 .cat_vocab_sizes
503 .iter()
504 .map(|&v| kaiming_uniform(v * d, d, &mut rng))
505 .collect();
506
507 let num_w = (0..cfg.n_num_features)
508 .map(|_| xavier_uniform(d, 1, d, &mut rng))
509 .collect();
510 let num_b = (0..cfg.n_num_features).map(|_| zeros(d)).collect();
511
512 let cls_token: Vec<f32> = (0..d)
513 .map(|_| {
514 let u: f32 = rng.random();
515 u * 0.02 - 0.01
516 })
517 .collect();
518
519 let layers = (0..cfg.n_layers)
520 .map(|_| TabTransformerLayer {
521 wq: xavier_uniform(d * d, d, d, &mut rng),
522 wk: xavier_uniform(d * d, d, d, &mut rng),
523 wv: xavier_uniform(d * d, d, d, &mut rng),
524 wo: xavier_uniform(d * d, d, d, &mut rng),
525 ln1_g: ones(d),
526 ln1_b: zeros(d),
527 ffn_w1: xavier_uniform(cfg.ffn_dim * d, d, cfg.ffn_dim, &mut rng),
528 ffn_b1: zeros(cfg.ffn_dim),
529 ffn_w2: xavier_uniform(d * cfg.ffn_dim, cfg.ffn_dim, d, &mut rng),
530 ffn_b2: zeros(d),
531 ln2_g: ones(d),
532 ln2_b: zeros(d),
533 })
534 .collect();
535
536 Ok(Self {
537 cat_embeddings,
538 num_w,
539 num_b,
540 cls_token,
541 layers,
542 head_w: xavier_uniform(cfg.n_classes * d, d, cfg.n_classes, &mut rng),
543 head_b: zeros(cfg.n_classes),
544 cfg,
545 })
546 }
547
548 pub fn forward(&self, cat_ids: &[usize], num_features: &[f32]) -> TabResult<Vec<f32>> {
550 let d = self.cfg.embed_dim;
551 if cat_ids.len() != self.cfg.n_cat_features {
552 return Err(format!(
553 "FTTransformer: expected {} cat ids, got {}",
554 self.cfg.n_cat_features,
555 cat_ids.len()
556 ));
557 }
558 if num_features.len() != self.cfg.n_num_features {
559 return Err(format!(
560 "FTTransformer: expected {} num features, got {}",
561 self.cfg.n_num_features,
562 num_features.len()
563 ));
564 }
565
566 let n_tokens = 1 + self.cfg.n_cat_features + self.cfg.n_num_features; let mut tokens = vec![0.0_f32; n_tokens * d];
568
569 tokens[..d].copy_from_slice(&self.cls_token);
571
572 for (i, &id) in cat_ids.iter().enumerate() {
574 let v = self.cfg.cat_vocab_sizes[i];
575 let clamped = id.min(v.saturating_sub(1));
576 let emb = &self.cat_embeddings[i][clamped * d..(clamped + 1) * d];
577 let offset = (1 + i) * d;
578 tokens[offset..offset + d].copy_from_slice(emb);
579 }
580
581 for (i, &x) in num_features.iter().enumerate() {
583 let offset = (1 + self.cfg.n_cat_features + i) * d;
584 for j in 0..d {
585 tokens[offset + j] = x * self.num_w[i][j] + self.num_b[i][j];
586 }
587 }
588
589 let mut seq = tokens;
591 for layer in &self.layers {
592 let attn_out = scaled_dot_product_attn(
593 &seq,
594 &seq,
595 &seq,
596 n_tokens,
597 d,
598 self.cfg.n_heads,
599 &layer.wq,
600 &layer.wk,
601 &layer.wv,
602 &layer.wo,
603 )?;
604 let mut h = vec![0.0_f32; n_tokens * d];
605 for (i, (&a, &s)) in attn_out.iter().zip(seq.iter()).enumerate() {
606 h[i] = a + s;
607 }
608 let mut ln_out = vec![0.0_f32; n_tokens * d];
609 for s in 0..n_tokens {
610 let normed = layer_norm(&h[s * d..(s + 1) * d], &layer.ln1_g, &layer.ln1_b)?;
611 ln_out[s * d..(s + 1) * d].copy_from_slice(&normed);
612 }
613 let mut ffn_out = vec![0.0_f32; n_tokens * d];
614 for s in 0..n_tokens {
615 let tok = &ln_out[s * d..(s + 1) * d];
616 let h1 = linear(&layer.ffn_w1, &layer.ffn_b1, tok)?;
617 let h1a: Vec<f32> = h1.iter().map(|&x| gelu(x)).collect();
618 let h2 = linear(&layer.ffn_w2, &layer.ffn_b2, &h1a)?;
619 let res2: Vec<f32> = h2.iter().zip(tok.iter()).map(|(&a, &b)| a + b).collect();
620 let normed2 = layer_norm(&res2, &layer.ln2_g, &layer.ln2_b)?;
621 ffn_out[s * d..(s + 1) * d].copy_from_slice(&normed2);
622 }
623 seq = ffn_out;
624 }
625
626 let cls = &seq[..d];
628 linear(&self.head_w, &self.head_b, cls)
629 }
630}
631
632#[derive(Debug, Clone)]
642pub struct ObliviousTree {
643 pub depth: usize,
645 pub n_features: usize,
647 pub feature_w: Vec<f32>,
649 pub thresholds: Vec<f32>,
651 pub leaf_responses: Vec<f32>,
653 pub output_dim: usize,
655}
656
657impl ObliviousTree {
658 pub fn new(depth: usize, n_features: usize, output_dim: usize, rng: &mut StdRng) -> Self {
660 let n_leaves = 1usize << depth;
661 let feature_w = xavier_uniform(depth * n_features, n_features, depth, rng);
662 let thresholds: Vec<f32> = (0..depth)
663 .map(|_| {
664 let u: f32 = rng.random();
665 u * 2.0 - 1.0
666 })
667 .collect();
668 let leaf_responses = xavier_uniform(n_leaves * output_dim, n_leaves, output_dim, rng);
669 Self {
670 depth,
671 n_features,
672 feature_w,
673 thresholds,
674 leaf_responses,
675 output_dim,
676 }
677 }
678
679 pub fn forward(&self, x: &[f32]) -> TabResult<Vec<f32>> {
681 if x.len() != self.n_features {
682 return Err(format!(
683 "ObliviousTree: expected {} features, got {}",
684 self.n_features,
685 x.len()
686 ));
687 }
688 let d = self.depth;
689 let n_leaves = 1usize << d;
690
691 let mut leaf_probs = vec![1.0_f32; n_leaves];
693 for layer in 0..d {
694 let fw = &self.feature_w[layer * self.n_features..(layer + 1) * self.n_features];
696 let feature_attn = softmax(fw);
697 let projected: f32 = feature_attn
699 .iter()
700 .zip(x.iter())
701 .map(|(&a, &b)| a * b)
702 .sum();
703 let split_val = sigmoid(projected - self.thresholds[layer]);
704 for leaf in 0..n_leaves {
706 let bit = (leaf >> (d - 1 - layer)) & 1;
707 let p = if bit == 1 { split_val } else { 1.0 - split_val };
708 leaf_probs[leaf] *= p;
709 }
710 }
711
712 let mut output = vec![0.0_f32; self.output_dim];
714 for (leaf, &lp) in leaf_probs.iter().enumerate() {
715 for o in 0..self.output_dim {
716 output[o] += lp * self.leaf_responses[leaf * self.output_dim + o];
717 }
718 }
719 Ok(output)
720 }
721}
722
723#[derive(Debug, Clone)]
725pub struct NodeModel {
726 pub trees: Vec<ObliviousTree>,
728 pub n_features: usize,
730 pub n_classes: usize,
732}
733
734impl NodeModel {
735 pub fn new(
737 n_trees: usize,
738 depth: usize,
739 n_features: usize,
740 n_classes: usize,
741 seed: u64,
742 ) -> Self {
743 let mut rng = StdRng::seed_from_u64(seed);
744 let trees = (0..n_trees)
745 .map(|_| ObliviousTree::new(depth, n_features, n_classes, &mut rng))
746 .collect();
747 Self {
748 trees,
749 n_features,
750 n_classes,
751 }
752 }
753
754 pub fn forward(&self, x: &[f32]) -> TabResult<Vec<f32>> {
756 if self.trees.is_empty() {
757 return Err("NodeModel: no trees".into());
758 }
759 let mut sum = vec![0.0_f32; self.n_classes];
760 for tree in &self.trees {
761 let out = tree.forward(x)?;
762 for (s, &o) in sum.iter_mut().zip(out.iter()) {
763 *s += o;
764 }
765 }
766 let n = self.trees.len() as f32;
767 Ok(sum.iter().map(|&s| s / n).collect())
768 }
769}
770
771#[derive(Debug, Clone)]
777pub struct TabNetConfig {
778 pub n_steps: usize,
780 pub n_d: usize,
782 pub n_a: usize,
784 pub gamma: f32,
786 pub epsilon: f32,
788 pub n_features: usize,
790 pub n_classes: usize,
792}
793
794#[derive(Debug, Clone)]
796pub struct AttentiveTransformer {
797 w: Vec<f32>,
798 b: Vec<f32>,
799 bn_gamma: Vec<f32>,
800 bn_beta: Vec<f32>,
801}
802
803impl AttentiveTransformer {
804 fn new(n_features: usize, n_a: usize, rng: &mut StdRng) -> Self {
805 Self {
806 w: xavier_uniform(n_features * n_a, n_a, n_features, rng),
807 b: zeros(n_features),
808 bn_gamma: ones(n_features),
809 bn_beta: zeros(n_features),
810 }
811 }
812
813 fn forward(&self, h: &[f32], prior_scale: &[f32]) -> TabResult<Vec<f32>> {
815 let n_features = self.b.len();
816 let h_proj = linear(&self.w, &self.b, h)?;
817 let normed = layer_norm(&h_proj, &self.bn_gamma, &self.bn_beta)?;
819 let masked: Vec<f32> = normed
821 .iter()
822 .zip(prior_scale.iter())
823 .map(|(&n, &p)| n * p)
824 .collect();
825 Ok(sparsemax(&masked[..n_features.min(masked.len())]))
826 }
827}
828
829#[derive(Debug, Clone)]
831struct FeatureTransformStep {
832 w1: Vec<f32>,
833 b1: Vec<f32>,
834 w2: Vec<f32>,
835 b2: Vec<f32>,
836}
837
838impl FeatureTransformStep {
839 fn new(in_dim: usize, out_dim: usize, rng: &mut StdRng) -> Self {
840 Self {
841 w1: xavier_uniform(out_dim * in_dim, in_dim, out_dim, rng),
842 b1: zeros(out_dim),
843 w2: xavier_uniform(out_dim * out_dim, out_dim, out_dim, rng),
844 b2: zeros(out_dim),
845 }
846 }
847
848 fn forward(&self, x: &[f32]) -> TabResult<Vec<f32>> {
849 let h = linear(&self.w1, &self.b1, x)?;
850 let ha: Vec<f32> = h.iter().map(|&v| relu(v)).collect();
851 let h2 = linear(&self.w2, &self.b2, &ha)?;
852 Ok(h2.iter().map(|&v| relu(v)).collect())
853 }
854}
855
856#[derive(Debug, Clone)]
858pub struct TabNet {
859 cfg: TabNetConfig,
860 shared_layer: FeatureTransformStep,
861 step_layers: Vec<FeatureTransformStep>,
862 attn_transformers: Vec<AttentiveTransformer>,
863 final_w: Vec<f32>,
864 final_b: Vec<f32>,
865}
866
867impl TabNet {
868 pub fn new(cfg: TabNetConfig, seed: u64) -> TabResult<Self> {
870 if cfg.n_steps == 0 {
871 return Err("TabNet: n_steps must be > 0".into());
872 }
873 let mut rng = StdRng::seed_from_u64(seed);
874 let shared_layer = FeatureTransformStep::new(cfg.n_features, cfg.n_d + cfg.n_a, &mut rng);
875 let step_layers = (0..cfg.n_steps)
876 .map(|_| FeatureTransformStep::new(cfg.n_d + cfg.n_a, cfg.n_d + cfg.n_a, &mut rng))
877 .collect();
878 let attn_transformers = (0..cfg.n_steps)
879 .map(|_| AttentiveTransformer::new(cfg.n_features, cfg.n_a, &mut rng))
880 .collect();
881 let final_w = xavier_uniform(cfg.n_classes * cfg.n_d, cfg.n_d, cfg.n_classes, &mut rng);
882 let final_b = zeros(cfg.n_classes);
883 Ok(Self {
884 cfg,
885 shared_layer,
886 step_layers,
887 attn_transformers,
888 final_w,
889 final_b,
890 })
891 }
892
893 pub fn forward(&self, x: &[f32]) -> TabResult<(Vec<f32>, Vec<Vec<f32>>)> {
895 if x.len() != self.cfg.n_features {
896 return Err(format!(
897 "TabNet: expected {} features, got {}",
898 self.cfg.n_features,
899 x.len()
900 ));
901 }
902 let n = self.cfg.n_features;
903 let mut prior_scale = vec![1.0_f32; n];
904 let mut aggregated_output = vec![0.0_f32; self.cfg.n_d];
905 let mut masks = Vec::with_capacity(self.cfg.n_steps);
906
907 for step in 0..self.cfg.n_steps {
908 let h_for_attn: Vec<f32> = if aggregated_output.is_empty() {
910 vec![0.0_f32; self.cfg.n_a]
911 } else {
912 let mut ha = vec![0.0_f32; self.cfg.n_a];
914 let copy_len = aggregated_output.len().min(self.cfg.n_a);
915 ha[..copy_len].copy_from_slice(&aggregated_output[..copy_len]);
916 ha
917 };
918
919 let mask = self.attn_transformers[step].forward(&h_for_attn, &prior_scale)?;
920 let masked_x: Vec<f32> = mask.iter().zip(x.iter()).map(|(&m, &xi)| m * xi).collect();
922
923 let shared_out = self.shared_layer.forward(&masked_x)?;
925 let step_out = self.step_layers[step].forward(&shared_out)?;
926
927 let split_pt = self.cfg.n_d.min(step_out.len());
929 let decision = &step_out[..split_pt];
930 let relu_decision: Vec<f32> = decision.iter().map(|&v| relu(v)).collect();
931
932 for (a, &d) in aggregated_output.iter_mut().zip(relu_decision.iter()) {
933 *a += d;
934 }
935
936 for (p, &m) in prior_scale.iter_mut().zip(mask.iter()) {
938 *p *= (self.cfg.gamma - m).max(0.0);
939 }
940
941 masks.push(mask);
942 }
943
944 let logits = linear(&self.final_w, &self.final_b, &aggregated_output)?;
946 Ok((logits, masks))
947 }
948}
949
950#[derive(Debug, Clone)]
956pub struct SaintBlock {
957 d_model: usize,
958 n_heads: usize,
959 wq_intra: Vec<f32>,
961 wk_intra: Vec<f32>,
962 wv_intra: Vec<f32>,
963 wo_intra: Vec<f32>,
964 ln1_g: Vec<f32>,
965 ln1_b: Vec<f32>,
966 wq_inter: Vec<f32>,
968 wk_inter: Vec<f32>,
969 wv_inter: Vec<f32>,
970 wo_inter: Vec<f32>,
971 ln2_g: Vec<f32>,
972 ln2_b: Vec<f32>,
973 ffn_w1: Vec<f32>,
975 ffn_b1: Vec<f32>,
976 ffn_w2: Vec<f32>,
977 ffn_b2: Vec<f32>,
978 ln3_g: Vec<f32>,
979 ln3_b: Vec<f32>,
980}
981
982impl SaintBlock {
983 pub fn new(d_model: usize, n_heads: usize, ffn_dim: usize, seed: u64) -> TabResult<Self> {
985 if d_model % n_heads != 0 {
986 return Err(format!(
987 "SaintBlock: d_model={d_model} not divisible by n_heads={n_heads}"
988 ));
989 }
990 let mut rng = StdRng::seed_from_u64(seed);
991 Ok(Self {
992 d_model,
993 n_heads,
994 wq_intra: xavier_uniform(d_model * d_model, d_model, d_model, &mut rng),
995 wk_intra: xavier_uniform(d_model * d_model, d_model, d_model, &mut rng),
996 wv_intra: xavier_uniform(d_model * d_model, d_model, d_model, &mut rng),
997 wo_intra: xavier_uniform(d_model * d_model, d_model, d_model, &mut rng),
998 ln1_g: ones(d_model),
999 ln1_b: zeros(d_model),
1000 wq_inter: xavier_uniform(d_model * d_model, d_model, d_model, &mut rng),
1001 wk_inter: xavier_uniform(d_model * d_model, d_model, d_model, &mut rng),
1002 wv_inter: xavier_uniform(d_model * d_model, d_model, d_model, &mut rng),
1003 wo_inter: xavier_uniform(d_model * d_model, d_model, d_model, &mut rng),
1004 ln2_g: ones(d_model),
1005 ln2_b: zeros(d_model),
1006 ffn_w1: xavier_uniform(ffn_dim * d_model, d_model, ffn_dim, &mut rng),
1007 ffn_b1: zeros(ffn_dim),
1008 ffn_w2: xavier_uniform(d_model * ffn_dim, ffn_dim, d_model, &mut rng),
1009 ffn_b2: zeros(d_model),
1010 ln3_g: ones(d_model),
1011 ln3_b: zeros(d_model),
1012 })
1013 }
1014
1015 pub fn intra_feature_attention(&self, x: &[f32]) -> TabResult<Vec<f32>> {
1018 let d = self.d_model;
1019 if x.len() % d != 0 {
1020 return Err(format!(
1021 "SaintBlock::intra: x.len()={} not divisible by d_model={d}",
1022 x.len()
1023 ));
1024 }
1025 let seq_len = x.len() / d;
1026 let attn_out = scaled_dot_product_attn(
1027 x,
1028 x,
1029 x,
1030 seq_len,
1031 d,
1032 self.n_heads,
1033 &self.wq_intra,
1034 &self.wk_intra,
1035 &self.wv_intra,
1036 &self.wo_intra,
1037 )?;
1038 let mut ln_out = vec![0.0_f32; seq_len * d];
1040 for s in 0..seq_len {
1041 let res: Vec<f32> = attn_out[s * d..(s + 1) * d]
1042 .iter()
1043 .zip(x[s * d..(s + 1) * d].iter())
1044 .map(|(&a, &b)| a + b)
1045 .collect();
1046 let normed = layer_norm(&res, &self.ln1_g, &self.ln1_b)?;
1047 ln_out[s * d..(s + 1) * d].copy_from_slice(&normed);
1048 }
1049 let mut ffn_out = vec![0.0_f32; seq_len * d];
1051 for s in 0..seq_len {
1052 let tok = &ln_out[s * d..(s + 1) * d];
1053 let h1 = linear(&self.ffn_w1, &self.ffn_b1, tok)?;
1054 let h1a: Vec<f32> = h1.iter().map(|&v| gelu(v)).collect();
1055 let h2 = linear(&self.ffn_w2, &self.ffn_b2, &h1a)?;
1056 let res2: Vec<f32> = h2.iter().zip(tok.iter()).map(|(&a, &b)| a + b).collect();
1057 let normed2 = layer_norm(&res2, &self.ln3_g, &self.ln3_b)?;
1058 ffn_out[s * d..(s + 1) * d].copy_from_slice(&normed2);
1059 }
1060 Ok(ffn_out)
1061 }
1062
1063 pub fn inter_sample_attention(&self, batch: &[Vec<f32>]) -> TabResult<Vec<Vec<f32>>> {
1066 let d = self.d_model;
1067 let n_samples = batch.len();
1068 if n_samples == 0 {
1069 return Ok(Vec::new());
1070 }
1071 for (i, s) in batch.iter().enumerate() {
1072 if s.len() != d {
1073 return Err(format!(
1074 "SaintBlock::inter: batch[{i}].len()={} != d_model={d}",
1075 s.len()
1076 ));
1077 }
1078 }
1079 let flat: Vec<f32> = batch.iter().flat_map(|s| s.iter().copied()).collect();
1081 let attn_out = scaled_dot_product_attn(
1082 &flat,
1083 &flat,
1084 &flat,
1085 n_samples,
1086 d,
1087 self.n_heads,
1088 &self.wq_inter,
1089 &self.wk_inter,
1090 &self.wv_inter,
1091 &self.wo_inter,
1092 )?;
1093 let mut result = Vec::with_capacity(n_samples);
1095 for s in 0..n_samples {
1096 let res: Vec<f32> = attn_out[s * d..(s + 1) * d]
1097 .iter()
1098 .zip(flat[s * d..(s + 1) * d].iter())
1099 .map(|(&a, &b)| a + b)
1100 .collect();
1101 let normed = layer_norm(&res, &self.ln2_g, &self.ln2_b)?;
1102 result.push(normed);
1103 }
1104 Ok(result)
1105 }
1106}
1107
1108#[derive(Debug, Clone)]
1110pub struct SaintModel {
1111 pub d_model: usize,
1113 pub n_blocks: usize,
1115 cat_embeddings: Vec<Vec<f32>>,
1117 n_cat_features: usize,
1118 n_num_features: usize,
1119 cat_vocab_sizes: Vec<usize>,
1120 num_w: Vec<Vec<f32>>,
1122 num_b: Vec<Vec<f32>>,
1123 blocks: Vec<SaintBlock>,
1125 head_w: Vec<f32>,
1127 head_b: Vec<f32>,
1128 n_classes: usize,
1129}
1130
1131impl SaintModel {
1132 pub fn new(
1134 n_cat_features: usize,
1135 n_num_features: usize,
1136 cat_vocab_sizes: Vec<usize>,
1137 d_model: usize,
1138 n_heads: usize,
1139 n_blocks: usize,
1140 ffn_dim: usize,
1141 n_classes: usize,
1142 seed: u64,
1143 ) -> TabResult<Self> {
1144 if cat_vocab_sizes.len() != n_cat_features {
1145 return Err("SaintModel: cat_vocab_sizes.len() != n_cat_features".into());
1146 }
1147 let mut rng = StdRng::seed_from_u64(seed);
1148 let cat_embeddings = cat_vocab_sizes
1149 .iter()
1150 .map(|&v| kaiming_uniform(v * d_model, d_model, &mut rng))
1151 .collect();
1152 let num_w = (0..n_num_features)
1153 .map(|_| xavier_uniform(d_model, 1, d_model, &mut rng))
1154 .collect();
1155 let num_b = (0..n_num_features).map(|_| zeros(d_model)).collect();
1156 let blocks = (0..n_blocks)
1157 .map(|i| SaintBlock::new(d_model, n_heads, ffn_dim, seed.wrapping_add(i as u64 + 1)))
1158 .collect::<TabResult<Vec<_>>>()?;
1159 let n_features = n_cat_features + n_num_features;
1160 let head_w = xavier_uniform(
1161 n_classes * n_features * d_model,
1162 n_features * d_model,
1163 n_classes,
1164 &mut rng,
1165 );
1166 let head_b = zeros(n_classes);
1167 Ok(Self {
1168 d_model,
1169 n_blocks,
1170 cat_embeddings,
1171 n_cat_features,
1172 n_num_features,
1173 cat_vocab_sizes,
1174 num_w,
1175 num_b,
1176 blocks,
1177 head_w,
1178 head_b,
1179 n_classes,
1180 })
1181 }
1182
1183 pub fn forward(&self, cat_ids: &[usize], num_features: &[f32]) -> TabResult<Vec<f32>> {
1185 let d = self.d_model;
1186 let n_features = self.n_cat_features + self.n_num_features;
1187
1188 let mut tokens = vec![0.0_f32; n_features * d];
1190 for (i, &id) in cat_ids.iter().enumerate() {
1191 let v = self.cat_vocab_sizes[i];
1192 let clamped = id.min(v.saturating_sub(1));
1193 let emb = &self.cat_embeddings[i][clamped * d..(clamped + 1) * d];
1194 tokens[i * d..(i + 1) * d].copy_from_slice(emb);
1195 }
1196 for (i, &x) in num_features.iter().enumerate() {
1197 let offset = (self.n_cat_features + i) * d;
1198 for j in 0..d {
1199 tokens[offset + j] = x * self.num_w[i][j] + self.num_b[i][j];
1200 }
1201 }
1202
1203 let mut seq = tokens;
1205 for block in &self.blocks {
1206 seq = block.intra_feature_attention(&seq)?;
1207 }
1208
1209 linear(&self.head_w, &self.head_b, &seq)
1211 }
1212}
1213
1214#[derive(Debug, Clone, Default)]
1220pub struct StandardScaler {
1221 pub mean: Vec<f32>,
1222 pub std: Vec<f32>,
1223}
1224
1225impl StandardScaler {
1226 pub fn fit(&mut self, data: &[&[f32]]) {
1228 if data.is_empty() {
1229 return;
1230 }
1231 let n_features = data[0].len();
1232 let n = data.len() as f32;
1233 self.mean = vec![0.0_f32; n_features];
1234 self.std = vec![1.0_f32; n_features];
1235 for row in data {
1236 for (i, &v) in row.iter().enumerate() {
1237 if i < n_features {
1238 self.mean[i] += v;
1239 }
1240 }
1241 }
1242 for m in self.mean.iter_mut() {
1243 *m /= n;
1244 }
1245 let mut var = vec![0.0_f32; n_features];
1246 for row in data {
1247 for (i, &v) in row.iter().enumerate() {
1248 if i < n_features {
1249 var[i] += (v - self.mean[i]).powi(2);
1250 }
1251 }
1252 }
1253 for (i, v) in var.iter().enumerate() {
1254 self.std[i] = (v / n.max(1.0) + 1e-7).sqrt();
1255 }
1256 }
1257
1258 pub fn transform(&self, x: &[f32]) -> Vec<f32> {
1260 x.iter()
1261 .enumerate()
1262 .map(|(i, &v)| {
1263 let m = self.mean.get(i).copied().unwrap_or(0.0);
1264 let s = self.std.get(i).copied().unwrap_or(1.0);
1265 (v - m) / s.max(1e-7)
1266 })
1267 .collect()
1268 }
1269}
1270
1271#[derive(Debug, Clone, Default)]
1273pub struct MinMaxScaler {
1274 pub min: Vec<f32>,
1275 pub range: Vec<f32>,
1276}
1277
1278impl MinMaxScaler {
1279 pub fn fit(&mut self, data: &[&[f32]]) {
1281 if data.is_empty() {
1282 return;
1283 }
1284 let n_features = data[0].len();
1285 let mut mins = vec![f32::INFINITY; n_features];
1286 let mut maxs = vec![f32::NEG_INFINITY; n_features];
1287 for row in data {
1288 for (i, &v) in row.iter().enumerate() {
1289 if i < n_features {
1290 if v < mins[i] {
1291 mins[i] = v;
1292 }
1293 if v > maxs[i] {
1294 maxs[i] = v;
1295 }
1296 }
1297 }
1298 }
1299 self.range = mins
1300 .iter()
1301 .zip(maxs.iter())
1302 .map(|(&lo, &hi)| (hi - lo).max(1e-7))
1303 .collect();
1304 self.min = mins;
1305 }
1306
1307 pub fn transform(&self, x: &[f32]) -> Vec<f32> {
1309 x.iter()
1310 .enumerate()
1311 .map(|(i, &v)| {
1312 let lo = self.min.get(i).copied().unwrap_or(0.0);
1313 let r = self.range.get(i).copied().unwrap_or(1.0);
1314 (v - lo) / r
1315 })
1316 .collect()
1317 }
1318}
1319
1320#[derive(Debug, Clone, Default)]
1322pub struct QuantileTransformer {
1323 pub quantiles: Vec<Vec<f32>>,
1325}
1326
1327impl QuantileTransformer {
1328 pub fn fit(&mut self, data: &[&[f32]]) {
1330 if data.is_empty() {
1331 return;
1332 }
1333 let n_features = data[0].len();
1334 self.quantiles = vec![Vec::new(); n_features];
1335 for feat_idx in 0..n_features {
1336 let mut vals: Vec<f32> = data
1337 .iter()
1338 .filter_map(|row| row.get(feat_idx).copied())
1339 .collect();
1340 vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1341 self.quantiles[feat_idx] = vals;
1342 }
1343 }
1344
1345 pub fn transform(&self, x: &[f32]) -> Vec<f32> {
1347 x.iter()
1348 .enumerate()
1349 .map(|(i, &v)| {
1350 if i >= self.quantiles.len() || self.quantiles[i].is_empty() {
1351 return v;
1352 }
1353 let q = &self.quantiles[i];
1354 let n = q.len() as f32;
1355 let rank = q.partition_point(|&s| s <= v);
1357 let p = (rank as f32 + 0.5) / (n + 1.0);
1358 let p_clamped = p.clamp(1e-6, 1.0 - 1e-6);
1359 probit(p_clamped)
1361 })
1362 .collect()
1363 }
1364}
1365
1366fn probit(p: f32) -> f32 {
1368 let q = p - 0.5;
1370 if q.abs() < 0.425 {
1371 let r = 0.180625 - q * q;
1372 let num = ((2.509_081_f32 * r + 33.143_f32) * r + 85.44_f32) * r + 45.41_f32;
1373 let den = ((r + 15.159_f32) * r + 29.891_f32) * r + 1.0;
1374 q * (num / den)
1375 } else {
1376 let r = if q < 0.0 { p } else { 1.0 - p };
1377 let lr = (-r.ln()).sqrt().clamp(0.0, 10.0);
1378 let sign = if q < 0.0 { -1.0_f32 } else { 1.0_f32 };
1379 let num = (1.4234_f32 * lr + 4.6233_f32) * lr + 0.6806_f32;
1380 let den = (lr + 3.6575_f32) * lr + 1.0_f32;
1381 sign * (num / den)
1382 }
1383}
1384
1385#[derive(Debug, Clone)]
1387pub struct CyclicEncoder {
1388 pub periods: Vec<f32>,
1390}
1391
1392impl CyclicEncoder {
1393 pub fn new(periods: Vec<f32>) -> Self {
1395 Self { periods }
1396 }
1397
1398 pub fn transform(&self, x: &[f32]) -> Vec<f32> {
1401 let mut out = Vec::with_capacity(x.len() * 2);
1402 for (i, &v) in x.iter().enumerate() {
1403 let t = self.periods.get(i).copied().unwrap_or(1.0).max(1e-7);
1404 let angle = 2.0 * PI * v / t;
1405 out.push(angle.sin());
1406 out.push(angle.cos());
1407 }
1408 out
1409 }
1410}
1411
1412#[derive(Debug, Clone)]
1414pub struct FeatureEncoder {
1415 pub scaler: StandardScaler,
1416 pub minmax: MinMaxScaler,
1417 pub quantile: QuantileTransformer,
1418 pub cyclic: Option<CyclicEncoder>,
1419}
1420
1421impl FeatureEncoder {
1422 pub fn new(cyclic_periods: Option<Vec<f32>>) -> Self {
1424 Self {
1425 scaler: StandardScaler::default(),
1426 minmax: MinMaxScaler::default(),
1427 quantile: QuantileTransformer::default(),
1428 cyclic: cyclic_periods.map(CyclicEncoder::new),
1429 }
1430 }
1431
1432 pub fn fit(&mut self, data: &[&[f32]]) {
1434 self.scaler.fit(data);
1435 self.minmax.fit(data);
1436 self.quantile.fit(data);
1437 }
1438
1439 pub fn transform(&self, x: &[f32]) -> Vec<f32> {
1441 self.scaler.transform(x)
1442 }
1443
1444 pub fn quantile_transform(&self, x: &[f32]) -> Vec<f32> {
1446 self.quantile.transform(x)
1447 }
1448}
1449
1450#[derive(Debug, Clone)]
1459pub struct MixedInputHead {
1460 cat_dim: usize,
1461 num_dim: usize,
1462 out_dim: usize,
1463 gate_w: Vec<f32>,
1465 gate_b: Vec<f32>,
1466 cat_proj_w: Vec<f32>,
1468 cat_proj_b: Vec<f32>,
1469 num_proj_w: Vec<f32>,
1471 num_proj_b: Vec<f32>,
1472}
1473
1474impl MixedInputHead {
1475 pub fn new(cat_dim: usize, num_dim: usize, out_dim: usize, seed: u64) -> Self {
1477 let mut rng = StdRng::seed_from_u64(seed);
1478 let joint_dim = cat_dim + num_dim;
1479 Self {
1480 cat_dim,
1481 num_dim,
1482 out_dim,
1483 gate_w: xavier_uniform(out_dim * joint_dim, joint_dim, out_dim, &mut rng),
1484 gate_b: zeros(out_dim),
1485 cat_proj_w: xavier_uniform(out_dim * cat_dim, cat_dim, out_dim, &mut rng),
1486 cat_proj_b: zeros(out_dim),
1487 num_proj_w: xavier_uniform(out_dim * num_dim, num_dim, out_dim, &mut rng),
1488 num_proj_b: zeros(out_dim),
1489 }
1490 }
1491
1492 pub fn gate(&self, cat_repr: &[f32], num_repr: &[f32]) -> TabResult<Vec<f32>> {
1494 if cat_repr.len() != self.cat_dim {
1495 return Err(format!(
1496 "MixedInputHead: cat_repr.len()={} != cat_dim={}",
1497 cat_repr.len(),
1498 self.cat_dim
1499 ));
1500 }
1501 if num_repr.len() != self.num_dim {
1502 return Err(format!(
1503 "MixedInputHead: num_repr.len()={} != num_dim={}",
1504 num_repr.len(),
1505 self.num_dim
1506 ));
1507 }
1508 let joint: Vec<f32> = cat_repr.iter().chain(num_repr.iter()).copied().collect();
1509 let gate_logits = linear(&self.gate_w, &self.gate_b, &joint)?;
1510 let g: Vec<f32> = gate_logits.iter().map(|&v| sigmoid(v)).collect();
1511
1512 let cat_out = linear(&self.cat_proj_w, &self.cat_proj_b, cat_repr)?;
1513 let num_out = linear(&self.num_proj_w, &self.num_proj_b, num_repr)?;
1514
1515 Ok(g.iter()
1516 .zip(cat_out.iter())
1517 .zip(num_out.iter())
1518 .map(|((&gi, &ci), &ni)| gi * ci + (1.0 - gi) * ni)
1519 .collect())
1520 }
1521}
1522
1523pub struct TabularAugmentation;
1529
1530impl TabularAugmentation {
1531 pub fn mixup(
1534 x1: &[f32],
1535 x2: &[f32],
1536 alpha: f32,
1537 rng: &mut StdRng,
1538 ) -> TabResult<(Vec<f32>, f32)> {
1539 if x1.len() != x2.len() {
1540 return Err(format!("mixup: len mismatch {} vs {}", x1.len(), x2.len()));
1541 }
1542 let lambda = Self::beta_sample(alpha, rng);
1543 let mixed: Vec<f32> = x1
1544 .iter()
1545 .zip(x2.iter())
1546 .map(|(&a, &b)| lambda * a + (1.0 - lambda) * b)
1547 .collect();
1548 Ok((mixed, lambda))
1549 }
1550
1551 pub fn cutmix(
1553 x1: &[f32],
1554 x2: &[f32],
1555 alpha: f32,
1556 rng: &mut StdRng,
1557 ) -> TabResult<(Vec<f32>, f32)> {
1558 if x1.len() != x2.len() {
1559 return Err(format!("cutmix: len mismatch {} vs {}", x1.len(), x2.len()));
1560 }
1561 let lambda = Self::beta_sample(alpha, rng);
1562 let n = x1.len();
1563 let cut_len = (n as f32 * (1.0 - lambda)).round() as usize;
1564 let start_f: f32 = rng.random();
1565 let start = (start_f * (n.saturating_sub(cut_len) + 1) as f32) as usize;
1566 let end = (start + cut_len).min(n);
1567
1568 let mut mixed = x1.to_vec();
1569 mixed[start..end].copy_from_slice(&x2[start..end]);
1570 let actual_lambda = 1.0 - (end - start) as f32 / n.max(1) as f32;
1571 Ok((mixed, actual_lambda))
1572 }
1573
1574 pub fn smote_like(
1577 sample: &[f32],
1578 neighbours: &[Vec<f32>],
1579 rng: &mut StdRng,
1580 ) -> TabResult<Vec<f32>> {
1581 if neighbours.is_empty() {
1582 return Err("smote_like: no neighbours provided".into());
1583 }
1584 let idx_f: f32 = rng.random();
1585 let idx = (idx_f * neighbours.len() as f32) as usize;
1586 let idx = idx.min(neighbours.len() - 1);
1587 let neighbour = &neighbours[idx];
1588 if sample.len() != neighbour.len() {
1589 return Err(format!(
1590 "smote_like: sample.len()={} != neighbour.len()={}",
1591 sample.len(),
1592 neighbour.len()
1593 ));
1594 }
1595 let gap: f32 = rng.random();
1596 Ok(sample
1597 .iter()
1598 .zip(neighbour.iter())
1599 .map(|(&s, &n)| s + gap * (n - s))
1600 .collect())
1601 }
1602
1603 fn beta_sample(alpha: f32, rng: &mut StdRng) -> f32 {
1606 let alpha = alpha.max(0.01);
1607 let u1: f32 = rng.random::<f32>().max(1e-7);
1609 let u2: f32 = rng.random();
1610 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
1611 let sigma = 1.0 / (2.0 * (2.0 * alpha + 1.0).sqrt());
1612 (0.5 + sigma * z).clamp(0.0, 1.0)
1613 }
1614}
1615
1616pub struct TabularMetrics;
1622
1623impl TabularMetrics {
1624 pub fn accuracy(pred_logits: &[Vec<f32>], target: &[usize]) -> f32 {
1627 if pred_logits.is_empty() || pred_logits.len() != target.len() {
1628 return 0.0;
1629 }
1630 let correct = pred_logits
1631 .iter()
1632 .zip(target.iter())
1633 .filter(|(logits, &t)| {
1634 let max_idx = logits
1635 .iter()
1636 .enumerate()
1637 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1638 .map(|(i, _)| i)
1639 .unwrap_or(0);
1640 max_idx == t
1641 })
1642 .count();
1643 correct as f32 / pred_logits.len() as f32
1644 }
1645
1646 pub fn macro_f1(pred: &[Vec<f32>], target: &[usize], n_classes: usize) -> f32 {
1648 if pred.is_empty() || n_classes == 0 {
1649 return 0.0;
1650 }
1651 let mut tp = vec![0u32; n_classes];
1652 let mut fp = vec![0u32; n_classes];
1653 let mut fn_ = vec![0u32; n_classes];
1654
1655 for (logits, &t) in pred.iter().zip(target.iter()) {
1656 let pred_class = logits
1657 .iter()
1658 .enumerate()
1659 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1660 .map(|(i, _)| i)
1661 .unwrap_or(0);
1662 if pred_class == t {
1663 if t < n_classes {
1664 tp[t] += 1;
1665 }
1666 } else {
1667 if pred_class < n_classes {
1668 fp[pred_class] += 1;
1669 }
1670 if t < n_classes {
1671 fn_[t] += 1;
1672 }
1673 }
1674 }
1675
1676 let mut f1_sum = 0.0_f32;
1677 for c in 0..n_classes {
1678 let precision = tp[c] as f32 / (tp[c] + fp[c]).max(1) as f32;
1679 let recall = tp[c] as f32 / (tp[c] + fn_[c]).max(1) as f32;
1680 let denom = precision + recall;
1681 let f1 = if denom > 0.0 {
1682 2.0 * precision * recall / denom
1683 } else {
1684 0.0
1685 };
1686 f1_sum += f1;
1687 }
1688 f1_sum / n_classes as f32
1689 }
1690
1691 pub fn rmse(pred: &[f32], target: &[f32]) -> f32 {
1693 if pred.is_empty() || pred.len() != target.len() {
1694 return f32::NAN;
1695 }
1696 let mse = pred
1697 .iter()
1698 .zip(target.iter())
1699 .map(|(&p, &t)| (p - t).powi(2))
1700 .sum::<f32>()
1701 / pred.len() as f32;
1702 mse.sqrt()
1703 }
1704
1705 pub fn r2_score(pred: &[f32], target: &[f32]) -> f32 {
1707 if pred.is_empty() || pred.len() != target.len() {
1708 return f32::NAN;
1709 }
1710 let mean_t = target.iter().sum::<f32>() / target.len() as f32;
1711 let ss_tot: f32 = target.iter().map(|&t| (t - mean_t).powi(2)).sum();
1712 let ss_res: f32 = pred
1713 .iter()
1714 .zip(target.iter())
1715 .map(|(&p, &t)| (t - p).powi(2))
1716 .sum();
1717 if ss_tot < 1e-12 {
1718 return if ss_res < 1e-12 { 1.0 } else { 0.0 };
1719 }
1720 1.0 - ss_res / ss_tot
1721 }
1722}
1723
1724#[derive(Debug, Clone, Default)]
1734pub struct CatBoostEncoder {
1735 pub category_stats: std::collections::HashMap<usize, (usize, f64)>,
1737 pub prior: f32,
1739}
1740
1741impl CatBoostEncoder {
1742 pub fn new() -> Self {
1744 Self::default()
1745 }
1746
1747 pub fn fit_transform(&mut self, categories: &[usize], targets: &[f32], prior: f32) -> Vec<f32> {
1749 if categories.len() != targets.len() || categories.is_empty() {
1750 return Vec::new();
1751 }
1752 self.prior = prior;
1753 self.category_stats.clear();
1754
1755 for (&cat, &t) in categories.iter().zip(targets.iter()) {
1757 let entry = self.category_stats.entry(cat).or_insert((0, 0.0));
1758 entry.0 += 1;
1759 entry.1 += t as f64;
1760 }
1761
1762 categories
1764 .iter()
1765 .zip(targets.iter())
1766 .map(|(&cat, &t)| {
1767 let (count, sum) = self.category_stats.get(&cat).copied().unwrap_or((0, 0.0));
1768 let loo_sum = sum - t as f64;
1770 let loo_count = count.saturating_sub(1);
1771 let lambda = loo_count as f32 / (loo_count as f32 + 1.0);
1773 let loo_mean = if loo_count == 0 {
1774 prior
1775 } else {
1776 (loo_sum / loo_count as f64) as f32
1777 };
1778 lambda * loo_mean + (1.0 - lambda) * prior
1779 })
1780 .collect()
1781 }
1782
1783 pub fn transform(&self, categories: &[usize]) -> Vec<f32> {
1785 categories
1786 .iter()
1787 .map(|&cat| match self.category_stats.get(&cat) {
1788 None => self.prior,
1789 Some(&(count, sum)) => {
1790 let mean = if count == 0 {
1791 self.prior as f64
1792 } else {
1793 sum / count as f64
1794 };
1795 let lambda = count as f32 / (count as f32 + 1.0);
1796 lambda * mean as f32 + (1.0 - lambda) * self.prior
1797 }
1798 })
1799 .collect()
1800 }
1801}
1802
1803#[cfg(test)]
1808mod tests;