1use std::collections::HashMap;
4
5use nalgebra::{Vector3, Vector4};
6use threecrate_core::{Error, Point3f, PointCloud, Result};
7
8use crate::GpuContext;
9
10const RANSAC_SCORE_SHADER: &str = r#"
11@group(0) @binding(0) var<storage, read> points: array<vec4<f32>>;
12@group(0) @binding(1) var<storage, read> samples: array<vec4<u32>>;
13@group(0) @binding(2) var<storage, read_write> candidates: array<PlaneCandidate>;
14@group(0) @binding(3) var<uniform> params: RansacParams;
15
16struct PlaneCandidate {
17 coefficients: vec4<f32>,
18 inlier_count: u32,
19 _pad0: u32,
20 _pad1: u32,
21 _pad2: u32,
22}
23
24struct RansacParams {
25 num_points: u32,
26 num_samples: u32,
27 threshold: f32,
28 _pad: u32,
29}
30
31@compute @workgroup_size(64)
32fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
33 let sample_idx = global_id.x;
34 if (sample_idx >= params.num_samples) {
35 return;
36 }
37
38 let sample = samples[sample_idx];
39 let p1 = points[sample.x].xyz;
40 let p2 = points[sample.y].xyz;
41 let p3 = points[sample.z].xyz;
42
43 let normal_raw = cross(p2 - p1, p3 - p1);
44 let normal_len = length(normal_raw);
45 if (normal_len < 1e-8) {
46 candidates[sample_idx].coefficients = vec4<f32>(0.0);
47 candidates[sample_idx].inlier_count = 0u;
48 return;
49 }
50
51 let normal = normal_raw / normal_len;
52 let d = -dot(normal, p1);
53 var count = 0u;
54
55 for (var i = 0u; i < params.num_points; i++) {
56 let distance = abs(dot(normal, points[i].xyz) + d);
57 if (distance <= params.threshold) {
58 count++;
59 }
60 }
61
62 candidates[sample_idx].coefficients = vec4<f32>(normal, d);
63 candidates[sample_idx].inlier_count = count;
64}
65"#;
66
67const RANSAC_INLIER_SHADER: &str = r#"
68@group(0) @binding(0) var<storage, read> points: array<vec4<f32>>;
69@group(0) @binding(1) var<storage, read_write> flags: array<u32>;
70@group(0) @binding(2) var<uniform> params: InlierParams;
71
72struct InlierParams {
73 num_points: u32,
74 threshold: f32,
75 _pad0: u32,
76 _pad1: u32,
77 plane: vec4<f32>,
78}
79
80@compute @workgroup_size(64)
81fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
82 let index = global_id.x;
83 if (index >= params.num_points) {
84 return;
85 }
86
87 let normal_len = length(params.plane.xyz);
88 if (normal_len < 1e-8) {
89 flags[index] = 0u;
90 return;
91 }
92
93 let distance = abs(dot(params.plane.xyz, points[index].xyz) + params.plane.w) / normal_len;
94 flags[index] = select(0u, 1u, distance <= params.threshold);
95}
96"#;
97
98const RADIUS_NEIGHBOR_SHADER: &str = r#"
99@group(0) @binding(0) var<storage, read> points: array<vec4<f32>>;
100@group(0) @binding(1) var<storage, read> candidate_counts: array<u32>;
101@group(0) @binding(2) var<storage, read> candidates: array<array<u32, MAX_CANDIDATES>>;
102@group(0) @binding(3) var<storage, read_write> neighbor_counts: array<u32>;
103@group(0) @binding(4) var<storage, read_write> neighbors: array<array<u32, MAX_NEIGHBORS>>;
104@group(0) @binding(5) var<uniform> params: ClusterParams;
105
106struct ClusterParams {
107 num_points: u32,
108 max_neighbors: u32,
109 max_candidates: u32,
110 _pad0: u32,
111 tolerance: f32,
112 _pad1: u32,
113 _pad2: u32,
114 _pad3: u32,
115}
116
117@compute @workgroup_size(64)
118fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
119 let index = global_id.x;
120 if (index >= params.num_points) {
121 return;
122 }
123
124 let center = points[index].xyz;
125 var stored = 0u;
126 let candidate_count = min(candidate_counts[index], params.max_candidates);
127
128 for (var slot = 0u; slot < candidate_count; slot++) {
129 let candidate_index = candidates[index][slot];
130 if (candidate_index >= params.num_points || candidate_index == index) {
131 continue;
132 }
133
134 let distance = length(points[candidate_index].xyz - center);
135 if (distance <= params.tolerance) {
136 if (stored < params.max_neighbors) {
137 neighbors[index][stored] = candidate_index;
138 stored++;
139 }
140 }
141 }
142
143 neighbor_counts[index] = stored;
144}
145"#;
146
147#[derive(Debug, Clone, PartialEq)]
149pub struct GpuPlaneModel {
150 pub coefficients: Vector4<f32>,
152}
153
154impl GpuPlaneModel {
155 pub fn new(coefficients: Vector4<f32>) -> Self {
157 Self { coefficients }
158 }
159
160 pub fn normal(&self) -> Vector3<f32> {
162 Vector3::new(
163 self.coefficients.x,
164 self.coefficients.y,
165 self.coefficients.z,
166 )
167 }
168
169 pub fn distance_to_point(&self, point: &Point3f) -> f32 {
171 let normal = self.normal();
172 let normal_len = normal.magnitude();
173 if normal_len < 1e-8 {
174 return f32::INFINITY;
175 }
176
177 (normal.dot(&point.coords) + self.coefficients.w).abs() / normal_len
178 }
179}
180
181#[derive(Debug, Clone)]
183pub struct GpuPlaneSegmentationResult {
184 pub plane: GpuPlaneModel,
186 pub model: GpuPlaneModel,
188 pub inliers: Vec<u32>,
190 pub iterations: usize,
192}
193
194#[derive(Debug, Clone, Copy)]
196pub struct GpuPlaneSegmentationConfig {
197 pub max_iterations: usize,
199 pub distance_threshold: f32,
201 pub min_inliers: usize,
203}
204
205impl Default for GpuPlaneSegmentationConfig {
206 fn default() -> Self {
207 Self {
208 max_iterations: 1_000,
209 distance_threshold: 0.02,
210 min_inliers: 1,
211 }
212 }
213}
214
215#[derive(Debug, Clone)]
217pub struct GpuEuclideanClusterConfig {
218 pub tolerance: f32,
220 pub min_cluster_size: usize,
222 pub max_cluster_size: usize,
224 pub max_neighbors: usize,
226}
227
228impl Default for GpuEuclideanClusterConfig {
229 fn default() -> Self {
230 Self {
231 tolerance: 0.02,
232 min_cluster_size: 100,
233 max_cluster_size: 25_000,
234 max_neighbors: 64,
235 }
236 }
237}
238
239impl GpuEuclideanClusterConfig {
240 pub fn new(tolerance: f32, min_cluster_size: usize, max_cluster_size: usize) -> Self {
242 Self {
243 tolerance,
244 min_cluster_size,
245 max_cluster_size,
246 ..Self::default()
247 }
248 }
249
250 pub fn with_max_neighbors(
252 tolerance: f32,
253 min_cluster_size: usize,
254 max_cluster_size: usize,
255 max_neighbors: usize,
256 ) -> Self {
257 Self {
258 tolerance,
259 min_cluster_size,
260 max_cluster_size,
261 max_neighbors,
262 }
263 }
264}
265
266pub type GpuClusterConfig = GpuEuclideanClusterConfig;
268
269#[derive(Debug, Clone)]
271pub struct GpuClusterExtractionResult {
272 pub clusters: Vec<Vec<usize>>,
274}
275
276impl GpuClusterExtractionResult {
277 pub fn num_clusters(&self) -> usize {
279 self.clusters.len()
280 }
281
282 pub fn get_cluster_cloud(
284 &self,
285 cloud: &PointCloud<Point3f>,
286 index: usize,
287 ) -> Option<PointCloud<Point3f>> {
288 self.clusters.get(index).map(|indices| {
289 PointCloud::from_points(indices.iter().map(|&i| cloud.points[i]).collect())
290 })
291 }
292}
293
294#[repr(C)]
295#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
296struct PlaneCandidate {
297 coefficients: [f32; 4],
298 inlier_count: u32,
299 _padding: [u32; 3],
300}
301
302impl GpuContext {
303 pub async fn segment_plane(
305 &self,
306 points: &[Point3f],
307 config: GpuPlaneSegmentationConfig,
308 ) -> Result<GpuPlaneSegmentationResult> {
309 validate_ransac_config(points, config)?;
310
311 let mut result = self
312 .segment_plane_ransac(points, config.distance_threshold, config.max_iterations)
313 .await?;
314 if result.inliers.len() < config.min_inliers {
315 return Err(Error::Algorithm(format!(
316 "Plane model has {} inliers, below required minimum {}",
317 result.inliers.len(),
318 config.min_inliers
319 )));
320 }
321
322 result.plane = result.model.clone();
323 Ok(result)
324 }
325
326 pub async fn segment_plane_ransac(
328 &self,
329 points: &[Point3f],
330 threshold: f32,
331 max_iters: usize,
332 ) -> Result<GpuPlaneSegmentationResult> {
333 validate_ransac_inputs(points, threshold, max_iters)?;
334
335 let point_data = points_to_vec4(points);
336 let sample_count = max_iters.min(u32::MAX as usize);
337 let samples = generate_ransac_samples(points.len(), sample_count);
338
339 let input_buffer =
340 self.create_buffer_init("RANSAC Points", &point_data, wgpu::BufferUsages::STORAGE);
341 let samples_buffer =
342 self.create_buffer_init("RANSAC Samples", &samples, wgpu::BufferUsages::STORAGE);
343 let candidates_buffer = self.create_buffer(
344 "RANSAC Candidates",
345 (sample_count * std::mem::size_of::<PlaneCandidate>()) as u64,
346 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
347 );
348
349 #[repr(C)]
350 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
351 struct RansacParams {
352 num_points: u32,
353 num_samples: u32,
354 threshold: f32,
355 _padding: u32,
356 }
357
358 let params = RansacParams {
359 num_points: points.len() as u32,
360 num_samples: sample_count as u32,
361 threshold,
362 _padding: 0,
363 };
364 let params_buffer =
365 self.create_buffer_init("RANSAC Params", &[params], wgpu::BufferUsages::UNIFORM);
366
367 let shader = self.create_shader_module("RANSAC Score Shader", RANSAC_SCORE_SHADER);
368 let layout = self.create_bind_group_layout(
369 "RANSAC Score Layout",
370 &[
371 storage_entry(0, true),
372 storage_entry(1, true),
373 storage_entry(2, false),
374 uniform_entry(3),
375 ],
376 );
377 let pipeline = self.create_pipeline_with_layout("RANSAC Score Pipeline", &shader, &layout);
378 let bind_group = self.create_bind_group(
379 "RANSAC Score Bind Group",
380 &layout,
381 &[
382 wgpu::BindGroupEntry {
383 binding: 0,
384 resource: input_buffer.as_entire_binding(),
385 },
386 wgpu::BindGroupEntry {
387 binding: 1,
388 resource: samples_buffer.as_entire_binding(),
389 },
390 wgpu::BindGroupEntry {
391 binding: 2,
392 resource: candidates_buffer.as_entire_binding(),
393 },
394 wgpu::BindGroupEntry {
395 binding: 3,
396 resource: params_buffer.as_entire_binding(),
397 },
398 ],
399 );
400
401 let mut encoder = self
402 .device
403 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
404 label: Some("RANSAC Score Encoder"),
405 });
406 {
407 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
408 label: Some("RANSAC Score Pass"),
409 timestamp_writes: None,
410 });
411 pass.set_pipeline(&pipeline);
412 pass.set_bind_group(0, &bind_group, &[]);
413 pass.dispatch_workgroups(div_ceil(sample_count, 64) as u32, 1, 1);
414 }
415
416 let candidates = self
417 .read_storage_buffer::<PlaneCandidate>(
418 encoder,
419 &candidates_buffer,
420 sample_count,
421 "RANSAC Candidate Staging",
422 )
423 .await?;
424
425 let best = candidates
426 .iter()
427 .max_by_key(|candidate| candidate.inlier_count)
428 .ok_or_else(|| Error::Algorithm("Failed to evaluate RANSAC candidates".to_string()))?;
429
430 if best.inlier_count == 0 {
431 return Err(Error::Algorithm(
432 "Failed to find valid plane model".to_string(),
433 ));
434 }
435
436 let coefficients = Vector4::new(
437 best.coefficients[0],
438 best.coefficients[1],
439 best.coefficients[2],
440 best.coefficients[3],
441 );
442 let inliers = self
443 .plane_inlier_indices(points, coefficients, threshold)
444 .await?;
445
446 Ok(GpuPlaneSegmentationResult {
447 plane: GpuPlaneModel::new(coefficients),
448 model: GpuPlaneModel::new(coefficients),
449 inliers,
450 iterations: sample_count,
451 })
452 }
453
454 pub async fn extract_clusters(
456 &self,
457 cloud: &PointCloud<Point3f>,
458 config: GpuClusterConfig,
459 ) -> Result<Vec<PointCloud<Point3f>>> {
460 let result = self
461 .extract_euclidean_clusters(&cloud.points, &config)
462 .await?;
463 Ok(result
464 .clusters
465 .iter()
466 .map(|indices| {
467 PointCloud::from_points(indices.iter().map(|&i| cloud.points[i]).collect())
468 })
469 .collect())
470 }
471
472 pub async fn extract_euclidean_clusters(
474 &self,
475 points: &[Point3f],
476 config: &GpuEuclideanClusterConfig,
477 ) -> Result<GpuClusterExtractionResult> {
478 validate_cluster_inputs(points, config)?;
479
480 let max_neighbors = config.max_neighbors.min(256).max(1);
481 let max_candidates = max_neighbors.saturating_mul(8).clamp(max_neighbors, 1024);
482 let point_data = points_to_vec4(points);
483 let (candidate_counts, candidate_indices) =
484 build_voxel_candidate_neighbors(points, config.tolerance, max_candidates)?;
485 let input_buffer =
486 self.create_buffer_init("Cluster Points", &point_data, wgpu::BufferUsages::STORAGE);
487 let candidate_counts_buffer = self.create_buffer_init(
488 "Cluster Candidate Counts",
489 &candidate_counts,
490 wgpu::BufferUsages::STORAGE,
491 );
492 let candidates_buffer = self.create_buffer_init(
493 "Cluster Candidates",
494 &candidate_indices,
495 wgpu::BufferUsages::STORAGE,
496 );
497 let counts_buffer = self.create_buffer(
498 "Cluster Neighbor Counts",
499 (points.len() * std::mem::size_of::<u32>()) as u64,
500 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
501 );
502 let neighbors_buffer = self.create_buffer(
503 "Cluster Neighbors",
504 (points.len() * max_neighbors * std::mem::size_of::<u32>()) as u64,
505 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
506 );
507
508 #[repr(C)]
509 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
510 struct ClusterParams {
511 num_points: u32,
512 max_neighbors: u32,
513 max_candidates: u32,
514 _padding0: u32,
515 tolerance: f32,
516 _padding1: [u32; 3],
517 }
518
519 let params = ClusterParams {
520 num_points: points.len() as u32,
521 max_neighbors: max_neighbors as u32,
522 max_candidates: max_candidates as u32,
523 _padding0: 0,
524 tolerance: config.tolerance,
525 _padding1: [0; 3],
526 };
527 let params_buffer =
528 self.create_buffer_init("Cluster Params", &[params], wgpu::BufferUsages::UNIFORM);
529
530 let shader_source = RADIUS_NEIGHBOR_SHADER
531 .replace("MAX_NEIGHBORS", &max_neighbors.to_string())
532 .replace("MAX_CANDIDATES", &max_candidates.to_string());
533 let shader = self.create_shader_module("Cluster Radius Neighbor Shader", &shader_source);
534 let layout = self.create_bind_group_layout(
535 "Cluster Radius Neighbor Layout",
536 &[
537 storage_entry(0, true),
538 storage_entry(1, true),
539 storage_entry(2, true),
540 storage_entry(3, false),
541 storage_entry(4, false),
542 uniform_entry(5),
543 ],
544 );
545 let pipeline =
546 self.create_pipeline_with_layout("Cluster Radius Neighbor Pipeline", &shader, &layout);
547 let bind_group = self.create_bind_group(
548 "Cluster Radius Neighbor Bind Group",
549 &layout,
550 &[
551 wgpu::BindGroupEntry {
552 binding: 0,
553 resource: input_buffer.as_entire_binding(),
554 },
555 wgpu::BindGroupEntry {
556 binding: 1,
557 resource: candidate_counts_buffer.as_entire_binding(),
558 },
559 wgpu::BindGroupEntry {
560 binding: 2,
561 resource: candidates_buffer.as_entire_binding(),
562 },
563 wgpu::BindGroupEntry {
564 binding: 3,
565 resource: counts_buffer.as_entire_binding(),
566 },
567 wgpu::BindGroupEntry {
568 binding: 4,
569 resource: neighbors_buffer.as_entire_binding(),
570 },
571 wgpu::BindGroupEntry {
572 binding: 5,
573 resource: params_buffer.as_entire_binding(),
574 },
575 ],
576 );
577
578 let mut encoder = self
579 .device
580 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
581 label: Some("Cluster Radius Neighbor Encoder"),
582 });
583 {
584 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
585 label: Some("Cluster Radius Neighbor Pass"),
586 timestamp_writes: None,
587 });
588 pass.set_pipeline(&pipeline);
589 pass.set_bind_group(0, &bind_group, &[]);
590 pass.dispatch_workgroups(div_ceil(points.len(), 64) as u32, 1, 1);
591 }
592
593 let counts = self
594 .read_storage_buffer::<u32>(
595 encoder,
596 &counts_buffer,
597 points.len(),
598 "Cluster Count Staging",
599 )
600 .await?;
601
602 let encoder = self
603 .device
604 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
605 label: Some("Cluster Neighbor Read Encoder"),
606 });
607 let neighbors = self
608 .read_storage_buffer::<u32>(
609 encoder,
610 &neighbors_buffer,
611 points.len() * max_neighbors,
612 "Cluster Neighbor Staging",
613 )
614 .await?;
615
616 let mut disjoint_set = DisjointSet::new(points.len());
617 for point_idx in 0..points.len() {
618 let count = counts[point_idx].min(max_neighbors as u32) as usize;
619 let base = point_idx * max_neighbors;
620
621 for &neighbor in &neighbors[base..base + count] {
622 let neighbor = neighbor as usize;
623 if neighbor < points.len() {
624 disjoint_set.union(point_idx, neighbor);
625 }
626 }
627 }
628
629 let mut by_root: HashMap<usize, usize> = HashMap::new();
630 let mut clusters = Vec::new();
631 for point_idx in 0..points.len() {
632 let root = disjoint_set.find(point_idx);
633 let cluster_idx = *by_root.entry(root).or_insert_with(|| {
634 clusters.push(Vec::new());
635 clusters.len() - 1
636 });
637 clusters[cluster_idx].push(point_idx);
638 }
639
640 clusters.retain(|cluster| {
641 cluster.len() >= config.min_cluster_size && cluster.len() <= config.max_cluster_size
642 });
643 clusters.sort_by(|a, b| b.len().cmp(&a.len()));
644 Ok(GpuClusterExtractionResult { clusters })
645 }
646
647 async fn plane_inlier_indices(
648 &self,
649 points: &[Point3f],
650 coefficients: Vector4<f32>,
651 threshold: f32,
652 ) -> Result<Vec<u32>> {
653 let point_data = points_to_vec4(points);
654 let input_buffer = self.create_buffer_init(
655 "Plane Inlier Points",
656 &point_data,
657 wgpu::BufferUsages::STORAGE,
658 );
659 let flags_buffer = self.create_buffer(
660 "Plane Inlier Flags",
661 (points.len() * std::mem::size_of::<u32>()) as u64,
662 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
663 );
664
665 #[repr(C)]
666 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
667 struct InlierParams {
668 num_points: u32,
669 threshold: f32,
670 _padding: [u32; 2],
671 plane: [f32; 4],
672 }
673
674 let params = InlierParams {
675 num_points: points.len() as u32,
676 threshold,
677 _padding: [0; 2],
678 plane: [
679 coefficients.x,
680 coefficients.y,
681 coefficients.z,
682 coefficients.w,
683 ],
684 };
685 let params_buffer = self.create_buffer_init(
686 "Plane Inlier Params",
687 &[params],
688 wgpu::BufferUsages::UNIFORM,
689 );
690 let shader = self.create_shader_module("Plane Inlier Shader", RANSAC_INLIER_SHADER);
691 let layout = self.create_bind_group_layout(
692 "Plane Inlier Layout",
693 &[
694 storage_entry(0, true),
695 storage_entry(1, false),
696 uniform_entry(2),
697 ],
698 );
699 let pipeline = self.create_pipeline_with_layout("Plane Inlier Pipeline", &shader, &layout);
700 let bind_group = self.create_bind_group(
701 "Plane Inlier Bind Group",
702 &layout,
703 &[
704 wgpu::BindGroupEntry {
705 binding: 0,
706 resource: input_buffer.as_entire_binding(),
707 },
708 wgpu::BindGroupEntry {
709 binding: 1,
710 resource: flags_buffer.as_entire_binding(),
711 },
712 wgpu::BindGroupEntry {
713 binding: 2,
714 resource: params_buffer.as_entire_binding(),
715 },
716 ],
717 );
718
719 let mut encoder = self
720 .device
721 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
722 label: Some("Plane Inlier Encoder"),
723 });
724 {
725 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
726 label: Some("Plane Inlier Pass"),
727 timestamp_writes: None,
728 });
729 pass.set_pipeline(&pipeline);
730 pass.set_bind_group(0, &bind_group, &[]);
731 pass.dispatch_workgroups(div_ceil(points.len(), 64) as u32, 1, 1);
732 }
733
734 let flags = self
735 .read_storage_buffer::<u32>(
736 encoder,
737 &flags_buffer,
738 points.len(),
739 "Plane Inlier Staging",
740 )
741 .await?;
742 Ok(flags
743 .iter()
744 .enumerate()
745 .filter_map(|(idx, flag)| (*flag == 1).then_some(idx as u32))
746 .collect())
747 }
748
749 fn create_pipeline_with_layout(
750 &self,
751 label: &str,
752 shader: &wgpu::ShaderModule,
753 layout: &wgpu::BindGroupLayout,
754 ) -> wgpu::ComputePipeline {
755 self.device
756 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
757 label: Some(label),
758 layout: Some(&self.device.create_pipeline_layout(
759 &wgpu::PipelineLayoutDescriptor {
760 label: Some(label),
761 bind_group_layouts: &[Some(layout)],
762 immediate_size: 0,
763 },
764 )),
765 module: shader,
766 entry_point: Some("main"),
767 compilation_options: wgpu::PipelineCompilationOptions::default(),
768 cache: None,
769 })
770 }
771
772 async fn read_storage_buffer<T: bytemuck::Pod>(
773 &self,
774 mut encoder: wgpu::CommandEncoder,
775 source: &wgpu::Buffer,
776 len: usize,
777 label: &str,
778 ) -> Result<Vec<T>> {
779 let size = (len * std::mem::size_of::<T>()) as u64;
780 let staging = self.create_buffer(
781 label,
782 size,
783 wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
784 );
785 encoder.copy_buffer_to_buffer(source, 0, &staging, 0, size);
786 self.queue.submit(std::iter::once(encoder.finish()));
787
788 let slice = staging.slice(..);
789 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
790 slice.map_async(wgpu::MapMode::Read, move |result| {
791 sender.send(result).unwrap()
792 });
793 let _ = self.device.poll(wgpu::PollType::Wait {
794 submission_index: None,
795 timeout: None,
796 });
797
798 if let Some(Ok(())) = receiver.receive().await {
799 let data = slice.get_mapped_range();
800 let result = bytemuck::cast_slice(&data).to_vec();
801 drop(data);
802 staging.unmap();
803 Ok(result)
804 } else {
805 Err(Error::Gpu(
806 "Failed to read GPU segmentation results".to_string(),
807 ))
808 }
809 }
810}
811
812pub async fn gpu_segment_plane(
814 gpu_context: &GpuContext,
815 cloud: &PointCloud<Point3f>,
816 config: GpuPlaneSegmentationConfig,
817) -> Result<GpuPlaneSegmentationResult> {
818 gpu_context.segment_plane(&cloud.points, config).await
819}
820
821pub async fn gpu_segment_plane_ransac(
823 gpu_context: &GpuContext,
824 cloud: &PointCloud<Point3f>,
825 threshold: f32,
826 max_iters: usize,
827) -> Result<GpuPlaneSegmentationResult> {
828 gpu_context
829 .segment_plane_ransac(&cloud.points, threshold, max_iters)
830 .await
831}
832
833pub async fn gpu_extract_clusters(
835 gpu_context: &GpuContext,
836 cloud: &PointCloud<Point3f>,
837 config: GpuClusterConfig,
838) -> Result<Vec<PointCloud<Point3f>>> {
839 gpu_context.extract_clusters(cloud, config).await
840}
841
842pub async fn gpu_extract_euclidean_clusters(
844 gpu_context: &GpuContext,
845 cloud: &PointCloud<Point3f>,
846 config: &GpuEuclideanClusterConfig,
847) -> Result<GpuClusterExtractionResult> {
848 gpu_context
849 .extract_euclidean_clusters(&cloud.points, config)
850 .await
851}
852
853fn validate_ransac_config(points: &[Point3f], config: GpuPlaneSegmentationConfig) -> Result<()> {
854 validate_ransac_inputs(points, config.distance_threshold, config.max_iterations)?;
855 if config.min_inliers == 0 {
856 return Err(Error::InvalidData(
857 "min_inliers must be at least 1".to_string(),
858 ));
859 }
860 Ok(())
861}
862
863fn validate_ransac_inputs(points: &[Point3f], threshold: f32, max_iters: usize) -> Result<()> {
864 if points.len() < 3 {
865 return Err(Error::InvalidData(
866 "Need at least 3 points for plane segmentation".to_string(),
867 ));
868 }
869 if threshold <= 0.0 {
870 return Err(Error::InvalidData("Threshold must be positive".to_string()));
871 }
872 if max_iters == 0 {
873 return Err(Error::InvalidData(
874 "Max iterations must be positive".to_string(),
875 ));
876 }
877 if points.len() > u32::MAX as usize {
878 return Err(Error::InvalidData(
879 "Point cloud is too large for GPU segmentation".to_string(),
880 ));
881 }
882 Ok(())
883}
884
885fn validate_cluster_inputs(points: &[Point3f], config: &GpuEuclideanClusterConfig) -> Result<()> {
886 if points.is_empty() {
887 return Err(Error::InvalidData("Point cloud is empty".to_string()));
888 }
889 if config.tolerance <= 0.0 {
890 return Err(Error::InvalidData("Tolerance must be positive".to_string()));
891 }
892 if config.min_cluster_size == 0 {
893 return Err(Error::InvalidData(
894 "min_cluster_size must be at least 1".to_string(),
895 ));
896 }
897 if config.min_cluster_size > config.max_cluster_size {
898 return Err(Error::InvalidData(
899 "min_cluster_size must not exceed max_cluster_size".to_string(),
900 ));
901 }
902 if config.max_neighbors == 0 {
903 return Err(Error::InvalidData(
904 "max_neighbors must be at least 1".to_string(),
905 ));
906 }
907 if points.len() > u32::MAX as usize {
908 return Err(Error::InvalidData(
909 "Point cloud is too large for GPU clustering".to_string(),
910 ));
911 }
912 Ok(())
913}
914
915fn points_to_vec4(points: &[Point3f]) -> Vec<[f32; 4]> {
916 points.iter().map(|p| [p.x, p.y, p.z, 0.0]).collect()
917}
918
919fn build_voxel_candidate_neighbors(
920 points: &[Point3f],
921 tolerance: f32,
922 max_candidates: usize,
923) -> Result<(Vec<u32>, Vec<u32>)> {
924 let total_candidates = points
925 .len()
926 .checked_mul(max_candidates)
927 .ok_or_else(|| Error::InvalidData("Cluster candidate buffer is too large".to_string()))?;
928 let mut bins: HashMap<(i32, i32, i32), Vec<u32>> = HashMap::new();
929
930 for (idx, point) in points.iter().enumerate() {
931 bins.entry(voxel_key(point, tolerance))
932 .or_default()
933 .push(idx as u32);
934 }
935
936 let mut counts = vec![0u32; points.len()];
937 let mut candidates = vec![u32::MAX; total_candidates];
938 for (idx, point) in points.iter().enumerate() {
939 let (vx, vy, vz) = voxel_key(point, tolerance);
940 let base = idx * max_candidates;
941 let mut stored = 0usize;
942
943 for dx in -1..=1 {
944 for dy in -1..=1 {
945 for dz in -1..=1 {
946 let key = (vx + dx, vy + dy, vz + dz);
947 let Some(bucket) = bins.get(&key) else {
948 continue;
949 };
950
951 for &candidate in bucket {
952 if candidate as usize == idx {
953 continue;
954 }
955 if stored == max_candidates {
956 break;
957 }
958 candidates[base + stored] = candidate;
959 stored += 1;
960 }
961 }
962 }
963 }
964
965 counts[idx] = stored as u32;
966 }
967
968 Ok((counts, candidates))
969}
970
971fn voxel_key(point: &Point3f, tolerance: f32) -> (i32, i32, i32) {
972 (
973 (point.x / tolerance).floor() as i32,
974 (point.y / tolerance).floor() as i32,
975 (point.z / tolerance).floor() as i32,
976 )
977}
978
979fn generate_ransac_samples(num_points: usize, sample_count: usize) -> Vec<[u32; 4]> {
980 let mut state = ((num_points as u64) << 32) ^ sample_count as u64 ^ 0x9E37_79B9_7F4A_7C15;
981 let mut samples = Vec::with_capacity(sample_count);
982
983 for iteration in 0..sample_count {
984 let mut a = next_index(&mut state, num_points);
985 let mut b = next_index(&mut state, num_points);
986 let mut c = next_index(&mut state, num_points);
987
988 if a == b || a == c || b == c {
989 a = iteration % num_points;
990 b = (iteration.wrapping_mul(37) + 1) % num_points;
991 c = (iteration.wrapping_mul(101) + 2) % num_points;
992 while b == a {
993 b = (b + 1) % num_points;
994 }
995 while c == a || c == b {
996 c = (c + 1) % num_points;
997 }
998 }
999
1000 samples.push([a as u32, b as u32, c as u32, 0]);
1001 }
1002
1003 samples
1004}
1005
1006fn next_index(state: &mut u64, len: usize) -> usize {
1007 *state = state
1008 .wrapping_mul(6_364_136_223_846_793_005)
1009 .wrapping_add(1_442_695_040_888_963_407);
1010 ((*state >> 32) as usize) % len
1011}
1012
1013struct DisjointSet {
1014 parent: Vec<usize>,
1015 rank: Vec<u8>,
1016}
1017
1018impl DisjointSet {
1019 fn new(len: usize) -> Self {
1020 Self {
1021 parent: (0..len).collect(),
1022 rank: vec![0; len],
1023 }
1024 }
1025
1026 fn find(&mut self, value: usize) -> usize {
1027 if self.parent[value] != value {
1028 self.parent[value] = self.find(self.parent[value]);
1029 }
1030 self.parent[value]
1031 }
1032
1033 fn union(&mut self, a: usize, b: usize) {
1034 let root_a = self.find(a);
1035 let root_b = self.find(b);
1036 if root_a == root_b {
1037 return;
1038 }
1039
1040 match self.rank[root_a].cmp(&self.rank[root_b]) {
1041 std::cmp::Ordering::Less => self.parent[root_a] = root_b,
1042 std::cmp::Ordering::Greater => self.parent[root_b] = root_a,
1043 std::cmp::Ordering::Equal => {
1044 self.parent[root_b] = root_a;
1045 self.rank[root_a] += 1;
1046 }
1047 }
1048 }
1049}
1050
1051fn storage_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
1052 wgpu::BindGroupLayoutEntry {
1053 binding,
1054 visibility: wgpu::ShaderStages::COMPUTE,
1055 ty: wgpu::BindingType::Buffer {
1056 ty: wgpu::BufferBindingType::Storage { read_only },
1057 has_dynamic_offset: false,
1058 min_binding_size: None,
1059 },
1060 count: None,
1061 }
1062}
1063
1064fn uniform_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
1065 wgpu::BindGroupLayoutEntry {
1066 binding,
1067 visibility: wgpu::ShaderStages::COMPUTE,
1068 ty: wgpu::BindingType::Buffer {
1069 ty: wgpu::BufferBindingType::Uniform,
1070 has_dynamic_offset: false,
1071 min_binding_size: None,
1072 },
1073 count: None,
1074 }
1075}
1076
1077fn div_ceil(value: usize, divisor: usize) -> usize {
1078 value.div_ceil(divisor)
1079}
1080
1081#[cfg(test)]
1082mod tests {
1083 use super::*;
1084
1085 async fn try_create_gpu_context() -> Option<GpuContext> {
1086 match GpuContext::new().await {
1087 Ok(gpu) => Some(gpu),
1088 Err(_) => {
1089 println!("GPU not available, skipping GPU-dependent test");
1090 None
1091 }
1092 }
1093 }
1094
1095 fn planar_cloud() -> PointCloud<Point3f> {
1096 let mut cloud = PointCloud::new();
1097 for x in 0..12 {
1098 for y in 0..12 {
1099 cloud.push(Point3f::new(x as f32 * 0.1, y as f32 * 0.1, 0.0));
1100 }
1101 }
1102 cloud.push(Point3f::new(0.0, 0.0, 3.0));
1103 cloud.push(Point3f::new(1.0, 1.0, -3.0));
1104 cloud
1105 }
1106
1107 fn clustered_cloud() -> PointCloud<Point3f> {
1108 let mut cloud = PointCloud::new();
1109 for i in 0..30 {
1110 let x = (i % 5) as f32 * 0.04;
1111 let y = (i / 5) as f32 * 0.04;
1112 cloud.push(Point3f::new(x, y, 0.0));
1113 }
1114 for i in 0..20 {
1115 let x = 5.0 + (i % 5) as f32 * 0.04;
1116 let y = (i / 5) as f32 * 0.04;
1117 cloud.push(Point3f::new(x, y, 0.0));
1118 }
1119 cloud
1120 }
1121
1122 #[test]
1123 fn test_gpu_ransac_plane() {
1124 pollster::block_on(async {
1125 let Some(gpu) = try_create_gpu_context().await else {
1126 return;
1127 };
1128
1129 let cloud = planar_cloud();
1130 let config = GpuPlaneSegmentationConfig {
1131 max_iterations: 256,
1132 distance_threshold: 0.01,
1133 min_inliers: 140,
1134 };
1135 let result = gpu_segment_plane(&gpu, &cloud, config).await.unwrap();
1136
1137 assert!(result.inliers.len() >= 140);
1138 assert!(result.plane.normal().z.abs() > 0.9);
1139 assert_eq!(result.model, result.plane);
1140 assert_eq!(result.iterations, 256);
1141 });
1142 }
1143
1144 #[test]
1145 fn test_gpu_clusters() {
1146 pollster::block_on(async {
1147 let Some(gpu) = try_create_gpu_context().await else {
1148 return;
1149 };
1150
1151 let cloud = clustered_cloud();
1152 let config = GpuEuclideanClusterConfig::with_max_neighbors(0.09, 5, 100, 16);
1153 let result = gpu_extract_euclidean_clusters(&gpu, &cloud, &config)
1154 .await
1155 .unwrap();
1156 let cluster_clouds = gpu_extract_clusters(&gpu, &cloud, config).await.unwrap();
1157
1158 assert_eq!(result.num_clusters(), 2);
1159 assert_eq!(result.clusters[0].len(), 30);
1160 assert_eq!(result.clusters[1].len(), 20);
1161 assert_eq!(cluster_clouds.len(), 2);
1162 assert_eq!(cluster_clouds[0].len(), 30);
1163 assert_eq!(cluster_clouds[1].len(), 20);
1164 });
1165 }
1166
1167 #[test]
1168 fn test_invalid_inputs() {
1169 let cloud = PointCloud::from_points(vec![Point3f::new(0.0, 0.0, 0.0)]);
1170 assert!(validate_ransac_inputs(&cloud.points, 0.1, 10).is_err());
1171 assert!(validate_ransac_inputs(&planar_cloud().points, -0.1, 10).is_err());
1172 assert!(validate_ransac_inputs(&planar_cloud().points, 0.1, 0).is_err());
1173 assert!(validate_ransac_config(
1174 &planar_cloud().points,
1175 GpuPlaneSegmentationConfig {
1176 min_inliers: 0,
1177 ..GpuPlaneSegmentationConfig::default()
1178 }
1179 )
1180 .is_err());
1181
1182 let config = GpuEuclideanClusterConfig::new(-1.0, 1, 10);
1183 assert!(validate_cluster_inputs(&clustered_cloud().points, &config).is_err());
1184 }
1185}