1use crate::internal::*;
23
24const WB: usize = 16;
26
27#[derive(Debug, Clone, Hash, PartialEq, Eq)]
31pub struct BlockedConv {
32 pub n: usize,
33 pub c_in: usize,
34 pub h_in: usize,
35 pub w: usize,
36 pub oc: usize,
37 pub group: usize,
38 pub kh: usize,
39 pub stride_h: usize,
40 pub dil_h: usize,
41 pub pad_before_h: usize,
42 pub h_out: usize,
43}
44
45impl BlockedConv {
46 #[inline]
47 fn icg(&self) -> usize {
48 self.c_in / self.group
49 }
50 #[inline]
51 fn ocg(&self) -> usize {
52 self.oc / self.group
53 }
54}
55
56impl Op for BlockedConv {
57 fn name(&self) -> StaticName {
58 "BlockedConv".into()
59 }
60
61 fn info(&self) -> TractResult<Vec<String>> {
62 Ok(vec![format!(
63 "N={} C={}->OC={} group={} kh={} (icg={} ocg={}) HxW={}x{} -> H_out={} pad_before={} stride_h={} dil_h={}",
64 self.n,
65 self.c_in,
66 self.oc,
67 self.group,
68 self.kh,
69 self.icg(),
70 self.ocg(),
71 self.h_in,
72 self.w,
73 self.h_out,
74 self.pad_before_h,
75 self.stride_h,
76 self.dil_h,
77 )])
78 }
79
80 op_as_typed_op!();
81}
82
83impl EvalOp for BlockedConv {
84 fn is_stateless(&self) -> bool {
85 true
86 }
87
88 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
89 let x_t = inputs[0].cast_to::<f32>()?;
90 let k_t = inputs[1].cast_to::<f32>()?;
91 let b_t = inputs[2].cast_to::<f32>()?;
92 let x = unsafe { x_t.as_slice_unchecked::<f32>() };
94 let kernel = unsafe { k_t.as_slice_unchecked::<f32>() };
95 let bias_raw = unsafe { b_t.as_slice_unchecked::<f32>() };
96 let bias_vec: Vec<f32> = match bias_raw.len() {
99 0 => vec![0.0; self.oc],
100 1 => vec![bias_raw[0]; self.oc],
101 _ => bias_raw.to_vec(),
102 };
103 let bias = bias_vec.as_slice();
104
105 let mut output =
106 unsafe { Tensor::uninitialized::<f32>(&[self.n, self.oc, self.h_out, self.w])? };
107 let out = unsafe { output.as_slice_mut_unchecked::<f32>() };
108
109 let ocg = self.ocg();
110 match ocg {
111 1 => self.run::<1>(x, kernel, bias, out),
112 2 => self.run::<2>(x, kernel, bias, out),
113 3 => self.run::<3>(x, kernel, bias, out),
114 4 => self.run::<4>(x, kernel, bias, out),
115 5 => self.run::<5>(x, kernel, bias, out),
116 6 => self.run::<6>(x, kernel, bias, out),
117 8 => self.run::<8>(x, kernel, bias, out),
118 _ => self.run_generic(x, kernel, bias, out),
119 }
120
121 Ok(tvec!(output.into_tvalue()))
122 }
123}
124
125impl BlockedConv {
126 #[allow(clippy::needless_range_loop)]
139 fn run<const OCG: usize>(&self, x: &[f32], kernel: &[f32], bias: &[f32], out: &mut [f32]) {
140 let (icg, w, h_in, h_out, kh) = (self.icg(), self.w, self.h_in, self.h_out, self.kh);
141 let (sh, dh, pb) =
142 (self.stride_h as isize, self.dil_h as isize, self.pad_before_h as isize);
143 let kstride_oc = icg * kh; let n_full = w / WB; for ni in 0..self.n {
146 let x_n = &x[ni * self.c_in * h_in * w..];
147 let out_n = &mut out[ni * self.oc * h_out * w..];
148 for g in 0..self.group {
149 let oc0 = g * OCG;
150 let ic0 = g * icg;
151 for oh in 0..h_out {
152 for blk in 0..n_full {
154 let wb = blk * WB;
155 let mut acc = [[0f32; WB]; OCG];
156 for ocl in 0..OCG {
157 let b = bias[oc0 + ocl];
158 for j in 0..WB {
159 acc[ocl][j] = b;
160 }
161 }
162 for kh_i in 0..kh {
163 let ih = oh as isize * sh + kh_i as isize * dh - pb;
164 if ih < 0 || ih >= h_in as isize {
165 continue;
166 }
167 let row0 = ((ic0 * h_in + ih as usize) * w + wb) as isize;
168 for icl in 0..icg {
169 let row_base = (row0 + (icl * h_in * w) as isize) as usize;
170 for ocl in 0..OCG {
171 let wv = unsafe {
172 *kernel.get_unchecked(
173 (oc0 + ocl) * kstride_oc + icl * kh + kh_i,
174 )
175 };
176 let a = &mut acc[ocl];
177 for j in 0..WB {
178 a[j] += unsafe { *x_n.get_unchecked(row_base + j) } * wv;
179 }
180 }
181 }
182 }
183 for ocl in 0..OCG {
184 let ob = ((oc0 + ocl) * h_out + oh) * w + wb;
185 for j in 0..WB {
186 unsafe { *out_n.get_unchecked_mut(ob + j) = acc[ocl][j] };
187 }
188 }
189 }
190 let wb = n_full * WB;
192 if wb < w {
193 let rem = w - wb;
194 for ocl in 0..OCG {
195 let b = bias[oc0 + ocl];
196 let ob = ((oc0 + ocl) * h_out + oh) * w + wb;
197 for j in 0..rem {
198 out_n[ob + j] = b;
199 }
200 }
201 for kh_i in 0..kh {
202 let ih = oh as isize * sh + kh_i as isize * dh - pb;
203 if ih < 0 || ih >= h_in as isize {
204 continue;
205 }
206 let ih = ih as usize;
207 for icl in 0..icg {
208 let row_base = ((ic0 + icl) * h_in + ih) * w + wb;
209 for ocl in 0..OCG {
210 let wv = kernel[(oc0 + ocl) * kstride_oc + icl * kh + kh_i];
211 let ob = ((oc0 + ocl) * h_out + oh) * w + wb;
212 for j in 0..rem {
213 out_n[ob + j] += x_n[row_base + j] * wv;
214 }
215 }
216 }
217 }
218 }
219 }
220 }
221 }
222 }
223
224 #[allow(clippy::needless_range_loop)]
227 fn run_generic(&self, x: &[f32], kernel: &[f32], bias: &[f32], out: &mut [f32]) {
228 let (icg, ocg, w, h_in, h_out, kh) =
229 (self.icg(), self.ocg(), self.w, self.h_in, self.h_out, self.kh);
230 let (sh, dh, pb) =
231 (self.stride_h as isize, self.dil_h as isize, self.pad_before_h as isize);
232 let kstride_oc = icg * kh;
233 let mut acc = vec![0f32; ocg * w];
234 for ni in 0..self.n {
235 let x_n = &x[ni * self.c_in * h_in * w..];
236 let out_n = &mut out[ni * self.oc * h_out * w..];
237 for g in 0..self.group {
238 let oc0 = g * ocg;
239 let ic0 = g * icg;
240 for oh in 0..h_out {
241 for ocl in 0..ocg {
242 let b = bias[oc0 + ocl];
243 for j in 0..w {
244 acc[ocl * w + j] = b;
245 }
246 }
247 for kh_i in 0..kh {
248 let ih = oh as isize * sh + kh_i as isize * dh - pb;
249 if ih < 0 || ih >= h_in as isize {
250 continue;
251 }
252 let ih = ih as usize;
253 for icl in 0..icg {
254 let ic = ic0 + icl;
255 let row = &x_n[(ic * h_in + ih) * w..(ic * h_in + ih) * w + w];
256 for ocl in 0..ocg {
257 let wv = kernel[(oc0 + ocl) * kstride_oc + icl * kh + kh_i];
258 let a = &mut acc[ocl * w..ocl * w + w];
259 for j in 0..w {
260 a[j] += row[j] * wv;
261 }
262 }
263 }
264 }
265 for ocl in 0..ocg {
266 let ob = ((oc0 + ocl) * h_out + oh) * w;
267 out_n[ob..ob + w].copy_from_slice(&acc[ocl * w..ocl * w + w]);
268 }
269 }
270 }
271 }
272 }
273}
274
275impl TypedOp for BlockedConv {
276 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
277 ensure!(inputs.len() == 3, "BlockedConv expects 3 inputs (X, kernel, bias)");
278 Ok(tvec!(f32::datum_type().fact([self.n, self.oc, self.h_out, self.w])))
279 }
280
281 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
282 let macs = self.n * self.oc * self.h_out * self.w * self.icg() * self.kh;
283 Ok(tvec!((Cost::FMA(f32::datum_type()), macs.to_dim())))
284 }
285
286 as_op!();
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[allow(clippy::too_many_arguments)]
297 fn reference(op: &BlockedConv, x: &[f32], kernel: &[f32], bias: &[f32]) -> Vec<f32> {
298 let (icg, ocg) = (op.icg(), op.ocg());
299 let (h_in, w, kh) = (op.h_in, op.w, op.kh);
300 let (sh, dh, pb) = (op.stride_h as isize, op.dil_h as isize, op.pad_before_h as isize);
301 let mut out = vec![0f32; op.n * op.oc * op.h_out * w];
302 for ni in 0..op.n {
303 for oc in 0..op.oc {
304 let g = oc / ocg;
305 for oh in 0..op.h_out {
306 for wi in 0..w {
307 let mut acc = bias[oc];
308 for kh_i in 0..kh {
309 let ih = oh as isize * sh + kh_i as isize * dh - pb;
310 if ih < 0 || ih >= h_in as isize {
311 continue;
312 }
313 let ih = ih as usize;
314 for icl in 0..icg {
315 let ic = g * icg + icl;
316 let xv = x[((ni * op.c_in + ic) * h_in + ih) * w + wi];
317 acc += xv * kernel[oc * (icg * kh) + icl * kh + kh_i];
318 }
319 }
320 out[((ni * op.oc + oc) * op.h_out + oh) * w + wi] = acc;
321 }
322 }
323 }
324 }
325 out
326 }
327
328 fn run_case(c_in: usize, oc: usize, group: usize, kh: usize, h_in: usize, w: usize, pb: usize) {
329 let icg = c_in / group;
330 let h_out = h_in + pb - (kh - 1); let op = BlockedConv {
332 n: 1,
333 c_in,
334 h_in,
335 w,
336 oc,
337 group,
338 kh,
339 stride_h: 1,
340 dil_h: 1,
341 pad_before_h: pb,
342 h_out,
343 };
344 let x: Vec<f32> = (0..c_in * h_in * w).map(|i| ((i as f32 * 0.137).sin()) * 0.7).collect();
345 let kernel: Vec<f32> =
346 (0..oc * icg * kh).map(|i| ((i as f32 * 0.091).cos()) * 0.3).collect();
347 let bias: Vec<f32> = (0..oc).map(|i| (i as f32 * 0.05) - 0.1).collect();
348
349 let want = reference(&op, &x, &kernel, &bias);
350 let got = op
351 .eval(tvec![
352 Tensor::from_shape(&[1, c_in, h_in, w], &x).unwrap().into_tvalue(),
353 Tensor::from_shape(&[oc, icg * kh], &kernel).unwrap().into_tvalue(),
354 Tensor::from_shape(&[oc], &bias).unwrap().into_tvalue(),
355 ])
356 .unwrap();
357 let got_view = got[0].to_plain_array_view::<f32>().unwrap();
358 let got = got_view.as_slice().unwrap();
359 assert_eq!(got.len(), want.len());
360 let max_abs = got.iter().zip(&want).map(|(a, b)| (a - b).abs()).fold(0.0, f32::max);
361 assert!(
362 max_abs < 1e-5,
363 "BlockedConv mismatch (c_in={c_in} oc={oc} g={group} kh={kh} h={h_in} w={w} pb={pb}): max_abs={max_abs}"
364 );
365 }
366
367 #[test]
368 fn blocked_conv_matches_reference() {
369 run_case(64, 10, 2, 5, 12, 96, 4);
371 run_case(4, 4, 2, 3, 5, 20, 1);
373 run_case(8, 6, 2, 4, 7, 5, 2);
375 run_case(6, 3, 1, 3, 8, 33, 0);
377 run_case(4, 2, 2, 2, 6, 17, 1);
379 }
380}