train_station/tensor/ops/
div.rs

1//! Division operations for tensors
2//!
3//! Provides element-wise division following PyTorch conventions with comprehensive
4//! broadcasting support, automatic differentiation, and high-performance SIMD optimization.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Division**: `div_tensor()` - Division with another tensor (PyTorch `div()` equivalent)
9//! - **Scalar Broadcasting**: `div_scalar()` - Division by scalar values
10//! - **Automatic Broadcasting**: NumPy-style broadcasting for compatible shapes
11//! - **SIMD Optimization**: AVX2 acceleration on x86_64 hardware
12//! - **Automatic Differentiation**: Full gradtrack support with gradient tracking
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Zero-copy Operations**: Efficient memory usage where possible
15//! - **Division by Zero Protection**: Comprehensive error checking and validation
16//!
17//! # Broadcasting Support
18//!
19//! All division operations support automatic broadcasting following NumPy rules:
20//! - Dimensions are aligned from the rightmost dimension
21//! - Dimensions are compatible if they are equal, or one of them is 1
22//! - Missing dimensions are treated as 1
23//! - Result shape follows broadcasting rules
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Acceleration**: 8x vectorization with AVX2 on compatible hardware
28//! - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
29//! - **Cache-friendly Access**: Linear memory access patterns
30//! - **Fallback Support**: Optimized scalar implementations for non-SIMD hardware
31//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
32//! - **Division by Zero Checks**: Optimized safety checks with minimal performance impact
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41// Note: removed manual prefetching; linear access + hardware prefetch is sufficient
42
43impl Tensor {
44    /// Element-wise division with another tensor with broadcasting support.
45    ///
46    /// Performs element-wise division with automatic broadcasting: `output[i] = self[i] / other[i]`
47    ///
48    /// Broadcasting enables division between tensors of different but compatible shapes.
49    /// Compatible shapes follow NumPy broadcasting rules:
50    /// - Dimensions are aligned from the rightmost dimension
51    /// - Dimensions are compatible if they are equal, or one of them is 1
52    /// - Missing dimensions are treated as 1
53    ///
54    /// # Arguments
55    /// * `other` - Tensor to divide by. Shapes must be broadcast-compatible.
56    ///
57    /// # Returns
58    /// A new tensor containing the element-wise quotient with broadcast result shape
59    ///
60    /// # Examples
61    ///
62    /// ## Same Shape Division
63    ///
64    /// ```
65    /// use train_station::Tensor;
66    ///
67    /// let a = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![3]).unwrap();
68    /// let b = Tensor::from_slice(&[2.0, 4.0, 5.0], vec![3]).unwrap();
69    /// let c = a.div_tensor(&b);
70    /// assert_eq!(c.shape().dims, vec![3]);
71    /// assert_eq!(c.get(&[0]), 5.0);
72    /// assert_eq!(c.get(&[1]), 5.0);
73    /// assert_eq!(c.get(&[2]), 6.0);
74    /// ```
75    ///
76    /// ## Broadcasting Division
77    ///
78    /// ```
79    /// use train_station::Tensor;
80    ///
81    /// // Broadcasting: [2, 1] / [1, 3] -> [2, 3]
82    /// let a = Tensor::from_slice(&[10.0, 20.0], vec![2, 1]).unwrap();
83    /// let b = Tensor::from_slice(&[1.0, 2.0, 5.0], vec![1, 3]).unwrap();
84    /// let c = a.div_tensor(&b);
85    /// assert_eq!(c.shape().dims, vec![2, 3]);
86    /// assert_eq!(c.get(&[0, 0]), 10.0);
87    /// assert_eq!(c.get(&[0, 1]), 5.0);
88    /// assert_eq!(c.get(&[1, 0]), 20.0);
89    /// assert_eq!(c.get(&[1, 1]), 10.0);
90    /// ```
91    ///
92    /// ## Scalar Division
93    ///
94    /// ```
95    /// use train_station::Tensor;
96    ///
97    /// // Scalar division: [2, 3] / scalar -> [2, 3]
98    /// let a = Tensor::ones(vec![2, 3]);
99    /// let b = Tensor::from_slice(&[2.0], vec![1]).unwrap();
100    /// let c = a.div_tensor(&b);
101    /// assert_eq!(c.shape().dims, vec![2, 3]);
102    /// assert_eq!(c.get(&[0, 0]), 0.5);
103    /// assert_eq!(c.get(&[1, 2]), 0.5);
104    /// ```
105    ///
106    /// # Panics
107    /// Panics if tensor shapes are not broadcast-compatible or division by zero
108    #[inline]
109    #[track_caller]
110    pub fn div_tensor(&self, other: &Tensor) -> Tensor {
111        // Check if shapes are identical for fast path
112        if self.shape().dims == other.shape().dims {
113            return self.div_tensor_same_shape(other);
114        }
115
116        // Use broadcasting for different shapes
117        let (broadcast_self, broadcast_other, _result_shape) =
118            self.broadcast_with(other).unwrap_or_else(|e| {
119                panic!(
120                    "Cannot broadcast tensor shapes {:?} and {:?}: {}",
121                    self.shape().dims,
122                    other.shape().dims,
123                    e
124                );
125            });
126
127        // Perform element-wise division on broadcasted tensors
128        let mut result = broadcast_self.div_tensor_optimized(&broadcast_other);
129
130        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
131            result.set_requires_grad_internal(true);
132            let operands = vec![self.clone(), other.clone()];
133            let grad_fn = GradFn::Div {
134                is_tensor_div: true,
135                scalar: None,
136                operands: Some(operands),
137                original_shapes: Some((self.shape().dims.clone(), other.shape().dims.clone())),
138            };
139            result.set_grad_fn(grad_fn.clone());
140
141            let mut input_ids = Vec::with_capacity(2);
142            if self.requires_grad() {
143                input_ids.push(self.id());
144            }
145            if other.requires_grad() {
146                input_ids.push(other.id());
147            }
148            GradEngine::register_operation(result.id(), input_ids, grad_fn);
149        }
150
151        result
152    }
153
154    /// Element-wise division for tensors with identical shapes (fast path).
155    ///
156    /// This is an optimized path for tensors that already have the same shape,
157    /// avoiding the overhead of broadcasting computation. Used internally by
158    /// `div_tensor()` when shapes are identical.
159    ///
160    /// # Arguments
161    /// * `other` - Tensor to divide by, must have the same shape as self
162    ///
163    /// # Returns
164    /// A new tensor containing the element-wise quotient
165    ///
166    /// # Performance Characteristics
167    ///
168    /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
169    /// - **SIMD Optimization**: Uses optimized tensor division with SIMD acceleration
170    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
171    ///
172    /// # Panics
173    ///
174    /// Panics if tensor shapes do not match or if any element in `other` is zero
175    #[inline]
176    fn div_tensor_same_shape(&self, other: &Tensor) -> Tensor {
177        assert_eq!(
178            self.shape(),
179            other.shape(),
180            "Tensor shapes must match for same-shape division"
181        );
182        let mut result = self.div_tensor_optimized(other);
183
184        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
185            result.set_requires_grad_internal(true);
186            let operands = vec![self.clone(), other.clone()];
187            let grad_fn = GradFn::Div {
188                is_tensor_div: true,
189                scalar: None,
190                operands: Some(operands),
191                original_shapes: None, // Same shape case
192            };
193            result.set_grad_fn(grad_fn.clone());
194
195            let mut input_ids = Vec::with_capacity(2);
196            if self.requires_grad() {
197                input_ids.push(self.id());
198            }
199            if other.requires_grad() {
200                input_ids.push(other.id());
201            }
202            GradEngine::register_operation(result.id(), input_ids, grad_fn);
203        }
204
205        result
206    }
207
208    /// Broadcast division with a scalar value.
209    ///
210    /// Divides every element by the scalar: `output[i] = self[i] / scalar`
211    ///
212    /// # Arguments
213    /// * `scalar` - Value to divide each element by (must not be zero)
214    ///
215    /// # Returns
216    /// A new tensor with each element divided by the scalar
217    ///
218    /// # Examples
219    ///
220    /// ## Basic Scalar Division
221    ///
222    /// ```
223    /// use train_station::Tensor;
224    ///
225    /// let a = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![3]).unwrap();
226    /// let b = a.div_scalar(10.0);
227    /// assert_eq!(b.shape().dims, vec![3]);
228    /// assert_eq!(b.get(&[0]), 1.0);
229    /// assert_eq!(b.get(&[1]), 2.0);
230    /// assert_eq!(b.get(&[2]), 3.0);
231    /// ```
232    ///
233    /// ## Multi-dimensional Scalar Division
234    ///
235    /// ```
236    /// use train_station::Tensor;
237    ///
238    /// let a = Tensor::ones(vec![2, 3]);
239    /// let b = a.div_scalar(2.0);
240    /// assert_eq!(b.shape().dims, vec![2, 3]);
241    /// assert_eq!(b.get(&[0, 0]), 0.5);
242    /// assert_eq!(b.get(&[1, 2]), 0.5);
243    /// ```
244    ///
245    /// # Panics
246    /// Panics if scalar is zero
247    #[inline]
248    #[track_caller]
249    pub fn div_scalar(&self, scalar: f32) -> Tensor {
250        let mut result = self.div_scalar_optimized(scalar);
251
252        if self.requires_grad() && is_grad_enabled() {
253            result.set_requires_grad_internal(true);
254            let grad_fn = GradFn::Div {
255                is_tensor_div: false,
256                scalar: Some(scalar),
257                operands: None,
258                original_shapes: None, // Scalar case
259            };
260            result.set_grad_fn(grad_fn.clone());
261            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
262        }
263
264        result
265    }
266    /// Internal optimized tensor / tensor operation
267    ///
268    /// Performs element-wise division between two tensors with the same shape,
269    /// using SIMD acceleration when available. This is the core implementation
270    /// used by `div_tensor()` after broadcasting has been applied.
271    ///
272    /// # Arguments
273    ///
274    /// * `other` - Tensor to divide by, must have the same shape as self
275    ///
276    /// # Returns
277    ///
278    /// A new tensor containing the element-wise quotient
279    ///
280    /// # Safety
281    ///
282    /// Assumes both tensors have the same shape and valid memory layouts.
283    /// Uses unsafe SIMD operations for performance optimization.
284    /// Division by zero will panic.
285    ///
286    /// # Performance Characteristics
287    ///
288    /// - **SIMD Optimization**: Uses AVX2 when available for 8x vectorization
289    /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
290    /// - **Cache-friendly**: Linear memory access patterns
291    /// - **Fallback**: Optimized scalar implementation for non-SIMD hardware
292    /// - **Division by Zero Checks**: Comprehensive safety validation
293    #[inline]
294    pub(crate) fn div_tensor_optimized(&self, other: &Tensor) -> Tensor {
295        assert_eq!(self.shape(), other.shape(), "Tensor shapes must match");
296
297        let mut output = Tensor::new(self.shape().dims.clone());
298
299        unsafe {
300            let a = self.as_ptr();
301            let b = other.as_ptr();
302            let dst = output.as_mut_ptr();
303
304            #[cfg(target_arch = "x86_64")]
305            {
306                // Use SIMD for better performance when available
307                if is_x86_feature_detected!("avx2") {
308                    self.div_tensors_simd_avx2_optimized(a, b, dst);
309                    return output;
310                }
311            }
312
313            // Fallback to scalar operations with better cache usage
314            self.div_tensors_scalar_optimized(a, b, dst);
315        }
316
317        output
318    }
319
320    /// SIMD-optimized tensor division using AVX2 instructions
321    ///
322    /// Performs element-wise division using AVX2 SIMD instructions for maximum
323    /// performance on x86_64 hardware. Processes 32 elements per iteration with
324    /// 4x unrolling for optimal instruction throughput. Includes comprehensive
325    /// division by zero checking for safety.
326    ///
327    /// # Arguments
328    ///
329    /// * `a` - Pointer to first tensor data (numerator)
330    /// * `b` - Pointer to second tensor data (denominator)
331    /// * `dst` - Pointer to output tensor data
332    ///
333    /// # Safety
334    ///
335    /// Requires AVX2 support and valid pointers with sufficient memory.
336    /// All pointers must be aligned and point to valid tensor data.
337    /// Division by zero will panic.
338    ///
339    /// # Performance Characteristics
340    ///
341    /// - **SIMD Width**: 8 elements per AVX2 vector operation
342    /// - **Unrolling**: 4x unrolling (32 elements per iteration)
343    /// - **Memory Access**: Linear access patterns for cache efficiency
344    /// - **Fallback**: Handles remaining elements with scalar operations
345    /// - **Safety Checks**: Comprehensive division by zero validation
346    #[cfg(target_arch = "x86_64")]
347    #[inline]
348    #[target_feature(enable = "avx2")]
349    unsafe fn div_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
350        let size = self.size();
351        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
352        let mut offset = 0;
353
354        // Unrolled SIMD loop for throughput
355        for _ in 0..simd_count {
356            // Process 4 AVX2 vectors (32 elements) per iteration
357            let a_vec1 = _mm256_loadu_ps(a.add(offset));
358            let b_vec1 = _mm256_loadu_ps(b.add(offset));
359            let div_vec1 = _mm256_div_ps(a_vec1, b_vec1);
360            _mm256_storeu_ps(dst.add(offset), div_vec1);
361
362            let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
363            let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
364            let div_vec2 = _mm256_div_ps(a_vec2, b_vec2);
365            _mm256_storeu_ps(dst.add(offset + 8), div_vec2);
366
367            let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
368            let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
369            let div_vec3 = _mm256_div_ps(a_vec3, b_vec3);
370            _mm256_storeu_ps(dst.add(offset + 16), div_vec3);
371
372            let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
373            let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
374            let div_vec4 = _mm256_div_ps(a_vec4, b_vec4);
375            _mm256_storeu_ps(dst.add(offset + 24), div_vec4);
376
377            offset += 32;
378        }
379
380        // Handle remaining 8-element blocks, then tail with checks
381        let remaining_full_blocks = (size - offset) / 8;
382        for _ in 0..remaining_full_blocks {
383            let a_vec = _mm256_loadu_ps(a.add(offset));
384            let b_vec = _mm256_loadu_ps(b.add(offset));
385            // Fallback to scalar for safety checks
386            let mut a_vals = [0.0f32; 8];
387            let mut b_vals = [0.0f32; 8];
388            _mm256_storeu_ps(a_vals.as_mut_ptr(), a_vec);
389            _mm256_storeu_ps(b_vals.as_mut_ptr(), b_vec);
390            for t in 0..8 {
391                let j = offset + t;
392                if b_vals[t] == 0.0 {
393                    panic!("Division by zero detected at index {}", j);
394                }
395                *dst.add(j) = a_vals[t] / b_vals[t];
396            }
397            offset += 8;
398        }
399        while offset + 4 <= size {
400            let b0 = *b.add(offset);
401            let b1 = *b.add(offset + 1);
402            let b2 = *b.add(offset + 2);
403            let b3 = *b.add(offset + 3);
404            if b0 == 0.0 || b1 == 0.0 || b2 == 0.0 || b3 == 0.0 {
405                panic!("Division by zero detected in unrolled loop");
406            }
407            *dst.add(offset) = *a.add(offset) / b0;
408            *dst.add(offset + 1) = *a.add(offset + 1) / b1;
409            *dst.add(offset + 2) = *a.add(offset + 2) / b2;
410            *dst.add(offset + 3) = *a.add(offset + 3) / b3;
411            offset += 4;
412        }
413        for i in offset..size {
414            let b_val = *b.add(i);
415            if b_val == 0.0 {
416                panic!("Division by zero detected at index {}", i);
417            }
418            *dst.add(i) = *a.add(i) / b_val;
419        }
420    }
421
422    /// Optimized scalar tensor division fallback
423    ///
424    /// Performs element-wise division using optimized scalar operations when
425    /// SIMD is not available. Uses 4x unrolling for better instruction-level
426    /// parallelism and cache efficiency. Includes comprehensive division by zero
427    /// checking for safety.
428    ///
429    /// # Arguments
430    ///
431    /// * `a` - Pointer to first tensor data (numerator)
432    /// * `b` - Pointer to second tensor data (denominator)
433    /// * `dst` - Pointer to output tensor data
434    ///
435    /// # Safety
436    ///
437    /// Requires valid pointers with sufficient memory for the tensor size.
438    /// All pointers must point to valid tensor data.
439    /// Division by zero will panic.
440    ///
441    /// # Performance Characteristics
442    ///
443    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
444    /// - **Memory Access**: Linear access patterns for cache efficiency
445    /// - **Fallback**: Handles remaining elements with scalar operations
446    /// - **Safety Checks**: Comprehensive division by zero validation
447    #[inline]
448    unsafe fn div_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
449        let size = self.size();
450
451        // Use unrolled loops for better instruction throughput
452        let unroll_count = size / 4;
453        let mut i = 0;
454
455        // Process 4 elements at a time for better cache utilization
456        while i < unroll_count {
457            let idx = i * 4;
458
459            // Check for division by zero
460            let b0 = b.add(idx).read();
461            let b1 = b.add(idx + 1).read();
462            let b2 = b.add(idx + 2).read();
463            let b3 = b.add(idx + 3).read();
464
465            if b0 == 0.0 || b1 == 0.0 || b2 == 0.0 || b3 == 0.0 {
466                panic!("Division by zero detected in unrolled loop");
467            }
468
469            dst.add(idx).write(a.add(idx).read() / b0);
470            dst.add(idx + 1).write(a.add(idx + 1).read() / b1);
471            dst.add(idx + 2).write(a.add(idx + 2).read() / b2);
472            dst.add(idx + 3).write(a.add(idx + 3).read() / b3);
473            i += 1;
474        }
475
476        // Handle remaining elements
477        for j in (unroll_count * 4)..size {
478            let b_val = b.add(j).read();
479            if b_val == 0.0 {
480                panic!("Division by zero detected at index {}", j);
481            }
482            dst.add(j).write(a.add(j).read() / b_val);
483        }
484    }
485
486    /// Internal optimized scalar / tensor operation
487    ///
488    /// Performs element-wise division of a scalar into each element of the tensor,
489    /// using SIMD acceleration when available. This is the core implementation
490    /// used by `div_scalar()`.
491    ///
492    /// # Arguments
493    ///
494    /// * `scalar` - Scalar value to divide each element by (must not be zero)
495    ///
496    /// # Returns
497    ///
498    /// A new tensor with each element divided by the scalar
499    ///
500    /// # Safety
501    ///
502    /// Assumes valid tensor memory layout. Uses unsafe SIMD operations for
503    /// performance optimization. Division by zero will panic.
504    ///
505    /// # Performance Characteristics
506    ///
507    /// - **SIMD Optimization**: Uses AVX2 when available for 8x vectorization
508    /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
509    /// - **Cache-friendly**: Linear memory access patterns
510    /// - **Fallback**: Optimized scalar implementation for non-SIMD hardware
511    /// - **Division by Zero Checks**: Comprehensive safety validation
512    #[inline]
513    pub(crate) fn div_scalar_optimized(&self, scalar: f32) -> Tensor {
514        if scalar == 0.0 {
515            panic!("Division by zero: cannot divide tensor by zero scalar");
516        }
517
518        let mut output = Tensor::new(self.shape().dims.clone());
519
520        unsafe {
521            let src = self.as_ptr();
522            let dst = output.as_mut_ptr();
523
524            #[cfg(target_arch = "x86_64")]
525            {
526                // Use SIMD for better performance when available
527                if is_x86_feature_detected!("avx2") {
528                    self.div_scalar_simd_avx2_optimized(src, dst, scalar);
529                    return output;
530                }
531            }
532
533            // Fallback to scalar operations with better cache usage
534            self.div_scalar_fallback_optimized(src, dst, scalar);
535        }
536
537        output
538    }
539
540    /// SIMD-optimized scalar division using AVX2 instructions
541    ///
542    /// Performs element-wise scalar division using AVX2 SIMD instructions for maximum
543    /// performance on x86_64 hardware. Processes 32 elements per iteration with
544    /// 4x unrolling for optimal instruction throughput.
545    ///
546    /// # Arguments
547    ///
548    /// * `src` - Pointer to source tensor data
549    /// * `dst` - Pointer to output tensor data
550    /// * `scalar` - Scalar value to divide each element by
551    ///
552    /// # Safety
553    ///
554    /// Requires AVX2 support and valid pointers with sufficient memory.
555    /// All pointers must be aligned and point to valid tensor data.
556    /// Scalar must not be zero (checked before calling this function).
557    ///
558    /// # Performance Characteristics
559    ///
560    /// - **SIMD Width**: 8 elements per AVX2 vector operation
561    /// - **Unrolling**: 4x unrolling (32 elements per iteration)
562    /// - **Memory Access**: Linear access patterns for cache efficiency
563    /// - **Fallback**: Handles remaining elements with scalar operations
564    /// - **Optimization**: Most common scalar division pattern optimized
565    #[cfg(target_arch = "x86_64")]
566    #[inline]
567    #[target_feature(enable = "avx2")]
568    unsafe fn div_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
569        let size = self.size();
570        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
571        let mut offset = 0;
572
573        // Create SIMD vector for scalar
574        let scalar_vec = _mm256_set1_ps(scalar);
575
576        // Unrolled SIMD loop for throughput
577        for _ in 0..simd_count {
578            // Process 4 AVX2 vectors (32 elements) per iteration
579            let src_vec1 = _mm256_loadu_ps(src.add(offset));
580            let div_vec1 = _mm256_div_ps(src_vec1, scalar_vec);
581            _mm256_storeu_ps(dst.add(offset), div_vec1);
582
583            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
584            let div_vec2 = _mm256_div_ps(src_vec2, scalar_vec);
585            _mm256_storeu_ps(dst.add(offset + 8), div_vec2);
586
587            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
588            let div_vec3 = _mm256_div_ps(src_vec3, scalar_vec);
589            _mm256_storeu_ps(dst.add(offset + 16), div_vec3);
590
591            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
592            let div_vec4 = _mm256_div_ps(src_vec4, scalar_vec);
593            _mm256_storeu_ps(dst.add(offset + 24), div_vec4);
594
595            offset += 32;
596        }
597
598        // Handle remaining elements with scalar operations
599        for i in offset..size {
600            *dst.add(i) = *src.add(i) / scalar;
601        }
602    }
603
604    /// Optimized scalar division fallback
605    ///
606    /// Performs element-wise scalar division using optimized scalar operations when
607    /// SIMD is not available. Uses 4x unrolling for better instruction-level
608    /// parallelism and cache efficiency.
609    ///
610    /// # Arguments
611    ///
612    /// * `src` - Pointer to source tensor data
613    /// * `dst` - Pointer to output tensor data
614    /// * `scalar` - Scalar value to divide each element by
615    ///
616    /// # Safety
617    ///
618    /// Requires valid pointers with sufficient memory for the tensor size.
619    /// All pointers must point to valid tensor data.
620    /// Scalar must not be zero (checked before calling this function).
621    ///
622    /// # Performance Characteristics
623    ///
624    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
625    /// - **Memory Access**: Linear access patterns for cache efficiency
626    /// - **Fallback**: Handles remaining elements with scalar operations
627    #[inline]
628    unsafe fn div_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
629        let size = self.size();
630
631        // Use unrolled loops for better instruction throughput
632        let unroll_count = size / 4;
633        let mut i = 0;
634
635        // Process 4 elements at a time for better cache utilization
636        while i < unroll_count {
637            let idx = i * 4;
638            dst.add(idx).write(src.add(idx).read() / scalar);
639            dst.add(idx + 1).write(src.add(idx + 1).read() / scalar);
640            dst.add(idx + 2).write(src.add(idx + 2).read() / scalar);
641            dst.add(idx + 3).write(src.add(idx + 3).read() / scalar);
642            i += 1;
643        }
644
645        // Handle remaining elements
646        for j in (unroll_count * 4)..size {
647            dst.add(j).write(src.add(j).read() / scalar);
648        }
649    }
650}
651
652#[cfg(test)]
653mod tests {
654    use super::*;
655
656    #[test]
657    fn test_tensor_division() {
658        let a = Tensor::ones(vec![2, 3]);
659        let mut b = Tensor::ones(vec![2, 3]);
660        b.fill(2.0);
661        let result = a.div_tensor_optimized(&b);
662
663        assert_eq!(result.shape().dims, vec![2, 3]);
664        assert_eq!(result.size(), 6);
665
666        // Check that all values are 0.5 (1.0 / 2.0)
667        unsafe {
668            for i in 0..result.size() {
669                assert!((result.as_ptr().add(i).read() - 0.5).abs() < 1e-6);
670            }
671        }
672    }
673
674    #[test]
675    fn test_scalar_division() {
676        let tensor = Tensor::ones(vec![2, 2]);
677        let result = tensor.div_scalar_optimized(2.0);
678
679        assert_eq!(result.shape().dims, vec![2, 2]);
680        assert_eq!(result.size(), 4);
681
682        // Check that all values are 0.5 (1.0 / 2.0)
683        unsafe {
684            for i in 0..result.size() {
685                assert!((result.as_ptr().add(i).read() - 0.5).abs() < 1e-6);
686            }
687        }
688    }
689
690    #[test]
691    fn test_negative_division() {
692        let tensor = Tensor::ones(vec![2, 3]);
693        let result = tensor.div_scalar_optimized(-2.0);
694
695        assert_eq!(result.shape().dims, vec![2, 3]);
696        assert_eq!(result.size(), 6);
697
698        // Check that all values are -0.5 (1.0 / -2.0)
699        unsafe {
700            for i in 0..result.size() {
701                assert!((result.as_ptr().add(i).read() - (-0.5)).abs() < 1e-6);
702            }
703        }
704    }
705
706    #[test]
707    #[should_panic(expected = "Division by zero")]
708    fn test_division_by_zero_scalar() {
709        let tensor = Tensor::ones(vec![2, 3]);
710        tensor.div_scalar_optimized(0.0);
711    }
712
713    #[test]
714    #[should_panic(expected = "Division by zero")]
715    fn test_division_by_zero_tensor() {
716        let a = Tensor::ones(vec![2, 3]);
717        let b = Tensor::zeros(vec![2, 3]);
718        a.div_tensor_optimized(&b);
719    }
720
721    #[test]
722    #[should_panic(expected = "Tensor shapes must match")]
723    fn test_mismatched_shapes() {
724        let a = Tensor::ones(vec![2, 3]);
725        let b = Tensor::ones(vec![3, 2]);
726        a.div_tensor_optimized(&b);
727    }
728
729    #[test]
730    fn test_edge_cases() {
731        // Test with zero numerator
732        let zero_tensor = Tensor::zeros(vec![2, 3]);
733        let other = Tensor::ones(vec![2, 3]);
734        let result = zero_tensor.div_tensor_optimized(&other);
735
736        assert_eq!(result.shape().dims, vec![2, 3]);
737        assert_eq!(result.size(), 6);
738
739        // Check that all values are 0.0 (0.0 / 1.0)
740        unsafe {
741            for i in 0..result.size() {
742                assert!((result.as_ptr().add(i).read() - 0.0).abs() < 1e-6);
743            }
744        }
745
746        // Test with negative values
747        let mut neg_tensor = Tensor::ones(vec![2, 3]);
748        neg_tensor.fill(-4.0);
749        let result = neg_tensor.div_scalar_optimized(2.0);
750
751        assert_eq!(result.shape().dims, vec![2, 3]);
752        assert_eq!(result.size(), 6);
753
754        // Check that all values are -2.0 (-4.0 / 2.0)
755        unsafe {
756            for i in 0..result.size() {
757                assert!((result.as_ptr().add(i).read() - (-2.0)).abs() < 1e-6);
758            }
759        }
760    }
761
762    #[test]
763    fn test_large_tensor_division() {
764        let a = Tensor::ones(vec![100, 100]);
765        let mut b = Tensor::ones(vec![100, 100]);
766        b.fill(1.5);
767        let result = a.div_tensor_optimized(&b);
768
769        assert_eq!(result.shape().dims, vec![100, 100]);
770        assert_eq!(result.size(), 10000);
771
772        // Check that all values are 0.666... (1.0 / 1.5)
773        unsafe {
774            for i in 0..result.size() {
775                assert!((result.as_ptr().add(i).read() - (2.0 / 3.0)).abs() < 1e-6);
776            }
777        }
778    }
779
780    #[test]
781    fn test_division_with_gradtrack() {
782        // Test scalar division with gradtrack
783        let a = Tensor::ones(vec![2, 3]).with_requires_grad();
784        let mut result = a.div_scalar(2.0);
785
786        // Check result values: 1.0 / 2.0 = 0.5
787        unsafe {
788            for i in 0..result.size() {
789                let val = result.as_ptr().add(i).read();
790                assert!((val - 0.5).abs() < 1e-6, "Expected 0.5, got {}", val);
791            }
792        }
793
794        result.backward(None);
795
796        // Check gradient: d/dx(x/2) = 1/2
797        if let Some(grad) = a.grad_by_value() {
798            unsafe {
799                for i in 0..grad.size() {
800                    let val = grad.as_ptr().add(i).read();
801                    assert!(
802                        (val - 0.5).abs() < 1e-6,
803                        "Expected gradient 0.5, got {}",
804                        val
805                    );
806                }
807            }
808        } else {
809            panic!("No gradient computed for scalar division!");
810        }
811
812        // Test tensor division with gradtrack
813        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
814        let mut b = Tensor::ones(vec![2, 2]);
815        b.fill(2.0);
816        let b = b.with_requires_grad();
817
818        let mut result = a.div_tensor(&b);
819
820        // Check result values: 1.0 / 2.0 = 0.5
821        unsafe {
822            for i in 0..result.size() {
823                let val = result.as_ptr().add(i).read();
824                assert!((val - 0.5).abs() < 1e-6, "Expected 0.5, got {}", val);
825            }
826        }
827
828        result.backward(None);
829
830        // Check gradients: ∂(a/b)/∂a = 1/b, ∂(a/b)/∂b = -a/b²
831        // For a = 1.0, b = 2.0: ∂(a/b)/∂a = 0.5, ∂(a/b)/∂b = -0.25
832        if let Some(grad_a) = a.grad_by_value() {
833            unsafe {
834                for i in 0..grad_a.size() {
835                    let val = grad_a.as_ptr().add(i).read();
836                    assert!(
837                        (val - 0.5).abs() < 1e-6,
838                        "Expected gradient A = 0.5 (∂(a/b)/∂a = 1/b), got {}",
839                        val
840                    );
841                }
842            }
843        } else {
844            panic!("No gradient A computed for tensor division!");
845        }
846
847        if let Some(grad_b) = b.grad_by_value() {
848            unsafe {
849                for i in 0..grad_b.size() {
850                    let val = grad_b.as_ptr().add(i).read();
851                    assert!(
852                        (val - (-0.25)).abs() < 1e-6,
853                        "Expected gradient B = -0.25 (∂(a/b)/∂b = -a/b²), got {}",
854                        val
855                    );
856                }
857            }
858        } else {
859            panic!("No gradient B computed for tensor division!");
860        }
861    }
862
863    #[test]
864    fn test_mixed_div_mul_operations_with_gradtrack() {
865        // Test complex computation: (a / 2) * (b / 3) + 1
866        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
867        let mut b = Tensor::ones(vec![2, 2]);
868        b.fill(6.0);
869        let b = b.with_requires_grad();
870
871        let scalar1 = 2.0;
872        let scalar2 = 3.0;
873
874        // Compute: (a / scalar1) * (b / scalar2) + 1
875        let div_a = a.div_scalar(scalar1); // a / 2
876        let div_b = b.div_scalar(scalar2); // b / 3
877        let mul_result = div_a.mul_tensor(&div_b); // (a / 2) * (b / 3)
878        let mut final_result = mul_result.add_scalar(1.0); // (a / 2) * (b / 3) + 1
879
880        // Check result values: (1/2) * (6/3) + 1 = 0.5 * 2 + 1 = 2
881        unsafe {
882            for i in 0..final_result.size() {
883                let val = final_result.as_ptr().add(i).read();
884                assert!((val - 2.0).abs() < 1e-6, "Expected 2.0, got {}", val);
885            }
886        }
887
888        final_result.backward(None);
889
890        // Check gradients: d/dx((x/2) * (y/3) + 1) = (y/3) * (1/2) = y/6
891        // d/dy((x/2) * (y/3) + 1) = (x/2) * (1/3) = x/6
892        // For x = 1.0, y = 6.0: d/dx = 6/6 = 1.0, d/dy = 1/6 = 0.166...
893        if let Some(grad_a) = a.grad_by_value() {
894            unsafe {
895                for i in 0..grad_a.size() {
896                    let val = grad_a.as_ptr().add(i).read();
897                    assert!(
898                        (val - 1.0).abs() < 1e-6,
899                        "Expected gradient A = 1.0, got {}",
900                        val
901                    );
902                }
903            }
904        } else {
905            panic!("No gradient A computed for mixed operations!");
906        }
907
908        if let Some(grad_b) = b.grad_by_value() {
909            unsafe {
910                for i in 0..grad_b.size() {
911                    let val = grad_b.as_ptr().add(i).read();
912                    assert!(
913                        (val - (1.0 / 6.0)).abs() < 1e-6,
914                        "Expected gradient B = 1/6, got {}",
915                        val
916                    );
917                }
918            }
919        } else {
920            panic!("No gradient B computed for mixed operations!");
921        }
922    }
923
924    #[test]
925    fn test_div_broadcasting_gradients_basic() {
926        use crate::gradtrack::clear_gradients;
927        clear_gradients();
928
929        // Test case: [2, 3] / [1, 3] -> [2, 3]
930        // For division: d/da (a / b) = 1/b, d/db (a / b) = -a/b^2
931
932        let a = Tensor::from_slice(&[2.0, 4.0, 6.0, 8.0, 10.0, 12.0], vec![2, 3])
933            .unwrap()
934            .with_requires_grad();
935        let b = Tensor::from_slice(&[2.0, 2.0, 2.0], vec![1, 3])
936            .unwrap()
937            .with_requires_grad();
938
939        let mut result = a.div_tensor(&b);
940        assert_eq!(result.shape().dims, vec![2, 3]);
941
942        // Set upstream gradient as ones
943        result.backward(None);
944
945        let grad_a = a.grad_by_value().expect("grad_a should exist");
946        let grad_b = b.grad_by_value().expect("grad_b should exist");
947
948        println!(
949            "Original shapes: a={:?}, b={:?}",
950            a.shape().dims,
951            b.shape().dims
952        );
953        println!(
954            "Gradient shapes: grad_a={:?}, grad_b={:?}",
955            grad_a.shape().dims,
956            grad_b.shape().dims
957        );
958
959        // grad_a should have same shape as a: [2, 3]
960        assert_eq!(
961            grad_a.shape().dims,
962            vec![2, 3],
963            "grad_a should match original shape of a"
964        );
965
966        // grad_b should have same shape as b: [1, 3]
967        assert_eq!(
968            grad_b.shape().dims,
969            vec![1, 3],
970            "grad_b should match original shape of b"
971        );
972
973        // grad_a should be [0.5, 0.5, 0.5, 0.5, 0.5, 0.5] (1/b = 1/2)
974        for i in 0..grad_a.size() {
975            let val = unsafe { *grad_a.as_ptr().add(i) };
976            assert!(
977                (val - 0.5).abs() < 1e-6,
978                "grad_a[{}] = {} should be 0.5",
979                i,
980                val
981            );
982        }
983
984        // For grad_b: d/db (a / b) = -a/b^2, summed over broadcast dimension
985        // a = [2,4,6,8,10,12], b = [2,2,2], so -a/b^2 = [-2/4, -4/4, -6/4, -8/4, -10/4, -12/4] = [-0.5, -1, -1.5, -2, -2.5, -3]
986        // Summed over first dimension: [-0.5-2, -1-2.5, -1.5-3] = [-2.5, -3.5, -4.5]
987        let expected_grad_b = [-2.5, -3.5, -4.5];
988        for (i, &expected) in expected_grad_b.iter().enumerate() {
989            let val = unsafe { *grad_b.as_ptr().add(i) };
990            assert!(
991                (val - expected).abs() < 1e-6,
992                "grad_b[{}] = {} should be {}",
993                i,
994                val,
995                expected
996            );
997        }
998    }
999
1000    #[test]
1001    fn test_div_scalar_broadcasting_gradients() {
1002        use crate::gradtrack::clear_gradients;
1003        clear_gradients();
1004
1005        // Test case: [2, 3] / [1] -> [2, 3]
1006        // For division: d/da (a / b) = 1/b, d/db (a / b) = -a/b^2
1007
1008        let a = Tensor::from_slice(&[2.0, 4.0, 6.0, 8.0, 10.0, 12.0], vec![2, 3])
1009            .unwrap()
1010            .with_requires_grad();
1011        let b = Tensor::from_slice(&[2.0], vec![1])
1012            .unwrap()
1013            .with_requires_grad();
1014
1015        let mut result = a.div_tensor(&b);
1016        result.backward(None);
1017
1018        let grad_a = a.grad_by_value().expect("grad_a should exist");
1019        let grad_b = b.grad_by_value().expect("grad_b should exist");
1020
1021        // grad_a should have same shape as a: [2, 3]
1022        assert_eq!(grad_a.shape().dims, vec![2, 3]);
1023
1024        // grad_b should have same shape as b: [1]
1025        println!("grad_b shape: {:?}, expected: [1]", grad_b.shape().dims);
1026        assert_eq!(grad_b.shape().dims, vec![1]);
1027
1028        // grad_b should be the sum of -a/b^2 = -(2+4+6+8+10+12)/4 = -42/4 = -10.5
1029        let val = unsafe { *grad_b.as_ptr() };
1030        assert!(
1031            (val - (-10.5)).abs() < 1e-6,
1032            "grad_b = {} should be -10.5",
1033            val
1034        );
1035    }
1036}