scirs2_integrate/
amr_advanced.rs

1//! Advanced Adaptive Mesh Refinement (AMR) with sophisticated error indicators
2//!
3//! This module provides state-of-the-art adaptive mesh refinement techniques
4//! including gradient-based refinement, feature detection, load balancing,
5//! and hierarchical mesh management for complex PDE solutions.
6
7use crate::common::IntegrateFloat;
8use crate::error::{IntegrateError, IntegrateResult};
9use scirs2_core::ndarray::{Array1, Array2};
10use std::collections::{HashMap, HashSet};
11
12/// Advanced AMR manager with multiple refinement strategies
13pub struct AdvancedAMRManager<F: IntegrateFloat> {
14    /// Current mesh hierarchy
15    pub mesh_hierarchy: MeshHierarchy<F>,
16    /// Refinement criteria
17    pub refinement_criteria: Vec<Box<dyn RefinementCriterion<F>>>,
18    /// Load balancing strategy
19    pub load_balancer: Option<Box<dyn LoadBalancer<F>>>,
20    /// Maximum refinement levels
21    pub max_levels: usize,
22    /// Minimum cell size
23    pub min_cell_size: F,
24    /// Coarsening tolerance
25    pub coarsening_tolerance: F,
26    /// Error tolerance for refinement
27    pub refinement_tolerance: F,
28    /// Adaptation frequency
29    pub adaptation_frequency: usize,
30    /// Current adaptation step
31    current_step: usize,
32}
33
34/// Hierarchical mesh structure supporting multiple levels
35#[derive(Debug, Clone)]
36pub struct MeshHierarchy<F: IntegrateFloat> {
37    /// Mesh levels (0 = coarsest)
38    pub levels: Vec<AdaptiveMeshLevel<F>>,
39    /// Parent-child relationships
40    pub hierarchy_map: HashMap<CellId, Vec<CellId>>,
41    /// Ghost cell information for parallel processing
42    pub ghost_cells: HashMap<usize, Vec<CellId>>, // level -> ghost cells
43}
44
45/// Single level in adaptive mesh
46#[derive(Debug, Clone)]
47pub struct AdaptiveMeshLevel<F: IntegrateFloat> {
48    /// Level number (0 = coarsest)
49    pub level: usize,
50    /// Active cells at this level
51    pub cells: HashMap<CellId, AdaptiveCell<F>>,
52    /// Grid spacing at this level
53    pub grid_spacing: F,
54    /// Boundary information
55    pub boundary_cells: HashSet<CellId>,
56}
57
58/// Individual adaptive cell
59#[derive(Debug, Clone)]
60pub struct AdaptiveCell<F: IntegrateFloat> {
61    /// Unique cell identifier
62    pub id: CellId,
63    /// Cell center coordinates
64    pub center: Array1<F>,
65    /// Cell size
66    pub size: F,
67    /// Solution value(s) in cell
68    pub solution: Array1<F>,
69    /// Error estimate for cell
70    pub error_estimate: F,
71    /// Refinement flag
72    pub refinement_flag: RefinementFlag,
73    /// Activity status
74    pub is_active: bool,
75    /// Neighboring cell IDs
76    pub neighbors: Vec<CellId>,
77    /// Parent cell ID (if refined)
78    pub parent: Option<CellId>,
79    /// Children cell IDs (if coarsened)
80    pub children: Vec<CellId>,
81}
82
83/// Cell identifier
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85pub struct CellId {
86    pub level: usize,
87    pub index: usize,
88}
89
90/// Refinement action flags
91#[derive(Debug, Clone, Copy, PartialEq)]
92pub enum RefinementFlag {
93    /// No action needed
94    None,
95    /// Cell should be refined
96    Refine,
97    /// Cell should be coarsened
98    Coarsen,
99    /// Cell marked for potential refinement
100    Tagged,
101}
102
103/// Trait for refinement criteria
104pub trait RefinementCriterion<F: IntegrateFloat>: Send + Sync {
105    /// Evaluate refinement criterion for a cell
106    fn evaluate(&self, cell: &AdaptiveCell<F>, neighbors: &[&AdaptiveCell<F>]) -> F;
107
108    /// Get criterion name
109    fn name(&self) -> &'static str;
110
111    /// Get criterion weight in combined evaluation
112    fn weight(&self) -> F {
113        F::one()
114    }
115}
116
117/// Gradient-based refinement criterion
118pub struct GradientRefinementCriterion<F: IntegrateFloat> {
119    /// Component to analyze (None = all components)
120    pub component: Option<usize>,
121    /// Gradient threshold
122    pub threshold: F,
123    /// Relative tolerance
124    pub relative_tolerance: F,
125}
126
127/// Feature detection refinement criterion
128pub struct FeatureDetectionCriterion<F: IntegrateFloat> {
129    /// Feature detection threshold
130    pub threshold: F,
131    /// Feature types to detect
132    pub feature_types: Vec<FeatureType>,
133    /// Window size for feature detection
134    pub window_size: usize,
135}
136
137/// Curvature-based refinement criterion
138pub struct CurvatureRefinementCriterion<F: IntegrateFloat> {
139    /// Curvature threshold
140    pub threshold: F,
141    /// Approximation order for curvature calculation
142    pub approximation_order: usize,
143}
144
145/// Load balancing strategy trait
146pub trait LoadBalancer<F: IntegrateFloat>: Send + Sync {
147    /// Balance computational load across processors/threads
148    fn balance(&self, hierarchy: &mut MeshHierarchy<F>) -> IntegrateResult<()>;
149}
150
151/// Zoltan-style geometric load balancer
152pub struct GeometricLoadBalancer<F: IntegrateFloat> {
153    /// Number of partitions
154    pub num_partitions: usize,
155    /// Load imbalance tolerance
156    pub imbalance_tolerance: F,
157    /// Partitioning method
158    pub method: PartitioningMethod,
159}
160
161/// Types of features to detect
162#[derive(Debug, Clone, Copy, PartialEq)]
163pub enum FeatureType {
164    /// Sharp gradients
165    SharpGradient,
166    /// Discontinuities
167    Discontinuity,
168    /// Local extrema
169    LocalExtremum,
170    /// Oscillatory behavior
171    Oscillation,
172    /// Boundary layers
173    BoundaryLayer,
174}
175
176/// Partitioning methods
177#[derive(Debug, Clone, Copy)]
178pub enum PartitioningMethod {
179    /// Recursive coordinate bisection
180    RCB,
181    /// Space-filling curves
182    SFC,
183    /// Graph partitioning
184    Graph,
185}
186
187/// AMR adaptation result
188pub struct AMRAdaptationResult<F: IntegrateFloat> {
189    /// Number of cells refined
190    pub cells_refined: usize,
191    /// Number of cells coarsened
192    pub cells_coarsened: usize,
193    /// Total active cells after adaptation
194    pub total_active_cells: usize,
195    /// Load balance quality metric
196    pub load_balance_quality: F,
197    /// Memory usage change
198    pub memory_change: i64,
199    /// Adaptation time
200    pub adaptation_time: std::time::Duration,
201}
202
203impl<F: IntegrateFloat> AdvancedAMRManager<F> {
204    /// Create new advanced AMR manager
205    pub fn new(_initial_mesh: AdaptiveMeshLevel<F>, max_levels: usize, min_cellsize: F) -> Self {
206        let mesh_hierarchy = MeshHierarchy {
207            levels: vec![_initial_mesh],
208            hierarchy_map: HashMap::new(),
209            ghost_cells: HashMap::new(),
210        };
211
212        Self {
213            mesh_hierarchy,
214            refinement_criteria: Vec::new(),
215            load_balancer: None,
216            max_levels,
217            min_cell_size: min_cellsize,
218            coarsening_tolerance: F::from(0.1).unwrap(),
219            refinement_tolerance: F::from(1.0).unwrap(),
220            adaptation_frequency: 1,
221            current_step: 0,
222        }
223    }
224
225    /// Add refinement criterion
226    pub fn add_criterion(&mut self, criterion: Box<dyn RefinementCriterion<F>>) {
227        self.refinement_criteria.push(criterion);
228    }
229
230    /// Set load balancer
231    pub fn set_load_balancer(&mut self, balancer: Box<dyn LoadBalancer<F>>) {
232        self.load_balancer = Some(balancer);
233    }
234
235    /// Perform adaptive mesh refinement
236    pub fn adapt_mesh(&mut self, solution: &Array2<F>) -> IntegrateResult<AMRAdaptationResult<F>> {
237        let start_time = std::time::Instant::now();
238        let initial_cells = self.count_active_cells();
239
240        self.current_step += 1;
241
242        // Skip adaptation if not at adaptation frequency
243        if !self.current_step.is_multiple_of(self.adaptation_frequency) {
244            return Ok(AMRAdaptationResult {
245                cells_refined: 0,
246                cells_coarsened: 0,
247                total_active_cells: initial_cells,
248                load_balance_quality: F::one(),
249                memory_change: 0,
250                adaptation_time: start_time.elapsed(),
251            });
252        }
253
254        // Step 1: Update solution values in cells
255        self.update_cell_solutions(solution)?;
256
257        // Step 2: Evaluate refinement criteria
258        self.evaluate_refinement_criteria()?;
259
260        // Step 3: Flag cells for refinement/coarsening
261        let _refine_count_coarsen_count = self.flag_cells_for_adaptation()?;
262
263        // Step 4: Perform refinement
264        let cells_refined = self.refine_flagged_cells()?;
265
266        // Step 5: Perform coarsening
267        let cells_coarsened = self.coarsen_flagged_cells()?;
268
269        // Step 6: Load balancing
270        let load_balance_quality = if let Some(ref balancer) = self.load_balancer {
271            balancer.balance(&mut self.mesh_hierarchy)?;
272            self.assess_load_balance()
273        } else {
274            F::one()
275        };
276
277        // Step 7: Update ghost cells
278        self.update_ghost_cells()?;
279
280        let final_cells = self.count_active_cells();
281        let memory_change = (final_cells as i64 - initial_cells as i64) * 8; // Rough estimate
282
283        Ok(AMRAdaptationResult {
284            cells_refined,
285            cells_coarsened,
286            total_active_cells: final_cells,
287            load_balance_quality,
288            memory_change,
289            adaptation_time: start_time.elapsed(),
290        })
291    }
292
293    /// Update solution values in mesh cells
294    fn update_cell_solutions(&mut self, solution: &Array2<F>) -> IntegrateResult<()> {
295        // Map solution array to mesh cells
296        // This is a simplified mapping - in practice would need sophisticated interpolation
297        for level in &mut self.mesh_hierarchy.levels {
298            for cell in level.cells.values_mut() {
299                if cell.is_active {
300                    // Simple mapping - in practice would use proper interpolation
301                    let i = (cell.center[0] * F::from(solution.nrows()).unwrap())
302                        .to_usize()
303                        .unwrap_or(0)
304                        .min(solution.nrows() - 1);
305                    let j = if solution.ncols() > 1 && cell.center.len() > 1 {
306                        (cell.center[1] * F::from(solution.ncols()).unwrap())
307                            .to_usize()
308                            .unwrap_or(0)
309                            .min(solution.ncols() - 1)
310                    } else {
311                        0
312                    };
313
314                    // Update cell solution (simplified)
315                    if cell.solution.len() == 1 {
316                        cell.solution[0] = solution[[i, j]];
317                    }
318                }
319            }
320        }
321        Ok(())
322    }
323
324    /// Evaluate all refinement criteria for all cells
325    fn evaluate_refinement_criteria(&mut self) -> IntegrateResult<()> {
326        for level in &mut self.mesh_hierarchy.levels {
327            let cellids: Vec<CellId> = level.cells.keys().cloned().collect();
328
329            for cellid in cellids {
330                if let Some(cell) = level.cells.get(&cellid) {
331                    if !cell.is_active {
332                        continue;
333                    }
334
335                    // Get neighboring cells
336                    let neighbor_cells: Vec<&AdaptiveCell<F>> = cell
337                        .neighbors
338                        .iter()
339                        .filter_map(|&neighbor_id| level.cells.get(&neighbor_id))
340                        .collect();
341
342                    // Evaluate all criteria
343                    let mut total_error = F::zero();
344                    let mut total_weight = F::zero();
345
346                    for criterion in &self.refinement_criteria {
347                        let error = criterion.evaluate(cell, &neighbor_cells);
348                        let weight = criterion.weight();
349                        total_error += error * weight;
350                        total_weight += weight;
351                    }
352
353                    // Normalize error estimate
354                    let error_estimate = if total_weight > F::zero() {
355                        total_error / total_weight
356                    } else {
357                        F::zero()
358                    };
359
360                    // Update cell error estimate
361                    if let Some(cell) = level.cells.get_mut(&cellid) {
362                        cell.error_estimate = error_estimate;
363                    }
364                }
365            }
366        }
367        Ok(())
368    }
369
370    /// Flag cells for refinement or coarsening
371    fn flag_cells_for_adaptation(&mut self) -> IntegrateResult<(usize, usize)> {
372        let mut refine_count = 0;
373        let mut coarsen_count = 0;
374
375        // Collect cells that can be coarsened first to avoid borrowing issues
376        let mut cells_to_check: Vec<(usize, CellId, F, usize, F)> = Vec::new();
377
378        for level in &self.mesh_hierarchy.levels {
379            for cell in level.cells.values() {
380                if cell.is_active {
381                    cells_to_check.push((
382                        level.level,
383                        cell.id,
384                        cell.error_estimate,
385                        level.level,
386                        cell.size,
387                    ));
388                }
389            }
390        }
391
392        // Now flag cells based on collected information
393        for (level_idx, cellid, error_estimate, level_num, cell_size) in cells_to_check {
394            if let Some(level) = self.mesh_hierarchy.levels.get_mut(level_idx) {
395                if let Some(cell) = level.cells.get_mut(&cellid) {
396                    // Refinement criterion
397                    if error_estimate > self.refinement_tolerance
398                        && level_num < self.max_levels
399                        && cell_size > self.min_cell_size
400                    {
401                        cell.refinement_flag = RefinementFlag::Refine;
402                        refine_count += 1;
403                    }
404                    // Coarsening criterion (simplified check)
405                    else if error_estimate < self.coarsening_tolerance && level_num > 0 {
406                        cell.refinement_flag = RefinementFlag::Coarsen;
407                        coarsen_count += 1;
408                    } else {
409                        cell.refinement_flag = RefinementFlag::None;
410                    }
411                }
412            }
413        }
414
415        Ok((refine_count, coarsen_count))
416    }
417
418    /// Check if cell can be coarsened (all siblings must be flagged)
419    fn can_coarsen_cell(&self, cell: &AdaptiveCell<F>) -> bool {
420        if let Some(parent_id) = cell.parent {
421            // Check if all sibling cells are also flagged for coarsening
422            if let Some(parent_children) = self.mesh_hierarchy.hierarchy_map.get(&parent_id) {
423                for &child_id in parent_children {
424                    if let Some(level) = self.mesh_hierarchy.levels.get(child_id.level) {
425                        if let Some(sibling) = level.cells.get(&child_id) {
426                            if sibling.refinement_flag != RefinementFlag::Coarsen {
427                                return false;
428                            }
429                        }
430                    }
431                }
432                return true;
433            }
434        }
435        false
436    }
437
438    /// Refine flagged cells
439    fn refine_flagged_cells(&mut self) -> IntegrateResult<usize> {
440        let mut cells_refined = 0;
441
442        // Process each level separately to avoid borrowing issues
443        for level_idx in 0..self.mesh_hierarchy.levels.len() {
444            let cells_to_refine: Vec<CellId> = self.mesh_hierarchy.levels[level_idx]
445                .cells
446                .values()
447                .filter(|cell| cell.refinement_flag == RefinementFlag::Refine)
448                .map(|cell| cell.id)
449                .collect();
450
451            for cellid in cells_to_refine {
452                self.refine_cell(cellid)?;
453                cells_refined += 1;
454            }
455        }
456
457        Ok(cells_refined)
458    }
459
460    /// Refine a single cell
461    fn refine_cell(&mut self, cellid: CellId) -> IntegrateResult<()> {
462        let parent_cell = if let Some(level) = self.mesh_hierarchy.levels.get(cellid.level) {
463            level.cells.get(&cellid).cloned()
464        } else {
465            return Err(IntegrateError::ValueError("Invalid cell level".to_string()));
466        };
467
468        let parent_cell =
469            parent_cell.ok_or_else(|| IntegrateError::ValueError("Cell not found".to_string()))?;
470
471        // Create child level if it doesn't exist
472        let child_level = cellid.level + 1;
473        while self.mesh_hierarchy.levels.len() <= child_level {
474            let new_level = AdaptiveMeshLevel {
475                level: self.mesh_hierarchy.levels.len(),
476                cells: HashMap::new(),
477                grid_spacing: if let Some(last_level) = self.mesh_hierarchy.levels.last() {
478                    last_level.grid_spacing / F::from(2.0).unwrap()
479                } else {
480                    F::one()
481                },
482                boundary_cells: HashSet::new(),
483            };
484            self.mesh_hierarchy.levels.push(new_level);
485        }
486
487        // Create child cells (2D refinement = 4 children, 3D = 8 children)
488        let num_children = 2_usize.pow(parent_cell.center.len() as u32);
489        let mut child_ids = Vec::new();
490        let child_size = parent_cell.size / F::from(2.0).unwrap();
491
492        for child_idx in 0..num_children {
493            let child_id = CellId {
494                level: child_level,
495                index: self.mesh_hierarchy.levels[child_level].cells.len(),
496            };
497
498            // Compute child center
499            let mut child_center = parent_cell.center.clone();
500            let offset = child_size / F::from(2.0).unwrap();
501
502            // Binary representation determines position
503            for dim in 0..parent_cell.center.len() {
504                if (child_idx >> dim) & 1 == 1 {
505                    child_center[dim] += offset;
506                } else {
507                    child_center[dim] -= offset;
508                }
509            }
510
511            let child_cell = AdaptiveCell {
512                id: child_id,
513                center: child_center,
514                size: child_size,
515                solution: parent_cell.solution.clone(), // Inherit parent solution
516                error_estimate: F::zero(),
517                refinement_flag: RefinementFlag::None,
518                is_active: true,
519                neighbors: Vec::new(),
520                parent: Some(cellid),
521                children: Vec::new(),
522            };
523
524            self.mesh_hierarchy.levels[child_level]
525                .cells
526                .insert(child_id, child_cell);
527            child_ids.push(child_id);
528        }
529
530        // Update hierarchy map
531        self.mesh_hierarchy
532            .hierarchy_map
533            .insert(cellid, child_ids.clone());
534
535        // Deactivate parent cell
536        if let Some(parent) = self.mesh_hierarchy.levels[cellid.level]
537            .cells
538            .get_mut(&cellid)
539        {
540            parent.is_active = false;
541            parent.children = child_ids;
542        }
543
544        // Update neighbor relationships
545        self.update_neighbor_relationships(child_level)?;
546
547        Ok(())
548    }
549
550    /// Coarsen flagged cells
551    fn coarsen_flagged_cells(&mut self) -> IntegrateResult<usize> {
552        let mut cells_coarsened = 0;
553
554        // Process from finest to coarsest level
555        for level_idx in (1..self.mesh_hierarchy.levels.len()).rev() {
556            let parent_cells_to_activate: Vec<CellId> = self.mesh_hierarchy.levels[level_idx]
557                .cells
558                .values()
559                .filter(|cell| cell.refinement_flag == RefinementFlag::Coarsen)
560                .filter_map(|cell| cell.parent)
561                .collect::<HashSet<_>>()
562                .into_iter()
563                .collect();
564
565            for parent_id in parent_cells_to_activate {
566                if self.coarsen_to_parent(parent_id)? {
567                    cells_coarsened += 1;
568                }
569            }
570        }
571
572        Ok(cells_coarsened)
573    }
574
575    /// Coarsen children back to parent cell
576    fn coarsen_to_parent(&mut self, parentid: CellId) -> IntegrateResult<bool> {
577        let child_ids = if let Some(children) = self.mesh_hierarchy.hierarchy_map.get(&parentid) {
578            children.clone()
579        } else {
580            return Ok(false);
581        };
582
583        // Verify all children are flagged for coarsening
584        for &child_id in &child_ids {
585            if let Some(level) = self.mesh_hierarchy.levels.get(child_id.level) {
586                if let Some(child) = level.cells.get(&child_id) {
587                    if child.refinement_flag != RefinementFlag::Coarsen {
588                        return Ok(false);
589                    }
590                }
591            }
592        }
593
594        // Average child solutions for parent
595        let mut avg_solution = Array1::zeros(child_ids.len());
596        if !child_ids.is_empty() {
597            if let Some(first_child_level) = self.mesh_hierarchy.levels.get(child_ids[0].level) {
598                if let Some(first_child) = first_child_level.cells.get(&child_ids[0]) {
599                    avg_solution = Array1::zeros(first_child.solution.len());
600
601                    for &child_id in &child_ids {
602                        if let Some(child_level) = self.mesh_hierarchy.levels.get(child_id.level) {
603                            if let Some(child) = child_level.cells.get(&child_id) {
604                                avg_solution = &avg_solution + &child.solution;
605                            }
606                        }
607                    }
608                    avg_solution /= F::from(child_ids.len()).unwrap();
609                }
610            }
611        }
612
613        // Reactivate parent cell
614        if let Some(parent_level) = self.mesh_hierarchy.levels.get_mut(parentid.level) {
615            if let Some(parent) = parent_level.cells.get_mut(&parentid) {
616                parent.is_active = true;
617                parent.solution = avg_solution;
618                parent.children.clear();
619                parent.refinement_flag = RefinementFlag::None;
620            }
621        }
622
623        // Remove children from hierarchy
624        for &child_id in &child_ids {
625            if let Some(child_level) = self.mesh_hierarchy.levels.get_mut(child_id.level) {
626                child_level.cells.remove(&child_id);
627            }
628        }
629
630        // Remove from hierarchy map
631        self.mesh_hierarchy.hierarchy_map.remove(&parentid);
632
633        Ok(true)
634    }
635
636    /// Update neighbor relationships after refinement
637    fn update_neighbor_relationships(&mut self, level: usize) -> IntegrateResult<()> {
638        // Collect neighbor relationships first to avoid borrowing conflicts
639        let mut all_neighbor_relationships: Vec<(CellId, Vec<CellId>)> = Vec::new();
640
641        if let Some(mesh_level) = self.mesh_hierarchy.levels.get(level) {
642            let cellids: Vec<CellId> = mesh_level.cells.keys().cloned().collect();
643
644            // Build spatial hash map for efficient neighbor searching
645            let mut spatial_hash: HashMap<(i32, i32, i32), Vec<CellId>> = HashMap::new();
646            let grid_spacing = mesh_level.grid_spacing;
647
648            // Hash all cells based on their spatial location
649            for cellid in &cellids {
650                if let Some(cell) = mesh_level.cells.get(cellid) {
651                    if cell.center.len() >= 3 {
652                        let hash_x = (cell.center[0] / grid_spacing)
653                            .floor()
654                            .to_i32()
655                            .unwrap_or(0);
656                        let hash_y = (cell.center[1] / grid_spacing)
657                            .floor()
658                            .to_i32()
659                            .unwrap_or(0);
660                        let hash_z = (cell.center[2] / grid_spacing)
661                            .floor()
662                            .to_i32()
663                            .unwrap_or(0);
664
665                        spatial_hash
666                            .entry((hash_x, hash_y, hash_z))
667                            .or_default()
668                            .push(*cellid);
669                    }
670                }
671            }
672
673            for cellid in &cellids {
674                if let Some(cell) = mesh_level.cells.get(cellid) {
675                    let mut neighbors = Vec::new();
676
677                    if cell.center.len() >= 3 {
678                        let hash_x = (cell.center[0] / grid_spacing)
679                            .floor()
680                            .to_i32()
681                            .unwrap_or(0);
682                        let hash_y = (cell.center[1] / grid_spacing)
683                            .floor()
684                            .to_i32()
685                            .unwrap_or(0);
686                        let hash_z = (cell.center[2] / grid_spacing)
687                            .floor()
688                            .to_i32()
689                            .unwrap_or(0);
690
691                        // Search in 27 neighboring hash buckets (3x3x3)
692                        for dx in -1..=1 {
693                            for dy in -1..=1 {
694                                for dz in -1..=1 {
695                                    let hash_key = (hash_x + dx, hash_y + dy, hash_z + dz);
696
697                                    if let Some(potential_neighbors) = spatial_hash.get(&hash_key) {
698                                        for &neighbor_id in potential_neighbors {
699                                            if neighbor_id != *cellid {
700                                                if let Some(neighbor_cell) =
701                                                    mesh_level.cells.get(&neighbor_id)
702                                                {
703                                                    // Check if cells are actually neighbors (face/edge/vertex sharing)
704                                                    if self.are_cells_neighbors(cell, neighbor_cell)
705                                                    {
706                                                        neighbors.push(neighbor_id);
707                                                    }
708                                                }
709                                            }
710                                        }
711                                    }
712                                }
713                            }
714                        }
715                    }
716                    all_neighbor_relationships.push((*cellid, neighbors));
717                }
718            }
719        }
720
721        // Now apply all neighbor relationships with mutable access
722        if let Some(mesh_level) = self.mesh_hierarchy.levels.get_mut(level) {
723            for (cellid, neighbors) in all_neighbor_relationships {
724                if let Some(cell) = mesh_level.cells.get_mut(&cellid) {
725                    cell.neighbors = neighbors;
726                }
727            }
728        }
729
730        // Now update inter-level neighbors separately to avoid borrowing conflicts
731        let cellids: Vec<CellId> = if let Some(mesh_level) = self.mesh_hierarchy.levels.get(level) {
732            mesh_level.cells.keys().cloned().collect()
733        } else {
734            Vec::new()
735        };
736
737        for cellid in cellids {
738            self.update_interlevel_neighbors(cellid, level)?;
739        }
740
741        Ok(())
742    }
743
744    /// Check if two cells are geometric neighbors
745    fn are_cells_neighbors(&self, cell1: &AdaptiveCell<F>, cell2: &AdaptiveCell<F>) -> bool {
746        if cell1.center.len() != cell2.center.len() || cell1.center.len() < 3 {
747            return false;
748        }
749
750        let max_size = cell1.size.max(cell2.size);
751        let tolerance = max_size * F::from(1.1).unwrap(); // 10% tolerance
752
753        // Calculate distance between cell centers
754        let mut distance_squared = F::zero();
755        for i in 0..cell1.center.len() {
756            let diff = cell1.center[i] - cell2.center[i];
757            distance_squared += diff * diff;
758        }
759
760        let distance = distance_squared.sqrt();
761
762        // Cells are neighbors if distance is approximately equal to sum of half-sizes
763        let expected_distance = (cell1.size + cell2.size) / F::from(2.0).unwrap();
764
765        distance <= tolerance && distance >= expected_distance * F::from(0.7).unwrap()
766    }
767
768    /// Update neighbor relationships across different mesh levels
769    fn update_interlevel_neighbors(&mut self, cellid: CellId, level: usize) -> IntegrateResult<()> {
770        // Collect neighbor relationships first to avoid borrowing conflicts
771        let mut coarser_neighbors = Vec::new();
772        let mut finer_neighbors = Vec::new();
773
774        // Check neighbors at level-1 (coarser level)
775        if level > 0 {
776            if let (Some(current_level), Some(coarser_level)) = (
777                self.mesh_hierarchy.levels.get(level),
778                self.mesh_hierarchy.levels.get(level - 1),
779            ) {
780                if let Some(current_cell) = current_level.cells.get(&cellid) {
781                    for (coarser_cellid, coarser_cell) in &coarser_level.cells {
782                        if self.are_cells_neighbors(current_cell, coarser_cell) {
783                            coarser_neighbors.push(*coarser_cellid);
784                        }
785                    }
786                }
787            }
788        }
789
790        // Check neighbors at level+1 (finer level)
791        if level + 1 < self.mesh_hierarchy.levels.len() {
792            if let (Some(current_level), Some(finer_level)) = (
793                self.mesh_hierarchy.levels.get(level),
794                self.mesh_hierarchy.levels.get(level + 1),
795            ) {
796                if let Some(current_cell) = current_level.cells.get(&cellid) {
797                    for (finer_cellid, finer_cell) in &finer_level.cells {
798                        if self.are_cells_neighbors(current_cell, finer_cell) {
799                            finer_neighbors.push(*finer_cellid);
800                        }
801                    }
802                }
803            }
804        }
805
806        // Now apply the neighbor relationships with mutable access
807        if let Some(current_level) = self.mesh_hierarchy.levels.get_mut(level) {
808            if let Some(current_cell) = current_level.cells.get_mut(&cellid) {
809                for coarser_id in coarser_neighbors {
810                    if !current_cell.neighbors.contains(&coarser_id) {
811                        current_cell.neighbors.push(coarser_id);
812                    }
813                }
814
815                for finer_id in finer_neighbors {
816                    if !current_cell.neighbors.contains(&finer_id) {
817                        current_cell.neighbors.push(finer_id);
818                    }
819                }
820            }
821        }
822
823        Ok(())
824    }
825
826    /// Update ghost cells for parallel processing
827    fn update_ghost_cells(&mut self) -> IntegrateResult<()> {
828        // Clear existing ghost cells
829        self.mesh_hierarchy.ghost_cells.clear();
830
831        // Identify boundary cells and their external neighbors for each level
832        for level_idx in 0..self.mesh_hierarchy.levels.len() {
833            let mut ghost_cells_for_level = Vec::new();
834            let mut boundary_cells = HashSet::new();
835
836            // First pass: identify boundary cells (cells with fewer neighbors than expected)
837            let expected_neighbors = self.calculate_expected_neighbors();
838
839            if let Some(mesh_level) = self.mesh_hierarchy.levels.get(level_idx) {
840                for (cellid, cell) in &mesh_level.cells {
841                    // A cell is on the boundary if it has fewer neighbors than expected
842                    // or if it's marked as a boundary cell
843                    if cell.neighbors.len() < expected_neighbors
844                        || mesh_level.boundary_cells.contains(cellid)
845                    {
846                        boundary_cells.insert(*cellid);
847                    }
848                }
849
850                // Second pass: create ghost cells for parallel processing
851                for boundary_cellid in &boundary_cells {
852                    if let Some(boundary_cell) = mesh_level.cells.get(boundary_cellid) {
853                        // Create ghost cells in the expected neighbor positions
854                        let ghost_cells =
855                            self.create_ghost_cells_for_boundary(boundary_cell, level_idx)?;
856                        ghost_cells_for_level.extend(ghost_cells);
857                    }
858                }
859
860                // Third pass: handle inter-level ghost cells
861                self.create_interlevel_ghost_cells(level_idx, &mut ghost_cells_for_level)?;
862            }
863
864            self.mesh_hierarchy
865                .ghost_cells
866                .insert(level_idx, ghost_cells_for_level);
867        }
868
869        Ok(())
870    }
871
872    /// Calculate expected number of neighbors for a regular cell
873    fn calculate_expected_neighbors(&self) -> usize {
874        // For a 3D structured grid, a regular internal cell should have 6 face neighbors
875        // For 2D: 4 neighbors, for 1D: 2 neighbors
876        // This is a simplification - actual count depends on mesh structure
877        6
878    }
879
880    /// Create ghost cells for a boundary cell
881    fn create_ghost_cells_for_boundary(
882        &self,
883        boundary_cell: &AdaptiveCell<F>,
884        level: usize,
885    ) -> IntegrateResult<Vec<CellId>> {
886        let mut ghost_cells = Vec::new();
887
888        if boundary_cell.center.len() >= 3 {
889            let cell_size = boundary_cell.size;
890
891            // Create ghost cells in the 6 cardinal directions (±x, ±y, ±z)
892            let directions = [
893                [F::one(), F::zero(), F::zero()],  // +x
894                [-F::one(), F::zero(), F::zero()], // -x
895                [F::zero(), F::one(), F::zero()],  // +y
896                [F::zero(), -F::one(), F::zero()], // -y
897                [F::zero(), F::zero(), F::one()],  // +z
898                [F::zero(), F::zero(), -F::one()], // -z
899            ];
900
901            for (dir_idx, direction) in directions.iter().enumerate() {
902                // Calculate ghost _cell position
903                let mut ghost_center = boundary_cell.center.clone();
904                for i in 0..3 {
905                    ghost_center[i] += direction[i] * cell_size;
906                }
907
908                // Check if a real _cell exists at this position
909                if !self.cell_exists_at_position(&ghost_center, level) {
910                    // Create ghost _cell ID (using high indices to avoid conflicts)
911                    let ghost_id = CellId {
912                        level,
913                        index: 1_000_000 + boundary_cell.id.index * 10 + dir_idx,
914                    };
915
916                    ghost_cells.push(ghost_id);
917                }
918            }
919        }
920
921        Ok(ghost_cells)
922    }
923
924    /// Create ghost cells for inter-level communication
925    fn create_interlevel_ghost_cells(
926        &self,
927        level: usize,
928        ghost_cells: &mut Vec<CellId>,
929    ) -> IntegrateResult<()> {
930        // Handle ghost _cells needed for communication between mesh levels
931
932        // Check if we need ghost _cells from coarser level
933        if level > 0 {
934            if let Some(current_level) = self.mesh_hierarchy.levels.get(level) {
935                for (cellid, cell) in &current_level.cells {
936                    // If this fine cell doesn't have a parent at the coarser level,
937                    // it might need ghost cell communication
938                    if cell.parent.is_none() {
939                        let ghost_id = CellId {
940                            level: level - 1,
941                            index: 2_000_000 + cellid.index,
942                        };
943                        ghost_cells.push(ghost_id);
944                    }
945                }
946            }
947        }
948
949        // Check if we need ghost _cells from finer level
950        if level + 1 < self.mesh_hierarchy.levels.len() {
951            if let Some(current_level) = self.mesh_hierarchy.levels.get(level) {
952                for (cellid, cell) in &current_level.cells {
953                    // If this coarse cell has children at the finer level,
954                    // it might need ghost cell communication
955                    if !cell.children.is_empty() {
956                        let ghost_id = CellId {
957                            level: level + 1,
958                            index: 3_000_000 + cellid.index,
959                        };
960                        ghost_cells.push(ghost_id);
961                    }
962                }
963            }
964        }
965
966        Ok(())
967    }
968
969    /// Check if a cell exists at the given position and level
970    fn cell_exists_at_position(&self, position: &Array1<F>, level: usize) -> bool {
971        if let Some(mesh_level) = self.mesh_hierarchy.levels.get(level) {
972            let tolerance = mesh_level.grid_spacing * F::from(0.1).unwrap();
973
974            for cell in mesh_level.cells.values() {
975                if position.len() == cell.center.len() {
976                    let mut distance_squared = F::zero();
977                    for i in 0..position.len() {
978                        let diff = position[i] - cell.center[i];
979                        distance_squared += diff * diff;
980                    }
981
982                    if distance_squared.sqrt() < tolerance {
983                        return true;
984                    }
985                }
986            }
987        }
988        false
989    }
990
991    /// Count total active cells across all levels
992    fn count_active_cells(&self) -> usize {
993        self.mesh_hierarchy
994            .levels
995            .iter()
996            .map(|level| level.cells.values().filter(|cell| cell.is_active).count())
997            .sum()
998    }
999
1000    /// Assess load balance quality
1001    fn assess_load_balance(&self) -> F {
1002        let total_cells = self.count_active_cells();
1003        if total_cells == 0 {
1004            return F::one(); // Empty mesh is perfectly balanced
1005        }
1006
1007        // Calculate multiple load balance metrics
1008        let cell_distribution_balance = self.calculate_cell_distribution_balance();
1009        let computational_load_balance = self.calculate_computational_load_balance();
1010        let communication_overhead_balance = self.calculate_communication_balance();
1011        let memory_distribution_balance = self.calculate_memory_balance();
1012
1013        // Weighted combination of different balance metrics
1014        let weight_cell = F::from(0.3).unwrap();
1015        let weight_compute = F::from(0.4).unwrap();
1016        let weight_comm = F::from(0.2).unwrap();
1017        let weight_memory = F::from(0.1).unwrap();
1018
1019        let overall_balance = weight_cell * cell_distribution_balance
1020            + weight_compute * computational_load_balance
1021            + weight_comm * communication_overhead_balance
1022            + weight_memory * memory_distribution_balance;
1023
1024        // Clamp to [0, 1] range where 1.0 = perfect balance
1025        overall_balance.min(F::one()).max(F::zero())
1026    }
1027
1028    /// Calculate cell count distribution balance across levels
1029    fn calculate_cell_distribution_balance(&self) -> F {
1030        if self.mesh_hierarchy.levels.is_empty() {
1031            return F::one();
1032        }
1033
1034        // Calculate cells per level
1035        let mut cells_per_level: Vec<usize> = Vec::new();
1036        let mut total_cells = 0;
1037
1038        for level in &self.mesh_hierarchy.levels {
1039            let active_cells = level.cells.values().filter(|c| c.is_active).count();
1040            cells_per_level.push(active_cells);
1041            total_cells += active_cells;
1042        }
1043
1044        if total_cells == 0 {
1045            return F::one();
1046        }
1047
1048        // Calculate variance in cell distribution
1049        let mean_cells = total_cells as f64 / cells_per_level.len() as f64;
1050        let variance: f64 = cells_per_level
1051            .iter()
1052            .map(|&count| {
1053                let diff = count as f64 - mean_cells;
1054                diff * diff
1055            })
1056            .sum::<f64>()
1057            / cells_per_level.len() as f64;
1058
1059        let std_dev = variance.sqrt();
1060        let coefficient_of_variation = if mean_cells > 0.0 {
1061            std_dev / mean_cells
1062        } else {
1063            0.0
1064        };
1065
1066        // Convert to balance score (lower variation = better balance)
1067        let balance = (1.0 - coefficient_of_variation.min(1.0)).max(0.0);
1068        F::from(balance).unwrap_or(F::zero())
1069    }
1070
1071    /// Calculate computational load balance based on cell error estimates
1072    fn calculate_computational_load_balance(&self) -> F {
1073        let mut level_computational_loads: Vec<F> = Vec::new();
1074        let mut total_load = F::zero();
1075
1076        for level in &self.mesh_hierarchy.levels {
1077            let mut level_load = F::zero();
1078
1079            for cell in level.cells.values() {
1080                if cell.is_active {
1081                    // Computational cost is proportional to error estimate and refinement complexity
1082                    let cell_cost = cell.error_estimate * cell.size * cell.size; // O(h^2) scaling
1083                    level_load += cell_cost;
1084                }
1085            }
1086
1087            level_computational_loads.push(level_load);
1088            total_load += level_load;
1089        }
1090
1091        if total_load <= F::zero() {
1092            return F::one();
1093        }
1094
1095        // Calculate coefficient of variation for computational loads
1096        let mean_load = total_load / F::from(level_computational_loads.len()).unwrap();
1097        let mut variance = F::zero();
1098
1099        for &load in &level_computational_loads {
1100            let diff = load - mean_load;
1101            variance += diff * diff;
1102        }
1103
1104        variance /= F::from(level_computational_loads.len()).unwrap();
1105        let std_dev = variance.sqrt();
1106
1107        let coeff_var = if mean_load > F::zero() {
1108            std_dev / mean_load
1109        } else {
1110            F::zero()
1111        };
1112
1113        // Convert to balance score
1114        let balance = F::one() - coeff_var.min(F::one());
1115        balance.max(F::zero())
1116    }
1117
1118    /// Calculate communication balance based on ghost cell overhead
1119    fn calculate_communication_balance(&self) -> F {
1120        let mut level_comm_costs: Vec<F> = Vec::new();
1121        let mut total_comm_cost = F::zero();
1122
1123        for (level_idx, level) in self.mesh_hierarchy.levels.iter().enumerate() {
1124            let active_cells = level.cells.values().filter(|c| c.is_active).count();
1125            let ghost_cells = self
1126                .mesh_hierarchy
1127                .ghost_cells
1128                .get(&level_idx)
1129                .map(|ghosts| ghosts.len())
1130                .unwrap_or(0);
1131
1132            // Communication cost is proportional to ghost cells per active cell
1133            let comm_cost = if active_cells > 0 {
1134                F::from(ghost_cells as f64 / active_cells as f64).unwrap_or(F::zero())
1135            } else {
1136                F::zero()
1137            };
1138
1139            level_comm_costs.push(comm_cost);
1140            total_comm_cost += comm_cost;
1141        }
1142
1143        if level_comm_costs.is_empty() || total_comm_cost <= F::zero() {
1144            return F::one();
1145        }
1146
1147        // Calculate variance in communication costs
1148        let mean_comm = total_comm_cost / F::from(level_comm_costs.len()).unwrap();
1149        let mut variance = F::zero();
1150
1151        for &cost in &level_comm_costs {
1152            let diff = cost - mean_comm;
1153            variance += diff * diff;
1154        }
1155
1156        variance /= F::from(level_comm_costs.len()).unwrap();
1157        let std_dev = variance.sqrt();
1158
1159        let coeff_var = if mean_comm > F::zero() {
1160            std_dev / mean_comm
1161        } else {
1162            F::zero()
1163        };
1164
1165        // Convert to balance score
1166        let balance = F::one() - coeff_var.min(F::one());
1167        balance.max(F::zero())
1168    }
1169
1170    /// Calculate memory distribution balance
1171    fn calculate_memory_balance(&self) -> F {
1172        let mut level_memory_usage: Vec<F> = Vec::new();
1173        let mut total_memory = F::zero();
1174
1175        for level in &self.mesh_hierarchy.levels {
1176            // Estimate memory usage: cells + solution data + neighbor lists
1177            let cell_count = level.cells.len();
1178            let total_neighbors: usize = level.cells.values().map(|c| c.neighbors.len()).sum();
1179
1180            let solution_size: usize = level.cells.values().map(|c| c.solution.len()).sum();
1181
1182            // Memory estimate (simplified)
1183            let memory_estimate = F::from(cell_count + total_neighbors + solution_size).unwrap();
1184            level_memory_usage.push(memory_estimate);
1185            total_memory += memory_estimate;
1186        }
1187
1188        if level_memory_usage.is_empty() || total_memory <= F::zero() {
1189            return F::one();
1190        }
1191
1192        // Calculate memory distribution balance
1193        let mean_memory = total_memory / F::from(level_memory_usage.len()).unwrap();
1194        let mut variance = F::zero();
1195
1196        for &memory in &level_memory_usage {
1197            let diff = memory - mean_memory;
1198            variance += diff * diff;
1199        }
1200
1201        variance /= F::from(level_memory_usage.len()).unwrap();
1202        let std_dev = variance.sqrt();
1203
1204        let coeff_var = if mean_memory > F::zero() {
1205            std_dev / mean_memory
1206        } else {
1207            F::zero()
1208        };
1209
1210        // Convert to balance score
1211        let balance = F::one() - coeff_var.min(F::one());
1212        balance.max(F::zero())
1213    }
1214}
1215
1216// Refinement criterion implementations
1217impl<F: IntegrateFloat + Send + Sync> RefinementCriterion<F> for GradientRefinementCriterion<F> {
1218    fn evaluate(&self, cell: &AdaptiveCell<F>, neighbors: &[&AdaptiveCell<F>]) -> F {
1219        if neighbors.is_empty() {
1220            return F::zero();
1221        }
1222
1223        let mut max_gradient = F::zero();
1224
1225        for neighbor in neighbors {
1226            let gradient = if let Some(comp) = self.component {
1227                if comp < cell.solution.len() && comp < neighbor.solution.len() {
1228                    (cell.solution[comp] - neighbor.solution[comp]).abs() / cell.size
1229                } else {
1230                    F::zero()
1231                }
1232            } else {
1233                // Use L2 norm of solution difference
1234                let diff = &cell.solution - &neighbor.solution;
1235                diff.mapv(|x| x.powi(2)).sum().sqrt() / cell.size
1236            };
1237
1238            max_gradient = max_gradient.max(gradient);
1239        }
1240
1241        // Relative criterion
1242        let solution_magnitude = if let Some(comp) = self.component {
1243            cell.solution
1244                .get(comp)
1245                .map(|&x| x.abs())
1246                .unwrap_or(F::zero())
1247        } else {
1248            cell.solution.mapv(|x| x.abs()).sum()
1249        };
1250
1251        if solution_magnitude > F::zero() {
1252            max_gradient / solution_magnitude
1253        } else {
1254            max_gradient
1255        }
1256    }
1257
1258    fn name(&self) -> &'static str {
1259        "Gradient"
1260    }
1261}
1262
1263impl<F: IntegrateFloat + Send + Sync> RefinementCriterion<F> for FeatureDetectionCriterion<F> {
1264    fn evaluate(&self, cell: &AdaptiveCell<F>, neighbors: &[&AdaptiveCell<F>]) -> F {
1265        let mut feature_score = F::zero();
1266
1267        for &feature_type in &self.feature_types {
1268            match feature_type {
1269                FeatureType::SharpGradient => {
1270                    // Detect sharp gradients
1271                    if neighbors.len() >= 2 {
1272                        let gradients: Vec<F> = neighbors
1273                            .iter()
1274                            .map(|n| (&cell.solution - &n.solution).mapv(|x| x.abs()).sum())
1275                            .collect();
1276
1277                        let max_grad = gradients.iter().fold(F::zero(), |acc, &x| acc.max(x));
1278                        let avg_grad = gradients.iter().fold(F::zero(), |acc, &x| acc + x)
1279                            / F::from(gradients.len()).unwrap();
1280
1281                        if avg_grad > F::zero() {
1282                            feature_score += max_grad / avg_grad;
1283                        }
1284                    }
1285                }
1286                FeatureType::LocalExtremum => {
1287                    // Detect local extrema
1288                    let cell_value = cell.solution.mapv(|x| x.abs()).sum();
1289                    let mut is_extremum = true;
1290
1291                    for neighbor in neighbors {
1292                        let neighbor_value = neighbor.solution.mapv(|x| x.abs()).sum();
1293                        if (neighbor_value - cell_value).abs() < self.threshold {
1294                            is_extremum = false;
1295                            break;
1296                        }
1297                    }
1298
1299                    if is_extremum {
1300                        feature_score += F::one();
1301                    }
1302                }
1303                _ => {
1304                    // Other feature types would be implemented here
1305                }
1306            }
1307        }
1308
1309        feature_score
1310    }
1311
1312    fn name(&self) -> &'static str {
1313        "FeatureDetection"
1314    }
1315}
1316
1317impl<F: IntegrateFloat + Send + Sync> RefinementCriterion<F> for CurvatureRefinementCriterion<F> {
1318    fn evaluate(&self, cell: &AdaptiveCell<F>, neighbors: &[&AdaptiveCell<F>]) -> F {
1319        if neighbors.len() < 2 {
1320            return F::zero();
1321        }
1322
1323        // Estimate curvature using second differences
1324        let mut curvature = F::zero();
1325
1326        for component in 0..cell.solution.len() {
1327            let center_value = cell.solution[component];
1328            let neighbor_values: Vec<F> = neighbors
1329                .iter()
1330                .filter_map(|n| n.solution.get(component).copied())
1331                .collect();
1332
1333            if neighbor_values.len() >= 2 {
1334                // Simple second difference approximation
1335                let avg_neighbor = neighbor_values.iter().fold(F::zero(), |acc, &x| acc + x)
1336                    / F::from(neighbor_values.len()).unwrap();
1337
1338                let second_diff = (avg_neighbor - center_value).abs() / (cell.size * cell.size);
1339                curvature += second_diff;
1340            }
1341        }
1342
1343        curvature
1344    }
1345
1346    fn name(&self) -> &'static str {
1347        "Curvature"
1348    }
1349}
1350
1351#[cfg(test)]
1352mod tests {
1353    use super::*;
1354
1355    #[test]
1356    fn test_amr_manager_creation() {
1357        let initial_level = AdaptiveMeshLevel {
1358            level: 0,
1359            cells: HashMap::new(),
1360            grid_spacing: 1.0,
1361            boundary_cells: HashSet::new(),
1362        };
1363
1364        let amr = AdvancedAMRManager::new(initial_level, 5, 0.01);
1365        assert_eq!(amr.max_levels, 5);
1366        assert_eq!(amr.mesh_hierarchy.levels.len(), 1);
1367    }
1368
1369    #[test]
1370    fn test_gradient_criterion() {
1371        let cell = AdaptiveCell {
1372            id: CellId { level: 0, index: 0 },
1373            center: Array1::from_vec(vec![0.5, 0.5]),
1374            size: 0.1,
1375            solution: Array1::from_vec(vec![1.0]),
1376            error_estimate: 0.0,
1377            refinement_flag: RefinementFlag::None,
1378            is_active: true,
1379            neighbors: vec![],
1380            parent: None,
1381            children: vec![],
1382        };
1383
1384        let neighbor = AdaptiveCell {
1385            id: CellId { level: 0, index: 1 },
1386            center: Array1::from_vec(vec![0.6, 0.5]),
1387            size: 0.1,
1388            solution: Array1::from_vec(vec![2.0]),
1389            error_estimate: 0.0,
1390            refinement_flag: RefinementFlag::None,
1391            is_active: true,
1392            neighbors: vec![],
1393            parent: None,
1394            children: vec![],
1395        };
1396
1397        let criterion = GradientRefinementCriterion {
1398            component: Some(0),
1399            threshold: 1.0,
1400            relative_tolerance: 0.1,
1401        };
1402
1403        let neighbors = vec![&neighbor];
1404        let result = criterion.evaluate(&cell, &neighbors);
1405        assert!(result > 0.0);
1406    }
1407}