1use core::ops::{Index, IndexMut, Mul};
2
3use crate::{EPSILON, Quat, Transform, Vec3, Vec4, tan};
4
5#[derive(Clone, Copy, Debug, PartialEq)]
7#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
8pub struct Mat3 {
9 pub cols: [Vec3; 3],
11}
12
13#[derive(Clone, Copy, Debug, PartialEq)]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16pub struct Mat4 {
17 pub cols: [Vec4; 4],
19}
20
21impl Mat3 {
22 pub const IDENTITY: Self = Self::from_cols(Vec3::X, Vec3::Y, Vec3::Z);
24
25 #[inline]
27 pub const fn from_cols(x: Vec3, y: Vec3, z: Vec3) -> Self {
28 Self { cols: [x, y, z] }
29 }
30
31 #[inline]
33 pub fn from_mat4(matrix: Mat4) -> Self {
34 Self::from_cols(
35 matrix.cols[0].truncate(),
36 matrix.cols[1].truncate(),
37 matrix.cols[2].truncate(),
38 )
39 }
40
41 #[inline]
43 pub fn get(self, row: usize, col: usize) -> f32 {
44 self.cols[col][row]
45 }
46
47 #[inline]
49 pub fn determinant(self) -> f32 {
50 let a = self.get(0, 0);
51 let b = self.get(0, 1);
52 let c = self.get(0, 2);
53 let d = self.get(1, 0);
54 let e = self.get(1, 1);
55 let f = self.get(1, 2);
56 let g = self.get(2, 0);
57 let h = self.get(2, 1);
58 let i = self.get(2, 2);
59
60 a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g)
61 }
62
63 #[allow(clippy::needless_range_loop)]
65 pub fn inverse(self) -> Option<Self> {
66 let det = self.determinant();
67 if det.abs() <= EPSILON {
68 return None;
69 }
70 let inv_det = 1.0 / det;
71 let m = |r, c| self.get(r, c);
72 Some(Self::from_cols(
73 Vec3::new(
74 (m(1, 1) * m(2, 2) - m(1, 2) * m(2, 1)) * inv_det,
75 (m(1, 2) * m(2, 0) - m(1, 0) * m(2, 2)) * inv_det,
76 (m(1, 0) * m(2, 1) - m(1, 1) * m(2, 0)) * inv_det,
77 ),
78 Vec3::new(
79 (m(0, 2) * m(2, 1) - m(0, 1) * m(2, 2)) * inv_det,
80 (m(0, 0) * m(2, 2) - m(0, 2) * m(2, 0)) * inv_det,
81 (m(0, 1) * m(2, 0) - m(0, 0) * m(2, 1)) * inv_det,
82 ),
83 Vec3::new(
84 (m(0, 1) * m(1, 2) - m(0, 2) * m(1, 1)) * inv_det,
85 (m(0, 2) * m(1, 0) - m(0, 0) * m(1, 2)) * inv_det,
86 (m(0, 0) * m(1, 1) - m(0, 1) * m(1, 0)) * inv_det,
87 ),
88 ))
89 }
90
91 #[inline]
93 pub fn transpose(self) -> Self {
94 Self::from_cols(
95 Vec3::new(self.get(0, 0), self.get(0, 1), self.get(0, 2)),
96 Vec3::new(self.get(1, 0), self.get(1, 1), self.get(1, 2)),
97 Vec3::new(self.get(2, 0), self.get(2, 1), self.get(2, 2)),
98 )
99 }
100
101 #[inline]
103 pub fn mul_mat3(self, rhs: Self) -> Self {
104 Self::from_cols(
105 self.mul_vec3(rhs.cols[0]),
106 self.mul_vec3(rhs.cols[1]),
107 self.mul_vec3(rhs.cols[2]),
108 )
109 }
110
111 #[inline]
113 pub fn mul_vec3(self, rhs: Vec3) -> Vec3 {
114 self.cols[0] * rhs.x + self.cols[1] * rhs.y + self.cols[2] * rhs.z
115 }
116
117 #[inline]
119 pub fn to_cols_array(self) -> [f32; 9] {
120 [
121 self.cols[0].x,
122 self.cols[0].y,
123 self.cols[0].z,
124 self.cols[1].x,
125 self.cols[1].y,
126 self.cols[1].z,
127 self.cols[2].x,
128 self.cols[2].y,
129 self.cols[2].z,
130 ]
131 }
132}
133
134impl Mat4 {
135 pub const IDENTITY: Self = Self::from_cols(Vec4::X, Vec4::Y, Vec4::Z, Vec4::W);
137
138 #[inline]
140 pub const fn from_cols(x: Vec4, y: Vec4, z: Vec4, w: Vec4) -> Self {
141 Self { cols: [x, y, z, w] }
142 }
143
144 #[inline]
146 pub const fn from_cols_array(values: [f32; 16]) -> Self {
147 Self::from_cols(
148 Vec4::new(values[0], values[1], values[2], values[3]),
149 Vec4::new(values[4], values[5], values[6], values[7]),
150 Vec4::new(values[8], values[9], values[10], values[11]),
151 Vec4::new(values[12], values[13], values[14], values[15]),
152 )
153 }
154
155 #[inline]
157 pub fn get(self, row: usize, col: usize) -> f32 {
158 self.cols[col][row]
159 }
160
161 pub fn perspective(fov_y_rad: f32, aspect: f32, near: f32, far: f32) -> Self {
163 if aspect.abs() <= EPSILON || near <= 0.0 || far <= near {
164 return Self::IDENTITY;
165 }
166
167 let f = 1.0 / tan(fov_y_rad * 0.5);
168 Self::from_cols(
169 Vec4::new(f / aspect, 0.0, 0.0, 0.0),
170 Vec4::new(0.0, f, 0.0, 0.0),
171 Vec4::new(0.0, 0.0, far / (near - far), -1.0),
172 Vec4::new(0.0, 0.0, (near * far) / (near - far), 0.0),
173 )
174 }
175
176 pub fn orthographic(left: f32, right: f32, bottom: f32, top: f32, near: f32, far: f32) -> Self {
178 let width = right - left;
179 let height = top - bottom;
180 let depth = near - far;
181 if width.abs() <= EPSILON || height.abs() <= EPSILON || depth.abs() <= EPSILON {
182 return Self::IDENTITY;
183 }
184
185 Self::from_cols(
186 Vec4::new(2.0 / width, 0.0, 0.0, 0.0),
187 Vec4::new(0.0, 2.0 / height, 0.0, 0.0),
188 Vec4::new(0.0, 0.0, 1.0 / depth, 0.0),
189 Vec4::new(
190 -(right + left) / width,
191 -(top + bottom) / height,
192 near / depth,
193 1.0,
194 ),
195 )
196 }
197
198 pub fn look_at(eye: Vec3, target: Vec3, up: Vec3) -> Self {
200 let forward = (target - eye).normalize();
201 if forward.length_squared() <= EPSILON {
202 return Self::from_translation(-eye);
203 }
204
205 let right = forward.cross(up).normalize();
206 let up = right.cross(forward).normalize();
207
208 Self::from_cols(
209 Vec4::new(right.x, up.x, -forward.x, 0.0),
210 Vec4::new(right.y, up.y, -forward.y, 0.0),
211 Vec4::new(right.z, up.z, -forward.z, 0.0),
212 Vec4::new(-right.dot(eye), -up.dot(eye), forward.dot(eye), 1.0),
213 )
214 }
215
216 #[inline]
218 pub fn from_translation(value: Vec3) -> Self {
219 Self::from_cols(
220 Vec4::X,
221 Vec4::Y,
222 Vec4::Z,
223 Vec4::new(value.x, value.y, value.z, 1.0),
224 )
225 }
226
227 #[inline]
229 pub fn from_rotation(rotation: Quat) -> Self {
230 rotation.to_mat4()
231 }
232
233 #[inline]
235 pub fn from_scale(value: Vec3) -> Self {
236 Self::from_cols(
237 Vec4::new(value.x, 0.0, 0.0, 0.0),
238 Vec4::new(0.0, value.y, 0.0, 0.0),
239 Vec4::new(0.0, 0.0, value.z, 0.0),
240 Vec4::W,
241 )
242 }
243
244 #[inline]
246 pub fn from_trs(translation: Vec3, rotation: Quat, scale: Vec3) -> Self {
247 Self::from_translation(translation)
248 .mul_mat4(Self::from_rotation(rotation))
249 .mul_mat4(Self::from_scale(scale))
250 }
251
252 #[inline]
254 pub fn mul_mat4(self, rhs: Self) -> Self {
255 Self::from_cols(
256 self.mul_vec4(rhs.cols[0]),
257 self.mul_vec4(rhs.cols[1]),
258 self.mul_vec4(rhs.cols[2]),
259 self.mul_vec4(rhs.cols[3]),
260 )
261 }
262
263 #[inline]
265 pub fn mul_vec4(self, rhs: Vec4) -> Vec4 {
266 self.cols[0] * rhs.x + self.cols[1] * rhs.y + self.cols[2] * rhs.z + self.cols[3] * rhs.w
267 }
268
269 #[inline]
271 pub fn mul_vec3(self, rhs: Vec3) -> Vec3 {
272 let out = self.mul_vec4(Vec4::new(rhs.x, rhs.y, rhs.z, 1.0));
273 if out.w.abs() <= EPSILON {
274 out.truncate()
275 } else {
276 out.truncate() / out.w
277 }
278 }
279
280 #[allow(clippy::needless_range_loop)]
282 pub fn inverse(self) -> Option<Self> {
283 let mut aug = [[0.0_f32; 8]; 4];
284 for row in 0..4 {
285 for col in 0..4 {
286 aug[row][col] = self.get(row, col);
287 }
288 aug[row][row + 4] = 1.0;
289 }
290
291 for col in 0..4 {
292 let mut pivot = col;
293 let mut pivot_abs = aug[pivot][col].abs();
294 for (row, values) in aug.iter().enumerate().skip(col + 1) {
295 let value_abs = values[col].abs();
296 if value_abs > pivot_abs {
297 pivot = row;
298 pivot_abs = value_abs;
299 }
300 }
301 if pivot_abs <= EPSILON {
302 return None;
303 }
304 if pivot != col {
305 aug.swap(pivot, col);
306 }
307
308 let inv_pivot = 1.0 / aug[col][col];
309 for value in &mut aug[col] {
310 *value *= inv_pivot;
311 }
312
313 for row in 0..4 {
314 if row == col {
315 continue;
316 }
317 let factor = aug[row][col];
318 if factor.abs() <= EPSILON {
319 continue;
320 }
321 for i in 0..8 {
322 aug[row][i] -= factor * aug[col][i];
323 }
324 }
325 }
326
327 Some(Self::from_cols(
328 Vec4::new(aug[0][4], aug[1][4], aug[2][4], aug[3][4]),
329 Vec4::new(aug[0][5], aug[1][5], aug[2][5], aug[3][5]),
330 Vec4::new(aug[0][6], aug[1][6], aug[2][6], aug[3][6]),
331 Vec4::new(aug[0][7], aug[1][7], aug[2][7], aug[3][7]),
332 ))
333 }
334
335 #[inline]
337 pub fn transpose(self) -> Self {
338 Self::from_cols(
339 Vec4::new(
340 self.get(0, 0),
341 self.get(0, 1),
342 self.get(0, 2),
343 self.get(0, 3),
344 ),
345 Vec4::new(
346 self.get(1, 0),
347 self.get(1, 1),
348 self.get(1, 2),
349 self.get(1, 3),
350 ),
351 Vec4::new(
352 self.get(2, 0),
353 self.get(2, 1),
354 self.get(2, 2),
355 self.get(2, 3),
356 ),
357 Vec4::new(
358 self.get(3, 0),
359 self.get(3, 1),
360 self.get(3, 2),
361 self.get(3, 3),
362 ),
363 )
364 }
365
366 pub fn decompose(self) -> Option<Transform> {
368 let translation = self.cols[3].truncate();
369 let scale = Vec3::new(
370 self.cols[0].truncate().length(),
371 self.cols[1].truncate().length(),
372 self.cols[2].truncate().length(),
373 );
374 if scale.x <= EPSILON || scale.y <= EPSILON || scale.z <= EPSILON {
375 return None;
376 }
377
378 let inv_scale = Vec3::new(1.0 / scale.x, 1.0 / scale.y, 1.0 / scale.z);
379 let rotation_matrix = Self::from_cols(
380 Vec4::new(
381 self.cols[0].x * inv_scale.x,
382 self.cols[0].y * inv_scale.x,
383 self.cols[0].z * inv_scale.x,
384 0.0,
385 ),
386 Vec4::new(
387 self.cols[1].x * inv_scale.y,
388 self.cols[1].y * inv_scale.y,
389 self.cols[1].z * inv_scale.y,
390 0.0,
391 ),
392 Vec4::new(
393 self.cols[2].x * inv_scale.z,
394 self.cols[2].y * inv_scale.z,
395 self.cols[2].z * inv_scale.z,
396 0.0,
397 ),
398 Vec4::W,
399 );
400
401 Some(Transform::new(
402 translation,
403 Quat::from_mat4(rotation_matrix),
404 scale,
405 ))
406 }
407
408 #[inline]
410 pub fn to_cols_array(self) -> [f32; 16] {
411 [
412 self.cols[0].x,
413 self.cols[0].y,
414 self.cols[0].z,
415 self.cols[0].w,
416 self.cols[1].x,
417 self.cols[1].y,
418 self.cols[1].z,
419 self.cols[1].w,
420 self.cols[2].x,
421 self.cols[2].y,
422 self.cols[2].z,
423 self.cols[2].w,
424 self.cols[3].x,
425 self.cols[3].y,
426 self.cols[3].z,
427 self.cols[3].w,
428 ]
429 }
430}
431
432impl Default for Mat3 {
433 #[inline]
434 fn default() -> Self {
435 Self::IDENTITY
436 }
437}
438
439impl Default for Mat4 {
440 #[inline]
441 fn default() -> Self {
442 Self::IDENTITY
443 }
444}
445
446impl Index<usize> for Mat3 {
447 type Output = Vec3;
448
449 #[inline]
450 fn index(&self, index: usize) -> &Self::Output {
451 &self.cols[index]
452 }
453}
454
455impl IndexMut<usize> for Mat3 {
456 #[inline]
457 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
458 &mut self.cols[index]
459 }
460}
461
462impl Index<usize> for Mat4 {
463 type Output = Vec4;
464
465 #[inline]
466 fn index(&self, index: usize) -> &Self::Output {
467 &self.cols[index]
468 }
469}
470
471impl IndexMut<usize> for Mat4 {
472 #[inline]
473 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
474 &mut self.cols[index]
475 }
476}
477
478impl Mul for Mat3 {
479 type Output = Self;
480
481 #[inline]
482 fn mul(self, rhs: Self) -> Self::Output {
483 self.mul_mat3(rhs)
484 }
485}
486
487impl Mul<Vec3> for Mat3 {
488 type Output = Vec3;
489
490 #[inline]
491 fn mul(self, rhs: Vec3) -> Self::Output {
492 self.mul_vec3(rhs)
493 }
494}
495
496impl Mul for Mat4 {
497 type Output = Self;
498
499 #[inline]
500 fn mul(self, rhs: Self) -> Self::Output {
501 self.mul_mat4(rhs)
502 }
503}
504
505impl Mul<Vec4> for Mat4 {
506 type Output = Vec4;
507
508 #[inline]
509 fn mul(self, rhs: Vec4) -> Self::Output {
510 self.mul_vec4(rhs)
511 }
512}
513
514impl Mul<Vec3> for Mat4 {
515 type Output = Vec3;
516
517 #[inline]
518 fn mul(self, rhs: Vec3) -> Self::Output {
519 self.mul_vec3(rhs)
520 }
521}
522
523#[cfg(feature = "approx")]
524macro_rules! impl_matrix_approx {
525 ($type:ident, $cols:expr) => {
526 impl approx::AbsDiffEq for $type {
527 type Epsilon = f32;
528
529 #[inline]
530 fn default_epsilon() -> Self::Epsilon {
531 f32::default_epsilon()
532 }
533
534 #[inline]
535 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
536 self.cols
537 .iter()
538 .zip(other.cols.iter())
539 .all(|(a, b)| approx::AbsDiffEq::abs_diff_eq(a, b, epsilon))
540 }
541 }
542
543 impl approx::RelativeEq for $type {
544 #[inline]
545 fn default_max_relative() -> Self::Epsilon {
546 f32::default_max_relative()
547 }
548
549 #[inline]
550 fn relative_eq(
551 &self,
552 other: &Self,
553 epsilon: Self::Epsilon,
554 max_relative: Self::Epsilon,
555 ) -> bool {
556 self.cols
557 .iter()
558 .zip(other.cols.iter())
559 .all(|(a, b)| approx::RelativeEq::relative_eq(a, b, epsilon, max_relative))
560 }
561 }
562
563 impl approx::UlpsEq for $type {
564 #[inline]
565 fn default_max_ulps() -> u32 {
566 f32::default_max_ulps()
567 }
568
569 #[inline]
570 fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
571 self.cols
572 .iter()
573 .zip(other.cols.iter())
574 .all(|(a, b)| approx::UlpsEq::ulps_eq(a, b, epsilon, max_ulps))
575 }
576 }
577 };
578}
579
580#[cfg(feature = "approx")]
581impl_matrix_approx!(Mat3, 3);
582#[cfg(feature = "approx")]
583impl_matrix_approx!(Mat4, 4);
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588 use crate::assert_close;
589
590 #[test]
591 fn perspective_maps_near_and_far_to_webgpu_depth() {
592 let projection = Mat4::perspective(core::f32::consts::FRAC_PI_2, 1.0, 0.1, 10.0);
593 let near = projection.mul_vec4(Vec4::new(0.0, 0.0, -0.1, 1.0));
594 let far = projection.mul_vec4(Vec4::new(0.0, 0.0, -10.0, 1.0));
595
596 assert_close(near.z / near.w, 0.0);
597 assert_close(far.z / far.w, 1.0);
598 }
599
600 #[test]
601 fn orthographic_maps_center_and_depth() {
602 let projection = Mat4::orthographic(-1.0, 1.0, -1.0, 1.0, 0.1, 10.0);
603 let center = projection.mul_vec4(Vec4::new(0.0, 0.0, -0.1, 1.0));
604 assert_close(center.x, 0.0);
605 assert_close(center.y, 0.0);
606 assert_close(center.z, 0.0);
607 }
608
609 #[test]
610 fn inverse_multiplies_to_identity() {
611 let matrix = Mat4::from_trs(
612 Vec3::new(2.0, 3.0, 4.0),
613 Quat::from_axis_angle(Vec3::Y, 0.7),
614 Vec3::new(2.0, 3.0, 4.0),
615 );
616 let inverse = matrix.inverse().unwrap();
617 let identity = matrix * inverse;
618 let values = identity.to_cols_array();
619 let expected = Mat4::IDENTITY.to_cols_array();
620 for (a, b) in values.into_iter().zip(expected) {
621 assert_close(a, b);
622 }
623 }
624
625 #[test]
626 fn transpose_and_column_major_array_work() {
627 let matrix = Mat4::from_cols_array([
628 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
629 ]);
630 assert_eq!(matrix.to_cols_array()[1], 2.0);
631 assert_eq!(matrix.transpose().get(0, 1), 2.0);
632 }
633}