train_station/tensor/ops/
mul.rs

1//! Multiplication operations for tensors
2//!
3//! Provides element-wise multiplication functions following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Multiplication**: `mul_tensor()` - Element-wise multiplication with another tensor (PyTorch `mul()` equivalent)
9//! - **Scalar Multiplication**: `mul_scalar()` - Broadcast multiplication with a scalar value
10//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
11//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
12//! - **Broadcasting Support**: NumPy-style broadcasting for compatible shapes
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Mathematical Accuracy**: High-precision multiplication computation
15//!
16//! # Mathematical Properties
17//!
18//! The multiplication operations have the following properties:
19//! - **Commutative**: a * b = b * a
20//! - **Associative**: (a * b) * c = a * (b * c)
21//! - **Distributive**: a * (b + c) = a * b + a * c
22//! - **Identity**: a * 1 = a
23//! - **Zero**: a * 0 = 0
24//! - **Gradient**: d/dx(a * b) = b, d/dy(a * b) = a
25//!
26//! # Performance Characteristics
27//!
28//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
29//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
30//! - **Cache-friendly Access**: Linear memory access patterns
31//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
32//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41// Note: removed manual prefetching; linear access + hardware prefetch is sufficient
42
43impl Tensor {
44    /// Element-wise multiplication with another tensor with broadcasting support.
45    ///
46    /// Performs element-wise multiplication with automatic broadcasting: `output[i] = self[i] * other[i]`
47    ///
48    /// Broadcasting enables multiplication between tensors of different but compatible shapes.
49    /// Compatible shapes follow NumPy broadcasting rules:
50    /// - Dimensions are aligned from the rightmost dimension
51    /// - Dimensions are compatible if they are equal, or one of them is 1
52    /// - Missing dimensions are treated as 1
53    ///
54    /// # Arguments
55    /// * `other` - Tensor to multiply. Shapes must be broadcast-compatible.
56    ///
57    /// # Returns
58    /// A new tensor containing the element-wise product with broadcast result shape
59    ///
60    /// # Examples
61    ///
62    /// ## Same Shape Multiplication
63    ///
64    /// ```
65    /// use train_station::Tensor;
66    ///
67    /// let a = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![3]).unwrap();
68    /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0], vec![3]).unwrap();
69    /// let c = a.mul_tensor(&b);
70    /// assert_eq!(c.shape().dims(), vec![3]);
71    /// assert_eq!(c.get(&[0]), 10.0); // 2.0 * 5.0
72    /// assert_eq!(c.get(&[1]), 18.0); // 3.0 * 6.0
73    /// assert_eq!(c.get(&[2]), 28.0); // 4.0 * 7.0
74    /// ```
75    ///
76    /// ## Broadcasting Multiplication
77    ///
78    /// ```
79    /// use train_station::Tensor;
80    ///
81    /// let a = Tensor::from_slice(&[2.0, 3.0], vec![2, 1]).unwrap();
82    /// let b = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![1, 3]).unwrap();
83    /// let c = a.mul_tensor(&b);
84    /// assert_eq!(c.shape().dims(), vec![2, 3]);
85    /// // Result: [[20.0, 40.0, 60.0], [30.0, 60.0, 90.0]]
86    /// assert_eq!(c.get(&[0, 0]), 20.0); // 2.0 * 10.0
87    /// assert_eq!(c.get(&[0, 1]), 40.0); // 2.0 * 20.0
88    /// assert_eq!(c.get(&[1, 0]), 30.0); // 3.0 * 10.0
89    /// ```
90    ///
91    /// # Panics
92    /// Panics if tensor shapes are not broadcast-compatible
93    #[inline]
94    #[track_caller]
95    pub fn mul_tensor(&self, other: &Tensor) -> Tensor {
96        // Check if shapes are identical for fast path
97        if self.shape().dims() == other.shape().dims() {
98            return self.mul_tensor_same_shape(other);
99        }
100
101        // Zero-copy broadcast views, then reuse same-shape optimized path
102        use crate::tensor::ops::broadcasting::{broadcast_shapes_cow, BroadcastError};
103        let mut result = match broadcast_shapes_cow(self, other) {
104            Ok((a_b, b_b, _)) => a_b.as_ref().mul_tensor_optimized(b_b.as_ref()),
105            Err(BroadcastError::IncompatibleShapes { .. }) => {
106                panic!(
107                    "Cannot broadcast tensor shapes {:?} and {:?}: {}",
108                    self.shape().dims(),
109                    other.shape().dims(),
110                    "incompatible shapes"
111                );
112            }
113            Err(BroadcastError::AllocationFailed) => {
114                panic!("Memory allocation failed during broadcasting");
115            }
116        };
117
118        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
119            result.set_requires_grad_internal(true);
120            let operands = vec![self.clone(), other.clone()];
121            let grad_fn = GradFn::Mul {
122                is_tensor_mul: true,
123                scalar: None,
124                operands: Some(operands),
125                original_shapes: Some((
126                    self.shape().dims().to_vec(),
127                    other.shape().dims().to_vec(),
128                )),
129            };
130            result.set_grad_fn(grad_fn.clone());
131
132            let mut input_ids = Vec::with_capacity(2);
133            if self.requires_grad() {
134                input_ids.push(self.id());
135            }
136            if other.requires_grad() {
137                input_ids.push(other.id());
138            }
139            GradEngine::register_operation(result.id(), input_ids, grad_fn);
140        }
141
142        result
143    }
144
145    /// Element-wise multiplication for tensors with identical shapes (fast path)
146    ///
147    /// This is an optimized path for tensors that already have the same shape,
148    /// avoiding the overhead of broadcasting computation. This method provides
149    /// better performance when tensors have matching dimensions.
150    ///
151    /// # Arguments
152    ///
153    /// * `other` - Tensor to multiply, must have the same shape as self
154    ///
155    /// # Returns
156    ///
157    /// A new tensor containing the element-wise product
158    ///
159    /// # Performance Characteristics
160    ///
161    /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
162    /// - **SIMD Optimization**: Uses optimized multiplication when available
163    /// - **Memory Efficiency**: Direct element-wise computation without shape conversion
164    /// - **Gradient Tracking**: Full gradtrack support with efficient gradient computation
165    ///
166    /// # Panics
167    ///
168    /// Panics if tensor shapes do not match
169    ///
170    /// # Implementation Details
171    ///
172    /// This method is called internally by `mul_tensor()` when shapes are identical.
173    /// It provides a performance optimization by skipping the broadcasting logic
174    /// and directly calling the optimized multiplication implementation.
175    #[inline]
176    fn mul_tensor_same_shape(&self, other: &Tensor) -> Tensor {
177        assert_eq!(
178            self.shape(),
179            other.shape(),
180            "Tensor shapes must match for same-shape multiplication"
181        );
182        let mut result = self.mul_tensor_optimized(other);
183
184        if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
185            result.set_requires_grad_internal(true);
186            let operands = vec![self.clone(), other.clone()];
187            let grad_fn = GradFn::Mul {
188                is_tensor_mul: true,
189                scalar: None,
190                operands: Some(operands),
191                original_shapes: None, // Same shape case
192            };
193            result.set_grad_fn(grad_fn.clone());
194
195            let mut input_ids = Vec::with_capacity(2);
196            if self.requires_grad() {
197                input_ids.push(self.id());
198            }
199            if other.requires_grad() {
200                input_ids.push(other.id());
201            }
202            GradEngine::register_operation(result.id(), input_ids, grad_fn);
203        }
204
205        result
206    }
207
208    /// Broadcast multiplication with a scalar value.
209    ///
210    /// Multiplies every element by the scalar: `output[i] = self[i] * scalar`
211    ///
212    /// # Arguments
213    /// * `scalar` - Value to multiply with each element
214    ///
215    /// # Returns
216    /// A new tensor with each element multiplied by the scalar
217    ///
218    /// # Examples
219    ///
220    /// ## Basic Scalar Multiplication
221    ///
222    /// ```
223    /// use train_station::Tensor;
224    ///
225    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
226    /// let b = a.mul_scalar(10.0);
227    /// assert_eq!(b.shape().dims(), vec![3]);
228    /// assert_eq!(b.get(&[0]), 10.0); // 1.0 * 10.0
229    /// assert_eq!(b.get(&[1]), 20.0); // 2.0 * 10.0
230    /// assert_eq!(b.get(&[2]), 30.0); // 3.0 * 10.0
231    /// ```
232    ///
233    /// ## Negative Scalar Multiplication
234    ///
235    /// ```
236    /// use train_station::Tensor;
237    ///
238    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
239    /// let b = a.mul_scalar(-2.0);
240    /// assert_eq!(b.shape().dims(), vec![3]);
241    /// assert_eq!(b.get(&[0]), -2.0); // 1.0 * -2.0
242    /// assert_eq!(b.get(&[1]), -4.0); // 2.0 * -2.0
243    /// assert_eq!(b.get(&[2]), -6.0); // 3.0 * -2.0
244    /// ```
245    #[inline]
246    #[track_caller]
247    pub fn mul_scalar(&self, scalar: f32) -> Tensor {
248        let mut result = self.mul_scalar_optimized(scalar);
249
250        if self.requires_grad() && is_grad_enabled() {
251            result.set_requires_grad_internal(true);
252            let grad_fn = GradFn::Mul {
253                is_tensor_mul: false,
254                scalar: Some(scalar),
255                operands: None,
256                original_shapes: None, // Scalar case
257            };
258            result.set_grad_fn(grad_fn.clone());
259            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
260        }
261
262        result
263    }
264    /// Optimized tensor multiplication using SIMD when available
265    ///
266    /// Performs element-wise multiplication using SIMD optimization when available
267    /// and falling back to optimized scalar computation. This is the core implementation
268    /// used by `mul_tensor()`.
269    ///
270    /// # Arguments
271    ///
272    /// * `other` - The tensor to multiply with
273    ///
274    /// # Returns
275    ///
276    /// A new tensor with the result of the multiplication
277    ///
278    /// # Performance Characteristics
279    ///
280    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
281    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
282    /// - **Cache-friendly**: Linear memory access patterns
283    /// - **Memory Safety**: Ensures contiguous memory layout for correctness
284    /// - **Zero-sized Handling**: Fast return for empty tensors
285    ///
286    /// # Safety
287    ///
288    /// This operation assumes the tensors have the same shape. The method ensures
289    /// contiguous memory layout for both input tensors to guarantee correctness
290    /// with view tensors.
291    ///
292    /// # Implementation Details
293    ///
294    /// Automatically selects between SIMD and scalar implementations based on hardware
295    /// capabilities. Ensures both input tensors are contiguous for memory safety.
296    /// SIMD implementation processes 32 elements per iteration with 4x unrolling.
297    #[inline]
298    pub(crate) fn mul_tensor_optimized(&self, other: &Tensor) -> Tensor {
299        assert_eq!(
300            self.shape().dims(),
301            other.shape().dims(),
302            "Tensor dims must match"
303        );
304        // Ensure contiguous sources for correctness with view tensors
305        let a_src = if self.is_contiguous() {
306            self.clone()
307        } else {
308            self.contiguous()
309        };
310        let b_src = if other.is_contiguous() {
311            other.clone()
312        } else {
313            other.contiguous()
314        };
315
316        let mut output = Tensor::new(self.shape().dims().to_vec());
317
318        unsafe {
319            let a = a_src.as_ptr();
320            let b = b_src.as_ptr();
321            let dst = output.as_mut_ptr();
322
323            #[cfg(target_arch = "x86_64")]
324            {
325                // Use SIMD for better performance when available
326                if is_x86_feature_detected!("avx2") {
327                    self.mul_tensors_simd_avx2_optimized(a, b, dst);
328                    return output;
329                }
330            }
331
332            // Fallback to scalar operations with better cache usage
333            self.mul_tensors_scalar_optimized(a, b, dst);
334        }
335
336        output
337    }
338
339    /// AVX2-optimized tensor multiplication implementation
340    ///
341    /// Performs element-wise multiplication using AVX2 SIMD instructions for maximum
342    /// performance on x86_64 architectures with AVX2 support.
343    ///
344    /// # Arguments
345    ///
346    /// * `a` - Pointer to first tensor data
347    /// * `b` - Pointer to second tensor data
348    /// * `dst` - Pointer to output tensor data
349    ///
350    /// # Safety
351    ///
352    /// Requires valid pointers with sufficient memory for the tensor size.
353    /// All pointers must point to valid tensor data. Requires AVX2 support.
354    ///
355    /// # Performance Characteristics
356    ///
357    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
358    /// - **Memory Access**: Linear access patterns for cache efficiency
359    /// - **Vector Operations**: Uses AVX2 multiplication instructions
360    /// - **Fallback**: Handles remaining elements with scalar operations
361    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
362    ///
363    /// # Implementation Details
364    ///
365    /// Uses AVX2 vector instructions to process 8 elements simultaneously.
366    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
367    /// Processes remaining elements with scalar operations for complete coverage.
368    #[cfg(target_arch = "x86_64")]
369    #[inline]
370    #[target_feature(enable = "avx2")]
371    unsafe fn mul_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
372        let size = self.size();
373        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
374        let mut offset = 0;
375
376        // Unrolled SIMD loop for throughput
377        for _ in 0..simd_count {
378            // Process 4 AVX2 vectors (32 elements) per iteration
379            let a_vec1 = _mm256_loadu_ps(a.add(offset));
380            let b_vec1 = _mm256_loadu_ps(b.add(offset));
381            let mul_vec1 = _mm256_mul_ps(a_vec1, b_vec1);
382            _mm256_storeu_ps(dst.add(offset), mul_vec1);
383
384            let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
385            let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
386            let mul_vec2 = _mm256_mul_ps(a_vec2, b_vec2);
387            _mm256_storeu_ps(dst.add(offset + 8), mul_vec2);
388
389            let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
390            let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
391            let mul_vec3 = _mm256_mul_ps(a_vec3, b_vec3);
392            _mm256_storeu_ps(dst.add(offset + 16), mul_vec3);
393
394            let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
395            let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
396            let mul_vec4 = _mm256_mul_ps(a_vec4, b_vec4);
397            _mm256_storeu_ps(dst.add(offset + 24), mul_vec4);
398
399            offset += 32;
400        }
401
402        // Handle remaining 8-element blocks, then tail
403        let remaining_full_blocks = (size - offset) / 8;
404        for _ in 0..remaining_full_blocks {
405            let a_vec = _mm256_loadu_ps(a.add(offset));
406            let b_vec = _mm256_loadu_ps(b.add(offset));
407            let mul_vec = _mm256_mul_ps(a_vec, b_vec);
408            _mm256_storeu_ps(dst.add(offset), mul_vec);
409            offset += 8;
410        }
411        while offset + 4 <= size {
412            *dst.add(offset) = *a.add(offset) * *b.add(offset);
413            *dst.add(offset + 1) = *a.add(offset + 1) * *b.add(offset + 1);
414            *dst.add(offset + 2) = *a.add(offset + 2) * *b.add(offset + 2);
415            *dst.add(offset + 3) = *a.add(offset + 3) * *b.add(offset + 3);
416            offset += 4;
417        }
418        for i in offset..size {
419            *dst.add(i) = *a.add(i) * *b.add(i);
420        }
421    }
422
423    /// Optimized scalar tensor multiplication fallback
424    ///
425    /// Performs element-wise multiplication using optimized scalar operations with
426    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
427    ///
428    /// # Arguments
429    ///
430    /// * `a` - Pointer to first tensor data
431    /// * `b` - Pointer to second tensor data
432    /// * `dst` - Pointer to output tensor data
433    ///
434    /// # Safety
435    ///
436    /// Requires valid pointers with sufficient memory for the tensor size.
437    /// All pointers must point to valid tensor data.
438    ///
439    /// # Performance Characteristics
440    ///
441    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
442    /// - **Memory Access**: Linear access patterns for cache efficiency
443    /// - **Fallback**: Handles remaining elements with scalar operations
444    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
445    /// - **Mathematical Accuracy**: High-precision scalar multiplication
446    ///
447    /// # Implementation Details
448    ///
449    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
450    /// Processes elements in groups of 4 to improve instruction-level parallelism
451    /// and reduce loop overhead.
452    #[inline]
453    unsafe fn mul_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
454        let size = self.size();
455
456        // Use unrolled loops for better instruction throughput
457        let unroll_count = size / 4;
458        let mut i = 0;
459
460        // Process 4 elements at a time for better cache utilization
461        while i < unroll_count {
462            let idx = i * 4;
463            dst.add(idx).write(a.add(idx).read() * b.add(idx).read());
464            dst.add(idx + 1)
465                .write(a.add(idx + 1).read() * b.add(idx + 1).read());
466            dst.add(idx + 2)
467                .write(a.add(idx + 2).read() * b.add(idx + 2).read());
468            dst.add(idx + 3)
469                .write(a.add(idx + 3).read() * b.add(idx + 3).read());
470            i += 1;
471        }
472
473        // Handle remaining elements
474        for j in (unroll_count * 4)..size {
475            dst.add(j).write(a.add(j).read() * b.add(j).read());
476        }
477    }
478
479    /// Optimized scalar multiplication using SIMD when available
480    ///
481    /// Performs element-wise scalar multiplication using SIMD optimization when available
482    /// and falling back to optimized scalar computation. This is the core implementation
483    /// used by `mul_scalar()`.
484    ///
485    /// # Arguments
486    ///
487    /// * `scalar` - The scalar value to multiply by
488    ///
489    /// # Returns
490    ///
491    /// A new tensor with the result of the multiplication
492    ///
493    /// # Performance Characteristics
494    ///
495    /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
496    /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
497    /// - **Cache-friendly**: Linear memory access patterns
498    /// - **Memory Safety**: Ensures contiguous memory layout for correctness
499    /// - **Zero-sized Handling**: Fast return for empty tensors
500    ///
501    /// # Implementation Details
502    ///
503    /// Automatically selects between SIMD and scalar implementations based on hardware
504    /// capabilities. Ensures the input tensor is contiguous for memory safety.
505    /// SIMD implementation processes 32 elements per iteration with 4x unrolling.
506    #[inline]
507    pub(crate) fn mul_scalar_optimized(&self, scalar: f32) -> Tensor {
508        // Ensure contiguous source for correctness with view tensors
509        let src_self = if self.is_contiguous() {
510            self.clone()
511        } else {
512            self.contiguous()
513        };
514        let mut output = Tensor::new(self.shape().dims().to_vec());
515
516        unsafe {
517            let src = src_self.as_ptr();
518            let dst = output.as_mut_ptr();
519
520            #[cfg(target_arch = "x86_64")]
521            {
522                // Use SIMD for better performance when available
523                if is_x86_feature_detected!("avx2") {
524                    self.mul_scalar_simd_avx2_optimized(src, dst, scalar);
525                    return output;
526                }
527            }
528
529            // Fallback to scalar operations with better cache usage
530            self.mul_scalar_fallback_optimized(src, dst, scalar);
531        }
532
533        output
534    }
535
536    /// AVX2-optimized scalar multiplication implementation
537    ///
538    /// Performs element-wise scalar multiplication using AVX2 SIMD instructions for maximum
539    /// performance on x86_64 architectures with AVX2 support.
540    ///
541    /// # Arguments
542    ///
543    /// * `src` - Pointer to source tensor data
544    /// * `dst` - Pointer to output tensor data
545    /// * `scalar` - Scalar value to multiply by
546    ///
547    /// # Safety
548    ///
549    /// Requires valid pointers with sufficient memory for the tensor size.
550    /// All pointers must point to valid tensor data. Requires AVX2 support.
551    ///
552    /// # Performance Characteristics
553    ///
554    /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
555    /// - **Memory Access**: Linear access patterns for cache efficiency
556    /// - **Vector Operations**: Uses AVX2 multiplication instructions
557    /// - **Fallback**: Handles remaining elements with scalar operations
558    /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
559    ///
560    /// # Implementation Details
561    ///
562    /// Uses AVX2 vector instructions to process 8 elements simultaneously.
563    /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
564    /// Processes remaining elements with scalar operations for complete coverage.
565    #[cfg(target_arch = "x86_64")]
566    #[inline]
567    #[target_feature(enable = "avx2")]
568    unsafe fn mul_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
569        let size = self.size();
570        let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
571        let scalar_vec = _mm256_set1_ps(scalar);
572        let mut offset = 0;
573
574        // Unrolled SIMD loop for throughput
575        for _ in 0..simd_count {
576            // Process 4 AVX2 vectors (32 elements) per iteration
577            let src_vec1 = _mm256_loadu_ps(src.add(offset));
578            let mul_vec1 = _mm256_mul_ps(src_vec1, scalar_vec);
579            _mm256_storeu_ps(dst.add(offset), mul_vec1);
580
581            let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
582            let mul_vec2 = _mm256_mul_ps(src_vec2, scalar_vec);
583            _mm256_storeu_ps(dst.add(offset + 8), mul_vec2);
584
585            let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
586            let mul_vec3 = _mm256_mul_ps(src_vec3, scalar_vec);
587            _mm256_storeu_ps(dst.add(offset + 16), mul_vec3);
588
589            let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
590            let mul_vec4 = _mm256_mul_ps(src_vec4, scalar_vec);
591            _mm256_storeu_ps(dst.add(offset + 24), mul_vec4);
592
593            offset += 32;
594        }
595
596        // Handle remaining 8-element blocks, then tail
597        let remaining_full_blocks = (size - offset) / 8;
598        for _ in 0..remaining_full_blocks {
599            let src_vec = _mm256_loadu_ps(src.add(offset));
600            let mul_vec = _mm256_mul_ps(src_vec, scalar_vec);
601            _mm256_storeu_ps(dst.add(offset), mul_vec);
602            offset += 8;
603        }
604        while offset + 4 <= size {
605            *dst.add(offset) = *src.add(offset) * scalar;
606            *dst.add(offset + 1) = *src.add(offset + 1) * scalar;
607            *dst.add(offset + 2) = *src.add(offset + 2) * scalar;
608            *dst.add(offset + 3) = *src.add(offset + 3) * scalar;
609            offset += 4;
610        }
611        for i in offset..size {
612            *dst.add(i) = *src.add(i) * scalar;
613        }
614    }
615
616    /// Optimized scalar multiplication fallback
617    ///
618    /// Performs element-wise scalar multiplication using optimized scalar operations with
619    /// 4x unrolling for better instruction-level parallelism and cache efficiency.
620    ///
621    /// # Arguments
622    ///
623    /// * `src` - Pointer to source tensor data
624    /// * `dst` - Pointer to output tensor data
625    /// * `scalar` - Scalar value to multiply by
626    ///
627    /// # Safety
628    ///
629    /// Requires valid pointers with sufficient memory for the tensor size.
630    /// All pointers must point to valid tensor data.
631    ///
632    /// # Performance Characteristics
633    ///
634    /// - **Unrolling**: 4x unrolling for instruction-level parallelism
635    /// - **Memory Access**: Linear access patterns for cache efficiency
636    /// - **Fallback**: Handles remaining elements with scalar operations
637    /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
638    /// - **Mathematical Accuracy**: High-precision scalar multiplication
639    ///
640    /// # Implementation Details
641    ///
642    /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
643    /// Processes elements in groups of 4 to improve instruction-level parallelism
644    /// and reduce loop overhead.
645    #[inline]
646    unsafe fn mul_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
647        let size = self.size();
648
649        // Use unrolled loops for better instruction throughput
650        let unroll_count = size / 4;
651        let mut i = 0;
652
653        // Process 4 elements at a time for better cache utilization
654        while i < unroll_count {
655            let idx = i * 4;
656            dst.add(idx).write(src.add(idx).read() * scalar);
657            dst.add(idx + 1).write(src.add(idx + 1).read() * scalar);
658            dst.add(idx + 2).write(src.add(idx + 2).read() * scalar);
659            dst.add(idx + 3).write(src.add(idx + 3).read() * scalar);
660            i += 1;
661        }
662
663        // Handle remaining elements
664        for j in (unroll_count * 4)..size {
665            dst.add(j).write(src.add(j).read() * scalar);
666        }
667    }
668}
669
670#[cfg(test)]
671mod tests {
672    use super::*;
673
674    #[test]
675    fn test_tensor_multiplication() {
676        let a = Tensor::ones(vec![2, 3]);
677        let mut b = Tensor::ones(vec![2, 3]);
678        b.fill(2.0);
679        let result = a.mul_tensor_optimized(&b);
680
681        assert_eq!(result.shape().dims(), vec![2, 3]);
682        assert_eq!(result.size(), 6);
683
684        // Check that all values are 2.0 (1.0 * 2.0)
685        unsafe {
686            for i in 0..result.size() {
687                assert!((result.as_ptr().add(i).read() - 2.0).abs() < 1e-6);
688            }
689        }
690    }
691
692    #[test]
693    fn test_scalar_multiplication() {
694        let tensor = Tensor::ones(vec![2, 2]);
695        let result = tensor.mul_scalar_optimized(3.0);
696
697        assert_eq!(result.shape().dims(), vec![2, 2]);
698        assert_eq!(result.size(), 4);
699
700        // Check that all values are 3.0 (1.0 * 3.0)
701        unsafe {
702            for i in 0..result.size() {
703                assert!((result.as_ptr().add(i).read() - 3.0).abs() < 1e-6);
704            }
705        }
706    }
707
708    #[test]
709    fn test_zero_multiplication() {
710        let tensor = Tensor::ones(vec![2, 3]);
711        let result = tensor.mul_scalar_optimized(0.0);
712
713        assert_eq!(result.shape().dims(), vec![2, 3]);
714        assert_eq!(result.size(), 6);
715
716        // Check that all values are 0.0 (1.0 * 0.0)
717        unsafe {
718            for i in 0..result.size() {
719                assert!((result.as_ptr().add(i).read() - 0.0).abs() < 1e-6);
720            }
721        }
722    }
723
724    #[test]
725    fn test_negative_multiplication() {
726        let tensor = Tensor::ones(vec![2, 3]);
727        let result = tensor.mul_scalar_optimized(-2.0);
728
729        assert_eq!(result.shape().dims(), vec![2, 3]);
730        assert_eq!(result.size(), 6);
731
732        // Check that all values are -2.0 (1.0 * -2.0)
733        unsafe {
734            for i in 0..result.size() {
735                assert!((result.as_ptr().add(i).read() - (-2.0)).abs() < 1e-6);
736            }
737        }
738    }
739
740    #[test]
741    #[should_panic(expected = "must match")]
742    fn test_mismatched_shapes() {
743        let a = Tensor::ones(vec![2, 3]);
744        let b = Tensor::ones(vec![3, 2]);
745        a.mul_tensor_optimized(&b);
746    }
747
748    #[test]
749    fn test_edge_cases() {
750        // Test with zero tensor
751        let zero_tensor = Tensor::zeros(vec![2, 3]);
752        let other = Tensor::ones(vec![2, 3]);
753        let result = zero_tensor.mul_tensor_optimized(&other);
754
755        assert_eq!(result.shape().dims(), vec![2, 3]);
756        assert_eq!(result.size(), 6);
757
758        // Check that all values are 0.0 (0.0 * 1.0)
759        unsafe {
760            for i in 0..result.size() {
761                assert!((result.as_ptr().add(i).read() - 0.0).abs() < 1e-6);
762            }
763        }
764
765        // Test with negative values
766        let mut neg_tensor = Tensor::ones(vec![2, 3]);
767        neg_tensor.fill(-1.0);
768        let result = neg_tensor.mul_scalar_optimized(2.0);
769
770        assert_eq!(result.shape().dims(), vec![2, 3]);
771        assert_eq!(result.size(), 6);
772
773        // Check that all values are -2.0 (-1.0 * 2.0)
774        unsafe {
775            for i in 0..result.size() {
776                assert!((result.as_ptr().add(i).read() - (-2.0)).abs() < 1e-6);
777            }
778        }
779    }
780
781    #[test]
782    fn test_large_tensor_multiplication() {
783        let a = Tensor::ones(vec![100, 100]);
784        let mut b = Tensor::ones(vec![100, 100]);
785        b.fill(1.5);
786        let result = a.mul_tensor_optimized(&b);
787
788        assert_eq!(result.shape().dims(), vec![100, 100]);
789        assert_eq!(result.size(), 10000);
790
791        // Check that all values are 1.5 (1.0 * 1.5)
792        unsafe {
793            for i in 0..result.size() {
794                assert!((result.as_ptr().add(i).read() - 1.5).abs() < 1e-6);
795            }
796        }
797    }
798
799    #[test]
800    fn test_multiplication_with_gradtrack() {
801        // Test scalar multiplication with gradtrack
802        let a = Tensor::ones(vec![2, 3]).with_requires_grad();
803        let mut result = a.mul_scalar(3.0);
804
805        // Check result values: 1.0 * 3.0 = 3.0
806        unsafe {
807            for i in 0..result.size() {
808                let val = result.as_ptr().add(i).read();
809                assert!((val - 3.0).abs() < 1e-6, "Expected 3.0, got {}", val);
810            }
811        }
812
813        result.backward(None);
814
815        // Check gradient: d/dx(3x) = 3
816        if let Some(grad) = a.grad_owned() {
817            unsafe {
818                for i in 0..grad.size() {
819                    let val = grad.as_ptr().add(i).read();
820                    assert!(
821                        (val - 3.0).abs() < 1e-6,
822                        "Expected gradient 3.0, got {}",
823                        val
824                    );
825                }
826            }
827        } else {
828            panic!("No gradient computed for scalar multiplication!");
829        }
830
831        // Test tensor multiplication with gradtrack
832        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
833        let mut b = Tensor::ones(vec![2, 2]);
834        b.fill(3.0);
835        let b = b.with_requires_grad();
836
837        let mut result = a.mul_tensor(&b);
838
839        // Check result values: 1.0 * 3.0 = 3.0
840        unsafe {
841            for i in 0..result.size() {
842                let val = result.as_ptr().add(i).read();
843                assert!((val - 3.0).abs() < 1e-6, "Expected 3.0, got {}", val);
844            }
845        }
846
847        result.backward(None);
848
849        // Check gradients: ∂(a*b)/∂a = b, ∂(a*b)/∂b = a
850        // For a = 1.0, b = 3.0: ∂(a*b)/∂a = 3.0, ∂(a*b)/∂b = 1.0
851        if let Some(grad_a) = a.grad_owned() {
852            unsafe {
853                for i in 0..grad_a.size() {
854                    let val = grad_a.as_ptr().add(i).read();
855                    assert!(
856                        (val - 3.0).abs() < 1e-6,
857                        "Expected gradient A = 3.0 (∂(a*b)/∂a = b), got {}",
858                        val
859                    );
860                }
861            }
862        } else {
863            panic!("No gradient A computed for tensor multiplication!");
864        }
865
866        if let Some(grad_b) = b.grad_owned() {
867            unsafe {
868                for i in 0..grad_b.size() {
869                    let val = grad_b.as_ptr().add(i).read();
870                    assert!(
871                        (val - 1.0).abs() < 1e-6,
872                        "Expected gradient B = 1.0 (∂(a*b)/∂b = a), got {}",
873                        val
874                    );
875                }
876            }
877        } else {
878            panic!("No gradient B computed for tensor multiplication!");
879        }
880    }
881
882    #[test]
883    fn test_mixed_mul_add_operations_with_gradtrack() {
884        // Test complex computation: (a * 2) + (b * 3) - 1
885        let a = Tensor::ones(vec![2, 2]).with_requires_grad();
886        let mut b = Tensor::ones(vec![2, 2]);
887        b.fill(3.0);
888        let b = b.with_requires_grad();
889
890        let scalar1 = 2.0;
891        let scalar2 = 3.0;
892
893        // Compute: (a * scalar1) + (b * scalar2) - 1
894        let mul_a = a.mul_scalar(scalar1); // a * 2
895        let mul_b = b.mul_scalar(scalar2); // b * 3
896        let add_result = mul_a.add_tensor(&mul_b); // (a * 2) + (b * 3)
897        let mut final_result = add_result.sub_scalar(1.0); // (a * 2) + (b * 3) - 1
898
899        // Check result values: (1*2) + (3*3) - 1 = 2 + 9 - 1 = 10
900        unsafe {
901            for i in 0..final_result.size() {
902                let val = final_result.as_ptr().add(i).read();
903                assert!((val - 10.0).abs() < 1e-6, "Expected 10.0, got {}", val);
904            }
905        }
906
907        final_result.backward(None);
908
909        // Check gradients: d/dx((2x + 3y - 1)) = 2, d/dy((2x + 3y - 1)) = 3
910        if let Some(grad_a) = a.grad_owned() {
911            unsafe {
912                for i in 0..grad_a.size() {
913                    let val = grad_a.as_ptr().add(i).read();
914                    assert!(
915                        (val - 2.0).abs() < 1e-6,
916                        "Expected gradient A = 2.0, got {}",
917                        val
918                    );
919                }
920            }
921        } else {
922            panic!("No gradient A computed for mixed operations!");
923        }
924
925        if let Some(grad_b) = b.grad_owned() {
926            unsafe {
927                for i in 0..grad_b.size() {
928                    let val = grad_b.as_ptr().add(i).read();
929                    assert!(
930                        (val - 3.0).abs() < 1e-6,
931                        "Expected gradient B = 3.0, got {}",
932                        val
933                    );
934                }
935            }
936        } else {
937            panic!("No gradient B computed for mixed operations!");
938        }
939    }
940
941    #[test]
942    fn test_mul_broadcasting_gradients() {
943        use crate::gradtrack::clear_gradients;
944        clear_gradients();
945
946        // Test case: [2, 3] * [1, 3] -> [2, 3]
947        // For multiplication: d/da (a * b) = b, d/db (a * b) = a
948        // So grad_a = grad_output * b, grad_b = grad_output * a (then reduced)
949
950        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
951            .unwrap()
952            .with_requires_grad();
953        let b = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![1, 3])
954            .unwrap()
955            .with_requires_grad();
956
957        let mut result = a.mul_tensor(&b);
958        assert_eq!(result.shape().dims(), vec![2, 3]);
959
960        // Set upstream gradient as ones
961        result.backward(None);
962
963        // Check gradients
964        let grad_a = a.grad_owned().expect("grad_a should exist");
965        let grad_b = b.grad_owned().expect("grad_b should exist");
966
967        println!(
968            "Original shapes: a={:?}, b={:?}",
969            a.shape().dims(),
970            b.shape().dims()
971        );
972        println!(
973            "Gradient shapes: grad_a={:?}, grad_b={:?}",
974            grad_a.shape().dims(),
975            grad_b.shape().dims()
976        );
977
978        // grad_a should have same shape as a: [2, 3]
979        assert_eq!(
980            grad_a.shape().dims(),
981            vec![2, 3],
982            "grad_a should match original shape of a"
983        );
984
985        // grad_b should have same shape as b: [1, 3]
986        assert_eq!(
987            grad_b.shape().dims(),
988            vec![1, 3],
989            "grad_b should match original shape of b"
990        );
991
992        // For multiplication gradients:
993        // grad_a = grad_output * b = 1.0 * [2.0, 3.0, 4.0] broadcasted to [2, 3]
994        let expected_grad_a = [2.0, 3.0, 4.0, 2.0, 3.0, 4.0];
995        for (i, val) in expected_grad_a.iter().enumerate().take(grad_a.size()) {
996            let actual = unsafe { *grad_a.as_ptr().add(i) };
997            assert!(
998                (actual - val).abs() < 1e-6,
999                "grad_a[{}] = {} should be {}",
1000                i,
1001                actual,
1002                val
1003            );
1004        }
1005
1006        // grad_b = grad_output * a summed over broadcasted dimension
1007        // a = [1,2,3,4,5,6] -> grad_b should be [1+4, 2+5, 3+6] = [5, 7, 9]
1008        let expected_grad_b = [5.0, 7.0, 9.0];
1009        for (i, val) in expected_grad_b.iter().enumerate().take(grad_b.size()) {
1010            let actual = unsafe { *grad_b.as_ptr().add(i) };
1011            assert!(
1012                (actual - val).abs() < 1e-6,
1013                "grad_b[{}] = {} should be {}",
1014                i,
1015                actual,
1016                val
1017            );
1018        }
1019    }
1020}