use errors::SprsError;
use indexing::SpIndex;
use num_traits::Num;
use sparse::vec;
use sparse::CsMatViewI;
use sparse::CsVecViewI;
use stack::{self, DStack, StackVal};
use std::ops::IndexMut;
fn check_solver_dimensions<N, I, V: ?Sized>(
lower_tri_mat: &CsMatViewI<N, I>,
rhs: &V,
) where
N: Copy + Num,
V: vec::VecDim<N>,
I: SpIndex,
{
let (cols, rows) = (lower_tri_mat.cols(), lower_tri_mat.rows());
if cols != rows {
panic!("Non square matrix passed to solver");
}
if cols != rhs.dim() {
panic!("Dimension mismatch");
}
}
pub fn lsolve_csr_dense_rhs<N, I, V: ?Sized>(
lower_tri_mat: CsMatViewI<N, I>,
rhs: &mut V,
) -> Result<(), SprsError>
where
N: Copy + Num,
V: IndexMut<usize, Output = N> + vec::VecDim<N>,
I: SpIndex,
{
check_solver_dimensions(&lower_tri_mat, rhs);
if !lower_tri_mat.is_csr() {
panic!("Storage mismatch");
}
for (row_ind, row) in lower_tri_mat.outer_iterator().enumerate() {
let mut diag_val = N::zero();
let mut x = rhs[row_ind];
for (col_ind, &val) in row.iter() {
if col_ind == row_ind {
diag_val = val;
continue;
}
if col_ind > row_ind {
continue;
}
x = x - val * rhs[col_ind];
}
if diag_val == N::zero() {
return Err(SprsError::SingularMatrix);
}
rhs[row_ind] = x / diag_val;
}
Ok(())
}
pub fn lsolve_csc_dense_rhs<N, I, V: ?Sized>(
lower_tri_mat: CsMatViewI<N, I>,
rhs: &mut V,
) -> Result<(), SprsError>
where
N: Copy + Num,
V: IndexMut<usize, Output = N> + vec::VecDim<N>,
I: SpIndex,
{
check_solver_dimensions(&lower_tri_mat, rhs);
if !lower_tri_mat.is_csc() {
panic!("Storage mismatch");
}
for (col_ind, col) in lower_tri_mat.outer_iterator().enumerate() {
try!(lspsolve_csc_process_col(col, col_ind, rhs));
}
Ok(())
}
fn lspsolve_csc_process_col<N: Copy + Num, I, V: ?Sized>(
col: CsVecViewI<N, I>,
col_ind: usize,
rhs: &mut V,
) -> Result<(), SprsError>
where
V: vec::VecDim<N> + IndexMut<usize, Output = N>,
I: SpIndex,
{
if let Some(&diag_val) = col.get(col_ind) {
if diag_val == N::zero() {
return Err(SprsError::SingularMatrix);
}
let b = rhs[col_ind];
let x = b / diag_val;
rhs[col_ind] = x;
for (row_ind, &val) in col.iter() {
if row_ind <= col_ind {
continue;
}
let b = rhs[row_ind];
rhs[row_ind] = b - val * x;
}
} else {
return Err(SprsError::SingularMatrix);
}
Ok(())
}
pub fn usolve_csc_dense_rhs<N, I, V: ?Sized>(
upper_tri_mat: CsMatViewI<N, I>,
rhs: &mut V,
) -> Result<(), SprsError>
where
N: Copy + Num,
V: IndexMut<usize, Output = N> + vec::VecDim<N>,
I: SpIndex,
{
check_solver_dimensions(&upper_tri_mat, rhs);
if !upper_tri_mat.is_csc() {
panic!("Storage mismatch");
}
for (col_ind, col) in upper_tri_mat.outer_iterator().enumerate().rev() {
if let Some(&diag_val) = col.get(col_ind) {
if diag_val == N::zero() {
return Err(SprsError::SingularMatrix);
}
let b = rhs[col_ind];
let x = b / diag_val;
rhs[col_ind] = x;
for (row_ind, &val) in col.iter() {
if row_ind >= col_ind {
continue;
}
let b = rhs[row_ind];
rhs[row_ind] = b - val * x;
}
} else {
return Err(SprsError::SingularMatrix);
}
}
Ok(())
}
pub fn usolve_csr_dense_rhs<N, I, V: ?Sized>(
upper_tri_mat: CsMatViewI<N, I>,
rhs: &mut V,
) -> Result<(), SprsError>
where
N: Copy + Num,
V: IndexMut<usize, Output = N> + vec::VecDim<N>,
I: SpIndex,
{
check_solver_dimensions(&upper_tri_mat, rhs);
if !upper_tri_mat.is_csr() {
panic!("Storage mismatch");
}
for (row_ind, row) in upper_tri_mat.outer_iterator().enumerate().rev() {
let mut diag_val = N::zero();
let mut x = rhs[row_ind];
for (col_ind, &val) in row.iter() {
if col_ind == row_ind {
diag_val = val;
continue;
}
if col_ind < row_ind {
continue;
}
x = x - val * rhs[col_ind];
}
if diag_val == N::zero() {
return Err(SprsError::SingularMatrix);
}
rhs[row_ind] = x / diag_val;
}
Ok(())
}
pub fn lsolve_csc_sparse_rhs<N, I>(
lower_tri_mat: CsMatViewI<N, I>,
rhs: CsVecViewI<N, I>,
dstack: &mut DStack<StackVal<usize>>,
x_workspace: &mut [N],
visited: &mut [bool],
) -> Result<(), SprsError>
where
N: Copy + Num,
I: SpIndex,
{
if !lower_tri_mat.is_csc() {
panic!("Storage mismatch");
}
let n = lower_tri_mat.rows();
assert!(dstack.capacity() >= 2 * n, "dstack cap should be 2*n");
assert!(
dstack.is_left_empty() && dstack.is_right_empty(),
"dstack should be empty"
);
assert!(x_workspace.len() == n, "x should be of len n");
for (root_ind, _) in rhs.iter() {
if visited[root_ind] {
continue;
}
dstack.push_left(StackVal::Enter(root_ind));
while let Some(stack_val) = dstack.pop_left() {
match stack_val {
StackVal::Enter(ind) => {
if visited[ind] {
continue;
}
visited[ind] = true;
dstack.push_left(StackVal::Exit(ind));
if let Some(column) = lower_tri_mat.outer_view(ind) {
for (child_ind, _) in column.iter() {
dstack.push_left(StackVal::Enter(child_ind));
}
} else {
unreachable!();
}
}
StackVal::Exit(ind) => {
dstack.push_right(StackVal::Enter(ind));
}
}
}
}
rhs.scatter(x_workspace);
for &ind in dstack.iter_right().map(stack::extract_stack_val) {
println!("ind: {}", ind);
let col = lower_tri_mat.outer_view(ind).expect("ind not in bounds");
try!(lspsolve_csc_process_col(col, ind, x_workspace));
}
Ok(())
}
#[cfg(test)]
mod test {
use sparse::{CsMat, CsVec};
use stack::{self, DStack};
use std::collections::HashSet;
#[test]
fn lsolve_csr_dense_rhs() {
let l = CsMat::new(
(3, 3),
vec![0, 1, 2, 4],
vec![0, 1, 0, 2],
vec![1, 2, 1, 1],
);
let b = vec![3, 2, 4];
let mut x = b.clone();
super::lsolve_csr_dense_rhs(l.view(), &mut x).unwrap();
assert_eq!(x, vec![3, 1, 1]);
}
#[test]
fn lsolve_csc_dense_rhs() {
let l = CsMat::new_csc(
(3, 3),
vec![0, 2, 3, 4],
vec![0, 1, 1, 2],
vec![1, 1, 2, 3],
);
let b = vec![3, 5, 3];
let mut x = b.clone();
super::lsolve_csc_dense_rhs(l.view(), &mut x).unwrap();
assert_eq!(x, vec![3, 1, 1]);
}
#[test]
fn usolve_csc_dense_rhs() {
let u = CsMat::new_csc(
(3, 3),
vec![0, 1, 2, 4],
vec![0, 1, 0, 2],
vec![1, 2, 1, 3],
);
let b = vec![4, 2, 3];
let mut x = b.clone();
super::usolve_csc_dense_rhs(u.view(), &mut x).unwrap();
assert_eq!(x, vec![3, 1, 1]);
}
#[test]
fn usolve_csr_dense_rhs() {
let u = CsMat::new(
(3, 3),
vec![0, 2, 4, 5],
vec![0, 1, 1, 2, 2],
vec![1, 1, 5, 3, 1],
);
let b = vec![4, 8, 1];
let mut x = b.clone();
super::usolve_csr_dense_rhs(u.view(), &mut x).unwrap();
assert_eq!(x, vec![3, 1, 1]);
}
#[test]
fn lspsolve_csc() {
let l = CsMat::new_csc(
(5, 5),
vec![0, 2, 5, 6, 8, 9],
vec![0, 1, 1, 2, 4, 2, 3, 4, 4],
vec![1, 1, 2, 3, 2, 3, 7, 3, 5],
);
let b = CsVec::new(5, vec![1, 2, 4], vec![4, 9, 9]);
let mut xw = vec![1; 5];
let mut visited = vec![false; 5];
let mut dstack = DStack::with_capacity(2 * 5);
super::lsolve_csc_sparse_rhs(
l.view(),
b.view(),
&mut dstack,
&mut xw,
&mut visited,
)
.unwrap();
let x: HashSet<_> = dstack
.iter_right()
.map(stack::extract_stack_val)
.map(|&i| (i, xw[i]))
.collect();
let expected_output = CsVec::new(5, vec![1, 2, 4], vec![2, 1, 1]);
let expected_output = expected_output.to_set();
assert_eq!(x, expected_output);
let l = CsMat::new_csc(
(7, 7),
vec![0, 2, 4, 6, 7, 9, 10, 11],
vec![0, 2, 1, 6, 2, 5, 3, 4, 6, 5, 6],
vec![1, 1, 2, 3, 3, 1, 7, 5, 2, 1, 2],
);
let b = CsVec::new(7, vec![0, 2, 3, 5], vec![1, 7, 7, 3]);
let mut dstack = DStack::with_capacity(2 * 7);
let mut xw = vec![1; 7];
let mut visited = vec![false; 7];
super::lsolve_csc_sparse_rhs(
l.view(),
b.view(),
&mut dstack,
&mut xw,
&mut visited,
)
.unwrap();
let x: HashSet<_> = dstack
.iter_right()
.map(stack::extract_stack_val)
.map(|&i| (i, xw[i]))
.collect();
let expected_output =
CsVec::new(7, vec![0, 2, 3, 5], vec![1, 2, 1, 1]).to_set();
assert_eq!(x, expected_output);
}
}