rustlearn/array/
traits.rs1use std::ops::{Range, RangeFrom, RangeFull, RangeTo};
4
5#[derive(Serialize, Deserialize, Clone, Debug)]
6pub enum MatrixOrder {
7 RowMajor,
8 ColumnMajor,
9}
10
11pub trait IndexableMatrix {
14 fn rows(&self) -> usize;
16
17 fn cols(&self) -> usize;
19
20 unsafe fn get_unchecked(&self, row: usize, column: usize) -> f32;
22
23 unsafe fn get_unchecked_mut(&mut self, row: usize, column: usize) -> &mut f32;
26
27 fn get(&self, row: usize, column: usize) -> f32 {
32 assert!(row < self.rows());
33 assert!(column < self.cols());
34
35 unsafe { self.get_unchecked(row, column) }
36 }
37
38 fn get_mut(&mut self, row: usize, column: usize) -> &mut f32 {
43 assert!(row < self.rows());
44 assert!(column < self.cols());
45
46 unsafe { self.get_unchecked_mut(row, column) }
47 }
48
49 fn set(&mut self, row: usize, column: usize, value: f32) {
54 assert!(row < self.rows());
55 assert!(column < self.cols());
56
57 unsafe {
58 self.set_unchecked(row, column, value);
59 }
60 }
61
62 unsafe fn set_unchecked(&mut self, row: usize, column: usize, value: f32) {
64 *self.get_unchecked_mut(row, column) = value;
65 }
66}
67
68pub trait RowIterable {
71 type Item: NonzeroIterable;
72 type Output: Iterator<Item = Self::Item>;
73 fn iter_rows(self) -> Self::Output;
75 fn iter_rows_range(self, range: Range<usize>) -> Self::Output;
77 fn view_row(self, idx: usize) -> Self::Item;
79}
80
81pub trait ColumnIterable {
84 type Item: NonzeroIterable;
85 type Output: Iterator<Item = Self::Item>;
86 fn iter_columns(self) -> Self::Output;
88 fn iter_columns_range(self, range: Range<usize>) -> Self::Output;
90 fn view_column(self, idx: usize) -> Self::Item;
92}
93
94pub trait NonzeroIterable {
97 type Output: Iterator<Item = (usize, f32)>;
98 fn iter_nonzero(&self) -> Self::Output;
99}
100
101pub trait RowIndex<Rhs> {
104 type Output;
105 fn get_rows(&self, index: &Rhs) -> Self::Output;
106}
107
108impl<T> RowIndex<usize> for T
109where
110 T: RowIndex<Vec<usize>>,
111{
112 type Output = T::Output;
113 fn get_rows(&self, index: &usize) -> Self::Output {
114 self.get_rows(&vec![*index])
115 }
116}
117
118impl<T> RowIndex<Range<usize>> for T
119where
120 T: RowIndex<Vec<usize>>,
121{
122 type Output = T::Output;
123 fn get_rows(&self, index: &Range<usize>) -> Self::Output {
124 self.get_rows(&(index.start..index.end).collect::<Vec<usize>>())
125 }
126}
127
128impl<T> RowIndex<RangeFrom<usize>> for T
129where
130 T: RowIndex<Range<usize>> + IndexableMatrix,
131{
132 type Output = T::Output;
133 fn get_rows(&self, index: &RangeFrom<usize>) -> Self::Output {
134 self.get_rows(&(index.start..self.rows()))
135 }
136}
137
138impl<T> RowIndex<RangeTo<usize>> for T
139where
140 T: RowIndex<Range<usize>> + IndexableMatrix,
141{
142 type Output = T::Output;
143 fn get_rows(&self, index: &RangeTo<usize>) -> Self::Output {
144 self.get_rows(&(0..index.end))
145 }
146}
147
148impl<T> RowIndex<RangeFull> for T
149where
150 T: RowIndex<Range<usize>> + IndexableMatrix,
151{
152 type Output = T::Output;
153 fn get_rows(&self, _: &RangeFull) -> Self::Output {
154 self.get_rows(&(0..self.rows()))
155 }
156}
157
158pub trait ElementwiseArrayOps<Rhs> {
160 type Output;
161 fn add(&self, rhs: Rhs) -> Self::Output;
162 fn add_inplace(&mut self, rhs: Rhs);
163 fn sub(&self, rhs: Rhs) -> Self::Output;
164 fn sub_inplace(&mut self, rhs: Rhs);
165 fn times(&self, rhs: Rhs) -> Self::Output;
166 fn times_inplace(&mut self, rhs: Rhs);
167 fn div(&self, rhs: Rhs) -> Self::Output;
168 fn div_inplace(&mut self, rhs: Rhs);
169}
170
171pub trait Dot<Rhs> {
173 type Output;
174 fn dot(&self, rhs: Rhs) -> Self::Output;
175}