scivex_optim/pde/finite_diff.rs
1//! Finite difference PDE solvers for the heat equation, wave equation, and
2//! Laplace equation.
3
4use scivex_core::Float;
5
6use crate::error::{OptimError, Result};
7
8// ---------------------------------------------------------------------------
9// Types
10// ---------------------------------------------------------------------------
11
12/// Boundary condition specification.
13///
14/// # Examples
15///
16/// ```
17/// # use scivex_optim::pde::BoundaryCondition;
18/// let bc = BoundaryCondition::Dirichlet(0.0_f64);
19/// assert_eq!(bc, BoundaryCondition::Dirichlet(0.0));
20/// ```
21#[derive(Debug, Clone, Copy, PartialEq)]
22pub enum BoundaryCondition<T: Float> {
23 /// Fixed value at boundary (Dirichlet).
24 Dirichlet(T),
25 /// Fixed derivative at boundary (Neumann).
26 Neumann(T),
27}
28
29/// Result of a PDE solve.
30///
31/// # Examples
32///
33/// ```
34/// # use scivex_optim::pde::{heat_equation_1d, BoundaryCondition};
35/// let result = heat_equation_1d(
36/// (0.0_f64, 1.0), 50, 0.1, 100, 0.01,
37/// &|x| (std::f64::consts::PI * x).sin(),
38/// BoundaryCondition::Dirichlet(0.0),
39/// BoundaryCondition::Dirichlet(0.0),
40/// ).unwrap();
41/// assert!(!result.u.is_empty());
42/// ```
43#[derive(Debug, Clone)]
44pub struct PdeResult<T: Float> {
45 /// Solution values: for 1-D time-dependent problems the shape is
46 /// `[n_time][n_space]`. For 2-D steady-state problems the shape is
47 /// `[ny][nx]`.
48 pub u: Vec<Vec<T>>,
49 /// Spatial grid points (x-axis).
50 pub x: Vec<T>,
51 /// Time points (for time-dependent problems) or y-axis grid (for 2-D
52 /// steady-state).
53 pub t_or_y: Vec<T>,
54 /// Number of time / iteration steps taken.
55 pub steps: usize,
56 /// Whether the solution converged (meaningful for iterative methods such
57 /// as Gauss-Seidel).
58 pub converged: bool,
59}
60
61// ---------------------------------------------------------------------------
62// Helpers
63// ---------------------------------------------------------------------------
64
65/// Build a uniform grid of `n` points spanning `[a, b]`.
66fn linspace<T: Float>(a: T, b: T, n: usize) -> Vec<T> {
67 if n < 2 {
68 return vec![a];
69 }
70 let n_intervals = T::from_usize(n - 1);
71 let dx = (b - a) / n_intervals;
72 (0..n).map(|i| a + T::from_usize(i) * dx).collect()
73}
74
75/// Apply a boundary condition at one end of a 1-D solution row.
76///
77/// * `is_left` — `true` for the left boundary (index 0), `false` for the
78/// right boundary (last index).
79/// * `row` — mutable slice of the current solution row.
80/// * `dx` — spatial step size.
81fn apply_bc_1d<T: Float>(bc: &BoundaryCondition<T>, is_left: bool, row: &mut [T], dx: T) {
82 let n = row.len();
83 match *bc {
84 BoundaryCondition::Dirichlet(val) => {
85 if is_left {
86 row[0] = val;
87 } else {
88 row[n - 1] = val;
89 }
90 }
91 BoundaryCondition::Neumann(deriv) => {
92 // Ghost-node approach: u[-1] = u[1] - 2*dx*deriv (left)
93 // u[n] = u[n-2] + 2*dx*deriv (right)
94 if is_left {
95 row[0] = row[1] - dx * deriv;
96 } else {
97 row[n - 1] = row[n - 2] + dx * deriv;
98 }
99 }
100 }
101}
102
103// ---------------------------------------------------------------------------
104// 1-D Heat equation (FTCS explicit scheme)
105// ---------------------------------------------------------------------------
106
107/// Solve the 1-D heat equation
108///
109/// ```text
110/// ∂u/∂t = α ∂²u/∂x²
111/// ```
112///
113/// using the explicit forward-time, centred-space (FTCS) scheme.
114///
115/// # Parameters
116///
117/// * `x_range` — spatial domain `[x0, x1]`.
118/// * `n_x` — number of spatial grid points (must be >= 3).
119/// * `t_final` — simulate until this time (must be > 0).
120/// * `n_t` — number of time steps (must be >= 1).
121/// * `alpha` — thermal diffusivity (must be > 0).
122/// * `initial` — initial condition `u(x, 0)`.
123/// * `left_bc` — boundary condition at `x = x0`.
124/// * `right_bc` — boundary condition at `x = x1`.
125///
126/// # Errors
127///
128/// Returns [`OptimError::InvalidParameter`] when the grid is too small,
129/// parameters are non-positive, or the CFL stability condition
130/// `r = α dt / dx² <= 0.5` is violated.
131///
132/// # Examples
133///
134/// ```
135/// # use scivex_optim::pde::{heat_equation_1d, BoundaryCondition};
136/// let result = heat_equation_1d(
137/// (0.0_f64, 1.0), 50, 0.01, 500, 1.0,
138/// &|x| (std::f64::consts::PI * x).sin(),
139/// BoundaryCondition::Dirichlet(0.0),
140/// BoundaryCondition::Dirichlet(0.0),
141/// ).unwrap();
142/// assert!(result.converged);
143/// ```
144#[allow(clippy::too_many_arguments)]
145pub fn heat_equation_1d<T: Float>(
146 x_range: (T, T),
147 n_x: usize,
148 t_final: T,
149 n_t: usize,
150 alpha: T,
151 initial: &dyn Fn(T) -> T,
152 left_bc: BoundaryCondition<T>,
153 right_bc: BoundaryCondition<T>,
154) -> Result<PdeResult<T>> {
155 // --- Validate inputs ---------------------------------------------------
156 if n_x < 3 {
157 return Err(OptimError::InvalidParameter {
158 name: "n_x",
159 reason: "need at least 3 spatial points",
160 });
161 }
162 if n_t < 1 {
163 return Err(OptimError::InvalidParameter {
164 name: "n_t",
165 reason: "need at least 1 time step",
166 });
167 }
168 let zero = T::zero();
169 if t_final <= zero {
170 return Err(OptimError::InvalidParameter {
171 name: "t_final",
172 reason: "must be positive",
173 });
174 }
175 if alpha <= zero {
176 return Err(OptimError::InvalidParameter {
177 name: "alpha",
178 reason: "must be positive",
179 });
180 }
181
182 let x = linspace(x_range.0, x_range.1, n_x);
183 let dx = x[1] - x[0];
184 let dt = t_final / T::from_usize(n_t);
185 let r = alpha * dt / (dx * dx);
186
187 let half = T::from_f64(0.5);
188 if r > half {
189 return Err(OptimError::InvalidParameter {
190 name: "n_t",
191 reason: "stability condition violated: r = alpha*dt/dx^2 must be <= 0.5",
192 });
193 }
194
195 // --- Initial condition --------------------------------------------------
196 let mut u_prev: Vec<T> = x.iter().map(|&xi| initial(xi)).collect();
197 apply_bc_1d(&left_bc, true, &mut u_prev, dx);
198 apply_bc_1d(&right_bc, false, &mut u_prev, dx);
199
200 let mut all_u: Vec<Vec<T>> = Vec::with_capacity(n_t + 1);
201 all_u.push(u_prev.clone());
202
203 let mut t_vals: Vec<T> = Vec::with_capacity(n_t + 1);
204 t_vals.push(zero);
205
206 // --- Time stepping (FTCS) ----------------------------------------------
207 let two = T::from_f64(2.0);
208 for step in 0..n_t {
209 let mut u_next = u_prev.clone();
210 for i in 1..(n_x - 1) {
211 u_next[i] = u_prev[i] + r * (u_prev[i + 1] - two * u_prev[i] + u_prev[i - 1]);
212 }
213 apply_bc_1d(&left_bc, true, &mut u_next, dx);
214 apply_bc_1d(&right_bc, false, &mut u_next, dx);
215
216 t_vals.push(T::from_usize(step + 1) * dt);
217 all_u.push(u_next.clone());
218 u_prev = u_next;
219 }
220
221 Ok(PdeResult {
222 u: all_u,
223 x,
224 t_or_y: t_vals,
225 steps: n_t,
226 converged: true,
227 })
228}
229
230// ---------------------------------------------------------------------------
231// 1-D Wave equation (explicit three-level scheme)
232// ---------------------------------------------------------------------------
233
234/// Solve the 1-D wave equation
235///
236/// ```text
237/// ∂²u/∂t² = c² ∂²u/∂x²
238/// ```
239///
240/// using the explicit three-level centred-difference scheme.
241///
242/// # Parameters
243///
244/// * `x_range` — spatial domain `[x0, x1]`.
245/// * `n_x` — number of spatial grid points (must be >= 3).
246/// * `t_final` — simulate until this time (must be > 0).
247/// * `n_t` — number of time steps (must be >= 1).
248/// * `c` — wave speed (must be > 0).
249/// * `initial_u` — initial displacement `u(x, 0)`.
250/// * `initial_ut` — initial velocity `∂u/∂t(x, 0)`.
251/// * `left_bc` — boundary condition at `x = x0`.
252/// * `right_bc` — boundary condition at `x = x1`.
253///
254/// # Errors
255///
256/// Returns [`OptimError::InvalidParameter`] when the grid is too small,
257/// parameters are non-positive, or the CFL condition `c dt / dx <= 1` is
258/// violated.
259#[allow(clippy::too_many_arguments)]
260pub fn wave_equation_1d<T: Float>(
261 x_range: (T, T),
262 n_x: usize,
263 t_final: T,
264 n_t: usize,
265 c: T,
266 initial_u: &dyn Fn(T) -> T,
267 initial_ut: &dyn Fn(T) -> T,
268 left_bc: BoundaryCondition<T>,
269 right_bc: BoundaryCondition<T>,
270) -> Result<PdeResult<T>> {
271 // --- Validate inputs ---------------------------------------------------
272 if n_x < 3 {
273 return Err(OptimError::InvalidParameter {
274 name: "n_x",
275 reason: "need at least 3 spatial points",
276 });
277 }
278 if n_t < 1 {
279 return Err(OptimError::InvalidParameter {
280 name: "n_t",
281 reason: "need at least 1 time step",
282 });
283 }
284 let zero = T::zero();
285 if t_final <= zero {
286 return Err(OptimError::InvalidParameter {
287 name: "t_final",
288 reason: "must be positive",
289 });
290 }
291 if c <= zero {
292 return Err(OptimError::InvalidParameter {
293 name: "c",
294 reason: "must be positive",
295 });
296 }
297
298 let x = linspace(x_range.0, x_range.1, n_x);
299 let dx = x[1] - x[0];
300 let dt = t_final / T::from_usize(n_t);
301 let r = c * dt / dx; // Courant number
302
303 if r > T::one() {
304 return Err(OptimError::InvalidParameter {
305 name: "n_t",
306 reason: "CFL condition violated: c*dt/dx must be <= 1",
307 });
308 }
309
310 let r2 = r * r;
311 let two = T::from_f64(2.0);
312
313 // --- Level 0: u(x, 0) --------------------------------------------------
314 let mut u_prev: Vec<T> = x.iter().map(|&xi| initial_u(xi)).collect();
315 apply_bc_1d(&left_bc, true, &mut u_prev, dx);
316 apply_bc_1d(&right_bc, false, &mut u_prev, dx);
317
318 let mut all_u: Vec<Vec<T>> = Vec::with_capacity(n_t + 1);
319 all_u.push(u_prev.clone());
320
321 let mut t_vals: Vec<T> = Vec::with_capacity(n_t + 1);
322 t_vals.push(zero);
323
324 // --- Level 1: special first step using initial velocity -----------------
325 // u^1_i = u^0_i + dt * ut(x_i) + 0.5*r²*(u^0_{i+1} - 2 u^0_i + u^0_{i-1})
326 let half = T::from_f64(0.5);
327 let mut u_curr: Vec<T> = vec![zero; n_x];
328 for i in 1..(n_x - 1) {
329 let laplacian = u_prev[i + 1] - two * u_prev[i] + u_prev[i - 1];
330 u_curr[i] = u_prev[i] + dt * initial_ut(x[i]) + half * r2 * laplacian;
331 }
332 apply_bc_1d(&left_bc, true, &mut u_curr, dx);
333 apply_bc_1d(&right_bc, false, &mut u_curr, dx);
334
335 t_vals.push(dt);
336 all_u.push(u_curr.clone());
337
338 // --- Remaining steps (three-level scheme) -------------------------------
339 for step in 1..n_t {
340 let mut u_next = vec![zero; n_x];
341 for i in 1..(n_x - 1) {
342 let laplacian = u_curr[i + 1] - two * u_curr[i] + u_curr[i - 1];
343 u_next[i] = two * u_curr[i] - u_prev[i] + r2 * laplacian;
344 }
345 apply_bc_1d(&left_bc, true, &mut u_next, dx);
346 apply_bc_1d(&right_bc, false, &mut u_next, dx);
347
348 t_vals.push(T::from_usize(step + 1) * dt);
349 all_u.push(u_next.clone());
350 u_prev = u_curr;
351 u_curr = u_next;
352 }
353
354 Ok(PdeResult {
355 u: all_u,
356 x,
357 t_or_y: t_vals,
358 steps: n_t,
359 converged: true,
360 })
361}
362
363// ---------------------------------------------------------------------------
364// 2-D Laplace equation (Gauss-Seidel iteration)
365// ---------------------------------------------------------------------------
366
367/// Solve the 2-D Laplace equation
368///
369/// ```text
370/// ∂²u/∂x² + ∂²u/∂y² = 0
371/// ```
372///
373/// on a rectangular domain using Gauss-Seidel relaxation.
374///
375/// The `boundary` closure receives `(x, y)` and must return `Some(value)` for
376/// every point on the boundary of the domain (i.e.\ the first/last row/column
377/// of the grid). Interior points should return `None`.
378///
379/// # Parameters
380///
381/// * `x_range` — `[x0, x1]`.
382/// * `y_range` — `[y0, y1]`.
383/// * `n_x` — number of grid points in x (must be >= 3).
384/// * `n_y` — number of grid points in y (must be >= 3).
385/// * `boundary` — closure returning `Some(T)` for boundary points.
386/// * `max_iter` — maximum number of Gauss-Seidel sweeps.
387/// * `tol` — convergence tolerance on the max absolute update.
388///
389/// # Errors
390///
391/// Returns [`OptimError::InvalidParameter`] for grids that are too small or
392/// non-positive `max_iter` / `tol`.
393pub fn laplace_2d<T: Float>(
394 x_range: (T, T),
395 y_range: (T, T),
396 n_x: usize,
397 n_y: usize,
398 boundary: &dyn Fn(T, T) -> Option<T>,
399 max_iter: usize,
400 tol: T,
401) -> Result<PdeResult<T>> {
402 // --- Validate inputs ---------------------------------------------------
403 if n_x < 3 {
404 return Err(OptimError::InvalidParameter {
405 name: "n_x",
406 reason: "need at least 3 grid points in x",
407 });
408 }
409 if n_y < 3 {
410 return Err(OptimError::InvalidParameter {
411 name: "n_y",
412 reason: "need at least 3 grid points in y",
413 });
414 }
415 if max_iter == 0 {
416 return Err(OptimError::InvalidParameter {
417 name: "max_iter",
418 reason: "must be at least 1",
419 });
420 }
421 if tol <= T::zero() {
422 return Err(OptimError::InvalidParameter {
423 name: "tol",
424 reason: "must be positive",
425 });
426 }
427
428 let x = linspace(x_range.0, x_range.1, n_x);
429 let y = linspace(y_range.0, y_range.1, n_y);
430
431 // --- Initialise grid with boundary values; interior = 0 ----------------
432 let mut u: Vec<Vec<T>> = Vec::with_capacity(n_y);
433 let mut is_boundary: Vec<Vec<bool>> = Vec::with_capacity(n_y);
434
435 for yj in &y {
436 let mut row = vec![T::zero(); n_x];
437 let mut brow = vec![false; n_x];
438 for (i, xi) in x.iter().enumerate() {
439 if let Some(val) = boundary(*xi, *yj) {
440 row[i] = val;
441 brow[i] = true;
442 }
443 }
444 u.push(row);
445 is_boundary.push(brow);
446 }
447
448 // --- Gauss-Seidel iteration --------------------------------------------
449 let quarter = T::from_f64(0.25);
450 let mut converged = false;
451 let mut steps: usize = 0;
452
453 for _iter in 0..max_iter {
454 let mut max_diff = T::zero();
455 for j in 1..(n_y - 1) {
456 for i in 1..(n_x - 1) {
457 if is_boundary[j][i] {
458 continue;
459 }
460 let new_val = quarter * (u[j][i + 1] + u[j][i - 1] + u[j + 1][i] + u[j - 1][i]);
461 let diff = (new_val - u[j][i]).abs();
462 if diff > max_diff {
463 max_diff = diff;
464 }
465 u[j][i] = new_val;
466 }
467 }
468 steps += 1;
469 if max_diff < tol {
470 converged = true;
471 break;
472 }
473 }
474
475 Ok(PdeResult {
476 u,
477 x,
478 t_or_y: y,
479 steps,
480 converged,
481 })
482}
483
484// ===========================================================================
485// Tests
486// ===========================================================================
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 /// Heat equation with fixed BCs (0 on left, 1 on right) should converge
493 /// toward a linear profile u(x) = x / L in steady state.
494 #[test]
495 fn test_heat_steady_state() {
496 let n_x = 21;
497 let n_t = 50_000;
498 let result = heat_equation_1d(
499 (0.0, 1.0),
500 n_x,
501 50.0, // long enough to approach steady state (L²/α = 10)
502 n_t,
503 0.1, // alpha
504 &|_x| 0.0, // initial = 0 everywhere
505 BoundaryCondition::Dirichlet(0.0),
506 BoundaryCondition::Dirichlet(1.0),
507 )
508 .unwrap();
509
510 // Last row should be approximately linear: u(x) ≈ x
511 let last = result.u.last().unwrap();
512 for (i, &xi) in result.x.iter().enumerate() {
513 let err = (last[i] - xi).abs();
514 assert!(
515 err < 0.05,
516 "steady-state error too large at x={xi}: u={}, expected={xi}, err={err}",
517 last[i],
518 );
519 }
520 }
521
522 /// A Gaussian pulse should diffuse (spread out) under the heat equation:
523 /// its peak amplitude should decrease over time.
524 #[test]
525 fn test_heat_gaussian_diffusion() {
526 let n_x = 101;
527 let n_t = 5000;
528 let result = heat_equation_1d(
529 (0.0, 1.0),
530 n_x,
531 0.05,
532 n_t,
533 0.01,
534 &|x: f64| (-(x - 0.5).powi(2) / 0.01).exp(),
535 BoundaryCondition::Dirichlet(0.0),
536 BoundaryCondition::Dirichlet(0.0),
537 )
538 .unwrap();
539
540 // Peak of initial condition
541 let initial_max = result.u[0]
542 .iter()
543 .copied()
544 .fold(f64::NEG_INFINITY, f64::max);
545
546 // Peak at final time
547 let final_max = result
548 .u
549 .last()
550 .unwrap()
551 .iter()
552 .copied()
553 .fold(f64::NEG_INFINITY, f64::max);
554
555 assert!(
556 final_max < initial_max,
557 "Gaussian peak should decrease: initial={initial_max}, final={final_max}",
558 );
559 }
560
561 /// A sine standing-wave should oscillate: u(x,0) = sin(pi*x), ut=0.
562 /// After half a period the solution should be approximately -sin(pi*x).
563 #[test]
564 fn test_wave_standing_wave() {
565 let n_x = 101;
566 let c = 1.0_f64;
567 // Period = 2*L/c = 2.0 for L=1, c=1. Half-period = 1.0.
568 let t_final = 1.0;
569 let n_t = 200;
570 let result = wave_equation_1d(
571 (0.0, 1.0),
572 n_x,
573 t_final,
574 n_t,
575 c,
576 &|x: f64| (std::f64::consts::PI * x).sin(),
577 &|_x: f64| 0.0,
578 BoundaryCondition::Dirichlet(0.0),
579 BoundaryCondition::Dirichlet(0.0),
580 )
581 .unwrap();
582
583 // At t = half-period the displacement should be ≈ -sin(pi*x).
584 let last = result.u.last().unwrap();
585 let mid = n_x / 2; // x = 0.5
586 // sin(pi*0.5) = 1.0, so we expect ≈ -1.0.
587 assert!(
588 last[mid] < -0.8,
589 "standing wave mid-point should be near -1 at half-period, got {}",
590 last[mid],
591 );
592 }
593
594 /// Laplace equation with boundary u = x (linear) should yield the exact
595 /// linear interior solution u(x, y) = x.
596 #[test]
597 fn test_laplace_linear_boundary() {
598 let n_x = 21;
599 let n_y = 21;
600 let result = laplace_2d(
601 (0.0, 1.0),
602 (0.0, 1.0),
603 n_x,
604 n_y,
605 &|x: f64, y: f64| {
606 // Mark every edge point as boundary with value = x.
607 if x < 1e-12 || (x - 1.0).abs() < 1e-12 || y < 1e-12 || (y - 1.0).abs() < 1e-12 {
608 Some(x)
609 } else {
610 None
611 }
612 },
613 10_000,
614 1e-10,
615 )
616 .unwrap();
617
618 assert!(result.converged, "Laplace solver should converge");
619
620 // Interior should be ≈ x.
621 for j in 1..(n_y - 1) {
622 for i in 1..(n_x - 1) {
623 let err = (result.u[j][i] - result.x[i]).abs();
624 assert!(
625 err < 1e-6,
626 "Laplace linear solution error at ({}, {}): u={}, expected={}, err={err}",
627 result.x[i],
628 result.t_or_y[j],
629 result.u[j][i],
630 result.x[i],
631 );
632 }
633 }
634 }
635
636 /// Verify the `converged` flag is actually set when the solver meets the
637 /// tolerance.
638 #[test]
639 fn test_laplace_convergence() {
640 let result = laplace_2d(
641 (0.0, 1.0),
642 (0.0, 1.0),
643 11,
644 11,
645 &|x: f64, y: f64| {
646 if x < 1e-12 || (x - 1.0).abs() < 1e-12 || y < 1e-12 || (y - 1.0).abs() < 1e-12 {
647 Some(x * y)
648 } else {
649 None
650 }
651 },
652 50_000,
653 1e-8,
654 )
655 .unwrap();
656
657 assert!(
658 result.converged,
659 "Laplace solver should converge within 50 000 iterations",
660 );
661 assert!(
662 result.steps < 50_000,
663 "should converge before hitting max_iter (took {} steps)",
664 result.steps,
665 );
666 }
667}