train_station/tensor/ops/matmul/
mod.rs

1//! Matrix multiplication operations with optimized kernels
2//!
3//! This module provides a comprehensive matrix multiplication implementation optimized
4//! for single-threaded performance with SIMD acceleration. The implementation supports
5//! all NumPy-style matrix multiplication patterns including 1D/2D/ND tensor operations
6//! with automatic differentiation support.
7//!
8//! # Key Features
9//!
10//! - **SIMD Optimization**: AVX512/AVX2/SSE2 implementations with runtime dispatch
11//! - **Intelligent Dispatch**: Cached kernel selection based on matrix dimensions and alignment
12//! - **Cache Optimization**: Blocked algorithms with panel packing for L1/L2/L3 cache efficiency
13//! - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
14//! - **GradTrack Integration**: Automatic gradient computation for all operations
15//! - **Thread Safety**: All operations are thread-safe and Send + Sync
16//! - **Mathematical Validation**: High-precision equivalence to PyTorch reference
17//!
18//! # Performance Characteristics
19//!
20//! The implementation uses intelligent dispatch to select optimal kernels based on matrix size:
21//! - **Small matrices (≤1K elements)**: Direct computation with minimal overhead
22//! - **Medium matrices (1K-64K elements)**: Cache-optimized blocking for L1/L2 cache
23//! - **Large matrices (≥64K elements)**: Memory bandwidth optimized with hierarchical blocking
24//! - **AVX512 acceleration**: 16x SIMD operations for compatible hardware
25//! - **AVX2 acceleration**: 8x SIMD operations for compatible hardware
26//! - **SSE2 acceleration**: 4x SIMD operations for compatible hardware
27//! - **Scalar fallbacks**: Optimized scalar implementations for non-SIMD platforms
28//! - **Memory Safety**: Safe memory management with proper alignment
29//!
30//! # Architecture
31//!
32//! The matmul system uses a cached kernel dispatch architecture:
33//! - **`MatMulKernels`**: Cached function pointers for optimal performance
34//! - **Static Dispatch**: Runtime SIMD detection with compile-time optimization
35//! - **Operation Types**: Specialized kernels for each matmul operation pattern
36//! - **Size Thresholds**: Intelligent kernel selection based on matrix dimensions
37//! - **Alignment Detection**: Optimized paths for aligned vs unaligned data
38//!
39//! # Supported Operations
40//!
41//! - **1D @ 1D**: Dot product returning scalar tensor
42//! - **1D @ 2D**: Vector-matrix multiplication (v^T * M)
43//! - **2D @ 1D**: Matrix-vector multiplication (M * v)
44//! - **2D @ 2D**: Standard matrix multiplication with cache-optimized blocking
45//! - **ND @ ND**: Batched matrix multiplication on last two dimensions with broadcasting
46//!
47//! # Examples
48//!
49//! ## Basic Matrix Multiplication
50//!
51//! ```
52//! use train_station::Tensor;
53//!
54//! // 2D matrix multiplication
55//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
56//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
57//! let result = a.matmul(&b); // Uses optimized SIMD kernels
58//!
59//! assert_eq!(result.shape().dims(), vec![2, 2]);
60//! assert_eq!(result.data(), &[19.0, 22.0, 43.0, 50.0]);
61//! ```
62//!
63//! ## Vector-Matrix Multiplication
64//!
65//! ```
66//! use train_station::Tensor;
67//!
68//! // Vector-matrix multiplication
69//! let v = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
70//! let m = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
71//! let result = v.matmul(&m); // [2] @ [2, 2] -> [2]
72//!
73//! assert_eq!(result.shape().dims(), vec![2]);
74//! assert_eq!(result.data(), &[7.0, 10.0]); // 1*1+2*3, 1*2+2*4
75//! ```
76//!
77//! ## Matrix-Vector Multiplication
78//!
79//! ```
80//! use train_station::Tensor;
81//!
82//! // Matrix-vector multiplication
83//! let m = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
84//! let v = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
85//! let result = m.matmul(&v); // [2, 2] @ [2] -> [2]
86//!
87//! assert_eq!(result.shape().dims(), vec![2]);
88//! assert_eq!(result.data(), &[5.0, 11.0]); // 1*1+2*2, 3*1+4*2
89//! ```
90//!
91//! ## Dot Product
92//!
93//! ```
94//! use train_station::Tensor;
95//!
96//! // 1D dot product
97//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
98//! let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
99//! let result = a.matmul(&b); // [3] @ [3] -> scalar
100//!
101//! assert_eq!(result.shape().dims(), vec![]); // Scalar tensor
102//! assert_eq!(result.data(), &[32.0]); // 1*4 + 2*5 + 3*6
103//! ```
104//!
105//! ## Batched Matrix Multiplication
106//!
107//! ```
108//! use train_station::Tensor;
109//!
110//! // Batched matrix multiplication
111//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]).unwrap();
112//! let b = Tensor::from_slice(&[0.5, 1.0, 1.5, 2.0], vec![2, 2]).unwrap();
113//! let result = a.matmul(&b); // [2, 2, 2] @ [2, 2] -> [2, 2, 2]
114//!
115//! assert_eq!(result.shape().dims(), vec![2, 2, 2]);
116//! ```
117//!
118//! ## Gradient Tracking
119//!
120//! ```
121//! use train_station::Tensor;
122//!
123//! // Matrix multiplication with gradient tracking
124//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
125//!     .unwrap()
126//!     .with_requires_grad();
127//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2])
128//!     .unwrap()
129//!     .with_requires_grad();
130//!
131//! let result = a.matmul(&b);
132//! assert!(result.requires_grad());
133//! assert_eq!(result.shape().dims(), vec![2, 2]);
134//! ```
135//!
136//! # Automatic Differentiation
137//!
138//! All operations support automatic differentiation when either operand requires gradients.
139//! Gradient computation follows PyTorch semantics with proper accumulation and chain rule
140//! application through the gradtrack engine.
141//!
142//! # Thread Safety
143//!
144//! All operations are thread-safe and can be used concurrently across multiple threads.
145//! The implementation uses immutable tensor references and thread-local gradtrack state.
146//!
147//! # Mathematical Validation
148//!
149//! All operations are validated against LibTorch reference implementation with high-precision
150//! numerical equivalence (target: 0.00e0 error tolerance, practical: 1e-6 tolerance for
151//! floating-point precision differences).
152
153use crate::Tensor;
154#[cfg(target_arch = "x86_64")]
155pub mod avx2_kernels;
156#[cfg(target_arch = "x86_64")]
157pub mod avx512_kernels;
158pub mod classification_and_validation;
159pub mod dispatch;
160#[cfg(target_arch = "x86_64")]
161pub mod pack_n_cache;
162pub mod scalar_kernels;
163#[cfg(target_arch = "x86_64")]
164pub mod sse_kernels;
165use dispatch::*;
166
167impl Tensor {
168    /// Matrix multiplication with intelligent kernel dispatch
169    ///
170    /// Performs matrix multiplication using optimized SIMD kernels selected based on:
171    /// - Runtime SIMD capability (AVX512/AVX2/SSE2/Scalar)
172    /// - Matrix operation type (1D@1D, 1D@2D, 2D@1D, 2D@2D, ND@ND)
173    /// - Matrix size classification (Small/Medium/Large)
174    /// - Memory alignment characteristics
175    ///
176    /// # Arguments
177    /// * `other` - Right-hand side tensor for multiplication
178    ///
179    /// # Returns
180    /// Result tensor with appropriate shape based on operation type
181    ///
182    /// # Panics
183    /// Panics if tensor shapes are incompatible for matrix multiplication
184    ///
185    /// # Examples
186    /// ```
187    /// use train_station::Tensor;
188    ///
189    /// // 2D matrix multiplication
190    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
191    /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
192    /// let result = a.matmul(&b);
193    /// assert_eq!(result.shape().dims(), vec![2, 2]);
194    ///
195    /// // 1D dot product
196    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
197    /// let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
198    /// let result = a.matmul(&b);
199    /// assert_eq!(result.shape().dims(), vec![]); // Scalar
200    /// ```
201    #[track_caller]
202    pub fn matmul(&self, other: &Tensor) -> Tensor {
203        // Get cached kernels for dispatch
204        let kernels = MatMulKernels::get_cached_kernels();
205
206        // Classify operation type based on shapes
207        let left_shape = self.shape().dims();
208        let right_shape = other.shape().dims();
209        let op_type = MatMulKernels::classify_operation(left_shape, right_shape);
210
211        // Validate shapes and compute result shape
212        let result_shape = Self::validate_and_compute_matmul_shape(left_shape, right_shape);
213
214        // Get tensor pointers directly - scalar kernels handle strides
215        let left_ptr = unsafe { self.as_ptr() };
216        let right_ptr = unsafe { other.as_ptr() };
217
218        // Create result tensor
219        let mut result = Tensor::new(result_shape.clone());
220
221        unsafe {
222            let result_ptr = result.as_mut_ptr();
223
224            // Dispatch to appropriate kernel based on operation type
225            match op_type {
226                MatMulOpType::Dot1D1D => {
227                    // 1D @ 1D: Dot product returning scalar
228                    let size = left_shape[0];
229                    if self.is_contiguous() && other.is_contiguous() {
230                        let dot_result = kernels.dispatch_dot_1d(left_ptr, right_ptr, size);
231                        *result_ptr = dot_result;
232                    } else {
233                        let left_stride = self.strides()[0];
234                        let right_stride = other.strides()[0];
235                        let dot_result = kernels.dispatch_dot_1d_strided(
236                            left_ptr,
237                            right_ptr,
238                            size,
239                            left_stride,
240                            right_stride,
241                        );
242                        *result_ptr = dot_result;
243                    }
244                }
245                MatMulOpType::Vec1D2D => {
246                    // 1D @ 2D: Vector-matrix multiplication (v^T * M)
247                    let k = left_shape[0];
248                    let n = right_shape[1];
249                    if self.is_contiguous() && other.is_contiguous() {
250                        kernels.dispatch_vec_mat(left_ptr, right_ptr, result_ptr, k, n);
251                    } else {
252                        let left_stride = self.strides()[0];
253                        let right_strides = other.strides();
254                        let right_row_stride = right_strides[0];
255                        let right_col_stride = right_strides[1];
256                        let result_stride = 1; // Result is always contiguous
257                        kernels.dispatch_vec_mat_strided(
258                            left_ptr,
259                            right_ptr,
260                            result_ptr,
261                            k,
262                            n,
263                            left_stride,
264                            right_row_stride,
265                            right_col_stride,
266                            result_stride,
267                        );
268                    }
269                }
270                MatMulOpType::Mat2D1D => {
271                    // 2D @ 1D: Matrix-vector multiplication (M * v)
272                    let m = left_shape[0];
273                    let k = left_shape[1];
274                    if self.is_contiguous() && other.is_contiguous() {
275                        kernels.dispatch_mat_vec(left_ptr, right_ptr, result_ptr, m, k);
276                    } else {
277                        let left_strides = self.strides();
278                        let left_row_stride = left_strides[0];
279                        let left_col_stride = left_strides[1];
280                        let right_stride = other.strides()[0];
281                        let result_stride = 1; // Result is always contiguous
282                        kernels.dispatch_mat_vec_strided(
283                            left_ptr,
284                            right_ptr,
285                            result_ptr,
286                            m,
287                            k,
288                            left_row_stride,
289                            left_col_stride,
290                            right_stride,
291                            result_stride,
292                        );
293                    }
294                }
295                MatMulOpType::Mat2D2D => {
296                    // 2D @ 2D: Standard matrix multiplication
297                    let m = left_shape[0];
298                    let k = left_shape[1];
299                    let n = right_shape[1];
300                    if self.is_contiguous() && other.is_contiguous() {
301                        kernels.dispatch_mat_mat(left_ptr, right_ptr, result_ptr, m, k, n);
302                    } else {
303                        let left_strides = self.strides();
304                        let left_row_stride = left_strides[0];
305                        let left_col_stride = left_strides[1];
306                        let right_strides = other.strides();
307                        let right_row_stride = right_strides[0];
308                        let right_col_stride = right_strides[1];
309                        let result_strides = result.strides();
310                        let result_row_stride = result_strides[0];
311                        let result_col_stride = result_strides[1];
312                        kernels.dispatch_mat_mat_strided(
313                            left_ptr,
314                            right_ptr,
315                            result_ptr,
316                            m,
317                            k,
318                            n,
319                            left_row_stride,
320                            left_col_stride,
321                            right_row_stride,
322                            right_col_stride,
323                            result_row_stride,
324                            result_col_stride,
325                        );
326                    }
327                }
328                MatMulOpType::BatchedND => {
329                    // ND @ ND: Batched matrix multiplication on last two dimensions
330                    Self::dispatch_batched_matmul_with_ptrs_strided(
331                        left_ptr,
332                        right_ptr,
333                        self,
334                        other,
335                        &mut result,
336                        kernels,
337                    );
338                }
339            }
340        }
341
342        // Set up gradient tracking if needed
343        if (self.requires_grad() || other.requires_grad()) && crate::gradtrack::is_grad_enabled() {
344            result.set_requires_grad_internal(true);
345            let grad_fn = crate::gradtrack::grad_fn::GradFn::MatMul {
346                left_operand: Box::new(self.clone()),
347                right_operand: Box::new(other.clone()),
348                requires_grad: (self.requires_grad(), other.requires_grad()),
349            };
350            result.set_grad_fn(grad_fn.clone());
351
352            // Register operation with gradtrack engine
353            // Always register both operand IDs - the gradient function will handle which ones need gradients
354            let input_ids = vec![self.id(), other.id()];
355            crate::gradtrack::engine::GradEngine::register_operation(
356                result.id(),
357                input_ids,
358                grad_fn,
359            );
360        }
361
362        result
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    //! Matrix multiplication operation tests
369    //!
370    //! This module contains comprehensive tests for matrix multiplication operations,
371    //! including basic functionality, kernel selection, and large matrix handling.
372    //! Tests cover all supported operation types and edge cases.
373
374    use super::*;
375
376    /// Test basic 2x2 matrix multiplication functionality
377    ///
378    /// Verifies that the matmul operation correctly computes the product of two 2x2 matrices
379    /// and produces the expected numerical results. This test validates the core matrix
380    /// multiplication algorithm and result shape computation.
381    #[test]
382    fn test_matmul_2d_basic() {
383        // Test basic 2x2 matrix multiplication
384        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
385        let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
386        let result = a.matmul(&b);
387
388        assert_eq!(result.shape().dims(), vec![2, 2]);
389
390        // Expected result: [[19, 22], [43, 50]]
391        unsafe {
392            let ptr = result.as_ptr();
393            assert_eq!(*ptr.add(0), 19.0); // (0,0)
394            assert_eq!(*ptr.add(1), 22.0); // (0,1)
395            assert_eq!(*ptr.add(2), 43.0); // (1,0)
396            assert_eq!(*ptr.add(3), 50.0); // (1,1)
397        }
398    }
399
400    #[test]
401    fn test_tall_skinny_vs_wide_correctness() {
402        // Tall-skinny: [128, 8] @ [8, 16]
403        let mut a_ts = Tensor::new(vec![128, 8]);
404        let mut b_ts = Tensor::new(vec![8, 16]);
405        for (i, v) in a_ts.data_mut().iter_mut().enumerate() {
406            *v = (i as f32 * 0.01).sin();
407        }
408        for (i, v) in b_ts.data_mut().iter_mut().enumerate() {
409            *v = (i as f32 * 0.02).cos();
410        }
411        let r_ts = a_ts.matmul(&b_ts);
412        assert_eq!(r_ts.shape().dims(), vec![128, 16]);
413
414        // Very-wide: [16, 8] @ [8, 256]
415        let mut a_w = Tensor::new(vec![16, 8]);
416        let mut b_w = Tensor::new(vec![8, 256]);
417        for (i, v) in a_w.data_mut().iter_mut().enumerate() {
418            *v = (i as f32 * 0.03).sin();
419        }
420        for (i, v) in b_w.data_mut().iter_mut().enumerate() {
421            *v = (i as f32 * 0.04).cos();
422        }
423        let r_w = a_w.matmul(&b_w);
424        assert_eq!(r_w.shape().dims(), vec![16, 256]);
425    }
426
427    /// Fast path: [1, K] @ [K, N] → [1, N]
428    #[test]
429    fn test_matmul_row_vector_times_matrix() {
430        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
431        let b = Tensor::from_slice(
432            &[
433                1.0, 2.0, 3.0, 4.0, // col-major? Our tensors are row-major [K,N]
434                5.0, 6.0, 7.0, 8.0,
435            ],
436            vec![4, 2],
437        )
438        .unwrap();
439        let result = a.matmul(&b);
440        assert_eq!(result.shape().dims(), vec![1, 2]);
441        unsafe {
442            let p = result.as_ptr();
443            // [1,4] * [4,2]
444            // col0 = 1*1 + 2*3 + 3*5 + 4*7 = 1 + 6 + 15 + 28 = 50
445            // col1 = 1*2 + 2*4 + 3*6 + 4*8 = 2 + 8 + 18 + 32 = 60
446            assert_eq!(*p.add(0), 50.0);
447            assert_eq!(*p.add(1), 60.0);
448        }
449    }
450
451    /// Fast path: [K] @ [K, N] → [N]
452    #[test]
453    fn test_matmul_1d_rowvec_times_matrix() {
454        let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
455        let b = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]).unwrap();
456        let r = a.matmul(&b);
457        assert_eq!(r.shape().dims(), vec![2]);
458        unsafe {
459            let p = r.as_ptr();
460            // [1,2,3] @ [[1,2],[3,4],[5,6]]
461            // = [1*1+2*3+3*5, 1*2+2*4+3*6] = [22, 28]
462            assert_eq!(*p.add(0), 22.0);
463            assert_eq!(*p.add(1), 28.0);
464        }
465    }
466
467    /// Test 2D @ 2D matmul gradient computation (matrix @ matrix)
468    #[test]
469    fn test_matmul_2d_2d_gradients() {
470        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
471            .unwrap()
472            .with_requires_grad();
473        let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2])
474            .unwrap()
475            .with_requires_grad();
476
477        let mut result = a.matmul(&b); // [2, 2] @ [2, 2] -> [2, 2]
478        assert_eq!(result.shape().dims(), vec![2, 2]);
479
480        // Expected result: [[19, 22], [43, 50]]
481        let expected = [19.0, 22.0, 43.0, 50.0];
482        unsafe {
483            let ptr = result.as_ptr();
484            for (i, val) in expected.iter().enumerate().take(4) {
485                assert_eq!(*ptr.add(i), *val);
486            }
487        }
488
489        // Set up gradient for backward pass
490        let grad_output = Tensor::from_slice(&[1.0, 1.0, 1.0, 1.0], vec![2, 2]).unwrap();
491        result.backward(Some(grad_output));
492
493        let grad_a = a.grad_owned().unwrap();
494        let grad_b = b.grad_owned().unwrap();
495
496        assert_eq!(grad_a.shape().dims(), vec![2, 2]);
497        assert_eq!(grad_b.shape().dims(), vec![2, 2]);
498
499        // grad_a = grad_output @ b^T = [[1, 1], [1, 1]] @ [[5, 7], [6, 8]] = [[11, 15], [11, 15]]
500
501        unsafe {
502            let grad_a_ptr = grad_a.as_ptr();
503            assert_eq!(*grad_a_ptr.add(0), 11.0); // 1*5 + 1*6
504            assert_eq!(*grad_a_ptr.add(1), 15.0); // 1*7 + 1*8
505            assert_eq!(*grad_a_ptr.add(2), 11.0); // 1*5 + 1*6
506            assert_eq!(*grad_a_ptr.add(3), 15.0); // 1*7 + 1*8
507        }
508
509        // grad_b = a^T @ grad_output = [[1, 3], [2, 4]] @ [[1, 1], [1, 1]] = [[4, 4], [6, 6]]
510        unsafe {
511            let grad_b_ptr = grad_b.as_ptr();
512            assert_eq!(*grad_b_ptr.add(0), 4.0); // 1*1 + 3*1
513            assert_eq!(*grad_b_ptr.add(1), 4.0); // 1*1 + 3*1
514            assert_eq!(*grad_b_ptr.add(2), 6.0); // 2*1 + 4*1
515            assert_eq!(*grad_b_ptr.add(3), 6.0); // 2*1 + 4*1
516        }
517    }
518
519    /// Test matmul gradient computation with partial requires_grad
520    #[test]
521    fn test_matmul_partial_requires_grad() {
522        // Test case where only one operand requires gradients (like the linear layer case)
523        let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap(); // No requires_grad
524        let b = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2])
525            .unwrap()
526            .with_requires_grad(); // Only b requires gradients
527
528        let mut result = a.matmul(&b); // [3] @ [3, 2] -> [2]
529        assert_eq!(result.shape().dims(), vec![2]);
530
531        result.backward(None);
532
533        // Only b should have gradients
534        assert!(a.grad_owned().is_none());
535        let grad_b = b.grad_owned().unwrap();
536
537        assert_eq!(grad_b.shape().dims(), vec![3, 2]);
538
539        // grad_b = outer_product(a, grad_output)
540        // Since grad_output defaults to ones([2]), grad_b[i,j] = a[i] * 1.0 = a[i]
541        unsafe {
542            let grad_b_ptr = grad_b.as_ptr();
543            assert_eq!(*grad_b_ptr.add(0), 1.0); // a[0] * grad_output[0]
544            assert_eq!(*grad_b_ptr.add(1), 1.0); // a[0] * grad_output[1]
545            assert_eq!(*grad_b_ptr.add(2), 2.0); // a[1] * grad_output[0]
546            assert_eq!(*grad_b_ptr.add(3), 2.0); // a[1] * grad_output[1]
547            assert_eq!(*grad_b_ptr.add(4), 3.0); // a[2] * grad_output[0]
548            assert_eq!(*grad_b_ptr.add(5), 3.0); // a[2] * grad_output[1]
549        }
550    }
551
552    #[test]
553    fn test_debug_gradient_values() {
554        println!("=== Debugging matmul gradient issue ===");
555
556        // Test case: [1, 3, 4] @ [2, 4, 5] which should fail with our=41, torch=29
557        let left_shape = vec![1, 3, 4];
558        let right_shape = vec![2, 4, 5];
559
560        let mut left = Tensor::zeros(left_shape.clone()).with_requires_grad();
561        let mut right = Tensor::zeros(right_shape.clone()).with_requires_grad();
562
563        let left_size = left_shape.iter().product::<usize>();
564        let right_size = right_shape.iter().product::<usize>();
565
566        // Fill with exactly the same data as the validation test
567        unsafe {
568            for i in 0..left_size {
569                *left.as_mut_ptr().add(i) = (i as f32) * 0.1 + 1.0;
570            }
571            for i in 0..right_size {
572                *right.as_mut_ptr().add(i) = (i as f32) * 0.2 + 0.5;
573            }
574        }
575
576        println!(
577            "Left shape: {:?}, data: {:?}",
578            left.shape().dims(),
579            left.data()
580        );
581        println!(
582            "Right shape: {:?}, data: {:?}",
583            right.shape().dims(),
584            right.data()
585        );
586
587        // Forward pass
588        let mut result = left.matmul(&right);
589        println!(
590            "Result shape: {:?}, data: {:?}",
591            result.shape().dims(),
592            result.data()
593        );
594
595        // Backward pass with ones
596        let grad_ones = Tensor::ones(result.shape().dims().to_vec());
597        println!(
598            "Grad ones shape: {:?}, data: {:?}",
599            grad_ones.shape().dims(),
600            grad_ones.data()
601        );
602
603        result.backward(Some(grad_ones));
604
605        let grad_left = left.grad_owned().unwrap();
606        let grad_right = right.grad_owned().unwrap();
607
608        println!(
609            "Left gradient shape: {:?}, data: {:?}",
610            grad_left.shape().dims(),
611            grad_left.data()
612        );
613        println!(
614            "Right gradient shape: {:?}, data: {:?}",
615            grad_right.shape().dims(),
616            grad_right.data()
617        );
618
619        println!(
620            "Left gradient[0] = {} (expected ~29, but we're getting ~41)",
621            grad_left.data()[0]
622        );
623    }
624
625    #[test]
626    fn test_simple_batched_gradient() {
627        println!("=== Testing simple batched gradient ===");
628
629        // Simple case: [2, 2, 2] @ [2, 2, 2]
630        let left = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2])
631            .unwrap()
632            .with_requires_grad();
633        let right = Tensor::from_slice(&[0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0], vec![2, 2, 2])
634            .unwrap()
635            .with_requires_grad();
636
637        println!("Left: {:?}", left.data());
638        println!("Right: {:?}", right.data());
639
640        // Test transpose function first
641        let right_t = right.transpose(1, 2);
642        println!("Right transposed: {:?}", right_t.data());
643        println!("Right transposed contiguous: {:?}", right_t.is_contiguous());
644        println!("Right transposed strides: {:?}", right_t.strides());
645
646        let mut result = left.matmul(&right);
647        println!("Result: {:?}", result.data());
648
649        let grad_ones = Tensor::ones(result.shape().dims().to_vec());
650        result.backward(Some(grad_ones));
651
652        let grad_left = left.grad_owned().unwrap();
653        let grad_right = right.grad_owned().unwrap();
654
655        println!("Left gradient: {:?}", grad_left.data());
656        println!("Right gradient: {:?}", grad_right.data());
657
658        // Manual calculation for verification
659        println!("\n=== Manual verification ===");
660        println!("Expected left grad batch 0: [0.5+1.0, 1.5+2.0] = [1.5, 3.5]");
661        println!("Expected left grad batch 1: [2.5+3.0, 3.5+4.0] = [5.5, 7.5]");
662    }
663
664    #[test]
665    fn test_alignment_checks_prevent_misaligned_access() {
666        use crate::tensor::core::memory::{detect_runtime_simd, simd_alignment_bytes};
667
668        // Create tensors that may not be aligned
669        let a = Tensor::new(vec![4]);
670        let b = Tensor::new(vec![4]);
671
672        let kernels = MatMulKernels::get_cached_kernels();
673        let simd_alignment = simd_alignment_bytes(detect_runtime_simd());
674
675        unsafe {
676            let a_ptr = a.as_ptr();
677            let b_ptr = b.as_ptr();
678
679            // Check if pointers are aligned
680            let a_aligned = (a_ptr as usize).is_multiple_of(simd_alignment);
681            let b_aligned = (b_ptr as usize).is_multiple_of(simd_alignment);
682
683            // The alignment check should correctly identify alignment status
684            let alignment_check = kernels.check_alignment_for_simd(
685                a_ptr,
686                b_ptr,
687                std::ptr::null_mut(),
688                simd_alignment,
689            );
690
691            // The check should match our manual calculation
692            assert_eq!(alignment_check, a_aligned && b_aligned);
693        }
694    }
695
696    #[test]
697    fn test_avx512_specific_alignment_validation() {
698        use crate::tensor::core::memory::SimdLevel;
699
700        let kernels = MatMulKernels::get_cached_kernels();
701
702        // Test AVX512-specific 64-byte alignment requirements
703        let test_ptr = 0x1004 as *const f32; // 4-byte aligned but not 64-byte aligned
704        let avx512_aligned_ptr = 0x1040 as *const f32; // 64-byte aligned
705
706        // 64-byte alignment check should fail for 4-byte aligned pointer
707        let not_avx512_aligned =
708            kernels.check_alignment_for_simd(test_ptr, test_ptr, std::ptr::null_mut(), 64);
709        assert!(
710            !not_avx512_aligned,
711            "4-byte aligned pointer should not be considered 64-byte aligned"
712        );
713
714        // 64-byte alignment check should pass for properly aligned pointer
715        let is_avx512_aligned = kernels.check_alignment_for_simd(
716            avx512_aligned_ptr,
717            avx512_aligned_ptr,
718            std::ptr::null_mut(),
719            64,
720        );
721        assert!(
722            is_avx512_aligned,
723            "64-byte aligned pointer should pass AVX512 alignment check"
724        );
725
726        // Verify that the SIMD level correctly maps to alignment requirements
727        match kernels.simd_level {
728            #[cfg(target_arch = "x86_64")]
729            SimdLevel::Avx512 => {
730                // Should require 64-byte alignment
731                assert_eq!(kernels.alignment, 64, "AVX512 should use 64-byte alignment");
732            }
733            #[cfg(target_arch = "x86_64")]
734            SimdLevel::Avx2 => {
735                // Should require 32-byte alignment
736                assert_eq!(kernels.alignment, 32, "AVX2 should use 32-byte alignment");
737            }
738            #[cfg(target_arch = "x86_64")]
739            SimdLevel::Sse2 => {
740                // Should require 16-byte alignment
741                assert_eq!(kernels.alignment, 16, "SSE2 should use 16-byte alignment");
742            }
743            SimdLevel::Scalar => {
744                // Should require at least 4-byte alignment
745                assert!(
746                    kernels.alignment >= 4,
747                    "Scalar should use at least 4-byte alignment"
748                );
749            }
750        }
751    }
752
753    #[test]
754    fn test_comprehensive_alignment_management() {
755        use crate::tensor::core::memory::{detect_runtime_simd, simd_alignment_bytes};
756
757        // Test that alignment checks work correctly throughout the entire pipeline
758        let kernels = MatMulKernels::get_cached_kernels();
759        let simd_alignment = simd_alignment_bytes(detect_runtime_simd());
760
761        // Create tensors - new tensors should be properly aligned
762        let a = Tensor::new(vec![16]);
763        let b = Tensor::new(vec![16]);
764
765        // Verify that new tensors are aligned
766        unsafe {
767            let a_ptr = a.as_ptr();
768            let b_ptr = b.as_ptr();
769
770            let a_aligned = (a_ptr as usize).is_multiple_of(simd_alignment);
771            let b_aligned = (b_ptr as usize).is_multiple_of(simd_alignment);
772
773            // New tensors should be aligned
774            assert!(a_aligned, "New tensor 'a' should be SIMD-aligned");
775            assert!(b_aligned, "New tensor 'b' should be SIMD-aligned");
776
777            // The helper function should correctly identify alignment (using a dummy c_ptr for testing)
778            let dummy_c_ptr = a_ptr as *mut f32; // Use a_ptr as dummy since we're just testing alignment
779            let alignment_check = kernels.check_actual_alignment(a_ptr, b_ptr, dummy_c_ptr);
780            // Now that aligned kernels are enabled, the alignment check should be true here
781            assert!(
782                alignment_check,
783                "Alignment check should pass for aligned pointers"
784            );
785        }
786
787        // Test that contiguous() preserves alignment
788        let a_transposed = a.transpose(0, 0); // Identity transpose, should still be contiguous
789        let a_contiguous = a_transposed.contiguous();
790
791        unsafe {
792            let a_cont_ptr = a_contiguous.as_ptr();
793            let a_cont_aligned = (a_cont_ptr as usize).is_multiple_of(simd_alignment);
794            assert!(
795                a_cont_aligned,
796                "Contiguous tensor should maintain SIMD alignment"
797            );
798        }
799
800        // Test actual matmul operations use correct kernels
801        let result = a.matmul(&b);
802        assert_eq!(result.shape().dims(), vec![]); // Dot product result is scalar
803
804        // Test that the system correctly handles both aligned and unaligned cases
805        // by ensuring no crashes occur and results are computed correctly
806        let a_2d = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
807        let b_2d = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
808        let result_2d = a_2d.matmul(&b_2d);
809
810        assert_eq!(result_2d.shape().dims(), vec![2, 2]);
811        // Verify computation correctness: [1,2; 3,4] @ [5,6; 7,8] = [19,22; 43,50]
812        assert!((result_2d.get(&[0, 0]) - 19.0).abs() < 1e-6);
813        assert!((result_2d.get(&[0, 1]) - 22.0).abs() < 1e-6);
814        assert!((result_2d.get(&[1, 0]) - 43.0).abs() < 1e-6);
815        assert!((result_2d.get(&[1, 1]) - 50.0).abs() < 1e-6);
816    }
817
818    #[test]
819    fn test_linear_layer_pattern() {
820        // Simulate the exact pattern from the training loop
821        let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap(); // Input (no grad)
822        let weight = Tensor::from_slice(&[0.1, 0.5, 0.3, 0.1, 0.5, 0.3], vec![3, 2])
823            .unwrap()
824            .with_requires_grad(); // Weight (requires grad)
825        let bias = Tensor::from_slice(&[0.0, 0.1], vec![2])
826            .unwrap()
827            .with_requires_grad(); // Bias (requires grad)
828
829        // Forward pass
830        let weighted = x_data.matmul(&weight); // [3] @ [3, 2] -> [2]
831        let y_pred = weighted.add_tensor(&bias); // [2] + [2] -> [2]
832
833        // Create a simple loss (sum of squared differences with some target)
834        let y_true = Tensor::from_slice(&[3.0, 5.0], vec![2]).unwrap();
835        let mut loss = y_pred.sub_tensor(&y_true).pow_scalar(2.0).mean();
836
837        // Backward pass
838        loss.backward(None);
839
840        // Check that gradients are computed correctly
841        let grad_bias = bias.grad_owned().unwrap();
842        let grad_weight = weight.grad_owned().unwrap();
843
844        assert_eq!(grad_weight.shape().dims(), vec![3, 2]); // Same shape as weight
845        assert_eq!(grad_bias.shape().dims(), vec![2]); // Same shape as bias
846
847        // The exact gradient values depend on the computation graph, but shapes should be correct
848        assert_eq!(grad_weight.size(), 6);
849        assert_eq!(grad_bias.size(), 2);
850
851        // Verify that no gradient is computed for x_data (doesn't require grad)
852        assert!(x_data.grad_owned().is_none());
853    }
854
855    #[test]
856    fn test_debug_large_matmul_gradient() {
857        use crate::gradtrack::clear_gradients;
858        clear_gradients();
859
860        println!("=== Debug Large MatMul Gradient Issue ===");
861
862        // Test progressively larger sizes to find where the issue starts
863        for &size in &[4, 8, 16, 32, 64] {
864            println!("\n--- Testing size {}x{} ---", size, size);
865
866            let left = Tensor::from_slice(
867                &(0..size * size)
868                    .map(|i| (i as f32) * 0.1 + 1.0)
869                    .collect::<Vec<_>>(),
870                vec![size, size],
871            )
872            .unwrap()
873            .with_requires_grad();
874
875            let right = Tensor::from_slice(
876                &(0..size * size)
877                    .map(|i| (i as f32) * 0.2 + 0.5)
878                    .collect::<Vec<_>>(),
879                vec![size, size],
880            )
881            .unwrap()
882            .with_requires_grad();
883
884            let mut result = left.matmul(&right);
885
886            // Backward with ones
887            result.backward(None);
888
889            let grad_left = left.grad_owned().unwrap();
890            let _grad_right = right.grad_owned().unwrap();
891
892            // Manual verification
893            // For C = A @ B, dC/dA = grad_output @ B^T, dC/dB = A^T @ grad_output
894            let right_t = right.transpose(0, 1);
895            let grad_ones = Tensor::ones(vec![size, size]);
896            let expected_grad_left = grad_ones.matmul(&right_t);
897
898            // Check if our gradient matches expected
899            let mut max_diff = 0.0f32;
900            let mut max_diff_idx = 0;
901            for i in 0..grad_left.size() {
902                let our_val = unsafe { *grad_left.as_ptr().add(i) };
903                let expected_val = unsafe { *expected_grad_left.as_ptr().add(i) };
904                let diff = (our_val - expected_val).abs();
905                if diff > max_diff {
906                    max_diff = diff;
907                    max_diff_idx = i;
908                }
909            }
910
911            println!(
912                "Max gradient diff: {} at index {} (size {}x{})",
913                max_diff, max_diff_idx, size, size
914            );
915
916            if max_diff > 1e-4 {
917                println!("PROBLEM DETECTED at size {}x{}", size, size);
918                let our_val = unsafe { *grad_left.as_ptr().add(max_diff_idx) };
919                let expected_val = unsafe { *expected_grad_left.as_ptr().add(max_diff_idx) };
920                println!(
921                    "  our_val={}, expected_val={}, diff={}",
922                    our_val, expected_val, max_diff
923                );
924
925                // Check if the issue is in transpose or contiguous
926                println!("  right.is_contiguous(): {}", right.is_contiguous());
927                println!("  right_t.is_contiguous(): {}", right_t.is_contiguous());
928
929                break;
930            }
931
932            clear_gradients();
933        }
934    }
935
936    #[test]
937    fn test_debug_transpose_contiguous_issue() {
938        use crate::gradtrack::clear_gradients;
939        clear_gradients();
940
941        println!("=== Debug Transpose/Contiguous Issue ===");
942
943        let size = 8;
944        let right = Tensor::from_slice(
945            &(0..size * size)
946                .map(|i| (i as f32) * 0.2 + 0.5)
947                .collect::<Vec<_>>(),
948            vec![size, size],
949        )
950        .unwrap();
951
952        println!("Original right tensor:");
953        println!("  is_contiguous: {}", right.is_contiguous());
954        println!("  strides: {:?}", right.strides());
955
956        let right_t = right.transpose(0, 1);
957        println!("Transposed right tensor:");
958        println!("  is_contiguous: {}", right_t.is_contiguous());
959        println!("  strides: {:?}", right_t.strides());
960
961        let right_t_contiguous = if right_t.is_contiguous() {
962            right_t.clone()
963        } else {
964            right_t.contiguous()
965        };
966        println!("Contiguous transposed right tensor:");
967        println!("  is_contiguous: {}", right_t_contiguous.is_contiguous());
968        println!("  strides: {:?}", right_t_contiguous.strides());
969
970        // Compare the data to make sure contiguous() works correctly
971        println!("Original data (first 8): {:?}", &right.data()[0..8]);
972        println!(
973            "Transposed data (first 8): {:?}",
974            &right_t_contiguous.data()[0..8]
975        );
976
977        // Manual transpose verification
978        for i in 0..4 {
979            for j in 0..4 {
980                let orig_val = right.get(&[i, j]);
981                let trans_val = right_t_contiguous.get(&[j, i]);
982                if (orig_val - trans_val).abs() > 1e-6 {
983                    println!(
984                        "Transpose error at ({},{}): orig={}, trans={}",
985                        i, j, orig_val, trans_val
986                    );
987                }
988            }
989        }
990    }
991
992    #[test]
993    fn test_debug_scalar_kernel_accuracy() {
994        use crate::gradtrack::clear_gradients;
995        clear_gradients();
996
997        println!("=== Debug Scalar Kernel Accuracy ===");
998
999        // Test the scalar kernel directly with known values
1000        let size = 3;
1001        let a_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
1002        let b_data = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]; // Identity matrix
1003        let mut c_data = vec![0.0; 9];
1004
1005        unsafe {
1006            crate::tensor::ops::matmul::scalar_kernels::matmul_scalar_2d_2d(
1007                a_data.as_ptr(),
1008                b_data.as_ptr(),
1009                c_data.as_mut_ptr(),
1010                size,
1011                size,
1012                size,
1013                size,
1014                1, // a strides
1015                size,
1016                1, // b strides
1017                size,
1018                1, // c strides
1019            );
1020        }
1021
1022        println!("A @ I = {:?}", c_data);
1023        println!("Expected: {:?}", a_data);
1024
1025        // Should be identical since A @ I = A
1026        for (i, (&expected, &actual)) in a_data.iter().zip(c_data.iter()).enumerate() {
1027            if (expected - actual).abs() > 1e-6 {
1028                println!(
1029                    "Scalar kernel error at index {}: expected={}, actual={}",
1030                    i, expected, actual
1031                );
1032            }
1033        }
1034
1035        // Now test a more complex case
1036        let a_data2 = [1.0, 2.0, 3.0, 4.0];
1037        let b_data2 = [2.0, 0.0, 0.0, 2.0];
1038        let mut c_data2 = vec![0.0; 4];
1039
1040        unsafe {
1041            crate::tensor::ops::matmul::scalar_kernels::matmul_scalar_2d_2d(
1042                a_data2.as_ptr(),
1043                b_data2.as_ptr(),
1044                c_data2.as_mut_ptr(),
1045                2,
1046                2,
1047                2,
1048                2,
1049                1, // a strides
1050                2,
1051                1, // b strides
1052                2,
1053                1, // c strides
1054            );
1055        }
1056
1057        println!("[[1,2],[3,4]] @ [[2,0],[0,2]] = {:?}", c_data2);
1058        println!("Expected: [2.0, 4.0, 6.0, 8.0]");
1059    }
1060
1061    #[test]
1062    fn test_minimal_noncontiguous_gradient() {
1063        use crate::gradtrack::clear_gradients;
1064        clear_gradients();
1065
1066        println!("=== Minimal Non-contiguous Gradient Test ===");
1067
1068        // Create simple tensors
1069        let left = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
1070            .unwrap()
1071            .with_requires_grad();
1072
1073        let right = Tensor::from_slice(&[1.0, 0.0, 0.0, 1.0], vec![2, 2])
1074            .unwrap()
1075            .with_requires_grad();
1076
1077        println!("Left requires_grad: {}", left.requires_grad());
1078        println!("Right requires_grad: {}", right.requires_grad());
1079
1080        // Make non-contiguous
1081        let left_nc = left.transpose(0, 1).transpose(0, 1);
1082        let right_nc = right.transpose(0, 1).transpose(0, 1);
1083
1084        println!("Left NC requires_grad: {}", left_nc.requires_grad());
1085        println!("Right NC requires_grad: {}", right_nc.requires_grad());
1086        println!("Left NC is_contiguous: {}", left_nc.is_contiguous());
1087        println!("Right NC is_contiguous: {}", right_nc.is_contiguous());
1088
1089        // Matmul
1090        let mut result = left_nc.matmul(&right_nc);
1091        println!("Result requires_grad: {}", result.requires_grad());
1092
1093        // Backward
1094        result.backward(None);
1095
1096        // Check gradients
1097        println!(
1098            "Left NC gradient exists: {}",
1099            left_nc.grad_owned().is_some()
1100        );
1101        println!(
1102            "Right NC gradient exists: {}",
1103            right_nc.grad_owned().is_some()
1104        );
1105
1106        // Also check original tensors
1107        println!(
1108            "Original left gradient exists: {}",
1109            left.grad_owned().is_some()
1110        );
1111        println!(
1112            "Original right gradient exists: {}",
1113            right.grad_owned().is_some()
1114        );
1115    }
1116
1117    #[test]
1118    fn test_debug_transpose_gradient_tracking() {
1119        use crate::gradtrack::clear_gradients;
1120        clear_gradients();
1121
1122        println!("=== Debug Transpose Gradient Tracking ===");
1123
1124        // Create a tensor with requires_grad
1125        let original = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
1126            .unwrap()
1127            .with_requires_grad();
1128
1129        println!("Original requires_grad: {}", original.requires_grad());
1130
1131        // Single transpose
1132        let t1 = original.transpose(0, 1);
1133        println!(
1134            "After first transpose requires_grad: {}",
1135            t1.requires_grad()
1136        );
1137
1138        // Double transpose (should be identity)
1139        let t2 = t1.transpose(0, 1);
1140        println!(
1141            "After second transpose requires_grad: {}",
1142            t2.requires_grad()
1143        );
1144
1145        // Test the pattern from the failing test
1146        let left = Tensor::from_slice(
1147            &[
1148                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1149            ],
1150            vec![2, 2, 3],
1151        )
1152        .unwrap()
1153        .with_requires_grad();
1154
1155        println!("Left original requires_grad: {}", left.requires_grad());
1156
1157        // Make non-contiguous like in the test
1158        let left_nc = left.transpose(1, 2).transpose(1, 2);
1159        println!(
1160            "Left non-contiguous requires_grad: {}",
1161            left_nc.requires_grad()
1162        );
1163        println!(
1164            "Left non-contiguous is_contiguous: {}",
1165            left_nc.is_contiguous()
1166        );
1167
1168        // Test matmul
1169        let right = Tensor::from_slice(
1170            &[
1171                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1172            ],
1173            vec![2, 3, 2],
1174        )
1175        .unwrap()
1176        .with_requires_grad();
1177
1178        let right_nc = right.transpose(1, 2).transpose(1, 2);
1179        println!(
1180            "Right non-contiguous requires_grad: {}",
1181            right_nc.requires_grad()
1182        );
1183        println!(
1184            "Right non-contiguous is_contiguous: {}",
1185            right_nc.is_contiguous()
1186        );
1187
1188        let mut result = left_nc.matmul(&right_nc);
1189        println!("Result requires_grad: {}", result.requires_grad());
1190
1191        // Backward
1192        result.backward(None);
1193
1194        // Check gradients
1195        println!("Left gradient exists: {}", left_nc.grad_owned().is_some());
1196        println!("Right gradient exists: {}", right_nc.grad_owned().is_some());
1197
1198        if let Some(grad_left) = left_nc.grad_owned() {
1199            println!("Left gradient shape: {:?}", grad_left.shape().dims());
1200        }
1201        if let Some(grad_right) = right_nc.grad_owned() {
1202            println!("Right gradient shape: {:?}", grad_right.shape().dims());
1203        }
1204    }
1205
1206    #[test]
1207    fn test_debug_4d_gradient_issue() {
1208        use crate::gradtrack::clear_gradients;
1209        clear_gradients();
1210
1211        println!("=== Debug 4D Gradient Issue ===");
1212
1213        // Test case: 4D: 2x3x4x5 @ 2x3x5x6 gradients (failing case)
1214        let left_shape = vec![2, 3, 4, 5];
1215        let right_shape = vec![2, 3, 5, 6];
1216
1217        let left_size: usize = left_shape.iter().product();
1218        let right_size: usize = right_shape.iter().product();
1219
1220        let left_data: Vec<f32> = (0..left_size).map(|i| (i as f32) * 0.1 + 1.0).collect();
1221        let right_data: Vec<f32> = (0..right_size).map(|i| (i as f32) * 0.2 + 0.5).collect();
1222
1223        let left = Tensor::from_slice(&left_data, left_shape.clone())
1224            .unwrap()
1225            .with_requires_grad();
1226        let right = Tensor::from_slice(&right_data, right_shape.clone())
1227            .unwrap()
1228            .with_requires_grad();
1229
1230        println!("Left shape: {:?}", left.shape().dims());
1231        println!("Right shape: {:?}", right.shape().dims());
1232
1233        // Check if forward pass is correct first
1234        let result_no_grad = crate::gradtrack::with_no_grad(|| left.matmul(&right));
1235        println!("Forward result shape: {:?}", result_no_grad.shape().dims());
1236        println!("Forward result[0] = {}", result_no_grad.data()[0]);
1237
1238        // Forward pass with gradients
1239        let mut result = left.matmul(&right);
1240        println!("Result shape: {:?}", result.shape().dims());
1241
1242        // Check if forward results match
1243        let forward_diff = (result.data()[0] - result_no_grad.data()[0]).abs();
1244        println!("Forward pass difference: {}", forward_diff);
1245
1246        // Backward pass
1247        result.backward(None);
1248
1249        let grad_left = left.grad_owned().unwrap();
1250
1251        println!("Left gradient shape: {:?}", grad_left.shape().dims());
1252        println!("Left gradient[40] = {}", grad_left.data()[40]);
1253        println!("Expected: ~78, Got: {}", grad_left.data()[40]);
1254
1255        // Check if gradient is all zeros (which would indicate a major issue)
1256        let grad_sum: f32 = grad_left.data().iter().sum();
1257        println!("Gradient sum: {} (should be non-zero)", grad_sum);
1258
1259        if grad_sum.abs() < 1e-6 {
1260            println!("ERROR: Gradient is essentially zero - major computation issue!");
1261        }
1262
1263        // Manual gradient computation for verification
1264        println!("\n=== Manual Verification ===");
1265
1266        // For C = A @ B, grad_A = grad_output @ B.T
1267        let grad_ones = Tensor::ones(result.shape().dims().to_vec());
1268        println!("Grad ones shape: {:?}", grad_ones.shape().dims());
1269
1270        // Transpose the last two dimensions of right
1271        let right_rank = right.shape().dims().len();
1272        let right_t = right.transpose(right_rank - 2, right_rank - 1);
1273        println!("Right transposed shape: {:?}", right_t.shape().dims());
1274
1275        let manual_grad_left =
1276            crate::gradtrack::with_no_grad(|| grad_ones.matmul(&right_t.contiguous()));
1277        println!(
1278            "Manual grad left shape: {:?}",
1279            manual_grad_left.shape().dims()
1280        );
1281        println!("Manual grad left[40] = {}", manual_grad_left.data()[40]);
1282
1283        // Check if they match
1284        let diff = (grad_left.data()[40] - manual_grad_left.data()[40]).abs();
1285        println!("Difference at element 40: {}", diff);
1286
1287        if diff > 1e-3 {
1288            println!("MAJOR GRADIENT ERROR DETECTED!");
1289            println!(
1290                "Expected (manual): {}, Got (automatic): {}",
1291                manual_grad_left.data()[40],
1292                grad_left.data()[40]
1293            );
1294
1295            // Check if the issue is in the gradient computation or the reduce function
1296            println!("\n=== Investigating Root Cause ===");
1297            println!(
1298                "Manual gradient sum: {}",
1299                manual_grad_left.data().iter().sum::<f32>()
1300            );
1301            println!(
1302                "Automatic gradient sum: {}",
1303                grad_left.data().iter().sum::<f32>()
1304            );
1305        } else {
1306            println!("Gradients match!");
1307        }
1308    }
1309
1310    #[test]
1311    fn test_debug_matmul_with_known_values() {
1312        use crate::gradtrack::clear_gradients;
1313        clear_gradients();
1314
1315        println!("=== Debug MatMul with Known Values ===");
1316
1317        // Use simple values that should produce exact results
1318        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
1319            .unwrap()
1320            .with_requires_grad();
1321        let b = Tensor::from_slice(&[1.0, 0.0, 0.0, 1.0], vec![2, 2]) // Identity
1322            .unwrap()
1323            .with_requires_grad();
1324
1325        println!("A: {:?}", a.data());
1326        println!("B (identity): {:?}", b.data());
1327
1328        let mut result = a.matmul(&b);
1329        println!("A @ B: {:?}", result.data());
1330        println!("Expected: {:?}", a.data()); // Should be same as A since B is identity
1331
1332        // Backward pass
1333        result.backward(None);
1334
1335        let grad_a = a.grad_owned().unwrap();
1336        let grad_b = b.grad_owned().unwrap();
1337
1338        println!("grad_A: {:?}", grad_a.data());
1339        println!("grad_B: {:?}", grad_b.data());
1340
1341        // For A @ B where B is identity:
1342        // grad_A = grad_output @ B^T = grad_output @ I = grad_output = ones([2,2])
1343        // grad_B = A^T @ grad_output = A^T @ ones([2,2])
1344
1345        let expected_grad_a = vec![1.0, 1.0, 1.0, 1.0];
1346        let expected_grad_b = vec![4.0, 6.0, 4.0, 6.0]; // A^T @ ones = [1+3, 2+4, 1+3, 2+4] = [4,6,4,6]
1347
1348        println!("Expected grad_A: {:?}", expected_grad_a);
1349        println!("Expected grad_B: {:?}", expected_grad_b);
1350
1351        for (i, (&expected, &actual)) in
1352            expected_grad_a.iter().zip(grad_a.data().iter()).enumerate()
1353        {
1354            if (expected - actual).abs() > 1e-5 {
1355                println!(
1356                    "grad_A error at index {}: expected={}, actual={}, diff={}",
1357                    i,
1358                    expected,
1359                    actual,
1360                    (expected - actual).abs()
1361                );
1362            }
1363        }
1364
1365        for (i, (&expected, &actual)) in
1366            expected_grad_b.iter().zip(grad_b.data().iter()).enumerate()
1367        {
1368            if (expected - actual).abs() > 1e-5 {
1369                println!(
1370                    "grad_B error at index {}: expected={}, actual={}, diff={}",
1371                    i,
1372                    expected,
1373                    actual,
1374                    (expected - actual).abs()
1375                );
1376            }
1377        }
1378    }
1379
1380    /// New: 1D @ ND batched vector broadcasting forward and backward
1381    #[test]
1382    fn test_matmul_1d_nd_broadcast_forward_backward() {
1383        // [K] @ [B, K, N] -> [B, N]
1384        let k = 4;
1385        let b = 2;
1386        let n = 3;
1387        let left_data: Vec<f32> = (0..k).map(|i| i as f32 + 1.0).collect(); // [1,2,3,4]
1388        let right_data: Vec<f32> = (0..b * k * n).map(|i| (i as f32) * 0.1 + 0.5).collect();
1389
1390        let left = Tensor::from_slice(&left_data, vec![k])
1391            .unwrap()
1392            .with_requires_grad();
1393        let right = Tensor::from_slice(&right_data, vec![b, k, n])
1394            .unwrap()
1395            .with_requires_grad();
1396
1397        let mut out = left.matmul(&right);
1398        assert_eq!(out.shape().dims(), vec![b, n]);
1399
1400        // Backward with ones
1401        out.backward(None);
1402        let grad_left = left.grad_owned().unwrap();
1403        let grad_right = right.grad_owned().unwrap();
1404
1405        assert_eq!(grad_left.shape().dims(), vec![k]);
1406        assert_eq!(grad_right.shape().dims(), vec![b, k, n]);
1407    }
1408
1409    /// New: ND @ 1D batched matrix-vector forward and backward
1410    #[test]
1411    fn test_matmul_nd_1d_broadcast_forward_backward() {
1412        // [B, M, K] @ [K] -> [B, M]
1413        let b = 3;
1414        let m = 2;
1415        let k = 5;
1416        let left_data: Vec<f32> = (0..b * m * k).map(|i| (i as f32) * 0.05 + 1.0).collect();
1417        let right_data: Vec<f32> = (0..k).map(|i| (i as f32) * 0.2 + 0.5).collect();
1418
1419        let left = Tensor::from_slice(&left_data, vec![b, m, k])
1420            .unwrap()
1421            .with_requires_grad();
1422        let right = Tensor::from_slice(&right_data, vec![k])
1423            .unwrap()
1424            .with_requires_grad();
1425
1426        let mut out = left.matmul(&right);
1427        assert_eq!(out.shape().dims(), vec![b, m]);
1428
1429        // Backward with ones
1430        out.backward(None);
1431        let grad_left = left.grad_owned().unwrap();
1432        let grad_right = right.grad_owned().unwrap();
1433
1434        assert_eq!(grad_left.shape().dims(), vec![b, m, k]);
1435        assert_eq!(grad_right.shape().dims(), vec![k]);
1436    }
1437
1438    #[test]
1439    fn test_batched_mat_vec_forward_and_shapes() {
1440        // [B, M, K] @ [B, K] -> [B, M]
1441        let b = 3usize;
1442        let m = 5usize;
1443        let k = 7usize;
1444        let left_data: Vec<f32> = (0..b * m * k).map(|i| (i as f32) * 0.01 + 1.0).collect();
1445        let right_data: Vec<f32> = (0..b * k).map(|i| (i as f32) * 0.02 + 0.5).collect();
1446        let left = Tensor::from_slice(&left_data, vec![b, m, k]).unwrap();
1447        let right = Tensor::from_slice(&right_data, vec![b, k]).unwrap();
1448        let out = left.matmul(&right);
1449        assert_eq!(out.shape().dims(), vec![b, m]);
1450    }
1451
1452    #[test]
1453    fn test_batched_mat_vec_numeric_small() {
1454        // Small numeric check: [2, 2, 3] @ [2, 3]
1455        let left = Tensor::from_slice(
1456            &[
1457                1.0, 2.0, 3.0, // b0 r0
1458                4.0, 5.0, 6.0, // b0 r1
1459                1.0, 1.0, 1.0, // b1 r0
1460                2.0, 2.0, 2.0, // b1 r1
1461            ],
1462            vec![2, 2, 3],
1463        )
1464        .unwrap();
1465        let right = Tensor::from_slice(&[1.0, 1.0, 1.0, 2.0, 3.0, 4.0], vec![2, 3]).unwrap();
1466        let out = left.matmul(&right); // [2,2]
1467        assert_eq!(out.shape().dims(), vec![2, 2]);
1468        unsafe {
1469            // batch 0: [2,3] @ [3] = [6, 15]
1470            assert!((*out.as_ptr() - 6.0).abs() < 1e-6);
1471            assert!((*out.as_ptr().add(1) - 15.0).abs() < 1e-6);
1472            // batch 1: [2,3] @ [3] = [9, 18]
1473            assert!((*out.as_ptr().add(2) - 9.0).abs() < 1e-6);
1474            assert!((*out.as_ptr().add(3) - 18.0).abs() < 1e-6);
1475        }
1476    }
1477
1478    #[test]
1479    fn test_batched_vector_cases_grad_shapes() {
1480        use crate::gradtrack::clear_gradients;
1481        clear_gradients();
1482
1483        // ND @ vec
1484        let left = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 2, 2])
1485            .unwrap()
1486            .with_requires_grad();
1487        let right = Tensor::from_slice(&[1.0, 1.0], vec![1, 2])
1488            .unwrap()
1489            .with_requires_grad();
1490        let mut out = left.matmul(&right);
1491        assert_eq!(out.shape().dims(), vec![1, 2]);
1492        out.backward(None);
1493        let gl = left.grad_owned().unwrap();
1494        let gr = right.grad_owned().unwrap();
1495        assert_eq!(gl.shape().dims(), vec![1, 2, 2]);
1496        assert_eq!(gr.shape().dims(), vec![1, 2]);
1497
1498        clear_gradients();
1499        // vec @ ND
1500        let left2 = Tensor::from_slice(&[1.0, 1.0], vec![2])
1501            .unwrap()
1502            .with_requires_grad();
1503        let right2 = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
1504            .unwrap()
1505            .with_requires_grad();
1506        let mut out2 = left2.matmul(&right2);
1507        assert_eq!(out2.shape().dims(), vec![2]);
1508        out2.backward(None);
1509        let gl2 = left2.grad_owned().unwrap();
1510        let gr2 = right2.grad_owned().unwrap();
1511        assert_eq!(gl2.shape().dims(), vec![2]);
1512        assert_eq!(gr2.shape().dims(), vec![2, 2]);
1513    }
1514
1515    #[test]
1516    fn test_broadcast_matmul_gradient_complex_case() {
1517        use crate::gradtrack::clear_gradients;
1518        clear_gradients();
1519
1520        println!("=== Testing complex broadcast case: [1,2,2,1,4,5] @ [2,1,5,6] ===");
1521
1522        // The exact failing case from the issue
1523        let left_shape = vec![1, 2, 2, 1, 4, 5];
1524        let right_shape = vec![2, 1, 5, 6];
1525        let expected_result_shape = vec![1, 2, 2, 1, 4, 6];
1526
1527        println!("Left shape: {:?}", left_shape);
1528        println!("Right shape: {:?}", right_shape);
1529        println!("Expected result shape: {:?}", expected_result_shape);
1530
1531        // Create tensors with some data
1532        let left_size: usize = left_shape.iter().product();
1533        let right_size: usize = right_shape.iter().product();
1534
1535        let left_data: Vec<f32> = (0..left_size).map(|i| (i as f32) * 0.01 + 1.0).collect();
1536        let right_data: Vec<f32> = (0..right_size).map(|i| (i as f32) * 0.02 + 0.5).collect();
1537
1538        let left = Tensor::from_slice(&left_data, left_shape.clone())
1539            .unwrap()
1540            .with_requires_grad();
1541        let right = Tensor::from_slice(&right_data, right_shape.clone())
1542            .unwrap()
1543            .with_requires_grad();
1544
1545        // Forward pass
1546        let mut result = left.matmul(&right);
1547        println!("Actual result shape: {:?}", result.shape().dims());
1548
1549        // Verify forward result shape
1550        assert_eq!(
1551            result.shape().dims(),
1552            expected_result_shape,
1553            "Forward result shape mismatch"
1554        );
1555
1556        // Backward pass
1557        result.backward(None);
1558
1559        // Check gradient shapes
1560        let grad_left = left.grad_owned().unwrap();
1561        let grad_right = right.grad_owned().unwrap();
1562
1563        println!("Left gradient shape: {:?}", grad_left.shape().dims());
1564        println!("Right gradient shape: {:?}", grad_right.shape().dims());
1565
1566        // Gradients should have the same shape as the original tensors
1567        assert_eq!(
1568            grad_left.shape().dims(),
1569            left_shape,
1570            "Left gradient shape mismatch"
1571        );
1572        assert_eq!(
1573            grad_right.shape().dims(),
1574            right_shape,
1575            "Right gradient shape mismatch"
1576        );
1577
1578        // Check that gradients are not all zeros (sanity check)
1579        let left_grad_sum: f32 = grad_left.data().iter().sum();
1580        let right_grad_sum: f32 = grad_right.data().iter().sum();
1581
1582        println!("Left gradient sum: {}", left_grad_sum);
1583        println!("Right gradient sum: {}", right_grad_sum);
1584
1585        assert!(
1586            left_grad_sum.abs() > 1e-6,
1587            "Left gradient should not be zero"
1588        );
1589        assert!(
1590            right_grad_sum.abs() > 1e-6,
1591            "Right gradient should not be zero"
1592        );
1593
1594        println!("✓ Complex broadcast matmul gradient test passed!");
1595    }
1596
1597    #[test]
1598    fn test_grad_left_vec_at_bkn_manual() {
1599        use crate::gradtrack::clear_gradients;
1600        clear_gradients();
1601        let k = 8usize;
1602        let b = 3usize;
1603        let n = 5usize;
1604
1605        // left vector [K]
1606        let left = Tensor::from_slice(
1607            &(0..k).map(|i| i as f32 * 0.1 + 1.0).collect::<Vec<_>>(),
1608            vec![k],
1609        )
1610        .unwrap()
1611        .with_requires_grad();
1612
1613        // right [B,K,N]
1614        let right = Tensor::from_slice(
1615            &(0..b * k * n)
1616                .map(|i| i as f32 * 0.2 + 0.5)
1617                .collect::<Vec<_>>(),
1618            vec![b, k, n],
1619        )
1620        .unwrap()
1621        .with_requires_grad();
1622
1623        let mut out = left.matmul(&right); // [B,N]
1624                                           // Backward with ones
1625        out.backward(None);
1626
1627        // Compute manual expected grad for left: sum_{b,n} right[b,k,n]
1628        let grad_left = left.grad_owned().unwrap();
1629        assert_eq!(grad_left.shape().dims(), vec![k]);
1630        let rp = unsafe { right.as_ptr() };
1631        for kk in 0..k {
1632            let mut sum = 0.0f32;
1633            for bb in 0..b {
1634                for nn in 0..n {
1635                    let idx = bb * (k * n) + kk * n + nn;
1636                    unsafe { sum += *rp.add(idx) };
1637                }
1638            }
1639            let got = unsafe { *grad_left.as_ptr().add(kk) };
1640            assert!(
1641                (got - sum).abs() < 1e-4,
1642                "k={} got={} expected={}",
1643                kk,
1644                got,
1645                sum
1646            );
1647        }
1648    }
1649}