tract_core/optim/
prop_const.rs1use tract_data::TooEarly;
2
3use crate::internal::*;
4use crate::ops::array::Slice;
5use crate::ops::dummy::Dummy;
6use crate::ops::konst::Const;
7use crate::ops::source::TypedSource;
8use crate::optim::OptimizerSession;
9
10#[derive(Clone, Debug, Default)]
11pub struct PropConst(usize);
12
13impl super::TypedPass for PropConst {
14 fn reset(&mut self) -> TractResult<()> {
15 self.0 = 0;
16 Ok(())
17 }
18 fn next(
19 &mut self,
20 _session: &mut OptimizerSession,
21 model: &TypedModel,
22 ) -> TractResult<Option<TypedModelPatch>> {
23 for node in &model.nodes[self.0..] {
24 if node.op_is::<Const>() && node.outputs[0].fact.konst.is_none() {
25 self.0 = node.id;
26 let mut patch = TypedModelPatch::default();
27 let wire =
28 patch.add_const(&node.name, node.op_as::<Const>().unwrap().val().clone())?;
29 patch.shunt_outside(model, node.id.into(), wire)?;
30 return Ok(Some(patch));
31 }
32 let inputs = model.node_input_facts(node.id)?;
33 if !node.op_is::<Const>()
34 && !node.op_is::<Dummy>()
35 && !node.op_is::<TypedSource>()
36 && node.op.is_stateless()
37 && inputs.iter().zip(&node.inputs).all(|(fact, outlet)| {
38 fact.konst.is_some()
39 && (model.node(outlet.node).outputs[outlet.slot].successors.len() == 1
40 || node.op_is::<Slice>()
41 || (fact.datum_type.is_number()
42 && fact.shape.volume().as_i64().is_some_and(|d| d < 1024)))
43 })
44 {
45 let inputs =
46 inputs.iter().map(|f| f.konst.clone().unwrap().into_tvalue()).collect();
47 match node.op.eval_with_session(node.id, &SessionState::default(), inputs) {
48 Ok(mut res) => {
49 self.0 = node.id;
50 let mut node = node;
51 loop {
52 let Some(succ) = model.single_succ(node.id)? else {
53 break;
54 };
55 if succ.inputs.len() > 1 || !succ.op.is_stateless() {
56 break;
57 }
58 let Ok(succ_res) =
59 succ.op.eval_with_session(node.id, &SessionState::default(), res.clone())
60 else {
61 break;
62 };
63 res = succ_res;
64 node = succ;
65 }
66 let mut patch = TypedModelPatch::default();
67 for (ix, output) in res.into_iter().enumerate() {
68 let opaque_fact =
69 model.outlet_fact(OutletId::new(node.id, ix))?.opaque_fact.clone();
70
71 let name = if ix > 0 {
72 format!("{}.{ix}", node.name)
73 } else {
74 node.name.clone()
75 };
76 let wire = patch.wire_node(
77 name,
78 Const::new_with_opt_opaque_fact(
79 output.into_arc_tensor(),
80 opaque_fact,
81 )?,
82 &[],
83 )?[0];
84 patch.shunt_outside(model, (node.id, ix).into(), wire)?;
85 }
86 self.0 = node.id;
87 return Ok(Some(patch));
88 }
89 Err(e) => {
90 if !e.root_cause().is::<TooEarly>() {
91 Err(e).with_context(|| {
92 format!("Eager eval {node} during optimisation")
93 })?;
94 }
95 }
96 }
97 }
98 }
99 Ok(None)
100 }
101}