Skip to main content

scirs2_autograd/tensor_ops/
simd_ops.rs

1//! SIMD-accelerated tensor operations for autograd
2//!
3//! This module provides high-performance SIMD implementations of common tensor
4//! operations used in neural network training and inference. It leverages the
5//! `scirs2_core::simd` infrastructure for portable SIMD across x86_64 (AVX2/SSE)
6//! and aarch64 (NEON) architectures.
7//!
8//! ## Operations
9//!
10//! ### Element-wise arithmetic (forward + backward)
11//! - [`simd_elementwise_add`] / [`simd_elementwise_sub`]
12//! - [`simd_elementwise_mul`] / [`simd_elementwise_div`]
13//!
14//! ### Gradient accumulation (critical for backprop)
15//! - [`simd_gradient_accumulate`]
16//! - [`simd_scaled_gradient_accumulate`]
17//!
18//! ### Broadcasting operations
19//! - [`simd_broadcast_add`] / [`simd_broadcast_mul`]
20//!
21//! ### Activation functions (forward + backward)
22//! - [`simd_activation_relu`] / `simd_activation_relu_backward`
23//! - [`simd_activation_sigmoid`] / `simd_activation_sigmoid_backward`
24//! - [`simd_activation_tanh`] / `simd_activation_tanh_backward`
25//!
26//! ### Dot product / reduction
27//! - [`simd_dot_product`] / [`simd_reduction_sum`]
28//!
29//! ## Feature gating
30//!
31//! All operations in this module are gated behind the `simd` feature flag.
32//! When the feature is not enabled, scalar fallback implementations are used.
33
34use crate::op::{ComputeContext, GradientContext, Op, OpError};
35use crate::tensor::Tensor;
36use crate::Float;
37
38// ============================================================================
39// Internal SIMD dispatch helpers
40// ============================================================================
41
42/// Apply a SIMD-accelerated element-wise binary operation on flat f32 slices.
43/// Falls back to scalar when the `simd` feature is disabled or slices are non-contiguous.
44#[cfg(feature = "simd")]
45fn dispatch_binary_f32(a: &[f32], b: &[f32], op: SimdBinaryKind) -> Vec<f32> {
46    use scirs2_core::ndarray::{Array1, ArrayView1};
47
48    let a_arr = ArrayView1::from(a);
49    let b_arr = ArrayView1::from(b);
50
51    let result: Array1<f32> = match op {
52        SimdBinaryKind::Add => scirs2_core::simd::simd_add_f32(&a_arr, &b_arr),
53        SimdBinaryKind::Sub => scirs2_core::simd::simd_sub_f32(&a_arr, &b_arr),
54        SimdBinaryKind::Mul => scirs2_core::simd::simd_mul_f32(&a_arr, &b_arr),
55        SimdBinaryKind::Div => scirs2_core::simd::simd_div_f32(&a_arr, &b_arr),
56    };
57    result.to_vec()
58}
59
60/// Apply a SIMD-accelerated element-wise binary operation on flat f64 slices.
61#[cfg(feature = "simd")]
62fn dispatch_binary_f64(a: &[f64], b: &[f64], op: SimdBinaryKind) -> Vec<f64> {
63    use scirs2_core::ndarray::{Array1, ArrayView1};
64
65    let a_arr = ArrayView1::from(a);
66    let b_arr = ArrayView1::from(b);
67
68    let result: Array1<f64> = match op {
69        SimdBinaryKind::Add => scirs2_core::simd::simd_add_f64(&a_arr, &b_arr),
70        SimdBinaryKind::Sub => scirs2_core::simd::simd_sub_f64(&a_arr, &b_arr),
71        SimdBinaryKind::Mul => scirs2_core::simd::simd_mul_f64(&a_arr, &b_arr),
72        SimdBinaryKind::Div => scirs2_core::simd::simd_div_f64(&a_arr, &b_arr),
73    };
74    result.to_vec()
75}
76
77#[cfg(feature = "simd")]
78#[derive(Debug, Clone, Copy)]
79enum SimdBinaryKind {
80    Add,
81    Sub,
82    Mul,
83    Div,
84}
85
86// ============================================================================
87// Element-wise Arithmetic Ops (Op trait implementations)
88// ============================================================================
89
90/// SIMD-accelerated element-wise addition operator.
91///
92/// Forward:  `y = a + b`
93/// Backward: `da = dy`, `db = dy`
94pub struct SimdElementwiseAdd;
95
96impl<F: Float> Op<F> for SimdElementwiseAdd {
97    fn name(&self) -> &'static str {
98        "SimdElementwiseAdd"
99    }
100
101    fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
102        let a = ctx.input(0);
103        let b = ctx.input(1);
104
105        #[cfg(feature = "simd")]
106        {
107            if let (Some(a_slice), Some(b_slice)) = (a.as_slice(), b.as_slice()) {
108                if a_slice.len() == b_slice.len() {
109                    if let Some(result) =
110                        try_simd_binary_op::<F>(a_slice, b_slice, SimdBinaryKind::Add)
111                    {
112                        let shape = a.shape().to_vec();
113                        let arr = scirs2_core::ndarray::Array::from_shape_vec(
114                            scirs2_core::ndarray::IxDyn(&shape),
115                            result,
116                        )
117                        .map_err(|e| OpError::NdArrayError("SimdElementwiseAdd shape".into(), e))?;
118                        ctx.append_output(arr);
119                        return Ok(());
120                    }
121                }
122            }
123        }
124
125        // Scalar fallback
126        let result = &a.to_owned() + &b.to_owned();
127        ctx.append_output(result);
128        Ok(())
129    }
130
131    fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>) {
132        let gy = ctx.output_grad();
133        ctx.append_input_grad(0, Some(*gy));
134        ctx.append_input_grad(1, Some(*gy));
135    }
136}
137
138/// SIMD-accelerated element-wise subtraction operator.
139///
140/// Forward:  `y = a - b`
141/// Backward: `da = dy`, `db = -dy`
142pub struct SimdElementwiseSub;
143
144impl<F: Float> Op<F> for SimdElementwiseSub {
145    fn name(&self) -> &'static str {
146        "SimdElementwiseSub"
147    }
148
149    fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
150        let a = ctx.input(0);
151        let b = ctx.input(1);
152
153        #[cfg(feature = "simd")]
154        {
155            if let (Some(a_slice), Some(b_slice)) = (a.as_slice(), b.as_slice()) {
156                if a_slice.len() == b_slice.len() {
157                    if let Some(result) =
158                        try_simd_binary_op::<F>(a_slice, b_slice, SimdBinaryKind::Sub)
159                    {
160                        let shape = a.shape().to_vec();
161                        let arr = scirs2_core::ndarray::Array::from_shape_vec(
162                            scirs2_core::ndarray::IxDyn(&shape),
163                            result,
164                        )
165                        .map_err(|e| OpError::NdArrayError("SimdElementwiseSub shape".into(), e))?;
166                        ctx.append_output(arr);
167                        return Ok(());
168                    }
169                }
170            }
171        }
172
173        let result = &a.to_owned() - &b.to_owned();
174        ctx.append_output(result);
175        Ok(())
176    }
177
178    fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>) {
179        let gy = ctx.output_grad();
180        ctx.append_input_grad(0, Some(*gy));
181        ctx.append_input_grad(1, Some(crate::tensor_ops::neg(*gy)));
182    }
183}
184
185/// SIMD-accelerated element-wise multiplication operator.
186///
187/// Forward:  `y = a * b`
188/// Backward: `da = dy * b`, `db = dy * a`
189pub struct SimdElementwiseMul;
190
191impl<F: Float> Op<F> for SimdElementwiseMul {
192    fn name(&self) -> &'static str {
193        "SimdElementwiseMul"
194    }
195
196    fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
197        let a = ctx.input(0);
198        let b = ctx.input(1);
199
200        #[cfg(feature = "simd")]
201        {
202            if let (Some(a_slice), Some(b_slice)) = (a.as_slice(), b.as_slice()) {
203                if a_slice.len() == b_slice.len() {
204                    if let Some(result) =
205                        try_simd_binary_op::<F>(a_slice, b_slice, SimdBinaryKind::Mul)
206                    {
207                        let shape = a.shape().to_vec();
208                        let arr = scirs2_core::ndarray::Array::from_shape_vec(
209                            scirs2_core::ndarray::IxDyn(&shape),
210                            result,
211                        )
212                        .map_err(|e| OpError::NdArrayError("SimdElementwiseMul shape".into(), e))?;
213                        ctx.append_output(arr);
214                        return Ok(());
215                    }
216                }
217            }
218        }
219
220        let result = &a.to_owned() * &b.to_owned();
221        ctx.append_output(result);
222        Ok(())
223    }
224
225    fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>) {
226        let gy = ctx.output_grad();
227        let a = ctx.input(0);
228        let b = ctx.input(1);
229        // da = dy * b
230        ctx.append_input_grad(0, Some(*gy * b));
231        // db = dy * a
232        ctx.append_input_grad(1, Some(*gy * a));
233    }
234}
235
236/// SIMD-accelerated element-wise division operator.
237///
238/// Forward:  `y = a / b`
239/// Backward: `da = dy / b`, `db = -dy * a / b^2`
240pub struct SimdElementwiseDiv;
241
242impl<F: Float> Op<F> for SimdElementwiseDiv {
243    fn name(&self) -> &'static str {
244        "SimdElementwiseDiv"
245    }
246
247    fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
248        let a = ctx.input(0);
249        let b = ctx.input(1);
250
251        #[cfg(feature = "simd")]
252        {
253            if let (Some(a_slice), Some(b_slice)) = (a.as_slice(), b.as_slice()) {
254                if a_slice.len() == b_slice.len() {
255                    if let Some(result) =
256                        try_simd_binary_op::<F>(a_slice, b_slice, SimdBinaryKind::Div)
257                    {
258                        let shape = a.shape().to_vec();
259                        let arr = scirs2_core::ndarray::Array::from_shape_vec(
260                            scirs2_core::ndarray::IxDyn(&shape),
261                            result,
262                        )
263                        .map_err(|e| OpError::NdArrayError("SimdElementwiseDiv shape".into(), e))?;
264                        ctx.append_output(arr);
265                        return Ok(());
266                    }
267                }
268            }
269        }
270
271        let result = &a.to_owned() / &b.to_owned();
272        ctx.append_output(result);
273        Ok(())
274    }
275
276    fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>) {
277        let gy = ctx.output_grad();
278        let a = ctx.input(0);
279        let b = ctx.input(1);
280        let g = ctx.graph();
281
282        // da = dy / b
283        ctx.append_input_grad(0, Some(*gy / b));
284
285        // db = -dy * a / b^2
286        let neg_one = crate::tensor_ops::scalar(-F::one(), g);
287        let b_sq = b * b;
288        ctx.append_input_grad(1, Some(neg_one * *gy * a / b_sq));
289    }
290}
291
292// ============================================================================
293// Gradient Accumulation Ops
294// ============================================================================
295
296/// SIMD-accelerated gradient accumulation operator.
297///
298/// Accumulates gradient `g` into existing gradient buffer `acc`:
299/// Forward: `y = acc + g`
300///
301/// This is the critical inner-loop operation during backpropagation.
302/// Using SIMD here gives substantial training speedups.
303pub struct SimdGradientAccumulate;
304
305impl<F: Float> Op<F> for SimdGradientAccumulate {
306    fn name(&self) -> &'static str {
307        "SimdGradientAccumulate"
308    }
309
310    fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
311        let acc = ctx.input(0);
312        let grad = ctx.input(1);
313
314        #[cfg(feature = "simd")]
315        {
316            if let (Some(acc_slice), Some(grad_slice)) = (acc.as_slice(), grad.as_slice()) {
317                if acc_slice.len() == grad_slice.len() {
318                    if let Some(result) =
319                        try_simd_binary_op::<F>(acc_slice, grad_slice, SimdBinaryKind::Add)
320                    {
321                        let shape = acc.shape().to_vec();
322                        let arr = scirs2_core::ndarray::Array::from_shape_vec(
323                            scirs2_core::ndarray::IxDyn(&shape),
324                            result,
325                        )
326                        .map_err(|e| {
327                            OpError::NdArrayError("SimdGradientAccumulate shape".into(), e)
328                        })?;
329                        ctx.append_output(arr);
330                        return Ok(());
331                    }
332                }
333            }
334        }
335
336        let result = &acc.to_owned() + &grad.to_owned();
337        ctx.append_output(result);
338        Ok(())
339    }
340
341    fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>) {
342        let gy = ctx.output_grad();
343        // Both inputs receive the upstream gradient as-is
344        ctx.append_input_grad(0, Some(*gy));
345        ctx.append_input_grad(1, Some(*gy));
346    }
347}
348
349/// SIMD-accelerated scaled gradient accumulation: `acc + scale * grad`
350///
351/// Combines scaling and accumulation into a single fused operation,
352/// which is common in optimizers (e.g., momentum SGD: `v = mu * v + lr * grad`).
353pub struct SimdScaledGradientAccumulate<F: Float> {
354    /// The scale factor applied to the gradient before accumulation
355    pub scale: F,
356}
357
358impl<F: Float> Op<F> for SimdScaledGradientAccumulate<F> {
359    fn name(&self) -> &'static str {
360        "SimdScaledGradientAccumulate"
361    }
362
363    fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
364        let acc = ctx.input(0);
365        let grad = ctx.input(1);
366
367        // scale * grad + acc  (FMA pattern)
368        #[cfg(feature = "simd")]
369        {
370            if let (Some(acc_slice), Some(grad_slice)) = (acc.as_slice(), grad.as_slice()) {
371                if acc_slice.len() == grad_slice.len() {
372                    if let Some(result) = try_simd_fma::<F>(grad_slice, self.scale, acc_slice) {
373                        let shape = acc.shape().to_vec();
374                        let arr = scirs2_core::ndarray::Array::from_shape_vec(
375                            scirs2_core::ndarray::IxDyn(&shape),
376                            result,
377                        )
378                        .map_err(|e| {
379                            OpError::NdArrayError("SimdScaledGradAccum shape".into(), e)
380                        })?;
381                        ctx.append_output(arr);
382                        return Ok(());
383                    }
384                }
385            }
386        }
387
388        // Scalar fallback: acc + scale * grad
389        let scaled = grad.mapv(|v| v * self.scale);
390        let result = &acc.to_owned() + &scaled;
391        ctx.append_output(result);
392        Ok(())
393    }
394
395    fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>) {
396        let gy = ctx.output_grad();
397        let g = ctx.graph();
398        // d(acc + scale * grad)/d(acc) = 1
399        ctx.append_input_grad(0, Some(*gy));
400        // d(acc + scale * grad)/d(grad) = scale
401        let scale_tensor = crate::tensor_ops::scalar(self.scale, g);
402        ctx.append_input_grad(1, Some(*gy * scale_tensor));
403    }
404}
405
406// ============================================================================
407// Broadcasting Operations
408// ============================================================================
409
410/// SIMD-accelerated broadcast addition.
411///
412/// Adds a bias vector (1D) to each row of a 2D tensor:
413/// `y[i, :] = x[i, :] + bias[:]`
414///
415/// This is the standard pattern for bias addition in dense layers.
416pub struct SimdBroadcastAdd;
417
418impl<F: Float> Op<F> for SimdBroadcastAdd {
419    fn name(&self) -> &'static str {
420        "SimdBroadcastAdd"
421    }
422
423    fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
424        let x = ctx.input(0);
425        let bias = ctx.input(1);
426
427        let x_shape = x.shape().to_vec();
428        let bias_shape = bias.shape().to_vec();
429
430        // Handle the common case: x is [batch, features], bias is [features]
431        if x_shape.len() == 2 && bias_shape.len() == 1 && x_shape[1] == bias_shape[0] {
432            let rows = x_shape[0];
433            let cols = x_shape[1];
434
435            #[cfg(feature = "simd")]
436            {
437                if let (Some(x_slice), Some(bias_slice)) = (x.as_slice(), bias.as_slice()) {
438                    if let Some(mut result_vec) =
439                        try_simd_broadcast_add_2d::<F>(x_slice, bias_slice, rows, cols)
440                    {
441                        let arr = scirs2_core::ndarray::Array::from_shape_vec(
442                            scirs2_core::ndarray::IxDyn(&x_shape),
443                            result_vec,
444                        )
445                        .map_err(|e| OpError::NdArrayError("SimdBroadcastAdd shape".into(), e))?;
446                        ctx.append_output(arr);
447                        return Ok(());
448                    }
449                }
450            }
451        }
452
453        // General fallback using ndarray broadcasting
454        let result = &x.to_owned() + &bias.to_owned();
455        ctx.append_output(result);
456        Ok(())
457    }
458
459    fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>) {
460        let gy = ctx.output_grad();
461        // dx = dy (same shape)
462        ctx.append_input_grad(0, Some(*gy));
463        // dbias = sum over batch dimension
464        let dbias = crate::tensor_ops::reduce_sum(*gy, &[0], false);
465        ctx.append_input_grad(1, Some(dbias));
466    }
467}
468
469/// SIMD-accelerated broadcast multiplication.
470///
471/// Multiplies each row of a 2D tensor by a 1D scale vector:
472/// `y[i, :] = x[i, :] * scale[:]`
473pub struct SimdBroadcastMul;
474
475impl<F: Float> Op<F> for SimdBroadcastMul {
476    fn name(&self) -> &'static str {
477        "SimdBroadcastMul"
478    }
479
480    fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
481        let x = ctx.input(0);
482        let scale = ctx.input(1);
483
484        let x_shape = x.shape().to_vec();
485        let scale_shape = scale.shape().to_vec();
486
487        if x_shape.len() == 2 && scale_shape.len() == 1 && x_shape[1] == scale_shape[0] {
488            let rows = x_shape[0];
489            let cols = x_shape[1];
490
491            #[cfg(feature = "simd")]
492            {
493                if let (Some(x_slice), Some(scale_slice)) = (x.as_slice(), scale.as_slice()) {
494                    if let Some(result_vec) =
495                        try_simd_broadcast_mul_2d::<F>(x_slice, scale_slice, rows, cols)
496                    {
497                        let arr = scirs2_core::ndarray::Array::from_shape_vec(
498                            scirs2_core::ndarray::IxDyn(&x_shape),
499                            result_vec,
500                        )
501                        .map_err(|e| OpError::NdArrayError("SimdBroadcastMul shape".into(), e))?;
502                        ctx.append_output(arr);
503                        return Ok(());
504                    }
505                }
506            }
507        }
508
509        let result = &x.to_owned() * &scale.to_owned();
510        ctx.append_output(result);
511        Ok(())
512    }
513
514    fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>) {
515        let gy = ctx.output_grad();
516        let x = ctx.input(0);
517        let scale = ctx.input(1);
518
519        // dx = dy * scale (broadcast)
520        ctx.append_input_grad(0, Some(*gy * scale));
521        // dscale = sum_over_batch(dy * x)
522        let dscale = crate::tensor_ops::reduce_sum(*gy * x, &[0], false);
523        ctx.append_input_grad(1, Some(dscale));
524    }
525}
526
527// ============================================================================
528// Activation Functions (Forward + Backward)
529// ============================================================================
530
531/// SIMD-accelerated ReLU activation operator.
532///
533/// Forward:  `y = max(0, x)`
534/// Backward: `dx = dy * (x > 0)`
535pub struct SimdReLU;
536
537impl<F: Float> Op<F> for SimdReLU {
538    fn name(&self) -> &'static str {
539        "SimdReLU"
540    }
541
542    fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
543        let x = ctx.input(0);
544
545        #[cfg(feature = "simd")]
546        {
547            if let Some(x_slice) = x.as_slice() {
548                if let Some(result) = try_simd_relu::<F>(x_slice) {
549                    let shape = x.shape().to_vec();
550                    let arr = scirs2_core::ndarray::Array::from_shape_vec(
551                        scirs2_core::ndarray::IxDyn(&shape),
552                        result,
553                    )
554                    .map_err(|e| OpError::NdArrayError("SimdReLU shape".into(), e))?;
555                    ctx.append_output(arr);
556                    return Ok(());
557                }
558            }
559        }
560
561        // Scalar fallback
562        let result = x.mapv(|v| if v > F::zero() { v } else { F::zero() });
563        ctx.append_output(result);
564        Ok(())
565    }
566
567    fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>) {
568        let gy = ctx.output_grad();
569        let x = ctx.input(0);
570        let g = ctx.graph();
571
572        // ReLU backward: dy * (x > 0)
573        let zero = crate::tensor_ops::scalar(F::zero(), g);
574        let mask = crate::tensor_ops::greater(x, zero);
575        ctx.append_input_grad(0, Some(*gy * mask));
576    }
577}
578
579/// SIMD-accelerated Sigmoid activation operator.
580///
581/// Forward:  `y = 1 / (1 + exp(-x))`
582/// Backward: `dx = dy * y * (1 - y)`
583pub struct SimdSigmoid;
584
585impl<F: Float> Op<F> for SimdSigmoid {
586    fn name(&self) -> &'static str {
587        "SimdSigmoid"
588    }
589
590    fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
591        let x = ctx.input(0);
592
593        #[cfg(feature = "simd")]
594        {
595            if let Some(x_slice) = x.as_slice() {
596                if let Some(result) = try_simd_sigmoid::<F>(x_slice) {
597                    let shape = x.shape().to_vec();
598                    let arr = scirs2_core::ndarray::Array::from_shape_vec(
599                        scirs2_core::ndarray::IxDyn(&shape),
600                        result,
601                    )
602                    .map_err(|e| OpError::NdArrayError("SimdSigmoid shape".into(), e))?;
603                    ctx.append_output(arr);
604                    return Ok(());
605                }
606            }
607        }
608
609        // Scalar fallback: sigmoid(x) = 0.5 * (tanh(0.5*x) + 1)
610        let half = F::from(0.5).ok_or_else(|| OpError::ConversionError {
611            context: "SimdSigmoid half constant".into(),
612            from_type: "f64".into(),
613            to_type: std::any::type_name::<F>().into(),
614        })?;
615        let result = x.mapv(move |v| ((v * half).tanh() * half) + half);
616        ctx.append_output(result);
617        Ok(())
618    }
619
620    fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>) {
621        let gy = ctx.output_grad();
622        let y = ctx.output();
623        let g = ctx.graph();
624
625        // sigmoid backward: dy * y * (1 - y)
626        let one = crate::tensor_ops::scalar(F::one(), g);
627        let one_minus_y = one - y;
628        ctx.append_input_grad(0, Some(*gy * y * one_minus_y));
629    }
630}
631
632/// SIMD-accelerated Tanh activation operator.
633///
634/// Forward:  `y = tanh(x)`
635/// Backward: `dx = dy * (1 - y^2)`
636pub struct SimdTanh;
637
638impl<F: Float> Op<F> for SimdTanh {
639    fn name(&self) -> &'static str {
640        "SimdTanh"
641    }
642
643    fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
644        let x = ctx.input(0);
645
646        #[cfg(feature = "simd")]
647        {
648            if let Some(x_slice) = x.as_slice() {
649                if let Some(result) = try_simd_tanh::<F>(x_slice) {
650                    let shape = x.shape().to_vec();
651                    let arr = scirs2_core::ndarray::Array::from_shape_vec(
652                        scirs2_core::ndarray::IxDyn(&shape),
653                        result,
654                    )
655                    .map_err(|e| OpError::NdArrayError("SimdTanh shape".into(), e))?;
656                    ctx.append_output(arr);
657                    return Ok(());
658                }
659            }
660        }
661
662        let result = x.mapv(|v| v.tanh());
663        ctx.append_output(result);
664        Ok(())
665    }
666
667    fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>) {
668        let gy = ctx.output_grad();
669        let y = ctx.output();
670        let g = ctx.graph();
671
672        // tanh backward: dy * (1 - y^2)
673        let one = crate::tensor_ops::scalar(F::one(), g);
674        let y_sq = y * y;
675        ctx.append_input_grad(0, Some(*gy * (one - y_sq)));
676    }
677}
678
679// ============================================================================
680// Dot Product / Reduction Operations
681// ============================================================================
682
683/// SIMD-accelerated dot product operator.
684///
685/// Computes the inner product of two 1-D tensors.
686/// Forward:  `y = sum(a * b)`
687/// Backward: `da = dy * b`, `db = dy * a`
688pub struct SimdDotProduct;
689
690impl<F: Float> Op<F> for SimdDotProduct {
691    fn name(&self) -> &'static str {
692        "SimdDotProduct"
693    }
694
695    fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
696        let a = ctx.input(0);
697        let b = ctx.input(1);
698
699        if a.ndim() != 1 || b.ndim() != 1 {
700            return Err(OpError::IncompatibleShape(
701                "SimdDotProduct requires 1-D inputs".into(),
702            ));
703        }
704
705        if a.len() != b.len() {
706            return Err(OpError::IncompatibleShape(format!(
707                "SimdDotProduct: length mismatch: {} vs {}",
708                a.len(),
709                b.len()
710            )));
711        }
712
713        #[cfg(feature = "simd")]
714        {
715            if let (Some(a_slice), Some(b_slice)) = (a.as_slice(), b.as_slice()) {
716                if let Some(dot_val) = try_simd_dot::<F>(a_slice, b_slice) {
717                    let arr = scirs2_core::ndarray::arr0(dot_val).into_dyn();
718                    ctx.append_output(arr);
719                    return Ok(());
720                }
721            }
722        }
723
724        // Scalar fallback
725        let mut sum = F::zero();
726        for (&ai, &bi) in a.iter().zip(b.iter()) {
727            sum += ai * bi;
728        }
729        let arr = scirs2_core::ndarray::arr0(sum).into_dyn();
730        ctx.append_output(arr);
731        Ok(())
732    }
733
734    fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>) {
735        let gy = ctx.output_grad();
736        let a = ctx.input(0);
737        let b = ctx.input(1);
738
739        // da = dy * b, db = dy * a
740        ctx.append_input_grad(0, Some(*gy * b));
741        ctx.append_input_grad(1, Some(*gy * a));
742    }
743}
744
745/// SIMD-accelerated sum reduction operator.
746///
747/// Computes the sum of all elements in a 1-D tensor.
748/// Forward:  `y = sum(x)`
749/// Backward: `dx = ones_like(x) * dy`
750pub struct SimdReductionSum;
751
752impl<F: Float> Op<F> for SimdReductionSum {
753    fn name(&self) -> &'static str {
754        "SimdReductionSum"
755    }
756
757    fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
758        let x = ctx.input(0);
759
760        #[cfg(feature = "simd")]
761        {
762            if let Some(x_slice) = x.as_slice() {
763                if let Some(sum_val) = try_simd_sum::<F>(x_slice) {
764                    let arr = scirs2_core::ndarray::arr0(sum_val).into_dyn();
765                    ctx.append_output(arr);
766                    return Ok(());
767                }
768            }
769        }
770
771        // Scalar fallback
772        let sum = x.iter().fold(F::zero(), |acc, &v| acc + v);
773        let arr = scirs2_core::ndarray::arr0(sum).into_dyn();
774        ctx.append_output(arr);
775        Ok(())
776    }
777
778    fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>) {
779        let gy = ctx.output_grad();
780        let x = ctx.input(0);
781        let g = ctx.graph();
782
783        // dx = broadcast(dy) to the shape of x  => ones_like(x) * dy
784        let ones_shape = crate::tensor_ops::shape(x);
785        let ones_val = crate::tensor_ops::ones(&ones_shape, g);
786        ctx.append_input_grad(0, Some(ones_val * *gy));
787    }
788}
789
790// ============================================================================
791// Type-dispatched SIMD helpers (compile-time type routing)
792// ============================================================================
793
794/// Attempt to dispatch a binary SIMD operation for the concrete float type.
795/// Returns `None` if the type is not f32/f64 (meaning fallback is needed).
796#[cfg(feature = "simd")]
797fn try_simd_binary_op<F: Float>(a: &[F], b: &[F], kind: SimdBinaryKind) -> Option<Vec<F>> {
798    use crate::same_type;
799
800    if same_type::<F, f32>() {
801        // SAFETY: we checked the type is f32
802        let a_f32: &[f32] =
803            unsafe { std::slice::from_raw_parts(a.as_ptr() as *const f32, a.len()) };
804        let b_f32: &[f32] =
805            unsafe { std::slice::from_raw_parts(b.as_ptr() as *const f32, b.len()) };
806        let result = dispatch_binary_f32(a_f32, b_f32, kind);
807        // SAFETY: F == f32
808        let result_f: Vec<F> = unsafe {
809            let mut v = std::mem::ManuallyDrop::new(result);
810            Vec::from_raw_parts(v.as_mut_ptr() as *mut F, v.len(), v.capacity())
811        };
812        Some(result_f)
813    } else if same_type::<F, f64>() {
814        let a_f64: &[f64] =
815            unsafe { std::slice::from_raw_parts(a.as_ptr() as *const f64, a.len()) };
816        let b_f64: &[f64] =
817            unsafe { std::slice::from_raw_parts(b.as_ptr() as *const f64, b.len()) };
818        let result = dispatch_binary_f64(a_f64, b_f64, kind);
819        let result_f: Vec<F> = unsafe {
820            let mut v = std::mem::ManuallyDrop::new(result);
821            Vec::from_raw_parts(v.as_mut_ptr() as *mut F, v.len(), v.capacity())
822        };
823        Some(result_f)
824    } else {
825        None
826    }
827}
828
829/// Attempt to dispatch SIMD FMA: `a * scale + c`
830#[cfg(feature = "simd")]
831fn try_simd_fma<F: Float>(a: &[F], scale: F, c: &[F]) -> Option<Vec<F>> {
832    use crate::same_type;
833
834    if same_type::<F, f32>() {
835        let a_f32: &[f32] =
836            unsafe { std::slice::from_raw_parts(a.as_ptr() as *const f32, a.len()) };
837        let c_f32: &[f32] =
838            unsafe { std::slice::from_raw_parts(c.as_ptr() as *const f32, c.len()) };
839        let scale_f32: f32 = unsafe { *(&scale as *const F as *const f32) };
840
841        // Create scale array for FMA: result = a * scale_arr + c
842        let scale_arr = scirs2_core::ndarray::Array1::from_elem(a.len(), scale_f32);
843        let a_view = scirs2_core::ndarray::ArrayView1::from(a_f32);
844        let scale_view = scale_arr.view();
845        let c_view = scirs2_core::ndarray::ArrayView1::from(c_f32);
846
847        let result = scirs2_core::simd::simd_fma_f32_ultra(&a_view, &scale_view, &c_view);
848        let result_vec = result.to_vec();
849
850        let result_f: Vec<F> = unsafe {
851            let mut v = std::mem::ManuallyDrop::new(result_vec);
852            Vec::from_raw_parts(v.as_mut_ptr() as *mut F, v.len(), v.capacity())
853        };
854        Some(result_f)
855    } else if same_type::<F, f64>() {
856        // No f64 FMA in core, fall back to manual scale + add
857        let a_f64: &[f64] =
858            unsafe { std::slice::from_raw_parts(a.as_ptr() as *const f64, a.len()) };
859        let c_f64: &[f64] =
860            unsafe { std::slice::from_raw_parts(c.as_ptr() as *const f64, c.len()) };
861        let scale_f64: f64 = unsafe { *(&scale as *const F as *const f64) };
862
863        let a_view = scirs2_core::ndarray::ArrayView1::from(a_f64);
864        let scale_arr = scirs2_core::simd::simd_scalar_mul_f64(&a_view, scale_f64);
865        let scale_view = scale_arr.view();
866        let c_view = scirs2_core::ndarray::ArrayView1::from(c_f64);
867        let result = scirs2_core::simd::simd_add_f64(&scale_view, &c_view);
868        let result_vec = result.to_vec();
869
870        let result_f: Vec<F> = unsafe {
871            let mut v = std::mem::ManuallyDrop::new(result_vec);
872            Vec::from_raw_parts(v.as_mut_ptr() as *mut F, v.len(), v.capacity())
873        };
874        Some(result_f)
875    } else {
876        None
877    }
878}
879
880/// SIMD broadcast add: add bias to each row of a 2D tensor
881#[cfg(feature = "simd")]
882fn try_simd_broadcast_add_2d<F: Float>(
883    x: &[F],
884    bias: &[F],
885    rows: usize,
886    cols: usize,
887) -> Option<Vec<F>> {
888    use crate::same_type;
889
890    if same_type::<F, f32>() {
891        let x_f32: &[f32] =
892            unsafe { std::slice::from_raw_parts(x.as_ptr() as *const f32, x.len()) };
893        let bias_f32: &[f32] =
894            unsafe { std::slice::from_raw_parts(bias.as_ptr() as *const f32, bias.len()) };
895
896        let bias_view = scirs2_core::ndarray::ArrayView1::from(bias_f32);
897        let mut result: Vec<f32> = Vec::with_capacity(rows * cols);
898        for row in 0..rows {
899            let row_start = row * cols;
900            let row_end = row_start + cols;
901            let row_slice = &x_f32[row_start..row_end];
902            let row_view = scirs2_core::ndarray::ArrayView1::from(row_slice);
903            let row_result = scirs2_core::simd::simd_add_f32(&row_view, &bias_view);
904            result.extend(row_result.iter().copied());
905        }
906
907        let result_f: Vec<F> = unsafe {
908            let mut v = std::mem::ManuallyDrop::new(result);
909            Vec::from_raw_parts(v.as_mut_ptr() as *mut F, v.len(), v.capacity())
910        };
911        Some(result_f)
912    } else if same_type::<F, f64>() {
913        let x_f64: &[f64] =
914            unsafe { std::slice::from_raw_parts(x.as_ptr() as *const f64, x.len()) };
915        let bias_f64: &[f64] =
916            unsafe { std::slice::from_raw_parts(bias.as_ptr() as *const f64, bias.len()) };
917
918        let bias_view = scirs2_core::ndarray::ArrayView1::from(bias_f64);
919        let mut result: Vec<f64> = Vec::with_capacity(rows * cols);
920        for row in 0..rows {
921            let row_start = row * cols;
922            let row_end = row_start + cols;
923            let row_slice = &x_f64[row_start..row_end];
924            let row_view = scirs2_core::ndarray::ArrayView1::from(row_slice);
925            let row_result = scirs2_core::simd::simd_add_f64(&row_view, &bias_view);
926            result.extend(row_result.iter().copied());
927        }
928
929        let result_f: Vec<F> = unsafe {
930            let mut v = std::mem::ManuallyDrop::new(result);
931            Vec::from_raw_parts(v.as_mut_ptr() as *mut F, v.len(), v.capacity())
932        };
933        Some(result_f)
934    } else {
935        None
936    }
937}
938
939/// SIMD broadcast mul: multiply each row of a 2D tensor by a scale vector
940#[cfg(feature = "simd")]
941fn try_simd_broadcast_mul_2d<F: Float>(
942    x: &[F],
943    scale: &[F],
944    rows: usize,
945    cols: usize,
946) -> Option<Vec<F>> {
947    use crate::same_type;
948
949    if same_type::<F, f32>() {
950        let x_f32: &[f32] =
951            unsafe { std::slice::from_raw_parts(x.as_ptr() as *const f32, x.len()) };
952        let scale_f32: &[f32] =
953            unsafe { std::slice::from_raw_parts(scale.as_ptr() as *const f32, scale.len()) };
954
955        let scale_view = scirs2_core::ndarray::ArrayView1::from(scale_f32);
956        let mut result: Vec<f32> = Vec::with_capacity(rows * cols);
957        for row in 0..rows {
958            let row_start = row * cols;
959            let row_end = row_start + cols;
960            let row_slice = &x_f32[row_start..row_end];
961            let row_view = scirs2_core::ndarray::ArrayView1::from(row_slice);
962            let row_result = scirs2_core::simd::simd_mul_f32(&row_view, &scale_view);
963            result.extend(row_result.iter().copied());
964        }
965
966        let result_f: Vec<F> = unsafe {
967            let mut v = std::mem::ManuallyDrop::new(result);
968            Vec::from_raw_parts(v.as_mut_ptr() as *mut F, v.len(), v.capacity())
969        };
970        Some(result_f)
971    } else if same_type::<F, f64>() {
972        let x_f64: &[f64] =
973            unsafe { std::slice::from_raw_parts(x.as_ptr() as *const f64, x.len()) };
974        let scale_f64: &[f64] =
975            unsafe { std::slice::from_raw_parts(scale.as_ptr() as *const f64, scale.len()) };
976
977        let scale_view = scirs2_core::ndarray::ArrayView1::from(scale_f64);
978        let mut result: Vec<f64> = Vec::with_capacity(rows * cols);
979        for row in 0..rows {
980            let row_start = row * cols;
981            let row_end = row_start + cols;
982            let row_slice = &x_f64[row_start..row_end];
983            let row_view = scirs2_core::ndarray::ArrayView1::from(row_slice);
984            let row_result = scirs2_core::simd::simd_mul_f64(&row_view, &scale_view);
985            result.extend(row_result.iter().copied());
986        }
987
988        let result_f: Vec<F> = unsafe {
989            let mut v = std::mem::ManuallyDrop::new(result);
990            Vec::from_raw_parts(v.as_mut_ptr() as *mut F, v.len(), v.capacity())
991        };
992        Some(result_f)
993    } else {
994        None
995    }
996}
997
998/// SIMD-accelerated ReLU for f32/f64
999#[cfg(feature = "simd")]
1000fn try_simd_relu<F: Float>(x: &[F]) -> Option<Vec<F>> {
1001    use crate::same_type;
1002
1003    if same_type::<F, f32>() {
1004        let x_f32: &[f32] =
1005            unsafe { std::slice::from_raw_parts(x.as_ptr() as *const f32, x.len()) };
1006        let x_view = scirs2_core::ndarray::ArrayView1::from(x_f32);
1007        let result = scirs2_core::simd::simd_relu_f32(&x_view);
1008        let result_vec = result.to_vec();
1009        let result_f: Vec<F> = unsafe {
1010            let mut v = std::mem::ManuallyDrop::new(result_vec);
1011            Vec::from_raw_parts(v.as_mut_ptr() as *mut F, v.len(), v.capacity())
1012        };
1013        Some(result_f)
1014    } else if same_type::<F, f64>() {
1015        let x_f64: &[f64] =
1016            unsafe { std::slice::from_raw_parts(x.as_ptr() as *const f64, x.len()) };
1017        let x_view = scirs2_core::ndarray::ArrayView1::from(x_f64);
1018        let result = scirs2_core::simd::simd_relu_f64(&x_view);
1019        let result_vec = result.to_vec();
1020        let result_f: Vec<F> = unsafe {
1021            let mut v = std::mem::ManuallyDrop::new(result_vec);
1022            Vec::from_raw_parts(v.as_mut_ptr() as *mut F, v.len(), v.capacity())
1023        };
1024        Some(result_f)
1025    } else {
1026        None
1027    }
1028}
1029
1030/// SIMD-accelerated sigmoid for f32/f64
1031#[cfg(feature = "simd")]
1032fn try_simd_sigmoid<F: Float>(x: &[F]) -> Option<Vec<F>> {
1033    use crate::same_type;
1034
1035    if same_type::<F, f32>() {
1036        let x_f32: &[f32] =
1037            unsafe { std::slice::from_raw_parts(x.as_ptr() as *const f32, x.len()) };
1038        let x_view = scirs2_core::ndarray::ArrayView1::from(x_f32);
1039        let result = scirs2_core::simd::simd_sigmoid_f32(&x_view);
1040        let result_vec = result.to_vec();
1041        let result_f: Vec<F> = unsafe {
1042            let mut v = std::mem::ManuallyDrop::new(result_vec);
1043            Vec::from_raw_parts(v.as_mut_ptr() as *mut F, v.len(), v.capacity())
1044        };
1045        Some(result_f)
1046    } else if same_type::<F, f64>() {
1047        let x_f64: &[f64] =
1048            unsafe { std::slice::from_raw_parts(x.as_ptr() as *const f64, x.len()) };
1049        let x_view = scirs2_core::ndarray::ArrayView1::from(x_f64);
1050        let result = scirs2_core::simd::simd_sigmoid_f64(&x_view);
1051        let result_vec = result.to_vec();
1052        let result_f: Vec<F> = unsafe {
1053            let mut v = std::mem::ManuallyDrop::new(result_vec);
1054            Vec::from_raw_parts(v.as_mut_ptr() as *mut F, v.len(), v.capacity())
1055        };
1056        Some(result_f)
1057    } else {
1058        None
1059    }
1060}
1061
1062/// SIMD-accelerated tanh for f32/f64
1063#[cfg(feature = "simd")]
1064fn try_simd_tanh<F: Float>(x: &[F]) -> Option<Vec<F>> {
1065    use crate::same_type;
1066
1067    if same_type::<F, f32>() {
1068        let x_f32: &[f32] =
1069            unsafe { std::slice::from_raw_parts(x.as_ptr() as *const f32, x.len()) };
1070        let x_view = scirs2_core::ndarray::ArrayView1::from(x_f32);
1071        let result = scirs2_core::simd::simd_tanh_f32(&x_view);
1072        let result_vec = result.to_vec();
1073        let result_f: Vec<F> = unsafe {
1074            let mut v = std::mem::ManuallyDrop::new(result_vec);
1075            Vec::from_raw_parts(v.as_mut_ptr() as *mut F, v.len(), v.capacity())
1076        };
1077        Some(result_f)
1078    } else if same_type::<F, f64>() {
1079        let x_f64: &[f64] =
1080            unsafe { std::slice::from_raw_parts(x.as_ptr() as *const f64, x.len()) };
1081        let x_view = scirs2_core::ndarray::ArrayView1::from(x_f64);
1082        let result = scirs2_core::simd::simd_tanh_f64(&x_view);
1083        let result_vec = result.to_vec();
1084        let result_f: Vec<F> = unsafe {
1085            let mut v = std::mem::ManuallyDrop::new(result_vec);
1086            Vec::from_raw_parts(v.as_mut_ptr() as *mut F, v.len(), v.capacity())
1087        };
1088        Some(result_f)
1089    } else {
1090        None
1091    }
1092}
1093
1094/// SIMD-accelerated dot product for f32/f64
1095#[cfg(feature = "simd")]
1096fn try_simd_dot<F: Float>(a: &[F], b: &[F]) -> Option<F> {
1097    use crate::same_type;
1098
1099    if same_type::<F, f32>() {
1100        let a_f32: &[f32] =
1101            unsafe { std::slice::from_raw_parts(a.as_ptr() as *const f32, a.len()) };
1102        let b_f32: &[f32] =
1103            unsafe { std::slice::from_raw_parts(b.as_ptr() as *const f32, b.len()) };
1104        let a_view = scirs2_core::ndarray::ArrayView1::from(a_f32);
1105        let b_view = scirs2_core::ndarray::ArrayView1::from(b_f32);
1106        let result_f32 = scirs2_core::simd::simd_dot_f32(&a_view, &b_view);
1107        // SAFETY: F == f32
1108        let result: F = unsafe { *(&result_f32 as *const f32 as *const F) };
1109        Some(result)
1110    } else if same_type::<F, f64>() {
1111        let a_f64: &[f64] =
1112            unsafe { std::slice::from_raw_parts(a.as_ptr() as *const f64, a.len()) };
1113        let b_f64: &[f64] =
1114            unsafe { std::slice::from_raw_parts(b.as_ptr() as *const f64, b.len()) };
1115        let a_view = scirs2_core::ndarray::ArrayView1::from(a_f64);
1116        let b_view = scirs2_core::ndarray::ArrayView1::from(b_f64);
1117        let result_f64 = scirs2_core::simd::simd_dot_f64(&a_view, &b_view);
1118        let result: F = unsafe { *(&result_f64 as *const f64 as *const F) };
1119        Some(result)
1120    } else {
1121        None
1122    }
1123}
1124
1125/// SIMD-accelerated sum for f32/f64
1126#[cfg(feature = "simd")]
1127fn try_simd_sum<F: Float>(x: &[F]) -> Option<F> {
1128    use crate::same_type;
1129
1130    if same_type::<F, f32>() {
1131        let x_f32: &[f32] =
1132            unsafe { std::slice::from_raw_parts(x.as_ptr() as *const f32, x.len()) };
1133        let x_view = scirs2_core::ndarray::ArrayView1::from(x_f32);
1134        let result_f32 = scirs2_core::simd::simd_sum_f32(&x_view);
1135        let result: F = unsafe { *(&result_f32 as *const f32 as *const F) };
1136        Some(result)
1137    } else if same_type::<F, f64>() {
1138        let x_f64: &[f64] =
1139            unsafe { std::slice::from_raw_parts(x.as_ptr() as *const f64, x.len()) };
1140        let x_view = scirs2_core::ndarray::ArrayView1::from(x_f64);
1141        let result_f64 = scirs2_core::simd::simd_sum_f64(&x_view);
1142        let result: F = unsafe { *(&result_f64 as *const f64 as *const F) };
1143        Some(result)
1144    } else {
1145        None
1146    }
1147}
1148
1149// ============================================================================
1150// Public API: Tensor-level functions
1151// ============================================================================
1152
1153/// SIMD-accelerated element-wise addition of two tensors.
1154///
1155/// Uses hardware SIMD instructions (AVX2/NEON) when the `simd` feature is enabled
1156/// and the tensors are contiguous in memory. Falls back to scalar otherwise.
1157///
1158/// # Arguments
1159/// * `a` - Left operand tensor
1160/// * `b` - Right operand tensor
1161///
1162/// # Returns
1163/// A new tensor containing the element-wise sum.
1164pub fn simd_elementwise_add<'g, F: Float>(a: &Tensor<'g, F>, b: &Tensor<'g, F>) -> Tensor<'g, F> {
1165    let g = a.graph();
1166    Tensor::builder(g)
1167        .append_input(a, false)
1168        .append_input(b, false)
1169        .build(SimdElementwiseAdd)
1170}
1171
1172/// SIMD-accelerated element-wise subtraction.
1173pub fn simd_elementwise_sub<'g, F: Float>(a: &Tensor<'g, F>, b: &Tensor<'g, F>) -> Tensor<'g, F> {
1174    let g = a.graph();
1175    Tensor::builder(g)
1176        .append_input(a, false)
1177        .append_input(b, false)
1178        .build(SimdElementwiseSub)
1179}
1180
1181/// SIMD-accelerated element-wise multiplication.
1182pub fn simd_elementwise_mul<'g, F: Float>(a: &Tensor<'g, F>, b: &Tensor<'g, F>) -> Tensor<'g, F> {
1183    let g = a.graph();
1184    Tensor::builder(g)
1185        .append_input(a, false)
1186        .append_input(b, false)
1187        .build(SimdElementwiseMul)
1188}
1189
1190/// SIMD-accelerated element-wise division.
1191pub fn simd_elementwise_div<'g, F: Float>(a: &Tensor<'g, F>, b: &Tensor<'g, F>) -> Tensor<'g, F> {
1192    let g = a.graph();
1193    Tensor::builder(g)
1194        .append_input(a, false)
1195        .append_input(b, false)
1196        .build(SimdElementwiseDiv)
1197}
1198
1199/// SIMD-accelerated gradient accumulation.
1200///
1201/// This is the critical inner-loop operation during backpropagation.
1202/// Accumulates `gradient` into the existing `accumulator`.
1203///
1204/// # Arguments
1205/// * `accumulator` - Existing gradient accumulator
1206/// * `gradient` - New gradient to accumulate
1207pub fn simd_gradient_accumulate<'g, F: Float>(
1208    accumulator: &Tensor<'g, F>,
1209    gradient: &Tensor<'g, F>,
1210) -> Tensor<'g, F> {
1211    let g = accumulator.graph();
1212    Tensor::builder(g)
1213        .append_input(accumulator, false)
1214        .append_input(gradient, false)
1215        .build(SimdGradientAccumulate)
1216}
1217
1218/// SIMD-accelerated scaled gradient accumulation: `acc + scale * grad`
1219///
1220/// Fused operation common in optimizers (momentum SGD, Adam, etc.)
1221pub fn simd_scaled_gradient_accumulate<'g, F: Float>(
1222    accumulator: &Tensor<'g, F>,
1223    gradient: &Tensor<'g, F>,
1224    scale: F,
1225) -> Tensor<'g, F> {
1226    let g = accumulator.graph();
1227    Tensor::builder(g)
1228        .append_input(accumulator, false)
1229        .append_input(gradient, false)
1230        .build(SimdScaledGradientAccumulate { scale })
1231}
1232
1233/// SIMD-accelerated broadcast addition (bias addition pattern).
1234///
1235/// Adds a 1-D bias to each row of a 2-D tensor.
1236pub fn simd_broadcast_add<'g, F: Float>(x: &Tensor<'g, F>, bias: &Tensor<'g, F>) -> Tensor<'g, F> {
1237    let g = x.graph();
1238    Tensor::builder(g)
1239        .append_input(x, false)
1240        .append_input(bias, false)
1241        .build(SimdBroadcastAdd)
1242}
1243
1244/// SIMD-accelerated broadcast multiplication.
1245///
1246/// Multiplies each row of a 2-D tensor by a 1-D scale vector.
1247pub fn simd_broadcast_mul<'g, F: Float>(x: &Tensor<'g, F>, scale: &Tensor<'g, F>) -> Tensor<'g, F> {
1248    let g = x.graph();
1249    Tensor::builder(g)
1250        .append_input(x, false)
1251        .append_input(scale, false)
1252        .build(SimdBroadcastMul)
1253}
1254
1255/// SIMD-accelerated ReLU activation.
1256pub fn simd_activation_relu<'g, F: Float>(x: &Tensor<'g, F>) -> Tensor<'g, F> {
1257    let g = x.graph();
1258    Tensor::builder(g).append_input(x, false).build(SimdReLU)
1259}
1260
1261/// SIMD-accelerated sigmoid activation.
1262pub fn simd_activation_sigmoid<'g, F: Float>(x: &Tensor<'g, F>) -> Tensor<'g, F> {
1263    let g = x.graph();
1264    Tensor::builder(g).append_input(x, false).build(SimdSigmoid)
1265}
1266
1267/// SIMD-accelerated tanh activation.
1268pub fn simd_activation_tanh<'g, F: Float>(x: &Tensor<'g, F>) -> Tensor<'g, F> {
1269    let g = x.graph();
1270    Tensor::builder(g).append_input(x, false).build(SimdTanh)
1271}
1272
1273/// SIMD-accelerated dot product of two 1-D tensors.
1274pub fn simd_dot_product<'g, F: Float>(a: &Tensor<'g, F>, b: &Tensor<'g, F>) -> Tensor<'g, F> {
1275    let g = a.graph();
1276    Tensor::builder(g)
1277        .append_input(a, false)
1278        .append_input(b, false)
1279        .build(SimdDotProduct)
1280}
1281
1282/// SIMD-accelerated sum reduction of a 1-D tensor.
1283pub fn simd_reduction_sum<'g, F: Float>(x: &Tensor<'g, F>) -> Tensor<'g, F> {
1284    let g = x.graph();
1285    Tensor::builder(g)
1286        .append_input(x, false)
1287        .build(SimdReductionSum)
1288}
1289
1290// ============================================================================
1291// Configuration
1292// ============================================================================
1293
1294/// Performance configuration for SIMD operations.
1295///
1296/// Controls minimum array sizes for engaging SIMD vs scalar paths,
1297/// and other tuning parameters.
1298#[derive(Debug, Clone)]
1299pub struct SimdConfig {
1300    /// Minimum number of elements before SIMD is used (default: 16)
1301    pub min_simd_length: usize,
1302    /// Whether to prefer FMA (fused multiply-add) when available (default: true)
1303    pub prefer_fma: bool,
1304    /// Whether to use adaptive algorithm selection (default: true)
1305    pub adaptive_dispatch: bool,
1306}
1307
1308impl Default for SimdConfig {
1309    fn default() -> Self {
1310        Self {
1311            min_simd_length: 16,
1312            prefer_fma: true,
1313            adaptive_dispatch: true,
1314        }
1315    }
1316}
1317
1318// ============================================================================
1319// Tests
1320// ============================================================================
1321
1322#[cfg(test)]
1323mod tests {
1324    use super::*;
1325    use crate as ag;
1326    use scirs2_core::ndarray::{array, Array1, ArrayView1};
1327
1328    /// Helper to compare two float slices within epsilon
1329    fn assert_approx_eq_f32(actual: &[f32], expected: &[f32], epsilon: f32) {
1330        assert_eq!(actual.len(), expected.len(), "Length mismatch");
1331        for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
1332            assert!(
1333                (a - e).abs() < epsilon,
1334                "Mismatch at index {}: actual={}, expected={}, diff={}",
1335                i,
1336                a,
1337                e,
1338                (a - e).abs()
1339            );
1340        }
1341    }
1342
1343    fn assert_approx_eq_f64(actual: &[f64], expected: &[f64], epsilon: f64) {
1344        assert_eq!(actual.len(), expected.len(), "Length mismatch");
1345        for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
1346            assert!(
1347                (a - e).abs() < epsilon,
1348                "Mismatch at index {}: actual={}, expected={}, diff={}",
1349                i,
1350                a,
1351                e,
1352                (a - e).abs()
1353            );
1354        }
1355    }
1356
1357    // -------------------------------------------------------
1358    // Element-wise arithmetic: correctness (SIMD vs scalar)
1359    // -------------------------------------------------------
1360
1361    #[test]
1362    fn test_simd_elementwise_add_f32() {
1363        ag::run::<f32, _, _>(|ctx| {
1364            let a_arr = array![
1365                1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
1366                16.0
1367            ];
1368            let b_arr = array![
1369                0.1f32, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6
1370            ];
1371            let expected: Vec<f32> = a_arr
1372                .iter()
1373                .zip(b_arr.iter())
1374                .map(|(&a, &b)| a + b)
1375                .collect();
1376
1377            let a = ag::tensor_ops::convert_to_tensor(a_arr.clone(), ctx);
1378            let b = ag::tensor_ops::convert_to_tensor(b_arr.clone(), ctx);
1379            let y = simd_elementwise_add(&a, &b);
1380
1381            if let Ok(result) = y.eval(ctx) {
1382                if let Some(result_slice) = result.as_slice() {
1383                    assert_approx_eq_f32(result_slice, &expected, 1e-6);
1384                }
1385            }
1386        });
1387    }
1388
1389    #[test]
1390    fn test_simd_elementwise_sub_f64() {
1391        ag::run::<f64, _, _>(|ctx| {
1392            let a_arr = array![10.0f64, 20.0, 30.0, 40.0];
1393            let b_arr = array![1.0f64, 2.0, 3.0, 4.0];
1394            let expected: Vec<f64> = a_arr
1395                .iter()
1396                .zip(b_arr.iter())
1397                .map(|(&a, &b)| a - b)
1398                .collect();
1399
1400            let a = ag::tensor_ops::convert_to_tensor(a_arr, ctx);
1401            let b = ag::tensor_ops::convert_to_tensor(b_arr, ctx);
1402            let y = simd_elementwise_sub(&a, &b);
1403
1404            if let Ok(result) = y.eval(ctx) {
1405                if let Some(result_slice) = result.as_slice() {
1406                    assert_approx_eq_f64(result_slice, &expected, 1e-12);
1407                }
1408            }
1409        });
1410    }
1411
1412    #[test]
1413    fn test_simd_elementwise_mul_f32() {
1414        ag::run::<f32, _, _>(|ctx| {
1415            let a_arr = array![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1416            let b_arr = array![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
1417            let expected: Vec<f32> = a_arr
1418                .iter()
1419                .zip(b_arr.iter())
1420                .map(|(&a, &b)| a * b)
1421                .collect();
1422
1423            let a = ag::tensor_ops::convert_to_tensor(a_arr, ctx);
1424            let b = ag::tensor_ops::convert_to_tensor(b_arr, ctx);
1425            let y = simd_elementwise_mul(&a, &b);
1426
1427            if let Ok(result) = y.eval(ctx) {
1428                if let Some(result_slice) = result.as_slice() {
1429                    assert_approx_eq_f32(result_slice, &expected, 1e-6);
1430                }
1431            }
1432        });
1433    }
1434
1435    #[test]
1436    fn test_simd_elementwise_div_f64() {
1437        ag::run::<f64, _, _>(|ctx| {
1438            let a_arr = array![10.0f64, 20.0, 30.0, 40.0];
1439            let b_arr = array![2.0f64, 4.0, 5.0, 8.0];
1440            let expected: Vec<f64> = a_arr
1441                .iter()
1442                .zip(b_arr.iter())
1443                .map(|(&a, &b)| a / b)
1444                .collect();
1445
1446            let a = ag::tensor_ops::convert_to_tensor(a_arr, ctx);
1447            let b = ag::tensor_ops::convert_to_tensor(b_arr, ctx);
1448            let y = simd_elementwise_div(&a, &b);
1449
1450            if let Ok(result) = y.eval(ctx) {
1451                if let Some(result_slice) = result.as_slice() {
1452                    assert_approx_eq_f64(result_slice, &expected, 1e-12);
1453                }
1454            }
1455        });
1456    }
1457
1458    // -------------------------------------------------------
1459    // Gradient accumulation
1460    // -------------------------------------------------------
1461
1462    #[test]
1463    fn test_simd_gradient_accumulate_f32() {
1464        ag::run::<f32, _, _>(|ctx| {
1465            let acc_arr = array![1.0f32, 2.0, 3.0, 4.0];
1466            let grad_arr = array![0.1f32, 0.2, 0.3, 0.4];
1467            let expected = vec![1.1, 2.2, 3.3, 4.4];
1468
1469            let acc = ag::tensor_ops::convert_to_tensor(acc_arr, ctx);
1470            let grad = ag::tensor_ops::convert_to_tensor(grad_arr, ctx);
1471            let y = simd_gradient_accumulate(&acc, &grad);
1472
1473            if let Ok(result) = y.eval(ctx) {
1474                if let Some(result_slice) = result.as_slice() {
1475                    assert_approx_eq_f32(result_slice, &expected, 1e-6);
1476                }
1477            }
1478        });
1479    }
1480
1481    #[test]
1482    fn test_simd_scaled_gradient_accumulate_f32() {
1483        ag::run::<f32, _, _>(|ctx| {
1484            let acc_arr = array![1.0f32, 2.0, 3.0, 4.0];
1485            let grad_arr = array![10.0f32, 20.0, 30.0, 40.0];
1486            let scale = 0.1f32;
1487            // expected: acc + scale * grad = [2.0, 4.0, 6.0, 8.0]
1488            let expected = vec![2.0, 4.0, 6.0, 8.0];
1489
1490            let acc = ag::tensor_ops::convert_to_tensor(acc_arr, ctx);
1491            let grad = ag::tensor_ops::convert_to_tensor(grad_arr, ctx);
1492            let y = simd_scaled_gradient_accumulate(&acc, &grad, scale);
1493
1494            if let Ok(result) = y.eval(ctx) {
1495                if let Some(result_slice) = result.as_slice() {
1496                    assert_approx_eq_f32(result_slice, &expected, 1e-5);
1497                }
1498            }
1499        });
1500    }
1501
1502    // -------------------------------------------------------
1503    // Activation functions
1504    // -------------------------------------------------------
1505
1506    #[test]
1507    fn test_simd_relu_f32() {
1508        ag::run::<f32, _, _>(|ctx| {
1509            let x_arr = array![-3.0f32, -1.0, 0.0, 1.0, 3.0, -0.5, 2.0, -2.0];
1510            let expected = vec![0.0, 0.0, 0.0, 1.0, 3.0, 0.0, 2.0, 0.0];
1511
1512            let x = ag::tensor_ops::convert_to_tensor(x_arr, ctx);
1513            let y = simd_activation_relu(&x);
1514
1515            if let Ok(result) = y.eval(ctx) {
1516                if let Some(result_slice) = result.as_slice() {
1517                    assert_approx_eq_f32(result_slice, &expected, 1e-6);
1518                }
1519            }
1520        });
1521    }
1522
1523    #[test]
1524    fn test_simd_sigmoid_f32() {
1525        ag::run::<f32, _, _>(|ctx| {
1526            let x_arr = array![0.0f32, 1.0, -1.0, 5.0, -5.0, 0.5, -0.5, 2.0];
1527            let expected: Vec<f32> = x_arr.iter().map(|&v| 1.0 / (1.0 + (-v).exp())).collect();
1528
1529            let x = ag::tensor_ops::convert_to_tensor(x_arr, ctx);
1530            let y = simd_activation_sigmoid(&x);
1531
1532            if let Ok(result) = y.eval(ctx) {
1533                if let Some(result_slice) = result.as_slice() {
1534                    assert_approx_eq_f32(result_slice, &expected, 1e-4);
1535                }
1536            }
1537        });
1538    }
1539
1540    #[test]
1541    fn test_simd_tanh_f64() {
1542        ag::run::<f64, _, _>(|ctx| {
1543            let x_arr = array![0.0f64, 1.0, -1.0, 2.0, -2.0, 0.5];
1544            let expected: Vec<f64> = x_arr.iter().map(|&v| v.tanh()).collect();
1545
1546            let x = ag::tensor_ops::convert_to_tensor(x_arr, ctx);
1547            let y = simd_activation_tanh(&x);
1548
1549            if let Ok(result) = y.eval(ctx) {
1550                if let Some(result_slice) = result.as_slice() {
1551                    assert_approx_eq_f64(result_slice, &expected, 1e-10);
1552                }
1553            }
1554        });
1555    }
1556
1557    // -------------------------------------------------------
1558    // Dot product / reduction
1559    // -------------------------------------------------------
1560
1561    #[test]
1562    fn test_simd_dot_product_f32() {
1563        ag::run::<f32, _, _>(|ctx| {
1564            let a_arr = array![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1565            let b_arr = array![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
1566            let expected: f32 = a_arr.iter().zip(b_arr.iter()).map(|(&a, &b)| a * b).sum();
1567
1568            let a = ag::tensor_ops::convert_to_tensor(a_arr, ctx);
1569            let b = ag::tensor_ops::convert_to_tensor(b_arr, ctx);
1570            let y = simd_dot_product(&a, &b);
1571
1572            if let Ok(result) = y.eval(ctx) {
1573                let val = result.iter().next().copied().unwrap_or(0.0);
1574                assert!(
1575                    (val - expected).abs() < 1e-3,
1576                    "dot product: got {}, expected {}",
1577                    val,
1578                    expected
1579                );
1580            }
1581        });
1582    }
1583
1584    #[test]
1585    fn test_simd_reduction_sum_f64() {
1586        ag::run::<f64, _, _>(|ctx| {
1587            let x_arr = array![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1588            let expected: f64 = x_arr.iter().sum();
1589
1590            let x = ag::tensor_ops::convert_to_tensor(x_arr, ctx);
1591            let y = simd_reduction_sum(&x);
1592
1593            if let Ok(result) = y.eval(ctx) {
1594                let val = result.iter().next().copied().unwrap_or(0.0);
1595                assert!(
1596                    (val - expected).abs() < 1e-10,
1597                    "sum: got {}, expected {}",
1598                    val,
1599                    expected
1600                );
1601            }
1602        });
1603    }
1604
1605    // -------------------------------------------------------
1606    // Edge cases
1607    // -------------------------------------------------------
1608
1609    #[test]
1610    fn test_simd_empty_array() {
1611        ag::run::<f32, _, _>(|ctx| {
1612            let empty = scirs2_core::ndarray::Array1::<f32>::zeros(0);
1613            let a = ag::tensor_ops::convert_to_tensor(empty.clone(), ctx);
1614            let b = ag::tensor_ops::convert_to_tensor(empty, ctx);
1615            let y = simd_elementwise_add(&a, &b);
1616            if let Ok(result) = y.eval(ctx) {
1617                assert_eq!(result.len(), 0);
1618            }
1619        });
1620    }
1621
1622    #[test]
1623    fn test_simd_single_element() {
1624        ag::run::<f64, _, _>(|ctx| {
1625            let a_arr = array![42.0f64];
1626            let b_arr = array![8.0f64];
1627            let a = ag::tensor_ops::convert_to_tensor(a_arr, ctx);
1628            let b = ag::tensor_ops::convert_to_tensor(b_arr, ctx);
1629            let y = simd_elementwise_mul(&a, &b);
1630            if let Ok(result) = y.eval(ctx) {
1631                if let Some(slice) = result.as_slice() {
1632                    assert_approx_eq_f64(slice, &[336.0], 1e-12);
1633                }
1634            }
1635        });
1636    }
1637
1638    #[test]
1639    fn test_simd_relu_all_negative() {
1640        ag::run::<f32, _, _>(|ctx| {
1641            let x_arr = array![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0];
1642            let expected = vec![0.0f32; 8];
1643
1644            let x = ag::tensor_ops::convert_to_tensor(x_arr, ctx);
1645            let y = simd_activation_relu(&x);
1646
1647            if let Ok(result) = y.eval(ctx) {
1648                if let Some(result_slice) = result.as_slice() {
1649                    assert_approx_eq_f32(result_slice, &expected, 1e-6);
1650                }
1651            }
1652        });
1653    }
1654
1655    #[test]
1656    fn test_simd_sigmoid_extreme_values() {
1657        ag::run::<f32, _, _>(|ctx| {
1658            let x_arr = array![-100.0f32, 100.0, 0.0, -50.0, 50.0, -10.0, 10.0, 0.0];
1659            // sigmoid(-100) ~ 0, sigmoid(100) ~ 1, sigmoid(0) = 0.5
1660            let x = ag::tensor_ops::convert_to_tensor(x_arr, ctx);
1661            let y = simd_activation_sigmoid(&x);
1662
1663            if let Ok(result) = y.eval(ctx) {
1664                if let Some(slice) = result.as_slice() {
1665                    // Very negative => ~0
1666                    assert!(
1667                        slice[0] < 1e-6,
1668                        "sigmoid(-100) should be near 0, got {}",
1669                        slice[0]
1670                    );
1671                    // Very positive => ~1
1672                    assert!(
1673                        (slice[1] - 1.0).abs() < 1e-6,
1674                        "sigmoid(100) should be near 1, got {}",
1675                        slice[1]
1676                    );
1677                    // Zero => 0.5
1678                    assert!(
1679                        (slice[2] - 0.5).abs() < 1e-4,
1680                        "sigmoid(0) should be 0.5, got {}",
1681                        slice[2]
1682                    );
1683                }
1684            }
1685        });
1686    }
1687
1688    #[test]
1689    fn test_simd_large_array_add() {
1690        // Test with a large array to exercise the full SIMD path (>32 elements)
1691        ag::run::<f32, _, _>(|ctx| {
1692            let n = 1024;
1693            let a_vec: Vec<f32> = (0..n).map(|i| i as f32 * 0.1).collect();
1694            let b_vec: Vec<f32> = (0..n).map(|i| (n - i) as f32 * 0.01).collect();
1695            let expected: Vec<f32> = a_vec
1696                .iter()
1697                .zip(b_vec.iter())
1698                .map(|(&a, &b)| a + b)
1699                .collect();
1700
1701            let a_arr = Array1::from_vec(a_vec);
1702            let b_arr = Array1::from_vec(b_vec);
1703
1704            let a = ag::tensor_ops::convert_to_tensor(a_arr, ctx);
1705            let b = ag::tensor_ops::convert_to_tensor(b_arr, ctx);
1706            let y = simd_elementwise_add(&a, &b);
1707
1708            if let Ok(result) = y.eval(ctx) {
1709                if let Some(result_slice) = result.as_slice() {
1710                    assert_approx_eq_f32(result_slice, &expected, 1e-4);
1711                }
1712            }
1713        });
1714    }
1715
1716    // -------------------------------------------------------
1717    // Gradient correctness tests
1718    // -------------------------------------------------------
1719
1720    #[test]
1721    fn test_simd_add_gradient() {
1722        ag::run::<f64, _, _>(|ctx| {
1723            let x = ctx.placeholder("x", &[4]);
1724            let y = ctx.placeholder("y", &[4]);
1725            let z = simd_elementwise_add(&x, &y);
1726            let sum_z = ag::tensor_ops::sum_all(z);
1727
1728            let grads = ag::tensor_ops::grad(&[sum_z], &[x, y]);
1729
1730            let x_val = array![1.0f64, 2.0, 3.0, 4.0];
1731            let y_val = array![5.0f64, 6.0, 7.0, 8.0];
1732
1733            let results = ctx
1734                .evaluator()
1735                .push(&grads[0])
1736                .push(&grads[1])
1737                .feed(x, x_val.view().into_dyn())
1738                .feed(y, y_val.view().into_dyn())
1739                .run();
1740
1741            // d(sum(x+y))/dx = [1,1,1,1], d(sum(x+y))/dy = [1,1,1,1]
1742            if let Some(Ok(dx)) = results.first() {
1743                if let Some(dx_slice) = dx.as_slice() {
1744                    assert_approx_eq_f64(dx_slice, &[1.0, 1.0, 1.0, 1.0], 1e-10);
1745                }
1746            }
1747            if let Some(Ok(dy)) = results.get(1) {
1748                if let Some(dy_slice) = dy.as_slice() {
1749                    assert_approx_eq_f64(dy_slice, &[1.0, 1.0, 1.0, 1.0], 1e-10);
1750                }
1751            }
1752        });
1753    }
1754
1755    #[test]
1756    fn test_simd_mul_gradient() {
1757        ag::run::<f64, _, _>(|ctx| {
1758            let x = ctx.placeholder("x", &[4]);
1759            let y = ctx.placeholder("y", &[4]);
1760            let z = simd_elementwise_mul(&x, &y);
1761            let sum_z = ag::tensor_ops::sum_all(z);
1762
1763            let grads = ag::tensor_ops::grad(&[sum_z], &[x, y]);
1764
1765            let x_val = array![1.0f64, 2.0, 3.0, 4.0];
1766            let y_val = array![5.0f64, 6.0, 7.0, 8.0];
1767
1768            let results = ctx
1769                .evaluator()
1770                .push(&grads[0])
1771                .push(&grads[1])
1772                .feed(x, x_val.view().into_dyn())
1773                .feed(y, y_val.view().into_dyn())
1774                .run();
1775
1776            // d(sum(x*y))/dx = y, d(sum(x*y))/dy = x
1777            if let Some(Ok(dx)) = results.first() {
1778                if let Some(dx_slice) = dx.as_slice() {
1779                    assert_approx_eq_f64(dx_slice, &[5.0, 6.0, 7.0, 8.0], 1e-10);
1780                }
1781            }
1782            if let Some(Ok(dy)) = results.get(1) {
1783                if let Some(dy_slice) = dy.as_slice() {
1784                    assert_approx_eq_f64(dy_slice, &[1.0, 2.0, 3.0, 4.0], 1e-10);
1785                }
1786            }
1787        });
1788    }
1789}