Skip to main content

scirs2_integrate/ode/utils/
events.rs

1//! Event detection and handling for ODE solvers
2//!
3//! This module provides functionality for detecting events during ODE integration
4//! and handling them appropriately. Events are defined as conditions where a given
5//! function crosses zero. The event can trigger various actions, such as stopping
6//! the integration, modifying the state, or recording the event time.
7
8use crate::common::IntegrateFloat;
9use crate::error::{IntegrateError, IntegrateResult};
10use crate::ode::utils::dense_output::DenseSolution;
11use scirs2_core::ndarray::{Array1, ArrayView1};
12
13/// Direction of zero-crossing for event detection
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
15pub enum EventDirection {
16    /// Detect crossings from negative to positive
17    Rising,
18    /// Detect crossings from positive to negative
19    Falling,
20    /// Detect both crossing directions
21    #[default]
22    Both,
23}
24
25/// Action to take when an event is detected
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
27pub enum EventAction {
28    /// Continue integration without stopping
29    #[default]
30    Continue,
31    /// Stop the integration and return
32    Stop,
33}
34
35/// Definition of an event to detect during integration
36#[derive(Debug, Clone)]
37pub struct EventSpec<F: IntegrateFloat> {
38    /// Unique identifier for this event type
39    pub id: String,
40    /// Direction of zero-crossing to detect
41    pub direction: EventDirection,
42    /// Action to take when event is detected
43    pub action: EventAction,
44    /// Tolerance for considering event to be triggered (prevents chattering)
45    pub threshold: F,
46    /// Maximum number of detections (None for unlimited)
47    pub max_count: Option<usize>,
48    /// Whether to refine the event time with high precision
49    pub precise_time: bool,
50}
51
52impl<F: IntegrateFloat> EventSpec<F> {
53    /// Check if the maximum count has been reached for this event
54    pub fn max_count_reached(&self, currentcount: Option<usize>) -> bool {
55        if let Some(max) = self.max_count {
56            if let Some(_count) = currentcount {
57                return _count >= max;
58            }
59        }
60        false
61    }
62}
63
64impl<F: IntegrateFloat> Default for EventSpec<F> {
65    fn default() -> Self {
66        EventSpec {
67            id: "default".to_string(),
68            direction: EventDirection::default(),
69            action: EventAction::default(),
70            threshold: F::from_f64(1e-6).expect("Operation failed"),
71            max_count: None,
72            precise_time: true,
73        }
74    }
75}
76
77/// Represents a detected event during integration
78#[derive(Debug, Clone)]
79pub struct Event<F: IntegrateFloat> {
80    /// ID of the event that was triggered
81    pub id: String,
82    /// Time at which the event occurred
83    pub time: F,
84    /// Values of the state at the event time
85    pub state: Array1<F>,
86    /// Value of the event function at the event time
87    pub value: F,
88    /// Direction of zero-crossing (1 for rising, -1 for falling)
89    pub direction: i8,
90}
91
92/// Record of all events detected during integration
93#[derive(Debug, Clone)]
94pub struct EventRecord<F: IntegrateFloat> {
95    /// List of all detected events in chronological order
96    pub events: Vec<Event<F>>,
97    /// Count of events by ID
98    pub counts: std::collections::HashMap<String, usize>,
99}
100
101impl<F: IntegrateFloat> Default for EventRecord<F> {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107impl<F: IntegrateFloat> EventRecord<F> {
108    /// Create a new empty event record
109    pub fn new() -> Self {
110        EventRecord {
111            events: Vec::new(),
112            counts: std::collections::HashMap::new(),
113        }
114    }
115
116    /// Add a detected event to the record
117    pub fn add_event(&mut self, event: Event<F>) {
118        // Update count for this _event type
119        *self.counts.entry(event.id.clone()).or_insert(0) += 1;
120
121        // Add to the list of events
122        self.events.push(event);
123    }
124
125    /// Get the count of a specific event type
126    pub fn get_count(&self, id: &str) -> usize {
127        *self.counts.get(id).unwrap_or(&0)
128    }
129
130    /// Get all events of a specific type
131    pub fn get_events(&self, id: &str) -> Vec<&Event<F>> {
132        self.events.iter().filter(|e| e.id == id).collect()
133    }
134
135    /// Check if the maximum event count has been reached for a specific event type
136    pub fn max_count_reached(&self, _id: &str, maxcount: Option<usize>) -> bool {
137        if let Some(max) = maxcount {
138            self.get_count(_id) >= max
139        } else {
140            false
141        }
142    }
143}
144
145/// Event detection and handling during ODE integration
146#[derive(Debug)]
147pub struct EventHandler<F: IntegrateFloat> {
148    /// List of event specifications to detect
149    pub specs: Vec<EventSpec<F>>,
150    /// Record of detected events
151    pub record: EventRecord<F>,
152    /// Last values of event functions for each spec
153    last_values: Vec<Option<F>>,
154    /// States at last step
155    last_state: Option<(F, Array1<F>)>,
156}
157
158impl<F: IntegrateFloat> EventHandler<F> {
159    /// Create a new event handler with the given event specifications
160    pub fn new(specs: Vec<EventSpec<F>>) -> Self {
161        let last_values = vec![None; specs.len()];
162
163        EventHandler {
164            specs,
165            record: EventRecord::new(),
166            last_values,
167            last_state: None,
168        }
169    }
170
171    /// Initialize the event handler with the initial state
172    pub fn initialize<Func>(
173        &mut self,
174        t: F,
175        y: &Array1<F>,
176        event_funcs: &[Func],
177    ) -> IntegrateResult<()>
178    where
179        Func: Fn(F, ArrayView1<F>) -> F,
180    {
181        // Store initial state
182        self.last_state = Some((t, y.clone()));
183
184        // Initialize last event function values
185        for (i, func) in event_funcs.iter().enumerate() {
186            let value = func(t, y.view());
187            self.last_values[i] = Some(value);
188        }
189
190        Ok(())
191    }
192
193    /// Check for events between the last state and the current state
194    pub fn check_events<Func>(
195        &mut self,
196        t: F,
197        y: &Array1<F>,
198        dense_output: Option<&DenseSolution<F>>,
199        event_funcs: &[Func],
200    ) -> IntegrateResult<EventAction>
201    where
202        Func: Fn(F, ArrayView1<F>) -> F,
203    {
204        if event_funcs.len() != self.specs.len() {
205            return Err(IntegrateError::ValueError(
206                "Number of event functions does not match number of event specifications"
207                    .to_string(),
208            ));
209        }
210
211        if self.last_state.is_none() {
212            // Initialize if not done already
213            self.initialize(t, y, event_funcs)?;
214            return Ok(EventAction::Continue);
215        }
216
217        let (t_prev, y_prev) = self.last_state.as_ref().expect("Operation failed");
218
219        // Check each event
220        let mut action = EventAction::Continue;
221
222        for (i, (func, spec)) in event_funcs.iter().zip(self.specs.iter()).enumerate() {
223            // Skip if we've already reached the maximum count for this event
224            if spec.max_count_reached(self.record.counts.get(&spec.id).cloned()) {
225                continue;
226            }
227
228            // Compute current value
229            let value = func(t, y.view());
230
231            // Check if we have a previous value
232            if let Some(prev_value) = self.last_values[i] {
233                // Check if event occurred (zero-crossing)
234                let rising = prev_value < F::zero() && value >= F::zero();
235                let falling = prev_value > F::zero() && value <= F::zero();
236
237                let triggered = match spec.direction {
238                    EventDirection::Rising => rising,
239                    EventDirection::Falling => falling,
240                    EventDirection::Both => rising || falling,
241                };
242
243                if triggered {
244                    // Refine the event time if requested and dense_output is available
245                    let (event_t, event_y, event_val, dir) =
246                        if let (true, Some(dense)) = (spec.precise_time, dense_output) {
247                            self.refine_event_time(
248                                *t_prev, y_prev, t, y, prev_value, value, func, dense,
249                            )?
250                        } else {
251                            // Use current time as event time (less accurate)
252                            let dir = if rising { 1 } else { -1 };
253                            (t, y.clone(), value, dir)
254                        };
255
256                    // Create event record
257                    let event = Event {
258                        id: spec.id.clone(),
259                        time: event_t,
260                        state: event_y,
261                        value: event_val,
262                        direction: dir,
263                    };
264
265                    // Add to record
266                    self.record.add_event(event);
267
268                    // If this event requires stopping, set the action
269                    if spec.action == EventAction::Stop {
270                        action = EventAction::Stop;
271                    }
272                }
273            }
274
275            // Update last value
276            self.last_values[i] = Some(value);
277        }
278
279        // Update last state
280        self.last_state = Some((t, y.clone()));
281
282        Ok(action)
283    }
284
285    /// Refine the exact time of an event using bisection on the dense output
286    #[allow(clippy::too_many_arguments)]
287    fn refine_event_time<Func>(
288        &self,
289        t_prev: F,
290        y_prev: &Array1<F>,
291        t_curr: F,
292        y_curr: &Array1<F>,
293        value_prev: F,
294        value_curr: F,
295        event_func: &Func,
296        dense_output: &DenseSolution<F>,
297    ) -> IntegrateResult<(F, Array1<F>, F, i8)>
298    where
299        Func: Fn(F, ArrayView1<F>) -> F,
300    {
301        // Determine event direction
302        let direction: i8 = if value_prev < F::zero() && value_curr >= F::zero() {
303            1 // Rising
304        } else {
305            -1 // Falling
306        };
307
308        // Root-finding tolerance
309        let tol = F::from_f64(1e-10).expect("Operation failed");
310        let max_iter = 50;
311
312        // Bisection search for zero-crossing
313        let mut t_left = t_prev;
314        let mut t_right = t_curr;
315        let mut f_left = value_prev;
316        let f_right = value_curr;
317
318        // Handle the case where one endpoint is exactly zero
319        if f_left.abs() < tol {
320            return Ok((t_left, y_prev.clone(), f_left, direction));
321        }
322
323        if f_right.abs() < tol {
324            return Ok((t_right, y_curr.clone(), f_right, direction));
325        }
326
327        // Bisection loop
328        let mut t_mid = F::zero();
329        let mut y_mid = Array1::<F>::zeros(y_prev.len());
330        let mut f_mid = F::zero();
331
332        for _ in 0..max_iter {
333            // Compute midpoint time
334            t_mid = (t_left + t_right) / F::from_f64(2.0).expect("Operation failed");
335
336            // Get state at midpoint using dense _output
337            y_mid = dense_output.evaluate(t_mid)?;
338
339            // Evaluate event function at midpoint
340            f_mid = event_func(t_mid, y_mid.view());
341
342            // Check convergence
343            if f_mid.abs() < tol || (t_right - t_left).abs() < tol {
344                break;
345            }
346
347            // Update interval
348            if f_left * f_mid < F::zero() {
349                t_right = t_mid;
350                let _f_right = f_mid;
351            } else {
352                t_left = t_mid;
353                f_left = f_mid;
354            }
355        }
356
357        Ok((t_mid, y_mid, f_mid, direction))
358    }
359
360    /// Get the record of all detected events
361    pub fn get_record(&self) -> &EventRecord<F> {
362        &self.record
363    }
364
365    /// Check if an event occurred that requires stopping the integration
366    pub fn should_stop(&self) -> bool {
367        self.record.events.iter().any(|e| {
368            let spec = self
369                .specs
370                .iter()
371                .find(|s| s.id == e.id)
372                .expect("Operation failed");
373            spec.action == EventAction::Stop
374        })
375    }
376}
377
378/// Function to create a terminal event (one that stops integration when triggered)
379#[allow(dead_code)]
380pub fn terminal_event<F: IntegrateFloat>(id: &str, direction: EventDirection) -> EventSpec<F> {
381    EventSpec {
382        id: id.to_string(),
383        direction,
384        action: EventAction::Stop,
385        threshold: F::from_f64(1e-6).expect("Operation failed"),
386        max_count: Some(1),
387        precise_time: true,
388    }
389}
390
391/// Extension to ODEOptions to include event handling
392#[derive(Debug, Clone)]
393pub struct ODEOptionsWithEvents<F: IntegrateFloat> {
394    /// Base ODE options
395    pub base_options: super::super::types::ODEOptions<F>,
396    /// Event specifications
397    pub event_specs: Vec<EventSpec<F>>,
398}
399
400impl<F: IntegrateFloat> ODEOptionsWithEvents<F> {
401    /// Create options with events from base options
402    pub fn new(
403        base_options: super::super::types::ODEOptions<F>,
404        event_specs: Vec<EventSpec<F>>,
405    ) -> Self {
406        ODEOptionsWithEvents {
407            base_options,
408            event_specs,
409        }
410    }
411}
412
413/// Extended ODE result that includes event information
414#[derive(Debug)]
415pub struct ODEResultWithEvents<F: IntegrateFloat> {
416    /// Base ODE result
417    pub base_result: super::super::types::ODEResult<F>,
418    /// Record of detected events
419    pub events: EventRecord<F>,
420    /// Dense output for the solution (if available)
421    pub dense_output: Option<DenseSolution<F>>,
422    /// Whether integration terminated due to an event
423    pub event_termination: bool,
424}
425
426impl<F: IntegrateFloat> ODEResultWithEvents<F> {
427    /// Create a new result with events
428    pub fn new(
429        base_result: super::super::types::ODEResult<F>,
430        events: EventRecord<F>,
431        dense_output: Option<DenseSolution<F>>,
432        event_termination: bool,
433    ) -> Self {
434        ODEResultWithEvents {
435            base_result,
436            events,
437            dense_output,
438            event_termination,
439        }
440    }
441
442    /// Get the solution at a specific time using dense output
443    pub fn at_time(&self, t: F) -> IntegrateResult<Option<Array1<F>>> {
444        if let Some(ref dense) = self.dense_output {
445            Ok(Some(dense.evaluate(t)?))
446        } else {
447            // If no dense output available, check if we have exact time points
448            for (i, &ti) in self.base_result.t.iter().enumerate() {
449                if (ti - t).abs() < F::from_f64(1e-10).expect("Operation failed") {
450                    return Ok(Some(self.base_result.y[i].clone()));
451                }
452            }
453            Ok(None)
454        }
455    }
456
457    /// Get events of a specific type
458    pub fn get_events(&self, id: &str) -> Vec<&Event<F>> {
459        self.events.get_events(id)
460    }
461
462    /// Get the first occurrence of a specific event
463    pub fn first_event(&self, id: &str) -> Option<&Event<F>> {
464        self.events.get_events(id).first().copied()
465    }
466}