traffic_sim/math/
cubic.rs

1//! Mathematical functions.
2
3use serde::{Deserialize, Serialize};
4
5/// A cubic function, which is a polynomial of the form ax^3 + bx^2 + cx + d.
6#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
7pub struct CubicFn {
8    coeffs: [f64; 4],
9    offset: f64,
10}
11
12impl CubicFn {
13    /// Creates a cubic function which outputs the fixed value `y` for any input.
14    pub const fn constant(y: f64) -> Self {
15        Self {
16            coeffs: [0.0, 0.0, 0.0, y],
17            offset: 0.0,
18        }
19    }
20
21    /// Creates a cubic function which passes through the two specified points,
22    /// and which has the specified derivates at those two points.
23    pub fn fit(x1: f64, y1: f64, dydx1: f64, x2: f64, y2: f64, dydx2: f64) -> Self {
24        let w = x2 - x1;
25        let a = 2. * y1 - 2. * y2 + w * dydx1 + w * dydx2;
26        let b = -3. * y1 + 3. * y2 - 2. * w * dydx1 - w * dydx2;
27        let c = w * dydx1;
28        let d = y1;
29        Self {
30            coeffs: [a * w.powi(-3), b * w.powi(-2), c * w.powi(-1), d],
31            offset: -x1,
32        }
33    }
34
35    /// Returns a copy of this `CubicFn` translated to the left by `amount` units.
36    pub fn translate_x(&self, amount: f64) -> CubicFn {
37        Self {
38            coeffs: self.coeffs,
39            offset: self.offset + amount,
40        }
41    }
42
43    /// Evaluates the function at `x`.
44    pub fn y(&self, x: f64) -> f64 {
45        self.y_and_dy(x).0
46    }
47
48    /// Evaluates the derivative of the function at `x`.
49    pub fn dy(&self, x: f64) -> f64 {
50        self.y_and_dy(x).1
51    }
52
53    /// Evaluates the value of the function, and the derivative of the function, at `x`.
54    pub fn y_and_dy(&self, x: f64) -> (f64, f64) {
55        let c = &self.coeffs;
56        let x = x + self.offset;
57
58        let y = c[0] * x * x * x + c[1] * x * x + c[2] * x + c[3];
59        let dy = c[0] * 3. * x * x + c[1] * 2. * x + c[2];
60
61        (y, dy)
62    }
63}
64
65#[cfg(test)]
66mod test {
67    use super::*;
68    use assert_approx_eq::assert_approx_eq;
69    use rand::{Rng, SeedableRng};
70
71    #[test]
72    pub fn fit_horizontal() {
73        let cubic = CubicFn::fit(10., 20., 0.0, 45.0, 5.0, 0.0);
74        assert_approx_eq!(cubic.y(10.), 20., 0.01);
75        assert_approx_eq!(cubic.dy(10.), 0., 0.01);
76        assert_approx_eq!(cubic.y(45.), 5., 0.01);
77        assert_approx_eq!(cubic.dy(45.), 0., 0.01);
78        assert_approx_eq!(cubic.y(27.5), 12.5, 0.01);
79    }
80
81    #[test]
82    pub fn fit() {
83        let mut rng = rand::rngs::StdRng::from_seed(*b"Vegemite sandwhich is not fun...");
84        for _i in 0..100 {
85            let x1 = rng.gen_range(-100.0..100.0);
86            let x2 = rng.gen_range(-100.0..100.0);
87            let y1 = rng.gen_range(-100.0..100.0);
88            let y2 = rng.gen_range(-100.0..100.0);
89            let dydx1 = rng.gen_range(-10.0..10.0);
90            let dydx2 = rng.gen_range(-10.0..10.0);
91            let cubic = CubicFn::fit(x1, y1, dydx1, x2, y2, dydx2);
92
93            assert_approx_eq!(cubic.y(x1), y1, 0.01);
94            assert_approx_eq!(cubic.dy(x1), dydx1, 0.01);
95            assert_approx_eq!(cubic.y(x2), y2, 0.01);
96            assert_approx_eq!(cubic.dy(x2), dydx2, 0.01);
97        }
98    }
99
100    #[test]
101    pub fn straight_lines() {
102        let mut rng = rand::rngs::StdRng::from_seed(*b"Vegemite sandwhich is not fun...");
103        for _i in 0..100 {
104            let x1 = rng.gen_range(-100.0..100.0);
105            let x2 = rng.gen_range(-100.0..100.0);
106            let y1 = rng.gen_range(-100.0..100.0);
107            let y2 = rng.gen_range(-100.0..100.0);
108            let dydx = (y2 - y1) / (x2 - x1);
109            let cubic = CubicFn::fit(x1, y1, dydx, x2, y2, dydx);
110
111            assert_approx_eq!(cubic.y(x1), y1, 0.01);
112            assert_approx_eq!(cubic.y(0.5 * (x1 + x2)), 0.5 * (y1 + y2), 0.01);
113            assert_approx_eq!(cubic.y(x2), y2, 0.01);
114            assert_approx_eq!(cubic.dy(x1), dydx, 0.01);
115            assert_approx_eq!(cubic.dy(0.5 * (x1 + x2)), dydx, 0.01);
116            assert_approx_eq!(cubic.dy(x2), dydx, 0.01);
117        }
118    }
119}