vector_victor/
decompose.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4
5use crate::util::checked_inv;
6use crate::{Matrix, Vector};
7use num_traits::real::Real;
8use num_traits::Signed;
9use std::iter::{Product, Sum};
10use std::ops::{Mul, Neg, Not};
11
12/// The parity of an [LU decomposition](LUDecomposition). In other words, how many times the
13/// source matrix has to have rows swapped before the decomposition takes place
14#[derive(Copy, Clone, Debug, PartialEq)]
15pub enum Parity {
16    Even,
17    Odd,
18}
19
20impl<T> Mul<T> for Parity
21where
22    T: Neg<Output = T>,
23{
24    type Output = T;
25
26    fn mul(self, rhs: T) -> Self::Output {
27        match self {
28            Parity::Even => rhs,
29            Parity::Odd => -rhs,
30        }
31    }
32}
33
34impl Not for Parity {
35    type Output = Parity;
36
37    fn not(self) -> Self::Output {
38        match self {
39            Parity::Even => Parity::Odd,
40            Parity::Odd => Parity::Even,
41        }
42    }
43}
44
45/// The result of the [LU decomposition](LUDecompose::lu) of a matrix.
46///
47/// This struct provides a convenient way to reuse one LU decomposition to solve multiple
48/// matrix equations. You likely do not need to worry about its contents.
49///
50/// See [LU decomposition](https://en.wikipedia.org/wiki/LU_decomposition)
51/// on wikipedia for more information
52#[derive(Copy, Clone, Debug, PartialEq)]
53pub struct LUDecomposition<T: Copy, const N: usize> {
54    /// The $bbL$ and $bbU$ matrices combined into one
55    ///
56    /// for example if
57    ///
58    /// $ bbU = [[u_{11}, u_{12}, cdots,  u_{1n} ],
59    ///          [0,      u_{22}, cdots,  u_{2n} ],
60    ///          [vdots,  vdots,  ddots,  vdots  ],
61    ///          [0,      0,      cdots,  u_{mn} ]] $
62    /// and
63    /// $ bbL = [[1,      0,      cdots,  0      ],
64    ///          [l_{21}, 1,      cdots,  0      ],
65    ///          [vdots,  vdots,  ddots,  vdots  ],
66    ///          [l_{m1}, l_{m2}, cdots,  1      ]] $,
67    /// then
68    /// $ bb{LU} = [[u_{11}, u_{12}, cdots,  u_{1n} ],
69    ///             [l_{21}, u_{22}, cdots,  u_{2n} ],
70    ///             [vdots,  vdots,  ddots,  vdots  ],
71    ///             [l_{m1}, l_{m2}, cdots,  u_{mn} ]] $
72    ///
73    /// note that the diagonals of the $bbL$ matrix are always 1, so no information is lost
74    pub lu: Matrix<T, N, N>,
75
76    /// The indices of the permutation matrix $bbP$, such that $bbP xx bbA$ = $bbL xx bbU$
77    ///
78    /// The permutation matrix rearranges the rows of the original matrix in order to produce
79    /// the LU decomposition. This makes calculation simpler, but makes the result
80    /// (known as an LUP decomposition) no longer unique
81    pub idx: Vector<usize, N>,
82
83    /// The parity of the decomposition.
84    pub parity: Parity,
85}
86
87impl<T: Copy + Default + Real, const N: usize> LUDecomposition<T, N> {
88    /// Solve for $x$ in $bbM xx x = b$, where $bbM$ is the original matrix this is a decomposition of.
89    ///
90    /// This is equivalent to [`LUDecompose::solve`] while allowing the LU decomposition
91    /// to be reused
92    #[must_use]
93    pub fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Matrix<T, N, M> {
94        let b_permuted = b.permute_rows(&self.idx);
95
96        Matrix::from_cols(b_permuted.cols().map(|mut x| {
97            // Implementation from Numerical Recipes §2.3
98            // When ii is set to a positive value,
99            // it will become the index of the first non-vanishing element of b
100            let mut ii = 0usize;
101            for i in 0..N {
102                // forward substitution using L
103                let mut sum = x[i];
104                if ii != 0 {
105                    for j in (ii - 1)..i {
106                        sum = sum - (self.lu[(i, j)] * x[j]);
107                    }
108                } else if sum.abs() > T::epsilon() {
109                    ii = i + 1;
110                }
111                x[i] = sum;
112            }
113            for i in (0..N).rev() {
114                // back substitution using U
115                let mut sum = x[i];
116                for j in (i + 1)..N {
117                    sum = sum - (self.lu[(i, j)] * x[j]);
118                }
119                x[i] = sum / self.lu[(i, i)]
120            }
121            x
122        }))
123    }
124
125    /// Calculate the determinant $|M|$ of the matrix $M$.
126    /// If the matrix is singular, the determinant is 0.
127    ///
128    /// This is equivalent to [`LUDecompose::det`] while allowing the LU decomposition
129    /// to be reused
130    pub fn det(&self) -> T {
131        self.parity * self.lu.diagonals().fold(T::one(), |l, &r| l * r)
132    }
133
134    /// Calculate the inverse of the original matrix, such that $bbM xx bbM^{-1} = bbI$
135    ///
136    /// This is equivalent to [`Matrix::inv`] while allowing the LU decomposition to be reused
137    #[must_use]
138    pub fn inv(&self) -> Matrix<T, N, N> {
139        return self.solve(&Matrix::<T, N, N>::identity());
140    }
141
142    /// Separate the $L$ and $U$ sides of the $LU$ matrix.
143    /// See [the `lu` field](LUDecomposition::lu) for more information
144    pub fn separate(&self) -> (Matrix<T, N, N>, Matrix<T, N, N>) {
145        let mut l = Matrix::<T, N, N>::identity();
146        let mut u = self.lu; // lu
147
148        for m in 1..N {
149            for n in 0..m {
150                // iterate over lower diagonal
151                l[(m, n)] = u[(m, n)];
152                u[(m, n)] = T::zero();
153            }
154        }
155
156        (l, u)
157    }
158}
159
160/// A Matrix that can be decomposed into an upper and lower diagonal matrix,
161/// known as an [LU Decomposition](LUDecomposition).
162///
163/// See [LU decomposition](https://en.wikipedia.org/wiki/LU_decomposition)
164/// on wikipedia for more information
165pub trait LUDecompose<T: Copy, const N: usize> {
166    /// return this matrix's [`LUDecomposition`], or [`None`] if the matrix is singular.
167    /// This can be used to solve for multiple results
168    ///
169    /// ```
170    /// # use vector_victor::decompose::LUDecompose;
171    /// # use vector_victor::{Matrix, Vector};
172    /// let m = Matrix::mat([[1.0,3.0],[2.0,4.0]]);
173    /// let lu = m.lu().expect("Cannot decompose a signular matrix");
174    ///
175    /// let b = Vector::vec([7.0,10.0]);
176    /// assert_eq!(lu.solve(&b), Vector::vec([1.0,2.0]));
177    ///
178    /// let c = Vector::vec([10.0, 14.0]);
179    /// assert_eq!(lu.solve(&c), Vector::vec([1.0,3.0]));
180    ///
181    /// ```
182    #[must_use]
183    fn lu(&self) -> Option<LUDecomposition<T, N>>;
184
185    /// Calculate the inverse of the matrix, such that $bbMxxbbM^{-1} = bbI$,
186    /// or [`None`] if the matrix is singular.
187    ///
188    /// ```
189    /// # use vector_victor::decompose::LUDecompose;
190    /// # use vector_victor::Matrix;
191    /// let m = Matrix::mat([[1.0,3.0],[2.0,4.0]]);
192    /// let mi = m.inv().expect("Cannot invert a singular matrix");
193    ///
194    /// assert_eq!(mi, Matrix::mat([[-2.0, 1.5],[1.0, -0.5]]), "unexpected inverse matrix");
195    ///
196    /// // multiplying a matrix by its inverse yields the identity matrix
197    /// assert_eq!(m.mmul(&mi), Matrix::identity())
198    /// ```
199    #[must_use]
200    fn inv(&self) -> Option<Matrix<T, N, N>>;
201
202    /// Calculate the determinant $|M|$ of the matrix $M$.
203    /// If the matrix is singular, the determinant is 0
204    #[must_use]
205    fn det(&self) -> T;
206
207    /// Solve for $x$ in $bbM xx x = b$
208    ///
209    /// ```
210    /// # use vector_victor::decompose::LUDecompose;
211    /// # use vector_victor::{Matrix, Vector};
212    ///
213    /// let m = Matrix::mat([[1.0,3.0],[2.0,4.0]]);
214    /// let b = Vector::vec([7.0,10.0]);
215    /// let x = m.solve(&b).expect("Cannot solve a singular matrix");
216    ///
217    /// assert_eq!(x, Vector::vec([1.0,2.0]), "x = [1,2]");
218    /// assert_eq!(m.mmul(&x), b, "Mx = b");
219    /// ```
220    ///
221    /// $x$ does not need to be a column-vector, it can also be a 2D matrix. For example,
222    /// the following is another way to calculate the [inverse](LUDecompose::inv()) by solving for the identity matrix $I$.
223    ///
224    /// ```
225    /// # use vector_victor::decompose::LUDecompose;
226    /// # use vector_victor::{Matrix, Vector};
227    ///
228    /// let m = Matrix::mat([[1.0,3.0],[2.0,4.0]]);
229    /// let i = Matrix::<f64,2,2>::identity();
230    /// let mi = m.solve(&i).expect("Cannot solve a singular matrix");
231    ///
232    /// assert_eq!(mi, Matrix::mat([[-2.0, 1.5],[1.0, -0.5]]));
233    /// assert_eq!(m.mmul(&mi), i, "M x M^-1 = I");
234    /// ```
235    #[must_use]
236    fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Option<Matrix<T, N, M>>;
237}
238
239impl<T, const N: usize> LUDecompose<T, N> for Matrix<T, N, N>
240where
241    T: Copy + Default + Real + Sum + Product + Signed,
242{
243    fn lu(&self) -> Option<LUDecomposition<T, N>> {
244        // Implementation from Numerical Recipes §2.3
245        let mut lu = self.clone();
246        let mut idx: Vector<usize, N> = (0..N).collect();
247        let mut parity = Parity::Even;
248
249        let mut vv: Vector<T, N> = self
250            .rows()
251            .map(|row| {
252                let m = row.elements().cloned().reduce(|acc, x| acc.max(x.abs()))?;
253                checked_inv(m)
254            })
255            .collect::<Option<_>>()?; // get the inverse max abs value in each row
256
257        // for each column in the matrix...
258        for k in 0..N {
259            // search for the pivot element and its index
260            let (ipivot, _) = (lu.col(k) * vv)
261                .abs()
262                .elements()
263                .enumerate()
264                .skip(k) // below the diagonal
265                .reduce(|(imax, xmax), (i, x)| match x > xmax {
266                    // Is the figure of merit for the pivot better than the best so far?
267                    true => (i, x),
268                    false => (imax, xmax),
269                })?;
270
271            // do we need to interchange rows?
272            if k != ipivot {
273                lu.pivot_row(ipivot, k); // yes, we do
274                idx.pivot_row(ipivot, k);
275                parity = !parity; // swap parity
276                vv[ipivot] = vv[k] // interchange scale factor
277            }
278
279            // select our pivot, which is now on the diagonal
280            let pivot = lu[(k, k)];
281            if pivot.abs() < T::epsilon() {
282                // if the pivot is zero, the matrix is singular
283                return None;
284            };
285
286            // for each element in the column k below the diagonal...
287            // this is called outer product Gaussian elimination
288            for i in (k + 1)..N {
289                // divide by the pivot element
290                lu[(i, k)] = lu[(i, k)] / pivot;
291
292                // for each element in the column k below the diagonal...
293                for j in (k + 1)..N {
294                    // reduce remaining submatrix
295                    lu[(i, j)] = lu[(i, j)] - (lu[(i, k)] * lu[(k, j)]);
296                }
297            }
298        }
299
300        return Some(LUDecomposition { lu, idx, parity });
301    }
302
303    fn inv(&self) -> Option<Matrix<T, N, N>> {
304        match N {
305            1 => Some(Self::fill(checked_inv(self[0])?)),
306            2 => {
307                let mut result = Self::default();
308                result[(0, 0)] = self[(1, 1)];
309                result[(1, 1)] = self[(0, 0)];
310                result[(1, 0)] = -self[(1, 0)];
311                result[(0, 1)] = -self[(0, 1)];
312                Some(result * checked_inv(self.det())?)
313            }
314            _ => Some(self.lu()?.inv()),
315        }
316    }
317
318    fn det(&self) -> T {
319        match N {
320            1 => self[0],
321            2 => (self[(0, 0)] * self[(1, 1)]) - (self[(0, 1)] * self[(1, 0)]),
322            3 => {
323                // use rule of Sarrus
324                (0..N) // starting column
325                    .map(|i| {
326                        let dn = (0..N)
327                            .map(|j| -> T { self[(j, (j + i) % N)] })
328                            .product::<T>();
329                        let up = (0..N)
330                            .map(|j| -> T { self[(N - j - 1, (j + i) % N)] })
331                            .product::<T>();
332                        dn - up
333                    })
334                    .sum::<T>()
335            }
336            _ => {
337                // use LU decomposition
338                self.lu().map_or(T::zero(), |lu| lu.det())
339            }
340        }
341    }
342
343    fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Option<Matrix<T, N, M>> {
344        Some(self.lu()?.solve(b))
345    }
346}