Skip to main content

vibe_graph_layout_gpu/
quadtree.rs

1//! Barnes-Hut quadtree for O(n log n) force approximation.
2//!
3//! The quadtree recursively subdivides space and computes center of mass
4//! for each cell. Distant cells can be approximated as single points,
5//! reducing the O(n²) pairwise force calculation to O(n log n).
6
7use crate::{Position, QuadTreeNode};
8
9/// A Barnes-Hut quadtree for 2D spatial partitioning.
10#[derive(Debug)]
11pub struct QuadTree {
12    /// Flattened tree nodes for GPU upload
13    nodes: Vec<QuadTreeNode>,
14    /// Bounding box min
15    bounds_min: Position,
16    /// Bounding box max
17    bounds_max: Position,
18}
19
20impl QuadTree {
21    /// Build a quadtree from node positions.
22    ///
23    /// # Arguments
24    /// * `positions` - Slice of node positions
25    /// * `max_depth` - Maximum tree depth (typically 10-15)
26    pub fn build(positions: &[Position], max_depth: usize) -> Self {
27        if positions.is_empty() {
28            return Self {
29                nodes: vec![QuadTreeNode::default()],
30                bounds_min: Position::default(),
31                bounds_max: Position::default(),
32            };
33        }
34
35        // Find bounding box with some padding
36        let mut min_x = f32::MAX;
37        let mut min_y = f32::MAX;
38        let mut max_x = f32::MIN;
39        let mut max_y = f32::MIN;
40
41        for pos in positions {
42            min_x = min_x.min(pos.x);
43            min_y = min_y.min(pos.y);
44            max_x = max_x.max(pos.x);
45            max_y = max_y.max(pos.y);
46        }
47
48        // Add padding
49        let padding = ((max_x - min_x).max(max_y - min_y) * 0.1).max(1.0);
50        min_x -= padding;
51        min_y -= padding;
52        max_x += padding;
53        max_y += padding;
54
55        // Make it square
56        let width = (max_x - min_x).max(max_y - min_y);
57        let center_x = (min_x + max_x) / 2.0;
58        let center_y = (min_y + max_y) / 2.0;
59
60        let bounds_min = Position::new(center_x - width / 2.0, center_y - width / 2.0);
61        let bounds_max = Position::new(center_x + width / 2.0, center_y + width / 2.0);
62
63        // Build tree recursively
64        let mut nodes = Vec::with_capacity(positions.len() * 2);
65        let mut builder = TreeBuilder {
66            positions,
67            nodes: &mut nodes,
68            max_depth,
69        };
70
71        let indices: Vec<usize> = (0..positions.len()).collect();
72        builder.build_node(&indices, bounds_min.x, bounds_min.y, width, 0);
73
74        Self {
75            nodes,
76            bounds_min,
77            bounds_max,
78        }
79    }
80
81    /// Get the flattened tree nodes for GPU upload.
82    pub fn nodes(&self) -> &[QuadTreeNode] {
83        &self.nodes
84    }
85
86    /// Get the bounding box.
87    pub fn bounds(&self) -> (Position, Position) {
88        (self.bounds_min, self.bounds_max)
89    }
90}
91
92struct TreeBuilder<'a> {
93    positions: &'a [Position],
94    nodes: &'a mut Vec<QuadTreeNode>,
95    max_depth: usize,
96}
97
98impl<'a> TreeBuilder<'a> {
99    fn build_node(&mut self, indices: &[usize], x: f32, y: f32, width: f32, depth: usize) -> i32 {
100        if indices.is_empty() {
101            return -1;
102        }
103
104        let node_idx = self.nodes.len() as i32;
105        self.nodes.push(QuadTreeNode::default());
106
107        // Compute center of mass
108        let mut com_x = 0.0;
109        let mut com_y = 0.0;
110        let mass = indices.len() as f32;
111
112        for &i in indices {
113            com_x += self.positions[i].x;
114            com_y += self.positions[i].y;
115        }
116        com_x /= mass;
117        com_y /= mass;
118
119        // If leaf (single node or max depth), store as leaf
120        if indices.len() == 1 || depth >= self.max_depth {
121            self.nodes[node_idx as usize] = QuadTreeNode {
122                center_x: com_x,
123                center_y: com_y,
124                mass,
125                width,
126                child_nw: -1,
127                child_ne: -1,
128                child_sw: -1,
129                child_se: -1,
130            };
131            return node_idx;
132        }
133
134        // Subdivide into quadrants
135        let half_width = width / 2.0;
136        let mid_x = x + half_width;
137        let mid_y = y + half_width;
138
139        let mut nw_indices = Vec::new();
140        let mut ne_indices = Vec::new();
141        let mut sw_indices = Vec::new();
142        let mut se_indices = Vec::new();
143
144        for &i in indices {
145            let pos = &self.positions[i];
146            if pos.x < mid_x {
147                if pos.y < mid_y {
148                    sw_indices.push(i);
149                } else {
150                    nw_indices.push(i);
151                }
152            } else if pos.y < mid_y {
153                se_indices.push(i);
154            } else {
155                ne_indices.push(i);
156            }
157        }
158
159        // Recursively build children
160        let child_nw = self.build_node(&nw_indices, x, mid_y, half_width, depth + 1);
161        let child_ne = self.build_node(&ne_indices, mid_x, mid_y, half_width, depth + 1);
162        let child_sw = self.build_node(&sw_indices, x, y, half_width, depth + 1);
163        let child_se = self.build_node(&se_indices, mid_x, y, half_width, depth + 1);
164
165        // Update node
166        self.nodes[node_idx as usize] = QuadTreeNode {
167            center_x: com_x,
168            center_y: com_y,
169            mass,
170            width,
171            child_nw,
172            child_ne,
173            child_sw,
174            child_se,
175        };
176
177        node_idx
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_empty_tree() {
187        let tree = QuadTree::build(&[], 10);
188        assert_eq!(tree.nodes().len(), 1);
189    }
190
191    #[test]
192    fn test_single_node() {
193        let positions = vec![Position::new(0.0, 0.0)];
194        let tree = QuadTree::build(&positions, 10);
195        assert!(!tree.nodes().is_empty());
196        assert_eq!(tree.nodes()[0].mass, 1.0);
197    }
198
199    #[test]
200    fn test_multiple_nodes() {
201        let positions = vec![
202            Position::new(0.0, 0.0),
203            Position::new(100.0, 0.0),
204            Position::new(0.0, 100.0),
205            Position::new(100.0, 100.0),
206        ];
207        let tree = QuadTree::build(&positions, 10);
208        // Should have subdivided
209        assert!(tree.nodes().len() > 1);
210    }
211}