1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
use std::collections::HashMap; use crate::model::{InletId, OutletId}; use crate::ops::prelude::*; use crate::Model; use itertools::Itertools; #[derive(Debug)] pub struct PushSplitDown; impl super::OptimizerPass for PushSplitDown { fn pass(&self, model: &mut Model) -> TractResult<bool> { let mut done_something = false; loop { let mut remap = HashMap::<usize, usize>::new(); for node in model.eval_order()? { for output in &model.node(node).outputs { for (a, b) in output.successors.iter().tuple_combinations() { if remap.contains_key(&a.node) { continue; } let a = model.node(a.node); let b = model.node(b.node); if a.same_as(b) { remap.insert(b.id, a.id); } } } } if remap.len() > 0 { for (&killed, &kept) in remap.iter() { trace!("collapsing {} into {}", killed, kept); let successors: Vec<InletId> = model .node(killed) .outputs .iter() .flat_map(|s| s.successors.iter()) .cloned() .collect(); for succ in successors { for input_ix in 0..model.node(succ.node).inputs.len() { let outlet = model.node(succ.node).inputs[input_ix]; if outlet.node == killed { model.add_edge( OutletId::new(kept, outlet.slot), InletId::new(succ.node, input_ix), )?; } } } model.clear_inputs(killed)?; if cfg!(debug_assertions) { model.check_edges()?; } done_something = true; } } else { break; } } Ok(done_something) } }