use crate::num::Num;
use crate::CooMat;
use crate::CscMat;
use crate::DokMat;
#[derive(Clone, Debug)]
pub struct CsrMat<T>
where
T: Num,
{
rows: usize,
cols: usize,
rowptr: Vec<usize>,
colind: Vec<usize>,
values: Vec<T>,
}
impl<T> CsrMat<T>
where
T: Num,
{
pub fn new(
rows: usize,
cols: usize,
rowptr: Vec<usize>,
colind: Vec<usize>,
values: Vec<T>,
) -> Self {
let rows = std::cmp::max(1, rows);
let cols = std::cmp::max(1, cols);
if rowptr.len() != (rows + 1) {
panic!(
"invalid rowptr length: rowptr length ({}) != rows + 1 ({})",
rowptr.len(),
rows + 1
);
}
if colind.len() != values.len() {
panic!(
"invalid colind/values length: colind length ({}) != values length ({})",
colind.len(),
values.len()
);
}
if rowptr[0] != 0 {
panic!("invalid rowptr: rowptr[0] ({}) != 0", rowptr[0])
}
for idx in 0..rows {
if rowptr[idx] > rowptr[idx + 1] {
panic!(
"unsorted rowptr: rowptr[{}] ({}) > rowptr[{}] ({})",
idx,
rowptr[idx],
idx + 1,
rowptr[idx + 1]
);
}
}
for row in 0..rows {
for idx in rowptr[row]..(rowptr[row + 1] - 1) {
if colind[idx] >= colind[idx + 1] {
panic!(
"unsorted colind: colind[{}] ({}) >= colind[{}] ({})",
idx,
colind[idx],
idx + 1,
colind[idx + 1]
);
}
}
}
CsrMat {
rows,
cols,
colind,
rowptr,
values,
}
}
pub fn eye(rows: usize, cols: usize) -> Self {
let rows = std::cmp::max(1, rows);
let cols = std::cmp::max(1, cols);
let nnz = rows;
let mut rowptr = Vec::with_capacity(rows + 1);
let mut colind = Vec::with_capacity(nnz);
let mut values = Vec::with_capacity(nnz);
for idx in 0..nnz {
rowptr[idx] = idx;
colind[idx] = idx;
values[idx] = T::one();
}
rowptr[nnz] = nnz;
CsrMat {
rows,
cols,
rowptr,
colind,
values,
}
}
pub fn identity(dim: usize) -> Self {
let dim = std::cmp::max(1, dim);
let rows = dim;
let cols = dim;
let nnz = dim;
let mut rowptr = Vec::with_capacity(rows + 1);
let mut colind = Vec::with_capacity(nnz);
let mut values = Vec::with_capacity(nnz);
for idx in 0..nnz {
rowptr[idx] = idx;
colind[idx] = idx;
values[idx] = T::one();
}
rowptr[nnz] = nnz;
CsrMat {
rows,
cols,
rowptr,
colind,
values,
}
}
pub fn rows(&self) -> usize {
self.rows
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn shape(&self) -> (usize, usize) {
(self.rows, self.cols)
}
pub fn nnz(&self) -> usize {
self.rowptr[self.rows]
}
pub fn rowptr(&self) -> &[usize] {
&self.rowptr
}
pub fn colind(&self) -> &[usize] {
&self.colind
}
pub fn values(&self) -> &[T] {
&self.values
}
}
impl<T> From<DokMat<T>> for CsrMat<T>
where
T: Num,
{
fn from(dokmat: DokMat<T>) -> Self {
let rows = dokmat.rows();
let cols = dokmat.cols();
let entries = dokmat.entries();
let entries: Vec<_> = entries.filter(|(_, _, &x)| !x.is_zero()).collect();
let nnz = entries.len();
let mut colptr = vec![0; cols + 1];
let mut rowind = vec![0; nnz];
let mut cscval = vec![T::zero(); nnz];
let mut work = vec![0; cols];
for &(_, c, _) in entries.iter() {
work[*c] += 1;
}
let mut sum = 0;
for k in 0..cols {
let cum = sum;
sum += work[k];
work[k] = cum;
colptr[k] = cum;
}
colptr[cols] = sum;
for &(r, c, v) in entries.iter() {
let idx = work[*c];
work[*c] += 1;
rowind[idx] = *r;
cscval[idx] = *v;
}
let mut rowptr = vec![0; rows + 1];
let mut colind = vec![0; nnz];
let mut csrval = vec![T::zero(); nnz];
let mut work = vec![0; rows];
for idx in 0..nnz {
work[rowind[idx]] += 1;
}
let mut sum = 0;
for k in 0..rows {
let cum = sum;
sum += work[k];
work[k] = cum;
rowptr[k] = cum;
}
rowptr[cols] = sum;
for col in 0..cols {
for cscidx in colptr[col]..colptr[col + 1] {
let row = rowind[cscidx];
let idx = work[row];
work[row] += 1;
colind[idx] = col;
csrval[idx] = cscval[cscidx];
}
}
CsrMat {
rows,
cols,
rowptr,
colind,
values: csrval,
}
}
}
impl<T> From<CooMat<T>> for CsrMat<T>
where
T: Num,
{
fn from(coomat: CooMat<T>) -> Self {
let rows = coomat.rows();
let cols = coomat.cols();
let entries = coomat.entries();
let entries: Vec<_> = entries.filter(|(_, _, x)| !x.is_zero()).collect();
let nnz = entries.len();
let mut colptr = vec![0; cols + 1];
let mut rowind = vec![0; nnz];
let mut cscval = vec![T::zero(); nnz];
let mut work = vec![0; cols];
for &(_, c, _) in entries.iter() {
work[*c] += 1;
}
let mut sum = 0;
for k in 0..cols {
let cum = sum;
sum += work[k];
work[k] = cum;
colptr[k] = cum;
}
colptr[cols] = sum;
for &(r, c, v) in entries.iter() {
let idx = work[*c];
work[*c] += 1;
rowind[idx] = *r;
cscval[idx] = *v;
}
let mut work = vec![nnz; rows];
let mut nz = 0;
for col in 0..cols {
let colidx = nz;
for idx in colptr[col]..colptr[col + 1] {
let row = rowind[idx];
if work[row] >= colidx && work[row] < nnz {
let val = cscval[idx];
cscval[work[row]] += val;
} else {
work[row] = nz;
rowind[nz] = row;
cscval[nz] = cscval[idx];
nz += 1;
}
}
colptr[col] = colidx;
}
colptr[rows] = nz;
let nnz = nz;
rowind.truncate(nnz);
cscval.truncate(nnz);
let mut nz = 0;
for col in 0..cols {
let ptr = colptr[col];
colptr[col] = nz;
for idx in ptr..colptr[col + 1] {
if !cscval[idx].is_zero() {
rowind[nz] = rowind[idx];
cscval[nz] = cscval[idx];
nz += 1;
}
}
}
colptr[cols] = nz;
let nnz = nz;
rowind.truncate(nnz);
cscval.truncate(nnz);
let mut rowptr = vec![0; rows + 1];
let mut colind = vec![0; nnz];
let mut csrval = vec![T::zero(); nnz];
let mut work = vec![0; cols];
for idx in 0..nnz {
work[rowind[idx]] += 1;
}
let mut sum = 0;
for k in 0..rows {
let cum = sum;
sum += work[k];
work[k] = cum;
rowptr[k] = cum;
}
rowptr[rows] = sum;
for col in 0..cols {
for cscidx in colptr[col]..colptr[col + 1] {
let row = rowind[cscidx];
let idx = work[row];
work[row] += 1;
colind[idx] = col;
csrval[idx] = cscval[cscidx];
}
}
CsrMat {
rows,
cols,
rowptr,
colind,
values: csrval,
}
}
}
impl<T> From<CscMat<T>> for CsrMat<T>
where
T: Num,
{
fn from(cscmat: CscMat<T>) -> Self {
let rows = cscmat.rows();
let cols = cscmat.cols();
let nnz = cscmat.nnz();
let colptr = cscmat.colptr();
let rowind = cscmat.rowind();
let cscval = cscmat.values();
let mut rowptr = vec![0; rows + 1];
let mut colind = vec![0; nnz];
let mut csrval = vec![T::zero(); nnz];
let mut work = vec![0; rows];
for idx in 0..nnz {
work[rowind[idx]] += 1;
}
let mut sum = 0;
for k in 0..rows {
let cum = sum;
sum += work[k];
work[k] = cum;
rowptr[k] = cum;
}
rowptr[rows] = sum;
for col in 0..cols {
for cscidx in colptr[col]..colptr[col + 1] {
let row = rowind[cscidx];
let idx = work[row];
work[row] += 1;
colind[idx] = col;
csrval[idx] = cscval[cscidx];
}
}
CsrMat {
rows,
cols,
rowptr,
colind,
values: csrval,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new() {
let rowptr = vec![0, 1];
let colind = vec![0];
let values = vec![1.0];
let matrix = CsrMat::<f64>::new(0, 0, rowptr.clone(), colind.clone(), values.clone());
assert_eq!(matrix.rows, 1);
assert_eq!(matrix.cols, 1);
let matrix = CsrMat::<f64>::new(0, 1, rowptr.clone(), colind.clone(), values.clone());
assert_eq!(matrix.rows, 1);
assert_eq!(matrix.cols, 1);
let matrix = CsrMat::<f64>::new(1, 0, rowptr.clone(), colind.clone(), values.clone());
assert_eq!(matrix.rows, 1);
assert_eq!(matrix.cols, 1);
let matrix = CsrMat::<f64>::new(1, 1, rowptr, colind, values);
assert_eq!(matrix.rows, 1);
assert_eq!(matrix.cols, 1);
}
#[test]
#[should_panic(expected = "invalid rowptr length: rowptr length (1) != rows + 1 (2)")]
fn new_invalid_rowptr_length() {
let rowptr = vec![0];
let colind = vec![0];
let values = vec![1.0];
CsrMat::<f64>::new(0, 0, rowptr, colind, values);
}
#[test]
#[should_panic(
expected = "invalid colind/values length: colind length (1) != values length (2)"
)]
fn new_invalid_colind_values_length() {
let rowptr = vec![0, 1];
let colind = vec![0];
let values = vec![1.0, 2.0];
CsrMat::<f64>::new(0, 0, rowptr, colind, values);
}
#[test]
#[should_panic(expected = "invalid rowptr: rowptr[0] (1) != 0")]
fn new_invalid_rowptr() {
let rowptr = vec![1, 2];
let colind = vec![0];
let values = vec![1.0];
CsrMat::<f64>::new(0, 0, rowptr, colind, values);
}
#[test]
#[should_panic(expected = "unsorted rowptr: rowptr[1] (2) > rowptr[2] (1)")]
fn new_unsorted_rowptr() {
let rowptr = vec![0, 2, 1];
let colind = vec![0];
let values = vec![1.0];
CsrMat::<f64>::new(2, 2, rowptr, colind, values);
}
#[test]
#[should_panic(expected = "unsorted colind: colind[0] (1) >= colind[1] (0)")]
fn new_unsorted_colind() {
let rowptr = vec![0, 2, 2];
let colind = vec![1, 0];
let values = vec![1.0, 2.0];
CsrMat::<f64>::new(2, 2, rowptr, colind, values);
}
#[test]
fn shape() {
let rowptr = vec![0, 1];
let colind = vec![0];
let values = vec![1.0];
let matrix = CsrMat::<f64>::new(0, 0, rowptr, colind, values);
assert_eq!(matrix.rows(), 1);
assert_eq!(matrix.cols(), 1);
assert_eq!((1, 1), matrix.shape());
}
#[test]
fn nnz() {
let colind = vec![0];
let rowptr = vec![0, 1];
let values = vec![1.0];
let matrix = CsrMat::<f64>::new(0, 0, rowptr, colind, values);
assert_eq!(matrix.nnz(), 1);
}
#[test]
fn from_dokmat() {
let mut dokmat = DokMat::new(3, 3);
dokmat.insert(0, 0, 1);
dokmat.insert(0, 1, 2);
dokmat.insert(0, 2, 3);
dokmat.insert(1, 0, 4);
dokmat.insert(1, 1, 5);
dokmat.insert(1, 2, 6);
dokmat.insert(2, 0, 7);
dokmat.insert(2, 1, 8);
dokmat.insert(2, 2, 9);
let csrmat: CsrMat<_> = dokmat.into();
assert_eq!(csrmat.rowptr, vec![0, 3, 6, 9]);
assert_eq!(csrmat.colind, vec![0, 1, 2, 0, 1, 2, 0, 1, 2]);
assert_eq!(csrmat.values, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn from_coomat() {
let mut coomat = CooMat::new(3, 3);
coomat.push(0, 0, 1);
coomat.push(0, 1, 1);
coomat.push(0, 2, 3);
coomat.push(1, 0, 2);
coomat.push(1, 1, 5);
coomat.push(1, 2, 3);
coomat.push(2, 0, 7);
coomat.push(2, 1, 4);
coomat.push(2, 2, 9);
coomat.push(0, 1, 1);
coomat.push(1, 0, 2);
coomat.push(2, 1, 4);
coomat.push(1, 2, -3);
let csrmat: CsrMat<_> = coomat.into();
assert_eq!(csrmat.rowptr, vec![0, 3, 5, 8]);
assert_eq!(csrmat.colind, vec![0, 1, 2, 0, 1, 0, 1, 2]);
assert_eq!(csrmat.values, vec![1, 2, 3, 4, 5, 7, 8, 9]);
}
#[test]
fn from_cscmat() {
let colptr = vec![0, 3, 6, 9];
let rowind = vec![0, 1, 2, 0, 1, 2, 0, 1, 2];
let cscval = vec![1, 4, 7, 2, 5, 8, 3, 6, 9];
let cscmat = CscMat::new(3, 3, colptr, rowind, cscval);
let csrmat: CsrMat<_> = cscmat.into();
assert_eq!(csrmat.rowptr, vec![0, 3, 6, 9]);
assert_eq!(csrmat.colind, vec![0, 1, 2, 0, 1, 2, 0, 1, 2]);
assert_eq!(csrmat.values, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
}