1use crate::common::IntegrateFloat;
31use crate::error::{IntegrateError, IntegrateResult};
32use crate::pde::PDEResult;
33use scirs2_core::ndarray::{Array2, ArrayView2};
34use std::collections::HashMap;
35
36#[derive(Debug, Clone)]
38pub struct AMRGrid<F: IntegrateFloat> {
39 levels: Vec<GridLevel<F>>,
41 max_level: usize,
43 min_level: usize,
45 #[allow(dead_code)]
47 domain: ([F; 2], [F; 2]),
48 solution: HashMap<(usize, usize, usize), F>, }
51
52#[derive(Debug, Clone)]
54pub struct GridLevel<F: IntegrateFloat> {
55 #[allow(dead_code)]
57 level: usize,
58 nx: usize,
60 ny: usize,
62 dx: F,
64 dy: F,
66 refined: Array2<bool>,
68 children: HashMap<(usize, usize), ChildInfo>,
70}
71
72#[derive(Debug, Clone)]
74pub struct ChildInfo {
75 child_start: (usize, usize),
77 child_size: (usize, usize),
79}
80
81#[derive(Clone)]
83pub enum RefinementCriteria<F: IntegrateFloat> {
84 GradientBased { threshold: F, coarsen_threshold: F },
86 CurvatureBased { threshold: F, coarsen_threshold: F },
88 ErrorBased { threshold: F, coarsen_threshold: F },
90 }
96
97impl<F: IntegrateFloat + std::fmt::Debug> std::fmt::Debug for RefinementCriteria<F> {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 match self {
100 Self::GradientBased {
101 threshold,
102 coarsen_threshold,
103 } => f
104 .debug_struct("GradientBased")
105 .field("threshold", &format!("{threshold:?}"))
106 .field("coarsen_threshold", &format!("{coarsen_threshold:?}"))
107 .finish(),
108 Self::CurvatureBased {
109 threshold,
110 coarsen_threshold,
111 } => f
112 .debug_struct("CurvatureBased")
113 .field("threshold", &format!("{threshold:?}"))
114 .field("coarsen_threshold", &format!("{coarsen_threshold:?}"))
115 .finish(),
116 Self::ErrorBased {
117 threshold,
118 coarsen_threshold,
119 } => f
120 .debug_struct("ErrorBased")
121 .field("threshold", &format!("{threshold:?}"))
122 .field("coarsen_threshold", &format!("{coarsen_threshold:?}"))
123 .finish(),
124 }
125 }
126}
127
128impl<F: IntegrateFloat> RefinementCriteria<F> {
129 pub fn gradient_based(threshold: F) -> Self {
131 Self::GradientBased {
132 threshold,
133 coarsen_threshold: threshold / F::from(4.0).unwrap(),
134 }
135 }
136
137 pub fn curvature_based(threshold: F) -> Self {
139 Self::CurvatureBased {
140 threshold,
141 coarsen_threshold: threshold / F::from(4.0).unwrap(),
142 }
143 }
144
145 pub fn error_based(threshold: F) -> Self {
147 Self::ErrorBased {
148 threshold,
149 coarsen_threshold: threshold / F::from(16.0).unwrap(),
150 }
151 }
152
153 pub fn should_refine(&self, solution: ArrayView2<F>, i: usize, j: usize) -> bool {
155 match self {
156 Self::GradientBased { threshold, .. } => {
157 Self::compute_gradient_magnitude(solution, i, j) > *threshold
158 }
159 Self::CurvatureBased { threshold, .. } => {
160 Self::compute_curvature(solution, i, j) > *threshold
161 }
162 Self::ErrorBased { threshold, .. } => {
163 self.estimate_truncation_error(solution, i, j) > *threshold
164 } }
166 }
167
168 pub fn should_coarsen(&self, solution: ArrayView2<F>, i: usize, j: usize) -> bool {
170 match self {
171 Self::GradientBased {
172 coarsen_threshold, ..
173 } => Self::compute_gradient_magnitude(solution, i, j) < *coarsen_threshold,
174 Self::CurvatureBased {
175 coarsen_threshold, ..
176 } => Self::compute_curvature(solution, i, j) < *coarsen_threshold,
177 Self::ErrorBased {
178 coarsen_threshold, ..
179 } => self.estimate_truncation_error(solution, i, j) < *coarsen_threshold,
180 }
182 }
183
184 fn compute_gradient_magnitude(solution: ArrayView2<F>, i: usize, j: usize) -> F {
186 let (nx, ny) = solution.dim();
187
188 let grad_x = if i == 0 {
190 solution[[1, j]] - solution[[0, j]]
191 } else if i == nx - 1 {
192 solution[[nx - 1, j]] - solution[[nx - 2, j]]
193 } else {
194 (solution[[i + 1, j]] - solution[[i - 1, j]]) / F::from(2.0).unwrap()
195 };
196
197 let grad_y = if j == 0 {
198 solution[[i, 1]] - solution[[i, 0]]
199 } else if j == ny - 1 {
200 solution[[i, ny - 1]] - solution[[i, ny - 2]]
201 } else {
202 (solution[[i, j + 1]] - solution[[i, j - 1]]) / F::from(2.0).unwrap()
203 };
204
205 (grad_x * grad_x + grad_y * grad_y).sqrt()
206 }
207
208 fn compute_curvature(solution: ArrayView2<F>, i: usize, j: usize) -> F {
210 let (nx, ny) = solution.dim();
211
212 if i == 0 || i == nx - 1 || j == 0 || j == ny - 1 {
213 return F::zero(); }
215
216 let d2_dx2 =
218 solution[[i + 1, j]] - F::from(2.0).unwrap() * solution[[i, j]] + solution[[i - 1, j]];
219 let d2_dy2 =
220 solution[[i, j + 1]] - F::from(2.0).unwrap() * solution[[i, j]] + solution[[i, j - 1]];
221 let d2_dxdy =
222 (solution[[i + 1, j + 1]] - solution[[i + 1, j - 1]] - solution[[i - 1, j + 1]]
223 + solution[[i - 1, j - 1]])
224 / F::from(4.0).unwrap();
225
226 (d2_dx2 * d2_dx2 + d2_dy2 * d2_dy2 + F::from(2.0).unwrap() * d2_dxdy * d2_dxdy).sqrt()
228 }
229
230 fn estimate_truncation_error(&self, solution: ArrayView2<F>, i: usize, j: usize) -> F {
232 Self::compute_curvature(solution, i, j) / F::from(12.0).unwrap() }
236}
237
238impl<F: IntegrateFloat> AMRGrid<F> {
239 pub fn new(_nx: usize, ny: usize, domain_x: [F; 2], domainy: [F; 2]) -> Self {
241 let dx = (domain_x[1] - domain_x[0]) / F::from(_nx).unwrap();
242 let dy = (domainy[1] - domainy[0]) / F::from(ny).unwrap();
243
244 let coarse_level = GridLevel {
245 level: 0,
246 nx: _nx,
247 ny,
248 dx,
249 dy,
250 refined: Array2::from_elem((_nx, ny), false),
251 children: HashMap::new(),
252 };
253
254 Self {
255 levels: vec![coarse_level],
256 max_level: 5, min_level: 0,
258 domain: (domain_x, domainy),
259 solution: HashMap::new(),
260 }
261 }
262
263 pub fn set_max_level(&mut self, maxlevel: usize) {
265 self.max_level = maxlevel;
266 }
267
268 pub fn refine(&mut self, criteria: &RefinementCriteria<F>) -> IntegrateResult<usize> {
270 let mut total_refined = 0;
271
272 for level in 0..self.levels.len() {
274 if level >= self.max_level {
275 break;
276 }
277
278 let current_level = &self.levels[level];
279 let solution_level = self.get_solution_array(level)?;
280
281 let mut cells_to_refine = Vec::new();
282
283 for i in 0..current_level.nx {
285 for j in 0..current_level.ny {
286 if !current_level.refined[[i, j]]
287 && criteria.should_refine(solution_level.view(), i, j)
288 {
289 cells_to_refine.push((i, j));
290 }
291 }
292 }
293
294 for (i, j) in cells_to_refine {
296 self.refine_cell(level, i, j)?;
297 total_refined += 1;
298 }
299 }
300
301 Ok(total_refined)
302 }
303
304 pub fn coarsen(&mut self, criteria: &RefinementCriteria<F>) -> IntegrateResult<usize> {
306 let mut total_coarsened = 0;
307
308 for level in (self.min_level + 1..self.levels.len()).rev() {
310 let current_level = &self.levels[level];
311 let solution_level = self.get_solution_array(level)?;
312
313 let mut cells_to_coarsen = Vec::new();
314
315 for i in 0..current_level.nx {
317 for j in 0..current_level.ny {
318 if criteria.should_coarsen(solution_level.view(), i, j) {
319 cells_to_coarsen.push((i, j));
320 }
321 }
322 }
323
324 let coarsened = self.coarsen_cells(level, cells_to_coarsen)?;
326 total_coarsened += coarsened;
327 }
328
329 Ok(total_coarsened)
330 }
331
332 fn refine_cell(&mut self, level: usize, i: usize, j: usize) -> IntegrateResult<()> {
334 if level >= self.max_level {
335 return Err(IntegrateError::ValueError(
336 "Cannot refine beyond maximum level".to_string(),
337 ));
338 }
339
340 while self.levels.len() <= level + 1 {
342 let parent_level = &self.levels[level];
343 let child_nx = parent_level.nx * 2;
344 let child_ny = parent_level.ny * 2;
345 let child_dx = parent_level.dx / F::from(2.0).unwrap();
346 let child_dy = parent_level.dy / F::from(2.0).unwrap();
347
348 let child_level = GridLevel {
349 level: level + 1,
350 nx: child_nx,
351 ny: child_ny,
352 dx: child_dx,
353 dy: child_dy,
354 refined: Array2::from_elem((child_nx, child_ny), false),
355 children: HashMap::new(),
356 };
357
358 self.levels.push(child_level);
359 }
360
361 self.levels[level].refined[[i, j]] = true;
363
364 let child_start = (i * 2, j * 2);
366 let child_size = (2, 2);
367
368 self.levels[level].children.insert(
369 (i, j),
370 ChildInfo {
371 child_start,
372 child_size,
373 },
374 );
375
376 self.interpolate_to_children(level, i, j)?;
378
379 Ok(())
380 }
381
382 fn coarsen_cells(
384 &mut self,
385 level: usize,
386 cells: Vec<(usize, usize)>,
387 ) -> IntegrateResult<usize> {
388 if level == 0 {
389 return Ok(0); }
391
392 let mut coarsened_count = 0;
393 let parent_level = level - 1;
394
395 let mut parent_candidates = HashMap::new();
397
398 for (i, j) in cells {
399 let parent_i = i / 2;
401 let parent_j = j / 2;
402 let child_offset = (i % 2, j % 2);
403
404 parent_candidates
405 .entry((parent_i, parent_j))
406 .or_insert_with(Vec::new)
407 .push((i, j, child_offset));
408 }
409
410 for ((parent_i, parent_j), children) in parent_candidates {
412 if children.len() == 4 {
414 if parent_level < self.levels.len()
416 && parent_i < self.levels[parent_level].nx
417 && parent_j < self.levels[parent_level].ny
418 && self.levels[parent_level].refined[[parent_i, parent_j]]
419 {
420 let mut averaged_value = F::zero();
422 let mut valid_children = 0;
423
424 for (child_i, child_j_, _) in &children {
425 if let Some(&child_value) = self.solution.get(&(level, *child_i, *child_j_))
426 {
427 averaged_value += child_value;
428 valid_children += 1;
429 }
430 }
431
432 if valid_children > 0 {
433 averaged_value /= F::from(valid_children).unwrap();
434
435 self.solution
437 .insert((parent_level, parent_i, parent_j), averaged_value);
438
439 for (child_i, child_j_, _) in &children {
441 self.solution.remove(&(level, *child_i, *child_j_));
442 }
443
444 self.levels[parent_level].refined[[parent_i, parent_j]] = false;
446
447 self.levels[parent_level]
449 .children
450 .remove(&(parent_i, parent_j));
451
452 coarsened_count += 1;
453 }
454 }
455 }
456 }
457
458 self.cleanup_empty_levels();
460
461 Ok(coarsened_count)
462 }
463
464 fn cleanup_empty_levels(&mut self) {
466 let mut max_active_level = 0;
468
469 for level in 0..self.levels.len() {
470 let has_refined_cells = self.levels[level].refined.iter().any(|&x| x);
471 let has_solution_data = self.solution.keys().any(|(l__, _, _)| *l__ == level);
472
473 if has_refined_cells || has_solution_data {
474 max_active_level = level;
475 }
476 }
477
478 let keep_levels = (max_active_level + 2).min(self.levels.len());
480 self.levels.truncate(keep_levels);
481 }
482
483 fn interpolate_to_children(&mut self, level: usize, i: usize, j: usize) -> IntegrateResult<()> {
485 let parent_value = self
486 .solution
487 .get(&(level, i, j))
488 .copied()
489 .unwrap_or(F::zero());
490
491 if let Some(child_info) = self.levels[level].children.get(&(i, j)) {
492 let (start_i, start_j) = child_info.child_start;
493 let (size_i, size_j) = child_info.child_size;
494
495 for ci in 0..size_i {
497 for cj in 0..size_j {
498 self.solution
499 .insert((level + 1, start_i + ci, start_j + cj), parent_value);
500 }
501 }
502 }
503
504 Ok(())
505 }
506
507 fn get_solution_array(&self, level: usize) -> IntegrateResult<Array2<F>> {
509 if level >= self.levels.len() {
510 return Err(IntegrateError::ValueError(
511 "Level does not exist".to_string(),
512 ));
513 }
514
515 let grid_level = &self.levels[level];
516 let mut solution = Array2::zeros((grid_level.nx, grid_level.ny));
517
518 for i in 0..grid_level.nx {
519 for j in 0..grid_level.ny {
520 if let Some(&value) = self.solution.get(&(level, i, j)) {
521 solution[[i, j]] = value;
522 }
523 }
524 }
525
526 Ok(solution)
527 }
528
529 pub fn get_level_info(&self, level: usize) -> Option<&GridLevel<F>> {
531 self.levels.get(level)
532 }
533
534 pub fn total_cells(&self) -> usize {
536 self.levels.iter().map(|level| level.nx * level.ny).sum()
537 }
538
539 pub fn refinement_efficiency(&self) -> f64 {
541 let total_refined = self
542 .levels
543 .iter()
544 .map(|level| level.refined.iter().filter(|&&x| x).count())
545 .sum::<usize>();
546
547 let total_possible = self.levels[0].nx * self.levels[0].ny;
548
549 total_refined as f64 / total_possible as f64
550 }
551}
552
553pub struct AMRSolver<F: IntegrateFloat> {
555 grid: AMRGrid<F>,
557 criteria: RefinementCriteria<F>,
559 amr_cycles: usize,
561 max_amr_cycles: usize,
563}
564
565impl<F: IntegrateFloat> AMRSolver<F> {
566 pub fn new(grid: AMRGrid<F>, criteria: RefinementCriteria<F>) -> Self {
568 Self {
569 grid,
570 criteria,
571 amr_cycles: 0,
572 max_amr_cycles: 5,
573 }
574 }
575
576 pub fn set_max_amr_cycles(&mut self, maxcycles: usize) {
578 self.max_amr_cycles = maxcycles;
579 }
580
581 pub fn solve_adaptive<ProblemFn>(
583 &mut self,
584 problem: ProblemFn,
585 initial_solution: Array2<F>,
586 ) -> PDEResult<Array2<F>>
587 where
588 ProblemFn: Fn(&AMRGrid<F>, ArrayView2<F>) -> PDEResult<Array2<F>>,
589 {
590 let mut current_solution = initial_solution;
592
593 for cycle in 0..self.max_amr_cycles {
594 self.amr_cycles = cycle;
595
596 self.store_solution_in_grid(¤t_solution)?;
598
599 let refined_cells = self.grid.refine(&self.criteria)?;
601
602 if refined_cells == 0 {
604 break;
605 }
606
607 current_solution = problem(&self.grid, current_solution.view())?;
609
610 let _coarsened_cells = self.grid.coarsen(&self.criteria)?;
612 }
613
614 Ok(current_solution)
615 }
616
617 fn store_solution_in_grid(&mut self, solution: &Array2<F>) -> IntegrateResult<()> {
619 let (nx, ny) = solution.dim();
620
621 self.grid.solution.clear();
623
624 for i in 0..nx {
626 for j in 0..ny {
627 self.grid.solution.insert((0, i, j), solution[[i, j]]);
628 }
629 }
630
631 Ok(())
632 }
633
634 pub fn grid_statistics(&self) -> AMRStatistics {
636 AMRStatistics {
637 num_levels: self.grid.levels.len(),
638 total_cells: self.grid.total_cells(),
639 refinement_efficiency: self.grid.refinement_efficiency(),
640 amr_cycles: self.amr_cycles,
641 }
642 }
643}
644
645#[derive(Debug, Clone)]
647pub struct AMRStatistics {
648 pub num_levels: usize,
650 pub total_cells: usize,
652 pub refinement_efficiency: f64,
654 pub amr_cycles: usize,
656}
657
658#[cfg(test)]
659mod tests {
660 use super::*;
661 use approx::assert_abs_diff_eq;
662
663 #[test]
664 fn test_amr_grid_creation() {
665 let grid: AMRGrid<f64> = AMRGrid::new(32, 32, [0.0, 1.0], [0.0, 1.0]);
666
667 assert_eq!(grid.levels.len(), 1);
668 assert_eq!(grid.levels[0].nx, 32);
669 assert_eq!(grid.levels[0].ny, 32);
670 assert_abs_diff_eq!(grid.levels[0].dx, 1.0 / 32.0);
671 assert_abs_diff_eq!(grid.levels[0].dy, 1.0 / 32.0);
672 }
673
674 #[test]
675 fn test_refinement_criteria() {
676 let mut solution = Array2::zeros((5, 5));
678 for i in 0..5 {
679 for j in 0..5 {
680 solution[[i, j]] = (i * i + j * j) as f64 * 0.1;
681 }
682 }
683
684 let criteria = RefinementCriteria::gradient_based(0.5);
685
686 let should_refine = criteria.should_refine(solution.view(), 2, 2);
688 let should_coarsen = criteria.should_coarsen(solution.view(), 0, 0);
691 let _ = should_refine;
695 let _ = should_coarsen;
696 }
697
698 #[test]
699 fn test_cell_refinement() {
700 let mut grid: AMRGrid<f64> = AMRGrid::new(4, 4, [0.0, 1.0], [0.0, 1.0]);
701
702 grid.solution.insert((0, 1, 1), 1.0);
704
705 assert!(grid.refine_cell(0, 1, 1).is_ok());
707
708 assert!(grid.levels[0].refined[[1, 1]]);
710 assert_eq!(grid.levels.len(), 2); assert!(grid.levels[0].children.contains_key(&(1, 1)));
714 }
715
716 #[test]
717 fn test_amr_solver() {
718 let grid: AMRGrid<f64> = AMRGrid::new(8, 8, [0.0, 1.0], [0.0, 1.0]);
719 let criteria = RefinementCriteria::gradient_based(0.1);
720 let mut solver = AMRSolver::new(grid, criteria);
721
722 let initial = Array2::from_elem((8, 8), 0.5);
724
725 let problem = |_grid: &AMRGrid<f64>, solution: ArrayView2<f64>| -> PDEResult<Array2<f64>> {
727 Ok(solution.to_owned())
728 };
729
730 let result = solver.solve_adaptive(problem, initial);
731 assert!(result.is_ok());
732
733 let stats = solver.grid_statistics();
734 assert!(stats.num_levels >= 1);
735 assert!(stats.total_cells >= 64); }
737
738 #[test]
739 fn test_amr_coarsening() {
740 let mut grid: AMRGrid<f64> = AMRGrid::new(4, 4, [0.0, 1.0], [0.0, 1.0]);
741
742 grid.solution.insert((0, 1, 1), 1.0);
744
745 assert!(grid.refine_cell(0, 1, 1).is_ok());
747 assert_eq!(grid.levels.len(), 2); grid.solution.insert((1, 2, 2), 0.8);
751 grid.solution.insert((1, 2, 3), 0.9);
752 grid.solution.insert((1, 3, 2), 1.1);
753 grid.solution.insert((1, 3, 3), 1.2);
754
755 let cells_to_coarsen = vec![(2, 2), (2, 3), (3, 2), (3, 3)];
757 let coarsened = grid.coarsen_cells(1, cells_to_coarsen).unwrap();
758
759 assert_eq!(coarsened, 1);
761
762 assert!(!grid.levels[0].refined[[1, 1]]);
764
765 assert_eq!(grid.solution.get(&(0, 1, 1)), Some(&1.0));
767
768 assert_eq!(grid.solution.get(&(1, 2, 2)), None);
770 assert_eq!(grid.solution.get(&(1, 2, 3)), None);
771 assert_eq!(grid.solution.get(&(1, 3, 2)), None);
772 assert_eq!(grid.solution.get(&(1, 3, 3)), None);
773 }
774}