use std::collections::LinkedList;
use std::default::Default;
use std::fmt::Debug;
use std::marker::PhantomData;
use rand::seq::SliceRandom;
use rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
use crate::numbers::basenum::Number;
use crate::rand_custom::get_rng_impl;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct DecisionTreeClassifierParameters {
#[cfg_attr(feature = "serde", serde(default))]
pub criterion: SplitCriterion,
#[cfg_attr(feature = "serde", serde(default))]
pub max_depth: Option<u16>,
#[cfg_attr(feature = "serde", serde(default))]
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
pub seed: Option<u64>,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct DecisionTreeClassifier<
TX: Number + PartialOrd,
TY: Number + Ord,
X: Array2<TX>,
Y: Array1<TY>,
> {
nodes: Vec<Node>,
parameters: Option<DecisionTreeClassifierParameters>,
num_classes: usize,
classes: Vec<TY>,
depth: u16,
_phantom_tx: PhantomData<TX>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
}
impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
DecisionTreeClassifier<TX, TY, X, Y>
{
fn nodes(&self) -> &Vec<Node> {
self.nodes.as_ref()
}
fn parameters(&self) -> &DecisionTreeClassifierParameters {
self.parameters.as_ref().unwrap()
}
fn classes(&self) -> &Vec<TY> {
self.classes.as_ref()
}
pub fn depth(&self) -> u16 {
self.depth
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Default)]
pub enum SplitCriterion {
#[default]
Gini,
Entropy,
ClassificationError,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
struct Node {
output: usize,
split_feature: usize,
split_value: Option<f64>,
split_score: Option<f64>,
true_child: Option<usize>,
false_child: Option<usize>,
}
impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
for DecisionTreeClassifier<TX, TY, X, Y>
{
fn eq(&self, other: &Self) -> bool {
if self.depth != other.depth
|| self.num_classes != other.num_classes
|| self.nodes().len() != other.nodes().len()
{
false
} else {
self.classes()
.iter()
.zip(other.classes().iter())
.all(|(a, b)| a == b)
&& self
.nodes()
.iter()
.zip(other.nodes().iter())
.all(|(a, b)| a == b)
}
}
}
impl PartialEq for Node {
fn eq(&self, other: &Self) -> bool {
self.output == other.output
&& self.split_feature == other.split_feature
&& match (self.split_value, other.split_value) {
(Some(a), Some(b)) => (a - b).abs() < std::f64::EPSILON,
(None, None) => true,
_ => false,
}
&& match (self.split_score, other.split_score) {
(Some(a), Some(b)) => (a - b).abs() < std::f64::EPSILON,
(None, None) => true,
_ => false,
}
}
}
impl DecisionTreeClassifierParameters {
pub fn with_criterion(mut self, criterion: SplitCriterion) -> Self {
self.criterion = criterion;
self
}
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
self.max_depth = Some(max_depth);
self
}
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
self.min_samples_leaf = min_samples_leaf;
self
}
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
self.min_samples_split = min_samples_split;
self
}
}
impl Default for DecisionTreeClassifierParameters {
fn default() -> Self {
DecisionTreeClassifierParameters {
criterion: SplitCriterion::default(),
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
seed: Option::None,
}
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct DecisionTreeClassifierSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
pub criterion: Vec<SplitCriterion>,
#[cfg_attr(feature = "serde", serde(default))]
pub max_depth: Vec<Option<u16>>,
#[cfg_attr(feature = "serde", serde(default))]
pub min_samples_leaf: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
pub min_samples_split: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
pub seed: Vec<Option<u64>>,
}
pub struct DecisionTreeClassifierSearchParametersIterator {
decision_tree_classifier_search_parameters: DecisionTreeClassifierSearchParameters,
current_criterion: usize,
current_max_depth: usize,
current_min_samples_leaf: usize,
current_min_samples_split: usize,
current_seed: usize,
}
impl IntoIterator for DecisionTreeClassifierSearchParameters {
type Item = DecisionTreeClassifierParameters;
type IntoIter = DecisionTreeClassifierSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
DecisionTreeClassifierSearchParametersIterator {
decision_tree_classifier_search_parameters: self,
current_criterion: 0,
current_max_depth: 0,
current_min_samples_leaf: 0,
current_min_samples_split: 0,
current_seed: 0,
}
}
}
impl Iterator for DecisionTreeClassifierSearchParametersIterator {
type Item = DecisionTreeClassifierParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_criterion
== self
.decision_tree_classifier_search_parameters
.criterion
.len()
&& self.current_max_depth
== self
.decision_tree_classifier_search_parameters
.max_depth
.len()
&& self.current_min_samples_leaf
== self
.decision_tree_classifier_search_parameters
.min_samples_leaf
.len()
&& self.current_min_samples_split
== self
.decision_tree_classifier_search_parameters
.min_samples_split
.len()
&& self.current_seed == self.decision_tree_classifier_search_parameters.seed.len()
{
return None;
}
let next = DecisionTreeClassifierParameters {
criterion: self.decision_tree_classifier_search_parameters.criterion
[self.current_criterion]
.clone(),
max_depth: self.decision_tree_classifier_search_parameters.max_depth
[self.current_max_depth],
min_samples_leaf: self
.decision_tree_classifier_search_parameters
.min_samples_leaf[self.current_min_samples_leaf],
min_samples_split: self
.decision_tree_classifier_search_parameters
.min_samples_split[self.current_min_samples_split],
seed: self.decision_tree_classifier_search_parameters.seed[self.current_seed],
};
if self.current_criterion + 1
< self
.decision_tree_classifier_search_parameters
.criterion
.len()
{
self.current_criterion += 1;
} else if self.current_max_depth + 1
< self
.decision_tree_classifier_search_parameters
.max_depth
.len()
{
self.current_criterion = 0;
self.current_max_depth += 1;
} else if self.current_min_samples_leaf + 1
< self
.decision_tree_classifier_search_parameters
.min_samples_leaf
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf += 1;
} else if self.current_min_samples_split + 1
< self
.decision_tree_classifier_search_parameters
.min_samples_split
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split += 1;
} else if self.current_seed + 1 < self.decision_tree_classifier_search_parameters.seed.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_seed += 1;
} else {
self.current_criterion += 1;
self.current_max_depth += 1;
self.current_min_samples_leaf += 1;
self.current_min_samples_split += 1;
self.current_seed += 1;
}
Some(next)
}
}
impl Default for DecisionTreeClassifierSearchParameters {
fn default() -> Self {
let default_params = DecisionTreeClassifierParameters::default();
DecisionTreeClassifierSearchParameters {
criterion: vec![default_params.criterion],
max_depth: vec![default_params.max_depth],
min_samples_leaf: vec![default_params.min_samples_leaf],
min_samples_split: vec![default_params.min_samples_split],
seed: vec![default_params.seed],
}
}
}
impl Node {
fn new(output: usize) -> Self {
Node {
output,
split_feature: 0,
split_value: Option::None,
split_score: Option::None,
true_child: Option::None,
false_child: Option::None,
}
}
}
struct NodeVisitor<'a, TX: Number + PartialOrd, X: Array2<TX>> {
x: &'a X,
y: &'a [usize],
node: usize,
samples: Vec<usize>,
order: &'a [Vec<usize>],
true_child_output: usize,
false_child_output: usize,
level: u16,
phantom: PhantomData<&'a TX>,
}
fn impurity(criterion: &SplitCriterion, count: &[usize], n: usize) -> f64 {
let mut impurity = 0f64;
match criterion {
SplitCriterion::Gini => {
impurity = 1f64;
for count_i in count.iter() {
if *count_i > 0 {
let p = *count_i as f64 / n as f64;
impurity -= p * p;
}
}
}
SplitCriterion::Entropy => {
for count_i in count.iter() {
if *count_i > 0 {
let p = *count_i as f64 / n as f64;
impurity -= p * p.log2();
}
}
}
SplitCriterion::ClassificationError => {
for count_i in count.iter() {
if *count_i > 0 {
impurity = impurity.max(*count_i as f64 / n as f64);
}
}
impurity = (1f64 - impurity).abs();
}
}
impurity
}
impl<'a, TX: Number + PartialOrd, X: Array2<TX>> NodeVisitor<'a, TX, X> {
fn new(
node_id: usize,
samples: Vec<usize>,
order: &'a [Vec<usize>],
x: &'a X,
y: &'a [usize],
level: u16,
) -> Self {
NodeVisitor {
x,
y,
node: node_id,
samples,
order,
true_child_output: 0,
false_child_output: 0,
level,
phantom: PhantomData,
}
}
}
pub(crate) fn which_max(x: &[usize]) -> usize {
let mut m = x[0];
let mut which = 0;
for (i, x_i) in x.iter().enumerate().skip(1) {
if *x_i > m {
m = *x_i;
which = i;
}
}
which
}
impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimator<X, Y, DecisionTreeClassifierParameters>
for DecisionTreeClassifier<TX, TY, X, Y>
{
fn new() -> Self {
Self {
nodes: vec![],
parameters: Option::None,
num_classes: 0usize,
classes: vec![],
depth: 0u16,
_phantom_tx: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
}
}
fn fit(x: &X, y: &Y, parameters: DecisionTreeClassifierParameters) -> Result<Self, Failed> {
DecisionTreeClassifier::fit(x, y, parameters)
}
}
impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> Predictor<X, Y>
for DecisionTreeClassifier<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
self.predict(x)
}
}
impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
DecisionTreeClassifier<TX, TY, X, Y>
{
pub fn fit(
x: &X,
y: &Y,
parameters: DecisionTreeClassifierParameters,
) -> Result<DecisionTreeClassifier<TX, TY, X, Y>, Failed> {
let (x_nrows, num_attributes) = x.shape();
if x_nrows != y.shape() {
return Err(Failed::fit("Size of x should equal size of y"));
}
let samples = vec![1; x_nrows];
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
}
pub(crate) fn fit_weak_learner(
x: &X,
y: &Y,
samples: Vec<usize>,
mtry: usize,
parameters: DecisionTreeClassifierParameters,
) -> Result<DecisionTreeClassifier<TX, TY, X, Y>, Failed> {
let y_ncols = y.shape();
let (_, num_attributes) = x.shape();
let classes = y.unique();
let k = classes.len();
if k < 2 {
return Err(Failed::fit(&format!(
"Incorrect number of classes: {k}. Should be >= 2."
)));
}
let mut rng = get_rng_impl(parameters.seed);
let mut yi: Vec<usize> = vec![0; y_ncols];
for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) {
let yc = y.get(i);
*yi_i = classes.iter().position(|c| yc == c).unwrap();
}
let mut change_nodes: Vec<Node> = Vec::new();
let mut count = vec![0; k];
for i in 0..y_ncols {
count[yi[i]] += samples[i];
}
let root = Node::new(which_max(&count));
change_nodes.push(root);
let mut order: Vec<Vec<usize>> = Vec::new();
for i in 0..num_attributes {
let mut col_i: Vec<TX> = x.get_col(i).iterator(0).copied().collect();
order.push(col_i.argsort_mut());
}
let mut tree = DecisionTreeClassifier {
nodes: change_nodes,
parameters: Some(parameters),
num_classes: k,
classes,
depth: 0u16,
_phantom_tx: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
};
let mut visitor = NodeVisitor::<TX, X>::new(0, samples, &order, x, &yi, 1);
let mut visitor_queue: LinkedList<NodeVisitor<'_, TX, X>> = LinkedList::new();
if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
visitor_queue.push_back(visitor);
}
while tree.depth() < tree.parameters().max_depth.unwrap_or(std::u16::MAX) {
match visitor_queue.pop_front() {
Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
None => break,
};
}
Ok(tree)
}
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let mut result = Y::zeros(x.shape().0);
let (n, _) = x.shape();
for i in 0..n {
result.set(i, self.classes()[self.predict_for_row(x, i)]);
}
Ok(result)
}
pub(crate) fn predict_for_row(&self, x: &X, row: usize) -> usize {
let mut result = 0;
let mut queue: LinkedList<usize> = LinkedList::new();
queue.push_back(0);
while !queue.is_empty() {
match queue.pop_front() {
Some(node_id) => {
let node = &self.nodes()[node_id];
if node.true_child.is_none() && node.false_child.is_none() {
result = node.output;
} else if x.get((row, node.split_feature)).to_f64().unwrap()
<= node.split_value.unwrap_or(std::f64::NAN)
{
queue.push_back(node.true_child.unwrap());
} else {
queue.push_back(node.false_child.unwrap());
}
}
None => break,
};
}
result
}
fn find_best_cutoff(
&mut self,
visitor: &mut NodeVisitor<'_, TX, X>,
mtry: usize,
rng: &mut impl Rng,
) -> bool {
let (n_rows, n_attr) = visitor.x.shape();
let mut label = Option::None;
let mut is_pure = true;
for i in 0..n_rows {
if visitor.samples[i] > 0 {
if label.is_none() {
label = Option::Some(visitor.y[i]);
} else if visitor.y[i] != label.unwrap() {
is_pure = false;
break;
}
}
}
if is_pure {
return false;
}
let n = visitor.samples.iter().sum();
if n <= self.parameters().min_samples_split {
return false;
}
let mut count = vec![0; self.num_classes];
let mut false_count = vec![0; self.num_classes];
for i in 0..n_rows {
if visitor.samples[i] > 0 {
count[visitor.y[i]] += visitor.samples[i];
}
}
let parent_impurity = impurity(&self.parameters().criterion, &count, n);
let mut variables = (0..n_attr).collect::<Vec<_>>();
if mtry < n_attr {
variables.shuffle(rng);
}
for variable in variables.iter().take(mtry) {
self.find_best_split(
visitor,
n,
&count,
&mut false_count,
parent_impurity,
*variable,
);
}
self.nodes()[visitor.node].split_score.is_some()
}
fn find_best_split(
&mut self,
visitor: &mut NodeVisitor<'_, TX, X>,
n: usize,
count: &[usize],
false_count: &mut [usize],
parent_impurity: f64,
j: usize,
) {
let mut true_count = vec![0; self.num_classes];
let mut prevx = Option::None;
let mut prevy = 0;
for i in visitor.order[j].iter() {
if visitor.samples[*i] > 0 {
let x_ij = *visitor.x.get((*i, j));
if prevx.is_none() || x_ij == prevx.unwrap() || visitor.y[*i] == prevy {
prevx = Some(x_ij);
prevy = visitor.y[*i];
true_count[visitor.y[*i]] += visitor.samples[*i];
continue;
}
let tc = true_count.iter().sum();
let fc = n - tc;
if tc < self.parameters().min_samples_leaf
|| fc < self.parameters().min_samples_leaf
{
prevx = Some(x_ij);
prevy = visitor.y[*i];
true_count[visitor.y[*i]] += visitor.samples[*i];
continue;
}
for l in 0..self.num_classes {
false_count[l] = count[l] - true_count[l];
}
let true_label = which_max(&true_count);
let false_label = which_max(false_count);
let gain = parent_impurity
- tc as f64 / n as f64
* impurity(&self.parameters().criterion, &true_count, tc)
- fc as f64 / n as f64
* impurity(&self.parameters().criterion, false_count, fc);
if self.nodes()[visitor.node].split_score.is_none()
|| gain > self.nodes()[visitor.node].split_score.unwrap()
{
self.nodes[visitor.node].split_feature = j;
self.nodes[visitor.node].split_value =
Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64);
self.nodes[visitor.node].split_score = Option::Some(gain);
visitor.true_child_output = true_label;
visitor.false_child_output = false_label;
}
prevx = Some(x_ij);
prevy = visitor.y[*i];
true_count[visitor.y[*i]] += visitor.samples[*i];
}
}
}
fn split<'a>(
&mut self,
mut visitor: NodeVisitor<'a, TX, X>,
mtry: usize,
visitor_queue: &mut LinkedList<NodeVisitor<'a, TX, X>>,
rng: &mut impl Rng,
) -> bool {
let (n, _) = visitor.x.shape();
let mut tc = 0;
let mut fc = 0;
let mut true_samples: Vec<usize> = vec![0; n];
for (i, true_sample) in true_samples.iter_mut().enumerate().take(n) {
if visitor.samples[i] > 0 {
if visitor
.x
.get((i, self.nodes()[visitor.node].split_feature))
.to_f64()
.unwrap()
<= self.nodes()[visitor.node]
.split_value
.unwrap_or(std::f64::NAN)
{
*true_sample = visitor.samples[i];
tc += *true_sample;
visitor.samples[i] = 0;
} else {
fc += visitor.samples[i];
}
}
}
if tc < self.parameters().min_samples_leaf || fc < self.parameters().min_samples_leaf {
self.nodes[visitor.node].split_feature = 0;
self.nodes[visitor.node].split_value = Option::None;
self.nodes[visitor.node].split_score = Option::None;
return false;
}
let true_child_idx = self.nodes().len();
self.nodes.push(Node::new(visitor.true_child_output));
let false_child_idx = self.nodes().len();
self.nodes.push(Node::new(visitor.false_child_output));
self.nodes[visitor.node].true_child = Some(true_child_idx);
self.nodes[visitor.node].false_child = Some(false_child_idx);
self.depth = u16::max(self.depth, visitor.level + 1);
let mut true_visitor = NodeVisitor::<TX, X>::new(
true_child_idx,
true_samples,
visitor.order,
visitor.x,
visitor.y,
visitor.level + 1,
);
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
visitor_queue.push_back(true_visitor);
}
let mut false_visitor = NodeVisitor::<TX, X>::new(
false_child_idx,
visitor.samples,
visitor.order,
visitor.x,
visitor.y,
visitor.level + 1,
);
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
visitor_queue.push_back(false_visitor);
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
#[test]
fn search_parameters() {
let parameters = DecisionTreeClassifierSearchParameters {
max_depth: vec![Some(10), Some(100)],
min_samples_split: vec![1, 2],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(10));
assert_eq!(next.min_samples_split, 1);
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(100));
assert_eq!(next.min_samples_split, 1);
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(10));
assert_eq!(next.min_samples_split, 2);
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(100));
assert_eq!(next.min_samples_split, 2);
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn gini_impurity() {
assert!((impurity(&SplitCriterion::Gini, &[7, 3], 10) - 0.42).abs() < std::f64::EPSILON);
assert!(
(impurity(&SplitCriterion::Entropy, &[7, 3], 10) - 0.8812908992306927).abs()
< std::f64::EPSILON
);
assert!(
(impurity(&SplitCriterion::ClassificationError, &[7, 3], 10) - 0.3).abs()
< std::f64::EPSILON
);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[cfg(feature = "datasets")]
fn fit_predict_iris() {
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
assert_eq!(
y,
DecisionTreeClassifier::fit(&x, &y, Default::default())
.and_then(|t| t.predict(&x))
.unwrap()
);
println!(
"{:?}",
DecisionTreeClassifier::fit(
&x,
&y,
DecisionTreeClassifierParameters {
criterion: SplitCriterion::Entropy,
max_depth: Some(3),
min_samples_leaf: 1,
min_samples_split: 2,
seed: Option::None
}
)
.unwrap()
.depth
);
}
#[test]
fn test_random_matrix_with_wrong_rownum() {
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(21, 200);
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
let fail = DecisionTreeClassifier::fit(&x_rand, &y, Default::default());
assert!(fail.is_err());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn fit_predict_baloons() {
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
&[1., 1., 1., 0.],
&[1., 1., 1., 0.],
&[1., 1., 1., 1.],
&[1., 1., 0., 0.],
&[1., 1., 0., 1.],
&[1., 0., 1., 0.],
&[1., 0., 1., 0.],
&[1., 0., 1., 1.],
&[1., 0., 0., 0.],
&[1., 0., 0., 1.],
&[0., 1., 1., 0.],
&[0., 1., 1., 0.],
&[0., 1., 1., 1.],
&[0., 1., 0., 0.],
&[0., 1., 0., 1.],
&[0., 0., 1., 0.],
&[0., 0., 1., 0.],
&[0., 0., 1., 1.],
&[0., 0., 0., 0.],
&[0., 0., 0., 1.],
]);
let y: Vec<u32> = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0];
assert_eq!(
y,
DecisionTreeClassifier::fit(&x, &y, Default::default())
.and_then(|t| t.predict(&x))
.unwrap()
);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
let x = DenseMatrix::from_2d_array(&[
&[1., 1., 1., 0.],
&[1., 1., 1., 0.],
&[1., 1., 1., 1.],
&[1., 1., 0., 0.],
&[1., 1., 0., 1.],
&[1., 0., 1., 0.],
&[1., 0., 1., 0.],
&[1., 0., 1., 1.],
&[1., 0., 0., 0.],
&[1., 0., 0., 1.],
&[0., 1., 1., 0.],
&[0., 1., 1., 0.],
&[0., 1., 1., 1.],
&[0., 1., 0., 0.],
&[0., 1., 0., 1.],
&[0., 0., 1., 0.],
&[0., 0., 1., 0.],
&[0., 0., 1., 1.],
&[0., 0., 0., 0.],
&[0., 0., 0., 1.],
]);
let y = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0];
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
let deserialized_tree: DecisionTreeClassifier<f64, i64, DenseMatrix<f64>, Vec<i64>> =
bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
assert_eq!(tree, deserialized_tree);
}
}