train_station/tensor/ops/sub.rs
1//! Subtraction operations for tensors
2//!
3//! Provides element-wise subtraction operations following PyTorch conventions with
4//! comprehensive GradTrack support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Tensor Subtraction**: `sub_tensor()` - Element-wise subtraction with broadcasting support
9//! - **Scalar Subtraction**: `sub_scalar()` - Subtraction of scalar from tensor
10//! - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
11//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
12//! - **Broadcasting Support**: NumPy-style broadcasting for compatible shapes
13//! - **Performance Optimization**: 4x unrolled SIMD operations with scalar fallback
14//!
15//! # Mathematical Properties
16//!
17//! The subtraction operations have the following properties:
18//! - **Tensor-Tensor**: `output[i] = a[i] - b[i]` with broadcasting
19//! - **Tensor-Scalar**: `output[i] = a[i] - scalar` for all elements
20//! - **Commutativity**: Subtraction is not commutative (a - b ≠ b - a)
21//! - **Associativity**: Subtraction is not associative ((a - b) - c ≠ a - (b - c))
22//! - **Gradient**: For tensor-tensor: ∂(a-b)/∂a = 1, ∂(a-b)/∂b = -1
23//! - **Broadcasting**: Follows NumPy broadcasting rules for shape compatibility
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
28//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
29//! - **Cache-friendly Access**: Linear memory access patterns
30//! - **Broadcasting Overhead**: Minimal overhead for compatible shapes
31//! - **GradTrack Optimization**: Efficient automatic differentiation with NoGradTrack support
32
33use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
34use crate::tensor::core::Tensor;
35
36// SIMD optimizations for performance-critical operations
37#[cfg(target_arch = "x86_64")]
38use std::arch::x86_64::*;
39
40// Note: removed manual prefetching; linear access + hardware prefetch is sufficient
41
42impl Tensor {
43 /// Element-wise subtraction with another tensor with broadcasting support
44 ///
45 /// Performs element-wise subtraction with automatic broadcasting: `output[i] = self[i] - other[i]`
46 ///
47 /// Broadcasting enables subtraction between tensors of different but compatible shapes.
48 /// Compatible shapes follow NumPy broadcasting rules:
49 /// - Dimensions are aligned from the rightmost dimension
50 /// - Dimensions are compatible if they are equal, or one of them is 1
51 /// - Missing dimensions are treated as 1
52 ///
53 /// # Arguments
54 ///
55 /// * `other` - Tensor to subtract. Shapes must be broadcast-compatible.
56 ///
57 /// # Returns
58 ///
59 /// A new tensor containing the element-wise difference with broadcast result shape
60 ///
61 /// # Performance Characteristics
62 ///
63 /// - **Fast Path**: Optimized for identical shapes to avoid broadcasting overhead
64 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
65 /// - **Broadcasting**: Efficient broadcasting for compatible shapes
66 /// - **Cache-friendly**: Linear memory access patterns
67 /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
68 ///
69 /// # Implementation Details
70 ///
71 /// Uses a fast path for identical shapes to avoid broadcasting overhead.
72 /// For different shapes, performs broadcasting followed by optimized element-wise subtraction.
73 /// Automatically selects between SIMD and scalar implementations based on hardware capabilities.
74 ///
75 /// # Examples
76 ///
77 /// ## Same Shape Subtraction
78 ///
79 /// ```
80 /// use train_station::Tensor;
81 ///
82 /// let a = Tensor::from_slice(&[5.0, 7.0, 9.0], vec![3]).unwrap();
83 /// let b = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
84 /// let c = a.sub_tensor(&b);
85 /// assert_eq!(c.shape().dims, vec![3]);
86 /// assert_eq!(c.get(&[0]), 4.0); // 5.0 - 1.0
87 /// assert_eq!(c.get(&[1]), 5.0); // 7.0 - 2.0
88 /// assert_eq!(c.get(&[2]), 6.0); // 9.0 - 3.0
89 /// ```
90 ///
91 /// ## Broadcasting Subtraction
92 ///
93 /// ```
94 /// use train_station::Tensor;
95 ///
96 /// let a = Tensor::from_slice(&[5.0, 10.0], vec![2, 1]).unwrap();
97 /// let b = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
98 /// let c = a.sub_tensor(&b);
99 /// assert_eq!(c.shape().dims, vec![2, 3]);
100 /// // Result: [[4.0, 3.0, 2.0], [9.0, 8.0, 7.0]]
101 /// assert_eq!(c.get(&[0, 0]), 4.0); // 5.0 - 1.0
102 /// assert_eq!(c.get(&[0, 1]), 3.0); // 5.0 - 2.0
103 /// assert_eq!(c.get(&[1, 0]), 9.0); // 10.0 - 1.0
104 /// ```
105 ///
106 /// ## Scalar Subtraction
107 ///
108 /// ```
109 /// use train_station::Tensor;
110 ///
111 /// let a = Tensor::ones(vec![2, 3]);
112 /// let b = Tensor::from_slice(&[0.5], vec![1]).unwrap();
113 /// let c = a.sub_tensor(&b);
114 /// assert_eq!(c.shape().dims, vec![2, 3]);
115 /// assert_eq!(c.get(&[0, 0]), 0.5); // 1.0 - 0.5
116 /// ```
117 ///
118 /// # Panics
119 /// Panics if tensor shapes are not broadcast-compatible
120 #[inline]
121 #[track_caller]
122 pub fn sub_tensor(&self, other: &Tensor) -> Tensor {
123 // Check if shapes are identical for fast path
124 if self.shape().dims == other.shape().dims {
125 return self.sub_tensor_same_shape(other);
126 }
127
128 // Use broadcasting for different shapes
129 let (broadcast_self, broadcast_other, _result_shape) =
130 self.broadcast_with(other).unwrap_or_else(|e| {
131 panic!(
132 "Cannot broadcast tensor shapes {:?} and {:?}: {}",
133 self.shape().dims,
134 other.shape().dims,
135 e
136 );
137 });
138
139 // Perform element-wise subtraction on broadcasted tensors
140 let mut result = broadcast_self.sub_tensor_optimized(&broadcast_other);
141
142 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
143 result.set_requires_grad_internal(true);
144 let grad_fn = GradFn::Sub {
145 is_tensor_sub: true,
146 original_shapes: Some((self.shape().dims.clone(), other.shape().dims.clone())),
147 };
148 result.set_grad_fn(grad_fn.clone());
149
150 let mut input_ids = Vec::with_capacity(2);
151 if self.requires_grad() {
152 input_ids.push(self.id());
153 }
154 if other.requires_grad() {
155 input_ids.push(other.id());
156 }
157 GradEngine::register_operation(result.id(), input_ids, grad_fn);
158 }
159
160 result
161 }
162
163 /// Element-wise subtraction for tensors with identical shapes (fast path)
164 ///
165 /// This is an optimized path for tensors that already have the same shape,
166 /// avoiding the overhead of broadcasting computation.
167 ///
168 /// # Arguments
169 ///
170 /// * `other` - Tensor to subtract, must have the same shape as self
171 ///
172 /// # Returns
173 ///
174 /// A new tensor containing the element-wise difference
175 ///
176 /// # Performance Characteristics
177 ///
178 /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
179 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
180 /// - **Cache-friendly**: Linear memory access patterns
181 /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
182 ///
183 /// # Implementation Details
184 ///
185 /// This method is used internally by `sub_tensor()` when tensors have identical shapes.
186 /// It bypasses the broadcasting logic and directly calls the optimized subtraction implementation.
187 #[inline]
188 fn sub_tensor_same_shape(&self, other: &Tensor) -> Tensor {
189 assert_eq!(
190 self.shape(),
191 other.shape(),
192 "Tensor shapes must match for same-shape subtraction"
193 );
194 let mut result = self.sub_tensor_optimized(other);
195
196 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
197 result.set_requires_grad_internal(true);
198 let grad_fn = GradFn::Sub {
199 is_tensor_sub: true,
200 original_shapes: None, // Same shape case
201 };
202 result.set_grad_fn(grad_fn.clone());
203
204 let mut input_ids = Vec::with_capacity(2);
205 if self.requires_grad() {
206 input_ids.push(self.id());
207 }
208 if other.requires_grad() {
209 input_ids.push(other.id());
210 }
211 GradEngine::register_operation(result.id(), input_ids, grad_fn);
212 }
213
214 result
215 }
216
217 /// Element-wise subtraction of a scalar from this tensor
218 ///
219 /// Performs element-wise subtraction of a scalar value: `output[i] = self[i] - scalar`
220 ///
221 /// # Arguments
222 ///
223 /// * `scalar` - The scalar value to subtract from each element
224 ///
225 /// # Returns
226 ///
227 /// A new tensor with the scalar subtracted from each element
228 ///
229 /// # Performance Characteristics
230 ///
231 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
232 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
233 /// - **Cache-friendly**: Linear memory access patterns
234 /// - **Mathematical Accuracy**: High-precision subtraction computation
235 /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
236 ///
237 /// # Examples
238 ///
239 /// ## Basic Scalar Subtraction
240 ///
241 /// ```
242 /// use train_station::Tensor;
243 ///
244 /// let a = Tensor::from_slice(&[5.0, 7.0, 9.0], vec![3]).unwrap();
245 /// let b = a.sub_scalar(2.0);
246 /// assert_eq!(b.shape().dims, vec![3]);
247 /// assert_eq!(b.get(&[0]), 3.0); // 5.0 - 2.0
248 /// assert_eq!(b.get(&[1]), 5.0); // 7.0 - 2.0
249 /// assert_eq!(b.get(&[2]), 7.0); // 9.0 - 2.0
250 /// ```
251 ///
252 /// ## Negative Scalar Subtraction
253 ///
254 /// ```
255 /// use train_station::Tensor;
256 ///
257 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
258 /// let b = a.sub_scalar(-2.0); // Subtracting negative = adding
259 /// assert_eq!(b.shape().dims, vec![3]);
260 /// assert_eq!(b.get(&[0]), 3.0); // 1.0 - (-2.0) = 3.0
261 /// assert_eq!(b.get(&[1]), 4.0); // 2.0 - (-2.0) = 4.0
262 /// assert_eq!(b.get(&[2]), 5.0); // 3.0 - (-2.0) = 5.0
263 /// ```
264 #[inline]
265 #[track_caller]
266 pub fn sub_scalar(&self, scalar: f32) -> Tensor {
267 let mut result = self.sub_scalar_optimized(scalar);
268
269 if self.requires_grad() && is_grad_enabled() {
270 result.set_requires_grad_internal(true);
271 let grad_fn = GradFn::Sub {
272 is_tensor_sub: false,
273 original_shapes: None, // Scalar case
274 };
275 result.set_grad_fn(grad_fn.clone());
276 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
277 }
278
279 result
280 }
281 /// Optimized tensor subtraction using SIMD when available
282 ///
283 /// Performs element-wise subtraction between tensors with identical shapes using
284 /// SIMD optimization when available and falling back to optimized scalar computation.
285 ///
286 /// # Arguments
287 ///
288 /// * `other` - The tensor to subtract from this tensor
289 ///
290 /// # Returns
291 ///
292 /// A new tensor with the result of the subtraction (self - other)
293 ///
294 /// # Performance Characteristics
295 ///
296 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
297 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
298 /// - **Cache-friendly**: Linear memory access patterns
299 /// - **Mathematical Accuracy**: High-precision subtraction computation
300 /// - **Zero-sized Handling**: Fast return for empty tensors
301 ///
302 /// # Implementation Details
303 ///
304 /// Automatically selects between SIMD and scalar implementations based on hardware
305 /// capabilities. SIMD implementation uses AVX2 vector subtraction operations for optimal
306 /// performance. Scalar implementation uses 4x unrolling for better instruction-level
307 /// parallelism.
308 ///
309 /// # Safety
310 ///
311 /// This operation assumes the tensors have the same shape.
312 #[inline]
313 pub(crate) fn sub_tensor_optimized(&self, other: &Tensor) -> Tensor {
314 assert_eq!(self.shape(), other.shape(), "Tensor shapes must match");
315
316 let mut output = Tensor::new(self.shape().dims.clone());
317
318 unsafe {
319 let a = self.as_ptr();
320 let b = other.as_ptr();
321 let dst = output.as_mut_ptr();
322
323 #[cfg(target_arch = "x86_64")]
324 {
325 // Use SIMD for better performance when available
326 if is_x86_feature_detected!("avx2") {
327 self.sub_tensors_simd_avx2_optimized(a, b, dst);
328 return output;
329 }
330 }
331
332 // Fallback to scalar operations with better cache usage
333 self.sub_tensors_scalar_optimized(a, b, dst);
334 }
335
336 output
337 }
338
339 /// AVX2-optimized tensor subtraction implementation
340 ///
341 /// Performs element-wise subtraction using AVX2 SIMD instructions for maximum
342 /// performance on x86_64 architectures with AVX2 support.
343 ///
344 /// # Arguments
345 ///
346 /// * `a` - Pointer to first tensor data
347 /// * `b` - Pointer to second tensor data
348 /// * `dst` - Pointer to output tensor data
349 ///
350 /// # Safety
351 ///
352 /// Requires valid pointers with sufficient memory for the tensor size.
353 /// All pointers must point to valid tensor data. Requires AVX2 support.
354 ///
355 /// # Performance Characteristics
356 ///
357 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
358 /// - **Memory Access**: Linear access patterns for cache efficiency
359 /// - **Vector Operations**: Uses AVX2 subtraction instructions for computation
360 /// - **Fallback**: Handles remaining elements with scalar operations
361 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
362 ///
363 /// # Implementation Details
364 ///
365 /// Uses AVX2 vector subtraction operations to compute a - b efficiently.
366 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
367 /// Processes remaining elements with scalar operations for complete coverage.
368 #[cfg(target_arch = "x86_64")]
369 #[inline]
370 #[target_feature(enable = "avx2")]
371 unsafe fn sub_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
372 let size = self.size();
373 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
374 let mut offset = 0;
375
376 // Unrolled SIMD loop for throughput
377 for _ in 0..simd_count {
378 // Process 4 AVX2 vectors (32 elements) per iteration
379 let a_vec1 = _mm256_loadu_ps(a.add(offset));
380 let b_vec1 = _mm256_loadu_ps(b.add(offset));
381 let sub_vec1 = _mm256_sub_ps(a_vec1, b_vec1);
382 _mm256_storeu_ps(dst.add(offset), sub_vec1);
383
384 let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
385 let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
386 let sub_vec2 = _mm256_sub_ps(a_vec2, b_vec2);
387 _mm256_storeu_ps(dst.add(offset + 8), sub_vec2);
388
389 let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
390 let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
391 let sub_vec3 = _mm256_sub_ps(a_vec3, b_vec3);
392 _mm256_storeu_ps(dst.add(offset + 16), sub_vec3);
393
394 let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
395 let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
396 let sub_vec4 = _mm256_sub_ps(a_vec4, b_vec4);
397 _mm256_storeu_ps(dst.add(offset + 24), sub_vec4);
398
399 offset += 32;
400 }
401
402 // Handle remaining 8-element blocks
403 let remaining_full_blocks = (size - offset) / 8;
404 for _ in 0..remaining_full_blocks {
405 let a_vec = _mm256_loadu_ps(a.add(offset));
406 let b_vec = _mm256_loadu_ps(b.add(offset));
407 let sub_vec = _mm256_sub_ps(a_vec, b_vec);
408 _mm256_storeu_ps(dst.add(offset), sub_vec);
409 offset += 8;
410 }
411
412 // Handle final elements with unrolled loop
413 let remaining = size - offset;
414 let unroll_count = remaining / 4;
415 for _ in 0..unroll_count {
416 *dst.add(offset) = *a.add(offset) - *b.add(offset);
417 *dst.add(offset + 1) = *a.add(offset + 1) - *b.add(offset + 1);
418 *dst.add(offset + 2) = *a.add(offset + 2) - *b.add(offset + 2);
419 *dst.add(offset + 3) = *a.add(offset + 3) - *b.add(offset + 3);
420 offset += 4;
421 }
422
423 for i in offset..size {
424 *dst.add(i) = *a.add(i) - *b.add(i);
425 }
426 }
427
428 /// Optimized scalar tensor subtraction fallback
429 ///
430 /// Performs element-wise subtraction using optimized scalar operations with
431 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
432 ///
433 /// # Arguments
434 ///
435 /// * `a` - Pointer to first tensor data
436 /// * `b` - Pointer to second tensor data
437 /// * `dst` - Pointer to output tensor data
438 ///
439 /// # Safety
440 ///
441 /// Requires valid pointers with sufficient memory for the tensor size.
442 /// All pointers must point to valid tensor data.
443 ///
444 /// # Performance Characteristics
445 ///
446 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
447 /// - **Memory Access**: Linear access patterns for cache efficiency
448 /// - **Fallback**: Handles remaining elements with scalar operations
449 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
450 /// - **Mathematical Accuracy**: High-precision scalar subtraction computation
451 ///
452 /// # Implementation Details
453 ///
454 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
455 /// Processes elements in groups of 4 to improve instruction-level parallelism
456 /// and reduce loop overhead.
457 #[inline]
458 unsafe fn sub_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
459 let size = self.size();
460 let unroll_count = size / 4;
461 let mut offset = 0;
462
463 // Unrolled scalar loop for better performance
464 for _ in 0..unroll_count {
465 *dst.add(offset) = *a.add(offset) - *b.add(offset);
466 *dst.add(offset + 1) = *a.add(offset + 1) - *b.add(offset + 1);
467 *dst.add(offset + 2) = *a.add(offset + 2) - *b.add(offset + 2);
468 *dst.add(offset + 3) = *a.add(offset + 3) - *b.add(offset + 3);
469 offset += 4;
470 }
471
472 // Handle remaining elements
473 for i in offset..size {
474 *dst.add(i) = *a.add(i) - *b.add(i);
475 }
476 }
477
478 /// Internal optimized scalar subtraction operation
479 ///
480 /// Performs element-wise subtraction of a scalar from tensor using SIMD optimization
481 /// when available and falling back to optimized scalar computation.
482 ///
483 /// # Arguments
484 ///
485 /// * `scalar` - The scalar value to subtract from each element
486 ///
487 /// # Returns
488 ///
489 /// A new tensor with the scalar subtracted from each element
490 ///
491 /// # Performance Characteristics
492 ///
493 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
494 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
495 /// - **Cache-friendly**: Linear memory access patterns
496 /// - **Mathematical Accuracy**: High-precision subtraction computation
497 /// - **Zero-sized Handling**: Fast return for empty tensors
498 ///
499 /// # Implementation Details
500 ///
501 /// Automatically selects between SIMD and scalar implementations based on hardware
502 /// capabilities. SIMD implementation uses AVX2 vector subtraction operations for optimal
503 /// performance. Scalar implementation uses 4x unrolling for better instruction-level
504 /// parallelism.
505 #[inline]
506 pub(crate) fn sub_scalar_optimized(&self, scalar: f32) -> Tensor {
507 let mut output = Tensor::new(self.shape().dims.clone());
508
509 unsafe {
510 let src = self.as_ptr();
511 let dst = output.as_mut_ptr();
512
513 #[cfg(target_arch = "x86_64")]
514 {
515 // Use SIMD for better performance when available
516 if is_x86_feature_detected!("avx2") {
517 self.sub_scalar_simd_avx2_optimized(src, dst, scalar);
518 return output;
519 }
520 }
521
522 // Fallback to optimized scalar operations
523 self.sub_scalar_fallback_optimized(src, dst, scalar);
524 }
525
526 output
527 }
528
529 /// AVX2-optimized scalar subtraction implementation
530 ///
531 /// Performs element-wise subtraction of a scalar using AVX2 SIMD instructions for maximum
532 /// performance on x86_64 architectures with AVX2 support.
533 ///
534 /// # Arguments
535 ///
536 /// * `src` - Pointer to source tensor data
537 /// * `dst` - Pointer to output tensor data
538 /// * `scalar` - The scalar value to subtract from each element
539 ///
540 /// # Safety
541 ///
542 /// Requires valid pointers with sufficient memory for the tensor size.
543 /// All pointers must point to valid tensor data. Requires AVX2 support.
544 ///
545 /// # Performance Characteristics
546 ///
547 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
548 /// - **Memory Access**: Linear access patterns for cache efficiency
549 /// - **Vector Operations**: Uses AVX2 subtraction instructions for computation
550 /// - **Fallback**: Handles remaining elements with scalar operations
551 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
552 ///
553 /// # Implementation Details
554 ///
555 /// Uses AVX2 vector subtraction operations to compute src - scalar efficiently.
556 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
557 /// Processes remaining elements with scalar operations for complete coverage.
558 #[cfg(target_arch = "x86_64")]
559 #[inline]
560 #[target_feature(enable = "avx2")]
561 unsafe fn sub_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
562 let scalar_vec = _mm256_set1_ps(scalar);
563 let size = self.size();
564 let simd_count = size / 32; // Process 32 elements per iteration
565 let mut offset = 0;
566
567 // Unrolled SIMD loop for better instruction throughput
568 for _ in 0..simd_count {
569 let src_vec1 = _mm256_loadu_ps(src.add(offset));
570 let sub_vec1 = _mm256_sub_ps(src_vec1, scalar_vec);
571 _mm256_storeu_ps(dst.add(offset), sub_vec1);
572
573 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
574 let sub_vec2 = _mm256_sub_ps(src_vec2, scalar_vec);
575 _mm256_storeu_ps(dst.add(offset + 8), sub_vec2);
576
577 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
578 let sub_vec3 = _mm256_sub_ps(src_vec3, scalar_vec);
579 _mm256_storeu_ps(dst.add(offset + 16), sub_vec3);
580
581 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
582 let sub_vec4 = _mm256_sub_ps(src_vec4, scalar_vec);
583 _mm256_storeu_ps(dst.add(offset + 24), sub_vec4);
584
585 offset += 32;
586 }
587
588 // Handle remaining 8-element blocks
589 let remaining_full_blocks = (size - offset) / 8;
590 for _ in 0..remaining_full_blocks {
591 let src_vec = _mm256_loadu_ps(src.add(offset));
592 let sub_vec = _mm256_sub_ps(src_vec, scalar_vec);
593 _mm256_storeu_ps(dst.add(offset), sub_vec);
594 offset += 8;
595 }
596
597 // Handle final elements
598 for i in offset..size {
599 *dst.add(i) = *src.add(i) - scalar;
600 }
601 }
602
603 /// Optimized scalar subtraction fallback
604 ///
605 /// Performs element-wise subtraction of a scalar using optimized scalar operations with
606 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
607 ///
608 /// # Arguments
609 ///
610 /// * `src` - Pointer to source tensor data
611 /// * `dst` - Pointer to output tensor data
612 /// * `scalar` - The scalar value to subtract from each element
613 ///
614 /// # Safety
615 ///
616 /// Requires valid pointers with sufficient memory for the tensor size.
617 /// All pointers must point to valid tensor data.
618 ///
619 /// # Performance Characteristics
620 ///
621 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
622 /// - **Memory Access**: Linear access patterns for cache efficiency
623 /// - **Fallback**: Handles remaining elements with scalar operations
624 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
625 /// - **Mathematical Accuracy**: High-precision scalar subtraction computation
626 ///
627 /// # Implementation Details
628 ///
629 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
630 /// Processes elements in groups of 4 to improve instruction-level parallelism
631 /// and reduce loop overhead.
632 #[inline]
633 unsafe fn sub_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
634 let size = self.size();
635 let unroll_count = size / 4;
636 let mut offset = 0;
637
638 // Unrolled scalar operations
639 for _ in 0..unroll_count {
640 *dst.add(offset) = *src.add(offset) - scalar;
641 *dst.add(offset + 1) = *src.add(offset + 1) - scalar;
642 *dst.add(offset + 2) = *src.add(offset + 2) - scalar;
643 *dst.add(offset + 3) = *src.add(offset + 3) - scalar;
644 offset += 4;
645 }
646
647 for i in offset..size {
648 *dst.add(i) = *src.add(i) - scalar;
649 }
650 }
651}
652
653#[cfg(test)]
654mod tests {
655 use super::*;
656
657 #[test]
658 fn test_tensor_subtraction() {
659 let mut a = Tensor::ones(vec![2, 3]);
660 a.fill(5.0); // Create a tensor with all 5.0s
661 let mut b = Tensor::ones(vec![2, 3]);
662 b.fill(2.0); // Create a tensor with all 2.0s
663 let result = a.sub_tensor_optimized(&b);
664
665 assert_eq!(result.shape().dims, vec![2, 3]);
666 assert_eq!(result.size(), 6);
667
668 // Check that all values are 3.0 (5.0 - 2.0)
669 unsafe {
670 for i in 0..result.size() {
671 assert!((result.as_ptr().add(i).read() - 3.0).abs() < 1e-6);
672 }
673 }
674 }
675
676 #[test]
677 fn test_scalar_subtraction() {
678 let mut tensor = Tensor::ones(vec![2, 2]);
679 tensor.fill(10.0); // Create a tensor with all 10.0s
680 let result = tensor.sub_scalar_optimized(3.0);
681
682 assert_eq!(result.shape().dims, vec![2, 2]);
683 assert_eq!(result.size(), 4);
684
685 // Check that all values are 7.0 (10.0 - 3.0)
686 unsafe {
687 for i in 0..result.size() {
688 assert!((result.as_ptr().add(i).read() - 7.0).abs() < 1e-6);
689 }
690 }
691 }
692
693 #[test]
694 fn test_negative_subtraction() {
695 let mut a = Tensor::ones(vec![2, 2]);
696 a.fill(2.0); // Create a tensor with all 2.0s
697 let mut b = Tensor::ones(vec![2, 2]);
698 b.fill(5.0); // Create a tensor with all 5.0s
699 let result = a.sub_tensor_optimized(&b);
700
701 // Check that all values are -3.0 (2.0 - 5.0)
702 unsafe {
703 for i in 0..result.size() {
704 assert!((result.as_ptr().add(i).read() - (-3.0)).abs() < 1e-6);
705 }
706 }
707 }
708
709 #[test]
710 fn test_scalar_negative_subtraction() {
711 let mut tensor = Tensor::ones(vec![2, 2]);
712 tensor.fill(3.0); // Create a tensor with all 3.0s
713 let result = tensor.sub_scalar_optimized(8.0);
714
715 // Check that all values are -5.0 (3.0 - 8.0)
716 unsafe {
717 for i in 0..result.size() {
718 assert!((result.as_ptr().add(i).read() - (-5.0)).abs() < 1e-6);
719 }
720 }
721 }
722
723 #[test]
724 #[should_panic(expected = "Tensor shapes must match")]
725 fn test_mismatched_shapes() {
726 let a = Tensor::ones(vec![2, 3]);
727 let b = Tensor::ones(vec![3, 2]);
728 a.sub_tensor_optimized(&b);
729 }
730
731 #[test]
732 fn test_edge_cases() {
733 // Test zero subtraction
734 let a = Tensor::ones(vec![3]);
735 let b = Tensor::zeros(vec![3]);
736 let result = a.sub_tensor_optimized(&b);
737
738 unsafe {
739 for i in 0..result.size() {
740 assert!((result.as_ptr().add(i).read() - 1.0).abs() < 1e-6);
741 }
742 }
743
744 // Test self subtraction
745 let mut tensor = Tensor::ones(vec![3]);
746 tensor.fill(5.0);
747 let result = tensor.sub_tensor_optimized(&tensor);
748
749 unsafe {
750 for i in 0..result.size() {
751 assert!(result.as_ptr().add(i).read().abs() < 1e-6);
752 }
753 }
754 }
755
756 #[test]
757 fn test_large_tensor_subtraction() {
758 let mut a = Tensor::ones(vec![100, 100]);
759 a.fill(10.0);
760 let mut b = Tensor::ones(vec![100, 100]);
761 b.fill(3.0);
762 let result = a.sub_tensor_optimized(&b);
763
764 assert_eq!(result.size(), 10000);
765
766 // Check some values are 7.0 (10.0 - 3.0)
767 unsafe {
768 for i in (0..result.size()).step_by(1000) {
769 assert!((result.as_ptr().add(i).read() - 7.0).abs() < 1e-6);
770 }
771 }
772 }
773
774 #[test]
775 fn test_negate_inplace() {
776 let mut tensor = Tensor::ones(vec![2, 2]);
777 tensor.fill(5.0);
778
779 // Check initial values
780 unsafe {
781 for i in 0..tensor.size() {
782 let val = tensor.as_ptr().add(i).read();
783 assert!((val - 5.0).abs() < 1e-6, "Expected 5.0, got {}", val);
784 }
785 }
786
787 tensor.negate_inplace();
788
789 // Check negated values
790 unsafe {
791 for i in 0..tensor.size() {
792 let val = tensor.as_ptr().add(i).read();
793 assert!(
794 (val - (-5.0)).abs() < 1e-6,
795 "Expected -5.0 after negation, got {}",
796 val
797 );
798 }
799 }
800 }
801
802 #[test]
803 fn test_subtraction_with_gradtrack() {
804 // Test scalar subtraction with gradtrack
805 let a = Tensor::ones(vec![2, 3]).with_requires_grad();
806 let mut result = a.sub_scalar(5.0);
807
808 // Check result values: 1.0 - 5.0 = -4.0
809 unsafe {
810 for i in 0..result.size() {
811 let val = result.as_ptr().add(i).read();
812 assert!((val - (-4.0)).abs() < 1e-6, "Expected -4.0, got {}", val);
813 }
814 }
815
816 result.backward(None);
817
818 // Check gradient: d/dx(x - c) = 1
819 if let Some(grad) = a.grad_by_value() {
820 unsafe {
821 for i in 0..grad.size() {
822 let val = grad.as_ptr().add(i).read();
823 assert!(
824 (val - 1.0).abs() < 1e-6,
825 "Expected gradient 1.0, got {}",
826 val
827 );
828 }
829 }
830 } else {
831 panic!("No gradient computed for scalar subtraction!");
832 }
833
834 // Test tensor subtraction with gradtrack
835 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
836 let mut b = Tensor::ones(vec![2, 2]);
837 b.fill(3.0);
838 let b = b.with_requires_grad();
839
840 let mut result = a.sub_tensor(&b);
841
842 // Check result values: 1.0 - 3.0 = -2.0
843 unsafe {
844 for i in 0..result.size() {
845 let val = result.as_ptr().add(i).read();
846 assert!((val - (-2.0)).abs() < 1e-6, "Expected -2.0, got {}", val);
847 }
848 }
849
850 result.backward(None);
851
852 // Check gradients: d/dx(x - y) = 1, d/dy(x - y) = -1
853 if let Some(grad_a) = a.grad_by_value() {
854 unsafe {
855 for i in 0..grad_a.size() {
856 let val = grad_a.as_ptr().add(i).read();
857 assert!(
858 (val - 1.0).abs() < 1e-6,
859 "Expected gradient A = 1.0, got {}",
860 val
861 );
862 }
863 }
864 } else {
865 panic!("No gradient A computed for tensor subtraction!");
866 }
867
868 if let Some(grad_b) = b.grad_by_value() {
869 unsafe {
870 for i in 0..grad_b.size() {
871 let val = grad_b.as_ptr().add(i).read();
872 println!("Debug: grad_b[{}] = {}", i, val);
873 assert!(
874 (val - (-1.0)).abs() < 1e-6,
875 "Expected gradient B = -1.0, got {}",
876 val
877 );
878 }
879 }
880 } else {
881 panic!("No gradient B computed for tensor subtraction!");
882 }
883 }
884
885 #[test]
886 fn test_mixed_add_sub_operations_with_gradtrack() {
887 // Test complex computation graph: (a + scalar1) - b + c - scalar2
888 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
889 let mut b = Tensor::ones(vec![2, 2]);
890 b.fill(2.0);
891 let b = b.with_requires_grad();
892 let mut c = Tensor::ones(vec![2, 2]);
893 c.fill(3.0);
894 let c = c.with_requires_grad();
895
896 let scalar1 = 5.0;
897 let scalar2 = 1.0;
898
899 // Complex computation: (a + scalar1) - b + c - scalar2
900 // Expected: (1 + 5) - 2 + 3 - 1 = 6
901 let step1 = a.add_scalar(scalar1); // a + 5 = 6
902 let step2 = step1.sub_tensor(&b); // 6 - 2 = 4
903 let step3 = step2.add_tensor(&c); // 4 + 3 = 7
904 let mut result = step3.sub_scalar(scalar2); // 7 - 1 = 6
905
906 // Check result values
907 unsafe {
908 for i in 0..result.size() {
909 let val = result.as_ptr().add(i).read();
910 assert!((val - 6.0).abs() < 1e-6, "Expected 6.0, got {}", val);
911 }
912 }
913
914 result.backward(None);
915
916 // Check gradients
917 // For computation: f = (a + 5) - b + c - 1
918 // df/da = 1, df/db = -1, df/dc = 1
919
920 if let Some(grad_a) = a.grad_by_value() {
921 unsafe {
922 for i in 0..grad_a.size() {
923 let val = grad_a.as_ptr().add(i).read();
924 assert!(
925 (val - 1.0).abs() < 1e-6,
926 "Expected gradient A = 1.0, got {}",
927 val
928 );
929 }
930 }
931 } else {
932 panic!("No gradient A computed for mixed operations!");
933 }
934
935 if let Some(grad_b) = b.grad_by_value() {
936 unsafe {
937 for i in 0..grad_b.size() {
938 let val = grad_b.as_ptr().add(i).read();
939 assert!(
940 (val - (-1.0)).abs() < 1e-6,
941 "Expected gradient B = -1.0, got {}",
942 val
943 );
944 }
945 }
946 } else {
947 panic!("No gradient B computed for mixed operations!");
948 }
949
950 if let Some(grad_c) = c.grad_by_value() {
951 unsafe {
952 for i in 0..grad_c.size() {
953 let val = grad_c.as_ptr().add(i).read();
954 assert!(
955 (val - 1.0).abs() < 1e-6,
956 "Expected gradient C = 1.0, got {}",
957 val
958 );
959 }
960 }
961 } else {
962 panic!("No gradient C computed for mixed operations!");
963 }
964
965 println!("Mixed add/sub operations with gradtrack test passed!");
966 println!("✓ Complex computation graph: (a + 5) - b + c - 1 = 6");
967 println!("✓ Gradients: da/df = 1, db/df = -1, dc/df = 1");
968 }
969
970 #[test]
971 fn test_sub_broadcasting_gradients_basic() {
972 use crate::gradtrack::clear_gradients;
973 clear_gradients();
974
975 // Test case: [2, 3] - [1, 3] -> [2, 3]
976 // grad_a should be [2, 3], grad_b should be [1, 3] (summed over broadcast dim)
977 // For subtraction: d/da (a - b) = 1, d/db (a - b) = -1
978
979 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
980 .unwrap()
981 .with_requires_grad();
982 let b = Tensor::from_slice(&[0.1, 0.2, 0.3], vec![1, 3])
983 .unwrap()
984 .with_requires_grad();
985
986 let mut result = a.sub_tensor(&b);
987 assert_eq!(result.shape().dims, vec![2, 3]);
988
989 // Set upstream gradient as ones
990 result.backward(None);
991
992 let grad_a = a.grad_by_value().expect("grad_a should exist");
993 let grad_b = b.grad_by_value().expect("grad_b should exist");
994
995 println!(
996 "Original shapes: a={:?}, b={:?}",
997 a.shape().dims,
998 b.shape().dims
999 );
1000 println!(
1001 "Gradient shapes: grad_a={:?}, grad_b={:?}",
1002 grad_a.shape().dims,
1003 grad_b.shape().dims
1004 );
1005
1006 // grad_a should have same shape as a: [2, 3]
1007 assert_eq!(
1008 grad_a.shape().dims,
1009 vec![2, 3],
1010 "grad_a should match original shape of a"
1011 );
1012
1013 // grad_b should have same shape as b: [1, 3]
1014 // This requires summing over the broadcasted dimension
1015 assert_eq!(
1016 grad_b.shape().dims,
1017 vec![1, 3],
1018 "grad_b should match original shape of b"
1019 );
1020
1021 // All gradients should be 1.0 for grad_a (d/da (a - b) = 1)
1022 for i in 0..grad_a.size() {
1023 let val = unsafe { *grad_a.as_ptr().add(i) };
1024 assert!(
1025 (val - 1.0).abs() < 1e-6,
1026 "grad_a[{}] = {} should be 1.0",
1027 i,
1028 val
1029 );
1030 }
1031
1032 // grad_b should be [-2.0, -2.0, -2.0] (sum over broadcast dim, then negated)
1033 let expected_grad_b = [-2.0, -2.0, -2.0]; // -1 * 2 rows = -2
1034 for (i, &expected) in expected_grad_b.iter().enumerate() {
1035 let val = unsafe { *grad_b.as_ptr().add(i) };
1036 assert!(
1037 (val - expected).abs() < 1e-6,
1038 "grad_b[{}] = {} should be {}",
1039 i,
1040 val,
1041 expected
1042 );
1043 }
1044 }
1045
1046 #[test]
1047 fn test_sub_scalar_broadcasting_gradients() {
1048 use crate::gradtrack::clear_gradients;
1049 clear_gradients();
1050
1051 // Test case: [2, 3] - [1] -> [2, 3]
1052 // grad_a should be [2, 3], grad_b should be [1] (summed over all dims, then negated)
1053
1054 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
1055 .unwrap()
1056 .with_requires_grad();
1057 let b = Tensor::from_slice(&[0.5], vec![1])
1058 .unwrap()
1059 .with_requires_grad();
1060
1061 let mut result = a.sub_tensor(&b);
1062 result.backward(None);
1063
1064 let grad_a = a.grad_by_value().expect("grad_a should exist");
1065 let grad_b = b.grad_by_value().expect("grad_b should exist");
1066
1067 // grad_a should have same shape as a: [2, 3]
1068 assert_eq!(grad_a.shape().dims, vec![2, 3]);
1069
1070 // grad_b should have same shape as b: [1] and sum to -6.0
1071 println!("grad_b shape: {:?}, expected: [1]", grad_b.shape().dims);
1072 assert_eq!(grad_b.shape().dims, vec![1]);
1073
1074 // grad_b should be -6.0 (sum over all 6 elements, then negated)
1075 let val = unsafe { *grad_b.as_ptr() };
1076 assert!(
1077 (val - (-6.0)).abs() < 1e-6,
1078 "grad_b = {} should be -6.0",
1079 val
1080 );
1081 }
1082}