1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
use crate::internal::*;
use tract_itertools::Itertools;
use crate::optim::OptimizerSession;
#[derive(Clone, Debug)]
pub struct PushSplitDown;
impl super::TypedPass for PushSplitDown {
fn reset(&mut self) -> TractResult<()> {
Ok(())
}
fn next(&mut self, _session: &mut OptimizerSession, model: &TypedModel) -> TractResult<Option<TypedModelPatch>> {
let mut patch = TypedModelPatch::default();
for node in model.eval_order()? {
for output in &model.node(node).outputs {
for (a, b) in output.successors.iter().tuple_combinations() {
if patch.obliterate.contains(&b.node) {
continue;
}
let a = model.node(a.node);
let b = model.node(b.node);
if a.same_as(b) {
for slot in 0..b.outputs.len() {
let tap = patch.tap_model(model, OutletId::new(a.id, slot))?;
patch.shunt_outside(model, OutletId::new(b.id, slot), tap)?;
patch.obliterate(b.id)?;
}
}
}
}
}
Ok(Some(patch).filter(|p| !p.is_empty()))
}
}