Skip to main content

scirs2_integrate/pde/
mol_enhanced.rs

1//! Enhanced Method of Lines (MOL) PDE Solver
2//!
3//! Semi-discretizes PDEs in space, converting them to ODE systems that are
4//! then integrated with the existing ODE solvers (RK45, BDF, etc.).
5//!
6//! ## Features
7//! - Configurable spatial stencils (2nd and 4th order)
8//! - Integration with existing ODE solvers (RK45, BDF)
9//! - Advection equation solver (upwind, Lax-Wendroff)
10//! - Reaction-diffusion system solver
11//! - Configurable boundary conditions (Dirichlet, Neumann, periodic)
12
13use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
14use std::sync::Arc;
15
16use crate::ode::{solve_ivp, ODEMethod, ODEOptions};
17use crate::pde::{PDEError, PDEResult};
18
19// ---------------------------------------------------------------------------
20// Types
21// ---------------------------------------------------------------------------
22
23/// Spatial stencil order for finite difference discretization
24#[derive(Debug, Clone, Copy, PartialEq)]
25pub enum StencilOrder {
26    /// Second-order central differences
27    Second,
28    /// Fourth-order central differences
29    Fourth,
30}
31
32/// Boundary condition for MOL solvers
33#[derive(Debug, Clone)]
34pub enum MOLBoundaryCondition {
35    /// Fixed value: u(boundary) = value
36    Dirichlet(f64),
37    /// Fixed derivative: du/dn(boundary) = value
38    Neumann(f64),
39    /// Periodic: u wraps around
40    Periodic,
41}
42
43/// ODE method selection for time integration
44#[derive(Debug, Clone, Copy, PartialEq)]
45pub enum MOLTimeIntegrator {
46    /// Runge-Kutta 4-5 (Dormand-Prince), good for non-stiff problems
47    RK45,
48    /// BDF (backward differentiation formula), good for stiff problems
49    BDF,
50    /// Runge-Kutta 2-3 (Bogacki-Shampine)
51    RK23,
52}
53
54impl MOLTimeIntegrator {
55    fn to_ode_method(self) -> ODEMethod {
56        match self {
57            MOLTimeIntegrator::RK45 => ODEMethod::RK45,
58            MOLTimeIntegrator::BDF => ODEMethod::Bdf,
59            MOLTimeIntegrator::RK23 => ODEMethod::RK23,
60        }
61    }
62}
63
64/// Options for the enhanced MOL solver
65#[derive(Debug, Clone)]
66pub struct MOLEnhancedOptions {
67    /// Time integrator to use
68    pub integrator: MOLTimeIntegrator,
69    /// Spatial stencil order
70    pub stencil: StencilOrder,
71    /// Absolute tolerance for ODE solver
72    pub atol: f64,
73    /// Relative tolerance for ODE solver
74    pub rtol: f64,
75    /// Maximum ODE steps
76    pub max_steps: usize,
77}
78
79impl Default for MOLEnhancedOptions {
80    fn default() -> Self {
81        MOLEnhancedOptions {
82            integrator: MOLTimeIntegrator::RK45,
83            stencil: StencilOrder::Second,
84            atol: 1e-6,
85            rtol: 1e-3,
86            max_steps: 10000,
87        }
88    }
89}
90
91/// Result from MOL solve
92#[derive(Debug, Clone)]
93pub struct MOLEnhancedResult {
94    /// Spatial grid
95    pub x: Array1<f64>,
96    /// Time points
97    pub t: Vec<f64>,
98    /// Solution u[time_step, spatial_index]
99    pub u: Vec<Array1<f64>>,
100    /// Number of ODE function evaluations
101    pub n_eval: usize,
102    /// Number of ODE steps taken
103    pub n_steps: usize,
104}
105
106// ---------------------------------------------------------------------------
107// Diffusion equation solver
108// ---------------------------------------------------------------------------
109
110/// Solve 1D diffusion (heat) equation: du/dt = alpha * d2u/dx2 + source(x,t,u)
111///
112/// Semi-discretizes in space using configurable stencils, then solves the
113/// resulting ODE system with the chosen time integrator.
114pub fn mol_diffusion_1d(
115    alpha: f64,
116    x_range: [f64; 2],
117    t_range: [f64; 2],
118    nx: usize,
119    left_bc: MOLBoundaryCondition,
120    right_bc: MOLBoundaryCondition,
121    initial_condition: impl Fn(f64) -> f64 + Send + Sync + 'static,
122    source: Option<Arc<dyn Fn(f64, f64, f64) -> f64 + Send + Sync>>,
123    options: &MOLEnhancedOptions,
124) -> PDEResult<MOLEnhancedResult> {
125    if alpha <= 0.0 {
126        return Err(PDEError::InvalidParameter(
127            "Diffusion coefficient alpha must be positive".to_string(),
128        ));
129    }
130    if nx < 5 {
131        return Err(PDEError::InvalidGrid(
132            "Need at least 5 spatial points for MOL".to_string(),
133        ));
134    }
135
136    let dx = (x_range[1] - x_range[0]) / (nx as f64 - 1.0);
137    let x = Array1::from_shape_fn(nx, |i| x_range[0] + i as f64 * dx);
138
139    // Initial condition
140    let mut u0 = Array1::from_shape_fn(nx, |i| initial_condition(x[i]));
141    apply_mol_bc(&mut u0, &left_bc, &right_bc, dx);
142
143    let stencil = options.stencil;
144    let x_clone = x.clone();
145    let left_bc_c = left_bc.clone();
146    let right_bc_c = right_bc.clone();
147
148    // Build ODE RHS
149    let rhs = move |t: f64, u: ArrayView1<f64>| -> Array1<f64> {
150        let n = u.len();
151        let mut dudt = Array1::zeros(n);
152
153        // Diffusion operator
154        match stencil {
155            StencilOrder::Second => {
156                let r = alpha / (dx * dx);
157                for i in 1..n - 1 {
158                    dudt[i] = r * (u[i + 1] - 2.0 * u[i] + u[i - 1]);
159                }
160            }
161            StencilOrder::Fourth => {
162                let r = alpha / (12.0 * dx * dx);
163                for i in 2..n - 2 {
164                    dudt[i] = r
165                        * (-u[i + 2] + 16.0 * u[i + 1] - 30.0 * u[i] + 16.0 * u[i - 1] - u[i - 2]);
166                }
167                // Fallback to 2nd order near boundaries
168                let r2 = alpha / (dx * dx);
169                if n > 2 {
170                    dudt[1] = r2 * (u[2] - 2.0 * u[1] + u[0]);
171                }
172                if n > 3 {
173                    dudt[n - 2] = r2 * (u[n - 1] - 2.0 * u[n - 2] + u[n - 3]);
174                }
175            }
176        }
177
178        // Source term
179        if let Some(ref src) = source {
180            for i in 1..n - 1 {
181                dudt[i] += src(x_clone[i], t, u[i]);
182            }
183        }
184
185        // Boundary treatment
186        apply_mol_bc_rhs(
187            &mut dudt,
188            &u,
189            &left_bc_c,
190            &right_bc_c,
191            n,
192            dx,
193            alpha,
194            &x_clone,
195            t,
196            &source,
197        );
198
199        dudt
200    };
201
202    run_mol_ode(x, u0, t_range, rhs, options)
203}
204
205// ---------------------------------------------------------------------------
206// Advection equation solver
207// ---------------------------------------------------------------------------
208
209/// Advection scheme type
210#[derive(Debug, Clone, Copy, PartialEq)]
211pub enum AdvectionScheme {
212    /// First-order upwind (stable, diffusive)
213    Upwind,
214    /// Lax-Wendroff (second-order, dispersive)
215    LaxWendroff,
216    /// Central difference (second-order, no inherent stability)
217    Central,
218}
219
220/// Solve 1D advection equation: du/dt + velocity * du/dx = source(x,t)
221///
222/// Spatial discretization depends on the chosen scheme.
223pub fn mol_advection_1d(
224    velocity: f64,
225    x_range: [f64; 2],
226    t_range: [f64; 2],
227    nx: usize,
228    left_bc: MOLBoundaryCondition,
229    right_bc: MOLBoundaryCondition,
230    initial_condition: impl Fn(f64) -> f64 + Send + Sync + 'static,
231    source: Option<Arc<dyn Fn(f64, f64) -> f64 + Send + Sync>>,
232    scheme: AdvectionScheme,
233    options: &MOLEnhancedOptions,
234) -> PDEResult<MOLEnhancedResult> {
235    if nx < 5 {
236        return Err(PDEError::InvalidGrid(
237            "Need at least 5 spatial points for advection".to_string(),
238        ));
239    }
240
241    let dx = (x_range[1] - x_range[0]) / (nx as f64 - 1.0);
242    let x = Array1::from_shape_fn(nx, |i| x_range[0] + i as f64 * dx);
243
244    let mut u0 = Array1::from_shape_fn(nx, |i| initial_condition(x[i]));
245    apply_mol_bc(&mut u0, &left_bc, &right_bc, dx);
246
247    let x_clone = x.clone();
248    let left_bc_c = left_bc.clone();
249    let right_bc_c = right_bc.clone();
250
251    let rhs = move |t: f64, u: ArrayView1<f64>| -> Array1<f64> {
252        let n = u.len();
253        let mut dudt = Array1::zeros(n);
254
255        match scheme {
256            AdvectionScheme::Upwind => {
257                if velocity >= 0.0 {
258                    // Backward difference
259                    for i in 1..n - 1 {
260                        dudt[i] = -velocity * (u[i] - u[i - 1]) / dx;
261                    }
262                } else {
263                    // Forward difference
264                    for i in 1..n - 1 {
265                        dudt[i] = -velocity * (u[i + 1] - u[i]) / dx;
266                    }
267                }
268            }
269            AdvectionScheme::LaxWendroff => {
270                // Lax-Wendroff: du/dt = -v * du/dx + 0.5*v^2*dt * d2u/dx2
271                // In MOL context, we just use the centered + diffusive correction stencil:
272                // du/dt = -v/(2dx)*(u[i+1]-u[i-1])
273                // The Lax-Wendroff correction is embedded in the time stepping
274                // For pure MOL, use centered advection (the ODE solver handles stability)
275                for i in 1..n - 1 {
276                    // Centered advection for MOL
277                    let advection = -velocity * (u[i + 1] - u[i - 1]) / (2.0 * dx);
278                    // Add numerical diffusion to stabilize: v^2/(2) * d2u/dx2 * approx_dt
279                    // In pure MOL, we skip this and let the ODE solver handle it
280                    // but add a small artificial diffusion proportional to dx
281                    let diffusion =
282                        velocity.abs() * dx / 2.0 * (u[i + 1] - 2.0 * u[i] + u[i - 1]) / (dx * dx);
283                    dudt[i] = advection + diffusion;
284                }
285            }
286            AdvectionScheme::Central => {
287                for i in 1..n - 1 {
288                    dudt[i] = -velocity * (u[i + 1] - u[i - 1]) / (2.0 * dx);
289                }
290            }
291        }
292
293        // Source term
294        if let Some(ref src) = source {
295            for i in 1..n - 1 {
296                dudt[i] += src(x_clone[i], t);
297            }
298        }
299
300        // Boundary
301        apply_advection_bc_rhs(&mut dudt, &u, &left_bc_c, &right_bc_c, n, dx, velocity);
302
303        dudt
304    };
305
306    run_mol_ode(x, u0, t_range, rhs, options)
307}
308
309// ---------------------------------------------------------------------------
310// Reaction-diffusion system solver
311// ---------------------------------------------------------------------------
312
313/// Reaction-diffusion system:
314///   du_i/dt = D_i * d2u_i/dx2 + R_i(x, t, u_1, ..., u_m)
315///
316/// where i = 1..m species, D_i are diffusion coefficients,
317/// and R_i are reaction terms coupling the species.
318pub struct ReactionDiffusionSystem {
319    /// Number of species
320    pub n_species: usize,
321    /// Diffusion coefficients for each species
322    pub diffusion_coeffs: Vec<f64>,
323    /// Reaction function: `(x, t, &[u_species]) -> Vec<f64>` (one per species)
324    pub reaction: Arc<dyn Fn(f64, f64, &[f64]) -> Vec<f64> + Send + Sync>,
325}
326
327/// Result from reaction-diffusion solve
328#[derive(Debug, Clone)]
329pub struct ReactionDiffusionResult {
330    /// Spatial grid
331    pub x: Array1<f64>,
332    /// Time points
333    pub t: Vec<f64>,
334    /// Solution: `u[time_step]` is a 2D array `[n_species, nx]`
335    pub u: Vec<Array2<f64>>,
336    /// Number of ODE evaluations
337    pub n_eval: usize,
338    /// Number of ODE steps
339    pub n_steps: usize,
340}
341
342/// Solve a reaction-diffusion system on [x_left, x_right] with Dirichlet BCs.
343///
344/// Each species has its own diffusion coefficient and they are coupled
345/// through the reaction function.
346pub fn mol_reaction_diffusion(
347    system: &ReactionDiffusionSystem,
348    x_range: [f64; 2],
349    t_range: [f64; 2],
350    nx: usize,
351    initial_conditions: &[impl Fn(f64) -> f64],
352    bc_left: &[f64],  // Dirichlet values at left for each species
353    bc_right: &[f64], // Dirichlet values at right for each species
354    options: &MOLEnhancedOptions,
355) -> PDEResult<ReactionDiffusionResult> {
356    let m = system.n_species;
357    if initial_conditions.len() != m || bc_left.len() != m || bc_right.len() != m {
358        return Err(PDEError::InvalidParameter(format!(
359            "Expected {} initial conditions/BCs for {} species",
360            m, m
361        )));
362    }
363    if nx < 5 {
364        return Err(PDEError::InvalidGrid(
365            "Need at least 5 spatial points".to_string(),
366        ));
367    }
368
369    let dx = (x_range[1] - x_range[0]) / (nx as f64 - 1.0);
370    let x = Array1::from_shape_fn(nx, |i| x_range[0] + i as f64 * dx);
371    let total_dof = m * nx;
372
373    // Pack initial conditions into a single vector: [species0_node0, .., species0_nodeN, species1_node0, ..]
374    let mut u0 = Array1::zeros(total_dof);
375    for s in 0..m {
376        for i in 0..nx {
377            u0[s * nx + i] = initial_conditions[s](x[i]);
378        }
379        // Apply BCs
380        u0[s * nx] = bc_left[s];
381        u0[s * nx + nx - 1] = bc_right[s];
382    }
383
384    let diffusion_coeffs = system.diffusion_coeffs.clone();
385    let reaction = system.reaction.clone();
386    let x_clone = x.clone();
387    let bc_l = bc_left.to_vec();
388    let bc_r = bc_right.to_vec();
389
390    let rhs = move |t: f64, u: ArrayView1<f64>| -> Array1<f64> {
391        let mut dudt = Array1::zeros(total_dof);
392        let dx2 = dx * dx;
393
394        // Diffusion for each species
395        for s in 0..m {
396            let offset = s * nx;
397            let d = diffusion_coeffs[s];
398            for i in 1..nx - 1 {
399                dudt[offset + i] =
400                    d * (u[offset + i + 1] - 2.0 * u[offset + i] + u[offset + i - 1]) / dx2;
401            }
402            // Dirichlet BCs: du/dt = 0 at boundaries
403            dudt[offset] = 0.0;
404            dudt[offset + nx - 1] = 0.0;
405        }
406
407        // Reaction terms (coupled)
408        let mut species_vals = vec![0.0; m];
409        for i in 1..nx - 1 {
410            for s in 0..m {
411                species_vals[s] = u[s * nx + i];
412            }
413            let r = reaction(x_clone[i], t, &species_vals);
414            for s in 0..m {
415                if s < r.len() {
416                    dudt[s * nx + i] += r[s];
417                }
418            }
419        }
420
421        dudt
422    };
423
424    // ODE solve
425    let ode_opts = ODEOptions {
426        method: options.integrator.to_ode_method(),
427        rtol: options.rtol,
428        atol: options.atol,
429        max_steps: options.max_steps,
430        dense_output: false,
431        ..Default::default()
432    };
433
434    // Wrap closure in Arc for Clone bound
435    let rhs_arc = Arc::new(rhs);
436    let rhs_clone = move |t: f64, u: ArrayView1<f64>| -> Array1<f64> { rhs_arc(t, u) };
437
438    let result = solve_ivp(rhs_clone, t_range, u0, Some(ode_opts))?;
439
440    // Unpack results
441    let mut t_vec = Vec::new();
442    let mut u_vec = Vec::new();
443
444    for (step, y) in result.y.iter().enumerate() {
445        t_vec.push(result.t[step]);
446        let mut u_2d = Array2::zeros((m, nx));
447        for s in 0..m {
448            for i in 0..nx {
449                u_2d[[s, i]] = y[s * nx + i];
450            }
451        }
452        u_vec.push(u_2d);
453    }
454
455    Ok(ReactionDiffusionResult {
456        x,
457        t: t_vec,
458        u: u_vec,
459        n_eval: result.n_eval,
460        n_steps: result.n_steps,
461    })
462}
463
464// ---------------------------------------------------------------------------
465// Advection-diffusion equation
466// ---------------------------------------------------------------------------
467
468/// Solve 1D advection-diffusion: du/dt + v * du/dx = D * d2u/dx2 + source(x,t)
469pub fn mol_advection_diffusion_1d(
470    velocity: f64,
471    diffusion: f64,
472    x_range: [f64; 2],
473    t_range: [f64; 2],
474    nx: usize,
475    left_bc: MOLBoundaryCondition,
476    right_bc: MOLBoundaryCondition,
477    initial_condition: impl Fn(f64) -> f64 + Send + Sync + 'static,
478    source: Option<Arc<dyn Fn(f64, f64) -> f64 + Send + Sync>>,
479    options: &MOLEnhancedOptions,
480) -> PDEResult<MOLEnhancedResult> {
481    if diffusion < 0.0 {
482        return Err(PDEError::InvalidParameter(
483            "Diffusion coefficient must be non-negative".to_string(),
484        ));
485    }
486    if nx < 5 {
487        return Err(PDEError::InvalidGrid(
488            "Need at least 5 spatial points".to_string(),
489        ));
490    }
491
492    let dx = (x_range[1] - x_range[0]) / (nx as f64 - 1.0);
493    let x = Array1::from_shape_fn(nx, |i| x_range[0] + i as f64 * dx);
494
495    let mut u0 = Array1::from_shape_fn(nx, |i| initial_condition(x[i]));
496    apply_mol_bc(&mut u0, &left_bc, &right_bc, dx);
497
498    let x_clone = x.clone();
499    let left_bc_c = left_bc.clone();
500    let right_bc_c = right_bc.clone();
501
502    let rhs = move |t: f64, u: ArrayView1<f64>| -> Array1<f64> {
503        let n = u.len();
504        let mut dudt = Array1::zeros(n);
505        let dx2 = dx * dx;
506
507        for i in 1..n - 1 {
508            // Upwind advection
509            let advection = if velocity >= 0.0 {
510                -velocity * (u[i] - u[i - 1]) / dx
511            } else {
512                -velocity * (u[i + 1] - u[i]) / dx
513            };
514            // Central diffusion
515            let diff = diffusion * (u[i + 1] - 2.0 * u[i] + u[i - 1]) / dx2;
516            dudt[i] = advection + diff;
517        }
518
519        if let Some(ref src) = source {
520            for i in 1..n - 1 {
521                dudt[i] += src(x_clone[i], t);
522            }
523        }
524
525        apply_advection_bc_rhs(&mut dudt, &u, &left_bc_c, &right_bc_c, n, dx, velocity);
526
527        dudt
528    };
529
530    run_mol_ode(x, u0, t_range, rhs, options)
531}
532
533// ---------------------------------------------------------------------------
534// Internal helpers
535// ---------------------------------------------------------------------------
536
537/// Apply BCs to initial condition vector
538fn apply_mol_bc(
539    u: &mut Array1<f64>,
540    left: &MOLBoundaryCondition,
541    right: &MOLBoundaryCondition,
542    dx: f64,
543) {
544    let n = u.len();
545    match left {
546        MOLBoundaryCondition::Dirichlet(val) => u[0] = *val,
547        MOLBoundaryCondition::Neumann(val) => u[0] = u[1] - dx * val,
548        MOLBoundaryCondition::Periodic => {
549            if n > 1 {
550                u[0] = u[n - 2]; // or leave as-is
551            }
552        }
553    }
554    match right {
555        MOLBoundaryCondition::Dirichlet(val) => u[n - 1] = *val,
556        MOLBoundaryCondition::Neumann(val) => u[n - 1] = u[n - 2] + dx * val,
557        MOLBoundaryCondition::Periodic => {
558            if n > 1 {
559                u[n - 1] = u[1]; // or leave as-is
560            }
561        }
562    }
563}
564
565/// Apply boundary conditions in the ODE RHS for diffusion
566#[allow(clippy::too_many_arguments)]
567fn apply_mol_bc_rhs(
568    dudt: &mut Array1<f64>,
569    u: &ArrayView1<f64>,
570    left: &MOLBoundaryCondition,
571    right: &MOLBoundaryCondition,
572    n: usize,
573    dx: f64,
574    alpha: f64,
575    _x: &Array1<f64>,
576    _t: f64,
577    _source: &Option<Arc<dyn Fn(f64, f64, f64) -> f64 + Send + Sync>>,
578) {
579    let dx2 = dx * dx;
580    match left {
581        MOLBoundaryCondition::Dirichlet(_) => {
582            dudt[0] = 0.0;
583        }
584        MOLBoundaryCondition::Neumann(val) => {
585            // Ghost point: u[-1] = u[1] - 2*dx*val
586            let ghost = u[1] - 2.0 * dx * val;
587            dudt[0] = alpha * (u[1] - 2.0 * u[0] + ghost) / dx2;
588        }
589        MOLBoundaryCondition::Periodic => {
590            dudt[0] = alpha * (u[1] - 2.0 * u[0] + u[n - 2]) / dx2;
591        }
592    }
593    match right {
594        MOLBoundaryCondition::Dirichlet(_) => {
595            dudt[n - 1] = 0.0;
596        }
597        MOLBoundaryCondition::Neumann(val) => {
598            let ghost = u[n - 2] + 2.0 * dx * val;
599            dudt[n - 1] = alpha * (ghost - 2.0 * u[n - 1] + u[n - 2]) / dx2;
600        }
601        MOLBoundaryCondition::Periodic => {
602            dudt[n - 1] = alpha * (u[1] - 2.0 * u[n - 1] + u[n - 2]) / dx2;
603        }
604    }
605}
606
607/// Apply BCs in advection RHS
608fn apply_advection_bc_rhs(
609    dudt: &mut Array1<f64>,
610    u: &ArrayView1<f64>,
611    left: &MOLBoundaryCondition,
612    right: &MOLBoundaryCondition,
613    n: usize,
614    dx: f64,
615    velocity: f64,
616) {
617    match left {
618        MOLBoundaryCondition::Dirichlet(_) => {
619            dudt[0] = 0.0;
620        }
621        MOLBoundaryCondition::Neumann(_) => {
622            dudt[0] = 0.0; // simplified
623        }
624        MOLBoundaryCondition::Periodic => {
625            if velocity >= 0.0 {
626                dudt[0] = -velocity * (u[0] - u[n - 2]) / dx;
627            } else {
628                dudt[0] = -velocity * (u[1] - u[0]) / dx;
629            }
630        }
631    }
632    match right {
633        MOLBoundaryCondition::Dirichlet(_) => {
634            dudt[n - 1] = 0.0;
635        }
636        MOLBoundaryCondition::Neumann(_) => {
637            dudt[n - 1] = 0.0; // simplified
638        }
639        MOLBoundaryCondition::Periodic => {
640            if velocity >= 0.0 {
641                dudt[n - 1] = -velocity * (u[n - 1] - u[n - 2]) / dx;
642            } else {
643                dudt[n - 1] = -velocity * (u[1] - u[n - 1]) / dx;
644            }
645        }
646    }
647}
648
649/// Run the ODE solver on the semi-discretized system
650fn run_mol_ode(
651    x: Array1<f64>,
652    u0: Array1<f64>,
653    t_range: [f64; 2],
654    rhs: impl Fn(f64, ArrayView1<f64>) -> Array1<f64> + Send + Sync + 'static,
655    options: &MOLEnhancedOptions,
656) -> PDEResult<MOLEnhancedResult> {
657    let ode_opts = ODEOptions {
658        method: options.integrator.to_ode_method(),
659        rtol: options.rtol,
660        atol: options.atol,
661        max_steps: options.max_steps,
662        dense_output: false,
663        ..Default::default()
664    };
665
666    // Wrap in Arc to satisfy Clone bound required by solve_ivp
667    let rhs_arc = Arc::new(rhs);
668    let rhs_clone = move |t: f64, u: ArrayView1<f64>| -> Array1<f64> { rhs_arc(t, u) };
669
670    let result = solve_ivp(rhs_clone, t_range, u0, Some(ode_opts))?;
671
672    let t_vec: Vec<f64> = result.t.to_vec();
673    let u_vec: Vec<Array1<f64>> = result.y.to_vec();
674
675    Ok(MOLEnhancedResult {
676        x,
677        t: t_vec,
678        u: u_vec,
679        n_eval: result.n_eval,
680        n_steps: result.n_steps,
681    })
682}
683
684// ---------------------------------------------------------------------------
685// Tests
686// ---------------------------------------------------------------------------
687
688#[cfg(test)]
689mod tests {
690    use super::*;
691    use std::f64::consts::PI;
692
693    #[test]
694    fn test_mol_diffusion_constant() {
695        // u(x,0) = 1 with u(0)=1, u(1)=1 => stays at 1
696        let result = mol_diffusion_1d(
697            0.1,
698            [0.0, 1.0],
699            [0.0, 0.5],
700            21,
701            MOLBoundaryCondition::Dirichlet(1.0),
702            MOLBoundaryCondition::Dirichlet(1.0),
703            |_| 1.0,
704            None,
705            &MOLEnhancedOptions::default(),
706        )
707        .expect("Should succeed");
708
709        let last = &result.u[result.u.len() - 1];
710        for &v in last.iter() {
711            assert!((v - 1.0).abs() < 0.01, "Should stay at 1.0, got {v}");
712        }
713    }
714
715    #[test]
716    fn test_mol_diffusion_decay() {
717        // u(x,0) = sin(pi*x), u(0)=0, u(1)=0
718        let alpha = 0.1;
719        let result = mol_diffusion_1d(
720            alpha,
721            [0.0, 1.0],
722            [0.0, 0.5],
723            41,
724            MOLBoundaryCondition::Dirichlet(0.0),
725            MOLBoundaryCondition::Dirichlet(0.0),
726            |x| (PI * x).sin(),
727            None,
728            &MOLEnhancedOptions::default(),
729        )
730        .expect("Should succeed");
731
732        let last = &result.u[result.u.len() - 1];
733        let mid = last.len() / 2;
734        let exact = (PI * 0.5).sin() * (-PI * PI * alpha * 0.5).exp();
735        assert!(
736            (last[mid] - exact).abs() < 0.05,
737            "MOL diffusion: got {}, expected {exact}",
738            last[mid]
739        );
740    }
741
742    #[test]
743    fn test_mol_diffusion_4th_order() {
744        let alpha = 0.1;
745        let opts = MOLEnhancedOptions {
746            stencil: StencilOrder::Fourth,
747            ..Default::default()
748        };
749        let result = mol_diffusion_1d(
750            alpha,
751            [0.0, 1.0],
752            [0.0, 0.3],
753            41,
754            MOLBoundaryCondition::Dirichlet(0.0),
755            MOLBoundaryCondition::Dirichlet(0.0),
756            |x| (PI * x).sin(),
757            None,
758            &opts,
759        )
760        .expect("Should succeed");
761
762        let last = &result.u[result.u.len() - 1];
763        let mid = last.len() / 2;
764        let exact = (PI * 0.5).sin() * (-PI * PI * alpha * 0.3).exp();
765        assert!(
766            (last[mid] - exact).abs() < 0.05,
767            "4th order: got {}, expected {exact}",
768            last[mid]
769        );
770    }
771
772    #[test]
773    fn test_mol_diffusion_with_source() {
774        // du/dt = alpha * d2u/dx2 + 1.0 with zero ICs and BCs
775        let result = mol_diffusion_1d(
776            0.1,
777            [0.0, 1.0],
778            [0.0, 0.5],
779            21,
780            MOLBoundaryCondition::Dirichlet(0.0),
781            MOLBoundaryCondition::Dirichlet(0.0),
782            |_| 0.0,
783            Some(Arc::new(|_, _, _| 1.0)),
784            &MOLEnhancedOptions::default(),
785        )
786        .expect("Should succeed");
787
788        // Interior values should be positive
789        let last = &result.u[result.u.len() - 1];
790        let mid = last.len() / 2;
791        assert!(last[mid] > 0.0, "Source should make interior positive");
792    }
793
794    #[test]
795    fn test_mol_diffusion_neumann() {
796        // Insulated boundaries
797        let result = mol_diffusion_1d(
798            0.01,
799            [0.0, 1.0],
800            [0.0, 0.5],
801            21,
802            MOLBoundaryCondition::Neumann(0.0),
803            MOLBoundaryCondition::Neumann(0.0),
804            |_| 1.0,
805            None,
806            &MOLEnhancedOptions::default(),
807        )
808        .expect("Should succeed");
809
810        let last = &result.u[result.u.len() - 1];
811        for &v in last.iter() {
812            assert!((v - 1.0).abs() < 0.05, "Neumann: should stay ~1.0, got {v}");
813        }
814    }
815
816    #[test]
817    fn test_mol_diffusion_periodic() {
818        let result = mol_diffusion_1d(
819            0.01,
820            [0.0, 1.0],
821            [0.0, 0.5],
822            41,
823            MOLBoundaryCondition::Periodic,
824            MOLBoundaryCondition::Periodic,
825            |x| (2.0 * PI * x).sin(),
826            None,
827            &MOLEnhancedOptions::default(),
828        )
829        .expect("Should succeed");
830
831        assert!(result.u.len() > 1, "Should have multiple time steps");
832    }
833
834    #[test]
835    fn test_mol_advection_upwind() {
836        // Simple advection: du/dt + 1.0 * du/dx = 0
837        let result = mol_advection_1d(
838            1.0,
839            [0.0, 2.0],
840            [0.0, 0.5],
841            41,
842            MOLBoundaryCondition::Dirichlet(0.0),
843            MOLBoundaryCondition::Dirichlet(0.0),
844            |x| if x > 0.3 && x < 0.7 { 1.0 } else { 0.0 },
845            None,
846            AdvectionScheme::Upwind,
847            &MOLEnhancedOptions::default(),
848        )
849        .expect("Should succeed");
850
851        assert!(result.u.len() > 1);
852    }
853
854    #[test]
855    fn test_mol_advection_lax_wendroff() {
856        let result = mol_advection_1d(
857            1.0,
858            [0.0, 2.0],
859            [0.0, 0.3],
860            41,
861            MOLBoundaryCondition::Dirichlet(0.0),
862            MOLBoundaryCondition::Dirichlet(0.0),
863            |x| (PI * x).sin(),
864            None,
865            AdvectionScheme::LaxWendroff,
866            &MOLEnhancedOptions::default(),
867        )
868        .expect("Should succeed");
869
870        assert!(result.u.len() > 1);
871    }
872
873    #[test]
874    fn test_mol_advection_periodic() {
875        let result = mol_advection_1d(
876            1.0,
877            [0.0, 1.0],
878            [0.0, 0.3],
879            41,
880            MOLBoundaryCondition::Periodic,
881            MOLBoundaryCondition::Periodic,
882            |x| (2.0 * PI * x).sin(),
883            None,
884            AdvectionScheme::Upwind,
885            &MOLEnhancedOptions::default(),
886        )
887        .expect("Should succeed");
888
889        assert!(result.u.len() > 1);
890    }
891
892    #[test]
893    fn test_mol_advection_diffusion() {
894        let result = mol_advection_diffusion_1d(
895            1.0,
896            0.01,
897            [0.0, 1.0],
898            [0.0, 0.5],
899            41,
900            MOLBoundaryCondition::Dirichlet(0.0),
901            MOLBoundaryCondition::Dirichlet(0.0),
902            |x| (PI * x).sin(),
903            None,
904            &MOLEnhancedOptions::default(),
905        )
906        .expect("Should succeed");
907
908        assert!(result.u.len() > 1);
909    }
910
911    #[test]
912    fn test_mol_reaction_diffusion() {
913        // Gray-Scott-like: u, v species
914        // du/dt = D_u * d2u/dx2 - u*v^2 + F*(1-u)
915        // dv/dt = D_v * d2v/dx2 + u*v^2 - (F+k)*v
916        let system = ReactionDiffusionSystem {
917            n_species: 2,
918            diffusion_coeffs: vec![0.01, 0.005],
919            reaction: Arc::new(|_x, _t, u| {
920                let f = 0.04;
921                let k = 0.06;
922                let u_val = u[0];
923                let v_val = u[1];
924                vec![
925                    -u_val * v_val * v_val + f * (1.0 - u_val),
926                    u_val * v_val * v_val - (f + k) * v_val,
927                ]
928            }),
929        };
930
931        fn ic_u(_x: f64) -> f64 {
932            1.0
933        }
934        fn ic_v(x: f64) -> f64 {
935            if x > 0.4 && x < 0.6 {
936                0.5
937            } else {
938                0.0
939            }
940        }
941        let ics: Vec<fn(f64) -> f64> = vec![ic_u, ic_v];
942        let result = mol_reaction_diffusion(
943            &system,
944            [0.0, 1.0],
945            [0.0, 1.0],
946            21,
947            &ics,
948            &[1.0, 0.0],
949            &[1.0, 0.0],
950            &MOLEnhancedOptions {
951                integrator: MOLTimeIntegrator::RK45,
952                ..Default::default()
953            },
954        )
955        .expect("Should succeed");
956
957        assert!(result.u.len() > 1);
958        assert_eq!(result.u[0].shape()[0], 2); // 2 species
959    }
960
961    #[test]
962    fn test_mol_bdf_integrator() {
963        // Test BDF for potentially stiff diffusion
964        let result = mol_diffusion_1d(
965            1.0, // Large diffusion => stiff
966            [0.0, 1.0],
967            [0.0, 0.1],
968            21,
969            MOLBoundaryCondition::Dirichlet(0.0),
970            MOLBoundaryCondition::Dirichlet(0.0),
971            |x| (PI * x).sin(),
972            None,
973            &MOLEnhancedOptions {
974                integrator: MOLTimeIntegrator::BDF,
975                ..Default::default()
976            },
977        )
978        .expect("BDF should succeed");
979
980        assert!(result.u.len() > 1);
981    }
982
983    #[test]
984    fn test_mol_result_fields() {
985        let result = mol_diffusion_1d(
986            0.1,
987            [0.0, 1.0],
988            [0.0, 0.1],
989            11,
990            MOLBoundaryCondition::Dirichlet(0.0),
991            MOLBoundaryCondition::Dirichlet(0.0),
992            |x| (PI * x).sin(),
993            None,
994            &MOLEnhancedOptions::default(),
995        )
996        .expect("Should succeed");
997
998        assert_eq!(result.x.len(), 11);
999        assert!(result.n_eval > 0);
1000        assert!(result.n_steps > 0);
1001        assert!(result.t.len() == result.u.len());
1002    }
1003}