Skip to main content

scivex_core/linalg/
blas.rs

1//! BLAS Level 1–3 operations on [`Tensor`].
2//!
3//! All functions operate on tensors and validate shapes, returning
4//! [`Result`] on dimension mismatches.
5
6use crate::error::{CoreError, Result};
7use crate::tensor::Tensor;
8use crate::{Float, Scalar};
9
10// ======================================================================
11// BLAS Level 1 — vector operations, O(n)
12// ======================================================================
13
14/// Inner (dot) product of two 1-D tensors: `sum(x_i * y_i)`.
15///
16/// Both tensors must be 1-D with the same length.
17///
18/// ```
19/// # use scivex_core::tensor::Tensor;
20/// # use scivex_core::linalg::dot;
21/// let x = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap();
22/// let y = Tensor::from_vec(vec![4.0_f64, 5.0, 6.0], vec![3]).unwrap();
23/// let d = dot(&x, &y).unwrap();
24/// assert!((d - 32.0).abs() < 1e-10);
25/// ```
26pub fn dot<T: Scalar>(x: &Tensor<T>, y: &Tensor<T>) -> Result<T> {
27    check_vectors(x, y, "dot")?;
28    Ok(dot_slice(x.as_slice(), y.as_slice()))
29}
30
31/// Inner dot product on raw slices, dispatching to SIMD when available.
32fn dot_slice<T: Scalar>(a: &[T], b: &[T]) -> T {
33    #[cfg(feature = "simd")]
34    {
35        use crate::simd;
36        use std::any::TypeId;
37        if TypeId::of::<T>() == TypeId::of::<f64>() {
38            // SAFETY: T is f64 confirmed by TypeId.
39            let result =
40                unsafe { simd::f64_ops::dot_f64(simd::slice_as_f64(a), simd::slice_as_f64(b)) };
41            return unsafe { simd::f64_to_t(result) };
42        }
43        if TypeId::of::<T>() == TypeId::of::<f32>() {
44            // SAFETY: T is f32 confirmed by TypeId.
45            let result =
46                unsafe { simd::f32_ops::dot_f32(simd::slice_as_f32(a), simd::slice_as_f32(b)) };
47            return unsafe { simd::f32_to_t(result) };
48        }
49    }
50    a.iter()
51        .zip(b.iter())
52        .fold(T::zero(), |acc, (&x, &y)| acc + x * y)
53}
54
55/// `y = alpha * x + y` (in-place update of `y`).
56///
57/// Both tensors must be 1-D with the same length.
58///
59/// ```
60/// # use scivex_core::tensor::Tensor;
61/// # use scivex_core::linalg::axpy;
62/// let x = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
63/// let mut y = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
64/// axpy(2.0, &x, &mut y).unwrap();
65/// assert_eq!(y.as_slice(), &[12.0, 24.0, 36.0]);
66/// ```
67pub fn axpy<T: Scalar>(alpha: T, x: &Tensor<T>, y: &mut Tensor<T>) -> Result<()> {
68    check_vectors(x, y, "axpy")?;
69    axpy_slice(alpha, x.as_slice(), y.as_mut_slice());
70    Ok(())
71}
72
73/// In-place axpy on raw slices, dispatching to SIMD when available.
74fn axpy_slice<T: Scalar>(alpha: T, x: &[T], y: &mut [T]) {
75    #[cfg(feature = "simd")]
76    {
77        use crate::simd;
78        use std::any::TypeId;
79        if TypeId::of::<T>() == TypeId::of::<f64>() {
80            // SAFETY: T is f64 confirmed by TypeId.
81            unsafe {
82                simd::f64_ops::axpy_f64(
83                    simd::t_to_f64(alpha),
84                    simd::slice_as_f64(x),
85                    simd::slice_as_f64_mut(y),
86                );
87            }
88            return;
89        }
90        if TypeId::of::<T>() == TypeId::of::<f32>() {
91            // SAFETY: T is f32 confirmed by TypeId.
92            unsafe {
93                simd::f32_ops::axpy_f32(
94                    simd::t_to_f32(alpha),
95                    simd::slice_as_f32(x),
96                    simd::slice_as_f32_mut(y),
97                );
98            }
99            return;
100        }
101    }
102    for (yi, &xi) in y.iter_mut().zip(x.iter()) {
103        *yi += alpha * xi;
104    }
105}
106
107/// Euclidean norm (L2 norm) of a 1-D tensor: `sqrt(sum(x_i^2))`.
108///
109/// ```
110/// # use scivex_core::tensor::Tensor;
111/// # use scivex_core::linalg::nrm2;
112/// let x = Tensor::from_vec(vec![3.0_f64, 4.0], vec![2]).unwrap();
113/// let n = nrm2(&x).unwrap();
114/// assert!((n - 5.0).abs() < 1e-10);
115/// ```
116pub fn nrm2<T: Float>(x: &Tensor<T>) -> Result<T> {
117    check_vector(x, "nrm2")?;
118    #[cfg(feature = "simd")]
119    {
120        use crate::simd;
121        use std::any::TypeId;
122        if TypeId::of::<T>() == TypeId::of::<f64>() {
123            // SAFETY: T is f64 confirmed by TypeId.
124            let result =
125                unsafe { simd::f64_ops::sum_sq_f64(simd::slice_as_f64(x.as_slice())).sqrt() };
126            return Ok(unsafe { simd::f64_to_t(result) });
127        }
128        if TypeId::of::<T>() == TypeId::of::<f32>() {
129            // SAFETY: T is f32 confirmed by TypeId.
130            let result =
131                unsafe { simd::f32_ops::sum_sq_f32(simd::slice_as_f32(x.as_slice())).sqrt() };
132            return Ok(unsafe { simd::f32_to_t(result) });
133        }
134    }
135    let sum_sq = x.as_slice().iter().fold(T::zero(), |acc, &v| acc + v * v);
136    Ok(sum_sq.sqrt())
137}
138
139/// Sum of absolute values (L1 norm) of a 1-D tensor: `sum(|x_i|)`.
140///
141/// ```
142/// # use scivex_core::tensor::Tensor;
143/// # use scivex_core::linalg::asum;
144/// let x = Tensor::from_vec(vec![-1.0_f64, 2.0, -3.0], vec![3]).unwrap();
145/// let s = asum(&x).unwrap();
146/// assert!((s - 6.0).abs() < 1e-10);
147/// ```
148pub fn asum<T: Float>(x: &Tensor<T>) -> Result<T> {
149    check_vector(x, "asum")?;
150    #[cfg(feature = "simd")]
151    {
152        use crate::simd;
153        use std::any::TypeId;
154        if TypeId::of::<T>() == TypeId::of::<f64>() {
155            // SAFETY: T is f64 confirmed by TypeId.
156            let result = unsafe { simd::f64_ops::asum_f64(simd::slice_as_f64(x.as_slice())) };
157            return Ok(unsafe { simd::f64_to_t(result) });
158        }
159        if TypeId::of::<T>() == TypeId::of::<f32>() {
160            // SAFETY: T is f32 confirmed by TypeId.
161            let result = unsafe { simd::f32_ops::asum_f32(simd::slice_as_f32(x.as_slice())) };
162            return Ok(unsafe { simd::f32_to_t(result) });
163        }
164    }
165    let result = x.as_slice().iter().fold(T::zero(), |acc, &v| acc + v.abs());
166    Ok(result)
167}
168
169/// Scale a vector in place: `x = alpha * x`.
170///
171/// ```
172/// # use scivex_core::tensor::Tensor;
173/// # use scivex_core::linalg::scal;
174/// let mut x = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
175/// scal(10.0, &mut x).unwrap();
176/// assert_eq!(x.as_slice(), &[10.0, 20.0, 30.0]);
177/// ```
178pub fn scal<T: Scalar>(alpha: T, x: &mut Tensor<T>) -> Result<()> {
179    check_vector(x, "scal")?;
180    #[cfg(feature = "simd")]
181    {
182        use crate::simd;
183        use std::any::TypeId;
184        if TypeId::of::<T>() == TypeId::of::<f64>() {
185            // SAFETY: T is f64 confirmed by TypeId.
186            unsafe {
187                simd::f64_ops::scal_f64(
188                    simd::t_to_f64(alpha),
189                    simd::slice_as_f64_mut(x.as_mut_slice()),
190                );
191            }
192            return Ok(());
193        }
194        if TypeId::of::<T>() == TypeId::of::<f32>() {
195            // SAFETY: T is f32 confirmed by TypeId.
196            unsafe {
197                simd::f32_ops::scal_f32(
198                    simd::t_to_f32(alpha),
199                    simd::slice_as_f32_mut(x.as_mut_slice()),
200                );
201            }
202            return Ok(());
203        }
204    }
205    for v in x.as_mut_slice() {
206        *v *= alpha;
207    }
208    Ok(())
209}
210
211/// Index of the element with the largest absolute value.
212///
213/// Returns `None` for empty tensors.
214///
215/// ```
216/// # use scivex_core::tensor::Tensor;
217/// # use scivex_core::linalg::iamax;
218/// let x = Tensor::from_vec(vec![1.0_f64, -5.0, 3.0], vec![3]).unwrap();
219/// assert_eq!(iamax(&x).unwrap(), Some(1));
220/// ```
221pub fn iamax<T: Float>(x: &Tensor<T>) -> Result<Option<usize>> {
222    check_vector(x, "iamax")?;
223    if x.is_empty() {
224        return Ok(None);
225    }
226    let mut max_idx = 0;
227    let mut max_val = x.as_slice()[0].abs();
228    for (i, &v) in x.as_slice().iter().enumerate().skip(1) {
229        let av = v.abs();
230        if av > max_val {
231            max_val = av;
232            max_idx = i;
233        }
234    }
235    Ok(Some(max_idx))
236}
237
238// ======================================================================
239// BLAS Level 2 — matrix-vector operations, O(n^2)
240// ======================================================================
241
242/// General matrix-vector multiply: `y = alpha * A * x + beta * y`.
243///
244/// - `a` must be 2-D with shape `[m, n]`.
245/// - `x` must be 1-D with length `n`.
246/// - `y` must be 1-D with length `m`.
247///
248/// If `beta` is zero, `y` is overwritten (not read).
249///
250/// ```
251/// # use scivex_core::tensor::Tensor;
252/// # use scivex_core::linalg::gemv;
253/// // A = [[1, 2], [3, 4]], x = [5, 6]
254/// let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
255/// let x = Tensor::from_vec(vec![5.0, 6.0], vec![2]).unwrap();
256/// let mut y = Tensor::<f64>::zeros(vec![2]);
257/// gemv(1.0, &a, &x, 0.0, &mut y).unwrap();
258/// assert_eq!(y.as_slice(), &[17.0, 39.0]);
259/// ```
260#[allow(clippy::many_single_char_names)]
261pub fn gemv<T: Scalar>(
262    alpha: T,
263    a: &Tensor<T>,
264    x: &Tensor<T>,
265    beta: T,
266    y: &mut Tensor<T>,
267) -> Result<()> {
268    if a.ndim() != 2 {
269        return Err(CoreError::InvalidArgument {
270            reason: "gemv: `a` must be a 2-D tensor (matrix)",
271        });
272    }
273    if x.ndim() != 1 {
274        return Err(CoreError::InvalidArgument {
275            reason: "gemv: `x` must be a 1-D tensor (vector)",
276        });
277    }
278    if y.ndim() != 1 {
279        return Err(CoreError::InvalidArgument {
280            reason: "gemv: `y` must be a 1-D tensor (vector)",
281        });
282    }
283
284    let m = a.shape()[0];
285    let n = a.shape()[1];
286
287    if x.numel() != n {
288        return Err(CoreError::DimensionMismatch {
289            expected: vec![n],
290            got: x.shape().to_vec(),
291        });
292    }
293    if y.numel() != m {
294        return Err(CoreError::DimensionMismatch {
295            expected: vec![m],
296            got: y.shape().to_vec(),
297        });
298    }
299
300    let a_data = a.as_slice();
301    let x_data = x.as_slice();
302    let y_data = y.as_mut_slice();
303
304    for (i, yi) in y_data.iter_mut().enumerate().take(m) {
305        let row_offset = i * n;
306        let row = &a_data[row_offset..row_offset + n];
307        let sum = dot_slice(row, x_data);
308        *yi = alpha * sum + beta * *yi;
309    }
310
311    Ok(())
312}
313
314// ======================================================================
315// BLAS Level 3 — matrix-matrix operations, O(n^3)
316// ======================================================================
317
318/// General matrix-matrix multiply: `C = alpha * A * B + beta * C`.
319///
320/// - `a` must be 2-D with shape `[m, k]`.
321/// - `b` must be 2-D with shape `[k, n]`.
322/// - `c` must be 2-D with shape `[m, n]`.
323///
324/// If `beta` is zero, `c` is overwritten (not read).
325///
326/// ```
327/// # use scivex_core::tensor::Tensor;
328/// # use scivex_core::linalg::gemm;
329/// let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
330/// let b = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
331/// let mut c = Tensor::<f64>::zeros(vec![2, 2]);
332/// gemm(1.0, &a, &b, 0.0, &mut c).unwrap();
333/// assert_eq!(c.as_slice(), &[19.0, 22.0, 43.0, 50.0]);
334/// ```
335#[allow(clippy::many_single_char_names, clippy::too_many_lines)]
336pub fn gemm<T: Scalar>(
337    alpha: T,
338    a: &Tensor<T>,
339    b: &Tensor<T>,
340    beta: T,
341    c: &mut Tensor<T>,
342) -> Result<()> {
343    // Tile sizes for blocked GEMM — keep working sets in L1/L2 cache.
344    const MC: usize = 64; // row block size for A/C
345    const KC: usize = 256; // reduction-dimension block size
346    const NC: usize = 256; // column block size for B/C
347
348    if a.ndim() != 2 || b.ndim() != 2 || c.ndim() != 2 {
349        return Err(CoreError::InvalidArgument {
350            reason: "gemm: all arguments must be 2-D tensors (matrices)",
351        });
352    }
353
354    let m = a.shape()[0];
355    let k = a.shape()[1];
356    let n = b.shape()[1];
357
358    if b.shape()[0] != k {
359        return Err(CoreError::DimensionMismatch {
360            expected: vec![k, n],
361            got: b.shape().to_vec(),
362        });
363    }
364    if c.shape()[0] != m || c.shape()[1] != n {
365        return Err(CoreError::DimensionMismatch {
366            expected: vec![m, n],
367            got: c.shape().to_vec(),
368        });
369    }
370
371    let a_data = a.as_slice();
372    let b_data = b.as_slice();
373    let c_data = c.as_mut_slice();
374
375    // Scale C by beta first (or zero it).
376    if beta == T::zero() {
377        for v in c_data.iter_mut() {
378            *v = T::zero();
379        }
380    } else if beta != T::one() {
381        for v in c_data.iter_mut() {
382            *v *= beta;
383        }
384    }
385
386    // Blocked GEMM with cache-aware tiling.
387    // Within each tile, the IKJ loop order is used so that the innermost
388    // j-loop is a contiguous AXPY (auto-vectorized / SIMD-accelerated).
389
390    // Loop over K-dimension blocks
391    for pk in (0..k).step_by(KC) {
392        let kb = KC.min(k - pk);
393
394        // Loop over row blocks of A / C
395        for pi in (0..m).step_by(MC) {
396            let mb = MC.min(m - pi);
397
398            // Loop over column blocks of B / C
399            for pj in (0..n).step_by(NC) {
400                let nb = NC.min(n - pj);
401
402                // On aarch64 with f64, use the NEON 4x4 micro-kernel for the
403                // bulk of the tile, then clean up remainder rows/cols with axpy.
404                #[cfg(all(target_arch = "aarch64", feature = "simd"))]
405                {
406                    use std::any::TypeId;
407                    if TypeId::of::<T>() == TypeId::of::<f64>() {
408                        unsafe {
409                            let a_f64 = a_data.as_ptr().cast::<f64>();
410                            let b_f64 = b_data.as_ptr().cast::<f64>();
411                            let c_f64 = c_data.as_mut_ptr().cast::<f64>();
412                            let alpha_f64 = crate::simd::t_to_f64(alpha);
413
414                            let j4 = nb / 4 * 4;
415
416                            // Process 8-row blocks with 8x4 micro-kernel
417                            let i8 = mb / 8 * 8;
418                            for i in (0..i8).step_by(8) {
419                                for j in (0..j4).step_by(4) {
420                                    let a_off = (pi + i) * k + pk;
421                                    let b_off = pk * n + (pj + j);
422                                    let c_off = (pi + i) * n + (pj + j);
423                                    crate::simd::neon_f64_ops::gemm_8x4_f64_neon(
424                                        a_f64.add(a_off),
425                                        b_f64.add(b_off),
426                                        c_f64.add(c_off),
427                                        alpha_f64,
428                                        kb, k, n, n,
429                                    );
430                                }
431                                // Remainder columns (j4..nb)
432                                if j4 < nb {
433                                    for ii in 0..8 {
434                                        let row_a = (pi + i + ii) * k + pk;
435                                        let row_c = (pi + i + ii) * n + pj + j4;
436                                        for p in 0..kb {
437                                            let scale_f64 = alpha_f64 * *a_f64.add(row_a + p);
438                                            for jj in 0..(nb - j4) {
439                                                let b_idx = (pk + p) * n + pj + j4 + jj;
440                                                *c_f64.add(row_c + jj) += scale_f64 * *b_f64.add(b_idx);
441                                            }
442                                        }
443                                    }
444                                }
445                            }
446                            // Remaining 4-row block with 4x4 micro-kernel
447                            let i4_start = i8;
448                            let i4_end = i4_start + (mb - i8) / 4 * 4;
449                            for i in (i4_start..i4_end).step_by(4) {
450                                for j in (0..j4).step_by(4) {
451                                    let a_off = (pi + i) * k + pk;
452                                    let b_off = pk * n + (pj + j);
453                                    let c_off = (pi + i) * n + (pj + j);
454                                    crate::simd::neon_f64_ops::gemm_4x4_f64_neon(
455                                        a_f64.add(a_off),
456                                        b_f64.add(b_off),
457                                        c_f64.add(c_off),
458                                        alpha_f64,
459                                        kb, k, n, n,
460                                    );
461                                }
462                                if j4 < nb {
463                                    for ii in 0..4 {
464                                        let row_a = (pi + i + ii) * k + pk;
465                                        let row_c = (pi + i + ii) * n + pj + j4;
466                                        for p in 0..kb {
467                                            let scale_f64 = alpha_f64 * *a_f64.add(row_a + p);
468                                            for jj in 0..(nb - j4) {
469                                                let b_idx = (pk + p) * n + pj + j4 + jj;
470                                                *c_f64.add(row_c + jj) += scale_f64 * *b_f64.add(b_idx);
471                                            }
472                                        }
473                                    }
474                                }
475                            }
476                            // Scalar remainder rows
477                            for i in i4_end..mb {
478                                let row_a = (pi + i) * k + pk;
479                                let row_c = (pi + i) * n + pj;
480                                for p in 0..kb {
481                                    let scale = alpha * a_data[row_a + p];
482                                    let b_off2 = (pk + p) * n + pj;
483                                    let b_row = &b_data[b_off2..b_off2 + nb];
484                                    let c_slice = &mut c_data[row_c..row_c + nb];
485                                    axpy_slice(scale, b_row, c_slice);
486                                }
487                            }
488                        }
489                        continue;
490                    }
491                }
492
493                // Generic fallback (non-f64 or non-aarch64)
494                for i in 0..mb {
495                    let row_a = (pi + i) * k + pk;
496                    let row_c = (pi + i) * n + pj;
497                    for p in 0..kb {
498                        let scale = alpha * a_data[row_a + p];
499                        let b_off = (pk + p) * n + pj;
500                        let b_row = &b_data[b_off..b_off + nb];
501                        let c_slice = &mut c_data[row_c..row_c + nb];
502                        axpy_slice(scale, b_row, c_slice);
503                    }
504                }
505            }
506        }
507    }
508
509    Ok(())
510}
511
512// ======================================================================
513// Convenience methods on Tensor
514// ======================================================================
515
516impl<T: Scalar> Tensor<T> {
517    /// Matrix-vector multiply: returns `A @ x` as a new 1-D tensor.
518    ///
519    /// `self` must be 2-D `[m, n]`, `x` must be 1-D `[n]`.
520    ///
521    /// # Examples
522    ///
523    /// ```
524    /// # use scivex_core::tensor::Tensor;
525    /// let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
526    /// let x = Tensor::from_vec(vec![5.0, 6.0], vec![2]).unwrap();
527    /// let y = a.matvec(&x).unwrap();
528    /// assert_eq!(y.as_slice(), &[17.0, 39.0]);
529    /// ```
530    pub fn matvec(&self, x: &Tensor<T>) -> Result<Tensor<T>> {
531        let m = self.shape().first().copied().unwrap_or(0);
532        let mut y = Tensor::zeros(vec![m]);
533        gemv(T::one(), self, x, T::zero(), &mut y)?;
534        Ok(y)
535    }
536
537    /// Matrix-matrix multiply: returns `self @ other` as a new 2-D tensor.
538    ///
539    /// `self` must be 2-D `[m, k]`, `other` must be 2-D `[k, n]`.
540    ///
541    /// # Examples
542    ///
543    /// ```
544    /// # use scivex_core::tensor::Tensor;
545    /// let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
546    /// let b = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
547    /// let c = a.matmul(&b).unwrap();
548    /// assert_eq!(c.as_slice(), &[19.0, 22.0, 43.0, 50.0]);
549    /// ```
550    pub fn matmul(&self, other: &Tensor<T>) -> Result<Tensor<T>> {
551        let m = self.shape().first().copied().unwrap_or(0);
552        let n = other.shape().get(1).copied().unwrap_or(0);
553        let mut c = Tensor::zeros(vec![m, n]);
554        gemm(T::one(), self, other, T::zero(), &mut c)?;
555        Ok(c)
556    }
557
558    /// Dot product with another 1-D tensor.
559    ///
560    /// # Examples
561    ///
562    /// ```
563    /// # use scivex_core::tensor::Tensor;
564    /// let x = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap();
565    /// let y = Tensor::from_vec(vec![4.0, 5.0, 6.0], vec![3]).unwrap();
566    /// assert_eq!(x.dot(&y).unwrap(), 32.0);
567    /// ```
568    pub fn dot(&self, other: &Tensor<T>) -> Result<T> {
569        dot(self, other)
570    }
571}
572
573impl<T: Float> Tensor<T> {
574    /// Euclidean (L2) norm of a 1-D tensor.
575    ///
576    /// # Examples
577    ///
578    /// ```
579    /// # use scivex_core::tensor::Tensor;
580    /// let x = Tensor::from_vec(vec![3.0_f64, 4.0], vec![2]).unwrap();
581    /// assert!((x.norm().unwrap() - 5.0).abs() < 1e-10);
582    /// ```
583    pub fn norm(&self) -> Result<T> {
584        nrm2(self)
585    }
586
587    /// Solve the linear system `self * x = b` for a square matrix `self`.
588    ///
589    /// Uses LU decomposition with partial pivoting.
590    ///
591    /// # Examples
592    ///
593    /// ```
594    /// # use scivex_core::tensor::Tensor;
595    /// let a = Tensor::from_vec(vec![2.0_f64, 1.0, 1.0, 4.0], vec![2, 2]).unwrap();
596    /// let b = Tensor::from_vec(vec![5.0_f64, 6.0], vec![2]).unwrap();
597    /// let x = a.solve(&b).unwrap();
598    /// assert!((x.as_slice()[0] - 2.0).abs() < 1e-10);
599    /// assert!((x.as_slice()[1] - 1.0).abs() < 1e-10);
600    /// ```
601    pub fn solve(&self, b: &Tensor<T>) -> Result<Tensor<T>> {
602        crate::linalg::solve(self, b)
603    }
604
605    /// Compute the inverse of a square matrix.
606    ///
607    /// Uses LU decomposition with partial pivoting.
608    ///
609    /// # Examples
610    ///
611    /// ```
612    /// # use scivex_core::tensor::Tensor;
613    /// let a = Tensor::from_vec(vec![2.0_f64, 1.0, 1.0, 4.0], vec![2, 2]).unwrap();
614    /// let inv = a.inv().unwrap();
615    /// let eye = a.matmul(&inv).unwrap();
616    /// assert!((eye.as_slice()[0] - 1.0).abs() < 1e-10);
617    /// ```
618    pub fn inv(&self) -> Result<Tensor<T>> {
619        crate::linalg::inv(self)
620    }
621
622    /// Compute the determinant of a square matrix.
623    ///
624    /// Uses LU decomposition with partial pivoting.
625    ///
626    /// # Examples
627    ///
628    /// ```
629    /// # use scivex_core::tensor::Tensor;
630    /// let a = Tensor::from_vec(vec![2.0_f64, 1.0, 1.0, 4.0], vec![2, 2]).unwrap();
631    /// assert!((a.det().unwrap() - 7.0).abs() < 1e-10);
632    /// ```
633    pub fn det(&self) -> Result<T> {
634        crate::linalg::det(self)
635    }
636
637    /// Solve the least-squares problem `min ||self * x - b||_2`.
638    ///
639    /// Uses QR decomposition with Householder reflections.
640    ///
641    /// # Examples
642    ///
643    /// ```
644    /// # use scivex_core::tensor::Tensor;
645    /// // Overdetermined system: 2x = [2, 4, 6] => x ≈ [1, 2, 3] / something
646    /// let a = Tensor::from_vec(vec![1.0_f64, 1.0, 1.0], vec![3, 1]).unwrap();
647    /// let b = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap();
648    /// let x = a.lstsq(&b).unwrap();
649    /// assert_eq!(x.shape(), &[1]);
650    /// ```
651    pub fn lstsq(&self, b: &Tensor<T>) -> Result<Tensor<T>> {
652        crate::linalg::lstsq(self, b)
653    }
654}
655
656// ======================================================================
657// Internal helpers
658// ======================================================================
659
660fn check_vector<T: Scalar>(x: &Tensor<T>, name: &'static str) -> Result<()> {
661    if x.ndim() != 1 {
662        return Err(CoreError::InvalidArgument {
663            reason: match name {
664                "nrm2" => "nrm2: expected a 1-D tensor",
665                "asum" => "asum: expected a 1-D tensor",
666                "scal" => "scal: expected a 1-D tensor",
667                "iamax" => "iamax: expected a 1-D tensor",
668                _ => "expected a 1-D tensor",
669            },
670        });
671    }
672    Ok(())
673}
674
675fn check_vectors<T: Scalar>(x: &Tensor<T>, y: &Tensor<T>, name: &'static str) -> Result<()> {
676    if x.ndim() != 1 || y.ndim() != 1 {
677        return Err(CoreError::InvalidArgument {
678            reason: match name {
679                "dot" => "dot: both arguments must be 1-D tensors",
680                "axpy" => "axpy: both arguments must be 1-D tensors",
681                _ => "both arguments must be 1-D tensors",
682            },
683        });
684    }
685    if x.numel() != y.numel() {
686        return Err(CoreError::DimensionMismatch {
687            expected: x.shape().to_vec(),
688            got: y.shape().to_vec(),
689        });
690    }
691    Ok(())
692}
693
694#[cfg(test)]
695#[allow(clippy::float_cmp)]
696mod tests {
697    use super::*;
698
699    // ------------------------------------------------------------------
700    // Helpers
701    // ------------------------------------------------------------------
702
703    fn vec_f64(data: &[f64]) -> Tensor<f64> {
704        Tensor::from_vec(data.to_vec(), vec![data.len()]).unwrap()
705    }
706
707    fn mat_f64(data: &[f64], rows: usize, cols: usize) -> Tensor<f64> {
708        Tensor::from_vec(data.to_vec(), vec![rows, cols]).unwrap()
709    }
710
711    // ------------------------------------------------------------------
712    // BLAS L1
713    // ------------------------------------------------------------------
714
715    #[test]
716    fn test_dot_basic() {
717        let x = vec_f64(&[1.0, 2.0, 3.0]);
718        let y = vec_f64(&[4.0, 5.0, 6.0]);
719        assert_eq!(dot(&x, &y).unwrap(), 32.0);
720    }
721
722    #[test]
723    fn test_dot_single() {
724        let x = vec_f64(&[3.0]);
725        let y = vec_f64(&[7.0]);
726        assert_eq!(dot(&x, &y).unwrap(), 21.0);
727    }
728
729    #[test]
730    fn test_dot_length_mismatch() {
731        let x = vec_f64(&[1.0, 2.0]);
732        let y = vec_f64(&[1.0, 2.0, 3.0]);
733        assert!(dot(&x, &y).is_err());
734    }
735
736    #[test]
737    fn test_dot_not_1d() {
738        let x = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
739        let y = vec_f64(&[1.0, 2.0]);
740        assert!(dot(&x, &y).is_err());
741    }
742
743    #[test]
744    fn test_axpy() {
745        let x = vec_f64(&[1.0, 2.0, 3.0]);
746        let mut y = vec_f64(&[10.0, 20.0, 30.0]);
747        axpy(2.0, &x, &mut y).unwrap();
748        assert_eq!(y.as_slice(), &[12.0, 24.0, 36.0]);
749    }
750
751    #[test]
752    fn test_axpy_zero_alpha() {
753        let x = vec_f64(&[1.0, 2.0, 3.0]);
754        let mut y = vec_f64(&[10.0, 20.0, 30.0]);
755        axpy(0.0, &x, &mut y).unwrap();
756        assert_eq!(y.as_slice(), &[10.0, 20.0, 30.0]);
757    }
758
759    #[test]
760    fn test_nrm2() {
761        let x = vec_f64(&[3.0, 4.0]);
762        assert!((nrm2(&x).unwrap() - 5.0).abs() < 1e-10);
763    }
764
765    #[test]
766    fn test_nrm2_single() {
767        let x = vec_f64(&[-7.0]);
768        assert!((nrm2(&x).unwrap() - 7.0).abs() < 1e-10);
769    }
770
771    #[test]
772    fn test_asum() {
773        let x = vec_f64(&[-1.0, 2.0, -3.0, 4.0]);
774        assert!((asum(&x).unwrap() - 10.0).abs() < 1e-10);
775    }
776
777    #[test]
778    fn test_scal() {
779        let mut x = vec_f64(&[1.0, 2.0, 3.0]);
780        scal(10.0, &mut x).unwrap();
781        assert_eq!(x.as_slice(), &[10.0, 20.0, 30.0]);
782    }
783
784    #[test]
785    fn test_scal_zero() {
786        let mut x = vec_f64(&[1.0, 2.0, 3.0]);
787        scal(0.0, &mut x).unwrap();
788        assert_eq!(x.as_slice(), &[0.0, 0.0, 0.0]);
789    }
790
791    #[test]
792    fn test_iamax() {
793        let x = vec_f64(&[1.0, -5.0, 3.0, -2.0]);
794        assert_eq!(iamax(&x).unwrap(), Some(1));
795    }
796
797    #[test]
798    fn test_iamax_first_is_max() {
799        let x = vec_f64(&[100.0, 1.0, 2.0]);
800        assert_eq!(iamax(&x).unwrap(), Some(0));
801    }
802
803    #[test]
804    fn test_iamax_empty() {
805        let x = Tensor::<f64>::zeros(vec![0]);
806        assert_eq!(iamax(&x).unwrap(), None);
807    }
808
809    // ------------------------------------------------------------------
810    // BLAS L2
811    // ------------------------------------------------------------------
812
813    #[test]
814    fn test_gemv_basic() {
815        // A = [[1, 2], [3, 4]], x = [5, 6]
816        // y = A @ x = [17, 39]
817        let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
818        let x = vec_f64(&[5.0, 6.0]);
819        let mut y = Tensor::<f64>::zeros(vec![2]);
820        gemv(1.0, &a, &x, 0.0, &mut y).unwrap();
821        assert_eq!(y.as_slice(), &[17.0, 39.0]);
822    }
823
824    #[test]
825    fn test_gemv_with_alpha_beta() {
826        // y = 2 * A @ x + 3 * y
827        let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
828        let x = vec_f64(&[1.0, 1.0]);
829        let mut y = vec_f64(&[10.0, 10.0]);
830        gemv(2.0, &a, &x, 3.0, &mut y).unwrap();
831        // A @ x = [3, 7], 2*[3,7] + 3*[10,10] = [6+30, 14+30] = [36, 44]
832        assert_eq!(y.as_slice(), &[36.0, 44.0]);
833    }
834
835    #[test]
836    fn test_gemv_rectangular() {
837        // A = [[1, 2, 3], [4, 5, 6]]  (2x3)
838        // x = [1, 0, 1]  (3)
839        // y = A @ x = [4, 10]
840        let a = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
841        let x = vec_f64(&[1.0, 0.0, 1.0]);
842        let mut y = Tensor::<f64>::zeros(vec![2]);
843        gemv(1.0, &a, &x, 0.0, &mut y).unwrap();
844        assert_eq!(y.as_slice(), &[4.0, 10.0]);
845    }
846
847    #[test]
848    fn test_gemv_dimension_mismatch() {
849        let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
850        let x = vec_f64(&[1.0, 2.0, 3.0]);
851        let mut y = Tensor::<f64>::zeros(vec![2]);
852        assert!(gemv(1.0, &a, &x, 0.0, &mut y).is_err());
853    }
854
855    #[test]
856    fn test_gemv_y_dimension_mismatch() {
857        let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
858        let x = vec_f64(&[1.0, 2.0]);
859        let mut y = Tensor::<f64>::zeros(vec![3]);
860        assert!(gemv(1.0, &a, &x, 0.0, &mut y).is_err());
861    }
862
863    // ------------------------------------------------------------------
864    // BLAS L3
865    // ------------------------------------------------------------------
866
867    #[test]
868    fn test_gemm_square() {
869        // A = [[1, 2], [3, 4]]
870        // B = [[5, 6], [7, 8]]
871        // C = A @ B = [[19, 22], [43, 50]]
872        let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
873        let b = mat_f64(&[5.0, 6.0, 7.0, 8.0], 2, 2);
874        let mut c = Tensor::<f64>::zeros(vec![2, 2]);
875        gemm(1.0, &a, &b, 0.0, &mut c).unwrap();
876        assert_eq!(c.as_slice(), &[19.0, 22.0, 43.0, 50.0]);
877    }
878
879    #[test]
880    fn test_gemm_rectangular() {
881        // A (2x3) @ B (3x2) = C (2x2)
882        let a = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
883        let b = mat_f64(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], 3, 2);
884        let mut c = Tensor::<f64>::zeros(vec![2, 2]);
885        gemm(1.0, &a, &b, 0.0, &mut c).unwrap();
886        // Row 0: 1*7+2*9+3*11 = 7+18+33 = 58, 1*8+2*10+3*12 = 8+20+36 = 64
887        // Row 1: 4*7+5*9+6*11 = 28+45+66 = 139, 4*8+5*10+6*12 = 32+50+72 = 154
888        assert_eq!(c.as_slice(), &[58.0, 64.0, 139.0, 154.0]);
889    }
890
891    #[test]
892    fn test_gemm_with_alpha_beta() {
893        // C = 2 * A @ B + 3 * C
894        let a = mat_f64(&[1.0, 0.0, 0.0, 1.0], 2, 2); // identity
895        let b = mat_f64(&[5.0, 6.0, 7.0, 8.0], 2, 2);
896        let mut c = mat_f64(&[1.0, 1.0, 1.0, 1.0], 2, 2);
897        gemm(2.0, &a, &b, 3.0, &mut c).unwrap();
898        // 2*B + 3*ones = [10+3, 12+3, 14+3, 16+3] = [13, 15, 17, 19]
899        assert_eq!(c.as_slice(), &[13.0, 15.0, 17.0, 19.0]);
900    }
901
902    #[test]
903    fn test_gemm_identity() {
904        let a = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 3, 3);
905        let eye = Tensor::<f64>::eye(3);
906        let mut c = Tensor::<f64>::zeros(vec![3, 3]);
907        gemm(1.0, &a, &eye, 0.0, &mut c).unwrap();
908        assert_eq!(c.as_slice(), a.as_slice());
909    }
910
911    #[test]
912    fn test_gemm_single_element() {
913        let a = mat_f64(&[3.0], 1, 1);
914        let b = mat_f64(&[7.0], 1, 1);
915        let mut c = Tensor::<f64>::zeros(vec![1, 1]);
916        gemm(1.0, &a, &b, 0.0, &mut c).unwrap();
917        assert_eq!(c.as_slice(), &[21.0]);
918    }
919
920    #[test]
921    fn test_gemm_dimension_mismatch() {
922        let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
923        let b = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
924        let mut c = Tensor::<f64>::zeros(vec![2, 2]);
925        assert!(gemm(1.0, &a, &b, 0.0, &mut c).is_err());
926    }
927
928    #[test]
929    fn test_gemm_c_shape_mismatch() {
930        let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
931        let b = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
932        let mut c = Tensor::<f64>::zeros(vec![3, 3]);
933        assert!(gemm(1.0, &a, &b, 0.0, &mut c).is_err());
934    }
935
936    // ------------------------------------------------------------------
937    // Convenience methods
938    // ------------------------------------------------------------------
939
940    #[test]
941    fn test_matvec() {
942        let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
943        let x = vec_f64(&[5.0, 6.0]);
944        let y = a.matvec(&x).unwrap();
945        assert_eq!(y.as_slice(), &[17.0, 39.0]);
946    }
947
948    #[test]
949    fn test_matmul() {
950        let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
951        let b = mat_f64(&[5.0, 6.0, 7.0, 8.0], 2, 2);
952        let c = a.matmul(&b).unwrap();
953        assert_eq!(c.as_slice(), &[19.0, 22.0, 43.0, 50.0]);
954    }
955
956    #[test]
957    fn test_tensor_dot() {
958        let x = vec_f64(&[1.0, 2.0, 3.0]);
959        let y = vec_f64(&[4.0, 5.0, 6.0]);
960        assert_eq!(x.dot(&y).unwrap(), 32.0);
961    }
962
963    #[test]
964    fn test_tensor_norm() {
965        let x = vec_f64(&[3.0, 4.0]);
966        assert!((x.norm().unwrap() - 5.0).abs() < 1e-10);
967    }
968
969    // ------------------------------------------------------------------
970    // NumPy reference values
971    // ------------------------------------------------------------------
972
973    #[test]
974    fn test_gemm_numpy_reference() {
975        // >>> import numpy as np
976        // >>> a = np.array([[1,2,3],[4,5,6]], dtype=np.float64)
977        // >>> b = np.array([[7,8],[9,10],[11,12]], dtype=np.float64)
978        // >>> a @ b
979        // array([[ 58.,  64.],
980        //        [139., 154.]])
981        let a = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
982        let b = mat_f64(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], 3, 2);
983        let c = a.matmul(&b).unwrap();
984        assert_eq!(c.as_slice(), &[58.0, 64.0, 139.0, 154.0]);
985    }
986
987    #[test]
988    fn test_gemv_numpy_reference() {
989        // >>> a = np.array([[1,2,3],[4,5,6]], dtype=np.float64)
990        // >>> x = np.array([1,1,1], dtype=np.float64)
991        // >>> a @ x
992        // array([ 6., 15.])
993        let a = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
994        let x = vec_f64(&[1.0, 1.0, 1.0]);
995        let y = a.matvec(&x).unwrap();
996        assert_eq!(y.as_slice(), &[6.0, 15.0]);
997    }
998
999    #[test]
1000    fn test_dot_numpy_reference() {
1001        // >>> np.dot([1,2,3,4,5], [5,4,3,2,1])
1002        // 35
1003        let x = vec_f64(&[1.0, 2.0, 3.0, 4.0, 5.0]);
1004        let y = vec_f64(&[5.0, 4.0, 3.0, 2.0, 1.0]);
1005        assert_eq!(dot(&x, &y).unwrap(), 35.0);
1006    }
1007
1008    #[test]
1009    fn test_nrm2_numpy_reference() {
1010        // >>> np.linalg.norm([1, 2, 3, 4, 5])
1011        // 7.416198487095663
1012        let x = vec_f64(&[1.0, 2.0, 3.0, 4.0, 5.0]);
1013        let n = nrm2(&x).unwrap();
1014        assert!((n - 7.416_198_487_095_663).abs() < 1e-12);
1015    }
1016}