rust_tensors/
matrix.rs

1use crate::address_bound::AddressBound;
2use crate::matrix_address::MatrixAddress;
3use crate::tensor::Tensor;
4use std::fmt::{Display, Formatter};
5use std::io::Error;
6use std::io::ErrorKind::InvalidInput;
7
8/// A tensor of two dimensions accessed using MatrixAddress.
9#[derive(Clone, Debug, PartialEq)]
10pub struct Matrix<T> {
11    data: Vec<T>,
12    bounds: AddressBound<MatrixAddress>,
13}
14
15impl<T> Tensor<T, MatrixAddress> for Matrix<T> {
16    fn new<F>(bounds: AddressBound<MatrixAddress>, address_value_converter: F) -> Matrix<T>
17    where
18        F: Fn(MatrixAddress) -> T,
19    {
20        let data: Vec<T> = bounds.iter().map(address_value_converter).collect();
21        Matrix { data, bounds }
22    }
23
24    fn get(&self, address: &MatrixAddress) -> Option<&T> {
25        if !self.bounds.contains_address(address) {
26            return None;
27        }
28        self.data.get(self.bounds.index_address(address).unwrap())
29    }
30
31    fn get_mut(&mut self, address: &MatrixAddress) -> Option<&mut T> {
32        if !self.bounds.contains_address(address) {
33            return None;
34        }
35        self.data
36            .get_mut(self.bounds.index_address(address).unwrap())
37    }
38
39    fn set(&mut self, address: &MatrixAddress, value: T) -> Result<(), Error> {
40        if !self.bounds.contains_address(address) {
41            return Err(Error::new(
42                InvalidInput,
43                format!("The following address is out of bounds: {address}"),
44            ));
45        }
46        self.data[self.bounds.index_address(address).unwrap()] = value;
47        Ok(())
48    }
49
50    fn bounds(&self) -> &AddressBound<MatrixAddress> {
51        &self.bounds
52    }
53}
54
55impl<T> Matrix<T> {
56    pub fn to_display_string<T1: Display, F: Fn(&T) -> T1>(
57        &self,
58        display_func: F,
59        row_delimiter: &str,
60        column_delimiter: &str,
61    ) -> String {
62        self.address_iterator()
63            .enumerate()
64            .map(|(i, address)| {
65                format!(
66                    "{}{}",
67                    display_func(self.get(&address).unwrap()),
68                    if (i as i64 + 1) % (self.bounds.largest_possible_position.x + 1) == 0 {
69                        column_delimiter
70                    } else {
71                        row_delimiter
72                    }
73                )
74            })
75            .fold("".to_string(), |a: String, b: String| a + &b)
76    }
77}
78
79impl<'a, T: Display + 'a> Display for Matrix<T> {
80    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
81        write!(
82            f,
83            "{}",
84            self.to_display_string(|x: &T| x.to_string(), " ", "\n")
85        )
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use crate::address_bound::AddressBound;
92    use crate::matrix::Matrix;
93    use crate::matrix_address::MatrixAddress;
94    use crate::tensor::Tensor;
95
96    #[test]
97    fn display_test() {
98        let bound = AddressBound::new(MatrixAddress { x: 0, y: 0 }, MatrixAddress { x: 10, y: 10 });
99        assert_eq!(
100                "0 1 2 3 4 5 6 0 1 2 3\n4 5 6 0 1 2 3 4 5 6 0\n1 2 3 4 5 6 0 1 2 3 4\n5 6 0 1 2 3 4 5 6 0 1\n2 3 4 5 6 0 1 2 3 4 5\n6 0 1 2 3 4 5 6 0 1 2\n3 4 5 6 0 1 2 3 4 5 6\n0 1 2 3 4 5 6 0 1 2 3\n4 5 6 0 1 2 3 4 5 6 0\n1 2 3 4 5 6 0 1 2 3 4\n5 6 0 1 2 3 4 5 6 0 1\n",
101                format!(
102                    "{}",
103                    Matrix::new(bound.clone(), |address: MatrixAddress| bound
104                        .index_address(&address)
105                        .unwrap()
106                        % 7)
107                )
108            )
109    }
110
111    #[test]
112    fn get_test() {
113        let bound = AddressBound {
114            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
115            largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
116        };
117        let matrix = Matrix::new(bound.clone(), |address| bound.index_address(&address));
118        matrix.address_iterator().for_each(|address| {
119            assert_eq!(
120                bound.index_address(&address),
121                *matrix.get(&address).unwrap()
122            )
123        })
124    }
125
126    #[test]
127    fn set_test() {
128        let bound = AddressBound {
129            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
130            largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
131        };
132        let mut matrix = Matrix::new(bound.clone(), |_address| 0usize);
133        matrix.address_iterator().for_each(|address| {
134            assert_eq!(matrix.get(&address).unwrap(), &0usize);
135            matrix
136                .set(&address, bound.index_address(&address).unwrap())
137                .expect("Index out of bounds error");
138            assert_eq!(
139                matrix.get(&address).unwrap(),
140                &(bound.index_address(&address).unwrap())
141            );
142        });
143        matrix.address_iterator().for_each(|address| {
144            assert_eq!(
145                bound.index_address(&address).unwrap(),
146                *(matrix.get(&address).unwrap())
147            )
148        })
149    }
150    #[test]
151    fn transform_test() {
152        let bound = AddressBound {
153            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
154            largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
155        };
156        let matrix = Matrix::new(bound.clone(), |address| {
157            bound.index_address(&address).unwrap()
158        });
159        let transformed_matrix: Matrix<f64> = matrix.transform(|value| *value as f64);
160        let transformed_by_address_matrix: Matrix<f64> =
161            matrix.transform_by_address(|_address, value| *value as f64);
162        matrix.address_iterator().for_each(|address| {
163            assert_eq!(
164                *matrix.get(&address).unwrap() as f64,
165                *transformed_matrix.get(&address).unwrap(),
166            );
167            assert_eq!(
168                *matrix.get(&address).unwrap() as f64,
169                *transformed_by_address_matrix.get(&address).unwrap()
170            );
171        })
172    }
173
174    #[test]
175    fn transform_in_place_test() {
176        let bound = AddressBound {
177            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
178            largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
179        };
180        let original_matrix = Matrix::new(bound.clone(), |address| {
181            bound.index_address(&address).unwrap() as i64
182        });
183        let mut working_matrix = original_matrix.clone();
184        let mut working_by_address_matrix = original_matrix.clone();
185
186        working_matrix.transform_in_place(|value| *value *= -2);
187        working_by_address_matrix.transform_by_address_in_place(|_address, value| *value *= -2);
188        assert_eq!(
189            original_matrix.transform::<_, i64, Matrix<i64>>(|value| *value * -2),
190            working_matrix
191        );
192        assert_eq!(working_matrix, working_by_address_matrix)
193    }
194
195    #[test]
196    fn equality_test() {
197        let bound = AddressBound {
198            smallest_possible_position: MatrixAddress { x: 0, y: 0 },
199            largest_possible_position: MatrixAddress { x: 1000, y: 1000 },
200        };
201        let m1 = Matrix::new(bound.clone(), |address| {
202            bound.index_address(&address).unwrap() as i32
203        });
204        let m2 = Matrix::new(bound.clone(), |address| {
205            bound.index_address(&address).unwrap() as i32
206        });
207        assert_eq!(m1, m2);
208    }
209}