train_station/tensor/ops/sub.rs
1//! Subtraction operations for tensors
2//!
3//! Provides element-wise subtraction operations following PyTorch conventions with
4//! comprehensive GradTrack support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Tensor Subtraction**: `sub_tensor()` - Element-wise subtraction with broadcasting support
9//! - **Scalar Subtraction**: `sub_scalar()` - Subtraction of scalar from tensor
10//! - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
11//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
12//! - **Broadcasting Support**: NumPy-style broadcasting for compatible shapes
13//! - **Performance Optimization**: 4x unrolled SIMD operations with scalar fallback
14//!
15//! # Mathematical Properties
16//!
17//! The subtraction operations have the following properties:
18//! - **Tensor-Tensor**: `output[i] = a[i] - b[i]` with broadcasting
19//! - **Tensor-Scalar**: `output[i] = a[i] - scalar` for all elements
20//! - **Commutativity**: Subtraction is not commutative (a - b ≠ b - a)
21//! - **Associativity**: Subtraction is not associative ((a - b) - c ≠ a - (b - c))
22//! - **Gradient**: For tensor-tensor: ∂(a-b)/∂a = 1, ∂(a-b)/∂b = -1
23//! - **Broadcasting**: Follows NumPy broadcasting rules for shape compatibility
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
28//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
29//! - **Cache-friendly Access**: Linear memory access patterns
30//! - **Broadcasting Overhead**: Minimal overhead for compatible shapes
31//! - **GradTrack Optimization**: Efficient automatic differentiation with NoGradTrack support
32
33use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
34use crate::tensor::core::Tensor;
35
36// SIMD optimizations for performance-critical operations
37#[cfg(target_arch = "x86_64")]
38use std::arch::x86_64::*;
39
40// Note: removed manual prefetching; linear access + hardware prefetch is sufficient
41
42impl Tensor {
43 /// Element-wise subtraction with another tensor with broadcasting support
44 ///
45 /// Performs element-wise subtraction with automatic broadcasting: `output[i] = self[i] - other[i]`
46 ///
47 /// Broadcasting enables subtraction between tensors of different but compatible shapes.
48 /// Compatible shapes follow NumPy broadcasting rules:
49 /// - Dimensions are aligned from the rightmost dimension
50 /// - Dimensions are compatible if they are equal, or one of them is 1
51 /// - Missing dimensions are treated as 1
52 ///
53 /// # Arguments
54 ///
55 /// * `other` - Tensor to subtract. Shapes must be broadcast-compatible.
56 ///
57 /// # Returns
58 ///
59 /// A new tensor containing the element-wise difference with broadcast result shape
60 ///
61 /// # Performance Characteristics
62 ///
63 /// - **Fast Path**: Optimized for identical shapes to avoid broadcasting overhead
64 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
65 /// - **Broadcasting**: Efficient broadcasting for compatible shapes
66 /// - **Cache-friendly**: Linear memory access patterns
67 /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
68 ///
69 /// # Implementation Details
70 ///
71 /// Uses a fast path for identical shapes to avoid broadcasting overhead.
72 /// For different shapes, performs broadcasting followed by optimized element-wise subtraction.
73 /// Automatically selects between SIMD and scalar implementations based on hardware capabilities.
74 ///
75 /// # Examples
76 ///
77 /// ## Same Shape Subtraction
78 ///
79 /// ```
80 /// use train_station::Tensor;
81 ///
82 /// let a = Tensor::from_slice(&[5.0, 7.0, 9.0], vec![3]).unwrap();
83 /// let b = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
84 /// let c = a.sub_tensor(&b);
85 /// assert_eq!(c.shape().dims(), vec![3]);
86 /// assert_eq!(c.get(&[0]), 4.0); // 5.0 - 1.0
87 /// assert_eq!(c.get(&[1]), 5.0); // 7.0 - 2.0
88 /// assert_eq!(c.get(&[2]), 6.0); // 9.0 - 3.0
89 /// ```
90 ///
91 /// ## Broadcasting Subtraction
92 ///
93 /// ```
94 /// use train_station::Tensor;
95 ///
96 /// let a = Tensor::from_slice(&[5.0, 10.0], vec![2, 1]).unwrap();
97 /// let b = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
98 /// let c = a.sub_tensor(&b);
99 /// assert_eq!(c.shape().dims(), vec![2, 3]);
100 /// // Result: [[4.0, 3.0, 2.0], [9.0, 8.0, 7.0]]
101 /// assert_eq!(c.get(&[0, 0]), 4.0); // 5.0 - 1.0
102 /// assert_eq!(c.get(&[0, 1]), 3.0); // 5.0 - 2.0
103 /// assert_eq!(c.get(&[1, 0]), 9.0); // 10.0 - 1.0
104 /// ```
105 ///
106 /// ## Scalar Subtraction
107 ///
108 /// ```
109 /// use train_station::Tensor;
110 ///
111 /// let a = Tensor::ones(vec![2, 3]);
112 /// let b = Tensor::from_slice(&[0.5], vec![1]).unwrap();
113 /// let c = a.sub_tensor(&b);
114 /// assert_eq!(c.shape().dims(), vec![2, 3]);
115 /// assert_eq!(c.get(&[0, 0]), 0.5); // 1.0 - 0.5
116 /// ```
117 ///
118 /// # Panics
119 /// Panics if tensor shapes are not broadcast-compatible
120 #[inline]
121 #[track_caller]
122 pub fn sub_tensor(&self, other: &Tensor) -> Tensor {
123 // Check if shapes are identical for fast path
124 if self.shape().dims() == other.shape().dims() {
125 return self.sub_tensor_same_shape(other);
126 }
127
128 // Zero-copy broadcast views, then reuse same-shape optimized path
129 use crate::tensor::ops::broadcasting::{broadcast_shapes_cow, BroadcastError};
130 let mut result = match broadcast_shapes_cow(self, other) {
131 Ok((a_b, b_b, _)) => a_b.as_ref().sub_tensor_optimized(b_b.as_ref()),
132 Err(BroadcastError::IncompatibleShapes { .. }) => {
133 panic!(
134 "Cannot broadcast tensor shapes {:?} and {:?}: {}",
135 self.shape().dims(),
136 other.shape().dims(),
137 "incompatible shapes"
138 );
139 }
140 Err(BroadcastError::AllocationFailed) => {
141 panic!("Memory allocation failed during broadcasting");
142 }
143 };
144
145 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
146 result.set_requires_grad_internal(true);
147 let grad_fn = GradFn::Sub {
148 is_tensor_sub: true,
149 original_shapes: Some((
150 self.shape().dims().to_vec(),
151 other.shape().dims().to_vec(),
152 )),
153 };
154 result.set_grad_fn(grad_fn.clone());
155
156 let mut input_ids = Vec::with_capacity(2);
157 if self.requires_grad() {
158 input_ids.push(self.id());
159 }
160 if other.requires_grad() {
161 input_ids.push(other.id());
162 }
163 GradEngine::register_operation(result.id(), input_ids, grad_fn);
164 }
165
166 result
167 }
168
169 /// Element-wise subtraction for tensors with identical shapes (fast path)
170 ///
171 /// This is an optimized path for tensors that already have the same shape,
172 /// avoiding the overhead of broadcasting computation.
173 ///
174 /// # Arguments
175 ///
176 /// * `other` - Tensor to subtract, must have the same shape as self
177 ///
178 /// # Returns
179 ///
180 /// A new tensor containing the element-wise difference
181 ///
182 /// # Performance Characteristics
183 ///
184 /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
185 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
186 /// - **Cache-friendly**: Linear memory access patterns
187 /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
188 ///
189 /// # Implementation Details
190 ///
191 /// This method is used internally by `sub_tensor()` when tensors have identical shapes.
192 /// It bypasses the broadcasting logic and directly calls the optimized subtraction implementation.
193 #[inline]
194 fn sub_tensor_same_shape(&self, other: &Tensor) -> Tensor {
195 assert_eq!(
196 self.shape(),
197 other.shape(),
198 "Tensor shapes must match for same-shape subtraction"
199 );
200 let mut result = self.sub_tensor_optimized(other);
201
202 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
203 result.set_requires_grad_internal(true);
204 let grad_fn = GradFn::Sub {
205 is_tensor_sub: true,
206 original_shapes: None, // Same shape case
207 };
208 result.set_grad_fn(grad_fn.clone());
209
210 let mut input_ids = Vec::with_capacity(2);
211 if self.requires_grad() {
212 input_ids.push(self.id());
213 }
214 if other.requires_grad() {
215 input_ids.push(other.id());
216 }
217 GradEngine::register_operation(result.id(), input_ids, grad_fn);
218 }
219
220 result
221 }
222
223 /// Element-wise subtraction of a scalar from this tensor
224 ///
225 /// Performs element-wise subtraction of a scalar value: `output[i] = self[i] - scalar`
226 ///
227 /// # Arguments
228 ///
229 /// * `scalar` - The scalar value to subtract from each element
230 ///
231 /// # Returns
232 ///
233 /// A new tensor with the scalar subtracted from each element
234 ///
235 /// # Performance Characteristics
236 ///
237 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
238 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
239 /// - **Cache-friendly**: Linear memory access patterns
240 /// - **Mathematical Accuracy**: High-precision subtraction computation
241 /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
242 ///
243 /// # Examples
244 ///
245 /// ## Basic Scalar Subtraction
246 ///
247 /// ```
248 /// use train_station::Tensor;
249 ///
250 /// let a = Tensor::from_slice(&[5.0, 7.0, 9.0], vec![3]).unwrap();
251 /// let b = a.sub_scalar(2.0);
252 /// assert_eq!(b.shape().dims(), vec![3]);
253 /// assert_eq!(b.get(&[0]), 3.0); // 5.0 - 2.0
254 /// assert_eq!(b.get(&[1]), 5.0); // 7.0 - 2.0
255 /// assert_eq!(b.get(&[2]), 7.0); // 9.0 - 2.0
256 /// ```
257 ///
258 /// ## Negative Scalar Subtraction
259 ///
260 /// ```
261 /// use train_station::Tensor;
262 ///
263 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
264 /// let b = a.sub_scalar(-2.0); // Subtracting negative = adding
265 /// assert_eq!(b.shape().dims(), vec![3]);
266 /// assert_eq!(b.get(&[0]), 3.0); // 1.0 - (-2.0) = 3.0
267 /// assert_eq!(b.get(&[1]), 4.0); // 2.0 - (-2.0) = 4.0
268 /// assert_eq!(b.get(&[2]), 5.0); // 3.0 - (-2.0) = 5.0
269 /// ```
270 #[inline]
271 #[track_caller]
272 pub fn sub_scalar(&self, scalar: f32) -> Tensor {
273 let mut result = self.sub_scalar_optimized(scalar);
274
275 if self.requires_grad() && is_grad_enabled() {
276 result.set_requires_grad_internal(true);
277 let grad_fn = GradFn::Sub {
278 is_tensor_sub: false,
279 original_shapes: None, // Scalar case
280 };
281 result.set_grad_fn(grad_fn.clone());
282 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
283 }
284
285 result
286 }
287 /// Optimized tensor subtraction using SIMD when available
288 ///
289 /// Performs element-wise subtraction between tensors with identical shapes using
290 /// SIMD optimization when available and falling back to optimized scalar computation.
291 ///
292 /// # Arguments
293 ///
294 /// * `other` - The tensor to subtract from this tensor
295 ///
296 /// # Returns
297 ///
298 /// A new tensor with the result of the subtraction (self - other)
299 ///
300 /// # Performance Characteristics
301 ///
302 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
303 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
304 /// - **Cache-friendly**: Linear memory access patterns
305 /// - **Mathematical Accuracy**: High-precision subtraction computation
306 /// - **Zero-sized Handling**: Fast return for empty tensors
307 ///
308 /// # Implementation Details
309 ///
310 /// Automatically selects between SIMD and scalar implementations based on hardware
311 /// capabilities. SIMD implementation uses AVX2 vector subtraction operations for optimal
312 /// performance. Scalar implementation uses 4x unrolling for better instruction-level
313 /// parallelism.
314 ///
315 /// # Safety
316 ///
317 /// This operation assumes the tensors have the same shape.
318 #[inline]
319 pub(crate) fn sub_tensor_optimized(&self, other: &Tensor) -> Tensor {
320 assert_eq!(
321 self.shape().dims(),
322 other.shape().dims(),
323 "Tensor dims must match"
324 );
325
326 // Ensure contiguous sources for correctness with broadcast views/strides
327 let a_src = if self.is_contiguous() {
328 self.clone()
329 } else {
330 self.contiguous()
331 };
332 let b_src = if other.is_contiguous() {
333 other.clone()
334 } else {
335 other.contiguous()
336 };
337
338 let mut output = Tensor::new(self.shape().dims().to_vec());
339
340 unsafe {
341 let a = a_src.as_ptr();
342 let b = b_src.as_ptr();
343 let dst = output.as_mut_ptr();
344
345 #[cfg(target_arch = "x86_64")]
346 {
347 // Use SIMD for better performance when available
348 if is_x86_feature_detected!("avx2") {
349 self.sub_tensors_simd_avx2_optimized(a, b, dst);
350 return output;
351 }
352 }
353
354 // Fallback to scalar operations with better cache usage
355 self.sub_tensors_scalar_optimized(a, b, dst);
356 }
357
358 output
359 }
360
361 /// AVX2-optimized tensor subtraction implementation
362 ///
363 /// Performs element-wise subtraction using AVX2 SIMD instructions for maximum
364 /// performance on x86_64 architectures with AVX2 support.
365 ///
366 /// # Arguments
367 ///
368 /// * `a` - Pointer to first tensor data
369 /// * `b` - Pointer to second tensor data
370 /// * `dst` - Pointer to output tensor data
371 ///
372 /// # Safety
373 ///
374 /// Requires valid pointers with sufficient memory for the tensor size.
375 /// All pointers must point to valid tensor data. Requires AVX2 support.
376 ///
377 /// # Performance Characteristics
378 ///
379 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
380 /// - **Memory Access**: Linear access patterns for cache efficiency
381 /// - **Vector Operations**: Uses AVX2 subtraction instructions for computation
382 /// - **Fallback**: Handles remaining elements with scalar operations
383 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
384 ///
385 /// # Implementation Details
386 ///
387 /// Uses AVX2 vector subtraction operations to compute a - b efficiently.
388 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
389 /// Processes remaining elements with scalar operations for complete coverage.
390 #[cfg(target_arch = "x86_64")]
391 #[inline]
392 #[target_feature(enable = "avx2")]
393 unsafe fn sub_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
394 let size = self.size();
395 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
396 let mut offset = 0;
397
398 // Unrolled SIMD loop for throughput
399 for _ in 0..simd_count {
400 // Process 4 AVX2 vectors (32 elements) per iteration
401 let a_vec1 = _mm256_loadu_ps(a.add(offset));
402 let b_vec1 = _mm256_loadu_ps(b.add(offset));
403 let sub_vec1 = _mm256_sub_ps(a_vec1, b_vec1);
404 _mm256_storeu_ps(dst.add(offset), sub_vec1);
405
406 let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
407 let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
408 let sub_vec2 = _mm256_sub_ps(a_vec2, b_vec2);
409 _mm256_storeu_ps(dst.add(offset + 8), sub_vec2);
410
411 let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
412 let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
413 let sub_vec3 = _mm256_sub_ps(a_vec3, b_vec3);
414 _mm256_storeu_ps(dst.add(offset + 16), sub_vec3);
415
416 let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
417 let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
418 let sub_vec4 = _mm256_sub_ps(a_vec4, b_vec4);
419 _mm256_storeu_ps(dst.add(offset + 24), sub_vec4);
420
421 offset += 32;
422 }
423
424 // Handle remaining 8-element blocks
425 let remaining_full_blocks = (size - offset) / 8;
426 for _ in 0..remaining_full_blocks {
427 let a_vec = _mm256_loadu_ps(a.add(offset));
428 let b_vec = _mm256_loadu_ps(b.add(offset));
429 let sub_vec = _mm256_sub_ps(a_vec, b_vec);
430 _mm256_storeu_ps(dst.add(offset), sub_vec);
431 offset += 8;
432 }
433
434 // Handle final elements with unrolled loop
435 let remaining = size - offset;
436 let unroll_count = remaining / 4;
437 for _ in 0..unroll_count {
438 *dst.add(offset) = *a.add(offset) - *b.add(offset);
439 *dst.add(offset + 1) = *a.add(offset + 1) - *b.add(offset + 1);
440 *dst.add(offset + 2) = *a.add(offset + 2) - *b.add(offset + 2);
441 *dst.add(offset + 3) = *a.add(offset + 3) - *b.add(offset + 3);
442 offset += 4;
443 }
444
445 for i in offset..size {
446 *dst.add(i) = *a.add(i) - *b.add(i);
447 }
448 }
449
450 /// Optimized scalar tensor subtraction fallback
451 ///
452 /// Performs element-wise subtraction using optimized scalar operations with
453 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
454 ///
455 /// # Arguments
456 ///
457 /// * `a` - Pointer to first tensor data
458 /// * `b` - Pointer to second tensor data
459 /// * `dst` - Pointer to output tensor data
460 ///
461 /// # Safety
462 ///
463 /// Requires valid pointers with sufficient memory for the tensor size.
464 /// All pointers must point to valid tensor data.
465 ///
466 /// # Performance Characteristics
467 ///
468 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
469 /// - **Memory Access**: Linear access patterns for cache efficiency
470 /// - **Fallback**: Handles remaining elements with scalar operations
471 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
472 /// - **Mathematical Accuracy**: High-precision scalar subtraction computation
473 ///
474 /// # Implementation Details
475 ///
476 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
477 /// Processes elements in groups of 4 to improve instruction-level parallelism
478 /// and reduce loop overhead.
479 #[inline]
480 unsafe fn sub_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
481 let size = self.size();
482 let unroll_count = size / 4;
483 let mut offset = 0;
484
485 // Unrolled scalar loop for better performance
486 for _ in 0..unroll_count {
487 *dst.add(offset) = *a.add(offset) - *b.add(offset);
488 *dst.add(offset + 1) = *a.add(offset + 1) - *b.add(offset + 1);
489 *dst.add(offset + 2) = *a.add(offset + 2) - *b.add(offset + 2);
490 *dst.add(offset + 3) = *a.add(offset + 3) - *b.add(offset + 3);
491 offset += 4;
492 }
493
494 // Handle remaining elements
495 for i in offset..size {
496 *dst.add(i) = *a.add(i) - *b.add(i);
497 }
498 }
499
500 /// Internal optimized scalar subtraction operation
501 ///
502 /// Performs element-wise subtraction of a scalar from tensor using SIMD optimization
503 /// when available and falling back to optimized scalar computation.
504 ///
505 /// # Arguments
506 ///
507 /// * `scalar` - The scalar value to subtract from each element
508 ///
509 /// # Returns
510 ///
511 /// A new tensor with the scalar subtracted from each element
512 ///
513 /// # Performance Characteristics
514 ///
515 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
516 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
517 /// - **Cache-friendly**: Linear memory access patterns
518 /// - **Mathematical Accuracy**: High-precision subtraction computation
519 /// - **Zero-sized Handling**: Fast return for empty tensors
520 ///
521 /// # Implementation Details
522 ///
523 /// Automatically selects between SIMD and scalar implementations based on hardware
524 /// capabilities. SIMD implementation uses AVX2 vector subtraction operations for optimal
525 /// performance. Scalar implementation uses 4x unrolling for better instruction-level
526 /// parallelism.
527 #[inline]
528 pub(crate) fn sub_scalar_optimized(&self, scalar: f32) -> Tensor {
529 let mut output = Tensor::new(self.shape().dims().to_vec());
530
531 unsafe {
532 let src = self.as_ptr();
533 let dst = output.as_mut_ptr();
534
535 #[cfg(target_arch = "x86_64")]
536 {
537 // Use SIMD for better performance when available
538 if is_x86_feature_detected!("avx2") {
539 self.sub_scalar_simd_avx2_optimized(src, dst, scalar);
540 return output;
541 }
542 }
543
544 // Fallback to optimized scalar operations
545 self.sub_scalar_fallback_optimized(src, dst, scalar);
546 }
547
548 output
549 }
550
551 /// AVX2-optimized scalar subtraction implementation
552 ///
553 /// Performs element-wise subtraction of a scalar using AVX2 SIMD instructions for maximum
554 /// performance on x86_64 architectures with AVX2 support.
555 ///
556 /// # Arguments
557 ///
558 /// * `src` - Pointer to source tensor data
559 /// * `dst` - Pointer to output tensor data
560 /// * `scalar` - The scalar value to subtract from each element
561 ///
562 /// # Safety
563 ///
564 /// Requires valid pointers with sufficient memory for the tensor size.
565 /// All pointers must point to valid tensor data. Requires AVX2 support.
566 ///
567 /// # Performance Characteristics
568 ///
569 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
570 /// - **Memory Access**: Linear access patterns for cache efficiency
571 /// - **Vector Operations**: Uses AVX2 subtraction instructions for computation
572 /// - **Fallback**: Handles remaining elements with scalar operations
573 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
574 ///
575 /// # Implementation Details
576 ///
577 /// Uses AVX2 vector subtraction operations to compute src - scalar efficiently.
578 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
579 /// Processes remaining elements with scalar operations for complete coverage.
580 #[cfg(target_arch = "x86_64")]
581 #[inline]
582 #[target_feature(enable = "avx2")]
583 unsafe fn sub_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
584 let scalar_vec = _mm256_set1_ps(scalar);
585 let size = self.size();
586 let simd_count = size / 32; // Process 32 elements per iteration
587 let mut offset = 0;
588
589 // Unrolled SIMD loop for better instruction throughput
590 for _ in 0..simd_count {
591 let src_vec1 = _mm256_loadu_ps(src.add(offset));
592 let sub_vec1 = _mm256_sub_ps(src_vec1, scalar_vec);
593 _mm256_storeu_ps(dst.add(offset), sub_vec1);
594
595 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
596 let sub_vec2 = _mm256_sub_ps(src_vec2, scalar_vec);
597 _mm256_storeu_ps(dst.add(offset + 8), sub_vec2);
598
599 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
600 let sub_vec3 = _mm256_sub_ps(src_vec3, scalar_vec);
601 _mm256_storeu_ps(dst.add(offset + 16), sub_vec3);
602
603 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
604 let sub_vec4 = _mm256_sub_ps(src_vec4, scalar_vec);
605 _mm256_storeu_ps(dst.add(offset + 24), sub_vec4);
606
607 offset += 32;
608 }
609
610 // Handle remaining 8-element blocks
611 let remaining_full_blocks = (size - offset) / 8;
612 for _ in 0..remaining_full_blocks {
613 let src_vec = _mm256_loadu_ps(src.add(offset));
614 let sub_vec = _mm256_sub_ps(src_vec, scalar_vec);
615 _mm256_storeu_ps(dst.add(offset), sub_vec);
616 offset += 8;
617 }
618
619 // Handle final elements
620 for i in offset..size {
621 *dst.add(i) = *src.add(i) - scalar;
622 }
623 }
624
625 /// Optimized scalar subtraction fallback
626 ///
627 /// Performs element-wise subtraction of a scalar using optimized scalar operations with
628 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
629 ///
630 /// # Arguments
631 ///
632 /// * `src` - Pointer to source tensor data
633 /// * `dst` - Pointer to output tensor data
634 /// * `scalar` - The scalar value to subtract from each element
635 ///
636 /// # Safety
637 ///
638 /// Requires valid pointers with sufficient memory for the tensor size.
639 /// All pointers must point to valid tensor data.
640 ///
641 /// # Performance Characteristics
642 ///
643 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
644 /// - **Memory Access**: Linear access patterns for cache efficiency
645 /// - **Fallback**: Handles remaining elements with scalar operations
646 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
647 /// - **Mathematical Accuracy**: High-precision scalar subtraction computation
648 ///
649 /// # Implementation Details
650 ///
651 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
652 /// Processes elements in groups of 4 to improve instruction-level parallelism
653 /// and reduce loop overhead.
654 #[inline]
655 unsafe fn sub_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
656 let size = self.size();
657 let unroll_count = size / 4;
658 let mut offset = 0;
659
660 // Unrolled scalar operations
661 for _ in 0..unroll_count {
662 *dst.add(offset) = *src.add(offset) - scalar;
663 *dst.add(offset + 1) = *src.add(offset + 1) - scalar;
664 *dst.add(offset + 2) = *src.add(offset + 2) - scalar;
665 *dst.add(offset + 3) = *src.add(offset + 3) - scalar;
666 offset += 4;
667 }
668
669 for i in offset..size {
670 *dst.add(i) = *src.add(i) - scalar;
671 }
672 }
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678
679 #[test]
680 fn test_tensor_subtraction() {
681 let mut a = Tensor::ones(vec![2, 3]);
682 a.fill(5.0); // Create a tensor with all 5.0s
683 let mut b = Tensor::ones(vec![2, 3]);
684 b.fill(2.0); // Create a tensor with all 2.0s
685 let result = a.sub_tensor_optimized(&b);
686
687 assert_eq!(result.shape().dims(), vec![2, 3]);
688 assert_eq!(result.size(), 6);
689
690 // Check that all values are 3.0 (5.0 - 2.0)
691 unsafe {
692 for i in 0..result.size() {
693 assert!((result.as_ptr().add(i).read() - 3.0).abs() < 1e-6);
694 }
695 }
696 }
697
698 #[test]
699 fn test_scalar_subtraction() {
700 let mut tensor = Tensor::ones(vec![2, 2]);
701 tensor.fill(10.0); // Create a tensor with all 10.0s
702 let result = tensor.sub_scalar_optimized(3.0);
703
704 assert_eq!(result.shape().dims(), vec![2, 2]);
705 assert_eq!(result.size(), 4);
706
707 // Check that all values are 7.0 (10.0 - 3.0)
708 unsafe {
709 for i in 0..result.size() {
710 assert!((result.as_ptr().add(i).read() - 7.0).abs() < 1e-6);
711 }
712 }
713 }
714
715 #[test]
716 fn test_negative_subtraction() {
717 let mut a = Tensor::ones(vec![2, 2]);
718 a.fill(2.0); // Create a tensor with all 2.0s
719 let mut b = Tensor::ones(vec![2, 2]);
720 b.fill(5.0); // Create a tensor with all 5.0s
721 let result = a.sub_tensor_optimized(&b);
722
723 // Check that all values are -3.0 (2.0 - 5.0)
724 unsafe {
725 for i in 0..result.size() {
726 assert!((result.as_ptr().add(i).read() - (-3.0)).abs() < 1e-6);
727 }
728 }
729 }
730
731 #[test]
732 fn test_scalar_negative_subtraction() {
733 let mut tensor = Tensor::ones(vec![2, 2]);
734 tensor.fill(3.0); // Create a tensor with all 3.0s
735 let result = tensor.sub_scalar_optimized(8.0);
736
737 // Check that all values are -5.0 (3.0 - 8.0)
738 unsafe {
739 for i in 0..result.size() {
740 assert!((result.as_ptr().add(i).read() - (-5.0)).abs() < 1e-6);
741 }
742 }
743 }
744
745 #[test]
746 #[should_panic(expected = "must match")]
747 fn test_mismatched_shapes() {
748 let a = Tensor::ones(vec![2, 3]);
749 let b = Tensor::ones(vec![3, 2]);
750 a.sub_tensor_optimized(&b);
751 }
752
753 #[test]
754 fn test_edge_cases() {
755 // Test zero subtraction
756 let a = Tensor::ones(vec![3]);
757 let b = Tensor::zeros(vec![3]);
758 let result = a.sub_tensor_optimized(&b);
759
760 unsafe {
761 for i in 0..result.size() {
762 assert!((result.as_ptr().add(i).read() - 1.0).abs() < 1e-6);
763 }
764 }
765
766 // Test self subtraction
767 let mut tensor = Tensor::ones(vec![3]);
768 tensor.fill(5.0);
769 let result = tensor.sub_tensor_optimized(&tensor);
770
771 unsafe {
772 for i in 0..result.size() {
773 assert!(result.as_ptr().add(i).read().abs() < 1e-6);
774 }
775 }
776 }
777
778 #[test]
779 fn test_large_tensor_subtraction() {
780 let mut a = Tensor::ones(vec![100, 100]);
781 a.fill(10.0);
782 let mut b = Tensor::ones(vec![100, 100]);
783 b.fill(3.0);
784 let result = a.sub_tensor_optimized(&b);
785
786 assert_eq!(result.size(), 10000);
787
788 // Check some values are 7.0 (10.0 - 3.0)
789 unsafe {
790 for i in (0..result.size()).step_by(1000) {
791 assert!((result.as_ptr().add(i).read() - 7.0).abs() < 1e-6);
792 }
793 }
794 }
795
796 #[test]
797 fn test_negate_inplace() {
798 let mut tensor = Tensor::ones(vec![2, 2]);
799 tensor.fill(5.0);
800
801 // Check initial values
802 unsafe {
803 for i in 0..tensor.size() {
804 let val = tensor.as_ptr().add(i).read();
805 assert!((val - 5.0).abs() < 1e-6, "Expected 5.0, got {}", val);
806 }
807 }
808
809 tensor.negate_inplace();
810
811 // Check negated values
812 unsafe {
813 for i in 0..tensor.size() {
814 let val = tensor.as_ptr().add(i).read();
815 assert!(
816 (val - (-5.0)).abs() < 1e-6,
817 "Expected -5.0 after negation, got {}",
818 val
819 );
820 }
821 }
822 }
823
824 #[test]
825 fn test_subtraction_with_gradtrack() {
826 // Test scalar subtraction with gradtrack
827 let a = Tensor::ones(vec![2, 3]).with_requires_grad();
828 let mut result = a.sub_scalar(5.0);
829
830 // Check result values: 1.0 - 5.0 = -4.0
831 unsafe {
832 for i in 0..result.size() {
833 let val = result.as_ptr().add(i).read();
834 assert!((val - (-4.0)).abs() < 1e-6, "Expected -4.0, got {}", val);
835 }
836 }
837
838 result.backward(None);
839
840 // Check gradient: d/dx(x - c) = 1
841 if let Some(grad) = a.grad_owned() {
842 unsafe {
843 for i in 0..grad.size() {
844 let val = grad.as_ptr().add(i).read();
845 assert!(
846 (val - 1.0).abs() < 1e-6,
847 "Expected gradient 1.0, got {}",
848 val
849 );
850 }
851 }
852 } else {
853 panic!("No gradient computed for scalar subtraction!");
854 }
855
856 // Test tensor subtraction with gradtrack
857 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
858 let mut b = Tensor::ones(vec![2, 2]);
859 b.fill(3.0);
860 let b = b.with_requires_grad();
861
862 let mut result = a.sub_tensor(&b);
863
864 // Check result values: 1.0 - 3.0 = -2.0
865 unsafe {
866 for i in 0..result.size() {
867 let val = result.as_ptr().add(i).read();
868 assert!((val - (-2.0)).abs() < 1e-6, "Expected -2.0, got {}", val);
869 }
870 }
871
872 result.backward(None);
873
874 // Check gradients: d/dx(x - y) = 1, d/dy(x - y) = -1
875 if let Some(grad_a) = a.grad_owned() {
876 unsafe {
877 for i in 0..grad_a.size() {
878 let val = grad_a.as_ptr().add(i).read();
879 assert!(
880 (val - 1.0).abs() < 1e-6,
881 "Expected gradient A = 1.0, got {}",
882 val
883 );
884 }
885 }
886 } else {
887 panic!("No gradient A computed for tensor subtraction!");
888 }
889
890 if let Some(grad_b) = b.grad_owned() {
891 unsafe {
892 for i in 0..grad_b.size() {
893 let val = grad_b.as_ptr().add(i).read();
894 println!("Debug: grad_b[{}] = {}", i, val);
895 assert!(
896 (val - (-1.0)).abs() < 1e-6,
897 "Expected gradient B = -1.0, got {}",
898 val
899 );
900 }
901 }
902 } else {
903 panic!("No gradient B computed for tensor subtraction!");
904 }
905 }
906
907 #[test]
908 fn test_mixed_add_sub_operations_with_gradtrack() {
909 // Test complex computation graph: (a + scalar1) - b + c - scalar2
910 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
911 let mut b = Tensor::ones(vec![2, 2]);
912 b.fill(2.0);
913 let b = b.with_requires_grad();
914 let mut c = Tensor::ones(vec![2, 2]);
915 c.fill(3.0);
916 let c = c.with_requires_grad();
917
918 let scalar1 = 5.0;
919 let scalar2 = 1.0;
920
921 // Complex computation: (a + scalar1) - b + c - scalar2
922 // Expected: (1 + 5) - 2 + 3 - 1 = 6
923 let step1 = a.add_scalar(scalar1); // a + 5 = 6
924 let step2 = step1.sub_tensor(&b); // 6 - 2 = 4
925 let step3 = step2.add_tensor(&c); // 4 + 3 = 7
926 let mut result = step3.sub_scalar(scalar2); // 7 - 1 = 6
927
928 // Check result values
929 unsafe {
930 for i in 0..result.size() {
931 let val = result.as_ptr().add(i).read();
932 assert!((val - 6.0).abs() < 1e-6, "Expected 6.0, got {}", val);
933 }
934 }
935
936 result.backward(None);
937
938 // Check gradients
939 // For computation: f = (a + 5) - b + c - 1
940 // df/da = 1, df/db = -1, df/dc = 1
941
942 if let Some(grad_a) = a.grad_owned() {
943 unsafe {
944 for i in 0..grad_a.size() {
945 let val = grad_a.as_ptr().add(i).read();
946 assert!(
947 (val - 1.0).abs() < 1e-6,
948 "Expected gradient A = 1.0, got {}",
949 val
950 );
951 }
952 }
953 } else {
954 panic!("No gradient A computed for mixed operations!");
955 }
956
957 if let Some(grad_b) = b.grad_owned() {
958 unsafe {
959 for i in 0..grad_b.size() {
960 let val = grad_b.as_ptr().add(i).read();
961 assert!(
962 (val - (-1.0)).abs() < 1e-6,
963 "Expected gradient B = -1.0, got {}",
964 val
965 );
966 }
967 }
968 } else {
969 panic!("No gradient B computed for mixed operations!");
970 }
971
972 if let Some(grad_c) = c.grad_owned() {
973 unsafe {
974 for i in 0..grad_c.size() {
975 let val = grad_c.as_ptr().add(i).read();
976 assert!(
977 (val - 1.0).abs() < 1e-6,
978 "Expected gradient C = 1.0, got {}",
979 val
980 );
981 }
982 }
983 } else {
984 panic!("No gradient C computed for mixed operations!");
985 }
986
987 println!("Mixed add/sub operations with gradtrack test passed!");
988 println!("✓ Complex computation graph: (a + 5) - b + c - 1 = 6");
989 println!("✓ Gradients: da/df = 1, db/df = -1, dc/df = 1");
990 }
991
992 #[test]
993 fn test_sub_broadcasting_gradients_basic() {
994 use crate::gradtrack::clear_gradients;
995 clear_gradients();
996
997 // Test case: [2, 3] - [1, 3] -> [2, 3]
998 // grad_a should be [2, 3], grad_b should be [1, 3] (summed over broadcast dim)
999 // For subtraction: d/da (a - b) = 1, d/db (a - b) = -1
1000
1001 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
1002 .unwrap()
1003 .with_requires_grad();
1004 let b = Tensor::from_slice(&[0.1, 0.2, 0.3], vec![1, 3])
1005 .unwrap()
1006 .with_requires_grad();
1007
1008 let mut result = a.sub_tensor(&b);
1009 assert_eq!(result.shape().dims(), vec![2, 3]);
1010
1011 // Set upstream gradient as ones
1012 result.backward(None);
1013
1014 let grad_a = a.grad_owned().expect("grad_a should exist");
1015 let grad_b = b.grad_owned().expect("grad_b should exist");
1016
1017 println!(
1018 "Original shapes: a={:?}, b={:?}",
1019 a.shape().dims(),
1020 b.shape().dims()
1021 );
1022 println!(
1023 "Gradient shapes: grad_a={:?}, grad_b={:?}",
1024 grad_a.shape().dims(),
1025 grad_b.shape().dims()
1026 );
1027
1028 // grad_a should have same shape as a: [2, 3]
1029 assert_eq!(
1030 grad_a.shape().dims(),
1031 vec![2, 3],
1032 "grad_a should match original shape of a"
1033 );
1034
1035 // grad_b should have same shape as b: [1, 3]
1036 // This requires summing over the broadcasted dimension
1037 assert_eq!(
1038 grad_b.shape().dims(),
1039 vec![1, 3],
1040 "grad_b should match original shape of b"
1041 );
1042
1043 // All gradients should be 1.0 for grad_a (d/da (a - b) = 1)
1044 for i in 0..grad_a.size() {
1045 let val = unsafe { *grad_a.as_ptr().add(i) };
1046 assert!(
1047 (val - 1.0).abs() < 1e-6,
1048 "grad_a[{}] = {} should be 1.0",
1049 i,
1050 val
1051 );
1052 }
1053
1054 // grad_b should be [-2.0, -2.0, -2.0] (sum over broadcast dim, then negated)
1055 let expected_grad_b = [-2.0, -2.0, -2.0]; // -1 * 2 rows = -2
1056 for (i, &expected) in expected_grad_b.iter().enumerate() {
1057 let val = unsafe { *grad_b.as_ptr().add(i) };
1058 assert!(
1059 (val - expected).abs() < 1e-6,
1060 "grad_b[{}] = {} should be {}",
1061 i,
1062 val,
1063 expected
1064 );
1065 }
1066 }
1067
1068 #[test]
1069 fn test_sub_scalar_broadcasting_gradients() {
1070 use crate::gradtrack::clear_gradients;
1071 clear_gradients();
1072
1073 // Test case: [2, 3] - [1] -> [2, 3]
1074 // grad_a should be [2, 3], grad_b should be [1] (summed over all dims, then negated)
1075
1076 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
1077 .unwrap()
1078 .with_requires_grad();
1079 let b = Tensor::from_slice(&[0.5], vec![1])
1080 .unwrap()
1081 .with_requires_grad();
1082
1083 let mut result = a.sub_tensor(&b);
1084 result.backward(None);
1085
1086 let grad_a = a.grad_owned().expect("grad_a should exist");
1087 let grad_b = b.grad_owned().expect("grad_b should exist");
1088
1089 // grad_a should have same shape as a: [2, 3]
1090 assert_eq!(grad_a.shape().dims(), vec![2, 3]);
1091
1092 // grad_b should have same shape as b: [1] and sum to -6.0
1093 println!("grad_b shape: {:?}, expected: [1]", grad_b.shape().dims());
1094 assert_eq!(grad_b.shape().dims(), vec![1]);
1095
1096 // grad_b should be -6.0 (sum over all 6 elements, then negated)
1097 let val = unsafe { *grad_b.as_ptr() };
1098 assert!(
1099 (val - (-6.0)).abs() < 1e-6,
1100 "grad_b = {} should be -6.0",
1101 val
1102 );
1103 }
1104}