Skip to main content

tl_ai/
train.rs

1// ThinkingLanguage — Training dispatcher
2// Uses linfa for pure-Rust ML training.
3
4use std::collections::{HashMap, HashSet};
5
6use linfa::Dataset;
7use linfa::prelude::*;
8use ndarray::{Array1, Array2, Axis};
9
10use crate::model::{LinfaKind, ModelMeta, TlModel};
11use crate::tensor::TlTensor;
12
13/// Training configuration extracted from TL source.
14pub struct TrainConfig {
15    /// Feature data (2D: samples x features).
16    pub features: TlTensor,
17    /// Target data (1D: samples).
18    pub target: TlTensor,
19    /// Feature column names.
20    pub feature_names: Vec<String>,
21    /// Target column name.
22    pub target_name: String,
23    /// Model name.
24    pub model_name: String,
25    /// Train/test split ratio (0.0 to 1.0, fraction for training).
26    pub split_ratio: f64,
27    /// Hyperparameters.
28    pub hyperparams: HashMap<String, f64>,
29}
30
31/// Train a model using the specified algorithm.
32pub fn train(algorithm: &str, config: &TrainConfig) -> Result<TlModel, String> {
33    match algorithm {
34        "linear" => train_linear(config),
35        "logistic" => train_logistic(config),
36        "tree" | "decision_tree" => train_decision_tree(config),
37        "random_forest" | "forest" => train_random_forest(config),
38        "kmeans" | "k_means" => train_kmeans(config),
39        "knn" | "k_nearest_neighbors" => train_knn(config),
40        "naive_bayes" | "gaussian_nb" | "nb" => train_naive_bayes(config),
41        "dbscan" => train_dbscan(config),
42        "ridge" => train_ridge(config),
43        "gradient_boosting" | "gbt" | "gbm" | "xgboost" => train_gradient_boosting(config),
44        _ => Err(format!(
45            "Unknown training algorithm: '{algorithm}'. Supported: linear, ridge, logistic, \
46             tree, random_forest, gradient_boosting, knn, naive_bayes, kmeans, dbscan"
47        )),
48    }
49}
50
51// ---- shared helpers --------------------------------------------------------
52
53/// Apply a per-row prediction closure over a 1D (single sample) or 2D
54/// (samples × features) input tensor, returning a 1D tensor of predictions.
55fn apply_rowwise<P: Fn(&[f64]) -> f64>(
56    input: &TlTensor,
57    predict_row: P,
58) -> Result<TlTensor, String> {
59    let shape = input.shape();
60    let flat = input.to_vec();
61    if shape.len() == 1 {
62        Ok(TlTensor::from_list(vec![predict_row(&flat)]))
63    } else if shape.len() == 2 {
64        let (rows, cols) = (shape[0], shape[1]);
65        let mut preds = Vec::with_capacity(rows);
66        for i in 0..rows {
67            preds.push(predict_row(&flat[i * cols..(i + 1) * cols]));
68        }
69        Ok(TlTensor::from_list(preds))
70    } else {
71        Err(format!("Input must be 1D or 2D, got {}D", shape.len()))
72    }
73}
74
75/// Serialize a fitted linfa decision tree into a self-contained JSON node so it
76/// can be reloaded and used for inference (linfa's tree isn't serde-serializable).
77fn tree_node_to_json(node: &linfa_trees::TreeNode<f64, usize>) -> serde_json::Value {
78    if node.is_leaf() {
79        serde_json::json!({ "leaf": true, "value": node.prediction().unwrap_or(0) })
80    } else {
81        let (feature, threshold, _) = node.split();
82        let children = node.children(); // [left, right]
83        let left = children[0]
84            .as_ref()
85            .map(|c| tree_node_to_json(c))
86            .unwrap_or(serde_json::Value::Null);
87        let right = children[1]
88            .as_ref()
89            .map(|c| tree_node_to_json(c))
90            .unwrap_or(serde_json::Value::Null);
91        serde_json::json!({ "leaf": false, "feature": feature, "threshold": threshold, "left": left, "right": right })
92    }
93}
94
95/// Traverse a serialized tree (see `tree_node_to_json`) for one feature row.
96/// Matches linfa's split rule: `x[feature] < threshold` goes left, else right.
97fn predict_tree_json(node: &serde_json::Value, row: &[f64]) -> f64 {
98    if node["leaf"].as_bool().unwrap_or(true) {
99        return node["value"].as_f64().unwrap_or(0.0);
100    }
101    let f = node["feature"].as_u64().unwrap_or(0) as usize;
102    let thr = node["threshold"].as_f64().unwrap_or(0.0);
103    let xv = row.get(f).copied().unwrap_or(0.0);
104    if xv < thr {
105        predict_tree_json(&node["left"], row)
106    } else {
107        predict_tree_json(&node["right"], row)
108    }
109}
110
111/// Majority-vote class over a set of serialized trees (used by random forest;
112/// a single decision tree is the one-tree case).
113fn vote_trees(trees: &[serde_json::Value], row: &[f64]) -> f64 {
114    let mut counts: HashMap<i64, usize> = HashMap::new();
115    for t in trees {
116        *counts.entry(predict_tree_json(t, row) as i64).or_insert(0) += 1;
117    }
118    counts
119        .into_iter()
120        .max_by_key(|(_, c)| *c)
121        .map(|(v, _)| v as f64)
122        .unwrap_or(0.0)
123}
124
125fn features_to_array2(features: &TlTensor) -> Result<Array2<f64>, String> {
126    let shape = features.shape();
127    if shape.len() != 2 {
128        return Err(format!("Features must be 2D, got {}D", shape.len()));
129    }
130    let rows = shape[0];
131    let cols = shape[1];
132    let flat = features.to_vec();
133    Array2::from_shape_vec((rows, cols), flat).map_err(|e| format!("Shape error: {e}"))
134}
135
136fn target_to_array1(target: &TlTensor) -> Result<Array1<f64>, String> {
137    let shape = target.shape();
138    if shape.len() != 1 {
139        return Err(format!("Target must be 1D, got {}D", shape.len()));
140    }
141    Ok(Array1::from_vec(target.to_vec()))
142}
143
144fn train_linear(config: &TrainConfig) -> Result<TlModel, String> {
145    let x = features_to_array2(&config.features)?;
146    let y = target_to_array1(&config.target)?;
147    let dataset = Dataset::new(x, y);
148
149    let model = linfa_linear::LinearRegression::default()
150        .fit(&dataset)
151        .map_err(|e| format!("Linear regression training failed: {e}"))?;
152
153    // Compute R² on training data
154    let pred = model.predict(&dataset);
155    let r2 = pred
156        .r2(&dataset)
157        .map_err(|e| format!("R² computation failed: {e}"))?;
158
159    // Serialize model params
160    let params = model.params();
161    let intercept = model.intercept();
162    let model_data = serde_json::json!({
163        "params": params.as_slice().unwrap_or(&[]),
164        "intercept": intercept,
165    });
166    let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
167
168    let mut metrics = HashMap::new();
169    metrics.insert("r2".to_string(), r2);
170
171    Ok(TlModel::Linfa {
172        kind: LinfaKind::LinearRegression,
173        data,
174        metadata: ModelMeta {
175            name: config.model_name.clone(),
176            version: "0.1.0".to_string(),
177            created_at: String::new(),
178            features: config.feature_names.clone(),
179            target: config.target_name.clone(),
180            metrics,
181        },
182    })
183}
184
185fn train_logistic(config: &TrainConfig) -> Result<TlModel, String> {
186    let x = features_to_array2(&config.features)?;
187    let y_float = target_to_array1(&config.target)?;
188
189    // Convert targets to bool for binary classification
190    let y_bool: Array1<bool> = y_float.mapv(|v| v > 0.5);
191
192    let dataset = Dataset::new(x, y_bool);
193
194    let model = linfa_logistic::LogisticRegression::default()
195        .max_iterations(100)
196        .fit(&dataset)
197        .map_err(|e| format!("Logistic regression training failed: {e}"))?;
198
199    // Compute accuracy
200    let pred = model.predict(&dataset);
201    let correct = pred
202        .iter()
203        .zip(dataset.targets().iter())
204        .filter(|(p, t)| p == t)
205        .count();
206    let accuracy = correct as f64 / dataset.targets().len() as f64;
207
208    // Serialize model params
209    let params = model.params();
210    let intercept = model.intercept();
211    let params_slice = params.as_slice().unwrap_or(&[]);
212
213    // Map the decision function (sigmoid(x·w + b) > 0.5) back to the original
214    // class labels. linfa's internal class ordering does NOT guarantee that the
215    // "1" class sits on the positive side, so derive the mapping from linfa's own
216    // predictions instead of hard-coding 0/1 — this is what caused inverted labels.
217    let (mut pos_label, mut neg_label) = (1.0_f64, 0.0_f64);
218    {
219        let records = dataset.records();
220        for (i, p) in pred.iter().enumerate() {
221            let row = records.row(i);
222            let logit: f64 = row
223                .iter()
224                .zip(params_slice.iter())
225                .map(|(a, b)| a * b)
226                .sum::<f64>()
227                + intercept;
228            let label = if *p { 1.0 } else { 0.0 };
229            if logit > 0.0 {
230                pos_label = label;
231            } else {
232                neg_label = label;
233            }
234        }
235    }
236
237    let model_data = serde_json::json!({
238        "params": params_slice,
239        "intercept": intercept,
240        "pos_label": pos_label,
241        "neg_label": neg_label,
242    });
243    let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
244
245    let mut metrics = HashMap::new();
246    metrics.insert("accuracy".to_string(), accuracy);
247
248    Ok(TlModel::Linfa {
249        kind: LinfaKind::LogisticRegression,
250        data,
251        metadata: ModelMeta {
252            name: config.model_name.clone(),
253            version: "0.1.0".to_string(),
254            created_at: String::new(),
255            features: config.feature_names.clone(),
256            target: config.target_name.clone(),
257            metrics,
258        },
259    })
260}
261
262fn train_decision_tree(config: &TrainConfig) -> Result<TlModel, String> {
263    let x = features_to_array2(&config.features)?;
264    let y_float = target_to_array1(&config.target)?;
265
266    // Convert targets to usize for classification
267    let y_usize: Array1<usize> = y_float.mapv(|v| v as usize);
268
269    let max_depth = config
270        .hyperparams
271        .get("max_depth")
272        .copied()
273        .map(|d| d as usize);
274
275    let dataset = Dataset::new(x, y_usize);
276
277    let mut builder = linfa_trees::DecisionTree::params();
278    if let Some(depth) = max_depth {
279        builder = builder.max_depth(Some(depth));
280    }
281    let model = builder
282        .fit(&dataset)
283        .map_err(|e| format!("Decision tree training failed: {e}"))?;
284
285    // Compute accuracy
286    let pred = model.predict(&dataset);
287    let correct = pred
288        .iter()
289        .zip(dataset.targets().iter())
290        .filter(|(p, t)| p == t)
291        .count();
292    let accuracy = correct as f64 / dataset.targets().len() as f64;
293
294    // Serialize the full tree structure so inference works after reload.
295    let model_data = serde_json::json!({
296        "type": "decision_tree",
297        "accuracy": accuracy,
298        "tree": tree_node_to_json(model.root_node()),
299    });
300    let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
301
302    let mut metrics = HashMap::new();
303    metrics.insert("accuracy".to_string(), accuracy);
304
305    Ok(TlModel::Linfa {
306        kind: LinfaKind::DecisionTree,
307        data,
308        metadata: ModelMeta {
309            name: config.model_name.clone(),
310            version: "0.1.0".to_string(),
311            created_at: String::new(),
312            features: config.feature_names.clone(),
313            target: config.target_name.clone(),
314            metrics,
315        },
316    })
317}
318
319/// Random forest: bootstrap-aggregated ensemble of linfa decision trees.
320/// Predicts by majority vote. Hyperparameters: `n_trees` (default 10),
321/// `max_depth` (optional).
322fn train_random_forest(config: &TrainConfig) -> Result<TlModel, String> {
323    let x = features_to_array2(&config.features)?;
324    let y_float = target_to_array1(&config.target)?;
325    let y_usize: Array1<usize> = y_float.mapv(|v| v as usize);
326
327    let n = x.nrows();
328    if n == 0 {
329        return Err("Random forest: no training samples".to_string());
330    }
331    let n_trees = config
332        .hyperparams
333        .get("n_trees")
334        .or_else(|| config.hyperparams.get("trees"))
335        .copied()
336        .map(|v| (v as usize).max(1))
337        .unwrap_or(10);
338    let max_depth = config
339        .hyperparams
340        .get("max_depth")
341        .copied()
342        .map(|d| d as usize);
343
344    // Deterministic xorshift RNG for bootstrap sampling (no extra dependency).
345    let mut seed: u64 = 0x2545F4914F6CDD1D;
346    let mut next = || {
347        seed ^= seed << 13;
348        seed ^= seed >> 7;
349        seed ^= seed << 17;
350        seed
351    };
352
353    let mut trees: Vec<serde_json::Value> = Vec::with_capacity(n_trees);
354    for _ in 0..n_trees {
355        let rows: Vec<usize> = (0..n).map(|_| (next() as usize) % n).collect();
356        let xb = x.select(Axis(0), &rows);
357        let yb = y_usize.select(Axis(0), &rows);
358        let ds = Dataset::new(xb, yb);
359        let mut builder = linfa_trees::DecisionTree::params();
360        if let Some(d) = max_depth {
361            builder = builder.max_depth(Some(d));
362        }
363        let tree = builder
364            .fit(&ds)
365            .map_err(|e| format!("Random forest tree training failed: {e}"))?;
366        trees.push(tree_node_to_json(tree.root_node()));
367    }
368
369    // Training accuracy via majority vote.
370    let flat = x.iter().copied().collect::<Vec<f64>>();
371    let cols = x.ncols();
372    let mut correct = 0usize;
373    for i in 0..n {
374        let row = &flat[i * cols..(i + 1) * cols];
375        if vote_trees(&trees, row) as usize == y_usize[i] {
376            correct += 1;
377        }
378    }
379    let accuracy = correct as f64 / n as f64;
380
381    let model_data = serde_json::json!({ "type": "random_forest", "trees": trees });
382    let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
383
384    let mut metrics = HashMap::new();
385    metrics.insert("accuracy".to_string(), accuracy);
386    metrics.insert("n_trees".to_string(), n_trees as f64);
387
388    Ok(TlModel::Linfa {
389        kind: LinfaKind::RandomForest,
390        data,
391        metadata: ModelMeta {
392            name: config.model_name.clone(),
393            version: "0.1.0".to_string(),
394            created_at: String::new(),
395            features: config.feature_names.clone(),
396            target: config.target_name.clone(),
397            metrics,
398        },
399    })
400}
401
402/// K-means clustering (Lloyd's algorithm, pure Rust — unsupervised, the target
403/// column is ignored). Hyperparameters: `k` (default 3), `max_iter` (default
404/// 100). Predict returns the nearest-centroid cluster index per row.
405fn train_kmeans(config: &TrainConfig) -> Result<TlModel, String> {
406    let x = features_to_array2(&config.features)?;
407    let n = x.nrows();
408    let d = x.ncols();
409    if n == 0 {
410        return Err("K-means: no training samples".to_string());
411    }
412    let k = config
413        .hyperparams
414        .get("k")
415        .or_else(|| config.hyperparams.get("clusters"))
416        .copied()
417        .map(|v| (v as usize).max(1))
418        .unwrap_or(3)
419        .min(n);
420    let max_iter = config
421        .hyperparams
422        .get("max_iter")
423        .copied()
424        .map(|v| (v as usize).max(1))
425        .unwrap_or(100);
426
427    // Deterministic init: evenly spaced rows as initial centroids.
428    let mut centroids: Vec<Vec<f64>> = (0..k).map(|i| x.row((i * n) / k).to_vec()).collect();
429    let mut assign = vec![0usize; n];
430
431    for _ in 0..max_iter {
432        let mut changed = false;
433        for (i, slot) in assign.iter_mut().enumerate() {
434            let row = x.row(i);
435            let mut best = 0usize;
436            let mut best_d = f64::INFINITY;
437            for (c, cen) in centroids.iter().enumerate() {
438                let dist: f64 = row.iter().zip(cen).map(|(a, b)| (a - b) * (a - b)).sum();
439                if dist < best_d {
440                    best_d = dist;
441                    best = c;
442                }
443            }
444            if *slot != best {
445                *slot = best;
446                changed = true;
447            }
448        }
449        let mut sums = vec![vec![0.0f64; d]; k];
450        let mut counts = vec![0usize; k];
451        for i in 0..n {
452            let row = x.row(i);
453            counts[assign[i]] += 1;
454            for j in 0..d {
455                sums[assign[i]][j] += row[j];
456            }
457        }
458        for c in 0..k {
459            if counts[c] > 0 {
460                for j in 0..d {
461                    centroids[c][j] = sums[c][j] / counts[c] as f64;
462                }
463            }
464        }
465        if !changed {
466            break;
467        }
468    }
469
470    // Inertia (within-cluster sum of squares) as a quality metric.
471    let mut inertia = 0.0f64;
472    for i in 0..n {
473        let row = x.row(i);
474        let cen = &centroids[assign[i]];
475        inertia += row
476            .iter()
477            .zip(cen)
478            .map(|(a, b)| (a - b) * (a - b))
479            .sum::<f64>();
480    }
481
482    let model_data = serde_json::json!({ "type": "kmeans", "centroids": centroids });
483    let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
484
485    let mut metrics = HashMap::new();
486    metrics.insert("k".to_string(), k as f64);
487    metrics.insert("inertia".to_string(), inertia);
488
489    Ok(TlModel::Linfa {
490        kind: LinfaKind::KMeans,
491        data,
492        metadata: ModelMeta {
493            name: config.model_name.clone(),
494            version: "0.1.0".to_string(),
495            created_at: String::new(),
496            features: config.feature_names.clone(),
497            target: config.target_name.clone(),
498            metrics,
499        },
500    })
501}
502
503/// Squared Euclidean distance between two equal-length rows.
504fn sq_dist(a: &[f64], b: &[f64]) -> f64 {
505    a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum()
506}
507
508/// k-NN majority vote of `ytrain` over the `k` nearest rows of `xtrain`.
509fn knn_vote(xtrain: &[Vec<f64>], ytrain: &[f64], k: usize, row: &[f64]) -> f64 {
510    let mut dists: Vec<(f64, f64)> = xtrain
511        .iter()
512        .zip(ytrain)
513        .map(|(p, &l)| (sq_dist(p, row), l))
514        .collect();
515    dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
516    let mut counts: HashMap<i64, usize> = HashMap::new();
517    for (_, l) in dists.iter().take(k.min(dists.len())) {
518        *counts.entry(*l as i64).or_insert(0) += 1;
519    }
520    counts
521        .into_iter()
522        .max_by_key(|(_, c)| *c)
523        .map(|(v, _)| v as f64)
524        .unwrap_or(0.0)
525}
526
527/// Solve A·x = b (n×n) by Gauss-Jordan elimination with partial pivoting.
528/// Returns None if the matrix is (numerically) singular.
529fn solve_linear_system(mut a: Vec<Vec<f64>>, mut b: Vec<f64>) -> Option<Vec<f64>> {
530    let n = b.len();
531    for col in 0..n {
532        let mut piv = col;
533        for r in (col + 1)..n {
534            if a[r][col].abs() > a[piv][col].abs() {
535                piv = r;
536            }
537        }
538        if a[piv][col].abs() < 1e-12 {
539            return None;
540        }
541        a.swap(col, piv);
542        b.swap(col, piv);
543        let d = a[col][col];
544        for v in a[col].iter_mut() {
545            *v /= d;
546        }
547        b[col] /= d;
548        let pivot_row = a[col].clone();
549        let pivot_b = b[col];
550        for r in 0..n {
551            if r != col {
552                let f = a[r][col];
553                if f != 0.0 {
554                    for (v, p) in a[r].iter_mut().zip(&pivot_row) {
555                        *v -= f * p;
556                    }
557                    b[r] -= f * pivot_b;
558                }
559            }
560        }
561    }
562    Some(b)
563}
564
565/// k-nearest-neighbors classifier. Stores the training set; predicts by majority
566/// vote of the `k` (default 5) nearest rows.
567fn train_knn(config: &TrainConfig) -> Result<TlModel, String> {
568    let x = features_to_array2(&config.features)?;
569    let y = target_to_array1(&config.target)?;
570    let k = config
571        .hyperparams
572        .get("k")
573        .or_else(|| config.hyperparams.get("neighbors"))
574        .copied()
575        .map(|v| (v as usize).max(1))
576        .unwrap_or(5);
577    let rows: Vec<Vec<f64>> = (0..x.nrows()).map(|i| x.row(i).to_vec()).collect();
578    let labels: Vec<f64> = y.to_vec();
579
580    let mut correct = 0usize;
581    for i in 0..rows.len() {
582        if knn_vote(&rows, &labels, k, &rows[i]) == labels[i] {
583            correct += 1;
584        }
585    }
586    let accuracy = if rows.is_empty() {
587        0.0
588    } else {
589        correct as f64 / rows.len() as f64
590    };
591
592    let model_data = serde_json::json!({ "type": "knn", "k": k, "x": rows, "y": labels });
593    let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
594    let mut metrics = HashMap::new();
595    metrics.insert("accuracy".to_string(), accuracy);
596    metrics.insert("k".to_string(), k as f64);
597    Ok(linfa_model(LinfaKind::Knn, data, config, metrics))
598}
599
600/// Gaussian Naive Bayes classifier. Stores per-class priors and per-feature
601/// mean/variance; predicts the maximum-a-posteriori class.
602fn train_naive_bayes(config: &TrainConfig) -> Result<TlModel, String> {
603    let x = features_to_array2(&config.features)?;
604    let y = target_to_array1(&config.target)?;
605    let n = x.nrows();
606    let d = x.ncols();
607    if n == 0 {
608        return Err("Naive Bayes: no training samples".to_string());
609    }
610    // Group row indices by class label.
611    let mut by_class: HashMap<i64, Vec<usize>> = HashMap::new();
612    for i in 0..n {
613        by_class.entry(y[i] as i64).or_default().push(i);
614    }
615    let mut classes: Vec<serde_json::Value> = Vec::new();
616    for (label, idxs) in &by_class {
617        let cnt = idxs.len();
618        let mut means = vec![0.0f64; d];
619        for &i in idxs {
620            let row = x.row(i);
621            for j in 0..d {
622                means[j] += row[j];
623            }
624        }
625        for m in &mut means {
626            *m /= cnt as f64;
627        }
628        let mut vars = vec![0.0f64; d];
629        for &i in idxs {
630            let row = x.row(i);
631            for j in 0..d {
632                vars[j] += (row[j] - means[j]).powi(2);
633            }
634        }
635        for v in &mut vars {
636            *v = (*v / cnt as f64).max(1e-9); // floor to avoid div-by-zero
637        }
638        classes.push(serde_json::json!({
639            "label": *label as f64,
640            "prior": (cnt as f64 / n as f64).ln(),
641            "means": means,
642            "vars": vars,
643        }));
644    }
645
646    // Training accuracy.
647    let nb = NaiveBayesModel::from_json(&classes);
648    let correct = (0..n)
649        .filter(|&i| nb.predict(&x.row(i).to_vec()) == y[i].round())
650        .count();
651    let accuracy = correct as f64 / n as f64;
652
653    let model_data = serde_json::json!({ "type": "naive_bayes", "classes": classes });
654    let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
655    let mut metrics = HashMap::new();
656    metrics.insert("accuracy".to_string(), accuracy);
657    metrics.insert("classes".to_string(), by_class.len() as f64);
658    Ok(linfa_model(LinfaKind::NaiveBayes, data, config, metrics))
659}
660
661/// Parsed Gaussian-NB model used for scoring (train accuracy + inference).
662struct NaiveBayesModel {
663    classes: Vec<(f64, f64, Vec<f64>, Vec<f64>)>, // (label, log_prior, means, vars)
664}
665impl NaiveBayesModel {
666    fn from_json(classes: &[serde_json::Value]) -> Self {
667        let classes = classes
668            .iter()
669            .map(|c| {
670                let label = c["label"].as_f64().unwrap_or(0.0);
671                let prior = c["prior"].as_f64().unwrap_or(0.0);
672                let means: Vec<f64> =
673                    serde_json::from_value(c["means"].clone()).unwrap_or_default();
674                let vars: Vec<f64> = serde_json::from_value(c["vars"].clone()).unwrap_or_default();
675                (label, prior, means, vars)
676            })
677            .collect();
678        Self { classes }
679    }
680    fn predict(&self, row: &[f64]) -> f64 {
681        let mut best_label = 0.0;
682        let mut best_score = f64::NEG_INFINITY;
683        for (label, log_prior, means, vars) in &self.classes {
684            let mut score = *log_prior;
685            for j in 0..row.len().min(means.len()) {
686                let v = vars[j].max(1e-9);
687                score += -0.5
688                    * ((row[j] - means[j]).powi(2) / v + (2.0 * std::f64::consts::PI * v).ln());
689            }
690            if score > best_score {
691                best_score = score;
692                best_label = *label;
693            }
694        }
695        best_label
696    }
697}
698
699/// DBSCAN density clustering (unsupervised; `target` ignored). Hyperparameters:
700/// `eps` (neighborhood radius, default 0.5) and `min_samples` (default 3).
701/// Predict assigns a new point to the cluster of the nearest core point within
702/// `eps`, or -1 (noise) if none.
703fn train_dbscan(config: &TrainConfig) -> Result<TlModel, String> {
704    let x = features_to_array2(&config.features)?;
705    let n = x.nrows();
706    if n == 0 {
707        return Err("DBSCAN: no training samples".to_string());
708    }
709    let pts: Vec<Vec<f64>> = (0..n).map(|i| x.row(i).to_vec()).collect();
710    let eps = config.hyperparams.get("eps").copied().unwrap_or(0.5);
711    let min_samples = config
712        .hyperparams
713        .get("min_samples")
714        .or_else(|| config.hyperparams.get("min_points"))
715        .copied()
716        .map(|v| (v as usize).max(1))
717        .unwrap_or(3);
718    let eps2 = eps * eps;
719    let neighbors = |i: usize| -> Vec<usize> {
720        (0..n)
721            .filter(|&j| sq_dist(&pts[i], &pts[j]) <= eps2)
722            .collect()
723    };
724
725    let mut labels = vec![-1i64; n];
726    let mut visited = vec![false; n];
727    let mut cid = 0i64;
728    for i in 0..n {
729        if visited[i] {
730            continue;
731        }
732        visited[i] = true;
733        let nb = neighbors(i);
734        if nb.len() < min_samples {
735            continue; // provisional noise (may be absorbed as a border point)
736        }
737        labels[i] = cid;
738        let mut queue = nb;
739        let mut qi = 0;
740        while qi < queue.len() {
741            let q = queue[qi];
742            qi += 1;
743            if labels[q] < 0 {
744                labels[q] = cid;
745            }
746            if !visited[q] {
747                visited[q] = true;
748                let qnb = neighbors(q);
749                if qnb.len() >= min_samples {
750                    for m in qnb {
751                        if !queue.contains(&m) {
752                            queue.push(m);
753                        }
754                    }
755                }
756            }
757        }
758        cid += 1;
759    }
760
761    let mut cores: Vec<serde_json::Value> = Vec::new();
762    let mut n_noise = 0usize;
763    for i in 0..n {
764        if labels[i] < 0 {
765            n_noise += 1;
766        } else if neighbors(i).len() >= min_samples {
767            cores.push(serde_json::json!({ "p": pts[i], "c": labels[i] as f64 }));
768        }
769    }
770
771    let model_data = serde_json::json!({ "type": "dbscan", "eps": eps, "cores": cores });
772    let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
773    let mut metrics = HashMap::new();
774    metrics.insert("clusters".to_string(), cid as f64);
775    metrics.insert("noise".to_string(), n_noise as f64);
776    Ok(linfa_model(LinfaKind::Dbscan, data, config, metrics))
777}
778
779/// Ridge regression (L2-regularized least squares) solved via the normal
780/// equations. Hyperparameter `alpha` (a.k.a. `lambda`, default 1.0); the
781/// intercept is not regularized.
782fn train_ridge(config: &TrainConfig) -> Result<TlModel, String> {
783    let x = features_to_array2(&config.features)?;
784    let y = target_to_array1(&config.target)?;
785    let n = x.nrows();
786    let d = x.ncols();
787    if n == 0 {
788        return Err("Ridge: no training samples".to_string());
789    }
790    let lambda = config
791        .hyperparams
792        .get("alpha")
793        .or_else(|| config.hyperparams.get("lambda"))
794        .copied()
795        .unwrap_or(1.0);
796
797    let p = d + 1; // trailing intercept column
798    let row_aug = |i: usize| -> Vec<f64> {
799        let mut r = x.row(i).to_vec();
800        r.push(1.0);
801        r
802    };
803    let mut a = vec![vec![0.0f64; p]; p];
804    let mut bvec = vec![0.0f64; p];
805    for i in 0..n {
806        let r = row_aug(i);
807        let yi = y[i];
808        for j in 0..p {
809            for k2 in 0..p {
810                a[j][k2] += r[j] * r[k2];
811            }
812            bvec[j] += r[j] * yi;
813        }
814    }
815    // Regularize the coefficients, not the intercept (the trailing row/col `d`).
816    for (j, row) in a.iter_mut().enumerate().take(d) {
817        row[j] += lambda;
818    }
819    let w = solve_linear_system(a, bvec)
820        .ok_or("Ridge: singular system — try a larger alpha or fewer collinear features")?;
821    let coef: Vec<f64> = w[0..d].to_vec();
822    let intercept = w[d];
823
824    // R² on the training data.
825    let mean_y = y.iter().sum::<f64>() / n as f64;
826    let (mut ss_res, mut ss_tot) = (0.0, 0.0);
827    for i in 0..n {
828        let row = x.row(i);
829        let pred: f64 = row.iter().zip(&coef).map(|(a, b)| a * b).sum::<f64>() + intercept;
830        ss_res += (y[i] - pred).powi(2);
831        ss_tot += (y[i] - mean_y).powi(2);
832    }
833    let r2 = if ss_tot > 0.0 {
834        1.0 - ss_res / ss_tot
835    } else {
836        0.0
837    };
838
839    let model_data = serde_json::json!({ "type": "ridge", "params": coef, "intercept": intercept });
840    let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
841    let mut metrics = HashMap::new();
842    metrics.insert("r2".to_string(), r2);
843    Ok(linfa_model(LinfaKind::Ridge, data, config, metrics))
844}
845
846/// Build a weighted CART **regression** tree over sample indices `idx`, fitting
847/// targets `r` with weights `w` by minimizing weighted SSE. Emits the same JSON
848/// node format as the decision-tree builder, so `predict_tree_json` traverses it.
849/// (linfa-trees is classifier-only, so gradient boosting needs this.)
850fn build_reg_tree(
851    idx: &[usize],
852    x: &Array2<f64>,
853    r: &[f64],
854    w: &[f64],
855    depth: usize,
856    max_depth: usize,
857    min_leaf: usize,
858) -> serde_json::Value {
859    let (mut sw, mut swr, mut swr2) = (0.0f64, 0.0f64, 0.0f64);
860    for &i in idx {
861        sw += w[i];
862        swr += w[i] * r[i];
863        swr2 += w[i] * r[i] * r[i];
864    }
865    let leaf_val = if sw > 0.0 { swr / sw } else { 0.0 };
866    let leaf = serde_json::json!({ "leaf": true, "value": leaf_val });
867    if depth >= max_depth || idx.len() <= min_leaf.max(1) || sw <= 0.0 {
868        return leaf;
869    }
870    let parent_sse = swr2 - swr * swr / sw;
871
872    let d = x.ncols();
873    let mut best: Option<(usize, f64, f64)> = None; // (feature, threshold, sse)
874    for f in 0..d {
875        let mut order: Vec<usize> = idx.to_vec();
876        order.sort_by(|&a, &b| {
877            x[[a, f]]
878                .partial_cmp(&x[[b, f]])
879                .unwrap_or(std::cmp::Ordering::Equal)
880        });
881        let (mut lw, mut lwr, mut lwr2) = (0.0f64, 0.0f64, 0.0f64);
882        for k in 0..order.len() - 1 {
883            let i = order[k];
884            lw += w[i];
885            lwr += w[i] * r[i];
886            lwr2 += w[i] * r[i] * r[i];
887            let (xi, xnext) = (x[[order[k], f]], x[[order[k + 1], f]]);
888            if xi == xnext {
889                continue;
890            }
891            let rw = sw - lw;
892            if lw <= 0.0 || rw <= 0.0 {
893                continue;
894            }
895            let sse_l = lwr2 - lwr * lwr / lw;
896            let sse_r = (swr2 - lwr2) - (swr - lwr) * (swr - lwr) / rw;
897            let sse = sse_l + sse_r;
898            if best.is_none_or(|(_, _, bs)| sse < bs) {
899                best = Some((f, (xi + xnext) / 2.0, sse));
900            }
901        }
902    }
903
904    match best {
905        Some((f, thr, sse)) if sse < parent_sse - 1e-12 => {
906            let left: Vec<usize> = idx.iter().copied().filter(|&i| x[[i, f]] < thr).collect();
907            let right: Vec<usize> = idx.iter().copied().filter(|&i| x[[i, f]] >= thr).collect();
908            if left.is_empty() || right.is_empty() {
909                return leaf;
910            }
911            serde_json::json!({
912                "leaf": false, "feature": f, "threshold": thr,
913                "left": build_reg_tree(&left, x, r, w, depth + 1, max_depth, min_leaf),
914                "right": build_reg_tree(&right, x, r, w, depth + 1, max_depth, min_leaf),
915            })
916        }
917        _ => leaf,
918    }
919}
920
921/// Gradient-boosted decision trees (the XGBoost family). Fits shallow regression
922/// trees to the gradient/Hessian of the loss (second-order / Newton leaves, like
923/// XGBoost). Auto-detects the task: ≤2 distinct 0/1 targets ⇒ binary logistic
924/// classification, otherwise squared-error regression (override with hyperparam
925/// `objective`: 1 = binary, 0 = regression). Hyperparameters: `n_estimators`
926/// (default 100), `learning_rate` (0.1), `max_depth` (3), `min_leaf` (1).
927fn train_gradient_boosting(config: &TrainConfig) -> Result<TlModel, String> {
928    let x = features_to_array2(&config.features)?;
929    let y = target_to_array1(&config.target)?;
930    let n = x.nrows();
931    if n == 0 {
932        return Err("Gradient boosting: no training samples".to_string());
933    }
934    let hp_usize = |a: &str, b: &str, def: usize| -> usize {
935        config
936            .hyperparams
937            .get(a)
938            .or_else(|| config.hyperparams.get(b))
939            .copied()
940            .map(|v| (v as usize).max(1))
941            .unwrap_or(def)
942    };
943    let n_est = hp_usize("n_estimators", "trees", 100);
944    let max_depth = hp_usize("max_depth", "depth", 3);
945    let min_leaf = hp_usize("min_leaf", "min_samples_leaf", 1);
946    let lr = config
947        .hyperparams
948        .get("learning_rate")
949        .or_else(|| config.hyperparams.get("eta"))
950        .copied()
951        .unwrap_or(0.1);
952
953    let all01 = y.iter().all(|v| *v == 0.0 || *v == 1.0);
954    let distinct: HashSet<i64> = y.iter().map(|v| *v as i64).collect();
955    let binary = match config.hyperparams.get("objective") {
956        Some(o) => *o > 0.5,
957        None => all01 && distinct.len() <= 2,
958    };
959    if binary && !all01 {
960        return Err("Gradient boosting (binary objective) requires 0/1 targets".to_string());
961    }
962
963    // Initial raw score: mean (regression) or log-odds (classification).
964    let init = if binary {
965        let pos = y.iter().filter(|&&v| v == 1.0).count() as f64;
966        let p = (pos / n as f64).clamp(1e-6, 1.0 - 1e-6);
967        (p / (1.0 - p)).ln()
968    } else {
969        y.iter().sum::<f64>() / n as f64
970    };
971
972    let mut f_scores = vec![init; n];
973    let all_idx: Vec<usize> = (0..n).collect();
974    let mut trees: Vec<serde_json::Value> = Vec::with_capacity(n_est);
975
976    for _ in 0..n_est {
977        // Per-sample gradient g and Hessian h of the loss; Newton pseudo-residual
978        // r = -g/h with weight h reproduces XGBoost's optimal leaf weight -G/H.
979        let mut r = vec![0.0f64; n];
980        let mut w = vec![0.0f64; n];
981        for i in 0..n {
982            let (g, h) = if binary {
983                let p = 1.0 / (1.0 + (-f_scores[i]).exp());
984                (p - y[i], (p * (1.0 - p)).max(1e-6))
985            } else {
986                (f_scores[i] - y[i], 1.0)
987            };
988            r[i] = -g / h;
989            w[i] = h;
990        }
991        let tree = build_reg_tree(&all_idx, &x, &r, &w, 0, max_depth, min_leaf);
992        for (i, fs) in f_scores.iter_mut().enumerate() {
993            *fs += lr * predict_tree_json(&tree, &x.row(i).to_vec());
994        }
995        trees.push(tree);
996    }
997
998    // Training metric.
999    let mut metrics = HashMap::new();
1000    if binary {
1001        let correct = (0..n)
1002            .filter(|&i| ((1.0 / (1.0 + (-f_scores[i]).exp()) > 0.5) as i32 as f64) == y[i])
1003            .count();
1004        metrics.insert("accuracy".to_string(), correct as f64 / n as f64);
1005    } else {
1006        let mean_y = y.iter().sum::<f64>() / n as f64;
1007        let (mut ss_res, mut ss_tot) = (0.0, 0.0);
1008        for i in 0..n {
1009            ss_res += (y[i] - f_scores[i]).powi(2);
1010            ss_tot += (y[i] - mean_y).powi(2);
1011        }
1012        metrics.insert(
1013            "r2".to_string(),
1014            if ss_tot > 0.0 {
1015                1.0 - ss_res / ss_tot
1016            } else {
1017                0.0
1018            },
1019        );
1020    }
1021    metrics.insert("n_estimators".to_string(), n_est as f64);
1022
1023    let model_data = serde_json::json!({
1024        "type": "gradient_boosting", "binary": binary, "init": init, "lr": lr, "trees": trees,
1025    });
1026    let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
1027    Ok(linfa_model(
1028        LinfaKind::GradientBoosting,
1029        data,
1030        config,
1031        metrics,
1032    ))
1033}
1034
1035/// Build a `TlModel::Linfa` with standard metadata (shared by the new algorithms).
1036fn linfa_model(
1037    kind: LinfaKind,
1038    data: Vec<u8>,
1039    config: &TrainConfig,
1040    metrics: HashMap<String, f64>,
1041) -> TlModel {
1042    TlModel::Linfa {
1043        kind,
1044        data,
1045        metadata: ModelMeta {
1046            name: config.model_name.clone(),
1047            version: "0.1.0".to_string(),
1048            created_at: String::new(),
1049            features: config.feature_names.clone(),
1050            target: config.target_name.clone(),
1051            metrics,
1052        },
1053    }
1054}
1055
1056/// Predict using a linfa model.
1057pub fn predict_linfa(model: &TlModel, input: &TlTensor) -> Result<TlTensor, String> {
1058    match model {
1059        TlModel::Linfa { kind, data, .. } => match kind {
1060            LinfaKind::LinearRegression | LinfaKind::Ridge => {
1061                let model_data: serde_json::Value = serde_json::from_slice(data)
1062                    .map_err(|e| format!("Deserialization failed: {e}"))?;
1063                let params: Vec<f64> = model_data["params"]
1064                    .as_array()
1065                    .ok_or("Missing params")?
1066                    .iter()
1067                    .map(|v| v.as_f64().unwrap_or(0.0))
1068                    .collect();
1069                let intercept: f64 = model_data["intercept"].as_f64().unwrap_or(0.0);
1070
1071                let shape = input.shape();
1072                if shape.len() == 1 {
1073                    let x = input.to_vec();
1074                    let pred: f64 =
1075                        x.iter().zip(params.iter()).map(|(a, b)| a * b).sum::<f64>() + intercept;
1076                    Ok(TlTensor::from_list(vec![pred]))
1077                } else if shape.len() == 2 {
1078                    let rows = shape[0];
1079                    let cols = shape[1];
1080                    let flat = input.to_vec();
1081                    let mut preds = Vec::with_capacity(rows);
1082                    for i in 0..rows {
1083                        let row = &flat[i * cols..(i + 1) * cols];
1084                        let pred: f64 = row
1085                            .iter()
1086                            .zip(params.iter())
1087                            .map(|(a, b)| a * b)
1088                            .sum::<f64>()
1089                            + intercept;
1090                        preds.push(pred);
1091                    }
1092                    Ok(TlTensor::from_list(preds))
1093                } else {
1094                    Err(format!("Input must be 1D or 2D, got {}D", shape.len()))
1095                }
1096            }
1097            LinfaKind::LogisticRegression => {
1098                let model_data: serde_json::Value = serde_json::from_slice(data)
1099                    .map_err(|e| format!("Deserialization failed: {e}"))?;
1100                let params: Vec<f64> = model_data["params"]
1101                    .as_array()
1102                    .ok_or("Missing params")?
1103                    .iter()
1104                    .map(|v| v.as_f64().unwrap_or(0.0))
1105                    .collect();
1106                let intercept: f64 = model_data["intercept"].as_f64().unwrap_or(0.0);
1107                // Class labels for each side of the decision boundary (persisted at
1108                // train time). Default to 1/0 for models saved before this fix.
1109                let pos_label = model_data["pos_label"].as_f64().unwrap_or(1.0);
1110                let neg_label = model_data["neg_label"].as_f64().unwrap_or(0.0);
1111
1112                apply_rowwise(input, |row| {
1113                    let logit: f64 = row
1114                        .iter()
1115                        .zip(params.iter())
1116                        .map(|(a, b)| a * b)
1117                        .sum::<f64>()
1118                        + intercept;
1119                    let prob = 1.0 / (1.0 + (-logit).exp());
1120                    if prob > 0.5 { pos_label } else { neg_label }
1121                })
1122            }
1123            LinfaKind::DecisionTree => {
1124                let model_data: serde_json::Value = serde_json::from_slice(data)
1125                    .map_err(|e| format!("Deserialization failed: {e}"))?;
1126                let tree = model_data["tree"].clone();
1127                if tree.is_null() {
1128                    return Err(
1129                        "This decision-tree model was saved without its tree structure; retrain it."
1130                            .to_string(),
1131                    );
1132                }
1133                apply_rowwise(input, |row| predict_tree_json(&tree, row))
1134            }
1135            LinfaKind::RandomForest => {
1136                let model_data: serde_json::Value = serde_json::from_slice(data)
1137                    .map_err(|e| format!("Deserialization failed: {e}"))?;
1138                let trees: Vec<serde_json::Value> = model_data["trees"]
1139                    .as_array()
1140                    .ok_or("Missing trees")?
1141                    .clone();
1142                apply_rowwise(input, |row| vote_trees(&trees, row))
1143            }
1144            LinfaKind::KMeans => {
1145                let model_data: serde_json::Value = serde_json::from_slice(data)
1146                    .map_err(|e| format!("Deserialization failed: {e}"))?;
1147                let centroids: Vec<Vec<f64>> =
1148                    serde_json::from_value(model_data["centroids"].clone())
1149                        .map_err(|e| format!("Missing centroids: {e}"))?;
1150                apply_rowwise(input, |row| {
1151                    let mut best = 0usize;
1152                    let mut best_d = f64::INFINITY;
1153                    for (c, cen) in centroids.iter().enumerate() {
1154                        let dist: f64 = row.iter().zip(cen).map(|(a, b)| (a - b) * (a - b)).sum();
1155                        if dist < best_d {
1156                            best_d = dist;
1157                            best = c;
1158                        }
1159                    }
1160                    best as f64
1161                })
1162            }
1163            LinfaKind::Knn => {
1164                let model_data: serde_json::Value = serde_json::from_slice(data)
1165                    .map_err(|e| format!("Deserialization failed: {e}"))?;
1166                let k = model_data["k"].as_u64().unwrap_or(5) as usize;
1167                let xtrain: Vec<Vec<f64>> = serde_json::from_value(model_data["x"].clone())
1168                    .map_err(|e| format!("Missing training data: {e}"))?;
1169                let ytrain: Vec<f64> = serde_json::from_value(model_data["y"].clone())
1170                    .map_err(|e| format!("Missing labels: {e}"))?;
1171                apply_rowwise(input, |row| knn_vote(&xtrain, &ytrain, k, row))
1172            }
1173            LinfaKind::NaiveBayes => {
1174                let model_data: serde_json::Value = serde_json::from_slice(data)
1175                    .map_err(|e| format!("Deserialization failed: {e}"))?;
1176                let classes = model_data["classes"]
1177                    .as_array()
1178                    .ok_or("Missing classes")?
1179                    .clone();
1180                let nb = NaiveBayesModel::from_json(&classes);
1181                apply_rowwise(input, |row| nb.predict(row))
1182            }
1183            LinfaKind::Dbscan => {
1184                let model_data: serde_json::Value = serde_json::from_slice(data)
1185                    .map_err(|e| format!("Deserialization failed: {e}"))?;
1186                let eps = model_data["eps"].as_f64().unwrap_or(0.5);
1187                let eps2 = eps * eps;
1188                let cores: Vec<(Vec<f64>, f64)> = model_data["cores"]
1189                    .as_array()
1190                    .ok_or("Missing cores")?
1191                    .iter()
1192                    .map(|c| {
1193                        let p: Vec<f64> =
1194                            serde_json::from_value(c["p"].clone()).unwrap_or_default();
1195                        (p, c["c"].as_f64().unwrap_or(-1.0))
1196                    })
1197                    .collect();
1198                apply_rowwise(input, |row| {
1199                    let mut best = -1.0;
1200                    let mut best_d = f64::INFINITY;
1201                    for (p, c) in &cores {
1202                        let dist = sq_dist(p, row);
1203                        if dist <= eps2 && dist < best_d {
1204                            best_d = dist;
1205                            best = *c;
1206                        }
1207                    }
1208                    best
1209                })
1210            }
1211            LinfaKind::GradientBoosting => {
1212                let model_data: serde_json::Value = serde_json::from_slice(data)
1213                    .map_err(|e| format!("Deserialization failed: {e}"))?;
1214                let binary = model_data["binary"].as_bool().unwrap_or(false);
1215                let init = model_data["init"].as_f64().unwrap_or(0.0);
1216                let lr = model_data["lr"].as_f64().unwrap_or(0.1);
1217                let trees: Vec<serde_json::Value> = model_data["trees"]
1218                    .as_array()
1219                    .ok_or("Missing trees")?
1220                    .clone();
1221                apply_rowwise(input, |row| {
1222                    let mut score = init;
1223                    for t in &trees {
1224                        score += lr * predict_tree_json(t, row);
1225                    }
1226                    if binary {
1227                        if 1.0 / (1.0 + (-score).exp()) > 0.5 {
1228                            1.0
1229                        } else {
1230                            0.0
1231                        }
1232                    } else {
1233                        score
1234                    }
1235                })
1236            }
1237        },
1238        _ => Err("predict_linfa called on non-Linfa model".to_string()),
1239    }
1240}
1241
1242#[cfg(test)]
1243mod tests {
1244    use super::*;
1245
1246    #[test]
1247    fn test_train_linear_regression() {
1248        // y = 2*x1 + 3*x2 + 1
1249        let features = TlTensor::from_vec(
1250            vec![
1251                1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 2.0, 1.0, 3.0, 2.0, 3.0,
1252                3.0, 3.0, 4.0, 4.0,
1253            ],
1254            &[10, 2],
1255        )
1256        .unwrap();
1257
1258        let target = TlTensor::from_list(vec![
1259            6.0, 8.0, 10.0, 9.0, 11.0, 13.0, 12.0, 14.0, 16.0, 21.0,
1260        ]);
1261
1262        let config = TrainConfig {
1263            features,
1264            target,
1265            feature_names: vec!["x1".to_string(), "x2".to_string()],
1266            target_name: "y".to_string(),
1267            model_name: "test_linear".to_string(),
1268            split_ratio: 1.0,
1269            hyperparams: HashMap::new(),
1270        };
1271
1272        let model = train("linear", &config).unwrap();
1273        if let TlModel::Linfa { metadata, .. } = &model {
1274            assert!(metadata.metrics["r2"] > 0.9, "R² should be > 0.9");
1275        } else {
1276            panic!("Expected Linfa model");
1277        }
1278    }
1279
1280    #[test]
1281    fn test_predict_linear() {
1282        let features =
1283            TlTensor::from_vec(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 0.0], &[4, 2]).unwrap();
1284        let target = TlTensor::from_list(vec![2.0, 3.0, 5.0, 4.0]);
1285
1286        let config = TrainConfig {
1287            features,
1288            target,
1289            feature_names: vec!["x1".to_string(), "x2".to_string()],
1290            target_name: "y".to_string(),
1291            model_name: "test".to_string(),
1292            split_ratio: 1.0,
1293            hyperparams: HashMap::new(),
1294        };
1295
1296        let model = train("linear", &config).unwrap();
1297        let input = TlTensor::from_vec(vec![1.0, 0.0], &[1, 2]).unwrap();
1298        let pred = predict_linfa(&model, &input).unwrap();
1299        // Should be close to 2.0
1300        assert!((pred.to_vec()[0] - 2.0).abs() < 1.0);
1301    }
1302}