1use crate::error::{CoreError, Result};
7use crate::tensor::Tensor;
8use crate::{Float, Scalar};
9
10pub fn dot<T: Scalar>(x: &Tensor<T>, y: &Tensor<T>) -> Result<T> {
27 check_vectors(x, y, "dot")?;
28 Ok(dot_slice(x.as_slice(), y.as_slice()))
29}
30
31fn dot_slice<T: Scalar>(a: &[T], b: &[T]) -> T {
33 #[cfg(feature = "simd")]
34 {
35 use crate::simd;
36 use std::any::TypeId;
37 if TypeId::of::<T>() == TypeId::of::<f64>() {
38 let result =
40 unsafe { simd::f64_ops::dot_f64(simd::slice_as_f64(a), simd::slice_as_f64(b)) };
41 return unsafe { simd::f64_to_t(result) };
42 }
43 if TypeId::of::<T>() == TypeId::of::<f32>() {
44 let result =
46 unsafe { simd::f32_ops::dot_f32(simd::slice_as_f32(a), simd::slice_as_f32(b)) };
47 return unsafe { simd::f32_to_t(result) };
48 }
49 }
50 a.iter()
51 .zip(b.iter())
52 .fold(T::zero(), |acc, (&x, &y)| acc + x * y)
53}
54
55pub fn axpy<T: Scalar>(alpha: T, x: &Tensor<T>, y: &mut Tensor<T>) -> Result<()> {
68 check_vectors(x, y, "axpy")?;
69 axpy_slice(alpha, x.as_slice(), y.as_mut_slice());
70 Ok(())
71}
72
73fn axpy_slice<T: Scalar>(alpha: T, x: &[T], y: &mut [T]) {
75 #[cfg(feature = "simd")]
76 {
77 use crate::simd;
78 use std::any::TypeId;
79 if TypeId::of::<T>() == TypeId::of::<f64>() {
80 unsafe {
82 simd::f64_ops::axpy_f64(
83 simd::t_to_f64(alpha),
84 simd::slice_as_f64(x),
85 simd::slice_as_f64_mut(y),
86 );
87 }
88 return;
89 }
90 if TypeId::of::<T>() == TypeId::of::<f32>() {
91 unsafe {
93 simd::f32_ops::axpy_f32(
94 simd::t_to_f32(alpha),
95 simd::slice_as_f32(x),
96 simd::slice_as_f32_mut(y),
97 );
98 }
99 return;
100 }
101 }
102 for (yi, &xi) in y.iter_mut().zip(x.iter()) {
103 *yi += alpha * xi;
104 }
105}
106
107pub fn nrm2<T: Float>(x: &Tensor<T>) -> Result<T> {
117 check_vector(x, "nrm2")?;
118 #[cfg(feature = "simd")]
119 {
120 use crate::simd;
121 use std::any::TypeId;
122 if TypeId::of::<T>() == TypeId::of::<f64>() {
123 let result =
125 unsafe { simd::f64_ops::sum_sq_f64(simd::slice_as_f64(x.as_slice())).sqrt() };
126 return Ok(unsafe { simd::f64_to_t(result) });
127 }
128 if TypeId::of::<T>() == TypeId::of::<f32>() {
129 let result =
131 unsafe { simd::f32_ops::sum_sq_f32(simd::slice_as_f32(x.as_slice())).sqrt() };
132 return Ok(unsafe { simd::f32_to_t(result) });
133 }
134 }
135 let sum_sq = x.as_slice().iter().fold(T::zero(), |acc, &v| acc + v * v);
136 Ok(sum_sq.sqrt())
137}
138
139pub fn asum<T: Float>(x: &Tensor<T>) -> Result<T> {
149 check_vector(x, "asum")?;
150 #[cfg(feature = "simd")]
151 {
152 use crate::simd;
153 use std::any::TypeId;
154 if TypeId::of::<T>() == TypeId::of::<f64>() {
155 let result = unsafe { simd::f64_ops::asum_f64(simd::slice_as_f64(x.as_slice())) };
157 return Ok(unsafe { simd::f64_to_t(result) });
158 }
159 if TypeId::of::<T>() == TypeId::of::<f32>() {
160 let result = unsafe { simd::f32_ops::asum_f32(simd::slice_as_f32(x.as_slice())) };
162 return Ok(unsafe { simd::f32_to_t(result) });
163 }
164 }
165 let result = x.as_slice().iter().fold(T::zero(), |acc, &v| acc + v.abs());
166 Ok(result)
167}
168
169pub fn scal<T: Scalar>(alpha: T, x: &mut Tensor<T>) -> Result<()> {
179 check_vector(x, "scal")?;
180 #[cfg(feature = "simd")]
181 {
182 use crate::simd;
183 use std::any::TypeId;
184 if TypeId::of::<T>() == TypeId::of::<f64>() {
185 unsafe {
187 simd::f64_ops::scal_f64(
188 simd::t_to_f64(alpha),
189 simd::slice_as_f64_mut(x.as_mut_slice()),
190 );
191 }
192 return Ok(());
193 }
194 if TypeId::of::<T>() == TypeId::of::<f32>() {
195 unsafe {
197 simd::f32_ops::scal_f32(
198 simd::t_to_f32(alpha),
199 simd::slice_as_f32_mut(x.as_mut_slice()),
200 );
201 }
202 return Ok(());
203 }
204 }
205 for v in x.as_mut_slice() {
206 *v *= alpha;
207 }
208 Ok(())
209}
210
211pub fn iamax<T: Float>(x: &Tensor<T>) -> Result<Option<usize>> {
222 check_vector(x, "iamax")?;
223 if x.is_empty() {
224 return Ok(None);
225 }
226 let mut max_idx = 0;
227 let mut max_val = x.as_slice()[0].abs();
228 for (i, &v) in x.as_slice().iter().enumerate().skip(1) {
229 let av = v.abs();
230 if av > max_val {
231 max_val = av;
232 max_idx = i;
233 }
234 }
235 Ok(Some(max_idx))
236}
237
238#[allow(clippy::many_single_char_names)]
261pub fn gemv<T: Scalar>(
262 alpha: T,
263 a: &Tensor<T>,
264 x: &Tensor<T>,
265 beta: T,
266 y: &mut Tensor<T>,
267) -> Result<()> {
268 if a.ndim() != 2 {
269 return Err(CoreError::InvalidArgument {
270 reason: "gemv: `a` must be a 2-D tensor (matrix)",
271 });
272 }
273 if x.ndim() != 1 {
274 return Err(CoreError::InvalidArgument {
275 reason: "gemv: `x` must be a 1-D tensor (vector)",
276 });
277 }
278 if y.ndim() != 1 {
279 return Err(CoreError::InvalidArgument {
280 reason: "gemv: `y` must be a 1-D tensor (vector)",
281 });
282 }
283
284 let m = a.shape()[0];
285 let n = a.shape()[1];
286
287 if x.numel() != n {
288 return Err(CoreError::DimensionMismatch {
289 expected: vec![n],
290 got: x.shape().to_vec(),
291 });
292 }
293 if y.numel() != m {
294 return Err(CoreError::DimensionMismatch {
295 expected: vec![m],
296 got: y.shape().to_vec(),
297 });
298 }
299
300 let a_data = a.as_slice();
301 let x_data = x.as_slice();
302 let y_data = y.as_mut_slice();
303
304 for (i, yi) in y_data.iter_mut().enumerate().take(m) {
305 let row_offset = i * n;
306 let row = &a_data[row_offset..row_offset + n];
307 let sum = dot_slice(row, x_data);
308 *yi = alpha * sum + beta * *yi;
309 }
310
311 Ok(())
312}
313
314#[allow(clippy::many_single_char_names, clippy::too_many_lines)]
336pub fn gemm<T: Scalar>(
337 alpha: T,
338 a: &Tensor<T>,
339 b: &Tensor<T>,
340 beta: T,
341 c: &mut Tensor<T>,
342) -> Result<()> {
343 const MC: usize = 64; const KC: usize = 256; const NC: usize = 256; if a.ndim() != 2 || b.ndim() != 2 || c.ndim() != 2 {
349 return Err(CoreError::InvalidArgument {
350 reason: "gemm: all arguments must be 2-D tensors (matrices)",
351 });
352 }
353
354 let m = a.shape()[0];
355 let k = a.shape()[1];
356 let n = b.shape()[1];
357
358 if b.shape()[0] != k {
359 return Err(CoreError::DimensionMismatch {
360 expected: vec![k, n],
361 got: b.shape().to_vec(),
362 });
363 }
364 if c.shape()[0] != m || c.shape()[1] != n {
365 return Err(CoreError::DimensionMismatch {
366 expected: vec![m, n],
367 got: c.shape().to_vec(),
368 });
369 }
370
371 let a_data = a.as_slice();
372 let b_data = b.as_slice();
373 let c_data = c.as_mut_slice();
374
375 if beta == T::zero() {
377 for v in c_data.iter_mut() {
378 *v = T::zero();
379 }
380 } else if beta != T::one() {
381 for v in c_data.iter_mut() {
382 *v *= beta;
383 }
384 }
385
386 for pk in (0..k).step_by(KC) {
392 let kb = KC.min(k - pk);
393
394 for pi in (0..m).step_by(MC) {
396 let mb = MC.min(m - pi);
397
398 for pj in (0..n).step_by(NC) {
400 let nb = NC.min(n - pj);
401
402 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
405 {
406 use std::any::TypeId;
407 if TypeId::of::<T>() == TypeId::of::<f64>() {
408 unsafe {
409 let a_f64 = a_data.as_ptr().cast::<f64>();
410 let b_f64 = b_data.as_ptr().cast::<f64>();
411 let c_f64 = c_data.as_mut_ptr().cast::<f64>();
412 let alpha_f64 = crate::simd::t_to_f64(alpha);
413
414 let j4 = nb / 4 * 4;
415
416 let i8 = mb / 8 * 8;
418 for i in (0..i8).step_by(8) {
419 for j in (0..j4).step_by(4) {
420 let a_off = (pi + i) * k + pk;
421 let b_off = pk * n + (pj + j);
422 let c_off = (pi + i) * n + (pj + j);
423 crate::simd::neon_f64_ops::gemm_8x4_f64_neon(
424 a_f64.add(a_off),
425 b_f64.add(b_off),
426 c_f64.add(c_off),
427 alpha_f64,
428 kb, k, n, n,
429 );
430 }
431 if j4 < nb {
433 for ii in 0..8 {
434 let row_a = (pi + i + ii) * k + pk;
435 let row_c = (pi + i + ii) * n + pj + j4;
436 for p in 0..kb {
437 let scale_f64 = alpha_f64 * *a_f64.add(row_a + p);
438 for jj in 0..(nb - j4) {
439 let b_idx = (pk + p) * n + pj + j4 + jj;
440 *c_f64.add(row_c + jj) += scale_f64 * *b_f64.add(b_idx);
441 }
442 }
443 }
444 }
445 }
446 let i4_start = i8;
448 let i4_end = i4_start + (mb - i8) / 4 * 4;
449 for i in (i4_start..i4_end).step_by(4) {
450 for j in (0..j4).step_by(4) {
451 let a_off = (pi + i) * k + pk;
452 let b_off = pk * n + (pj + j);
453 let c_off = (pi + i) * n + (pj + j);
454 crate::simd::neon_f64_ops::gemm_4x4_f64_neon(
455 a_f64.add(a_off),
456 b_f64.add(b_off),
457 c_f64.add(c_off),
458 alpha_f64,
459 kb, k, n, n,
460 );
461 }
462 if j4 < nb {
463 for ii in 0..4 {
464 let row_a = (pi + i + ii) * k + pk;
465 let row_c = (pi + i + ii) * n + pj + j4;
466 for p in 0..kb {
467 let scale_f64 = alpha_f64 * *a_f64.add(row_a + p);
468 for jj in 0..(nb - j4) {
469 let b_idx = (pk + p) * n + pj + j4 + jj;
470 *c_f64.add(row_c + jj) += scale_f64 * *b_f64.add(b_idx);
471 }
472 }
473 }
474 }
475 }
476 for i in i4_end..mb {
478 let row_a = (pi + i) * k + pk;
479 let row_c = (pi + i) * n + pj;
480 for p in 0..kb {
481 let scale = alpha * a_data[row_a + p];
482 let b_off2 = (pk + p) * n + pj;
483 let b_row = &b_data[b_off2..b_off2 + nb];
484 let c_slice = &mut c_data[row_c..row_c + nb];
485 axpy_slice(scale, b_row, c_slice);
486 }
487 }
488 }
489 continue;
490 }
491 }
492
493 for i in 0..mb {
495 let row_a = (pi + i) * k + pk;
496 let row_c = (pi + i) * n + pj;
497 for p in 0..kb {
498 let scale = alpha * a_data[row_a + p];
499 let b_off = (pk + p) * n + pj;
500 let b_row = &b_data[b_off..b_off + nb];
501 let c_slice = &mut c_data[row_c..row_c + nb];
502 axpy_slice(scale, b_row, c_slice);
503 }
504 }
505 }
506 }
507 }
508
509 Ok(())
510}
511
512impl<T: Scalar> Tensor<T> {
517 pub fn matvec(&self, x: &Tensor<T>) -> Result<Tensor<T>> {
531 let m = self.shape().first().copied().unwrap_or(0);
532 let mut y = Tensor::zeros(vec![m]);
533 gemv(T::one(), self, x, T::zero(), &mut y)?;
534 Ok(y)
535 }
536
537 pub fn matmul(&self, other: &Tensor<T>) -> Result<Tensor<T>> {
551 let m = self.shape().first().copied().unwrap_or(0);
552 let n = other.shape().get(1).copied().unwrap_or(0);
553 let mut c = Tensor::zeros(vec![m, n]);
554 gemm(T::one(), self, other, T::zero(), &mut c)?;
555 Ok(c)
556 }
557
558 pub fn dot(&self, other: &Tensor<T>) -> Result<T> {
569 dot(self, other)
570 }
571}
572
573impl<T: Float> Tensor<T> {
574 pub fn norm(&self) -> Result<T> {
584 nrm2(self)
585 }
586
587 pub fn solve(&self, b: &Tensor<T>) -> Result<Tensor<T>> {
602 crate::linalg::solve(self, b)
603 }
604
605 pub fn inv(&self) -> Result<Tensor<T>> {
619 crate::linalg::inv(self)
620 }
621
622 pub fn det(&self) -> Result<T> {
634 crate::linalg::det(self)
635 }
636
637 pub fn lstsq(&self, b: &Tensor<T>) -> Result<Tensor<T>> {
652 crate::linalg::lstsq(self, b)
653 }
654}
655
656fn check_vector<T: Scalar>(x: &Tensor<T>, name: &'static str) -> Result<()> {
661 if x.ndim() != 1 {
662 return Err(CoreError::InvalidArgument {
663 reason: match name {
664 "nrm2" => "nrm2: expected a 1-D tensor",
665 "asum" => "asum: expected a 1-D tensor",
666 "scal" => "scal: expected a 1-D tensor",
667 "iamax" => "iamax: expected a 1-D tensor",
668 _ => "expected a 1-D tensor",
669 },
670 });
671 }
672 Ok(())
673}
674
675fn check_vectors<T: Scalar>(x: &Tensor<T>, y: &Tensor<T>, name: &'static str) -> Result<()> {
676 if x.ndim() != 1 || y.ndim() != 1 {
677 return Err(CoreError::InvalidArgument {
678 reason: match name {
679 "dot" => "dot: both arguments must be 1-D tensors",
680 "axpy" => "axpy: both arguments must be 1-D tensors",
681 _ => "both arguments must be 1-D tensors",
682 },
683 });
684 }
685 if x.numel() != y.numel() {
686 return Err(CoreError::DimensionMismatch {
687 expected: x.shape().to_vec(),
688 got: y.shape().to_vec(),
689 });
690 }
691 Ok(())
692}
693
694#[cfg(test)]
695#[allow(clippy::float_cmp)]
696mod tests {
697 use super::*;
698
699 fn vec_f64(data: &[f64]) -> Tensor<f64> {
704 Tensor::from_vec(data.to_vec(), vec![data.len()]).unwrap()
705 }
706
707 fn mat_f64(data: &[f64], rows: usize, cols: usize) -> Tensor<f64> {
708 Tensor::from_vec(data.to_vec(), vec![rows, cols]).unwrap()
709 }
710
711 #[test]
716 fn test_dot_basic() {
717 let x = vec_f64(&[1.0, 2.0, 3.0]);
718 let y = vec_f64(&[4.0, 5.0, 6.0]);
719 assert_eq!(dot(&x, &y).unwrap(), 32.0);
720 }
721
722 #[test]
723 fn test_dot_single() {
724 let x = vec_f64(&[3.0]);
725 let y = vec_f64(&[7.0]);
726 assert_eq!(dot(&x, &y).unwrap(), 21.0);
727 }
728
729 #[test]
730 fn test_dot_length_mismatch() {
731 let x = vec_f64(&[1.0, 2.0]);
732 let y = vec_f64(&[1.0, 2.0, 3.0]);
733 assert!(dot(&x, &y).is_err());
734 }
735
736 #[test]
737 fn test_dot_not_1d() {
738 let x = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
739 let y = vec_f64(&[1.0, 2.0]);
740 assert!(dot(&x, &y).is_err());
741 }
742
743 #[test]
744 fn test_axpy() {
745 let x = vec_f64(&[1.0, 2.0, 3.0]);
746 let mut y = vec_f64(&[10.0, 20.0, 30.0]);
747 axpy(2.0, &x, &mut y).unwrap();
748 assert_eq!(y.as_slice(), &[12.0, 24.0, 36.0]);
749 }
750
751 #[test]
752 fn test_axpy_zero_alpha() {
753 let x = vec_f64(&[1.0, 2.0, 3.0]);
754 let mut y = vec_f64(&[10.0, 20.0, 30.0]);
755 axpy(0.0, &x, &mut y).unwrap();
756 assert_eq!(y.as_slice(), &[10.0, 20.0, 30.0]);
757 }
758
759 #[test]
760 fn test_nrm2() {
761 let x = vec_f64(&[3.0, 4.0]);
762 assert!((nrm2(&x).unwrap() - 5.0).abs() < 1e-10);
763 }
764
765 #[test]
766 fn test_nrm2_single() {
767 let x = vec_f64(&[-7.0]);
768 assert!((nrm2(&x).unwrap() - 7.0).abs() < 1e-10);
769 }
770
771 #[test]
772 fn test_asum() {
773 let x = vec_f64(&[-1.0, 2.0, -3.0, 4.0]);
774 assert!((asum(&x).unwrap() - 10.0).abs() < 1e-10);
775 }
776
777 #[test]
778 fn test_scal() {
779 let mut x = vec_f64(&[1.0, 2.0, 3.0]);
780 scal(10.0, &mut x).unwrap();
781 assert_eq!(x.as_slice(), &[10.0, 20.0, 30.0]);
782 }
783
784 #[test]
785 fn test_scal_zero() {
786 let mut x = vec_f64(&[1.0, 2.0, 3.0]);
787 scal(0.0, &mut x).unwrap();
788 assert_eq!(x.as_slice(), &[0.0, 0.0, 0.0]);
789 }
790
791 #[test]
792 fn test_iamax() {
793 let x = vec_f64(&[1.0, -5.0, 3.0, -2.0]);
794 assert_eq!(iamax(&x).unwrap(), Some(1));
795 }
796
797 #[test]
798 fn test_iamax_first_is_max() {
799 let x = vec_f64(&[100.0, 1.0, 2.0]);
800 assert_eq!(iamax(&x).unwrap(), Some(0));
801 }
802
803 #[test]
804 fn test_iamax_empty() {
805 let x = Tensor::<f64>::zeros(vec![0]);
806 assert_eq!(iamax(&x).unwrap(), None);
807 }
808
809 #[test]
814 fn test_gemv_basic() {
815 let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
818 let x = vec_f64(&[5.0, 6.0]);
819 let mut y = Tensor::<f64>::zeros(vec![2]);
820 gemv(1.0, &a, &x, 0.0, &mut y).unwrap();
821 assert_eq!(y.as_slice(), &[17.0, 39.0]);
822 }
823
824 #[test]
825 fn test_gemv_with_alpha_beta() {
826 let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
828 let x = vec_f64(&[1.0, 1.0]);
829 let mut y = vec_f64(&[10.0, 10.0]);
830 gemv(2.0, &a, &x, 3.0, &mut y).unwrap();
831 assert_eq!(y.as_slice(), &[36.0, 44.0]);
833 }
834
835 #[test]
836 fn test_gemv_rectangular() {
837 let a = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
841 let x = vec_f64(&[1.0, 0.0, 1.0]);
842 let mut y = Tensor::<f64>::zeros(vec![2]);
843 gemv(1.0, &a, &x, 0.0, &mut y).unwrap();
844 assert_eq!(y.as_slice(), &[4.0, 10.0]);
845 }
846
847 #[test]
848 fn test_gemv_dimension_mismatch() {
849 let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
850 let x = vec_f64(&[1.0, 2.0, 3.0]);
851 let mut y = Tensor::<f64>::zeros(vec![2]);
852 assert!(gemv(1.0, &a, &x, 0.0, &mut y).is_err());
853 }
854
855 #[test]
856 fn test_gemv_y_dimension_mismatch() {
857 let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
858 let x = vec_f64(&[1.0, 2.0]);
859 let mut y = Tensor::<f64>::zeros(vec![3]);
860 assert!(gemv(1.0, &a, &x, 0.0, &mut y).is_err());
861 }
862
863 #[test]
868 fn test_gemm_square() {
869 let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
873 let b = mat_f64(&[5.0, 6.0, 7.0, 8.0], 2, 2);
874 let mut c = Tensor::<f64>::zeros(vec![2, 2]);
875 gemm(1.0, &a, &b, 0.0, &mut c).unwrap();
876 assert_eq!(c.as_slice(), &[19.0, 22.0, 43.0, 50.0]);
877 }
878
879 #[test]
880 fn test_gemm_rectangular() {
881 let a = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
883 let b = mat_f64(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], 3, 2);
884 let mut c = Tensor::<f64>::zeros(vec![2, 2]);
885 gemm(1.0, &a, &b, 0.0, &mut c).unwrap();
886 assert_eq!(c.as_slice(), &[58.0, 64.0, 139.0, 154.0]);
889 }
890
891 #[test]
892 fn test_gemm_with_alpha_beta() {
893 let a = mat_f64(&[1.0, 0.0, 0.0, 1.0], 2, 2); let b = mat_f64(&[5.0, 6.0, 7.0, 8.0], 2, 2);
896 let mut c = mat_f64(&[1.0, 1.0, 1.0, 1.0], 2, 2);
897 gemm(2.0, &a, &b, 3.0, &mut c).unwrap();
898 assert_eq!(c.as_slice(), &[13.0, 15.0, 17.0, 19.0]);
900 }
901
902 #[test]
903 fn test_gemm_identity() {
904 let a = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 3, 3);
905 let eye = Tensor::<f64>::eye(3);
906 let mut c = Tensor::<f64>::zeros(vec![3, 3]);
907 gemm(1.0, &a, &eye, 0.0, &mut c).unwrap();
908 assert_eq!(c.as_slice(), a.as_slice());
909 }
910
911 #[test]
912 fn test_gemm_single_element() {
913 let a = mat_f64(&[3.0], 1, 1);
914 let b = mat_f64(&[7.0], 1, 1);
915 let mut c = Tensor::<f64>::zeros(vec![1, 1]);
916 gemm(1.0, &a, &b, 0.0, &mut c).unwrap();
917 assert_eq!(c.as_slice(), &[21.0]);
918 }
919
920 #[test]
921 fn test_gemm_dimension_mismatch() {
922 let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
923 let b = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
924 let mut c = Tensor::<f64>::zeros(vec![2, 2]);
925 assert!(gemm(1.0, &a, &b, 0.0, &mut c).is_err());
926 }
927
928 #[test]
929 fn test_gemm_c_shape_mismatch() {
930 let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
931 let b = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
932 let mut c = Tensor::<f64>::zeros(vec![3, 3]);
933 assert!(gemm(1.0, &a, &b, 0.0, &mut c).is_err());
934 }
935
936 #[test]
941 fn test_matvec() {
942 let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
943 let x = vec_f64(&[5.0, 6.0]);
944 let y = a.matvec(&x).unwrap();
945 assert_eq!(y.as_slice(), &[17.0, 39.0]);
946 }
947
948 #[test]
949 fn test_matmul() {
950 let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
951 let b = mat_f64(&[5.0, 6.0, 7.0, 8.0], 2, 2);
952 let c = a.matmul(&b).unwrap();
953 assert_eq!(c.as_slice(), &[19.0, 22.0, 43.0, 50.0]);
954 }
955
956 #[test]
957 fn test_tensor_dot() {
958 let x = vec_f64(&[1.0, 2.0, 3.0]);
959 let y = vec_f64(&[4.0, 5.0, 6.0]);
960 assert_eq!(x.dot(&y).unwrap(), 32.0);
961 }
962
963 #[test]
964 fn test_tensor_norm() {
965 let x = vec_f64(&[3.0, 4.0]);
966 assert!((x.norm().unwrap() - 5.0).abs() < 1e-10);
967 }
968
969 #[test]
974 fn test_gemm_numpy_reference() {
975 let a = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
982 let b = mat_f64(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], 3, 2);
983 let c = a.matmul(&b).unwrap();
984 assert_eq!(c.as_slice(), &[58.0, 64.0, 139.0, 154.0]);
985 }
986
987 #[test]
988 fn test_gemv_numpy_reference() {
989 let a = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
994 let x = vec_f64(&[1.0, 1.0, 1.0]);
995 let y = a.matvec(&x).unwrap();
996 assert_eq!(y.as_slice(), &[6.0, 15.0]);
997 }
998
999 #[test]
1000 fn test_dot_numpy_reference() {
1001 let x = vec_f64(&[1.0, 2.0, 3.0, 4.0, 5.0]);
1004 let y = vec_f64(&[5.0, 4.0, 3.0, 2.0, 1.0]);
1005 assert_eq!(dot(&x, &y).unwrap(), 35.0);
1006 }
1007
1008 #[test]
1009 fn test_nrm2_numpy_reference() {
1010 let x = vec_f64(&[1.0, 2.0, 3.0, 4.0, 5.0]);
1013 let n = nrm2(&x).unwrap();
1014 assert!((n - 7.416_198_487_095_663).abs() < 1e-12);
1015 }
1016}