1use crate::address_bound::AddressBound;
2use crate::matrix_address::MatrixAddress;
3use crate::tensor::Tensor;
4use std::error::Error;
5use std::fmt::{Display, Formatter};
6use std::io::ErrorKind;
7use std::io::ErrorKind::InvalidInput;
8use std::str::FromStr;
9
10#[derive(Clone, Debug, PartialEq)]
12pub struct Matrix<T> {
13 data: Vec<T>,
14 bounds: AddressBound<MatrixAddress>,
15}
16
17impl<T> Tensor<T, MatrixAddress> for Matrix<T> {
18 fn new<F>(bounds: AddressBound<MatrixAddress>, address_value_converter: F) -> Matrix<T>
19 where
20 F: Fn(MatrixAddress) -> T,
21 {
22 let data: Vec<T> = bounds.iter().map(address_value_converter).collect();
23 Matrix { data, bounds }
24 }
25
26 fn get(&self, address: &MatrixAddress) -> Option<&T> {
27 if !self.bounds.contains_address(address) {
28 return None;
29 }
30 self.data.get(self.bounds.index_address(address).unwrap())
31 }
32
33 fn get_mut(&mut self, address: &MatrixAddress) -> Option<&mut T> {
34 if !self.bounds.contains_address(address) {
35 return None;
36 }
37 self.data
38 .get_mut(self.bounds.index_address(address).unwrap())
39 }
40
41 fn set(&mut self, address: &MatrixAddress, value: T) -> Result<(), std::io::Error> {
42 if !self.bounds.contains_address(address) {
43 return Err(std::io::Error::new(
44 InvalidInput,
45 format!("The following address is out of bounds: {address}"),
46 ));
47 }
48 self.data[self.bounds.index_address(address).unwrap()] = value;
49 Ok(())
50 }
51
52 fn bounds(&self) -> &AddressBound<MatrixAddress> {
53 &self.bounds
54 }
55}
56
57impl<T> Matrix<T> {
58 pub fn to_display_string<T1: Display, F: Fn(&T) -> T1>(
59 &self,
60 display_func: F,
61 row_delimiter: &str,
62 column_delimiter: &str,
63 ) -> String {
64 self.address_iterator()
65 .enumerate()
66 .map(|(i, address)| {
67 format!(
68 "{}{}",
69 display_func(self.get(&address).unwrap()),
70 if (i as i64 + 1) % (self.bounds.largest_possible_position.x + 1) == 0 {
71 column_delimiter
72 } else {
73 row_delimiter
74 }
75 )
76 })
77 .fold("".to_string(), |a: String, b: String| a + &b)
78 }
79
80 pub fn parse_matrix<F>(
81 data_str: &str,
82 column_delimiter: &str,
83 row_delimiter: &str,
84 str_to_t_converter: F,
85 ) -> Result<Matrix<T>, std::io::Error>
86 where
87 F: Fn(&str) -> T,
88 {
89 let values: Vec<Vec<&str>> = data_str
90 .split(row_delimiter)
91 .map(|row| row.split(column_delimiter).collect())
92 .collect();
93 if values
94 .iter()
95 .skip(1)
96 .any(|row| row.len() != values.get(0).unwrap().len())
97 {
98 return Err(std::io::Error::new(
99 ErrorKind::InvalidData,
100 "Row Lengths are not constant",
101 ));
102 }
103 let height = values.len();
104 let width = values.get(0).unwrap().len();
105 let matrix_bounds = AddressBound {
106 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
107 largest_possible_position: MatrixAddress {
108 x: (width - 1) as i64,
109 y: (height - 1) as i64,
110 },
111 };
112
113 Ok(Matrix::new(matrix_bounds, |address| {
114 str_to_t_converter(
115 values
116 .get(address.y as usize)
117 .unwrap()
118 .get(address.x as usize)
119 .unwrap(),
120 )
121 }))
122 }
123}
124
125impl<'a, T: Display + 'a> Display for Matrix<T> {
126 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
127 write!(
128 f,
129 "{}",
130 self.to_display_string(|x: &T| x.to_string(), " ", "\n")
131 )
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use crate::address_bound::AddressBound;
138 use crate::matrix::Matrix;
139 use crate::matrix_address::MatrixAddress;
140 use crate::tensor::Tensor;
141 use proptest::num::usize;
142 use std::str::FromStr;
143
144 #[test]
145 fn display_test() {
146 let bound = AddressBound::new(MatrixAddress { x: 0, y: 0 }, MatrixAddress { x: 10, y: 10 });
147 assert_eq!(
148 "0 1 2 3 4 5 6 0 1 2 3\n4 5 6 0 1 2 3 4 5 6 0\n1 2 3 4 5 6 0 1 2 3 4\n5 6 0 1 2 3 4 5 6 0 1\n2 3 4 5 6 0 1 2 3 4 5\n6 0 1 2 3 4 5 6 0 1 2\n3 4 5 6 0 1 2 3 4 5 6\n0 1 2 3 4 5 6 0 1 2 3\n4 5 6 0 1 2 3 4 5 6 0\n1 2 3 4 5 6 0 1 2 3 4\n5 6 0 1 2 3 4 5 6 0 1\n",
149 format!(
150 "{}",
151 Matrix::new(bound.clone(), |address: MatrixAddress| bound
152 .index_address(&address)
153 .unwrap()
154 % 7)
155 )
156 )
157 }
158
159 #[test]
160 fn parse_test() {
161 let data_str = "0,1,2,3,4,5,6,0,1,2,3|4,5,6,0,1,2,3,4,5,6,0|1,2,3,4,5,6,0,1,2,3,4|5,6,0,1,2,3,4,5,6,0,1|2,3,4,5,6,0,1,2,3,4,5|6,0,1,2,3,4,5,6,0,1,2|3,4,5,6,0,1,2,3,4,5,6|0,1,2,3,4,5,6,0,1,2,3|4,5,6,0,1,2,3,4,5,6,0|1,2,3,4,5,6,0,1,2,3,4|5,6,0,1,2,3,4,5,6,0,1";
162 let bound = AddressBound::new(MatrixAddress { x: 0, y: 0 }, MatrixAddress { x: 10, y: 10 });
163 assert_eq!(
164 Matrix::new(bound.clone(), |address: MatrixAddress| bound
165 .index_address(&address)
166 .unwrap()
167 % 7),
168 Matrix::parse_matrix(data_str, ",", "|", |string| usize::from_str(string)
169 .expect(""))
170 .expect("")
171 );
172 }
173 #[test]
174 fn get_test() {
175 let bound = AddressBound {
176 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
177 largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
178 };
179 let matrix = Matrix::new(bound.clone(), |address| bound.index_address(&address));
180 matrix.address_iterator().for_each(|address| {
181 assert_eq!(
182 bound.index_address(&address),
183 *matrix.get(&address).unwrap()
184 )
185 })
186 }
187
188 #[test]
189 fn set_test() {
190 let bound = AddressBound {
191 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
192 largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
193 };
194 let mut matrix = Matrix::new(bound.clone(), |_address| 0usize);
195 matrix.address_iterator().for_each(|address| {
196 assert_eq!(matrix.get(&address).unwrap(), &0usize);
197 matrix
198 .set(&address, bound.index_address(&address).unwrap())
199 .expect("Index out of bounds error");
200 assert_eq!(
201 matrix.get(&address).unwrap(),
202 &(bound.index_address(&address).unwrap())
203 );
204 });
205 matrix.address_iterator().for_each(|address| {
206 assert_eq!(
207 bound.index_address(&address).unwrap(),
208 *(matrix.get(&address).unwrap())
209 )
210 })
211 }
212 #[test]
213 fn transform_test() {
214 let bound = AddressBound {
215 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
216 largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
217 };
218 let matrix = Matrix::new(bound.clone(), |address| {
219 bound.index_address(&address).unwrap()
220 });
221 let transformed_matrix: Matrix<f64> = matrix.transform(|value| *value as f64);
222 let transformed_by_address_matrix: Matrix<f64> =
223 matrix.transform_by_address(|_address, value| *value as f64);
224 matrix.address_iterator().for_each(|address| {
225 assert_eq!(
226 *matrix.get(&address).unwrap() as f64,
227 *transformed_matrix.get(&address).unwrap(),
228 );
229 assert_eq!(
230 *matrix.get(&address).unwrap() as f64,
231 *transformed_by_address_matrix.get(&address).unwrap()
232 );
233 })
234 }
235
236 #[test]
237 fn transform_in_place_test() {
238 let bound = AddressBound {
239 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
240 largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
241 };
242 let original_matrix = Matrix::new(bound.clone(), |address| {
243 bound.index_address(&address).unwrap() as i64
244 });
245 let mut working_matrix = original_matrix.clone();
246 let mut working_by_address_matrix = original_matrix.clone();
247
248 working_matrix.transform_in_place(|value| *value *= -2);
249 working_by_address_matrix.transform_by_address_in_place(|_address, value| *value *= -2);
250 assert_eq!(
251 original_matrix.transform::<_, i64, Matrix<i64>>(|value| *value * -2),
252 working_matrix
253 );
254 assert_eq!(working_matrix, working_by_address_matrix)
255 }
256
257 #[test]
258 fn equality_test() {
259 let bound = AddressBound {
260 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
261 largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
262 };
263 let m1 = Matrix::new(bound.clone(), |address| {
264 bound.index_address(&address).unwrap() as i32
265 });
266 let m2 = Matrix::new(bound.clone(), |address| {
267 bound.index_address(&address).unwrap() as i32
268 });
269 assert_eq!(m1, m2);
270 }
271}