train_station/tensor/ops/
mul.rs

1//! Multiplication operations for tensors
2//!
3//! Provides element-wise multiplication functions following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Multiplication**: `mul_tensor()` - Element-wise multiplication with another tensor (PyTorch `mul()` equivalent)
9//! - **Scalar Multiplication**: `mul_scalar()` - Broadcast multiplication with a scalar value
10//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
11//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
12//! - **Broadcasting Support**: NumPy-style broadcasting for compatible shapes
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Mathematical Accuracy**: High-precision multiplication computation
15//!
16//! # Mathematical Properties
17//!
18//! The multiplication operations have the following properties:
19//! - **Commutative**: a * b = b * a
20//! - **Associative**: (a * b) * c = a * (b * c)
21//! - **Distributive**: a * (b + c) = a * b + a * c
22//! - **Identity**: a * 1 = a
23//! - **Zero**: a * 0 = 0
24//! - **Gradient**: d/dx(a * b) = b, d/dy(a * b) = a
25//!
26//! # Performance Characteristics
27//!
28//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
29//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
30//! - **Cache-friendly Access**: Linear memory access patterns
31//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
32//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41// Note: removed manual prefetching; linear access + hardware prefetch is sufficient
42
43impl Tensor {
44    /// Element-wise multiplication with another tensor with broadcasting support.
45    ///
46    /// Performs element-wise multiplication with automatic broadcasting: `output[i] = self[i] * other[i]`
47    ///
48    /// Broadcasting enables multiplication between tensors of different but compatible shapes.
49    /// Compatible shapes follow NumPy broadcasting rules:
50    /// - Dimensions are aligned from the rightmost dimension
51    /// - Dimensions are compatible if they are equal, or one of them is 1
52    /// - Missing dimensions are treated as 1
53    ///
54    /// # Arguments
55    /// * `other` - Tensor to multiply. Shapes must be broadcast-compatible.
56    ///
57    /// # Returns
58    /// A new tensor containing the element-wise product with broadcast result shape
59    ///
60    /// # Examples
61    ///
62    /// ## Same Shape Multiplication
63    ///
64    /// ```
65    /// use train_station::Tensor;
66    ///
67    /// let a = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![3]).unwrap();
68    /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0], vec![3]).unwrap();
69    /// let c = a.mul_tensor(&b);
70    /// assert_eq!(c.shape().dims, vec![3]);
71    /// assert_eq!(c.get(&[0]), 10.0); // 2.0 * 5.0
72    /// assert_eq!(c.get(&[1]), 18.0); // 3.0 * 6.0
73    /// assert_eq!(c.get(&[2]), 28.0); // 4.0 * 7.0
74    /// ```
75    ///
76    /// ## Broadcasting Multiplication
77    ///
78    /// ```
79    /// use train_station::Tensor;
80    ///
81    /// let a = Tensor::from_slice(&[2.0, 3.0], vec![2, 1]).unwrap();
82    /// let b = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![1, 3]).unwrap();
83    /// let c = a.mul_tensor(&b);
84    /// assert_eq!(c.shape().dims, vec![2, 3]);
85    /// // Result: [[20.0, 40.0, 60.0], [30.0, 60.0, 90.0]]
86    /// assert_eq!(c.get(&[0, 0]), 20.0); // 2.0 * 10.0
87    /// assert_eq!(c.get(&[0, 1]), 40.0); // 2.0 * 20.0
88    /// assert_eq!(c.get(&[1, 0]), 30.0); // 3.0 * 10.0
89    /// ```
90    ///
91    /// # Panics
92    /// Panics if tensor shapes are not broadcast-compatible
93    #[inline]
94    pub fn mul_tensor(&self, other: &Tensor) -> Tensor {
95        // Check if shapes are identical for fast path
96        if self.shape().dims == other.shape().dims {
97            return self.mul_tensor_same_shape(other);
98        }
99
100        // Use broadcasting for different shapes
101        let (broadcast_self, broadcast_other, _result_shape) =
102            self.broadcast_with(other).unwrap_or_else(|e| {
103                panic!(
104                    "Cannot broadcast tensor shapes {:?} and {:?}: {}",
105                    self.shape().dims,
106                    other.shape().dims,
107                    e
108                );
109            });
110
111        // Perform element-wise multiplication on broadcasted tensors
112        let mut result = broadcast_self.mul_tensor_optimized(&broadcast_other);
113
114        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
115            result.set_requires_grad_internal(true);
116            let operands = vec![self.clone(), other.clone()];
117            let grad_fn = GradFn::Mul {
118                is_tensor_mul: true,
119                scalar: None,
120                operands: Some(operands),
121                original_shapes: Some((self.shape().dims.clone(), other.shape().dims.clone())),
122            };
123            result.set_grad_fn(grad_fn.clone());
124
125            let mut input_ids = Vec::with_capacity(2);
126            if self.requires_grad() {
127                input_ids.push(self.id());
128            }
129            if other.requires_grad() {
130                input_ids.push(other.id());
131            }
132            GradEngine::register_operation(result.id(), input_ids, grad_fn);
133        }
134
135        result
136    }
137
138    /// Element-wise multiplication for tensors with identical shapes (fast path)
139    ///
140    /// This is an optimized path for tensors that already have the same shape,
141    /// avoiding the overhead of broadcasting computation. This method provides
142    /// better performance when tensors have matching dimensions.
143    ///
144    /// # Arguments
145    ///
146    /// * `other` - Tensor to multiply, must have the same shape as self
147    ///
148    /// # Returns
149    ///
150    /// A new tensor containing the element-wise product
151    ///
152    /// # Performance Characteristics
153    ///
154    /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
155    /// - **SIMD Optimization**: Uses optimized multiplication when available
156    /// - **Memory Efficiency**: Direct element-wise computation without shape conversion
157    /// - **Gradient Tracking**: Full gradtrack support with efficient gradient computation
158    ///
159    /// # Panics
160    ///
161    /// Panics if tensor shapes do not match
162    ///
163    /// # Implementation Details
164    ///
165    /// This method is called internally by `mul_tensor()` when shapes are identical.
166    /// It provides a performance optimization by skipping the broadcasting logic
167    /// and directly calling the optimized multiplication implementation.
168    #[inline]
169    fn mul_tensor_same_shape(&self, other: &Tensor) -> Tensor {
170        assert_eq!(
171            self.shape(),
172            other.shape(),
173            "Tensor shapes must match for same-shape multiplication"
174        );
175        let mut result = self.mul_tensor_optimized(other);
176
177        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
178            result.set_requires_grad_internal(true);
179            let operands = vec![self.clone(), other.clone()];
180            let grad_fn = GradFn::Mul {
181                is_tensor_mul: true,
182                scalar: None,
183                operands: Some(operands),
184                original_shapes: None, // Same shape case
185            };
186            result.set_grad_fn(grad_fn.clone());
187
188            let mut input_ids = Vec::with_capacity(2);
189            if self.requires_grad() {
190                input_ids.push(self.id());
191            }
192            if other.requires_grad() {
193                input_ids.push(other.id());
194            }
195            GradEngine::register_operation(result.id(), input_ids, grad_fn);
196        }
197
198        result
199    }
200
201    /// Broadcast multiplication with a scalar value.
202    ///
203    /// Multiplies every element by the scalar: `output[i] = self[i] * scalar`
204    ///
205    /// # Arguments
206    /// * `scalar` - Value to multiply with each element
207    ///
208    /// # Returns
209    /// A new tensor with each element multiplied by the scalar
210    ///
211    /// # Examples
212    ///
213    /// ## Basic Scalar Multiplication
214    ///
215    /// ```
216    /// use train_station::Tensor;
217    ///
218    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
219    /// let b = a.mul_scalar(10.0);
220    /// assert_eq!(b.shape().dims, vec![3]);
221    /// assert_eq!(b.get(&[0]), 10.0); // 1.0 * 10.0
222    /// assert_eq!(b.get(&[1]), 20.0); // 2.0 * 10.0
223    /// assert_eq!(b.get(&[2]), 30.0); // 3.0 * 10.0
224    /// ```
225    ///
226    /// ## Negative Scalar Multiplication
227    ///
228    /// ```
229    /// use train_station::Tensor;
230    ///
231    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
232    /// let b = a.mul_scalar(-2.0);
233    /// assert_eq!(b.shape().dims, vec![3]);
234    /// assert_eq!(b.get(&[0]), -2.0); // 1.0 * -2.0
235    /// assert_eq!(b.get(&[1]), -4.0); // 2.0 * -2.0
236    /// assert_eq!(b.get(&[2]), -6.0); // 3.0 * -2.0
237    /// ```
238    #[inline]
239    pub fn mul_scalar(&self, scalar: f32) -> Tensor {
240        let mut result = self.mul_scalar_optimized(scalar);
241
242        if self.requires_grad() && is_grad_enabled() {
243            result.set_requires_grad_internal(true);
244            let grad_fn = GradFn::Mul {
245                is_tensor_mul: false,
246                scalar: Some(scalar),
247                operands: None,
248                original_shapes: None, // Scalar case
249            };
250            result.set_grad_fn(grad_fn.clone());
251            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
252        }
253
254        result
255    }
256    /// Optimized tensor multiplication using SIMD when available
257    ///
258    /// Performs element-wise multiplication using SIMD optimization when available
259    /// and falling back to optimized scalar computation. This is the core implementation
260    /// used by `mul_tensor()`.
261    ///
262    /// # Arguments
263    ///
264    /// * `other` - The tensor to multiply with
265    ///
266    /// # Returns
267    ///
268    /// A new tensor with the result of the multiplication
269    ///
270    /// # Performance Characteristics
271    ///
272    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
273    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
274    /// - **Cache-friendly**: Linear memory access patterns
275    /// - **Memory Safety**: Ensures contiguous memory layout for correctness
276    /// - **Zero-sized Handling**: Fast return for empty tensors
277    ///
278    /// # Safety
279    ///
280    /// This operation assumes the tensors have the same shape. The method ensures
281    /// contiguous memory layout for both input tensors to guarantee correctness
282    /// with view tensors.
283    ///
284    /// # Implementation Details
285    ///
286    /// Automatically selects between SIMD and scalar implementations based on hardware
287    /// capabilities. Ensures both input tensors are contiguous for memory safety.
288    /// SIMD implementation processes 32 elements per iteration with 4x unrolling.
289    #[inline]
290    pub(crate) fn mul_tensor_optimized(&self, other: &Tensor) -> Tensor {
291        assert_eq!(self.shape(), other.shape(), "Tensor shapes must match");
292        // Ensure contiguous sources for correctness with view tensors
293        let a_src = if self.is_contiguous() {
294            self.clone()
295        } else {
296            self.contiguous()
297        };
298        let b_src = if other.is_contiguous() {
299            other.clone()
300        } else {
301            other.contiguous()
302        };
303
304        let mut output = Tensor::new(self.shape().dims.clone());
305
306        unsafe {
307            let a = a_src.as_ptr();
308            let b = b_src.as_ptr();
309            let dst = output.as_mut_ptr();
310
311            #[cfg(target_arch = "x86_64")]
312            {
313                // Use SIMD for better performance when available
314                if is_x86_feature_detected!("avx2") {
315                    self.mul_tensors_simd_avx2_optimized(a, b, dst);
316                    return output;
317                }
318            }
319
320            // Fallback to scalar operations with better cache usage
321            self.mul_tensors_scalar_optimized(a, b, dst);
322        }
323
324        output
325    }
326
327    /// AVX2-optimized tensor multiplication implementation
328    ///
329    /// Performs element-wise multiplication using AVX2 SIMD instructions for maximum
330    /// performance on x86_64 architectures with AVX2 support.
331    ///
332    /// # Arguments
333    ///
334    /// * `a` - Pointer to first tensor data
335    /// * `b` - Pointer to second tensor data
336    /// * `dst` - Pointer to output tensor data
337    ///
338    /// # Safety
339    ///
340    /// Requires valid pointers with sufficient memory for the tensor size.
341    /// All pointers must point to valid tensor data. Requires AVX2 support.
342    ///
343    /// # Performance Characteristics
344    ///
345    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
346    /// - **Memory Access**: Linear access patterns for cache efficiency
347    /// - **Vector Operations**: Uses AVX2 multiplication instructions
348    /// - **Fallback**: Handles remaining elements with scalar operations
349    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
350    ///
351    /// # Implementation Details
352    ///
353    /// Uses AVX2 vector instructions to process 8 elements simultaneously.
354    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
355    /// Processes remaining elements with scalar operations for complete coverage.
356    #[cfg(target_arch = "x86_64")]
357    #[inline]
358    #[target_feature(enable = "avx2")]
359    unsafe fn mul_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
360        let size = self.size();
361        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
362        let mut offset = 0;
363
364        // Unrolled SIMD loop for throughput
365        for _ in 0..simd_count {
366            // Process 4 AVX2 vectors (32 elements) per iteration
367            let a_vec1 = _mm256_loadu_ps(a.add(offset));
368            let b_vec1 = _mm256_loadu_ps(b.add(offset));
369            let mul_vec1 = _mm256_mul_ps(a_vec1, b_vec1);
370            _mm256_storeu_ps(dst.add(offset), mul_vec1);
371
372            let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
373            let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
374            let mul_vec2 = _mm256_mul_ps(a_vec2, b_vec2);
375            _mm256_storeu_ps(dst.add(offset + 8), mul_vec2);
376
377            let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
378            let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
379            let mul_vec3 = _mm256_mul_ps(a_vec3, b_vec3);
380            _mm256_storeu_ps(dst.add(offset + 16), mul_vec3);
381
382            let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
383            let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
384            let mul_vec4 = _mm256_mul_ps(a_vec4, b_vec4);
385            _mm256_storeu_ps(dst.add(offset + 24), mul_vec4);
386
387            offset += 32;
388        }
389
390        // Handle remaining 8-element blocks, then tail
391        let remaining_full_blocks = (size - offset) / 8;
392        for _ in 0..remaining_full_blocks {
393            let a_vec = _mm256_loadu_ps(a.add(offset));
394            let b_vec = _mm256_loadu_ps(b.add(offset));
395            let mul_vec = _mm256_mul_ps(a_vec, b_vec);
396            _mm256_storeu_ps(dst.add(offset), mul_vec);
397            offset += 8;
398        }
399        while offset + 4 <= size {
400            *dst.add(offset) = *a.add(offset) * *b.add(offset);
401            *dst.add(offset + 1) = *a.add(offset + 1) * *b.add(offset + 1);
402            *dst.add(offset + 2) = *a.add(offset + 2) * *b.add(offset + 2);
403            *dst.add(offset + 3) = *a.add(offset + 3) * *b.add(offset + 3);
404            offset += 4;
405        }
406        for i in offset..size {
407            *dst.add(i) = *a.add(i) * *b.add(i);
408        }
409    }
410
411    /// Optimized scalar tensor multiplication fallback
412    ///
413    /// Performs element-wise multiplication using optimized scalar operations with
414    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
415    ///
416    /// # Arguments
417    ///
418    /// * `a` - Pointer to first tensor data
419    /// * `b` - Pointer to second tensor data
420    /// * `dst` - Pointer to output tensor data
421    ///
422    /// # Safety
423    ///
424    /// Requires valid pointers with sufficient memory for the tensor size.
425    /// All pointers must point to valid tensor data.
426    ///
427    /// # Performance Characteristics
428    ///
429    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
430    /// - **Memory Access**: Linear access patterns for cache efficiency
431    /// - **Fallback**: Handles remaining elements with scalar operations
432    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
433    /// - **Mathematical Accuracy**: High-precision scalar multiplication
434    ///
435    /// # Implementation Details
436    ///
437    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
438    /// Processes elements in groups of 4 to improve instruction-level parallelism
439    /// and reduce loop overhead.
440    #[inline]
441    unsafe fn mul_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
442        let size = self.size();
443
444        // Use unrolled loops for better instruction throughput
445        let unroll_count = size / 4;
446        let mut i = 0;
447
448        // Process 4 elements at a time for better cache utilization
449        while i < unroll_count {
450            let idx = i * 4;
451            dst.add(idx).write(a.add(idx).read() * b.add(idx).read());
452            dst.add(idx + 1)
453                .write(a.add(idx + 1).read() * b.add(idx + 1).read());
454            dst.add(idx + 2)
455                .write(a.add(idx + 2).read() * b.add(idx + 2).read());
456            dst.add(idx + 3)
457                .write(a.add(idx + 3).read() * b.add(idx + 3).read());
458            i += 1;
459        }
460
461        // Handle remaining elements
462        for j in (unroll_count * 4)..size {
463            dst.add(j).write(a.add(j).read() * b.add(j).read());
464        }
465    }
466
467    /// Optimized scalar multiplication using SIMD when available
468    ///
469    /// Performs element-wise scalar multiplication using SIMD optimization when available
470    /// and falling back to optimized scalar computation. This is the core implementation
471    /// used by `mul_scalar()`.
472    ///
473    /// # Arguments
474    ///
475    /// * `scalar` - The scalar value to multiply by
476    ///
477    /// # Returns
478    ///
479    /// A new tensor with the result of the multiplication
480    ///
481    /// # Performance Characteristics
482    ///
483    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
484    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
485    /// - **Cache-friendly**: Linear memory access patterns
486    /// - **Memory Safety**: Ensures contiguous memory layout for correctness
487    /// - **Zero-sized Handling**: Fast return for empty tensors
488    ///
489    /// # Implementation Details
490    ///
491    /// Automatically selects between SIMD and scalar implementations based on hardware
492    /// capabilities. Ensures the input tensor is contiguous for memory safety.
493    /// SIMD implementation processes 32 elements per iteration with 4x unrolling.
494    #[inline]
495    pub(crate) fn mul_scalar_optimized(&self, scalar: f32) -> Tensor {
496        // Ensure contiguous source for correctness with view tensors
497        let src_self = if self.is_contiguous() {
498            self.clone()
499        } else {
500            self.contiguous()
501        };
502        let mut output = Tensor::new(self.shape().dims.clone());
503
504        unsafe {
505            let src = src_self.as_ptr();
506            let dst = output.as_mut_ptr();
507
508            #[cfg(target_arch = "x86_64")]
509            {
510                // Use SIMD for better performance when available
511                if is_x86_feature_detected!("avx2") {
512                    self.mul_scalar_simd_avx2_optimized(src, dst, scalar);
513                    return output;
514                }
515            }
516
517            // Fallback to scalar operations with better cache usage
518            self.mul_scalar_fallback_optimized(src, dst, scalar);
519        }
520
521        output
522    }
523
524    /// AVX2-optimized scalar multiplication implementation
525    ///
526    /// Performs element-wise scalar multiplication using AVX2 SIMD instructions for maximum
527    /// performance on x86_64 architectures with AVX2 support.
528    ///
529    /// # Arguments
530    ///
531    /// * `src` - Pointer to source tensor data
532    /// * `dst` - Pointer to output tensor data
533    /// * `scalar` - Scalar value to multiply by
534    ///
535    /// # Safety
536    ///
537    /// Requires valid pointers with sufficient memory for the tensor size.
538    /// All pointers must point to valid tensor data. Requires AVX2 support.
539    ///
540    /// # Performance Characteristics
541    ///
542    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
543    /// - **Memory Access**: Linear access patterns for cache efficiency
544    /// - **Vector Operations**: Uses AVX2 multiplication instructions
545    /// - **Fallback**: Handles remaining elements with scalar operations
546    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
547    ///
548    /// # Implementation Details
549    ///
550    /// Uses AVX2 vector instructions to process 8 elements simultaneously.
551    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
552    /// Processes remaining elements with scalar operations for complete coverage.
553    #[cfg(target_arch = "x86_64")]
554    #[inline]
555    #[target_feature(enable = "avx2")]
556    unsafe fn mul_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
557        let size = self.size();
558        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
559        let scalar_vec = _mm256_set1_ps(scalar);
560        let mut offset = 0;
561
562        // Unrolled SIMD loop for throughput
563        for _ in 0..simd_count {
564            // Process 4 AVX2 vectors (32 elements) per iteration
565            let src_vec1 = _mm256_loadu_ps(src.add(offset));
566            let mul_vec1 = _mm256_mul_ps(src_vec1, scalar_vec);
567            _mm256_storeu_ps(dst.add(offset), mul_vec1);
568
569            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
570            let mul_vec2 = _mm256_mul_ps(src_vec2, scalar_vec);
571            _mm256_storeu_ps(dst.add(offset + 8), mul_vec2);
572
573            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
574            let mul_vec3 = _mm256_mul_ps(src_vec3, scalar_vec);
575            _mm256_storeu_ps(dst.add(offset + 16), mul_vec3);
576
577            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
578            let mul_vec4 = _mm256_mul_ps(src_vec4, scalar_vec);
579            _mm256_storeu_ps(dst.add(offset + 24), mul_vec4);
580
581            offset += 32;
582        }
583
584        // Handle remaining 8-element blocks, then tail
585        let remaining_full_blocks = (size - offset) / 8;
586        for _ in 0..remaining_full_blocks {
587            let src_vec = _mm256_loadu_ps(src.add(offset));
588            let mul_vec = _mm256_mul_ps(src_vec, scalar_vec);
589            _mm256_storeu_ps(dst.add(offset), mul_vec);
590            offset += 8;
591        }
592        while offset + 4 <= size {
593            *dst.add(offset) = *src.add(offset) * scalar;
594            *dst.add(offset + 1) = *src.add(offset + 1) * scalar;
595            *dst.add(offset + 2) = *src.add(offset + 2) * scalar;
596            *dst.add(offset + 3) = *src.add(offset + 3) * scalar;
597            offset += 4;
598        }
599        for i in offset..size {
600            *dst.add(i) = *src.add(i) * scalar;
601        }
602    }
603
604    /// Optimized scalar multiplication fallback
605    ///
606    /// Performs element-wise scalar multiplication using optimized scalar operations with
607    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
608    ///
609    /// # Arguments
610    ///
611    /// * `src` - Pointer to source tensor data
612    /// * `dst` - Pointer to output tensor data
613    /// * `scalar` - Scalar value to multiply by
614    ///
615    /// # Safety
616    ///
617    /// Requires valid pointers with sufficient memory for the tensor size.
618    /// All pointers must point to valid tensor data.
619    ///
620    /// # Performance Characteristics
621    ///
622    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
623    /// - **Memory Access**: Linear access patterns for cache efficiency
624    /// - **Fallback**: Handles remaining elements with scalar operations
625    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
626    /// - **Mathematical Accuracy**: High-precision scalar multiplication
627    ///
628    /// # Implementation Details
629    ///
630    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
631    /// Processes elements in groups of 4 to improve instruction-level parallelism
632    /// and reduce loop overhead.
633    #[inline]
634    unsafe fn mul_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
635        let size = self.size();
636
637        // Use unrolled loops for better instruction throughput
638        let unroll_count = size / 4;
639        let mut i = 0;
640
641        // Process 4 elements at a time for better cache utilization
642        while i < unroll_count {
643            let idx = i * 4;
644            dst.add(idx).write(src.add(idx).read() * scalar);
645            dst.add(idx + 1).write(src.add(idx + 1).read() * scalar);
646            dst.add(idx + 2).write(src.add(idx + 2).read() * scalar);
647            dst.add(idx + 3).write(src.add(idx + 3).read() * scalar);
648            i += 1;
649        }
650
651        // Handle remaining elements
652        for j in (unroll_count * 4)..size {
653            dst.add(j).write(src.add(j).read() * scalar);
654        }
655    }
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661
662    #[test]
663    fn test_tensor_multiplication() {
664        let a = Tensor::ones(vec![2, 3]);
665        let mut b = Tensor::ones(vec![2, 3]);
666        b.fill(2.0);
667        let result = a.mul_tensor_optimized(&b);
668
669        assert_eq!(result.shape().dims, vec![2, 3]);
670        assert_eq!(result.size(), 6);
671
672        // Check that all values are 2.0 (1.0 * 2.0)
673        unsafe {
674            for i in 0..result.size() {
675                assert!((result.as_ptr().add(i).read() - 2.0).abs() < 1e-6);
676            }
677        }
678    }
679
680    #[test]
681    fn test_scalar_multiplication() {
682        let tensor = Tensor::ones(vec![2, 2]);
683        let result = tensor.mul_scalar_optimized(3.0);
684
685        assert_eq!(result.shape().dims, vec![2, 2]);
686        assert_eq!(result.size(), 4);
687
688        // Check that all values are 3.0 (1.0 * 3.0)
689        unsafe {
690            for i in 0..result.size() {
691                assert!((result.as_ptr().add(i).read() - 3.0).abs() < 1e-6);
692            }
693        }
694    }
695
696    #[test]
697    fn test_zero_multiplication() {
698        let tensor = Tensor::ones(vec![2, 3]);
699        let result = tensor.mul_scalar_optimized(0.0);
700
701        assert_eq!(result.shape().dims, vec![2, 3]);
702        assert_eq!(result.size(), 6);
703
704        // Check that all values are 0.0 (1.0 * 0.0)
705        unsafe {
706            for i in 0..result.size() {
707                assert!((result.as_ptr().add(i).read() - 0.0).abs() < 1e-6);
708            }
709        }
710    }
711
712    #[test]
713    fn test_negative_multiplication() {
714        let tensor = Tensor::ones(vec![2, 3]);
715        let result = tensor.mul_scalar_optimized(-2.0);
716
717        assert_eq!(result.shape().dims, vec![2, 3]);
718        assert_eq!(result.size(), 6);
719
720        // Check that all values are -2.0 (1.0 * -2.0)
721        unsafe {
722            for i in 0..result.size() {
723                assert!((result.as_ptr().add(i).read() - (-2.0)).abs() < 1e-6);
724            }
725        }
726    }
727
728    #[test]
729    #[should_panic(expected = "Tensor shapes must match")]
730    fn test_mismatched_shapes() {
731        let a = Tensor::ones(vec![2, 3]);
732        let b = Tensor::ones(vec![3, 2]);
733        a.mul_tensor_optimized(&b);
734    }
735
736    #[test]
737    fn test_edge_cases() {
738        // Test with zero tensor
739        let zero_tensor = Tensor::zeros(vec![2, 3]);
740        let other = Tensor::ones(vec![2, 3]);
741        let result = zero_tensor.mul_tensor_optimized(&other);
742
743        assert_eq!(result.shape().dims, vec![2, 3]);
744        assert_eq!(result.size(), 6);
745
746        // Check that all values are 0.0 (0.0 * 1.0)
747        unsafe {
748            for i in 0..result.size() {
749                assert!((result.as_ptr().add(i).read() - 0.0).abs() < 1e-6);
750            }
751        }
752
753        // Test with negative values
754        let mut neg_tensor = Tensor::ones(vec![2, 3]);
755        neg_tensor.fill(-1.0);
756        let result = neg_tensor.mul_scalar_optimized(2.0);
757
758        assert_eq!(result.shape().dims, vec![2, 3]);
759        assert_eq!(result.size(), 6);
760
761        // Check that all values are -2.0 (-1.0 * 2.0)
762        unsafe {
763            for i in 0..result.size() {
764                assert!((result.as_ptr().add(i).read() - (-2.0)).abs() < 1e-6);
765            }
766        }
767    }
768
769    #[test]
770    fn test_large_tensor_multiplication() {
771        let a = Tensor::ones(vec![100, 100]);
772        let mut b = Tensor::ones(vec![100, 100]);
773        b.fill(1.5);
774        let result = a.mul_tensor_optimized(&b);
775
776        assert_eq!(result.shape().dims, vec![100, 100]);
777        assert_eq!(result.size(), 10000);
778
779        // Check that all values are 1.5 (1.0 * 1.5)
780        unsafe {
781            for i in 0..result.size() {
782                assert!((result.as_ptr().add(i).read() - 1.5).abs() < 1e-6);
783            }
784        }
785    }
786
787    #[test]
788    fn test_multiplication_with_gradtrack() {
789        // Test scalar multiplication with gradtrack
790        let a = Tensor::ones(vec![2, 3]).with_requires_grad();
791        let mut result = a.mul_scalar(3.0);
792
793        // Check result values: 1.0 * 3.0 = 3.0
794        unsafe {
795            for i in 0..result.size() {
796                let val = result.as_ptr().add(i).read();
797                assert!((val - 3.0).abs() < 1e-6, "Expected 3.0, got {}", val);
798            }
799        }
800
801        result.backward(None);
802
803        // Check gradient: d/dx(3x) = 3
804        if let Some(grad) = a.grad_by_value() {
805            unsafe {
806                for i in 0..grad.size() {
807                    let val = grad.as_ptr().add(i).read();
808                    assert!(
809                        (val - 3.0).abs() < 1e-6,
810                        "Expected gradient 3.0, got {}",
811                        val
812                    );
813                }
814            }
815        } else {
816            panic!("No gradient computed for scalar multiplication!");
817        }
818
819        // Test tensor multiplication with gradtrack
820        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
821        let mut b = Tensor::ones(vec![2, 2]);
822        b.fill(3.0);
823        let b = b.with_requires_grad();
824
825        let mut result = a.mul_tensor(&b);
826
827        // Check result values: 1.0 * 3.0 = 3.0
828        unsafe {
829            for i in 0..result.size() {
830                let val = result.as_ptr().add(i).read();
831                assert!((val - 3.0).abs() < 1e-6, "Expected 3.0, got {}", val);
832            }
833        }
834
835        result.backward(None);
836
837        // Check gradients: ∂(a*b)/∂a = b, ∂(a*b)/∂b = a
838        // For a = 1.0, b = 3.0: ∂(a*b)/∂a = 3.0, ∂(a*b)/∂b = 1.0
839        if let Some(grad_a) = a.grad_by_value() {
840            unsafe {
841                for i in 0..grad_a.size() {
842                    let val = grad_a.as_ptr().add(i).read();
843                    assert!(
844                        (val - 3.0).abs() < 1e-6,
845                        "Expected gradient A = 3.0 (∂(a*b)/∂a = b), got {}",
846                        val
847                    );
848                }
849            }
850        } else {
851            panic!("No gradient A computed for tensor multiplication!");
852        }
853
854        if let Some(grad_b) = b.grad_by_value() {
855            unsafe {
856                for i in 0..grad_b.size() {
857                    let val = grad_b.as_ptr().add(i).read();
858                    assert!(
859                        (val - 1.0).abs() < 1e-6,
860                        "Expected gradient B = 1.0 (∂(a*b)/∂b = a), got {}",
861                        val
862                    );
863                }
864            }
865        } else {
866            panic!("No gradient B computed for tensor multiplication!");
867        }
868    }
869
870    #[test]
871    fn test_mixed_mul_add_operations_with_gradtrack() {
872        // Test complex computation: (a * 2) + (b * 3) - 1
873        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
874        let mut b = Tensor::ones(vec![2, 2]);
875        b.fill(3.0);
876        let b = b.with_requires_grad();
877
878        let scalar1 = 2.0;
879        let scalar2 = 3.0;
880
881        // Compute: (a * scalar1) + (b * scalar2) - 1
882        let mul_a = a.mul_scalar(scalar1); // a * 2
883        let mul_b = b.mul_scalar(scalar2); // b * 3
884        let add_result = mul_a.add_tensor(&mul_b); // (a * 2) + (b * 3)
885        let mut final_result = add_result.sub_scalar(1.0); // (a * 2) + (b * 3) - 1
886
887        // Check result values: (1*2) + (3*3) - 1 = 2 + 9 - 1 = 10
888        unsafe {
889            for i in 0..final_result.size() {
890                let val = final_result.as_ptr().add(i).read();
891                assert!((val - 10.0).abs() < 1e-6, "Expected 10.0, got {}", val);
892            }
893        }
894
895        final_result.backward(None);
896
897        // Check gradients: d/dx((2x + 3y - 1)) = 2, d/dy((2x + 3y - 1)) = 3
898        if let Some(grad_a) = a.grad_by_value() {
899            unsafe {
900                for i in 0..grad_a.size() {
901                    let val = grad_a.as_ptr().add(i).read();
902                    assert!(
903                        (val - 2.0).abs() < 1e-6,
904                        "Expected gradient A = 2.0, got {}",
905                        val
906                    );
907                }
908            }
909        } else {
910            panic!("No gradient A computed for mixed operations!");
911        }
912
913        if let Some(grad_b) = b.grad_by_value() {
914            unsafe {
915                for i in 0..grad_b.size() {
916                    let val = grad_b.as_ptr().add(i).read();
917                    assert!(
918                        (val - 3.0).abs() < 1e-6,
919                        "Expected gradient B = 3.0, got {}",
920                        val
921                    );
922                }
923            }
924        } else {
925            panic!("No gradient B computed for mixed operations!");
926        }
927    }
928
929    #[test]
930    fn test_mul_broadcasting_gradients() {
931        use crate::gradtrack::clear_gradients;
932        clear_gradients();
933
934        // Test case: [2, 3] * [1, 3] -> [2, 3]
935        // For multiplication: d/da (a * b) = b, d/db (a * b) = a
936        // So grad_a = grad_output * b, grad_b = grad_output * a (then reduced)
937
938        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
939            .unwrap()
940            .with_requires_grad();
941        let b = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![1, 3])
942            .unwrap()
943            .with_requires_grad();
944
945        let mut result = a.mul_tensor(&b);
946        assert_eq!(result.shape().dims, vec![2, 3]);
947
948        // Set upstream gradient as ones
949        result.backward(None);
950
951        // Check gradients
952        let grad_a = a.grad_by_value().expect("grad_a should exist");
953        let grad_b = b.grad_by_value().expect("grad_b should exist");
954
955        println!(
956            "Original shapes: a={:?}, b={:?}",
957            a.shape().dims,
958            b.shape().dims
959        );
960        println!(
961            "Gradient shapes: grad_a={:?}, grad_b={:?}",
962            grad_a.shape().dims,
963            grad_b.shape().dims
964        );
965
966        // grad_a should have same shape as a: [2, 3]
967        assert_eq!(
968            grad_a.shape().dims,
969            vec![2, 3],
970            "grad_a should match original shape of a"
971        );
972
973        // grad_b should have same shape as b: [1, 3]
974        assert_eq!(
975            grad_b.shape().dims,
976            vec![1, 3],
977            "grad_b should match original shape of b"
978        );
979
980        // For multiplication gradients:
981        // grad_a = grad_output * b = 1.0 * [2.0, 3.0, 4.0] broadcasted to [2, 3]
982        let expected_grad_a = [2.0, 3.0, 4.0, 2.0, 3.0, 4.0];
983        for (i, val) in expected_grad_a.iter().enumerate().take(grad_a.size()) {
984            let actual = unsafe { *grad_a.as_ptr().add(i) };
985            assert!(
986                (actual - val).abs() < 1e-6,
987                "grad_a[{}] = {} should be {}",
988                i,
989                actual,
990                val
991            );
992        }
993
994        // grad_b = grad_output * a summed over broadcasted dimension
995        // a = [1,2,3,4,5,6] -> grad_b should be [1+4, 2+5, 3+6] = [5, 7, 9]
996        let expected_grad_b = [5.0, 7.0, 9.0];
997        for (i, val) in expected_grad_b.iter().enumerate().take(grad_b.size()) {
998            let actual = unsafe { *grad_b.as_ptr().add(i) };
999            assert!(
1000                (actual - val).abs() < 1e-6,
1001                "grad_b[{}] = {} should be {}",
1002                i,
1003                actual,
1004                val
1005            );
1006        }
1007    }
1008}