Skip to main content

yui_matrix/sparse/
triang.rs

1use either::Either;
2use log::debug;
3use num_traits::Zero;
4use yui_core::{Ring, RingOps};
5
6use super::*;
7
8cfg_if::cfg_if! {
9    if #[cfg(feature = "multithread")] {
10        use std::cell::RefCell;
11        use std::sync::Arc;
12        use thread_local::ThreadLocal;
13        use rayon::prelude::*;
14    }
15}
16
17const LOG_THRESHOLD: usize = 10_000;
18
19#[derive(Clone, Copy, PartialEq, Eq)]
20pub enum TriangularType { 
21    Upper, Lower
22}
23
24impl TriangularType { 
25    pub fn is_upper(&self) -> bool { 
26        match self { 
27            Self::Upper => true,
28            Self::Lower => false
29        }
30    }
31
32    pub fn tranpose(&self) -> Self { 
33        match self { 
34            Self::Upper => Self::Lower,
35            Self::Lower => Self::Upper
36        }
37    }
38}
39
40pub fn inv_triangular<R>(t: TriangularType, a: &SpMat<R>) -> SpMat<R>
41where R: Ring, for<'x> &'x R: RingOps<R> {
42    let e = SpMat::id(a.nrows());
43    solve_triangular(t, a, &e)
44}
45
46// solve ax = y.
47pub fn solve_triangular<R>(t: TriangularType, a: &SpMat<R>, y: &SpMat<R>) -> SpMat<R>
48where R: Ring, for<'x> &'x R: RingOps<R> {
49    assert_eq!(a.nrows(), y.nrows());
50    debug_assert!(a.is_triang(t));
51
52    cfg_if::cfg_if! { 
53        if #[cfg(feature = "multithread")] { 
54            solve_triangular_m(t, a, y)
55        } else { 
56            solve_triangular_s(t, a, y)
57        }
58    }
59}
60
61// solve xa = y.
62pub fn solve_triangular_left<R>(t: TriangularType, a: &SpMat<R>, y: &SpMat<R>) -> SpMat<R>
63where R: Ring, for<'x> &'x R: RingOps<R> {
64    solve_triangular(t.tranpose(), &a.transpose(), &y.transpose()).transpose()
65}
66
67pub fn solve_triangular_vec<R>(t: TriangularType, a: &SpMat<R>, b: &SpVec<R>) -> SpVec<R>
68where R: Ring, for<'x> &'x R: RingOps<R> {
69    assert_eq!(a.nrows(), b.dim());
70    debug_assert!(a.is_triang(t));
71
72    let diag = collect_diag(a);
73    let mut b = b.to_dense();
74
75    _solve_triangular(t, a, &diag, &mut b)
76}
77
78#[allow(unused)]
79fn solve_triangular_s<R>(t: TriangularType, a: &SpMat<R>, y: &SpMat<R>) -> SpMat<R>
80where R: Ring, for<'x> &'x R: RingOps<R> {
81    debug!("solve triangular, y: {:?}", y.shape());
82
83    let (n, k) = (a.nrows(), y.ncols());
84    let diag = collect_diag(a);
85    let mut b = vec![R::zero(); n];
86
87    let cols = (0..k).map(|j| { 
88        copy_into(y.col_vec(j), &mut b);
89        _solve_triangular(t, a, &diag, &mut b)
90    });
91
92    SpMat::from_col_vecs(n, cols)
93}
94
95#[cfg(feature = "multithread")]
96fn solve_triangular_m<R>(t: TriangularType, a: &SpMat<R>, y: &SpMat<R>) -> SpMat<R>
97where R: Ring, for<'x> &'x R: RingOps<R> {
98    use yui_core::util::sync::SyncCounter;
99
100    debug!("solve triangular, y: {:?}", y.shape());
101
102    let (n, k) = (a.nrows(), y.ncols());
103    let diag = collect_diag(a);
104    let tl_b = Arc::new(ThreadLocal::new());
105
106    let report = should_report(y);
107    let counter = SyncCounter::new();
108
109    let cols = (0..k).into_par_iter().map(|j| { 
110        let mut b = tl_b.get_or(|| 
111            RefCell::new(vec![R::zero(); n])
112        ).borrow_mut();
113
114        copy_into(y.col_vec(j), &mut b);
115        let col = _solve_triangular(t, a, &diag, &mut b);
116
117        if report { 
118            let c = counter.incr();
119            if (c > 0 && c % LOG_THRESHOLD == 0) || c == k { 
120                debug!("  solved {c}/{k}.");
121            }
122        }
123
124        col
125    }).collect::<Vec<_>>();
126
127    SpMat::from_col_vecs(n, cols)
128}
129
130#[inline(never)] // for profilability
131fn _solve_triangular<R>(t: TriangularType, a: &SpMat<R>, diag: &[&R], b: &mut [R]) -> SpVec<R>
132where R: Ring, for<'x> &'x R: RingOps<R> {
133    let mut entries = vec![];
134
135    let itr = diag.iter().enumerate();
136    let itr = if t.is_upper() { 
137        Either::Left(itr.rev())
138    } else { 
139        Either::Right(itr)
140    };
141
142    for (j, u) in itr { // u = a_jj
143        if b[j].is_zero() { continue }
144
145        let uinv = u.inv().unwrap();
146        let x_j = &b[j] * &uinv; // non-zero
147
148        for (i, a_ij) in a.col_vec(j).iter() {
149            if a_ij.is_zero() { continue }
150            b[i] -= a_ij * &x_j;
151        }
152
153        entries.push((j, x_j));
154    }
155
156    debug_assert!(b.iter().all(|b_i| 
157        b_i.is_zero())
158    );
159
160    if t.is_upper() { 
161        entries.reverse()
162    };
163
164    SpVec::from_sorted_entries(a.ncols(), entries)
165}
166
167fn collect_diag<'a, R>(a: &'a SpMat<R>) -> Vec<&'a R>
168where R: Ring, for<'x> &'x R: RingOps<R> { 
169    a.iter().filter_map(|(i, j, a)| 
170        if i == j { Some(a) } else { None }
171    ).collect()
172}
173
174fn copy_into<R>(vec: SpVec<R>, x: &mut [R])
175where R: Clone + Zero { 
176    vec.iter().for_each(|(i, r)| x[i] = r.clone())
177}
178
179#[allow(unused)]
180fn should_report<R>(a: &SpMat<R>) -> bool { 
181    usize::min(a.nrows(), a.ncols()) > LOG_THRESHOLD && log::max_level() >= log::LevelFilter::Debug
182}
183
184#[cfg(test)]
185mod tests { 
186    use super::*;
187    use super::TriangularType::{Upper, Lower};
188
189    #[test]
190    fn solve_upper() { 
191        let u = SpMat::from_dense_data((5, 5), vec![
192            1, -2, 1,  3, 5,
193            0, -1, 4,  2, 1,
194            0,  0, 1,  0, 3,
195            0,  0, 0, -1, 5,
196            0,  0, 0,  0, 1
197        ]);
198        let x = SpVec::from(vec![1,2,3,4,5]);
199        let b = SpVec::from(vec![37,23,18,21,5]);
200        assert_eq!(solve_triangular_vec(Upper, &u, &b), x);
201    }
202
203    #[test]
204    fn inv_upper() { 
205        let u = SpMat::from_dense_data((5, 5), [
206            1, -2, 1,  3, 5,
207            0, -1, 4,  2, 1,
208            0,  0, 1,  0, 3,
209            0,  0, 0, -1, 5,
210            0,  0, 0,  0, 1
211        ]);
212        let uinv = inv_triangular(Upper, &u);
213        let e = &u * &uinv;
214        assert!(e.is_id());
215    }
216
217    #[test]
218    fn solve_lower() { 
219        let l = SpMat::from_dense_data((5, 5), [
220            1,  0, 0,  0, 0,
221           -2, -1, 0,  0, 0,
222            1,  4, 1,  0, 0,
223            3,  2, 0, -1, 0,
224            5,  1, 3,  5, 1
225        ]);
226        let x = SpVec::from(vec![1,2,3,4,5]);
227        let b = SpVec::from(vec![1,-4,12,3,41]);
228        assert_eq!(solve_triangular_vec(Lower, &l, &b), x);
229    }
230
231    #[test]
232    fn inv_lower() { 
233        let l = SpMat::from_dense_data((5, 5), [
234            1,  0, 0,  0, 0,
235           -2, -1, 0,  0, 0,
236            1,  4, 1,  0, 0,
237            3,  2, 0, -1, 0,
238            5,  1, 3,  5, 1
239        ]);
240        let linv = inv_triangular(Lower, &l);
241        let e = &l * &linv;
242        assert!(e.is_id());
243    }
244}