Skip to main content

sidereon_core/astro/events/
root.rs

1//! Shared scalar-crossing refinement primitives.
2//!
3//! These helpers own the bisection mechanics used by event-style code: keep a
4//! bracket whose endpoint scalar values are on opposite sides of zero, choose a
5//! midpoint, and retain the half bracket containing the crossing.
6
7/// Return true when two scalar samples bracket a zero crossing.
8///
9/// Zero is treated as the non-negative side. This matches the legacy pass
10/// finder semantics, where a sample exactly on the mask is considered visible.
11pub fn sign_change_bracketed(a: f64, b: f64) -> Result<bool, RootError> {
12    validate_finite("bracket.low_value", a)?;
13    validate_finite("bracket.high_value", b)?;
14    Ok(!same_sign(a, b))
15}
16
17/// Error returned when scalar root-bracketing inputs leave the finite domain.
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum RootError<E = core::convert::Infallible> {
20    /// A scalar endpoint, sample, or midpoint value was non-finite.
21    InvalidInput {
22        /// Name of the malformed field.
23        field: &'static str,
24        /// Stable validation reason.
25        reason: &'static str,
26    },
27    /// The caller-provided fallible predicate returned an error.
28    Predicate(E),
29}
30
31impl<E: core::fmt::Display> core::fmt::Display for RootError<E> {
32    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
33        match self {
34            Self::InvalidInput { field, reason } => {
35                write!(f, "invalid root input {field}: {reason}")
36            }
37            Self::Predicate(error) => write!(f, "root predicate failed: {error}"),
38        }
39    }
40}
41
42impl<E: core::fmt::Debug + core::fmt::Display> std::error::Error for RootError<E> {}
43
44fn invalid_input<E>(field: &'static str, reason: &'static str) -> RootError<E> {
45    RootError::InvalidInput { field, reason }
46}
47
48fn validate_finite<E>(field: &'static str, value: f64) -> Result<f64, RootError<E>> {
49    if value.is_finite() {
50        Ok(value)
51    } else {
52        Err(invalid_input(field, "not finite"))
53    }
54}
55
56/// Refine a zero crossing for a fixed number of bisection iterations.
57///
58/// `value_at` returns the signed scalar value at an epoch-like point, and
59/// `midpoint` supplies the midpoint arithmetic for that point type.
60pub fn bisect_crossing_by_iterations<T, F, M>(
61    low: T,
62    high: T,
63    iterations: usize,
64    value_at: F,
65    midpoint: M,
66) -> Result<T, RootError>
67where
68    T: Copy + PartialEq,
69    F: FnMut(T) -> f64,
70    M: FnMut(T, T) -> T,
71{
72    let mut remaining = iterations;
73    bisect_crossing_while(low, high, value_at, midpoint, |_, _| {
74        if remaining == 0 {
75            false
76        } else {
77            remaining -= 1;
78            true
79        }
80    })
81}
82
83/// Refine a zero crossing until `within_tolerance` accepts the active bracket.
84///
85/// The predicate receives `(low, high)` and should return true once no more
86/// refinement is required.
87pub fn bisect_crossing_until<T, F, M, W>(
88    low: T,
89    high: T,
90    value_at: F,
91    midpoint: M,
92    mut within_tolerance: W,
93) -> Result<T, RootError>
94where
95    T: Copy + PartialEq,
96    F: FnMut(T) -> f64,
97    M: FnMut(T, T) -> T,
98    W: FnMut(T, T) -> bool,
99{
100    bisect_crossing_while(low, high, value_at, midpoint, |lo, hi| {
101        !within_tolerance(lo, hi)
102    })
103}
104
105/// Refine a zero crossing until `within_tolerance` accepts the active bracket,
106/// allowing predicate evaluation to abort the refinement.
107pub fn try_bisect_crossing_until<T, F, M, W, E>(
108    low: T,
109    high: T,
110    value_at: F,
111    midpoint: M,
112    mut within_tolerance: W,
113) -> Result<T, RootError<E>>
114where
115    T: Copy + PartialEq,
116    F: FnMut(T) -> Result<f64, E>,
117    M: FnMut(T, T) -> T,
118    W: FnMut(T, T) -> bool,
119{
120    try_bisect_crossing_while(low, high, value_at, midpoint, |lo, hi| {
121        !within_tolerance(lo, hi)
122    })
123}
124
125fn bisect_crossing_while<T, F, M, C>(
126    low: T,
127    high: T,
128    mut value_at: F,
129    mut midpoint: M,
130    mut keep_refining: C,
131) -> Result<T, RootError>
132where
133    T: Copy + PartialEq,
134    F: FnMut(T) -> f64,
135    M: FnMut(T, T) -> T,
136    C: FnMut(T, T) -> bool,
137{
138    let mut lo = low;
139    let mut hi = high;
140    let mut value_lo = validate_finite("bracket.low_value", value_at(lo))?;
141    validate_finite("bracket.high_value", value_at(hi))?;
142
143    while keep_refining(lo, hi) {
144        let mid = midpoint(lo, hi);
145        if mid == lo || mid == hi {
146            validate_finite("bracket.mid_value", value_at(mid))?;
147            return Ok(mid);
148        }
149        let value_mid = validate_finite("bracket.mid_value", value_at(mid))?;
150        if value_mid == 0.0 {
151            return Ok(mid);
152        }
153        if same_sign(value_lo, value_mid) {
154            lo = mid;
155            value_lo = value_mid;
156        } else {
157            hi = mid;
158        }
159    }
160
161    let mid = midpoint(lo, hi);
162    validate_finite("bracket.mid_value", value_at(mid))?;
163    Ok(mid)
164}
165
166fn try_bisect_crossing_while<T, F, M, C, E>(
167    low: T,
168    high: T,
169    mut value_at: F,
170    mut midpoint: M,
171    mut keep_refining: C,
172) -> Result<T, RootError<E>>
173where
174    T: Copy + PartialEq,
175    F: FnMut(T) -> Result<f64, E>,
176    M: FnMut(T, T) -> T,
177    C: FnMut(T, T) -> bool,
178{
179    let mut lo = low;
180    let mut hi = high;
181    let mut value_lo = validate_finite(
182        "bracket.low_value",
183        value_at(lo).map_err(RootError::Predicate)?,
184    )?;
185    validate_finite(
186        "bracket.high_value",
187        value_at(hi).map_err(RootError::Predicate)?,
188    )?;
189
190    while keep_refining(lo, hi) {
191        let mid = midpoint(lo, hi);
192        if mid == lo || mid == hi {
193            validate_finite(
194                "bracket.mid_value",
195                value_at(mid).map_err(RootError::Predicate)?,
196            )?;
197            return Ok(mid);
198        }
199        let value_mid = validate_finite(
200            "bracket.mid_value",
201            value_at(mid).map_err(RootError::Predicate)?,
202        )?;
203        if value_mid == 0.0 {
204            return Ok(mid);
205        }
206        if same_sign(value_lo, value_mid) {
207            lo = mid;
208            value_lo = value_mid;
209        } else {
210            hi = mid;
211        }
212    }
213
214    let mid = midpoint(lo, hi);
215    validate_finite(
216        "bracket.mid_value",
217        value_at(mid).map_err(RootError::Predicate)?,
218    )?;
219    Ok(mid)
220}
221
222fn same_sign(a: f64, b: f64) -> bool {
223    (a >= 0.0 && b >= 0.0) || (a < 0.0 && b < 0.0)
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    fn midpoint(a: f64, b: f64) -> f64 {
231        (a + b) * 0.5
232    }
233
234    #[test]
235    fn sign_change_bracket_uses_zero_as_non_negative_side() {
236        assert!(sign_change_bracketed(-1.0, 1.0).expect("finite bracket"));
237        assert!(sign_change_bracketed(-1.0, 0.0).expect("finite bracket"));
238        assert!(sign_change_bracketed(0.0, -1.0).expect("finite bracket"));
239        assert!(!sign_change_bracketed(0.0, 1.0).expect("finite bracket"));
240        assert!(!sign_change_bracketed(1.0, 0.0).expect("finite bracket"));
241    }
242
243    #[test]
244    fn fixed_iteration_bisection_refines_crossing() {
245        let crossing = bisect_crossing_by_iterations(0.0, 1.0, 4, |x| x - 0.3, midpoint)
246            .expect("finite bisection");
247
248        assert_eq!(crossing.to_bits(), 0.28125_f64.to_bits());
249    }
250
251    #[test]
252    fn tolerance_bisection_refines_to_requested_bracket_width() {
253        let crossing = bisect_crossing_until(
254            1.0,
255            2.0,
256            |x| x * x - 2.0,
257            midpoint,
258            |lo, hi| (hi - lo).abs() <= 1.0e-12,
259        )
260        .expect("finite bisection");
261
262        assert!((crossing - 2.0_f64.sqrt()).abs() <= 5.0e-13);
263    }
264
265    #[test]
266    fn bisection_returns_exact_midpoint_root() {
267        let crossing = bisect_crossing_by_iterations(0.0, 2.0, 8, |x| x - 1.0, midpoint)
268            .expect("finite bisection");
269
270        assert_eq!(crossing.to_bits(), 1.0_f64.to_bits());
271
272        let crossing = try_bisect_crossing_until(
273            0.0,
274            2.0,
275            |x| Ok::<f64, ()>(x - 1.0),
276            midpoint,
277            |lo, hi| (hi - lo).abs() <= 1.0e-12,
278        )
279        .expect("exact midpoint root should resolve");
280
281        assert_eq!(crossing.to_bits(), 1.0_f64.to_bits());
282    }
283
284    #[test]
285    fn bisection_stops_when_midpoint_cannot_shrink_bracket() {
286        let high = 1.0_f64;
287        let low = f64::from_bits(high.to_bits() - 1);
288        let max_iterations = 64;
289        let mut value_calls = 0;
290
291        let crossing = bisect_crossing_by_iterations(
292            low,
293            high,
294            max_iterations,
295            |x| {
296                value_calls += 1;
297                x - high
298            },
299            midpoint,
300        )
301        .expect("finite bisection");
302
303        assert_eq!(crossing.to_bits(), high.to_bits());
304        assert!(value_calls < max_iterations);
305    }
306
307    #[test]
308    fn fallible_bisection_returns_predicate_errors() {
309        let err = try_bisect_crossing_until(
310            0.0,
311            2.0,
312            |x| {
313                if x == 1.0 {
314                    Err("predicate")
315                } else {
316                    Ok(x - 1.0)
317                }
318            },
319            midpoint,
320            |lo, hi| (hi - lo).abs() <= 0.25,
321        )
322        .expect_err("midpoint error must abort refinement");
323
324        assert_eq!(err, RootError::Predicate("predicate"));
325    }
326}