Skip to main content

scirs2_integrate/pde/implicit/
adi.rs

1//! Alternating Direction Implicit (ADI) methods for 2D and 3D PDEs
2//!
3//! This module provides implementations of ADI methods for solving
4//! two-dimensional and three-dimensional partial differential equations.
5//! ADI methods split multi-dimensional problems into sequences of one-dimensional
6//! problems, making them computationally efficient.
7
8use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView1};
9use std::time::Instant;
10
11use super::ImplicitOptions;
12use crate::pde::finite_difference::FiniteDifferenceScheme;
13use crate::pde::{
14    BoundaryCondition, BoundaryConditionType, BoundaryLocation, Domain, PDEError, PDEResult,
15    PDESolution, PDESolverInfo,
16};
17
18/// Type alias for 2D coefficient function taking (x, y, t, u) and returning a value
19type CoeffFn2D = Box<dyn Fn(f64, f64, f64, f64) -> f64 + Send + Sync>;
20
21/// Result of ADI method solution
22pub struct ADIResult {
23    /// Time points
24    pub t: Array1<f64>,
25
26    /// Solution values, indexed as [time, x, y]
27    pub u: Vec<Array3<f64>>,
28
29    /// Solver information
30    pub info: Option<String>,
31
32    /// Computation time
33    pub computation_time: f64,
34
35    /// Number of time steps
36    pub num_steps: usize,
37
38    /// Number of linear system solves
39    pub num_linear_solves: usize,
40}
41
42/// ADI solver for 2D parabolic PDEs
43///
44/// This solver implements the Peaceman-Rachford ADI method for solving
45/// two-dimensional parabolic PDEs of the form:
46/// ∂u/∂t = Dx*∂²u/∂x² + Dy*∂²u/∂y² + f(x,y,t,u)
47pub struct ADI2D {
48    /// Spatial domain
49    domain: Domain,
50
51    /// Time range [t_start, t_end]
52    time_range: [f64; 2],
53
54    /// Diffusion coefficient function in x-direction: Dx(x, y, t, u)
55    diffusion_x: CoeffFn2D,
56
57    /// Diffusion coefficient function in y-direction: Dy(x, y, t, u)
58    diffusion_y: CoeffFn2D,
59
60    /// Advection coefficient function in x-direction: vx(x, y, t, u)
61    advection_x: Option<CoeffFn2D>,
62
63    /// Advection coefficient function in y-direction: vy(x, y, t, u)
64    advection_y: Option<CoeffFn2D>,
65
66    /// Reaction term function: f(x, y, t, u)
67    reaction_term: Option<CoeffFn2D>,
68
69    /// Initial condition function: u(x, y, 0)
70    initial_condition: Box<dyn Fn(f64, f64) -> f64 + Send + Sync>,
71
72    /// Boundary conditions
73    boundary_conditions: Vec<BoundaryCondition<f64>>,
74
75    /// Finite difference scheme for spatial discretization
76    fd_scheme: FiniteDifferenceScheme,
77
78    /// Solver options
79    options: ImplicitOptions,
80}
81
82impl ADI2D {
83    /// Create a new ADI solver for 2D parabolic PDEs
84    pub fn new(
85        domain: Domain,
86        time_range: [f64; 2],
87        diffusion_x: impl Fn(f64, f64, f64, f64) -> f64 + Send + Sync + 'static,
88        diffusion_y: impl Fn(f64, f64, f64, f64) -> f64 + Send + Sync + 'static,
89        initial_condition: impl Fn(f64, f64) -> f64 + Send + Sync + 'static,
90        boundary_conditions: Vec<BoundaryCondition<f64>>,
91        options: Option<ImplicitOptions>,
92    ) -> PDEResult<Self> {
93        // Validate the domain
94        if domain.dimensions() != 2 {
95            return Err(PDEError::DomainError(
96                "Domain must be 2-dimensional for 2D ADI solver".to_string(),
97            ));
98        }
99
100        // Validate time _range
101        if time_range[0] >= time_range[1] {
102            return Err(PDEError::DomainError(
103                "Invalid time _range: start must be less than end".to_string(),
104            ));
105        }
106
107        // Validate boundary _conditions
108        if boundary_conditions.len() != 4 {
109            return Err(PDEError::BoundaryConditions(
110                "2D parabolic PDE requires exactly 4 boundary _conditions".to_string(),
111            ));
112        }
113
114        // Ensure we have boundary _conditions for all four boundaries
115        let has_lower_x = boundary_conditions
116            .iter()
117            .any(|bc| bc.location == BoundaryLocation::Lower && bc.dimension == 0);
118        let has_upper_x = boundary_conditions
119            .iter()
120            .any(|bc| bc.location == BoundaryLocation::Upper && bc.dimension == 0);
121        let has_lower_y = boundary_conditions
122            .iter()
123            .any(|bc| bc.location == BoundaryLocation::Lower && bc.dimension == 1);
124        let has_upper_y = boundary_conditions
125            .iter()
126            .any(|bc| bc.location == BoundaryLocation::Upper && bc.dimension == 1);
127
128        if !has_lower_x || !has_upper_x || !has_lower_y || !has_upper_y {
129            return Err(PDEError::BoundaryConditions(
130                "2D parabolic PDE requires boundary _conditions for all four sides".to_string(),
131            ));
132        }
133
134        // Use default options if none provided
135        let options = options.unwrap_or_default();
136
137        Ok(ADI2D {
138            domain,
139            time_range,
140            diffusion_x: Box::new(diffusion_x),
141            diffusion_y: Box::new(diffusion_y),
142            advection_x: None,
143            advection_y: None,
144            reaction_term: None,
145            initial_condition: Box::new(initial_condition),
146            boundary_conditions,
147            fd_scheme: FiniteDifferenceScheme::CentralDifference,
148            options,
149        })
150    }
151
152    /// Add advection terms to the PDE
153    pub fn with_advection(
154        mut self,
155        advection_x: impl Fn(f64, f64, f64, f64) -> f64 + Send + Sync + 'static,
156        advection_y: impl Fn(f64, f64, f64, f64) -> f64 + Send + Sync + 'static,
157    ) -> Self {
158        self.advection_x = Some(Box::new(advection_x));
159        self.advection_y = Some(Box::new(advection_y));
160        self
161    }
162
163    /// Add a reaction term to the PDE
164    pub fn with_reaction(
165        mut self,
166        reaction_term: impl Fn(f64, f64, f64, f64) -> f64 + Send + Sync + 'static,
167    ) -> Self {
168        self.reaction_term = Some(Box::new(reaction_term));
169        self
170    }
171
172    /// Set the finite difference scheme for spatial discretization
173    pub fn with_fd_scheme(mut self, scheme: FiniteDifferenceScheme) -> Self {
174        self.fd_scheme = scheme;
175        self
176    }
177
178    /// Solve the PDE using the Peaceman-Rachford ADI method
179    pub fn solve(&self) -> PDEResult<ADIResult> {
180        let start_time = Instant::now();
181
182        // Generate spatial grids
183        let x_grid = self.domain.grid(0)?;
184        let y_grid = self.domain.grid(1)?;
185        let nx = x_grid.len();
186        let ny = y_grid.len();
187
188        // Grid spacing
189        let dx = self.domain.grid_spacing(0)?;
190        let dy = self.domain.grid_spacing(1)?;
191
192        // Time step
193        let dt = self.options.dt.unwrap_or(0.01);
194
195        // Calculate number of time steps
196        let t_start = self.time_range[0];
197        let t_end = self.time_range[1];
198        let num_steps = ((t_end - t_start) / dt).ceil() as usize;
199
200        // Initialize time array
201        let mut t_values = Vec::with_capacity(num_steps + 1);
202        t_values.push(t_start);
203
204        // Initialize solution array with initial condition
205        let mut u_current = Array2::zeros((nx, ny));
206
207        // Apply initial condition
208        for (i, &x) in x_grid.iter().enumerate() {
209            for (j, &y) in y_grid.iter().enumerate() {
210                u_current[[i, j]] = (self.initial_condition)(x, y);
211            }
212        }
213
214        // Apply boundary conditions to initial state
215        self.apply_boundary_conditions(&mut u_current, &x_grid, &y_grid, t_start);
216
217        // Store solutions
218        let save_every = self.options.save_every.unwrap_or(1);
219        let mut solutions = Vec::with_capacity((num_steps + 1) / save_every + 1);
220
221        // Add initial condition to solutions
222        let mut u3d = Array3::zeros((nx, ny, 1));
223        for i in 0..nx {
224            for j in 0..ny {
225                u3d[[i, j, 0]] = u_current[[i, j]];
226            }
227        }
228        solutions.push(u3d);
229
230        // Initialize coefficient matrices for both directions
231        let mut a_x = Array2::zeros((nx, nx)); // For x-direction sweep
232        let mut b_x = Array2::zeros((nx, nx));
233        let mut a_y = Array2::zeros((ny, ny)); // For y-direction sweep
234        let mut b_y = Array2::zeros((ny, ny));
235
236        // Track solver statistics
237        let mut num_linear_solves = 0;
238
239        // Time-stepping loop
240        for step in 0..num_steps {
241            let t_current = t_start + step as f64 * dt;
242            let t_mid = t_current + 0.5 * dt;
243            let t_next = t_current + dt;
244
245            // Intermediate solution after x-sweep
246            let mut u_intermediate = Array2::zeros((nx, ny));
247
248            // 1. First half-step: Implicit in x-direction, explicit in y-direction
249            for j in 0..ny {
250                // 1.1 Set up coefficient matrices for x-direction
251                self.setup_coefficient_matrices_x(
252                    &mut a_x,
253                    &mut b_x,
254                    &x_grid,
255                    y_grid[j],
256                    dx,
257                    0.5 * dt,
258                    t_current,
259                    &u_current.slice(s![.., j]),
260                );
261
262                // 1.2 Extract the current solution row
263                let mut u_row = Array1::zeros(nx);
264                for i in 0..nx {
265                    u_row[i] = u_current[[i, j]];
266                }
267
268                // 1.3 Right-hand side vector for x-direction
269                let rhs_x = b_x.dot(&u_row);
270
271                // 1.4 Solve the linear system for x-direction
272                let u_x_next = ADI2D::solve_tridiagonal(&a_x, &rhs_x)?;
273                num_linear_solves += 1;
274
275                // 1.5 Update intermediate solution row
276                for i in 0..nx {
277                    u_intermediate[[i, j]] = u_x_next[i];
278                }
279            }
280
281            // Apply boundary conditions to intermediate solution
282            self.apply_boundary_conditions(&mut u_intermediate, &x_grid, &y_grid, t_mid);
283
284            // 2. Second half-step: Implicit in y-direction, explicit in x-direction
285            for i in 0..nx {
286                // 2.1 Set up coefficient matrices for y-direction
287                self.setup_coefficient_matrices_y(
288                    &mut a_y,
289                    &mut b_y,
290                    x_grid[i],
291                    &y_grid,
292                    dy,
293                    0.5 * dt,
294                    t_mid,
295                    &u_intermediate.slice(s![i, ..]),
296                );
297
298                // 2.2 Extract the intermediate solution column
299                let mut u_col = Array1::zeros(ny);
300                for j in 0..ny {
301                    u_col[j] = u_intermediate[[i, j]];
302                }
303
304                // 2.3 Right-hand side vector for y-direction
305                let rhs_y = b_y.dot(&u_col);
306
307                // 2.4 Solve the linear system for y-direction
308                let u_y_next = ADI2D::solve_tridiagonal(&a_y, &rhs_y)?;
309                num_linear_solves += 1;
310
311                // 2.5 Update solution column
312                for j in 0..ny {
313                    u_current[[i, j]] = u_y_next[j];
314                }
315            }
316
317            // Apply boundary conditions to final solution for this time step
318            self.apply_boundary_conditions(&mut u_current, &x_grid, &y_grid, t_next);
319
320            // Update time
321            t_values.push(t_next);
322
323            // Save solution if needed
324            if (step + 1) % save_every == 0 || step == num_steps - 1 {
325                let mut u3d = Array3::zeros((nx, ny, 1));
326                for i in 0..nx {
327                    for j in 0..ny {
328                        u3d[[i, j, 0]] = u_current[[i, j]];
329                    }
330                }
331                solutions.push(u3d);
332            }
333
334            // Print progress if verbose
335            if self.options.verbose && (step + 1) % 10 == 0 {
336                println!(
337                    "Step {}/{} completed, t = {:.4}",
338                    step + 1,
339                    num_steps,
340                    t_next
341                );
342            }
343        }
344
345        // Convert time values to Array1
346        let t_array = Array1::from_vec(t_values);
347
348        // Compute solution time
349        let computation_time = start_time.elapsed().as_secs_f64();
350
351        // Create result
352        let info = Some(format!(
353            "Time steps: {num_steps}, Linear system solves: {num_linear_solves}"
354        ));
355
356        Ok(ADIResult {
357            t: t_array,
358            u: solutions,
359            info,
360            computation_time,
361            num_steps,
362            num_linear_solves,
363        })
364    }
365
366    /// Set up coefficient matrices for the x-direction sweep
367    #[allow(clippy::too_many_arguments)]
368    fn setup_coefficient_matrices_x(
369        &self,
370        a_matrix: &mut Array2<f64>,
371        b_matrix: &mut Array2<f64>,
372        x_grid: &Array1<f64>,
373        y: f64,
374        dx: f64,
375        half_dt: f64,
376        t: f64,
377        u_row: &ArrayView1<f64>,
378    ) {
379        let nx = x_grid.len();
380
381        // Clear matrices
382        a_matrix.fill(0.0);
383        b_matrix.fill(0.0);
384
385        // Set up matrices for interior points
386        for i in 1..nx - 1 {
387            let x = x_grid[i];
388            let u_val = u_row[i];
389
390            // Diffusion coefficient at the current point
391            let d = (self.diffusion_x)(x, y, t, u_val);
392
393            // Coefficient for diffusion term
394            let r = 0.5 * d * half_dt / (dx * dx);
395
396            // Implicit part (left-hand side)
397            a_matrix[[i, i - 1]] = -r; // Coefficient for u_{i-1,j}^{n+1/2}
398            a_matrix[[i, i]] = 1.0 + 2.0 * r; // Coefficient for u_{i,j}^{n+1/2}
399            a_matrix[[i, i + 1]] = -r; // Coefficient for u_{i+1,j}^{n+1/2}
400
401            // Explicit part (right-hand side)
402            b_matrix[[i, i - 1]] = r; // Coefficient for u_{i-1,j}^{n}
403            b_matrix[[i, i]] = 1.0 - 2.0 * r; // Coefficient for u_{i,j}^{n}
404            b_matrix[[i, i + 1]] = r; // Coefficient for u_{i+1,j}^{n}
405
406            // Add advection term in x-direction if present
407            if let Some(advection_x) = &self.advection_x {
408                let vx = advection_x(x, y, t, u_val);
409
410                // Coefficient for advection term
411                let c = 0.25 * vx * half_dt / dx;
412
413                // Implicit part
414                a_matrix[[i, i - 1]] -= c; // Additional term for u_{i-1,j}^{n+1/2}
415                a_matrix[[i, i + 1]] += c; // Additional term for u_{i+1,j}^{n+1/2}
416
417                // Explicit part
418                b_matrix[[i, i - 1]] -= c; // Additional term for u_{i-1,j}^{n}
419                b_matrix[[i, i + 1]] += c; // Additional term for u_{i+1,j}^{n}
420            }
421
422            // Add half of the reaction term if present
423            if let Some(reaction) = &self.reaction_term {
424                let k = reaction(x, y, t, u_val);
425
426                // Coefficient for reaction term (half applied to each direction)
427                let s = 0.25 * k * half_dt;
428
429                // Implicit part
430                a_matrix[[i, i]] += s; // Additional term for u_{i,j}^{n+1/2}
431
432                // Explicit part
433                b_matrix[[i, i]] += s; // Additional term for u_{i,j}^{n}
434            }
435        }
436
437        // Apply boundary conditions in x-direction
438        for bc in &self.boundary_conditions {
439            if bc.dimension == 0 {
440                // x-direction
441                match bc.location {
442                    BoundaryLocation::Lower => {
443                        // Apply boundary condition at x[0]
444                        match bc.bc_type {
445                            BoundaryConditionType::Dirichlet => {
446                                // u(a, y, t) = bc.value
447                                for j in 0..nx {
448                                    a_matrix[[0, j]] = 0.0;
449                                    b_matrix[[0, j]] = 0.0;
450                                }
451                                a_matrix[[0, 0]] = 1.0;
452                                b_matrix[[0, 0]] = bc.value;
453                            }
454                            BoundaryConditionType::Neumann => {
455                                // du/dx(a, y, t) = bc.value
456                                // Use second-order one-sided difference
457                                for j in 0..nx {
458                                    a_matrix[[0, j]] = 0.0;
459                                    b_matrix[[0, j]] = 0.0;
460                                }
461                                a_matrix[[0, 0]] = -3.0;
462                                a_matrix[[0, 1]] = 4.0;
463                                a_matrix[[0, 2]] = -1.0;
464                                b_matrix[[0, 0]] = 2.0 * dx * bc.value;
465                            }
466                            BoundaryConditionType::Robin => {
467                                // a*u + b*du/dx = c
468                                if let Some([a_val, b_val, c_val]) = bc.coefficients {
469                                    for j in 0..nx {
470                                        a_matrix[[0, j]] = 0.0;
471                                        b_matrix[[0, j]] = 0.0;
472                                    }
473                                    a_matrix[[0, 0]] = a_val - 3.0 * b_val / (2.0 * dx);
474                                    a_matrix[[0, 1]] = 4.0 * b_val / (2.0 * dx);
475                                    a_matrix[[0, 2]] = -b_val / (2.0 * dx);
476                                    b_matrix[[0, 0]] = c_val;
477                                }
478                            }
479                            BoundaryConditionType::Periodic => {
480                                // For periodic BCs, handled together with upper boundary
481                            }
482                        }
483                    }
484                    BoundaryLocation::Upper => {
485                        // Apply boundary condition at x[nx-1]
486                        let i = nx - 1;
487
488                        match bc.bc_type {
489                            BoundaryConditionType::Dirichlet => {
490                                // u(b, y, t) = bc.value
491                                for j in 0..nx {
492                                    a_matrix[[i, j]] = 0.0;
493                                    b_matrix[[i, j]] = 0.0;
494                                }
495                                a_matrix[[i, i]] = 1.0;
496                                b_matrix[[i, i]] = bc.value;
497                            }
498                            BoundaryConditionType::Neumann => {
499                                // du/dx(b, y, t) = bc.value
500                                // Use second-order one-sided difference
501                                for j in 0..nx {
502                                    a_matrix[[i, j]] = 0.0;
503                                    b_matrix[[i, j]] = 0.0;
504                                }
505                                a_matrix[[i, i]] = 3.0;
506                                a_matrix[[i, i - 1]] = -4.0;
507                                a_matrix[[i, i - 2]] = 1.0;
508                                b_matrix[[i, i]] = 2.0 * dx * bc.value;
509                            }
510                            BoundaryConditionType::Robin => {
511                                // a*u + b*du/dx = c
512                                if let Some([a_val, b_val, c_val]) = bc.coefficients {
513                                    for j in 0..nx {
514                                        a_matrix[[i, j]] = 0.0;
515                                        b_matrix[[i, j]] = 0.0;
516                                    }
517                                    a_matrix[[i, i]] = a_val + 3.0 * b_val / (2.0 * dx);
518                                    a_matrix[[i, i - 1]] = -4.0 * b_val / (2.0 * dx);
519                                    a_matrix[[i, i - 2]] = b_val / (2.0 * dx);
520                                    b_matrix[[i, i]] = c_val;
521                                }
522                            }
523                            BoundaryConditionType::Periodic => {
524                                // Handle periodic boundary conditions in x-direction
525
526                                // First, clear boundary rows
527                                for j in 0..nx {
528                                    a_matrix[[0, j]] = 0.0;
529                                    a_matrix[[i, j]] = 0.0;
530                                    b_matrix[[0, j]] = 0.0;
531                                    b_matrix[[i, j]] = 0.0;
532                                }
533
534                                // Extract diffusion coefficient
535                                let x_lower = x_grid[0];
536                                let x_upper = x_grid[i];
537                                let u_lower = u_row[0];
538                                let u_upper = u_row[i];
539
540                                let d_lower = (self.diffusion_x)(x_lower, y, t, u_lower);
541                                let d_upper = (self.diffusion_x)(x_upper, y, t, u_upper);
542
543                                let r_lower = 0.5 * d_lower * half_dt / (dx * dx);
544                                let r_upper = 0.5 * d_upper * half_dt / (dx * dx);
545
546                                // Lower boundary (connects to upper)
547                                a_matrix[[0, i]] = -r_lower;
548                                a_matrix[[0, 0]] = 1.0 + 2.0 * r_lower;
549                                a_matrix[[0, 1]] = -r_lower;
550
551                                b_matrix[[0, i]] = r_lower;
552                                b_matrix[[0, 0]] = 1.0 - 2.0 * r_lower;
553                                b_matrix[[0, 1]] = r_lower;
554
555                                // Upper boundary (connects to lower)
556                                a_matrix[[i, i - 1]] = -r_upper;
557                                a_matrix[[i, i]] = 1.0 + 2.0 * r_upper;
558                                a_matrix[[i, 0]] = -r_upper;
559
560                                b_matrix[[i, i - 1]] = r_upper;
561                                b_matrix[[i, i]] = 1.0 - 2.0 * r_upper;
562                                b_matrix[[i, 0]] = r_upper;
563                            }
564                        }
565                    }
566                }
567            }
568        }
569    }
570
571    /// Set up coefficient matrices for the y-direction sweep
572    #[allow(clippy::too_many_arguments)]
573    fn setup_coefficient_matrices_y(
574        &self,
575        a_matrix: &mut Array2<f64>,
576        b_matrix: &mut Array2<f64>,
577        x: f64,
578        y_grid: &Array1<f64>,
579        dy: f64,
580        half_dt: f64,
581        t: f64,
582        u_col: &ArrayView1<f64>,
583    ) {
584        let ny = y_grid.len();
585
586        // Clear matrices
587        a_matrix.fill(0.0);
588        b_matrix.fill(0.0);
589
590        // Set up matrices for interior points
591        for j in 1..ny - 1 {
592            let y = y_grid[j];
593            let u_val = u_col[j];
594
595            // Diffusion coefficient at the current point
596            let d = (self.diffusion_y)(x, y, t, u_val);
597
598            // Coefficient for diffusion term
599            let r = 0.5 * d * half_dt / (dy * dy);
600
601            // Implicit part (left-hand side)
602            a_matrix[[j, j - 1]] = -r; // Coefficient for u_{i,j-1}^{n+1}
603            a_matrix[[j, j]] = 1.0 + 2.0 * r; // Coefficient for u_{i,j}^{n+1}
604            a_matrix[[j, j + 1]] = -r; // Coefficient for u_{i,j+1}^{n+1}
605
606            // Explicit part (right-hand side)
607            b_matrix[[j, j - 1]] = r; // Coefficient for u_{i,j-1}^{n+1/2}
608            b_matrix[[j, j]] = 1.0 - 2.0 * r; // Coefficient for u_{i,j}^{n+1/2}
609            b_matrix[[j, j + 1]] = r; // Coefficient for u_{i,j+1}^{n+1/2}
610
611            // Add advection term in y-direction if present
612            if let Some(advection_y) = &self.advection_y {
613                let vy = advection_y(x, y, t, u_val);
614
615                // Coefficient for advection term
616                let c = 0.25 * vy * half_dt / dy;
617
618                // Implicit part
619                a_matrix[[j, j - 1]] -= c; // Additional term for u_{i,j-1}^{n+1}
620                a_matrix[[j, j + 1]] += c; // Additional term for u_{i,j+1}^{n+1}
621
622                // Explicit part
623                b_matrix[[j, j - 1]] -= c; // Additional term for u_{i,j-1}^{n+1/2}
624                b_matrix[[j, j + 1]] += c; // Additional term for u_{i,j+1}^{n+1/2}
625            }
626
627            // Add half of the reaction term if present
628            if let Some(reaction) = &self.reaction_term {
629                let k = reaction(x, y, t, u_val);
630
631                // Coefficient for reaction term (half applied to each direction)
632                let s = 0.25 * k * half_dt;
633
634                // Implicit part
635                a_matrix[[j, j]] += s; // Additional term for u_{i,j}^{n+1}
636
637                // Explicit part
638                b_matrix[[j, j]] += s; // Additional term for u_{i,j}^{n+1/2}
639            }
640        }
641
642        // Apply boundary conditions in y-direction
643        for bc in &self.boundary_conditions {
644            if bc.dimension == 1 {
645                // y-direction
646                match bc.location {
647                    BoundaryLocation::Lower => {
648                        // Apply boundary condition at y[0]
649                        match bc.bc_type {
650                            BoundaryConditionType::Dirichlet => {
651                                // u(x, c, t) = bc.value
652                                for j in 0..ny {
653                                    a_matrix[[0, j]] = 0.0;
654                                    b_matrix[[0, j]] = 0.0;
655                                }
656                                a_matrix[[0, 0]] = 1.0;
657                                b_matrix[[0, 0]] = bc.value;
658                            }
659                            BoundaryConditionType::Neumann => {
660                                // du/dy(x, c, t) = bc.value
661                                // Use second-order one-sided difference
662                                for j in 0..ny {
663                                    a_matrix[[0, j]] = 0.0;
664                                    b_matrix[[0, j]] = 0.0;
665                                }
666                                a_matrix[[0, 0]] = -3.0;
667                                a_matrix[[0, 1]] = 4.0;
668                                a_matrix[[0, 2]] = -1.0;
669                                b_matrix[[0, 0]] = 2.0 * dy * bc.value;
670                            }
671                            BoundaryConditionType::Robin => {
672                                // a*u + b*du/dy = c
673                                if let Some([a_val, b_val, c_val]) = bc.coefficients {
674                                    for j in 0..ny {
675                                        a_matrix[[0, j]] = 0.0;
676                                        b_matrix[[0, j]] = 0.0;
677                                    }
678                                    a_matrix[[0, 0]] = a_val - 3.0 * b_val / (2.0 * dy);
679                                    a_matrix[[0, 1]] = 4.0 * b_val / (2.0 * dy);
680                                    a_matrix[[0, 2]] = -b_val / (2.0 * dy);
681                                    b_matrix[[0, 0]] = c_val;
682                                }
683                            }
684                            BoundaryConditionType::Periodic => {
685                                // For periodic BCs, handled together with upper boundary
686                            }
687                        }
688                    }
689                    BoundaryLocation::Upper => {
690                        // Apply boundary condition at y[ny-1]
691                        let j = ny - 1;
692
693                        match bc.bc_type {
694                            BoundaryConditionType::Dirichlet => {
695                                // u(x, d, t) = bc.value
696                                for i in 0..ny {
697                                    a_matrix[[j, i]] = 0.0;
698                                    b_matrix[[j, i]] = 0.0;
699                                }
700                                a_matrix[[j, j]] = 1.0;
701                                b_matrix[[j, j]] = bc.value;
702                            }
703                            BoundaryConditionType::Neumann => {
704                                // du/dy(x, d, t) = bc.value
705                                // Use second-order one-sided difference
706                                for i in 0..ny {
707                                    a_matrix[[j, i]] = 0.0;
708                                    b_matrix[[j, i]] = 0.0;
709                                }
710                                a_matrix[[j, j]] = 3.0;
711                                a_matrix[[j, j - 1]] = -4.0;
712                                a_matrix[[j, j - 2]] = 1.0;
713                                b_matrix[[j, j]] = 2.0 * dy * bc.value;
714                            }
715                            BoundaryConditionType::Robin => {
716                                // a*u + b*du/dy = c
717                                if let Some([a_val, b_val, c_val]) = bc.coefficients {
718                                    for i in 0..ny {
719                                        a_matrix[[j, i]] = 0.0;
720                                        b_matrix[[j, i]] = 0.0;
721                                    }
722                                    a_matrix[[j, j]] = a_val + 3.0 * b_val / (2.0 * dy);
723                                    a_matrix[[j, j - 1]] = -4.0 * b_val / (2.0 * dy);
724                                    a_matrix[[j, j - 2]] = b_val / (2.0 * dy);
725                                    b_matrix[[j, j]] = c_val;
726                                }
727                            }
728                            BoundaryConditionType::Periodic => {
729                                // Handle periodic boundary conditions in y-direction
730
731                                // First, clear boundary rows
732                                for i in 0..ny {
733                                    a_matrix[[0, i]] = 0.0;
734                                    a_matrix[[j, i]] = 0.0;
735                                    b_matrix[[0, i]] = 0.0;
736                                    b_matrix[[j, i]] = 0.0;
737                                }
738
739                                // Extract diffusion coefficient
740                                let y_lower = y_grid[0];
741                                let y_upper = y_grid[j];
742                                let u_lower = u_col[0];
743                                let u_upper = u_col[j];
744
745                                let d_lower = (self.diffusion_y)(x, y_lower, t, u_lower);
746                                let d_upper = (self.diffusion_y)(x, y_upper, t, u_upper);
747
748                                let r_lower = 0.5 * d_lower * half_dt / (dy * dy);
749                                let r_upper = 0.5 * d_upper * half_dt / (dy * dy);
750
751                                // Lower boundary (connects to upper)
752                                a_matrix[[0, j]] = -r_lower;
753                                a_matrix[[0, 0]] = 1.0 + 2.0 * r_lower;
754                                a_matrix[[0, 1]] = -r_lower;
755
756                                b_matrix[[0, j]] = r_lower;
757                                b_matrix[[0, 0]] = 1.0 - 2.0 * r_lower;
758                                b_matrix[[0, 1]] = r_lower;
759
760                                // Upper boundary (connects to lower)
761                                a_matrix[[j, j - 1]] = -r_upper;
762                                a_matrix[[j, j]] = 1.0 + 2.0 * r_upper;
763                                a_matrix[[j, 0]] = -r_upper;
764
765                                b_matrix[[j, j - 1]] = r_upper;
766                                b_matrix[[j, j]] = 1.0 - 2.0 * r_upper;
767                                b_matrix[[j, 0]] = r_upper;
768                            }
769                        }
770                    }
771                }
772            }
773        }
774    }
775
776    /// Apply boundary conditions to the solution
777    fn apply_boundary_conditions(
778        &self,
779        u: &mut Array2<f64>,
780        x_grid: &Array1<f64>,
781        y_grid: &Array1<f64>,
782        _t: f64,
783    ) {
784        let nx = x_grid.len();
785        let ny = y_grid.len();
786
787        for bc in &self.boundary_conditions {
788            match (bc.dimension, bc.location, bc.bc_type) {
789                // x-direction Dirichlet boundary conditions
790                (0, BoundaryLocation::Lower, BoundaryConditionType::Dirichlet) => {
791                    for j in 0..ny {
792                        u[[0, j]] = bc.value;
793                    }
794                }
795                (0, BoundaryLocation::Upper, BoundaryConditionType::Dirichlet) => {
796                    for j in 0..ny {
797                        u[[nx - 1, j]] = bc.value;
798                    }
799                }
800
801                // y-direction Dirichlet boundary conditions
802                (1, BoundaryLocation::Lower, BoundaryConditionType::Dirichlet) => {
803                    for i in 0..nx {
804                        u[[i, 0]] = bc.value;
805                    }
806                }
807                (1, BoundaryLocation::Upper, BoundaryConditionType::Dirichlet) => {
808                    for i in 0..nx {
809                        u[[i, ny - 1]] = bc.value;
810                    }
811                }
812
813                // Other boundary conditions are handled in the coefficient matrices
814                _ => {}
815            }
816        }
817    }
818
819    /// Solve a tridiagonal linear system using the Thomas algorithm
820    fn solve_tridiagonal(a: &Array2<f64>, b: &Array1<f64>) -> PDEResult<Array1<f64>> {
821        let n = b.len();
822
823        // Extract the tridiagonal elements
824        let mut lower = Array1::zeros(n - 1);
825        let mut diagonal = Array1::zeros(n);
826        let mut upper = Array1::zeros(n - 1);
827
828        for i in 0..n {
829            diagonal[i] = a[[i, i]];
830            if i < n - 1 {
831                upper[i] = a[[i, i + 1]];
832            }
833            if i > 0 {
834                lower[i - 1] = a[[i, i - 1]];
835            }
836        }
837
838        // Solve tridiagonal system using Thomas algorithm
839        let mut solution = Array1::zeros(n);
840        let mut temp_diag = diagonal.clone();
841        let mut temp_rhs = b.to_owned();
842
843        // Forward sweep
844        for i in 1..n {
845            let w = lower[i - 1] / temp_diag[i - 1];
846            temp_diag[i] -= w * upper[i - 1];
847            temp_rhs[i] -= w * temp_rhs[i - 1];
848        }
849
850        // Backward sweep
851        solution[n - 1] = temp_rhs[n - 1] / temp_diag[n - 1];
852        for i in (0..n - 1).rev() {
853            solution[i] = (temp_rhs[i] - upper[i] * solution[i + 1]) / temp_diag[i];
854        }
855
856        Ok(solution)
857    }
858}
859
860/// Convert an ADIResult to a PDESolution
861impl From<ADIResult> for PDESolution<f64> {
862    fn from(result: ADIResult) -> Self {
863        let mut grids = Vec::new();
864
865        // Add time grid
866        grids.push(result.t.clone());
867
868        // Extract spatial grids from solution shape
869        let nx = result.u[0].shape()[0];
870        let ny = result.u[0].shape()[1];
871
872        // Create spatial grids (we don't have the actual grid values, so use linspace)
873        let x_grid = Array1::linspace(0.0, 1.0, nx);
874        let y_grid = Array1::linspace(0.0, 1.0, ny);
875
876        grids.push(x_grid);
877        grids.push(y_grid);
878
879        // Convert 3D arrays to 2D arrays for PDESolution format
880        let mut values = Vec::new();
881        for u3d in result.u {
882            let mut u2d = Array2::zeros((nx, ny));
883            for i in 0..nx {
884                for j in 0..ny {
885                    u2d[[i, j]] = u3d[[i, j, 0]];
886                }
887            }
888            values.push(u2d);
889        }
890
891        // Create solver info
892        let info = PDESolverInfo {
893            num_iterations: result.num_linear_solves,
894            computation_time: result.computation_time,
895            residual_norm: None,
896            convergence_history: None,
897            method: "ADI Method".to_string(),
898        };
899
900        PDESolution {
901            grids,
902            values,
903            error_estimate: None,
904            info,
905        }
906    }
907}