1use crate::Tensor;
154#[cfg(target_arch = "x86_64")]
155pub mod avx2_kernels;
156#[cfg(target_arch = "x86_64")]
157pub mod avx512_kernels;
158pub mod classification_and_validation;
159pub mod dispatch;
160#[cfg(target_arch = "x86_64")]
161pub mod pack_n_cache;
162pub mod scalar_kernels;
163#[cfg(target_arch = "x86_64")]
164pub mod sse_kernels;
165use dispatch::*;
166
167impl Tensor {
168 #[track_caller]
202 pub fn matmul(&self, other: &Tensor) -> Tensor {
203 let kernels = MatMulKernels::get_cached_kernels();
205
206 let left_shape = self.shape().dims();
208 let right_shape = other.shape().dims();
209 let op_type = MatMulKernels::classify_operation(left_shape, right_shape);
210
211 let result_shape = Self::validate_and_compute_matmul_shape(left_shape, right_shape);
213
214 let left_ptr = unsafe { self.as_ptr() };
216 let right_ptr = unsafe { other.as_ptr() };
217
218 let mut result = Tensor::new(result_shape.clone());
220
221 unsafe {
222 let result_ptr = result.as_mut_ptr();
223
224 match op_type {
226 MatMulOpType::Dot1D1D => {
227 let size = left_shape[0];
229 if self.is_contiguous() && other.is_contiguous() {
230 let dot_result = kernels.dispatch_dot_1d(left_ptr, right_ptr, size);
231 *result_ptr = dot_result;
232 } else {
233 let left_stride = self.strides()[0];
234 let right_stride = other.strides()[0];
235 let dot_result = kernels.dispatch_dot_1d_strided(
236 left_ptr,
237 right_ptr,
238 size,
239 left_stride,
240 right_stride,
241 );
242 *result_ptr = dot_result;
243 }
244 }
245 MatMulOpType::Vec1D2D => {
246 let k = left_shape[0];
248 let n = right_shape[1];
249 if self.is_contiguous() && other.is_contiguous() {
250 kernels.dispatch_vec_mat(left_ptr, right_ptr, result_ptr, k, n);
251 } else {
252 let left_stride = self.strides()[0];
253 let right_strides = other.strides();
254 let right_row_stride = right_strides[0];
255 let right_col_stride = right_strides[1];
256 let result_stride = 1; kernels.dispatch_vec_mat_strided(
258 left_ptr,
259 right_ptr,
260 result_ptr,
261 k,
262 n,
263 left_stride,
264 right_row_stride,
265 right_col_stride,
266 result_stride,
267 );
268 }
269 }
270 MatMulOpType::Mat2D1D => {
271 let m = left_shape[0];
273 let k = left_shape[1];
274 if self.is_contiguous() && other.is_contiguous() {
275 kernels.dispatch_mat_vec(left_ptr, right_ptr, result_ptr, m, k);
276 } else {
277 let left_strides = self.strides();
278 let left_row_stride = left_strides[0];
279 let left_col_stride = left_strides[1];
280 let right_stride = other.strides()[0];
281 let result_stride = 1; kernels.dispatch_mat_vec_strided(
283 left_ptr,
284 right_ptr,
285 result_ptr,
286 m,
287 k,
288 left_row_stride,
289 left_col_stride,
290 right_stride,
291 result_stride,
292 );
293 }
294 }
295 MatMulOpType::Mat2D2D => {
296 let m = left_shape[0];
298 let k = left_shape[1];
299 let n = right_shape[1];
300 if self.is_contiguous() && other.is_contiguous() {
301 kernels.dispatch_mat_mat(left_ptr, right_ptr, result_ptr, m, k, n);
302 } else {
303 let left_strides = self.strides();
304 let left_row_stride = left_strides[0];
305 let left_col_stride = left_strides[1];
306 let right_strides = other.strides();
307 let right_row_stride = right_strides[0];
308 let right_col_stride = right_strides[1];
309 let result_strides = result.strides();
310 let result_row_stride = result_strides[0];
311 let result_col_stride = result_strides[1];
312 kernels.dispatch_mat_mat_strided(
313 left_ptr,
314 right_ptr,
315 result_ptr,
316 m,
317 k,
318 n,
319 left_row_stride,
320 left_col_stride,
321 right_row_stride,
322 right_col_stride,
323 result_row_stride,
324 result_col_stride,
325 );
326 }
327 }
328 MatMulOpType::BatchedND => {
329 Self::dispatch_batched_matmul_with_ptrs_strided(
331 left_ptr,
332 right_ptr,
333 self,
334 other,
335 &mut result,
336 kernels,
337 );
338 }
339 }
340 }
341
342 if (self.requires_grad() || other.requires_grad()) && crate::gradtrack::is_grad_enabled() {
344 result.set_requires_grad_internal(true);
345 let grad_fn = crate::gradtrack::grad_fn::GradFn::MatMul {
346 left_operand: Box::new(self.clone()),
347 right_operand: Box::new(other.clone()),
348 requires_grad: (self.requires_grad(), other.requires_grad()),
349 };
350 result.set_grad_fn(grad_fn.clone());
351
352 let input_ids = vec![self.id(), other.id()];
355 crate::gradtrack::engine::GradEngine::register_operation(
356 result.id(),
357 input_ids,
358 grad_fn,
359 );
360 }
361
362 result
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
375
376 #[test]
382 fn test_matmul_2d_basic() {
383 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
385 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
386 let result = a.matmul(&b);
387
388 assert_eq!(result.shape().dims(), vec![2, 2]);
389
390 unsafe {
392 let ptr = result.as_ptr();
393 assert_eq!(*ptr.add(0), 19.0); assert_eq!(*ptr.add(1), 22.0); assert_eq!(*ptr.add(2), 43.0); assert_eq!(*ptr.add(3), 50.0); }
398 }
399
400 #[test]
401 fn test_tall_skinny_vs_wide_correctness() {
402 let mut a_ts = Tensor::new(vec![128, 8]);
404 let mut b_ts = Tensor::new(vec![8, 16]);
405 for (i, v) in a_ts.data_mut().iter_mut().enumerate() {
406 *v = (i as f32 * 0.01).sin();
407 }
408 for (i, v) in b_ts.data_mut().iter_mut().enumerate() {
409 *v = (i as f32 * 0.02).cos();
410 }
411 let r_ts = a_ts.matmul(&b_ts);
412 assert_eq!(r_ts.shape().dims(), vec![128, 16]);
413
414 let mut a_w = Tensor::new(vec![16, 8]);
416 let mut b_w = Tensor::new(vec![8, 256]);
417 for (i, v) in a_w.data_mut().iter_mut().enumerate() {
418 *v = (i as f32 * 0.03).sin();
419 }
420 for (i, v) in b_w.data_mut().iter_mut().enumerate() {
421 *v = (i as f32 * 0.04).cos();
422 }
423 let r_w = a_w.matmul(&b_w);
424 assert_eq!(r_w.shape().dims(), vec![16, 256]);
425 }
426
427 #[test]
429 fn test_matmul_row_vector_times_matrix() {
430 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
431 let b = Tensor::from_slice(
432 &[
433 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
435 ],
436 vec![4, 2],
437 )
438 .unwrap();
439 let result = a.matmul(&b);
440 assert_eq!(result.shape().dims(), vec![1, 2]);
441 unsafe {
442 let p = result.as_ptr();
443 assert_eq!(*p.add(0), 50.0);
447 assert_eq!(*p.add(1), 60.0);
448 }
449 }
450
451 #[test]
453 fn test_matmul_1d_rowvec_times_matrix() {
454 let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
455 let b = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]).unwrap();
456 let r = a.matmul(&b);
457 assert_eq!(r.shape().dims(), vec![2]);
458 unsafe {
459 let p = r.as_ptr();
460 assert_eq!(*p.add(0), 22.0);
463 assert_eq!(*p.add(1), 28.0);
464 }
465 }
466
467 #[test]
469 fn test_matmul_2d_2d_gradients() {
470 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
471 .unwrap()
472 .with_requires_grad();
473 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2])
474 .unwrap()
475 .with_requires_grad();
476
477 let mut result = a.matmul(&b); assert_eq!(result.shape().dims(), vec![2, 2]);
479
480 let expected = [19.0, 22.0, 43.0, 50.0];
482 unsafe {
483 let ptr = result.as_ptr();
484 for (i, val) in expected.iter().enumerate().take(4) {
485 assert_eq!(*ptr.add(i), *val);
486 }
487 }
488
489 let grad_output = Tensor::from_slice(&[1.0, 1.0, 1.0, 1.0], vec![2, 2]).unwrap();
491 result.backward(Some(grad_output));
492
493 let grad_a = a.grad_owned().unwrap();
494 let grad_b = b.grad_owned().unwrap();
495
496 assert_eq!(grad_a.shape().dims(), vec![2, 2]);
497 assert_eq!(grad_b.shape().dims(), vec![2, 2]);
498
499 unsafe {
502 let grad_a_ptr = grad_a.as_ptr();
503 assert_eq!(*grad_a_ptr.add(0), 11.0); assert_eq!(*grad_a_ptr.add(1), 15.0); assert_eq!(*grad_a_ptr.add(2), 11.0); assert_eq!(*grad_a_ptr.add(3), 15.0); }
508
509 unsafe {
511 let grad_b_ptr = grad_b.as_ptr();
512 assert_eq!(*grad_b_ptr.add(0), 4.0); assert_eq!(*grad_b_ptr.add(1), 4.0); assert_eq!(*grad_b_ptr.add(2), 6.0); assert_eq!(*grad_b_ptr.add(3), 6.0); }
517 }
518
519 #[test]
521 fn test_matmul_partial_requires_grad() {
522 let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap(); let b = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2])
525 .unwrap()
526 .with_requires_grad(); let mut result = a.matmul(&b); assert_eq!(result.shape().dims(), vec![2]);
530
531 result.backward(None);
532
533 assert!(a.grad_owned().is_none());
535 let grad_b = b.grad_owned().unwrap();
536
537 assert_eq!(grad_b.shape().dims(), vec![3, 2]);
538
539 unsafe {
542 let grad_b_ptr = grad_b.as_ptr();
543 assert_eq!(*grad_b_ptr.add(0), 1.0); assert_eq!(*grad_b_ptr.add(1), 1.0); assert_eq!(*grad_b_ptr.add(2), 2.0); assert_eq!(*grad_b_ptr.add(3), 2.0); assert_eq!(*grad_b_ptr.add(4), 3.0); assert_eq!(*grad_b_ptr.add(5), 3.0); }
550 }
551
552 #[test]
553 fn test_debug_gradient_values() {
554 println!("=== Debugging matmul gradient issue ===");
555
556 let left_shape = vec![1, 3, 4];
558 let right_shape = vec![2, 4, 5];
559
560 let mut left = Tensor::zeros(left_shape.clone()).with_requires_grad();
561 let mut right = Tensor::zeros(right_shape.clone()).with_requires_grad();
562
563 let left_size = left_shape.iter().product::<usize>();
564 let right_size = right_shape.iter().product::<usize>();
565
566 unsafe {
568 for i in 0..left_size {
569 *left.as_mut_ptr().add(i) = (i as f32) * 0.1 + 1.0;
570 }
571 for i in 0..right_size {
572 *right.as_mut_ptr().add(i) = (i as f32) * 0.2 + 0.5;
573 }
574 }
575
576 println!(
577 "Left shape: {:?}, data: {:?}",
578 left.shape().dims(),
579 left.data()
580 );
581 println!(
582 "Right shape: {:?}, data: {:?}",
583 right.shape().dims(),
584 right.data()
585 );
586
587 let mut result = left.matmul(&right);
589 println!(
590 "Result shape: {:?}, data: {:?}",
591 result.shape().dims(),
592 result.data()
593 );
594
595 let grad_ones = Tensor::ones(result.shape().dims().to_vec());
597 println!(
598 "Grad ones shape: {:?}, data: {:?}",
599 grad_ones.shape().dims(),
600 grad_ones.data()
601 );
602
603 result.backward(Some(grad_ones));
604
605 let grad_left = left.grad_owned().unwrap();
606 let grad_right = right.grad_owned().unwrap();
607
608 println!(
609 "Left gradient shape: {:?}, data: {:?}",
610 grad_left.shape().dims(),
611 grad_left.data()
612 );
613 println!(
614 "Right gradient shape: {:?}, data: {:?}",
615 grad_right.shape().dims(),
616 grad_right.data()
617 );
618
619 println!(
620 "Left gradient[0] = {} (expected ~29, but we're getting ~41)",
621 grad_left.data()[0]
622 );
623 }
624
625 #[test]
626 fn test_simple_batched_gradient() {
627 println!("=== Testing simple batched gradient ===");
628
629 let left = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2])
631 .unwrap()
632 .with_requires_grad();
633 let right = Tensor::from_slice(&[0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0], vec![2, 2, 2])
634 .unwrap()
635 .with_requires_grad();
636
637 println!("Left: {:?}", left.data());
638 println!("Right: {:?}", right.data());
639
640 let right_t = right.transpose(1, 2);
642 println!("Right transposed: {:?}", right_t.data());
643 println!("Right transposed contiguous: {:?}", right_t.is_contiguous());
644 println!("Right transposed strides: {:?}", right_t.strides());
645
646 let mut result = left.matmul(&right);
647 println!("Result: {:?}", result.data());
648
649 let grad_ones = Tensor::ones(result.shape().dims().to_vec());
650 result.backward(Some(grad_ones));
651
652 let grad_left = left.grad_owned().unwrap();
653 let grad_right = right.grad_owned().unwrap();
654
655 println!("Left gradient: {:?}", grad_left.data());
656 println!("Right gradient: {:?}", grad_right.data());
657
658 println!("\n=== Manual verification ===");
660 println!("Expected left grad batch 0: [0.5+1.0, 1.5+2.0] = [1.5, 3.5]");
661 println!("Expected left grad batch 1: [2.5+3.0, 3.5+4.0] = [5.5, 7.5]");
662 }
663
664 #[test]
665 fn test_alignment_checks_prevent_misaligned_access() {
666 use crate::tensor::core::memory::{detect_runtime_simd, simd_alignment_bytes};
667
668 let a = Tensor::new(vec![4]);
670 let b = Tensor::new(vec![4]);
671
672 let kernels = MatMulKernels::get_cached_kernels();
673 let simd_alignment = simd_alignment_bytes(detect_runtime_simd());
674
675 unsafe {
676 let a_ptr = a.as_ptr();
677 let b_ptr = b.as_ptr();
678
679 let a_aligned = (a_ptr as usize).is_multiple_of(simd_alignment);
681 let b_aligned = (b_ptr as usize).is_multiple_of(simd_alignment);
682
683 let alignment_check = kernels.check_alignment_for_simd(
685 a_ptr,
686 b_ptr,
687 std::ptr::null_mut(),
688 simd_alignment,
689 );
690
691 assert_eq!(alignment_check, a_aligned && b_aligned);
693 }
694 }
695
696 #[test]
697 fn test_avx512_specific_alignment_validation() {
698 use crate::tensor::core::memory::SimdLevel;
699
700 let kernels = MatMulKernels::get_cached_kernels();
701
702 let test_ptr = 0x1004 as *const f32; let avx512_aligned_ptr = 0x1040 as *const f32; let not_avx512_aligned =
708 kernels.check_alignment_for_simd(test_ptr, test_ptr, std::ptr::null_mut(), 64);
709 assert!(
710 !not_avx512_aligned,
711 "4-byte aligned pointer should not be considered 64-byte aligned"
712 );
713
714 let is_avx512_aligned = kernels.check_alignment_for_simd(
716 avx512_aligned_ptr,
717 avx512_aligned_ptr,
718 std::ptr::null_mut(),
719 64,
720 );
721 assert!(
722 is_avx512_aligned,
723 "64-byte aligned pointer should pass AVX512 alignment check"
724 );
725
726 match kernels.simd_level {
728 #[cfg(target_arch = "x86_64")]
729 SimdLevel::Avx512 => {
730 assert_eq!(kernels.alignment, 64, "AVX512 should use 64-byte alignment");
732 }
733 #[cfg(target_arch = "x86_64")]
734 SimdLevel::Avx2 => {
735 assert_eq!(kernels.alignment, 32, "AVX2 should use 32-byte alignment");
737 }
738 #[cfg(target_arch = "x86_64")]
739 SimdLevel::Sse2 => {
740 assert_eq!(kernels.alignment, 16, "SSE2 should use 16-byte alignment");
742 }
743 SimdLevel::Scalar => {
744 assert!(
746 kernels.alignment >= 4,
747 "Scalar should use at least 4-byte alignment"
748 );
749 }
750 }
751 }
752
753 #[test]
754 fn test_comprehensive_alignment_management() {
755 use crate::tensor::core::memory::{detect_runtime_simd, simd_alignment_bytes};
756
757 let kernels = MatMulKernels::get_cached_kernels();
759 let simd_alignment = simd_alignment_bytes(detect_runtime_simd());
760
761 let a = Tensor::new(vec![16]);
763 let b = Tensor::new(vec![16]);
764
765 unsafe {
767 let a_ptr = a.as_ptr();
768 let b_ptr = b.as_ptr();
769
770 let a_aligned = (a_ptr as usize).is_multiple_of(simd_alignment);
771 let b_aligned = (b_ptr as usize).is_multiple_of(simd_alignment);
772
773 assert!(a_aligned, "New tensor 'a' should be SIMD-aligned");
775 assert!(b_aligned, "New tensor 'b' should be SIMD-aligned");
776
777 let dummy_c_ptr = a_ptr as *mut f32; let alignment_check = kernels.check_actual_alignment(a_ptr, b_ptr, dummy_c_ptr);
780 assert!(
782 alignment_check,
783 "Alignment check should pass for aligned pointers"
784 );
785 }
786
787 let a_transposed = a.transpose(0, 0); let a_contiguous = a_transposed.contiguous();
790
791 unsafe {
792 let a_cont_ptr = a_contiguous.as_ptr();
793 let a_cont_aligned = (a_cont_ptr as usize).is_multiple_of(simd_alignment);
794 assert!(
795 a_cont_aligned,
796 "Contiguous tensor should maintain SIMD alignment"
797 );
798 }
799
800 let result = a.matmul(&b);
802 assert_eq!(result.shape().dims(), vec![]); let a_2d = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
807 let b_2d = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
808 let result_2d = a_2d.matmul(&b_2d);
809
810 assert_eq!(result_2d.shape().dims(), vec![2, 2]);
811 assert!((result_2d.get(&[0, 0]) - 19.0).abs() < 1e-6);
813 assert!((result_2d.get(&[0, 1]) - 22.0).abs() < 1e-6);
814 assert!((result_2d.get(&[1, 0]) - 43.0).abs() < 1e-6);
815 assert!((result_2d.get(&[1, 1]) - 50.0).abs() < 1e-6);
816 }
817
818 #[test]
819 fn test_linear_layer_pattern() {
820 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap(); let weight = Tensor::from_slice(&[0.1, 0.5, 0.3, 0.1, 0.5, 0.3], vec![3, 2])
823 .unwrap()
824 .with_requires_grad(); let bias = Tensor::from_slice(&[0.0, 0.1], vec![2])
826 .unwrap()
827 .with_requires_grad(); let weighted = x_data.matmul(&weight); let y_pred = weighted.add_tensor(&bias); let y_true = Tensor::from_slice(&[3.0, 5.0], vec![2]).unwrap();
835 let mut loss = y_pred.sub_tensor(&y_true).pow_scalar(2.0).mean();
836
837 loss.backward(None);
839
840 let grad_bias = bias.grad_owned().unwrap();
842 let grad_weight = weight.grad_owned().unwrap();
843
844 assert_eq!(grad_weight.shape().dims(), vec![3, 2]); assert_eq!(grad_bias.shape().dims(), vec![2]); assert_eq!(grad_weight.size(), 6);
849 assert_eq!(grad_bias.size(), 2);
850
851 assert!(x_data.grad_owned().is_none());
853 }
854
855 #[test]
856 fn test_debug_large_matmul_gradient() {
857 use crate::gradtrack::clear_gradients;
858 clear_gradients();
859
860 println!("=== Debug Large MatMul Gradient Issue ===");
861
862 for &size in &[4, 8, 16, 32, 64] {
864 println!("\n--- Testing size {}x{} ---", size, size);
865
866 let left = Tensor::from_slice(
867 &(0..size * size)
868 .map(|i| (i as f32) * 0.1 + 1.0)
869 .collect::<Vec<_>>(),
870 vec![size, size],
871 )
872 .unwrap()
873 .with_requires_grad();
874
875 let right = Tensor::from_slice(
876 &(0..size * size)
877 .map(|i| (i as f32) * 0.2 + 0.5)
878 .collect::<Vec<_>>(),
879 vec![size, size],
880 )
881 .unwrap()
882 .with_requires_grad();
883
884 let mut result = left.matmul(&right);
885
886 result.backward(None);
888
889 let grad_left = left.grad_owned().unwrap();
890 let _grad_right = right.grad_owned().unwrap();
891
892 let right_t = right.transpose(0, 1);
895 let grad_ones = Tensor::ones(vec![size, size]);
896 let expected_grad_left = grad_ones.matmul(&right_t);
897
898 let mut max_diff = 0.0f32;
900 let mut max_diff_idx = 0;
901 for i in 0..grad_left.size() {
902 let our_val = unsafe { *grad_left.as_ptr().add(i) };
903 let expected_val = unsafe { *expected_grad_left.as_ptr().add(i) };
904 let diff = (our_val - expected_val).abs();
905 if diff > max_diff {
906 max_diff = diff;
907 max_diff_idx = i;
908 }
909 }
910
911 println!(
912 "Max gradient diff: {} at index {} (size {}x{})",
913 max_diff, max_diff_idx, size, size
914 );
915
916 if max_diff > 1e-4 {
917 println!("PROBLEM DETECTED at size {}x{}", size, size);
918 let our_val = unsafe { *grad_left.as_ptr().add(max_diff_idx) };
919 let expected_val = unsafe { *expected_grad_left.as_ptr().add(max_diff_idx) };
920 println!(
921 " our_val={}, expected_val={}, diff={}",
922 our_val, expected_val, max_diff
923 );
924
925 println!(" right.is_contiguous(): {}", right.is_contiguous());
927 println!(" right_t.is_contiguous(): {}", right_t.is_contiguous());
928
929 break;
930 }
931
932 clear_gradients();
933 }
934 }
935
936 #[test]
937 fn test_debug_transpose_contiguous_issue() {
938 use crate::gradtrack::clear_gradients;
939 clear_gradients();
940
941 println!("=== Debug Transpose/Contiguous Issue ===");
942
943 let size = 8;
944 let right = Tensor::from_slice(
945 &(0..size * size)
946 .map(|i| (i as f32) * 0.2 + 0.5)
947 .collect::<Vec<_>>(),
948 vec![size, size],
949 )
950 .unwrap();
951
952 println!("Original right tensor:");
953 println!(" is_contiguous: {}", right.is_contiguous());
954 println!(" strides: {:?}", right.strides());
955
956 let right_t = right.transpose(0, 1);
957 println!("Transposed right tensor:");
958 println!(" is_contiguous: {}", right_t.is_contiguous());
959 println!(" strides: {:?}", right_t.strides());
960
961 let right_t_contiguous = if right_t.is_contiguous() {
962 right_t.clone()
963 } else {
964 right_t.contiguous()
965 };
966 println!("Contiguous transposed right tensor:");
967 println!(" is_contiguous: {}", right_t_contiguous.is_contiguous());
968 println!(" strides: {:?}", right_t_contiguous.strides());
969
970 println!("Original data (first 8): {:?}", &right.data()[0..8]);
972 println!(
973 "Transposed data (first 8): {:?}",
974 &right_t_contiguous.data()[0..8]
975 );
976
977 for i in 0..4 {
979 for j in 0..4 {
980 let orig_val = right.get(&[i, j]);
981 let trans_val = right_t_contiguous.get(&[j, i]);
982 if (orig_val - trans_val).abs() > 1e-6 {
983 println!(
984 "Transpose error at ({},{}): orig={}, trans={}",
985 i, j, orig_val, trans_val
986 );
987 }
988 }
989 }
990 }
991
992 #[test]
993 fn test_debug_scalar_kernel_accuracy() {
994 use crate::gradtrack::clear_gradients;
995 clear_gradients();
996
997 println!("=== Debug Scalar Kernel Accuracy ===");
998
999 let size = 3;
1001 let a_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
1002 let b_data = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]; let mut c_data = vec![0.0; 9];
1004
1005 unsafe {
1006 crate::tensor::ops::matmul::scalar_kernels::matmul_scalar_2d_2d(
1007 a_data.as_ptr(),
1008 b_data.as_ptr(),
1009 c_data.as_mut_ptr(),
1010 size,
1011 size,
1012 size,
1013 size,
1014 1, size,
1016 1, size,
1018 1, );
1020 }
1021
1022 println!("A @ I = {:?}", c_data);
1023 println!("Expected: {:?}", a_data);
1024
1025 for (i, (&expected, &actual)) in a_data.iter().zip(c_data.iter()).enumerate() {
1027 if (expected - actual).abs() > 1e-6 {
1028 println!(
1029 "Scalar kernel error at index {}: expected={}, actual={}",
1030 i, expected, actual
1031 );
1032 }
1033 }
1034
1035 let a_data2 = [1.0, 2.0, 3.0, 4.0];
1037 let b_data2 = [2.0, 0.0, 0.0, 2.0];
1038 let mut c_data2 = vec![0.0; 4];
1039
1040 unsafe {
1041 crate::tensor::ops::matmul::scalar_kernels::matmul_scalar_2d_2d(
1042 a_data2.as_ptr(),
1043 b_data2.as_ptr(),
1044 c_data2.as_mut_ptr(),
1045 2,
1046 2,
1047 2,
1048 2,
1049 1, 2,
1051 1, 2,
1053 1, );
1055 }
1056
1057 println!("[[1,2],[3,4]] @ [[2,0],[0,2]] = {:?}", c_data2);
1058 println!("Expected: [2.0, 4.0, 6.0, 8.0]");
1059 }
1060
1061 #[test]
1062 fn test_minimal_noncontiguous_gradient() {
1063 use crate::gradtrack::clear_gradients;
1064 clear_gradients();
1065
1066 println!("=== Minimal Non-contiguous Gradient Test ===");
1067
1068 let left = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
1070 .unwrap()
1071 .with_requires_grad();
1072
1073 let right = Tensor::from_slice(&[1.0, 0.0, 0.0, 1.0], vec![2, 2])
1074 .unwrap()
1075 .with_requires_grad();
1076
1077 println!("Left requires_grad: {}", left.requires_grad());
1078 println!("Right requires_grad: {}", right.requires_grad());
1079
1080 let left_nc = left.transpose(0, 1).transpose(0, 1);
1082 let right_nc = right.transpose(0, 1).transpose(0, 1);
1083
1084 println!("Left NC requires_grad: {}", left_nc.requires_grad());
1085 println!("Right NC requires_grad: {}", right_nc.requires_grad());
1086 println!("Left NC is_contiguous: {}", left_nc.is_contiguous());
1087 println!("Right NC is_contiguous: {}", right_nc.is_contiguous());
1088
1089 let mut result = left_nc.matmul(&right_nc);
1091 println!("Result requires_grad: {}", result.requires_grad());
1092
1093 result.backward(None);
1095
1096 println!(
1098 "Left NC gradient exists: {}",
1099 left_nc.grad_owned().is_some()
1100 );
1101 println!(
1102 "Right NC gradient exists: {}",
1103 right_nc.grad_owned().is_some()
1104 );
1105
1106 println!(
1108 "Original left gradient exists: {}",
1109 left.grad_owned().is_some()
1110 );
1111 println!(
1112 "Original right gradient exists: {}",
1113 right.grad_owned().is_some()
1114 );
1115 }
1116
1117 #[test]
1118 fn test_debug_transpose_gradient_tracking() {
1119 use crate::gradtrack::clear_gradients;
1120 clear_gradients();
1121
1122 println!("=== Debug Transpose Gradient Tracking ===");
1123
1124 let original = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
1126 .unwrap()
1127 .with_requires_grad();
1128
1129 println!("Original requires_grad: {}", original.requires_grad());
1130
1131 let t1 = original.transpose(0, 1);
1133 println!(
1134 "After first transpose requires_grad: {}",
1135 t1.requires_grad()
1136 );
1137
1138 let t2 = t1.transpose(0, 1);
1140 println!(
1141 "After second transpose requires_grad: {}",
1142 t2.requires_grad()
1143 );
1144
1145 let left = Tensor::from_slice(
1147 &[
1148 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1149 ],
1150 vec![2, 2, 3],
1151 )
1152 .unwrap()
1153 .with_requires_grad();
1154
1155 println!("Left original requires_grad: {}", left.requires_grad());
1156
1157 let left_nc = left.transpose(1, 2).transpose(1, 2);
1159 println!(
1160 "Left non-contiguous requires_grad: {}",
1161 left_nc.requires_grad()
1162 );
1163 println!(
1164 "Left non-contiguous is_contiguous: {}",
1165 left_nc.is_contiguous()
1166 );
1167
1168 let right = Tensor::from_slice(
1170 &[
1171 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1172 ],
1173 vec![2, 3, 2],
1174 )
1175 .unwrap()
1176 .with_requires_grad();
1177
1178 let right_nc = right.transpose(1, 2).transpose(1, 2);
1179 println!(
1180 "Right non-contiguous requires_grad: {}",
1181 right_nc.requires_grad()
1182 );
1183 println!(
1184 "Right non-contiguous is_contiguous: {}",
1185 right_nc.is_contiguous()
1186 );
1187
1188 let mut result = left_nc.matmul(&right_nc);
1189 println!("Result requires_grad: {}", result.requires_grad());
1190
1191 result.backward(None);
1193
1194 println!("Left gradient exists: {}", left_nc.grad_owned().is_some());
1196 println!("Right gradient exists: {}", right_nc.grad_owned().is_some());
1197
1198 if let Some(grad_left) = left_nc.grad_owned() {
1199 println!("Left gradient shape: {:?}", grad_left.shape().dims());
1200 }
1201 if let Some(grad_right) = right_nc.grad_owned() {
1202 println!("Right gradient shape: {:?}", grad_right.shape().dims());
1203 }
1204 }
1205
1206 #[test]
1207 fn test_debug_4d_gradient_issue() {
1208 use crate::gradtrack::clear_gradients;
1209 clear_gradients();
1210
1211 println!("=== Debug 4D Gradient Issue ===");
1212
1213 let left_shape = vec![2, 3, 4, 5];
1215 let right_shape = vec![2, 3, 5, 6];
1216
1217 let left_size: usize = left_shape.iter().product();
1218 let right_size: usize = right_shape.iter().product();
1219
1220 let left_data: Vec<f32> = (0..left_size).map(|i| (i as f32) * 0.1 + 1.0).collect();
1221 let right_data: Vec<f32> = (0..right_size).map(|i| (i as f32) * 0.2 + 0.5).collect();
1222
1223 let left = Tensor::from_slice(&left_data, left_shape.clone())
1224 .unwrap()
1225 .with_requires_grad();
1226 let right = Tensor::from_slice(&right_data, right_shape.clone())
1227 .unwrap()
1228 .with_requires_grad();
1229
1230 println!("Left shape: {:?}", left.shape().dims());
1231 println!("Right shape: {:?}", right.shape().dims());
1232
1233 let result_no_grad = crate::gradtrack::with_no_grad(|| left.matmul(&right));
1235 println!("Forward result shape: {:?}", result_no_grad.shape().dims());
1236 println!("Forward result[0] = {}", result_no_grad.data()[0]);
1237
1238 let mut result = left.matmul(&right);
1240 println!("Result shape: {:?}", result.shape().dims());
1241
1242 let forward_diff = (result.data()[0] - result_no_grad.data()[0]).abs();
1244 println!("Forward pass difference: {}", forward_diff);
1245
1246 result.backward(None);
1248
1249 let grad_left = left.grad_owned().unwrap();
1250
1251 println!("Left gradient shape: {:?}", grad_left.shape().dims());
1252 println!("Left gradient[40] = {}", grad_left.data()[40]);
1253 println!("Expected: ~78, Got: {}", grad_left.data()[40]);
1254
1255 let grad_sum: f32 = grad_left.data().iter().sum();
1257 println!("Gradient sum: {} (should be non-zero)", grad_sum);
1258
1259 if grad_sum.abs() < 1e-6 {
1260 println!("ERROR: Gradient is essentially zero - major computation issue!");
1261 }
1262
1263 println!("\n=== Manual Verification ===");
1265
1266 let grad_ones = Tensor::ones(result.shape().dims().to_vec());
1268 println!("Grad ones shape: {:?}", grad_ones.shape().dims());
1269
1270 let right_rank = right.shape().dims().len();
1272 let right_t = right.transpose(right_rank - 2, right_rank - 1);
1273 println!("Right transposed shape: {:?}", right_t.shape().dims());
1274
1275 let manual_grad_left =
1276 crate::gradtrack::with_no_grad(|| grad_ones.matmul(&right_t.contiguous()));
1277 println!(
1278 "Manual grad left shape: {:?}",
1279 manual_grad_left.shape().dims()
1280 );
1281 println!("Manual grad left[40] = {}", manual_grad_left.data()[40]);
1282
1283 let diff = (grad_left.data()[40] - manual_grad_left.data()[40]).abs();
1285 println!("Difference at element 40: {}", diff);
1286
1287 if diff > 1e-3 {
1288 println!("MAJOR GRADIENT ERROR DETECTED!");
1289 println!(
1290 "Expected (manual): {}, Got (automatic): {}",
1291 manual_grad_left.data()[40],
1292 grad_left.data()[40]
1293 );
1294
1295 println!("\n=== Investigating Root Cause ===");
1297 println!(
1298 "Manual gradient sum: {}",
1299 manual_grad_left.data().iter().sum::<f32>()
1300 );
1301 println!(
1302 "Automatic gradient sum: {}",
1303 grad_left.data().iter().sum::<f32>()
1304 );
1305 } else {
1306 println!("Gradients match!");
1307 }
1308 }
1309
1310 #[test]
1311 fn test_debug_matmul_with_known_values() {
1312 use crate::gradtrack::clear_gradients;
1313 clear_gradients();
1314
1315 println!("=== Debug MatMul with Known Values ===");
1316
1317 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
1319 .unwrap()
1320 .with_requires_grad();
1321 let b = Tensor::from_slice(&[1.0, 0.0, 0.0, 1.0], vec![2, 2]) .unwrap()
1323 .with_requires_grad();
1324
1325 println!("A: {:?}", a.data());
1326 println!("B (identity): {:?}", b.data());
1327
1328 let mut result = a.matmul(&b);
1329 println!("A @ B: {:?}", result.data());
1330 println!("Expected: {:?}", a.data()); result.backward(None);
1334
1335 let grad_a = a.grad_owned().unwrap();
1336 let grad_b = b.grad_owned().unwrap();
1337
1338 println!("grad_A: {:?}", grad_a.data());
1339 println!("grad_B: {:?}", grad_b.data());
1340
1341 let expected_grad_a = vec![1.0, 1.0, 1.0, 1.0];
1346 let expected_grad_b = vec![4.0, 6.0, 4.0, 6.0]; println!("Expected grad_A: {:?}", expected_grad_a);
1349 println!("Expected grad_B: {:?}", expected_grad_b);
1350
1351 for (i, (&expected, &actual)) in
1352 expected_grad_a.iter().zip(grad_a.data().iter()).enumerate()
1353 {
1354 if (expected - actual).abs() > 1e-5 {
1355 println!(
1356 "grad_A error at index {}: expected={}, actual={}, diff={}",
1357 i,
1358 expected,
1359 actual,
1360 (expected - actual).abs()
1361 );
1362 }
1363 }
1364
1365 for (i, (&expected, &actual)) in
1366 expected_grad_b.iter().zip(grad_b.data().iter()).enumerate()
1367 {
1368 if (expected - actual).abs() > 1e-5 {
1369 println!(
1370 "grad_B error at index {}: expected={}, actual={}, diff={}",
1371 i,
1372 expected,
1373 actual,
1374 (expected - actual).abs()
1375 );
1376 }
1377 }
1378 }
1379
1380 #[test]
1382 fn test_matmul_1d_nd_broadcast_forward_backward() {
1383 let k = 4;
1385 let b = 2;
1386 let n = 3;
1387 let left_data: Vec<f32> = (0..k).map(|i| i as f32 + 1.0).collect(); let right_data: Vec<f32> = (0..b * k * n).map(|i| (i as f32) * 0.1 + 0.5).collect();
1389
1390 let left = Tensor::from_slice(&left_data, vec![k])
1391 .unwrap()
1392 .with_requires_grad();
1393 let right = Tensor::from_slice(&right_data, vec![b, k, n])
1394 .unwrap()
1395 .with_requires_grad();
1396
1397 let mut out = left.matmul(&right);
1398 assert_eq!(out.shape().dims(), vec![b, n]);
1399
1400 out.backward(None);
1402 let grad_left = left.grad_owned().unwrap();
1403 let grad_right = right.grad_owned().unwrap();
1404
1405 assert_eq!(grad_left.shape().dims(), vec![k]);
1406 assert_eq!(grad_right.shape().dims(), vec![b, k, n]);
1407 }
1408
1409 #[test]
1411 fn test_matmul_nd_1d_broadcast_forward_backward() {
1412 let b = 3;
1414 let m = 2;
1415 let k = 5;
1416 let left_data: Vec<f32> = (0..b * m * k).map(|i| (i as f32) * 0.05 + 1.0).collect();
1417 let right_data: Vec<f32> = (0..k).map(|i| (i as f32) * 0.2 + 0.5).collect();
1418
1419 let left = Tensor::from_slice(&left_data, vec![b, m, k])
1420 .unwrap()
1421 .with_requires_grad();
1422 let right = Tensor::from_slice(&right_data, vec![k])
1423 .unwrap()
1424 .with_requires_grad();
1425
1426 let mut out = left.matmul(&right);
1427 assert_eq!(out.shape().dims(), vec![b, m]);
1428
1429 out.backward(None);
1431 let grad_left = left.grad_owned().unwrap();
1432 let grad_right = right.grad_owned().unwrap();
1433
1434 assert_eq!(grad_left.shape().dims(), vec![b, m, k]);
1435 assert_eq!(grad_right.shape().dims(), vec![k]);
1436 }
1437
1438 #[test]
1439 fn test_batched_mat_vec_forward_and_shapes() {
1440 let b = 3usize;
1442 let m = 5usize;
1443 let k = 7usize;
1444 let left_data: Vec<f32> = (0..b * m * k).map(|i| (i as f32) * 0.01 + 1.0).collect();
1445 let right_data: Vec<f32> = (0..b * k).map(|i| (i as f32) * 0.02 + 0.5).collect();
1446 let left = Tensor::from_slice(&left_data, vec![b, m, k]).unwrap();
1447 let right = Tensor::from_slice(&right_data, vec![b, k]).unwrap();
1448 let out = left.matmul(&right);
1449 assert_eq!(out.shape().dims(), vec![b, m]);
1450 }
1451
1452 #[test]
1453 fn test_batched_mat_vec_numeric_small() {
1454 let left = Tensor::from_slice(
1456 &[
1457 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, ],
1462 vec![2, 2, 3],
1463 )
1464 .unwrap();
1465 let right = Tensor::from_slice(&[1.0, 1.0, 1.0, 2.0, 3.0, 4.0], vec![2, 3]).unwrap();
1466 let out = left.matmul(&right); assert_eq!(out.shape().dims(), vec![2, 2]);
1468 unsafe {
1469 assert!((*out.as_ptr() - 6.0).abs() < 1e-6);
1471 assert!((*out.as_ptr().add(1) - 15.0).abs() < 1e-6);
1472 assert!((*out.as_ptr().add(2) - 9.0).abs() < 1e-6);
1474 assert!((*out.as_ptr().add(3) - 18.0).abs() < 1e-6);
1475 }
1476 }
1477
1478 #[test]
1479 fn test_batched_vector_cases_grad_shapes() {
1480 use crate::gradtrack::clear_gradients;
1481 clear_gradients();
1482
1483 let left = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 2, 2])
1485 .unwrap()
1486 .with_requires_grad();
1487 let right = Tensor::from_slice(&[1.0, 1.0], vec![1, 2])
1488 .unwrap()
1489 .with_requires_grad();
1490 let mut out = left.matmul(&right);
1491 assert_eq!(out.shape().dims(), vec![1, 2]);
1492 out.backward(None);
1493 let gl = left.grad_owned().unwrap();
1494 let gr = right.grad_owned().unwrap();
1495 assert_eq!(gl.shape().dims(), vec![1, 2, 2]);
1496 assert_eq!(gr.shape().dims(), vec![1, 2]);
1497
1498 clear_gradients();
1499 let left2 = Tensor::from_slice(&[1.0, 1.0], vec![2])
1501 .unwrap()
1502 .with_requires_grad();
1503 let right2 = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
1504 .unwrap()
1505 .with_requires_grad();
1506 let mut out2 = left2.matmul(&right2);
1507 assert_eq!(out2.shape().dims(), vec![2]);
1508 out2.backward(None);
1509 let gl2 = left2.grad_owned().unwrap();
1510 let gr2 = right2.grad_owned().unwrap();
1511 assert_eq!(gl2.shape().dims(), vec![2]);
1512 assert_eq!(gr2.shape().dims(), vec![2, 2]);
1513 }
1514
1515 #[test]
1516 fn test_broadcast_matmul_gradient_complex_case() {
1517 use crate::gradtrack::clear_gradients;
1518 clear_gradients();
1519
1520 println!("=== Testing complex broadcast case: [1,2,2,1,4,5] @ [2,1,5,6] ===");
1521
1522 let left_shape = vec![1, 2, 2, 1, 4, 5];
1524 let right_shape = vec![2, 1, 5, 6];
1525 let expected_result_shape = vec![1, 2, 2, 1, 4, 6];
1526
1527 println!("Left shape: {:?}", left_shape);
1528 println!("Right shape: {:?}", right_shape);
1529 println!("Expected result shape: {:?}", expected_result_shape);
1530
1531 let left_size: usize = left_shape.iter().product();
1533 let right_size: usize = right_shape.iter().product();
1534
1535 let left_data: Vec<f32> = (0..left_size).map(|i| (i as f32) * 0.01 + 1.0).collect();
1536 let right_data: Vec<f32> = (0..right_size).map(|i| (i as f32) * 0.02 + 0.5).collect();
1537
1538 let left = Tensor::from_slice(&left_data, left_shape.clone())
1539 .unwrap()
1540 .with_requires_grad();
1541 let right = Tensor::from_slice(&right_data, right_shape.clone())
1542 .unwrap()
1543 .with_requires_grad();
1544
1545 let mut result = left.matmul(&right);
1547 println!("Actual result shape: {:?}", result.shape().dims());
1548
1549 assert_eq!(
1551 result.shape().dims(),
1552 expected_result_shape,
1553 "Forward result shape mismatch"
1554 );
1555
1556 result.backward(None);
1558
1559 let grad_left = left.grad_owned().unwrap();
1561 let grad_right = right.grad_owned().unwrap();
1562
1563 println!("Left gradient shape: {:?}", grad_left.shape().dims());
1564 println!("Right gradient shape: {:?}", grad_right.shape().dims());
1565
1566 assert_eq!(
1568 grad_left.shape().dims(),
1569 left_shape,
1570 "Left gradient shape mismatch"
1571 );
1572 assert_eq!(
1573 grad_right.shape().dims(),
1574 right_shape,
1575 "Right gradient shape mismatch"
1576 );
1577
1578 let left_grad_sum: f32 = grad_left.data().iter().sum();
1580 let right_grad_sum: f32 = grad_right.data().iter().sum();
1581
1582 println!("Left gradient sum: {}", left_grad_sum);
1583 println!("Right gradient sum: {}", right_grad_sum);
1584
1585 assert!(
1586 left_grad_sum.abs() > 1e-6,
1587 "Left gradient should not be zero"
1588 );
1589 assert!(
1590 right_grad_sum.abs() > 1e-6,
1591 "Right gradient should not be zero"
1592 );
1593
1594 println!("✓ Complex broadcast matmul gradient test passed!");
1595 }
1596
1597 #[test]
1598 fn test_grad_left_vec_at_bkn_manual() {
1599 use crate::gradtrack::clear_gradients;
1600 clear_gradients();
1601 let k = 8usize;
1602 let b = 3usize;
1603 let n = 5usize;
1604
1605 let left = Tensor::from_slice(
1607 &(0..k).map(|i| i as f32 * 0.1 + 1.0).collect::<Vec<_>>(),
1608 vec![k],
1609 )
1610 .unwrap()
1611 .with_requires_grad();
1612
1613 let right = Tensor::from_slice(
1615 &(0..b * k * n)
1616 .map(|i| i as f32 * 0.2 + 0.5)
1617 .collect::<Vec<_>>(),
1618 vec![b, k, n],
1619 )
1620 .unwrap()
1621 .with_requires_grad();
1622
1623 let mut out = left.matmul(&right); out.backward(None);
1626
1627 let grad_left = left.grad_owned().unwrap();
1629 assert_eq!(grad_left.shape().dims(), vec![k]);
1630 let rp = unsafe { right.as_ptr() };
1631 for kk in 0..k {
1632 let mut sum = 0.0f32;
1633 for bb in 0..b {
1634 for nn in 0..n {
1635 let idx = bb * (k * n) + kk * n + nn;
1636 unsafe { sum += *rp.add(idx) };
1637 }
1638 }
1639 let got = unsafe { *grad_left.as_ptr().add(kk) };
1640 assert!(
1641 (got - sum).abs() < 1e-4,
1642 "k={} got={} expected={}",
1643 kk,
1644 got,
1645 sum
1646 );
1647 }
1648 }
1649}