1pub mod euler_maruyama;
52pub mod examples;
53pub mod fractional_brownian;
54pub mod jump_diffusion;
55pub mod levy_area;
56pub mod milstein;
57pub mod particle_filter;
58pub mod processes;
59pub mod rough_sde;
60pub mod runge_kutta_sde;
61pub mod srk;
62pub mod streaming_particle_filter;
63pub mod weak_order2;
64pub mod weak_schemes;
65
66pub use levy_area::{iterated_integral, levy_area_wiktorsson};
67pub use streaming_particle_filter::{
68 FilterEstimate, SimpleRng, StreamingParticleFilter, StreamingParticleFilterBuilder,
69};
70
71use crate::error::{IntegrateError, IntegrateResult};
72use scirs2_core::ndarray::{Array1, Array2};
73
74pub struct SdeProblem<F, G>
88where
89 F: Fn(f64, &Array1<f64>) -> Array1<f64>,
90 G: Fn(f64, &Array1<f64>) -> Array2<f64>,
91{
92 pub x0: Array1<f64>,
94 pub t_span: [f64; 2],
96 pub n_brownian: usize,
98 pub f_drift: F,
100 pub g_diffusion: G,
102}
103
104impl<F, G> SdeProblem<F, G>
105where
106 F: Fn(f64, &Array1<f64>) -> Array1<f64>,
107 G: Fn(f64, &Array1<f64>) -> Array2<f64>,
108{
109 pub fn new(
123 x0: Array1<f64>,
124 t_span: [f64; 2],
125 n_brownian: usize,
126 f_drift: F,
127 g_diffusion: G,
128 ) -> Self {
129 Self {
130 x0,
131 t_span,
132 n_brownian,
133 f_drift,
134 g_diffusion,
135 }
136 }
137
138 pub fn dim(&self) -> usize {
140 self.x0.len()
141 }
142
143 pub fn validate(&self) -> IntegrateResult<()> {
145 if self.t_span[0] >= self.t_span[1] {
146 return Err(IntegrateError::InvalidInput(format!(
147 "t_span must satisfy t0 < t1, got [{}, {}]",
148 self.t_span[0], self.t_span[1]
149 )));
150 }
151 if self.n_brownian == 0 {
152 return Err(IntegrateError::InvalidInput(
153 "n_brownian must be at least 1".to_string(),
154 ));
155 }
156 if self.x0.is_empty() {
157 return Err(IntegrateError::InvalidInput(
158 "Initial state x0 must be non-empty".to_string(),
159 ));
160 }
161 Ok(())
162 }
163}
164
165#[derive(Debug, Clone)]
167pub struct SdeSolution {
168 pub t: Vec<f64>,
170 pub x: Vec<Array1<f64>>,
172}
173
174impl SdeSolution {
175 pub fn with_capacity(n: usize) -> Self {
177 Self {
178 t: Vec::with_capacity(n),
179 x: Vec::with_capacity(n),
180 }
181 }
182
183 pub fn push(&mut self, t: f64, x: Array1<f64>) {
185 self.t.push(t);
186 self.x.push(x);
187 }
188
189 pub fn len(&self) -> usize {
191 self.t.len()
192 }
193
194 pub fn is_empty(&self) -> bool {
196 self.t.is_empty()
197 }
198
199 pub fn t_final(&self) -> Option<f64> {
201 self.t.last().copied()
202 }
203
204 pub fn x_final(&self) -> Option<&Array1<f64>> {
206 self.x.last()
207 }
208
209 pub fn ensemble_mean(solutions: &[SdeSolution]) -> IntegrateResult<SdeSolution> {
213 if solutions.is_empty() {
214 return Err(IntegrateError::InvalidInput(
215 "Cannot compute mean of empty ensemble".to_string(),
216 ));
217 }
218 let n_steps = solutions[0].len();
219 let n_ensemble = solutions.len();
220 let mut result = SdeSolution::with_capacity(n_steps);
221
222 for step in 0..n_steps {
223 let t = solutions[0].t[step];
224 let dim = solutions[0].x[step].len();
225 let mut mean_x = Array1::zeros(dim);
226 for sol in solutions {
227 if sol.len() != n_steps {
228 return Err(IntegrateError::DimensionMismatch(
229 "All solutions in ensemble must have the same number of steps".to_string(),
230 ));
231 }
232 mean_x += &sol.x[step];
233 }
234 mean_x /= n_ensemble as f64;
235 result.push(t, mean_x);
236 }
237 Ok(result)
238 }
239
240 pub fn ensemble_variance(solutions: &[SdeSolution]) -> IntegrateResult<SdeSolution> {
242 if solutions.is_empty() {
243 return Err(IntegrateError::InvalidInput(
244 "Cannot compute variance of empty ensemble".to_string(),
245 ));
246 }
247 let n_steps = solutions[0].len();
248 let n_ensemble = solutions.len();
249 if n_ensemble < 2 {
250 return Err(IntegrateError::InvalidInput(
251 "Need at least 2 solutions to compute variance".to_string(),
252 ));
253 }
254 let mean_sol = Self::ensemble_mean(solutions)?;
255 let mut result = SdeSolution::with_capacity(n_steps);
256
257 for step in 0..n_steps {
258 let t = solutions[0].t[step];
259 let dim = solutions[0].x[step].len();
260 let mut var_x = Array1::zeros(dim);
261 for sol in solutions {
262 let diff = &sol.x[step] - &mean_sol.x[step];
263 var_x += &diff.mapv(|v| v * v);
264 }
265 var_x /= (n_ensemble - 1) as f64;
266 result.push(t, var_x);
267 }
268 Ok(result)
269 }
270}
271
272#[derive(Debug, Clone)]
274pub struct SdeOptions {
275 pub save_all_steps: bool,
277 pub max_steps: usize,
279}
280
281impl Default for SdeOptions {
282 fn default() -> Self {
283 Self {
284 save_all_steps: true,
285 max_steps: 10_000_000,
286 }
287 }
288}
289
290pub(crate) fn compute_n_steps(
293 t0: f64,
294 t1: f64,
295 dt: f64,
296 max_steps: usize,
297) -> IntegrateResult<usize> {
298 if dt <= 0.0 {
299 return Err(IntegrateError::InvalidInput(format!(
300 "Step size dt must be positive, got {}",
301 dt
302 )));
303 }
304 let n = ((t1 - t0) / dt).ceil() as usize;
305 if n > max_steps {
306 return Err(IntegrateError::InvalidInput(format!(
307 "Required steps {} exceeds maximum {}",
308 n, max_steps
309 )));
310 }
311 Ok(n.max(1))
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use scirs2_core::ndarray::{array, Array2};
318
319 #[test]
320 fn test_sde_problem_creation() {
321 let x0 = array![1.0_f64];
322 let prob = SdeProblem::new(
323 x0,
324 [0.0, 1.0],
325 1,
326 |_t, x| x.clone(),
327 |_t, _x| Array2::eye(1),
328 );
329 assert_eq!(prob.dim(), 1);
330 assert_eq!(prob.n_brownian, 1);
331 prob.validate().expect("Validation should pass");
332 }
333
334 #[test]
335 fn test_sde_problem_invalid_tspan() {
336 let x0 = array![1.0_f64];
337 let prob = SdeProblem::new(
338 x0,
339 [1.0, 0.0], 1,
341 |_t, x| x.clone(),
342 |_t, _x| Array2::eye(1),
343 );
344 assert!(prob.validate().is_err());
345 }
346
347 #[test]
348 fn test_sde_solution_push_and_query() {
349 let mut sol = SdeSolution::with_capacity(3);
350 sol.push(0.0, array![1.0_f64]);
351 sol.push(0.5, array![1.1_f64]);
352 sol.push(1.0, array![1.2_f64]);
353 assert_eq!(sol.len(), 3);
354 assert!(!sol.is_empty());
355 assert!((sol.t_final().expect("solution has time steps") - 1.0).abs() < 1e-12);
356 assert!((sol.x_final().expect("solution has state")[0] - 1.2).abs() < 1e-12);
357 }
358
359 #[test]
360 fn test_ensemble_mean() {
361 let mut sol1 = SdeSolution::with_capacity(2);
362 sol1.push(0.0, array![1.0_f64]);
363 sol1.push(1.0, array![2.0_f64]);
364
365 let mut sol2 = SdeSolution::with_capacity(2);
366 sol2.push(0.0, array![1.0_f64]);
367 sol2.push(1.0, array![4.0_f64]);
368
369 let mean = SdeSolution::ensemble_mean(&[sol1, sol2]).expect("ensemble_mean should succeed");
370 assert!((mean.x[1][0] - 3.0).abs() < 1e-12);
371 }
372
373 #[test]
374 fn test_compute_n_steps() {
375 let n = compute_n_steps(0.0, 1.0, 0.1, 1000).expect("compute_n_steps should succeed");
376 assert_eq!(n, 10);
377 }
378
379 #[test]
380 fn test_compute_n_steps_invalid_dt() {
381 assert!(compute_n_steps(0.0, 1.0, -0.1, 1000).is_err());
382 assert!(compute_n_steps(0.0, 1.0, 0.0, 1000).is_err());
383 }
384}