1use std::ops::{Add, AddAssign, Mul, Sub, SubAssign};
2
3type Vec2D<T> = Vec<Vec<T>>;
4
5#[derive(Clone, Debug, PartialEq)]
7pub struct Matrix<T>(pub Vec2D<T>)
8where
9 T: Clone + Copy
10 + Add<Output = T> + Sub<Output = T> + Mul<Output = T>
11 + AddAssign + SubAssign
12 + From<i32>;
13
14#[macro_export]
26macro_rules! mx {
27 ($r: expr, $c: expr$(; $v: expr)?) => {
28 Matrix::from(vec![vec![0$(+$v)?; $c]; $r])
29 };
30 [$($($v: expr),+);+$(;)?] => {
31 Matrix::from(vec![$(vec![$($v,)+]),+])
32 };
33}
34
35
36impl<T> From<Vec2D<T>> for Matrix<T>
37where
38 T: Clone + Copy
39 + Add<Output = T> + Sub<Output = T> + Mul<Output = T>
40 + AddAssign + SubAssign
41 + From<i32>,
42{
43 fn from(v: Vec2D<T>) -> Self {
44 Matrix(v)
45 }
46}
47
48impl<T> Matrix<T>
49where
50 T: Clone + Copy
51 + Add<Output = T> + Sub<Output = T> + Mul<Output = T>
52 + AddAssign + SubAssign
53 + From<i32>,
54{
55 pub fn rows(&self) -> usize {
57 self.0.len()
58 }
59
60 pub fn cols(&self) -> usize {
62 self.0[0].len()
63 }
64
65 pub fn get(&self, row: usize, col: usize) -> T {
67 self.0[row][col]
68 }
69
70 pub fn set(&mut self, row: usize, col: usize, value: T) {
72 self.0[row][col] = value;
73 }
74}
75
76mod ops;
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 #[test]
83 fn test_macro_1() {
84 let mx = mx!(2, 3; 1);
85
86 assert_eq!(mx, Matrix(vec![vec![1, 1, 1], vec![1, 1, 1]]));
87 }
88
89 #[test]
90 fn test_macro_1_2() {
91 let mut mx = mx!(2, 3; 0);
92 mx.0[0][0] = 1;
93 assert_ne!(mx.0[0][0], mx.0[0][1]);
94 }
95
96 #[test]
97 fn test_macro_2() {
98 let mx = mx![
99 0, 0, 0;
100 0, 0, 0;
101 ];
102
103 assert_eq!(mx, Matrix(vec![vec![0, 0, 0], vec![0, 0, 0]]));
104 }
105
106 #[test]
107 fn test_transpose() {
108 let mx = mx![
109 1, 2, 3;
110 4, 5, 6;
111 ];
112
113 let tp = mx![
114 1, 4;
115 2, 5;
116 3, 6;
117 ];
118
119 assert_eq!(mx.transpose(), tp);
120 }
121
122 #[test]
123 fn test_add_i32() {
124 let m1 = mx![
125 1, 1;
126 1, 1;
127 ];
128
129 let m2 = mx![
130 1, 2;
131 3, 4;
132 ];
133
134 let m3 = mx![
135 2, 3;
136 4, 5;
137 ];
138
139 assert_eq!(m1 + m2, m3);
140 }
141
142 #[test]
143 fn test_add_f64() {
144 let m1 = mx![
145 1.0, 1.0;
146 1.0, 1.0;
147 ];
148
149 let m2 = mx![
150 1.0, 2.0;
151 3.0, 4.0;
152 ];
153
154 let m3 = mx![
155 2.0, 3.0;
156 4.0, 5.0;
157 ];
158
159 assert_eq!(m1 + m2, m3);
160 }
161
162 #[test]
163 fn test_sub() {
164 let m1 = mx![
165 3, 2, 1;
166 ];
167
168 let m2 = mx![
169 1, 1, 1;
170 ];
171
172 let m3 = mx![
173 2, 1, 0;
174 ];
175
176 assert_eq!(m1 - m2, m3);
177 }
178
179 #[test]
180 fn test_mul_i32() {
181 let m1 = mx![
182 1, 1, 1;
183 1, 1, 1;
184 ];
185
186 let m2 = mx![
187 2, 2, 2;
188 2, 2, 2;
189 2, 2, 2;
190 ];
191
192 let mx = mx![
193 6, 6, 6;
194 6, 6, 6;
195 ];
196
197 assert_eq!(m1 * m2, mx);
198 }
199
200 #[test]
201 fn test_mul_f64() {
202 let m1 = mx![
203 1.0, 1.0, 1.0;
204 1.0, 1.0, 1.0;
205 ];
206
207 let m2 = mx![
208 2.0, 2.0, 2.0;
209 2.0, 2.0, 2.0;
210 2.0, 2.0, 2.0;
211 ];
212
213 let mx = mx![
214 6.0, 6.0, 6.0;
215 6.0, 6.0, 6.0;
216 ];
217
218 assert_eq!(m1 * m2, mx);
219 }
220
221 #[test]
222 fn test_mul_scalar() {
223 let m1 = mx![
224 1, 2, 3;
225 4, 5, 6;
226 ];
227
228 let m2 = mx![
229 -1, -2, -3;
230 -4, -5, -6;
231 ];
232
233 assert_eq!(m1.mul_scalar(-1), m2);
234 }
235}