scirs2_integrate/pde/
amr.rs

1//! Adaptive Mesh Refinement (AMR) for PDEs
2//!
3//! This module implements adaptive mesh refinement techniques for partial
4//! differential equations. AMR automatically refines the computational mesh
5//! in regions where high accuracy is needed while keeping coarse meshes
6//! elsewhere to maintain computational efficiency.
7//!
8//! # AMR Concepts
9//!
10//! - **Error Estimation**: Identify regions needing refinement
11//! - **Refinement Criteria**: Decide when and where to refine
12//! - **Data Transfer**: Interpolate solution between mesh levels
13//! - **Load Balancing**: Distribute refined regions across processors
14//!
15//! # Examples
16//!
17//! ```
18//! use scirs2_integrate::pde::amr::{AMRGrid, RefinementCriteria, AMRSolver};
19//!
20//! // Create AMR grid with initial coarse mesh
21//! let mut grid = AMRGrid::new(64, 64, [0.0, 1.0], [0.0, 1.0]);
22//!
23//! // Set refinement criteria based on solution gradients
24//! let criteria = RefinementCriteria::gradient_based(0.1);
25//!
26//! // Solve with adaptive refinement
27//! let solver = AMRSolver::new(grid, criteria);
28//! ```
29
30use crate::common::IntegrateFloat;
31use crate::error::{IntegrateError, IntegrateResult};
32use crate::pde::PDEResult;
33use scirs2_core::ndarray::{Array2, ArrayView2};
34use std::collections::HashMap;
35
36/// Adaptive mesh refinement grid with hierarchical structure
37#[derive(Debug, Clone)]
38pub struct AMRGrid<F: IntegrateFloat> {
39    /// Grid hierarchy levels (level 0 is coarsest)
40    levels: Vec<GridLevel<F>>,
41    /// Maximum allowed refinement level
42    max_level: usize,
43    /// Minimum allowed refinement level
44    min_level: usize,
45    /// Domain bounds
46    #[allow(dead_code)]
47    domain: ([F; 2], [F; 2]),
48    /// Current solution on all levels
49    solution: HashMap<(usize, usize, usize), F>, // (level, i, j) -> value
50}
51
52/// Single level in the AMR hierarchy
53#[derive(Debug, Clone)]
54pub struct GridLevel<F: IntegrateFloat> {
55    /// Grid level (0 = coarsest)
56    #[allow(dead_code)]
57    level: usize,
58    /// Number of cells in x direction
59    nx: usize,
60    /// Number of cells in y direction
61    ny: usize,
62    /// Grid spacing in x direction
63    dx: F,
64    /// Grid spacing in y direction
65    dy: F,
66    /// Refinement map (true = refined, false = not refined)
67    refined: Array2<bool>,
68    /// Child level information for refined cells
69    children: HashMap<(usize, usize), ChildInfo>,
70}
71
72/// Information about child cells in refined regions
73#[derive(Debug, Clone)]
74pub struct ChildInfo {
75    /// Starting indices of child region in next level
76    child_start: (usize, usize),
77    /// Size of child region
78    child_size: (usize, usize),
79}
80
81/// Refinement criteria for deciding when to refine/coarsen
82#[derive(Clone)]
83pub enum RefinementCriteria<F: IntegrateFloat> {
84    /// Refine based on solution gradient magnitude
85    GradientBased { threshold: F, coarsen_threshold: F },
86    /// Refine based on second derivative (curvature)
87    CurvatureBased { threshold: F, coarsen_threshold: F },
88    /// Refine based on estimated truncation error
89    ErrorBased { threshold: F, coarsen_threshold: F },
90    // Custom refinement function (Note: Clone not supported for function pointers)
91    // Custom {
92    //     refine_fn: Box<dyn Fn(ArrayView2<F>, usize, usize) -> bool + Send + Sync>,
93    //     coarsen_fn: Box<dyn Fn(ArrayView2<F>, usize, usize) -> bool + Send + Sync>,
94    // },
95}
96
97impl<F: IntegrateFloat + std::fmt::Debug> std::fmt::Debug for RefinementCriteria<F> {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        match self {
100            Self::GradientBased {
101                threshold,
102                coarsen_threshold,
103            } => f
104                .debug_struct("GradientBased")
105                .field("threshold", &format!("{threshold:?}"))
106                .field("coarsen_threshold", &format!("{coarsen_threshold:?}"))
107                .finish(),
108            Self::CurvatureBased {
109                threshold,
110                coarsen_threshold,
111            } => f
112                .debug_struct("CurvatureBased")
113                .field("threshold", &format!("{threshold:?}"))
114                .field("coarsen_threshold", &format!("{coarsen_threshold:?}"))
115                .finish(),
116            Self::ErrorBased {
117                threshold,
118                coarsen_threshold,
119            } => f
120                .debug_struct("ErrorBased")
121                .field("threshold", &format!("{threshold:?}"))
122                .field("coarsen_threshold", &format!("{coarsen_threshold:?}"))
123                .finish(),
124        }
125    }
126}
127
128impl<F: IntegrateFloat> RefinementCriteria<F> {
129    /// Create gradient-based refinement criteria
130    pub fn gradient_based(threshold: F) -> Self {
131        Self::GradientBased {
132            threshold,
133            coarsen_threshold: threshold / F::from(4.0).unwrap(),
134        }
135    }
136
137    /// Create curvature-based refinement criteria
138    pub fn curvature_based(threshold: F) -> Self {
139        Self::CurvatureBased {
140            threshold,
141            coarsen_threshold: threshold / F::from(4.0).unwrap(),
142        }
143    }
144
145    /// Create error-based refinement criteria
146    pub fn error_based(threshold: F) -> Self {
147        Self::ErrorBased {
148            threshold,
149            coarsen_threshold: threshold / F::from(16.0).unwrap(),
150        }
151    }
152
153    /// Check if cell should be refined
154    pub fn should_refine(&self, solution: ArrayView2<F>, i: usize, j: usize) -> bool {
155        match self {
156            Self::GradientBased { threshold, .. } => {
157                Self::compute_gradient_magnitude(solution, i, j) > *threshold
158            }
159            Self::CurvatureBased { threshold, .. } => {
160                Self::compute_curvature(solution, i, j) > *threshold
161            }
162            Self::ErrorBased { threshold, .. } => {
163                self.estimate_truncation_error(solution, i, j) > *threshold
164            } // Self::Custom { refine_fn, .. } => refine_fn(solution, i, j),
165        }
166    }
167
168    /// Check if cell should be coarsened
169    pub fn should_coarsen(&self, solution: ArrayView2<F>, i: usize, j: usize) -> bool {
170        match self {
171            Self::GradientBased {
172                coarsen_threshold, ..
173            } => Self::compute_gradient_magnitude(solution, i, j) < *coarsen_threshold,
174            Self::CurvatureBased {
175                coarsen_threshold, ..
176            } => Self::compute_curvature(solution, i, j) < *coarsen_threshold,
177            Self::ErrorBased {
178                coarsen_threshold, ..
179            } => self.estimate_truncation_error(solution, i, j) < *coarsen_threshold,
180            // Self::Custom { coarsen_fn, .. } => coarsen_fn(solution, i, j),
181        }
182    }
183
184    /// Compute gradient magnitude at cell (i, j)
185    fn compute_gradient_magnitude(solution: ArrayView2<F>, i: usize, j: usize) -> F {
186        let (nx, ny) = solution.dim();
187
188        // Compute gradients using centered differences where possible
189        let grad_x = if i == 0 {
190            solution[[1, j]] - solution[[0, j]]
191        } else if i == nx - 1 {
192            solution[[nx - 1, j]] - solution[[nx - 2, j]]
193        } else {
194            (solution[[i + 1, j]] - solution[[i - 1, j]]) / F::from(2.0).unwrap()
195        };
196
197        let grad_y = if j == 0 {
198            solution[[i, 1]] - solution[[i, 0]]
199        } else if j == ny - 1 {
200            solution[[i, ny - 1]] - solution[[i, ny - 2]]
201        } else {
202            (solution[[i, j + 1]] - solution[[i, j - 1]]) / F::from(2.0).unwrap()
203        };
204
205        (grad_x * grad_x + grad_y * grad_y).sqrt()
206    }
207
208    /// Compute curvature (second derivative magnitude) at cell (i, j)
209    fn compute_curvature(solution: ArrayView2<F>, i: usize, j: usize) -> F {
210        let (nx, ny) = solution.dim();
211
212        if i == 0 || i == nx - 1 || j == 0 || j == ny - 1 {
213            return F::zero(); // Can't compute curvature at boundaries
214        }
215
216        // Second derivatives using centered differences
217        let d2_dx2 =
218            solution[[i + 1, j]] - F::from(2.0).unwrap() * solution[[i, j]] + solution[[i - 1, j]];
219        let d2_dy2 =
220            solution[[i, j + 1]] - F::from(2.0).unwrap() * solution[[i, j]] + solution[[i, j - 1]];
221        let d2_dxdy =
222            (solution[[i + 1, j + 1]] - solution[[i + 1, j - 1]] - solution[[i - 1, j + 1]]
223                + solution[[i - 1, j - 1]])
224                / F::from(4.0).unwrap();
225
226        // Frobenius norm of Hessian matrix
227        (d2_dx2 * d2_dx2 + d2_dy2 * d2_dy2 + F::from(2.0).unwrap() * d2_dxdy * d2_dxdy).sqrt()
228    }
229
230    /// Estimate local truncation error
231    fn estimate_truncation_error(&self, solution: ArrayView2<F>, i: usize, j: usize) -> F {
232        // Simple Richardson extrapolation-based error estimate
233        // Compare solution at current resolution vs estimated higher-order solution
234        Self::compute_curvature(solution, i, j) / F::from(12.0).unwrap() // h² error estimate
235    }
236}
237
238impl<F: IntegrateFloat> AMRGrid<F> {
239    /// Create new AMR grid with initial coarse level
240    pub fn new(_nx: usize, ny: usize, domain_x: [F; 2], domainy: [F; 2]) -> Self {
241        let dx = (domain_x[1] - domain_x[0]) / F::from(_nx).unwrap();
242        let dy = (domainy[1] - domainy[0]) / F::from(ny).unwrap();
243
244        let coarse_level = GridLevel {
245            level: 0,
246            nx: _nx,
247            ny,
248            dx,
249            dy,
250            refined: Array2::from_elem((_nx, ny), false),
251            children: HashMap::new(),
252        };
253
254        Self {
255            levels: vec![coarse_level],
256            max_level: 5, // Default maximum 5 levels
257            min_level: 0,
258            domain: (domain_x, domainy),
259            solution: HashMap::new(),
260        }
261    }
262
263    /// Set maximum refinement level
264    pub fn set_max_level(&mut self, maxlevel: usize) {
265        self.max_level = maxlevel;
266    }
267
268    /// Refine grid based on criteria
269    pub fn refine(&mut self, criteria: &RefinementCriteria<F>) -> IntegrateResult<usize> {
270        let mut total_refined = 0;
271
272        // Process each level from coarse to fine
273        for level in 0..self.levels.len() {
274            if level >= self.max_level {
275                break;
276            }
277
278            let current_level = &self.levels[level];
279            let solution_level = self.get_solution_array(level)?;
280
281            let mut cells_to_refine = Vec::new();
282
283            // Find cells that need refinement
284            for i in 0..current_level.nx {
285                for j in 0..current_level.ny {
286                    if !current_level.refined[[i, j]]
287                        && criteria.should_refine(solution_level.view(), i, j)
288                    {
289                        cells_to_refine.push((i, j));
290                    }
291                }
292            }
293
294            // Refine selected cells
295            for (i, j) in cells_to_refine {
296                self.refine_cell(level, i, j)?;
297                total_refined += 1;
298            }
299        }
300
301        Ok(total_refined)
302    }
303
304    /// Coarsen grid based on criteria
305    pub fn coarsen(&mut self, criteria: &RefinementCriteria<F>) -> IntegrateResult<usize> {
306        let mut total_coarsened = 0;
307
308        // Process from fine to coarse levels
309        for level in (self.min_level + 1..self.levels.len()).rev() {
310            let current_level = &self.levels[level];
311            let solution_level = self.get_solution_array(level)?;
312
313            let mut cells_to_coarsen = Vec::new();
314
315            // Find cells that can be coarsened
316            for i in 0..current_level.nx {
317                for j in 0..current_level.ny {
318                    if criteria.should_coarsen(solution_level.view(), i, j) {
319                        cells_to_coarsen.push((i, j));
320                    }
321                }
322            }
323
324            // Coarsen selected cells (group into parent cells)
325            let coarsened = self.coarsen_cells(level, cells_to_coarsen)?;
326            total_coarsened += coarsened;
327        }
328
329        Ok(total_coarsened)
330    }
331
332    /// Refine a single cell
333    fn refine_cell(&mut self, level: usize, i: usize, j: usize) -> IntegrateResult<()> {
334        if level >= self.max_level {
335            return Err(IntegrateError::ValueError(
336                "Cannot refine beyond maximum level".to_string(),
337            ));
338        }
339
340        // Ensure next level exists
341        while self.levels.len() <= level + 1 {
342            let parent_level = &self.levels[level];
343            let child_nx = parent_level.nx * 2;
344            let child_ny = parent_level.ny * 2;
345            let child_dx = parent_level.dx / F::from(2.0).unwrap();
346            let child_dy = parent_level.dy / F::from(2.0).unwrap();
347
348            let child_level = GridLevel {
349                level: level + 1,
350                nx: child_nx,
351                ny: child_ny,
352                dx: child_dx,
353                dy: child_dy,
354                refined: Array2::from_elem((child_nx, child_ny), false),
355                children: HashMap::new(),
356            };
357
358            self.levels.push(child_level);
359        }
360
361        // Mark cell as refined
362        self.levels[level].refined[[i, j]] = true;
363
364        // Create child information
365        let child_start = (i * 2, j * 2);
366        let child_size = (2, 2);
367
368        self.levels[level].children.insert(
369            (i, j),
370            ChildInfo {
371                child_start,
372                child_size,
373            },
374        );
375
376        // Interpolate solution to child cells
377        self.interpolate_to_children(level, i, j)?;
378
379        Ok(())
380    }
381
382    /// Coarsen cells by grouping child cells back to parent cells
383    fn coarsen_cells(
384        &mut self,
385        level: usize,
386        cells: Vec<(usize, usize)>,
387    ) -> IntegrateResult<usize> {
388        if level == 0 {
389            return Ok(0); // Can't coarsen the coarsest level
390        }
391
392        let mut coarsened_count = 0;
393        let parent_level = level - 1;
394
395        // Group cells into parent cell candidates
396        let mut parent_candidates = HashMap::new();
397
398        for (i, j) in cells {
399            // Find parent cell coordinates
400            let parent_i = i / 2;
401            let parent_j = j / 2;
402            let child_offset = (i % 2, j % 2);
403
404            parent_candidates
405                .entry((parent_i, parent_j))
406                .or_insert_with(Vec::new)
407                .push((i, j, child_offset));
408        }
409
410        // Process each parent cell candidate
411        for ((parent_i, parent_j), children) in parent_candidates {
412            // Check if all 4 children are ready for coarsening
413            if children.len() == 4 {
414                // Verify the parent cell is actually refined
415                if parent_level < self.levels.len()
416                    && parent_i < self.levels[parent_level].nx
417                    && parent_j < self.levels[parent_level].ny
418                    && self.levels[parent_level].refined[[parent_i, parent_j]]
419                {
420                    // Average/restrict solution values from children to parent
421                    let mut averaged_value = F::zero();
422                    let mut valid_children = 0;
423
424                    for (child_i, child_j_, _) in &children {
425                        if let Some(&child_value) = self.solution.get(&(level, *child_i, *child_j_))
426                        {
427                            averaged_value += child_value;
428                            valid_children += 1;
429                        }
430                    }
431
432                    if valid_children > 0 {
433                        averaged_value /= F::from(valid_children).unwrap();
434
435                        // Store averaged value in parent cell
436                        self.solution
437                            .insert((parent_level, parent_i, parent_j), averaged_value);
438
439                        // Remove child values
440                        for (child_i, child_j_, _) in &children {
441                            self.solution.remove(&(level, *child_i, *child_j_));
442                        }
443
444                        // Mark parent as not refined
445                        self.levels[parent_level].refined[[parent_i, parent_j]] = false;
446
447                        // Remove child information
448                        self.levels[parent_level]
449                            .children
450                            .remove(&(parent_i, parent_j));
451
452                        coarsened_count += 1;
453                    }
454                }
455            }
456        }
457
458        // Clean up empty levels if they exist
459        self.cleanup_empty_levels();
460
461        Ok(coarsened_count)
462    }
463
464    /// Remove empty refinement levels from the hierarchy
465    fn cleanup_empty_levels(&mut self) {
466        // Find the highest level with any refined cells or solution data
467        let mut max_active_level = 0;
468
469        for level in 0..self.levels.len() {
470            let has_refined_cells = self.levels[level].refined.iter().any(|&x| x);
471            let has_solution_data = self.solution.keys().any(|(l__, _, _)| *l__ == level);
472
473            if has_refined_cells || has_solution_data {
474                max_active_level = level;
475            }
476        }
477
478        // Keep only active levels plus one extra for potential future refinement
479        let keep_levels = (max_active_level + 2).min(self.levels.len());
480        self.levels.truncate(keep_levels);
481    }
482
483    /// Interpolate solution from parent cell to child cells
484    fn interpolate_to_children(&mut self, level: usize, i: usize, j: usize) -> IntegrateResult<()> {
485        let parent_value = self
486            .solution
487            .get(&(level, i, j))
488            .copied()
489            .unwrap_or(F::zero());
490
491        if let Some(child_info) = self.levels[level].children.get(&(i, j)) {
492            let (start_i, start_j) = child_info.child_start;
493            let (size_i, size_j) = child_info.child_size;
494
495            // Simple constant interpolation (could be improved with higher-order)
496            for ci in 0..size_i {
497                for cj in 0..size_j {
498                    self.solution
499                        .insert((level + 1, start_i + ci, start_j + cj), parent_value);
500                }
501            }
502        }
503
504        Ok(())
505    }
506
507    /// Get solution as Array2 for a specific level
508    fn get_solution_array(&self, level: usize) -> IntegrateResult<Array2<F>> {
509        if level >= self.levels.len() {
510            return Err(IntegrateError::ValueError(
511                "Level does not exist".to_string(),
512            ));
513        }
514
515        let grid_level = &self.levels[level];
516        let mut solution = Array2::zeros((grid_level.nx, grid_level.ny));
517
518        for i in 0..grid_level.nx {
519            for j in 0..grid_level.ny {
520                if let Some(&value) = self.solution.get(&(level, i, j)) {
521                    solution[[i, j]] = value;
522                }
523            }
524        }
525
526        Ok(solution)
527    }
528
529    /// Get grid information for a specific level
530    pub fn get_level_info(&self, level: usize) -> Option<&GridLevel<F>> {
531        self.levels.get(level)
532    }
533
534    /// Get total number of cells across all levels
535    pub fn total_cells(&self) -> usize {
536        self.levels.iter().map(|level| level.nx * level.ny).sum()
537    }
538
539    /// Get refinement efficiency (fraction of cells refined)
540    pub fn refinement_efficiency(&self) -> f64 {
541        let total_refined = self
542            .levels
543            .iter()
544            .map(|level| level.refined.iter().filter(|&&x| x).count())
545            .sum::<usize>();
546
547        let total_possible = self.levels[0].nx * self.levels[0].ny;
548
549        total_refined as f64 / total_possible as f64
550    }
551}
552
553/// AMR-enhanced PDE solver
554pub struct AMRSolver<F: IntegrateFloat> {
555    /// AMR grid
556    grid: AMRGrid<F>,
557    /// Refinement criteria
558    criteria: RefinementCriteria<F>,
559    /// Number of AMR cycles performed
560    amr_cycles: usize,
561    /// Maximum AMR cycles per solve
562    max_amr_cycles: usize,
563}
564
565impl<F: IntegrateFloat> AMRSolver<F> {
566    /// Create new AMR solver
567    pub fn new(grid: AMRGrid<F>, criteria: RefinementCriteria<F>) -> Self {
568        Self {
569            grid,
570            criteria,
571            amr_cycles: 0,
572            max_amr_cycles: 5,
573        }
574    }
575
576    /// Set maximum AMR cycles
577    pub fn set_max_amr_cycles(&mut self, maxcycles: usize) {
578        self.max_amr_cycles = maxcycles;
579    }
580
581    /// Solve PDE with adaptive mesh refinement
582    pub fn solve_adaptive<ProblemFn>(
583        &mut self,
584        problem: ProblemFn,
585        initial_solution: Array2<F>,
586    ) -> PDEResult<Array2<F>>
587    where
588        ProblemFn: Fn(&AMRGrid<F>, ArrayView2<F>) -> PDEResult<Array2<F>>,
589    {
590        // Initialize _solution on coarse grid
591        let mut current_solution = initial_solution;
592
593        for cycle in 0..self.max_amr_cycles {
594            self.amr_cycles = cycle;
595
596            // Store current _solution in grid
597            self.store_solution_in_grid(&current_solution)?;
598
599            // Refine grid based on current _solution
600            let refined_cells = self.grid.refine(&self.criteria)?;
601
602            // If no refinement occurred, we're done
603            if refined_cells == 0 {
604                break;
605            }
606
607            // Solve on refined grid
608            current_solution = problem(&self.grid, current_solution.view())?;
609
610            // Optional: coarsen grid where possible
611            let _coarsened_cells = self.grid.coarsen(&self.criteria)?;
612        }
613
614        Ok(current_solution)
615    }
616
617    /// Store solution array in the AMR grid structure
618    fn store_solution_in_grid(&mut self, solution: &Array2<F>) -> IntegrateResult<()> {
619        let (nx, ny) = solution.dim();
620
621        // Clear existing solution
622        self.grid.solution.clear();
623
624        // Store solution for level 0
625        for i in 0..nx {
626            for j in 0..ny {
627                self.grid.solution.insert((0, i, j), solution[[i, j]]);
628            }
629        }
630
631        Ok(())
632    }
633
634    /// Get grid statistics
635    pub fn grid_statistics(&self) -> AMRStatistics {
636        AMRStatistics {
637            num_levels: self.grid.levels.len(),
638            total_cells: self.grid.total_cells(),
639            refinement_efficiency: self.grid.refinement_efficiency(),
640            amr_cycles: self.amr_cycles,
641        }
642    }
643}
644
645/// Statistics about AMR grid
646#[derive(Debug, Clone)]
647pub struct AMRStatistics {
648    /// Number of refinement levels
649    pub num_levels: usize,
650    /// Total number of cells across all levels
651    pub total_cells: usize,
652    /// Refinement efficiency (0.0 to 1.0)
653    pub refinement_efficiency: f64,
654    /// Number of AMR cycles performed
655    pub amr_cycles: usize,
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661    use approx::assert_abs_diff_eq;
662
663    #[test]
664    fn test_amr_grid_creation() {
665        let grid: AMRGrid<f64> = AMRGrid::new(32, 32, [0.0, 1.0], [0.0, 1.0]);
666
667        assert_eq!(grid.levels.len(), 1);
668        assert_eq!(grid.levels[0].nx, 32);
669        assert_eq!(grid.levels[0].ny, 32);
670        assert_abs_diff_eq!(grid.levels[0].dx, 1.0 / 32.0);
671        assert_abs_diff_eq!(grid.levels[0].dy, 1.0 / 32.0);
672    }
673
674    #[test]
675    fn test_refinement_criteria() {
676        // Create test solution with gradient
677        let mut solution = Array2::zeros((5, 5));
678        for i in 0..5 {
679            for j in 0..5 {
680                solution[[i, j]] = (i * i + j * j) as f64 * 0.1;
681            }
682        }
683
684        let criteria = RefinementCriteria::gradient_based(0.5);
685
686        // Test gradient computation
687        let should_refine = criteria.should_refine(solution.view(), 2, 2);
688        // Middle cell should have some gradient
689
690        let should_coarsen = criteria.should_coarsen(solution.view(), 0, 0);
691        // Corner cell might be suitable for coarsening
692
693        // Just verify these don't panic - exact values depend on implementation
694        let _ = should_refine;
695        let _ = should_coarsen;
696    }
697
698    #[test]
699    fn test_cell_refinement() {
700        let mut grid: AMRGrid<f64> = AMRGrid::new(4, 4, [0.0, 1.0], [0.0, 1.0]);
701
702        // Set up some solution values
703        grid.solution.insert((0, 1, 1), 1.0);
704
705        // Refine a cell
706        assert!(grid.refine_cell(0, 1, 1).is_ok());
707
708        // Check that refinement was applied
709        assert!(grid.levels[0].refined[[1, 1]]);
710        assert_eq!(grid.levels.len(), 2); // Should create next level
711
712        // Check child information
713        assert!(grid.levels[0].children.contains_key(&(1, 1)));
714    }
715
716    #[test]
717    fn test_amr_solver() {
718        let grid: AMRGrid<f64> = AMRGrid::new(8, 8, [0.0, 1.0], [0.0, 1.0]);
719        let criteria = RefinementCriteria::gradient_based(0.1);
720        let mut solver = AMRSolver::new(grid, criteria);
721
722        // Simple initial solution
723        let initial = Array2::from_elem((8, 8), 0.5);
724
725        // Dummy problem function
726        let problem = |_grid: &AMRGrid<f64>, solution: ArrayView2<f64>| -> PDEResult<Array2<f64>> {
727            Ok(solution.to_owned())
728        };
729
730        let result = solver.solve_adaptive(problem, initial);
731        assert!(result.is_ok());
732
733        let stats = solver.grid_statistics();
734        assert!(stats.num_levels >= 1);
735        assert!(stats.total_cells >= 64); // At least the initial 8×8 grid
736    }
737
738    #[test]
739    fn test_amr_coarsening() {
740        let mut grid: AMRGrid<f64> = AMRGrid::new(4, 4, [0.0, 1.0], [0.0, 1.0]);
741
742        // Set up solution values
743        grid.solution.insert((0, 1, 1), 1.0);
744
745        // Refine a cell to create level 1
746        assert!(grid.refine_cell(0, 1, 1).is_ok());
747        assert_eq!(grid.levels.len(), 2); // Should have 2 levels now
748
749        // Add solution values to all 4 child cells
750        grid.solution.insert((1, 2, 2), 0.8);
751        grid.solution.insert((1, 2, 3), 0.9);
752        grid.solution.insert((1, 3, 2), 1.1);
753        grid.solution.insert((1, 3, 3), 1.2);
754
755        // Test coarsening - create cells to coarsen (all 4 children)
756        let cells_to_coarsen = vec![(2, 2), (2, 3), (3, 2), (3, 3)];
757        let coarsened = grid.coarsen_cells(1, cells_to_coarsen).unwrap();
758
759        // Should have coarsened 1 parent cell (containing 4 children)
760        assert_eq!(coarsened, 1);
761
762        // Parent cell should no longer be marked as refined
763        assert!(!grid.levels[0].refined[[1, 1]]);
764
765        // Parent cell should have averaged value: (0.8 + 0.9 + 1.1 + 1.2) / 4 = 1.0
766        assert_eq!(grid.solution.get(&(0, 1, 1)), Some(&1.0));
767
768        // Child values should be removed
769        assert_eq!(grid.solution.get(&(1, 2, 2)), None);
770        assert_eq!(grid.solution.get(&(1, 2, 3)), None);
771        assert_eq!(grid.solution.get(&(1, 3, 2)), None);
772        assert_eq!(grid.solution.get(&(1, 3, 3)), None);
773    }
774}