train_station/tensor/ops/
sub.rs

1//! Subtraction operations for tensors
2//!
3//! Provides element-wise subtraction operations following PyTorch conventions with
4//! comprehensive GradTrack support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Tensor Subtraction**: `sub_tensor()` - Element-wise subtraction with broadcasting support
9//! - **Scalar Subtraction**: `sub_scalar()` - Subtraction of scalar from tensor
10//! - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
11//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
12//! - **Broadcasting Support**: NumPy-style broadcasting for compatible shapes
13//! - **Performance Optimization**: 4x unrolled SIMD operations with scalar fallback
14//!
15//! # Mathematical Properties
16//!
17//! The subtraction operations have the following properties:
18//! - **Tensor-Tensor**: `output[i] = a[i] - b[i]` with broadcasting
19//! - **Tensor-Scalar**: `output[i] = a[i] - scalar` for all elements
20//! - **Commutativity**: Subtraction is not commutative (a - b ≠ b - a)
21//! - **Associativity**: Subtraction is not associative ((a - b) - c ≠ a - (b - c))
22//! - **Gradient**: For tensor-tensor: ∂(a-b)/∂a = 1, ∂(a-b)/∂b = -1
23//! - **Broadcasting**: Follows NumPy broadcasting rules for shape compatibility
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
28//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
29//! - **Cache-friendly Access**: Linear memory access patterns
30//! - **Broadcasting Overhead**: Minimal overhead for compatible shapes
31//! - **GradTrack Optimization**: Efficient automatic differentiation with NoGradTrack support
32
33use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
34use crate::tensor::core::Tensor;
35
36// SIMD optimizations for performance-critical operations
37#[cfg(target_arch = "x86_64")]
38use std::arch::x86_64::*;
39
40// Note: removed manual prefetching; linear access + hardware prefetch is sufficient
41
42impl Tensor {
43    /// Element-wise subtraction with another tensor with broadcasting support
44    ///
45    /// Performs element-wise subtraction with automatic broadcasting: `output[i] = self[i] - other[i]`
46    ///
47    /// Broadcasting enables subtraction between tensors of different but compatible shapes.
48    /// Compatible shapes follow NumPy broadcasting rules:
49    /// - Dimensions are aligned from the rightmost dimension
50    /// - Dimensions are compatible if they are equal, or one of them is 1
51    /// - Missing dimensions are treated as 1
52    ///
53    /// # Arguments
54    ///
55    /// * `other` - Tensor to subtract. Shapes must be broadcast-compatible.
56    ///
57    /// # Returns
58    ///
59    /// A new tensor containing the element-wise difference with broadcast result shape
60    ///
61    /// # Performance Characteristics
62    ///
63    /// - **Fast Path**: Optimized for identical shapes to avoid broadcasting overhead
64    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
65    /// - **Broadcasting**: Efficient broadcasting for compatible shapes
66    /// - **Cache-friendly**: Linear memory access patterns
67    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
68    ///
69    /// # Implementation Details
70    ///
71    /// Uses a fast path for identical shapes to avoid broadcasting overhead.
72    /// For different shapes, performs broadcasting followed by optimized element-wise subtraction.
73    /// Automatically selects between SIMD and scalar implementations based on hardware capabilities.
74    ///
75    /// # Examples
76    ///
77    /// ## Same Shape Subtraction
78    ///
79    /// ```
80    /// use train_station::Tensor;
81    ///
82    /// let a = Tensor::from_slice(&[5.0, 7.0, 9.0], vec![3]).unwrap();
83    /// let b = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
84    /// let c = a.sub_tensor(&b);
85    /// assert_eq!(c.shape().dims, vec![3]);
86    /// assert_eq!(c.get(&[0]), 4.0); // 5.0 - 1.0
87    /// assert_eq!(c.get(&[1]), 5.0); // 7.0 - 2.0
88    /// assert_eq!(c.get(&[2]), 6.0); // 9.0 - 3.0
89    /// ```
90    ///
91    /// ## Broadcasting Subtraction
92    ///
93    /// ```
94    /// use train_station::Tensor;
95    ///
96    /// let a = Tensor::from_slice(&[5.0, 10.0], vec![2, 1]).unwrap();
97    /// let b = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
98    /// let c = a.sub_tensor(&b);
99    /// assert_eq!(c.shape().dims, vec![2, 3]);
100    /// // Result: [[4.0, 3.0, 2.0], [9.0, 8.0, 7.0]]
101    /// assert_eq!(c.get(&[0, 0]), 4.0); // 5.0 - 1.0
102    /// assert_eq!(c.get(&[0, 1]), 3.0); // 5.0 - 2.0
103    /// assert_eq!(c.get(&[1, 0]), 9.0); // 10.0 - 1.0
104    /// ```
105    ///
106    /// ## Scalar Subtraction
107    ///
108    /// ```
109    /// use train_station::Tensor;
110    ///
111    /// let a = Tensor::ones(vec![2, 3]);
112    /// let b = Tensor::from_slice(&[0.5], vec![1]).unwrap();
113    /// let c = a.sub_tensor(&b);
114    /// assert_eq!(c.shape().dims, vec![2, 3]);
115    /// assert_eq!(c.get(&[0, 0]), 0.5); // 1.0 - 0.5
116    /// ```
117    ///
118    /// # Panics
119    /// Panics if tensor shapes are not broadcast-compatible
120    #[inline]
121    pub fn sub_tensor(&self, other: &Tensor) -> Tensor {
122        // Check if shapes are identical for fast path
123        if self.shape().dims == other.shape().dims {
124            return self.sub_tensor_same_shape(other);
125        }
126
127        // Use broadcasting for different shapes
128        let (broadcast_self, broadcast_other, _result_shape) =
129            self.broadcast_with(other).unwrap_or_else(|e| {
130                panic!(
131                    "Cannot broadcast tensor shapes {:?} and {:?}: {}",
132                    self.shape().dims,
133                    other.shape().dims,
134                    e
135                );
136            });
137
138        // Perform element-wise subtraction on broadcasted tensors
139        let mut result = broadcast_self.sub_tensor_optimized(&broadcast_other);
140
141        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
142            result.set_requires_grad_internal(true);
143            let grad_fn = GradFn::Sub {
144                is_tensor_sub: true,
145                original_shapes: Some((self.shape().dims.clone(), other.shape().dims.clone())),
146            };
147            result.set_grad_fn(grad_fn.clone());
148
149            let mut input_ids = Vec::with_capacity(2);
150            if self.requires_grad() {
151                input_ids.push(self.id());
152            }
153            if other.requires_grad() {
154                input_ids.push(other.id());
155            }
156            GradEngine::register_operation(result.id(), input_ids, grad_fn);
157        }
158
159        result
160    }
161
162    /// Element-wise subtraction for tensors with identical shapes (fast path)
163    ///
164    /// This is an optimized path for tensors that already have the same shape,
165    /// avoiding the overhead of broadcasting computation.
166    ///
167    /// # Arguments
168    ///
169    /// * `other` - Tensor to subtract, must have the same shape as self
170    ///
171    /// # Returns
172    ///
173    /// A new tensor containing the element-wise difference
174    ///
175    /// # Performance Characteristics
176    ///
177    /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
178    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
179    /// - **Cache-friendly**: Linear memory access patterns
180    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
181    ///
182    /// # Implementation Details
183    ///
184    /// This method is used internally by `sub_tensor()` when tensors have identical shapes.
185    /// It bypasses the broadcasting logic and directly calls the optimized subtraction implementation.
186    #[inline]
187    fn sub_tensor_same_shape(&self, other: &Tensor) -> Tensor {
188        assert_eq!(
189            self.shape(),
190            other.shape(),
191            "Tensor shapes must match for same-shape subtraction"
192        );
193        let mut result = self.sub_tensor_optimized(other);
194
195        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
196            result.set_requires_grad_internal(true);
197            let grad_fn = GradFn::Sub {
198                is_tensor_sub: true,
199                original_shapes: None, // Same shape case
200            };
201            result.set_grad_fn(grad_fn.clone());
202
203            let mut input_ids = Vec::with_capacity(2);
204            if self.requires_grad() {
205                input_ids.push(self.id());
206            }
207            if other.requires_grad() {
208                input_ids.push(other.id());
209            }
210            GradEngine::register_operation(result.id(), input_ids, grad_fn);
211        }
212
213        result
214    }
215
216    /// Element-wise subtraction of a scalar from this tensor
217    ///
218    /// Performs element-wise subtraction of a scalar value: `output[i] = self[i] - scalar`
219    ///
220    /// # Arguments
221    ///
222    /// * `scalar` - The scalar value to subtract from each element
223    ///
224    /// # Returns
225    ///
226    /// A new tensor with the scalar subtracted from each element
227    ///
228    /// # Performance Characteristics
229    ///
230    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
231    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
232    /// - **Cache-friendly**: Linear memory access patterns
233    /// - **Mathematical Accuracy**: High-precision subtraction computation
234    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
235    ///
236    /// # Examples
237    ///
238    /// ## Basic Scalar Subtraction
239    ///
240    /// ```
241    /// use train_station::Tensor;
242    ///
243    /// let a = Tensor::from_slice(&[5.0, 7.0, 9.0], vec![3]).unwrap();
244    /// let b = a.sub_scalar(2.0);
245    /// assert_eq!(b.shape().dims, vec![3]);
246    /// assert_eq!(b.get(&[0]), 3.0); // 5.0 - 2.0
247    /// assert_eq!(b.get(&[1]), 5.0); // 7.0 - 2.0
248    /// assert_eq!(b.get(&[2]), 7.0); // 9.0 - 2.0
249    /// ```
250    ///
251    /// ## Negative Scalar Subtraction
252    ///
253    /// ```
254    /// use train_station::Tensor;
255    ///
256    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
257    /// let b = a.sub_scalar(-2.0); // Subtracting negative = adding
258    /// assert_eq!(b.shape().dims, vec![3]);
259    /// assert_eq!(b.get(&[0]), 3.0); // 1.0 - (-2.0) = 3.0
260    /// assert_eq!(b.get(&[1]), 4.0); // 2.0 - (-2.0) = 4.0
261    /// assert_eq!(b.get(&[2]), 5.0); // 3.0 - (-2.0) = 5.0
262    /// ```
263    #[inline]
264    pub fn sub_scalar(&self, scalar: f32) -> Tensor {
265        let mut result = self.sub_scalar_optimized(scalar);
266
267        if self.requires_grad() && is_grad_enabled() {
268            result.set_requires_grad_internal(true);
269            let grad_fn = GradFn::Sub {
270                is_tensor_sub: false,
271                original_shapes: None, // Scalar case
272            };
273            result.set_grad_fn(grad_fn.clone());
274            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
275        }
276
277        result
278    }
279    /// Optimized tensor subtraction using SIMD when available
280    ///
281    /// Performs element-wise subtraction between tensors with identical shapes using
282    /// SIMD optimization when available and falling back to optimized scalar computation.
283    ///
284    /// # Arguments
285    ///
286    /// * `other` - The tensor to subtract from this tensor
287    ///
288    /// # Returns
289    ///
290    /// A new tensor with the result of the subtraction (self - other)
291    ///
292    /// # Performance Characteristics
293    ///
294    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
295    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
296    /// - **Cache-friendly**: Linear memory access patterns
297    /// - **Mathematical Accuracy**: High-precision subtraction computation
298    /// - **Zero-sized Handling**: Fast return for empty tensors
299    ///
300    /// # Implementation Details
301    ///
302    /// Automatically selects between SIMD and scalar implementations based on hardware
303    /// capabilities. SIMD implementation uses AVX2 vector subtraction operations for optimal
304    /// performance. Scalar implementation uses 4x unrolling for better instruction-level
305    /// parallelism.
306    ///
307    /// # Safety
308    ///
309    /// This operation assumes the tensors have the same shape.
310    #[inline]
311    pub(crate) fn sub_tensor_optimized(&self, other: &Tensor) -> Tensor {
312        assert_eq!(self.shape(), other.shape(), "Tensor shapes must match");
313
314        let mut output = Tensor::new(self.shape().dims.clone());
315
316        unsafe {
317            let a = self.as_ptr();
318            let b = other.as_ptr();
319            let dst = output.as_mut_ptr();
320
321            #[cfg(target_arch = "x86_64")]
322            {
323                // Use SIMD for better performance when available
324                if is_x86_feature_detected!("avx2") {
325                    self.sub_tensors_simd_avx2_optimized(a, b, dst);
326                    return output;
327                }
328            }
329
330            // Fallback to scalar operations with better cache usage
331            self.sub_tensors_scalar_optimized(a, b, dst);
332        }
333
334        output
335    }
336
337    /// AVX2-optimized tensor subtraction implementation
338    ///
339    /// Performs element-wise subtraction using AVX2 SIMD instructions for maximum
340    /// performance on x86_64 architectures with AVX2 support.
341    ///
342    /// # Arguments
343    ///
344    /// * `a` - Pointer to first tensor data
345    /// * `b` - Pointer to second tensor data
346    /// * `dst` - Pointer to output tensor data
347    ///
348    /// # Safety
349    ///
350    /// Requires valid pointers with sufficient memory for the tensor size.
351    /// All pointers must point to valid tensor data. Requires AVX2 support.
352    ///
353    /// # Performance Characteristics
354    ///
355    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
356    /// - **Memory Access**: Linear access patterns for cache efficiency
357    /// - **Vector Operations**: Uses AVX2 subtraction instructions for computation
358    /// - **Fallback**: Handles remaining elements with scalar operations
359    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
360    ///
361    /// # Implementation Details
362    ///
363    /// Uses AVX2 vector subtraction operations to compute a - b efficiently.
364    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
365    /// Processes remaining elements with scalar operations for complete coverage.
366    #[cfg(target_arch = "x86_64")]
367    #[inline]
368    #[target_feature(enable = "avx2")]
369    unsafe fn sub_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
370        let size = self.size();
371        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
372        let mut offset = 0;
373
374        // Unrolled SIMD loop for throughput
375        for _ in 0..simd_count {
376            // Process 4 AVX2 vectors (32 elements) per iteration
377            let a_vec1 = _mm256_loadu_ps(a.add(offset));
378            let b_vec1 = _mm256_loadu_ps(b.add(offset));
379            let sub_vec1 = _mm256_sub_ps(a_vec1, b_vec1);
380            _mm256_storeu_ps(dst.add(offset), sub_vec1);
381
382            let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
383            let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
384            let sub_vec2 = _mm256_sub_ps(a_vec2, b_vec2);
385            _mm256_storeu_ps(dst.add(offset + 8), sub_vec2);
386
387            let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
388            let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
389            let sub_vec3 = _mm256_sub_ps(a_vec3, b_vec3);
390            _mm256_storeu_ps(dst.add(offset + 16), sub_vec3);
391
392            let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
393            let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
394            let sub_vec4 = _mm256_sub_ps(a_vec4, b_vec4);
395            _mm256_storeu_ps(dst.add(offset + 24), sub_vec4);
396
397            offset += 32;
398        }
399
400        // Handle remaining 8-element blocks
401        let remaining_full_blocks = (size - offset) / 8;
402        for _ in 0..remaining_full_blocks {
403            let a_vec = _mm256_loadu_ps(a.add(offset));
404            let b_vec = _mm256_loadu_ps(b.add(offset));
405            let sub_vec = _mm256_sub_ps(a_vec, b_vec);
406            _mm256_storeu_ps(dst.add(offset), sub_vec);
407            offset += 8;
408        }
409
410        // Handle final elements with unrolled loop
411        let remaining = size - offset;
412        let unroll_count = remaining / 4;
413        for _ in 0..unroll_count {
414            *dst.add(offset) = *a.add(offset) - *b.add(offset);
415            *dst.add(offset + 1) = *a.add(offset + 1) - *b.add(offset + 1);
416            *dst.add(offset + 2) = *a.add(offset + 2) - *b.add(offset + 2);
417            *dst.add(offset + 3) = *a.add(offset + 3) - *b.add(offset + 3);
418            offset += 4;
419        }
420
421        for i in offset..size {
422            *dst.add(i) = *a.add(i) - *b.add(i);
423        }
424    }
425
426    /// Optimized scalar tensor subtraction fallback
427    ///
428    /// Performs element-wise subtraction using optimized scalar operations with
429    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
430    ///
431    /// # Arguments
432    ///
433    /// * `a` - Pointer to first tensor data
434    /// * `b` - Pointer to second tensor data
435    /// * `dst` - Pointer to output tensor data
436    ///
437    /// # Safety
438    ///
439    /// Requires valid pointers with sufficient memory for the tensor size.
440    /// All pointers must point to valid tensor data.
441    ///
442    /// # Performance Characteristics
443    ///
444    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
445    /// - **Memory Access**: Linear access patterns for cache efficiency
446    /// - **Fallback**: Handles remaining elements with scalar operations
447    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
448    /// - **Mathematical Accuracy**: High-precision scalar subtraction computation
449    ///
450    /// # Implementation Details
451    ///
452    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
453    /// Processes elements in groups of 4 to improve instruction-level parallelism
454    /// and reduce loop overhead.
455    #[inline]
456    unsafe fn sub_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
457        let size = self.size();
458        let unroll_count = size / 4;
459        let mut offset = 0;
460
461        // Unrolled scalar loop for better performance
462        for _ in 0..unroll_count {
463            *dst.add(offset) = *a.add(offset) - *b.add(offset);
464            *dst.add(offset + 1) = *a.add(offset + 1) - *b.add(offset + 1);
465            *dst.add(offset + 2) = *a.add(offset + 2) - *b.add(offset + 2);
466            *dst.add(offset + 3) = *a.add(offset + 3) - *b.add(offset + 3);
467            offset += 4;
468        }
469
470        // Handle remaining elements
471        for i in offset..size {
472            *dst.add(i) = *a.add(i) - *b.add(i);
473        }
474    }
475
476    /// Internal optimized scalar subtraction operation
477    ///
478    /// Performs element-wise subtraction of a scalar from tensor using SIMD optimization
479    /// when available and falling back to optimized scalar computation.
480    ///
481    /// # Arguments
482    ///
483    /// * `scalar` - The scalar value to subtract from each element
484    ///
485    /// # Returns
486    ///
487    /// A new tensor with the scalar subtracted from each element
488    ///
489    /// # Performance Characteristics
490    ///
491    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
492    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
493    /// - **Cache-friendly**: Linear memory access patterns
494    /// - **Mathematical Accuracy**: High-precision subtraction computation
495    /// - **Zero-sized Handling**: Fast return for empty tensors
496    ///
497    /// # Implementation Details
498    ///
499    /// Automatically selects between SIMD and scalar implementations based on hardware
500    /// capabilities. SIMD implementation uses AVX2 vector subtraction operations for optimal
501    /// performance. Scalar implementation uses 4x unrolling for better instruction-level
502    /// parallelism.
503    #[inline]
504    pub(crate) fn sub_scalar_optimized(&self, scalar: f32) -> Tensor {
505        let mut output = Tensor::new(self.shape().dims.clone());
506
507        unsafe {
508            let src = self.as_ptr();
509            let dst = output.as_mut_ptr();
510
511            #[cfg(target_arch = "x86_64")]
512            {
513                // Use SIMD for better performance when available
514                if is_x86_feature_detected!("avx2") {
515                    self.sub_scalar_simd_avx2_optimized(src, dst, scalar);
516                    return output;
517                }
518            }
519
520            // Fallback to optimized scalar operations
521            self.sub_scalar_fallback_optimized(src, dst, scalar);
522        }
523
524        output
525    }
526
527    /// AVX2-optimized scalar subtraction implementation
528    ///
529    /// Performs element-wise subtraction of a scalar using AVX2 SIMD instructions for maximum
530    /// performance on x86_64 architectures with AVX2 support.
531    ///
532    /// # Arguments
533    ///
534    /// * `src` - Pointer to source tensor data
535    /// * `dst` - Pointer to output tensor data
536    /// * `scalar` - The scalar value to subtract from each element
537    ///
538    /// # Safety
539    ///
540    /// Requires valid pointers with sufficient memory for the tensor size.
541    /// All pointers must point to valid tensor data. Requires AVX2 support.
542    ///
543    /// # Performance Characteristics
544    ///
545    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
546    /// - **Memory Access**: Linear access patterns for cache efficiency
547    /// - **Vector Operations**: Uses AVX2 subtraction instructions for computation
548    /// - **Fallback**: Handles remaining elements with scalar operations
549    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
550    ///
551    /// # Implementation Details
552    ///
553    /// Uses AVX2 vector subtraction operations to compute src - scalar efficiently.
554    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
555    /// Processes remaining elements with scalar operations for complete coverage.
556    #[cfg(target_arch = "x86_64")]
557    #[inline]
558    #[target_feature(enable = "avx2")]
559    unsafe fn sub_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
560        let scalar_vec = _mm256_set1_ps(scalar);
561        let size = self.size();
562        let simd_count = size / 32; // Process 32 elements per iteration
563        let mut offset = 0;
564
565        // Unrolled SIMD loop for better instruction throughput
566        for _ in 0..simd_count {
567            let src_vec1 = _mm256_loadu_ps(src.add(offset));
568            let sub_vec1 = _mm256_sub_ps(src_vec1, scalar_vec);
569            _mm256_storeu_ps(dst.add(offset), sub_vec1);
570
571            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
572            let sub_vec2 = _mm256_sub_ps(src_vec2, scalar_vec);
573            _mm256_storeu_ps(dst.add(offset + 8), sub_vec2);
574
575            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
576            let sub_vec3 = _mm256_sub_ps(src_vec3, scalar_vec);
577            _mm256_storeu_ps(dst.add(offset + 16), sub_vec3);
578
579            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
580            let sub_vec4 = _mm256_sub_ps(src_vec4, scalar_vec);
581            _mm256_storeu_ps(dst.add(offset + 24), sub_vec4);
582
583            offset += 32;
584        }
585
586        // Handle remaining 8-element blocks
587        let remaining_full_blocks = (size - offset) / 8;
588        for _ in 0..remaining_full_blocks {
589            let src_vec = _mm256_loadu_ps(src.add(offset));
590            let sub_vec = _mm256_sub_ps(src_vec, scalar_vec);
591            _mm256_storeu_ps(dst.add(offset), sub_vec);
592            offset += 8;
593        }
594
595        // Handle final elements
596        for i in offset..size {
597            *dst.add(i) = *src.add(i) - scalar;
598        }
599    }
600
601    /// Optimized scalar subtraction fallback
602    ///
603    /// Performs element-wise subtraction of a scalar using optimized scalar operations with
604    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
605    ///
606    /// # Arguments
607    ///
608    /// * `src` - Pointer to source tensor data
609    /// * `dst` - Pointer to output tensor data
610    /// * `scalar` - The scalar value to subtract from each element
611    ///
612    /// # Safety
613    ///
614    /// Requires valid pointers with sufficient memory for the tensor size.
615    /// All pointers must point to valid tensor data.
616    ///
617    /// # Performance Characteristics
618    ///
619    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
620    /// - **Memory Access**: Linear access patterns for cache efficiency
621    /// - **Fallback**: Handles remaining elements with scalar operations
622    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
623    /// - **Mathematical Accuracy**: High-precision scalar subtraction computation
624    ///
625    /// # Implementation Details
626    ///
627    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
628    /// Processes elements in groups of 4 to improve instruction-level parallelism
629    /// and reduce loop overhead.
630    #[inline]
631    unsafe fn sub_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
632        let size = self.size();
633        let unroll_count = size / 4;
634        let mut offset = 0;
635
636        // Unrolled scalar operations
637        for _ in 0..unroll_count {
638            *dst.add(offset) = *src.add(offset) - scalar;
639            *dst.add(offset + 1) = *src.add(offset + 1) - scalar;
640            *dst.add(offset + 2) = *src.add(offset + 2) - scalar;
641            *dst.add(offset + 3) = *src.add(offset + 3) - scalar;
642            offset += 4;
643        }
644
645        for i in offset..size {
646            *dst.add(i) = *src.add(i) - scalar;
647        }
648    }
649}
650
651#[cfg(test)]
652mod tests {
653    use super::*;
654
655    #[test]
656    fn test_tensor_subtraction() {
657        let mut a = Tensor::ones(vec![2, 3]);
658        a.fill(5.0); // Create a tensor with all 5.0s
659        let mut b = Tensor::ones(vec![2, 3]);
660        b.fill(2.0); // Create a tensor with all 2.0s
661        let result = a.sub_tensor_optimized(&b);
662
663        assert_eq!(result.shape().dims, vec![2, 3]);
664        assert_eq!(result.size(), 6);
665
666        // Check that all values are 3.0 (5.0 - 2.0)
667        unsafe {
668            for i in 0..result.size() {
669                assert!((result.as_ptr().add(i).read() - 3.0).abs() < 1e-6);
670            }
671        }
672    }
673
674    #[test]
675    fn test_scalar_subtraction() {
676        let mut tensor = Tensor::ones(vec![2, 2]);
677        tensor.fill(10.0); // Create a tensor with all 10.0s
678        let result = tensor.sub_scalar_optimized(3.0);
679
680        assert_eq!(result.shape().dims, vec![2, 2]);
681        assert_eq!(result.size(), 4);
682
683        // Check that all values are 7.0 (10.0 - 3.0)
684        unsafe {
685            for i in 0..result.size() {
686                assert!((result.as_ptr().add(i).read() - 7.0).abs() < 1e-6);
687            }
688        }
689    }
690
691    #[test]
692    fn test_negative_subtraction() {
693        let mut a = Tensor::ones(vec![2, 2]);
694        a.fill(2.0); // Create a tensor with all 2.0s
695        let mut b = Tensor::ones(vec![2, 2]);
696        b.fill(5.0); // Create a tensor with all 5.0s
697        let result = a.sub_tensor_optimized(&b);
698
699        // Check that all values are -3.0 (2.0 - 5.0)
700        unsafe {
701            for i in 0..result.size() {
702                assert!((result.as_ptr().add(i).read() - (-3.0)).abs() < 1e-6);
703            }
704        }
705    }
706
707    #[test]
708    fn test_scalar_negative_subtraction() {
709        let mut tensor = Tensor::ones(vec![2, 2]);
710        tensor.fill(3.0); // Create a tensor with all 3.0s
711        let result = tensor.sub_scalar_optimized(8.0);
712
713        // Check that all values are -5.0 (3.0 - 8.0)
714        unsafe {
715            for i in 0..result.size() {
716                assert!((result.as_ptr().add(i).read() - (-5.0)).abs() < 1e-6);
717            }
718        }
719    }
720
721    #[test]
722    #[should_panic(expected = "Tensor shapes must match")]
723    fn test_mismatched_shapes() {
724        let a = Tensor::ones(vec![2, 3]);
725        let b = Tensor::ones(vec![3, 2]);
726        a.sub_tensor_optimized(&b);
727    }
728
729    #[test]
730    fn test_edge_cases() {
731        // Test zero subtraction
732        let a = Tensor::ones(vec![3]);
733        let b = Tensor::zeros(vec![3]);
734        let result = a.sub_tensor_optimized(&b);
735
736        unsafe {
737            for i in 0..result.size() {
738                assert!((result.as_ptr().add(i).read() - 1.0).abs() < 1e-6);
739            }
740        }
741
742        // Test self subtraction
743        let mut tensor = Tensor::ones(vec![3]);
744        tensor.fill(5.0);
745        let result = tensor.sub_tensor_optimized(&tensor);
746
747        unsafe {
748            for i in 0..result.size() {
749                assert!(result.as_ptr().add(i).read().abs() < 1e-6);
750            }
751        }
752    }
753
754    #[test]
755    fn test_large_tensor_subtraction() {
756        let mut a = Tensor::ones(vec![100, 100]);
757        a.fill(10.0);
758        let mut b = Tensor::ones(vec![100, 100]);
759        b.fill(3.0);
760        let result = a.sub_tensor_optimized(&b);
761
762        assert_eq!(result.size(), 10000);
763
764        // Check some values are 7.0 (10.0 - 3.0)
765        unsafe {
766            for i in (0..result.size()).step_by(1000) {
767                assert!((result.as_ptr().add(i).read() - 7.0).abs() < 1e-6);
768            }
769        }
770    }
771
772    #[test]
773    fn test_negate_inplace() {
774        let mut tensor = Tensor::ones(vec![2, 2]);
775        tensor.fill(5.0);
776
777        // Check initial values
778        unsafe {
779            for i in 0..tensor.size() {
780                let val = tensor.as_ptr().add(i).read();
781                assert!((val - 5.0).abs() < 1e-6, "Expected 5.0, got {}", val);
782            }
783        }
784
785        tensor.negate_inplace();
786
787        // Check negated values
788        unsafe {
789            for i in 0..tensor.size() {
790                let val = tensor.as_ptr().add(i).read();
791                assert!(
792                    (val - (-5.0)).abs() < 1e-6,
793                    "Expected -5.0 after negation, got {}",
794                    val
795                );
796            }
797        }
798    }
799
800    #[test]
801    fn test_subtraction_with_gradtrack() {
802        // Test scalar subtraction with gradtrack
803        let a = Tensor::ones(vec![2, 3]).with_requires_grad();
804        let mut result = a.sub_scalar(5.0);
805
806        // Check result values: 1.0 - 5.0 = -4.0
807        unsafe {
808            for i in 0..result.size() {
809                let val = result.as_ptr().add(i).read();
810                assert!((val - (-4.0)).abs() < 1e-6, "Expected -4.0, got {}", val);
811            }
812        }
813
814        result.backward(None);
815
816        // Check gradient: d/dx(x - c) = 1
817        if let Some(grad) = a.grad_by_value() {
818            unsafe {
819                for i in 0..grad.size() {
820                    let val = grad.as_ptr().add(i).read();
821                    assert!(
822                        (val - 1.0).abs() < 1e-6,
823                        "Expected gradient 1.0, got {}",
824                        val
825                    );
826                }
827            }
828        } else {
829            panic!("No gradient computed for scalar subtraction!");
830        }
831
832        // Test tensor subtraction with gradtrack
833        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
834        let mut b = Tensor::ones(vec![2, 2]);
835        b.fill(3.0);
836        let b = b.with_requires_grad();
837
838        let mut result = a.sub_tensor(&b);
839
840        // Check result values: 1.0 - 3.0 = -2.0
841        unsafe {
842            for i in 0..result.size() {
843                let val = result.as_ptr().add(i).read();
844                assert!((val - (-2.0)).abs() < 1e-6, "Expected -2.0, got {}", val);
845            }
846        }
847
848        result.backward(None);
849
850        // Check gradients: d/dx(x - y) = 1, d/dy(x - y) = -1
851        if let Some(grad_a) = a.grad_by_value() {
852            unsafe {
853                for i in 0..grad_a.size() {
854                    let val = grad_a.as_ptr().add(i).read();
855                    assert!(
856                        (val - 1.0).abs() < 1e-6,
857                        "Expected gradient A = 1.0, got {}",
858                        val
859                    );
860                }
861            }
862        } else {
863            panic!("No gradient A computed for tensor subtraction!");
864        }
865
866        if let Some(grad_b) = b.grad_by_value() {
867            unsafe {
868                for i in 0..grad_b.size() {
869                    let val = grad_b.as_ptr().add(i).read();
870                    println!("Debug: grad_b[{}] = {}", i, val);
871                    assert!(
872                        (val - (-1.0)).abs() < 1e-6,
873                        "Expected gradient B = -1.0, got {}",
874                        val
875                    );
876                }
877            }
878        } else {
879            panic!("No gradient B computed for tensor subtraction!");
880        }
881    }
882
883    #[test]
884    fn test_mixed_add_sub_operations_with_gradtrack() {
885        // Test complex computation graph: (a + scalar1) - b + c - scalar2
886        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
887        let mut b = Tensor::ones(vec![2, 2]);
888        b.fill(2.0);
889        let b = b.with_requires_grad();
890        let mut c = Tensor::ones(vec![2, 2]);
891        c.fill(3.0);
892        let c = c.with_requires_grad();
893
894        let scalar1 = 5.0;
895        let scalar2 = 1.0;
896
897        // Complex computation: (a + scalar1) - b + c - scalar2
898        // Expected: (1 + 5) - 2 + 3 - 1 = 6
899        let step1 = a.add_scalar(scalar1); // a + 5 = 6
900        let step2 = step1.sub_tensor(&b); // 6 - 2 = 4
901        let step3 = step2.add_tensor(&c); // 4 + 3 = 7
902        let mut result = step3.sub_scalar(scalar2); // 7 - 1 = 6
903
904        // Check result values
905        unsafe {
906            for i in 0..result.size() {
907                let val = result.as_ptr().add(i).read();
908                assert!((val - 6.0).abs() < 1e-6, "Expected 6.0, got {}", val);
909            }
910        }
911
912        result.backward(None);
913
914        // Check gradients
915        // For computation: f = (a + 5) - b + c - 1
916        // df/da = 1, df/db = -1, df/dc = 1
917
918        if let Some(grad_a) = a.grad_by_value() {
919            unsafe {
920                for i in 0..grad_a.size() {
921                    let val = grad_a.as_ptr().add(i).read();
922                    assert!(
923                        (val - 1.0).abs() < 1e-6,
924                        "Expected gradient A = 1.0, got {}",
925                        val
926                    );
927                }
928            }
929        } else {
930            panic!("No gradient A computed for mixed operations!");
931        }
932
933        if let Some(grad_b) = b.grad_by_value() {
934            unsafe {
935                for i in 0..grad_b.size() {
936                    let val = grad_b.as_ptr().add(i).read();
937                    assert!(
938                        (val - (-1.0)).abs() < 1e-6,
939                        "Expected gradient B = -1.0, got {}",
940                        val
941                    );
942                }
943            }
944        } else {
945            panic!("No gradient B computed for mixed operations!");
946        }
947
948        if let Some(grad_c) = c.grad_by_value() {
949            unsafe {
950                for i in 0..grad_c.size() {
951                    let val = grad_c.as_ptr().add(i).read();
952                    assert!(
953                        (val - 1.0).abs() < 1e-6,
954                        "Expected gradient C = 1.0, got {}",
955                        val
956                    );
957                }
958            }
959        } else {
960            panic!("No gradient C computed for mixed operations!");
961        }
962
963        println!("Mixed add/sub operations with gradtrack test passed!");
964        println!("✓ Complex computation graph: (a + 5) - b + c - 1 = 6");
965        println!("✓ Gradients: da/df = 1, db/df = -1, dc/df = 1");
966    }
967
968    #[test]
969    fn test_sub_broadcasting_gradients_basic() {
970        use crate::gradtrack::clear_gradients;
971        clear_gradients();
972
973        // Test case: [2, 3] - [1, 3] -> [2, 3]
974        // grad_a should be [2, 3], grad_b should be [1, 3] (summed over broadcast dim)
975        // For subtraction: d/da (a - b) = 1, d/db (a - b) = -1
976
977        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
978            .unwrap()
979            .with_requires_grad();
980        let b = Tensor::from_slice(&[0.1, 0.2, 0.3], vec![1, 3])
981            .unwrap()
982            .with_requires_grad();
983
984        let mut result = a.sub_tensor(&b);
985        assert_eq!(result.shape().dims, vec![2, 3]);
986
987        // Set upstream gradient as ones
988        result.backward(None);
989
990        let grad_a = a.grad_by_value().expect("grad_a should exist");
991        let grad_b = b.grad_by_value().expect("grad_b should exist");
992
993        println!(
994            "Original shapes: a={:?}, b={:?}",
995            a.shape().dims,
996            b.shape().dims
997        );
998        println!(
999            "Gradient shapes: grad_a={:?}, grad_b={:?}",
1000            grad_a.shape().dims,
1001            grad_b.shape().dims
1002        );
1003
1004        // grad_a should have same shape as a: [2, 3]
1005        assert_eq!(
1006            grad_a.shape().dims,
1007            vec![2, 3],
1008            "grad_a should match original shape of a"
1009        );
1010
1011        // grad_b should have same shape as b: [1, 3]
1012        // This requires summing over the broadcasted dimension
1013        assert_eq!(
1014            grad_b.shape().dims,
1015            vec![1, 3],
1016            "grad_b should match original shape of b"
1017        );
1018
1019        // All gradients should be 1.0 for grad_a (d/da (a - b) = 1)
1020        for i in 0..grad_a.size() {
1021            let val = unsafe { *grad_a.as_ptr().add(i) };
1022            assert!(
1023                (val - 1.0).abs() < 1e-6,
1024                "grad_a[{}] = {} should be 1.0",
1025                i,
1026                val
1027            );
1028        }
1029
1030        // grad_b should be [-2.0, -2.0, -2.0] (sum over broadcast dim, then negated)
1031        let expected_grad_b = [-2.0, -2.0, -2.0]; // -1 * 2 rows = -2
1032        for (i, &expected) in expected_grad_b.iter().enumerate() {
1033            let val = unsafe { *grad_b.as_ptr().add(i) };
1034            assert!(
1035                (val - expected).abs() < 1e-6,
1036                "grad_b[{}] = {} should be {}",
1037                i,
1038                val,
1039                expected
1040            );
1041        }
1042    }
1043
1044    #[test]
1045    fn test_sub_scalar_broadcasting_gradients() {
1046        use crate::gradtrack::clear_gradients;
1047        clear_gradients();
1048
1049        // Test case: [2, 3] - [1] -> [2, 3]
1050        // grad_a should be [2, 3], grad_b should be [1] (summed over all dims, then negated)
1051
1052        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
1053            .unwrap()
1054            .with_requires_grad();
1055        let b = Tensor::from_slice(&[0.5], vec![1])
1056            .unwrap()
1057            .with_requires_grad();
1058
1059        let mut result = a.sub_tensor(&b);
1060        result.backward(None);
1061
1062        let grad_a = a.grad_by_value().expect("grad_a should exist");
1063        let grad_b = b.grad_by_value().expect("grad_b should exist");
1064
1065        // grad_a should have same shape as a: [2, 3]
1066        assert_eq!(grad_a.shape().dims, vec![2, 3]);
1067
1068        // grad_b should have same shape as b: [1] and sum to -6.0
1069        println!("grad_b shape: {:?}, expected: [1]", grad_b.shape().dims);
1070        assert_eq!(grad_b.shape().dims, vec![1]);
1071
1072        // grad_b should be -6.0 (sum over all 6 elements, then negated)
1073        let val = unsafe { *grad_b.as_ptr() };
1074        assert!(
1075            (val - (-6.0)).abs() < 1e-6,
1076            "grad_b = {} should be -6.0",
1077            val
1078        );
1079    }
1080}