scirs2_integrate/ode/utils/
events.rs

1//! Event detection and handling for ODE solvers
2//!
3//! This module provides functionality for detecting events during ODE integration
4//! and handling them appropriately. Events are defined as conditions where a given
5//! function crosses zero. The event can trigger various actions, such as stopping
6//! the integration, modifying the state, or recording the event time.
7
8use crate::common::IntegrateFloat;
9use crate::error::{IntegrateError, IntegrateResult};
10use crate::ode::utils::dense_output::DenseSolution;
11use scirs2_core::ndarray::{Array1, ArrayView1};
12
13/// Direction of zero-crossing for event detection
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
15pub enum EventDirection {
16    /// Detect crossings from negative to positive
17    Rising,
18    /// Detect crossings from positive to negative
19    Falling,
20    /// Detect both crossing directions
21    #[default]
22    Both,
23}
24
25/// Action to take when an event is detected
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
27pub enum EventAction {
28    /// Continue integration without stopping
29    #[default]
30    Continue,
31    /// Stop the integration and return
32    Stop,
33}
34
35/// Definition of an event to detect during integration
36#[derive(Debug, Clone)]
37pub struct EventSpec<F: IntegrateFloat> {
38    /// Unique identifier for this event type
39    pub id: String,
40    /// Direction of zero-crossing to detect
41    pub direction: EventDirection,
42    /// Action to take when event is detected
43    pub action: EventAction,
44    /// Tolerance for considering event to be triggered (prevents chattering)
45    pub threshold: F,
46    /// Maximum number of detections (None for unlimited)
47    pub max_count: Option<usize>,
48    /// Whether to refine the event time with high precision
49    pub precise_time: bool,
50}
51
52impl<F: IntegrateFloat> EventSpec<F> {
53    /// Check if the maximum count has been reached for this event
54    pub fn max_count_reached(&self, currentcount: Option<usize>) -> bool {
55        if let Some(max) = self.max_count {
56            if let Some(_count) = currentcount {
57                return _count >= max;
58            }
59        }
60        false
61    }
62}
63
64impl<F: IntegrateFloat> Default for EventSpec<F> {
65    fn default() -> Self {
66        EventSpec {
67            id: "default".to_string(),
68            direction: EventDirection::default(),
69            action: EventAction::default(),
70            threshold: F::from_f64(1e-6).unwrap(),
71            max_count: None,
72            precise_time: true,
73        }
74    }
75}
76
77/// Represents a detected event during integration
78#[derive(Debug, Clone)]
79pub struct Event<F: IntegrateFloat> {
80    /// ID of the event that was triggered
81    pub id: String,
82    /// Time at which the event occurred
83    pub time: F,
84    /// Values of the state at the event time
85    pub state: Array1<F>,
86    /// Value of the event function at the event time
87    pub value: F,
88    /// Direction of zero-crossing (1 for rising, -1 for falling)
89    pub direction: i8,
90}
91
92/// Record of all events detected during integration
93#[derive(Debug, Clone)]
94pub struct EventRecord<F: IntegrateFloat> {
95    /// List of all detected events in chronological order
96    pub events: Vec<Event<F>>,
97    /// Count of events by ID
98    pub counts: std::collections::HashMap<String, usize>,
99}
100
101impl<F: IntegrateFloat> Default for EventRecord<F> {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107impl<F: IntegrateFloat> EventRecord<F> {
108    /// Create a new empty event record
109    pub fn new() -> Self {
110        EventRecord {
111            events: Vec::new(),
112            counts: std::collections::HashMap::new(),
113        }
114    }
115
116    /// Add a detected event to the record
117    pub fn add_event(&mut self, event: Event<F>) {
118        // Update count for this _event type
119        *self.counts.entry(event.id.clone()).or_insert(0) += 1;
120
121        // Add to the list of events
122        self.events.push(event);
123    }
124
125    /// Get the count of a specific event type
126    pub fn get_count(&self, id: &str) -> usize {
127        *self.counts.get(id).unwrap_or(&0)
128    }
129
130    /// Get all events of a specific type
131    pub fn get_events(&self, id: &str) -> Vec<&Event<F>> {
132        self.events.iter().filter(|e| e.id == id).collect()
133    }
134
135    /// Check if the maximum event count has been reached for a specific event type
136    pub fn max_count_reached(&self, _id: &str, maxcount: Option<usize>) -> bool {
137        if let Some(max) = maxcount {
138            self.get_count(_id) >= max
139        } else {
140            false
141        }
142    }
143}
144
145/// Event detection and handling during ODE integration
146#[derive(Debug)]
147pub struct EventHandler<F: IntegrateFloat> {
148    /// List of event specifications to detect
149    pub specs: Vec<EventSpec<F>>,
150    /// Record of detected events
151    pub record: EventRecord<F>,
152    /// Last values of event functions for each spec
153    last_values: Vec<Option<F>>,
154    /// States at last step
155    last_state: Option<(F, Array1<F>)>,
156}
157
158impl<F: IntegrateFloat> EventHandler<F> {
159    /// Create a new event handler with the given event specifications
160    pub fn new(specs: Vec<EventSpec<F>>) -> Self {
161        let last_values = vec![None; specs.len()];
162
163        EventHandler {
164            specs,
165            record: EventRecord::new(),
166            last_values,
167            last_state: None,
168        }
169    }
170
171    /// Initialize the event handler with the initial state
172    pub fn initialize<Func>(
173        &mut self,
174        t: F,
175        y: &Array1<F>,
176        event_funcs: &[Func],
177    ) -> IntegrateResult<()>
178    where
179        Func: Fn(F, ArrayView1<F>) -> F,
180    {
181        // Store initial state
182        self.last_state = Some((t, y.clone()));
183
184        // Initialize last event function values
185        for (i, func) in event_funcs.iter().enumerate() {
186            let value = func(t, y.view());
187            self.last_values[i] = Some(value);
188        }
189
190        Ok(())
191    }
192
193    /// Check for events between the last state and the current state
194    pub fn check_events<Func>(
195        &mut self,
196        t: F,
197        y: &Array1<F>,
198        dense_output: Option<&DenseSolution<F>>,
199        event_funcs: &[Func],
200    ) -> IntegrateResult<EventAction>
201    where
202        Func: Fn(F, ArrayView1<F>) -> F,
203    {
204        if event_funcs.len() != self.specs.len() {
205            return Err(IntegrateError::ValueError(
206                "Number of event functions does not match number of event specifications"
207                    .to_string(),
208            ));
209        }
210
211        if self.last_state.is_none() {
212            // Initialize if not done already
213            self.initialize(t, y, event_funcs)?;
214            return Ok(EventAction::Continue);
215        }
216
217        let (t_prev, y_prev) = self.last_state.as_ref().unwrap();
218
219        // Check each event
220        let mut action = EventAction::Continue;
221
222        for (i, (func, spec)) in event_funcs.iter().zip(self.specs.iter()).enumerate() {
223            // Skip if we've already reached the maximum count for this event
224            if spec.max_count_reached(self.record.counts.get(&spec.id).cloned()) {
225                continue;
226            }
227
228            // Compute current value
229            let value = func(t, y.view());
230
231            // Check if we have a previous value
232            if let Some(prev_value) = self.last_values[i] {
233                // Check if event occurred (zero-crossing)
234                let rising = prev_value < F::zero() && value >= F::zero();
235                let falling = prev_value > F::zero() && value <= F::zero();
236
237                let triggered = match spec.direction {
238                    EventDirection::Rising => rising,
239                    EventDirection::Falling => falling,
240                    EventDirection::Both => rising || falling,
241                };
242
243                if triggered {
244                    // Refine the event time if requested and dense _output is available
245                    let (event_t, event_y, event_val, dir) =
246                        if spec.precise_time && dense_output.is_some() {
247                            self.refine_event_time(
248                                *t_prev,
249                                y_prev,
250                                t,
251                                y,
252                                prev_value,
253                                value,
254                                func,
255                                dense_output.unwrap(),
256                            )?
257                        } else {
258                            // Use current time as event time (less accurate)
259                            let dir = if rising { 1 } else { -1 };
260                            (t, y.clone(), value, dir)
261                        };
262
263                    // Create event record
264                    let event = Event {
265                        id: spec.id.clone(),
266                        time: event_t,
267                        state: event_y,
268                        value: event_val,
269                        direction: dir,
270                    };
271
272                    // Add to record
273                    self.record.add_event(event);
274
275                    // If this event requires stopping, set the action
276                    if spec.action == EventAction::Stop {
277                        action = EventAction::Stop;
278                    }
279                }
280            }
281
282            // Update last value
283            self.last_values[i] = Some(value);
284        }
285
286        // Update last state
287        self.last_state = Some((t, y.clone()));
288
289        Ok(action)
290    }
291
292    /// Refine the exact time of an event using bisection on the dense output
293    #[allow(clippy::too_many_arguments)]
294    fn refine_event_time<Func>(
295        &self,
296        t_prev: F,
297        y_prev: &Array1<F>,
298        t_curr: F,
299        y_curr: &Array1<F>,
300        value_prev: F,
301        value_curr: F,
302        event_func: &Func,
303        dense_output: &DenseSolution<F>,
304    ) -> IntegrateResult<(F, Array1<F>, F, i8)>
305    where
306        Func: Fn(F, ArrayView1<F>) -> F,
307    {
308        // Determine event direction
309        let direction: i8 = if value_prev < F::zero() && value_curr >= F::zero() {
310            1 // Rising
311        } else {
312            -1 // Falling
313        };
314
315        // Root-finding tolerance
316        let tol = F::from_f64(1e-10).unwrap();
317        let max_iter = 50;
318
319        // Bisection search for zero-crossing
320        let mut t_left = t_prev;
321        let mut t_right = t_curr;
322        let mut f_left = value_prev;
323        let f_right = value_curr;
324
325        // Handle the case where one endpoint is exactly zero
326        if f_left.abs() < tol {
327            return Ok((t_left, y_prev.clone(), f_left, direction));
328        }
329
330        if f_right.abs() < tol {
331            return Ok((t_right, y_curr.clone(), f_right, direction));
332        }
333
334        // Bisection loop
335        let mut t_mid = F::zero();
336        let mut y_mid = Array1::<F>::zeros(y_prev.len());
337        let mut f_mid = F::zero();
338
339        for _ in 0..max_iter {
340            // Compute midpoint time
341            t_mid = (t_left + t_right) / F::from_f64(2.0).unwrap();
342
343            // Get state at midpoint using dense _output
344            y_mid = dense_output.evaluate(t_mid)?;
345
346            // Evaluate event function at midpoint
347            f_mid = event_func(t_mid, y_mid.view());
348
349            // Check convergence
350            if f_mid.abs() < tol || (t_right - t_left).abs() < tol {
351                break;
352            }
353
354            // Update interval
355            if f_left * f_mid < F::zero() {
356                t_right = t_mid;
357                let _f_right = f_mid;
358            } else {
359                t_left = t_mid;
360                f_left = f_mid;
361            }
362        }
363
364        Ok((t_mid, y_mid, f_mid, direction))
365    }
366
367    /// Get the record of all detected events
368    pub fn get_record(&self) -> &EventRecord<F> {
369        &self.record
370    }
371
372    /// Check if an event occurred that requires stopping the integration
373    pub fn should_stop(&self) -> bool {
374        self.record.events.iter().any(|e| {
375            let spec = self.specs.iter().find(|s| s.id == e.id).unwrap();
376            spec.action == EventAction::Stop
377        })
378    }
379}
380
381/// Function to create a terminal event (one that stops integration when triggered)
382#[allow(dead_code)]
383pub fn terminal_event<F: IntegrateFloat>(id: &str, direction: EventDirection) -> EventSpec<F> {
384    EventSpec {
385        id: id.to_string(),
386        direction,
387        action: EventAction::Stop,
388        threshold: F::from_f64(1e-6).unwrap(),
389        max_count: Some(1),
390        precise_time: true,
391    }
392}
393
394/// Extension to ODEOptions to include event handling
395#[derive(Debug, Clone)]
396pub struct ODEOptionsWithEvents<F: IntegrateFloat> {
397    /// Base ODE options
398    pub base_options: super::super::types::ODEOptions<F>,
399    /// Event specifications
400    pub event_specs: Vec<EventSpec<F>>,
401}
402
403impl<F: IntegrateFloat> ODEOptionsWithEvents<F> {
404    /// Create options with events from base options
405    pub fn new(
406        base_options: super::super::types::ODEOptions<F>,
407        event_specs: Vec<EventSpec<F>>,
408    ) -> Self {
409        ODEOptionsWithEvents {
410            base_options,
411            event_specs,
412        }
413    }
414}
415
416/// Extended ODE result that includes event information
417#[derive(Debug)]
418pub struct ODEResultWithEvents<F: IntegrateFloat> {
419    /// Base ODE result
420    pub base_result: super::super::types::ODEResult<F>,
421    /// Record of detected events
422    pub events: EventRecord<F>,
423    /// Dense output for the solution (if available)
424    pub dense_output: Option<DenseSolution<F>>,
425    /// Whether integration terminated due to an event
426    pub event_termination: bool,
427}
428
429impl<F: IntegrateFloat> ODEResultWithEvents<F> {
430    /// Create a new result with events
431    pub fn new(
432        base_result: super::super::types::ODEResult<F>,
433        events: EventRecord<F>,
434        dense_output: Option<DenseSolution<F>>,
435        event_termination: bool,
436    ) -> Self {
437        ODEResultWithEvents {
438            base_result,
439            events,
440            dense_output,
441            event_termination,
442        }
443    }
444
445    /// Get the solution at a specific time using dense output
446    pub fn at_time(&self, t: F) -> IntegrateResult<Option<Array1<F>>> {
447        if let Some(ref dense) = self.dense_output {
448            Ok(Some(dense.evaluate(t)?))
449        } else {
450            // If no dense output available, check if we have exact time points
451            for (i, &ti) in self.base_result.t.iter().enumerate() {
452                if (ti - t).abs() < F::from_f64(1e-10).unwrap() {
453                    return Ok(Some(self.base_result.y[i].clone()));
454                }
455            }
456            Ok(None)
457        }
458    }
459
460    /// Get events of a specific type
461    pub fn get_events(&self, id: &str) -> Vec<&Event<F>> {
462        self.events.get_events(id)
463    }
464
465    /// Get the first occurrence of a specific event
466    pub fn first_event(&self, id: &str) -> Option<&Event<F>> {
467        self.events.get_events(id).first().copied()
468    }
469}