1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
use super::Downsample;
use crate::internal::*;
use crate::ops;

// trivial cases (sampling on N, mat-mul-as-conv) is handled by invariants
pub fn fuse_downsample_into_conv(
    model: &TypedModel,
    conv_node: &TypedNode,
    conv_op: &ops::cnn::conv::ConvUnary,
    down_node: &TypedNode,
    down_op: &Downsample,
) -> TractResult<Option<TypedModelPatch>> {
    if down_op.stride < 0 {
        return Ok(None);
    }
    let input_fact = model.outlet_fact(conv_node.inputs[0])?;
    let input_shape =
        conv_op.pool_spec.data_format.shape(input_fact.shape.iter().collect::<TVec<_>>())?;
    if down_op.axis < input_shape.h_axis() {
        return Ok(None);
    }
    let geo_axis = down_op.axis - input_shape.h_axis();
    if geo_axis >= input_shape.rank() {
        return Ok(None);
    }
    let mut new_conv = conv_op.clone();
    if new_conv.pool_spec.strides.is_none() {
        new_conv.pool_spec.strides = Some(tvec!(1; input_shape.hw_rank()));
    }
    new_conv.pool_spec.strides.as_mut().unwrap()[geo_axis] *= down_op.stride as usize;

    let mut patch = TypedModelPatch::default();
    let tap = patch.tap_model(model, conv_node.inputs[0])?;
    let new_output = patch.wire_node(&*conv_node.name, new_conv, [tap].as_ref())?[0];
    patch.shunt_outside(model, OutletId::new(down_node.id, 0), new_output)?;
    Ok(Some(patch))
}