rust_tensors/
matrix.rs

1use crate::address_bound::AddressBound;
2use crate::matrix_address::MatrixAddress;
3use crate::tensor::Tensor;
4use std::fmt::{Display, Formatter};
5use std::io::ErrorKind;
6use std::io::ErrorKind::InvalidInput;
7
8/// A tensor of two dimensions accessed using MatrixAddress.
9#[derive(Clone, Debug, PartialEq)]
10pub struct Matrix<T> {
11    data: Vec<T>,
12    bounds: AddressBound<MatrixAddress>,
13}
14
15impl<T> Tensor<T, MatrixAddress> for Matrix<T> {
16    fn new<F>(bounds: AddressBound<MatrixAddress>, address_value_converter: F) -> Matrix<T>
17    where
18        F: Fn(MatrixAddress) -> T,
19    {
20        let data: Vec<T> = bounds.iter().map(address_value_converter).collect();
21        Matrix { data, bounds }
22    }
23
24    fn get(&self, address: &MatrixAddress) -> Option<&T> {
25        if !self.bounds.contains_address(address) {
26            return None;
27        }
28        self.data.get(self.bounds.index_address(address).unwrap())
29    }
30
31    fn get_mut(&mut self, address: &MatrixAddress) -> Option<&mut T> {
32        if !self.bounds.contains_address(address) {
33            return None;
34        }
35        self.data
36            .get_mut(self.bounds.index_address(address).unwrap())
37    }
38
39    fn set(&mut self, address: &MatrixAddress, value: T) -> Result<(), std::io::Error> {
40        if !self.bounds.contains_address(address) {
41            return Err(std::io::Error::new(
42                InvalidInput,
43                format!("The following address is out of bounds: {address}"),
44            ));
45        }
46        self.data[self.bounds.index_address(address).unwrap()] = value;
47        Ok(())
48    }
49
50    fn bounds(&self) -> &AddressBound<MatrixAddress> {
51        &self.bounds
52    }
53}
54
55impl<T> Matrix<T> {
56    pub fn to_display_string<T1: Display, F: Fn(&T) -> T1>(
57        &self,
58        display_func: F,
59        row_delimiter: &str,
60        column_delimiter: &str,
61    ) -> String {
62        self.address_iterator()
63            .enumerate()
64            .map(|(i, address)| {
65                format!(
66                    "{}{}",
67                    display_func(self.get(&address).unwrap()),
68                    if (i as i64 + 1) % (self.bounds.largest_possible_position.x + 1) == 0 {
69                        column_delimiter
70                    } else {
71                        row_delimiter
72                    }
73                )
74            })
75            .fold("".to_string(), |a: String, b: String| a + &b)
76    }
77
78    pub fn parse_matrix<F>(
79        data_str: &str,
80        column_delimiter: &str,
81        row_delimiter: &str,
82        str_to_t_converter: F,
83    ) -> Result<Matrix<T>, std::io::Error>
84    where
85        F: Fn(&str, MatrixAddress) -> T,
86    {
87        let values: Vec<Vec<&str>> = data_str
88            .split(row_delimiter)
89            .map(|row| {
90                row.split(column_delimiter)
91                    .filter(|string| *string != "")
92                    .collect()
93            })
94            .filter(|row: &Vec<&str>| row.len() != 0)
95            .collect();
96        if values
97            .iter()
98            .skip(1)
99            .any(|row| row.len() != values.get(0).unwrap().len())
100        {
101            return Err(std::io::Error::new(
102                ErrorKind::InvalidData,
103                "Row Lengths are not constant",
104            ));
105        }
106        let height = values.len();
107        let width = values.get(0).unwrap().len();
108        let matrix_bounds = AddressBound {
109            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
110            largest_possible_position: MatrixAddress {
111                x: (width - 1) as i64,
112                y: (height - 1) as i64,
113            },
114        };
115
116        Ok(Matrix::new(matrix_bounds, |address| {
117            str_to_t_converter(
118                values
119                    .get(address.y as usize)
120                    .unwrap()
121                    .get(address.x as usize)
122                    .unwrap(),
123                address,
124            )
125        }))
126    }
127}
128
129impl<'a, T: Display + 'a> Display for Matrix<T> {
130    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
131        write!(
132            f,
133            "{}",
134            self.to_display_string(|x: &T| x.to_string(), " ", "\n")
135        )
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use crate::address_bound::AddressBound;
142    use crate::matrix::Matrix;
143    use crate::matrix_address::MatrixAddress;
144    use crate::tensor::Tensor;
145    use proptest::num::usize;
146    use std::str::FromStr;
147
148    #[test]
149    fn display_test() {
150        let bound = AddressBound::new(MatrixAddress { x: 0, y: 0 }, MatrixAddress { x: 10, y: 10 });
151        assert_eq!(
152            "0 1 2 3 4 5 6 0 1 2 3\n4 5 6 0 1 2 3 4 5 6 0\n1 2 3 4 5 6 0 1 2 3 4\n5 6 0 1 2 3 4 5 6 0 1\n2 3 4 5 6 0 1 2 3 4 5\n6 0 1 2 3 4 5 6 0 1 2\n3 4 5 6 0 1 2 3 4 5 6\n0 1 2 3 4 5 6 0 1 2 3\n4 5 6 0 1 2 3 4 5 6 0\n1 2 3 4 5 6 0 1 2 3 4\n5 6 0 1 2 3 4 5 6 0 1\n",
153            format!(
154                "{}",
155                Matrix::new(bound.clone(), |address: MatrixAddress| bound
156                    .index_address(&address)
157                    .unwrap()
158                    % 7)
159            )
160        )
161    }
162
163    #[test]
164    fn parse_test() {
165        let data_str = "0,1,2,3,4,5,6,0,1,2,3|4,5,6,0,1,2,3,4,5,6,0|1,2,3,4,5,6,0,1,2,3,4|5,6,0,1,2,3,4,5,6,0,1|2,3,4,5,6,0,1,2,3,4,5|6,0,1,2,3,4,5,6,0,1,2|3,4,5,6,0,1,2,3,4,5,6|0,1,2,3,4,5,6,0,1,2,3|4,5,6,0,1,2,3,4,5,6,0|1,2,3,4,5,6,0,1,2,3,4|5,6,0,1,2,3,4,5,6,0,1";
166        let bound = AddressBound::new(MatrixAddress { x: 0, y: 0 }, MatrixAddress { x: 10, y: 10 });
167        assert_eq!(
168            Matrix::new(bound.clone(), |address: MatrixAddress| bound
169                .index_address(&address)
170                .unwrap()
171                % 7),
172            Matrix::parse_matrix(data_str, ",", "|", |string, _| usize::from_str(string)
173                .expect(""))
174            .expect("")
175        );
176    }
177    #[test]
178    fn get_test() {
179        let bound = AddressBound {
180            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
181            largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
182        };
183        let matrix = Matrix::new(bound.clone(), |address| bound.index_address(&address));
184        matrix.address_iterator().for_each(|address| {
185            assert_eq!(
186                bound.index_address(&address),
187                *matrix.get(&address).unwrap()
188            )
189        })
190    }
191
192    #[test]
193    fn set_test() {
194        let bound = AddressBound {
195            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
196            largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
197        };
198        let mut matrix = Matrix::new(bound.clone(), |_address| 0usize);
199        matrix.address_iterator().for_each(|address| {
200            assert_eq!(matrix.get(&address).unwrap(), &0usize);
201            matrix
202                .set(&address, bound.index_address(&address).unwrap())
203                .expect("Index out of bounds error");
204            assert_eq!(
205                matrix.get(&address).unwrap(),
206                &(bound.index_address(&address).unwrap())
207            );
208        });
209        matrix.address_iterator().for_each(|address| {
210            assert_eq!(
211                bound.index_address(&address).unwrap(),
212                *(matrix.get(&address).unwrap())
213            )
214        })
215    }
216    #[test]
217    fn transform_test() {
218        let bound = AddressBound {
219            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
220            largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
221        };
222        let matrix = Matrix::new(bound.clone(), |address| {
223            bound.index_address(&address).unwrap()
224        });
225        let transformed_matrix: Matrix<f64> = matrix.transform(|value| *value as f64);
226        let transformed_by_address_matrix: Matrix<f64> =
227            matrix.transform_by_address(|_address, value| *value as f64);
228        matrix.address_iterator().for_each(|address| {
229            assert_eq!(
230                *matrix.get(&address).unwrap() as f64,
231                *transformed_matrix.get(&address).unwrap(),
232            );
233            assert_eq!(
234                *matrix.get(&address).unwrap() as f64,
235                *transformed_by_address_matrix.get(&address).unwrap()
236            );
237        })
238    }
239
240    #[test]
241    fn transform_in_place_test() {
242        let bound = AddressBound {
243            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
244            largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
245        };
246        let original_matrix = Matrix::new(bound.clone(), |address| {
247            bound.index_address(&address).unwrap() as i64
248        });
249        let mut working_matrix = original_matrix.clone();
250        let mut working_by_address_matrix = original_matrix.clone();
251
252        working_matrix.transform_in_place(|value| *value *= -2);
253        working_by_address_matrix.transform_by_address_in_place(|_address, value| *value *= -2);
254        assert_eq!(
255            original_matrix.transform::<_, i64, Matrix<i64>>(|value| *value * -2),
256            working_matrix
257        );
258        assert_eq!(working_matrix, working_by_address_matrix)
259    }
260
261    #[test]
262    fn equality_test() {
263        let bound = AddressBound {
264            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
265            largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
266        };
267        let m1 = Matrix::new(bound.clone(), |address| {
268            bound.index_address(&address).unwrap() as i32
269        });
270        let m2 = Matrix::new(bound.clone(), |address| {
271            bound.index_address(&address).unwrap() as i32
272        });
273        assert_eq!(m1, m2);
274    }
275}