Skip to main content

use_calculus/
derivative.rs

1use crate::CalculusError;
2
3/// Finite-difference differentiation configuration.
4#[derive(Debug, Clone, Copy, PartialEq)]
5pub struct Differentiator {
6    step: f64,
7}
8
9impl Differentiator {
10    /// Creates a differentiator with the given sample step.
11    #[must_use]
12    pub const fn new(step: f64) -> Self {
13        Self { step }
14    }
15
16    /// Creates a differentiator from a finite positive step.
17    ///
18    /// # Errors
19    ///
20    /// Returns [`CalculusError::NonFiniteStep`] when `step` is `NaN` or
21    /// infinite, and [`CalculusError::NonPositiveStep`] when `step <= 0.0`.
22    ///
23    /// # Examples
24    ///
25    /// ```
26    /// use use_calculus::{CalculusError, Differentiator};
27    ///
28    /// let differentiator = Differentiator::try_new(1.0e-4)?;
29    /// assert_eq!(differentiator, Differentiator::new(1.0e-4));
30    ///
31    /// assert!(matches!(
32    ///     Differentiator::try_new(0.0),
33    ///     Err(CalculusError::NonPositiveStep(0.0))
34    /// ));
35    /// # Ok::<(), CalculusError>(())
36    /// ```
37    pub fn try_new(step: f64) -> Result<Self, CalculusError> {
38        CalculusError::validate_step(step)?;
39        Ok(Self::new(step))
40    }
41
42    /// Validates that the stored step is finite and positive.
43    ///
44    /// # Errors
45    ///
46    /// Returns the same error variants as [`Self::try_new`].
47    pub fn validate(self) -> Result<Self, CalculusError> {
48        Self::try_new(self.step)
49    }
50
51    /// Returns the sample step.
52    #[must_use]
53    pub const fn step(&self) -> f64 {
54        self.step
55    }
56
57    /// Approximates the first derivative using the central-difference formula.
58    ///
59    /// # Errors
60    ///
61    /// Returns [`CalculusError`] when the stored step is invalid, `at` is not
62    /// finite, or sampled evaluations are not finite.
63    ///
64    /// # Examples
65    ///
66    /// ```
67    /// use use_calculus::Differentiator;
68    ///
69    /// let differentiator = Differentiator::try_new(1.0e-5)?;
70    /// let slope = differentiator.derivative_at(|x| x.powi(2), 3.0)?;
71    ///
72    /// assert!((slope - 6.0).abs() < 1.0e-6);
73    /// # Ok::<(), use_calculus::CalculusError>(())
74    /// ```
75    pub fn derivative_at<F>(self, function: F, at: f64) -> Result<f64, CalculusError>
76    where
77        F: FnMut(f64) -> f64,
78    {
79        central_difference(function, at, self.step)
80    }
81
82    /// Approximates the second derivative using the central-difference formula.
83    ///
84    /// # Errors
85    ///
86    /// Returns [`CalculusError`] when the stored step is invalid, `at` is not
87    /// finite, or sampled evaluations are not finite.
88    pub fn second_derivative_at<F>(self, function: F, at: f64) -> Result<f64, CalculusError>
89    where
90        F: FnMut(f64) -> f64,
91    {
92        second_central_difference(function, at, self.step)
93    }
94}
95
96/// Approximates the first derivative with a central difference.
97///
98/// # Errors
99///
100/// Returns [`CalculusError`] when `step` is invalid, `at` is not finite, or
101/// sampled evaluations are not finite.
102#[must_use = "derivative estimates should be used or handled"]
103pub fn central_difference<F>(mut function: F, at: f64, step: f64) -> Result<f64, CalculusError>
104where
105    F: FnMut(f64) -> f64,
106{
107    let at = CalculusError::validate_point("at", at)?;
108    let step = CalculusError::validate_step(step)?;
109    let left = evaluate(&mut function, at - step)?;
110    let right = evaluate(&mut function, at + step)?;
111
112    Ok((right - left) / (2.0 * step))
113}
114
115/// Approximates the second derivative with a central difference.
116///
117/// # Errors
118///
119/// Returns [`CalculusError`] when `step` is invalid, `at` is not finite, or
120/// sampled evaluations are not finite.
121#[must_use = "second-derivative estimates should be used or handled"]
122pub fn second_central_difference<F>(
123    mut function: F,
124    at: f64,
125    step: f64,
126) -> Result<f64, CalculusError>
127where
128    F: FnMut(f64) -> f64,
129{
130    let at = CalculusError::validate_point("at", at)?;
131    let step = CalculusError::validate_step(step)?;
132    let left = evaluate(&mut function, at - step)?;
133    let center = evaluate(&mut function, at)?;
134    let right = evaluate(&mut function, at + step)?;
135    let step_squared = step * step;
136    let numerator = (-2.0_f64).mul_add(center, left + right);
137
138    Ok(numerator / step_squared)
139}
140
141fn evaluate<F>(function: &mut F, input: f64) -> Result<f64, CalculusError>
142where
143    F: FnMut(f64) -> f64,
144{
145    let input = CalculusError::validate_point("sample", input)?;
146    let value = function(input);
147
148    CalculusError::validate_evaluation(input, value)
149}
150
151#[cfg(test)]
152mod tests {
153    use super::{CalculusError, Differentiator, central_difference, second_central_difference};
154
155    fn assert_close(left: f64, right: f64, tolerance: f64) {
156        assert!(
157            (left - right).abs() <= tolerance,
158            "expected {left} to be within {tolerance} of {right}"
159        );
160    }
161
162    #[test]
163    fn validates_differentiator_steps() {
164        assert!(matches!(
165            Differentiator::try_new(f64::INFINITY),
166            Err(CalculusError::NonFiniteStep(f64::INFINITY))
167        ));
168        assert!(matches!(
169            Differentiator::try_new(0.0),
170            Err(CalculusError::NonPositiveStep(0.0))
171        ));
172    }
173
174    #[test]
175    fn computes_first_derivatives() -> Result<(), CalculusError> {
176        let slope = central_difference(|x| x.powi(2), 3.0, 1.0e-5)?;
177
178        assert_close(slope, 6.0, 1.0e-6);
179        Ok(())
180    }
181
182    #[test]
183    fn computes_second_derivatives() -> Result<(), CalculusError> {
184        let curvature = second_central_difference(|x| x.powi(2), 1.5, 1.0e-4)?;
185
186        assert_close(curvature, 2.0, 1.0e-6);
187        Ok(())
188    }
189
190    #[test]
191    fn rejects_non_finite_points() {
192        assert!(matches!(
193            central_difference(|x| x, f64::NAN, 1.0e-5),
194            Err(CalculusError::NonFinitePoint { name: "at", .. })
195        ));
196    }
197
198    #[test]
199    fn rejects_non_finite_evaluations() {
200        assert!(matches!(
201            central_difference(|_| f64::NAN, 1.0, 1.0e-5),
202            Err(CalculusError::NonFiniteEvaluation { .. })
203        ));
204    }
205}