stack_algebra/
algebra.rs

1#![allow(non_snake_case)]
2
3use core::{
4    // fmt,
5    iter::Sum,
6    ops::{Add, Div, Mul, Neg, Sub},
7};
8
9use crate::{
10    eye,
11    num::{Abs, One, Zero},
12    Matrix,
13};
14
15impl<const D: usize, T> Matrix<D, D, T>
16where
17    T: Abs
18        + PartialOrd
19        + Copy
20        + Zero
21        + One
22        // + fmt::Display
23        + Sum
24        + Add<Output = T>
25        + Neg<Output = T>
26        + Mul<Output = T>
27        + Sub<Output = T>
28        + Div<Output = T>,
29{
30    pub fn inv(&self) -> Option<Self> {
31        let (mut L, mut U, P) = self.lu();
32        if let (Some(L_inv), Some(U_inv)) = (
33            Self::invert_lower_triangular(&mut L),
34            Self::invert_upper_triangular(&mut U),
35        ) {
36            Some(U_inv * L_inv * P)
37        } else {
38            None
39        }
40    }
41
42    pub fn det(&self) -> T {
43        let (L, U, _) = self.lu();
44        let mut det = T::one();
45        for i in 0..D {
46            det = det * L[(i, i)] * U[(i, i)];
47        }
48        if D % 2 != 0 {
49            det = -det;
50        }
51        det
52    }
53
54    pub fn lu(&self) -> (Matrix<D, D, T>, Matrix<D, D, T>, Matrix<D, D, T>) {
55        let mut P = eye!(D, T);
56        let mut L = eye!(D, T);
57        let mut U = *self;
58
59        for d in 0..D {
60            // Find row index of maximum absolute value equal to or below given diagonal
61            let max_row = Self::find_max_row(&U, d);
62            // Swap rows if a non-diagonal row is larger (i.e. partial pivot)
63            Self::partial_pivot(&mut P, &mut L, &mut U, d, max_row);
64            // Perform single step of gaussian-elimination
65            Self::gauss_eliminate(&mut L, &mut U, d);
66        }
67        (L, U, P)
68    }
69
70    fn invert_upper_triangular(U: &mut Matrix<D, D, T>) -> Option<Matrix<D, D, T>> {
71        let mut I = eye!(D, T);
72        for i in (0..D).rev() {
73            let diag = U[(i, i)];
74            if diag == T::zero() {
75                return None;
76            }
77            let coeff = T::one() / diag;
78
79            // Make current diagonal identity and scale by same in the row of `I`
80            U[(i, i)] = U[(i, i)] * coeff;
81            for c in i..D {
82                I[(i, c)] = I[(i, c)] * coeff;
83            }
84
85            // Perform gaussian elimination on upper rows of current diagonal
86            for r in 0..i {
87                let coeff = -U[(r, i)];
88                U[(r, i)] = T::zero();
89                for c in i..D {
90                    I[(r, c)] = I[(r, c)] + coeff * I[(i, c)];
91                }
92            }
93        }
94        Some(I)
95    }
96
97    fn invert_lower_triangular(L: &mut Matrix<D, D, T>) -> Option<Matrix<D, D, T>> {
98        let mut I = eye!(D, T);
99        for i in 0..D {
100            let diag = L[(i, i)];
101            if diag != T::one() {
102                return None;
103            }
104            for r in (i + 1)..D {
105                let coeff = -L[(r, i)];
106                L[(r, i)] = T::zero();
107                for c in 0..(i + 1) {
108                    I[(r, c)] = I[(r, c)] + coeff * I[(i, c)];
109                }
110            }
111        }
112        // println!("{}", L);
113        // println!("{}", I);
114        Some(I)
115    }
116
117    fn gauss_eliminate(L: &mut Matrix<D, D, T>, U: &mut Matrix<D, D, T>, diag: usize) {
118        let d = diag;
119        for r in (d + 1)..D {
120            L[(r, d)] = U[(r, d)] / U[(d, d)];
121            for c in 0..D {
122                U[(r, c)] = U[(r, c)] - L[(r, d)] * U[(d, c)];
123            }
124        }
125    }
126
127    fn partial_pivot(
128        P: &mut Matrix<D, D, T>,
129        L: &mut Matrix<D, D, T>,
130        U: &mut Matrix<D, D, T>,
131        diag: usize,
132        max_row: usize,
133    ) {
134        P.swap_rows(diag, max_row);
135        U.swap_rows(diag, max_row);
136        // Swap partial rows of L
137        for c in 0..diag {
138            let temp = L[(max_row, c)];
139            L[(max_row, c)] = L[(diag, c)];
140            L[(diag, c)] = temp;
141        }
142    }
143
144    fn find_max_row(U: &Matrix<D, D, T>, diag: usize) -> usize {
145        let mut max_row = diag;
146        for r in diag..D {
147            if U[(max_row, diag)].abs() < U[(r, diag)].abs() {
148                max_row = r;
149            }
150        }
151        max_row
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use approx::{assert_abs_diff_eq, assert_relative_eq};
158
159    use super::*;
160    use crate::matrix;
161
162    #[test]
163    fn LU_decomp() {
164        let A = matrix![
165            1.0, 3.0, 5.0;
166            2.0, 4.0, 7.0;
167            1.0, 1.0, 0.0;
168        ];
169        let (L, U, P) = A.lu();
170        let L_exp = matrix![
171            1.0,  0.0, 0.0;
172            0.5,  1.0, 0.0;
173            0.5, -1.0, 1.0;
174        ];
175        let U_exp = matrix![
176            2.0, 4.0,  7.0;
177            0.0, 1.0,  1.5;
178            0.0, 0.0, -2.0;
179        ];
180        let P_exp = matrix![
181            0.0, 1.0, 0.0;
182            1.0, 0.0, 0.0;
183            0.0, 0.0, 1.0;
184        ];
185        assert_eq!(L, L_exp);
186        assert_eq!(U, U_exp);
187        assert_eq!(P, P_exp);
188    }
189
190    #[test]
191    fn determinant() {
192        let A = matrix![
193            6.0, 2.0, 3.0;
194            1.0, 1.0, 1.0;
195            0.0, 4.0, 9.0;
196        ];
197        assert_abs_diff_eq!(A.det(), 24.0, epsilon = 1e-10);
198
199        let A = matrix![
200            3.0, 7.0;
201            1.0, -4.0;
202        ];
203        assert_abs_diff_eq!(A.det(), -19.0, epsilon = 1e-10);
204
205        let A = matrix![
206            1.0, 2.0;
207            4.0, 8.0;
208        ];
209        assert_abs_diff_eq!(A.det(), 0.0, epsilon = 1e-10);
210
211        let A = matrix![
212            1.0, 1.0;
213            2.0, 2.0;
214        ];
215        assert_abs_diff_eq!(A.det(), 0.0, epsilon = 1e-10);
216
217        let A = matrix![
218            1.0, 2.0, 3.0, 4.0;
219            2.0, 5.0, 7.0, 3.0;
220            4.0, 10.0, 14.0, 6.0;
221            3.0, 4.0, 2.0, 7.0;
222        ];
223        assert_abs_diff_eq!(A.det(), 0.0, epsilon = 1e-10);
224
225        let A = matrix![
226            1.0, 2.0, 3.0, 4.0;
227            2.0, 5.0, 7.0, 3.0;
228            4.0, 10.0, 14.0, 6.0;
229            3.0, 4.0, 2.0, 7.0;
230        ];
231        assert_abs_diff_eq!(A.det(), 0.0, epsilon = 1e-10);
232
233        let A = matrix![
234            11.0, 9.0, 24.0, 2.0;
235            1.0, 5.0, 2.0, 6.0;
236            3.0, 17.0, 18.0, 1.0;
237            2.0, 5.0, 7.0, 1.0;
238        ];
239        assert_abs_diff_eq!(A.det(), -284.0, epsilon = 1e-10);
240
241        let A = matrix![
242              2.0, 3.0, 0.0, 9.0, 0.0, 1.0, 0.0, 1.0, 1.0, 2.0, 1.0;
243              1.0, 1.0, 0.0, 3.0, 0.0, 0.0, 0.0, 9.0, 2.0, 3.0, 1.0;
244              1.0, 4.0, 0.0, 2.0, 8.0, 5.0, 0.0, 3.0, 6.0, 1.0, 9.0;
245              0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0;
246              2.0, 2.0, 4.0, 1.0, 1.0, 2.0, 1.0, 6.0, 9.0, 0.0, 7.0;
247              0.0, 0.0, 0.0, 6.0, 0.0, 7.0, 0.0, 1.0, 0.0, 0.0, 0.0;
248              2.0, 5.0, 0.0, 7.0, 0.0, 4.0, 6.0, 8.0, 5.0, 1.0, 3.0;
249              0.0, 0.0, 0.0, 1.0, 0.0, 4.0, 0.0, 1.0, 0.0, 0.0, 0.0;
250              0.0, 0.0, 0.0, 8.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0;
251              2.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0;
252              2.0, 6.0, 0.0, 1.0, 0.0,30.0, 0.0, 2.0, 3.0, 2.0, 1.0;
253        ];
254        assert_abs_diff_eq!(A.det(), 0.0, epsilon = 1e-10);
255    }
256
257    #[test]
258    fn upper_inverse() {
259        let mut A = matrix![
260            2.0, 4.0, 6.0;
261            0.0,-1.0,-8.0;
262            0.0, 0.0,96.0;
263        ];
264        let I = Matrix::invert_upper_triangular(&mut A).unwrap();
265        let E = matrix![
266            0.5, 2.0, 0.13541667;
267            0.0,-1.0,-0.08333333;
268            0.0, 0.0, 0.01041667;
269        ];
270        assert_relative_eq!(I, E, max_relative = 1e-6);
271    }
272
273    #[test]
274    fn lower_inverse() {
275        let mut A = matrix![
276            1.0, 0.0, 0.0;
277            8.0, 1.0, 0.0;
278            4.0, 9.0, 1.0;
279        ];
280        let I = Matrix::invert_lower_triangular(&mut A).unwrap();
281        let E = matrix![
282            1.0, 0.0, 0.0;
283            -8.0, 1.0, 0.0;
284            68.0, -9.0, 1.0;
285        ];
286        assert_relative_eq!(I, E, max_relative = 1e-6);
287    }
288
289    #[test]
290    fn inverse() {
291        let A = matrix![
292            6.0, 2.0, 3.0;
293            1.0, 1.0, 1.0;
294            0.0, 4.0, 9.0;
295        ];
296        let exp = matrix![
297            0.20833333, -0.25, -0.04166667;
298                -0.375,  2.25, -0.125;
299            0.16666667,  -1.0,  0.16666667;
300        ];
301        assert_relative_eq!(A.inv().unwrap(), exp, max_relative = 1e-6);
302
303        let A = matrix![
304            11.0, 9.0, 24.0, 2.0;
305            1.0, 5.0, 2.0, 6.0;
306            3.0, 17.0, 18.0, 1.0;
307            2.0, 5.0, 7.0, 1.0;
308        ];
309        let exp = matrix![
310         0.72183099,  0.46126761,  1.02112676, -5.23239437;
311         0.28521127,  0.23591549,  0.59859155, -2.58450704;
312         -0.37676056, -0.29929577, -0.65492958,  3.20422535;
313         -0.23239437, -0.00704225, -0.45070423,  1.95774648;
314        ];
315        assert_relative_eq!(A.inv().unwrap(), exp, max_relative = 1e-6);
316    }
317}