train_station/tensor/ops/
sub.rs

1//! Subtraction operations for tensors
2//!
3//! Provides element-wise subtraction operations following PyTorch conventions with
4//! comprehensive GradTrack support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Tensor Subtraction**: `sub_tensor()` - Element-wise subtraction with broadcasting support
9//! - **Scalar Subtraction**: `sub_scalar()` - Subtraction of scalar from tensor
10//! - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
11//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
12//! - **Broadcasting Support**: NumPy-style broadcasting for compatible shapes
13//! - **Performance Optimization**: 4x unrolled SIMD operations with scalar fallback
14//!
15//! # Mathematical Properties
16//!
17//! The subtraction operations have the following properties:
18//! - **Tensor-Tensor**: `output[i] = a[i] - b[i]` with broadcasting
19//! - **Tensor-Scalar**: `output[i] = a[i] - scalar` for all elements
20//! - **Commutativity**: Subtraction is not commutative (a - b ≠ b - a)
21//! - **Associativity**: Subtraction is not associative ((a - b) - c ≠ a - (b - c))
22//! - **Gradient**: For tensor-tensor: ∂(a-b)/∂a = 1, ∂(a-b)/∂b = -1
23//! - **Broadcasting**: Follows NumPy broadcasting rules for shape compatibility
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
28//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
29//! - **Cache-friendly Access**: Linear memory access patterns
30//! - **Broadcasting Overhead**: Minimal overhead for compatible shapes
31//! - **GradTrack Optimization**: Efficient automatic differentiation with NoGradTrack support
32
33use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
34use crate::tensor::core::Tensor;
35
36// SIMD optimizations for performance-critical operations
37#[cfg(target_arch = "x86_64")]
38use std::arch::x86_64::*;
39
40// Note: removed manual prefetching; linear access + hardware prefetch is sufficient
41
42impl Tensor {
43    /// Element-wise subtraction with another tensor with broadcasting support
44    ///
45    /// Performs element-wise subtraction with automatic broadcasting: `output[i] = self[i] - other[i]`
46    ///
47    /// Broadcasting enables subtraction between tensors of different but compatible shapes.
48    /// Compatible shapes follow NumPy broadcasting rules:
49    /// - Dimensions are aligned from the rightmost dimension
50    /// - Dimensions are compatible if they are equal, or one of them is 1
51    /// - Missing dimensions are treated as 1
52    ///
53    /// # Arguments
54    ///
55    /// * `other` - Tensor to subtract. Shapes must be broadcast-compatible.
56    ///
57    /// # Returns
58    ///
59    /// A new tensor containing the element-wise difference with broadcast result shape
60    ///
61    /// # Performance Characteristics
62    ///
63    /// - **Fast Path**: Optimized for identical shapes to avoid broadcasting overhead
64    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
65    /// - **Broadcasting**: Efficient broadcasting for compatible shapes
66    /// - **Cache-friendly**: Linear memory access patterns
67    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
68    ///
69    /// # Implementation Details
70    ///
71    /// Uses a fast path for identical shapes to avoid broadcasting overhead.
72    /// For different shapes, performs broadcasting followed by optimized element-wise subtraction.
73    /// Automatically selects between SIMD and scalar implementations based on hardware capabilities.
74    ///
75    /// # Examples
76    ///
77    /// ## Same Shape Subtraction
78    ///
79    /// ```
80    /// use train_station::Tensor;
81    ///
82    /// let a = Tensor::from_slice(&[5.0, 7.0, 9.0], vec![3]).unwrap();
83    /// let b = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
84    /// let c = a.sub_tensor(&b);
85    /// assert_eq!(c.shape().dims, vec![3]);
86    /// assert_eq!(c.get(&[0]), 4.0); // 5.0 - 1.0
87    /// assert_eq!(c.get(&[1]), 5.0); // 7.0 - 2.0
88    /// assert_eq!(c.get(&[2]), 6.0); // 9.0 - 3.0
89    /// ```
90    ///
91    /// ## Broadcasting Subtraction
92    ///
93    /// ```
94    /// use train_station::Tensor;
95    ///
96    /// let a = Tensor::from_slice(&[5.0, 10.0], vec![2, 1]).unwrap();
97    /// let b = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
98    /// let c = a.sub_tensor(&b);
99    /// assert_eq!(c.shape().dims, vec![2, 3]);
100    /// // Result: [[4.0, 3.0, 2.0], [9.0, 8.0, 7.0]]
101    /// assert_eq!(c.get(&[0, 0]), 4.0); // 5.0 - 1.0
102    /// assert_eq!(c.get(&[0, 1]), 3.0); // 5.0 - 2.0
103    /// assert_eq!(c.get(&[1, 0]), 9.0); // 10.0 - 1.0
104    /// ```
105    ///
106    /// ## Scalar Subtraction
107    ///
108    /// ```
109    /// use train_station::Tensor;
110    ///
111    /// let a = Tensor::ones(vec![2, 3]);
112    /// let b = Tensor::from_slice(&[0.5], vec![1]).unwrap();
113    /// let c = a.sub_tensor(&b);
114    /// assert_eq!(c.shape().dims, vec![2, 3]);
115    /// assert_eq!(c.get(&[0, 0]), 0.5); // 1.0 - 0.5
116    /// ```
117    ///
118    /// # Panics
119    /// Panics if tensor shapes are not broadcast-compatible
120    #[inline]
121    #[track_caller]
122    pub fn sub_tensor(&self, other: &Tensor) -> Tensor {
123        // Check if shapes are identical for fast path
124        if self.shape().dims == other.shape().dims {
125            return self.sub_tensor_same_shape(other);
126        }
127
128        // Use broadcasting for different shapes
129        let (broadcast_self, broadcast_other, _result_shape) =
130            self.broadcast_with(other).unwrap_or_else(|e| {
131                panic!(
132                    "Cannot broadcast tensor shapes {:?} and {:?}: {}",
133                    self.shape().dims,
134                    other.shape().dims,
135                    e
136                );
137            });
138
139        // Perform element-wise subtraction on broadcasted tensors
140        let mut result = broadcast_self.sub_tensor_optimized(&broadcast_other);
141
142        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
143            result.set_requires_grad_internal(true);
144            let grad_fn = GradFn::Sub {
145                is_tensor_sub: true,
146                original_shapes: Some((self.shape().dims.clone(), other.shape().dims.clone())),
147            };
148            result.set_grad_fn(grad_fn.clone());
149
150            let mut input_ids = Vec::with_capacity(2);
151            if self.requires_grad() {
152                input_ids.push(self.id());
153            }
154            if other.requires_grad() {
155                input_ids.push(other.id());
156            }
157            GradEngine::register_operation(result.id(), input_ids, grad_fn);
158        }
159
160        result
161    }
162
163    /// Element-wise subtraction for tensors with identical shapes (fast path)
164    ///
165    /// This is an optimized path for tensors that already have the same shape,
166    /// avoiding the overhead of broadcasting computation.
167    ///
168    /// # Arguments
169    ///
170    /// * `other` - Tensor to subtract, must have the same shape as self
171    ///
172    /// # Returns
173    ///
174    /// A new tensor containing the element-wise difference
175    ///
176    /// # Performance Characteristics
177    ///
178    /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
179    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
180    /// - **Cache-friendly**: Linear memory access patterns
181    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
182    ///
183    /// # Implementation Details
184    ///
185    /// This method is used internally by `sub_tensor()` when tensors have identical shapes.
186    /// It bypasses the broadcasting logic and directly calls the optimized subtraction implementation.
187    #[inline]
188    fn sub_tensor_same_shape(&self, other: &Tensor) -> Tensor {
189        assert_eq!(
190            self.shape(),
191            other.shape(),
192            "Tensor shapes must match for same-shape subtraction"
193        );
194        let mut result = self.sub_tensor_optimized(other);
195
196        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
197            result.set_requires_grad_internal(true);
198            let grad_fn = GradFn::Sub {
199                is_tensor_sub: true,
200                original_shapes: None, // Same shape case
201            };
202            result.set_grad_fn(grad_fn.clone());
203
204            let mut input_ids = Vec::with_capacity(2);
205            if self.requires_grad() {
206                input_ids.push(self.id());
207            }
208            if other.requires_grad() {
209                input_ids.push(other.id());
210            }
211            GradEngine::register_operation(result.id(), input_ids, grad_fn);
212        }
213
214        result
215    }
216
217    /// Element-wise subtraction of a scalar from this tensor
218    ///
219    /// Performs element-wise subtraction of a scalar value: `output[i] = self[i] - scalar`
220    ///
221    /// # Arguments
222    ///
223    /// * `scalar` - The scalar value to subtract from each element
224    ///
225    /// # Returns
226    ///
227    /// A new tensor with the scalar subtracted from each element
228    ///
229    /// # Performance Characteristics
230    ///
231    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
232    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
233    /// - **Cache-friendly**: Linear memory access patterns
234    /// - **Mathematical Accuracy**: High-precision subtraction computation
235    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
236    ///
237    /// # Examples
238    ///
239    /// ## Basic Scalar Subtraction
240    ///
241    /// ```
242    /// use train_station::Tensor;
243    ///
244    /// let a = Tensor::from_slice(&[5.0, 7.0, 9.0], vec![3]).unwrap();
245    /// let b = a.sub_scalar(2.0);
246    /// assert_eq!(b.shape().dims, vec![3]);
247    /// assert_eq!(b.get(&[0]), 3.0); // 5.0 - 2.0
248    /// assert_eq!(b.get(&[1]), 5.0); // 7.0 - 2.0
249    /// assert_eq!(b.get(&[2]), 7.0); // 9.0 - 2.0
250    /// ```
251    ///
252    /// ## Negative Scalar Subtraction
253    ///
254    /// ```
255    /// use train_station::Tensor;
256    ///
257    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
258    /// let b = a.sub_scalar(-2.0); // Subtracting negative = adding
259    /// assert_eq!(b.shape().dims, vec![3]);
260    /// assert_eq!(b.get(&[0]), 3.0); // 1.0 - (-2.0) = 3.0
261    /// assert_eq!(b.get(&[1]), 4.0); // 2.0 - (-2.0) = 4.0
262    /// assert_eq!(b.get(&[2]), 5.0); // 3.0 - (-2.0) = 5.0
263    /// ```
264    #[inline]
265    #[track_caller]
266    pub fn sub_scalar(&self, scalar: f32) -> Tensor {
267        let mut result = self.sub_scalar_optimized(scalar);
268
269        if self.requires_grad() && is_grad_enabled() {
270            result.set_requires_grad_internal(true);
271            let grad_fn = GradFn::Sub {
272                is_tensor_sub: false,
273                original_shapes: None, // Scalar case
274            };
275            result.set_grad_fn(grad_fn.clone());
276            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
277        }
278
279        result
280    }
281    /// Optimized tensor subtraction using SIMD when available
282    ///
283    /// Performs element-wise subtraction between tensors with identical shapes using
284    /// SIMD optimization when available and falling back to optimized scalar computation.
285    ///
286    /// # Arguments
287    ///
288    /// * `other` - The tensor to subtract from this tensor
289    ///
290    /// # Returns
291    ///
292    /// A new tensor with the result of the subtraction (self - other)
293    ///
294    /// # Performance Characteristics
295    ///
296    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
297    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
298    /// - **Cache-friendly**: Linear memory access patterns
299    /// - **Mathematical Accuracy**: High-precision subtraction computation
300    /// - **Zero-sized Handling**: Fast return for empty tensors
301    ///
302    /// # Implementation Details
303    ///
304    /// Automatically selects between SIMD and scalar implementations based on hardware
305    /// capabilities. SIMD implementation uses AVX2 vector subtraction operations for optimal
306    /// performance. Scalar implementation uses 4x unrolling for better instruction-level
307    /// parallelism.
308    ///
309    /// # Safety
310    ///
311    /// This operation assumes the tensors have the same shape.
312    #[inline]
313    pub(crate) fn sub_tensor_optimized(&self, other: &Tensor) -> Tensor {
314        assert_eq!(self.shape(), other.shape(), "Tensor shapes must match");
315
316        let mut output = Tensor::new(self.shape().dims.clone());
317
318        unsafe {
319            let a = self.as_ptr();
320            let b = other.as_ptr();
321            let dst = output.as_mut_ptr();
322
323            #[cfg(target_arch = "x86_64")]
324            {
325                // Use SIMD for better performance when available
326                if is_x86_feature_detected!("avx2") {
327                    self.sub_tensors_simd_avx2_optimized(a, b, dst);
328                    return output;
329                }
330            }
331
332            // Fallback to scalar operations with better cache usage
333            self.sub_tensors_scalar_optimized(a, b, dst);
334        }
335
336        output
337    }
338
339    /// AVX2-optimized tensor subtraction implementation
340    ///
341    /// Performs element-wise subtraction using AVX2 SIMD instructions for maximum
342    /// performance on x86_64 architectures with AVX2 support.
343    ///
344    /// # Arguments
345    ///
346    /// * `a` - Pointer to first tensor data
347    /// * `b` - Pointer to second tensor data
348    /// * `dst` - Pointer to output tensor data
349    ///
350    /// # Safety
351    ///
352    /// Requires valid pointers with sufficient memory for the tensor size.
353    /// All pointers must point to valid tensor data. Requires AVX2 support.
354    ///
355    /// # Performance Characteristics
356    ///
357    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
358    /// - **Memory Access**: Linear access patterns for cache efficiency
359    /// - **Vector Operations**: Uses AVX2 subtraction instructions for computation
360    /// - **Fallback**: Handles remaining elements with scalar operations
361    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
362    ///
363    /// # Implementation Details
364    ///
365    /// Uses AVX2 vector subtraction operations to compute a - b efficiently.
366    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
367    /// Processes remaining elements with scalar operations for complete coverage.
368    #[cfg(target_arch = "x86_64")]
369    #[inline]
370    #[target_feature(enable = "avx2")]
371    unsafe fn sub_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
372        let size = self.size();
373        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
374        let mut offset = 0;
375
376        // Unrolled SIMD loop for throughput
377        for _ in 0..simd_count {
378            // Process 4 AVX2 vectors (32 elements) per iteration
379            let a_vec1 = _mm256_loadu_ps(a.add(offset));
380            let b_vec1 = _mm256_loadu_ps(b.add(offset));
381            let sub_vec1 = _mm256_sub_ps(a_vec1, b_vec1);
382            _mm256_storeu_ps(dst.add(offset), sub_vec1);
383
384            let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
385            let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
386            let sub_vec2 = _mm256_sub_ps(a_vec2, b_vec2);
387            _mm256_storeu_ps(dst.add(offset + 8), sub_vec2);
388
389            let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
390            let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
391            let sub_vec3 = _mm256_sub_ps(a_vec3, b_vec3);
392            _mm256_storeu_ps(dst.add(offset + 16), sub_vec3);
393
394            let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
395            let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
396            let sub_vec4 = _mm256_sub_ps(a_vec4, b_vec4);
397            _mm256_storeu_ps(dst.add(offset + 24), sub_vec4);
398
399            offset += 32;
400        }
401
402        // Handle remaining 8-element blocks
403        let remaining_full_blocks = (size - offset) / 8;
404        for _ in 0..remaining_full_blocks {
405            let a_vec = _mm256_loadu_ps(a.add(offset));
406            let b_vec = _mm256_loadu_ps(b.add(offset));
407            let sub_vec = _mm256_sub_ps(a_vec, b_vec);
408            _mm256_storeu_ps(dst.add(offset), sub_vec);
409            offset += 8;
410        }
411
412        // Handle final elements with unrolled loop
413        let remaining = size - offset;
414        let unroll_count = remaining / 4;
415        for _ in 0..unroll_count {
416            *dst.add(offset) = *a.add(offset) - *b.add(offset);
417            *dst.add(offset + 1) = *a.add(offset + 1) - *b.add(offset + 1);
418            *dst.add(offset + 2) = *a.add(offset + 2) - *b.add(offset + 2);
419            *dst.add(offset + 3) = *a.add(offset + 3) - *b.add(offset + 3);
420            offset += 4;
421        }
422
423        for i in offset..size {
424            *dst.add(i) = *a.add(i) - *b.add(i);
425        }
426    }
427
428    /// Optimized scalar tensor subtraction fallback
429    ///
430    /// Performs element-wise subtraction using optimized scalar operations with
431    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
432    ///
433    /// # Arguments
434    ///
435    /// * `a` - Pointer to first tensor data
436    /// * `b` - Pointer to second tensor data
437    /// * `dst` - Pointer to output tensor data
438    ///
439    /// # Safety
440    ///
441    /// Requires valid pointers with sufficient memory for the tensor size.
442    /// All pointers must point to valid tensor data.
443    ///
444    /// # Performance Characteristics
445    ///
446    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
447    /// - **Memory Access**: Linear access patterns for cache efficiency
448    /// - **Fallback**: Handles remaining elements with scalar operations
449    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
450    /// - **Mathematical Accuracy**: High-precision scalar subtraction computation
451    ///
452    /// # Implementation Details
453    ///
454    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
455    /// Processes elements in groups of 4 to improve instruction-level parallelism
456    /// and reduce loop overhead.
457    #[inline]
458    unsafe fn sub_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
459        let size = self.size();
460        let unroll_count = size / 4;
461        let mut offset = 0;
462
463        // Unrolled scalar loop for better performance
464        for _ in 0..unroll_count {
465            *dst.add(offset) = *a.add(offset) - *b.add(offset);
466            *dst.add(offset + 1) = *a.add(offset + 1) - *b.add(offset + 1);
467            *dst.add(offset + 2) = *a.add(offset + 2) - *b.add(offset + 2);
468            *dst.add(offset + 3) = *a.add(offset + 3) - *b.add(offset + 3);
469            offset += 4;
470        }
471
472        // Handle remaining elements
473        for i in offset..size {
474            *dst.add(i) = *a.add(i) - *b.add(i);
475        }
476    }
477
478    /// Internal optimized scalar subtraction operation
479    ///
480    /// Performs element-wise subtraction of a scalar from tensor using SIMD optimization
481    /// when available and falling back to optimized scalar computation.
482    ///
483    /// # Arguments
484    ///
485    /// * `scalar` - The scalar value to subtract from each element
486    ///
487    /// # Returns
488    ///
489    /// A new tensor with the scalar subtracted from each element
490    ///
491    /// # Performance Characteristics
492    ///
493    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
494    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
495    /// - **Cache-friendly**: Linear memory access patterns
496    /// - **Mathematical Accuracy**: High-precision subtraction computation
497    /// - **Zero-sized Handling**: Fast return for empty tensors
498    ///
499    /// # Implementation Details
500    ///
501    /// Automatically selects between SIMD and scalar implementations based on hardware
502    /// capabilities. SIMD implementation uses AVX2 vector subtraction operations for optimal
503    /// performance. Scalar implementation uses 4x unrolling for better instruction-level
504    /// parallelism.
505    #[inline]
506    pub(crate) fn sub_scalar_optimized(&self, scalar: f32) -> Tensor {
507        let mut output = Tensor::new(self.shape().dims.clone());
508
509        unsafe {
510            let src = self.as_ptr();
511            let dst = output.as_mut_ptr();
512
513            #[cfg(target_arch = "x86_64")]
514            {
515                // Use SIMD for better performance when available
516                if is_x86_feature_detected!("avx2") {
517                    self.sub_scalar_simd_avx2_optimized(src, dst, scalar);
518                    return output;
519                }
520            }
521
522            // Fallback to optimized scalar operations
523            self.sub_scalar_fallback_optimized(src, dst, scalar);
524        }
525
526        output
527    }
528
529    /// AVX2-optimized scalar subtraction implementation
530    ///
531    /// Performs element-wise subtraction of a scalar using AVX2 SIMD instructions for maximum
532    /// performance on x86_64 architectures with AVX2 support.
533    ///
534    /// # Arguments
535    ///
536    /// * `src` - Pointer to source tensor data
537    /// * `dst` - Pointer to output tensor data
538    /// * `scalar` - The scalar value to subtract from each element
539    ///
540    /// # Safety
541    ///
542    /// Requires valid pointers with sufficient memory for the tensor size.
543    /// All pointers must point to valid tensor data. Requires AVX2 support.
544    ///
545    /// # Performance Characteristics
546    ///
547    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
548    /// - **Memory Access**: Linear access patterns for cache efficiency
549    /// - **Vector Operations**: Uses AVX2 subtraction instructions for computation
550    /// - **Fallback**: Handles remaining elements with scalar operations
551    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
552    ///
553    /// # Implementation Details
554    ///
555    /// Uses AVX2 vector subtraction operations to compute src - scalar efficiently.
556    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
557    /// Processes remaining elements with scalar operations for complete coverage.
558    #[cfg(target_arch = "x86_64")]
559    #[inline]
560    #[target_feature(enable = "avx2")]
561    unsafe fn sub_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
562        let scalar_vec = _mm256_set1_ps(scalar);
563        let size = self.size();
564        let simd_count = size / 32; // Process 32 elements per iteration
565        let mut offset = 0;
566
567        // Unrolled SIMD loop for better instruction throughput
568        for _ in 0..simd_count {
569            let src_vec1 = _mm256_loadu_ps(src.add(offset));
570            let sub_vec1 = _mm256_sub_ps(src_vec1, scalar_vec);
571            _mm256_storeu_ps(dst.add(offset), sub_vec1);
572
573            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
574            let sub_vec2 = _mm256_sub_ps(src_vec2, scalar_vec);
575            _mm256_storeu_ps(dst.add(offset + 8), sub_vec2);
576
577            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
578            let sub_vec3 = _mm256_sub_ps(src_vec3, scalar_vec);
579            _mm256_storeu_ps(dst.add(offset + 16), sub_vec3);
580
581            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
582            let sub_vec4 = _mm256_sub_ps(src_vec4, scalar_vec);
583            _mm256_storeu_ps(dst.add(offset + 24), sub_vec4);
584
585            offset += 32;
586        }
587
588        // Handle remaining 8-element blocks
589        let remaining_full_blocks = (size - offset) / 8;
590        for _ in 0..remaining_full_blocks {
591            let src_vec = _mm256_loadu_ps(src.add(offset));
592            let sub_vec = _mm256_sub_ps(src_vec, scalar_vec);
593            _mm256_storeu_ps(dst.add(offset), sub_vec);
594            offset += 8;
595        }
596
597        // Handle final elements
598        for i in offset..size {
599            *dst.add(i) = *src.add(i) - scalar;
600        }
601    }
602
603    /// Optimized scalar subtraction fallback
604    ///
605    /// Performs element-wise subtraction of a scalar using optimized scalar operations with
606    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
607    ///
608    /// # Arguments
609    ///
610    /// * `src` - Pointer to source tensor data
611    /// * `dst` - Pointer to output tensor data
612    /// * `scalar` - The scalar value to subtract from each element
613    ///
614    /// # Safety
615    ///
616    /// Requires valid pointers with sufficient memory for the tensor size.
617    /// All pointers must point to valid tensor data.
618    ///
619    /// # Performance Characteristics
620    ///
621    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
622    /// - **Memory Access**: Linear access patterns for cache efficiency
623    /// - **Fallback**: Handles remaining elements with scalar operations
624    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
625    /// - **Mathematical Accuracy**: High-precision scalar subtraction computation
626    ///
627    /// # Implementation Details
628    ///
629    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
630    /// Processes elements in groups of 4 to improve instruction-level parallelism
631    /// and reduce loop overhead.
632    #[inline]
633    unsafe fn sub_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
634        let size = self.size();
635        let unroll_count = size / 4;
636        let mut offset = 0;
637
638        // Unrolled scalar operations
639        for _ in 0..unroll_count {
640            *dst.add(offset) = *src.add(offset) - scalar;
641            *dst.add(offset + 1) = *src.add(offset + 1) - scalar;
642            *dst.add(offset + 2) = *src.add(offset + 2) - scalar;
643            *dst.add(offset + 3) = *src.add(offset + 3) - scalar;
644            offset += 4;
645        }
646
647        for i in offset..size {
648            *dst.add(i) = *src.add(i) - scalar;
649        }
650    }
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656
657    #[test]
658    fn test_tensor_subtraction() {
659        let mut a = Tensor::ones(vec![2, 3]);
660        a.fill(5.0); // Create a tensor with all 5.0s
661        let mut b = Tensor::ones(vec![2, 3]);
662        b.fill(2.0); // Create a tensor with all 2.0s
663        let result = a.sub_tensor_optimized(&b);
664
665        assert_eq!(result.shape().dims, vec![2, 3]);
666        assert_eq!(result.size(), 6);
667
668        // Check that all values are 3.0 (5.0 - 2.0)
669        unsafe {
670            for i in 0..result.size() {
671                assert!((result.as_ptr().add(i).read() - 3.0).abs() < 1e-6);
672            }
673        }
674    }
675
676    #[test]
677    fn test_scalar_subtraction() {
678        let mut tensor = Tensor::ones(vec![2, 2]);
679        tensor.fill(10.0); // Create a tensor with all 10.0s
680        let result = tensor.sub_scalar_optimized(3.0);
681
682        assert_eq!(result.shape().dims, vec![2, 2]);
683        assert_eq!(result.size(), 4);
684
685        // Check that all values are 7.0 (10.0 - 3.0)
686        unsafe {
687            for i in 0..result.size() {
688                assert!((result.as_ptr().add(i).read() - 7.0).abs() < 1e-6);
689            }
690        }
691    }
692
693    #[test]
694    fn test_negative_subtraction() {
695        let mut a = Tensor::ones(vec![2, 2]);
696        a.fill(2.0); // Create a tensor with all 2.0s
697        let mut b = Tensor::ones(vec![2, 2]);
698        b.fill(5.0); // Create a tensor with all 5.0s
699        let result = a.sub_tensor_optimized(&b);
700
701        // Check that all values are -3.0 (2.0 - 5.0)
702        unsafe {
703            for i in 0..result.size() {
704                assert!((result.as_ptr().add(i).read() - (-3.0)).abs() < 1e-6);
705            }
706        }
707    }
708
709    #[test]
710    fn test_scalar_negative_subtraction() {
711        let mut tensor = Tensor::ones(vec![2, 2]);
712        tensor.fill(3.0); // Create a tensor with all 3.0s
713        let result = tensor.sub_scalar_optimized(8.0);
714
715        // Check that all values are -5.0 (3.0 - 8.0)
716        unsafe {
717            for i in 0..result.size() {
718                assert!((result.as_ptr().add(i).read() - (-5.0)).abs() < 1e-6);
719            }
720        }
721    }
722
723    #[test]
724    #[should_panic(expected = "Tensor shapes must match")]
725    fn test_mismatched_shapes() {
726        let a = Tensor::ones(vec![2, 3]);
727        let b = Tensor::ones(vec![3, 2]);
728        a.sub_tensor_optimized(&b);
729    }
730
731    #[test]
732    fn test_edge_cases() {
733        // Test zero subtraction
734        let a = Tensor::ones(vec![3]);
735        let b = Tensor::zeros(vec![3]);
736        let result = a.sub_tensor_optimized(&b);
737
738        unsafe {
739            for i in 0..result.size() {
740                assert!((result.as_ptr().add(i).read() - 1.0).abs() < 1e-6);
741            }
742        }
743
744        // Test self subtraction
745        let mut tensor = Tensor::ones(vec![3]);
746        tensor.fill(5.0);
747        let result = tensor.sub_tensor_optimized(&tensor);
748
749        unsafe {
750            for i in 0..result.size() {
751                assert!(result.as_ptr().add(i).read().abs() < 1e-6);
752            }
753        }
754    }
755
756    #[test]
757    fn test_large_tensor_subtraction() {
758        let mut a = Tensor::ones(vec![100, 100]);
759        a.fill(10.0);
760        let mut b = Tensor::ones(vec![100, 100]);
761        b.fill(3.0);
762        let result = a.sub_tensor_optimized(&b);
763
764        assert_eq!(result.size(), 10000);
765
766        // Check some values are 7.0 (10.0 - 3.0)
767        unsafe {
768            for i in (0..result.size()).step_by(1000) {
769                assert!((result.as_ptr().add(i).read() - 7.0).abs() < 1e-6);
770            }
771        }
772    }
773
774    #[test]
775    fn test_negate_inplace() {
776        let mut tensor = Tensor::ones(vec![2, 2]);
777        tensor.fill(5.0);
778
779        // Check initial values
780        unsafe {
781            for i in 0..tensor.size() {
782                let val = tensor.as_ptr().add(i).read();
783                assert!((val - 5.0).abs() < 1e-6, "Expected 5.0, got {}", val);
784            }
785        }
786
787        tensor.negate_inplace();
788
789        // Check negated values
790        unsafe {
791            for i in 0..tensor.size() {
792                let val = tensor.as_ptr().add(i).read();
793                assert!(
794                    (val - (-5.0)).abs() < 1e-6,
795                    "Expected -5.0 after negation, got {}",
796                    val
797                );
798            }
799        }
800    }
801
802    #[test]
803    fn test_subtraction_with_gradtrack() {
804        // Test scalar subtraction with gradtrack
805        let a = Tensor::ones(vec![2, 3]).with_requires_grad();
806        let mut result = a.sub_scalar(5.0);
807
808        // Check result values: 1.0 - 5.0 = -4.0
809        unsafe {
810            for i in 0..result.size() {
811                let val = result.as_ptr().add(i).read();
812                assert!((val - (-4.0)).abs() < 1e-6, "Expected -4.0, got {}", val);
813            }
814        }
815
816        result.backward(None);
817
818        // Check gradient: d/dx(x - c) = 1
819        if let Some(grad) = a.grad_by_value() {
820            unsafe {
821                for i in 0..grad.size() {
822                    let val = grad.as_ptr().add(i).read();
823                    assert!(
824                        (val - 1.0).abs() < 1e-6,
825                        "Expected gradient 1.0, got {}",
826                        val
827                    );
828                }
829            }
830        } else {
831            panic!("No gradient computed for scalar subtraction!");
832        }
833
834        // Test tensor subtraction with gradtrack
835        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
836        let mut b = Tensor::ones(vec![2, 2]);
837        b.fill(3.0);
838        let b = b.with_requires_grad();
839
840        let mut result = a.sub_tensor(&b);
841
842        // Check result values: 1.0 - 3.0 = -2.0
843        unsafe {
844            for i in 0..result.size() {
845                let val = result.as_ptr().add(i).read();
846                assert!((val - (-2.0)).abs() < 1e-6, "Expected -2.0, got {}", val);
847            }
848        }
849
850        result.backward(None);
851
852        // Check gradients: d/dx(x - y) = 1, d/dy(x - y) = -1
853        if let Some(grad_a) = a.grad_by_value() {
854            unsafe {
855                for i in 0..grad_a.size() {
856                    let val = grad_a.as_ptr().add(i).read();
857                    assert!(
858                        (val - 1.0).abs() < 1e-6,
859                        "Expected gradient A = 1.0, got {}",
860                        val
861                    );
862                }
863            }
864        } else {
865            panic!("No gradient A computed for tensor subtraction!");
866        }
867
868        if let Some(grad_b) = b.grad_by_value() {
869            unsafe {
870                for i in 0..grad_b.size() {
871                    let val = grad_b.as_ptr().add(i).read();
872                    println!("Debug: grad_b[{}] = {}", i, val);
873                    assert!(
874                        (val - (-1.0)).abs() < 1e-6,
875                        "Expected gradient B = -1.0, got {}",
876                        val
877                    );
878                }
879            }
880        } else {
881            panic!("No gradient B computed for tensor subtraction!");
882        }
883    }
884
885    #[test]
886    fn test_mixed_add_sub_operations_with_gradtrack() {
887        // Test complex computation graph: (a + scalar1) - b + c - scalar2
888        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
889        let mut b = Tensor::ones(vec![2, 2]);
890        b.fill(2.0);
891        let b = b.with_requires_grad();
892        let mut c = Tensor::ones(vec![2, 2]);
893        c.fill(3.0);
894        let c = c.with_requires_grad();
895
896        let scalar1 = 5.0;
897        let scalar2 = 1.0;
898
899        // Complex computation: (a + scalar1) - b + c - scalar2
900        // Expected: (1 + 5) - 2 + 3 - 1 = 6
901        let step1 = a.add_scalar(scalar1); // a + 5 = 6
902        let step2 = step1.sub_tensor(&b); // 6 - 2 = 4
903        let step3 = step2.add_tensor(&c); // 4 + 3 = 7
904        let mut result = step3.sub_scalar(scalar2); // 7 - 1 = 6
905
906        // Check result values
907        unsafe {
908            for i in 0..result.size() {
909                let val = result.as_ptr().add(i).read();
910                assert!((val - 6.0).abs() < 1e-6, "Expected 6.0, got {}", val);
911            }
912        }
913
914        result.backward(None);
915
916        // Check gradients
917        // For computation: f = (a + 5) - b + c - 1
918        // df/da = 1, df/db = -1, df/dc = 1
919
920        if let Some(grad_a) = a.grad_by_value() {
921            unsafe {
922                for i in 0..grad_a.size() {
923                    let val = grad_a.as_ptr().add(i).read();
924                    assert!(
925                        (val - 1.0).abs() < 1e-6,
926                        "Expected gradient A = 1.0, got {}",
927                        val
928                    );
929                }
930            }
931        } else {
932            panic!("No gradient A computed for mixed operations!");
933        }
934
935        if let Some(grad_b) = b.grad_by_value() {
936            unsafe {
937                for i in 0..grad_b.size() {
938                    let val = grad_b.as_ptr().add(i).read();
939                    assert!(
940                        (val - (-1.0)).abs() < 1e-6,
941                        "Expected gradient B = -1.0, got {}",
942                        val
943                    );
944                }
945            }
946        } else {
947            panic!("No gradient B computed for mixed operations!");
948        }
949
950        if let Some(grad_c) = c.grad_by_value() {
951            unsafe {
952                for i in 0..grad_c.size() {
953                    let val = grad_c.as_ptr().add(i).read();
954                    assert!(
955                        (val - 1.0).abs() < 1e-6,
956                        "Expected gradient C = 1.0, got {}",
957                        val
958                    );
959                }
960            }
961        } else {
962            panic!("No gradient C computed for mixed operations!");
963        }
964
965        println!("Mixed add/sub operations with gradtrack test passed!");
966        println!("✓ Complex computation graph: (a + 5) - b + c - 1 = 6");
967        println!("✓ Gradients: da/df = 1, db/df = -1, dc/df = 1");
968    }
969
970    #[test]
971    fn test_sub_broadcasting_gradients_basic() {
972        use crate::gradtrack::clear_gradients;
973        clear_gradients();
974
975        // Test case: [2, 3] - [1, 3] -> [2, 3]
976        // grad_a should be [2, 3], grad_b should be [1, 3] (summed over broadcast dim)
977        // For subtraction: d/da (a - b) = 1, d/db (a - b) = -1
978
979        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
980            .unwrap()
981            .with_requires_grad();
982        let b = Tensor::from_slice(&[0.1, 0.2, 0.3], vec![1, 3])
983            .unwrap()
984            .with_requires_grad();
985
986        let mut result = a.sub_tensor(&b);
987        assert_eq!(result.shape().dims, vec![2, 3]);
988
989        // Set upstream gradient as ones
990        result.backward(None);
991
992        let grad_a = a.grad_by_value().expect("grad_a should exist");
993        let grad_b = b.grad_by_value().expect("grad_b should exist");
994
995        println!(
996            "Original shapes: a={:?}, b={:?}",
997            a.shape().dims,
998            b.shape().dims
999        );
1000        println!(
1001            "Gradient shapes: grad_a={:?}, grad_b={:?}",
1002            grad_a.shape().dims,
1003            grad_b.shape().dims
1004        );
1005
1006        // grad_a should have same shape as a: [2, 3]
1007        assert_eq!(
1008            grad_a.shape().dims,
1009            vec![2, 3],
1010            "grad_a should match original shape of a"
1011        );
1012
1013        // grad_b should have same shape as b: [1, 3]
1014        // This requires summing over the broadcasted dimension
1015        assert_eq!(
1016            grad_b.shape().dims,
1017            vec![1, 3],
1018            "grad_b should match original shape of b"
1019        );
1020
1021        // All gradients should be 1.0 for grad_a (d/da (a - b) = 1)
1022        for i in 0..grad_a.size() {
1023            let val = unsafe { *grad_a.as_ptr().add(i) };
1024            assert!(
1025                (val - 1.0).abs() < 1e-6,
1026                "grad_a[{}] = {} should be 1.0",
1027                i,
1028                val
1029            );
1030        }
1031
1032        // grad_b should be [-2.0, -2.0, -2.0] (sum over broadcast dim, then negated)
1033        let expected_grad_b = [-2.0, -2.0, -2.0]; // -1 * 2 rows = -2
1034        for (i, &expected) in expected_grad_b.iter().enumerate() {
1035            let val = unsafe { *grad_b.as_ptr().add(i) };
1036            assert!(
1037                (val - expected).abs() < 1e-6,
1038                "grad_b[{}] = {} should be {}",
1039                i,
1040                val,
1041                expected
1042            );
1043        }
1044    }
1045
1046    #[test]
1047    fn test_sub_scalar_broadcasting_gradients() {
1048        use crate::gradtrack::clear_gradients;
1049        clear_gradients();
1050
1051        // Test case: [2, 3] - [1] -> [2, 3]
1052        // grad_a should be [2, 3], grad_b should be [1] (summed over all dims, then negated)
1053
1054        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
1055            .unwrap()
1056            .with_requires_grad();
1057        let b = Tensor::from_slice(&[0.5], vec![1])
1058            .unwrap()
1059            .with_requires_grad();
1060
1061        let mut result = a.sub_tensor(&b);
1062        result.backward(None);
1063
1064        let grad_a = a.grad_by_value().expect("grad_a should exist");
1065        let grad_b = b.grad_by_value().expect("grad_b should exist");
1066
1067        // grad_a should have same shape as a: [2, 3]
1068        assert_eq!(grad_a.shape().dims, vec![2, 3]);
1069
1070        // grad_b should have same shape as b: [1] and sum to -6.0
1071        println!("grad_b shape: {:?}, expected: [1]", grad_b.shape().dims);
1072        assert_eq!(grad_b.shape().dims, vec![1]);
1073
1074        // grad_b should be -6.0 (sum over all 6 elements, then negated)
1075        let val = unsafe { *grad_b.as_ptr() };
1076        assert!(
1077            (val - (-6.0)).abs() < 1e-6,
1078            "grad_b = {} should be -6.0",
1079            val
1080        );
1081    }
1082}