train_station/tensor/ops/div.rs
1//! Division operations for tensors
2//!
3//! Provides element-wise division following PyTorch conventions with comprehensive
4//! broadcasting support, automatic differentiation, and high-performance SIMD optimization.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Division**: `div_tensor()` - Division with another tensor (PyTorch `div()` equivalent)
9//! - **Scalar Broadcasting**: `div_scalar()` - Division by scalar values
10//! - **Automatic Broadcasting**: NumPy-style broadcasting for compatible shapes
11//! - **SIMD Optimization**: AVX2 acceleration on x86_64 hardware
12//! - **Automatic Differentiation**: Full gradtrack support with gradient tracking
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Zero-copy Operations**: Efficient memory usage where possible
15//! - **Division by Zero Protection**: Comprehensive error checking and validation
16//!
17//! # Broadcasting Support
18//!
19//! All division operations support automatic broadcasting following NumPy rules:
20//! - Dimensions are aligned from the rightmost dimension
21//! - Dimensions are compatible if they are equal, or one of them is 1
22//! - Missing dimensions are treated as 1
23//! - Result shape follows broadcasting rules
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Acceleration**: 8x vectorization with AVX2 on compatible hardware
28//! - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
29//! - **Cache-friendly Access**: Linear memory access patterns
30//! - **Fallback Support**: Optimized scalar implementations for non-SIMD hardware
31//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
32//! - **Division by Zero Checks**: Optimized safety checks with minimal performance impact
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41// Note: removed manual prefetching; linear access + hardware prefetch is sufficient
42
43impl Tensor {
44 /// Element-wise division with another tensor with broadcasting support.
45 ///
46 /// Performs element-wise division with automatic broadcasting: `output[i] = self[i] / other[i]`
47 ///
48 /// Broadcasting enables division between tensors of different but compatible shapes.
49 /// Compatible shapes follow NumPy broadcasting rules:
50 /// - Dimensions are aligned from the rightmost dimension
51 /// - Dimensions are compatible if they are equal, or one of them is 1
52 /// - Missing dimensions are treated as 1
53 ///
54 /// # Arguments
55 /// * `other` - Tensor to divide by. Shapes must be broadcast-compatible.
56 ///
57 /// # Returns
58 /// A new tensor containing the element-wise quotient with broadcast result shape
59 ///
60 /// # Examples
61 ///
62 /// ## Same Shape Division
63 ///
64 /// ```
65 /// use train_station::Tensor;
66 ///
67 /// let a = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![3]).unwrap();
68 /// let b = Tensor::from_slice(&[2.0, 4.0, 5.0], vec![3]).unwrap();
69 /// let c = a.div_tensor(&b);
70 /// assert_eq!(c.shape().dims(), vec![3]);
71 /// assert_eq!(c.get(&[0]), 5.0);
72 /// assert_eq!(c.get(&[1]), 5.0);
73 /// assert_eq!(c.get(&[2]), 6.0);
74 /// ```
75 ///
76 /// ## Broadcasting Division
77 ///
78 /// ```
79 /// use train_station::Tensor;
80 ///
81 /// // Broadcasting: [2, 1] / [1, 3] -> [2, 3]
82 /// let a = Tensor::from_slice(&[10.0, 20.0], vec![2, 1]).unwrap();
83 /// let b = Tensor::from_slice(&[1.0, 2.0, 5.0], vec![1, 3]).unwrap();
84 /// let c = a.div_tensor(&b);
85 /// assert_eq!(c.shape().dims(), vec![2, 3]);
86 /// assert_eq!(c.get(&[0, 0]), 10.0);
87 /// assert_eq!(c.get(&[0, 1]), 5.0);
88 /// assert_eq!(c.get(&[1, 0]), 20.0);
89 /// assert_eq!(c.get(&[1, 1]), 10.0);
90 /// ```
91 ///
92 /// ## Scalar Division
93 ///
94 /// ```
95 /// use train_station::Tensor;
96 ///
97 /// // Scalar division: [2, 3] / scalar -> [2, 3]
98 /// let a = Tensor::ones(vec![2, 3]);
99 /// let b = Tensor::from_slice(&[2.0], vec![1]).unwrap();
100 /// let c = a.div_tensor(&b);
101 /// assert_eq!(c.shape().dims(), vec![2, 3]);
102 /// assert_eq!(c.get(&[0, 0]), 0.5);
103 /// assert_eq!(c.get(&[1, 2]), 0.5);
104 /// ```
105 ///
106 /// # Panics
107 /// Panics if tensor shapes are not broadcast-compatible or division by zero
108 #[inline]
109 #[track_caller]
110 pub fn div_tensor(&self, other: &Tensor) -> Tensor {
111 // Check if shapes are identical for fast path
112 if self.shape().dims() == other.shape().dims() {
113 return self.div_tensor_same_shape(other);
114 }
115
116 // Zero-copy broadcast views, then reuse same-shape optimized path
117 use crate::tensor::ops::broadcasting::{broadcast_shapes_cow, BroadcastError};
118 let mut result = match broadcast_shapes_cow(self, other) {
119 Ok((a_b, b_b, _)) => a_b.as_ref().div_tensor_optimized(b_b.as_ref()),
120 Err(BroadcastError::IncompatibleShapes { .. }) => {
121 panic!(
122 "Cannot broadcast tensor shapes {:?} and {:?}: {}",
123 self.shape().dims(),
124 other.shape().dims(),
125 "incompatible shapes"
126 );
127 }
128 Err(BroadcastError::AllocationFailed) => {
129 panic!("Memory allocation failed during broadcasting");
130 }
131 };
132
133 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
134 result.set_requires_grad_internal(true);
135 let operands = vec![self.clone(), other.clone()];
136 let grad_fn = GradFn::Div {
137 is_tensor_div: true,
138 scalar: None,
139 operands: Some(operands),
140 original_shapes: Some((
141 self.shape().dims().to_vec(),
142 other.shape().dims().to_vec(),
143 )),
144 };
145 result.set_grad_fn(grad_fn.clone());
146
147 let mut input_ids = Vec::with_capacity(2);
148 if self.requires_grad() {
149 input_ids.push(self.id());
150 }
151 if other.requires_grad() {
152 input_ids.push(other.id());
153 }
154 GradEngine::register_operation(result.id(), input_ids, grad_fn);
155 }
156
157 result
158 }
159
160 /// Element-wise division for tensors with identical shapes (fast path).
161 ///
162 /// This is an optimized path for tensors that already have the same shape,
163 /// avoiding the overhead of broadcasting computation. Used internally by
164 /// `div_tensor()` when shapes are identical.
165 ///
166 /// # Arguments
167 /// * `other` - Tensor to divide by, must have the same shape as self
168 ///
169 /// # Returns
170 /// A new tensor containing the element-wise quotient
171 ///
172 /// # Performance Characteristics
173 ///
174 /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
175 /// - **SIMD Optimization**: Uses optimized tensor division with SIMD acceleration
176 /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
177 ///
178 /// # Panics
179 ///
180 /// Panics if tensor shapes do not match or if any element in `other` is zero
181 #[inline]
182 fn div_tensor_same_shape(&self, other: &Tensor) -> Tensor {
183 assert_eq!(
184 self.shape(),
185 other.shape(),
186 "Tensor shapes must match for same-shape division"
187 );
188 let mut result = self.div_tensor_optimized(other);
189
190 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
191 result.set_requires_grad_internal(true);
192 let operands = vec![self.clone(), other.clone()];
193 let grad_fn = GradFn::Div {
194 is_tensor_div: true,
195 scalar: None,
196 operands: Some(operands),
197 original_shapes: None, // Same shape case
198 };
199 result.set_grad_fn(grad_fn.clone());
200
201 let mut input_ids = Vec::with_capacity(2);
202 if self.requires_grad() {
203 input_ids.push(self.id());
204 }
205 if other.requires_grad() {
206 input_ids.push(other.id());
207 }
208 GradEngine::register_operation(result.id(), input_ids, grad_fn);
209 }
210
211 result
212 }
213
214 /// Broadcast division with a scalar value.
215 ///
216 /// Divides every element by the scalar: `output[i] = self[i] / scalar`
217 ///
218 /// # Arguments
219 /// * `scalar` - Value to divide each element by (must not be zero)
220 ///
221 /// # Returns
222 /// A new tensor with each element divided by the scalar
223 ///
224 /// # Examples
225 ///
226 /// ## Basic Scalar Division
227 ///
228 /// ```
229 /// use train_station::Tensor;
230 ///
231 /// let a = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![3]).unwrap();
232 /// let b = a.div_scalar(10.0);
233 /// assert_eq!(b.shape().dims(), vec![3]);
234 /// assert_eq!(b.get(&[0]), 1.0);
235 /// assert_eq!(b.get(&[1]), 2.0);
236 /// assert_eq!(b.get(&[2]), 3.0);
237 /// ```
238 ///
239 /// ## Multi-dimensional Scalar Division
240 ///
241 /// ```
242 /// use train_station::Tensor;
243 ///
244 /// let a = Tensor::ones(vec![2, 3]);
245 /// let b = a.div_scalar(2.0);
246 /// assert_eq!(b.shape().dims(), vec![2, 3]);
247 /// assert_eq!(b.get(&[0, 0]), 0.5);
248 /// assert_eq!(b.get(&[1, 2]), 0.5);
249 /// ```
250 ///
251 /// # Panics
252 /// Panics if scalar is zero
253 #[inline]
254 #[track_caller]
255 pub fn div_scalar(&self, scalar: f32) -> Tensor {
256 let mut result = self.div_scalar_optimized(scalar);
257
258 if self.requires_grad() && is_grad_enabled() {
259 result.set_requires_grad_internal(true);
260 let grad_fn = GradFn::Div {
261 is_tensor_div: false,
262 scalar: Some(scalar),
263 operands: None,
264 original_shapes: None, // Scalar case
265 };
266 result.set_grad_fn(grad_fn.clone());
267 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
268 }
269
270 result
271 }
272 /// Internal optimized tensor / tensor operation
273 ///
274 /// Performs element-wise division between two tensors with the same shape,
275 /// using SIMD acceleration when available. This is the core implementation
276 /// used by `div_tensor()` after broadcasting has been applied.
277 ///
278 /// # Arguments
279 ///
280 /// * `other` - Tensor to divide by, must have the same shape as self
281 ///
282 /// # Returns
283 ///
284 /// A new tensor containing the element-wise quotient
285 ///
286 /// # Safety
287 ///
288 /// Assumes both tensors have the same shape and valid memory layouts.
289 /// Uses unsafe SIMD operations for performance optimization.
290 /// Division by zero will panic.
291 ///
292 /// # Performance Characteristics
293 ///
294 /// - **SIMD Optimization**: Uses AVX2 when available for 8x vectorization
295 /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
296 /// - **Cache-friendly**: Linear memory access patterns
297 /// - **Fallback**: Optimized scalar implementation for non-SIMD hardware
298 /// - **Division by Zero Checks**: Comprehensive safety validation
299 #[inline]
300 pub(crate) fn div_tensor_optimized(&self, other: &Tensor) -> Tensor {
301 assert_eq!(
302 self.shape().dims(),
303 other.shape().dims(),
304 "Tensor dims must match"
305 );
306
307 // Ensure contiguous sources for correctness with broadcast views/strides
308 let a_src = if self.is_contiguous() {
309 self.clone()
310 } else {
311 self.contiguous()
312 };
313 let b_src = if other.is_contiguous() {
314 other.clone()
315 } else {
316 other.contiguous()
317 };
318
319 let mut output = Tensor::new(self.shape().dims().to_vec());
320
321 unsafe {
322 let a = a_src.as_ptr();
323 let b = b_src.as_ptr();
324 let dst = output.as_mut_ptr();
325
326 #[cfg(target_arch = "x86_64")]
327 {
328 // Use SIMD for better performance when available
329 if is_x86_feature_detected!("avx2") {
330 self.div_tensors_simd_avx2_optimized(a, b, dst);
331 return output;
332 }
333 }
334
335 // Fallback to scalar operations with better cache usage
336 self.div_tensors_scalar_optimized(a, b, dst);
337 }
338
339 output
340 }
341
342 /// SIMD-optimized tensor division using AVX2 instructions
343 ///
344 /// Performs element-wise division using AVX2 SIMD instructions for maximum
345 /// performance on x86_64 hardware. Processes 32 elements per iteration with
346 /// 4x unrolling for optimal instruction throughput. Includes comprehensive
347 /// division by zero checking for safety.
348 ///
349 /// # Arguments
350 ///
351 /// * `a` - Pointer to first tensor data (numerator)
352 /// * `b` - Pointer to second tensor data (denominator)
353 /// * `dst` - Pointer to output tensor data
354 ///
355 /// # Safety
356 ///
357 /// Requires AVX2 support and valid pointers with sufficient memory.
358 /// All pointers must be aligned and point to valid tensor data.
359 /// Division by zero will panic.
360 ///
361 /// # Performance Characteristics
362 ///
363 /// - **SIMD Width**: 8 elements per AVX2 vector operation
364 /// - **Unrolling**: 4x unrolling (32 elements per iteration)
365 /// - **Memory Access**: Linear access patterns for cache efficiency
366 /// - **Fallback**: Handles remaining elements with scalar operations
367 /// - **Safety Checks**: Comprehensive division by zero validation
368 #[cfg(target_arch = "x86_64")]
369 #[inline]
370 #[target_feature(enable = "avx2")]
371 unsafe fn div_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
372 let size = self.size();
373 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
374 let mut offset = 0;
375
376 // Unrolled SIMD loop for throughput
377 for _ in 0..simd_count {
378 // Process 4 AVX2 vectors (32 elements) per iteration
379 let a_vec1 = _mm256_loadu_ps(a.add(offset));
380 let b_vec1 = _mm256_loadu_ps(b.add(offset));
381 let div_vec1 = _mm256_div_ps(a_vec1, b_vec1);
382 _mm256_storeu_ps(dst.add(offset), div_vec1);
383
384 let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
385 let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
386 let div_vec2 = _mm256_div_ps(a_vec2, b_vec2);
387 _mm256_storeu_ps(dst.add(offset + 8), div_vec2);
388
389 let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
390 let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
391 let div_vec3 = _mm256_div_ps(a_vec3, b_vec3);
392 _mm256_storeu_ps(dst.add(offset + 16), div_vec3);
393
394 let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
395 let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
396 let div_vec4 = _mm256_div_ps(a_vec4, b_vec4);
397 _mm256_storeu_ps(dst.add(offset + 24), div_vec4);
398
399 offset += 32;
400 }
401
402 // Handle remaining 8-element blocks, then tail with checks
403 let remaining_full_blocks = (size - offset) / 8;
404 for _ in 0..remaining_full_blocks {
405 let a_vec = _mm256_loadu_ps(a.add(offset));
406 let b_vec = _mm256_loadu_ps(b.add(offset));
407 // Fallback to scalar for safety checks
408 let mut a_vals = [0.0f32; 8];
409 let mut b_vals = [0.0f32; 8];
410 _mm256_storeu_ps(a_vals.as_mut_ptr(), a_vec);
411 _mm256_storeu_ps(b_vals.as_mut_ptr(), b_vec);
412 for t in 0..8 {
413 let j = offset + t;
414 if b_vals[t] == 0.0 {
415 panic!("Division by zero detected at index {}", j);
416 }
417 *dst.add(j) = a_vals[t] / b_vals[t];
418 }
419 offset += 8;
420 }
421 while offset + 4 <= size {
422 let b0 = *b.add(offset);
423 let b1 = *b.add(offset + 1);
424 let b2 = *b.add(offset + 2);
425 let b3 = *b.add(offset + 3);
426 if b0 == 0.0 || b1 == 0.0 || b2 == 0.0 || b3 == 0.0 {
427 panic!("Division by zero detected in unrolled loop");
428 }
429 *dst.add(offset) = *a.add(offset) / b0;
430 *dst.add(offset + 1) = *a.add(offset + 1) / b1;
431 *dst.add(offset + 2) = *a.add(offset + 2) / b2;
432 *dst.add(offset + 3) = *a.add(offset + 3) / b3;
433 offset += 4;
434 }
435 for i in offset..size {
436 let b_val = *b.add(i);
437 if b_val == 0.0 {
438 panic!("Division by zero detected at index {}", i);
439 }
440 *dst.add(i) = *a.add(i) / b_val;
441 }
442 }
443
444 /// Optimized scalar tensor division fallback
445 ///
446 /// Performs element-wise division using optimized scalar operations when
447 /// SIMD is not available. Uses 4x unrolling for better instruction-level
448 /// parallelism and cache efficiency. Includes comprehensive division by zero
449 /// checking for safety.
450 ///
451 /// # Arguments
452 ///
453 /// * `a` - Pointer to first tensor data (numerator)
454 /// * `b` - Pointer to second tensor data (denominator)
455 /// * `dst` - Pointer to output tensor data
456 ///
457 /// # Safety
458 ///
459 /// Requires valid pointers with sufficient memory for the tensor size.
460 /// All pointers must point to valid tensor data.
461 /// Division by zero will panic.
462 ///
463 /// # Performance Characteristics
464 ///
465 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
466 /// - **Memory Access**: Linear access patterns for cache efficiency
467 /// - **Fallback**: Handles remaining elements with scalar operations
468 /// - **Safety Checks**: Comprehensive division by zero validation
469 #[inline]
470 unsafe fn div_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
471 let size = self.size();
472
473 // Use unrolled loops for better instruction throughput
474 let unroll_count = size / 4;
475 let mut i = 0;
476
477 // Process 4 elements at a time for better cache utilization
478 while i < unroll_count {
479 let idx = i * 4;
480
481 // Check for division by zero
482 let b0 = b.add(idx).read();
483 let b1 = b.add(idx + 1).read();
484 let b2 = b.add(idx + 2).read();
485 let b3 = b.add(idx + 3).read();
486
487 if b0 == 0.0 || b1 == 0.0 || b2 == 0.0 || b3 == 0.0 {
488 panic!("Division by zero detected in unrolled loop");
489 }
490
491 dst.add(idx).write(a.add(idx).read() / b0);
492 dst.add(idx + 1).write(a.add(idx + 1).read() / b1);
493 dst.add(idx + 2).write(a.add(idx + 2).read() / b2);
494 dst.add(idx + 3).write(a.add(idx + 3).read() / b3);
495 i += 1;
496 }
497
498 // Handle remaining elements
499 for j in (unroll_count * 4)..size {
500 let b_val = b.add(j).read();
501 if b_val == 0.0 {
502 panic!("Division by zero detected at index {}", j);
503 }
504 dst.add(j).write(a.add(j).read() / b_val);
505 }
506 }
507
508 /// Internal optimized scalar / tensor operation
509 ///
510 /// Performs element-wise division of a scalar into each element of the tensor,
511 /// using SIMD acceleration when available. This is the core implementation
512 /// used by `div_scalar()`.
513 ///
514 /// # Arguments
515 ///
516 /// * `scalar` - Scalar value to divide each element by (must not be zero)
517 ///
518 /// # Returns
519 ///
520 /// A new tensor with each element divided by the scalar
521 ///
522 /// # Safety
523 ///
524 /// Assumes valid tensor memory layout. Uses unsafe SIMD operations for
525 /// performance optimization. Division by zero will panic.
526 ///
527 /// # Performance Characteristics
528 ///
529 /// - **SIMD Optimization**: Uses AVX2 when available for 8x vectorization
530 /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
531 /// - **Cache-friendly**: Linear memory access patterns
532 /// - **Fallback**: Optimized scalar implementation for non-SIMD hardware
533 /// - **Division by Zero Checks**: Comprehensive safety validation
534 #[inline]
535 pub(crate) fn div_scalar_optimized(&self, scalar: f32) -> Tensor {
536 if scalar == 0.0 {
537 panic!("Division by zero: cannot divide tensor by zero scalar");
538 }
539
540 let mut output = Tensor::new(self.shape().dims().to_vec());
541
542 unsafe {
543 let src = self.as_ptr();
544 let dst = output.as_mut_ptr();
545
546 #[cfg(target_arch = "x86_64")]
547 {
548 // Use SIMD for better performance when available
549 if is_x86_feature_detected!("avx2") {
550 self.div_scalar_simd_avx2_optimized(src, dst, scalar);
551 return output;
552 }
553 }
554
555 // Fallback to scalar operations with better cache usage
556 self.div_scalar_fallback_optimized(src, dst, scalar);
557 }
558
559 output
560 }
561
562 /// SIMD-optimized scalar division using AVX2 instructions
563 ///
564 /// Performs element-wise scalar division using AVX2 SIMD instructions for maximum
565 /// performance on x86_64 hardware. Processes 32 elements per iteration with
566 /// 4x unrolling for optimal instruction throughput.
567 ///
568 /// # Arguments
569 ///
570 /// * `src` - Pointer to source tensor data
571 /// * `dst` - Pointer to output tensor data
572 /// * `scalar` - Scalar value to divide each element by
573 ///
574 /// # Safety
575 ///
576 /// Requires AVX2 support and valid pointers with sufficient memory.
577 /// All pointers must be aligned and point to valid tensor data.
578 /// Scalar must not be zero (checked before calling this function).
579 ///
580 /// # Performance Characteristics
581 ///
582 /// - **SIMD Width**: 8 elements per AVX2 vector operation
583 /// - **Unrolling**: 4x unrolling (32 elements per iteration)
584 /// - **Memory Access**: Linear access patterns for cache efficiency
585 /// - **Fallback**: Handles remaining elements with scalar operations
586 /// - **Optimization**: Most common scalar division pattern optimized
587 #[cfg(target_arch = "x86_64")]
588 #[inline]
589 #[target_feature(enable = "avx2")]
590 unsafe fn div_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
591 let size = self.size();
592 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
593 let mut offset = 0;
594
595 // Create SIMD vector for scalar
596 let scalar_vec = _mm256_set1_ps(scalar);
597
598 // Unrolled SIMD loop for throughput
599 for _ in 0..simd_count {
600 // Process 4 AVX2 vectors (32 elements) per iteration
601 let src_vec1 = _mm256_loadu_ps(src.add(offset));
602 let div_vec1 = _mm256_div_ps(src_vec1, scalar_vec);
603 _mm256_storeu_ps(dst.add(offset), div_vec1);
604
605 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
606 let div_vec2 = _mm256_div_ps(src_vec2, scalar_vec);
607 _mm256_storeu_ps(dst.add(offset + 8), div_vec2);
608
609 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
610 let div_vec3 = _mm256_div_ps(src_vec3, scalar_vec);
611 _mm256_storeu_ps(dst.add(offset + 16), div_vec3);
612
613 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
614 let div_vec4 = _mm256_div_ps(src_vec4, scalar_vec);
615 _mm256_storeu_ps(dst.add(offset + 24), div_vec4);
616
617 offset += 32;
618 }
619
620 // Handle remaining elements with scalar operations
621 for i in offset..size {
622 *dst.add(i) = *src.add(i) / scalar;
623 }
624 }
625
626 /// Optimized scalar division fallback
627 ///
628 /// Performs element-wise scalar division using optimized scalar operations when
629 /// SIMD is not available. Uses 4x unrolling for better instruction-level
630 /// parallelism and cache efficiency.
631 ///
632 /// # Arguments
633 ///
634 /// * `src` - Pointer to source tensor data
635 /// * `dst` - Pointer to output tensor data
636 /// * `scalar` - Scalar value to divide each element by
637 ///
638 /// # Safety
639 ///
640 /// Requires valid pointers with sufficient memory for the tensor size.
641 /// All pointers must point to valid tensor data.
642 /// Scalar must not be zero (checked before calling this function).
643 ///
644 /// # Performance Characteristics
645 ///
646 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
647 /// - **Memory Access**: Linear access patterns for cache efficiency
648 /// - **Fallback**: Handles remaining elements with scalar operations
649 #[inline]
650 unsafe fn div_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
651 let size = self.size();
652
653 // Use unrolled loops for better instruction throughput
654 let unroll_count = size / 4;
655 let mut i = 0;
656
657 // Process 4 elements at a time for better cache utilization
658 while i < unroll_count {
659 let idx = i * 4;
660 dst.add(idx).write(src.add(idx).read() / scalar);
661 dst.add(idx + 1).write(src.add(idx + 1).read() / scalar);
662 dst.add(idx + 2).write(src.add(idx + 2).read() / scalar);
663 dst.add(idx + 3).write(src.add(idx + 3).read() / scalar);
664 i += 1;
665 }
666
667 // Handle remaining elements
668 for j in (unroll_count * 4)..size {
669 dst.add(j).write(src.add(j).read() / scalar);
670 }
671 }
672}
673
674#[cfg(test)]
675mod tests {
676 use super::*;
677
678 #[test]
679 fn test_tensor_division() {
680 let a = Tensor::ones(vec![2, 3]);
681 let mut b = Tensor::ones(vec![2, 3]);
682 b.fill(2.0);
683 let result = a.div_tensor_optimized(&b);
684
685 assert_eq!(result.shape().dims(), vec![2, 3]);
686 assert_eq!(result.size(), 6);
687
688 // Check that all values are 0.5 (1.0 / 2.0)
689 unsafe {
690 for i in 0..result.size() {
691 assert!((result.as_ptr().add(i).read() - 0.5).abs() < 1e-6);
692 }
693 }
694 }
695
696 #[test]
697 fn test_scalar_division() {
698 let tensor = Tensor::ones(vec![2, 2]);
699 let result = tensor.div_scalar_optimized(2.0);
700
701 assert_eq!(result.shape().dims(), vec![2, 2]);
702 assert_eq!(result.size(), 4);
703
704 // Check that all values are 0.5 (1.0 / 2.0)
705 unsafe {
706 for i in 0..result.size() {
707 assert!((result.as_ptr().add(i).read() - 0.5).abs() < 1e-6);
708 }
709 }
710 }
711
712 #[test]
713 fn test_negative_division() {
714 let tensor = Tensor::ones(vec![2, 3]);
715 let result = tensor.div_scalar_optimized(-2.0);
716
717 assert_eq!(result.shape().dims(), vec![2, 3]);
718 assert_eq!(result.size(), 6);
719
720 // Check that all values are -0.5 (1.0 / -2.0)
721 unsafe {
722 for i in 0..result.size() {
723 assert!((result.as_ptr().add(i).read() - (-0.5)).abs() < 1e-6);
724 }
725 }
726 }
727
728 #[test]
729 #[should_panic(expected = "Division by zero")]
730 fn test_division_by_zero_scalar() {
731 let tensor = Tensor::ones(vec![2, 3]);
732 tensor.div_scalar_optimized(0.0);
733 }
734
735 #[test]
736 #[should_panic(expected = "Division by zero")]
737 fn test_division_by_zero_tensor() {
738 let a = Tensor::ones(vec![2, 3]);
739 let b = Tensor::zeros(vec![2, 3]);
740 a.div_tensor_optimized(&b);
741 }
742
743 #[test]
744 #[should_panic(expected = "Tensor dims must match")]
745 fn test_mismatched_shapes() {
746 let a = Tensor::ones(vec![2, 3]);
747 let b = Tensor::ones(vec![3, 2]);
748 a.div_tensor_optimized(&b);
749 }
750
751 #[test]
752 fn test_edge_cases() {
753 // Test with zero numerator
754 let zero_tensor = Tensor::zeros(vec![2, 3]);
755 let other = Tensor::ones(vec![2, 3]);
756 let result = zero_tensor.div_tensor_optimized(&other);
757
758 assert_eq!(result.shape().dims(), vec![2, 3]);
759 assert_eq!(result.size(), 6);
760
761 // Check that all values are 0.0 (0.0 / 1.0)
762 unsafe {
763 for i in 0..result.size() {
764 assert!((result.as_ptr().add(i).read() - 0.0).abs() < 1e-6);
765 }
766 }
767
768 // Test with negative values
769 let mut neg_tensor = Tensor::ones(vec![2, 3]);
770 neg_tensor.fill(-4.0);
771 let result = neg_tensor.div_scalar_optimized(2.0);
772
773 assert_eq!(result.shape().dims(), vec![2, 3]);
774 assert_eq!(result.size(), 6);
775
776 // Check that all values are -2.0 (-4.0 / 2.0)
777 unsafe {
778 for i in 0..result.size() {
779 assert!((result.as_ptr().add(i).read() - (-2.0)).abs() < 1e-6);
780 }
781 }
782 }
783
784 #[test]
785 fn test_large_tensor_division() {
786 let a = Tensor::ones(vec![100, 100]);
787 let mut b = Tensor::ones(vec![100, 100]);
788 b.fill(1.5);
789 let result = a.div_tensor_optimized(&b);
790
791 assert_eq!(result.shape().dims(), vec![100, 100]);
792 assert_eq!(result.size(), 10000);
793
794 // Check that all values are 0.666... (1.0 / 1.5)
795 unsafe {
796 for i in 0..result.size() {
797 assert!((result.as_ptr().add(i).read() - (2.0 / 3.0)).abs() < 1e-6);
798 }
799 }
800 }
801
802 #[test]
803 fn test_division_with_gradtrack() {
804 // Test scalar division with gradtrack
805 let a = Tensor::ones(vec![2, 3]).with_requires_grad();
806 let mut result = a.div_scalar(2.0);
807
808 // Check result values: 1.0 / 2.0 = 0.5
809 unsafe {
810 for i in 0..result.size() {
811 let val = result.as_ptr().add(i).read();
812 assert!((val - 0.5).abs() < 1e-6, "Expected 0.5, got {}", val);
813 }
814 }
815
816 result.backward(None);
817
818 // Check gradient: d/dx(x/2) = 1/2
819 if let Some(grad) = a.grad_owned() {
820 unsafe {
821 for i in 0..grad.size() {
822 let val = grad.as_ptr().add(i).read();
823 assert!(
824 (val - 0.5).abs() < 1e-6,
825 "Expected gradient 0.5, got {}",
826 val
827 );
828 }
829 }
830 } else {
831 panic!("No gradient computed for scalar division!");
832 }
833
834 // Test tensor division with gradtrack
835 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
836 let mut b = Tensor::ones(vec![2, 2]);
837 b.fill(2.0);
838 let b = b.with_requires_grad();
839
840 let mut result = a.div_tensor(&b);
841
842 // Check result values: 1.0 / 2.0 = 0.5
843 unsafe {
844 for i in 0..result.size() {
845 let val = result.as_ptr().add(i).read();
846 assert!((val - 0.5).abs() < 1e-6, "Expected 0.5, got {}", val);
847 }
848 }
849
850 result.backward(None);
851
852 // Check gradients: ∂(a/b)/∂a = 1/b, ∂(a/b)/∂b = -a/b²
853 // For a = 1.0, b = 2.0: ∂(a/b)/∂a = 0.5, ∂(a/b)/∂b = -0.25
854 if let Some(grad_a) = a.grad_owned() {
855 unsafe {
856 for i in 0..grad_a.size() {
857 let val = grad_a.as_ptr().add(i).read();
858 assert!(
859 (val - 0.5).abs() < 1e-6,
860 "Expected gradient A = 0.5 (∂(a/b)/∂a = 1/b), got {}",
861 val
862 );
863 }
864 }
865 } else {
866 panic!("No gradient A computed for tensor division!");
867 }
868
869 if let Some(grad_b) = b.grad_owned() {
870 unsafe {
871 for i in 0..grad_b.size() {
872 let val = grad_b.as_ptr().add(i).read();
873 assert!(
874 (val - (-0.25)).abs() < 1e-6,
875 "Expected gradient B = -0.25 (∂(a/b)/∂b = -a/b²), got {}",
876 val
877 );
878 }
879 }
880 } else {
881 panic!("No gradient B computed for tensor division!");
882 }
883 }
884
885 #[test]
886 fn test_mixed_div_mul_operations_with_gradtrack() {
887 // Test complex computation: (a / 2) * (b / 3) + 1
888 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
889 let mut b = Tensor::ones(vec![2, 2]);
890 b.fill(6.0);
891 let b = b.with_requires_grad();
892
893 let scalar1 = 2.0;
894 let scalar2 = 3.0;
895
896 // Compute: (a / scalar1) * (b / scalar2) + 1
897 let div_a = a.div_scalar(scalar1); // a / 2
898 let div_b = b.div_scalar(scalar2); // b / 3
899 let mul_result = div_a.mul_tensor(&div_b); // (a / 2) * (b / 3)
900 let mut final_result = mul_result.add_scalar(1.0); // (a / 2) * (b / 3) + 1
901
902 // Check result values: (1/2) * (6/3) + 1 = 0.5 * 2 + 1 = 2
903 unsafe {
904 for i in 0..final_result.size() {
905 let val = final_result.as_ptr().add(i).read();
906 assert!((val - 2.0).abs() < 1e-6, "Expected 2.0, got {}", val);
907 }
908 }
909
910 final_result.backward(None);
911
912 // Check gradients: d/dx((x/2) * (y/3) + 1) = (y/3) * (1/2) = y/6
913 // d/dy((x/2) * (y/3) + 1) = (x/2) * (1/3) = x/6
914 // For x = 1.0, y = 6.0: d/dx = 6/6 = 1.0, d/dy = 1/6 = 0.166...
915 if let Some(grad_a) = a.grad_owned() {
916 unsafe {
917 for i in 0..grad_a.size() {
918 let val = grad_a.as_ptr().add(i).read();
919 assert!(
920 (val - 1.0).abs() < 1e-6,
921 "Expected gradient A = 1.0, got {}",
922 val
923 );
924 }
925 }
926 } else {
927 panic!("No gradient A computed for mixed operations!");
928 }
929
930 if let Some(grad_b) = b.grad_owned() {
931 unsafe {
932 for i in 0..grad_b.size() {
933 let val = grad_b.as_ptr().add(i).read();
934 assert!(
935 (val - (1.0 / 6.0)).abs() < 1e-6,
936 "Expected gradient B = 1/6, got {}",
937 val
938 );
939 }
940 }
941 } else {
942 panic!("No gradient B computed for mixed operations!");
943 }
944 }
945
946 #[test]
947 fn test_div_broadcasting_gradients_basic() {
948 use crate::gradtrack::clear_gradients;
949 clear_gradients();
950
951 // Test case: [2, 3] / [1, 3] -> [2, 3]
952 // For division: d/da (a / b) = 1/b, d/db (a / b) = -a/b^2
953
954 let a = Tensor::from_slice(&[2.0, 4.0, 6.0, 8.0, 10.0, 12.0], vec![2, 3])
955 .unwrap()
956 .with_requires_grad();
957 let b = Tensor::from_slice(&[2.0, 2.0, 2.0], vec![1, 3])
958 .unwrap()
959 .with_requires_grad();
960
961 let mut result = a.div_tensor(&b);
962 assert_eq!(result.shape().dims(), vec![2, 3]);
963
964 // Set upstream gradient as ones
965 result.backward(None);
966
967 let grad_a = a.grad_owned().expect("grad_a should exist");
968 let grad_b = b.grad_owned().expect("grad_b should exist");
969
970 println!(
971 "Original shapes: a={:?}, b={:?}",
972 a.shape().dims(),
973 b.shape().dims()
974 );
975 println!(
976 "Gradient shapes: grad_a={:?}, grad_b={:?}",
977 grad_a.shape().dims(),
978 grad_b.shape().dims()
979 );
980
981 // grad_a should have same shape as a: [2, 3]
982 assert_eq!(
983 grad_a.shape().dims(),
984 vec![2, 3],
985 "grad_a should match original shape of a"
986 );
987
988 // grad_b should have same shape as b: [1, 3]
989 assert_eq!(
990 grad_b.shape().dims(),
991 vec![1, 3],
992 "grad_b should match original shape of b"
993 );
994
995 // grad_a should be [0.5, 0.5, 0.5, 0.5, 0.5, 0.5] (1/b = 1/2)
996 for i in 0..grad_a.size() {
997 let val = unsafe { *grad_a.as_ptr().add(i) };
998 assert!(
999 (val - 0.5).abs() < 1e-6,
1000 "grad_a[{}] = {} should be 0.5",
1001 i,
1002 val
1003 );
1004 }
1005
1006 // For grad_b: d/db (a / b) = -a/b^2, summed over broadcast dimension
1007 // a = [2,4,6,8,10,12], b = [2,2,2], so -a/b^2 = [-2/4, -4/4, -6/4, -8/4, -10/4, -12/4] = [-0.5, -1, -1.5, -2, -2.5, -3]
1008 // Summed over first dimension: [-0.5-2, -1-2.5, -1.5-3] = [-2.5, -3.5, -4.5]
1009 let expected_grad_b = [-2.5, -3.5, -4.5];
1010 for (i, &expected) in expected_grad_b.iter().enumerate() {
1011 let val = unsafe { *grad_b.as_ptr().add(i) };
1012 assert!(
1013 (val - expected).abs() < 1e-6,
1014 "grad_b[{}] = {} should be {}",
1015 i,
1016 val,
1017 expected
1018 );
1019 }
1020 }
1021
1022 #[test]
1023 fn test_div_scalar_broadcasting_gradients() {
1024 use crate::gradtrack::clear_gradients;
1025 clear_gradients();
1026
1027 // Test case: [2, 3] / [1] -> [2, 3]
1028 // For division: d/da (a / b) = 1/b, d/db (a / b) = -a/b^2
1029
1030 let a = Tensor::from_slice(&[2.0, 4.0, 6.0, 8.0, 10.0, 12.0], vec![2, 3])
1031 .unwrap()
1032 .with_requires_grad();
1033 let b = Tensor::from_slice(&[2.0], vec![1])
1034 .unwrap()
1035 .with_requires_grad();
1036
1037 let mut result = a.div_tensor(&b);
1038 result.backward(None);
1039
1040 let grad_a = a.grad_owned().expect("grad_a should exist");
1041 let grad_b = b.grad_owned().expect("grad_b should exist");
1042
1043 // grad_a should have same shape as a: [2, 3]
1044 assert_eq!(grad_a.shape().dims(), vec![2, 3]);
1045
1046 // grad_b should have same shape as b: [1]
1047 println!("grad_b shape: {:?}, expected: [1]", grad_b.shape().dims());
1048 assert_eq!(grad_b.shape().dims(), vec![1]);
1049
1050 // grad_b should be the sum of -a/b^2 = -(2+4+6+8+10+12)/4 = -42/4 = -10.5
1051 let val = unsafe { *grad_b.as_ptr() };
1052 assert!(
1053 (val - (-10.5)).abs() < 1e-6,
1054 "grad_b = {} should be -10.5",
1055 val
1056 );
1057 }
1058}