1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
use tract_data::UndeterminedSymbol;
use crate::internal::*;
use crate::ops::konst::Const;
use crate::optim::OptimizerSession;
#[derive(Clone, Debug)]
pub struct PropConst;
impl super::TypedPass for PropConst {
fn reset(&mut self) -> TractResult<()> {
Ok(())
}
fn next(
&mut self,
_session: &mut OptimizerSession,
model: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
let mut patch = TypedModelPatch::default();
for n in model.eval_order()? {
let node = model.node(n);
if node.op.is_stateless() && !node.op_is::<Const>() {
if let Some(inputs) = model
.node_input_facts(n)?
.iter()
.map(|f| f.konst.clone().map(|t| t.into_tvalue()))
.collect()
{
match node.op.eval(inputs) {
Ok(res) => {
for (ix, output) in res.into_iter().enumerate() {
let mut name = node.name.clone();
if ix > 0 {
name = format!("{}.{}", name, ix);
}
let wire = patch.add_const(name, output.into_arc_tensor())?;
patch.shunt_outside(model, (n, ix).into(), wire)?;
}
}
Err(e) => {
if !e.root_cause().is::<UndeterminedSymbol>() {
Err(e).with_context(|| {
format!("Eager eval {} during optimisation", model.node(n))
})?;
}
}
}
}
}
}
Ok(Some(patch).filter(|p| p.nodes.len() > 0))
}
}