tract_onnx_opl/ml/
tree_ensemble_classifier.rs1pub use super::tree::{Aggregate, Cmp, TreeEnsemble, TreeEnsembleData};
2use tract_nnef::internal::*;
3
4pub fn register(registry: &mut Registry) {
5 registry.register_primitive(
6 "tract_onnx_ml_tree_ensemble_classifier",
7 ¶meters(),
8 &[("output", TypeName::Scalar.tensor())],
9 load,
10 );
11 registry.register_dumper(dump);
12}
13
14pub fn parse_aggregate(s: &str) -> TractResult<Aggregate> {
15 match s {
16 "SUM" => Ok(Aggregate::Sum),
17 "AVERAGE" => Ok(Aggregate::Avg),
18 "MAX" => Ok(Aggregate::Max),
19 "MIN" => Ok(Aggregate::Min),
20 _ => bail!("Invalid aggregate function: {}", s),
21 }
22}
23
24#[derive(Debug, Clone, Hash)]
25pub struct TreeEnsembleClassifier {
26 pub ensemble: TreeEnsemble,
27}
28
29impl Op for TreeEnsembleClassifier {
30 fn name(&self) -> StaticName {
31 "TreeEnsembleClassifier".into()
32 }
33
34 op_as_typed_op!();
35}
36
37impl EvalOp for TreeEnsembleClassifier {
38 fn is_stateless(&self) -> bool {
39 true
40 }
41
42 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
43 let input = args_1!(inputs);
44 let input = input.cast_to::<f32>()?;
45 let input = input.to_array_view::<f32>()?;
46 let scores = self.ensemble.eval(input)?;
47 Ok(tvec!(scores.into_tvalue()))
48 }
49}
50
51impl TypedOp for TreeEnsembleClassifier {
52 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
53 let n = &inputs[0].shape[0];
54 Ok(tvec!(f32::fact(&[n.clone(), self.ensemble.n_classes().into()])))
55 }
56
57 as_op!();
58}
59
60fn parameters() -> Vec<Parameter> {
61 vec![
62 TypeName::Scalar.tensor().named("input"),
63 TypeName::Scalar.tensor().named("trees"),
64 TypeName::Scalar.tensor().named("nodes"),
65 TypeName::Scalar.tensor().named("leaves"),
66 TypeName::Integer.named("max_used_feature"),
67 TypeName::Integer.named("n_classes"),
68 TypeName::String.named("aggregate_fn"),
69 ]
70}
71
72fn dump(
73 ast: &mut IntoAst,
74 node: &TypedNode,
75 op: &TreeEnsembleClassifier,
76) -> TractResult<Option<Arc<RValue>>> {
77 let input = ast.mapping[&node.inputs[0]].clone();
78 let trees = ast.konst_variable(format!("{}_trees", node.name), &op.ensemble.data.trees)?;
79 let nodes = ast.konst_variable(format!("{}_nodes", node.name), &op.ensemble.data.nodes)?;
80 let leaves = ast.konst_variable(format!("{}_leaves", node.name), &op.ensemble.data.leaves)?;
81 let agg = match op.ensemble.aggregate_fn {
82 Aggregate::Min => "MIN",
83 Aggregate::Max => "MAX",
84 Aggregate::Sum => "SUM",
85 Aggregate::Avg => "AVERAGE",
86 };
87 Ok(Some(invocation(
88 "tract_onnx_ml_tree_ensemble_classifier",
89 &[input, trees, nodes, leaves],
90 &[
91 ("max_used_feature", numeric(op.ensemble.max_used_feature)),
92 ("n_classes", numeric(op.ensemble.n_classes)),
93 ("aggregate_fn", string(agg)),
94 ],
95 )))
96}
97
98fn load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
99 let input = invocation.named_arg_as(builder, "input")?;
100 let trees = invocation.named_arg_as(builder, "trees")?;
101 let nodes = invocation.named_arg_as(builder, "nodes")?;
102 let leaves = invocation.named_arg_as(builder, "leaves")?;
103 let max_used_feature = invocation.named_arg_as(builder, "max_used_feature")?;
104 let n_classes = invocation.named_arg_as(builder, "n_classes")?;
105 let aggregate_fn: String = invocation.named_arg_as(builder, "aggregate_fn")?;
106 let aggregate_fn = parse_aggregate(&aggregate_fn)?;
107 let data = TreeEnsembleData { trees, nodes, leaves };
108 let ensemble = TreeEnsemble { data, n_classes, max_used_feature, aggregate_fn };
109 let op = TreeEnsembleClassifier { ensemble };
110 builder.wire(op, &[input])
111}