1use super::Octree;
2use len_trait::{Clear, Empty, Len};
3use num::One;
4use std::{
5 borrow::{Borrow, BorrowMut},
6 collections::HashMap,
7 hash::Hash,
8 mem,
9 ops::{Add, Div, Sub},
10};
11
12pub type ManagedOctree<D, S> = Octree<ManagedOctreeData<D, S>>;
13pub type ManagedVecOctree<T, S> = ManagedOctree<Vec<T>, S>;
14pub type ManagedHashMapOctree<K, V, S> = ManagedOctree<HashMap<K, V>, S>;
15
16pub trait OctreeCollection<I> {
18 fn add(&mut self, item: I) -> Option<()>;
19}
20
21pub trait CentredItem<S> {
22 fn centre(&self) -> (S, S, S);
23}
24
25impl<S> CentredItem<S> for (S, S, S)
26where
27 S: Copy,
28{
29 fn centre(&self) -> (S, S, S) { *self }
30}
31
32impl<S, K> CentredItem<S> for (K, (S, S, S))
33where
34 S: Copy,
35{
36 fn centre(&self) -> (S, S, S) { self.1 }
37}
38
39impl<I> OctreeCollection<I> for Vec<I> {
40 fn add(&mut self, item: I) -> Option<()> {
41 self.push(item);
42 Some(())
43 }
44}
45
46impl<K, V> OctreeCollection<(K, V)> for HashMap<K, V>
47where
48 K: Eq + Hash,
49{
50 fn add(&mut self, (key, val): (K, V)) -> Option<()> {
51 if self.contains_key(&key) {
52 return None;
53 }
54 self.insert(key, val);
55 Some(())
56 }
57}
58
59pub struct ManagedOctreeData<D, S>
60where
61 D: Default + Empty + Len,
62 S: Default + One,
63{
64 centre: (S, S, S),
65 half_length: S,
66 max_size: usize,
67 drop_below_size: usize,
68 len: usize,
69 data: D,
70}
71
72impl<D, S> Default for ManagedOctreeData<D, S>
73where
74 D: Default + Empty + Len,
75 S: Default
76 + Copy
77 + One
78 + Add<S, Output = S>
79 + Sub<S, Output = S>
80 + Div<S, Output = S>,
81{
82 fn default() -> Self {
83 Self {
84 centre: (S::default(), S::default(), S::default()),
85 half_length: S::one(),
86 max_size: 1,
87 drop_below_size: 1,
88 len: 0,
89 data: D::default(),
90 }
91 }
92}
93
94impl<D, S> ManagedOctreeData<D, S>
95where
96 D: Default + Empty + Len,
97 S: Default
98 + Copy
99 + One
100 + Add<S, Output = S>
101 + Sub<S, Output = S>
102 + Div<S, Output = S>,
103{
104 #[must_use]
106 pub fn get_data(&self) -> &D { self.data.borrow() }
107
108 #[must_use]
110 pub fn get_data_mut(&mut self) -> &mut D { self.data.borrow_mut() }
111}
112
113impl<D, S, T> ManagedOctree<D, S>
114where
115 D: Default
116 + Empty
117 + Len
118 + Clear
119 + IntoIterator<Item = T>
120 + OctreeCollection<T>,
121 T: CentredItem<S>,
122 S: Default
123 + Copy
124 + One
125 + PartialOrd
126 + Add<S, Output = S>
127 + Sub<S, Output = S>
128 + Div<S, Output = S>,
129{
130 #[must_use]
131 pub fn new_managed(centre: (S, S, S), half_length: S) -> Self {
132 Self::new_with_data(ManagedOctreeData {
133 centre,
134 half_length,
135 ..ManagedOctreeData::default()
136 })
137 }
138
139 #[must_use]
141 pub fn with_max_size(mut self, max_size: usize) -> Self {
142 self.data.max_size = max_size;
143 self
144 }
145
146 #[must_use]
150 pub fn with_drop_below_size(mut self, drop_below_size: usize) -> Self {
151 if drop_below_size == 0 {
152 panic!("drop_below_size must be greater than 0");
153 }
154
155 self.data.drop_below_size = drop_below_size;
156 self
157 }
158
159 pub fn add(&mut self, item: T) {
161 self.data.data.add(item);
162 self.data.len += 1;
163 }
164
165 pub fn clear_data(&mut self) {
167 self.data.len -= self.data.data.len();
168 self.data.data.clear()
169 }
170
171 pub fn rebalance(&mut self) {
172 let bucket_counts = self.move_to_existing_children();
173 if self.data.data.len() <= self.data.max_size {
174 return;
175 }
176 let bucket_sizes = Self::sort_bucket_sizes(bucket_counts);
177 let mut new_size = self.data.data.len();
178 for (max_idx, max_val) in bucket_sizes {
179 let (px, py, pz) = Self::get_child_pos_at_idx(max_idx);
180 let (centre, half_length) =
181 self.get_child_centre_and_half_length_at_pos(px, py, pz);
182 self.add_child(
183 max_idx,
184 Self::new_managed(centre, half_length)
185 .with_max_size(self.data.max_size)
186 .with_drop_below_size(self.data.drop_below_size),
187 )
188 .unwrap();
189 new_size -= max_val;
190 if new_size <= self.data.max_size {
191 break;
192 }
193 }
194 self.move_to_existing_children();
195 }
196
197 fn sort_bucket_sizes(sizes: [usize; 8]) -> Vec<(usize, usize)> {
198 let mut bucket_sizes: Vec<(usize, usize)> =
199 sizes.iter().enumerate().map(|(i, &v)| (i, v)).collect();
200 bucket_sizes.sort_unstable_by(|(_ai, am), (_bi, bm)| {
201 bm.partial_cmp(am).unwrap()
202 });
203 bucket_sizes
204 }
205
206 fn move_to_existing_children(&mut self) -> [usize; 8] {
209 let (cx, cy, cz) = self.data.centre;
210
211 let mut result = [0; 8];
212 let mut old_d = D::default();
213 mem::swap(&mut old_d, &mut self.data.data);
214 for item in old_d {
215 let (ix, iy, iz) = item.centre();
216 let idx = Self::get_child_idx_at_pos(ix > cx, iy > cy, iz > cz);
217 if let Some(child) = &mut self.children[idx] {
218 child.add(item);
219 } else {
220 self.add(item);
221 result[idx] += 1;
222 }
223 }
224
225 result
226 }
227
228 fn get_child_centre_and_half_length_at_pos(
229 &self,
230 pos_x: bool,
231 pos_y: bool,
232 pos_z: bool,
233 ) -> ((S, S, S), S) {
234 let (cx, cy, cz) = self.data.centre;
235 let hhl = self.data.half_length / (S::one() + S::one());
236 match (pos_x, pos_y, pos_z) {
237 (false, false, false) => ((cx - hhl, cy - hhl, cz - hhl), (hhl)),
238 (false, false, true) => ((cx - hhl, cy - hhl, cz + hhl), (hhl)),
239 (false, true, false) => ((cx - hhl, cy + hhl, cz - hhl), (hhl)),
240 (false, true, true) => ((cx - hhl, cy + hhl, cz + hhl), (hhl)),
241 (true, false, false) => ((cx + hhl, cy - hhl, cz - hhl), (hhl)),
242 (true, false, true) => ((cx + hhl, cy - hhl, cz + hhl), (hhl)),
243 (true, true, false) => ((cx + hhl, cy + hhl, cz - hhl), (hhl)),
244 (true, true, true) => ((cx + hhl, cy + hhl, cz + hhl), (hhl)),
245 }
246 }
247}
248
249impl<T, S> Empty for ManagedVecOctree<T, S>
250where
251 S: Default
252 + Copy
253 + One
254 + Add<S, Output = S>
255 + Sub<S, Output = S>
256 + Div<S, Output = S>,
257{
258 fn is_empty(&self) -> bool { self.data.len == 0 }
259}
260
261impl<T, S> Len for ManagedVecOctree<T, S>
262where
263 S: Default
264 + Copy
265 + One
266 + Add<S, Output = S>
267 + Sub<S, Output = S>
268 + Div<S, Output = S>,
269{
270 fn len(&self) -> usize { self.data.len }
271}
272
273impl<K, V, S> Empty for ManagedHashMapOctree<K, V, S>
274where
275 K: Eq + Hash,
276 S: Default
277 + Copy
278 + One
279 + Add<S, Output = S>
280 + Sub<S, Output = S>
281 + Div<S, Output = S>,
282{
283 fn is_empty(&self) -> bool { self.data.len == 0 }
284}
285
286impl<K, V, S> Len for ManagedHashMapOctree<K, V, S>
287where
288 K: Eq + Hash,
289 S: Default
290 + Copy
291 + One
292 + Add<S, Output = S>
293 + Sub<S, Output = S>
294 + Div<S, Output = S>,
295{
296 fn len(&self) -> usize { self.data.len }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::{ManagedHashMapOctree, ManagedVecOctree};
302 use len_trait::Len;
303
304 #[test]
305 fn test_with_drop_below_size() {
306 let o = ManagedVecOctree::<(f32, f32, f32), f32>::new_managed(
307 (0.0, 0.0, 0.0),
308 1000.0,
309 )
310 .with_drop_below_size(3);
311 assert_eq!(o.data.drop_below_size, 3);
312 }
313
314 #[test]
315 #[should_panic]
316 fn test_with_drop_below_size_0_panics() {
317 let _ = ManagedVecOctree::<(f32, f32, f32), f32>::new_managed(
318 (0.0, 0.0, 0.0),
319 1000.0,
320 )
321 .with_drop_below_size(0);
322 }
323
324 #[test]
325 fn test_with_max_size() {
326 let o = ManagedVecOctree::<(f32, f32, f32), f32>::new_managed(
327 (0.0, 0.0, 0.0),
328 1000.0,
329 )
330 .with_max_size(3);
331 assert_eq!(o.data.max_size, 3);
332 }
333
334 #[test]
335 fn test_get_child_centre_and_half_length_neg() {
336 let o = ManagedVecOctree::<(f32, f32, f32), f32>::new_managed(
337 (0.0, 0.0, 0.0),
338 1000.0,
339 );
340 let ((cx, cy, cz), half_length) =
341 o.get_child_centre_and_half_length_at_pos(false, false, false);
342 assert_relative_eq!(cx, -500.0);
343 assert_relative_eq!(cy, -500.0);
344 assert_relative_eq!(cz, -500.0);
345 assert_relative_eq!(half_length, 500.0);
346 }
347
348 #[test]
349 fn test_get_child_centre_and_half_length_pos() {
350 let o = ManagedVecOctree::<(f32, f32, f32), f32>::new_managed(
351 (0.0, 0.0, 0.0),
352 1000.0,
353 );
354 let ((cx, cy, cz), half_length) =
355 o.get_child_centre_and_half_length_at_pos(true, true, true);
356 assert_relative_eq!(cx, 500.0);
357 assert_relative_eq!(cy, 500.0);
358 assert_relative_eq!(cz, 500.0);
359 assert_relative_eq!(half_length, 500.0);
360 }
361
362 #[test]
363 fn test_get_child_centre_and_half_length_partial_pos_off_centre() {
364 let o = ManagedVecOctree::<(f32, f32, f32), f32>::new_managed(
365 (100.0, 200.0, 300.0),
366 1000.0,
367 );
368 let ((cx, cy, cz), half_length) =
369 o.get_child_centre_and_half_length_at_pos(true, false, true);
370 assert_relative_eq!(cx, 600.0);
371 assert_relative_eq!(cy, -300.0);
372 assert_relative_eq!(cz, 800.0);
373 assert_relative_eq!(half_length, 500.0);
374 }
375
376 #[test]
377 fn test_vec_add() {
378 let mut o = ManagedVecOctree::<(f32, f32, f32), f32>::new_managed(
379 (0.0, 0.0, 0.0),
380 1000.0,
381 );
382 assert_eq!(o.len(), 0);
383 o.add((123.45, 234.567, 345.678));
384 assert_eq!(o.len(), 1);
385 }
386
387 #[test]
388 fn test_hash_add() {
389 let mut o =
390 ManagedHashMapOctree::<u32, (f32, f32, f32), f32>::new_managed(
391 (0.0, 0.0, 0.0),
392 1000.0,
393 );
394 assert_eq!(o.len(), 0);
395 o.add((123, (123.45, 234.567, 345.678)));
396 assert_eq!(o.len(), 1);
397 }
398
399 #[test]
400 fn test_rebalance_max_2() {
401 let mut o = ManagedVecOctree::<(f32, f32, f32), f32>::new_managed(
402 (0.0, 0.0, 0.0),
403 1000.0,
404 )
405 .with_max_size(2);
406 o.add((1.0, 1.0, 1.0));
407 o.add((2.0, 2.0, 1.0));
408 o.add((-1.0, -1.0, -1.0));
409 o.rebalance();
410 assert_eq!(o.data.data.len(), 1);
411 assert!(o.get_child_at_pos(true, true, true).is_some());
412 assert!(o.get_child_at_pos(false, false, false).is_none());
413 assert_eq!(
414 o.get_child_at_pos(true, true, true)
415 .unwrap()
416 .data
417 .data
418 .len(),
419 2
420 );
421 }
422}