Skip to main content

yscv_kernels/
backend.rs

1use std::num::NonZeroUsize;
2
3use rayon::{ThreadPool, ThreadPoolBuilder};
4use yscv_tensor::Tensor;
5
6use super::{
7    error::KernelError,
8    ops::{
9        self, BatchNorm2dTensors, Conv2dSpec, DepthwiseConv2dSpec, GroupNorm2dTensors,
10        LayerNormLastDimTensors, ParallelElementwiseConfig, ParallelMatmulConfig, Pool2dSpec,
11        RmsNormLastDimTensors, SeparableConv2dKernels, SeparableConv2dSpec,
12    },
13};
14
15/// Tensor parameter bundle for NHWC separable convolution:
16/// depthwise (`[KH, KW, C, depth_multiplier]`) then pointwise (`[1, 1, C*depth_multiplier, C_out]`).
17#[derive(Debug, Clone, Copy)]
18pub struct SeparableConv2dParams<'a> {
19    pub depthwise_kernel: &'a Tensor,
20    pub depthwise_bias: Option<&'a Tensor>,
21    pub pointwise_kernel: &'a Tensor,
22    pub pointwise_bias: Option<&'a Tensor>,
23}
24
25/// Tensor parameter bundle for NHWC batch-normalization inference.
26#[derive(Debug, Clone, Copy)]
27pub struct BatchNorm2dParams<'a> {
28    pub gamma: &'a Tensor,
29    pub beta: &'a Tensor,
30    pub mean: &'a Tensor,
31    pub variance: &'a Tensor,
32    pub epsilon: f32,
33}
34
35/// Tensor parameter bundle for layer normalization over the last tensor dimension.
36#[derive(Debug, Clone, Copy)]
37pub struct LayerNormLastDimParams<'a> {
38    pub gamma: &'a Tensor,
39    pub beta: &'a Tensor,
40    pub epsilon: f32,
41}
42
43/// Tensor parameter bundle for NHWC group normalization.
44#[derive(Debug, Clone, Copy)]
45pub struct GroupNormNhwcParams<'a> {
46    pub gamma: &'a Tensor,
47    pub beta: &'a Tensor,
48    pub num_groups: usize,
49    pub epsilon: f32,
50}
51
52/// Tensor parameter bundle for RMS normalization over the last tensor dimension.
53#[derive(Debug, Clone, Copy)]
54pub struct RmsNormLastDimParams<'a> {
55    pub gamma: &'a Tensor,
56    pub epsilon: f32,
57}
58
59/// Runtime backend contract for core deterministic kernels.
60pub trait Backend {
61    fn add(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>;
62    fn sub(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>;
63    fn mul(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>;
64    fn relu(&self, input: &Tensor) -> Tensor;
65    fn sigmoid(&self, input: &Tensor) -> Tensor;
66    fn exp(&self, input: &Tensor) -> Tensor;
67    fn tanh_act(&self, input: &Tensor) -> Tensor;
68    fn softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError>;
69    fn log_softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError>;
70    fn logsumexp_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError>;
71    fn layer_norm_last_dim(
72        &self,
73        input: &Tensor,
74        params: LayerNormLastDimParams<'_>,
75    ) -> Result<Tensor, KernelError>;
76    fn max_pool2d_nhwc(
77        &self,
78        input: &Tensor,
79        kernel_h: usize,
80        kernel_w: usize,
81        stride_h: usize,
82        stride_w: usize,
83    ) -> Result<Tensor, KernelError>;
84    fn avg_pool2d_nhwc(
85        &self,
86        input: &Tensor,
87        kernel_h: usize,
88        kernel_w: usize,
89        stride_h: usize,
90        stride_w: usize,
91    ) -> Result<Tensor, KernelError>;
92    fn conv2d_nhwc(
93        &self,
94        input: &Tensor,
95        kernel: &Tensor,
96        bias: Option<&Tensor>,
97        stride_h: usize,
98        stride_w: usize,
99    ) -> Result<Tensor, KernelError>;
100    fn depthwise_conv2d_nhwc(
101        &self,
102        input: &Tensor,
103        kernel: &Tensor,
104        bias: Option<&Tensor>,
105        stride_h: usize,
106        stride_w: usize,
107    ) -> Result<Tensor, KernelError>;
108    fn separable_conv2d_nhwc(
109        &self,
110        input: &Tensor,
111        params: SeparableConv2dParams<'_>,
112        stride_h: usize,
113        stride_w: usize,
114    ) -> Result<Tensor, KernelError>;
115    fn batch_norm2d_nhwc(
116        &self,
117        input: &Tensor,
118        params: BatchNorm2dParams<'_>,
119    ) -> Result<Tensor, KernelError>;
120    fn group_norm_nhwc(
121        &self,
122        input: &Tensor,
123        params: GroupNormNhwcParams<'_>,
124    ) -> Result<Tensor, KernelError>;
125    fn rms_norm_last_dim(
126        &self,
127        input: &Tensor,
128        params: RmsNormLastDimParams<'_>,
129    ) -> Result<Tensor, KernelError>;
130    fn matmul_2d(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>;
131
132    // ── Backward-relevant ops with default CPU implementations ──────
133
134    /// Element-wise negation.
135    fn neg(&self, input: &Tensor) -> Tensor {
136        input.neg()
137    }
138
139    /// Element-wise division with broadcast.
140    fn div(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
141        lhs.div(rhs).map_err(Into::into)
142    }
143
144    /// Element-wise square root.
145    fn sqrt(&self, input: &Tensor) -> Tensor {
146        input.sqrt()
147    }
148
149    /// Transpose a 2-D matrix.
150    fn transpose_2d(&self, input: &Tensor) -> Result<Tensor, KernelError> {
151        input.transpose_2d().map_err(Into::into)
152    }
153
154    /// Scalar sum of all elements (returns a scalar tensor).
155    fn sum_all(&self, input: &Tensor) -> Tensor {
156        Tensor::scalar(input.sum())
157    }
158
159    /// Multiply every element by a scalar.
160    fn mul_scalar(&self, input: &Tensor, scalar: f32) -> Tensor {
161        input.scale(scalar)
162    }
163
164    /// Element-wise reciprocal (1/x).
165    fn reciprocal(&self, input: &Tensor) -> Tensor {
166        input.reciprocal()
167    }
168}
169
170/// Extension trait for backward-pass operations.
171///
172/// Separated from [`Backend`] so that forward-only consumers (e.g. ONNX inference)
173/// need not depend on backward-related method signatures.  All methods have default
174/// CPU implementations, so `impl BackwardOps for MyBackend {}` is sufficient.
175pub trait BackwardOps: Backend {
176    /// ReLU backward: `grad_input[i] = upstream[i] * (forward_input[i] > 0 ? 1 : 0)`.
177    fn relu_backward(
178        &self,
179        upstream: &Tensor,
180        forward_input: &Tensor,
181    ) -> Result<Tensor, KernelError> {
182        let u = upstream.data();
183        let f = forward_input.data();
184        let out: Vec<f32> = u
185            .iter()
186            .zip(f.iter())
187            .map(|(&u, &x)| if x > 0.0 { u } else { 0.0 })
188            .collect();
189        Tensor::from_vec(upstream.shape().to_vec(), out).map_err(Into::into)
190    }
191
192    /// Sigmoid backward: `grad_input[i] = upstream[i] * s[i] * (1 - s[i])` where `s` = forward output.
193    fn sigmoid_backward(
194        &self,
195        upstream: &Tensor,
196        forward_output: &Tensor,
197    ) -> Result<Tensor, KernelError> {
198        let u = upstream.data();
199        let s = forward_output.data();
200        let out: Vec<f32> = u
201            .iter()
202            .zip(s.iter())
203            .map(|(&u, &s)| u * s * (1.0 - s))
204            .collect();
205        Tensor::from_vec(upstream.shape().to_vec(), out).map_err(Into::into)
206    }
207
208    /// Tanh backward: `grad_input[i] = upstream[i] * (1 - t[i]^2)` where `t` = forward output.
209    fn tanh_backward(
210        &self,
211        upstream: &Tensor,
212        forward_output: &Tensor,
213    ) -> Result<Tensor, KernelError> {
214        let u = upstream.data();
215        let t = forward_output.data();
216        let out: Vec<f32> = u
217            .iter()
218            .zip(t.iter())
219            .map(|(&u, &t)| u * (1.0 - t * t))
220            .collect();
221        Tensor::from_vec(upstream.shape().to_vec(), out).map_err(Into::into)
222    }
223
224    /// Exp backward: `grad_input[i] = upstream[i] * e[i]` where `e` = forward output.
225    fn exp_backward(
226        &self,
227        upstream: &Tensor,
228        forward_output: &Tensor,
229    ) -> Result<Tensor, KernelError> {
230        let u = upstream.data();
231        let e = forward_output.data();
232        let out: Vec<f32> = u.iter().zip(e.iter()).map(|(&u, &e)| u * e).collect();
233        Tensor::from_vec(upstream.shape().to_vec(), out).map_err(Into::into)
234    }
235
236    /// Reduce-sum backward: broadcast scalar gradient to all elements of `original_shape`.
237    fn reduce_sum_backward(
238        &self,
239        upstream: &Tensor,
240        original_shape: &[usize],
241    ) -> Result<Tensor, KernelError> {
242        let grad_val = upstream.data()[0];
243        let len: usize = original_shape.iter().product();
244        let out = vec![grad_val; len];
245        Tensor::from_vec(original_shape.to_vec(), out).map_err(Into::into)
246    }
247
248    /// MatMul backward: `grad_lhs = upstream @ rhs^T`, `grad_rhs = lhs^T @ upstream`.
249    fn matmul_backward(
250        &self,
251        upstream: &Tensor,
252        lhs: &Tensor,
253        rhs: &Tensor,
254    ) -> Result<(Tensor, Tensor), KernelError> {
255        let rt = self.transpose_2d(rhs)?;
256        let lt = self.transpose_2d(lhs)?;
257        let grad_lhs = self.matmul_2d(upstream, &rt)?;
258        let grad_rhs = self.matmul_2d(&lt, upstream)?;
259        Ok((grad_lhs, grad_rhs))
260    }
261
262    /// Add backward: gradient passes through unchanged to both operands.
263    fn add_backward(
264        &self,
265        upstream: &Tensor,
266        _lhs: &Tensor,
267        _rhs: &Tensor,
268    ) -> Result<(Tensor, Tensor), KernelError> {
269        Ok((upstream.clone(), upstream.clone()))
270    }
271
272    /// Sub backward: `grad_lhs = upstream`, `grad_rhs = -upstream`.
273    fn sub_backward(
274        &self,
275        upstream: &Tensor,
276        _lhs: &Tensor,
277        _rhs: &Tensor,
278    ) -> Result<(Tensor, Tensor), KernelError> {
279        Ok((upstream.clone(), self.neg(upstream)))
280    }
281
282    /// Mul backward: `grad_lhs = upstream * rhs`, `grad_rhs = upstream * lhs`.
283    fn mul_backward(
284        &self,
285        upstream: &Tensor,
286        lhs: &Tensor,
287        rhs: &Tensor,
288    ) -> Result<(Tensor, Tensor), KernelError> {
289        let grad_lhs = self.mul(upstream, rhs)?;
290        let grad_rhs = self.mul(upstream, lhs)?;
291        Ok((grad_lhs, grad_rhs))
292    }
293
294    /// Conv2d backward (input gradient): compute dL/dInput from dL/dOutput and weights.
295    ///
296    /// Default CPU implementation via full convolution with flipped kernel.
297    fn conv2d_input_backward(
298        &self,
299        upstream: &Tensor,
300        kernel: &Tensor,
301        input_shape: &[usize],
302        stride_h: usize,
303        stride_w: usize,
304    ) -> Result<Tensor, KernelError> {
305        // upstream: [N, OH, OW, OC], kernel: [KH, KW, IC, OC], output: [N, IH, IW, IC]
306        let us = upstream.shape();
307        let ks = kernel.shape();
308        if us.len() != 4 || ks.len() != 4 || input_shape.len() != 4 {
309            return Err(KernelError::InvalidConvRank {
310                input_rank: input_shape.len(),
311                kernel_rank: ks.len(),
312            });
313        }
314        let (n, ih, iw, ic) = (
315            input_shape[0],
316            input_shape[1],
317            input_shape[2],
318            input_shape[3],
319        );
320        let (_n, oh, ow, oc) = (us[0], us[1], us[2], us[3]);
321        let (kh, kw) = (ks[0], ks[1]);
322
323        let u_data = upstream.data();
324        let k_data = kernel.data();
325        let mut grad_input = vec![0.0f32; n * ih * iw * ic];
326
327        for b in 0..n {
328            for oy in 0..oh {
329                for ox in 0..ow {
330                    for co in 0..oc {
331                        let g = u_data[((b * oh + oy) * ow + ox) * oc + co];
332                        if g == 0.0 {
333                            continue;
334                        }
335                        for ky in 0..kh {
336                            for kx in 0..kw {
337                                let iy = oy * stride_h + ky;
338                                let ix = ox * stride_w + kx;
339                                if iy < ih && ix < iw {
340                                    for ci in 0..ic {
341                                        let k_val = k_data[((ky * kw + kx) * ic + ci) * oc + co];
342                                        grad_input[((b * ih + iy) * iw + ix) * ic + ci] +=
343                                            g * k_val;
344                                    }
345                                }
346                            }
347                        }
348                    }
349                }
350            }
351        }
352
353        Tensor::from_vec(input_shape.to_vec(), grad_input).map_err(Into::into)
354    }
355}
356
357// Explicit BackwardOps implementations using default methods (CPU fallback).
358// GpuBackend provides GPU-accelerated overrides in gpu_backend.rs.
359
360/// Deterministic CPU backend with fixed operation order.
361#[derive(Debug, Clone, Copy, Default)]
362pub struct CpuBackend;
363
364impl BackwardOps for CpuBackend {}
365
366impl Backend for CpuBackend {
367    fn add(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
368        ops::add_with_config(lhs, rhs, ParallelElementwiseConfig::disabled())
369    }
370
371    fn sub(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
372        ops::sub_with_config(lhs, rhs, ParallelElementwiseConfig::disabled())
373    }
374
375    fn mul(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
376        ops::mul_with_config(lhs, rhs, ParallelElementwiseConfig::disabled())
377    }
378
379    fn relu(&self, input: &Tensor) -> Tensor {
380        ops::relu(input)
381    }
382
383    fn sigmoid(&self, input: &Tensor) -> Tensor {
384        ops::sigmoid(input)
385    }
386
387    fn exp(&self, input: &Tensor) -> Tensor {
388        ops::exp(input)
389    }
390
391    fn tanh_act(&self, input: &Tensor) -> Tensor {
392        ops::tanh_act(input)
393    }
394
395    fn softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError> {
396        ops::softmax_last_dim_with_config_and_pool(
397            input,
398            ParallelElementwiseConfig::disabled(),
399            None,
400        )
401    }
402
403    fn log_softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError> {
404        ops::log_softmax_last_dim_with_config_and_pool(
405            input,
406            ParallelElementwiseConfig::disabled(),
407            None,
408        )
409    }
410
411    fn logsumexp_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError> {
412        ops::logsumexp_last_dim_with_config_and_pool(
413            input,
414            ParallelElementwiseConfig::disabled(),
415            None,
416        )
417    }
418
419    fn layer_norm_last_dim(
420        &self,
421        input: &Tensor,
422        params: LayerNormLastDimParams<'_>,
423    ) -> Result<Tensor, KernelError> {
424        ops::layer_norm_last_dim_with_config_and_pool(
425            input,
426            LayerNormLastDimTensors {
427                gamma: params.gamma,
428                beta: params.beta,
429                epsilon: params.epsilon,
430            },
431            ParallelElementwiseConfig::disabled(),
432            None,
433        )
434    }
435
436    fn max_pool2d_nhwc(
437        &self,
438        input: &Tensor,
439        kernel_h: usize,
440        kernel_w: usize,
441        stride_h: usize,
442        stride_w: usize,
443    ) -> Result<Tensor, KernelError> {
444        ops::max_pool2d_nhwc_with_config_and_pool(
445            input,
446            Pool2dSpec {
447                kernel_h,
448                kernel_w,
449                stride_h,
450                stride_w,
451            },
452            ParallelElementwiseConfig::disabled(),
453            None,
454        )
455    }
456
457    fn avg_pool2d_nhwc(
458        &self,
459        input: &Tensor,
460        kernel_h: usize,
461        kernel_w: usize,
462        stride_h: usize,
463        stride_w: usize,
464    ) -> Result<Tensor, KernelError> {
465        ops::avg_pool2d_nhwc_with_config_and_pool(
466            input,
467            Pool2dSpec {
468                kernel_h,
469                kernel_w,
470                stride_h,
471                stride_w,
472            },
473            ParallelElementwiseConfig::disabled(),
474            None,
475        )
476    }
477
478    fn conv2d_nhwc(
479        &self,
480        input: &Tensor,
481        kernel: &Tensor,
482        bias: Option<&Tensor>,
483        stride_h: usize,
484        stride_w: usize,
485    ) -> Result<Tensor, KernelError> {
486        ops::conv2d_nhwc_with_config_and_pool(
487            input,
488            kernel,
489            bias,
490            Conv2dSpec { stride_h, stride_w },
491            ParallelElementwiseConfig::disabled(),
492            None,
493        )
494    }
495
496    fn depthwise_conv2d_nhwc(
497        &self,
498        input: &Tensor,
499        kernel: &Tensor,
500        bias: Option<&Tensor>,
501        stride_h: usize,
502        stride_w: usize,
503    ) -> Result<Tensor, KernelError> {
504        ops::depthwise_conv2d_nhwc_with_config_and_pool(
505            input,
506            kernel,
507            bias,
508            DepthwiseConv2dSpec { stride_h, stride_w },
509            ParallelElementwiseConfig::disabled(),
510            None,
511        )
512    }
513
514    fn separable_conv2d_nhwc(
515        &self,
516        input: &Tensor,
517        params: SeparableConv2dParams<'_>,
518        stride_h: usize,
519        stride_w: usize,
520    ) -> Result<Tensor, KernelError> {
521        ops::separable_conv2d_nhwc_with_config_and_pool(
522            input,
523            SeparableConv2dKernels {
524                depthwise_kernel: params.depthwise_kernel,
525                depthwise_bias: params.depthwise_bias,
526                pointwise_kernel: params.pointwise_kernel,
527                pointwise_bias: params.pointwise_bias,
528            },
529            SeparableConv2dSpec { stride_h, stride_w },
530            ParallelElementwiseConfig::disabled(),
531            None,
532        )
533    }
534
535    fn batch_norm2d_nhwc(
536        &self,
537        input: &Tensor,
538        params: BatchNorm2dParams<'_>,
539    ) -> Result<Tensor, KernelError> {
540        ops::batch_norm2d_nhwc_with_config_and_pool(
541            input,
542            BatchNorm2dTensors {
543                gamma: params.gamma,
544                beta: params.beta,
545                mean: params.mean,
546                variance: params.variance,
547                epsilon: params.epsilon,
548            },
549            ParallelElementwiseConfig::disabled(),
550            None,
551        )
552    }
553
554    fn group_norm_nhwc(
555        &self,
556        input: &Tensor,
557        params: GroupNormNhwcParams<'_>,
558    ) -> Result<Tensor, KernelError> {
559        ops::group_norm_nhwc_with_config_and_pool(
560            input,
561            GroupNorm2dTensors {
562                gamma: params.gamma,
563                beta: params.beta,
564                num_groups: params.num_groups,
565                epsilon: params.epsilon,
566            },
567            ParallelElementwiseConfig::disabled(),
568            None,
569        )
570    }
571
572    fn rms_norm_last_dim(
573        &self,
574        input: &Tensor,
575        params: RmsNormLastDimParams<'_>,
576    ) -> Result<Tensor, KernelError> {
577        ops::rms_norm_last_dim_with_config_and_pool(
578            input,
579            RmsNormLastDimTensors {
580                gamma: params.gamma,
581                epsilon: params.epsilon,
582            },
583            ParallelElementwiseConfig::disabled(),
584            None,
585        )
586    }
587
588    fn matmul_2d(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
589        ops::matmul_2d(lhs, rhs)
590    }
591}
592
593/// CPU backend with a dedicated rayon thread pool for predictable kernel threading depth.
594#[derive(Debug)]
595pub struct ThreadedCpuBackend {
596    matmul_config: ParallelMatmulConfig,
597    elementwise_config: ParallelElementwiseConfig,
598    thread_pool: ThreadPool,
599}
600
601/// Runtime knobs for threaded CPU backend execution behavior.
602#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
603pub struct ThreadedCpuBackendConfig {
604    pub matmul: ParallelMatmulConfig,
605    pub elementwise: ParallelElementwiseConfig,
606}
607
608impl ThreadedCpuBackend {
609    /// Build a threaded backend with default parallel matmul heuristics.
610    pub fn new(num_threads: NonZeroUsize) -> Result<Self, KernelError> {
611        Self::with_full_config(num_threads, ThreadedCpuBackendConfig::default())
612    }
613
614    /// Build a threaded backend with explicit parallel-matmul configuration.
615    pub fn with_config(
616        num_threads: NonZeroUsize,
617        matmul_config: ParallelMatmulConfig,
618    ) -> Result<Self, KernelError> {
619        Self::with_full_config(
620            num_threads,
621            ThreadedCpuBackendConfig {
622                matmul: matmul_config,
623                elementwise: ParallelElementwiseConfig::default(),
624            },
625        )
626    }
627
628    /// Build a threaded backend with explicit matmul and elementwise configuration.
629    pub fn with_full_config(
630        num_threads: NonZeroUsize,
631        config: ThreadedCpuBackendConfig,
632    ) -> Result<Self, KernelError> {
633        let thread_pool = ThreadPoolBuilder::new()
634            .num_threads(num_threads.get())
635            .build()
636            .map_err(|error| KernelError::ThreadPoolBuild {
637                message: error.to_string(),
638            })?;
639        Ok(Self {
640            matmul_config: config.matmul,
641            elementwise_config: config.elementwise,
642            thread_pool,
643        })
644    }
645
646    /// Matmul parallelism knobs used by this backend.
647    pub const fn matmul_config(&self) -> ParallelMatmulConfig {
648        self.matmul_config
649    }
650
651    /// Elementwise parallelism knobs used by this backend.
652    pub const fn elementwise_config(&self) -> ParallelElementwiseConfig {
653        self.elementwise_config
654    }
655}
656
657impl BackwardOps for ThreadedCpuBackend {}
658
659impl Backend for ThreadedCpuBackend {
660    fn add(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
661        ops::add_with_config_and_pool(lhs, rhs, self.elementwise_config, Some(&self.thread_pool))
662    }
663
664    fn sub(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
665        ops::sub_with_config_and_pool(lhs, rhs, self.elementwise_config, Some(&self.thread_pool))
666    }
667
668    fn mul(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
669        ops::mul_with_config_and_pool(lhs, rhs, self.elementwise_config, Some(&self.thread_pool))
670    }
671
672    fn relu(&self, input: &Tensor) -> Tensor {
673        ops::relu_with_config_and_pool(input, self.elementwise_config, Some(&self.thread_pool))
674    }
675
676    fn sigmoid(&self, input: &Tensor) -> Tensor {
677        ops::sigmoid_with_config_and_pool(input, self.elementwise_config, Some(&self.thread_pool))
678    }
679
680    fn exp(&self, input: &Tensor) -> Tensor {
681        ops::exp_with_config_and_pool(input, self.elementwise_config, Some(&self.thread_pool))
682    }
683
684    fn tanh_act(&self, input: &Tensor) -> Tensor {
685        ops::tanh_act_with_config_and_pool(input, self.elementwise_config, Some(&self.thread_pool))
686    }
687
688    fn softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError> {
689        ops::softmax_last_dim_with_config_and_pool(
690            input,
691            self.elementwise_config,
692            Some(&self.thread_pool),
693        )
694    }
695
696    fn log_softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError> {
697        ops::log_softmax_last_dim_with_config_and_pool(
698            input,
699            self.elementwise_config,
700            Some(&self.thread_pool),
701        )
702    }
703
704    fn logsumexp_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError> {
705        ops::logsumexp_last_dim_with_config_and_pool(
706            input,
707            self.elementwise_config,
708            Some(&self.thread_pool),
709        )
710    }
711
712    fn layer_norm_last_dim(
713        &self,
714        input: &Tensor,
715        params: LayerNormLastDimParams<'_>,
716    ) -> Result<Tensor, KernelError> {
717        ops::layer_norm_last_dim_with_config_and_pool(
718            input,
719            LayerNormLastDimTensors {
720                gamma: params.gamma,
721                beta: params.beta,
722                epsilon: params.epsilon,
723            },
724            self.elementwise_config,
725            Some(&self.thread_pool),
726        )
727    }
728
729    fn max_pool2d_nhwc(
730        &self,
731        input: &Tensor,
732        kernel_h: usize,
733        kernel_w: usize,
734        stride_h: usize,
735        stride_w: usize,
736    ) -> Result<Tensor, KernelError> {
737        ops::max_pool2d_nhwc_with_config_and_pool(
738            input,
739            Pool2dSpec {
740                kernel_h,
741                kernel_w,
742                stride_h,
743                stride_w,
744            },
745            self.elementwise_config,
746            Some(&self.thread_pool),
747        )
748    }
749
750    fn avg_pool2d_nhwc(
751        &self,
752        input: &Tensor,
753        kernel_h: usize,
754        kernel_w: usize,
755        stride_h: usize,
756        stride_w: usize,
757    ) -> Result<Tensor, KernelError> {
758        ops::avg_pool2d_nhwc_with_config_and_pool(
759            input,
760            Pool2dSpec {
761                kernel_h,
762                kernel_w,
763                stride_h,
764                stride_w,
765            },
766            self.elementwise_config,
767            Some(&self.thread_pool),
768        )
769    }
770
771    fn conv2d_nhwc(
772        &self,
773        input: &Tensor,
774        kernel: &Tensor,
775        bias: Option<&Tensor>,
776        stride_h: usize,
777        stride_w: usize,
778    ) -> Result<Tensor, KernelError> {
779        ops::conv2d_nhwc_with_config_and_pool(
780            input,
781            kernel,
782            bias,
783            Conv2dSpec { stride_h, stride_w },
784            self.elementwise_config,
785            Some(&self.thread_pool),
786        )
787    }
788
789    fn depthwise_conv2d_nhwc(
790        &self,
791        input: &Tensor,
792        kernel: &Tensor,
793        bias: Option<&Tensor>,
794        stride_h: usize,
795        stride_w: usize,
796    ) -> Result<Tensor, KernelError> {
797        ops::depthwise_conv2d_nhwc_with_config_and_pool(
798            input,
799            kernel,
800            bias,
801            DepthwiseConv2dSpec { stride_h, stride_w },
802            self.elementwise_config,
803            Some(&self.thread_pool),
804        )
805    }
806
807    fn separable_conv2d_nhwc(
808        &self,
809        input: &Tensor,
810        params: SeparableConv2dParams<'_>,
811        stride_h: usize,
812        stride_w: usize,
813    ) -> Result<Tensor, KernelError> {
814        ops::separable_conv2d_nhwc_with_config_and_pool(
815            input,
816            SeparableConv2dKernels {
817                depthwise_kernel: params.depthwise_kernel,
818                depthwise_bias: params.depthwise_bias,
819                pointwise_kernel: params.pointwise_kernel,
820                pointwise_bias: params.pointwise_bias,
821            },
822            SeparableConv2dSpec { stride_h, stride_w },
823            self.elementwise_config,
824            Some(&self.thread_pool),
825        )
826    }
827
828    fn batch_norm2d_nhwc(
829        &self,
830        input: &Tensor,
831        params: BatchNorm2dParams<'_>,
832    ) -> Result<Tensor, KernelError> {
833        ops::batch_norm2d_nhwc_with_config_and_pool(
834            input,
835            BatchNorm2dTensors {
836                gamma: params.gamma,
837                beta: params.beta,
838                mean: params.mean,
839                variance: params.variance,
840                epsilon: params.epsilon,
841            },
842            self.elementwise_config,
843            Some(&self.thread_pool),
844        )
845    }
846
847    fn group_norm_nhwc(
848        &self,
849        input: &Tensor,
850        params: GroupNormNhwcParams<'_>,
851    ) -> Result<Tensor, KernelError> {
852        ops::group_norm_nhwc_with_config_and_pool(
853            input,
854            GroupNorm2dTensors {
855                gamma: params.gamma,
856                beta: params.beta,
857                num_groups: params.num_groups,
858                epsilon: params.epsilon,
859            },
860            self.elementwise_config,
861            Some(&self.thread_pool),
862        )
863    }
864
865    fn rms_norm_last_dim(
866        &self,
867        input: &Tensor,
868        params: RmsNormLastDimParams<'_>,
869    ) -> Result<Tensor, KernelError> {
870        ops::rms_norm_last_dim_with_config_and_pool(
871            input,
872            RmsNormLastDimTensors {
873                gamma: params.gamma,
874                epsilon: params.epsilon,
875            },
876            self.elementwise_config,
877            Some(&self.thread_pool),
878        )
879    }
880
881    fn matmul_2d(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
882        ops::matmul_2d_with_config_and_pool(lhs, rhs, self.matmul_config, Some(&self.thread_pool))
883    }
884}
885
886/// Backend-agnostic convenience call for add.
887pub fn add(lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
888    CpuBackend.add(lhs, rhs)
889}
890
891/// Backend-agnostic add with explicit elementwise parallelization heuristics.
892pub fn add_with_config(
893    lhs: &Tensor,
894    rhs: &Tensor,
895    config: ParallelElementwiseConfig,
896) -> Result<Tensor, KernelError> {
897    ops::add_with_config(lhs, rhs, config)
898}
899
900/// Backend-agnostic convenience call for sub.
901pub fn sub(lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
902    CpuBackend.sub(lhs, rhs)
903}
904
905/// Backend-agnostic subtract with explicit elementwise parallelization heuristics.
906pub fn sub_with_config(
907    lhs: &Tensor,
908    rhs: &Tensor,
909    config: ParallelElementwiseConfig,
910) -> Result<Tensor, KernelError> {
911    ops::sub_with_config(lhs, rhs, config)
912}
913
914/// Backend-agnostic convenience call for mul.
915pub fn mul(lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
916    CpuBackend.mul(lhs, rhs)
917}
918
919/// Backend-agnostic multiply with explicit elementwise parallelization heuristics.
920pub fn mul_with_config(
921    lhs: &Tensor,
922    rhs: &Tensor,
923    config: ParallelElementwiseConfig,
924) -> Result<Tensor, KernelError> {
925    ops::mul_with_config(lhs, rhs, config)
926}
927
928/// Elementwise ReLU activation.
929pub fn relu(input: &Tensor) -> Tensor {
930    CpuBackend.relu(input)
931}
932
933/// In-place ReLU activation: clamps negative values to zero.
934pub fn relu_inplace(tensor: &mut Tensor) {
935    ops::relu_inplace(tensor);
936}
937
938/// Elementwise ReLU with explicit elementwise parallelization heuristics.
939pub fn relu_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
940    ops::relu_with_config(input, config)
941}
942
943/// Elementwise sigmoid activation.
944pub fn sigmoid(input: &Tensor) -> Tensor {
945    CpuBackend.sigmoid(input)
946}
947
948/// Elementwise sigmoid with explicit elementwise parallelization heuristics.
949pub fn sigmoid_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
950    ops::sigmoid_with_config(input, config)
951}
952
953/// Elementwise exp activation.
954pub fn exp(input: &Tensor) -> Tensor {
955    CpuBackend.exp(input)
956}
957
958/// Elementwise exp with explicit elementwise parallelization heuristics.
959pub fn exp_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
960    ops::exp_with_config(input, config)
961}
962
963/// Elementwise tanh activation.
964pub fn tanh_act(input: &Tensor) -> Tensor {
965    CpuBackend.tanh_act(input)
966}
967
968/// Elementwise tanh with explicit elementwise parallelization heuristics.
969pub fn tanh_act_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
970    ops::tanh_act_with_config(input, config)
971}
972
973/// Elementwise GELU activation (fast approximation): `x * sigmoid(1.702 * x)`.
974pub fn gelu(input: &Tensor) -> Tensor {
975    ops::gelu(input)
976}
977
978/// Elementwise SiLU (Swish) activation: `x * sigmoid(x)`.
979pub fn silu(input: &Tensor) -> Tensor {
980    ops::silu(input)
981}
982
983/// Elementwise Mish activation: `x * tanh(ln(1 + exp(x)))`.
984pub fn mish(input: &Tensor) -> Tensor {
985    ops::mish(input)
986}
987
988/// Softmax along the last tensor dimension.
989pub fn softmax_last_dim(input: &Tensor) -> Result<Tensor, KernelError> {
990    CpuBackend.softmax_last_dim(input)
991}
992
993/// Softmax along the last tensor dimension with explicit elementwise parallelization heuristics.
994pub fn softmax_last_dim_with_config(
995    input: &Tensor,
996    config: ParallelElementwiseConfig,
997) -> Result<Tensor, KernelError> {
998    ops::softmax_last_dim_with_config_and_pool(input, config, None)
999}
1000
1001/// Log-softmax along the last tensor dimension.
1002pub fn log_softmax_last_dim(input: &Tensor) -> Result<Tensor, KernelError> {
1003    CpuBackend.log_softmax_last_dim(input)
1004}
1005
1006/// Log-softmax along the last tensor dimension with explicit elementwise parallelization heuristics.
1007pub fn log_softmax_last_dim_with_config(
1008    input: &Tensor,
1009    config: ParallelElementwiseConfig,
1010) -> Result<Tensor, KernelError> {
1011    ops::log_softmax_last_dim_with_config_and_pool(input, config, None)
1012}
1013
1014/// Log-sum-exp reduction along the last tensor dimension.
1015///
1016/// Returns shape equal to input with the last dimension set to `1`.
1017pub fn logsumexp_last_dim(input: &Tensor) -> Result<Tensor, KernelError> {
1018    CpuBackend.logsumexp_last_dim(input)
1019}
1020
1021/// Log-sum-exp reduction along the last tensor dimension with explicit elementwise parallelization heuristics.
1022pub fn logsumexp_last_dim_with_config(
1023    input: &Tensor,
1024    config: ParallelElementwiseConfig,
1025) -> Result<Tensor, KernelError> {
1026    ops::logsumexp_last_dim_with_config_and_pool(input, config, None)
1027}
1028
1029/// Layer normalization over the last tensor dimension.
1030pub fn layer_norm_last_dim(
1031    input: &Tensor,
1032    params: LayerNormLastDimParams<'_>,
1033) -> Result<Tensor, KernelError> {
1034    CpuBackend.layer_norm_last_dim(input, params)
1035}
1036
1037/// Layer normalization over the last tensor dimension with explicit elementwise parallelization heuristics.
1038pub fn layer_norm_last_dim_with_config(
1039    input: &Tensor,
1040    params: LayerNormLastDimParams<'_>,
1041    config: ParallelElementwiseConfig,
1042) -> Result<Tensor, KernelError> {
1043    ops::layer_norm_last_dim_with_config_and_pool(
1044        input,
1045        LayerNormLastDimTensors {
1046            gamma: params.gamma,
1047            beta: params.beta,
1048            epsilon: params.epsilon,
1049        },
1050        config,
1051        None,
1052    )
1053}
1054
1055/// NHWC max-pooling without padding.
1056pub fn max_pool2d_nhwc(
1057    input: &Tensor,
1058    kernel_h: usize,
1059    kernel_w: usize,
1060    stride_h: usize,
1061    stride_w: usize,
1062) -> Result<Tensor, KernelError> {
1063    CpuBackend.max_pool2d_nhwc(input, kernel_h, kernel_w, stride_h, stride_w)
1064}
1065
1066/// NHWC max-pooling without padding with explicit parallelization heuristics.
1067pub fn max_pool2d_nhwc_with_config(
1068    input: &Tensor,
1069    kernel_h: usize,
1070    kernel_w: usize,
1071    stride_h: usize,
1072    stride_w: usize,
1073    config: ParallelElementwiseConfig,
1074) -> Result<Tensor, KernelError> {
1075    ops::max_pool2d_nhwc_with_config_and_pool(
1076        input,
1077        Pool2dSpec {
1078            kernel_h,
1079            kernel_w,
1080            stride_h,
1081            stride_w,
1082        },
1083        config,
1084        None,
1085    )
1086}
1087
1088/// NHWC average-pooling without padding.
1089pub fn avg_pool2d_nhwc(
1090    input: &Tensor,
1091    kernel_h: usize,
1092    kernel_w: usize,
1093    stride_h: usize,
1094    stride_w: usize,
1095) -> Result<Tensor, KernelError> {
1096    CpuBackend.avg_pool2d_nhwc(input, kernel_h, kernel_w, stride_h, stride_w)
1097}
1098
1099/// NHWC average-pooling without padding with explicit parallelization heuristics.
1100pub fn avg_pool2d_nhwc_with_config(
1101    input: &Tensor,
1102    kernel_h: usize,
1103    kernel_w: usize,
1104    stride_h: usize,
1105    stride_w: usize,
1106    config: ParallelElementwiseConfig,
1107) -> Result<Tensor, KernelError> {
1108    ops::avg_pool2d_nhwc_with_config_and_pool(
1109        input,
1110        Pool2dSpec {
1111            kernel_h,
1112            kernel_w,
1113            stride_h,
1114            stride_w,
1115        },
1116        config,
1117        None,
1118    )
1119}
1120
1121/// NHWC convolution without padding using kernel shape `[KH, KW, C_in, C_out]`.
1122pub fn conv2d_nhwc(
1123    input: &Tensor,
1124    kernel: &Tensor,
1125    bias: Option<&Tensor>,
1126    stride_h: usize,
1127    stride_w: usize,
1128) -> Result<Tensor, KernelError> {
1129    CpuBackend.conv2d_nhwc(input, kernel, bias, stride_h, stride_w)
1130}
1131
1132/// NHWC convolution without padding with explicit parallelization heuristics.
1133pub fn conv2d_nhwc_with_config(
1134    input: &Tensor,
1135    kernel: &Tensor,
1136    bias: Option<&Tensor>,
1137    stride_h: usize,
1138    stride_w: usize,
1139    config: ParallelElementwiseConfig,
1140) -> Result<Tensor, KernelError> {
1141    ops::conv2d_nhwc_with_config_and_pool(
1142        input,
1143        kernel,
1144        bias,
1145        Conv2dSpec { stride_h, stride_w },
1146        config,
1147        None,
1148    )
1149}
1150
1151/// NHWC deformable convolution with learned offsets.
1152///
1153/// Input: `[N, H, W, C_in]`, Weight: `[kH, kW, C_in, C_out]`,
1154/// Offsets: `[N, H_out, W_out, kH*kW*2]`, Bias: `[C_out]` (optional).
1155pub fn deformable_conv2d_nhwc(
1156    input: &Tensor,
1157    weight: &Tensor,
1158    offsets: &Tensor,
1159    bias: Option<&Tensor>,
1160    stride: usize,
1161    padding: usize,
1162) -> Result<Tensor, KernelError> {
1163    ops::deformable_conv2d_nhwc(input, weight, offsets, bias, stride, padding)
1164}
1165
1166/// NHWC depthwise convolution without padding using kernel shape `[KH, KW, C, depth_multiplier]`.
1167pub fn depthwise_conv2d_nhwc(
1168    input: &Tensor,
1169    kernel: &Tensor,
1170    bias: Option<&Tensor>,
1171    stride_h: usize,
1172    stride_w: usize,
1173) -> Result<Tensor, KernelError> {
1174    CpuBackend.depthwise_conv2d_nhwc(input, kernel, bias, stride_h, stride_w)
1175}
1176
1177/// NHWC depthwise convolution without padding with explicit parallelization heuristics.
1178pub fn depthwise_conv2d_nhwc_with_config(
1179    input: &Tensor,
1180    kernel: &Tensor,
1181    bias: Option<&Tensor>,
1182    stride_h: usize,
1183    stride_w: usize,
1184    config: ParallelElementwiseConfig,
1185) -> Result<Tensor, KernelError> {
1186    ops::depthwise_conv2d_nhwc_with_config_and_pool(
1187        input,
1188        kernel,
1189        bias,
1190        DepthwiseConv2dSpec { stride_h, stride_w },
1191        config,
1192        None,
1193    )
1194}
1195
1196/// NHWC separable convolution without padding:
1197/// depthwise (`[KH, KW, C, depth_multiplier]`) then pointwise (`[1, 1, C*depth_multiplier, C_out]`).
1198pub fn separable_conv2d_nhwc(
1199    input: &Tensor,
1200    params: SeparableConv2dParams<'_>,
1201    stride_h: usize,
1202    stride_w: usize,
1203) -> Result<Tensor, KernelError> {
1204    CpuBackend.separable_conv2d_nhwc(input, params, stride_h, stride_w)
1205}
1206
1207/// NHWC separable convolution without padding with explicit parallelization heuristics.
1208pub fn separable_conv2d_nhwc_with_config(
1209    input: &Tensor,
1210    params: SeparableConv2dParams<'_>,
1211    stride_h: usize,
1212    stride_w: usize,
1213    config: ParallelElementwiseConfig,
1214) -> Result<Tensor, KernelError> {
1215    ops::separable_conv2d_nhwc_with_config_and_pool(
1216        input,
1217        SeparableConv2dKernels {
1218            depthwise_kernel: params.depthwise_kernel,
1219            depthwise_bias: params.depthwise_bias,
1220            pointwise_kernel: params.pointwise_kernel,
1221            pointwise_bias: params.pointwise_bias,
1222        },
1223        SeparableConv2dSpec { stride_h, stride_w },
1224        config,
1225        None,
1226    )
1227}
1228
1229/// NHWC per-channel batch normalization inference:
1230/// `out = ((x - mean) / sqrt(variance + epsilon)) * gamma + beta`.
1231pub fn batch_norm2d_nhwc(
1232    input: &Tensor,
1233    params: BatchNorm2dParams<'_>,
1234) -> Result<Tensor, KernelError> {
1235    CpuBackend.batch_norm2d_nhwc(input, params)
1236}
1237
1238/// NHWC per-channel batch normalization inference with explicit parallelization heuristics.
1239pub fn batch_norm2d_nhwc_with_config(
1240    input: &Tensor,
1241    params: BatchNorm2dParams<'_>,
1242    config: ParallelElementwiseConfig,
1243) -> Result<Tensor, KernelError> {
1244    ops::batch_norm2d_nhwc_with_config_and_pool(
1245        input,
1246        BatchNorm2dTensors {
1247            gamma: params.gamma,
1248            beta: params.beta,
1249            mean: params.mean,
1250            variance: params.variance,
1251            epsilon: params.epsilon,
1252        },
1253        config,
1254        None,
1255    )
1256}
1257
1258/// NHWC group normalization: normalize within groups of channels.
1259pub fn group_norm_nhwc(
1260    input: &Tensor,
1261    params: GroupNormNhwcParams<'_>,
1262) -> Result<Tensor, KernelError> {
1263    CpuBackend.group_norm_nhwc(input, params)
1264}
1265
1266/// NHWC group normalization with explicit parallelization heuristics.
1267pub fn group_norm_nhwc_with_config(
1268    input: &Tensor,
1269    params: GroupNormNhwcParams<'_>,
1270    config: ParallelElementwiseConfig,
1271) -> Result<Tensor, KernelError> {
1272    ops::group_norm_nhwc_with_config_and_pool(
1273        input,
1274        GroupNorm2dTensors {
1275            gamma: params.gamma,
1276            beta: params.beta,
1277            num_groups: params.num_groups,
1278            epsilon: params.epsilon,
1279        },
1280        config,
1281        None,
1282    )
1283}
1284
1285/// RMS normalization over the last tensor dimension.
1286pub fn rms_norm_last_dim(
1287    input: &Tensor,
1288    params: RmsNormLastDimParams<'_>,
1289) -> Result<Tensor, KernelError> {
1290    CpuBackend.rms_norm_last_dim(input, params)
1291}
1292
1293/// RMS normalization over the last tensor dimension with explicit parallelization heuristics.
1294pub fn rms_norm_last_dim_with_config(
1295    input: &Tensor,
1296    params: RmsNormLastDimParams<'_>,
1297    config: ParallelElementwiseConfig,
1298) -> Result<Tensor, KernelError> {
1299    ops::rms_norm_last_dim_with_config_and_pool(
1300        input,
1301        RmsNormLastDimTensors {
1302            gamma: params.gamma,
1303            epsilon: params.epsilon,
1304        },
1305        config,
1306        None,
1307    )
1308}
1309
1310/// Deterministic rank-2 matrix multiplication: `(m x k) * (k x n) -> (m x n)`.
1311pub fn matmul_2d(lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
1312    CpuBackend.matmul_2d(lhs, rhs)
1313}
1314
1315/// Single-thread deterministic rank-2 matrix multiplication.
1316pub fn matmul_2d_sequential(lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> {
1317    ops::matmul_2d_sequential(lhs, rhs)
1318}
1319
1320/// Rank-2 matrix multiplication with explicit parallelization heuristics.
1321pub fn matmul_2d_with_config(
1322    lhs: &Tensor,
1323    rhs: &Tensor,
1324    config: ParallelMatmulConfig,
1325) -> Result<Tensor, KernelError> {
1326    ops::matmul_2d_with_config(lhs, rhs, config)
1327}
1328
1329/// Rank-2 matrix multiplication executed through a dedicated thread pool.
1330pub fn matmul_2d_with_threads(
1331    lhs: &Tensor,
1332    rhs: &Tensor,
1333    num_threads: NonZeroUsize,
1334    config: ParallelMatmulConfig,
1335) -> Result<Tensor, KernelError> {
1336    let backend = ThreadedCpuBackend::with_config(num_threads, config)?;
1337    backend.matmul_2d(lhs, rhs)
1338}
1339
1340/// Looks up embeddings from a weight matrix.
1341///
1342/// `weight`: `[vocab_size, embed_dim]`
1343/// `indices`: `[*]` — flat tensor of integer indices (stored as f32)
1344///
1345/// Returns: `[*indices_shape, embed_dim]`
1346pub fn embedding_lookup(weight: &Tensor, indices: &Tensor) -> Result<Tensor, KernelError> {
1347    ops::embedding_lookup(weight, indices)
1348}
1349
1350/// Applies dropout: randomly zeroes elements with probability `p`.
1351///
1352/// During inference (`training=false`), returns input unchanged.
1353/// Uses xorshift64 PRNG with given seed for deterministic masking.
1354pub fn dropout(input: &Tensor, p: f32, seed: u64, training: bool) -> Result<Tensor, KernelError> {
1355    ops::dropout(input, p, seed, training)
1356}
1357
1358/// Scaled dot-product attention for 2-D (unbatched) inputs.
1359///
1360/// `Attention(Q, K, V, mask?) = softmax(Q @ K^T / sqrt(d_k) + mask) @ V`
1361///
1362/// * `query`:  `[seq_q, d_k]`
1363/// * `key`:    `[seq_k, d_k]`
1364/// * `value`:  `[seq_k, d_v]`
1365/// * `mask`:   optional `[seq_q, seq_k]` additive mask
1366///
1367/// Returns `[seq_q, d_v]`.
1368pub fn scaled_dot_product_attention(
1369    query: &Tensor,
1370    key: &Tensor,
1371    value: &Tensor,
1372    mask: Option<&Tensor>,
1373) -> Result<Tensor, KernelError> {
1374    ops::attention::scaled_dot_product_attention(query, key, value, mask)
1375}
1376
1377/// Memory-efficient (flash) attention — same result as `scaled_dot_product_attention`
1378/// but uses O(Br×Bc) peak memory instead of O(seq_q×seq_k).
1379pub fn flash_attention(
1380    query: &Tensor,
1381    key: &Tensor,
1382    value: &Tensor,
1383    mask: Option<&Tensor>,
1384) -> Result<Tensor, KernelError> {
1385    ops::attention::flash_attention(query, key, value, mask)
1386}
1387
1388/// CPU transposed convolution (deconvolution) in NHWC layout.
1389///
1390/// Input: `[N,H,W,C_in]`, kernel: `[KH,KW,C_in,C_out]`, bias: optional `[C_out]`.
1391/// Output: `[N, (H-1)*stride_h + KH, (W-1)*stride_w + KW, C_out]`.
1392#[allow(clippy::too_many_arguments)]
1393pub fn transpose_conv2d_nhwc(
1394    input: &Tensor,
1395    kernel: &Tensor,
1396    bias: Option<&Tensor>,
1397    stride_h: usize,
1398    stride_w: usize,
1399) -> Result<Tensor, KernelError> {
1400    let is = input.shape();
1401    let ks = kernel.shape();
1402    if is.len() != 4 || ks.len() != 4 {
1403        return Err(KernelError::InvalidConvRank {
1404            input_rank: is.len(),
1405            kernel_rank: ks.len(),
1406        });
1407    }
1408    let (n, ih, iw, ic) = (is[0], is[1], is[2], is[3]);
1409    let (kh, kw, _kc, oc) = (ks[0], ks[1], ks[2], ks[3]);
1410    let oh = (ih - 1) * stride_h + kh;
1411    let ow = (iw - 1) * stride_w + kw;
1412
1413    let in_d = input.data();
1414    let k_d = kernel.data();
1415    let bias_d: Vec<f32> = bias.map_or_else(|| vec![0.0f32; oc], |b| b.data().to_vec());
1416
1417    let mut out = vec![0.0f32; n * oh * ow * oc];
1418
1419    for b in 0..n {
1420        for iy in 0..ih {
1421            for ix in 0..iw {
1422                for ci in 0..ic {
1423                    let in_val = in_d[((b * ih + iy) * iw + ix) * ic + ci];
1424                    for ky in 0..kh {
1425                        for kx in 0..kw {
1426                            let oy = iy * stride_h + ky;
1427                            let ox = ix * stride_w + kx;
1428                            for co in 0..oc {
1429                                let k_val = k_d[((ky * kw + kx) * ic + ci) * oc + co];
1430                                out[((b * oh + oy) * ow + ox) * oc + co] += in_val * k_val;
1431                            }
1432                        }
1433                    }
1434                }
1435            }
1436        }
1437        for oy in 0..oh {
1438            for ox in 0..ow {
1439                for co in 0..oc {
1440                    out[((b * oh + oy) * ow + ox) * oc + co] += bias_d[co];
1441                }
1442            }
1443        }
1444    }
1445
1446    Tensor::from_vec(vec![n, oh, ow, oc], out).map_err(Into::into)
1447}