stack_algebra/
new.rs

1use crate::num::{One, Zero};
2use crate::Matrix;
3
4use core::mem;
5use core::mem::MaybeUninit;
6use core::ptr;
7
8////////////////////////////////////////////////////////////////////////////////
9// Matrix<M,N,T> methods
10////////////////////////////////////////////////////////////////////////////////
11impl<const M: usize, const N: usize, T> Matrix<M, N, T> {
12    /// Create a new matrix from an array of arrays in column-major order.
13    #[doc(hidden)]
14    #[inline]
15    pub const fn from_column_major_order(data: [[T; M]; N]) -> Self {
16        Self { data }
17    }
18}
19
20impl<const M: usize, const N: usize, T> Matrix<M, N, T>
21where
22    T: Zero + Copy,
23{
24    /// Create a new matrix from an array of arrays in column-major order.
25    #[doc(hidden)]
26    #[inline]
27    pub fn zeros() -> Self {
28        Self::from_column_major_order([[T::zero(); M]; N])
29    }
30}
31
32impl<const M: usize, const N: usize, T> Matrix<M, N, T>
33where
34    T: One + Copy,
35{
36    /// Create a new matrix from an array of arrays in column-major order.
37    #[doc(hidden)]
38    #[inline]
39    pub fn ones() -> Self {
40        Self::from_column_major_order([[T::one(); M]; N])
41    }
42}
43
44impl<const D: usize, T> Matrix<D, D, T>
45where
46    T: Zero + One + Copy,
47{
48    /// Create a new matrix from an array of arrays in column-major order.
49    #[doc(hidden)]
50    #[inline]
51    pub fn eye() -> Self {
52        let mut m = Self::from_column_major_order([[T::zero(); D]; D]);
53        for i in 0..D {
54            m[(i, i)] = T::one();
55        }
56        m
57    }
58}
59
60/// A macro for creating a matrix.
61#[macro_export]
62macro_rules! matrix {
63    ($($data:tt)*) => {
64        $crate::Matrix::from_column_major_order($crate::proc_macro::matrix!($($data)*))
65    };
66}
67
68/// A macro for composing vectors.
69#[macro_export]
70macro_rules! vector {
71    ($($data:tt)*) => {
72        $crate::Matrix::from_column_major_order($crate::proc_macro::matrix!($($data)*))
73    };
74}
75
76#[macro_export]
77macro_rules! zeros {
78    ($cols:expr) => {
79        $crate::Matrix::<$cols, $cols>::zeros()
80    };
81    ($rows:expr, $cols:expr) => {{
82        $crate::Matrix::<$rows, $cols>::zeros()
83    }};
84    ($rows:expr, $cols:expr, $ty:ty) => {{
85        $crate::Matrix::<$rows, $cols, $ty>::zeros()
86    }};
87}
88
89#[macro_export]
90macro_rules! ones {
91    ($cols:expr) => {
92        $crate::Matrix::<$cols, $cols>::ones()
93    };
94    ($rows:expr, $cols:expr) => {{
95        $crate::Matrix::<$rows, $cols>::ones()
96    }};
97    ($rows:expr, $cols:expr, $ty:ty) => {{
98        $crate::Matrix::<$rows, $cols, $ty>::ones()
99    }};
100}
101
102#[macro_export]
103macro_rules! eye {
104    ($dim:expr) => {
105        $crate::Matrix::<$dim, $dim>::eye()
106    };
107    ($dim:expr, $ty:ty) => {{
108        $crate::Matrix::<$dim, $dim, $ty>::eye()
109    }};
110}
111
112#[macro_export]
113macro_rules! diag {
114    ($d1:expr, $d2:expr) => {{
115        let mut m = $crate::Matrix::<2, 2>::zeros();
116        m[(0, 0)] = $d1;
117        m[(1, 1)] = $d2;
118        m
119    }};
120    ($d1:expr, $d2:expr, $d3:expr) => {{
121        let mut m = $crate::Matrix::<3, 3>::zeros();
122        m[(0, 0)] = $d1;
123        m[(1, 1)] = $d2;
124        m[(2, 2)] = $d3;
125        m
126    }};
127    ($d1:expr, $d2:expr, $d3:expr, $d4:expr) => {{
128        let mut m = $crate::Matrix::<4, 4>::zeros();
129        m[(0, 0)] = $d1;
130        m[(1, 1)] = $d2;
131        m[(2, 2)] = $d3;
132        m[(3, 3)] = $d4;
133        m
134    }};
135    ($d1:expr, $d2:expr, $d3:expr, $d4:expr, $d5:expr) => {{
136        let mut m = $crate::Matrix::<5, 5>::zeros();
137        m[(0, 0)] = $d1;
138        m[(1, 1)] = $d2;
139        m[(2, 2)] = $d3;
140        m[(3, 3)] = $d4;
141        m[(4, 4)] = $d5;
142        m
143    }};
144    ($d1:expr, $d2:expr, $d3:expr, $d4:expr, $d5:expr, $d6:expr) => {{
145        let mut m = $crate::Matrix::<6, 6>::zeros();
146        m[(0, 0)] = $d1;
147        m[(1, 1)] = $d2;
148        m[(2, 2)] = $d3;
149        m[(3, 3)] = $d4;
150        m[(4, 4)] = $d5;
151        m[(5, 5)] = $d6;
152        m
153    }};
154}
155
156////////////////////////////////////////////////////////////////////////////////
157// Uninit related methods
158////////////////////////////////////////////////////////////////////////////////
159
160/// Size-heterogeneous transmutation.
161///
162/// This is required because the compiler doesn't yet know how to deal with the
163/// size of const arrays. We should be able to use [`mem::transmute()`] but it
164/// doesn't work yet :(.
165#[inline]
166pub unsafe fn transmute_unchecked<A, B>(a: A) -> B {
167    let b = unsafe { ptr::read(&a as *const A as *const B) };
168    mem::forget(a);
169    b
170}
171
172impl<T, const M: usize, const N: usize> Matrix<M, N, MaybeUninit<T>> {
173    /// Create a new matrix with uninitialized contents.
174    #[inline]
175    pub(crate) fn uninit() -> Self {
176        // SAFETY: The `assume_init` is safe because the type we are claiming to
177        // have initialized here is a bunch of `MaybeUninit`s, which do not
178        // require initialization. Additionally, `Matrix` is `repr(transparent)`
179        // with an array of arrays.
180        //
181        // Note: this is not the most ideal way of doing this. In the future
182        // when Rust allows inline const expressions we might be able to use
183        // `Self { data: [const { MaybeUninit::<T>::uninit() }; M] ; N] }`
184        //
185        // See https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#initializing-an-array-element-by-element
186        let matrix = MaybeUninit::uninit();
187        unsafe { matrix.assume_init() }
188    }
189
190    /// Assumes the data is initialized and extracts each element as `T`.
191    ///
192    /// # Safety
193    ///
194    /// As with [`MaybeUninit::assume_init`], it is up to the caller to
195    /// guarantee that the matrix is really is in an initialized state. Calling
196    /// this when the contents are not yet fully initialized causes immediate
197    /// undefined behavior.
198    #[inline]
199    pub(crate) unsafe fn assume_init(self) -> Matrix<M, N, T> {
200        // SAFETY: The caller is responsible for all the elements being
201        // initialized. Additionally, we know that `T` is the same size as
202        // `MaybeUninit<T>`.
203        unsafe { transmute_unchecked(self) }
204    }
205}
206
207////////////////////////////////////////////////////////////////////////////////
208// FromIterator
209////////////////////////////////////////////////////////////////////////////////
210
211/// Pulls `M * N` items from `iter` and fills a matrix. If the iterator yields
212/// fewer than `M * N` items, `Err(_)` is returned and all already yielded items
213/// are dropped.
214///
215/// If `iter.next()` panics, all items already yielded by the iterator are
216/// dropped.
217pub fn collect<I, T, const M: usize, const N: usize>(mut iter: I) -> Result<Matrix<M, N, T>, usize>
218where
219    I: Iterator<Item = T>,
220{
221    struct Guard<'a, T, const M: usize, const N: usize> {
222        matrix: &'a mut Matrix<M, N, MaybeUninit<T>>,
223        init: usize,
224    }
225
226    impl<T, const M: usize, const N: usize> Drop for Guard<'_, T, M, N> {
227        fn drop(&mut self) {
228            for elem in &mut self.matrix.as_mut_slice()[..self.init] {
229                // SAFETY: this raw slice up to `self.len` will only contain
230                // the initialized objects.
231                unsafe { ptr::drop_in_place(elem.as_mut_ptr()) };
232            }
233        }
234    }
235
236    let mut matrix: Matrix<M, N, MaybeUninit<T>> = Matrix::uninit();
237    let mut guard = Guard {
238        matrix: &mut matrix,
239        init: 0,
240    };
241
242    for _ in 0..(M * N) {
243        match iter.next() {
244            Some(item) => {
245                // SAFETY: `guard.init` starts at zero, is increased by 1 each
246                // iteration of the loop, and the loop is aborted once M * N
247                // is reached, which is the length of the matrix.
248                unsafe { guard.matrix.get_unchecked_mut(guard.init).write(item) };
249                guard.init += 1;
250            }
251            None => {
252                return Err(guard.init);
253                // <-- guard is dropped here with already initialized elements
254            }
255        }
256    }
257
258    mem::forget(guard);
259    // SAFETY: the loop above loops exactly M * N times which is the size of the
260    // matrix, so all elements in the matrix are initialized.
261    Ok(unsafe { matrix.assume_init() })
262}
263
264// /// Like [`collect()`] except the caller must guarantee that the iterator will
265// /// yield enough elements to fill the matrix.
266// pub unsafe fn collect_unchecked<I, T, const M: usize, const N: usize>(iter: I) -> Matrix<M, N, T>
267// where
268//     I: IntoIterator<Item = T>,
269// {
270//     match collect(iter.into_iter()) {
271//         Ok(matrix) => matrix,
272//         Err(_) => {
273//             // SAFETY: the caller guarantees the iterator will yield enough
274//             // elements, so this error case can never be reached.
275//             unsafe { hint::unreachable_unchecked() }
276//         }
277//     }
278// }
279
280impl<T, const M: usize, const N: usize> FromIterator<T> for Matrix<M, N, T> {
281    /// Create a new matrix from an iterator.
282    ///
283    /// Elements will be filled in column-major order.
284    ///
285    /// # Panics
286    ///
287    /// If the iterator doesn't yield enough elements to fill the matrix.
288    #[inline]
289    fn from_iter<I>(iter: I) -> Self
290    where
291        I: IntoIterator<Item = T>,
292    {
293        collect(iter.into_iter()).unwrap_or_else(|len| collect_panic::<M, N>(len))
294    }
295}
296
297#[cold]
298fn collect_panic<const M: usize, const N: usize>(len: usize) -> ! {
299    if N == 1 {
300        panic!("collect iterator of length {} into `Vector<_, {}>`", len, M);
301    } else if M == 1 {
302        panic!(
303            "collect iterator of length {} into `RowVector<_, {}>`",
304            len, N
305        );
306    } else {
307        panic!(
308            "collect iterator of length {} into `Matrix<_, {}, {}>`",
309            len, M, N
310        );
311    }
312}
313
314#[cfg(test)]
315mod new_test {
316    use approx::assert_relative_eq;
317    #[test]
318    fn diag() {
319        let d = diag!(0.1, 0.2);
320        let e = matrix![
321        0.1, 0.0;
322        0.0, 0.2;
323        ];
324        assert_relative_eq!(d, e, max_relative = 1e-6);
325
326        let d = diag!(0.1, 0.2, 0.3);
327        let e = matrix![
328        0.1, 0.0, 0.0;
329        0.0, 0.2, 0.0;
330        0.0, 0.0, 0.3;
331        ];
332        assert_relative_eq!(d, e, max_relative = 1e-6);
333
334        let d = diag!(0.1, 0.2, 0.3, 0.4);
335        let e = matrix![
336        0.1, 0.0, 0.0, 0.0;
337        0.0, 0.2, 0.0, 0.0;
338        0.0, 0.0, 0.3, 0.0;
339        0.0, 0.0, 0.0, 0.4;
340        ];
341        assert_relative_eq!(d, e, max_relative = 1e-6);
342
343        let d = diag!(0.1, 0.2, 0.3, 0.4, 0.5);
344        let e = matrix![
345        0.1, 0.0, 0.0, 0.0, 0.0;
346        0.0, 0.2, 0.0, 0.0, 0.0;
347        0.0, 0.0, 0.3, 0.0, 0.0;
348        0.0, 0.0, 0.0, 0.4, 0.0;
349        0.0, 0.0, 0.0, 0.0, 0.5;
350        ];
351        assert_relative_eq!(d, e, max_relative = 1e-6);
352
353        let d = diag!(0.1, 0.2, 0.3, 0.4, 0.5, 0.6);
354        let e = matrix![
355        0.1, 0.0, 0.0, 0.0, 0.0, 0.0;
356        0.0, 0.2, 0.0, 0.0, 0.0, 0.0;
357        0.0, 0.0, 0.3, 0.0, 0.0, 0.0;
358        0.0, 0.0, 0.0, 0.4, 0.0, 0.0;
359        0.0, 0.0, 0.0, 0.0, 0.5, 0.0;
360        0.0, 0.0, 0.0, 0.0, 0.0, 0.6;
361        ];
362        assert_relative_eq!(d, e, max_relative = 1e-6);
363    }
364}