vibe_graph_layout_gpu/
quadtree.rs

1//! Barnes-Hut quadtree for O(n log n) force approximation.
2//!
3//! The quadtree recursively subdivides space and computes center of mass
4//! for each cell. Distant cells can be approximated as single points,
5//! reducing the O(n²) pairwise force calculation to O(n log n).
6
7use crate::{Position, QuadTreeNode};
8
9/// A Barnes-Hut quadtree for 2D spatial partitioning.
10#[derive(Debug)]
11pub struct QuadTree {
12    /// Flattened tree nodes for GPU upload
13    nodes: Vec<QuadTreeNode>,
14    /// Bounding box min
15    bounds_min: Position,
16    /// Bounding box max
17    bounds_max: Position,
18}
19
20impl QuadTree {
21    /// Build a quadtree from node positions.
22    ///
23    /// # Arguments
24    /// * `positions` - Slice of node positions
25    /// * `max_depth` - Maximum tree depth (typically 10-15)
26    pub fn build(positions: &[Position], max_depth: usize) -> Self {
27        if positions.is_empty() {
28            return Self {
29                nodes: vec![QuadTreeNode::default()],
30                bounds_min: Position::default(),
31                bounds_max: Position::default(),
32            };
33        }
34
35        // Find bounding box with some padding
36        let mut min_x = f32::MAX;
37        let mut min_y = f32::MAX;
38        let mut max_x = f32::MIN;
39        let mut max_y = f32::MIN;
40
41        for pos in positions {
42            min_x = min_x.min(pos.x);
43            min_y = min_y.min(pos.y);
44            max_x = max_x.max(pos.x);
45            max_y = max_y.max(pos.y);
46        }
47
48        // Add padding
49        let padding = ((max_x - min_x).max(max_y - min_y) * 0.1).max(1.0);
50        min_x -= padding;
51        min_y -= padding;
52        max_x += padding;
53        max_y += padding;
54
55        // Make it square
56        let width = (max_x - min_x).max(max_y - min_y);
57        let center_x = (min_x + max_x) / 2.0;
58        let center_y = (min_y + max_y) / 2.0;
59
60        let bounds_min = Position::new(center_x - width / 2.0, center_y - width / 2.0);
61        let bounds_max = Position::new(center_x + width / 2.0, center_y + width / 2.0);
62
63        // Build tree recursively
64        let mut nodes = Vec::with_capacity(positions.len() * 2);
65        let mut builder = TreeBuilder {
66            positions,
67            nodes: &mut nodes,
68            max_depth,
69        };
70
71        let indices: Vec<usize> = (0..positions.len()).collect();
72        builder.build_node(&indices, bounds_min.x, bounds_min.y, width, 0);
73
74        Self {
75            nodes,
76            bounds_min,
77            bounds_max,
78        }
79    }
80
81    /// Get the flattened tree nodes for GPU upload.
82    pub fn nodes(&self) -> &[QuadTreeNode] {
83        &self.nodes
84    }
85
86    /// Get the bounding box.
87    pub fn bounds(&self) -> (Position, Position) {
88        (self.bounds_min, self.bounds_max)
89    }
90}
91
92struct TreeBuilder<'a> {
93    positions: &'a [Position],
94    nodes: &'a mut Vec<QuadTreeNode>,
95    max_depth: usize,
96}
97
98impl<'a> TreeBuilder<'a> {
99    fn build_node(
100        &mut self,
101        indices: &[usize],
102        x: f32,
103        y: f32,
104        width: f32,
105        depth: usize,
106    ) -> i32 {
107        if indices.is_empty() {
108            return -1;
109        }
110
111        let node_idx = self.nodes.len() as i32;
112        self.nodes.push(QuadTreeNode::default());
113
114        // Compute center of mass
115        let mut com_x = 0.0;
116        let mut com_y = 0.0;
117        let mass = indices.len() as f32;
118
119        for &i in indices {
120            com_x += self.positions[i].x;
121            com_y += self.positions[i].y;
122        }
123        com_x /= mass;
124        com_y /= mass;
125
126        // If leaf (single node or max depth), store as leaf
127        if indices.len() == 1 || depth >= self.max_depth {
128            self.nodes[node_idx as usize] = QuadTreeNode {
129                center_x: com_x,
130                center_y: com_y,
131                mass,
132                width,
133                child_nw: -1,
134                child_ne: -1,
135                child_sw: -1,
136                child_se: -1,
137            };
138            return node_idx;
139        }
140
141        // Subdivide into quadrants
142        let half_width = width / 2.0;
143        let mid_x = x + half_width;
144        let mid_y = y + half_width;
145
146        let mut nw_indices = Vec::new();
147        let mut ne_indices = Vec::new();
148        let mut sw_indices = Vec::new();
149        let mut se_indices = Vec::new();
150
151        for &i in indices {
152            let pos = &self.positions[i];
153            if pos.x < mid_x {
154                if pos.y < mid_y {
155                    sw_indices.push(i);
156                } else {
157                    nw_indices.push(i);
158                }
159            } else if pos.y < mid_y {
160                se_indices.push(i);
161            } else {
162                ne_indices.push(i);
163            }
164        }
165
166        // Recursively build children
167        let child_nw = self.build_node(&nw_indices, x, mid_y, half_width, depth + 1);
168        let child_ne = self.build_node(&ne_indices, mid_x, mid_y, half_width, depth + 1);
169        let child_sw = self.build_node(&sw_indices, x, y, half_width, depth + 1);
170        let child_se = self.build_node(&se_indices, mid_x, y, half_width, depth + 1);
171
172        // Update node
173        self.nodes[node_idx as usize] = QuadTreeNode {
174            center_x: com_x,
175            center_y: com_y,
176            mass,
177            width,
178            child_nw,
179            child_ne,
180            child_sw,
181            child_se,
182        };
183
184        node_idx
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn test_empty_tree() {
194        let tree = QuadTree::build(&[], 10);
195        assert_eq!(tree.nodes().len(), 1);
196    }
197
198    #[test]
199    fn test_single_node() {
200        let positions = vec![Position::new(0.0, 0.0)];
201        let tree = QuadTree::build(&positions, 10);
202        assert!(!tree.nodes().is_empty());
203        assert_eq!(tree.nodes()[0].mass, 1.0);
204    }
205
206    #[test]
207    fn test_multiple_nodes() {
208        let positions = vec![
209            Position::new(0.0, 0.0),
210            Position::new(100.0, 0.0),
211            Position::new(0.0, 100.0),
212            Position::new(100.0, 100.0),
213        ];
214        let tree = QuadTree::build(&positions, 10);
215        // Should have subdivided
216        assert!(tree.nodes().len() > 1);
217    }
218}
219