Skip to main content

scenix_math/
mat.rs

1use core::ops::{Index, IndexMut, Mul};
2
3use crate::{EPSILON, Quat, Transform, Vec3, Vec4, tan};
4
5/// A 3x3 column-major matrix.
6#[derive(Clone, Copy, Debug, PartialEq)]
7#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
8pub struct Mat3 {
9    /// Matrix columns.
10    pub cols: [Vec3; 3],
11}
12
13/// A 4x4 column-major matrix.
14#[derive(Clone, Copy, Debug, PartialEq)]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16pub struct Mat4 {
17    /// Matrix columns.
18    pub cols: [Vec4; 4],
19}
20
21impl Mat3 {
22    /// Identity matrix.
23    pub const IDENTITY: Self = Self::from_cols(Vec3::X, Vec3::Y, Vec3::Z);
24
25    /// Creates a matrix from column vectors.
26    #[inline]
27    pub const fn from_cols(x: Vec3, y: Vec3, z: Vec3) -> Self {
28        Self { cols: [x, y, z] }
29    }
30
31    /// Extracts the upper-left 3x3 matrix from a `Mat4`.
32    #[inline]
33    pub fn from_mat4(matrix: Mat4) -> Self {
34        Self::from_cols(
35            matrix.cols[0].truncate(),
36            matrix.cols[1].truncate(),
37            matrix.cols[2].truncate(),
38        )
39    }
40
41    /// Returns the element at row and column.
42    #[inline]
43    pub fn get(self, row: usize, col: usize) -> f32 {
44        self.cols[col][row]
45    }
46
47    /// Returns the determinant.
48    #[inline]
49    pub fn determinant(self) -> f32 {
50        let a = self.get(0, 0);
51        let b = self.get(0, 1);
52        let c = self.get(0, 2);
53        let d = self.get(1, 0);
54        let e = self.get(1, 1);
55        let f = self.get(1, 2);
56        let g = self.get(2, 0);
57        let h = self.get(2, 1);
58        let i = self.get(2, 2);
59
60        a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g)
61    }
62
63    /// Returns the inverse matrix, if the matrix is invertible.
64    #[allow(clippy::needless_range_loop)]
65    pub fn inverse(self) -> Option<Self> {
66        let det = self.determinant();
67        if det.abs() <= EPSILON {
68            return None;
69        }
70        let inv_det = 1.0 / det;
71        let m = |r, c| self.get(r, c);
72        Some(Self::from_cols(
73            Vec3::new(
74                (m(1, 1) * m(2, 2) - m(1, 2) * m(2, 1)) * inv_det,
75                (m(1, 2) * m(2, 0) - m(1, 0) * m(2, 2)) * inv_det,
76                (m(1, 0) * m(2, 1) - m(1, 1) * m(2, 0)) * inv_det,
77            ),
78            Vec3::new(
79                (m(0, 2) * m(2, 1) - m(0, 1) * m(2, 2)) * inv_det,
80                (m(0, 0) * m(2, 2) - m(0, 2) * m(2, 0)) * inv_det,
81                (m(0, 1) * m(2, 0) - m(0, 0) * m(2, 1)) * inv_det,
82            ),
83            Vec3::new(
84                (m(0, 1) * m(1, 2) - m(0, 2) * m(1, 1)) * inv_det,
85                (m(0, 2) * m(1, 0) - m(0, 0) * m(1, 2)) * inv_det,
86                (m(0, 0) * m(1, 1) - m(0, 1) * m(1, 0)) * inv_det,
87            ),
88        ))
89    }
90
91    /// Returns the transpose matrix.
92    #[inline]
93    pub fn transpose(self) -> Self {
94        Self::from_cols(
95            Vec3::new(self.get(0, 0), self.get(0, 1), self.get(0, 2)),
96            Vec3::new(self.get(1, 0), self.get(1, 1), self.get(1, 2)),
97            Vec3::new(self.get(2, 0), self.get(2, 1), self.get(2, 2)),
98        )
99    }
100
101    /// Multiplies this matrix by another matrix.
102    #[inline]
103    pub fn mul_mat3(self, rhs: Self) -> Self {
104        Self::from_cols(
105            self.mul_vec3(rhs.cols[0]),
106            self.mul_vec3(rhs.cols[1]),
107            self.mul_vec3(rhs.cols[2]),
108        )
109    }
110
111    /// Multiplies this matrix by a vector.
112    #[inline]
113    pub fn mul_vec3(self, rhs: Vec3) -> Vec3 {
114        self.cols[0] * rhs.x + self.cols[1] * rhs.y + self.cols[2] * rhs.z
115    }
116
117    /// Returns the matrix as a column-major array.
118    #[inline]
119    pub fn to_cols_array(self) -> [f32; 9] {
120        [
121            self.cols[0].x,
122            self.cols[0].y,
123            self.cols[0].z,
124            self.cols[1].x,
125            self.cols[1].y,
126            self.cols[1].z,
127            self.cols[2].x,
128            self.cols[2].y,
129            self.cols[2].z,
130        ]
131    }
132}
133
134impl Mat4 {
135    /// Identity matrix.
136    pub const IDENTITY: Self = Self::from_cols(Vec4::X, Vec4::Y, Vec4::Z, Vec4::W);
137
138    /// Creates a matrix from column vectors.
139    #[inline]
140    pub const fn from_cols(x: Vec4, y: Vec4, z: Vec4, w: Vec4) -> Self {
141        Self { cols: [x, y, z, w] }
142    }
143
144    /// Creates a matrix from a column-major array.
145    #[inline]
146    pub const fn from_cols_array(values: [f32; 16]) -> Self {
147        Self::from_cols(
148            Vec4::new(values[0], values[1], values[2], values[3]),
149            Vec4::new(values[4], values[5], values[6], values[7]),
150            Vec4::new(values[8], values[9], values[10], values[11]),
151            Vec4::new(values[12], values[13], values[14], values[15]),
152        )
153    }
154
155    /// Returns the element at row and column.
156    #[inline]
157    pub fn get(self, row: usize, col: usize) -> f32 {
158        self.cols[col][row]
159    }
160
161    /// Returns a right-handed perspective projection matrix with WebGPU depth.
162    pub fn perspective(fov_y_rad: f32, aspect: f32, near: f32, far: f32) -> Self {
163        if aspect.abs() <= EPSILON || near <= 0.0 || far <= near {
164            return Self::IDENTITY;
165        }
166
167        let f = 1.0 / tan(fov_y_rad * 0.5);
168        Self::from_cols(
169            Vec4::new(f / aspect, 0.0, 0.0, 0.0),
170            Vec4::new(0.0, f, 0.0, 0.0),
171            Vec4::new(0.0, 0.0, far / (near - far), -1.0),
172            Vec4::new(0.0, 0.0, (near * far) / (near - far), 0.0),
173        )
174    }
175
176    /// Returns a right-handed orthographic projection matrix with WebGPU depth.
177    pub fn orthographic(left: f32, right: f32, bottom: f32, top: f32, near: f32, far: f32) -> Self {
178        let width = right - left;
179        let height = top - bottom;
180        let depth = near - far;
181        if width.abs() <= EPSILON || height.abs() <= EPSILON || depth.abs() <= EPSILON {
182            return Self::IDENTITY;
183        }
184
185        Self::from_cols(
186            Vec4::new(2.0 / width, 0.0, 0.0, 0.0),
187            Vec4::new(0.0, 2.0 / height, 0.0, 0.0),
188            Vec4::new(0.0, 0.0, 1.0 / depth, 0.0),
189            Vec4::new(
190                -(right + left) / width,
191                -(top + bottom) / height,
192                near / depth,
193                1.0,
194            ),
195        )
196    }
197
198    /// Returns a right-handed view matrix looking from `eye` to `target`.
199    pub fn look_at(eye: Vec3, target: Vec3, up: Vec3) -> Self {
200        let forward = (target - eye).normalize();
201        if forward.length_squared() <= EPSILON {
202            return Self::from_translation(-eye);
203        }
204
205        let right = forward.cross(up).normalize();
206        let up = right.cross(forward).normalize();
207
208        Self::from_cols(
209            Vec4::new(right.x, up.x, -forward.x, 0.0),
210            Vec4::new(right.y, up.y, -forward.y, 0.0),
211            Vec4::new(right.z, up.z, -forward.z, 0.0),
212            Vec4::new(-right.dot(eye), -up.dot(eye), forward.dot(eye), 1.0),
213        )
214    }
215
216    /// Creates a translation matrix.
217    #[inline]
218    pub fn from_translation(value: Vec3) -> Self {
219        Self::from_cols(
220            Vec4::X,
221            Vec4::Y,
222            Vec4::Z,
223            Vec4::new(value.x, value.y, value.z, 1.0),
224        )
225    }
226
227    /// Creates a rotation matrix from a quaternion.
228    #[inline]
229    pub fn from_rotation(rotation: Quat) -> Self {
230        rotation.to_mat4()
231    }
232
233    /// Creates a scale matrix.
234    #[inline]
235    pub fn from_scale(value: Vec3) -> Self {
236        Self::from_cols(
237            Vec4::new(value.x, 0.0, 0.0, 0.0),
238            Vec4::new(0.0, value.y, 0.0, 0.0),
239            Vec4::new(0.0, 0.0, value.z, 0.0),
240            Vec4::W,
241        )
242    }
243
244    /// Creates a matrix from translation, rotation, and scale.
245    #[inline]
246    pub fn from_trs(translation: Vec3, rotation: Quat, scale: Vec3) -> Self {
247        Self::from_translation(translation)
248            .mul_mat4(Self::from_rotation(rotation))
249            .mul_mat4(Self::from_scale(scale))
250    }
251
252    /// Multiplies this matrix by another matrix.
253    #[inline]
254    pub fn mul_mat4(self, rhs: Self) -> Self {
255        Self::from_cols(
256            self.mul_vec4(rhs.cols[0]),
257            self.mul_vec4(rhs.cols[1]),
258            self.mul_vec4(rhs.cols[2]),
259            self.mul_vec4(rhs.cols[3]),
260        )
261    }
262
263    /// Multiplies this matrix by a vector.
264    #[inline]
265    pub fn mul_vec4(self, rhs: Vec4) -> Vec4 {
266        self.cols[0] * rhs.x + self.cols[1] * rhs.y + self.cols[2] * rhs.z + self.cols[3] * rhs.w
267    }
268
269    /// Transforms a point and applies homogeneous divide when possible.
270    #[inline]
271    pub fn mul_vec3(self, rhs: Vec3) -> Vec3 {
272        let out = self.mul_vec4(Vec4::new(rhs.x, rhs.y, rhs.z, 1.0));
273        if out.w.abs() <= EPSILON {
274            out.truncate()
275        } else {
276            out.truncate() / out.w
277        }
278    }
279
280    /// Returns the inverse matrix, if the matrix is invertible.
281    #[allow(clippy::needless_range_loop)]
282    pub fn inverse(self) -> Option<Self> {
283        let mut aug = [[0.0_f32; 8]; 4];
284        for row in 0..4 {
285            for col in 0..4 {
286                aug[row][col] = self.get(row, col);
287            }
288            aug[row][row + 4] = 1.0;
289        }
290
291        for col in 0..4 {
292            let mut pivot = col;
293            let mut pivot_abs = aug[pivot][col].abs();
294            for (row, values) in aug.iter().enumerate().skip(col + 1) {
295                let value_abs = values[col].abs();
296                if value_abs > pivot_abs {
297                    pivot = row;
298                    pivot_abs = value_abs;
299                }
300            }
301            if pivot_abs <= EPSILON {
302                return None;
303            }
304            if pivot != col {
305                aug.swap(pivot, col);
306            }
307
308            let inv_pivot = 1.0 / aug[col][col];
309            for value in &mut aug[col] {
310                *value *= inv_pivot;
311            }
312
313            for row in 0..4 {
314                if row == col {
315                    continue;
316                }
317                let factor = aug[row][col];
318                if factor.abs() <= EPSILON {
319                    continue;
320                }
321                for i in 0..8 {
322                    aug[row][i] -= factor * aug[col][i];
323                }
324            }
325        }
326
327        Some(Self::from_cols(
328            Vec4::new(aug[0][4], aug[1][4], aug[2][4], aug[3][4]),
329            Vec4::new(aug[0][5], aug[1][5], aug[2][5], aug[3][5]),
330            Vec4::new(aug[0][6], aug[1][6], aug[2][6], aug[3][6]),
331            Vec4::new(aug[0][7], aug[1][7], aug[2][7], aug[3][7]),
332        ))
333    }
334
335    /// Returns the transpose matrix.
336    #[inline]
337    pub fn transpose(self) -> Self {
338        Self::from_cols(
339            Vec4::new(
340                self.get(0, 0),
341                self.get(0, 1),
342                self.get(0, 2),
343                self.get(0, 3),
344            ),
345            Vec4::new(
346                self.get(1, 0),
347                self.get(1, 1),
348                self.get(1, 2),
349                self.get(1, 3),
350            ),
351            Vec4::new(
352                self.get(2, 0),
353                self.get(2, 1),
354                self.get(2, 2),
355                self.get(2, 3),
356            ),
357            Vec4::new(
358                self.get(3, 0),
359                self.get(3, 1),
360                self.get(3, 2),
361                self.get(3, 3),
362            ),
363        )
364    }
365
366    /// Decomposes a TRS matrix into translation, rotation, and scale.
367    pub fn decompose(self) -> Option<Transform> {
368        let translation = self.cols[3].truncate();
369        let scale = Vec3::new(
370            self.cols[0].truncate().length(),
371            self.cols[1].truncate().length(),
372            self.cols[2].truncate().length(),
373        );
374        if scale.x <= EPSILON || scale.y <= EPSILON || scale.z <= EPSILON {
375            return None;
376        }
377
378        let inv_scale = Vec3::new(1.0 / scale.x, 1.0 / scale.y, 1.0 / scale.z);
379        let rotation_matrix = Self::from_cols(
380            Vec4::new(
381                self.cols[0].x * inv_scale.x,
382                self.cols[0].y * inv_scale.x,
383                self.cols[0].z * inv_scale.x,
384                0.0,
385            ),
386            Vec4::new(
387                self.cols[1].x * inv_scale.y,
388                self.cols[1].y * inv_scale.y,
389                self.cols[1].z * inv_scale.y,
390                0.0,
391            ),
392            Vec4::new(
393                self.cols[2].x * inv_scale.z,
394                self.cols[2].y * inv_scale.z,
395                self.cols[2].z * inv_scale.z,
396                0.0,
397            ),
398            Vec4::W,
399        );
400
401        Some(Transform::new(
402            translation,
403            Quat::from_mat4(rotation_matrix),
404            scale,
405        ))
406    }
407
408    /// Returns the matrix as a column-major array.
409    #[inline]
410    pub fn to_cols_array(self) -> [f32; 16] {
411        [
412            self.cols[0].x,
413            self.cols[0].y,
414            self.cols[0].z,
415            self.cols[0].w,
416            self.cols[1].x,
417            self.cols[1].y,
418            self.cols[1].z,
419            self.cols[1].w,
420            self.cols[2].x,
421            self.cols[2].y,
422            self.cols[2].z,
423            self.cols[2].w,
424            self.cols[3].x,
425            self.cols[3].y,
426            self.cols[3].z,
427            self.cols[3].w,
428        ]
429    }
430}
431
432impl Default for Mat3 {
433    #[inline]
434    fn default() -> Self {
435        Self::IDENTITY
436    }
437}
438
439impl Default for Mat4 {
440    #[inline]
441    fn default() -> Self {
442        Self::IDENTITY
443    }
444}
445
446impl Index<usize> for Mat3 {
447    type Output = Vec3;
448
449    #[inline]
450    fn index(&self, index: usize) -> &Self::Output {
451        &self.cols[index]
452    }
453}
454
455impl IndexMut<usize> for Mat3 {
456    #[inline]
457    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
458        &mut self.cols[index]
459    }
460}
461
462impl Index<usize> for Mat4 {
463    type Output = Vec4;
464
465    #[inline]
466    fn index(&self, index: usize) -> &Self::Output {
467        &self.cols[index]
468    }
469}
470
471impl IndexMut<usize> for Mat4 {
472    #[inline]
473    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
474        &mut self.cols[index]
475    }
476}
477
478impl Mul for Mat3 {
479    type Output = Self;
480
481    #[inline]
482    fn mul(self, rhs: Self) -> Self::Output {
483        self.mul_mat3(rhs)
484    }
485}
486
487impl Mul<Vec3> for Mat3 {
488    type Output = Vec3;
489
490    #[inline]
491    fn mul(self, rhs: Vec3) -> Self::Output {
492        self.mul_vec3(rhs)
493    }
494}
495
496impl Mul for Mat4 {
497    type Output = Self;
498
499    #[inline]
500    fn mul(self, rhs: Self) -> Self::Output {
501        self.mul_mat4(rhs)
502    }
503}
504
505impl Mul<Vec4> for Mat4 {
506    type Output = Vec4;
507
508    #[inline]
509    fn mul(self, rhs: Vec4) -> Self::Output {
510        self.mul_vec4(rhs)
511    }
512}
513
514impl Mul<Vec3> for Mat4 {
515    type Output = Vec3;
516
517    #[inline]
518    fn mul(self, rhs: Vec3) -> Self::Output {
519        self.mul_vec3(rhs)
520    }
521}
522
523#[cfg(feature = "approx")]
524macro_rules! impl_matrix_approx {
525    ($type:ident, $cols:expr) => {
526        impl approx::AbsDiffEq for $type {
527            type Epsilon = f32;
528
529            #[inline]
530            fn default_epsilon() -> Self::Epsilon {
531                f32::default_epsilon()
532            }
533
534            #[inline]
535            fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
536                self.cols
537                    .iter()
538                    .zip(other.cols.iter())
539                    .all(|(a, b)| approx::AbsDiffEq::abs_diff_eq(a, b, epsilon))
540            }
541        }
542
543        impl approx::RelativeEq for $type {
544            #[inline]
545            fn default_max_relative() -> Self::Epsilon {
546                f32::default_max_relative()
547            }
548
549            #[inline]
550            fn relative_eq(
551                &self,
552                other: &Self,
553                epsilon: Self::Epsilon,
554                max_relative: Self::Epsilon,
555            ) -> bool {
556                self.cols
557                    .iter()
558                    .zip(other.cols.iter())
559                    .all(|(a, b)| approx::RelativeEq::relative_eq(a, b, epsilon, max_relative))
560            }
561        }
562
563        impl approx::UlpsEq for $type {
564            #[inline]
565            fn default_max_ulps() -> u32 {
566                f32::default_max_ulps()
567            }
568
569            #[inline]
570            fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
571                self.cols
572                    .iter()
573                    .zip(other.cols.iter())
574                    .all(|(a, b)| approx::UlpsEq::ulps_eq(a, b, epsilon, max_ulps))
575            }
576        }
577    };
578}
579
580#[cfg(feature = "approx")]
581impl_matrix_approx!(Mat3, 3);
582#[cfg(feature = "approx")]
583impl_matrix_approx!(Mat4, 4);
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588    use crate::assert_close;
589
590    #[test]
591    fn perspective_maps_near_and_far_to_webgpu_depth() {
592        let projection = Mat4::perspective(core::f32::consts::FRAC_PI_2, 1.0, 0.1, 10.0);
593        let near = projection.mul_vec4(Vec4::new(0.0, 0.0, -0.1, 1.0));
594        let far = projection.mul_vec4(Vec4::new(0.0, 0.0, -10.0, 1.0));
595
596        assert_close(near.z / near.w, 0.0);
597        assert_close(far.z / far.w, 1.0);
598    }
599
600    #[test]
601    fn orthographic_maps_center_and_depth() {
602        let projection = Mat4::orthographic(-1.0, 1.0, -1.0, 1.0, 0.1, 10.0);
603        let center = projection.mul_vec4(Vec4::new(0.0, 0.0, -0.1, 1.0));
604        assert_close(center.x, 0.0);
605        assert_close(center.y, 0.0);
606        assert_close(center.z, 0.0);
607    }
608
609    #[test]
610    fn inverse_multiplies_to_identity() {
611        let matrix = Mat4::from_trs(
612            Vec3::new(2.0, 3.0, 4.0),
613            Quat::from_axis_angle(Vec3::Y, 0.7),
614            Vec3::new(2.0, 3.0, 4.0),
615        );
616        let inverse = matrix.inverse().unwrap();
617        let identity = matrix * inverse;
618        let values = identity.to_cols_array();
619        let expected = Mat4::IDENTITY.to_cols_array();
620        for (a, b) in values.into_iter().zip(expected) {
621            assert_close(a, b);
622        }
623    }
624
625    #[test]
626    fn transpose_and_column_major_array_work() {
627        let matrix = Mat4::from_cols_array([
628            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
629        ]);
630        assert_eq!(matrix.to_cols_array()[1], 2.0);
631        assert_eq!(matrix.transpose().get(0, 1), 2.0);
632    }
633}