train_station/tensor/ops/
mul.rs

1//! Multiplication operations for tensors
2//!
3//! Provides element-wise multiplication functions following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Multiplication**: `mul_tensor()` - Element-wise multiplication with another tensor (PyTorch `mul()` equivalent)
9//! - **Scalar Multiplication**: `mul_scalar()` - Broadcast multiplication with a scalar value
10//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
11//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
12//! - **Broadcasting Support**: NumPy-style broadcasting for compatible shapes
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Mathematical Accuracy**: High-precision multiplication computation
15//!
16//! # Mathematical Properties
17//!
18//! The multiplication operations have the following properties:
19//! - **Commutative**: a * b = b * a
20//! - **Associative**: (a * b) * c = a * (b * c)
21//! - **Distributive**: a * (b + c) = a * b + a * c
22//! - **Identity**: a * 1 = a
23//! - **Zero**: a * 0 = 0
24//! - **Gradient**: d/dx(a * b) = b, d/dy(a * b) = a
25//!
26//! # Performance Characteristics
27//!
28//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
29//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
30//! - **Cache-friendly Access**: Linear memory access patterns
31//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
32//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41// Note: removed manual prefetching; linear access + hardware prefetch is sufficient
42
43impl Tensor {
44    /// Element-wise multiplication with another tensor with broadcasting support.
45    ///
46    /// Performs element-wise multiplication with automatic broadcasting: `output[i] = self[i] * other[i]`
47    ///
48    /// Broadcasting enables multiplication between tensors of different but compatible shapes.
49    /// Compatible shapes follow NumPy broadcasting rules:
50    /// - Dimensions are aligned from the rightmost dimension
51    /// - Dimensions are compatible if they are equal, or one of them is 1
52    /// - Missing dimensions are treated as 1
53    ///
54    /// # Arguments
55    /// * `other` - Tensor to multiply. Shapes must be broadcast-compatible.
56    ///
57    /// # Returns
58    /// A new tensor containing the element-wise product with broadcast result shape
59    ///
60    /// # Examples
61    ///
62    /// ## Same Shape Multiplication
63    ///
64    /// ```
65    /// use train_station::Tensor;
66    ///
67    /// let a = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![3]).unwrap();
68    /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0], vec![3]).unwrap();
69    /// let c = a.mul_tensor(&b);
70    /// assert_eq!(c.shape().dims, vec![3]);
71    /// assert_eq!(c.get(&[0]), 10.0); // 2.0 * 5.0
72    /// assert_eq!(c.get(&[1]), 18.0); // 3.0 * 6.0
73    /// assert_eq!(c.get(&[2]), 28.0); // 4.0 * 7.0
74    /// ```
75    ///
76    /// ## Broadcasting Multiplication
77    ///
78    /// ```
79    /// use train_station::Tensor;
80    ///
81    /// let a = Tensor::from_slice(&[2.0, 3.0], vec![2, 1]).unwrap();
82    /// let b = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![1, 3]).unwrap();
83    /// let c = a.mul_tensor(&b);
84    /// assert_eq!(c.shape().dims, vec![2, 3]);
85    /// // Result: [[20.0, 40.0, 60.0], [30.0, 60.0, 90.0]]
86    /// assert_eq!(c.get(&[0, 0]), 20.0); // 2.0 * 10.0
87    /// assert_eq!(c.get(&[0, 1]), 40.0); // 2.0 * 20.0
88    /// assert_eq!(c.get(&[1, 0]), 30.0); // 3.0 * 10.0
89    /// ```
90    ///
91    /// # Panics
92    /// Panics if tensor shapes are not broadcast-compatible
93    #[inline]
94    #[track_caller]
95    pub fn mul_tensor(&self, other: &Tensor) -> Tensor {
96        // Check if shapes are identical for fast path
97        if self.shape().dims == other.shape().dims {
98            return self.mul_tensor_same_shape(other);
99        }
100
101        // Use broadcasting for different shapes
102        let (broadcast_self, broadcast_other, _result_shape) =
103            self.broadcast_with(other).unwrap_or_else(|e| {
104                panic!(
105                    "Cannot broadcast tensor shapes {:?} and {:?}: {}",
106                    self.shape().dims,
107                    other.shape().dims,
108                    e
109                );
110            });
111
112        // Perform element-wise multiplication on broadcasted tensors
113        let mut result = broadcast_self.mul_tensor_optimized(&broadcast_other);
114
115        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
116            result.set_requires_grad_internal(true);
117            let operands = vec![self.clone(), other.clone()];
118            let grad_fn = GradFn::Mul {
119                is_tensor_mul: true,
120                scalar: None,
121                operands: Some(operands),
122                original_shapes: Some((self.shape().dims.clone(), other.shape().dims.clone())),
123            };
124            result.set_grad_fn(grad_fn.clone());
125
126            let mut input_ids = Vec::with_capacity(2);
127            if self.requires_grad() {
128                input_ids.push(self.id());
129            }
130            if other.requires_grad() {
131                input_ids.push(other.id());
132            }
133            GradEngine::register_operation(result.id(), input_ids, grad_fn);
134        }
135
136        result
137    }
138
139    /// Element-wise multiplication for tensors with identical shapes (fast path)
140    ///
141    /// This is an optimized path for tensors that already have the same shape,
142    /// avoiding the overhead of broadcasting computation. This method provides
143    /// better performance when tensors have matching dimensions.
144    ///
145    /// # Arguments
146    ///
147    /// * `other` - Tensor to multiply, must have the same shape as self
148    ///
149    /// # Returns
150    ///
151    /// A new tensor containing the element-wise product
152    ///
153    /// # Performance Characteristics
154    ///
155    /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
156    /// - **SIMD Optimization**: Uses optimized multiplication when available
157    /// - **Memory Efficiency**: Direct element-wise computation without shape conversion
158    /// - **Gradient Tracking**: Full gradtrack support with efficient gradient computation
159    ///
160    /// # Panics
161    ///
162    /// Panics if tensor shapes do not match
163    ///
164    /// # Implementation Details
165    ///
166    /// This method is called internally by `mul_tensor()` when shapes are identical.
167    /// It provides a performance optimization by skipping the broadcasting logic
168    /// and directly calling the optimized multiplication implementation.
169    #[inline]
170    fn mul_tensor_same_shape(&self, other: &Tensor) -> Tensor {
171        assert_eq!(
172            self.shape(),
173            other.shape(),
174            "Tensor shapes must match for same-shape multiplication"
175        );
176        let mut result = self.mul_tensor_optimized(other);
177
178        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
179            result.set_requires_grad_internal(true);
180            let operands = vec![self.clone(), other.clone()];
181            let grad_fn = GradFn::Mul {
182                is_tensor_mul: true,
183                scalar: None,
184                operands: Some(operands),
185                original_shapes: None, // Same shape case
186            };
187            result.set_grad_fn(grad_fn.clone());
188
189            let mut input_ids = Vec::with_capacity(2);
190            if self.requires_grad() {
191                input_ids.push(self.id());
192            }
193            if other.requires_grad() {
194                input_ids.push(other.id());
195            }
196            GradEngine::register_operation(result.id(), input_ids, grad_fn);
197        }
198
199        result
200    }
201
202    /// Broadcast multiplication with a scalar value.
203    ///
204    /// Multiplies every element by the scalar: `output[i] = self[i] * scalar`
205    ///
206    /// # Arguments
207    /// * `scalar` - Value to multiply with each element
208    ///
209    /// # Returns
210    /// A new tensor with each element multiplied by the scalar
211    ///
212    /// # Examples
213    ///
214    /// ## Basic Scalar Multiplication
215    ///
216    /// ```
217    /// use train_station::Tensor;
218    ///
219    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
220    /// let b = a.mul_scalar(10.0);
221    /// assert_eq!(b.shape().dims, vec![3]);
222    /// assert_eq!(b.get(&[0]), 10.0); // 1.0 * 10.0
223    /// assert_eq!(b.get(&[1]), 20.0); // 2.0 * 10.0
224    /// assert_eq!(b.get(&[2]), 30.0); // 3.0 * 10.0
225    /// ```
226    ///
227    /// ## Negative Scalar Multiplication
228    ///
229    /// ```
230    /// use train_station::Tensor;
231    ///
232    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
233    /// let b = a.mul_scalar(-2.0);
234    /// assert_eq!(b.shape().dims, vec![3]);
235    /// assert_eq!(b.get(&[0]), -2.0); // 1.0 * -2.0
236    /// assert_eq!(b.get(&[1]), -4.0); // 2.0 * -2.0
237    /// assert_eq!(b.get(&[2]), -6.0); // 3.0 * -2.0
238    /// ```
239    #[inline]
240    #[track_caller]
241    pub fn mul_scalar(&self, scalar: f32) -> Tensor {
242        let mut result = self.mul_scalar_optimized(scalar);
243
244        if self.requires_grad() && is_grad_enabled() {
245            result.set_requires_grad_internal(true);
246            let grad_fn = GradFn::Mul {
247                is_tensor_mul: false,
248                scalar: Some(scalar),
249                operands: None,
250                original_shapes: None, // Scalar case
251            };
252            result.set_grad_fn(grad_fn.clone());
253            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
254        }
255
256        result
257    }
258    /// Optimized tensor multiplication using SIMD when available
259    ///
260    /// Performs element-wise multiplication using SIMD optimization when available
261    /// and falling back to optimized scalar computation. This is the core implementation
262    /// used by `mul_tensor()`.
263    ///
264    /// # Arguments
265    ///
266    /// * `other` - The tensor to multiply with
267    ///
268    /// # Returns
269    ///
270    /// A new tensor with the result of the multiplication
271    ///
272    /// # Performance Characteristics
273    ///
274    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
275    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
276    /// - **Cache-friendly**: Linear memory access patterns
277    /// - **Memory Safety**: Ensures contiguous memory layout for correctness
278    /// - **Zero-sized Handling**: Fast return for empty tensors
279    ///
280    /// # Safety
281    ///
282    /// This operation assumes the tensors have the same shape. The method ensures
283    /// contiguous memory layout for both input tensors to guarantee correctness
284    /// with view tensors.
285    ///
286    /// # Implementation Details
287    ///
288    /// Automatically selects between SIMD and scalar implementations based on hardware
289    /// capabilities. Ensures both input tensors are contiguous for memory safety.
290    /// SIMD implementation processes 32 elements per iteration with 4x unrolling.
291    #[inline]
292    pub(crate) fn mul_tensor_optimized(&self, other: &Tensor) -> Tensor {
293        assert_eq!(self.shape(), other.shape(), "Tensor shapes must match");
294        // Ensure contiguous sources for correctness with view tensors
295        let a_src = if self.is_contiguous() {
296            self.clone()
297        } else {
298            self.contiguous()
299        };
300        let b_src = if other.is_contiguous() {
301            other.clone()
302        } else {
303            other.contiguous()
304        };
305
306        let mut output = Tensor::new(self.shape().dims.clone());
307
308        unsafe {
309            let a = a_src.as_ptr();
310            let b = b_src.as_ptr();
311            let dst = output.as_mut_ptr();
312
313            #[cfg(target_arch = "x86_64")]
314            {
315                // Use SIMD for better performance when available
316                if is_x86_feature_detected!("avx2") {
317                    self.mul_tensors_simd_avx2_optimized(a, b, dst);
318                    return output;
319                }
320            }
321
322            // Fallback to scalar operations with better cache usage
323            self.mul_tensors_scalar_optimized(a, b, dst);
324        }
325
326        output
327    }
328
329    /// AVX2-optimized tensor multiplication implementation
330    ///
331    /// Performs element-wise multiplication using AVX2 SIMD instructions for maximum
332    /// performance on x86_64 architectures with AVX2 support.
333    ///
334    /// # Arguments
335    ///
336    /// * `a` - Pointer to first tensor data
337    /// * `b` - Pointer to second tensor data
338    /// * `dst` - Pointer to output tensor data
339    ///
340    /// # Safety
341    ///
342    /// Requires valid pointers with sufficient memory for the tensor size.
343    /// All pointers must point to valid tensor data. Requires AVX2 support.
344    ///
345    /// # Performance Characteristics
346    ///
347    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
348    /// - **Memory Access**: Linear access patterns for cache efficiency
349    /// - **Vector Operations**: Uses AVX2 multiplication instructions
350    /// - **Fallback**: Handles remaining elements with scalar operations
351    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
352    ///
353    /// # Implementation Details
354    ///
355    /// Uses AVX2 vector instructions to process 8 elements simultaneously.
356    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
357    /// Processes remaining elements with scalar operations for complete coverage.
358    #[cfg(target_arch = "x86_64")]
359    #[inline]
360    #[target_feature(enable = "avx2")]
361    unsafe fn mul_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
362        let size = self.size();
363        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
364        let mut offset = 0;
365
366        // Unrolled SIMD loop for throughput
367        for _ in 0..simd_count {
368            // Process 4 AVX2 vectors (32 elements) per iteration
369            let a_vec1 = _mm256_loadu_ps(a.add(offset));
370            let b_vec1 = _mm256_loadu_ps(b.add(offset));
371            let mul_vec1 = _mm256_mul_ps(a_vec1, b_vec1);
372            _mm256_storeu_ps(dst.add(offset), mul_vec1);
373
374            let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
375            let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
376            let mul_vec2 = _mm256_mul_ps(a_vec2, b_vec2);
377            _mm256_storeu_ps(dst.add(offset + 8), mul_vec2);
378
379            let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
380            let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
381            let mul_vec3 = _mm256_mul_ps(a_vec3, b_vec3);
382            _mm256_storeu_ps(dst.add(offset + 16), mul_vec3);
383
384            let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
385            let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
386            let mul_vec4 = _mm256_mul_ps(a_vec4, b_vec4);
387            _mm256_storeu_ps(dst.add(offset + 24), mul_vec4);
388
389            offset += 32;
390        }
391
392        // Handle remaining 8-element blocks, then tail
393        let remaining_full_blocks = (size - offset) / 8;
394        for _ in 0..remaining_full_blocks {
395            let a_vec = _mm256_loadu_ps(a.add(offset));
396            let b_vec = _mm256_loadu_ps(b.add(offset));
397            let mul_vec = _mm256_mul_ps(a_vec, b_vec);
398            _mm256_storeu_ps(dst.add(offset), mul_vec);
399            offset += 8;
400        }
401        while offset + 4 <= size {
402            *dst.add(offset) = *a.add(offset) * *b.add(offset);
403            *dst.add(offset + 1) = *a.add(offset + 1) * *b.add(offset + 1);
404            *dst.add(offset + 2) = *a.add(offset + 2) * *b.add(offset + 2);
405            *dst.add(offset + 3) = *a.add(offset + 3) * *b.add(offset + 3);
406            offset += 4;
407        }
408        for i in offset..size {
409            *dst.add(i) = *a.add(i) * *b.add(i);
410        }
411    }
412
413    /// Optimized scalar tensor multiplication fallback
414    ///
415    /// Performs element-wise multiplication using optimized scalar operations with
416    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
417    ///
418    /// # Arguments
419    ///
420    /// * `a` - Pointer to first tensor data
421    /// * `b` - Pointer to second tensor data
422    /// * `dst` - Pointer to output tensor data
423    ///
424    /// # Safety
425    ///
426    /// Requires valid pointers with sufficient memory for the tensor size.
427    /// All pointers must point to valid tensor data.
428    ///
429    /// # Performance Characteristics
430    ///
431    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
432    /// - **Memory Access**: Linear access patterns for cache efficiency
433    /// - **Fallback**: Handles remaining elements with scalar operations
434    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
435    /// - **Mathematical Accuracy**: High-precision scalar multiplication
436    ///
437    /// # Implementation Details
438    ///
439    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
440    /// Processes elements in groups of 4 to improve instruction-level parallelism
441    /// and reduce loop overhead.
442    #[inline]
443    unsafe fn mul_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
444        let size = self.size();
445
446        // Use unrolled loops for better instruction throughput
447        let unroll_count = size / 4;
448        let mut i = 0;
449
450        // Process 4 elements at a time for better cache utilization
451        while i < unroll_count {
452            let idx = i * 4;
453            dst.add(idx).write(a.add(idx).read() * b.add(idx).read());
454            dst.add(idx + 1)
455                .write(a.add(idx + 1).read() * b.add(idx + 1).read());
456            dst.add(idx + 2)
457                .write(a.add(idx + 2).read() * b.add(idx + 2).read());
458            dst.add(idx + 3)
459                .write(a.add(idx + 3).read() * b.add(idx + 3).read());
460            i += 1;
461        }
462
463        // Handle remaining elements
464        for j in (unroll_count * 4)..size {
465            dst.add(j).write(a.add(j).read() * b.add(j).read());
466        }
467    }
468
469    /// Optimized scalar multiplication using SIMD when available
470    ///
471    /// Performs element-wise scalar multiplication using SIMD optimization when available
472    /// and falling back to optimized scalar computation. This is the core implementation
473    /// used by `mul_scalar()`.
474    ///
475    /// # Arguments
476    ///
477    /// * `scalar` - The scalar value to multiply by
478    ///
479    /// # Returns
480    ///
481    /// A new tensor with the result of the multiplication
482    ///
483    /// # Performance Characteristics
484    ///
485    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
486    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
487    /// - **Cache-friendly**: Linear memory access patterns
488    /// - **Memory Safety**: Ensures contiguous memory layout for correctness
489    /// - **Zero-sized Handling**: Fast return for empty tensors
490    ///
491    /// # Implementation Details
492    ///
493    /// Automatically selects between SIMD and scalar implementations based on hardware
494    /// capabilities. Ensures the input tensor is contiguous for memory safety.
495    /// SIMD implementation processes 32 elements per iteration with 4x unrolling.
496    #[inline]
497    pub(crate) fn mul_scalar_optimized(&self, scalar: f32) -> Tensor {
498        // Ensure contiguous source for correctness with view tensors
499        let src_self = if self.is_contiguous() {
500            self.clone()
501        } else {
502            self.contiguous()
503        };
504        let mut output = Tensor::new(self.shape().dims.clone());
505
506        unsafe {
507            let src = src_self.as_ptr();
508            let dst = output.as_mut_ptr();
509
510            #[cfg(target_arch = "x86_64")]
511            {
512                // Use SIMD for better performance when available
513                if is_x86_feature_detected!("avx2") {
514                    self.mul_scalar_simd_avx2_optimized(src, dst, scalar);
515                    return output;
516                }
517            }
518
519            // Fallback to scalar operations with better cache usage
520            self.mul_scalar_fallback_optimized(src, dst, scalar);
521        }
522
523        output
524    }
525
526    /// AVX2-optimized scalar multiplication implementation
527    ///
528    /// Performs element-wise scalar multiplication using AVX2 SIMD instructions for maximum
529    /// performance on x86_64 architectures with AVX2 support.
530    ///
531    /// # Arguments
532    ///
533    /// * `src` - Pointer to source tensor data
534    /// * `dst` - Pointer to output tensor data
535    /// * `scalar` - Scalar value to multiply by
536    ///
537    /// # Safety
538    ///
539    /// Requires valid pointers with sufficient memory for the tensor size.
540    /// All pointers must point to valid tensor data. Requires AVX2 support.
541    ///
542    /// # Performance Characteristics
543    ///
544    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
545    /// - **Memory Access**: Linear access patterns for cache efficiency
546    /// - **Vector Operations**: Uses AVX2 multiplication instructions
547    /// - **Fallback**: Handles remaining elements with scalar operations
548    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
549    ///
550    /// # Implementation Details
551    ///
552    /// Uses AVX2 vector instructions to process 8 elements simultaneously.
553    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
554    /// Processes remaining elements with scalar operations for complete coverage.
555    #[cfg(target_arch = "x86_64")]
556    #[inline]
557    #[target_feature(enable = "avx2")]
558    unsafe fn mul_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
559        let size = self.size();
560        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
561        let scalar_vec = _mm256_set1_ps(scalar);
562        let mut offset = 0;
563
564        // Unrolled SIMD loop for throughput
565        for _ in 0..simd_count {
566            // Process 4 AVX2 vectors (32 elements) per iteration
567            let src_vec1 = _mm256_loadu_ps(src.add(offset));
568            let mul_vec1 = _mm256_mul_ps(src_vec1, scalar_vec);
569            _mm256_storeu_ps(dst.add(offset), mul_vec1);
570
571            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
572            let mul_vec2 = _mm256_mul_ps(src_vec2, scalar_vec);
573            _mm256_storeu_ps(dst.add(offset + 8), mul_vec2);
574
575            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
576            let mul_vec3 = _mm256_mul_ps(src_vec3, scalar_vec);
577            _mm256_storeu_ps(dst.add(offset + 16), mul_vec3);
578
579            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
580            let mul_vec4 = _mm256_mul_ps(src_vec4, scalar_vec);
581            _mm256_storeu_ps(dst.add(offset + 24), mul_vec4);
582
583            offset += 32;
584        }
585
586        // Handle remaining 8-element blocks, then tail
587        let remaining_full_blocks = (size - offset) / 8;
588        for _ in 0..remaining_full_blocks {
589            let src_vec = _mm256_loadu_ps(src.add(offset));
590            let mul_vec = _mm256_mul_ps(src_vec, scalar_vec);
591            _mm256_storeu_ps(dst.add(offset), mul_vec);
592            offset += 8;
593        }
594        while offset + 4 <= size {
595            *dst.add(offset) = *src.add(offset) * scalar;
596            *dst.add(offset + 1) = *src.add(offset + 1) * scalar;
597            *dst.add(offset + 2) = *src.add(offset + 2) * scalar;
598            *dst.add(offset + 3) = *src.add(offset + 3) * scalar;
599            offset += 4;
600        }
601        for i in offset..size {
602            *dst.add(i) = *src.add(i) * scalar;
603        }
604    }
605
606    /// Optimized scalar multiplication fallback
607    ///
608    /// Performs element-wise scalar multiplication using optimized scalar operations with
609    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
610    ///
611    /// # Arguments
612    ///
613    /// * `src` - Pointer to source tensor data
614    /// * `dst` - Pointer to output tensor data
615    /// * `scalar` - Scalar value to multiply by
616    ///
617    /// # Safety
618    ///
619    /// Requires valid pointers with sufficient memory for the tensor size.
620    /// All pointers must point to valid tensor data.
621    ///
622    /// # Performance Characteristics
623    ///
624    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
625    /// - **Memory Access**: Linear access patterns for cache efficiency
626    /// - **Fallback**: Handles remaining elements with scalar operations
627    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
628    /// - **Mathematical Accuracy**: High-precision scalar multiplication
629    ///
630    /// # Implementation Details
631    ///
632    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
633    /// Processes elements in groups of 4 to improve instruction-level parallelism
634    /// and reduce loop overhead.
635    #[inline]
636    unsafe fn mul_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
637        let size = self.size();
638
639        // Use unrolled loops for better instruction throughput
640        let unroll_count = size / 4;
641        let mut i = 0;
642
643        // Process 4 elements at a time for better cache utilization
644        while i < unroll_count {
645            let idx = i * 4;
646            dst.add(idx).write(src.add(idx).read() * scalar);
647            dst.add(idx + 1).write(src.add(idx + 1).read() * scalar);
648            dst.add(idx + 2).write(src.add(idx + 2).read() * scalar);
649            dst.add(idx + 3).write(src.add(idx + 3).read() * scalar);
650            i += 1;
651        }
652
653        // Handle remaining elements
654        for j in (unroll_count * 4)..size {
655            dst.add(j).write(src.add(j).read() * scalar);
656        }
657    }
658}
659
660#[cfg(test)]
661mod tests {
662    use super::*;
663
664    #[test]
665    fn test_tensor_multiplication() {
666        let a = Tensor::ones(vec![2, 3]);
667        let mut b = Tensor::ones(vec![2, 3]);
668        b.fill(2.0);
669        let result = a.mul_tensor_optimized(&b);
670
671        assert_eq!(result.shape().dims, vec![2, 3]);
672        assert_eq!(result.size(), 6);
673
674        // Check that all values are 2.0 (1.0 * 2.0)
675        unsafe {
676            for i in 0..result.size() {
677                assert!((result.as_ptr().add(i).read() - 2.0).abs() < 1e-6);
678            }
679        }
680    }
681
682    #[test]
683    fn test_scalar_multiplication() {
684        let tensor = Tensor::ones(vec![2, 2]);
685        let result = tensor.mul_scalar_optimized(3.0);
686
687        assert_eq!(result.shape().dims, vec![2, 2]);
688        assert_eq!(result.size(), 4);
689
690        // Check that all values are 3.0 (1.0 * 3.0)
691        unsafe {
692            for i in 0..result.size() {
693                assert!((result.as_ptr().add(i).read() - 3.0).abs() < 1e-6);
694            }
695        }
696    }
697
698    #[test]
699    fn test_zero_multiplication() {
700        let tensor = Tensor::ones(vec![2, 3]);
701        let result = tensor.mul_scalar_optimized(0.0);
702
703        assert_eq!(result.shape().dims, vec![2, 3]);
704        assert_eq!(result.size(), 6);
705
706        // Check that all values are 0.0 (1.0 * 0.0)
707        unsafe {
708            for i in 0..result.size() {
709                assert!((result.as_ptr().add(i).read() - 0.0).abs() < 1e-6);
710            }
711        }
712    }
713
714    #[test]
715    fn test_negative_multiplication() {
716        let tensor = Tensor::ones(vec![2, 3]);
717        let result = tensor.mul_scalar_optimized(-2.0);
718
719        assert_eq!(result.shape().dims, vec![2, 3]);
720        assert_eq!(result.size(), 6);
721
722        // Check that all values are -2.0 (1.0 * -2.0)
723        unsafe {
724            for i in 0..result.size() {
725                assert!((result.as_ptr().add(i).read() - (-2.0)).abs() < 1e-6);
726            }
727        }
728    }
729
730    #[test]
731    #[should_panic(expected = "Tensor shapes must match")]
732    fn test_mismatched_shapes() {
733        let a = Tensor::ones(vec![2, 3]);
734        let b = Tensor::ones(vec![3, 2]);
735        a.mul_tensor_optimized(&b);
736    }
737
738    #[test]
739    fn test_edge_cases() {
740        // Test with zero tensor
741        let zero_tensor = Tensor::zeros(vec![2, 3]);
742        let other = Tensor::ones(vec![2, 3]);
743        let result = zero_tensor.mul_tensor_optimized(&other);
744
745        assert_eq!(result.shape().dims, vec![2, 3]);
746        assert_eq!(result.size(), 6);
747
748        // Check that all values are 0.0 (0.0 * 1.0)
749        unsafe {
750            for i in 0..result.size() {
751                assert!((result.as_ptr().add(i).read() - 0.0).abs() < 1e-6);
752            }
753        }
754
755        // Test with negative values
756        let mut neg_tensor = Tensor::ones(vec![2, 3]);
757        neg_tensor.fill(-1.0);
758        let result = neg_tensor.mul_scalar_optimized(2.0);
759
760        assert_eq!(result.shape().dims, vec![2, 3]);
761        assert_eq!(result.size(), 6);
762
763        // Check that all values are -2.0 (-1.0 * 2.0)
764        unsafe {
765            for i in 0..result.size() {
766                assert!((result.as_ptr().add(i).read() - (-2.0)).abs() < 1e-6);
767            }
768        }
769    }
770
771    #[test]
772    fn test_large_tensor_multiplication() {
773        let a = Tensor::ones(vec![100, 100]);
774        let mut b = Tensor::ones(vec![100, 100]);
775        b.fill(1.5);
776        let result = a.mul_tensor_optimized(&b);
777
778        assert_eq!(result.shape().dims, vec![100, 100]);
779        assert_eq!(result.size(), 10000);
780
781        // Check that all values are 1.5 (1.0 * 1.5)
782        unsafe {
783            for i in 0..result.size() {
784                assert!((result.as_ptr().add(i).read() - 1.5).abs() < 1e-6);
785            }
786        }
787    }
788
789    #[test]
790    fn test_multiplication_with_gradtrack() {
791        // Test scalar multiplication with gradtrack
792        let a = Tensor::ones(vec![2, 3]).with_requires_grad();
793        let mut result = a.mul_scalar(3.0);
794
795        // Check result values: 1.0 * 3.0 = 3.0
796        unsafe {
797            for i in 0..result.size() {
798                let val = result.as_ptr().add(i).read();
799                assert!((val - 3.0).abs() < 1e-6, "Expected 3.0, got {}", val);
800            }
801        }
802
803        result.backward(None);
804
805        // Check gradient: d/dx(3x) = 3
806        if let Some(grad) = a.grad_by_value() {
807            unsafe {
808                for i in 0..grad.size() {
809                    let val = grad.as_ptr().add(i).read();
810                    assert!(
811                        (val - 3.0).abs() < 1e-6,
812                        "Expected gradient 3.0, got {}",
813                        val
814                    );
815                }
816            }
817        } else {
818            panic!("No gradient computed for scalar multiplication!");
819        }
820
821        // Test tensor multiplication with gradtrack
822        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
823        let mut b = Tensor::ones(vec![2, 2]);
824        b.fill(3.0);
825        let b = b.with_requires_grad();
826
827        let mut result = a.mul_tensor(&b);
828
829        // Check result values: 1.0 * 3.0 = 3.0
830        unsafe {
831            for i in 0..result.size() {
832                let val = result.as_ptr().add(i).read();
833                assert!((val - 3.0).abs() < 1e-6, "Expected 3.0, got {}", val);
834            }
835        }
836
837        result.backward(None);
838
839        // Check gradients: ∂(a*b)/∂a = b, ∂(a*b)/∂b = a
840        // For a = 1.0, b = 3.0: ∂(a*b)/∂a = 3.0, ∂(a*b)/∂b = 1.0
841        if let Some(grad_a) = a.grad_by_value() {
842            unsafe {
843                for i in 0..grad_a.size() {
844                    let val = grad_a.as_ptr().add(i).read();
845                    assert!(
846                        (val - 3.0).abs() < 1e-6,
847                        "Expected gradient A = 3.0 (∂(a*b)/∂a = b), got {}",
848                        val
849                    );
850                }
851            }
852        } else {
853            panic!("No gradient A computed for tensor multiplication!");
854        }
855
856        if let Some(grad_b) = b.grad_by_value() {
857            unsafe {
858                for i in 0..grad_b.size() {
859                    let val = grad_b.as_ptr().add(i).read();
860                    assert!(
861                        (val - 1.0).abs() < 1e-6,
862                        "Expected gradient B = 1.0 (∂(a*b)/∂b = a), got {}",
863                        val
864                    );
865                }
866            }
867        } else {
868            panic!("No gradient B computed for tensor multiplication!");
869        }
870    }
871
872    #[test]
873    fn test_mixed_mul_add_operations_with_gradtrack() {
874        // Test complex computation: (a * 2) + (b * 3) - 1
875        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
876        let mut b = Tensor::ones(vec![2, 2]);
877        b.fill(3.0);
878        let b = b.with_requires_grad();
879
880        let scalar1 = 2.0;
881        let scalar2 = 3.0;
882
883        // Compute: (a * scalar1) + (b * scalar2) - 1
884        let mul_a = a.mul_scalar(scalar1); // a * 2
885        let mul_b = b.mul_scalar(scalar2); // b * 3
886        let add_result = mul_a.add_tensor(&mul_b); // (a * 2) + (b * 3)
887        let mut final_result = add_result.sub_scalar(1.0); // (a * 2) + (b * 3) - 1
888
889        // Check result values: (1*2) + (3*3) - 1 = 2 + 9 - 1 = 10
890        unsafe {
891            for i in 0..final_result.size() {
892                let val = final_result.as_ptr().add(i).read();
893                assert!((val - 10.0).abs() < 1e-6, "Expected 10.0, got {}", val);
894            }
895        }
896
897        final_result.backward(None);
898
899        // Check gradients: d/dx((2x + 3y - 1)) = 2, d/dy((2x + 3y - 1)) = 3
900        if let Some(grad_a) = a.grad_by_value() {
901            unsafe {
902                for i in 0..grad_a.size() {
903                    let val = grad_a.as_ptr().add(i).read();
904                    assert!(
905                        (val - 2.0).abs() < 1e-6,
906                        "Expected gradient A = 2.0, got {}",
907                        val
908                    );
909                }
910            }
911        } else {
912            panic!("No gradient A computed for mixed operations!");
913        }
914
915        if let Some(grad_b) = b.grad_by_value() {
916            unsafe {
917                for i in 0..grad_b.size() {
918                    let val = grad_b.as_ptr().add(i).read();
919                    assert!(
920                        (val - 3.0).abs() < 1e-6,
921                        "Expected gradient B = 3.0, got {}",
922                        val
923                    );
924                }
925            }
926        } else {
927            panic!("No gradient B computed for mixed operations!");
928        }
929    }
930
931    #[test]
932    fn test_mul_broadcasting_gradients() {
933        use crate::gradtrack::clear_gradients;
934        clear_gradients();
935
936        // Test case: [2, 3] * [1, 3] -> [2, 3]
937        // For multiplication: d/da (a * b) = b, d/db (a * b) = a
938        // So grad_a = grad_output * b, grad_b = grad_output * a (then reduced)
939
940        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
941            .unwrap()
942            .with_requires_grad();
943        let b = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![1, 3])
944            .unwrap()
945            .with_requires_grad();
946
947        let mut result = a.mul_tensor(&b);
948        assert_eq!(result.shape().dims, vec![2, 3]);
949
950        // Set upstream gradient as ones
951        result.backward(None);
952
953        // Check gradients
954        let grad_a = a.grad_by_value().expect("grad_a should exist");
955        let grad_b = b.grad_by_value().expect("grad_b should exist");
956
957        println!(
958            "Original shapes: a={:?}, b={:?}",
959            a.shape().dims,
960            b.shape().dims
961        );
962        println!(
963            "Gradient shapes: grad_a={:?}, grad_b={:?}",
964            grad_a.shape().dims,
965            grad_b.shape().dims
966        );
967
968        // grad_a should have same shape as a: [2, 3]
969        assert_eq!(
970            grad_a.shape().dims,
971            vec![2, 3],
972            "grad_a should match original shape of a"
973        );
974
975        // grad_b should have same shape as b: [1, 3]
976        assert_eq!(
977            grad_b.shape().dims,
978            vec![1, 3],
979            "grad_b should match original shape of b"
980        );
981
982        // For multiplication gradients:
983        // grad_a = grad_output * b = 1.0 * [2.0, 3.0, 4.0] broadcasted to [2, 3]
984        let expected_grad_a = [2.0, 3.0, 4.0, 2.0, 3.0, 4.0];
985        for (i, val) in expected_grad_a.iter().enumerate().take(grad_a.size()) {
986            let actual = unsafe { *grad_a.as_ptr().add(i) };
987            assert!(
988                (actual - val).abs() < 1e-6,
989                "grad_a[{}] = {} should be {}",
990                i,
991                actual,
992                val
993            );
994        }
995
996        // grad_b = grad_output * a summed over broadcasted dimension
997        // a = [1,2,3,4,5,6] -> grad_b should be [1+4, 2+5, 3+6] = [5, 7, 9]
998        let expected_grad_b = [5.0, 7.0, 9.0];
999        for (i, val) in expected_grad_b.iter().enumerate().take(grad_b.size()) {
1000            let actual = unsafe { *grad_b.as_ptr().add(i) };
1001            assert!(
1002                (actual - val).abs() < 1e-6,
1003                "grad_b[{}] = {} should be {}",
1004                i,
1005                actual,
1006                val
1007            );
1008        }
1009    }
1010}