Skip to main content

tract_core/model/
helpers.rs

1use crate::ops::binary::{BinMiniOp, TypedBinOp};
2use crate::ops::konst::Const;
3use crate::prelude::*;
4use tract_data::internal::Approximation;
5
6pub trait TypedModelHelpers {
7    fn next_node(&self, node: &TypedNode) -> Option<&TypedNode>;
8    fn previous_node(&self, node: &TypedNode) -> Option<&TypedNode>;
9    fn previous_nodes(&self, node: &TypedNode) -> TVec<&TypedNode>;
10    fn collect_const_inputs<'a>(&'a self, node: &TypedNode) -> TVec<&'a Const>;
11    fn single_prev_node_as<O: TypedOp>(&self, node: &TypedNode) -> Option<(usize, &TypedNode)>;
12    fn matches_single_input_const(&self, node: &TypedNode, konst: f32) -> bool;
13    fn find_succ_bin_with_const<B: BinMiniOp>(
14        &self,
15        node: &TypedNode,
16        konst: f32,
17    ) -> Option<&TypedNode>;
18    fn find_succ_bin_with_outlet<B: BinMiniOp>(
19        &self,
20        node: &TypedNode,
21        outlet_id: &OutletId,
22    ) -> Option<&TypedNode>;
23}
24
25impl TypedModelHelpers for TypedModel {
26    fn next_node(&self, node: &TypedNode) -> Option<&TypedNode> {
27        if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
28            return None;
29        }
30        let succ = node.outputs[0].successors[0];
31        Some(&self.nodes()[succ.node])
32    }
33
34    fn previous_node(&self, node: &TypedNode) -> Option<&TypedNode> {
35        if node.inputs.len() != 1 {
36            return None;
37        }
38        Some(&self.nodes()[node.inputs[0].node])
39    }
40
41    fn previous_nodes(&self, node: &TypedNode) -> TVec<&TypedNode> {
42        node.inputs.iter().map(|n| &self.nodes()[n.node]).collect()
43    }
44
45    fn collect_const_inputs<'a>(&'a self, node: &TypedNode) -> TVec<&'a Const> {
46        node.inputs
47            .iter()
48            .filter_map(|i| {
49                let prec = &self.nodes()[i.node];
50                prec.op_as::<Const>()
51            })
52            .collect::<TVec<_>>()
53    }
54
55    fn single_prev_node_as<O: TypedOp>(&self, node: &TypedNode) -> Option<(usize, &TypedNode)> {
56        let prev_nodes = node
57            .inputs
58            .iter()
59            .enumerate()
60            .filter_map(|(in_idx, i)| {
61                let prec = &self.nodes()[i.node];
62                prec.op_is::<O>().then_some((in_idx, prec))
63            })
64            .collect::<TVec<_>>();
65
66        if prev_nodes.len() != 1 { None } else { Some(prev_nodes[0]) }
67    }
68
69    fn matches_single_input_const(&self, node: &TypedNode, konst: f32) -> bool {
70        let consts = self.collect_const_inputs(node);
71        if consts.len() != 1 {
72            return false;
73        }
74        let Ok(in_const) = consts[0].val().cast_to_dt(DatumType::F32) else {
75            return false;
76        };
77        let Ok(in_const) = in_const.to_scalar_tensor() else {
78            return false;
79        };
80        in_const
81            .close_enough(&tract_data::prelude::tensor0(konst), Approximation::Approximate)
82            .is_ok()
83    }
84
85    fn find_succ_bin_with_const<B: BinMiniOp>(
86        &self,
87        node: &TypedNode,
88        konst: f32,
89    ) -> Option<&TypedNode> {
90        let succ = self.single_succ(node.id).ok()??;
91        let succ_op = succ.op_as::<TypedBinOp>()?;
92        (succ_op.0.is::<B>() && self.matches_single_input_const(succ, konst)).then_some(succ)
93    }
94
95    fn find_succ_bin_with_outlet<B: BinMiniOp>(
96        &self,
97        node: &TypedNode,
98        outlet_id: &OutletId,
99    ) -> Option<&TypedNode> {
100        let succ = self.single_succ(node.id).ok()??;
101        let succ_op = succ.op_as::<TypedBinOp>()?;
102        (succ_op.0.is::<B>() && succ.inputs.contains(outlet_id)).then_some(succ)
103    }
104}