stack_algebra/
lib.rs

1#![no_std]
2
3mod algebra;
4mod fmt;
5mod index;
6mod iter;
7mod new;
8mod num;
9mod ops;
10mod util;
11mod view;
12
13use core::{
14    mem::MaybeUninit,
15    ops::{Add, Div, Mul, Sub},
16    slice,
17};
18
19pub use index::MatrixIndex;
20pub use num::{Abs, Sqrt, Zero};
21pub use view::{Column, Row};
22
23#[doc(hidden)]
24pub use vectrix_macro as proc_macro;
25
26/// Represents a matrix with constant `M` rows and constant `N` columns.
27///
28/// The underlying data is represented as an array and is always stored in
29/// column-major order.
30///
31/// See the [crate root][crate] for usage examples.
32#[repr(C)]
33#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
34pub struct Matrix<const M: usize, const N: usize, T = f32> {
35    data: [[T; M]; N],
36}
37
38impl<const M: usize, const N: usize, T> Matrix<M, N, T> {
39    /// Returns a raw pointer to the underlying data.
40    #[inline]
41    fn as_ptr(&self) -> *const T {
42        self.data.as_ptr() as *const T
43    }
44
45    /// Returns an unsafe mutable pointer to the underlying data.
46    #[inline]
47    fn as_mut_ptr(&mut self) -> *mut T {
48        self.data.as_mut_ptr() as *mut T
49    }
50
51    /// Views the underlying data as a contiguous slice.
52    #[inline]
53    pub fn as_slice(&self) -> &[T] {
54        unsafe { slice::from_raw_parts(self.as_ptr(), M * N) }
55    }
56
57    /// Views the underlying data as a contiguous mutable slice.
58    #[inline]
59    pub fn as_mut_slice(&mut self) -> &mut [T] {
60        unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), M * N) }
61    }
62
63    /// Returns a reference to the `i`-th row of this matrix.
64    #[inline]
65    pub fn row(&self, i: usize) -> &Row<M, N, T> {
66        Row::new(&self.as_slice()[i..])
67    }
68
69    /// Returns a mutable reference to the `i`-th row of this matrix.
70    #[inline]
71    pub fn row_mut(&mut self, i: usize) -> &mut Row<M, N, T> {
72        Row::new_mut(&mut self.as_mut_slice()[i..])
73    }
74
75    /// Returns a reference to the `i`-th column of this matrix.
76    #[inline]
77    pub fn column(&self, i: usize) -> &Column<M, N, T> {
78        Column::new(&self.data[i])
79    }
80
81    /// Returns a mutable reference to the `i`-th column of this matrix.
82    #[inline]
83    pub fn column_mut(&mut self, i: usize) -> &mut Column<M, N, T> {
84        Column::new_mut(&mut self.data[i])
85    }
86
87    /// Returns a reference to an element in the matrix or `None` if out of
88    /// bounds.
89    #[inline]
90    pub fn get<I>(&self, i: I) -> Option<&I::Output>
91    where
92        I: MatrixIndex<Self>,
93    {
94        i.get(self)
95    }
96
97    /// Returns a mutable reference to an element in the matrix or `None` if out
98    /// of bounds.
99    #[inline]
100    pub fn get_mut<I>(&mut self, i: I) -> Option<&mut I::Output>
101    where
102        I: MatrixIndex<Self>,
103    {
104        i.get_mut(self)
105    }
106
107    /// Returns a reference to an element in the matrix without doing any bounds
108    /// checking.
109    ///
110    /// # Safety
111    ///
112    /// Calling this method with an out-of-bounds index is
113    /// *[undefined behavior]* even if the resulting reference is not used.
114    ///
115    /// [undefined behavior]: https://doc.rust-lang.org/reference/behavior-considered-undefined.html
116    #[inline]
117    pub unsafe fn get_unchecked<I>(&self, i: I) -> &I::Output
118    where
119        I: MatrixIndex<Self>,
120    {
121        unsafe { &*i.get_unchecked(self) }
122    }
123
124    /// Returns a mutable reference to an element in the matrix without doing
125    /// any bounds checking.
126    ///
127    /// # Safety
128    ///
129    /// Calling this method with an out-of-bounds index is
130    /// *[undefined behavior]* even if the resulting reference is not used.
131    ///
132    /// [undefined behavior]: https://doc.rust-lang.org/reference/behavior-considered-undefined.html
133    #[inline]
134    pub unsafe fn get_unchecked_mut<I>(&mut self, i: I) -> &mut I::Output
135    where
136        I: MatrixIndex<Self>,
137    {
138        unsafe { &mut *i.get_unchecked_mut(self) }
139    }
140
141    /// Returns an iterator over the underlying data.
142    #[inline]
143    pub fn iter(&self) -> slice::Iter<'_, T> {
144        self.as_slice().iter()
145    }
146
147    /// Returns a mutable iterator over the underlying data.
148    #[inline]
149    pub fn iter_mut(&mut self) -> slice::IterMut<'_, T> {
150        self.as_mut_slice().iter_mut()
151    }
152
153    /// Swap the two given rows of this matrix
154    #[inline]
155    pub fn swap_rows(&mut self, r1: usize, r2: usize)
156    where
157        T: Copy,
158    {
159        if r1 < M && r2 < M {
160            for i in 0..N {
161                let tmp = self[(r1, i)];
162                self[(r1, i)] = self[(r2, i)];
163                self[(r2, i)] = tmp;
164            }
165        }
166    }
167
168    /// Swap the two given columns of this matrix
169    #[inline]
170    pub fn swap_columns(&mut self, c1: usize, c2: usize)
171    where
172        T: Copy,
173    {
174        if c1 < N && c2 < N {
175            for i in 0..M {
176                let tmp = self[(i, c1)];
177                self[(i, c1)] = self[(i, c2)];
178                self[(i, c2)] = tmp;
179            }
180        }
181    }
182
183    // /// Clone the current matrix.
184    // #[inline]
185    // pub fn clone(&self) -> Matrix<M, N, T>
186    // where
187    //     T: Copy,
188    // {
189    //     // let mut clone = zeros!(M, N, T);
190    //     let mut clone = unsafe { Matrix::<M, N, MaybeUninit<T>>::uninit().assume_init() };
191    //     for c in 0..N {
192    //         for r in 0..M {
193    //             clone[(r, c)] = self[(r, c)];
194    //         }
195    //     }
196    //     clone
197    // }
198
199    /// Transpose of the current matrix.
200    #[inline]
201    pub fn transpose(&self) -> Matrix<N, M, T>
202    where
203        T: Clone,
204    {
205        // let mut transpose = zeros!(N, M, T);
206        let mut transpose = unsafe { Matrix::<N, M, MaybeUninit<T>>::uninit().assume_init() };
207        for c in 0..N {
208            for r in 0..M {
209                transpose[(c, r)] = self[(r, c)].clone();
210            }
211        }
212        transpose
213    }
214
215    /// Transpose of the current matrix.
216    #[allow(non_snake_case)]
217    #[inline]
218    pub fn T(&self) -> Matrix<N, M, T>
219    where
220        T: Clone,
221    {
222        self.transpose()
223    }
224
225    /// Compute the Frobenius norm
226    pub fn norm(&self) -> T
227    where
228        T: Copy + Zero + Abs + Sqrt + Add<Output = T> + Mul<Output = T>,
229    {
230        let mut tmp = T::zero();
231        for c in 0..N {
232            for r in 0..M {
233                let v = self[(r, c)].abs();
234                tmp = tmp + v * v;
235            }
236        }
237        tmp.sqrt()
238    }
239
240    /// Compute the Frobenius norm
241    pub fn normalize(self) -> Self
242    where
243        T: Copy + Zero + Abs + Sqrt + Add<Output = T> + Mul<Output = T> + Div<Output = T>,
244    {
245        self / self.norm()
246    }
247
248    // /// Returns an iterator over the rows in this matrix.
249    // #[inline]
250    // pub fn iter_rows(&self) -> IterRows<'_, T, M, N> {
251    //     IterRows::new(self)
252    // }
253
254    // /// Returns a mutable iterator over the rows in this matrix.
255    // #[inline]
256    // pub fn iter_rows_mut(&mut self) -> IterRowsMut<'_, T, M, N> {
257    //     IterRowsMut::new(self)
258    // }
259
260    // /// Returns an iterator over the columns in this matrix.
261    // #[inline]
262    // pub fn iter_columns(&self) -> IterColumns<'_, T, M, N> {
263    //     IterColumns::new(self)
264    // }
265
266    // /// Returns a mutable iterator over the columns in this matrix.
267    // #[inline]
268    // pub fn iter_columns_mut(&mut self) -> IterColumnsMut<'_, T, M, N> {
269    //     IterColumnsMut::new(self)
270    // }
271
272    // /// Returns a matrix of the same size as self, with function `f` applied to
273    // /// each element in column-major order.
274    // #[inline]
275    // pub fn map<F, U>(self, f: F) -> Matrix<M, N, U>
276    // where
277    //     F: FnMut(T) -> U,
278    // {
279    //     // SAFETY: the iterator has the exact number of elements required.
280    //     unsafe { new::collect_unchecked(self.into_iter().map(f)) }
281    // }
282
283    // /// Returns the L1 norm of the matrix.
284    // ///
285    // /// Also known as *Manhattan Distance* or *Taxicab norm*. L1 Norm is the sum
286    // /// of the magnitudes of the vectors in a space.
287    // pub fn l1_norm(&self) -> T
288    // where
289    //     T: Copy + Ord + Abs + Zero + Sum<T>,
290    // {
291    //     (0..N)
292    //         .map(|i| self.data[i].iter().copied().map(Abs::abs).sum())
293    //         .max()
294    //         .unwrap_or_else(Zero::zero)
295    // }
296}
297
298// impl<const M: usize, const N: usize, T> Clone for Matrix<M, N, T>
299// where
300//     for<'a> &'a T,
301// {
302//     fn clone(&self) -> Self {
303//         // let mut clone = zeros!(M, N, T);
304//         let mut clone = unsafe { Matrix::<M, N, MaybeUninit<T>>::uninit().assume_init() };
305//         for c in 0..N {
306//             for r in 0..M {
307//                 clone[(r, c)] = &self[(r, c)];
308//             }
309//         }
310//         clone
311//     }
312// }
313
314////////////////////////////////////////////////////////////////////////////////
315// Square matrix functions
316////////////////////////////////////////////////////////////////////////////////
317impl<const N: usize, T> Matrix<N, N, T> {
318    /// Compute the sum of diagonal elements
319    pub fn trace(&self) -> T
320    where
321        T: Zero,
322        for<'a> &'a T: Add<&'a T, Output = T>,
323    {
324        let mut t = T::zero();
325        for i in 0..N {
326            t = &t + &self[(i, i)];
327        }
328        t
329    }
330}
331
332impl<T> Matrix<3, 1, T> {
333    pub fn cross(&self, other: &Self) -> Self
334    where
335        for<'a> &'a T: Mul<&'a T, Output = T> + Sub<&'a T, Output = T>,
336    {
337        let mut res = unsafe { Matrix::<3, 1, MaybeUninit<T>>::uninit().assume_init() };
338        res[0] = &(&self[1] * &other[2]) - &(&self[2] * &other[1]);
339        res[1] = &(&self[2] * &other[0]) - &(&self[0] * &other[2]);
340        res[2] = &(&self[0] * &other[1]) - &(&self[1] * &other[0]);
341        res
342    }
343}
344
345pub fn cross<T>(a: &Matrix<3, 1, T>, b: &Matrix<3, 1, T>) -> Matrix<3, 1, T>
346where
347    for<'a> &'a T: Mul<&'a T, Output = T> + Sub<&'a T, Output = T>,
348{
349    a.cross(b)
350}
351
352////////////////////////////////////////////////////////////////////////////////
353// 3D/4D Vector Type Conversion to Tuple
354////////////////////////////////////////////////////////////////////////////////
355
356impl<T: Copy> From<(T, T, T)> for Matrix<3, 1, T> {
357    fn from(src: (T, T, T)) -> Self {
358        matrix![src.0; src.1; src.2]
359    }
360}
361
362impl<T: Copy> From<(T, T, T)> for Matrix<1, 3, T> {
363    fn from(src: (T, T, T)) -> Self {
364        matrix![src.0, src.1, src.2]
365    }
366}
367
368impl<T: Copy> From<Matrix<3, 1, T>> for (T, T, T) {
369    fn from(src: Matrix<3, 1, T>) -> Self {
370        (src[0], src[1], src[2])
371    }
372}
373
374impl<T: Copy> From<Matrix<1, 3, T>> for (T, T, T) {
375    fn from(src: Matrix<1, 3, T>) -> Self {
376        (src[0], src[1], src[2])
377    }
378}
379
380impl<T: Copy> From<(T, T, T, T)> for Matrix<4, 1, T> {
381    fn from(src: (T, T, T, T)) -> Self {
382        matrix![src.0; src.1; src.2; src.3]
383    }
384}
385
386impl<T: Copy> From<(T, T, T, T)> for Matrix<1, 4, T> {
387    fn from(src: (T, T, T, T)) -> Self {
388        matrix![src.0, src.1, src.2, src.3]
389    }
390}
391
392impl<T: Copy> From<Matrix<4, 1, T>> for (T, T, T, T) {
393    fn from(src: Matrix<4, 1, T>) -> Self {
394        (src[0], src[1], src[2], src[3])
395    }
396}
397
398impl<T: Copy> From<Matrix<1, 4, T>> for (T, T, T, T) {
399    fn from(src: Matrix<1, 4, T>) -> Self {
400        (src[0], src[1], src[2], src[3])
401    }
402}
403
404// #[cfg(test)]
405impl<const M: usize, const N: usize, T: approx::AbsDiffEq> approx::AbsDiffEq for Matrix<M, N, T>
406where
407    T::Epsilon: Copy,
408{
409    type Epsilon = T::Epsilon;
410    fn default_epsilon() -> Self::Epsilon {
411        T::default_epsilon()
412    }
413
414    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
415        let mut eq = true;
416        for j in 0..N {
417            for i in 0..M {
418                eq = eq && T::abs_diff_eq(&self[(i, j)], &other[(i, j)], epsilon);
419                if !eq {
420                    return false;
421                }
422            }
423        }
424        true
425    }
426}
427
428// #[cfg(test)]
429impl<const M: usize, const N: usize, T: approx::RelativeEq> approx::RelativeEq for Matrix<M, N, T>
430where
431    T::Epsilon: Copy,
432{
433    fn default_max_relative() -> Self::Epsilon {
434        T::default_max_relative()
435    }
436
437    fn relative_eq(
438        &self,
439        other: &Self,
440        epsilon: Self::Epsilon,
441        max_relative: Self::Epsilon,
442    ) -> bool {
443        let mut eq = true;
444        for j in 0..N {
445            for i in 0..M {
446                eq = eq && T::relative_eq(&self[(i, j)], &other[(i, j)], epsilon, max_relative);
447                if !eq {
448                    return false;
449                }
450            }
451        }
452        true
453    }
454}
455
456/// A matrix with one row and `N` columns.
457pub type RowVector<const N: usize, T> = Matrix<1, N, T>;
458
459/// A matrix with one column and `M` rows.
460pub type Vector<const M: usize, T> = Matrix<M, 1, T>;
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use approx::assert_relative_eq;
466
467    #[test]
468    fn create() {
469        let m = matrix![
470            1.0, 2.0, 3.0;
471            4.0, 5.0, 6.0;
472        ];
473        assert_eq!(m[(0, 0)], 1.0);
474        assert_eq!(m[(1, 2)], 6.0);
475
476        let v = vector![1.0, 2.0, 3.0];
477        assert_eq!(v[0], 1.0);
478        assert_eq!(v[2], 3.0);
479
480        let z = zeros!(2, 3);
481        assert_eq!(z[(0, 0)], 0.0);
482        assert_eq!(z[(1, 2)], 0.0);
483
484        let z = zeros!(3);
485        assert_eq!(z[(2, 2)], 0.0);
486
487        let o = ones!(2, 3);
488        assert_eq!(o[(0, 0)], 1.0);
489        assert_eq!(o[(1, 2)], 1.0);
490
491        let o = ones!(3);
492        assert_eq!(o[(2, 2)], 1.0);
493    }
494
495    #[test]
496    fn index() {
497        let m = matrix![
498            1.0, 2.0, 3.0;
499            4.0, 5.0, 6.0;
500        ];
501        assert_eq!(m[1], 4.0);
502        assert_eq!(m[(1, 2)], 6.0);
503
504        let mut s = m.as_slice().iter();
505        assert_eq!(s.next(), Some(&1.0));
506        assert_eq!(s.next(), Some(&4.0));
507        assert_eq!(s.next(), Some(&2.0));
508        assert_eq!(s.next(), Some(&5.0));
509        assert_eq!(s.next(), Some(&3.0));
510        assert_eq!(s.next(), Some(&6.0));
511        assert_eq!(s.next(), None);
512    }
513    #[test]
514    fn swap() {
515        let mut m = matrix![
516            1.0, 2.0, 3.0;
517            4.0, 5.0, 6.0;
518            7.0, 8.0, 9.0;
519        ];
520        m.swap_rows(0, 2);
521        let exp = matrix![
522            7.0, 8.0, 9.0;
523            4.0, 5.0, 6.0;
524            1.0, 2.0, 3.0;
525        ];
526        assert_eq!(m, exp);
527        m.swap_columns(0, 2);
528        let exp = matrix![
529            9.0, 8.0, 7.0;
530            6.0, 5.0, 4.0;
531            3.0, 2.0, 1.0;
532        ];
533        assert_eq!(m, exp);
534    }
535    #[test]
536    fn transpose() {
537        let m = matrix![
538            1.0, 2.0, 3.0;
539            4.0, 5.0, 6.0;
540        ];
541        let t = matrix![
542            1.0, 4.0;
543            2.0, 5.0;
544            3.0, 6.0;
545        ];
546        assert_eq!(m.transpose(), t);
547    }
548    #[test]
549    fn clone() {
550        let a = matrix![
551            1.0, 2.0, 3.0;
552            5.0, 6.0, 4.0;
553        ];
554        assert_eq!(a.clone(), a);
555    }
556    #[test]
557    fn norm() {
558        let m = matrix![
559            1.0,-2.0;
560           -3.0, 6.0;
561        ];
562        assert_relative_eq!(m.norm(), 7.0710678, max_relative = 1e-6);
563    }
564
565    #[test]
566    fn cross() {
567        let a = vector![3.0;-3.0; 1.0];
568        let b = vector![4.0; 9.0; 2.0];
569        let exp = vector![-15.0; -2.0; 39.0];
570        assert_relative_eq!(a.cross(&b), exp, max_relative = 1e-6);
571    }
572
573    #[test]
574    fn trace() {
575        let m = matrix![
576            9.0, 8.0, 7.0;
577            6.0, 5.0, 4.0;
578            3.0, 2.0, 1.0;
579        ];
580        assert_eq!(m.trace(), 15.0);
581    }
582}