train_station/tensor/ops/
add.rs

1//! Addition operations for tensors
2//!
3//! Provides element-wise addition following PyTorch conventions with comprehensive
4//! broadcasting support, automatic differentiation, and high-performance SIMD optimization.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Addition**: `add_tensor()` - Addition with another tensor (PyTorch `add()` equivalent)
9//! - **Scalar Broadcasting**: `add_scalar()` - Addition with scalar values
10//! - **Automatic Broadcasting**: NumPy-style broadcasting for compatible shapes
11//! - **SIMD Optimization**: AVX2 acceleration on x86_64 hardware
12//! - **Automatic Differentiation**: Full gradtrack support with gradient tracking
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Zero-copy Operations**: Efficient memory usage where possible
15//!
16//! # Broadcasting Support
17//!
18//! All addition operations support automatic broadcasting following NumPy rules:
19//! - Dimensions are aligned from the rightmost dimension
20//! - Dimensions are compatible if they are equal, or one of them is 1
21//! - Missing dimensions are treated as 1
22//! - Result shape follows broadcasting rules
23//!
24//! # Performance Characteristics
25//!
26//! - **SIMD Acceleration**: 8x vectorization with AVX2 on compatible hardware
27//! - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
28//! - **Cache-friendly Access**: Linear memory access patterns
29//! - **Fallback Support**: Optimized scalar implementations for non-SIMD hardware
30//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
31
32use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
33use crate::tensor::core::Tensor;
34
35// SIMD optimizations for performance-critical operations
36#[cfg(target_arch = "x86_64")]
37use std::arch::x86_64::*;
38
39// (Removed manual prefetching: simplifies hot path; modern CPUs prefetch effectively for linear access)
40
41impl Tensor {
42    /// Element-wise addition with another tensor with broadcasting support.
43    ///
44    /// Performs element-wise addition with automatic broadcasting: `output[i] = self[i] + other[i]`
45    ///
46    /// Broadcasting enables addition between tensors of different but compatible shapes.
47    /// Compatible shapes follow NumPy broadcasting rules:
48    /// - Dimensions are aligned from the rightmost dimension
49    /// - Dimensions are compatible if they are equal, or one of them is 1
50    /// - Missing dimensions are treated as 1
51    ///
52    /// # Arguments
53    /// * `other` - Tensor to add. Shapes must be broadcast-compatible.
54    ///
55    /// # Returns
56    /// A new tensor containing the element-wise sum with broadcast result shape
57    ///
58    /// # Examples
59    ///
60    /// ## Same Shape Addition
61    ///
62    /// ```
63    /// use train_station::Tensor;
64    ///
65    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
66    /// let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
67    /// let c = a.add_tensor(&b);
68    /// assert_eq!(c.shape().dims, vec![3]);
69    /// assert_eq!(c.get(&[0]), 5.0);
70    /// assert_eq!(c.get(&[1]), 7.0);
71    /// assert_eq!(c.get(&[2]), 9.0);
72    /// ```
73    ///
74    /// ## Broadcasting Addition
75    ///
76    /// ```
77    /// use train_station::Tensor;
78    ///
79    /// // Broadcasting: [2, 1] + [1, 3] -> [2, 3]
80    /// let a = Tensor::from_slice(&[1.0, 2.0], vec![2, 1]).unwrap();
81    /// let b = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![1, 3]).unwrap();
82    /// let c = a.add_tensor(&b);
83    /// assert_eq!(c.shape().dims, vec![2, 3]);
84    /// assert_eq!(c.get(&[0, 0]), 11.0);
85    /// assert_eq!(c.get(&[0, 1]), 21.0);
86    /// assert_eq!(c.get(&[1, 0]), 12.0);
87    /// assert_eq!(c.get(&[1, 1]), 22.0);
88    /// ```
89    ///
90    /// ## Scalar Broadcasting
91    ///
92    /// ```
93    /// use train_station::Tensor;
94    ///
95    /// // Scalar broadcasting: [2, 3] + scalar -> [2, 3]
96    /// let a = Tensor::ones(vec![2, 3]);
97    /// let b = Tensor::from_slice(&[5.0], vec![1]).unwrap();
98    /// let c = a.add_tensor(&b);
99    /// assert_eq!(c.shape().dims, vec![2, 3]);
100    /// assert_eq!(c.get(&[0, 0]), 6.0);
101    /// assert_eq!(c.get(&[1, 2]), 6.0);
102    /// ```
103    ///
104    /// # Panics
105    /// Panics if tensor shapes are not broadcast-compatible
106    #[inline]
107    pub fn add_tensor(&self, other: &Tensor) -> Tensor {
108        // Check if shapes are identical for fast path
109        if self.shape().dims == other.shape().dims {
110            return self.add_tensor_same_shape(other);
111        }
112
113        // Use broadcasting for different shapes
114        let (broadcast_self, broadcast_other, _result_shape) =
115            self.broadcast_with(other).unwrap_or_else(|e| {
116                panic!(
117                    "Cannot broadcast tensor shapes {:?} and {:?}: {}",
118                    self.shape().dims,
119                    other.shape().dims,
120                    e
121                );
122            });
123
124        // Perform element-wise addition on broadcasted tensors
125        let mut result = broadcast_self.add_tensor_optimized(&broadcast_other);
126
127        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
128            result.set_requires_grad_internal(true);
129            let grad_fn = GradFn::Add {
130                is_tensor_add: true,
131                original_shapes: Some((self.shape().dims.clone(), other.shape().dims.clone())),
132            };
133            result.set_grad_fn(grad_fn.clone());
134
135            let mut input_ids = Vec::with_capacity(2);
136            if self.requires_grad() {
137                input_ids.push(self.id());
138            }
139            if other.requires_grad() {
140                input_ids.push(other.id());
141            }
142            GradEngine::register_operation(result.id(), input_ids, grad_fn);
143        }
144
145        result
146    }
147
148    /// Element-wise addition for tensors with identical shapes (fast path).
149    ///
150    /// This is an optimized path for tensors that already have the same shape,
151    /// avoiding the overhead of broadcasting computation. Used internally by
152    /// `add_tensor()` when shapes are identical.
153    ///
154    /// # Arguments
155    /// * `other` - Tensor to add, must have the same shape as self
156    ///
157    /// # Returns
158    /// A new tensor containing the element-wise sum
159    ///
160    /// # Performance Characteristics
161    ///
162    /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
163    /// - **SIMD Optimization**: Uses optimized tensor addition with SIMD acceleration
164    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
165    ///
166    /// # Panics
167    ///
168    /// Panics if tensor shapes do not match
169    #[inline]
170    fn add_tensor_same_shape(&self, other: &Tensor) -> Tensor {
171        assert_eq!(
172            self.shape(),
173            other.shape(),
174            "Tensor shapes must match for same-shape addition"
175        );
176        let mut result = self.add_tensor_optimized(other);
177
178        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
179            result.set_requires_grad_internal(true);
180            let grad_fn = GradFn::Add {
181                is_tensor_add: true,
182                original_shapes: None, // Same shape case
183            };
184            result.set_grad_fn(grad_fn.clone());
185
186            let mut input_ids = Vec::with_capacity(2);
187            if self.requires_grad() {
188                input_ids.push(self.id());
189            }
190            if other.requires_grad() {
191                input_ids.push(other.id());
192            }
193            GradEngine::register_operation(result.id(), input_ids, grad_fn);
194        }
195
196        result
197    }
198
199    /// Broadcast addition with a scalar value.
200    ///
201    /// Adds the scalar to every element: `output[i] = self[i] + scalar`
202    ///
203    /// # Arguments
204    /// * `scalar` - Value to add to each element
205    ///
206    /// # Returns
207    /// A new tensor with the scalar added to each element
208    ///
209    /// # Examples
210    ///
211    /// ## Basic Scalar Addition
212    ///
213    /// ```
214    /// use train_station::Tensor;
215    ///
216    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
217    /// let b = a.add_scalar(10.0);
218    /// assert_eq!(b.shape().dims, vec![3]);
219    /// assert_eq!(b.get(&[0]), 11.0);
220    /// assert_eq!(b.get(&[1]), 12.0);
221    /// assert_eq!(b.get(&[2]), 13.0);
222    /// ```
223    ///
224    /// ## Multi-dimensional Scalar Addition
225    ///
226    /// ```
227    /// use train_station::Tensor;
228    ///
229    /// let a = Tensor::ones(vec![2, 3]);
230    /// let b = a.add_scalar(5.0);
231    /// assert_eq!(b.shape().dims, vec![2, 3]);
232    /// assert_eq!(b.get(&[0, 0]), 6.0);
233    /// assert_eq!(b.get(&[1, 2]), 6.0);
234    /// ```
235    #[inline]
236    pub fn add_scalar(&self, scalar: f32) -> Tensor {
237        let mut result = self.add_scalar_optimized(scalar);
238
239        if self.requires_grad() && is_grad_enabled() {
240            result.set_requires_grad_internal(true);
241            let grad_fn = GradFn::Add {
242                is_tensor_add: false,
243                original_shapes: None, // Scalar case
244            };
245            result.set_grad_fn(grad_fn.clone());
246            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
247        }
248
249        result
250    }
251    /// Internal optimized tensor + tensor operation
252    ///
253    /// Performs element-wise addition between two tensors with the same shape,
254    /// using SIMD acceleration when available. This is the core implementation
255    /// used by `add_tensor()` after broadcasting has been applied.
256    ///
257    /// # Arguments
258    ///
259    /// * `other` - Tensor to add, must have the same shape as self
260    ///
261    /// # Returns
262    ///
263    /// A new tensor containing the element-wise sum
264    ///
265    /// # Safety
266    ///
267    /// Assumes both tensors have the same shape and valid memory layouts.
268    /// Uses unsafe SIMD operations for performance optimization.
269    ///
270    /// # Performance Characteristics
271    ///
272    /// - **SIMD Optimization**: Uses AVX2 when available for 8x vectorization
273    /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
274    /// - **Cache-friendly**: Linear memory access patterns
275    /// - **Fallback**: Optimized scalar implementation for non-SIMD hardware
276    #[inline]
277    pub(crate) fn add_tensor_optimized(&self, other: &Tensor) -> Tensor {
278        assert_eq!(self.shape(), other.shape(), "Tensor shapes must match");
279
280        let mut output = Tensor::new(self.shape().dims.clone());
281
282        unsafe {
283            let a = self.as_ptr();
284            let b = other.as_ptr();
285            let dst = output.as_mut_ptr();
286
287            #[cfg(target_arch = "x86_64")]
288            {
289                // Use SIMD for better performance when available
290                if is_x86_feature_detected!("avx2") {
291                    self.add_tensors_simd_avx2_optimized(a, b, dst);
292                    return output;
293                }
294            }
295
296            // Fallback to scalar operations with better cache usage
297            self.add_tensors_scalar_optimized(a, b, dst);
298        }
299
300        output
301    }
302
303    /// SIMD-optimized tensor addition using AVX2 instructions
304    ///
305    /// Performs element-wise addition using AVX2 SIMD instructions for maximum
306    /// performance on x86_64 hardware. Processes 32 elements per iteration with
307    /// 4x unrolling for optimal instruction throughput.
308    ///
309    /// # Arguments
310    ///
311    /// * `a` - Pointer to first tensor data
312    /// * `b` - Pointer to second tensor data
313    /// * `dst` - Pointer to output tensor data
314    ///
315    /// # Safety
316    ///
317    /// Requires AVX2 support and valid pointers with sufficient memory.
318    /// All pointers must be aligned and point to valid tensor data.
319    ///
320    /// # Performance Characteristics
321    ///
322    /// - **SIMD Width**: 8 elements per AVX2 vector operation
323    /// - **Unrolling**: 4x unrolling (32 elements per iteration)
324    /// - **Memory Access**: Linear access patterns for cache efficiency
325    /// - **Fallback**: Handles remaining elements with scalar operations
326    #[cfg(target_arch = "x86_64")]
327    #[inline]
328    #[target_feature(enable = "avx2")]
329    unsafe fn add_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
330        let size = self.size();
331        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
332        let mut offset = 0;
333
334        // Unrolled SIMD loop for throughput
335        for _ in 0..simd_count {
336            // Process 4 AVX2 vectors (32 elements) per iteration
337            let a_vec1 = _mm256_loadu_ps(a.add(offset));
338            let b_vec1 = _mm256_loadu_ps(b.add(offset));
339            let sum_vec1 = _mm256_add_ps(a_vec1, b_vec1);
340            _mm256_storeu_ps(dst.add(offset), sum_vec1);
341
342            let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
343            let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
344            let sum_vec2 = _mm256_add_ps(a_vec2, b_vec2);
345            _mm256_storeu_ps(dst.add(offset + 8), sum_vec2);
346
347            let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
348            let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
349            let sum_vec3 = _mm256_add_ps(a_vec3, b_vec3);
350            _mm256_storeu_ps(dst.add(offset + 16), sum_vec3);
351
352            let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
353            let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
354            let sum_vec4 = _mm256_add_ps(a_vec4, b_vec4);
355            _mm256_storeu_ps(dst.add(offset + 24), sum_vec4);
356
357            offset += 32;
358        }
359
360        // Handle remaining elements in blocks of 8 then tail
361        let remaining_full_blocks = (size - offset) / 8;
362        for _ in 0..remaining_full_blocks {
363            let a_vec = _mm256_loadu_ps(a.add(offset));
364            let b_vec = _mm256_loadu_ps(b.add(offset));
365            let sum_vec = _mm256_add_ps(a_vec, b_vec);
366            _mm256_storeu_ps(dst.add(offset), sum_vec);
367            offset += 8;
368        }
369        while offset + 4 <= size {
370            *dst.add(offset) = *a.add(offset) + *b.add(offset);
371            *dst.add(offset + 1) = *a.add(offset + 1) + *b.add(offset + 1);
372            *dst.add(offset + 2) = *a.add(offset + 2) + *b.add(offset + 2);
373            *dst.add(offset + 3) = *a.add(offset + 3) + *b.add(offset + 3);
374            offset += 4;
375        }
376        for i in offset..size {
377            *dst.add(i) = *a.add(i) + *b.add(i);
378        }
379    }
380
381    /// Optimized scalar tensor addition fallback
382    ///
383    /// Performs element-wise addition using optimized scalar operations when
384    /// SIMD is not available. Uses 4x unrolling for better instruction-level
385    /// parallelism and cache efficiency.
386    ///
387    /// # Arguments
388    ///
389    /// * `a` - Pointer to first tensor data
390    /// * `b` - Pointer to second tensor data
391    /// * `dst` - Pointer to output tensor data
392    ///
393    /// # Safety
394    ///
395    /// Requires valid pointers with sufficient memory for the tensor size.
396    /// All pointers must point to valid tensor data.
397    ///
398    /// # Performance Characteristics
399    ///
400    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
401    /// - **Memory Access**: Linear access patterns for cache efficiency
402    /// - **Fallback**: Handles remaining elements with scalar operations
403    #[inline]
404    unsafe fn add_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
405        let size = self.size();
406        let unroll_count = size / 4;
407        let mut offset = 0;
408
409        // Unrolled scalar loop for better performance
410        for _ in 0..unroll_count {
411            *dst.add(offset) = *a.add(offset) + *b.add(offset);
412            *dst.add(offset + 1) = *a.add(offset + 1) + *b.add(offset + 1);
413            *dst.add(offset + 2) = *a.add(offset + 2) + *b.add(offset + 2);
414            *dst.add(offset + 3) = *a.add(offset + 3) + *b.add(offset + 3);
415            offset += 4;
416        }
417
418        // Handle remaining elements
419        for i in offset..size {
420            *dst.add(i) = *a.add(i) + *b.add(i);
421        }
422    }
423
424    /// Internal optimized scalar + tensor operation
425    ///
426    /// Performs element-wise addition of a scalar to each element of the tensor,
427    /// using SIMD acceleration when available. This is the core implementation
428    /// used by `add_scalar()`.
429    ///
430    /// # Arguments
431    ///
432    /// * `scalar` - Scalar value to add to each element
433    ///
434    /// # Returns
435    ///
436    /// A new tensor with the scalar added to each element
437    ///
438    /// # Safety
439    ///
440    /// Assumes valid tensor memory layout. Uses unsafe SIMD operations for
441    /// performance optimization.
442    ///
443    /// # Performance Characteristics
444    ///
445    /// - **SIMD Optimization**: Uses AVX2 when available for 8x vectorization
446    /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
447    /// - **Cache-friendly**: Linear memory access patterns
448    /// - **Fallback**: Optimized scalar implementation for non-SIMD hardware
449    #[inline]
450    pub(crate) fn add_scalar_optimized(&self, scalar: f32) -> Tensor {
451        let mut output = Tensor::new(self.shape().dims.clone());
452
453        unsafe {
454            let src = self.as_ptr();
455            let dst = output.as_mut_ptr();
456
457            #[cfg(target_arch = "x86_64")]
458            {
459                // Use SIMD for better performance when available
460                if is_x86_feature_detected!("avx2") {
461                    self.add_scalar_simd_avx2_optimized(src, dst, scalar);
462                    return output;
463                }
464            }
465
466            // Fallback to optimized scalar operations
467            self.add_scalar_fallback_optimized(src, dst, scalar);
468        }
469
470        output
471    }
472
473    /// SIMD-optimized scalar addition using AVX2 instructions
474    ///
475    /// Performs element-wise scalar addition using AVX2 SIMD instructions for maximum
476    /// performance on x86_64 hardware. Processes 32 elements per iteration with
477    /// 4x unrolling for optimal instruction throughput.
478    ///
479    /// # Arguments
480    ///
481    /// * `src` - Pointer to source tensor data
482    /// * `dst` - Pointer to output tensor data
483    /// * `scalar` - Scalar value to add to each element
484    ///
485    /// # Safety
486    ///
487    /// Requires AVX2 support and valid pointers with sufficient memory.
488    /// All pointers must be aligned and point to valid tensor data.
489    ///
490    /// # Performance Characteristics
491    ///
492    /// - **SIMD Width**: 8 elements per AVX2 vector operation
493    /// - **Unrolling**: 4x unrolling (32 elements per iteration)
494    /// - **Memory Access**: Linear access patterns for cache efficiency
495    /// - **Fallback**: Handles remaining elements with scalar operations
496    #[cfg(target_arch = "x86_64")]
497    #[inline]
498    #[target_feature(enable = "avx2")]
499    unsafe fn add_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
500        let scalar_vec = _mm256_set1_ps(scalar);
501        let size = self.size();
502        let simd_count = size / 32; // Process 32 elements per iteration
503        let mut offset = 0;
504
505        // Unrolled SIMD loop for instruction throughput
506        for _ in 0..simd_count {
507            let src_vec1 = _mm256_loadu_ps(src.add(offset));
508            let sum_vec1 = _mm256_add_ps(src_vec1, scalar_vec);
509            _mm256_storeu_ps(dst.add(offset), sum_vec1);
510
511            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
512            let sum_vec2 = _mm256_add_ps(src_vec2, scalar_vec);
513            _mm256_storeu_ps(dst.add(offset + 8), sum_vec2);
514
515            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
516            let sum_vec3 = _mm256_add_ps(src_vec3, scalar_vec);
517            _mm256_storeu_ps(dst.add(offset + 16), sum_vec3);
518
519            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
520            let sum_vec4 = _mm256_add_ps(src_vec4, scalar_vec);
521            _mm256_storeu_ps(dst.add(offset + 24), sum_vec4);
522
523            offset += 32;
524        }
525
526        // Handle remaining 8-element blocks
527        let remaining_full_blocks = (size - offset) / 8;
528        for _ in 0..remaining_full_blocks {
529            let src_vec = _mm256_loadu_ps(src.add(offset));
530            let sum_vec = _mm256_add_ps(src_vec, scalar_vec);
531            _mm256_storeu_ps(dst.add(offset), sum_vec);
532            offset += 8;
533        }
534
535        // Handle final elements
536        for i in offset..size {
537            *dst.add(i) = *src.add(i) + scalar;
538        }
539    }
540
541    /// Optimized scalar addition fallback
542    ///
543    /// Performs element-wise scalar addition using optimized scalar operations when
544    /// SIMD is not available. Uses 4x unrolling for better instruction-level
545    /// parallelism and cache efficiency.
546    ///
547    /// # Arguments
548    ///
549    /// * `src` - Pointer to source tensor data
550    /// * `dst` - Pointer to output tensor data
551    /// * `scalar` - Scalar value to add to each element
552    ///
553    /// # Safety
554    ///
555    /// Requires valid pointers with sufficient memory for the tensor size.
556    /// All pointers must point to valid tensor data.
557    ///
558    /// # Performance Characteristics
559    ///
560    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
561    /// - **Memory Access**: Linear access patterns for cache efficiency
562    /// - **Fallback**: Handles remaining elements with scalar operations
563    #[inline]
564    unsafe fn add_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
565        let size = self.size();
566        let unroll_count = size / 4;
567        let mut offset = 0;
568
569        // Unrolled scalar operations with while for clarity
570        for _ in 0..unroll_count {
571            *dst.add(offset) = *src.add(offset) + scalar;
572            *dst.add(offset + 1) = *src.add(offset + 1) + scalar;
573            *dst.add(offset + 2) = *src.add(offset + 2) + scalar;
574            *dst.add(offset + 3) = *src.add(offset + 3) + scalar;
575            offset += 4;
576        }
577        for i in offset..size {
578            *dst.add(i) = *src.add(i) + scalar;
579        }
580    }
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586
587    #[test]
588    fn test_tensor_addition() {
589        let a = Tensor::ones(vec![2, 3]);
590        let b = Tensor::ones(vec![2, 3]);
591        let result = a.add_tensor_optimized(&b);
592
593        assert_eq!(result.shape().dims, vec![2, 3]);
594        assert_eq!(result.size(), 6);
595
596        // Check that all values are 2.0 (1.0 + 1.0)
597        unsafe {
598            for i in 0..result.size() {
599                assert!((result.as_ptr().add(i).read() - 2.0).abs() < 1e-6);
600            }
601        }
602    }
603
604    #[test]
605    fn test_scalar_addition() {
606        let tensor = Tensor::ones(vec![2, 2]);
607        let result = tensor.add_scalar_optimized(5.0);
608
609        assert_eq!(result.shape().dims, vec![2, 2]);
610        assert_eq!(result.size(), 4);
611
612        // Check that all values are 6.0 (1.0 + 5.0)
613        unsafe {
614            for i in 0..result.size() {
615                assert!((result.as_ptr().add(i).read() - 6.0).abs() < 1e-6);
616            }
617        }
618    }
619
620    #[test]
621    #[should_panic(expected = "Tensor shapes must match")]
622    fn test_mismatched_shapes() {
623        let a = Tensor::ones(vec![2, 3]);
624        let b = Tensor::ones(vec![3, 2]);
625        a.add_tensor_optimized(&b);
626    }
627
628    #[test]
629    fn test_add_with_no_grad_guard() {
630        use crate::gradtrack::{is_grad_enabled, NoGradTrack};
631
632        // Create tensors with requires_grad enabled
633        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
634        let b = Tensor::ones(vec![2, 2]).with_requires_grad();
635
636        // Verify gradients are enabled by default
637        assert!(is_grad_enabled());
638
639        // Normal addition with gradients
640        let c1 = a.add_tensor(&b);
641        assert!(
642            c1.requires_grad(),
643            "Result should require gradients normally"
644        );
645
646        // Addition with NoGradTrack - gradients should be disabled
647        {
648            let _guard = NoGradTrack::new();
649            assert!(
650                !is_grad_enabled(),
651                "Gradients should be disabled within guard"
652            );
653
654            let c2 = a.add_tensor(&b);
655            assert!(
656                !c2.requires_grad(),
657                "Result should not require gradients within NoGradTrack"
658            );
659
660            // Test scalar addition as well
661            let c3 = a.add_scalar(5.0);
662            assert!(
663                !c3.requires_grad(),
664                "Scalar addition result should not require gradients within NoGradTrack"
665            );
666        }
667
668        // Gradients should be restored after guard goes out of scope
669        assert!(
670            is_grad_enabled(),
671            "Gradients should be restored after guard"
672        );
673
674        let c4 = a.add_tensor(&b);
675        assert!(
676            c4.requires_grad(),
677            "Result should require gradients after guard is dropped"
678        );
679    }
680
681    #[test]
682    fn test_add_nested_no_grad_guards() {
683        use crate::gradtrack::{is_grad_enabled, NoGradTrack};
684
685        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
686        let b = Tensor::ones(vec![2, 2]).with_requires_grad();
687
688        assert!(is_grad_enabled());
689
690        {
691            let _guard1 = NoGradTrack::new();
692            assert!(!is_grad_enabled());
693
694            let c1 = a.add_tensor(&b);
695            assert!(!c1.requires_grad());
696
697            {
698                let _guard2 = NoGradTrack::new();
699                assert!(!is_grad_enabled());
700
701                let c2 = a.add_tensor(&b);
702                assert!(!c2.requires_grad());
703            }
704
705            // Still disabled after inner guard drops
706            assert!(!is_grad_enabled());
707            let c3 = a.add_tensor(&b);
708            assert!(!c3.requires_grad());
709        }
710
711        // Restored after all guards drop
712        assert!(is_grad_enabled());
713        let c4 = a.add_tensor(&b);
714        assert!(c4.requires_grad());
715    }
716
717    #[test]
718    fn test_add_with_mixed_requires_grad() {
719        use crate::gradtrack::NoGradTrack;
720
721        let a = Tensor::ones(vec![2, 2]).with_requires_grad(); // requires_grad = true
722        let b = Tensor::ones(vec![2, 2]); // requires_grad = false
723
724        // Without NoGradTrack, result should require gradients if any input does
725        let c1 = a.add_tensor(&b);
726        assert!(c1.requires_grad());
727
728        let c2 = b.add_tensor(&a);
729        assert!(c2.requires_grad());
730
731        // With NoGradTrack, result should not require gradients regardless
732        {
733            let _guard = NoGradTrack::new();
734
735            let c3 = a.add_tensor(&b);
736            assert!(!c3.requires_grad());
737
738            let c4 = b.add_tensor(&a);
739            assert!(!c4.requires_grad());
740        }
741    }
742
743    #[test]
744    fn test_add_performance_no_overhead() {
745        use crate::gradtrack::NoGradTrack;
746        use std::time::Instant;
747
748        let size = 1000; // Smaller size for test stability
749        let a = Tensor::ones(vec![size]).with_requires_grad();
750        let b = Tensor::ones(vec![size]);
751
752        // Time normal addition (with potential grad overhead)
753        let start = Instant::now();
754        for _ in 0..10 {
755            let _ = a.add_tensor(&b);
756        }
757        let normal_duration = start.elapsed();
758
759        // Time addition with NoGradTrack (should be faster)
760        let start = Instant::now();
761        {
762            let _guard = NoGradTrack::new();
763            for _ in 0..10 {
764                let _ = a.add_tensor(&b);
765            }
766        }
767        let no_grad_duration = start.elapsed();
768
769        // NoGradTrack should provide performance benefit by skipping gradtrack setup
770        // Allow generous variance for timing inconsistencies in tests
771        println!(
772            "Normal: {:?}, NoGrad: {:?}",
773            normal_duration, no_grad_duration
774        );
775
776        // The key verification is that NoGradTrack doesn't add overhead
777        assert!(
778            no_grad_duration <= normal_duration * 3,
779            "NoGradTrack should not add significant overhead"
780        );
781    }
782
783    #[test]
784    fn test_broadcasting_gradients_basic() {
785        use crate::gradtrack::clear_gradients;
786        clear_gradients();
787
788        // Test case: [2, 3] + [1, 3] -> [2, 3]
789        // grad_a should be [2, 3], grad_b should be [1, 3] (summed over broadcast dim)
790
791        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
792            .unwrap()
793            .with_requires_grad();
794        let b = Tensor::from_slice(&[0.1, 0.2, 0.3], vec![1, 3])
795            .unwrap()
796            .with_requires_grad();
797
798        let mut result = a.add_tensor(&b);
799        assert_eq!(result.shape().dims, vec![2, 3]);
800
801        // Set upstream gradient as ones
802        result.backward(None);
803
804        // Check gradients
805        let grad_a = a.grad_by_value().expect("grad_a should exist");
806        let grad_b = b.grad_by_value().expect("grad_b should exist");
807
808        println!(
809            "Original shapes: a={:?}, b={:?}",
810            a.shape().dims,
811            b.shape().dims
812        );
813        println!(
814            "Gradient shapes: grad_a={:?}, grad_b={:?}",
815            grad_a.shape().dims,
816            grad_b.shape().dims
817        );
818
819        // grad_a should have same shape as a: [2, 3]
820        assert_eq!(
821            grad_a.shape().dims,
822            vec![2, 3],
823            "grad_a should match original shape of a"
824        );
825
826        // grad_b should have same shape as b: [1, 3]
827        // This requires summing over the broadcasted dimension
828        assert_eq!(
829            grad_b.shape().dims,
830            vec![1, 3],
831            "grad_b should match original shape of b"
832        );
833
834        // All gradients should be 1.0 for grad_a
835        for i in 0..grad_a.size() {
836            let val = unsafe { *grad_a.as_ptr().add(i) };
837            assert!(
838                (val - 1.0).abs() < 1e-6,
839                "grad_a[{}] = {} should be 1.0",
840                i,
841                val
842            );
843        }
844
845        // grad_b should be [2.0, 2.0, 2.0] (sum over broadcast dim)
846        let expected_grad_b = [2.0, 2.0, 2.0];
847        for (i, val) in expected_grad_b.iter().enumerate().take(grad_b.size()) {
848            let actual = unsafe { *grad_b.as_ptr().add(i) };
849            assert!(
850                (actual - val).abs() < 1e-6,
851                "grad_b[{}] = {} should be {}",
852                i,
853                actual,
854                val
855            );
856        }
857    }
858
859    #[test]
860    fn test_scalar_broadcasting_gradients() {
861        use crate::gradtrack::clear_gradients;
862        clear_gradients();
863
864        // Test case: [2, 3] + [1] -> [2, 3]
865        // grad_a should be [2, 3], grad_b should be [1] (summed over all dims)
866
867        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
868            .unwrap()
869            .with_requires_grad();
870        let b = Tensor::from_slice(&[0.5], vec![1])
871            .unwrap()
872            .with_requires_grad();
873
874        let mut result = a.add_tensor(&b);
875        result.backward(None);
876
877        let grad_a = a.grad_by_value().expect("grad_a should exist");
878        let grad_b = b.grad_by_value().expect("grad_b should exist");
879
880        // grad_a should have same shape as a: [2, 3]
881        assert_eq!(grad_a.shape().dims, vec![2, 3]);
882
883        // grad_b should have same shape as b: [1] and sum to 6.0
884        println!("grad_b shape: {:?}, expected: [1]", grad_b.shape().dims);
885        assert_eq!(grad_b.shape().dims, vec![1]);
886
887        // grad_b should be 6.0 (sum over all 6 elements)
888        let val = unsafe { *grad_b.as_ptr() };
889        assert!((val - 6.0).abs() < 1e-6, "grad_b = {} should be 6.0", val);
890    }
891
892    #[test]
893    fn test_linear_layer_bias_broadcasting() {
894        use crate::gradtrack::clear_gradients;
895        clear_gradients();
896
897        // Simulate linear layer bias broadcasting
898        // input: [2, 3], weight: [3, 4], bias: [4]
899        // matmul result: [2, 4], bias broadcast: [4] -> [2, 4]
900
901        let input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
902            .unwrap()
903            .with_requires_grad();
904        let weight = Tensor::from_slice(
905            &(1..=12).map(|i| i as f32 * 0.1).collect::<Vec<_>>(),
906            vec![3, 4],
907        )
908        .unwrap()
909        .with_requires_grad();
910        let bias = Tensor::from_slice(&[0.1, 0.2, 0.3, 0.4], vec![4])
911            .unwrap()
912            .with_requires_grad();
913
914        // Forward pass: input @ weight + bias
915        let matmul_result = input.matmul(&weight);
916        println!("Matmul result shape: {:?}", matmul_result.shape().dims);
917        println!("Bias shape: {:?}", bias.shape().dims);
918
919        let linear_output = matmul_result.add_tensor(&bias);
920        println!("Linear output shape: {:?}", linear_output.shape().dims);
921
922        // Sum all outputs as loss
923        let mut loss = linear_output.sum();
924        loss.backward(None);
925
926        // Check bias gradient
927        let bias_grad = bias.grad_by_value().expect("bias gradient should exist");
928        println!("Bias gradient shape: {:?}", bias_grad.shape().dims);
929        assert_eq!(
930            bias_grad.shape().dims,
931            vec![4],
932            "bias gradient should match bias shape"
933        );
934
935        // Bias gradient should be [2.0, 2.0, 2.0, 2.0] (sum over batch dimension)
936        for i in 0..4 {
937            let val = unsafe { *bias_grad.as_ptr().add(i) };
938            assert!(
939                (val - 2.0).abs() < 1e-6,
940                "bias_grad[{}] = {} should be 2.0",
941                i,
942                val
943            );
944        }
945
946        println!("Linear layer bias broadcasting test passed!");
947    }
948}