rust_tensors/
matrix.rs

1use crate::address_bound::AddressBound;
2use crate::matrix_address::MatrixAddress;
3use crate::tensor::Tensor;
4use std::error::Error;
5use std::fmt::{Display, Formatter};
6use std::io::ErrorKind;
7use std::io::ErrorKind::InvalidInput;
8use std::str::FromStr;
9
10/// A tensor of two dimensions accessed using MatrixAddress.
11#[derive(Clone, Debug, PartialEq)]
12pub struct Matrix<T> {
13    data: Vec<T>,
14    bounds: AddressBound<MatrixAddress>,
15}
16
17impl<T> Tensor<T, MatrixAddress> for Matrix<T> {
18    fn new<F>(bounds: AddressBound<MatrixAddress>, address_value_converter: F) -> Matrix<T>
19    where
20        F: Fn(MatrixAddress) -> T,
21    {
22        let data: Vec<T> = bounds.iter().map(address_value_converter).collect();
23        Matrix { data, bounds }
24    }
25
26    fn get(&self, address: &MatrixAddress) -> Option<&T> {
27        if !self.bounds.contains_address(address) {
28            return None;
29        }
30        self.data.get(self.bounds.index_address(address).unwrap())
31    }
32
33    fn get_mut(&mut self, address: &MatrixAddress) -> Option<&mut T> {
34        if !self.bounds.contains_address(address) {
35            return None;
36        }
37        self.data
38            .get_mut(self.bounds.index_address(address).unwrap())
39    }
40
41    fn set(&mut self, address: &MatrixAddress, value: T) -> Result<(), std::io::Error> {
42        if !self.bounds.contains_address(address) {
43            return Err(std::io::Error::new(
44                InvalidInput,
45                format!("The following address is out of bounds: {address}"),
46            ));
47        }
48        self.data[self.bounds.index_address(address).unwrap()] = value;
49        Ok(())
50    }
51
52    fn bounds(&self) -> &AddressBound<MatrixAddress> {
53        &self.bounds
54    }
55}
56
57impl<T> Matrix<T> {
58    pub fn to_display_string<T1: Display, F: Fn(&T) -> T1>(
59        &self,
60        display_func: F,
61        row_delimiter: &str,
62        column_delimiter: &str,
63    ) -> String {
64        self.address_iterator()
65            .enumerate()
66            .map(|(i, address)| {
67                format!(
68                    "{}{}",
69                    display_func(self.get(&address).unwrap()),
70                    if (i as i64 + 1) % (self.bounds.largest_possible_position.x + 1) == 0 {
71                        column_delimiter
72                    } else {
73                        row_delimiter
74                    }
75                )
76            })
77            .fold("".to_string(), |a: String, b: String| a + &b)
78    }
79
80    pub fn parse_matrix<F>(
81        data_str: &str,
82        column_delimiter: &str,
83        row_delimiter: &str,
84        str_to_t_converter: F,
85    ) -> Result<Matrix<T>, std::io::Error>
86    where
87        F: Fn(&str) -> T,
88    {
89        let values: Vec<Vec<&str>> = data_str
90            .split(row_delimiter)
91            .map(|row| row.split(column_delimiter).collect())
92            .collect();
93        if values
94            .iter()
95            .skip(1)
96            .any(|row| row.len() != values.get(0).unwrap().len())
97        {
98            return Err(std::io::Error::new(
99                ErrorKind::InvalidData,
100                "Row Lengths are not constant",
101            ));
102        }
103        let height = values.len();
104        let width = values.get(0).unwrap().len();
105        let matrix_bounds = AddressBound {
106            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
107            largest_possible_position: MatrixAddress {
108                x: (width - 1) as i64,
109                y: (height - 1) as i64,
110            },
111        };
112
113        Ok(Matrix::new(matrix_bounds, |address| {
114            str_to_t_converter(
115                values
116                    .get(address.y as usize)
117                    .unwrap()
118                    .get(address.x as usize)
119                    .unwrap(),
120            )
121        }))
122    }
123}
124
125impl<'a, T: Display + 'a> Display for Matrix<T> {
126    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
127        write!(
128            f,
129            "{}",
130            self.to_display_string(|x: &T| x.to_string(), " ", "\n")
131        )
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use crate::address_bound::AddressBound;
138    use crate::matrix::Matrix;
139    use crate::matrix_address::MatrixAddress;
140    use crate::tensor::Tensor;
141    use proptest::num::usize;
142    use std::str::FromStr;
143
144    #[test]
145    fn display_test() {
146        let bound = AddressBound::new(MatrixAddress { x: 0, y: 0 }, MatrixAddress { x: 10, y: 10 });
147        assert_eq!(
148            "0 1 2 3 4 5 6 0 1 2 3\n4 5 6 0 1 2 3 4 5 6 0\n1 2 3 4 5 6 0 1 2 3 4\n5 6 0 1 2 3 4 5 6 0 1\n2 3 4 5 6 0 1 2 3 4 5\n6 0 1 2 3 4 5 6 0 1 2\n3 4 5 6 0 1 2 3 4 5 6\n0 1 2 3 4 5 6 0 1 2 3\n4 5 6 0 1 2 3 4 5 6 0\n1 2 3 4 5 6 0 1 2 3 4\n5 6 0 1 2 3 4 5 6 0 1\n",
149            format!(
150                "{}",
151                Matrix::new(bound.clone(), |address: MatrixAddress| bound
152                    .index_address(&address)
153                    .unwrap()
154                    % 7)
155            )
156        )
157    }
158
159    #[test]
160    fn parse_test() {
161        let data_str = "0,1,2,3,4,5,6,0,1,2,3|4,5,6,0,1,2,3,4,5,6,0|1,2,3,4,5,6,0,1,2,3,4|5,6,0,1,2,3,4,5,6,0,1|2,3,4,5,6,0,1,2,3,4,5|6,0,1,2,3,4,5,6,0,1,2|3,4,5,6,0,1,2,3,4,5,6|0,1,2,3,4,5,6,0,1,2,3|4,5,6,0,1,2,3,4,5,6,0|1,2,3,4,5,6,0,1,2,3,4|5,6,0,1,2,3,4,5,6,0,1";
162        let bound = AddressBound::new(MatrixAddress { x: 0, y: 0 }, MatrixAddress { x: 10, y: 10 });
163        assert_eq!(
164            Matrix::new(bound.clone(), |address: MatrixAddress| bound
165                .index_address(&address)
166                .unwrap()
167                % 7),
168            Matrix::parse_matrix(data_str, ",", "|", |string| usize::from_str(string)
169                .expect(""))
170            .expect("")
171        );
172    }
173    #[test]
174    fn get_test() {
175        let bound = AddressBound {
176            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
177            largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
178        };
179        let matrix = Matrix::new(bound.clone(), |address| bound.index_address(&address));
180        matrix.address_iterator().for_each(|address| {
181            assert_eq!(
182                bound.index_address(&address),
183                *matrix.get(&address).unwrap()
184            )
185        })
186    }
187
188    #[test]
189    fn set_test() {
190        let bound = AddressBound {
191            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
192            largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
193        };
194        let mut matrix = Matrix::new(bound.clone(), |_address| 0usize);
195        matrix.address_iterator().for_each(|address| {
196            assert_eq!(matrix.get(&address).unwrap(), &0usize);
197            matrix
198                .set(&address, bound.index_address(&address).unwrap())
199                .expect("Index out of bounds error");
200            assert_eq!(
201                matrix.get(&address).unwrap(),
202                &(bound.index_address(&address).unwrap())
203            );
204        });
205        matrix.address_iterator().for_each(|address| {
206            assert_eq!(
207                bound.index_address(&address).unwrap(),
208                *(matrix.get(&address).unwrap())
209            )
210        })
211    }
212    #[test]
213    fn transform_test() {
214        let bound = AddressBound {
215            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
216            largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
217        };
218        let matrix = Matrix::new(bound.clone(), |address| {
219            bound.index_address(&address).unwrap()
220        });
221        let transformed_matrix: Matrix<f64> = matrix.transform(|value| *value as f64);
222        let transformed_by_address_matrix: Matrix<f64> =
223            matrix.transform_by_address(|_address, value| *value as f64);
224        matrix.address_iterator().for_each(|address| {
225            assert_eq!(
226                *matrix.get(&address).unwrap() as f64,
227                *transformed_matrix.get(&address).unwrap(),
228            );
229            assert_eq!(
230                *matrix.get(&address).unwrap() as f64,
231                *transformed_by_address_matrix.get(&address).unwrap()
232            );
233        })
234    }
235
236    #[test]
237    fn transform_in_place_test() {
238        let bound = AddressBound {
239            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
240            largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
241        };
242        let original_matrix = Matrix::new(bound.clone(), |address| {
243            bound.index_address(&address).unwrap() as i64
244        });
245        let mut working_matrix = original_matrix.clone();
246        let mut working_by_address_matrix = original_matrix.clone();
247
248        working_matrix.transform_in_place(|value| *value *= -2);
249        working_by_address_matrix.transform_by_address_in_place(|_address, value| *value *= -2);
250        assert_eq!(
251            original_matrix.transform::<_, i64, Matrix<i64>>(|value| *value * -2),
252            working_matrix
253        );
254        assert_eq!(working_matrix, working_by_address_matrix)
255    }
256
257    #[test]
258    fn equality_test() {
259        let bound = AddressBound {
260            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
261            largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
262        };
263        let m1 = Matrix::new(bound.clone(), |address| {
264            bound.index_address(&address).unwrap() as i32
265        });
266        let m2 = Matrix::new(bound.clone(), |address| {
267            bound.index_address(&address).unwrap() as i32
268        });
269        assert_eq!(m1, m2);
270    }
271}