Skip to main content

tract_core/optim/
prop_const.rs

1use tract_data::TooEarly;
2
3use crate::internal::*;
4use crate::ops::array::Slice;
5use crate::ops::dummy::Dummy;
6use crate::ops::konst::Const;
7use crate::ops::source::TypedSource;
8use crate::optim::OptimizerSession;
9
10#[derive(Clone, Debug, Default)]
11pub struct PropConst(usize);
12
13impl super::TypedPass for PropConst {
14    fn reset(&mut self) -> TractResult<()> {
15        self.0 = 0;
16        Ok(())
17    }
18    fn next(
19        &mut self,
20        _session: &mut OptimizerSession,
21        model: &TypedModel,
22    ) -> TractResult<Option<TypedModelPatch>> {
23        for node in &model.nodes[self.0..] {
24            if node.op_is::<Const>() && node.outputs[0].fact.konst.is_none() {
25                self.0 = node.id;
26                let mut patch = TypedModelPatch::default();
27                let wire =
28                    patch.add_const(&node.name, node.op_as::<Const>().unwrap().val().clone())?;
29                patch.shunt_outside(model, node.id.into(), wire)?;
30                return Ok(Some(patch));
31            }
32            let inputs = model.node_input_facts(node.id)?;
33            if !node.op_is::<Const>()
34                && !node.op_is::<Dummy>()
35                && !node.op_is::<TypedSource>()
36                && node.op.is_stateless()
37                && inputs.iter().zip(&node.inputs).all(|(fact, outlet)| {
38                    fact.konst.is_some()
39                        && (model.node(outlet.node).outputs[outlet.slot].successors.len() == 1
40                            || node.op_is::<Slice>()
41                            || (fact.datum_type.is_number()
42                                && fact.shape.volume().as_i64().is_some_and(|d| d < 1024)))
43                })
44            {
45                let inputs =
46                    inputs.iter().map(|f| f.konst.clone().unwrap().into_tvalue()).collect();
47                match node.op.eval_with_session(node.id, &SessionState::default(), inputs) {
48                    Ok(mut res) => {
49                        self.0 = node.id;
50                        let mut node = node;
51                        loop {
52                            let Some(succ) = model.single_succ(node.id)? else {
53                                break;
54                            };
55                            if succ.inputs.len() > 1 || !succ.op.is_stateless() {
56                                break;
57                            }
58                            let Ok(succ_res) =
59                                succ.op.eval_with_session(node.id, &SessionState::default(), res.clone())
60                            else {
61                                break;
62                            };
63                            res = succ_res;
64                            node = succ;
65                        }
66                        let mut patch = TypedModelPatch::default();
67                        for (ix, output) in res.into_iter().enumerate() {
68                            let opaque_fact =
69                                model.outlet_fact(OutletId::new(node.id, ix))?.opaque_fact.clone();
70
71                            let name = if ix > 0 {
72                                format!("{}.{ix}", node.name)
73                            } else {
74                                node.name.clone()
75                            };
76                            let wire = patch.wire_node(
77                                name,
78                                Const::new_with_opt_opaque_fact(
79                                    output.into_arc_tensor(),
80                                    opaque_fact,
81                                )?,
82                                &[],
83                            )?[0];
84                            patch.shunt_outside(model, (node.id, ix).into(), wire)?;
85                        }
86                        self.0 = node.id;
87                        return Ok(Some(patch));
88                    }
89                    Err(e) => {
90                        if !e.root_cause().is::<TooEarly>() {
91                            Err(e).with_context(|| {
92                                format!("Eager eval {node} during optimisation")
93                            })?;
94                        }
95                    }
96                }
97            }
98        }
99        Ok(None)
100    }
101}