sable_core/math/
mat3.rs

1//! 3x3 matrix type.
2
3use super::{EPSILON, Float, Vec3};
4
5/// A 3x3 column-major matrix.
6#[derive(Debug, Clone, Copy, PartialEq)]
7#[repr(C)]
8pub struct Mat3 {
9    /// Matrix columns.
10    pub cols: [Vec3; 3],
11}
12
13impl Default for Mat3 {
14    fn default() -> Self {
15        Self::IDENTITY
16    }
17}
18
19impl Mat3 {
20    /// Zero matrix.
21    pub const ZERO: Self = Self {
22        cols: [Vec3::ZERO, Vec3::ZERO, Vec3::ZERO],
23    };
24
25    /// Identity matrix.
26    pub const IDENTITY: Self = Self {
27        cols: [Vec3::X, Vec3::Y, Vec3::Z],
28    };
29
30    /// Creates a matrix from column vectors.
31    #[inline]
32    #[must_use]
33    pub const fn from_cols(c0: Vec3, c1: Vec3, c2: Vec3) -> Self {
34        Self { cols: [c0, c1, c2] }
35    }
36
37    /// Creates a diagonal matrix.
38    #[inline]
39    #[must_use]
40    pub fn from_diagonal(diag: Vec3) -> Self {
41        Self {
42            cols: [
43                Vec3::new(diag.x, 0.0, 0.0),
44                Vec3::new(0.0, diag.y, 0.0),
45                Vec3::new(0.0, 0.0, diag.z),
46            ],
47        }
48    }
49
50    /// Creates a 2D scaling matrix.
51    #[inline]
52    #[must_use]
53    pub fn from_scale(scale: super::Vec2) -> Self {
54        Self::from_diagonal(Vec3::new(scale.x, scale.y, 1.0))
55    }
56
57    /// Creates a 2D translation matrix.
58    #[inline]
59    #[must_use]
60    pub fn from_translation(translation: super::Vec2) -> Self {
61        Self {
62            cols: [
63                Vec3::X,
64                Vec3::Y,
65                Vec3::new(translation.x, translation.y, 1.0),
66            ],
67        }
68    }
69
70    /// Creates a 2D rotation matrix.
71    #[inline]
72    #[must_use]
73    pub fn from_rotation(angle: Float) -> Self {
74        let (s, c) = angle.sin_cos();
75        Self {
76            cols: [Vec3::new(c, s, 0.0), Vec3::new(-s, c, 0.0), Vec3::Z],
77        }
78    }
79
80    /// Returns the transpose of this matrix.
81    #[inline]
82    #[must_use]
83    pub fn transpose(self) -> Self {
84        let c = self.cols;
85        Self {
86            cols: [
87                Vec3::new(c[0].x, c[1].x, c[2].x),
88                Vec3::new(c[0].y, c[1].y, c[2].y),
89                Vec3::new(c[0].z, c[1].z, c[2].z),
90            ],
91        }
92    }
93
94    /// Computes the determinant.
95    #[inline]
96    #[must_use]
97    pub fn determinant(self) -> Float {
98        let c = self.cols;
99        c[0].x * (c[1].y * c[2].z - c[2].y * c[1].z) - c[1].x * (c[0].y * c[2].z - c[2].y * c[0].z)
100            + c[2].x * (c[0].y * c[1].z - c[1].y * c[0].z)
101    }
102
103    /// Returns the inverse, or `None` if not invertible.
104    #[must_use]
105    pub fn try_inverse(self) -> Option<Self> {
106        let det = self.determinant();
107        if det.abs() < EPSILON {
108            return None;
109        }
110
111        let c = self.cols;
112        let inv_det = 1.0 / det;
113
114        Some(Self {
115            cols: [
116                Vec3::new(
117                    (c[1].y * c[2].z - c[2].y * c[1].z) * inv_det,
118                    (c[2].y * c[0].z - c[0].y * c[2].z) * inv_det,
119                    (c[0].y * c[1].z - c[1].y * c[0].z) * inv_det,
120                ),
121                Vec3::new(
122                    (c[2].x * c[1].z - c[1].x * c[2].z) * inv_det,
123                    (c[0].x * c[2].z - c[2].x * c[0].z) * inv_det,
124                    (c[1].x * c[0].z - c[0].x * c[1].z) * inv_det,
125                ),
126                Vec3::new(
127                    (c[1].x * c[2].y - c[2].x * c[1].y) * inv_det,
128                    (c[2].x * c[0].y - c[0].x * c[2].y) * inv_det,
129                    (c[0].x * c[1].y - c[1].x * c[0].y) * inv_det,
130                ),
131            ],
132        })
133    }
134
135    /// Returns the inverse.
136    ///
137    /// # Panics
138    ///
139    /// Panics if the matrix is not invertible.
140    #[must_use]
141    pub fn inverse(self) -> Self {
142        self.try_inverse().expect("Matrix is not invertible")
143    }
144
145    /// Transforms a Vec3 by this matrix.
146    #[inline]
147    #[must_use]
148    pub fn transform(self, v: Vec3) -> Vec3 {
149        self.cols[0] * v.x + self.cols[1] * v.y + self.cols[2] * v.z
150    }
151
152    /// Transforms a 2D point (Vec2 with z=1).
153    #[inline]
154    #[must_use]
155    pub fn transform_point2(self, p: super::Vec2) -> super::Vec2 {
156        let v = self.transform(Vec3::from_vec2(p, 1.0));
157        super::Vec2::new(v.x, v.y)
158    }
159
160    /// Transforms a 2D vector (Vec2 with z=0).
161    #[inline]
162    #[must_use]
163    pub fn transform_vector2(self, v: super::Vec2) -> super::Vec2 {
164        let r = self.transform(Vec3::from_vec2(v, 0.0));
165        super::Vec2::new(r.x, r.y)
166    }
167
168    /// Checks if this matrix is approximately equal to another.
169    #[inline]
170    #[must_use]
171    pub fn approx_eq(self, other: Self) -> bool {
172        self.cols[0].approx_eq(other.cols[0])
173            && self.cols[1].approx_eq(other.cols[1])
174            && self.cols[2].approx_eq(other.cols[2])
175    }
176}
177
178impl std::ops::Mul for Mat3 {
179    type Output = Self;
180
181    #[inline]
182    fn mul(self, other: Self) -> Self {
183        Self {
184            cols: [
185                self.transform(other.cols[0]),
186                self.transform(other.cols[1]),
187                self.transform(other.cols[2]),
188            ],
189        }
190    }
191}
192
193impl std::ops::Mul<Vec3> for Mat3 {
194    type Output = Vec3;
195
196    #[inline]
197    fn mul(self, v: Vec3) -> Vec3 {
198        self.transform(v)
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::super::{PI, Vec2, approx_eq};
205    use super::*;
206
207    #[test]
208    fn test_identity() {
209        let m = Mat3::IDENTITY;
210        assert!(m.approx_eq(Mat3::default()));
211        assert!(approx_eq(m.determinant(), 1.0));
212    }
213
214    #[test]
215    fn test_from_scale() {
216        let m = Mat3::from_scale(Vec2::new(2.0, 3.0));
217        let p = m.transform_point2(Vec2::ONE);
218        assert!(p.approx_eq(Vec2::new(2.0, 3.0)));
219    }
220
221    #[test]
222    fn test_from_translation() {
223        let m = Mat3::from_translation(Vec2::new(5.0, 10.0));
224        let p = m.transform_point2(Vec2::ZERO);
225        assert!(p.approx_eq(Vec2::new(5.0, 10.0)));
226    }
227
228    #[test]
229    fn test_from_rotation() {
230        let m = Mat3::from_rotation(PI / 2.0);
231        let v = m.transform_vector2(Vec2::X);
232        assert!(v.approx_eq(Vec2::Y));
233    }
234
235    #[test]
236    fn test_transpose() {
237        let m = Mat3::from_cols(
238            Vec3::new(1.0, 4.0, 7.0),
239            Vec3::new(2.0, 5.0, 8.0),
240            Vec3::new(3.0, 6.0, 9.0),
241        );
242        let t = m.transpose();
243        assert!(approx_eq(t.cols[0].x, 1.0));
244        assert!(approx_eq(t.cols[0].y, 2.0));
245        assert!(approx_eq(t.cols[0].z, 3.0));
246    }
247
248    #[test]
249    fn test_determinant() {
250        let m = Mat3::from_diagonal(Vec3::new(2.0, 3.0, 4.0));
251        assert!(approx_eq(m.determinant(), 24.0));
252    }
253
254    #[test]
255    fn test_inverse() {
256        let m = Mat3::from_translation(Vec2::new(1.0, 2.0));
257        let inv = m.inverse();
258        let result = m * inv;
259        assert!(result.approx_eq(Mat3::IDENTITY));
260    }
261
262    #[test]
263    fn test_inverse_rotation() {
264        let m = Mat3::from_rotation(0.5);
265        let inv = m.inverse();
266        let result = m * inv;
267        assert!(result.approx_eq(Mat3::IDENTITY));
268    }
269
270    #[test]
271    fn test_try_inverse_singular() {
272        let m = Mat3::ZERO;
273        assert!(m.try_inverse().is_none());
274    }
275
276    #[test]
277    fn test_mul_identity() {
278        let m = Mat3::from_translation(Vec2::new(1.0, 2.0));
279        let result = m * Mat3::IDENTITY;
280        assert!(result.approx_eq(m));
281    }
282
283    #[test]
284    fn test_transform_vector_not_affected_by_translation() {
285        let m = Mat3::from_translation(Vec2::new(100.0, 100.0));
286        let v = m.transform_vector2(Vec2::X);
287        assert!(v.approx_eq(Vec2::X));
288    }
289}