Skip to main content

scirs2_integrate/sde/
mod.rs

1//! Stochastic Differential Equation (SDE) solvers
2//!
3//! This module provides numerical methods for solving SDEs of the form:
4//!
5//! ```text
6//! dx = f(t, x) dt + g(t, x) dW
7//! ```
8//!
9//! where:
10//! - `f(t, x)` is the drift coefficient (deterministic part)
11//! - `g(t, x)` is the diffusion matrix (stochastic part)
12//! - `dW` is an increment of a Wiener process (Brownian motion)
13//!
14//! ## Methods
15//!
16//! | Method | Module | Strong Order | Weak Order |
17//! |--------|--------|-------------|------------|
18//! | Euler-Maruyama | `euler_maruyama` | 0.5 | 1.0 |
19//! | Milstein | `milstein` | 1.0 | 1.0 |
20//! | Stochastic Runge-Kutta | `runge_kutta_sde` | 1.5 | 2.0 |
21//! | Platen explicit | `runge_kutta_sde` | 1.5 | 2.0 |
22//!
23//! ## Quick Start
24//!
25//! ```rust
26//! use scirs2_integrate::sde::{SdeProblem, SdeOptions};
27//! use scirs2_integrate::sde::euler_maruyama::euler_maruyama;
28//! use scirs2_core::ndarray::{array, Array1, Array2};
29//! use scirs2_core::random::prelude::*;
30//!
31//! // Geometric Brownian Motion: dX = μ X dt + σ X dW
32//! let mu = 0.05_f64;
33//! let sigma = 0.2_f64;
34//! let x0 = array![100.0_f64];
35//!
36//! let drift = move |_t: f64, x: &Array1<f64>| -> Array1<f64> {
37//!     array![mu * x[0]]
38//! };
39//! let diffusion = move |_t: f64, x: &Array1<f64>| -> Array2<f64> {
40//!     let mut g = Array2::zeros((1, 1));
41//!     g[[0, 0]] = sigma * x[0];
42//!     g
43//! };
44//!
45//! let prob = SdeProblem::new(x0, [0.0, 1.0], 1, drift, diffusion);
46//! let mut rng = seeded_rng(42);
47//! let sol = euler_maruyama(&prob, 0.01, &mut rng).unwrap();
48//! assert!(!sol.t.is_empty());
49//! ```
50
51pub mod euler_maruyama;
52pub mod examples;
53pub mod fractional_brownian;
54pub mod jump_diffusion;
55pub mod levy_area;
56pub mod milstein;
57pub mod particle_filter;
58pub mod processes;
59pub mod rough_sde;
60pub mod runge_kutta_sde;
61pub mod srk;
62pub mod streaming_particle_filter;
63pub mod weak_order2;
64pub mod weak_schemes;
65
66pub use levy_area::{iterated_integral, levy_area_wiktorsson};
67pub use streaming_particle_filter::{
68    FilterEstimate, SimpleRng, StreamingParticleFilter, StreamingParticleFilterBuilder,
69};
70
71use crate::error::{IntegrateError, IntegrateResult};
72use scirs2_core::ndarray::{Array1, Array2};
73
74/// Defines a Stochastic Differential Equation problem of the form:
75///
76/// ```text
77/// dx = f(t, x) dt + g(t, x) dW,   x(t0) = x0
78/// ```
79///
80/// where `f` is the drift coefficient, `g` is the diffusion matrix,
81/// and `dW` is an m-dimensional Wiener process increment.
82///
83/// # Type Parameters
84///
85/// * `F` - drift function type: `Fn(f64, &Array1<f64>) -> Array1<f64>`
86/// * `G` - diffusion function type: `Fn(f64, &Array1<f64>) -> Array2<f64>`
87pub struct SdeProblem<F, G>
88where
89    F: Fn(f64, &Array1<f64>) -> Array1<f64>,
90    G: Fn(f64, &Array1<f64>) -> Array2<f64>,
91{
92    /// Initial state vector x(t0), dimension n
93    pub x0: Array1<f64>,
94    /// Time span [t0, t1]
95    pub t_span: [f64; 2],
96    /// Number of independent Brownian motions (Wiener processes), m
97    pub n_brownian: usize,
98    /// Drift coefficient f(t, x): R × R^n → R^n
99    pub f_drift: F,
100    /// Diffusion matrix g(t, x): R × R^n → R^{n×m}
101    pub g_diffusion: G,
102}
103
104impl<F, G> SdeProblem<F, G>
105where
106    F: Fn(f64, &Array1<f64>) -> Array1<f64>,
107    G: Fn(f64, &Array1<f64>) -> Array2<f64>,
108{
109    /// Create a new SDE problem.
110    ///
111    /// # Arguments
112    ///
113    /// * `x0` - Initial state vector (length n)
114    /// * `t_span` - Time interval [t0, t1]
115    /// * `n_brownian` - Number of independent Brownian motions
116    /// * `f_drift` - Drift function f(t, x) → R^n
117    /// * `g_diffusion` - Diffusion function g(t, x) → R^{n×m}
118    ///
119    /// # Errors
120    ///
121    /// Returns an error if the time span is invalid (t0 >= t1).
122    pub fn new(
123        x0: Array1<f64>,
124        t_span: [f64; 2],
125        n_brownian: usize,
126        f_drift: F,
127        g_diffusion: G,
128    ) -> Self {
129        Self {
130            x0,
131            t_span,
132            n_brownian,
133            f_drift,
134            g_diffusion,
135        }
136    }
137
138    /// Dimension of the state space (n)
139    pub fn dim(&self) -> usize {
140        self.x0.len()
141    }
142
143    /// Validate the problem parameters
144    pub fn validate(&self) -> IntegrateResult<()> {
145        if self.t_span[0] >= self.t_span[1] {
146            return Err(IntegrateError::InvalidInput(format!(
147                "t_span must satisfy t0 < t1, got [{}, {}]",
148                self.t_span[0], self.t_span[1]
149            )));
150        }
151        if self.n_brownian == 0 {
152            return Err(IntegrateError::InvalidInput(
153                "n_brownian must be at least 1".to_string(),
154            ));
155        }
156        if self.x0.is_empty() {
157            return Err(IntegrateError::InvalidInput(
158                "Initial state x0 must be non-empty".to_string(),
159            ));
160        }
161        Ok(())
162    }
163}
164
165/// Solution to an SDE, containing the time points and state trajectories.
166#[derive(Debug, Clone)]
167pub struct SdeSolution {
168    /// Time points t_0, t_1, ..., t_N
169    pub t: Vec<f64>,
170    /// State trajectory x(t_0), x(t_1), ..., x(t_N)
171    pub x: Vec<Array1<f64>>,
172}
173
174impl SdeSolution {
175    /// Create a new empty solution with pre-allocated capacity.
176    pub fn with_capacity(n: usize) -> Self {
177        Self {
178            t: Vec::with_capacity(n),
179            x: Vec::with_capacity(n),
180        }
181    }
182
183    /// Push a new time-state pair.
184    pub fn push(&mut self, t: f64, x: Array1<f64>) {
185        self.t.push(t);
186        self.x.push(x);
187    }
188
189    /// Number of time points in the solution.
190    pub fn len(&self) -> usize {
191        self.t.len()
192    }
193
194    /// Returns true if the solution is empty.
195    pub fn is_empty(&self) -> bool {
196        self.t.is_empty()
197    }
198
199    /// Final time value.
200    pub fn t_final(&self) -> Option<f64> {
201        self.t.last().copied()
202    }
203
204    /// Final state.
205    pub fn x_final(&self) -> Option<&Array1<f64>> {
206        self.x.last()
207    }
208
209    /// Compute the mean trajectory across an ensemble of solutions.
210    ///
211    /// All solutions must have the same time points and state dimensions.
212    pub fn ensemble_mean(solutions: &[SdeSolution]) -> IntegrateResult<SdeSolution> {
213        if solutions.is_empty() {
214            return Err(IntegrateError::InvalidInput(
215                "Cannot compute mean of empty ensemble".to_string(),
216            ));
217        }
218        let n_steps = solutions[0].len();
219        let n_ensemble = solutions.len();
220        let mut result = SdeSolution::with_capacity(n_steps);
221
222        for step in 0..n_steps {
223            let t = solutions[0].t[step];
224            let dim = solutions[0].x[step].len();
225            let mut mean_x = Array1::zeros(dim);
226            for sol in solutions {
227                if sol.len() != n_steps {
228                    return Err(IntegrateError::DimensionMismatch(
229                        "All solutions in ensemble must have the same number of steps".to_string(),
230                    ));
231                }
232                mean_x += &sol.x[step];
233            }
234            mean_x /= n_ensemble as f64;
235            result.push(t, mean_x);
236        }
237        Ok(result)
238    }
239
240    /// Compute the variance trajectory across an ensemble of solutions.
241    pub fn ensemble_variance(solutions: &[SdeSolution]) -> IntegrateResult<SdeSolution> {
242        if solutions.is_empty() {
243            return Err(IntegrateError::InvalidInput(
244                "Cannot compute variance of empty ensemble".to_string(),
245            ));
246        }
247        let n_steps = solutions[0].len();
248        let n_ensemble = solutions.len();
249        if n_ensemble < 2 {
250            return Err(IntegrateError::InvalidInput(
251                "Need at least 2 solutions to compute variance".to_string(),
252            ));
253        }
254        let mean_sol = Self::ensemble_mean(solutions)?;
255        let mut result = SdeSolution::with_capacity(n_steps);
256
257        for step in 0..n_steps {
258            let t = solutions[0].t[step];
259            let dim = solutions[0].x[step].len();
260            let mut var_x = Array1::zeros(dim);
261            for sol in solutions {
262                let diff = &sol.x[step] - &mean_sol.x[step];
263                var_x += &diff.mapv(|v| v * v);
264            }
265            var_x /= (n_ensemble - 1) as f64;
266            result.push(t, var_x);
267        }
268        Ok(result)
269    }
270}
271
272/// Options for SDE solvers.
273#[derive(Debug, Clone)]
274pub struct SdeOptions {
275    /// Whether to save the solution at every step (true) or only at final time (false)
276    pub save_all_steps: bool,
277    /// Maximum number of steps (safety limit)
278    pub max_steps: usize,
279}
280
281impl Default for SdeOptions {
282    fn default() -> Self {
283        Self {
284            save_all_steps: true,
285            max_steps: 10_000_000,
286        }
287    }
288}
289
290/// Compute the number of steps needed for a given time span and step size,
291/// clamping to `max_steps` as a safety limit.
292pub(crate) fn compute_n_steps(
293    t0: f64,
294    t1: f64,
295    dt: f64,
296    max_steps: usize,
297) -> IntegrateResult<usize> {
298    if dt <= 0.0 {
299        return Err(IntegrateError::InvalidInput(format!(
300            "Step size dt must be positive, got {}",
301            dt
302        )));
303    }
304    let n = ((t1 - t0) / dt).ceil() as usize;
305    if n > max_steps {
306        return Err(IntegrateError::InvalidInput(format!(
307            "Required steps {} exceeds maximum {}",
308            n, max_steps
309        )));
310    }
311    Ok(n.max(1))
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use scirs2_core::ndarray::{array, Array2};
318
319    #[test]
320    fn test_sde_problem_creation() {
321        let x0 = array![1.0_f64];
322        let prob = SdeProblem::new(
323            x0,
324            [0.0, 1.0],
325            1,
326            |_t, x| x.clone(),
327            |_t, _x| Array2::eye(1),
328        );
329        assert_eq!(prob.dim(), 1);
330        assert_eq!(prob.n_brownian, 1);
331        prob.validate().expect("Validation should pass");
332    }
333
334    #[test]
335    fn test_sde_problem_invalid_tspan() {
336        let x0 = array![1.0_f64];
337        let prob = SdeProblem::new(
338            x0,
339            [1.0, 0.0], // t0 > t1 is invalid
340            1,
341            |_t, x| x.clone(),
342            |_t, _x| Array2::eye(1),
343        );
344        assert!(prob.validate().is_err());
345    }
346
347    #[test]
348    fn test_sde_solution_push_and_query() {
349        let mut sol = SdeSolution::with_capacity(3);
350        sol.push(0.0, array![1.0_f64]);
351        sol.push(0.5, array![1.1_f64]);
352        sol.push(1.0, array![1.2_f64]);
353        assert_eq!(sol.len(), 3);
354        assert!(!sol.is_empty());
355        assert!((sol.t_final().expect("solution has time steps") - 1.0).abs() < 1e-12);
356        assert!((sol.x_final().expect("solution has state")[0] - 1.2).abs() < 1e-12);
357    }
358
359    #[test]
360    fn test_ensemble_mean() {
361        let mut sol1 = SdeSolution::with_capacity(2);
362        sol1.push(0.0, array![1.0_f64]);
363        sol1.push(1.0, array![2.0_f64]);
364
365        let mut sol2 = SdeSolution::with_capacity(2);
366        sol2.push(0.0, array![1.0_f64]);
367        sol2.push(1.0, array![4.0_f64]);
368
369        let mean = SdeSolution::ensemble_mean(&[sol1, sol2]).expect("ensemble_mean should succeed");
370        assert!((mean.x[1][0] - 3.0).abs() < 1e-12);
371    }
372
373    #[test]
374    fn test_compute_n_steps() {
375        let n = compute_n_steps(0.0, 1.0, 0.1, 1000).expect("compute_n_steps should succeed");
376        assert_eq!(n, 10);
377    }
378
379    #[test]
380    fn test_compute_n_steps_invalid_dt() {
381        assert!(compute_n_steps(0.0, 1.0, -0.1, 1000).is_err());
382        assert!(compute_n_steps(0.0, 1.0, 0.0, 1000).is_err());
383    }
384}