Skip to main content

scivex_optim/pde/
finite_diff.rs

1//! Finite difference PDE solvers for the heat equation, wave equation, and
2//! Laplace equation.
3
4use scivex_core::Float;
5
6use crate::error::{OptimError, Result};
7
8// ---------------------------------------------------------------------------
9// Types
10// ---------------------------------------------------------------------------
11
12/// Boundary condition specification.
13///
14/// # Examples
15///
16/// ```
17/// # use scivex_optim::pde::BoundaryCondition;
18/// let bc = BoundaryCondition::Dirichlet(0.0_f64);
19/// assert_eq!(bc, BoundaryCondition::Dirichlet(0.0));
20/// ```
21#[derive(Debug, Clone, Copy, PartialEq)]
22pub enum BoundaryCondition<T: Float> {
23    /// Fixed value at boundary (Dirichlet).
24    Dirichlet(T),
25    /// Fixed derivative at boundary (Neumann).
26    Neumann(T),
27}
28
29/// Result of a PDE solve.
30///
31/// # Examples
32///
33/// ```
34/// # use scivex_optim::pde::{heat_equation_1d, BoundaryCondition};
35/// let result = heat_equation_1d(
36///     (0.0_f64, 1.0), 50, 0.1, 100, 0.01,
37///     &|x| (std::f64::consts::PI * x).sin(),
38///     BoundaryCondition::Dirichlet(0.0),
39///     BoundaryCondition::Dirichlet(0.0),
40/// ).unwrap();
41/// assert!(!result.u.is_empty());
42/// ```
43#[derive(Debug, Clone)]
44pub struct PdeResult<T: Float> {
45    /// Solution values: for 1-D time-dependent problems the shape is
46    /// `[n_time][n_space]`.  For 2-D steady-state problems the shape is
47    /// `[ny][nx]`.
48    pub u: Vec<Vec<T>>,
49    /// Spatial grid points (x-axis).
50    pub x: Vec<T>,
51    /// Time points (for time-dependent problems) or y-axis grid (for 2-D
52    /// steady-state).
53    pub t_or_y: Vec<T>,
54    /// Number of time / iteration steps taken.
55    pub steps: usize,
56    /// Whether the solution converged (meaningful for iterative methods such
57    /// as Gauss-Seidel).
58    pub converged: bool,
59}
60
61// ---------------------------------------------------------------------------
62// Helpers
63// ---------------------------------------------------------------------------
64
65/// Build a uniform grid of `n` points spanning `[a, b]`.
66fn linspace<T: Float>(a: T, b: T, n: usize) -> Vec<T> {
67    if n < 2 {
68        return vec![a];
69    }
70    let n_intervals = T::from_usize(n - 1);
71    let dx = (b - a) / n_intervals;
72    (0..n).map(|i| a + T::from_usize(i) * dx).collect()
73}
74
75/// Apply a boundary condition at one end of a 1-D solution row.
76///
77/// * `is_left`  — `true` for the left boundary (index 0), `false` for the
78///   right boundary (last index).
79/// * `row`      — mutable slice of the current solution row.
80/// * `dx`       — spatial step size.
81fn apply_bc_1d<T: Float>(bc: &BoundaryCondition<T>, is_left: bool, row: &mut [T], dx: T) {
82    let n = row.len();
83    match *bc {
84        BoundaryCondition::Dirichlet(val) => {
85            if is_left {
86                row[0] = val;
87            } else {
88                row[n - 1] = val;
89            }
90        }
91        BoundaryCondition::Neumann(deriv) => {
92            // Ghost-node approach: u[-1] = u[1] - 2*dx*deriv  (left)
93            //                      u[n]  = u[n-2] + 2*dx*deriv (right)
94            if is_left {
95                row[0] = row[1] - dx * deriv;
96            } else {
97                row[n - 1] = row[n - 2] + dx * deriv;
98            }
99        }
100    }
101}
102
103// ---------------------------------------------------------------------------
104// 1-D Heat equation  (FTCS explicit scheme)
105// ---------------------------------------------------------------------------
106
107/// Solve the 1-D heat equation
108///
109/// ```text
110///   ∂u/∂t = α ∂²u/∂x²
111/// ```
112///
113/// using the explicit forward-time, centred-space (FTCS) scheme.
114///
115/// # Parameters
116///
117/// * `x_range`  — spatial domain `[x0, x1]`.
118/// * `n_x`      — number of spatial grid points (must be >= 3).
119/// * `t_final`  — simulate until this time (must be > 0).
120/// * `n_t`      — number of time steps (must be >= 1).
121/// * `alpha`    — thermal diffusivity (must be > 0).
122/// * `initial`  — initial condition `u(x, 0)`.
123/// * `left_bc`  — boundary condition at `x = x0`.
124/// * `right_bc` — boundary condition at `x = x1`.
125///
126/// # Errors
127///
128/// Returns [`OptimError::InvalidParameter`] when the grid is too small,
129/// parameters are non-positive, or the CFL stability condition
130/// `r = α dt / dx² <= 0.5` is violated.
131///
132/// # Examples
133///
134/// ```
135/// # use scivex_optim::pde::{heat_equation_1d, BoundaryCondition};
136/// let result = heat_equation_1d(
137///     (0.0_f64, 1.0), 50, 0.01, 500, 1.0,
138///     &|x| (std::f64::consts::PI * x).sin(),
139///     BoundaryCondition::Dirichlet(0.0),
140///     BoundaryCondition::Dirichlet(0.0),
141/// ).unwrap();
142/// assert!(result.converged);
143/// ```
144#[allow(clippy::too_many_arguments)]
145pub fn heat_equation_1d<T: Float>(
146    x_range: (T, T),
147    n_x: usize,
148    t_final: T,
149    n_t: usize,
150    alpha: T,
151    initial: &dyn Fn(T) -> T,
152    left_bc: BoundaryCondition<T>,
153    right_bc: BoundaryCondition<T>,
154) -> Result<PdeResult<T>> {
155    // --- Validate inputs ---------------------------------------------------
156    if n_x < 3 {
157        return Err(OptimError::InvalidParameter {
158            name: "n_x",
159            reason: "need at least 3 spatial points",
160        });
161    }
162    if n_t < 1 {
163        return Err(OptimError::InvalidParameter {
164            name: "n_t",
165            reason: "need at least 1 time step",
166        });
167    }
168    let zero = T::zero();
169    if t_final <= zero {
170        return Err(OptimError::InvalidParameter {
171            name: "t_final",
172            reason: "must be positive",
173        });
174    }
175    if alpha <= zero {
176        return Err(OptimError::InvalidParameter {
177            name: "alpha",
178            reason: "must be positive",
179        });
180    }
181
182    let x = linspace(x_range.0, x_range.1, n_x);
183    let dx = x[1] - x[0];
184    let dt = t_final / T::from_usize(n_t);
185    let r = alpha * dt / (dx * dx);
186
187    let half = T::from_f64(0.5);
188    if r > half {
189        return Err(OptimError::InvalidParameter {
190            name: "n_t",
191            reason: "stability condition violated: r = alpha*dt/dx^2 must be <= 0.5",
192        });
193    }
194
195    // --- Initial condition --------------------------------------------------
196    let mut u_prev: Vec<T> = x.iter().map(|&xi| initial(xi)).collect();
197    apply_bc_1d(&left_bc, true, &mut u_prev, dx);
198    apply_bc_1d(&right_bc, false, &mut u_prev, dx);
199
200    let mut all_u: Vec<Vec<T>> = Vec::with_capacity(n_t + 1);
201    all_u.push(u_prev.clone());
202
203    let mut t_vals: Vec<T> = Vec::with_capacity(n_t + 1);
204    t_vals.push(zero);
205
206    // --- Time stepping (FTCS) ----------------------------------------------
207    let two = T::from_f64(2.0);
208    for step in 0..n_t {
209        let mut u_next = u_prev.clone();
210        for i in 1..(n_x - 1) {
211            u_next[i] = u_prev[i] + r * (u_prev[i + 1] - two * u_prev[i] + u_prev[i - 1]);
212        }
213        apply_bc_1d(&left_bc, true, &mut u_next, dx);
214        apply_bc_1d(&right_bc, false, &mut u_next, dx);
215
216        t_vals.push(T::from_usize(step + 1) * dt);
217        all_u.push(u_next.clone());
218        u_prev = u_next;
219    }
220
221    Ok(PdeResult {
222        u: all_u,
223        x,
224        t_or_y: t_vals,
225        steps: n_t,
226        converged: true,
227    })
228}
229
230// ---------------------------------------------------------------------------
231// 1-D Wave equation (explicit three-level scheme)
232// ---------------------------------------------------------------------------
233
234/// Solve the 1-D wave equation
235///
236/// ```text
237///   ∂²u/∂t² = c² ∂²u/∂x²
238/// ```
239///
240/// using the explicit three-level centred-difference scheme.
241///
242/// # Parameters
243///
244/// * `x_range`    — spatial domain `[x0, x1]`.
245/// * `n_x`        — number of spatial grid points (must be >= 3).
246/// * `t_final`    — simulate until this time (must be > 0).
247/// * `n_t`        — number of time steps (must be >= 1).
248/// * `c`          — wave speed (must be > 0).
249/// * `initial_u`  — initial displacement `u(x, 0)`.
250/// * `initial_ut` — initial velocity `∂u/∂t(x, 0)`.
251/// * `left_bc`    — boundary condition at `x = x0`.
252/// * `right_bc`   — boundary condition at `x = x1`.
253///
254/// # Errors
255///
256/// Returns [`OptimError::InvalidParameter`] when the grid is too small,
257/// parameters are non-positive, or the CFL condition `c dt / dx <= 1` is
258/// violated.
259#[allow(clippy::too_many_arguments)]
260pub fn wave_equation_1d<T: Float>(
261    x_range: (T, T),
262    n_x: usize,
263    t_final: T,
264    n_t: usize,
265    c: T,
266    initial_u: &dyn Fn(T) -> T,
267    initial_ut: &dyn Fn(T) -> T,
268    left_bc: BoundaryCondition<T>,
269    right_bc: BoundaryCondition<T>,
270) -> Result<PdeResult<T>> {
271    // --- Validate inputs ---------------------------------------------------
272    if n_x < 3 {
273        return Err(OptimError::InvalidParameter {
274            name: "n_x",
275            reason: "need at least 3 spatial points",
276        });
277    }
278    if n_t < 1 {
279        return Err(OptimError::InvalidParameter {
280            name: "n_t",
281            reason: "need at least 1 time step",
282        });
283    }
284    let zero = T::zero();
285    if t_final <= zero {
286        return Err(OptimError::InvalidParameter {
287            name: "t_final",
288            reason: "must be positive",
289        });
290    }
291    if c <= zero {
292        return Err(OptimError::InvalidParameter {
293            name: "c",
294            reason: "must be positive",
295        });
296    }
297
298    let x = linspace(x_range.0, x_range.1, n_x);
299    let dx = x[1] - x[0];
300    let dt = t_final / T::from_usize(n_t);
301    let r = c * dt / dx; // Courant number
302
303    if r > T::one() {
304        return Err(OptimError::InvalidParameter {
305            name: "n_t",
306            reason: "CFL condition violated: c*dt/dx must be <= 1",
307        });
308    }
309
310    let r2 = r * r;
311    let two = T::from_f64(2.0);
312
313    // --- Level 0: u(x, 0) --------------------------------------------------
314    let mut u_prev: Vec<T> = x.iter().map(|&xi| initial_u(xi)).collect();
315    apply_bc_1d(&left_bc, true, &mut u_prev, dx);
316    apply_bc_1d(&right_bc, false, &mut u_prev, dx);
317
318    let mut all_u: Vec<Vec<T>> = Vec::with_capacity(n_t + 1);
319    all_u.push(u_prev.clone());
320
321    let mut t_vals: Vec<T> = Vec::with_capacity(n_t + 1);
322    t_vals.push(zero);
323
324    // --- Level 1: special first step using initial velocity -----------------
325    // u^1_i = u^0_i + dt * ut(x_i) + 0.5*r²*(u^0_{i+1} - 2 u^0_i + u^0_{i-1})
326    let half = T::from_f64(0.5);
327    let mut u_curr: Vec<T> = vec![zero; n_x];
328    for i in 1..(n_x - 1) {
329        let laplacian = u_prev[i + 1] - two * u_prev[i] + u_prev[i - 1];
330        u_curr[i] = u_prev[i] + dt * initial_ut(x[i]) + half * r2 * laplacian;
331    }
332    apply_bc_1d(&left_bc, true, &mut u_curr, dx);
333    apply_bc_1d(&right_bc, false, &mut u_curr, dx);
334
335    t_vals.push(dt);
336    all_u.push(u_curr.clone());
337
338    // --- Remaining steps (three-level scheme) -------------------------------
339    for step in 1..n_t {
340        let mut u_next = vec![zero; n_x];
341        for i in 1..(n_x - 1) {
342            let laplacian = u_curr[i + 1] - two * u_curr[i] + u_curr[i - 1];
343            u_next[i] = two * u_curr[i] - u_prev[i] + r2 * laplacian;
344        }
345        apply_bc_1d(&left_bc, true, &mut u_next, dx);
346        apply_bc_1d(&right_bc, false, &mut u_next, dx);
347
348        t_vals.push(T::from_usize(step + 1) * dt);
349        all_u.push(u_next.clone());
350        u_prev = u_curr;
351        u_curr = u_next;
352    }
353
354    Ok(PdeResult {
355        u: all_u,
356        x,
357        t_or_y: t_vals,
358        steps: n_t,
359        converged: true,
360    })
361}
362
363// ---------------------------------------------------------------------------
364// 2-D Laplace equation (Gauss-Seidel iteration)
365// ---------------------------------------------------------------------------
366
367/// Solve the 2-D Laplace equation
368///
369/// ```text
370///   ∂²u/∂x² + ∂²u/∂y² = 0
371/// ```
372///
373/// on a rectangular domain using Gauss-Seidel relaxation.
374///
375/// The `boundary` closure receives `(x, y)` and must return `Some(value)` for
376/// every point on the boundary of the domain (i.e.\ the first/last row/column
377/// of the grid).  Interior points should return `None`.
378///
379/// # Parameters
380///
381/// * `x_range`  — `[x0, x1]`.
382/// * `y_range`  — `[y0, y1]`.
383/// * `n_x`      — number of grid points in x (must be >= 3).
384/// * `n_y`      — number of grid points in y (must be >= 3).
385/// * `boundary` — closure returning `Some(T)` for boundary points.
386/// * `max_iter` — maximum number of Gauss-Seidel sweeps.
387/// * `tol`      — convergence tolerance on the max absolute update.
388///
389/// # Errors
390///
391/// Returns [`OptimError::InvalidParameter`] for grids that are too small or
392/// non-positive `max_iter` / `tol`.
393pub fn laplace_2d<T: Float>(
394    x_range: (T, T),
395    y_range: (T, T),
396    n_x: usize,
397    n_y: usize,
398    boundary: &dyn Fn(T, T) -> Option<T>,
399    max_iter: usize,
400    tol: T,
401) -> Result<PdeResult<T>> {
402    // --- Validate inputs ---------------------------------------------------
403    if n_x < 3 {
404        return Err(OptimError::InvalidParameter {
405            name: "n_x",
406            reason: "need at least 3 grid points in x",
407        });
408    }
409    if n_y < 3 {
410        return Err(OptimError::InvalidParameter {
411            name: "n_y",
412            reason: "need at least 3 grid points in y",
413        });
414    }
415    if max_iter == 0 {
416        return Err(OptimError::InvalidParameter {
417            name: "max_iter",
418            reason: "must be at least 1",
419        });
420    }
421    if tol <= T::zero() {
422        return Err(OptimError::InvalidParameter {
423            name: "tol",
424            reason: "must be positive",
425        });
426    }
427
428    let x = linspace(x_range.0, x_range.1, n_x);
429    let y = linspace(y_range.0, y_range.1, n_y);
430
431    // --- Initialise grid with boundary values; interior = 0 ----------------
432    let mut u: Vec<Vec<T>> = Vec::with_capacity(n_y);
433    let mut is_boundary: Vec<Vec<bool>> = Vec::with_capacity(n_y);
434
435    for yj in &y {
436        let mut row = vec![T::zero(); n_x];
437        let mut brow = vec![false; n_x];
438        for (i, xi) in x.iter().enumerate() {
439            if let Some(val) = boundary(*xi, *yj) {
440                row[i] = val;
441                brow[i] = true;
442            }
443        }
444        u.push(row);
445        is_boundary.push(brow);
446    }
447
448    // --- Gauss-Seidel iteration --------------------------------------------
449    let quarter = T::from_f64(0.25);
450    let mut converged = false;
451    let mut steps: usize = 0;
452
453    for _iter in 0..max_iter {
454        let mut max_diff = T::zero();
455        for j in 1..(n_y - 1) {
456            for i in 1..(n_x - 1) {
457                if is_boundary[j][i] {
458                    continue;
459                }
460                let new_val = quarter * (u[j][i + 1] + u[j][i - 1] + u[j + 1][i] + u[j - 1][i]);
461                let diff = (new_val - u[j][i]).abs();
462                if diff > max_diff {
463                    max_diff = diff;
464                }
465                u[j][i] = new_val;
466            }
467        }
468        steps += 1;
469        if max_diff < tol {
470            converged = true;
471            break;
472        }
473    }
474
475    Ok(PdeResult {
476        u,
477        x,
478        t_or_y: y,
479        steps,
480        converged,
481    })
482}
483
484// ===========================================================================
485// Tests
486// ===========================================================================
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    /// Heat equation with fixed BCs (0 on left, 1 on right) should converge
493    /// toward a linear profile u(x) = x / L in steady state.
494    #[test]
495    fn test_heat_steady_state() {
496        let n_x = 21;
497        let n_t = 50_000;
498        let result = heat_equation_1d(
499            (0.0, 1.0),
500            n_x,
501            50.0, // long enough to approach steady state (L²/α = 10)
502            n_t,
503            0.1,       // alpha
504            &|_x| 0.0, // initial = 0 everywhere
505            BoundaryCondition::Dirichlet(0.0),
506            BoundaryCondition::Dirichlet(1.0),
507        )
508        .unwrap();
509
510        // Last row should be approximately linear: u(x) ≈ x
511        let last = result.u.last().unwrap();
512        for (i, &xi) in result.x.iter().enumerate() {
513            let err = (last[i] - xi).abs();
514            assert!(
515                err < 0.05,
516                "steady-state error too large at x={xi}: u={}, expected={xi}, err={err}",
517                last[i],
518            );
519        }
520    }
521
522    /// A Gaussian pulse should diffuse (spread out) under the heat equation:
523    /// its peak amplitude should decrease over time.
524    #[test]
525    fn test_heat_gaussian_diffusion() {
526        let n_x = 101;
527        let n_t = 5000;
528        let result = heat_equation_1d(
529            (0.0, 1.0),
530            n_x,
531            0.05,
532            n_t,
533            0.01,
534            &|x: f64| (-(x - 0.5).powi(2) / 0.01).exp(),
535            BoundaryCondition::Dirichlet(0.0),
536            BoundaryCondition::Dirichlet(0.0),
537        )
538        .unwrap();
539
540        // Peak of initial condition
541        let initial_max = result.u[0]
542            .iter()
543            .copied()
544            .fold(f64::NEG_INFINITY, f64::max);
545
546        // Peak at final time
547        let final_max = result
548            .u
549            .last()
550            .unwrap()
551            .iter()
552            .copied()
553            .fold(f64::NEG_INFINITY, f64::max);
554
555        assert!(
556            final_max < initial_max,
557            "Gaussian peak should decrease: initial={initial_max}, final={final_max}",
558        );
559    }
560
561    /// A sine standing-wave should oscillate: u(x,0) = sin(pi*x), ut=0.
562    /// After half a period the solution should be approximately -sin(pi*x).
563    #[test]
564    fn test_wave_standing_wave() {
565        let n_x = 101;
566        let c = 1.0_f64;
567        // Period = 2*L/c = 2.0 for L=1, c=1.  Half-period = 1.0.
568        let t_final = 1.0;
569        let n_t = 200;
570        let result = wave_equation_1d(
571            (0.0, 1.0),
572            n_x,
573            t_final,
574            n_t,
575            c,
576            &|x: f64| (std::f64::consts::PI * x).sin(),
577            &|_x: f64| 0.0,
578            BoundaryCondition::Dirichlet(0.0),
579            BoundaryCondition::Dirichlet(0.0),
580        )
581        .unwrap();
582
583        // At t = half-period the displacement should be ≈ -sin(pi*x).
584        let last = result.u.last().unwrap();
585        let mid = n_x / 2; // x = 0.5
586        // sin(pi*0.5) = 1.0, so we expect ≈ -1.0.
587        assert!(
588            last[mid] < -0.8,
589            "standing wave mid-point should be near -1 at half-period, got {}",
590            last[mid],
591        );
592    }
593
594    /// Laplace equation with boundary u = x (linear) should yield the exact
595    /// linear interior solution u(x, y) = x.
596    #[test]
597    fn test_laplace_linear_boundary() {
598        let n_x = 21;
599        let n_y = 21;
600        let result = laplace_2d(
601            (0.0, 1.0),
602            (0.0, 1.0),
603            n_x,
604            n_y,
605            &|x: f64, y: f64| {
606                // Mark every edge point as boundary with value = x.
607                if x < 1e-12 || (x - 1.0).abs() < 1e-12 || y < 1e-12 || (y - 1.0).abs() < 1e-12 {
608                    Some(x)
609                } else {
610                    None
611                }
612            },
613            10_000,
614            1e-10,
615        )
616        .unwrap();
617
618        assert!(result.converged, "Laplace solver should converge");
619
620        // Interior should be ≈ x.
621        for j in 1..(n_y - 1) {
622            for i in 1..(n_x - 1) {
623                let err = (result.u[j][i] - result.x[i]).abs();
624                assert!(
625                    err < 1e-6,
626                    "Laplace linear solution error at ({}, {}): u={}, expected={}, err={err}",
627                    result.x[i],
628                    result.t_or_y[j],
629                    result.u[j][i],
630                    result.x[i],
631                );
632            }
633        }
634    }
635
636    /// Verify the `converged` flag is actually set when the solver meets the
637    /// tolerance.
638    #[test]
639    fn test_laplace_convergence() {
640        let result = laplace_2d(
641            (0.0, 1.0),
642            (0.0, 1.0),
643            11,
644            11,
645            &|x: f64, y: f64| {
646                if x < 1e-12 || (x - 1.0).abs() < 1e-12 || y < 1e-12 || (y - 1.0).abs() < 1e-12 {
647                    Some(x * y)
648                } else {
649                    None
650                }
651            },
652            50_000,
653            1e-8,
654        )
655        .unwrap();
656
657        assert!(
658            result.converged,
659            "Laplace solver should converge within 50 000 iterations",
660        );
661        assert!(
662            result.steps < 50_000,
663            "should converge before hitting max_iter (took {} steps)",
664            result.steps,
665        );
666    }
667}