scirs2_integrate/pde/method_of_lines/hyperbolic.rs
1//! Method of Lines for hyperbolic PDEs
2//!
3//! This module implements the Method of Lines (MOL) approach for solving
4//! hyperbolic PDEs, such as the wave equation.
5
6use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1};
7use std::sync::Arc;
8use std::time::Instant;
9
10use crate::ode::{solve_ivp, ODEOptions};
11use crate::pde::finite_difference::FiniteDifferenceScheme;
12use crate::pde::{
13 BoundaryCondition, BoundaryConditionType, BoundaryLocation, Domain, PDEError, PDEResult,
14 PDESolution, PDESolverInfo,
15};
16
17/// Type alias for 1D coefficient function taking (x, t, u) and returning a value
18type CoeffFn1D = Arc<dyn Fn(f64, f64, f64) -> f64 + Send + Sync>;
19
20/// Result of hyperbolic PDE solution
21pub struct MOLHyperbolicResult {
22 /// Time points
23 pub t: Array1<f64>,
24
25 /// Solution values, indexed as [time, space]
26 pub u: Array2<f64>,
27
28 /// First-order time derivative values (∂u/∂t)
29 pub u_t: Array2<f64>,
30
31 /// ODE solver information
32 pub ode_info: Option<String>,
33
34 /// Computation time
35 pub computation_time: f64,
36}
37
38/// Method of Lines solver for 1D Wave Equation
39///
40/// Solves the equation: ∂²u/∂t² = c² ∂²u/∂x² + f(x,t,u)
41#[derive(Clone)]
42pub struct MOLWaveEquation1D {
43 /// Spatial domain
44 domain: Domain,
45
46 /// Time range [t_start, t_end]
47 time_range: [f64; 2],
48
49 /// Wave speed (squared) coefficient c²(x, t, u)
50 wave_speed_squared: CoeffFn1D,
51
52 /// Source term function f(x, t, u)
53 source_term: Option<CoeffFn1D>,
54
55 /// Initial condition function u(x, 0)
56 initial_condition: Arc<dyn Fn(f64) -> f64 + Send + Sync>,
57
58 /// Initial velocity function ∂u/∂t(x, 0)
59 initial_velocity: Arc<dyn Fn(f64) -> f64 + Send + Sync>,
60
61 /// Boundary conditions
62 boundary_conditions: Vec<BoundaryCondition<f64>>,
63
64 /// Finite difference scheme for spatial discretization
65 fd_scheme: FiniteDifferenceScheme,
66
67 /// Solver options
68 options: super::MOLOptions,
69}
70
71impl MOLWaveEquation1D {
72 /// Create a new Method of Lines solver for the 1D wave equation
73 pub fn new(
74 domain: Domain,
75 time_range: [f64; 2],
76 wave_speed_squared: impl Fn(f64, f64, f64) -> f64 + Send + Sync + 'static,
77 initial_condition: impl Fn(f64) -> f64 + Send + Sync + 'static,
78 initial_velocity: impl Fn(f64) -> f64 + Send + Sync + 'static,
79 boundary_conditions: Vec<BoundaryCondition<f64>>,
80 options: Option<super::MOLOptions>,
81 ) -> PDEResult<Self> {
82 // Validate the domain
83 if domain.dimensions() != 1 {
84 return Err(PDEError::DomainError(
85 "Domain must be 1-dimensional for 1D wave equation solver".to_string(),
86 ));
87 }
88
89 // Validate time _range
90 if time_range[0] >= time_range[1] {
91 return Err(PDEError::DomainError(
92 "Invalid time _range: start must be less than end".to_string(),
93 ));
94 }
95
96 // Validate boundary _conditions
97 if boundary_conditions.len() != 2 {
98 return Err(PDEError::BoundaryConditions(
99 "1D wave equation requires exactly 2 boundary _conditions".to_string(),
100 ));
101 }
102
103 // Ensure we have both lower and upper boundary _conditions
104 let has_lower = boundary_conditions
105 .iter()
106 .any(|bc| bc.location == BoundaryLocation::Lower);
107 let has_upper = boundary_conditions
108 .iter()
109 .any(|bc| bc.location == BoundaryLocation::Upper);
110
111 if !has_lower || !has_upper {
112 return Err(PDEError::BoundaryConditions(
113 "1D wave equation requires both lower and upper boundary _conditions".to_string(),
114 ));
115 }
116
117 Ok(MOLWaveEquation1D {
118 domain,
119 time_range,
120 wave_speed_squared: Arc::new(wave_speed_squared),
121 source_term: None,
122 initial_condition: Arc::new(initial_condition),
123 initial_velocity: Arc::new(initial_velocity),
124 boundary_conditions,
125 fd_scheme: FiniteDifferenceScheme::CentralDifference,
126 options: options.unwrap_or_default(),
127 })
128 }
129
130 /// Add a source term to the wave equation
131 pub fn with_source(
132 mut self,
133 source_term: impl Fn(f64, f64, f64) -> f64 + Send + Sync + 'static,
134 ) -> Self {
135 self.source_term = Some(Arc::new(source_term));
136 self
137 }
138
139 /// Set the finite difference scheme for spatial discretization
140 pub fn with_fd_scheme(mut self, scheme: FiniteDifferenceScheme) -> Self {
141 self.fd_scheme = scheme;
142 self
143 }
144
145 /// Solve the wave equation
146 pub fn solve(&self) -> PDEResult<MOLHyperbolicResult> {
147 let start_time = Instant::now();
148
149 // Generate spatial grid
150 let x_grid = self.domain.grid(0)?;
151 let nx = x_grid.len();
152 let dx = self.domain.grid_spacing(0)?;
153
154 // Create initial condition and velocity vectors
155 let mut u0 = Array1::zeros(nx);
156 let mut v0 = Array1::zeros(nx);
157
158 for (i, &x) in x_grid.iter().enumerate() {
159 u0[i] = (self.initial_condition)(x);
160 v0[i] = (self.initial_velocity)(x);
161 }
162
163 // The wave equation is a second-order in time PDE, so we convert it
164 // to a first-order system by introducing v = ∂u/∂t
165 // This gives us:
166 // ∂u/∂t = v
167 // ∂v/∂t = c² ∂²u/∂x² + f
168
169 // Combine u and v into a single state vector for the ODE solver
170 let mut y0 = Array1::zeros(2 * nx);
171 for i in 0..nx {
172 y0[i] = u0[i]; // First nx elements are u
173 y0[i + nx] = v0[i]; // Next nx elements are v = ∂u/∂t
174 }
175
176 // Extract data before moving self
177 let x_grid = x_grid.clone();
178 let time_range = self.time_range;
179 let boundary_conditions = self.boundary_conditions.clone();
180 let boundary_conditions_copy = boundary_conditions.clone();
181 let options = self.options.clone();
182
183 // Move self into closure
184 let solver = self;
185
186 // Construct the ODE function for the first-order system
187 let ode_func = move |t: f64, y: ArrayView1<f64>| -> Array1<f64> {
188 // Extract u and v from the combined state vector
189 let u = y.slice(s![0..nx]);
190 let v = y.slice(s![nx..2 * nx]);
191
192 let mut dydt = Array1::zeros(2 * nx);
193
194 // First part: ∂u/∂t = v
195 for i in 0..nx {
196 dydt[i] = v[i];
197 }
198
199 // Second part: ∂v/∂t = c² ∂²u/∂x² + f
200
201 // Apply finite difference approximations for interior points
202 for i in 1..nx - 1 {
203 let x = x_grid[i];
204 let u_i = u[i];
205
206 // Second derivative term
207 let d2u_dx2 = (u[i + 1] - 2.0 * u[i] + u[i - 1]) / (dx * dx);
208 let c_squared = (solver.wave_speed_squared)(x, t, u_i);
209 let wave_term = c_squared * d2u_dx2;
210
211 // Source term
212 let source_term = if let Some(source) = &solver.source_term {
213 source(x, t, u_i)
214 } else {
215 0.0
216 };
217
218 dydt[i + nx] = wave_term + source_term;
219 }
220
221 // Apply boundary conditions
222 for bc in &boundary_conditions_copy {
223 match bc.location {
224 BoundaryLocation::Lower => {
225 // Apply boundary condition at x[0]
226 match bc.bc_type {
227 BoundaryConditionType::Dirichlet => {
228 // Fixed value: u(x_0, t) = bc.value
229 // For Dirichlet, we set v[0] = 0 to maintain the fixed value
230 // and to ensure u[0] doesn't change
231 dydt[0] = 0.0; // ∂u/∂t = 0
232 dydt[nx] = 0.0; // ∂v/∂t = 0
233 }
234 BoundaryConditionType::Neumann => {
235 // Fixed gradient: ∂u/∂x|_{x_0} = bc.value
236
237 // Calculate the ghost point value based on the Neumann condition
238 let du_dx = bc.value;
239 let u_ghost = u[0] - dx * du_dx; // Ghost point value
240
241 // Use central difference for the second derivative
242 let d2u_dx2 = (u[1] - 2.0 * u[0] + u_ghost) / (dx * dx);
243 let c_squared = (solver.wave_speed_squared)(x_grid[0], t, u[0]);
244 let wave_term = c_squared * d2u_dx2;
245
246 // Source term
247 let source_term = if let Some(source) = &solver.source_term {
248 source(x_grid[0], t, u[0])
249 } else {
250 0.0
251 };
252
253 dydt[0] = v[0]; // ∂u/∂t = v
254 dydt[nx] = wave_term + source_term; // ∂v/∂t
255 }
256 BoundaryConditionType::Robin => {
257 // Robin boundary condition: a*u + b*du/dx = c
258 if let Some([a, b, c]) = bc.coefficients {
259 // Solve for ghost point value using Robin condition
260 let du_dx = (c - a * u[0]) / b;
261 let u_ghost = u[0] - dx * du_dx;
262
263 // Use central difference for the second derivative
264 let d2u_dx2 = (u[1] - 2.0 * u[0] + u_ghost) / (dx * dx);
265 let c_squared = (solver.wave_speed_squared)(x_grid[0], t, u[0]);
266 let wave_term = c_squared * d2u_dx2;
267
268 // Source term
269 let source_term = if let Some(source) = &solver.source_term {
270 source(x_grid[0], t, u[0])
271 } else {
272 0.0
273 };
274
275 dydt[0] = v[0]; // ∂u/∂t = v
276 dydt[nx] = wave_term + source_term; // ∂v/∂t
277 }
278 }
279 BoundaryConditionType::Periodic => {
280 // Periodic boundary: u(x_0, t) = u(x_n, t)
281
282 // Use values from the other end of the domain
283 let d2u_dx2 = (u[1] - 2.0 * u[0] + u[nx - 1]) / (dx * dx);
284 let c_squared = (solver.wave_speed_squared)(x_grid[0], t, u[0]);
285 let wave_term = c_squared * d2u_dx2;
286
287 // Source term
288 let source_term = if let Some(source) = &solver.source_term {
289 source(x_grid[0], t, u[0])
290 } else {
291 0.0
292 };
293
294 dydt[0] = v[0]; // ∂u/∂t = v
295 dydt[nx] = wave_term + source_term; // ∂v/∂t
296 }
297 }
298 }
299 BoundaryLocation::Upper => {
300 // Apply boundary condition at x[nx-1]
301 match bc.bc_type {
302 BoundaryConditionType::Dirichlet => {
303 // Fixed value: u(x_n, t) = bc.value
304 dydt[nx - 1] = 0.0; // ∂u/∂t = 0
305 dydt[nx - 1 + nx] = 0.0; // ∂v/∂t = 0
306 }
307 BoundaryConditionType::Neumann => {
308 // Fixed gradient: ∂u/∂x|_{x_n} = bc.value
309
310 // Calculate the ghost point value based on the Neumann condition
311 let du_dx = bc.value;
312 let u_ghost = u[nx - 1] + dx * du_dx; // Ghost point value
313
314 // Use central difference for the second derivative
315 let d2u_dx2 = (u_ghost - 2.0 * u[nx - 1] + u[nx - 2]) / (dx * dx);
316 let c_squared =
317 (solver.wave_speed_squared)(x_grid[nx - 1], t, u[nx - 1]);
318 let wave_term = c_squared * d2u_dx2;
319
320 // Source term
321 let source_term = if let Some(source) = &solver.source_term {
322 source(x_grid[nx - 1], t, u[nx - 1])
323 } else {
324 0.0
325 };
326
327 dydt[nx - 1] = v[nx - 1]; // ∂u/∂t = v
328 dydt[nx - 1 + nx] = wave_term + source_term; // ∂v/∂t
329 }
330 BoundaryConditionType::Robin => {
331 // Robin boundary condition: a*u + b*du/dx = c
332 if let Some([a, b, c]) = bc.coefficients {
333 // Solve for ghost point value using Robin condition
334 let du_dx = (c - a * u[nx - 1]) / b;
335 let u_ghost = u[nx - 1] + dx * du_dx;
336
337 // Use central difference for the second derivative
338 let d2u_dx2 =
339 (u_ghost - 2.0 * u[nx - 1] + u[nx - 2]) / (dx * dx);
340 let c_squared =
341 (solver.wave_speed_squared)(x_grid[nx - 1], t, u[nx - 1]);
342 let wave_term = c_squared * d2u_dx2;
343
344 // Source term
345 let source_term = if let Some(source) = &solver.source_term {
346 source(x_grid[nx - 1], t, u[nx - 1])
347 } else {
348 0.0
349 };
350
351 dydt[nx - 1] = v[nx - 1]; // ∂u/∂t = v
352 dydt[nx - 1 + nx] = wave_term + source_term;
353 // ∂v/∂t
354 }
355 }
356 BoundaryConditionType::Periodic => {
357 // Periodic boundary: u(x_n, t) = u(x_0, t)
358
359 // Use values from the other end of the domain
360 let d2u_dx2 = (u[0] - 2.0 * u[nx - 1] + u[nx - 2]) / (dx * dx);
361 let c_squared =
362 (solver.wave_speed_squared)(x_grid[nx - 1], t, u[nx - 1]);
363 let wave_term = c_squared * d2u_dx2;
364
365 // Source term
366 let source_term = if let Some(source) = &solver.source_term {
367 source(x_grid[nx - 1], t, u[nx - 1])
368 } else {
369 0.0
370 };
371
372 dydt[nx - 1] = v[nx - 1]; // ∂u/∂t = v
373 dydt[nx - 1 + nx] = wave_term + source_term; // ∂v/∂t
374 }
375 }
376 }
377 }
378 }
379
380 dydt
381 };
382
383 // Set up ODE solver options
384 let ode_options = ODEOptions {
385 method: options.ode_method,
386 rtol: options.rtol,
387 atol: options.atol,
388 h0: None,
389 max_steps: options.max_steps.unwrap_or(500),
390 max_step: None,
391 min_step: None,
392 dense_output: true,
393 max_order: None,
394 jac: None,
395 use_banded_jacobian: false,
396 ml: None,
397 mu: None,
398 mass_matrix: None,
399 jacobian_strategy: None,
400 };
401
402 // Apply Dirichlet boundary conditions to initial condition
403 for bc in &boundary_conditions {
404 if bc.bc_type == BoundaryConditionType::Dirichlet {
405 match bc.location {
406 BoundaryLocation::Lower => {
407 y0[0] = bc.value; // u(x_0, 0) = bc.value
408 y0[nx] = 0.0; // v(x_0, 0) = 0
409 }
410 BoundaryLocation::Upper => {
411 y0[nx - 1] = bc.value; // u(x_n, 0) = bc.value
412 y0[nx - 1 + nx] = 0.0; // v(x_n, 0) = 0
413 }
414 }
415 }
416 }
417
418 // Solve the ODE system
419 let ode_result = solve_ivp(ode_func, time_range, y0, Some(ode_options))?;
420
421 // Extract results
422 let computation_time = start_time.elapsed().as_secs_f64();
423
424 // Reshape the ODE result to separate u and v
425 let t = ode_result.t;
426 let nt = t.len();
427
428 let mut u = Array2::zeros((nt, nx));
429 let mut u_t = Array2::zeros((nt, nx));
430
431 for (i, y) in ode_result.y.iter().enumerate() {
432 // Split the state vector into u and v
433 for j in 0..nx {
434 u[[i, j]] = y[j]; // u values
435 u_t[[i, j]] = y[j + nx]; // v = ∂u/∂t values
436 }
437 }
438
439 let ode_info = Some(format!(
440 "ODE steps: {}, function evaluations: {}, successful steps: {}",
441 ode_result.n_steps, ode_result.n_eval, ode_result.n_accepted,
442 ));
443
444 Ok(MOLHyperbolicResult {
445 t: t.into(),
446 u,
447 u_t,
448 ode_info,
449 computation_time,
450 })
451 }
452}
453
454/// Convert a MOLHyperbolicResult to a PDESolution
455impl From<MOLHyperbolicResult> for PDESolution<f64> {
456 fn from(result: MOLHyperbolicResult) -> Self {
457 let mut grids = Vec::new();
458
459 // Add time grid
460 grids.push(result.t.clone());
461
462 // Extract spatial grid from solution
463 let nx = result.u.shape()[1];
464
465 // Note: For a proper implementation, the spatial grid should be provided
466 let spatial_grid = Array1::linspace(0.0, 1.0, nx);
467 grids.push(spatial_grid);
468
469 // Create solver info
470 let info = PDESolverInfo {
471 num_iterations: 0, // This information is not available directly
472 computation_time: result.computation_time,
473 residual_norm: None,
474 convergence_history: None,
475 method: "Method of Lines (Hyperbolic)".to_string(),
476 };
477
478 // For hyperbolic PDEs, we return both u and u_t as values
479 let values = vec![result.u, result.u_t];
480
481 PDESolution {
482 grids,
483 values,
484 error_estimate: None,
485 info,
486 }
487 }
488}