tract_onnx_opl/ml/
tree.rs

1use std::convert::TryFrom;
2use std::convert::TryInto;
3use std::fmt::{self, Debug, Display};
4use std::iter;
5
6use tract_nnef::internal::*;
7
8use tract_ndarray::{
9    Array1, Array2, ArrayD, ArrayView1, ArrayView2, ArrayViewD, ArrayViewMut1, Axis, Ix1, Ix2,
10};
11
12use tract_num_traits::AsPrimitive;
13
14macro_rules! ensure {
15    ($cond: expr, $($rest: expr),* $(,)?) => {
16        if !$cond {
17            bail!($($rest),*)
18        }
19    }
20}
21
22#[repr(u8)]
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum Cmp {
25    Equal = 1,
26    NotEqual = 2,
27    Less = 3,
28    Greater = 4,
29    LessEqual = 5,
30    GreaterEqual = 6,
31}
32
33impl Cmp {
34    pub fn compare<T>(&self, x: T, y: T) -> bool
35    where
36        T: PartialOrd + Copy,
37    {
38        match self {
39            Cmp::LessEqual => x <= y,
40            Cmp::Less => x < y,
41            Cmp::GreaterEqual => x >= y,
42            Cmp::Greater => x > y,
43            Cmp::Equal => x == y,
44            Cmp::NotEqual => x != y,
45        }
46    }
47    pub fn to_u8(&self) -> u8 {
48        unsafe { std::mem::transmute(*self) }
49    }
50}
51
52impl TryFrom<u8> for Cmp {
53    type Error = TractError;
54    fn try_from(value: u8) -> Result<Self, Self::Error> {
55        if (1..=5).contains(&value) {
56            unsafe { Ok(std::mem::transmute::<u8, Cmp>(value)) }
57        } else {
58            bail!("Invalid value for Cmp: {}", value);
59        }
60    }
61}
62
63impl Display for Cmp {
64    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65        f.write_str(match self {
66            Cmp::LessEqual => "<=",
67            Cmp::Less => "<",
68            Cmp::GreaterEqual => ">=",
69            Cmp::Greater => ">",
70            Cmp::Equal => "==",
71            Cmp::NotEqual => "!=",
72        })
73    }
74}
75
76#[derive(Debug, Clone, Hash)]
77pub struct TreeEnsembleData {
78    // u32, [Ntrees], root row of each tree in nodes array (in rows)
79    pub trees: Arc<Tensor>,
80    // u32, [_, 5],
81    // 5th number is flags: last byte is comparator, 0 for leaves, transmuted Cmp for the internal nodes
82    //                      is_nan is 0x100 bit
83    // intern nodes:    feature_id, true_id, false_id, value.to_bits(),
84    //                  comp | (0x100 if nan_is_true)
85    // leaves:          start row, end row in leaves array, 3 zeros for padding
86    pub nodes: Arc<Tensor>,
87    // categ,
88    pub leaves: Arc<Tensor>,
89}
90
91impl Display for TreeEnsembleData {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        let tree = self.trees.as_slice::<u32>().unwrap();
94        for t in 0..tree.len() {
95            let last_node = tree.get(t + 1).cloned().unwrap_or(self.nodes.len() as u32 / 5);
96            writeln!(f, "Tree {}, nodes {:?}", t, tree[t]..last_node)?;
97            for n in tree[t]..last_node {
98                unsafe {
99                    let node = self.get_unchecked(n as _);
100                    if let TreeNode::Leaf(leaf) = node {
101                        for vote in leaf.start_id..leaf.end_id {
102                            let cat = self.leaves.as_slice::<u32>().unwrap()[vote * 2];
103                            let contrib = self.leaves.as_slice::<u32>().unwrap()[vote * 2 + 1];
104                            let contrib = f32::from_bits(contrib);
105                            writeln!(f, "{n} categ:{cat} add:{contrib}")?;
106                        }
107                    } else {
108                        writeln!(f, "{} {:?}", n, self.get_unchecked(n as _))?;
109                    }
110                }
111            }
112        }
113        Ok(())
114    }
115}
116
117impl TreeEnsembleData {
118    unsafe fn get_unchecked(&self, node: usize) -> TreeNode {
119        let row = &unsafe { self.nodes.as_slice_unchecked::<u32>() }[node * 5..][..5];
120        if let Ok(cmp) = ((row[4] & 0xFF) as u8).try_into() {
121            let feature_id = row[0];
122            let true_id = row[1];
123            let false_id = row[2];
124            let value = f32::from_bits(row[3]);
125            let nan_is_true = (row[4] & 0x0100) != 0;
126            TreeNode::Branch(BranchNode { cmp, feature_id, value, true_id, false_id, nan_is_true })
127        } else {
128            TreeNode::Leaf(LeafNode { start_id: row[0] as usize, end_id: row[1] as usize })
129        }
130    }
131
132    unsafe fn get_leaf_unchecked<T>(&self, tree: usize, input: &ArrayView1<T>) -> LeafNode
133    where
134        T: AsPrimitive<f32>,
135    {
136        unsafe {
137            let mut node_id = self.trees.as_slice_unchecked::<u32>()[tree] as usize;
138            loop {
139                let node = self.get_unchecked(node_id);
140                match node {
141                    TreeNode::Branch(ref b) => {
142                        let feature = *input.uget(b.feature_id as usize);
143                        node_id = b.get_child_id(feature.as_());
144                    }
145                    TreeNode::Leaf(l) => return l,
146                }
147            }
148        }
149    }
150
151    unsafe fn eval_unchecked<A, T>(
152        &self,
153        tree: usize,
154        input: &ArrayView1<T>,
155        output: &mut ArrayViewMut1<f32>,
156        aggs: &mut [A],
157    ) where
158        A: AggregateFn,
159        T: AsPrimitive<f32>,
160    {
161        unsafe {
162            let leaf = self.get_leaf_unchecked(tree, input);
163            for leaf in self
164                .leaves
165                .to_array_view_unchecked::<u32>()
166                .outer_iter()
167                .skip(leaf.start_id)
168                .take(leaf.end_id - leaf.start_id)
169            {
170                let class_id = leaf[0] as usize;
171                let weight = f32::from_bits(leaf[1]);
172                let agg_fn = aggs.get_unchecked_mut(class_id);
173                agg_fn.aggregate(weight, output.uget_mut(class_id));
174            }
175        }
176    }
177}
178
179#[derive(Copy, Clone)]
180struct BranchNode {
181    pub cmp: Cmp, // TODO: perf: most real forests have only 1 type of comparison
182    pub feature_id: u32,
183    pub value: f32,
184    pub true_id: u32,
185    pub false_id: u32,
186    pub nan_is_true: bool,
187}
188
189impl std::fmt::Debug for BranchNode {
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        write!(
192            f,
193            "if feat({}) {} {} then {} else {}",
194            self.feature_id, self.cmp, self.value, self.true_id, self.false_id
195        )
196    }
197}
198
199impl BranchNode {
200    pub fn get_child_id(&self, feature: f32) -> usize {
201        let condition =
202            if feature.is_nan() { self.nan_is_true } else { self.cmp.compare(feature, self.value) };
203        if condition {
204            self.true_id as usize
205        } else {
206            self.false_id as usize
207        }
208    }
209}
210
211#[derive(Copy, Clone, Debug, Hash)]
212struct LeafNode {
213    pub start_id: usize,
214    pub end_id: usize,
215}
216
217#[derive(Copy, Clone, Debug)]
218enum TreeNode {
219    Branch(BranchNode),
220    Leaf(LeafNode),
221}
222
223pub trait AggregateFn: Default {
224    fn aggregate(&mut self, score: f32, total: &mut f32);
225
226    fn post_aggregate(&mut self, _total: &mut f32) {}
227}
228
229#[derive(Clone, Copy, Default, Debug)]
230pub struct SumFn;
231
232impl AggregateFn for SumFn {
233    fn aggregate(&mut self, score: f32, total: &mut f32) {
234        *total += score;
235    }
236}
237
238#[derive(Clone, Copy, Default, Debug)]
239pub struct AvgFn {
240    count: usize,
241}
242
243impl AggregateFn for AvgFn {
244    fn aggregate(&mut self, score: f32, total: &mut f32) {
245        *total += score;
246        self.count += 1;
247    }
248
249    fn post_aggregate(&mut self, total: &mut f32) {
250        if self.count > 1 {
251            *total /= self.count as f32;
252        }
253        self.count = 0;
254    }
255}
256
257#[derive(Clone, Copy, Default, Debug)]
258pub struct MaxFn;
259
260impl AggregateFn for MaxFn {
261    fn aggregate(&mut self, score: f32, total: &mut f32) {
262        *total = total.max(score);
263    }
264}
265
266#[derive(Clone, Copy, Default, Debug)]
267pub struct MinFn;
268
269impl AggregateFn for MinFn {
270    fn aggregate(&mut self, score: f32, total: &mut f32) {
271        *total = total.min(score);
272    }
273}
274
275#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
276pub enum Aggregate {
277    #[default]
278    Sum,
279    Avg,
280    Max,
281    Min,
282}
283
284#[derive(Clone, Debug, Hash)]
285pub struct TreeEnsemble {
286    pub data: TreeEnsembleData,
287    pub max_used_feature: usize,
288    pub n_classes: usize,
289    pub aggregate_fn: Aggregate, // TODO: should this be an argument to eval()?
290}
291
292impl TreeEnsemble {
293    pub fn build(
294        data: TreeEnsembleData,
295        max_used_feature: usize,
296        n_classes: usize,
297        aggregate_fn: Aggregate,
298    ) -> TractResult<Self> {
299        Ok(Self { data, max_used_feature, n_classes, aggregate_fn })
300    }
301
302    pub fn n_classes(&self) -> usize {
303        self.n_classes
304    }
305
306    unsafe fn eval_one_unchecked<A, T>(
307        &self,
308        input: &ArrayView1<T>,
309        output: &mut ArrayViewMut1<f32>,
310        aggs: &mut [A],
311    ) where
312        A: AggregateFn,
313        T: AsPrimitive<f32>,
314    {
315        unsafe {
316            for t in 0..self.data.trees.len() {
317                self.data.eval_unchecked(t, input, output, aggs)
318            }
319            for i in 0..self.n_classes {
320                aggs.get_unchecked_mut(i).post_aggregate(output.uget_mut(i));
321            }
322        }
323    }
324
325    pub fn check_n_features(&self, n_features: usize) -> TractResult<()> {
326        ensure!(
327            n_features > self.max_used_feature,
328            "Invalid input shape: input has {} features, tree ensemble use feature #{}",
329            n_features,
330            self.max_used_feature
331        );
332        Ok(())
333    }
334
335    fn eval_2d<A, T>(&self, input: &ArrayView2<T>) -> TractResult<Array2<f32>>
336    where
337        A: AggregateFn,
338        T: AsPrimitive<f32>,
339    {
340        self.check_n_features(input.shape()[1])?;
341        let n = input.shape()[0];
342        let mut output = Array2::zeros((n, self.n_classes));
343        let mut aggs: tract_smallvec::SmallVec<[A; 16]> =
344            iter::repeat_with(Default::default).take(self.n_classes).collect();
345        for i in 0..n {
346            unsafe {
347                self.eval_one_unchecked::<A, T>(
348                    &input.index_axis(Axis(0), i),
349                    &mut output.index_axis_mut(Axis(0), i),
350                    &mut aggs,
351                );
352            }
353        }
354        Ok(output)
355    }
356
357    fn eval_1d<A, T>(&self, input: &ArrayView1<T>) -> TractResult<Array1<f32>>
358    where
359        A: AggregateFn,
360        T: AsPrimitive<f32>,
361    {
362        self.check_n_features(input.len())?;
363        let mut output = Array1::zeros(self.n_classes);
364        let mut aggs: tract_smallvec::SmallVec<[A; 16]> =
365            iter::repeat_with(Default::default).take(self.n_classes).collect();
366        unsafe {
367            self.eval_one_unchecked::<A, T>(input, &mut output.view_mut(), &mut aggs);
368        }
369        Ok(output)
370    }
371
372    pub fn eval<'i, I, T>(&self, input: I) -> TractResult<ArrayD<f32>>
373    where
374        I: Into<ArrayViewD<'i, T>>, // TODO: accept generic dimensions, not just IxDyn
375        T: Datum + AsPrimitive<f32>,
376    {
377        let input = input.into();
378        if let Ok(input) = input.view().into_dimensionality::<Ix1>() {
379            Ok(match self.aggregate_fn {
380                Aggregate::Sum => self.eval_1d::<SumFn, T>(&input),
381                Aggregate::Avg => self.eval_1d::<AvgFn, T>(&input),
382                Aggregate::Min => self.eval_1d::<MinFn, T>(&input),
383                Aggregate::Max => self.eval_1d::<MaxFn, T>(&input),
384            }?
385            .into_dyn())
386        } else if let Ok(input) = input.view().into_dimensionality::<Ix2>() {
387            Ok(match self.aggregate_fn {
388                Aggregate::Sum => self.eval_2d::<SumFn, T>(&input),
389                Aggregate::Avg => self.eval_2d::<AvgFn, T>(&input),
390                Aggregate::Min => self.eval_2d::<MinFn, T>(&input),
391                Aggregate::Max => self.eval_2d::<MaxFn, T>(&input),
392            }?
393            .into_dyn())
394        } else {
395            bail!("Invalid input dimensionality for tree ensemble: {:?}", input.shape());
396        }
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use tract_ndarray::prelude::*;
404
405    fn b(
406        node_offset: usize,
407        cmp: Cmp,
408        feat: usize,
409        v: f32,
410        left: usize,
411        right: usize,
412        nan_is_true: bool,
413    ) -> [u32; 5] {
414        [
415            feat as u32,
416            (node_offset + left) as u32,
417            (node_offset + right) as u32,
418            v.to_bits(),
419            cmp as u32 | if nan_is_true { 0x100 } else { 0 },
420        ]
421    }
422
423    fn l(leaf_offset: usize, start_id: usize, end_id: usize) -> [u32; 5] {
424        [(leaf_offset + start_id) as u32, (leaf_offset + end_id) as u32, 0, 0, 0]
425    }
426
427    fn w(categ: usize, weight: f32) -> [u32; 2] {
428        [categ as u32, weight.to_bits()]
429    }
430
431    fn generate_gbm_trees() -> TreeEnsembleData {
432        let trees = rctensor1(&[0u32, 5u32, 14, 21, 30, 41]);
433        let nodes = rctensor2(&[
434            b(0, Cmp::LessEqual, 2, 3.15, 1, 2, true),
435            b(0, Cmp::LessEqual, 1, 3.35, 3, 4, true),
436            l(0, 0, 1),
437            l(0, 1, 2),
438            l(0, 2, 3),
439            //
440            b(5, Cmp::LessEqual, 2, 1.8, 1, 2, true),
441            l(3, 0, 1),
442            b(5, Cmp::LessEqual, 3, 1.65, 3, 4, true),
443            b(5, Cmp::LessEqual, 2, 4.45, 5, 6, true),
444            b(5, Cmp::LessEqual, 2, 5.35, 7, 8, true),
445            l(3, 1, 2),
446            l(3, 2, 3),
447            l(3, 3, 4),
448            l(3, 4, 5),
449            //
450            b(14, Cmp::LessEqual, 3, 1.65, 1, 2, true),
451            b(14, Cmp::LessEqual, 2, 4.45, 3, 4, true),
452            b(14, Cmp::LessEqual, 2, 5.35, 5, 6, true),
453            l(8, 0, 1),
454            l(8, 1, 2),
455            l(8, 2, 3),
456            l(8, 3, 4),
457            //
458            b(21, Cmp::LessEqual, 2, 3.15, 1, 2, true),
459            b(21, Cmp::LessEqual, 1, 3.35, 3, 4, true),
460            b(21, Cmp::LessEqual, 2, 4.45, 5, 6, true),
461            l(12, 0, 1),
462            l(12, 1, 2),
463            l(12, 2, 3),
464            b(21, Cmp::LessEqual, 2, 5.35, 7, 8, true),
465            l(12, 3, 4),
466            l(12, 4, 5),
467            //
468            b(30, Cmp::LessEqual, 3, 0.45, 1, 2, true),
469            b(30, Cmp::LessEqual, 2, 1.45, 3, 4, true),
470            b(30, Cmp::LessEqual, 3, 1.65, 5, 6, true),
471            l(17, 0, 1),
472            l(17, 1, 2),
473            b(30, Cmp::LessEqual, 2, 4.45, 7, 8, true),
474            b(30, Cmp::LessEqual, 2, 5.35, 9, 10, true),
475            l(17, 2, 3),
476            l(17, 3, 4),
477            l(17, 4, 5),
478            l(17, 5, 6),
479            //
480            b(41, Cmp::LessEqual, 2, 4.75, 1, 2, true),
481            b(41, Cmp::LessEqual, 1, 2.75, 3, 4, true),
482            b(41, Cmp::LessEqual, 2, 5.15, 7, 8, true),
483            l(23, 0, 1),
484            b(41, Cmp::LessEqual, 2, 4.15, 5, 6, true),
485            l(23, 1, 2),
486            l(23, 2, 3),
487            l(23, 3, 4),
488            l(23, 4, 5),
489        ]);
490        assert_eq!(nodes.shape(), &[50, 5]);
491        let leaves = rctensor2(&[
492            w(0, -0.075),
493            w(0, 0.13928571),
494            w(0, 0.15),
495            //
496            w(1, -0.075),
497            w(1, 0.13548388),
498            w(1, 0.110869564),
499            w(1, -0.052500002),
500            w(1, -0.075),
501            //
502            w(2, -0.075),
503            w(2, -0.035869565),
504            w(2, 0.1275),
505            w(2, 0.15),
506            //
507            w(0, 0.12105576),
508            w(0, 0.1304589),
509            w(0, -0.07237862),
510            w(0, -0.07226522),
511            w(0, -0.07220469),
512            //
513            w(1, -0.07226842),
514            w(1, -0.07268012),
515            w(1, 0.119391434),
516            w(1, 0.097440675),
517            w(1, -0.049815115),
518            w(1, -0.07219931),
519            //
520            w(2, -0.061642267),
521            w(2, -0.0721846),
522            w(2, -0.07319043),
523            w(2, 0.076814815),
524            w(2, 0.1315959),
525        ]);
526        assert_eq!(leaves.shape(), &[28, 2]);
527        TreeEnsembleData { nodes, trees, leaves }
528    }
529
530    fn generate_gbm_ensemble() -> TreeEnsemble {
531        // converted manually from LightGBM, fitted on iris dataset
532        let trees = generate_gbm_trees();
533        TreeEnsemble::build(trees, 4, 3, Aggregate::Sum).unwrap()
534    }
535
536    fn generate_gbm_input() -> Array2<f32> {
537        arr2(&[
538            [5.1, 3.5, 1.4, 0.2],
539            [5.4, 3.7, 1.5, 0.2],
540            [5.4, 3.4, 1.7, 0.2],
541            [4.8, 3.1, 1.6, 0.2],
542            [5.0, 3.5, 1.3, 0.3],
543            [7.0, 3.2, 4.7, 1.4],
544            [5.0, 2.0, 3.5, 1.0],
545            [5.9, 3.2, 4.8, 1.8],
546            [5.5, 2.4, 3.8, 1.1],
547            [5.5, 2.6, 4.4, 1.2],
548            [6.3, 3.3, 6.0, 2.5],
549            [6.5, 3.2, 5.1, 2.0],
550            [6.9, 3.2, 5.7, 2.3],
551            [7.4, 2.8, 6.1, 1.9],
552            [6.7, 3.1, 5.6, 2.4],
553        ])
554    }
555
556    fn generate_gbm_raw_output() -> Array2<f32> {
557        arr2(&[
558            [0.28045893, -0.14726841, -0.14718461],
559            [0.28045893, -0.14768013, -0.14718461],
560            [0.28045893, -0.14768013, -0.14718461],
561            [0.26034147, -0.14768013, -0.14718461],
562            [0.28045893, -0.14726841, -0.14718461],
563            [-0.14726523, 0.20831025, -0.10905999],
564            [-0.14737862, 0.254_875_3, -0.13664228],
565            [-0.14726523, -0.10231511, 0.20431481],
566            [-0.14737862, 0.254_875_3, -0.13664228],
567            [-0.14737862, 0.254_875_3, -0.13664228],
568            [-0.147_204_7, -0.147_199_3, 0.281_595_9],
569            [-0.14726523, -0.10231511, 0.20431481],
570            [-0.147_204_7, -0.147_199_3, 0.281_595_9],
571            [-0.147_204_7, -0.147_199_3, 0.281_595_9],
572            [-0.147_204_7, -0.147_199_3, 0.281_595_9],
573        ])
574    }
575
576    #[test]
577    #[ignore]
578    fn test_tree_ensemble() {
579        let ensemble = generate_gbm_ensemble();
580        let input = generate_gbm_input();
581        let output = ensemble.eval(input.view().into_dyn()).unwrap();
582        assert_eq!(output, generate_gbm_raw_output().into_dyn());
583    }
584}