train_station/tensor/ops/
sub.rs

1//! Subtraction operations for tensors
2//!
3//! Provides element-wise subtraction operations following PyTorch conventions with
4//! comprehensive GradTrack support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Tensor Subtraction**: `sub_tensor()` - Element-wise subtraction with broadcasting support
9//! - **Scalar Subtraction**: `sub_scalar()` - Subtraction of scalar from tensor
10//! - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
11//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
12//! - **Broadcasting Support**: NumPy-style broadcasting for compatible shapes
13//! - **Performance Optimization**: 4x unrolled SIMD operations with scalar fallback
14//!
15//! # Mathematical Properties
16//!
17//! The subtraction operations have the following properties:
18//! - **Tensor-Tensor**: `output[i] = a[i] - b[i]` with broadcasting
19//! - **Tensor-Scalar**: `output[i] = a[i] - scalar` for all elements
20//! - **Commutativity**: Subtraction is not commutative (a - b ≠ b - a)
21//! - **Associativity**: Subtraction is not associative ((a - b) - c ≠ a - (b - c))
22//! - **Gradient**: For tensor-tensor: ∂(a-b)/∂a = 1, ∂(a-b)/∂b = -1
23//! - **Broadcasting**: Follows NumPy broadcasting rules for shape compatibility
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
28//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
29//! - **Cache-friendly Access**: Linear memory access patterns
30//! - **Broadcasting Overhead**: Minimal overhead for compatible shapes
31//! - **GradTrack Optimization**: Efficient automatic differentiation with NoGradTrack support
32
33use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
34use crate::tensor::core::Tensor;
35
36// SIMD optimizations for performance-critical operations
37#[cfg(target_arch = "x86_64")]
38use std::arch::x86_64::*;
39
40// Note: removed manual prefetching; linear access + hardware prefetch is sufficient
41
42impl Tensor {
43    /// Element-wise subtraction with another tensor with broadcasting support
44    ///
45    /// Performs element-wise subtraction with automatic broadcasting: `output[i] = self[i] - other[i]`
46    ///
47    /// Broadcasting enables subtraction between tensors of different but compatible shapes.
48    /// Compatible shapes follow NumPy broadcasting rules:
49    /// - Dimensions are aligned from the rightmost dimension
50    /// - Dimensions are compatible if they are equal, or one of them is 1
51    /// - Missing dimensions are treated as 1
52    ///
53    /// # Arguments
54    ///
55    /// * `other` - Tensor to subtract. Shapes must be broadcast-compatible.
56    ///
57    /// # Returns
58    ///
59    /// A new tensor containing the element-wise difference with broadcast result shape
60    ///
61    /// # Performance Characteristics
62    ///
63    /// - **Fast Path**: Optimized for identical shapes to avoid broadcasting overhead
64    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
65    /// - **Broadcasting**: Efficient broadcasting for compatible shapes
66    /// - **Cache-friendly**: Linear memory access patterns
67    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
68    ///
69    /// # Implementation Details
70    ///
71    /// Uses a fast path for identical shapes to avoid broadcasting overhead.
72    /// For different shapes, performs broadcasting followed by optimized element-wise subtraction.
73    /// Automatically selects between SIMD and scalar implementations based on hardware capabilities.
74    ///
75    /// # Examples
76    ///
77    /// ## Same Shape Subtraction
78    ///
79    /// ```
80    /// use train_station::Tensor;
81    ///
82    /// let a = Tensor::from_slice(&[5.0, 7.0, 9.0], vec![3]).unwrap();
83    /// let b = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
84    /// let c = a.sub_tensor(&b);
85    /// assert_eq!(c.shape().dims(), vec![3]);
86    /// assert_eq!(c.get(&[0]), 4.0); // 5.0 - 1.0
87    /// assert_eq!(c.get(&[1]), 5.0); // 7.0 - 2.0
88    /// assert_eq!(c.get(&[2]), 6.0); // 9.0 - 3.0
89    /// ```
90    ///
91    /// ## Broadcasting Subtraction
92    ///
93    /// ```
94    /// use train_station::Tensor;
95    ///
96    /// let a = Tensor::from_slice(&[5.0, 10.0], vec![2, 1]).unwrap();
97    /// let b = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
98    /// let c = a.sub_tensor(&b);
99    /// assert_eq!(c.shape().dims(), vec![2, 3]);
100    /// // Result: [[4.0, 3.0, 2.0], [9.0, 8.0, 7.0]]
101    /// assert_eq!(c.get(&[0, 0]), 4.0); // 5.0 - 1.0
102    /// assert_eq!(c.get(&[0, 1]), 3.0); // 5.0 - 2.0
103    /// assert_eq!(c.get(&[1, 0]), 9.0); // 10.0 - 1.0
104    /// ```
105    ///
106    /// ## Scalar Subtraction
107    ///
108    /// ```
109    /// use train_station::Tensor;
110    ///
111    /// let a = Tensor::ones(vec![2, 3]);
112    /// let b = Tensor::from_slice(&[0.5], vec![1]).unwrap();
113    /// let c = a.sub_tensor(&b);
114    /// assert_eq!(c.shape().dims(), vec![2, 3]);
115    /// assert_eq!(c.get(&[0, 0]), 0.5); // 1.0 - 0.5
116    /// ```
117    ///
118    /// # Panics
119    /// Panics if tensor shapes are not broadcast-compatible
120    #[inline]
121    #[track_caller]
122    pub fn sub_tensor(&self, other: &Tensor) -> Tensor {
123        // Check if shapes are identical for fast path
124        if self.shape().dims() == other.shape().dims() {
125            return self.sub_tensor_same_shape(other);
126        }
127
128        // Zero-copy broadcast views, then reuse same-shape optimized path
129        use crate::tensor::ops::broadcasting::{broadcast_shapes_cow, BroadcastError};
130        let mut result = match broadcast_shapes_cow(self, other) {
131            Ok((a_b, b_b, _)) => a_b.as_ref().sub_tensor_optimized(b_b.as_ref()),
132            Err(BroadcastError::IncompatibleShapes { .. }) => {
133                panic!(
134                    "Cannot broadcast tensor shapes {:?} and {:?}: {}",
135                    self.shape().dims(),
136                    other.shape().dims(),
137                    "incompatible shapes"
138                );
139            }
140            Err(BroadcastError::AllocationFailed) => {
141                panic!("Memory allocation failed during broadcasting");
142            }
143        };
144
145        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
146            result.set_requires_grad_internal(true);
147            let grad_fn = GradFn::Sub {
148                is_tensor_sub: true,
149                original_shapes: Some((
150                    self.shape().dims().to_vec(),
151                    other.shape().dims().to_vec(),
152                )),
153            };
154            result.set_grad_fn(grad_fn.clone());
155
156            let mut input_ids = Vec::with_capacity(2);
157            if self.requires_grad() {
158                input_ids.push(self.id());
159            }
160            if other.requires_grad() {
161                input_ids.push(other.id());
162            }
163            GradEngine::register_operation(result.id(), input_ids, grad_fn);
164        }
165
166        result
167    }
168
169    /// Element-wise subtraction for tensors with identical shapes (fast path)
170    ///
171    /// This is an optimized path for tensors that already have the same shape,
172    /// avoiding the overhead of broadcasting computation.
173    ///
174    /// # Arguments
175    ///
176    /// * `other` - Tensor to subtract, must have the same shape as self
177    ///
178    /// # Returns
179    ///
180    /// A new tensor containing the element-wise difference
181    ///
182    /// # Performance Characteristics
183    ///
184    /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
185    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
186    /// - **Cache-friendly**: Linear memory access patterns
187    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
188    ///
189    /// # Implementation Details
190    ///
191    /// This method is used internally by `sub_tensor()` when tensors have identical shapes.
192    /// It bypasses the broadcasting logic and directly calls the optimized subtraction implementation.
193    #[inline]
194    fn sub_tensor_same_shape(&self, other: &Tensor) -> Tensor {
195        assert_eq!(
196            self.shape(),
197            other.shape(),
198            "Tensor shapes must match for same-shape subtraction"
199        );
200        let mut result = self.sub_tensor_optimized(other);
201
202        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
203            result.set_requires_grad_internal(true);
204            let grad_fn = GradFn::Sub {
205                is_tensor_sub: true,
206                original_shapes: None, // Same shape case
207            };
208            result.set_grad_fn(grad_fn.clone());
209
210            let mut input_ids = Vec::with_capacity(2);
211            if self.requires_grad() {
212                input_ids.push(self.id());
213            }
214            if other.requires_grad() {
215                input_ids.push(other.id());
216            }
217            GradEngine::register_operation(result.id(), input_ids, grad_fn);
218        }
219
220        result
221    }
222
223    /// Element-wise subtraction of a scalar from this tensor
224    ///
225    /// Performs element-wise subtraction of a scalar value: `output[i] = self[i] - scalar`
226    ///
227    /// # Arguments
228    ///
229    /// * `scalar` - The scalar value to subtract from each element
230    ///
231    /// # Returns
232    ///
233    /// A new tensor with the scalar subtracted from each element
234    ///
235    /// # Performance Characteristics
236    ///
237    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
238    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
239    /// - **Cache-friendly**: Linear memory access patterns
240    /// - **Mathematical Accuracy**: High-precision subtraction computation
241    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
242    ///
243    /// # Examples
244    ///
245    /// ## Basic Scalar Subtraction
246    ///
247    /// ```
248    /// use train_station::Tensor;
249    ///
250    /// let a = Tensor::from_slice(&[5.0, 7.0, 9.0], vec![3]).unwrap();
251    /// let b = a.sub_scalar(2.0);
252    /// assert_eq!(b.shape().dims(), vec![3]);
253    /// assert_eq!(b.get(&[0]), 3.0); // 5.0 - 2.0
254    /// assert_eq!(b.get(&[1]), 5.0); // 7.0 - 2.0
255    /// assert_eq!(b.get(&[2]), 7.0); // 9.0 - 2.0
256    /// ```
257    ///
258    /// ## Negative Scalar Subtraction
259    ///
260    /// ```
261    /// use train_station::Tensor;
262    ///
263    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
264    /// let b = a.sub_scalar(-2.0); // Subtracting negative = adding
265    /// assert_eq!(b.shape().dims(), vec![3]);
266    /// assert_eq!(b.get(&[0]), 3.0); // 1.0 - (-2.0) = 3.0
267    /// assert_eq!(b.get(&[1]), 4.0); // 2.0 - (-2.0) = 4.0
268    /// assert_eq!(b.get(&[2]), 5.0); // 3.0 - (-2.0) = 5.0
269    /// ```
270    #[inline]
271    #[track_caller]
272    pub fn sub_scalar(&self, scalar: f32) -> Tensor {
273        let mut result = self.sub_scalar_optimized(scalar);
274
275        if self.requires_grad() && is_grad_enabled() {
276            result.set_requires_grad_internal(true);
277            let grad_fn = GradFn::Sub {
278                is_tensor_sub: false,
279                original_shapes: None, // Scalar case
280            };
281            result.set_grad_fn(grad_fn.clone());
282            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
283        }
284
285        result
286    }
287    /// Optimized tensor subtraction using SIMD when available
288    ///
289    /// Performs element-wise subtraction between tensors with identical shapes using
290    /// SIMD optimization when available and falling back to optimized scalar computation.
291    ///
292    /// # Arguments
293    ///
294    /// * `other` - The tensor to subtract from this tensor
295    ///
296    /// # Returns
297    ///
298    /// A new tensor with the result of the subtraction (self - other)
299    ///
300    /// # Performance Characteristics
301    ///
302    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
303    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
304    /// - **Cache-friendly**: Linear memory access patterns
305    /// - **Mathematical Accuracy**: High-precision subtraction computation
306    /// - **Zero-sized Handling**: Fast return for empty tensors
307    ///
308    /// # Implementation Details
309    ///
310    /// Automatically selects between SIMD and scalar implementations based on hardware
311    /// capabilities. SIMD implementation uses AVX2 vector subtraction operations for optimal
312    /// performance. Scalar implementation uses 4x unrolling for better instruction-level
313    /// parallelism.
314    ///
315    /// # Safety
316    ///
317    /// This operation assumes the tensors have the same shape.
318    #[inline]
319    pub(crate) fn sub_tensor_optimized(&self, other: &Tensor) -> Tensor {
320        assert_eq!(
321            self.shape().dims(),
322            other.shape().dims(),
323            "Tensor dims must match"
324        );
325
326        // Ensure contiguous sources for correctness with broadcast views/strides
327        let a_src = if self.is_contiguous() {
328            self.clone()
329        } else {
330            self.contiguous()
331        };
332        let b_src = if other.is_contiguous() {
333            other.clone()
334        } else {
335            other.contiguous()
336        };
337
338        let mut output = Tensor::new(self.shape().dims().to_vec());
339
340        unsafe {
341            let a = a_src.as_ptr();
342            let b = b_src.as_ptr();
343            let dst = output.as_mut_ptr();
344
345            #[cfg(target_arch = "x86_64")]
346            {
347                // Use SIMD for better performance when available
348                if is_x86_feature_detected!("avx2") {
349                    self.sub_tensors_simd_avx2_optimized(a, b, dst);
350                    return output;
351                }
352            }
353
354            // Fallback to scalar operations with better cache usage
355            self.sub_tensors_scalar_optimized(a, b, dst);
356        }
357
358        output
359    }
360
361    /// AVX2-optimized tensor subtraction implementation
362    ///
363    /// Performs element-wise subtraction using AVX2 SIMD instructions for maximum
364    /// performance on x86_64 architectures with AVX2 support.
365    ///
366    /// # Arguments
367    ///
368    /// * `a` - Pointer to first tensor data
369    /// * `b` - Pointer to second tensor data
370    /// * `dst` - Pointer to output tensor data
371    ///
372    /// # Safety
373    ///
374    /// Requires valid pointers with sufficient memory for the tensor size.
375    /// All pointers must point to valid tensor data. Requires AVX2 support.
376    ///
377    /// # Performance Characteristics
378    ///
379    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
380    /// - **Memory Access**: Linear access patterns for cache efficiency
381    /// - **Vector Operations**: Uses AVX2 subtraction instructions for computation
382    /// - **Fallback**: Handles remaining elements with scalar operations
383    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
384    ///
385    /// # Implementation Details
386    ///
387    /// Uses AVX2 vector subtraction operations to compute a - b efficiently.
388    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
389    /// Processes remaining elements with scalar operations for complete coverage.
390    #[cfg(target_arch = "x86_64")]
391    #[inline]
392    #[target_feature(enable = "avx2")]
393    unsafe fn sub_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
394        let size = self.size();
395        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
396        let mut offset = 0;
397
398        // Unrolled SIMD loop for throughput
399        for _ in 0..simd_count {
400            // Process 4 AVX2 vectors (32 elements) per iteration
401            let a_vec1 = _mm256_loadu_ps(a.add(offset));
402            let b_vec1 = _mm256_loadu_ps(b.add(offset));
403            let sub_vec1 = _mm256_sub_ps(a_vec1, b_vec1);
404            _mm256_storeu_ps(dst.add(offset), sub_vec1);
405
406            let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
407            let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
408            let sub_vec2 = _mm256_sub_ps(a_vec2, b_vec2);
409            _mm256_storeu_ps(dst.add(offset + 8), sub_vec2);
410
411            let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
412            let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
413            let sub_vec3 = _mm256_sub_ps(a_vec3, b_vec3);
414            _mm256_storeu_ps(dst.add(offset + 16), sub_vec3);
415
416            let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
417            let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
418            let sub_vec4 = _mm256_sub_ps(a_vec4, b_vec4);
419            _mm256_storeu_ps(dst.add(offset + 24), sub_vec4);
420
421            offset += 32;
422        }
423
424        // Handle remaining 8-element blocks
425        let remaining_full_blocks = (size - offset) / 8;
426        for _ in 0..remaining_full_blocks {
427            let a_vec = _mm256_loadu_ps(a.add(offset));
428            let b_vec = _mm256_loadu_ps(b.add(offset));
429            let sub_vec = _mm256_sub_ps(a_vec, b_vec);
430            _mm256_storeu_ps(dst.add(offset), sub_vec);
431            offset += 8;
432        }
433
434        // Handle final elements with unrolled loop
435        let remaining = size - offset;
436        let unroll_count = remaining / 4;
437        for _ in 0..unroll_count {
438            *dst.add(offset) = *a.add(offset) - *b.add(offset);
439            *dst.add(offset + 1) = *a.add(offset + 1) - *b.add(offset + 1);
440            *dst.add(offset + 2) = *a.add(offset + 2) - *b.add(offset + 2);
441            *dst.add(offset + 3) = *a.add(offset + 3) - *b.add(offset + 3);
442            offset += 4;
443        }
444
445        for i in offset..size {
446            *dst.add(i) = *a.add(i) - *b.add(i);
447        }
448    }
449
450    /// Optimized scalar tensor subtraction fallback
451    ///
452    /// Performs element-wise subtraction using optimized scalar operations with
453    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
454    ///
455    /// # Arguments
456    ///
457    /// * `a` - Pointer to first tensor data
458    /// * `b` - Pointer to second tensor data
459    /// * `dst` - Pointer to output tensor data
460    ///
461    /// # Safety
462    ///
463    /// Requires valid pointers with sufficient memory for the tensor size.
464    /// All pointers must point to valid tensor data.
465    ///
466    /// # Performance Characteristics
467    ///
468    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
469    /// - **Memory Access**: Linear access patterns for cache efficiency
470    /// - **Fallback**: Handles remaining elements with scalar operations
471    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
472    /// - **Mathematical Accuracy**: High-precision scalar subtraction computation
473    ///
474    /// # Implementation Details
475    ///
476    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
477    /// Processes elements in groups of 4 to improve instruction-level parallelism
478    /// and reduce loop overhead.
479    #[inline]
480    unsafe fn sub_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
481        let size = self.size();
482        let unroll_count = size / 4;
483        let mut offset = 0;
484
485        // Unrolled scalar loop for better performance
486        for _ in 0..unroll_count {
487            *dst.add(offset) = *a.add(offset) - *b.add(offset);
488            *dst.add(offset + 1) = *a.add(offset + 1) - *b.add(offset + 1);
489            *dst.add(offset + 2) = *a.add(offset + 2) - *b.add(offset + 2);
490            *dst.add(offset + 3) = *a.add(offset + 3) - *b.add(offset + 3);
491            offset += 4;
492        }
493
494        // Handle remaining elements
495        for i in offset..size {
496            *dst.add(i) = *a.add(i) - *b.add(i);
497        }
498    }
499
500    /// Internal optimized scalar subtraction operation
501    ///
502    /// Performs element-wise subtraction of a scalar from tensor using SIMD optimization
503    /// when available and falling back to optimized scalar computation.
504    ///
505    /// # Arguments
506    ///
507    /// * `scalar` - The scalar value to subtract from each element
508    ///
509    /// # Returns
510    ///
511    /// A new tensor with the scalar subtracted from each element
512    ///
513    /// # Performance Characteristics
514    ///
515    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
516    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
517    /// - **Cache-friendly**: Linear memory access patterns
518    /// - **Mathematical Accuracy**: High-precision subtraction computation
519    /// - **Zero-sized Handling**: Fast return for empty tensors
520    ///
521    /// # Implementation Details
522    ///
523    /// Automatically selects between SIMD and scalar implementations based on hardware
524    /// capabilities. SIMD implementation uses AVX2 vector subtraction operations for optimal
525    /// performance. Scalar implementation uses 4x unrolling for better instruction-level
526    /// parallelism.
527    #[inline]
528    pub(crate) fn sub_scalar_optimized(&self, scalar: f32) -> Tensor {
529        let mut output = Tensor::new(self.shape().dims().to_vec());
530
531        unsafe {
532            let src = self.as_ptr();
533            let dst = output.as_mut_ptr();
534
535            #[cfg(target_arch = "x86_64")]
536            {
537                // Use SIMD for better performance when available
538                if is_x86_feature_detected!("avx2") {
539                    self.sub_scalar_simd_avx2_optimized(src, dst, scalar);
540                    return output;
541                }
542            }
543
544            // Fallback to optimized scalar operations
545            self.sub_scalar_fallback_optimized(src, dst, scalar);
546        }
547
548        output
549    }
550
551    /// AVX2-optimized scalar subtraction implementation
552    ///
553    /// Performs element-wise subtraction of a scalar using AVX2 SIMD instructions for maximum
554    /// performance on x86_64 architectures with AVX2 support.
555    ///
556    /// # Arguments
557    ///
558    /// * `src` - Pointer to source tensor data
559    /// * `dst` - Pointer to output tensor data
560    /// * `scalar` - The scalar value to subtract from each element
561    ///
562    /// # Safety
563    ///
564    /// Requires valid pointers with sufficient memory for the tensor size.
565    /// All pointers must point to valid tensor data. Requires AVX2 support.
566    ///
567    /// # Performance Characteristics
568    ///
569    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
570    /// - **Memory Access**: Linear access patterns for cache efficiency
571    /// - **Vector Operations**: Uses AVX2 subtraction instructions for computation
572    /// - **Fallback**: Handles remaining elements with scalar operations
573    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
574    ///
575    /// # Implementation Details
576    ///
577    /// Uses AVX2 vector subtraction operations to compute src - scalar efficiently.
578    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
579    /// Processes remaining elements with scalar operations for complete coverage.
580    #[cfg(target_arch = "x86_64")]
581    #[inline]
582    #[target_feature(enable = "avx2")]
583    unsafe fn sub_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
584        let scalar_vec = _mm256_set1_ps(scalar);
585        let size = self.size();
586        let simd_count = size / 32; // Process 32 elements per iteration
587        let mut offset = 0;
588
589        // Unrolled SIMD loop for better instruction throughput
590        for _ in 0..simd_count {
591            let src_vec1 = _mm256_loadu_ps(src.add(offset));
592            let sub_vec1 = _mm256_sub_ps(src_vec1, scalar_vec);
593            _mm256_storeu_ps(dst.add(offset), sub_vec1);
594
595            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
596            let sub_vec2 = _mm256_sub_ps(src_vec2, scalar_vec);
597            _mm256_storeu_ps(dst.add(offset + 8), sub_vec2);
598
599            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
600            let sub_vec3 = _mm256_sub_ps(src_vec3, scalar_vec);
601            _mm256_storeu_ps(dst.add(offset + 16), sub_vec3);
602
603            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
604            let sub_vec4 = _mm256_sub_ps(src_vec4, scalar_vec);
605            _mm256_storeu_ps(dst.add(offset + 24), sub_vec4);
606
607            offset += 32;
608        }
609
610        // Handle remaining 8-element blocks
611        let remaining_full_blocks = (size - offset) / 8;
612        for _ in 0..remaining_full_blocks {
613            let src_vec = _mm256_loadu_ps(src.add(offset));
614            let sub_vec = _mm256_sub_ps(src_vec, scalar_vec);
615            _mm256_storeu_ps(dst.add(offset), sub_vec);
616            offset += 8;
617        }
618
619        // Handle final elements
620        for i in offset..size {
621            *dst.add(i) = *src.add(i) - scalar;
622        }
623    }
624
625    /// Optimized scalar subtraction fallback
626    ///
627    /// Performs element-wise subtraction of a scalar using optimized scalar operations with
628    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
629    ///
630    /// # Arguments
631    ///
632    /// * `src` - Pointer to source tensor data
633    /// * `dst` - Pointer to output tensor data
634    /// * `scalar` - The scalar value to subtract from each element
635    ///
636    /// # Safety
637    ///
638    /// Requires valid pointers with sufficient memory for the tensor size.
639    /// All pointers must point to valid tensor data.
640    ///
641    /// # Performance Characteristics
642    ///
643    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
644    /// - **Memory Access**: Linear access patterns for cache efficiency
645    /// - **Fallback**: Handles remaining elements with scalar operations
646    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
647    /// - **Mathematical Accuracy**: High-precision scalar subtraction computation
648    ///
649    /// # Implementation Details
650    ///
651    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
652    /// Processes elements in groups of 4 to improve instruction-level parallelism
653    /// and reduce loop overhead.
654    #[inline]
655    unsafe fn sub_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
656        let size = self.size();
657        let unroll_count = size / 4;
658        let mut offset = 0;
659
660        // Unrolled scalar operations
661        for _ in 0..unroll_count {
662            *dst.add(offset) = *src.add(offset) - scalar;
663            *dst.add(offset + 1) = *src.add(offset + 1) - scalar;
664            *dst.add(offset + 2) = *src.add(offset + 2) - scalar;
665            *dst.add(offset + 3) = *src.add(offset + 3) - scalar;
666            offset += 4;
667        }
668
669        for i in offset..size {
670            *dst.add(i) = *src.add(i) - scalar;
671        }
672    }
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678
679    #[test]
680    fn test_tensor_subtraction() {
681        let mut a = Tensor::ones(vec![2, 3]);
682        a.fill(5.0); // Create a tensor with all 5.0s
683        let mut b = Tensor::ones(vec![2, 3]);
684        b.fill(2.0); // Create a tensor with all 2.0s
685        let result = a.sub_tensor_optimized(&b);
686
687        assert_eq!(result.shape().dims(), vec![2, 3]);
688        assert_eq!(result.size(), 6);
689
690        // Check that all values are 3.0 (5.0 - 2.0)
691        unsafe {
692            for i in 0..result.size() {
693                assert!((result.as_ptr().add(i).read() - 3.0).abs() < 1e-6);
694            }
695        }
696    }
697
698    #[test]
699    fn test_scalar_subtraction() {
700        let mut tensor = Tensor::ones(vec![2, 2]);
701        tensor.fill(10.0); // Create a tensor with all 10.0s
702        let result = tensor.sub_scalar_optimized(3.0);
703
704        assert_eq!(result.shape().dims(), vec![2, 2]);
705        assert_eq!(result.size(), 4);
706
707        // Check that all values are 7.0 (10.0 - 3.0)
708        unsafe {
709            for i in 0..result.size() {
710                assert!((result.as_ptr().add(i).read() - 7.0).abs() < 1e-6);
711            }
712        }
713    }
714
715    #[test]
716    fn test_negative_subtraction() {
717        let mut a = Tensor::ones(vec![2, 2]);
718        a.fill(2.0); // Create a tensor with all 2.0s
719        let mut b = Tensor::ones(vec![2, 2]);
720        b.fill(5.0); // Create a tensor with all 5.0s
721        let result = a.sub_tensor_optimized(&b);
722
723        // Check that all values are -3.0 (2.0 - 5.0)
724        unsafe {
725            for i in 0..result.size() {
726                assert!((result.as_ptr().add(i).read() - (-3.0)).abs() < 1e-6);
727            }
728        }
729    }
730
731    #[test]
732    fn test_scalar_negative_subtraction() {
733        let mut tensor = Tensor::ones(vec![2, 2]);
734        tensor.fill(3.0); // Create a tensor with all 3.0s
735        let result = tensor.sub_scalar_optimized(8.0);
736
737        // Check that all values are -5.0 (3.0 - 8.0)
738        unsafe {
739            for i in 0..result.size() {
740                assert!((result.as_ptr().add(i).read() - (-5.0)).abs() < 1e-6);
741            }
742        }
743    }
744
745    #[test]
746    #[should_panic(expected = "must match")]
747    fn test_mismatched_shapes() {
748        let a = Tensor::ones(vec![2, 3]);
749        let b = Tensor::ones(vec![3, 2]);
750        a.sub_tensor_optimized(&b);
751    }
752
753    #[test]
754    fn test_edge_cases() {
755        // Test zero subtraction
756        let a = Tensor::ones(vec![3]);
757        let b = Tensor::zeros(vec![3]);
758        let result = a.sub_tensor_optimized(&b);
759
760        unsafe {
761            for i in 0..result.size() {
762                assert!((result.as_ptr().add(i).read() - 1.0).abs() < 1e-6);
763            }
764        }
765
766        // Test self subtraction
767        let mut tensor = Tensor::ones(vec![3]);
768        tensor.fill(5.0);
769        let result = tensor.sub_tensor_optimized(&tensor);
770
771        unsafe {
772            for i in 0..result.size() {
773                assert!(result.as_ptr().add(i).read().abs() < 1e-6);
774            }
775        }
776    }
777
778    #[test]
779    fn test_large_tensor_subtraction() {
780        let mut a = Tensor::ones(vec![100, 100]);
781        a.fill(10.0);
782        let mut b = Tensor::ones(vec![100, 100]);
783        b.fill(3.0);
784        let result = a.sub_tensor_optimized(&b);
785
786        assert_eq!(result.size(), 10000);
787
788        // Check some values are 7.0 (10.0 - 3.0)
789        unsafe {
790            for i in (0..result.size()).step_by(1000) {
791                assert!((result.as_ptr().add(i).read() - 7.0).abs() < 1e-6);
792            }
793        }
794    }
795
796    #[test]
797    fn test_negate_inplace() {
798        let mut tensor = Tensor::ones(vec![2, 2]);
799        tensor.fill(5.0);
800
801        // Check initial values
802        unsafe {
803            for i in 0..tensor.size() {
804                let val = tensor.as_ptr().add(i).read();
805                assert!((val - 5.0).abs() < 1e-6, "Expected 5.0, got {}", val);
806            }
807        }
808
809        tensor.negate_inplace();
810
811        // Check negated values
812        unsafe {
813            for i in 0..tensor.size() {
814                let val = tensor.as_ptr().add(i).read();
815                assert!(
816                    (val - (-5.0)).abs() < 1e-6,
817                    "Expected -5.0 after negation, got {}",
818                    val
819                );
820            }
821        }
822    }
823
824    #[test]
825    fn test_subtraction_with_gradtrack() {
826        // Test scalar subtraction with gradtrack
827        let a = Tensor::ones(vec![2, 3]).with_requires_grad();
828        let mut result = a.sub_scalar(5.0);
829
830        // Check result values: 1.0 - 5.0 = -4.0
831        unsafe {
832            for i in 0..result.size() {
833                let val = result.as_ptr().add(i).read();
834                assert!((val - (-4.0)).abs() < 1e-6, "Expected -4.0, got {}", val);
835            }
836        }
837
838        result.backward(None);
839
840        // Check gradient: d/dx(x - c) = 1
841        if let Some(grad) = a.grad_owned() {
842            unsafe {
843                for i in 0..grad.size() {
844                    let val = grad.as_ptr().add(i).read();
845                    assert!(
846                        (val - 1.0).abs() < 1e-6,
847                        "Expected gradient 1.0, got {}",
848                        val
849                    );
850                }
851            }
852        } else {
853            panic!("No gradient computed for scalar subtraction!");
854        }
855
856        // Test tensor subtraction with gradtrack
857        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
858        let mut b = Tensor::ones(vec![2, 2]);
859        b.fill(3.0);
860        let b = b.with_requires_grad();
861
862        let mut result = a.sub_tensor(&b);
863
864        // Check result values: 1.0 - 3.0 = -2.0
865        unsafe {
866            for i in 0..result.size() {
867                let val = result.as_ptr().add(i).read();
868                assert!((val - (-2.0)).abs() < 1e-6, "Expected -2.0, got {}", val);
869            }
870        }
871
872        result.backward(None);
873
874        // Check gradients: d/dx(x - y) = 1, d/dy(x - y) = -1
875        if let Some(grad_a) = a.grad_owned() {
876            unsafe {
877                for i in 0..grad_a.size() {
878                    let val = grad_a.as_ptr().add(i).read();
879                    assert!(
880                        (val - 1.0).abs() < 1e-6,
881                        "Expected gradient A = 1.0, got {}",
882                        val
883                    );
884                }
885            }
886        } else {
887            panic!("No gradient A computed for tensor subtraction!");
888        }
889
890        if let Some(grad_b) = b.grad_owned() {
891            unsafe {
892                for i in 0..grad_b.size() {
893                    let val = grad_b.as_ptr().add(i).read();
894                    println!("Debug: grad_b[{}] = {}", i, val);
895                    assert!(
896                        (val - (-1.0)).abs() < 1e-6,
897                        "Expected gradient B = -1.0, got {}",
898                        val
899                    );
900                }
901            }
902        } else {
903            panic!("No gradient B computed for tensor subtraction!");
904        }
905    }
906
907    #[test]
908    fn test_mixed_add_sub_operations_with_gradtrack() {
909        // Test complex computation graph: (a + scalar1) - b + c - scalar2
910        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
911        let mut b = Tensor::ones(vec![2, 2]);
912        b.fill(2.0);
913        let b = b.with_requires_grad();
914        let mut c = Tensor::ones(vec![2, 2]);
915        c.fill(3.0);
916        let c = c.with_requires_grad();
917
918        let scalar1 = 5.0;
919        let scalar2 = 1.0;
920
921        // Complex computation: (a + scalar1) - b + c - scalar2
922        // Expected: (1 + 5) - 2 + 3 - 1 = 6
923        let step1 = a.add_scalar(scalar1); // a + 5 = 6
924        let step2 = step1.sub_tensor(&b); // 6 - 2 = 4
925        let step3 = step2.add_tensor(&c); // 4 + 3 = 7
926        let mut result = step3.sub_scalar(scalar2); // 7 - 1 = 6
927
928        // Check result values
929        unsafe {
930            for i in 0..result.size() {
931                let val = result.as_ptr().add(i).read();
932                assert!((val - 6.0).abs() < 1e-6, "Expected 6.0, got {}", val);
933            }
934        }
935
936        result.backward(None);
937
938        // Check gradients
939        // For computation: f = (a + 5) - b + c - 1
940        // df/da = 1, df/db = -1, df/dc = 1
941
942        if let Some(grad_a) = a.grad_owned() {
943            unsafe {
944                for i in 0..grad_a.size() {
945                    let val = grad_a.as_ptr().add(i).read();
946                    assert!(
947                        (val - 1.0).abs() < 1e-6,
948                        "Expected gradient A = 1.0, got {}",
949                        val
950                    );
951                }
952            }
953        } else {
954            panic!("No gradient A computed for mixed operations!");
955        }
956
957        if let Some(grad_b) = b.grad_owned() {
958            unsafe {
959                for i in 0..grad_b.size() {
960                    let val = grad_b.as_ptr().add(i).read();
961                    assert!(
962                        (val - (-1.0)).abs() < 1e-6,
963                        "Expected gradient B = -1.0, got {}",
964                        val
965                    );
966                }
967            }
968        } else {
969            panic!("No gradient B computed for mixed operations!");
970        }
971
972        if let Some(grad_c) = c.grad_owned() {
973            unsafe {
974                for i in 0..grad_c.size() {
975                    let val = grad_c.as_ptr().add(i).read();
976                    assert!(
977                        (val - 1.0).abs() < 1e-6,
978                        "Expected gradient C = 1.0, got {}",
979                        val
980                    );
981                }
982            }
983        } else {
984            panic!("No gradient C computed for mixed operations!");
985        }
986
987        println!("Mixed add/sub operations with gradtrack test passed!");
988        println!("✓ Complex computation graph: (a + 5) - b + c - 1 = 6");
989        println!("✓ Gradients: da/df = 1, db/df = -1, dc/df = 1");
990    }
991
992    #[test]
993    fn test_sub_broadcasting_gradients_basic() {
994        use crate::gradtrack::clear_gradients;
995        clear_gradients();
996
997        // Test case: [2, 3] - [1, 3] -> [2, 3]
998        // grad_a should be [2, 3], grad_b should be [1, 3] (summed over broadcast dim)
999        // For subtraction: d/da (a - b) = 1, d/db (a - b) = -1
1000
1001        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
1002            .unwrap()
1003            .with_requires_grad();
1004        let b = Tensor::from_slice(&[0.1, 0.2, 0.3], vec![1, 3])
1005            .unwrap()
1006            .with_requires_grad();
1007
1008        let mut result = a.sub_tensor(&b);
1009        assert_eq!(result.shape().dims(), vec![2, 3]);
1010
1011        // Set upstream gradient as ones
1012        result.backward(None);
1013
1014        let grad_a = a.grad_owned().expect("grad_a should exist");
1015        let grad_b = b.grad_owned().expect("grad_b should exist");
1016
1017        println!(
1018            "Original shapes: a={:?}, b={:?}",
1019            a.shape().dims(),
1020            b.shape().dims()
1021        );
1022        println!(
1023            "Gradient shapes: grad_a={:?}, grad_b={:?}",
1024            grad_a.shape().dims(),
1025            grad_b.shape().dims()
1026        );
1027
1028        // grad_a should have same shape as a: [2, 3]
1029        assert_eq!(
1030            grad_a.shape().dims(),
1031            vec![2, 3],
1032            "grad_a should match original shape of a"
1033        );
1034
1035        // grad_b should have same shape as b: [1, 3]
1036        // This requires summing over the broadcasted dimension
1037        assert_eq!(
1038            grad_b.shape().dims(),
1039            vec![1, 3],
1040            "grad_b should match original shape of b"
1041        );
1042
1043        // All gradients should be 1.0 for grad_a (d/da (a - b) = 1)
1044        for i in 0..grad_a.size() {
1045            let val = unsafe { *grad_a.as_ptr().add(i) };
1046            assert!(
1047                (val - 1.0).abs() < 1e-6,
1048                "grad_a[{}] = {} should be 1.0",
1049                i,
1050                val
1051            );
1052        }
1053
1054        // grad_b should be [-2.0, -2.0, -2.0] (sum over broadcast dim, then negated)
1055        let expected_grad_b = [-2.0, -2.0, -2.0]; // -1 * 2 rows = -2
1056        for (i, &expected) in expected_grad_b.iter().enumerate() {
1057            let val = unsafe { *grad_b.as_ptr().add(i) };
1058            assert!(
1059                (val - expected).abs() < 1e-6,
1060                "grad_b[{}] = {} should be {}",
1061                i,
1062                val,
1063                expected
1064            );
1065        }
1066    }
1067
1068    #[test]
1069    fn test_sub_scalar_broadcasting_gradients() {
1070        use crate::gradtrack::clear_gradients;
1071        clear_gradients();
1072
1073        // Test case: [2, 3] - [1] -> [2, 3]
1074        // grad_a should be [2, 3], grad_b should be [1] (summed over all dims, then negated)
1075
1076        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
1077            .unwrap()
1078            .with_requires_grad();
1079        let b = Tensor::from_slice(&[0.5], vec![1])
1080            .unwrap()
1081            .with_requires_grad();
1082
1083        let mut result = a.sub_tensor(&b);
1084        result.backward(None);
1085
1086        let grad_a = a.grad_owned().expect("grad_a should exist");
1087        let grad_b = b.grad_owned().expect("grad_b should exist");
1088
1089        // grad_a should have same shape as a: [2, 3]
1090        assert_eq!(grad_a.shape().dims(), vec![2, 3]);
1091
1092        // grad_b should have same shape as b: [1] and sum to -6.0
1093        println!("grad_b shape: {:?}, expected: [1]", grad_b.shape().dims());
1094        assert_eq!(grad_b.shape().dims(), vec![1]);
1095
1096        // grad_b should be -6.0 (sum over all 6 elements, then negated)
1097        let val = unsafe { *grad_b.as_ptr() };
1098        assert!(
1099            (val - (-6.0)).abs() < 1e-6,
1100            "grad_b = {} should be -6.0",
1101            val
1102        );
1103    }
1104}