use indexing::SpIndex;
use ndarray::{
self, Array, ArrayBase, ArrayView, ArrayViewMut, Axis, ShapeBuilder,
};
use num_traits::Num;
use sparse::compressed::SpMatView;
use sparse::csmat::CompressedStorage;
use sparse::prelude::*;
use sparse::vec::NnzEither::{Both, Left, Right};
use sparse::vec::SparseIterTools;
use Ix2;
use SpRes;
pub fn add_mat_same_storage<N, I, Mat1, Mat2>(
lhs: &Mat1,
rhs: &Mat2,
) -> CsMatI<N, I>
where
N: Num + Copy,
I: SpIndex,
Mat1: SpMatView<N, I>,
Mat2: SpMatView<N, I>,
{
csmat_binop(lhs.view(), rhs.view(), |&x, &y| x + y)
}
pub fn sub_mat_same_storage<N, I, Mat1, Mat2>(
lhs: &Mat1,
rhs: &Mat2,
) -> CsMatI<N, I>
where
N: Num + Copy,
I: SpIndex,
Mat1: SpMatView<N, I>,
Mat2: SpMatView<N, I>,
{
csmat_binop(lhs.view(), rhs.view(), |&x, &y| x - y)
}
pub fn mul_mat_same_storage<N, I, Mat1, Mat2>(
lhs: &Mat1,
rhs: &Mat2,
) -> CsMatI<N, I>
where
N: Num + Copy,
I: SpIndex,
Mat1: SpMatView<N, I>,
Mat2: SpMatView<N, I>,
{
csmat_binop(lhs.view(), rhs.view(), |&x, &y| x * y)
}
pub fn scalar_mul_mat<N, I, Mat>(mat: &Mat, val: N) -> CsMatI<N, I>
where
N: Num + Copy,
I: SpIndex,
Mat: SpMatView<N, I>,
{
let mat = mat.view();
mat.map(|&x| x * val)
}
pub fn csmat_binop<N, I, F>(
lhs: CsMatViewI<N, I>,
rhs: CsMatViewI<N, I>,
binop: F,
) -> CsMatI<N, I>
where
N: Num,
I: SpIndex,
F: Fn(&N, &N) -> N,
{
let nrows = lhs.rows();
let ncols = lhs.cols();
let storage = lhs.storage();
if nrows != rhs.rows() || ncols != rhs.cols() {
panic!("Dimension mismatch");
}
if storage != rhs.storage() {
panic!("Storage mismatch");
}
let max_nnz = lhs.nnz() + rhs.nnz();
let mut out_indptr = vec![I::zero(); lhs.outer_dims() + 1];
let mut out_indices = vec![I::zero(); max_nnz];
let mut out_data = Vec::with_capacity(max_nnz);
for _ in 0..max_nnz {
out_data.push(N::zero());
}
let nnz = csmat_binop_same_storage_raw(
lhs,
rhs,
binop,
&mut out_indptr[..],
&mut out_indices[..],
&mut out_data[..],
);
out_indices.truncate(nnz);
out_data.truncate(nnz);
CsMatI {
storage,
nrows,
ncols,
indptr: out_indptr,
indices: out_indices,
data: out_data,
}
}
pub fn csmat_binop_same_storage_raw<N, I, F>(
lhs: CsMatViewI<N, I>,
rhs: CsMatViewI<N, I>,
binop: F,
out_indptr: &mut [I],
out_indices: &mut [I],
out_data: &mut [N],
) -> usize
where
N: Num,
I: SpIndex,
F: Fn(&N, &N) -> N,
{
assert_eq!(lhs.cols(), rhs.cols());
assert_eq!(lhs.rows(), rhs.rows());
assert_eq!(lhs.storage(), rhs.storage());
assert_eq!(out_indptr.len(), rhs.outer_dims() + 1);
let max_nnz = lhs.nnz() + rhs.nnz();
assert!(out_data.len() >= max_nnz);
assert!(out_indices.len() >= max_nnz);
let mut nnz = 0;
out_indptr[0] = I::zero();
let iter = lhs.outer_iterator().zip(rhs.outer_iterator()).enumerate();
for (dim, (lv, rv)) in iter {
for elem in lv.iter().nnz_or_zip(rv.iter()) {
let (ind, binop_val) = match elem {
Left((ind, val)) => (ind, binop(val, &N::zero())),
Right((ind, val)) => (ind, binop(&N::zero(), val)),
Both((ind, lval, rval)) => (ind, binop(lval, rval)),
};
if binop_val != N::zero() {
out_indices[nnz] = I::from_usize(ind);
out_data[nnz] = binop_val;
nnz += 1;
}
}
out_indptr[dim + 1] = I::from_usize(nnz);
}
nnz
}
pub fn add_dense_mat_same_ordering<N, I, Mat, D>(
lhs: &Mat,
rhs: &ArrayBase<D, Ix2>,
alpha: N,
beta: N,
) -> Array<N, Ix2>
where
N: Num + Copy,
I: SpIndex,
Mat: SpMatView<N, I>,
D: ndarray::Data<Elem = N>,
{
let shape = (rhs.shape()[0], rhs.shape()[1]);
let mut res = if rhs.is_standard_layout() {
Array::zeros(shape)
} else {
Array::zeros(shape.f())
};
csmat_binop_dense_raw(
lhs.view(),
rhs.view(),
|&x, &y| alpha * x + beta * y,
res.view_mut(),
);
res
}
pub fn mul_dense_mat_same_ordering<N, I, Mat, D>(
lhs: &Mat,
rhs: &ArrayBase<D, Ix2>,
alpha: N,
) -> Array<N, Ix2>
where
N: Num + Copy,
I: SpIndex,
Mat: SpMatView<N, I>,
D: ndarray::Data<Elem = N>,
{
let shape = (rhs.shape()[0], rhs.shape()[1]);
let mut res = if rhs.is_standard_layout() {
Array::zeros(shape)
} else {
Array::zeros(shape.f())
};
csmat_binop_dense_raw(
lhs.view(),
rhs.view(),
|&x, &y| alpha * x * y,
res.view_mut(),
);
res
}
pub fn csmat_binop_dense_raw<'a, N, I, F>(
lhs: CsMatViewI<'a, N, I>,
rhs: ArrayView<'a, N, Ix2>,
binop: F,
mut out: ArrayViewMut<'a, N, Ix2>,
) where
N: 'a + Num,
I: 'a + SpIndex,
F: Fn(&N, &N) -> N,
{
if lhs.cols() != rhs.shape()[1]
|| lhs.cols() != out.shape()[1]
|| lhs.rows() != rhs.shape()[0]
|| lhs.rows() != out.shape()[0]
{
panic!("Dimension mismatch");
}
match (
lhs.storage(),
rhs.is_standard_layout(),
out.is_standard_layout(),
) {
(CompressedStorage::CSR, true, true) => (),
(CompressedStorage::CSC, false, false) => (),
(_, _, _) => panic!("Storage mismatch"),
}
let outer_axis = if rhs.is_standard_layout() {
Axis(0)
} else {
Axis(1)
};
for ((mut orow, lrow), rrow) in out
.axis_iter_mut(outer_axis)
.zip(lhs.outer_iterator())
.zip(rhs.axis_iter(outer_axis))
{
for items in orow
.iter_mut()
.zip(rrow.iter().enumerate().nnz_or_zip(lrow.iter()))
{
let (oval, lr_elems) = items;
let binop_val = match lr_elems {
Left((_, val)) => binop(val, &N::zero()),
Right((_, val)) => binop(&N::zero(), val),
Both((_, lval, rval)) => binop(lval, rval),
};
*oval = binop_val;
}
}
}
pub fn csvec_binop<N, I, F>(
mut lhs: CsVecViewI<N, I>,
mut rhs: CsVecViewI<N, I>,
binop: F,
) -> SpRes<CsVecI<N, I>>
where
N: Num,
F: Fn(&N, &N) -> N,
I: SpIndex,
{
csvec_fix_zeros(&mut lhs, &mut rhs);
if lhs.dim() != rhs.dim() {
panic!("Dimension mismatch");
}
let mut res = CsVecI::empty(lhs.dim());
let max_nnz = lhs.nnz() + rhs.nnz();
res.reserve_exact(max_nnz);
for elem in lhs.iter().nnz_or_zip(rhs.iter()) {
let (ind, binop_val) = match elem {
Left((ind, val)) => (ind, binop(val, &N::zero())),
Right((ind, val)) => (ind, binop(&N::zero(), val)),
Both((ind, lval, rval)) => (ind, binop(lval, rval)),
};
res.append(ind, binop_val);
}
Ok(res)
}
fn csvec_fix_zeros<N, I: SpIndex>(
lhs: &mut CsVecViewI<N, I>,
rhs: &mut CsVecViewI<N, I>,
) {
if rhs.dim() == 0 {
rhs.dim = lhs.dim;
}
if lhs.dim() == 0 {
lhs.dim = rhs.dim;
}
}
#[cfg(test)]
mod test {
use ndarray::{arr2, Array};
use sparse::CsMat;
use sparse::CsVec;
use test_data::{mat1, mat1_times_2, mat2, mat_dense1};
fn mat1_plus_mat2() -> CsMat<f64> {
let indptr = vec![0, 5, 8, 9, 12, 15];
let indices = vec![0, 1, 2, 3, 4, 0, 3, 4, 2, 1, 2, 3, 1, 2, 3];
let data =
vec![6., 7., 6., 4., 3., 8., 11., 5., 5., 8., 2., 4., 4., 4., 7.];
CsMat::new((5, 5), indptr, indices, data)
}
fn mat1_minus_mat2() -> CsMat<f64> {
let indptr = vec![0, 4, 7, 8, 11, 14];
let indices = vec![0, 1, 3, 4, 0, 3, 4, 2, 1, 2, 3, 1, 2, 3];
let data = vec![
-6., -7., 4., -3., -8., -7., 5., 5., 8., -2., -4., -4., -4., 7.,
];
CsMat::new((5, 5), indptr, indices, data)
}
fn mat1_times_mat2() -> CsMat<f64> {
let indptr = vec![0, 1, 2, 2, 2, 2];
let indices = vec![2, 3];
let data = vec![9., 18.];
CsMat::new((5, 5), indptr, indices, data)
}
#[test]
fn test_add1() {
let a = mat1();
let b = mat2();
let c = super::add_mat_same_storage(&a, &b);
let c_true = mat1_plus_mat2();
assert_eq!(c, c_true);
let c = &a + &b;
assert_eq!(c, c_true);
let a = CsMat::new((3, 3), vec![0, 1, 1, 2], vec![0, 2], vec![1., 1.]);
let b = CsMat::new((3, 3), vec![0, 1, 2, 2], vec![0, 1], vec![1., 1.]);
let c = CsMat::new(
(3, 3),
vec![0, 1, 2, 3],
vec![0, 1, 2],
vec![2., 1., 1.],
);
assert_eq!(c, &a + &b);
}
#[test]
fn test_sub1() {
let a = mat1();
let b = mat2();
let c = super::sub_mat_same_storage(&a, &b);
let c_true = mat1_minus_mat2();
assert_eq!(c, c_true);
let c = &a - &b;
assert_eq!(c, c_true);
}
#[test]
fn test_mul1() {
let a = mat1();
let b = mat2();
let c = super::mul_mat_same_storage(&a, &b);
let c_true = mat1_times_mat2();
assert_eq!(c.indptr(), c_true.indptr());
assert_eq!(c.indices(), c_true.indices());
assert_eq!(c.data(), c_true.data());
}
#[test]
fn test_smul() {
let a = mat1();
let c = super::scalar_mul_mat(&a, 2.);
let c_true = mat1_times_2();
assert_eq!(c.indptr(), c_true.indptr());
assert_eq!(c.indices(), c_true.indices());
assert_eq!(c.data(), c_true.data());
}
#[test]
fn csvec_binops() {
let vec1 = CsVec::new(8, vec![0, 2, 4, 6], vec![1.; 4]);
let vec2 = CsVec::new(8, vec![1, 3, 5, 7], vec![2.; 4]);
let vec3 = CsVec::new(8, vec![1, 2, 5, 6], vec![3.; 4]);
let res = &vec1 + &vec2;
let expected_output = CsVec::new(
8,
vec![0, 1, 2, 3, 4, 5, 6, 7],
vec![1., 2., 1., 2., 1., 2., 1., 2.],
);
assert_eq!(expected_output, res);
let res = &vec1 + &vec3;
let expected_output =
CsVec::new(8, vec![0, 1, 2, 4, 5, 6], vec![1., 3., 4., 1., 3., 4.]);
assert_eq!(expected_output, res);
}
#[test]
fn zero_sized_vector_works_as_right_vector_operand() {
let vector = CsVec::new(8, vec![0, 2, 4, 6], vec![1.; 4]);
let zero = CsVec::<f64>::new(0, vec![], vec![]);
assert_eq!(&vector + zero, vector);
}
#[test]
fn zero_sized_vector_works_as_left_vector_operand() {
let vector = CsVec::new(8, vec![0, 2, 4, 6], vec![1.; 4]);
let zero = CsVec::<f64>::new(0, vec![], vec![]);
assert_eq!(zero + &vector, vector);
}
#[test]
fn csr_add_dense_rowmaj() {
let a = Array::zeros((3, 3));
let b = CsMat::eye(3);
let c = super::add_dense_mat_same_ordering(&b, &a, 1., 1.);
let mut expected_output = Array::zeros((3, 3));
expected_output[[0, 0]] = 1.;
expected_output[[1, 1]] = 1.;
expected_output[[2, 2]] = 1.;
assert_eq!(c, expected_output);
let a = mat1();
let b = mat_dense1();
let expected_output = arr2(&[
[0., 1., 5., 7., 4.],
[5., 6., 5., 6., 8.],
[4., 5., 9., 3., 2.],
[3., 12., 3., 2., 1.],
[1., 2., 1., 8., 0.],
]);
let c = super::add_dense_mat_same_ordering(&a, &b, 1., 1.);
assert_eq!(c, expected_output);
let c = &a + &b;
assert_eq!(c, expected_output);
}
#[test]
fn csr_mul_dense_rowmaj() {
let a = Array::from_elem((3, 3), 1.);
let b = CsMat::eye(3);
let c = super::mul_dense_mat_same_ordering(&b, &a, 1.);
let expected_output = Array::eye(3);
assert_eq!(c, expected_output);
}
}