1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
use crate::ops::prelude::*; use crate::Model; #[derive(Debug)] pub struct Reduce(pub ReductionPhase); impl super::OptimizerPass for Reduce { fn pass(&self, model: &mut Model) -> TractResult<bool> { let mut done_something = false; loop { let mut done_something_this_time = false; for id in model.eval_order()? { let reduced = { let node = &model.nodes()[id]; debug!( "Consider {:?} {} #{} ({})", self, node.name, node.id, node.op().name() ); let input_facts: TVec<&TensorFact> = node .inputs .iter() .map(|o| model.fact(*o)) .inspect(|fact| trace!(" Input {:?}", fact)) .collect::<TractResult<_>>()?; let output_facts: TVec<&TensorFact> = node.outputs.iter().map(|o| &o.fact).collect(); node.op .reduce(input_facts, output_facts, self.0) .map_err(|e| format!("Reduce {:?} node {:?}, {:?}", self.0, node, e))? }; if let Some(red) = reduced { debug!(" Unarize to {:?}", red); use crate::model::dsl::ModelDsl; use crate::model::{InletId, OutletId}; let crate::ops::ReducedOpRewire { mut ops, rewired } = red; let inputs: Vec<OutletId> = rewired .into_iter() .map(|ix| model.node(id).inputs[ix]) .collect(); if ops.len() == 1 { model.node_mut(id).op = ops.remove(0); model.clear_inputs(id)?; for (ix, i) in inputs.iter().enumerate() { model.add_edge(*i, InletId::new(id, ix))?; } } else { model.mut_nodes()[id].op = ops.pop().unwrap(); let name = format!("{}-{}", model.node(id).name, ops.len()); let mut created_node_id = model.add_node(name, ops.remove(0))?; for (ix, i) in inputs.iter().enumerate() { model.add_edge(*i, InletId::new(created_node_id, ix))?; } while ops.len() > 0 { let name = format!("{}-{}", model.node(id).name, ops.len()); created_node_id = model.chain(name, ops.remove(0))?; } model.clear_inputs(id)?; model.add_edge(OutletId::new(created_node_id, 0), InletId::new(id, 0))?; } if cfg!(debug_assertions) { model.check_edges()?; } done_something_this_time = true } } done_something = done_something || done_something_this_time; if !done_something_this_time { break; } } Ok(done_something) } }