scirs2_integrate/ode/utils/
events.rs1use crate::common::IntegrateFloat;
9use crate::error::{IntegrateError, IntegrateResult};
10use crate::ode::utils::dense_output::DenseSolution;
11use scirs2_core::ndarray::{Array1, ArrayView1};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
15pub enum EventDirection {
16 Rising,
18 Falling,
20 #[default]
22 Both,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
27pub enum EventAction {
28 #[default]
30 Continue,
31 Stop,
33}
34
35#[derive(Debug, Clone)]
37pub struct EventSpec<F: IntegrateFloat> {
38 pub id: String,
40 pub direction: EventDirection,
42 pub action: EventAction,
44 pub threshold: F,
46 pub max_count: Option<usize>,
48 pub precise_time: bool,
50}
51
52impl<F: IntegrateFloat> EventSpec<F> {
53 pub fn max_count_reached(&self, currentcount: Option<usize>) -> bool {
55 if let Some(max) = self.max_count {
56 if let Some(_count) = currentcount {
57 return _count >= max;
58 }
59 }
60 false
61 }
62}
63
64impl<F: IntegrateFloat> Default for EventSpec<F> {
65 fn default() -> Self {
66 EventSpec {
67 id: "default".to_string(),
68 direction: EventDirection::default(),
69 action: EventAction::default(),
70 threshold: F::from_f64(1e-6).expect("Operation failed"),
71 max_count: None,
72 precise_time: true,
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct Event<F: IntegrateFloat> {
80 pub id: String,
82 pub time: F,
84 pub state: Array1<F>,
86 pub value: F,
88 pub direction: i8,
90}
91
92#[derive(Debug, Clone)]
94pub struct EventRecord<F: IntegrateFloat> {
95 pub events: Vec<Event<F>>,
97 pub counts: std::collections::HashMap<String, usize>,
99}
100
101impl<F: IntegrateFloat> Default for EventRecord<F> {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107impl<F: IntegrateFloat> EventRecord<F> {
108 pub fn new() -> Self {
110 EventRecord {
111 events: Vec::new(),
112 counts: std::collections::HashMap::new(),
113 }
114 }
115
116 pub fn add_event(&mut self, event: Event<F>) {
118 *self.counts.entry(event.id.clone()).or_insert(0) += 1;
120
121 self.events.push(event);
123 }
124
125 pub fn get_count(&self, id: &str) -> usize {
127 *self.counts.get(id).unwrap_or(&0)
128 }
129
130 pub fn get_events(&self, id: &str) -> Vec<&Event<F>> {
132 self.events.iter().filter(|e| e.id == id).collect()
133 }
134
135 pub fn max_count_reached(&self, _id: &str, maxcount: Option<usize>) -> bool {
137 if let Some(max) = maxcount {
138 self.get_count(_id) >= max
139 } else {
140 false
141 }
142 }
143}
144
145#[derive(Debug)]
147pub struct EventHandler<F: IntegrateFloat> {
148 pub specs: Vec<EventSpec<F>>,
150 pub record: EventRecord<F>,
152 last_values: Vec<Option<F>>,
154 last_state: Option<(F, Array1<F>)>,
156}
157
158impl<F: IntegrateFloat> EventHandler<F> {
159 pub fn new(specs: Vec<EventSpec<F>>) -> Self {
161 let last_values = vec![None; specs.len()];
162
163 EventHandler {
164 specs,
165 record: EventRecord::new(),
166 last_values,
167 last_state: None,
168 }
169 }
170
171 pub fn initialize<Func>(
173 &mut self,
174 t: F,
175 y: &Array1<F>,
176 event_funcs: &[Func],
177 ) -> IntegrateResult<()>
178 where
179 Func: Fn(F, ArrayView1<F>) -> F,
180 {
181 self.last_state = Some((t, y.clone()));
183
184 for (i, func) in event_funcs.iter().enumerate() {
186 let value = func(t, y.view());
187 self.last_values[i] = Some(value);
188 }
189
190 Ok(())
191 }
192
193 pub fn check_events<Func>(
195 &mut self,
196 t: F,
197 y: &Array1<F>,
198 dense_output: Option<&DenseSolution<F>>,
199 event_funcs: &[Func],
200 ) -> IntegrateResult<EventAction>
201 where
202 Func: Fn(F, ArrayView1<F>) -> F,
203 {
204 if event_funcs.len() != self.specs.len() {
205 return Err(IntegrateError::ValueError(
206 "Number of event functions does not match number of event specifications"
207 .to_string(),
208 ));
209 }
210
211 if self.last_state.is_none() {
212 self.initialize(t, y, event_funcs)?;
214 return Ok(EventAction::Continue);
215 }
216
217 let (t_prev, y_prev) = self.last_state.as_ref().expect("Operation failed");
218
219 let mut action = EventAction::Continue;
221
222 for (i, (func, spec)) in event_funcs.iter().zip(self.specs.iter()).enumerate() {
223 if spec.max_count_reached(self.record.counts.get(&spec.id).cloned()) {
225 continue;
226 }
227
228 let value = func(t, y.view());
230
231 if let Some(prev_value) = self.last_values[i] {
233 let rising = prev_value < F::zero() && value >= F::zero();
235 let falling = prev_value > F::zero() && value <= F::zero();
236
237 let triggered = match spec.direction {
238 EventDirection::Rising => rising,
239 EventDirection::Falling => falling,
240 EventDirection::Both => rising || falling,
241 };
242
243 if triggered {
244 let (event_t, event_y, event_val, dir) =
246 if let (true, Some(dense)) = (spec.precise_time, dense_output) {
247 self.refine_event_time(
248 *t_prev, y_prev, t, y, prev_value, value, func, dense,
249 )?
250 } else {
251 let dir = if rising { 1 } else { -1 };
253 (t, y.clone(), value, dir)
254 };
255
256 let event = Event {
258 id: spec.id.clone(),
259 time: event_t,
260 state: event_y,
261 value: event_val,
262 direction: dir,
263 };
264
265 self.record.add_event(event);
267
268 if spec.action == EventAction::Stop {
270 action = EventAction::Stop;
271 }
272 }
273 }
274
275 self.last_values[i] = Some(value);
277 }
278
279 self.last_state = Some((t, y.clone()));
281
282 Ok(action)
283 }
284
285 #[allow(clippy::too_many_arguments)]
287 fn refine_event_time<Func>(
288 &self,
289 t_prev: F,
290 y_prev: &Array1<F>,
291 t_curr: F,
292 y_curr: &Array1<F>,
293 value_prev: F,
294 value_curr: F,
295 event_func: &Func,
296 dense_output: &DenseSolution<F>,
297 ) -> IntegrateResult<(F, Array1<F>, F, i8)>
298 where
299 Func: Fn(F, ArrayView1<F>) -> F,
300 {
301 let direction: i8 = if value_prev < F::zero() && value_curr >= F::zero() {
303 1 } else {
305 -1 };
307
308 let tol = F::from_f64(1e-10).expect("Operation failed");
310 let max_iter = 50;
311
312 let mut t_left = t_prev;
314 let mut t_right = t_curr;
315 let mut f_left = value_prev;
316 let f_right = value_curr;
317
318 if f_left.abs() < tol {
320 return Ok((t_left, y_prev.clone(), f_left, direction));
321 }
322
323 if f_right.abs() < tol {
324 return Ok((t_right, y_curr.clone(), f_right, direction));
325 }
326
327 let mut t_mid = F::zero();
329 let mut y_mid = Array1::<F>::zeros(y_prev.len());
330 let mut f_mid = F::zero();
331
332 for _ in 0..max_iter {
333 t_mid = (t_left + t_right) / F::from_f64(2.0).expect("Operation failed");
335
336 y_mid = dense_output.evaluate(t_mid)?;
338
339 f_mid = event_func(t_mid, y_mid.view());
341
342 if f_mid.abs() < tol || (t_right - t_left).abs() < tol {
344 break;
345 }
346
347 if f_left * f_mid < F::zero() {
349 t_right = t_mid;
350 let _f_right = f_mid;
351 } else {
352 t_left = t_mid;
353 f_left = f_mid;
354 }
355 }
356
357 Ok((t_mid, y_mid, f_mid, direction))
358 }
359
360 pub fn get_record(&self) -> &EventRecord<F> {
362 &self.record
363 }
364
365 pub fn should_stop(&self) -> bool {
367 self.record.events.iter().any(|e| {
368 let spec = self
369 .specs
370 .iter()
371 .find(|s| s.id == e.id)
372 .expect("Operation failed");
373 spec.action == EventAction::Stop
374 })
375 }
376}
377
378#[allow(dead_code)]
380pub fn terminal_event<F: IntegrateFloat>(id: &str, direction: EventDirection) -> EventSpec<F> {
381 EventSpec {
382 id: id.to_string(),
383 direction,
384 action: EventAction::Stop,
385 threshold: F::from_f64(1e-6).expect("Operation failed"),
386 max_count: Some(1),
387 precise_time: true,
388 }
389}
390
391#[derive(Debug, Clone)]
393pub struct ODEOptionsWithEvents<F: IntegrateFloat> {
394 pub base_options: super::super::types::ODEOptions<F>,
396 pub event_specs: Vec<EventSpec<F>>,
398}
399
400impl<F: IntegrateFloat> ODEOptionsWithEvents<F> {
401 pub fn new(
403 base_options: super::super::types::ODEOptions<F>,
404 event_specs: Vec<EventSpec<F>>,
405 ) -> Self {
406 ODEOptionsWithEvents {
407 base_options,
408 event_specs,
409 }
410 }
411}
412
413#[derive(Debug)]
415pub struct ODEResultWithEvents<F: IntegrateFloat> {
416 pub base_result: super::super::types::ODEResult<F>,
418 pub events: EventRecord<F>,
420 pub dense_output: Option<DenseSolution<F>>,
422 pub event_termination: bool,
424}
425
426impl<F: IntegrateFloat> ODEResultWithEvents<F> {
427 pub fn new(
429 base_result: super::super::types::ODEResult<F>,
430 events: EventRecord<F>,
431 dense_output: Option<DenseSolution<F>>,
432 event_termination: bool,
433 ) -> Self {
434 ODEResultWithEvents {
435 base_result,
436 events,
437 dense_output,
438 event_termination,
439 }
440 }
441
442 pub fn at_time(&self, t: F) -> IntegrateResult<Option<Array1<F>>> {
444 if let Some(ref dense) = self.dense_output {
445 Ok(Some(dense.evaluate(t)?))
446 } else {
447 for (i, &ti) in self.base_result.t.iter().enumerate() {
449 if (ti - t).abs() < F::from_f64(1e-10).expect("Operation failed") {
450 return Ok(Some(self.base_result.y[i].clone()));
451 }
452 }
453 Ok(None)
454 }
455 }
456
457 pub fn get_events(&self, id: &str) -> Vec<&Event<F>> {
459 self.events.get_events(id)
460 }
461
462 pub fn first_event(&self, id: &str) -> Option<&Event<F>> {
464 self.events.get_events(id).first().copied()
465 }
466}