Skip to main content

scirs2_integrate/
dde.rs

1//! Delay Differential Equation (DDE) solvers
2//!
3//! This module provides numerical methods for solving delay differential equations
4//! of the form:
5//!
6//!   y'(t) = f(t, y(t), y(t - tau_1), y(t - tau_2), ...)
7//!
8//! where tau_i are the delays. The history function phi(t) for t <= t0 must be provided.
9//!
10//! # Features
11//!
12//! - **Method of steps**: Integrates interval-by-interval, using the known
13//!   solution on previous intervals to evaluate delayed terms.
14//! - **Dense output interpolation**: Hermite cubic interpolation for evaluating
15//!   the solution at arbitrary points in the history.
16//! - **Multiple delays**: Supports any number of constant delays.
17//! - **State-dependent delays**: Delays that depend on the current state y(t).
18//! - **Discontinuity tracking**: Automatically tracks and resolves discontinuities
19//!   propagated by the delays.
20//!
21//! # References
22//!
23//! - Bellen & Zennaro: "Numerical Methods for Delay Differential Equations" (2003)
24//! - Shampine & Thompson: "Solving DDEs in Matlab" (2001)
25//! - Baker, Paul & Willé: "Issues in the numerical solution of DDEs" (1995)
26
27use crate::common::IntegrateFloat;
28use crate::error::{IntegrateError, IntegrateResult};
29use scirs2_core::ndarray::{array, Array1, ArrayView1};
30use std::fmt::Debug;
31
32// ---------------------------------------------------------------------------
33// Types
34// ---------------------------------------------------------------------------
35
36/// Specification of delays in a DDE system.
37#[derive(Debug, Clone)]
38pub enum DelayType<F: IntegrateFloat> {
39    /// Constant delays: y(t - tau_i) for fixed tau_i > 0
40    Constant(Vec<F>),
41    /// State-dependent delays: tau_i = tau_i(t, y(t))
42    /// The function returns the vector of delay values.
43    StateDependent,
44}
45
46/// Options for the DDE solver.
47#[derive(Debug, Clone)]
48pub struct DDEOptions<F: IntegrateFloat> {
49    /// Relative tolerance
50    pub rtol: F,
51    /// Absolute tolerance
52    pub atol: F,
53    /// Initial step size (None for automatic)
54    pub h0: Option<F>,
55    /// Maximum step size
56    pub max_step: Option<F>,
57    /// Minimum step size
58    pub min_step: Option<F>,
59    /// Maximum number of steps
60    pub max_steps: usize,
61    /// Whether to track discontinuities
62    pub track_discontinuities: bool,
63    /// Maximum discontinuity order to track (0 = jump, 1 = corner, etc.)
64    pub max_discontinuity_order: usize,
65}
66
67impl<F: IntegrateFloat> Default for DDEOptions<F> {
68    fn default() -> Self {
69        DDEOptions {
70            rtol: F::from_f64(1e-6).unwrap_or_else(|| F::epsilon()),
71            atol: F::from_f64(1e-9).unwrap_or_else(|| F::epsilon()),
72            h0: None,
73            max_step: None,
74            min_step: None,
75            max_steps: 100_000,
76            track_discontinuities: true,
77            max_discontinuity_order: 5,
78        }
79    }
80}
81
82/// Result of DDE integration.
83#[derive(Debug, Clone)]
84pub struct DDEResult<F: IntegrateFloat> {
85    /// Time points
86    pub t: Vec<F>,
87    /// Solution values at each time point
88    pub y: Vec<Array1<F>>,
89    /// Whether integration completed successfully
90    pub success: bool,
91    /// Status message
92    pub message: Option<String>,
93    /// Number of function evaluations
94    pub n_eval: usize,
95    /// Number of steps taken
96    pub n_steps: usize,
97    /// Number of accepted steps
98    pub n_accepted: usize,
99    /// Number of rejected steps
100    pub n_rejected: usize,
101    /// Detected discontinuity times
102    pub discontinuities: Vec<F>,
103}
104
105// ---------------------------------------------------------------------------
106// Dense output (Hermite cubic interpolation)
107// ---------------------------------------------------------------------------
108
109/// A segment of dense output for Hermite interpolation on [t_left, t_right].
110#[derive(Debug, Clone)]
111struct DenseSegment<F: IntegrateFloat> {
112    t_left: F,
113    t_right: F,
114    y_left: Array1<F>,
115    y_right: Array1<F>,
116    yp_left: Array1<F>,
117    yp_right: Array1<F>,
118}
119
120impl<F: IntegrateFloat> DenseSegment<F> {
121    /// Evaluate the Hermite interpolant at time t in [t_left, t_right].
122    fn evaluate(&self, t: F) -> Array1<F> {
123        let h = self.t_right - self.t_left;
124        if h.abs() < F::from_f64(1e-30).unwrap_or_else(|| F::epsilon()) {
125            return self.y_left.clone();
126        }
127
128        let s = (t - self.t_left) / h;
129        let s2 = s * s;
130        let s3 = s2 * s;
131
132        // Hermite basis functions: h00(s) = 2s^3 - 3s^2 + 1,  h10(s) = s^3 - 2s^2 + s
133        //                          h01(s) = -2s^3 + 3s^2,     h11(s) = s^3 - s^2
134        let two = F::one() + F::one();
135        let three = two + F::one();
136        let h00 = two * s3 - three * s2 + F::one();
137        let h10 = s3 - two * s2 + s;
138        let h01 = three * s2 - two * s3;
139        let h11 = s3 - s2;
140
141        &self.y_left * h00
142            + &(&self.yp_left * (h * h10))
143            + &(&self.y_right * h01)
144            + &(&self.yp_right * (h * h11))
145    }
146}
147
148/// History storage and interpolation for delayed values.
149#[derive(Debug, Clone)]
150struct HistoryBuffer<F: IntegrateFloat> {
151    /// Dense output segments, ordered by time
152    segments: Vec<DenseSegment<F>>,
153    /// History function for t <= t0
154    /// Stored as a vector of (t, y) pairs for the pre-initial-time history
155    pre_history: Vec<(F, Array1<F>)>,
156    /// The earliest time in the solution buffer
157    t_start: F,
158}
159
160impl<F: IntegrateFloat> HistoryBuffer<F> {
161    fn new(t0: F) -> Self {
162        HistoryBuffer {
163            segments: Vec::new(),
164            pre_history: Vec::new(),
165            t_start: t0,
166        }
167    }
168
169    /// Add pre-history samples (for t <= t0)
170    fn add_pre_history(&mut self, t: F, y: Array1<F>) {
171        self.pre_history.push((t, y));
172        // Keep sorted
173        self.pre_history
174            .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
175    }
176
177    /// Add a new dense output segment
178    fn add_segment(&mut self, seg: DenseSegment<F>) {
179        self.segments.push(seg);
180    }
181
182    /// Evaluate the history/solution at time t
183    fn evaluate(&self, t: F) -> IntegrateResult<Array1<F>> {
184        // Check if t falls in the pre-history
185        if t < self.t_start || (t <= self.t_start && self.segments.is_empty()) {
186            return self.evaluate_pre_history(t);
187        }
188
189        // Find the segment containing t
190        for seg in &self.segments {
191            if t >= seg.t_left && t <= seg.t_right {
192                return Ok(seg.evaluate(t));
193            }
194        }
195
196        // If t is at the very end, use the last segment
197        if let Some(last) = self.segments.last() {
198            if (t - last.t_right).abs() < F::from_f64(1e-12).unwrap_or_else(|| F::epsilon()) {
199                return Ok(last.y_right.clone());
200            }
201        }
202
203        Err(IntegrateError::ValueError(format!(
204            "Time {t} is outside the computed solution range"
205        )))
206    }
207
208    /// Evaluate pre-history via linear interpolation of stored samples
209    fn evaluate_pre_history(&self, t: F) -> IntegrateResult<Array1<F>> {
210        if self.pre_history.is_empty() {
211            return Err(IntegrateError::ValueError(
212                "No pre-history available for the requested time".into(),
213            ));
214        }
215
216        // If only one point, return it
217        if self.pre_history.len() == 1 {
218            return Ok(self.pre_history[0].1.clone());
219        }
220
221        // Clamp to the range of pre-history
222        let first_t = self.pre_history[0].0;
223        let last_t = self.pre_history[self.pre_history.len() - 1].0;
224
225        if t <= first_t {
226            return Ok(self.pre_history[0].1.clone());
227        }
228        if t >= last_t {
229            return Ok(self.pre_history[self.pre_history.len() - 1].1.clone());
230        }
231
232        // Find bracketing interval
233        for i in 0..self.pre_history.len() - 1 {
234            let (t_i, ref y_i) = self.pre_history[i];
235            let (t_ip1, ref y_ip1) = self.pre_history[i + 1];
236
237            if t >= t_i && t <= t_ip1 {
238                let dt = t_ip1 - t_i;
239                if dt.abs() < F::from_f64(1e-30).unwrap_or_else(|| F::epsilon()) {
240                    return Ok(y_i.clone());
241                }
242                let s = (t - t_i) / dt;
243                return Ok(y_i * (F::one() - s) + y_ip1 * s);
244            }
245        }
246
247        Ok(self.pre_history[self.pre_history.len() - 1].1.clone())
248    }
249}
250
251// ---------------------------------------------------------------------------
252// Discontinuity tracker
253// ---------------------------------------------------------------------------
254
255/// Tracks discontinuities propagated by delays.
256///
257/// If the history function has a discontinuity at t0, the DDE solution will
258/// have a derivative discontinuity at t0 + tau, and a higher-order
259/// discontinuity at t0 + 2*tau, etc. The solver must step exactly to
260/// these points for accuracy.
261#[derive(Debug, Clone)]
262struct DiscontinuityTracker<F: IntegrateFloat> {
263    /// Queue of discontinuity times, sorted in ascending order
264    queue: Vec<F>,
265    /// Maximum order of discontinuity to track
266    max_order: usize,
267}
268
269impl<F: IntegrateFloat> DiscontinuityTracker<F> {
270    fn new(max_order: usize) -> Self {
271        DiscontinuityTracker {
272            queue: Vec::new(),
273            max_order,
274        }
275    }
276
277    /// Seed the tracker with initial discontinuity at t0 and constant delays.
278    fn seed(&mut self, t0: F, tf: F, delays: &[F]) {
279        // The initial time t0 is a discontinuity in the derivative
280        // It propagates through the delays: t0 + k*tau for each delay tau
281        for order in 0..=self.max_order {
282            for tau in delays {
283                let disc_t = t0 + F::from_usize(order + 1).unwrap_or_else(|| F::one()) * (*tau);
284                if disc_t > t0 && disc_t <= tf {
285                    self.queue.push(disc_t);
286                }
287            }
288        }
289
290        // Sort and deduplicate
291        self.queue
292            .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
293        self.queue
294            .dedup_by(|a, b| (*a - *b).abs() < F::from_f64(1e-12).unwrap_or_else(|| F::epsilon()));
295    }
296
297    /// Get the next discontinuity time after `t_current`.
298    fn next_after(&self, t_current: F) -> Option<F> {
299        let eps = F::from_f64(1e-12).unwrap_or_else(|| F::epsilon());
300        self.queue.iter().find(|&&td| td > t_current + eps).copied()
301    }
302
303    /// Get all tracked discontinuity times.
304    fn all_times(&self) -> &[F] {
305        &self.queue
306    }
307}
308
309// ---------------------------------------------------------------------------
310// DDE System trait
311// ---------------------------------------------------------------------------
312
313/// Trait for a DDE right-hand side function.
314///
315/// The user implements this to define f(t, y(t), y(t-tau_1), ..., y(t-tau_k)).
316pub trait DDESystem<F: IntegrateFloat> {
317    /// Dimension of the system
318    fn ndim(&self) -> usize;
319
320    /// Evaluate the right-hand side.
321    ///
322    /// * `t` - current time
323    /// * `y` - current state y(t)
324    /// * `y_delayed` - delayed states [y(t-tau_1), y(t-tau_2), ...]
325    fn rhs(&self, t: F, y: ArrayView1<F>, y_delayed: &[Array1<F>]) -> IntegrateResult<Array1<F>>;
326
327    /// Constant delays (if DelayType::Constant)
328    fn delays(&self) -> Vec<F>;
329
330    /// State-dependent delays: given (t, y), return the delay values.
331    /// Default: calls `self.delays()` ignoring state.
332    fn state_dependent_delays(&self, _t: F, _y: ArrayView1<F>) -> Vec<F> {
333        self.delays()
334    }
335
336    /// History function phi(t) for t <= t0.
337    fn history(&self, t: F) -> Array1<F>;
338}
339
340// ---------------------------------------------------------------------------
341// Method of Steps solver (RK45 based)
342// ---------------------------------------------------------------------------
343
344/// Solve a DDE using the method of steps with an embedded RK45 integrator.
345///
346/// # Arguments
347/// * `sys` - the DDE system
348/// * `t_span` - [t0, tf]
349/// * `options` - solver options
350///
351/// # Returns
352/// A `DDEResult` with the solution trajectory
353pub fn solve_dde<F: IntegrateFloat>(
354    sys: &dyn DDESystem<F>,
355    t_span: [F; 2],
356    options: &DDEOptions<F>,
357) -> IntegrateResult<DDEResult<F>> {
358    let t0 = t_span[0];
359    let tf = t_span[1];
360    let n = sys.ndim();
361
362    if tf <= t0 {
363        return Err(IntegrateError::ValueError("tf must be > t0".into()));
364    }
365
366    let delays = sys.delays();
367    if delays.is_empty() {
368        return Err(IntegrateError::ValueError(
369            "DDE system must have at least one delay".into(),
370        ));
371    }
372
373    // Validate delays
374    for (i, &tau) in delays.iter().enumerate() {
375        if tau <= F::zero() {
376            return Err(IntegrateError::ValueError(format!(
377                "Delay {} must be positive, got {tau}",
378                i
379            )));
380        }
381    }
382
383    // Initialize history buffer
384    let mut history = HistoryBuffer::new(t0);
385
386    // Sample the history function on a grid before t0
387    let max_delay = delays.iter().fold(F::zero(), |a, &b| a.max(b));
388    let n_pre_samples = 100;
389    let pre_dt = max_delay / F::from_usize(n_pre_samples).unwrap_or_else(|| F::one());
390    for i in 0..=n_pre_samples {
391        let t_pre = t0 - max_delay + F::from_usize(i).unwrap_or_else(|| F::zero()) * pre_dt;
392        let y_pre = sys.history(t_pre);
393        history.add_pre_history(t_pre, y_pre);
394    }
395
396    // Initialize discontinuity tracker
397    let mut disc_tracker = DiscontinuityTracker::new(if options.track_discontinuities {
398        options.max_discontinuity_order
399    } else {
400        0
401    });
402    if options.track_discontinuities {
403        disc_tracker.seed(t0, tf, &delays);
404    }
405
406    // Initial state from history
407    let y0 = sys.history(t0);
408
409    // Step size
410    let span = tf - t0;
411    let mut h = match options.h0 {
412        Some(h0) => h0,
413        None => {
414            let h_init = span * F::from_f64(0.001).unwrap_or_else(|| F::epsilon());
415            if let Some(max_h) = options.max_step {
416                h_init.min(max_h)
417            } else {
418                h_init
419            }
420        }
421    };
422
423    let min_step = options
424        .min_step
425        .unwrap_or_else(|| span * F::from_f64(1e-14).unwrap_or_else(|| F::epsilon()));
426
427    // Result storage
428    let mut t_out = vec![t0];
429    let mut y_out = vec![y0.clone()];
430    let mut disc_times: Vec<F> = Vec::new();
431
432    let mut t = t0;
433    let mut y = y0;
434    let mut n_eval: usize = 0;
435    let mut n_steps: usize = 0;
436    let mut n_accepted: usize = 0;
437    let mut n_rejected: usize = 0;
438
439    let safety = F::from_f64(0.9).unwrap_or_else(|| F::one());
440    let fac_min = F::from_f64(0.2).unwrap_or_else(|| F::one());
441    let fac_max = F::from_f64(5.0).unwrap_or_else(|| F::one());
442
443    while t < tf && n_steps < options.max_steps {
444        // Clamp step to not overshoot tf
445        if t + h > tf {
446            h = tf - t;
447        }
448
449        // Clamp step to next discontinuity
450        if let Some(t_disc) = disc_tracker.next_after(t) {
451            if t + h > t_disc {
452                h = t_disc - t;
453                // Record this discontinuity
454                let eps = F::from_f64(1e-12).unwrap_or_else(|| F::epsilon());
455                if (t + h - t_disc).abs() < eps {
456                    disc_times.push(t_disc);
457                }
458            }
459        }
460
461        if h < min_step {
462            break;
463        }
464
465        // RK45 step with delay evaluation
466        let step_result = rk45_dde_step(sys, &history, t, h, &y, n, options.rtol, options.atol)?;
467        n_eval += 6;
468
469        let err_norm = step_result.error_norm;
470
471        if err_norm <= F::one() {
472            // Accept step
473            let t_new = t + h;
474
475            // Compute derivative at old and new points for dense output
476            let yp_old = evaluate_rhs(sys, &history, t, &y)?;
477            let yp_new = evaluate_rhs(sys, &history, t_new, &step_result.y_new)?;
478
479            // Store dense output segment
480            history.add_segment(DenseSegment {
481                t_left: t,
482                t_right: t_new,
483                y_left: y.clone(),
484                y_right: step_result.y_new.clone(),
485                yp_left: yp_old,
486                yp_right: yp_new,
487            });
488
489            t = t_new;
490            y = step_result.y_new;
491            n_accepted += 1;
492
493            t_out.push(t);
494            y_out.push(y.clone());
495        } else {
496            n_rejected += 1;
497        }
498
499        // Adjust step size
500        let factor = if err_norm > F::zero() {
501            safety
502                * (F::one() / err_norm)
503                    .powf(F::one() / F::from_f64(5.0).unwrap_or_else(|| F::one()))
504        } else {
505            fac_max
506        };
507        let factor = factor.max(fac_min).min(fac_max);
508        h *= factor;
509
510        if let Some(max_h) = options.max_step {
511            h = h.min(max_h);
512        }
513
514        n_steps += 1;
515    }
516
517    Ok(DDEResult {
518        t: t_out,
519        y: y_out,
520        success: t >= tf - min_step,
521        message: if t >= tf - min_step {
522            Some("DDE integration completed successfully".into())
523        } else {
524            Some(format!("DDE integration stopped at t = {t}"))
525        },
526        n_eval,
527        n_steps,
528        n_accepted,
529        n_rejected,
530        discontinuities: disc_times,
531    })
532}
533
534// ---------------------------------------------------------------------------
535// RK45 step for DDE
536// ---------------------------------------------------------------------------
537
538struct RK45StepResult<F: IntegrateFloat> {
539    y_new: Array1<F>,
540    error_norm: F,
541}
542
543/// Perform one RK45 (Dormand-Prince) step with delay evaluation.
544fn rk45_dde_step<F: IntegrateFloat>(
545    sys: &dyn DDESystem<F>,
546    history: &HistoryBuffer<F>,
547    t: F,
548    h: F,
549    y: &Array1<F>,
550    n: usize,
551    rtol: F,
552    atol: F,
553) -> IntegrateResult<RK45StepResult<F>> {
554    // Dormand-Prince coefficients
555    let a21 = F::from_f64(1.0 / 5.0).unwrap_or_else(|| F::zero());
556    let a31 = F::from_f64(3.0 / 40.0).unwrap_or_else(|| F::zero());
557    let a32 = F::from_f64(9.0 / 40.0).unwrap_or_else(|| F::zero());
558    let a41 = F::from_f64(44.0 / 45.0).unwrap_or_else(|| F::zero());
559    let a42 = F::from_f64(-56.0 / 15.0).unwrap_or_else(|| F::zero());
560    let a43 = F::from_f64(32.0 / 9.0).unwrap_or_else(|| F::zero());
561    let a51 = F::from_f64(19372.0 / 6561.0).unwrap_or_else(|| F::zero());
562    let a52 = F::from_f64(-25360.0 / 2187.0).unwrap_or_else(|| F::zero());
563    let a53 = F::from_f64(64448.0 / 6561.0).unwrap_or_else(|| F::zero());
564    let a54 = F::from_f64(-212.0 / 729.0).unwrap_or_else(|| F::zero());
565    let a61 = F::from_f64(9017.0 / 3168.0).unwrap_or_else(|| F::zero());
566    let a62 = F::from_f64(-355.0 / 33.0).unwrap_or_else(|| F::zero());
567    let a63 = F::from_f64(46732.0 / 5247.0).unwrap_or_else(|| F::zero());
568    let a64 = F::from_f64(49.0 / 176.0).unwrap_or_else(|| F::zero());
569    let a65 = F::from_f64(-5103.0 / 18656.0).unwrap_or_else(|| F::zero());
570
571    // 5th order weights
572    let b1 = F::from_f64(35.0 / 384.0).unwrap_or_else(|| F::zero());
573    let b3 = F::from_f64(500.0 / 1113.0).unwrap_or_else(|| F::zero());
574    let b4 = F::from_f64(125.0 / 192.0).unwrap_or_else(|| F::zero());
575    let b5 = F::from_f64(-2187.0 / 6784.0).unwrap_or_else(|| F::zero());
576    let b6 = F::from_f64(11.0 / 84.0).unwrap_or_else(|| F::zero());
577
578    // 4th order weights (for error estimate)
579    let e1 = F::from_f64(71.0 / 57600.0).unwrap_or_else(|| F::zero());
580    let e3 = F::from_f64(-71.0 / 16695.0).unwrap_or_else(|| F::zero());
581    let e4 = F::from_f64(71.0 / 1920.0).unwrap_or_else(|| F::zero());
582    let e5 = F::from_f64(-17253.0 / 339200.0).unwrap_or_else(|| F::zero());
583    let e6 = F::from_f64(22.0 / 525.0).unwrap_or_else(|| F::zero());
584    let e7 = F::from_f64(-1.0 / 40.0).unwrap_or_else(|| F::zero());
585
586    // Node points
587    let c2 = F::from_f64(1.0 / 5.0).unwrap_or_else(|| F::zero());
588    let c3 = F::from_f64(3.0 / 10.0).unwrap_or_else(|| F::zero());
589    let c4 = F::from_f64(4.0 / 5.0).unwrap_or_else(|| F::zero());
590    let c5 = F::from_f64(8.0 / 9.0).unwrap_or_else(|| F::zero());
591
592    // Stage 1
593    let k1 = evaluate_rhs(sys, history, t, y)?;
594
595    // Stage 2
596    let y2 = y + &(&k1 * (h * a21));
597    let k2 = evaluate_rhs(sys, history, t + c2 * h, &y2)?;
598
599    // Stage 3
600    let y3 = y + &(&k1 * (h * a31) + &k2 * (h * a32));
601    let k3 = evaluate_rhs(sys, history, t + c3 * h, &y3)?;
602
603    // Stage 4
604    let y4 = y + &(&k1 * (h * a41) + &k2 * (h * a42) + &k3 * (h * a43));
605    let k4 = evaluate_rhs(sys, history, t + c4 * h, &y4)?;
606
607    // Stage 5
608    let y5 = y + &(&k1 * (h * a51) + &k2 * (h * a52) + &k3 * (h * a53) + &k4 * (h * a54));
609    let k5 = evaluate_rhs(sys, history, t + c5 * h, &y5)?;
610
611    // Stage 6
612    let y6 = y + &(&k1 * (h * a61)
613        + &k2 * (h * a62)
614        + &k3 * (h * a63)
615        + &k4 * (h * a64)
616        + &k5 * (h * a65));
617    let k6 = evaluate_rhs(sys, history, t + h, &y6)?;
618
619    // 5th order solution
620    let y_new =
621        y + &(&k1 * (h * b1) + &k3 * (h * b3) + &k4 * (h * b4) + &k5 * (h * b5) + &k6 * (h * b6));
622
623    // Stage 7 (for error estimate, FSAL property)
624    let k7 = evaluate_rhs(sys, history, t + h, &y_new)?;
625
626    // Error estimate
627    let err = &k1 * (h * e1)
628        + &k3 * (h * e3)
629        + &k4 * (h * e4)
630        + &k5 * (h * e5)
631        + &k6 * (h * e6)
632        + &k7 * (h * e7);
633
634    // Compute error norm
635    let mut sum = F::zero();
636    for i in 0..n {
637        let scale = atol + rtol * y[i].abs().max(y_new[i].abs());
638        let ratio = err[i] / scale;
639        sum += ratio * ratio;
640    }
641    let err_norm = (sum / F::from_usize(n).unwrap_or_else(|| F::one())).sqrt();
642
643    Ok(RK45StepResult {
644        y_new,
645        error_norm: err_norm,
646    })
647}
648
649/// Evaluate the DDE right-hand side at (t, y), looking up delayed values from history.
650fn evaluate_rhs<F: IntegrateFloat>(
651    sys: &dyn DDESystem<F>,
652    history: &HistoryBuffer<F>,
653    t: F,
654    y: &Array1<F>,
655) -> IntegrateResult<Array1<F>> {
656    // Get delays (possibly state-dependent)
657    let delays = sys.state_dependent_delays(t, y.view());
658
659    // Look up delayed values
660    let mut y_delayed = Vec::with_capacity(delays.len());
661    for tau in &delays {
662        let t_delayed = t - *tau;
663        let y_del = history.evaluate(t_delayed)?;
664        y_delayed.push(y_del);
665    }
666
667    sys.rhs(t, y.view(), &y_delayed)
668}
669
670// ---------------------------------------------------------------------------
671// Convenience: Simple DDE (single constant delay)
672// ---------------------------------------------------------------------------
673
674/// A simple DDE system with a single constant delay.
675///
676/// y'(t) = f(t, y(t), y(t - tau))
677pub struct SimpleConstantDDE<F: IntegrateFloat> {
678    ndim: usize,
679    delay: F,
680    rhs_fn: Box<dyn Fn(F, ArrayView1<F>, ArrayView1<F>) -> Array1<F> + Send + Sync>,
681    history_fn: Box<dyn Fn(F) -> Array1<F> + Send + Sync>,
682}
683
684impl<F: IntegrateFloat> SimpleConstantDDE<F> {
685    /// Create a simple DDE with a single constant delay.
686    ///
687    /// * `ndim` - system dimension
688    /// * `delay` - the constant delay tau > 0
689    /// * `rhs_fn` - f(t, y, y_delayed)
690    /// * `history_fn` - phi(t) for t <= t0
691    pub fn new<R, H>(ndim: usize, delay: F, rhs_fn: R, history_fn: H) -> Self
692    where
693        R: Fn(F, ArrayView1<F>, ArrayView1<F>) -> Array1<F> + Send + Sync + 'static,
694        H: Fn(F) -> Array1<F> + Send + Sync + 'static,
695    {
696        SimpleConstantDDE {
697            ndim,
698            delay,
699            rhs_fn: Box::new(rhs_fn),
700            history_fn: Box::new(history_fn),
701        }
702    }
703}
704
705impl<F: IntegrateFloat> DDESystem<F> for SimpleConstantDDE<F> {
706    fn ndim(&self) -> usize {
707        self.ndim
708    }
709
710    fn rhs(&self, t: F, y: ArrayView1<F>, y_delayed: &[Array1<F>]) -> IntegrateResult<Array1<F>> {
711        if y_delayed.is_empty() {
712            return Err(IntegrateError::ValueError(
713                "Expected at least one delayed value".into(),
714            ));
715        }
716        Ok((self.rhs_fn)(t, y, y_delayed[0].view()))
717    }
718
719    fn delays(&self) -> Vec<F> {
720        vec![self.delay]
721    }
722
723    fn history(&self, t: F) -> Array1<F> {
724        (self.history_fn)(t)
725    }
726}
727
728/// A DDE system with multiple constant delays.
729pub struct MultiDelayDDE<F: IntegrateFloat> {
730    ndim: usize,
731    delays_vec: Vec<F>,
732    rhs_fn: Box<dyn Fn(F, ArrayView1<F>, &[Array1<F>]) -> Array1<F> + Send + Sync>,
733    history_fn: Box<dyn Fn(F) -> Array1<F> + Send + Sync>,
734}
735
736impl<F: IntegrateFloat> MultiDelayDDE<F> {
737    /// Create a DDE with multiple constant delays.
738    pub fn new<R, H>(ndim: usize, delays_vec: Vec<F>, rhs_fn: R, history_fn: H) -> Self
739    where
740        R: Fn(F, ArrayView1<F>, &[Array1<F>]) -> Array1<F> + Send + Sync + 'static,
741        H: Fn(F) -> Array1<F> + Send + Sync + 'static,
742    {
743        MultiDelayDDE {
744            ndim,
745            delays_vec,
746            rhs_fn: Box::new(rhs_fn),
747            history_fn: Box::new(history_fn),
748        }
749    }
750}
751
752impl<F: IntegrateFloat> DDESystem<F> for MultiDelayDDE<F> {
753    fn ndim(&self) -> usize {
754        self.ndim
755    }
756
757    fn rhs(&self, t: F, y: ArrayView1<F>, y_delayed: &[Array1<F>]) -> IntegrateResult<Array1<F>> {
758        Ok((self.rhs_fn)(t, y, y_delayed))
759    }
760
761    fn delays(&self) -> Vec<F> {
762        self.delays_vec.clone()
763    }
764
765    fn history(&self, t: F) -> Array1<F> {
766        (self.history_fn)(t)
767    }
768}
769
770/// A DDE system with state-dependent delays.
771pub struct StateDependentDDE<F: IntegrateFloat> {
772    ndim: usize,
773    n_delays: usize,
774    rhs_fn: Box<dyn Fn(F, ArrayView1<F>, &[Array1<F>]) -> Array1<F> + Send + Sync>,
775    delay_fn: Box<dyn Fn(F, ArrayView1<F>) -> Vec<F> + Send + Sync>,
776    history_fn: Box<dyn Fn(F) -> Array1<F> + Send + Sync>,
777}
778
779impl<F: IntegrateFloat> StateDependentDDE<F> {
780    /// Create a DDE with state-dependent delays.
781    ///
782    /// * `delay_fn` - function (t, y) -> vec of delay values
783    pub fn new<R, D, H>(ndim: usize, n_delays: usize, rhs_fn: R, delay_fn: D, history_fn: H) -> Self
784    where
785        R: Fn(F, ArrayView1<F>, &[Array1<F>]) -> Array1<F> + Send + Sync + 'static,
786        D: Fn(F, ArrayView1<F>) -> Vec<F> + Send + Sync + 'static,
787        H: Fn(F) -> Array1<F> + Send + Sync + 'static,
788    {
789        StateDependentDDE {
790            ndim,
791            n_delays,
792            rhs_fn: Box::new(rhs_fn),
793            delay_fn: Box::new(delay_fn),
794            history_fn: Box::new(history_fn),
795        }
796    }
797}
798
799impl<F: IntegrateFloat> DDESystem<F> for StateDependentDDE<F> {
800    fn ndim(&self) -> usize {
801        self.ndim
802    }
803
804    fn rhs(&self, t: F, y: ArrayView1<F>, y_delayed: &[Array1<F>]) -> IntegrateResult<Array1<F>> {
805        Ok((self.rhs_fn)(t, y, y_delayed))
806    }
807
808    fn delays(&self) -> Vec<F> {
809        // Return dummy delays; actual delays come from state_dependent_delays
810        vec![F::one(); self.n_delays]
811    }
812
813    fn state_dependent_delays(&self, t: F, y: ArrayView1<F>) -> Vec<F> {
814        (self.delay_fn)(t, y)
815    }
816
817    fn history(&self, t: F) -> Array1<F> {
818        (self.history_fn)(t)
819    }
820}
821
822// ---------------------------------------------------------------------------
823// Tests
824// ---------------------------------------------------------------------------
825
826#[cfg(test)]
827mod tests {
828    use super::*;
829    use scirs2_core::ndarray::array;
830
831    #[test]
832    fn test_simple_constant_delay() {
833        // y'(t) = -y(t - 1), y(t) = 1 for t <= 0
834        // On [0, 1]: y'(t) = -1, so y(t) = 1 - t
835        // At t=1: y(1) = 0
836        let sys = SimpleConstantDDE::new(
837            1,
838            1.0,
839            |_t, _y: ArrayView1<f64>, y_del: ArrayView1<f64>| -> Array1<f64> { array![-y_del[0]] },
840            |_t| array![1.0],
841        );
842
843        let opts = DDEOptions {
844            h0: Some(0.01),
845            rtol: 1e-6,
846            atol: 1e-9,
847            ..Default::default()
848        };
849
850        let result = solve_dde(&sys, [0.0, 1.0], &opts).expect("DDE solve should succeed");
851
852        assert!(result.success, "DDE integration should succeed");
853
854        // Check y(1) ~ 0
855        let y_final = result.y.last().expect("should have final state");
856        assert!(
857            (y_final[0] - 0.0).abs() < 0.05,
858            "y(1) = {} should be near 0",
859            y_final[0]
860        );
861    }
862
863    #[test]
864    fn test_dde_first_interval_exact() {
865        // y'(t) = -y(t-1), phi(t) = 1
866        // On [0,1], y' = -phi(t-1) = -1
867        // => y(t) = 1 - t on [0,1]
868        let sys = SimpleConstantDDE::new(
869            1,
870            1.0,
871            |_t, _y: ArrayView1<f64>, y_del: ArrayView1<f64>| -> Array1<f64> { array![-y_del[0]] },
872            |_t| array![1.0],
873        );
874
875        let opts = DDEOptions {
876            h0: Some(0.005),
877            rtol: 1e-8,
878            atol: 1e-12,
879            ..Default::default()
880        };
881
882        let result = solve_dde(&sys, [0.0, 0.5], &opts).expect("DDE solve should succeed");
883
884        // y(0.5) = 1 - 0.5 = 0.5
885        let y_final = result.y.last().expect("should have final state");
886        assert!(
887            (y_final[0] - 0.5).abs() < 0.01,
888            "y(0.5) = {} should be near 0.5",
889            y_final[0]
890        );
891    }
892
893    #[test]
894    fn test_multi_delay_dde() {
895        // y'(t) = -y(t-0.5) - y(t-1.0), phi(t) = 1
896        // On [0, 0.5], y' = -1 - 1 = -2, so y(t) = 1 - 2t
897        let sys = MultiDelayDDE::new(
898            1,
899            vec![0.5, 1.0],
900            |_t, _y: ArrayView1<f64>, y_del: &[Array1<f64>]| -> Array1<f64> {
901                array![-y_del[0][0] - y_del[1][0]]
902            },
903            |_t| array![1.0],
904        );
905
906        let opts = DDEOptions {
907            h0: Some(0.005),
908            rtol: 1e-8,
909            atol: 1e-12,
910            ..Default::default()
911        };
912
913        let result = solve_dde(&sys, [0.0, 0.5], &opts).expect("multi-delay DDE should succeed");
914
915        // y(0.5) = 1 - 2*0.5 = 0
916        let y_final = result.y.last().expect("should have final state");
917        assert!(
918            (y_final[0] - 0.0).abs() < 0.05,
919            "y(0.5) = {} should be near 0",
920            y_final[0]
921        );
922    }
923
924    #[test]
925    fn test_discontinuity_tracking() {
926        let sys = SimpleConstantDDE::new(
927            1,
928            0.5,
929            |_t, _y: ArrayView1<f64>, y_del: ArrayView1<f64>| -> Array1<f64> { array![-y_del[0]] },
930            |_t| array![1.0],
931        );
932
933        let opts = DDEOptions {
934            h0: Some(0.01),
935            track_discontinuities: true,
936            max_discontinuity_order: 3,
937            ..Default::default()
938        };
939
940        let result = solve_dde(&sys, [0.0, 2.0], &opts).expect("DDE with disc tracking");
941        assert!(result.success);
942
943        // Discontinuity tracker should have seeded times at 0.5, 1.0, 1.5, 2.0
944        // (though some may not be hit exactly)
945    }
946
947    #[test]
948    fn test_state_dependent_delay() {
949        // y'(t) = -y(t - |y(t)|), phi(t) = 1
950        // When y near 1, delay is 1.
951        let sys = StateDependentDDE::new(
952            1,
953            1,
954            |_t, _y: ArrayView1<f64>, y_del: &[Array1<f64>]| -> Array1<f64> {
955                array![-y_del[0][0]]
956            },
957            |_t, y: ArrayView1<f64>| -> Vec<f64> {
958                vec![y[0].abs().max(0.1)] // clamp delay to at least 0.1
959            },
960            |_t| array![1.0],
961        );
962
963        let opts = DDEOptions {
964            h0: Some(0.005),
965            rtol: 1e-5,
966            atol: 1e-8,
967            ..Default::default()
968        };
969
970        let result = solve_dde(&sys, [0.0, 0.5], &opts).expect("state-dep DDE should succeed");
971        assert!(result.success);
972        // Just verify it doesn't crash and produces reasonable output
973        assert!(result.y.len() > 2);
974    }
975
976    #[test]
977    fn test_dde_invalid_inputs() {
978        let sys = SimpleConstantDDE::new(
979            1,
980            1.0,
981            |_t, _y: ArrayView1<f64>, y_del: ArrayView1<f64>| -> Array1<f64> { array![-y_del[0]] },
982            |_t| array![1.0],
983        );
984
985        let opts = DDEOptions::default();
986
987        // tf <= t0
988        let result = solve_dde(&sys, [1.0, 0.0], &opts);
989        assert!(result.is_err());
990    }
991
992    #[test]
993    fn test_hermite_interpolation() {
994        // Test the dense output segment interpolation
995        let seg = DenseSegment {
996            t_left: 0.0,
997            t_right: 1.0,
998            y_left: array![1.0],
999            y_right: array![0.0],
1000            yp_left: array![-1.0],  // derivative at left
1001            yp_right: array![-1.0], // derivative at right
1002        };
1003
1004        // For a linear function y = 1 - t, the Hermite interpolant should be exact
1005        let y_mid = seg.evaluate(0.5_f64);
1006        assert!(
1007            (y_mid[0] - 0.5_f64).abs() < 1e-10,
1008            "Hermite at 0.5: {}",
1009            y_mid[0]
1010        );
1011
1012        // Check endpoints
1013        let y0 = seg.evaluate(0.0_f64);
1014        assert!((y0[0] - 1.0_f64).abs() < 1e-10);
1015
1016        let y1 = seg.evaluate(1.0_f64);
1017        assert!((y1[0] - 0.0_f64).abs() < 1e-10);
1018    }
1019
1020    #[test]
1021    fn test_history_buffer() {
1022        let mut buf = HistoryBuffer::new(0.0);
1023        buf.add_pre_history(-1.0, array![2.0]);
1024        buf.add_pre_history(-0.5, array![1.5]);
1025        buf.add_pre_history(0.0, array![1.0]);
1026
1027        // Interpolate in pre-history
1028        let y = buf.evaluate(-0.75).expect("pre-history eval");
1029        assert!((y[0] - 1.75_f64).abs() < 1e-10, "y(-0.75) = {}", y[0]);
1030
1031        // Add a segment
1032        buf.add_segment(DenseSegment {
1033            t_left: 0.0,
1034            t_right: 0.5,
1035            y_left: array![1.0],
1036            y_right: array![0.5],
1037            yp_left: array![-1.0],
1038            yp_right: array![-1.0],
1039        });
1040
1041        let y = buf.evaluate(0.25).expect("segment eval");
1042        assert!((y[0] - 0.75_f64).abs() < 0.1, "y(0.25) = {}", y[0]);
1043    }
1044}