train_station/tensor/ops/matmul/
mod.rs

1//! Matrix multiplication operations with optimized kernels
2//!
3//! This module provides a comprehensive matrix multiplication implementation optimized
4//! for single-threaded performance with SIMD acceleration. The implementation supports
5//! all NumPy-style matrix multiplication patterns including 1D/2D/ND tensor operations
6//! with automatic differentiation support.
7//!
8//! # Key Features
9//!
10//! - **SIMD Optimization**: AVX2 implementations for x86_64 architectures
11//! - **Intelligent Dispatch**: Dynamic kernel selection based on matrix dimensions
12//! - **Cache Optimization**: Blocked algorithms for L1/L2 cache efficiency
13//! - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
14//! - **GradTrack Integration**: Automatic gradient computation for all operations
15//! - **Thread Safety**: All operations are thread-safe and Send + Sync
16//! - **Mathematical Validation**: High-precision equivalence to LibTorch reference
17//!
18//! # Performance Characteristics
19//!
20//! The implementation uses intelligent dispatch to select optimal kernels based on matrix size:
21//! - **Small matrices (16-64 elements)**: Direct computation with minimal overhead
22//! - **Medium matrices (64-256 elements)**: Cache-optimized blocking for L1/L2 cache
23//! - **Large matrices (256+ elements)**: Memory bandwidth optimized with hierarchical blocking
24//! - **AVX2 acceleration**: 8x SIMD operations for compatible hardware
25//! - **Scalar fallbacks**: Optimized scalar implementations for non-SIMD platforms
26//! - **Memory Safety**: Safe memory management with `Tensor::new_uninitialized`
27//!
28//! # Organization
29//!
30//! The matmul module is organized into focused submodules:
31//! - **`config`**: Dynamic configuration and kernel selection based on matrix dimensions
32//! - **`kernels`**: SIMD-optimized computational kernels with ML-specific optimizations
33//!
34//! # Supported Operations
35//!
36//! - **1D @ 1D**: Dot product returning scalar tensor
37//! - **1D @ 2D**: Vector-matrix multiplication (v^T * M)
38//! - **2D @ 1D**: Matrix-vector multiplication (M * v)
39//! - **2D @ 2D**: Standard matrix multiplication with cache-optimized blocking
40//! - **ND @ ND**: Batched matrix multiplication on last two dimensions with broadcasting
41//!
42//! # Examples
43//!
44//! ## Basic Matrix Multiplication
45//!
46//! ```
47//! use train_station::Tensor;
48//!
49//! // 2D matrix multiplication
50//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
51//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
52//! let result = a.matmul(&b); // Uses optimized SIMD kernels
53//!
54//! assert_eq!(result.shape().dims, vec![2, 2]);
55//! assert_eq!(result.data(), &[19.0, 22.0, 43.0, 50.0]);
56//! ```
57//!
58//! ## Vector-Matrix Multiplication
59//!
60//! ```
61//! use train_station::Tensor;
62//!
63//! // Vector-matrix multiplication
64//! let v = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
65//! let m = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
66//! let result = v.matmul(&m); // [2] @ [2, 2] -> [2]
67//!
68//! assert_eq!(result.shape().dims, vec![2]);
69//! assert_eq!(result.data(), &[7.0, 10.0]); // 1*1+2*3, 1*2+2*4
70//! ```
71//!
72//! ## Matrix-Vector Multiplication
73//!
74//! ```
75//! use train_station::Tensor;
76//!
77//! // Matrix-vector multiplication
78//! let m = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
79//! let v = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
80//! let result = m.matmul(&v); // [2, 2] @ [2] -> [2]
81//!
82//! assert_eq!(result.shape().dims, vec![2]);
83//! assert_eq!(result.data(), &[5.0, 11.0]); // 1*1+2*2, 3*1+4*2
84//! ```
85//!
86//! ## Dot Product
87//!
88//! ```
89//! use train_station::Tensor;
90//!
91//! // 1D dot product
92//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
93//! let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
94//! let result = a.matmul(&b); // [3] @ [3] -> scalar
95//!
96//! assert_eq!(result.shape().dims, vec![]); // Scalar tensor
97//! assert_eq!(result.data(), &[32.0]); // 1*4 + 2*5 + 3*6
98//! ```
99//!
100//! ## Batched Matrix Multiplication
101//!
102//! ```
103//! use train_station::Tensor;
104//!
105//! // Batched matrix multiplication
106//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]).unwrap();
107//! let b = Tensor::from_slice(&[0.5, 1.0, 1.5, 2.0], vec![2, 2]).unwrap();
108//! let result = a.matmul(&b); // [2, 2, 2] @ [2, 2] -> [2, 2, 2]
109//!
110//! assert_eq!(result.shape().dims, vec![2, 2, 2]);
111//! ```
112//!
113//! ## Gradient Tracking
114//!
115//! ```
116//! use train_station::Tensor;
117//!
118//! // Matrix multiplication with gradient tracking
119//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
120//!     .unwrap()
121//!     .with_requires_grad();
122//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2])
123//!     .unwrap()
124//!     .with_requires_grad();
125//!
126//! let result = a.matmul(&b);
127//! assert!(result.requires_grad());
128//! assert_eq!(result.shape().dims, vec![2, 2]);
129//! ```
130//!
131//! # Automatic Differentiation
132//!
133//! All operations support automatic differentiation when either operand requires gradients.
134//! Gradient computation follows PyTorch semantics with proper accumulation and chain rule
135//! application through the gradtrack engine.
136//!
137//! # Thread Safety
138//!
139//! All operations are thread-safe and can be used concurrently across multiple threads.
140//! The implementation uses immutable tensor references and thread-local gradtrack state.
141//!
142//! # Mathematical Validation
143//!
144//! All operations are validated against LibTorch reference implementation with high-precision
145//! numerical equivalence (target: 0.00e0 error tolerance, practical: 1e-6 tolerance for
146//! floating-point precision differences).
147
148use crate::tensor::core::Tensor;
149
150pub mod config;
151pub mod kernels;
152
153// Re-export public types
154pub use config::MatmulConfig;
155
156// SIMD optimizations for performance-critical operations
157#[cfg(target_arch = "x86_64")]
158use std::arch::x86_64::*;
159
160impl Tensor {
161    /// Matrix multiplication operation following NumPy semantics
162    ///
163    /// Performs matrix multiplication between this tensor and another tensor with intelligent
164    /// kernel selection based on matrix dimensions and hardware capabilities. The operation
165    /// follows broadcasting rules and supports all common matrix multiplication patterns
166    /// found in machine learning workloads.
167    ///
168    /// # Supported Operations
169    ///
170    /// - **1D @ 1D**: Dot product returning scalar tensor
171    /// - **1D @ 2D**: Vector-matrix multiplication (v^T * M) returning 1D tensor
172    /// - **2D @ 1D**: Matrix-vector multiplication (M * v) returning 1D tensor
173    /// - **2D @ 2D**: Standard matrix multiplication with cache-optimized blocking
174    /// - **ND @ ND**: Batched matrix multiplication on last two dimensions with broadcasting
175    ///
176    /// # Performance Characteristics
177    ///
178    /// The implementation automatically selects optimal kernels based on matrix dimensions:
179    /// - **Small matrices (<64 elements)**: Direct computation with minimal overhead
180    /// - **Medium matrices (64-256 elements)**: Cache-optimized blocking for L1/L2 cache
181    /// - **Large matrices (256+ elements)**: Memory bandwidth optimized with hierarchical blocking
182    /// - **AVX2 acceleration**: 8x SIMD operations for compatible hardware
183    /// - **Scalar fallbacks**: Optimized scalar implementations for non-SIMD platforms
184    ///
185    /// # Automatic Differentiation
186    ///
187    /// This operation supports automatic differentiation when either operand requires gradients.
188    /// Gradient computation follows PyTorch semantics with proper accumulation and chain rule
189    /// application through the gradtrack engine. Gradients are computed for both operands when
190    /// `requires_grad` is set.
191    ///
192    /// # Arguments
193    ///
194    /// * `other` - The tensor to multiply with (must have compatible dimensions)
195    ///
196    /// # Returns
197    ///
198    /// A new tensor containing the matrix multiplication result with appropriate shape
199    /// determined by broadcasting rules and matrix multiplication semantics
200    ///
201    /// # Panics
202    ///
203    /// Panics if the inner dimensions don't match for matrix multiplication:
204    /// - For 2D @ 2D: `self.shape()[1] != other.shape()[0]`
205    /// - For 1D @ 2D: `self.shape()[0] != other.shape()[0]`
206    /// - For 2D @ 1D: `self.shape()[1] != other.shape()[0]`
207    /// - For ND @ ND: Last two dimensions must be compatible for matrix multiplication
208    ///
209    /// # Examples
210    ///
211    /// ```
212    /// use train_station::Tensor;
213    ///
214    /// // 2D matrix multiplication
215    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
216    /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
217    /// let result = a.matmul(&b); // [2, 2] @ [2, 2] -> [2, 2]
218    ///
219    /// assert_eq!(result.shape().dims, vec![2, 2]);
220    /// assert_eq!(result.data(), &[19.0, 22.0, 43.0, 50.0]);
221    /// ```
222    ///
223    /// ## Vector-Matrix Multiplication
224    ///
225    /// ```
226    /// use train_station::Tensor;
227    ///
228    /// let v = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
229    /// let m = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
230    /// let result = v.matmul(&m); // [2] @ [2, 2] -> [2]
231    ///
232    /// assert_eq!(result.shape().dims, vec![2]);
233    /// assert_eq!(result.data(), &[7.0, 10.0]); // 1*1+2*3, 1*2+2*4
234    /// ```
235    ///
236    /// ## Gradient Tracking
237    ///
238    /// ```
239    /// use train_station::Tensor;
240    ///
241    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
242    ///     .unwrap()
243    ///     .with_requires_grad();
244    /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2])
245    ///     .unwrap()
246    ///     .with_requires_grad();
247    ///
248    /// let result = a.matmul(&b);
249    /// assert!(result.requires_grad());
250    /// assert_eq!(result.shape().dims, vec![2, 2]);
251    /// ```
252    ///
253    /// # Thread Safety
254    ///
255    /// This operation is thread-safe and can be used concurrently across multiple threads.
256    /// The implementation uses immutable tensor references and thread-local gradtrack state.
257    ///
258    /// # Memory Safety
259    ///
260    /// The implementation uses `Tensor::new_uninitialized` for performance-critical allocations
261    /// and handles memory initialization safely through the kernel system. All unsafe operations
262    /// are validated through comprehensive FFI testing against LibTorch reference implementation.
263    #[track_caller]
264    pub fn matmul(&self, other: &Tensor) -> Tensor {
265        let self_shape = self.shape();
266        let other_shape = other.shape();
267
268        let mut result = match (self_shape.rank(), other_shape.rank()) {
269            (1, 1) => {
270                // 1D @ 1D: dot product -> scalar
271                self.dot_product_1d(other)
272            }
273            (1, 2) => {
274                // 1D @ 2D: vector-matrix multiplication -> 1D
275                self.vector_matrix_mult(other)
276            }
277            (2, 1) => {
278                // 2D @ 1D: matrix-vector multiplication -> 1D
279                self.matrix_vector_mult(other)
280            }
281            (2, 2) => {
282                // 2D @ 2D: standard matrix multiplication -> 2D
283                self.matrix_matrix_mult(other)
284            }
285            _ => {
286                // ND @ ND: batched matrix multiplication
287                self.batched_matmul(other)
288            }
289        };
290
291        // Set up gradtrack if either operand requires gradients
292        if (self.requires_grad() || other.requires_grad()) && crate::gradtrack::is_grad_enabled() {
293            use crate::gradtrack::{GradEngine, GradFn};
294
295            result.set_requires_grad(true);
296            let grad_fn = GradFn::MatMul {
297                left_operand: Box::new(self.clone()),
298                right_operand: Box::new(other.clone()),
299                requires_grad: (self.requires_grad(), other.requires_grad()),
300            };
301            result.set_grad_fn(grad_fn.clone());
302
303            // Register with gradtrack engine for gradient computation
304            // Always register both operands to maintain consistent indexing
305            let input_ids = vec![self.id(), other.id()];
306
307            GradEngine::register_operation(result.id(), input_ids, grad_fn);
308        }
309
310        result
311    }
312
313    /// Dot product of two 1D tensors (returns scalar)
314    ///
315    /// Computes the dot product between two 1D tensors using SIMD-optimized kernels
316    /// when available. The implementation uses AVX2 instructions for 8x vectorization
317    /// with scalar fallbacks for non-SIMD hardware.
318    ///
319    /// # Arguments
320    ///
321    /// * `other` - The other 1D tensor (must have same length as self)
322    ///
323    /// # Returns
324    ///
325    /// A scalar tensor containing the dot product result
326    ///
327    /// # Implementation Details
328    ///
329    /// - Uses `Tensor::new_uninitialized` for performance-critical allocation
330    /// - SIMD path processes 8 elements at a time with horizontal reduction
331    /// - Scalar path uses 4x unrolled loops for instruction-level parallelism
332    /// - Memory is fully written to avoid uninitialized access
333    fn dot_product_1d(&self, other: &Tensor) -> Tensor {
334        assert_eq!(self.shape().rank(), 1, "First tensor must be 1D");
335        assert_eq!(other.shape().rank(), 1, "Second tensor must be 1D");
336        assert_eq!(
337            self.shape().dims[0],
338            other.shape().dims[0],
339            "Tensors must have same length for dot product"
340        );
341
342        let n = self.shape().dims[0];
343
344        // Ensure both tensors are contiguous for kernel compatibility
345        let self_contiguous = if self.is_contiguous() {
346            self.clone()
347        } else {
348            self.contiguous()
349        };
350        let other_contiguous = if other.is_contiguous() {
351            other.clone()
352        } else {
353            other.contiguous()
354        };
355
356        // Use uninitialized allocation for scalar result - memory will be fully written
357        let mut result = Tensor::new_uninitialized(vec![]); // Scalar tensor
358
359        unsafe {
360            let a_ptr = self_contiguous.as_ptr();
361            let b_ptr = other_contiguous.as_ptr();
362            let result_ptr = result.as_mut_ptr();
363
364            #[cfg(target_arch = "x86_64")]
365            {
366                if is_x86_feature_detected!("avx2") {
367                    let dot_product = self.dot_product_simd_avx2(a_ptr, b_ptr, n);
368                    *result_ptr = dot_product;
369                } else {
370                    let dot_product = self.dot_product_scalar(a_ptr, b_ptr, n);
371                    *result_ptr = dot_product;
372                }
373            }
374
375            #[cfg(not(target_arch = "x86_64"))]
376            {
377                let dot_product = self.dot_product_scalar(a_ptr, b_ptr, n);
378                *result_ptr = dot_product;
379            }
380        }
381
382        result
383    }
384
385    /// Vector-matrix multiplication: v^T * M
386    ///
387    /// Computes the product of a 1D vector with a 2D matrix, treating the vector
388    /// as a row vector. The implementation uses SIMD-optimized column-wise dot products
389    /// for maximum performance on compatible hardware.
390    ///
391    /// # Arguments
392    ///
393    /// * `other` - The 2D matrix tensor (vector length must match matrix rows)
394    ///
395    /// # Returns
396    ///
397    /// A 1D tensor containing the vector-matrix multiplication result
398    ///
399    /// # Implementation Details
400    ///
401    /// - Computes dot product between vector and each matrix column
402    /// - Uses SIMD kernels for each column when AVX2 is available
403    /// - Scalar fallback processes each column individually
404    /// - Memory layout optimized for column-wise access patterns
405    fn vector_matrix_mult(&self, other: &Tensor) -> Tensor {
406        assert_eq!(self.shape().rank(), 1, "First tensor must be 1D (vector)");
407        assert_eq!(other.shape().rank(), 2, "Second tensor must be 2D (matrix)");
408        assert_eq!(
409            self.shape().dims[0],
410            other.shape().dims[0],
411            "Vector length must match matrix rows"
412        );
413
414        let v_len = self.shape().dims[0];
415        let m_cols = other.shape().dims[1];
416
417        // Ensure both tensors are contiguous for kernel compatibility
418        let self_contiguous = if self.is_contiguous() {
419            self.clone()
420        } else {
421            self.contiguous()
422        };
423        let other_contiguous = if other.is_contiguous() {
424            other.clone()
425        } else {
426            other.contiguous()
427        };
428
429        // Use uninitialized allocation for performance - result will be fully written
430        let mut result = Tensor::new_uninitialized(vec![m_cols]);
431
432        unsafe {
433            let v_ptr = self_contiguous.as_ptr();
434            let m_ptr = other_contiguous.as_ptr();
435            let result_ptr = result.as_mut_ptr();
436
437            #[cfg(target_arch = "x86_64")]
438            {
439                if is_x86_feature_detected!("avx2") {
440                    // Use SIMD for each column
441                    for col in 0..m_cols {
442                        let dot_product =
443                            self.vector_matrix_column_simd_avx2(v_ptr, m_ptr, v_len, m_cols, col);
444                        *result_ptr.add(col) = dot_product;
445                    }
446                } else {
447                    // Use scalar for each column
448                    for col in 0..m_cols {
449                        let dot_product =
450                            self.vector_matrix_column_scalar(v_ptr, m_ptr, v_len, m_cols, col);
451                        *result_ptr.add(col) = dot_product;
452                    }
453                }
454            }
455
456            #[cfg(not(target_arch = "x86_64"))]
457            {
458                // Use scalar for each column
459                for col in 0..m_cols {
460                    let dot_product =
461                        self.vector_matrix_column_scalar(v_ptr, m_ptr, v_len, m_cols, col);
462                    *result_ptr.add(col) = dot_product;
463                }
464            }
465        }
466
467        result
468    }
469
470    /// Matrix-vector multiplication: M * v
471    ///
472    /// Computes the product of a 2D matrix with a 1D vector, treating the vector
473    /// as a column vector. The implementation uses SIMD-optimized row-wise dot products
474    /// for maximum performance on compatible hardware.
475    ///
476    /// # Arguments
477    ///
478    /// * `other` - The 1D vector tensor (matrix columns must match vector length)
479    ///
480    /// # Returns
481    ///
482    /// A 1D tensor containing the matrix-vector multiplication result
483    ///
484    /// # Implementation Details
485    ///
486    /// - Computes dot product between each matrix row and the vector
487    /// - Uses SIMD kernels for each row when AVX2 is available
488    /// - Scalar fallback processes each row individually
489    /// - Memory layout optimized for row-wise access patterns
490    fn matrix_vector_mult(&self, other: &Tensor) -> Tensor {
491        assert_eq!(self.shape().rank(), 2, "First tensor must be 2D (matrix)");
492        assert_eq!(other.shape().rank(), 1, "Second tensor must be 1D (vector)");
493        assert_eq!(
494            self.shape().dims[1],
495            other.shape().dims[0],
496            "Matrix columns must match vector length"
497        );
498
499        let m_rows = self.shape().dims[0];
500        let m_cols = self.shape().dims[1];
501
502        // Ensure both tensors are contiguous for kernel compatibility
503        let self_contiguous = if self.is_contiguous() {
504            self.clone()
505        } else {
506            self.contiguous()
507        };
508        let other_contiguous = if other.is_contiguous() {
509            other.clone()
510        } else {
511            other.contiguous()
512        };
513
514        // Use uninitialized allocation for performance - result will be fully written
515        let mut result = Tensor::new_uninitialized(vec![m_rows]);
516
517        unsafe {
518            let m_ptr = self_contiguous.as_ptr();
519            let v_ptr = other_contiguous.as_ptr();
520            let result_ptr = result.as_mut_ptr();
521
522            #[cfg(target_arch = "x86_64")]
523            {
524                if is_x86_feature_detected!("avx2") {
525                    // Use SIMD for each row
526                    for row in 0..m_rows {
527                        let dot_product =
528                            self.matrix_vector_row_simd_avx2(m_ptr, v_ptr, m_cols, row);
529                        *result_ptr.add(row) = dot_product;
530                    }
531                } else {
532                    // Use scalar for each row
533                    for row in 0..m_rows {
534                        let dot_product = self.matrix_vector_row_scalar(m_ptr, v_ptr, m_cols, row);
535                        *result_ptr.add(row) = dot_product;
536                    }
537                }
538            }
539
540            #[cfg(not(target_arch = "x86_64"))]
541            {
542                // Use scalar for each row
543                for row in 0..m_rows {
544                    let dot_product = self.matrix_vector_row_scalar(m_ptr, v_ptr, m_cols, row);
545                    *result_ptr.add(row) = dot_product;
546                }
547            }
548        }
549
550        result
551    }
552
553    /// Standard matrix-matrix multiplication (2D @ 2D)
554    ///
555    /// Computes the product of two 2D matrices using intelligent kernel selection
556    /// based on matrix dimensions. The implementation uses cache-friendly blocked
557    /// algorithms for large matrices and direct computation for small matrices.
558    ///
559    /// # Arguments
560    ///
561    /// * `other` - The right matrix (2D tensor with compatible inner dimensions)
562    ///
563    /// # Returns
564    ///
565    /// A 2D tensor containing the matrix multiplication result
566    ///
567    /// # Implementation Details
568    ///
569    /// - Uses `MatmulConfig::for_dimensions` for optimal kernel selection
570    /// - Dispatches to `kernels::matrix_multiply_blocked` for computation
571    /// - Supports both SIMD and scalar execution paths
572    /// - Memory layout optimized for cache efficiency and SIMD alignment
573    fn matrix_matrix_mult(&self, other: &Tensor) -> Tensor {
574        let m = self.shape().dims[0]; // Result rows
575        let k = self.shape().dims[1]; // Inner dimension
576        let n = other.shape().dims[1]; // Result columns
577
578        assert_eq!(
579            k,
580            other.shape().dims[0],
581            "Inner dimensions must match: {} vs {}",
582            k,
583            other.shape().dims[0]
584        );
585
586        // Ensure both tensors are contiguous for kernel compatibility
587        let self_contiguous = if self.is_contiguous() {
588            self.clone()
589        } else {
590            self.contiguous()
591        };
592        let other_contiguous = if other.is_contiguous() {
593            other.clone()
594        } else {
595            other.contiguous()
596        };
597
598        // Use uninitialized allocation for performance - will be initialized properly
599        let mut result = Tensor::new_uninitialized(vec![m, n]);
600
601        unsafe {
602            let a_ptr = self_contiguous.as_ptr();
603            let b_ptr = other_contiguous.as_ptr();
604            let c_ptr = result.as_mut_ptr();
605
606            // Determine optimal configuration and dispatch
607            let config = MatmulConfig::for_dimensions(m, n, k);
608            kernels::matrix_multiply_blocked(a_ptr, b_ptr, c_ptr, m, n, k, &config);
609        }
610
611        result
612    }
613
614    /// Batched matrix multiplication for higher-dimensional tensors
615    ///
616    /// Performs matrix multiplication on the last two dimensions while broadcasting
617    /// the leading dimensions. This operation supports arbitrary tensor shapes
618    /// with at least 2 dimensions, following NumPy broadcasting rules.
619    ///
620    /// # Arguments
621    ///
622    /// * `other` - The other tensor for batched multiplication (must have at least 2D)
623    ///
624    /// # Returns
625    ///
626    /// A tensor with batched matrix multiplication results, with shape determined
627    /// by broadcasting the batch dimensions and matrix multiplication on the last two
628    ///
629    /// # Implementation Details
630    ///
631    /// - Broadcasts batch dimensions following NumPy right-aligned rules
632    /// - Performs individual matrix multiplications for each batch element
633    /// - Uses `calculate_batch_offset_with_broadcast` for memory offset computation
634    /// - Supports broadcasting of singleton dimensions (size 1) to any size
635    fn batched_matmul(&self, other: &Tensor) -> Tensor {
636        let self_shape = self.shape();
637        let other_shape = other.shape();
638
639        // Ensure both tensors have at least 2 dimensions
640        assert!(
641            self_shape.rank() >= 2,
642            "Batched matmul requires at least 2D tensors"
643        );
644        assert!(
645            other_shape.rank() >= 2,
646            "Batched matmul requires at least 2D tensors"
647        );
648
649        // Get matrix dimensions (last two dimensions)
650        let self_m = self_shape.dims[self_shape.rank() - 2];
651        let self_k = self_shape.dims[self_shape.rank() - 1];
652        let other_k = other_shape.dims[other_shape.rank() - 2];
653        let other_n = other_shape.dims[other_shape.rank() - 1];
654
655        assert_eq!(
656            self_k, other_k,
657            "Inner dimensions must match for batched matmul: {} vs {}",
658            self_k, other_k
659        );
660
661        // Calculate output shape by broadcasting batch dimensions
662        let mut output_dims = Vec::new();
663        let max_rank = self_shape.rank().max(other_shape.rank());
664
665        // Broadcast batch dimensions (right-aligned)
666        for i in 0..(max_rank - 2) {
667            let self_batch_rank = self_shape.rank() - 2;
668            let other_batch_rank = other_shape.rank() - 2;
669
670            let self_dim = if i < self_batch_rank {
671                self_shape.dims[self_batch_rank - 1 - i]
672            } else {
673                1
674            };
675            let other_dim = if i < other_batch_rank {
676                other_shape.dims[other_batch_rank - 1 - i]
677            } else {
678                1
679            };
680
681            if self_dim == 1 {
682                output_dims.push(other_dim);
683            } else if other_dim == 1 || self_dim == other_dim {
684                output_dims.push(self_dim);
685            } else {
686                panic!("Cannot broadcast dimensions {} and {}", self_dim, other_dim);
687            }
688        }
689
690        // Reverse to get correct order (we built from right to left)
691        output_dims.reverse();
692
693        // Add matrix dimensions
694        output_dims.push(self_m);
695        output_dims.push(other_n);
696
697        // Use uninitialized allocation for performance - result will be fully written
698        let mut result = Tensor::new_uninitialized(output_dims.clone());
699
700        // Calculate total number of matrix multiplications
701        let batch_size: usize = output_dims[..output_dims.len() - 2].iter().product();
702
703        unsafe {
704            // Perform batched matrix multiplication
705            let batch_dims = &output_dims[..output_dims.len() - 2];
706            for batch_idx in 0..batch_size {
707                // Calculate offsets for this batch with broadcasting support
708                let self_offset = self.calculate_batch_offset_with_broadcast(
709                    batch_idx,
710                    self_m * self_k,
711                    batch_dims,
712                );
713                let other_offset = other.calculate_batch_offset_with_broadcast(
714                    batch_idx,
715                    other_k * other_n,
716                    batch_dims,
717                );
718                let result_offset = batch_idx * self_m * other_n;
719
720                let a_ptr = self.as_ptr().add(self_offset);
721                let b_ptr = other.as_ptr().add(other_offset);
722                let c_ptr = result.as_mut_ptr().add(result_offset);
723
724                // Perform single matrix multiplication with dynamic configuration
725                let config = MatmulConfig::for_dimensions(self_m, other_n, self_k);
726                kernels::matrix_multiply_blocked(
727                    a_ptr, b_ptr, c_ptr, self_m, other_n, self_k, &config,
728                );
729            }
730        }
731
732        // Handle PyTorch-compatible shape squeezing
733        // If one operand was 2D, squeeze out the batch dimension from the result
734        let should_squeeze_batch = self_shape.rank() == 2 || other_shape.rank() == 2;
735        if should_squeeze_batch && output_dims.len() > 2 && output_dims[0] == 1 {
736            // Squeeze out the leading dimension of size 1
737            result = result.squeeze(Some(0));
738        }
739
740        result
741    }
742
743    /// Calculate memory offset for batched operations with broadcasting support
744    ///
745    /// Computes the memory offset for a specific batch element, taking into account
746    /// broadcasting rules where singleton dimensions (size 1) are repeated across
747    /// the batch. This enables efficient batched operations with broadcasting.
748    ///
749    /// # Arguments
750    ///
751    /// * `batch_idx` - Linear batch index (0-based)
752    /// * `matrix_size` - Size of each matrix in elements (product of last two dimensions)
753    /// * `output_batch_dims` - Output batch dimensions for reference (leading dimensions)
754    ///
755    /// # Returns
756    ///
757    /// Memory offset in elements for the specified batch index
758    ///
759    /// # Implementation Details
760    ///
761    /// - Converts linear batch index to multi-dimensional coordinates
762    /// - Handles broadcasting by mapping coordinates to actual tensor dimensions
763    /// - Uses stride-based offset calculation for memory efficiency
764    /// - Supports right-aligned broadcasting following NumPy conventions
765    fn calculate_batch_offset_with_broadcast(
766        &self,
767        batch_idx: usize,
768        matrix_size: usize,
769        output_batch_dims: &[usize],
770    ) -> usize {
771        if output_batch_dims.is_empty() {
772            return 0;
773        }
774
775        // Convert linear batch index to multi-dimensional coordinates
776        let mut coords = Vec::new();
777        let mut temp_idx = batch_idx;
778
779        for &dim_size in output_batch_dims.iter().rev() {
780            coords.push(temp_idx % dim_size);
781            temp_idx /= dim_size;
782        }
783        coords.reverse();
784
785        // Calculate actual offset based on this tensor's batch dimensions
786        let self_batch_dims = &self.shape().dims[..self.shape().rank() - 2];
787        let mut offset = 0;
788
789        // Align coordinates from the right (broadcasting is right-aligned)
790        let coord_offset = if output_batch_dims.len() >= self_batch_dims.len() {
791            output_batch_dims.len() - self_batch_dims.len()
792        } else {
793            0
794        };
795
796        // Calculate offset using strides
797        for (i, &self_dim) in self_batch_dims.iter().enumerate() {
798            let coord_idx = coord_offset + i;
799            if coord_idx < coords.len() {
800                let coord = coords[coord_idx];
801                // If this tensor's dimension is 1, we stay at the same position (broadcasting)
802                let actual_coord = if self_dim == 1 { 0 } else { coord % self_dim };
803
804                // Calculate stride for this dimension
805                let mut stride = matrix_size;
806                for &later_dim in self_batch_dims.iter().skip(i + 1) {
807                    stride *= later_dim;
808                }
809
810                offset += actual_coord * stride;
811            }
812        }
813
814        offset
815    }
816
817    // ===== SIMD Optimized Implementations =====
818
819    /// AVX2-optimized dot product implementation
820    ///
821    /// Computes dot product using AVX2 SIMD instructions for 8x vectorization.
822    /// Processes 8 elements at a time with horizontal reduction for final sum.
823    ///
824    /// # Safety
825    ///
826    /// Requires AVX2 support and valid pointers with sufficient memory for n elements.
827    /// Memory must be properly aligned for optimal performance.
828    ///
829    /// # Arguments
830    ///
831    /// * `a_ptr` - Pointer to first vector data
832    /// * `b_ptr` - Pointer to second vector data  
833    /// * `n` - Number of elements to process
834    ///
835    /// # Returns
836    ///
837    /// Dot product result as f32
838    #[cfg(target_arch = "x86_64")]
839    #[inline]
840    #[target_feature(enable = "avx2")]
841    unsafe fn dot_product_simd_avx2(&self, a_ptr: *const f32, b_ptr: *const f32, n: usize) -> f32 {
842        let simd_end = n & !7; // Process 8 elements at a time
843        let mut sum_vec = _mm256_setzero_ps();
844
845        // SIMD loop
846        for i in (0..simd_end).step_by(8) {
847            let a_vec = _mm256_loadu_ps(a_ptr.add(i));
848            let b_vec = _mm256_loadu_ps(b_ptr.add(i));
849            let prod = _mm256_mul_ps(a_vec, b_vec);
850            sum_vec = _mm256_add_ps(sum_vec, prod);
851        }
852
853        // Horizontal sum of SIMD register
854        let sum_hi = _mm256_extractf128_ps(sum_vec, 1);
855        let sum_lo = _mm256_castps256_ps128(sum_vec);
856        let sum_quad = _mm_add_ps(sum_hi, sum_lo);
857        let sum_dual = _mm_hadd_ps(sum_quad, sum_quad);
858        let sum_single = _mm_hadd_ps(sum_dual, sum_dual);
859        let mut result = _mm_cvtss_f32(sum_single);
860
861        // Handle remaining elements
862        for i in simd_end..n {
863            result += *a_ptr.add(i) * *b_ptr.add(i);
864        }
865
866        result
867    }
868
869    /// Scalar-optimized dot product implementation
870    ///
871    /// Computes dot product using scalar operations with 4x loop unrolling for
872    /// better instruction-level parallelism. Provides fallback for non-SIMD hardware.
873    ///
874    /// # Safety
875    ///
876    /// Requires valid pointers with sufficient memory for n elements.
877    ///
878    /// # Arguments
879    ///
880    /// * `a_ptr` - Pointer to first vector data
881    /// * `b_ptr` - Pointer to second vector data
882    /// * `n` - Number of elements to process
883    ///
884    /// # Returns
885    ///
886    /// Dot product result as f32
887    #[inline]
888    unsafe fn dot_product_scalar(&self, a_ptr: *const f32, b_ptr: *const f32, n: usize) -> f32 {
889        let mut sum = 0.0f32;
890        let unroll_end = n & !3; // Process 4 elements at a time
891
892        // Unrolled loop for better instruction-level parallelism
893        for i in (0..unroll_end).step_by(4) {
894            sum += *a_ptr.add(i) * *b_ptr.add(i);
895            sum += *a_ptr.add(i + 1) * *b_ptr.add(i + 1);
896            sum += *a_ptr.add(i + 2) * *b_ptr.add(i + 2);
897            sum += *a_ptr.add(i + 3) * *b_ptr.add(i + 3);
898        }
899
900        // Handle remaining elements
901        for i in unroll_end..n {
902            sum += *a_ptr.add(i) * *b_ptr.add(i);
903        }
904
905        sum
906    }
907
908    /// AVX2-optimized vector-matrix column dot product
909    ///
910    /// Computes dot product between a vector and a specific matrix column using
911    /// AVX2 SIMD instructions. Optimized for column-wise access patterns.
912    ///
913    /// # Safety
914    ///
915    /// Requires AVX2 support and valid pointers with sufficient memory.
916    /// Matrix must be in row-major layout with m_cols columns.
917    ///
918    /// # Arguments
919    ///
920    /// * `v_ptr` - Pointer to vector data
921    /// * `m_ptr` - Pointer to matrix data (row-major layout)
922    /// * `v_len` - Length of vector (must match matrix rows)
923    /// * `m_cols` - Number of columns in matrix
924    /// * `col` - Column index to compute dot product with
925    ///
926    /// # Returns
927    ///
928    /// Dot product result as f32
929    #[cfg(target_arch = "x86_64")]
930    #[inline]
931    #[target_feature(enable = "avx2")]
932    unsafe fn vector_matrix_column_simd_avx2(
933        &self,
934        v_ptr: *const f32,
935        m_ptr: *const f32,
936        v_len: usize,
937        m_cols: usize,
938        col: usize,
939    ) -> f32 {
940        let simd_end = v_len & !7;
941        let mut sum_vec = _mm256_setzero_ps();
942
943        // Process 8 elements at a time with optimized gather
944        for i in (0..simd_end).step_by(8) {
945            let v_vec = _mm256_loadu_ps(v_ptr.add(i));
946
947            // Optimized gather for matrix column elements
948            let m0 = *m_ptr.add(i * m_cols + col);
949            let m1 = *m_ptr.add((i + 1) * m_cols + col);
950            let m2 = *m_ptr.add((i + 2) * m_cols + col);
951            let m3 = *m_ptr.add((i + 3) * m_cols + col);
952            let m4 = *m_ptr.add((i + 4) * m_cols + col);
953            let m5 = *m_ptr.add((i + 5) * m_cols + col);
954            let m6 = *m_ptr.add((i + 6) * m_cols + col);
955            let m7 = *m_ptr.add((i + 7) * m_cols + col);
956
957            let m_vec = _mm256_set_ps(m7, m6, m5, m4, m3, m2, m1, m0);
958
959            let prod = _mm256_mul_ps(v_vec, m_vec);
960            sum_vec = _mm256_add_ps(sum_vec, prod);
961        }
962
963        // Horizontal sum
964        let sum_hi = _mm256_extractf128_ps(sum_vec, 1);
965        let sum_lo = _mm256_castps256_ps128(sum_vec);
966        let sum_quad = _mm_add_ps(sum_hi, sum_lo);
967        let sum_dual = _mm_hadd_ps(sum_quad, sum_quad);
968        let sum_single = _mm_hadd_ps(sum_dual, sum_dual);
969        let mut result = _mm_cvtss_f32(sum_single);
970
971        // Handle remaining elements
972        for i in simd_end..v_len {
973            result += *v_ptr.add(i) * *m_ptr.add(i * m_cols + col);
974        }
975
976        result
977    }
978
979    /// Scalar vector-matrix column dot product
980    ///
981    /// Computes dot product between a vector and a specific matrix column using
982    /// scalar operations. Provides fallback for non-SIMD hardware.
983    ///
984    /// # Safety
985    ///
986    /// Requires valid pointers with sufficient memory.
987    /// Matrix must be in row-major layout with m_cols columns.
988    ///
989    /// # Arguments
990    ///
991    /// * `v_ptr` - Pointer to vector data
992    /// * `m_ptr` - Pointer to matrix data (row-major layout)
993    /// * `v_len` - Length of vector (must match matrix rows)
994    /// * `m_cols` - Number of columns in matrix
995    /// * `col` - Column index to compute dot product with
996    ///
997    /// # Returns
998    ///
999    /// Dot product result as f32
1000    #[inline]
1001    unsafe fn vector_matrix_column_scalar(
1002        &self,
1003        v_ptr: *const f32,
1004        m_ptr: *const f32,
1005        v_len: usize,
1006        m_cols: usize,
1007        col: usize,
1008    ) -> f32 {
1009        let mut sum = 0.0f32;
1010        for i in 0..v_len {
1011            sum += *v_ptr.add(i) * *m_ptr.add(i * m_cols + col);
1012        }
1013        sum
1014    }
1015
1016    /// AVX2-optimized matrix-vector row dot product
1017    ///
1018    /// Computes dot product between a specific matrix row and a vector using
1019    /// AVX2 SIMD instructions. Optimized for row-wise access patterns.
1020    ///
1021    /// # Safety
1022    ///
1023    /// Requires AVX2 support and valid pointers with sufficient memory.
1024    /// Matrix must be in row-major layout with m_cols columns.
1025    ///
1026    /// # Arguments
1027    ///
1028    /// * `m_ptr` - Pointer to matrix data (row-major layout)
1029    /// * `v_ptr` - Pointer to vector data
1030    /// * `m_cols` - Number of columns in matrix (must match vector length)
1031    /// * `row` - Row index to compute dot product with
1032    ///
1033    /// # Returns
1034    ///
1035    /// Dot product result as f32
1036    #[cfg(target_arch = "x86_64")]
1037    #[inline]
1038    #[target_feature(enable = "avx2")]
1039    unsafe fn matrix_vector_row_simd_avx2(
1040        &self,
1041        m_ptr: *const f32,
1042        v_ptr: *const f32,
1043        m_cols: usize,
1044        row: usize,
1045    ) -> f32 {
1046        let simd_end = m_cols & !7;
1047        let mut sum_vec = _mm256_setzero_ps();
1048        let row_ptr = m_ptr.add(row * m_cols);
1049
1050        for i in (0..simd_end).step_by(8) {
1051            let m_vec = _mm256_loadu_ps(row_ptr.add(i));
1052            let v_vec = _mm256_loadu_ps(v_ptr.add(i));
1053            let prod = _mm256_mul_ps(m_vec, v_vec);
1054            sum_vec = _mm256_add_ps(sum_vec, prod);
1055        }
1056
1057        // Horizontal sum
1058        let sum_hi = _mm256_extractf128_ps(sum_vec, 1);
1059        let sum_lo = _mm256_castps256_ps128(sum_vec);
1060        let sum_quad = _mm_add_ps(sum_hi, sum_lo);
1061        let sum_dual = _mm_hadd_ps(sum_quad, sum_quad);
1062        let sum_single = _mm_hadd_ps(sum_dual, sum_dual);
1063        let mut result = _mm_cvtss_f32(sum_single);
1064
1065        // Handle remaining elements
1066        for i in simd_end..m_cols {
1067            result += *row_ptr.add(i) * *v_ptr.add(i);
1068        }
1069
1070        result
1071    }
1072
1073    /// Scalar matrix-vector row dot product
1074    ///
1075    /// Computes dot product between a specific matrix row and a vector using
1076    /// scalar operations. Provides fallback for non-SIMD hardware.
1077    ///
1078    /// # Safety
1079    ///
1080    /// Requires valid pointers with sufficient memory.
1081    /// Matrix must be in row-major layout with m_cols columns.
1082    ///
1083    /// # Arguments
1084    ///
1085    /// * `m_ptr` - Pointer to matrix data (row-major layout)
1086    /// * `v_ptr` - Pointer to vector data
1087    /// * `m_cols` - Number of columns in matrix (must match vector length)
1088    /// * `row` - Row index to compute dot product with
1089    ///
1090    /// # Returns
1091    ///
1092    /// Dot product result as f32
1093    #[inline]
1094    unsafe fn matrix_vector_row_scalar(
1095        &self,
1096        m_ptr: *const f32,
1097        v_ptr: *const f32,
1098        m_cols: usize,
1099        row: usize,
1100    ) -> f32 {
1101        let mut sum = 0.0f32;
1102        let row_ptr = m_ptr.add(row * m_cols);
1103        for i in 0..m_cols {
1104            sum += *row_ptr.add(i) * *v_ptr.add(i);
1105        }
1106        sum
1107    }
1108}
1109
1110#[cfg(test)]
1111mod tests {
1112    //! Matrix multiplication operation tests
1113    //!
1114    //! This module contains comprehensive tests for matrix multiplication operations,
1115    //! including basic functionality, kernel selection, and large matrix handling.
1116    //! Tests cover all supported operation types and edge cases.
1117
1118    use super::*;
1119
1120    /// Test basic 2x2 matrix multiplication functionality
1121    ///
1122    /// Verifies that the matmul operation correctly computes the product of two 2x2 matrices
1123    /// and produces the expected numerical results. This test validates the core matrix
1124    /// multiplication algorithm and result shape computation.
1125    #[test]
1126    fn test_matmul_2d_basic() {
1127        // Test basic 2x2 matrix multiplication
1128        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1129        let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
1130        let result = a.matmul(&b);
1131
1132        assert_eq!(result.shape().dims, vec![2, 2]);
1133
1134        // Expected result: [[19, 22], [43, 50]]
1135        unsafe {
1136            let ptr = result.as_ptr();
1137            assert_eq!(*ptr.add(0), 19.0); // (0,0)
1138            assert_eq!(*ptr.add(1), 22.0); // (0,1)
1139            assert_eq!(*ptr.add(2), 43.0); // (1,0)
1140            assert_eq!(*ptr.add(3), 50.0); // (1,1)
1141        }
1142    }
1143
1144    /// Test 2D @ 2D matmul gradient computation (matrix @ matrix)
1145    #[test]
1146    fn test_matmul_2d_2d_gradients() {
1147        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
1148            .unwrap()
1149            .with_requires_grad();
1150        let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2])
1151            .unwrap()
1152            .with_requires_grad();
1153
1154        let mut result = a.matmul(&b); // [2, 2] @ [2, 2] -> [2, 2]
1155        assert_eq!(result.shape().dims, vec![2, 2]);
1156
1157        // Expected result: [[19, 22], [43, 50]]
1158        let expected = [19.0, 22.0, 43.0, 50.0];
1159        unsafe {
1160            let ptr = result.as_ptr();
1161            for (i, val) in expected.iter().enumerate().take(4) {
1162                assert_eq!(*ptr.add(i), *val);
1163            }
1164        }
1165
1166        // Set up gradient for backward pass
1167        let grad_output = Tensor::from_slice(&[1.0, 1.0, 1.0, 1.0], vec![2, 2]).unwrap();
1168        result.backward(Some(grad_output));
1169
1170        let grad_a = a.grad_by_value().unwrap();
1171        let grad_b = b.grad_by_value().unwrap();
1172
1173        assert_eq!(grad_a.shape().dims, vec![2, 2]);
1174        assert_eq!(grad_b.shape().dims, vec![2, 2]);
1175
1176        // grad_a = grad_output @ b^T = [[1, 1], [1, 1]] @ [[5, 7], [6, 8]] = [[11, 15], [11, 15]]
1177
1178        unsafe {
1179            let grad_a_ptr = grad_a.as_ptr();
1180            assert_eq!(*grad_a_ptr.add(0), 11.0); // 1*5 + 1*6
1181            assert_eq!(*grad_a_ptr.add(1), 15.0); // 1*7 + 1*8
1182            assert_eq!(*grad_a_ptr.add(2), 11.0); // 1*5 + 1*6
1183            assert_eq!(*grad_a_ptr.add(3), 15.0); // 1*7 + 1*8
1184        }
1185
1186        // grad_b = a^T @ grad_output = [[1, 3], [2, 4]] @ [[1, 1], [1, 1]] = [[4, 4], [6, 6]]
1187        unsafe {
1188            let grad_b_ptr = grad_b.as_ptr();
1189            assert_eq!(*grad_b_ptr.add(0), 4.0); // 1*1 + 3*1
1190            assert_eq!(*grad_b_ptr.add(1), 4.0); // 1*1 + 3*1
1191            assert_eq!(*grad_b_ptr.add(2), 6.0); // 2*1 + 4*1
1192            assert_eq!(*grad_b_ptr.add(3), 6.0); // 2*1 + 4*1
1193        }
1194    }
1195
1196    /// Test matmul gradient computation with partial requires_grad
1197    #[test]
1198    fn test_matmul_partial_requires_grad() {
1199        // Test case where only one operand requires gradients (like the linear layer case)
1200        let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap(); // No requires_grad
1201        let b = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2])
1202            .unwrap()
1203            .with_requires_grad(); // Only b requires gradients
1204
1205        let mut result = a.matmul(&b); // [3] @ [3, 2] -> [2]
1206        assert_eq!(result.shape().dims, vec![2]);
1207
1208        result.backward(None);
1209
1210        // Only b should have gradients
1211        assert!(a.grad_by_value().is_none());
1212        let grad_b = b.grad_by_value().unwrap();
1213
1214        assert_eq!(grad_b.shape().dims, vec![3, 2]);
1215
1216        // grad_b = outer_product(a, grad_output)
1217        // Since grad_output defaults to ones([2]), grad_b[i,j] = a[i] * 1.0 = a[i]
1218        unsafe {
1219            let grad_b_ptr = grad_b.as_ptr();
1220            assert_eq!(*grad_b_ptr.add(0), 1.0); // a[0] * grad_output[0]
1221            assert_eq!(*grad_b_ptr.add(1), 1.0); // a[0] * grad_output[1]
1222            assert_eq!(*grad_b_ptr.add(2), 2.0); // a[1] * grad_output[0]
1223            assert_eq!(*grad_b_ptr.add(3), 2.0); // a[1] * grad_output[1]
1224            assert_eq!(*grad_b_ptr.add(4), 3.0); // a[2] * grad_output[0]
1225            assert_eq!(*grad_b_ptr.add(5), 3.0); // a[2] * grad_output[1]
1226        }
1227    }
1228
1229    #[test]
1230    fn test_debug_gradient_values() {
1231        println!("=== Debugging matmul gradient issue ===");
1232
1233        // Test case: [1, 3, 4] @ [2, 4, 5] which should fail with our=41, torch=29
1234        let left_shape = vec![1, 3, 4];
1235        let right_shape = vec![2, 4, 5];
1236
1237        let mut left = Tensor::zeros(left_shape.clone()).with_requires_grad();
1238        let mut right = Tensor::zeros(right_shape.clone()).with_requires_grad();
1239
1240        let left_size = left_shape.iter().product::<usize>();
1241        let right_size = right_shape.iter().product::<usize>();
1242
1243        // Fill with exactly the same data as the validation test
1244        unsafe {
1245            for i in 0..left_size {
1246                *left.as_mut_ptr().add(i) = (i as f32) * 0.1 + 1.0;
1247            }
1248            for i in 0..right_size {
1249                *right.as_mut_ptr().add(i) = (i as f32) * 0.2 + 0.5;
1250            }
1251        }
1252
1253        println!(
1254            "Left shape: {:?}, data: {:?}",
1255            left.shape().dims,
1256            left.data()
1257        );
1258        println!(
1259            "Right shape: {:?}, data: {:?}",
1260            right.shape().dims,
1261            right.data()
1262        );
1263
1264        // Forward pass
1265        let mut result = left.matmul(&right);
1266        println!(
1267            "Result shape: {:?}, data: {:?}",
1268            result.shape().dims,
1269            result.data()
1270        );
1271
1272        // Backward pass with ones
1273        let grad_ones = Tensor::ones(result.shape().dims.clone());
1274        println!(
1275            "Grad ones shape: {:?}, data: {:?}",
1276            grad_ones.shape().dims,
1277            grad_ones.data()
1278        );
1279
1280        result.backward(Some(grad_ones));
1281
1282        let grad_left = left.grad_by_value().unwrap();
1283        let grad_right = right.grad_by_value().unwrap();
1284
1285        println!(
1286            "Left gradient shape: {:?}, data: {:?}",
1287            grad_left.shape().dims,
1288            grad_left.data()
1289        );
1290        println!(
1291            "Right gradient shape: {:?}, data: {:?}",
1292            grad_right.shape().dims,
1293            grad_right.data()
1294        );
1295
1296        println!(
1297            "Left gradient[0] = {} (expected ~29, but we're getting ~41)",
1298            grad_left.data()[0]
1299        );
1300    }
1301
1302    #[test]
1303    fn test_simple_batched_gradient() {
1304        println!("=== Testing simple batched gradient ===");
1305
1306        // Simple case: [2, 2, 2] @ [2, 2, 2]
1307        let left = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2])
1308            .unwrap()
1309            .with_requires_grad();
1310        let right = Tensor::from_slice(&[0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0], vec![2, 2, 2])
1311            .unwrap()
1312            .with_requires_grad();
1313
1314        println!("Left: {:?}", left.data());
1315        println!("Right: {:?}", right.data());
1316
1317        // Test transpose function first
1318        let right_t = right.transpose(1, 2);
1319        println!("Right transposed: {:?}", right_t.data());
1320        println!("Right transposed contiguous: {:?}", right_t.is_contiguous());
1321        println!("Right transposed strides: {:?}", right_t.strides());
1322
1323        let mut result = left.matmul(&right);
1324        println!("Result: {:?}", result.data());
1325
1326        let grad_ones = Tensor::ones(result.shape().dims.clone());
1327        result.backward(Some(grad_ones));
1328
1329        let grad_left = left.grad_by_value().unwrap();
1330        let grad_right = right.grad_by_value().unwrap();
1331
1332        println!("Left gradient: {:?}", grad_left.data());
1333        println!("Right gradient: {:?}", grad_right.data());
1334
1335        // Manual calculation for verification
1336        println!("\n=== Manual verification ===");
1337        println!("Expected left grad batch 0: [0.5+1.0, 1.5+2.0] = [1.5, 3.5]");
1338        println!("Expected left grad batch 1: [2.5+3.0, 3.5+4.0] = [5.5, 7.5]");
1339    }
1340
1341    #[test]
1342    fn test_linear_layer_pattern() {
1343        // Simulate the exact pattern from the training loop
1344        let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap(); // Input (no grad)
1345        let weight = Tensor::from_slice(&[0.1, 0.5, 0.3, 0.1, 0.5, 0.3], vec![3, 2])
1346            .unwrap()
1347            .with_requires_grad(); // Weight (requires grad)
1348        let bias = Tensor::from_slice(&[0.0, 0.1], vec![2])
1349            .unwrap()
1350            .with_requires_grad(); // Bias (requires grad)
1351
1352        // Forward pass
1353        let weighted = x_data.matmul(&weight); // [3] @ [3, 2] -> [2]
1354        let y_pred = weighted.add_tensor(&bias); // [2] + [2] -> [2]
1355
1356        // Create a simple loss (sum of squared differences with some target)
1357        let y_true = Tensor::from_slice(&[3.0, 5.0], vec![2]).unwrap();
1358        let mut loss = y_pred.sub_tensor(&y_true).pow_scalar(2.0).mean();
1359
1360        // Backward pass
1361        loss.backward(None);
1362
1363        // Check that gradients are computed correctly
1364        let grad_weight = weight.grad_by_value().unwrap();
1365        let grad_bias = bias.grad_by_value().unwrap();
1366
1367        assert_eq!(grad_weight.shape().dims, vec![3, 2]); // Same shape as weight
1368        assert_eq!(grad_bias.shape().dims, vec![2]); // Same shape as bias
1369
1370        // The exact gradient values depend on the computation graph, but shapes should be correct
1371        assert_eq!(grad_weight.size(), 6);
1372        assert_eq!(grad_bias.size(), 2);
1373
1374        // Verify that no gradient is computed for x_data (doesn't require grad)
1375        assert!(x_data.grad_by_value().is_none());
1376    }
1377}