Skip to main content

yui_matrix/sparse/
sp_vec.rs

1use std::ops::{Add, AddAssign, Neg, Sub, SubAssign, Mul, Range};
2use std::fmt::{Display, Debug};
3use nalgebra_sparse::CscMatrix;
4use nalgebra_sparse::na::{Scalar, ClosedAddAssign, ClosedSubAssign, ClosedMulAssign};
5use num_traits::{Zero, One};
6use sprs::PermView;
7use auto_impl_ops::auto_ops;
8use yui_core::{Ring, RingOps, AddGrpOps,  AddGrp};
9use super::sp_mat::SpMat;
10
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct SpVec<R> { 
13    inner: CscMatrix<R> // ncols == 1
14}
15
16impl<R> SpVec<R> { 
17    fn new(inner: CscMatrix<R>) -> Self { 
18        assert_eq!(inner.ncols(), 1);
19        Self { inner }
20    }
21
22    #[allow(unused)]
23    pub(crate) fn inner(&self) -> &CscMatrix<R> { 
24        &self.inner
25    }
26
27    pub(crate) fn into_inner(self) -> CscMatrix<R> { 
28        self.inner
29    }
30
31    pub fn data(&self) -> (&[usize], &[R]) { 
32        let (_, indices, values) = self.inner.csc_data();
33        (indices, values)
34    }
35
36    pub fn zero(dim: usize) -> Self {
37        let inner = CscMatrix::zeros(dim, 1);
38        Self::new(inner)
39    }
40
41    pub fn is_zero(&self) -> bool
42    where R: Zero {
43        self.inner.values().iter().all(|a| a.is_zero())
44    }
45
46    pub fn unit(n: usize, i: usize) -> Self
47    where R: One {
48        let inner = CscMatrix::try_from_csc_data(
49            n, 1, 
50            vec![0, 1], 
51            vec![i], 
52            vec![R::one()]
53        ).unwrap();
54
55        Self::new(inner)
56    }
57
58    pub fn dim(&self) -> usize { 
59        self.inner.nrows()
60    }
61
62    pub fn iter(&self) -> impl Iterator<Item = (usize, &R)> { 
63        self.inner.triplet_iter().map(|(i, _, a)| (i, a))
64    }
65
66    pub fn iter_nz(&self) -> impl Iterator<Item = (usize, &R)>
67    where R: Zero { 
68        self.iter().filter(|(_, a)| !a.is_zero())
69    }
70
71    pub fn into_vec(self) -> Vec<R>
72    where R: Clone + Zero { 
73        self.into()
74    }
75
76    pub fn into_mat(self) -> SpMat<R> { 
77        self.into()
78    }
79}
80
81impl<R> From<Vec<R>> for SpVec<R>
82where R: Scalar + Zero + ClosedAddAssign {
83    fn from(vec: Vec<R>) -> Self {
84        Self::from_entries(vec.len(), vec.into_iter().enumerate())
85    }
86}
87
88impl<R> From<SpVec<R>> for Vec<R>
89where R: Clone + Zero {
90    fn from(value: SpVec<R>) -> Self {
91        let mut res = vec![R::zero(); value.dim()];
92        for (i, a) in value.iter_nz() { 
93            res[i] = a.clone();
94        }
95        res
96    }
97}
98
99// SpVec(n) as SpMat(n, 1)
100impl<R> From<SpVec<R>> for SpMat<R> { 
101    fn from(vec: SpVec<R>) -> Self {
102        SpMat::from(vec.into_inner())
103    }
104}
105
106impl<R> SpMat<R> {
107    fn into_spvec(self) -> SpVec<R> { 
108        assert_eq!(self.inner().ncols(), 1);
109        SpVec::new(self.into_inner())
110    }
111}
112
113impl<R> SpVec<R> 
114where R: Scalar + Zero + ClosedAddAssign { 
115    pub fn from_entries<T>(dim: usize, entries: T) -> Self
116    where T: IntoIterator<Item = (usize, R)> {
117        SpMat::from_entries(
118            (dim, 1), 
119            entries.into_iter().map(|(i, a)| (i, 0, a))
120        ).into_spvec()
121    }
122
123    pub fn from_sorted_entries<T>(dim: usize, entries: T) -> Self
124    where T: IntoIterator<Item = (usize, R)> {
125        let init = (vec![], vec![]);
126        let (row_indices, values) = entries.into_iter().fold(init, |mut res, (i, a)| { 
127            assert!(i < dim);
128            res.0.push(i);
129            res.1.push(a);
130            res
131        });
132        Self::from_raw_data(dim, row_indices, values)
133    }
134
135    fn from_raw_data(dim: usize, row_indices: Vec<usize>, values: Vec<R>) -> SpVec<R> { 
136        let col_offsets = vec![0, row_indices.len()];
137        let csc = CscMatrix::try_from_csc_data(dim, 1, col_offsets, row_indices, values).unwrap();
138        SpMat::from(csc).into_spvec()
139    }
140    
141    pub fn stack_vecs<I>(vecs: I) -> Self 
142    where I: IntoIterator<Item = SpVec<R>> { 
143        let init = (0, vec![], vec![]);
144        let (dim, row_indices, values) = vecs.into_iter().fold(init, |mut res, v| { 
145            let n1 = res.0;
146            let n2 = v.dim();
147            
148            let (_, mut rows, mut vals) = v.inner.disassemble();
149            rows.iter_mut().for_each(|i| *i += n1);
150
151            res.0 += n2;
152            res.1.append(&mut rows);
153            res.2.append(&mut vals);
154            res
155        });
156        Self::from_raw_data(dim, row_indices, values)
157    }
158
159    pub fn extract<F>(&self, dim: usize, f: F) -> SpVec<R>
160    where F: Fn(usize) -> Option<usize> { 
161        SpVec::from_entries(dim, self.iter().filter_map(|(i, a)|
162            f(i).map(|i| (i, a.clone()))
163        ))
164    }
165
166    pub fn permute(&self, p: PermView<'_>) -> SpVec<R> { 
167        self.extract(self.dim(), |i| Some(p.at(i)))
168    }
169
170    pub fn subvec(&self, range: Range<usize>) -> SpVec<R> { 
171        self.extract(
172            range.end - range.start, 
173            |i| range.contains(&i).then(|| i - range.start)
174        )
175    }
176
177    pub fn stack(&self, other: &SpVec<R>) -> SpVec<R> {
178        let (n1, n2) = (self.dim(), other.dim());
179        Self::from_entries(n1 + n2, Iterator::chain(
180            self.iter_nz().map(|(i, a)| (i, a.clone())),
181            other.iter_nz().map(|(i, a)| (n1 + i, a.clone()))
182        ))
183    }
184
185    pub fn split(&self, at: usize) -> (SpVec<R>, SpVec<R>) { 
186        let n = self.dim();
187        let k = at;
188        assert!(k <= n);
189
190        let mut e1 = vec![];
191        let mut e2 = vec![];
192
193        for (i, a) in self.iter() { 
194            if i < k { 
195                e1.push((i, a.clone()));
196            } else { 
197                e2.push((i - k, a.clone()));
198            }
199        }
200
201        (SpVec::from_entries(k, e1), SpVec::from_entries(n - k, e2))
202    }
203
204    pub fn to_dense(&self) -> Vec<R> { 
205        let mut vec = vec![R::zero(); self.dim()];
206        for (i, a) in self.iter_nz() { 
207            vec[i] = a.clone();
208        }
209        vec
210    }
211}
212
213impl<R> Default for SpVec<R> {
214    fn default() -> Self {
215        Self::zero(0)
216    }
217}
218
219impl<R> Neg for SpVec<R>
220where R: AddGrp, for<'a> &'a R: AddGrpOps<R> {
221    type Output = Self;
222    fn neg(self) -> Self::Output {
223        SpVec { inner: -self.inner }
224    }
225}
226
227impl<R> Neg for &SpVec<R>
228where R: Scalar + Neg<Output = R> {
229    type Output = SpVec<R>;
230    fn neg(self) -> Self::Output {
231        SpVec { inner: -&self.inner }
232    }
233}
234
235macro_rules! impl_binop {
236    ($trait:ident, $method:ident) => {
237        #[auto_ops]
238        impl<'a, 'b, R> $trait<&'b SpVec<R>> for &'a SpVec<R>
239        where R: Scalar + ClosedAddAssign + ClosedSubAssign + ClosedMulAssign + Zero + One + Neg<Output = R> {
240            type Output = SpVec<R>;
241            fn $method(self, rhs: &'b SpVec<R>) -> Self::Output {
242                let res = (&self.inner).$method(&rhs.inner);
243                SpVec::new(res)
244            }
245        }
246    };
247}
248
249impl_binop!(Add, add);
250impl_binop!(Sub, sub);
251
252// SpMat * SpVec
253#[auto_ops(val_val, val_ref, ref_val)]
254impl<'a, 'b, R> Mul<&'b SpVec<R>> for &'a SpMat<R>
255where R: Ring, for<'x> &'x R: RingOps<R> {
256    type Output = SpVec<R>;
257    fn mul(self, rhs: &'b SpVec<R>) -> Self::Output {
258        let res = self.inner() * &rhs.inner;
259        SpVec::new(res)
260    }
261}
262
263impl<R> Display for SpVec<R>
264where R: Ring, for<'a> &'a R: RingOps<R> {
265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        self.inner.fmt(f)
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use itertools::Itertools;
273    use sprs::PermOwned;
274    use super::*;
275
276    #[test]
277    fn from_vec() {
278        let v = SpVec::from(vec![1,0,3,5,0]);
279        assert_eq!(v.inner.disassemble(), (vec![0, 3], vec![0, 2, 3], vec![1, 3, 5]));
280    }
281
282    #[test]
283    fn from_entries() {
284        let v = SpVec::from_entries(5, vec![(0, 1), (4, 5), (2, 3)]);
285        assert_eq!(v.inner.disassemble(), (vec![0, 3], vec![0, 2, 4], vec![1, 3, 5]));
286    }
287
288    #[test]
289    fn to_dense() {
290        let v = SpVec::from(vec![1,0,3,5,0]);
291        assert_eq!(v.to_dense(), vec![1,0,3,5,0]);
292    }
293
294    #[test]
295    fn add() {
296        let v = SpVec::from(vec![1,0,3,5,0]);
297        let w = SpVec::from(vec![2,1,-1,3,2]);
298        assert_eq!(v + w, SpVec::from(vec![3,1,2,8,2]));
299    }
300
301    #[test]
302    fn sub() {
303        let v = SpVec::from(vec![1,0,3,5,0]);
304        let w = SpVec::from(vec![2,1,-1,3,2]);
305        assert_eq!(v - w, SpVec::from(vec![-1,-1,4,2,-2]));
306    }
307
308    #[test]
309    fn neg() {
310        let v = SpVec::from(vec![1,0,3,5,0]);
311        assert_eq!(-v, SpVec::from(vec![-1,0,-3,-5,0]));
312    }
313
314    #[test]
315    fn subvec() {
316        let v = SpVec::from((0..10).collect_vec());
317        let w = v.subvec(3..7);
318        assert_eq!(w, SpVec::from(vec![3,4,5,6]))
319    }
320
321    #[test]
322    fn subvec2() {
323        let v = SpVec::from((0..10).collect_vec());
324        let w = v.subvec(1..9);
325        let w = w.subvec(1..4);
326        assert_eq!(w, SpVec::from(vec![2,3,4]))
327    }
328
329    #[test]
330    fn permute() {
331        let p = PermOwned::new(vec![1,3,0,2]);
332        let v = SpVec::from(vec![0,1,2,3]);
333        let w = v.permute(p.view());
334        assert_eq!(w, SpVec::from(vec![2,0,3,1]));
335    }
336
337    #[test]
338    fn stack() {
339        let v1 = SpVec::from((0..3).collect_vec());
340        let v2 = SpVec::from((5..8).collect_vec());
341        let w = v1.stack(&v2);
342        assert_eq!(w, SpVec::from(vec![0,1,2,5,6,7]));
343    }
344
345    #[test]
346    fn split() { 
347        let v = SpVec::from((0..10).collect_vec());
348        let (x, y) = v.split(4);
349        assert_eq!(x, SpVec::from((0..4).collect_vec()));
350        assert_eq!(y, SpVec::from((4..10).collect_vec()));
351    }
352}