1use std::num::NonZeroUsize;
2
3use rayon::{ThreadPool, ThreadPoolBuilder};
4use yscv_tensor::Tensor;
5
6use super::{
7 error::KernelError,
8 ops::{
9 self, BatchNorm2dTensors, Conv2dSpec, DepthwiseConv2dSpec, GroupNorm2dTensors,
10 LayerNormLastDimTensors, ParallelElementwiseConfig, ParallelMatmulConfig, Pool2dSpec,
11 RmsNormLastDimTensors, SeparableConv2dKernels, SeparableConv2dSpec,
12 },
13};
14
15#[derive(Debug, Clone, Copy)]
18pub struct SeparableConv2dParams<'a> {
19 pub depthwise_kernel: &'a Tensor,
20 pub depthwise_bias: Option<&'a Tensor>,
21 pub pointwise_kernel: &'a Tensor,
22 pub pointwise_bias: Option<&'a Tensor>,
23}
24
25#[derive(Debug, Clone, Copy)]
27pub struct BatchNorm2dParams<'a> {
28 pub gamma: &'a Tensor,
29 pub beta: &'a Tensor,
30 pub mean: &'a Tensor,
31 pub variance: &'a Tensor,
32 pub epsilon: f32,
33}
34
35#[derive(Debug, Clone, Copy)]
37pub struct LayerNormLastDimParams<'a> {
38 pub gamma: &'a Tensor,
39 pub beta: &'a Tensor,
40 pub epsilon: f32,
41}
42
43#[derive(Debug, Clone, Copy)]
45pub struct GroupNormNhwcParams<'a> {
46 pub gamma: &'a Tensor,
47 pub beta: &'a Tensor,
48 pub num_groups: usize,
49 pub epsilon: f32,
50}
51
52#[derive(Debug, Clone, Copy)]
54pub struct RmsNormLastDimParams<'a> {
55 pub gamma: &'a Tensor,
56 pub epsilon: f32,
57}
58
59pub trait Backend {
61 fn add(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>;
62 fn sub(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>;
63 fn mul(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>;
64 fn relu(&self, input: &Tensor) -> Tensor;
65 fn sigmoid(&self, input: &Tensor) -> Tensor;
66 fn exp(&self, input: &Tensor) -> Tensor;
67 fn tanh_act(&self, input: &Tensor) -> Tensor;
68 fn softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError>;
69 fn log_softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError>;
70 fn logsumexp_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError>;
71 fn layer_norm_last_dim(
72 &self,
73 input: &Tensor,
74 params: LayerNormLastDimParams<'_>,
75 ) -> Result<Tensor, KernelError>;
76 fn max_pool2d_nhwc(
77 &self,
78 input: &Tensor,
79 kernel_h: usize,
80 kernel_w: usize,
81 stride_h: usize,
82 stride_w: usize,
83 ) -> Result<Tensor, KernelError>;
84 fn avg_pool2d_nhwc(
85 &self,
86 input: &Tensor,
87 kernel_h: usize,
88 kernel_w: usize,
89 stride_h: usize,
90 stride_w: usize,
91 ) -> Result<Tensor, KernelError>;
92 fn conv2d_nhwc(
93 &self,
94 input: &Tensor,
95 kernel: &Tensor,
96 bias: Option<&Tensor>,
97 stride_h: usize,
98 stride_w: usize,
99 ) -> Result<Tensor, KernelError>;
100 fn depthwise_conv2d_nhwc(
101 &self,
102 input: &Tensor,
103 kernel: &Tensor,
104 bias: Option<&Tensor>,
105 stride_h: usize,
106 stride_w: usize,
107 ) -> Result<Tensor, KernelError>;
108 fn separable_conv2d_nhwc(
109 &self,
110 input: &Tensor,
111 params: SeparableConv2dParams<'_>,
112 stride_h: usize,
113 stride_w: usize,
114 ) -> Result<Tensor, KernelError>;
115 fn batch_norm2d_nhwc(
116 &self,
117 input: &Tensor,
118 params: BatchNorm2dParams<'_>,
119 ) -> Result<Tensor, KernelError>;
120 fn group_norm_nhwc(
121 &self,
122 input: &Tensor,
123 params: GroupNormNhwcParams<'_>,
124 ) -> Result<Tensor, KernelError>;
125 fn rms_norm_last_dim(
126 &self,
127 input: &Tensor,
128 params: RmsNormLastDimParams<'_>,
129 ) -> Result<Tensor, KernelError>;
130 fn matmul_2d(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>;
131
132 fn neg(&self, input: &Tensor) -> Tensor {
136 input.neg()
137 }
138
139 fn div(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
141 lhs.div(rhs).map_err(Into::into)
142 }
143
144 fn sqrt(&self, input: &Tensor) -> Tensor {
146 input.sqrt()
147 }
148
149 fn transpose_2d(&self, input: &Tensor) -> Result<Tensor, KernelError> {
151 input.transpose_2d().map_err(Into::into)
152 }
153
154 fn sum_all(&self, input: &Tensor) -> Tensor {
156 Tensor::scalar(input.sum())
157 }
158
159 fn mul_scalar(&self, input: &Tensor, scalar: f32) -> Tensor {
161 input.scale(scalar)
162 }
163
164 fn reciprocal(&self, input: &Tensor) -> Tensor {
166 input.reciprocal()
167 }
168}
169
170pub trait BackwardOps: Backend {
176 fn relu_backward(
178 &self,
179 upstream: &Tensor,
180 forward_input: &Tensor,
181 ) -> Result<Tensor, KernelError> {
182 let u = upstream.data();
183 let f = forward_input.data();
184 let out: Vec<f32> = u
185 .iter()
186 .zip(f.iter())
187 .map(|(&u, &x)| if x > 0.0 { u } else { 0.0 })
188 .collect();
189 Tensor::from_vec(upstream.shape().to_vec(), out).map_err(Into::into)
190 }
191
192 fn sigmoid_backward(
194 &self,
195 upstream: &Tensor,
196 forward_output: &Tensor,
197 ) -> Result<Tensor, KernelError> {
198 let u = upstream.data();
199 let s = forward_output.data();
200 let out: Vec<f32> = u
201 .iter()
202 .zip(s.iter())
203 .map(|(&u, &s)| u * s * (1.0 - s))
204 .collect();
205 Tensor::from_vec(upstream.shape().to_vec(), out).map_err(Into::into)
206 }
207
208 fn tanh_backward(
210 &self,
211 upstream: &Tensor,
212 forward_output: &Tensor,
213 ) -> Result<Tensor, KernelError> {
214 let u = upstream.data();
215 let t = forward_output.data();
216 let out: Vec<f32> = u
217 .iter()
218 .zip(t.iter())
219 .map(|(&u, &t)| u * (1.0 - t * t))
220 .collect();
221 Tensor::from_vec(upstream.shape().to_vec(), out).map_err(Into::into)
222 }
223
224 fn exp_backward(
226 &self,
227 upstream: &Tensor,
228 forward_output: &Tensor,
229 ) -> Result<Tensor, KernelError> {
230 let u = upstream.data();
231 let e = forward_output.data();
232 let out: Vec<f32> = u.iter().zip(e.iter()).map(|(&u, &e)| u * e).collect();
233 Tensor::from_vec(upstream.shape().to_vec(), out).map_err(Into::into)
234 }
235
236 fn reduce_sum_backward(
238 &self,
239 upstream: &Tensor,
240 original_shape: &[usize],
241 ) -> Result<Tensor, KernelError> {
242 let grad_val = upstream.data()[0];
243 let len: usize = original_shape.iter().product();
244 let out = vec![grad_val; len];
245 Tensor::from_vec(original_shape.to_vec(), out).map_err(Into::into)
246 }
247
248 fn matmul_backward(
250 &self,
251 upstream: &Tensor,
252 lhs: &Tensor,
253 rhs: &Tensor,
254 ) -> Result<(Tensor, Tensor), KernelError> {
255 let rt = self.transpose_2d(rhs)?;
256 let lt = self.transpose_2d(lhs)?;
257 let grad_lhs = self.matmul_2d(upstream, &rt)?;
258 let grad_rhs = self.matmul_2d(<, upstream)?;
259 Ok((grad_lhs, grad_rhs))
260 }
261
262 fn add_backward(
264 &self,
265 upstream: &Tensor,
266 _lhs: &Tensor,
267 _rhs: &Tensor,
268 ) -> Result<(Tensor, Tensor), KernelError> {
269 Ok((upstream.clone(), upstream.clone()))
270 }
271
272 fn sub_backward(
274 &self,
275 upstream: &Tensor,
276 _lhs: &Tensor,
277 _rhs: &Tensor,
278 ) -> Result<(Tensor, Tensor), KernelError> {
279 Ok((upstream.clone(), self.neg(upstream)))
280 }
281
282 fn mul_backward(
284 &self,
285 upstream: &Tensor,
286 lhs: &Tensor,
287 rhs: &Tensor,
288 ) -> Result<(Tensor, Tensor), KernelError> {
289 let grad_lhs = self.mul(upstream, rhs)?;
290 let grad_rhs = self.mul(upstream, lhs)?;
291 Ok((grad_lhs, grad_rhs))
292 }
293
294 fn conv2d_input_backward(
298 &self,
299 upstream: &Tensor,
300 kernel: &Tensor,
301 input_shape: &[usize],
302 stride_h: usize,
303 stride_w: usize,
304 ) -> Result<Tensor, KernelError> {
305 let us = upstream.shape();
307 let ks = kernel.shape();
308 if us.len() != 4 || ks.len() != 4 || input_shape.len() != 4 {
309 return Err(KernelError::InvalidConvRank {
310 input_rank: input_shape.len(),
311 kernel_rank: ks.len(),
312 });
313 }
314 let (n, ih, iw, ic) = (
315 input_shape[0],
316 input_shape[1],
317 input_shape[2],
318 input_shape[3],
319 );
320 let (_n, oh, ow, oc) = (us[0], us[1], us[2], us[3]);
321 let (kh, kw) = (ks[0], ks[1]);
322
323 let u_data = upstream.data();
324 let k_data = kernel.data();
325 let mut grad_input = vec![0.0f32; n * ih * iw * ic];
326
327 for b in 0..n {
328 for oy in 0..oh {
329 for ox in 0..ow {
330 for co in 0..oc {
331 let g = u_data[((b * oh + oy) * ow + ox) * oc + co];
332 if g == 0.0 {
333 continue;
334 }
335 for ky in 0..kh {
336 for kx in 0..kw {
337 let iy = oy * stride_h + ky;
338 let ix = ox * stride_w + kx;
339 if iy < ih && ix < iw {
340 for ci in 0..ic {
341 let k_val = k_data[((ky * kw + kx) * ic + ci) * oc + co];
342 grad_input[((b * ih + iy) * iw + ix) * ic + ci] +=
343 g * k_val;
344 }
345 }
346 }
347 }
348 }
349 }
350 }
351 }
352
353 Tensor::from_vec(input_shape.to_vec(), grad_input).map_err(Into::into)
354 }
355}
356
357#[derive(Debug, Clone, Copy, Default)]
362pub struct CpuBackend;
363
364impl BackwardOps for CpuBackend {}
365
366impl Backend for CpuBackend {
367 fn add(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
368 ops::add_with_config(lhs, rhs, ParallelElementwiseConfig::disabled())
369 }
370
371 fn sub(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
372 ops::sub_with_config(lhs, rhs, ParallelElementwiseConfig::disabled())
373 }
374
375 fn mul(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
376 ops::mul_with_config(lhs, rhs, ParallelElementwiseConfig::disabled())
377 }
378
379 fn relu(&self, input: &Tensor) -> Tensor {
380 ops::relu(input)
381 }
382
383 fn sigmoid(&self, input: &Tensor) -> Tensor {
384 ops::sigmoid(input)
385 }
386
387 fn exp(&self, input: &Tensor) -> Tensor {
388 ops::exp(input)
389 }
390
391 fn tanh_act(&self, input: &Tensor) -> Tensor {
392 ops::tanh_act(input)
393 }
394
395 fn softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError> {
396 ops::softmax_last_dim_with_config_and_pool(
397 input,
398 ParallelElementwiseConfig::disabled(),
399 None,
400 )
401 }
402
403 fn log_softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError> {
404 ops::log_softmax_last_dim_with_config_and_pool(
405 input,
406 ParallelElementwiseConfig::disabled(),
407 None,
408 )
409 }
410
411 fn logsumexp_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError> {
412 ops::logsumexp_last_dim_with_config_and_pool(
413 input,
414 ParallelElementwiseConfig::disabled(),
415 None,
416 )
417 }
418
419 fn layer_norm_last_dim(
420 &self,
421 input: &Tensor,
422 params: LayerNormLastDimParams<'_>,
423 ) -> Result<Tensor, KernelError> {
424 ops::layer_norm_last_dim_with_config_and_pool(
425 input,
426 LayerNormLastDimTensors {
427 gamma: params.gamma,
428 beta: params.beta,
429 epsilon: params.epsilon,
430 },
431 ParallelElementwiseConfig::disabled(),
432 None,
433 )
434 }
435
436 fn max_pool2d_nhwc(
437 &self,
438 input: &Tensor,
439 kernel_h: usize,
440 kernel_w: usize,
441 stride_h: usize,
442 stride_w: usize,
443 ) -> Result<Tensor, KernelError> {
444 ops::max_pool2d_nhwc_with_config_and_pool(
445 input,
446 Pool2dSpec {
447 kernel_h,
448 kernel_w,
449 stride_h,
450 stride_w,
451 },
452 ParallelElementwiseConfig::disabled(),
453 None,
454 )
455 }
456
457 fn avg_pool2d_nhwc(
458 &self,
459 input: &Tensor,
460 kernel_h: usize,
461 kernel_w: usize,
462 stride_h: usize,
463 stride_w: usize,
464 ) -> Result<Tensor, KernelError> {
465 ops::avg_pool2d_nhwc_with_config_and_pool(
466 input,
467 Pool2dSpec {
468 kernel_h,
469 kernel_w,
470 stride_h,
471 stride_w,
472 },
473 ParallelElementwiseConfig::disabled(),
474 None,
475 )
476 }
477
478 fn conv2d_nhwc(
479 &self,
480 input: &Tensor,
481 kernel: &Tensor,
482 bias: Option<&Tensor>,
483 stride_h: usize,
484 stride_w: usize,
485 ) -> Result<Tensor, KernelError> {
486 ops::conv2d_nhwc_with_config_and_pool(
487 input,
488 kernel,
489 bias,
490 Conv2dSpec { stride_h, stride_w },
491 ParallelElementwiseConfig::disabled(),
492 None,
493 )
494 }
495
496 fn depthwise_conv2d_nhwc(
497 &self,
498 input: &Tensor,
499 kernel: &Tensor,
500 bias: Option<&Tensor>,
501 stride_h: usize,
502 stride_w: usize,
503 ) -> Result<Tensor, KernelError> {
504 ops::depthwise_conv2d_nhwc_with_config_and_pool(
505 input,
506 kernel,
507 bias,
508 DepthwiseConv2dSpec { stride_h, stride_w },
509 ParallelElementwiseConfig::disabled(),
510 None,
511 )
512 }
513
514 fn separable_conv2d_nhwc(
515 &self,
516 input: &Tensor,
517 params: SeparableConv2dParams<'_>,
518 stride_h: usize,
519 stride_w: usize,
520 ) -> Result<Tensor, KernelError> {
521 ops::separable_conv2d_nhwc_with_config_and_pool(
522 input,
523 SeparableConv2dKernels {
524 depthwise_kernel: params.depthwise_kernel,
525 depthwise_bias: params.depthwise_bias,
526 pointwise_kernel: params.pointwise_kernel,
527 pointwise_bias: params.pointwise_bias,
528 },
529 SeparableConv2dSpec { stride_h, stride_w },
530 ParallelElementwiseConfig::disabled(),
531 None,
532 )
533 }
534
535 fn batch_norm2d_nhwc(
536 &self,
537 input: &Tensor,
538 params: BatchNorm2dParams<'_>,
539 ) -> Result<Tensor, KernelError> {
540 ops::batch_norm2d_nhwc_with_config_and_pool(
541 input,
542 BatchNorm2dTensors {
543 gamma: params.gamma,
544 beta: params.beta,
545 mean: params.mean,
546 variance: params.variance,
547 epsilon: params.epsilon,
548 },
549 ParallelElementwiseConfig::disabled(),
550 None,
551 )
552 }
553
554 fn group_norm_nhwc(
555 &self,
556 input: &Tensor,
557 params: GroupNormNhwcParams<'_>,
558 ) -> Result<Tensor, KernelError> {
559 ops::group_norm_nhwc_with_config_and_pool(
560 input,
561 GroupNorm2dTensors {
562 gamma: params.gamma,
563 beta: params.beta,
564 num_groups: params.num_groups,
565 epsilon: params.epsilon,
566 },
567 ParallelElementwiseConfig::disabled(),
568 None,
569 )
570 }
571
572 fn rms_norm_last_dim(
573 &self,
574 input: &Tensor,
575 params: RmsNormLastDimParams<'_>,
576 ) -> Result<Tensor, KernelError> {
577 ops::rms_norm_last_dim_with_config_and_pool(
578 input,
579 RmsNormLastDimTensors {
580 gamma: params.gamma,
581 epsilon: params.epsilon,
582 },
583 ParallelElementwiseConfig::disabled(),
584 None,
585 )
586 }
587
588 fn matmul_2d(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
589 ops::matmul_2d(lhs, rhs)
590 }
591}
592
593#[derive(Debug)]
595pub struct ThreadedCpuBackend {
596 matmul_config: ParallelMatmulConfig,
597 elementwise_config: ParallelElementwiseConfig,
598 thread_pool: ThreadPool,
599}
600
601#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
603pub struct ThreadedCpuBackendConfig {
604 pub matmul: ParallelMatmulConfig,
605 pub elementwise: ParallelElementwiseConfig,
606}
607
608impl ThreadedCpuBackend {
609 pub fn new(num_threads: NonZeroUsize) -> Result<Self, KernelError> {
611 Self::with_full_config(num_threads, ThreadedCpuBackendConfig::default())
612 }
613
614 pub fn with_config(
616 num_threads: NonZeroUsize,
617 matmul_config: ParallelMatmulConfig,
618 ) -> Result<Self, KernelError> {
619 Self::with_full_config(
620 num_threads,
621 ThreadedCpuBackendConfig {
622 matmul: matmul_config,
623 elementwise: ParallelElementwiseConfig::default(),
624 },
625 )
626 }
627
628 pub fn with_full_config(
630 num_threads: NonZeroUsize,
631 config: ThreadedCpuBackendConfig,
632 ) -> Result<Self, KernelError> {
633 let thread_pool = ThreadPoolBuilder::new()
634 .num_threads(num_threads.get())
635 .build()
636 .map_err(|error| KernelError::ThreadPoolBuild {
637 message: error.to_string(),
638 })?;
639 Ok(Self {
640 matmul_config: config.matmul,
641 elementwise_config: config.elementwise,
642 thread_pool,
643 })
644 }
645
646 pub const fn matmul_config(&self) -> ParallelMatmulConfig {
648 self.matmul_config
649 }
650
651 pub const fn elementwise_config(&self) -> ParallelElementwiseConfig {
653 self.elementwise_config
654 }
655}
656
657impl BackwardOps for ThreadedCpuBackend {}
658
659impl Backend for ThreadedCpuBackend {
660 fn add(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
661 ops::add_with_config_and_pool(lhs, rhs, self.elementwise_config, Some(&self.thread_pool))
662 }
663
664 fn sub(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
665 ops::sub_with_config_and_pool(lhs, rhs, self.elementwise_config, Some(&self.thread_pool))
666 }
667
668 fn mul(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
669 ops::mul_with_config_and_pool(lhs, rhs, self.elementwise_config, Some(&self.thread_pool))
670 }
671
672 fn relu(&self, input: &Tensor) -> Tensor {
673 ops::relu_with_config_and_pool(input, self.elementwise_config, Some(&self.thread_pool))
674 }
675
676 fn sigmoid(&self, input: &Tensor) -> Tensor {
677 ops::sigmoid_with_config_and_pool(input, self.elementwise_config, Some(&self.thread_pool))
678 }
679
680 fn exp(&self, input: &Tensor) -> Tensor {
681 ops::exp_with_config_and_pool(input, self.elementwise_config, Some(&self.thread_pool))
682 }
683
684 fn tanh_act(&self, input: &Tensor) -> Tensor {
685 ops::tanh_act_with_config_and_pool(input, self.elementwise_config, Some(&self.thread_pool))
686 }
687
688 fn softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError> {
689 ops::softmax_last_dim_with_config_and_pool(
690 input,
691 self.elementwise_config,
692 Some(&self.thread_pool),
693 )
694 }
695
696 fn log_softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError> {
697 ops::log_softmax_last_dim_with_config_and_pool(
698 input,
699 self.elementwise_config,
700 Some(&self.thread_pool),
701 )
702 }
703
704 fn logsumexp_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError> {
705 ops::logsumexp_last_dim_with_config_and_pool(
706 input,
707 self.elementwise_config,
708 Some(&self.thread_pool),
709 )
710 }
711
712 fn layer_norm_last_dim(
713 &self,
714 input: &Tensor,
715 params: LayerNormLastDimParams<'_>,
716 ) -> Result<Tensor, KernelError> {
717 ops::layer_norm_last_dim_with_config_and_pool(
718 input,
719 LayerNormLastDimTensors {
720 gamma: params.gamma,
721 beta: params.beta,
722 epsilon: params.epsilon,
723 },
724 self.elementwise_config,
725 Some(&self.thread_pool),
726 )
727 }
728
729 fn max_pool2d_nhwc(
730 &self,
731 input: &Tensor,
732 kernel_h: usize,
733 kernel_w: usize,
734 stride_h: usize,
735 stride_w: usize,
736 ) -> Result<Tensor, KernelError> {
737 ops::max_pool2d_nhwc_with_config_and_pool(
738 input,
739 Pool2dSpec {
740 kernel_h,
741 kernel_w,
742 stride_h,
743 stride_w,
744 },
745 self.elementwise_config,
746 Some(&self.thread_pool),
747 )
748 }
749
750 fn avg_pool2d_nhwc(
751 &self,
752 input: &Tensor,
753 kernel_h: usize,
754 kernel_w: usize,
755 stride_h: usize,
756 stride_w: usize,
757 ) -> Result<Tensor, KernelError> {
758 ops::avg_pool2d_nhwc_with_config_and_pool(
759 input,
760 Pool2dSpec {
761 kernel_h,
762 kernel_w,
763 stride_h,
764 stride_w,
765 },
766 self.elementwise_config,
767 Some(&self.thread_pool),
768 )
769 }
770
771 fn conv2d_nhwc(
772 &self,
773 input: &Tensor,
774 kernel: &Tensor,
775 bias: Option<&Tensor>,
776 stride_h: usize,
777 stride_w: usize,
778 ) -> Result<Tensor, KernelError> {
779 ops::conv2d_nhwc_with_config_and_pool(
780 input,
781 kernel,
782 bias,
783 Conv2dSpec { stride_h, stride_w },
784 self.elementwise_config,
785 Some(&self.thread_pool),
786 )
787 }
788
789 fn depthwise_conv2d_nhwc(
790 &self,
791 input: &Tensor,
792 kernel: &Tensor,
793 bias: Option<&Tensor>,
794 stride_h: usize,
795 stride_w: usize,
796 ) -> Result<Tensor, KernelError> {
797 ops::depthwise_conv2d_nhwc_with_config_and_pool(
798 input,
799 kernel,
800 bias,
801 DepthwiseConv2dSpec { stride_h, stride_w },
802 self.elementwise_config,
803 Some(&self.thread_pool),
804 )
805 }
806
807 fn separable_conv2d_nhwc(
808 &self,
809 input: &Tensor,
810 params: SeparableConv2dParams<'_>,
811 stride_h: usize,
812 stride_w: usize,
813 ) -> Result<Tensor, KernelError> {
814 ops::separable_conv2d_nhwc_with_config_and_pool(
815 input,
816 SeparableConv2dKernels {
817 depthwise_kernel: params.depthwise_kernel,
818 depthwise_bias: params.depthwise_bias,
819 pointwise_kernel: params.pointwise_kernel,
820 pointwise_bias: params.pointwise_bias,
821 },
822 SeparableConv2dSpec { stride_h, stride_w },
823 self.elementwise_config,
824 Some(&self.thread_pool),
825 )
826 }
827
828 fn batch_norm2d_nhwc(
829 &self,
830 input: &Tensor,
831 params: BatchNorm2dParams<'_>,
832 ) -> Result<Tensor, KernelError> {
833 ops::batch_norm2d_nhwc_with_config_and_pool(
834 input,
835 BatchNorm2dTensors {
836 gamma: params.gamma,
837 beta: params.beta,
838 mean: params.mean,
839 variance: params.variance,
840 epsilon: params.epsilon,
841 },
842 self.elementwise_config,
843 Some(&self.thread_pool),
844 )
845 }
846
847 fn group_norm_nhwc(
848 &self,
849 input: &Tensor,
850 params: GroupNormNhwcParams<'_>,
851 ) -> Result<Tensor, KernelError> {
852 ops::group_norm_nhwc_with_config_and_pool(
853 input,
854 GroupNorm2dTensors {
855 gamma: params.gamma,
856 beta: params.beta,
857 num_groups: params.num_groups,
858 epsilon: params.epsilon,
859 },
860 self.elementwise_config,
861 Some(&self.thread_pool),
862 )
863 }
864
865 fn rms_norm_last_dim(
866 &self,
867 input: &Tensor,
868 params: RmsNormLastDimParams<'_>,
869 ) -> Result<Tensor, KernelError> {
870 ops::rms_norm_last_dim_with_config_and_pool(
871 input,
872 RmsNormLastDimTensors {
873 gamma: params.gamma,
874 epsilon: params.epsilon,
875 },
876 self.elementwise_config,
877 Some(&self.thread_pool),
878 )
879 }
880
881 fn matmul_2d(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
882 ops::matmul_2d_with_config_and_pool(lhs, rhs, self.matmul_config, Some(&self.thread_pool))
883 }
884}
885
886pub fn add(lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
888 CpuBackend.add(lhs, rhs)
889}
890
891pub fn add_with_config(
893 lhs: &Tensor,
894 rhs: &Tensor,
895 config: ParallelElementwiseConfig,
896) -> Result<Tensor, KernelError> {
897 ops::add_with_config(lhs, rhs, config)
898}
899
900pub fn sub(lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
902 CpuBackend.sub(lhs, rhs)
903}
904
905pub fn sub_with_config(
907 lhs: &Tensor,
908 rhs: &Tensor,
909 config: ParallelElementwiseConfig,
910) -> Result<Tensor, KernelError> {
911 ops::sub_with_config(lhs, rhs, config)
912}
913
914pub fn mul(lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
916 CpuBackend.mul(lhs, rhs)
917}
918
919pub fn mul_with_config(
921 lhs: &Tensor,
922 rhs: &Tensor,
923 config: ParallelElementwiseConfig,
924) -> Result<Tensor, KernelError> {
925 ops::mul_with_config(lhs, rhs, config)
926}
927
928pub fn relu(input: &Tensor) -> Tensor {
930 CpuBackend.relu(input)
931}
932
933pub fn relu_inplace(tensor: &mut Tensor) {
935 ops::relu_inplace(tensor);
936}
937
938pub fn relu_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
940 ops::relu_with_config(input, config)
941}
942
943pub fn sigmoid(input: &Tensor) -> Tensor {
945 CpuBackend.sigmoid(input)
946}
947
948pub fn sigmoid_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
950 ops::sigmoid_with_config(input, config)
951}
952
953pub fn exp(input: &Tensor) -> Tensor {
955 CpuBackend.exp(input)
956}
957
958pub fn exp_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
960 ops::exp_with_config(input, config)
961}
962
963pub fn tanh_act(input: &Tensor) -> Tensor {
965 CpuBackend.tanh_act(input)
966}
967
968pub fn tanh_act_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
970 ops::tanh_act_with_config(input, config)
971}
972
973pub fn gelu(input: &Tensor) -> Tensor {
975 ops::gelu(input)
976}
977
978pub fn silu(input: &Tensor) -> Tensor {
980 ops::silu(input)
981}
982
983pub fn mish(input: &Tensor) -> Tensor {
985 ops::mish(input)
986}
987
988pub fn softmax_last_dim(input: &Tensor) -> Result<Tensor, KernelError> {
990 CpuBackend.softmax_last_dim(input)
991}
992
993pub fn softmax_last_dim_with_config(
995 input: &Tensor,
996 config: ParallelElementwiseConfig,
997) -> Result<Tensor, KernelError> {
998 ops::softmax_last_dim_with_config_and_pool(input, config, None)
999}
1000
1001pub fn log_softmax_last_dim(input: &Tensor) -> Result<Tensor, KernelError> {
1003 CpuBackend.log_softmax_last_dim(input)
1004}
1005
1006pub fn log_softmax_last_dim_with_config(
1008 input: &Tensor,
1009 config: ParallelElementwiseConfig,
1010) -> Result<Tensor, KernelError> {
1011 ops::log_softmax_last_dim_with_config_and_pool(input, config, None)
1012}
1013
1014pub fn logsumexp_last_dim(input: &Tensor) -> Result<Tensor, KernelError> {
1018 CpuBackend.logsumexp_last_dim(input)
1019}
1020
1021pub fn logsumexp_last_dim_with_config(
1023 input: &Tensor,
1024 config: ParallelElementwiseConfig,
1025) -> Result<Tensor, KernelError> {
1026 ops::logsumexp_last_dim_with_config_and_pool(input, config, None)
1027}
1028
1029pub fn layer_norm_last_dim(
1031 input: &Tensor,
1032 params: LayerNormLastDimParams<'_>,
1033) -> Result<Tensor, KernelError> {
1034 CpuBackend.layer_norm_last_dim(input, params)
1035}
1036
1037pub fn layer_norm_last_dim_with_config(
1039 input: &Tensor,
1040 params: LayerNormLastDimParams<'_>,
1041 config: ParallelElementwiseConfig,
1042) -> Result<Tensor, KernelError> {
1043 ops::layer_norm_last_dim_with_config_and_pool(
1044 input,
1045 LayerNormLastDimTensors {
1046 gamma: params.gamma,
1047 beta: params.beta,
1048 epsilon: params.epsilon,
1049 },
1050 config,
1051 None,
1052 )
1053}
1054
1055pub fn max_pool2d_nhwc(
1057 input: &Tensor,
1058 kernel_h: usize,
1059 kernel_w: usize,
1060 stride_h: usize,
1061 stride_w: usize,
1062) -> Result<Tensor, KernelError> {
1063 CpuBackend.max_pool2d_nhwc(input, kernel_h, kernel_w, stride_h, stride_w)
1064}
1065
1066pub fn max_pool2d_nhwc_with_config(
1068 input: &Tensor,
1069 kernel_h: usize,
1070 kernel_w: usize,
1071 stride_h: usize,
1072 stride_w: usize,
1073 config: ParallelElementwiseConfig,
1074) -> Result<Tensor, KernelError> {
1075 ops::max_pool2d_nhwc_with_config_and_pool(
1076 input,
1077 Pool2dSpec {
1078 kernel_h,
1079 kernel_w,
1080 stride_h,
1081 stride_w,
1082 },
1083 config,
1084 None,
1085 )
1086}
1087
1088pub fn avg_pool2d_nhwc(
1090 input: &Tensor,
1091 kernel_h: usize,
1092 kernel_w: usize,
1093 stride_h: usize,
1094 stride_w: usize,
1095) -> Result<Tensor, KernelError> {
1096 CpuBackend.avg_pool2d_nhwc(input, kernel_h, kernel_w, stride_h, stride_w)
1097}
1098
1099pub fn avg_pool2d_nhwc_with_config(
1101 input: &Tensor,
1102 kernel_h: usize,
1103 kernel_w: usize,
1104 stride_h: usize,
1105 stride_w: usize,
1106 config: ParallelElementwiseConfig,
1107) -> Result<Tensor, KernelError> {
1108 ops::avg_pool2d_nhwc_with_config_and_pool(
1109 input,
1110 Pool2dSpec {
1111 kernel_h,
1112 kernel_w,
1113 stride_h,
1114 stride_w,
1115 },
1116 config,
1117 None,
1118 )
1119}
1120
1121pub fn conv2d_nhwc(
1123 input: &Tensor,
1124 kernel: &Tensor,
1125 bias: Option<&Tensor>,
1126 stride_h: usize,
1127 stride_w: usize,
1128) -> Result<Tensor, KernelError> {
1129 CpuBackend.conv2d_nhwc(input, kernel, bias, stride_h, stride_w)
1130}
1131
1132pub fn conv2d_nhwc_with_config(
1134 input: &Tensor,
1135 kernel: &Tensor,
1136 bias: Option<&Tensor>,
1137 stride_h: usize,
1138 stride_w: usize,
1139 config: ParallelElementwiseConfig,
1140) -> Result<Tensor, KernelError> {
1141 ops::conv2d_nhwc_with_config_and_pool(
1142 input,
1143 kernel,
1144 bias,
1145 Conv2dSpec { stride_h, stride_w },
1146 config,
1147 None,
1148 )
1149}
1150
1151pub fn deformable_conv2d_nhwc(
1156 input: &Tensor,
1157 weight: &Tensor,
1158 offsets: &Tensor,
1159 bias: Option<&Tensor>,
1160 stride: usize,
1161 padding: usize,
1162) -> Result<Tensor, KernelError> {
1163 ops::deformable_conv2d_nhwc(input, weight, offsets, bias, stride, padding)
1164}
1165
1166pub fn depthwise_conv2d_nhwc(
1168 input: &Tensor,
1169 kernel: &Tensor,
1170 bias: Option<&Tensor>,
1171 stride_h: usize,
1172 stride_w: usize,
1173) -> Result<Tensor, KernelError> {
1174 CpuBackend.depthwise_conv2d_nhwc(input, kernel, bias, stride_h, stride_w)
1175}
1176
1177pub fn depthwise_conv2d_nhwc_with_config(
1179 input: &Tensor,
1180 kernel: &Tensor,
1181 bias: Option<&Tensor>,
1182 stride_h: usize,
1183 stride_w: usize,
1184 config: ParallelElementwiseConfig,
1185) -> Result<Tensor, KernelError> {
1186 ops::depthwise_conv2d_nhwc_with_config_and_pool(
1187 input,
1188 kernel,
1189 bias,
1190 DepthwiseConv2dSpec { stride_h, stride_w },
1191 config,
1192 None,
1193 )
1194}
1195
1196pub fn separable_conv2d_nhwc(
1199 input: &Tensor,
1200 params: SeparableConv2dParams<'_>,
1201 stride_h: usize,
1202 stride_w: usize,
1203) -> Result<Tensor, KernelError> {
1204 CpuBackend.separable_conv2d_nhwc(input, params, stride_h, stride_w)
1205}
1206
1207pub fn separable_conv2d_nhwc_with_config(
1209 input: &Tensor,
1210 params: SeparableConv2dParams<'_>,
1211 stride_h: usize,
1212 stride_w: usize,
1213 config: ParallelElementwiseConfig,
1214) -> Result<Tensor, KernelError> {
1215 ops::separable_conv2d_nhwc_with_config_and_pool(
1216 input,
1217 SeparableConv2dKernels {
1218 depthwise_kernel: params.depthwise_kernel,
1219 depthwise_bias: params.depthwise_bias,
1220 pointwise_kernel: params.pointwise_kernel,
1221 pointwise_bias: params.pointwise_bias,
1222 },
1223 SeparableConv2dSpec { stride_h, stride_w },
1224 config,
1225 None,
1226 )
1227}
1228
1229pub fn batch_norm2d_nhwc(
1232 input: &Tensor,
1233 params: BatchNorm2dParams<'_>,
1234) -> Result<Tensor, KernelError> {
1235 CpuBackend.batch_norm2d_nhwc(input, params)
1236}
1237
1238pub fn batch_norm2d_nhwc_with_config(
1240 input: &Tensor,
1241 params: BatchNorm2dParams<'_>,
1242 config: ParallelElementwiseConfig,
1243) -> Result<Tensor, KernelError> {
1244 ops::batch_norm2d_nhwc_with_config_and_pool(
1245 input,
1246 BatchNorm2dTensors {
1247 gamma: params.gamma,
1248 beta: params.beta,
1249 mean: params.mean,
1250 variance: params.variance,
1251 epsilon: params.epsilon,
1252 },
1253 config,
1254 None,
1255 )
1256}
1257
1258pub fn group_norm_nhwc(
1260 input: &Tensor,
1261 params: GroupNormNhwcParams<'_>,
1262) -> Result<Tensor, KernelError> {
1263 CpuBackend.group_norm_nhwc(input, params)
1264}
1265
1266pub fn group_norm_nhwc_with_config(
1268 input: &Tensor,
1269 params: GroupNormNhwcParams<'_>,
1270 config: ParallelElementwiseConfig,
1271) -> Result<Tensor, KernelError> {
1272 ops::group_norm_nhwc_with_config_and_pool(
1273 input,
1274 GroupNorm2dTensors {
1275 gamma: params.gamma,
1276 beta: params.beta,
1277 num_groups: params.num_groups,
1278 epsilon: params.epsilon,
1279 },
1280 config,
1281 None,
1282 )
1283}
1284
1285pub fn rms_norm_last_dim(
1287 input: &Tensor,
1288 params: RmsNormLastDimParams<'_>,
1289) -> Result<Tensor, KernelError> {
1290 CpuBackend.rms_norm_last_dim(input, params)
1291}
1292
1293pub fn rms_norm_last_dim_with_config(
1295 input: &Tensor,
1296 params: RmsNormLastDimParams<'_>,
1297 config: ParallelElementwiseConfig,
1298) -> Result<Tensor, KernelError> {
1299 ops::rms_norm_last_dim_with_config_and_pool(
1300 input,
1301 RmsNormLastDimTensors {
1302 gamma: params.gamma,
1303 epsilon: params.epsilon,
1304 },
1305 config,
1306 None,
1307 )
1308}
1309
1310pub fn matmul_2d(lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
1312 CpuBackend.matmul_2d(lhs, rhs)
1313}
1314
1315pub fn matmul_2d_sequential(lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
1317 ops::matmul_2d_sequential(lhs, rhs)
1318}
1319
1320pub fn matmul_2d_with_config(
1322 lhs: &Tensor,
1323 rhs: &Tensor,
1324 config: ParallelMatmulConfig,
1325) -> Result<Tensor, KernelError> {
1326 ops::matmul_2d_with_config(lhs, rhs, config)
1327}
1328
1329pub fn matmul_2d_with_threads(
1331 lhs: &Tensor,
1332 rhs: &Tensor,
1333 num_threads: NonZeroUsize,
1334 config: ParallelMatmulConfig,
1335) -> Result<Tensor, KernelError> {
1336 let backend = ThreadedCpuBackend::with_config(num_threads, config)?;
1337 backend.matmul_2d(lhs, rhs)
1338}
1339
1340pub fn embedding_lookup(weight: &Tensor, indices: &Tensor) -> Result<Tensor, KernelError> {
1347 ops::embedding_lookup(weight, indices)
1348}
1349
1350pub fn dropout(input: &Tensor, p: f32, seed: u64, training: bool) -> Result<Tensor, KernelError> {
1355 ops::dropout(input, p, seed, training)
1356}
1357
1358pub fn scaled_dot_product_attention(
1369 query: &Tensor,
1370 key: &Tensor,
1371 value: &Tensor,
1372 mask: Option<&Tensor>,
1373) -> Result<Tensor, KernelError> {
1374 ops::attention::scaled_dot_product_attention(query, key, value, mask)
1375}
1376
1377pub fn flash_attention(
1380 query: &Tensor,
1381 key: &Tensor,
1382 value: &Tensor,
1383 mask: Option<&Tensor>,
1384) -> Result<Tensor, KernelError> {
1385 ops::attention::flash_attention(query, key, value, mask)
1386}
1387
1388#[allow(clippy::too_many_arguments)]
1393pub fn transpose_conv2d_nhwc(
1394 input: &Tensor,
1395 kernel: &Tensor,
1396 bias: Option<&Tensor>,
1397 stride_h: usize,
1398 stride_w: usize,
1399) -> Result<Tensor, KernelError> {
1400 let is = input.shape();
1401 let ks = kernel.shape();
1402 if is.len() != 4 || ks.len() != 4 {
1403 return Err(KernelError::InvalidConvRank {
1404 input_rank: is.len(),
1405 kernel_rank: ks.len(),
1406 });
1407 }
1408 let (n, ih, iw, ic) = (is[0], is[1], is[2], is[3]);
1409 let (kh, kw, _kc, oc) = (ks[0], ks[1], ks[2], ks[3]);
1410 let oh = (ih - 1) * stride_h + kh;
1411 let ow = (iw - 1) * stride_w + kw;
1412
1413 let in_d = input.data();
1414 let k_d = kernel.data();
1415 let bias_d: Vec<f32> = bias.map_or_else(|| vec![0.0f32; oc], |b| b.data().to_vec());
1416
1417 let mut out = vec![0.0f32; n * oh * ow * oc];
1418
1419 for b in 0..n {
1420 for iy in 0..ih {
1421 for ix in 0..iw {
1422 for ci in 0..ic {
1423 let in_val = in_d[((b * ih + iy) * iw + ix) * ic + ci];
1424 for ky in 0..kh {
1425 for kx in 0..kw {
1426 let oy = iy * stride_h + ky;
1427 let ox = ix * stride_w + kx;
1428 for co in 0..oc {
1429 let k_val = k_d[((ky * kw + kx) * ic + ci) * oc + co];
1430 out[((b * oh + oy) * ow + ox) * oc + co] += in_val * k_val;
1431 }
1432 }
1433 }
1434 }
1435 }
1436 }
1437 for oy in 0..oh {
1438 for ox in 0..ow {
1439 for co in 0..oc {
1440 out[((b * oh + oy) * ow + ox) * oc + co] += bias_d[co];
1441 }
1442 }
1443 }
1444 }
1445
1446 Tensor::from_vec(vec![n, oh, ow, oc], out).map_err(Into::into)
1447}