1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
use tract_data::UndeterminedSymbol;

use crate::internal::*;
use crate::ops::konst::Const;
use crate::optim::OptimizerSession;

#[derive(Clone, Debug)]
pub struct PropConst;

impl super::TypedPass for PropConst {
    fn reset(&mut self) -> TractResult<()> {
        Ok(())
    }
    fn next(
        &mut self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut patch = TypedModelPatch::default();
        for n in model.eval_order()? {
            let node = model.node(n);
            if node.op.is_stateless() && !node.op_is::<Const>() {
                if let Some(inputs) = model
                    .node_input_facts(n)?
                    .iter()
                    .map(|f| f.konst.clone().map(|t| t.into_tvalue()))
                    .collect()
                {
                    match node.op.eval(inputs) {
                        Ok(res) => {
                            for (ix, output) in res.into_iter().enumerate() {
                                let mut name = node.name.clone();
                                if ix > 0 {
                                    name = format!("{}.{}", name, ix);
                                }
                                let wire = patch.add_const(name, output.into_arc_tensor())?;
                                patch.shunt_outside(model, (n, ix).into(), wire)?;
                            }
                        }
                        Err(e) => {
                            if !e.root_cause().is::<UndeterminedSymbol>() {
                                Err(e).with_context(|| {
                                    format!("Eager eval {} during optimisation", model.node(n))
                                })?;
                            }
                        }
                    }
                }
            }
        }
        Ok(Some(patch).filter(|p| p.nodes.len() > 0))
    }
}