solve_ivp_bdf/
lib.rs

1// lib.rs
2use thiserror::Error;
3
4pub mod common;
5pub mod bdf;
6pub mod dense_output;
7
8
9use common::*;
10use bdf::*;
11use dense_output::*;
12
13#[derive(Debug, Error, Clone)]
14pub enum IntegratorError {
15    #[error("Step size too small: required step size is less than spacing between numbers")]
16    StepTooSmall,
17    
18    #[error("LU decomposition failed: matrix is singular")]
19    SingularMatrix,
20    
21    #[error("Newton iteration failed to converge")]
22    NewtonFailed,
23    
24    #[error("Invalid function evaluation: non-finite values detected")]
25    InvalidFunction,
26    
27    #[error("Integration bounds error: {message}")]
28    BoundsError { message: String },
29    
30    #[error("Tolerance validation error: {message}")]
31    ToleranceError { message: String },
32    
33    #[error("Initial step validation error: {message}")]
34    InitialStepError { message: String },
35    
36    #[error("General integration error: {message}")]
37    IntegrationError { message: String },
38}
39
40#[derive(Debug)]
41pub enum SolveResult {
42    Success {
43        t: Vec<f64>,
44        y: Vec<Vec<f64>>,
45        dense_output: Option<Vec<BdfDenseOutput>>,
46    },
47    Failed {
48        error: anyhow::Error,
49        t: Vec<f64>,
50        y: Vec<Vec<f64>>,
51    },
52    EventTerminated {
53        t: Vec<f64>,
54        y: Vec<Vec<f64>>,
55        event_index: usize,
56    },
57}
58
59pub type EventFunction = Box<dyn Fn(f64, &[f64]) -> bool>;
60
61pub struct Event {
62    pub function: EventFunction,
63    pub terminal: bool,
64}
65
66pub struct SolveIvpOptions {
67    pub rtol: f64,
68    pub atol: f64,
69    pub max_step: f64,
70    pub first_step: Option<f64>,
71    pub events: Vec<Event>,
72    pub dense_output: bool,
73}
74
75impl Default for SolveIvpOptions {
76    fn default() -> Self {
77        Self {
78            rtol: 1e-3,
79            atol: 1e-6,
80            max_step: f64::INFINITY,
81            first_step: None,
82            events: Vec::new(),
83            dense_output: false,
84        }
85    }
86}
87
88pub fn solve_ivp_bdf(
89    fun: impl Fn(f64, &[f64]) -> Vec<f64>,
90    t0: f64,
91    t1: f64,
92    y0: Vec<f64>,
93    options: Option<SolveIvpOptions>,
94) -> SolveResult {
95    let opts = options.unwrap_or_default();
96    
97    let mut solver = BdfSolver::new(
98        fun,
99        t0,
100        y0,
101        t1,
102        opts.rtol,
103        opts.atol,
104        opts.max_step,
105        opts.first_step,
106    );
107    
108    let mut t_values = vec![t0];
109    let mut y_values = vec![solver.y.clone()];
110    let mut dense_outputs = if opts.dense_output {
111        Some(Vec::<BdfDenseOutput>::new())
112    } else { 
113        None 
114    };
115    
116    while solver.status == SolverStatus::Running {
117        match solver.step() {
118            Ok(()) => {
119                // Step was successful
120            }
121            Err(e) => {
122                return SolveResult::Failed {
123                    error: e,
124                    t: t_values,
125                    y: y_values,
126                };
127            }
128        }
129        
130        // Check events
131        for (i, event) in opts.events.iter().enumerate() {
132            let event_value = (event.function)(solver.t, &solver.y);
133            // Simple event detection (should be improved for production use)
134            if event_value {
135                if event.terminal {
136                    t_values.push(solver.t);
137                    y_values.push(solver.y.clone());
138                    return SolveResult::EventTerminated {
139                        t: t_values,
140                        y: y_values,
141                        event_index: i,
142                    };
143                }
144            }
145        }
146        
147        t_values.push(solver.t);
148        y_values.push(solver.y.clone());
149        
150        if let Some(ref mut dense_vec) = dense_outputs {
151            if solver.t_old.is_some() {
152                dense_vec.push(solver.dense_output());
153            }
154        }
155    }
156    
157    SolveResult::Success {
158        t: t_values,
159        y: y_values,
160        dense_output: dense_outputs,
161    }
162}
163
164
165// lib.rs の最後に追記
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    
171    #[test]
172    fn test_simple_harmonic_oscillator_success() {
173        // Simple harmonic oscillator: y'' + y = 0
174        // Converted to system: y' = z, z' = -y
175        // Analytical solution: y(t) = cos(t), z(t) = -sin(t)
176        let fun = Box::new(|_t: f64, y: &[f64]| -> Vec<f64> {
177            vec![y[1], -y[0]]
178        });
179        
180        let t0 = 0.0;
181        let t1 = 2.0 * std::f64::consts::PI; // One full period
182        let y0 = vec![1.0, 0.0]; // y(0) = 1, y'(0) = 0
183        
184        let options = SolveIvpOptions {
185            rtol: 1e-8,
186            atol: 1e-10,
187            dense_output: true,
188            ..Default::default()
189        };
190        
191        let result = solve_ivp_bdf(fun, t0, t1, y0, Some(options));
192        
193        match result {
194            SolveResult::Success { t, y, dense_output } => {
195                // Check that we reached the final time
196                assert!((t.last().unwrap() - t1).abs() < 1e-10);
197                
198                // Check that the final state is close to the initial state (periodic solution)
199                let y_final = y.last().unwrap();
200                assert!((y_final[0] - 1.0).abs() < 1e-5, "y position error: {}", (y_final[0] - 1.0).abs());
201                assert!((y_final[1] - 0.0).abs() < 1e-5, "y velocity error: {}", (y_final[1] - 0.0).abs());
202                
203                // Check that dense output is available
204                assert!(dense_output.is_some());
205                
206                println!("Harmonic oscillator test successful!");
207                println!("Final error in position: {}", (y_final[0] - 1.0).abs());
208                println!("Final error in velocity: {}", (y_final[1] - 0.0).abs());
209            }
210            SolveResult::Failed { error, .. } => {
211                panic!("Integration failed: {}", error);
212            }
213            SolveResult::EventTerminated { .. } => {
214                panic!("Unexpected event termination");
215            }
216        }
217    }
218    
219    #[test]
220    fn test_exponential_decay_success() {
221        // Simple exponential decay: y' = -y
222        // Analytical solution: y(t) = exp(-t)
223        let fun = Box::new(|_t: f64, y: &[f64]| -> Vec<f64> {
224            vec![-y[0]]
225        });
226        
227        let t0 = 0.0;
228        let t1 = 5.0;
229        let y0 = vec![1.0];
230        
231        let options = SolveIvpOptions {
232            rtol: 1e-10,
233            atol: 1e-12,
234            ..Default::default()
235        };
236        
237        let result = solve_ivp_bdf(fun, t0, t1, y0, Some(options));
238        
239        match result {
240            SolveResult::Success { t, y, .. } => {
241                // Check final time
242                assert!((t.last().unwrap() - t1).abs() < 1e-10);
243                
244                // Check analytical solution: y(5) = exp(-5) ≈ 0.006737947
245                let y_final = y.last().unwrap()[0];
246                let y_analytical = (-t1).exp();
247                let error = (y_final - y_analytical).abs();
248                
249                assert!(error < 1e-7, "Solution error too large: {}", error);
250                
251                println!("Exponential decay test successful!");
252                println!("Analytical solution: {}", y_analytical);
253                println!("Numerical solution: {}", y_final);
254                println!("Absolute error: {}", error);
255            }
256            SolveResult::Failed { error, .. } => {
257                panic!("Integration failed: {}", error);
258            }
259            SolveResult::EventTerminated { .. } => {
260                panic!("Unexpected event termination");
261            }
262        }
263    }
264    
265    #[test]
266    fn test_event_termination() {
267        // Simple linear system: y' = 1, starting from y(0) = 0
268        // We'll set an event at y = 2.5
269        let fun = Box::new(|_t: f64, _y: &[f64]| -> Vec<f64> {
270            vec![1.0]
271        });
272        
273        // Event function: trigger when y reaches 2.5
274        let event_fn = Box::new(|_t: f64, y: &[f64]| -> bool {
275            y[0] - 2.5 > 0.0 // Trigger when y crosses 2.5 upward
276        });
277        
278        let event = Event {
279            function: event_fn,
280            terminal: true,
281        };
282        
283        let t0 = 0.0;
284        let t1 = 10.0; // We should never reach this due to event
285        let y0 = vec![0.0];
286        
287        let options = SolveIvpOptions {
288            rtol: 1e-8,
289            atol: 1e-10,
290            events: vec![event],
291            ..Default::default()
292        };
293        
294        let result = solve_ivp_bdf(fun, t0, t1, y0, Some(options));
295        
296        match result {
297            SolveResult::Success { .. } => {
298                panic!("Expected event termination, but got success");
299            }
300            SolveResult::Failed { error, .. } => {
301                panic!("Integration failed: {}", error);
302            }
303            SolveResult::EventTerminated { t, y, event_index} => {
304                assert_eq!(event_index, 0, "Wrong event index");
305                
306                // Check that the final y value is close to 2.5
307                let y_final = y.last().unwrap()[0];
308                let t_final = *t.last().unwrap();
309                
310                // Since y' = 1 and y(0) = 0, we have y(t) = t
311                // So when y = 2.5, t should be 2.5
312                assert!((y_final - 2.5).abs() < 0.1, "Event y value error: {}", (y_final - 2.5).abs());
313                assert!((t_final - 2.5).abs() < 0.1, "Event time error: {}", (t_final - 2.5).abs());
314                assert!(t_final < t1, "Event should occur before final time");
315                
316                println!("Event termination test successful!");
317                println!("Event triggered at t = {}", t_final);
318                println!("Event triggered at y = {}", y_final);
319                println!("Expected values: t ≈ 2.5, y ≈ 2.5");
320            }
321        }
322    }
323    
324    #[test]
325    fn test_pendulum_with_stopping_event() {
326        // Simple pendulum: θ'' + sin(θ) = 0
327        // Convert to system: θ' = ω, ω' = -sin(θ)
328        // We'll stop when the pendulum reaches θ = 0 after starting from θ = π/4
329        let fun = Box::new(|_t: f64, y: &[f64]| -> Vec<f64> {
330            let theta = y[0];
331            let omega = y[1];
332            vec![omega, -theta.sin()]
333        });
334        
335        // Event: stop when θ crosses zero going downward
336        let event_fn = Box::new(|_t: f64, y: &[f64]| -> bool {
337            y[0] < 0.0 // Trigger when θ crosses zero downward
338        });
339        
340        let event = Event {
341            function: event_fn,
342            terminal: true,
343        };
344        
345        let t0 = 0.0;
346        let t1 = 10.0;
347        let y0 = vec![std::f64::consts::PI / 4.0, 0.0]; // Start at 45 degrees, no initial velocity
348        
349        let options = SolveIvpOptions {
350            rtol: 1e-8,
351            atol: 1e-10,
352            events: vec![event],
353            dense_output: true,
354            ..Default::default()
355        };
356        
357        let result = solve_ivp_bdf(fun, t0, t1, y0, Some(options));
358        
359        match result {
360            SolveResult::Success { .. } => {
361                panic!("Expected event termination, but got success");
362            }
363            SolveResult::Failed { error, .. } => {
364                panic!("Integration failed: {}", error);
365            }
366            SolveResult::EventTerminated { t, y, event_index} => {
367                assert_eq!(event_index, 0, "Wrong event index");
368                
369                let t_final = *t.last().unwrap();
370                let y_final = y.last().unwrap();
371                let theta_final = y_final[0];
372                let omega_final = y_final[1];
373                
374                // Check that θ is close to 0
375                assert!(theta_final.abs() < 0.1, "Theta should be close to 0, got {}", theta_final);
376                
377                // Check that ω is negative (pendulum swinging downward)
378                assert!(omega_final < 0.0, "Omega should be negative, got {}", omega_final);
379                
380                // Check that the event occurred at a reasonable time
381                assert!(t_final > 0.5 && t_final < 5.0, "Event time seems unreasonable: {}", t_final);
382                
383                println!("Pendulum event test successful!");
384                println!("Event triggered at t = {}", t_final);
385                println!("Final theta = {}", theta_final);
386                println!("Final omega = {}", omega_final);
387            }
388        }
389    }
390}