tract_core/optim/
push_split_down.rs1use crate::internal::*;
2
3use crate::optim::OptimizerSession;
4use tract_itertools::Itertools;
5
6#[derive(Clone, Debug)]
7pub struct PushSplitDown;
8
9impl super::TypedPass for PushSplitDown {
10 fn reset(&mut self) -> TractResult<()> {
11 Ok(())
12 }
13 fn next(
14 &mut self,
15 _session: &mut OptimizerSession,
16 model: &TypedModel,
17 ) -> TractResult<Option<TypedModelPatch>> {
18 let mut patch = TypedModelPatch::default();
19 for node in model.eval_order()? {
20 for output in &model.node(node).outputs {
21 for (a, b) in output.successors.iter().tuple_combinations() {
22 if a.node == b.node {
23 continue;
25 }
26 if patch.obliterate.contains(&b.node) {
27 continue;
28 }
29 if model.outputs.contains(&a.node.into())
31 && model.outputs.contains(&b.node.into())
32 {
33 continue;
34 }
35 let a = model.node(a.node);
36 let b = model.node(b.node);
37 if a.same_as(b) {
38 for slot in 0..b.outputs.len() {
39 let tap = patch.tap_model(model, OutletId::new(a.id, slot))?;
40 patch.shunt_outside(model, OutletId::new(b.id, slot), tap)?;
41 patch.obliterate(b.id)?;
42 }
43 }
44 }
45 }
46 }
47 Ok(Some(patch).filter(|p| !p.is_empty()))
48 }
49}