train_station/tensor/ops/matmul/mod.rs
1//! Matrix multiplication operations with optimized kernels
2//!
3//! This module provides a comprehensive matrix multiplication implementation optimized
4//! for single-threaded performance with SIMD acceleration. The implementation supports
5//! all NumPy-style matrix multiplication patterns including 1D/2D/ND tensor operations
6//! with automatic differentiation support.
7//!
8//! # Key Features
9//!
10//! - **SIMD Optimization**: AVX2 implementations for x86_64 architectures
11//! - **Intelligent Dispatch**: Dynamic kernel selection based on matrix dimensions
12//! - **Cache Optimization**: Blocked algorithms for L1/L2 cache efficiency
13//! - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
14//! - **GradTrack Integration**: Automatic gradient computation for all operations
15//! - **Thread Safety**: All operations are thread-safe and Send + Sync
16//! - **Mathematical Validation**: High-precision equivalence to LibTorch reference
17//!
18//! # Performance Characteristics
19//!
20//! The implementation uses intelligent dispatch to select optimal kernels based on matrix size:
21//! - **Small matrices (16-64 elements)**: Direct computation with minimal overhead
22//! - **Medium matrices (64-256 elements)**: Cache-optimized blocking for L1/L2 cache
23//! - **Large matrices (256+ elements)**: Memory bandwidth optimized with hierarchical blocking
24//! - **AVX2 acceleration**: 8x SIMD operations for compatible hardware
25//! - **Scalar fallbacks**: Optimized scalar implementations for non-SIMD platforms
26//! - **Memory Safety**: Safe memory management with `Tensor::new_uninitialized`
27//!
28//! # Organization
29//!
30//! The matmul module is organized into focused submodules:
31//! - **`config`**: Dynamic configuration and kernel selection based on matrix dimensions
32//! - **`kernels`**: SIMD-optimized computational kernels with ML-specific optimizations
33//!
34//! # Supported Operations
35//!
36//! - **1D @ 1D**: Dot product returning scalar tensor
37//! - **1D @ 2D**: Vector-matrix multiplication (v^T * M)
38//! - **2D @ 1D**: Matrix-vector multiplication (M * v)
39//! - **2D @ 2D**: Standard matrix multiplication with cache-optimized blocking
40//! - **ND @ ND**: Batched matrix multiplication on last two dimensions with broadcasting
41//!
42//! # Examples
43//!
44//! ## Basic Matrix Multiplication
45//!
46//! ```
47//! use train_station::Tensor;
48//!
49//! // 2D matrix multiplication
50//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
51//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
52//! let result = a.matmul(&b); // Uses optimized SIMD kernels
53//!
54//! assert_eq!(result.shape().dims, vec![2, 2]);
55//! assert_eq!(result.data(), &[19.0, 22.0, 43.0, 50.0]);
56//! ```
57//!
58//! ## Vector-Matrix Multiplication
59//!
60//! ```
61//! use train_station::Tensor;
62//!
63//! // Vector-matrix multiplication
64//! let v = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
65//! let m = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
66//! let result = v.matmul(&m); // [2] @ [2, 2] -> [2]
67//!
68//! assert_eq!(result.shape().dims, vec![2]);
69//! assert_eq!(result.data(), &[7.0, 10.0]); // 1*1+2*3, 1*2+2*4
70//! ```
71//!
72//! ## Matrix-Vector Multiplication
73//!
74//! ```
75//! use train_station::Tensor;
76//!
77//! // Matrix-vector multiplication
78//! let m = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
79//! let v = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
80//! let result = m.matmul(&v); // [2, 2] @ [2] -> [2]
81//!
82//! assert_eq!(result.shape().dims, vec![2]);
83//! assert_eq!(result.data(), &[5.0, 11.0]); // 1*1+2*2, 3*1+4*2
84//! ```
85//!
86//! ## Dot Product
87//!
88//! ```
89//! use train_station::Tensor;
90//!
91//! // 1D dot product
92//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
93//! let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
94//! let result = a.matmul(&b); // [3] @ [3] -> scalar
95//!
96//! assert_eq!(result.shape().dims, vec![]); // Scalar tensor
97//! assert_eq!(result.data(), &[32.0]); // 1*4 + 2*5 + 3*6
98//! ```
99//!
100//! ## Batched Matrix Multiplication
101//!
102//! ```
103//! use train_station::Tensor;
104//!
105//! // Batched matrix multiplication
106//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]).unwrap();
107//! let b = Tensor::from_slice(&[0.5, 1.0, 1.5, 2.0], vec![2, 2]).unwrap();
108//! let result = a.matmul(&b); // [2, 2, 2] @ [2, 2] -> [2, 2, 2]
109//!
110//! assert_eq!(result.shape().dims, vec![2, 2, 2]);
111//! ```
112//!
113//! ## Gradient Tracking
114//!
115//! ```
116//! use train_station::Tensor;
117//!
118//! // Matrix multiplication with gradient tracking
119//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
120//! .unwrap()
121//! .with_requires_grad();
122//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2])
123//! .unwrap()
124//! .with_requires_grad();
125//!
126//! let result = a.matmul(&b);
127//! assert!(result.requires_grad());
128//! assert_eq!(result.shape().dims, vec![2, 2]);
129//! ```
130//!
131//! # Automatic Differentiation
132//!
133//! All operations support automatic differentiation when either operand requires gradients.
134//! Gradient computation follows PyTorch semantics with proper accumulation and chain rule
135//! application through the gradtrack engine.
136//!
137//! # Thread Safety
138//!
139//! All operations are thread-safe and can be used concurrently across multiple threads.
140//! The implementation uses immutable tensor references and thread-local gradtrack state.
141//!
142//! # Mathematical Validation
143//!
144//! All operations are validated against LibTorch reference implementation with high-precision
145//! numerical equivalence (target: 0.00e0 error tolerance, practical: 1e-6 tolerance for
146//! floating-point precision differences).
147
148use crate::tensor::core::Tensor;
149
150pub mod config;
151pub mod kernels;
152
153// Re-export public types
154pub use config::MatmulConfig;
155
156// SIMD optimizations for performance-critical operations
157#[cfg(target_arch = "x86_64")]
158use std::arch::x86_64::*;
159
160impl Tensor {
161 /// Matrix multiplication operation following NumPy semantics
162 ///
163 /// Performs matrix multiplication between this tensor and another tensor with intelligent
164 /// kernel selection based on matrix dimensions and hardware capabilities. The operation
165 /// follows broadcasting rules and supports all common matrix multiplication patterns
166 /// found in machine learning workloads.
167 ///
168 /// # Supported Operations
169 ///
170 /// - **1D @ 1D**: Dot product returning scalar tensor
171 /// - **1D @ 2D**: Vector-matrix multiplication (v^T * M) returning 1D tensor
172 /// - **2D @ 1D**: Matrix-vector multiplication (M * v) returning 1D tensor
173 /// - **2D @ 2D**: Standard matrix multiplication with cache-optimized blocking
174 /// - **ND @ ND**: Batched matrix multiplication on last two dimensions with broadcasting
175 ///
176 /// # Performance Characteristics
177 ///
178 /// The implementation automatically selects optimal kernels based on matrix dimensions:
179 /// - **Small matrices (<64 elements)**: Direct computation with minimal overhead
180 /// - **Medium matrices (64-256 elements)**: Cache-optimized blocking for L1/L2 cache
181 /// - **Large matrices (256+ elements)**: Memory bandwidth optimized with hierarchical blocking
182 /// - **AVX2 acceleration**: 8x SIMD operations for compatible hardware
183 /// - **Scalar fallbacks**: Optimized scalar implementations for non-SIMD platforms
184 ///
185 /// # Automatic Differentiation
186 ///
187 /// This operation supports automatic differentiation when either operand requires gradients.
188 /// Gradient computation follows PyTorch semantics with proper accumulation and chain rule
189 /// application through the gradtrack engine. Gradients are computed for both operands when
190 /// `requires_grad` is set.
191 ///
192 /// # Arguments
193 ///
194 /// * `other` - The tensor to multiply with (must have compatible dimensions)
195 ///
196 /// # Returns
197 ///
198 /// A new tensor containing the matrix multiplication result with appropriate shape
199 /// determined by broadcasting rules and matrix multiplication semantics
200 ///
201 /// # Panics
202 ///
203 /// Panics if the inner dimensions don't match for matrix multiplication:
204 /// - For 2D @ 2D: `self.shape()[1] != other.shape()[0]`
205 /// - For 1D @ 2D: `self.shape()[0] != other.shape()[0]`
206 /// - For 2D @ 1D: `self.shape()[1] != other.shape()[0]`
207 /// - For ND @ ND: Last two dimensions must be compatible for matrix multiplication
208 ///
209 /// # Examples
210 ///
211 /// ```
212 /// use train_station::Tensor;
213 ///
214 /// // 2D matrix multiplication
215 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
216 /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
217 /// let result = a.matmul(&b); // [2, 2] @ [2, 2] -> [2, 2]
218 ///
219 /// assert_eq!(result.shape().dims, vec![2, 2]);
220 /// assert_eq!(result.data(), &[19.0, 22.0, 43.0, 50.0]);
221 /// ```
222 ///
223 /// ## Vector-Matrix Multiplication
224 ///
225 /// ```
226 /// use train_station::Tensor;
227 ///
228 /// let v = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
229 /// let m = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
230 /// let result = v.matmul(&m); // [2] @ [2, 2] -> [2]
231 ///
232 /// assert_eq!(result.shape().dims, vec![2]);
233 /// assert_eq!(result.data(), &[7.0, 10.0]); // 1*1+2*3, 1*2+2*4
234 /// ```
235 ///
236 /// ## Gradient Tracking
237 ///
238 /// ```
239 /// use train_station::Tensor;
240 ///
241 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
242 /// .unwrap()
243 /// .with_requires_grad();
244 /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2])
245 /// .unwrap()
246 /// .with_requires_grad();
247 ///
248 /// let result = a.matmul(&b);
249 /// assert!(result.requires_grad());
250 /// assert_eq!(result.shape().dims, vec![2, 2]);
251 /// ```
252 ///
253 /// # Thread Safety
254 ///
255 /// This operation is thread-safe and can be used concurrently across multiple threads.
256 /// The implementation uses immutable tensor references and thread-local gradtrack state.
257 ///
258 /// # Memory Safety
259 ///
260 /// The implementation uses `Tensor::new_uninitialized` for performance-critical allocations
261 /// and handles memory initialization safely through the kernel system. All unsafe operations
262 /// are validated through comprehensive FFI testing against LibTorch reference implementation.
263 #[track_caller]
264 pub fn matmul(&self, other: &Tensor) -> Tensor {
265 let self_shape = self.shape();
266 let other_shape = other.shape();
267
268 let mut result = match (self_shape.rank(), other_shape.rank()) {
269 (1, 1) => {
270 // 1D @ 1D: dot product -> scalar
271 self.dot_product_1d(other)
272 }
273 (1, 2) => {
274 // 1D @ 2D: vector-matrix multiplication -> 1D
275 self.vector_matrix_mult(other)
276 }
277 (2, 1) => {
278 // 2D @ 1D: matrix-vector multiplication -> 1D
279 self.matrix_vector_mult(other)
280 }
281 (2, 2) => {
282 // 2D @ 2D: standard matrix multiplication -> 2D
283 self.matrix_matrix_mult(other)
284 }
285 _ => {
286 // ND @ ND: batched matrix multiplication
287 self.batched_matmul(other)
288 }
289 };
290
291 // Set up gradtrack if either operand requires gradients
292 if (self.requires_grad() || other.requires_grad()) && crate::gradtrack::is_grad_enabled() {
293 use crate::gradtrack::{GradEngine, GradFn};
294
295 result.set_requires_grad(true);
296 let grad_fn = GradFn::MatMul {
297 left_operand: Box::new(self.clone()),
298 right_operand: Box::new(other.clone()),
299 requires_grad: (self.requires_grad(), other.requires_grad()),
300 };
301 result.set_grad_fn(grad_fn.clone());
302
303 // Register with gradtrack engine for gradient computation
304 // Always register both operands to maintain consistent indexing
305 let input_ids = vec![self.id(), other.id()];
306
307 GradEngine::register_operation(result.id(), input_ids, grad_fn);
308 }
309
310 result
311 }
312
313 /// Dot product of two 1D tensors (returns scalar)
314 ///
315 /// Computes the dot product between two 1D tensors using SIMD-optimized kernels
316 /// when available. The implementation uses AVX2 instructions for 8x vectorization
317 /// with scalar fallbacks for non-SIMD hardware.
318 ///
319 /// # Arguments
320 ///
321 /// * `other` - The other 1D tensor (must have same length as self)
322 ///
323 /// # Returns
324 ///
325 /// A scalar tensor containing the dot product result
326 ///
327 /// # Implementation Details
328 ///
329 /// - Uses `Tensor::new_uninitialized` for performance-critical allocation
330 /// - SIMD path processes 8 elements at a time with horizontal reduction
331 /// - Scalar path uses 4x unrolled loops for instruction-level parallelism
332 /// - Memory is fully written to avoid uninitialized access
333 fn dot_product_1d(&self, other: &Tensor) -> Tensor {
334 assert_eq!(self.shape().rank(), 1, "First tensor must be 1D");
335 assert_eq!(other.shape().rank(), 1, "Second tensor must be 1D");
336 assert_eq!(
337 self.shape().dims[0],
338 other.shape().dims[0],
339 "Tensors must have same length for dot product"
340 );
341
342 let n = self.shape().dims[0];
343
344 // Ensure both tensors are contiguous for kernel compatibility
345 let self_contiguous = if self.is_contiguous() {
346 self.clone()
347 } else {
348 self.contiguous()
349 };
350 let other_contiguous = if other.is_contiguous() {
351 other.clone()
352 } else {
353 other.contiguous()
354 };
355
356 // Use uninitialized allocation for scalar result - memory will be fully written
357 let mut result = Tensor::new_uninitialized(vec![]); // Scalar tensor
358
359 unsafe {
360 let a_ptr = self_contiguous.as_ptr();
361 let b_ptr = other_contiguous.as_ptr();
362 let result_ptr = result.as_mut_ptr();
363
364 #[cfg(target_arch = "x86_64")]
365 {
366 if is_x86_feature_detected!("avx2") {
367 let dot_product = self.dot_product_simd_avx2(a_ptr, b_ptr, n);
368 *result_ptr = dot_product;
369 } else {
370 let dot_product = self.dot_product_scalar(a_ptr, b_ptr, n);
371 *result_ptr = dot_product;
372 }
373 }
374
375 #[cfg(not(target_arch = "x86_64"))]
376 {
377 let dot_product = self.dot_product_scalar(a_ptr, b_ptr, n);
378 *result_ptr = dot_product;
379 }
380 }
381
382 result
383 }
384
385 /// Vector-matrix multiplication: v^T * M
386 ///
387 /// Computes the product of a 1D vector with a 2D matrix, treating the vector
388 /// as a row vector. The implementation uses SIMD-optimized column-wise dot products
389 /// for maximum performance on compatible hardware.
390 ///
391 /// # Arguments
392 ///
393 /// * `other` - The 2D matrix tensor (vector length must match matrix rows)
394 ///
395 /// # Returns
396 ///
397 /// A 1D tensor containing the vector-matrix multiplication result
398 ///
399 /// # Implementation Details
400 ///
401 /// - Computes dot product between vector and each matrix column
402 /// - Uses SIMD kernels for each column when AVX2 is available
403 /// - Scalar fallback processes each column individually
404 /// - Memory layout optimized for column-wise access patterns
405 fn vector_matrix_mult(&self, other: &Tensor) -> Tensor {
406 assert_eq!(self.shape().rank(), 1, "First tensor must be 1D (vector)");
407 assert_eq!(other.shape().rank(), 2, "Second tensor must be 2D (matrix)");
408 assert_eq!(
409 self.shape().dims[0],
410 other.shape().dims[0],
411 "Vector length must match matrix rows"
412 );
413
414 let v_len = self.shape().dims[0];
415 let m_cols = other.shape().dims[1];
416
417 // Ensure both tensors are contiguous for kernel compatibility
418 let self_contiguous = if self.is_contiguous() {
419 self.clone()
420 } else {
421 self.contiguous()
422 };
423 let other_contiguous = if other.is_contiguous() {
424 other.clone()
425 } else {
426 other.contiguous()
427 };
428
429 // Use uninitialized allocation for performance - result will be fully written
430 let mut result = Tensor::new_uninitialized(vec![m_cols]);
431
432 unsafe {
433 let v_ptr = self_contiguous.as_ptr();
434 let m_ptr = other_contiguous.as_ptr();
435 let result_ptr = result.as_mut_ptr();
436
437 #[cfg(target_arch = "x86_64")]
438 {
439 if is_x86_feature_detected!("avx2") {
440 // Use SIMD for each column
441 for col in 0..m_cols {
442 let dot_product =
443 self.vector_matrix_column_simd_avx2(v_ptr, m_ptr, v_len, m_cols, col);
444 *result_ptr.add(col) = dot_product;
445 }
446 } else {
447 // Use scalar for each column
448 for col in 0..m_cols {
449 let dot_product =
450 self.vector_matrix_column_scalar(v_ptr, m_ptr, v_len, m_cols, col);
451 *result_ptr.add(col) = dot_product;
452 }
453 }
454 }
455
456 #[cfg(not(target_arch = "x86_64"))]
457 {
458 // Use scalar for each column
459 for col in 0..m_cols {
460 let dot_product =
461 self.vector_matrix_column_scalar(v_ptr, m_ptr, v_len, m_cols, col);
462 *result_ptr.add(col) = dot_product;
463 }
464 }
465 }
466
467 result
468 }
469
470 /// Matrix-vector multiplication: M * v
471 ///
472 /// Computes the product of a 2D matrix with a 1D vector, treating the vector
473 /// as a column vector. The implementation uses SIMD-optimized row-wise dot products
474 /// for maximum performance on compatible hardware.
475 ///
476 /// # Arguments
477 ///
478 /// * `other` - The 1D vector tensor (matrix columns must match vector length)
479 ///
480 /// # Returns
481 ///
482 /// A 1D tensor containing the matrix-vector multiplication result
483 ///
484 /// # Implementation Details
485 ///
486 /// - Computes dot product between each matrix row and the vector
487 /// - Uses SIMD kernels for each row when AVX2 is available
488 /// - Scalar fallback processes each row individually
489 /// - Memory layout optimized for row-wise access patterns
490 fn matrix_vector_mult(&self, other: &Tensor) -> Tensor {
491 assert_eq!(self.shape().rank(), 2, "First tensor must be 2D (matrix)");
492 assert_eq!(other.shape().rank(), 1, "Second tensor must be 1D (vector)");
493 assert_eq!(
494 self.shape().dims[1],
495 other.shape().dims[0],
496 "Matrix columns must match vector length"
497 );
498
499 let m_rows = self.shape().dims[0];
500 let m_cols = self.shape().dims[1];
501
502 // Ensure both tensors are contiguous for kernel compatibility
503 let self_contiguous = if self.is_contiguous() {
504 self.clone()
505 } else {
506 self.contiguous()
507 };
508 let other_contiguous = if other.is_contiguous() {
509 other.clone()
510 } else {
511 other.contiguous()
512 };
513
514 // Use uninitialized allocation for performance - result will be fully written
515 let mut result = Tensor::new_uninitialized(vec![m_rows]);
516
517 unsafe {
518 let m_ptr = self_contiguous.as_ptr();
519 let v_ptr = other_contiguous.as_ptr();
520 let result_ptr = result.as_mut_ptr();
521
522 #[cfg(target_arch = "x86_64")]
523 {
524 if is_x86_feature_detected!("avx2") {
525 // Use SIMD for each row
526 for row in 0..m_rows {
527 let dot_product =
528 self.matrix_vector_row_simd_avx2(m_ptr, v_ptr, m_cols, row);
529 *result_ptr.add(row) = dot_product;
530 }
531 } else {
532 // Use scalar for each row
533 for row in 0..m_rows {
534 let dot_product = self.matrix_vector_row_scalar(m_ptr, v_ptr, m_cols, row);
535 *result_ptr.add(row) = dot_product;
536 }
537 }
538 }
539
540 #[cfg(not(target_arch = "x86_64"))]
541 {
542 // Use scalar for each row
543 for row in 0..m_rows {
544 let dot_product = self.matrix_vector_row_scalar(m_ptr, v_ptr, m_cols, row);
545 *result_ptr.add(row) = dot_product;
546 }
547 }
548 }
549
550 result
551 }
552
553 /// Standard matrix-matrix multiplication (2D @ 2D)
554 ///
555 /// Computes the product of two 2D matrices using intelligent kernel selection
556 /// based on matrix dimensions. The implementation uses cache-friendly blocked
557 /// algorithms for large matrices and direct computation for small matrices.
558 ///
559 /// # Arguments
560 ///
561 /// * `other` - The right matrix (2D tensor with compatible inner dimensions)
562 ///
563 /// # Returns
564 ///
565 /// A 2D tensor containing the matrix multiplication result
566 ///
567 /// # Implementation Details
568 ///
569 /// - Uses `MatmulConfig::for_dimensions` for optimal kernel selection
570 /// - Dispatches to `kernels::matrix_multiply_blocked` for computation
571 /// - Supports both SIMD and scalar execution paths
572 /// - Memory layout optimized for cache efficiency and SIMD alignment
573 fn matrix_matrix_mult(&self, other: &Tensor) -> Tensor {
574 let m = self.shape().dims[0]; // Result rows
575 let k = self.shape().dims[1]; // Inner dimension
576 let n = other.shape().dims[1]; // Result columns
577
578 assert_eq!(
579 k,
580 other.shape().dims[0],
581 "Inner dimensions must match: {} vs {}",
582 k,
583 other.shape().dims[0]
584 );
585
586 // Ensure both tensors are contiguous for kernel compatibility
587 let self_contiguous = if self.is_contiguous() {
588 self.clone()
589 } else {
590 self.contiguous()
591 };
592 let other_contiguous = if other.is_contiguous() {
593 other.clone()
594 } else {
595 other.contiguous()
596 };
597
598 // Use uninitialized allocation for performance - will be initialized properly
599 let mut result = Tensor::new_uninitialized(vec![m, n]);
600
601 unsafe {
602 let a_ptr = self_contiguous.as_ptr();
603 let b_ptr = other_contiguous.as_ptr();
604 let c_ptr = result.as_mut_ptr();
605
606 // Determine optimal configuration and dispatch
607 let config = MatmulConfig::for_dimensions(m, n, k);
608 kernels::matrix_multiply_blocked(a_ptr, b_ptr, c_ptr, m, n, k, &config);
609 }
610
611 result
612 }
613
614 /// Batched matrix multiplication for higher-dimensional tensors
615 ///
616 /// Performs matrix multiplication on the last two dimensions while broadcasting
617 /// the leading dimensions. This operation supports arbitrary tensor shapes
618 /// with at least 2 dimensions, following NumPy broadcasting rules.
619 ///
620 /// # Arguments
621 ///
622 /// * `other` - The other tensor for batched multiplication (must have at least 2D)
623 ///
624 /// # Returns
625 ///
626 /// A tensor with batched matrix multiplication results, with shape determined
627 /// by broadcasting the batch dimensions and matrix multiplication on the last two
628 ///
629 /// # Implementation Details
630 ///
631 /// - Broadcasts batch dimensions following NumPy right-aligned rules
632 /// - Performs individual matrix multiplications for each batch element
633 /// - Uses `calculate_batch_offset_with_broadcast` for memory offset computation
634 /// - Supports broadcasting of singleton dimensions (size 1) to any size
635 fn batched_matmul(&self, other: &Tensor) -> Tensor {
636 let self_shape = self.shape();
637 let other_shape = other.shape();
638
639 // Ensure both tensors have at least 2 dimensions
640 assert!(
641 self_shape.rank() >= 2,
642 "Batched matmul requires at least 2D tensors"
643 );
644 assert!(
645 other_shape.rank() >= 2,
646 "Batched matmul requires at least 2D tensors"
647 );
648
649 // Get matrix dimensions (last two dimensions)
650 let self_m = self_shape.dims[self_shape.rank() - 2];
651 let self_k = self_shape.dims[self_shape.rank() - 1];
652 let other_k = other_shape.dims[other_shape.rank() - 2];
653 let other_n = other_shape.dims[other_shape.rank() - 1];
654
655 assert_eq!(
656 self_k, other_k,
657 "Inner dimensions must match for batched matmul: {} vs {}",
658 self_k, other_k
659 );
660
661 // Calculate output shape by broadcasting batch dimensions
662 let mut output_dims = Vec::new();
663 let max_rank = self_shape.rank().max(other_shape.rank());
664
665 // Broadcast batch dimensions (right-aligned)
666 for i in 0..(max_rank - 2) {
667 let self_batch_rank = self_shape.rank() - 2;
668 let other_batch_rank = other_shape.rank() - 2;
669
670 let self_dim = if i < self_batch_rank {
671 self_shape.dims[self_batch_rank - 1 - i]
672 } else {
673 1
674 };
675 let other_dim = if i < other_batch_rank {
676 other_shape.dims[other_batch_rank - 1 - i]
677 } else {
678 1
679 };
680
681 if self_dim == 1 {
682 output_dims.push(other_dim);
683 } else if other_dim == 1 || self_dim == other_dim {
684 output_dims.push(self_dim);
685 } else {
686 panic!("Cannot broadcast dimensions {} and {}", self_dim, other_dim);
687 }
688 }
689
690 // Reverse to get correct order (we built from right to left)
691 output_dims.reverse();
692
693 // Add matrix dimensions
694 output_dims.push(self_m);
695 output_dims.push(other_n);
696
697 // Use uninitialized allocation for performance - result will be fully written
698 let mut result = Tensor::new_uninitialized(output_dims.clone());
699
700 // Calculate total number of matrix multiplications
701 let batch_size: usize = output_dims[..output_dims.len() - 2].iter().product();
702
703 unsafe {
704 // Perform batched matrix multiplication
705 let batch_dims = &output_dims[..output_dims.len() - 2];
706 for batch_idx in 0..batch_size {
707 // Calculate offsets for this batch with broadcasting support
708 let self_offset = self.calculate_batch_offset_with_broadcast(
709 batch_idx,
710 self_m * self_k,
711 batch_dims,
712 );
713 let other_offset = other.calculate_batch_offset_with_broadcast(
714 batch_idx,
715 other_k * other_n,
716 batch_dims,
717 );
718 let result_offset = batch_idx * self_m * other_n;
719
720 let a_ptr = self.as_ptr().add(self_offset);
721 let b_ptr = other.as_ptr().add(other_offset);
722 let c_ptr = result.as_mut_ptr().add(result_offset);
723
724 // Perform single matrix multiplication with dynamic configuration
725 let config = MatmulConfig::for_dimensions(self_m, other_n, self_k);
726 kernels::matrix_multiply_blocked(
727 a_ptr, b_ptr, c_ptr, self_m, other_n, self_k, &config,
728 );
729 }
730 }
731
732 // Handle PyTorch-compatible shape squeezing
733 // If one operand was 2D, squeeze out the batch dimension from the result
734 let should_squeeze_batch = self_shape.rank() == 2 || other_shape.rank() == 2;
735 if should_squeeze_batch && output_dims.len() > 2 && output_dims[0] == 1 {
736 // Squeeze out the leading dimension of size 1
737 result = result.squeeze(Some(0));
738 }
739
740 result
741 }
742
743 /// Calculate memory offset for batched operations with broadcasting support
744 ///
745 /// Computes the memory offset for a specific batch element, taking into account
746 /// broadcasting rules where singleton dimensions (size 1) are repeated across
747 /// the batch. This enables efficient batched operations with broadcasting.
748 ///
749 /// # Arguments
750 ///
751 /// * `batch_idx` - Linear batch index (0-based)
752 /// * `matrix_size` - Size of each matrix in elements (product of last two dimensions)
753 /// * `output_batch_dims` - Output batch dimensions for reference (leading dimensions)
754 ///
755 /// # Returns
756 ///
757 /// Memory offset in elements for the specified batch index
758 ///
759 /// # Implementation Details
760 ///
761 /// - Converts linear batch index to multi-dimensional coordinates
762 /// - Handles broadcasting by mapping coordinates to actual tensor dimensions
763 /// - Uses stride-based offset calculation for memory efficiency
764 /// - Supports right-aligned broadcasting following NumPy conventions
765 fn calculate_batch_offset_with_broadcast(
766 &self,
767 batch_idx: usize,
768 matrix_size: usize,
769 output_batch_dims: &[usize],
770 ) -> usize {
771 if output_batch_dims.is_empty() {
772 return 0;
773 }
774
775 // Convert linear batch index to multi-dimensional coordinates
776 let mut coords = Vec::new();
777 let mut temp_idx = batch_idx;
778
779 for &dim_size in output_batch_dims.iter().rev() {
780 coords.push(temp_idx % dim_size);
781 temp_idx /= dim_size;
782 }
783 coords.reverse();
784
785 // Calculate actual offset based on this tensor's batch dimensions
786 let self_batch_dims = &self.shape().dims[..self.shape().rank() - 2];
787 let mut offset = 0;
788
789 // Align coordinates from the right (broadcasting is right-aligned)
790 let coord_offset = if output_batch_dims.len() >= self_batch_dims.len() {
791 output_batch_dims.len() - self_batch_dims.len()
792 } else {
793 0
794 };
795
796 // Calculate offset using strides
797 for (i, &self_dim) in self_batch_dims.iter().enumerate() {
798 let coord_idx = coord_offset + i;
799 if coord_idx < coords.len() {
800 let coord = coords[coord_idx];
801 // If this tensor's dimension is 1, we stay at the same position (broadcasting)
802 let actual_coord = if self_dim == 1 { 0 } else { coord % self_dim };
803
804 // Calculate stride for this dimension
805 let mut stride = matrix_size;
806 for &later_dim in self_batch_dims.iter().skip(i + 1) {
807 stride *= later_dim;
808 }
809
810 offset += actual_coord * stride;
811 }
812 }
813
814 offset
815 }
816
817 // ===== SIMD Optimized Implementations =====
818
819 /// AVX2-optimized dot product implementation
820 ///
821 /// Computes dot product using AVX2 SIMD instructions for 8x vectorization.
822 /// Processes 8 elements at a time with horizontal reduction for final sum.
823 ///
824 /// # Safety
825 ///
826 /// Requires AVX2 support and valid pointers with sufficient memory for n elements.
827 /// Memory must be properly aligned for optimal performance.
828 ///
829 /// # Arguments
830 ///
831 /// * `a_ptr` - Pointer to first vector data
832 /// * `b_ptr` - Pointer to second vector data
833 /// * `n` - Number of elements to process
834 ///
835 /// # Returns
836 ///
837 /// Dot product result as f32
838 #[cfg(target_arch = "x86_64")]
839 #[inline]
840 #[target_feature(enable = "avx2")]
841 unsafe fn dot_product_simd_avx2(&self, a_ptr: *const f32, b_ptr: *const f32, n: usize) -> f32 {
842 let simd_end = n & !7; // Process 8 elements at a time
843 let mut sum_vec = _mm256_setzero_ps();
844
845 // SIMD loop
846 for i in (0..simd_end).step_by(8) {
847 let a_vec = _mm256_loadu_ps(a_ptr.add(i));
848 let b_vec = _mm256_loadu_ps(b_ptr.add(i));
849 let prod = _mm256_mul_ps(a_vec, b_vec);
850 sum_vec = _mm256_add_ps(sum_vec, prod);
851 }
852
853 // Horizontal sum of SIMD register
854 let sum_hi = _mm256_extractf128_ps(sum_vec, 1);
855 let sum_lo = _mm256_castps256_ps128(sum_vec);
856 let sum_quad = _mm_add_ps(sum_hi, sum_lo);
857 let sum_dual = _mm_hadd_ps(sum_quad, sum_quad);
858 let sum_single = _mm_hadd_ps(sum_dual, sum_dual);
859 let mut result = _mm_cvtss_f32(sum_single);
860
861 // Handle remaining elements
862 for i in simd_end..n {
863 result += *a_ptr.add(i) * *b_ptr.add(i);
864 }
865
866 result
867 }
868
869 /// Scalar-optimized dot product implementation
870 ///
871 /// Computes dot product using scalar operations with 4x loop unrolling for
872 /// better instruction-level parallelism. Provides fallback for non-SIMD hardware.
873 ///
874 /// # Safety
875 ///
876 /// Requires valid pointers with sufficient memory for n elements.
877 ///
878 /// # Arguments
879 ///
880 /// * `a_ptr` - Pointer to first vector data
881 /// * `b_ptr` - Pointer to second vector data
882 /// * `n` - Number of elements to process
883 ///
884 /// # Returns
885 ///
886 /// Dot product result as f32
887 #[inline]
888 unsafe fn dot_product_scalar(&self, a_ptr: *const f32, b_ptr: *const f32, n: usize) -> f32 {
889 let mut sum = 0.0f32;
890 let unroll_end = n & !3; // Process 4 elements at a time
891
892 // Unrolled loop for better instruction-level parallelism
893 for i in (0..unroll_end).step_by(4) {
894 sum += *a_ptr.add(i) * *b_ptr.add(i);
895 sum += *a_ptr.add(i + 1) * *b_ptr.add(i + 1);
896 sum += *a_ptr.add(i + 2) * *b_ptr.add(i + 2);
897 sum += *a_ptr.add(i + 3) * *b_ptr.add(i + 3);
898 }
899
900 // Handle remaining elements
901 for i in unroll_end..n {
902 sum += *a_ptr.add(i) * *b_ptr.add(i);
903 }
904
905 sum
906 }
907
908 /// AVX2-optimized vector-matrix column dot product
909 ///
910 /// Computes dot product between a vector and a specific matrix column using
911 /// AVX2 SIMD instructions. Optimized for column-wise access patterns.
912 ///
913 /// # Safety
914 ///
915 /// Requires AVX2 support and valid pointers with sufficient memory.
916 /// Matrix must be in row-major layout with m_cols columns.
917 ///
918 /// # Arguments
919 ///
920 /// * `v_ptr` - Pointer to vector data
921 /// * `m_ptr` - Pointer to matrix data (row-major layout)
922 /// * `v_len` - Length of vector (must match matrix rows)
923 /// * `m_cols` - Number of columns in matrix
924 /// * `col` - Column index to compute dot product with
925 ///
926 /// # Returns
927 ///
928 /// Dot product result as f32
929 #[cfg(target_arch = "x86_64")]
930 #[inline]
931 #[target_feature(enable = "avx2")]
932 unsafe fn vector_matrix_column_simd_avx2(
933 &self,
934 v_ptr: *const f32,
935 m_ptr: *const f32,
936 v_len: usize,
937 m_cols: usize,
938 col: usize,
939 ) -> f32 {
940 let simd_end = v_len & !7;
941 let mut sum_vec = _mm256_setzero_ps();
942
943 // Process 8 elements at a time with optimized gather
944 for i in (0..simd_end).step_by(8) {
945 let v_vec = _mm256_loadu_ps(v_ptr.add(i));
946
947 // Optimized gather for matrix column elements
948 let m0 = *m_ptr.add(i * m_cols + col);
949 let m1 = *m_ptr.add((i + 1) * m_cols + col);
950 let m2 = *m_ptr.add((i + 2) * m_cols + col);
951 let m3 = *m_ptr.add((i + 3) * m_cols + col);
952 let m4 = *m_ptr.add((i + 4) * m_cols + col);
953 let m5 = *m_ptr.add((i + 5) * m_cols + col);
954 let m6 = *m_ptr.add((i + 6) * m_cols + col);
955 let m7 = *m_ptr.add((i + 7) * m_cols + col);
956
957 let m_vec = _mm256_set_ps(m7, m6, m5, m4, m3, m2, m1, m0);
958
959 let prod = _mm256_mul_ps(v_vec, m_vec);
960 sum_vec = _mm256_add_ps(sum_vec, prod);
961 }
962
963 // Horizontal sum
964 let sum_hi = _mm256_extractf128_ps(sum_vec, 1);
965 let sum_lo = _mm256_castps256_ps128(sum_vec);
966 let sum_quad = _mm_add_ps(sum_hi, sum_lo);
967 let sum_dual = _mm_hadd_ps(sum_quad, sum_quad);
968 let sum_single = _mm_hadd_ps(sum_dual, sum_dual);
969 let mut result = _mm_cvtss_f32(sum_single);
970
971 // Handle remaining elements
972 for i in simd_end..v_len {
973 result += *v_ptr.add(i) * *m_ptr.add(i * m_cols + col);
974 }
975
976 result
977 }
978
979 /// Scalar vector-matrix column dot product
980 ///
981 /// Computes dot product between a vector and a specific matrix column using
982 /// scalar operations. Provides fallback for non-SIMD hardware.
983 ///
984 /// # Safety
985 ///
986 /// Requires valid pointers with sufficient memory.
987 /// Matrix must be in row-major layout with m_cols columns.
988 ///
989 /// # Arguments
990 ///
991 /// * `v_ptr` - Pointer to vector data
992 /// * `m_ptr` - Pointer to matrix data (row-major layout)
993 /// * `v_len` - Length of vector (must match matrix rows)
994 /// * `m_cols` - Number of columns in matrix
995 /// * `col` - Column index to compute dot product with
996 ///
997 /// # Returns
998 ///
999 /// Dot product result as f32
1000 #[inline]
1001 unsafe fn vector_matrix_column_scalar(
1002 &self,
1003 v_ptr: *const f32,
1004 m_ptr: *const f32,
1005 v_len: usize,
1006 m_cols: usize,
1007 col: usize,
1008 ) -> f32 {
1009 let mut sum = 0.0f32;
1010 for i in 0..v_len {
1011 sum += *v_ptr.add(i) * *m_ptr.add(i * m_cols + col);
1012 }
1013 sum
1014 }
1015
1016 /// AVX2-optimized matrix-vector row dot product
1017 ///
1018 /// Computes dot product between a specific matrix row and a vector using
1019 /// AVX2 SIMD instructions. Optimized for row-wise access patterns.
1020 ///
1021 /// # Safety
1022 ///
1023 /// Requires AVX2 support and valid pointers with sufficient memory.
1024 /// Matrix must be in row-major layout with m_cols columns.
1025 ///
1026 /// # Arguments
1027 ///
1028 /// * `m_ptr` - Pointer to matrix data (row-major layout)
1029 /// * `v_ptr` - Pointer to vector data
1030 /// * `m_cols` - Number of columns in matrix (must match vector length)
1031 /// * `row` - Row index to compute dot product with
1032 ///
1033 /// # Returns
1034 ///
1035 /// Dot product result as f32
1036 #[cfg(target_arch = "x86_64")]
1037 #[inline]
1038 #[target_feature(enable = "avx2")]
1039 unsafe fn matrix_vector_row_simd_avx2(
1040 &self,
1041 m_ptr: *const f32,
1042 v_ptr: *const f32,
1043 m_cols: usize,
1044 row: usize,
1045 ) -> f32 {
1046 let simd_end = m_cols & !7;
1047 let mut sum_vec = _mm256_setzero_ps();
1048 let row_ptr = m_ptr.add(row * m_cols);
1049
1050 for i in (0..simd_end).step_by(8) {
1051 let m_vec = _mm256_loadu_ps(row_ptr.add(i));
1052 let v_vec = _mm256_loadu_ps(v_ptr.add(i));
1053 let prod = _mm256_mul_ps(m_vec, v_vec);
1054 sum_vec = _mm256_add_ps(sum_vec, prod);
1055 }
1056
1057 // Horizontal sum
1058 let sum_hi = _mm256_extractf128_ps(sum_vec, 1);
1059 let sum_lo = _mm256_castps256_ps128(sum_vec);
1060 let sum_quad = _mm_add_ps(sum_hi, sum_lo);
1061 let sum_dual = _mm_hadd_ps(sum_quad, sum_quad);
1062 let sum_single = _mm_hadd_ps(sum_dual, sum_dual);
1063 let mut result = _mm_cvtss_f32(sum_single);
1064
1065 // Handle remaining elements
1066 for i in simd_end..m_cols {
1067 result += *row_ptr.add(i) * *v_ptr.add(i);
1068 }
1069
1070 result
1071 }
1072
1073 /// Scalar matrix-vector row dot product
1074 ///
1075 /// Computes dot product between a specific matrix row and a vector using
1076 /// scalar operations. Provides fallback for non-SIMD hardware.
1077 ///
1078 /// # Safety
1079 ///
1080 /// Requires valid pointers with sufficient memory.
1081 /// Matrix must be in row-major layout with m_cols columns.
1082 ///
1083 /// # Arguments
1084 ///
1085 /// * `m_ptr` - Pointer to matrix data (row-major layout)
1086 /// * `v_ptr` - Pointer to vector data
1087 /// * `m_cols` - Number of columns in matrix (must match vector length)
1088 /// * `row` - Row index to compute dot product with
1089 ///
1090 /// # Returns
1091 ///
1092 /// Dot product result as f32
1093 #[inline]
1094 unsafe fn matrix_vector_row_scalar(
1095 &self,
1096 m_ptr: *const f32,
1097 v_ptr: *const f32,
1098 m_cols: usize,
1099 row: usize,
1100 ) -> f32 {
1101 let mut sum = 0.0f32;
1102 let row_ptr = m_ptr.add(row * m_cols);
1103 for i in 0..m_cols {
1104 sum += *row_ptr.add(i) * *v_ptr.add(i);
1105 }
1106 sum
1107 }
1108}
1109
1110#[cfg(test)]
1111mod tests {
1112 //! Matrix multiplication operation tests
1113 //!
1114 //! This module contains comprehensive tests for matrix multiplication operations,
1115 //! including basic functionality, kernel selection, and large matrix handling.
1116 //! Tests cover all supported operation types and edge cases.
1117
1118 use super::*;
1119
1120 /// Test basic 2x2 matrix multiplication functionality
1121 ///
1122 /// Verifies that the matmul operation correctly computes the product of two 2x2 matrices
1123 /// and produces the expected numerical results. This test validates the core matrix
1124 /// multiplication algorithm and result shape computation.
1125 #[test]
1126 fn test_matmul_2d_basic() {
1127 // Test basic 2x2 matrix multiplication
1128 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1129 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
1130 let result = a.matmul(&b);
1131
1132 assert_eq!(result.shape().dims, vec![2, 2]);
1133
1134 // Expected result: [[19, 22], [43, 50]]
1135 unsafe {
1136 let ptr = result.as_ptr();
1137 assert_eq!(*ptr.add(0), 19.0); // (0,0)
1138 assert_eq!(*ptr.add(1), 22.0); // (0,1)
1139 assert_eq!(*ptr.add(2), 43.0); // (1,0)
1140 assert_eq!(*ptr.add(3), 50.0); // (1,1)
1141 }
1142 }
1143
1144 /// Test 2D @ 2D matmul gradient computation (matrix @ matrix)
1145 #[test]
1146 fn test_matmul_2d_2d_gradients() {
1147 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
1148 .unwrap()
1149 .with_requires_grad();
1150 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2])
1151 .unwrap()
1152 .with_requires_grad();
1153
1154 let mut result = a.matmul(&b); // [2, 2] @ [2, 2] -> [2, 2]
1155 assert_eq!(result.shape().dims, vec![2, 2]);
1156
1157 // Expected result: [[19, 22], [43, 50]]
1158 let expected = [19.0, 22.0, 43.0, 50.0];
1159 unsafe {
1160 let ptr = result.as_ptr();
1161 for (i, val) in expected.iter().enumerate().take(4) {
1162 assert_eq!(*ptr.add(i), *val);
1163 }
1164 }
1165
1166 // Set up gradient for backward pass
1167 let grad_output = Tensor::from_slice(&[1.0, 1.0, 1.0, 1.0], vec![2, 2]).unwrap();
1168 result.backward(Some(grad_output));
1169
1170 let grad_a = a.grad_by_value().unwrap();
1171 let grad_b = b.grad_by_value().unwrap();
1172
1173 assert_eq!(grad_a.shape().dims, vec![2, 2]);
1174 assert_eq!(grad_b.shape().dims, vec![2, 2]);
1175
1176 // grad_a = grad_output @ b^T = [[1, 1], [1, 1]] @ [[5, 7], [6, 8]] = [[11, 15], [11, 15]]
1177
1178 unsafe {
1179 let grad_a_ptr = grad_a.as_ptr();
1180 assert_eq!(*grad_a_ptr.add(0), 11.0); // 1*5 + 1*6
1181 assert_eq!(*grad_a_ptr.add(1), 15.0); // 1*7 + 1*8
1182 assert_eq!(*grad_a_ptr.add(2), 11.0); // 1*5 + 1*6
1183 assert_eq!(*grad_a_ptr.add(3), 15.0); // 1*7 + 1*8
1184 }
1185
1186 // grad_b = a^T @ grad_output = [[1, 3], [2, 4]] @ [[1, 1], [1, 1]] = [[4, 4], [6, 6]]
1187 unsafe {
1188 let grad_b_ptr = grad_b.as_ptr();
1189 assert_eq!(*grad_b_ptr.add(0), 4.0); // 1*1 + 3*1
1190 assert_eq!(*grad_b_ptr.add(1), 4.0); // 1*1 + 3*1
1191 assert_eq!(*grad_b_ptr.add(2), 6.0); // 2*1 + 4*1
1192 assert_eq!(*grad_b_ptr.add(3), 6.0); // 2*1 + 4*1
1193 }
1194 }
1195
1196 /// Test matmul gradient computation with partial requires_grad
1197 #[test]
1198 fn test_matmul_partial_requires_grad() {
1199 // Test case where only one operand requires gradients (like the linear layer case)
1200 let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap(); // No requires_grad
1201 let b = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2])
1202 .unwrap()
1203 .with_requires_grad(); // Only b requires gradients
1204
1205 let mut result = a.matmul(&b); // [3] @ [3, 2] -> [2]
1206 assert_eq!(result.shape().dims, vec![2]);
1207
1208 result.backward(None);
1209
1210 // Only b should have gradients
1211 assert!(a.grad_by_value().is_none());
1212 let grad_b = b.grad_by_value().unwrap();
1213
1214 assert_eq!(grad_b.shape().dims, vec![3, 2]);
1215
1216 // grad_b = outer_product(a, grad_output)
1217 // Since grad_output defaults to ones([2]), grad_b[i,j] = a[i] * 1.0 = a[i]
1218 unsafe {
1219 let grad_b_ptr = grad_b.as_ptr();
1220 assert_eq!(*grad_b_ptr.add(0), 1.0); // a[0] * grad_output[0]
1221 assert_eq!(*grad_b_ptr.add(1), 1.0); // a[0] * grad_output[1]
1222 assert_eq!(*grad_b_ptr.add(2), 2.0); // a[1] * grad_output[0]
1223 assert_eq!(*grad_b_ptr.add(3), 2.0); // a[1] * grad_output[1]
1224 assert_eq!(*grad_b_ptr.add(4), 3.0); // a[2] * grad_output[0]
1225 assert_eq!(*grad_b_ptr.add(5), 3.0); // a[2] * grad_output[1]
1226 }
1227 }
1228
1229 #[test]
1230 fn test_debug_gradient_values() {
1231 println!("=== Debugging matmul gradient issue ===");
1232
1233 // Test case: [1, 3, 4] @ [2, 4, 5] which should fail with our=41, torch=29
1234 let left_shape = vec![1, 3, 4];
1235 let right_shape = vec![2, 4, 5];
1236
1237 let mut left = Tensor::zeros(left_shape.clone()).with_requires_grad();
1238 let mut right = Tensor::zeros(right_shape.clone()).with_requires_grad();
1239
1240 let left_size = left_shape.iter().product::<usize>();
1241 let right_size = right_shape.iter().product::<usize>();
1242
1243 // Fill with exactly the same data as the validation test
1244 unsafe {
1245 for i in 0..left_size {
1246 *left.as_mut_ptr().add(i) = (i as f32) * 0.1 + 1.0;
1247 }
1248 for i in 0..right_size {
1249 *right.as_mut_ptr().add(i) = (i as f32) * 0.2 + 0.5;
1250 }
1251 }
1252
1253 println!(
1254 "Left shape: {:?}, data: {:?}",
1255 left.shape().dims,
1256 left.data()
1257 );
1258 println!(
1259 "Right shape: {:?}, data: {:?}",
1260 right.shape().dims,
1261 right.data()
1262 );
1263
1264 // Forward pass
1265 let mut result = left.matmul(&right);
1266 println!(
1267 "Result shape: {:?}, data: {:?}",
1268 result.shape().dims,
1269 result.data()
1270 );
1271
1272 // Backward pass with ones
1273 let grad_ones = Tensor::ones(result.shape().dims.clone());
1274 println!(
1275 "Grad ones shape: {:?}, data: {:?}",
1276 grad_ones.shape().dims,
1277 grad_ones.data()
1278 );
1279
1280 result.backward(Some(grad_ones));
1281
1282 let grad_left = left.grad_by_value().unwrap();
1283 let grad_right = right.grad_by_value().unwrap();
1284
1285 println!(
1286 "Left gradient shape: {:?}, data: {:?}",
1287 grad_left.shape().dims,
1288 grad_left.data()
1289 );
1290 println!(
1291 "Right gradient shape: {:?}, data: {:?}",
1292 grad_right.shape().dims,
1293 grad_right.data()
1294 );
1295
1296 println!(
1297 "Left gradient[0] = {} (expected ~29, but we're getting ~41)",
1298 grad_left.data()[0]
1299 );
1300 }
1301
1302 #[test]
1303 fn test_simple_batched_gradient() {
1304 println!("=== Testing simple batched gradient ===");
1305
1306 // Simple case: [2, 2, 2] @ [2, 2, 2]
1307 let left = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2])
1308 .unwrap()
1309 .with_requires_grad();
1310 let right = Tensor::from_slice(&[0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0], vec![2, 2, 2])
1311 .unwrap()
1312 .with_requires_grad();
1313
1314 println!("Left: {:?}", left.data());
1315 println!("Right: {:?}", right.data());
1316
1317 // Test transpose function first
1318 let right_t = right.transpose(1, 2);
1319 println!("Right transposed: {:?}", right_t.data());
1320 println!("Right transposed contiguous: {:?}", right_t.is_contiguous());
1321 println!("Right transposed strides: {:?}", right_t.strides());
1322
1323 let mut result = left.matmul(&right);
1324 println!("Result: {:?}", result.data());
1325
1326 let grad_ones = Tensor::ones(result.shape().dims.clone());
1327 result.backward(Some(grad_ones));
1328
1329 let grad_left = left.grad_by_value().unwrap();
1330 let grad_right = right.grad_by_value().unwrap();
1331
1332 println!("Left gradient: {:?}", grad_left.data());
1333 println!("Right gradient: {:?}", grad_right.data());
1334
1335 // Manual calculation for verification
1336 println!("\n=== Manual verification ===");
1337 println!("Expected left grad batch 0: [0.5+1.0, 1.5+2.0] = [1.5, 3.5]");
1338 println!("Expected left grad batch 1: [2.5+3.0, 3.5+4.0] = [5.5, 7.5]");
1339 }
1340
1341 #[test]
1342 fn test_linear_layer_pattern() {
1343 // Simulate the exact pattern from the training loop
1344 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap(); // Input (no grad)
1345 let weight = Tensor::from_slice(&[0.1, 0.5, 0.3, 0.1, 0.5, 0.3], vec![3, 2])
1346 .unwrap()
1347 .with_requires_grad(); // Weight (requires grad)
1348 let bias = Tensor::from_slice(&[0.0, 0.1], vec![2])
1349 .unwrap()
1350 .with_requires_grad(); // Bias (requires grad)
1351
1352 // Forward pass
1353 let weighted = x_data.matmul(&weight); // [3] @ [3, 2] -> [2]
1354 let y_pred = weighted.add_tensor(&bias); // [2] + [2] -> [2]
1355
1356 // Create a simple loss (sum of squared differences with some target)
1357 let y_true = Tensor::from_slice(&[3.0, 5.0], vec![2]).unwrap();
1358 let mut loss = y_pred.sub_tensor(&y_true).pow_scalar(2.0).mean();
1359
1360 // Backward pass
1361 loss.backward(None);
1362
1363 // Check that gradients are computed correctly
1364 let grad_weight = weight.grad_by_value().unwrap();
1365 let grad_bias = bias.grad_by_value().unwrap();
1366
1367 assert_eq!(grad_weight.shape().dims, vec![3, 2]); // Same shape as weight
1368 assert_eq!(grad_bias.shape().dims, vec![2]); // Same shape as bias
1369
1370 // The exact gradient values depend on the computation graph, but shapes should be correct
1371 assert_eq!(grad_weight.size(), 6);
1372 assert_eq!(grad_bias.size(), 2);
1373
1374 // Verify that no gradient is computed for x_data (doesn't require grad)
1375 assert!(x_data.grad_by_value().is_none());
1376 }
1377}