train_station/tensor/ops/mul.rs
1//! Multiplication operations for tensors
2//!
3//! Provides element-wise multiplication functions following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Multiplication**: `mul_tensor()` - Element-wise multiplication with another tensor (PyTorch `mul()` equivalent)
9//! - **Scalar Multiplication**: `mul_scalar()` - Broadcast multiplication with a scalar value
10//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
11//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
12//! - **Broadcasting Support**: NumPy-style broadcasting for compatible shapes
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Mathematical Accuracy**: High-precision multiplication computation
15//!
16//! # Mathematical Properties
17//!
18//! The multiplication operations have the following properties:
19//! - **Commutative**: a * b = b * a
20//! - **Associative**: (a * b) * c = a * (b * c)
21//! - **Distributive**: a * (b + c) = a * b + a * c
22//! - **Identity**: a * 1 = a
23//! - **Zero**: a * 0 = 0
24//! - **Gradient**: d/dx(a * b) = b, d/dy(a * b) = a
25//!
26//! # Performance Characteristics
27//!
28//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
29//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
30//! - **Cache-friendly Access**: Linear memory access patterns
31//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
32//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41// Note: removed manual prefetching; linear access + hardware prefetch is sufficient
42
43impl Tensor {
44 /// Element-wise multiplication with another tensor with broadcasting support.
45 ///
46 /// Performs element-wise multiplication with automatic broadcasting: `output[i] = self[i] * other[i]`
47 ///
48 /// Broadcasting enables multiplication between tensors of different but compatible shapes.
49 /// Compatible shapes follow NumPy broadcasting rules:
50 /// - Dimensions are aligned from the rightmost dimension
51 /// - Dimensions are compatible if they are equal, or one of them is 1
52 /// - Missing dimensions are treated as 1
53 ///
54 /// # Arguments
55 /// * `other` - Tensor to multiply. Shapes must be broadcast-compatible.
56 ///
57 /// # Returns
58 /// A new tensor containing the element-wise product with broadcast result shape
59 ///
60 /// # Examples
61 ///
62 /// ## Same Shape Multiplication
63 ///
64 /// ```
65 /// use train_station::Tensor;
66 ///
67 /// let a = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![3]).unwrap();
68 /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0], vec![3]).unwrap();
69 /// let c = a.mul_tensor(&b);
70 /// assert_eq!(c.shape().dims, vec![3]);
71 /// assert_eq!(c.get(&[0]), 10.0); // 2.0 * 5.0
72 /// assert_eq!(c.get(&[1]), 18.0); // 3.0 * 6.0
73 /// assert_eq!(c.get(&[2]), 28.0); // 4.0 * 7.0
74 /// ```
75 ///
76 /// ## Broadcasting Multiplication
77 ///
78 /// ```
79 /// use train_station::Tensor;
80 ///
81 /// let a = Tensor::from_slice(&[2.0, 3.0], vec![2, 1]).unwrap();
82 /// let b = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![1, 3]).unwrap();
83 /// let c = a.mul_tensor(&b);
84 /// assert_eq!(c.shape().dims, vec![2, 3]);
85 /// // Result: [[20.0, 40.0, 60.0], [30.0, 60.0, 90.0]]
86 /// assert_eq!(c.get(&[0, 0]), 20.0); // 2.0 * 10.0
87 /// assert_eq!(c.get(&[0, 1]), 40.0); // 2.0 * 20.0
88 /// assert_eq!(c.get(&[1, 0]), 30.0); // 3.0 * 10.0
89 /// ```
90 ///
91 /// # Panics
92 /// Panics if tensor shapes are not broadcast-compatible
93 #[inline]
94 #[track_caller]
95 pub fn mul_tensor(&self, other: &Tensor) -> Tensor {
96 // Check if shapes are identical for fast path
97 if self.shape().dims == other.shape().dims {
98 return self.mul_tensor_same_shape(other);
99 }
100
101 // Use broadcasting for different shapes
102 let (broadcast_self, broadcast_other, _result_shape) =
103 self.broadcast_with(other).unwrap_or_else(|e| {
104 panic!(
105 "Cannot broadcast tensor shapes {:?} and {:?}: {}",
106 self.shape().dims,
107 other.shape().dims,
108 e
109 );
110 });
111
112 // Perform element-wise multiplication on broadcasted tensors
113 let mut result = broadcast_self.mul_tensor_optimized(&broadcast_other);
114
115 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
116 result.set_requires_grad_internal(true);
117 let operands = vec![self.clone(), other.clone()];
118 let grad_fn = GradFn::Mul {
119 is_tensor_mul: true,
120 scalar: None,
121 operands: Some(operands),
122 original_shapes: Some((self.shape().dims.clone(), other.shape().dims.clone())),
123 };
124 result.set_grad_fn(grad_fn.clone());
125
126 let mut input_ids = Vec::with_capacity(2);
127 if self.requires_grad() {
128 input_ids.push(self.id());
129 }
130 if other.requires_grad() {
131 input_ids.push(other.id());
132 }
133 GradEngine::register_operation(result.id(), input_ids, grad_fn);
134 }
135
136 result
137 }
138
139 /// Element-wise multiplication for tensors with identical shapes (fast path)
140 ///
141 /// This is an optimized path for tensors that already have the same shape,
142 /// avoiding the overhead of broadcasting computation. This method provides
143 /// better performance when tensors have matching dimensions.
144 ///
145 /// # Arguments
146 ///
147 /// * `other` - Tensor to multiply, must have the same shape as self
148 ///
149 /// # Returns
150 ///
151 /// A new tensor containing the element-wise product
152 ///
153 /// # Performance Characteristics
154 ///
155 /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
156 /// - **SIMD Optimization**: Uses optimized multiplication when available
157 /// - **Memory Efficiency**: Direct element-wise computation without shape conversion
158 /// - **Gradient Tracking**: Full gradtrack support with efficient gradient computation
159 ///
160 /// # Panics
161 ///
162 /// Panics if tensor shapes do not match
163 ///
164 /// # Implementation Details
165 ///
166 /// This method is called internally by `mul_tensor()` when shapes are identical.
167 /// It provides a performance optimization by skipping the broadcasting logic
168 /// and directly calling the optimized multiplication implementation.
169 #[inline]
170 fn mul_tensor_same_shape(&self, other: &Tensor) -> Tensor {
171 assert_eq!(
172 self.shape(),
173 other.shape(),
174 "Tensor shapes must match for same-shape multiplication"
175 );
176 let mut result = self.mul_tensor_optimized(other);
177
178 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
179 result.set_requires_grad_internal(true);
180 let operands = vec![self.clone(), other.clone()];
181 let grad_fn = GradFn::Mul {
182 is_tensor_mul: true,
183 scalar: None,
184 operands: Some(operands),
185 original_shapes: None, // Same shape case
186 };
187 result.set_grad_fn(grad_fn.clone());
188
189 let mut input_ids = Vec::with_capacity(2);
190 if self.requires_grad() {
191 input_ids.push(self.id());
192 }
193 if other.requires_grad() {
194 input_ids.push(other.id());
195 }
196 GradEngine::register_operation(result.id(), input_ids, grad_fn);
197 }
198
199 result
200 }
201
202 /// Broadcast multiplication with a scalar value.
203 ///
204 /// Multiplies every element by the scalar: `output[i] = self[i] * scalar`
205 ///
206 /// # Arguments
207 /// * `scalar` - Value to multiply with each element
208 ///
209 /// # Returns
210 /// A new tensor with each element multiplied by the scalar
211 ///
212 /// # Examples
213 ///
214 /// ## Basic Scalar Multiplication
215 ///
216 /// ```
217 /// use train_station::Tensor;
218 ///
219 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
220 /// let b = a.mul_scalar(10.0);
221 /// assert_eq!(b.shape().dims, vec![3]);
222 /// assert_eq!(b.get(&[0]), 10.0); // 1.0 * 10.0
223 /// assert_eq!(b.get(&[1]), 20.0); // 2.0 * 10.0
224 /// assert_eq!(b.get(&[2]), 30.0); // 3.0 * 10.0
225 /// ```
226 ///
227 /// ## Negative Scalar Multiplication
228 ///
229 /// ```
230 /// use train_station::Tensor;
231 ///
232 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
233 /// let b = a.mul_scalar(-2.0);
234 /// assert_eq!(b.shape().dims, vec![3]);
235 /// assert_eq!(b.get(&[0]), -2.0); // 1.0 * -2.0
236 /// assert_eq!(b.get(&[1]), -4.0); // 2.0 * -2.0
237 /// assert_eq!(b.get(&[2]), -6.0); // 3.0 * -2.0
238 /// ```
239 #[inline]
240 #[track_caller]
241 pub fn mul_scalar(&self, scalar: f32) -> Tensor {
242 let mut result = self.mul_scalar_optimized(scalar);
243
244 if self.requires_grad() && is_grad_enabled() {
245 result.set_requires_grad_internal(true);
246 let grad_fn = GradFn::Mul {
247 is_tensor_mul: false,
248 scalar: Some(scalar),
249 operands: None,
250 original_shapes: None, // Scalar case
251 };
252 result.set_grad_fn(grad_fn.clone());
253 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
254 }
255
256 result
257 }
258 /// Optimized tensor multiplication using SIMD when available
259 ///
260 /// Performs element-wise multiplication using SIMD optimization when available
261 /// and falling back to optimized scalar computation. This is the core implementation
262 /// used by `mul_tensor()`.
263 ///
264 /// # Arguments
265 ///
266 /// * `other` - The tensor to multiply with
267 ///
268 /// # Returns
269 ///
270 /// A new tensor with the result of the multiplication
271 ///
272 /// # Performance Characteristics
273 ///
274 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
275 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
276 /// - **Cache-friendly**: Linear memory access patterns
277 /// - **Memory Safety**: Ensures contiguous memory layout for correctness
278 /// - **Zero-sized Handling**: Fast return for empty tensors
279 ///
280 /// # Safety
281 ///
282 /// This operation assumes the tensors have the same shape. The method ensures
283 /// contiguous memory layout for both input tensors to guarantee correctness
284 /// with view tensors.
285 ///
286 /// # Implementation Details
287 ///
288 /// Automatically selects between SIMD and scalar implementations based on hardware
289 /// capabilities. Ensures both input tensors are contiguous for memory safety.
290 /// SIMD implementation processes 32 elements per iteration with 4x unrolling.
291 #[inline]
292 pub(crate) fn mul_tensor_optimized(&self, other: &Tensor) -> Tensor {
293 assert_eq!(self.shape(), other.shape(), "Tensor shapes must match");
294 // Ensure contiguous sources for correctness with view tensors
295 let a_src = if self.is_contiguous() {
296 self.clone()
297 } else {
298 self.contiguous()
299 };
300 let b_src = if other.is_contiguous() {
301 other.clone()
302 } else {
303 other.contiguous()
304 };
305
306 let mut output = Tensor::new(self.shape().dims.clone());
307
308 unsafe {
309 let a = a_src.as_ptr();
310 let b = b_src.as_ptr();
311 let dst = output.as_mut_ptr();
312
313 #[cfg(target_arch = "x86_64")]
314 {
315 // Use SIMD for better performance when available
316 if is_x86_feature_detected!("avx2") {
317 self.mul_tensors_simd_avx2_optimized(a, b, dst);
318 return output;
319 }
320 }
321
322 // Fallback to scalar operations with better cache usage
323 self.mul_tensors_scalar_optimized(a, b, dst);
324 }
325
326 output
327 }
328
329 /// AVX2-optimized tensor multiplication implementation
330 ///
331 /// Performs element-wise multiplication using AVX2 SIMD instructions for maximum
332 /// performance on x86_64 architectures with AVX2 support.
333 ///
334 /// # Arguments
335 ///
336 /// * `a` - Pointer to first tensor data
337 /// * `b` - Pointer to second tensor data
338 /// * `dst` - Pointer to output tensor data
339 ///
340 /// # Safety
341 ///
342 /// Requires valid pointers with sufficient memory for the tensor size.
343 /// All pointers must point to valid tensor data. Requires AVX2 support.
344 ///
345 /// # Performance Characteristics
346 ///
347 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
348 /// - **Memory Access**: Linear access patterns for cache efficiency
349 /// - **Vector Operations**: Uses AVX2 multiplication instructions
350 /// - **Fallback**: Handles remaining elements with scalar operations
351 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
352 ///
353 /// # Implementation Details
354 ///
355 /// Uses AVX2 vector instructions to process 8 elements simultaneously.
356 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
357 /// Processes remaining elements with scalar operations for complete coverage.
358 #[cfg(target_arch = "x86_64")]
359 #[inline]
360 #[target_feature(enable = "avx2")]
361 unsafe fn mul_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
362 let size = self.size();
363 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
364 let mut offset = 0;
365
366 // Unrolled SIMD loop for throughput
367 for _ in 0..simd_count {
368 // Process 4 AVX2 vectors (32 elements) per iteration
369 let a_vec1 = _mm256_loadu_ps(a.add(offset));
370 let b_vec1 = _mm256_loadu_ps(b.add(offset));
371 let mul_vec1 = _mm256_mul_ps(a_vec1, b_vec1);
372 _mm256_storeu_ps(dst.add(offset), mul_vec1);
373
374 let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
375 let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
376 let mul_vec2 = _mm256_mul_ps(a_vec2, b_vec2);
377 _mm256_storeu_ps(dst.add(offset + 8), mul_vec2);
378
379 let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
380 let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
381 let mul_vec3 = _mm256_mul_ps(a_vec3, b_vec3);
382 _mm256_storeu_ps(dst.add(offset + 16), mul_vec3);
383
384 let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
385 let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
386 let mul_vec4 = _mm256_mul_ps(a_vec4, b_vec4);
387 _mm256_storeu_ps(dst.add(offset + 24), mul_vec4);
388
389 offset += 32;
390 }
391
392 // Handle remaining 8-element blocks, then tail
393 let remaining_full_blocks = (size - offset) / 8;
394 for _ in 0..remaining_full_blocks {
395 let a_vec = _mm256_loadu_ps(a.add(offset));
396 let b_vec = _mm256_loadu_ps(b.add(offset));
397 let mul_vec = _mm256_mul_ps(a_vec, b_vec);
398 _mm256_storeu_ps(dst.add(offset), mul_vec);
399 offset += 8;
400 }
401 while offset + 4 <= size {
402 *dst.add(offset) = *a.add(offset) * *b.add(offset);
403 *dst.add(offset + 1) = *a.add(offset + 1) * *b.add(offset + 1);
404 *dst.add(offset + 2) = *a.add(offset + 2) * *b.add(offset + 2);
405 *dst.add(offset + 3) = *a.add(offset + 3) * *b.add(offset + 3);
406 offset += 4;
407 }
408 for i in offset..size {
409 *dst.add(i) = *a.add(i) * *b.add(i);
410 }
411 }
412
413 /// Optimized scalar tensor multiplication fallback
414 ///
415 /// Performs element-wise multiplication using optimized scalar operations with
416 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
417 ///
418 /// # Arguments
419 ///
420 /// * `a` - Pointer to first tensor data
421 /// * `b` - Pointer to second tensor data
422 /// * `dst` - Pointer to output tensor data
423 ///
424 /// # Safety
425 ///
426 /// Requires valid pointers with sufficient memory for the tensor size.
427 /// All pointers must point to valid tensor data.
428 ///
429 /// # Performance Characteristics
430 ///
431 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
432 /// - **Memory Access**: Linear access patterns for cache efficiency
433 /// - **Fallback**: Handles remaining elements with scalar operations
434 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
435 /// - **Mathematical Accuracy**: High-precision scalar multiplication
436 ///
437 /// # Implementation Details
438 ///
439 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
440 /// Processes elements in groups of 4 to improve instruction-level parallelism
441 /// and reduce loop overhead.
442 #[inline]
443 unsafe fn mul_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
444 let size = self.size();
445
446 // Use unrolled loops for better instruction throughput
447 let unroll_count = size / 4;
448 let mut i = 0;
449
450 // Process 4 elements at a time for better cache utilization
451 while i < unroll_count {
452 let idx = i * 4;
453 dst.add(idx).write(a.add(idx).read() * b.add(idx).read());
454 dst.add(idx + 1)
455 .write(a.add(idx + 1).read() * b.add(idx + 1).read());
456 dst.add(idx + 2)
457 .write(a.add(idx + 2).read() * b.add(idx + 2).read());
458 dst.add(idx + 3)
459 .write(a.add(idx + 3).read() * b.add(idx + 3).read());
460 i += 1;
461 }
462
463 // Handle remaining elements
464 for j in (unroll_count * 4)..size {
465 dst.add(j).write(a.add(j).read() * b.add(j).read());
466 }
467 }
468
469 /// Optimized scalar multiplication using SIMD when available
470 ///
471 /// Performs element-wise scalar multiplication using SIMD optimization when available
472 /// and falling back to optimized scalar computation. This is the core implementation
473 /// used by `mul_scalar()`.
474 ///
475 /// # Arguments
476 ///
477 /// * `scalar` - The scalar value to multiply by
478 ///
479 /// # Returns
480 ///
481 /// A new tensor with the result of the multiplication
482 ///
483 /// # Performance Characteristics
484 ///
485 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
486 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
487 /// - **Cache-friendly**: Linear memory access patterns
488 /// - **Memory Safety**: Ensures contiguous memory layout for correctness
489 /// - **Zero-sized Handling**: Fast return for empty tensors
490 ///
491 /// # Implementation Details
492 ///
493 /// Automatically selects between SIMD and scalar implementations based on hardware
494 /// capabilities. Ensures the input tensor is contiguous for memory safety.
495 /// SIMD implementation processes 32 elements per iteration with 4x unrolling.
496 #[inline]
497 pub(crate) fn mul_scalar_optimized(&self, scalar: f32) -> Tensor {
498 // Ensure contiguous source for correctness with view tensors
499 let src_self = if self.is_contiguous() {
500 self.clone()
501 } else {
502 self.contiguous()
503 };
504 let mut output = Tensor::new(self.shape().dims.clone());
505
506 unsafe {
507 let src = src_self.as_ptr();
508 let dst = output.as_mut_ptr();
509
510 #[cfg(target_arch = "x86_64")]
511 {
512 // Use SIMD for better performance when available
513 if is_x86_feature_detected!("avx2") {
514 self.mul_scalar_simd_avx2_optimized(src, dst, scalar);
515 return output;
516 }
517 }
518
519 // Fallback to scalar operations with better cache usage
520 self.mul_scalar_fallback_optimized(src, dst, scalar);
521 }
522
523 output
524 }
525
526 /// AVX2-optimized scalar multiplication implementation
527 ///
528 /// Performs element-wise scalar multiplication using AVX2 SIMD instructions for maximum
529 /// performance on x86_64 architectures with AVX2 support.
530 ///
531 /// # Arguments
532 ///
533 /// * `src` - Pointer to source tensor data
534 /// * `dst` - Pointer to output tensor data
535 /// * `scalar` - Scalar value to multiply by
536 ///
537 /// # Safety
538 ///
539 /// Requires valid pointers with sufficient memory for the tensor size.
540 /// All pointers must point to valid tensor data. Requires AVX2 support.
541 ///
542 /// # Performance Characteristics
543 ///
544 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
545 /// - **Memory Access**: Linear access patterns for cache efficiency
546 /// - **Vector Operations**: Uses AVX2 multiplication instructions
547 /// - **Fallback**: Handles remaining elements with scalar operations
548 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
549 ///
550 /// # Implementation Details
551 ///
552 /// Uses AVX2 vector instructions to process 8 elements simultaneously.
553 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
554 /// Processes remaining elements with scalar operations for complete coverage.
555 #[cfg(target_arch = "x86_64")]
556 #[inline]
557 #[target_feature(enable = "avx2")]
558 unsafe fn mul_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
559 let size = self.size();
560 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
561 let scalar_vec = _mm256_set1_ps(scalar);
562 let mut offset = 0;
563
564 // Unrolled SIMD loop for throughput
565 for _ in 0..simd_count {
566 // Process 4 AVX2 vectors (32 elements) per iteration
567 let src_vec1 = _mm256_loadu_ps(src.add(offset));
568 let mul_vec1 = _mm256_mul_ps(src_vec1, scalar_vec);
569 _mm256_storeu_ps(dst.add(offset), mul_vec1);
570
571 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
572 let mul_vec2 = _mm256_mul_ps(src_vec2, scalar_vec);
573 _mm256_storeu_ps(dst.add(offset + 8), mul_vec2);
574
575 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
576 let mul_vec3 = _mm256_mul_ps(src_vec3, scalar_vec);
577 _mm256_storeu_ps(dst.add(offset + 16), mul_vec3);
578
579 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
580 let mul_vec4 = _mm256_mul_ps(src_vec4, scalar_vec);
581 _mm256_storeu_ps(dst.add(offset + 24), mul_vec4);
582
583 offset += 32;
584 }
585
586 // Handle remaining 8-element blocks, then tail
587 let remaining_full_blocks = (size - offset) / 8;
588 for _ in 0..remaining_full_blocks {
589 let src_vec = _mm256_loadu_ps(src.add(offset));
590 let mul_vec = _mm256_mul_ps(src_vec, scalar_vec);
591 _mm256_storeu_ps(dst.add(offset), mul_vec);
592 offset += 8;
593 }
594 while offset + 4 <= size {
595 *dst.add(offset) = *src.add(offset) * scalar;
596 *dst.add(offset + 1) = *src.add(offset + 1) * scalar;
597 *dst.add(offset + 2) = *src.add(offset + 2) * scalar;
598 *dst.add(offset + 3) = *src.add(offset + 3) * scalar;
599 offset += 4;
600 }
601 for i in offset..size {
602 *dst.add(i) = *src.add(i) * scalar;
603 }
604 }
605
606 /// Optimized scalar multiplication fallback
607 ///
608 /// Performs element-wise scalar multiplication using optimized scalar operations with
609 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
610 ///
611 /// # Arguments
612 ///
613 /// * `src` - Pointer to source tensor data
614 /// * `dst` - Pointer to output tensor data
615 /// * `scalar` - Scalar value to multiply by
616 ///
617 /// # Safety
618 ///
619 /// Requires valid pointers with sufficient memory for the tensor size.
620 /// All pointers must point to valid tensor data.
621 ///
622 /// # Performance Characteristics
623 ///
624 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
625 /// - **Memory Access**: Linear access patterns for cache efficiency
626 /// - **Fallback**: Handles remaining elements with scalar operations
627 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
628 /// - **Mathematical Accuracy**: High-precision scalar multiplication
629 ///
630 /// # Implementation Details
631 ///
632 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
633 /// Processes elements in groups of 4 to improve instruction-level parallelism
634 /// and reduce loop overhead.
635 #[inline]
636 unsafe fn mul_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
637 let size = self.size();
638
639 // Use unrolled loops for better instruction throughput
640 let unroll_count = size / 4;
641 let mut i = 0;
642
643 // Process 4 elements at a time for better cache utilization
644 while i < unroll_count {
645 let idx = i * 4;
646 dst.add(idx).write(src.add(idx).read() * scalar);
647 dst.add(idx + 1).write(src.add(idx + 1).read() * scalar);
648 dst.add(idx + 2).write(src.add(idx + 2).read() * scalar);
649 dst.add(idx + 3).write(src.add(idx + 3).read() * scalar);
650 i += 1;
651 }
652
653 // Handle remaining elements
654 for j in (unroll_count * 4)..size {
655 dst.add(j).write(src.add(j).read() * scalar);
656 }
657 }
658}
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663
664 #[test]
665 fn test_tensor_multiplication() {
666 let a = Tensor::ones(vec![2, 3]);
667 let mut b = Tensor::ones(vec![2, 3]);
668 b.fill(2.0);
669 let result = a.mul_tensor_optimized(&b);
670
671 assert_eq!(result.shape().dims, vec![2, 3]);
672 assert_eq!(result.size(), 6);
673
674 // Check that all values are 2.0 (1.0 * 2.0)
675 unsafe {
676 for i in 0..result.size() {
677 assert!((result.as_ptr().add(i).read() - 2.0).abs() < 1e-6);
678 }
679 }
680 }
681
682 #[test]
683 fn test_scalar_multiplication() {
684 let tensor = Tensor::ones(vec![2, 2]);
685 let result = tensor.mul_scalar_optimized(3.0);
686
687 assert_eq!(result.shape().dims, vec![2, 2]);
688 assert_eq!(result.size(), 4);
689
690 // Check that all values are 3.0 (1.0 * 3.0)
691 unsafe {
692 for i in 0..result.size() {
693 assert!((result.as_ptr().add(i).read() - 3.0).abs() < 1e-6);
694 }
695 }
696 }
697
698 #[test]
699 fn test_zero_multiplication() {
700 let tensor = Tensor::ones(vec![2, 3]);
701 let result = tensor.mul_scalar_optimized(0.0);
702
703 assert_eq!(result.shape().dims, vec![2, 3]);
704 assert_eq!(result.size(), 6);
705
706 // Check that all values are 0.0 (1.0 * 0.0)
707 unsafe {
708 for i in 0..result.size() {
709 assert!((result.as_ptr().add(i).read() - 0.0).abs() < 1e-6);
710 }
711 }
712 }
713
714 #[test]
715 fn test_negative_multiplication() {
716 let tensor = Tensor::ones(vec![2, 3]);
717 let result = tensor.mul_scalar_optimized(-2.0);
718
719 assert_eq!(result.shape().dims, vec![2, 3]);
720 assert_eq!(result.size(), 6);
721
722 // Check that all values are -2.0 (1.0 * -2.0)
723 unsafe {
724 for i in 0..result.size() {
725 assert!((result.as_ptr().add(i).read() - (-2.0)).abs() < 1e-6);
726 }
727 }
728 }
729
730 #[test]
731 #[should_panic(expected = "Tensor shapes must match")]
732 fn test_mismatched_shapes() {
733 let a = Tensor::ones(vec![2, 3]);
734 let b = Tensor::ones(vec![3, 2]);
735 a.mul_tensor_optimized(&b);
736 }
737
738 #[test]
739 fn test_edge_cases() {
740 // Test with zero tensor
741 let zero_tensor = Tensor::zeros(vec![2, 3]);
742 let other = Tensor::ones(vec![2, 3]);
743 let result = zero_tensor.mul_tensor_optimized(&other);
744
745 assert_eq!(result.shape().dims, vec![2, 3]);
746 assert_eq!(result.size(), 6);
747
748 // Check that all values are 0.0 (0.0 * 1.0)
749 unsafe {
750 for i in 0..result.size() {
751 assert!((result.as_ptr().add(i).read() - 0.0).abs() < 1e-6);
752 }
753 }
754
755 // Test with negative values
756 let mut neg_tensor = Tensor::ones(vec![2, 3]);
757 neg_tensor.fill(-1.0);
758 let result = neg_tensor.mul_scalar_optimized(2.0);
759
760 assert_eq!(result.shape().dims, vec![2, 3]);
761 assert_eq!(result.size(), 6);
762
763 // Check that all values are -2.0 (-1.0 * 2.0)
764 unsafe {
765 for i in 0..result.size() {
766 assert!((result.as_ptr().add(i).read() - (-2.0)).abs() < 1e-6);
767 }
768 }
769 }
770
771 #[test]
772 fn test_large_tensor_multiplication() {
773 let a = Tensor::ones(vec![100, 100]);
774 let mut b = Tensor::ones(vec![100, 100]);
775 b.fill(1.5);
776 let result = a.mul_tensor_optimized(&b);
777
778 assert_eq!(result.shape().dims, vec![100, 100]);
779 assert_eq!(result.size(), 10000);
780
781 // Check that all values are 1.5 (1.0 * 1.5)
782 unsafe {
783 for i in 0..result.size() {
784 assert!((result.as_ptr().add(i).read() - 1.5).abs() < 1e-6);
785 }
786 }
787 }
788
789 #[test]
790 fn test_multiplication_with_gradtrack() {
791 // Test scalar multiplication with gradtrack
792 let a = Tensor::ones(vec![2, 3]).with_requires_grad();
793 let mut result = a.mul_scalar(3.0);
794
795 // Check result values: 1.0 * 3.0 = 3.0
796 unsafe {
797 for i in 0..result.size() {
798 let val = result.as_ptr().add(i).read();
799 assert!((val - 3.0).abs() < 1e-6, "Expected 3.0, got {}", val);
800 }
801 }
802
803 result.backward(None);
804
805 // Check gradient: d/dx(3x) = 3
806 if let Some(grad) = a.grad_by_value() {
807 unsafe {
808 for i in 0..grad.size() {
809 let val = grad.as_ptr().add(i).read();
810 assert!(
811 (val - 3.0).abs() < 1e-6,
812 "Expected gradient 3.0, got {}",
813 val
814 );
815 }
816 }
817 } else {
818 panic!("No gradient computed for scalar multiplication!");
819 }
820
821 // Test tensor multiplication with gradtrack
822 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
823 let mut b = Tensor::ones(vec![2, 2]);
824 b.fill(3.0);
825 let b = b.with_requires_grad();
826
827 let mut result = a.mul_tensor(&b);
828
829 // Check result values: 1.0 * 3.0 = 3.0
830 unsafe {
831 for i in 0..result.size() {
832 let val = result.as_ptr().add(i).read();
833 assert!((val - 3.0).abs() < 1e-6, "Expected 3.0, got {}", val);
834 }
835 }
836
837 result.backward(None);
838
839 // Check gradients: ∂(a*b)/∂a = b, ∂(a*b)/∂b = a
840 // For a = 1.0, b = 3.0: ∂(a*b)/∂a = 3.0, ∂(a*b)/∂b = 1.0
841 if let Some(grad_a) = a.grad_by_value() {
842 unsafe {
843 for i in 0..grad_a.size() {
844 let val = grad_a.as_ptr().add(i).read();
845 assert!(
846 (val - 3.0).abs() < 1e-6,
847 "Expected gradient A = 3.0 (∂(a*b)/∂a = b), got {}",
848 val
849 );
850 }
851 }
852 } else {
853 panic!("No gradient A computed for tensor multiplication!");
854 }
855
856 if let Some(grad_b) = b.grad_by_value() {
857 unsafe {
858 for i in 0..grad_b.size() {
859 let val = grad_b.as_ptr().add(i).read();
860 assert!(
861 (val - 1.0).abs() < 1e-6,
862 "Expected gradient B = 1.0 (∂(a*b)/∂b = a), got {}",
863 val
864 );
865 }
866 }
867 } else {
868 panic!("No gradient B computed for tensor multiplication!");
869 }
870 }
871
872 #[test]
873 fn test_mixed_mul_add_operations_with_gradtrack() {
874 // Test complex computation: (a * 2) + (b * 3) - 1
875 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
876 let mut b = Tensor::ones(vec![2, 2]);
877 b.fill(3.0);
878 let b = b.with_requires_grad();
879
880 let scalar1 = 2.0;
881 let scalar2 = 3.0;
882
883 // Compute: (a * scalar1) + (b * scalar2) - 1
884 let mul_a = a.mul_scalar(scalar1); // a * 2
885 let mul_b = b.mul_scalar(scalar2); // b * 3
886 let add_result = mul_a.add_tensor(&mul_b); // (a * 2) + (b * 3)
887 let mut final_result = add_result.sub_scalar(1.0); // (a * 2) + (b * 3) - 1
888
889 // Check result values: (1*2) + (3*3) - 1 = 2 + 9 - 1 = 10
890 unsafe {
891 for i in 0..final_result.size() {
892 let val = final_result.as_ptr().add(i).read();
893 assert!((val - 10.0).abs() < 1e-6, "Expected 10.0, got {}", val);
894 }
895 }
896
897 final_result.backward(None);
898
899 // Check gradients: d/dx((2x + 3y - 1)) = 2, d/dy((2x + 3y - 1)) = 3
900 if let Some(grad_a) = a.grad_by_value() {
901 unsafe {
902 for i in 0..grad_a.size() {
903 let val = grad_a.as_ptr().add(i).read();
904 assert!(
905 (val - 2.0).abs() < 1e-6,
906 "Expected gradient A = 2.0, got {}",
907 val
908 );
909 }
910 }
911 } else {
912 panic!("No gradient A computed for mixed operations!");
913 }
914
915 if let Some(grad_b) = b.grad_by_value() {
916 unsafe {
917 for i in 0..grad_b.size() {
918 let val = grad_b.as_ptr().add(i).read();
919 assert!(
920 (val - 3.0).abs() < 1e-6,
921 "Expected gradient B = 3.0, got {}",
922 val
923 );
924 }
925 }
926 } else {
927 panic!("No gradient B computed for mixed operations!");
928 }
929 }
930
931 #[test]
932 fn test_mul_broadcasting_gradients() {
933 use crate::gradtrack::clear_gradients;
934 clear_gradients();
935
936 // Test case: [2, 3] * [1, 3] -> [2, 3]
937 // For multiplication: d/da (a * b) = b, d/db (a * b) = a
938 // So grad_a = grad_output * b, grad_b = grad_output * a (then reduced)
939
940 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
941 .unwrap()
942 .with_requires_grad();
943 let b = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![1, 3])
944 .unwrap()
945 .with_requires_grad();
946
947 let mut result = a.mul_tensor(&b);
948 assert_eq!(result.shape().dims, vec![2, 3]);
949
950 // Set upstream gradient as ones
951 result.backward(None);
952
953 // Check gradients
954 let grad_a = a.grad_by_value().expect("grad_a should exist");
955 let grad_b = b.grad_by_value().expect("grad_b should exist");
956
957 println!(
958 "Original shapes: a={:?}, b={:?}",
959 a.shape().dims,
960 b.shape().dims
961 );
962 println!(
963 "Gradient shapes: grad_a={:?}, grad_b={:?}",
964 grad_a.shape().dims,
965 grad_b.shape().dims
966 );
967
968 // grad_a should have same shape as a: [2, 3]
969 assert_eq!(
970 grad_a.shape().dims,
971 vec![2, 3],
972 "grad_a should match original shape of a"
973 );
974
975 // grad_b should have same shape as b: [1, 3]
976 assert_eq!(
977 grad_b.shape().dims,
978 vec![1, 3],
979 "grad_b should match original shape of b"
980 );
981
982 // For multiplication gradients:
983 // grad_a = grad_output * b = 1.0 * [2.0, 3.0, 4.0] broadcasted to [2, 3]
984 let expected_grad_a = [2.0, 3.0, 4.0, 2.0, 3.0, 4.0];
985 for (i, val) in expected_grad_a.iter().enumerate().take(grad_a.size()) {
986 let actual = unsafe { *grad_a.as_ptr().add(i) };
987 assert!(
988 (actual - val).abs() < 1e-6,
989 "grad_a[{}] = {} should be {}",
990 i,
991 actual,
992 val
993 );
994 }
995
996 // grad_b = grad_output * a summed over broadcasted dimension
997 // a = [1,2,3,4,5,6] -> grad_b should be [1+4, 2+5, 3+6] = [5, 7, 9]
998 let expected_grad_b = [5.0, 7.0, 9.0];
999 for (i, val) in expected_grad_b.iter().enumerate().take(grad_b.size()) {
1000 let actual = unsafe { *grad_b.as_ptr().add(i) };
1001 assert!(
1002 (actual - val).abs() < 1e-6,
1003 "grad_b[{}] = {} should be {}",
1004 i,
1005 actual,
1006 val
1007 );
1008 }
1009 }
1010}