train_station/tensor/ops/mul.rs
1//! Multiplication operations for tensors
2//!
3//! Provides element-wise multiplication functions following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Multiplication**: `mul_tensor()` - Element-wise multiplication with another tensor (PyTorch `mul()` equivalent)
9//! - **Scalar Multiplication**: `mul_scalar()` - Broadcast multiplication with a scalar value
10//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
11//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
12//! - **Broadcasting Support**: NumPy-style broadcasting for compatible shapes
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Mathematical Accuracy**: High-precision multiplication computation
15//!
16//! # Mathematical Properties
17//!
18//! The multiplication operations have the following properties:
19//! - **Commutative**: a * b = b * a
20//! - **Associative**: (a * b) * c = a * (b * c)
21//! - **Distributive**: a * (b + c) = a * b + a * c
22//! - **Identity**: a * 1 = a
23//! - **Zero**: a * 0 = 0
24//! - **Gradient**: d/dx(a * b) = b, d/dy(a * b) = a
25//!
26//! # Performance Characteristics
27//!
28//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
29//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
30//! - **Cache-friendly Access**: Linear memory access patterns
31//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
32//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41// Note: removed manual prefetching; linear access + hardware prefetch is sufficient
42
43impl Tensor {
44 /// Element-wise multiplication with another tensor with broadcasting support.
45 ///
46 /// Performs element-wise multiplication with automatic broadcasting: `output[i] = self[i] * other[i]`
47 ///
48 /// Broadcasting enables multiplication between tensors of different but compatible shapes.
49 /// Compatible shapes follow NumPy broadcasting rules:
50 /// - Dimensions are aligned from the rightmost dimension
51 /// - Dimensions are compatible if they are equal, or one of them is 1
52 /// - Missing dimensions are treated as 1
53 ///
54 /// # Arguments
55 /// * `other` - Tensor to multiply. Shapes must be broadcast-compatible.
56 ///
57 /// # Returns
58 /// A new tensor containing the element-wise product with broadcast result shape
59 ///
60 /// # Examples
61 ///
62 /// ## Same Shape Multiplication
63 ///
64 /// ```
65 /// use train_station::Tensor;
66 ///
67 /// let a = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![3]).unwrap();
68 /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0], vec![3]).unwrap();
69 /// let c = a.mul_tensor(&b);
70 /// assert_eq!(c.shape().dims(), vec![3]);
71 /// assert_eq!(c.get(&[0]), 10.0); // 2.0 * 5.0
72 /// assert_eq!(c.get(&[1]), 18.0); // 3.0 * 6.0
73 /// assert_eq!(c.get(&[2]), 28.0); // 4.0 * 7.0
74 /// ```
75 ///
76 /// ## Broadcasting Multiplication
77 ///
78 /// ```
79 /// use train_station::Tensor;
80 ///
81 /// let a = Tensor::from_slice(&[2.0, 3.0], vec![2, 1]).unwrap();
82 /// let b = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![1, 3]).unwrap();
83 /// let c = a.mul_tensor(&b);
84 /// assert_eq!(c.shape().dims(), vec![2, 3]);
85 /// // Result: [[20.0, 40.0, 60.0], [30.0, 60.0, 90.0]]
86 /// assert_eq!(c.get(&[0, 0]), 20.0); // 2.0 * 10.0
87 /// assert_eq!(c.get(&[0, 1]), 40.0); // 2.0 * 20.0
88 /// assert_eq!(c.get(&[1, 0]), 30.0); // 3.0 * 10.0
89 /// ```
90 ///
91 /// # Panics
92 /// Panics if tensor shapes are not broadcast-compatible
93 #[inline]
94 #[track_caller]
95 pub fn mul_tensor(&self, other: &Tensor) -> Tensor {
96 // Check if shapes are identical for fast path
97 if self.shape().dims() == other.shape().dims() {
98 return self.mul_tensor_same_shape(other);
99 }
100
101 // Zero-copy broadcast views, then reuse same-shape optimized path
102 use crate::tensor::ops::broadcasting::{broadcast_shapes_cow, BroadcastError};
103 let mut result = match broadcast_shapes_cow(self, other) {
104 Ok((a_b, b_b, _)) => a_b.as_ref().mul_tensor_optimized(b_b.as_ref()),
105 Err(BroadcastError::IncompatibleShapes { .. }) => {
106 panic!(
107 "Cannot broadcast tensor shapes {:?} and {:?}: {}",
108 self.shape().dims(),
109 other.shape().dims(),
110 "incompatible shapes"
111 );
112 }
113 Err(BroadcastError::AllocationFailed) => {
114 panic!("Memory allocation failed during broadcasting");
115 }
116 };
117
118 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
119 result.set_requires_grad_internal(true);
120 let operands = vec![self.clone(), other.clone()];
121 let grad_fn = GradFn::Mul {
122 is_tensor_mul: true,
123 scalar: None,
124 operands: Some(operands),
125 original_shapes: Some((
126 self.shape().dims().to_vec(),
127 other.shape().dims().to_vec(),
128 )),
129 };
130 result.set_grad_fn(grad_fn.clone());
131
132 let mut input_ids = Vec::with_capacity(2);
133 if self.requires_grad() {
134 input_ids.push(self.id());
135 }
136 if other.requires_grad() {
137 input_ids.push(other.id());
138 }
139 GradEngine::register_operation(result.id(), input_ids, grad_fn);
140 }
141
142 result
143 }
144
145 /// Element-wise multiplication for tensors with identical shapes (fast path)
146 ///
147 /// This is an optimized path for tensors that already have the same shape,
148 /// avoiding the overhead of broadcasting computation. This method provides
149 /// better performance when tensors have matching dimensions.
150 ///
151 /// # Arguments
152 ///
153 /// * `other` - Tensor to multiply, must have the same shape as self
154 ///
155 /// # Returns
156 ///
157 /// A new tensor containing the element-wise product
158 ///
159 /// # Performance Characteristics
160 ///
161 /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
162 /// - **SIMD Optimization**: Uses optimized multiplication when available
163 /// - **Memory Efficiency**: Direct element-wise computation without shape conversion
164 /// - **Gradient Tracking**: Full gradtrack support with efficient gradient computation
165 ///
166 /// # Panics
167 ///
168 /// Panics if tensor shapes do not match
169 ///
170 /// # Implementation Details
171 ///
172 /// This method is called internally by `mul_tensor()` when shapes are identical.
173 /// It provides a performance optimization by skipping the broadcasting logic
174 /// and directly calling the optimized multiplication implementation.
175 #[inline]
176 fn mul_tensor_same_shape(&self, other: &Tensor) -> Tensor {
177 assert_eq!(
178 self.shape(),
179 other.shape(),
180 "Tensor shapes must match for same-shape multiplication"
181 );
182 let mut result = self.mul_tensor_optimized(other);
183
184 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
185 result.set_requires_grad_internal(true);
186 let operands = vec![self.clone(), other.clone()];
187 let grad_fn = GradFn::Mul {
188 is_tensor_mul: true,
189 scalar: None,
190 operands: Some(operands),
191 original_shapes: None, // Same shape case
192 };
193 result.set_grad_fn(grad_fn.clone());
194
195 let mut input_ids = Vec::with_capacity(2);
196 if self.requires_grad() {
197 input_ids.push(self.id());
198 }
199 if other.requires_grad() {
200 input_ids.push(other.id());
201 }
202 GradEngine::register_operation(result.id(), input_ids, grad_fn);
203 }
204
205 result
206 }
207
208 /// Broadcast multiplication with a scalar value.
209 ///
210 /// Multiplies every element by the scalar: `output[i] = self[i] * scalar`
211 ///
212 /// # Arguments
213 /// * `scalar` - Value to multiply with each element
214 ///
215 /// # Returns
216 /// A new tensor with each element multiplied by the scalar
217 ///
218 /// # Examples
219 ///
220 /// ## Basic Scalar Multiplication
221 ///
222 /// ```
223 /// use train_station::Tensor;
224 ///
225 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
226 /// let b = a.mul_scalar(10.0);
227 /// assert_eq!(b.shape().dims(), vec![3]);
228 /// assert_eq!(b.get(&[0]), 10.0); // 1.0 * 10.0
229 /// assert_eq!(b.get(&[1]), 20.0); // 2.0 * 10.0
230 /// assert_eq!(b.get(&[2]), 30.0); // 3.0 * 10.0
231 /// ```
232 ///
233 /// ## Negative Scalar Multiplication
234 ///
235 /// ```
236 /// use train_station::Tensor;
237 ///
238 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
239 /// let b = a.mul_scalar(-2.0);
240 /// assert_eq!(b.shape().dims(), vec![3]);
241 /// assert_eq!(b.get(&[0]), -2.0); // 1.0 * -2.0
242 /// assert_eq!(b.get(&[1]), -4.0); // 2.0 * -2.0
243 /// assert_eq!(b.get(&[2]), -6.0); // 3.0 * -2.0
244 /// ```
245 #[inline]
246 #[track_caller]
247 pub fn mul_scalar(&self, scalar: f32) -> Tensor {
248 let mut result = self.mul_scalar_optimized(scalar);
249
250 if self.requires_grad() && is_grad_enabled() {
251 result.set_requires_grad_internal(true);
252 let grad_fn = GradFn::Mul {
253 is_tensor_mul: false,
254 scalar: Some(scalar),
255 operands: None,
256 original_shapes: None, // Scalar case
257 };
258 result.set_grad_fn(grad_fn.clone());
259 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
260 }
261
262 result
263 }
264 /// Optimized tensor multiplication using SIMD when available
265 ///
266 /// Performs element-wise multiplication using SIMD optimization when available
267 /// and falling back to optimized scalar computation. This is the core implementation
268 /// used by `mul_tensor()`.
269 ///
270 /// # Arguments
271 ///
272 /// * `other` - The tensor to multiply with
273 ///
274 /// # Returns
275 ///
276 /// A new tensor with the result of the multiplication
277 ///
278 /// # Performance Characteristics
279 ///
280 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
281 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
282 /// - **Cache-friendly**: Linear memory access patterns
283 /// - **Memory Safety**: Ensures contiguous memory layout for correctness
284 /// - **Zero-sized Handling**: Fast return for empty tensors
285 ///
286 /// # Safety
287 ///
288 /// This operation assumes the tensors have the same shape. The method ensures
289 /// contiguous memory layout for both input tensors to guarantee correctness
290 /// with view tensors.
291 ///
292 /// # Implementation Details
293 ///
294 /// Automatically selects between SIMD and scalar implementations based on hardware
295 /// capabilities. Ensures both input tensors are contiguous for memory safety.
296 /// SIMD implementation processes 32 elements per iteration with 4x unrolling.
297 #[inline]
298 pub(crate) fn mul_tensor_optimized(&self, other: &Tensor) -> Tensor {
299 assert_eq!(
300 self.shape().dims(),
301 other.shape().dims(),
302 "Tensor dims must match"
303 );
304 // Ensure contiguous sources for correctness with view tensors
305 let a_src = if self.is_contiguous() {
306 self.clone()
307 } else {
308 self.contiguous()
309 };
310 let b_src = if other.is_contiguous() {
311 other.clone()
312 } else {
313 other.contiguous()
314 };
315
316 let mut output = Tensor::new(self.shape().dims().to_vec());
317
318 unsafe {
319 let a = a_src.as_ptr();
320 let b = b_src.as_ptr();
321 let dst = output.as_mut_ptr();
322
323 #[cfg(target_arch = "x86_64")]
324 {
325 // Use SIMD for better performance when available
326 if is_x86_feature_detected!("avx2") {
327 self.mul_tensors_simd_avx2_optimized(a, b, dst);
328 return output;
329 }
330 }
331
332 // Fallback to scalar operations with better cache usage
333 self.mul_tensors_scalar_optimized(a, b, dst);
334 }
335
336 output
337 }
338
339 /// AVX2-optimized tensor multiplication implementation
340 ///
341 /// Performs element-wise multiplication using AVX2 SIMD instructions for maximum
342 /// performance on x86_64 architectures with AVX2 support.
343 ///
344 /// # Arguments
345 ///
346 /// * `a` - Pointer to first tensor data
347 /// * `b` - Pointer to second tensor data
348 /// * `dst` - Pointer to output tensor data
349 ///
350 /// # Safety
351 ///
352 /// Requires valid pointers with sufficient memory for the tensor size.
353 /// All pointers must point to valid tensor data. Requires AVX2 support.
354 ///
355 /// # Performance Characteristics
356 ///
357 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
358 /// - **Memory Access**: Linear access patterns for cache efficiency
359 /// - **Vector Operations**: Uses AVX2 multiplication instructions
360 /// - **Fallback**: Handles remaining elements with scalar operations
361 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
362 ///
363 /// # Implementation Details
364 ///
365 /// Uses AVX2 vector instructions to process 8 elements simultaneously.
366 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
367 /// Processes remaining elements with scalar operations for complete coverage.
368 #[cfg(target_arch = "x86_64")]
369 #[inline]
370 #[target_feature(enable = "avx2")]
371 unsafe fn mul_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
372 let size = self.size();
373 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
374 let mut offset = 0;
375
376 // Unrolled SIMD loop for throughput
377 for _ in 0..simd_count {
378 // Process 4 AVX2 vectors (32 elements) per iteration
379 let a_vec1 = _mm256_loadu_ps(a.add(offset));
380 let b_vec1 = _mm256_loadu_ps(b.add(offset));
381 let mul_vec1 = _mm256_mul_ps(a_vec1, b_vec1);
382 _mm256_storeu_ps(dst.add(offset), mul_vec1);
383
384 let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
385 let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
386 let mul_vec2 = _mm256_mul_ps(a_vec2, b_vec2);
387 _mm256_storeu_ps(dst.add(offset + 8), mul_vec2);
388
389 let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
390 let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
391 let mul_vec3 = _mm256_mul_ps(a_vec3, b_vec3);
392 _mm256_storeu_ps(dst.add(offset + 16), mul_vec3);
393
394 let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
395 let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
396 let mul_vec4 = _mm256_mul_ps(a_vec4, b_vec4);
397 _mm256_storeu_ps(dst.add(offset + 24), mul_vec4);
398
399 offset += 32;
400 }
401
402 // Handle remaining 8-element blocks, then tail
403 let remaining_full_blocks = (size - offset) / 8;
404 for _ in 0..remaining_full_blocks {
405 let a_vec = _mm256_loadu_ps(a.add(offset));
406 let b_vec = _mm256_loadu_ps(b.add(offset));
407 let mul_vec = _mm256_mul_ps(a_vec, b_vec);
408 _mm256_storeu_ps(dst.add(offset), mul_vec);
409 offset += 8;
410 }
411 while offset + 4 <= size {
412 *dst.add(offset) = *a.add(offset) * *b.add(offset);
413 *dst.add(offset + 1) = *a.add(offset + 1) * *b.add(offset + 1);
414 *dst.add(offset + 2) = *a.add(offset + 2) * *b.add(offset + 2);
415 *dst.add(offset + 3) = *a.add(offset + 3) * *b.add(offset + 3);
416 offset += 4;
417 }
418 for i in offset..size {
419 *dst.add(i) = *a.add(i) * *b.add(i);
420 }
421 }
422
423 /// Optimized scalar tensor multiplication fallback
424 ///
425 /// Performs element-wise multiplication using optimized scalar operations with
426 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
427 ///
428 /// # Arguments
429 ///
430 /// * `a` - Pointer to first tensor data
431 /// * `b` - Pointer to second tensor data
432 /// * `dst` - Pointer to output tensor data
433 ///
434 /// # Safety
435 ///
436 /// Requires valid pointers with sufficient memory for the tensor size.
437 /// All pointers must point to valid tensor data.
438 ///
439 /// # Performance Characteristics
440 ///
441 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
442 /// - **Memory Access**: Linear access patterns for cache efficiency
443 /// - **Fallback**: Handles remaining elements with scalar operations
444 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
445 /// - **Mathematical Accuracy**: High-precision scalar multiplication
446 ///
447 /// # Implementation Details
448 ///
449 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
450 /// Processes elements in groups of 4 to improve instruction-level parallelism
451 /// and reduce loop overhead.
452 #[inline]
453 unsafe fn mul_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
454 let size = self.size();
455
456 // Use unrolled loops for better instruction throughput
457 let unroll_count = size / 4;
458 let mut i = 0;
459
460 // Process 4 elements at a time for better cache utilization
461 while i < unroll_count {
462 let idx = i * 4;
463 dst.add(idx).write(a.add(idx).read() * b.add(idx).read());
464 dst.add(idx + 1)
465 .write(a.add(idx + 1).read() * b.add(idx + 1).read());
466 dst.add(idx + 2)
467 .write(a.add(idx + 2).read() * b.add(idx + 2).read());
468 dst.add(idx + 3)
469 .write(a.add(idx + 3).read() * b.add(idx + 3).read());
470 i += 1;
471 }
472
473 // Handle remaining elements
474 for j in (unroll_count * 4)..size {
475 dst.add(j).write(a.add(j).read() * b.add(j).read());
476 }
477 }
478
479 /// Optimized scalar multiplication using SIMD when available
480 ///
481 /// Performs element-wise scalar multiplication using SIMD optimization when available
482 /// and falling back to optimized scalar computation. This is the core implementation
483 /// used by `mul_scalar()`.
484 ///
485 /// # Arguments
486 ///
487 /// * `scalar` - The scalar value to multiply by
488 ///
489 /// # Returns
490 ///
491 /// A new tensor with the result of the multiplication
492 ///
493 /// # Performance Characteristics
494 ///
495 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
496 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
497 /// - **Cache-friendly**: Linear memory access patterns
498 /// - **Memory Safety**: Ensures contiguous memory layout for correctness
499 /// - **Zero-sized Handling**: Fast return for empty tensors
500 ///
501 /// # Implementation Details
502 ///
503 /// Automatically selects between SIMD and scalar implementations based on hardware
504 /// capabilities. Ensures the input tensor is contiguous for memory safety.
505 /// SIMD implementation processes 32 elements per iteration with 4x unrolling.
506 #[inline]
507 pub(crate) fn mul_scalar_optimized(&self, scalar: f32) -> Tensor {
508 // Ensure contiguous source for correctness with view tensors
509 let src_self = if self.is_contiguous() {
510 self.clone()
511 } else {
512 self.contiguous()
513 };
514 let mut output = Tensor::new(self.shape().dims().to_vec());
515
516 unsafe {
517 let src = src_self.as_ptr();
518 let dst = output.as_mut_ptr();
519
520 #[cfg(target_arch = "x86_64")]
521 {
522 // Use SIMD for better performance when available
523 if is_x86_feature_detected!("avx2") {
524 self.mul_scalar_simd_avx2_optimized(src, dst, scalar);
525 return output;
526 }
527 }
528
529 // Fallback to scalar operations with better cache usage
530 self.mul_scalar_fallback_optimized(src, dst, scalar);
531 }
532
533 output
534 }
535
536 /// AVX2-optimized scalar multiplication implementation
537 ///
538 /// Performs element-wise scalar multiplication using AVX2 SIMD instructions for maximum
539 /// performance on x86_64 architectures with AVX2 support.
540 ///
541 /// # Arguments
542 ///
543 /// * `src` - Pointer to source tensor data
544 /// * `dst` - Pointer to output tensor data
545 /// * `scalar` - Scalar value to multiply by
546 ///
547 /// # Safety
548 ///
549 /// Requires valid pointers with sufficient memory for the tensor size.
550 /// All pointers must point to valid tensor data. Requires AVX2 support.
551 ///
552 /// # Performance Characteristics
553 ///
554 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
555 /// - **Memory Access**: Linear access patterns for cache efficiency
556 /// - **Vector Operations**: Uses AVX2 multiplication instructions
557 /// - **Fallback**: Handles remaining elements with scalar operations
558 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
559 ///
560 /// # Implementation Details
561 ///
562 /// Uses AVX2 vector instructions to process 8 elements simultaneously.
563 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
564 /// Processes remaining elements with scalar operations for complete coverage.
565 #[cfg(target_arch = "x86_64")]
566 #[inline]
567 #[target_feature(enable = "avx2")]
568 unsafe fn mul_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
569 let size = self.size();
570 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
571 let scalar_vec = _mm256_set1_ps(scalar);
572 let mut offset = 0;
573
574 // Unrolled SIMD loop for throughput
575 for _ in 0..simd_count {
576 // Process 4 AVX2 vectors (32 elements) per iteration
577 let src_vec1 = _mm256_loadu_ps(src.add(offset));
578 let mul_vec1 = _mm256_mul_ps(src_vec1, scalar_vec);
579 _mm256_storeu_ps(dst.add(offset), mul_vec1);
580
581 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
582 let mul_vec2 = _mm256_mul_ps(src_vec2, scalar_vec);
583 _mm256_storeu_ps(dst.add(offset + 8), mul_vec2);
584
585 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
586 let mul_vec3 = _mm256_mul_ps(src_vec3, scalar_vec);
587 _mm256_storeu_ps(dst.add(offset + 16), mul_vec3);
588
589 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
590 let mul_vec4 = _mm256_mul_ps(src_vec4, scalar_vec);
591 _mm256_storeu_ps(dst.add(offset + 24), mul_vec4);
592
593 offset += 32;
594 }
595
596 // Handle remaining 8-element blocks, then tail
597 let remaining_full_blocks = (size - offset) / 8;
598 for _ in 0..remaining_full_blocks {
599 let src_vec = _mm256_loadu_ps(src.add(offset));
600 let mul_vec = _mm256_mul_ps(src_vec, scalar_vec);
601 _mm256_storeu_ps(dst.add(offset), mul_vec);
602 offset += 8;
603 }
604 while offset + 4 <= size {
605 *dst.add(offset) = *src.add(offset) * scalar;
606 *dst.add(offset + 1) = *src.add(offset + 1) * scalar;
607 *dst.add(offset + 2) = *src.add(offset + 2) * scalar;
608 *dst.add(offset + 3) = *src.add(offset + 3) * scalar;
609 offset += 4;
610 }
611 for i in offset..size {
612 *dst.add(i) = *src.add(i) * scalar;
613 }
614 }
615
616 /// Optimized scalar multiplication fallback
617 ///
618 /// Performs element-wise scalar multiplication using optimized scalar operations with
619 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
620 ///
621 /// # Arguments
622 ///
623 /// * `src` - Pointer to source tensor data
624 /// * `dst` - Pointer to output tensor data
625 /// * `scalar` - Scalar value to multiply by
626 ///
627 /// # Safety
628 ///
629 /// Requires valid pointers with sufficient memory for the tensor size.
630 /// All pointers must point to valid tensor data.
631 ///
632 /// # Performance Characteristics
633 ///
634 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
635 /// - **Memory Access**: Linear access patterns for cache efficiency
636 /// - **Fallback**: Handles remaining elements with scalar operations
637 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
638 /// - **Mathematical Accuracy**: High-precision scalar multiplication
639 ///
640 /// # Implementation Details
641 ///
642 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
643 /// Processes elements in groups of 4 to improve instruction-level parallelism
644 /// and reduce loop overhead.
645 #[inline]
646 unsafe fn mul_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
647 let size = self.size();
648
649 // Use unrolled loops for better instruction throughput
650 let unroll_count = size / 4;
651 let mut i = 0;
652
653 // Process 4 elements at a time for better cache utilization
654 while i < unroll_count {
655 let idx = i * 4;
656 dst.add(idx).write(src.add(idx).read() * scalar);
657 dst.add(idx + 1).write(src.add(idx + 1).read() * scalar);
658 dst.add(idx + 2).write(src.add(idx + 2).read() * scalar);
659 dst.add(idx + 3).write(src.add(idx + 3).read() * scalar);
660 i += 1;
661 }
662
663 // Handle remaining elements
664 for j in (unroll_count * 4)..size {
665 dst.add(j).write(src.add(j).read() * scalar);
666 }
667 }
668}
669
670#[cfg(test)]
671mod tests {
672 use super::*;
673
674 #[test]
675 fn test_tensor_multiplication() {
676 let a = Tensor::ones(vec![2, 3]);
677 let mut b = Tensor::ones(vec![2, 3]);
678 b.fill(2.0);
679 let result = a.mul_tensor_optimized(&b);
680
681 assert_eq!(result.shape().dims(), vec![2, 3]);
682 assert_eq!(result.size(), 6);
683
684 // Check that all values are 2.0 (1.0 * 2.0)
685 unsafe {
686 for i in 0..result.size() {
687 assert!((result.as_ptr().add(i).read() - 2.0).abs() < 1e-6);
688 }
689 }
690 }
691
692 #[test]
693 fn test_scalar_multiplication() {
694 let tensor = Tensor::ones(vec![2, 2]);
695 let result = tensor.mul_scalar_optimized(3.0);
696
697 assert_eq!(result.shape().dims(), vec![2, 2]);
698 assert_eq!(result.size(), 4);
699
700 // Check that all values are 3.0 (1.0 * 3.0)
701 unsafe {
702 for i in 0..result.size() {
703 assert!((result.as_ptr().add(i).read() - 3.0).abs() < 1e-6);
704 }
705 }
706 }
707
708 #[test]
709 fn test_zero_multiplication() {
710 let tensor = Tensor::ones(vec![2, 3]);
711 let result = tensor.mul_scalar_optimized(0.0);
712
713 assert_eq!(result.shape().dims(), vec![2, 3]);
714 assert_eq!(result.size(), 6);
715
716 // Check that all values are 0.0 (1.0 * 0.0)
717 unsafe {
718 for i in 0..result.size() {
719 assert!((result.as_ptr().add(i).read() - 0.0).abs() < 1e-6);
720 }
721 }
722 }
723
724 #[test]
725 fn test_negative_multiplication() {
726 let tensor = Tensor::ones(vec![2, 3]);
727 let result = tensor.mul_scalar_optimized(-2.0);
728
729 assert_eq!(result.shape().dims(), vec![2, 3]);
730 assert_eq!(result.size(), 6);
731
732 // Check that all values are -2.0 (1.0 * -2.0)
733 unsafe {
734 for i in 0..result.size() {
735 assert!((result.as_ptr().add(i).read() - (-2.0)).abs() < 1e-6);
736 }
737 }
738 }
739
740 #[test]
741 #[should_panic(expected = "must match")]
742 fn test_mismatched_shapes() {
743 let a = Tensor::ones(vec![2, 3]);
744 let b = Tensor::ones(vec![3, 2]);
745 a.mul_tensor_optimized(&b);
746 }
747
748 #[test]
749 fn test_edge_cases() {
750 // Test with zero tensor
751 let zero_tensor = Tensor::zeros(vec![2, 3]);
752 let other = Tensor::ones(vec![2, 3]);
753 let result = zero_tensor.mul_tensor_optimized(&other);
754
755 assert_eq!(result.shape().dims(), vec![2, 3]);
756 assert_eq!(result.size(), 6);
757
758 // Check that all values are 0.0 (0.0 * 1.0)
759 unsafe {
760 for i in 0..result.size() {
761 assert!((result.as_ptr().add(i).read() - 0.0).abs() < 1e-6);
762 }
763 }
764
765 // Test with negative values
766 let mut neg_tensor = Tensor::ones(vec![2, 3]);
767 neg_tensor.fill(-1.0);
768 let result = neg_tensor.mul_scalar_optimized(2.0);
769
770 assert_eq!(result.shape().dims(), vec![2, 3]);
771 assert_eq!(result.size(), 6);
772
773 // Check that all values are -2.0 (-1.0 * 2.0)
774 unsafe {
775 for i in 0..result.size() {
776 assert!((result.as_ptr().add(i).read() - (-2.0)).abs() < 1e-6);
777 }
778 }
779 }
780
781 #[test]
782 fn test_large_tensor_multiplication() {
783 let a = Tensor::ones(vec![100, 100]);
784 let mut b = Tensor::ones(vec![100, 100]);
785 b.fill(1.5);
786 let result = a.mul_tensor_optimized(&b);
787
788 assert_eq!(result.shape().dims(), vec![100, 100]);
789 assert_eq!(result.size(), 10000);
790
791 // Check that all values are 1.5 (1.0 * 1.5)
792 unsafe {
793 for i in 0..result.size() {
794 assert!((result.as_ptr().add(i).read() - 1.5).abs() < 1e-6);
795 }
796 }
797 }
798
799 #[test]
800 fn test_multiplication_with_gradtrack() {
801 // Test scalar multiplication with gradtrack
802 let a = Tensor::ones(vec![2, 3]).with_requires_grad();
803 let mut result = a.mul_scalar(3.0);
804
805 // Check result values: 1.0 * 3.0 = 3.0
806 unsafe {
807 for i in 0..result.size() {
808 let val = result.as_ptr().add(i).read();
809 assert!((val - 3.0).abs() < 1e-6, "Expected 3.0, got {}", val);
810 }
811 }
812
813 result.backward(None);
814
815 // Check gradient: d/dx(3x) = 3
816 if let Some(grad) = a.grad_owned() {
817 unsafe {
818 for i in 0..grad.size() {
819 let val = grad.as_ptr().add(i).read();
820 assert!(
821 (val - 3.0).abs() < 1e-6,
822 "Expected gradient 3.0, got {}",
823 val
824 );
825 }
826 }
827 } else {
828 panic!("No gradient computed for scalar multiplication!");
829 }
830
831 // Test tensor multiplication with gradtrack
832 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
833 let mut b = Tensor::ones(vec![2, 2]);
834 b.fill(3.0);
835 let b = b.with_requires_grad();
836
837 let mut result = a.mul_tensor(&b);
838
839 // Check result values: 1.0 * 3.0 = 3.0
840 unsafe {
841 for i in 0..result.size() {
842 let val = result.as_ptr().add(i).read();
843 assert!((val - 3.0).abs() < 1e-6, "Expected 3.0, got {}", val);
844 }
845 }
846
847 result.backward(None);
848
849 // Check gradients: ∂(a*b)/∂a = b, ∂(a*b)/∂b = a
850 // For a = 1.0, b = 3.0: ∂(a*b)/∂a = 3.0, ∂(a*b)/∂b = 1.0
851 if let Some(grad_a) = a.grad_owned() {
852 unsafe {
853 for i in 0..grad_a.size() {
854 let val = grad_a.as_ptr().add(i).read();
855 assert!(
856 (val - 3.0).abs() < 1e-6,
857 "Expected gradient A = 3.0 (∂(a*b)/∂a = b), got {}",
858 val
859 );
860 }
861 }
862 } else {
863 panic!("No gradient A computed for tensor multiplication!");
864 }
865
866 if let Some(grad_b) = b.grad_owned() {
867 unsafe {
868 for i in 0..grad_b.size() {
869 let val = grad_b.as_ptr().add(i).read();
870 assert!(
871 (val - 1.0).abs() < 1e-6,
872 "Expected gradient B = 1.0 (∂(a*b)/∂b = a), got {}",
873 val
874 );
875 }
876 }
877 } else {
878 panic!("No gradient B computed for tensor multiplication!");
879 }
880 }
881
882 #[test]
883 fn test_mixed_mul_add_operations_with_gradtrack() {
884 // Test complex computation: (a * 2) + (b * 3) - 1
885 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
886 let mut b = Tensor::ones(vec![2, 2]);
887 b.fill(3.0);
888 let b = b.with_requires_grad();
889
890 let scalar1 = 2.0;
891 let scalar2 = 3.0;
892
893 // Compute: (a * scalar1) + (b * scalar2) - 1
894 let mul_a = a.mul_scalar(scalar1); // a * 2
895 let mul_b = b.mul_scalar(scalar2); // b * 3
896 let add_result = mul_a.add_tensor(&mul_b); // (a * 2) + (b * 3)
897 let mut final_result = add_result.sub_scalar(1.0); // (a * 2) + (b * 3) - 1
898
899 // Check result values: (1*2) + (3*3) - 1 = 2 + 9 - 1 = 10
900 unsafe {
901 for i in 0..final_result.size() {
902 let val = final_result.as_ptr().add(i).read();
903 assert!((val - 10.0).abs() < 1e-6, "Expected 10.0, got {}", val);
904 }
905 }
906
907 final_result.backward(None);
908
909 // Check gradients: d/dx((2x + 3y - 1)) = 2, d/dy((2x + 3y - 1)) = 3
910 if let Some(grad_a) = a.grad_owned() {
911 unsafe {
912 for i in 0..grad_a.size() {
913 let val = grad_a.as_ptr().add(i).read();
914 assert!(
915 (val - 2.0).abs() < 1e-6,
916 "Expected gradient A = 2.0, got {}",
917 val
918 );
919 }
920 }
921 } else {
922 panic!("No gradient A computed for mixed operations!");
923 }
924
925 if let Some(grad_b) = b.grad_owned() {
926 unsafe {
927 for i in 0..grad_b.size() {
928 let val = grad_b.as_ptr().add(i).read();
929 assert!(
930 (val - 3.0).abs() < 1e-6,
931 "Expected gradient B = 3.0, got {}",
932 val
933 );
934 }
935 }
936 } else {
937 panic!("No gradient B computed for mixed operations!");
938 }
939 }
940
941 #[test]
942 fn test_mul_broadcasting_gradients() {
943 use crate::gradtrack::clear_gradients;
944 clear_gradients();
945
946 // Test case: [2, 3] * [1, 3] -> [2, 3]
947 // For multiplication: d/da (a * b) = b, d/db (a * b) = a
948 // So grad_a = grad_output * b, grad_b = grad_output * a (then reduced)
949
950 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
951 .unwrap()
952 .with_requires_grad();
953 let b = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![1, 3])
954 .unwrap()
955 .with_requires_grad();
956
957 let mut result = a.mul_tensor(&b);
958 assert_eq!(result.shape().dims(), vec![2, 3]);
959
960 // Set upstream gradient as ones
961 result.backward(None);
962
963 // Check gradients
964 let grad_a = a.grad_owned().expect("grad_a should exist");
965 let grad_b = b.grad_owned().expect("grad_b should exist");
966
967 println!(
968 "Original shapes: a={:?}, b={:?}",
969 a.shape().dims(),
970 b.shape().dims()
971 );
972 println!(
973 "Gradient shapes: grad_a={:?}, grad_b={:?}",
974 grad_a.shape().dims(),
975 grad_b.shape().dims()
976 );
977
978 // grad_a should have same shape as a: [2, 3]
979 assert_eq!(
980 grad_a.shape().dims(),
981 vec![2, 3],
982 "grad_a should match original shape of a"
983 );
984
985 // grad_b should have same shape as b: [1, 3]
986 assert_eq!(
987 grad_b.shape().dims(),
988 vec![1, 3],
989 "grad_b should match original shape of b"
990 );
991
992 // For multiplication gradients:
993 // grad_a = grad_output * b = 1.0 * [2.0, 3.0, 4.0] broadcasted to [2, 3]
994 let expected_grad_a = [2.0, 3.0, 4.0, 2.0, 3.0, 4.0];
995 for (i, val) in expected_grad_a.iter().enumerate().take(grad_a.size()) {
996 let actual = unsafe { *grad_a.as_ptr().add(i) };
997 assert!(
998 (actual - val).abs() < 1e-6,
999 "grad_a[{}] = {} should be {}",
1000 i,
1001 actual,
1002 val
1003 );
1004 }
1005
1006 // grad_b = grad_output * a summed over broadcasted dimension
1007 // a = [1,2,3,4,5,6] -> grad_b should be [1+4, 2+5, 3+6] = [5, 7, 9]
1008 let expected_grad_b = [5.0, 7.0, 9.0];
1009 for (i, val) in expected_grad_b.iter().enumerate().take(grad_b.size()) {
1010 let actual = unsafe { *grad_b.as_ptr().add(i) };
1011 assert!(
1012 (actual - val).abs() < 1e-6,
1013 "grad_b[{}] = {} should be {}",
1014 i,
1015 actual,
1016 val
1017 );
1018 }
1019 }
1020}