train_station/tensor/ops/sub.rs
1//! Subtraction operations for tensors
2//!
3//! Provides element-wise subtraction operations following PyTorch conventions with
4//! comprehensive GradTrack support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Tensor Subtraction**: `sub_tensor()` - Element-wise subtraction with broadcasting support
9//! - **Scalar Subtraction**: `sub_scalar()` - Subtraction of scalar from tensor
10//! - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
11//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
12//! - **Broadcasting Support**: NumPy-style broadcasting for compatible shapes
13//! - **Performance Optimization**: 4x unrolled SIMD operations with scalar fallback
14//!
15//! # Mathematical Properties
16//!
17//! The subtraction operations have the following properties:
18//! - **Tensor-Tensor**: `output[i] = a[i] - b[i]` with broadcasting
19//! - **Tensor-Scalar**: `output[i] = a[i] - scalar` for all elements
20//! - **Commutativity**: Subtraction is not commutative (a - b ≠ b - a)
21//! - **Associativity**: Subtraction is not associative ((a - b) - c ≠ a - (b - c))
22//! - **Gradient**: For tensor-tensor: ∂(a-b)/∂a = 1, ∂(a-b)/∂b = -1
23//! - **Broadcasting**: Follows NumPy broadcasting rules for shape compatibility
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
28//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
29//! - **Cache-friendly Access**: Linear memory access patterns
30//! - **Broadcasting Overhead**: Minimal overhead for compatible shapes
31//! - **GradTrack Optimization**: Efficient automatic differentiation with NoGradTrack support
32
33use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
34use crate::tensor::core::Tensor;
35
36// SIMD optimizations for performance-critical operations
37#[cfg(target_arch = "x86_64")]
38use std::arch::x86_64::*;
39
40// Note: removed manual prefetching; linear access + hardware prefetch is sufficient
41
42impl Tensor {
43 /// Element-wise subtraction with another tensor with broadcasting support
44 ///
45 /// Performs element-wise subtraction with automatic broadcasting: `output[i] = self[i] - other[i]`
46 ///
47 /// Broadcasting enables subtraction between tensors of different but compatible shapes.
48 /// Compatible shapes follow NumPy broadcasting rules:
49 /// - Dimensions are aligned from the rightmost dimension
50 /// - Dimensions are compatible if they are equal, or one of them is 1
51 /// - Missing dimensions are treated as 1
52 ///
53 /// # Arguments
54 ///
55 /// * `other` - Tensor to subtract. Shapes must be broadcast-compatible.
56 ///
57 /// # Returns
58 ///
59 /// A new tensor containing the element-wise difference with broadcast result shape
60 ///
61 /// # Performance Characteristics
62 ///
63 /// - **Fast Path**: Optimized for identical shapes to avoid broadcasting overhead
64 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
65 /// - **Broadcasting**: Efficient broadcasting for compatible shapes
66 /// - **Cache-friendly**: Linear memory access patterns
67 /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
68 ///
69 /// # Implementation Details
70 ///
71 /// Uses a fast path for identical shapes to avoid broadcasting overhead.
72 /// For different shapes, performs broadcasting followed by optimized element-wise subtraction.
73 /// Automatically selects between SIMD and scalar implementations based on hardware capabilities.
74 ///
75 /// # Examples
76 ///
77 /// ## Same Shape Subtraction
78 ///
79 /// ```
80 /// use train_station::Tensor;
81 ///
82 /// let a = Tensor::from_slice(&[5.0, 7.0, 9.0], vec![3]).unwrap();
83 /// let b = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
84 /// let c = a.sub_tensor(&b);
85 /// assert_eq!(c.shape().dims, vec![3]);
86 /// assert_eq!(c.get(&[0]), 4.0); // 5.0 - 1.0
87 /// assert_eq!(c.get(&[1]), 5.0); // 7.0 - 2.0
88 /// assert_eq!(c.get(&[2]), 6.0); // 9.0 - 3.0
89 /// ```
90 ///
91 /// ## Broadcasting Subtraction
92 ///
93 /// ```
94 /// use train_station::Tensor;
95 ///
96 /// let a = Tensor::from_slice(&[5.0, 10.0], vec![2, 1]).unwrap();
97 /// let b = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
98 /// let c = a.sub_tensor(&b);
99 /// assert_eq!(c.shape().dims, vec![2, 3]);
100 /// // Result: [[4.0, 3.0, 2.0], [9.0, 8.0, 7.0]]
101 /// assert_eq!(c.get(&[0, 0]), 4.0); // 5.0 - 1.0
102 /// assert_eq!(c.get(&[0, 1]), 3.0); // 5.0 - 2.0
103 /// assert_eq!(c.get(&[1, 0]), 9.0); // 10.0 - 1.0
104 /// ```
105 ///
106 /// ## Scalar Subtraction
107 ///
108 /// ```
109 /// use train_station::Tensor;
110 ///
111 /// let a = Tensor::ones(vec![2, 3]);
112 /// let b = Tensor::from_slice(&[0.5], vec![1]).unwrap();
113 /// let c = a.sub_tensor(&b);
114 /// assert_eq!(c.shape().dims, vec![2, 3]);
115 /// assert_eq!(c.get(&[0, 0]), 0.5); // 1.0 - 0.5
116 /// ```
117 ///
118 /// # Panics
119 /// Panics if tensor shapes are not broadcast-compatible
120 #[inline]
121 pub fn sub_tensor(&self, other: &Tensor) -> Tensor {
122 // Check if shapes are identical for fast path
123 if self.shape().dims == other.shape().dims {
124 return self.sub_tensor_same_shape(other);
125 }
126
127 // Use broadcasting for different shapes
128 let (broadcast_self, broadcast_other, _result_shape) =
129 self.broadcast_with(other).unwrap_or_else(|e| {
130 panic!(
131 "Cannot broadcast tensor shapes {:?} and {:?}: {}",
132 self.shape().dims,
133 other.shape().dims,
134 e
135 );
136 });
137
138 // Perform element-wise subtraction on broadcasted tensors
139 let mut result = broadcast_self.sub_tensor_optimized(&broadcast_other);
140
141 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
142 result.set_requires_grad_internal(true);
143 let grad_fn = GradFn::Sub {
144 is_tensor_sub: true,
145 original_shapes: Some((self.shape().dims.clone(), other.shape().dims.clone())),
146 };
147 result.set_grad_fn(grad_fn.clone());
148
149 let mut input_ids = Vec::with_capacity(2);
150 if self.requires_grad() {
151 input_ids.push(self.id());
152 }
153 if other.requires_grad() {
154 input_ids.push(other.id());
155 }
156 GradEngine::register_operation(result.id(), input_ids, grad_fn);
157 }
158
159 result
160 }
161
162 /// Element-wise subtraction for tensors with identical shapes (fast path)
163 ///
164 /// This is an optimized path for tensors that already have the same shape,
165 /// avoiding the overhead of broadcasting computation.
166 ///
167 /// # Arguments
168 ///
169 /// * `other` - Tensor to subtract, must have the same shape as self
170 ///
171 /// # Returns
172 ///
173 /// A new tensor containing the element-wise difference
174 ///
175 /// # Performance Characteristics
176 ///
177 /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
178 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
179 /// - **Cache-friendly**: Linear memory access patterns
180 /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
181 ///
182 /// # Implementation Details
183 ///
184 /// This method is used internally by `sub_tensor()` when tensors have identical shapes.
185 /// It bypasses the broadcasting logic and directly calls the optimized subtraction implementation.
186 #[inline]
187 fn sub_tensor_same_shape(&self, other: &Tensor) -> Tensor {
188 assert_eq!(
189 self.shape(),
190 other.shape(),
191 "Tensor shapes must match for same-shape subtraction"
192 );
193 let mut result = self.sub_tensor_optimized(other);
194
195 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
196 result.set_requires_grad_internal(true);
197 let grad_fn = GradFn::Sub {
198 is_tensor_sub: true,
199 original_shapes: None, // Same shape case
200 };
201 result.set_grad_fn(grad_fn.clone());
202
203 let mut input_ids = Vec::with_capacity(2);
204 if self.requires_grad() {
205 input_ids.push(self.id());
206 }
207 if other.requires_grad() {
208 input_ids.push(other.id());
209 }
210 GradEngine::register_operation(result.id(), input_ids, grad_fn);
211 }
212
213 result
214 }
215
216 /// Element-wise subtraction of a scalar from this tensor
217 ///
218 /// Performs element-wise subtraction of a scalar value: `output[i] = self[i] - scalar`
219 ///
220 /// # Arguments
221 ///
222 /// * `scalar` - The scalar value to subtract from each element
223 ///
224 /// # Returns
225 ///
226 /// A new tensor with the scalar subtracted from each element
227 ///
228 /// # Performance Characteristics
229 ///
230 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
231 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
232 /// - **Cache-friendly**: Linear memory access patterns
233 /// - **Mathematical Accuracy**: High-precision subtraction computation
234 /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
235 ///
236 /// # Examples
237 ///
238 /// ## Basic Scalar Subtraction
239 ///
240 /// ```
241 /// use train_station::Tensor;
242 ///
243 /// let a = Tensor::from_slice(&[5.0, 7.0, 9.0], vec![3]).unwrap();
244 /// let b = a.sub_scalar(2.0);
245 /// assert_eq!(b.shape().dims, vec![3]);
246 /// assert_eq!(b.get(&[0]), 3.0); // 5.0 - 2.0
247 /// assert_eq!(b.get(&[1]), 5.0); // 7.0 - 2.0
248 /// assert_eq!(b.get(&[2]), 7.0); // 9.0 - 2.0
249 /// ```
250 ///
251 /// ## Negative Scalar Subtraction
252 ///
253 /// ```
254 /// use train_station::Tensor;
255 ///
256 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
257 /// let b = a.sub_scalar(-2.0); // Subtracting negative = adding
258 /// assert_eq!(b.shape().dims, vec![3]);
259 /// assert_eq!(b.get(&[0]), 3.0); // 1.0 - (-2.0) = 3.0
260 /// assert_eq!(b.get(&[1]), 4.0); // 2.0 - (-2.0) = 4.0
261 /// assert_eq!(b.get(&[2]), 5.0); // 3.0 - (-2.0) = 5.0
262 /// ```
263 #[inline]
264 pub fn sub_scalar(&self, scalar: f32) -> Tensor {
265 let mut result = self.sub_scalar_optimized(scalar);
266
267 if self.requires_grad() && is_grad_enabled() {
268 result.set_requires_grad_internal(true);
269 let grad_fn = GradFn::Sub {
270 is_tensor_sub: false,
271 original_shapes: None, // Scalar case
272 };
273 result.set_grad_fn(grad_fn.clone());
274 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
275 }
276
277 result
278 }
279 /// Optimized tensor subtraction using SIMD when available
280 ///
281 /// Performs element-wise subtraction between tensors with identical shapes using
282 /// SIMD optimization when available and falling back to optimized scalar computation.
283 ///
284 /// # Arguments
285 ///
286 /// * `other` - The tensor to subtract from this tensor
287 ///
288 /// # Returns
289 ///
290 /// A new tensor with the result of the subtraction (self - other)
291 ///
292 /// # Performance Characteristics
293 ///
294 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
295 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
296 /// - **Cache-friendly**: Linear memory access patterns
297 /// - **Mathematical Accuracy**: High-precision subtraction computation
298 /// - **Zero-sized Handling**: Fast return for empty tensors
299 ///
300 /// # Implementation Details
301 ///
302 /// Automatically selects between SIMD and scalar implementations based on hardware
303 /// capabilities. SIMD implementation uses AVX2 vector subtraction operations for optimal
304 /// performance. Scalar implementation uses 4x unrolling for better instruction-level
305 /// parallelism.
306 ///
307 /// # Safety
308 ///
309 /// This operation assumes the tensors have the same shape.
310 #[inline]
311 pub(crate) fn sub_tensor_optimized(&self, other: &Tensor) -> Tensor {
312 assert_eq!(self.shape(), other.shape(), "Tensor shapes must match");
313
314 let mut output = Tensor::new(self.shape().dims.clone());
315
316 unsafe {
317 let a = self.as_ptr();
318 let b = other.as_ptr();
319 let dst = output.as_mut_ptr();
320
321 #[cfg(target_arch = "x86_64")]
322 {
323 // Use SIMD for better performance when available
324 if is_x86_feature_detected!("avx2") {
325 self.sub_tensors_simd_avx2_optimized(a, b, dst);
326 return output;
327 }
328 }
329
330 // Fallback to scalar operations with better cache usage
331 self.sub_tensors_scalar_optimized(a, b, dst);
332 }
333
334 output
335 }
336
337 /// AVX2-optimized tensor subtraction implementation
338 ///
339 /// Performs element-wise subtraction using AVX2 SIMD instructions for maximum
340 /// performance on x86_64 architectures with AVX2 support.
341 ///
342 /// # Arguments
343 ///
344 /// * `a` - Pointer to first tensor data
345 /// * `b` - Pointer to second tensor data
346 /// * `dst` - Pointer to output tensor data
347 ///
348 /// # Safety
349 ///
350 /// Requires valid pointers with sufficient memory for the tensor size.
351 /// All pointers must point to valid tensor data. Requires AVX2 support.
352 ///
353 /// # Performance Characteristics
354 ///
355 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
356 /// - **Memory Access**: Linear access patterns for cache efficiency
357 /// - **Vector Operations**: Uses AVX2 subtraction instructions for computation
358 /// - **Fallback**: Handles remaining elements with scalar operations
359 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
360 ///
361 /// # Implementation Details
362 ///
363 /// Uses AVX2 vector subtraction operations to compute a - b efficiently.
364 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
365 /// Processes remaining elements with scalar operations for complete coverage.
366 #[cfg(target_arch = "x86_64")]
367 #[inline]
368 #[target_feature(enable = "avx2")]
369 unsafe fn sub_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
370 let size = self.size();
371 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
372 let mut offset = 0;
373
374 // Unrolled SIMD loop for throughput
375 for _ in 0..simd_count {
376 // Process 4 AVX2 vectors (32 elements) per iteration
377 let a_vec1 = _mm256_loadu_ps(a.add(offset));
378 let b_vec1 = _mm256_loadu_ps(b.add(offset));
379 let sub_vec1 = _mm256_sub_ps(a_vec1, b_vec1);
380 _mm256_storeu_ps(dst.add(offset), sub_vec1);
381
382 let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
383 let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
384 let sub_vec2 = _mm256_sub_ps(a_vec2, b_vec2);
385 _mm256_storeu_ps(dst.add(offset + 8), sub_vec2);
386
387 let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
388 let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
389 let sub_vec3 = _mm256_sub_ps(a_vec3, b_vec3);
390 _mm256_storeu_ps(dst.add(offset + 16), sub_vec3);
391
392 let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
393 let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
394 let sub_vec4 = _mm256_sub_ps(a_vec4, b_vec4);
395 _mm256_storeu_ps(dst.add(offset + 24), sub_vec4);
396
397 offset += 32;
398 }
399
400 // Handle remaining 8-element blocks
401 let remaining_full_blocks = (size - offset) / 8;
402 for _ in 0..remaining_full_blocks {
403 let a_vec = _mm256_loadu_ps(a.add(offset));
404 let b_vec = _mm256_loadu_ps(b.add(offset));
405 let sub_vec = _mm256_sub_ps(a_vec, b_vec);
406 _mm256_storeu_ps(dst.add(offset), sub_vec);
407 offset += 8;
408 }
409
410 // Handle final elements with unrolled loop
411 let remaining = size - offset;
412 let unroll_count = remaining / 4;
413 for _ in 0..unroll_count {
414 *dst.add(offset) = *a.add(offset) - *b.add(offset);
415 *dst.add(offset + 1) = *a.add(offset + 1) - *b.add(offset + 1);
416 *dst.add(offset + 2) = *a.add(offset + 2) - *b.add(offset + 2);
417 *dst.add(offset + 3) = *a.add(offset + 3) - *b.add(offset + 3);
418 offset += 4;
419 }
420
421 for i in offset..size {
422 *dst.add(i) = *a.add(i) - *b.add(i);
423 }
424 }
425
426 /// Optimized scalar tensor subtraction fallback
427 ///
428 /// Performs element-wise subtraction using optimized scalar operations with
429 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
430 ///
431 /// # Arguments
432 ///
433 /// * `a` - Pointer to first tensor data
434 /// * `b` - Pointer to second tensor data
435 /// * `dst` - Pointer to output tensor data
436 ///
437 /// # Safety
438 ///
439 /// Requires valid pointers with sufficient memory for the tensor size.
440 /// All pointers must point to valid tensor data.
441 ///
442 /// # Performance Characteristics
443 ///
444 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
445 /// - **Memory Access**: Linear access patterns for cache efficiency
446 /// - **Fallback**: Handles remaining elements with scalar operations
447 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
448 /// - **Mathematical Accuracy**: High-precision scalar subtraction computation
449 ///
450 /// # Implementation Details
451 ///
452 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
453 /// Processes elements in groups of 4 to improve instruction-level parallelism
454 /// and reduce loop overhead.
455 #[inline]
456 unsafe fn sub_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
457 let size = self.size();
458 let unroll_count = size / 4;
459 let mut offset = 0;
460
461 // Unrolled scalar loop for better performance
462 for _ in 0..unroll_count {
463 *dst.add(offset) = *a.add(offset) - *b.add(offset);
464 *dst.add(offset + 1) = *a.add(offset + 1) - *b.add(offset + 1);
465 *dst.add(offset + 2) = *a.add(offset + 2) - *b.add(offset + 2);
466 *dst.add(offset + 3) = *a.add(offset + 3) - *b.add(offset + 3);
467 offset += 4;
468 }
469
470 // Handle remaining elements
471 for i in offset..size {
472 *dst.add(i) = *a.add(i) - *b.add(i);
473 }
474 }
475
476 /// Internal optimized scalar subtraction operation
477 ///
478 /// Performs element-wise subtraction of a scalar from tensor using SIMD optimization
479 /// when available and falling back to optimized scalar computation.
480 ///
481 /// # Arguments
482 ///
483 /// * `scalar` - The scalar value to subtract from each element
484 ///
485 /// # Returns
486 ///
487 /// A new tensor with the scalar subtracted from each element
488 ///
489 /// # Performance Characteristics
490 ///
491 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
492 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
493 /// - **Cache-friendly**: Linear memory access patterns
494 /// - **Mathematical Accuracy**: High-precision subtraction computation
495 /// - **Zero-sized Handling**: Fast return for empty tensors
496 ///
497 /// # Implementation Details
498 ///
499 /// Automatically selects between SIMD and scalar implementations based on hardware
500 /// capabilities. SIMD implementation uses AVX2 vector subtraction operations for optimal
501 /// performance. Scalar implementation uses 4x unrolling for better instruction-level
502 /// parallelism.
503 #[inline]
504 pub(crate) fn sub_scalar_optimized(&self, scalar: f32) -> Tensor {
505 let mut output = Tensor::new(self.shape().dims.clone());
506
507 unsafe {
508 let src = self.as_ptr();
509 let dst = output.as_mut_ptr();
510
511 #[cfg(target_arch = "x86_64")]
512 {
513 // Use SIMD for better performance when available
514 if is_x86_feature_detected!("avx2") {
515 self.sub_scalar_simd_avx2_optimized(src, dst, scalar);
516 return output;
517 }
518 }
519
520 // Fallback to optimized scalar operations
521 self.sub_scalar_fallback_optimized(src, dst, scalar);
522 }
523
524 output
525 }
526
527 /// AVX2-optimized scalar subtraction implementation
528 ///
529 /// Performs element-wise subtraction of a scalar using AVX2 SIMD instructions for maximum
530 /// performance on x86_64 architectures with AVX2 support.
531 ///
532 /// # Arguments
533 ///
534 /// * `src` - Pointer to source tensor data
535 /// * `dst` - Pointer to output tensor data
536 /// * `scalar` - The scalar value to subtract from each element
537 ///
538 /// # Safety
539 ///
540 /// Requires valid pointers with sufficient memory for the tensor size.
541 /// All pointers must point to valid tensor data. Requires AVX2 support.
542 ///
543 /// # Performance Characteristics
544 ///
545 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
546 /// - **Memory Access**: Linear access patterns for cache efficiency
547 /// - **Vector Operations**: Uses AVX2 subtraction instructions for computation
548 /// - **Fallback**: Handles remaining elements with scalar operations
549 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
550 ///
551 /// # Implementation Details
552 ///
553 /// Uses AVX2 vector subtraction operations to compute src - scalar efficiently.
554 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
555 /// Processes remaining elements with scalar operations for complete coverage.
556 #[cfg(target_arch = "x86_64")]
557 #[inline]
558 #[target_feature(enable = "avx2")]
559 unsafe fn sub_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
560 let scalar_vec = _mm256_set1_ps(scalar);
561 let size = self.size();
562 let simd_count = size / 32; // Process 32 elements per iteration
563 let mut offset = 0;
564
565 // Unrolled SIMD loop for better instruction throughput
566 for _ in 0..simd_count {
567 let src_vec1 = _mm256_loadu_ps(src.add(offset));
568 let sub_vec1 = _mm256_sub_ps(src_vec1, scalar_vec);
569 _mm256_storeu_ps(dst.add(offset), sub_vec1);
570
571 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
572 let sub_vec2 = _mm256_sub_ps(src_vec2, scalar_vec);
573 _mm256_storeu_ps(dst.add(offset + 8), sub_vec2);
574
575 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
576 let sub_vec3 = _mm256_sub_ps(src_vec3, scalar_vec);
577 _mm256_storeu_ps(dst.add(offset + 16), sub_vec3);
578
579 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
580 let sub_vec4 = _mm256_sub_ps(src_vec4, scalar_vec);
581 _mm256_storeu_ps(dst.add(offset + 24), sub_vec4);
582
583 offset += 32;
584 }
585
586 // Handle remaining 8-element blocks
587 let remaining_full_blocks = (size - offset) / 8;
588 for _ in 0..remaining_full_blocks {
589 let src_vec = _mm256_loadu_ps(src.add(offset));
590 let sub_vec = _mm256_sub_ps(src_vec, scalar_vec);
591 _mm256_storeu_ps(dst.add(offset), sub_vec);
592 offset += 8;
593 }
594
595 // Handle final elements
596 for i in offset..size {
597 *dst.add(i) = *src.add(i) - scalar;
598 }
599 }
600
601 /// Optimized scalar subtraction fallback
602 ///
603 /// Performs element-wise subtraction of a scalar using optimized scalar operations with
604 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
605 ///
606 /// # Arguments
607 ///
608 /// * `src` - Pointer to source tensor data
609 /// * `dst` - Pointer to output tensor data
610 /// * `scalar` - The scalar value to subtract from each element
611 ///
612 /// # Safety
613 ///
614 /// Requires valid pointers with sufficient memory for the tensor size.
615 /// All pointers must point to valid tensor data.
616 ///
617 /// # Performance Characteristics
618 ///
619 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
620 /// - **Memory Access**: Linear access patterns for cache efficiency
621 /// - **Fallback**: Handles remaining elements with scalar operations
622 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
623 /// - **Mathematical Accuracy**: High-precision scalar subtraction computation
624 ///
625 /// # Implementation Details
626 ///
627 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
628 /// Processes elements in groups of 4 to improve instruction-level parallelism
629 /// and reduce loop overhead.
630 #[inline]
631 unsafe fn sub_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
632 let size = self.size();
633 let unroll_count = size / 4;
634 let mut offset = 0;
635
636 // Unrolled scalar operations
637 for _ in 0..unroll_count {
638 *dst.add(offset) = *src.add(offset) - scalar;
639 *dst.add(offset + 1) = *src.add(offset + 1) - scalar;
640 *dst.add(offset + 2) = *src.add(offset + 2) - scalar;
641 *dst.add(offset + 3) = *src.add(offset + 3) - scalar;
642 offset += 4;
643 }
644
645 for i in offset..size {
646 *dst.add(i) = *src.add(i) - scalar;
647 }
648 }
649}
650
651#[cfg(test)]
652mod tests {
653 use super::*;
654
655 #[test]
656 fn test_tensor_subtraction() {
657 let mut a = Tensor::ones(vec![2, 3]);
658 a.fill(5.0); // Create a tensor with all 5.0s
659 let mut b = Tensor::ones(vec![2, 3]);
660 b.fill(2.0); // Create a tensor with all 2.0s
661 let result = a.sub_tensor_optimized(&b);
662
663 assert_eq!(result.shape().dims, vec![2, 3]);
664 assert_eq!(result.size(), 6);
665
666 // Check that all values are 3.0 (5.0 - 2.0)
667 unsafe {
668 for i in 0..result.size() {
669 assert!((result.as_ptr().add(i).read() - 3.0).abs() < 1e-6);
670 }
671 }
672 }
673
674 #[test]
675 fn test_scalar_subtraction() {
676 let mut tensor = Tensor::ones(vec![2, 2]);
677 tensor.fill(10.0); // Create a tensor with all 10.0s
678 let result = tensor.sub_scalar_optimized(3.0);
679
680 assert_eq!(result.shape().dims, vec![2, 2]);
681 assert_eq!(result.size(), 4);
682
683 // Check that all values are 7.0 (10.0 - 3.0)
684 unsafe {
685 for i in 0..result.size() {
686 assert!((result.as_ptr().add(i).read() - 7.0).abs() < 1e-6);
687 }
688 }
689 }
690
691 #[test]
692 fn test_negative_subtraction() {
693 let mut a = Tensor::ones(vec![2, 2]);
694 a.fill(2.0); // Create a tensor with all 2.0s
695 let mut b = Tensor::ones(vec![2, 2]);
696 b.fill(5.0); // Create a tensor with all 5.0s
697 let result = a.sub_tensor_optimized(&b);
698
699 // Check that all values are -3.0 (2.0 - 5.0)
700 unsafe {
701 for i in 0..result.size() {
702 assert!((result.as_ptr().add(i).read() - (-3.0)).abs() < 1e-6);
703 }
704 }
705 }
706
707 #[test]
708 fn test_scalar_negative_subtraction() {
709 let mut tensor = Tensor::ones(vec![2, 2]);
710 tensor.fill(3.0); // Create a tensor with all 3.0s
711 let result = tensor.sub_scalar_optimized(8.0);
712
713 // Check that all values are -5.0 (3.0 - 8.0)
714 unsafe {
715 for i in 0..result.size() {
716 assert!((result.as_ptr().add(i).read() - (-5.0)).abs() < 1e-6);
717 }
718 }
719 }
720
721 #[test]
722 #[should_panic(expected = "Tensor shapes must match")]
723 fn test_mismatched_shapes() {
724 let a = Tensor::ones(vec![2, 3]);
725 let b = Tensor::ones(vec![3, 2]);
726 a.sub_tensor_optimized(&b);
727 }
728
729 #[test]
730 fn test_edge_cases() {
731 // Test zero subtraction
732 let a = Tensor::ones(vec![3]);
733 let b = Tensor::zeros(vec![3]);
734 let result = a.sub_tensor_optimized(&b);
735
736 unsafe {
737 for i in 0..result.size() {
738 assert!((result.as_ptr().add(i).read() - 1.0).abs() < 1e-6);
739 }
740 }
741
742 // Test self subtraction
743 let mut tensor = Tensor::ones(vec![3]);
744 tensor.fill(5.0);
745 let result = tensor.sub_tensor_optimized(&tensor);
746
747 unsafe {
748 for i in 0..result.size() {
749 assert!(result.as_ptr().add(i).read().abs() < 1e-6);
750 }
751 }
752 }
753
754 #[test]
755 fn test_large_tensor_subtraction() {
756 let mut a = Tensor::ones(vec![100, 100]);
757 a.fill(10.0);
758 let mut b = Tensor::ones(vec![100, 100]);
759 b.fill(3.0);
760 let result = a.sub_tensor_optimized(&b);
761
762 assert_eq!(result.size(), 10000);
763
764 // Check some values are 7.0 (10.0 - 3.0)
765 unsafe {
766 for i in (0..result.size()).step_by(1000) {
767 assert!((result.as_ptr().add(i).read() - 7.0).abs() < 1e-6);
768 }
769 }
770 }
771
772 #[test]
773 fn test_negate_inplace() {
774 let mut tensor = Tensor::ones(vec![2, 2]);
775 tensor.fill(5.0);
776
777 // Check initial values
778 unsafe {
779 for i in 0..tensor.size() {
780 let val = tensor.as_ptr().add(i).read();
781 assert!((val - 5.0).abs() < 1e-6, "Expected 5.0, got {}", val);
782 }
783 }
784
785 tensor.negate_inplace();
786
787 // Check negated values
788 unsafe {
789 for i in 0..tensor.size() {
790 let val = tensor.as_ptr().add(i).read();
791 assert!(
792 (val - (-5.0)).abs() < 1e-6,
793 "Expected -5.0 after negation, got {}",
794 val
795 );
796 }
797 }
798 }
799
800 #[test]
801 fn test_subtraction_with_gradtrack() {
802 // Test scalar subtraction with gradtrack
803 let a = Tensor::ones(vec![2, 3]).with_requires_grad();
804 let mut result = a.sub_scalar(5.0);
805
806 // Check result values: 1.0 - 5.0 = -4.0
807 unsafe {
808 for i in 0..result.size() {
809 let val = result.as_ptr().add(i).read();
810 assert!((val - (-4.0)).abs() < 1e-6, "Expected -4.0, got {}", val);
811 }
812 }
813
814 result.backward(None);
815
816 // Check gradient: d/dx(x - c) = 1
817 if let Some(grad) = a.grad_by_value() {
818 unsafe {
819 for i in 0..grad.size() {
820 let val = grad.as_ptr().add(i).read();
821 assert!(
822 (val - 1.0).abs() < 1e-6,
823 "Expected gradient 1.0, got {}",
824 val
825 );
826 }
827 }
828 } else {
829 panic!("No gradient computed for scalar subtraction!");
830 }
831
832 // Test tensor subtraction with gradtrack
833 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
834 let mut b = Tensor::ones(vec![2, 2]);
835 b.fill(3.0);
836 let b = b.with_requires_grad();
837
838 let mut result = a.sub_tensor(&b);
839
840 // Check result values: 1.0 - 3.0 = -2.0
841 unsafe {
842 for i in 0..result.size() {
843 let val = result.as_ptr().add(i).read();
844 assert!((val - (-2.0)).abs() < 1e-6, "Expected -2.0, got {}", val);
845 }
846 }
847
848 result.backward(None);
849
850 // Check gradients: d/dx(x - y) = 1, d/dy(x - y) = -1
851 if let Some(grad_a) = a.grad_by_value() {
852 unsafe {
853 for i in 0..grad_a.size() {
854 let val = grad_a.as_ptr().add(i).read();
855 assert!(
856 (val - 1.0).abs() < 1e-6,
857 "Expected gradient A = 1.0, got {}",
858 val
859 );
860 }
861 }
862 } else {
863 panic!("No gradient A computed for tensor subtraction!");
864 }
865
866 if let Some(grad_b) = b.grad_by_value() {
867 unsafe {
868 for i in 0..grad_b.size() {
869 let val = grad_b.as_ptr().add(i).read();
870 println!("Debug: grad_b[{}] = {}", i, val);
871 assert!(
872 (val - (-1.0)).abs() < 1e-6,
873 "Expected gradient B = -1.0, got {}",
874 val
875 );
876 }
877 }
878 } else {
879 panic!("No gradient B computed for tensor subtraction!");
880 }
881 }
882
883 #[test]
884 fn test_mixed_add_sub_operations_with_gradtrack() {
885 // Test complex computation graph: (a + scalar1) - b + c - scalar2
886 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
887 let mut b = Tensor::ones(vec![2, 2]);
888 b.fill(2.0);
889 let b = b.with_requires_grad();
890 let mut c = Tensor::ones(vec![2, 2]);
891 c.fill(3.0);
892 let c = c.with_requires_grad();
893
894 let scalar1 = 5.0;
895 let scalar2 = 1.0;
896
897 // Complex computation: (a + scalar1) - b + c - scalar2
898 // Expected: (1 + 5) - 2 + 3 - 1 = 6
899 let step1 = a.add_scalar(scalar1); // a + 5 = 6
900 let step2 = step1.sub_tensor(&b); // 6 - 2 = 4
901 let step3 = step2.add_tensor(&c); // 4 + 3 = 7
902 let mut result = step3.sub_scalar(scalar2); // 7 - 1 = 6
903
904 // Check result values
905 unsafe {
906 for i in 0..result.size() {
907 let val = result.as_ptr().add(i).read();
908 assert!((val - 6.0).abs() < 1e-6, "Expected 6.0, got {}", val);
909 }
910 }
911
912 result.backward(None);
913
914 // Check gradients
915 // For computation: f = (a + 5) - b + c - 1
916 // df/da = 1, df/db = -1, df/dc = 1
917
918 if let Some(grad_a) = a.grad_by_value() {
919 unsafe {
920 for i in 0..grad_a.size() {
921 let val = grad_a.as_ptr().add(i).read();
922 assert!(
923 (val - 1.0).abs() < 1e-6,
924 "Expected gradient A = 1.0, got {}",
925 val
926 );
927 }
928 }
929 } else {
930 panic!("No gradient A computed for mixed operations!");
931 }
932
933 if let Some(grad_b) = b.grad_by_value() {
934 unsafe {
935 for i in 0..grad_b.size() {
936 let val = grad_b.as_ptr().add(i).read();
937 assert!(
938 (val - (-1.0)).abs() < 1e-6,
939 "Expected gradient B = -1.0, got {}",
940 val
941 );
942 }
943 }
944 } else {
945 panic!("No gradient B computed for mixed operations!");
946 }
947
948 if let Some(grad_c) = c.grad_by_value() {
949 unsafe {
950 for i in 0..grad_c.size() {
951 let val = grad_c.as_ptr().add(i).read();
952 assert!(
953 (val - 1.0).abs() < 1e-6,
954 "Expected gradient C = 1.0, got {}",
955 val
956 );
957 }
958 }
959 } else {
960 panic!("No gradient C computed for mixed operations!");
961 }
962
963 println!("Mixed add/sub operations with gradtrack test passed!");
964 println!("✓ Complex computation graph: (a + 5) - b + c - 1 = 6");
965 println!("✓ Gradients: da/df = 1, db/df = -1, dc/df = 1");
966 }
967
968 #[test]
969 fn test_sub_broadcasting_gradients_basic() {
970 use crate::gradtrack::clear_gradients;
971 clear_gradients();
972
973 // Test case: [2, 3] - [1, 3] -> [2, 3]
974 // grad_a should be [2, 3], grad_b should be [1, 3] (summed over broadcast dim)
975 // For subtraction: d/da (a - b) = 1, d/db (a - b) = -1
976
977 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
978 .unwrap()
979 .with_requires_grad();
980 let b = Tensor::from_slice(&[0.1, 0.2, 0.3], vec![1, 3])
981 .unwrap()
982 .with_requires_grad();
983
984 let mut result = a.sub_tensor(&b);
985 assert_eq!(result.shape().dims, vec![2, 3]);
986
987 // Set upstream gradient as ones
988 result.backward(None);
989
990 let grad_a = a.grad_by_value().expect("grad_a should exist");
991 let grad_b = b.grad_by_value().expect("grad_b should exist");
992
993 println!(
994 "Original shapes: a={:?}, b={:?}",
995 a.shape().dims,
996 b.shape().dims
997 );
998 println!(
999 "Gradient shapes: grad_a={:?}, grad_b={:?}",
1000 grad_a.shape().dims,
1001 grad_b.shape().dims
1002 );
1003
1004 // grad_a should have same shape as a: [2, 3]
1005 assert_eq!(
1006 grad_a.shape().dims,
1007 vec![2, 3],
1008 "grad_a should match original shape of a"
1009 );
1010
1011 // grad_b should have same shape as b: [1, 3]
1012 // This requires summing over the broadcasted dimension
1013 assert_eq!(
1014 grad_b.shape().dims,
1015 vec![1, 3],
1016 "grad_b should match original shape of b"
1017 );
1018
1019 // All gradients should be 1.0 for grad_a (d/da (a - b) = 1)
1020 for i in 0..grad_a.size() {
1021 let val = unsafe { *grad_a.as_ptr().add(i) };
1022 assert!(
1023 (val - 1.0).abs() < 1e-6,
1024 "grad_a[{}] = {} should be 1.0",
1025 i,
1026 val
1027 );
1028 }
1029
1030 // grad_b should be [-2.0, -2.0, -2.0] (sum over broadcast dim, then negated)
1031 let expected_grad_b = [-2.0, -2.0, -2.0]; // -1 * 2 rows = -2
1032 for (i, &expected) in expected_grad_b.iter().enumerate() {
1033 let val = unsafe { *grad_b.as_ptr().add(i) };
1034 assert!(
1035 (val - expected).abs() < 1e-6,
1036 "grad_b[{}] = {} should be {}",
1037 i,
1038 val,
1039 expected
1040 );
1041 }
1042 }
1043
1044 #[test]
1045 fn test_sub_scalar_broadcasting_gradients() {
1046 use crate::gradtrack::clear_gradients;
1047 clear_gradients();
1048
1049 // Test case: [2, 3] - [1] -> [2, 3]
1050 // grad_a should be [2, 3], grad_b should be [1] (summed over all dims, then negated)
1051
1052 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
1053 .unwrap()
1054 .with_requires_grad();
1055 let b = Tensor::from_slice(&[0.5], vec![1])
1056 .unwrap()
1057 .with_requires_grad();
1058
1059 let mut result = a.sub_tensor(&b);
1060 result.backward(None);
1061
1062 let grad_a = a.grad_by_value().expect("grad_a should exist");
1063 let grad_b = b.grad_by_value().expect("grad_b should exist");
1064
1065 // grad_a should have same shape as a: [2, 3]
1066 assert_eq!(grad_a.shape().dims, vec![2, 3]);
1067
1068 // grad_b should have same shape as b: [1] and sum to -6.0
1069 println!("grad_b shape: {:?}, expected: [1]", grad_b.shape().dims);
1070 assert_eq!(grad_b.shape().dims, vec![1]);
1071
1072 // grad_b should be -6.0 (sum over all 6 elements, then negated)
1073 let val = unsafe { *grad_b.as_ptr() };
1074 assert!(
1075 (val - (-6.0)).abs() < 1e-6,
1076 "grad_b = {} should be -6.0",
1077 val
1078 );
1079 }
1080}