1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use crate::internal::*;

pub mod conv;
pub mod deconv;
mod maxpool;
mod padding;
mod patch_axis;
mod patches;
pub mod pools;
mod sumpool;

pub use self::conv::{Conv, KernelFormat};
pub use self::deconv::Deconv;
pub use self::maxpool::MaxPool;
pub use self::padding::PaddingSpec;
pub use self::patch_axis::PatchAxis;
pub use self::patches::{Patch, PatchSpec};
pub use self::pools::PoolSpec;
pub use self::sumpool::SumPool;

use super::array::MultiBroadcastTo;

pub fn wire_reshape_bias_as_vector(
    model: &mut TypedModel,
    name: impl AsRef<str>,
    outlet: OutletId,
    output_channels: usize,
) -> TractResult<TVec<OutletId>> {
    let name = name.as_ref();
    let mut bias = tvec!(outlet);
    let fact = model.outlet_fact(outlet)?.clone();
    if fact.shape.volume().is_one() && fact.rank() > 0 {
        bias = model.wire_node(
            format!("{name}.bias.make_scalar"),
            AxisOp::Reshape(0, fact.shape.to_tvec(), tvec![]),
            &bias,
        )?;
    }
    if model.outlet_fact(bias[0])?.rank() == 0 {
        bias = model.wire_node(
            format!("{name}.bias.broadcast"),
            MultiBroadcastTo { shape: tvec!(output_channels).into() },
            &bias,
        )?;
    }
    Ok(bias)
}

pub fn wire_reshape_bias_for_bin(
    model: &mut TypedModel,
    name: impl AsRef<str>,
    outlet: OutletId,
    rank: usize,
    c_axis: usize,
    output_channels: usize,
) -> TractResult<TVec<OutletId>> {
    let name = name.as_ref();
    let mut bias = wire_reshape_bias_as_vector(model, name, outlet, output_channels)?;
    let fact = model.outlet_fact(bias[0])?.clone();
    let mut bias_final_shape = tvec![1.to_dim(); rank];
    bias_final_shape[c_axis] = output_channels.to_dim();
    if *bias_final_shape != *fact.shape {
        bias = model.wire_node(
            format!("{name}.bias"),
            AxisOp::Reshape(0, fact.shape.to_tvec(), bias_final_shape),
            &bias,
        )?;
    }
    Ok(bias)
}

pub fn rewrite_conv_with_n_axis(
    _ctx: &(),
    model: &TypedModel,
    node: &TypedNode,
    name: &str,
    conv: &Conv,
) -> TractResult<Option<TypedModelPatch>> {
    if !conv.pool_spec.data_format.has_n() {
        let mut new = conv.clone();
        new.pool_spec.data_format = conv.pool_spec.data_format.with_n();
        let mut patch = TypedModelPatch::default();
        let mut wire = patch.taps(model, &node.inputs)?;
        wire[0] = patch.wire_node(format!("{name}.add_n"), AxisOp::Add(0), &[wire[0]])?[0];
        wire = patch.wire_node(name, new, &wire)?;
        wire = patch.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), &wire)?;
        patch.shunt_outside(model, node.id.into(), wire[0])?;
        return Ok(Some(patch));
    }
    Ok(None)
}

pub fn rewrite_deconv_with_n_axis(
    _ctx: &(),
    model: &TypedModel,
    node: &TypedNode,
    name: &str,
    deconv: &Deconv,
) -> TractResult<Option<TypedModelPatch>> {
    if !deconv.pool_spec.data_format.has_n() {
        let mut new = deconv.clone();
        new.pool_spec.data_format = deconv.pool_spec.data_format.with_n();
        let mut patch = TypedModelPatch::default();
        let mut wire = patch.taps(model, &node.inputs)?;
        wire[0] = patch.wire_node(format!("{name}.add_n"), AxisOp::Add(0), &[wire[0]])?[0];
        wire = patch.wire_node(name, new, &wire)?;
        wire = patch.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), &wire)?;
        patch.shunt_outside(model, node.id.into(), wire[0])?;
        return Ok(Some(patch));
    }
    Ok(None)
}