1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
use std::error::Error;

#[derive(Clone, Debug)]
pub struct TreeParams {
    pub min_samples_split: u16,
    pub max_depth: Option<u16>,
}

impl Default for TreeParams {
    fn default() -> Self {
        Self::new()
    }
}

impl TreeParams {
    pub fn new() -> Self {
        Self {
            min_samples_split: 2,
            max_depth: None,
        }
    }

    pub fn set_min_samples_split(&mut self, min_samples_split: u16) -> Result<(), Box<dyn Error>> {
        if min_samples_split < 2 {
            return Err("The minimum number of samples to split must be greater than 1.".into());
        }
        self.min_samples_split = min_samples_split;
        Ok(())
    }

    pub fn set_max_depth(&mut self, max_depth: Option<u16>) -> Result<(), Box<dyn Error>> {
        if max_depth.is_some_and(|depth| depth < 1) {
            return Err("The maximum depth must be greater than 0.".into());
        }
        self.max_depth = max_depth;
        Ok(())
    }

    pub fn min_samples_split(&self) -> u16 {
        self.min_samples_split
    }

    pub fn max_depth(&self) -> Option<u16> {
        self.max_depth
    }
}

#[derive(Clone, Debug)]
pub struct TreeClassifierParams {
    pub base_params: TreeParams,
    pub criterion: String,
}

impl Default for TreeClassifierParams {
    fn default() -> Self {
        Self::new()
    }
}

impl TreeClassifierParams {
    pub fn new() -> Self {
        Self {
            base_params: TreeParams::new(),
            criterion: "gini".to_string(),
        }
    }

    pub fn set_min_samples_split(&mut self, min_samples_split: u16) -> Result<(), Box<dyn Error>> {
        self.base_params.set_min_samples_split(min_samples_split)
    }

    pub fn set_max_depth(&mut self, max_depth: Option<u16>) -> Result<(), Box<dyn Error>> {
        self.base_params.set_max_depth(max_depth)
    }

    pub fn set_criterion(&mut self, criterion: String) -> Result<(), Box<dyn Error>> {
        if !["gini", "entropy"].contains(&criterion.as_str()) {
            return Err("The criterion must be either 'gini' or 'entropy'.".into());
        }
        self.criterion = criterion;
        Ok(())
    }

    pub fn min_samples_split(&self) -> u16 {
        self.base_params.min_samples_split
    }

    pub fn max_depth(&self) -> Option<u16> {
        self.base_params.max_depth
    }

    pub fn criterion(&self) -> &str {
        &self.criterion
    }
}