train_station/tensor/ops/
div.rs

1//! Division operations for tensors
2//!
3//! Provides element-wise division following PyTorch conventions with comprehensive
4//! broadcasting support, automatic differentiation, and high-performance SIMD optimization.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Division**: `div_tensor()` - Division with another tensor (PyTorch `div()` equivalent)
9//! - **Scalar Broadcasting**: `div_scalar()` - Division by scalar values
10//! - **Automatic Broadcasting**: NumPy-style broadcasting for compatible shapes
11//! - **SIMD Optimization**: AVX2 acceleration on x86_64 hardware
12//! - **Automatic Differentiation**: Full gradtrack support with gradient tracking
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Zero-copy Operations**: Efficient memory usage where possible
15//! - **Division by Zero Protection**: Comprehensive error checking and validation
16//!
17//! # Broadcasting Support
18//!
19//! All division operations support automatic broadcasting following NumPy rules:
20//! - Dimensions are aligned from the rightmost dimension
21//! - Dimensions are compatible if they are equal, or one of them is 1
22//! - Missing dimensions are treated as 1
23//! - Result shape follows broadcasting rules
24//!
25//! # Performance Characteristics
26//!
27//! - **SIMD Acceleration**: 8x vectorization with AVX2 on compatible hardware
28//! - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
29//! - **Cache-friendly Access**: Linear memory access patterns
30//! - **Fallback Support**: Optimized scalar implementations for non-SIMD hardware
31//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
32//! - **Division by Zero Checks**: Optimized safety checks with minimal performance impact
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41// Note: removed manual prefetching; linear access + hardware prefetch is sufficient
42
43impl Tensor {
44    /// Element-wise division with another tensor with broadcasting support.
45    ///
46    /// Performs element-wise division with automatic broadcasting: `output[i] = self[i] / other[i]`
47    ///
48    /// Broadcasting enables division between tensors of different but compatible shapes.
49    /// Compatible shapes follow NumPy broadcasting rules:
50    /// - Dimensions are aligned from the rightmost dimension
51    /// - Dimensions are compatible if they are equal, or one of them is 1
52    /// - Missing dimensions are treated as 1
53    ///
54    /// # Arguments
55    /// * `other` - Tensor to divide by. Shapes must be broadcast-compatible.
56    ///
57    /// # Returns
58    /// A new tensor containing the element-wise quotient with broadcast result shape
59    ///
60    /// # Examples
61    ///
62    /// ## Same Shape Division
63    ///
64    /// ```
65    /// use train_station::Tensor;
66    ///
67    /// let a = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![3]).unwrap();
68    /// let b = Tensor::from_slice(&[2.0, 4.0, 5.0], vec![3]).unwrap();
69    /// let c = a.div_tensor(&b);
70    /// assert_eq!(c.shape().dims(), vec![3]);
71    /// assert_eq!(c.get(&[0]), 5.0);
72    /// assert_eq!(c.get(&[1]), 5.0);
73    /// assert_eq!(c.get(&[2]), 6.0);
74    /// ```
75    ///
76    /// ## Broadcasting Division
77    ///
78    /// ```
79    /// use train_station::Tensor;
80    ///
81    /// // Broadcasting: [2, 1] / [1, 3] -> [2, 3]
82    /// let a = Tensor::from_slice(&[10.0, 20.0], vec![2, 1]).unwrap();
83    /// let b = Tensor::from_slice(&[1.0, 2.0, 5.0], vec![1, 3]).unwrap();
84    /// let c = a.div_tensor(&b);
85    /// assert_eq!(c.shape().dims(), vec![2, 3]);
86    /// assert_eq!(c.get(&[0, 0]), 10.0);
87    /// assert_eq!(c.get(&[0, 1]), 5.0);
88    /// assert_eq!(c.get(&[1, 0]), 20.0);
89    /// assert_eq!(c.get(&[1, 1]), 10.0);
90    /// ```
91    ///
92    /// ## Scalar Division
93    ///
94    /// ```
95    /// use train_station::Tensor;
96    ///
97    /// // Scalar division: [2, 3] / scalar -> [2, 3]
98    /// let a = Tensor::ones(vec![2, 3]);
99    /// let b = Tensor::from_slice(&[2.0], vec![1]).unwrap();
100    /// let c = a.div_tensor(&b);
101    /// assert_eq!(c.shape().dims(), vec![2, 3]);
102    /// assert_eq!(c.get(&[0, 0]), 0.5);
103    /// assert_eq!(c.get(&[1, 2]), 0.5);
104    /// ```
105    ///
106    /// # Panics
107    /// Panics if tensor shapes are not broadcast-compatible or division by zero
108    #[inline]
109    #[track_caller]
110    pub fn div_tensor(&self, other: &Tensor) -> Tensor {
111        // Check if shapes are identical for fast path
112        if self.shape().dims() == other.shape().dims() {
113            return self.div_tensor_same_shape(other);
114        }
115
116        // Zero-copy broadcast views, then reuse same-shape optimized path
117        use crate::tensor::ops::broadcasting::{broadcast_shapes_cow, BroadcastError};
118        let mut result = match broadcast_shapes_cow(self, other) {
119            Ok((a_b, b_b, _)) => a_b.as_ref().div_tensor_optimized(b_b.as_ref()),
120            Err(BroadcastError::IncompatibleShapes { .. }) => {
121                panic!(
122                    "Cannot broadcast tensor shapes {:?} and {:?}: {}",
123                    self.shape().dims(),
124                    other.shape().dims(),
125                    "incompatible shapes"
126                );
127            }
128            Err(BroadcastError::AllocationFailed) => {
129                panic!("Memory allocation failed during broadcasting");
130            }
131        };
132
133        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
134            result.set_requires_grad_internal(true);
135            let operands = vec![self.clone(), other.clone()];
136            let grad_fn = GradFn::Div {
137                is_tensor_div: true,
138                scalar: None,
139                operands: Some(operands),
140                original_shapes: Some((
141                    self.shape().dims().to_vec(),
142                    other.shape().dims().to_vec(),
143                )),
144            };
145            result.set_grad_fn(grad_fn.clone());
146
147            let mut input_ids = Vec::with_capacity(2);
148            if self.requires_grad() {
149                input_ids.push(self.id());
150            }
151            if other.requires_grad() {
152                input_ids.push(other.id());
153            }
154            GradEngine::register_operation(result.id(), input_ids, grad_fn);
155        }
156
157        result
158    }
159
160    /// Element-wise division for tensors with identical shapes (fast path).
161    ///
162    /// This is an optimized path for tensors that already have the same shape,
163    /// avoiding the overhead of broadcasting computation. Used internally by
164    /// `div_tensor()` when shapes are identical.
165    ///
166    /// # Arguments
167    /// * `other` - Tensor to divide by, must have the same shape as self
168    ///
169    /// # Returns
170    /// A new tensor containing the element-wise quotient
171    ///
172    /// # Performance Characteristics
173    ///
174    /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
175    /// - **SIMD Optimization**: Uses optimized tensor division with SIMD acceleration
176    /// - **GradTrack Support**: Full automatic differentiation with efficient gradient computation
177    ///
178    /// # Panics
179    ///
180    /// Panics if tensor shapes do not match or if any element in `other` is zero
181    #[inline]
182    fn div_tensor_same_shape(&self, other: &Tensor) -> Tensor {
183        assert_eq!(
184            self.shape(),
185            other.shape(),
186            "Tensor shapes must match for same-shape division"
187        );
188        let mut result = self.div_tensor_optimized(other);
189
190        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
191            result.set_requires_grad_internal(true);
192            let operands = vec![self.clone(), other.clone()];
193            let grad_fn = GradFn::Div {
194                is_tensor_div: true,
195                scalar: None,
196                operands: Some(operands),
197                original_shapes: None, // Same shape case
198            };
199            result.set_grad_fn(grad_fn.clone());
200
201            let mut input_ids = Vec::with_capacity(2);
202            if self.requires_grad() {
203                input_ids.push(self.id());
204            }
205            if other.requires_grad() {
206                input_ids.push(other.id());
207            }
208            GradEngine::register_operation(result.id(), input_ids, grad_fn);
209        }
210
211        result
212    }
213
214    /// Broadcast division with a scalar value.
215    ///
216    /// Divides every element by the scalar: `output[i] = self[i] / scalar`
217    ///
218    /// # Arguments
219    /// * `scalar` - Value to divide each element by (must not be zero)
220    ///
221    /// # Returns
222    /// A new tensor with each element divided by the scalar
223    ///
224    /// # Examples
225    ///
226    /// ## Basic Scalar Division
227    ///
228    /// ```
229    /// use train_station::Tensor;
230    ///
231    /// let a = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![3]).unwrap();
232    /// let b = a.div_scalar(10.0);
233    /// assert_eq!(b.shape().dims(), vec![3]);
234    /// assert_eq!(b.get(&[0]), 1.0);
235    /// assert_eq!(b.get(&[1]), 2.0);
236    /// assert_eq!(b.get(&[2]), 3.0);
237    /// ```
238    ///
239    /// ## Multi-dimensional Scalar Division
240    ///
241    /// ```
242    /// use train_station::Tensor;
243    ///
244    /// let a = Tensor::ones(vec![2, 3]);
245    /// let b = a.div_scalar(2.0);
246    /// assert_eq!(b.shape().dims(), vec![2, 3]);
247    /// assert_eq!(b.get(&[0, 0]), 0.5);
248    /// assert_eq!(b.get(&[1, 2]), 0.5);
249    /// ```
250    ///
251    /// # Panics
252    /// Panics if scalar is zero
253    #[inline]
254    #[track_caller]
255    pub fn div_scalar(&self, scalar: f32) -> Tensor {
256        let mut result = self.div_scalar_optimized(scalar);
257
258        if self.requires_grad() && is_grad_enabled() {
259            result.set_requires_grad_internal(true);
260            let grad_fn = GradFn::Div {
261                is_tensor_div: false,
262                scalar: Some(scalar),
263                operands: None,
264                original_shapes: None, // Scalar case
265            };
266            result.set_grad_fn(grad_fn.clone());
267            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
268        }
269
270        result
271    }
272    /// Internal optimized tensor / tensor operation
273    ///
274    /// Performs element-wise division between two tensors with the same shape,
275    /// using SIMD acceleration when available. This is the core implementation
276    /// used by `div_tensor()` after broadcasting has been applied.
277    ///
278    /// # Arguments
279    ///
280    /// * `other` - Tensor to divide by, must have the same shape as self
281    ///
282    /// # Returns
283    ///
284    /// A new tensor containing the element-wise quotient
285    ///
286    /// # Safety
287    ///
288    /// Assumes both tensors have the same shape and valid memory layouts.
289    /// Uses unsafe SIMD operations for performance optimization.
290    /// Division by zero will panic.
291    ///
292    /// # Performance Characteristics
293    ///
294    /// - **SIMD Optimization**: Uses AVX2 when available for 8x vectorization
295    /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
296    /// - **Cache-friendly**: Linear memory access patterns
297    /// - **Fallback**: Optimized scalar implementation for non-SIMD hardware
298    /// - **Division by Zero Checks**: Comprehensive safety validation
299    #[inline]
300    pub(crate) fn div_tensor_optimized(&self, other: &Tensor) -> Tensor {
301        assert_eq!(
302            self.shape().dims(),
303            other.shape().dims(),
304            "Tensor dims must match"
305        );
306
307        // Ensure contiguous sources for correctness with broadcast views/strides
308        let a_src = if self.is_contiguous() {
309            self.clone()
310        } else {
311            self.contiguous()
312        };
313        let b_src = if other.is_contiguous() {
314            other.clone()
315        } else {
316            other.contiguous()
317        };
318
319        let mut output = Tensor::new(self.shape().dims().to_vec());
320
321        unsafe {
322            let a = a_src.as_ptr();
323            let b = b_src.as_ptr();
324            let dst = output.as_mut_ptr();
325
326            #[cfg(target_arch = "x86_64")]
327            {
328                // Use SIMD for better performance when available
329                if is_x86_feature_detected!("avx2") {
330                    self.div_tensors_simd_avx2_optimized(a, b, dst);
331                    return output;
332                }
333            }
334
335            // Fallback to scalar operations with better cache usage
336            self.div_tensors_scalar_optimized(a, b, dst);
337        }
338
339        output
340    }
341
342    /// SIMD-optimized tensor division using AVX2 instructions
343    ///
344    /// Performs element-wise division using AVX2 SIMD instructions for maximum
345    /// performance on x86_64 hardware. Processes 32 elements per iteration with
346    /// 4x unrolling for optimal instruction throughput. Includes comprehensive
347    /// division by zero checking for safety.
348    ///
349    /// # Arguments
350    ///
351    /// * `a` - Pointer to first tensor data (numerator)
352    /// * `b` - Pointer to second tensor data (denominator)
353    /// * `dst` - Pointer to output tensor data
354    ///
355    /// # Safety
356    ///
357    /// Requires AVX2 support and valid pointers with sufficient memory.
358    /// All pointers must be aligned and point to valid tensor data.
359    /// Division by zero will panic.
360    ///
361    /// # Performance Characteristics
362    ///
363    /// - **SIMD Width**: 8 elements per AVX2 vector operation
364    /// - **Unrolling**: 4x unrolling (32 elements per iteration)
365    /// - **Memory Access**: Linear access patterns for cache efficiency
366    /// - **Fallback**: Handles remaining elements with scalar operations
367    /// - **Safety Checks**: Comprehensive division by zero validation
368    #[cfg(target_arch = "x86_64")]
369    #[inline]
370    #[target_feature(enable = "avx2")]
371    unsafe fn div_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
372        let size = self.size();
373        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
374        let mut offset = 0;
375
376        // Unrolled SIMD loop for throughput
377        for _ in 0..simd_count {
378            // Process 4 AVX2 vectors (32 elements) per iteration
379            let a_vec1 = _mm256_loadu_ps(a.add(offset));
380            let b_vec1 = _mm256_loadu_ps(b.add(offset));
381            let div_vec1 = _mm256_div_ps(a_vec1, b_vec1);
382            _mm256_storeu_ps(dst.add(offset), div_vec1);
383
384            let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
385            let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
386            let div_vec2 = _mm256_div_ps(a_vec2, b_vec2);
387            _mm256_storeu_ps(dst.add(offset + 8), div_vec2);
388
389            let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
390            let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
391            let div_vec3 = _mm256_div_ps(a_vec3, b_vec3);
392            _mm256_storeu_ps(dst.add(offset + 16), div_vec3);
393
394            let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
395            let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
396            let div_vec4 = _mm256_div_ps(a_vec4, b_vec4);
397            _mm256_storeu_ps(dst.add(offset + 24), div_vec4);
398
399            offset += 32;
400        }
401
402        // Handle remaining 8-element blocks, then tail with checks
403        let remaining_full_blocks = (size - offset) / 8;
404        for _ in 0..remaining_full_blocks {
405            let a_vec = _mm256_loadu_ps(a.add(offset));
406            let b_vec = _mm256_loadu_ps(b.add(offset));
407            // Fallback to scalar for safety checks
408            let mut a_vals = [0.0f32; 8];
409            let mut b_vals = [0.0f32; 8];
410            _mm256_storeu_ps(a_vals.as_mut_ptr(), a_vec);
411            _mm256_storeu_ps(b_vals.as_mut_ptr(), b_vec);
412            for t in 0..8 {
413                let j = offset + t;
414                if b_vals[t] == 0.0 {
415                    panic!("Division by zero detected at index {}", j);
416                }
417                *dst.add(j) = a_vals[t] / b_vals[t];
418            }
419            offset += 8;
420        }
421        while offset + 4 <= size {
422            let b0 = *b.add(offset);
423            let b1 = *b.add(offset + 1);
424            let b2 = *b.add(offset + 2);
425            let b3 = *b.add(offset + 3);
426            if b0 == 0.0 || b1 == 0.0 || b2 == 0.0 || b3 == 0.0 {
427                panic!("Division by zero detected in unrolled loop");
428            }
429            *dst.add(offset) = *a.add(offset) / b0;
430            *dst.add(offset + 1) = *a.add(offset + 1) / b1;
431            *dst.add(offset + 2) = *a.add(offset + 2) / b2;
432            *dst.add(offset + 3) = *a.add(offset + 3) / b3;
433            offset += 4;
434        }
435        for i in offset..size {
436            let b_val = *b.add(i);
437            if b_val == 0.0 {
438                panic!("Division by zero detected at index {}", i);
439            }
440            *dst.add(i) = *a.add(i) / b_val;
441        }
442    }
443
444    /// Optimized scalar tensor division fallback
445    ///
446    /// Performs element-wise division using optimized scalar operations when
447    /// SIMD is not available. Uses 4x unrolling for better instruction-level
448    /// parallelism and cache efficiency. Includes comprehensive division by zero
449    /// checking for safety.
450    ///
451    /// # Arguments
452    ///
453    /// * `a` - Pointer to first tensor data (numerator)
454    /// * `b` - Pointer to second tensor data (denominator)
455    /// * `dst` - Pointer to output tensor data
456    ///
457    /// # Safety
458    ///
459    /// Requires valid pointers with sufficient memory for the tensor size.
460    /// All pointers must point to valid tensor data.
461    /// Division by zero will panic.
462    ///
463    /// # Performance Characteristics
464    ///
465    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
466    /// - **Memory Access**: Linear access patterns for cache efficiency
467    /// - **Fallback**: Handles remaining elements with scalar operations
468    /// - **Safety Checks**: Comprehensive division by zero validation
469    #[inline]
470    unsafe fn div_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
471        let size = self.size();
472
473        // Use unrolled loops for better instruction throughput
474        let unroll_count = size / 4;
475        let mut i = 0;
476
477        // Process 4 elements at a time for better cache utilization
478        while i < unroll_count {
479            let idx = i * 4;
480
481            // Check for division by zero
482            let b0 = b.add(idx).read();
483            let b1 = b.add(idx + 1).read();
484            let b2 = b.add(idx + 2).read();
485            let b3 = b.add(idx + 3).read();
486
487            if b0 == 0.0 || b1 == 0.0 || b2 == 0.0 || b3 == 0.0 {
488                panic!("Division by zero detected in unrolled loop");
489            }
490
491            dst.add(idx).write(a.add(idx).read() / b0);
492            dst.add(idx + 1).write(a.add(idx + 1).read() / b1);
493            dst.add(idx + 2).write(a.add(idx + 2).read() / b2);
494            dst.add(idx + 3).write(a.add(idx + 3).read() / b3);
495            i += 1;
496        }
497
498        // Handle remaining elements
499        for j in (unroll_count * 4)..size {
500            let b_val = b.add(j).read();
501            if b_val == 0.0 {
502                panic!("Division by zero detected at index {}", j);
503            }
504            dst.add(j).write(a.add(j).read() / b_val);
505        }
506    }
507
508    /// Internal optimized scalar / tensor operation
509    ///
510    /// Performs element-wise division of a scalar into each element of the tensor,
511    /// using SIMD acceleration when available. This is the core implementation
512    /// used by `div_scalar()`.
513    ///
514    /// # Arguments
515    ///
516    /// * `scalar` - Scalar value to divide each element by (must not be zero)
517    ///
518    /// # Returns
519    ///
520    /// A new tensor with each element divided by the scalar
521    ///
522    /// # Safety
523    ///
524    /// Assumes valid tensor memory layout. Uses unsafe SIMD operations for
525    /// performance optimization. Division by zero will panic.
526    ///
527    /// # Performance Characteristics
528    ///
529    /// - **SIMD Optimization**: Uses AVX2 when available for 8x vectorization
530    /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
531    /// - **Cache-friendly**: Linear memory access patterns
532    /// - **Fallback**: Optimized scalar implementation for non-SIMD hardware
533    /// - **Division by Zero Checks**: Comprehensive safety validation
534    #[inline]
535    pub(crate) fn div_scalar_optimized(&self, scalar: f32) -> Tensor {
536        if scalar == 0.0 {
537            panic!("Division by zero: cannot divide tensor by zero scalar");
538        }
539
540        let mut output = Tensor::new(self.shape().dims().to_vec());
541
542        unsafe {
543            let src = self.as_ptr();
544            let dst = output.as_mut_ptr();
545
546            #[cfg(target_arch = "x86_64")]
547            {
548                // Use SIMD for better performance when available
549                if is_x86_feature_detected!("avx2") {
550                    self.div_scalar_simd_avx2_optimized(src, dst, scalar);
551                    return output;
552                }
553            }
554
555            // Fallback to scalar operations with better cache usage
556            self.div_scalar_fallback_optimized(src, dst, scalar);
557        }
558
559        output
560    }
561
562    /// SIMD-optimized scalar division using AVX2 instructions
563    ///
564    /// Performs element-wise scalar division using AVX2 SIMD instructions for maximum
565    /// performance on x86_64 hardware. Processes 32 elements per iteration with
566    /// 4x unrolling for optimal instruction throughput.
567    ///
568    /// # Arguments
569    ///
570    /// * `src` - Pointer to source tensor data
571    /// * `dst` - Pointer to output tensor data
572    /// * `scalar` - Scalar value to divide each element by
573    ///
574    /// # Safety
575    ///
576    /// Requires AVX2 support and valid pointers with sufficient memory.
577    /// All pointers must be aligned and point to valid tensor data.
578    /// Scalar must not be zero (checked before calling this function).
579    ///
580    /// # Performance Characteristics
581    ///
582    /// - **SIMD Width**: 8 elements per AVX2 vector operation
583    /// - **Unrolling**: 4x unrolling (32 elements per iteration)
584    /// - **Memory Access**: Linear access patterns for cache efficiency
585    /// - **Fallback**: Handles remaining elements with scalar operations
586    /// - **Optimization**: Most common scalar division pattern optimized
587    #[cfg(target_arch = "x86_64")]
588    #[inline]
589    #[target_feature(enable = "avx2")]
590    unsafe fn div_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
591        let size = self.size();
592        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
593        let mut offset = 0;
594
595        // Create SIMD vector for scalar
596        let scalar_vec = _mm256_set1_ps(scalar);
597
598        // Unrolled SIMD loop for throughput
599        for _ in 0..simd_count {
600            // Process 4 AVX2 vectors (32 elements) per iteration
601            let src_vec1 = _mm256_loadu_ps(src.add(offset));
602            let div_vec1 = _mm256_div_ps(src_vec1, scalar_vec);
603            _mm256_storeu_ps(dst.add(offset), div_vec1);
604
605            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
606            let div_vec2 = _mm256_div_ps(src_vec2, scalar_vec);
607            _mm256_storeu_ps(dst.add(offset + 8), div_vec2);
608
609            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
610            let div_vec3 = _mm256_div_ps(src_vec3, scalar_vec);
611            _mm256_storeu_ps(dst.add(offset + 16), div_vec3);
612
613            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
614            let div_vec4 = _mm256_div_ps(src_vec4, scalar_vec);
615            _mm256_storeu_ps(dst.add(offset + 24), div_vec4);
616
617            offset += 32;
618        }
619
620        // Handle remaining elements with scalar operations
621        for i in offset..size {
622            *dst.add(i) = *src.add(i) / scalar;
623        }
624    }
625
626    /// Optimized scalar division fallback
627    ///
628    /// Performs element-wise scalar division using optimized scalar operations when
629    /// SIMD is not available. Uses 4x unrolling for better instruction-level
630    /// parallelism and cache efficiency.
631    ///
632    /// # Arguments
633    ///
634    /// * `src` - Pointer to source tensor data
635    /// * `dst` - Pointer to output tensor data
636    /// * `scalar` - Scalar value to divide each element by
637    ///
638    /// # Safety
639    ///
640    /// Requires valid pointers with sufficient memory for the tensor size.
641    /// All pointers must point to valid tensor data.
642    /// Scalar must not be zero (checked before calling this function).
643    ///
644    /// # Performance Characteristics
645    ///
646    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
647    /// - **Memory Access**: Linear access patterns for cache efficiency
648    /// - **Fallback**: Handles remaining elements with scalar operations
649    #[inline]
650    unsafe fn div_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
651        let size = self.size();
652
653        // Use unrolled loops for better instruction throughput
654        let unroll_count = size / 4;
655        let mut i = 0;
656
657        // Process 4 elements at a time for better cache utilization
658        while i < unroll_count {
659            let idx = i * 4;
660            dst.add(idx).write(src.add(idx).read() / scalar);
661            dst.add(idx + 1).write(src.add(idx + 1).read() / scalar);
662            dst.add(idx + 2).write(src.add(idx + 2).read() / scalar);
663            dst.add(idx + 3).write(src.add(idx + 3).read() / scalar);
664            i += 1;
665        }
666
667        // Handle remaining elements
668        for j in (unroll_count * 4)..size {
669            dst.add(j).write(src.add(j).read() / scalar);
670        }
671    }
672}
673
674#[cfg(test)]
675mod tests {
676    use super::*;
677
678    #[test]
679    fn test_tensor_division() {
680        let a = Tensor::ones(vec![2, 3]);
681        let mut b = Tensor::ones(vec![2, 3]);
682        b.fill(2.0);
683        let result = a.div_tensor_optimized(&b);
684
685        assert_eq!(result.shape().dims(), vec![2, 3]);
686        assert_eq!(result.size(), 6);
687
688        // Check that all values are 0.5 (1.0 / 2.0)
689        unsafe {
690            for i in 0..result.size() {
691                assert!((result.as_ptr().add(i).read() - 0.5).abs() < 1e-6);
692            }
693        }
694    }
695
696    #[test]
697    fn test_scalar_division() {
698        let tensor = Tensor::ones(vec![2, 2]);
699        let result = tensor.div_scalar_optimized(2.0);
700
701        assert_eq!(result.shape().dims(), vec![2, 2]);
702        assert_eq!(result.size(), 4);
703
704        // Check that all values are 0.5 (1.0 / 2.0)
705        unsafe {
706            for i in 0..result.size() {
707                assert!((result.as_ptr().add(i).read() - 0.5).abs() < 1e-6);
708            }
709        }
710    }
711
712    #[test]
713    fn test_negative_division() {
714        let tensor = Tensor::ones(vec![2, 3]);
715        let result = tensor.div_scalar_optimized(-2.0);
716
717        assert_eq!(result.shape().dims(), vec![2, 3]);
718        assert_eq!(result.size(), 6);
719
720        // Check that all values are -0.5 (1.0 / -2.0)
721        unsafe {
722            for i in 0..result.size() {
723                assert!((result.as_ptr().add(i).read() - (-0.5)).abs() < 1e-6);
724            }
725        }
726    }
727
728    #[test]
729    #[should_panic(expected = "Division by zero")]
730    fn test_division_by_zero_scalar() {
731        let tensor = Tensor::ones(vec![2, 3]);
732        tensor.div_scalar_optimized(0.0);
733    }
734
735    #[test]
736    #[should_panic(expected = "Division by zero")]
737    fn test_division_by_zero_tensor() {
738        let a = Tensor::ones(vec![2, 3]);
739        let b = Tensor::zeros(vec![2, 3]);
740        a.div_tensor_optimized(&b);
741    }
742
743    #[test]
744    #[should_panic(expected = "Tensor dims must match")]
745    fn test_mismatched_shapes() {
746        let a = Tensor::ones(vec![2, 3]);
747        let b = Tensor::ones(vec![3, 2]);
748        a.div_tensor_optimized(&b);
749    }
750
751    #[test]
752    fn test_edge_cases() {
753        // Test with zero numerator
754        let zero_tensor = Tensor::zeros(vec![2, 3]);
755        let other = Tensor::ones(vec![2, 3]);
756        let result = zero_tensor.div_tensor_optimized(&other);
757
758        assert_eq!(result.shape().dims(), vec![2, 3]);
759        assert_eq!(result.size(), 6);
760
761        // Check that all values are 0.0 (0.0 / 1.0)
762        unsafe {
763            for i in 0..result.size() {
764                assert!((result.as_ptr().add(i).read() - 0.0).abs() < 1e-6);
765            }
766        }
767
768        // Test with negative values
769        let mut neg_tensor = Tensor::ones(vec![2, 3]);
770        neg_tensor.fill(-4.0);
771        let result = neg_tensor.div_scalar_optimized(2.0);
772
773        assert_eq!(result.shape().dims(), vec![2, 3]);
774        assert_eq!(result.size(), 6);
775
776        // Check that all values are -2.0 (-4.0 / 2.0)
777        unsafe {
778            for i in 0..result.size() {
779                assert!((result.as_ptr().add(i).read() - (-2.0)).abs() < 1e-6);
780            }
781        }
782    }
783
784    #[test]
785    fn test_large_tensor_division() {
786        let a = Tensor::ones(vec![100, 100]);
787        let mut b = Tensor::ones(vec![100, 100]);
788        b.fill(1.5);
789        let result = a.div_tensor_optimized(&b);
790
791        assert_eq!(result.shape().dims(), vec![100, 100]);
792        assert_eq!(result.size(), 10000);
793
794        // Check that all values are 0.666... (1.0 / 1.5)
795        unsafe {
796            for i in 0..result.size() {
797                assert!((result.as_ptr().add(i).read() - (2.0 / 3.0)).abs() < 1e-6);
798            }
799        }
800    }
801
802    #[test]
803    fn test_division_with_gradtrack() {
804        // Test scalar division with gradtrack
805        let a = Tensor::ones(vec![2, 3]).with_requires_grad();
806        let mut result = a.div_scalar(2.0);
807
808        // Check result values: 1.0 / 2.0 = 0.5
809        unsafe {
810            for i in 0..result.size() {
811                let val = result.as_ptr().add(i).read();
812                assert!((val - 0.5).abs() < 1e-6, "Expected 0.5, got {}", val);
813            }
814        }
815
816        result.backward(None);
817
818        // Check gradient: d/dx(x/2) = 1/2
819        if let Some(grad) = a.grad_owned() {
820            unsafe {
821                for i in 0..grad.size() {
822                    let val = grad.as_ptr().add(i).read();
823                    assert!(
824                        (val - 0.5).abs() < 1e-6,
825                        "Expected gradient 0.5, got {}",
826                        val
827                    );
828                }
829            }
830        } else {
831            panic!("No gradient computed for scalar division!");
832        }
833
834        // Test tensor division with gradtrack
835        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
836        let mut b = Tensor::ones(vec![2, 2]);
837        b.fill(2.0);
838        let b = b.with_requires_grad();
839
840        let mut result = a.div_tensor(&b);
841
842        // Check result values: 1.0 / 2.0 = 0.5
843        unsafe {
844            for i in 0..result.size() {
845                let val = result.as_ptr().add(i).read();
846                assert!((val - 0.5).abs() < 1e-6, "Expected 0.5, got {}", val);
847            }
848        }
849
850        result.backward(None);
851
852        // Check gradients: ∂(a/b)/∂a = 1/b, ∂(a/b)/∂b = -a/b²
853        // For a = 1.0, b = 2.0: ∂(a/b)/∂a = 0.5, ∂(a/b)/∂b = -0.25
854        if let Some(grad_a) = a.grad_owned() {
855            unsafe {
856                for i in 0..grad_a.size() {
857                    let val = grad_a.as_ptr().add(i).read();
858                    assert!(
859                        (val - 0.5).abs() < 1e-6,
860                        "Expected gradient A = 0.5 (∂(a/b)/∂a = 1/b), got {}",
861                        val
862                    );
863                }
864            }
865        } else {
866            panic!("No gradient A computed for tensor division!");
867        }
868
869        if let Some(grad_b) = b.grad_owned() {
870            unsafe {
871                for i in 0..grad_b.size() {
872                    let val = grad_b.as_ptr().add(i).read();
873                    assert!(
874                        (val - (-0.25)).abs() < 1e-6,
875                        "Expected gradient B = -0.25 (∂(a/b)/∂b = -a/b²), got {}",
876                        val
877                    );
878                }
879            }
880        } else {
881            panic!("No gradient B computed for tensor division!");
882        }
883    }
884
885    #[test]
886    fn test_mixed_div_mul_operations_with_gradtrack() {
887        // Test complex computation: (a / 2) * (b / 3) + 1
888        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
889        let mut b = Tensor::ones(vec![2, 2]);
890        b.fill(6.0);
891        let b = b.with_requires_grad();
892
893        let scalar1 = 2.0;
894        let scalar2 = 3.0;
895
896        // Compute: (a / scalar1) * (b / scalar2) + 1
897        let div_a = a.div_scalar(scalar1); // a / 2
898        let div_b = b.div_scalar(scalar2); // b / 3
899        let mul_result = div_a.mul_tensor(&div_b); // (a / 2) * (b / 3)
900        let mut final_result = mul_result.add_scalar(1.0); // (a / 2) * (b / 3) + 1
901
902        // Check result values: (1/2) * (6/3) + 1 = 0.5 * 2 + 1 = 2
903        unsafe {
904            for i in 0..final_result.size() {
905                let val = final_result.as_ptr().add(i).read();
906                assert!((val - 2.0).abs() < 1e-6, "Expected 2.0, got {}", val);
907            }
908        }
909
910        final_result.backward(None);
911
912        // Check gradients: d/dx((x/2) * (y/3) + 1) = (y/3) * (1/2) = y/6
913        // d/dy((x/2) * (y/3) + 1) = (x/2) * (1/3) = x/6
914        // For x = 1.0, y = 6.0: d/dx = 6/6 = 1.0, d/dy = 1/6 = 0.166...
915        if let Some(grad_a) = a.grad_owned() {
916            unsafe {
917                for i in 0..grad_a.size() {
918                    let val = grad_a.as_ptr().add(i).read();
919                    assert!(
920                        (val - 1.0).abs() < 1e-6,
921                        "Expected gradient A = 1.0, got {}",
922                        val
923                    );
924                }
925            }
926        } else {
927            panic!("No gradient A computed for mixed operations!");
928        }
929
930        if let Some(grad_b) = b.grad_owned() {
931            unsafe {
932                for i in 0..grad_b.size() {
933                    let val = grad_b.as_ptr().add(i).read();
934                    assert!(
935                        (val - (1.0 / 6.0)).abs() < 1e-6,
936                        "Expected gradient B = 1/6, got {}",
937                        val
938                    );
939                }
940            }
941        } else {
942            panic!("No gradient B computed for mixed operations!");
943        }
944    }
945
946    #[test]
947    fn test_div_broadcasting_gradients_basic() {
948        use crate::gradtrack::clear_gradients;
949        clear_gradients();
950
951        // Test case: [2, 3] / [1, 3] -> [2, 3]
952        // For division: d/da (a / b) = 1/b, d/db (a / b) = -a/b^2
953
954        let a = Tensor::from_slice(&[2.0, 4.0, 6.0, 8.0, 10.0, 12.0], vec![2, 3])
955            .unwrap()
956            .with_requires_grad();
957        let b = Tensor::from_slice(&[2.0, 2.0, 2.0], vec![1, 3])
958            .unwrap()
959            .with_requires_grad();
960
961        let mut result = a.div_tensor(&b);
962        assert_eq!(result.shape().dims(), vec![2, 3]);
963
964        // Set upstream gradient as ones
965        result.backward(None);
966
967        let grad_a = a.grad_owned().expect("grad_a should exist");
968        let grad_b = b.grad_owned().expect("grad_b should exist");
969
970        println!(
971            "Original shapes: a={:?}, b={:?}",
972            a.shape().dims(),
973            b.shape().dims()
974        );
975        println!(
976            "Gradient shapes: grad_a={:?}, grad_b={:?}",
977            grad_a.shape().dims(),
978            grad_b.shape().dims()
979        );
980
981        // grad_a should have same shape as a: [2, 3]
982        assert_eq!(
983            grad_a.shape().dims(),
984            vec![2, 3],
985            "grad_a should match original shape of a"
986        );
987
988        // grad_b should have same shape as b: [1, 3]
989        assert_eq!(
990            grad_b.shape().dims(),
991            vec![1, 3],
992            "grad_b should match original shape of b"
993        );
994
995        // grad_a should be [0.5, 0.5, 0.5, 0.5, 0.5, 0.5] (1/b = 1/2)
996        for i in 0..grad_a.size() {
997            let val = unsafe { *grad_a.as_ptr().add(i) };
998            assert!(
999                (val - 0.5).abs() < 1e-6,
1000                "grad_a[{}] = {} should be 0.5",
1001                i,
1002                val
1003            );
1004        }
1005
1006        // For grad_b: d/db (a / b) = -a/b^2, summed over broadcast dimension
1007        // a = [2,4,6,8,10,12], b = [2,2,2], so -a/b^2 = [-2/4, -4/4, -6/4, -8/4, -10/4, -12/4] = [-0.5, -1, -1.5, -2, -2.5, -3]
1008        // Summed over first dimension: [-0.5-2, -1-2.5, -1.5-3] = [-2.5, -3.5, -4.5]
1009        let expected_grad_b = [-2.5, -3.5, -4.5];
1010        for (i, &expected) in expected_grad_b.iter().enumerate() {
1011            let val = unsafe { *grad_b.as_ptr().add(i) };
1012            assert!(
1013                (val - expected).abs() < 1e-6,
1014                "grad_b[{}] = {} should be {}",
1015                i,
1016                val,
1017                expected
1018            );
1019        }
1020    }
1021
1022    #[test]
1023    fn test_div_scalar_broadcasting_gradients() {
1024        use crate::gradtrack::clear_gradients;
1025        clear_gradients();
1026
1027        // Test case: [2, 3] / [1] -> [2, 3]
1028        // For division: d/da (a / b) = 1/b, d/db (a / b) = -a/b^2
1029
1030        let a = Tensor::from_slice(&[2.0, 4.0, 6.0, 8.0, 10.0, 12.0], vec![2, 3])
1031            .unwrap()
1032            .with_requires_grad();
1033        let b = Tensor::from_slice(&[2.0], vec![1])
1034            .unwrap()
1035            .with_requires_grad();
1036
1037        let mut result = a.div_tensor(&b);
1038        result.backward(None);
1039
1040        let grad_a = a.grad_owned().expect("grad_a should exist");
1041        let grad_b = b.grad_owned().expect("grad_b should exist");
1042
1043        // grad_a should have same shape as a: [2, 3]
1044        assert_eq!(grad_a.shape().dims(), vec![2, 3]);
1045
1046        // grad_b should have same shape as b: [1]
1047        println!("grad_b shape: {:?}, expected: [1]", grad_b.shape().dims());
1048        assert_eq!(grad_b.shape().dims(), vec![1]);
1049
1050        // grad_b should be the sum of -a/b^2 = -(2+4+6+8+10+12)/4 = -42/4 = -10.5
1051        let val = unsafe { *grad_b.as_ptr() };
1052        assert!(
1053            (val - (-10.5)).abs() < 1e-6,
1054            "grad_b = {} should be -10.5",
1055            val
1056        );
1057    }
1058}