train_station/tensor/ops/add.rs
1//! Addition operations for tensors
2//!
3//! Provides element-wise addition following PyTorch conventions with comprehensive
4//! broadcasting support, automatic differentiation, and high-performance SIMD optimization.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Addition**: `add_tensor()` - Addition with another tensor (PyTorch `add()` equivalent)
9//! - **Scalar Broadcasting**: `add_scalar()` - Addition with scalar values
10//! - **Automatic Broadcasting**: NumPy-style broadcasting for compatible shapes
11//! - **SIMD Optimization**: AVX2 acceleration on x86_64 hardware
12//! - **Automatic Differentiation**: Full gradtrack support with gradient tracking
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Zero-copy Operations**: Efficient memory usage where possible
15//!
16//! # Broadcasting Support
17//!
18//! All addition operations support automatic broadcasting following NumPy rules:
19//! - Dimensions are aligned from the rightmost dimension
20//! - Dimensions are compatible if they are equal, or one of them is 1
21//! - Missing dimensions are treated as 1
22//! - Result shape follows broadcasting rules
23//!
24//! # Performance Characteristics
25//!
26//! - **SIMD Acceleration**: 8x vectorization with AVX2 on compatible hardware
27//! - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
28//! - **Cache-friendly Access**: Linear memory access patterns
29//! - **Fallback Support**: Optimized scalar implementations for non-SIMD hardware
30//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
31
32use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
33use crate::tensor::core::Tensor;
34
35// SIMD optimizations for performance-critical operations
36#[cfg(target_arch = "x86_64")]
37use std::arch::x86_64::*;
38
39// (Removed manual prefetching: simplifies hot path; modern CPUs prefetch effectively for linear access)
40
41impl Tensor {
42 /// Element-wise addition with another tensor with broadcasting support.
43 ///
44 /// Performs element-wise addition with automatic broadcasting: `output[i] = self[i] + other[i]`
45 ///
46 /// Broadcasting enables addition between tensors of different but compatible shapes.
47 /// Compatible shapes follow NumPy broadcasting rules:
48 /// - Dimensions are aligned from the rightmost dimension
49 /// - Dimensions are compatible if they are equal, or one of them is 1
50 /// - Missing dimensions are treated as 1
51 ///
52 /// # Arguments
53 /// * `other` - Tensor to add. Shapes must be broadcast-compatible.
54 ///
55 /// # Returns
56 /// A new tensor containing the element-wise sum with broadcast result shape
57 ///
58 /// # Examples
59 ///
60 /// ## Same Shape Addition
61 ///
62 /// ```
63 /// use train_station::Tensor;
64 ///
65 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
66 /// let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
67 /// let c = a.add_tensor(&b);
68 /// assert_eq!(c.shape().dims, vec![3]);
69 /// assert_eq!(c.get(&[0]), 5.0);
70 /// assert_eq!(c.get(&[1]), 7.0);
71 /// assert_eq!(c.get(&[2]), 9.0);
72 /// ```
73 ///
74 /// ## Broadcasting Addition
75 ///
76 /// ```
77 /// use train_station::Tensor;
78 ///
79 /// // Broadcasting: [2, 1] + [1, 3] -> [2, 3]
80 /// let a = Tensor::from_slice(&[1.0, 2.0], vec![2, 1]).unwrap();
81 /// let b = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![1, 3]).unwrap();
82 /// let c = a.add_tensor(&b);
83 /// assert_eq!(c.shape().dims, vec![2, 3]);
84 /// assert_eq!(c.get(&[0, 0]), 11.0);
85 /// assert_eq!(c.get(&[0, 1]), 21.0);
86 /// assert_eq!(c.get(&[1, 0]), 12.0);
87 /// assert_eq!(c.get(&[1, 1]), 22.0);
88 /// ```
89 ///
90 /// ## Scalar Broadcasting
91 ///
92 /// ```
93 /// use train_station::Tensor;
94 ///
95 /// // Scalar broadcasting: [2, 3] + scalar -> [2, 3]
96 /// let a = Tensor::ones(vec![2, 3]);
97 /// let b = Tensor::from_slice(&[5.0], vec![1]).unwrap();
98 /// let c = a.add_tensor(&b);
99 /// assert_eq!(c.shape().dims, vec![2, 3]);
100 /// assert_eq!(c.get(&[0, 0]), 6.0);
101 /// assert_eq!(c.get(&[1, 2]), 6.0);
102 /// ```
103 ///
104 /// # Panics
105 /// Panics if tensor shapes are not broadcast-compatible
106 #[inline]
107 pub fn add_tensor(&self, other: &Tensor) -> Tensor {
108 // Check if shapes are identical for fast path
109 if self.shape().dims == other.shape().dims {
110 return self.add_tensor_same_shape(other);
111 }
112
113 // Use broadcasting for different shapes
114 let (broadcast_self, broadcast_other, _result_shape) =
115 self.broadcast_with(other).unwrap_or_else(|e| {
116 panic!(
117 "Cannot broadcast tensor shapes {:?} and {:?}: {}",
118 self.shape().dims,
119 other.shape().dims,
120 e
121 );
122 });
123
124 // Perform element-wise addition on broadcasted tensors
125 let mut result = broadcast_self.add_tensor_optimized(&broadcast_other);
126
127 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
128 result.set_requires_grad_internal(true);
129 let grad_fn = GradFn::Add {
130 is_tensor_add: true,
131 original_shapes: Some((self.shape().dims.clone(), other.shape().dims.clone())),
132 };
133 result.set_grad_fn(grad_fn.clone());
134
135 let mut input_ids = Vec::with_capacity(2);
136 if self.requires_grad() {
137 input_ids.push(self.id());
138 }
139 if other.requires_grad() {
140 input_ids.push(other.id());
141 }
142 GradEngine::register_operation(result.id(), input_ids, grad_fn);
143 }
144
145 result
146 }
147
148 /// Element-wise addition for tensors with identical shapes (fast path).
149 ///
150 /// This is an optimized path for tensors that already have the same shape,
151 /// avoiding the overhead of broadcasting computation. Used internally by
152 /// `add_tensor()` when shapes are identical.
153 ///
154 /// # Arguments
155 /// * `other` - Tensor to add, must have the same shape as self
156 ///
157 /// # Returns
158 /// A new tensor containing the element-wise sum
159 ///
160 /// # Performance Characteristics
161 ///
162 /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
163 /// - **SIMD Optimization**: Uses optimized tensor addition with SIMD acceleration
164 /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
165 ///
166 /// # Panics
167 ///
168 /// Panics if tensor shapes do not match
169 #[inline]
170 fn add_tensor_same_shape(&self, other: &Tensor) -> Tensor {
171 assert_eq!(
172 self.shape(),
173 other.shape(),
174 "Tensor shapes must match for same-shape addition"
175 );
176 let mut result = self.add_tensor_optimized(other);
177
178 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
179 result.set_requires_grad_internal(true);
180 let grad_fn = GradFn::Add {
181 is_tensor_add: true,
182 original_shapes: None, // Same shape case
183 };
184 result.set_grad_fn(grad_fn.clone());
185
186 let mut input_ids = Vec::with_capacity(2);
187 if self.requires_grad() {
188 input_ids.push(self.id());
189 }
190 if other.requires_grad() {
191 input_ids.push(other.id());
192 }
193 GradEngine::register_operation(result.id(), input_ids, grad_fn);
194 }
195
196 result
197 }
198
199 /// Broadcast addition with a scalar value.
200 ///
201 /// Adds the scalar to every element: `output[i] = self[i] + scalar`
202 ///
203 /// # Arguments
204 /// * `scalar` - Value to add to each element
205 ///
206 /// # Returns
207 /// A new tensor with the scalar added to each element
208 ///
209 /// # Examples
210 ///
211 /// ## Basic Scalar Addition
212 ///
213 /// ```
214 /// use train_station::Tensor;
215 ///
216 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
217 /// let b = a.add_scalar(10.0);
218 /// assert_eq!(b.shape().dims, vec![3]);
219 /// assert_eq!(b.get(&[0]), 11.0);
220 /// assert_eq!(b.get(&[1]), 12.0);
221 /// assert_eq!(b.get(&[2]), 13.0);
222 /// ```
223 ///
224 /// ## Multi-dimensional Scalar Addition
225 ///
226 /// ```
227 /// use train_station::Tensor;
228 ///
229 /// let a = Tensor::ones(vec![2, 3]);
230 /// let b = a.add_scalar(5.0);
231 /// assert_eq!(b.shape().dims, vec![2, 3]);
232 /// assert_eq!(b.get(&[0, 0]), 6.0);
233 /// assert_eq!(b.get(&[1, 2]), 6.0);
234 /// ```
235 #[inline]
236 pub fn add_scalar(&self, scalar: f32) -> Tensor {
237 let mut result = self.add_scalar_optimized(scalar);
238
239 if self.requires_grad() && is_grad_enabled() {
240 result.set_requires_grad_internal(true);
241 let grad_fn = GradFn::Add {
242 is_tensor_add: false,
243 original_shapes: None, // Scalar case
244 };
245 result.set_grad_fn(grad_fn.clone());
246 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
247 }
248
249 result
250 }
251 /// Internal optimized tensor + tensor operation
252 ///
253 /// Performs element-wise addition between two tensors with the same shape,
254 /// using SIMD acceleration when available. This is the core implementation
255 /// used by `add_tensor()` after broadcasting has been applied.
256 ///
257 /// # Arguments
258 ///
259 /// * `other` - Tensor to add, must have the same shape as self
260 ///
261 /// # Returns
262 ///
263 /// A new tensor containing the element-wise sum
264 ///
265 /// # Safety
266 ///
267 /// Assumes both tensors have the same shape and valid memory layouts.
268 /// Uses unsafe SIMD operations for performance optimization.
269 ///
270 /// # Performance Characteristics
271 ///
272 /// - **SIMD Optimization**: Uses AVX2 when available for 8x vectorization
273 /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
274 /// - **Cache-friendly**: Linear memory access patterns
275 /// - **Fallback**: Optimized scalar implementation for non-SIMD hardware
276 #[inline]
277 pub(crate) fn add_tensor_optimized(&self, other: &Tensor) -> Tensor {
278 assert_eq!(self.shape(), other.shape(), "Tensor shapes must match");
279
280 let mut output = Tensor::new(self.shape().dims.clone());
281
282 unsafe {
283 let a = self.as_ptr();
284 let b = other.as_ptr();
285 let dst = output.as_mut_ptr();
286
287 #[cfg(target_arch = "x86_64")]
288 {
289 // Use SIMD for better performance when available
290 if is_x86_feature_detected!("avx2") {
291 self.add_tensors_simd_avx2_optimized(a, b, dst);
292 return output;
293 }
294 }
295
296 // Fallback to scalar operations with better cache usage
297 self.add_tensors_scalar_optimized(a, b, dst);
298 }
299
300 output
301 }
302
303 /// SIMD-optimized tensor addition using AVX2 instructions
304 ///
305 /// Performs element-wise addition using AVX2 SIMD instructions for maximum
306 /// performance on x86_64 hardware. Processes 32 elements per iteration with
307 /// 4x unrolling for optimal instruction throughput.
308 ///
309 /// # Arguments
310 ///
311 /// * `a` - Pointer to first tensor data
312 /// * `b` - Pointer to second tensor data
313 /// * `dst` - Pointer to output tensor data
314 ///
315 /// # Safety
316 ///
317 /// Requires AVX2 support and valid pointers with sufficient memory.
318 /// All pointers must be aligned and point to valid tensor data.
319 ///
320 /// # Performance Characteristics
321 ///
322 /// - **SIMD Width**: 8 elements per AVX2 vector operation
323 /// - **Unrolling**: 4x unrolling (32 elements per iteration)
324 /// - **Memory Access**: Linear access patterns for cache efficiency
325 /// - **Fallback**: Handles remaining elements with scalar operations
326 #[cfg(target_arch = "x86_64")]
327 #[inline]
328 #[target_feature(enable = "avx2")]
329 unsafe fn add_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
330 let size = self.size();
331 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
332 let mut offset = 0;
333
334 // Unrolled SIMD loop for throughput
335 for _ in 0..simd_count {
336 // Process 4 AVX2 vectors (32 elements) per iteration
337 let a_vec1 = _mm256_loadu_ps(a.add(offset));
338 let b_vec1 = _mm256_loadu_ps(b.add(offset));
339 let sum_vec1 = _mm256_add_ps(a_vec1, b_vec1);
340 _mm256_storeu_ps(dst.add(offset), sum_vec1);
341
342 let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
343 let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
344 let sum_vec2 = _mm256_add_ps(a_vec2, b_vec2);
345 _mm256_storeu_ps(dst.add(offset + 8), sum_vec2);
346
347 let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
348 let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
349 let sum_vec3 = _mm256_add_ps(a_vec3, b_vec3);
350 _mm256_storeu_ps(dst.add(offset + 16), sum_vec3);
351
352 let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
353 let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
354 let sum_vec4 = _mm256_add_ps(a_vec4, b_vec4);
355 _mm256_storeu_ps(dst.add(offset + 24), sum_vec4);
356
357 offset += 32;
358 }
359
360 // Handle remaining elements in blocks of 8 then tail
361 let remaining_full_blocks = (size - offset) / 8;
362 for _ in 0..remaining_full_blocks {
363 let a_vec = _mm256_loadu_ps(a.add(offset));
364 let b_vec = _mm256_loadu_ps(b.add(offset));
365 let sum_vec = _mm256_add_ps(a_vec, b_vec);
366 _mm256_storeu_ps(dst.add(offset), sum_vec);
367 offset += 8;
368 }
369 while offset + 4 <= size {
370 *dst.add(offset) = *a.add(offset) + *b.add(offset);
371 *dst.add(offset + 1) = *a.add(offset + 1) + *b.add(offset + 1);
372 *dst.add(offset + 2) = *a.add(offset + 2) + *b.add(offset + 2);
373 *dst.add(offset + 3) = *a.add(offset + 3) + *b.add(offset + 3);
374 offset += 4;
375 }
376 for i in offset..size {
377 *dst.add(i) = *a.add(i) + *b.add(i);
378 }
379 }
380
381 /// Optimized scalar tensor addition fallback
382 ///
383 /// Performs element-wise addition using optimized scalar operations when
384 /// SIMD is not available. Uses 4x unrolling for better instruction-level
385 /// parallelism and cache efficiency.
386 ///
387 /// # Arguments
388 ///
389 /// * `a` - Pointer to first tensor data
390 /// * `b` - Pointer to second tensor data
391 /// * `dst` - Pointer to output tensor data
392 ///
393 /// # Safety
394 ///
395 /// Requires valid pointers with sufficient memory for the tensor size.
396 /// All pointers must point to valid tensor data.
397 ///
398 /// # Performance Characteristics
399 ///
400 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
401 /// - **Memory Access**: Linear access patterns for cache efficiency
402 /// - **Fallback**: Handles remaining elements with scalar operations
403 #[inline]
404 unsafe fn add_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
405 let size = self.size();
406 let unroll_count = size / 4;
407 let mut offset = 0;
408
409 // Unrolled scalar loop for better performance
410 for _ in 0..unroll_count {
411 *dst.add(offset) = *a.add(offset) + *b.add(offset);
412 *dst.add(offset + 1) = *a.add(offset + 1) + *b.add(offset + 1);
413 *dst.add(offset + 2) = *a.add(offset + 2) + *b.add(offset + 2);
414 *dst.add(offset + 3) = *a.add(offset + 3) + *b.add(offset + 3);
415 offset += 4;
416 }
417
418 // Handle remaining elements
419 for i in offset..size {
420 *dst.add(i) = *a.add(i) + *b.add(i);
421 }
422 }
423
424 /// Internal optimized scalar + tensor operation
425 ///
426 /// Performs element-wise addition of a scalar to each element of the tensor,
427 /// using SIMD acceleration when available. This is the core implementation
428 /// used by `add_scalar()`.
429 ///
430 /// # Arguments
431 ///
432 /// * `scalar` - Scalar value to add to each element
433 ///
434 /// # Returns
435 ///
436 /// A new tensor with the scalar added to each element
437 ///
438 /// # Safety
439 ///
440 /// Assumes valid tensor memory layout. Uses unsafe SIMD operations for
441 /// performance optimization.
442 ///
443 /// # Performance Characteristics
444 ///
445 /// - **SIMD Optimization**: Uses AVX2 when available for 8x vectorization
446 /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
447 /// - **Cache-friendly**: Linear memory access patterns
448 /// - **Fallback**: Optimized scalar implementation for non-SIMD hardware
449 #[inline]
450 pub(crate) fn add_scalar_optimized(&self, scalar: f32) -> Tensor {
451 let mut output = Tensor::new(self.shape().dims.clone());
452
453 unsafe {
454 let src = self.as_ptr();
455 let dst = output.as_mut_ptr();
456
457 #[cfg(target_arch = "x86_64")]
458 {
459 // Use SIMD for better performance when available
460 if is_x86_feature_detected!("avx2") {
461 self.add_scalar_simd_avx2_optimized(src, dst, scalar);
462 return output;
463 }
464 }
465
466 // Fallback to optimized scalar operations
467 self.add_scalar_fallback_optimized(src, dst, scalar);
468 }
469
470 output
471 }
472
473 /// SIMD-optimized scalar addition using AVX2 instructions
474 ///
475 /// Performs element-wise scalar addition using AVX2 SIMD instructions for maximum
476 /// performance on x86_64 hardware. Processes 32 elements per iteration with
477 /// 4x unrolling for optimal instruction throughput.
478 ///
479 /// # Arguments
480 ///
481 /// * `src` - Pointer to source tensor data
482 /// * `dst` - Pointer to output tensor data
483 /// * `scalar` - Scalar value to add to each element
484 ///
485 /// # Safety
486 ///
487 /// Requires AVX2 support and valid pointers with sufficient memory.
488 /// All pointers must be aligned and point to valid tensor data.
489 ///
490 /// # Performance Characteristics
491 ///
492 /// - **SIMD Width**: 8 elements per AVX2 vector operation
493 /// - **Unrolling**: 4x unrolling (32 elements per iteration)
494 /// - **Memory Access**: Linear access patterns for cache efficiency
495 /// - **Fallback**: Handles remaining elements with scalar operations
496 #[cfg(target_arch = "x86_64")]
497 #[inline]
498 #[target_feature(enable = "avx2")]
499 unsafe fn add_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
500 let scalar_vec = _mm256_set1_ps(scalar);
501 let size = self.size();
502 let simd_count = size / 32; // Process 32 elements per iteration
503 let mut offset = 0;
504
505 // Unrolled SIMD loop for instruction throughput
506 for _ in 0..simd_count {
507 let src_vec1 = _mm256_loadu_ps(src.add(offset));
508 let sum_vec1 = _mm256_add_ps(src_vec1, scalar_vec);
509 _mm256_storeu_ps(dst.add(offset), sum_vec1);
510
511 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
512 let sum_vec2 = _mm256_add_ps(src_vec2, scalar_vec);
513 _mm256_storeu_ps(dst.add(offset + 8), sum_vec2);
514
515 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
516 let sum_vec3 = _mm256_add_ps(src_vec3, scalar_vec);
517 _mm256_storeu_ps(dst.add(offset + 16), sum_vec3);
518
519 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
520 let sum_vec4 = _mm256_add_ps(src_vec4, scalar_vec);
521 _mm256_storeu_ps(dst.add(offset + 24), sum_vec4);
522
523 offset += 32;
524 }
525
526 // Handle remaining 8-element blocks
527 let remaining_full_blocks = (size - offset) / 8;
528 for _ in 0..remaining_full_blocks {
529 let src_vec = _mm256_loadu_ps(src.add(offset));
530 let sum_vec = _mm256_add_ps(src_vec, scalar_vec);
531 _mm256_storeu_ps(dst.add(offset), sum_vec);
532 offset += 8;
533 }
534
535 // Handle final elements
536 for i in offset..size {
537 *dst.add(i) = *src.add(i) + scalar;
538 }
539 }
540
541 /// Optimized scalar addition fallback
542 ///
543 /// Performs element-wise scalar addition using optimized scalar operations when
544 /// SIMD is not available. Uses 4x unrolling for better instruction-level
545 /// parallelism and cache efficiency.
546 ///
547 /// # Arguments
548 ///
549 /// * `src` - Pointer to source tensor data
550 /// * `dst` - Pointer to output tensor data
551 /// * `scalar` - Scalar value to add to each element
552 ///
553 /// # Safety
554 ///
555 /// Requires valid pointers with sufficient memory for the tensor size.
556 /// All pointers must point to valid tensor data.
557 ///
558 /// # Performance Characteristics
559 ///
560 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
561 /// - **Memory Access**: Linear access patterns for cache efficiency
562 /// - **Fallback**: Handles remaining elements with scalar operations
563 #[inline]
564 unsafe fn add_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
565 let size = self.size();
566 let unroll_count = size / 4;
567 let mut offset = 0;
568
569 // Unrolled scalar operations with while for clarity
570 for _ in 0..unroll_count {
571 *dst.add(offset) = *src.add(offset) + scalar;
572 *dst.add(offset + 1) = *src.add(offset + 1) + scalar;
573 *dst.add(offset + 2) = *src.add(offset + 2) + scalar;
574 *dst.add(offset + 3) = *src.add(offset + 3) + scalar;
575 offset += 4;
576 }
577 for i in offset..size {
578 *dst.add(i) = *src.add(i) + scalar;
579 }
580 }
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586
587 #[test]
588 fn test_tensor_addition() {
589 let a = Tensor::ones(vec![2, 3]);
590 let b = Tensor::ones(vec![2, 3]);
591 let result = a.add_tensor_optimized(&b);
592
593 assert_eq!(result.shape().dims, vec![2, 3]);
594 assert_eq!(result.size(), 6);
595
596 // Check that all values are 2.0 (1.0 + 1.0)
597 unsafe {
598 for i in 0..result.size() {
599 assert!((result.as_ptr().add(i).read() - 2.0).abs() < 1e-6);
600 }
601 }
602 }
603
604 #[test]
605 fn test_scalar_addition() {
606 let tensor = Tensor::ones(vec![2, 2]);
607 let result = tensor.add_scalar_optimized(5.0);
608
609 assert_eq!(result.shape().dims, vec![2, 2]);
610 assert_eq!(result.size(), 4);
611
612 // Check that all values are 6.0 (1.0 + 5.0)
613 unsafe {
614 for i in 0..result.size() {
615 assert!((result.as_ptr().add(i).read() - 6.0).abs() < 1e-6);
616 }
617 }
618 }
619
620 #[test]
621 #[should_panic(expected = "Tensor shapes must match")]
622 fn test_mismatched_shapes() {
623 let a = Tensor::ones(vec![2, 3]);
624 let b = Tensor::ones(vec![3, 2]);
625 a.add_tensor_optimized(&b);
626 }
627
628 #[test]
629 fn test_add_with_no_grad_guard() {
630 use crate::gradtrack::{is_grad_enabled, NoGradTrack};
631
632 // Create tensors with requires_grad enabled
633 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
634 let b = Tensor::ones(vec![2, 2]).with_requires_grad();
635
636 // Verify gradients are enabled by default
637 assert!(is_grad_enabled());
638
639 // Normal addition with gradients
640 let c1 = a.add_tensor(&b);
641 assert!(
642 c1.requires_grad(),
643 "Result should require gradients normally"
644 );
645
646 // Addition with NoGradTrack - gradients should be disabled
647 {
648 let _guard = NoGradTrack::new();
649 assert!(
650 !is_grad_enabled(),
651 "Gradients should be disabled within guard"
652 );
653
654 let c2 = a.add_tensor(&b);
655 assert!(
656 !c2.requires_grad(),
657 "Result should not require gradients within NoGradTrack"
658 );
659
660 // Test scalar addition as well
661 let c3 = a.add_scalar(5.0);
662 assert!(
663 !c3.requires_grad(),
664 "Scalar addition result should not require gradients within NoGradTrack"
665 );
666 }
667
668 // Gradients should be restored after guard goes out of scope
669 assert!(
670 is_grad_enabled(),
671 "Gradients should be restored after guard"
672 );
673
674 let c4 = a.add_tensor(&b);
675 assert!(
676 c4.requires_grad(),
677 "Result should require gradients after guard is dropped"
678 );
679 }
680
681 #[test]
682 fn test_add_nested_no_grad_guards() {
683 use crate::gradtrack::{is_grad_enabled, NoGradTrack};
684
685 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
686 let b = Tensor::ones(vec![2, 2]).with_requires_grad();
687
688 assert!(is_grad_enabled());
689
690 {
691 let _guard1 = NoGradTrack::new();
692 assert!(!is_grad_enabled());
693
694 let c1 = a.add_tensor(&b);
695 assert!(!c1.requires_grad());
696
697 {
698 let _guard2 = NoGradTrack::new();
699 assert!(!is_grad_enabled());
700
701 let c2 = a.add_tensor(&b);
702 assert!(!c2.requires_grad());
703 }
704
705 // Still disabled after inner guard drops
706 assert!(!is_grad_enabled());
707 let c3 = a.add_tensor(&b);
708 assert!(!c3.requires_grad());
709 }
710
711 // Restored after all guards drop
712 assert!(is_grad_enabled());
713 let c4 = a.add_tensor(&b);
714 assert!(c4.requires_grad());
715 }
716
717 #[test]
718 fn test_add_with_mixed_requires_grad() {
719 use crate::gradtrack::NoGradTrack;
720
721 let a = Tensor::ones(vec![2, 2]).with_requires_grad(); // requires_grad = true
722 let b = Tensor::ones(vec![2, 2]); // requires_grad = false
723
724 // Without NoGradTrack, result should require gradients if any input does
725 let c1 = a.add_tensor(&b);
726 assert!(c1.requires_grad());
727
728 let c2 = b.add_tensor(&a);
729 assert!(c2.requires_grad());
730
731 // With NoGradTrack, result should not require gradients regardless
732 {
733 let _guard = NoGradTrack::new();
734
735 let c3 = a.add_tensor(&b);
736 assert!(!c3.requires_grad());
737
738 let c4 = b.add_tensor(&a);
739 assert!(!c4.requires_grad());
740 }
741 }
742
743 #[test]
744 fn test_add_performance_no_overhead() {
745 use crate::gradtrack::NoGradTrack;
746 use std::time::Instant;
747
748 let size = 1000; // Smaller size for test stability
749 let a = Tensor::ones(vec![size]).with_requires_grad();
750 let b = Tensor::ones(vec![size]);
751
752 // Time normal addition (with potential grad overhead)
753 let start = Instant::now();
754 for _ in 0..10 {
755 let _ = a.add_tensor(&b);
756 }
757 let normal_duration = start.elapsed();
758
759 // Time addition with NoGradTrack (should be faster)
760 let start = Instant::now();
761 {
762 let _guard = NoGradTrack::new();
763 for _ in 0..10 {
764 let _ = a.add_tensor(&b);
765 }
766 }
767 let no_grad_duration = start.elapsed();
768
769 // NoGradTrack should provide performance benefit by skipping gradtrack setup
770 // Allow generous variance for timing inconsistencies in tests
771 println!(
772 "Normal: {:?}, NoGrad: {:?}",
773 normal_duration, no_grad_duration
774 );
775
776 // The key verification is that NoGradTrack doesn't add overhead
777 assert!(
778 no_grad_duration <= normal_duration * 3,
779 "NoGradTrack should not add significant overhead"
780 );
781 }
782
783 #[test]
784 fn test_broadcasting_gradients_basic() {
785 use crate::gradtrack::clear_gradients;
786 clear_gradients();
787
788 // Test case: [2, 3] + [1, 3] -> [2, 3]
789 // grad_a should be [2, 3], grad_b should be [1, 3] (summed over broadcast dim)
790
791 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
792 .unwrap()
793 .with_requires_grad();
794 let b = Tensor::from_slice(&[0.1, 0.2, 0.3], vec![1, 3])
795 .unwrap()
796 .with_requires_grad();
797
798 let mut result = a.add_tensor(&b);
799 assert_eq!(result.shape().dims, vec![2, 3]);
800
801 // Set upstream gradient as ones
802 result.backward(None);
803
804 // Check gradients
805 let grad_a = a.grad_by_value().expect("grad_a should exist");
806 let grad_b = b.grad_by_value().expect("grad_b should exist");
807
808 println!(
809 "Original shapes: a={:?}, b={:?}",
810 a.shape().dims,
811 b.shape().dims
812 );
813 println!(
814 "Gradient shapes: grad_a={:?}, grad_b={:?}",
815 grad_a.shape().dims,
816 grad_b.shape().dims
817 );
818
819 // grad_a should have same shape as a: [2, 3]
820 assert_eq!(
821 grad_a.shape().dims,
822 vec![2, 3],
823 "grad_a should match original shape of a"
824 );
825
826 // grad_b should have same shape as b: [1, 3]
827 // This requires summing over the broadcasted dimension
828 assert_eq!(
829 grad_b.shape().dims,
830 vec![1, 3],
831 "grad_b should match original shape of b"
832 );
833
834 // All gradients should be 1.0 for grad_a
835 for i in 0..grad_a.size() {
836 let val = unsafe { *grad_a.as_ptr().add(i) };
837 assert!(
838 (val - 1.0).abs() < 1e-6,
839 "grad_a[{}] = {} should be 1.0",
840 i,
841 val
842 );
843 }
844
845 // grad_b should be [2.0, 2.0, 2.0] (sum over broadcast dim)
846 let expected_grad_b = [2.0, 2.0, 2.0];
847 for (i, val) in expected_grad_b.iter().enumerate().take(grad_b.size()) {
848 let actual = unsafe { *grad_b.as_ptr().add(i) };
849 assert!(
850 (actual - val).abs() < 1e-6,
851 "grad_b[{}] = {} should be {}",
852 i,
853 actual,
854 val
855 );
856 }
857 }
858
859 #[test]
860 fn test_scalar_broadcasting_gradients() {
861 use crate::gradtrack::clear_gradients;
862 clear_gradients();
863
864 // Test case: [2, 3] + [1] -> [2, 3]
865 // grad_a should be [2, 3], grad_b should be [1] (summed over all dims)
866
867 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
868 .unwrap()
869 .with_requires_grad();
870 let b = Tensor::from_slice(&[0.5], vec![1])
871 .unwrap()
872 .with_requires_grad();
873
874 let mut result = a.add_tensor(&b);
875 result.backward(None);
876
877 let grad_a = a.grad_by_value().expect("grad_a should exist");
878 let grad_b = b.grad_by_value().expect("grad_b should exist");
879
880 // grad_a should have same shape as a: [2, 3]
881 assert_eq!(grad_a.shape().dims, vec![2, 3]);
882
883 // grad_b should have same shape as b: [1] and sum to 6.0
884 println!("grad_b shape: {:?}, expected: [1]", grad_b.shape().dims);
885 assert_eq!(grad_b.shape().dims, vec![1]);
886
887 // grad_b should be 6.0 (sum over all 6 elements)
888 let val = unsafe { *grad_b.as_ptr() };
889 assert!((val - 6.0).abs() < 1e-6, "grad_b = {} should be 6.0", val);
890 }
891
892 #[test]
893 fn test_linear_layer_bias_broadcasting() {
894 use crate::gradtrack::clear_gradients;
895 clear_gradients();
896
897 // Simulate linear layer bias broadcasting
898 // input: [2, 3], weight: [3, 4], bias: [4]
899 // matmul result: [2, 4], bias broadcast: [4] -> [2, 4]
900
901 let input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
902 .unwrap()
903 .with_requires_grad();
904 let weight = Tensor::from_slice(
905 &(1..=12).map(|i| i as f32 * 0.1).collect::<Vec<_>>(),
906 vec![3, 4],
907 )
908 .unwrap()
909 .with_requires_grad();
910 let bias = Tensor::from_slice(&[0.1, 0.2, 0.3, 0.4], vec![4])
911 .unwrap()
912 .with_requires_grad();
913
914 // Forward pass: input @ weight + bias
915 let matmul_result = input.matmul(&weight);
916 println!("Matmul result shape: {:?}", matmul_result.shape().dims);
917 println!("Bias shape: {:?}", bias.shape().dims);
918
919 let linear_output = matmul_result.add_tensor(&bias);
920 println!("Linear output shape: {:?}", linear_output.shape().dims);
921
922 // Sum all outputs as loss
923 let mut loss = linear_output.sum();
924 loss.backward(None);
925
926 // Check bias gradient
927 let bias_grad = bias.grad_by_value().expect("bias gradient should exist");
928 println!("Bias gradient shape: {:?}", bias_grad.shape().dims);
929 assert_eq!(
930 bias_grad.shape().dims,
931 vec![4],
932 "bias gradient should match bias shape"
933 );
934
935 // Bias gradient should be [2.0, 2.0, 2.0, 2.0] (sum over batch dimension)
936 for i in 0..4 {
937 let val = unsafe { *bias_grad.as_ptr().add(i) };
938 assert!(
939 (val - 2.0).abs() < 1e-6,
940 "bias_grad[{}] = {} should be 2.0",
941 i,
942 val
943 );
944 }
945
946 println!("Linear layer bias broadcasting test passed!");
947 }
948}