rusty_tree/octree/
adaptive_octree.rs

1//! Data structures and functions for adaptive octrees.
2
3use super::{Octree, OctreeType, Statistics};
4use ndarray::{Array1, ArrayView2, ArrayViewMut1, Axis};
5use rusty_kernel_tools::RealType;
6use std::collections::{HashMap, HashSet};
7use std::time::Instant;
8
9pub enum BalanceMode {
10    /// Use for unbalanced adaptive octree.
11    Unbalanced,
12    /// Use for balanced adaptive octree.
13    Balanced,
14}
15
16fn refine_tree<T: RealType>(
17    key: usize,
18    refine_indices: &HashSet<usize>,
19    mut particle_keys: ArrayViewMut1<usize>,
20    particles: ArrayView2<T>,
21    max_particles: usize,
22    origin: &[f64; 3],
23    diameter: &[f64; 3],
24) {
25    use crate::morton::{encode_point, find_level};
26
27    let level = find_level(key);
28
29    if (level == 16) | (refine_indices.len() < max_particles) {
30        // Do not refine if we have reached level cap or
31        // we are already below the particle limit.
32        return;
33    }
34    let mut new_keys = HashSet::<usize>::new();
35
36    for &particle_index in refine_indices {
37        let particle = [
38            particles[[0, particle_index]].to_f64().unwrap(),
39            particles[[1, particle_index]].to_f64().unwrap(),
40            particles[[2, particle_index]].to_f64().unwrap(),
41        ];
42
43        let particle_key = encode_point(&particle, 1 + level, origin, diameter);
44        particle_keys[particle_index] = particle_key;
45        new_keys.insert(particle_key);
46    }
47
48    for new_key in new_keys {
49        let associated_indices: HashSet<usize> = refine_indices
50            .iter()
51            .copied()
52            .filter(|&item| particle_keys[item] == new_key)
53            .collect();
54        refine_tree(
55            new_key,
56            &associated_indices,
57            particle_keys.view_mut(),
58            particles,
59            max_particles,
60            origin,
61            diameter,
62        );
63    }
64}
65
66/// Create a adaptive octree.
67///
68/// Returns a `AdaptiveOctree` struct describing an adaptive octree.
69///
70/// # Arguments
71/// * `particles` - A (3, N) array of particles of type f32 or f64.
72/// * `max_particles` - The maximum number of particles in each leaf.
73/// * `balance_mode` - Use `Balanced` for a 2:1 balanced octree, `Unbalanced` otherwise.
74pub fn adaptive_octree<T: RealType>(
75    particles: ArrayView2<T>,
76    max_particles: usize,
77    balance_mode: BalanceMode,
78) -> Octree<'_, T> {
79    use crate::helpers::compute_bounds;
80
81    const TOL: f64 = 1E-5;
82
83    let bounds = compute_bounds(particles);
84    let diameter = [
85        (bounds[0][1] - bounds[0][0]).to_f64().unwrap() * (1.0 + TOL),
86        (bounds[1][1] - bounds[1][0]).to_f64().unwrap() * (1.0 + TOL),
87        (bounds[2][1] - bounds[2][0]).to_f64().unwrap() * (1.0 + TOL),
88    ];
89
90    let origin = [
91        bounds[0][0].to_f64().unwrap(),
92        bounds[1][0].to_f64().unwrap(),
93        bounds[2][0].to_f64().unwrap(),
94    ];
95
96    adaptive_octree_with_bounding_box(particles, max_particles, origin, diameter, balance_mode)
97}
98
99/// Create an adaptive Octree with given bounding box.
100///
101/// Returns a `AdaptiveOctree` struct describing an adaptive octree.
102///
103/// # Arguments
104/// * `particles` - A (3, N) array of particles of type f32 or f64.
105/// * `max_particles` - Maximum number of particles.
106/// * `origin` - The origin of the bounding box.
107/// * `diameter` - The diameter of the bounding box in each dimension.
108/// * `balance_mode` - Use `Balanced` for a 2:1 balanced octree, `Unbalanced` otherwise.
109pub fn adaptive_octree_with_bounding_box<T: RealType>(
110    particles: ArrayView2<T>,
111    max_particles: usize,
112    origin: [f64; 3],
113    diameter: [f64; 3],
114    balance_mode: BalanceMode,
115) -> Octree<'_, T> {
116    use super::{
117        compute_interaction_list_map, compute_leaf_map, compute_level_information,
118        compute_near_field_map,
119    };
120
121    let number_of_particles = particles.len_of(Axis(1));
122
123    let now = Instant::now();
124
125    // First build up the non-adaptive tree by continuous refinement.
126
127    let mut particle_keys = Array1::<usize>::zeros(number_of_particles);
128    let refine_indices: HashSet<usize> = (0..number_of_particles).collect();
129
130    refine_tree(
131        0,
132        &refine_indices,
133        particle_keys.view_mut(),
134        particles,
135        max_particles,
136        &origin,
137        &diameter,
138    );
139
140    let (max_level, mut all_keys, mut level_keys) = compute_level_information(particle_keys.view());
141
142    match &balance_mode {
143        BalanceMode::Balanced => balance_tree(
144            &mut level_keys,
145            particle_keys.view_mut(),
146            particles,
147            &mut all_keys,
148            &origin,
149            &diameter,
150        ),
151        _ => (),
152    }
153
154    let leaf_key_to_particles = compute_leaf_map(particle_keys.view());
155
156    let near_field = compute_near_field_map(&all_keys);
157    let interaction_list = compute_interaction_list_map(&all_keys);
158
159    let duration = now.elapsed();
160
161    let statistics = Statistics {
162        number_of_particles: particles.len_of(Axis(1)),
163        max_level,
164        number_of_leafs: leaf_key_to_particles.keys().len(),
165        number_of_keys: all_keys.len(),
166        creation_time: duration,
167        minimum_number_of_particles_in_leaf: leaf_key_to_particles
168            .values()
169            .map(|item| item.len())
170            .reduce(std::cmp::min)
171            .unwrap(),
172        maximum_number_of_particles_in_leaf: leaf_key_to_particles
173            .values()
174            .map(|item| item.len())
175            .reduce(std::cmp::max)
176            .unwrap(),
177        average_number_of_particles_in_leaf: (leaf_key_to_particles
178            .values()
179            .map(|item| item.len())
180            .sum::<usize>() as f64)
181            / (leaf_key_to_particles.keys().len() as f64),
182    };
183
184    Octree {
185        particles,
186        particle_keys,
187        max_level,
188        origin,
189        diameter,
190        leaf_key_to_particles,
191        level_keys,
192        interaction_list,
193        near_field,
194        all_keys,
195        octree_type: match &balance_mode {
196            BalanceMode::Balanced => OctreeType::BalancedAdaptive,
197            BalanceMode::Unbalanced => OctreeType::UnbalancedAdaptive,
198        },
199        statistics,
200    }
201}
202
203/// Take a key and add the key and all its ancestors to the tree
204fn find_completion(
205    mut key: usize,
206    level_keys: &mut HashMap<usize, HashSet<usize>>,
207    all_keys: &mut HashSet<usize>,
208) {
209    use crate::morton::{find_level, find_parent};
210
211    let mut intermediate_keys = HashSet::<usize>::new();
212    let mut level = find_level(key);
213    while !all_keys.contains(&key) {
214        intermediate_keys.insert(key);
215        level_keys.get_mut(&level).unwrap().insert(key);
216        level = level - 1;
217        key = find_parent(key);
218    }
219
220    all_keys.extend(intermediate_keys);
221}
222
223fn balance_tree<T: RealType>(
224    level_keys: &mut HashMap<usize, HashSet<usize>>,
225    mut particle_keys: ArrayViewMut1<usize>,
226    particles: ArrayView2<T>,
227    all_keys: &mut HashSet<usize>,
228    origin: &[f64; 3],
229    diameter: &[f64; 3],
230) {
231    use super::compute_complete_regular_tree;
232    use crate::morton::{compute_near_field, encode_point, find_level, find_parent};
233
234    let max_level = level_keys.keys().max().unwrap().clone();
235    let nlevels = 1 + max_level;
236
237    let regular_tree = compute_complete_regular_tree(particles, max_level, origin, diameter);
238
239    for level in (1..nlevels).rev() {
240        let current_keys: HashSet<usize> =
241            level_keys.get(&level).unwrap().iter().copied().collect();
242        for key in current_keys {
243            let near_field = compute_near_field(key);
244            for near_field_key in near_field {
245                let parent = find_parent(near_field_key);
246                // Only fill up if there can actually be particles in the parent
247                // of the neighbour.
248                if regular_tree.contains(&parent) {
249                    find_completion(parent, level_keys, all_keys);
250                }
251            }
252        }
253    }
254
255    // Now adapt the particle keys.
256    for (particle_index, key) in particle_keys.iter_mut().enumerate() {
257        let particle = [
258            particles[[0, particle_index]].to_f64().unwrap(),
259            particles[[1, particle_index]].to_f64().unwrap(),
260            particles[[2, particle_index]].to_f64().unwrap(),
261        ];
262
263        let mut current_level = find_level(*key);
264
265        while current_level < max_level {
266            let descendent_key = encode_point(&particle, current_level + 1, origin, diameter);
267
268            if all_keys.contains(&descendent_key) {
269                *key = descendent_key;
270                current_level += 1;
271            } else {
272                break;
273            }
274        }
275    }
276}