1use crate::astro::events::root::{sign_change_bracketed, try_bisect_crossing_until, RootError};
9use crate::validate;
10use rayon::prelude::*;
11
12const GOLDEN_RESPHI: f64 = 0.381_966_011_250_105_1;
13const MAX_GOLDEN_ITERATIONS: usize = 128;
14const MAX_EVENT_COARSE_SAMPLES: usize = 1_000_000;
15
16pub trait ScalarEventPredicate {
18 fn value_at(&self, time_seconds: f64) -> f64;
20}
21
22impl<F> ScalarEventPredicate for F
23where
24 F: Fn(f64) -> f64,
25{
26 fn value_at(&self, time_seconds: f64) -> f64 {
27 self(time_seconds)
28 }
29}
30
31pub trait DiscreteEventPredicate {
33 fn state_at(&self, time_seconds: f64) -> bool;
35}
36
37impl<F> DiscreteEventPredicate for F
38where
39 F: Fn(f64) -> bool,
40{
41 fn state_at(&self, time_seconds: f64) -> bool {
42 self(time_seconds)
43 }
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
48pub enum EventFinderError {
49 #[error("invalid event-finder input {field}: {reason}")]
51 InvalidInput {
52 field: &'static str,
54 reason: &'static str,
56 },
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum CrossingDirection {
62 Rising,
64 Falling,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum ExtremumKind {
71 Maximum,
73 Minimum,
75}
76
77#[derive(Debug, Clone, Copy, PartialEq)]
79pub struct CrossingEvent {
80 pub time_seconds: f64,
82 pub value: f64,
84 pub threshold: f64,
86 pub direction: CrossingDirection,
88}
89
90#[derive(Debug, Clone, Copy, PartialEq)]
92pub struct ExtremumEvent {
93 pub time_seconds: f64,
95 pub value: f64,
97 pub kind: ExtremumKind,
99}
100
101#[derive(Debug, Clone, Copy, PartialEq)]
103pub struct StateChangeEvent {
104 pub time_seconds: f64,
106 pub previous_state: bool,
108 pub next_state: bool,
110}
111
112#[derive(Debug, Clone, Copy, PartialEq)]
114pub struct EventFinder {
115 start_seconds: f64,
116 end_seconds: f64,
117 step_seconds: f64,
118 time_tolerance_seconds: f64,
119}
120
121impl EventFinder {
122 pub fn new(
124 start_seconds: f64,
125 end_seconds: f64,
126 step_seconds: f64,
127 time_tolerance_seconds: f64,
128 ) -> Result<Self, EventFinderError> {
129 let start_seconds =
130 validate::finite(start_seconds, "start_seconds").map_err(map_event_input)?;
131 let end_seconds = validate::finite(end_seconds, "end_seconds").map_err(map_event_input)?;
132 validate::range_order(start_seconds, end_seconds, "end_seconds")
133 .map_err(map_event_input)?;
134 let step_seconds =
135 validate::positive_step(step_seconds, "step_seconds").map_err(map_event_input)?;
136 let time_tolerance_seconds =
137 validate::positive_step(time_tolerance_seconds, "time_tolerance_seconds")
138 .map_err(map_event_input)?;
139
140 Ok(Self {
141 start_seconds,
142 end_seconds,
143 step_seconds,
144 time_tolerance_seconds,
145 })
146 }
147
148 pub fn start_seconds(self) -> f64 {
150 self.start_seconds
151 }
152
153 pub fn end_seconds(self) -> f64 {
155 self.end_seconds
156 }
157
158 pub fn step_seconds(self) -> f64 {
160 self.step_seconds
161 }
162
163 pub fn time_tolerance_seconds(self) -> f64 {
165 self.time_tolerance_seconds
166 }
167
168 pub fn find_crossings<P>(
170 self,
171 predicate: P,
172 threshold: f64,
173 ) -> Result<Vec<CrossingEvent>, EventFinderError>
174 where
175 P: ScalarEventPredicate,
176 {
177 self.find_crossings_ref(&predicate, threshold)
178 }
179
180 pub fn find_crossings_batch_serial<P>(
184 self,
185 predicates: &[P],
186 threshold: f64,
187 ) -> Vec<Result<Vec<CrossingEvent>, EventFinderError>>
188 where
189 P: ScalarEventPredicate,
190 {
191 predicates
192 .iter()
193 .map(|predicate| self.find_crossings_ref(predicate, threshold))
194 .collect()
195 }
196
197 pub fn find_crossings_batch_parallel<P>(
203 self,
204 predicates: &[P],
205 threshold: f64,
206 ) -> Vec<Result<Vec<CrossingEvent>, EventFinderError>>
207 where
208 P: ScalarEventPredicate + Sync,
209 {
210 predicates
211 .par_iter()
212 .map(|predicate| self.find_crossings_ref(predicate, threshold))
213 .collect()
214 }
215
216 fn find_crossings_ref<P>(
217 self,
218 predicate: &P,
219 threshold: f64,
220 ) -> Result<Vec<CrossingEvent>, EventFinderError>
221 where
222 P: ScalarEventPredicate + ?Sized,
223 {
224 let threshold = validate::finite(threshold, "threshold").map_err(map_event_input)?;
225 let samples = self.scalar_samples(predicate)?;
226 let mut events = Vec::new();
227
228 for (left_index, pair) in samples.windows(2).enumerate() {
229 let a = pair[0];
230 let b = pair[1];
231 let value_a = a.value - threshold;
232 let value_b = b.value - threshold;
233 let Some(direction) =
234 crossing_direction_for_sample_pair(&samples, left_index, threshold)
235 else {
236 continue;
237 };
238 let time_seconds = if value_a == 0.0 {
239 let zero_run_start = zero_run_start(&samples, left_index, threshold);
240 samples[zero_run_start].time_seconds
241 } else if value_b == 0.0 {
242 b.time_seconds
243 } else {
244 try_bisect_crossing_until(
245 a.time_seconds,
246 b.time_seconds,
247 |time| finite_predicate_value(predicate.value_at(time) - threshold),
248 midpoint_seconds,
249 |lo, hi| (hi - lo).abs() <= self.time_tolerance_seconds,
250 )
251 .map_err(map_root_error)?
252 };
253 if events.last().is_some_and(|event: &CrossingEvent| {
254 event.time_seconds.to_bits() == time_seconds.to_bits()
255 }) {
256 continue;
257 }
258 let value = finite_predicate_value(predicate.value_at(time_seconds))?;
259
260 events.push(CrossingEvent {
261 time_seconds,
262 value,
263 threshold,
264 direction,
265 });
266 }
267
268 Ok(events)
269 }
270
271 pub fn find_extrema<P>(self, predicate: P) -> Result<Vec<ExtremumEvent>, EventFinderError>
273 where
274 P: ScalarEventPredicate,
275 {
276 self.find_extrema_ref(&predicate)
277 }
278
279 pub fn find_extrema_batch_serial<P>(
283 self,
284 predicates: &[P],
285 ) -> Vec<Result<Vec<ExtremumEvent>, EventFinderError>>
286 where
287 P: ScalarEventPredicate,
288 {
289 predicates
290 .iter()
291 .map(|predicate| self.find_extrema_ref(predicate))
292 .collect()
293 }
294
295 pub fn find_extrema_batch_parallel<P>(
301 self,
302 predicates: &[P],
303 ) -> Vec<Result<Vec<ExtremumEvent>, EventFinderError>>
304 where
305 P: ScalarEventPredicate + Sync,
306 {
307 predicates
308 .par_iter()
309 .map(|predicate| self.find_extrema_ref(predicate))
310 .collect()
311 }
312
313 fn find_extrema_ref<P>(self, predicate: &P) -> Result<Vec<ExtremumEvent>, EventFinderError>
314 where
315 P: ScalarEventPredicate + ?Sized,
316 {
317 let samples = self.extrema_samples(predicate)?;
318 let mut events = Vec::new();
319
320 let mut index = 1;
321 while index + 1 < samples.len() {
322 let run_start = index;
323 let run_value = samples[run_start].value;
324 let mut run_end = run_start;
325 while run_end + 1 < samples.len() && samples[run_end + 1].value == run_value {
326 run_end += 1;
327 }
328
329 if run_end + 1 >= samples.len() {
330 break;
331 }
332
333 let prev = samples[run_start - 1];
334 let next = samples[run_end + 1];
335 let kind = if run_value > prev.value && run_value > next.value {
336 Some(ExtremumKind::Maximum)
337 } else if run_value < prev.value && run_value < next.value {
338 Some(ExtremumKind::Minimum)
339 } else {
340 None
341 };
342
343 if let Some(kind) = kind {
344 events.push(self.refine_extremum(
345 predicate,
346 kind,
347 prev.time_seconds,
348 next.time_seconds,
349 )?);
350 }
351
352 index = run_end + 1;
353 }
354
355 Ok(events)
356 }
357
358 pub fn find_state_changes<P>(
360 self,
361 predicate: P,
362 ) -> Result<Vec<StateChangeEvent>, EventFinderError>
363 where
364 P: DiscreteEventPredicate,
365 {
366 self.find_state_changes_ref(&predicate)
367 }
368
369 pub fn find_state_changes_batch_serial<P>(
373 self,
374 predicates: &[P],
375 ) -> Vec<Result<Vec<StateChangeEvent>, EventFinderError>>
376 where
377 P: DiscreteEventPredicate,
378 {
379 predicates
380 .iter()
381 .map(|predicate| self.find_state_changes_ref(predicate))
382 .collect()
383 }
384
385 pub fn find_state_changes_batch_parallel<P>(
391 self,
392 predicates: &[P],
393 ) -> Vec<Result<Vec<StateChangeEvent>, EventFinderError>>
394 where
395 P: DiscreteEventPredicate + Sync,
396 {
397 predicates
398 .par_iter()
399 .map(|predicate| self.find_state_changes_ref(predicate))
400 .collect()
401 }
402
403 fn find_state_changes_ref<P>(
404 self,
405 predicate: &P,
406 ) -> Result<Vec<StateChangeEvent>, EventFinderError>
407 where
408 P: DiscreteEventPredicate + ?Sized,
409 {
410 let samples = self.state_samples(predicate)?;
411 let mut events = Vec::new();
412
413 for pair in samples.windows(2) {
414 let a = pair[0];
415 let b = pair[1];
416 if a.state == b.state {
417 continue;
418 }
419
420 let time_seconds =
421 self.refine_state_change(predicate, a.time_seconds, b.time_seconds, a.state);
422 if events.last().is_some_and(|event: &StateChangeEvent| {
423 event.time_seconds.to_bits() == time_seconds.to_bits()
424 }) {
425 continue;
426 }
427 events.push(StateChangeEvent {
428 time_seconds,
429 previous_state: a.state,
430 next_state: b.state,
431 });
432 }
433
434 Ok(events)
435 }
436
437 fn scalar_samples<P>(self, predicate: &P) -> Result<Vec<ScalarSample>, EventFinderError>
438 where
439 P: ScalarEventPredicate + ?Sized,
440 {
441 let duration_seconds = self.end_seconds - self.start_seconds;
442 let sample_iterations = self.coarse_sample_iterations()?;
443 let mut samples = Vec::with_capacity(sample_iterations.saturating_add(1));
444 let mut offset_seconds = 0.0;
445
446 for _ in 0..sample_iterations {
447 if offset_seconds >= duration_seconds {
448 break;
449 }
450 let time_seconds = self.start_seconds + offset_seconds;
451 if time_seconds >= self.end_seconds {
452 break;
453 }
454 samples.push(ScalarSample {
455 time_seconds,
456 value: finite_predicate_value(predicate.value_at(time_seconds))?,
457 });
458 let next_offset_seconds = offset_seconds + self.step_seconds;
459 if next_offset_seconds <= offset_seconds {
460 return Err(non_advancing_sample_step_error());
461 }
462 offset_seconds = next_offset_seconds;
463 }
464 if offset_seconds < duration_seconds
465 && self.start_seconds + offset_seconds < self.end_seconds
466 {
467 return Err(too_many_event_samples_error());
468 }
469
470 samples.push(ScalarSample {
471 time_seconds: self.end_seconds,
472 value: finite_predicate_value(predicate.value_at(self.end_seconds))?,
473 });
474 Ok(samples)
475 }
476
477 fn extrema_samples<P>(self, predicate: &P) -> Result<Vec<ScalarSample>, EventFinderError>
478 where
479 P: ScalarEventPredicate + ?Sized,
480 {
481 let mut samples = self.scalar_samples(predicate)?;
482 if samples.len() == 2 {
483 let midpoint = midpoint_seconds(samples[0].time_seconds, samples[1].time_seconds);
484 if midpoint != samples[0].time_seconds && midpoint != samples[1].time_seconds {
485 samples.insert(
486 1,
487 ScalarSample {
488 time_seconds: midpoint,
489 value: finite_predicate_value(predicate.value_at(midpoint))?,
490 },
491 );
492 }
493 }
494 Ok(samples)
495 }
496
497 fn state_samples<P>(self, predicate: &P) -> Result<Vec<StateSample>, EventFinderError>
498 where
499 P: DiscreteEventPredicate + ?Sized,
500 {
501 let duration_seconds = self.end_seconds - self.start_seconds;
502 let sample_iterations = self.coarse_sample_iterations()?;
503 let mut samples = Vec::with_capacity(sample_iterations.saturating_add(1));
504 let mut offset_seconds = 0.0;
505
506 for _ in 0..sample_iterations {
507 if offset_seconds >= duration_seconds {
508 break;
509 }
510 let time_seconds = self.start_seconds + offset_seconds;
511 if time_seconds >= self.end_seconds {
512 break;
513 }
514 samples.push(StateSample {
515 time_seconds,
516 state: predicate.state_at(time_seconds),
517 });
518 let next_offset_seconds = offset_seconds + self.step_seconds;
519 if next_offset_seconds <= offset_seconds {
520 return Err(non_advancing_sample_step_error());
521 }
522 offset_seconds = next_offset_seconds;
523 }
524 if offset_seconds < duration_seconds
525 && self.start_seconds + offset_seconds < self.end_seconds
526 {
527 return Err(too_many_event_samples_error());
528 }
529
530 samples.push(StateSample {
531 time_seconds: self.end_seconds,
532 state: predicate.state_at(self.end_seconds),
533 });
534 Ok(samples)
535 }
536
537 fn coarse_sample_iterations(self) -> Result<usize, EventFinderError> {
538 let duration_seconds = self.end_seconds - self.start_seconds;
539 if duration_seconds <= 0.0 {
540 return Ok(0);
541 }
542 if !duration_seconds.is_finite() {
543 return Err(too_many_event_samples_error());
544 }
545
546 let coarse_samples = (duration_seconds / self.step_seconds).ceil();
547 if !(coarse_samples.is_finite() && coarse_samples >= 1.0) {
548 return Err(too_many_event_samples_error());
549 }
550 if coarse_samples > MAX_EVENT_COARSE_SAMPLES as f64 {
551 return Err(too_many_event_samples_error());
552 }
553
554 Ok((coarse_samples as usize).saturating_add(1))
555 }
556
557 fn refine_extremum<P>(
558 self,
559 predicate: &P,
560 kind: ExtremumKind,
561 low: f64,
562 high: f64,
563 ) -> Result<ExtremumEvent, EventFinderError>
564 where
565 P: ScalarEventPredicate + ?Sized,
566 {
567 let mut lo = low;
568 let mut hi = high;
569
570 for _ in 0..MAX_GOLDEN_ITERATIONS {
571 if (hi - lo).abs() <= self.time_tolerance_seconds {
572 break;
573 }
574 let span = hi - lo;
575 let left = lo + GOLDEN_RESPHI * span;
576 let right = hi - GOLDEN_RESPHI * span;
577 if !(left > lo && right < hi) {
578 break;
579 }
580
581 let score_left =
582 extremum_score(kind, finite_predicate_value(predicate.value_at(left))?);
583 let score_right =
584 extremum_score(kind, finite_predicate_value(predicate.value_at(right))?);
585
586 let score_delta = (score_left - score_right).abs();
587 let score_scale = score_left.abs().max(score_right.abs()).max(1.0);
588 if score_delta <= 16.0 * f64::EPSILON * score_scale {
589 lo = left;
590 hi = right;
591 } else if score_left > score_right {
592 hi = right;
593 } else {
594 lo = left;
595 }
596 }
597
598 let time_seconds = midpoint_seconds(lo, hi);
599 let value = finite_predicate_value(predicate.value_at(time_seconds))?;
600 Ok(ExtremumEvent {
601 time_seconds,
602 value,
603 kind,
604 })
605 }
606
607 fn refine_state_change<P>(self, predicate: &P, low: f64, high: f64, low_state: bool) -> f64
608 where
609 P: DiscreteEventPredicate + ?Sized,
610 {
611 let mut lo = low;
612 let mut hi = high;
613
614 while (hi - lo).abs() > self.time_tolerance_seconds {
615 let mid = midpoint_seconds(lo, hi);
616 if mid == lo || mid == hi {
617 return mid;
618 }
619 let mid_state = predicate.state_at(mid);
620 if exact_state_transition_midpoint(predicate, lo, hi, mid, low_state, mid_state) {
621 return mid;
622 }
623 if mid_state == low_state {
624 lo = mid;
625 } else {
626 hi = mid;
627 }
628 }
629
630 midpoint_seconds(lo, hi)
631 }
632}
633
634#[derive(Debug, Clone, Copy)]
635struct ScalarSample {
636 time_seconds: f64,
637 value: f64,
638}
639
640#[derive(Debug, Clone, Copy)]
641struct StateSample {
642 time_seconds: f64,
643 state: bool,
644}
645
646fn midpoint_seconds(a: f64, b: f64) -> f64 {
647 (a + b) * 0.5
648}
649
650fn map_root_error(error: RootError<EventFinderError>) -> EventFinderError {
651 match error {
652 RootError::InvalidInput { field, reason } => {
653 EventFinderError::InvalidInput { field, reason }
654 }
655 RootError::Predicate(error) => error,
656 }
657}
658
659fn crossing_direction_for_sample_pair(
660 samples: &[ScalarSample],
661 left_index: usize,
662 threshold: f64,
663) -> Option<CrossingDirection> {
664 let value_a = samples[left_index].value - threshold;
665 let value_b = samples[left_index + 1].value - threshold;
666
667 if value_a == 0.0 || value_b == 0.0 {
668 return exact_sample_crossing_direction(samples, left_index, threshold, value_a, value_b);
669 }
670 if !sign_change_bracketed(value_a, value_b).unwrap_or(false) {
671 return None;
672 }
673 Some(crossing_direction_from_sides(value_a, value_b))
674}
675
676fn exact_sample_crossing_direction(
677 samples: &[ScalarSample],
678 left_index: usize,
679 threshold: f64,
680 value_a: f64,
681 value_b: f64,
682) -> Option<CrossingDirection> {
683 if value_a == 0.0 && value_b == 0.0 {
684 return None;
685 }
686
687 if value_b == 0.0 {
688 let right_value = first_nonzero_value_from(samples, left_index + 2, threshold);
689 return match right_value {
690 Some(right) => crossing_direction_from_opposite_sides(value_a, right),
691 None => Some(crossing_direction_from_sides(value_a, 0.0)),
692 };
693 }
694
695 let zero_run_start = zero_run_start(samples, left_index, threshold);
696 match last_nonzero_value_before(samples, zero_run_start, threshold) {
697 Some(_) => None,
698 None => Some(crossing_direction_from_sides(0.0, value_b)),
699 }
700}
701
702fn zero_run_start(samples: &[ScalarSample], zero_index: usize, threshold: f64) -> usize {
703 let mut index = zero_index;
704 while index > 0 && samples[index - 1].value - threshold == 0.0 {
705 index -= 1;
706 }
707 index
708}
709
710fn last_nonzero_value_before(
711 samples: &[ScalarSample],
712 end_index: usize,
713 threshold: f64,
714) -> Option<f64> {
715 samples[..end_index]
716 .iter()
717 .rev()
718 .map(|sample| sample.value - threshold)
719 .find(|value| *value != 0.0)
720}
721
722fn first_nonzero_value_from(
723 samples: &[ScalarSample],
724 start_index: usize,
725 threshold: f64,
726) -> Option<f64> {
727 samples
728 .iter()
729 .skip(start_index)
730 .map(|sample| sample.value - threshold)
731 .find(|value| *value != 0.0)
732}
733
734fn crossing_direction_from_opposite_sides(left: f64, right: f64) -> Option<CrossingDirection> {
735 if left < 0.0 && right > 0.0 {
736 Some(CrossingDirection::Rising)
737 } else if left > 0.0 && right < 0.0 {
738 Some(CrossingDirection::Falling)
739 } else {
740 None
741 }
742}
743
744fn crossing_direction_from_sides(left: f64, right: f64) -> CrossingDirection {
745 if left < 0.0 || (left == 0.0 && right > 0.0) {
746 CrossingDirection::Rising
747 } else {
748 CrossingDirection::Falling
749 }
750}
751
752fn exact_state_transition_midpoint<P>(
753 predicate: &P,
754 lo: f64,
755 hi: f64,
756 mid: f64,
757 low_state: bool,
758 mid_state: bool,
759) -> bool
760where
761 P: DiscreteEventPredicate + ?Sized,
762{
763 if mid_state == low_state {
764 predicate.state_at(adjacent_float_toward(mid, hi)) != low_state
765 } else {
766 predicate.state_at(adjacent_float_toward(mid, lo)) == low_state
767 }
768}
769
770fn adjacent_float_toward(value: f64, target: f64) -> f64 {
771 if target > value {
772 next_float_up(value)
773 } else {
774 next_float_down(value)
775 }
776}
777
778fn next_float_up(value: f64) -> f64 {
779 if value == f64::INFINITY {
780 return value;
781 }
782 let bits = value.to_bits();
783 if bits == 0x8000_0000_0000_0000 {
784 f64::from_bits(1)
785 } else if value >= 0.0 {
786 f64::from_bits(bits + 1)
787 } else {
788 f64::from_bits(bits - 1)
789 }
790}
791
792fn next_float_down(value: f64) -> f64 {
793 if value == f64::NEG_INFINITY {
794 return value;
795 }
796 let bits = value.to_bits();
797 if bits == 0 {
798 f64::from_bits(0x8000_0000_0000_0001)
799 } else if value > 0.0 {
800 f64::from_bits(bits - 1)
801 } else {
802 f64::from_bits(bits + 1)
803 }
804}
805
806fn extremum_score(kind: ExtremumKind, value: f64) -> f64 {
807 match kind {
808 ExtremumKind::Maximum => value,
809 ExtremumKind::Minimum => -value,
810 }
811}
812
813fn finite_predicate_value(value: f64) -> Result<f64, EventFinderError> {
814 validate::finite(value, "predicate").map_err(map_event_input)
815}
816
817fn too_many_event_samples_error() -> EventFinderError {
818 EventFinderError::InvalidInput {
819 field: "step_seconds",
820 reason: "too many samples",
821 }
822}
823
824fn non_advancing_sample_step_error() -> EventFinderError {
825 EventFinderError::InvalidInput {
826 field: "step_seconds",
827 reason: "does not advance samples",
828 }
829}
830
831fn map_event_input(error: validate::FieldError) -> EventFinderError {
832 EventFinderError::InvalidInput {
833 field: error.field(),
834 reason: error.reason(),
835 }
836}
837
838#[cfg(test)]
839mod tests {
840 use super::*;
841 use std::cell::Cell;
842 use std::f64::consts::{FRAC_PI_2, PI, TAU};
843
844 #[derive(Debug, Clone, Copy)]
845 struct ShiftedSine {
846 phase_seconds: f64,
847 }
848
849 impl ScalarEventPredicate for ShiftedSine {
850 fn value_at(&self, time_seconds: f64) -> f64 {
851 (time_seconds + self.phase_seconds).sin()
852 }
853 }
854
855 #[derive(Debug, Clone, Copy)]
856 struct StepState {
857 transition_seconds: f64,
858 }
859
860 impl DiscreteEventPredicate for StepState {
861 fn state_at(&self, time_seconds: f64) -> bool {
862 time_seconds >= self.transition_seconds
863 }
864 }
865
866 fn finder(start: f64, end: f64) -> EventFinder {
867 EventFinder::new(start, end, 0.2, 1.0e-12).expect("valid finder")
868 }
869
870 #[test]
871 fn scalar_samples_step_from_relative_offset_after_nonzero_start() {
872 let start = 1_000_000_000.0;
873 let step = 0.1;
874 let end = start + 0.5;
875 let samples = EventFinder::new(start, end, step, 1.0e-12)
876 .expect("valid finder")
877 .scalar_samples(&|time| time)
878 .expect("finite samples");
879 let expected_times = [
880 start,
881 start + step,
882 start + 2.0 * step,
883 start + 3.0 * step,
884 start + 4.0 * step,
885 end,
886 ];
887
888 assert_eq!(samples.len(), expected_times.len());
889 for (index, (sample, expected_time)) in samples.iter().zip(expected_times).enumerate() {
890 assert_eq!(
891 sample.time_seconds.to_bits(),
892 expected_time.to_bits(),
893 "sample {index} time"
894 );
895 assert_eq!(
896 sample.value.to_bits(),
897 expected_time.to_bits(),
898 "sample {index} value"
899 );
900 }
901 }
902
903 #[test]
904 fn scalar_samples_preserve_repeated_addition_near_endpoint() {
905 let samples = EventFinder::new(0.0, 1.0, 0.1, 1.0e-12)
906 .expect("valid finder")
907 .scalar_samples(&|time| time)
908 .expect("finite samples");
909
910 assert_eq!(samples.len(), 12);
911 assert_eq!(
912 samples[samples.len() - 2].time_seconds.to_bits(),
913 0.999_999_999_999_999_9_f64.to_bits()
914 );
915 assert_eq!(
916 samples
917 .last()
918 .expect("endpoint sample")
919 .time_seconds
920 .to_bits(),
921 1.0_f64.to_bits()
922 );
923 }
924
925 #[test]
926 fn scalar_samples_reject_infeasible_grid_without_sampling() {
927 let finder = EventFinder::new(0.0, 1.0, f64::MIN_POSITIVE, 1.0e-12).expect("valid finder");
928
929 assert_invalid_field(
930 finder.find_crossings(|_| 1.0, 0.0).unwrap_err(),
931 "step_seconds",
932 "too many samples",
933 );
934 assert_invalid_field(
935 finder.find_extrema(|time| time).unwrap_err(),
936 "step_seconds",
937 "too many samples",
938 );
939 }
940
941 #[test]
942 fn state_changes_reject_infeasible_grid_without_sampling() {
943 let finder = EventFinder::new(0.0, 1.0, f64::MIN_POSITIVE, 1.0e-12).expect("valid finder");
944 let state_calls = Cell::new(0);
945
946 assert_invalid_field(
947 finder
948 .find_state_changes(|time| {
949 state_calls.set(state_calls.get() + 1);
950 time >= 0.5
951 })
952 .unwrap_err(),
953 "step_seconds",
954 "too many samples",
955 );
956 assert_eq!(state_calls.get(), 0);
957
958 let predicates: [fn(f64) -> bool; 3] =
959 [|time| time >= 0.25, |time| time >= 0.5, |time| time >= 0.75];
960 let serial = finder.find_state_changes_batch_serial(&predicates);
961 let parallel = finder.find_state_changes_batch_parallel(&predicates);
962 assert_eq!(serial, parallel);
963 assert!(serial.iter().all(|result| {
964 matches!(
965 result,
966 Err(EventFinderError::InvalidInput {
967 field: "step_seconds",
968 reason: "too many samples"
969 })
970 )
971 }));
972 }
973
974 #[test]
975 fn crossings_find_sine_zeroes_with_direction() {
976 let events = finder(-0.4, TAU + 0.4)
977 .find_crossings(f64::sin, 0.0)
978 .expect("finite sine samples");
979
980 assert_eq!(events.len(), 3);
981 assert_close(events[0].time_seconds, 0.0, 1.0e-12);
982 assert_eq!(events[0].direction, CrossingDirection::Rising);
983 assert_close(events[0].value, 0.0, 1.0e-12);
984
985 assert_close(events[1].time_seconds, PI, 1.0e-12);
986 assert_eq!(events[1].direction, CrossingDirection::Falling);
987 assert_close(events[1].value, 0.0, 1.0e-12);
988
989 assert_close(events[2].time_seconds, TAU, 1.0e-12);
990 assert_eq!(events[2].direction, CrossingDirection::Rising);
991 assert_close(events[2].value, 0.0, 1.0e-12);
992 }
993
994 #[test]
995 fn crossings_suppress_tangential_threshold_touch() {
996 let tangent_from_above_events = EventFinder::new(0.0, 2.0, 1.0, 1.0e-12)
997 .expect("valid finder")
998 .find_crossings(|time: f64| (time - 1.0) * (time - 1.0), 0.0)
999 .expect("finite tangent samples");
1000
1001 assert!(tangent_from_above_events.is_empty());
1002
1003 let tangent_from_below_events = EventFinder::new(0.0, 2.0, 1.0, 1.0e-12)
1004 .expect("valid finder")
1005 .find_crossings(|time: f64| -(time - 1.0) * (time - 1.0), 0.0)
1006 .expect("finite tangent samples");
1007
1008 assert!(tangent_from_below_events.is_empty());
1009
1010 let crossing_events = EventFinder::new(0.0, 2.0, 0.25, 1.0e-12)
1011 .expect("valid finder")
1012 .find_crossings(|time: f64| 0.25 - (time - 1.0) * (time - 1.0), 0.0)
1013 .expect("finite crossing samples");
1014
1015 assert_eq!(crossing_events.len(), 2);
1016 assert_eq!(crossing_events[0].direction, CrossingDirection::Rising);
1017 assert_eq!(crossing_events[0].time_seconds.to_bits(), 0.5_f64.to_bits());
1018 assert_eq!(crossing_events[1].direction, CrossingDirection::Falling);
1019 assert_eq!(crossing_events[1].time_seconds.to_bits(), 1.5_f64.to_bits());
1020 }
1021
1022 #[test]
1023 fn crossings_detect_opposite_side_threshold_plateaus() {
1024 let rising_events = plateau_crossings([-1.0, 0.0, 0.0, 1.0]);
1025 assert_eq!(rising_events.len(), 1);
1026 assert_eq!(rising_events[0].direction, CrossingDirection::Rising);
1027 assert_eq!(rising_events[0].time_seconds.to_bits(), 1.0_f64.to_bits());
1028
1029 let falling_events = plateau_crossings([1.0, 0.0, 0.0, -1.0]);
1030 assert_eq!(falling_events.len(), 1);
1031 assert_eq!(falling_events[0].direction, CrossingDirection::Falling);
1032 assert_eq!(falling_events[0].time_seconds.to_bits(), 1.0_f64.to_bits());
1033 }
1034
1035 #[test]
1036 fn crossings_emit_boundary_threshold_plateaus_at_start() {
1037 let rising_events = plateau_crossings([0.0, 0.0, 1.0]);
1038 assert_eq!(rising_events.len(), 1);
1039 assert_eq!(rising_events[0].direction, CrossingDirection::Rising);
1040 assert_eq!(rising_events[0].time_seconds.to_bits(), 0.0_f64.to_bits());
1041
1042 let falling_events = plateau_crossings([0.0, 0.0, -1.0]);
1043 assert_eq!(falling_events.len(), 1);
1044 assert_eq!(falling_events[0].direction, CrossingDirection::Falling);
1045 assert_eq!(falling_events[0].time_seconds.to_bits(), 0.0_f64.to_bits());
1046 }
1047
1048 #[test]
1049 fn crossings_suppress_same_side_threshold_plateaus() {
1050 assert!(plateau_crossings([-1.0, 0.0, 0.0, -1.0]).is_empty());
1051 assert!(plateau_crossings([1.0, 0.0, 0.0, 1.0]).is_empty());
1052 }
1053
1054 fn plateau_crossings<const N: usize>(values: [f64; N]) -> Vec<CrossingEvent> {
1055 EventFinder::new(0.0, (N - 1) as f64, 1.0, 1.0e-12)
1056 .expect("valid finder")
1057 .find_crossings(
1058 move |time: f64| {
1059 let index = time.round() as usize;
1060 assert!(index < N);
1061 assert_eq!(time.to_bits(), (index as f64).to_bits());
1062 values[index]
1063 },
1064 0.0,
1065 )
1066 .expect("finite plateau samples")
1067 }
1068
1069 #[test]
1070 fn crossings_detect_exact_right_endpoint_once() {
1071 let final_endpoint_events = EventFinder::new(0.0, 1.0, 1.0, 1.0e-12)
1072 .expect("valid finder")
1073 .find_crossings(|time| 1.0 - time, 0.0)
1074 .expect("finite endpoint samples");
1075
1076 assert_eq!(final_endpoint_events.len(), 1);
1077 assert_eq!(
1078 final_endpoint_events[0].time_seconds.to_bits(),
1079 1.0_f64.to_bits()
1080 );
1081 assert_eq!(
1082 final_endpoint_events[0].direction,
1083 CrossingDirection::Falling
1084 );
1085
1086 let shared_endpoint_events = EventFinder::new(0.0, 2.0, 1.0, 1.0e-12)
1087 .expect("valid finder")
1088 .find_crossings(|time| 1.0 - time, 0.0)
1089 .expect("finite endpoint samples");
1090
1091 assert_eq!(shared_endpoint_events.len(), 1);
1092 assert_eq!(
1093 shared_endpoint_events[0].time_seconds.to_bits(),
1094 1.0_f64.to_bits()
1095 );
1096 assert_eq!(
1097 shared_endpoint_events[0].direction,
1098 CrossingDirection::Falling
1099 );
1100
1101 let interior_events = EventFinder::new(0.0, 1.0, 1.0, 1.0e-12)
1102 .expect("valid finder")
1103 .find_crossings(|time| 0.5 - time, 0.0)
1104 .expect("finite interior samples");
1105
1106 assert_eq!(interior_events.len(), 1);
1107 assert_close(interior_events[0].time_seconds, 0.5, 1.0e-12);
1108 assert_eq!(interior_events[0].direction, CrossingDirection::Falling);
1109 }
1110
1111 #[test]
1112 fn crossings_detect_exact_start_endpoint_once() {
1113 let start = 12.0;
1114 let start_endpoint_events = EventFinder::new(start, start + 1.0, 0.25, 1.0e-12)
1115 .expect("valid finder")
1116 .find_crossings(|time| time - start, 0.0)
1117 .expect("finite endpoint samples");
1118
1119 assert_eq!(start_endpoint_events.len(), 1);
1120 assert_eq!(
1121 start_endpoint_events[0].time_seconds.to_bits(),
1122 start.to_bits()
1123 );
1124 assert_eq!(
1125 start_endpoint_events[0].direction,
1126 CrossingDirection::Rising
1127 );
1128
1129 let interior_events = EventFinder::new(start, start + 1.0, 0.5, 1.0e-12)
1130 .expect("valid finder")
1131 .find_crossings(|time| time - (start + 0.5), 0.0)
1132 .expect("finite endpoint samples");
1133
1134 assert_eq!(interior_events.len(), 1);
1135 assert_eq!(
1136 interior_events[0].time_seconds.to_bits(),
1137 (start + 0.5_f64).to_bits()
1138 );
1139 assert_eq!(interior_events[0].direction, CrossingDirection::Rising);
1140 }
1141
1142 #[test]
1143 fn extrema_find_sine_maximum_and_minimum() {
1144 let events = EventFinder::new(0.0, TAU, 0.2, 1.0e-8)
1145 .expect("valid finder")
1146 .find_extrema(f64::sin)
1147 .expect("finite sine samples");
1148
1149 assert_eq!(events.len(), 2);
1150 assert_eq!(events[0].kind, ExtremumKind::Maximum);
1151 assert_close(events[0].time_seconds, FRAC_PI_2, 5.0e-8);
1152 assert_close(events[0].value, 1.0, 1.0e-12);
1153
1154 assert_eq!(events[1].kind, ExtremumKind::Minimum);
1155 assert_close(events[1].time_seconds, 3.0 * FRAC_PI_2, 5.0e-8);
1156 assert_close(events[1].value, -1.0, 1.0e-12);
1157 }
1158
1159 #[test]
1160 fn extrema_detect_short_window_inside_single_coarse_step() {
1161 let events = EventFinder::new(0.0, 1.0, 10.0, 1.0e-12)
1162 .expect("valid finder")
1163 .find_extrema(|time: f64| -(time - 0.3) * (time - 0.3))
1164 .expect("finite parabola samples");
1165
1166 assert_eq!(events.len(), 1);
1167 assert_eq!(events[0].kind, ExtremumKind::Maximum);
1168 assert_close(events[0].time_seconds, 0.3, 1.0e-8);
1169 assert_close(events[0].value, 0.0, 1.0e-12);
1170 }
1171
1172 #[test]
1173 fn extrema_deduplicate_flat_minimum_and_maximum() {
1174 let minima = sampled_extrema([2.0, 1.0, 1.0, 2.0]);
1175 assert_eq!(minima.len(), 1);
1176 assert_eq!(minima[0].kind, ExtremumKind::Minimum);
1177 assert!((1.0..=2.0).contains(&minima[0].time_seconds));
1178 assert_close(minima[0].value, 1.0, 1.0e-12);
1179
1180 let maxima = sampled_extrema([1.0, 2.0, 2.0, 1.0]);
1181 assert_eq!(maxima.len(), 1);
1182 assert_eq!(maxima[0].kind, ExtremumKind::Maximum);
1183 assert!((1.0..=2.0).contains(&maxima[0].time_seconds));
1184 assert_close(maxima[0].value, 2.0, 1.0e-12);
1185 }
1186
1187 #[test]
1188 fn extrema_keep_distinct_adjacent_minima() {
1189 let events = sampled_extrema([2.0, 1.0, 2.0, 1.0, 2.0]);
1190 let minima: Vec<_> = events
1191 .iter()
1192 .filter(|event| event.kind == ExtremumKind::Minimum)
1193 .collect();
1194
1195 assert_eq!(minima.len(), 2);
1196 assert_close(minima[0].time_seconds, 1.0, 1.0e-8);
1197 assert_close(minima[1].time_seconds, 3.0, 1.0e-8);
1198 }
1199
1200 #[test]
1201 fn state_changes_find_step_transition() {
1202 let events = EventFinder::new(0.0, 5.0, 1.0, 1.0e-9)
1203 .expect("valid finder")
1204 .find_state_changes(|time| time >= 2.5)
1205 .expect("state changes");
1206
1207 assert_eq!(events.len(), 1);
1208 assert_close(events[0].time_seconds, 2.5, 1.0e-9);
1209 assert!(!events[0].previous_state);
1210 assert!(events[0].next_state);
1211 }
1212
1213 #[test]
1214 fn state_change_refinement_returns_exact_midpoint_transition() {
1215 let events = EventFinder::new(0.0, 2.0, 2.0, 1.0e-12)
1216 .expect("valid finder")
1217 .find_state_changes(|time| time >= 1.0)
1218 .expect("state changes");
1219
1220 assert_eq!(events.len(), 1);
1221 assert_eq!(events[0].time_seconds.to_bits(), 1.0_f64.to_bits());
1222 assert!(!events[0].previous_state);
1223 assert!(events[0].next_state);
1224 }
1225
1226 #[test]
1227 fn state_changes_keep_sampling_inside_window() {
1228 let start = 12.0;
1229 let end = start + 1.0;
1230 let min_seen = Cell::new(f64::INFINITY);
1231 let max_seen = Cell::new(f64::NEG_INFINITY);
1232 let transition_seconds = start + 0.65;
1233
1234 let events = EventFinder::new(start, end, 0.4, 1.0e-12)
1235 .expect("valid finder")
1236 .find_state_changes(|time| {
1237 min_seen.set(min_seen.get().min(time));
1238 max_seen.set(max_seen.get().max(time));
1239 time >= transition_seconds
1240 })
1241 .expect("state changes");
1242
1243 assert_eq!(events.len(), 1);
1244 assert!((start..=end).contains(&events[0].time_seconds));
1245 assert_close(events[0].time_seconds, transition_seconds, 1.0e-12);
1246 assert!(!events[0].previous_state);
1247 assert!(events[0].next_state);
1248 assert!(min_seen.get() >= start);
1249 assert!(max_seen.get() <= end);
1250 }
1251
1252 #[test]
1253 fn state_change_refinement_stops_when_midpoint_cannot_shrink_bracket() {
1254 let high = 1.0_f64;
1255 let low = f64::from_bits(high.to_bits() - 1);
1256 let finder = EventFinder::new(low, high, high - low, f64::MIN_POSITIVE)
1257 .expect("valid adjacent-float finder");
1258 let state_calls = Cell::new(0);
1259
1260 let transition = finder.refine_state_change(
1261 &|time| {
1262 state_calls.set(state_calls.get() + 1);
1263 time >= high
1264 },
1265 low,
1266 high,
1267 false,
1268 );
1269
1270 assert_eq!(transition.to_bits(), high.to_bits());
1271 assert_eq!(state_calls.get(), 0);
1272 }
1273
1274 #[test]
1275 fn batch_parallel_matches_serial_in_input_order() {
1276 let wave_finder =
1277 EventFinder::new(-0.8, TAU + 0.8, 0.2, 1.0e-10).expect("valid wave finder");
1278 let waves = [
1279 ShiftedSine { phase_seconds: 0.0 },
1280 ShiftedSine {
1281 phase_seconds: 0.35,
1282 },
1283 ShiftedSine {
1284 phase_seconds: -0.45,
1285 },
1286 ShiftedSine { phase_seconds: 0.7 },
1287 ];
1288
1289 let crossing_serial = wave_finder.find_crossings_batch_serial(&waves, 0.0);
1290 let crossing_parallel = wave_finder.find_crossings_batch_parallel(&waves, 0.0);
1291 assert_eq!(crossing_serial, crossing_parallel);
1292 assert_eq!(crossing_serial.len(), waves.len());
1293 assert!(crossing_serial
1294 .iter()
1295 .all(|events| events.as_ref().is_ok_and(|events| !events.is_empty())));
1296
1297 let extrema_serial = wave_finder.find_extrema_batch_serial(&waves);
1298 let extrema_parallel = wave_finder.find_extrema_batch_parallel(&waves);
1299 assert_eq!(extrema_serial, extrema_parallel);
1300 assert_eq!(extrema_serial.len(), waves.len());
1301 assert!(extrema_serial
1302 .iter()
1303 .all(|events| events.as_ref().is_ok_and(|events| events.len() >= 2)));
1304
1305 let state_finder = EventFinder::new(0.0, 5.0, 0.25, 1.0e-10).expect("valid state finder");
1306 let states = [
1307 StepState {
1308 transition_seconds: 0.6,
1309 },
1310 StepState {
1311 transition_seconds: 1.9,
1312 },
1313 StepState {
1314 transition_seconds: 3.4,
1315 },
1316 StepState {
1317 transition_seconds: 4.75,
1318 },
1319 ];
1320 let state_serial = state_finder.find_state_changes_batch_serial(&states);
1321 let state_parallel = state_finder.find_state_changes_batch_parallel(&states);
1322 assert_eq!(state_serial, state_parallel);
1323 assert_eq!(state_serial.len(), states.len());
1324 for (result, predicate) in state_serial.iter().zip(states.iter()) {
1325 let events = result.as_ref().expect("state changes");
1326 assert_eq!(events.len(), 1);
1327 assert_close(
1328 events[0].time_seconds,
1329 predicate.transition_seconds,
1330 1.0e-10,
1331 );
1332 }
1333 }
1334
1335 #[test]
1336 fn invalid_window_and_steps_are_rejected() {
1337 assert_invalid_field(
1338 EventFinder::new(1.0, 0.0, 1.0, 1.0).unwrap_err(),
1339 "end_seconds",
1340 "out of range",
1341 );
1342 assert_invalid_field(
1343 EventFinder::new(0.0, 1.0, 0.0, 1.0).unwrap_err(),
1344 "step_seconds",
1345 "not positive",
1346 );
1347 assert_invalid_field(
1348 EventFinder::new(0.0, 1.0, 1.0, 0.0).unwrap_err(),
1349 "time_tolerance_seconds",
1350 "not positive",
1351 );
1352 }
1353
1354 #[test]
1355 fn non_finite_scalar_inputs_are_rejected() {
1356 let finder = EventFinder::new(0.0, 1.0, 0.5, 1.0e-9).expect("valid finder");
1357 assert_invalid_field(
1358 finder.find_crossings(|time| time, f64::NAN).unwrap_err(),
1359 "threshold",
1360 "not finite",
1361 );
1362 assert_invalid_field(
1363 finder
1364 .find_extrema(|time| if time < 0.5 { time } else { f64::NAN })
1365 .unwrap_err(),
1366 "predicate",
1367 "not finite",
1368 );
1369 }
1370
1371 #[test]
1372 fn crossings_reject_non_finite_midpoint_values() {
1373 let finder = EventFinder::new(0.0, 2.0, 2.0, 0.25).expect("valid finder");
1374 assert_invalid_field(
1375 finder
1376 .find_crossings(
1377 |time| {
1378 if time == 1.0 {
1379 f64::NAN
1380 } else {
1381 time - 1.0
1382 }
1383 },
1384 0.0,
1385 )
1386 .unwrap_err(),
1387 "predicate",
1388 "not finite",
1389 );
1390
1391 let crossing = finder
1392 .find_crossings(|time| time - 1.0, 0.0)
1393 .expect("finite midpoint predicate should resolve normally");
1394 assert_eq!(crossing.len(), 1);
1395 assert_close(crossing[0].time_seconds, 1.0, 0.25);
1396 }
1397
1398 fn assert_invalid_field(
1399 error: EventFinderError,
1400 expected_field: &'static str,
1401 expected_reason: &'static str,
1402 ) {
1403 let EventFinderError::InvalidInput { field, reason } = error;
1404 assert_eq!(field, expected_field);
1405 assert_eq!(reason, expected_reason);
1406 }
1407
1408 fn sampled_extrema<const N: usize>(values: [f64; N]) -> Vec<ExtremumEvent> {
1409 assert!(N >= 2);
1410 EventFinder::new(0.0, (N - 1) as f64, 1.0, 1.0e-12)
1411 .expect("valid finder")
1412 .find_extrema(move |time| piecewise_linear_sample(&values, time))
1413 .expect("finite sample extrema")
1414 }
1415
1416 fn piecewise_linear_sample<const N: usize>(values: &[f64; N], time: f64) -> f64 {
1417 if time <= 0.0 {
1418 return values[0];
1419 }
1420
1421 let last_index = N - 1;
1422 let last_time = last_index as f64;
1423 if time >= last_time {
1424 return values[last_index];
1425 }
1426
1427 let left_index = time.floor() as usize;
1428 let fraction = time - left_index as f64;
1429 values[left_index] + (values[left_index + 1] - values[left_index]) * fraction
1430 }
1431
1432 fn assert_close(actual: f64, expected: f64, tolerance: f64) {
1433 assert!(
1434 (actual - expected).abs() <= tolerance,
1435 "{actual} differs from {expected} by more than {tolerance}"
1436 );
1437 }
1438}