1use rayon::{ThreadPool, prelude::*};
2use yscv_tensor::{AlignedVec, Tensor, TensorError};
3
4use super::super::error::KernelError;
5use super::config::{
6 Conv2dPlan, Conv2dSpec, DepthwiseConv2dPlan, DepthwiseConv2dSpec, ParallelElementwiseConfig,
7 SeparableConv2dKernels, SeparableConv2dSpec, should_parallelize_len,
8};
9
10pub fn conv2d_nhwc_with_config_and_pool(
11 input: &Tensor,
12 kernel: &Tensor,
13 bias: Option<&Tensor>,
14 spec: Conv2dSpec,
15 config: ParallelElementwiseConfig,
16 thread_pool: Option<&ThreadPool>,
17) -> Result<Tensor, KernelError> {
18 let plan = build_conv2d_plan(input, kernel, bias, spec)?;
19
20 #[cfg(target_arch = "aarch64")]
25 if plan.kernel_h == 3
26 && plan.kernel_w == 3
27 && plan.batch == 1
28 && !cfg!(miri)
29 && std::arch::is_aarch64_feature_detected!("neon")
30 && plan.out_h * plan.out_w < 4096
31 {
32 #[allow(unsafe_code)]
33 let mut output = AlignedVec::<f32>::uninitialized(plan.output_len);
34 #[allow(unsafe_code)]
35 unsafe {
36 conv2d_3x3_direct_neon(
37 input.data(),
38 kernel.data(),
39 &mut output,
40 plan.in_w,
41 plan.in_channels,
42 plan.out_channels,
43 plan.out_h,
44 plan.out_w,
45 plan.stride_h,
46 plan.stride_w,
47 );
48 }
49 if let Some(b) = bias {
50 let bd = b.data();
51 for i in 0..plan.out_h * plan.out_w {
52 for c in 0..plan.out_channels {
53 output[i * plan.out_channels + c] += bd[c];
54 }
55 }
56 }
57 return Tensor::from_aligned(vec![1, plan.out_h, plan.out_w, plan.out_channels], output)
58 .map_err(Into::into);
59 }
60
61 #[cfg(target_arch = "x86_64")]
62 if plan.kernel_h == 3
63 && plan.kernel_w == 3
64 && plan.batch == 1
65 && !cfg!(miri)
66 && is_x86_feature_detected!("avx")
67 && is_x86_feature_detected!("fma")
68 && plan.out_h * plan.out_w < 4096
69 {
70 #[allow(unsafe_code)]
71 let mut output = AlignedVec::<f32>::uninitialized(plan.output_len);
72 #[allow(unsafe_code)]
73 unsafe {
74 conv2d_3x3_direct_avx(
75 input.data(),
76 kernel.data(),
77 &mut output,
78 plan.in_w,
79 plan.in_channels,
80 plan.out_channels,
81 plan.out_h,
82 plan.out_w,
83 plan.stride_h,
84 plan.stride_w,
85 );
86 }
87 if let Some(b) = bias {
88 let bd = b.data();
89 for i in 0..plan.out_h * plan.out_w {
90 for c in 0..plan.out_channels {
91 output[i * plan.out_channels + c] += bd[c];
92 }
93 }
94 }
95 return Tensor::from_aligned(vec![1, plan.out_h, plan.out_w, plan.out_channels], output)
96 .map_err(Into::into);
97 }
98
99 #[cfg(target_arch = "x86_64")]
100 if plan.kernel_h == 3
101 && plan.kernel_w == 3
102 && plan.batch == 1
103 && !cfg!(miri)
104 && is_x86_feature_detected!("fma")
105 && plan.out_h * plan.out_w < 4096
106 {
107 #[allow(unsafe_code)]
108 let mut output = AlignedVec::<f32>::uninitialized(plan.output_len);
109 #[allow(unsafe_code)]
110 unsafe {
111 conv2d_3x3_direct_sse(
112 input.data(),
113 kernel.data(),
114 &mut output,
115 plan.in_w,
116 plan.in_channels,
117 plan.out_channels,
118 plan.out_h,
119 plan.out_w,
120 plan.stride_h,
121 plan.stride_w,
122 );
123 }
124 if let Some(b) = bias {
125 let bd = b.data();
126 for i in 0..plan.out_h * plan.out_w {
127 for c in 0..plan.out_channels {
128 output[i * plan.out_channels + c] += bd[c];
129 }
130 }
131 }
132 return Tensor::from_aligned(vec![1, plan.out_h, plan.out_w, plan.out_channels], output)
133 .map_err(Into::into);
134 }
135
136 #[cfg(feature = "blas")]
139 if !cfg!(miri) && plan.batch == 1 {
140 return conv2d_im2col_gemm(&plan, input.data(), kernel.data(), bias.map(Tensor::data));
141 }
142
143 let input_data = input.data();
144 let kernel_data = kernel.data();
145 let bias_data = bias.map(Tensor::data);
146 let out_row_len = plan.out_w * plan.out_channels;
147 if plan.output_len == 0 || out_row_len == 0 {
148 return Tensor::from_vec(
149 vec![plan.batch, plan.out_h, plan.out_w, plan.out_channels],
150 vec![],
151 )
152 .map_err(Into::into);
153 }
154
155 #[allow(unsafe_code)]
157 let mut output = AlignedVec::<f32>::uninitialized(plan.output_len);
158
159 if should_parallelize_len(plan.output_len, config.min_parallel_elements, thread_pool) {
160 let mut work = || {
161 output
162 .par_chunks_mut(out_row_len)
163 .enumerate()
164 .for_each(|(row_idx, out_row)| {
165 conv2d_nhwc_row(input_data, kernel_data, bias_data, plan, row_idx, out_row);
166 });
167 };
168 if let Some(pool) = thread_pool {
169 pool.install(work);
170 } else {
171 work();
172 }
173 } else {
174 for (row_idx, out_row) in output.chunks_mut(out_row_len).enumerate() {
175 conv2d_nhwc_row(input_data, kernel_data, bias_data, plan, row_idx, out_row);
176 }
177 }
178
179 Tensor::from_aligned(
180 vec![plan.batch, plan.out_h, plan.out_w, plan.out_channels],
181 output,
182 )
183 .map_err(Into::into)
184}
185
186pub fn depthwise_conv2d_nhwc_with_config_and_pool(
187 input: &Tensor,
188 kernel: &Tensor,
189 bias: Option<&Tensor>,
190 spec: DepthwiseConv2dSpec,
191 config: ParallelElementwiseConfig,
192 thread_pool: Option<&ThreadPool>,
193) -> Result<Tensor, KernelError> {
194 let plan = build_depthwise_conv2d_plan(input, kernel, bias, spec)?;
195 let input_data = input.data();
196 let kernel_data = kernel.data();
197 let bias_data = bias.map(Tensor::data);
198 let out_row_len = plan.out_w * plan.out_channels;
199 if plan.output_len == 0 || out_row_len == 0 {
200 return Tensor::from_aligned(
201 vec![plan.batch, plan.out_h, plan.out_w, plan.out_channels],
202 AlignedVec::<f32>::calloc(plan.output_len),
203 )
204 .map_err(Into::into);
205 }
206
207 let mut output = AlignedVec::<f32>::uninitialized(plan.output_len);
208
209 if should_parallelize_len(plan.output_len, config.min_parallel_elements, thread_pool) {
210 let mut work = || {
211 output
212 .par_chunks_mut(out_row_len)
213 .enumerate()
214 .for_each(|(row_idx, out_row)| {
215 depthwise_conv2d_nhwc_row(
216 input_data,
217 kernel_data,
218 bias_data,
219 plan,
220 row_idx,
221 out_row,
222 );
223 });
224 };
225 if let Some(pool) = thread_pool {
226 pool.install(work);
227 } else {
228 work();
229 }
230 } else {
231 for (row_idx, out_row) in output.chunks_mut(out_row_len).enumerate() {
232 depthwise_conv2d_nhwc_row(input_data, kernel_data, bias_data, plan, row_idx, out_row);
233 }
234 }
235
236 Tensor::from_aligned(
237 vec![plan.batch, plan.out_h, plan.out_w, plan.out_channels],
238 output,
239 )
240 .map_err(Into::into)
241}
242
243pub fn separable_conv2d_nhwc_with_config_and_pool(
244 input: &Tensor,
245 kernels: SeparableConv2dKernels<'_>,
246 spec: SeparableConv2dSpec,
247 config: ParallelElementwiseConfig,
248 thread_pool: Option<&ThreadPool>,
249) -> Result<Tensor, KernelError> {
250 if kernels.pointwise_kernel.rank() != 4
251 || kernels.pointwise_kernel.shape()[0] != 1
252 || kernels.pointwise_kernel.shape()[1] != 1
253 {
254 return Err(KernelError::InvalidSeparablePointwiseKernelShape {
255 pointwise_shape: kernels.pointwise_kernel.shape().to_vec(),
256 });
257 }
258
259 let depthwise_out = depthwise_conv2d_nhwc_with_config_and_pool(
260 input,
261 kernels.depthwise_kernel,
262 kernels.depthwise_bias,
263 DepthwiseConv2dSpec {
264 stride_h: spec.stride_h,
265 stride_w: spec.stride_w,
266 },
267 config,
268 thread_pool,
269 )?;
270
271 conv2d_nhwc_with_config_and_pool(
272 &depthwise_out,
273 kernels.pointwise_kernel,
274 kernels.pointwise_bias,
275 Conv2dSpec {
276 stride_h: 1,
277 stride_w: 1,
278 },
279 config,
280 thread_pool,
281 )
282}
283
284fn build_conv2d_plan(
285 input: &Tensor,
286 kernel: &Tensor,
287 bias: Option<&Tensor>,
288 spec: Conv2dSpec,
289) -> Result<Conv2dPlan, KernelError> {
290 let stride_h = spec.stride_h;
291 let stride_w = spec.stride_w;
292 if input.rank() != 4 || kernel.rank() != 4 {
293 return Err(KernelError::InvalidConvRank {
294 input_rank: input.rank(),
295 kernel_rank: kernel.rank(),
296 });
297 }
298 if stride_h == 0 || stride_w == 0 {
299 return Err(KernelError::InvalidConvParameters {
300 kernel_h: kernel.shape()[0],
301 kernel_w: kernel.shape()[1],
302 stride_h,
303 stride_w,
304 });
305 }
306
307 let batch = input.shape()[0];
308 let in_h = input.shape()[1];
309 let in_w = input.shape()[2];
310 let in_channels = input.shape()[3];
311 let kernel_h = kernel.shape()[0];
312 let kernel_w = kernel.shape()[1];
313 let kernel_in_channels = kernel.shape()[2];
314 let out_channels = kernel.shape()[3];
315
316 if kernel_h == 0 || kernel_w == 0 {
317 return Err(KernelError::InvalidConvParameters {
318 kernel_h,
319 kernel_w,
320 stride_h,
321 stride_w,
322 });
323 }
324 if kernel_in_channels != in_channels {
325 return Err(KernelError::ConvChannelMismatch {
326 input_channels: in_channels,
327 kernel_in_channels,
328 });
329 }
330 if kernel_h > in_h || kernel_w > in_w {
331 return Err(KernelError::ConvKernelLargerThanInput {
332 input_h: in_h,
333 input_w: in_w,
334 kernel_h,
335 kernel_w,
336 });
337 }
338 if let Some(bias_tensor) = bias
339 && (bias_tensor.rank() != 1 || bias_tensor.shape()[0] != out_channels)
340 {
341 return Err(KernelError::ConvBiasShapeMismatch {
342 bias_shape: bias_tensor.shape().to_vec(),
343 out_channels,
344 });
345 }
346
347 let out_h = (in_h - kernel_h) / stride_h + 1;
348 let out_w = (in_w - kernel_w) / stride_w + 1;
349 let output_len = batch
350 .checked_mul(out_h)
351 .and_then(|v| v.checked_mul(out_w))
352 .and_then(|v| v.checked_mul(out_channels))
353 .ok_or_else(|| {
354 KernelError::Tensor(TensorError::SizeOverflow {
355 shape: vec![batch, out_h, out_w, out_channels],
356 })
357 })?;
358
359 Ok(Conv2dPlan {
360 batch,
361 in_h,
362 in_w,
363 in_channels,
364 out_h,
365 out_w,
366 out_channels,
367 kernel_h,
368 kernel_w,
369 stride_h,
370 stride_w,
371 output_len,
372 })
373}
374
375fn build_depthwise_conv2d_plan(
376 input: &Tensor,
377 kernel: &Tensor,
378 bias: Option<&Tensor>,
379 spec: DepthwiseConv2dSpec,
380) -> Result<DepthwiseConv2dPlan, KernelError> {
381 let stride_h = spec.stride_h;
382 let stride_w = spec.stride_w;
383 if input.rank() != 4 || kernel.rank() != 4 {
384 return Err(KernelError::InvalidDepthwiseConvRank {
385 input_rank: input.rank(),
386 kernel_rank: kernel.rank(),
387 });
388 }
389 if stride_h == 0 || stride_w == 0 {
390 return Err(KernelError::InvalidDepthwiseConvParameters {
391 kernel_h: kernel.shape()[0],
392 kernel_w: kernel.shape()[1],
393 stride_h,
394 stride_w,
395 });
396 }
397
398 let batch = input.shape()[0];
399 let in_h = input.shape()[1];
400 let in_w = input.shape()[2];
401 let channels = input.shape()[3];
402 let kernel_h = kernel.shape()[0];
403 let kernel_w = kernel.shape()[1];
404 let kernel_channels = kernel.shape()[2];
405 let depth_multiplier = kernel.shape()[3];
406
407 if kernel_h == 0 || kernel_w == 0 || depth_multiplier == 0 {
408 return Err(KernelError::InvalidDepthwiseConvParameters {
409 kernel_h,
410 kernel_w,
411 stride_h,
412 stride_w,
413 });
414 }
415 if kernel_channels != channels {
416 return Err(KernelError::DepthwiseConvChannelMismatch {
417 input_channels: channels,
418 kernel_channels,
419 });
420 }
421 if kernel_h > in_h || kernel_w > in_w {
422 return Err(KernelError::DepthwiseConvKernelLargerThanInput {
423 input_h: in_h,
424 input_w: in_w,
425 kernel_h,
426 kernel_w,
427 });
428 }
429
430 let out_channels = channels.checked_mul(depth_multiplier).ok_or_else(|| {
431 KernelError::Tensor(TensorError::SizeOverflow {
432 shape: vec![channels, depth_multiplier],
433 })
434 })?;
435 if let Some(bias_tensor) = bias
436 && (bias_tensor.rank() != 1 || bias_tensor.shape()[0] != out_channels)
437 {
438 return Err(KernelError::DepthwiseConvBiasShapeMismatch {
439 bias_shape: bias_tensor.shape().to_vec(),
440 out_channels,
441 });
442 }
443
444 let out_h = (in_h - kernel_h) / stride_h + 1;
445 let out_w = (in_w - kernel_w) / stride_w + 1;
446 let output_len = batch
447 .checked_mul(out_h)
448 .and_then(|v| v.checked_mul(out_w))
449 .and_then(|v| v.checked_mul(out_channels))
450 .ok_or_else(|| {
451 KernelError::Tensor(TensorError::SizeOverflow {
452 shape: vec![batch, out_h, out_w, out_channels],
453 })
454 })?;
455
456 Ok(DepthwiseConv2dPlan {
457 batch,
458 in_h,
459 in_w,
460 channels,
461 depth_multiplier,
462 out_h,
463 out_w,
464 out_channels,
465 kernel_h,
466 kernel_w,
467 stride_h,
468 stride_w,
469 output_len,
470 })
471}
472
473fn conv2d_nhwc_row(
474 input: &[f32],
475 kernel: &[f32],
476 bias: Option<&[f32]>,
477 plan: Conv2dPlan,
478 row_idx: usize,
479 out_row: &mut [f32],
480) {
481 let batch_idx = row_idx / plan.out_h;
482 let out_y = row_idx % plan.out_h;
483 let in_y0 = out_y * plan.stride_h;
484 let batch_input_base = batch_idx * plan.in_h * plan.in_w * plan.in_channels;
485
486 for out_x in 0..plan.out_w {
487 let in_x0 = out_x * plan.stride_w;
488 let out_cell_base = out_x * plan.out_channels;
489 let out_slice = &mut out_row[out_cell_base..out_cell_base + plan.out_channels];
490
491 if let Some(bias_values) = bias {
493 out_slice.copy_from_slice(&bias_values[..plan.out_channels]);
494 } else {
495 out_slice.fill(0.0);
496 }
497
498 for ky in 0..plan.kernel_h {
500 let in_y = in_y0 + ky;
501 let input_row_base = batch_input_base + (in_y * plan.in_w + in_x0) * plan.in_channels;
502 let kernel_row_base = ky * plan.kernel_w * plan.in_channels * plan.out_channels;
503
504 for kx in 0..plan.kernel_w {
505 let input_pixel_base = input_row_base + kx * plan.in_channels;
506 let kernel_pixel_base = kernel_row_base + kx * plan.in_channels * plan.out_channels;
507
508 for in_channel in 0..plan.in_channels {
509 let input_val = input[input_pixel_base + in_channel];
510 let k_base = kernel_pixel_base + in_channel * plan.out_channels;
511 conv_fma_row(
513 out_slice,
514 &kernel[k_base..k_base + plan.out_channels],
515 input_val,
516 );
517 }
518 }
519 }
520 }
521}
522
523#[cfg(target_arch = "aarch64")]
532#[target_feature(enable = "neon")]
533#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
534unsafe fn conv2d_3x3_direct_neon(
535 input: &[f32], kernel: &[f32], output: &mut [f32], w: usize,
539 c_in: usize,
540 c_out: usize,
541 out_h: usize,
542 out_w: usize,
543 stride_h: usize,
544 stride_w: usize,
545) {
546 use std::arch::aarch64::*;
547
548 debug_assert!(
550 input.len() >= ((out_h.saturating_sub(1)) * stride_h + 3) * w * c_in,
551 "conv2d_3x3_direct_neon: input too small"
552 );
553 debug_assert!(
554 output.len() >= out_h * out_w * c_out,
555 "conv2d_3x3_direct_neon: output too small"
556 );
557 debug_assert!(
558 kernel.len() >= 3 * 3 * c_in * c_out,
559 "conv2d_3x3_direct_neon: kernel too small"
560 );
561
562 for oy in 0..out_h {
563 let iy_base = oy * stride_h;
564
565 let mut ox = 0usize;
572 if stride_w == 1 {
573 while ox + 2 <= out_w {
574 let ix_base = ox; let out_off_a = (oy * out_w + ox) * c_out;
576 let out_off_b = out_off_a + c_out;
577
578 let mut co = 0;
579 while co + 8 <= c_out {
580 let mut acc_a0 = vdupq_n_f32(0.0);
581 let mut acc_a1 = vdupq_n_f32(0.0);
582 let mut acc_b0 = vdupq_n_f32(0.0);
583 let mut acc_b1 = vdupq_n_f32(0.0);
584
585 for ky in 0..3 {
586 let iy = iy_base + ky;
587 let row_base = iy * w;
588 for ci in 0..c_in {
591 let in0 = *input.get_unchecked((row_base + ix_base) * c_in + ci);
592 let in1 = *input.get_unchecked((row_base + ix_base + 1) * c_in + ci);
593 let in2 = *input.get_unchecked((row_base + ix_base + 2) * c_in + ci);
594 let in3 = *input.get_unchecked((row_base + ix_base + 3) * c_in + ci);
595
596 let k0_off = ky * 3 * c_in * c_out + ci * c_out + co;
598 let k1_off = (ky * 3 + 1) * c_in * c_out + ci * c_out + co;
599 let k2_off = (ky * 3 + 2) * c_in * c_out + ci * c_out + co;
600 let kw0_lo = vld1q_f32(kernel.as_ptr().add(k0_off));
601 let kw0_hi = vld1q_f32(kernel.as_ptr().add(k0_off + 4));
602 let kw1_lo = vld1q_f32(kernel.as_ptr().add(k1_off));
603 let kw1_hi = vld1q_f32(kernel.as_ptr().add(k1_off + 4));
604 let kw2_lo = vld1q_f32(kernel.as_ptr().add(k2_off));
605 let kw2_hi = vld1q_f32(kernel.as_ptr().add(k2_off + 4));
606
607 let va0 = vdupq_n_f32(in0);
609 let va1 = vdupq_n_f32(in1);
610 let va2 = vdupq_n_f32(in2);
611 acc_a0 = vfmaq_f32(acc_a0, va0, kw0_lo);
612 acc_a1 = vfmaq_f32(acc_a1, va0, kw0_hi);
613 acc_a0 = vfmaq_f32(acc_a0, va1, kw1_lo);
614 acc_a1 = vfmaq_f32(acc_a1, va1, kw1_hi);
615 acc_a0 = vfmaq_f32(acc_a0, va2, kw2_lo);
616 acc_a1 = vfmaq_f32(acc_a1, va2, kw2_hi);
617
618 let vb3 = vdupq_n_f32(in3);
620 acc_b0 = vfmaq_f32(acc_b0, va1, kw0_lo);
621 acc_b1 = vfmaq_f32(acc_b1, va1, kw0_hi);
622 acc_b0 = vfmaq_f32(acc_b0, va2, kw1_lo);
623 acc_b1 = vfmaq_f32(acc_b1, va2, kw1_hi);
624 acc_b0 = vfmaq_f32(acc_b0, vb3, kw2_lo);
625 acc_b1 = vfmaq_f32(acc_b1, vb3, kw2_hi);
626 }
627 }
628
629 vst1q_f32(output.as_mut_ptr().add(out_off_a + co), acc_a0);
630 vst1q_f32(output.as_mut_ptr().add(out_off_a + co + 4), acc_a1);
631 vst1q_f32(output.as_mut_ptr().add(out_off_b + co), acc_b0);
632 vst1q_f32(output.as_mut_ptr().add(out_off_b + co + 4), acc_b1);
633 co += 8;
634 }
635
636 while co + 4 <= c_out {
637 let mut acc_a = vdupq_n_f32(0.0);
638 let mut acc_b = vdupq_n_f32(0.0);
639
640 for ky in 0..3 {
641 let iy = iy_base + ky;
642 let row_base = iy * w;
643 for ci in 0..c_in {
644 let in0 = *input.get_unchecked((row_base + ix_base) * c_in + ci);
645 let in1 = *input.get_unchecked((row_base + ix_base + 1) * c_in + ci);
646 let in2 = *input.get_unchecked((row_base + ix_base + 2) * c_in + ci);
647 let in3 = *input.get_unchecked((row_base + ix_base + 3) * c_in + ci);
648
649 let k0_off = ky * 3 * c_in * c_out + ci * c_out + co;
650 let k1_off = (ky * 3 + 1) * c_in * c_out + ci * c_out + co;
651 let k2_off = (ky * 3 + 2) * c_in * c_out + ci * c_out + co;
652 let kw0 = vld1q_f32(kernel.as_ptr().add(k0_off));
653 let kw1 = vld1q_f32(kernel.as_ptr().add(k1_off));
654 let kw2 = vld1q_f32(kernel.as_ptr().add(k2_off));
655
656 let va0 = vdupq_n_f32(in0);
657 let va1 = vdupq_n_f32(in1);
658 let va2 = vdupq_n_f32(in2);
659 acc_a = vfmaq_f32(acc_a, va0, kw0);
660 acc_a = vfmaq_f32(acc_a, va1, kw1);
661 acc_a = vfmaq_f32(acc_a, va2, kw2);
662
663 let vb3 = vdupq_n_f32(in3);
664 acc_b = vfmaq_f32(acc_b, va1, kw0);
665 acc_b = vfmaq_f32(acc_b, va2, kw1);
666 acc_b = vfmaq_f32(acc_b, vb3, kw2);
667 }
668 }
669
670 vst1q_f32(output.as_mut_ptr().add(out_off_a + co), acc_a);
671 vst1q_f32(output.as_mut_ptr().add(out_off_b + co), acc_b);
672 co += 4;
673 }
674
675 while co < c_out {
676 let mut acc_a = 0.0f32;
677 let mut acc_b = 0.0f32;
678 for ky in 0..3 {
679 let iy = iy_base + ky;
680 let row_base = iy * w;
681 for ci in 0..c_in {
682 let in0 = input[(row_base + ix_base) * c_in + ci];
683 let in1 = input[(row_base + ix_base + 1) * c_in + ci];
684 let in2 = input[(row_base + ix_base + 2) * c_in + ci];
685 let in3 = input[(row_base + ix_base + 3) * c_in + ci];
686 let k0 = kernel[ky * 3 * c_in * c_out + ci * c_out + co];
687 let k1 = kernel[(ky * 3 + 1) * c_in * c_out + ci * c_out + co];
688 let k2 = kernel[(ky * 3 + 2) * c_in * c_out + ci * c_out + co];
689 acc_a += in0 * k0 + in1 * k1 + in2 * k2;
690 acc_b += in1 * k0 + in2 * k1 + in3 * k2;
691 }
692 }
693 *output.get_unchecked_mut(out_off_a + co) = acc_a;
694 *output.get_unchecked_mut(out_off_b + co) = acc_b;
695 co += 1;
696 }
697
698 ox += 2;
699 }
700 }
701
702 while ox < out_w {
704 let ix_base = ox * stride_w;
705 let out_off = (oy * out_w + ox) * c_out;
706
707 let mut co = 0;
708 while co + 8 <= c_out {
709 let mut acc0 = vdupq_n_f32(0.0);
710 let mut acc1 = vdupq_n_f32(0.0);
711
712 for ky in 0..3 {
713 for kx in 0..3 {
714 let iy = iy_base + ky;
715 let ix = ix_base + kx;
716 let in_off = (iy * w + ix) * c_in;
717 let k_base = (ky * 3 + kx) * c_in * c_out;
718
719 for ci in 0..c_in {
720 let iv = vdupq_n_f32(*input.get_unchecked(in_off + ci));
721 let koff = k_base + ci * c_out + co;
722 acc0 = vfmaq_f32(acc0, iv, vld1q_f32(kernel.as_ptr().add(koff)));
723 acc1 = vfmaq_f32(acc1, iv, vld1q_f32(kernel.as_ptr().add(koff + 4)));
724 }
725 }
726 }
727
728 vst1q_f32(output.as_mut_ptr().add(out_off + co), acc0);
729 vst1q_f32(output.as_mut_ptr().add(out_off + co + 4), acc1);
730 co += 8;
731 }
732
733 while co + 4 <= c_out {
734 let mut acc = vdupq_n_f32(0.0);
735 for ky in 0..3 {
736 for kx in 0..3 {
737 let iy = iy_base + ky;
738 let ix = ix_base + kx;
739 let in_off = (iy * w + ix) * c_in;
740 for ci in 0..c_in {
741 let iv = vdupq_n_f32(*input.get_unchecked(in_off + ci));
742 acc = vfmaq_f32(
743 acc,
744 iv,
745 vld1q_f32(
746 kernel
747 .as_ptr()
748 .add((ky * 3 + kx) * c_in * c_out + ci * c_out + co),
749 ),
750 );
751 }
752 }
753 }
754 vst1q_f32(output.as_mut_ptr().add(out_off + co), acc);
755 co += 4;
756 }
757
758 while co < c_out {
760 let mut acc = 0.0f32;
761 for ky in 0..3 {
762 for kx in 0..3 {
763 let iy = iy_base + ky;
764 let ix = ix_base + kx;
765 for ci in 0..c_in {
766 acc += input[(iy * w + ix) * c_in + ci]
767 * kernel[(ky * 3 + kx) * c_in * c_out + ci * c_out + co];
768 }
769 }
770 }
771 *output.get_unchecked_mut(out_off + co) = acc;
772 co += 1;
773 }
774
775 ox += 1;
776 }
777 }
778}
779
780#[cfg(target_arch = "x86_64")]
784#[target_feature(enable = "avx", enable = "fma")]
785#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
786unsafe fn conv2d_3x3_direct_avx(
787 input: &[f32], kernel: &[f32], output: &mut [f32], w: usize,
791 c_in: usize,
792 c_out: usize,
793 out_h: usize,
794 out_w: usize,
795 stride_h: usize,
796 stride_w: usize,
797) {
798 use std::arch::x86_64::*;
799
800 for oy in 0..out_h {
801 let iy_base = oy * stride_h;
802
803 let mut ox = 0usize;
804 if stride_w == 1 {
805 while ox + 2 <= out_w {
806 let ix_base = ox; let out_off_a = (oy * out_w + ox) * c_out;
808 let out_off_b = out_off_a + c_out;
809
810 let mut co = 0;
812 while co + 16 <= c_out {
813 let mut acc_a0 = _mm256_setzero_ps();
814 let mut acc_a1 = _mm256_setzero_ps();
815 let mut acc_b0 = _mm256_setzero_ps();
816 let mut acc_b1 = _mm256_setzero_ps();
817
818 for ky in 0..3 {
819 let iy = iy_base + ky;
820 let row_base = iy * w;
821 for ci in 0..c_in {
822 let in0 = *input.get_unchecked((row_base + ix_base) * c_in + ci);
823 let in1 = *input.get_unchecked((row_base + ix_base + 1) * c_in + ci);
824 let in2 = *input.get_unchecked((row_base + ix_base + 2) * c_in + ci);
825 let in3 = *input.get_unchecked((row_base + ix_base + 3) * c_in + ci);
826
827 let k0_off = ky * 3 * c_in * c_out + ci * c_out + co;
828 let k1_off = (ky * 3 + 1) * c_in * c_out + ci * c_out + co;
829 let k2_off = (ky * 3 + 2) * c_in * c_out + ci * c_out + co;
830 let kw0_lo = _mm256_loadu_ps(kernel.as_ptr().add(k0_off));
831 let kw0_hi = _mm256_loadu_ps(kernel.as_ptr().add(k0_off + 8));
832 let kw1_lo = _mm256_loadu_ps(kernel.as_ptr().add(k1_off));
833 let kw1_hi = _mm256_loadu_ps(kernel.as_ptr().add(k1_off + 8));
834 let kw2_lo = _mm256_loadu_ps(kernel.as_ptr().add(k2_off));
835 let kw2_hi = _mm256_loadu_ps(kernel.as_ptr().add(k2_off + 8));
836
837 let va0 = _mm256_set1_ps(in0);
839 let va1 = _mm256_set1_ps(in1);
840 let va2 = _mm256_set1_ps(in2);
841 acc_a0 = _mm256_fmadd_ps(va0, kw0_lo, acc_a0);
842 acc_a1 = _mm256_fmadd_ps(va0, kw0_hi, acc_a1);
843 acc_a0 = _mm256_fmadd_ps(va1, kw1_lo, acc_a0);
844 acc_a1 = _mm256_fmadd_ps(va1, kw1_hi, acc_a1);
845 acc_a0 = _mm256_fmadd_ps(va2, kw2_lo, acc_a0);
846 acc_a1 = _mm256_fmadd_ps(va2, kw2_hi, acc_a1);
847
848 let vb3 = _mm256_set1_ps(in3);
850 acc_b0 = _mm256_fmadd_ps(va1, kw0_lo, acc_b0);
851 acc_b1 = _mm256_fmadd_ps(va1, kw0_hi, acc_b1);
852 acc_b0 = _mm256_fmadd_ps(va2, kw1_lo, acc_b0);
853 acc_b1 = _mm256_fmadd_ps(va2, kw1_hi, acc_b1);
854 acc_b0 = _mm256_fmadd_ps(vb3, kw2_lo, acc_b0);
855 acc_b1 = _mm256_fmadd_ps(vb3, kw2_hi, acc_b1);
856 }
857 }
858
859 _mm256_storeu_ps(output.as_mut_ptr().add(out_off_a + co), acc_a0);
860 _mm256_storeu_ps(output.as_mut_ptr().add(out_off_a + co + 8), acc_a1);
861 _mm256_storeu_ps(output.as_mut_ptr().add(out_off_b + co), acc_b0);
862 _mm256_storeu_ps(output.as_mut_ptr().add(out_off_b + co + 8), acc_b1);
863 co += 16;
864 }
865
866 while co + 8 <= c_out {
868 let mut acc_a = _mm256_setzero_ps();
869 let mut acc_b = _mm256_setzero_ps();
870
871 for ky in 0..3 {
872 let iy = iy_base + ky;
873 let row_base = iy * w;
874 for ci in 0..c_in {
875 let in0 = *input.get_unchecked((row_base + ix_base) * c_in + ci);
876 let in1 = *input.get_unchecked((row_base + ix_base + 1) * c_in + ci);
877 let in2 = *input.get_unchecked((row_base + ix_base + 2) * c_in + ci);
878 let in3 = *input.get_unchecked((row_base + ix_base + 3) * c_in + ci);
879
880 let k0_off = ky * 3 * c_in * c_out + ci * c_out + co;
881 let k1_off = (ky * 3 + 1) * c_in * c_out + ci * c_out + co;
882 let k2_off = (ky * 3 + 2) * c_in * c_out + ci * c_out + co;
883 let kw0 = _mm256_loadu_ps(kernel.as_ptr().add(k0_off));
884 let kw1 = _mm256_loadu_ps(kernel.as_ptr().add(k1_off));
885 let kw2 = _mm256_loadu_ps(kernel.as_ptr().add(k2_off));
886
887 let va0 = _mm256_set1_ps(in0);
888 let va1 = _mm256_set1_ps(in1);
889 let va2 = _mm256_set1_ps(in2);
890 acc_a = _mm256_fmadd_ps(va0, kw0, acc_a);
891 acc_a = _mm256_fmadd_ps(va1, kw1, acc_a);
892 acc_a = _mm256_fmadd_ps(va2, kw2, acc_a);
893
894 let vb3 = _mm256_set1_ps(in3);
895 acc_b = _mm256_fmadd_ps(va1, kw0, acc_b);
896 acc_b = _mm256_fmadd_ps(va2, kw1, acc_b);
897 acc_b = _mm256_fmadd_ps(vb3, kw2, acc_b);
898 }
899 }
900
901 _mm256_storeu_ps(output.as_mut_ptr().add(out_off_a + co), acc_a);
902 _mm256_storeu_ps(output.as_mut_ptr().add(out_off_b + co), acc_b);
903 co += 8;
904 }
905
906 while co < c_out {
908 let mut acc_a = 0.0f32;
909 let mut acc_b = 0.0f32;
910 for ky in 0..3 {
911 let iy = iy_base + ky;
912 let row_base = iy * w;
913 for ci in 0..c_in {
914 let in0 = input[(row_base + ix_base) * c_in + ci];
915 let in1 = input[(row_base + ix_base + 1) * c_in + ci];
916 let in2 = input[(row_base + ix_base + 2) * c_in + ci];
917 let in3 = input[(row_base + ix_base + 3) * c_in + ci];
918 let k0 = kernel[ky * 3 * c_in * c_out + ci * c_out + co];
919 let k1 = kernel[(ky * 3 + 1) * c_in * c_out + ci * c_out + co];
920 let k2 = kernel[(ky * 3 + 2) * c_in * c_out + ci * c_out + co];
921 acc_a += in0 * k0 + in1 * k1 + in2 * k2;
922 acc_b += in1 * k0 + in2 * k1 + in3 * k2;
923 }
924 }
925 *output.get_unchecked_mut(out_off_a + co) = acc_a;
926 *output.get_unchecked_mut(out_off_b + co) = acc_b;
927 co += 1;
928 }
929
930 ox += 2;
931 }
932 }
933
934 while ox < out_w {
936 let ix_base = ox * stride_w;
937 let out_off = (oy * out_w + ox) * c_out;
938
939 let mut co = 0;
940 while co + 8 <= c_out {
941 let mut acc = _mm256_setzero_ps();
942
943 for ky in 0..3 {
944 for kx in 0..3 {
945 let iy = iy_base + ky;
946 let ix = ix_base + kx;
947 let in_off = (iy * w + ix) * c_in;
948 let k_base = (ky * 3 + kx) * c_in * c_out;
949
950 for ci in 0..c_in {
951 let iv = _mm256_set1_ps(*input.get_unchecked(in_off + ci));
952 let koff = k_base + ci * c_out + co;
953 acc = _mm256_fmadd_ps(
954 iv,
955 _mm256_loadu_ps(kernel.as_ptr().add(koff)),
956 acc,
957 );
958 }
959 }
960 }
961
962 _mm256_storeu_ps(output.as_mut_ptr().add(out_off + co), acc);
963 co += 8;
964 }
965
966 while co < c_out {
968 let mut acc = 0.0f32;
969 for ky in 0..3 {
970 for kx in 0..3 {
971 let iy = iy_base + ky;
972 let ix = ix_base + kx;
973 for ci in 0..c_in {
974 acc += input[(iy * w + ix) * c_in + ci]
975 * kernel[(ky * 3 + kx) * c_in * c_out + ci * c_out + co];
976 }
977 }
978 }
979 *output.get_unchecked_mut(out_off + co) = acc;
980 co += 1;
981 }
982
983 ox += 1;
984 }
985 }
986}
987
988#[cfg(target_arch = "x86_64")]
992#[target_feature(enable = "sse", enable = "fma")]
993#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
994unsafe fn conv2d_3x3_direct_sse(
995 input: &[f32], kernel: &[f32], output: &mut [f32], w: usize,
999 c_in: usize,
1000 c_out: usize,
1001 out_h: usize,
1002 out_w: usize,
1003 stride_h: usize,
1004 stride_w: usize,
1005) {
1006 #[cfg(target_arch = "x86_64")]
1007 use std::arch::x86_64::*;
1008
1009 for oy in 0..out_h {
1010 let iy_base = oy * stride_h;
1011
1012 let mut ox = 0usize;
1013 if stride_w == 1 {
1014 while ox + 2 <= out_w {
1015 let ix_base = ox; let out_off_a = (oy * out_w + ox) * c_out;
1017 let out_off_b = out_off_a + c_out;
1018
1019 let mut co = 0;
1020 while co + 8 <= c_out {
1021 let mut acc_a0 = _mm_setzero_ps();
1022 let mut acc_a1 = _mm_setzero_ps();
1023 let mut acc_b0 = _mm_setzero_ps();
1024 let mut acc_b1 = _mm_setzero_ps();
1025
1026 for ky in 0..3 {
1027 let iy = iy_base + ky;
1028 let row_base = iy * w;
1029 for ci in 0..c_in {
1030 let in0 = *input.get_unchecked((row_base + ix_base) * c_in + ci);
1031 let in1 = *input.get_unchecked((row_base + ix_base + 1) * c_in + ci);
1032 let in2 = *input.get_unchecked((row_base + ix_base + 2) * c_in + ci);
1033 let in3 = *input.get_unchecked((row_base + ix_base + 3) * c_in + ci);
1034
1035 let k0_off = ky * 3 * c_in * c_out + ci * c_out + co;
1036 let k1_off = (ky * 3 + 1) * c_in * c_out + ci * c_out + co;
1037 let k2_off = (ky * 3 + 2) * c_in * c_out + ci * c_out + co;
1038 let kw0_lo = _mm_loadu_ps(kernel.as_ptr().add(k0_off));
1039 let kw0_hi = _mm_loadu_ps(kernel.as_ptr().add(k0_off + 4));
1040 let kw1_lo = _mm_loadu_ps(kernel.as_ptr().add(k1_off));
1041 let kw1_hi = _mm_loadu_ps(kernel.as_ptr().add(k1_off + 4));
1042 let kw2_lo = _mm_loadu_ps(kernel.as_ptr().add(k2_off));
1043 let kw2_hi = _mm_loadu_ps(kernel.as_ptr().add(k2_off + 4));
1044
1045 let va0 = _mm_set1_ps(in0);
1047 let va1 = _mm_set1_ps(in1);
1048 let va2 = _mm_set1_ps(in2);
1049 acc_a0 = _mm_fmadd_ps(va0, kw0_lo, acc_a0);
1050 acc_a1 = _mm_fmadd_ps(va0, kw0_hi, acc_a1);
1051 acc_a0 = _mm_fmadd_ps(va1, kw1_lo, acc_a0);
1052 acc_a1 = _mm_fmadd_ps(va1, kw1_hi, acc_a1);
1053 acc_a0 = _mm_fmadd_ps(va2, kw2_lo, acc_a0);
1054 acc_a1 = _mm_fmadd_ps(va2, kw2_hi, acc_a1);
1055
1056 let vb3 = _mm_set1_ps(in3);
1058 acc_b0 = _mm_fmadd_ps(va1, kw0_lo, acc_b0);
1059 acc_b1 = _mm_fmadd_ps(va1, kw0_hi, acc_b1);
1060 acc_b0 = _mm_fmadd_ps(va2, kw1_lo, acc_b0);
1061 acc_b1 = _mm_fmadd_ps(va2, kw1_hi, acc_b1);
1062 acc_b0 = _mm_fmadd_ps(vb3, kw2_lo, acc_b0);
1063 acc_b1 = _mm_fmadd_ps(vb3, kw2_hi, acc_b1);
1064 }
1065 }
1066
1067 _mm_storeu_ps(output.as_mut_ptr().add(out_off_a + co), acc_a0);
1068 _mm_storeu_ps(output.as_mut_ptr().add(out_off_a + co + 4), acc_a1);
1069 _mm_storeu_ps(output.as_mut_ptr().add(out_off_b + co), acc_b0);
1070 _mm_storeu_ps(output.as_mut_ptr().add(out_off_b + co + 4), acc_b1);
1071 co += 8;
1072 }
1073
1074 while co + 4 <= c_out {
1075 let mut acc_a = _mm_setzero_ps();
1076 let mut acc_b = _mm_setzero_ps();
1077
1078 for ky in 0..3 {
1079 let iy = iy_base + ky;
1080 let row_base = iy * w;
1081 for ci in 0..c_in {
1082 let in0 = *input.get_unchecked((row_base + ix_base) * c_in + ci);
1083 let in1 = *input.get_unchecked((row_base + ix_base + 1) * c_in + ci);
1084 let in2 = *input.get_unchecked((row_base + ix_base + 2) * c_in + ci);
1085 let in3 = *input.get_unchecked((row_base + ix_base + 3) * c_in + ci);
1086
1087 let k0_off = ky * 3 * c_in * c_out + ci * c_out + co;
1088 let k1_off = (ky * 3 + 1) * c_in * c_out + ci * c_out + co;
1089 let k2_off = (ky * 3 + 2) * c_in * c_out + ci * c_out + co;
1090 let kw0 = _mm_loadu_ps(kernel.as_ptr().add(k0_off));
1091 let kw1 = _mm_loadu_ps(kernel.as_ptr().add(k1_off));
1092 let kw2 = _mm_loadu_ps(kernel.as_ptr().add(k2_off));
1093
1094 let va0 = _mm_set1_ps(in0);
1095 let va1 = _mm_set1_ps(in1);
1096 let va2 = _mm_set1_ps(in2);
1097 acc_a = _mm_fmadd_ps(va0, kw0, acc_a);
1098 acc_a = _mm_fmadd_ps(va1, kw1, acc_a);
1099 acc_a = _mm_fmadd_ps(va2, kw2, acc_a);
1100
1101 let vb3 = _mm_set1_ps(in3);
1102 acc_b = _mm_fmadd_ps(va1, kw0, acc_b);
1103 acc_b = _mm_fmadd_ps(va2, kw1, acc_b);
1104 acc_b = _mm_fmadd_ps(vb3, kw2, acc_b);
1105 }
1106 }
1107
1108 _mm_storeu_ps(output.as_mut_ptr().add(out_off_a + co), acc_a);
1109 _mm_storeu_ps(output.as_mut_ptr().add(out_off_b + co), acc_b);
1110 co += 4;
1111 }
1112
1113 while co < c_out {
1114 let mut acc_a = 0.0f32;
1115 let mut acc_b = 0.0f32;
1116 for ky in 0..3 {
1117 let iy = iy_base + ky;
1118 let row_base = iy * w;
1119 for ci in 0..c_in {
1120 let in0 = input[(row_base + ix_base) * c_in + ci];
1121 let in1 = input[(row_base + ix_base + 1) * c_in + ci];
1122 let in2 = input[(row_base + ix_base + 2) * c_in + ci];
1123 let in3 = input[(row_base + ix_base + 3) * c_in + ci];
1124 let k0 = kernel[ky * 3 * c_in * c_out + ci * c_out + co];
1125 let k1 = kernel[(ky * 3 + 1) * c_in * c_out + ci * c_out + co];
1126 let k2 = kernel[(ky * 3 + 2) * c_in * c_out + ci * c_out + co];
1127 acc_a += in0 * k0 + in1 * k1 + in2 * k2;
1128 acc_b += in1 * k0 + in2 * k1 + in3 * k2;
1129 }
1130 }
1131 *output.get_unchecked_mut(out_off_a + co) = acc_a;
1132 *output.get_unchecked_mut(out_off_b + co) = acc_b;
1133 co += 1;
1134 }
1135
1136 ox += 2;
1137 }
1138 }
1139
1140 while ox < out_w {
1142 let ix_base = ox * stride_w;
1143 let out_off = (oy * out_w + ox) * c_out;
1144
1145 let mut co = 0;
1146 while co + 8 <= c_out {
1147 let mut acc0 = _mm_setzero_ps();
1148 let mut acc1 = _mm_setzero_ps();
1149
1150 for ky in 0..3 {
1151 for kx in 0..3 {
1152 let iy = iy_base + ky;
1153 let ix = ix_base + kx;
1154 let in_off = (iy * w + ix) * c_in;
1155 let k_base = (ky * 3 + kx) * c_in * c_out;
1156
1157 for ci in 0..c_in {
1158 let iv = _mm_set1_ps(*input.get_unchecked(in_off + ci));
1159 let koff = k_base + ci * c_out + co;
1160 acc0 = _mm_fmadd_ps(iv, _mm_loadu_ps(kernel.as_ptr().add(koff)), acc0);
1161 acc1 =
1162 _mm_fmadd_ps(iv, _mm_loadu_ps(kernel.as_ptr().add(koff + 4)), acc1);
1163 }
1164 }
1165 }
1166
1167 _mm_storeu_ps(output.as_mut_ptr().add(out_off + co), acc0);
1168 _mm_storeu_ps(output.as_mut_ptr().add(out_off + co + 4), acc1);
1169 co += 8;
1170 }
1171
1172 while co + 4 <= c_out {
1173 let mut acc = _mm_setzero_ps();
1174 for ky in 0..3 {
1175 for kx in 0..3 {
1176 let iy = iy_base + ky;
1177 let ix = ix_base + kx;
1178 let in_off = (iy * w + ix) * c_in;
1179 for ci in 0..c_in {
1180 let iv = _mm_set1_ps(*input.get_unchecked(in_off + ci));
1181 acc = _mm_fmadd_ps(
1182 iv,
1183 _mm_loadu_ps(
1184 kernel
1185 .as_ptr()
1186 .add((ky * 3 + kx) * c_in * c_out + ci * c_out + co),
1187 ),
1188 acc,
1189 );
1190 }
1191 }
1192 }
1193 _mm_storeu_ps(output.as_mut_ptr().add(out_off + co), acc);
1194 co += 4;
1195 }
1196
1197 while co < c_out {
1199 let mut acc = 0.0f32;
1200 for ky in 0..3 {
1201 for kx in 0..3 {
1202 let iy = iy_base + ky;
1203 let ix = ix_base + kx;
1204 for ci in 0..c_in {
1205 acc += input[(iy * w + ix) * c_in + ci]
1206 * kernel[(ky * 3 + kx) * c_in * c_out + ci * c_out + co];
1207 }
1208 }
1209 }
1210 *output.get_unchecked_mut(out_off + co) = acc;
1211 co += 1;
1212 }
1213
1214 ox += 1;
1215 }
1216 }
1217}
1218
1219#[allow(unsafe_code)]
1221fn conv_fma_row(out: &mut [f32], kernel: &[f32], input_val: f32) {
1222 let len = out.len();
1223 debug_assert_eq!(len, kernel.len());
1224
1225 if cfg!(miri) || len < 4 {
1226 for i in 0..len {
1227 out[i] += kernel[i] * input_val;
1228 }
1229 return;
1230 }
1231
1232 #[cfg(target_arch = "aarch64")]
1233 {
1234 if std::arch::is_aarch64_feature_detected!("neon") {
1235 unsafe { conv_fma_neon(out, kernel, input_val) };
1236 return;
1237 }
1238 }
1239
1240 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1241 {
1242 if std::is_x86_feature_detected!("avx") {
1243 unsafe { conv_fma_avx(out, kernel, input_val) };
1244 return;
1245 }
1246 if std::is_x86_feature_detected!("sse") {
1247 unsafe { conv_fma_sse(out, kernel, input_val) };
1248 return;
1249 }
1250 }
1251
1252 for i in 0..len {
1253 out[i] += kernel[i] * input_val;
1254 }
1255}
1256
1257#[cfg(target_arch = "aarch64")]
1258#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1259#[target_feature(enable = "neon")]
1260unsafe fn conv_fma_neon(out: &mut [f32], kernel: &[f32], input_val: f32) {
1261 use std::arch::aarch64::*;
1262 let len = out.len();
1263 let op = out.as_mut_ptr();
1264 let kp = kernel.as_ptr();
1265 let v_input = vdupq_n_f32(input_val);
1266 let mut i = 0usize;
1267 while i + 4 <= len {
1268 let o = vld1q_f32(op.add(i));
1269 let k = vld1q_f32(kp.add(i));
1270 vst1q_f32(op.add(i), vfmaq_f32(o, k, v_input));
1271 i += 4;
1272 }
1273 while i < len {
1274 *op.add(i) += *kp.add(i) * input_val;
1275 i += 1;
1276 }
1277}
1278
1279#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1280#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1281#[target_feature(enable = "sse")]
1282unsafe fn conv_fma_sse(out: &mut [f32], kernel: &[f32], input_val: f32) {
1283 #[cfg(target_arch = "x86")]
1284 use std::arch::x86::*;
1285 #[cfg(target_arch = "x86_64")]
1286 use std::arch::x86_64::*;
1287 let len = out.len();
1288 let op = out.as_mut_ptr();
1289 let kp = kernel.as_ptr();
1290 let v_input = _mm_set1_ps(input_val);
1291 let mut i = 0usize;
1292 while i + 4 <= len {
1293 let o = _mm_loadu_ps(op.add(i));
1294 let k = _mm_loadu_ps(kp.add(i));
1295 _mm_storeu_ps(op.add(i), _mm_add_ps(o, _mm_mul_ps(k, v_input)));
1296 i += 4;
1297 }
1298 while i < len {
1299 *op.add(i) += *kp.add(i) * input_val;
1300 i += 1;
1301 }
1302}
1303
1304#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1305#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1306#[target_feature(enable = "avx")]
1307unsafe fn conv_fma_avx(out: &mut [f32], kernel: &[f32], input_val: f32) {
1308 #[cfg(target_arch = "x86")]
1309 use std::arch::x86::*;
1310 #[cfg(target_arch = "x86_64")]
1311 use std::arch::x86_64::*;
1312 let len = out.len();
1313 let op = out.as_mut_ptr();
1314 let kp = kernel.as_ptr();
1315 let v_input = _mm256_set1_ps(input_val);
1316 let mut i = 0usize;
1317 while i + 8 <= len {
1318 let o = _mm256_loadu_ps(op.add(i));
1319 let k = _mm256_loadu_ps(kp.add(i));
1320 _mm256_storeu_ps(op.add(i), _mm256_add_ps(o, _mm256_mul_ps(k, v_input)));
1321 i += 8;
1322 }
1323 if i < len {
1324 conv_fma_sse(&mut out[i..], &kernel[i..], input_val);
1325 }
1326}
1327
1328pub fn conv3d(
1331 input: &[f32],
1332 input_shape: &[usize], kernel: &[f32],
1334 kernel_shape: &[usize], stride: (usize, usize, usize), padding: (usize, usize, usize), ) -> (Vec<f32>, Vec<usize>) {
1338 assert_eq!(
1339 input_shape.len(),
1340 5,
1341 "input_shape must be [B, D, H, W, C_in]"
1342 );
1343 assert_eq!(
1344 kernel_shape.len(),
1345 5,
1346 "kernel_shape must be [KD, KH, KW, C_in, C_out]"
1347 );
1348
1349 let (batch, in_d, in_h, in_w, c_in) = (
1350 input_shape[0],
1351 input_shape[1],
1352 input_shape[2],
1353 input_shape[3],
1354 input_shape[4],
1355 );
1356 let (kd, kh, kw, k_cin, c_out) = (
1357 kernel_shape[0],
1358 kernel_shape[1],
1359 kernel_shape[2],
1360 kernel_shape[3],
1361 kernel_shape[4],
1362 );
1363 let (stride_d, stride_h, stride_w) = stride;
1364 let (pad_d, pad_h, pad_w) = padding;
1365
1366 assert_eq!(c_in, k_cin, "input C_in must match kernel C_in");
1367 assert!(
1368 stride_d > 0 && stride_h > 0 && stride_w > 0,
1369 "strides must be positive"
1370 );
1371 assert_eq!(input.len(), batch * in_d * in_h * in_w * c_in);
1372 assert_eq!(kernel.len(), kd * kh * kw * c_in * c_out);
1373
1374 let out_d = (in_d + 2 * pad_d - kd) / stride_d + 1;
1375 let out_h = (in_h + 2 * pad_h - kh) / stride_h + 1;
1376 let out_w = (in_w + 2 * pad_w - kw) / stride_w + 1;
1377
1378 let output_shape = vec![batch, out_d, out_h, out_w, c_out];
1379 let out_spatial = out_d * out_h * out_w;
1380 let output_len = batch * out_spatial * c_out;
1381 let k_spatial = kd * kh * kw;
1382 let col_k = k_spatial * c_in; #[cfg(feature = "blas")]
1389 let use_blas = !cfg!(miri) && batch == 1;
1390 #[cfg(not(feature = "blas"))]
1391 let use_blas = false;
1392
1393 if use_blas {
1394 let mut output = vec![0.0f32; output_len];
1395 let in_hwc = in_h * in_w * c_in;
1396 let in_wc = in_w * c_in;
1397
1398 for b in 0..batch {
1399 let b_in = b * in_d * in_hwc;
1400 let mut col = vec![0.0f32; out_spatial * col_k];
1402 let mut row = 0;
1403 for od in 0..out_d {
1404 for oh in 0..out_h {
1405 for ow in 0..out_w {
1406 let mut col_idx = 0;
1407 for fd in 0..kd {
1408 let id_raw = od * stride_d + fd;
1409 for fh in 0..kh {
1410 let ih_raw = oh * stride_h + fh;
1411 for fw in 0..kw {
1412 let iw_raw = ow * stride_w + fw;
1413 let in_bounds = id_raw >= pad_d
1414 && id_raw - pad_d < in_d
1415 && ih_raw >= pad_h
1416 && ih_raw - pad_h < in_h
1417 && iw_raw >= pad_w
1418 && iw_raw - pad_w < in_w;
1419 if in_bounds {
1420 let id = id_raw - pad_d;
1421 let ih = ih_raw - pad_h;
1422 let iw = iw_raw - pad_w;
1423 let base = b_in + id * in_hwc + ih * in_wc + iw * c_in;
1424 col[row * col_k + col_idx..row * col_k + col_idx + c_in]
1425 .copy_from_slice(&input[base..base + c_in]);
1426 }
1427 col_idx += c_in;
1429 }
1430 }
1431 }
1432 row += 1;
1433 }
1434 }
1435 }
1436
1437 let b_out = b * out_spatial * c_out;
1439 super::matmul::blas_sgemm(
1440 &col,
1441 kernel,
1442 &mut output[b_out..b_out + out_spatial * c_out],
1443 out_spatial,
1444 col_k,
1445 c_out,
1446 );
1447 }
1448 return (output, output_shape);
1449 }
1450
1451 let mut output = vec![0.0f32; output_len];
1453 let in_dhwc = in_d * in_h * in_w * c_in;
1454 let in_hwc = in_h * in_w * c_in;
1455 let in_wc = in_w * c_in;
1456 let k_hwcico = kh * kw * c_in * c_out;
1457 let k_wcico = kw * c_in * c_out;
1458 let k_cico = c_in * c_out;
1459 let out_dhwco = out_d * out_h * out_w * c_out;
1460 let out_hwco = out_h * out_w * c_out;
1461 let out_wco = out_w * c_out;
1462
1463 for b in 0..batch {
1464 let b_in = b * in_dhwc;
1465 let b_out = b * out_dhwco;
1466 for od in 0..out_d {
1467 for oh in 0..out_h {
1468 for ow in 0..out_w {
1469 let out_base = b_out + od * out_hwco + oh * out_wco + ow * c_out;
1470 for fd in 0..kd {
1471 let id = od * stride_d + fd;
1472 if id < pad_d || id - pad_d >= in_d {
1473 continue;
1474 }
1475 let id = id - pad_d;
1476 for fh in 0..kh {
1477 let ih = oh * stride_h + fh;
1478 if ih < pad_h || ih - pad_h >= in_h {
1479 continue;
1480 }
1481 let ih = ih - pad_h;
1482 for fw in 0..kw {
1483 let iw = ow * stride_w + fw;
1484 if iw < pad_w || iw - pad_w >= in_w {
1485 continue;
1486 }
1487 let iw = iw - pad_w;
1488 let in_base = b_in + id * in_hwc + ih * in_wc + iw * c_in;
1489 let k_base = fd * k_hwcico + fh * k_wcico + fw * k_cico;
1490 for ci in 0..c_in {
1491 let input_val = input[in_base + ci];
1492 let k_offset = k_base + ci * c_out;
1493 for co in 0..c_out {
1494 output[out_base + co] += input_val * kernel[k_offset + co];
1495 }
1496 }
1497 }
1498 }
1499 }
1500 }
1501 }
1502 }
1503 }
1504
1505 (output, output_shape)
1506}
1507
1508fn depthwise_conv2d_nhwc_row(
1509 input: &[f32],
1510 kernel: &[f32],
1511 bias: Option<&[f32]>,
1512 plan: DepthwiseConv2dPlan,
1513 row_idx: usize,
1514 out_row: &mut [f32],
1515) {
1516 let batch_idx = row_idx / plan.out_h;
1517 let out_y = row_idx % plan.out_h;
1518 let in_y0 = out_y * plan.stride_h;
1519 let batch_input_base = batch_idx * plan.in_h * plan.in_w * plan.channels;
1520
1521 for out_x in 0..plan.out_w {
1522 let in_x0 = out_x * plan.stride_w;
1523 let out_cell_base = out_x * plan.out_channels;
1524
1525 for out_channel in 0..plan.out_channels {
1526 let mut acc = bias.map_or(0.0, |bias_values| bias_values[out_channel]);
1527 let in_channel = out_channel / plan.depth_multiplier;
1528 let depth_index = out_channel % plan.depth_multiplier;
1529
1530 for ky in 0..plan.kernel_h {
1531 let in_y = in_y0 + ky;
1532 let input_row_base = batch_input_base + (in_y * plan.in_w + in_x0) * plan.channels;
1533 let kernel_row_base = ky * plan.kernel_w * plan.channels * plan.depth_multiplier;
1534
1535 for kx in 0..plan.kernel_w {
1536 let input_value = input[input_row_base + kx * plan.channels + in_channel];
1537 let kernel_index = kernel_row_base
1538 + kx * plan.channels * plan.depth_multiplier
1539 + in_channel * plan.depth_multiplier
1540 + depth_index;
1541 acc += input_value * kernel[kernel_index];
1542 }
1543 }
1544
1545 out_row[out_cell_base + out_channel] = acc;
1546 }
1547 }
1548}
1549
1550#[cfg(feature = "blas")]
1559fn im2col_nhwc(
1560 input: &[f32],
1561 in_w: usize,
1562 c: usize,
1563 kh: usize,
1564 kw: usize,
1565 stride_h: usize,
1566 stride_w: usize,
1567 out_h: usize,
1568 out_w: usize,
1569 col: &mut [f32],
1570) {
1571 let k = kh * kw * c;
1572 for oy in 0..out_h {
1573 for ox in 0..out_w {
1574 let row = oy * out_w + ox;
1575 let row_off = row * k;
1576 for ky in 0..kh {
1577 let iy = oy * stride_h + ky;
1578 for kx in 0..kw {
1579 let ix = ox * stride_w + kx;
1580 let src_off = (iy * in_w + ix) * c;
1581 let dst_off = row_off + (ky * kw + kx) * c;
1582 col[dst_off..dst_off + c].copy_from_slice(&input[src_off..src_off + c]);
1583 }
1584 }
1585 }
1586 }
1587}
1588
1589#[cfg(feature = "blas")]
1595fn conv2d_im2col_gemm(
1596 plan: &Conv2dPlan,
1597 input: &[f32],
1598 kernel: &[f32],
1599 bias: Option<&[f32]>,
1600) -> Result<Tensor, KernelError> {
1601 let out_h = plan.out_h;
1602 let out_w = plan.out_w;
1603 let k = plan.kernel_h * plan.kernel_w * plan.in_channels;
1604 let m = out_h * out_w;
1605 let n = plan.out_channels;
1606
1607 #[allow(unsafe_code)]
1611 let mut col = AlignedVec::<f32>::uninitialized(m * k);
1612 im2col_nhwc(
1613 input,
1614 plan.in_w,
1615 plan.in_channels,
1616 plan.kernel_h,
1617 plan.kernel_w,
1618 plan.stride_h,
1619 plan.stride_w,
1620 out_h,
1621 out_w,
1622 &mut col,
1623 );
1624
1625 #[allow(unsafe_code)]
1629 let mut output = AlignedVec::<f32>::uninitialized(m * n);
1630 super::matmul::blas_sgemm(&col, kernel, &mut output, m, k, n);
1631
1632 if let Some(bias) = bias {
1634 for row in 0..m {
1635 let row_off = row * n;
1636 for c in 0..n {
1637 output[row_off + c] += bias[c];
1638 }
1639 }
1640 }
1641
1642 Tensor::from_aligned(vec![1, out_h, out_w, n], output).map_err(Into::into)
1643}