1use std::ops::{Add, AddAssign, Neg, Sub, SubAssign, Mul, Range};
2use std::fmt::{Display, Debug};
3use nalgebra_sparse::CscMatrix;
4use nalgebra_sparse::na::{Scalar, ClosedAddAssign, ClosedSubAssign, ClosedMulAssign};
5use num_traits::{Zero, One};
6use sprs::PermView;
7use auto_impl_ops::auto_ops;
8use yui_core::{Ring, RingOps, AddGrpOps, AddGrp};
9use super::sp_mat::SpMat;
10
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct SpVec<R> {
13 inner: CscMatrix<R> }
15
16impl<R> SpVec<R> {
17 fn new(inner: CscMatrix<R>) -> Self {
18 assert_eq!(inner.ncols(), 1);
19 Self { inner }
20 }
21
22 #[allow(unused)]
23 pub(crate) fn inner(&self) -> &CscMatrix<R> {
24 &self.inner
25 }
26
27 pub(crate) fn into_inner(self) -> CscMatrix<R> {
28 self.inner
29 }
30
31 pub fn data(&self) -> (&[usize], &[R]) {
32 let (_, indices, values) = self.inner.csc_data();
33 (indices, values)
34 }
35
36 pub fn zero(dim: usize) -> Self {
37 let inner = CscMatrix::zeros(dim, 1);
38 Self::new(inner)
39 }
40
41 pub fn is_zero(&self) -> bool
42 where R: Zero {
43 self.inner.values().iter().all(|a| a.is_zero())
44 }
45
46 pub fn unit(n: usize, i: usize) -> Self
47 where R: One {
48 let inner = CscMatrix::try_from_csc_data(
49 n, 1,
50 vec![0, 1],
51 vec![i],
52 vec![R::one()]
53 ).unwrap();
54
55 Self::new(inner)
56 }
57
58 pub fn dim(&self) -> usize {
59 self.inner.nrows()
60 }
61
62 pub fn iter(&self) -> impl Iterator<Item = (usize, &R)> {
63 self.inner.triplet_iter().map(|(i, _, a)| (i, a))
64 }
65
66 pub fn iter_nz(&self) -> impl Iterator<Item = (usize, &R)>
67 where R: Zero {
68 self.iter().filter(|(_, a)| !a.is_zero())
69 }
70
71 pub fn into_vec(self) -> Vec<R>
72 where R: Clone + Zero {
73 self.into()
74 }
75
76 pub fn into_mat(self) -> SpMat<R> {
77 self.into()
78 }
79}
80
81impl<R> From<Vec<R>> for SpVec<R>
82where R: Scalar + Zero + ClosedAddAssign {
83 fn from(vec: Vec<R>) -> Self {
84 Self::from_entries(vec.len(), vec.into_iter().enumerate())
85 }
86}
87
88impl<R> From<SpVec<R>> for Vec<R>
89where R: Clone + Zero {
90 fn from(value: SpVec<R>) -> Self {
91 let mut res = vec![R::zero(); value.dim()];
92 for (i, a) in value.iter_nz() {
93 res[i] = a.clone();
94 }
95 res
96 }
97}
98
99impl<R> From<SpVec<R>> for SpMat<R> {
101 fn from(vec: SpVec<R>) -> Self {
102 SpMat::from(vec.into_inner())
103 }
104}
105
106impl<R> SpMat<R> {
107 fn into_spvec(self) -> SpVec<R> {
108 assert_eq!(self.inner().ncols(), 1);
109 SpVec::new(self.into_inner())
110 }
111}
112
113impl<R> SpVec<R>
114where R: Scalar + Zero + ClosedAddAssign {
115 pub fn from_entries<T>(dim: usize, entries: T) -> Self
116 where T: IntoIterator<Item = (usize, R)> {
117 SpMat::from_entries(
118 (dim, 1),
119 entries.into_iter().map(|(i, a)| (i, 0, a))
120 ).into_spvec()
121 }
122
123 pub fn from_sorted_entries<T>(dim: usize, entries: T) -> Self
124 where T: IntoIterator<Item = (usize, R)> {
125 let init = (vec![], vec![]);
126 let (row_indices, values) = entries.into_iter().fold(init, |mut res, (i, a)| {
127 assert!(i < dim);
128 res.0.push(i);
129 res.1.push(a);
130 res
131 });
132 Self::from_raw_data(dim, row_indices, values)
133 }
134
135 fn from_raw_data(dim: usize, row_indices: Vec<usize>, values: Vec<R>) -> SpVec<R> {
136 let col_offsets = vec![0, row_indices.len()];
137 let csc = CscMatrix::try_from_csc_data(dim, 1, col_offsets, row_indices, values).unwrap();
138 SpMat::from(csc).into_spvec()
139 }
140
141 pub fn stack_vecs<I>(vecs: I) -> Self
142 where I: IntoIterator<Item = SpVec<R>> {
143 let init = (0, vec![], vec![]);
144 let (dim, row_indices, values) = vecs.into_iter().fold(init, |mut res, v| {
145 let n1 = res.0;
146 let n2 = v.dim();
147
148 let (_, mut rows, mut vals) = v.inner.disassemble();
149 rows.iter_mut().for_each(|i| *i += n1);
150
151 res.0 += n2;
152 res.1.append(&mut rows);
153 res.2.append(&mut vals);
154 res
155 });
156 Self::from_raw_data(dim, row_indices, values)
157 }
158
159 pub fn extract<F>(&self, dim: usize, f: F) -> SpVec<R>
160 where F: Fn(usize) -> Option<usize> {
161 SpVec::from_entries(dim, self.iter().filter_map(|(i, a)|
162 f(i).map(|i| (i, a.clone()))
163 ))
164 }
165
166 pub fn permute(&self, p: PermView<'_>) -> SpVec<R> {
167 self.extract(self.dim(), |i| Some(p.at(i)))
168 }
169
170 pub fn subvec(&self, range: Range<usize>) -> SpVec<R> {
171 self.extract(
172 range.end - range.start,
173 |i| range.contains(&i).then(|| i - range.start)
174 )
175 }
176
177 pub fn stack(&self, other: &SpVec<R>) -> SpVec<R> {
178 let (n1, n2) = (self.dim(), other.dim());
179 Self::from_entries(n1 + n2, Iterator::chain(
180 self.iter_nz().map(|(i, a)| (i, a.clone())),
181 other.iter_nz().map(|(i, a)| (n1 + i, a.clone()))
182 ))
183 }
184
185 pub fn split(&self, at: usize) -> (SpVec<R>, SpVec<R>) {
186 let n = self.dim();
187 let k = at;
188 assert!(k <= n);
189
190 let mut e1 = vec![];
191 let mut e2 = vec![];
192
193 for (i, a) in self.iter() {
194 if i < k {
195 e1.push((i, a.clone()));
196 } else {
197 e2.push((i - k, a.clone()));
198 }
199 }
200
201 (SpVec::from_entries(k, e1), SpVec::from_entries(n - k, e2))
202 }
203
204 pub fn to_dense(&self) -> Vec<R> {
205 let mut vec = vec![R::zero(); self.dim()];
206 for (i, a) in self.iter_nz() {
207 vec[i] = a.clone();
208 }
209 vec
210 }
211}
212
213impl<R> Default for SpVec<R> {
214 fn default() -> Self {
215 Self::zero(0)
216 }
217}
218
219impl<R> Neg for SpVec<R>
220where R: AddGrp, for<'a> &'a R: AddGrpOps<R> {
221 type Output = Self;
222 fn neg(self) -> Self::Output {
223 SpVec { inner: -self.inner }
224 }
225}
226
227impl<R> Neg for &SpVec<R>
228where R: Scalar + Neg<Output = R> {
229 type Output = SpVec<R>;
230 fn neg(self) -> Self::Output {
231 SpVec { inner: -&self.inner }
232 }
233}
234
235macro_rules! impl_binop {
236 ($trait:ident, $method:ident) => {
237 #[auto_ops]
238 impl<'a, 'b, R> $trait<&'b SpVec<R>> for &'a SpVec<R>
239 where R: Scalar + ClosedAddAssign + ClosedSubAssign + ClosedMulAssign + Zero + One + Neg<Output = R> {
240 type Output = SpVec<R>;
241 fn $method(self, rhs: &'b SpVec<R>) -> Self::Output {
242 let res = (&self.inner).$method(&rhs.inner);
243 SpVec::new(res)
244 }
245 }
246 };
247}
248
249impl_binop!(Add, add);
250impl_binop!(Sub, sub);
251
252#[auto_ops(val_val, val_ref, ref_val)]
254impl<'a, 'b, R> Mul<&'b SpVec<R>> for &'a SpMat<R>
255where R: Ring, for<'x> &'x R: RingOps<R> {
256 type Output = SpVec<R>;
257 fn mul(self, rhs: &'b SpVec<R>) -> Self::Output {
258 let res = self.inner() * &rhs.inner;
259 SpVec::new(res)
260 }
261}
262
263impl<R> Display for SpVec<R>
264where R: Ring, for<'a> &'a R: RingOps<R> {
265 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266 self.inner.fmt(f)
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use itertools::Itertools;
273 use sprs::PermOwned;
274 use super::*;
275
276 #[test]
277 fn from_vec() {
278 let v = SpVec::from(vec![1,0,3,5,0]);
279 assert_eq!(v.inner.disassemble(), (vec![0, 3], vec![0, 2, 3], vec![1, 3, 5]));
280 }
281
282 #[test]
283 fn from_entries() {
284 let v = SpVec::from_entries(5, vec![(0, 1), (4, 5), (2, 3)]);
285 assert_eq!(v.inner.disassemble(), (vec![0, 3], vec![0, 2, 4], vec![1, 3, 5]));
286 }
287
288 #[test]
289 fn to_dense() {
290 let v = SpVec::from(vec![1,0,3,5,0]);
291 assert_eq!(v.to_dense(), vec![1,0,3,5,0]);
292 }
293
294 #[test]
295 fn add() {
296 let v = SpVec::from(vec![1,0,3,5,0]);
297 let w = SpVec::from(vec![2,1,-1,3,2]);
298 assert_eq!(v + w, SpVec::from(vec![3,1,2,8,2]));
299 }
300
301 #[test]
302 fn sub() {
303 let v = SpVec::from(vec![1,0,3,5,0]);
304 let w = SpVec::from(vec![2,1,-1,3,2]);
305 assert_eq!(v - w, SpVec::from(vec![-1,-1,4,2,-2]));
306 }
307
308 #[test]
309 fn neg() {
310 let v = SpVec::from(vec![1,0,3,5,0]);
311 assert_eq!(-v, SpVec::from(vec![-1,0,-3,-5,0]));
312 }
313
314 #[test]
315 fn subvec() {
316 let v = SpVec::from((0..10).collect_vec());
317 let w = v.subvec(3..7);
318 assert_eq!(w, SpVec::from(vec![3,4,5,6]))
319 }
320
321 #[test]
322 fn subvec2() {
323 let v = SpVec::from((0..10).collect_vec());
324 let w = v.subvec(1..9);
325 let w = w.subvec(1..4);
326 assert_eq!(w, SpVec::from(vec![2,3,4]))
327 }
328
329 #[test]
330 fn permute() {
331 let p = PermOwned::new(vec![1,3,0,2]);
332 let v = SpVec::from(vec![0,1,2,3]);
333 let w = v.permute(p.view());
334 assert_eq!(w, SpVec::from(vec![2,0,3,1]));
335 }
336
337 #[test]
338 fn stack() {
339 let v1 = SpVec::from((0..3).collect_vec());
340 let v2 = SpVec::from((5..8).collect_vec());
341 let w = v1.stack(&v2);
342 assert_eq!(w, SpVec::from(vec![0,1,2,5,6,7]));
343 }
344
345 #[test]
346 fn split() {
347 let v = SpVec::from((0..10).collect_vec());
348 let (x, y) = v.split(4);
349 assert_eq!(x, SpVec::from((0..4).collect_vec()));
350 assert_eq!(y, SpVec::from((4..10).collect_vec()));
351 }
352}