[][src]Module smartcore::tree::decision_tree_classifier

Classification tree for dependent variables that take a finite number of unordered values.

Decision Tree Classifier

The process of building a classification tree is similar to the task of building a regression tree. However, in the classification setting one of these criteriums is used for making the binary splits:

  • Classification error rate, \(E = 1 - \max_k(p_{mk})\)

  • Gini index, \(G = \sum_{k=1}^K p_{mk}(1 - p_{mk})\)

  • Entropy, \(D = -\sum_{k=1}^K p_{mk}\log p_{mk}\)

where \(p_{mk}\) represents the proportion of training observations in the mth region that are from the kth class.

The classification error rate is simply the fraction of the training observations in that region that do not belong to the most common class. Classification error is not sufficiently sensitive for tree-growing, and in practice Gini index or Entropy are preferable.

The Gini index is referred to as a measure of node purity. A small value indicates that a node contains predominantly observations from a single class.

The Entropy, like Gini index will take on a small value if the mth node is pure.

Example:

use smartcore::linalg::naive::dense_matrix::*;
use smartcore::tree::decision_tree_classifier::*;

// Iris dataset
let x = DenseMatrix::from_2d_array(&[
           &[5.1, 3.5, 1.4, 0.2],
           &[4.9, 3.0, 1.4, 0.2],
           &[4.7, 3.2, 1.3, 0.2],
           &[4.6, 3.1, 1.5, 0.2],
           &[5.0, 3.6, 1.4, 0.2],
           &[5.4, 3.9, 1.7, 0.4],
           &[4.6, 3.4, 1.4, 0.3],
           &[5.0, 3.4, 1.5, 0.2],
           &[4.4, 2.9, 1.4, 0.2],
           &[4.9, 3.1, 1.5, 0.1],
           &[7.0, 3.2, 4.7, 1.4],
           &[6.4, 3.2, 4.5, 1.5],
           &[6.9, 3.1, 4.9, 1.5],
           &[5.5, 2.3, 4.0, 1.3],
           &[6.5, 2.8, 4.6, 1.5],
           &[5.7, 2.8, 4.5, 1.3],
           &[6.3, 3.3, 4.7, 1.6],
           &[4.9, 2.4, 3.3, 1.0],
           &[6.6, 2.9, 4.6, 1.3],
           &[5.2, 2.7, 3.9, 1.4],
        ]);
let y = vec![ 0., 0., 0., 0., 0., 0., 0., 0.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];

let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();

let y_hat = tree.predict(&x).unwrap(); // use the same data for prediction

References:

Structs

DecisionTreeClassifier

Decision Tree

DecisionTreeClassifierParameters

Parameters of Decision Tree

Enums

SplitCriterion

The function to measure the quality of a split.