Skip to main content

tenflowers_neural/tabular_learning/
mod.rs

1//! Deep Learning for Tabular Data.
2//!
3//! Implements state-of-the-art architectures for structured / tabular data:
4//!
5//! | Component | Reference |
6//! |-----------|-----------|
7//! | [`TabTransformer`] | Huang et al. (2020) — TabTransformer |
8//! | [`FTTransformer`] | Gorishniy et al. (2021) — FT-Transformer |
9//! | [`NodeModel`] | Popov et al. (2020) — NODE (oblivious decision trees) |
10//! | [`TabNet`] | Arik & Pfister (2021) — TabNet |
11//! | [`SaintModel`] | Somepalli et al. (2021) — SAINT |
12//! | [`FeatureEncoder`] | preprocessing pipeline (scaler, quantile, cyclic) |
13//! | [`MixedInputHead`] | gated fusion of categorical + numeric representations |
14//! | [`TabularAugmentation`] | Mixup / CutMix / SMOTE-like oversampling |
15//! | [`TabularMetrics`] | accuracy, macro-F1, RMSE, R² |
16//! | [`CatBoostEncoder`] | leave-one-out target encoding |
17//!
18//! All weights are `Vec<f32>` buffers; no `Tensor`/autograd dependency.
19//! No `unwrap()`, no `unsafe`, single file ≤ 1900 lines.
20
21use std::f32::consts::PI;
22
23use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
24use scirs2_core::RngExt;
25
26// ─────────────────────────────────────────────────────────────────────────────
27// Shared error / result type
28// ─────────────────────────────────────────────────────────────────────────────
29
30type TabResult<T> = Result<T, String>;
31
32// ─────────────────────────────────────────────────────────────────────────────
33// Shared math helpers
34// ─────────────────────────────────────────────────────────────────────────────
35
36#[inline]
37fn relu(x: f32) -> f32 {
38    x.max(0.0)
39}
40
41#[inline]
42fn gelu(x: f32) -> f32 {
43    0.5 * x * (1.0 + (x * 0.797_884_6 * (1.0 + 0.044715 * x * x)).tanh())
44}
45
46#[inline]
47fn sigmoid(x: f32) -> f32 {
48    let c = x.clamp(-88.0, 88.0);
49    1.0 / (1.0 + (-c).exp())
50}
51
52/// Softmax over a slice; returns a new `Vec`.
53fn softmax(v: &[f32]) -> Vec<f32> {
54    if v.is_empty() {
55        return Vec::new();
56    }
57    let max = v.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
58    let exps: Vec<f32> = v.iter().map(|&x| (x - max).exp()).collect();
59    let sum: f32 = exps.iter().sum::<f32>().max(1e-12);
60    exps.iter().map(|&e| e / sum).collect()
61}
62
63/// Sparsemax: projects onto the probability simplex.
64fn sparsemax(z: &[f32]) -> Vec<f32> {
65    let n = z.len();
66    if n == 0 {
67        return Vec::new();
68    }
69    let mut sorted = z.to_vec();
70    sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
71    let mut cumsum = 0.0_f32;
72    let mut k = n;
73    for (i, &s) in sorted.iter().enumerate() {
74        cumsum += s;
75        if s > (cumsum - 1.0) / (i + 1) as f32 {
76            k = i + 1;
77        }
78    }
79    let tau = (sorted[..k].iter().sum::<f32>() - 1.0) / k as f32;
80    z.iter().map(|&zi| (zi - tau).max(0.0)).collect()
81}
82
83/// Dense layer forward: `y = W·x + b`.
84/// `w` is row-major `[out × in]`.
85fn linear(w: &[f32], b: &[f32], x: &[f32]) -> TabResult<Vec<f32>> {
86    let in_dim = x.len();
87    let out_dim = b.len();
88    if w.len() != out_dim * in_dim {
89        return Err(format!(
90            "linear: w.len()={} != out×in={}×{}",
91            w.len(),
92            out_dim,
93            in_dim
94        ));
95    }
96    let mut y = vec![0.0_f32; out_dim];
97    for o in 0..out_dim {
98        let row = &w[o * in_dim..(o + 1) * in_dim];
99        y[o] = b[o]
100            + row
101                .iter()
102                .zip(x.iter())
103                .map(|(&wi, &xi)| wi * xi)
104                .sum::<f32>();
105    }
106    Ok(y)
107}
108
109/// Layer normalisation: `(x − mean) / (std + ε) * γ + β`.
110fn layer_norm(x: &[f32], gamma: &[f32], beta: &[f32]) -> TabResult<Vec<f32>> {
111    let n = x.len();
112    if gamma.len() != n || beta.len() != n {
113        return Err(format!(
114            "layer_norm: dim mismatch x={n}, γ={}, β={}",
115            gamma.len(),
116            beta.len()
117        ));
118    }
119    let mean = x.iter().copied().sum::<f32>() / n as f32;
120    let var = x.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n as f32;
121    let std_inv = (var + 1e-5_f32).sqrt().recip();
122    Ok(x.iter()
123        .enumerate()
124        .map(|(i, &v)| (v - mean) * std_inv * gamma[i] + beta[i])
125        .collect())
126}
127
128/// Xavier uniform initialisation.
129fn xavier_uniform(size: usize, fan_in: usize, fan_out: usize, rng: &mut StdRng) -> Vec<f32> {
130    let bound = (6.0_f32 / (fan_in + fan_out).max(1) as f32).sqrt();
131    (0..size)
132        .map(|_| {
133            let u: f32 = rng.random();
134            2.0 * bound * u - bound
135        })
136        .collect()
137}
138
139/// Kaiming uniform initialisation.
140fn kaiming_uniform(size: usize, fan_in: usize, rng: &mut StdRng) -> Vec<f32> {
141    let bound = (2.0_f32 / fan_in.max(1) as f32).sqrt();
142    (0..size)
143        .map(|_| {
144            let u: f32 = rng.random();
145            2.0 * bound * u - bound
146        })
147        .collect()
148}
149
150/// Zeros initialisation.
151fn zeros(size: usize) -> Vec<f32> {
152    vec![0.0_f32; size]
153}
154
155/// Ones initialisation.
156fn ones(size: usize) -> Vec<f32> {
157    vec![1.0_f32; size]
158}
159
160/// Scaled dot-product attention over a sequence of [seq_len × d_model] vectors.
161/// Returns [seq_len × d_model].
162fn scaled_dot_product_attn(
163    q: &[f32],
164    k: &[f32],
165    v: &[f32],
166    seq_len: usize,
167    d_model: usize,
168    n_heads: usize,
169    wq: &[f32],
170    wk: &[f32],
171    wv: &[f32],
172    wo: &[f32],
173) -> TabResult<Vec<f32>> {
174    if n_heads == 0 || d_model % n_heads != 0 {
175        return Err(format!(
176            "d_model={d_model} not divisible by n_heads={n_heads}"
177        ));
178    }
179    let dh = d_model / n_heads;
180    let scale = (dh as f32).sqrt().recip();
181
182    // Project: [seq_len × d_model] → [seq_len × d_model] for Q, K, V
183    // wq/wk/wv: [d_model × d_model]
184    let project = |w: &[f32], inp: &[f32]| -> TabResult<Vec<f32>> {
185        let mut out = vec![0.0_f32; seq_len * d_model];
186        for s in 0..seq_len {
187            for o in 0..d_model {
188                let mut acc = 0.0_f32;
189                for i in 0..d_model {
190                    acc += w[o * d_model + i] * inp[s * d_model + i];
191                }
192                out[s * d_model + o] = acc;
193            }
194        }
195        Ok(out)
196    };
197
198    let pq = project(wq, q)?;
199    let pk = project(wk, k)?;
200    let pv = project(wv, v)?;
201
202    let mut output = vec![0.0_f32; seq_len * d_model];
203
204    for h in 0..n_heads {
205        let offset = h * dh;
206        // Compute attention scores [seq_len × seq_len]
207        let mut scores = vec![0.0_f32; seq_len * seq_len];
208        for i in 0..seq_len {
209            for j in 0..seq_len {
210                let mut dot = 0.0_f32;
211                for d in 0..dh {
212                    dot += pq[i * d_model + offset + d] * pk[j * d_model + offset + d];
213                }
214                scores[i * seq_len + j] = dot * scale;
215            }
216        }
217        // Softmax row-wise
218        for i in 0..seq_len {
219            let row = softmax(&scores[i * seq_len..(i + 1) * seq_len]);
220            scores[i * seq_len..(i + 1) * seq_len].copy_from_slice(&row);
221        }
222        // Weighted sum of V
223        for i in 0..seq_len {
224            for d in 0..dh {
225                let mut acc = 0.0_f32;
226                for j in 0..seq_len {
227                    acc += scores[i * seq_len + j] * pv[j * d_model + offset + d];
228                }
229                output[i * d_model + offset + d] = acc;
230            }
231        }
232    }
233
234    // Output projection: [seq_len × d_model] × W_O [d_model × d_model]
235    let mut result = vec![0.0_f32; seq_len * d_model];
236    for s in 0..seq_len {
237        for o in 0..d_model {
238            let mut acc = 0.0_f32;
239            for i in 0..d_model {
240                acc += wo[o * d_model + i] * output[s * d_model + i];
241            }
242            result[s * d_model + o] = acc;
243        }
244    }
245    Ok(result)
246}
247
248// ─────────────────────────────────────────────────────────────────────────────
249// ══════════════════════  TabTransformer  ════════════════════════════════════
250// ─────────────────────────────────────────────────────────────────────────────
251
252/// Configuration for [`TabTransformer`].
253#[derive(Debug, Clone)]
254pub struct TabTransformerConfig {
255    /// Number of categorical features.
256    pub n_cat_features: usize,
257    /// Number of numeric (continuous) features.
258    pub n_num_features: usize,
259    /// Vocabulary size for each categorical feature.
260    pub cat_vocab_sizes: Vec<usize>,
261    /// Embedding dimension for categorical features.
262    pub embed_dim: usize,
263    /// Number of attention heads per transformer layer.
264    pub n_heads: usize,
265    /// Number of transformer encoder layers.
266    pub n_layers: usize,
267    /// Feed-forward network hidden dimension.
268    pub ffn_dim: usize,
269    /// Number of output classes (1 for regression).
270    pub n_classes: usize,
271}
272
273/// Transformer for tabular data (Huang et al., 2020).
274///
275/// Categorical features are embedded and processed through column-wise
276/// multi-head self-attention layers. Numeric features are concatenated
277/// after the transformer and passed through an MLP classification head.
278#[derive(Debug, Clone)]
279pub struct TabTransformer {
280    cfg: TabTransformerConfig,
281    /// Embedding tables: one per categorical feature, each [vocab × embed_dim].
282    embeddings: Vec<Vec<f32>>,
283    /// Transformer layers (each has WQ, WK, WV, WO, FFN W1/b1/W2/b2, LN params).
284    layers: Vec<TabTransformerLayer>,
285    /// Final LN: gamma, beta over embed_dim.
286    final_ln_g: Vec<f32>,
287    final_ln_b: Vec<f32>,
288    /// MLP head: input = n_cat * embed_dim + n_num → hidden → n_classes.
289    head_w1: Vec<f32>,
290    head_b1: Vec<f32>,
291    head_w2: Vec<f32>,
292    head_b2: Vec<f32>,
293}
294
295#[derive(Debug, Clone)]
296struct TabTransformerLayer {
297    wq: Vec<f32>,
298    wk: Vec<f32>,
299    wv: Vec<f32>,
300    wo: Vec<f32>,
301    ln1_g: Vec<f32>,
302    ln1_b: Vec<f32>,
303    ffn_w1: Vec<f32>,
304    ffn_b1: Vec<f32>,
305    ffn_w2: Vec<f32>,
306    ffn_b2: Vec<f32>,
307    ln2_g: Vec<f32>,
308    ln2_b: Vec<f32>,
309}
310
311impl TabTransformer {
312    /// Create a new `TabTransformer` with Xavier-initialised weights.
313    pub fn new(cfg: TabTransformerConfig, seed: u64) -> TabResult<Self> {
314        if cfg.cat_vocab_sizes.len() != cfg.n_cat_features {
315            return Err(format!(
316                "TabTransformer: cat_vocab_sizes.len()={} != n_cat_features={}",
317                cfg.cat_vocab_sizes.len(),
318                cfg.n_cat_features
319            ));
320        }
321        let mut rng = StdRng::seed_from_u64(seed);
322        let d = cfg.embed_dim;
323
324        let embeddings = cfg
325            .cat_vocab_sizes
326            .iter()
327            .map(|&v| kaiming_uniform(v * d, d, &mut rng))
328            .collect();
329
330        let layers = (0..cfg.n_layers)
331            .map(|_| TabTransformerLayer {
332                wq: xavier_uniform(d * d, d, d, &mut rng),
333                wk: xavier_uniform(d * d, d, d, &mut rng),
334                wv: xavier_uniform(d * d, d, d, &mut rng),
335                wo: xavier_uniform(d * d, d, d, &mut rng),
336                ln1_g: ones(d),
337                ln1_b: zeros(d),
338                ffn_w1: xavier_uniform(cfg.ffn_dim * d, d, cfg.ffn_dim, &mut rng),
339                ffn_b1: zeros(cfg.ffn_dim),
340                ffn_w2: xavier_uniform(d * cfg.ffn_dim, cfg.ffn_dim, d, &mut rng),
341                ffn_b2: zeros(d),
342                ln2_g: ones(d),
343                ln2_b: zeros(d),
344            })
345            .collect();
346
347        let head_in = cfg.n_cat_features * d + cfg.n_num_features;
348        let head_h = (head_in * 2).max(64);
349        Ok(Self {
350            embeddings,
351            layers,
352            final_ln_g: ones(d),
353            final_ln_b: zeros(d),
354            head_w1: xavier_uniform(head_h * head_in, head_in, head_h, &mut rng),
355            head_b1: zeros(head_h),
356            head_w2: xavier_uniform(cfg.n_classes * head_h, head_h, cfg.n_classes, &mut rng),
357            head_b2: zeros(cfg.n_classes),
358            cfg,
359        })
360    }
361
362    /// Forward pass. `cat_ids` length must equal `n_cat_features`;
363    /// `num_features` length must equal `n_num_features`.
364    pub fn forward(&self, cat_ids: &[usize], num_features: &[f32]) -> TabResult<Vec<f32>> {
365        let d = self.cfg.embed_dim;
366        if cat_ids.len() != self.cfg.n_cat_features {
367            return Err(format!(
368                "TabTransformer: expected {} cat ids, got {}",
369                self.cfg.n_cat_features,
370                cat_ids.len()
371            ));
372        }
373        if num_features.len() != self.cfg.n_num_features {
374            return Err(format!(
375                "TabTransformer: expected {} num features, got {}",
376                self.cfg.n_num_features,
377                num_features.len()
378            ));
379        }
380
381        // Embed each categorical feature → [n_cat × embed_dim]
382        let seq_len = self.cfg.n_cat_features;
383        let mut seq = vec![0.0_f32; seq_len * d];
384        for (i, &id) in cat_ids.iter().enumerate() {
385            let v = self.cfg.cat_vocab_sizes[i];
386            let clamped = id.min(v.saturating_sub(1));
387            let emb = &self.embeddings[i][clamped * d..(clamped + 1) * d];
388            seq[i * d..(i + 1) * d].copy_from_slice(emb);
389        }
390
391        // Apply transformer layers
392        for layer in &self.layers {
393            let attn_out = scaled_dot_product_attn(
394                &seq,
395                &seq,
396                &seq,
397                seq_len,
398                d,
399                self.cfg.n_heads,
400                &layer.wq,
401                &layer.wk,
402                &layer.wv,
403                &layer.wo,
404            )?;
405            // Residual + LN
406            let mut h = vec![0.0_f32; seq_len * d];
407            for (i, (&a, &s)) in attn_out.iter().zip(seq.iter()).enumerate() {
408                h[i] = a + s;
409            }
410            let mut ln_out = vec![0.0_f32; seq_len * d];
411            for s in 0..seq_len {
412                let normed = layer_norm(&h[s * d..(s + 1) * d], &layer.ln1_g, &layer.ln1_b)?;
413                ln_out[s * d..(s + 1) * d].copy_from_slice(&normed);
414            }
415            // FFN per token
416            let mut ffn_out = vec![0.0_f32; seq_len * d];
417            for s in 0..seq_len {
418                let tok = &ln_out[s * d..(s + 1) * d];
419                let h1 = linear(&layer.ffn_w1, &layer.ffn_b1, tok)?;
420                let h1a: Vec<f32> = h1.iter().map(|&x| gelu(x)).collect();
421                let h2 = linear(&layer.ffn_w2, &layer.ffn_b2, &h1a)?;
422                // Residual + LN2
423                let res2: Vec<f32> = h2.iter().zip(tok.iter()).map(|(&a, &b)| a + b).collect();
424                let normed2 = layer_norm(&res2, &layer.ln2_g, &layer.ln2_b)?;
425                ffn_out[s * d..(s + 1) * d].copy_from_slice(&normed2);
426            }
427            seq = ffn_out;
428        }
429
430        // Pool: flatten cat embeddings
431        let mut head_input = vec![0.0_f32; self.cfg.n_cat_features * d + self.cfg.n_num_features];
432        head_input[..seq.len()].copy_from_slice(&seq);
433        head_input[seq.len()..].copy_from_slice(num_features);
434
435        let h1 = linear(&self.head_w1, &self.head_b1, &head_input)?;
436        let h1a: Vec<f32> = h1.iter().map(|&x| relu(x)).collect();
437        linear(&self.head_w2, &self.head_b2, &h1a)
438    }
439}
440
441// ─────────────────────────────────────────────────────────────────────────────
442// ══════════════════════  FT-Transformer  ════════════════════════════════════
443// ─────────────────────────────────────────────────────────────────────────────
444
445/// Configuration for [`FTTransformer`].
446#[derive(Debug, Clone)]
447pub struct FTTransformerConfig {
448    /// Number of categorical features.
449    pub n_cat_features: usize,
450    /// Number of continuous numeric features.
451    pub n_num_features: usize,
452    /// Vocabulary sizes for each categorical feature.
453    pub cat_vocab_sizes: Vec<usize>,
454    /// Shared embedding dimension for all feature tokens.
455    pub embed_dim: usize,
456    /// Number of attention heads.
457    pub n_heads: usize,
458    /// Number of transformer encoder layers.
459    pub n_layers: usize,
460    /// FFN hidden dim.
461    pub ffn_dim: usize,
462    /// Number of output classes (1 = regression).
463    pub n_classes: usize,
464}
465
466/// Feature Tokenizer + Transformer (Gorishniy et al., 2021).
467///
468/// Both numeric and categorical features are projected into the same
469/// `embed_dim`-dimensional token space, then processed by a standard
470/// transformer encoder. The `[CLS]` token drives the output head.
471#[derive(Debug, Clone)]
472pub struct FTTransformer {
473    cfg: FTTransformerConfig,
474    /// Categorical embeddings [vocab × embed_dim] per feature.
475    cat_embeddings: Vec<Vec<f32>>,
476    /// Numeric tokenizer: one weight + bias per numeric feature → embed_dim.
477    num_w: Vec<Vec<f32>>, // [n_num × embed_dim]
478    num_b: Vec<Vec<f32>>, // [n_num × embed_dim]
479    /// Learnable [CLS] token embedding.
480    cls_token: Vec<f32>,
481    /// Transformer encoder layers (same structure as TabTransformer).
482    layers: Vec<TabTransformerLayer>,
483    /// Output head on CLS representation.
484    head_w: Vec<f32>,
485    head_b: Vec<f32>,
486}
487
488impl FTTransformer {
489    /// Create a new `FTTransformer` with Xavier-initialised weights.
490    pub fn new(cfg: FTTransformerConfig, seed: u64) -> TabResult<Self> {
491        if cfg.cat_vocab_sizes.len() != cfg.n_cat_features {
492            return Err(format!(
493                "FTTransformer: cat_vocab_sizes.len()={} != n_cat_features={}",
494                cfg.cat_vocab_sizes.len(),
495                cfg.n_cat_features
496            ));
497        }
498        let mut rng = StdRng::seed_from_u64(seed);
499        let d = cfg.embed_dim;
500
501        let cat_embeddings = cfg
502            .cat_vocab_sizes
503            .iter()
504            .map(|&v| kaiming_uniform(v * d, d, &mut rng))
505            .collect();
506
507        let num_w = (0..cfg.n_num_features)
508            .map(|_| xavier_uniform(d, 1, d, &mut rng))
509            .collect();
510        let num_b = (0..cfg.n_num_features).map(|_| zeros(d)).collect();
511
512        let cls_token: Vec<f32> = (0..d)
513            .map(|_| {
514                let u: f32 = rng.random();
515                u * 0.02 - 0.01
516            })
517            .collect();
518
519        let layers = (0..cfg.n_layers)
520            .map(|_| TabTransformerLayer {
521                wq: xavier_uniform(d * d, d, d, &mut rng),
522                wk: xavier_uniform(d * d, d, d, &mut rng),
523                wv: xavier_uniform(d * d, d, d, &mut rng),
524                wo: xavier_uniform(d * d, d, d, &mut rng),
525                ln1_g: ones(d),
526                ln1_b: zeros(d),
527                ffn_w1: xavier_uniform(cfg.ffn_dim * d, d, cfg.ffn_dim, &mut rng),
528                ffn_b1: zeros(cfg.ffn_dim),
529                ffn_w2: xavier_uniform(d * cfg.ffn_dim, cfg.ffn_dim, d, &mut rng),
530                ffn_b2: zeros(d),
531                ln2_g: ones(d),
532                ln2_b: zeros(d),
533            })
534            .collect();
535
536        Ok(Self {
537            cat_embeddings,
538            num_w,
539            num_b,
540            cls_token,
541            layers,
542            head_w: xavier_uniform(cfg.n_classes * d, d, cfg.n_classes, &mut rng),
543            head_b: zeros(cfg.n_classes),
544            cfg,
545        })
546    }
547
548    /// Forward pass. Returns logits of length `n_classes`.
549    pub fn forward(&self, cat_ids: &[usize], num_features: &[f32]) -> TabResult<Vec<f32>> {
550        let d = self.cfg.embed_dim;
551        if cat_ids.len() != self.cfg.n_cat_features {
552            return Err(format!(
553                "FTTransformer: expected {} cat ids, got {}",
554                self.cfg.n_cat_features,
555                cat_ids.len()
556            ));
557        }
558        if num_features.len() != self.cfg.n_num_features {
559            return Err(format!(
560                "FTTransformer: expected {} num features, got {}",
561                self.cfg.n_num_features,
562                num_features.len()
563            ));
564        }
565
566        let n_tokens = 1 + self.cfg.n_cat_features + self.cfg.n_num_features; // CLS + features
567        let mut tokens = vec![0.0_f32; n_tokens * d];
568
569        // CLS token at index 0
570        tokens[..d].copy_from_slice(&self.cls_token);
571
572        // Categorical tokens
573        for (i, &id) in cat_ids.iter().enumerate() {
574            let v = self.cfg.cat_vocab_sizes[i];
575            let clamped = id.min(v.saturating_sub(1));
576            let emb = &self.cat_embeddings[i][clamped * d..(clamped + 1) * d];
577            let offset = (1 + i) * d;
578            tokens[offset..offset + d].copy_from_slice(emb);
579        }
580
581        // Numeric tokens: scalar * w + b
582        for (i, &x) in num_features.iter().enumerate() {
583            let offset = (1 + self.cfg.n_cat_features + i) * d;
584            for j in 0..d {
585                tokens[offset + j] = x * self.num_w[i][j] + self.num_b[i][j];
586            }
587        }
588
589        // Transformer encoder
590        let mut seq = tokens;
591        for layer in &self.layers {
592            let attn_out = scaled_dot_product_attn(
593                &seq,
594                &seq,
595                &seq,
596                n_tokens,
597                d,
598                self.cfg.n_heads,
599                &layer.wq,
600                &layer.wk,
601                &layer.wv,
602                &layer.wo,
603            )?;
604            let mut h = vec![0.0_f32; n_tokens * d];
605            for (i, (&a, &s)) in attn_out.iter().zip(seq.iter()).enumerate() {
606                h[i] = a + s;
607            }
608            let mut ln_out = vec![0.0_f32; n_tokens * d];
609            for s in 0..n_tokens {
610                let normed = layer_norm(&h[s * d..(s + 1) * d], &layer.ln1_g, &layer.ln1_b)?;
611                ln_out[s * d..(s + 1) * d].copy_from_slice(&normed);
612            }
613            let mut ffn_out = vec![0.0_f32; n_tokens * d];
614            for s in 0..n_tokens {
615                let tok = &ln_out[s * d..(s + 1) * d];
616                let h1 = linear(&layer.ffn_w1, &layer.ffn_b1, tok)?;
617                let h1a: Vec<f32> = h1.iter().map(|&x| gelu(x)).collect();
618                let h2 = linear(&layer.ffn_w2, &layer.ffn_b2, &h1a)?;
619                let res2: Vec<f32> = h2.iter().zip(tok.iter()).map(|(&a, &b)| a + b).collect();
620                let normed2 = layer_norm(&res2, &layer.ln2_g, &layer.ln2_b)?;
621                ffn_out[s * d..(s + 1) * d].copy_from_slice(&normed2);
622            }
623            seq = ffn_out;
624        }
625
626        // Use CLS token for classification
627        let cls = &seq[..d];
628        linear(&self.head_w, &self.head_b, cls)
629    }
630}
631
632// ─────────────────────────────────────────────────────────────────────────────
633// ══════════════════════  NODE  ══════════════════════════════════════════════
634// ─────────────────────────────────────────────────────────────────────────────
635
636/// Single differentiable oblivious decision tree (Popov et al., 2020).
637///
638/// Each internal node selects the same feature at every depth level.
639/// The `depth` parameter controls the number of splits; each tree
640/// produces `2^depth` leaves.
641#[derive(Debug, Clone)]
642pub struct ObliviousTree {
643    /// Number of split layers (depth of the tree).
644    pub depth: usize,
645    /// Total number of input features.
646    pub n_features: usize,
647    /// Feature selection weights: [depth × n_features] (softmax along dim 1).
648    pub feature_w: Vec<f32>,
649    /// Learned split thresholds: \[depth\].
650    pub thresholds: Vec<f32>,
651    /// Leaf response values: [2^depth × output_dim].
652    pub leaf_responses: Vec<f32>,
653    /// Output dimension per tree.
654    pub output_dim: usize,
655}
656
657impl ObliviousTree {
658    /// Create a new `ObliviousTree` with random initialisation.
659    pub fn new(depth: usize, n_features: usize, output_dim: usize, rng: &mut StdRng) -> Self {
660        let n_leaves = 1usize << depth;
661        let feature_w = xavier_uniform(depth * n_features, n_features, depth, rng);
662        let thresholds: Vec<f32> = (0..depth)
663            .map(|_| {
664                let u: f32 = rng.random();
665                u * 2.0 - 1.0
666            })
667            .collect();
668        let leaf_responses = xavier_uniform(n_leaves * output_dim, n_leaves, output_dim, rng);
669        Self {
670            depth,
671            n_features,
672            feature_w,
673            thresholds,
674            leaf_responses,
675            output_dim,
676        }
677    }
678
679    /// Differentiable forward: soft routing via entmax/sigmoid.
680    pub fn forward(&self, x: &[f32]) -> TabResult<Vec<f32>> {
681        if x.len() != self.n_features {
682            return Err(format!(
683                "ObliviousTree: expected {} features, got {}",
684                self.n_features,
685                x.len()
686            ));
687        }
688        let d = self.depth;
689        let n_leaves = 1usize << d;
690
691        // Compute one split value per depth level using the selected feature.
692        let mut leaf_probs = vec![1.0_f32; n_leaves];
693        for layer in 0..d {
694            // Soft feature selection via softmax
695            let fw = &self.feature_w[layer * self.n_features..(layer + 1) * self.n_features];
696            let feature_attn = softmax(fw);
697            // Compute feature projection
698            let projected: f32 = feature_attn
699                .iter()
700                .zip(x.iter())
701                .map(|(&a, &b)| a * b)
702                .sum();
703            let split_val = sigmoid(projected - self.thresholds[layer]);
704            // Update leaf probabilities: left branch = (1 - split), right = split
705            for leaf in 0..n_leaves {
706                let bit = (leaf >> (d - 1 - layer)) & 1;
707                let p = if bit == 1 { split_val } else { 1.0 - split_val };
708                leaf_probs[leaf] *= p;
709            }
710        }
711
712        // Weighted sum of leaf responses
713        let mut output = vec![0.0_f32; self.output_dim];
714        for (leaf, &lp) in leaf_probs.iter().enumerate() {
715            for o in 0..self.output_dim {
716                output[o] += lp * self.leaf_responses[leaf * self.output_dim + o];
717            }
718        }
719        Ok(output)
720    }
721}
722
723/// Ensemble of Oblivious Decision Trees (NODE — Popov et al., 2020).
724#[derive(Debug, Clone)]
725pub struct NodeModel {
726    /// The ensemble trees.
727    pub trees: Vec<ObliviousTree>,
728    /// Input feature dimension.
729    pub n_features: usize,
730    /// Number of output classes / regression targets.
731    pub n_classes: usize,
732}
733
734impl NodeModel {
735    /// Create a new `NodeModel` with `n_trees` oblivious trees.
736    pub fn new(
737        n_trees: usize,
738        depth: usize,
739        n_features: usize,
740        n_classes: usize,
741        seed: u64,
742    ) -> Self {
743        let mut rng = StdRng::seed_from_u64(seed);
744        let trees = (0..n_trees)
745            .map(|_| ObliviousTree::new(depth, n_features, n_classes, &mut rng))
746            .collect();
747        Self {
748            trees,
749            n_features,
750            n_classes,
751        }
752    }
753
754    /// Forward pass: averages tree outputs.
755    pub fn forward(&self, x: &[f32]) -> TabResult<Vec<f32>> {
756        if self.trees.is_empty() {
757            return Err("NodeModel: no trees".into());
758        }
759        let mut sum = vec![0.0_f32; self.n_classes];
760        for tree in &self.trees {
761            let out = tree.forward(x)?;
762            for (s, &o) in sum.iter_mut().zip(out.iter()) {
763                *s += o;
764            }
765        }
766        let n = self.trees.len() as f32;
767        Ok(sum.iter().map(|&s| s / n).collect())
768    }
769}
770
771// ─────────────────────────────────────────────────────────────────────────────
772// ══════════════════════  TabNet  ════════════════════════════════════════════
773// ─────────────────────────────────────────────────────────────────────────────
774
775/// TabNet configuration (Arik & Pfister, 2021).
776#[derive(Debug, Clone)]
777pub struct TabNetConfig {
778    /// Number of sequential attention steps.
779    pub n_steps: usize,
780    /// Width of feature transform output (decision step output dim).
781    pub n_d: usize,
782    /// Width of attentive transformer output.
783    pub n_a: usize,
784    /// Coefficient for feature reusage penalty.
785    pub gamma: f32,
786    /// Epsilon for batch normalisation.
787    pub epsilon: f32,
788    /// Number of input features.
789    pub n_features: usize,
790    /// Number of output classes.
791    pub n_classes: usize,
792}
793
794/// Attentive transformer for one TabNet step.
795#[derive(Debug, Clone)]
796pub struct AttentiveTransformer {
797    w: Vec<f32>,
798    b: Vec<f32>,
799    bn_gamma: Vec<f32>,
800    bn_beta: Vec<f32>,
801}
802
803impl AttentiveTransformer {
804    fn new(n_features: usize, n_a: usize, rng: &mut StdRng) -> Self {
805        Self {
806            w: xavier_uniform(n_features * n_a, n_a, n_features, rng),
807            b: zeros(n_features),
808            bn_gamma: ones(n_features),
809            bn_beta: zeros(n_features),
810        }
811    }
812
813    /// Compute sparsemax-normalised feature mask.
814    fn forward(&self, h: &[f32], prior_scale: &[f32]) -> TabResult<Vec<f32>> {
815        let n_features = self.b.len();
816        let h_proj = linear(&self.w, &self.b, h)?;
817        // Batch-norm approximation (instance normalisation)
818        let normed = layer_norm(&h_proj, &self.bn_gamma, &self.bn_beta)?;
819        // Element-wise multiply by prior_scale
820        let masked: Vec<f32> = normed
821            .iter()
822            .zip(prior_scale.iter())
823            .map(|(&n, &p)| n * p)
824            .collect();
825        Ok(sparsemax(&masked[..n_features.min(masked.len())]))
826    }
827}
828
829/// Shared + step-specific feature transform layers.
830#[derive(Debug, Clone)]
831struct FeatureTransformStep {
832    w1: Vec<f32>,
833    b1: Vec<f32>,
834    w2: Vec<f32>,
835    b2: Vec<f32>,
836}
837
838impl FeatureTransformStep {
839    fn new(in_dim: usize, out_dim: usize, rng: &mut StdRng) -> Self {
840        Self {
841            w1: xavier_uniform(out_dim * in_dim, in_dim, out_dim, rng),
842            b1: zeros(out_dim),
843            w2: xavier_uniform(out_dim * out_dim, out_dim, out_dim, rng),
844            b2: zeros(out_dim),
845        }
846    }
847
848    fn forward(&self, x: &[f32]) -> TabResult<Vec<f32>> {
849        let h = linear(&self.w1, &self.b1, x)?;
850        let ha: Vec<f32> = h.iter().map(|&v| relu(v)).collect();
851        let h2 = linear(&self.w2, &self.b2, &ha)?;
852        Ok(h2.iter().map(|&v| relu(v)).collect())
853    }
854}
855
856/// TabNet model (Arik & Pfister, 2021).
857#[derive(Debug, Clone)]
858pub struct TabNet {
859    cfg: TabNetConfig,
860    shared_layer: FeatureTransformStep,
861    step_layers: Vec<FeatureTransformStep>,
862    attn_transformers: Vec<AttentiveTransformer>,
863    final_w: Vec<f32>,
864    final_b: Vec<f32>,
865}
866
867impl TabNet {
868    /// Create a new `TabNet`.
869    pub fn new(cfg: TabNetConfig, seed: u64) -> TabResult<Self> {
870        if cfg.n_steps == 0 {
871            return Err("TabNet: n_steps must be > 0".into());
872        }
873        let mut rng = StdRng::seed_from_u64(seed);
874        let shared_layer = FeatureTransformStep::new(cfg.n_features, cfg.n_d + cfg.n_a, &mut rng);
875        let step_layers = (0..cfg.n_steps)
876            .map(|_| FeatureTransformStep::new(cfg.n_d + cfg.n_a, cfg.n_d + cfg.n_a, &mut rng))
877            .collect();
878        let attn_transformers = (0..cfg.n_steps)
879            .map(|_| AttentiveTransformer::new(cfg.n_features, cfg.n_a, &mut rng))
880            .collect();
881        let final_w = xavier_uniform(cfg.n_classes * cfg.n_d, cfg.n_d, cfg.n_classes, &mut rng);
882        let final_b = zeros(cfg.n_classes);
883        Ok(Self {
884            cfg,
885            shared_layer,
886            step_layers,
887            attn_transformers,
888            final_w,
889            final_b,
890        })
891    }
892
893    /// Forward pass returning `(logits, per_step_masks)`.
894    pub fn forward(&self, x: &[f32]) -> TabResult<(Vec<f32>, Vec<Vec<f32>>)> {
895        if x.len() != self.cfg.n_features {
896            return Err(format!(
897                "TabNet: expected {} features, got {}",
898                self.cfg.n_features,
899                x.len()
900            ));
901        }
902        let n = self.cfg.n_features;
903        let mut prior_scale = vec![1.0_f32; n];
904        let mut aggregated_output = vec![0.0_f32; self.cfg.n_d];
905        let mut masks = Vec::with_capacity(self.cfg.n_steps);
906
907        for step in 0..self.cfg.n_steps {
908            // Attentive transformer: needs n_a-dim input (use current output aggregation)
909            let h_for_attn: Vec<f32> = if aggregated_output.is_empty() {
910                vec![0.0_f32; self.cfg.n_a]
911            } else {
912                // pad/trim aggregated_output to n_a
913                let mut ha = vec![0.0_f32; self.cfg.n_a];
914                let copy_len = aggregated_output.len().min(self.cfg.n_a);
915                ha[..copy_len].copy_from_slice(&aggregated_output[..copy_len]);
916                ha
917            };
918
919            let mask = self.attn_transformers[step].forward(&h_for_attn, &prior_scale)?;
920            // Masked input
921            let masked_x: Vec<f32> = mask.iter().zip(x.iter()).map(|(&m, &xi)| m * xi).collect();
922
923            // Feature transform: shared + step-specific
924            let shared_out = self.shared_layer.forward(&masked_x)?;
925            let step_out = self.step_layers[step].forward(&shared_out)?;
926
927            // Split into n_d (decision) and n_a (attention) parts
928            let split_pt = self.cfg.n_d.min(step_out.len());
929            let decision = &step_out[..split_pt];
930            let relu_decision: Vec<f32> = decision.iter().map(|&v| relu(v)).collect();
931
932            for (a, &d) in aggregated_output.iter_mut().zip(relu_decision.iter()) {
933                *a += d;
934            }
935
936            // Update prior scale
937            for (p, &m) in prior_scale.iter_mut().zip(mask.iter()) {
938                *p *= (self.cfg.gamma - m).max(0.0);
939            }
940
941            masks.push(mask);
942        }
943
944        // Final output
945        let logits = linear(&self.final_w, &self.final_b, &aggregated_output)?;
946        Ok((logits, masks))
947    }
948}
949
950// ─────────────────────────────────────────────────────────────────────────────
951// ══════════════════════  SAINT  ═════════════════════════════════════════════
952// ─────────────────────────────────────────────────────────────────────────────
953
954/// Single SAINT block with inter-sample + intra-feature attention.
955#[derive(Debug, Clone)]
956pub struct SaintBlock {
957    d_model: usize,
958    n_heads: usize,
959    // Intra-feature (column) self-attention
960    wq_intra: Vec<f32>,
961    wk_intra: Vec<f32>,
962    wv_intra: Vec<f32>,
963    wo_intra: Vec<f32>,
964    ln1_g: Vec<f32>,
965    ln1_b: Vec<f32>,
966    // Inter-sample (row) self-attention: same-dim
967    wq_inter: Vec<f32>,
968    wk_inter: Vec<f32>,
969    wv_inter: Vec<f32>,
970    wo_inter: Vec<f32>,
971    ln2_g: Vec<f32>,
972    ln2_b: Vec<f32>,
973    // FFN
974    ffn_w1: Vec<f32>,
975    ffn_b1: Vec<f32>,
976    ffn_w2: Vec<f32>,
977    ffn_b2: Vec<f32>,
978    ln3_g: Vec<f32>,
979    ln3_b: Vec<f32>,
980}
981
982impl SaintBlock {
983    /// Create a new `SaintBlock`.
984    pub fn new(d_model: usize, n_heads: usize, ffn_dim: usize, seed: u64) -> TabResult<Self> {
985        if d_model % n_heads != 0 {
986            return Err(format!(
987                "SaintBlock: d_model={d_model} not divisible by n_heads={n_heads}"
988            ));
989        }
990        let mut rng = StdRng::seed_from_u64(seed);
991        Ok(Self {
992            d_model,
993            n_heads,
994            wq_intra: xavier_uniform(d_model * d_model, d_model, d_model, &mut rng),
995            wk_intra: xavier_uniform(d_model * d_model, d_model, d_model, &mut rng),
996            wv_intra: xavier_uniform(d_model * d_model, d_model, d_model, &mut rng),
997            wo_intra: xavier_uniform(d_model * d_model, d_model, d_model, &mut rng),
998            ln1_g: ones(d_model),
999            ln1_b: zeros(d_model),
1000            wq_inter: xavier_uniform(d_model * d_model, d_model, d_model, &mut rng),
1001            wk_inter: xavier_uniform(d_model * d_model, d_model, d_model, &mut rng),
1002            wv_inter: xavier_uniform(d_model * d_model, d_model, d_model, &mut rng),
1003            wo_inter: xavier_uniform(d_model * d_model, d_model, d_model, &mut rng),
1004            ln2_g: ones(d_model),
1005            ln2_b: zeros(d_model),
1006            ffn_w1: xavier_uniform(ffn_dim * d_model, d_model, ffn_dim, &mut rng),
1007            ffn_b1: zeros(ffn_dim),
1008            ffn_w2: xavier_uniform(d_model * ffn_dim, ffn_dim, d_model, &mut rng),
1009            ffn_b2: zeros(d_model),
1010            ln3_g: ones(d_model),
1011            ln3_b: zeros(d_model),
1012        })
1013    }
1014
1015    /// Intra-feature (column-wise) self-attention over a single sample's feature tokens.
1016    /// `x` is [n_features × d_model] flattened.
1017    pub fn intra_feature_attention(&self, x: &[f32]) -> TabResult<Vec<f32>> {
1018        let d = self.d_model;
1019        if x.len() % d != 0 {
1020            return Err(format!(
1021                "SaintBlock::intra: x.len()={} not divisible by d_model={d}",
1022                x.len()
1023            ));
1024        }
1025        let seq_len = x.len() / d;
1026        let attn_out = scaled_dot_product_attn(
1027            x,
1028            x,
1029            x,
1030            seq_len,
1031            d,
1032            self.n_heads,
1033            &self.wq_intra,
1034            &self.wk_intra,
1035            &self.wv_intra,
1036            &self.wo_intra,
1037        )?;
1038        // Residual + LN
1039        let mut ln_out = vec![0.0_f32; seq_len * d];
1040        for s in 0..seq_len {
1041            let res: Vec<f32> = attn_out[s * d..(s + 1) * d]
1042                .iter()
1043                .zip(x[s * d..(s + 1) * d].iter())
1044                .map(|(&a, &b)| a + b)
1045                .collect();
1046            let normed = layer_norm(&res, &self.ln1_g, &self.ln1_b)?;
1047            ln_out[s * d..(s + 1) * d].copy_from_slice(&normed);
1048        }
1049        // FFN
1050        let mut ffn_out = vec![0.0_f32; seq_len * d];
1051        for s in 0..seq_len {
1052            let tok = &ln_out[s * d..(s + 1) * d];
1053            let h1 = linear(&self.ffn_w1, &self.ffn_b1, tok)?;
1054            let h1a: Vec<f32> = h1.iter().map(|&v| gelu(v)).collect();
1055            let h2 = linear(&self.ffn_w2, &self.ffn_b2, &h1a)?;
1056            let res2: Vec<f32> = h2.iter().zip(tok.iter()).map(|(&a, &b)| a + b).collect();
1057            let normed2 = layer_norm(&res2, &self.ln3_g, &self.ln3_b)?;
1058            ffn_out[s * d..(s + 1) * d].copy_from_slice(&normed2);
1059        }
1060        Ok(ffn_out)
1061    }
1062
1063    /// Inter-sample (row-wise) attention over a batch of sample representations.
1064    /// `batch` is a list of per-sample vectors, each of length `d_model`.
1065    pub fn inter_sample_attention(&self, batch: &[Vec<f32>]) -> TabResult<Vec<Vec<f32>>> {
1066        let d = self.d_model;
1067        let n_samples = batch.len();
1068        if n_samples == 0 {
1069            return Ok(Vec::new());
1070        }
1071        for (i, s) in batch.iter().enumerate() {
1072            if s.len() != d {
1073                return Err(format!(
1074                    "SaintBlock::inter: batch[{i}].len()={} != d_model={d}",
1075                    s.len()
1076                ));
1077            }
1078        }
1079        // Flatten batch → [n_samples × d_model]
1080        let flat: Vec<f32> = batch.iter().flat_map(|s| s.iter().copied()).collect();
1081        let attn_out = scaled_dot_product_attn(
1082            &flat,
1083            &flat,
1084            &flat,
1085            n_samples,
1086            d,
1087            self.n_heads,
1088            &self.wq_inter,
1089            &self.wk_inter,
1090            &self.wv_inter,
1091            &self.wo_inter,
1092        )?;
1093        // Residual + LN per sample
1094        let mut result = Vec::with_capacity(n_samples);
1095        for s in 0..n_samples {
1096            let res: Vec<f32> = attn_out[s * d..(s + 1) * d]
1097                .iter()
1098                .zip(flat[s * d..(s + 1) * d].iter())
1099                .map(|(&a, &b)| a + b)
1100                .collect();
1101            let normed = layer_norm(&res, &self.ln2_g, &self.ln2_b)?;
1102            result.push(normed);
1103        }
1104        Ok(result)
1105    }
1106}
1107
1108/// Full SAINT model combining intra- and inter-sample attention blocks.
1109#[derive(Debug, Clone)]
1110pub struct SaintModel {
1111    /// Embedding dimension.
1112    pub d_model: usize,
1113    /// Number of SAINT blocks.
1114    pub n_blocks: usize,
1115    /// Categorical embeddings: one table per feature.
1116    cat_embeddings: Vec<Vec<f32>>,
1117    n_cat_features: usize,
1118    n_num_features: usize,
1119    cat_vocab_sizes: Vec<usize>,
1120    /// Numeric projection weights [n_num × d_model].
1121    num_w: Vec<Vec<f32>>,
1122    num_b: Vec<Vec<f32>>,
1123    /// SAINT blocks.
1124    blocks: Vec<SaintBlock>,
1125    /// Output head.
1126    head_w: Vec<f32>,
1127    head_b: Vec<f32>,
1128    n_classes: usize,
1129}
1130
1131impl SaintModel {
1132    /// Create a new `SaintModel`.
1133    pub fn new(
1134        n_cat_features: usize,
1135        n_num_features: usize,
1136        cat_vocab_sizes: Vec<usize>,
1137        d_model: usize,
1138        n_heads: usize,
1139        n_blocks: usize,
1140        ffn_dim: usize,
1141        n_classes: usize,
1142        seed: u64,
1143    ) -> TabResult<Self> {
1144        if cat_vocab_sizes.len() != n_cat_features {
1145            return Err("SaintModel: cat_vocab_sizes.len() != n_cat_features".into());
1146        }
1147        let mut rng = StdRng::seed_from_u64(seed);
1148        let cat_embeddings = cat_vocab_sizes
1149            .iter()
1150            .map(|&v| kaiming_uniform(v * d_model, d_model, &mut rng))
1151            .collect();
1152        let num_w = (0..n_num_features)
1153            .map(|_| xavier_uniform(d_model, 1, d_model, &mut rng))
1154            .collect();
1155        let num_b = (0..n_num_features).map(|_| zeros(d_model)).collect();
1156        let blocks = (0..n_blocks)
1157            .map(|i| SaintBlock::new(d_model, n_heads, ffn_dim, seed.wrapping_add(i as u64 + 1)))
1158            .collect::<TabResult<Vec<_>>>()?;
1159        let n_features = n_cat_features + n_num_features;
1160        let head_w = xavier_uniform(
1161            n_classes * n_features * d_model,
1162            n_features * d_model,
1163            n_classes,
1164            &mut rng,
1165        );
1166        let head_b = zeros(n_classes);
1167        Ok(Self {
1168            d_model,
1169            n_blocks,
1170            cat_embeddings,
1171            n_cat_features,
1172            n_num_features,
1173            cat_vocab_sizes,
1174            num_w,
1175            num_b,
1176            blocks,
1177            head_w,
1178            head_b,
1179            n_classes,
1180        })
1181    }
1182
1183    /// Forward pass for a single sample.
1184    pub fn forward(&self, cat_ids: &[usize], num_features: &[f32]) -> TabResult<Vec<f32>> {
1185        let d = self.d_model;
1186        let n_features = self.n_cat_features + self.n_num_features;
1187
1188        // Build feature token sequence [n_features × d]
1189        let mut tokens = vec![0.0_f32; n_features * d];
1190        for (i, &id) in cat_ids.iter().enumerate() {
1191            let v = self.cat_vocab_sizes[i];
1192            let clamped = id.min(v.saturating_sub(1));
1193            let emb = &self.cat_embeddings[i][clamped * d..(clamped + 1) * d];
1194            tokens[i * d..(i + 1) * d].copy_from_slice(emb);
1195        }
1196        for (i, &x) in num_features.iter().enumerate() {
1197            let offset = (self.n_cat_features + i) * d;
1198            for j in 0..d {
1199                tokens[offset + j] = x * self.num_w[i][j] + self.num_b[i][j];
1200            }
1201        }
1202
1203        // Apply SAINT blocks
1204        let mut seq = tokens;
1205        for block in &self.blocks {
1206            seq = block.intra_feature_attention(&seq)?;
1207        }
1208
1209        // Flatten and classify
1210        linear(&self.head_w, &self.head_b, &seq)
1211    }
1212}
1213
1214// ─────────────────────────────────────────────────────────────────────────────
1215// ══════════════════════  FeatureEncoder  ════════════════════════════════════
1216// ─────────────────────────────────────────────────────────────────────────────
1217
1218/// Standard z-score scaler: `(x − mean) / (std + ε)`.
1219#[derive(Debug, Clone, Default)]
1220pub struct StandardScaler {
1221    pub mean: Vec<f32>,
1222    pub std: Vec<f32>,
1223}
1224
1225impl StandardScaler {
1226    /// Fit on a list of feature vectors (each is a column slice).
1227    pub fn fit(&mut self, data: &[&[f32]]) {
1228        if data.is_empty() {
1229            return;
1230        }
1231        let n_features = data[0].len();
1232        let n = data.len() as f32;
1233        self.mean = vec![0.0_f32; n_features];
1234        self.std = vec![1.0_f32; n_features];
1235        for row in data {
1236            for (i, &v) in row.iter().enumerate() {
1237                if i < n_features {
1238                    self.mean[i] += v;
1239                }
1240            }
1241        }
1242        for m in self.mean.iter_mut() {
1243            *m /= n;
1244        }
1245        let mut var = vec![0.0_f32; n_features];
1246        for row in data {
1247            for (i, &v) in row.iter().enumerate() {
1248                if i < n_features {
1249                    var[i] += (v - self.mean[i]).powi(2);
1250                }
1251            }
1252        }
1253        for (i, v) in var.iter().enumerate() {
1254            self.std[i] = (v / n.max(1.0) + 1e-7).sqrt();
1255        }
1256    }
1257
1258    /// Transform a single feature vector.
1259    pub fn transform(&self, x: &[f32]) -> Vec<f32> {
1260        x.iter()
1261            .enumerate()
1262            .map(|(i, &v)| {
1263                let m = self.mean.get(i).copied().unwrap_or(0.0);
1264                let s = self.std.get(i).copied().unwrap_or(1.0);
1265                (v - m) / s.max(1e-7)
1266            })
1267            .collect()
1268    }
1269}
1270
1271/// Min-max scaler: `(x − min) / (range + ε)`.
1272#[derive(Debug, Clone, Default)]
1273pub struct MinMaxScaler {
1274    pub min: Vec<f32>,
1275    pub range: Vec<f32>,
1276}
1277
1278impl MinMaxScaler {
1279    /// Fit on data columns.
1280    pub fn fit(&mut self, data: &[&[f32]]) {
1281        if data.is_empty() {
1282            return;
1283        }
1284        let n_features = data[0].len();
1285        let mut mins = vec![f32::INFINITY; n_features];
1286        let mut maxs = vec![f32::NEG_INFINITY; n_features];
1287        for row in data {
1288            for (i, &v) in row.iter().enumerate() {
1289                if i < n_features {
1290                    if v < mins[i] {
1291                        mins[i] = v;
1292                    }
1293                    if v > maxs[i] {
1294                        maxs[i] = v;
1295                    }
1296                }
1297            }
1298        }
1299        self.range = mins
1300            .iter()
1301            .zip(maxs.iter())
1302            .map(|(&lo, &hi)| (hi - lo).max(1e-7))
1303            .collect();
1304        self.min = mins;
1305    }
1306
1307    /// Transform a single feature vector.
1308    pub fn transform(&self, x: &[f32]) -> Vec<f32> {
1309        x.iter()
1310            .enumerate()
1311            .map(|(i, &v)| {
1312                let lo = self.min.get(i).copied().unwrap_or(0.0);
1313                let r = self.range.get(i).copied().unwrap_or(1.0);
1314                (v - lo) / r
1315            })
1316            .collect()
1317    }
1318}
1319
1320/// Rank-based Gaussian quantile transformer.
1321#[derive(Debug, Clone, Default)]
1322pub struct QuantileTransformer {
1323    /// Per-feature sorted quantile values.
1324    pub quantiles: Vec<Vec<f32>>,
1325}
1326
1327impl QuantileTransformer {
1328    /// Fit the quantile mapping from sorted data.
1329    pub fn fit(&mut self, data: &[&[f32]]) {
1330        if data.is_empty() {
1331            return;
1332        }
1333        let n_features = data[0].len();
1334        self.quantiles = vec![Vec::new(); n_features];
1335        for feat_idx in 0..n_features {
1336            let mut vals: Vec<f32> = data
1337                .iter()
1338                .filter_map(|row| row.get(feat_idx).copied())
1339                .collect();
1340            vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1341            self.quantiles[feat_idx] = vals;
1342        }
1343    }
1344
1345    /// Transform: rank → standard normal quantile using erfinv approximation.
1346    pub fn transform(&self, x: &[f32]) -> Vec<f32> {
1347        x.iter()
1348            .enumerate()
1349            .map(|(i, &v)| {
1350                if i >= self.quantiles.len() || self.quantiles[i].is_empty() {
1351                    return v;
1352                }
1353                let q = &self.quantiles[i];
1354                let n = q.len() as f32;
1355                // Find rank via binary search
1356                let rank = q.partition_point(|&s| s <= v);
1357                let p = (rank as f32 + 0.5) / (n + 1.0);
1358                let p_clamped = p.clamp(1e-6, 1.0 - 1e-6);
1359                // Probit via rational approximation (Beasley-Springer-Moro)
1360                probit(p_clamped)
1361            })
1362            .collect()
1363    }
1364}
1365
1366/// Rational probit approximation (accurate to ~6 decimal places).
1367fn probit(p: f32) -> f32 {
1368    // Abramowitz & Stegun rational approximation
1369    let q = p - 0.5;
1370    if q.abs() < 0.425 {
1371        let r = 0.180625 - q * q;
1372        let num = ((2.509_081_f32 * r + 33.143_f32) * r + 85.44_f32) * r + 45.41_f32;
1373        let den = ((r + 15.159_f32) * r + 29.891_f32) * r + 1.0;
1374        q * (num / den)
1375    } else {
1376        let r = if q < 0.0 { p } else { 1.0 - p };
1377        let lr = (-r.ln()).sqrt().clamp(0.0, 10.0);
1378        let sign = if q < 0.0 { -1.0_f32 } else { 1.0_f32 };
1379        let num = (1.4234_f32 * lr + 4.6233_f32) * lr + 0.6806_f32;
1380        let den = (lr + 3.6575_f32) * lr + 1.0_f32;
1381        sign * (num / den)
1382    }
1383}
1384
1385/// Cyclic encoder for periodic features (sin + cos encoding).
1386#[derive(Debug, Clone)]
1387pub struct CyclicEncoder {
1388    /// Period for each feature (e.g., 24.0 for hours, 7.0 for weekdays).
1389    pub periods: Vec<f32>,
1390}
1391
1392impl CyclicEncoder {
1393    /// Create a new `CyclicEncoder` with specified periods.
1394    pub fn new(periods: Vec<f32>) -> Self {
1395        Self { periods }
1396    }
1397
1398    /// Transform: each feature `x` with period `T` → `[sin(2π x/T), cos(2π x/T)]`.
1399    /// Output length = 2 × input length.
1400    pub fn transform(&self, x: &[f32]) -> Vec<f32> {
1401        let mut out = Vec::with_capacity(x.len() * 2);
1402        for (i, &v) in x.iter().enumerate() {
1403            let t = self.periods.get(i).copied().unwrap_or(1.0).max(1e-7);
1404            let angle = 2.0 * PI * v / t;
1405            out.push(angle.sin());
1406            out.push(angle.cos());
1407        }
1408        out
1409    }
1410}
1411
1412/// Preprocessing pipeline combining multiple encoders.
1413#[derive(Debug, Clone)]
1414pub struct FeatureEncoder {
1415    pub scaler: StandardScaler,
1416    pub minmax: MinMaxScaler,
1417    pub quantile: QuantileTransformer,
1418    pub cyclic: Option<CyclicEncoder>,
1419}
1420
1421impl FeatureEncoder {
1422    /// Create a new `FeatureEncoder`.
1423    pub fn new(cyclic_periods: Option<Vec<f32>>) -> Self {
1424        Self {
1425            scaler: StandardScaler::default(),
1426            minmax: MinMaxScaler::default(),
1427            quantile: QuantileTransformer::default(),
1428            cyclic: cyclic_periods.map(CyclicEncoder::new),
1429        }
1430    }
1431
1432    /// Fit all sub-encoders.
1433    pub fn fit(&mut self, data: &[&[f32]]) {
1434        self.scaler.fit(data);
1435        self.minmax.fit(data);
1436        self.quantile.fit(data);
1437    }
1438
1439    /// Apply standard scaling (most common).
1440    pub fn transform(&self, x: &[f32]) -> Vec<f32> {
1441        self.scaler.transform(x)
1442    }
1443
1444    /// Apply quantile transform.
1445    pub fn quantile_transform(&self, x: &[f32]) -> Vec<f32> {
1446        self.quantile.transform(x)
1447    }
1448}
1449
1450// ─────────────────────────────────────────────────────────────────────────────
1451// ══════════════════════  MixedInputHead  ════════════════════════════════════
1452// ─────────────────────────────────────────────────────────────────────────────
1453
1454/// Gated fusion of categorical and numeric representations.
1455///
1456/// Uses a learned sigmoid gate: `output = g * cat + (1 - g) * num`, where
1457/// `g = sigmoid(W_g [cat; num] + b_g)`.
1458#[derive(Debug, Clone)]
1459pub struct MixedInputHead {
1460    cat_dim: usize,
1461    num_dim: usize,
1462    out_dim: usize,
1463    /// Gate weights [out_dim × (cat_dim + num_dim)].
1464    gate_w: Vec<f32>,
1465    gate_b: Vec<f32>,
1466    /// Projection for cat representation.
1467    cat_proj_w: Vec<f32>,
1468    cat_proj_b: Vec<f32>,
1469    /// Projection for num representation.
1470    num_proj_w: Vec<f32>,
1471    num_proj_b: Vec<f32>,
1472}
1473
1474impl MixedInputHead {
1475    /// Create a new `MixedInputHead`.
1476    pub fn new(cat_dim: usize, num_dim: usize, out_dim: usize, seed: u64) -> Self {
1477        let mut rng = StdRng::seed_from_u64(seed);
1478        let joint_dim = cat_dim + num_dim;
1479        Self {
1480            cat_dim,
1481            num_dim,
1482            out_dim,
1483            gate_w: xavier_uniform(out_dim * joint_dim, joint_dim, out_dim, &mut rng),
1484            gate_b: zeros(out_dim),
1485            cat_proj_w: xavier_uniform(out_dim * cat_dim, cat_dim, out_dim, &mut rng),
1486            cat_proj_b: zeros(out_dim),
1487            num_proj_w: xavier_uniform(out_dim * num_dim, num_dim, out_dim, &mut rng),
1488            num_proj_b: zeros(out_dim),
1489        }
1490    }
1491
1492    /// Gated fusion: returns blended representation of length `out_dim`.
1493    pub fn gate(&self, cat_repr: &[f32], num_repr: &[f32]) -> TabResult<Vec<f32>> {
1494        if cat_repr.len() != self.cat_dim {
1495            return Err(format!(
1496                "MixedInputHead: cat_repr.len()={} != cat_dim={}",
1497                cat_repr.len(),
1498                self.cat_dim
1499            ));
1500        }
1501        if num_repr.len() != self.num_dim {
1502            return Err(format!(
1503                "MixedInputHead: num_repr.len()={} != num_dim={}",
1504                num_repr.len(),
1505                self.num_dim
1506            ));
1507        }
1508        let joint: Vec<f32> = cat_repr.iter().chain(num_repr.iter()).copied().collect();
1509        let gate_logits = linear(&self.gate_w, &self.gate_b, &joint)?;
1510        let g: Vec<f32> = gate_logits.iter().map(|&v| sigmoid(v)).collect();
1511
1512        let cat_out = linear(&self.cat_proj_w, &self.cat_proj_b, cat_repr)?;
1513        let num_out = linear(&self.num_proj_w, &self.num_proj_b, num_repr)?;
1514
1515        Ok(g.iter()
1516            .zip(cat_out.iter())
1517            .zip(num_out.iter())
1518            .map(|((&gi, &ci), &ni)| gi * ci + (1.0 - gi) * ni)
1519            .collect())
1520    }
1521}
1522
1523// ─────────────────────────────────────────────────────────────────────────────
1524// ══════════════════════  TabularAugmentation  ═══════════════════════════════
1525// ─────────────────────────────────────────────────────────────────────────────
1526
1527/// Data augmentation for tabular data.
1528pub struct TabularAugmentation;
1529
1530impl TabularAugmentation {
1531    /// Mixup: returns `(λ·x1 + (1−λ)·x2, λ)` where `λ ~ Beta(α, α)`.
1532    /// Approximation: `λ = clip(|N(0.5, 1/(12α))|, 0, 1)`.
1533    pub fn mixup(
1534        x1: &[f32],
1535        x2: &[f32],
1536        alpha: f32,
1537        rng: &mut StdRng,
1538    ) -> TabResult<(Vec<f32>, f32)> {
1539        if x1.len() != x2.len() {
1540            return Err(format!("mixup: len mismatch {} vs {}", x1.len(), x2.len()));
1541        }
1542        let lambda = Self::beta_sample(alpha, rng);
1543        let mixed: Vec<f32> = x1
1544            .iter()
1545            .zip(x2.iter())
1546            .map(|(&a, &b)| lambda * a + (1.0 - lambda) * b)
1547            .collect();
1548        Ok((mixed, lambda))
1549    }
1550
1551    /// CutMix: randomly replace a contiguous block of features from x2 into x1.
1552    pub fn cutmix(
1553        x1: &[f32],
1554        x2: &[f32],
1555        alpha: f32,
1556        rng: &mut StdRng,
1557    ) -> TabResult<(Vec<f32>, f32)> {
1558        if x1.len() != x2.len() {
1559            return Err(format!("cutmix: len mismatch {} vs {}", x1.len(), x2.len()));
1560        }
1561        let lambda = Self::beta_sample(alpha, rng);
1562        let n = x1.len();
1563        let cut_len = (n as f32 * (1.0 - lambda)).round() as usize;
1564        let start_f: f32 = rng.random();
1565        let start = (start_f * (n.saturating_sub(cut_len) + 1) as f32) as usize;
1566        let end = (start + cut_len).min(n);
1567
1568        let mut mixed = x1.to_vec();
1569        mixed[start..end].copy_from_slice(&x2[start..end]);
1570        let actual_lambda = 1.0 - (end - start) as f32 / n.max(1) as f32;
1571        Ok((mixed, actual_lambda))
1572    }
1573
1574    /// SMOTE-like synthetic oversampling: generates a new sample between `sample`
1575    /// and a random neighbour from `neighbours`.
1576    pub fn smote_like(
1577        sample: &[f32],
1578        neighbours: &[Vec<f32>],
1579        rng: &mut StdRng,
1580    ) -> TabResult<Vec<f32>> {
1581        if neighbours.is_empty() {
1582            return Err("smote_like: no neighbours provided".into());
1583        }
1584        let idx_f: f32 = rng.random();
1585        let idx = (idx_f * neighbours.len() as f32) as usize;
1586        let idx = idx.min(neighbours.len() - 1);
1587        let neighbour = &neighbours[idx];
1588        if sample.len() != neighbour.len() {
1589            return Err(format!(
1590                "smote_like: sample.len()={} != neighbour.len()={}",
1591                sample.len(),
1592                neighbour.len()
1593            ));
1594        }
1595        let gap: f32 = rng.random();
1596        Ok(sample
1597            .iter()
1598            .zip(neighbour.iter())
1599            .map(|(&s, &n)| s + gap * (n - s))
1600            .collect())
1601    }
1602
1603    /// Sample λ from a symmetric Beta(α, α) using the approximation
1604    /// `λ ≈ 0.5 + N(0, 1) * σ` then clipped, where `σ = 1/(2*sqrt(2α+1))`.
1605    fn beta_sample(alpha: f32, rng: &mut StdRng) -> f32 {
1606        let alpha = alpha.max(0.01);
1607        // Box-Muller
1608        let u1: f32 = rng.random::<f32>().max(1e-7);
1609        let u2: f32 = rng.random();
1610        let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
1611        let sigma = 1.0 / (2.0 * (2.0 * alpha + 1.0).sqrt());
1612        (0.5 + sigma * z).clamp(0.0, 1.0)
1613    }
1614}
1615
1616// ─────────────────────────────────────────────────────────────────────────────
1617// ══════════════════════  TabularMetrics  ════════════════════════════════════
1618// ─────────────────────────────────────────────────────────────────────────────
1619
1620/// Evaluation metrics for tabular ML tasks.
1621pub struct TabularMetrics;
1622
1623impl TabularMetrics {
1624    /// Classification accuracy: fraction of correctly predicted classes.
1625    /// `pred` is a flat vector of per-class logits; `n_classes` implicit from pred/target.
1626    pub fn accuracy(pred_logits: &[Vec<f32>], target: &[usize]) -> f32 {
1627        if pred_logits.is_empty() || pred_logits.len() != target.len() {
1628            return 0.0;
1629        }
1630        let correct = pred_logits
1631            .iter()
1632            .zip(target.iter())
1633            .filter(|(logits, &t)| {
1634                let max_idx = logits
1635                    .iter()
1636                    .enumerate()
1637                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1638                    .map(|(i, _)| i)
1639                    .unwrap_or(0);
1640                max_idx == t
1641            })
1642            .count();
1643        correct as f32 / pred_logits.len() as f32
1644    }
1645
1646    /// Macro-averaged F1 score.
1647    pub fn macro_f1(pred: &[Vec<f32>], target: &[usize], n_classes: usize) -> f32 {
1648        if pred.is_empty() || n_classes == 0 {
1649            return 0.0;
1650        }
1651        let mut tp = vec![0u32; n_classes];
1652        let mut fp = vec![0u32; n_classes];
1653        let mut fn_ = vec![0u32; n_classes];
1654
1655        for (logits, &t) in pred.iter().zip(target.iter()) {
1656            let pred_class = logits
1657                .iter()
1658                .enumerate()
1659                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1660                .map(|(i, _)| i)
1661                .unwrap_or(0);
1662            if pred_class == t {
1663                if t < n_classes {
1664                    tp[t] += 1;
1665                }
1666            } else {
1667                if pred_class < n_classes {
1668                    fp[pred_class] += 1;
1669                }
1670                if t < n_classes {
1671                    fn_[t] += 1;
1672                }
1673            }
1674        }
1675
1676        let mut f1_sum = 0.0_f32;
1677        for c in 0..n_classes {
1678            let precision = tp[c] as f32 / (tp[c] + fp[c]).max(1) as f32;
1679            let recall = tp[c] as f32 / (tp[c] + fn_[c]).max(1) as f32;
1680            let denom = precision + recall;
1681            let f1 = if denom > 0.0 {
1682                2.0 * precision * recall / denom
1683            } else {
1684                0.0
1685            };
1686            f1_sum += f1;
1687        }
1688        f1_sum / n_classes as f32
1689    }
1690
1691    /// Root Mean Squared Error.
1692    pub fn rmse(pred: &[f32], target: &[f32]) -> f32 {
1693        if pred.is_empty() || pred.len() != target.len() {
1694            return f32::NAN;
1695        }
1696        let mse = pred
1697            .iter()
1698            .zip(target.iter())
1699            .map(|(&p, &t)| (p - t).powi(2))
1700            .sum::<f32>()
1701            / pred.len() as f32;
1702        mse.sqrt()
1703    }
1704
1705    /// R² (coefficient of determination).
1706    pub fn r2_score(pred: &[f32], target: &[f32]) -> f32 {
1707        if pred.is_empty() || pred.len() != target.len() {
1708            return f32::NAN;
1709        }
1710        let mean_t = target.iter().sum::<f32>() / target.len() as f32;
1711        let ss_tot: f32 = target.iter().map(|&t| (t - mean_t).powi(2)).sum();
1712        let ss_res: f32 = pred
1713            .iter()
1714            .zip(target.iter())
1715            .map(|(&p, &t)| (t - p).powi(2))
1716            .sum();
1717        if ss_tot < 1e-12 {
1718            return if ss_res < 1e-12 { 1.0 } else { 0.0 };
1719        }
1720        1.0 - ss_res / ss_tot
1721    }
1722}
1723
1724// ─────────────────────────────────────────────────────────────────────────────
1725// ══════════════════════  CatBoostEncoder  ═══════════════════════════════════
1726// ─────────────────────────────────────────────────────────────────────────────
1727
1728/// Target encoding with Leave-One-Out regularisation (à la CatBoost).
1729///
1730/// For each sample `(cat_i, y_i)` at training time, the encoded value is:
1731/// `(sum_j≠i y_j where cat_j = cat_i) / (count_j≠i - 1 + 1) * λ + prior * (1 - λ)`
1732/// where `λ` is based on per-category count.
1733#[derive(Debug, Clone, Default)]
1734pub struct CatBoostEncoder {
1735    /// Mapping from category id → (count, sum_of_targets).
1736    pub category_stats: std::collections::HashMap<usize, (usize, f64)>,
1737    /// Global prior (mean of all targets).
1738    pub prior: f32,
1739}
1740
1741impl CatBoostEncoder {
1742    /// Create a `CatBoostEncoder`.
1743    pub fn new() -> Self {
1744        Self::default()
1745    }
1746
1747    /// Fit and transform simultaneously (LOO to avoid leakage).
1748    pub fn fit_transform(&mut self, categories: &[usize], targets: &[f32], prior: f32) -> Vec<f32> {
1749        if categories.len() != targets.len() || categories.is_empty() {
1750            return Vec::new();
1751        }
1752        self.prior = prior;
1753        self.category_stats.clear();
1754
1755        // Accumulate full stats first
1756        for (&cat, &t) in categories.iter().zip(targets.iter()) {
1757            let entry = self.category_stats.entry(cat).or_insert((0, 0.0));
1758            entry.0 += 1;
1759            entry.1 += t as f64;
1760        }
1761
1762        // LOO transform
1763        categories
1764            .iter()
1765            .zip(targets.iter())
1766            .map(|(&cat, &t)| {
1767                let (count, sum) = self.category_stats.get(&cat).copied().unwrap_or((0, 0.0));
1768                // Remove current sample
1769                let loo_sum = sum - t as f64;
1770                let loo_count = count.saturating_sub(1);
1771                // Regularisation weight: diminishes for low-count categories
1772                let lambda = loo_count as f32 / (loo_count as f32 + 1.0);
1773                let loo_mean = if loo_count == 0 {
1774                    prior
1775                } else {
1776                    (loo_sum / loo_count as f64) as f32
1777                };
1778                lambda * loo_mean + (1.0 - lambda) * prior
1779            })
1780            .collect()
1781    }
1782
1783    /// Transform new (unseen) categories using fitted statistics.
1784    pub fn transform(&self, categories: &[usize]) -> Vec<f32> {
1785        categories
1786            .iter()
1787            .map(|&cat| match self.category_stats.get(&cat) {
1788                None => self.prior,
1789                Some(&(count, sum)) => {
1790                    let mean = if count == 0 {
1791                        self.prior as f64
1792                    } else {
1793                        sum / count as f64
1794                    };
1795                    let lambda = count as f32 / (count as f32 + 1.0);
1796                    lambda * mean as f32 + (1.0 - lambda) * self.prior
1797                }
1798            })
1799            .collect()
1800    }
1801}
1802
1803// ─────────────────────────────────────────────────────────────────────────────
1804// Tests
1805// ─────────────────────────────────────────────────────────────────────────────
1806
1807#[cfg(test)]
1808mod tests;