train_station/tensor/ops/
div.rs

1//! Division operations for tensors
2//!
3//! Provides element-wise division following PyTorch conventions with comprehensive
4//! broadcasting support, automatic differentiation, and high-performance SIMD optimization.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Division**: `div_tensor()` - Division with another tensor (PyTorch `div()` equivalent)
9//! - **Scalar Broadcasting**: `div_scalar()` - Division by scalar values
10//! - **Automatic Broadcasting**: NumPy-style broadcasting for compatible shapes
11//! - **SIMD Optimization**: AVX2 acceleration on x86_64 hardware
12//! - **Automatic Differentiation**: Full gradtrack support with gradient tracking
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Zero-copy Operations**: Efficient memory usage where possible
15//! - **Division by Zero Protection**: Comprehensive error checking and validation
16//!
17//! # Broadcasting Support
18//!
19//! All division operations support automatic broadcasting following NumPy rules:
20//! - Dimensions are aligned from the rightmost dimension
21//! - Dimensions are compatible if they are equal, or one of them is 1
22//! - Missing dimensions are treated as 1
23//! - Result shape follows broadcasting rules
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Acceleration**: 8x vectorization with AVX2 on compatible hardware
28//! - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
29//! - **Cache-friendly Access**: Linear memory access patterns
30//! - **Fallback Support**: Optimized scalar implementations for non-SIMD hardware
31//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
32//! - **Division by Zero Checks**: Optimized safety checks with minimal performance impact
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41// Note: removed manual prefetching; linear access + hardware prefetch is sufficient
42
43impl Tensor {
44    /// Element-wise division with another tensor with broadcasting support.
45    ///
46    /// Performs element-wise division with automatic broadcasting: `output[i] = self[i] / other[i]`
47    ///
48    /// Broadcasting enables division between tensors of different but compatible shapes.
49    /// Compatible shapes follow NumPy broadcasting rules:
50    /// - Dimensions are aligned from the rightmost dimension
51    /// - Dimensions are compatible if they are equal, or one of them is 1
52    /// - Missing dimensions are treated as 1
53    ///
54    /// # Arguments
55    /// * `other` - Tensor to divide by. Shapes must be broadcast-compatible.
56    ///
57    /// # Returns
58    /// A new tensor containing the element-wise quotient with broadcast result shape
59    ///
60    /// # Examples
61    ///
62    /// ## Same Shape Division
63    ///
64    /// ```
65    /// use train_station::Tensor;
66    ///
67    /// let a = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![3]).unwrap();
68    /// let b = Tensor::from_slice(&[2.0, 4.0, 5.0], vec![3]).unwrap();
69    /// let c = a.div_tensor(&b);
70    /// assert_eq!(c.shape().dims, vec![3]);
71    /// assert_eq!(c.get(&[0]), 5.0);
72    /// assert_eq!(c.get(&[1]), 5.0);
73    /// assert_eq!(c.get(&[2]), 6.0);
74    /// ```
75    ///
76    /// ## Broadcasting Division
77    ///
78    /// ```
79    /// use train_station::Tensor;
80    ///
81    /// // Broadcasting: [2, 1] / [1, 3] -> [2, 3]
82    /// let a = Tensor::from_slice(&[10.0, 20.0], vec![2, 1]).unwrap();
83    /// let b = Tensor::from_slice(&[1.0, 2.0, 5.0], vec![1, 3]).unwrap();
84    /// let c = a.div_tensor(&b);
85    /// assert_eq!(c.shape().dims, vec![2, 3]);
86    /// assert_eq!(c.get(&[0, 0]), 10.0);
87    /// assert_eq!(c.get(&[0, 1]), 5.0);
88    /// assert_eq!(c.get(&[1, 0]), 20.0);
89    /// assert_eq!(c.get(&[1, 1]), 10.0);
90    /// ```
91    ///
92    /// ## Scalar Division
93    ///
94    /// ```
95    /// use train_station::Tensor;
96    ///
97    /// // Scalar division: [2, 3] / scalar -> [2, 3]
98    /// let a = Tensor::ones(vec![2, 3]);
99    /// let b = Tensor::from_slice(&[2.0], vec![1]).unwrap();
100    /// let c = a.div_tensor(&b);
101    /// assert_eq!(c.shape().dims, vec![2, 3]);
102    /// assert_eq!(c.get(&[0, 0]), 0.5);
103    /// assert_eq!(c.get(&[1, 2]), 0.5);
104    /// ```
105    ///
106    /// # Panics
107    /// Panics if tensor shapes are not broadcast-compatible or division by zero
108    #[inline]
109    pub fn div_tensor(&self, other: &Tensor) -> Tensor {
110        // Check if shapes are identical for fast path
111        if self.shape().dims == other.shape().dims {
112            return self.div_tensor_same_shape(other);
113        }
114
115        // Use broadcasting for different shapes
116        let (broadcast_self, broadcast_other, _result_shape) =
117            self.broadcast_with(other).unwrap_or_else(|e| {
118                panic!(
119                    "Cannot broadcast tensor shapes {:?} and {:?}: {}",
120                    self.shape().dims,
121                    other.shape().dims,
122                    e
123                );
124            });
125
126        // Perform element-wise division on broadcasted tensors
127        let mut result = broadcast_self.div_tensor_optimized(&broadcast_other);
128
129        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
130            result.set_requires_grad_internal(true);
131            let operands = vec![self.clone(), other.clone()];
132            let grad_fn = GradFn::Div {
133                is_tensor_div: true,
134                scalar: None,
135                operands: Some(operands),
136                original_shapes: Some((self.shape().dims.clone(), other.shape().dims.clone())),
137            };
138            result.set_grad_fn(grad_fn.clone());
139
140            let mut input_ids = Vec::with_capacity(2);
141            if self.requires_grad() {
142                input_ids.push(self.id());
143            }
144            if other.requires_grad() {
145                input_ids.push(other.id());
146            }
147            GradEngine::register_operation(result.id(), input_ids, grad_fn);
148        }
149
150        result
151    }
152
153    /// Element-wise division for tensors with identical shapes (fast path).
154    ///
155    /// This is an optimized path for tensors that already have the same shape,
156    /// avoiding the overhead of broadcasting computation. Used internally by
157    /// `div_tensor()` when shapes are identical.
158    ///
159    /// # Arguments
160    /// * `other` - Tensor to divide by, must have the same shape as self
161    ///
162    /// # Returns
163    /// A new tensor containing the element-wise quotient
164    ///
165    /// # Performance Characteristics
166    ///
167    /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
168    /// - **SIMD Optimization**: Uses optimized tensor division with SIMD acceleration
169    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
170    ///
171    /// # Panics
172    ///
173    /// Panics if tensor shapes do not match or if any element in `other` is zero
174    #[inline]
175    fn div_tensor_same_shape(&self, other: &Tensor) -> Tensor {
176        assert_eq!(
177            self.shape(),
178            other.shape(),
179            "Tensor shapes must match for same-shape division"
180        );
181        let mut result = self.div_tensor_optimized(other);
182
183        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
184            result.set_requires_grad_internal(true);
185            let operands = vec![self.clone(), other.clone()];
186            let grad_fn = GradFn::Div {
187                is_tensor_div: true,
188                scalar: None,
189                operands: Some(operands),
190                original_shapes: None, // Same shape case
191            };
192            result.set_grad_fn(grad_fn.clone());
193
194            let mut input_ids = Vec::with_capacity(2);
195            if self.requires_grad() {
196                input_ids.push(self.id());
197            }
198            if other.requires_grad() {
199                input_ids.push(other.id());
200            }
201            GradEngine::register_operation(result.id(), input_ids, grad_fn);
202        }
203
204        result
205    }
206
207    /// Broadcast division with a scalar value.
208    ///
209    /// Divides every element by the scalar: `output[i] = self[i] / scalar`
210    ///
211    /// # Arguments
212    /// * `scalar` - Value to divide each element by (must not be zero)
213    ///
214    /// # Returns
215    /// A new tensor with each element divided by the scalar
216    ///
217    /// # Examples
218    ///
219    /// ## Basic Scalar Division
220    ///
221    /// ```
222    /// use train_station::Tensor;
223    ///
224    /// let a = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![3]).unwrap();
225    /// let b = a.div_scalar(10.0);
226    /// assert_eq!(b.shape().dims, vec![3]);
227    /// assert_eq!(b.get(&[0]), 1.0);
228    /// assert_eq!(b.get(&[1]), 2.0);
229    /// assert_eq!(b.get(&[2]), 3.0);
230    /// ```
231    ///
232    /// ## Multi-dimensional Scalar Division
233    ///
234    /// ```
235    /// use train_station::Tensor;
236    ///
237    /// let a = Tensor::ones(vec![2, 3]);
238    /// let b = a.div_scalar(2.0);
239    /// assert_eq!(b.shape().dims, vec![2, 3]);
240    /// assert_eq!(b.get(&[0, 0]), 0.5);
241    /// assert_eq!(b.get(&[1, 2]), 0.5);
242    /// ```
243    ///
244    /// # Panics
245    /// Panics if scalar is zero
246    #[inline]
247    pub fn div_scalar(&self, scalar: f32) -> Tensor {
248        let mut result = self.div_scalar_optimized(scalar);
249
250        if self.requires_grad() && is_grad_enabled() {
251            result.set_requires_grad_internal(true);
252            let grad_fn = GradFn::Div {
253                is_tensor_div: false,
254                scalar: Some(scalar),
255                operands: None,
256                original_shapes: None, // Scalar case
257            };
258            result.set_grad_fn(grad_fn.clone());
259            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
260        }
261
262        result
263    }
264    /// Internal optimized tensor / tensor operation
265    ///
266    /// Performs element-wise division between two tensors with the same shape,
267    /// using SIMD acceleration when available. This is the core implementation
268    /// used by `div_tensor()` after broadcasting has been applied.
269    ///
270    /// # Arguments
271    ///
272    /// * `other` - Tensor to divide by, must have the same shape as self
273    ///
274    /// # Returns
275    ///
276    /// A new tensor containing the element-wise quotient
277    ///
278    /// # Safety
279    ///
280    /// Assumes both tensors have the same shape and valid memory layouts.
281    /// Uses unsafe SIMD operations for performance optimization.
282    /// Division by zero will panic.
283    ///
284    /// # Performance Characteristics
285    ///
286    /// - **SIMD Optimization**: Uses AVX2 when available for 8x vectorization
287    /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
288    /// - **Cache-friendly**: Linear memory access patterns
289    /// - **Fallback**: Optimized scalar implementation for non-SIMD hardware
290    /// - **Division by Zero Checks**: Comprehensive safety validation
291    #[inline]
292    pub(crate) fn div_tensor_optimized(&self, other: &Tensor) -> Tensor {
293        assert_eq!(self.shape(), other.shape(), "Tensor shapes must match");
294
295        let mut output = Tensor::new(self.shape().dims.clone());
296
297        unsafe {
298            let a = self.as_ptr();
299            let b = other.as_ptr();
300            let dst = output.as_mut_ptr();
301
302            #[cfg(target_arch = "x86_64")]
303            {
304                // Use SIMD for better performance when available
305                if is_x86_feature_detected!("avx2") {
306                    self.div_tensors_simd_avx2_optimized(a, b, dst);
307                    return output;
308                }
309            }
310
311            // Fallback to scalar operations with better cache usage
312            self.div_tensors_scalar_optimized(a, b, dst);
313        }
314
315        output
316    }
317
318    /// SIMD-optimized tensor division using AVX2 instructions
319    ///
320    /// Performs element-wise division using AVX2 SIMD instructions for maximum
321    /// performance on x86_64 hardware. Processes 32 elements per iteration with
322    /// 4x unrolling for optimal instruction throughput. Includes comprehensive
323    /// division by zero checking for safety.
324    ///
325    /// # Arguments
326    ///
327    /// * `a` - Pointer to first tensor data (numerator)
328    /// * `b` - Pointer to second tensor data (denominator)
329    /// * `dst` - Pointer to output tensor data
330    ///
331    /// # Safety
332    ///
333    /// Requires AVX2 support and valid pointers with sufficient memory.
334    /// All pointers must be aligned and point to valid tensor data.
335    /// Division by zero will panic.
336    ///
337    /// # Performance Characteristics
338    ///
339    /// - **SIMD Width**: 8 elements per AVX2 vector operation
340    /// - **Unrolling**: 4x unrolling (32 elements per iteration)
341    /// - **Memory Access**: Linear access patterns for cache efficiency
342    /// - **Fallback**: Handles remaining elements with scalar operations
343    /// - **Safety Checks**: Comprehensive division by zero validation
344    #[cfg(target_arch = "x86_64")]
345    #[inline]
346    #[target_feature(enable = "avx2")]
347    unsafe fn div_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
348        let size = self.size();
349        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
350        let mut offset = 0;
351
352        // Unrolled SIMD loop for throughput
353        for _ in 0..simd_count {
354            // Process 4 AVX2 vectors (32 elements) per iteration
355            let a_vec1 = _mm256_loadu_ps(a.add(offset));
356            let b_vec1 = _mm256_loadu_ps(b.add(offset));
357            let div_vec1 = _mm256_div_ps(a_vec1, b_vec1);
358            _mm256_storeu_ps(dst.add(offset), div_vec1);
359
360            let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
361            let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
362            let div_vec2 = _mm256_div_ps(a_vec2, b_vec2);
363            _mm256_storeu_ps(dst.add(offset + 8), div_vec2);
364
365            let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
366            let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
367            let div_vec3 = _mm256_div_ps(a_vec3, b_vec3);
368            _mm256_storeu_ps(dst.add(offset + 16), div_vec3);
369
370            let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
371            let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
372            let div_vec4 = _mm256_div_ps(a_vec4, b_vec4);
373            _mm256_storeu_ps(dst.add(offset + 24), div_vec4);
374
375            offset += 32;
376        }
377
378        // Handle remaining 8-element blocks, then tail with checks
379        let remaining_full_blocks = (size - offset) / 8;
380        for _ in 0..remaining_full_blocks {
381            let a_vec = _mm256_loadu_ps(a.add(offset));
382            let b_vec = _mm256_loadu_ps(b.add(offset));
383            // Fallback to scalar for safety checks
384            let mut a_vals = [0.0f32; 8];
385            let mut b_vals = [0.0f32; 8];
386            _mm256_storeu_ps(a_vals.as_mut_ptr(), a_vec);
387            _mm256_storeu_ps(b_vals.as_mut_ptr(), b_vec);
388            for t in 0..8 {
389                let j = offset + t;
390                if b_vals[t] == 0.0 {
391                    panic!("Division by zero detected at index {}", j);
392                }
393                *dst.add(j) = a_vals[t] / b_vals[t];
394            }
395            offset += 8;
396        }
397        while offset + 4 <= size {
398            let b0 = *b.add(offset);
399            let b1 = *b.add(offset + 1);
400            let b2 = *b.add(offset + 2);
401            let b3 = *b.add(offset + 3);
402            if b0 == 0.0 || b1 == 0.0 || b2 == 0.0 || b3 == 0.0 {
403                panic!("Division by zero detected in unrolled loop");
404            }
405            *dst.add(offset) = *a.add(offset) / b0;
406            *dst.add(offset + 1) = *a.add(offset + 1) / b1;
407            *dst.add(offset + 2) = *a.add(offset + 2) / b2;
408            *dst.add(offset + 3) = *a.add(offset + 3) / b3;
409            offset += 4;
410        }
411        for i in offset..size {
412            let b_val = *b.add(i);
413            if b_val == 0.0 {
414                panic!("Division by zero detected at index {}", i);
415            }
416            *dst.add(i) = *a.add(i) / b_val;
417        }
418    }
419
420    /// Optimized scalar tensor division fallback
421    ///
422    /// Performs element-wise division using optimized scalar operations when
423    /// SIMD is not available. Uses 4x unrolling for better instruction-level
424    /// parallelism and cache efficiency. Includes comprehensive division by zero
425    /// checking for safety.
426    ///
427    /// # Arguments
428    ///
429    /// * `a` - Pointer to first tensor data (numerator)
430    /// * `b` - Pointer to second tensor data (denominator)
431    /// * `dst` - Pointer to output tensor data
432    ///
433    /// # Safety
434    ///
435    /// Requires valid pointers with sufficient memory for the tensor size.
436    /// All pointers must point to valid tensor data.
437    /// Division by zero will panic.
438    ///
439    /// # Performance Characteristics
440    ///
441    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
442    /// - **Memory Access**: Linear access patterns for cache efficiency
443    /// - **Fallback**: Handles remaining elements with scalar operations
444    /// - **Safety Checks**: Comprehensive division by zero validation
445    #[inline]
446    unsafe fn div_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
447        let size = self.size();
448
449        // Use unrolled loops for better instruction throughput
450        let unroll_count = size / 4;
451        let mut i = 0;
452
453        // Process 4 elements at a time for better cache utilization
454        while i < unroll_count {
455            let idx = i * 4;
456
457            // Check for division by zero
458            let b0 = b.add(idx).read();
459            let b1 = b.add(idx + 1).read();
460            let b2 = b.add(idx + 2).read();
461            let b3 = b.add(idx + 3).read();
462
463            if b0 == 0.0 || b1 == 0.0 || b2 == 0.0 || b3 == 0.0 {
464                panic!("Division by zero detected in unrolled loop");
465            }
466
467            dst.add(idx).write(a.add(idx).read() / b0);
468            dst.add(idx + 1).write(a.add(idx + 1).read() / b1);
469            dst.add(idx + 2).write(a.add(idx + 2).read() / b2);
470            dst.add(idx + 3).write(a.add(idx + 3).read() / b3);
471            i += 1;
472        }
473
474        // Handle remaining elements
475        for j in (unroll_count * 4)..size {
476            let b_val = b.add(j).read();
477            if b_val == 0.0 {
478                panic!("Division by zero detected at index {}", j);
479            }
480            dst.add(j).write(a.add(j).read() / b_val);
481        }
482    }
483
484    /// Internal optimized scalar / tensor operation
485    ///
486    /// Performs element-wise division of a scalar into each element of the tensor,
487    /// using SIMD acceleration when available. This is the core implementation
488    /// used by `div_scalar()`.
489    ///
490    /// # Arguments
491    ///
492    /// * `scalar` - Scalar value to divide each element by (must not be zero)
493    ///
494    /// # Returns
495    ///
496    /// A new tensor with each element divided by the scalar
497    ///
498    /// # Safety
499    ///
500    /// Assumes valid tensor memory layout. Uses unsafe SIMD operations for
501    /// performance optimization. Division by zero will panic.
502    ///
503    /// # Performance Characteristics
504    ///
505    /// - **SIMD Optimization**: Uses AVX2 when available for 8x vectorization
506    /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
507    /// - **Cache-friendly**: Linear memory access patterns
508    /// - **Fallback**: Optimized scalar implementation for non-SIMD hardware
509    /// - **Division by Zero Checks**: Comprehensive safety validation
510    #[inline]
511    pub(crate) fn div_scalar_optimized(&self, scalar: f32) -> Tensor {
512        if scalar == 0.0 {
513            panic!("Division by zero: cannot divide tensor by zero scalar");
514        }
515
516        let mut output = Tensor::new(self.shape().dims.clone());
517
518        unsafe {
519            let src = self.as_ptr();
520            let dst = output.as_mut_ptr();
521
522            #[cfg(target_arch = "x86_64")]
523            {
524                // Use SIMD for better performance when available
525                if is_x86_feature_detected!("avx2") {
526                    self.div_scalar_simd_avx2_optimized(src, dst, scalar);
527                    return output;
528                }
529            }
530
531            // Fallback to scalar operations with better cache usage
532            self.div_scalar_fallback_optimized(src, dst, scalar);
533        }
534
535        output
536    }
537
538    /// SIMD-optimized scalar division using AVX2 instructions
539    ///
540    /// Performs element-wise scalar division using AVX2 SIMD instructions for maximum
541    /// performance on x86_64 hardware. Processes 32 elements per iteration with
542    /// 4x unrolling for optimal instruction throughput.
543    ///
544    /// # Arguments
545    ///
546    /// * `src` - Pointer to source tensor data
547    /// * `dst` - Pointer to output tensor data
548    /// * `scalar` - Scalar value to divide each element by
549    ///
550    /// # Safety
551    ///
552    /// Requires AVX2 support and valid pointers with sufficient memory.
553    /// All pointers must be aligned and point to valid tensor data.
554    /// Scalar must not be zero (checked before calling this function).
555    ///
556    /// # Performance Characteristics
557    ///
558    /// - **SIMD Width**: 8 elements per AVX2 vector operation
559    /// - **Unrolling**: 4x unrolling (32 elements per iteration)
560    /// - **Memory Access**: Linear access patterns for cache efficiency
561    /// - **Fallback**: Handles remaining elements with scalar operations
562    /// - **Optimization**: Most common scalar division pattern optimized
563    #[cfg(target_arch = "x86_64")]
564    #[inline]
565    #[target_feature(enable = "avx2")]
566    unsafe fn div_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
567        let size = self.size();
568        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
569        let mut offset = 0;
570
571        // Create SIMD vector for scalar
572        let scalar_vec = _mm256_set1_ps(scalar);
573
574        // Unrolled SIMD loop for throughput
575        for _ in 0..simd_count {
576            // Process 4 AVX2 vectors (32 elements) per iteration
577            let src_vec1 = _mm256_loadu_ps(src.add(offset));
578            let div_vec1 = _mm256_div_ps(src_vec1, scalar_vec);
579            _mm256_storeu_ps(dst.add(offset), div_vec1);
580
581            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
582            let div_vec2 = _mm256_div_ps(src_vec2, scalar_vec);
583            _mm256_storeu_ps(dst.add(offset + 8), div_vec2);
584
585            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
586            let div_vec3 = _mm256_div_ps(src_vec3, scalar_vec);
587            _mm256_storeu_ps(dst.add(offset + 16), div_vec3);
588
589            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
590            let div_vec4 = _mm256_div_ps(src_vec4, scalar_vec);
591            _mm256_storeu_ps(dst.add(offset + 24), div_vec4);
592
593            offset += 32;
594        }
595
596        // Handle remaining elements with scalar operations
597        for i in offset..size {
598            *dst.add(i) = *src.add(i) / scalar;
599        }
600    }
601
602    /// Optimized scalar division fallback
603    ///
604    /// Performs element-wise scalar division using optimized scalar operations when
605    /// SIMD is not available. Uses 4x unrolling for better instruction-level
606    /// parallelism and cache efficiency.
607    ///
608    /// # Arguments
609    ///
610    /// * `src` - Pointer to source tensor data
611    /// * `dst` - Pointer to output tensor data
612    /// * `scalar` - Scalar value to divide each element by
613    ///
614    /// # Safety
615    ///
616    /// Requires valid pointers with sufficient memory for the tensor size.
617    /// All pointers must point to valid tensor data.
618    /// Scalar must not be zero (checked before calling this function).
619    ///
620    /// # Performance Characteristics
621    ///
622    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
623    /// - **Memory Access**: Linear access patterns for cache efficiency
624    /// - **Fallback**: Handles remaining elements with scalar operations
625    #[inline]
626    unsafe fn div_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
627        let size = self.size();
628
629        // Use unrolled loops for better instruction throughput
630        let unroll_count = size / 4;
631        let mut i = 0;
632
633        // Process 4 elements at a time for better cache utilization
634        while i < unroll_count {
635            let idx = i * 4;
636            dst.add(idx).write(src.add(idx).read() / scalar);
637            dst.add(idx + 1).write(src.add(idx + 1).read() / scalar);
638            dst.add(idx + 2).write(src.add(idx + 2).read() / scalar);
639            dst.add(idx + 3).write(src.add(idx + 3).read() / scalar);
640            i += 1;
641        }
642
643        // Handle remaining elements
644        for j in (unroll_count * 4)..size {
645            dst.add(j).write(src.add(j).read() / scalar);
646        }
647    }
648}
649
650#[cfg(test)]
651mod tests {
652    use super::*;
653
654    #[test]
655    fn test_tensor_division() {
656        let a = Tensor::ones(vec![2, 3]);
657        let mut b = Tensor::ones(vec![2, 3]);
658        b.fill(2.0);
659        let result = a.div_tensor_optimized(&b);
660
661        assert_eq!(result.shape().dims, vec![2, 3]);
662        assert_eq!(result.size(), 6);
663
664        // Check that all values are 0.5 (1.0 / 2.0)
665        unsafe {
666            for i in 0..result.size() {
667                assert!((result.as_ptr().add(i).read() - 0.5).abs() < 1e-6);
668            }
669        }
670    }
671
672    #[test]
673    fn test_scalar_division() {
674        let tensor = Tensor::ones(vec![2, 2]);
675        let result = tensor.div_scalar_optimized(2.0);
676
677        assert_eq!(result.shape().dims, vec![2, 2]);
678        assert_eq!(result.size(), 4);
679
680        // Check that all values are 0.5 (1.0 / 2.0)
681        unsafe {
682            for i in 0..result.size() {
683                assert!((result.as_ptr().add(i).read() - 0.5).abs() < 1e-6);
684            }
685        }
686    }
687
688    #[test]
689    fn test_negative_division() {
690        let tensor = Tensor::ones(vec![2, 3]);
691        let result = tensor.div_scalar_optimized(-2.0);
692
693        assert_eq!(result.shape().dims, vec![2, 3]);
694        assert_eq!(result.size(), 6);
695
696        // Check that all values are -0.5 (1.0 / -2.0)
697        unsafe {
698            for i in 0..result.size() {
699                assert!((result.as_ptr().add(i).read() - (-0.5)).abs() < 1e-6);
700            }
701        }
702    }
703
704    #[test]
705    #[should_panic(expected = "Division by zero")]
706    fn test_division_by_zero_scalar() {
707        let tensor = Tensor::ones(vec![2, 3]);
708        tensor.div_scalar_optimized(0.0);
709    }
710
711    #[test]
712    #[should_panic(expected = "Division by zero")]
713    fn test_division_by_zero_tensor() {
714        let a = Tensor::ones(vec![2, 3]);
715        let b = Tensor::zeros(vec![2, 3]);
716        a.div_tensor_optimized(&b);
717    }
718
719    #[test]
720    #[should_panic(expected = "Tensor shapes must match")]
721    fn test_mismatched_shapes() {
722        let a = Tensor::ones(vec![2, 3]);
723        let b = Tensor::ones(vec![3, 2]);
724        a.div_tensor_optimized(&b);
725    }
726
727    #[test]
728    fn test_edge_cases() {
729        // Test with zero numerator
730        let zero_tensor = Tensor::zeros(vec![2, 3]);
731        let other = Tensor::ones(vec![2, 3]);
732        let result = zero_tensor.div_tensor_optimized(&other);
733
734        assert_eq!(result.shape().dims, vec![2, 3]);
735        assert_eq!(result.size(), 6);
736
737        // Check that all values are 0.0 (0.0 / 1.0)
738        unsafe {
739            for i in 0..result.size() {
740                assert!((result.as_ptr().add(i).read() - 0.0).abs() < 1e-6);
741            }
742        }
743
744        // Test with negative values
745        let mut neg_tensor = Tensor::ones(vec![2, 3]);
746        neg_tensor.fill(-4.0);
747        let result = neg_tensor.div_scalar_optimized(2.0);
748
749        assert_eq!(result.shape().dims, vec![2, 3]);
750        assert_eq!(result.size(), 6);
751
752        // Check that all values are -2.0 (-4.0 / 2.0)
753        unsafe {
754            for i in 0..result.size() {
755                assert!((result.as_ptr().add(i).read() - (-2.0)).abs() < 1e-6);
756            }
757        }
758    }
759
760    #[test]
761    fn test_large_tensor_division() {
762        let a = Tensor::ones(vec![100, 100]);
763        let mut b = Tensor::ones(vec![100, 100]);
764        b.fill(1.5);
765        let result = a.div_tensor_optimized(&b);
766
767        assert_eq!(result.shape().dims, vec![100, 100]);
768        assert_eq!(result.size(), 10000);
769
770        // Check that all values are 0.666... (1.0 / 1.5)
771        unsafe {
772            for i in 0..result.size() {
773                assert!((result.as_ptr().add(i).read() - (2.0 / 3.0)).abs() < 1e-6);
774            }
775        }
776    }
777
778    #[test]
779    fn test_division_with_gradtrack() {
780        // Test scalar division with gradtrack
781        let a = Tensor::ones(vec![2, 3]).with_requires_grad();
782        let mut result = a.div_scalar(2.0);
783
784        // Check result values: 1.0 / 2.0 = 0.5
785        unsafe {
786            for i in 0..result.size() {
787                let val = result.as_ptr().add(i).read();
788                assert!((val - 0.5).abs() < 1e-6, "Expected 0.5, got {}", val);
789            }
790        }
791
792        result.backward(None);
793
794        // Check gradient: d/dx(x/2) = 1/2
795        if let Some(grad) = a.grad_by_value() {
796            unsafe {
797                for i in 0..grad.size() {
798                    let val = grad.as_ptr().add(i).read();
799                    assert!(
800                        (val - 0.5).abs() < 1e-6,
801                        "Expected gradient 0.5, got {}",
802                        val
803                    );
804                }
805            }
806        } else {
807            panic!("No gradient computed for scalar division!");
808        }
809
810        // Test tensor division with gradtrack
811        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
812        let mut b = Tensor::ones(vec![2, 2]);
813        b.fill(2.0);
814        let b = b.with_requires_grad();
815
816        let mut result = a.div_tensor(&b);
817
818        // Check result values: 1.0 / 2.0 = 0.5
819        unsafe {
820            for i in 0..result.size() {
821                let val = result.as_ptr().add(i).read();
822                assert!((val - 0.5).abs() < 1e-6, "Expected 0.5, got {}", val);
823            }
824        }
825
826        result.backward(None);
827
828        // Check gradients: ∂(a/b)/∂a = 1/b, ∂(a/b)/∂b = -a/b²
829        // For a = 1.0, b = 2.0: ∂(a/b)/∂a = 0.5, ∂(a/b)/∂b = -0.25
830        if let Some(grad_a) = a.grad_by_value() {
831            unsafe {
832                for i in 0..grad_a.size() {
833                    let val = grad_a.as_ptr().add(i).read();
834                    assert!(
835                        (val - 0.5).abs() < 1e-6,
836                        "Expected gradient A = 0.5 (∂(a/b)/∂a = 1/b), got {}",
837                        val
838                    );
839                }
840            }
841        } else {
842            panic!("No gradient A computed for tensor division!");
843        }
844
845        if let Some(grad_b) = b.grad_by_value() {
846            unsafe {
847                for i in 0..grad_b.size() {
848                    let val = grad_b.as_ptr().add(i).read();
849                    assert!(
850                        (val - (-0.25)).abs() < 1e-6,
851                        "Expected gradient B = -0.25 (∂(a/b)/∂b = -a/b²), got {}",
852                        val
853                    );
854                }
855            }
856        } else {
857            panic!("No gradient B computed for tensor division!");
858        }
859    }
860
861    #[test]
862    fn test_mixed_div_mul_operations_with_gradtrack() {
863        // Test complex computation: (a / 2) * (b / 3) + 1
864        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
865        let mut b = Tensor::ones(vec![2, 2]);
866        b.fill(6.0);
867        let b = b.with_requires_grad();
868
869        let scalar1 = 2.0;
870        let scalar2 = 3.0;
871
872        // Compute: (a / scalar1) * (b / scalar2) + 1
873        let div_a = a.div_scalar(scalar1); // a / 2
874        let div_b = b.div_scalar(scalar2); // b / 3
875        let mul_result = div_a.mul_tensor(&div_b); // (a / 2) * (b / 3)
876        let mut final_result = mul_result.add_scalar(1.0); // (a / 2) * (b / 3) + 1
877
878        // Check result values: (1/2) * (6/3) + 1 = 0.5 * 2 + 1 = 2
879        unsafe {
880            for i in 0..final_result.size() {
881                let val = final_result.as_ptr().add(i).read();
882                assert!((val - 2.0).abs() < 1e-6, "Expected 2.0, got {}", val);
883            }
884        }
885
886        final_result.backward(None);
887
888        // Check gradients: d/dx((x/2) * (y/3) + 1) = (y/3) * (1/2) = y/6
889        // d/dy((x/2) * (y/3) + 1) = (x/2) * (1/3) = x/6
890        // For x = 1.0, y = 6.0: d/dx = 6/6 = 1.0, d/dy = 1/6 = 0.166...
891        if let Some(grad_a) = a.grad_by_value() {
892            unsafe {
893                for i in 0..grad_a.size() {
894                    let val = grad_a.as_ptr().add(i).read();
895                    assert!(
896                        (val - 1.0).abs() < 1e-6,
897                        "Expected gradient A = 1.0, got {}",
898                        val
899                    );
900                }
901            }
902        } else {
903            panic!("No gradient A computed for mixed operations!");
904        }
905
906        if let Some(grad_b) = b.grad_by_value() {
907            unsafe {
908                for i in 0..grad_b.size() {
909                    let val = grad_b.as_ptr().add(i).read();
910                    assert!(
911                        (val - (1.0 / 6.0)).abs() < 1e-6,
912                        "Expected gradient B = 1/6, got {}",
913                        val
914                    );
915                }
916            }
917        } else {
918            panic!("No gradient B computed for mixed operations!");
919        }
920    }
921
922    #[test]
923    fn test_div_broadcasting_gradients_basic() {
924        use crate::gradtrack::clear_gradients;
925        clear_gradients();
926
927        // Test case: [2, 3] / [1, 3] -> [2, 3]
928        // For division: d/da (a / b) = 1/b, d/db (a / b) = -a/b^2
929
930        let a = Tensor::from_slice(&[2.0, 4.0, 6.0, 8.0, 10.0, 12.0], vec![2, 3])
931            .unwrap()
932            .with_requires_grad();
933        let b = Tensor::from_slice(&[2.0, 2.0, 2.0], vec![1, 3])
934            .unwrap()
935            .with_requires_grad();
936
937        let mut result = a.div_tensor(&b);
938        assert_eq!(result.shape().dims, vec![2, 3]);
939
940        // Set upstream gradient as ones
941        result.backward(None);
942
943        let grad_a = a.grad_by_value().expect("grad_a should exist");
944        let grad_b = b.grad_by_value().expect("grad_b should exist");
945
946        println!(
947            "Original shapes: a={:?}, b={:?}",
948            a.shape().dims,
949            b.shape().dims
950        );
951        println!(
952            "Gradient shapes: grad_a={:?}, grad_b={:?}",
953            grad_a.shape().dims,
954            grad_b.shape().dims
955        );
956
957        // grad_a should have same shape as a: [2, 3]
958        assert_eq!(
959            grad_a.shape().dims,
960            vec![2, 3],
961            "grad_a should match original shape of a"
962        );
963
964        // grad_b should have same shape as b: [1, 3]
965        assert_eq!(
966            grad_b.shape().dims,
967            vec![1, 3],
968            "grad_b should match original shape of b"
969        );
970
971        // grad_a should be [0.5, 0.5, 0.5, 0.5, 0.5, 0.5] (1/b = 1/2)
972        for i in 0..grad_a.size() {
973            let val = unsafe { *grad_a.as_ptr().add(i) };
974            assert!(
975                (val - 0.5).abs() < 1e-6,
976                "grad_a[{}] = {} should be 0.5",
977                i,
978                val
979            );
980        }
981
982        // For grad_b: d/db (a / b) = -a/b^2, summed over broadcast dimension
983        // a = [2,4,6,8,10,12], b = [2,2,2], so -a/b^2 = [-2/4, -4/4, -6/4, -8/4, -10/4, -12/4] = [-0.5, -1, -1.5, -2, -2.5, -3]
984        // Summed over first dimension: [-0.5-2, -1-2.5, -1.5-3] = [-2.5, -3.5, -4.5]
985        let expected_grad_b = [-2.5, -3.5, -4.5];
986        for (i, &expected) in expected_grad_b.iter().enumerate() {
987            let val = unsafe { *grad_b.as_ptr().add(i) };
988            assert!(
989                (val - expected).abs() < 1e-6,
990                "grad_b[{}] = {} should be {}",
991                i,
992                val,
993                expected
994            );
995        }
996    }
997
998    #[test]
999    fn test_div_scalar_broadcasting_gradients() {
1000        use crate::gradtrack::clear_gradients;
1001        clear_gradients();
1002
1003        // Test case: [2, 3] / [1] -> [2, 3]
1004        // For division: d/da (a / b) = 1/b, d/db (a / b) = -a/b^2
1005
1006        let a = Tensor::from_slice(&[2.0, 4.0, 6.0, 8.0, 10.0, 12.0], vec![2, 3])
1007            .unwrap()
1008            .with_requires_grad();
1009        let b = Tensor::from_slice(&[2.0], vec![1])
1010            .unwrap()
1011            .with_requires_grad();
1012
1013        let mut result = a.div_tensor(&b);
1014        result.backward(None);
1015
1016        let grad_a = a.grad_by_value().expect("grad_a should exist");
1017        let grad_b = b.grad_by_value().expect("grad_b should exist");
1018
1019        // grad_a should have same shape as a: [2, 3]
1020        assert_eq!(grad_a.shape().dims, vec![2, 3]);
1021
1022        // grad_b should have same shape as b: [1]
1023        println!("grad_b shape: {:?}, expected: [1]", grad_b.shape().dims);
1024        assert_eq!(grad_b.shape().dims, vec![1]);
1025
1026        // grad_b should be the sum of -a/b^2 = -(2+4+6+8+10+12)/4 = -42/4 = -10.5
1027        let val = unsafe { *grad_b.as_ptr() };
1028        assert!(
1029            (val - (-10.5)).abs() < 1e-6,
1030            "grad_b = {} should be -10.5",
1031            val
1032        );
1033    }
1034}