tract_core/ops/cnn/
mod.rs

1use crate::internal::*;
2
3pub mod conv;
4pub mod deconv;
5mod maxpool;
6mod padding;
7mod patch_axis;
8mod patches;
9pub mod pools;
10mod sumpool;
11
12pub use self::conv::{Conv, KernelFormat};
13pub use self::deconv::Deconv;
14pub use self::maxpool::MaxPool;
15pub use self::padding::PaddingSpec;
16pub use self::patch_axis::PatchAxis;
17pub use self::patches::{Patch, PatchSpec};
18pub use self::pools::PoolSpec;
19pub use self::sumpool::SumPool;
20
21use super::array::MultiBroadcastTo;
22
23pub fn wire_reshape_bias_as_vector(
24    model: &mut TypedModel,
25    name: impl AsRef<str>,
26    outlet: OutletId,
27    output_channels: usize,
28) -> TractResult<TVec<OutletId>> {
29    let name = name.as_ref();
30    let mut bias = tvec!(outlet);
31    let fact = model.outlet_fact(outlet)?.clone();
32    if fact.shape.volume().is_one() && fact.rank() > 0 {
33        bias = model.wire_node(
34            format!("{name}.bias.make_scalar"),
35            AxisOp::Reshape(0, fact.shape.to_tvec(), tvec![]),
36            &bias,
37        )?;
38    }
39    if model.outlet_fact(bias[0])?.rank() == 0 {
40        bias = model.wire_node(
41            format!("{name}.bias.broadcast"),
42            MultiBroadcastTo { shape: tvec!(output_channels).into() },
43            &bias,
44        )?;
45    }
46    Ok(bias)
47}
48
49pub fn wire_reshape_bias_for_bin(
50    model: &mut TypedModel,
51    name: impl AsRef<str>,
52    outlet: OutletId,
53    rank: usize,
54    c_axis: usize,
55    output_channels: usize,
56) -> TractResult<TVec<OutletId>> {
57    let name = name.as_ref();
58    let mut bias = wire_reshape_bias_as_vector(model, name, outlet, output_channels)?;
59    let fact = model.outlet_fact(bias[0])?.clone();
60    let mut bias_final_shape = tvec![1.to_dim(); rank];
61    bias_final_shape[c_axis] = output_channels.to_dim();
62    if *bias_final_shape != *fact.shape {
63        bias = model.wire_node(
64            format!("{name}.bias"),
65            AxisOp::Reshape(0, fact.shape.to_tvec(), bias_final_shape),
66            &bias,
67        )?;
68    }
69    Ok(bias)
70}
71
72pub fn rewrite_conv_with_n_axis(
73    _ctx: &(),
74    model: &TypedModel,
75    node: &TypedNode,
76    name: &str,
77    conv: &Conv,
78) -> TractResult<Option<TypedModelPatch>> {
79    if !conv.pool_spec.data_format.has_n() {
80        let mut new = conv.clone();
81        new.pool_spec.data_format = conv.pool_spec.data_format.with_n();
82        let mut patch = TypedModelPatch::default();
83        let mut wire = patch.taps(model, &node.inputs)?;
84        wire[0] = patch.wire_node(format!("{name}.add_n"), AxisOp::Add(0), &[wire[0]])?[0];
85        wire = patch.wire_node(name, new, &wire)?;
86        wire = patch.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), &wire)?;
87        patch.shunt_outside(model, node.id.into(), wire[0])?;
88        return Ok(Some(patch));
89    }
90    Ok(None)
91}
92
93pub fn rewrite_deconv_with_n_axis(
94    _ctx: &(),
95    model: &TypedModel,
96    node: &TypedNode,
97    name: &str,
98    deconv: &Deconv,
99) -> TractResult<Option<TypedModelPatch>> {
100    if !deconv.pool_spec.data_format.has_n() {
101        let mut new = deconv.clone();
102        new.pool_spec.data_format = deconv.pool_spec.data_format.with_n();
103        let mut patch = TypedModelPatch::default();
104        let mut wire = patch.taps(model, &node.inputs)?;
105        wire[0] = patch.wire_node(format!("{name}.add_n"), AxisOp::Add(0), &[wire[0]])?[0];
106        wire = patch.wire_node(name, new, &wire)?;
107        wire = patch.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), &wire)?;
108        patch.shunt_outside(model, node.id.into(), wire[0])?;
109        return Ok(Some(patch));
110    }
111    Ok(None)
112}
113