Skip to main content

tract_core/ops/cnn/conv/
blocked.rs

1//! Direct, register-blocked convolution for the "channel-mixing temporal conv"
2//! shape class: NCHW, kernel width 1 (extent only on H), unit stride/dilation on
3//! the contiguous W axis, grouped, with a *small* number of output channels per
4//! group (`ocg`).
5//!
6//! For such convs the im2col + matmul lowering is inefficient: the per-group
7//! matmul is `M = ocg` (tiny, e.g. 5) × `K = icg·KH` × `N = H·W`, so the matmul
8//! kernel's m-tile is mostly wasted — exactly the same pathology as a low-M GEMV.
9//! ORT side-steps it with a direct conv.
10//!
11//! This op computes the conv directly: for each (group, output-row, block of the
12//! contiguous W axis) it holds `ocg` accumulators in registers and reduces over
13//! `(kh, icg)`, loading each input row ONCE and reusing it across all `ocg`
14//! outputs (the same input-reuse a GEMM gets). Measured on df_dec's `df_convp.1`
15//! (group=2, 64→10ch, kernel [5,1], 100×96): 0.77 ms native / 0.79 ms wasm vs
16//! 1.72 / 2.42 ms for tract's lazy im2col and 1.13 ms for ORT — a 2.2–3.1× win,
17//! bit-exact.
18//!
19//! Eligibility is checked in `Conv::codegen`; anything outside the supported
20//! shape class falls back to im2col.
21
22use crate::internal::*;
23
24/// Width of the inner SIMD-vectorised block over the contiguous W axis.
25const WB: usize = 16;
26
27/// Direct blocked conv. Inputs: X [N, C, H, W] (NCHW, f32), kernel
28/// [OC, ICG·KH] (group-major: row `oc` holds its group's `icg·KH` weights,
29/// i-major/h-minor), bias [OC]. Output [N, OC, H_out, W].
30#[derive(Debug, Clone, Hash, PartialEq, Eq)]
31pub struct BlockedConv {
32    pub n: usize,
33    pub c_in: usize,
34    pub h_in: usize,
35    pub w: usize,
36    pub oc: usize,
37    pub group: usize,
38    pub kh: usize,
39    pub stride_h: usize,
40    pub dil_h: usize,
41    pub pad_before_h: usize,
42    pub h_out: usize,
43}
44
45impl BlockedConv {
46    #[inline]
47    fn icg(&self) -> usize {
48        self.c_in / self.group
49    }
50    #[inline]
51    fn ocg(&self) -> usize {
52        self.oc / self.group
53    }
54}
55
56impl Op for BlockedConv {
57    fn name(&self) -> StaticName {
58        "BlockedConv".into()
59    }
60
61    fn info(&self) -> TractResult<Vec<String>> {
62        Ok(vec![format!(
63            "N={} C={}->OC={} group={} kh={} (icg={} ocg={}) HxW={}x{} -> H_out={} pad_before={} stride_h={} dil_h={}",
64            self.n,
65            self.c_in,
66            self.oc,
67            self.group,
68            self.kh,
69            self.icg(),
70            self.ocg(),
71            self.h_in,
72            self.w,
73            self.h_out,
74            self.pad_before_h,
75            self.stride_h,
76            self.dil_h,
77        )])
78    }
79
80    op_as_typed_op!();
81}
82
83impl EvalOp for BlockedConv {
84    fn is_stateless(&self) -> bool {
85        true
86    }
87
88    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
89        let x_t = inputs[0].cast_to::<f32>()?;
90        let k_t = inputs[1].cast_to::<f32>()?;
91        let b_t = inputs[2].cast_to::<f32>()?;
92        // SAFETY: just cast to f32; conv I/O tensors are standard (contiguous) layout.
93        let x = unsafe { x_t.as_slice_unchecked::<f32>() };
94        let kernel = unsafe { k_t.as_slice_unchecked::<f32>() };
95        let bias_raw = unsafe { b_t.as_slice_unchecked::<f32>() };
96        // Normalise bias to a per-output-channel vector (it may arrive as a
97        // scalar zero, empty, or already [oc]).
98        let bias_vec: Vec<f32> = match bias_raw.len() {
99            0 => vec![0.0; self.oc],
100            1 => vec![bias_raw[0]; self.oc],
101            _ => bias_raw.to_vec(),
102        };
103        let bias = bias_vec.as_slice();
104
105        let mut output =
106            unsafe { Tensor::uninitialized::<f32>(&[self.n, self.oc, self.h_out, self.w])? };
107        let out = unsafe { output.as_slice_mut_unchecked::<f32>() };
108
109        let ocg = self.ocg();
110        match ocg {
111            1 => self.run::<1>(x, kernel, bias, out),
112            2 => self.run::<2>(x, kernel, bias, out),
113            3 => self.run::<3>(x, kernel, bias, out),
114            4 => self.run::<4>(x, kernel, bias, out),
115            5 => self.run::<5>(x, kernel, bias, out),
116            6 => self.run::<6>(x, kernel, bias, out),
117            8 => self.run::<8>(x, kernel, bias, out),
118            _ => self.run_generic(x, kernel, bias, out),
119        }
120
121        Ok(tvec!(output.into_tvalue()))
122    }
123}
124
125impl BlockedConv {
126    /// Const-OCG fast path: `ocg` accumulators of WB lanes held in registers.
127    ///
128    /// The hot loop (full WB-wide blocks) touches `acc` ONLY at compile-time
129    /// constant offsets `[ocl][j]` (ocl<OCG, j<WB, both const) and stores it
130    /// whole — so LLVM's SROA promotes the OCG·WB accumulators to SSA registers
131    /// and keeps them resident across the runtime `(kh, icg)` reduction (this is
132    /// what makes it ~2.4× faster than a runtime-length-access variant, matching
133    /// the standalone microbench). `get_unchecked` keeps the runtime-derived
134    /// input/kernel/output indices bounds-check-free; all are provably in range
135    /// from the shape invariants. The `w % WB` remainder uses a scalar tail.
136    // Index loops are deliberate here: const offsets into `acc` are what let SROA
137    // keep the accumulators register-resident; iterator forms regressed codegen.
138    #[allow(clippy::needless_range_loop)]
139    fn run<const OCG: usize>(&self, x: &[f32], kernel: &[f32], bias: &[f32], out: &mut [f32]) {
140        let (icg, w, h_in, h_out, kh) = (self.icg(), self.w, self.h_in, self.h_out, self.kh);
141        let (sh, dh, pb) =
142            (self.stride_h as isize, self.dil_h as isize, self.pad_before_h as isize);
143        let kstride_oc = icg * kh; // weights row stride per output channel
144        let n_full = w / WB; // full WB-wide blocks; remainder handled after
145        for ni in 0..self.n {
146            let x_n = &x[ni * self.c_in * h_in * w..];
147            let out_n = &mut out[ni * self.oc * h_out * w..];
148            for g in 0..self.group {
149                let oc0 = g * OCG;
150                let ic0 = g * icg;
151                for oh in 0..h_out {
152                    // ---- full WB blocks: all-const acc access -> register-resident ----
153                    for blk in 0..n_full {
154                        let wb = blk * WB;
155                        let mut acc = [[0f32; WB]; OCG];
156                        for ocl in 0..OCG {
157                            let b = bias[oc0 + ocl];
158                            for j in 0..WB {
159                                acc[ocl][j] = b;
160                            }
161                        }
162                        for kh_i in 0..kh {
163                            let ih = oh as isize * sh + kh_i as isize * dh - pb;
164                            if ih < 0 || ih >= h_in as isize {
165                                continue;
166                            }
167                            let row0 = ((ic0 * h_in + ih as usize) * w + wb) as isize;
168                            for icl in 0..icg {
169                                let row_base = (row0 + (icl * h_in * w) as isize) as usize;
170                                for ocl in 0..OCG {
171                                    let wv = unsafe {
172                                        *kernel.get_unchecked(
173                                            (oc0 + ocl) * kstride_oc + icl * kh + kh_i,
174                                        )
175                                    };
176                                    let a = &mut acc[ocl];
177                                    for j in 0..WB {
178                                        a[j] += unsafe { *x_n.get_unchecked(row_base + j) } * wv;
179                                    }
180                                }
181                            }
182                        }
183                        for ocl in 0..OCG {
184                            let ob = ((oc0 + ocl) * h_out + oh) * w + wb;
185                            for j in 0..WB {
186                                unsafe { *out_n.get_unchecked_mut(ob + j) = acc[ocl][j] };
187                            }
188                        }
189                    }
190                    // ---- remainder (w % WB != 0): scalar tail accumulated in place ----
191                    let wb = n_full * WB;
192                    if wb < w {
193                        let rem = w - wb;
194                        for ocl in 0..OCG {
195                            let b = bias[oc0 + ocl];
196                            let ob = ((oc0 + ocl) * h_out + oh) * w + wb;
197                            for j in 0..rem {
198                                out_n[ob + j] = b;
199                            }
200                        }
201                        for kh_i in 0..kh {
202                            let ih = oh as isize * sh + kh_i as isize * dh - pb;
203                            if ih < 0 || ih >= h_in as isize {
204                                continue;
205                            }
206                            let ih = ih as usize;
207                            for icl in 0..icg {
208                                let row_base = ((ic0 + icl) * h_in + ih) * w + wb;
209                                for ocl in 0..OCG {
210                                    let wv = kernel[(oc0 + ocl) * kstride_oc + icl * kh + kh_i];
211                                    let ob = ((oc0 + ocl) * h_out + oh) * w + wb;
212                                    for j in 0..rem {
213                                        out_n[ob + j] += x_n[row_base + j] * wv;
214                                    }
215                                }
216                            }
217                        }
218                    }
219                }
220            }
221        }
222    }
223
224    /// Generic fallback for `ocg` outside the const-dispatched set. Correct but
225    /// not register-blocked (heap accumulators). Rarely hit for the eligible class.
226    #[allow(clippy::needless_range_loop)]
227    fn run_generic(&self, x: &[f32], kernel: &[f32], bias: &[f32], out: &mut [f32]) {
228        let (icg, ocg, w, h_in, h_out, kh) =
229            (self.icg(), self.ocg(), self.w, self.h_in, self.h_out, self.kh);
230        let (sh, dh, pb) =
231            (self.stride_h as isize, self.dil_h as isize, self.pad_before_h as isize);
232        let kstride_oc = icg * kh;
233        let mut acc = vec![0f32; ocg * w];
234        for ni in 0..self.n {
235            let x_n = &x[ni * self.c_in * h_in * w..];
236            let out_n = &mut out[ni * self.oc * h_out * w..];
237            for g in 0..self.group {
238                let oc0 = g * ocg;
239                let ic0 = g * icg;
240                for oh in 0..h_out {
241                    for ocl in 0..ocg {
242                        let b = bias[oc0 + ocl];
243                        for j in 0..w {
244                            acc[ocl * w + j] = b;
245                        }
246                    }
247                    for kh_i in 0..kh {
248                        let ih = oh as isize * sh + kh_i as isize * dh - pb;
249                        if ih < 0 || ih >= h_in as isize {
250                            continue;
251                        }
252                        let ih = ih as usize;
253                        for icl in 0..icg {
254                            let ic = ic0 + icl;
255                            let row = &x_n[(ic * h_in + ih) * w..(ic * h_in + ih) * w + w];
256                            for ocl in 0..ocg {
257                                let wv = kernel[(oc0 + ocl) * kstride_oc + icl * kh + kh_i];
258                                let a = &mut acc[ocl * w..ocl * w + w];
259                                for j in 0..w {
260                                    a[j] += row[j] * wv;
261                                }
262                            }
263                        }
264                    }
265                    for ocl in 0..ocg {
266                        let ob = ((oc0 + ocl) * h_out + oh) * w;
267                        out_n[ob..ob + w].copy_from_slice(&acc[ocl * w..ocl * w + w]);
268                    }
269                }
270            }
271        }
272    }
273}
274
275impl TypedOp for BlockedConv {
276    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
277        ensure!(inputs.len() == 3, "BlockedConv expects 3 inputs (X, kernel, bias)");
278        Ok(tvec!(f32::datum_type().fact([self.n, self.oc, self.h_out, self.w])))
279    }
280
281    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
282        let macs = self.n * self.oc * self.h_out * self.w * self.icg() * self.kh;
283        Ok(tvec!((Cost::FMA(f32::datum_type()), macs.to_dim())))
284    }
285
286    as_op!();
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    /// Independent scalar reference for the eligible conv class (NCHW, kw=1,
294    /// unit stride/dilation on W). `kernel` is `[oc, icg*kh]` (group-major,
295    /// i-major/h-minor); input channel for output `oc` is `(oc/ocg)*icg + icl`.
296    #[allow(clippy::too_many_arguments)]
297    fn reference(op: &BlockedConv, x: &[f32], kernel: &[f32], bias: &[f32]) -> Vec<f32> {
298        let (icg, ocg) = (op.icg(), op.ocg());
299        let (h_in, w, kh) = (op.h_in, op.w, op.kh);
300        let (sh, dh, pb) = (op.stride_h as isize, op.dil_h as isize, op.pad_before_h as isize);
301        let mut out = vec![0f32; op.n * op.oc * op.h_out * w];
302        for ni in 0..op.n {
303            for oc in 0..op.oc {
304                let g = oc / ocg;
305                for oh in 0..op.h_out {
306                    for wi in 0..w {
307                        let mut acc = bias[oc];
308                        for kh_i in 0..kh {
309                            let ih = oh as isize * sh + kh_i as isize * dh - pb;
310                            if ih < 0 || ih >= h_in as isize {
311                                continue;
312                            }
313                            let ih = ih as usize;
314                            for icl in 0..icg {
315                                let ic = g * icg + icl;
316                                let xv = x[((ni * op.c_in + ic) * h_in + ih) * w + wi];
317                                acc += xv * kernel[oc * (icg * kh) + icl * kh + kh_i];
318                            }
319                        }
320                        out[((ni * op.oc + oc) * op.h_out + oh) * w + wi] = acc;
321                    }
322                }
323            }
324        }
325        out
326    }
327
328    fn run_case(c_in: usize, oc: usize, group: usize, kh: usize, h_in: usize, w: usize, pb: usize) {
329        let icg = c_in / group;
330        let h_out = h_in + pb - (kh - 1); // stride=dil=1, pad_after=0
331        let op = BlockedConv {
332            n: 1,
333            c_in,
334            h_in,
335            w,
336            oc,
337            group,
338            kh,
339            stride_h: 1,
340            dil_h: 1,
341            pad_before_h: pb,
342            h_out,
343        };
344        let x: Vec<f32> = (0..c_in * h_in * w).map(|i| ((i as f32 * 0.137).sin()) * 0.7).collect();
345        let kernel: Vec<f32> =
346            (0..oc * icg * kh).map(|i| ((i as f32 * 0.091).cos()) * 0.3).collect();
347        let bias: Vec<f32> = (0..oc).map(|i| (i as f32 * 0.05) - 0.1).collect();
348
349        let want = reference(&op, &x, &kernel, &bias);
350        let got = op
351            .eval(tvec![
352                Tensor::from_shape(&[1, c_in, h_in, w], &x).unwrap().into_tvalue(),
353                Tensor::from_shape(&[oc, icg * kh], &kernel).unwrap().into_tvalue(),
354                Tensor::from_shape(&[oc], &bias).unwrap().into_tvalue(),
355            ])
356            .unwrap();
357        let got_view = got[0].to_plain_array_view::<f32>().unwrap();
358        let got = got_view.as_slice().unwrap();
359        assert_eq!(got.len(), want.len());
360        let max_abs = got.iter().zip(&want).map(|(a, b)| (a - b).abs()).fold(0.0, f32::max);
361        assert!(
362            max_abs < 1e-5,
363            "BlockedConv mismatch (c_in={c_in} oc={oc} g={group} kh={kh} h={h_in} w={w} pb={pb}): max_abs={max_abs}"
364        );
365    }
366
367    #[test]
368    fn blocked_conv_matches_reference() {
369        // df_convp.1-like: group=2, ocg=5, kh=5, causal pad, w multiple of WB.
370        run_case(64, 10, 2, 5, 12, 96, 4);
371        // full block + remainder (w=20 = 16 + 4), ocg=2.
372        run_case(4, 4, 2, 3, 5, 20, 1);
373        // remainder-only (w=5 < WB), ocg=3.
374        run_case(8, 6, 2, 4, 7, 5, 2);
375        // group=1, ocg=3, no padding.
376        run_case(6, 3, 1, 3, 8, 33, 0);
377        // ocg=1 edge.
378        run_case(4, 2, 2, 2, 6, 17, 1);
379    }
380}