1use super::{Float, Vec3, Vec4};
4
5#[derive(Debug, Clone, Copy, PartialEq)]
10#[repr(C)]
11pub struct Mat4 {
12 pub cols: [Vec4; 4],
14}
15
16impl Default for Mat4 {
17 fn default() -> Self {
18 Self::IDENTITY
19 }
20}
21
22impl Mat4 {
23 pub const ZERO: Self = Self {
25 cols: [Vec4::ZERO, Vec4::ZERO, Vec4::ZERO, Vec4::ZERO],
26 };
27
28 pub const IDENTITY: Self = Self {
30 cols: [Vec4::X, Vec4::Y, Vec4::Z, Vec4::W],
31 };
32
33 #[inline]
35 #[must_use]
36 pub const fn from_cols(c0: Vec4, c1: Vec4, c2: Vec4, c3: Vec4) -> Self {
37 Self {
38 cols: [c0, c1, c2, c3],
39 }
40 }
41
42 #[inline]
44 #[must_use]
45 pub fn from_rows(rows: [[Float; 4]; 4]) -> Self {
46 Self {
47 cols: [
48 Vec4::new(rows[0][0], rows[1][0], rows[2][0], rows[3][0]),
49 Vec4::new(rows[0][1], rows[1][1], rows[2][1], rows[3][1]),
50 Vec4::new(rows[0][2], rows[1][2], rows[2][2], rows[3][2]),
51 Vec4::new(rows[0][3], rows[1][3], rows[2][3], rows[3][3]),
52 ],
53 }
54 }
55
56 #[inline]
58 #[must_use]
59 pub fn from_cols_array(arr: [Float; 16]) -> Self {
60 Self {
61 cols: [
62 Vec4::from_array([arr[0], arr[1], arr[2], arr[3]]),
63 Vec4::from_array([arr[4], arr[5], arr[6], arr[7]]),
64 Vec4::from_array([arr[8], arr[9], arr[10], arr[11]]),
65 Vec4::from_array([arr[12], arr[13], arr[14], arr[15]]),
66 ],
67 }
68 }
69
70 #[inline]
72 #[must_use]
73 pub fn to_cols_array(self) -> [Float; 16] {
74 let c = self.cols;
75 [
76 c[0].x, c[0].y, c[0].z, c[0].w, c[1].x, c[1].y, c[1].z, c[1].w, c[2].x, c[2].y, c[2].z,
77 c[2].w, c[3].x, c[3].y, c[3].z, c[3].w,
78 ]
79 }
80
81 #[inline]
83 #[must_use]
84 pub fn from_diagonal(diag: Vec4) -> Self {
85 Self {
86 cols: [
87 Vec4::new(diag.x, 0.0, 0.0, 0.0),
88 Vec4::new(0.0, diag.y, 0.0, 0.0),
89 Vec4::new(0.0, 0.0, diag.z, 0.0),
90 Vec4::new(0.0, 0.0, 0.0, diag.w),
91 ],
92 }
93 }
94
95 #[inline]
97 #[must_use]
98 pub fn from_scale(scale: Vec3) -> Self {
99 Self::from_diagonal(Vec4::new(scale.x, scale.y, scale.z, 1.0))
100 }
101
102 #[inline]
104 #[must_use]
105 pub fn from_translation(translation: Vec3) -> Self {
106 Self {
107 cols: [
108 Vec4::X,
109 Vec4::Y,
110 Vec4::Z,
111 Vec4::new(translation.x, translation.y, translation.z, 1.0),
112 ],
113 }
114 }
115
116 #[inline]
118 #[must_use]
119 pub fn from_rotation_x(angle: Float) -> Self {
120 let (s, c) = angle.sin_cos();
121 Self {
122 cols: [
123 Vec4::X,
124 Vec4::new(0.0, c, s, 0.0),
125 Vec4::new(0.0, -s, c, 0.0),
126 Vec4::W,
127 ],
128 }
129 }
130
131 #[inline]
133 #[must_use]
134 pub fn from_rotation_y(angle: Float) -> Self {
135 let (s, c) = angle.sin_cos();
136 Self {
137 cols: [
138 Vec4::new(c, 0.0, -s, 0.0),
139 Vec4::Y,
140 Vec4::new(s, 0.0, c, 0.0),
141 Vec4::W,
142 ],
143 }
144 }
145
146 #[inline]
148 #[must_use]
149 pub fn from_rotation_z(angle: Float) -> Self {
150 let (s, c) = angle.sin_cos();
151 Self {
152 cols: [
153 Vec4::new(c, s, 0.0, 0.0),
154 Vec4::new(-s, c, 0.0, 0.0),
155 Vec4::Z,
156 Vec4::W,
157 ],
158 }
159 }
160
161 #[inline]
163 #[must_use]
164 pub fn from_axis_angle(axis: Vec3, angle: Float) -> Self {
165 let axis = axis.normalize();
166 let (s, c) = angle.sin_cos();
167 let t = 1.0 - c;
168 let x = axis.x;
169 let y = axis.y;
170 let z = axis.z;
171
172 Self {
173 cols: [
174 Vec4::new(t * x * x + c, t * x * y + s * z, t * x * z - s * y, 0.0),
175 Vec4::new(t * x * y - s * z, t * y * y + c, t * y * z + s * x, 0.0),
176 Vec4::new(t * x * z + s * y, t * y * z - s * x, t * z * z + c, 0.0),
177 Vec4::W,
178 ],
179 }
180 }
181
182 #[inline]
184 #[must_use]
185 pub fn look_at_rh(eye: Vec3, target: Vec3, up: Vec3) -> Self {
186 let f = (target - eye).normalize();
187 let r = f.cross(up).normalize();
188 let u = r.cross(f);
189
190 Self {
191 cols: [
192 Vec4::new(r.x, u.x, -f.x, 0.0),
193 Vec4::new(r.y, u.y, -f.y, 0.0),
194 Vec4::new(r.z, u.z, -f.z, 0.0),
195 Vec4::new(-r.dot(eye), -u.dot(eye), f.dot(eye), 1.0),
196 ],
197 }
198 }
199
200 #[inline]
202 #[must_use]
203 pub fn look_at_lh(eye: Vec3, target: Vec3, up: Vec3) -> Self {
204 let f = (target - eye).normalize();
205 let r = up.cross(f).normalize();
206 let u = f.cross(r);
207
208 Self {
209 cols: [
210 Vec4::new(r.x, u.x, f.x, 0.0),
211 Vec4::new(r.y, u.y, f.y, 0.0),
212 Vec4::new(r.z, u.z, f.z, 0.0),
213 Vec4::new(-r.dot(eye), -u.dot(eye), -f.dot(eye), 1.0),
214 ],
215 }
216 }
217
218 #[inline]
225 #[must_use]
226 pub fn perspective_rh(fov_y: Float, aspect: Float, near: Float, far: Float) -> Self {
227 let f = 1.0 / (fov_y / 2.0).tan();
228 let nf = 1.0 / (near - far);
229
230 Self {
231 cols: [
232 Vec4::new(f / aspect, 0.0, 0.0, 0.0),
233 Vec4::new(0.0, f, 0.0, 0.0),
234 Vec4::new(0.0, 0.0, far * nf, -1.0),
235 Vec4::new(0.0, 0.0, near * far * nf, 0.0),
236 ],
237 }
238 }
239
240 #[inline]
242 #[must_use]
243 pub fn orthographic_rh(
244 left: Float,
245 right: Float,
246 bottom: Float,
247 top: Float,
248 near: Float,
249 far: Float,
250 ) -> Self {
251 let rml = right - left;
252 let tmb = top - bottom;
253 let fmn = far - near;
254
255 Self {
256 cols: [
257 Vec4::new(2.0 / rml, 0.0, 0.0, 0.0),
258 Vec4::new(0.0, 2.0 / tmb, 0.0, 0.0),
259 Vec4::new(0.0, 0.0, -1.0 / fmn, 0.0),
260 Vec4::new(
261 -(right + left) / rml,
262 -(top + bottom) / tmb,
263 -near / fmn,
264 1.0,
265 ),
266 ],
267 }
268 }
269
270 #[inline]
272 #[must_use]
273 pub fn transpose(self) -> Self {
274 let c = self.cols;
275 Self {
276 cols: [
277 Vec4::new(c[0].x, c[1].x, c[2].x, c[3].x),
278 Vec4::new(c[0].y, c[1].y, c[2].y, c[3].y),
279 Vec4::new(c[0].z, c[1].z, c[2].z, c[3].z),
280 Vec4::new(c[0].w, c[1].w, c[2].w, c[3].w),
281 ],
282 }
283 }
284
285 #[inline]
287 #[must_use]
288 pub fn determinant(self) -> Float {
289 let c = self.cols;
290
291 let a = c[2].z * c[3].w - c[3].z * c[2].w;
292 let b = c[2].y * c[3].w - c[3].y * c[2].w;
293 let cc = c[2].y * c[3].z - c[3].y * c[2].z;
294 let d = c[2].x * c[3].w - c[3].x * c[2].w;
295 let e = c[2].x * c[3].z - c[3].x * c[2].z;
296 let f = c[2].x * c[3].y - c[3].x * c[2].y;
297
298 c[0].x * (c[1].y * a - c[1].z * b + c[1].w * cc)
299 - c[0].y * (c[1].x * a - c[1].z * d + c[1].w * e)
300 + c[0].z * (c[1].x * b - c[1].y * d + c[1].w * f)
301 - c[0].w * (c[1].x * cc - c[1].y * e + c[1].z * f)
302 }
303
304 #[must_use]
306 pub fn try_inverse(self) -> Option<Self> {
307 let det = self.determinant();
308 if det.abs() < super::EPSILON {
309 return None;
310 }
311 Some(self.inverse_unchecked(det))
312 }
313
314 #[must_use]
320 pub fn inverse(self) -> Self {
321 self.try_inverse().expect("Matrix is not invertible")
322 }
323
324 fn inverse_unchecked(self, det: Float) -> Self {
325 let c = self.cols;
326 let inv_det = 1.0 / det;
327
328 let c00 = c[1].y * (c[2].z * c[3].w - c[3].z * c[2].w)
329 - c[2].y * (c[1].z * c[3].w - c[3].z * c[1].w)
330 + c[3].y * (c[1].z * c[2].w - c[2].z * c[1].w);
331 let c01 = -(c[0].y * (c[2].z * c[3].w - c[3].z * c[2].w)
332 - c[2].y * (c[0].z * c[3].w - c[3].z * c[0].w)
333 + c[3].y * (c[0].z * c[2].w - c[2].z * c[0].w));
334 let c02 = c[0].y * (c[1].z * c[3].w - c[3].z * c[1].w)
335 - c[1].y * (c[0].z * c[3].w - c[3].z * c[0].w)
336 + c[3].y * (c[0].z * c[1].w - c[1].z * c[0].w);
337 let c03 = -(c[0].y * (c[1].z * c[2].w - c[2].z * c[1].w)
338 - c[1].y * (c[0].z * c[2].w - c[2].z * c[0].w)
339 + c[2].y * (c[0].z * c[1].w - c[1].z * c[0].w));
340
341 let c10 = -(c[1].x * (c[2].z * c[3].w - c[3].z * c[2].w)
342 - c[2].x * (c[1].z * c[3].w - c[3].z * c[1].w)
343 + c[3].x * (c[1].z * c[2].w - c[2].z * c[1].w));
344 let c11 = c[0].x * (c[2].z * c[3].w - c[3].z * c[2].w)
345 - c[2].x * (c[0].z * c[3].w - c[3].z * c[0].w)
346 + c[3].x * (c[0].z * c[2].w - c[2].z * c[0].w);
347 let c12 = -(c[0].x * (c[1].z * c[3].w - c[3].z * c[1].w)
348 - c[1].x * (c[0].z * c[3].w - c[3].z * c[0].w)
349 + c[3].x * (c[0].z * c[1].w - c[1].z * c[0].w));
350 let c13 = c[0].x * (c[1].z * c[2].w - c[2].z * c[1].w)
351 - c[1].x * (c[0].z * c[2].w - c[2].z * c[0].w)
352 + c[2].x * (c[0].z * c[1].w - c[1].z * c[0].w);
353
354 let c20 = c[1].x * (c[2].y * c[3].w - c[3].y * c[2].w)
355 - c[2].x * (c[1].y * c[3].w - c[3].y * c[1].w)
356 + c[3].x * (c[1].y * c[2].w - c[2].y * c[1].w);
357 let c21 = -(c[0].x * (c[2].y * c[3].w - c[3].y * c[2].w)
358 - c[2].x * (c[0].y * c[3].w - c[3].y * c[0].w)
359 + c[3].x * (c[0].y * c[2].w - c[2].y * c[0].w));
360 let c22 = c[0].x * (c[1].y * c[3].w - c[3].y * c[1].w)
361 - c[1].x * (c[0].y * c[3].w - c[3].y * c[0].w)
362 + c[3].x * (c[0].y * c[1].w - c[1].y * c[0].w);
363 let c23 = -(c[0].x * (c[1].y * c[2].w - c[2].y * c[1].w)
364 - c[1].x * (c[0].y * c[2].w - c[2].y * c[0].w)
365 + c[2].x * (c[0].y * c[1].w - c[1].y * c[0].w));
366
367 let c30 = -(c[1].x * (c[2].y * c[3].z - c[3].y * c[2].z)
368 - c[2].x * (c[1].y * c[3].z - c[3].y * c[1].z)
369 + c[3].x * (c[1].y * c[2].z - c[2].y * c[1].z));
370 let c31 = c[0].x * (c[2].y * c[3].z - c[3].y * c[2].z)
371 - c[2].x * (c[0].y * c[3].z - c[3].y * c[0].z)
372 + c[3].x * (c[0].y * c[2].z - c[2].y * c[0].z);
373 let c32 = -(c[0].x * (c[1].y * c[3].z - c[3].y * c[1].z)
374 - c[1].x * (c[0].y * c[3].z - c[3].y * c[0].z)
375 + c[3].x * (c[0].y * c[1].z - c[1].y * c[0].z));
376 let c33 = c[0].x * (c[1].y * c[2].z - c[2].y * c[1].z)
377 - c[1].x * (c[0].y * c[2].z - c[2].y * c[0].z)
378 + c[2].x * (c[0].y * c[1].z - c[1].y * c[0].z);
379
380 Self {
381 cols: [
382 Vec4::new(c00 * inv_det, c01 * inv_det, c02 * inv_det, c03 * inv_det),
383 Vec4::new(c10 * inv_det, c11 * inv_det, c12 * inv_det, c13 * inv_det),
384 Vec4::new(c20 * inv_det, c21 * inv_det, c22 * inv_det, c23 * inv_det),
385 Vec4::new(c30 * inv_det, c31 * inv_det, c32 * inv_det, c33 * inv_det),
386 ],
387 }
388 }
389
390 #[inline]
392 #[must_use]
393 pub fn transform_vec4(self, v: Vec4) -> Vec4 {
394 let c = self.cols;
395 c[0] * v.x + c[1] * v.y + c[2] * v.z + c[3] * v.w
396 }
397
398 #[inline]
400 #[must_use]
401 pub fn transform_point(self, p: Vec3) -> Vec3 {
402 let v = self.transform_vec4(Vec4::from_vec3(p, 1.0));
403 Vec3::new(v.x, v.y, v.z)
404 }
405
406 #[inline]
408 #[must_use]
409 pub fn transform_vector(self, v: Vec3) -> Vec3 {
410 let r = self.transform_vec4(Vec4::from_vec3(v, 0.0));
411 Vec3::new(r.x, r.y, r.z)
412 }
413
414 #[inline]
416 #[must_use]
417 pub fn approx_eq(self, other: Self) -> bool {
418 self.cols[0].approx_eq(other.cols[0])
419 && self.cols[1].approx_eq(other.cols[1])
420 && self.cols[2].approx_eq(other.cols[2])
421 && self.cols[3].approx_eq(other.cols[3])
422 }
423}
424
425impl std::ops::Mul for Mat4 {
426 type Output = Self;
427
428 #[inline]
429 fn mul(self, other: Self) -> Self {
430 Self {
431 cols: [
432 self.transform_vec4(other.cols[0]),
433 self.transform_vec4(other.cols[1]),
434 self.transform_vec4(other.cols[2]),
435 self.transform_vec4(other.cols[3]),
436 ],
437 }
438 }
439}
440
441impl std::ops::Mul<Vec4> for Mat4 {
442 type Output = Vec4;
443
444 #[inline]
445 fn mul(self, v: Vec4) -> Vec4 {
446 self.transform_vec4(v)
447 }
448}
449
450impl std::ops::Mul<Float> for Mat4 {
451 type Output = Self;
452
453 #[inline]
454 fn mul(self, s: Float) -> Self {
455 Self {
456 cols: [
457 self.cols[0] * s,
458 self.cols[1] * s,
459 self.cols[2] * s,
460 self.cols[3] * s,
461 ],
462 }
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::super::{PI, approx_eq};
469 use super::*;
470
471 #[test]
472 fn test_identity() {
473 let m = Mat4::IDENTITY;
474 assert!(m.approx_eq(Mat4::default()));
475 assert!(approx_eq(m.determinant(), 1.0));
476 }
477
478 #[test]
479 fn test_from_scale() {
480 let m = Mat4::from_scale(Vec3::new(2.0, 3.0, 4.0));
481 let p = m.transform_point(Vec3::ONE);
482 assert!(p.approx_eq(Vec3::new(2.0, 3.0, 4.0)));
483 }
484
485 #[test]
486 fn test_from_translation() {
487 let m = Mat4::from_translation(Vec3::new(1.0, 2.0, 3.0));
488 let p = m.transform_point(Vec3::ZERO);
489 assert!(p.approx_eq(Vec3::new(1.0, 2.0, 3.0)));
490 }
491
492 #[test]
493 fn test_from_rotation_x() {
494 let m = Mat4::from_rotation_x(PI / 2.0);
495 let v = m.transform_vector(Vec3::Y);
496 assert!(v.approx_eq(Vec3::Z));
497 }
498
499 #[test]
500 fn test_from_rotation_y() {
501 let m = Mat4::from_rotation_y(PI / 2.0);
502 let v = m.transform_vector(Vec3::Z);
503 assert!(v.approx_eq(Vec3::X));
504 }
505
506 #[test]
507 fn test_from_rotation_z() {
508 let m = Mat4::from_rotation_z(PI / 2.0);
509 let v = m.transform_vector(Vec3::X);
510 assert!(v.approx_eq(Vec3::Y));
511 }
512
513 #[test]
514 fn test_from_axis_angle() {
515 let m = Mat4::from_axis_angle(Vec3::Z, PI / 2.0);
516 let v = m.transform_vector(Vec3::X);
517 assert!(v.approx_eq(Vec3::Y));
518 }
519
520 #[test]
521 fn test_transpose() {
522 let m = Mat4::from_rows([
523 [1.0, 2.0, 3.0, 4.0],
524 [5.0, 6.0, 7.0, 8.0],
525 [9.0, 10.0, 11.0, 12.0],
526 [13.0, 14.0, 15.0, 16.0],
527 ]);
528 let t = m.transpose();
529 assert!(approx_eq(t.cols[0].x, 1.0));
532 assert!(approx_eq(t.cols[0].y, 2.0));
533 assert!(approx_eq(t.cols[0].z, 3.0));
534 assert!(approx_eq(t.cols[1].x, 5.0));
535 }
536
537 #[test]
538 fn test_determinant() {
539 let m = Mat4::from_scale(Vec3::new(2.0, 3.0, 4.0));
540 assert!(approx_eq(m.determinant(), 24.0));
541 }
542
543 #[test]
544 fn test_inverse() {
545 let m = Mat4::from_translation(Vec3::new(1.0, 2.0, 3.0));
546 let inv = m.inverse();
547 let result = m * inv;
548 assert!(result.approx_eq(Mat4::IDENTITY));
549 }
550
551 #[test]
552 fn test_inverse_scale() {
553 let m = Mat4::from_scale(Vec3::new(2.0, 3.0, 4.0));
554 let inv = m.inverse();
555 let result = m * inv;
556 assert!(result.approx_eq(Mat4::IDENTITY));
557 }
558
559 #[test]
560 fn test_inverse_rotation() {
561 let m = Mat4::from_rotation_y(0.5);
562 let inv = m.inverse();
563 let result = m * inv;
564 assert!(result.approx_eq(Mat4::IDENTITY));
565 }
566
567 #[test]
568 fn test_try_inverse_singular() {
569 let m = Mat4::ZERO;
570 assert!(m.try_inverse().is_none());
571 }
572
573 #[test]
574 fn test_mul_identity() {
575 let m = Mat4::from_translation(Vec3::new(1.0, 2.0, 3.0));
576 let result = m * Mat4::IDENTITY;
577 assert!(result.approx_eq(m));
578 }
579
580 #[test]
581 fn test_transform_point() {
582 let t = Mat4::from_translation(Vec3::new(1.0, 0.0, 0.0));
583 let s = Mat4::from_scale(Vec3::new(2.0, 2.0, 2.0));
584 let m = t * s; let p = m.transform_point(Vec3::ONE);
586 assert!(p.approx_eq(Vec3::new(3.0, 2.0, 2.0)));
587 }
588
589 #[test]
590 fn test_transform_vector() {
591 let t = Mat4::from_translation(Vec3::new(100.0, 100.0, 100.0));
592 let v = t.transform_vector(Vec3::X);
593 assert!(v.approx_eq(Vec3::X));
595 }
596
597 #[test]
598 fn test_look_at_rh() {
599 let eye = Vec3::new(0.0, 0.0, 5.0);
600 let target = Vec3::ZERO;
601 let up = Vec3::Y;
602 let m = Mat4::look_at_rh(eye, target, up);
603 let p = m.transform_point(Vec3::new(0.0, 0.0, 5.0));
604 assert!(p.approx_eq(Vec3::ZERO));
605 }
606
607 #[test]
608 fn test_perspective_rh() {
609 let m = Mat4::perspective_rh(PI / 4.0, 1.0, 0.1, 100.0);
610 let p = m.transform_vec4(Vec4::new(0.0, 0.0, -0.1, 1.0));
612 assert!(approx_eq(p.z / p.w, 0.0));
613 }
614
615 #[test]
616 fn test_orthographic_rh() {
617 let m = Mat4::orthographic_rh(-1.0, 1.0, -1.0, 1.0, 0.0, 1.0);
618 let p = m.transform_vec4(Vec4::new(0.0, 0.0, 0.0, 1.0));
619 assert!(p.approx_eq(Vec4::new(0.0, 0.0, 0.0, 1.0)));
620 }
621
622 #[test]
623 fn test_mul_scalar() {
624 let m = Mat4::IDENTITY * 2.0;
625 assert!(approx_eq(m.cols[0].x, 2.0));
626 assert!(approx_eq(m.cols[1].y, 2.0));
627 }
628}