1use super::{Octree, OctreeType, Statistics};
4use ndarray::{Array1, ArrayView2, ArrayViewMut1, Axis};
5use rusty_kernel_tools::RealType;
6use std::collections::{HashMap, HashSet};
7use std::time::Instant;
8
9pub enum BalanceMode {
10 Unbalanced,
12 Balanced,
14}
15
16fn refine_tree<T: RealType>(
17 key: usize,
18 refine_indices: &HashSet<usize>,
19 mut particle_keys: ArrayViewMut1<usize>,
20 particles: ArrayView2<T>,
21 max_particles: usize,
22 origin: &[f64; 3],
23 diameter: &[f64; 3],
24) {
25 use crate::morton::{encode_point, find_level};
26
27 let level = find_level(key);
28
29 if (level == 16) | (refine_indices.len() < max_particles) {
30 return;
33 }
34 let mut new_keys = HashSet::<usize>::new();
35
36 for &particle_index in refine_indices {
37 let particle = [
38 particles[[0, particle_index]].to_f64().unwrap(),
39 particles[[1, particle_index]].to_f64().unwrap(),
40 particles[[2, particle_index]].to_f64().unwrap(),
41 ];
42
43 let particle_key = encode_point(&particle, 1 + level, origin, diameter);
44 particle_keys[particle_index] = particle_key;
45 new_keys.insert(particle_key);
46 }
47
48 for new_key in new_keys {
49 let associated_indices: HashSet<usize> = refine_indices
50 .iter()
51 .copied()
52 .filter(|&item| particle_keys[item] == new_key)
53 .collect();
54 refine_tree(
55 new_key,
56 &associated_indices,
57 particle_keys.view_mut(),
58 particles,
59 max_particles,
60 origin,
61 diameter,
62 );
63 }
64}
65
66pub fn adaptive_octree<T: RealType>(
75 particles: ArrayView2<T>,
76 max_particles: usize,
77 balance_mode: BalanceMode,
78) -> Octree<'_, T> {
79 use crate::helpers::compute_bounds;
80
81 const TOL: f64 = 1E-5;
82
83 let bounds = compute_bounds(particles);
84 let diameter = [
85 (bounds[0][1] - bounds[0][0]).to_f64().unwrap() * (1.0 + TOL),
86 (bounds[1][1] - bounds[1][0]).to_f64().unwrap() * (1.0 + TOL),
87 (bounds[2][1] - bounds[2][0]).to_f64().unwrap() * (1.0 + TOL),
88 ];
89
90 let origin = [
91 bounds[0][0].to_f64().unwrap(),
92 bounds[1][0].to_f64().unwrap(),
93 bounds[2][0].to_f64().unwrap(),
94 ];
95
96 adaptive_octree_with_bounding_box(particles, max_particles, origin, diameter, balance_mode)
97}
98
99pub fn adaptive_octree_with_bounding_box<T: RealType>(
110 particles: ArrayView2<T>,
111 max_particles: usize,
112 origin: [f64; 3],
113 diameter: [f64; 3],
114 balance_mode: BalanceMode,
115) -> Octree<'_, T> {
116 use super::{
117 compute_interaction_list_map, compute_leaf_map, compute_level_information,
118 compute_near_field_map,
119 };
120
121 let number_of_particles = particles.len_of(Axis(1));
122
123 let now = Instant::now();
124
125 let mut particle_keys = Array1::<usize>::zeros(number_of_particles);
128 let refine_indices: HashSet<usize> = (0..number_of_particles).collect();
129
130 refine_tree(
131 0,
132 &refine_indices,
133 particle_keys.view_mut(),
134 particles,
135 max_particles,
136 &origin,
137 &diameter,
138 );
139
140 let (max_level, mut all_keys, mut level_keys) = compute_level_information(particle_keys.view());
141
142 match &balance_mode {
143 BalanceMode::Balanced => balance_tree(
144 &mut level_keys,
145 particle_keys.view_mut(),
146 particles,
147 &mut all_keys,
148 &origin,
149 &diameter,
150 ),
151 _ => (),
152 }
153
154 let leaf_key_to_particles = compute_leaf_map(particle_keys.view());
155
156 let near_field = compute_near_field_map(&all_keys);
157 let interaction_list = compute_interaction_list_map(&all_keys);
158
159 let duration = now.elapsed();
160
161 let statistics = Statistics {
162 number_of_particles: particles.len_of(Axis(1)),
163 max_level,
164 number_of_leafs: leaf_key_to_particles.keys().len(),
165 number_of_keys: all_keys.len(),
166 creation_time: duration,
167 minimum_number_of_particles_in_leaf: leaf_key_to_particles
168 .values()
169 .map(|item| item.len())
170 .reduce(std::cmp::min)
171 .unwrap(),
172 maximum_number_of_particles_in_leaf: leaf_key_to_particles
173 .values()
174 .map(|item| item.len())
175 .reduce(std::cmp::max)
176 .unwrap(),
177 average_number_of_particles_in_leaf: (leaf_key_to_particles
178 .values()
179 .map(|item| item.len())
180 .sum::<usize>() as f64)
181 / (leaf_key_to_particles.keys().len() as f64),
182 };
183
184 Octree {
185 particles,
186 particle_keys,
187 max_level,
188 origin,
189 diameter,
190 leaf_key_to_particles,
191 level_keys,
192 interaction_list,
193 near_field,
194 all_keys,
195 octree_type: match &balance_mode {
196 BalanceMode::Balanced => OctreeType::BalancedAdaptive,
197 BalanceMode::Unbalanced => OctreeType::UnbalancedAdaptive,
198 },
199 statistics,
200 }
201}
202
203fn find_completion(
205 mut key: usize,
206 level_keys: &mut HashMap<usize, HashSet<usize>>,
207 all_keys: &mut HashSet<usize>,
208) {
209 use crate::morton::{find_level, find_parent};
210
211 let mut intermediate_keys = HashSet::<usize>::new();
212 let mut level = find_level(key);
213 while !all_keys.contains(&key) {
214 intermediate_keys.insert(key);
215 level_keys.get_mut(&level).unwrap().insert(key);
216 level = level - 1;
217 key = find_parent(key);
218 }
219
220 all_keys.extend(intermediate_keys);
221}
222
223fn balance_tree<T: RealType>(
224 level_keys: &mut HashMap<usize, HashSet<usize>>,
225 mut particle_keys: ArrayViewMut1<usize>,
226 particles: ArrayView2<T>,
227 all_keys: &mut HashSet<usize>,
228 origin: &[f64; 3],
229 diameter: &[f64; 3],
230) {
231 use super::compute_complete_regular_tree;
232 use crate::morton::{compute_near_field, encode_point, find_level, find_parent};
233
234 let max_level = level_keys.keys().max().unwrap().clone();
235 let nlevels = 1 + max_level;
236
237 let regular_tree = compute_complete_regular_tree(particles, max_level, origin, diameter);
238
239 for level in (1..nlevels).rev() {
240 let current_keys: HashSet<usize> =
241 level_keys.get(&level).unwrap().iter().copied().collect();
242 for key in current_keys {
243 let near_field = compute_near_field(key);
244 for near_field_key in near_field {
245 let parent = find_parent(near_field_key);
246 if regular_tree.contains(&parent) {
249 find_completion(parent, level_keys, all_keys);
250 }
251 }
252 }
253 }
254
255 for (particle_index, key) in particle_keys.iter_mut().enumerate() {
257 let particle = [
258 particles[[0, particle_index]].to_f64().unwrap(),
259 particles[[1, particle_index]].to_f64().unwrap(),
260 particles[[2, particle_index]].to_f64().unwrap(),
261 ];
262
263 let mut current_level = find_level(*key);
264
265 while current_level < max_level {
266 let descendent_key = encode_point(&particle, current_level + 1, origin, diameter);
267
268 if all_keys.contains(&descendent_key) {
269 *key = descendent_key;
270 current_level += 1;
271 } else {
272 break;
273 }
274 }
275 }
276}