Skip to main content

viewport_lib/geometry/
bvh.rs

1//! BVH-accelerated picking with TriMesh caching.
2//!
3//! Provides `PickAccelerator` — a binary bounding volume hierarchy built from
4//! scene objects' world-space AABBs. Ray queries traverse the BVH to quickly
5//! reject non-intersecting subtrees, then test leaf objects with cached
6//! `parry3d::TriMesh` instances.
7
8use std::collections::HashMap;
9
10use crate::interaction::selection::NodeId;
11use crate::resources::mesh_store::MeshId;
12use crate::scene::aabb::Aabb;
13use crate::scene::scene::Scene;
14
15use parry3d::math::Vector;
16use parry3d::query::{Ray, RayCast};
17use parry3d::shape::FeatureId;
18
19/// An entry in the BVH representing a single scene object.
20#[derive(Debug, Clone)]
21struct BvhEntry {
22    aabb: Aabb,
23    node_id: NodeId,
24    mesh_index: usize,
25    world_transform: glam::Mat4,
26}
27
28/// BVH tree node.
29enum BvhNode {
30    Leaf {
31        entry_indices: Vec<usize>,
32        aabb: Aabb,
33    },
34    Interior {
35        aabb: Aabb,
36        left: Box<BvhNode>,
37        right: Box<BvhNode>,
38    },
39}
40
41impl BvhNode {
42    fn aabb(&self) -> &Aabb {
43        match self {
44            BvhNode::Leaf { aabb, .. } => aabb,
45            BvhNode::Interior { aabb, .. } => aabb,
46        }
47    }
48}
49
50/// BVH-accelerated picking structure with TriMesh cache.
51pub struct PickAccelerator {
52    entries: Vec<BvhEntry>,
53    root: Option<BvhNode>,
54    trimesh_cache: HashMap<usize, parry3d::shape::TriMesh>,
55}
56
57impl PickAccelerator {
58    /// Build a BVH from the current scene state.
59    ///
60    /// `mesh_aabb_fn` provides the local-space AABB for each mesh.
61    pub fn build_from_scene(scene: &Scene, mesh_aabb_fn: impl Fn(MeshId) -> Option<Aabb>) -> Self {
62        let mut entries = Vec::new();
63        for node in scene.nodes() {
64            if !node.is_visible() {
65                continue;
66            }
67            let Some(mesh_id) = node.mesh_id() else {
68                continue;
69            };
70            if let Some(local_aabb) = mesh_aabb_fn(mesh_id) {
71                let world_aabb = local_aabb.transformed(&node.world_transform());
72                entries.push(BvhEntry {
73                    aabb: world_aabb,
74                    node_id: node.id(),
75                    mesh_index: mesh_id.index(),
76                    world_transform: node.world_transform(),
77                });
78            }
79        }
80
81        let indices: Vec<usize> = (0..entries.len()).collect();
82        let root = if entries.is_empty() {
83            None
84        } else {
85            Some(build_bvh_node(&entries, indices))
86        };
87
88        Self {
89            entries,
90            root,
91            trimesh_cache: HashMap::new(),
92        }
93    }
94
95    /// Pick the nearest object hit by the ray.
96    ///
97    /// `mesh_lookup` maps mesh_index to (positions, indices) for TriMesh construction.
98    pub fn pick(
99        &mut self,
100        ray_origin: glam::Vec3,
101        ray_dir: glam::Vec3,
102        mesh_lookup: &HashMap<u64, (Vec<[f32; 3]>, Vec<u32>)>,
103    ) -> Option<crate::interaction::picking::PickHit> {
104        let root = self.root.as_ref()?;
105        let mut best: Option<(NodeId, f32, crate::interaction::picking::PickHit)> = None;
106
107        // Collect candidate entry indices via iterative BVH traversal (read-only).
108        let mut candidates = Vec::new();
109        let mut stack: Vec<&BvhNode> = vec![root];
110        while let Some(node) = stack.pop() {
111            if !ray_aabb_test(ray_origin, ray_dir, node.aabb()) {
112                continue;
113            }
114            match node {
115                BvhNode::Leaf { entry_indices, .. } => {
116                    candidates.extend_from_slice(entry_indices);
117                }
118                BvhNode::Interior { left, right, .. } => {
119                    stack.push(left);
120                    stack.push(right);
121                }
122            }
123        }
124
125        // Test each candidate (may mutate trimesh_cache).
126        for idx in candidates {
127            let node_id = self.entries[idx].node_id;
128            let mesh_index = self.entries[idx].mesh_index;
129            let world_transform = self.entries[idx].world_transform;
130
131            if let Some((toi, mut hit)) = self.test_entry_by_parts(
132                mesh_index,
133                &world_transform,
134                ray_origin,
135                ray_dir,
136                mesh_lookup,
137            ) {
138                if best.is_none() || toi < best.as_ref().unwrap().1 {
139                    hit.id = node_id;
140                    best = Some((node_id, toi, hit));
141                }
142            }
143        }
144
145        best.map(|(_, _, hit)| hit)
146    }
147
148    fn test_entry_by_parts(
149        &mut self,
150        mesh_index: usize,
151        world_transform: &glam::Mat4,
152        ray_origin: glam::Vec3,
153        ray_dir: glam::Vec3,
154        mesh_lookup: &HashMap<u64, (Vec<[f32; 3]>, Vec<u32>)>,
155    ) -> Option<(f32, crate::interaction::picking::PickHit)> {
156        let (positions, indices) = mesh_lookup.get(&(mesh_index as u64))?;
157
158        // Lazily build and cache TriMesh.
159        if let std::collections::hash_map::Entry::Vacant(e) = self.trimesh_cache.entry(mesh_index) {
160            let verts: Vec<Vector> = positions
161                .iter()
162                .map(|p| Vector::new(p[0], p[1], p[2]))
163                .collect();
164            let tri_indices: Vec<[u32; 3]> = indices
165                .chunks(3)
166                .filter(|c| c.len() == 3)
167                .map(|c| [c[0], c[1], c[2]])
168                .collect();
169            if tri_indices.is_empty() {
170                return None;
171            }
172            match parry3d::shape::TriMesh::new(verts, tri_indices) {
173                Ok(tm) => {
174                    e.insert(tm);
175                }
176                Err(_) => return None,
177            }
178        }
179
180        let trimesh = self.trimesh_cache.get(&mesh_index)?;
181
182        // Extract scale, rotation, translation from world transform.
183        let (scale, rotation, translation) = world_transform.to_scale_rotation_translation();
184
185        // Transform ray into object's local (scaled) space.
186        let inv_rot = rotation.inverse();
187        let local_origin = inv_rot * (ray_origin - translation);
188        let local_dir = inv_rot * ray_dir;
189        let inv_scale = glam::Vec3::new(1.0 / scale.x, 1.0 / scale.y, 1.0 / scale.z);
190        let scaled_origin = local_origin * inv_scale;
191        let scaled_dir = (local_dir * inv_scale).normalize();
192
193        let ray = Ray::new(
194            Vector::new(scaled_origin.x, scaled_origin.y, scaled_origin.z),
195            Vector::new(scaled_dir.x, scaled_dir.y, scaled_dir.z),
196        );
197
198        trimesh
199            .cast_local_ray_and_get_normal(&ray, f32::MAX, true)
200            .map(|intersection| {
201                // Scale TOI back to world space.
202                let avg_scale = (scale.x + scale.y + scale.z) / 3.0;
203                let toi = intersection.time_of_impact * avg_scale;
204
205                let triangle_index = match intersection.feature {
206                    FeatureId::Face(idx) => idx,
207                    _ => u32::MAX,
208                };
209
210                // Transform hit point to world space.
211                // scaled_origin and scaled_dir are in inv-scaled local space, so:
212                // local_hit = scaled_origin + scaled_dir * intersection.time_of_impact
213                // undo inv_scale: multiply by scale to get unscaled local coords
214                // then apply rotation and translation.
215                let local_hit_scaled = scaled_origin + scaled_dir * intersection.time_of_impact;
216                let local_hit = local_hit_scaled * scale;
217                let world_pos = rotation * local_hit + translation;
218
219                // Transform normal to world space.
220                // The normal from cast_local_ray_and_get_normal is in scaled-local space.
221                // Use inverse-transpose (scale the normal by inv_scale) then normalize.
222                let world_normal = (rotation * (intersection.normal * inv_scale)).normalize();
223
224                (
225                    toi,
226                    crate::interaction::picking::PickHit {
227                        id: 0, // placeholder — caller fills in actual node_id
228                        triangle_index,
229                        world_pos,
230                        normal: world_normal,
231                        point_index: None,
232                        scalar_value: None,
233                    },
234                )
235            })
236    }
237
238    /// Invalidate the TriMesh cache for a specific mesh (e.g. after re-tessellation).
239    pub fn invalidate_mesh(&mut self, mesh_index: usize) {
240        self.trimesh_cache.remove(&mesh_index);
241    }
242
243    /// Clear all cached data. A full rebuild is needed.
244    pub fn invalidate_all(&mut self) {
245        self.trimesh_cache.clear();
246        self.entries.clear();
247        self.root = None;
248    }
249
250    /// Number of cached TriMesh instances.
251    pub fn trimesh_cache_len(&self) -> usize {
252        self.trimesh_cache.len()
253    }
254}
255
256// ---------------------------------------------------------------------------
257// BVH construction (SAH-based binary split)
258// ---------------------------------------------------------------------------
259
260fn build_bvh_node(entries: &[BvhEntry], indices: Vec<usize>) -> BvhNode {
261    // Compute combined AABB.
262    let combined = combined_aabb(entries, &indices);
263
264    // Leaf threshold: 4 or fewer entries.
265    if indices.len() <= 4 {
266        return BvhNode::Leaf {
267            entry_indices: indices,
268            aabb: combined,
269        };
270    }
271
272    // Find best split axis and position using SAH.
273    let (best_axis, best_split) = find_best_split(entries, &indices, &combined);
274
275    let mut left_indices = Vec::new();
276    let mut right_indices = Vec::new();
277    for &idx in &indices {
278        let center = entries[idx].aabb.center();
279        let val = match best_axis {
280            0 => center.x,
281            1 => center.y,
282            _ => center.z,
283        };
284        if val <= best_split {
285            left_indices.push(idx);
286        } else {
287            right_indices.push(idx);
288        }
289    }
290
291    // Fallback: if all entries ended up on one side, split in half.
292    if left_indices.is_empty() || right_indices.is_empty() {
293        let mid = indices.len() / 2;
294        left_indices = indices[..mid].to_vec();
295        right_indices = indices[mid..].to_vec();
296    }
297
298    BvhNode::Interior {
299        aabb: combined,
300        left: Box::new(build_bvh_node(entries, left_indices)),
301        right: Box::new(build_bvh_node(entries, right_indices)),
302    }
303}
304
305fn combined_aabb(entries: &[BvhEntry], indices: &[usize]) -> Aabb {
306    let mut min = glam::Vec3::splat(f32::INFINITY);
307    let mut max = glam::Vec3::splat(f32::NEG_INFINITY);
308    for &idx in indices {
309        min = min.min(entries[idx].aabb.min);
310        max = max.max(entries[idx].aabb.max);
311    }
312    Aabb { min, max }
313}
314
315fn find_best_split(_entries: &[BvhEntry], _indices: &[usize], parent_aabb: &Aabb) -> (usize, f32) {
316    let extents = parent_aabb.max - parent_aabb.min;
317    // Choose the longest axis.
318    let axis = if extents.x >= extents.y && extents.x >= extents.z {
319        0
320    } else if extents.y >= extents.z {
321        1
322    } else {
323        2
324    };
325
326    // Use midpoint of the longest axis as the split.
327    let split = match axis {
328        0 => (parent_aabb.min.x + parent_aabb.max.x) * 0.5,
329        1 => (parent_aabb.min.y + parent_aabb.max.y) * 0.5,
330        _ => (parent_aabb.min.z + parent_aabb.max.z) * 0.5,
331    };
332
333    (axis, split)
334}
335
336// ---------------------------------------------------------------------------
337// Ray-AABB intersection test
338// ---------------------------------------------------------------------------
339
340fn ray_aabb_test(origin: glam::Vec3, dir: glam::Vec3, aabb: &Aabb) -> bool {
341    let inv_dir = glam::Vec3::new(
342        if dir.x.abs() > 1e-10 {
343            1.0 / dir.x
344        } else {
345            f32::INFINITY * dir.x.signum()
346        },
347        if dir.y.abs() > 1e-10 {
348            1.0 / dir.y
349        } else {
350            f32::INFINITY * dir.y.signum()
351        },
352        if dir.z.abs() > 1e-10 {
353            1.0 / dir.z
354        } else {
355            f32::INFINITY * dir.z.signum()
356        },
357    );
358
359    let t1 = (aabb.min - origin) * inv_dir;
360    let t2 = (aabb.max - origin) * inv_dir;
361
362    let tmin = t1.min(t2);
363    let tmax = t1.max(t2);
364
365    let tenter = tmin.x.max(tmin.y).max(tmin.z);
366    let texit = tmax.x.min(tmax.y).min(tmax.z);
367
368    texit >= tenter.max(0.0)
369}
370
371// ---------------------------------------------------------------------------
372// Public API wrapper
373// ---------------------------------------------------------------------------
374
375/// Pick the nearest scene node using a BVH accelerator.
376///
377/// Thin wrapper around `PickAccelerator::pick`.
378pub fn pick_scene_accelerated(
379    ray_origin: glam::Vec3,
380    ray_dir: glam::Vec3,
381    accelerator: &mut PickAccelerator,
382    mesh_lookup: &HashMap<u64, (Vec<[f32; 3]>, Vec<u32>)>,
383) -> Option<crate::interaction::picking::PickHit> {
384    accelerator.pick(ray_origin, ray_dir, mesh_lookup)
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use crate::resources::mesh_store::MeshId;
391    use crate::scene::material::Material;
392
393    fn unit_cube_mesh() -> (Vec<[f32; 3]>, Vec<u32>) {
394        let positions = vec![
395            [-0.5, -0.5, -0.5],
396            [0.5, -0.5, -0.5],
397            [0.5, 0.5, -0.5],
398            [-0.5, 0.5, -0.5],
399            [-0.5, -0.5, 0.5],
400            [0.5, -0.5, 0.5],
401            [0.5, 0.5, 0.5],
402            [-0.5, 0.5, 0.5],
403        ];
404        let indices = vec![
405            0, 1, 2, 2, 3, 0, 4, 6, 5, 6, 4, 7, 0, 3, 7, 7, 4, 0, 1, 5, 6, 6, 2, 1, 3, 2, 6, 6, 7,
406            3, 0, 4, 5, 5, 1, 0,
407        ];
408        (positions, indices)
409    }
410
411    fn unit_aabb() -> Aabb {
412        Aabb {
413            min: glam::Vec3::splat(-0.5),
414            max: glam::Vec3::splat(0.5),
415        }
416    }
417
418    #[test]
419    fn test_bvh_build_single() {
420        let mut scene = Scene::new();
421        scene.add(Some(MeshId(0)), glam::Mat4::IDENTITY, Material::default());
422        scene.update_transforms();
423
424        let accel = PickAccelerator::build_from_scene(&scene, |_| Some(unit_aabb()));
425        assert_eq!(accel.entries.len(), 1);
426        assert!(accel.root.is_some());
427    }
428
429    #[test]
430    fn test_bvh_pick_hit() {
431        let mut scene = Scene::new();
432        scene.add(Some(MeshId(0)), glam::Mat4::IDENTITY, Material::default());
433        scene.update_transforms();
434
435        let mut accel = PickAccelerator::build_from_scene(&scene, |_| Some(unit_aabb()));
436
437        let (positions, indices) = unit_cube_mesh();
438        let mut mesh_lookup = HashMap::new();
439        mesh_lookup.insert(0u64, (positions, indices));
440
441        let result = accel.pick(
442            glam::Vec3::new(0.0, 0.0, 5.0),
443            glam::Vec3::new(0.0, 0.0, -1.0),
444            &mesh_lookup,
445        );
446        assert!(result.is_some(), "should hit the cube");
447    }
448
449    #[test]
450    fn test_bvh_pick_miss() {
451        let mut scene = Scene::new();
452        scene.add(Some(MeshId(0)), glam::Mat4::IDENTITY, Material::default());
453        scene.update_transforms();
454
455        let mut accel = PickAccelerator::build_from_scene(&scene, |_| Some(unit_aabb()));
456
457        let (positions, indices) = unit_cube_mesh();
458        let mut mesh_lookup = HashMap::new();
459        mesh_lookup.insert(0u64, (positions, indices));
460
461        let result = accel.pick(
462            glam::Vec3::new(100.0, 100.0, 5.0),
463            glam::Vec3::new(0.0, 0.0, -1.0),
464            &mesh_lookup,
465        );
466        assert!(result.is_none(), "should miss");
467    }
468
469    #[test]
470    fn test_bvh_pick_nearest() {
471        let mut scene = Scene::new();
472        scene.add(
473            Some(MeshId(0)),
474            glam::Mat4::from_translation(glam::Vec3::new(0.0, 0.0, 2.0)),
475            Material::default(),
476        );
477        scene.add(
478            Some(MeshId(1)),
479            glam::Mat4::from_translation(glam::Vec3::new(0.0, 0.0, -2.0)),
480            Material::default(),
481        );
482        scene.update_transforms();
483
484        let mut accel = PickAccelerator::build_from_scene(&scene, |_| Some(unit_aabb()));
485
486        let (positions, indices) = unit_cube_mesh();
487        let mut mesh_lookup = HashMap::new();
488        mesh_lookup.insert(0u64, (positions.clone(), indices.clone()));
489        mesh_lookup.insert(1u64, (positions, indices));
490
491        // Ray from z=10 toward -Z: should hit the nearer object at z=2.
492        let result = accel.pick(
493            glam::Vec3::new(0.0, 0.0, 10.0),
494            glam::Vec3::new(0.0, 0.0, -1.0),
495            &mesh_lookup,
496        );
497        assert!(result.is_some(), "should hit something");
498    }
499
500    #[test]
501    fn test_trimesh_cache_reuse() {
502        let mut scene = Scene::new();
503        scene.add(Some(MeshId(0)), glam::Mat4::IDENTITY, Material::default());
504        scene.update_transforms();
505
506        let mut accel = PickAccelerator::build_from_scene(&scene, |_| Some(unit_aabb()));
507
508        let (positions, indices) = unit_cube_mesh();
509        let mut mesh_lookup = HashMap::new();
510        mesh_lookup.insert(0u64, (positions, indices));
511
512        // First pick — builds TriMesh.
513        let _ = accel.pick(
514            glam::Vec3::new(0.0, 0.0, 5.0),
515            glam::Vec3::new(0.0, 0.0, -1.0),
516            &mesh_lookup,
517        );
518        assert_eq!(accel.trimesh_cache_len(), 1);
519
520        // Second pick — should reuse cached TriMesh (cache len stays 1).
521        let _ = accel.pick(
522            glam::Vec3::new(0.0, 0.0, 5.0),
523            glam::Vec3::new(0.0, 0.0, -1.0),
524            &mesh_lookup,
525        );
526        assert_eq!(accel.trimesh_cache_len(), 1);
527    }
528
529    #[test]
530    fn test_ray_aabb_hit() {
531        let aabb = unit_aabb();
532        assert!(ray_aabb_test(
533            glam::Vec3::new(0.0, 0.0, 5.0),
534            glam::Vec3::new(0.0, 0.0, -1.0),
535            &aabb,
536        ));
537    }
538
539    #[test]
540    fn test_ray_aabb_miss() {
541        let aabb = unit_aabb();
542        assert!(!ray_aabb_test(
543            glam::Vec3::new(100.0, 100.0, 5.0),
544            glam::Vec3::new(0.0, 0.0, -1.0),
545            &aabb,
546        ));
547    }
548}