1use std::convert::TryFrom;
2use std::convert::TryInto;
3use std::fmt::{self, Debug, Display};
4use std::iter;
5
6use tract_nnef::internal::*;
7
8use tract_ndarray::{
9 Array1, Array2, ArrayD, ArrayView1, ArrayView2, ArrayViewD, ArrayViewMut1, Axis, Ix1, Ix2,
10};
11
12use tract_num_traits::AsPrimitive;
13
14macro_rules! ensure {
15 ($cond: expr, $($rest: expr),* $(,)?) => {
16 if !$cond {
17 bail!($($rest),*)
18 }
19 }
20}
21
22#[repr(u8)]
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum Cmp {
25 Equal = 1,
26 NotEqual = 2,
27 Less = 3,
28 Greater = 4,
29 LessEqual = 5,
30 GreaterEqual = 6,
31}
32
33impl Cmp {
34 pub fn compare<T>(&self, x: T, y: T) -> bool
35 where
36 T: PartialOrd + Copy,
37 {
38 match self {
39 Cmp::LessEqual => x <= y,
40 Cmp::Less => x < y,
41 Cmp::GreaterEqual => x >= y,
42 Cmp::Greater => x > y,
43 Cmp::Equal => x == y,
44 Cmp::NotEqual => x != y,
45 }
46 }
47 pub fn to_u8(&self) -> u8 {
48 unsafe { std::mem::transmute(*self) }
49 }
50}
51
52impl TryFrom<u8> for Cmp {
53 type Error = TractError;
54 fn try_from(value: u8) -> Result<Self, Self::Error> {
55 if (1..=5).contains(&value) {
56 unsafe { Ok(std::mem::transmute::<u8, Cmp>(value)) }
57 } else {
58 bail!("Invalid value for Cmp: {}", value);
59 }
60 }
61}
62
63impl Display for Cmp {
64 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65 f.write_str(match self {
66 Cmp::LessEqual => "<=",
67 Cmp::Less => "<",
68 Cmp::GreaterEqual => ">=",
69 Cmp::Greater => ">",
70 Cmp::Equal => "==",
71 Cmp::NotEqual => "!=",
72 })
73 }
74}
75
76#[derive(Debug, Clone, Hash)]
77pub struct TreeEnsembleData {
78 pub trees: Arc<Tensor>,
80 pub nodes: Arc<Tensor>,
87 pub leaves: Arc<Tensor>,
89}
90
91impl Display for TreeEnsembleData {
92 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 let tree = self.trees.as_slice::<u32>().unwrap();
94 for t in 0..tree.len() {
95 let last_node = tree.get(t + 1).cloned().unwrap_or(self.nodes.len() as u32 / 5);
96 writeln!(f, "Tree {}, nodes {:?}", t, tree[t]..last_node)?;
97 for n in tree[t]..last_node {
98 unsafe {
99 let node = self.get_unchecked(n as _);
100 if let TreeNode::Leaf(leaf) = node {
101 for vote in leaf.start_id..leaf.end_id {
102 let cat = self.leaves.as_slice::<u32>().unwrap()[vote * 2];
103 let contrib = self.leaves.as_slice::<u32>().unwrap()[vote * 2 + 1];
104 let contrib = f32::from_bits(contrib);
105 writeln!(f, "{n} categ:{cat} add:{contrib}")?;
106 }
107 } else {
108 writeln!(f, "{} {:?}", n, self.get_unchecked(n as _))?;
109 }
110 }
111 }
112 }
113 Ok(())
114 }
115}
116
117impl TreeEnsembleData {
118 unsafe fn get_unchecked(&self, node: usize) -> TreeNode {
119 let row = &unsafe { self.nodes.as_slice_unchecked::<u32>() }[node * 5..][..5];
120 if let Ok(cmp) = ((row[4] & 0xFF) as u8).try_into() {
121 let feature_id = row[0];
122 let true_id = row[1];
123 let false_id = row[2];
124 let value = f32::from_bits(row[3]);
125 let nan_is_true = (row[4] & 0x0100) != 0;
126 TreeNode::Branch(BranchNode { cmp, feature_id, value, true_id, false_id, nan_is_true })
127 } else {
128 TreeNode::Leaf(LeafNode { start_id: row[0] as usize, end_id: row[1] as usize })
129 }
130 }
131
132 unsafe fn get_leaf_unchecked<T>(&self, tree: usize, input: &ArrayView1<T>) -> LeafNode
133 where
134 T: AsPrimitive<f32>,
135 {
136 unsafe {
137 let mut node_id = self.trees.as_slice_unchecked::<u32>()[tree] as usize;
138 loop {
139 let node = self.get_unchecked(node_id);
140 match node {
141 TreeNode::Branch(ref b) => {
142 let feature = *input.uget(b.feature_id as usize);
143 node_id = b.get_child_id(feature.as_());
144 }
145 TreeNode::Leaf(l) => return l,
146 }
147 }
148 }
149 }
150
151 unsafe fn eval_unchecked<A, T>(
152 &self,
153 tree: usize,
154 input: &ArrayView1<T>,
155 output: &mut ArrayViewMut1<f32>,
156 aggs: &mut [A],
157 ) where
158 A: AggregateFn,
159 T: AsPrimitive<f32>,
160 {
161 unsafe {
162 let leaf = self.get_leaf_unchecked(tree, input);
163 for leaf in self
164 .leaves
165 .to_array_view_unchecked::<u32>()
166 .outer_iter()
167 .skip(leaf.start_id)
168 .take(leaf.end_id - leaf.start_id)
169 {
170 let class_id = leaf[0] as usize;
171 let weight = f32::from_bits(leaf[1]);
172 let agg_fn = aggs.get_unchecked_mut(class_id);
173 agg_fn.aggregate(weight, output.uget_mut(class_id));
174 }
175 }
176 }
177}
178
179#[derive(Copy, Clone)]
180struct BranchNode {
181 pub cmp: Cmp, pub feature_id: u32,
183 pub value: f32,
184 pub true_id: u32,
185 pub false_id: u32,
186 pub nan_is_true: bool,
187}
188
189impl std::fmt::Debug for BranchNode {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 write!(
192 f,
193 "if feat({}) {} {} then {} else {}",
194 self.feature_id, self.cmp, self.value, self.true_id, self.false_id
195 )
196 }
197}
198
199impl BranchNode {
200 pub fn get_child_id(&self, feature: f32) -> usize {
201 let condition =
202 if feature.is_nan() { self.nan_is_true } else { self.cmp.compare(feature, self.value) };
203 if condition {
204 self.true_id as usize
205 } else {
206 self.false_id as usize
207 }
208 }
209}
210
211#[derive(Copy, Clone, Debug, Hash)]
212struct LeafNode {
213 pub start_id: usize,
214 pub end_id: usize,
215}
216
217#[derive(Copy, Clone, Debug)]
218enum TreeNode {
219 Branch(BranchNode),
220 Leaf(LeafNode),
221}
222
223pub trait AggregateFn: Default {
224 fn aggregate(&mut self, score: f32, total: &mut f32);
225
226 fn post_aggregate(&mut self, _total: &mut f32) {}
227}
228
229#[derive(Clone, Copy, Default, Debug)]
230pub struct SumFn;
231
232impl AggregateFn for SumFn {
233 fn aggregate(&mut self, score: f32, total: &mut f32) {
234 *total += score;
235 }
236}
237
238#[derive(Clone, Copy, Default, Debug)]
239pub struct AvgFn {
240 count: usize,
241}
242
243impl AggregateFn for AvgFn {
244 fn aggregate(&mut self, score: f32, total: &mut f32) {
245 *total += score;
246 self.count += 1;
247 }
248
249 fn post_aggregate(&mut self, total: &mut f32) {
250 if self.count > 1 {
251 *total /= self.count as f32;
252 }
253 self.count = 0;
254 }
255}
256
257#[derive(Clone, Copy, Default, Debug)]
258pub struct MaxFn;
259
260impl AggregateFn for MaxFn {
261 fn aggregate(&mut self, score: f32, total: &mut f32) {
262 *total = total.max(score);
263 }
264}
265
266#[derive(Clone, Copy, Default, Debug)]
267pub struct MinFn;
268
269impl AggregateFn for MinFn {
270 fn aggregate(&mut self, score: f32, total: &mut f32) {
271 *total = total.min(score);
272 }
273}
274
275#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
276pub enum Aggregate {
277 #[default]
278 Sum,
279 Avg,
280 Max,
281 Min,
282}
283
284#[derive(Clone, Debug, Hash)]
285pub struct TreeEnsemble {
286 pub data: TreeEnsembleData,
287 pub max_used_feature: usize,
288 pub n_classes: usize,
289 pub aggregate_fn: Aggregate, }
291
292impl TreeEnsemble {
293 pub fn build(
294 data: TreeEnsembleData,
295 max_used_feature: usize,
296 n_classes: usize,
297 aggregate_fn: Aggregate,
298 ) -> TractResult<Self> {
299 Ok(Self { data, max_used_feature, n_classes, aggregate_fn })
300 }
301
302 pub fn n_classes(&self) -> usize {
303 self.n_classes
304 }
305
306 unsafe fn eval_one_unchecked<A, T>(
307 &self,
308 input: &ArrayView1<T>,
309 output: &mut ArrayViewMut1<f32>,
310 aggs: &mut [A],
311 ) where
312 A: AggregateFn,
313 T: AsPrimitive<f32>,
314 {
315 unsafe {
316 for t in 0..self.data.trees.len() {
317 self.data.eval_unchecked(t, input, output, aggs)
318 }
319 for i in 0..self.n_classes {
320 aggs.get_unchecked_mut(i).post_aggregate(output.uget_mut(i));
321 }
322 }
323 }
324
325 pub fn check_n_features(&self, n_features: usize) -> TractResult<()> {
326 ensure!(
327 n_features > self.max_used_feature,
328 "Invalid input shape: input has {} features, tree ensemble use feature #{}",
329 n_features,
330 self.max_used_feature
331 );
332 Ok(())
333 }
334
335 fn eval_2d<A, T>(&self, input: &ArrayView2<T>) -> TractResult<Array2<f32>>
336 where
337 A: AggregateFn,
338 T: AsPrimitive<f32>,
339 {
340 self.check_n_features(input.shape()[1])?;
341 let n = input.shape()[0];
342 let mut output = Array2::zeros((n, self.n_classes));
343 let mut aggs: tract_smallvec::SmallVec<[A; 16]> =
344 iter::repeat_with(Default::default).take(self.n_classes).collect();
345 for i in 0..n {
346 unsafe {
347 self.eval_one_unchecked::<A, T>(
348 &input.index_axis(Axis(0), i),
349 &mut output.index_axis_mut(Axis(0), i),
350 &mut aggs,
351 );
352 }
353 }
354 Ok(output)
355 }
356
357 fn eval_1d<A, T>(&self, input: &ArrayView1<T>) -> TractResult<Array1<f32>>
358 where
359 A: AggregateFn,
360 T: AsPrimitive<f32>,
361 {
362 self.check_n_features(input.len())?;
363 let mut output = Array1::zeros(self.n_classes);
364 let mut aggs: tract_smallvec::SmallVec<[A; 16]> =
365 iter::repeat_with(Default::default).take(self.n_classes).collect();
366 unsafe {
367 self.eval_one_unchecked::<A, T>(input, &mut output.view_mut(), &mut aggs);
368 }
369 Ok(output)
370 }
371
372 pub fn eval<'i, I, T>(&self, input: I) -> TractResult<ArrayD<f32>>
373 where
374 I: Into<ArrayViewD<'i, T>>, T: Datum + AsPrimitive<f32>,
376 {
377 let input = input.into();
378 if let Ok(input) = input.view().into_dimensionality::<Ix1>() {
379 Ok(match self.aggregate_fn {
380 Aggregate::Sum => self.eval_1d::<SumFn, T>(&input),
381 Aggregate::Avg => self.eval_1d::<AvgFn, T>(&input),
382 Aggregate::Min => self.eval_1d::<MinFn, T>(&input),
383 Aggregate::Max => self.eval_1d::<MaxFn, T>(&input),
384 }?
385 .into_dyn())
386 } else if let Ok(input) = input.view().into_dimensionality::<Ix2>() {
387 Ok(match self.aggregate_fn {
388 Aggregate::Sum => self.eval_2d::<SumFn, T>(&input),
389 Aggregate::Avg => self.eval_2d::<AvgFn, T>(&input),
390 Aggregate::Min => self.eval_2d::<MinFn, T>(&input),
391 Aggregate::Max => self.eval_2d::<MaxFn, T>(&input),
392 }?
393 .into_dyn())
394 } else {
395 bail!("Invalid input dimensionality for tree ensemble: {:?}", input.shape());
396 }
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use tract_ndarray::prelude::*;
404
405 fn b(
406 node_offset: usize,
407 cmp: Cmp,
408 feat: usize,
409 v: f32,
410 left: usize,
411 right: usize,
412 nan_is_true: bool,
413 ) -> [u32; 5] {
414 [
415 feat as u32,
416 (node_offset + left) as u32,
417 (node_offset + right) as u32,
418 v.to_bits(),
419 cmp as u32 | if nan_is_true { 0x100 } else { 0 },
420 ]
421 }
422
423 fn l(leaf_offset: usize, start_id: usize, end_id: usize) -> [u32; 5] {
424 [(leaf_offset + start_id) as u32, (leaf_offset + end_id) as u32, 0, 0, 0]
425 }
426
427 fn w(categ: usize, weight: f32) -> [u32; 2] {
428 [categ as u32, weight.to_bits()]
429 }
430
431 fn generate_gbm_trees() -> TreeEnsembleData {
432 let trees = rctensor1(&[0u32, 5u32, 14, 21, 30, 41]);
433 let nodes = rctensor2(&[
434 b(0, Cmp::LessEqual, 2, 3.15, 1, 2, true),
435 b(0, Cmp::LessEqual, 1, 3.35, 3, 4, true),
436 l(0, 0, 1),
437 l(0, 1, 2),
438 l(0, 2, 3),
439 b(5, Cmp::LessEqual, 2, 1.8, 1, 2, true),
441 l(3, 0, 1),
442 b(5, Cmp::LessEqual, 3, 1.65, 3, 4, true),
443 b(5, Cmp::LessEqual, 2, 4.45, 5, 6, true),
444 b(5, Cmp::LessEqual, 2, 5.35, 7, 8, true),
445 l(3, 1, 2),
446 l(3, 2, 3),
447 l(3, 3, 4),
448 l(3, 4, 5),
449 b(14, Cmp::LessEqual, 3, 1.65, 1, 2, true),
451 b(14, Cmp::LessEqual, 2, 4.45, 3, 4, true),
452 b(14, Cmp::LessEqual, 2, 5.35, 5, 6, true),
453 l(8, 0, 1),
454 l(8, 1, 2),
455 l(8, 2, 3),
456 l(8, 3, 4),
457 b(21, Cmp::LessEqual, 2, 3.15, 1, 2, true),
459 b(21, Cmp::LessEqual, 1, 3.35, 3, 4, true),
460 b(21, Cmp::LessEqual, 2, 4.45, 5, 6, true),
461 l(12, 0, 1),
462 l(12, 1, 2),
463 l(12, 2, 3),
464 b(21, Cmp::LessEqual, 2, 5.35, 7, 8, true),
465 l(12, 3, 4),
466 l(12, 4, 5),
467 b(30, Cmp::LessEqual, 3, 0.45, 1, 2, true),
469 b(30, Cmp::LessEqual, 2, 1.45, 3, 4, true),
470 b(30, Cmp::LessEqual, 3, 1.65, 5, 6, true),
471 l(17, 0, 1),
472 l(17, 1, 2),
473 b(30, Cmp::LessEqual, 2, 4.45, 7, 8, true),
474 b(30, Cmp::LessEqual, 2, 5.35, 9, 10, true),
475 l(17, 2, 3),
476 l(17, 3, 4),
477 l(17, 4, 5),
478 l(17, 5, 6),
479 b(41, Cmp::LessEqual, 2, 4.75, 1, 2, true),
481 b(41, Cmp::LessEqual, 1, 2.75, 3, 4, true),
482 b(41, Cmp::LessEqual, 2, 5.15, 7, 8, true),
483 l(23, 0, 1),
484 b(41, Cmp::LessEqual, 2, 4.15, 5, 6, true),
485 l(23, 1, 2),
486 l(23, 2, 3),
487 l(23, 3, 4),
488 l(23, 4, 5),
489 ]);
490 assert_eq!(nodes.shape(), &[50, 5]);
491 let leaves = rctensor2(&[
492 w(0, -0.075),
493 w(0, 0.13928571),
494 w(0, 0.15),
495 w(1, -0.075),
497 w(1, 0.13548388),
498 w(1, 0.110869564),
499 w(1, -0.052500002),
500 w(1, -0.075),
501 w(2, -0.075),
503 w(2, -0.035869565),
504 w(2, 0.1275),
505 w(2, 0.15),
506 w(0, 0.12105576),
508 w(0, 0.1304589),
509 w(0, -0.07237862),
510 w(0, -0.07226522),
511 w(0, -0.07220469),
512 w(1, -0.07226842),
514 w(1, -0.07268012),
515 w(1, 0.119391434),
516 w(1, 0.097440675),
517 w(1, -0.049815115),
518 w(1, -0.07219931),
519 w(2, -0.061642267),
521 w(2, -0.0721846),
522 w(2, -0.07319043),
523 w(2, 0.076814815),
524 w(2, 0.1315959),
525 ]);
526 assert_eq!(leaves.shape(), &[28, 2]);
527 TreeEnsembleData { nodes, trees, leaves }
528 }
529
530 fn generate_gbm_ensemble() -> TreeEnsemble {
531 let trees = generate_gbm_trees();
533 TreeEnsemble::build(trees, 4, 3, Aggregate::Sum).unwrap()
534 }
535
536 fn generate_gbm_input() -> Array2<f32> {
537 arr2(&[
538 [5.1, 3.5, 1.4, 0.2],
539 [5.4, 3.7, 1.5, 0.2],
540 [5.4, 3.4, 1.7, 0.2],
541 [4.8, 3.1, 1.6, 0.2],
542 [5.0, 3.5, 1.3, 0.3],
543 [7.0, 3.2, 4.7, 1.4],
544 [5.0, 2.0, 3.5, 1.0],
545 [5.9, 3.2, 4.8, 1.8],
546 [5.5, 2.4, 3.8, 1.1],
547 [5.5, 2.6, 4.4, 1.2],
548 [6.3, 3.3, 6.0, 2.5],
549 [6.5, 3.2, 5.1, 2.0],
550 [6.9, 3.2, 5.7, 2.3],
551 [7.4, 2.8, 6.1, 1.9],
552 [6.7, 3.1, 5.6, 2.4],
553 ])
554 }
555
556 fn generate_gbm_raw_output() -> Array2<f32> {
557 arr2(&[
558 [0.28045893, -0.14726841, -0.14718461],
559 [0.28045893, -0.14768013, -0.14718461],
560 [0.28045893, -0.14768013, -0.14718461],
561 [0.26034147, -0.14768013, -0.14718461],
562 [0.28045893, -0.14726841, -0.14718461],
563 [-0.14726523, 0.20831025, -0.10905999],
564 [-0.14737862, 0.254_875_3, -0.13664228],
565 [-0.14726523, -0.10231511, 0.20431481],
566 [-0.14737862, 0.254_875_3, -0.13664228],
567 [-0.14737862, 0.254_875_3, -0.13664228],
568 [-0.147_204_7, -0.147_199_3, 0.281_595_9],
569 [-0.14726523, -0.10231511, 0.20431481],
570 [-0.147_204_7, -0.147_199_3, 0.281_595_9],
571 [-0.147_204_7, -0.147_199_3, 0.281_595_9],
572 [-0.147_204_7, -0.147_199_3, 0.281_595_9],
573 ])
574 }
575
576 #[test]
577 #[ignore]
578 fn test_tree_ensemble() {
579 let ensemble = generate_gbm_ensemble();
580 let input = generate_gbm_input();
581 let output = ensemble.eval(input.view().into_dyn()).unwrap();
582 assert_eq!(output, generate_gbm_raw_output().into_dyn());
583 }
584}