train_station/tensor/ops/
add.rs

1//! Addition operations for tensors
2//!
3//! Provides element-wise addition following PyTorch conventions with comprehensive
4//! broadcasting support, automatic differentiation, and high-performance SIMD optimization.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Addition**: `add_tensor()` - Addition with another tensor (PyTorch `add()` equivalent)
9//! - **Scalar Broadcasting**: `add_scalar()` - Addition with scalar values
10//! - **Automatic Broadcasting**: NumPy-style broadcasting for compatible shapes
11//! - **SIMD Optimization**: AVX2 acceleration on x86_64 hardware
12//! - **Automatic Differentiation**: Full gradtrack support with gradient tracking
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Zero-copy Operations**: Efficient memory usage where possible
15//!
16//! # Broadcasting Support
17//!
18//! All addition operations support automatic broadcasting following NumPy rules:
19//! - Dimensions are aligned from the rightmost dimension
20//! - Dimensions are compatible if they are equal, or one of them is 1
21//! - Missing dimensions are treated as 1
22//! - Result shape follows broadcasting rules
23//!
24//! # Performance Characteristics
25//!
26//! - **SIMD Acceleration**: 8x vectorization with AVX2 on compatible hardware
27//! - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
28//! - **Cache-friendly Access**: Linear memory access patterns
29//! - **Fallback Support**: Optimized scalar implementations for non-SIMD hardware
30//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
31
32use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
33use crate::tensor::core::Tensor;
34
35// SIMD optimizations for performance-critical operations
36#[cfg(target_arch = "x86_64")]
37use std::arch::x86_64::*;
38
39// (Removed manual prefetching: simplifies hot path; modern CPUs prefetch effectively for linear access)
40
41impl Tensor {
42    /// Element-wise addition with another tensor with broadcasting support.
43    ///
44    /// Performs element-wise addition with automatic broadcasting: `output[i] = self[i] + other[i]`
45    ///
46    /// Broadcasting enables addition between tensors of different but compatible shapes.
47    /// Compatible shapes follow NumPy broadcasting rules:
48    /// - Dimensions are aligned from the rightmost dimension
49    /// - Dimensions are compatible if they are equal, or one of them is 1
50    /// - Missing dimensions are treated as 1
51    ///
52    /// # Arguments
53    /// * `other` - Tensor to add. Shapes must be broadcast-compatible.
54    ///
55    /// # Returns
56    /// A new tensor containing the element-wise sum with broadcast result shape
57    ///
58    /// # Examples
59    ///
60    /// ## Same Shape Addition
61    ///
62    /// ```
63    /// use train_station::Tensor;
64    ///
65    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
66    /// let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
67    /// let c = a.add_tensor(&b);
68    /// assert_eq!(c.shape().dims, vec![3]);
69    /// assert_eq!(c.get(&[0]), 5.0);
70    /// assert_eq!(c.get(&[1]), 7.0);
71    /// assert_eq!(c.get(&[2]), 9.0);
72    /// ```
73    ///
74    /// ## Broadcasting Addition
75    ///
76    /// ```
77    /// use train_station::Tensor;
78    ///
79    /// // Broadcasting: [2, 1] + [1, 3] -> [2, 3]
80    /// let a = Tensor::from_slice(&[1.0, 2.0], vec![2, 1]).unwrap();
81    /// let b = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![1, 3]).unwrap();
82    /// let c = a.add_tensor(&b);
83    /// assert_eq!(c.shape().dims, vec![2, 3]);
84    /// assert_eq!(c.get(&[0, 0]), 11.0);
85    /// assert_eq!(c.get(&[0, 1]), 21.0);
86    /// assert_eq!(c.get(&[1, 0]), 12.0);
87    /// assert_eq!(c.get(&[1, 1]), 22.0);
88    /// ```
89    ///
90    /// ## Scalar Broadcasting
91    ///
92    /// ```
93    /// use train_station::Tensor;
94    ///
95    /// // Scalar broadcasting: [2, 3] + scalar -> [2, 3]
96    /// let a = Tensor::ones(vec![2, 3]);
97    /// let b = Tensor::from_slice(&[5.0], vec![1]).unwrap();
98    /// let c = a.add_tensor(&b);
99    /// assert_eq!(c.shape().dims, vec![2, 3]);
100    /// assert_eq!(c.get(&[0, 0]), 6.0);
101    /// assert_eq!(c.get(&[1, 2]), 6.0);
102    /// ```
103    ///
104    /// # Panics
105    /// Panics if tensor shapes are not broadcast-compatible
106    #[inline]
107    #[track_caller]
108    pub fn add_tensor(&self, other: &Tensor) -> Tensor {
109        // Check if shapes are identical for fast path
110        if self.shape().dims == other.shape().dims {
111            return self.add_tensor_same_shape(other);
112        }
113
114        // Use broadcasting for different shapes
115        let (broadcast_self, broadcast_other, _result_shape) =
116            self.broadcast_with(other).unwrap_or_else(|e| {
117                panic!(
118                    "Cannot broadcast tensor shapes {:?} and {:?}: {}",
119                    self.shape().dims,
120                    other.shape().dims,
121                    e
122                );
123            });
124
125        // Perform element-wise addition on broadcasted tensors
126        let mut result = broadcast_self.add_tensor_optimized(&broadcast_other);
127
128        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
129            result.set_requires_grad_internal(true);
130            let grad_fn = GradFn::Add {
131                is_tensor_add: true,
132                original_shapes: Some((self.shape().dims.clone(), other.shape().dims.clone())),
133            };
134            result.set_grad_fn(grad_fn.clone());
135
136            let mut input_ids = Vec::with_capacity(2);
137            if self.requires_grad() {
138                input_ids.push(self.id());
139            }
140            if other.requires_grad() {
141                input_ids.push(other.id());
142            }
143            GradEngine::register_operation(result.id(), input_ids, grad_fn);
144        }
145
146        result
147    }
148
149    /// Element-wise addition for tensors with identical shapes (fast path).
150    ///
151    /// This is an optimized path for tensors that already have the same shape,
152    /// avoiding the overhead of broadcasting computation. Used internally by
153    /// `add_tensor()` when shapes are identical.
154    ///
155    /// # Arguments
156    /// * `other` - Tensor to add, must have the same shape as self
157    ///
158    /// # Returns
159    /// A new tensor containing the element-wise sum
160    ///
161    /// # Performance Characteristics
162    ///
163    /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
164    /// - **SIMD Optimization**: Uses optimized tensor addition with SIMD acceleration
165    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
166    ///
167    /// # Panics
168    ///
169    /// Panics if tensor shapes do not match
170    #[inline]
171    fn add_tensor_same_shape(&self, other: &Tensor) -> Tensor {
172        assert_eq!(
173            self.shape(),
174            other.shape(),
175            "Tensor shapes must match for same-shape addition"
176        );
177        let mut result = self.add_tensor_optimized(other);
178
179        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
180            result.set_requires_grad_internal(true);
181            let grad_fn = GradFn::Add {
182                is_tensor_add: true,
183                original_shapes: None, // Same shape case
184            };
185            result.set_grad_fn(grad_fn.clone());
186
187            let mut input_ids = Vec::with_capacity(2);
188            if self.requires_grad() {
189                input_ids.push(self.id());
190            }
191            if other.requires_grad() {
192                input_ids.push(other.id());
193            }
194            GradEngine::register_operation(result.id(), input_ids, grad_fn);
195        }
196
197        result
198    }
199
200    /// Broadcast addition with a scalar value.
201    ///
202    /// Adds the scalar to every element: `output[i] = self[i] + scalar`
203    ///
204    /// # Arguments
205    /// * `scalar` - Value to add to each element
206    ///
207    /// # Returns
208    /// A new tensor with the scalar added to each element
209    ///
210    /// # Examples
211    ///
212    /// ## Basic Scalar Addition
213    ///
214    /// ```
215    /// use train_station::Tensor;
216    ///
217    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
218    /// let b = a.add_scalar(10.0);
219    /// assert_eq!(b.shape().dims, vec![3]);
220    /// assert_eq!(b.get(&[0]), 11.0);
221    /// assert_eq!(b.get(&[1]), 12.0);
222    /// assert_eq!(b.get(&[2]), 13.0);
223    /// ```
224    ///
225    /// ## Multi-dimensional Scalar Addition
226    ///
227    /// ```
228    /// use train_station::Tensor;
229    ///
230    /// let a = Tensor::ones(vec![2, 3]);
231    /// let b = a.add_scalar(5.0);
232    /// assert_eq!(b.shape().dims, vec![2, 3]);
233    /// assert_eq!(b.get(&[0, 0]), 6.0);
234    /// assert_eq!(b.get(&[1, 2]), 6.0);
235    /// ```
236    #[inline]
237    #[track_caller]
238    pub fn add_scalar(&self, scalar: f32) -> Tensor {
239        let mut result = self.add_scalar_optimized(scalar);
240
241        if self.requires_grad() && is_grad_enabled() {
242            result.set_requires_grad_internal(true);
243            let grad_fn = GradFn::Add {
244                is_tensor_add: false,
245                original_shapes: None, // Scalar case
246            };
247            result.set_grad_fn(grad_fn.clone());
248            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
249        }
250
251        result
252    }
253    /// Internal optimized tensor + tensor operation
254    ///
255    /// Performs element-wise addition between two tensors with the same shape,
256    /// using SIMD acceleration when available. This is the core implementation
257    /// used by `add_tensor()` after broadcasting has been applied.
258    ///
259    /// # Arguments
260    ///
261    /// * `other` - Tensor to add, must have the same shape as self
262    ///
263    /// # Returns
264    ///
265    /// A new tensor containing the element-wise sum
266    ///
267    /// # Safety
268    ///
269    /// Assumes both tensors have the same shape and valid memory layouts.
270    /// Uses unsafe SIMD operations for performance optimization.
271    ///
272    /// # Performance Characteristics
273    ///
274    /// - **SIMD Optimization**: Uses AVX2 when available for 8x vectorization
275    /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
276    /// - **Cache-friendly**: Linear memory access patterns
277    /// - **Fallback**: Optimized scalar implementation for non-SIMD hardware
278    #[inline]
279    pub(crate) fn add_tensor_optimized(&self, other: &Tensor) -> Tensor {
280        assert_eq!(self.shape(), other.shape(), "Tensor shapes must match");
281
282        let mut output = Tensor::new(self.shape().dims.clone());
283
284        unsafe {
285            let a = self.as_ptr();
286            let b = other.as_ptr();
287            let dst = output.as_mut_ptr();
288
289            #[cfg(target_arch = "x86_64")]
290            {
291                // Use SIMD for better performance when available
292                if is_x86_feature_detected!("avx2") {
293                    self.add_tensors_simd_avx2_optimized(a, b, dst);
294                    return output;
295                }
296            }
297
298            // Fallback to scalar operations with better cache usage
299            self.add_tensors_scalar_optimized(a, b, dst);
300        }
301
302        output
303    }
304
305    /// SIMD-optimized tensor addition using AVX2 instructions
306    ///
307    /// Performs element-wise addition using AVX2 SIMD instructions for maximum
308    /// performance on x86_64 hardware. Processes 32 elements per iteration with
309    /// 4x unrolling for optimal instruction throughput.
310    ///
311    /// # Arguments
312    ///
313    /// * `a` - Pointer to first tensor data
314    /// * `b` - Pointer to second tensor data
315    /// * `dst` - Pointer to output tensor data
316    ///
317    /// # Safety
318    ///
319    /// Requires AVX2 support and valid pointers with sufficient memory.
320    /// All pointers must be aligned and point to valid tensor data.
321    ///
322    /// # Performance Characteristics
323    ///
324    /// - **SIMD Width**: 8 elements per AVX2 vector operation
325    /// - **Unrolling**: 4x unrolling (32 elements per iteration)
326    /// - **Memory Access**: Linear access patterns for cache efficiency
327    /// - **Fallback**: Handles remaining elements with scalar operations
328    #[cfg(target_arch = "x86_64")]
329    #[inline]
330    #[target_feature(enable = "avx2")]
331    unsafe fn add_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
332        let size = self.size();
333        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
334        let mut offset = 0;
335
336        // Unrolled SIMD loop for throughput
337        for _ in 0..simd_count {
338            // Process 4 AVX2 vectors (32 elements) per iteration
339            let a_vec1 = _mm256_loadu_ps(a.add(offset));
340            let b_vec1 = _mm256_loadu_ps(b.add(offset));
341            let sum_vec1 = _mm256_add_ps(a_vec1, b_vec1);
342            _mm256_storeu_ps(dst.add(offset), sum_vec1);
343
344            let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
345            let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
346            let sum_vec2 = _mm256_add_ps(a_vec2, b_vec2);
347            _mm256_storeu_ps(dst.add(offset + 8), sum_vec2);
348
349            let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
350            let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
351            let sum_vec3 = _mm256_add_ps(a_vec3, b_vec3);
352            _mm256_storeu_ps(dst.add(offset + 16), sum_vec3);
353
354            let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
355            let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
356            let sum_vec4 = _mm256_add_ps(a_vec4, b_vec4);
357            _mm256_storeu_ps(dst.add(offset + 24), sum_vec4);
358
359            offset += 32;
360        }
361
362        // Handle remaining elements in blocks of 8 then tail
363        let remaining_full_blocks = (size - offset) / 8;
364        for _ in 0..remaining_full_blocks {
365            let a_vec = _mm256_loadu_ps(a.add(offset));
366            let b_vec = _mm256_loadu_ps(b.add(offset));
367            let sum_vec = _mm256_add_ps(a_vec, b_vec);
368            _mm256_storeu_ps(dst.add(offset), sum_vec);
369            offset += 8;
370        }
371        while offset + 4 <= size {
372            *dst.add(offset) = *a.add(offset) + *b.add(offset);
373            *dst.add(offset + 1) = *a.add(offset + 1) + *b.add(offset + 1);
374            *dst.add(offset + 2) = *a.add(offset + 2) + *b.add(offset + 2);
375            *dst.add(offset + 3) = *a.add(offset + 3) + *b.add(offset + 3);
376            offset += 4;
377        }
378        for i in offset..size {
379            *dst.add(i) = *a.add(i) + *b.add(i);
380        }
381    }
382
383    /// Optimized scalar tensor addition fallback
384    ///
385    /// Performs element-wise addition using optimized scalar operations when
386    /// SIMD is not available. Uses 4x unrolling for better instruction-level
387    /// parallelism and cache efficiency.
388    ///
389    /// # Arguments
390    ///
391    /// * `a` - Pointer to first tensor data
392    /// * `b` - Pointer to second tensor data
393    /// * `dst` - Pointer to output tensor data
394    ///
395    /// # Safety
396    ///
397    /// Requires valid pointers with sufficient memory for the tensor size.
398    /// All pointers must point to valid tensor data.
399    ///
400    /// # Performance Characteristics
401    ///
402    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
403    /// - **Memory Access**: Linear access patterns for cache efficiency
404    /// - **Fallback**: Handles remaining elements with scalar operations
405    #[inline]
406    unsafe fn add_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
407        let size = self.size();
408        let unroll_count = size / 4;
409        let mut offset = 0;
410
411        // Unrolled scalar loop for better performance
412        for _ in 0..unroll_count {
413            *dst.add(offset) = *a.add(offset) + *b.add(offset);
414            *dst.add(offset + 1) = *a.add(offset + 1) + *b.add(offset + 1);
415            *dst.add(offset + 2) = *a.add(offset + 2) + *b.add(offset + 2);
416            *dst.add(offset + 3) = *a.add(offset + 3) + *b.add(offset + 3);
417            offset += 4;
418        }
419
420        // Handle remaining elements
421        for i in offset..size {
422            *dst.add(i) = *a.add(i) + *b.add(i);
423        }
424    }
425
426    /// Internal optimized scalar + tensor operation
427    ///
428    /// Performs element-wise addition of a scalar to each element of the tensor,
429    /// using SIMD acceleration when available. This is the core implementation
430    /// used by `add_scalar()`.
431    ///
432    /// # Arguments
433    ///
434    /// * `scalar` - Scalar value to add to each element
435    ///
436    /// # Returns
437    ///
438    /// A new tensor with the scalar added to each element
439    ///
440    /// # Safety
441    ///
442    /// Assumes valid tensor memory layout. Uses unsafe SIMD operations for
443    /// performance optimization.
444    ///
445    /// # Performance Characteristics
446    ///
447    /// - **SIMD Optimization**: Uses AVX2 when available for 8x vectorization
448    /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
449    /// - **Cache-friendly**: Linear memory access patterns
450    /// - **Fallback**: Optimized scalar implementation for non-SIMD hardware
451    #[inline]
452    pub(crate) fn add_scalar_optimized(&self, scalar: f32) -> Tensor {
453        let mut output = Tensor::new(self.shape().dims.clone());
454
455        unsafe {
456            let src = self.as_ptr();
457            let dst = output.as_mut_ptr();
458
459            #[cfg(target_arch = "x86_64")]
460            {
461                // Use SIMD for better performance when available
462                if is_x86_feature_detected!("avx2") {
463                    self.add_scalar_simd_avx2_optimized(src, dst, scalar);
464                    return output;
465                }
466            }
467
468            // Fallback to optimized scalar operations
469            self.add_scalar_fallback_optimized(src, dst, scalar);
470        }
471
472        output
473    }
474
475    /// SIMD-optimized scalar addition using AVX2 instructions
476    ///
477    /// Performs element-wise scalar addition using AVX2 SIMD instructions for maximum
478    /// performance on x86_64 hardware. Processes 32 elements per iteration with
479    /// 4x unrolling for optimal instruction throughput.
480    ///
481    /// # Arguments
482    ///
483    /// * `src` - Pointer to source tensor data
484    /// * `dst` - Pointer to output tensor data
485    /// * `scalar` - Scalar value to add to each element
486    ///
487    /// # Safety
488    ///
489    /// Requires AVX2 support and valid pointers with sufficient memory.
490    /// All pointers must be aligned and point to valid tensor data.
491    ///
492    /// # Performance Characteristics
493    ///
494    /// - **SIMD Width**: 8 elements per AVX2 vector operation
495    /// - **Unrolling**: 4x unrolling (32 elements per iteration)
496    /// - **Memory Access**: Linear access patterns for cache efficiency
497    /// - **Fallback**: Handles remaining elements with scalar operations
498    #[cfg(target_arch = "x86_64")]
499    #[inline]
500    #[target_feature(enable = "avx2")]
501    unsafe fn add_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
502        let scalar_vec = _mm256_set1_ps(scalar);
503        let size = self.size();
504        let simd_count = size / 32; // Process 32 elements per iteration
505        let mut offset = 0;
506
507        // Unrolled SIMD loop for instruction throughput
508        for _ in 0..simd_count {
509            let src_vec1 = _mm256_loadu_ps(src.add(offset));
510            let sum_vec1 = _mm256_add_ps(src_vec1, scalar_vec);
511            _mm256_storeu_ps(dst.add(offset), sum_vec1);
512
513            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
514            let sum_vec2 = _mm256_add_ps(src_vec2, scalar_vec);
515            _mm256_storeu_ps(dst.add(offset + 8), sum_vec2);
516
517            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
518            let sum_vec3 = _mm256_add_ps(src_vec3, scalar_vec);
519            _mm256_storeu_ps(dst.add(offset + 16), sum_vec3);
520
521            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
522            let sum_vec4 = _mm256_add_ps(src_vec4, scalar_vec);
523            _mm256_storeu_ps(dst.add(offset + 24), sum_vec4);
524
525            offset += 32;
526        }
527
528        // Handle remaining 8-element blocks
529        let remaining_full_blocks = (size - offset) / 8;
530        for _ in 0..remaining_full_blocks {
531            let src_vec = _mm256_loadu_ps(src.add(offset));
532            let sum_vec = _mm256_add_ps(src_vec, scalar_vec);
533            _mm256_storeu_ps(dst.add(offset), sum_vec);
534            offset += 8;
535        }
536
537        // Handle final elements
538        for i in offset..size {
539            *dst.add(i) = *src.add(i) + scalar;
540        }
541    }
542
543    /// Optimized scalar addition fallback
544    ///
545    /// Performs element-wise scalar addition using optimized scalar operations when
546    /// SIMD is not available. Uses 4x unrolling for better instruction-level
547    /// parallelism and cache efficiency.
548    ///
549    /// # Arguments
550    ///
551    /// * `src` - Pointer to source tensor data
552    /// * `dst` - Pointer to output tensor data
553    /// * `scalar` - Scalar value to add to each element
554    ///
555    /// # Safety
556    ///
557    /// Requires valid pointers with sufficient memory for the tensor size.
558    /// All pointers must point to valid tensor data.
559    ///
560    /// # Performance Characteristics
561    ///
562    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
563    /// - **Memory Access**: Linear access patterns for cache efficiency
564    /// - **Fallback**: Handles remaining elements with scalar operations
565    #[inline]
566    unsafe fn add_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
567        let size = self.size();
568        let unroll_count = size / 4;
569        let mut offset = 0;
570
571        // Unrolled scalar operations with while for clarity
572        for _ in 0..unroll_count {
573            *dst.add(offset) = *src.add(offset) + scalar;
574            *dst.add(offset + 1) = *src.add(offset + 1) + scalar;
575            *dst.add(offset + 2) = *src.add(offset + 2) + scalar;
576            *dst.add(offset + 3) = *src.add(offset + 3) + scalar;
577            offset += 4;
578        }
579        for i in offset..size {
580            *dst.add(i) = *src.add(i) + scalar;
581        }
582    }
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588
589    #[test]
590    fn test_tensor_addition() {
591        let a = Tensor::ones(vec![2, 3]);
592        let b = Tensor::ones(vec![2, 3]);
593        let result = a.add_tensor_optimized(&b);
594
595        assert_eq!(result.shape().dims, vec![2, 3]);
596        assert_eq!(result.size(), 6);
597
598        // Check that all values are 2.0 (1.0 + 1.0)
599        unsafe {
600            for i in 0..result.size() {
601                assert!((result.as_ptr().add(i).read() - 2.0).abs() < 1e-6);
602            }
603        }
604    }
605
606    #[test]
607    fn test_scalar_addition() {
608        let tensor = Tensor::ones(vec![2, 2]);
609        let result = tensor.add_scalar_optimized(5.0);
610
611        assert_eq!(result.shape().dims, vec![2, 2]);
612        assert_eq!(result.size(), 4);
613
614        // Check that all values are 6.0 (1.0 + 5.0)
615        unsafe {
616            for i in 0..result.size() {
617                assert!((result.as_ptr().add(i).read() - 6.0).abs() < 1e-6);
618            }
619        }
620    }
621
622    #[test]
623    #[should_panic(expected = "Tensor shapes must match")]
624    fn test_mismatched_shapes() {
625        let a = Tensor::ones(vec![2, 3]);
626        let b = Tensor::ones(vec![3, 2]);
627        a.add_tensor_optimized(&b);
628    }
629
630    #[test]
631    fn test_add_with_no_grad_guard() {
632        use crate::gradtrack::{is_grad_enabled, NoGradTrack};
633
634        // Create tensors with requires_grad enabled
635        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
636        let b = Tensor::ones(vec![2, 2]).with_requires_grad();
637
638        // Verify gradients are enabled by default
639        assert!(is_grad_enabled());
640
641        // Normal addition with gradients
642        let c1 = a.add_tensor(&b);
643        assert!(
644            c1.requires_grad(),
645            "Result should require gradients normally"
646        );
647
648        // Addition with NoGradTrack - gradients should be disabled
649        {
650            let _guard = NoGradTrack::new();
651            assert!(
652                !is_grad_enabled(),
653                "Gradients should be disabled within guard"
654            );
655
656            let c2 = a.add_tensor(&b);
657            assert!(
658                !c2.requires_grad(),
659                "Result should not require gradients within NoGradTrack"
660            );
661
662            // Test scalar addition as well
663            let c3 = a.add_scalar(5.0);
664            assert!(
665                !c3.requires_grad(),
666                "Scalar addition result should not require gradients within NoGradTrack"
667            );
668        }
669
670        // Gradients should be restored after guard goes out of scope
671        assert!(
672            is_grad_enabled(),
673            "Gradients should be restored after guard"
674        );
675
676        let c4 = a.add_tensor(&b);
677        assert!(
678            c4.requires_grad(),
679            "Result should require gradients after guard is dropped"
680        );
681    }
682
683    #[test]
684    fn test_add_nested_no_grad_guards() {
685        use crate::gradtrack::{is_grad_enabled, NoGradTrack};
686
687        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
688        let b = Tensor::ones(vec![2, 2]).with_requires_grad();
689
690        assert!(is_grad_enabled());
691
692        {
693            let _guard1 = NoGradTrack::new();
694            assert!(!is_grad_enabled());
695
696            let c1 = a.add_tensor(&b);
697            assert!(!c1.requires_grad());
698
699            {
700                let _guard2 = NoGradTrack::new();
701                assert!(!is_grad_enabled());
702
703                let c2 = a.add_tensor(&b);
704                assert!(!c2.requires_grad());
705            }
706
707            // Still disabled after inner guard drops
708            assert!(!is_grad_enabled());
709            let c3 = a.add_tensor(&b);
710            assert!(!c3.requires_grad());
711        }
712
713        // Restored after all guards drop
714        assert!(is_grad_enabled());
715        let c4 = a.add_tensor(&b);
716        assert!(c4.requires_grad());
717    }
718
719    #[test]
720    fn test_add_with_mixed_requires_grad() {
721        use crate::gradtrack::NoGradTrack;
722
723        let a = Tensor::ones(vec![2, 2]).with_requires_grad(); // requires_grad = true
724        let b = Tensor::ones(vec![2, 2]); // requires_grad = false
725
726        // Without NoGradTrack, result should require gradients if any input does
727        let c1 = a.add_tensor(&b);
728        assert!(c1.requires_grad());
729
730        let c2 = b.add_tensor(&a);
731        assert!(c2.requires_grad());
732
733        // With NoGradTrack, result should not require gradients regardless
734        {
735            let _guard = NoGradTrack::new();
736
737            let c3 = a.add_tensor(&b);
738            assert!(!c3.requires_grad());
739
740            let c4 = b.add_tensor(&a);
741            assert!(!c4.requires_grad());
742        }
743    }
744
745    #[test]
746    fn test_add_performance_no_overhead() {
747        use crate::gradtrack::NoGradTrack;
748        use std::time::Instant;
749
750        let size = 1000; // Smaller size for test stability
751        let a = Tensor::ones(vec![size]).with_requires_grad();
752        let b = Tensor::ones(vec![size]);
753
754        // Time normal addition (with potential grad overhead)
755        let start = Instant::now();
756        for _ in 0..10 {
757            let _ = a.add_tensor(&b);
758        }
759        let normal_duration = start.elapsed();
760
761        // Time addition with NoGradTrack (should be faster)
762        let start = Instant::now();
763        {
764            let _guard = NoGradTrack::new();
765            for _ in 0..10 {
766                let _ = a.add_tensor(&b);
767            }
768        }
769        let no_grad_duration = start.elapsed();
770
771        // NoGradTrack should provide performance benefit by skipping gradtrack setup
772        // Allow generous variance for timing inconsistencies in tests
773        println!(
774            "Normal: {:?}, NoGrad: {:?}",
775            normal_duration, no_grad_duration
776        );
777
778        // The key verification is that NoGradTrack doesn't add overhead
779        assert!(
780            no_grad_duration <= normal_duration * 3,
781            "NoGradTrack should not add significant overhead"
782        );
783    }
784
785    #[test]
786    fn test_broadcasting_gradients_basic() {
787        use crate::gradtrack::clear_gradients;
788        clear_gradients();
789
790        // Test case: [2, 3] + [1, 3] -> [2, 3]
791        // grad_a should be [2, 3], grad_b should be [1, 3] (summed over broadcast dim)
792
793        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
794            .unwrap()
795            .with_requires_grad();
796        let b = Tensor::from_slice(&[0.1, 0.2, 0.3], vec![1, 3])
797            .unwrap()
798            .with_requires_grad();
799
800        let mut result = a.add_tensor(&b);
801        assert_eq!(result.shape().dims, vec![2, 3]);
802
803        // Set upstream gradient as ones
804        result.backward(None);
805
806        // Check gradients
807        let grad_a = a.grad_by_value().expect("grad_a should exist");
808        let grad_b = b.grad_by_value().expect("grad_b should exist");
809
810        println!(
811            "Original shapes: a={:?}, b={:?}",
812            a.shape().dims,
813            b.shape().dims
814        );
815        println!(
816            "Gradient shapes: grad_a={:?}, grad_b={:?}",
817            grad_a.shape().dims,
818            grad_b.shape().dims
819        );
820
821        // grad_a should have same shape as a: [2, 3]
822        assert_eq!(
823            grad_a.shape().dims,
824            vec![2, 3],
825            "grad_a should match original shape of a"
826        );
827
828        // grad_b should have same shape as b: [1, 3]
829        // This requires summing over the broadcasted dimension
830        assert_eq!(
831            grad_b.shape().dims,
832            vec![1, 3],
833            "grad_b should match original shape of b"
834        );
835
836        // All gradients should be 1.0 for grad_a
837        for i in 0..grad_a.size() {
838            let val = unsafe { *grad_a.as_ptr().add(i) };
839            assert!(
840                (val - 1.0).abs() < 1e-6,
841                "grad_a[{}] = {} should be 1.0",
842                i,
843                val
844            );
845        }
846
847        // grad_b should be [2.0, 2.0, 2.0] (sum over broadcast dim)
848        let expected_grad_b = [2.0, 2.0, 2.0];
849        for (i, val) in expected_grad_b.iter().enumerate().take(grad_b.size()) {
850            let actual = unsafe { *grad_b.as_ptr().add(i) };
851            assert!(
852                (actual - val).abs() < 1e-6,
853                "grad_b[{}] = {} should be {}",
854                i,
855                actual,
856                val
857            );
858        }
859    }
860
861    #[test]
862    fn test_scalar_broadcasting_gradients() {
863        use crate::gradtrack::clear_gradients;
864        clear_gradients();
865
866        // Test case: [2, 3] + [1] -> [2, 3]
867        // grad_a should be [2, 3], grad_b should be [1] (summed over all dims)
868
869        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
870            .unwrap()
871            .with_requires_grad();
872        let b = Tensor::from_slice(&[0.5], vec![1])
873            .unwrap()
874            .with_requires_grad();
875
876        let mut result = a.add_tensor(&b);
877        result.backward(None);
878
879        let grad_a = a.grad_by_value().expect("grad_a should exist");
880        let grad_b = b.grad_by_value().expect("grad_b should exist");
881
882        // grad_a should have same shape as a: [2, 3]
883        assert_eq!(grad_a.shape().dims, vec![2, 3]);
884
885        // grad_b should have same shape as b: [1] and sum to 6.0
886        println!("grad_b shape: {:?}, expected: [1]", grad_b.shape().dims);
887        assert_eq!(grad_b.shape().dims, vec![1]);
888
889        // grad_b should be 6.0 (sum over all 6 elements)
890        let val = unsafe { *grad_b.as_ptr() };
891        assert!((val - 6.0).abs() < 1e-6, "grad_b = {} should be 6.0", val);
892    }
893
894    #[test]
895    fn test_linear_layer_bias_broadcasting() {
896        use crate::gradtrack::clear_gradients;
897        clear_gradients();
898
899        // Simulate linear layer bias broadcasting
900        // input: [2, 3], weight: [3, 4], bias: [4]
901        // matmul result: [2, 4], bias broadcast: [4] -> [2, 4]
902
903        let input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
904            .unwrap()
905            .with_requires_grad();
906        let weight = Tensor::from_slice(
907            &(1..=12).map(|i| i as f32 * 0.1).collect::<Vec<_>>(),
908            vec![3, 4],
909        )
910        .unwrap()
911        .with_requires_grad();
912        let bias = Tensor::from_slice(&[0.1, 0.2, 0.3, 0.4], vec![4])
913            .unwrap()
914            .with_requires_grad();
915
916        // Forward pass: input @ weight + bias
917        let matmul_result = input.matmul(&weight);
918        println!("Matmul result shape: {:?}", matmul_result.shape().dims);
919        println!("Bias shape: {:?}", bias.shape().dims);
920
921        let linear_output = matmul_result.add_tensor(&bias);
922        println!("Linear output shape: {:?}", linear_output.shape().dims);
923
924        // Sum all outputs as loss
925        let mut loss = linear_output.sum();
926        loss.backward(None);
927
928        // Check bias gradient
929        let bias_grad = bias.grad_by_value().expect("bias gradient should exist");
930        println!("Bias gradient shape: {:?}", bias_grad.shape().dims);
931        assert_eq!(
932            bias_grad.shape().dims,
933            vec![4],
934            "bias gradient should match bias shape"
935        );
936
937        // Bias gradient should be [2.0, 2.0, 2.0, 2.0] (sum over batch dimension)
938        for i in 0..4 {
939            let val = unsafe { *bias_grad.as_ptr().add(i) };
940            assert!(
941                (val - 2.0).abs() < 1e-6,
942                "bias_grad[{}] = {} should be 2.0",
943                i,
944                val
945            );
946        }
947
948        println!("Linear layer bias broadcasting test passed!");
949    }
950}