train_station/tensor/ops/mul.rs
1//! Multiplication operations for tensors
2//!
3//! Provides element-wise multiplication functions following PyTorch conventions with
4//! comprehensive automatic differentiation support and SIMD-optimized computation.
5//!
6//! # Key Features
7//!
8//! - **Element-wise Multiplication**: `mul_tensor()` - Element-wise multiplication with another tensor (PyTorch `mul()` equivalent)
9//! - **Scalar Multiplication**: `mul_scalar()` - Broadcast multiplication with a scalar value
10//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
11//! - **SIMD Optimization**: AVX2-optimized implementation for maximum performance
12//! - **Broadcasting Support**: NumPy-style broadcasting for compatible shapes
13//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
14//! - **Mathematical Accuracy**: High-precision multiplication computation
15//!
16//! # Mathematical Properties
17//!
18//! The multiplication operations have the following properties:
19//! - **Commutative**: a * b = b * a
20//! - **Associative**: (a * b) * c = a * (b * c)
21//! - **Distributive**: a * (b + c) = a * b + a * c
22//! - **Identity**: a * 1 = a
23//! - **Zero**: a * 0 = 0
24//! - **Gradient**: d/dx(a * b) = b, d/dy(a * b) = a
25//!
26//! # Performance Characteristics
27//!
28//! - **SIMD Optimization**: AVX2-optimized with 32-element blocks and 4x unrolling
29//! - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
30//! - **Cache-friendly Access**: Linear memory access patterns
31//! - **Branch Prediction**: Optimized conditional logic for modern CPUs
32//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37// SIMD optimizations for performance-critical operations
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41// Note: removed manual prefetching; linear access + hardware prefetch is sufficient
42
43impl Tensor {
44 /// Element-wise multiplication with another tensor with broadcasting support.
45 ///
46 /// Performs element-wise multiplication with automatic broadcasting: `output[i] = self[i] * other[i]`
47 ///
48 /// Broadcasting enables multiplication between tensors of different but compatible shapes.
49 /// Compatible shapes follow NumPy broadcasting rules:
50 /// - Dimensions are aligned from the rightmost dimension
51 /// - Dimensions are compatible if they are equal, or one of them is 1
52 /// - Missing dimensions are treated as 1
53 ///
54 /// # Arguments
55 /// * `other` - Tensor to multiply. Shapes must be broadcast-compatible.
56 ///
57 /// # Returns
58 /// A new tensor containing the element-wise product with broadcast result shape
59 ///
60 /// # Examples
61 ///
62 /// ## Same Shape Multiplication
63 ///
64 /// ```
65 /// use train_station::Tensor;
66 ///
67 /// let a = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![3]).unwrap();
68 /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0], vec![3]).unwrap();
69 /// let c = a.mul_tensor(&b);
70 /// assert_eq!(c.shape().dims, vec![3]);
71 /// assert_eq!(c.get(&[0]), 10.0); // 2.0 * 5.0
72 /// assert_eq!(c.get(&[1]), 18.0); // 3.0 * 6.0
73 /// assert_eq!(c.get(&[2]), 28.0); // 4.0 * 7.0
74 /// ```
75 ///
76 /// ## Broadcasting Multiplication
77 ///
78 /// ```
79 /// use train_station::Tensor;
80 ///
81 /// let a = Tensor::from_slice(&[2.0, 3.0], vec![2, 1]).unwrap();
82 /// let b = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![1, 3]).unwrap();
83 /// let c = a.mul_tensor(&b);
84 /// assert_eq!(c.shape().dims, vec![2, 3]);
85 /// // Result: [[20.0, 40.0, 60.0], [30.0, 60.0, 90.0]]
86 /// assert_eq!(c.get(&[0, 0]), 20.0); // 2.0 * 10.0
87 /// assert_eq!(c.get(&[0, 1]), 40.0); // 2.0 * 20.0
88 /// assert_eq!(c.get(&[1, 0]), 30.0); // 3.0 * 10.0
89 /// ```
90 ///
91 /// # Panics
92 /// Panics if tensor shapes are not broadcast-compatible
93 #[inline]
94 pub fn mul_tensor(&self, other: &Tensor) -> Tensor {
95 // Check if shapes are identical for fast path
96 if self.shape().dims == other.shape().dims {
97 return self.mul_tensor_same_shape(other);
98 }
99
100 // Use broadcasting for different shapes
101 let (broadcast_self, broadcast_other, _result_shape) =
102 self.broadcast_with(other).unwrap_or_else(|e| {
103 panic!(
104 "Cannot broadcast tensor shapes {:?} and {:?}: {}",
105 self.shape().dims,
106 other.shape().dims,
107 e
108 );
109 });
110
111 // Perform element-wise multiplication on broadcasted tensors
112 let mut result = broadcast_self.mul_tensor_optimized(&broadcast_other);
113
114 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
115 result.set_requires_grad_internal(true);
116 let operands = vec![self.clone(), other.clone()];
117 let grad_fn = GradFn::Mul {
118 is_tensor_mul: true,
119 scalar: None,
120 operands: Some(operands),
121 original_shapes: Some((self.shape().dims.clone(), other.shape().dims.clone())),
122 };
123 result.set_grad_fn(grad_fn.clone());
124
125 let mut input_ids = Vec::with_capacity(2);
126 if self.requires_grad() {
127 input_ids.push(self.id());
128 }
129 if other.requires_grad() {
130 input_ids.push(other.id());
131 }
132 GradEngine::register_operation(result.id(), input_ids, grad_fn);
133 }
134
135 result
136 }
137
138 /// Element-wise multiplication for tensors with identical shapes (fast path)
139 ///
140 /// This is an optimized path for tensors that already have the same shape,
141 /// avoiding the overhead of broadcasting computation. This method provides
142 /// better performance when tensors have matching dimensions.
143 ///
144 /// # Arguments
145 ///
146 /// * `other` - Tensor to multiply, must have the same shape as self
147 ///
148 /// # Returns
149 ///
150 /// A new tensor containing the element-wise product
151 ///
152 /// # Performance Characteristics
153 ///
154 /// - **Fast Path**: Avoids broadcasting overhead for identical shapes
155 /// - **SIMD Optimization**: Uses optimized multiplication when available
156 /// - **Memory Efficiency**: Direct element-wise computation without shape conversion
157 /// - **Gradient Tracking**: Full gradtrack support with efficient gradient computation
158 ///
159 /// # Panics
160 ///
161 /// Panics if tensor shapes do not match
162 ///
163 /// # Implementation Details
164 ///
165 /// This method is called internally by `mul_tensor()` when shapes are identical.
166 /// It provides a performance optimization by skipping the broadcasting logic
167 /// and directly calling the optimized multiplication implementation.
168 #[inline]
169 fn mul_tensor_same_shape(&self, other: &Tensor) -> Tensor {
170 assert_eq!(
171 self.shape(),
172 other.shape(),
173 "Tensor shapes must match for same-shape multiplication"
174 );
175 let mut result = self.mul_tensor_optimized(other);
176
177 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
178 result.set_requires_grad_internal(true);
179 let operands = vec![self.clone(), other.clone()];
180 let grad_fn = GradFn::Mul {
181 is_tensor_mul: true,
182 scalar: None,
183 operands: Some(operands),
184 original_shapes: None, // Same shape case
185 };
186 result.set_grad_fn(grad_fn.clone());
187
188 let mut input_ids = Vec::with_capacity(2);
189 if self.requires_grad() {
190 input_ids.push(self.id());
191 }
192 if other.requires_grad() {
193 input_ids.push(other.id());
194 }
195 GradEngine::register_operation(result.id(), input_ids, grad_fn);
196 }
197
198 result
199 }
200
201 /// Broadcast multiplication with a scalar value.
202 ///
203 /// Multiplies every element by the scalar: `output[i] = self[i] * scalar`
204 ///
205 /// # Arguments
206 /// * `scalar` - Value to multiply with each element
207 ///
208 /// # Returns
209 /// A new tensor with each element multiplied by the scalar
210 ///
211 /// # Examples
212 ///
213 /// ## Basic Scalar Multiplication
214 ///
215 /// ```
216 /// use train_station::Tensor;
217 ///
218 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
219 /// let b = a.mul_scalar(10.0);
220 /// assert_eq!(b.shape().dims, vec![3]);
221 /// assert_eq!(b.get(&[0]), 10.0); // 1.0 * 10.0
222 /// assert_eq!(b.get(&[1]), 20.0); // 2.0 * 10.0
223 /// assert_eq!(b.get(&[2]), 30.0); // 3.0 * 10.0
224 /// ```
225 ///
226 /// ## Negative Scalar Multiplication
227 ///
228 /// ```
229 /// use train_station::Tensor;
230 ///
231 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
232 /// let b = a.mul_scalar(-2.0);
233 /// assert_eq!(b.shape().dims, vec![3]);
234 /// assert_eq!(b.get(&[0]), -2.0); // 1.0 * -2.0
235 /// assert_eq!(b.get(&[1]), -4.0); // 2.0 * -2.0
236 /// assert_eq!(b.get(&[2]), -6.0); // 3.0 * -2.0
237 /// ```
238 #[inline]
239 pub fn mul_scalar(&self, scalar: f32) -> Tensor {
240 let mut result = self.mul_scalar_optimized(scalar);
241
242 if self.requires_grad() && is_grad_enabled() {
243 result.set_requires_grad_internal(true);
244 let grad_fn = GradFn::Mul {
245 is_tensor_mul: false,
246 scalar: Some(scalar),
247 operands: None,
248 original_shapes: None, // Scalar case
249 };
250 result.set_grad_fn(grad_fn.clone());
251 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
252 }
253
254 result
255 }
256 /// Optimized tensor multiplication using SIMD when available
257 ///
258 /// Performs element-wise multiplication using SIMD optimization when available
259 /// and falling back to optimized scalar computation. This is the core implementation
260 /// used by `mul_tensor()`.
261 ///
262 /// # Arguments
263 ///
264 /// * `other` - The tensor to multiply with
265 ///
266 /// # Returns
267 ///
268 /// A new tensor with the result of the multiplication
269 ///
270 /// # Performance Characteristics
271 ///
272 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
273 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
274 /// - **Cache-friendly**: Linear memory access patterns
275 /// - **Memory Safety**: Ensures contiguous memory layout for correctness
276 /// - **Zero-sized Handling**: Fast return for empty tensors
277 ///
278 /// # Safety
279 ///
280 /// This operation assumes the tensors have the same shape. The method ensures
281 /// contiguous memory layout for both input tensors to guarantee correctness
282 /// with view tensors.
283 ///
284 /// # Implementation Details
285 ///
286 /// Automatically selects between SIMD and scalar implementations based on hardware
287 /// capabilities. Ensures both input tensors are contiguous for memory safety.
288 /// SIMD implementation processes 32 elements per iteration with 4x unrolling.
289 #[inline]
290 pub(crate) fn mul_tensor_optimized(&self, other: &Tensor) -> Tensor {
291 assert_eq!(self.shape(), other.shape(), "Tensor shapes must match");
292 // Ensure contiguous sources for correctness with view tensors
293 let a_src = if self.is_contiguous() {
294 self.clone()
295 } else {
296 self.contiguous()
297 };
298 let b_src = if other.is_contiguous() {
299 other.clone()
300 } else {
301 other.contiguous()
302 };
303
304 let mut output = Tensor::new(self.shape().dims.clone());
305
306 unsafe {
307 let a = a_src.as_ptr();
308 let b = b_src.as_ptr();
309 let dst = output.as_mut_ptr();
310
311 #[cfg(target_arch = "x86_64")]
312 {
313 // Use SIMD for better performance when available
314 if is_x86_feature_detected!("avx2") {
315 self.mul_tensors_simd_avx2_optimized(a, b, dst);
316 return output;
317 }
318 }
319
320 // Fallback to scalar operations with better cache usage
321 self.mul_tensors_scalar_optimized(a, b, dst);
322 }
323
324 output
325 }
326
327 /// AVX2-optimized tensor multiplication implementation
328 ///
329 /// Performs element-wise multiplication using AVX2 SIMD instructions for maximum
330 /// performance on x86_64 architectures with AVX2 support.
331 ///
332 /// # Arguments
333 ///
334 /// * `a` - Pointer to first tensor data
335 /// * `b` - Pointer to second tensor data
336 /// * `dst` - Pointer to output tensor data
337 ///
338 /// # Safety
339 ///
340 /// Requires valid pointers with sufficient memory for the tensor size.
341 /// All pointers must point to valid tensor data. Requires AVX2 support.
342 ///
343 /// # Performance Characteristics
344 ///
345 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
346 /// - **Memory Access**: Linear access patterns for cache efficiency
347 /// - **Vector Operations**: Uses AVX2 multiplication instructions
348 /// - **Fallback**: Handles remaining elements with scalar operations
349 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
350 ///
351 /// # Implementation Details
352 ///
353 /// Uses AVX2 vector instructions to process 8 elements simultaneously.
354 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
355 /// Processes remaining elements with scalar operations for complete coverage.
356 #[cfg(target_arch = "x86_64")]
357 #[inline]
358 #[target_feature(enable = "avx2")]
359 unsafe fn mul_tensors_simd_avx2_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
360 let size = self.size();
361 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
362 let mut offset = 0;
363
364 // Unrolled SIMD loop for throughput
365 for _ in 0..simd_count {
366 // Process 4 AVX2 vectors (32 elements) per iteration
367 let a_vec1 = _mm256_loadu_ps(a.add(offset));
368 let b_vec1 = _mm256_loadu_ps(b.add(offset));
369 let mul_vec1 = _mm256_mul_ps(a_vec1, b_vec1);
370 _mm256_storeu_ps(dst.add(offset), mul_vec1);
371
372 let a_vec2 = _mm256_loadu_ps(a.add(offset + 8));
373 let b_vec2 = _mm256_loadu_ps(b.add(offset + 8));
374 let mul_vec2 = _mm256_mul_ps(a_vec2, b_vec2);
375 _mm256_storeu_ps(dst.add(offset + 8), mul_vec2);
376
377 let a_vec3 = _mm256_loadu_ps(a.add(offset + 16));
378 let b_vec3 = _mm256_loadu_ps(b.add(offset + 16));
379 let mul_vec3 = _mm256_mul_ps(a_vec3, b_vec3);
380 _mm256_storeu_ps(dst.add(offset + 16), mul_vec3);
381
382 let a_vec4 = _mm256_loadu_ps(a.add(offset + 24));
383 let b_vec4 = _mm256_loadu_ps(b.add(offset + 24));
384 let mul_vec4 = _mm256_mul_ps(a_vec4, b_vec4);
385 _mm256_storeu_ps(dst.add(offset + 24), mul_vec4);
386
387 offset += 32;
388 }
389
390 // Handle remaining 8-element blocks, then tail
391 let remaining_full_blocks = (size - offset) / 8;
392 for _ in 0..remaining_full_blocks {
393 let a_vec = _mm256_loadu_ps(a.add(offset));
394 let b_vec = _mm256_loadu_ps(b.add(offset));
395 let mul_vec = _mm256_mul_ps(a_vec, b_vec);
396 _mm256_storeu_ps(dst.add(offset), mul_vec);
397 offset += 8;
398 }
399 while offset + 4 <= size {
400 *dst.add(offset) = *a.add(offset) * *b.add(offset);
401 *dst.add(offset + 1) = *a.add(offset + 1) * *b.add(offset + 1);
402 *dst.add(offset + 2) = *a.add(offset + 2) * *b.add(offset + 2);
403 *dst.add(offset + 3) = *a.add(offset + 3) * *b.add(offset + 3);
404 offset += 4;
405 }
406 for i in offset..size {
407 *dst.add(i) = *a.add(i) * *b.add(i);
408 }
409 }
410
411 /// Optimized scalar tensor multiplication fallback
412 ///
413 /// Performs element-wise multiplication using optimized scalar operations with
414 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
415 ///
416 /// # Arguments
417 ///
418 /// * `a` - Pointer to first tensor data
419 /// * `b` - Pointer to second tensor data
420 /// * `dst` - Pointer to output tensor data
421 ///
422 /// # Safety
423 ///
424 /// Requires valid pointers with sufficient memory for the tensor size.
425 /// All pointers must point to valid tensor data.
426 ///
427 /// # Performance Characteristics
428 ///
429 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
430 /// - **Memory Access**: Linear access patterns for cache efficiency
431 /// - **Fallback**: Handles remaining elements with scalar operations
432 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
433 /// - **Mathematical Accuracy**: High-precision scalar multiplication
434 ///
435 /// # Implementation Details
436 ///
437 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
438 /// Processes elements in groups of 4 to improve instruction-level parallelism
439 /// and reduce loop overhead.
440 #[inline]
441 unsafe fn mul_tensors_scalar_optimized(&self, a: *const f32, b: *const f32, dst: *mut f32) {
442 let size = self.size();
443
444 // Use unrolled loops for better instruction throughput
445 let unroll_count = size / 4;
446 let mut i = 0;
447
448 // Process 4 elements at a time for better cache utilization
449 while i < unroll_count {
450 let idx = i * 4;
451 dst.add(idx).write(a.add(idx).read() * b.add(idx).read());
452 dst.add(idx + 1)
453 .write(a.add(idx + 1).read() * b.add(idx + 1).read());
454 dst.add(idx + 2)
455 .write(a.add(idx + 2).read() * b.add(idx + 2).read());
456 dst.add(idx + 3)
457 .write(a.add(idx + 3).read() * b.add(idx + 3).read());
458 i += 1;
459 }
460
461 // Handle remaining elements
462 for j in (unroll_count * 4)..size {
463 dst.add(j).write(a.add(j).read() * b.add(j).read());
464 }
465 }
466
467 /// Optimized scalar multiplication using SIMD when available
468 ///
469 /// Performs element-wise scalar multiplication using SIMD optimization when available
470 /// and falling back to optimized scalar computation. This is the core implementation
471 /// used by `mul_scalar()`.
472 ///
473 /// # Arguments
474 ///
475 /// * `scalar` - The scalar value to multiply by
476 ///
477 /// # Returns
478 ///
479 /// A new tensor with the result of the multiplication
480 ///
481 /// # Performance Characteristics
482 ///
483 /// - **SIMD Optimization**: AVX2-optimized with 32-element blocks when available
484 /// - **Scalar Fallback**: 4x unrolled scalar implementation for non-SIMD hardware
485 /// - **Cache-friendly**: Linear memory access patterns
486 /// - **Memory Safety**: Ensures contiguous memory layout for correctness
487 /// - **Zero-sized Handling**: Fast return for empty tensors
488 ///
489 /// # Implementation Details
490 ///
491 /// Automatically selects between SIMD and scalar implementations based on hardware
492 /// capabilities. Ensures the input tensor is contiguous for memory safety.
493 /// SIMD implementation processes 32 elements per iteration with 4x unrolling.
494 #[inline]
495 pub(crate) fn mul_scalar_optimized(&self, scalar: f32) -> Tensor {
496 // Ensure contiguous source for correctness with view tensors
497 let src_self = if self.is_contiguous() {
498 self.clone()
499 } else {
500 self.contiguous()
501 };
502 let mut output = Tensor::new(self.shape().dims.clone());
503
504 unsafe {
505 let src = src_self.as_ptr();
506 let dst = output.as_mut_ptr();
507
508 #[cfg(target_arch = "x86_64")]
509 {
510 // Use SIMD for better performance when available
511 if is_x86_feature_detected!("avx2") {
512 self.mul_scalar_simd_avx2_optimized(src, dst, scalar);
513 return output;
514 }
515 }
516
517 // Fallback to scalar operations with better cache usage
518 self.mul_scalar_fallback_optimized(src, dst, scalar);
519 }
520
521 output
522 }
523
524 /// AVX2-optimized scalar multiplication implementation
525 ///
526 /// Performs element-wise scalar multiplication using AVX2 SIMD instructions for maximum
527 /// performance on x86_64 architectures with AVX2 support.
528 ///
529 /// # Arguments
530 ///
531 /// * `src` - Pointer to source tensor data
532 /// * `dst` - Pointer to output tensor data
533 /// * `scalar` - Scalar value to multiply by
534 ///
535 /// # Safety
536 ///
537 /// Requires valid pointers with sufficient memory for the tensor size.
538 /// All pointers must point to valid tensor data. Requires AVX2 support.
539 ///
540 /// # Performance Characteristics
541 ///
542 /// - **SIMD Processing**: 32 elements per iteration with 4x unrolling
543 /// - **Memory Access**: Linear access patterns for cache efficiency
544 /// - **Vector Operations**: Uses AVX2 multiplication instructions
545 /// - **Fallback**: Handles remaining elements with scalar operations
546 /// - **Hardware Requirements**: Requires x86_64 with AVX2 support
547 ///
548 /// # Implementation Details
549 ///
550 /// Uses AVX2 vector instructions to process 8 elements simultaneously.
551 /// Implements 4x unrolling for optimal instruction throughput and cache utilization.
552 /// Processes remaining elements with scalar operations for complete coverage.
553 #[cfg(target_arch = "x86_64")]
554 #[inline]
555 #[target_feature(enable = "avx2")]
556 unsafe fn mul_scalar_simd_avx2_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
557 let size = self.size();
558 let simd_count = size / 32; // Process 32 elements per iteration (4x unroll)
559 let scalar_vec = _mm256_set1_ps(scalar);
560 let mut offset = 0;
561
562 // Unrolled SIMD loop for throughput
563 for _ in 0..simd_count {
564 // Process 4 AVX2 vectors (32 elements) per iteration
565 let src_vec1 = _mm256_loadu_ps(src.add(offset));
566 let mul_vec1 = _mm256_mul_ps(src_vec1, scalar_vec);
567 _mm256_storeu_ps(dst.add(offset), mul_vec1);
568
569 let src_vec2 = _mm256_loadu_ps(src.add(offset + 8));
570 let mul_vec2 = _mm256_mul_ps(src_vec2, scalar_vec);
571 _mm256_storeu_ps(dst.add(offset + 8), mul_vec2);
572
573 let src_vec3 = _mm256_loadu_ps(src.add(offset + 16));
574 let mul_vec3 = _mm256_mul_ps(src_vec3, scalar_vec);
575 _mm256_storeu_ps(dst.add(offset + 16), mul_vec3);
576
577 let src_vec4 = _mm256_loadu_ps(src.add(offset + 24));
578 let mul_vec4 = _mm256_mul_ps(src_vec4, scalar_vec);
579 _mm256_storeu_ps(dst.add(offset + 24), mul_vec4);
580
581 offset += 32;
582 }
583
584 // Handle remaining 8-element blocks, then tail
585 let remaining_full_blocks = (size - offset) / 8;
586 for _ in 0..remaining_full_blocks {
587 let src_vec = _mm256_loadu_ps(src.add(offset));
588 let mul_vec = _mm256_mul_ps(src_vec, scalar_vec);
589 _mm256_storeu_ps(dst.add(offset), mul_vec);
590 offset += 8;
591 }
592 while offset + 4 <= size {
593 *dst.add(offset) = *src.add(offset) * scalar;
594 *dst.add(offset + 1) = *src.add(offset + 1) * scalar;
595 *dst.add(offset + 2) = *src.add(offset + 2) * scalar;
596 *dst.add(offset + 3) = *src.add(offset + 3) * scalar;
597 offset += 4;
598 }
599 for i in offset..size {
600 *dst.add(i) = *src.add(i) * scalar;
601 }
602 }
603
604 /// Optimized scalar multiplication fallback
605 ///
606 /// Performs element-wise scalar multiplication using optimized scalar operations with
607 /// 4x unrolling for better instruction-level parallelism and cache efficiency.
608 ///
609 /// # Arguments
610 ///
611 /// * `src` - Pointer to source tensor data
612 /// * `dst` - Pointer to output tensor data
613 /// * `scalar` - Scalar value to multiply by
614 ///
615 /// # Safety
616 ///
617 /// Requires valid pointers with sufficient memory for the tensor size.
618 /// All pointers must point to valid tensor data.
619 ///
620 /// # Performance Characteristics
621 ///
622 /// - **Unrolling**: 4x unrolling for instruction-level parallelism
623 /// - **Memory Access**: Linear access patterns for cache efficiency
624 /// - **Fallback**: Handles remaining elements with scalar operations
625 /// - **Cache Optimization**: Optimized for modern CPU cache hierarchies
626 /// - **Mathematical Accuracy**: High-precision scalar multiplication
627 ///
628 /// # Implementation Details
629 ///
630 /// Uses 4x unrolled scalar operations for optimal performance on non-SIMD hardware.
631 /// Processes elements in groups of 4 to improve instruction-level parallelism
632 /// and reduce loop overhead.
633 #[inline]
634 unsafe fn mul_scalar_fallback_optimized(&self, src: *const f32, dst: *mut f32, scalar: f32) {
635 let size = self.size();
636
637 // Use unrolled loops for better instruction throughput
638 let unroll_count = size / 4;
639 let mut i = 0;
640
641 // Process 4 elements at a time for better cache utilization
642 while i < unroll_count {
643 let idx = i * 4;
644 dst.add(idx).write(src.add(idx).read() * scalar);
645 dst.add(idx + 1).write(src.add(idx + 1).read() * scalar);
646 dst.add(idx + 2).write(src.add(idx + 2).read() * scalar);
647 dst.add(idx + 3).write(src.add(idx + 3).read() * scalar);
648 i += 1;
649 }
650
651 // Handle remaining elements
652 for j in (unroll_count * 4)..size {
653 dst.add(j).write(src.add(j).read() * scalar);
654 }
655 }
656}
657
658#[cfg(test)]
659mod tests {
660 use super::*;
661
662 #[test]
663 fn test_tensor_multiplication() {
664 let a = Tensor::ones(vec![2, 3]);
665 let mut b = Tensor::ones(vec![2, 3]);
666 b.fill(2.0);
667 let result = a.mul_tensor_optimized(&b);
668
669 assert_eq!(result.shape().dims, vec![2, 3]);
670 assert_eq!(result.size(), 6);
671
672 // Check that all values are 2.0 (1.0 * 2.0)
673 unsafe {
674 for i in 0..result.size() {
675 assert!((result.as_ptr().add(i).read() - 2.0).abs() < 1e-6);
676 }
677 }
678 }
679
680 #[test]
681 fn test_scalar_multiplication() {
682 let tensor = Tensor::ones(vec![2, 2]);
683 let result = tensor.mul_scalar_optimized(3.0);
684
685 assert_eq!(result.shape().dims, vec![2, 2]);
686 assert_eq!(result.size(), 4);
687
688 // Check that all values are 3.0 (1.0 * 3.0)
689 unsafe {
690 for i in 0..result.size() {
691 assert!((result.as_ptr().add(i).read() - 3.0).abs() < 1e-6);
692 }
693 }
694 }
695
696 #[test]
697 fn test_zero_multiplication() {
698 let tensor = Tensor::ones(vec![2, 3]);
699 let result = tensor.mul_scalar_optimized(0.0);
700
701 assert_eq!(result.shape().dims, vec![2, 3]);
702 assert_eq!(result.size(), 6);
703
704 // Check that all values are 0.0 (1.0 * 0.0)
705 unsafe {
706 for i in 0..result.size() {
707 assert!((result.as_ptr().add(i).read() - 0.0).abs() < 1e-6);
708 }
709 }
710 }
711
712 #[test]
713 fn test_negative_multiplication() {
714 let tensor = Tensor::ones(vec![2, 3]);
715 let result = tensor.mul_scalar_optimized(-2.0);
716
717 assert_eq!(result.shape().dims, vec![2, 3]);
718 assert_eq!(result.size(), 6);
719
720 // Check that all values are -2.0 (1.0 * -2.0)
721 unsafe {
722 for i in 0..result.size() {
723 assert!((result.as_ptr().add(i).read() - (-2.0)).abs() < 1e-6);
724 }
725 }
726 }
727
728 #[test]
729 #[should_panic(expected = "Tensor shapes must match")]
730 fn test_mismatched_shapes() {
731 let a = Tensor::ones(vec![2, 3]);
732 let b = Tensor::ones(vec![3, 2]);
733 a.mul_tensor_optimized(&b);
734 }
735
736 #[test]
737 fn test_edge_cases() {
738 // Test with zero tensor
739 let zero_tensor = Tensor::zeros(vec![2, 3]);
740 let other = Tensor::ones(vec![2, 3]);
741 let result = zero_tensor.mul_tensor_optimized(&other);
742
743 assert_eq!(result.shape().dims, vec![2, 3]);
744 assert_eq!(result.size(), 6);
745
746 // Check that all values are 0.0 (0.0 * 1.0)
747 unsafe {
748 for i in 0..result.size() {
749 assert!((result.as_ptr().add(i).read() - 0.0).abs() < 1e-6);
750 }
751 }
752
753 // Test with negative values
754 let mut neg_tensor = Tensor::ones(vec![2, 3]);
755 neg_tensor.fill(-1.0);
756 let result = neg_tensor.mul_scalar_optimized(2.0);
757
758 assert_eq!(result.shape().dims, vec![2, 3]);
759 assert_eq!(result.size(), 6);
760
761 // Check that all values are -2.0 (-1.0 * 2.0)
762 unsafe {
763 for i in 0..result.size() {
764 assert!((result.as_ptr().add(i).read() - (-2.0)).abs() < 1e-6);
765 }
766 }
767 }
768
769 #[test]
770 fn test_large_tensor_multiplication() {
771 let a = Tensor::ones(vec![100, 100]);
772 let mut b = Tensor::ones(vec![100, 100]);
773 b.fill(1.5);
774 let result = a.mul_tensor_optimized(&b);
775
776 assert_eq!(result.shape().dims, vec![100, 100]);
777 assert_eq!(result.size(), 10000);
778
779 // Check that all values are 1.5 (1.0 * 1.5)
780 unsafe {
781 for i in 0..result.size() {
782 assert!((result.as_ptr().add(i).read() - 1.5).abs() < 1e-6);
783 }
784 }
785 }
786
787 #[test]
788 fn test_multiplication_with_gradtrack() {
789 // Test scalar multiplication with gradtrack
790 let a = Tensor::ones(vec![2, 3]).with_requires_grad();
791 let mut result = a.mul_scalar(3.0);
792
793 // Check result values: 1.0 * 3.0 = 3.0
794 unsafe {
795 for i in 0..result.size() {
796 let val = result.as_ptr().add(i).read();
797 assert!((val - 3.0).abs() < 1e-6, "Expected 3.0, got {}", val);
798 }
799 }
800
801 result.backward(None);
802
803 // Check gradient: d/dx(3x) = 3
804 if let Some(grad) = a.grad_by_value() {
805 unsafe {
806 for i in 0..grad.size() {
807 let val = grad.as_ptr().add(i).read();
808 assert!(
809 (val - 3.0).abs() < 1e-6,
810 "Expected gradient 3.0, got {}",
811 val
812 );
813 }
814 }
815 } else {
816 panic!("No gradient computed for scalar multiplication!");
817 }
818
819 // Test tensor multiplication with gradtrack
820 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
821 let mut b = Tensor::ones(vec![2, 2]);
822 b.fill(3.0);
823 let b = b.with_requires_grad();
824
825 let mut result = a.mul_tensor(&b);
826
827 // Check result values: 1.0 * 3.0 = 3.0
828 unsafe {
829 for i in 0..result.size() {
830 let val = result.as_ptr().add(i).read();
831 assert!((val - 3.0).abs() < 1e-6, "Expected 3.0, got {}", val);
832 }
833 }
834
835 result.backward(None);
836
837 // Check gradients: ∂(a*b)/∂a = b, ∂(a*b)/∂b = a
838 // For a = 1.0, b = 3.0: ∂(a*b)/∂a = 3.0, ∂(a*b)/∂b = 1.0
839 if let Some(grad_a) = a.grad_by_value() {
840 unsafe {
841 for i in 0..grad_a.size() {
842 let val = grad_a.as_ptr().add(i).read();
843 assert!(
844 (val - 3.0).abs() < 1e-6,
845 "Expected gradient A = 3.0 (∂(a*b)/∂a = b), got {}",
846 val
847 );
848 }
849 }
850 } else {
851 panic!("No gradient A computed for tensor multiplication!");
852 }
853
854 if let Some(grad_b) = b.grad_by_value() {
855 unsafe {
856 for i in 0..grad_b.size() {
857 let val = grad_b.as_ptr().add(i).read();
858 assert!(
859 (val - 1.0).abs() < 1e-6,
860 "Expected gradient B = 1.0 (∂(a*b)/∂b = a), got {}",
861 val
862 );
863 }
864 }
865 } else {
866 panic!("No gradient B computed for tensor multiplication!");
867 }
868 }
869
870 #[test]
871 fn test_mixed_mul_add_operations_with_gradtrack() {
872 // Test complex computation: (a * 2) + (b * 3) - 1
873 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
874 let mut b = Tensor::ones(vec![2, 2]);
875 b.fill(3.0);
876 let b = b.with_requires_grad();
877
878 let scalar1 = 2.0;
879 let scalar2 = 3.0;
880
881 // Compute: (a * scalar1) + (b * scalar2) - 1
882 let mul_a = a.mul_scalar(scalar1); // a * 2
883 let mul_b = b.mul_scalar(scalar2); // b * 3
884 let add_result = mul_a.add_tensor(&mul_b); // (a * 2) + (b * 3)
885 let mut final_result = add_result.sub_scalar(1.0); // (a * 2) + (b * 3) - 1
886
887 // Check result values: (1*2) + (3*3) - 1 = 2 + 9 - 1 = 10
888 unsafe {
889 for i in 0..final_result.size() {
890 let val = final_result.as_ptr().add(i).read();
891 assert!((val - 10.0).abs() < 1e-6, "Expected 10.0, got {}", val);
892 }
893 }
894
895 final_result.backward(None);
896
897 // Check gradients: d/dx((2x + 3y - 1)) = 2, d/dy((2x + 3y - 1)) = 3
898 if let Some(grad_a) = a.grad_by_value() {
899 unsafe {
900 for i in 0..grad_a.size() {
901 let val = grad_a.as_ptr().add(i).read();
902 assert!(
903 (val - 2.0).abs() < 1e-6,
904 "Expected gradient A = 2.0, got {}",
905 val
906 );
907 }
908 }
909 } else {
910 panic!("No gradient A computed for mixed operations!");
911 }
912
913 if let Some(grad_b) = b.grad_by_value() {
914 unsafe {
915 for i in 0..grad_b.size() {
916 let val = grad_b.as_ptr().add(i).read();
917 assert!(
918 (val - 3.0).abs() < 1e-6,
919 "Expected gradient B = 3.0, got {}",
920 val
921 );
922 }
923 }
924 } else {
925 panic!("No gradient B computed for mixed operations!");
926 }
927 }
928
929 #[test]
930 fn test_mul_broadcasting_gradients() {
931 use crate::gradtrack::clear_gradients;
932 clear_gradients();
933
934 // Test case: [2, 3] * [1, 3] -> [2, 3]
935 // For multiplication: d/da (a * b) = b, d/db (a * b) = a
936 // So grad_a = grad_output * b, grad_b = grad_output * a (then reduced)
937
938 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
939 .unwrap()
940 .with_requires_grad();
941 let b = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![1, 3])
942 .unwrap()
943 .with_requires_grad();
944
945 let mut result = a.mul_tensor(&b);
946 assert_eq!(result.shape().dims, vec![2, 3]);
947
948 // Set upstream gradient as ones
949 result.backward(None);
950
951 // Check gradients
952 let grad_a = a.grad_by_value().expect("grad_a should exist");
953 let grad_b = b.grad_by_value().expect("grad_b should exist");
954
955 println!(
956 "Original shapes: a={:?}, b={:?}",
957 a.shape().dims,
958 b.shape().dims
959 );
960 println!(
961 "Gradient shapes: grad_a={:?}, grad_b={:?}",
962 grad_a.shape().dims,
963 grad_b.shape().dims
964 );
965
966 // grad_a should have same shape as a: [2, 3]
967 assert_eq!(
968 grad_a.shape().dims,
969 vec![2, 3],
970 "grad_a should match original shape of a"
971 );
972
973 // grad_b should have same shape as b: [1, 3]
974 assert_eq!(
975 grad_b.shape().dims,
976 vec![1, 3],
977 "grad_b should match original shape of b"
978 );
979
980 // For multiplication gradients:
981 // grad_a = grad_output * b = 1.0 * [2.0, 3.0, 4.0] broadcasted to [2, 3]
982 let expected_grad_a = [2.0, 3.0, 4.0, 2.0, 3.0, 4.0];
983 for (i, val) in expected_grad_a.iter().enumerate().take(grad_a.size()) {
984 let actual = unsafe { *grad_a.as_ptr().add(i) };
985 assert!(
986 (actual - val).abs() < 1e-6,
987 "grad_a[{}] = {} should be {}",
988 i,
989 actual,
990 val
991 );
992 }
993
994 // grad_b = grad_output * a summed over broadcasted dimension
995 // a = [1,2,3,4,5,6] -> grad_b should be [1+4, 2+5, 3+6] = [5, 7, 9]
996 let expected_grad_b = [5.0, 7.0, 9.0];
997 for (i, val) in expected_grad_b.iter().enumerate().take(grad_b.size()) {
998 let actual = unsafe { *grad_b.as_ptr().add(i) };
999 assert!(
1000 (actual - val).abs() < 1e-6,
1001 "grad_b[{}] = {} should be {}",
1002 i,
1003 actual,
1004 val
1005 );
1006 }
1007 }
1008}