train_station/tensor/ops/add.rs
1//! Addition operations for tensors
2//!
3//! Provides element-wise addition following PyTorch conventions with comprehensive
4//! broadcasting support, automatic differentiation, and high-performance SIMD optimization.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Addition**: `add_tensor()` - Addition with another tensor (PyTorch `add()` equivalent)
9//! - **Scalar Broadcasting**: `add_scalar()` - Addition with scalar values
10//! - **Automatic Broadcasting**: NumPy-style broadcasting for compatible shapes
11//! - **SIMD Optimization**: AVX2 acceleration on x86_64 hardware
12//! - **Automatic Differentiation**: Full gradtrack support with gradient tracking
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Zero-copy Operations**: Efficient memory usage where possible
15//!
16//! # Broadcasting Support
17//!
18//! All addition operations support automatic broadcasting following NumPy rules:
19//! - Dimensions are aligned from the rightmost dimension
20//! - Dimensions are compatible if they are equal, or one of them is 1
21//! - Missing dimensions are treated as 1
22//! - Result shape follows broadcasting rules
23//!
24//! # Performance Characteristics
25//!
26//! - **SIMD Acceleration**: 8x vectorization with AVX2 on compatible hardware
27//! - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
28//! - **Cache-friendly Access**: Linear memory access patterns
29//! - **Fallback Support**: Optimized scalar implementations for non-SIMD hardware
30//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
31
32use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
33use crate::tensor::core::Tensor;
34
35// SIMD optimizations for performance-critical operations
36#[cfg(target_arch = "x86_64")]
37use std::arch::x86_64::*;
38
39// (Removed manual prefetching: simplifies hot path; modern CPUs prefetch effectively for linear access)
40
41impl Tensor {
42 /// Element-wise addition with another tensor with broadcasting support.
43 ///
44 /// Performs element-wise addition with automatic broadcasting: `output[i] = self[i] + other[i]`
45 ///
46 /// Broadcasting enables addition between tensors of different but compatible shapes.
47 /// Compatible shapes follow NumPy broadcasting rules:
48 /// - Dimensions are aligned from the rightmost dimension
49 /// - Dimensions are compatible if they are equal, or one of them is 1
50 /// - Missing dimensions are treated as 1
51 ///
52 /// # Arguments
53 /// * `other` - Tensor to add. Shapes must be broadcast-compatible.
54 ///
55 /// # Returns
56 /// A new tensor containing the element-wise sum with broadcast result shape
57 ///
58 /// # Examples
59 ///
60 /// ## Same Shape Addition
61 ///
62 /// ```
63 /// use train_station::Tensor;
64 ///
65 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
66 /// let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
67 /// let c = a.add_tensor(&b);
68 /// assert_eq!(c.shape().dims, vec![3]);
69 /// assert_eq!(c.get(&[0]), 5.0);
70 /// assert_eq!(c.get(&[1]), 7.0);
71 /// assert_eq!(c.get(&[2]), 9.0);
72 /// ```
73 ///
74 /// ## Broadcasting Addition
75 ///
76 /// ```
77 /// use train_station::Tensor;
78 ///
79 /// // Broadcasting: [2, 1] + [1, 3] -> [2, 3]
80 /// let a = Tensor::from_slice(&[1.0, 2.0], vec![2, 1]).unwrap();
81 /// let b = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![1, 3]).unwrap();
82 /// let c = a.add_tensor(&b);
83 /// assert_eq!(c.shape().dims, vec![2, 3]);
84 /// assert_eq!(c.get(&[0, 0]), 11.0);
85 /// assert_eq!(c.get(&[0, 1]), 21.0);
86 /// assert_eq!(c.get(&[1, 0]), 12.0);
87 /// assert_eq!(c.get(&[1, 1]), 22.0);
88 /// ```
89 ///
90 /// ## Scalar Broadcasting
91 ///
92 /// ```
93 /// use train_station::Tensor;
94 ///
95 /// // Scalar broadcasting: [2, 3] + scalar -> [2, 3]
96 /// let a = Tensor::ones(vec![2, 3]);
97 /// let b = Tensor::from_slice(&[5.0], vec![1]).unwrap();
98 /// let c = a.add_tensor(&b);
99 /// assert_eq!(c.shape().dims, vec![2, 3]);
100 /// assert_eq!(c.get(&[0, 0]), 6.0);
101 /// assert_eq!(c.get(&[1, 2]), 6.0);
102 /// ```
103 ///
104 /// # Panics
105 /// Panics if tensor shapes are not broadcast-compatible
106 #[inline]
107 #[track_caller]
108 pub fn add_tensor(&self, other: &Tensor) -> Tensor {
109 // Check if shapes are identical for fast path
110 if self.shape().dims == other.shape().dims {
111 return self.add_tensor_same_shape(other);
112 }
113
114 // Use broadcasting for different shapes
115 let (broadcast_self, broadcast_other, _result_shape) =
116 self.broadcast_with(other).unwrap_or_else(|e| {
117 panic!(
118 "Cannot broadcast tensor shapes {:?} and {:?}: {}",
119 self.shape().dims,
120 other.shape().dims,
121 e
122 );
123 });
124
125 // Perform element-wise addition on broadcasted tensors
126 let mut result = broadcast_self.add_tensor_optimized(&broadcast_other);
127
128 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
129 result.set_requires_grad_internal(true);
130 let grad_fn = GradFn::Add {
131 is_tensor_add: true,
132 original_shapes: Some((self.shape().dims.clone(), other.shape().dims.clone())),
133 };
134 result.set_grad_fn(grad_fn.clone());
135
136 let mut input_ids = Vec::with_capacity(2);
137 if self.requires_grad() {
138 input_ids.push(self.id());
139 }
140 if other.requires_grad() {
141 input_ids.push(other.id());
142 }
143 GradEngine::register_operation(result.id(), input_ids, grad_fn);
144 }
145
146 result
147 }
148
149 /// Element-wise addition for tensors with identical shapes (fast path).
150 ///
151 /// This is an optimized path for tensors that already have the same shape,
152 /// avoiding the overhead of broadcasting computation. Used internally by
153 /// `add_tensor()` when shapes are identical.
154 ///
155 /// # Arguments
156 /// * `other` - Tensor to add, must have the same shape as self
157 ///
158 /// # Returns
159 /// A new tensor containing the element-wise sum
160 ///
161 /// # Performance Characteristics
162 ///
163 /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
164 /// - **SIMD Optimization**: Uses optimized tensor addition with SIMD acceleration
165 /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
166 ///
167 /// # Panics
168 ///
169 /// Panics if tensor shapes do not match
170 #[inline]
171 fn add_tensor_same_shape(&self, other: &Tensor) -> Tensor {
172 assert_eq!(
173 self.shape(),
174 other.shape(),
175 "Tensor shapes must match for same-shape addition"
176 );
177 let mut result = self.add_tensor_optimized(other);
178
179 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
180 result.set_requires_grad_internal(true);
181 let grad_fn = GradFn::Add {
182 is_tensor_add: true,
183 original_shapes: None, // Same shape case
184 };
185 result.set_grad_fn(grad_fn.clone());
186
187 let mut input_ids = Vec::with_capacity(2);
188 if self.requires_grad() {
189 input_ids.push(self.id());
190 }
191 if other.requires_grad() {
192 input_ids.push(other.id());
193 }
194 GradEngine::register_operation(result.id(), input_ids, grad_fn);
195 }
196
197 result
198 }
199
200 /// Broadcast addition with a scalar value.
201 ///
202 /// Adds the scalar to every element: `output[i] = self[i] + scalar`
203 ///
204 /// # Arguments
205 /// * `scalar` - Value to add to each element
206 ///
207 /// # Returns
208 /// A new tensor with the scalar added to each element
209 ///
210 /// # Examples
211 ///
212 /// ## Basic Scalar Addition
213 ///
214 /// ```
215 /// use train_station::Tensor;
216 ///
217 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
218 /// let b = a.add_scalar(10.0);
219 /// assert_eq!(b.shape().dims, vec![3]);
220 /// assert_eq!(b.get(&[0]), 11.0);
221 /// assert_eq!(b.get(&[1]), 12.0);
222 /// assert_eq!(b.get(&[2]), 13.0);
223 /// ```
224 ///
225 /// ## Multi-dimensional Scalar Addition
226 ///
227 /// ```
228 /// use train_station::Tensor;
229 ///
230 /// let a = Tensor::ones(vec![2, 3]);
231 /// let b = a.add_scalar(5.0);
232 /// assert_eq!(b.shape().dims, vec![2, 3]);
233 /// assert_eq!(b.get(&[0, 0]), 6.0);
234 /// assert_eq!(b.get(&[1, 2]), 6.0);
235 /// ```
236 #[inline]
237 #[track_caller]
238 pub fn add_scalar(&self, scalar: f32) -> Tensor {
239 let mut result = self.add_scalar_optimized(scalar);
240
241 if self.requires_grad() && is_grad_enabled() {
242 result.set_requires_grad_internal(true);
243 let grad_fn = GradFn::Add {
244 is_tensor_add: false,
245 original_shapes: None, // Scalar case
246 };
247 result.set_grad_fn(grad_fn.clone());
248 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
249 }
250
251 result
252 }
253 /// Internal optimized tensor + tensor operation
254 ///
255 /// Performs element-wise addition between two tensors with the same shape,
256 /// using SIMD acceleration when available. This is the core implementation
257 /// used by `add_tensor()` after broadcasting has been applied.
258 ///
259 /// # Arguments
260 ///
261 /// * `other` - Tensor to add, must have the same shape as self
262 ///
263 /// # Returns
264 ///
265 /// A new tensor containing the element-wise sum
266 ///
267 /// # Safety
268 ///
269 /// Assumes both tensors have the same shape and valid memory layouts.
270 /// Uses unsafe SIMD operations for performance optimization.
271 ///
272 /// # Performance Characteristics
273 ///
274 /// - **SIMD Optimization**: Uses AVX2 when available for 8x vectorization
275 /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
276 /// - **Cache-friendly**: Linear memory access patterns
277 /// - **Fallback**: Optimized scalar implementation for non-SIMD hardware
278 #[inline]
279 pub(crate) fn add_tensor_optimized(&self, other: &Tensor) -> Tensor {
280 assert_eq!(self.shape(), other.shape(), "Tensor shapes must match");
281
282 let mut output = Tensor::new(self.shape().dims.clone());
283
284 unsafe {
285 let a = self.as_ptr();
286 let b = other.as_ptr();
287 let dst = output.as_mut_ptr();
288
289 #[cfg(target_arch = "x86_64")]
290 {
291 // Use SIMD for better performance when available
292 if is_x86_feature_detected!("avx2") {
293 self.add_tensors_simd_avx2_optimized(a, b, dst);
294 return output;
295 }
296 }
297
298 // Fallback to scalar operations with better cache usage
299 self.add_tensors_scalar_optimized(a, b, dst);
300 }
301
302 output
303 }
304
305 /// SIMD-optimized tensor addition using AVX2 instructions
306 ///
307 /// Performs element-wise addition using AVX2 SIMD instructions for maximum
308 /// performance on x86_64 hardware. Processes 32 elements per iteration with
309 /// 4x unrolling for optimal instruction throughput.
310 ///
311 /// # Arguments
312 ///
313 /// * `a` - Pointer to first tensor data
314 /// * `b` - Pointer to second tensor data
315 /// * `dst` - Pointer to output tensor data
316 ///
317 /// # Safety
318 ///
319 /// Requires AVX2 support and valid pointers with sufficient memory.
320 /// All pointers must be aligned and point to valid tensor data.
321 ///
322 /// # Performance Characteristics
323 ///
324 /// - **SIMD Width**: 8 elements per AVX2 vector operation
325 /// - **Unrolling**: 4x unrolling (32 elements per iteration)
326 /// - **Memory Access**: Linear access patterns for cache efficiency
327 /// - **Fallback**: Handles remaining elements with scalar operations
328 #[cfg(target_arch = "x86_64")]
329 #[inline]
330 #[target_feature(enable = "avx2")]
331 unsafe fn add_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
332 let size = self.size();
333 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
334 let mut offset = 0;
335
336 // Unrolled SIMD loop for throughput
337 for _ in 0..simd_count {
338 // Process 4 AVX2 vectors (32 elements) per iteration
339 let a_vec1 = _mm256_loadu_ps(a.add(offset));
340 let b_vec1 = _mm256_loadu_ps(b.add(offset));
341 let sum_vec1 = _mm256_add_ps(a_vec1, b_vec1);
342 _mm256_storeu_ps(dst.add(offset), sum_vec1);
343
344 let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
345 let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
346 let sum_vec2 = _mm256_add_ps(a_vec2, b_vec2);
347 _mm256_storeu_ps(dst.add(offset + 8), sum_vec2);
348
349 let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
350 let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
351 let sum_vec3 = _mm256_add_ps(a_vec3, b_vec3);
352 _mm256_storeu_ps(dst.add(offset + 16), sum_vec3);
353
354 let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
355 let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
356 let sum_vec4 = _mm256_add_ps(a_vec4, b_vec4);
357 _mm256_storeu_ps(dst.add(offset + 24), sum_vec4);
358
359 offset += 32;
360 }
361
362 // Handle remaining elements in blocks of 8 then tail
363 let remaining_full_blocks = (size - offset) / 8;
364 for _ in 0..remaining_full_blocks {
365 let a_vec = _mm256_loadu_ps(a.add(offset));
366 let b_vec = _mm256_loadu_ps(b.add(offset));
367 let sum_vec = _mm256_add_ps(a_vec, b_vec);
368 _mm256_storeu_ps(dst.add(offset), sum_vec);
369 offset += 8;
370 }
371 while offset + 4 <= size {
372 *dst.add(offset) = *a.add(offset) + *b.add(offset);
373 *dst.add(offset + 1) = *a.add(offset + 1) + *b.add(offset + 1);
374 *dst.add(offset + 2) = *a.add(offset + 2) + *b.add(offset + 2);
375 *dst.add(offset + 3) = *a.add(offset + 3) + *b.add(offset + 3);
376 offset += 4;
377 }
378 for i in offset..size {
379 *dst.add(i) = *a.add(i) + *b.add(i);
380 }
381 }
382
383 /// Optimized scalar tensor addition fallback
384 ///
385 /// Performs element-wise addition using optimized scalar operations when
386 /// SIMD is not available. Uses 4x unrolling for better instruction-level
387 /// parallelism and cache efficiency.
388 ///
389 /// # Arguments
390 ///
391 /// * `a` - Pointer to first tensor data
392 /// * `b` - Pointer to second tensor data
393 /// * `dst` - Pointer to output tensor data
394 ///
395 /// # Safety
396 ///
397 /// Requires valid pointers with sufficient memory for the tensor size.
398 /// All pointers must point to valid tensor data.
399 ///
400 /// # Performance Characteristics
401 ///
402 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
403 /// - **Memory Access**: Linear access patterns for cache efficiency
404 /// - **Fallback**: Handles remaining elements with scalar operations
405 #[inline]
406 unsafe fn add_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
407 let size = self.size();
408 let unroll_count = size / 4;
409 let mut offset = 0;
410
411 // Unrolled scalar loop for better performance
412 for _ in 0..unroll_count {
413 *dst.add(offset) = *a.add(offset) + *b.add(offset);
414 *dst.add(offset + 1) = *a.add(offset + 1) + *b.add(offset + 1);
415 *dst.add(offset + 2) = *a.add(offset + 2) + *b.add(offset + 2);
416 *dst.add(offset + 3) = *a.add(offset + 3) + *b.add(offset + 3);
417 offset += 4;
418 }
419
420 // Handle remaining elements
421 for i in offset..size {
422 *dst.add(i) = *a.add(i) + *b.add(i);
423 }
424 }
425
426 /// Internal optimized scalar + tensor operation
427 ///
428 /// Performs element-wise addition of a scalar to each element of the tensor,
429 /// using SIMD acceleration when available. This is the core implementation
430 /// used by `add_scalar()`.
431 ///
432 /// # Arguments
433 ///
434 /// * `scalar` - Scalar value to add to each element
435 ///
436 /// # Returns
437 ///
438 /// A new tensor with the scalar added to each element
439 ///
440 /// # Safety
441 ///
442 /// Assumes valid tensor memory layout. Uses unsafe SIMD operations for
443 /// performance optimization.
444 ///
445 /// # Performance Characteristics
446 ///
447 /// - **SIMD Optimization**: Uses AVX2 when available for 8x vectorization
448 /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
449 /// - **Cache-friendly**: Linear memory access patterns
450 /// - **Fallback**: Optimized scalar implementation for non-SIMD hardware
451 #[inline]
452 pub(crate) fn add_scalar_optimized(&self, scalar: f32) -> Tensor {
453 let mut output = Tensor::new(self.shape().dims.clone());
454
455 unsafe {
456 let src = self.as_ptr();
457 let dst = output.as_mut_ptr();
458
459 #[cfg(target_arch = "x86_64")]
460 {
461 // Use SIMD for better performance when available
462 if is_x86_feature_detected!("avx2") {
463 self.add_scalar_simd_avx2_optimized(src, dst, scalar);
464 return output;
465 }
466 }
467
468 // Fallback to optimized scalar operations
469 self.add_scalar_fallback_optimized(src, dst, scalar);
470 }
471
472 output
473 }
474
475 /// SIMD-optimized scalar addition using AVX2 instructions
476 ///
477 /// Performs element-wise scalar addition using AVX2 SIMD instructions for maximum
478 /// performance on x86_64 hardware. Processes 32 elements per iteration with
479 /// 4x unrolling for optimal instruction throughput.
480 ///
481 /// # Arguments
482 ///
483 /// * `src` - Pointer to source tensor data
484 /// * `dst` - Pointer to output tensor data
485 /// * `scalar` - Scalar value to add to each element
486 ///
487 /// # Safety
488 ///
489 /// Requires AVX2 support and valid pointers with sufficient memory.
490 /// All pointers must be aligned and point to valid tensor data.
491 ///
492 /// # Performance Characteristics
493 ///
494 /// - **SIMD Width**: 8 elements per AVX2 vector operation
495 /// - **Unrolling**: 4x unrolling (32 elements per iteration)
496 /// - **Memory Access**: Linear access patterns for cache efficiency
497 /// - **Fallback**: Handles remaining elements with scalar operations
498 #[cfg(target_arch = "x86_64")]
499 #[inline]
500 #[target_feature(enable = "avx2")]
501 unsafe fn add_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
502 let scalar_vec = _mm256_set1_ps(scalar);
503 let size = self.size();
504 let simd_count = size / 32; // Process 32 elements per iteration
505 let mut offset = 0;
506
507 // Unrolled SIMD loop for instruction throughput
508 for _ in 0..simd_count {
509 let src_vec1 = _mm256_loadu_ps(src.add(offset));
510 let sum_vec1 = _mm256_add_ps(src_vec1, scalar_vec);
511 _mm256_storeu_ps(dst.add(offset), sum_vec1);
512
513 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
514 let sum_vec2 = _mm256_add_ps(src_vec2, scalar_vec);
515 _mm256_storeu_ps(dst.add(offset + 8), sum_vec2);
516
517 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
518 let sum_vec3 = _mm256_add_ps(src_vec3, scalar_vec);
519 _mm256_storeu_ps(dst.add(offset + 16), sum_vec3);
520
521 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
522 let sum_vec4 = _mm256_add_ps(src_vec4, scalar_vec);
523 _mm256_storeu_ps(dst.add(offset + 24), sum_vec4);
524
525 offset += 32;
526 }
527
528 // Handle remaining 8-element blocks
529 let remaining_full_blocks = (size - offset) / 8;
530 for _ in 0..remaining_full_blocks {
531 let src_vec = _mm256_loadu_ps(src.add(offset));
532 let sum_vec = _mm256_add_ps(src_vec, scalar_vec);
533 _mm256_storeu_ps(dst.add(offset), sum_vec);
534 offset += 8;
535 }
536
537 // Handle final elements
538 for i in offset..size {
539 *dst.add(i) = *src.add(i) + scalar;
540 }
541 }
542
543 /// Optimized scalar addition fallback
544 ///
545 /// Performs element-wise scalar addition using optimized scalar operations when
546 /// SIMD is not available. Uses 4x unrolling for better instruction-level
547 /// parallelism and cache efficiency.
548 ///
549 /// # Arguments
550 ///
551 /// * `src` - Pointer to source tensor data
552 /// * `dst` - Pointer to output tensor data
553 /// * `scalar` - Scalar value to add to each element
554 ///
555 /// # Safety
556 ///
557 /// Requires valid pointers with sufficient memory for the tensor size.
558 /// All pointers must point to valid tensor data.
559 ///
560 /// # Performance Characteristics
561 ///
562 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
563 /// - **Memory Access**: Linear access patterns for cache efficiency
564 /// - **Fallback**: Handles remaining elements with scalar operations
565 #[inline]
566 unsafe fn add_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
567 let size = self.size();
568 let unroll_count = size / 4;
569 let mut offset = 0;
570
571 // Unrolled scalar operations with while for clarity
572 for _ in 0..unroll_count {
573 *dst.add(offset) = *src.add(offset) + scalar;
574 *dst.add(offset + 1) = *src.add(offset + 1) + scalar;
575 *dst.add(offset + 2) = *src.add(offset + 2) + scalar;
576 *dst.add(offset + 3) = *src.add(offset + 3) + scalar;
577 offset += 4;
578 }
579 for i in offset..size {
580 *dst.add(i) = *src.add(i) + scalar;
581 }
582 }
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 #[test]
590 fn test_tensor_addition() {
591 let a = Tensor::ones(vec![2, 3]);
592 let b = Tensor::ones(vec![2, 3]);
593 let result = a.add_tensor_optimized(&b);
594
595 assert_eq!(result.shape().dims, vec![2, 3]);
596 assert_eq!(result.size(), 6);
597
598 // Check that all values are 2.0 (1.0 + 1.0)
599 unsafe {
600 for i in 0..result.size() {
601 assert!((result.as_ptr().add(i).read() - 2.0).abs() < 1e-6);
602 }
603 }
604 }
605
606 #[test]
607 fn test_scalar_addition() {
608 let tensor = Tensor::ones(vec![2, 2]);
609 let result = tensor.add_scalar_optimized(5.0);
610
611 assert_eq!(result.shape().dims, vec![2, 2]);
612 assert_eq!(result.size(), 4);
613
614 // Check that all values are 6.0 (1.0 + 5.0)
615 unsafe {
616 for i in 0..result.size() {
617 assert!((result.as_ptr().add(i).read() - 6.0).abs() < 1e-6);
618 }
619 }
620 }
621
622 #[test]
623 #[should_panic(expected = "Tensor shapes must match")]
624 fn test_mismatched_shapes() {
625 let a = Tensor::ones(vec![2, 3]);
626 let b = Tensor::ones(vec![3, 2]);
627 a.add_tensor_optimized(&b);
628 }
629
630 #[test]
631 fn test_add_with_no_grad_guard() {
632 use crate::gradtrack::{is_grad_enabled, NoGradTrack};
633
634 // Create tensors with requires_grad enabled
635 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
636 let b = Tensor::ones(vec![2, 2]).with_requires_grad();
637
638 // Verify gradients are enabled by default
639 assert!(is_grad_enabled());
640
641 // Normal addition with gradients
642 let c1 = a.add_tensor(&b);
643 assert!(
644 c1.requires_grad(),
645 "Result should require gradients normally"
646 );
647
648 // Addition with NoGradTrack - gradients should be disabled
649 {
650 let _guard = NoGradTrack::new();
651 assert!(
652 !is_grad_enabled(),
653 "Gradients should be disabled within guard"
654 );
655
656 let c2 = a.add_tensor(&b);
657 assert!(
658 !c2.requires_grad(),
659 "Result should not require gradients within NoGradTrack"
660 );
661
662 // Test scalar addition as well
663 let c3 = a.add_scalar(5.0);
664 assert!(
665 !c3.requires_grad(),
666 "Scalar addition result should not require gradients within NoGradTrack"
667 );
668 }
669
670 // Gradients should be restored after guard goes out of scope
671 assert!(
672 is_grad_enabled(),
673 "Gradients should be restored after guard"
674 );
675
676 let c4 = a.add_tensor(&b);
677 assert!(
678 c4.requires_grad(),
679 "Result should require gradients after guard is dropped"
680 );
681 }
682
683 #[test]
684 fn test_add_nested_no_grad_guards() {
685 use crate::gradtrack::{is_grad_enabled, NoGradTrack};
686
687 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
688 let b = Tensor::ones(vec![2, 2]).with_requires_grad();
689
690 assert!(is_grad_enabled());
691
692 {
693 let _guard1 = NoGradTrack::new();
694 assert!(!is_grad_enabled());
695
696 let c1 = a.add_tensor(&b);
697 assert!(!c1.requires_grad());
698
699 {
700 let _guard2 = NoGradTrack::new();
701 assert!(!is_grad_enabled());
702
703 let c2 = a.add_tensor(&b);
704 assert!(!c2.requires_grad());
705 }
706
707 // Still disabled after inner guard drops
708 assert!(!is_grad_enabled());
709 let c3 = a.add_tensor(&b);
710 assert!(!c3.requires_grad());
711 }
712
713 // Restored after all guards drop
714 assert!(is_grad_enabled());
715 let c4 = a.add_tensor(&b);
716 assert!(c4.requires_grad());
717 }
718
719 #[test]
720 fn test_add_with_mixed_requires_grad() {
721 use crate::gradtrack::NoGradTrack;
722
723 let a = Tensor::ones(vec![2, 2]).with_requires_grad(); // requires_grad = true
724 let b = Tensor::ones(vec![2, 2]); // requires_grad = false
725
726 // Without NoGradTrack, result should require gradients if any input does
727 let c1 = a.add_tensor(&b);
728 assert!(c1.requires_grad());
729
730 let c2 = b.add_tensor(&a);
731 assert!(c2.requires_grad());
732
733 // With NoGradTrack, result should not require gradients regardless
734 {
735 let _guard = NoGradTrack::new();
736
737 let c3 = a.add_tensor(&b);
738 assert!(!c3.requires_grad());
739
740 let c4 = b.add_tensor(&a);
741 assert!(!c4.requires_grad());
742 }
743 }
744
745 #[test]
746 fn test_add_performance_no_overhead() {
747 use crate::gradtrack::NoGradTrack;
748 use std::time::Instant;
749
750 let size = 1000; // Smaller size for test stability
751 let a = Tensor::ones(vec![size]).with_requires_grad();
752 let b = Tensor::ones(vec![size]);
753
754 // Time normal addition (with potential grad overhead)
755 let start = Instant::now();
756 for _ in 0..10 {
757 let _ = a.add_tensor(&b);
758 }
759 let normal_duration = start.elapsed();
760
761 // Time addition with NoGradTrack (should be faster)
762 let start = Instant::now();
763 {
764 let _guard = NoGradTrack::new();
765 for _ in 0..10 {
766 let _ = a.add_tensor(&b);
767 }
768 }
769 let no_grad_duration = start.elapsed();
770
771 // NoGradTrack should provide performance benefit by skipping gradtrack setup
772 // Allow generous variance for timing inconsistencies in tests
773 println!(
774 "Normal: {:?}, NoGrad: {:?}",
775 normal_duration, no_grad_duration
776 );
777
778 // The key verification is that NoGradTrack doesn't add overhead
779 assert!(
780 no_grad_duration <= normal_duration * 3,
781 "NoGradTrack should not add significant overhead"
782 );
783 }
784
785 #[test]
786 fn test_broadcasting_gradients_basic() {
787 use crate::gradtrack::clear_gradients;
788 clear_gradients();
789
790 // Test case: [2, 3] + [1, 3] -> [2, 3]
791 // grad_a should be [2, 3], grad_b should be [1, 3] (summed over broadcast dim)
792
793 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
794 .unwrap()
795 .with_requires_grad();
796 let b = Tensor::from_slice(&[0.1, 0.2, 0.3], vec![1, 3])
797 .unwrap()
798 .with_requires_grad();
799
800 let mut result = a.add_tensor(&b);
801 assert_eq!(result.shape().dims, vec![2, 3]);
802
803 // Set upstream gradient as ones
804 result.backward(None);
805
806 // Check gradients
807 let grad_a = a.grad_by_value().expect("grad_a should exist");
808 let grad_b = b.grad_by_value().expect("grad_b should exist");
809
810 println!(
811 "Original shapes: a={:?}, b={:?}",
812 a.shape().dims,
813 b.shape().dims
814 );
815 println!(
816 "Gradient shapes: grad_a={:?}, grad_b={:?}",
817 grad_a.shape().dims,
818 grad_b.shape().dims
819 );
820
821 // grad_a should have same shape as a: [2, 3]
822 assert_eq!(
823 grad_a.shape().dims,
824 vec![2, 3],
825 "grad_a should match original shape of a"
826 );
827
828 // grad_b should have same shape as b: [1, 3]
829 // This requires summing over the broadcasted dimension
830 assert_eq!(
831 grad_b.shape().dims,
832 vec![1, 3],
833 "grad_b should match original shape of b"
834 );
835
836 // All gradients should be 1.0 for grad_a
837 for i in 0..grad_a.size() {
838 let val = unsafe { *grad_a.as_ptr().add(i) };
839 assert!(
840 (val - 1.0).abs() < 1e-6,
841 "grad_a[{}] = {} should be 1.0",
842 i,
843 val
844 );
845 }
846
847 // grad_b should be [2.0, 2.0, 2.0] (sum over broadcast dim)
848 let expected_grad_b = [2.0, 2.0, 2.0];
849 for (i, val) in expected_grad_b.iter().enumerate().take(grad_b.size()) {
850 let actual = unsafe { *grad_b.as_ptr().add(i) };
851 assert!(
852 (actual - val).abs() < 1e-6,
853 "grad_b[{}] = {} should be {}",
854 i,
855 actual,
856 val
857 );
858 }
859 }
860
861 #[test]
862 fn test_scalar_broadcasting_gradients() {
863 use crate::gradtrack::clear_gradients;
864 clear_gradients();
865
866 // Test case: [2, 3] + [1] -> [2, 3]
867 // grad_a should be [2, 3], grad_b should be [1] (summed over all dims)
868
869 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
870 .unwrap()
871 .with_requires_grad();
872 let b = Tensor::from_slice(&[0.5], vec![1])
873 .unwrap()
874 .with_requires_grad();
875
876 let mut result = a.add_tensor(&b);
877 result.backward(None);
878
879 let grad_a = a.grad_by_value().expect("grad_a should exist");
880 let grad_b = b.grad_by_value().expect("grad_b should exist");
881
882 // grad_a should have same shape as a: [2, 3]
883 assert_eq!(grad_a.shape().dims, vec![2, 3]);
884
885 // grad_b should have same shape as b: [1] and sum to 6.0
886 println!("grad_b shape: {:?}, expected: [1]", grad_b.shape().dims);
887 assert_eq!(grad_b.shape().dims, vec![1]);
888
889 // grad_b should be 6.0 (sum over all 6 elements)
890 let val = unsafe { *grad_b.as_ptr() };
891 assert!((val - 6.0).abs() < 1e-6, "grad_b = {} should be 6.0", val);
892 }
893
894 #[test]
895 fn test_linear_layer_bias_broadcasting() {
896 use crate::gradtrack::clear_gradients;
897 clear_gradients();
898
899 // Simulate linear layer bias broadcasting
900 // input: [2, 3], weight: [3, 4], bias: [4]
901 // matmul result: [2, 4], bias broadcast: [4] -> [2, 4]
902
903 let input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
904 .unwrap()
905 .with_requires_grad();
906 let weight = Tensor::from_slice(
907 &(1..=12).map(|i| i as f32 * 0.1).collect::<Vec<_>>(),
908 vec![3, 4],
909 )
910 .unwrap()
911 .with_requires_grad();
912 let bias = Tensor::from_slice(&[0.1, 0.2, 0.3, 0.4], vec![4])
913 .unwrap()
914 .with_requires_grad();
915
916 // Forward pass: input @ weight + bias
917 let matmul_result = input.matmul(&weight);
918 println!("Matmul result shape: {:?}", matmul_result.shape().dims);
919 println!("Bias shape: {:?}", bias.shape().dims);
920
921 let linear_output = matmul_result.add_tensor(&bias);
922 println!("Linear output shape: {:?}", linear_output.shape().dims);
923
924 // Sum all outputs as loss
925 let mut loss = linear_output.sum();
926 loss.backward(None);
927
928 // Check bias gradient
929 let bias_grad = bias.grad_by_value().expect("bias gradient should exist");
930 println!("Bias gradient shape: {:?}", bias_grad.shape().dims);
931 assert_eq!(
932 bias_grad.shape().dims,
933 vec![4],
934 "bias gradient should match bias shape"
935 );
936
937 // Bias gradient should be [2.0, 2.0, 2.0, 2.0] (sum over batch dimension)
938 for i in 0..4 {
939 let val = unsafe { *bias_grad.as_ptr().add(i) };
940 assert!(
941 (val - 2.0).abs() < 1e-6,
942 "bias_grad[{}] = {} should be 2.0",
943 i,
944 val
945 );
946 }
947
948 println!("Linear layer bias broadcasting test passed!");
949 }
950}