1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
use super::Downsample;
use crate::internal::*;
use crate::ops;
use crate::ops::scan::*;
pub fn pull_downsample_over_scan(
model: &TypedModel,
scan_node: &TypedNode,
scan_op: &ops::scan::Scan,
down_node: &TypedNode,
down_op: &Downsample,
) -> TractResult<Option<TypedModelPatch>> {
if down_op.stride < 0 {
return Ok(None);
}
let mut downsampled_body = scan_op.body.clone();
downsampled_body.check_consistency()?;
let outputs = downsampled_body.output_outlets()?.to_owned();
let downsample_outputs = outputs
.into_iter()
.enumerate()
.map(|(ix, oo)| {
Ok(downsampled_body.wire_node(
format!("{}-{}", &down_node.name, ix),
down_op.clone(),
&[oo],
)?[0])
})
.collect::<TractResult<Vec<_>>>()?;
downsampled_body.set_output_outlets(&downsample_outputs)?;
downsampled_body.declutter()?;
downsampled_body.check_consistency()?;
for input in downsampled_body.input_outlets()? {
let input = downsampled_body.node(input.node);
if input.outputs[0]
.successors
.iter()
.any(|succ| !downsampled_body.node(succ.node).op().same_as(down_op))
{
return Ok(None);
}
}
let inputs = downsampled_body.input_outlets()?.to_vec();
for input in inputs {
let node = &mut downsampled_body.node_mut(input.node);
let fact = &mut node.outputs[0].fact;
*fact = down_op.transform_fact(fact)?;
node.op_as_mut::<crate::ops::source::TypedSource>().unwrap().fact = fact.clone();
let downsamples = downsampled_body.node(input.node).outputs[0].successors.clone();
for ds in downsamples {
TypedModelPatch::shunt_one_op(&downsampled_body as _, downsampled_body.node(ds.node))?
.apply(&mut downsampled_body)?;
}
}
downsampled_body.check_consistency()?;
let inner_model = downsampled_body.into_decluttered()?;
let mut new_scan = scan_op.clone();
new_scan.body = inner_model;
for input in &mut new_scan.input_mapping {
match input {
InputMapping::State { ref mut initializer } => {
if let StateInitializer::Value(ref v) = initializer {
let mut new_v = down_op.eval(tvec!(v.clone().into_tvalue()))?;
*initializer = StateInitializer::Value(new_v.remove(0).into_arc_tensor());
}
}
InputMapping::Scan(info) => {
if info.chunk > 0 && info.chunk as usize % down_op.stride as usize != 0 {
return Ok(None);
}
info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
* info.chunk.signum()
}
_ => (),
}
}
for output in &mut new_scan.output_mapping {
if let Some(d) = output.full_dim_hint.as_mut() {
*d = down_op.transform_dim(d)
}
if let Some(info) = &mut output.scan {
if info.chunk as usize % down_op.stride as usize != 0 {
return Ok(None);
}
info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
* info.chunk.signum()
}
}
let mut patch = TypedModelPatch::default();
let mut inputs = tvec!();
for (ix, &i) in scan_node.inputs.iter().enumerate() {
let tap = patch.tap_model(model, i)?;
let ds = patch.wire_node(format!("{}-{}", down_node.name, ix), down_op.clone(), &[tap])?[0];
inputs.push(ds);
}
let scan = patch.wire_node(&*scan_node.name, new_scan, &inputs)?;
for ix in 0..scan_node.outputs.len() {
let succ = scan_node.outputs[ix].successors[0].node;
patch.shunt_outside(model, OutletId::new(succ, 0), scan[ix])?;
}
Ok(Some(patch))
}