1use std::ops::{Add, Mul, Sub};
2
3use crate::math::Vec3;
4
5const SINGULARITY_THRESHOLD: f64 = 1e-12;
6
7#[derive(Debug, Clone, Copy, PartialEq)]
8pub struct Mat3 {
9 pub data: [[f64; 3]; 3],
10}
11
12impl Mat3 {
13 #[must_use]
15 pub fn zero() -> Self {
16 Self {
17 data: [[0.0; 3]; 3],
18 }
19 }
20
21 #[must_use]
23 pub fn identity() -> Self {
24 Self {
25 data: [
26 [1.0, 0.0, 0.0],
27 [0.0, 1.0, 0.0],
28 [0.0, 0.0, 1.0],
29 ],
30 }
31 }
32
33 #[must_use]
35 pub fn from_rows(r0: [f64; 3], r1: [f64; 3], r2: [f64; 3]) -> Self {
36 Self { data: [r0, r1, r2] }
37 }
38
39 #[must_use]
41 pub fn determinant(&self) -> f64 {
42 let d = &self.data;
43 d[0][0] * (d[1][1] * d[2][2] - d[1][2] * d[2][1])
44 - d[0][1] * (d[1][0] * d[2][2] - d[1][2] * d[2][0])
45 + d[0][2] * (d[1][0] * d[2][1] - d[1][1] * d[2][0])
46 }
47
48 #[must_use]
50 pub fn transpose(&self) -> Self {
51 let d = &self.data;
52 Self {
53 data: [
54 [d[0][0], d[1][0], d[2][0]],
55 [d[0][1], d[1][1], d[2][1]],
56 [d[0][2], d[1][2], d[2][2]],
57 ],
58 }
59 }
60
61 #[must_use]
63 pub fn inverse(&self) -> Option<Self> {
64 let det = self.determinant();
65 if det.abs() < SINGULARITY_THRESHOLD {
66 return None;
67 }
68 let d = &self.data;
69 let inv_det = 1.0 / det;
70
71 Some(Self {
73 data: [
74 [
75 (d[1][1] * d[2][2] - d[1][2] * d[2][1]) * inv_det,
76 (d[0][2] * d[2][1] - d[0][1] * d[2][2]) * inv_det,
77 (d[0][1] * d[1][2] - d[0][2] * d[1][1]) * inv_det,
78 ],
79 [
80 (d[1][2] * d[2][0] - d[1][0] * d[2][2]) * inv_det,
81 (d[0][0] * d[2][2] - d[0][2] * d[2][0]) * inv_det,
82 (d[0][2] * d[1][0] - d[0][0] * d[1][2]) * inv_det,
83 ],
84 [
85 (d[1][0] * d[2][1] - d[1][1] * d[2][0]) * inv_det,
86 (d[0][1] * d[2][0] - d[0][0] * d[2][1]) * inv_det,
87 (d[0][0] * d[1][1] - d[0][1] * d[1][0]) * inv_det,
88 ],
89 ],
90 })
91 }
92
93 #[must_use]
95 pub fn trace(&self) -> f64 {
96 self.data[0][0] + self.data[1][1] + self.data[2][2]
97 }
98
99 #[must_use]
101 pub fn mul_vec(&self, v: Vec3) -> Vec3 {
102 let d = &self.data;
103 Vec3::new(
104 d[0][0] * v.x + d[0][1] * v.y + d[0][2] * v.z,
105 d[1][0] * v.x + d[1][1] * v.y + d[1][2] * v.z,
106 d[2][0] * v.x + d[2][1] * v.y + d[2][2] * v.z,
107 )
108 }
109
110 #[must_use]
112 pub fn mul_mat(&self, other: &Mat3) -> Mat3 {
113 let a = &self.data;
114 let b = &other.data;
115 let mut result = [[0.0; 3]; 3];
116 for i in 0..3 {
117 for j in 0..3 {
118 result[i][j] = a[i][0] * b[0][j] + a[i][1] * b[1][j] + a[i][2] * b[2][j];
119 }
120 }
121 Mat3 { data: result }
122 }
123
124 #[must_use]
126 pub fn scale(s: f64) -> Self {
127 Self {
128 data: [
129 [s, 0.0, 0.0],
130 [0.0, s, 0.0],
131 [0.0, 0.0, s],
132 ],
133 }
134 }
135
136 #[must_use]
138 pub fn mul_scalar(&self, s: f64) -> Self {
139 let d = &self.data;
140 Self {
141 data: [
142 [d[0][0] * s, d[0][1] * s, d[0][2] * s],
143 [d[1][0] * s, d[1][1] * s, d[1][2] * s],
144 [d[2][0] * s, d[2][1] * s, d[2][2] * s],
145 ],
146 }
147 }
148}
149
150impl Mul<Mat3> for Mat3 {
151 type Output = Mat3;
152 fn mul(self, rhs: Mat3) -> Mat3 {
153 self.mul_mat(&rhs)
154 }
155}
156
157impl Mul<Vec3> for Mat3 {
158 type Output = Vec3;
159 fn mul(self, rhs: Vec3) -> Vec3 {
160 self.mul_vec(rhs)
161 }
162}
163
164impl Add<Mat3> for Mat3 {
165 type Output = Mat3;
166 fn add(self, rhs: Mat3) -> Mat3 {
167 let mut result = [[0.0; 3]; 3];
168 for i in 0..3 {
169 for j in 0..3 {
170 result[i][j] = self.data[i][j] + rhs.data[i][j];
171 }
172 }
173 Mat3 { data: result }
174 }
175}
176
177impl Sub<Mat3> for Mat3 {
178 type Output = Mat3;
179 fn sub(self, rhs: Mat3) -> Mat3 {
180 let mut result = [[0.0; 3]; 3];
181 for i in 0..3 {
182 for j in 0..3 {
183 result[i][j] = self.data[i][j] - rhs.data[i][j];
184 }
185 }
186 Mat3 { data: result }
187 }
188}
189
190#[must_use]
194pub fn rotation_x(angle: f64) -> Mat3 {
195 let (s, c) = angle.sin_cos();
196 Mat3::from_rows(
197 [1.0, 0.0, 0.0],
198 [0.0, c, -s],
199 [0.0, s, c],
200 )
201}
202
203#[must_use]
205pub fn rotation_y(angle: f64) -> Mat3 {
206 let (s, c) = angle.sin_cos();
207 Mat3::from_rows(
208 [c, 0.0, s],
209 [0.0, 1.0, 0.0],
210 [-s, 0.0, c],
211 )
212}
213
214#[must_use]
216pub fn rotation_z(angle: f64) -> Mat3 {
217 let (s, c) = angle.sin_cos();
218 Mat3::from_rows(
219 [c, -s, 0.0],
220 [s, c, 0.0],
221 [0.0, 0.0, 1.0],
222 )
223}
224
225#[must_use]
228pub fn rotation_axis_angle(axis: Vec3, angle: f64) -> Mat3 {
229 let n = axis.normalized();
230 let (s, c) = angle.sin_cos();
231 let t = 1.0 - c;
232
233 Mat3::from_rows(
234 [
235 t * n.x * n.x + c,
236 t * n.x * n.y - s * n.z,
237 t * n.x * n.z + s * n.y,
238 ],
239 [
240 t * n.y * n.x + s * n.z,
241 t * n.y * n.y + c,
242 t * n.y * n.z - s * n.x,
243 ],
244 [
245 t * n.z * n.x - s * n.y,
246 t * n.z * n.y + s * n.x,
247 t * n.z * n.z + c,
248 ],
249 )
250}
251
252#[must_use]
256pub fn cartesian_to_spherical(x: f64, y: f64, z: f64) -> (f64, f64, f64) {
257 let r = (x * x + y * y + z * z).sqrt();
258 if r < SINGULARITY_THRESHOLD {
259 return (0.0, 0.0, 0.0);
260 }
261 let theta = (z / r).clamp(-1.0, 1.0).acos();
262 let phi = y.atan2(x);
263 (r, theta, phi)
264}
265
266#[must_use]
268pub fn spherical_to_cartesian(r: f64, theta: f64, phi: f64) -> (f64, f64, f64) {
269 let (sin_theta, cos_theta) = theta.sin_cos();
270 let (sin_phi, cos_phi) = phi.sin_cos();
271 (
272 r * sin_theta * cos_phi,
273 r * sin_theta * sin_phi,
274 r * cos_theta,
275 )
276}
277
278#[must_use]
280pub fn cartesian_to_cylindrical(x: f64, y: f64, z: f64) -> (f64, f64, f64) {
281 let rho = (x * x + y * y).sqrt();
282 let phi = y.atan2(x);
283 (rho, phi, z)
284}
285
286#[must_use]
288pub fn cylindrical_to_cartesian(rho: f64, phi: f64, z: f64) -> (f64, f64, f64) {
289 let (sin_phi, cos_phi) = phi.sin_cos();
290 (rho * cos_phi, rho * sin_phi, z)
291}
292
293#[must_use]
295pub fn polar_to_cartesian(r: f64, theta: f64) -> (f64, f64) {
296 let (sin_t, cos_t) = theta.sin_cos();
297 (r * cos_t, r * sin_t)
298}
299
300#[must_use]
302pub fn cartesian_to_polar(x: f64, y: f64) -> (f64, f64) {
303 let r = (x * x + y * y).sqrt();
304 let theta = y.atan2(x);
305 (r, theta)
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use crate::math::constants::PI;
312
313 const APPROX_EPSILON: f64 = 1e-9;
314
315 fn approx(a: f64, b: f64) -> bool {
316 (a - b).abs() < APPROX_EPSILON
317 }
318
319 fn mat3_approx_eq(a: &Mat3, b: &Mat3) -> bool {
320 for i in 0..3 {
321 for j in 0..3 {
322 if !approx(a.data[i][j], b.data[i][j]) {
323 return false;
324 }
325 }
326 }
327 true
328 }
329
330 #[test]
331 fn test_identity_determinant() {
332 assert!(approx(Mat3::identity().determinant(), 1.0));
333 }
334
335 #[test]
336 fn test_zero_determinant() {
337 assert!(approx(Mat3::zero().determinant(), 0.0));
338 }
339
340 #[test]
341 fn test_transpose_identity() {
342 assert_eq!(Mat3::identity().transpose(), Mat3::identity());
343 }
344
345 #[test]
346 fn test_trace() {
347 let m = Mat3::from_rows([2.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 5.0]);
348 assert!(approx(m.trace(), 10.0));
349 }
350
351 #[test]
352 fn test_inverse_times_original_is_identity() {
353 let m = Mat3::from_rows([1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]);
354 let inv = m.inverse().expect("matrix should be invertible");
355 let product = m * inv;
356 assert!(
357 mat3_approx_eq(&product, &Mat3::identity()),
358 "M * M^-1 should equal I, got {:?}",
359 product
360 );
361 }
362
363 #[test]
364 fn test_singular_matrix_has_no_inverse() {
365 let m = Mat3::from_rows([1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]);
366 assert!(m.inverse().is_none());
367 }
368
369 #[test]
370 fn test_mul_vec() {
371 let m = Mat3::identity();
372 let v = Vec3::new(1.0, 2.0, 3.0);
373 let result = m * v;
374 assert!(approx(result.x, 1.0) && approx(result.y, 2.0) && approx(result.z, 3.0));
375 }
376
377 #[test]
378 fn test_mul_mat_identity() {
379 let m = Mat3::from_rows([1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]);
380 let result = m * Mat3::identity();
381 assert!(mat3_approx_eq(&result, &m));
382 }
383
384 #[test]
385 fn test_add_sub() {
386 let a = Mat3::identity();
387 let b = Mat3::identity();
388 let sum = a + b;
389 assert!(approx(sum.data[0][0], 2.0));
390 let diff = sum - a;
391 assert!(mat3_approx_eq(&diff, &Mat3::identity()));
392 }
393
394 #[test]
395 fn test_scale() {
396 let s = Mat3::scale(3.0);
397 let v = Vec3::new(1.0, 2.0, 3.0);
398 let result = s * v;
399 assert!(approx(result.x, 3.0) && approx(result.y, 6.0) && approx(result.z, 9.0));
400 }
401
402 #[test]
403 fn test_mul_scalar() {
404 let m = Mat3::identity();
405 let scaled = m.mul_scalar(5.0);
406 assert!(approx(scaled.data[0][0], 5.0));
407 assert!(approx(scaled.data[0][1], 0.0));
408 }
409
410 #[test]
413 fn test_rotation_z_90_maps_x_to_y() {
414 let r = rotation_z(PI / 2.0);
415 let x_hat = Vec3::new(1.0, 0.0, 0.0);
416 let result = r * x_hat;
417 assert!(
418 approx(result.x, 0.0) && approx(result.y, 1.0) && approx(result.z, 0.0),
419 "90-deg rotation about z should map x-hat to y-hat, got {:?}",
420 result
421 );
422 }
423
424 #[test]
425 fn test_rotation_x_90_maps_y_to_z() {
426 let r = rotation_x(PI / 2.0);
427 let y_hat = Vec3::new(0.0, 1.0, 0.0);
428 let result = r * y_hat;
429 assert!(
430 approx(result.x, 0.0) && approx(result.y, 0.0) && approx(result.z, 1.0),
431 "90-deg rotation about x should map y-hat to z-hat, got {:?}",
432 result
433 );
434 }
435
436 #[test]
437 fn test_rotation_y_90_maps_z_to_x() {
438 let r = rotation_y(PI / 2.0);
439 let z_hat = Vec3::new(0.0, 0.0, 1.0);
440 let result = r * z_hat;
441 assert!(
442 approx(result.x, 1.0) && approx(result.y, 0.0) && approx(result.z, 0.0),
443 "90-deg rotation about y should map z-hat to x-hat, got {:?}",
444 result
445 );
446 }
447
448 #[test]
449 fn test_rotation_matrix_is_orthogonal() {
450 let r = rotation_axis_angle(Vec3::new(1.0, 1.0, 1.0), 1.23);
451 let rt_r = r.transpose() * r;
452 assert!(
453 mat3_approx_eq(&rt_r, &Mat3::identity()),
454 "R^T * R should equal I for rotation matrices, got {:?}",
455 rt_r
456 );
457 }
458
459 #[test]
460 fn test_rotation_matrix_determinant_is_one() {
461 let r = rotation_axis_angle(Vec3::new(0.0, 1.0, 0.0), 0.75);
462 let det = r.determinant();
463 assert!(
464 approx(det, 1.0),
465 "Rotation matrix determinant should be 1, got {det}",
466 );
467 }
468
469 #[test]
470 fn test_axis_angle_matches_rotation_z() {
471 let angle = 1.2;
472 let rz = rotation_z(angle);
473 let raa = rotation_axis_angle(Vec3::new(0.0, 0.0, 1.0), angle);
474 assert!(
475 mat3_approx_eq(&rz, &raa),
476 "Axis-angle about z should match rotation_z"
477 );
478 }
479
480 #[test]
483 fn test_cartesian_spherical_roundtrip() {
484 let (x, y, z) = (3.0, 4.0, 5.0);
485 let (r, theta, phi) = cartesian_to_spherical(x, y, z);
486 let (x2, y2, z2) = spherical_to_cartesian(r, theta, phi);
487 assert!(
488 approx(x, x2) && approx(y, y2) && approx(z, z2),
489 "Spherical roundtrip failed: ({x}, {y}, {z}) -> ({x2}, {y2}, {z2})"
490 );
491 }
492
493 #[test]
494 fn test_cartesian_cylindrical_roundtrip() {
495 let (x, y, z) = (-2.0, 7.0, 3.5);
496 let (rho, phi, z_cyl) = cartesian_to_cylindrical(x, y, z);
497 let (x2, y2, z2) = cylindrical_to_cartesian(rho, phi, z_cyl);
498 assert!(
499 approx(x, x2) && approx(y, y2) && approx(z, z2),
500 "Cylindrical roundtrip failed: ({x}, {y}, {z}) -> ({x2}, {y2}, {z2})"
501 );
502 }
503
504 #[test]
505 fn test_polar_roundtrip() {
506 let (x, y) = (3.0, -4.0);
507 let (r, theta) = cartesian_to_polar(x, y);
508 let (x2, y2) = polar_to_cartesian(r, theta);
509 assert!(
510 approx(x, x2) && approx(y, y2),
511 "Polar roundtrip failed: ({x}, {y}) -> ({x2}, {y2})"
512 );
513 }
514
515 #[test]
516 fn test_spherical_known_values() {
517 let (r, theta, _phi) = cartesian_to_spherical(0.0, 0.0, 5.0);
519 assert!(approx(r, 5.0));
520 assert!(approx(theta, 0.0));
521
522 let (r, theta, phi) = cartesian_to_spherical(3.0, 0.0, 0.0);
524 assert!(approx(r, 3.0));
525 assert!(approx(theta, PI / 2.0));
526 assert!(approx(phi, 0.0));
527 }
528
529 #[test]
530 fn test_origin_spherical() {
531 let (r, theta, phi) = cartesian_to_spherical(0.0, 0.0, 0.0);
532 assert!(approx(r, 0.0) && approx(theta, 0.0) && approx(phi, 0.0));
533 }
534
535 #[test]
536 fn test_mul_vec_non_identity() {
537 let m = Mat3::from_rows([2.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 4.0]);
538 let v = Vec3::new(1.0, 2.0, 3.0);
539 let result = m.mul_vec(v);
540 assert!(approx(result.x, 2.0) && approx(result.y, 6.0) && approx(result.z, 12.0));
541 }
542
543 #[test]
544 fn test_mul_mat_non_trivial() {
545 let a = Mat3::from_rows([1.0, 2.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]);
546 let b = Mat3::from_rows([1.0, 0.0, 0.0], [3.0, 1.0, 0.0], [0.0, 0.0, 1.0]);
547 let c = a.mul_mat(&b);
548 assert!(approx(c.data[0][0], 7.0), "got {}", c.data[0][0]);
550 assert!(approx(c.data[0][1], 2.0), "got {}", c.data[0][1]);
551 assert!(approx(c.data[1][0], 3.0), "got {}", c.data[1][0]);
552 }
553
554 #[test]
555 fn test_mat3_approx_eq_different() {
556 let a = Mat3::identity();
557 let b = Mat3::from_rows([2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]);
558 assert!(!mat3_approx_eq(&a, &b));
559 }
560}