train_station/tensor/ops/
add.rs

1//! Addition operations for tensors
2//!
3//! Provides element-wise addition following PyTorch conventions with comprehensive
4//! broadcasting support, automatic differentiation, and high-performance SIMD optimization.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Addition**: `add_tensor()` - Addition with another tensor (PyTorch `add()` equivalent)
9//! - **Scalar Broadcasting**: `add_scalar()` - Addition with scalar values
10//! - **Automatic Broadcasting**: NumPy-style broadcasting for compatible shapes
11//! - **SIMD Optimization**: AVX2 acceleration on x86_64 hardware
12//! - **Automatic Differentiation**: Full gradtrack support with gradient tracking
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Zero-copy Operations**: Efficient memory usage where possible
15//!
16//! # Broadcasting Support
17//!
18//! All addition operations support automatic broadcasting following NumPy rules:
19//! - Dimensions are aligned from the rightmost dimension
20//! - Dimensions are compatible if they are equal, or one of them is 1
21//! - Missing dimensions are treated as 1
22//! - Result shape follows broadcasting rules
23//!
24//! # Performance Characteristics
25//!
26//! - **SIMD Acceleration**: 8x vectorization with AVX2 on compatible hardware
27//! - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
28//! - **Cache-friendly Access**: Linear memory access patterns
29//! - **Fallback Support**: Optimized scalar implementations for non-SIMD hardware
30//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
31
32use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
33#[cfg(target_arch = "x86_64")]
34use crate::tensor::core::memory::simd_alignment_bytes;
35use crate::tensor::core::memory::{detect_runtime_simd, SimdLevel};
36// Enhanced thread pool imports
37// thread pool invoked via aligned wrappers inside functions where needed
38use crate::tensor::core::Tensor;
39// SIMD optimizations for performance-critical operations
40
41// (Removed manual prefetching: simplifies hot path; modern CPUs prefetch effectively for linear access)
42
43/// OPTIMIZATION #6: Cached SIMD kernels and dispatch information for maximum performance
44struct CachedKernels {
45    simd_level: SimdLevel,
46    alignment: usize,
47
48    // Tensor + Tensor kernels
49    tensor_aligned: unsafe fn(*const f32, *const f32, *mut f32, usize),
50    tensor_unaligned: unsafe fn(*const f32, *const f32, *mut f32, usize),
51    tensor_stream: unsafe fn(*const f32, *const f32, *mut f32, usize),
52
53    // Tensor + Scalar kernels
54    scalar_aligned: unsafe fn(*const f32, *mut f32, usize, f32),
55    scalar_unaligned: unsafe fn(*const f32, *mut f32, usize, f32),
56    scalar_stream: unsafe fn(*const f32, *mut f32, usize, f32),
57
58    // Dispatch parameters
59    min_aligned_size: usize,
60    min_stream_size: usize,
61}
62
63impl Tensor {
64    // ===== Streaming store thresholds =====
65    #[inline]
66    pub(crate) fn stream_min_elems() -> usize {
67        1 << 22 // ~16MB per chunk (f32) conservative threshold for streaming stores
68    }
69
70    /// Element-wise addition with another tensor with broadcasting support.
71    ///
72    /// Performs element-wise addition with automatic broadcasting: `output[i] = self[i] + other[i]`
73    ///
74    /// Broadcasting enables addition between tensors of different but compatible shapes.
75    /// Compatible shapes follow NumPy broadcasting rules:
76    /// - Dimensions are aligned from the rightmost dimension
77    /// - Dimensions are compatible if they are equal, or one of them is 1
78    /// - Missing dimensions are treated as 1
79    ///
80    /// # Arguments
81    /// * `other` - Tensor to add. Shapes must be broadcast-compatible.
82    ///
83    /// # Returns
84    /// A new tensor containing the element-wise sum with broadcast result shape
85    ///
86    /// # Examples
87    ///
88    /// ## Same Shape Addition
89    ///
90    /// ```
91    /// use train_station::Tensor;
92    ///
93    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
94    /// let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
95    /// let c = a.add_tensor(&b);
96    /// assert_eq!(c.shape().dims(), vec![3]);
97    /// assert_eq!(c.get(&[0]), 5.0);
98    /// assert_eq!(c.get(&[1]), 7.0);
99    /// assert_eq!(c.get(&[2]), 9.0);
100    /// ```
101    ///
102    /// ## Broadcasting Addition
103    ///
104    /// ```
105    /// use train_station::Tensor;
106    ///
107    /// // Broadcasting: [2, 1] + [1, 3] -> [2, 3]
108    /// let a = Tensor::from_slice(&[1.0, 2.0], vec![2, 1]).unwrap();
109    /// let b = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![1, 3]).unwrap();
110    /// let c = a.add_tensor(&b);
111    /// assert_eq!(c.shape().dims(), vec![2, 3]);
112    /// assert_eq!(c.get(&[0, 0]), 11.0);
113    /// assert_eq!(c.get(&[0, 1]), 21.0);
114    /// assert_eq!(c.get(&[1, 0]), 12.0);
115    /// assert_eq!(c.get(&[1, 1]), 22.0);
116    /// ```
117    ///
118    /// ## Scalar Broadcasting
119    ///
120    /// ```
121    /// use train_station::Tensor;
122    ///
123    /// // Scalar broadcasting: [2, 3] + scalar -> [2, 3]
124    /// let a = Tensor::ones(vec![2, 3]);
125    /// let b = Tensor::from_slice(&[5.0], vec![1]).unwrap();
126    /// let c = a.add_tensor(&b);
127    /// assert_eq!(c.shape().dims(), vec![2, 3]);
128    /// assert_eq!(c.get(&[0, 0]), 6.0);
129    /// assert_eq!(c.get(&[1, 2]), 6.0);
130    /// ```
131    ///
132    /// # Panics
133    /// Panics if tensor shapes are not broadcast-compatible
134    #[inline]
135    #[track_caller]
136    pub fn add_tensor(&self, other: &Tensor) -> Tensor {
137        // Check if shapes are identical for fast path
138        if self.shape().dims() == other.shape().dims() {
139            return self.add_tensor_same_shape(other);
140        }
141
142        // Use broadcasting.rs for efficient broadcasting
143        let mut result = self.add_tensor_optimized(other);
144
145        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
146            result.set_requires_grad_internal(true);
147            let grad_fn = GradFn::Add {
148                is_tensor_add: true,
149                original_shapes: Some((
150                    self.shape().dims().to_vec(),
151                    other.shape().dims().to_vec(),
152                )),
153            };
154            result.set_grad_fn(grad_fn.clone());
155
156            // Optimize input ID collection - avoid Vec allocation for common cases
157            let input_ids = vec![self.id(), other.id()];
158            GradEngine::register_operation(result.id(), input_ids, grad_fn);
159        }
160
161        result
162    }
163
164    /// Element-wise addition for tensors with identical shapes (fast path).
165    #[inline]
166    fn add_tensor_same_shape(&self, other: &Tensor) -> Tensor {
167        assert_eq!(
168            self.shape(),
169            other.shape(),
170            "Tensor shapes must match for same-shape addition"
171        );
172        let mut result = self.add_tensor_same_shape_optimized(other);
173
174        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
175            result.set_requires_grad_internal(true);
176            let grad_fn = GradFn::Add {
177                is_tensor_add: true,
178                original_shapes: None, // Same shape case
179            };
180            result.set_grad_fn(grad_fn.clone());
181
182            // Optimize input ID collection - avoid Vec allocation for common cases
183            let input_ids = vec![self.id(), other.id()];
184            GradEngine::register_operation(result.id(), input_ids, grad_fn);
185        }
186
187        result
188    }
189
190    /// Broadcast addition with a scalar value.
191    #[inline]
192    #[track_caller]
193    pub fn add_scalar(&self, scalar: f32) -> Tensor {
194        let mut result = self.add_scalar_optimized(scalar);
195
196        if self.requires_grad() && is_grad_enabled() {
197            result.set_requires_grad_internal(true);
198            let grad_fn = GradFn::Add {
199                is_tensor_add: false,
200                original_shapes: None, // Scalar case
201            };
202            result.set_grad_fn(grad_fn.clone());
203            // Optimize scalar operation registration - no Vec allocation needed
204            let input_ids = vec![self.id()];
205            GradEngine::register_operation(result.id(), input_ids, grad_fn);
206        }
207
208        result
209    }
210
211    /// Internal optimized tensor + tensor operation with broadcasting support
212    #[inline]
213    pub(crate) fn add_tensor_optimized(&self, other: &Tensor) -> Tensor {
214        // Check if shapes are identical for fast path
215        if self.shape() == other.shape() {
216            return self.add_tensor_same_shape_optimized(other);
217        }
218
219        // Use zero-copy broadcasting to create same-shape views, then reuse optimized kernels
220        use crate::tensor::ops::broadcasting::{broadcast_shapes_cow, BroadcastError};
221
222        match broadcast_shapes_cow(self, other) {
223            Ok((broadcasted_self, broadcasted_other, _result_shape)) => {
224                debug_assert_eq!(
225                    broadcasted_self.shape().dims(),
226                    broadcasted_other.shape().dims()
227                );
228                broadcasted_self
229                    .as_ref()
230                    .add_tensor_same_shape_optimized(broadcasted_other.as_ref())
231            }
232            Err(BroadcastError::IncompatibleShapes { shape1, shape2, .. }) => {
233                panic!(
234                    "Cannot broadcast tensor shapes {:?} and {:?}: shapes are incompatible",
235                    shape1, shape2
236                );
237            }
238            Err(BroadcastError::AllocationFailed) => {
239                panic!("Memory allocation failed during broadcasting");
240            }
241        }
242    }
243
244    /// Optimized same-shape tensor addition (extracted from original add_tensor_optimized)
245    #[inline]
246    fn add_tensor_same_shape_optimized(&self, other: &Tensor) -> Tensor {
247        debug_assert_eq!(
248            self.shape().dims(),
249            other.shape().dims(),
250            "Tensor dims must match"
251        );
252
253        // Optimize contiguous handling - use views when possible
254        let (a_ptr, _a_keep): (*const f32, Option<Tensor>) = Self::get_optimized_tensor_ptr(self);
255        let (b_ptr, _b_keep): (*const f32, Option<Tensor>) = Self::get_optimized_tensor_ptr(other);
256
257        let mut output = Tensor::new(self.shape().dims().to_vec());
258
259        unsafe {
260            let a = a_ptr;
261            let b = b_ptr;
262            let dst = output.as_mut_ptr();
263            let n = self.size();
264
265            // Sequential execution with SIMD optimization
266            let stream_min = Self::stream_min_elems();
267            if n >= stream_min && Self::try_add_stream_best(a, b, dst, n) {
268                // done via streaming stores
269            } else if !Self::try_add_simd_best(a, b, dst, n) {
270                Self::add_tensors_scalar_chunk(a, b, dst, n);
271            }
272        }
273
274        output
275    }
276
277    // Replaced single AVX2 kernel with multi-level SIMD selection and kernels below
278
279    // ===== OPTIMIZATION #6: Cached SIMD tensor+tensor addition dispatch =====
280    #[inline]
281    unsafe fn try_add_simd_best(a: *const f32, b: *const f32, dst: *mut f32, size: usize) -> bool {
282        if size == 0 {
283            return true;
284        }
285
286        let kernels = Self::get_cached_kernels();
287
288        // Skip SIMD for scalar fallback
289        if matches!(kernels.simd_level, SimdLevel::Scalar) {
290            return false;
291        }
292
293        // Check alignment for optimal kernel selection
294        let a_mod = (a as usize) % kernels.alignment;
295        let b_mod = (b as usize) % kernels.alignment;
296        let d_mod = (dst as usize) % kernels.alignment;
297
298        // All aligned - use fastest kernel
299        if a_mod == 0 && b_mod == 0 && d_mod == 0 && size >= kernels.min_aligned_size {
300            (kernels.tensor_aligned)(a, b, dst, size);
301            return true;
302        }
303
304        // Same misalignment - align and use fast kernel
305        if a_mod == b_mod && b_mod == d_mod && size >= kernels.min_aligned_size {
306            let bytes_to_align = if a_mod == 0 {
307                0
308            } else {
309                kernels.alignment - a_mod
310            };
311            let elems_to_align = (bytes_to_align / std::mem::size_of::<f32>()).min(size);
312
313            // Scalar prologue to achieve alignment
314            for i in 0..elems_to_align {
315                *dst.add(i) = *a.add(i) + *b.add(i);
316            }
317
318            let rem = size - elems_to_align;
319            if rem >= kernels.min_aligned_size {
320                (kernels.tensor_aligned)(
321                    a.add(elems_to_align),
322                    b.add(elems_to_align),
323                    dst.add(elems_to_align),
324                    rem,
325                );
326            }
327            return true;
328        }
329
330        // Mixed alignment - use unaligned kernel
331        (kernels.tensor_unaligned)(a, b, dst, size);
332        true
333    }
334
335    // OPTIMIZATION #6: Cached streaming store selection for tensor+tensor addition
336    #[inline]
337    unsafe fn try_add_stream_best(
338        a: *const f32,
339        b: *const f32,
340        dst: *mut f32,
341        size: usize,
342    ) -> bool {
343        let kernels = Self::get_cached_kernels();
344
345        if size < kernels.min_stream_size || size == 0 {
346            return false;
347        }
348
349        // Only use streaming if destination is aligned (critical for streaming stores)
350        if (dst as usize).is_multiple_of(kernels.alignment) {
351            (kernels.tensor_stream)(a, b, dst, size);
352            return true;
353        }
354
355        false
356    }
357
358    #[cfg(target_arch = "x86_64")]
359    #[inline]
360    #[target_feature(enable = "avx512f")]
361    unsafe fn add_simd_avx512_aligned(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
362        use std::arch::x86_64::*;
363        let mut offset = 0usize;
364        let block = 64usize; // 4x unroll of 16-wide vectors
365        while offset + block <= size {
366            let a1 = _mm512_load_ps(a.add(offset));
367            let b1 = _mm512_load_ps(b.add(offset));
368            _mm512_store_ps(dst.add(offset), _mm512_add_ps(a1, b1));
369
370            let a2 = _mm512_load_ps(a.add(offset + 16));
371            let b2 = _mm512_load_ps(b.add(offset + 16));
372            _mm512_store_ps(dst.add(offset + 16), _mm512_add_ps(a2, b2));
373
374            let a3 = _mm512_load_ps(a.add(offset + 32));
375            let b3 = _mm512_load_ps(b.add(offset + 32));
376            _mm512_store_ps(dst.add(offset + 32), _mm512_add_ps(a3, b3));
377
378            let a4 = _mm512_load_ps(a.add(offset + 48));
379            let b4 = _mm512_load_ps(b.add(offset + 48));
380            _mm512_store_ps(dst.add(offset + 48), _mm512_add_ps(a4, b4));
381            offset += block;
382        }
383        let mut rem = size - offset;
384        while rem >= 16 {
385            let av = _mm512_load_ps(a.add(offset));
386            let bv = _mm512_load_ps(b.add(offset));
387            _mm512_store_ps(dst.add(offset), _mm512_add_ps(av, bv));
388            offset += 16;
389            rem -= 16;
390        }
391        for i in offset..size {
392            *dst.add(i) = *a.add(i) + *b.add(i);
393        }
394    }
395
396    #[cfg(target_arch = "x86_64")]
397    #[inline]
398    #[target_feature(enable = "avx512f")]
399    unsafe fn add_simd_avx512_unaligned(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
400        use std::arch::x86_64::*;
401        let mut offset = 0usize;
402        let block = 64usize;
403        while offset + block <= size {
404            let a1 = _mm512_loadu_ps(a.add(offset));
405            let b1 = _mm512_loadu_ps(b.add(offset));
406            _mm512_storeu_ps(dst.add(offset), _mm512_add_ps(a1, b1));
407
408            let a2 = _mm512_loadu_ps(a.add(offset + 16));
409            let b2 = _mm512_loadu_ps(b.add(offset + 16));
410            _mm512_storeu_ps(dst.add(offset + 16), _mm512_add_ps(a2, b2));
411
412            let a3 = _mm512_loadu_ps(a.add(offset + 32));
413            let b3 = _mm512_loadu_ps(b.add(offset + 32));
414            _mm512_storeu_ps(dst.add(offset + 32), _mm512_add_ps(a3, b3));
415
416            let a4 = _mm512_loadu_ps(a.add(offset + 48));
417            let b4 = _mm512_loadu_ps(b.add(offset + 48));
418            _mm512_storeu_ps(dst.add(offset + 48), _mm512_add_ps(a4, b4));
419            offset += block;
420        }
421        let mut rem = size - offset;
422        while rem >= 16 {
423            let av = _mm512_loadu_ps(a.add(offset));
424            let bv = _mm512_loadu_ps(b.add(offset));
425            _mm512_storeu_ps(dst.add(offset), _mm512_add_ps(av, bv));
426            offset += 16;
427            rem -= 16;
428        }
429        for i in offset..size {
430            *dst.add(i) = *a.add(i) + *b.add(i);
431        }
432    }
433
434    #[cfg(target_arch = "x86_64")]
435    #[inline]
436    #[target_feature(enable = "avx512f")]
437    unsafe fn add_simd_avx512_stream(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
438        use std::arch::x86_64::*;
439        let mut offset = 0usize;
440        let block = 64usize;
441        while offset + block <= size {
442            let a1 = _mm512_loadu_ps(a.add(offset));
443            let b1 = _mm512_loadu_ps(b.add(offset));
444            _mm512_stream_ps(dst.add(offset), _mm512_add_ps(a1, b1));
445
446            let a2 = _mm512_loadu_ps(a.add(offset + 16));
447            let b2 = _mm512_loadu_ps(b.add(offset + 16));
448            _mm512_stream_ps(dst.add(offset + 16), _mm512_add_ps(a2, b2));
449
450            let a3 = _mm512_loadu_ps(a.add(offset + 32));
451            let b3 = _mm512_loadu_ps(b.add(offset + 32));
452            _mm512_stream_ps(dst.add(offset + 32), _mm512_add_ps(a3, b3));
453
454            let a4 = _mm512_loadu_ps(a.add(offset + 48));
455            let b4 = _mm512_loadu_ps(b.add(offset + 48));
456            _mm512_stream_ps(dst.add(offset + 48), _mm512_add_ps(a4, b4));
457            offset += block;
458        }
459        let mut rem = size - offset;
460        while rem >= 16 {
461            let av = _mm512_loadu_ps(a.add(offset));
462            let bv = _mm512_loadu_ps(b.add(offset));
463            _mm512_stream_ps(dst.add(offset), _mm512_add_ps(av, bv));
464            offset += 16;
465            rem -= 16;
466        }
467        for i in offset..size {
468            *dst.add(i) = *a.add(i) + *b.add(i);
469        }
470    }
471
472    #[cfg(target_arch = "x86_64")]
473    #[inline]
474    #[target_feature(enable = "avx2")]
475    unsafe fn add_simd_avx2_aligned(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
476        use std::arch::x86_64::*;
477        let mut offset = 0usize;
478        let block = 32usize; // 4x unroll of 8-wide vectors
479        while offset + block <= size {
480            let a1 = _mm256_load_ps(a.add(offset));
481            let b1 = _mm256_load_ps(b.add(offset));
482            _mm256_store_ps(dst.add(offset), _mm256_add_ps(a1, b1));
483
484            let a2 = _mm256_load_ps(a.add(offset + 8));
485            let b2 = _mm256_load_ps(b.add(offset + 8));
486            _mm256_store_ps(dst.add(offset + 8), _mm256_add_ps(a2, b2));
487
488            let a3 = _mm256_load_ps(a.add(offset + 16));
489            let b3 = _mm256_load_ps(b.add(offset + 16));
490            _mm256_store_ps(dst.add(offset + 16), _mm256_add_ps(a3, b3));
491
492            let a4 = _mm256_load_ps(a.add(offset + 24));
493            let b4 = _mm256_load_ps(b.add(offset + 24));
494            _mm256_store_ps(dst.add(offset + 24), _mm256_add_ps(a4, b4));
495            offset += block;
496        }
497        let mut rem = size - offset;
498        while rem >= 8 {
499            let av = _mm256_load_ps(a.add(offset));
500            let bv = _mm256_load_ps(b.add(offset));
501            _mm256_store_ps(dst.add(offset), _mm256_add_ps(av, bv));
502            offset += 8;
503            rem -= 8;
504        }
505        for i in offset..size {
506            *dst.add(i) = *a.add(i) + *b.add(i);
507        }
508    }
509
510    #[cfg(target_arch = "x86_64")]
511    #[inline]
512    #[target_feature(enable = "avx2")]
513    unsafe fn add_simd_avx2_unaligned(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
514        use std::arch::x86_64::*;
515        let mut offset = 0usize;
516        let block = 32usize;
517        while offset + block <= size {
518            let a1 = _mm256_loadu_ps(a.add(offset));
519            let b1 = _mm256_loadu_ps(b.add(offset));
520            _mm256_storeu_ps(dst.add(offset), _mm256_add_ps(a1, b1));
521
522            let a2 = _mm256_loadu_ps(a.add(offset + 8));
523            let b2 = _mm256_loadu_ps(b.add(offset + 8));
524            _mm256_storeu_ps(dst.add(offset + 8), _mm256_add_ps(a2, b2));
525
526            let a3 = _mm256_loadu_ps(a.add(offset + 16));
527            let b3 = _mm256_loadu_ps(b.add(offset + 16));
528            _mm256_storeu_ps(dst.add(offset + 16), _mm256_add_ps(a3, b3));
529
530            let a4 = _mm256_loadu_ps(a.add(offset + 24));
531            let b4 = _mm256_loadu_ps(b.add(offset + 24));
532            _mm256_storeu_ps(dst.add(offset + 24), _mm256_add_ps(a4, b4));
533            offset += block;
534        }
535        let mut rem = size - offset;
536        while rem >= 8 {
537            let av = _mm256_loadu_ps(a.add(offset));
538            let bv = _mm256_loadu_ps(b.add(offset));
539            _mm256_storeu_ps(dst.add(offset), _mm256_add_ps(av, bv));
540            offset += 8;
541            rem -= 8;
542        }
543        for i in offset..size {
544            *dst.add(i) = *a.add(i) + *b.add(i);
545        }
546    }
547
548    #[cfg(target_arch = "x86_64")]
549    #[inline]
550    #[target_feature(enable = "avx2")]
551    unsafe fn add_simd_avx2_stream(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
552        use std::arch::x86_64::*;
553        let mut offset = 0usize;
554        let block = 32usize;
555        while offset + block <= size {
556            let a1 = _mm256_loadu_ps(a.add(offset));
557            let b1 = _mm256_loadu_ps(b.add(offset));
558            _mm256_stream_ps(dst.add(offset), _mm256_add_ps(a1, b1));
559
560            let a2 = _mm256_loadu_ps(a.add(offset + 8));
561            let b2 = _mm256_loadu_ps(b.add(offset + 8));
562            _mm256_stream_ps(dst.add(offset + 8), _mm256_add_ps(a2, b2));
563
564            let a3 = _mm256_loadu_ps(a.add(offset + 16));
565            let b3 = _mm256_loadu_ps(b.add(offset + 16));
566            _mm256_stream_ps(dst.add(offset + 16), _mm256_add_ps(a3, b3));
567
568            let a4 = _mm256_loadu_ps(a.add(offset + 24));
569            let b4 = _mm256_loadu_ps(b.add(offset + 24));
570            _mm256_stream_ps(dst.add(offset + 24), _mm256_add_ps(a4, b4));
571            offset += block;
572        }
573        let mut rem = size - offset;
574        while rem >= 8 {
575            let av = _mm256_loadu_ps(a.add(offset));
576            let bv = _mm256_loadu_ps(b.add(offset));
577            _mm256_stream_ps(dst.add(offset), _mm256_add_ps(av, bv));
578            offset += 8;
579            rem -= 8;
580        }
581        for i in offset..size {
582            *dst.add(i) = *a.add(i) + *b.add(i);
583        }
584    }
585
586    #[cfg(target_arch = "x86_64")]
587    #[inline]
588    #[target_feature(enable = "sse2")]
589    unsafe fn add_simd_sse_aligned(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
590        use std::arch::x86_64::*;
591        let mut offset = 0usize;
592        let block = 16usize; // 4x unroll of 4-wide vectors
593        while offset + block <= size {
594            let a1 = _mm_load_ps(a.add(offset));
595            let b1 = _mm_load_ps(b.add(offset));
596            _mm_store_ps(dst.add(offset), _mm_add_ps(a1, b1));
597
598            let a2 = _mm_load_ps(a.add(offset + 4));
599            let b2 = _mm_load_ps(b.add(offset + 4));
600            _mm_store_ps(dst.add(offset + 4), _mm_add_ps(a2, b2));
601
602            let a3 = _mm_load_ps(a.add(offset + 8));
603            let b3 = _mm_load_ps(b.add(offset + 8));
604            _mm_store_ps(dst.add(offset + 8), _mm_add_ps(a3, b3));
605
606            let a4 = _mm_load_ps(a.add(offset + 12));
607            let b4 = _mm_load_ps(b.add(offset + 12));
608            _mm_store_ps(dst.add(offset + 12), _mm_add_ps(a4, b4));
609            offset += block;
610        }
611        let mut rem = size - offset;
612        while rem >= 4 {
613            let av = _mm_load_ps(a.add(offset));
614            let bv = _mm_load_ps(b.add(offset));
615            _mm_store_ps(dst.add(offset), _mm_add_ps(av, bv));
616            offset += 4;
617            rem -= 4;
618        }
619        for i in offset..size {
620            *dst.add(i) = *a.add(i) + *b.add(i);
621        }
622    }
623
624    #[cfg(target_arch = "x86_64")]
625    #[inline]
626    #[target_feature(enable = "sse2")]
627    unsafe fn add_simd_sse_unaligned(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
628        use std::arch::x86_64::*;
629        let mut offset = 0usize;
630        let block = 16usize;
631        while offset + block <= size {
632            let a1 = _mm_loadu_ps(a.add(offset));
633            let b1 = _mm_loadu_ps(b.add(offset));
634            _mm_storeu_ps(dst.add(offset), _mm_add_ps(a1, b1));
635
636            let a2 = _mm_loadu_ps(a.add(offset + 4));
637            let b2 = _mm_loadu_ps(b.add(offset + 4));
638            _mm_storeu_ps(dst.add(offset + 4), _mm_add_ps(a2, b2));
639
640            let a3 = _mm_loadu_ps(a.add(offset + 8));
641            let b3 = _mm_loadu_ps(b.add(offset + 8));
642            _mm_storeu_ps(dst.add(offset + 8), _mm_add_ps(a3, b3));
643
644            let a4 = _mm_loadu_ps(a.add(offset + 12));
645            let b4 = _mm_loadu_ps(b.add(offset + 12));
646            _mm_storeu_ps(dst.add(offset + 12), _mm_add_ps(a4, b4));
647            offset += block;
648        }
649        let mut rem = size - offset;
650        while rem >= 4 {
651            let s = _mm_loadu_ps(a.add(offset));
652            let t = _mm_loadu_ps(b.add(offset));
653            _mm_storeu_ps(dst.add(offset), _mm_add_ps(s, t));
654            offset += 4;
655            rem -= 4;
656        }
657        for i in offset..size {
658            *dst.add(i) = *a.add(i) + *b.add(i);
659        }
660    }
661
662    #[cfg(target_arch = "x86_64")]
663    #[inline]
664    #[target_feature(enable = "sse2")]
665    unsafe fn add_simd_sse_stream(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
666        use std::arch::x86_64::*;
667        let mut offset = 0usize;
668        let block = 16usize;
669        while offset + block <= size {
670            let a1 = _mm_loadu_ps(a.add(offset));
671            let b1 = _mm_loadu_ps(b.add(offset));
672            _mm_stream_ps(dst.add(offset), _mm_add_ps(a1, b1));
673
674            let a2 = _mm_loadu_ps(a.add(offset + 4));
675            let b2 = _mm_loadu_ps(b.add(offset + 4));
676            _mm_stream_ps(dst.add(offset + 4), _mm_add_ps(a2, b2));
677
678            let a3 = _mm_loadu_ps(a.add(offset + 8));
679            let b3 = _mm_loadu_ps(b.add(offset + 8));
680            _mm_stream_ps(dst.add(offset + 8), _mm_add_ps(a3, b3));
681
682            let a4 = _mm_loadu_ps(a.add(offset + 12));
683            let b4 = _mm_loadu_ps(b.add(offset + 12));
684            _mm_stream_ps(dst.add(offset + 12), _mm_add_ps(a4, b4));
685            offset += block;
686        }
687        let mut rem = size - offset;
688        while rem >= 4 {
689            let s = _mm_loadu_ps(a.add(offset));
690            let t = _mm_loadu_ps(b.add(offset));
691            _mm_stream_ps(dst.add(offset), _mm_add_ps(s, t));
692            offset += 4;
693            rem -= 4;
694        }
695        for i in offset..size {
696            *dst.add(i) = *a.add(i) + *b.add(i);
697        }
698    }
699
700    // ===== SIMD scalar addition selection and kernels =====
701    #[inline]
702    // OPTIMIZATION #6: Cached scalar SIMD dispatch
703    unsafe fn try_add_scalar_simd_best(
704        src: *const f32,
705        dst: *mut f32,
706        size: usize,
707        scalar: f32,
708    ) -> bool {
709        if size == 0 {
710            return true;
711        }
712
713        let kernels = Self::get_cached_kernels();
714
715        // Skip SIMD for scalar fallback
716        if matches!(kernels.simd_level, SimdLevel::Scalar) {
717            return false;
718        }
719
720        // Check alignment for optimal kernel selection
721        let s_mod = (src as usize) % kernels.alignment;
722        let d_mod = (dst as usize) % kernels.alignment;
723
724        // Both aligned - use fastest kernel
725        if s_mod == 0 && d_mod == 0 && size >= kernels.min_aligned_size {
726            (kernels.scalar_aligned)(src, dst, size, scalar);
727            return true;
728        }
729
730        // Same misalignment - align and use fast kernel
731        if s_mod == d_mod && size >= kernels.min_aligned_size {
732            let bytes_to_align = if s_mod == 0 {
733                0
734            } else {
735                kernels.alignment - s_mod
736            };
737            let elems_to_align = (bytes_to_align / std::mem::size_of::<f32>()).min(size);
738
739            // Scalar prologue to achieve alignment
740            for i in 0..elems_to_align {
741                *dst.add(i) = *src.add(i) + scalar;
742            }
743
744            let rem = size - elems_to_align;
745            if rem >= kernels.min_aligned_size {
746                (kernels.scalar_aligned)(
747                    src.add(elems_to_align),
748                    dst.add(elems_to_align),
749                    rem,
750                    scalar,
751                );
752            }
753            return true;
754        }
755
756        // Mixed alignment - use unaligned kernel
757        (kernels.scalar_unaligned)(src, dst, size, scalar);
758        true
759    }
760
761    // OPTIMIZATION #6: Cached streaming store selection for scalar addition
762    #[inline]
763    unsafe fn try_add_scalar_stream_best(
764        src: *const f32,
765        dst: *mut f32,
766        size: usize,
767        scalar: f32,
768    ) -> bool {
769        let kernels = Self::get_cached_kernels();
770
771        if size < kernels.min_stream_size || size == 0 {
772            return false;
773        }
774
775        // Only use streaming if destination is aligned (critical for streaming stores)
776        if (dst as usize).is_multiple_of(kernels.alignment) {
777            (kernels.scalar_stream)(src, dst, size, scalar);
778            return true;
779        }
780
781        false
782    }
783
784    #[cfg(target_arch = "x86_64")]
785    #[inline]
786    #[target_feature(enable = "avx512f")]
787    unsafe fn add_scalar_avx512_aligned(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
788        use std::arch::x86_64::*;
789        let sv = _mm512_set1_ps(scalar);
790        let mut offset = 0usize;
791        let block = 64usize; // 4x 16-wide
792        while offset + block <= size {
793            let s1 = _mm512_load_ps(src.add(offset));
794            _mm512_store_ps(dst.add(offset), _mm512_add_ps(s1, sv));
795            let s2 = _mm512_load_ps(src.add(offset + 16));
796            _mm512_store_ps(dst.add(offset + 16), _mm512_add_ps(s2, sv));
797            let s3 = _mm512_load_ps(src.add(offset + 32));
798            _mm512_store_ps(dst.add(offset + 32), _mm512_add_ps(s3, sv));
799            let s4 = _mm512_load_ps(src.add(offset + 48));
800            _mm512_store_ps(dst.add(offset + 48), _mm512_add_ps(s4, sv));
801            offset += block;
802        }
803        let mut rem = size - offset;
804        while rem >= 16 {
805            let s = _mm512_load_ps(src.add(offset));
806            _mm512_store_ps(dst.add(offset), _mm512_add_ps(s, sv));
807            offset += 16;
808            rem -= 16;
809        }
810        for i in offset..size {
811            *dst.add(i) = *src.add(i) + scalar;
812        }
813    }
814
815    #[cfg(target_arch = "x86_64")]
816    #[inline]
817    #[target_feature(enable = "avx512f")]
818    unsafe fn add_scalar_avx512_unaligned(
819        src: *const f32,
820        dst: *mut f32,
821        size: usize,
822        scalar: f32,
823    ) {
824        use std::arch::x86_64::*;
825        let sv = _mm512_set1_ps(scalar);
826        let mut offset = 0usize;
827        let block = 64usize;
828        while offset + block <= size {
829            let s1 = _mm512_loadu_ps(src.add(offset));
830            _mm512_storeu_ps(dst.add(offset), _mm512_add_ps(s1, sv));
831            let s2 = _mm512_loadu_ps(src.add(offset + 16));
832            _mm512_storeu_ps(dst.add(offset + 16), _mm512_add_ps(s2, sv));
833            let s3 = _mm512_loadu_ps(src.add(offset + 32));
834            _mm512_storeu_ps(dst.add(offset + 32), _mm512_add_ps(s3, sv));
835            let s4 = _mm512_loadu_ps(src.add(offset + 48));
836            _mm512_storeu_ps(dst.add(offset + 48), _mm512_add_ps(s4, sv));
837            offset += block;
838        }
839        let mut rem = size - offset;
840        while rem >= 16 {
841            let s = _mm512_loadu_ps(src.add(offset));
842            _mm512_storeu_ps(dst.add(offset), _mm512_add_ps(s, sv));
843            offset += 16;
844            rem -= 16;
845        }
846        for i in offset..size {
847            *dst.add(i) = *src.add(i) + scalar;
848        }
849    }
850
851    #[cfg(target_arch = "x86_64")]
852    #[inline]
853    #[target_feature(enable = "avx512f")]
854    unsafe fn add_scalar_avx512_stream(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
855        use std::arch::x86_64::*;
856        let sv = _mm512_set1_ps(scalar);
857        let mut offset = 0usize;
858        let block = 64usize;
859        while offset + block <= size {
860            let s1 = _mm512_loadu_ps(src.add(offset));
861            _mm512_stream_ps(dst.add(offset), _mm512_add_ps(s1, sv));
862            let s2 = _mm512_loadu_ps(src.add(offset + 16));
863            _mm512_stream_ps(dst.add(offset + 16), _mm512_add_ps(s2, sv));
864            let s3 = _mm512_loadu_ps(src.add(offset + 32));
865            _mm512_stream_ps(dst.add(offset + 32), _mm512_add_ps(s3, sv));
866            let s4 = _mm512_loadu_ps(src.add(offset + 48));
867            _mm512_stream_ps(dst.add(offset + 48), _mm512_add_ps(s4, sv));
868            offset += block;
869        }
870        let mut rem = size - offset;
871        while rem >= 16 {
872            let s = _mm512_loadu_ps(src.add(offset));
873            _mm512_stream_ps(dst.add(offset), _mm512_add_ps(s, sv));
874            offset += 16;
875            rem -= 16;
876        }
877        for i in offset..size {
878            *dst.add(i) = *src.add(i) + scalar;
879        }
880    }
881
882    #[cfg(target_arch = "x86_64")]
883    #[inline]
884    #[target_feature(enable = "avx2")]
885    unsafe fn add_scalar_avx2_aligned(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
886        use std::arch::x86_64::*;
887        let sv = _mm256_set1_ps(scalar);
888        let mut offset = 0usize;
889        let block = 32usize;
890        while offset + block <= size {
891            let s1 = _mm256_load_ps(src.add(offset));
892            _mm256_store_ps(dst.add(offset), _mm256_add_ps(s1, sv));
893            let s2 = _mm256_load_ps(src.add(offset + 8));
894            _mm256_store_ps(dst.add(offset + 8), _mm256_add_ps(s2, sv));
895            let s3 = _mm256_load_ps(src.add(offset + 16));
896            _mm256_store_ps(dst.add(offset + 16), _mm256_add_ps(s3, sv));
897            let s4 = _mm256_load_ps(src.add(offset + 24));
898            _mm256_store_ps(dst.add(offset + 24), _mm256_add_ps(s4, sv));
899            offset += block;
900        }
901        let mut rem = size - offset;
902        while rem >= 8 {
903            let s = _mm256_load_ps(src.add(offset));
904            _mm256_store_ps(dst.add(offset), _mm256_add_ps(s, sv));
905            offset += 8;
906            rem -= 8;
907        }
908        for i in offset..size {
909            *dst.add(i) = *src.add(i) + scalar;
910        }
911    }
912
913    #[cfg(target_arch = "x86_64")]
914    #[inline]
915    #[target_feature(enable = "avx2")]
916    unsafe fn add_scalar_avx2_unaligned(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
917        use std::arch::x86_64::*;
918        let sv = _mm256_set1_ps(scalar);
919        let mut offset = 0usize;
920        let block = 32usize;
921        while offset + block <= size {
922            let s1 = _mm256_loadu_ps(src.add(offset));
923            _mm256_storeu_ps(dst.add(offset), _mm256_add_ps(s1, sv));
924            let s2 = _mm256_loadu_ps(src.add(offset + 8));
925            _mm256_storeu_ps(dst.add(offset + 8), _mm256_add_ps(s2, sv));
926            let s3 = _mm256_loadu_ps(src.add(offset + 16));
927            _mm256_storeu_ps(dst.add(offset + 16), _mm256_add_ps(s3, sv));
928            let s4 = _mm256_loadu_ps(src.add(offset + 24));
929            _mm256_storeu_ps(dst.add(offset + 24), _mm256_add_ps(s4, sv));
930            offset += block;
931        }
932        let mut rem = size - offset;
933        while rem >= 8 {
934            let s = _mm256_loadu_ps(src.add(offset));
935            _mm256_storeu_ps(dst.add(offset), _mm256_add_ps(s, sv));
936            offset += 8;
937            rem -= 8;
938        }
939        for i in offset..size {
940            *dst.add(i) = *src.add(i) + scalar;
941        }
942    }
943
944    #[cfg(target_arch = "x86_64")]
945    #[inline]
946    #[target_feature(enable = "avx2")]
947    unsafe fn add_scalar_avx2_stream(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
948        use std::arch::x86_64::*;
949        let sv = _mm256_set1_ps(scalar);
950        let mut offset = 0usize;
951        let block = 32usize;
952        while offset + block <= size {
953            let s1 = _mm256_loadu_ps(src.add(offset));
954            _mm256_stream_ps(dst.add(offset), _mm256_add_ps(s1, sv));
955            let s2 = _mm256_loadu_ps(src.add(offset + 8));
956            _mm256_stream_ps(dst.add(offset + 8), _mm256_add_ps(s2, sv));
957            let s3 = _mm256_loadu_ps(src.add(offset + 16));
958            _mm256_stream_ps(dst.add(offset + 16), _mm256_add_ps(s3, sv));
959            let s4 = _mm256_loadu_ps(src.add(offset + 24));
960            _mm256_stream_ps(dst.add(offset + 24), _mm256_add_ps(s4, sv));
961            offset += block;
962        }
963        let mut rem = size - offset;
964        while rem >= 8 {
965            let s = _mm256_loadu_ps(src.add(offset));
966            _mm256_stream_ps(dst.add(offset), _mm256_add_ps(s, sv));
967            offset += 8;
968            rem -= 8;
969        }
970        for i in offset..size {
971            *dst.add(i) = *src.add(i) + scalar;
972        }
973    }
974
975    #[cfg(target_arch = "x86_64")]
976    #[inline]
977    #[target_feature(enable = "sse2")]
978    unsafe fn add_scalar_sse_aligned(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
979        use std::arch::x86_64::*;
980        let sv = _mm_set1_ps(scalar);
981        let mut offset = 0usize;
982        let block = 16usize; // 4x 4-wide
983        while offset + block <= size {
984            let s1 = _mm_load_ps(src.add(offset));
985            _mm_store_ps(dst.add(offset), _mm_add_ps(s1, sv));
986            let s2 = _mm_load_ps(src.add(offset + 4));
987            _mm_store_ps(dst.add(offset + 4), _mm_add_ps(s2, sv));
988            let s3 = _mm_load_ps(src.add(offset + 8));
989            _mm_store_ps(dst.add(offset + 8), _mm_add_ps(s3, sv));
990            let s4 = _mm_load_ps(src.add(offset + 12));
991            _mm_store_ps(dst.add(offset + 12), _mm_add_ps(s4, sv));
992            offset += block;
993        }
994        let mut rem = size - offset;
995        while rem >= 4 {
996            let s = _mm_load_ps(src.add(offset));
997            _mm_store_ps(dst.add(offset), _mm_add_ps(s, sv));
998            offset += 4;
999            rem -= 4;
1000        }
1001        for i in offset..size {
1002            *dst.add(i) = *src.add(i) + scalar;
1003        }
1004    }
1005
1006    #[cfg(target_arch = "x86_64")]
1007    #[inline]
1008    #[target_feature(enable = "sse2")]
1009    unsafe fn add_scalar_sse_unaligned(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
1010        use std::arch::x86_64::*;
1011        let sv = _mm_set1_ps(scalar);
1012        let mut offset = 0usize;
1013        let block = 16usize;
1014        while offset + block <= size {
1015            let s1 = _mm_loadu_ps(src.add(offset));
1016            _mm_storeu_ps(dst.add(offset), _mm_add_ps(s1, sv));
1017            let s2 = _mm_loadu_ps(src.add(offset + 4));
1018            _mm_storeu_ps(dst.add(offset + 4), _mm_add_ps(s2, sv));
1019            let s3 = _mm_loadu_ps(src.add(offset + 8));
1020            _mm_storeu_ps(dst.add(offset + 8), _mm_add_ps(s3, sv));
1021            let s4 = _mm_loadu_ps(src.add(offset + 12));
1022            _mm_storeu_ps(dst.add(offset + 12), _mm_add_ps(s4, sv));
1023            offset += block;
1024        }
1025        let mut rem = size - offset;
1026        while rem >= 4 {
1027            let s = _mm_loadu_ps(src.add(offset));
1028            _mm_storeu_ps(dst.add(offset), _mm_add_ps(s, sv));
1029            offset += 4;
1030            rem -= 4;
1031        }
1032        for i in offset..size {
1033            *dst.add(i) = *src.add(i) + scalar;
1034        }
1035    }
1036
1037    #[cfg(target_arch = "x86_64")]
1038    #[inline]
1039    #[target_feature(enable = "sse2")]
1040    unsafe fn add_scalar_sse_stream(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
1041        use std::arch::x86_64::*;
1042        let sv = _mm_set1_ps(scalar);
1043        let mut offset = 0usize;
1044        let block = 16usize;
1045        while offset + block <= size {
1046            let s1 = _mm_loadu_ps(src.add(offset));
1047            _mm_stream_ps(dst.add(offset), _mm_add_ps(s1, sv));
1048            let s2 = _mm_loadu_ps(src.add(offset + 4));
1049            _mm_stream_ps(dst.add(offset + 4), _mm_add_ps(s2, sv));
1050            let s3 = _mm_loadu_ps(src.add(offset + 8));
1051            _mm_stream_ps(dst.add(offset + 8), _mm_add_ps(s3, sv));
1052            let s4 = _mm_loadu_ps(src.add(offset + 12));
1053            _mm_stream_ps(dst.add(offset + 12), _mm_add_ps(s4, sv));
1054            offset += block;
1055        }
1056        let mut rem = size - offset;
1057        while rem >= 4 {
1058            let s = _mm_loadu_ps(src.add(offset));
1059            _mm_stream_ps(dst.add(offset), _mm_add_ps(s, sv));
1060            offset += 4;
1061            rem -= 4;
1062        }
1063        for i in offset..size {
1064            *dst.add(i) = *src.add(i) + scalar;
1065        }
1066    }
1067
1068    /// Optimized tensor+tensor scalar fallback (chunked)
1069    #[inline]
1070    unsafe fn add_tensors_scalar_chunk(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
1071        let unroll_count = size / 8;
1072        let mut offset = 0;
1073
1074        for _ in 0..unroll_count {
1075            *dst.add(offset) = *a.add(offset) + *b.add(offset);
1076            *dst.add(offset + 1) = *a.add(offset + 1) + *b.add(offset + 1);
1077            *dst.add(offset + 2) = *a.add(offset + 2) + *b.add(offset + 2);
1078            *dst.add(offset + 3) = *a.add(offset + 3) + *b.add(offset + 3);
1079            *dst.add(offset + 4) = *a.add(offset + 4) + *b.add(offset + 4);
1080            *dst.add(offset + 5) = *a.add(offset + 5) + *b.add(offset + 5);
1081            *dst.add(offset + 6) = *a.add(offset + 6) + *b.add(offset + 6);
1082            *dst.add(offset + 7) = *a.add(offset + 7) + *b.add(offset + 7);
1083            offset += 8;
1084        }
1085        for i in offset..size {
1086            *dst.add(i) = *a.add(i) + *b.add(i);
1087        }
1088    }
1089
1090    /// Internal optimized scalar + tensor operation
1091    #[inline]
1092    pub(crate) fn add_scalar_optimized(&self, scalar: f32) -> Tensor {
1093        // Optimize contiguous handling - use views when possible
1094        let (src_ptr, _src_keep): (*const f32, Option<Tensor>) =
1095            Self::get_optimized_tensor_ptr(self);
1096        let mut output = Tensor::new(self.shape().dims().to_vec());
1097
1098        unsafe {
1099            let src = src_ptr;
1100            let dst = output.as_mut_ptr();
1101            let n = self.size();
1102
1103            // Sequential execution with SIMD optimization
1104            let stream_min = Self::stream_min_elems();
1105            if n >= stream_min && Self::try_add_scalar_stream_best(src, dst, n, scalar) {
1106                // done via streaming stores
1107            } else if !Self::try_add_scalar_simd_best(src, dst, n, scalar) {
1108                Self::add_scalar_fallback_chunk(src, dst, n, scalar);
1109            }
1110        }
1111
1112        output
1113    }
1114
1115    // Replaced single AVX2 scalar kernel with multi-level SIMD selection and kernels below
1116
1117    /// Optimized scalar addition fallback (chunked)
1118    #[inline]
1119    unsafe fn add_scalar_fallback_chunk(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
1120        let unroll_count = size / 8;
1121        let mut offset = 0;
1122
1123        for _ in 0..unroll_count {
1124            *dst.add(offset) = *src.add(offset) + scalar;
1125            *dst.add(offset + 1) = *src.add(offset + 1) + scalar;
1126            *dst.add(offset + 2) = *src.add(offset + 2) + scalar;
1127            *dst.add(offset + 3) = *src.add(offset + 3) + scalar;
1128            *dst.add(offset + 4) = *src.add(offset + 4) + scalar;
1129            *dst.add(offset + 5) = *src.add(offset + 5) + scalar;
1130            *dst.add(offset + 6) = *src.add(offset + 6) + scalar;
1131            *dst.add(offset + 7) = *src.add(offset + 7) + scalar;
1132            offset += 8;
1133        }
1134        for i in offset..size {
1135            *dst.add(i) = *src.add(i) + scalar;
1136        }
1137    }
1138
1139    /// Get optimized tensor pointer with intelligent contiguous handling
1140    ///
1141    /// This function avoids unnecessary contiguous copies by:
1142    /// 1. Using direct pointers for contiguous tensors
1143    /// 2. Using stride-based access for simple non-contiguous patterns
1144    /// 3. Only creating contiguous copies when absolutely necessary
1145    ///
1146    /// Returns (pointer, optional_owned_tensor_to_keep_alive)
1147    #[inline]
1148    fn get_optimized_tensor_ptr(tensor: &Tensor) -> (*const f32, Option<Tensor>) {
1149        unsafe {
1150            if tensor.is_contiguous() {
1151                (tensor.as_ptr(), None)
1152            } else {
1153                // Materialize non-contiguous views to ensure linear access for SIMD kernels
1154                let tmp = tensor.contiguous();
1155                (tmp.as_ptr(), Some(tmp))
1156            }
1157        }
1158    }
1159
1160    /// Check if we can use stride-based access instead of copying
1161    ///
1162    /// OPTIMIZATION #2: Implements stride-aware SIMD for common patterns
1163    /// to eliminate unnecessary contiguous() copies.
1164    #[inline]
1165    #[allow(dead_code)]
1166    fn can_use_stride_based_access(_tensor: &Tensor) -> bool {
1167        false
1168    }
1169
1170    // ===== OPTIMIZATION #6: Comprehensive SIMD Kernel Caching =====
1171
1172    /// Get cached SIMD kernels - single initialization, maximum performance
1173    #[inline]
1174    fn get_cached_kernels() -> &'static CachedKernels {
1175        use std::sync::OnceLock;
1176
1177        static CACHED_KERNELS: OnceLock<CachedKernels> = OnceLock::new();
1178
1179        CACHED_KERNELS.get_or_init(|| {
1180            let simd_level = detect_runtime_simd();
1181            #[cfg(target_arch = "x86_64")]
1182            let alignment = simd_alignment_bytes(simd_level);
1183
1184            #[cfg(target_arch = "x86_64")]
1185            {
1186                match simd_level {
1187                    SimdLevel::Avx512 => CachedKernels {
1188                        simd_level,
1189                        alignment,
1190                        tensor_aligned: Self::add_simd_avx512_aligned,
1191                        tensor_unaligned: Self::add_simd_avx512_unaligned,
1192                        tensor_stream: Self::add_simd_avx512_stream,
1193                        scalar_aligned: Self::add_scalar_avx512_aligned,
1194                        scalar_unaligned: Self::add_scalar_avx512_unaligned,
1195                        scalar_stream: Self::add_scalar_avx512_stream,
1196                        min_aligned_size: 16,
1197                        min_stream_size: Self::stream_min_elems(),
1198                    },
1199                    SimdLevel::Avx2 => CachedKernels {
1200                        simd_level,
1201                        alignment,
1202                        tensor_aligned: Self::add_simd_avx2_aligned,
1203                        tensor_unaligned: Self::add_simd_avx2_unaligned,
1204                        tensor_stream: Self::add_simd_avx2_stream,
1205                        scalar_aligned: Self::add_scalar_avx2_aligned,
1206                        scalar_unaligned: Self::add_scalar_avx2_unaligned,
1207                        scalar_stream: Self::add_scalar_avx2_stream,
1208                        min_aligned_size: 8,
1209                        min_stream_size: Self::stream_min_elems(),
1210                    },
1211                    SimdLevel::Sse2 => CachedKernels {
1212                        simd_level,
1213                        alignment,
1214                        tensor_aligned: Self::add_simd_sse_aligned,
1215                        tensor_unaligned: Self::add_simd_sse_unaligned,
1216                        tensor_stream: Self::add_simd_sse_stream,
1217                        scalar_aligned: Self::add_scalar_sse_aligned,
1218                        scalar_unaligned: Self::add_scalar_sse_unaligned,
1219                        scalar_stream: Self::add_scalar_sse_stream,
1220                        min_aligned_size: 4,
1221                        min_stream_size: Self::stream_min_elems(),
1222                    },
1223                    SimdLevel::Scalar => CachedKernels {
1224                        simd_level,
1225                        alignment: 4, // f32 alignment
1226                        tensor_aligned: Self::add_tensors_scalar_chunk,
1227                        tensor_unaligned: Self::add_tensors_scalar_chunk,
1228                        tensor_stream: Self::add_tensors_scalar_chunk,
1229                        scalar_aligned: Self::add_scalar_fallback_chunk,
1230                        scalar_unaligned: Self::add_scalar_fallback_chunk,
1231                        scalar_stream: Self::add_scalar_fallback_chunk,
1232                        min_aligned_size: 1,
1233                        min_stream_size: usize::MAX, // Never use streaming for scalar
1234                    },
1235                }
1236            }
1237
1238            #[cfg(not(target_arch = "x86_64"))]
1239            {
1240                CachedKernels {
1241                    simd_level,
1242                    alignment: 4, // f32 alignment
1243                    tensor_aligned: Self::add_tensors_scalar_chunk,
1244                    tensor_unaligned: Self::add_tensors_scalar_chunk,
1245                    tensor_stream: Self::add_tensors_scalar_chunk,
1246                    scalar_aligned: Self::add_scalar_fallback_chunk,
1247                    scalar_unaligned: Self::add_scalar_fallback_chunk,
1248                    scalar_stream: Self::add_scalar_fallback_chunk,
1249                    min_aligned_size: 1,
1250                    min_stream_size: usize::MAX,
1251                }
1252            }
1253        })
1254    }
1255}
1256
1257#[cfg(test)]
1258mod tests {
1259    use super::*;
1260    use std::sync::Arc;
1261    use std::thread;
1262
1263    #[test]
1264    fn test_tensor_addition() {
1265        let a = Tensor::ones(vec![2, 3]);
1266        let b = Tensor::ones(vec![2, 3]);
1267        let result = a.add_tensor_optimized(&b);
1268
1269        assert_eq!(result.shape().dims(), vec![2, 3]);
1270        assert_eq!(result.size(), 6);
1271
1272        // Check that all values are 2.0 (1.0 + 1.0)
1273        unsafe {
1274            for i in 0..result.size() {
1275                assert!((result.as_ptr().add(i).read() - 2.0).abs() < 1e-6);
1276            }
1277        }
1278    }
1279
1280    #[test]
1281    fn test_thread_safety_cross_thread_ops() {
1282        use crate::gradtrack::clear_gradients;
1283        clear_gradients();
1284
1285        // Create base tensors with gradient tracking
1286        let a = Arc::new(Tensor::ones(vec![2, 3]).with_requires_grad());
1287        let b = Arc::new(Tensor::ones(vec![2, 3]).with_requires_grad());
1288
1289        // Perform entire forward + backward in a single worker thread (TLS-bound grad graph)
1290        let a1 = a.clone();
1291        let b1 = b.clone();
1292        let handle = thread::spawn(move || {
1293            let t_local1 = Tensor::from_slice(&[2.0; 6], vec![2, 3]).unwrap();
1294            let r1 = (*a1).add_tensor(&t_local1); // (a + 2)
1295
1296            let t_local2 = Tensor::ones(vec![2, 3]);
1297            let r2 = t_local2.add_tensor(&b1); // (1 + b)
1298
1299            let combined = r1.add_tensor(&r2); // a + b + 3
1300            let mut loss = combined.sum();
1301            loss.backward(None);
1302
1303            let ga = (*a1).grad_owned().expect("grad for a (thread)");
1304            let gb = (*b1).grad_owned().expect("grad for b (thread)");
1305            let ga_sum = unsafe { (0..ga.size()).map(|i| *ga.as_ptr().add(i)).sum::<f32>() };
1306            let gb_sum = unsafe { (0..gb.size()).map(|i| *gb.as_ptr().add(i)).sum::<f32>() };
1307            (
1308                ga.shape().dims().to_vec(),
1309                gb.shape().dims().to_vec(),
1310                ga_sum,
1311                gb_sum,
1312            )
1313        });
1314
1315        let (ga_dims, gb_dims, ga_sum, gb_sum) = handle.join().expect("worker panicked");
1316        assert_eq!(ga_dims, vec![2, 3]);
1317        assert_eq!(gb_dims, vec![2, 3]);
1318        // All ones accumulated: 6 elements each
1319        assert!((ga_sum - 6.0).abs() < 1e-6);
1320        assert!((gb_sum - 6.0).abs() < 1e-6);
1321    }
1322
1323    #[test]
1324    fn test_thread_safety_parallel_large_add_backward() {
1325        use crate::gradtrack::clear_gradients;
1326        clear_gradients();
1327
1328        // Size chosen to exceed parallel threshold and align with SIMD-friendly chunking
1329        let n = 8_388_608; // 32MB of f32 data
1330        let a = Arc::new(Tensor::ones(vec![n]).with_requires_grad());
1331        let b = Arc::new(Tensor::ones(vec![n]).with_requires_grad());
1332
1333        // Perform large forward + backward entirely within the worker thread
1334        let at = a.clone();
1335        let bt = b.clone();
1336        let handle = thread::spawn(move || {
1337            let result = (*at).add_tensor(&bt);
1338            let mut loss = result.sum();
1339            loss.backward(None);
1340            let ga = (*at).grad_owned().expect("grad for a (thread)");
1341            let gb = (*bt).grad_owned().expect("grad for b (thread)");
1342            // Return sizes and simple sums to avoid moving huge tensors across threads
1343            let ga_sum = unsafe { (0..ga.size()).map(|i| *ga.as_ptr().add(i)).sum::<f32>() };
1344            let gb_sum = unsafe { (0..gb.size()).map(|i| *gb.as_ptr().add(i)).sum::<f32>() };
1345            (
1346                ga.shape().dims().to_vec(),
1347                gb.shape().dims().to_vec(),
1348                ga_sum,
1349                gb_sum,
1350            )
1351        });
1352
1353        let (ga_dims, gb_dims, ga_sum, gb_sum) = handle.join().expect("worker thread panicked");
1354        assert_eq!(ga_dims, vec![n]);
1355        assert_eq!(gb_dims, vec![n]);
1356        // Each gradient should be all ones: sum equals n
1357        assert!((ga_sum - n as f32).abs() < 1e-3);
1358        assert!((gb_sum - n as f32).abs() < 1e-3);
1359    }
1360
1361    #[test]
1362    fn test_scalar_addition() {
1363        let tensor = Tensor::ones(vec![2, 2]);
1364        let result = tensor.add_scalar_optimized(5.0);
1365
1366        assert_eq!(result.shape().dims(), vec![2, 2]);
1367        assert_eq!(result.size(), 4);
1368
1369        // Check that all values are 6.0 (1.0 + 5.0)
1370        unsafe {
1371            for i in 0..result.size() {
1372                assert!((result.as_ptr().add(i).read() - 6.0).abs() < 1e-6);
1373            }
1374        }
1375    }
1376
1377    #[test]
1378    #[should_panic(expected = "Cannot broadcast tensor shapes")]
1379    fn test_mismatched_shapes() {
1380        let a = Tensor::ones(vec![2, 3]);
1381        let b = Tensor::ones(vec![3, 2]);
1382        a.add_tensor_optimized(&b);
1383    }
1384
1385    #[test]
1386    fn test_add_with_no_grad_guard() {
1387        use crate::gradtrack::{is_grad_enabled, NoGradTrack};
1388
1389        // Create tensors with requires_grad enabled
1390        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
1391        let b = Tensor::ones(vec![2, 2]).with_requires_grad();
1392
1393        // Verify gradients are enabled by default
1394        assert!(is_grad_enabled());
1395
1396        // Normal addition with gradients
1397        let c1 = a.add_tensor(&b);
1398        assert!(
1399            c1.requires_grad(),
1400            "Result should require gradients normally"
1401        );
1402
1403        // Addition with NoGradTrack - gradients should be disabled
1404        {
1405            let _guard = NoGradTrack::new();
1406            assert!(
1407                !is_grad_enabled(),
1408                "Gradients should be disabled within guard"
1409            );
1410
1411            let c2 = a.add_tensor(&b);
1412            assert!(
1413                !c2.requires_grad(),
1414                "Result should not require gradients within NoGradTrack"
1415            );
1416
1417            // Test scalar addition as well
1418            let c3 = a.add_scalar(5.0);
1419            assert!(
1420                !c3.requires_grad(),
1421                "Scalar addition result should not require gradients within NoGradTrack"
1422            );
1423        }
1424
1425        // Gradients should be restored after guard goes out of scope
1426        assert!(
1427            is_grad_enabled(),
1428            "Gradients should be restored after guard"
1429        );
1430
1431        let c4 = a.add_tensor(&b);
1432        assert!(
1433            c4.requires_grad(),
1434            "Result should require gradients after guard is dropped"
1435        );
1436    }
1437
1438    #[test]
1439    fn test_add_nested_no_grad_guards() {
1440        use crate::gradtrack::{is_grad_enabled, NoGradTrack};
1441
1442        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
1443        let b = Tensor::ones(vec![2, 2]).with_requires_grad();
1444
1445        assert!(is_grad_enabled());
1446
1447        {
1448            let _guard1 = NoGradTrack::new();
1449            assert!(!is_grad_enabled());
1450
1451            let c1 = a.add_tensor(&b);
1452            assert!(!c1.requires_grad());
1453
1454            {
1455                let _guard2 = NoGradTrack::new();
1456                assert!(!is_grad_enabled());
1457
1458                let c2 = a.add_tensor(&b);
1459                assert!(!c2.requires_grad());
1460            }
1461
1462            // Still disabled after inner guard drops
1463            assert!(!is_grad_enabled());
1464            let c3 = a.add_tensor(&b);
1465            assert!(!c3.requires_grad());
1466        }
1467
1468        // Restored after all guards drop
1469        assert!(is_grad_enabled());
1470        let c4 = a.add_tensor(&b);
1471        assert!(c4.requires_grad());
1472    }
1473
1474    #[test]
1475    fn test_add_with_mixed_requires_grad() {
1476        use crate::gradtrack::NoGradTrack;
1477
1478        let a = Tensor::ones(vec![2, 2]).with_requires_grad(); // requires_grad = true
1479        let b = Tensor::ones(vec![2, 2]); // requires_grad = false
1480
1481        // Without NoGradTrack, result should require gradients if any input does
1482        let c1 = a.add_tensor(&b);
1483        assert!(c1.requires_grad());
1484
1485        let c2 = b.add_tensor(&a);
1486        assert!(c2.requires_grad());
1487
1488        // With NoGradTrack, result should not require gradients regardless
1489        {
1490            let _guard = NoGradTrack::new();
1491
1492            let c3 = a.add_tensor(&b);
1493            assert!(!c3.requires_grad());
1494
1495            let c4 = b.add_tensor(&a);
1496            assert!(!c4.requires_grad());
1497        }
1498    }
1499
1500    #[test]
1501    fn test_broadcasting_gradients_basic() {
1502        use crate::gradtrack::clear_gradients;
1503        clear_gradients();
1504
1505        // Test case: [2, 3] + [1, 3] -> [2, 3]
1506        // grad_a should be [2, 3], grad_b should be [1, 3] (summed over broadcast dim)
1507
1508        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
1509            .unwrap()
1510            .with_requires_grad();
1511        let b = Tensor::from_slice(&[0.1, 0.2, 0.3], vec![1, 3])
1512            .unwrap()
1513            .with_requires_grad();
1514
1515        let mut result = a.add_tensor(&b);
1516        assert_eq!(result.shape().dims(), vec![2, 3]);
1517
1518        // Set upstream gradient as ones
1519        result.backward(None);
1520
1521        // Check gradients
1522        let grad_a = a.grad_owned().expect("grad_a should exist");
1523        let grad_b = b.grad_owned().expect("grad_b should exist");
1524
1525        println!(
1526            "Original shapes: a={:?}, b={:?}",
1527            a.shape().dims(),
1528            b.shape().dims()
1529        );
1530        println!(
1531            "Gradient shapes: grad_a={:?}, grad_b={:?}",
1532            grad_a.shape().dims(),
1533            grad_b.shape().dims()
1534        );
1535
1536        // grad_a should have same shape as a: [2, 3]
1537        assert_eq!(
1538            grad_a.shape().dims(),
1539            vec![2, 3],
1540            "grad_a should match original shape of a"
1541        );
1542
1543        // grad_b should have same shape as b: [1, 3]
1544        // This requires summing over the broadcasted dimension
1545        assert_eq!(
1546            grad_b.shape().dims(),
1547            vec![1, 3],
1548            "grad_b should match original shape of b"
1549        );
1550
1551        // All gradients should be 1.0 for grad_a
1552        for i in 0..grad_a.size() {
1553            let val = unsafe { *grad_a.as_ptr().add(i) };
1554            assert!(
1555                (val - 1.0).abs() < 1e-6,
1556                "grad_a[{}] = {} should be 1.0",
1557                i,
1558                val
1559            );
1560        }
1561
1562        // grad_b should be [2.0, 2.0, 2.0] (sum over broadcast dim)
1563        let expected_grad_b = [2.0, 2.0, 2.0];
1564        for (i, val) in expected_grad_b.iter().enumerate().take(grad_b.size()) {
1565            let actual = unsafe { *grad_b.as_ptr().add(i) };
1566            assert!(
1567                (actual - val).abs() < 1e-6,
1568                "grad_b[{}] = {} should be {}",
1569                i,
1570                actual,
1571                val
1572            );
1573        }
1574    }
1575
1576    #[test]
1577    fn test_scalar_broadcasting_gradients() {
1578        use crate::gradtrack::clear_gradients;
1579        clear_gradients();
1580
1581        // Test case: [2, 3] + [1] -> [2, 3]
1582        // grad_a should be [2, 3], grad_b should be [1] (summed over all dims)
1583
1584        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
1585            .unwrap()
1586            .with_requires_grad();
1587        let b = Tensor::from_slice(&[0.5], vec![1])
1588            .unwrap()
1589            .with_requires_grad();
1590
1591        let mut result = a.add_tensor(&b);
1592        result.backward(None);
1593
1594        let grad_a = a.grad_owned().expect("grad_a should exist");
1595        let grad_b = b.grad_owned().expect("grad_b should exist");
1596
1597        // grad_a should have same shape as a: [2, 3]
1598        assert_eq!(grad_a.shape().dims(), vec![2, 3]);
1599
1600        // grad_b should have same shape as b: [1] and sum to 6.0
1601        println!("grad_b shape: {:?}, expected: [1]", grad_b.shape().dims());
1602        assert_eq!(grad_b.shape().dims(), vec![1]);
1603
1604        // grad_b should be 6.0 (sum over all 6 elements)
1605        let val = unsafe { *grad_b.as_ptr() };
1606        assert!((val - 6.0).abs() < 1e-6, "grad_b = {} should be 6.0", val);
1607    }
1608
1609    #[test]
1610    fn test_linear_layer_bias_broadcasting() {
1611        use crate::gradtrack::clear_gradients;
1612        clear_gradients();
1613
1614        // Simulate linear layer bias broadcasting
1615        // input: [2, 3], weight: [3, 4], bias: [4]
1616        // matmul result: [2, 4], bias broadcast: [4] -> [2, 4]
1617
1618        let input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
1619            .unwrap()
1620            .with_requires_grad();
1621        let weight = Tensor::from_slice(
1622            &(1..=12).map(|i| i as f32 * 0.1).collect::<Vec<_>>(),
1623            vec![3, 4],
1624        )
1625        .unwrap()
1626        .with_requires_grad();
1627        let bias = Tensor::from_slice(&[0.1, 0.2, 0.3, 0.4], vec![4])
1628            .unwrap()
1629            .with_requires_grad();
1630
1631        // Forward pass: input @ weight + bias
1632        let matmul_result = input.matmul(&weight);
1633        println!("Matmul result shape: {:?}", matmul_result.shape().dims());
1634        println!("Bias shape: {:?}", bias.shape().dims());
1635
1636        let linear_output = matmul_result.add_tensor(&bias);
1637        println!("Linear output shape: {:?}", linear_output.shape().dims());
1638
1639        // Sum all outputs as loss
1640        let mut loss = linear_output.sum();
1641        loss.backward(None);
1642
1643        // Check bias gradient
1644        let bias_grad = bias.grad_owned().expect("bias gradient should exist");
1645        println!("Bias gradient shape: {:?}", bias_grad.shape().dims());
1646        assert_eq!(
1647            bias_grad.shape().dims(),
1648            vec![4],
1649            "bias gradient should match bias shape"
1650        );
1651
1652        // Bias gradient should be [2.0, 2.0, 2.0, 2.0] (sum over batch dimension)
1653        for i in 0..4 {
1654            let val = unsafe { *bias_grad.as_ptr().add(i) };
1655            assert!(
1656                (val - 2.0).abs() < 1e-6,
1657                "bias_grad[{}] = {} should be 2.0",
1658                i,
1659                val
1660            );
1661        }
1662
1663        println!("Linear layer bias broadcasting test passed!");
1664    }
1665}