tract_onnx_opl/ml/
tree_ensemble_classifier.rs

1pub use super::tree::{Aggregate, Cmp, TreeEnsemble, TreeEnsembleData};
2use tract_nnef::internal::*;
3
4pub fn register(registry: &mut Registry) {
5    registry.register_primitive(
6        "tract_onnx_ml_tree_ensemble_classifier",
7        &parameters(),
8        &[("output", TypeName::Scalar.tensor())],
9        load,
10    );
11    registry.register_dumper(dump);
12}
13
14pub fn parse_aggregate(s: &str) -> TractResult<Aggregate> {
15    match s {
16        "SUM" => Ok(Aggregate::Sum),
17        "AVERAGE" => Ok(Aggregate::Avg),
18        "MAX" => Ok(Aggregate::Max),
19        "MIN" => Ok(Aggregate::Min),
20        _ => bail!("Invalid aggregate function: {}", s),
21    }
22}
23
24#[derive(Debug, Clone, Hash)]
25pub struct TreeEnsembleClassifier {
26    pub ensemble: TreeEnsemble,
27}
28
29impl Op for TreeEnsembleClassifier {
30    fn name(&self) -> StaticName {
31        "TreeEnsembleClassifier".into()
32    }
33
34    op_as_typed_op!();
35}
36
37impl EvalOp for TreeEnsembleClassifier {
38    fn is_stateless(&self) -> bool {
39        true
40    }
41
42    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
43        let input = args_1!(inputs);
44        let input = input.cast_to::<f32>()?;
45        let input = input.to_array_view::<f32>()?;
46        let scores = self.ensemble.eval(input)?;
47        Ok(tvec!(scores.into_tvalue()))
48    }
49}
50
51impl TypedOp for TreeEnsembleClassifier {
52    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
53        let n = &inputs[0].shape[0];
54        Ok(tvec!(f32::fact(&[n.clone(), self.ensemble.n_classes().into()])))
55    }
56
57    as_op!();
58}
59
60fn parameters() -> Vec<Parameter> {
61    vec![
62        TypeName::Scalar.tensor().named("input"),
63        TypeName::Scalar.tensor().named("trees"),
64        TypeName::Scalar.tensor().named("nodes"),
65        TypeName::Scalar.tensor().named("leaves"),
66        TypeName::Integer.named("max_used_feature"),
67        TypeName::Integer.named("n_classes"),
68        TypeName::String.named("aggregate_fn"),
69    ]
70}
71
72fn dump(
73    ast: &mut IntoAst,
74    node: &TypedNode,
75    op: &TreeEnsembleClassifier,
76) -> TractResult<Option<Arc<RValue>>> {
77    let input = ast.mapping[&node.inputs[0]].clone();
78    let trees = ast.konst_variable(format!("{}_trees", node.name), &op.ensemble.data.trees)?;
79    let nodes = ast.konst_variable(format!("{}_nodes", node.name), &op.ensemble.data.nodes)?;
80    let leaves = ast.konst_variable(format!("{}_leaves", node.name), &op.ensemble.data.leaves)?;
81    let agg = match op.ensemble.aggregate_fn {
82        Aggregate::Min => "MIN",
83        Aggregate::Max => "MAX",
84        Aggregate::Sum => "SUM",
85        Aggregate::Avg => "AVERAGE",
86    };
87    Ok(Some(invocation(
88        "tract_onnx_ml_tree_ensemble_classifier",
89        &[input, trees, nodes, leaves],
90        &[
91            ("max_used_feature", numeric(op.ensemble.max_used_feature)),
92            ("n_classes", numeric(op.ensemble.n_classes)),
93            ("aggregate_fn", string(agg)),
94        ],
95    )))
96}
97
98fn load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
99    let input = invocation.named_arg_as(builder, "input")?;
100    let trees = invocation.named_arg_as(builder, "trees")?;
101    let nodes = invocation.named_arg_as(builder, "nodes")?;
102    let leaves = invocation.named_arg_as(builder, "leaves")?;
103    let max_used_feature = invocation.named_arg_as(builder, "max_used_feature")?;
104    let n_classes = invocation.named_arg_as(builder, "n_classes")?;
105    let aggregate_fn: String = invocation.named_arg_as(builder, "aggregate_fn")?;
106    let aggregate_fn = parse_aggregate(&aggregate_fn)?;
107    let data = TreeEnsembleData { trees, nodes, leaves };
108    let ensemble = TreeEnsemble { data, n_classes, max_used_feature, aggregate_fn };
109    let op = TreeEnsembleClassifier { ensemble };
110    builder.wire(op, &[input])
111}