train_station/tensor/ops/matmul/
mod.rs

1//! Matrix multiplication operations with optimized kernels
2//!
3//! This module provides a comprehensive matrix multiplication implementation optimized
4//! for single-threaded performance with SIMD acceleration. The implementation supports
5//! all NumPy-style matrix multiplication patterns including 1D/2D/ND tensor operations
6//! with automatic differentiation support.
7//!
8//! # Key Features
9//!
10//! - **SIMD Optimization**: AVX2 implementations for x86_64 architectures
11//! - **Intelligent Dispatch**: Dynamic kernel selection based on matrix dimensions
12//! - **Cache Optimization**: Blocked algorithms for L1/L2 cache efficiency
13//! - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
14//! - **GradTrack Integration**: Automatic gradient computation for all operations
15//! - **Thread Safety**: All operations are thread-safe and Send + Sync
16//! - **Mathematical Validation**: High-precision equivalence to LibTorch reference
17//!
18//! # Performance Characteristics
19//!
20//! The implementation uses intelligent dispatch to select optimal kernels based on matrix size:
21//! - **Small matrices (16-64 elements)**: Direct computation with minimal overhead
22//! - **Medium matrices (64-256 elements)**: Cache-optimized blocking for L1/L2 cache
23//! - **Large matrices (256+ elements)**: Memory bandwidth optimized with hierarchical blocking
24//! - **AVX2 acceleration**: 8x SIMD operations for compatible hardware
25//! - **Scalar fallbacks**: Optimized scalar implementations for non-SIMD platforms
26//! - **Memory Safety**: Safe memory management with `Tensor::new_uninitialized`
27//!
28//! # Organization
29//!
30//! The matmul module is organized into focused submodules:
31//! - **`config`**: Dynamic configuration and kernel selection based on matrix dimensions
32//! - **`kernels`**: SIMD-optimized computational kernels with ML-specific optimizations
33//!
34//! # Supported Operations
35//!
36//! - **1D @ 1D**: Dot product returning scalar tensor
37//! - **1D @ 2D**: Vector-matrix multiplication (v^T * M)
38//! - **2D @ 1D**: Matrix-vector multiplication (M * v)
39//! - **2D @ 2D**: Standard matrix multiplication with cache-optimized blocking
40//! - **ND @ ND**: Batched matrix multiplication on last two dimensions with broadcasting
41//!
42//! # Examples
43//!
44//! ## Basic Matrix Multiplication
45//!
46//! ```
47//! use train_station::Tensor;
48//!
49//! // 2D matrix multiplication
50//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
51//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
52//! let result = a.matmul(&b); // Uses optimized SIMD kernels
53//!
54//! assert_eq!(result.shape().dims, vec![2, 2]);
55//! assert_eq!(result.data(), &[19.0, 22.0, 43.0, 50.0]);
56//! ```
57//!
58//! ## Vector-Matrix Multiplication
59//!
60//! ```
61//! use train_station::Tensor;
62//!
63//! // Vector-matrix multiplication
64//! let v = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
65//! let m = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
66//! let result = v.matmul(&m); // [2] @ [2, 2] -> [2]
67//!
68//! assert_eq!(result.shape().dims, vec![2]);
69//! assert_eq!(result.data(), &[7.0, 10.0]); // 1*1+2*3, 1*2+2*4
70//! ```
71//!
72//! ## Matrix-Vector Multiplication
73//!
74//! ```
75//! use train_station::Tensor;
76//!
77//! // Matrix-vector multiplication
78//! let m = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
79//! let v = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
80//! let result = m.matmul(&v); // [2, 2] @ [2] -> [2]
81//!
82//! assert_eq!(result.shape().dims, vec![2]);
83//! assert_eq!(result.data(), &[5.0, 11.0]); // 1*1+2*2, 3*1+4*2
84//! ```
85//!
86//! ## Dot Product
87//!
88//! ```
89//! use train_station::Tensor;
90//!
91//! // 1D dot product
92//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
93//! let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
94//! let result = a.matmul(&b); // [3] @ [3] -> scalar
95//!
96//! assert_eq!(result.shape().dims, vec![]); // Scalar tensor
97//! assert_eq!(result.data(), &[32.0]); // 1*4 + 2*5 + 3*6
98//! ```
99//!
100//! ## Batched Matrix Multiplication
101//!
102//! ```
103//! use train_station::Tensor;
104//!
105//! // Batched matrix multiplication
106//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]).unwrap();
107//! let b = Tensor::from_slice(&[0.5, 1.0, 1.5, 2.0], vec![2, 2]).unwrap();
108//! let result = a.matmul(&b); // [2, 2, 2] @ [2, 2] -> [2, 2, 2]
109//!
110//! assert_eq!(result.shape().dims, vec![2, 2, 2]);
111//! ```
112//!
113//! ## Gradient Tracking
114//!
115//! ```
116//! use train_station::Tensor;
117//!
118//! // Matrix multiplication with gradient tracking
119//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
120//!     .unwrap()
121//!     .with_requires_grad();
122//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2])
123//!     .unwrap()
124//!     .with_requires_grad();
125//!
126//! let result = a.matmul(&b);
127//! assert!(result.requires_grad());
128//! assert_eq!(result.shape().dims, vec![2, 2]);
129//! ```
130//!
131//! # Automatic Differentiation
132//!
133//! All operations support automatic differentiation when either operand requires gradients.
134//! Gradient computation follows PyTorch semantics with proper accumulation and chain rule
135//! application through the gradtrack engine.
136//!
137//! # Thread Safety
138//!
139//! All operations are thread-safe and can be used concurrently across multiple threads.
140//! The implementation uses immutable tensor references and thread-local gradtrack state.
141//!
142//! # Mathematical Validation
143//!
144//! All operations are validated against LibTorch reference implementation with high-precision
145//! numerical equivalence (target: 0.00e0 error tolerance, practical: 1e-6 tolerance for
146//! floating-point precision differences).
147
148use crate::tensor::core::Tensor;
149
150pub mod config;
151pub mod kernels;
152
153// Re-export public types
154pub use config::MatmulConfig;
155
156// SIMD optimizations for performance-critical operations
157#[cfg(target_arch = "x86_64")]
158use std::arch::x86_64::*;
159
160impl Tensor {
161    /// Matrix multiplication operation following NumPy semantics
162    ///
163    /// Performs matrix multiplication between this tensor and another tensor with intelligent
164    /// kernel selection based on matrix dimensions and hardware capabilities. The operation
165    /// follows broadcasting rules and supports all common matrix multiplication patterns
166    /// found in machine learning workloads.
167    ///
168    /// # Supported Operations
169    ///
170    /// - **1D @ 1D**: Dot product returning scalar tensor
171    /// - **1D @ 2D**: Vector-matrix multiplication (v^T * M) returning 1D tensor
172    /// - **2D @ 1D**: Matrix-vector multiplication (M * v) returning 1D tensor
173    /// - **2D @ 2D**: Standard matrix multiplication with cache-optimized blocking
174    /// - **ND @ ND**: Batched matrix multiplication on last two dimensions with broadcasting
175    ///
176    /// # Performance Characteristics
177    ///
178    /// The implementation automatically selects optimal kernels based on matrix dimensions:
179    /// - **Small matrices (<64 elements)**: Direct computation with minimal overhead
180    /// - **Medium matrices (64-256 elements)**: Cache-optimized blocking for L1/L2 cache
181    /// - **Large matrices (256+ elements)**: Memory bandwidth optimized with hierarchical blocking
182    /// - **AVX2 acceleration**: 8x SIMD operations for compatible hardware
183    /// - **Scalar fallbacks**: Optimized scalar implementations for non-SIMD platforms
184    ///
185    /// # Automatic Differentiation
186    ///
187    /// This operation supports automatic differentiation when either operand requires gradients.
188    /// Gradient computation follows PyTorch semantics with proper accumulation and chain rule
189    /// application through the gradtrack engine. Gradients are computed for both operands when
190    /// `requires_grad` is set.
191    ///
192    /// # Arguments
193    ///
194    /// * `other` - The tensor to multiply with (must have compatible dimensions)
195    ///
196    /// # Returns
197    ///
198    /// A new tensor containing the matrix multiplication result with appropriate shape
199    /// determined by broadcasting rules and matrix multiplication semantics
200    ///
201    /// # Panics
202    ///
203    /// Panics if the inner dimensions don't match for matrix multiplication:
204    /// - For 2D @ 2D: `self.shape()[1] != other.shape()[0]`
205    /// - For 1D @ 2D: `self.shape()[0] != other.shape()[0]`
206    /// - For 2D @ 1D: `self.shape()[1] != other.shape()[0]`
207    /// - For ND @ ND: Last two dimensions must be compatible for matrix multiplication
208    ///
209    /// # Examples
210    ///
211    /// ```
212    /// use train_station::Tensor;
213    ///
214    /// // 2D matrix multiplication
215    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
216    /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
217    /// let result = a.matmul(&b); // [2, 2] @ [2, 2] -> [2, 2]
218    ///
219    /// assert_eq!(result.shape().dims, vec![2, 2]);
220    /// assert_eq!(result.data(), &[19.0, 22.0, 43.0, 50.0]);
221    /// ```
222    ///
223    /// ## Vector-Matrix Multiplication
224    ///
225    /// ```
226    /// use train_station::Tensor;
227    ///
228    /// let v = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
229    /// let m = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
230    /// let result = v.matmul(&m); // [2] @ [2, 2] -> [2]
231    ///
232    /// assert_eq!(result.shape().dims, vec![2]);
233    /// assert_eq!(result.data(), &[7.0, 10.0]); // 1*1+2*3, 1*2+2*4
234    /// ```
235    ///
236    /// ## Gradient Tracking
237    ///
238    /// ```
239    /// use train_station::Tensor;
240    ///
241    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
242    ///     .unwrap()
243    ///     .with_requires_grad();
244    /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2])
245    ///     .unwrap()
246    ///     .with_requires_grad();
247    ///
248    /// let result = a.matmul(&b);
249    /// assert!(result.requires_grad());
250    /// assert_eq!(result.shape().dims, vec![2, 2]);
251    /// ```
252    ///
253    /// # Thread Safety
254    ///
255    /// This operation is thread-safe and can be used concurrently across multiple threads.
256    /// The implementation uses immutable tensor references and thread-local gradtrack state.
257    ///
258    /// # Memory Safety
259    ///
260    /// The implementation uses `Tensor::new_uninitialized` for performance-critical allocations
261    /// and handles memory initialization safely through the kernel system. All unsafe operations
262    /// are validated through comprehensive FFI testing against LibTorch reference implementation.
263    pub fn matmul(&self, other: &Tensor) -> Tensor {
264        let self_shape = self.shape();
265        let other_shape = other.shape();
266
267        let mut result = match (self_shape.rank(), other_shape.rank()) {
268            (1, 1) => {
269                // 1D @ 1D: dot product -> scalar
270                self.dot_product_1d(other)
271            }
272            (1, 2) => {
273                // 1D @ 2D: vector-matrix multiplication -> 1D
274                self.vector_matrix_mult(other)
275            }
276            (2, 1) => {
277                // 2D @ 1D: matrix-vector multiplication -> 1D
278                self.matrix_vector_mult(other)
279            }
280            (2, 2) => {
281                // 2D @ 2D: standard matrix multiplication -> 2D
282                self.matrix_matrix_mult(other)
283            }
284            _ => {
285                // ND @ ND: batched matrix multiplication
286                self.batched_matmul(other)
287            }
288        };
289
290        // Set up gradtrack if either operand requires gradients
291        if (self.requires_grad() || other.requires_grad()) && crate::gradtrack::is_grad_enabled() {
292            use crate::gradtrack::{GradEngine, GradFn};
293
294            result.set_requires_grad(true);
295            let grad_fn = GradFn::MatMul {
296                left_operand: Box::new(self.clone()),
297                right_operand: Box::new(other.clone()),
298                requires_grad: (self.requires_grad(), other.requires_grad()),
299            };
300            result.set_grad_fn(grad_fn.clone());
301
302            // Register with gradtrack engine for gradient computation
303            // Always register both operands to maintain consistent indexing
304            let input_ids = vec![self.id(), other.id()];
305
306            GradEngine::register_operation(result.id(), input_ids, grad_fn);
307        }
308
309        result
310    }
311
312    /// Dot product of two 1D tensors (returns scalar)
313    ///
314    /// Computes the dot product between two 1D tensors using SIMD-optimized kernels
315    /// when available. The implementation uses AVX2 instructions for 8x vectorization
316    /// with scalar fallbacks for non-SIMD hardware.
317    ///
318    /// # Arguments
319    ///
320    /// * `other` - The other 1D tensor (must have same length as self)
321    ///
322    /// # Returns
323    ///
324    /// A scalar tensor containing the dot product result
325    ///
326    /// # Implementation Details
327    ///
328    /// - Uses `Tensor::new_uninitialized` for performance-critical allocation
329    /// - SIMD path processes 8 elements at a time with horizontal reduction
330    /// - Scalar path uses 4x unrolled loops for instruction-level parallelism
331    /// - Memory is fully written to avoid uninitialized access
332    fn dot_product_1d(&self, other: &Tensor) -> Tensor {
333        assert_eq!(self.shape().rank(), 1, "First tensor must be 1D");
334        assert_eq!(other.shape().rank(), 1, "Second tensor must be 1D");
335        assert_eq!(
336            self.shape().dims[0],
337            other.shape().dims[0],
338            "Tensors must have same length for dot product"
339        );
340
341        let n = self.shape().dims[0];
342
343        // Ensure both tensors are contiguous for kernel compatibility
344        let self_contiguous = if self.is_contiguous() {
345            self.clone()
346        } else {
347            self.contiguous()
348        };
349        let other_contiguous = if other.is_contiguous() {
350            other.clone()
351        } else {
352            other.contiguous()
353        };
354
355        // Use uninitialized allocation for scalar result - memory will be fully written
356        let mut result = Tensor::new_uninitialized(vec![]); // Scalar tensor
357
358        unsafe {
359            let a_ptr = self_contiguous.as_ptr();
360            let b_ptr = other_contiguous.as_ptr();
361            let result_ptr = result.as_mut_ptr();
362
363            #[cfg(target_arch = "x86_64")]
364            {
365                if is_x86_feature_detected!("avx2") {
366                    let dot_product = self.dot_product_simd_avx2(a_ptr, b_ptr, n);
367                    *result_ptr = dot_product;
368                } else {
369                    let dot_product = self.dot_product_scalar(a_ptr, b_ptr, n);
370                    *result_ptr = dot_product;
371                }
372            }
373
374            #[cfg(not(target_arch = "x86_64"))]
375            {
376                let dot_product = self.dot_product_scalar(a_ptr, b_ptr, n);
377                *result_ptr = dot_product;
378            }
379        }
380
381        result
382    }
383
384    /// Vector-matrix multiplication: v^T * M
385    ///
386    /// Computes the product of a 1D vector with a 2D matrix, treating the vector
387    /// as a row vector. The implementation uses SIMD-optimized column-wise dot products
388    /// for maximum performance on compatible hardware.
389    ///
390    /// # Arguments
391    ///
392    /// * `other` - The 2D matrix tensor (vector length must match matrix rows)
393    ///
394    /// # Returns
395    ///
396    /// A 1D tensor containing the vector-matrix multiplication result
397    ///
398    /// # Implementation Details
399    ///
400    /// - Computes dot product between vector and each matrix column
401    /// - Uses SIMD kernels for each column when AVX2 is available
402    /// - Scalar fallback processes each column individually
403    /// - Memory layout optimized for column-wise access patterns
404    fn vector_matrix_mult(&self, other: &Tensor) -> Tensor {
405        assert_eq!(self.shape().rank(), 1, "First tensor must be 1D (vector)");
406        assert_eq!(other.shape().rank(), 2, "Second tensor must be 2D (matrix)");
407        assert_eq!(
408            self.shape().dims[0],
409            other.shape().dims[0],
410            "Vector length must match matrix rows"
411        );
412
413        let v_len = self.shape().dims[0];
414        let m_cols = other.shape().dims[1];
415
416        // Ensure both tensors are contiguous for kernel compatibility
417        let self_contiguous = if self.is_contiguous() {
418            self.clone()
419        } else {
420            self.contiguous()
421        };
422        let other_contiguous = if other.is_contiguous() {
423            other.clone()
424        } else {
425            other.contiguous()
426        };
427
428        // Use uninitialized allocation for performance - result will be fully written
429        let mut result = Tensor::new_uninitialized(vec![m_cols]);
430
431        unsafe {
432            let v_ptr = self_contiguous.as_ptr();
433            let m_ptr = other_contiguous.as_ptr();
434            let result_ptr = result.as_mut_ptr();
435
436            #[cfg(target_arch = "x86_64")]
437            {
438                if is_x86_feature_detected!("avx2") {
439                    // Use SIMD for each column
440                    for col in 0..m_cols {
441                        let dot_product =
442                            self.vector_matrix_column_simd_avx2(v_ptr, m_ptr, v_len, m_cols, col);
443                        *result_ptr.add(col) = dot_product;
444                    }
445                } else {
446                    // Use scalar for each column
447                    for col in 0..m_cols {
448                        let dot_product =
449                            self.vector_matrix_column_scalar(v_ptr, m_ptr, v_len, m_cols, col);
450                        *result_ptr.add(col) = dot_product;
451                    }
452                }
453            }
454
455            #[cfg(not(target_arch = "x86_64"))]
456            {
457                // Use scalar for each column
458                for col in 0..m_cols {
459                    let dot_product =
460                        self.vector_matrix_column_scalar(v_ptr, m_ptr, v_len, m_cols, col);
461                    *result_ptr.add(col) = dot_product;
462                }
463            }
464        }
465
466        result
467    }
468
469    /// Matrix-vector multiplication: M * v
470    ///
471    /// Computes the product of a 2D matrix with a 1D vector, treating the vector
472    /// as a column vector. The implementation uses SIMD-optimized row-wise dot products
473    /// for maximum performance on compatible hardware.
474    ///
475    /// # Arguments
476    ///
477    /// * `other` - The 1D vector tensor (matrix columns must match vector length)
478    ///
479    /// # Returns
480    ///
481    /// A 1D tensor containing the matrix-vector multiplication result
482    ///
483    /// # Implementation Details
484    ///
485    /// - Computes dot product between each matrix row and the vector
486    /// - Uses SIMD kernels for each row when AVX2 is available
487    /// - Scalar fallback processes each row individually
488    /// - Memory layout optimized for row-wise access patterns
489    fn matrix_vector_mult(&self, other: &Tensor) -> Tensor {
490        assert_eq!(self.shape().rank(), 2, "First tensor must be 2D (matrix)");
491        assert_eq!(other.shape().rank(), 1, "Second tensor must be 1D (vector)");
492        assert_eq!(
493            self.shape().dims[1],
494            other.shape().dims[0],
495            "Matrix columns must match vector length"
496        );
497
498        let m_rows = self.shape().dims[0];
499        let m_cols = self.shape().dims[1];
500
501        // Ensure both tensors are contiguous for kernel compatibility
502        let self_contiguous = if self.is_contiguous() {
503            self.clone()
504        } else {
505            self.contiguous()
506        };
507        let other_contiguous = if other.is_contiguous() {
508            other.clone()
509        } else {
510            other.contiguous()
511        };
512
513        // Use uninitialized allocation for performance - result will be fully written
514        let mut result = Tensor::new_uninitialized(vec![m_rows]);
515
516        unsafe {
517            let m_ptr = self_contiguous.as_ptr();
518            let v_ptr = other_contiguous.as_ptr();
519            let result_ptr = result.as_mut_ptr();
520
521            #[cfg(target_arch = "x86_64")]
522            {
523                if is_x86_feature_detected!("avx2") {
524                    // Use SIMD for each row
525                    for row in 0..m_rows {
526                        let dot_product =
527                            self.matrix_vector_row_simd_avx2(m_ptr, v_ptr, m_cols, row);
528                        *result_ptr.add(row) = dot_product;
529                    }
530                } else {
531                    // Use scalar for each row
532                    for row in 0..m_rows {
533                        let dot_product = self.matrix_vector_row_scalar(m_ptr, v_ptr, m_cols, row);
534                        *result_ptr.add(row) = dot_product;
535                    }
536                }
537            }
538
539            #[cfg(not(target_arch = "x86_64"))]
540            {
541                // Use scalar for each row
542                for row in 0..m_rows {
543                    let dot_product = self.matrix_vector_row_scalar(m_ptr, v_ptr, m_cols, row);
544                    *result_ptr.add(row) = dot_product;
545                }
546            }
547        }
548
549        result
550    }
551
552    /// Standard matrix-matrix multiplication (2D @ 2D)
553    ///
554    /// Computes the product of two 2D matrices using intelligent kernel selection
555    /// based on matrix dimensions. The implementation uses cache-friendly blocked
556    /// algorithms for large matrices and direct computation for small matrices.
557    ///
558    /// # Arguments
559    ///
560    /// * `other` - The right matrix (2D tensor with compatible inner dimensions)
561    ///
562    /// # Returns
563    ///
564    /// A 2D tensor containing the matrix multiplication result
565    ///
566    /// # Implementation Details
567    ///
568    /// - Uses `MatmulConfig::for_dimensions` for optimal kernel selection
569    /// - Dispatches to `kernels::matrix_multiply_blocked` for computation
570    /// - Supports both SIMD and scalar execution paths
571    /// - Memory layout optimized for cache efficiency and SIMD alignment
572    fn matrix_matrix_mult(&self, other: &Tensor) -> Tensor {
573        let m = self.shape().dims[0]; // Result rows
574        let k = self.shape().dims[1]; // Inner dimension
575        let n = other.shape().dims[1]; // Result columns
576
577        assert_eq!(
578            k,
579            other.shape().dims[0],
580            "Inner dimensions must match: {} vs {}",
581            k,
582            other.shape().dims[0]
583        );
584
585        // Ensure both tensors are contiguous for kernel compatibility
586        let self_contiguous = if self.is_contiguous() {
587            self.clone()
588        } else {
589            self.contiguous()
590        };
591        let other_contiguous = if other.is_contiguous() {
592            other.clone()
593        } else {
594            other.contiguous()
595        };
596
597        // Use uninitialized allocation for performance - will be initialized properly
598        let mut result = Tensor::new_uninitialized(vec![m, n]);
599
600        unsafe {
601            let a_ptr = self_contiguous.as_ptr();
602            let b_ptr = other_contiguous.as_ptr();
603            let c_ptr = result.as_mut_ptr();
604
605            // Determine optimal configuration and dispatch
606            let config = MatmulConfig::for_dimensions(m, n, k);
607            kernels::matrix_multiply_blocked(a_ptr, b_ptr, c_ptr, m, n, k, &config);
608        }
609
610        result
611    }
612
613    /// Batched matrix multiplication for higher-dimensional tensors
614    ///
615    /// Performs matrix multiplication on the last two dimensions while broadcasting
616    /// the leading dimensions. This operation supports arbitrary tensor shapes
617    /// with at least 2 dimensions, following NumPy broadcasting rules.
618    ///
619    /// # Arguments
620    ///
621    /// * `other` - The other tensor for batched multiplication (must have at least 2D)
622    ///
623    /// # Returns
624    ///
625    /// A tensor with batched matrix multiplication results, with shape determined
626    /// by broadcasting the batch dimensions and matrix multiplication on the last two
627    ///
628    /// # Implementation Details
629    ///
630    /// - Broadcasts batch dimensions following NumPy right-aligned rules
631    /// - Performs individual matrix multiplications for each batch element
632    /// - Uses `calculate_batch_offset_with_broadcast` for memory offset computation
633    /// - Supports broadcasting of singleton dimensions (size 1) to any size
634    fn batched_matmul(&self, other: &Tensor) -> Tensor {
635        let self_shape = self.shape();
636        let other_shape = other.shape();
637
638        // Ensure both tensors have at least 2 dimensions
639        assert!(
640            self_shape.rank() >= 2,
641            "Batched matmul requires at least 2D tensors"
642        );
643        assert!(
644            other_shape.rank() >= 2,
645            "Batched matmul requires at least 2D tensors"
646        );
647
648        // Get matrix dimensions (last two dimensions)
649        let self_m = self_shape.dims[self_shape.rank() - 2];
650        let self_k = self_shape.dims[self_shape.rank() - 1];
651        let other_k = other_shape.dims[other_shape.rank() - 2];
652        let other_n = other_shape.dims[other_shape.rank() - 1];
653
654        assert_eq!(
655            self_k, other_k,
656            "Inner dimensions must match for batched matmul: {} vs {}",
657            self_k, other_k
658        );
659
660        // Calculate output shape by broadcasting batch dimensions
661        let mut output_dims = Vec::new();
662        let max_rank = self_shape.rank().max(other_shape.rank());
663
664        // Broadcast batch dimensions (right-aligned)
665        for i in 0..(max_rank - 2) {
666            let self_batch_rank = self_shape.rank() - 2;
667            let other_batch_rank = other_shape.rank() - 2;
668
669            let self_dim = if i < self_batch_rank {
670                self_shape.dims[self_batch_rank - 1 - i]
671            } else {
672                1
673            };
674            let other_dim = if i < other_batch_rank {
675                other_shape.dims[other_batch_rank - 1 - i]
676            } else {
677                1
678            };
679
680            if self_dim == 1 {
681                output_dims.push(other_dim);
682            } else if other_dim == 1 || self_dim == other_dim {
683                output_dims.push(self_dim);
684            } else {
685                panic!("Cannot broadcast dimensions {} and {}", self_dim, other_dim);
686            }
687        }
688
689        // Reverse to get correct order (we built from right to left)
690        output_dims.reverse();
691
692        // Add matrix dimensions
693        output_dims.push(self_m);
694        output_dims.push(other_n);
695
696        // Use uninitialized allocation for performance - result will be fully written
697        let mut result = Tensor::new_uninitialized(output_dims.clone());
698
699        // Calculate total number of matrix multiplications
700        let batch_size: usize = output_dims[..output_dims.len() - 2].iter().product();
701
702        unsafe {
703            // Perform batched matrix multiplication
704            let batch_dims = &output_dims[..output_dims.len() - 2];
705            for batch_idx in 0..batch_size {
706                // Calculate offsets for this batch with broadcasting support
707                let self_offset = self.calculate_batch_offset_with_broadcast(
708                    batch_idx,
709                    self_m * self_k,
710                    batch_dims,
711                );
712                let other_offset = other.calculate_batch_offset_with_broadcast(
713                    batch_idx,
714                    other_k * other_n,
715                    batch_dims,
716                );
717                let result_offset = batch_idx * self_m * other_n;
718
719                let a_ptr = self.as_ptr().add(self_offset);
720                let b_ptr = other.as_ptr().add(other_offset);
721                let c_ptr = result.as_mut_ptr().add(result_offset);
722
723                // Perform single matrix multiplication with dynamic configuration
724                let config = MatmulConfig::for_dimensions(self_m, other_n, self_k);
725                kernels::matrix_multiply_blocked(
726                    a_ptr, b_ptr, c_ptr, self_m, other_n, self_k, &config,
727                );
728            }
729        }
730
731        // Handle PyTorch-compatible shape squeezing
732        // If one operand was 2D, squeeze out the batch dimension from the result
733        let should_squeeze_batch = self_shape.rank() == 2 || other_shape.rank() == 2;
734        if should_squeeze_batch && output_dims.len() > 2 && output_dims[0] == 1 {
735            // Squeeze out the leading dimension of size 1
736            result = result.squeeze(Some(0));
737        }
738
739        result
740    }
741
742    /// Calculate memory offset for batched operations with broadcasting support
743    ///
744    /// Computes the memory offset for a specific batch element, taking into account
745    /// broadcasting rules where singleton dimensions (size 1) are repeated across
746    /// the batch. This enables efficient batched operations with broadcasting.
747    ///
748    /// # Arguments
749    ///
750    /// * `batch_idx` - Linear batch index (0-based)
751    /// * `matrix_size` - Size of each matrix in elements (product of last two dimensions)
752    /// * `output_batch_dims` - Output batch dimensions for reference (leading dimensions)
753    ///
754    /// # Returns
755    ///
756    /// Memory offset in elements for the specified batch index
757    ///
758    /// # Implementation Details
759    ///
760    /// - Converts linear batch index to multi-dimensional coordinates
761    /// - Handles broadcasting by mapping coordinates to actual tensor dimensions
762    /// - Uses stride-based offset calculation for memory efficiency
763    /// - Supports right-aligned broadcasting following NumPy conventions
764    fn calculate_batch_offset_with_broadcast(
765        &self,
766        batch_idx: usize,
767        matrix_size: usize,
768        output_batch_dims: &[usize],
769    ) -> usize {
770        if output_batch_dims.is_empty() {
771            return 0;
772        }
773
774        // Convert linear batch index to multi-dimensional coordinates
775        let mut coords = Vec::new();
776        let mut temp_idx = batch_idx;
777
778        for &dim_size in output_batch_dims.iter().rev() {
779            coords.push(temp_idx % dim_size);
780            temp_idx /= dim_size;
781        }
782        coords.reverse();
783
784        // Calculate actual offset based on this tensor's batch dimensions
785        let self_batch_dims = &self.shape().dims[..self.shape().rank() - 2];
786        let mut offset = 0;
787
788        // Align coordinates from the right (broadcasting is right-aligned)
789        let coord_offset = if output_batch_dims.len() >= self_batch_dims.len() {
790            output_batch_dims.len() - self_batch_dims.len()
791        } else {
792            0
793        };
794
795        // Calculate offset using strides
796        for (i, &self_dim) in self_batch_dims.iter().enumerate() {
797            let coord_idx = coord_offset + i;
798            if coord_idx < coords.len() {
799                let coord = coords[coord_idx];
800                // If this tensor's dimension is 1, we stay at the same position (broadcasting)
801                let actual_coord = if self_dim == 1 { 0 } else { coord % self_dim };
802
803                // Calculate stride for this dimension
804                let mut stride = matrix_size;
805                for &later_dim in self_batch_dims.iter().skip(i + 1) {
806                    stride *= later_dim;
807                }
808
809                offset += actual_coord * stride;
810            }
811        }
812
813        offset
814    }
815
816    // ===== SIMD Optimized Implementations =====
817
818    /// AVX2-optimized dot product implementation
819    ///
820    /// Computes dot product using AVX2 SIMD instructions for 8x vectorization.
821    /// Processes 8 elements at a time with horizontal reduction for final sum.
822    ///
823    /// # Safety
824    ///
825    /// Requires AVX2 support and valid pointers with sufficient memory for n elements.
826    /// Memory must be properly aligned for optimal performance.
827    ///
828    /// # Arguments
829    ///
830    /// * `a_ptr` - Pointer to first vector data
831    /// * `b_ptr` - Pointer to second vector data  
832    /// * `n` - Number of elements to process
833    ///
834    /// # Returns
835    ///
836    /// Dot product result as f32
837    #[cfg(target_arch = "x86_64")]
838    #[inline]
839    #[target_feature(enable = "avx2")]
840    unsafe fn dot_product_simd_avx2(&self, a_ptr: *const f32, b_ptr: *const f32, n: usize) -> f32 {
841        let simd_end = n & !7; // Process 8 elements at a time
842        let mut sum_vec = _mm256_setzero_ps();
843
844        // SIMD loop
845        for i in (0..simd_end).step_by(8) {
846            let a_vec = _mm256_loadu_ps(a_ptr.add(i));
847            let b_vec = _mm256_loadu_ps(b_ptr.add(i));
848            let prod = _mm256_mul_ps(a_vec, b_vec);
849            sum_vec = _mm256_add_ps(sum_vec, prod);
850        }
851
852        // Horizontal sum of SIMD register
853        let sum_hi = _mm256_extractf128_ps(sum_vec, 1);
854        let sum_lo = _mm256_castps256_ps128(sum_vec);
855        let sum_quad = _mm_add_ps(sum_hi, sum_lo);
856        let sum_dual = _mm_hadd_ps(sum_quad, sum_quad);
857        let sum_single = _mm_hadd_ps(sum_dual, sum_dual);
858        let mut result = _mm_cvtss_f32(sum_single);
859
860        // Handle remaining elements
861        for i in simd_end..n {
862            result += *a_ptr.add(i) * *b_ptr.add(i);
863        }
864
865        result
866    }
867
868    /// Scalar-optimized dot product implementation
869    ///
870    /// Computes dot product using scalar operations with 4x loop unrolling for
871    /// better instruction-level parallelism. Provides fallback for non-SIMD hardware.
872    ///
873    /// # Safety
874    ///
875    /// Requires valid pointers with sufficient memory for n elements.
876    ///
877    /// # Arguments
878    ///
879    /// * `a_ptr` - Pointer to first vector data
880    /// * `b_ptr` - Pointer to second vector data
881    /// * `n` - Number of elements to process
882    ///
883    /// # Returns
884    ///
885    /// Dot product result as f32
886    #[inline]
887    unsafe fn dot_product_scalar(&self, a_ptr: *const f32, b_ptr: *const f32, n: usize) -> f32 {
888        let mut sum = 0.0f32;
889        let unroll_end = n & !3; // Process 4 elements at a time
890
891        // Unrolled loop for better instruction-level parallelism
892        for i in (0..unroll_end).step_by(4) {
893            sum += *a_ptr.add(i) * *b_ptr.add(i);
894            sum += *a_ptr.add(i + 1) * *b_ptr.add(i + 1);
895            sum += *a_ptr.add(i + 2) * *b_ptr.add(i + 2);
896            sum += *a_ptr.add(i + 3) * *b_ptr.add(i + 3);
897        }
898
899        // Handle remaining elements
900        for i in unroll_end..n {
901            sum += *a_ptr.add(i) * *b_ptr.add(i);
902        }
903
904        sum
905    }
906
907    /// AVX2-optimized vector-matrix column dot product
908    ///
909    /// Computes dot product between a vector and a specific matrix column using
910    /// AVX2 SIMD instructions. Optimized for column-wise access patterns.
911    ///
912    /// # Safety
913    ///
914    /// Requires AVX2 support and valid pointers with sufficient memory.
915    /// Matrix must be in row-major layout with m_cols columns.
916    ///
917    /// # Arguments
918    ///
919    /// * `v_ptr` - Pointer to vector data
920    /// * `m_ptr` - Pointer to matrix data (row-major layout)
921    /// * `v_len` - Length of vector (must match matrix rows)
922    /// * `m_cols` - Number of columns in matrix
923    /// * `col` - Column index to compute dot product with
924    ///
925    /// # Returns
926    ///
927    /// Dot product result as f32
928    #[cfg(target_arch = "x86_64")]
929    #[inline]
930    #[target_feature(enable = "avx2")]
931    unsafe fn vector_matrix_column_simd_avx2(
932        &self,
933        v_ptr: *const f32,
934        m_ptr: *const f32,
935        v_len: usize,
936        m_cols: usize,
937        col: usize,
938    ) -> f32 {
939        let simd_end = v_len & !7;
940        let mut sum_vec = _mm256_setzero_ps();
941
942        // Process 8 elements at a time with optimized gather
943        for i in (0..simd_end).step_by(8) {
944            let v_vec = _mm256_loadu_ps(v_ptr.add(i));
945
946            // Optimized gather for matrix column elements
947            let m0 = *m_ptr.add(i * m_cols + col);
948            let m1 = *m_ptr.add((i + 1) * m_cols + col);
949            let m2 = *m_ptr.add((i + 2) * m_cols + col);
950            let m3 = *m_ptr.add((i + 3) * m_cols + col);
951            let m4 = *m_ptr.add((i + 4) * m_cols + col);
952            let m5 = *m_ptr.add((i + 5) * m_cols + col);
953            let m6 = *m_ptr.add((i + 6) * m_cols + col);
954            let m7 = *m_ptr.add((i + 7) * m_cols + col);
955
956            let m_vec = _mm256_set_ps(m7, m6, m5, m4, m3, m2, m1, m0);
957
958            let prod = _mm256_mul_ps(v_vec, m_vec);
959            sum_vec = _mm256_add_ps(sum_vec, prod);
960        }
961
962        // Horizontal sum
963        let sum_hi = _mm256_extractf128_ps(sum_vec, 1);
964        let sum_lo = _mm256_castps256_ps128(sum_vec);
965        let sum_quad = _mm_add_ps(sum_hi, sum_lo);
966        let sum_dual = _mm_hadd_ps(sum_quad, sum_quad);
967        let sum_single = _mm_hadd_ps(sum_dual, sum_dual);
968        let mut result = _mm_cvtss_f32(sum_single);
969
970        // Handle remaining elements
971        for i in simd_end..v_len {
972            result += *v_ptr.add(i) * *m_ptr.add(i * m_cols + col);
973        }
974
975        result
976    }
977
978    /// Scalar vector-matrix column dot product
979    ///
980    /// Computes dot product between a vector and a specific matrix column using
981    /// scalar operations. Provides fallback for non-SIMD hardware.
982    ///
983    /// # Safety
984    ///
985    /// Requires valid pointers with sufficient memory.
986    /// Matrix must be in row-major layout with m_cols columns.
987    ///
988    /// # Arguments
989    ///
990    /// * `v_ptr` - Pointer to vector data
991    /// * `m_ptr` - Pointer to matrix data (row-major layout)
992    /// * `v_len` - Length of vector (must match matrix rows)
993    /// * `m_cols` - Number of columns in matrix
994    /// * `col` - Column index to compute dot product with
995    ///
996    /// # Returns
997    ///
998    /// Dot product result as f32
999    #[inline]
1000    unsafe fn vector_matrix_column_scalar(
1001        &self,
1002        v_ptr: *const f32,
1003        m_ptr: *const f32,
1004        v_len: usize,
1005        m_cols: usize,
1006        col: usize,
1007    ) -> f32 {
1008        let mut sum = 0.0f32;
1009        for i in 0..v_len {
1010            sum += *v_ptr.add(i) * *m_ptr.add(i * m_cols + col);
1011        }
1012        sum
1013    }
1014
1015    /// AVX2-optimized matrix-vector row dot product
1016    ///
1017    /// Computes dot product between a specific matrix row and a vector using
1018    /// AVX2 SIMD instructions. Optimized for row-wise access patterns.
1019    ///
1020    /// # Safety
1021    ///
1022    /// Requires AVX2 support and valid pointers with sufficient memory.
1023    /// Matrix must be in row-major layout with m_cols columns.
1024    ///
1025    /// # Arguments
1026    ///
1027    /// * `m_ptr` - Pointer to matrix data (row-major layout)
1028    /// * `v_ptr` - Pointer to vector data
1029    /// * `m_cols` - Number of columns in matrix (must match vector length)
1030    /// * `row` - Row index to compute dot product with
1031    ///
1032    /// # Returns
1033    ///
1034    /// Dot product result as f32
1035    #[cfg(target_arch = "x86_64")]
1036    #[inline]
1037    #[target_feature(enable = "avx2")]
1038    unsafe fn matrix_vector_row_simd_avx2(
1039        &self,
1040        m_ptr: *const f32,
1041        v_ptr: *const f32,
1042        m_cols: usize,
1043        row: usize,
1044    ) -> f32 {
1045        let simd_end = m_cols & !7;
1046        let mut sum_vec = _mm256_setzero_ps();
1047        let row_ptr = m_ptr.add(row * m_cols);
1048
1049        for i in (0..simd_end).step_by(8) {
1050            let m_vec = _mm256_loadu_ps(row_ptr.add(i));
1051            let v_vec = _mm256_loadu_ps(v_ptr.add(i));
1052            let prod = _mm256_mul_ps(m_vec, v_vec);
1053            sum_vec = _mm256_add_ps(sum_vec, prod);
1054        }
1055
1056        // Horizontal sum
1057        let sum_hi = _mm256_extractf128_ps(sum_vec, 1);
1058        let sum_lo = _mm256_castps256_ps128(sum_vec);
1059        let sum_quad = _mm_add_ps(sum_hi, sum_lo);
1060        let sum_dual = _mm_hadd_ps(sum_quad, sum_quad);
1061        let sum_single = _mm_hadd_ps(sum_dual, sum_dual);
1062        let mut result = _mm_cvtss_f32(sum_single);
1063
1064        // Handle remaining elements
1065        for i in simd_end..m_cols {
1066            result += *row_ptr.add(i) * *v_ptr.add(i);
1067        }
1068
1069        result
1070    }
1071
1072    /// Scalar matrix-vector row dot product
1073    ///
1074    /// Computes dot product between a specific matrix row and a vector using
1075    /// scalar operations. Provides fallback for non-SIMD hardware.
1076    ///
1077    /// # Safety
1078    ///
1079    /// Requires valid pointers with sufficient memory.
1080    /// Matrix must be in row-major layout with m_cols columns.
1081    ///
1082    /// # Arguments
1083    ///
1084    /// * `m_ptr` - Pointer to matrix data (row-major layout)
1085    /// * `v_ptr` - Pointer to vector data
1086    /// * `m_cols` - Number of columns in matrix (must match vector length)
1087    /// * `row` - Row index to compute dot product with
1088    ///
1089    /// # Returns
1090    ///
1091    /// Dot product result as f32
1092    #[inline]
1093    unsafe fn matrix_vector_row_scalar(
1094        &self,
1095        m_ptr: *const f32,
1096        v_ptr: *const f32,
1097        m_cols: usize,
1098        row: usize,
1099    ) -> f32 {
1100        let mut sum = 0.0f32;
1101        let row_ptr = m_ptr.add(row * m_cols);
1102        for i in 0..m_cols {
1103            sum += *row_ptr.add(i) * *v_ptr.add(i);
1104        }
1105        sum
1106    }
1107}
1108
1109#[cfg(test)]
1110mod tests {
1111    //! Matrix multiplication operation tests
1112    //!
1113    //! This module contains comprehensive tests for matrix multiplication operations,
1114    //! including basic functionality, kernel selection, and large matrix handling.
1115    //! Tests cover all supported operation types and edge cases.
1116
1117    use super::*;
1118
1119    /// Test basic 2x2 matrix multiplication functionality
1120    ///
1121    /// Verifies that the matmul operation correctly computes the product of two 2x2 matrices
1122    /// and produces the expected numerical results. This test validates the core matrix
1123    /// multiplication algorithm and result shape computation.
1124    #[test]
1125    fn test_matmul_2d_basic() {
1126        // Test basic 2x2 matrix multiplication
1127        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1128        let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
1129        let result = a.matmul(&b);
1130
1131        assert_eq!(result.shape().dims, vec![2, 2]);
1132
1133        // Expected result: [[19, 22], [43, 50]]
1134        unsafe {
1135            let ptr = result.as_ptr();
1136            assert_eq!(*ptr.add(0), 19.0); // (0,0)
1137            assert_eq!(*ptr.add(1), 22.0); // (0,1)
1138            assert_eq!(*ptr.add(2), 43.0); // (1,0)
1139            assert_eq!(*ptr.add(3), 50.0); // (1,1)
1140        }
1141    }
1142
1143    /// Test 2D @ 2D matmul gradient computation (matrix @ matrix)
1144    #[test]
1145    fn test_matmul_2d_2d_gradients() {
1146        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
1147            .unwrap()
1148            .with_requires_grad();
1149        let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2])
1150            .unwrap()
1151            .with_requires_grad();
1152
1153        let mut result = a.matmul(&b); // [2, 2] @ [2, 2] -> [2, 2]
1154        assert_eq!(result.shape().dims, vec![2, 2]);
1155
1156        // Expected result: [[19, 22], [43, 50]]
1157        let expected = [19.0, 22.0, 43.0, 50.0];
1158        unsafe {
1159            let ptr = result.as_ptr();
1160            for (i, val) in expected.iter().enumerate().take(4) {
1161                assert_eq!(*ptr.add(i), *val);
1162            }
1163        }
1164
1165        // Set up gradient for backward pass
1166        let grad_output = Tensor::from_slice(&[1.0, 1.0, 1.0, 1.0], vec![2, 2]).unwrap();
1167        result.backward(Some(grad_output));
1168
1169        let grad_a = a.grad_by_value().unwrap();
1170        let grad_b = b.grad_by_value().unwrap();
1171
1172        assert_eq!(grad_a.shape().dims, vec![2, 2]);
1173        assert_eq!(grad_b.shape().dims, vec![2, 2]);
1174
1175        // grad_a = grad_output @ b^T = [[1, 1], [1, 1]] @ [[5, 7], [6, 8]] = [[11, 15], [11, 15]]
1176
1177        unsafe {
1178            let grad_a_ptr = grad_a.as_ptr();
1179            assert_eq!(*grad_a_ptr.add(0), 11.0); // 1*5 + 1*6
1180            assert_eq!(*grad_a_ptr.add(1), 15.0); // 1*7 + 1*8
1181            assert_eq!(*grad_a_ptr.add(2), 11.0); // 1*5 + 1*6
1182            assert_eq!(*grad_a_ptr.add(3), 15.0); // 1*7 + 1*8
1183        }
1184
1185        // grad_b = a^T @ grad_output = [[1, 3], [2, 4]] @ [[1, 1], [1, 1]] = [[4, 4], [6, 6]]
1186        unsafe {
1187            let grad_b_ptr = grad_b.as_ptr();
1188            assert_eq!(*grad_b_ptr.add(0), 4.0); // 1*1 + 3*1
1189            assert_eq!(*grad_b_ptr.add(1), 4.0); // 1*1 + 3*1
1190            assert_eq!(*grad_b_ptr.add(2), 6.0); // 2*1 + 4*1
1191            assert_eq!(*grad_b_ptr.add(3), 6.0); // 2*1 + 4*1
1192        }
1193    }
1194
1195    /// Test matmul gradient computation with partial requires_grad
1196    #[test]
1197    fn test_matmul_partial_requires_grad() {
1198        // Test case where only one operand requires gradients (like the linear layer case)
1199        let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap(); // No requires_grad
1200        let b = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2])
1201            .unwrap()
1202            .with_requires_grad(); // Only b requires gradients
1203
1204        let mut result = a.matmul(&b); // [3] @ [3, 2] -> [2]
1205        assert_eq!(result.shape().dims, vec![2]);
1206
1207        result.backward(None);
1208
1209        // Only b should have gradients
1210        assert!(a.grad_by_value().is_none());
1211        let grad_b = b.grad_by_value().unwrap();
1212
1213        assert_eq!(grad_b.shape().dims, vec![3, 2]);
1214
1215        // grad_b = outer_product(a, grad_output)
1216        // Since grad_output defaults to ones([2]), grad_b[i,j] = a[i] * 1.0 = a[i]
1217        unsafe {
1218            let grad_b_ptr = grad_b.as_ptr();
1219            assert_eq!(*grad_b_ptr.add(0), 1.0); // a[0] * grad_output[0]
1220            assert_eq!(*grad_b_ptr.add(1), 1.0); // a[0] * grad_output[1]
1221            assert_eq!(*grad_b_ptr.add(2), 2.0); // a[1] * grad_output[0]
1222            assert_eq!(*grad_b_ptr.add(3), 2.0); // a[1] * grad_output[1]
1223            assert_eq!(*grad_b_ptr.add(4), 3.0); // a[2] * grad_output[0]
1224            assert_eq!(*grad_b_ptr.add(5), 3.0); // a[2] * grad_output[1]
1225        }
1226    }
1227
1228    #[test]
1229    fn test_debug_gradient_values() {
1230        println!("=== Debugging matmul gradient issue ===");
1231
1232        // Test case: [1, 3, 4] @ [2, 4, 5] which should fail with our=41, torch=29
1233        let left_shape = vec![1, 3, 4];
1234        let right_shape = vec![2, 4, 5];
1235
1236        let mut left = Tensor::zeros(left_shape.clone()).with_requires_grad();
1237        let mut right = Tensor::zeros(right_shape.clone()).with_requires_grad();
1238
1239        let left_size = left_shape.iter().product::<usize>();
1240        let right_size = right_shape.iter().product::<usize>();
1241
1242        // Fill with exactly the same data as the validation test
1243        unsafe {
1244            for i in 0..left_size {
1245                *left.as_mut_ptr().add(i) = (i as f32) * 0.1 + 1.0;
1246            }
1247            for i in 0..right_size {
1248                *right.as_mut_ptr().add(i) = (i as f32) * 0.2 + 0.5;
1249            }
1250        }
1251
1252        println!(
1253            "Left shape: {:?}, data: {:?}",
1254            left.shape().dims,
1255            left.data()
1256        );
1257        println!(
1258            "Right shape: {:?}, data: {:?}",
1259            right.shape().dims,
1260            right.data()
1261        );
1262
1263        // Forward pass
1264        let mut result = left.matmul(&right);
1265        println!(
1266            "Result shape: {:?}, data: {:?}",
1267            result.shape().dims,
1268            result.data()
1269        );
1270
1271        // Backward pass with ones
1272        let grad_ones = Tensor::ones(result.shape().dims.clone());
1273        println!(
1274            "Grad ones shape: {:?}, data: {:?}",
1275            grad_ones.shape().dims,
1276            grad_ones.data()
1277        );
1278
1279        result.backward(Some(grad_ones));
1280
1281        let grad_left = left.grad_by_value().unwrap();
1282        let grad_right = right.grad_by_value().unwrap();
1283
1284        println!(
1285            "Left gradient shape: {:?}, data: {:?}",
1286            grad_left.shape().dims,
1287            grad_left.data()
1288        );
1289        println!(
1290            "Right gradient shape: {:?}, data: {:?}",
1291            grad_right.shape().dims,
1292            grad_right.data()
1293        );
1294
1295        println!(
1296            "Left gradient[0] = {} (expected ~29, but we're getting ~41)",
1297            grad_left.data()[0]
1298        );
1299    }
1300
1301    #[test]
1302    fn test_simple_batched_gradient() {
1303        println!("=== Testing simple batched gradient ===");
1304
1305        // Simple case: [2, 2, 2] @ [2, 2, 2]
1306        let left = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2])
1307            .unwrap()
1308            .with_requires_grad();
1309        let right = Tensor::from_slice(&[0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0], vec![2, 2, 2])
1310            .unwrap()
1311            .with_requires_grad();
1312
1313        println!("Left: {:?}", left.data());
1314        println!("Right: {:?}", right.data());
1315
1316        // Test transpose function first
1317        let right_t = right.transpose(1, 2);
1318        println!("Right transposed: {:?}", right_t.data());
1319        println!("Right transposed contiguous: {:?}", right_t.is_contiguous());
1320        println!("Right transposed strides: {:?}", right_t.strides());
1321
1322        let mut result = left.matmul(&right);
1323        println!("Result: {:?}", result.data());
1324
1325        let grad_ones = Tensor::ones(result.shape().dims.clone());
1326        result.backward(Some(grad_ones));
1327
1328        let grad_left = left.grad_by_value().unwrap();
1329        let grad_right = right.grad_by_value().unwrap();
1330
1331        println!("Left gradient: {:?}", grad_left.data());
1332        println!("Right gradient: {:?}", grad_right.data());
1333
1334        // Manual calculation for verification
1335        println!("\n=== Manual verification ===");
1336        println!("Expected left grad batch 0: [0.5+1.0, 1.5+2.0] = [1.5, 3.5]");
1337        println!("Expected left grad batch 1: [2.5+3.0, 3.5+4.0] = [5.5, 7.5]");
1338    }
1339
1340    #[test]
1341    fn test_linear_layer_pattern() {
1342        // Simulate the exact pattern from the training loop
1343        let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap(); // Input (no grad)
1344        let weight = Tensor::from_slice(&[0.1, 0.5, 0.3, 0.1, 0.5, 0.3], vec![3, 2])
1345            .unwrap()
1346            .with_requires_grad(); // Weight (requires grad)
1347        let bias = Tensor::from_slice(&[0.0, 0.1], vec![2])
1348            .unwrap()
1349            .with_requires_grad(); // Bias (requires grad)
1350
1351        // Forward pass
1352        let weighted = x_data.matmul(&weight); // [3] @ [3, 2] -> [2]
1353        let y_pred = weighted.add_tensor(&bias); // [2] + [2] -> [2]
1354
1355        // Create a simple loss (sum of squared differences with some target)
1356        let y_true = Tensor::from_slice(&[3.0, 5.0], vec![2]).unwrap();
1357        let mut loss = y_pred.sub_tensor(&y_true).pow_scalar(2.0).mean();
1358
1359        // Backward pass
1360        loss.backward(None);
1361
1362        // Check that gradients are computed correctly
1363        let grad_weight = weight.grad_by_value().unwrap();
1364        let grad_bias = bias.grad_by_value().unwrap();
1365
1366        assert_eq!(grad_weight.shape().dims, vec![3, 2]); // Same shape as weight
1367        assert_eq!(grad_bias.shape().dims, vec![2]); // Same shape as bias
1368
1369        // The exact gradient values depend on the computation graph, but shapes should be correct
1370        assert_eq!(grad_weight.size(), 6);
1371        assert_eq!(grad_bias.size(), 2);
1372
1373        // Verify that no gradient is computed for x_data (doesn't require grad)
1374        assert!(x_data.grad_by_value().is_none());
1375    }
1376}