1use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3use crate::Scalar;
4use crate::tensor::Vector3;
5
6pub struct Matrix3<T> {
8 cols: [Vector3<T>; 3],
9}
10
11impl<T: ::core::marker::Copy> ::core::marker::Copy for Matrix3<T> {}
12
13impl<T: ::core::clone::Clone> ::core::clone::Clone for Matrix3<T> {
14 #[inline]
15 fn clone(&self) -> Self {
16 Self {
17 cols: [
18 self.cols[0].clone(),
19 self.cols[1].clone(),
20 self.cols[2].clone(),
21 ],
22 }
23 }
24}
25
26impl<T: ::core::fmt::Debug> ::core::fmt::Debug for Matrix3<T> {
27 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
28 f.debug_struct("Matrix3")
29 .field("x_col", &self.cols[0])
30 .field("y_col", &self.cols[1])
31 .field("z_col", &self.cols[2])
32 .finish()
33 }
34}
35
36impl<T: ::core::cmp::PartialEq> ::core::cmp::PartialEq for Matrix3<T> {
37 #[inline]
38 fn eq(&self, other: &Self) -> bool {
39 self.cols[0] == other.cols[0]
40 && self.cols[1] == other.cols[1]
41 && self.cols[2] == other.cols[2]
42 }
43}
44
45impl<T> Matrix3<T> {
46 #[inline]
48 pub const fn from_cols(x_col: Vector3<T>, y_col: Vector3<T>, z_col: Vector3<T>) -> Self {
49 Self {
50 cols: [x_col, y_col, z_col],
51 }
52 }
53
54 #[inline]
57 pub fn map<U, F: FnMut(T) -> U>(self, mut f: F) -> Matrix3<U> {
58 let [c0, c1, c2] = self.cols;
59 Matrix3::from_cols(c0.map(&mut f), c1.map(&mut f), c2.map(&mut f))
60 }
61
62 #[inline]
65 pub fn zip_map<U, R, F: FnMut(T, U) -> R>(self, rhs: Matrix3<U>, mut f: F) -> Matrix3<R> {
66 let [c0, c1, c2] = self.cols;
67 let [r0, r1, r2] = rhs.cols;
68 Matrix3::from_cols(
69 c0.zip_map(r0, &mut f),
70 c1.zip_map(r1, &mut f),
71 c2.zip_map(r2, &mut f),
72 )
73 }
74}
75
76impl<T: Copy> Matrix3<T> {
77 #[inline]
79 pub fn from_rows(x_row: Vector3<T>, y_row: Vector3<T>, z_row: Vector3<T>) -> Self {
80 Self::from_cols(
81 Vector3::new(x_row.x, y_row.x, z_row.x),
82 Vector3::new(x_row.y, y_row.y, z_row.y),
83 Vector3::new(x_row.z, y_row.z, z_row.z),
84 )
85 }
86
87 #[inline]
89 pub fn from_cols_array(m: &[T; 9]) -> Self {
90 Self::from_cols(
91 Vector3::new(m[0], m[1], m[2]),
92 Vector3::new(m[3], m[4], m[5]),
93 Vector3::new(m[6], m[7], m[8]),
94 )
95 }
96
97 #[inline]
99 pub fn to_cols_array(&self) -> [T; 9] {
100 [
101 self.cols[0].x,
102 self.cols[0].y,
103 self.cols[0].z,
104 self.cols[1].x,
105 self.cols[1].y,
106 self.cols[1].z,
107 self.cols[2].x,
108 self.cols[2].y,
109 self.cols[2].z,
110 ]
111 }
112
113 #[inline]
119 pub fn col(&self, index: usize) -> Vector3<T> {
120 self.cols[index]
121 }
122
123 #[inline]
129 pub fn row(&self, index: usize) -> Vector3<T> {
130 Vector3::new(
131 self.cols[0][index],
132 self.cols[1][index],
133 self.cols[2][index],
134 )
135 }
136
137 #[inline]
139 pub fn transpose(&self) -> Self {
140 Self::from_cols(self.row(0), self.row(1), self.row(2))
141 }
142}
143
144impl<T: Copy + Default> Matrix3<T> {
145 #[inline]
148 pub fn from_diagonal(diagonal: Vector3<T>) -> Self {
149 let zero = T::default();
150 Self::from_cols(
151 Vector3::new(diagonal.x, zero, zero),
152 Vector3::new(zero, diagonal.y, zero),
153 Vector3::new(zero, zero, diagonal.z),
154 )
155 }
156
157 #[inline]
159 pub fn diagonal(&self) -> Vector3<T> {
160 Vector3::new(self.cols[0].x, self.cols[1].y, self.cols[2].z)
161 }
162}
163
164impl<T: Default> Default for Matrix3<T> {
165 #[inline]
167 fn default() -> Self {
168 Self::from_cols(Vector3::default(), Vector3::default(), Vector3::default())
169 }
170}
171
172impl<T: Neg<Output = T>> Neg for Matrix3<T> {
173 type Output = Self;
174 #[inline]
176 fn neg(self) -> Self {
177 let [c0, c1, c2] = self.cols;
178 Self::from_cols(-c0, -c1, -c2)
179 }
180}
181
182impl<T: Add<Output = T>> Add for Matrix3<T> {
183 type Output = Self;
184 #[inline]
186 fn add(self, rhs: Self) -> Self {
187 let [a0, a1, a2] = self.cols;
188 let [b0, b1, b2] = rhs.cols;
189 Self::from_cols(a0 + b0, a1 + b1, a2 + b2)
190 }
191}
192
193impl<T: AddAssign> AddAssign for Matrix3<T> {
194 #[inline]
195 fn add_assign(&mut self, rhs: Self) {
196 let [b0, b1, b2] = rhs.cols;
197 self.cols[0] += b0;
198 self.cols[1] += b1;
199 self.cols[2] += b2;
200 }
201}
202
203impl<T: Sub<Output = T>> Sub for Matrix3<T> {
204 type Output = Self;
205 #[inline]
207 fn sub(self, rhs: Self) -> Self {
208 let [a0, a1, a2] = self.cols;
209 let [b0, b1, b2] = rhs.cols;
210 Self::from_cols(a0 - b0, a1 - b1, a2 - b2)
211 }
212}
213
214impl<T: SubAssign> SubAssign for Matrix3<T> {
215 #[inline]
216 fn sub_assign(&mut self, rhs: Self) {
217 let [b0, b1, b2] = rhs.cols;
218 self.cols[0] -= b0;
219 self.cols[1] -= b1;
220 self.cols[2] -= b2;
221 }
222}
223
224impl<V: Scalar> Matrix3<V> {
225 pub const ZERO: Self = Self::from_cols(Vector3::ZERO, Vector3::ZERO, Vector3::ZERO);
227
228 pub const IDENTITY: Self = Self::from_cols(Vector3::X, Vector3::Y, Vector3::Z);
230
231 #[inline]
234 pub fn from_scale(scale: Vector3<V>) -> Self {
235 Self::from_cols(
236 Vector3::new(scale.x, V::ZERO, V::ZERO),
237 Vector3::new(V::ZERO, scale.y, V::ZERO),
238 Vector3::new(V::ZERO, V::ZERO, scale.z),
239 )
240 }
241
242 #[inline]
245 pub fn outer_product(a: Vector3<V>, b: Vector3<V>) -> Self {
246 Self::from_cols(a * b.x, a * b.y, a * b.z)
247 }
248
249 #[inline]
252 pub fn from_rotation_x(angle: V) -> Self {
253 let (sin, cos) = angle.sin_cos();
254 Self::from_cols(
255 Vector3::new(V::ONE, V::ZERO, V::ZERO),
256 Vector3::new(V::ZERO, cos, sin),
257 Vector3::new(V::ZERO, -sin, cos),
258 )
259 }
260
261 #[inline]
264 pub fn from_rotation_y(angle: V) -> Self {
265 let (sin, cos) = angle.sin_cos();
266 Self::from_cols(
267 Vector3::new(cos, V::ZERO, -sin),
268 Vector3::new(V::ZERO, V::ONE, V::ZERO),
269 Vector3::new(sin, V::ZERO, cos),
270 )
271 }
272
273 #[inline]
276 pub fn from_rotation_z(angle: V) -> Self {
277 let (sin, cos) = angle.sin_cos();
278 Self::from_cols(
279 Vector3::new(cos, sin, V::ZERO),
280 Vector3::new(-sin, cos, V::ZERO),
281 Vector3::new(V::ZERO, V::ZERO, V::ONE),
282 )
283 }
284
285 #[inline]
291 pub fn from_axis_angle(axis: Vector3<V>, angle: V) -> Self {
292 let (sin, cos) = angle.sin_cos();
293 let t = V::ONE - cos;
294 let Vector3 { x, y, z } = axis;
295 Self::from_cols(
296 Vector3::new(t * x * x + cos, t * x * y + sin * z, t * x * z - sin * y),
297 Vector3::new(t * x * y - sin * z, t * y * y + cos, t * y * z + sin * x),
298 Vector3::new(t * x * z + sin * y, t * y * z - sin * x, t * z * z + cos),
299 )
300 }
301
302 #[inline]
304 pub fn trace(&self) -> V {
305 self.cols[0].x + self.cols[1].y + self.cols[2].z
306 }
307
308 #[inline]
310 pub fn determinant(&self) -> V {
311 self.cols[0].dot(self.cols[1].cross(self.cols[2]))
312 }
313
314 #[inline]
317 pub fn is_invertible(&self) -> bool {
318 let det = self.determinant();
319 det != V::ZERO && det.is_finite()
320 }
321
322 #[inline]
324 pub fn try_inverse(&self) -> Option<Self> {
325 let r0 = self.cols[1].cross(self.cols[2]);
326 let r1 = self.cols[2].cross(self.cols[0]);
327 let r2 = self.cols[0].cross(self.cols[1]);
328 let det = self.cols[0].dot(r0);
329 if det == V::ZERO || !det.is_finite() {
330 return None;
331 }
332 let inv_det = det.recip();
333 Some(Self::from_rows(r0 * inv_det, r1 * inv_det, r2 * inv_det))
334 }
335
336 #[inline]
343 pub fn inverse(&self) -> Self {
344 self.try_inverse().expect("matrix is not invertible")
345 }
346}
347
348impl<T: Mul<S, Output = T> + Copy, S: Scalar> Mul<S> for Matrix3<T> {
349 type Output = Self;
350 #[inline]
352 fn mul(self, rhs: S) -> Self {
353 let [c0, c1, c2] = self.cols;
354 Self::from_cols(c0 * rhs, c1 * rhs, c2 * rhs)
355 }
356}
357
358impl<T: MulAssign<S> + Copy, S: Scalar> MulAssign<S> for Matrix3<T> {
359 #[inline]
360 fn mul_assign(&mut self, rhs: S) {
361 self.cols[0] *= rhs;
362 self.cols[1] *= rhs;
363 self.cols[2] *= rhs;
364 }
365}
366
367impl<T: Div<S, Output = T> + Copy, S: Scalar> Div<S> for Matrix3<T> {
368 type Output = Self;
369 #[inline]
371 fn div(self, rhs: S) -> Self {
372 let [c0, c1, c2] = self.cols;
373 Self::from_cols(c0 / rhs, c1 / rhs, c2 / rhs)
374 }
375}
376
377impl<T: DivAssign<S> + Copy, S: Scalar> DivAssign<S> for Matrix3<T> {
378 #[inline]
379 fn div_assign(&mut self, rhs: S) {
380 self.cols[0] /= rhs;
381 self.cols[1] /= rhs;
382 self.cols[2] /= rhs;
383 }
384}
385
386impl<V: Scalar> Mul<Vector3<V>> for Matrix3<V> {
387 type Output = Vector3<V>;
388 #[inline]
390 fn mul(self, rhs: Vector3<V>) -> Vector3<V> {
391 self.cols[0] * rhs.x + self.cols[1] * rhs.y + self.cols[2] * rhs.z
392 }
393}
394
395impl<V: Scalar> Mul for Matrix3<V> {
396 type Output = Self;
397 #[inline]
399 fn mul(self, rhs: Self) -> Self {
400 Self::from_cols(self * rhs.cols[0], self * rhs.cols[1], self * rhs.cols[2])
401 }
402}
403
404impl<V: Scalar> MulAssign for Matrix3<V> {
405 #[inline]
406 fn mul_assign(&mut self, rhs: Self) {
407 *self = *self * rhs;
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use core::f64::consts::FRAC_PI_2;
415
416 fn columns() -> Matrix3<f64> {
417 Matrix3::from_cols(
418 Vector3::new(1.0, 2.0, 3.0),
419 Vector3::new(4.0, 5.0, 6.0),
420 Vector3::new(7.0, 8.0, 9.0),
421 )
422 }
423
424 #[test]
425 fn from_cols() {
426 let m = columns();
427 assert_eq!(m.col(0), Vector3::new(1.0, 2.0, 3.0));
428 assert_eq!(m.col(1), Vector3::new(4.0, 5.0, 6.0));
429 assert_eq!(m.col(2), Vector3::new(7.0, 8.0, 9.0));
430 }
431
432 #[test]
433 fn from_rows() {
434 let m = Matrix3::from_rows(
435 Vector3::new(1.0, 2.0, 3.0),
436 Vector3::new(4.0, 5.0, 6.0),
437 Vector3::new(7.0, 8.0, 9.0),
438 );
439 assert_eq!(m.row(0), Vector3::new(1.0, 2.0, 3.0));
440 assert_eq!(m.row(1), Vector3::new(4.0, 5.0, 6.0));
441 assert_eq!(m.row(2), Vector3::new(7.0, 8.0, 9.0));
442 }
443
444 #[test]
445 fn from_cols_array() {
446 let m = Matrix3::from_cols_array(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
447 assert_eq!(m, columns());
448 }
449
450 #[test]
451 fn to_cols_array() {
452 assert_eq!(
453 columns().to_cols_array(),
454 [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
455 );
456 }
457
458 #[test]
459 fn from_diagonal() {
460 let m = Matrix3::from_diagonal(Vector3::new(1.0, 2.0, 3.0));
461 assert_eq!(
462 m.to_cols_array(),
463 [1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]
464 );
465 }
466
467 #[test]
468 fn diagonal() {
469 assert_eq!(columns().diagonal(), Vector3::new(1.0, 5.0, 9.0));
470 }
471
472 #[test]
473 fn map() {
474 assert_eq!(
475 columns().map(|e| e * 2.0).to_cols_array(),
476 [2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0]
477 );
478 }
479
480 #[test]
481 fn zip_map() {
482 let sum = columns().zip_map(columns(), |a, b| a + b);
483 assert_eq!(sum, columns().map(|e| e * 2.0));
484 }
485
486 #[test]
487 fn col() {
488 assert_eq!(columns().col(2), Vector3::new(7.0, 8.0, 9.0));
489 }
490
491 #[test]
492 #[should_panic]
493 fn col_panics_when_out_of_bounds() {
494 let _ = columns().col(3);
495 }
496
497 #[test]
498 fn row() {
499 assert_eq!(columns().row(1), Vector3::new(2.0, 5.0, 8.0));
500 }
501
502 #[test]
503 fn transpose() {
504 assert_eq!(
505 columns().transpose().to_cols_array(),
506 [1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0]
507 );
508 assert_eq!(columns().transpose().transpose(), columns());
509 }
510
511 #[test]
512 fn default_is_zero() {
513 assert_eq!(Matrix3::<f64>::default(), Matrix3::ZERO);
514 }
515
516 #[test]
517 fn copy_and_clone() {
518 let a = columns();
519 let b = a;
520 let c = ::core::clone::Clone::clone(&a);
521 assert_eq!(a, b);
522 assert_eq!(a, c);
523 }
524
525 #[test]
526 fn eq() {
527 assert_eq!(columns(), columns());
528 assert_ne!(columns(), Matrix3::<f64>::IDENTITY);
529 }
530
531 #[test]
532 fn debug() {
533 assert_eq!(
534 format!("{:?}", Matrix3::<f64>::IDENTITY),
535 concat!(
536 "Matrix3 { ",
537 "x_col: Vector3 { x: 1.0, y: 0.0, z: 0.0 }, ",
538 "y_col: Vector3 { x: 0.0, y: 1.0, z: 0.0 }, ",
539 "z_col: Vector3 { x: 0.0, y: 0.0, z: 1.0 } }"
540 )
541 );
542 }
543
544 #[test]
545 fn zero_constant() {
546 assert_eq!(Matrix3::<f64>::ZERO.to_cols_array(), [0.0; 9]);
547 }
548
549 #[test]
550 fn identity_constant() {
551 assert_eq!(
552 Matrix3::<f64>::IDENTITY.to_cols_array(),
553 [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]
554 );
555 }
556
557 #[test]
558 fn neg() {
559 assert_eq!((-columns()), columns().map(|e| -e));
560 }
561
562 #[test]
563 fn add() {
564 assert_eq!(columns() + columns(), columns().map(|e| e * 2.0));
565 }
566
567 #[test]
568 fn add_assign() {
569 let mut m = columns();
570 m += columns();
571 assert_eq!(m, columns().map(|e| e * 2.0));
572 }
573
574 #[test]
575 fn sub() {
576 assert_eq!(columns() - columns(), Matrix3::ZERO);
577 }
578
579 #[test]
580 fn sub_assign() {
581 let mut m = columns();
582 m -= columns();
583 assert_eq!(m, Matrix3::ZERO);
584 }
585
586 #[test]
587 fn mul_scalar() {
588 assert_eq!(columns() * 2.0, columns().map(|e| e * 2.0));
589 }
590
591 #[test]
592 fn mul_assign_scalar() {
593 let mut m = columns();
594 m *= 2.0;
595 assert_eq!(m, columns().map(|e| e * 2.0));
596 }
597
598 #[test]
599 fn div_scalar() {
600 assert_eq!(columns().map(|e| e * 2.0) / 2.0, columns());
601 }
602
603 #[test]
604 fn div_assign_scalar() {
605 let mut m = columns().map(|e| e * 2.0);
606 m /= 2.0;
607 assert_eq!(m, columns());
608 }
609
610 #[test]
611 fn mul_vector() {
612 assert_eq!(
613 Matrix3::<f64>::IDENTITY * Vector3::new(1.0, 2.0, 3.0),
614 Vector3::new(1.0, 2.0, 3.0)
615 );
616 assert_eq!(
617 Matrix3::from_scale(Vector3::new(2.0, 3.0, 4.0)) * Vector3::new(1.0, 1.0, 1.0),
618 Vector3::new(2.0, 3.0, 4.0)
619 );
620 }
621
622 #[test]
623 fn mul_matrix() {
624 assert_eq!(Matrix3::<f64>::IDENTITY * columns(), columns());
625 assert_eq!(
626 Matrix3::from_scale(Vector3::new(2.0, 2.0, 2.0))
627 * Matrix3::from_scale(Vector3::new(3.0, 3.0, 3.0)),
628 Matrix3::from_scale(Vector3::new(6.0, 6.0, 6.0))
629 );
630 }
631
632 #[test]
633 fn mul_assign_matrix() {
634 let mut m = columns();
635 m *= Matrix3::IDENTITY;
636 assert_eq!(m, columns());
637 }
638
639 #[test]
640 fn from_scale() {
641 assert_eq!(
642 Matrix3::from_scale(Vector3::new(2.0, 3.0, 4.0)).to_cols_array(),
643 [2.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 4.0]
644 );
645 }
646
647 #[test]
648 fn outer_product() {
649 assert_eq!(
650 Matrix3::outer_product(Vector3::new(1.0, 2.0, 3.0), Vector3::new(4.0, 5.0, 6.0)),
651 Matrix3::from_cols(
652 Vector3::new(4.0, 8.0, 12.0),
653 Vector3::new(5.0, 10.0, 15.0),
654 Vector3::new(6.0, 12.0, 18.0),
655 )
656 );
657 }
658
659 #[test]
660 fn from_rotation_x() {
661 let m = Matrix3::from_rotation_x(FRAC_PI_2);
662 assert!((m * Vector3::Y - Vector3::Z).norm() < 1e-12);
663 assert!((m * Vector3::Z - (-Vector3::<f64>::Y)).norm() < 1e-12);
664 }
665
666 #[test]
667 fn from_rotation_y() {
668 let m = Matrix3::from_rotation_y(FRAC_PI_2);
669 assert!((m * Vector3::Z - Vector3::X).norm() < 1e-12);
670 assert!((m * Vector3::X - (-Vector3::<f64>::Z)).norm() < 1e-12);
671 }
672
673 #[test]
674 fn from_rotation_z() {
675 let m = Matrix3::from_rotation_z(FRAC_PI_2);
676 assert!((m * Vector3::X - Vector3::Y).norm() < 1e-12);
677 assert!((m * Vector3::Y - (-Vector3::<f64>::X)).norm() < 1e-12);
678 }
679
680 #[test]
681 fn from_axis_angle() {
682 let a = Matrix3::from_axis_angle(Vector3::Z, 0.7).to_cols_array();
683 let b = Matrix3::from_rotation_z(0.7).to_cols_array();
684 for i in 0..9 {
685 assert!((a[i] - b[i]).abs() < 1e-12);
686 }
687 }
688
689 #[test]
690 fn trace() {
691 assert_eq!(columns().trace(), 15.0);
692 }
693
694 #[test]
695 fn determinant() {
696 assert_eq!(Matrix3::<f64>::IDENTITY.determinant(), 1.0);
697 assert_eq!(
698 Matrix3::from_diagonal(Vector3::new(2.0, 3.0, 4.0)).determinant(),
699 24.0
700 );
701 }
702
703 #[test]
704 fn is_invertible() {
705 assert!(Matrix3::<f64>::IDENTITY.is_invertible());
706 assert!(!Matrix3::from_scale(Vector3::new(1.0, 0.0, 1.0)).is_invertible());
707 }
708
709 #[test]
710 fn try_inverse() {
711 assert_eq!(
712 Matrix3::from_scale(Vector3::new(2.0, 4.0, 8.0)).try_inverse(),
713 Some(Matrix3::from_scale(Vector3::new(0.5, 0.25, 0.125)))
714 );
715 }
716
717 #[test]
718 fn try_inverse_singular_is_none() {
719 assert_eq!(
720 Matrix3::from_scale(Vector3::new(1.0, 0.0, 1.0)).try_inverse(),
721 None
722 );
723 }
724
725 #[test]
726 fn inverse() {
727 assert_eq!(
728 Matrix3::from_scale(Vector3::new(2.0, 4.0, 8.0)).inverse(),
729 Matrix3::from_scale(Vector3::new(0.5, 0.25, 0.125))
730 );
731 }
732
733 #[test]
734 #[should_panic]
735 fn inverse_panics_when_singular() {
736 Matrix3::from_scale(Vector3::new(1.0, 0.0, 1.0)).inverse();
737 }
738
739 #[test]
740 fn inverse_roundtrip() {
741 let m = Matrix3::from_cols(
742 Vector3::new(2.0, 1.0, 0.0),
743 Vector3::new(1.0, 2.0, 1.0),
744 Vector3::new(0.0, 1.0, 2.0),
745 );
746 let product = (m * m.inverse()).to_cols_array();
747 let identity = Matrix3::<f64>::IDENTITY.to_cols_array();
748 for i in 0..9 {
749 assert!((product[i] - identity[i]).abs() < 1e-12);
750 }
751 }
752
753 #[test]
754 fn f32_mul_vector() {
755 assert_eq!(
756 Matrix3::from_scale(Vector3::<f32>::new(2.0, 3.0, 4.0)) * Vector3::new(1.0, 1.0, 1.0),
757 Vector3::new(2.0, 3.0, 4.0)
758 );
759 }
760}