simple_bezier_easing/
lib.rs

1//! This module provides functions and utilities for calculating cubic Bézier curves
2//! using De Casteljau's algorithm. It includes robust error handling for invalid
3//! control points and supports binary subdivision to efficiently determine the parameter
4//! `t` corresponding to a given x-coordinate on the curve.
5//!
6//! # Key Features
7//! - **Cubic Bézier Curve Calculation**: Computes the y-coordinate for any given x-coordinate
8//!   based on the specified control points using De Casteljau's algorithm.
9//! - **Error Handling**: Returns detailed errors if the control points are outside the valid range
10//!   [0, 1] or if the binary subdivision fails to converge within the allowed iterations.
11//! - **Binary Subdivision**: Utilizes binary subdivision to efficiently find the parameter `t` for a
12//!   given x-coordinate, ensuring accurate results even for complex curves.
13//! - **Precision Control**: Allows customization of the precision for the binary subdivision process
14//!   and handles edge cases gracefully, such as very small slopes or extreme control point values.
15
16use std::{
17    error::Error,
18    fmt::{Display, Formatter, Result as FmtResult},
19};
20
21/// The precision used for binary subdivision in the Bézier curve evaluation.
22///
23/// This value determines how accurately the parameter `t` is found during the binary subdivision process.
24/// A smaller value means more iterations and higher precision, but at the cost of performance.
25/// It's used as the stopping condition when the difference between two successive `t` values is smaller than this threshold.
26const SUBDIVISION_PRECISION: f32 = 0.0001;
27
28/// The minimum slope allowed for the curve's derivative during the binary subdivision process.
29///
30/// This constant ensures that we don't encounter division by zero or excessively small slopes, which could
31/// lead to instability or undefined behavior. It helps control the accuracy of the subdivision by rejecting
32/// slopes that are too close to zero.
33const MIN_SLOPE: f32 = 0.001;
34
35/// The maximum number of iterations allowed for binary subdivision.
36///
37/// This constant limits how many times the binary subdivision process can be iterated in the search for the
38/// correct parameter `t`. It prevents the algorithm from running indefinitely and ensures that the function
39/// completes in a reasonable amount of time. If the desired precision is not met within this number of iterations,
40/// an error will be returned.
41const SUBDIVISION_MAX_ITERATIONS: u32 = 10;
42
43/// Custom errors for handling Bézier curve computations.
44#[derive(Debug)]
45pub enum BezierError {
46    /// Error indicating that control points are out of the valid range [0, 1].
47    InvalidControlPoint { x1: f32, y1: f32, x2: f32, y2: f32 },
48
49    /// Error for failing to calculate the parameter `t` for a given x-coordinate.
50    ParameterCalculationError(f32),
51}
52
53impl Display for BezierError {
54    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
55        match self {
56            BezierError::InvalidControlPoint { x1, y1, x2, y2 } => {
57                write!(
58                    f,
59                    "Control points must be in the range [0, 1], but got: ({}, {}), ({}, {})",
60                    x1, y1, x2, y2
61                )
62            }
63            BezierError::ParameterCalculationError(x) => {
64                write!(f, "Failed to find parameter t for x = {}", x)
65            }
66        }
67    }
68}
69
70impl Error for BezierError {
71    fn source(&self) -> Option<&(dyn Error + 'static)> {
72        // If you don't have a nested error, return None.
73        None
74    }
75}
76
77/// A simple 2D point represented as (x, y).
78#[derive(Debug, Clone, Copy)]
79pub struct Point(pub f32, pub f32);
80
81impl Point {
82    /// Linearly interpolates between two points `a` and `b` at a parameter `t`.
83    fn lerp(a: Point, b: Point, t: f32) -> Point {
84        Point(a.0 + (b.0 - a.0) * t, a.1 + (b.1 - a.1) * t)
85    }
86}
87
88impl std::fmt::Display for Point {
89    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
90        write!(f, "({}, {})", self.0, self.1)
91    }
92}
93
94/// Computes the cubic Bézier curve using De Casteljau's algorithm.
95///
96/// De Casteljau's algorithm is a recursive method for evaluating Bézier curves
97/// at a specific parameter `t`. It works by linearly interpolating between
98/// control points at each level until a single point is obtained.
99///
100/// Parameters:
101/// - `t`: The parameter (0 ≤ t ≤ 1) at which to evaluate the curve.
102/// - `p0`: The first control point (start of the curve).
103/// - `p1`: The second control point (first "pull" control).
104/// - `p2`: The third control point (second "pull" control).
105/// - `p3`: The fourth control point (end of the curve).
106///
107/// Returns:
108/// - The y-coordinate of the Bézier curve at parameter `t`.
109fn de_casteljau(t: f32, p0: Point, p1: Point, p2: Point, p3: Point) -> Point {
110    // First level: Linearly interpolate between the control points
111    // Compute intermediate points `q0`, `q1`, and `q2` at the first level.
112    let q0 = Point::lerp(p0, p1, t);
113    let q1 = Point::lerp(p1, p2, t);
114    let q2 = Point::lerp(p2, p3, t);
115
116    // Second level: Interpolate between the intermediate points from the first level
117    // Compute `r0` and `r1` as the second-level intermediate points.
118    let r0 = Point::lerp(q0, q1, t);
119    let r1 = Point::lerp(q1, q2, t);
120
121    // Final level: Interpolate between the second-level points to get the final result
122    // Compute the final point on the curve corresponding to `t`.
123    Point::lerp(r0, r1, t) // Interpolates between r0 and r1 to get the curve's value
124}
125
126/// Uses binary subdivision to find the parameter `t` for a given x-coordinate on the Bézier curve.
127///
128/// This function attempts to compute the parameter `t` such that when the Bézier curve is evaluated
129/// at `t`, the resulting x-coordinate matches the given target `x`. The method uses binary subdivision
130/// within the range [0, 1] to iteratively narrow down the value of `t` until the x-coordinate of the
131/// Bézier curve at `t` is sufficiently close to the target `x` or the maximum number of iterations is reached.
132///
133/// # Parameters
134/// - `x`: The target x-coordinate for which to find the corresponding parameter `t`.
135/// - `p0`, `p1`, `p2`, `p3`: The control points of the cubic Bézier curve, where `p0` is the start point,
136///   `p3` is the end point, and `p1` and `p2` are the intermediate control points that define the curve's shape.
137///
138/// # Returns
139/// - `Ok(t)`: The parameter `t` corresponding to the given x-coordinate, where `0 <= t <= 1`.
140/// - `Err(BezierError::ParameterCalculationError(x))`: If the x-coordinate cannot be matched within the
141///   specified precision or the maximum number of iterations is reached, an error is returned, indicating
142///   that the calculation for the given x-coordinate failed.
143///
144/// # Errors
145/// - If the binary subdivision fails to find a suitable parameter `t` within the maximum iterations or
146///   the precision requirements, it returns a `ParameterCalculationError` with the target x-coordinate.
147///
148fn get_t_for_x(x: f32, p0: Point, p1: Point, p2: Point, p3: Point) -> Result<f32, BezierError> {
149    let mut t0 = 0.0;
150    let mut t1 = 1.0;
151    let mut t = (t0 + t1) / 2.0; // Start with a midpoint guess rather than x
152    let mut last_t = t;
153
154    for _ in 0..SUBDIVISION_MAX_ITERATIONS {
155        // Evaluate the Bézier curve at `t` to find its x-coordinate.
156        let x_val = de_casteljau(t, p0, p1, p2, p3);
157        let error = x - x_val.0;
158
159        // Adjust the range based on the error.
160        if error.abs() < SUBDIVISION_PRECISION {
161            break;
162        }
163        if error > 0.0 {
164            t0 = t;
165        } else {
166            t1 = t;
167        }
168        t = (t0 + t1) / 2.0;
169
170        if (t - last_t).abs() < SUBDIVISION_PRECISION {
171            break;
172        }
173
174        last_t = t;
175    }
176
177    let final_x_val = de_casteljau(t, p0, p1, p2, p3);
178    if (x - final_x_val.0).abs() < MIN_SLOPE && (t - last_t).abs() < MIN_SLOPE {
179        Ok(t) // Return the result if it's sufficiently accurate
180    } else {
181        Err(BezierError::ParameterCalculationError(x)) // Otherwise, return an error
182    }
183}
184
185/// Creates a cubic Bézier curve function based on the given control points.
186///
187/// This function returns a closure that can be used to compute the y-coordinate of the Bézier curve for
188/// any given x-coordinate within the range [0, 1]. The closure uses De Casteljau's algorithm to compute
189/// the point on the curve corresponding to the x-coordinate. It also supports binary subdivision to find
190/// the parameter `t` corresponding to the given x and then evaluate the curve at that parameter.
191///
192/// # Parameters
193/// - `x1`, `y1`: Coordinates of the first control point (p1), which influences the curve's direction and shape.
194/// - `x2`, `y2`: Coordinates of the second control point (p2), which also influences the curve's direction and shape.
195///
196/// # Returns
197/// - `Ok`: A closure that accepts an x-coordinate and returns the corresponding y-coordinate on the Bézier curve.
198///   The closure will return a `Result<f32, BezierError>`, where the y-coordinate is computed for the given x.
199/// - `Err`: If the control points are outside the valid range [0, 1], an error is returned indicating the invalid
200///   control points. The error type is `BezierError::InvalidControlPoint`.
201///
202/// # Errors
203/// - `BezierError::InvalidControlPoint`: If any of the control points are outside the range [0, 1], an error is returned.
204/// - `BezierError::ParameterCalculationError`: If the binary subdivision fails to calculate a valid parameter `t`
205///   for a given x-coordinate when calling the returned closure.
206///
207/// # Example
208///
209/// ```rust
210/// use simple_bezier_easing::bezier;
211/// let bez = bezier(0.2, 0.4, 0.6, 0.8).unwrap();
212/// let y_at_0_5 = bez(0.5).unwrap();  // Compute the y-coordinate at x = 0.5
213/// let rounded_value = (y_at_0_5 * 10.0).round() / 10.0; // Round to 1 decimal place
214/// assert_eq!(rounded_value, 0.6); // expected y-value
215///
216/// // Error example:
217/// let invalid_bez = bezier(1.2, 0.4, 0.6, 0.8); // This will return an error due to control points out of bounds
218/// ```
219pub fn bezier(
220    x1: f32,
221    y1: f32,
222    x2: f32,
223    y2: f32,
224) -> Result<impl Fn(f32) -> Result<f32, BezierError>, BezierError> {
225    // Ensure control points are within bounds (for x-coordinates of p1 and p2).
226    if !(0.0..=1.0).contains(&x1)
227        || !(0.0..=1.0).contains(&x2)
228        || !(0.0..=1.0).contains(&y1)
229        || !(0.0..=1.0).contains(&y2)
230    {
231        return Err(BezierError::InvalidControlPoint { x1, y1, x2, y2 });
232    }
233
234    Ok(move |x: f32| {
235        // Shortcut for linear curves (control points are on the line y = x).
236        if x1 == y1 && x2 == y2 {
237            return Ok(x); // Return the same x for a linear curve (y = x).
238        }
239
240        if !(0.0 - f32::EPSILON..=1.0 + f32::EPSILON).contains(&x) {
241            return Err(BezierError::ParameterCalculationError(x));
242        }
243
244        if x == 0.0 || x == 1.0 {
245            return Ok(x);
246        }
247
248        let p0 = Point(0.0, 0.0);
249        let p1 = Point(x1, y1);
250        let p2 = Point(x2, y2);
251        let p3 = Point(1.0, 1.0);
252
253        // Find the parameter `t` corresponding to the x-coordinate.
254        let t = get_t_for_x(x, p0, p1, p2, p3)?;
255        // Once `t` is found, evaluate the Bézier curve for the y-coordinate.
256        // Return the y-coordinate from the Point.
257        Ok(de_casteljau(t, p0, p1, p2, p3).1)
258    })
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*; // This brings the items from the parent module into scope
264
265    /// Tests the `bezier` function by calculating the y-coordinate for a given x-coordinate.
266    ///
267    /// This test ensures that the `bezier` function correctly computes the y-coordinate for a given x-coordinate
268    /// within the Bézier curve. The test checks if the function produces a value that matches the expected result
269    /// for x = 0.5. The expected value is rounded to one decimal place for comparison.
270    #[test]
271    fn test_bezier_curve() {
272        // Test the bezier function with control points
273        let bez = bezier(0.2, 0.4, 0.6, 0.8).unwrap();
274
275        // Test if we can calculate the y value for a given x
276        let y_at_0_5 = bez(0.5).unwrap();
277
278        // Round the computed value to one decimal place for comparison
279        let rounded_y = (y_at_0_5 * 10.0).round() / 10.0;
280
281        // Assert that the rounded y value is equal to the expected value (0.6)
282        assert_eq!(
283            rounded_y, 0.6,
284            "Expected y value at x = 0.5 to be 0.6, but got {}",
285            rounded_y
286        );
287    }
288
289    /// Tests the `bezier` function with invalid control points to ensure proper error handling.
290    ///
291    /// This test verifies that the `bezier` function returns an error when control points are provided that are outside
292    /// the valid range [0, 1]. It specifically checks for cases where the first control point (x1) is greater than 1,
293    /// which should trigger an error (`BezierError::InvalidControlPoint`).
294    #[test]
295    fn test_invalid_control_points() {
296        // Test for invalid control points that should return an error
297        let bez = bezier(1.2, 0.4, 0.6, 0.8); // Invalid control point x1 = 1.2
298        assert!(bez.is_err(), "Expected error for invalid control points");
299    }
300
301    /// Tests edge cases for the Bézier curve, specifically when the x-coordinate is at the endpoints (t == 0 or t == 1).
302    ///
303    /// This test ensures that the `bezier` function correctly handles special cases when the x-coordinate is at the
304    /// endpoints of the Bézier curve, which should always return the same value as the x-coordinate in a linear case.
305    /// The function is tested at the endpoints of the curve (x = 0.0 and x = 1.0), ensuring that the y-coordinate
306    /// matches the input x-coordinate in these cases.
307    #[test]
308    fn test_edge_cases() {
309        // Test when t == 0 or t == 1 (endpoints of the curve)
310        let bez = bezier(0.0, 0.0, 1.0, 1.0).unwrap();
311
312        // Assert that the value at x = 0.0 is 0.0 (start of the curve)
313        assert_eq!(
314            bez(0.0).unwrap(),
315            0.0,
316            "Expected y value at x = 0.0 to be 0.0"
317        );
318
319        // Assert that the value at x = 1.0 is 1.0 (end of the curve)
320        assert_eq!(
321            bez(1.0).unwrap(),
322            1.0,
323            "Expected y value at x = 1.0 to be 1.0"
324        );
325    }
326}