1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
use crate::ops::prelude::*;
use crate::Model;

#[derive(Debug)]
pub struct Reduce(pub ReductionPhase);

impl super::OptimizerPass for Reduce {
    fn pass(&self, model: &mut Model) -> TractResult<bool> {
        let mut done_something = false;
        loop {
            let mut done_something_this_time = false;
            for id in model.eval_order()? {
                let reduced = {
                    let node = &model.nodes()[id];
                    debug!(
                        "Consider {:?} {} #{} ({})",
                        self,
                        node.name,
                        node.id,
                        node.op().name()
                    );
                    let input_facts: TVec<&TensorFact> = node
                        .inputs
                        .iter()
                        .map(|o| model.fact(*o))
                        .inspect(|fact| trace!("   Input {:?}", fact))
                        .collect::<TractResult<_>>()?;
                    let output_facts: TVec<&TensorFact> =
                        node.outputs.iter().map(|o| &o.fact).collect();
                    node.op
                        .reduce(input_facts, output_facts, self.0)
                        .map_err(|e| format!("Reduce {:?} node {:?}, {:?}", self.0, node, e))?
                };
                if let Some(red) = reduced {
                    debug!("  Unarize to {:?}", red);
                    use crate::model::dsl::ModelDsl;
                    use crate::model::{InletId, OutletId};

                    let crate::ops::ReducedOpRewire { mut ops, rewired } = red;
                    let inputs: Vec<OutletId> = rewired
                        .into_iter()
                        .map(|ix| model.node(id).inputs[ix])
                        .collect();
                    if ops.len() == 1 {
                        model.node_mut(id).op = ops.remove(0);
                        model.clear_inputs(id)?;
                        for (ix, i) in inputs.iter().enumerate() {
                            model.add_edge(*i, InletId::new(id, ix))?;
                        }
                    } else {
                        model.mut_nodes()[id].op = ops.pop().unwrap();
                        let name = format!("{}-{}", model.node(id).name, ops.len());
                        let mut created_node_id = model.add_node(name, ops.remove(0))?;
                        for (ix, i) in inputs.iter().enumerate() {
                            model.add_edge(*i, InletId::new(created_node_id, ix))?;
                        }
                        while ops.len() > 0 {
                            let name = format!("{}-{}", model.node(id).name, ops.len());
                            created_node_id = model.chain(name, ops.remove(0))?;
                        }
                        model.clear_inputs(id)?;
                        model.add_edge(OutletId::new(created_node_id, 0), InletId::new(id, 0))?;
                    }
                    if cfg!(debug_assertions) {
                        model.check_edges()?;
                    }
                    done_something_this_time = true
                }
            }
            done_something = done_something || done_something_this_time;
            if !done_something_this_time {
                break;
            }
        }
        Ok(done_something)
    }
}