Skip to main content

rust_tensors/
matrix.rs

1use crate::matrix_address::MatrixAddress;
2use crate::tensor::Tensor;
3use std::fmt::{Display, Formatter};
4use std::ops::{Index, IndexMut};
5
6#[derive(Debug, Eq, PartialEq, Clone)]
7pub struct Matrix<T> {
8    width: usize,
9    height: usize,
10    data: Vec<T>,
11}
12
13impl<T> Matrix<T> {
14    /// Creates a new Matrix based on dimensions and a mapper function.
15    ///
16    /// # Arguments
17    ///
18    /// * `width`: The width, or number of columns in the matrix
19    /// * `height`: The height, or number of rows in the matrix
20    /// * `address_value_converter`: Converts a matrix address to a value.
21    ///
22    /// Returns: Matrix<T>
23    ///
24    /// # Examples
25    ///
26    /// ```
27    /// use rust_tensors::matrix::Matrix;
28    /// use rust_tensors::tensor::Tensor;
29    ///
30    /// // Creates a 1000x1000 zero matrix
31    /// let (width, height) = (1000, 1000);
32    /// let mut matrix = Matrix::new(width, height, |_address| 0usize);
33    /// matrix.address_iter()
34    ///     .for_each(|address| assert_eq!(matrix[address], 0));
35    ///
36    /// // Creates a 50x10 matrix where the value is the index of the array
37    /// let (width, height) = (1000, 1000);
38    /// let mut matrix = Matrix::new(width, height, |address| address.y * width as i32 + address.x);
39    /// matrix.address_iter()
40    ///     .for_each(|address| assert_eq!(matrix[address], address.y * width as i32 + address.x));
41    /// ```
42    pub fn new<F>(width: usize, height: usize, address_value_converter: F) -> Self
43    where
44        F: Fn(MatrixAddress) -> T,
45    {
46        let mut matrix = Matrix {
47            width,
48            height,
49            data: Vec::<T>::with_capacity(width * height),
50        };
51        matrix
52            .address_iter()
53            .for_each(|address| matrix.data.push(address_value_converter(address)));
54        matrix
55    }
56
57    /// Makes a string fit for displaying the contents of the matrix
58    ///
59    /// # Arguments
60    ///
61    /// * `display_func`: Converts a value to a string
62    /// * `row_delimiter`: Separates the rows in the matrix
63    /// * `column_delimiter`: Separates the columns in the matrix
64    ///
65    /// Returns: the formatted string
66    ///
67    /// # Examples
68    ///
69    /// ```
70    /// use rust_tensors::matrix::Matrix;
71    /// let mut matrix =
72    /// Matrix::<i32>::parse_matrix("1 2 3|4 5 6|7 8 9", " ", "|", |s| s.parse().unwrap())
73    ///     .unwrap();
74    /// assert_eq!(
75    ///     matrix.to_display_string(|i| i.to_string(), "-", "|"),
76    ///     "1-2-3|4-5-6|7-8-9"
77    /// );
78    /// ```
79    pub fn to_display_string<T1: Display, F: Fn(&T) -> T1>(
80        &self,
81        display_func: F,
82        row_delimiter: &str,
83        column_delimiter: &str,
84    ) -> String {
85        self.address_iter()
86            .enumerate()
87            .map(|(i, address)| {
88                format!(
89                    "{}{}",
90                    display_func(&self[address]),
91                    if (i + 1) % (self.width) == 0 {
92                        if i != self.width * self.height - 1 {
93                            column_delimiter
94                        } else {
95                            ""
96                        }
97                    } else {
98                        row_delimiter
99                    }
100                )
101            })
102            .fold("".to_string(), |a: String, b: String| a + &b)
103    }
104
105    /// Parses a matrix from a string.
106    /// Fallible, and will return an Err if the matrix cannot be parsed,
107    /// or if the matrix does not have a uniform row length
108    ///
109    /// # Arguments
110    ///
111    /// * `data_str`: The string to be parsed
112    /// * `column_delimiter`: The string which separates the items in the columns
113    /// * `row_delimiter`: The string which separates the rows
114    /// * `str_to_t_converter`: The function which converts the item strings to a value
115    ///
116    /// Returns: Result<Matrix<T>, String>
117    ///
118    /// # Examples
119    ///
120    /// ```
121    /// use rust_tensors::matrix::Matrix;
122    ///
123    /// let mut matrix =
124    ///     Matrix::<i32>::parse_matrix("0 1 2|3 4 5|6 7 8", " ", "|", |s| s.parse().unwrap())
125    ///         .unwrap();
126    ///
127    /// assert_eq!(
128    ///     matrix, Matrix::new(3, 3, |address| address.x + 3 * address.y)
129    /// );
130    /// ```
131    pub fn parse_matrix<F>(
132        data_str: &str,
133        column_delimiter: &str,
134        row_delimiter: &str,
135        str_to_t_converter: F,
136    ) -> Result<Matrix<T>, String>
137    where
138        F: Fn(&str) -> T,
139    {
140        let values: Vec<Vec<&str>> = data_str
141            .split(row_delimiter)
142            .map(|row| {
143                row.split(column_delimiter)
144                    .filter(|string| !string.is_empty())
145                    .collect()
146            })
147            .filter(|row: &Vec<&str>| !row.is_empty())
148            .collect();
149        if values
150            .iter()
151            .skip(1)
152            .any(|row| row.len() != values.first().unwrap().len())
153        {
154            return Err("Row Lengths are not constant".into());
155        }
156        let height = values.len();
157        let width = values.first().unwrap().len();
158
159        Ok(Matrix::new(width, height, |address| {
160            str_to_t_converter(values[address.y as usize][address.x as usize])
161        }))
162    }
163    fn index_address(&self, address: MatrixAddress) -> usize {
164        address.y as usize * self.width + address.x as usize
165    }
166}
167
168impl<T> Tensor<T, i32, MatrixAddress, 2> for Matrix<T> {
169    fn smallest_contained_address(&self) -> MatrixAddress {
170        MatrixAddress { x: 0, y: 0 }
171    }
172
173    fn largest_contained_address(&self) -> MatrixAddress {
174        MatrixAddress {
175            x: (self.width - 1) as i32,
176            y: (self.height - 1) as i32,
177        }
178    }
179}
180
181impl<T: Display> Display for Matrix<T> {
182    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
183        write!(
184            f,
185            "{}",
186            self.to_display_string(|t| t.to_string(), " ", "\n")
187        )
188    }
189}
190
191impl<T> Index<MatrixAddress> for Matrix<T> {
192    type Output = T;
193
194    fn index(&self, index: MatrixAddress) -> &Self::Output {
195        &self.data[self.index_address(index)]
196    }
197}
198
199impl<T> Index<(i32, i32)> for Matrix<T> {
200    type Output = T;
201
202    fn index(&self, index: (i32, i32)) -> &Self::Output {
203        &self[MatrixAddress {
204            x: index.0,
205            y: index.1,
206        }]
207    }
208}
209
210impl<T> IndexMut<MatrixAddress> for Matrix<T> {
211    fn index_mut(&mut self, index: MatrixAddress) -> &mut Self::Output {
212        let index = self.index_address(index);
213        &mut self.data[index]
214    }
215}
216
217impl<T> IndexMut<(i32, i32)> for Matrix<T> {
218    fn index_mut(&mut self, index: (i32, i32)) -> &mut Self::Output {
219        &mut self[MatrixAddress {
220            x: index.0,
221            y: index.1,
222        }]
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use crate::address_iterator::AddressIterator;
229    use crate::matrix::Matrix;
230    use crate::matrix_address::MatrixAddress;
231    use crate::tensor::Tensor;
232    use proptest::proptest;
233    use std::str::FromStr;
234
235    #[test]
236    fn display_test() {
237        let (width, height) = (11, 11);
238        assert_eq!(
239            "0 1 2 3 4 5 6 0 1 2 3\n4 5 6 0 1 2 3 4 5 6 0\n1 2 3 4 5 6 0 1 2 3 4\n5 6 0 1 2 3 4 5 6 0 1\n2 3 4 5 6 0 1 2 3 4 5\n6 0 1 2 3 4 5 6 0 1 2\n3 4 5 6 0 1 2 3 4 5 6\n0 1 2 3 4 5 6 0 1 2 3\n4 5 6 0 1 2 3 4 5 6 0\n1 2 3 4 5 6 0 1 2 3 4\n5 6 0 1 2 3 4 5 6 0 1",
240            format!(
241                "{}",
242                Matrix::new(width, height, |address: MatrixAddress| {
243                    (address.x as usize + address.y as usize * width) % 7
244                })
245            )
246        )
247    }
248    #[test]
249    fn set_test() {
250        let (width, height) = (1000, 1000);
251        let mut matrix = Matrix::new(width, height, |_address| 0usize);
252        matrix.address_iter().for_each(|address| {
253            assert_eq!(matrix[address], 0usize);
254            matrix[address] = matrix.index_address(address);
255            assert_eq!(matrix[address], matrix.index_address(address));
256        });
257        matrix
258            .address_iter()
259            .for_each(|address| assert_eq!(matrix.index_address(address), matrix[address]))
260    }
261
262    #[test]
263    fn get_test() {
264        let (width, height) = (1000, 1000);
265        let matrix = Matrix::new(width, height, |address| {
266            address.x as usize + address.y as usize * width
267        });
268        assert_eq!(matrix.index_address(MatrixAddress { x: 999, y: 0 }), 999);
269        assert_eq!(matrix.index_address(MatrixAddress { x: 0, y: 1 }), 1000);
270        assert_eq!(matrix.index_address(MatrixAddress { x: 1, y: 1 }), 1001);
271        matrix
272            .address_iter()
273            .for_each(|address| assert_eq!(matrix.index_address(address), matrix[address]))
274    }
275    #[test]
276    fn parse_test() {
277        let data_str = "0,1,2,3,4,5,6,0,1,2,3|4,5,6,0,1,2,3,4,5,6,0|1,2,3,4,5,6,0,1,2,3,4|5,6,0,1,2,3,4,5,6,0,1|2,3,4,5,6,0,1,2,3,4,5|6,0,1,2,3,4,5,6,0,1,2|3,4,5,6,0,1,2,3,4,5,6|0,1,2,3,4,5,6,0,1,2,3|4,5,6,0,1,2,3,4,5,6,0|1,2,3,4,5,6,0,1,2,3,4|5,6,0,1,2,3,4,5,6,0,1";
278        let (width, height) = (11, 11);
279        assert_eq!(
280            Matrix::new(width, height, |address: MatrixAddress| (address.y
281                * width as i32
282                + address.x)
283                % 7),
284            Matrix::parse_matrix(data_str, ",", "|", |string| i32::from_str(string)
285                .expect(""))
286            .expect("")
287        );
288    }
289
290    #[test]
291    fn equality_test() {
292        let (width, height) = (100, 200);
293        let mut m1 = Matrix::new(width, height, |address| {
294            address.y * width as i32 + address.x
295        });
296        let m2 = Matrix::new(width, height, |address| {
297            address.y * width as i32 + address.x
298        });
299        assert_eq!(m1, m2);
300        for address in m1.address_iter() {
301            assert_eq!(m1, m2);
302            m1[address] += 1;
303            assert_ne!(m1, m2);
304            m1[address] -= 1;
305        }
306    }
307    #[test]
308    fn address_iterator_test() {
309        let iter: AddressIterator<_, MatrixAddress, 2> = AddressIterator::new([0, 0], [2, 4]);
310        let values = iter
311            .map(|address| (address.x, address.y))
312            .collect::<Vec<(i32, i32)>>();
313        assert_eq!(
314            values,
315            vec![
316                (0, 0),
317                (1, 0),
318                (2, 0),
319                (0, 1),
320                (1, 1),
321                (2, 1),
322                (0, 2),
323                (1, 2),
324                (2, 2),
325                (0, 3),
326                (1, 3),
327                (2, 3),
328                (0, 4),
329                (1, 4),
330                (2, 4),
331            ]
332        )
333    }
334    proptest! {
335        #[test]
336        fn address_sugar_test(x in 0..100, y in 0..200) {
337            let matrix = Matrix::new(100, 200, |address| address.y * 100 + address.x);
338            let mut mut_matrix = matrix.clone();
339            let pos_tuple = (x, y);
340            let pos_address = MatrixAddress{x, y};
341            assert_eq!(matrix[pos_tuple], matrix[pos_address]);
342            let temp = matrix[pos_tuple];
343            mut_matrix[pos_address] = -1;
344            mut_matrix[pos_tuple] = temp;
345            assert_eq!(matrix, mut_matrix);
346        }
347    }
348}