scirs2_integrate/ode/utils/
events.rs1use crate::common::IntegrateFloat;
9use crate::error::{IntegrateError, IntegrateResult};
10use crate::ode::utils::dense_output::DenseSolution;
11use scirs2_core::ndarray::{Array1, ArrayView1};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
15pub enum EventDirection {
16 Rising,
18 Falling,
20 #[default]
22 Both,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
27pub enum EventAction {
28 #[default]
30 Continue,
31 Stop,
33}
34
35#[derive(Debug, Clone)]
37pub struct EventSpec<F: IntegrateFloat> {
38 pub id: String,
40 pub direction: EventDirection,
42 pub action: EventAction,
44 pub threshold: F,
46 pub max_count: Option<usize>,
48 pub precise_time: bool,
50}
51
52impl<F: IntegrateFloat> EventSpec<F> {
53 pub fn max_count_reached(&self, currentcount: Option<usize>) -> bool {
55 if let Some(max) = self.max_count {
56 if let Some(_count) = currentcount {
57 return _count >= max;
58 }
59 }
60 false
61 }
62}
63
64impl<F: IntegrateFloat> Default for EventSpec<F> {
65 fn default() -> Self {
66 EventSpec {
67 id: "default".to_string(),
68 direction: EventDirection::default(),
69 action: EventAction::default(),
70 threshold: F::from_f64(1e-6).unwrap(),
71 max_count: None,
72 precise_time: true,
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct Event<F: IntegrateFloat> {
80 pub id: String,
82 pub time: F,
84 pub state: Array1<F>,
86 pub value: F,
88 pub direction: i8,
90}
91
92#[derive(Debug, Clone)]
94pub struct EventRecord<F: IntegrateFloat> {
95 pub events: Vec<Event<F>>,
97 pub counts: std::collections::HashMap<String, usize>,
99}
100
101impl<F: IntegrateFloat> Default for EventRecord<F> {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107impl<F: IntegrateFloat> EventRecord<F> {
108 pub fn new() -> Self {
110 EventRecord {
111 events: Vec::new(),
112 counts: std::collections::HashMap::new(),
113 }
114 }
115
116 pub fn add_event(&mut self, event: Event<F>) {
118 *self.counts.entry(event.id.clone()).or_insert(0) += 1;
120
121 self.events.push(event);
123 }
124
125 pub fn get_count(&self, id: &str) -> usize {
127 *self.counts.get(id).unwrap_or(&0)
128 }
129
130 pub fn get_events(&self, id: &str) -> Vec<&Event<F>> {
132 self.events.iter().filter(|e| e.id == id).collect()
133 }
134
135 pub fn max_count_reached(&self, _id: &str, maxcount: Option<usize>) -> bool {
137 if let Some(max) = maxcount {
138 self.get_count(_id) >= max
139 } else {
140 false
141 }
142 }
143}
144
145#[derive(Debug)]
147pub struct EventHandler<F: IntegrateFloat> {
148 pub specs: Vec<EventSpec<F>>,
150 pub record: EventRecord<F>,
152 last_values: Vec<Option<F>>,
154 last_state: Option<(F, Array1<F>)>,
156}
157
158impl<F: IntegrateFloat> EventHandler<F> {
159 pub fn new(specs: Vec<EventSpec<F>>) -> Self {
161 let last_values = vec![None; specs.len()];
162
163 EventHandler {
164 specs,
165 record: EventRecord::new(),
166 last_values,
167 last_state: None,
168 }
169 }
170
171 pub fn initialize<Func>(
173 &mut self,
174 t: F,
175 y: &Array1<F>,
176 event_funcs: &[Func],
177 ) -> IntegrateResult<()>
178 where
179 Func: Fn(F, ArrayView1<F>) -> F,
180 {
181 self.last_state = Some((t, y.clone()));
183
184 for (i, func) in event_funcs.iter().enumerate() {
186 let value = func(t, y.view());
187 self.last_values[i] = Some(value);
188 }
189
190 Ok(())
191 }
192
193 pub fn check_events<Func>(
195 &mut self,
196 t: F,
197 y: &Array1<F>,
198 dense_output: Option<&DenseSolution<F>>,
199 event_funcs: &[Func],
200 ) -> IntegrateResult<EventAction>
201 where
202 Func: Fn(F, ArrayView1<F>) -> F,
203 {
204 if event_funcs.len() != self.specs.len() {
205 return Err(IntegrateError::ValueError(
206 "Number of event functions does not match number of event specifications"
207 .to_string(),
208 ));
209 }
210
211 if self.last_state.is_none() {
212 self.initialize(t, y, event_funcs)?;
214 return Ok(EventAction::Continue);
215 }
216
217 let (t_prev, y_prev) = self.last_state.as_ref().unwrap();
218
219 let mut action = EventAction::Continue;
221
222 for (i, (func, spec)) in event_funcs.iter().zip(self.specs.iter()).enumerate() {
223 if spec.max_count_reached(self.record.counts.get(&spec.id).cloned()) {
225 continue;
226 }
227
228 let value = func(t, y.view());
230
231 if let Some(prev_value) = self.last_values[i] {
233 let rising = prev_value < F::zero() && value >= F::zero();
235 let falling = prev_value > F::zero() && value <= F::zero();
236
237 let triggered = match spec.direction {
238 EventDirection::Rising => rising,
239 EventDirection::Falling => falling,
240 EventDirection::Both => rising || falling,
241 };
242
243 if triggered {
244 let (event_t, event_y, event_val, dir) =
246 if spec.precise_time && dense_output.is_some() {
247 self.refine_event_time(
248 *t_prev,
249 y_prev,
250 t,
251 y,
252 prev_value,
253 value,
254 func,
255 dense_output.unwrap(),
256 )?
257 } else {
258 let dir = if rising { 1 } else { -1 };
260 (t, y.clone(), value, dir)
261 };
262
263 let event = Event {
265 id: spec.id.clone(),
266 time: event_t,
267 state: event_y,
268 value: event_val,
269 direction: dir,
270 };
271
272 self.record.add_event(event);
274
275 if spec.action == EventAction::Stop {
277 action = EventAction::Stop;
278 }
279 }
280 }
281
282 self.last_values[i] = Some(value);
284 }
285
286 self.last_state = Some((t, y.clone()));
288
289 Ok(action)
290 }
291
292 #[allow(clippy::too_many_arguments)]
294 fn refine_event_time<Func>(
295 &self,
296 t_prev: F,
297 y_prev: &Array1<F>,
298 t_curr: F,
299 y_curr: &Array1<F>,
300 value_prev: F,
301 value_curr: F,
302 event_func: &Func,
303 dense_output: &DenseSolution<F>,
304 ) -> IntegrateResult<(F, Array1<F>, F, i8)>
305 where
306 Func: Fn(F, ArrayView1<F>) -> F,
307 {
308 let direction: i8 = if value_prev < F::zero() && value_curr >= F::zero() {
310 1 } else {
312 -1 };
314
315 let tol = F::from_f64(1e-10).unwrap();
317 let max_iter = 50;
318
319 let mut t_left = t_prev;
321 let mut t_right = t_curr;
322 let mut f_left = value_prev;
323 let f_right = value_curr;
324
325 if f_left.abs() < tol {
327 return Ok((t_left, y_prev.clone(), f_left, direction));
328 }
329
330 if f_right.abs() < tol {
331 return Ok((t_right, y_curr.clone(), f_right, direction));
332 }
333
334 let mut t_mid = F::zero();
336 let mut y_mid = Array1::<F>::zeros(y_prev.len());
337 let mut f_mid = F::zero();
338
339 for _ in 0..max_iter {
340 t_mid = (t_left + t_right) / F::from_f64(2.0).unwrap();
342
343 y_mid = dense_output.evaluate(t_mid)?;
345
346 f_mid = event_func(t_mid, y_mid.view());
348
349 if f_mid.abs() < tol || (t_right - t_left).abs() < tol {
351 break;
352 }
353
354 if f_left * f_mid < F::zero() {
356 t_right = t_mid;
357 let _f_right = f_mid;
358 } else {
359 t_left = t_mid;
360 f_left = f_mid;
361 }
362 }
363
364 Ok((t_mid, y_mid, f_mid, direction))
365 }
366
367 pub fn get_record(&self) -> &EventRecord<F> {
369 &self.record
370 }
371
372 pub fn should_stop(&self) -> bool {
374 self.record.events.iter().any(|e| {
375 let spec = self.specs.iter().find(|s| s.id == e.id).unwrap();
376 spec.action == EventAction::Stop
377 })
378 }
379}
380
381#[allow(dead_code)]
383pub fn terminal_event<F: IntegrateFloat>(id: &str, direction: EventDirection) -> EventSpec<F> {
384 EventSpec {
385 id: id.to_string(),
386 direction,
387 action: EventAction::Stop,
388 threshold: F::from_f64(1e-6).unwrap(),
389 max_count: Some(1),
390 precise_time: true,
391 }
392}
393
394#[derive(Debug, Clone)]
396pub struct ODEOptionsWithEvents<F: IntegrateFloat> {
397 pub base_options: super::super::types::ODEOptions<F>,
399 pub event_specs: Vec<EventSpec<F>>,
401}
402
403impl<F: IntegrateFloat> ODEOptionsWithEvents<F> {
404 pub fn new(
406 base_options: super::super::types::ODEOptions<F>,
407 event_specs: Vec<EventSpec<F>>,
408 ) -> Self {
409 ODEOptionsWithEvents {
410 base_options,
411 event_specs,
412 }
413 }
414}
415
416#[derive(Debug)]
418pub struct ODEResultWithEvents<F: IntegrateFloat> {
419 pub base_result: super::super::types::ODEResult<F>,
421 pub events: EventRecord<F>,
423 pub dense_output: Option<DenseSolution<F>>,
425 pub event_termination: bool,
427}
428
429impl<F: IntegrateFloat> ODEResultWithEvents<F> {
430 pub fn new(
432 base_result: super::super::types::ODEResult<F>,
433 events: EventRecord<F>,
434 dense_output: Option<DenseSolution<F>>,
435 event_termination: bool,
436 ) -> Self {
437 ODEResultWithEvents {
438 base_result,
439 events,
440 dense_output,
441 event_termination,
442 }
443 }
444
445 pub fn at_time(&self, t: F) -> IntegrateResult<Option<Array1<F>>> {
447 if let Some(ref dense) = self.dense_output {
448 Ok(Some(dense.evaluate(t)?))
449 } else {
450 for (i, &ti) in self.base_result.t.iter().enumerate() {
452 if (ti - t).abs() < F::from_f64(1e-10).unwrap() {
453 return Ok(Some(self.base_result.y[i].clone()));
454 }
455 }
456 Ok(None)
457 }
458 }
459
460 pub fn get_events(&self, id: &str) -> Vec<&Event<F>> {
462 self.events.get_events(id)
463 }
464
465 pub fn first_event(&self, id: &str) -> Option<&Event<F>> {
467 self.events.get_events(id).first().copied()
468 }
469}