rustrix/matrix/
mod.rs

1use std::ops::{Add, AddAssign, Mul, Sub, SubAssign};
2
3type Vec2D<T> = Vec<Vec<T>>;
4
5// TODO: How to ensure the vector is not empty?
6#[derive(Clone, Debug, PartialEq)]
7pub struct Matrix<T>(pub Vec2D<T>)
8where
9    T: Clone + Copy
10        + Add<Output = T> + Sub<Output = T> + Mul<Output = T>
11        + AddAssign + SubAssign
12        + From<i32>;
13
14/// ```
15/// use rustrix::*;
16/// 
17/// // Both macro invocations build the same results.
18/// let (rows, cols, init) = (2, 3, 1);
19/// let m1 = mx!(rows, cols; init);
20/// let m2 = mx![
21///     1, 1, 1;
22///     1, 1, 1;
23/// ];
24/// ```
25#[macro_export]
26macro_rules! mx {
27    ($r: expr, $c: expr$(; $v: expr)?) => {
28        Matrix::from(vec![vec![0$(+$v)?; $c]; $r])
29    };
30    [$($($v: expr),+);+$(;)?] => {
31        Matrix::from(vec![$(vec![$($v,)+]),+])
32    };
33}
34
35
36impl<T> From<Vec2D<T>> for Matrix<T>
37where
38    T: Clone + Copy
39        + Add<Output = T> + Sub<Output = T> + Mul<Output = T>
40        + AddAssign + SubAssign
41        + From<i32>,
42{
43    fn from(v: Vec2D<T>) -> Self {
44        Matrix(v)
45    }
46}
47
48impl<T> Matrix<T>
49where
50    T: Clone + Copy
51        + Add<Output = T> + Sub<Output = T> + Mul<Output = T>
52        + AddAssign + SubAssign
53        + From<i32>,
54{
55    /// Returns the number of rows in the matrix.
56    pub fn rows(&self) -> usize {
57        self.0.len()
58    }
59
60    /// Returns the number of columns in the matrix.
61    pub fn cols(&self) -> usize {
62        self.0[0].len()
63    }
64
65    /// Returns the value at given row, column.
66    pub fn get(&self, row: usize, col: usize) -> T {
67        self.0[row][col]
68    }
69
70    /// Sets the value at given row, column.
71    pub fn set(&mut self, row: usize, col: usize, value: T) {
72        self.0[row][col] = value;
73    }
74}
75
76mod ops;
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    #[test]
83    fn test_macro_1() {
84        let mx = mx!(2, 3; 1);
85
86        assert_eq!(mx, Matrix(vec![vec![1, 1, 1], vec![1, 1, 1]]));
87    }
88
89    #[test]
90    fn test_macro_1_2() {
91        let mut mx = mx!(2, 3; 0);
92        mx.0[0][0] = 1;
93        assert_ne!(mx.0[0][0], mx.0[0][1]);
94    }
95
96    #[test]
97    fn test_macro_2() {
98        let mx = mx![
99            0, 0, 0;
100            0, 0, 0;
101        ];
102
103        assert_eq!(mx, Matrix(vec![vec![0, 0, 0], vec![0, 0, 0]]));
104    }
105
106    #[test]
107    fn test_transpose() {
108        let mx = mx![
109            1, 2, 3;
110            4, 5, 6;
111        ];
112
113        let tp = mx![
114            1, 4;
115            2, 5;
116            3, 6;
117        ];
118
119        assert_eq!(mx.transpose(), tp);
120    }
121
122    #[test]
123    fn test_add_i32() {
124        let m1 = mx![
125            1, 1;
126            1, 1;
127        ];
128
129        let m2 = mx![
130            1, 2;
131            3, 4;
132        ];
133
134        let m3 = mx![
135            2, 3;
136            4, 5;
137        ];
138
139        assert_eq!(m1 + m2, m3);
140    }
141
142    #[test]
143    fn test_add_f64() {
144        let m1 = mx![
145            1.0, 1.0;
146            1.0, 1.0;
147        ];
148
149        let m2 = mx![
150            1.0, 2.0;
151            3.0, 4.0;
152        ];
153
154        let m3 = mx![
155            2.0, 3.0;
156            4.0, 5.0;
157        ];
158
159        assert_eq!(m1 + m2, m3);
160    }
161
162    #[test]
163    fn test_sub()  {
164        let m1 = mx![
165            3, 2, 1;
166        ];
167
168        let m2 = mx![
169            1, 1, 1;
170        ];
171
172        let m3 = mx![
173            2, 1, 0;
174        ];
175
176        assert_eq!(m1 - m2, m3);
177    }
178
179    #[test]
180    fn test_mul_i32() {
181        let m1 = mx![
182            1, 1, 1;
183            1, 1, 1;
184        ];
185
186        let m2 = mx![
187            2, 2, 2;
188            2, 2, 2;
189            2, 2, 2;
190        ];
191
192        let mx = mx![
193            6, 6, 6;
194            6, 6, 6;
195        ];
196
197        assert_eq!(m1 * m2, mx);
198    }
199
200    #[test]
201    fn test_mul_f64() {
202        let m1 = mx![
203            1.0, 1.0, 1.0;
204            1.0, 1.0, 1.0;
205        ];
206
207        let m2 = mx![
208            2.0, 2.0, 2.0;
209            2.0, 2.0, 2.0;
210            2.0, 2.0, 2.0;
211        ];
212
213        let mx = mx![
214            6.0, 6.0, 6.0;
215            6.0, 6.0, 6.0;
216        ];
217
218        assert_eq!(m1 * m2, mx);
219    }
220
221    #[test]
222    fn test_mul_scalar() {
223        let m1 = mx![
224            1, 2, 3;
225            4, 5, 6;
226        ];
227
228        let m2 = mx![
229            -1, -2, -3;
230            -4, -5, -6;
231        ];
232
233        assert_eq!(m1.mul_scalar(-1), m2);
234    }
235}