sparse_bin_mat/vector/
mod.rs

1use crate::error::{validate_positions, IncompatibleDimensions, InvalidPositions};
2use crate::BinNum;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt;
6use std::ops::{Add, Deref, Mul};
7
8mod bitwise_operations;
9use bitwise_operations::BitwiseZipIter;
10
11/// A sparse binary vector.
12///
13/// There are two main variants of a vector,
14/// the owned one, [`SparseBinVec`](crate::SparseBinVec), and the borrowed one,
15/// [`SparseBinSlice`](crate::SparseBinSlice).
16/// Most of the time, you want to create a owned version.
17/// However, some iterators, such as those defined on [`SparseBinMat`](crate::SparseBinMat)
18/// returns the borrowed version.
19#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
20pub struct SparseBinVecBase<T> {
21    positions: T,
22    length: usize,
23}
24
25pub type SparseBinVec = SparseBinVecBase<Vec<usize>>;
26pub type SparseBinSlice<'a> = SparseBinVecBase<&'a [usize]>;
27
28impl SparseBinVec {
29    /// Creates a vector fill with zeros of the given length.
30    ///
31    /// # Example
32    ///
33    /// ```
34    /// # use sparse_bin_mat::SparseBinVec;
35    /// let vector = SparseBinVec::zeros(3);
36    ///
37    /// assert_eq!(vector.len(), 3);
38    /// assert_eq!(vector.weight(), 0);
39    /// ```
40    pub fn zeros(length: usize) -> Self {
41        Self {
42            length,
43            positions: Vec::new(),
44        }
45    }
46
47    /// Creates an empty vector.
48    ///
49    /// This allocate minimally, so it is a good placeholder.
50    ///
51    /// # Example
52    ///
53    /// ```
54    /// # use sparse_bin_mat::SparseBinVec;
55    /// let vector = SparseBinVec::empty();
56    ///
57    /// assert_eq!(vector.len(), 0);
58    /// assert_eq!(vector.weight(), 0);
59    /// ```
60    pub fn empty() -> Self {
61        Self {
62            length: 0,
63            positions: Vec::new(),
64        }
65    }
66
67    /// Converts the sparse binary vector to a `Vec` of
68    /// the non trivial positions.
69    ///
70    /// # Example
71    ///
72    /// ```
73    /// # use sparse_bin_mat::SparseBinVec;
74    /// let vector = SparseBinVec::new(3, vec![0, 2]);
75    ///
76    /// assert_eq!(vector.to_positions_vec(), vec![0, 2]);
77    /// ```
78    pub fn to_positions_vec(self) -> Vec<usize> {
79        self.positions
80    }
81}
82
83impl<T: Deref<Target = [usize]>> SparseBinVecBase<T> {
84    /// Creates a new vector with the given length and list of non trivial positions.
85    ///
86    /// # Example
87    ///
88    /// ```
89    /// # use sparse_bin_mat::SparseBinVec;
90    /// use sparse_bin_mat::error::InvalidPositions;
91    ///
92    /// let vector = SparseBinVec::new(5, vec![0, 2]);
93    ///
94    /// assert_eq!(vector.len(), 5);
95    /// assert_eq!(vector.weight(), 2);
96    /// ```
97    pub fn new(length: usize, positions: T) -> Self {
98        Self::try_new(length, positions).unwrap()
99    }
100
101    /// Creates a new vector with the given length and list of non trivial positions
102    /// or returns as error if the positions are unsorted, greater or equal to length
103    /// or contain duplicates.
104    ///
105    ///
106    /// # Example
107    ///
108    /// ```
109    /// # use sparse_bin_mat::SparseBinVec;
110    /// use sparse_bin_mat::error::InvalidPositions;
111    ///
112    /// let vector = SparseBinVec::try_new(5, vec![0, 2]);
113    /// assert_eq!(vector, Ok(SparseBinVec::new(5, vec![0, 2])));
114    ///
115    /// let vector = SparseBinVec::try_new(5, vec![2, 0]);
116    /// assert_eq!(vector, Err(InvalidPositions::Unsorted));
117    ///
118    /// let vector = SparseBinVec::try_new(5, vec![0, 10]);
119    /// assert_eq!(vector, Err(InvalidPositions::OutOfBound));
120    ///
121    /// let vector = SparseBinVec::try_new(5, vec![0, 0]);
122    /// assert_eq!(vector, Err(InvalidPositions::Duplicated));
123    /// ```
124    pub fn try_new(length: usize, positions: T) -> Result<Self, InvalidPositions> {
125        validate_positions(length, &positions)?;
126        Ok(Self { positions, length })
127    }
128
129    // Positions should be sorted, in bound and all unique.
130    pub(crate) fn new_unchecked(length: usize, positions: T) -> Self {
131        Self { positions, length }
132    }
133
134    /// Returns the length (number of elements) of the vector.
135    pub fn len(&self) -> usize {
136        self.length
137    }
138
139    /// Returns the number of elements with value 1 in the vector.
140    pub fn weight(&self) -> usize {
141        self.positions.len()
142    }
143
144    /// Returns true if the length of the vector is 0.
145    pub fn is_empty(&self) -> bool {
146        self.len() == 0
147    }
148
149    /// Returns true if all the element in the vector are 0.
150    pub fn is_zero(&self) -> bool {
151        self.weight() == 0
152    }
153
154    /// Returns the value at the given position
155    /// or None if the position is out of bound.
156    ///
157    /// # Example
158    ///
159    /// ```
160    /// # use sparse_bin_mat::SparseBinVec;
161    /// let vector = SparseBinVec::new(3, vec![0, 2]);
162    ///
163    /// assert!(vector.get(0).unwrap().is_one());
164    /// assert!(vector.get(1).unwrap().is_zero());
165    /// assert!(vector.get(2).unwrap().is_one());
166    /// assert!(vector.get(3).is_none());
167    /// ```
168    pub fn get(&self, position: usize) -> Option<BinNum> {
169        if position < self.len() {
170            if self.positions.contains(&position) {
171                Some(1.into())
172            } else {
173                Some(0.into())
174            }
175        } else {
176            None
177        }
178    }
179
180    /// Returns true if the value at the given position is 0
181    /// or None if the position is out of bound.
182    ///
183    /// # Example
184    ///
185    /// ```
186    /// # use sparse_bin_mat::SparseBinVec;
187    /// let vector = SparseBinVec::new(3, vec![0, 2]);
188    ///
189    /// assert_eq!(vector.is_zero_at(0), Some(false));
190    /// assert_eq!(vector.is_zero_at(1), Some(true));
191    /// assert_eq!(vector.is_zero_at(2), Some(false));
192    /// assert_eq!(vector.is_zero_at(3), None);
193    /// ```
194    pub fn is_zero_at(&self, position: usize) -> Option<bool> {
195        self.get(position).map(|value| value == 0.into())
196    }
197
198    /// Returns true if the value at the given position is 1
199    /// or None if the position is out of bound.
200    ///
201    /// # Example
202    ///
203    /// ```
204    /// # use sparse_bin_mat::SparseBinVec;
205    /// let vector = SparseBinVec::new(3, vec![0, 2]);
206    ///
207    /// assert_eq!(vector.is_one_at(0), Some(true));
208    /// assert_eq!(vector.is_one_at(1), Some(false));
209    /// assert_eq!(vector.is_one_at(2), Some(true));
210    /// assert_eq!(vector.is_one_at(3), None);
211    /// ```
212    pub fn is_one_at(&self, position: usize) -> Option<bool> {
213        self.get(position).map(|value| value == 1.into())
214    }
215
216    /// Returns an iterator over all positions where the value is 1.
217    ///
218    /// # Example
219    ///
220    /// ```
221    /// # use sparse_bin_mat::SparseBinVec;
222    /// let vector = SparseBinVec::new(5, vec![0, 1, 3]);
223    /// let mut iter = vector.non_trivial_positions();
224    ///
225    /// assert_eq!(iter.next(), Some(0));
226    /// assert_eq!(iter.next(), Some(1));
227    /// assert_eq!(iter.next(), Some(3));
228    /// assert_eq!(iter.next(), None);
229    /// ```
230    pub fn non_trivial_positions<'a>(&'a self) -> NonTrivialPositions<'a> {
231        NonTrivialPositions {
232            positions: &self.positions,
233            index: 0,
234        }
235    }
236
237    /// Returns an iterator over all value in the vector.
238    ///
239    /// # Example
240    ///
241    /// ```
242    /// # use sparse_bin_mat::{SparseBinVec, BinNum};
243    /// let vector = SparseBinVec::new(4, vec![0, 2]);
244    /// let mut iter = vector.iter_dense();
245    ///
246    /// assert_eq!(iter.next(), Some(BinNum::one()));
247    /// assert_eq!(iter.next(), Some(BinNum::zero()));
248    /// assert_eq!(iter.next(), Some(BinNum::one()));
249    /// assert_eq!(iter.next(), Some(BinNum::zero()));
250    /// assert_eq!(iter.next(), None);
251    /// ```
252    pub fn iter_dense<'a>(&'a self) -> IterDense<'a, T> {
253        IterDense {
254            vec: self,
255            index: 0,
256        }
257    }
258
259    /// Returns the concatenation of two vectors.
260    ///
261    /// # Example
262    ///
263    /// ```
264    /// # use sparse_bin_mat::SparseBinVec;
265    /// let left_vector = SparseBinVec::new(3, vec![0, 1]);
266    /// let right_vector = SparseBinVec::new(4, vec![2, 3]);
267    ///
268    /// let concatened = left_vector.concat(&right_vector);
269    ///
270    /// let expected = SparseBinVec::new(7, vec![0, 1, 5, 6]);
271    ///
272    /// assert_eq!(concatened, expected);
273    /// ```
274    pub fn concat(&self, other: &Self) -> SparseBinVec {
275        let positions = self
276            .non_trivial_positions()
277            .chain(other.non_trivial_positions().map(|pos| pos + self.len()))
278            .collect();
279        SparseBinVec::new_unchecked(self.len() + other.len(), positions)
280    }
281
282    /// Returns a new vector keeping only the given positions or an error
283    /// if the positions are unsorted, out of bound or contain deplicate.
284    ///
285    /// Positions are relabeled to the fit new number of positions.
286    ///
287    /// # Example
288    ///
289    /// ```
290    /// use sparse_bin_mat::SparseBinVec;
291    /// let vector = SparseBinVec::new(5, vec![0, 2, 4]);
292    /// let truncated = SparseBinVec::new(3, vec![0, 2]);
293    ///
294    /// assert_eq!(vector.keep_only_positions(&[0, 1, 2]), Ok(truncated));
295    /// assert_eq!(vector.keep_only_positions(&[1, 2]).map(|vec| vec.len()), Ok(2));
296    /// ```
297    pub fn keep_only_positions(
298        &self,
299        positions: &[usize],
300    ) -> Result<SparseBinVec, InvalidPositions> {
301        validate_positions(self.length, positions)?;
302        let old_to_new_positions_map = positions
303            .iter()
304            .enumerate()
305            .map(|(new, old)| (old, new))
306            .collect::<HashMap<_, _>>();
307        let new_positions = self
308            .non_trivial_positions()
309            .filter_map(|position| old_to_new_positions_map.get(&position).cloned())
310            .collect();
311        Ok(SparseBinVec::new_unchecked(positions.len(), new_positions))
312    }
313
314    /// Returns a truncated vector where the given positions are remove or an error
315    /// if the positions are unsorted or out of bound.
316    ///
317    /// Positions are relabeled to fit the new number of positions.
318    ///
319    /// # Example
320    ///
321    /// ```
322    /// # use sparse_bin_mat::SparseBinVec;
323    /// let vector = SparseBinVec::new(5, vec![0, 2, 4]);
324    /// let truncated = SparseBinVec::new(3, vec![0, 2]);
325    ///
326    /// assert_eq!(vector.without_positions(&[3, 4]), Ok(truncated));
327    /// assert_eq!(vector.without_positions(&[1, 2]).map(|vec| vec.len()), Ok(3));
328    /// ```
329    pub fn without_positions(&self, positions: &[usize]) -> Result<SparseBinVec, InvalidPositions> {
330        let to_keep: Vec<usize> = (0..self.len()).filter(|x| !positions.contains(x)).collect();
331        self.keep_only_positions(&to_keep)
332    }
333
334    /// Returns a view over the vector.
335    pub fn as_view(&self) -> SparseBinSlice {
336        SparseBinSlice {
337            length: self.length,
338            positions: &self.positions,
339        }
340    }
341
342    /// Returns a slice of the non trivial positions.
343    pub fn as_slice(&self) -> &[usize] {
344        self.positions.as_ref()
345    }
346
347    /// Returns an owned version of the vector.
348    pub fn to_vec(self) -> SparseBinVec {
349        SparseBinVec {
350            length: self.length,
351            positions: self.positions.to_owned(),
352        }
353    }
354
355    /// Returns the dot product of two vectors or an
356    /// error if the vectors have different length.
357    ///
358    /// # Example
359    ///
360    /// ```
361    /// # use sparse_bin_mat::SparseBinVec;
362    /// let first = SparseBinVec::new(4, vec![0, 1, 2]);
363    /// let second = SparseBinVec::new(4, vec![1, 2, 3]);
364    /// let third = SparseBinVec::new(4, vec![0, 3]);
365    ///
366    /// assert_eq!(first.dot_with(&second), Ok(0.into()));
367    /// assert_eq!(first.dot_with(&third), Ok((1.into())));
368    /// ```
369    pub fn dot_with<S: Deref<Target = [usize]>>(
370        &self,
371        other: &SparseBinVecBase<S>,
372    ) -> Result<BinNum, IncompatibleDimensions<usize, usize>> {
373        if self.len() != other.len() {
374            return Err(IncompatibleDimensions::new(self.len(), other.len()));
375        }
376        Ok(
377            BitwiseZipIter::new(self.as_view(), other.as_view()).fold(0.into(), |sum, x| {
378                sum + x.first_row_value * x.second_row_value
379            }),
380        )
381    }
382
383    /// Returns the bitwise xor of two vectors or an
384    /// error if the vectors have different length.
385    ///
386    /// Use the Add (+) operator if you want a version
387    /// that panics instead or returning an error.
388    /// # Example
389    ///
390    /// ```
391    /// # use sparse_bin_mat::SparseBinVec;
392    /// let first = SparseBinVec::new(4, vec![0, 1, 2]);
393    /// let second = SparseBinVec::new(4, vec![1, 2, 3]);
394    /// let third = SparseBinVec::new(4, vec![0, 3]);
395    ///
396    /// assert_eq!(first.bitwise_xor_with(&second), Ok(third));
397    /// ```
398    pub fn bitwise_xor_with<S: Deref<Target = [usize]>>(
399        &self,
400        other: &SparseBinVecBase<S>,
401    ) -> Result<SparseBinVec, IncompatibleDimensions<usize, usize>> {
402        if self.len() != other.len() {
403            return Err(IncompatibleDimensions::new(self.len(), other.len()));
404        }
405        let positions = BitwiseZipIter::new(self.as_view(), other.as_view())
406            .filter_map(|x| {
407                if x.first_row_value + x.second_row_value == 1.into() {
408                    Some(x.position)
409                } else {
410                    None
411                }
412            })
413            .collect();
414        Ok(SparseBinVec::new_unchecked(self.len(), positions))
415    }
416
417    /// Returns a json string for the vector.
418    pub fn as_json(&self) -> Result<String, serde_json::Error>
419    where
420        T: Serialize,
421    {
422        serde_json::to_string(self)
423    }
424}
425
426impl<S, T> Add<&SparseBinVecBase<S>> for &SparseBinVecBase<T>
427where
428    S: Deref<Target = [usize]>,
429    T: Deref<Target = [usize]>,
430{
431    type Output = SparseBinVec;
432
433    fn add(self, other: &SparseBinVecBase<S>) -> Self::Output {
434        self.bitwise_xor_with(other).expect(&format!(
435            "vector of length {} can't be added to vector of length {}",
436            self.len(),
437            other.len()
438        ))
439    }
440}
441
442impl<S, T> Mul<&SparseBinVecBase<S>> for &SparseBinVecBase<T>
443where
444    S: Deref<Target = [usize]>,
445    T: Deref<Target = [usize]>,
446{
447    type Output = BinNum;
448
449    fn mul(self, other: &SparseBinVecBase<S>) -> Self::Output {
450        self.dot_with(other).expect(&format!(
451            "vector of length {} can't be dotted to vector of length {}",
452            self.len(),
453            other.len()
454        ))
455    }
456}
457
458impl<T: Deref<Target = [usize]>> fmt::Display for SparseBinVecBase<T> {
459    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
460        write!(f, "{:?}", self.positions.deref())
461    }
462}
463
464/// An iterator over all non trivial positions of
465/// a sparse binary vector.
466#[derive(Debug, Clone)]
467pub struct NonTrivialPositions<'vec> {
468    positions: &'vec [usize],
469    index: usize,
470}
471
472impl<'vec> Iterator for NonTrivialPositions<'vec> {
473    type Item = usize;
474
475    fn next(&mut self) -> Option<Self::Item> {
476        self.positions.get(self.index).map(|position| {
477            self.index += 1;
478            *position
479        })
480    }
481}
482
483/// An iterator over all positions of
484/// a sparse binary vector.
485#[derive(Debug, Clone)]
486pub struct IterDense<'vec, T> {
487    vec: &'vec SparseBinVecBase<T>,
488    index: usize,
489}
490
491impl<'vec, T> Iterator for IterDense<'vec, T>
492where
493    T: Deref<Target = [usize]>,
494{
495    type Item = BinNum;
496
497    fn next(&mut self) -> Option<Self::Item> {
498        let value = self.vec.get(self.index);
499        self.index += 1;
500        value
501    }
502}
503
504#[cfg(test)]
505mod test {
506    use super::*;
507
508    #[test]
509    fn addition() {
510        let first_vector = SparseBinVec::new(6, vec![0, 2, 4]);
511        let second_vector = SparseBinVec::new(6, vec![0, 1, 2]);
512        let sum = SparseBinVec::new(6, vec![1, 4]);
513        assert_eq!(&first_vector + &second_vector, sum);
514    }
515
516    #[test]
517    fn panics_on_addition_if_different_length() {
518        let vector_6 = SparseBinVec::new(6, vec![0, 2, 4]);
519        let vector_2 = SparseBinVec::new(2, vec![0]);
520
521        let result = std::panic::catch_unwind(|| &vector_6 + &vector_2);
522        assert!(result.is_err());
523    }
524
525    #[test]
526    fn ser_de() {
527        let vector = SparseBinVec::new(10, vec![0, 5, 7, 8]);
528        let json = serde_json::to_string(&vector).unwrap();
529        let expected = String::from("{\"positions\":[0,5,7,8],\"length\":10}");
530        assert_eq!(json, expected);
531    }
532}