rustlearn/array/
traits.rs

1//! Basic traits applying to all types of matrices.
2
3use std::ops::{Range, RangeFrom, RangeFull, RangeTo};
4
5#[derive(Serialize, Deserialize, Clone, Debug)]
6pub enum MatrixOrder {
7    RowMajor,
8    ColumnMajor,
9}
10
11/// Trait representing a shaped matrix whose entries can be accessed
12/// at will using their row and column position.
13pub trait IndexableMatrix {
14    /// Return the number of rows of the matrix.
15    fn rows(&self) -> usize;
16
17    /// Return the number of columns of the matrix.
18    fn cols(&self) -> usize;
19
20    /// Get the value of the entry at (`row`, `column`) without bounds checking.
21    unsafe fn get_unchecked(&self, row: usize, column: usize) -> f32;
22
23    /// Get a mutable reference to the value of the entry at (`row`, `column`)
24    /// without bounds checking.
25    unsafe fn get_unchecked_mut(&mut self, row: usize, column: usize) -> &mut f32;
26
27    /// Get the value of the entry at (`row`, `column`).
28    ///
29    /// # Panics
30    /// Will panic if the element accessed is out of bounds.
31    fn get(&self, row: usize, column: usize) -> f32 {
32        assert!(row < self.rows());
33        assert!(column < self.cols());
34
35        unsafe { self.get_unchecked(row, column) }
36    }
37
38    /// Get a mutable reference to value of the entry at (`row`, `column`).
39    ///
40    /// # Panics
41    /// Will panic if the element accessed is out of bounds.
42    fn get_mut(&mut self, row: usize, column: usize) -> &mut f32 {
43        assert!(row < self.rows());
44        assert!(column < self.cols());
45
46        unsafe { self.get_unchecked_mut(row, column) }
47    }
48
49    /// Set the value of the entry at (`row`, `column`) to `value`.
50    ///
51    /// # Panics
52    /// Will panic if the element accessed is out of bounds.
53    fn set(&mut self, row: usize, column: usize, value: f32) {
54        assert!(row < self.rows());
55        assert!(column < self.cols());
56
57        unsafe {
58            self.set_unchecked(row, column, value);
59        }
60    }
61
62    /// Set the value of the entry at (`row`, `column`) to `value` without bounds checking.
63    unsafe fn set_unchecked(&mut self, row: usize, column: usize, value: f32) {
64        *self.get_unchecked_mut(row, column) = value;
65    }
66}
67
68/// Trait representing a matrix that can be iterated over in
69/// a row-wise fashion.
70pub trait RowIterable {
71    type Item: NonzeroIterable;
72    type Output: Iterator<Item = Self::Item>;
73    /// Iterate over rows of the matrix.
74    fn iter_rows(self) -> Self::Output;
75    /// Iterate over a subset of rows of the matrix.
76    fn iter_rows_range(self, range: Range<usize>) -> Self::Output;
77    /// View a row of the matrix.
78    fn view_row(self, idx: usize) -> Self::Item;
79}
80
81/// Trait representing a matrix that can be iterated over in
82/// a column-wise fashion.
83pub trait ColumnIterable {
84    type Item: NonzeroIterable;
85    type Output: Iterator<Item = Self::Item>;
86    /// Iterate over columns of a the matrix.
87    fn iter_columns(self) -> Self::Output;
88    /// Iterate over a subset of columns of the matrix.
89    fn iter_columns_range(self, range: Range<usize>) -> Self::Output;
90    /// View a column of the matrix.
91    fn view_column(self, idx: usize) -> Self::Item;
92}
93
94/// Trait representing an object whose non-zero
95/// entries can be iterated over.
96pub trait NonzeroIterable {
97    type Output: Iterator<Item = (usize, f32)>;
98    fn iter_nonzero(&self) -> Self::Output;
99}
100
101/// Trait representing a matrix whose rows can be selected
102/// to create a new matrix containing those rows.
103pub trait RowIndex<Rhs> {
104    type Output;
105    fn get_rows(&self, index: &Rhs) -> Self::Output;
106}
107
108impl<T> RowIndex<usize> for T
109where
110    T: RowIndex<Vec<usize>>,
111{
112    type Output = T::Output;
113    fn get_rows(&self, index: &usize) -> Self::Output {
114        self.get_rows(&vec![*index])
115    }
116}
117
118impl<T> RowIndex<Range<usize>> for T
119where
120    T: RowIndex<Vec<usize>>,
121{
122    type Output = T::Output;
123    fn get_rows(&self, index: &Range<usize>) -> Self::Output {
124        self.get_rows(&(index.start..index.end).collect::<Vec<usize>>())
125    }
126}
127
128impl<T> RowIndex<RangeFrom<usize>> for T
129where
130    T: RowIndex<Range<usize>> + IndexableMatrix,
131{
132    type Output = T::Output;
133    fn get_rows(&self, index: &RangeFrom<usize>) -> Self::Output {
134        self.get_rows(&(index.start..self.rows()))
135    }
136}
137
138impl<T> RowIndex<RangeTo<usize>> for T
139where
140    T: RowIndex<Range<usize>> + IndexableMatrix,
141{
142    type Output = T::Output;
143    fn get_rows(&self, index: &RangeTo<usize>) -> Self::Output {
144        self.get_rows(&(0..index.end))
145    }
146}
147
148impl<T> RowIndex<RangeFull> for T
149where
150    T: RowIndex<Range<usize>> + IndexableMatrix,
151{
152    type Output = T::Output;
153    fn get_rows(&self, _: &RangeFull) -> Self::Output {
154        self.get_rows(&(0..self.rows()))
155    }
156}
157
158/// Elementwise array operations trait.
159pub trait ElementwiseArrayOps<Rhs> {
160    type Output;
161    fn add(&self, rhs: Rhs) -> Self::Output;
162    fn add_inplace(&mut self, rhs: Rhs);
163    fn sub(&self, rhs: Rhs) -> Self::Output;
164    fn sub_inplace(&mut self, rhs: Rhs);
165    fn times(&self, rhs: Rhs) -> Self::Output;
166    fn times_inplace(&mut self, rhs: Rhs);
167    fn div(&self, rhs: Rhs) -> Self::Output;
168    fn div_inplace(&mut self, rhs: Rhs);
169}
170
171/// A matrix multiplication trait.
172pub trait Dot<Rhs> {
173    type Output;
174    fn dot(&self, rhs: Rhs) -> Self::Output;
175}