1#![allow(non_snake_case)]
2
3use core::{
4 iter::Sum,
6 ops::{Add, Div, Mul, Neg, Sub},
7};
8
9use crate::{
10 eye,
11 num::{Abs, One, Zero},
12 Matrix,
13};
14
15impl<const D: usize, T> Matrix<D, D, T>
16where
17 T: Abs
18 + PartialOrd
19 + Copy
20 + Zero
21 + One
22 + Sum
24 + Add<Output = T>
25 + Neg<Output = T>
26 + Mul<Output = T>
27 + Sub<Output = T>
28 + Div<Output = T>,
29{
30 pub fn inv(&self) -> Option<Self> {
31 let (mut L, mut U, P) = self.lu();
32 if let (Some(L_inv), Some(U_inv)) = (
33 Self::invert_lower_triangular(&mut L),
34 Self::invert_upper_triangular(&mut U),
35 ) {
36 Some(U_inv * L_inv * P)
37 } else {
38 None
39 }
40 }
41
42 pub fn det(&self) -> T {
43 let (L, U, _) = self.lu();
44 let mut det = T::one();
45 for i in 0..D {
46 det = det * L[(i, i)] * U[(i, i)];
47 }
48 if D % 2 != 0 {
49 det = -det;
50 }
51 det
52 }
53
54 pub fn lu(&self) -> (Matrix<D, D, T>, Matrix<D, D, T>, Matrix<D, D, T>) {
55 let mut P = eye!(D, T);
56 let mut L = eye!(D, T);
57 let mut U = *self;
58
59 for d in 0..D {
60 let max_row = Self::find_max_row(&U, d);
62 Self::partial_pivot(&mut P, &mut L, &mut U, d, max_row);
64 Self::gauss_eliminate(&mut L, &mut U, d);
66 }
67 (L, U, P)
68 }
69
70 fn invert_upper_triangular(U: &mut Matrix<D, D, T>) -> Option<Matrix<D, D, T>> {
71 let mut I = eye!(D, T);
72 for i in (0..D).rev() {
73 let diag = U[(i, i)];
74 if diag == T::zero() {
75 return None;
76 }
77 let coeff = T::one() / diag;
78
79 U[(i, i)] = U[(i, i)] * coeff;
81 for c in i..D {
82 I[(i, c)] = I[(i, c)] * coeff;
83 }
84
85 for r in 0..i {
87 let coeff = -U[(r, i)];
88 U[(r, i)] = T::zero();
89 for c in i..D {
90 I[(r, c)] = I[(r, c)] + coeff * I[(i, c)];
91 }
92 }
93 }
94 Some(I)
95 }
96
97 fn invert_lower_triangular(L: &mut Matrix<D, D, T>) -> Option<Matrix<D, D, T>> {
98 let mut I = eye!(D, T);
99 for i in 0..D {
100 let diag = L[(i, i)];
101 if diag != T::one() {
102 return None;
103 }
104 for r in (i + 1)..D {
105 let coeff = -L[(r, i)];
106 L[(r, i)] = T::zero();
107 for c in 0..(i + 1) {
108 I[(r, c)] = I[(r, c)] + coeff * I[(i, c)];
109 }
110 }
111 }
112 Some(I)
115 }
116
117 fn gauss_eliminate(L: &mut Matrix<D, D, T>, U: &mut Matrix<D, D, T>, diag: usize) {
118 let d = diag;
119 for r in (d + 1)..D {
120 L[(r, d)] = U[(r, d)] / U[(d, d)];
121 for c in 0..D {
122 U[(r, c)] = U[(r, c)] - L[(r, d)] * U[(d, c)];
123 }
124 }
125 }
126
127 fn partial_pivot(
128 P: &mut Matrix<D, D, T>,
129 L: &mut Matrix<D, D, T>,
130 U: &mut Matrix<D, D, T>,
131 diag: usize,
132 max_row: usize,
133 ) {
134 P.swap_rows(diag, max_row);
135 U.swap_rows(diag, max_row);
136 for c in 0..diag {
138 let temp = L[(max_row, c)];
139 L[(max_row, c)] = L[(diag, c)];
140 L[(diag, c)] = temp;
141 }
142 }
143
144 fn find_max_row(U: &Matrix<D, D, T>, diag: usize) -> usize {
145 let mut max_row = diag;
146 for r in diag..D {
147 if U[(max_row, diag)].abs() < U[(r, diag)].abs() {
148 max_row = r;
149 }
150 }
151 max_row
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use approx::{assert_abs_diff_eq, assert_relative_eq};
158
159 use super::*;
160 use crate::matrix;
161
162 #[test]
163 fn LU_decomp() {
164 let A = matrix![
165 1.0, 3.0, 5.0;
166 2.0, 4.0, 7.0;
167 1.0, 1.0, 0.0;
168 ];
169 let (L, U, P) = A.lu();
170 let L_exp = matrix![
171 1.0, 0.0, 0.0;
172 0.5, 1.0, 0.0;
173 0.5, -1.0, 1.0;
174 ];
175 let U_exp = matrix![
176 2.0, 4.0, 7.0;
177 0.0, 1.0, 1.5;
178 0.0, 0.0, -2.0;
179 ];
180 let P_exp = matrix![
181 0.0, 1.0, 0.0;
182 1.0, 0.0, 0.0;
183 0.0, 0.0, 1.0;
184 ];
185 assert_eq!(L, L_exp);
186 assert_eq!(U, U_exp);
187 assert_eq!(P, P_exp);
188 }
189
190 #[test]
191 fn determinant() {
192 let A = matrix![
193 6.0, 2.0, 3.0;
194 1.0, 1.0, 1.0;
195 0.0, 4.0, 9.0;
196 ];
197 assert_abs_diff_eq!(A.det(), 24.0, epsilon = 1e-10);
198
199 let A = matrix![
200 3.0, 7.0;
201 1.0, -4.0;
202 ];
203 assert_abs_diff_eq!(A.det(), -19.0, epsilon = 1e-10);
204
205 let A = matrix![
206 1.0, 2.0;
207 4.0, 8.0;
208 ];
209 assert_abs_diff_eq!(A.det(), 0.0, epsilon = 1e-10);
210
211 let A = matrix![
212 1.0, 1.0;
213 2.0, 2.0;
214 ];
215 assert_abs_diff_eq!(A.det(), 0.0, epsilon = 1e-10);
216
217 let A = matrix![
218 1.0, 2.0, 3.0, 4.0;
219 2.0, 5.0, 7.0, 3.0;
220 4.0, 10.0, 14.0, 6.0;
221 3.0, 4.0, 2.0, 7.0;
222 ];
223 assert_abs_diff_eq!(A.det(), 0.0, epsilon = 1e-10);
224
225 let A = matrix![
226 1.0, 2.0, 3.0, 4.0;
227 2.0, 5.0, 7.0, 3.0;
228 4.0, 10.0, 14.0, 6.0;
229 3.0, 4.0, 2.0, 7.0;
230 ];
231 assert_abs_diff_eq!(A.det(), 0.0, epsilon = 1e-10);
232
233 let A = matrix![
234 11.0, 9.0, 24.0, 2.0;
235 1.0, 5.0, 2.0, 6.0;
236 3.0, 17.0, 18.0, 1.0;
237 2.0, 5.0, 7.0, 1.0;
238 ];
239 assert_abs_diff_eq!(A.det(), -284.0, epsilon = 1e-10);
240
241 let A = matrix![
242 2.0, 3.0, 0.0, 9.0, 0.0, 1.0, 0.0, 1.0, 1.0, 2.0, 1.0;
243 1.0, 1.0, 0.0, 3.0, 0.0, 0.0, 0.0, 9.0, 2.0, 3.0, 1.0;
244 1.0, 4.0, 0.0, 2.0, 8.0, 5.0, 0.0, 3.0, 6.0, 1.0, 9.0;
245 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0;
246 2.0, 2.0, 4.0, 1.0, 1.0, 2.0, 1.0, 6.0, 9.0, 0.0, 7.0;
247 0.0, 0.0, 0.0, 6.0, 0.0, 7.0, 0.0, 1.0, 0.0, 0.0, 0.0;
248 2.0, 5.0, 0.0, 7.0, 0.0, 4.0, 6.0, 8.0, 5.0, 1.0, 3.0;
249 0.0, 0.0, 0.0, 1.0, 0.0, 4.0, 0.0, 1.0, 0.0, 0.0, 0.0;
250 0.0, 0.0, 0.0, 8.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0;
251 2.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0;
252 2.0, 6.0, 0.0, 1.0, 0.0,30.0, 0.0, 2.0, 3.0, 2.0, 1.0;
253 ];
254 assert_abs_diff_eq!(A.det(), 0.0, epsilon = 1e-10);
255 }
256
257 #[test]
258 fn upper_inverse() {
259 let mut A = matrix![
260 2.0, 4.0, 6.0;
261 0.0,-1.0,-8.0;
262 0.0, 0.0,96.0;
263 ];
264 let I = Matrix::invert_upper_triangular(&mut A).unwrap();
265 let E = matrix![
266 0.5, 2.0, 0.13541667;
267 0.0,-1.0,-0.08333333;
268 0.0, 0.0, 0.01041667;
269 ];
270 assert_relative_eq!(I, E, max_relative = 1e-6);
271 }
272
273 #[test]
274 fn lower_inverse() {
275 let mut A = matrix![
276 1.0, 0.0, 0.0;
277 8.0, 1.0, 0.0;
278 4.0, 9.0, 1.0;
279 ];
280 let I = Matrix::invert_lower_triangular(&mut A).unwrap();
281 let E = matrix![
282 1.0, 0.0, 0.0;
283 -8.0, 1.0, 0.0;
284 68.0, -9.0, 1.0;
285 ];
286 assert_relative_eq!(I, E, max_relative = 1e-6);
287 }
288
289 #[test]
290 fn inverse() {
291 let A = matrix![
292 6.0, 2.0, 3.0;
293 1.0, 1.0, 1.0;
294 0.0, 4.0, 9.0;
295 ];
296 let exp = matrix![
297 0.20833333, -0.25, -0.04166667;
298 -0.375, 2.25, -0.125;
299 0.16666667, -1.0, 0.16666667;
300 ];
301 assert_relative_eq!(A.inv().unwrap(), exp, max_relative = 1e-6);
302
303 let A = matrix![
304 11.0, 9.0, 24.0, 2.0;
305 1.0, 5.0, 2.0, 6.0;
306 3.0, 17.0, 18.0, 1.0;
307 2.0, 5.0, 7.0, 1.0;
308 ];
309 let exp = matrix![
310 0.72183099, 0.46126761, 1.02112676, -5.23239437;
311 0.28521127, 0.23591549, 0.59859155, -2.58450704;
312 -0.37676056, -0.29929577, -0.65492958, 3.20422535;
313 -0.23239437, -0.00704225, -0.45070423, 1.95774648;
314 ];
315 assert_relative_eq!(A.inv().unwrap(), exp, max_relative = 1e-6);
316 }
317}