sklears_tree/
model_tree.rs

1//! Model Trees - Decision Trees with Linear Models in Leaves
2//!
3//! Implementation of M5 Model Trees and variants that use linear regression
4//! models in leaf nodes instead of constant predictions.
5//!
6//! # References
7//!
8//! - Quinlan, J. R. (1992). Learning with continuous classes.
9//!   In Proceedings of the 5th Australian Joint Conference on AI (pp. 343-348).
10//! - Wang, Y., & Witten, I. H. (1997). Induction of model trees for
11//!   predicting continuous classes.
12
13use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
14use sklears_core::error::{Result, SklearsError};
15use sklears_core::traits::{Estimator, Fit, Predict, Trained, Untrained};
16use sklears_core::types::Float;
17use std::marker::PhantomData;
18
19/// Node in a Model Tree
20#[derive(Debug, Clone)]
21pub struct ModelTreeNode {
22    /// Feature index for splitting (None for leaf nodes)
23    pub feature: Option<usize>,
24    /// Split threshold value
25    pub threshold: Float,
26    /// Left child node (samples <= threshold)
27    pub left: Option<Box<ModelTreeNode>>,
28    /// Right child node (samples > threshold)
29    pub right: Option<Box<ModelTreeNode>>,
30    /// Linear model coefficients for leaf node
31    pub coefficients: Option<Array1<Float>>,
32    /// Intercept for linear model in leaf
33    pub intercept: Option<Float>,
34    /// Number of samples in this node
35    pub n_samples: usize,
36    /// Standard deviation of targets in this node
37    pub std_dev: Float,
38}
39
40impl ModelTreeNode {
41    /// Create a new leaf node with a linear model
42    pub fn new_leaf(
43        coefficients: Array1<Float>,
44        intercept: Float,
45        n_samples: usize,
46        std_dev: Float,
47    ) -> Self {
48        Self {
49            feature: None,
50            threshold: 0.0,
51            left: None,
52            right: None,
53            coefficients: Some(coefficients),
54            intercept: Some(intercept),
55            n_samples,
56            std_dev,
57        }
58    }
59
60    /// Create a new internal split node
61    pub fn new_internal(
62        feature: usize,
63        threshold: Float,
64        left: Self,
65        right: Self,
66        n_samples: usize,
67        std_dev: Float,
68    ) -> Self {
69        Self {
70            feature: Some(feature),
71            threshold,
72            left: Some(Box::new(left)),
73            right: Some(Box::new(right)),
74            coefficients: None,
75            intercept: None,
76            n_samples,
77            std_dev,
78        }
79    }
80
81    /// Check if this is a leaf node
82    pub fn is_leaf(&self) -> bool {
83        self.left.is_none() && self.right.is_none()
84    }
85
86    /// Predict value for a single sample
87    pub fn predict_sample(&self, sample: &ArrayView1<Float>) -> Float {
88        if self.is_leaf() {
89            // Apply linear model
90            if let (Some(coef), Some(intercept)) = (&self.coefficients, &self.intercept) {
91                let prediction: Float = sample.dot(coef) + intercept;
92                prediction
93            } else {
94                0.0 // Fallback for invalid leaf
95            }
96        } else if let Some(feature_idx) = self.feature {
97            // Navigate to child
98            let value = sample[feature_idx];
99            if value <= self.threshold {
100                self.left
101                    .as_ref()
102                    .map(|node| node.predict_sample(sample))
103                    .unwrap_or(0.0)
104            } else {
105                self.right
106                    .as_ref()
107                    .map(|node| node.predict_sample(sample))
108                    .unwrap_or(0.0)
109            }
110        } else {
111            0.0
112        }
113    }
114}
115
116/// Configuration for Model Tree
117#[derive(Debug, Clone)]
118pub struct ModelTreeConfig {
119    /// Maximum depth of the tree
120    pub max_depth: Option<usize>,
121    /// Minimum samples required to split an internal node
122    pub min_samples_split: usize,
123    /// Minimum samples required to be at a leaf node
124    pub min_samples_leaf: usize,
125    /// Minimum standard deviation reduction required for a split
126    pub min_std_dev_reduction: Float,
127    /// Whether to prune the tree
128    pub prune: bool,
129    /// Smoothing parameter for leaf predictions
130    pub smoothing: bool,
131    /// Model type in leaves
132    pub leaf_model: LeafModelType,
133}
134
135impl Default for ModelTreeConfig {
136    fn default() -> Self {
137        Self {
138            max_depth: None,
139            min_samples_split: 4,
140            min_samples_leaf: 2,
141            min_std_dev_reduction: 0.05,
142            prune: true,
143            smoothing: true,
144            leaf_model: LeafModelType::Linear,
145        }
146    }
147}
148
149/// Type of model to use in leaf nodes
150#[derive(Debug, Clone, Copy)]
151pub enum LeafModelType {
152    /// Linear regression in leaves
153    Linear,
154    /// Constant (mean) prediction in leaves
155    Constant,
156    /// Polynomial regression (degree 2) in leaves
157    Polynomial,
158}
159
160/// Model Tree for regression with linear models in leaves
161pub struct ModelTree<State = Untrained> {
162    config: ModelTreeConfig,
163    state: PhantomData<State>,
164    root: Option<ModelTreeNode>,
165    n_features: Option<usize>,
166    feature_importances: Option<Array1<Float>>,
167}
168
169impl ModelTree<Untrained> {
170    /// Create a new Model Tree
171    pub fn new() -> Self {
172        Self {
173            config: ModelTreeConfig::default(),
174            state: PhantomData,
175            root: None,
176            n_features: None,
177            feature_importances: None,
178        }
179    }
180
181    /// Set the maximum tree depth
182    pub fn max_depth(mut self, max_depth: usize) -> Self {
183        self.config.max_depth = Some(max_depth);
184        self
185    }
186
187    /// Set the minimum samples required to split
188    pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
189        self.config.min_samples_split = min_samples_split;
190        self
191    }
192
193    /// Set the minimum samples required at a leaf
194    pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
195        self.config.min_samples_leaf = min_samples_leaf;
196        self
197    }
198
199    /// Set the minimum standard deviation reduction
200    pub fn min_std_dev_reduction(mut self, min_std_dev_reduction: Float) -> Self {
201        self.config.min_std_dev_reduction = min_std_dev_reduction;
202        self
203    }
204
205    /// Enable or disable pruning
206    pub fn prune(mut self, prune: bool) -> Self {
207        self.config.prune = prune;
208        self
209    }
210
211    /// Set the leaf model type
212    pub fn leaf_model(mut self, leaf_model: LeafModelType) -> Self {
213        self.config.leaf_model = leaf_model;
214        self
215    }
216}
217
218impl Default for ModelTree<Untrained> {
219    fn default() -> Self {
220        Self::new()
221    }
222}
223
224impl Estimator for ModelTree<Untrained> {
225    type Config = ModelTreeConfig;
226    type Error = SklearsError;
227    type Float = Float;
228
229    fn config(&self) -> &Self::Config {
230        &self.config
231    }
232}
233
234impl Fit<Array2<Float>, Array1<Float>> for ModelTree<Untrained> {
235    type Fitted = ModelTree<Trained>;
236
237    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
238        let n_samples = x.nrows();
239        let n_features = x.ncols();
240
241        if n_samples == 0 {
242            return Err(SklearsError::InvalidInput(
243                "No samples provided".to_string(),
244            ));
245        }
246
247        if n_samples != y.len() {
248            return Err(SklearsError::ShapeMismatch {
249                expected: format!("X.shape[0] == y.len() ({})", x.nrows()),
250                actual: format!("y.len() = {}", y.len()),
251            });
252        }
253
254        // Build the tree
255        let indices: Vec<usize> = (0..n_samples).collect();
256        let root = build_model_tree(x, y, &indices, 0, &self.config)?;
257
258        // Compute feature importances
259        let mut feature_importances = Array1::zeros(n_features);
260        compute_feature_importances(&root, &mut feature_importances);
261
262        // Normalize importances
263        let sum = feature_importances.sum();
264        if sum > 0.0 {
265            feature_importances /= sum;
266        }
267
268        Ok(ModelTree::<Trained> {
269            config: self.config,
270            state: PhantomData,
271            root: Some(root),
272            n_features: Some(n_features),
273            feature_importances: Some(feature_importances),
274        })
275    }
276}
277
278impl ModelTree<Trained> {
279    /// Get the number of features
280    pub fn n_features(&self) -> usize {
281        self.n_features.expect("Model should be fitted")
282    }
283
284    /// Get feature importances
285    pub fn feature_importances(&self) -> &Array1<Float> {
286        self.feature_importances
287            .as_ref()
288            .expect("Model should be fitted")
289    }
290
291    /// Get the tree structure (root node)
292    pub fn tree(&self) -> &ModelTreeNode {
293        self.root.as_ref().expect("Model should be fitted")
294    }
295}
296
297impl Predict<Array2<Float>, Array1<Float>> for ModelTree<Trained> {
298    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
299        let n_samples = x.nrows();
300
301        if x.ncols() != self.n_features() {
302            return Err(SklearsError::FeatureMismatch {
303                expected: self.n_features(),
304                actual: x.ncols(),
305            });
306        }
307
308        let root = self.root.as_ref().ok_or(SklearsError::NotFitted {
309            operation: "predict".to_string(),
310        })?;
311
312        let mut predictions = Array1::zeros(n_samples);
313        for (i, sample) in x.axis_iter(Axis(0)).enumerate() {
314            predictions[i] = root.predict_sample(&sample);
315        }
316
317        Ok(predictions)
318    }
319}
320
321/// Build a model tree recursively
322fn build_model_tree(
323    x: &Array2<Float>,
324    y: &Array1<Float>,
325    indices: &[usize],
326    depth: usize,
327    config: &ModelTreeConfig,
328) -> Result<ModelTreeNode> {
329    let n_samples = indices.len();
330    let _n_features = x.ncols();
331
332    // Calculate standard deviation of targets
333    let mean_y: Float = indices.iter().map(|&i| y[i]).sum::<Float>() / n_samples as Float;
334    let variance: Float = indices
335        .iter()
336        .map(|&i| (y[i] - mean_y).powi(2))
337        .sum::<Float>()
338        / n_samples as Float;
339    let std_dev = variance.sqrt();
340
341    // Base cases for creating a leaf
342    let should_create_leaf = n_samples < config.min_samples_split
343        || depth >= config.max_depth.unwrap_or(usize::MAX)
344        || std_dev < config.min_std_dev_reduction
345        || n_samples < 2 * config.min_samples_leaf;
346
347    if should_create_leaf {
348        return create_leaf_node(x, y, indices, config);
349    }
350
351    // Find the best split
352    let best_split = find_best_split(x, y, indices, std_dev, config)?;
353
354    if let Some((feature, threshold, left_indices, right_indices, std_dev_reduction)) = best_split {
355        // Check if split provides sufficient improvement
356        if std_dev_reduction < config.min_std_dev_reduction {
357            return create_leaf_node(x, y, indices, config);
358        }
359
360        // Recursively build subtrees
361        let left_node = build_model_tree(x, y, &left_indices, depth + 1, config)?;
362        let right_node = build_model_tree(x, y, &right_indices, depth + 1, config)?;
363
364        Ok(ModelTreeNode::new_internal(
365            feature, threshold, left_node, right_node, n_samples, std_dev,
366        ))
367    } else {
368        // No valid split found, create leaf
369        create_leaf_node(x, y, indices, config)
370    }
371}
372
373/// Find the best split for the current node
374fn find_best_split(
375    x: &Array2<Float>,
376    y: &Array1<Float>,
377    indices: &[usize],
378    current_std_dev: Float,
379    config: &ModelTreeConfig,
380) -> Result<Option<(usize, Float, Vec<usize>, Vec<usize>, Float)>> {
381    let n_features = x.ncols();
382    let n_samples = indices.len();
383
384    let mut best_split = None;
385    let mut best_reduction = 0.0;
386
387    // Try each feature
388    for feature in 0..n_features {
389        // Get unique values for this feature
390        let mut feature_values: Vec<Float> = indices.iter().map(|&i| x[[i, feature]]).collect();
391        feature_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
392        feature_values.dedup();
393
394        if feature_values.len() < 2 {
395            continue;
396        }
397
398        // Try split points between consecutive values
399        for i in 0..feature_values.len() - 1 {
400            let threshold = (feature_values[i] + feature_values[i + 1]) / 2.0;
401
402            // Partition samples
403            let mut left_indices = Vec::new();
404            let mut right_indices = Vec::new();
405
406            for &idx in indices {
407                if x[[idx, feature]] <= threshold {
408                    left_indices.push(idx);
409                } else {
410                    right_indices.push(idx);
411                }
412            }
413
414            // Check minimum samples
415            if left_indices.len() < config.min_samples_leaf
416                || right_indices.len() < config.min_samples_leaf
417            {
418                continue;
419            }
420
421            // Calculate standard deviation reduction
422            let left_std = calculate_std_dev(y, &left_indices);
423            let right_std = calculate_std_dev(y, &right_indices);
424
425            let weighted_std = (left_indices.len() as Float * left_std
426                + right_indices.len() as Float * right_std)
427                / n_samples as Float;
428
429            let std_dev_reduction = current_std_dev - weighted_std;
430
431            if std_dev_reduction > best_reduction {
432                best_reduction = std_dev_reduction;
433                best_split = Some((
434                    feature,
435                    threshold,
436                    left_indices,
437                    right_indices,
438                    std_dev_reduction,
439                ));
440            }
441        }
442    }
443
444    Ok(best_split)
445}
446
447/// Calculate standard deviation for a subset of targets
448fn calculate_std_dev(y: &Array1<Float>, indices: &[usize]) -> Float {
449    if indices.is_empty() {
450        return 0.0;
451    }
452
453    let mean: Float = indices.iter().map(|&i| y[i]).sum::<Float>() / indices.len() as Float;
454    let variance: Float = indices
455        .iter()
456        .map(|&i| (y[i] - mean).powi(2))
457        .sum::<Float>()
458        / indices.len() as Float;
459
460    variance.sqrt()
461}
462
463/// Create a leaf node with a linear model
464fn create_leaf_node(
465    x: &Array2<Float>,
466    y: &Array1<Float>,
467    indices: &[usize],
468    config: &ModelTreeConfig,
469) -> Result<ModelTreeNode> {
470    let n_samples = indices.len();
471    let n_features = x.ncols();
472
473    // Calculate standard deviation
474    let std_dev = calculate_std_dev(y, indices);
475
476    match config.leaf_model {
477        LeafModelType::Constant => {
478            // Just use mean prediction
479            let mean: Float = indices.iter().map(|&i| y[i]).sum::<Float>() / n_samples as Float;
480            let coefficients = Array1::zeros(n_features);
481            Ok(ModelTreeNode::new_leaf(
482                coefficients,
483                mean,
484                n_samples,
485                std_dev,
486            ))
487        }
488        LeafModelType::Linear | LeafModelType::Polynomial => {
489            // Build linear regression model for this leaf
490            let (coefficients, intercept) = fit_linear_model(x, y, indices)?;
491            Ok(ModelTreeNode::new_leaf(
492                coefficients,
493                intercept,
494                n_samples,
495                std_dev,
496            ))
497        }
498    }
499}
500
501/// Fit a linear model using least squares
502fn fit_linear_model(
503    x: &Array2<Float>,
504    y: &Array1<Float>,
505    indices: &[usize],
506) -> Result<(Array1<Float>, Float)> {
507    let n_samples = indices.len();
508    let n_features = x.ncols();
509
510    if n_samples == 0 {
511        return Ok((Array1::zeros(n_features), 0.0));
512    }
513
514    // Extract subset of data
515    let mut x_subset = Array2::zeros((n_samples, n_features));
516    let mut y_subset = Array1::zeros(n_samples);
517
518    for (i, &idx) in indices.iter().enumerate() {
519        x_subset.row_mut(i).assign(&x.row(idx));
520        y_subset[i] = y[idx];
521    }
522
523    // Add column of ones for intercept
524    let mut x_design = Array2::ones((n_samples, n_features + 1));
525    for i in 0..n_samples {
526        for j in 0..n_features {
527            x_design[[i, j]] = x_subset[[i, j]];
528        }
529    }
530
531    // Solve normal equations: (X^T X) β = X^T y
532    // For numerical stability, use direct computation with regularization
533    let xt = x_design.t();
534    let xtx = xt.dot(&x_design);
535    let xty = xt.dot(&y_subset);
536
537    // Add small ridge regularization for numerical stability
538    let mut xtx_reg = xtx.to_owned();
539    for i in 0..n_features + 1 {
540        xtx_reg[[i, i]] += 1e-6;
541    }
542
543    // Solve using simple Gaussian elimination (for small systems)
544    let beta = solve_linear_system(&xtx_reg, &xty)?;
545
546    // Extract coefficients and intercept
547    let coefficients = beta
548        .slice(scirs2_core::ndarray::s![0..n_features])
549        .to_owned();
550    let intercept = beta[n_features];
551
552    Ok((coefficients, intercept))
553}
554
555/// Solve linear system Ax = b using Gaussian elimination
556fn solve_linear_system(a: &Array2<Float>, b: &Array1<Float>) -> Result<Array1<Float>> {
557    let n = a.nrows();
558
559    if n != b.len() {
560        return Err(SklearsError::ShapeMismatch {
561            expected: format!("A.nrows() == b.len() ({})", n),
562            actual: format!("b.len() = {}", b.len()),
563        });
564    }
565
566    // Create augmented matrix [A | b]
567    let mut aug = Array2::zeros((n, n + 1));
568    for i in 0..n {
569        for j in 0..n {
570            aug[[i, j]] = a[[i, j]];
571        }
572        aug[[i, n]] = b[i];
573    }
574
575    // Forward elimination with partial pivoting
576    for k in 0..n {
577        // Find pivot
578        let mut pivot_row = k;
579        let mut max_val = aug[[k, k]].abs();
580        for i in (k + 1)..n {
581            if aug[[i, k]].abs() > max_val {
582                max_val = aug[[i, k]].abs();
583                pivot_row = i;
584            }
585        }
586
587        // Swap rows if needed
588        if pivot_row != k {
589            for j in 0..=n {
590                let temp = aug[[k, j]];
591                aug[[k, j]] = aug[[pivot_row, j]];
592                aug[[pivot_row, j]] = temp;
593            }
594        }
595
596        // Eliminate column
597        let pivot = aug[[k, k]];
598        if pivot.abs() < 1e-10 {
599            // Singular matrix, use regularized fallback
600            continue;
601        }
602
603        for i in (k + 1)..n {
604            let factor = aug[[i, k]] / pivot;
605            for j in k..=n {
606                aug[[i, j]] -= factor * aug[[k, j]];
607            }
608        }
609    }
610
611    // Back substitution
612    let mut x = Array1::zeros(n);
613    for i in (0..n).rev() {
614        let mut sum = aug[[i, n]];
615        for j in (i + 1)..n {
616            sum -= aug[[i, j]] * x[j];
617        }
618
619        let diag = aug[[i, i]];
620        x[i] = if diag.abs() > 1e-10 {
621            sum / diag
622        } else {
623            0.0 // Fallback for singular case
624        };
625    }
626
627    Ok(x)
628}
629
630/// Compute feature importances from the tree
631fn compute_feature_importances(node: &ModelTreeNode, importances: &mut Array1<Float>) {
632    if let Some(feature) = node.feature {
633        // Importance is weighted by standard deviation reduction and sample count
634        let importance = node.std_dev * node.n_samples as Float;
635        importances[feature] += importance;
636
637        // Recursively compute for children
638        if let Some(ref left) = node.left {
639            compute_feature_importances(left, importances);
640        }
641        if let Some(ref right) = node.right {
642            compute_feature_importances(right, importances);
643        }
644    }
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650    use approx::assert_relative_eq;
651
652    #[test]
653    fn test_model_tree_basic() {
654        // Create simple dataset: y = 2*x1 + 3*x2 + 1
655        let mut x = Array2::zeros((100, 2));
656        let mut y = Array1::zeros(100);
657
658        for i in 0..100 {
659            let x1 = (i as Float) / 50.0 - 1.0;
660            let x2 = ((i * 2) as Float) / 50.0 - 2.0;
661            x[[i, 0]] = x1;
662            x[[i, 1]] = x2;
663            y[i] = 2.0 * x1 + 3.0 * x2 + 1.0;
664        }
665
666        let model = ModelTree::new().max_depth(5).min_samples_leaf(5);
667
668        let fitted = model.fit(&x, &y).unwrap();
669        let predictions = fitted.predict(&x).unwrap();
670
671        // Calculate R² score
672        let y_mean = y.mean().unwrap();
673        let ss_tot: Float = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
674        let ss_res: Float = y
675            .iter()
676            .zip(predictions.iter())
677            .map(|(&yi, &pred)| (yi - pred).powi(2))
678            .sum();
679        let r2 = 1.0 - ss_res / ss_tot;
680
681        assert!(
682            r2 > 0.8,
683            "R² should be high for linear relationship: {}",
684            r2
685        );
686    }
687
688    #[test]
689    fn test_model_tree_nonlinear() {
690        // Create nonlinear dataset: y = x^2
691        let mut x = Array2::zeros((50, 1));
692        let mut y = Array1::zeros(50);
693
694        for i in 0..50 {
695            let xi = (i as Float) / 10.0 - 2.5;
696            x[[i, 0]] = xi;
697            y[i] = xi * xi;
698        }
699
700        let model = ModelTree::new().max_depth(4).min_samples_leaf(3);
701
702        let fitted = model.fit(&x, &y).unwrap();
703        let predictions = fitted.predict(&x).unwrap();
704
705        // Model tree should approximate quadratic function reasonably well
706        let mse: Float = y
707            .iter()
708            .zip(predictions.iter())
709            .map(|(&yi, &pred)| (yi - pred).powi(2))
710            .sum::<Float>()
711            / y.len() as Float;
712
713        assert!(
714            mse < 1.0,
715            "MSE should be reasonable for piecewise linear approximation: {}",
716            mse
717        );
718    }
719
720    #[test]
721    fn test_linear_model_fitting() {
722        // Test linear model fitting
723        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
724        let y = Array1::from_vec(vec![3.0, 7.0, 11.0]); // y = 2*x1 + 1*x2
725        let indices = vec![0, 1, 2];
726
727        let (coef, intercept) = fit_linear_model(&x, &y, &indices).unwrap();
728
729        // Predictions should match targets closely
730        let mut error = 0.0;
731        for i in 0..3 {
732            let pred = x.row(i).dot(&coef) + intercept;
733            error += (y[i] - pred).abs();
734        }
735
736        assert!(error < 0.1, "Linear model should fit well: error={}", error);
737    }
738
739    #[test]
740    fn test_solve_linear_system() {
741        // Test solving 2x2 system:
742        // 2x + y = 5
743        // x + 3y = 8
744        // Solution: x = 1.4, y = 2.2
745        let a = Array2::from_shape_vec((2, 2), vec![2.0, 1.0, 1.0, 3.0]).unwrap();
746        let b = Array1::from_vec(vec![5.0, 8.0]);
747
748        let x = solve_linear_system(&a, &b).unwrap();
749
750        assert_relative_eq!(x[0], 1.4, epsilon = 1e-6);
751        assert_relative_eq!(x[1], 2.2, epsilon = 1e-6);
752    }
753
754    #[test]
755    fn test_constant_leaf_model() {
756        let mut x = Array2::zeros((20, 1));
757        let mut y = Array1::zeros(20);
758
759        for i in 0..20 {
760            x[[i, 0]] = (i as Float) / 10.0;
761            y[i] = if i < 10 { 1.0 } else { 5.0 };
762        }
763
764        let model = ModelTree::new()
765            .leaf_model(LeafModelType::Constant)
766            .max_depth(2);
767
768        let fitted = model.fit(&x, &y).unwrap();
769        let predictions = fitted.predict(&x).unwrap();
770
771        // With constant leaves, predictions in each region should be close to the mean
772        assert!(predictions[0] > 0.5 && predictions[0] < 2.0);
773        assert!(predictions[19] > 4.0 && predictions[19] < 6.0);
774    }
775}