Skip to main content

simple_octree/
managed_octree.rs

1use super::Octree;
2use len_trait::{Clear, Empty, Len};
3use num::One;
4use std::{
5    borrow::{Borrow, BorrowMut},
6    collections::HashMap,
7    hash::Hash,
8    mem,
9    ops::{Add, Div, Sub},
10};
11
12pub type ManagedOctree<D, S> = Octree<ManagedOctreeData<D, S>>;
13pub type ManagedVecOctree<T, S> = ManagedOctree<Vec<T>, S>;
14pub type ManagedHashMapOctree<K, V, S> = ManagedOctree<HashMap<K, V>, S>;
15
16/// A trait that will allow the underlying collection to be treated generically.
17pub trait OctreeCollection<I> {
18    fn add(&mut self, item: I) -> Option<()>;
19}
20
21pub trait CentredItem<S> {
22    fn centre(&self) -> (S, S, S);
23}
24
25impl<S> CentredItem<S> for (S, S, S)
26where
27    S: Copy,
28{
29    fn centre(&self) -> (S, S, S) { *self }
30}
31
32impl<S, K> CentredItem<S> for (K, (S, S, S))
33where
34    S: Copy,
35{
36    fn centre(&self) -> (S, S, S) { self.1 }
37}
38
39impl<I> OctreeCollection<I> for Vec<I> {
40    fn add(&mut self, item: I) -> Option<()> {
41        self.push(item);
42        Some(())
43    }
44}
45
46impl<K, V> OctreeCollection<(K, V)> for HashMap<K, V>
47where
48    K: Eq + Hash,
49{
50    fn add(&mut self, (key, val): (K, V)) -> Option<()> {
51        if self.contains_key(&key) {
52            return None;
53        }
54        self.insert(key, val);
55        Some(())
56    }
57}
58
59pub struct ManagedOctreeData<D, S>
60where
61    D: Default + Empty + Len,
62    S: Default + One,
63{
64    centre: (S, S, S),
65    half_length: S,
66    max_size: usize,
67    drop_below_size: usize,
68    len: usize,
69    data: D,
70}
71
72impl<D, S> Default for ManagedOctreeData<D, S>
73where
74    D: Default + Empty + Len,
75    S: Default
76        + Copy
77        + One
78        + Add<S, Output = S>
79        + Sub<S, Output = S>
80        + Div<S, Output = S>,
81{
82    fn default() -> Self {
83        Self {
84            centre: (S::default(), S::default(), S::default()),
85            half_length: S::one(),
86            max_size: 1,
87            drop_below_size: 1,
88            len: 0,
89            data: D::default(),
90        }
91    }
92}
93
94impl<D, S> ManagedOctreeData<D, S>
95where
96    D: Default + Empty + Len,
97    S: Default
98        + Copy
99        + One
100        + Add<S, Output = S>
101        + Sub<S, Output = S>
102        + Div<S, Output = S>,
103{
104    /// Gets a reference to the underlying data in the node.
105    #[must_use]
106    pub fn get_data(&self) -> &D { self.data.borrow() }
107
108    /// Gets a mutable reference to the underlying data in the node.
109    #[must_use]
110    pub fn get_data_mut(&mut self) -> &mut D { self.data.borrow_mut() }
111}
112
113impl<D, S, T> ManagedOctree<D, S>
114where
115    D: Default
116        + Empty
117        + Len
118        + Clear
119        + IntoIterator<Item = T>
120        + OctreeCollection<T>,
121    T: CentredItem<S>,
122    S: Default
123        + Copy
124        + One
125        + PartialOrd
126        + Add<S, Output = S>
127        + Sub<S, Output = S>
128        + Div<S, Output = S>,
129{
130    #[must_use]
131    pub fn new_managed(centre: (S, S, S), half_length: S) -> Self {
132        Self::new_with_data(ManagedOctreeData {
133            centre,
134            half_length,
135            ..ManagedOctreeData::default()
136        })
137    }
138
139    /// Set `max_size`
140    #[must_use]
141    pub fn with_max_size(mut self, max_size: usize) -> Self {
142        self.data.max_size = max_size;
143        self
144    }
145
146    /// Set `drop_below_size`
147    ///
148    /// Panics when set to 0
149    #[must_use]
150    pub fn with_drop_below_size(mut self, drop_below_size: usize) -> Self {
151        if drop_below_size == 0 {
152            panic!("drop_below_size must be greater than 0");
153        }
154
155        self.data.drop_below_size = drop_below_size;
156        self
157    }
158
159    /// Adds data to the node without flushing/rebalancing the tree.
160    pub fn add(&mut self, item: T) {
161        self.data.data.add(item);
162        self.data.len += 1;
163    }
164
165    /// Clears data from the node (not the whole tree)
166    pub fn clear_data(&mut self) {
167        self.data.len -= self.data.data.len();
168        self.data.data.clear()
169    }
170
171    pub fn rebalance(&mut self) {
172        let bucket_counts = self.move_to_existing_children();
173        if self.data.data.len() <= self.data.max_size {
174            return;
175        }
176        let bucket_sizes = Self::sort_bucket_sizes(bucket_counts);
177        let mut new_size = self.data.data.len();
178        for (max_idx, max_val) in bucket_sizes {
179            let (px, py, pz) = Self::get_child_pos_at_idx(max_idx);
180            let (centre, half_length) =
181                self.get_child_centre_and_half_length_at_pos(px, py, pz);
182            self.add_child(
183                max_idx,
184                Self::new_managed(centre, half_length)
185                    .with_max_size(self.data.max_size)
186                    .with_drop_below_size(self.data.drop_below_size),
187            )
188            .unwrap();
189            new_size -= max_val;
190            if new_size <= self.data.max_size {
191                break;
192            }
193        }
194        self.move_to_existing_children();
195    }
196
197    fn sort_bucket_sizes(sizes: [usize; 8]) -> Vec<(usize, usize)> {
198        let mut bucket_sizes: Vec<(usize, usize)> =
199            sizes.iter().enumerate().map(|(i, &v)| (i, v)).collect();
200        bucket_sizes.sort_unstable_by(|(_ai, am), (_bi, bm)| {
201            bm.partial_cmp(am).unwrap()
202        });
203        bucket_sizes
204    }
205
206    /// Moves any objects that should belong to a child to that child if it
207    /// exists. Returns the bucket sizes of any remaining items.
208    fn move_to_existing_children(&mut self) -> [usize; 8] {
209        let (cx, cy, cz) = self.data.centre;
210
211        let mut result = [0; 8];
212        let mut old_d = D::default();
213        mem::swap(&mut old_d, &mut self.data.data);
214        for item in old_d {
215            let (ix, iy, iz) = item.centre();
216            let idx = Self::get_child_idx_at_pos(ix > cx, iy > cy, iz > cz);
217            if let Some(child) = &mut self.children[idx] {
218                child.add(item);
219            } else {
220                self.add(item);
221                result[idx] += 1;
222            }
223        }
224
225        result
226    }
227
228    fn get_child_centre_and_half_length_at_pos(
229        &self,
230        pos_x: bool,
231        pos_y: bool,
232        pos_z: bool,
233    ) -> ((S, S, S), S) {
234        let (cx, cy, cz) = self.data.centre;
235        let hhl = self.data.half_length / (S::one() + S::one());
236        match (pos_x, pos_y, pos_z) {
237            (false, false, false) => ((cx - hhl, cy - hhl, cz - hhl), (hhl)),
238            (false, false, true) => ((cx - hhl, cy - hhl, cz + hhl), (hhl)),
239            (false, true, false) => ((cx - hhl, cy + hhl, cz - hhl), (hhl)),
240            (false, true, true) => ((cx - hhl, cy + hhl, cz + hhl), (hhl)),
241            (true, false, false) => ((cx + hhl, cy - hhl, cz - hhl), (hhl)),
242            (true, false, true) => ((cx + hhl, cy - hhl, cz + hhl), (hhl)),
243            (true, true, false) => ((cx + hhl, cy + hhl, cz - hhl), (hhl)),
244            (true, true, true) => ((cx + hhl, cy + hhl, cz + hhl), (hhl)),
245        }
246    }
247}
248
249impl<T, S> Empty for ManagedVecOctree<T, S>
250where
251    S: Default
252        + Copy
253        + One
254        + Add<S, Output = S>
255        + Sub<S, Output = S>
256        + Div<S, Output = S>,
257{
258    fn is_empty(&self) -> bool { self.data.len == 0 }
259}
260
261impl<T, S> Len for ManagedVecOctree<T, S>
262where
263    S: Default
264        + Copy
265        + One
266        + Add<S, Output = S>
267        + Sub<S, Output = S>
268        + Div<S, Output = S>,
269{
270    fn len(&self) -> usize { self.data.len }
271}
272
273impl<K, V, S> Empty for ManagedHashMapOctree<K, V, S>
274where
275    K: Eq + Hash,
276    S: Default
277        + Copy
278        + One
279        + Add<S, Output = S>
280        + Sub<S, Output = S>
281        + Div<S, Output = S>,
282{
283    fn is_empty(&self) -> bool { self.data.len == 0 }
284}
285
286impl<K, V, S> Len for ManagedHashMapOctree<K, V, S>
287where
288    K: Eq + Hash,
289    S: Default
290        + Copy
291        + One
292        + Add<S, Output = S>
293        + Sub<S, Output = S>
294        + Div<S, Output = S>,
295{
296    fn len(&self) -> usize { self.data.len }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::{ManagedHashMapOctree, ManagedVecOctree};
302    use len_trait::Len;
303
304    #[test]
305    fn test_with_drop_below_size() {
306        let o = ManagedVecOctree::<(f32, f32, f32), f32>::new_managed(
307            (0.0, 0.0, 0.0),
308            1000.0,
309        )
310        .with_drop_below_size(3);
311        assert_eq!(o.data.drop_below_size, 3);
312    }
313
314    #[test]
315    #[should_panic]
316    fn test_with_drop_below_size_0_panics() {
317        let _ = ManagedVecOctree::<(f32, f32, f32), f32>::new_managed(
318            (0.0, 0.0, 0.0),
319            1000.0,
320        )
321        .with_drop_below_size(0);
322    }
323
324    #[test]
325    fn test_with_max_size() {
326        let o = ManagedVecOctree::<(f32, f32, f32), f32>::new_managed(
327            (0.0, 0.0, 0.0),
328            1000.0,
329        )
330        .with_max_size(3);
331        assert_eq!(o.data.max_size, 3);
332    }
333
334    #[test]
335    fn test_get_child_centre_and_half_length_neg() {
336        let o = ManagedVecOctree::<(f32, f32, f32), f32>::new_managed(
337            (0.0, 0.0, 0.0),
338            1000.0,
339        );
340        let ((cx, cy, cz), half_length) =
341            o.get_child_centre_and_half_length_at_pos(false, false, false);
342        assert_relative_eq!(cx, -500.0);
343        assert_relative_eq!(cy, -500.0);
344        assert_relative_eq!(cz, -500.0);
345        assert_relative_eq!(half_length, 500.0);
346    }
347
348    #[test]
349    fn test_get_child_centre_and_half_length_pos() {
350        let o = ManagedVecOctree::<(f32, f32, f32), f32>::new_managed(
351            (0.0, 0.0, 0.0),
352            1000.0,
353        );
354        let ((cx, cy, cz), half_length) =
355            o.get_child_centre_and_half_length_at_pos(true, true, true);
356        assert_relative_eq!(cx, 500.0);
357        assert_relative_eq!(cy, 500.0);
358        assert_relative_eq!(cz, 500.0);
359        assert_relative_eq!(half_length, 500.0);
360    }
361
362    #[test]
363    fn test_get_child_centre_and_half_length_partial_pos_off_centre() {
364        let o = ManagedVecOctree::<(f32, f32, f32), f32>::new_managed(
365            (100.0, 200.0, 300.0),
366            1000.0,
367        );
368        let ((cx, cy, cz), half_length) =
369            o.get_child_centre_and_half_length_at_pos(true, false, true);
370        assert_relative_eq!(cx, 600.0);
371        assert_relative_eq!(cy, -300.0);
372        assert_relative_eq!(cz, 800.0);
373        assert_relative_eq!(half_length, 500.0);
374    }
375
376    #[test]
377    fn test_vec_add() {
378        let mut o = ManagedVecOctree::<(f32, f32, f32), f32>::new_managed(
379            (0.0, 0.0, 0.0),
380            1000.0,
381        );
382        assert_eq!(o.len(), 0);
383        o.add((123.45, 234.567, 345.678));
384        assert_eq!(o.len(), 1);
385    }
386
387    #[test]
388    fn test_hash_add() {
389        let mut o =
390            ManagedHashMapOctree::<u32, (f32, f32, f32), f32>::new_managed(
391                (0.0, 0.0, 0.0),
392                1000.0,
393            );
394        assert_eq!(o.len(), 0);
395        o.add((123, (123.45, 234.567, 345.678)));
396        assert_eq!(o.len(), 1);
397    }
398
399    #[test]
400    fn test_rebalance_max_2() {
401        let mut o = ManagedVecOctree::<(f32, f32, f32), f32>::new_managed(
402            (0.0, 0.0, 0.0),
403            1000.0,
404        )
405        .with_max_size(2);
406        o.add((1.0, 1.0, 1.0));
407        o.add((2.0, 2.0, 1.0));
408        o.add((-1.0, -1.0, -1.0));
409        o.rebalance();
410        assert_eq!(o.data.data.len(), 1);
411        assert!(o.get_child_at_pos(true, true, true).is_some());
412        assert!(o.get_child_at_pos(false, false, false).is_none());
413        assert_eq!(
414            o.get_child_at_pos(true, true, true)
415                .unwrap()
416                .data
417                .data
418                .len(),
419            2
420        );
421    }
422}