scirs2_integrate/pde/method_of_lines/
hyperbolic.rs

1//! Method of Lines for hyperbolic PDEs
2//!
3//! This module implements the Method of Lines (MOL) approach for solving
4//! hyperbolic PDEs, such as the wave equation.
5
6use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1};
7use std::sync::Arc;
8use std::time::Instant;
9
10use crate::ode::{solve_ivp, ODEOptions};
11use crate::pde::finite_difference::FiniteDifferenceScheme;
12use crate::pde::{
13    BoundaryCondition, BoundaryConditionType, BoundaryLocation, Domain, PDEError, PDEResult,
14    PDESolution, PDESolverInfo,
15};
16
17/// Type alias for 1D coefficient function taking (x, t, u) and returning a value
18type CoeffFn1D = Arc<dyn Fn(f64, f64, f64) -> f64 + Send + Sync>;
19
20/// Result of hyperbolic PDE solution
21pub struct MOLHyperbolicResult {
22    /// Time points
23    pub t: Array1<f64>,
24
25    /// Solution values, indexed as [time, space]
26    pub u: Array2<f64>,
27
28    /// First-order time derivative values (∂u/∂t)
29    pub u_t: Array2<f64>,
30
31    /// ODE solver information
32    pub ode_info: Option<String>,
33
34    /// Computation time
35    pub computation_time: f64,
36}
37
38/// Method of Lines solver for 1D Wave Equation
39///
40/// Solves the equation: ∂²u/∂t² = c² ∂²u/∂x² + f(x,t,u)
41#[derive(Clone)]
42pub struct MOLWaveEquation1D {
43    /// Spatial domain
44    domain: Domain,
45
46    /// Time range [t_start, t_end]
47    time_range: [f64; 2],
48
49    /// Wave speed (squared) coefficient c²(x, t, u)
50    wave_speed_squared: CoeffFn1D,
51
52    /// Source term function f(x, t, u)
53    source_term: Option<CoeffFn1D>,
54
55    /// Initial condition function u(x, 0)
56    initial_condition: Arc<dyn Fn(f64) -> f64 + Send + Sync>,
57
58    /// Initial velocity function ∂u/∂t(x, 0)
59    initial_velocity: Arc<dyn Fn(f64) -> f64 + Send + Sync>,
60
61    /// Boundary conditions
62    boundary_conditions: Vec<BoundaryCondition<f64>>,
63
64    /// Finite difference scheme for spatial discretization
65    fd_scheme: FiniteDifferenceScheme,
66
67    /// Solver options
68    options: super::MOLOptions,
69}
70
71impl MOLWaveEquation1D {
72    /// Create a new Method of Lines solver for the 1D wave equation
73    pub fn new(
74        domain: Domain,
75        time_range: [f64; 2],
76        wave_speed_squared: impl Fn(f64, f64, f64) -> f64 + Send + Sync + 'static,
77        initial_condition: impl Fn(f64) -> f64 + Send + Sync + 'static,
78        initial_velocity: impl Fn(f64) -> f64 + Send + Sync + 'static,
79        boundary_conditions: Vec<BoundaryCondition<f64>>,
80        options: Option<super::MOLOptions>,
81    ) -> PDEResult<Self> {
82        // Validate the domain
83        if domain.dimensions() != 1 {
84            return Err(PDEError::DomainError(
85                "Domain must be 1-dimensional for 1D wave equation solver".to_string(),
86            ));
87        }
88
89        // Validate time _range
90        if time_range[0] >= time_range[1] {
91            return Err(PDEError::DomainError(
92                "Invalid time _range: start must be less than end".to_string(),
93            ));
94        }
95
96        // Validate boundary _conditions
97        if boundary_conditions.len() != 2 {
98            return Err(PDEError::BoundaryConditions(
99                "1D wave equation requires exactly 2 boundary _conditions".to_string(),
100            ));
101        }
102
103        // Ensure we have both lower and upper boundary _conditions
104        let has_lower = boundary_conditions
105            .iter()
106            .any(|bc| bc.location == BoundaryLocation::Lower);
107        let has_upper = boundary_conditions
108            .iter()
109            .any(|bc| bc.location == BoundaryLocation::Upper);
110
111        if !has_lower || !has_upper {
112            return Err(PDEError::BoundaryConditions(
113                "1D wave equation requires both lower and upper boundary _conditions".to_string(),
114            ));
115        }
116
117        Ok(MOLWaveEquation1D {
118            domain,
119            time_range,
120            wave_speed_squared: Arc::new(wave_speed_squared),
121            source_term: None,
122            initial_condition: Arc::new(initial_condition),
123            initial_velocity: Arc::new(initial_velocity),
124            boundary_conditions,
125            fd_scheme: FiniteDifferenceScheme::CentralDifference,
126            options: options.unwrap_or_default(),
127        })
128    }
129
130    /// Add a source term to the wave equation
131    pub fn with_source(
132        mut self,
133        source_term: impl Fn(f64, f64, f64) -> f64 + Send + Sync + 'static,
134    ) -> Self {
135        self.source_term = Some(Arc::new(source_term));
136        self
137    }
138
139    /// Set the finite difference scheme for spatial discretization
140    pub fn with_fd_scheme(mut self, scheme: FiniteDifferenceScheme) -> Self {
141        self.fd_scheme = scheme;
142        self
143    }
144
145    /// Solve the wave equation
146    pub fn solve(&self) -> PDEResult<MOLHyperbolicResult> {
147        let start_time = Instant::now();
148
149        // Generate spatial grid
150        let x_grid = self.domain.grid(0)?;
151        let nx = x_grid.len();
152        let dx = self.domain.grid_spacing(0)?;
153
154        // Create initial condition and velocity vectors
155        let mut u0 = Array1::zeros(nx);
156        let mut v0 = Array1::zeros(nx);
157
158        for (i, &x) in x_grid.iter().enumerate() {
159            u0[i] = (self.initial_condition)(x);
160            v0[i] = (self.initial_velocity)(x);
161        }
162
163        // The wave equation is a second-order in time PDE, so we convert it
164        // to a first-order system by introducing v = ∂u/∂t
165        // This gives us:
166        // ∂u/∂t = v
167        // ∂v/∂t = c² ∂²u/∂x² + f
168
169        // Combine u and v into a single state vector for the ODE solver
170        let mut y0 = Array1::zeros(2 * nx);
171        for i in 0..nx {
172            y0[i] = u0[i]; // First nx elements are u
173            y0[i + nx] = v0[i]; // Next nx elements are v = ∂u/∂t
174        }
175
176        // Extract data before moving self
177        let x_grid = x_grid.clone();
178        let time_range = self.time_range;
179        let boundary_conditions = self.boundary_conditions.clone();
180        let boundary_conditions_copy = boundary_conditions.clone();
181        let options = self.options.clone();
182
183        // Move self into closure
184        let solver = self;
185
186        // Construct the ODE function for the first-order system
187        let ode_func = move |t: f64, y: ArrayView1<f64>| -> Array1<f64> {
188            // Extract u and v from the combined state vector
189            let u = y.slice(s![0..nx]);
190            let v = y.slice(s![nx..2 * nx]);
191
192            let mut dydt = Array1::zeros(2 * nx);
193
194            // First part: ∂u/∂t = v
195            for i in 0..nx {
196                dydt[i] = v[i];
197            }
198
199            // Second part: ∂v/∂t = c² ∂²u/∂x² + f
200
201            // Apply finite difference approximations for interior points
202            for i in 1..nx - 1 {
203                let x = x_grid[i];
204                let u_i = u[i];
205
206                // Second derivative term
207                let d2u_dx2 = (u[i + 1] - 2.0 * u[i] + u[i - 1]) / (dx * dx);
208                let c_squared = (solver.wave_speed_squared)(x, t, u_i);
209                let wave_term = c_squared * d2u_dx2;
210
211                // Source term
212                let source_term = if let Some(source) = &solver.source_term {
213                    source(x, t, u_i)
214                } else {
215                    0.0
216                };
217
218                dydt[i + nx] = wave_term + source_term;
219            }
220
221            // Apply boundary conditions
222            for bc in &boundary_conditions_copy {
223                match bc.location {
224                    BoundaryLocation::Lower => {
225                        // Apply boundary condition at x[0]
226                        match bc.bc_type {
227                            BoundaryConditionType::Dirichlet => {
228                                // Fixed value: u(x_0, t) = bc.value
229                                // For Dirichlet, we set v[0] = 0 to maintain the fixed value
230                                // and to ensure u[0] doesn't change
231                                dydt[0] = 0.0; // ∂u/∂t = 0
232                                dydt[nx] = 0.0; // ∂v/∂t = 0
233                            }
234                            BoundaryConditionType::Neumann => {
235                                // Fixed gradient: ∂u/∂x|_{x_0} = bc.value
236
237                                // Calculate the ghost point value based on the Neumann condition
238                                let du_dx = bc.value;
239                                let u_ghost = u[0] - dx * du_dx; // Ghost point value
240
241                                // Use central difference for the second derivative
242                                let d2u_dx2 = (u[1] - 2.0 * u[0] + u_ghost) / (dx * dx);
243                                let c_squared = (solver.wave_speed_squared)(x_grid[0], t, u[0]);
244                                let wave_term = c_squared * d2u_dx2;
245
246                                // Source term
247                                let source_term = if let Some(source) = &solver.source_term {
248                                    source(x_grid[0], t, u[0])
249                                } else {
250                                    0.0
251                                };
252
253                                dydt[0] = v[0]; // ∂u/∂t = v
254                                dydt[nx] = wave_term + source_term; // ∂v/∂t
255                            }
256                            BoundaryConditionType::Robin => {
257                                // Robin boundary condition: a*u + b*du/dx = c
258                                if let Some([a, b, c]) = bc.coefficients {
259                                    // Solve for ghost point value using Robin condition
260                                    let du_dx = (c - a * u[0]) / b;
261                                    let u_ghost = u[0] - dx * du_dx;
262
263                                    // Use central difference for the second derivative
264                                    let d2u_dx2 = (u[1] - 2.0 * u[0] + u_ghost) / (dx * dx);
265                                    let c_squared = (solver.wave_speed_squared)(x_grid[0], t, u[0]);
266                                    let wave_term = c_squared * d2u_dx2;
267
268                                    // Source term
269                                    let source_term = if let Some(source) = &solver.source_term {
270                                        source(x_grid[0], t, u[0])
271                                    } else {
272                                        0.0
273                                    };
274
275                                    dydt[0] = v[0]; // ∂u/∂t = v
276                                    dydt[nx] = wave_term + source_term; // ∂v/∂t
277                                }
278                            }
279                            BoundaryConditionType::Periodic => {
280                                // Periodic boundary: u(x_0, t) = u(x_n, t)
281
282                                // Use values from the other end of the domain
283                                let d2u_dx2 = (u[1] - 2.0 * u[0] + u[nx - 1]) / (dx * dx);
284                                let c_squared = (solver.wave_speed_squared)(x_grid[0], t, u[0]);
285                                let wave_term = c_squared * d2u_dx2;
286
287                                // Source term
288                                let source_term = if let Some(source) = &solver.source_term {
289                                    source(x_grid[0], t, u[0])
290                                } else {
291                                    0.0
292                                };
293
294                                dydt[0] = v[0]; // ∂u/∂t = v
295                                dydt[nx] = wave_term + source_term; // ∂v/∂t
296                            }
297                        }
298                    }
299                    BoundaryLocation::Upper => {
300                        // Apply boundary condition at x[nx-1]
301                        match bc.bc_type {
302                            BoundaryConditionType::Dirichlet => {
303                                // Fixed value: u(x_n, t) = bc.value
304                                dydt[nx - 1] = 0.0; // ∂u/∂t = 0
305                                dydt[nx - 1 + nx] = 0.0; // ∂v/∂t = 0
306                            }
307                            BoundaryConditionType::Neumann => {
308                                // Fixed gradient: ∂u/∂x|_{x_n} = bc.value
309
310                                // Calculate the ghost point value based on the Neumann condition
311                                let du_dx = bc.value;
312                                let u_ghost = u[nx - 1] + dx * du_dx; // Ghost point value
313
314                                // Use central difference for the second derivative
315                                let d2u_dx2 = (u_ghost - 2.0 * u[nx - 1] + u[nx - 2]) / (dx * dx);
316                                let c_squared =
317                                    (solver.wave_speed_squared)(x_grid[nx - 1], t, u[nx - 1]);
318                                let wave_term = c_squared * d2u_dx2;
319
320                                // Source term
321                                let source_term = if let Some(source) = &solver.source_term {
322                                    source(x_grid[nx - 1], t, u[nx - 1])
323                                } else {
324                                    0.0
325                                };
326
327                                dydt[nx - 1] = v[nx - 1]; // ∂u/∂t = v
328                                dydt[nx - 1 + nx] = wave_term + source_term; // ∂v/∂t
329                            }
330                            BoundaryConditionType::Robin => {
331                                // Robin boundary condition: a*u + b*du/dx = c
332                                if let Some([a, b, c]) = bc.coefficients {
333                                    // Solve for ghost point value using Robin condition
334                                    let du_dx = (c - a * u[nx - 1]) / b;
335                                    let u_ghost = u[nx - 1] + dx * du_dx;
336
337                                    // Use central difference for the second derivative
338                                    let d2u_dx2 =
339                                        (u_ghost - 2.0 * u[nx - 1] + u[nx - 2]) / (dx * dx);
340                                    let c_squared =
341                                        (solver.wave_speed_squared)(x_grid[nx - 1], t, u[nx - 1]);
342                                    let wave_term = c_squared * d2u_dx2;
343
344                                    // Source term
345                                    let source_term = if let Some(source) = &solver.source_term {
346                                        source(x_grid[nx - 1], t, u[nx - 1])
347                                    } else {
348                                        0.0
349                                    };
350
351                                    dydt[nx - 1] = v[nx - 1]; // ∂u/∂t = v
352                                    dydt[nx - 1 + nx] = wave_term + source_term;
353                                    // ∂v/∂t
354                                }
355                            }
356                            BoundaryConditionType::Periodic => {
357                                // Periodic boundary: u(x_n, t) = u(x_0, t)
358
359                                // Use values from the other end of the domain
360                                let d2u_dx2 = (u[0] - 2.0 * u[nx - 1] + u[nx - 2]) / (dx * dx);
361                                let c_squared =
362                                    (solver.wave_speed_squared)(x_grid[nx - 1], t, u[nx - 1]);
363                                let wave_term = c_squared * d2u_dx2;
364
365                                // Source term
366                                let source_term = if let Some(source) = &solver.source_term {
367                                    source(x_grid[nx - 1], t, u[nx - 1])
368                                } else {
369                                    0.0
370                                };
371
372                                dydt[nx - 1] = v[nx - 1]; // ∂u/∂t = v
373                                dydt[nx - 1 + nx] = wave_term + source_term; // ∂v/∂t
374                            }
375                        }
376                    }
377                }
378            }
379
380            dydt
381        };
382
383        // Set up ODE solver options
384        let ode_options = ODEOptions {
385            method: options.ode_method,
386            rtol: options.rtol,
387            atol: options.atol,
388            h0: None,
389            max_steps: options.max_steps.unwrap_or(500),
390            max_step: None,
391            min_step: None,
392            dense_output: true,
393            max_order: None,
394            jac: None,
395            use_banded_jacobian: false,
396            ml: None,
397            mu: None,
398            mass_matrix: None,
399            jacobian_strategy: None,
400        };
401
402        // Apply Dirichlet boundary conditions to initial condition
403        for bc in &boundary_conditions {
404            if bc.bc_type == BoundaryConditionType::Dirichlet {
405                match bc.location {
406                    BoundaryLocation::Lower => {
407                        y0[0] = bc.value; // u(x_0, 0) = bc.value
408                        y0[nx] = 0.0; // v(x_0, 0) = 0
409                    }
410                    BoundaryLocation::Upper => {
411                        y0[nx - 1] = bc.value; // u(x_n, 0) = bc.value
412                        y0[nx - 1 + nx] = 0.0; // v(x_n, 0) = 0
413                    }
414                }
415            }
416        }
417
418        // Solve the ODE system
419        let ode_result = solve_ivp(ode_func, time_range, y0, Some(ode_options))?;
420
421        // Extract results
422        let computation_time = start_time.elapsed().as_secs_f64();
423
424        // Reshape the ODE result to separate u and v
425        let t = ode_result.t;
426        let nt = t.len();
427
428        let mut u = Array2::zeros((nt, nx));
429        let mut u_t = Array2::zeros((nt, nx));
430
431        for (i, y) in ode_result.y.iter().enumerate() {
432            // Split the state vector into u and v
433            for j in 0..nx {
434                u[[i, j]] = y[j]; // u values
435                u_t[[i, j]] = y[j + nx]; // v = ∂u/∂t values
436            }
437        }
438
439        let ode_info = Some(format!(
440            "ODE steps: {}, function evaluations: {}, successful steps: {}",
441            ode_result.n_steps, ode_result.n_eval, ode_result.n_accepted,
442        ));
443
444        Ok(MOLHyperbolicResult {
445            t: t.into(),
446            u,
447            u_t,
448            ode_info,
449            computation_time,
450        })
451    }
452}
453
454/// Convert a MOLHyperbolicResult to a PDESolution
455impl From<MOLHyperbolicResult> for PDESolution<f64> {
456    fn from(result: MOLHyperbolicResult) -> Self {
457        let mut grids = Vec::new();
458
459        // Add time grid
460        grids.push(result.t.clone());
461
462        // Extract spatial grid from solution
463        let nx = result.u.shape()[1];
464
465        // Note: For a proper implementation, the spatial grid should be provided
466        let spatial_grid = Array1::linspace(0.0, 1.0, nx);
467        grids.push(spatial_grid);
468
469        // Create solver info
470        let info = PDESolverInfo {
471            num_iterations: 0, // This information is not available directly
472            computation_time: result.computation_time,
473            residual_norm: None,
474            convergence_history: None,
475            method: "Method of Lines (Hyperbolic)".to_string(),
476        };
477
478        // For hyperbolic PDEs, we return both u and u_t as values
479        let values = vec![result.u, result.u_t];
480
481        PDESolution {
482            grids,
483            values,
484            error_estimate: None,
485            info,
486        }
487    }
488}