1use crate::address_bound::AddressBound;
2use crate::matrix_address::MatrixAddress;
3use crate::tensor::Tensor;
4use std::fmt::{Display, Formatter};
5use std::io::Error;
6use std::io::ErrorKind::InvalidInput;
7
8#[derive(Clone, Debug, PartialEq)]
10pub struct Matrix<T> {
11 data: Vec<T>,
12 bounds: AddressBound<MatrixAddress>,
13}
14
15impl<T> Tensor<T, MatrixAddress> for Matrix<T> {
16 fn new<F>(bounds: AddressBound<MatrixAddress>, address_value_converter: F) -> Matrix<T>
17 where
18 F: Fn(MatrixAddress) -> T,
19 {
20 let data: Vec<T> = bounds.iter().map(address_value_converter).collect();
21 Matrix { data, bounds }
22 }
23
24 fn get(&self, address: &MatrixAddress) -> Option<&T> {
25 if !self.bounds.contains_address(address) {
26 return None;
27 }
28 self.data.get(self.bounds.index_address(address).unwrap())
29 }
30
31 fn get_mut(&mut self, address: &MatrixAddress) -> Option<&mut T> {
32 if !self.bounds.contains_address(address) {
33 return None;
34 }
35 self.data
36 .get_mut(self.bounds.index_address(address).unwrap())
37 }
38
39 fn set(&mut self, address: &MatrixAddress, value: T) -> Result<(), Error> {
40 if !self.bounds.contains_address(address) {
41 return Err(Error::new(
42 InvalidInput,
43 format!("The following address is out of bounds: {address}"),
44 ));
45 }
46 self.data[self.bounds.index_address(address).unwrap()] = value;
47 Ok(())
48 }
49
50 fn bounds(&self) -> &AddressBound<MatrixAddress> {
51 &self.bounds
52 }
53}
54
55impl<T> Matrix<T> {
56 pub fn to_display_string<T1: Display, F: Fn(&T) -> T1>(
57 &self,
58 display_func: F,
59 row_delimiter: &str,
60 column_delimiter: &str,
61 ) -> String {
62 self.address_iterator()
63 .enumerate()
64 .map(|(i, address)| {
65 format!(
66 "{}{}",
67 display_func(self.get(&address).unwrap()),
68 if (i as i64 + 1) % (self.bounds.largest_possible_position.x + 1) == 0 {
69 column_delimiter
70 } else {
71 row_delimiter
72 }
73 )
74 })
75 .fold("".to_string(), |a: String, b: String| a + &b)
76 }
77}
78
79impl<'a, T: Display + 'a> Display for Matrix<T> {
80 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
81 write!(
82 f,
83 "{}",
84 self.to_display_string(|x: &T| x.to_string(), " ", "\n")
85 )
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use crate::address_bound::AddressBound;
92 use crate::matrix::Matrix;
93 use crate::matrix_address::MatrixAddress;
94 use crate::tensor::Tensor;
95
96 #[test]
97 fn display_test() {
98 let bound = AddressBound::new(MatrixAddress { x: 0, y: 0 }, MatrixAddress { x: 10, y: 10 });
99 assert_eq!(
100 "0 1 2 3 4 5 6 0 1 2 3\n4 5 6 0 1 2 3 4 5 6 0\n1 2 3 4 5 6 0 1 2 3 4\n5 6 0 1 2 3 4 5 6 0 1\n2 3 4 5 6 0 1 2 3 4 5\n6 0 1 2 3 4 5 6 0 1 2\n3 4 5 6 0 1 2 3 4 5 6\n0 1 2 3 4 5 6 0 1 2 3\n4 5 6 0 1 2 3 4 5 6 0\n1 2 3 4 5 6 0 1 2 3 4\n5 6 0 1 2 3 4 5 6 0 1\n",
101 format!(
102 "{}",
103 Matrix::new(bound.clone(), |address: MatrixAddress| bound
104 .index_address(&address)
105 .unwrap()
106 % 7)
107 )
108 )
109 }
110
111 #[test]
112 fn get_test() {
113 let bound = AddressBound {
114 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
115 largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
116 };
117 let matrix = Matrix::new(bound.clone(), |address| bound.index_address(&address));
118 matrix.address_iterator().for_each(|address| {
119 assert_eq!(
120 bound.index_address(&address),
121 *matrix.get(&address).unwrap()
122 )
123 })
124 }
125
126 #[test]
127 fn set_test() {
128 let bound = AddressBound {
129 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
130 largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
131 };
132 let mut matrix = Matrix::new(bound.clone(), |_address| 0usize);
133 matrix.address_iterator().for_each(|address| {
134 assert_eq!(matrix.get(&address).unwrap(), &0usize);
135 matrix
136 .set(&address, bound.index_address(&address).unwrap())
137 .expect("Index out of bounds error");
138 assert_eq!(
139 matrix.get(&address).unwrap(),
140 &(bound.index_address(&address).unwrap())
141 );
142 });
143 matrix.address_iterator().for_each(|address| {
144 assert_eq!(
145 bound.index_address(&address).unwrap(),
146 *(matrix.get(&address).unwrap())
147 )
148 })
149 }
150 #[test]
151 fn transform_test() {
152 let bound = AddressBound {
153 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
154 largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
155 };
156 let matrix = Matrix::new(bound.clone(), |address| {
157 bound.index_address(&address).unwrap()
158 });
159 let transformed_matrix: Matrix<f64> = matrix.transform(|value| *value as f64);
160 let transformed_by_address_matrix: Matrix<f64> =
161 matrix.transform_by_address(|_address, value| *value as f64);
162 matrix.address_iterator().for_each(|address| {
163 assert_eq!(
164 *matrix.get(&address).unwrap() as f64,
165 *transformed_matrix.get(&address).unwrap(),
166 );
167 assert_eq!(
168 *matrix.get(&address).unwrap() as f64,
169 *transformed_by_address_matrix.get(&address).unwrap()
170 );
171 })
172 }
173
174 #[test]
175 fn transform_in_place_test() {
176 let bound = AddressBound {
177 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
178 largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
179 };
180 let original_matrix = Matrix::new(bound.clone(), |address| {
181 bound.index_address(&address).unwrap() as i64
182 });
183 let mut working_matrix = original_matrix.clone();
184 let mut working_by_address_matrix = original_matrix.clone();
185
186 working_matrix.transform_in_place(|value| *value *= -2);
187 working_by_address_matrix.transform_by_address_in_place(|_address, value| *value *= -2);
188 assert_eq!(
189 original_matrix.transform::<_, i64, Matrix<i64>>(|value| *value * -2),
190 working_matrix
191 );
192 assert_eq!(working_matrix, working_by_address_matrix)
193 }
194
195 #[test]
196 fn equality_test() {
197 let bound = AddressBound {
198 smallest_possible_position: MatrixAddress { x: 0, y: 0 },
199 largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
200 };
201 let m1 = Matrix::new(bound.clone(), |address| {
202 bound.index_address(&address).unwrap() as i32
203 });
204 let m2 = Matrix::new(bound.clone(), |address| {
205 bound.index_address(&address).unwrap() as i32
206 });
207 assert_eq!(m1, m2);
208 }
209}