1use std::marker::PhantomData;
2use std::num::NonZeroUsize;
3
4use crate::{iter::twig_len_pad, Node, Object, Point, RTree, TWIG_LEN};
5
6pub const DEF_NODE_LEN: usize = 6;
8
9impl<O> RTree<O>
10where
11 O: Object,
12{
13 pub fn new(node_len: usize, objects: Vec<O>) -> Self {
19 assert!(node_len > 1);
20 assert!(!objects.is_empty());
21
22 let mut nodes = Vec::new();
23
24 let root_idx = build(node_len, objects, &mut nodes, &mut Vec::new());
25 debug_assert_eq!(root_idx, nodes.len() - 1);
26
27 nodes.reverse();
29
30 for node in &mut nodes {
31 if let Node::Twig(twig) = node {
32 for idx in twig {
33 *idx = root_idx - *idx;
34 }
35 }
36 }
37
38 Self {
39 nodes: nodes.into_boxed_slice(),
40 _marker: PhantomData,
41 }
42 }
43}
44
45fn build<O>(
49 node_len: usize,
50 objects: Vec<O>,
51 nodes: &mut Vec<Node<O>>,
52 next_nodes: &mut Vec<usize>,
53) -> usize
54where
55 O: Object,
56{
57 let next_nodes_len = next_nodes.len();
58
59 if objects.len() > node_len {
60 let num_clusters = num_clusters(node_len, O::Point::DIM, objects.len()).max(2);
61
62 struct State<O> {
63 objects: Vec<O>,
64 axis: usize,
65 }
66
67 let mut state = vec![State {
68 objects,
69 axis: O::Point::DIM,
70 }];
71
72 while let Some(State {
73 mut objects,
74 mut axis,
75 }) = state.pop()
76 {
77 if axis != 0 {
78 axis -= 1;
79
80 let cluster_len = (objects.len() + num_clusters - 1) / num_clusters;
81
82 while objects.len() > cluster_len {
83 let split_off = objects.len() - cluster_len;
84
85 objects.select_nth_unstable_by(split_off, |lhs, rhs| {
86 let lhs = lhs.aabb().0.coord(axis);
87 let rhs = rhs.aabb().0.coord(axis);
88 lhs.partial_cmp(&rhs).unwrap()
89 });
90
91 state.push(State {
92 objects: objects.drain(split_off..).collect(),
93 axis,
94 });
95 }
96
97 if !objects.is_empty() {
98 objects.shrink_to_fit();
99
100 state.push(State { objects, axis });
101 }
102 } else {
103 let node = build(node_len, objects, nodes, next_nodes);
104 next_nodes.push(node);
105 }
106 }
107 } else {
108 next_nodes.extend(nodes.len()..nodes.len() + objects.len());
109 nodes.extend(objects.into_iter().map(Node::Leaf));
110 }
111
112 let node = add_branch(nodes, &next_nodes[next_nodes_len..]);
113 next_nodes.truncate(next_nodes_len);
114 node
115}
116
117fn num_clusters(node_len: usize, point_dim: usize, num_objects: usize) -> usize {
118 let node_len = node_len as f32;
119 let point_dim = point_dim as f32;
120 let num_objects = num_objects as f32;
121
122 let depth = num_objects.log(node_len).ceil() as i32;
123
124 let subtree_len = node_len.powi(depth - 1);
125 let num_subtree = (num_objects / subtree_len).ceil();
126
127 num_subtree.powf(point_dim.recip()).floor() as usize
128}
129
130fn add_branch<O>(nodes: &mut Vec<Node<O>>, next_nodes: &[usize]) -> usize
131where
132 O: Object,
133{
134 let len = NonZeroUsize::new(next_nodes.len()).unwrap();
135
136 let aabb = merge_aabb(nodes, next_nodes);
137
138 {
139 let (len, pad) = twig_len_pad(&len);
141
142 nodes.reserve(len + 1);
143
144 let mut twig = [0; TWIG_LEN];
145 let mut pos = TWIG_LEN;
146
147 for next_node in next_nodes.iter().rev() {
149 pos -= 1;
150 twig[pos] = *next_node;
151
152 if pos == 0 {
153 nodes.push(Node::Twig(twig));
154 pos = TWIG_LEN;
155 }
156 }
157
158 if pos != TWIG_LEN {
159 debug_assert_eq!(pos, pad);
160 nodes.push(Node::Twig(twig));
161 }
162 }
163
164 let node = nodes.len();
165 nodes.push(Node::Branch { len, aabb });
166 node
167}
168
169fn merge_aabb<O>(nodes: &[Node<O>], next_nodes: &[usize]) -> (O::Point, O::Point)
170where
171 O: Object,
172{
173 next_nodes
174 .iter()
175 .map(|idx| match &nodes[*idx] {
176 Node::Branch { aabb, .. } => aabb.clone(),
177 Node::Twig(_) => unreachable!(),
178 Node::Leaf(obj) => obj.aabb(),
179 })
180 .reduce(|mut res, aabb| {
181 res.0 = res.0.min(&aabb.0);
182 res.1 = res.1.max(&aabb.1);
183
184 res
185 })
186 .unwrap()
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 use std::ops::ControlFlow;
194
195 use proptest::test_runner::TestRunner;
196
197 use crate::{
198 iter::branch_for_each,
199 tests::{random_objects, RandomObject},
200 };
201
202 impl rstar::RTreeObject for RandomObject {
203 type Envelope = rstar::AABB<[f32; 3]>;
204
205 fn envelope(&self) -> Self::Envelope {
206 rstar::AABB::from_corners(self.0, self.1)
207 }
208 }
209
210 fn collect_index<'a>(
211 nodes: &'a [Node<RandomObject>],
212 idx: usize,
213 branches: &mut Vec<usize>,
214 leaves: &mut Vec<&'a RandomObject>,
215 ) {
216 let [node, rest @ ..] = &nodes[idx..] else {
217 unreachable!()
218 };
219 let len = match node {
220 Node::Branch { len, .. } => len,
221 Node::Twig(_) | Node::Leaf(_) => unreachable!(),
222 };
223 branches.push(len.get());
224 branch_for_each(len, rest, |idx| {
225 match &nodes[idx] {
226 Node::Branch { .. } => collect_index(nodes, idx, branches, leaves),
227 Node::Twig(_) => unreachable!(),
228 Node::Leaf(obj) => {
229 branches.push(0);
230 leaves.push(obj);
231 }
232 }
233 ControlFlow::<()>::Continue(())
234 })
235 .continue_value()
236 .unwrap();
237 }
238
239 fn collect_rstar_index<'a>(
240 node: &'a rstar::ParentNode<RandomObject>,
241 branches: &mut Vec<usize>,
242 leaves: &mut Vec<&'a RandomObject>,
243 ) {
244 let children = node.children();
245 branches.push(children.len());
246 for child in children {
247 match child {
248 rstar::RTreeNode::Parent(node) => collect_rstar_index(node, branches, leaves),
249 rstar::RTreeNode::Leaf(obj) => {
250 branches.push(0);
251 leaves.push(obj);
252 }
253 }
254 }
255 }
256
257 #[test]
258 fn random_trees() {
259 TestRunner::default()
260 .run(&random_objects(100), |objects| {
261 let index = RTree::new(DEF_NODE_LEN, objects.clone());
262
263 let mut branches = Vec::new();
264 let mut leaves = Vec::new();
265
266 collect_index(&index, 0, &mut branches, &mut leaves);
267
268 let rstar_index = rstar::RTree::bulk_load(objects);
269
270 let mut rstar_branches = Vec::new();
271 let mut rstar_leaves = Vec::new();
272
273 collect_rstar_index(rstar_index.root(), &mut rstar_branches, &mut rstar_leaves);
274
275 assert_eq!(branches, rstar_branches);
276 assert_eq!(leaves, rstar_leaves);
277
278 Ok(())
279 })
280 .unwrap();
281 }
282}