1pub fn sign_change_bracketed(a: f64, b: f64) -> Result<bool, RootError> {
12 validate_finite("bracket.low_value", a)?;
13 validate_finite("bracket.high_value", b)?;
14 Ok(!same_sign(a, b))
15}
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum RootError<E = core::convert::Infallible> {
20 InvalidInput {
22 field: &'static str,
24 reason: &'static str,
26 },
27 Predicate(E),
29}
30
31impl<E: core::fmt::Display> core::fmt::Display for RootError<E> {
32 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
33 match self {
34 Self::InvalidInput { field, reason } => {
35 write!(f, "invalid root input {field}: {reason}")
36 }
37 Self::Predicate(error) => write!(f, "root predicate failed: {error}"),
38 }
39 }
40}
41
42impl<E: core::fmt::Debug + core::fmt::Display> std::error::Error for RootError<E> {}
43
44fn invalid_input<E>(field: &'static str, reason: &'static str) -> RootError<E> {
45 RootError::InvalidInput { field, reason }
46}
47
48fn validate_finite<E>(field: &'static str, value: f64) -> Result<f64, RootError<E>> {
49 if value.is_finite() {
50 Ok(value)
51 } else {
52 Err(invalid_input(field, "not finite"))
53 }
54}
55
56pub fn bisect_crossing_by_iterations<T, F, M>(
61 low: T,
62 high: T,
63 iterations: usize,
64 value_at: F,
65 midpoint: M,
66) -> Result<T, RootError>
67where
68 T: Copy + PartialEq,
69 F: FnMut(T) -> f64,
70 M: FnMut(T, T) -> T,
71{
72 let mut remaining = iterations;
73 bisect_crossing_while(low, high, value_at, midpoint, |_, _| {
74 if remaining == 0 {
75 false
76 } else {
77 remaining -= 1;
78 true
79 }
80 })
81}
82
83pub fn bisect_crossing_until<T, F, M, W>(
88 low: T,
89 high: T,
90 value_at: F,
91 midpoint: M,
92 mut within_tolerance: W,
93) -> Result<T, RootError>
94where
95 T: Copy + PartialEq,
96 F: FnMut(T) -> f64,
97 M: FnMut(T, T) -> T,
98 W: FnMut(T, T) -> bool,
99{
100 bisect_crossing_while(low, high, value_at, midpoint, |lo, hi| {
101 !within_tolerance(lo, hi)
102 })
103}
104
105pub fn try_bisect_crossing_until<T, F, M, W, E>(
108 low: T,
109 high: T,
110 value_at: F,
111 midpoint: M,
112 mut within_tolerance: W,
113) -> Result<T, RootError<E>>
114where
115 T: Copy + PartialEq,
116 F: FnMut(T) -> Result<f64, E>,
117 M: FnMut(T, T) -> T,
118 W: FnMut(T, T) -> bool,
119{
120 try_bisect_crossing_while(low, high, value_at, midpoint, |lo, hi| {
121 !within_tolerance(lo, hi)
122 })
123}
124
125fn bisect_crossing_while<T, F, M, C>(
126 low: T,
127 high: T,
128 mut value_at: F,
129 mut midpoint: M,
130 mut keep_refining: C,
131) -> Result<T, RootError>
132where
133 T: Copy + PartialEq,
134 F: FnMut(T) -> f64,
135 M: FnMut(T, T) -> T,
136 C: FnMut(T, T) -> bool,
137{
138 let mut lo = low;
139 let mut hi = high;
140 let mut value_lo = validate_finite("bracket.low_value", value_at(lo))?;
141 validate_finite("bracket.high_value", value_at(hi))?;
142
143 while keep_refining(lo, hi) {
144 let mid = midpoint(lo, hi);
145 if mid == lo || mid == hi {
146 validate_finite("bracket.mid_value", value_at(mid))?;
147 return Ok(mid);
148 }
149 let value_mid = validate_finite("bracket.mid_value", value_at(mid))?;
150 if value_mid == 0.0 {
151 return Ok(mid);
152 }
153 if same_sign(value_lo, value_mid) {
154 lo = mid;
155 value_lo = value_mid;
156 } else {
157 hi = mid;
158 }
159 }
160
161 let mid = midpoint(lo, hi);
162 validate_finite("bracket.mid_value", value_at(mid))?;
163 Ok(mid)
164}
165
166fn try_bisect_crossing_while<T, F, M, C, E>(
167 low: T,
168 high: T,
169 mut value_at: F,
170 mut midpoint: M,
171 mut keep_refining: C,
172) -> Result<T, RootError<E>>
173where
174 T: Copy + PartialEq,
175 F: FnMut(T) -> Result<f64, E>,
176 M: FnMut(T, T) -> T,
177 C: FnMut(T, T) -> bool,
178{
179 let mut lo = low;
180 let mut hi = high;
181 let mut value_lo = validate_finite(
182 "bracket.low_value",
183 value_at(lo).map_err(RootError::Predicate)?,
184 )?;
185 validate_finite(
186 "bracket.high_value",
187 value_at(hi).map_err(RootError::Predicate)?,
188 )?;
189
190 while keep_refining(lo, hi) {
191 let mid = midpoint(lo, hi);
192 if mid == lo || mid == hi {
193 validate_finite(
194 "bracket.mid_value",
195 value_at(mid).map_err(RootError::Predicate)?,
196 )?;
197 return Ok(mid);
198 }
199 let value_mid = validate_finite(
200 "bracket.mid_value",
201 value_at(mid).map_err(RootError::Predicate)?,
202 )?;
203 if value_mid == 0.0 {
204 return Ok(mid);
205 }
206 if same_sign(value_lo, value_mid) {
207 lo = mid;
208 value_lo = value_mid;
209 } else {
210 hi = mid;
211 }
212 }
213
214 let mid = midpoint(lo, hi);
215 validate_finite(
216 "bracket.mid_value",
217 value_at(mid).map_err(RootError::Predicate)?,
218 )?;
219 Ok(mid)
220}
221
222fn same_sign(a: f64, b: f64) -> bool {
223 (a >= 0.0 && b >= 0.0) || (a < 0.0 && b < 0.0)
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 fn midpoint(a: f64, b: f64) -> f64 {
231 (a + b) * 0.5
232 }
233
234 #[test]
235 fn sign_change_bracket_uses_zero_as_non_negative_side() {
236 assert!(sign_change_bracketed(-1.0, 1.0).expect("finite bracket"));
237 assert!(sign_change_bracketed(-1.0, 0.0).expect("finite bracket"));
238 assert!(sign_change_bracketed(0.0, -1.0).expect("finite bracket"));
239 assert!(!sign_change_bracketed(0.0, 1.0).expect("finite bracket"));
240 assert!(!sign_change_bracketed(1.0, 0.0).expect("finite bracket"));
241 }
242
243 #[test]
244 fn fixed_iteration_bisection_refines_crossing() {
245 let crossing = bisect_crossing_by_iterations(0.0, 1.0, 4, |x| x - 0.3, midpoint)
246 .expect("finite bisection");
247
248 assert_eq!(crossing.to_bits(), 0.28125_f64.to_bits());
249 }
250
251 #[test]
252 fn tolerance_bisection_refines_to_requested_bracket_width() {
253 let crossing = bisect_crossing_until(
254 1.0,
255 2.0,
256 |x| x * x - 2.0,
257 midpoint,
258 |lo, hi| (hi - lo).abs() <= 1.0e-12,
259 )
260 .expect("finite bisection");
261
262 assert!((crossing - 2.0_f64.sqrt()).abs() <= 5.0e-13);
263 }
264
265 #[test]
266 fn bisection_returns_exact_midpoint_root() {
267 let crossing = bisect_crossing_by_iterations(0.0, 2.0, 8, |x| x - 1.0, midpoint)
268 .expect("finite bisection");
269
270 assert_eq!(crossing.to_bits(), 1.0_f64.to_bits());
271
272 let crossing = try_bisect_crossing_until(
273 0.0,
274 2.0,
275 |x| Ok::<f64, ()>(x - 1.0),
276 midpoint,
277 |lo, hi| (hi - lo).abs() <= 1.0e-12,
278 )
279 .expect("exact midpoint root should resolve");
280
281 assert_eq!(crossing.to_bits(), 1.0_f64.to_bits());
282 }
283
284 #[test]
285 fn bisection_stops_when_midpoint_cannot_shrink_bracket() {
286 let high = 1.0_f64;
287 let low = f64::from_bits(high.to_bits() - 1);
288 let max_iterations = 64;
289 let mut value_calls = 0;
290
291 let crossing = bisect_crossing_by_iterations(
292 low,
293 high,
294 max_iterations,
295 |x| {
296 value_calls += 1;
297 x - high
298 },
299 midpoint,
300 )
301 .expect("finite bisection");
302
303 assert_eq!(crossing.to_bits(), high.to_bits());
304 assert!(value_calls < max_iterations);
305 }
306
307 #[test]
308 fn fallible_bisection_returns_predicate_errors() {
309 let err = try_bisect_crossing_until(
310 0.0,
311 2.0,
312 |x| {
313 if x == 1.0 {
314 Err("predicate")
315 } else {
316 Ok(x - 1.0)
317 }
318 },
319 midpoint,
320 |lo, hi| (hi - lo).abs() <= 0.25,
321 )
322 .expect_err("midpoint error must abort refinement");
323
324 assert_eq!(err, RootError::Predicate("predicate"));
325 }
326}