1use std::collections::HashMap;
9
10use crate::interaction::selection::NodeId;
11use crate::resources::mesh_store::MeshId;
12use crate::scene::aabb::Aabb;
13use crate::scene::scene::Scene;
14
15use parry3d::math::Vector;
16use parry3d::query::{Ray, RayCast};
17use parry3d::shape::FeatureId;
18
19#[derive(Debug, Clone)]
21struct BvhEntry {
22 aabb: Aabb,
23 node_id: NodeId,
24 mesh_index: usize,
25 world_transform: glam::Mat4,
26}
27
28enum BvhNode {
30 Leaf {
31 entry_indices: Vec<usize>,
32 aabb: Aabb,
33 },
34 Interior {
35 aabb: Aabb,
36 left: Box<BvhNode>,
37 right: Box<BvhNode>,
38 },
39}
40
41impl BvhNode {
42 fn aabb(&self) -> &Aabb {
43 match self {
44 BvhNode::Leaf { aabb, .. } => aabb,
45 BvhNode::Interior { aabb, .. } => aabb,
46 }
47 }
48}
49
50pub struct PickAccelerator {
52 entries: Vec<BvhEntry>,
53 root: Option<BvhNode>,
54 trimesh_cache: HashMap<usize, parry3d::shape::TriMesh>,
55}
56
57impl PickAccelerator {
58 pub fn build_from_scene(scene: &Scene, mesh_aabb_fn: impl Fn(MeshId) -> Option<Aabb>) -> Self {
62 let mut entries = Vec::new();
63 for node in scene.nodes() {
64 if !node.is_visible() {
65 continue;
66 }
67 let Some(mesh_id) = node.mesh_id() else {
68 continue;
69 };
70 if let Some(local_aabb) = mesh_aabb_fn(mesh_id) {
71 let world_aabb = local_aabb.transformed(&node.world_transform());
72 entries.push(BvhEntry {
73 aabb: world_aabb,
74 node_id: node.id(),
75 mesh_index: mesh_id.index(),
76 world_transform: node.world_transform(),
77 });
78 }
79 }
80
81 let indices: Vec<usize> = (0..entries.len()).collect();
82 let root = if entries.is_empty() {
83 None
84 } else {
85 Some(build_bvh_node(&entries, indices))
86 };
87
88 Self {
89 entries,
90 root,
91 trimesh_cache: HashMap::new(),
92 }
93 }
94
95 pub fn pick(
99 &mut self,
100 ray_origin: glam::Vec3,
101 ray_dir: glam::Vec3,
102 mesh_lookup: &HashMap<u64, (Vec<[f32; 3]>, Vec<u32>)>,
103 ) -> Option<crate::interaction::picking::PickHit> {
104 let root = self.root.as_ref()?;
105 let mut best: Option<(NodeId, f32, crate::interaction::picking::PickHit)> = None;
106
107 let mut candidates = Vec::new();
109 let mut stack: Vec<&BvhNode> = vec![root];
110 while let Some(node) = stack.pop() {
111 if !ray_aabb_test(ray_origin, ray_dir, node.aabb()) {
112 continue;
113 }
114 match node {
115 BvhNode::Leaf { entry_indices, .. } => {
116 candidates.extend_from_slice(entry_indices);
117 }
118 BvhNode::Interior { left, right, .. } => {
119 stack.push(left);
120 stack.push(right);
121 }
122 }
123 }
124
125 for idx in candidates {
127 let node_id = self.entries[idx].node_id;
128 let mesh_index = self.entries[idx].mesh_index;
129 let world_transform = self.entries[idx].world_transform;
130
131 if let Some((toi, mut hit)) = self.test_entry_by_parts(
132 mesh_index,
133 &world_transform,
134 ray_origin,
135 ray_dir,
136 mesh_lookup,
137 ) {
138 if best.is_none() || toi < best.as_ref().unwrap().1 {
139 hit.id = node_id;
140 best = Some((node_id, toi, hit));
141 }
142 }
143 }
144
145 best.map(|(_, _, hit)| hit)
146 }
147
148 fn test_entry_by_parts(
149 &mut self,
150 mesh_index: usize,
151 world_transform: &glam::Mat4,
152 ray_origin: glam::Vec3,
153 ray_dir: glam::Vec3,
154 mesh_lookup: &HashMap<u64, (Vec<[f32; 3]>, Vec<u32>)>,
155 ) -> Option<(f32, crate::interaction::picking::PickHit)> {
156 let (positions, indices) = mesh_lookup.get(&(mesh_index as u64))?;
157
158 if let std::collections::hash_map::Entry::Vacant(e) = self.trimesh_cache.entry(mesh_index) {
160 let verts: Vec<Vector> = positions
161 .iter()
162 .map(|p| Vector::new(p[0], p[1], p[2]))
163 .collect();
164 let tri_indices: Vec<[u32; 3]> = indices
165 .chunks(3)
166 .filter(|c| c.len() == 3)
167 .map(|c| [c[0], c[1], c[2]])
168 .collect();
169 if tri_indices.is_empty() {
170 return None;
171 }
172 match parry3d::shape::TriMesh::new(verts, tri_indices) {
173 Ok(tm) => {
174 e.insert(tm);
175 }
176 Err(_) => return None,
177 }
178 }
179
180 let trimesh = self.trimesh_cache.get(&mesh_index)?;
181
182 let (scale, rotation, translation) = world_transform.to_scale_rotation_translation();
184
185 let inv_rot = rotation.inverse();
187 let local_origin = inv_rot * (ray_origin - translation);
188 let local_dir = inv_rot * ray_dir;
189 let inv_scale = glam::Vec3::new(1.0 / scale.x, 1.0 / scale.y, 1.0 / scale.z);
190 let scaled_origin = local_origin * inv_scale;
191 let scaled_dir = (local_dir * inv_scale).normalize();
192
193 let ray = Ray::new(
194 Vector::new(scaled_origin.x, scaled_origin.y, scaled_origin.z),
195 Vector::new(scaled_dir.x, scaled_dir.y, scaled_dir.z),
196 );
197
198 trimesh
199 .cast_local_ray_and_get_normal(&ray, f32::MAX, true)
200 .map(|intersection| {
201 let avg_scale = (scale.x + scale.y + scale.z) / 3.0;
203 let toi = intersection.time_of_impact * avg_scale;
204
205 let triangle_index = match intersection.feature {
206 FeatureId::Face(idx) => idx,
207 _ => u32::MAX,
208 };
209
210 let local_hit_scaled = scaled_origin + scaled_dir * intersection.time_of_impact;
216 let local_hit = local_hit_scaled * scale;
217 let world_pos = rotation * local_hit + translation;
218
219 let world_normal = (rotation * (intersection.normal * inv_scale)).normalize();
223
224 (
225 toi,
226 crate::interaction::picking::PickHit {
227 id: 0, triangle_index,
229 world_pos,
230 normal: world_normal,
231 point_index: None,
232 scalar_value: None,
233 },
234 )
235 })
236 }
237
238 pub fn invalidate_mesh(&mut self, mesh_index: usize) {
240 self.trimesh_cache.remove(&mesh_index);
241 }
242
243 pub fn invalidate_all(&mut self) {
245 self.trimesh_cache.clear();
246 self.entries.clear();
247 self.root = None;
248 }
249
250 pub fn trimesh_cache_len(&self) -> usize {
252 self.trimesh_cache.len()
253 }
254}
255
256fn build_bvh_node(entries: &[BvhEntry], indices: Vec<usize>) -> BvhNode {
261 let combined = combined_aabb(entries, &indices);
263
264 if indices.len() <= 4 {
266 return BvhNode::Leaf {
267 entry_indices: indices,
268 aabb: combined,
269 };
270 }
271
272 let (best_axis, best_split) = find_best_split(entries, &indices, &combined);
274
275 let mut left_indices = Vec::new();
276 let mut right_indices = Vec::new();
277 for &idx in &indices {
278 let center = entries[idx].aabb.center();
279 let val = match best_axis {
280 0 => center.x,
281 1 => center.y,
282 _ => center.z,
283 };
284 if val <= best_split {
285 left_indices.push(idx);
286 } else {
287 right_indices.push(idx);
288 }
289 }
290
291 if left_indices.is_empty() || right_indices.is_empty() {
293 let mid = indices.len() / 2;
294 left_indices = indices[..mid].to_vec();
295 right_indices = indices[mid..].to_vec();
296 }
297
298 BvhNode::Interior {
299 aabb: combined,
300 left: Box::new(build_bvh_node(entries, left_indices)),
301 right: Box::new(build_bvh_node(entries, right_indices)),
302 }
303}
304
305fn combined_aabb(entries: &[BvhEntry], indices: &[usize]) -> Aabb {
306 let mut min = glam::Vec3::splat(f32::INFINITY);
307 let mut max = glam::Vec3::splat(f32::NEG_INFINITY);
308 for &idx in indices {
309 min = min.min(entries[idx].aabb.min);
310 max = max.max(entries[idx].aabb.max);
311 }
312 Aabb { min, max }
313}
314
315fn find_best_split(_entries: &[BvhEntry], _indices: &[usize], parent_aabb: &Aabb) -> (usize, f32) {
316 let extents = parent_aabb.max - parent_aabb.min;
317 let axis = if extents.x >= extents.y && extents.x >= extents.z {
319 0
320 } else if extents.y >= extents.z {
321 1
322 } else {
323 2
324 };
325
326 let split = match axis {
328 0 => (parent_aabb.min.x + parent_aabb.max.x) * 0.5,
329 1 => (parent_aabb.min.y + parent_aabb.max.y) * 0.5,
330 _ => (parent_aabb.min.z + parent_aabb.max.z) * 0.5,
331 };
332
333 (axis, split)
334}
335
336fn ray_aabb_test(origin: glam::Vec3, dir: glam::Vec3, aabb: &Aabb) -> bool {
341 let inv_dir = glam::Vec3::new(
342 if dir.x.abs() > 1e-10 {
343 1.0 / dir.x
344 } else {
345 f32::INFINITY * dir.x.signum()
346 },
347 if dir.y.abs() > 1e-10 {
348 1.0 / dir.y
349 } else {
350 f32::INFINITY * dir.y.signum()
351 },
352 if dir.z.abs() > 1e-10 {
353 1.0 / dir.z
354 } else {
355 f32::INFINITY * dir.z.signum()
356 },
357 );
358
359 let t1 = (aabb.min - origin) * inv_dir;
360 let t2 = (aabb.max - origin) * inv_dir;
361
362 let tmin = t1.min(t2);
363 let tmax = t1.max(t2);
364
365 let tenter = tmin.x.max(tmin.y).max(tmin.z);
366 let texit = tmax.x.min(tmax.y).min(tmax.z);
367
368 texit >= tenter.max(0.0)
369}
370
371pub fn pick_scene_accelerated(
379 ray_origin: glam::Vec3,
380 ray_dir: glam::Vec3,
381 accelerator: &mut PickAccelerator,
382 mesh_lookup: &HashMap<u64, (Vec<[f32; 3]>, Vec<u32>)>,
383) -> Option<crate::interaction::picking::PickHit> {
384 accelerator.pick(ray_origin, ray_dir, mesh_lookup)
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use crate::resources::mesh_store::MeshId;
391 use crate::scene::material::Material;
392
393 fn unit_cube_mesh() -> (Vec<[f32; 3]>, Vec<u32>) {
394 let positions = vec![
395 [-0.5, -0.5, -0.5],
396 [0.5, -0.5, -0.5],
397 [0.5, 0.5, -0.5],
398 [-0.5, 0.5, -0.5],
399 [-0.5, -0.5, 0.5],
400 [0.5, -0.5, 0.5],
401 [0.5, 0.5, 0.5],
402 [-0.5, 0.5, 0.5],
403 ];
404 let indices = vec![
405 0, 1, 2, 2, 3, 0, 4, 6, 5, 6, 4, 7, 0, 3, 7, 7, 4, 0, 1, 5, 6, 6, 2, 1, 3, 2, 6, 6, 7,
406 3, 0, 4, 5, 5, 1, 0,
407 ];
408 (positions, indices)
409 }
410
411 fn unit_aabb() -> Aabb {
412 Aabb {
413 min: glam::Vec3::splat(-0.5),
414 max: glam::Vec3::splat(0.5),
415 }
416 }
417
418 #[test]
419 fn test_bvh_build_single() {
420 let mut scene = Scene::new();
421 scene.add(Some(MeshId(0)), glam::Mat4::IDENTITY, Material::default());
422 scene.update_transforms();
423
424 let accel = PickAccelerator::build_from_scene(&scene, |_| Some(unit_aabb()));
425 assert_eq!(accel.entries.len(), 1);
426 assert!(accel.root.is_some());
427 }
428
429 #[test]
430 fn test_bvh_pick_hit() {
431 let mut scene = Scene::new();
432 scene.add(Some(MeshId(0)), glam::Mat4::IDENTITY, Material::default());
433 scene.update_transforms();
434
435 let mut accel = PickAccelerator::build_from_scene(&scene, |_| Some(unit_aabb()));
436
437 let (positions, indices) = unit_cube_mesh();
438 let mut mesh_lookup = HashMap::new();
439 mesh_lookup.insert(0u64, (positions, indices));
440
441 let result = accel.pick(
442 glam::Vec3::new(0.0, 0.0, 5.0),
443 glam::Vec3::new(0.0, 0.0, -1.0),
444 &mesh_lookup,
445 );
446 assert!(result.is_some(), "should hit the cube");
447 }
448
449 #[test]
450 fn test_bvh_pick_miss() {
451 let mut scene = Scene::new();
452 scene.add(Some(MeshId(0)), glam::Mat4::IDENTITY, Material::default());
453 scene.update_transforms();
454
455 let mut accel = PickAccelerator::build_from_scene(&scene, |_| Some(unit_aabb()));
456
457 let (positions, indices) = unit_cube_mesh();
458 let mut mesh_lookup = HashMap::new();
459 mesh_lookup.insert(0u64, (positions, indices));
460
461 let result = accel.pick(
462 glam::Vec3::new(100.0, 100.0, 5.0),
463 glam::Vec3::new(0.0, 0.0, -1.0),
464 &mesh_lookup,
465 );
466 assert!(result.is_none(), "should miss");
467 }
468
469 #[test]
470 fn test_bvh_pick_nearest() {
471 let mut scene = Scene::new();
472 scene.add(
473 Some(MeshId(0)),
474 glam::Mat4::from_translation(glam::Vec3::new(0.0, 0.0, 2.0)),
475 Material::default(),
476 );
477 scene.add(
478 Some(MeshId(1)),
479 glam::Mat4::from_translation(glam::Vec3::new(0.0, 0.0, -2.0)),
480 Material::default(),
481 );
482 scene.update_transforms();
483
484 let mut accel = PickAccelerator::build_from_scene(&scene, |_| Some(unit_aabb()));
485
486 let (positions, indices) = unit_cube_mesh();
487 let mut mesh_lookup = HashMap::new();
488 mesh_lookup.insert(0u64, (positions.clone(), indices.clone()));
489 mesh_lookup.insert(1u64, (positions, indices));
490
491 let result = accel.pick(
493 glam::Vec3::new(0.0, 0.0, 10.0),
494 glam::Vec3::new(0.0, 0.0, -1.0),
495 &mesh_lookup,
496 );
497 assert!(result.is_some(), "should hit something");
498 }
499
500 #[test]
501 fn test_trimesh_cache_reuse() {
502 let mut scene = Scene::new();
503 scene.add(Some(MeshId(0)), glam::Mat4::IDENTITY, Material::default());
504 scene.update_transforms();
505
506 let mut accel = PickAccelerator::build_from_scene(&scene, |_| Some(unit_aabb()));
507
508 let (positions, indices) = unit_cube_mesh();
509 let mut mesh_lookup = HashMap::new();
510 mesh_lookup.insert(0u64, (positions, indices));
511
512 let _ = accel.pick(
514 glam::Vec3::new(0.0, 0.0, 5.0),
515 glam::Vec3::new(0.0, 0.0, -1.0),
516 &mesh_lookup,
517 );
518 assert_eq!(accel.trimesh_cache_len(), 1);
519
520 let _ = accel.pick(
522 glam::Vec3::new(0.0, 0.0, 5.0),
523 glam::Vec3::new(0.0, 0.0, -1.0),
524 &mesh_lookup,
525 );
526 assert_eq!(accel.trimesh_cache_len(), 1);
527 }
528
529 #[test]
530 fn test_ray_aabb_hit() {
531 let aabb = unit_aabb();
532 assert!(ray_aabb_test(
533 glam::Vec3::new(0.0, 0.0, 5.0),
534 glam::Vec3::new(0.0, 0.0, -1.0),
535 &aabb,
536 ));
537 }
538
539 #[test]
540 fn test_ray_aabb_miss() {
541 let aabb = unit_aabb();
542 assert!(!ray_aabb_test(
543 glam::Vec3::new(100.0, 100.0, 5.0),
544 glam::Vec3::new(0.0, 0.0, -1.0),
545 &aabb,
546 ));
547 }
548}