1use rand::Rng;
2
3use crate::prelude::*;
4use std::{
5 fmt::Display,
6 ops::{Index, IndexMut},
7};
8
9pub mod ops;
10
11#[derive(Debug, PartialEq, Clone)]
12pub struct Matrix2<T> {
13 data: Vec<T>,
14 dim: (usize, usize),
15}
16
17impl<T: Clone> Matrix2<T> {
18 pub fn clone_row_to_vec(&self, row: usize) -> Vec<T> {
19 (0..self.cols())
20 .map(|col| self[(row, col)].clone())
21 .collect()
22 }
23
24 pub fn clone_row(&self, row: usize) -> Matrix2<T> {
25 Matrix2::from_row(
26 (0..self.cols())
27 .map(|col| self[(row, col)].clone())
28 .collect(),
29 )
30 }
31}
32
33impl<T: Default + Clone> Matrix2<T> {
34 pub fn new(rows: usize, cols: usize) -> Self {
35 Self {
36 data: vec![T::default(); rows * cols],
37 dim: (rows, cols),
38 }
39 }
40
41 pub fn zero(&mut self) {
42 for row in &mut self.data {
43 *row = T::default();
44 }
45 }
46}
47
48impl<T> Matrix2<T> {
49 pub fn from_array<const R: usize, const C: usize>(arr: [[T; C]; R]) -> Self {
50 let mut data = Vec::with_capacity(R * C);
51
52 for row in arr {
53 for x in row {
54 data.push(x);
55 }
56 }
57
58 Self { data, dim: (R, C) }
59 }
60
61 pub fn concat_rows(&mut self, mut other: Matrix2<T>) -> Result<()> {
62 if self.cols() != other.cols() {
63 return Err(Error::DimensionErr);
64 }
65
66 self.data.append(&mut other.data);
67 self.dim.0 += other.rows();
68 Ok(())
69 }
70
71 pub fn dim(&self) -> (usize, usize) {
72 self.dim
73 }
74
75 pub fn rows(&self) -> usize {
76 self.dim.0
77 }
78
79 pub fn cols(&self) -> usize {
80 self.dim.1
81 }
82
83 pub fn row_as_vec(&self, row: usize) -> Vec<&T> {
84 (0..self.cols()).map(|col| &self[(row, col)]).collect()
85 }
86
87 pub fn from_row(row_vec: Vec<T>) -> Self {
88 Self {
89 dim: (1, row_vec.len()),
90 data: row_vec,
91 }
92 }
93 pub fn from_vec(vec: Vec<Vec<T>>) -> Result<Self> {
94 let rows = vec.len();
95 let cols = vec.get(0).map(|row| row.len()).unwrap_or(0);
96
97 let mut data = Vec::new();
98 for row in vec {
99 if cols != row.len() {
100 return Err(Error::DimensionErr);
101 }
102
103 for x in row {
104 data.push(x);
105 }
106 }
107
108 Ok(Self {
109 data,
110 dim: (rows, cols),
111 })
112 }
113 pub fn to_vec(mut self) -> Vec<Vec<T>> {
114 let mut res = Vec::with_capacity(self.rows());
115 for _ in 0..self.rows() {
116 let mut r = Vec::with_capacity(self.cols());
117 for _ in 0..self.cols() {
118 r.push(self.data.remove(0))
119 }
120 res.push(r);
121 }
122 res
123 }
124
125 pub fn as_row_major(&self) -> &Vec<T> {
126 &self.data
127 }
128
129 pub fn as_vec(&self) -> Vec<Vec<&T>> {
130 let mut res = Vec::with_capacity(self.rows());
131 for row in 0..self.rows() {
132 let mut r = Vec::with_capacity(self.cols());
133 for col in 0..self.cols() {
134 r.push(&self[(row, col)])
135 }
136 res.push(r)
137 }
138 res
139 }
140
141 pub fn shuffle_rows_synced(m1: &mut Matrix2<T>, m2: &mut Matrix2<T>) -> Result<()> {
144 if m1.rows() != m2.rows() {
145 return Err(Error::DimensionErr);
146 }
147
148 let mut rng = rand::thread_rng();
149 let rows = m1.rows();
150 let cols_m1 = m1.cols();
151 let cols_m2 = m2.cols();
152 for i in 0..rows {
153 let rand_row = rng.gen_range(i..rows);
154 for col in 0..cols_m1 {
155 m1.data.swap(i * cols_m1 + col, rand_row * cols_m1 + col);
156 }
157 for col in 0..cols_m2 {
158 m2.data.swap(i * cols_m2 + col, rand_row * cols_m2 + col);
159 }
160 }
161
162 Ok(())
163 }
164}
165
166impl<T: Display> Display for Matrix2<T> {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 for row in 0..self.rows() {
169 for col in 0..self.cols() {
170 write!(f, "{} ", self[(row, col)])?;
171 }
172 writeln!(f)?;
173 }
174
175 Ok(())
176 }
177}
178
179impl<T: Clone> Matrix2<&T> {
180 pub fn clone_inner(&self) -> Matrix2<T> {
181 let mut data_clone = Vec::with_capacity(self.rows() * self.cols());
182 for row in 0..self.rows() {
183 for col in 0..self.cols() {
184 data_clone.push(self[(row, col)].clone())
185 }
186 }
187 Matrix2 {
188 data: data_clone,
189 dim: (self.rows(), self.cols()),
190 }
191 }
192}
193
194impl<T: Copy> Matrix2<T> {
195 pub fn copy_rows(&self, from: usize, n: usize) -> Self {
196 let end_row = (from + n).min(self.rows());
197 let data = &self.data[from * self.cols()..end_row * self.cols()];
198 Self {
199 data: data.to_vec(),
200 dim: (end_row - from, self.cols()),
201 }
202 }
203}
204
205impl<T> Matrix2<T>
206where
207 T: Default,
208{
209 pub fn apply<F: Fn(T) -> T>(&mut self, f: F) {
211 for x in &mut self.data {
212 let old = std::mem::take(x);
213 let _ = std::mem::replace(x, f(old));
214 }
215 }
216}
217
218impl<T> Index<(usize, usize)> for Matrix2<T> {
219 type Output = T;
220 fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
221 &self.data[i * self.cols() + j]
222 }
223}
224
225impl<T> IndexMut<(usize, usize)> for Matrix2<T> {
226 fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
227 let idx = i * self.cols() + j;
228 &mut self.data[idx]
229 }
230}
231
232impl From<Matrix2<u32>> for Matrix2<f64> {
233 fn from(value: Matrix2<u32>) -> Self {
234 Self {
235 dim: value.dim(),
236 data: value.data.into_iter().map(|x| x as f64).collect(),
237 }
238 }
239}
240
241impl From<Matrix2<i32>> for Matrix2<f64> {
242 fn from(value: Matrix2<i32>) -> Self {
243 Self {
244 dim: value.dim(),
245 data: value.data.into_iter().map(|x| x as f64).collect(),
246 }
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use std::collections::HashMap;
253
254 use super::*;
255 #[test]
256 fn access_matrix2_from_array() {
257 let matrix = Matrix2::from_array([[1, 2, 3], [4, 5, 6]]);
258 assert_eq!(matrix[(0, 1)], 2);
259 assert_eq!(matrix[(1, 2)], 6);
260 assert_eq!(matrix[(0, 0)], 1);
261 assert_eq!(matrix[(1, 1)], 5);
262 }
263
264 #[test]
265 fn matrix2_from_vec() {
266 let vec = vec![vec![1, 2, 3], vec![4, 5, 6]];
267 let matrix = Matrix2::from_vec(vec).unwrap();
268
269 assert_eq!(matrix[(0, 1)], 2);
270 assert_eq!(matrix[(1, 2)], 6);
271 assert_eq!(matrix[(0, 0)], 1);
272 assert_eq!(matrix[(1, 1)], 5);
273 }
274
275 #[test]
276 fn matrix2_from_vec_err() {
277 let vec = vec![vec![1, 2, 3], vec![4, 5, 9], vec![1, 2]];
278 let matrix = Matrix2::from_vec(vec);
279
280 assert_eq!(matrix, Err(Error::DimensionErr));
281
282 let vec = vec![vec![1, 2], vec![4, 5, 9], vec![1, 2, 2]];
283 let matrix = Matrix2::from_vec(vec);
284
285 assert_eq!(matrix, Err(Error::DimensionErr));
286 }
287
288 #[test]
289 fn matrix2_apply() {
290 let mut matrix = Matrix2::from_array([[1, 2], [2, 2], [4, 8]]);
291
292 matrix.apply(|x| x / 2);
293
294 assert_eq!(matrix.to_vec(), [[0, 1], [1, 1], [2, 4]]);
295 }
296
297 #[test]
298 fn shuffle_rows() {
299 let mut relation = HashMap::new();
300 relation.insert([1, 2], [9]);
301 relation.insert([2, 2], [7]);
302 relation.insert([4, 8], [1]);
303 let mut m1 = Matrix2::from_array([[1, 2], [2, 2], [4, 8]]);
304 let mut m2 = Matrix2::from_array([[9], [7], [1]]);
305
306 assert_eq!(Ok(()), Matrix2::shuffle_rows_synced(&mut m1, &mut m2));
307
308 println!("m1 = {m1}\nm2 = {m2}");
309 assert!(m1
310 .to_vec()
311 .into_iter()
312 .zip(m2.to_vec())
313 .all(|(v1, v2)| relation[v1.as_slice()] == v2.as_slice()));
314 }
315}