Skip to main content

scirs2_integrate/pde/
fd_solvers.rs

1//! Enhanced Finite Difference PDE Solvers
2//!
3//! Provides solvers for parabolic, hyperbolic, and elliptic PDEs using
4//! finite difference discretization with configurable boundary conditions,
5//! explicit and implicit time-stepping, and stability analysis.
6//!
7//! ## Equation Types
8//! - **Heat equation** (parabolic): du/dt = alpha * d2u/dx2
9//! - **Wave equation** (hyperbolic): d2u/dt2 = c^2 * d2u/dx2
10//! - **Poisson equation** (elliptic): d2u/dx2 + d2u/dy2 = f(x,y)
11//!
12//! ## Boundary Conditions
13//! - Dirichlet (fixed value)
14//! - Neumann (fixed derivative)
15//! - Periodic (wrap-around)
16//!
17//! ## Time-Stepping Methods
18//! - Explicit (forward Euler, limited by CFL condition)
19//! - Implicit Crank-Nicolson (unconditionally stable, second-order)
20
21use scirs2_core::ndarray::{Array1, Array2};
22
23use crate::pde::{PDEError, PDEResult};
24
25// ---------------------------------------------------------------------------
26// Boundary condition types for FD solvers
27// ---------------------------------------------------------------------------
28
29/// Boundary condition for finite-difference PDE solvers
30#[derive(Debug, Clone)]
31pub enum FDBoundaryCondition {
32    /// Fixed value at boundary: u(boundary) = value
33    Dirichlet(f64),
34    /// Fixed derivative at boundary: du/dn(boundary) = value
35    Neumann(f64),
36    /// Periodic boundary (left and right are identified)
37    Periodic,
38}
39
40/// Time-stepping method for parabolic/hyperbolic PDEs
41#[derive(Debug, Clone, Copy, PartialEq)]
42pub enum TimeSteppingMethod {
43    /// Explicit forward Euler (conditionally stable)
44    Explicit,
45    /// Crank-Nicolson (unconditionally stable, second-order in time)
46    CrankNicolson,
47}
48
49/// Iterative method for elliptic PDEs
50#[derive(Debug, Clone, Copy, PartialEq)]
51pub enum EllipticIterativeMethod {
52    /// Jacobi iteration
53    Jacobi,
54    /// Gauss-Seidel iteration
55    GaussSeidel,
56    /// Successive Over-Relaxation (SOR)
57    SOR(f64),
58}
59
60// ---------------------------------------------------------------------------
61// CFL stability analysis
62// ---------------------------------------------------------------------------
63
64/// CFL (Courant-Friedrichs-Lewy) stability analysis result
65#[derive(Debug, Clone)]
66pub struct CFLAnalysis {
67    /// The computed CFL number
68    pub cfl_number: f64,
69    /// Whether the scheme is stable at this CFL number
70    pub is_stable: bool,
71    /// Maximum stable time step for explicit methods
72    pub max_stable_dt: f64,
73    /// Description of the stability condition
74    pub description: String,
75}
76
77/// Check CFL condition for the heat equation (explicit forward Euler)
78///
79/// For the 1D heat equation du/dt = alpha * d2u/dx2,
80/// the CFL condition is: alpha * dt / dx^2 <= 0.5
81pub fn cfl_heat_1d(alpha: f64, dx: f64, dt: f64) -> CFLAnalysis {
82    let cfl = alpha * dt / (dx * dx);
83    let max_stable_dt = 0.5 * dx * dx / alpha;
84    CFLAnalysis {
85        cfl_number: cfl,
86        is_stable: cfl <= 0.5,
87        max_stable_dt,
88        description: format!(
89            "Heat 1D: CFL = {cfl:.4e} (must be <= 0.5). Max stable dt = {max_stable_dt:.4e}"
90        ),
91    }
92}
93
94/// Check CFL condition for the 2D heat equation (explicit forward Euler)
95///
96/// CFL condition: alpha * dt * (1/dx^2 + 1/dy^2) <= 0.5
97pub fn cfl_heat_2d(alpha: f64, dx: f64, dy: f64, dt: f64) -> CFLAnalysis {
98    let cfl = alpha * dt * (1.0 / (dx * dx) + 1.0 / (dy * dy));
99    let max_stable_dt = 0.5 / (alpha * (1.0 / (dx * dx) + 1.0 / (dy * dy)));
100    CFLAnalysis {
101        cfl_number: cfl,
102        is_stable: cfl <= 0.5,
103        max_stable_dt,
104        description: format!(
105            "Heat 2D: CFL = {cfl:.4e} (must be <= 0.5). Max stable dt = {max_stable_dt:.4e}"
106        ),
107    }
108}
109
110/// Check CFL condition for the 1D wave equation (explicit)
111///
112/// CFL condition: c * dt / dx <= 1.0
113pub fn cfl_wave_1d(c: f64, dx: f64, dt: f64) -> CFLAnalysis {
114    let cfl = c * dt / dx;
115    let max_stable_dt = dx / c;
116    CFLAnalysis {
117        cfl_number: cfl,
118        is_stable: cfl <= 1.0,
119        max_stable_dt,
120        description: format!(
121            "Wave 1D: CFL = {cfl:.4e} (must be <= 1.0). Max stable dt = {max_stable_dt:.4e}"
122        ),
123    }
124}
125
126/// Check CFL condition for the 2D wave equation (explicit)
127///
128/// CFL condition: c * dt * sqrt(1/dx^2 + 1/dy^2) <= 1.0
129pub fn cfl_wave_2d(c: f64, dx: f64, dy: f64, dt: f64) -> CFLAnalysis {
130    let factor = (1.0 / (dx * dx) + 1.0 / (dy * dy)).sqrt();
131    let cfl = c * dt * factor;
132    let max_stable_dt = 1.0 / (c * factor);
133    CFLAnalysis {
134        cfl_number: cfl,
135        is_stable: cfl <= 1.0,
136        max_stable_dt,
137        description: format!(
138            "Wave 2D: CFL = {cfl:.4e} (must be <= 1.0). Max stable dt = {max_stable_dt:.4e}"
139        ),
140    }
141}
142
143// ---------------------------------------------------------------------------
144// 1D Heat Equation solver result
145// ---------------------------------------------------------------------------
146
147/// Result from a heat equation solve
148#[derive(Debug, Clone)]
149pub struct HeatResult {
150    /// Spatial grid x values
151    pub x: Array1<f64>,
152    /// Time grid t values
153    pub t: Array1<f64>,
154    /// Solution u[time_step, spatial_index]
155    pub u: Array2<f64>,
156    /// CFL analysis (if explicit method used)
157    pub cfl: Option<CFLAnalysis>,
158}
159
160// ---------------------------------------------------------------------------
161// 1D Heat Equation
162// ---------------------------------------------------------------------------
163
164/// Solve 1D heat equation: du/dt = alpha * d2u/dx2
165///
166/// # Arguments
167/// * `alpha` - thermal diffusivity (> 0)
168/// * `x_range` - spatial domain [x_min, x_max]
169/// * `t_range` - time domain [t_min, t_max]
170/// * `nx` - number of spatial grid points
171/// * `nt` - number of time steps
172/// * `initial_condition` - function u(x, 0)
173/// * `left_bc` - boundary condition at x_min
174/// * `right_bc` - boundary condition at x_max
175/// * `method` - time-stepping method
176pub fn solve_heat_1d(
177    alpha: f64,
178    x_range: [f64; 2],
179    t_range: [f64; 2],
180    nx: usize,
181    nt: usize,
182    initial_condition: &dyn Fn(f64) -> f64,
183    left_bc: &FDBoundaryCondition,
184    right_bc: &FDBoundaryCondition,
185    method: TimeSteppingMethod,
186) -> PDEResult<HeatResult> {
187    if alpha <= 0.0 {
188        return Err(PDEError::InvalidParameter(
189            "Thermal diffusivity alpha must be positive".to_string(),
190        ));
191    }
192    if nx < 3 {
193        return Err(PDEError::InvalidGrid(
194            "Need at least 3 spatial grid points".to_string(),
195        ));
196    }
197    if nt < 1 {
198        return Err(PDEError::InvalidParameter(
199            "Need at least 1 time step".to_string(),
200        ));
201    }
202
203    let dx = (x_range[1] - x_range[0]) / (nx as f64 - 1.0);
204    let dt = (t_range[1] - t_range[0]) / nt as f64;
205
206    // Build spatial grid
207    let x = Array1::from_shape_fn(nx, |i| x_range[0] + i as f64 * dx);
208    // Build time grid
209    let t = Array1::from_shape_fn(nt + 1, |i| t_range[0] + i as f64 * dt);
210
211    // Initialize solution array
212    let mut u = Array2::zeros((nt + 1, nx));
213    for i in 0..nx {
214        u[[0, i]] = initial_condition(x[i]);
215    }
216    // Apply initial BCs
217    apply_bc_1d(&mut u, 0, left_bc, right_bc, dx);
218
219    let cfl = cfl_heat_1d(alpha, dx, dt);
220
221    match method {
222        TimeSteppingMethod::Explicit => {
223            if !cfl.is_stable {
224                return Err(PDEError::ComputationError(format!(
225                    "Explicit scheme unstable: {}",
226                    cfl.description
227                )));
228            }
229            let r = alpha * dt / (dx * dx);
230            for n in 0..nt {
231                // Check for periodic BC pair
232                let is_periodic = matches!(
233                    (left_bc, right_bc),
234                    (FDBoundaryCondition::Periodic, FDBoundaryCondition::Periodic)
235                );
236                for i in 1..nx - 1 {
237                    u[[n + 1, i]] =
238                        u[[n, i]] + r * (u[[n, i + 1]] - 2.0 * u[[n, i]] + u[[n, i - 1]]);
239                }
240                if is_periodic {
241                    // Periodic: wrap around
242                    u[[n + 1, 0]] = u[[n, 0]] + r * (u[[n, 1]] - 2.0 * u[[n, 0]] + u[[n, nx - 2]]);
243                    u[[n + 1, nx - 1]] = u[[n + 1, 0]];
244                } else {
245                    apply_bc_1d(&mut u, n + 1, left_bc, right_bc, dx);
246                }
247            }
248        }
249        TimeSteppingMethod::CrankNicolson => {
250            let r = alpha * dt / (2.0 * dx * dx);
251            let is_periodic = matches!(
252                (left_bc, right_bc),
253                (FDBoundaryCondition::Periodic, FDBoundaryCondition::Periodic)
254            );
255            for n in 0..nt {
256                // Build RHS
257                let mut rhs = Array1::zeros(nx);
258                for i in 1..nx - 1 {
259                    rhs[i] = u[[n, i]] + r * (u[[n, i + 1]] - 2.0 * u[[n, i]] + u[[n, i - 1]]);
260                }
261                if is_periodic {
262                    rhs[0] = u[[n, 0]] + r * (u[[n, 1]] - 2.0 * u[[n, 0]] + u[[n, nx - 2]]);
263                    rhs[nx - 1] = rhs[0];
264                }
265
266                // Solve tridiagonal system (1+2r) u_new[i] - r u_new[i-1] - r u_new[i+1] = rhs[i]
267                if is_periodic {
268                    let solved = solve_periodic_tridiag(nx - 1, -r, 1.0 + 2.0 * r, -r, &rhs)?;
269                    for i in 0..nx - 1 {
270                        u[[n + 1, i]] = solved[i];
271                    }
272                    u[[n + 1, nx - 1]] = u[[n + 1, 0]];
273                } else {
274                    let interior_size = nx - 2;
275                    if interior_size == 0 {
276                        apply_bc_1d(&mut u, n + 1, left_bc, right_bc, dx);
277                        continue;
278                    }
279                    let mut rhs_interior = Array1::zeros(interior_size);
280                    for i in 0..interior_size {
281                        rhs_interior[i] = rhs[i + 1];
282                    }
283                    // Adjust RHS for boundary conditions
284                    apply_cn_bc_adjustment(
285                        &mut rhs_interior,
286                        left_bc,
287                        right_bc,
288                        r,
289                        &u,
290                        n + 1,
291                        nx,
292                        dx,
293                    );
294                    let solved =
295                        solve_tridiag(interior_size, -r, 1.0 + 2.0 * r, -r, &rhs_interior)?;
296                    for i in 0..interior_size {
297                        u[[n + 1, i + 1]] = solved[i];
298                    }
299                    apply_bc_1d(&mut u, n + 1, left_bc, right_bc, dx);
300                }
301            }
302        }
303    }
304
305    Ok(HeatResult {
306        x,
307        t,
308        u,
309        cfl: Some(cfl),
310    })
311}
312
313// ---------------------------------------------------------------------------
314// 2D Heat Equation
315// ---------------------------------------------------------------------------
316
317/// Result from a 2D heat equation solve
318#[derive(Debug, Clone)]
319pub struct Heat2DResult {
320    /// Spatial grid x values
321    pub x: Array1<f64>,
322    /// Spatial grid y values
323    pub y: Array1<f64>,
324    /// Time grid t values
325    pub t: Array1<f64>,
326    /// Solution snapshots, `u[time_step]` is a 2D array `[ny, nx]`
327    pub u: Vec<Array2<f64>>,
328    /// CFL analysis
329    pub cfl: Option<CFLAnalysis>,
330}
331
332/// Solve 2D heat equation: du/dt = alpha * (d2u/dx2 + d2u/dy2)
333///
334/// Only Dirichlet BCs are supported for the 2D version to keep the interface simple.
335/// Explicit forward Euler time stepping.
336pub fn solve_heat_2d(
337    alpha: f64,
338    x_range: [f64; 2],
339    y_range: [f64; 2],
340    t_range: [f64; 2],
341    nx: usize,
342    ny: usize,
343    nt: usize,
344    initial_condition: &dyn Fn(f64, f64) -> f64,
345    bc_values: [f64; 4], // [left, right, bottom, top] Dirichlet values
346    save_every: usize,
347) -> PDEResult<Heat2DResult> {
348    if alpha <= 0.0 {
349        return Err(PDEError::InvalidParameter(
350            "Thermal diffusivity alpha must be positive".to_string(),
351        ));
352    }
353    if nx < 3 || ny < 3 {
354        return Err(PDEError::InvalidGrid(
355            "Need at least 3 grid points in each dimension".to_string(),
356        ));
357    }
358
359    let dx = (x_range[1] - x_range[0]) / (nx as f64 - 1.0);
360    let dy = (y_range[1] - y_range[0]) / (ny as f64 - 1.0);
361    let dt = (t_range[1] - t_range[0]) / nt as f64;
362
363    let cfl = cfl_heat_2d(alpha, dx, dy, dt);
364    if !cfl.is_stable {
365        return Err(PDEError::ComputationError(format!(
366            "Explicit scheme unstable: {}",
367            cfl.description
368        )));
369    }
370
371    let x = Array1::from_shape_fn(nx, |i| x_range[0] + i as f64 * dx);
372    let y = Array1::from_shape_fn(ny, |j| y_range[0] + j as f64 * dy);
373    let mut t_save = vec![t_range[0]];
374
375    // Initialize
376    let mut u_curr = Array2::zeros((ny, nx));
377    for j in 0..ny {
378        for i in 0..nx {
379            u_curr[[j, i]] = initial_condition(x[i], y[j]);
380        }
381    }
382    apply_dirichlet_2d(&mut u_curr, bc_values, nx, ny);
383
384    let save_every = if save_every == 0 { 1 } else { save_every };
385    let mut snapshots = vec![u_curr.clone()];
386
387    let rx = alpha * dt / (dx * dx);
388    let ry = alpha * dt / (dy * dy);
389
390    for n in 0..nt {
391        let mut u_next = u_curr.clone();
392        for j in 1..ny - 1 {
393            for i in 1..nx - 1 {
394                u_next[[j, i]] = u_curr[[j, i]]
395                    + rx * (u_curr[[j, i + 1]] - 2.0 * u_curr[[j, i]] + u_curr[[j, i - 1]])
396                    + ry * (u_curr[[j + 1, i]] - 2.0 * u_curr[[j, i]] + u_curr[[j - 1, i]]);
397            }
398        }
399        apply_dirichlet_2d(&mut u_next, bc_values, nx, ny);
400        u_curr = u_next;
401
402        if (n + 1) % save_every == 0 || n + 1 == nt {
403            snapshots.push(u_curr.clone());
404            t_save.push(t_range[0] + (n + 1) as f64 * dt);
405        }
406    }
407
408    Ok(Heat2DResult {
409        x,
410        y,
411        t: Array1::from_vec(t_save),
412        u: snapshots,
413        cfl: Some(cfl),
414    })
415}
416
417// ---------------------------------------------------------------------------
418// 1D Wave Equation
419// ---------------------------------------------------------------------------
420
421/// Result from a wave equation solve
422#[derive(Debug, Clone)]
423pub struct WaveResult {
424    /// Spatial grid x values
425    pub x: Array1<f64>,
426    /// Time grid t values
427    pub t: Array1<f64>,
428    /// Solution u[time_step, spatial_index]
429    pub u: Array2<f64>,
430    /// CFL analysis
431    pub cfl: Option<CFLAnalysis>,
432}
433
434/// Solve 1D wave equation: d2u/dt2 = c^2 * d2u/dx2
435///
436/// Uses the explicit leapfrog scheme, which requires CFL number c*dt/dx <= 1.
437pub fn solve_wave_1d(
438    c: f64,
439    x_range: [f64; 2],
440    t_range: [f64; 2],
441    nx: usize,
442    nt: usize,
443    initial_displacement: &dyn Fn(f64) -> f64,
444    initial_velocity: &dyn Fn(f64) -> f64,
445    left_bc: &FDBoundaryCondition,
446    right_bc: &FDBoundaryCondition,
447) -> PDEResult<WaveResult> {
448    if c <= 0.0 {
449        return Err(PDEError::InvalidParameter(
450            "Wave speed c must be positive".to_string(),
451        ));
452    }
453    if nx < 3 {
454        return Err(PDEError::InvalidGrid(
455            "Need at least 3 spatial grid points".to_string(),
456        ));
457    }
458
459    let dx = (x_range[1] - x_range[0]) / (nx as f64 - 1.0);
460    let dt = (t_range[1] - t_range[0]) / nt as f64;
461
462    let cfl = cfl_wave_1d(c, dx, dt);
463    if !cfl.is_stable {
464        return Err(PDEError::ComputationError(format!(
465            "Explicit wave scheme unstable: {}",
466            cfl.description
467        )));
468    }
469
470    let r2 = (c * dt / dx) * (c * dt / dx);
471    let x = Array1::from_shape_fn(nx, |i| x_range[0] + i as f64 * dx);
472    let t = Array1::from_shape_fn(nt + 1, |i| t_range[0] + i as f64 * dt);
473
474    let mut u = Array2::zeros((nt + 1, nx));
475
476    // Time step 0: initial displacement
477    for i in 0..nx {
478        u[[0, i]] = initial_displacement(x[i]);
479    }
480    apply_bc_1d(&mut u, 0, left_bc, right_bc, dx);
481
482    // Time step 1: use Taylor expansion with initial velocity
483    // u(x, dt) ~ u(x, 0) + dt * v(x, 0) + 0.5 * dt^2 * c^2 * d2u/dx2
484    let is_periodic = matches!(
485        (left_bc, right_bc),
486        (FDBoundaryCondition::Periodic, FDBoundaryCondition::Periodic)
487    );
488    for i in 1..nx - 1 {
489        let d2u = u[[0, i + 1]] - 2.0 * u[[0, i]] + u[[0, i - 1]];
490        u[[1, i]] = u[[0, i]] + dt * initial_velocity(x[i]) + 0.5 * r2 * d2u;
491    }
492    if is_periodic {
493        let d2u = u[[0, 1]] - 2.0 * u[[0, 0]] + u[[0, nx - 2]];
494        u[[1, 0]] = u[[0, 0]] + dt * initial_velocity(x[0]) + 0.5 * r2 * d2u;
495        u[[1, nx - 1]] = u[[1, 0]];
496    } else {
497        apply_bc_1d(&mut u, 1, left_bc, right_bc, dx);
498    }
499
500    // Leapfrog time stepping for n >= 2
501    for n in 1..nt {
502        for i in 1..nx - 1 {
503            u[[n + 1, i]] = 2.0 * u[[n, i]] - u[[n - 1, i]]
504                + r2 * (u[[n, i + 1]] - 2.0 * u[[n, i]] + u[[n, i - 1]]);
505        }
506        if is_periodic {
507            u[[n + 1, 0]] = 2.0 * u[[n, 0]] - u[[n - 1, 0]]
508                + r2 * (u[[n, 1]] - 2.0 * u[[n, 0]] + u[[n, nx - 2]]);
509            u[[n + 1, nx - 1]] = u[[n + 1, 0]];
510        } else {
511            apply_bc_1d(&mut u, n + 1, left_bc, right_bc, dx);
512        }
513    }
514
515    Ok(WaveResult {
516        x,
517        t,
518        u,
519        cfl: Some(cfl),
520    })
521}
522
523/// Solve 2D wave equation: d2u/dt2 = c^2 * (d2u/dx2 + d2u/dy2)
524///
525/// Explicit leapfrog on a rectangular grid with Dirichlet BCs.
526pub fn solve_wave_2d(
527    c: f64,
528    x_range: [f64; 2],
529    y_range: [f64; 2],
530    t_range: [f64; 2],
531    nx: usize,
532    ny: usize,
533    nt: usize,
534    initial_displacement: &dyn Fn(f64, f64) -> f64,
535    initial_velocity: &dyn Fn(f64, f64) -> f64,
536    bc_value: f64,
537    save_every: usize,
538) -> PDEResult<Wave2DResult> {
539    if c <= 0.0 {
540        return Err(PDEError::InvalidParameter(
541            "Wave speed c must be positive".to_string(),
542        ));
543    }
544    if nx < 3 || ny < 3 {
545        return Err(PDEError::InvalidGrid(
546            "Need at least 3 grid points in each dimension".to_string(),
547        ));
548    }
549
550    let dx = (x_range[1] - x_range[0]) / (nx as f64 - 1.0);
551    let dy = (y_range[1] - y_range[0]) / (ny as f64 - 1.0);
552    let dt = (t_range[1] - t_range[0]) / nt as f64;
553
554    let cfl = cfl_wave_2d(c, dx, dy, dt);
555    if !cfl.is_stable {
556        return Err(PDEError::ComputationError(format!(
557            "Explicit 2D wave scheme unstable: {}",
558            cfl.description
559        )));
560    }
561
562    let rx2 = (c * dt / dx) * (c * dt / dx);
563    let ry2 = (c * dt / dy) * (c * dt / dy);
564
565    let x = Array1::from_shape_fn(nx, |i| x_range[0] + i as f64 * dx);
566    let y = Array1::from_shape_fn(ny, |j| y_range[0] + j as f64 * dy);
567
568    let save_every = if save_every == 0 { 1 } else { save_every };
569    let bc_vals = [bc_value; 4];
570
571    // u at time n-1, n, n+1
572    let mut u_prev = Array2::zeros((ny, nx));
573    let mut u_curr = Array2::zeros((ny, nx));
574
575    // Step 0
576    for j in 0..ny {
577        for i in 0..nx {
578            u_curr[[j, i]] = initial_displacement(x[i], y[j]);
579        }
580    }
581    apply_dirichlet_2d(&mut u_curr, bc_vals, nx, ny);
582
583    let mut snapshots = vec![u_curr.clone()];
584    let mut t_save = vec![t_range[0]];
585
586    // Step 1 via Taylor expansion
587    for j in 1..ny - 1 {
588        for i in 1..nx - 1 {
589            let d2x = u_curr[[j, i + 1]] - 2.0 * u_curr[[j, i]] + u_curr[[j, i - 1]];
590            let d2y = u_curr[[j + 1, i]] - 2.0 * u_curr[[j, i]] + u_curr[[j - 1, i]];
591            u_prev[[j, i]] =
592                u_curr[[j, i]] + dt * initial_velocity(x[i], y[j]) + 0.5 * (rx2 * d2x + ry2 * d2y);
593        }
594    }
595    apply_dirichlet_2d(&mut u_prev, bc_vals, nx, ny);
596    // swap: prev = step0, curr = step1
597    std::mem::swap(&mut u_prev, &mut u_curr);
598
599    if save_every == 1 {
600        snapshots.push(u_curr.clone());
601        t_save.push(t_range[0] + dt);
602    }
603
604    // Steps 2..nt via leapfrog
605    for n in 1..nt {
606        let mut u_next = Array2::zeros((ny, nx));
607        for j in 1..ny - 1 {
608            for i in 1..nx - 1 {
609                let d2x = u_curr[[j, i + 1]] - 2.0 * u_curr[[j, i]] + u_curr[[j, i - 1]];
610                let d2y = u_curr[[j + 1, i]] - 2.0 * u_curr[[j, i]] + u_curr[[j - 1, i]];
611                u_next[[j, i]] = 2.0 * u_curr[[j, i]] - u_prev[[j, i]] + rx2 * d2x + ry2 * d2y;
612            }
613        }
614        apply_dirichlet_2d(&mut u_next, bc_vals, nx, ny);
615        u_prev = u_curr;
616        u_curr = u_next;
617
618        if (n + 1) % save_every == 0 || n + 1 == nt {
619            snapshots.push(u_curr.clone());
620            t_save.push(t_range[0] + (n + 1) as f64 * dt);
621        }
622    }
623
624    Ok(Wave2DResult {
625        x,
626        y,
627        t: Array1::from_vec(t_save),
628        u: snapshots,
629        cfl: Some(cfl),
630    })
631}
632
633/// Result from a 2D wave equation solve
634#[derive(Debug, Clone)]
635pub struct Wave2DResult {
636    /// Spatial grid x values
637    pub x: Array1<f64>,
638    /// Spatial grid y values
639    pub y: Array1<f64>,
640    /// Time grid t values
641    pub t: Array1<f64>,
642    /// Solution snapshots, `u[time_index]` is `[ny, nx]`
643    pub u: Vec<Array2<f64>>,
644    /// CFL analysis
645    pub cfl: Option<CFLAnalysis>,
646}
647
648// ---------------------------------------------------------------------------
649// Poisson Equation (elliptic) via iterative methods
650// ---------------------------------------------------------------------------
651
652/// Result from a Poisson equation solve
653#[derive(Debug, Clone)]
654pub struct PoissonResult {
655    /// Spatial grid x values
656    pub x: Array1<f64>,
657    /// Spatial grid y values
658    pub y: Array1<f64>,
659    /// Solution u[ny, nx]
660    pub u: Array2<f64>,
661    /// Number of iterations
662    pub iterations: usize,
663    /// Final residual norm
664    pub residual: f64,
665    /// Convergence history (residual per iteration)
666    pub convergence_history: Vec<f64>,
667}
668
669/// Solve Poisson equation d2u/dx2 + d2u/dy2 = f(x,y) with Dirichlet BCs
670///
671/// Uses the specified iterative method (Jacobi, Gauss-Seidel, or SOR).
672pub fn solve_poisson_2d(
673    source: &dyn Fn(f64, f64) -> f64,
674    x_range: [f64; 2],
675    y_range: [f64; 2],
676    nx: usize,
677    ny: usize,
678    bc_values: [f64; 4], // [left, right, bottom, top]
679    method: EllipticIterativeMethod,
680    tol: f64,
681    max_iter: usize,
682) -> PDEResult<PoissonResult> {
683    if nx < 3 || ny < 3 {
684        return Err(PDEError::InvalidGrid(
685            "Need at least 3 grid points in each dimension".to_string(),
686        ));
687    }
688
689    let dx = (x_range[1] - x_range[0]) / (nx as f64 - 1.0);
690    let dy = (y_range[1] - y_range[0]) / (ny as f64 - 1.0);
691
692    let x = Array1::from_shape_fn(nx, |i| x_range[0] + i as f64 * dx);
693    let y = Array1::from_shape_fn(ny, |j| y_range[0] + j as f64 * dy);
694
695    let mut u = Array2::zeros((ny, nx));
696    apply_dirichlet_2d(&mut u, bc_values, nx, ny);
697
698    let dx2 = dx * dx;
699    let dy2 = dy * dy;
700    let denom = 2.0 * (1.0 / dx2 + 1.0 / dy2);
701
702    let mut convergence_history = Vec::with_capacity(max_iter);
703    let mut iterations = 0;
704    let mut residual = f64::MAX;
705
706    for iter in 0..max_iter {
707        match method {
708            EllipticIterativeMethod::Jacobi => {
709                let u_old = u.clone();
710                for j in 1..ny - 1 {
711                    for i in 1..nx - 1 {
712                        u[[j, i]] = ((u_old[[j, i + 1]] + u_old[[j, i - 1]]) / dx2
713                            + (u_old[[j + 1, i]] + u_old[[j - 1, i]]) / dy2
714                            - source(x[i], y[j]))
715                            / denom;
716                    }
717                }
718            }
719            EllipticIterativeMethod::GaussSeidel => {
720                for j in 1..ny - 1 {
721                    for i in 1..nx - 1 {
722                        u[[j, i]] = ((u[[j, i + 1]] + u[[j, i - 1]]) / dx2
723                            + (u[[j + 1, i]] + u[[j - 1, i]]) / dy2
724                            - source(x[i], y[j]))
725                            / denom;
726                    }
727                }
728            }
729            EllipticIterativeMethod::SOR(omega) => {
730                for j in 1..ny - 1 {
731                    for i in 1..nx - 1 {
732                        let gs_val = ((u[[j, i + 1]] + u[[j, i - 1]]) / dx2
733                            + (u[[j + 1, i]] + u[[j - 1, i]]) / dy2
734                            - source(x[i], y[j]))
735                            / denom;
736                        u[[j, i]] = (1.0 - omega) * u[[j, i]] + omega * gs_val;
737                    }
738                }
739            }
740        }
741
742        // Compute residual: r = f - Laplacian(u)
743        let mut res_sum = 0.0;
744        for j in 1..ny - 1 {
745            for i in 1..nx - 1 {
746                let lap = (u[[j, i + 1]] - 2.0 * u[[j, i]] + u[[j, i - 1]]) / dx2
747                    + (u[[j + 1, i]] - 2.0 * u[[j, i]] + u[[j - 1, i]]) / dy2;
748                let r = source(x[i], y[j]) - lap;
749                res_sum += r * r;
750            }
751        }
752        residual = (res_sum / ((nx - 2) * (ny - 2)) as f64).sqrt();
753        convergence_history.push(residual);
754        iterations = iter + 1;
755
756        if residual < tol {
757            break;
758        }
759    }
760
761    Ok(PoissonResult {
762        x,
763        y,
764        u,
765        iterations,
766        residual,
767        convergence_history,
768    })
769}
770
771// ---------------------------------------------------------------------------
772// Helper functions
773// ---------------------------------------------------------------------------
774
775/// Apply 1D boundary conditions at a given time step
776fn apply_bc_1d(
777    u: &mut Array2<f64>,
778    time_idx: usize,
779    left_bc: &FDBoundaryCondition,
780    right_bc: &FDBoundaryCondition,
781    dx: f64,
782) {
783    let nx = u.shape()[1];
784    match left_bc {
785        FDBoundaryCondition::Dirichlet(val) => {
786            u[[time_idx, 0]] = *val;
787        }
788        FDBoundaryCondition::Neumann(val) => {
789            // du/dx = val at left boundary => u[0] = u[1] - dx*val
790            u[[time_idx, 0]] = u[[time_idx, 1]] - dx * val;
791        }
792        FDBoundaryCondition::Periodic => {
793            // Handled in the main loop
794        }
795    }
796    match right_bc {
797        FDBoundaryCondition::Dirichlet(val) => {
798            u[[time_idx, nx - 1]] = *val;
799        }
800        FDBoundaryCondition::Neumann(val) => {
801            // du/dx = val at right boundary => u[nx-1] = u[nx-2] + dx*val
802            u[[time_idx, nx - 1]] = u[[time_idx, nx - 2]] + dx * val;
803        }
804        FDBoundaryCondition::Periodic => {
805            // Handled in the main loop
806        }
807    }
808}
809
810/// Apply Dirichlet BCs on a 2D array: [left, right, bottom, top]
811fn apply_dirichlet_2d(u: &mut Array2<f64>, bc: [f64; 4], nx: usize, ny: usize) {
812    for j in 0..ny {
813        u[[j, 0]] = bc[0]; // left
814        u[[j, nx - 1]] = bc[1]; // right
815    }
816    for i in 0..nx {
817        u[[0, i]] = bc[2]; // bottom
818        u[[ny - 1, i]] = bc[3]; // top
819    }
820}
821
822/// Adjust RHS for Crank-Nicolson boundary conditions
823#[allow(clippy::too_many_arguments)]
824fn apply_cn_bc_adjustment(
825    rhs: &mut Array1<f64>,
826    left_bc: &FDBoundaryCondition,
827    right_bc: &FDBoundaryCondition,
828    r: f64,
829    u: &Array2<f64>,
830    _time_idx: usize,
831    nx: usize,
832    dx: f64,
833) {
834    let interior_size = rhs.len();
835    if interior_size == 0 {
836        return;
837    }
838    // Left BC contribution to first interior point
839    match left_bc {
840        FDBoundaryCondition::Dirichlet(val) => {
841            rhs[0] += r * val;
842        }
843        FDBoundaryCondition::Neumann(val) => {
844            // Ghost: u[0] = u[1] - dx*val, so contribution is r*(u[1]-dx*val)
845            // The u[1] part is absorbed into the matrix diagonal modification
846            rhs[0] -= r * dx * val;
847        }
848        FDBoundaryCondition::Periodic => {}
849    }
850    // Right BC contribution to last interior point
851    match right_bc {
852        FDBoundaryCondition::Dirichlet(val) => {
853            rhs[interior_size - 1] += r * val;
854        }
855        FDBoundaryCondition::Neumann(val) => {
856            rhs[interior_size - 1] += r * dx * val;
857        }
858        FDBoundaryCondition::Periodic => {}
859    }
860    let _ = u; // used for potential future Neumann ghost adjustments
861}
862
863/// Solve a tridiagonal system with constant bands:
864/// sub * x[i-1] + diag * x[i] + sup * x[i+1] = rhs[i]
865fn solve_tridiag(
866    n: usize,
867    sub: f64,
868    diag: f64,
869    sup: f64,
870    rhs: &Array1<f64>,
871) -> PDEResult<Array1<f64>> {
872    if n == 0 {
873        return Ok(Array1::zeros(0));
874    }
875    let mut c_prime = vec![0.0; n];
876    let mut d_prime = vec![0.0; n];
877
878    // Forward sweep
879    c_prime[0] = sup / diag;
880    d_prime[0] = rhs[0] / diag;
881    for i in 1..n {
882        let m = diag - sub * c_prime[i - 1];
883        if m.abs() < 1e-15 {
884            return Err(PDEError::ComputationError(
885                "Zero pivot in tridiagonal solve".to_string(),
886            ));
887        }
888        c_prime[i] = if i < n - 1 { sup / m } else { 0.0 };
889        d_prime[i] = (rhs[i] - sub * d_prime[i - 1]) / m;
890    }
891
892    // Back substitution
893    let mut x = Array1::zeros(n);
894    x[n - 1] = d_prime[n - 1];
895    for i in (0..n - 1).rev() {
896        x[i] = d_prime[i] - c_prime[i] * x[i + 1];
897    }
898    Ok(x)
899}
900
901/// Solve a periodic tridiagonal system using the Sherman-Morrison formula
902fn solve_periodic_tridiag(
903    n: usize,
904    sub: f64,
905    diag: f64,
906    sup: f64,
907    rhs: &Array1<f64>,
908) -> PDEResult<Array1<f64>> {
909    if n < 3 {
910        return Err(PDEError::ComputationError(
911            "Periodic tridiagonal system needs at least 3 unknowns".to_string(),
912        ));
913    }
914
915    // Sherman-Morrison trick: perturb first and last diagonal entries
916    let gamma = -diag;
917    let d_mod = diag - gamma; // first diagonal becomes diag + gamma effectively
918    let d_last = diag - sub * sup / gamma; // last diagonal modified
919
920    // Build modified RHS for standard tridiagonal solve
921    let mut rhs_mod = rhs.clone();
922    // Create vector u_sm = [gamma, 0, ..., 0, sup]
923    // Create vector v_sm = [1, 0, ..., 0, sub/gamma]
924
925    // Solve A_mod * y = rhs_mod
926    // Solve A_mod * z = u_sm
927    // where A_mod is the tridiagonal with modified corners
928
929    // For simplicity, assemble the modified system as arrays and solve twice
930    let mut diag_arr = vec![diag; n];
931    diag_arr[0] = d_mod;
932    diag_arr[n - 1] = d_last;
933
934    let mut sub_arr = vec![sub; n];
935    sub_arr[0] = 0.0; // not used
936    let mut sup_arr = vec![sup; n];
937    sup_arr[n - 1] = 0.0; // not used
938
939    // Solve with general tridiagonal
940    let y = solve_general_tridiag(&sub_arr, &diag_arr, &sup_arr, &rhs_mod)?;
941
942    // u_sm vector
943    let mut u_sm = Array1::zeros(n);
944    u_sm[0] = gamma;
945    u_sm[n - 1] = sup;
946    let z = solve_general_tridiag(&sub_arr, &diag_arr, &sup_arr, &u_sm)?;
947
948    // v_sm = [1, 0, ..., 0, sub/gamma]
949    let v0 = 1.0;
950    let vn = sub / gamma;
951
952    let numer = v0 * y[0] + vn * y[n - 1];
953    let denom_val = 1.0 + v0 * z[0] + vn * z[n - 1];
954
955    if denom_val.abs() < 1e-15 {
956        return Err(PDEError::ComputationError(
957            "Singular periodic tridiagonal system".to_string(),
958        ));
959    }
960
961    let factor = numer / denom_val;
962    let mut x = Array1::zeros(n);
963    for i in 0..n {
964        x[i] = y[i] - factor * z[i];
965    }
966
967    Ok(x)
968}
969
970/// General tridiagonal solver (varying bands)
971fn solve_general_tridiag(
972    sub: &[f64],
973    diag: &[f64],
974    sup: &[f64],
975    rhs: &Array1<f64>,
976) -> PDEResult<Array1<f64>> {
977    let n = rhs.len();
978    if n == 0 {
979        return Ok(Array1::zeros(0));
980    }
981
982    let mut c_prime = vec![0.0; n];
983    let mut d_prime = vec![0.0; n];
984
985    if diag[0].abs() < 1e-15 {
986        return Err(PDEError::ComputationError(
987            "Zero pivot in general tridiagonal solve".to_string(),
988        ));
989    }
990    c_prime[0] = sup[0] / diag[0];
991    d_prime[0] = rhs[0] / diag[0];
992
993    for i in 1..n {
994        let m = diag[i] - sub[i] * c_prime[i - 1];
995        if m.abs() < 1e-15 {
996            return Err(PDEError::ComputationError(
997                "Zero pivot in general tridiagonal solve".to_string(),
998            ));
999        }
1000        c_prime[i] = if i < n - 1 { sup[i] / m } else { 0.0 };
1001        d_prime[i] = (rhs[i] - sub[i] * d_prime[i - 1]) / m;
1002    }
1003
1004    let mut x = Array1::zeros(n);
1005    x[n - 1] = d_prime[n - 1];
1006    for i in (0..n - 1).rev() {
1007        x[i] = d_prime[i] - c_prime[i] * x[i + 1];
1008    }
1009    Ok(x)
1010}
1011
1012// ---------------------------------------------------------------------------
1013// Tests
1014// ---------------------------------------------------------------------------
1015
1016#[cfg(test)]
1017mod tests {
1018    use super::*;
1019    use std::f64::consts::PI;
1020
1021    #[test]
1022    fn test_cfl_heat_1d_stable() {
1023        let cfl = cfl_heat_1d(0.01, 0.1, 0.1);
1024        // r = 0.01 * 0.1 / 0.01 = 0.1 <= 0.5 => stable
1025        assert!(cfl.is_stable);
1026        assert!(cfl.cfl_number < 0.5 + 1e-10);
1027    }
1028
1029    #[test]
1030    fn test_cfl_heat_1d_unstable() {
1031        let cfl = cfl_heat_1d(1.0, 0.01, 0.01);
1032        // r = 1.0 * 0.01 / 0.0001 = 100 >> 0.5 => unstable
1033        assert!(!cfl.is_stable);
1034    }
1035
1036    #[test]
1037    fn test_cfl_wave_1d_stable() {
1038        let cfl = cfl_wave_1d(1.0, 0.1, 0.05);
1039        // CFL = 1.0 * 0.05 / 0.1 = 0.5 <= 1.0
1040        assert!(cfl.is_stable);
1041    }
1042
1043    #[test]
1044    fn test_heat_1d_explicit_constant_ic() {
1045        // u(x,0) = 1.0 with Dirichlet u(0)=1, u(1)=1
1046        // Steady-state is u=1 everywhere
1047        let result = solve_heat_1d(
1048            0.01,
1049            [0.0, 1.0],
1050            [0.0, 0.1],
1051            21,
1052            100,
1053            &|_x| 1.0,
1054            &FDBoundaryCondition::Dirichlet(1.0),
1055            &FDBoundaryCondition::Dirichlet(1.0),
1056            TimeSteppingMethod::Explicit,
1057        );
1058        let res = result.expect("Should succeed");
1059        // All values should remain 1.0
1060        let last = res.u.row(res.u.shape()[0] - 1);
1061        for &v in last.iter() {
1062            assert!((v - 1.0).abs() < 1e-10);
1063        }
1064    }
1065
1066    #[test]
1067    fn test_heat_1d_explicit_decay() {
1068        // u(x,0) = sin(pi*x) with Dirichlet u(0)=0, u(1)=0
1069        // Exact: u(x,t) = sin(pi*x) * exp(-pi^2 * alpha * t)
1070        let alpha = 0.01;
1071        let nx = 51;
1072        let nt = 5000;
1073        let result = solve_heat_1d(
1074            alpha,
1075            [0.0, 1.0],
1076            [0.0, 1.0],
1077            nx,
1078            nt,
1079            &|x| (PI * x).sin(),
1080            &FDBoundaryCondition::Dirichlet(0.0),
1081            &FDBoundaryCondition::Dirichlet(0.0),
1082            TimeSteppingMethod::Explicit,
1083        );
1084        let res = result.expect("Should succeed");
1085        let last = res.u.row(res.u.shape()[0] - 1);
1086        // Check midpoint: exact ~ sin(pi*0.5) * exp(-pi^2*0.01*1.0) ~ exp(-0.0987..) ~ 0.906
1087        let mid = nx / 2;
1088        let exact = (PI * 0.5).sin() * (-PI * PI * alpha * 1.0).exp();
1089        assert!(
1090            (last[mid] - exact).abs() < 0.02,
1091            "Got {}, expected {} (tol=0.02)",
1092            last[mid],
1093            exact
1094        );
1095    }
1096
1097    #[test]
1098    fn test_heat_1d_crank_nicolson() {
1099        let alpha = 0.1;
1100        let nx = 21;
1101        let nt = 50;
1102        let result = solve_heat_1d(
1103            alpha,
1104            [0.0, 1.0],
1105            [0.0, 1.0],
1106            nx,
1107            nt,
1108            &|x| (PI * x).sin(),
1109            &FDBoundaryCondition::Dirichlet(0.0),
1110            &FDBoundaryCondition::Dirichlet(0.0),
1111            TimeSteppingMethod::CrankNicolson,
1112        );
1113        let res = result.expect("Should succeed");
1114        let last = res.u.row(res.u.shape()[0] - 1);
1115        let mid = nx / 2;
1116        let exact = (PI * 0.5).sin() * (-PI * PI * alpha * 1.0).exp();
1117        assert!(
1118            (last[mid] - exact).abs() < 0.05,
1119            "CN got {}, expected {} (tol=0.05)",
1120            last[mid],
1121            exact
1122        );
1123    }
1124
1125    #[test]
1126    fn test_heat_1d_neumann() {
1127        // Insulated boundaries: du/dx=0 at both ends
1128        // u(x,0) = 1.0, should remain 1.0
1129        let result = solve_heat_1d(
1130            0.01,
1131            [0.0, 1.0],
1132            [0.0, 0.5],
1133            21,
1134            200,
1135            &|_| 1.0,
1136            &FDBoundaryCondition::Neumann(0.0),
1137            &FDBoundaryCondition::Neumann(0.0),
1138            TimeSteppingMethod::Explicit,
1139        );
1140        let res = result.expect("Should succeed");
1141        let last = res.u.row(res.u.shape()[0] - 1);
1142        for &v in last.iter() {
1143            assert!(
1144                (v - 1.0).abs() < 0.01,
1145                "Neumann with constant IC should stay ~1.0, got {v}"
1146            );
1147        }
1148    }
1149
1150    #[test]
1151    fn test_heat_1d_periodic() {
1152        // Periodic heat equation: u(x,0) = sin(2*pi*x)
1153        let alpha = 0.01;
1154        let nx = 41;
1155        let nt = 500;
1156        let result = solve_heat_1d(
1157            alpha,
1158            [0.0, 1.0],
1159            [0.0, 0.5],
1160            nx,
1161            nt,
1162            &|x| (2.0 * PI * x).sin(),
1163            &FDBoundaryCondition::Periodic,
1164            &FDBoundaryCondition::Periodic,
1165            TimeSteppingMethod::Explicit,
1166        );
1167        let res = result.expect("Should succeed");
1168        let last = res.u.row(res.u.shape()[0] - 1);
1169        // Exact: exp(-4*pi^2*alpha*t)*sin(2*pi*x)
1170        // At t=0.5: decay factor = exp(-4*pi^2*0.01*0.5) ~ exp(-0.197) ~ 0.821
1171        let decay = (-4.0 * PI * PI * alpha * 0.5).exp();
1172        let mid = nx / 4; // x=0.25 => sin(pi/2)=1.0
1173        let exact = decay * (2.0 * PI * 0.25).sin();
1174        assert!(
1175            (last[mid] - exact).abs() < 0.05,
1176            "Periodic got {}, expected {exact} (tol=0.05)",
1177            last[mid]
1178        );
1179    }
1180
1181    #[test]
1182    fn test_heat_2d_constant() {
1183        // Constant IC and matching BCs: should stay constant
1184        let result = solve_heat_2d(
1185            0.01,
1186            [0.0, 1.0],
1187            [0.0, 1.0],
1188            [0.0, 0.1],
1189            11,
1190            11,
1191            50,
1192            &|_, _| 1.0,
1193            [1.0, 1.0, 1.0, 1.0],
1194            50,
1195        );
1196        let res = result.expect("Should succeed");
1197        let last = &res.u[res.u.len() - 1];
1198        for j in 0..11 {
1199            for i in 0..11 {
1200                assert!(
1201                    (last[[j, i]] - 1.0).abs() < 1e-10,
1202                    "2D heat constant: [{j},{i}] = {}",
1203                    last[[j, i]]
1204                );
1205            }
1206        }
1207    }
1208
1209    #[test]
1210    fn test_wave_1d_standing() {
1211        // Standing wave: u(x,0) = sin(pi*x), v(x,0) = 0
1212        // Exact: u(x,t) = sin(pi*x) * cos(pi*c*t)
1213        let c = 1.0;
1214        let nx = 101;
1215        let nt = 200;
1216        let result = solve_wave_1d(
1217            c,
1218            [0.0, 1.0],
1219            [0.0, 0.5],
1220            nx,
1221            nt,
1222            &|x| (PI * x).sin(),
1223            &|_x| 0.0,
1224            &FDBoundaryCondition::Dirichlet(0.0),
1225            &FDBoundaryCondition::Dirichlet(0.0),
1226        );
1227        let res = result.expect("Should succeed");
1228        let last = res.u.row(res.u.shape()[0] - 1);
1229        let mid = nx / 2;
1230        let exact = (PI * 0.5).sin() * (PI * c * 0.5).cos();
1231        assert!(
1232            (last[mid] - exact).abs() < 0.05,
1233            "Wave got {}, expected {exact}",
1234            last[mid]
1235        );
1236    }
1237
1238    #[test]
1239    fn test_wave_1d_periodic() {
1240        let c = 1.0;
1241        let nx = 101;
1242        let nt = 100;
1243        let result = solve_wave_1d(
1244            c,
1245            [0.0, 1.0],
1246            [0.0, 0.5],
1247            nx,
1248            nt,
1249            &|x| (2.0 * PI * x).sin(),
1250            &|_x| 0.0,
1251            &FDBoundaryCondition::Periodic,
1252            &FDBoundaryCondition::Periodic,
1253        );
1254        assert!(result.is_ok(), "Periodic wave should succeed");
1255    }
1256
1257    #[test]
1258    fn test_wave_2d_basic() {
1259        let result = solve_wave_2d(
1260            1.0,
1261            [0.0, 1.0],
1262            [0.0, 1.0],
1263            [0.0, 0.1],
1264            21,
1265            21,
1266            50,
1267            &|x, y| (PI * x).sin() * (PI * y).sin(),
1268            &|_, _| 0.0,
1269            0.0,
1270            50,
1271        );
1272        assert!(result.is_ok(), "2D wave should succeed");
1273    }
1274
1275    #[test]
1276    fn test_poisson_zero_source() {
1277        // Laplace equation with constant Dirichlet BCs => u = constant
1278        let result = solve_poisson_2d(
1279            &|_, _| 0.0,
1280            [0.0, 1.0],
1281            [0.0, 1.0],
1282            11,
1283            11,
1284            [1.0, 1.0, 1.0, 1.0],
1285            EllipticIterativeMethod::GaussSeidel,
1286            1e-8,
1287            5000,
1288        );
1289        let res = result.expect("Should succeed");
1290        for j in 0..11 {
1291            for i in 0..11 {
1292                assert!(
1293                    (res.u[[j, i]] - 1.0).abs() < 1e-4,
1294                    "Laplace [{j},{i}] = {} (expected 1.0)",
1295                    res.u[[j, i]]
1296                );
1297            }
1298        }
1299    }
1300
1301    #[test]
1302    fn test_poisson_jacobi() {
1303        let result = solve_poisson_2d(
1304            &|_, _| -2.0,
1305            [0.0, 1.0],
1306            [0.0, 1.0],
1307            21,
1308            21,
1309            [0.0, 0.0, 0.0, 0.0],
1310            EllipticIterativeMethod::Jacobi,
1311            1e-6,
1312            10000,
1313        );
1314        let res = result.expect("Should succeed");
1315        // With f=-2 and zero BCs, the solution is a parabolic bowl
1316        // At center (0.5, 0.5): approximate value
1317        let mid = 10;
1318        assert!(
1319            res.u[[mid, mid]] > 0.0,
1320            "Center should be positive for negative source"
1321        );
1322    }
1323
1324    #[test]
1325    fn test_poisson_sor() {
1326        // SOR with omega=1.5 should converge faster than Gauss-Seidel
1327        let result_gs = solve_poisson_2d(
1328            &|_, _| -2.0,
1329            [0.0, 1.0],
1330            [0.0, 1.0],
1331            21,
1332            21,
1333            [0.0, 0.0, 0.0, 0.0],
1334            EllipticIterativeMethod::GaussSeidel,
1335            1e-6,
1336            10000,
1337        )
1338        .expect("GS should succeed");
1339
1340        let result_sor = solve_poisson_2d(
1341            &|_, _| -2.0,
1342            [0.0, 1.0],
1343            [0.0, 1.0],
1344            21,
1345            21,
1346            [0.0, 0.0, 0.0, 0.0],
1347            EllipticIterativeMethod::SOR(1.5),
1348            1e-6,
1349            10000,
1350        )
1351        .expect("SOR should succeed");
1352
1353        // SOR should converge in fewer iterations
1354        assert!(
1355            result_sor.iterations <= result_gs.iterations,
1356            "SOR ({}) should converge <= GS ({})",
1357            result_sor.iterations,
1358            result_gs.iterations
1359        );
1360    }
1361
1362    #[test]
1363    fn test_heat_explicit_unstable_rejected() {
1364        // Very large dt should be rejected by CFL check
1365        let result = solve_heat_1d(
1366            1.0,
1367            [0.0, 1.0],
1368            [0.0, 1.0],
1369            11,
1370            2,
1371            &|_| 0.0,
1372            &FDBoundaryCondition::Dirichlet(0.0),
1373            &FDBoundaryCondition::Dirichlet(0.0),
1374            TimeSteppingMethod::Explicit,
1375        );
1376        assert!(result.is_err(), "Should reject unstable explicit scheme");
1377    }
1378
1379    #[test]
1380    fn test_cfl_heat_2d() {
1381        let cfl = cfl_heat_2d(0.01, 0.1, 0.1, 0.1);
1382        // r = 0.01*0.1*(100+100) = 0.2 <= 0.5
1383        assert!(cfl.is_stable);
1384    }
1385
1386    #[test]
1387    fn test_cfl_wave_2d() {
1388        let cfl = cfl_wave_2d(1.0, 0.1, 0.1, 0.05);
1389        // CFL = 1.0 * 0.05 * sqrt(200) ~ 0.05*14.14 ~ 0.707 <= 1.0
1390        assert!(cfl.is_stable);
1391    }
1392}