tract_core/ops/cnn/
mod.rs1use crate::internal::*;
2
3pub mod conv;
4pub mod deconv;
5mod maxpool;
6mod padding;
7mod patch_axis;
8mod patches;
9pub mod pools;
10mod sumpool;
11
12pub use self::conv::{Conv, KernelFormat};
13pub use self::deconv::Deconv;
14pub use self::maxpool::MaxPool;
15pub use self::padding::PaddingSpec;
16pub use self::patch_axis::PatchAxis;
17pub use self::patches::{Patch, PatchSpec};
18pub use self::pools::PoolSpec;
19pub use self::sumpool::SumPool;
20
21use super::array::MultiBroadcastTo;
22
23pub fn wire_reshape_bias_as_vector(
24 model: &mut TypedModel,
25 name: impl AsRef<str>,
26 outlet: OutletId,
27 output_channels: usize,
28) -> TractResult<TVec<OutletId>> {
29 let name = name.as_ref();
30 let mut bias = tvec!(outlet);
31 let fact = model.outlet_fact(outlet)?.clone();
32 if fact.shape.volume().is_one() && fact.rank() > 0 {
33 bias = model.wire_node(
34 format!("{name}.bias.make_scalar"),
35 AxisOp::Reshape(0, fact.shape.to_tvec(), tvec![]),
36 &bias,
37 )?;
38 }
39 if model.outlet_fact(bias[0])?.rank() == 0 {
40 bias = model.wire_node(
41 format!("{name}.bias.broadcast"),
42 MultiBroadcastTo { shape: tvec!(output_channels).into() },
43 &bias,
44 )?;
45 }
46 Ok(bias)
47}
48
49pub fn wire_reshape_bias_for_bin(
50 model: &mut TypedModel,
51 name: impl AsRef<str>,
52 outlet: OutletId,
53 rank: usize,
54 c_axis: usize,
55 output_channels: usize,
56) -> TractResult<TVec<OutletId>> {
57 let name = name.as_ref();
58 let mut bias = wire_reshape_bias_as_vector(model, name, outlet, output_channels)?;
59 let fact = model.outlet_fact(bias[0])?.clone();
60 let mut bias_final_shape = tvec![1.to_dim(); rank];
61 bias_final_shape[c_axis] = output_channels.to_dim();
62 if *bias_final_shape != *fact.shape {
63 bias = model.wire_node(
64 format!("{name}.bias"),
65 AxisOp::Reshape(0, fact.shape.to_tvec(), bias_final_shape),
66 &bias,
67 )?;
68 }
69 Ok(bias)
70}
71
72pub fn rewrite_conv_with_n_axis(
73 _ctx: &(),
74 model: &TypedModel,
75 node: &TypedNode,
76 name: &str,
77 conv: &Conv,
78) -> TractResult<Option<TypedModelPatch>> {
79 if !conv.pool_spec.data_format.has_n() {
80 let mut new = conv.clone();
81 new.pool_spec.data_format = conv.pool_spec.data_format.with_n();
82 let mut patch = TypedModelPatch::default();
83 let mut wire = patch.taps(model, &node.inputs)?;
84 wire[0] = patch.wire_node(format!("{name}.add_n"), AxisOp::Add(0), &[wire[0]])?[0];
85 wire = patch.wire_node(name, new, &wire)?;
86 wire = patch.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), &wire)?;
87 patch.shunt_outside(model, node.id.into(), wire[0])?;
88 return Ok(Some(patch));
89 }
90 Ok(None)
91}
92
93pub fn rewrite_deconv_with_n_axis(
94 _ctx: &(),
95 model: &TypedModel,
96 node: &TypedNode,
97 name: &str,
98 deconv: &Deconv,
99) -> TractResult<Option<TypedModelPatch>> {
100 if !deconv.pool_spec.data_format.has_n() {
101 let mut new = deconv.clone();
102 new.pool_spec.data_format = deconv.pool_spec.data_format.with_n();
103 let mut patch = TypedModelPatch::default();
104 let mut wire = patch.taps(model, &node.inputs)?;
105 wire[0] = patch.wire_node(format!("{name}.add_n"), AxisOp::Add(0), &[wire[0]])?[0];
106 wire = patch.wire_node(name, new, &wire)?;
107 wire = patch.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), &wire)?;
108 patch.shunt_outside(model, node.id.into(), wire[0])?;
109 return Ok(Some(patch));
110 }
111 Ok(None)
112}
113