batch_matvec

Function batch_matvec 

Source
pub fn batch_matvec<F>(
    batch_a: &ArrayView3<'_, F>,
    x: &ArrayView1<'_, F>,
) -> LinalgResult<Array2<F>>
where F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
Expand description

Batch matrix-vector multiplication

Performs matrix-vector multiplication for a batch of matrices with a single vector.

§Arguments

  • batch_a - 3D array of shape (batchsize, m, n) representing the batch of matrices
  • x - Vector of length n

§Returns

  • 2D array of shape (batchsize, m) containing the result of each matrix-vector multiplication

§Examples

use scirs2_core::ndarray::{array, Array3};
use scirs2_linalg::batch::batch_matvec;

// Create a batch of 2 matrices, each 2x2
let batch_a = Array3::from_shape_vec((2, 2, 2), vec![
    1.0, 2.0,  // First matrix: [[1.0, 2.0],
    3.0, 4.0,  //               [3.0, 4.0]]
    5.0, 6.0,  // Second matrix: [[5.0, 6.0],
    7.0, 8.0   //                [7.0, 8.0]]
]).unwrap();

// Create a vector to multiply with each batch matrix
let x = array![10.0, 20.0];

// Perform batch matrix-vector multiplication
let result = batch_matvec(&batch_a.view(), &x.view()).unwrap();

// Expected results:
// First result: [[1.0, 2.0], [3.0, 4.0]] × [10.0, 20.0] = [50.0, 110.0]
// Second result: [[5.0, 6.0], [7.0, 8.0]] × [10.0, 20.0] = [170.0, 230.0]
assert_eq!(result.shape(), &[2, 2]);
assert_eq!(result[[0, 0]], 50.0);
assert_eq!(result[[0, 1]], 110.0);
assert_eq!(result[[1, 0]], 170.0);
assert_eq!(result[[1, 1]], 230.0);