1use thiserror::Error;
3
4pub mod common;
5pub mod bdf;
6pub mod dense_output;
7
8
9use common::*;
10use bdf::*;
11use dense_output::*;
12
13#[derive(Debug, Error, Clone)]
14pub enum IntegratorError {
15 #[error("Step size too small: required step size is less than spacing between numbers")]
16 StepTooSmall,
17
18 #[error("LU decomposition failed: matrix is singular")]
19 SingularMatrix,
20
21 #[error("Newton iteration failed to converge")]
22 NewtonFailed,
23
24 #[error("Invalid function evaluation: non-finite values detected")]
25 InvalidFunction,
26
27 #[error("Integration bounds error: {message}")]
28 BoundsError { message: String },
29
30 #[error("Tolerance validation error: {message}")]
31 ToleranceError { message: String },
32
33 #[error("Initial step validation error: {message}")]
34 InitialStepError { message: String },
35
36 #[error("General integration error: {message}")]
37 IntegrationError { message: String },
38}
39
40#[derive(Debug)]
41pub enum SolveResult {
42 Success {
43 t: Vec<f64>,
44 y: Vec<Vec<f64>>,
45 dense_output: Option<Vec<BdfDenseOutput>>,
46 },
47 Failed {
48 error: anyhow::Error,
49 t: Vec<f64>,
50 y: Vec<Vec<f64>>,
51 },
52 EventTerminated {
53 t: Vec<f64>,
54 y: Vec<Vec<f64>>,
55 event_index: usize,
56 },
57}
58
59pub type EventFunction = Box<dyn Fn(f64, &[f64]) -> bool>;
60
61pub struct Event {
62 pub function: EventFunction,
63 pub terminal: bool,
64}
65
66pub struct SolveIvpOptions {
67 pub rtol: f64,
68 pub atol: f64,
69 pub max_step: f64,
70 pub first_step: Option<f64>,
71 pub events: Vec<Event>,
72 pub dense_output: bool,
73}
74
75impl Default for SolveIvpOptions {
76 fn default() -> Self {
77 Self {
78 rtol: 1e-3,
79 atol: 1e-6,
80 max_step: f64::INFINITY,
81 first_step: None,
82 events: Vec::new(),
83 dense_output: false,
84 }
85 }
86}
87
88pub fn solve_ivp_bdf(
89 fun: impl Fn(f64, &[f64]) -> Vec<f64>,
90 t0: f64,
91 t1: f64,
92 y0: Vec<f64>,
93 options: Option<SolveIvpOptions>,
94) -> SolveResult {
95 let opts = options.unwrap_or_default();
96
97 let mut solver = BdfSolver::new(
98 fun,
99 t0,
100 y0,
101 t1,
102 opts.rtol,
103 opts.atol,
104 opts.max_step,
105 opts.first_step,
106 );
107
108 let mut t_values = vec![t0];
109 let mut y_values = vec![solver.y.clone()];
110 let mut dense_outputs = if opts.dense_output {
111 Some(Vec::<BdfDenseOutput>::new())
112 } else {
113 None
114 };
115
116 while solver.status == SolverStatus::Running {
117 match solver.step() {
118 Ok(()) => {
119 }
121 Err(e) => {
122 return SolveResult::Failed {
123 error: e,
124 t: t_values,
125 y: y_values,
126 };
127 }
128 }
129
130 for (i, event) in opts.events.iter().enumerate() {
132 let event_value = (event.function)(solver.t, &solver.y);
133 if event_value {
135 if event.terminal {
136 t_values.push(solver.t);
137 y_values.push(solver.y.clone());
138 return SolveResult::EventTerminated {
139 t: t_values,
140 y: y_values,
141 event_index: i,
142 };
143 }
144 }
145 }
146
147 t_values.push(solver.t);
148 y_values.push(solver.y.clone());
149
150 if let Some(ref mut dense_vec) = dense_outputs {
151 if solver.t_old.is_some() {
152 dense_vec.push(solver.dense_output());
153 }
154 }
155 }
156
157 SolveResult::Success {
158 t: t_values,
159 y: y_values,
160 dense_output: dense_outputs,
161 }
162}
163
164
165#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_simple_harmonic_oscillator_success() {
173 let fun = Box::new(|_t: f64, y: &[f64]| -> Vec<f64> {
177 vec![y[1], -y[0]]
178 });
179
180 let t0 = 0.0;
181 let t1 = 2.0 * std::f64::consts::PI; let y0 = vec![1.0, 0.0]; let options = SolveIvpOptions {
185 rtol: 1e-8,
186 atol: 1e-10,
187 dense_output: true,
188 ..Default::default()
189 };
190
191 let result = solve_ivp_bdf(fun, t0, t1, y0, Some(options));
192
193 match result {
194 SolveResult::Success { t, y, dense_output } => {
195 assert!((t.last().unwrap() - t1).abs() < 1e-10);
197
198 let y_final = y.last().unwrap();
200 assert!((y_final[0] - 1.0).abs() < 1e-5, "y position error: {}", (y_final[0] - 1.0).abs());
201 assert!((y_final[1] - 0.0).abs() < 1e-5, "y velocity error: {}", (y_final[1] - 0.0).abs());
202
203 assert!(dense_output.is_some());
205
206 println!("Harmonic oscillator test successful!");
207 println!("Final error in position: {}", (y_final[0] - 1.0).abs());
208 println!("Final error in velocity: {}", (y_final[1] - 0.0).abs());
209 }
210 SolveResult::Failed { error, .. } => {
211 panic!("Integration failed: {}", error);
212 }
213 SolveResult::EventTerminated { .. } => {
214 panic!("Unexpected event termination");
215 }
216 }
217 }
218
219 #[test]
220 fn test_exponential_decay_success() {
221 let fun = Box::new(|_t: f64, y: &[f64]| -> Vec<f64> {
224 vec![-y[0]]
225 });
226
227 let t0 = 0.0;
228 let t1 = 5.0;
229 let y0 = vec![1.0];
230
231 let options = SolveIvpOptions {
232 rtol: 1e-10,
233 atol: 1e-12,
234 ..Default::default()
235 };
236
237 let result = solve_ivp_bdf(fun, t0, t1, y0, Some(options));
238
239 match result {
240 SolveResult::Success { t, y, .. } => {
241 assert!((t.last().unwrap() - t1).abs() < 1e-10);
243
244 let y_final = y.last().unwrap()[0];
246 let y_analytical = (-t1).exp();
247 let error = (y_final - y_analytical).abs();
248
249 assert!(error < 1e-7, "Solution error too large: {}", error);
250
251 println!("Exponential decay test successful!");
252 println!("Analytical solution: {}", y_analytical);
253 println!("Numerical solution: {}", y_final);
254 println!("Absolute error: {}", error);
255 }
256 SolveResult::Failed { error, .. } => {
257 panic!("Integration failed: {}", error);
258 }
259 SolveResult::EventTerminated { .. } => {
260 panic!("Unexpected event termination");
261 }
262 }
263 }
264
265 #[test]
266 fn test_event_termination() {
267 let fun = Box::new(|_t: f64, _y: &[f64]| -> Vec<f64> {
270 vec![1.0]
271 });
272
273 let event_fn = Box::new(|_t: f64, y: &[f64]| -> bool {
275 y[0] - 2.5 > 0.0 });
277
278 let event = Event {
279 function: event_fn,
280 terminal: true,
281 };
282
283 let t0 = 0.0;
284 let t1 = 10.0; let y0 = vec![0.0];
286
287 let options = SolveIvpOptions {
288 rtol: 1e-8,
289 atol: 1e-10,
290 events: vec![event],
291 ..Default::default()
292 };
293
294 let result = solve_ivp_bdf(fun, t0, t1, y0, Some(options));
295
296 match result {
297 SolveResult::Success { .. } => {
298 panic!("Expected event termination, but got success");
299 }
300 SolveResult::Failed { error, .. } => {
301 panic!("Integration failed: {}", error);
302 }
303 SolveResult::EventTerminated { t, y, event_index} => {
304 assert_eq!(event_index, 0, "Wrong event index");
305
306 let y_final = y.last().unwrap()[0];
308 let t_final = *t.last().unwrap();
309
310 assert!((y_final - 2.5).abs() < 0.1, "Event y value error: {}", (y_final - 2.5).abs());
313 assert!((t_final - 2.5).abs() < 0.1, "Event time error: {}", (t_final - 2.5).abs());
314 assert!(t_final < t1, "Event should occur before final time");
315
316 println!("Event termination test successful!");
317 println!("Event triggered at t = {}", t_final);
318 println!("Event triggered at y = {}", y_final);
319 println!("Expected values: t ≈ 2.5, y ≈ 2.5");
320 }
321 }
322 }
323
324 #[test]
325 fn test_pendulum_with_stopping_event() {
326 let fun = Box::new(|_t: f64, y: &[f64]| -> Vec<f64> {
330 let theta = y[0];
331 let omega = y[1];
332 vec![omega, -theta.sin()]
333 });
334
335 let event_fn = Box::new(|_t: f64, y: &[f64]| -> bool {
337 y[0] < 0.0 });
339
340 let event = Event {
341 function: event_fn,
342 terminal: true,
343 };
344
345 let t0 = 0.0;
346 let t1 = 10.0;
347 let y0 = vec![std::f64::consts::PI / 4.0, 0.0]; let options = SolveIvpOptions {
350 rtol: 1e-8,
351 atol: 1e-10,
352 events: vec![event],
353 dense_output: true,
354 ..Default::default()
355 };
356
357 let result = solve_ivp_bdf(fun, t0, t1, y0, Some(options));
358
359 match result {
360 SolveResult::Success { .. } => {
361 panic!("Expected event termination, but got success");
362 }
363 SolveResult::Failed { error, .. } => {
364 panic!("Integration failed: {}", error);
365 }
366 SolveResult::EventTerminated { t, y, event_index} => {
367 assert_eq!(event_index, 0, "Wrong event index");
368
369 let t_final = *t.last().unwrap();
370 let y_final = y.last().unwrap();
371 let theta_final = y_final[0];
372 let omega_final = y_final[1];
373
374 assert!(theta_final.abs() < 0.1, "Theta should be close to 0, got {}", theta_final);
376
377 assert!(omega_final < 0.0, "Omega should be negative, got {}", omega_final);
379
380 assert!(t_final > 0.5 && t_final < 5.0, "Event time seems unreasonable: {}", t_final);
382
383 println!("Pendulum event test successful!");
384 println!("Event triggered at t = {}", t_final);
385 println!("Final theta = {}", theta_final);
386 println!("Final omega = {}", omega_final);
387 }
388 }
389 }
390}