Skip to main content

scirs2_core/
ops.rs

1//! # Ergonomic Matrix Operations
2//!
3//! This module provides shorthand functions for the most common matrix and
4//! vector operations in scientific computing, making it easy to write concise
5//! numerical code without importing many different traits.
6//!
7//! All functions are generic over `f64`/`f32` where practical, or carry their
8//! own trait bounds so IDEs can provide accurate type hints.
9//!
10//! ## Operations
11//!
12//! | Function | Description |
13//! |----------|-------------|
14//! | [`dot`] | Matrix–matrix multiply (`C = A · B`) |
15//! | [`outer`] | Outer (tensor) product of two 1D arrays |
16//! | [`kron`] | Kronecker product of two 2D arrays |
17//! | [`vstack`] | Vertical concatenation of 2D arrays (add rows) |
18//! | [`hstack`] | Horizontal concatenation of 2D arrays (add columns) |
19//! | [`block_diag`] | Build a block-diagonal matrix from a sequence of 2D blocks |
20//!
21//! ## Design Notes
22//!
23//! - **No unwrap** — every fallible function returns [`CoreResult`].
24//! - **Generic** — all functions are parameterised over numeric type `T`.
25//! - The implementations deliberately avoid pulling in full BLAS; they use pure
26//!   ndarray loops. For production workloads that require maximum matrix-multiply
27//!   throughput, use `scirs2-linalg` which delegates to OxiBLAS.
28
29use crate::error::{CoreError, CoreResult, ErrorContext};
30use ::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
31use num_traits::Zero;
32use std::ops::{Add, Mul};
33
34// ============================================================================
35// dot — Matrix–Matrix Multiplication
36// ============================================================================
37
38/// Compute the matrix product `C = A · B`.
39///
40/// This is a convenient shorthand for the ndarray `.dot()` method, useful when
41/// the full method syntax is overly verbose or when chaining with other ops
42/// from this module.
43///
44/// # Type Parameters
45///
46/// `T` must support addition (with identity `Zero`) and multiplication.
47///
48/// # Arguments
49///
50/// * `a` — Left operand, shape `(m, k)`
51/// * `b` — Right operand, shape `(k, n)`
52///
53/// # Returns
54///
55/// A new `Array2<T>` of shape `(m, n)`.
56///
57/// # Panics
58///
59/// The inner dimensions of `a` and `b` must match (`a.ncols() == b.nrows()`);
60/// ndarray will panic if they do not. This matches the standard contract for
61/// `.dot()`.
62///
63/// # Examples
64///
65/// ```rust
66/// use scirs2_core::ops::dot;
67/// use ndarray::array;
68///
69/// let a = array![[1.0_f64, 0.0], [0.0, 1.0]]; // Identity
70/// let b = array![[3.0_f64, 4.0], [5.0, 6.0]];
71/// let c = dot(&a.view(), &b.view());
72/// assert_eq!(c, b);
73/// ```
74pub fn dot<T>(a: &ArrayView2<T>, b: &ArrayView2<T>) -> Array2<T>
75where
76    T: Clone + Zero + Add<Output = T> + Mul<Output = T>,
77{
78    let (m, k) = (a.nrows(), a.ncols());
79    let (k2, n) = (b.nrows(), b.ncols());
80    debug_assert_eq!(k, k2, "dot: inner dimensions must match");
81
82    let mut result = Array2::<T>::zeros((m, n));
83    for i in 0..m {
84        for j in 0..n {
85            let mut sum = T::zero();
86            for l in 0..k {
87                sum = sum + a[[i, l]].clone() * b[[l, j]].clone();
88            }
89            result[[i, j]] = sum;
90        }
91    }
92    result
93}
94
95// ============================================================================
96// outer — Outer Product
97// ============================================================================
98
99/// Compute the outer product of two 1D arrays.
100///
101/// Given vectors `u` of length `m` and `v` of length `n`, returns an `m × n`
102/// matrix `M` where `M[i, j] = u[i] * v[j]`.
103///
104/// # Examples
105///
106/// ```rust
107/// use scirs2_core::ops::outer;
108/// use ndarray::array;
109///
110/// let u = array![1.0_f64, 2.0, 3.0];
111/// let v = array![4.0_f64, 5.0];
112/// let m = outer(&u.view(), &v.view());
113/// assert_eq!(m.shape(), &[3, 2]);
114/// assert_eq!(m[[0, 0]], 4.0);
115/// assert_eq!(m[[2, 1]], 15.0);
116/// ```
117pub fn outer<T>(u: &ArrayView1<T>, v: &ArrayView1<T>) -> Array2<T>
118where
119    T: Clone + Zero + Mul<Output = T>,
120{
121    let m = u.len();
122    let n = v.len();
123    Array2::from_shape_fn((m, n), |(i, j)| u[i].clone() * v[j].clone())
124}
125
126// ============================================================================
127// kron — Kronecker Product
128// ============================================================================
129
130/// Compute the Kronecker product of two 2D arrays.
131///
132/// If `A` has shape `(p, q)` and `B` has shape `(r, s)`, the result has shape
133/// `(p·r, q·s)` where block `(i, j)` equals `A[i,j] * B`.
134///
135/// # Examples
136///
137/// ```rust
138/// use scirs2_core::ops::kron;
139/// use ndarray::array;
140///
141/// let a = array![[1_i32, 0], [0, 1]]; // 2×2 identity
142/// let b = array![[1_i32, 2], [3, 4]];
143/// let k = kron(&a.view(), &b.view());
144/// assert_eq!(k.shape(), &[4, 4]);
145/// // Top-left block should be `1 * b`
146/// assert_eq!(k[[0, 0]], 1);
147/// assert_eq!(k[[0, 1]], 2);
148/// ```
149pub fn kron<T>(a: &ArrayView2<T>, b: &ArrayView2<T>) -> Array2<T>
150where
151    T: Clone + Zero + Mul<Output = T>,
152{
153    let (p, q) = (a.nrows(), a.ncols());
154    let (r, s) = (b.nrows(), b.ncols());
155
156    Array2::from_shape_fn((p * r, q * s), |(i, j)| {
157        let ai = i / r;
158        let bi = i % r;
159        let aj = j / s;
160        let bj = j % s;
161        a[[ai, aj]].clone() * b[[bi, bj]].clone()
162    })
163}
164
165// ============================================================================
166// vstack — Vertical Stack
167// ============================================================================
168
169/// Stack a sequence of 2D arrays vertically (concatenate rows).
170///
171/// All arrays must have the same number of columns. Returns an error if the
172/// slice is empty or if any column-count mismatches exist.
173///
174/// # Examples
175///
176/// ```rust
177/// use scirs2_core::ops::vstack;
178/// use ndarray::array;
179///
180/// let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
181/// let b = array![[5.0_f64, 6.0]];
182/// let s = vstack(&[a.view(), b.view()]).expect("same column count");
183/// assert_eq!(s.shape(), &[3, 2]);
184/// assert_eq!(s[[2, 0]], 5.0);
185/// ```
186pub fn vstack<T>(arrays: &[ArrayView2<T>]) -> CoreResult<Array2<T>>
187where
188    T: Clone + Zero,
189{
190    if arrays.is_empty() {
191        return Err(CoreError::InvalidInput(ErrorContext::new(
192            "vstack: cannot stack an empty slice of arrays",
193        )));
194    }
195
196    let ncols = arrays[0].ncols();
197    for (idx, arr) in arrays.iter().enumerate().skip(1) {
198        if arr.ncols() != ncols {
199            return Err(CoreError::InvalidInput(ErrorContext::new(format!(
200                "vstack: array at index {idx} has {cols} columns, expected {ncols}",
201                cols = arr.ncols()
202            ))));
203        }
204    }
205
206    let total_rows: usize = arrays.iter().map(|a| a.nrows()).sum();
207    let mut result = Array2::<T>::zeros((total_rows, ncols));
208
209    let mut row_offset = 0;
210    for arr in arrays {
211        let nrows = arr.nrows();
212        for r in 0..nrows {
213            for c in 0..ncols {
214                result[[row_offset + r, c]] = arr[[r, c]].clone();
215            }
216        }
217        row_offset += nrows;
218    }
219
220    Ok(result)
221}
222
223// ============================================================================
224// hstack — Horizontal Stack
225// ============================================================================
226
227/// Stack a sequence of 2D arrays horizontally (concatenate columns).
228///
229/// All arrays must have the same number of rows. Returns an error if the
230/// slice is empty or if any row-count mismatches exist.
231///
232/// # Examples
233///
234/// ```rust
235/// use scirs2_core::ops::hstack;
236/// use ndarray::array;
237///
238/// let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
239/// let b = array![[5.0_f64], [6.0]];
240/// let s = hstack(&[a.view(), b.view()]).expect("same row count");
241/// assert_eq!(s.shape(), &[2, 3]);
242/// assert_eq!(s[[0, 2]], 5.0);
243/// assert_eq!(s[[1, 2]], 6.0);
244/// ```
245pub fn hstack<T>(arrays: &[ArrayView2<T>]) -> CoreResult<Array2<T>>
246where
247    T: Clone + Zero,
248{
249    if arrays.is_empty() {
250        return Err(CoreError::InvalidInput(ErrorContext::new(
251            "hstack: cannot stack an empty slice of arrays",
252        )));
253    }
254
255    let nrows = arrays[0].nrows();
256    for (idx, arr) in arrays.iter().enumerate().skip(1) {
257        if arr.nrows() != nrows {
258            return Err(CoreError::InvalidInput(ErrorContext::new(format!(
259                "hstack: array at index {idx} has {r} rows, expected {nrows}",
260                r = arr.nrows()
261            ))));
262        }
263    }
264
265    let total_cols: usize = arrays.iter().map(|a| a.ncols()).sum();
266    let mut result = Array2::<T>::zeros((nrows, total_cols));
267
268    let mut col_offset = 0;
269    for arr in arrays {
270        let ncols = arr.ncols();
271        for r in 0..nrows {
272            for c in 0..ncols {
273                result[[r, col_offset + c]] = arr[[r, c]].clone();
274            }
275        }
276        col_offset += ncols;
277    }
278
279    Ok(result)
280}
281
282// ============================================================================
283// block_diag — Block Diagonal Matrix
284// ============================================================================
285
286/// Build a block-diagonal matrix from a sequence of 2D blocks.
287///
288/// Given blocks `B₀`, `B₁`, …, `Bₙ` with shapes `(r₀, c₀)`, `(r₁, c₁)`, …,
289/// the result is a `(Σrᵢ) × (Σcᵢ)` matrix with each block placed on the
290/// diagonal and zeros elsewhere.
291///
292/// Returns an empty `0×0` matrix when given an empty slice.
293///
294/// # Examples
295///
296/// ```rust
297/// use scirs2_core::ops::block_diag;
298/// use ndarray::array;
299///
300/// let a = array![[1_i32, 2], [3, 4]];
301/// let b = array![[5_i32]];
302/// let bd = block_diag(&[a.view(), b.view()]);
303/// assert_eq!(bd.shape(), &[3, 3]);
304/// assert_eq!(bd[[0, 0]], 1);
305/// assert_eq!(bd[[2, 2]], 5);
306/// assert_eq!(bd[[0, 2]], 0); // off-block element
307/// ```
308pub fn block_diag<T>(blocks: &[ArrayView2<T>]) -> Array2<T>
309where
310    T: Clone + Zero,
311{
312    if blocks.is_empty() {
313        return Array2::<T>::zeros((0, 0));
314    }
315
316    let total_rows: usize = blocks.iter().map(|b| b.nrows()).sum();
317    let total_cols: usize = blocks.iter().map(|b| b.ncols()).sum();
318
319    let mut result = Array2::<T>::zeros((total_rows, total_cols));
320
321    let mut row_off = 0;
322    let mut col_off = 0;
323    for block in blocks {
324        let (br, bc) = (block.nrows(), block.ncols());
325        for r in 0..br {
326            for c in 0..bc {
327                result[[row_off + r, col_off + c]] = block[[r, c]].clone();
328            }
329        }
330        row_off += br;
331        col_off += bc;
332    }
333
334    result
335}
336
337// ============================================================================
338// Tests
339// ============================================================================
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use ::ndarray::array;
345    use approx::assert_abs_diff_eq;
346
347    // --- dot ---
348
349    #[test]
350    fn test_dot_identity() {
351        let eye = array![[1.0_f64, 0.0], [0.0, 1.0]];
352        let b = array![[3.0_f64, 4.0], [5.0, 6.0]];
353        let c = dot(&eye.view(), &b.view());
354        assert_abs_diff_eq!(c[[0, 0]], 3.0, epsilon = 1e-12);
355        assert_abs_diff_eq!(c[[1, 1]], 6.0, epsilon = 1e-12);
356    }
357
358    #[test]
359    fn test_dot_rectangular() {
360        // (2×3) · (3×2) → (2×2)
361        let a = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
362        let b = array![[7.0_f64, 8.0], [9.0, 10.0], [11.0, 12.0]];
363        let c = dot(&a.view(), &b.view());
364        assert_eq!(c.shape(), &[2, 2]);
365        // Row 0: [1·7+2·9+3·11, 1·8+2·10+3·12] = [58, 64]
366        assert_abs_diff_eq!(c[[0, 0]], 58.0, epsilon = 1e-12);
367        assert_abs_diff_eq!(c[[0, 1]], 64.0, epsilon = 1e-12);
368        // Row 1: [4·7+5·9+6·11, 4·8+5·10+6·12] = [139, 154]
369        assert_abs_diff_eq!(c[[1, 0]], 139.0, epsilon = 1e-12);
370        assert_abs_diff_eq!(c[[1, 1]], 154.0, epsilon = 1e-12);
371    }
372
373    #[test]
374    fn test_dot_integers() {
375        let a = array![[1_i32, 2], [3, 4]];
376        let b = array![[5_i32, 6], [7, 8]];
377        let c = dot(&a.view(), &b.view());
378        assert_eq!(c[[0, 0]], 19); // 1·5+2·7
379        assert_eq!(c[[1, 1]], 50); // 3·6+4·8
380    }
381
382    // --- outer ---
383
384    #[test]
385    fn test_outer_basic() {
386        let u = array![1.0_f64, 2.0, 3.0];
387        let v = array![4.0_f64, 5.0];
388        let m = outer(&u.view(), &v.view());
389        assert_eq!(m.shape(), &[3, 2]);
390        assert_abs_diff_eq!(m[[0, 0]], 4.0, epsilon = 1e-12);
391        assert_abs_diff_eq!(m[[1, 1]], 10.0, epsilon = 1e-12);
392        assert_abs_diff_eq!(m[[2, 0]], 12.0, epsilon = 1e-12);
393        assert_abs_diff_eq!(m[[2, 1]], 15.0, epsilon = 1e-12);
394    }
395
396    #[test]
397    fn test_outer_integers() {
398        let u = array![1_i32, 2];
399        let v = array![3_i32, 4, 5];
400        let m = outer(&u.view(), &v.view());
401        assert_eq!(m.shape(), &[2, 3]);
402        assert_eq!(m[[0, 0]], 3);
403        assert_eq!(m[[1, 2]], 10);
404    }
405
406    // --- kron ---
407
408    #[test]
409    fn test_kron_identity_identity() {
410        let eye2 = array![[1_i32, 0], [0, 1]];
411        let eye3 = array![[1_i32, 0, 0], [0, 1, 0], [0, 0, 1]];
412        let k = kron(&eye2.view(), &eye3.view());
413        assert_eq!(k.shape(), &[6, 6]);
414        // Should also be the 6×6 identity
415        for i in 0..6 {
416            for j in 0..6 {
417                assert_eq!(k[[i, j]], if i == j { 1 } else { 0 });
418            }
419        }
420    }
421
422    #[test]
423    fn test_kron_scalar() {
424        let two = array![[2_i32]];
425        let b = array![[1_i32, 2], [3, 4]];
426        let k = kron(&two.view(), &b.view());
427        assert_eq!(k.shape(), &[2, 2]);
428        assert_eq!(k[[0, 0]], 2);
429        assert_eq!(k[[1, 1]], 8);
430    }
431
432    #[test]
433    fn test_kron_matches_expected() {
434        // From NumPy docs:
435        // A = [[1, 2], [3, 4]]
436        // B = [[0, 5], [6, 7]]
437        let a = array![[1_i32, 2], [3, 4]];
438        let b = array![[0_i32, 5], [6, 7]];
439        let k = kron(&a.view(), &b.view());
440        assert_eq!(k.shape(), &[4, 4]);
441        // Row 0: 1*[0,5] ++ 2*[0,5] = [0,5,0,10]
442        assert_eq!(k[[0, 0]], 0);
443        assert_eq!(k[[0, 1]], 5);
444        assert_eq!(k[[0, 2]], 0);
445        assert_eq!(k[[0, 3]], 10);
446        // Row 3: 3*[6,7] ++ 4*[6,7] = [18,21,24,28]
447        assert_eq!(k[[3, 0]], 18);
448        assert_eq!(k[[3, 1]], 21);
449        assert_eq!(k[[3, 2]], 24);
450        assert_eq!(k[[3, 3]], 28);
451    }
452
453    // --- vstack ---
454
455    #[test]
456    fn test_vstack_basic() {
457        let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
458        let b = array![[5.0_f64, 6.0]];
459        let s = vstack(&[a.view(), b.view()]).expect("same cols");
460        assert_eq!(s.shape(), &[3, 2]);
461        assert_abs_diff_eq!(s[[2, 0]], 5.0, epsilon = 1e-12);
462        assert_abs_diff_eq!(s[[2, 1]], 6.0, epsilon = 1e-12);
463    }
464
465    #[test]
466    fn test_vstack_three_arrays() {
467        let a = array![[1_i32, 2]];
468        let b = array![[3_i32, 4]];
469        let c = array![[5_i32, 6], [7, 8]];
470        let s = vstack(&[a.view(), b.view(), c.view()]).expect("same cols");
471        assert_eq!(s.shape(), &[4, 2]);
472        assert_eq!(s[[0, 0]], 1);
473        assert_eq!(s[[1, 1]], 4);
474        assert_eq!(s[[2, 0]], 5);
475        assert_eq!(s[[3, 1]], 8);
476    }
477
478    #[test]
479    fn test_vstack_mismatch_error() {
480        let a = array![[1.0_f64, 2.0, 3.0]]; // 3 cols
481        let b = array![[4.0_f64, 5.0]]; // 2 cols
482        assert!(vstack(&[a.view(), b.view()]).is_err());
483    }
484
485    #[test]
486    fn test_vstack_empty_error() {
487        let empty: &[ArrayView2<f64>] = &[];
488        assert!(vstack(empty).is_err());
489    }
490
491    // --- hstack ---
492
493    #[test]
494    fn test_hstack_basic() {
495        let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
496        let b = array![[5.0_f64], [6.0]];
497        let s = hstack(&[a.view(), b.view()]).expect("same rows");
498        assert_eq!(s.shape(), &[2, 3]);
499        assert_abs_diff_eq!(s[[0, 2]], 5.0, epsilon = 1e-12);
500        assert_abs_diff_eq!(s[[1, 2]], 6.0, epsilon = 1e-12);
501    }
502
503    #[test]
504    fn test_hstack_three_arrays() {
505        let a = array![[1_i32], [2]];
506        let b = array![[3_i32], [4]];
507        let c = array![[5_i32, 6], [7, 8]];
508        let s = hstack(&[a.view(), b.view(), c.view()]).expect("same rows");
509        assert_eq!(s.shape(), &[2, 4]);
510        assert_eq!(s[[0, 0]], 1);
511        assert_eq!(s[[0, 1]], 3);
512        assert_eq!(s[[1, 3]], 8);
513    }
514
515    #[test]
516    fn test_hstack_mismatch_error() {
517        let a = array![[1.0_f64], [2.0], [3.0]]; // 3 rows
518        let b = array![[4.0_f64], [5.0]]; // 2 rows
519        assert!(hstack(&[a.view(), b.view()]).is_err());
520    }
521
522    #[test]
523    fn test_hstack_empty_error() {
524        let empty: &[ArrayView2<f64>] = &[];
525        assert!(hstack(empty).is_err());
526    }
527
528    // --- block_diag ---
529
530    #[test]
531    fn test_block_diag_square_blocks() {
532        let a = array![[1_i32, 2], [3, 4]];
533        let b = array![[5_i32, 6], [7, 8]];
534        let bd = block_diag(&[a.view(), b.view()]);
535        assert_eq!(bd.shape(), &[4, 4]);
536        assert_eq!(bd[[0, 0]], 1);
537        assert_eq!(bd[[1, 1]], 4);
538        assert_eq!(bd[[2, 2]], 5);
539        assert_eq!(bd[[3, 3]], 8);
540        // Off-diagonal blocks should be zero
541        assert_eq!(bd[[0, 2]], 0);
542        assert_eq!(bd[[3, 0]], 0);
543    }
544
545    #[test]
546    fn test_block_diag_rectangular_blocks() {
547        let a = array![[1_i32, 2, 3]]; // 1×3
548        let b = array![[4_i32], [5]]; // 2×1
549        let bd = block_diag(&[a.view(), b.view()]);
550        assert_eq!(bd.shape(), &[3, 4]);
551        // a block at rows 0, cols 0..3
552        assert_eq!(bd[[0, 2]], 3);
553        // b block at rows 1..3, col 3
554        assert_eq!(bd[[1, 3]], 4);
555        assert_eq!(bd[[2, 3]], 5);
556        // zeros
557        assert_eq!(bd[[1, 0]], 0);
558    }
559
560    #[test]
561    fn test_block_diag_single() {
562        let a = array![[9_i32]];
563        let bd = block_diag(&[a.view()]);
564        assert_eq!(bd.shape(), &[1, 1]);
565        assert_eq!(bd[[0, 0]], 9);
566    }
567
568    #[test]
569    fn test_block_diag_empty() {
570        let empty: &[ArrayView2<i32>] = &[];
571        let bd = block_diag(empty);
572        assert_eq!(bd.shape(), &[0, 0]);
573    }
574
575    #[test]
576    fn test_block_diag_three_blocks() {
577        let a = array![[1_i32, 2], [3, 4]];
578        let b = array![[5_i32]];
579        let c = array![[6_i32, 7, 8]];
580        let bd = block_diag(&[a.view(), b.view(), c.view()]);
581        assert_eq!(bd.shape(), &[4, 6]);
582        // Check corners of each block
583        assert_eq!(bd[[0, 0]], 1);
584        assert_eq!(bd[[1, 1]], 4);
585        assert_eq!(bd[[2, 2]], 5);
586        assert_eq!(bd[[3, 3]], 6);
587        assert_eq!(bd[[3, 5]], 8);
588        // Verify zeros outside blocks
589        assert_eq!(bd[[0, 3]], 0);
590        assert_eq!(bd[[3, 0]], 0);
591    }
592}