train_station/tensor/ops/matmul/mod.rs
1//! Matrix multiplication operations with optimized kernels
2//!
3//! This module provides a comprehensive matrix multiplication implementation optimized
4//! for single-threaded performance with SIMD acceleration. The implementation supports
5//! all NumPy-style matrix multiplication patterns including 1D/2D/ND tensor operations
6//! with automatic differentiation support.
7//!
8//! # Key Features
9//!
10//! - **SIMD Optimization**: AVX2 implementations for x86_64 architectures
11//! - **Intelligent Dispatch**: Dynamic kernel selection based on matrix dimensions
12//! - **Cache Optimization**: Blocked algorithms for L1/L2 cache efficiency
13//! - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
14//! - **GradTrack Integration**: Automatic gradient computation for all operations
15//! - **Thread Safety**: All operations are thread-safe and Send + Sync
16//! - **Mathematical Validation**: High-precision equivalence to LibTorch reference
17//!
18//! # Performance Characteristics
19//!
20//! The implementation uses intelligent dispatch to select optimal kernels based on matrix size:
21//! - **Small matrices (16-64 elements)**: Direct computation with minimal overhead
22//! - **Medium matrices (64-256 elements)**: Cache-optimized blocking for L1/L2 cache
23//! - **Large matrices (256+ elements)**: Memory bandwidth optimized with hierarchical blocking
24//! - **AVX2 acceleration**: 8x SIMD operations for compatible hardware
25//! - **Scalar fallbacks**: Optimized scalar implementations for non-SIMD platforms
26//! - **Memory Safety**: Safe memory management with `Tensor::new_uninitialized`
27//!
28//! # Organization
29//!
30//! The matmul module is organized into focused submodules:
31//! - **`config`**: Dynamic configuration and kernel selection based on matrix dimensions
32//! - **`kernels`**: SIMD-optimized computational kernels with ML-specific optimizations
33//!
34//! # Supported Operations
35//!
36//! - **1D @ 1D**: Dot product returning scalar tensor
37//! - **1D @ 2D**: Vector-matrix multiplication (v^T * M)
38//! - **2D @ 1D**: Matrix-vector multiplication (M * v)
39//! - **2D @ 2D**: Standard matrix multiplication with cache-optimized blocking
40//! - **ND @ ND**: Batched matrix multiplication on last two dimensions with broadcasting
41//!
42//! # Examples
43//!
44//! ## Basic Matrix Multiplication
45//!
46//! ```
47//! use train_station::Tensor;
48//!
49//! // 2D matrix multiplication
50//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
51//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
52//! let result = a.matmul(&b); // Uses optimized SIMD kernels
53//!
54//! assert_eq!(result.shape().dims, vec![2, 2]);
55//! assert_eq!(result.data(), &[19.0, 22.0, 43.0, 50.0]);
56//! ```
57//!
58//! ## Vector-Matrix Multiplication
59//!
60//! ```
61//! use train_station::Tensor;
62//!
63//! // Vector-matrix multiplication
64//! let v = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
65//! let m = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
66//! let result = v.matmul(&m); // [2] @ [2, 2] -> [2]
67//!
68//! assert_eq!(result.shape().dims, vec![2]);
69//! assert_eq!(result.data(), &[7.0, 10.0]); // 1*1+2*3, 1*2+2*4
70//! ```
71//!
72//! ## Matrix-Vector Multiplication
73//!
74//! ```
75//! use train_station::Tensor;
76//!
77//! // Matrix-vector multiplication
78//! let m = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
79//! let v = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
80//! let result = m.matmul(&v); // [2, 2] @ [2] -> [2]
81//!
82//! assert_eq!(result.shape().dims, vec![2]);
83//! assert_eq!(result.data(), &[5.0, 11.0]); // 1*1+2*2, 3*1+4*2
84//! ```
85//!
86//! ## Dot Product
87//!
88//! ```
89//! use train_station::Tensor;
90//!
91//! // 1D dot product
92//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
93//! let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
94//! let result = a.matmul(&b); // [3] @ [3] -> scalar
95//!
96//! assert_eq!(result.shape().dims, vec![]); // Scalar tensor
97//! assert_eq!(result.data(), &[32.0]); // 1*4 + 2*5 + 3*6
98//! ```
99//!
100//! ## Batched Matrix Multiplication
101//!
102//! ```
103//! use train_station::Tensor;
104//!
105//! // Batched matrix multiplication
106//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]).unwrap();
107//! let b = Tensor::from_slice(&[0.5, 1.0, 1.5, 2.0], vec![2, 2]).unwrap();
108//! let result = a.matmul(&b); // [2, 2, 2] @ [2, 2] -> [2, 2, 2]
109//!
110//! assert_eq!(result.shape().dims, vec![2, 2, 2]);
111//! ```
112//!
113//! ## Gradient Tracking
114//!
115//! ```
116//! use train_station::Tensor;
117//!
118//! // Matrix multiplication with gradient tracking
119//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
120//! .unwrap()
121//! .with_requires_grad();
122//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2])
123//! .unwrap()
124//! .with_requires_grad();
125//!
126//! let result = a.matmul(&b);
127//! assert!(result.requires_grad());
128//! assert_eq!(result.shape().dims, vec![2, 2]);
129//! ```
130//!
131//! # Automatic Differentiation
132//!
133//! All operations support automatic differentiation when either operand requires gradients.
134//! Gradient computation follows PyTorch semantics with proper accumulation and chain rule
135//! application through the gradtrack engine.
136//!
137//! # Thread Safety
138//!
139//! All operations are thread-safe and can be used concurrently across multiple threads.
140//! The implementation uses immutable tensor references and thread-local gradtrack state.
141//!
142//! # Mathematical Validation
143//!
144//! All operations are validated against LibTorch reference implementation with high-precision
145//! numerical equivalence (target: 0.00e0 error tolerance, practical: 1e-6 tolerance for
146//! floating-point precision differences).
147
148use crate::tensor::core::Tensor;
149
150pub mod config;
151pub mod kernels;
152
153// Re-export public types
154pub use config::MatmulConfig;
155
156// SIMD optimizations for performance-critical operations
157#[cfg(target_arch = "x86_64")]
158use std::arch::x86_64::*;
159
160impl Tensor {
161 /// Matrix multiplication operation following NumPy semantics
162 ///
163 /// Performs matrix multiplication between this tensor and another tensor with intelligent
164 /// kernel selection based on matrix dimensions and hardware capabilities. The operation
165 /// follows broadcasting rules and supports all common matrix multiplication patterns
166 /// found in machine learning workloads.
167 ///
168 /// # Supported Operations
169 ///
170 /// - **1D @ 1D**: Dot product returning scalar tensor
171 /// - **1D @ 2D**: Vector-matrix multiplication (v^T * M) returning 1D tensor
172 /// - **2D @ 1D**: Matrix-vector multiplication (M * v) returning 1D tensor
173 /// - **2D @ 2D**: Standard matrix multiplication with cache-optimized blocking
174 /// - **ND @ ND**: Batched matrix multiplication on last two dimensions with broadcasting
175 ///
176 /// # Performance Characteristics
177 ///
178 /// The implementation automatically selects optimal kernels based on matrix dimensions:
179 /// - **Small matrices (<64 elements)**: Direct computation with minimal overhead
180 /// - **Medium matrices (64-256 elements)**: Cache-optimized blocking for L1/L2 cache
181 /// - **Large matrices (256+ elements)**: Memory bandwidth optimized with hierarchical blocking
182 /// - **AVX2 acceleration**: 8x SIMD operations for compatible hardware
183 /// - **Scalar fallbacks**: Optimized scalar implementations for non-SIMD platforms
184 ///
185 /// # Automatic Differentiation
186 ///
187 /// This operation supports automatic differentiation when either operand requires gradients.
188 /// Gradient computation follows PyTorch semantics with proper accumulation and chain rule
189 /// application through the gradtrack engine. Gradients are computed for both operands when
190 /// `requires_grad` is set.
191 ///
192 /// # Arguments
193 ///
194 /// * `other` - The tensor to multiply with (must have compatible dimensions)
195 ///
196 /// # Returns
197 ///
198 /// A new tensor containing the matrix multiplication result with appropriate shape
199 /// determined by broadcasting rules and matrix multiplication semantics
200 ///
201 /// # Panics
202 ///
203 /// Panics if the inner dimensions don't match for matrix multiplication:
204 /// - For 2D @ 2D: `self.shape()[1] != other.shape()[0]`
205 /// - For 1D @ 2D: `self.shape()[0] != other.shape()[0]`
206 /// - For 2D @ 1D: `self.shape()[1] != other.shape()[0]`
207 /// - For ND @ ND: Last two dimensions must be compatible for matrix multiplication
208 ///
209 /// # Examples
210 ///
211 /// ```
212 /// use train_station::Tensor;
213 ///
214 /// // 2D matrix multiplication
215 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
216 /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
217 /// let result = a.matmul(&b); // [2, 2] @ [2, 2] -> [2, 2]
218 ///
219 /// assert_eq!(result.shape().dims, vec![2, 2]);
220 /// assert_eq!(result.data(), &[19.0, 22.0, 43.0, 50.0]);
221 /// ```
222 ///
223 /// ## Vector-Matrix Multiplication
224 ///
225 /// ```
226 /// use train_station::Tensor;
227 ///
228 /// let v = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
229 /// let m = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
230 /// let result = v.matmul(&m); // [2] @ [2, 2] -> [2]
231 ///
232 /// assert_eq!(result.shape().dims, vec![2]);
233 /// assert_eq!(result.data(), &[7.0, 10.0]); // 1*1+2*3, 1*2+2*4
234 /// ```
235 ///
236 /// ## Gradient Tracking
237 ///
238 /// ```
239 /// use train_station::Tensor;
240 ///
241 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
242 /// .unwrap()
243 /// .with_requires_grad();
244 /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2])
245 /// .unwrap()
246 /// .with_requires_grad();
247 ///
248 /// let result = a.matmul(&b);
249 /// assert!(result.requires_grad());
250 /// assert_eq!(result.shape().dims, vec![2, 2]);
251 /// ```
252 ///
253 /// # Thread Safety
254 ///
255 /// This operation is thread-safe and can be used concurrently across multiple threads.
256 /// The implementation uses immutable tensor references and thread-local gradtrack state.
257 ///
258 /// # Memory Safety
259 ///
260 /// The implementation uses `Tensor::new_uninitialized` for performance-critical allocations
261 /// and handles memory initialization safely through the kernel system. All unsafe operations
262 /// are validated through comprehensive FFI testing against LibTorch reference implementation.
263 pub fn matmul(&self, other: &Tensor) -> Tensor {
264 let self_shape = self.shape();
265 let other_shape = other.shape();
266
267 let mut result = match (self_shape.rank(), other_shape.rank()) {
268 (1, 1) => {
269 // 1D @ 1D: dot product -> scalar
270 self.dot_product_1d(other)
271 }
272 (1, 2) => {
273 // 1D @ 2D: vector-matrix multiplication -> 1D
274 self.vector_matrix_mult(other)
275 }
276 (2, 1) => {
277 // 2D @ 1D: matrix-vector multiplication -> 1D
278 self.matrix_vector_mult(other)
279 }
280 (2, 2) => {
281 // 2D @ 2D: standard matrix multiplication -> 2D
282 self.matrix_matrix_mult(other)
283 }
284 _ => {
285 // ND @ ND: batched matrix multiplication
286 self.batched_matmul(other)
287 }
288 };
289
290 // Set up gradtrack if either operand requires gradients
291 if (self.requires_grad() || other.requires_grad()) && crate::gradtrack::is_grad_enabled() {
292 use crate::gradtrack::{GradEngine, GradFn};
293
294 result.set_requires_grad(true);
295 let grad_fn = GradFn::MatMul {
296 left_operand: Box::new(self.clone()),
297 right_operand: Box::new(other.clone()),
298 requires_grad: (self.requires_grad(), other.requires_grad()),
299 };
300 result.set_grad_fn(grad_fn.clone());
301
302 // Register with gradtrack engine for gradient computation
303 // Always register both operands to maintain consistent indexing
304 let input_ids = vec![self.id(), other.id()];
305
306 GradEngine::register_operation(result.id(), input_ids, grad_fn);
307 }
308
309 result
310 }
311
312 /// Dot product of two 1D tensors (returns scalar)
313 ///
314 /// Computes the dot product between two 1D tensors using SIMD-optimized kernels
315 /// when available. The implementation uses AVX2 instructions for 8x vectorization
316 /// with scalar fallbacks for non-SIMD hardware.
317 ///
318 /// # Arguments
319 ///
320 /// * `other` - The other 1D tensor (must have same length as self)
321 ///
322 /// # Returns
323 ///
324 /// A scalar tensor containing the dot product result
325 ///
326 /// # Implementation Details
327 ///
328 /// - Uses `Tensor::new_uninitialized` for performance-critical allocation
329 /// - SIMD path processes 8 elements at a time with horizontal reduction
330 /// - Scalar path uses 4x unrolled loops for instruction-level parallelism
331 /// - Memory is fully written to avoid uninitialized access
332 fn dot_product_1d(&self, other: &Tensor) -> Tensor {
333 assert_eq!(self.shape().rank(), 1, "First tensor must be 1D");
334 assert_eq!(other.shape().rank(), 1, "Second tensor must be 1D");
335 assert_eq!(
336 self.shape().dims[0],
337 other.shape().dims[0],
338 "Tensors must have same length for dot product"
339 );
340
341 let n = self.shape().dims[0];
342
343 // Ensure both tensors are contiguous for kernel compatibility
344 let self_contiguous = if self.is_contiguous() {
345 self.clone()
346 } else {
347 self.contiguous()
348 };
349 let other_contiguous = if other.is_contiguous() {
350 other.clone()
351 } else {
352 other.contiguous()
353 };
354
355 // Use uninitialized allocation for scalar result - memory will be fully written
356 let mut result = Tensor::new_uninitialized(vec![]); // Scalar tensor
357
358 unsafe {
359 let a_ptr = self_contiguous.as_ptr();
360 let b_ptr = other_contiguous.as_ptr();
361 let result_ptr = result.as_mut_ptr();
362
363 #[cfg(target_arch = "x86_64")]
364 {
365 if is_x86_feature_detected!("avx2") {
366 let dot_product = self.dot_product_simd_avx2(a_ptr, b_ptr, n);
367 *result_ptr = dot_product;
368 } else {
369 let dot_product = self.dot_product_scalar(a_ptr, b_ptr, n);
370 *result_ptr = dot_product;
371 }
372 }
373
374 #[cfg(not(target_arch = "x86_64"))]
375 {
376 let dot_product = self.dot_product_scalar(a_ptr, b_ptr, n);
377 *result_ptr = dot_product;
378 }
379 }
380
381 result
382 }
383
384 /// Vector-matrix multiplication: v^T * M
385 ///
386 /// Computes the product of a 1D vector with a 2D matrix, treating the vector
387 /// as a row vector. The implementation uses SIMD-optimized column-wise dot products
388 /// for maximum performance on compatible hardware.
389 ///
390 /// # Arguments
391 ///
392 /// * `other` - The 2D matrix tensor (vector length must match matrix rows)
393 ///
394 /// # Returns
395 ///
396 /// A 1D tensor containing the vector-matrix multiplication result
397 ///
398 /// # Implementation Details
399 ///
400 /// - Computes dot product between vector and each matrix column
401 /// - Uses SIMD kernels for each column when AVX2 is available
402 /// - Scalar fallback processes each column individually
403 /// - Memory layout optimized for column-wise access patterns
404 fn vector_matrix_mult(&self, other: &Tensor) -> Tensor {
405 assert_eq!(self.shape().rank(), 1, "First tensor must be 1D (vector)");
406 assert_eq!(other.shape().rank(), 2, "Second tensor must be 2D (matrix)");
407 assert_eq!(
408 self.shape().dims[0],
409 other.shape().dims[0],
410 "Vector length must match matrix rows"
411 );
412
413 let v_len = self.shape().dims[0];
414 let m_cols = other.shape().dims[1];
415
416 // Ensure both tensors are contiguous for kernel compatibility
417 let self_contiguous = if self.is_contiguous() {
418 self.clone()
419 } else {
420 self.contiguous()
421 };
422 let other_contiguous = if other.is_contiguous() {
423 other.clone()
424 } else {
425 other.contiguous()
426 };
427
428 // Use uninitialized allocation for performance - result will be fully written
429 let mut result = Tensor::new_uninitialized(vec![m_cols]);
430
431 unsafe {
432 let v_ptr = self_contiguous.as_ptr();
433 let m_ptr = other_contiguous.as_ptr();
434 let result_ptr = result.as_mut_ptr();
435
436 #[cfg(target_arch = "x86_64")]
437 {
438 if is_x86_feature_detected!("avx2") {
439 // Use SIMD for each column
440 for col in 0..m_cols {
441 let dot_product =
442 self.vector_matrix_column_simd_avx2(v_ptr, m_ptr, v_len, m_cols, col);
443 *result_ptr.add(col) = dot_product;
444 }
445 } else {
446 // Use scalar for each column
447 for col in 0..m_cols {
448 let dot_product =
449 self.vector_matrix_column_scalar(v_ptr, m_ptr, v_len, m_cols, col);
450 *result_ptr.add(col) = dot_product;
451 }
452 }
453 }
454
455 #[cfg(not(target_arch = "x86_64"))]
456 {
457 // Use scalar for each column
458 for col in 0..m_cols {
459 let dot_product =
460 self.vector_matrix_column_scalar(v_ptr, m_ptr, v_len, m_cols, col);
461 *result_ptr.add(col) = dot_product;
462 }
463 }
464 }
465
466 result
467 }
468
469 /// Matrix-vector multiplication: M * v
470 ///
471 /// Computes the product of a 2D matrix with a 1D vector, treating the vector
472 /// as a column vector. The implementation uses SIMD-optimized row-wise dot products
473 /// for maximum performance on compatible hardware.
474 ///
475 /// # Arguments
476 ///
477 /// * `other` - The 1D vector tensor (matrix columns must match vector length)
478 ///
479 /// # Returns
480 ///
481 /// A 1D tensor containing the matrix-vector multiplication result
482 ///
483 /// # Implementation Details
484 ///
485 /// - Computes dot product between each matrix row and the vector
486 /// - Uses SIMD kernels for each row when AVX2 is available
487 /// - Scalar fallback processes each row individually
488 /// - Memory layout optimized for row-wise access patterns
489 fn matrix_vector_mult(&self, other: &Tensor) -> Tensor {
490 assert_eq!(self.shape().rank(), 2, "First tensor must be 2D (matrix)");
491 assert_eq!(other.shape().rank(), 1, "Second tensor must be 1D (vector)");
492 assert_eq!(
493 self.shape().dims[1],
494 other.shape().dims[0],
495 "Matrix columns must match vector length"
496 );
497
498 let m_rows = self.shape().dims[0];
499 let m_cols = self.shape().dims[1];
500
501 // Ensure both tensors are contiguous for kernel compatibility
502 let self_contiguous = if self.is_contiguous() {
503 self.clone()
504 } else {
505 self.contiguous()
506 };
507 let other_contiguous = if other.is_contiguous() {
508 other.clone()
509 } else {
510 other.contiguous()
511 };
512
513 // Use uninitialized allocation for performance - result will be fully written
514 let mut result = Tensor::new_uninitialized(vec![m_rows]);
515
516 unsafe {
517 let m_ptr = self_contiguous.as_ptr();
518 let v_ptr = other_contiguous.as_ptr();
519 let result_ptr = result.as_mut_ptr();
520
521 #[cfg(target_arch = "x86_64")]
522 {
523 if is_x86_feature_detected!("avx2") {
524 // Use SIMD for each row
525 for row in 0..m_rows {
526 let dot_product =
527 self.matrix_vector_row_simd_avx2(m_ptr, v_ptr, m_cols, row);
528 *result_ptr.add(row) = dot_product;
529 }
530 } else {
531 // Use scalar for each row
532 for row in 0..m_rows {
533 let dot_product = self.matrix_vector_row_scalar(m_ptr, v_ptr, m_cols, row);
534 *result_ptr.add(row) = dot_product;
535 }
536 }
537 }
538
539 #[cfg(not(target_arch = "x86_64"))]
540 {
541 // Use scalar for each row
542 for row in 0..m_rows {
543 let dot_product = self.matrix_vector_row_scalar(m_ptr, v_ptr, m_cols, row);
544 *result_ptr.add(row) = dot_product;
545 }
546 }
547 }
548
549 result
550 }
551
552 /// Standard matrix-matrix multiplication (2D @ 2D)
553 ///
554 /// Computes the product of two 2D matrices using intelligent kernel selection
555 /// based on matrix dimensions. The implementation uses cache-friendly blocked
556 /// algorithms for large matrices and direct computation for small matrices.
557 ///
558 /// # Arguments
559 ///
560 /// * `other` - The right matrix (2D tensor with compatible inner dimensions)
561 ///
562 /// # Returns
563 ///
564 /// A 2D tensor containing the matrix multiplication result
565 ///
566 /// # Implementation Details
567 ///
568 /// - Uses `MatmulConfig::for_dimensions` for optimal kernel selection
569 /// - Dispatches to `kernels::matrix_multiply_blocked` for computation
570 /// - Supports both SIMD and scalar execution paths
571 /// - Memory layout optimized for cache efficiency and SIMD alignment
572 fn matrix_matrix_mult(&self, other: &Tensor) -> Tensor {
573 let m = self.shape().dims[0]; // Result rows
574 let k = self.shape().dims[1]; // Inner dimension
575 let n = other.shape().dims[1]; // Result columns
576
577 assert_eq!(
578 k,
579 other.shape().dims[0],
580 "Inner dimensions must match: {} vs {}",
581 k,
582 other.shape().dims[0]
583 );
584
585 // Ensure both tensors are contiguous for kernel compatibility
586 let self_contiguous = if self.is_contiguous() {
587 self.clone()
588 } else {
589 self.contiguous()
590 };
591 let other_contiguous = if other.is_contiguous() {
592 other.clone()
593 } else {
594 other.contiguous()
595 };
596
597 // Use uninitialized allocation for performance - will be initialized properly
598 let mut result = Tensor::new_uninitialized(vec![m, n]);
599
600 unsafe {
601 let a_ptr = self_contiguous.as_ptr();
602 let b_ptr = other_contiguous.as_ptr();
603 let c_ptr = result.as_mut_ptr();
604
605 // Determine optimal configuration and dispatch
606 let config = MatmulConfig::for_dimensions(m, n, k);
607 kernels::matrix_multiply_blocked(a_ptr, b_ptr, c_ptr, m, n, k, &config);
608 }
609
610 result
611 }
612
613 /// Batched matrix multiplication for higher-dimensional tensors
614 ///
615 /// Performs matrix multiplication on the last two dimensions while broadcasting
616 /// the leading dimensions. This operation supports arbitrary tensor shapes
617 /// with at least 2 dimensions, following NumPy broadcasting rules.
618 ///
619 /// # Arguments
620 ///
621 /// * `other` - The other tensor for batched multiplication (must have at least 2D)
622 ///
623 /// # Returns
624 ///
625 /// A tensor with batched matrix multiplication results, with shape determined
626 /// by broadcasting the batch dimensions and matrix multiplication on the last two
627 ///
628 /// # Implementation Details
629 ///
630 /// - Broadcasts batch dimensions following NumPy right-aligned rules
631 /// - Performs individual matrix multiplications for each batch element
632 /// - Uses `calculate_batch_offset_with_broadcast` for memory offset computation
633 /// - Supports broadcasting of singleton dimensions (size 1) to any size
634 fn batched_matmul(&self, other: &Tensor) -> Tensor {
635 let self_shape = self.shape();
636 let other_shape = other.shape();
637
638 // Ensure both tensors have at least 2 dimensions
639 assert!(
640 self_shape.rank() >= 2,
641 "Batched matmul requires at least 2D tensors"
642 );
643 assert!(
644 other_shape.rank() >= 2,
645 "Batched matmul requires at least 2D tensors"
646 );
647
648 // Get matrix dimensions (last two dimensions)
649 let self_m = self_shape.dims[self_shape.rank() - 2];
650 let self_k = self_shape.dims[self_shape.rank() - 1];
651 let other_k = other_shape.dims[other_shape.rank() - 2];
652 let other_n = other_shape.dims[other_shape.rank() - 1];
653
654 assert_eq!(
655 self_k, other_k,
656 "Inner dimensions must match for batched matmul: {} vs {}",
657 self_k, other_k
658 );
659
660 // Calculate output shape by broadcasting batch dimensions
661 let mut output_dims = Vec::new();
662 let max_rank = self_shape.rank().max(other_shape.rank());
663
664 // Broadcast batch dimensions (right-aligned)
665 for i in 0..(max_rank - 2) {
666 let self_batch_rank = self_shape.rank() - 2;
667 let other_batch_rank = other_shape.rank() - 2;
668
669 let self_dim = if i < self_batch_rank {
670 self_shape.dims[self_batch_rank - 1 - i]
671 } else {
672 1
673 };
674 let other_dim = if i < other_batch_rank {
675 other_shape.dims[other_batch_rank - 1 - i]
676 } else {
677 1
678 };
679
680 if self_dim == 1 {
681 output_dims.push(other_dim);
682 } else if other_dim == 1 || self_dim == other_dim {
683 output_dims.push(self_dim);
684 } else {
685 panic!("Cannot broadcast dimensions {} and {}", self_dim, other_dim);
686 }
687 }
688
689 // Reverse to get correct order (we built from right to left)
690 output_dims.reverse();
691
692 // Add matrix dimensions
693 output_dims.push(self_m);
694 output_dims.push(other_n);
695
696 // Use uninitialized allocation for performance - result will be fully written
697 let mut result = Tensor::new_uninitialized(output_dims.clone());
698
699 // Calculate total number of matrix multiplications
700 let batch_size: usize = output_dims[..output_dims.len() - 2].iter().product();
701
702 unsafe {
703 // Perform batched matrix multiplication
704 let batch_dims = &output_dims[..output_dims.len() - 2];
705 for batch_idx in 0..batch_size {
706 // Calculate offsets for this batch with broadcasting support
707 let self_offset = self.calculate_batch_offset_with_broadcast(
708 batch_idx,
709 self_m * self_k,
710 batch_dims,
711 );
712 let other_offset = other.calculate_batch_offset_with_broadcast(
713 batch_idx,
714 other_k * other_n,
715 batch_dims,
716 );
717 let result_offset = batch_idx * self_m * other_n;
718
719 let a_ptr = self.as_ptr().add(self_offset);
720 let b_ptr = other.as_ptr().add(other_offset);
721 let c_ptr = result.as_mut_ptr().add(result_offset);
722
723 // Perform single matrix multiplication with dynamic configuration
724 let config = MatmulConfig::for_dimensions(self_m, other_n, self_k);
725 kernels::matrix_multiply_blocked(
726 a_ptr, b_ptr, c_ptr, self_m, other_n, self_k, &config,
727 );
728 }
729 }
730
731 // Handle PyTorch-compatible shape squeezing
732 // If one operand was 2D, squeeze out the batch dimension from the result
733 let should_squeeze_batch = self_shape.rank() == 2 || other_shape.rank() == 2;
734 if should_squeeze_batch && output_dims.len() > 2 && output_dims[0] == 1 {
735 // Squeeze out the leading dimension of size 1
736 result = result.squeeze(Some(0));
737 }
738
739 result
740 }
741
742 /// Calculate memory offset for batched operations with broadcasting support
743 ///
744 /// Computes the memory offset for a specific batch element, taking into account
745 /// broadcasting rules where singleton dimensions (size 1) are repeated across
746 /// the batch. This enables efficient batched operations with broadcasting.
747 ///
748 /// # Arguments
749 ///
750 /// * `batch_idx` - Linear batch index (0-based)
751 /// * `matrix_size` - Size of each matrix in elements (product of last two dimensions)
752 /// * `output_batch_dims` - Output batch dimensions for reference (leading dimensions)
753 ///
754 /// # Returns
755 ///
756 /// Memory offset in elements for the specified batch index
757 ///
758 /// # Implementation Details
759 ///
760 /// - Converts linear batch index to multi-dimensional coordinates
761 /// - Handles broadcasting by mapping coordinates to actual tensor dimensions
762 /// - Uses stride-based offset calculation for memory efficiency
763 /// - Supports right-aligned broadcasting following NumPy conventions
764 fn calculate_batch_offset_with_broadcast(
765 &self,
766 batch_idx: usize,
767 matrix_size: usize,
768 output_batch_dims: &[usize],
769 ) -> usize {
770 if output_batch_dims.is_empty() {
771 return 0;
772 }
773
774 // Convert linear batch index to multi-dimensional coordinates
775 let mut coords = Vec::new();
776 let mut temp_idx = batch_idx;
777
778 for &dim_size in output_batch_dims.iter().rev() {
779 coords.push(temp_idx % dim_size);
780 temp_idx /= dim_size;
781 }
782 coords.reverse();
783
784 // Calculate actual offset based on this tensor's batch dimensions
785 let self_batch_dims = &self.shape().dims[..self.shape().rank() - 2];
786 let mut offset = 0;
787
788 // Align coordinates from the right (broadcasting is right-aligned)
789 let coord_offset = if output_batch_dims.len() >= self_batch_dims.len() {
790 output_batch_dims.len() - self_batch_dims.len()
791 } else {
792 0
793 };
794
795 // Calculate offset using strides
796 for (i, &self_dim) in self_batch_dims.iter().enumerate() {
797 let coord_idx = coord_offset + i;
798 if coord_idx < coords.len() {
799 let coord = coords[coord_idx];
800 // If this tensor's dimension is 1, we stay at the same position (broadcasting)
801 let actual_coord = if self_dim == 1 { 0 } else { coord % self_dim };
802
803 // Calculate stride for this dimension
804 let mut stride = matrix_size;
805 for &later_dim in self_batch_dims.iter().skip(i + 1) {
806 stride *= later_dim;
807 }
808
809 offset += actual_coord * stride;
810 }
811 }
812
813 offset
814 }
815
816 // ===== SIMD Optimized Implementations =====
817
818 /// AVX2-optimized dot product implementation
819 ///
820 /// Computes dot product using AVX2 SIMD instructions for 8x vectorization.
821 /// Processes 8 elements at a time with horizontal reduction for final sum.
822 ///
823 /// # Safety
824 ///
825 /// Requires AVX2 support and valid pointers with sufficient memory for n elements.
826 /// Memory must be properly aligned for optimal performance.
827 ///
828 /// # Arguments
829 ///
830 /// * `a_ptr` - Pointer to first vector data
831 /// * `b_ptr` - Pointer to second vector data
832 /// * `n` - Number of elements to process
833 ///
834 /// # Returns
835 ///
836 /// Dot product result as f32
837 #[cfg(target_arch = "x86_64")]
838 #[inline]
839 #[target_feature(enable = "avx2")]
840 unsafe fn dot_product_simd_avx2(&self, a_ptr: *const f32, b_ptr: *const f32, n: usize) -> f32 {
841 let simd_end = n & !7; // Process 8 elements at a time
842 let mut sum_vec = _mm256_setzero_ps();
843
844 // SIMD loop
845 for i in (0..simd_end).step_by(8) {
846 let a_vec = _mm256_loadu_ps(a_ptr.add(i));
847 let b_vec = _mm256_loadu_ps(b_ptr.add(i));
848 let prod = _mm256_mul_ps(a_vec, b_vec);
849 sum_vec = _mm256_add_ps(sum_vec, prod);
850 }
851
852 // Horizontal sum of SIMD register
853 let sum_hi = _mm256_extractf128_ps(sum_vec, 1);
854 let sum_lo = _mm256_castps256_ps128(sum_vec);
855 let sum_quad = _mm_add_ps(sum_hi, sum_lo);
856 let sum_dual = _mm_hadd_ps(sum_quad, sum_quad);
857 let sum_single = _mm_hadd_ps(sum_dual, sum_dual);
858 let mut result = _mm_cvtss_f32(sum_single);
859
860 // Handle remaining elements
861 for i in simd_end..n {
862 result += *a_ptr.add(i) * *b_ptr.add(i);
863 }
864
865 result
866 }
867
868 /// Scalar-optimized dot product implementation
869 ///
870 /// Computes dot product using scalar operations with 4x loop unrolling for
871 /// better instruction-level parallelism. Provides fallback for non-SIMD hardware.
872 ///
873 /// # Safety
874 ///
875 /// Requires valid pointers with sufficient memory for n elements.
876 ///
877 /// # Arguments
878 ///
879 /// * `a_ptr` - Pointer to first vector data
880 /// * `b_ptr` - Pointer to second vector data
881 /// * `n` - Number of elements to process
882 ///
883 /// # Returns
884 ///
885 /// Dot product result as f32
886 #[inline]
887 unsafe fn dot_product_scalar(&self, a_ptr: *const f32, b_ptr: *const f32, n: usize) -> f32 {
888 let mut sum = 0.0f32;
889 let unroll_end = n & !3; // Process 4 elements at a time
890
891 // Unrolled loop for better instruction-level parallelism
892 for i in (0..unroll_end).step_by(4) {
893 sum += *a_ptr.add(i) * *b_ptr.add(i);
894 sum += *a_ptr.add(i + 1) * *b_ptr.add(i + 1);
895 sum += *a_ptr.add(i + 2) * *b_ptr.add(i + 2);
896 sum += *a_ptr.add(i + 3) * *b_ptr.add(i + 3);
897 }
898
899 // Handle remaining elements
900 for i in unroll_end..n {
901 sum += *a_ptr.add(i) * *b_ptr.add(i);
902 }
903
904 sum
905 }
906
907 /// AVX2-optimized vector-matrix column dot product
908 ///
909 /// Computes dot product between a vector and a specific matrix column using
910 /// AVX2 SIMD instructions. Optimized for column-wise access patterns.
911 ///
912 /// # Safety
913 ///
914 /// Requires AVX2 support and valid pointers with sufficient memory.
915 /// Matrix must be in row-major layout with m_cols columns.
916 ///
917 /// # Arguments
918 ///
919 /// * `v_ptr` - Pointer to vector data
920 /// * `m_ptr` - Pointer to matrix data (row-major layout)
921 /// * `v_len` - Length of vector (must match matrix rows)
922 /// * `m_cols` - Number of columns in matrix
923 /// * `col` - Column index to compute dot product with
924 ///
925 /// # Returns
926 ///
927 /// Dot product result as f32
928 #[cfg(target_arch = "x86_64")]
929 #[inline]
930 #[target_feature(enable = "avx2")]
931 unsafe fn vector_matrix_column_simd_avx2(
932 &self,
933 v_ptr: *const f32,
934 m_ptr: *const f32,
935 v_len: usize,
936 m_cols: usize,
937 col: usize,
938 ) -> f32 {
939 let simd_end = v_len & !7;
940 let mut sum_vec = _mm256_setzero_ps();
941
942 // Process 8 elements at a time with optimized gather
943 for i in (0..simd_end).step_by(8) {
944 let v_vec = _mm256_loadu_ps(v_ptr.add(i));
945
946 // Optimized gather for matrix column elements
947 let m0 = *m_ptr.add(i * m_cols + col);
948 let m1 = *m_ptr.add((i + 1) * m_cols + col);
949 let m2 = *m_ptr.add((i + 2) * m_cols + col);
950 let m3 = *m_ptr.add((i + 3) * m_cols + col);
951 let m4 = *m_ptr.add((i + 4) * m_cols + col);
952 let m5 = *m_ptr.add((i + 5) * m_cols + col);
953 let m6 = *m_ptr.add((i + 6) * m_cols + col);
954 let m7 = *m_ptr.add((i + 7) * m_cols + col);
955
956 let m_vec = _mm256_set_ps(m7, m6, m5, m4, m3, m2, m1, m0);
957
958 let prod = _mm256_mul_ps(v_vec, m_vec);
959 sum_vec = _mm256_add_ps(sum_vec, prod);
960 }
961
962 // Horizontal sum
963 let sum_hi = _mm256_extractf128_ps(sum_vec, 1);
964 let sum_lo = _mm256_castps256_ps128(sum_vec);
965 let sum_quad = _mm_add_ps(sum_hi, sum_lo);
966 let sum_dual = _mm_hadd_ps(sum_quad, sum_quad);
967 let sum_single = _mm_hadd_ps(sum_dual, sum_dual);
968 let mut result = _mm_cvtss_f32(sum_single);
969
970 // Handle remaining elements
971 for i in simd_end..v_len {
972 result += *v_ptr.add(i) * *m_ptr.add(i * m_cols + col);
973 }
974
975 result
976 }
977
978 /// Scalar vector-matrix column dot product
979 ///
980 /// Computes dot product between a vector and a specific matrix column using
981 /// scalar operations. Provides fallback for non-SIMD hardware.
982 ///
983 /// # Safety
984 ///
985 /// Requires valid pointers with sufficient memory.
986 /// Matrix must be in row-major layout with m_cols columns.
987 ///
988 /// # Arguments
989 ///
990 /// * `v_ptr` - Pointer to vector data
991 /// * `m_ptr` - Pointer to matrix data (row-major layout)
992 /// * `v_len` - Length of vector (must match matrix rows)
993 /// * `m_cols` - Number of columns in matrix
994 /// * `col` - Column index to compute dot product with
995 ///
996 /// # Returns
997 ///
998 /// Dot product result as f32
999 #[inline]
1000 unsafe fn vector_matrix_column_scalar(
1001 &self,
1002 v_ptr: *const f32,
1003 m_ptr: *const f32,
1004 v_len: usize,
1005 m_cols: usize,
1006 col: usize,
1007 ) -> f32 {
1008 let mut sum = 0.0f32;
1009 for i in 0..v_len {
1010 sum += *v_ptr.add(i) * *m_ptr.add(i * m_cols + col);
1011 }
1012 sum
1013 }
1014
1015 /// AVX2-optimized matrix-vector row dot product
1016 ///
1017 /// Computes dot product between a specific matrix row and a vector using
1018 /// AVX2 SIMD instructions. Optimized for row-wise access patterns.
1019 ///
1020 /// # Safety
1021 ///
1022 /// Requires AVX2 support and valid pointers with sufficient memory.
1023 /// Matrix must be in row-major layout with m_cols columns.
1024 ///
1025 /// # Arguments
1026 ///
1027 /// * `m_ptr` - Pointer to matrix data (row-major layout)
1028 /// * `v_ptr` - Pointer to vector data
1029 /// * `m_cols` - Number of columns in matrix (must match vector length)
1030 /// * `row` - Row index to compute dot product with
1031 ///
1032 /// # Returns
1033 ///
1034 /// Dot product result as f32
1035 #[cfg(target_arch = "x86_64")]
1036 #[inline]
1037 #[target_feature(enable = "avx2")]
1038 unsafe fn matrix_vector_row_simd_avx2(
1039 &self,
1040 m_ptr: *const f32,
1041 v_ptr: *const f32,
1042 m_cols: usize,
1043 row: usize,
1044 ) -> f32 {
1045 let simd_end = m_cols & !7;
1046 let mut sum_vec = _mm256_setzero_ps();
1047 let row_ptr = m_ptr.add(row * m_cols);
1048
1049 for i in (0..simd_end).step_by(8) {
1050 let m_vec = _mm256_loadu_ps(row_ptr.add(i));
1051 let v_vec = _mm256_loadu_ps(v_ptr.add(i));
1052 let prod = _mm256_mul_ps(m_vec, v_vec);
1053 sum_vec = _mm256_add_ps(sum_vec, prod);
1054 }
1055
1056 // Horizontal sum
1057 let sum_hi = _mm256_extractf128_ps(sum_vec, 1);
1058 let sum_lo = _mm256_castps256_ps128(sum_vec);
1059 let sum_quad = _mm_add_ps(sum_hi, sum_lo);
1060 let sum_dual = _mm_hadd_ps(sum_quad, sum_quad);
1061 let sum_single = _mm_hadd_ps(sum_dual, sum_dual);
1062 let mut result = _mm_cvtss_f32(sum_single);
1063
1064 // Handle remaining elements
1065 for i in simd_end..m_cols {
1066 result += *row_ptr.add(i) * *v_ptr.add(i);
1067 }
1068
1069 result
1070 }
1071
1072 /// Scalar matrix-vector row dot product
1073 ///
1074 /// Computes dot product between a specific matrix row and a vector using
1075 /// scalar operations. Provides fallback for non-SIMD hardware.
1076 ///
1077 /// # Safety
1078 ///
1079 /// Requires valid pointers with sufficient memory.
1080 /// Matrix must be in row-major layout with m_cols columns.
1081 ///
1082 /// # Arguments
1083 ///
1084 /// * `m_ptr` - Pointer to matrix data (row-major layout)
1085 /// * `v_ptr` - Pointer to vector data
1086 /// * `m_cols` - Number of columns in matrix (must match vector length)
1087 /// * `row` - Row index to compute dot product with
1088 ///
1089 /// # Returns
1090 ///
1091 /// Dot product result as f32
1092 #[inline]
1093 unsafe fn matrix_vector_row_scalar(
1094 &self,
1095 m_ptr: *const f32,
1096 v_ptr: *const f32,
1097 m_cols: usize,
1098 row: usize,
1099 ) -> f32 {
1100 let mut sum = 0.0f32;
1101 let row_ptr = m_ptr.add(row * m_cols);
1102 for i in 0..m_cols {
1103 sum += *row_ptr.add(i) * *v_ptr.add(i);
1104 }
1105 sum
1106 }
1107}
1108
1109#[cfg(test)]
1110mod tests {
1111 //! Matrix multiplication operation tests
1112 //!
1113 //! This module contains comprehensive tests for matrix multiplication operations,
1114 //! including basic functionality, kernel selection, and large matrix handling.
1115 //! Tests cover all supported operation types and edge cases.
1116
1117 use super::*;
1118
1119 /// Test basic 2x2 matrix multiplication functionality
1120 ///
1121 /// Verifies that the matmul operation correctly computes the product of two 2x2 matrices
1122 /// and produces the expected numerical results. This test validates the core matrix
1123 /// multiplication algorithm and result shape computation.
1124 #[test]
1125 fn test_matmul_2d_basic() {
1126 // Test basic 2x2 matrix multiplication
1127 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1128 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
1129 let result = a.matmul(&b);
1130
1131 assert_eq!(result.shape().dims, vec![2, 2]);
1132
1133 // Expected result: [[19, 22], [43, 50]]
1134 unsafe {
1135 let ptr = result.as_ptr();
1136 assert_eq!(*ptr.add(0), 19.0); // (0,0)
1137 assert_eq!(*ptr.add(1), 22.0); // (0,1)
1138 assert_eq!(*ptr.add(2), 43.0); // (1,0)
1139 assert_eq!(*ptr.add(3), 50.0); // (1,1)
1140 }
1141 }
1142
1143 /// Test 2D @ 2D matmul gradient computation (matrix @ matrix)
1144 #[test]
1145 fn test_matmul_2d_2d_gradients() {
1146 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
1147 .unwrap()
1148 .with_requires_grad();
1149 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2])
1150 .unwrap()
1151 .with_requires_grad();
1152
1153 let mut result = a.matmul(&b); // [2, 2] @ [2, 2] -> [2, 2]
1154 assert_eq!(result.shape().dims, vec![2, 2]);
1155
1156 // Expected result: [[19, 22], [43, 50]]
1157 let expected = [19.0, 22.0, 43.0, 50.0];
1158 unsafe {
1159 let ptr = result.as_ptr();
1160 for (i, val) in expected.iter().enumerate().take(4) {
1161 assert_eq!(*ptr.add(i), *val);
1162 }
1163 }
1164
1165 // Set up gradient for backward pass
1166 let grad_output = Tensor::from_slice(&[1.0, 1.0, 1.0, 1.0], vec![2, 2]).unwrap();
1167 result.backward(Some(grad_output));
1168
1169 let grad_a = a.grad_by_value().unwrap();
1170 let grad_b = b.grad_by_value().unwrap();
1171
1172 assert_eq!(grad_a.shape().dims, vec![2, 2]);
1173 assert_eq!(grad_b.shape().dims, vec![2, 2]);
1174
1175 // grad_a = grad_output @ b^T = [[1, 1], [1, 1]] @ [[5, 7], [6, 8]] = [[11, 15], [11, 15]]
1176
1177 unsafe {
1178 let grad_a_ptr = grad_a.as_ptr();
1179 assert_eq!(*grad_a_ptr.add(0), 11.0); // 1*5 + 1*6
1180 assert_eq!(*grad_a_ptr.add(1), 15.0); // 1*7 + 1*8
1181 assert_eq!(*grad_a_ptr.add(2), 11.0); // 1*5 + 1*6
1182 assert_eq!(*grad_a_ptr.add(3), 15.0); // 1*7 + 1*8
1183 }
1184
1185 // grad_b = a^T @ grad_output = [[1, 3], [2, 4]] @ [[1, 1], [1, 1]] = [[4, 4], [6, 6]]
1186 unsafe {
1187 let grad_b_ptr = grad_b.as_ptr();
1188 assert_eq!(*grad_b_ptr.add(0), 4.0); // 1*1 + 3*1
1189 assert_eq!(*grad_b_ptr.add(1), 4.0); // 1*1 + 3*1
1190 assert_eq!(*grad_b_ptr.add(2), 6.0); // 2*1 + 4*1
1191 assert_eq!(*grad_b_ptr.add(3), 6.0); // 2*1 + 4*1
1192 }
1193 }
1194
1195 /// Test matmul gradient computation with partial requires_grad
1196 #[test]
1197 fn test_matmul_partial_requires_grad() {
1198 // Test case where only one operand requires gradients (like the linear layer case)
1199 let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap(); // No requires_grad
1200 let b = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2])
1201 .unwrap()
1202 .with_requires_grad(); // Only b requires gradients
1203
1204 let mut result = a.matmul(&b); // [3] @ [3, 2] -> [2]
1205 assert_eq!(result.shape().dims, vec![2]);
1206
1207 result.backward(None);
1208
1209 // Only b should have gradients
1210 assert!(a.grad_by_value().is_none());
1211 let grad_b = b.grad_by_value().unwrap();
1212
1213 assert_eq!(grad_b.shape().dims, vec![3, 2]);
1214
1215 // grad_b = outer_product(a, grad_output)
1216 // Since grad_output defaults to ones([2]), grad_b[i,j] = a[i] * 1.0 = a[i]
1217 unsafe {
1218 let grad_b_ptr = grad_b.as_ptr();
1219 assert_eq!(*grad_b_ptr.add(0), 1.0); // a[0] * grad_output[0]
1220 assert_eq!(*grad_b_ptr.add(1), 1.0); // a[0] * grad_output[1]
1221 assert_eq!(*grad_b_ptr.add(2), 2.0); // a[1] * grad_output[0]
1222 assert_eq!(*grad_b_ptr.add(3), 2.0); // a[1] * grad_output[1]
1223 assert_eq!(*grad_b_ptr.add(4), 3.0); // a[2] * grad_output[0]
1224 assert_eq!(*grad_b_ptr.add(5), 3.0); // a[2] * grad_output[1]
1225 }
1226 }
1227
1228 #[test]
1229 fn test_debug_gradient_values() {
1230 println!("=== Debugging matmul gradient issue ===");
1231
1232 // Test case: [1, 3, 4] @ [2, 4, 5] which should fail with our=41, torch=29
1233 let left_shape = vec![1, 3, 4];
1234 let right_shape = vec![2, 4, 5];
1235
1236 let mut left = Tensor::zeros(left_shape.clone()).with_requires_grad();
1237 let mut right = Tensor::zeros(right_shape.clone()).with_requires_grad();
1238
1239 let left_size = left_shape.iter().product::<usize>();
1240 let right_size = right_shape.iter().product::<usize>();
1241
1242 // Fill with exactly the same data as the validation test
1243 unsafe {
1244 for i in 0..left_size {
1245 *left.as_mut_ptr().add(i) = (i as f32) * 0.1 + 1.0;
1246 }
1247 for i in 0..right_size {
1248 *right.as_mut_ptr().add(i) = (i as f32) * 0.2 + 0.5;
1249 }
1250 }
1251
1252 println!(
1253 "Left shape: {:?}, data: {:?}",
1254 left.shape().dims,
1255 left.data()
1256 );
1257 println!(
1258 "Right shape: {:?}, data: {:?}",
1259 right.shape().dims,
1260 right.data()
1261 );
1262
1263 // Forward pass
1264 let mut result = left.matmul(&right);
1265 println!(
1266 "Result shape: {:?}, data: {:?}",
1267 result.shape().dims,
1268 result.data()
1269 );
1270
1271 // Backward pass with ones
1272 let grad_ones = Tensor::ones(result.shape().dims.clone());
1273 println!(
1274 "Grad ones shape: {:?}, data: {:?}",
1275 grad_ones.shape().dims,
1276 grad_ones.data()
1277 );
1278
1279 result.backward(Some(grad_ones));
1280
1281 let grad_left = left.grad_by_value().unwrap();
1282 let grad_right = right.grad_by_value().unwrap();
1283
1284 println!(
1285 "Left gradient shape: {:?}, data: {:?}",
1286 grad_left.shape().dims,
1287 grad_left.data()
1288 );
1289 println!(
1290 "Right gradient shape: {:?}, data: {:?}",
1291 grad_right.shape().dims,
1292 grad_right.data()
1293 );
1294
1295 println!(
1296 "Left gradient[0] = {} (expected ~29, but we're getting ~41)",
1297 grad_left.data()[0]
1298 );
1299 }
1300
1301 #[test]
1302 fn test_simple_batched_gradient() {
1303 println!("=== Testing simple batched gradient ===");
1304
1305 // Simple case: [2, 2, 2] @ [2, 2, 2]
1306 let left = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2])
1307 .unwrap()
1308 .with_requires_grad();
1309 let right = Tensor::from_slice(&[0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0], vec![2, 2, 2])
1310 .unwrap()
1311 .with_requires_grad();
1312
1313 println!("Left: {:?}", left.data());
1314 println!("Right: {:?}", right.data());
1315
1316 // Test transpose function first
1317 let right_t = right.transpose(1, 2);
1318 println!("Right transposed: {:?}", right_t.data());
1319 println!("Right transposed contiguous: {:?}", right_t.is_contiguous());
1320 println!("Right transposed strides: {:?}", right_t.strides());
1321
1322 let mut result = left.matmul(&right);
1323 println!("Result: {:?}", result.data());
1324
1325 let grad_ones = Tensor::ones(result.shape().dims.clone());
1326 result.backward(Some(grad_ones));
1327
1328 let grad_left = left.grad_by_value().unwrap();
1329 let grad_right = right.grad_by_value().unwrap();
1330
1331 println!("Left gradient: {:?}", grad_left.data());
1332 println!("Right gradient: {:?}", grad_right.data());
1333
1334 // Manual calculation for verification
1335 println!("\n=== Manual verification ===");
1336 println!("Expected left grad batch 0: [0.5+1.0, 1.5+2.0] = [1.5, 3.5]");
1337 println!("Expected left grad batch 1: [2.5+3.0, 3.5+4.0] = [5.5, 7.5]");
1338 }
1339
1340 #[test]
1341 fn test_linear_layer_pattern() {
1342 // Simulate the exact pattern from the training loop
1343 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap(); // Input (no grad)
1344 let weight = Tensor::from_slice(&[0.1, 0.5, 0.3, 0.1, 0.5, 0.3], vec![3, 2])
1345 .unwrap()
1346 .with_requires_grad(); // Weight (requires grad)
1347 let bias = Tensor::from_slice(&[0.0, 0.1], vec![2])
1348 .unwrap()
1349 .with_requires_grad(); // Bias (requires grad)
1350
1351 // Forward pass
1352 let weighted = x_data.matmul(&weight); // [3] @ [3, 2] -> [2]
1353 let y_pred = weighted.add_tensor(&bias); // [2] + [2] -> [2]
1354
1355 // Create a simple loss (sum of squared differences with some target)
1356 let y_true = Tensor::from_slice(&[3.0, 5.0], vec![2]).unwrap();
1357 let mut loss = y_pred.sub_tensor(&y_true).pow_scalar(2.0).mean();
1358
1359 // Backward pass
1360 loss.backward(None);
1361
1362 // Check that gradients are computed correctly
1363 let grad_weight = weight.grad_by_value().unwrap();
1364 let grad_bias = bias.grad_by_value().unwrap();
1365
1366 assert_eq!(grad_weight.shape().dims, vec![3, 2]); // Same shape as weight
1367 assert_eq!(grad_bias.shape().dims, vec![2]); // Same shape as bias
1368
1369 // The exact gradient values depend on the computation graph, but shapes should be correct
1370 assert_eq!(grad_weight.size(), 6);
1371 assert_eq!(grad_bias.size(), 2);
1372
1373 // Verify that no gradient is computed for x_data (doesn't require grad)
1374 assert!(x_data.grad_by_value().is_none());
1375 }
1376}