Skip to main content

rust_physics_engine/
linalg.rs

1use std::ops::{Add, Mul, Sub};
2
3use crate::math::Vec3;
4
5const SINGULARITY_THRESHOLD: f64 = 1e-12;
6
7#[derive(Debug, Clone, Copy, PartialEq)]
8pub struct Mat3 {
9    pub data: [[f64; 3]; 3],
10}
11
12impl Mat3 {
13    /// Returns the 3x3 zero matrix.
14    #[must_use]
15    pub fn zero() -> Self {
16        Self {
17            data: [[0.0; 3]; 3],
18        }
19    }
20
21    /// Returns the 3x3 identity matrix.
22    #[must_use]
23    pub fn identity() -> Self {
24        Self {
25            data: [
26                [1.0, 0.0, 0.0],
27                [0.0, 1.0, 0.0],
28                [0.0, 0.0, 1.0],
29            ],
30        }
31    }
32
33    /// Constructs a 3x3 matrix from three row arrays.
34    #[must_use]
35    pub fn from_rows(r0: [f64; 3], r1: [f64; 3], r2: [f64; 3]) -> Self {
36        Self { data: [r0, r1, r2] }
37    }
38
39    /// Computes the determinant using the Sarrus rule (cofactor expansion along the first row).
40    #[must_use]
41    pub fn determinant(&self) -> f64 {
42        let d = &self.data;
43        d[0][0] * (d[1][1] * d[2][2] - d[1][2] * d[2][1])
44            - d[0][1] * (d[1][0] * d[2][2] - d[1][2] * d[2][0])
45            + d[0][2] * (d[1][0] * d[2][1] - d[1][1] * d[2][0])
46    }
47
48    /// Returns the transpose of this matrix: A^T[i][j] = A[j][i].
49    #[must_use]
50    pub fn transpose(&self) -> Self {
51        let d = &self.data;
52        Self {
53            data: [
54                [d[0][0], d[1][0], d[2][0]],
55                [d[0][1], d[1][1], d[2][1]],
56                [d[0][2], d[1][2], d[2][2]],
57            ],
58        }
59    }
60
61    /// Computes the matrix inverse via the adjugate method: A⁻¹ = adj(A) / det(A).
62    #[must_use]
63    pub fn inverse(&self) -> Option<Self> {
64        let det = self.determinant();
65        if det.abs() < SINGULARITY_THRESHOLD {
66            return None;
67        }
68        let d = &self.data;
69        let inv_det = 1.0 / det;
70
71        // Cofactor matrix, transposed (adjugate), scaled by 1/det
72        Some(Self {
73            data: [
74                [
75                    (d[1][1] * d[2][2] - d[1][2] * d[2][1]) * inv_det,
76                    (d[0][2] * d[2][1] - d[0][1] * d[2][2]) * inv_det,
77                    (d[0][1] * d[1][2] - d[0][2] * d[1][1]) * inv_det,
78                ],
79                [
80                    (d[1][2] * d[2][0] - d[1][0] * d[2][2]) * inv_det,
81                    (d[0][0] * d[2][2] - d[0][2] * d[2][0]) * inv_det,
82                    (d[0][2] * d[1][0] - d[0][0] * d[1][2]) * inv_det,
83                ],
84                [
85                    (d[1][0] * d[2][1] - d[1][1] * d[2][0]) * inv_det,
86                    (d[0][1] * d[2][0] - d[0][0] * d[2][1]) * inv_det,
87                    (d[0][0] * d[1][1] - d[0][1] * d[1][0]) * inv_det,
88                ],
89            ],
90        })
91    }
92
93    /// Returns the trace (sum of diagonal elements) of the matrix.
94    #[must_use]
95    pub fn trace(&self) -> f64 {
96        self.data[0][0] + self.data[1][1] + self.data[2][2]
97    }
98
99    /// Multiplies this matrix by a column vector: result = A × v.
100    #[must_use]
101    pub fn mul_vec(&self, v: Vec3) -> Vec3 {
102        let d = &self.data;
103        Vec3::new(
104            d[0][0] * v.x + d[0][1] * v.y + d[0][2] * v.z,
105            d[1][0] * v.x + d[1][1] * v.y + d[1][2] * v.z,
106            d[2][0] * v.x + d[2][1] * v.y + d[2][2] * v.z,
107        )
108    }
109
110    /// Multiplies two 3x3 matrices: result = A × B.
111    #[must_use]
112    pub fn mul_mat(&self, other: &Mat3) -> Mat3 {
113        let a = &self.data;
114        let b = &other.data;
115        let mut result = [[0.0; 3]; 3];
116        for i in 0..3 {
117            for j in 0..3 {
118                result[i][j] = a[i][0] * b[0][j] + a[i][1] * b[1][j] + a[i][2] * b[2][j];
119            }
120        }
121        Mat3 { data: result }
122    }
123
124    /// Returns a uniform scaling matrix: diag(s, s, s).
125    #[must_use]
126    pub fn scale(s: f64) -> Self {
127        Self {
128            data: [
129                [s, 0.0, 0.0],
130                [0.0, s, 0.0],
131                [0.0, 0.0, s],
132            ],
133        }
134    }
135
136    /// Multiplies every element of the matrix by a scalar.
137    #[must_use]
138    pub fn mul_scalar(&self, s: f64) -> Self {
139        let d = &self.data;
140        Self {
141            data: [
142                [d[0][0] * s, d[0][1] * s, d[0][2] * s],
143                [d[1][0] * s, d[1][1] * s, d[1][2] * s],
144                [d[2][0] * s, d[2][1] * s, d[2][2] * s],
145            ],
146        }
147    }
148}
149
150impl Mul<Mat3> for Mat3 {
151    type Output = Mat3;
152    fn mul(self, rhs: Mat3) -> Mat3 {
153        self.mul_mat(&rhs)
154    }
155}
156
157impl Mul<Vec3> for Mat3 {
158    type Output = Vec3;
159    fn mul(self, rhs: Vec3) -> Vec3 {
160        self.mul_vec(rhs)
161    }
162}
163
164impl Add<Mat3> for Mat3 {
165    type Output = Mat3;
166    fn add(self, rhs: Mat3) -> Mat3 {
167        let mut result = [[0.0; 3]; 3];
168        for i in 0..3 {
169            for j in 0..3 {
170                result[i][j] = self.data[i][j] + rhs.data[i][j];
171            }
172        }
173        Mat3 { data: result }
174    }
175}
176
177impl Sub<Mat3> for Mat3 {
178    type Output = Mat3;
179    fn sub(self, rhs: Mat3) -> Mat3 {
180        let mut result = [[0.0; 3]; 3];
181        for i in 0..3 {
182            for j in 0..3 {
183                result[i][j] = self.data[i][j] - rhs.data[i][j];
184            }
185        }
186        Mat3 { data: result }
187    }
188}
189
190// --- Rotation matrices ---
191
192/// Rotation matrix about the x-axis by the given angle in radians.
193#[must_use]
194pub fn rotation_x(angle: f64) -> Mat3 {
195    let (s, c) = angle.sin_cos();
196    Mat3::from_rows(
197        [1.0, 0.0, 0.0],
198        [0.0, c, -s],
199        [0.0, s, c],
200    )
201}
202
203/// Rotation matrix about the y-axis by the given angle in radians.
204#[must_use]
205pub fn rotation_y(angle: f64) -> Mat3 {
206    let (s, c) = angle.sin_cos();
207    Mat3::from_rows(
208        [c, 0.0, s],
209        [0.0, 1.0, 0.0],
210        [-s, 0.0, c],
211    )
212}
213
214/// Rotation matrix about the z-axis by the given angle in radians.
215#[must_use]
216pub fn rotation_z(angle: f64) -> Mat3 {
217    let (s, c) = angle.sin_cos();
218    Mat3::from_rows(
219        [c, -s, 0.0],
220        [s, c, 0.0],
221        [0.0, 0.0, 1.0],
222    )
223}
224
225/// Rodrigues' rotation formula: rotate by `angle` radians about `axis`.
226/// The axis is normalized internally.
227#[must_use]
228pub fn rotation_axis_angle(axis: Vec3, angle: f64) -> Mat3 {
229    let n = axis.normalized();
230    let (s, c) = angle.sin_cos();
231    let t = 1.0 - c;
232
233    Mat3::from_rows(
234        [
235            t * n.x * n.x + c,
236            t * n.x * n.y - s * n.z,
237            t * n.x * n.z + s * n.y,
238        ],
239        [
240            t * n.y * n.x + s * n.z,
241            t * n.y * n.y + c,
242            t * n.y * n.z - s * n.x,
243        ],
244        [
245            t * n.z * n.x - s * n.y,
246            t * n.z * n.y + s * n.x,
247            t * n.z * n.z + c,
248        ],
249    )
250}
251
252// --- Coordinate transformations ---
253
254/// Returns (r, theta, phi) where theta is the polar angle from +z and phi is the azimuthal angle from +x.
255#[must_use]
256pub fn cartesian_to_spherical(x: f64, y: f64, z: f64) -> (f64, f64, f64) {
257    let r = (x * x + y * y + z * z).sqrt();
258    if r < SINGULARITY_THRESHOLD {
259        return (0.0, 0.0, 0.0);
260    }
261    let theta = (z / r).clamp(-1.0, 1.0).acos();
262    let phi = y.atan2(x);
263    (r, theta, phi)
264}
265
266/// Converts spherical coordinates (r, theta, phi) to Cartesian (x, y, z).
267#[must_use]
268pub fn spherical_to_cartesian(r: f64, theta: f64, phi: f64) -> (f64, f64, f64) {
269    let (sin_theta, cos_theta) = theta.sin_cos();
270    let (sin_phi, cos_phi) = phi.sin_cos();
271    (
272        r * sin_theta * cos_phi,
273        r * sin_theta * sin_phi,
274        r * cos_theta,
275    )
276}
277
278/// Returns (rho, phi, z) where rho is the radial distance in the xy-plane and phi is the azimuthal angle from +x.
279#[must_use]
280pub fn cartesian_to_cylindrical(x: f64, y: f64, z: f64) -> (f64, f64, f64) {
281    let rho = (x * x + y * y).sqrt();
282    let phi = y.atan2(x);
283    (rho, phi, z)
284}
285
286/// Converts cylindrical coordinates (rho, phi, z) to Cartesian (x, y, z).
287#[must_use]
288pub fn cylindrical_to_cartesian(rho: f64, phi: f64, z: f64) -> (f64, f64, f64) {
289    let (sin_phi, cos_phi) = phi.sin_cos();
290    (rho * cos_phi, rho * sin_phi, z)
291}
292
293/// Converts 2D polar coordinates (r, theta) to Cartesian (x, y).
294#[must_use]
295pub fn polar_to_cartesian(r: f64, theta: f64) -> (f64, f64) {
296    let (sin_t, cos_t) = theta.sin_cos();
297    (r * cos_t, r * sin_t)
298}
299
300/// Returns (r, theta) where theta is the angle from +x.
301#[must_use]
302pub fn cartesian_to_polar(x: f64, y: f64) -> (f64, f64) {
303    let r = (x * x + y * y).sqrt();
304    let theta = y.atan2(x);
305    (r, theta)
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use crate::math::constants::PI;
312
313    const APPROX_EPSILON: f64 = 1e-9;
314
315    fn approx(a: f64, b: f64) -> bool {
316        (a - b).abs() < APPROX_EPSILON
317    }
318
319    fn mat3_approx_eq(a: &Mat3, b: &Mat3) -> bool {
320        for i in 0..3 {
321            for j in 0..3 {
322                if !approx(a.data[i][j], b.data[i][j]) {
323                    return false;
324                }
325            }
326        }
327        true
328    }
329
330    #[test]
331    fn test_identity_determinant() {
332        assert!(approx(Mat3::identity().determinant(), 1.0));
333    }
334
335    #[test]
336    fn test_zero_determinant() {
337        assert!(approx(Mat3::zero().determinant(), 0.0));
338    }
339
340    #[test]
341    fn test_transpose_identity() {
342        assert_eq!(Mat3::identity().transpose(), Mat3::identity());
343    }
344
345    #[test]
346    fn test_trace() {
347        let m = Mat3::from_rows([2.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 5.0]);
348        assert!(approx(m.trace(), 10.0));
349    }
350
351    #[test]
352    fn test_inverse_times_original_is_identity() {
353        let m = Mat3::from_rows([1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]);
354        let inv = m.inverse().expect("matrix should be invertible");
355        let product = m * inv;
356        assert!(
357            mat3_approx_eq(&product, &Mat3::identity()),
358            "M * M^-1 should equal I, got {:?}",
359            product
360        );
361    }
362
363    #[test]
364    fn test_singular_matrix_has_no_inverse() {
365        let m = Mat3::from_rows([1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]);
366        assert!(m.inverse().is_none());
367    }
368
369    #[test]
370    fn test_mul_vec() {
371        let m = Mat3::identity();
372        let v = Vec3::new(1.0, 2.0, 3.0);
373        let result = m * v;
374        assert!(approx(result.x, 1.0) && approx(result.y, 2.0) && approx(result.z, 3.0));
375    }
376
377    #[test]
378    fn test_mul_mat_identity() {
379        let m = Mat3::from_rows([1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]);
380        let result = m * Mat3::identity();
381        assert!(mat3_approx_eq(&result, &m));
382    }
383
384    #[test]
385    fn test_add_sub() {
386        let a = Mat3::identity();
387        let b = Mat3::identity();
388        let sum = a + b;
389        assert!(approx(sum.data[0][0], 2.0));
390        let diff = sum - a;
391        assert!(mat3_approx_eq(&diff, &Mat3::identity()));
392    }
393
394    #[test]
395    fn test_scale() {
396        let s = Mat3::scale(3.0);
397        let v = Vec3::new(1.0, 2.0, 3.0);
398        let result = s * v;
399        assert!(approx(result.x, 3.0) && approx(result.y, 6.0) && approx(result.z, 9.0));
400    }
401
402    #[test]
403    fn test_mul_scalar() {
404        let m = Mat3::identity();
405        let scaled = m.mul_scalar(5.0);
406        assert!(approx(scaled.data[0][0], 5.0));
407        assert!(approx(scaled.data[0][1], 0.0));
408    }
409
410    // --- Rotation tests ---
411
412    #[test]
413    fn test_rotation_z_90_maps_x_to_y() {
414        let r = rotation_z(PI / 2.0);
415        let x_hat = Vec3::new(1.0, 0.0, 0.0);
416        let result = r * x_hat;
417        assert!(
418            approx(result.x, 0.0) && approx(result.y, 1.0) && approx(result.z, 0.0),
419            "90-deg rotation about z should map x-hat to y-hat, got {:?}",
420            result
421        );
422    }
423
424    #[test]
425    fn test_rotation_x_90_maps_y_to_z() {
426        let r = rotation_x(PI / 2.0);
427        let y_hat = Vec3::new(0.0, 1.0, 0.0);
428        let result = r * y_hat;
429        assert!(
430            approx(result.x, 0.0) && approx(result.y, 0.0) && approx(result.z, 1.0),
431            "90-deg rotation about x should map y-hat to z-hat, got {:?}",
432            result
433        );
434    }
435
436    #[test]
437    fn test_rotation_y_90_maps_z_to_x() {
438        let r = rotation_y(PI / 2.0);
439        let z_hat = Vec3::new(0.0, 0.0, 1.0);
440        let result = r * z_hat;
441        assert!(
442            approx(result.x, 1.0) && approx(result.y, 0.0) && approx(result.z, 0.0),
443            "90-deg rotation about y should map z-hat to x-hat, got {:?}",
444            result
445        );
446    }
447
448    #[test]
449    fn test_rotation_matrix_is_orthogonal() {
450        let r = rotation_axis_angle(Vec3::new(1.0, 1.0, 1.0), 1.23);
451        let rt_r = r.transpose() * r;
452        assert!(
453            mat3_approx_eq(&rt_r, &Mat3::identity()),
454            "R^T * R should equal I for rotation matrices, got {:?}",
455            rt_r
456        );
457    }
458
459    #[test]
460    fn test_rotation_matrix_determinant_is_one() {
461        let r = rotation_axis_angle(Vec3::new(0.0, 1.0, 0.0), 0.75);
462        let det = r.determinant();
463        assert!(
464            approx(det, 1.0),
465            "Rotation matrix determinant should be 1, got {det}",
466        );
467    }
468
469    #[test]
470    fn test_axis_angle_matches_rotation_z() {
471        let angle = 1.2;
472        let rz = rotation_z(angle);
473        let raa = rotation_axis_angle(Vec3::new(0.0, 0.0, 1.0), angle);
474        assert!(
475            mat3_approx_eq(&rz, &raa),
476            "Axis-angle about z should match rotation_z"
477        );
478    }
479
480    // --- Coordinate transformation roundtrip tests ---
481
482    #[test]
483    fn test_cartesian_spherical_roundtrip() {
484        let (x, y, z) = (3.0, 4.0, 5.0);
485        let (r, theta, phi) = cartesian_to_spherical(x, y, z);
486        let (x2, y2, z2) = spherical_to_cartesian(r, theta, phi);
487        assert!(
488            approx(x, x2) && approx(y, y2) && approx(z, z2),
489            "Spherical roundtrip failed: ({x}, {y}, {z}) -> ({x2}, {y2}, {z2})"
490        );
491    }
492
493    #[test]
494    fn test_cartesian_cylindrical_roundtrip() {
495        let (x, y, z) = (-2.0, 7.0, 3.5);
496        let (rho, phi, z_cyl) = cartesian_to_cylindrical(x, y, z);
497        let (x2, y2, z2) = cylindrical_to_cartesian(rho, phi, z_cyl);
498        assert!(
499            approx(x, x2) && approx(y, y2) && approx(z, z2),
500            "Cylindrical roundtrip failed: ({x}, {y}, {z}) -> ({x2}, {y2}, {z2})"
501        );
502    }
503
504    #[test]
505    fn test_polar_roundtrip() {
506        let (x, y) = (3.0, -4.0);
507        let (r, theta) = cartesian_to_polar(x, y);
508        let (x2, y2) = polar_to_cartesian(r, theta);
509        assert!(
510            approx(x, x2) && approx(y, y2),
511            "Polar roundtrip failed: ({x}, {y}) -> ({x2}, {y2})"
512        );
513    }
514
515    #[test]
516    fn test_spherical_known_values() {
517        // Point on +z axis
518        let (r, theta, _phi) = cartesian_to_spherical(0.0, 0.0, 5.0);
519        assert!(approx(r, 5.0));
520        assert!(approx(theta, 0.0));
521
522        // Point on +x axis
523        let (r, theta, phi) = cartesian_to_spherical(3.0, 0.0, 0.0);
524        assert!(approx(r, 3.0));
525        assert!(approx(theta, PI / 2.0));
526        assert!(approx(phi, 0.0));
527    }
528
529    #[test]
530    fn test_origin_spherical() {
531        let (r, theta, phi) = cartesian_to_spherical(0.0, 0.0, 0.0);
532        assert!(approx(r, 0.0) && approx(theta, 0.0) && approx(phi, 0.0));
533    }
534
535    #[test]
536    fn test_mul_vec_non_identity() {
537        let m = Mat3::from_rows([2.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 4.0]);
538        let v = Vec3::new(1.0, 2.0, 3.0);
539        let result = m.mul_vec(v);
540        assert!(approx(result.x, 2.0) && approx(result.y, 6.0) && approx(result.z, 12.0));
541    }
542
543    #[test]
544    fn test_mul_mat_non_trivial() {
545        let a = Mat3::from_rows([1.0, 2.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]);
546        let b = Mat3::from_rows([1.0, 0.0, 0.0], [3.0, 1.0, 0.0], [0.0, 0.0, 1.0]);
547        let c = a.mul_mat(&b);
548        // c[0][0] = 1*1+2*3+0*0 = 7, c[0][1] = 1*0+2*1+0*0 = 2
549        assert!(approx(c.data[0][0], 7.0), "got {}", c.data[0][0]);
550        assert!(approx(c.data[0][1], 2.0), "got {}", c.data[0][1]);
551        assert!(approx(c.data[1][0], 3.0), "got {}", c.data[1][0]);
552    }
553
554    #[test]
555    fn test_mat3_approx_eq_different() {
556        let a = Mat3::identity();
557        let b = Mat3::from_rows([2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]);
558        assert!(!mat3_approx_eq(&a, &b));
559    }
560}