1use crate::address_bound::AddressBound;
2use crate::matrix_address::MatrixAddress;
3use crate::tensor::Tensor;
4use std::fmt::{Display, Formatter};
5use std::io::ErrorKind;
6use std::io::ErrorKind::InvalidInput;
7
8#[derive(Clone, Debug, PartialEq)]
10pub struct Matrix<T> {
11 data: Vec<T>,
12 bounds: AddressBound<MatrixAddress>,
13}
14
15impl<T> Tensor<T, MatrixAddress> for Matrix<T> {
16 fn new<F>(bounds: AddressBound<MatrixAddress>, address_value_converter: F) -> Matrix<T>
17 where
18 F: Fn(MatrixAddress) -> T,
19 {
20 let data: Vec<T> = bounds.iter().map(address_value_converter).collect();
21 Matrix { data, bounds }
22 }
23
24 fn get(&self, address: &MatrixAddress) -> Option<&T> {
25 if !self.bounds.contains_address(address) {
26 return None;
27 }
28 self.data.get(self.bounds.index_address(address).unwrap())
29 }
30
31 fn get_mut(&mut self, address: &MatrixAddress) -> Option<&mut T> {
32 if !self.bounds.contains_address(address) {
33 return None;
34 }
35 self.data
36 .get_mut(self.bounds.index_address(address).unwrap())
37 }
38
39 fn set(&mut self, address: &MatrixAddress, value: T) -> Result<(), std::io::Error> {
40 if !self.bounds.contains_address(address) {
41 return Err(std::io::Error::new(
42 InvalidInput,
43 format!("The following address is out of bounds: {address}"),
44 ));
45 }
46 self.data[self.bounds.index_address(address).unwrap()] = value;
47 Ok(())
48 }
49
50 fn bounds(&self) -> &AddressBound<MatrixAddress> {
51 &self.bounds
52 }
53}
54
55impl<T> Matrix<T> {
56 pub fn to_display_string<T1: Display, F: Fn(&T) -> T1>(
57 &self,
58 display_func: F,
59 row_delimiter: &str,
60 column_delimiter: &str,
61 ) -> String {
62 self.address_iterator()
63 .enumerate()
64 .map(|(i, address)| {
65 format!(
66 "{}{}",
67 display_func(self.get(&address).unwrap()),
68 if (i as i64 + 1) % (self.bounds.largest_possible_position.x + 1) == 0 {
69 column_delimiter
70 } else {
71 row_delimiter
72 }
73 )
74 })
75 .fold("".to_string(), |a: String, b: String| a + &b)
76 }
77
78 pub fn parse_matrix<F>(
79 data_str: &str,
80 column_delimiter: &str,
81 row_delimiter: &str,
82 str_to_t_converter: F,
83 ) -> Result<Matrix<T>, std::io::Error>
84 where
85 F: Fn(&str, MatrixAddress) -> T,
86 {
87 let values: Vec<Vec<&str>> = data_str
88 .split(row_delimiter)
89 .map(|row| {
90 row.split(column_delimiter)
91 .filter(|string| *string != "")
92 .collect()
93 })
94 .filter(|row: &Vec<&str>| row.len() != 0)
95 .collect();
96 if values
97 .iter()
98 .skip(1)
99 .any(|row| row.len() != values.get(0).unwrap().len())
100 {
101 return Err(std::io::Error::new(
102 ErrorKind::InvalidData,
103 "Row Lengths are not constant",
104 ));
105 }
106 let height = values.len();
107 let width = values.get(0).unwrap().len();
108 let matrix_bounds = AddressBound {
109 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
110 largest_possible_position: MatrixAddress {
111 x: (width - 1) as i64,
112 y: (height - 1) as i64,
113 },
114 };
115
116 Ok(Matrix::new(matrix_bounds, |address| {
117 str_to_t_converter(
118 values
119 .get(address.y as usize)
120 .unwrap()
121 .get(address.x as usize)
122 .unwrap(),
123 address,
124 )
125 }))
126 }
127}
128
129impl<'a, T: Display + 'a> Display for Matrix<T> {
130 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
131 write!(
132 f,
133 "{}",
134 self.to_display_string(|x: &T| x.to_string(), " ", "\n")
135 )
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use crate::address_bound::AddressBound;
142 use crate::matrix::Matrix;
143 use crate::matrix_address::MatrixAddress;
144 use crate::tensor::Tensor;
145 use proptest::num::usize;
146 use std::str::FromStr;
147
148 #[test]
149 fn display_test() {
150 let bound = AddressBound::new(MatrixAddress { x: 0, y: 0 }, MatrixAddress { x: 10, y: 10 });
151 assert_eq!(
152 "0 1 2 3 4 5 6 0 1 2 3\n4 5 6 0 1 2 3 4 5 6 0\n1 2 3 4 5 6 0 1 2 3 4\n5 6 0 1 2 3 4 5 6 0 1\n2 3 4 5 6 0 1 2 3 4 5\n6 0 1 2 3 4 5 6 0 1 2\n3 4 5 6 0 1 2 3 4 5 6\n0 1 2 3 4 5 6 0 1 2 3\n4 5 6 0 1 2 3 4 5 6 0\n1 2 3 4 5 6 0 1 2 3 4\n5 6 0 1 2 3 4 5 6 0 1\n",
153 format!(
154 "{}",
155 Matrix::new(bound.clone(), |address: MatrixAddress| bound
156 .index_address(&address)
157 .unwrap()
158 % 7)
159 )
160 )
161 }
162
163 #[test]
164 fn parse_test() {
165 let data_str = "0,1,2,3,4,5,6,0,1,2,3|4,5,6,0,1,2,3,4,5,6,0|1,2,3,4,5,6,0,1,2,3,4|5,6,0,1,2,3,4,5,6,0,1|2,3,4,5,6,0,1,2,3,4,5|6,0,1,2,3,4,5,6,0,1,2|3,4,5,6,0,1,2,3,4,5,6|0,1,2,3,4,5,6,0,1,2,3|4,5,6,0,1,2,3,4,5,6,0|1,2,3,4,5,6,0,1,2,3,4|5,6,0,1,2,3,4,5,6,0,1";
166 let bound = AddressBound::new(MatrixAddress { x: 0, y: 0 }, MatrixAddress { x: 10, y: 10 });
167 assert_eq!(
168 Matrix::new(bound.clone(), |address: MatrixAddress| bound
169 .index_address(&address)
170 .unwrap()
171 % 7),
172 Matrix::parse_matrix(data_str, ",", "|", |string, _| usize::from_str(string)
173 .expect(""))
174 .expect("")
175 );
176 }
177 #[test]
178 fn get_test() {
179 let bound = AddressBound {
180 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
181 largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
182 };
183 let matrix = Matrix::new(bound.clone(), |address| bound.index_address(&address));
184 matrix.address_iterator().for_each(|address| {
185 assert_eq!(
186 bound.index_address(&address),
187 *matrix.get(&address).unwrap()
188 )
189 })
190 }
191
192 #[test]
193 fn set_test() {
194 let bound = AddressBound {
195 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
196 largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
197 };
198 let mut matrix = Matrix::new(bound.clone(), |_address| 0usize);
199 matrix.address_iterator().for_each(|address| {
200 assert_eq!(matrix.get(&address).unwrap(), &0usize);
201 matrix
202 .set(&address, bound.index_address(&address).unwrap())
203 .expect("Index out of bounds error");
204 assert_eq!(
205 matrix.get(&address).unwrap(),
206 &(bound.index_address(&address).unwrap())
207 );
208 });
209 matrix.address_iterator().for_each(|address| {
210 assert_eq!(
211 bound.index_address(&address).unwrap(),
212 *(matrix.get(&address).unwrap())
213 )
214 })
215 }
216 #[test]
217 fn transform_test() {
218 let bound = AddressBound {
219 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
220 largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
221 };
222 let matrix = Matrix::new(bound.clone(), |address| {
223 bound.index_address(&address).unwrap()
224 });
225 let transformed_matrix: Matrix<f64> = matrix.transform(|value| *value as f64);
226 let transformed_by_address_matrix: Matrix<f64> =
227 matrix.transform_by_address(|_address, value| *value as f64);
228 matrix.address_iterator().for_each(|address| {
229 assert_eq!(
230 *matrix.get(&address).unwrap() as f64,
231 *transformed_matrix.get(&address).unwrap(),
232 );
233 assert_eq!(
234 *matrix.get(&address).unwrap() as f64,
235 *transformed_by_address_matrix.get(&address).unwrap()
236 );
237 })
238 }
239
240 #[test]
241 fn transform_in_place_test() {
242 let bound = AddressBound {
243 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
244 largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
245 };
246 let original_matrix = Matrix::new(bound.clone(), |address| {
247 bound.index_address(&address).unwrap() as i64
248 });
249 let mut working_matrix = original_matrix.clone();
250 let mut working_by_address_matrix = original_matrix.clone();
251
252 working_matrix.transform_in_place(|value| *value *= -2);
253 working_by_address_matrix.transform_by_address_in_place(|_address, value| *value *= -2);
254 assert_eq!(
255 original_matrix.transform::<_, i64, Matrix<i64>>(|value| *value * -2),
256 working_matrix
257 );
258 assert_eq!(working_matrix, working_by_address_matrix)
259 }
260
261 #[test]
262 fn equality_test() {
263 let bound = AddressBound {
264 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
265 largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
266 };
267 let m1 = Matrix::new(bound.clone(), |address| {
268 bound.index_address(&address).unwrap() as i32
269 });
270 let m2 = Matrix::new(bound.clone(), |address| {
271 bound.index_address(&address).unwrap() as i32
272 });
273 assert_eq!(m1, m2);
274 }
275}