simple_bezier_easing/lib.rs
1//! This module provides functions and utilities for calculating cubic Bézier curves
2//! using De Casteljau's algorithm. It includes robust error handling for invalid
3//! control points and supports binary subdivision to efficiently determine the parameter
4//! `t` corresponding to a given x-coordinate on the curve.
5//!
6//! # Key Features
7//! - **Cubic Bézier Curve Calculation**: Computes the y-coordinate for any given x-coordinate
8//! based on the specified control points using De Casteljau's algorithm.
9//! - **Error Handling**: Returns detailed errors if the control points are outside the valid range
10//! [0, 1] or if the binary subdivision fails to converge within the allowed iterations.
11//! - **Binary Subdivision**: Utilizes binary subdivision to efficiently find the parameter `t` for a
12//! given x-coordinate, ensuring accurate results even for complex curves.
13//! - **Precision Control**: Allows customization of the precision for the binary subdivision process
14//! and handles edge cases gracefully, such as very small slopes or extreme control point values.
15
16use std::{
17 error::Error,
18 fmt::{Display, Formatter, Result as FmtResult},
19};
20
21/// The precision used for binary subdivision in the Bézier curve evaluation.
22///
23/// This value determines how accurately the parameter `t` is found during the binary subdivision process.
24/// A smaller value means more iterations and higher precision, but at the cost of performance.
25/// It's used as the stopping condition when the difference between two successive `t` values is smaller than this threshold.
26const SUBDIVISION_PRECISION: f32 = 0.0001;
27
28/// The minimum slope allowed for the curve's derivative during the binary subdivision process.
29///
30/// This constant ensures that we don't encounter division by zero or excessively small slopes, which could
31/// lead to instability or undefined behavior. It helps control the accuracy of the subdivision by rejecting
32/// slopes that are too close to zero.
33const MIN_SLOPE: f32 = 0.001;
34
35/// The maximum number of iterations allowed for binary subdivision.
36///
37/// This constant limits how many times the binary subdivision process can be iterated in the search for the
38/// correct parameter `t`. It prevents the algorithm from running indefinitely and ensures that the function
39/// completes in a reasonable amount of time. If the desired precision is not met within this number of iterations,
40/// an error will be returned.
41const SUBDIVISION_MAX_ITERATIONS: u32 = 10;
42
43/// Custom errors for handling Bézier curve computations.
44#[derive(Debug)]
45pub enum BezierError {
46 /// Error indicating that control points are out of the valid range [0, 1].
47 InvalidControlPoint { x1: f32, y1: f32, x2: f32, y2: f32 },
48
49 /// Error for failing to calculate the parameter `t` for a given x-coordinate.
50 ParameterCalculationError(f32),
51}
52
53impl Display for BezierError {
54 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
55 match self {
56 BezierError::InvalidControlPoint { x1, y1, x2, y2 } => {
57 write!(
58 f,
59 "Control points must be in the range [0, 1], but got: ({}, {}), ({}, {})",
60 x1, y1, x2, y2
61 )
62 }
63 BezierError::ParameterCalculationError(x) => {
64 write!(f, "Failed to find parameter t for x = {}", x)
65 }
66 }
67 }
68}
69
70impl Error for BezierError {
71 fn source(&self) -> Option<&(dyn Error + 'static)> {
72 // If you don't have a nested error, return None.
73 None
74 }
75}
76
77/// A simple 2D point represented as (x, y).
78#[derive(Debug, Clone, Copy)]
79pub struct Point(pub f32, pub f32);
80
81impl Point {
82 /// Linearly interpolates between two points `a` and `b` at a parameter `t`.
83 fn lerp(a: Point, b: Point, t: f32) -> Point {
84 Point(a.0 + (b.0 - a.0) * t, a.1 + (b.1 - a.1) * t)
85 }
86}
87
88impl std::fmt::Display for Point {
89 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
90 write!(f, "({}, {})", self.0, self.1)
91 }
92}
93
94/// Computes the cubic Bézier curve using De Casteljau's algorithm.
95///
96/// De Casteljau's algorithm is a recursive method for evaluating Bézier curves
97/// at a specific parameter `t`. It works by linearly interpolating between
98/// control points at each level until a single point is obtained.
99///
100/// Parameters:
101/// - `t`: The parameter (0 ≤ t ≤ 1) at which to evaluate the curve.
102/// - `p0`: The first control point (start of the curve).
103/// - `p1`: The second control point (first "pull" control).
104/// - `p2`: The third control point (second "pull" control).
105/// - `p3`: The fourth control point (end of the curve).
106///
107/// Returns:
108/// - The y-coordinate of the Bézier curve at parameter `t`.
109fn de_casteljau(t: f32, p0: Point, p1: Point, p2: Point, p3: Point) -> Point {
110 // First level: Linearly interpolate between the control points
111 // Compute intermediate points `q0`, `q1`, and `q2` at the first level.
112 let q0 = Point::lerp(p0, p1, t);
113 let q1 = Point::lerp(p1, p2, t);
114 let q2 = Point::lerp(p2, p3, t);
115
116 // Second level: Interpolate between the intermediate points from the first level
117 // Compute `r0` and `r1` as the second-level intermediate points.
118 let r0 = Point::lerp(q0, q1, t);
119 let r1 = Point::lerp(q1, q2, t);
120
121 // Final level: Interpolate between the second-level points to get the final result
122 // Compute the final point on the curve corresponding to `t`.
123 Point::lerp(r0, r1, t) // Interpolates between r0 and r1 to get the curve's value
124}
125
126/// Uses binary subdivision to find the parameter `t` for a given x-coordinate on the Bézier curve.
127///
128/// This function attempts to compute the parameter `t` such that when the Bézier curve is evaluated
129/// at `t`, the resulting x-coordinate matches the given target `x`. The method uses binary subdivision
130/// within the range [0, 1] to iteratively narrow down the value of `t` until the x-coordinate of the
131/// Bézier curve at `t` is sufficiently close to the target `x` or the maximum number of iterations is reached.
132///
133/// # Parameters
134/// - `x`: The target x-coordinate for which to find the corresponding parameter `t`.
135/// - `p0`, `p1`, `p2`, `p3`: The control points of the cubic Bézier curve, where `p0` is the start point,
136/// `p3` is the end point, and `p1` and `p2` are the intermediate control points that define the curve's shape.
137///
138/// # Returns
139/// - `Ok(t)`: The parameter `t` corresponding to the given x-coordinate, where `0 <= t <= 1`.
140/// - `Err(BezierError::ParameterCalculationError(x))`: If the x-coordinate cannot be matched within the
141/// specified precision or the maximum number of iterations is reached, an error is returned, indicating
142/// that the calculation for the given x-coordinate failed.
143///
144/// # Errors
145/// - If the binary subdivision fails to find a suitable parameter `t` within the maximum iterations or
146/// the precision requirements, it returns a `ParameterCalculationError` with the target x-coordinate.
147///
148fn get_t_for_x(x: f32, p0: Point, p1: Point, p2: Point, p3: Point) -> Result<f32, BezierError> {
149 let mut t0 = 0.0;
150 let mut t1 = 1.0;
151 let mut t = (t0 + t1) / 2.0; // Start with a midpoint guess rather than x
152 let mut last_t = t;
153
154 for _ in 0..SUBDIVISION_MAX_ITERATIONS {
155 // Evaluate the Bézier curve at `t` to find its x-coordinate.
156 let x_val = de_casteljau(t, p0, p1, p2, p3);
157 let error = x - x_val.0;
158
159 // Adjust the range based on the error.
160 if error.abs() < SUBDIVISION_PRECISION {
161 break;
162 }
163 if error > 0.0 {
164 t0 = t;
165 } else {
166 t1 = t;
167 }
168 t = (t0 + t1) / 2.0;
169
170 if (t - last_t).abs() < SUBDIVISION_PRECISION {
171 break;
172 }
173
174 last_t = t;
175 }
176
177 let final_x_val = de_casteljau(t, p0, p1, p2, p3);
178 if (x - final_x_val.0).abs() < MIN_SLOPE && (t - last_t).abs() < MIN_SLOPE {
179 Ok(t) // Return the result if it's sufficiently accurate
180 } else {
181 Err(BezierError::ParameterCalculationError(x)) // Otherwise, return an error
182 }
183}
184
185/// Creates a cubic Bézier curve function based on the given control points.
186///
187/// This function returns a closure that can be used to compute the y-coordinate of the Bézier curve for
188/// any given x-coordinate within the range [0, 1]. The closure uses De Casteljau's algorithm to compute
189/// the point on the curve corresponding to the x-coordinate. It also supports binary subdivision to find
190/// the parameter `t` corresponding to the given x and then evaluate the curve at that parameter.
191///
192/// # Parameters
193/// - `x1`, `y1`: Coordinates of the first control point (p1), which influences the curve's direction and shape.
194/// - `x2`, `y2`: Coordinates of the second control point (p2), which also influences the curve's direction and shape.
195///
196/// # Returns
197/// - `Ok`: A closure that accepts an x-coordinate and returns the corresponding y-coordinate on the Bézier curve.
198/// The closure will return a `Result<f32, BezierError>`, where the y-coordinate is computed for the given x.
199/// - `Err`: If the control points are outside the valid range [0, 1], an error is returned indicating the invalid
200/// control points. The error type is `BezierError::InvalidControlPoint`.
201///
202/// # Errors
203/// - `BezierError::InvalidControlPoint`: If any of the control points are outside the range [0, 1], an error is returned.
204/// - `BezierError::ParameterCalculationError`: If the binary subdivision fails to calculate a valid parameter `t`
205/// for a given x-coordinate when calling the returned closure.
206///
207/// # Example
208///
209/// ```rust
210/// use simple_bezier_easing::bezier;
211/// let bez = bezier(0.2, 0.4, 0.6, 0.8).unwrap();
212/// let y_at_0_5 = bez(0.5).unwrap(); // Compute the y-coordinate at x = 0.5
213/// let rounded_value = (y_at_0_5 * 10.0).round() / 10.0; // Round to 1 decimal place
214/// assert_eq!(rounded_value, 0.6); // expected y-value
215///
216/// // Error example:
217/// let invalid_bez = bezier(1.2, 0.4, 0.6, 0.8); // This will return an error due to control points out of bounds
218/// ```
219pub fn bezier(
220 x1: f32,
221 y1: f32,
222 x2: f32,
223 y2: f32,
224) -> Result<impl Fn(f32) -> Result<f32, BezierError>, BezierError> {
225 // Ensure control points are within bounds (for x-coordinates of p1 and p2).
226 if !(0.0..=1.0).contains(&x1)
227 || !(0.0..=1.0).contains(&x2)
228 || !(0.0..=1.0).contains(&y1)
229 || !(0.0..=1.0).contains(&y2)
230 {
231 return Err(BezierError::InvalidControlPoint { x1, y1, x2, y2 });
232 }
233
234 Ok(move |x: f32| {
235 // Shortcut for linear curves (control points are on the line y = x).
236 if x1 == y1 && x2 == y2 {
237 return Ok(x); // Return the same x for a linear curve (y = x).
238 }
239
240 if !(0.0 - f32::EPSILON..=1.0 + f32::EPSILON).contains(&x) {
241 return Err(BezierError::ParameterCalculationError(x));
242 }
243
244 if x == 0.0 || x == 1.0 {
245 return Ok(x);
246 }
247
248 let p0 = Point(0.0, 0.0);
249 let p1 = Point(x1, y1);
250 let p2 = Point(x2, y2);
251 let p3 = Point(1.0, 1.0);
252
253 // Find the parameter `t` corresponding to the x-coordinate.
254 let t = get_t_for_x(x, p0, p1, p2, p3)?;
255 // Once `t` is found, evaluate the Bézier curve for the y-coordinate.
256 // Return the y-coordinate from the Point.
257 Ok(de_casteljau(t, p0, p1, p2, p3).1)
258 })
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*; // This brings the items from the parent module into scope
264
265 /// Tests the `bezier` function by calculating the y-coordinate for a given x-coordinate.
266 ///
267 /// This test ensures that the `bezier` function correctly computes the y-coordinate for a given x-coordinate
268 /// within the Bézier curve. The test checks if the function produces a value that matches the expected result
269 /// for x = 0.5. The expected value is rounded to one decimal place for comparison.
270 #[test]
271 fn test_bezier_curve() {
272 // Test the bezier function with control points
273 let bez = bezier(0.2, 0.4, 0.6, 0.8).unwrap();
274
275 // Test if we can calculate the y value for a given x
276 let y_at_0_5 = bez(0.5).unwrap();
277
278 // Round the computed value to one decimal place for comparison
279 let rounded_y = (y_at_0_5 * 10.0).round() / 10.0;
280
281 // Assert that the rounded y value is equal to the expected value (0.6)
282 assert_eq!(
283 rounded_y, 0.6,
284 "Expected y value at x = 0.5 to be 0.6, but got {}",
285 rounded_y
286 );
287 }
288
289 /// Tests the `bezier` function with invalid control points to ensure proper error handling.
290 ///
291 /// This test verifies that the `bezier` function returns an error when control points are provided that are outside
292 /// the valid range [0, 1]. It specifically checks for cases where the first control point (x1) is greater than 1,
293 /// which should trigger an error (`BezierError::InvalidControlPoint`).
294 #[test]
295 fn test_invalid_control_points() {
296 // Test for invalid control points that should return an error
297 let bez = bezier(1.2, 0.4, 0.6, 0.8); // Invalid control point x1 = 1.2
298 assert!(bez.is_err(), "Expected error for invalid control points");
299 }
300
301 /// Tests edge cases for the Bézier curve, specifically when the x-coordinate is at the endpoints (t == 0 or t == 1).
302 ///
303 /// This test ensures that the `bezier` function correctly handles special cases when the x-coordinate is at the
304 /// endpoints of the Bézier curve, which should always return the same value as the x-coordinate in a linear case.
305 /// The function is tested at the endpoints of the curve (x = 0.0 and x = 1.0), ensuring that the y-coordinate
306 /// matches the input x-coordinate in these cases.
307 #[test]
308 fn test_edge_cases() {
309 // Test when t == 0 or t == 1 (endpoints of the curve)
310 let bez = bezier(0.0, 0.0, 1.0, 1.0).unwrap();
311
312 // Assert that the value at x = 0.0 is 0.0 (start of the curve)
313 assert_eq!(
314 bez(0.0).unwrap(),
315 0.0,
316 "Expected y value at x = 0.0 to be 0.0"
317 );
318
319 // Assert that the value at x = 1.0 is 1.0 (end of the curve)
320 assert_eq!(
321 bez(1.0).unwrap(),
322 1.0,
323 "Expected y value at x = 1.0 to be 1.0"
324 );
325 }
326}