tract_core/model/
helpers.rs1use crate::ops::binary::{BinMiniOp, TypedBinOp};
2use crate::ops::konst::Const;
3use crate::prelude::*;
4use tract_data::internal::Approximation;
5
6pub trait TypedModelHelpers {
7 fn next_node(&self, node: &TypedNode) -> Option<&TypedNode>;
8 fn previous_node(&self, node: &TypedNode) -> Option<&TypedNode>;
9 fn previous_nodes(&self, node: &TypedNode) -> TVec<&TypedNode>;
10 fn collect_const_inputs<'a>(&'a self, node: &TypedNode) -> TVec<&'a Const>;
11 fn single_prev_node_as<O: TypedOp>(&self, node: &TypedNode) -> Option<(usize, &TypedNode)>;
12 fn matches_single_input_const(&self, node: &TypedNode, konst: f32) -> bool;
13 fn find_succ_bin_with_const<B: BinMiniOp>(
14 &self,
15 node: &TypedNode,
16 konst: f32,
17 ) -> Option<&TypedNode>;
18 fn find_succ_bin_with_outlet<B: BinMiniOp>(
19 &self,
20 node: &TypedNode,
21 outlet_id: &OutletId,
22 ) -> Option<&TypedNode>;
23}
24
25impl TypedModelHelpers for TypedModel {
26 fn next_node(&self, node: &TypedNode) -> Option<&TypedNode> {
27 if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
28 return None;
29 }
30 let succ = node.outputs[0].successors[0];
31 Some(&self.nodes()[succ.node])
32 }
33
34 fn previous_node(&self, node: &TypedNode) -> Option<&TypedNode> {
35 if node.inputs.len() != 1 {
36 return None;
37 }
38 Some(&self.nodes()[node.inputs[0].node])
39 }
40
41 fn previous_nodes(&self, node: &TypedNode) -> TVec<&TypedNode> {
42 node.inputs.iter().map(|n| &self.nodes()[n.node]).collect()
43 }
44
45 fn collect_const_inputs<'a>(&'a self, node: &TypedNode) -> TVec<&'a Const> {
46 node.inputs
47 .iter()
48 .filter_map(|i| {
49 let prec = &self.nodes()[i.node];
50 prec.op_as::<Const>()
51 })
52 .collect::<TVec<_>>()
53 }
54
55 fn single_prev_node_as<O: TypedOp>(&self, node: &TypedNode) -> Option<(usize, &TypedNode)> {
56 let prev_nodes = node
57 .inputs
58 .iter()
59 .enumerate()
60 .filter_map(|(in_idx, i)| {
61 let prec = &self.nodes()[i.node];
62 prec.op_is::<O>().then_some((in_idx, prec))
63 })
64 .collect::<TVec<_>>();
65
66 if prev_nodes.len() != 1 { None } else { Some(prev_nodes[0]) }
67 }
68
69 fn matches_single_input_const(&self, node: &TypedNode, konst: f32) -> bool {
70 let consts = self.collect_const_inputs(node);
71 if consts.len() != 1 {
72 return false;
73 }
74 let Ok(in_const) = consts[0].val().cast_to_dt(DatumType::F32) else {
75 return false;
76 };
77 let Ok(in_const) = in_const.to_scalar_tensor() else {
78 return false;
79 };
80 in_const
81 .close_enough(&tract_data::prelude::tensor0(konst), Approximation::Approximate)
82 .is_ok()
83 }
84
85 fn find_succ_bin_with_const<B: BinMiniOp>(
86 &self,
87 node: &TypedNode,
88 konst: f32,
89 ) -> Option<&TypedNode> {
90 let succ = self.single_succ(node.id).ok()??;
91 let succ_op = succ.op_as::<TypedBinOp>()?;
92 (succ_op.0.is::<B>() && self.matches_single_input_const(succ, konst)).then_some(succ)
93 }
94
95 fn find_succ_bin_with_outlet<B: BinMiniOp>(
96 &self,
97 node: &TypedNode,
98 outlet_id: &OutletId,
99 ) -> Option<&TypedNode> {
100 let succ = self.single_succ(node.id).ok()??;
101 let succ_op = succ.op_as::<TypedBinOp>()?;
102 (succ_op.0.is::<B>() && succ.inputs.contains(outlet_id)).then_some(succ)
103 }
104}