toolbox_rs/
r_tree.rs

1use log::debug;
2use num::integer::Roots;
3use std::{cmp::Ordering, collections::BinaryHeap};
4use thiserror::Error;
5
6const BRANCHING_FACTOR: usize = 30;
7const LEAF_PACK_FACTOR: usize = 30;
8
9use crate::{
10    bounding_box::BoundingBox, geometry::FPCoordinate, partition_id::PartitionID,
11    space_filling_curve::zorder_cmp,
12};
13
14#[derive(Error, Debug)]
15pub enum RTreeError {
16    #[error("Empty tree")]
17    EmptyTree,
18    #[error("Invalid coordinate")]
19    InvalidCoordinate,
20    #[error("Node index out of bounds: {0}")]
21    InvalidNodeIndex(usize),
22}
23
24#[derive(Clone, Debug)]
25pub struct Leaf<T> {
26    bbox: BoundingBox,
27    elements: Vec<T>,
28}
29
30impl<T: RTreeElement> Leaf<T> {
31    pub fn new(bbox: BoundingBox, elements: Vec<T>) -> Self {
32        Self { bbox, elements }
33    }
34
35    #[must_use]
36    pub fn bbox(&self) -> &BoundingBox {
37        &self.bbox
38    }
39
40    #[must_use]
41    pub fn elements(&self) -> &[T] {
42        &self.elements
43    }
44}
45
46#[derive(Clone, Copy, Debug, Eq, PartialEq)]
47pub struct LeafNode {
48    bbox: BoundingBox,
49    index: usize,
50}
51
52impl LeafNode {
53    pub fn new(bbox: BoundingBox, index: usize) -> Self {
54        Self { bbox, index }
55    }
56}
57
58#[derive(Clone, Copy, Debug)]
59pub struct TreeNode {
60    bbox: BoundingBox,
61    index: usize,
62}
63
64#[derive(Clone, Copy, Debug)]
65enum SearchNode {
66    LeafNode(LeafNode),
67    TreeNode(TreeNode),
68}
69
70#[derive(Debug, PartialEq)]
71enum QueueNodeType {
72    TreeNode,
73    LeafNode,
74    Candidate(usize),
75}
76
77#[derive(Debug)]
78struct QueueElement {
79    distance: f64,
80    child_start_index: usize,
81    node_type: QueueNodeType,
82}
83
84impl QueueElement {
85    /// Creates a new queue element for the R-tree search.
86    ///
87    /// # Arguments
88    ///
89    /// * `distance` - Minimum possible distance to search point
90    /// * `child_start_index` - Starting index of children in the node array
91    /// * `node_type` - Type of the node (Tree, Leaf or Candidate)
92    pub fn new(distance: f64, child_start_index: usize, node_type: QueueNodeType) -> Self {
93        Self {
94            distance,
95            child_start_index,
96            node_type,
97        }
98    }
99}
100
101impl PartialEq for QueueElement {
102    fn eq(&self, other: &Self) -> bool {
103        self.distance == other.distance
104    }
105}
106
107impl PartialOrd for QueueElement {
108    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
109        Some(self.cmp(other))
110    }
111}
112
113impl Eq for QueueElement {}
114
115impl Ord for QueueElement {
116    fn cmp(&self, other: &Self) -> Ordering {
117        other.distance.partial_cmp(&self.distance).unwrap()
118    }
119}
120
121/// Trait for elements that can be stored in an RTree
122pub trait RTreeElement {
123    /// Returns the bounding box of this element
124    fn bbox(&self) -> BoundingBox;
125
126    /// Returns the distance from this element to the given coordinate
127    fn distance_to(&self, coordinate: &FPCoordinate) -> f64;
128
129    /// Returns the center coordinate of this element
130    fn center(&self) -> &FPCoordinate;
131}
132
133#[derive(Debug)]
134pub struct RTree<T: RTreeElement> {
135    leaf_nodes: Vec<Leaf<T>>,
136    search_nodes: Vec<SearchNode>,
137}
138
139impl<T: RTreeElement + std::clone::Clone> RTree<T> {
140    /// Creates a new R-tree from an iterator of elements.
141    ///
142    /// # Arguments
143    ///
144    /// * `elements` - Iterator of elements to build the tree from
145    ///
146    /// # Returns
147    ///
148    /// A new R-tree instance with the provided data organized in a hierarchical structure
149    ///
150    /// # Implementation Details
151    ///
152    /// 1. Creates leaf nodes with up to LEAF_PACK_FACTOR elements
153    /// 2. Builds interior nodes with up to BRANCHING_FACTOR children
154    /// 3. Constructs tree bottom-up level by level
155    #[must_use]
156    pub fn from_elements<I>(elements: I) -> Self
157    where
158        I: IntoIterator<Item = T>,
159    {
160        let mut elements: Vec<_> = elements.into_iter().collect();
161        debug!("Creating R-tree from {} elements", elements.len());
162        debug!("sorting by z-order");
163        elements.sort_by(|a, b| zorder_cmp(a.center(), b.center()));
164
165        let estimated_leaf_nodes = elements.len().div_ceil(LEAF_PACK_FACTOR);
166        let estimated_search_nodes = estimated_leaf_nodes * 2; // Rough estimate for tree structure
167
168        let mut search_nodes = Vec::with_capacity(estimated_search_nodes);
169        let mut next = Vec::with_capacity(estimated_leaf_nodes);
170
171        // Create leaf nodes
172        let leaf_nodes = elements
173            .chunks(LEAF_PACK_FACTOR)
174            .map(|chunk| {
175                let bbox = chunk.iter().fold(BoundingBox::invalid(), |mut acc, elem| {
176                    acc.extend_with(&elem.bbox());
177                    acc
178                });
179                Leaf::new(bbox, chunk.to_vec())
180            })
181            .collect::<Vec<_>>();
182
183        search_nodes.extend(leaf_nodes.chunks(BRANCHING_FACTOR).enumerate().map(
184            |(index, chunk)| {
185                let bbox = chunk.iter().fold(BoundingBox::invalid(), |acc, leaf| {
186                    let mut bbox = acc;
187                    bbox.extend_with(leaf.bbox());
188                    bbox
189                });
190                SearchNode::LeafNode(LeafNode::new(bbox, BRANCHING_FACTOR * index))
191            },
192        ));
193
194        debug!("Created {} search nodes", search_nodes.len());
195
196        let mut start = 0;
197        let mut end = search_nodes.len();
198
199        let mut level = 0;
200        debug!("Creating tree nodes, start {start}, end {end}");
201        while start < end - 1 {
202            debug!(
203                "level: {}, packing {} nodes [{}]",
204                level,
205                search_nodes.len(),
206                (end - start)
207            );
208            level += 1;
209            search_nodes[start..end]
210                .chunks(BRANCHING_FACTOR)
211                .enumerate()
212                .for_each(|(index, node)| {
213                    let bbox = node.iter().fold(BoundingBox::invalid(), |acc, node| {
214                        let mut bbox = acc;
215                        match node {
216                            SearchNode::LeafNode(leaf) => {
217                                bbox.extend_with(&leaf.bbox);
218                            }
219                            SearchNode::TreeNode(tree) => {
220                                bbox.extend_with(&tree.bbox);
221                            }
222                        }
223                        bbox
224                    });
225                    next.push(SearchNode::TreeNode(TreeNode {
226                        bbox,
227                        index: start + (BRANCHING_FACTOR * index),
228                    }));
229                });
230            start = end;
231            end += next.len();
232            search_nodes.append(&mut next);
233            next.clear();
234        }
235
236        debug!("Created {} search nodes", search_nodes.len());
237        debug!(
238            "Created tree with {} levels, {} leaf nodes, {} total nodes",
239            level,
240            leaf_nodes.len(),
241            search_nodes.len()
242        );
243
244        RTree {
245            leaf_nodes,
246            search_nodes,
247        }
248    }
249
250    /// Returns an iterator over elements in ascending order of distance from the given coordinate
251    pub fn nearest_iter<'a>(&'a self, coordinate: &'a FPCoordinate) -> RTreeNearestIterator<'a, T> {
252        RTreeNearestIterator::new(self, coordinate)
253    }
254}
255
256// Implement RTreeElement for the original (FPCoordinate, PartitionID) tuple
257impl RTreeElement for (FPCoordinate, PartitionID) {
258    fn bbox(&self) -> BoundingBox {
259        BoundingBox::from_coordinate(&self.0)
260    }
261
262    fn distance_to(&self, coordinate: &FPCoordinate) -> f64 {
263        self.0.distance_to(coordinate)
264    }
265
266    fn center(&self) -> &FPCoordinate {
267        &self.0
268    }
269}
270
271#[derive(Debug)]
272pub struct RTreeNearestIterator<'a, T: RTreeElement> {
273    tree: &'a RTree<T>,
274    input_coordinate: &'a FPCoordinate,
275    queue: BinaryHeap<QueueElement>,
276}
277
278impl<'a, T: RTreeElement> RTreeNearestIterator<'a, T> {
279    fn new(tree: &'a RTree<T>, input_coordinate: &'a FPCoordinate) -> Self {
280        let capacity = (tree.leaf_nodes.len() * LEAF_PACK_FACTOR).sqrt();
281        let mut queue = BinaryHeap::with_capacity(capacity);
282
283        // Initialize with root node if tree is not empty
284        if let Some(SearchNode::TreeNode(root)) = tree.search_nodes.last() {
285            queue.push(QueueElement::new(
286                root.bbox.min_distance(input_coordinate),
287                root.index,
288                QueueNodeType::TreeNode,
289            ));
290        }
291
292        Self {
293            tree,
294            input_coordinate,
295            queue,
296        }
297    }
298}
299
300impl<T: RTreeElement + Clone> Iterator for RTreeNearestIterator<'_, T> {
301    /// Returns the next nearest element and its distance from the query point.
302    /// Elements are returned in ascending order of distance.
303    ///
304    /// # Returns
305    /// * `Some((element, distance))` - The next nearest element and its distance
306    /// * `None` - When all elements have been visited
307    type Item = (T, f64);
308
309    fn next(&mut self) -> Option<Self::Item> {
310        while let Some(QueueElement {
311            distance,
312            child_start_index,
313            node_type,
314        }) = self.queue.pop()
315        {
316            match node_type {
317                QueueNodeType::TreeNode => {
318                    let children_count =
319                        BRANCHING_FACTOR.min(self.tree.search_nodes.len() - 1 - child_start_index);
320                    for i in 0..children_count {
321                        match &self.tree.search_nodes[child_start_index + i] {
322                            SearchNode::LeafNode(node) => self.queue.push(QueueElement::new(
323                                node.bbox.min_distance(self.input_coordinate),
324                                node.index,
325                                QueueNodeType::LeafNode,
326                            )),
327                            SearchNode::TreeNode(node) => self.queue.push(QueueElement::new(
328                                node.bbox.min_distance(self.input_coordinate),
329                                node.index,
330                                QueueNodeType::TreeNode,
331                            )),
332                        }
333                    }
334                }
335                QueueNodeType::LeafNode => {
336                    for leaf_idx in 0..LEAF_PACK_FACTOR {
337                        let leaf = &self.tree.leaf_nodes[child_start_index + leaf_idx];
338                        for (elem_idx, elem) in leaf.elements().iter().enumerate() {
339                            let dist = elem.distance_to(self.input_coordinate);
340                            self.queue.push(QueueElement::new(
341                                dist,
342                                child_start_index,
343                                QueueNodeType::Candidate(elem_idx),
344                            ));
345                        }
346                    }
347                }
348                QueueNodeType::Candidate(offset) => {
349                    let element =
350                        self.tree.leaf_nodes[child_start_index].elements()[offset].clone();
351                    return Some((element, distance));
352                }
353            }
354        }
355        None
356    }
357}