rust_nn/matrix/
mod.rs

1use rand::Rng;
2
3use crate::prelude::*;
4use std::{
5    fmt::Display,
6    ops::{Index, IndexMut},
7};
8
9pub mod ops;
10
11#[derive(Debug, PartialEq, Clone)]
12pub struct Matrix2<T> {
13    data: Vec<T>,
14    dim: (usize, usize),
15}
16
17impl<T: Clone> Matrix2<T> {
18    pub fn clone_row_to_vec(&self, row: usize) -> Vec<T> {
19        (0..self.cols())
20            .map(|col| self[(row, col)].clone())
21            .collect()
22    }
23
24    pub fn clone_row(&self, row: usize) -> Matrix2<T> {
25        Matrix2::from_row(
26            (0..self.cols())
27                .map(|col| self[(row, col)].clone())
28                .collect(),
29        )
30    }
31}
32
33impl<T: Default + Clone> Matrix2<T> {
34    pub fn new(rows: usize, cols: usize) -> Self {
35        Self {
36            data: vec![T::default(); rows * cols],
37            dim: (rows, cols),
38        }
39    }
40
41    pub fn zero(&mut self) {
42        for row in &mut self.data {
43            *row = T::default();
44        }
45    }
46}
47
48impl<T> Matrix2<T> {
49    pub fn from_array<const R: usize, const C: usize>(arr: [[T; C]; R]) -> Self {
50        let mut data = Vec::with_capacity(R * C);
51
52        for row in arr {
53            for x in row {
54                data.push(x);
55            }
56        }
57
58        Self { data, dim: (R, C) }
59    }
60
61    pub fn concat_rows(&mut self, mut other: Matrix2<T>) -> Result<()> {
62        if self.cols() != other.cols() {
63            return Err(Error::DimensionErr);
64        }
65
66        self.data.append(&mut other.data);
67        self.dim.0 += other.rows();
68        Ok(())
69    }
70
71    pub fn dim(&self) -> (usize, usize) {
72        self.dim
73    }
74
75    pub fn rows(&self) -> usize {
76        self.dim.0
77    }
78
79    pub fn cols(&self) -> usize {
80        self.dim.1
81    }
82
83    pub fn row_as_vec(&self, row: usize) -> Vec<&T> {
84        (0..self.cols()).map(|col| &self[(row, col)]).collect()
85    }
86
87    pub fn from_row(row_vec: Vec<T>) -> Self {
88        Self {
89            dim: (1, row_vec.len()),
90            data: row_vec,
91        }
92    }
93    pub fn from_vec(vec: Vec<Vec<T>>) -> Result<Self> {
94        let rows = vec.len();
95        let cols = vec.get(0).map(|row| row.len()).unwrap_or(0);
96
97        let mut data = Vec::new();
98        for row in vec {
99            if cols != row.len() {
100                return Err(Error::DimensionErr);
101            }
102
103            for x in row {
104                data.push(x);
105            }
106        }
107
108        Ok(Self {
109            data,
110            dim: (rows, cols),
111        })
112    }
113    pub fn to_vec(mut self) -> Vec<Vec<T>> {
114        let mut res = Vec::with_capacity(self.rows());
115        for _ in 0..self.rows() {
116            let mut r = Vec::with_capacity(self.cols());
117            for _ in 0..self.cols() {
118                r.push(self.data.remove(0))
119            }
120            res.push(r);
121        }
122        res
123    }
124
125    pub fn as_row_major(&self) -> &Vec<T> {
126        &self.data
127    }
128
129    pub fn as_vec(&self) -> Vec<Vec<&T>> {
130        let mut res = Vec::with_capacity(self.rows());
131        for row in 0..self.rows() {
132            let mut r = Vec::with_capacity(self.cols());
133            for col in 0..self.cols() {
134                r.push(&self[(row, col)])
135            }
136            res.push(r)
137        }
138        res
139    }
140
141    /// Shuffles the rows of two matrices with the same amount of rows
142    /// as if their rows were concatenated
143    pub fn shuffle_rows_synced(m1: &mut Matrix2<T>, m2: &mut Matrix2<T>) -> Result<()> {
144        if m1.rows() != m2.rows() {
145            return Err(Error::DimensionErr);
146        }
147
148        let mut rng = rand::thread_rng();
149        let rows = m1.rows();
150        let cols_m1 = m1.cols();
151        let cols_m2 = m2.cols();
152        for i in 0..rows {
153            let rand_row = rng.gen_range(i..rows);
154            for col in 0..cols_m1 {
155                m1.data.swap(i * cols_m1 + col, rand_row * cols_m1 + col);
156            }
157            for col in 0..cols_m2 {
158                m2.data.swap(i * cols_m2 + col, rand_row * cols_m2 + col);
159            }
160        }
161
162        Ok(())
163    }
164}
165
166impl<T: Display> Display for Matrix2<T> {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        for row in 0..self.rows() {
169            for col in 0..self.cols() {
170                write!(f, "{} ", self[(row, col)])?;
171            }
172            writeln!(f)?;
173        }
174
175        Ok(())
176    }
177}
178
179impl<T: Clone> Matrix2<&T> {
180    pub fn clone_inner(&self) -> Matrix2<T> {
181        let mut data_clone = Vec::with_capacity(self.rows() * self.cols());
182        for row in 0..self.rows() {
183            for col in 0..self.cols() {
184                data_clone.push(self[(row, col)].clone())
185            }
186        }
187        Matrix2 {
188            data: data_clone,
189            dim: (self.rows(), self.cols()),
190        }
191    }
192}
193
194impl<T: Copy> Matrix2<T> {
195    pub fn copy_rows(&self, from: usize, n: usize) -> Self {
196        let end_row = (from + n).min(self.rows());
197        let data = &self.data[from * self.cols()..end_row * self.cols()];
198        Self {
199            data: data.to_vec(),
200            dim: (end_row - from, self.cols()),
201        }
202    }
203}
204
205impl<T> Matrix2<T>
206where
207    T: Default,
208{
209    /// Applies a function to every element of the matrix
210    pub fn apply<F: Fn(T) -> T>(&mut self, f: F) {
211        for x in &mut self.data {
212            let old = std::mem::take(x);
213            let _ = std::mem::replace(x, f(old));
214        }
215    }
216}
217
218impl<T> Index<(usize, usize)> for Matrix2<T> {
219    type Output = T;
220    fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
221        &self.data[i * self.cols() + j]
222    }
223}
224
225impl<T> IndexMut<(usize, usize)> for Matrix2<T> {
226    fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
227        let idx = i * self.cols() + j;
228        &mut self.data[idx]
229    }
230}
231
232impl From<Matrix2<u32>> for Matrix2<f64> {
233    fn from(value: Matrix2<u32>) -> Self {
234        Self {
235            dim: value.dim(),
236            data: value.data.into_iter().map(|x| x as f64).collect(),
237        }
238    }
239}
240
241impl From<Matrix2<i32>> for Matrix2<f64> {
242    fn from(value: Matrix2<i32>) -> Self {
243        Self {
244            dim: value.dim(),
245            data: value.data.into_iter().map(|x| x as f64).collect(),
246        }
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use std::collections::HashMap;
253
254    use super::*;
255    #[test]
256    fn access_matrix2_from_array() {
257        let matrix = Matrix2::from_array([[1, 2, 3], [4, 5, 6]]);
258        assert_eq!(matrix[(0, 1)], 2);
259        assert_eq!(matrix[(1, 2)], 6);
260        assert_eq!(matrix[(0, 0)], 1);
261        assert_eq!(matrix[(1, 1)], 5);
262    }
263
264    #[test]
265    fn matrix2_from_vec() {
266        let vec = vec![vec![1, 2, 3], vec![4, 5, 6]];
267        let matrix = Matrix2::from_vec(vec).unwrap();
268
269        assert_eq!(matrix[(0, 1)], 2);
270        assert_eq!(matrix[(1, 2)], 6);
271        assert_eq!(matrix[(0, 0)], 1);
272        assert_eq!(matrix[(1, 1)], 5);
273    }
274
275    #[test]
276    fn matrix2_from_vec_err() {
277        let vec = vec![vec![1, 2, 3], vec![4, 5, 9], vec![1, 2]];
278        let matrix = Matrix2::from_vec(vec);
279
280        assert_eq!(matrix, Err(Error::DimensionErr));
281
282        let vec = vec![vec![1, 2], vec![4, 5, 9], vec![1, 2, 2]];
283        let matrix = Matrix2::from_vec(vec);
284
285        assert_eq!(matrix, Err(Error::DimensionErr));
286    }
287
288    #[test]
289    fn matrix2_apply() {
290        let mut matrix = Matrix2::from_array([[1, 2], [2, 2], [4, 8]]);
291
292        matrix.apply(|x| x / 2);
293
294        assert_eq!(matrix.to_vec(), [[0, 1], [1, 1], [2, 4]]);
295    }
296
297    #[test]
298    fn shuffle_rows() {
299        let mut relation = HashMap::new();
300        relation.insert([1, 2], [9]);
301        relation.insert([2, 2], [7]);
302        relation.insert([4, 8], [1]);
303        let mut m1 = Matrix2::from_array([[1, 2], [2, 2], [4, 8]]);
304        let mut m2 = Matrix2::from_array([[9], [7], [1]]);
305
306        assert_eq!(Ok(()), Matrix2::shuffle_rows_synced(&mut m1, &mut m2));
307
308        println!("m1 = {m1}\nm2 = {m2}");
309        assert!(m1
310            .to_vec()
311            .into_iter()
312            .zip(m2.to_vec())
313            .all(|(v1, v2)| relation[v1.as_slice()] == v2.as_slice()));
314    }
315}