Skip to main content

threecrate_gpu/
segmentation.rs

1//! GPU-accelerated point cloud segmentation.
2
3use std::collections::HashMap;
4
5use nalgebra::{Vector3, Vector4};
6use threecrate_core::{Error, Point3f, PointCloud, Result};
7
8use crate::GpuContext;
9
10const RANSAC_SCORE_SHADER: &str = r#"
11@group(0) @binding(0) var<storage, read> points: array<vec4<f32>>;
12@group(0) @binding(1) var<storage, read> samples: array<vec4<u32>>;
13@group(0) @binding(2) var<storage, read_write> candidates: array<PlaneCandidate>;
14@group(0) @binding(3) var<uniform> params: RansacParams;
15
16struct PlaneCandidate {
17    coefficients: vec4<f32>,
18    inlier_count: u32,
19    _pad0: u32,
20    _pad1: u32,
21    _pad2: u32,
22}
23
24struct RansacParams {
25    num_points: u32,
26    num_samples: u32,
27    threshold: f32,
28    _pad: u32,
29}
30
31@compute @workgroup_size(64)
32fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
33    let sample_idx = global_id.x;
34    if (sample_idx >= params.num_samples) {
35        return;
36    }
37
38    let sample = samples[sample_idx];
39    let p1 = points[sample.x].xyz;
40    let p2 = points[sample.y].xyz;
41    let p3 = points[sample.z].xyz;
42
43    let normal_raw = cross(p2 - p1, p3 - p1);
44    let normal_len = length(normal_raw);
45    if (normal_len < 1e-8) {
46        candidates[sample_idx].coefficients = vec4<f32>(0.0);
47        candidates[sample_idx].inlier_count = 0u;
48        return;
49    }
50
51    let normal = normal_raw / normal_len;
52    let d = -dot(normal, p1);
53    var count = 0u;
54
55    for (var i = 0u; i < params.num_points; i++) {
56        let distance = abs(dot(normal, points[i].xyz) + d);
57        if (distance <= params.threshold) {
58            count++;
59        }
60    }
61
62    candidates[sample_idx].coefficients = vec4<f32>(normal, d);
63    candidates[sample_idx].inlier_count = count;
64}
65"#;
66
67const RANSAC_INLIER_SHADER: &str = r#"
68@group(0) @binding(0) var<storage, read> points: array<vec4<f32>>;
69@group(0) @binding(1) var<storage, read_write> flags: array<u32>;
70@group(0) @binding(2) var<uniform> params: InlierParams;
71
72struct InlierParams {
73    num_points: u32,
74    threshold: f32,
75    _pad0: u32,
76    _pad1: u32,
77    plane: vec4<f32>,
78}
79
80@compute @workgroup_size(64)
81fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
82    let index = global_id.x;
83    if (index >= params.num_points) {
84        return;
85    }
86
87    let normal_len = length(params.plane.xyz);
88    if (normal_len < 1e-8) {
89        flags[index] = 0u;
90        return;
91    }
92
93    let distance = abs(dot(params.plane.xyz, points[index].xyz) + params.plane.w) / normal_len;
94    flags[index] = select(0u, 1u, distance <= params.threshold);
95}
96"#;
97
98const RADIUS_NEIGHBOR_SHADER: &str = r#"
99@group(0) @binding(0) var<storage, read> points: array<vec4<f32>>;
100@group(0) @binding(1) var<storage, read> candidate_counts: array<u32>;
101@group(0) @binding(2) var<storage, read> candidates: array<array<u32, MAX_CANDIDATES>>;
102@group(0) @binding(3) var<storage, read_write> neighbor_counts: array<u32>;
103@group(0) @binding(4) var<storage, read_write> neighbors: array<array<u32, MAX_NEIGHBORS>>;
104@group(0) @binding(5) var<uniform> params: ClusterParams;
105
106struct ClusterParams {
107    num_points: u32,
108    max_neighbors: u32,
109    max_candidates: u32,
110    _pad0: u32,
111    tolerance: f32,
112    _pad1: u32,
113    _pad2: u32,
114    _pad3: u32,
115}
116
117@compute @workgroup_size(64)
118fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
119    let index = global_id.x;
120    if (index >= params.num_points) {
121        return;
122    }
123
124    let center = points[index].xyz;
125    var stored = 0u;
126    let candidate_count = min(candidate_counts[index], params.max_candidates);
127
128    for (var slot = 0u; slot < candidate_count; slot++) {
129        let candidate_index = candidates[index][slot];
130        if (candidate_index >= params.num_points || candidate_index == index) {
131            continue;
132        }
133
134        let distance = length(points[candidate_index].xyz - center);
135        if (distance <= params.tolerance) {
136            if (stored < params.max_neighbors) {
137                neighbors[index][stored] = candidate_index;
138                stored++;
139            }
140        }
141    }
142
143    neighbor_counts[index] = stored;
144}
145"#;
146
147/// A 3D plane model defined by `ax + by + cz + d = 0`.
148#[derive(Debug, Clone, PartialEq)]
149pub struct GpuPlaneModel {
150    /// Plane coefficients `[a, b, c, d]`.
151    pub coefficients: Vector4<f32>,
152}
153
154impl GpuPlaneModel {
155    /// Create a model from normalized coefficients.
156    pub fn new(coefficients: Vector4<f32>) -> Self {
157        Self { coefficients }
158    }
159
160    /// Return the plane normal.
161    pub fn normal(&self) -> Vector3<f32> {
162        Vector3::new(
163            self.coefficients.x,
164            self.coefficients.y,
165            self.coefficients.z,
166        )
167    }
168
169    /// Distance from `point` to this plane.
170    pub fn distance_to_point(&self, point: &Point3f) -> f32 {
171        let normal = self.normal();
172        let normal_len = normal.magnitude();
173        if normal_len < 1e-8 {
174            return f32::INFINITY;
175        }
176
177        (normal.dot(&point.coords) + self.coefficients.w).abs() / normal_len
178    }
179}
180
181/// Result of GPU RANSAC plane segmentation.
182#[derive(Debug, Clone)]
183pub struct GpuPlaneSegmentationResult {
184    /// Best-scoring plane model.
185    pub plane: GpuPlaneModel,
186    /// Best-scoring plane model.
187    pub model: GpuPlaneModel,
188    /// Indices of points within the distance threshold.
189    pub inliers: Vec<u32>,
190    /// Number of RANSAC candidates evaluated.
191    pub iterations: usize,
192}
193
194/// Configuration for GPU RANSAC plane segmentation.
195#[derive(Debug, Clone, Copy)]
196pub struct GpuPlaneSegmentationConfig {
197    /// Maximum RANSAC candidates to evaluate.
198    pub max_iterations: usize,
199    /// Maximum point-to-plane distance for an inlier.
200    pub distance_threshold: f32,
201    /// Minimum inliers required for a valid result.
202    pub min_inliers: usize,
203}
204
205impl Default for GpuPlaneSegmentationConfig {
206    fn default() -> Self {
207        Self {
208            max_iterations: 1_000,
209            distance_threshold: 0.02,
210            min_inliers: 1,
211        }
212    }
213}
214
215/// Configuration for GPU Euclidean cluster extraction.
216#[derive(Debug, Clone)]
217pub struct GpuEuclideanClusterConfig {
218    /// Maximum distance between neighboring points in the same cluster.
219    pub tolerance: f32,
220    /// Minimum number of points for a valid cluster.
221    pub min_cluster_size: usize,
222    /// Maximum number of points allowed in a valid cluster.
223    pub max_cluster_size: usize,
224    /// Maximum radius neighbors retained per point from the GPU adjacency pass.
225    pub max_neighbors: usize,
226}
227
228impl Default for GpuEuclideanClusterConfig {
229    fn default() -> Self {
230        Self {
231            tolerance: 0.02,
232            min_cluster_size: 100,
233            max_cluster_size: 25_000,
234            max_neighbors: 64,
235        }
236    }
237}
238
239impl GpuEuclideanClusterConfig {
240    /// Create a config using the default `max_neighbors` cap.
241    pub fn new(tolerance: f32, min_cluster_size: usize, max_cluster_size: usize) -> Self {
242        Self {
243            tolerance,
244            min_cluster_size,
245            max_cluster_size,
246            ..Self::default()
247        }
248    }
249
250    /// Create a config with an explicit GPU neighbor cap.
251    pub fn with_max_neighbors(
252        tolerance: f32,
253        min_cluster_size: usize,
254        max_cluster_size: usize,
255        max_neighbors: usize,
256    ) -> Self {
257        Self {
258            tolerance,
259            min_cluster_size,
260            max_cluster_size,
261            max_neighbors,
262        }
263    }
264}
265
266/// Issue-compatible alias for GPU Euclidean cluster extraction config.
267pub type GpuClusterConfig = GpuEuclideanClusterConfig;
268
269/// Result of GPU-accelerated Euclidean cluster extraction.
270#[derive(Debug, Clone)]
271pub struct GpuClusterExtractionResult {
272    /// Each inner vector contains point indices for one cluster, largest first.
273    pub clusters: Vec<Vec<usize>>,
274}
275
276impl GpuClusterExtractionResult {
277    /// Number of clusters found.
278    pub fn num_clusters(&self) -> usize {
279        self.clusters.len()
280    }
281
282    /// Extract a sub-cloud for the cluster at `index`.
283    pub fn get_cluster_cloud(
284        &self,
285        cloud: &PointCloud<Point3f>,
286        index: usize,
287    ) -> Option<PointCloud<Point3f>> {
288        self.clusters.get(index).map(|indices| {
289            PointCloud::from_points(indices.iter().map(|&i| cloud.points[i]).collect())
290        })
291    }
292}
293
294#[repr(C)]
295#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
296struct PlaneCandidate {
297    coefficients: [f32; 4],
298    inlier_count: u32,
299    _padding: [u32; 3],
300}
301
302impl GpuContext {
303    /// Segment the dominant plane using GPU-scored RANSAC candidates.
304    pub async fn segment_plane(
305        &self,
306        points: &[Point3f],
307        config: GpuPlaneSegmentationConfig,
308    ) -> Result<GpuPlaneSegmentationResult> {
309        validate_ransac_config(points, config)?;
310
311        let mut result = self
312            .segment_plane_ransac(points, config.distance_threshold, config.max_iterations)
313            .await?;
314        if result.inliers.len() < config.min_inliers {
315            return Err(Error::Algorithm(format!(
316                "Plane model has {} inliers, below required minimum {}",
317                result.inliers.len(),
318                config.min_inliers
319            )));
320        }
321
322        result.plane = result.model.clone();
323        Ok(result)
324    }
325
326    /// Segment the dominant plane using GPU-scored RANSAC candidates.
327    pub async fn segment_plane_ransac(
328        &self,
329        points: &[Point3f],
330        threshold: f32,
331        max_iters: usize,
332    ) -> Result<GpuPlaneSegmentationResult> {
333        validate_ransac_inputs(points, threshold, max_iters)?;
334
335        let point_data = points_to_vec4(points);
336        let sample_count = max_iters.min(u32::MAX as usize);
337        let samples = generate_ransac_samples(points.len(), sample_count);
338
339        let input_buffer =
340            self.create_buffer_init("RANSAC Points", &point_data, wgpu::BufferUsages::STORAGE);
341        let samples_buffer =
342            self.create_buffer_init("RANSAC Samples", &samples, wgpu::BufferUsages::STORAGE);
343        let candidates_buffer = self.create_buffer(
344            "RANSAC Candidates",
345            (sample_count * std::mem::size_of::<PlaneCandidate>()) as u64,
346            wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
347        );
348
349        #[repr(C)]
350        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
351        struct RansacParams {
352            num_points: u32,
353            num_samples: u32,
354            threshold: f32,
355            _padding: u32,
356        }
357
358        let params = RansacParams {
359            num_points: points.len() as u32,
360            num_samples: sample_count as u32,
361            threshold,
362            _padding: 0,
363        };
364        let params_buffer =
365            self.create_buffer_init("RANSAC Params", &[params], wgpu::BufferUsages::UNIFORM);
366
367        let shader = self.create_shader_module("RANSAC Score Shader", RANSAC_SCORE_SHADER);
368        let layout = self.create_bind_group_layout(
369            "RANSAC Score Layout",
370            &[
371                storage_entry(0, true),
372                storage_entry(1, true),
373                storage_entry(2, false),
374                uniform_entry(3),
375            ],
376        );
377        let pipeline = self.create_pipeline_with_layout("RANSAC Score Pipeline", &shader, &layout);
378        let bind_group = self.create_bind_group(
379            "RANSAC Score Bind Group",
380            &layout,
381            &[
382                wgpu::BindGroupEntry {
383                    binding: 0,
384                    resource: input_buffer.as_entire_binding(),
385                },
386                wgpu::BindGroupEntry {
387                    binding: 1,
388                    resource: samples_buffer.as_entire_binding(),
389                },
390                wgpu::BindGroupEntry {
391                    binding: 2,
392                    resource: candidates_buffer.as_entire_binding(),
393                },
394                wgpu::BindGroupEntry {
395                    binding: 3,
396                    resource: params_buffer.as_entire_binding(),
397                },
398            ],
399        );
400
401        let mut encoder = self
402            .device
403            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
404                label: Some("RANSAC Score Encoder"),
405            });
406        {
407            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
408                label: Some("RANSAC Score Pass"),
409                timestamp_writes: None,
410            });
411            pass.set_pipeline(&pipeline);
412            pass.set_bind_group(0, &bind_group, &[]);
413            pass.dispatch_workgroups(div_ceil(sample_count, 64) as u32, 1, 1);
414        }
415
416        let candidates = self
417            .read_storage_buffer::<PlaneCandidate>(
418                encoder,
419                &candidates_buffer,
420                sample_count,
421                "RANSAC Candidate Staging",
422            )
423            .await?;
424
425        let best = candidates
426            .iter()
427            .max_by_key(|candidate| candidate.inlier_count)
428            .ok_or_else(|| Error::Algorithm("Failed to evaluate RANSAC candidates".to_string()))?;
429
430        if best.inlier_count == 0 {
431            return Err(Error::Algorithm(
432                "Failed to find valid plane model".to_string(),
433            ));
434        }
435
436        let coefficients = Vector4::new(
437            best.coefficients[0],
438            best.coefficients[1],
439            best.coefficients[2],
440            best.coefficients[3],
441        );
442        let inliers = self
443            .plane_inlier_indices(points, coefficients, threshold)
444            .await?;
445
446        Ok(GpuPlaneSegmentationResult {
447            plane: GpuPlaneModel::new(coefficients),
448            model: GpuPlaneModel::new(coefficients),
449            inliers,
450            iterations: sample_count,
451        })
452    }
453
454    /// Extract Euclidean clusters as point-cloud values.
455    pub async fn extract_clusters(
456        &self,
457        cloud: &PointCloud<Point3f>,
458        config: GpuClusterConfig,
459    ) -> Result<Vec<PointCloud<Point3f>>> {
460        let result = self
461            .extract_euclidean_clusters(&cloud.points, &config)
462            .await?;
463        Ok(result
464            .clusters
465            .iter()
466            .map(|indices| {
467                PointCloud::from_points(indices.iter().map(|&i| cloud.points[i]).collect())
468            })
469            .collect())
470    }
471
472    /// Extract Euclidean clusters using GPU-computed radius adjacency.
473    pub async fn extract_euclidean_clusters(
474        &self,
475        points: &[Point3f],
476        config: &GpuEuclideanClusterConfig,
477    ) -> Result<GpuClusterExtractionResult> {
478        validate_cluster_inputs(points, config)?;
479
480        let max_neighbors = config.max_neighbors.min(256).max(1);
481        let max_candidates = max_neighbors.saturating_mul(8).clamp(max_neighbors, 1024);
482        let point_data = points_to_vec4(points);
483        let (candidate_counts, candidate_indices) =
484            build_voxel_candidate_neighbors(points, config.tolerance, max_candidates)?;
485        let input_buffer =
486            self.create_buffer_init("Cluster Points", &point_data, wgpu::BufferUsages::STORAGE);
487        let candidate_counts_buffer = self.create_buffer_init(
488            "Cluster Candidate Counts",
489            &candidate_counts,
490            wgpu::BufferUsages::STORAGE,
491        );
492        let candidates_buffer = self.create_buffer_init(
493            "Cluster Candidates",
494            &candidate_indices,
495            wgpu::BufferUsages::STORAGE,
496        );
497        let counts_buffer = self.create_buffer(
498            "Cluster Neighbor Counts",
499            (points.len() * std::mem::size_of::<u32>()) as u64,
500            wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
501        );
502        let neighbors_buffer = self.create_buffer(
503            "Cluster Neighbors",
504            (points.len() * max_neighbors * std::mem::size_of::<u32>()) as u64,
505            wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
506        );
507
508        #[repr(C)]
509        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
510        struct ClusterParams {
511            num_points: u32,
512            max_neighbors: u32,
513            max_candidates: u32,
514            _padding0: u32,
515            tolerance: f32,
516            _padding1: [u32; 3],
517        }
518
519        let params = ClusterParams {
520            num_points: points.len() as u32,
521            max_neighbors: max_neighbors as u32,
522            max_candidates: max_candidates as u32,
523            _padding0: 0,
524            tolerance: config.tolerance,
525            _padding1: [0; 3],
526        };
527        let params_buffer =
528            self.create_buffer_init("Cluster Params", &[params], wgpu::BufferUsages::UNIFORM);
529
530        let shader_source = RADIUS_NEIGHBOR_SHADER
531            .replace("MAX_NEIGHBORS", &max_neighbors.to_string())
532            .replace("MAX_CANDIDATES", &max_candidates.to_string());
533        let shader = self.create_shader_module("Cluster Radius Neighbor Shader", &shader_source);
534        let layout = self.create_bind_group_layout(
535            "Cluster Radius Neighbor Layout",
536            &[
537                storage_entry(0, true),
538                storage_entry(1, true),
539                storage_entry(2, true),
540                storage_entry(3, false),
541                storage_entry(4, false),
542                uniform_entry(5),
543            ],
544        );
545        let pipeline =
546            self.create_pipeline_with_layout("Cluster Radius Neighbor Pipeline", &shader, &layout);
547        let bind_group = self.create_bind_group(
548            "Cluster Radius Neighbor Bind Group",
549            &layout,
550            &[
551                wgpu::BindGroupEntry {
552                    binding: 0,
553                    resource: input_buffer.as_entire_binding(),
554                },
555                wgpu::BindGroupEntry {
556                    binding: 1,
557                    resource: candidate_counts_buffer.as_entire_binding(),
558                },
559                wgpu::BindGroupEntry {
560                    binding: 2,
561                    resource: candidates_buffer.as_entire_binding(),
562                },
563                wgpu::BindGroupEntry {
564                    binding: 3,
565                    resource: counts_buffer.as_entire_binding(),
566                },
567                wgpu::BindGroupEntry {
568                    binding: 4,
569                    resource: neighbors_buffer.as_entire_binding(),
570                },
571                wgpu::BindGroupEntry {
572                    binding: 5,
573                    resource: params_buffer.as_entire_binding(),
574                },
575            ],
576        );
577
578        let mut encoder = self
579            .device
580            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
581                label: Some("Cluster Radius Neighbor Encoder"),
582            });
583        {
584            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
585                label: Some("Cluster Radius Neighbor Pass"),
586                timestamp_writes: None,
587            });
588            pass.set_pipeline(&pipeline);
589            pass.set_bind_group(0, &bind_group, &[]);
590            pass.dispatch_workgroups(div_ceil(points.len(), 64) as u32, 1, 1);
591        }
592
593        let counts = self
594            .read_storage_buffer::<u32>(
595                encoder,
596                &counts_buffer,
597                points.len(),
598                "Cluster Count Staging",
599            )
600            .await?;
601
602        let encoder = self
603            .device
604            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
605                label: Some("Cluster Neighbor Read Encoder"),
606            });
607        let neighbors = self
608            .read_storage_buffer::<u32>(
609                encoder,
610                &neighbors_buffer,
611                points.len() * max_neighbors,
612                "Cluster Neighbor Staging",
613            )
614            .await?;
615
616        let mut disjoint_set = DisjointSet::new(points.len());
617        for point_idx in 0..points.len() {
618            let count = counts[point_idx].min(max_neighbors as u32) as usize;
619            let base = point_idx * max_neighbors;
620
621            for &neighbor in &neighbors[base..base + count] {
622                let neighbor = neighbor as usize;
623                if neighbor < points.len() {
624                    disjoint_set.union(point_idx, neighbor);
625                }
626            }
627        }
628
629        let mut by_root: HashMap<usize, usize> = HashMap::new();
630        let mut clusters = Vec::new();
631        for point_idx in 0..points.len() {
632            let root = disjoint_set.find(point_idx);
633            let cluster_idx = *by_root.entry(root).or_insert_with(|| {
634                clusters.push(Vec::new());
635                clusters.len() - 1
636            });
637            clusters[cluster_idx].push(point_idx);
638        }
639
640        clusters.retain(|cluster| {
641            cluster.len() >= config.min_cluster_size && cluster.len() <= config.max_cluster_size
642        });
643        clusters.sort_by(|a, b| b.len().cmp(&a.len()));
644        Ok(GpuClusterExtractionResult { clusters })
645    }
646
647    async fn plane_inlier_indices(
648        &self,
649        points: &[Point3f],
650        coefficients: Vector4<f32>,
651        threshold: f32,
652    ) -> Result<Vec<u32>> {
653        let point_data = points_to_vec4(points);
654        let input_buffer = self.create_buffer_init(
655            "Plane Inlier Points",
656            &point_data,
657            wgpu::BufferUsages::STORAGE,
658        );
659        let flags_buffer = self.create_buffer(
660            "Plane Inlier Flags",
661            (points.len() * std::mem::size_of::<u32>()) as u64,
662            wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
663        );
664
665        #[repr(C)]
666        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
667        struct InlierParams {
668            num_points: u32,
669            threshold: f32,
670            _padding: [u32; 2],
671            plane: [f32; 4],
672        }
673
674        let params = InlierParams {
675            num_points: points.len() as u32,
676            threshold,
677            _padding: [0; 2],
678            plane: [
679                coefficients.x,
680                coefficients.y,
681                coefficients.z,
682                coefficients.w,
683            ],
684        };
685        let params_buffer = self.create_buffer_init(
686            "Plane Inlier Params",
687            &[params],
688            wgpu::BufferUsages::UNIFORM,
689        );
690        let shader = self.create_shader_module("Plane Inlier Shader", RANSAC_INLIER_SHADER);
691        let layout = self.create_bind_group_layout(
692            "Plane Inlier Layout",
693            &[
694                storage_entry(0, true),
695                storage_entry(1, false),
696                uniform_entry(2),
697            ],
698        );
699        let pipeline = self.create_pipeline_with_layout("Plane Inlier Pipeline", &shader, &layout);
700        let bind_group = self.create_bind_group(
701            "Plane Inlier Bind Group",
702            &layout,
703            &[
704                wgpu::BindGroupEntry {
705                    binding: 0,
706                    resource: input_buffer.as_entire_binding(),
707                },
708                wgpu::BindGroupEntry {
709                    binding: 1,
710                    resource: flags_buffer.as_entire_binding(),
711                },
712                wgpu::BindGroupEntry {
713                    binding: 2,
714                    resource: params_buffer.as_entire_binding(),
715                },
716            ],
717        );
718
719        let mut encoder = self
720            .device
721            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
722                label: Some("Plane Inlier Encoder"),
723            });
724        {
725            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
726                label: Some("Plane Inlier Pass"),
727                timestamp_writes: None,
728            });
729            pass.set_pipeline(&pipeline);
730            pass.set_bind_group(0, &bind_group, &[]);
731            pass.dispatch_workgroups(div_ceil(points.len(), 64) as u32, 1, 1);
732        }
733
734        let flags = self
735            .read_storage_buffer::<u32>(
736                encoder,
737                &flags_buffer,
738                points.len(),
739                "Plane Inlier Staging",
740            )
741            .await?;
742        Ok(flags
743            .iter()
744            .enumerate()
745            .filter_map(|(idx, flag)| (*flag == 1).then_some(idx as u32))
746            .collect())
747    }
748
749    fn create_pipeline_with_layout(
750        &self,
751        label: &str,
752        shader: &wgpu::ShaderModule,
753        layout: &wgpu::BindGroupLayout,
754    ) -> wgpu::ComputePipeline {
755        self.device
756            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
757                label: Some(label),
758                layout: Some(&self.device.create_pipeline_layout(
759                    &wgpu::PipelineLayoutDescriptor {
760                        label: Some(label),
761                        bind_group_layouts: &[Some(layout)],
762                        immediate_size: 0,
763                    },
764                )),
765                module: shader,
766                entry_point: Some("main"),
767                compilation_options: wgpu::PipelineCompilationOptions::default(),
768                cache: None,
769            })
770    }
771
772    async fn read_storage_buffer<T: bytemuck::Pod>(
773        &self,
774        mut encoder: wgpu::CommandEncoder,
775        source: &wgpu::Buffer,
776        len: usize,
777        label: &str,
778    ) -> Result<Vec<T>> {
779        let size = (len * std::mem::size_of::<T>()) as u64;
780        let staging = self.create_buffer(
781            label,
782            size,
783            wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
784        );
785        encoder.copy_buffer_to_buffer(source, 0, &staging, 0, size);
786        self.queue.submit(std::iter::once(encoder.finish()));
787
788        let slice = staging.slice(..);
789        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
790        slice.map_async(wgpu::MapMode::Read, move |result| {
791            sender.send(result).unwrap()
792        });
793        let _ = self.device.poll(wgpu::PollType::Wait {
794            submission_index: None,
795            timeout: None,
796        });
797
798        if let Some(Ok(())) = receiver.receive().await {
799            let data = slice.get_mapped_range();
800            let result = bytemuck::cast_slice(&data).to_vec();
801            drop(data);
802            staging.unmap();
803            Ok(result)
804        } else {
805            Err(Error::Gpu(
806                "Failed to read GPU segmentation results".to_string(),
807            ))
808        }
809    }
810}
811
812/// GPU-accelerated RANSAC plane segmentation for a point cloud.
813pub async fn gpu_segment_plane(
814    gpu_context: &GpuContext,
815    cloud: &PointCloud<Point3f>,
816    config: GpuPlaneSegmentationConfig,
817) -> Result<GpuPlaneSegmentationResult> {
818    gpu_context.segment_plane(&cloud.points, config).await
819}
820
821/// GPU-accelerated RANSAC plane segmentation for a point cloud.
822pub async fn gpu_segment_plane_ransac(
823    gpu_context: &GpuContext,
824    cloud: &PointCloud<Point3f>,
825    threshold: f32,
826    max_iters: usize,
827) -> Result<GpuPlaneSegmentationResult> {
828    gpu_context
829        .segment_plane_ransac(&cloud.points, threshold, max_iters)
830        .await
831}
832
833/// GPU-accelerated Euclidean cluster extraction for a point cloud.
834pub async fn gpu_extract_clusters(
835    gpu_context: &GpuContext,
836    cloud: &PointCloud<Point3f>,
837    config: GpuClusterConfig,
838) -> Result<Vec<PointCloud<Point3f>>> {
839    gpu_context.extract_clusters(cloud, config).await
840}
841
842/// GPU-accelerated Euclidean cluster extraction for a point cloud.
843pub async fn gpu_extract_euclidean_clusters(
844    gpu_context: &GpuContext,
845    cloud: &PointCloud<Point3f>,
846    config: &GpuEuclideanClusterConfig,
847) -> Result<GpuClusterExtractionResult> {
848    gpu_context
849        .extract_euclidean_clusters(&cloud.points, config)
850        .await
851}
852
853fn validate_ransac_config(points: &[Point3f], config: GpuPlaneSegmentationConfig) -> Result<()> {
854    validate_ransac_inputs(points, config.distance_threshold, config.max_iterations)?;
855    if config.min_inliers == 0 {
856        return Err(Error::InvalidData(
857            "min_inliers must be at least 1".to_string(),
858        ));
859    }
860    Ok(())
861}
862
863fn validate_ransac_inputs(points: &[Point3f], threshold: f32, max_iters: usize) -> Result<()> {
864    if points.len() < 3 {
865        return Err(Error::InvalidData(
866            "Need at least 3 points for plane segmentation".to_string(),
867        ));
868    }
869    if threshold <= 0.0 {
870        return Err(Error::InvalidData("Threshold must be positive".to_string()));
871    }
872    if max_iters == 0 {
873        return Err(Error::InvalidData(
874            "Max iterations must be positive".to_string(),
875        ));
876    }
877    if points.len() > u32::MAX as usize {
878        return Err(Error::InvalidData(
879            "Point cloud is too large for GPU segmentation".to_string(),
880        ));
881    }
882    Ok(())
883}
884
885fn validate_cluster_inputs(points: &[Point3f], config: &GpuEuclideanClusterConfig) -> Result<()> {
886    if points.is_empty() {
887        return Err(Error::InvalidData("Point cloud is empty".to_string()));
888    }
889    if config.tolerance <= 0.0 {
890        return Err(Error::InvalidData("Tolerance must be positive".to_string()));
891    }
892    if config.min_cluster_size == 0 {
893        return Err(Error::InvalidData(
894            "min_cluster_size must be at least 1".to_string(),
895        ));
896    }
897    if config.min_cluster_size > config.max_cluster_size {
898        return Err(Error::InvalidData(
899            "min_cluster_size must not exceed max_cluster_size".to_string(),
900        ));
901    }
902    if config.max_neighbors == 0 {
903        return Err(Error::InvalidData(
904            "max_neighbors must be at least 1".to_string(),
905        ));
906    }
907    if points.len() > u32::MAX as usize {
908        return Err(Error::InvalidData(
909            "Point cloud is too large for GPU clustering".to_string(),
910        ));
911    }
912    Ok(())
913}
914
915fn points_to_vec4(points: &[Point3f]) -> Vec<[f32; 4]> {
916    points.iter().map(|p| [p.x, p.y, p.z, 0.0]).collect()
917}
918
919fn build_voxel_candidate_neighbors(
920    points: &[Point3f],
921    tolerance: f32,
922    max_candidates: usize,
923) -> Result<(Vec<u32>, Vec<u32>)> {
924    let total_candidates = points
925        .len()
926        .checked_mul(max_candidates)
927        .ok_or_else(|| Error::InvalidData("Cluster candidate buffer is too large".to_string()))?;
928    let mut bins: HashMap<(i32, i32, i32), Vec<u32>> = HashMap::new();
929
930    for (idx, point) in points.iter().enumerate() {
931        bins.entry(voxel_key(point, tolerance))
932            .or_default()
933            .push(idx as u32);
934    }
935
936    let mut counts = vec![0u32; points.len()];
937    let mut candidates = vec![u32::MAX; total_candidates];
938    for (idx, point) in points.iter().enumerate() {
939        let (vx, vy, vz) = voxel_key(point, tolerance);
940        let base = idx * max_candidates;
941        let mut stored = 0usize;
942
943        for dx in -1..=1 {
944            for dy in -1..=1 {
945                for dz in -1..=1 {
946                    let key = (vx + dx, vy + dy, vz + dz);
947                    let Some(bucket) = bins.get(&key) else {
948                        continue;
949                    };
950
951                    for &candidate in bucket {
952                        if candidate as usize == idx {
953                            continue;
954                        }
955                        if stored == max_candidates {
956                            break;
957                        }
958                        candidates[base + stored] = candidate;
959                        stored += 1;
960                    }
961                }
962            }
963        }
964
965        counts[idx] = stored as u32;
966    }
967
968    Ok((counts, candidates))
969}
970
971fn voxel_key(point: &Point3f, tolerance: f32) -> (i32, i32, i32) {
972    (
973        (point.x / tolerance).floor() as i32,
974        (point.y / tolerance).floor() as i32,
975        (point.z / tolerance).floor() as i32,
976    )
977}
978
979fn generate_ransac_samples(num_points: usize, sample_count: usize) -> Vec<[u32; 4]> {
980    let mut state = ((num_points as u64) << 32) ^ sample_count as u64 ^ 0x9E37_79B9_7F4A_7C15;
981    let mut samples = Vec::with_capacity(sample_count);
982
983    for iteration in 0..sample_count {
984        let mut a = next_index(&mut state, num_points);
985        let mut b = next_index(&mut state, num_points);
986        let mut c = next_index(&mut state, num_points);
987
988        if a == b || a == c || b == c {
989            a = iteration % num_points;
990            b = (iteration.wrapping_mul(37) + 1) % num_points;
991            c = (iteration.wrapping_mul(101) + 2) % num_points;
992            while b == a {
993                b = (b + 1) % num_points;
994            }
995            while c == a || c == b {
996                c = (c + 1) % num_points;
997            }
998        }
999
1000        samples.push([a as u32, b as u32, c as u32, 0]);
1001    }
1002
1003    samples
1004}
1005
1006fn next_index(state: &mut u64, len: usize) -> usize {
1007    *state = state
1008        .wrapping_mul(6_364_136_223_846_793_005)
1009        .wrapping_add(1_442_695_040_888_963_407);
1010    ((*state >> 32) as usize) % len
1011}
1012
1013struct DisjointSet {
1014    parent: Vec<usize>,
1015    rank: Vec<u8>,
1016}
1017
1018impl DisjointSet {
1019    fn new(len: usize) -> Self {
1020        Self {
1021            parent: (0..len).collect(),
1022            rank: vec![0; len],
1023        }
1024    }
1025
1026    fn find(&mut self, value: usize) -> usize {
1027        if self.parent[value] != value {
1028            self.parent[value] = self.find(self.parent[value]);
1029        }
1030        self.parent[value]
1031    }
1032
1033    fn union(&mut self, a: usize, b: usize) {
1034        let root_a = self.find(a);
1035        let root_b = self.find(b);
1036        if root_a == root_b {
1037            return;
1038        }
1039
1040        match self.rank[root_a].cmp(&self.rank[root_b]) {
1041            std::cmp::Ordering::Less => self.parent[root_a] = root_b,
1042            std::cmp::Ordering::Greater => self.parent[root_b] = root_a,
1043            std::cmp::Ordering::Equal => {
1044                self.parent[root_b] = root_a;
1045                self.rank[root_a] += 1;
1046            }
1047        }
1048    }
1049}
1050
1051fn storage_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
1052    wgpu::BindGroupLayoutEntry {
1053        binding,
1054        visibility: wgpu::ShaderStages::COMPUTE,
1055        ty: wgpu::BindingType::Buffer {
1056            ty: wgpu::BufferBindingType::Storage { read_only },
1057            has_dynamic_offset: false,
1058            min_binding_size: None,
1059        },
1060        count: None,
1061    }
1062}
1063
1064fn uniform_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
1065    wgpu::BindGroupLayoutEntry {
1066        binding,
1067        visibility: wgpu::ShaderStages::COMPUTE,
1068        ty: wgpu::BindingType::Buffer {
1069            ty: wgpu::BufferBindingType::Uniform,
1070            has_dynamic_offset: false,
1071            min_binding_size: None,
1072        },
1073        count: None,
1074    }
1075}
1076
1077fn div_ceil(value: usize, divisor: usize) -> usize {
1078    value.div_ceil(divisor)
1079}
1080
1081#[cfg(test)]
1082mod tests {
1083    use super::*;
1084
1085    async fn try_create_gpu_context() -> Option<GpuContext> {
1086        match GpuContext::new().await {
1087            Ok(gpu) => Some(gpu),
1088            Err(_) => {
1089                println!("GPU not available, skipping GPU-dependent test");
1090                None
1091            }
1092        }
1093    }
1094
1095    fn planar_cloud() -> PointCloud<Point3f> {
1096        let mut cloud = PointCloud::new();
1097        for x in 0..12 {
1098            for y in 0..12 {
1099                cloud.push(Point3f::new(x as f32 * 0.1, y as f32 * 0.1, 0.0));
1100            }
1101        }
1102        cloud.push(Point3f::new(0.0, 0.0, 3.0));
1103        cloud.push(Point3f::new(1.0, 1.0, -3.0));
1104        cloud
1105    }
1106
1107    fn clustered_cloud() -> PointCloud<Point3f> {
1108        let mut cloud = PointCloud::new();
1109        for i in 0..30 {
1110            let x = (i % 5) as f32 * 0.04;
1111            let y = (i / 5) as f32 * 0.04;
1112            cloud.push(Point3f::new(x, y, 0.0));
1113        }
1114        for i in 0..20 {
1115            let x = 5.0 + (i % 5) as f32 * 0.04;
1116            let y = (i / 5) as f32 * 0.04;
1117            cloud.push(Point3f::new(x, y, 0.0));
1118        }
1119        cloud
1120    }
1121
1122    #[test]
1123    fn test_gpu_ransac_plane() {
1124        pollster::block_on(async {
1125            let Some(gpu) = try_create_gpu_context().await else {
1126                return;
1127            };
1128
1129            let cloud = planar_cloud();
1130            let config = GpuPlaneSegmentationConfig {
1131                max_iterations: 256,
1132                distance_threshold: 0.01,
1133                min_inliers: 140,
1134            };
1135            let result = gpu_segment_plane(&gpu, &cloud, config).await.unwrap();
1136
1137            assert!(result.inliers.len() >= 140);
1138            assert!(result.plane.normal().z.abs() > 0.9);
1139            assert_eq!(result.model, result.plane);
1140            assert_eq!(result.iterations, 256);
1141        });
1142    }
1143
1144    #[test]
1145    fn test_gpu_clusters() {
1146        pollster::block_on(async {
1147            let Some(gpu) = try_create_gpu_context().await else {
1148                return;
1149            };
1150
1151            let cloud = clustered_cloud();
1152            let config = GpuEuclideanClusterConfig::with_max_neighbors(0.09, 5, 100, 16);
1153            let result = gpu_extract_euclidean_clusters(&gpu, &cloud, &config)
1154                .await
1155                .unwrap();
1156            let cluster_clouds = gpu_extract_clusters(&gpu, &cloud, config).await.unwrap();
1157
1158            assert_eq!(result.num_clusters(), 2);
1159            assert_eq!(result.clusters[0].len(), 30);
1160            assert_eq!(result.clusters[1].len(), 20);
1161            assert_eq!(cluster_clouds.len(), 2);
1162            assert_eq!(cluster_clouds[0].len(), 30);
1163            assert_eq!(cluster_clouds[1].len(), 20);
1164        });
1165    }
1166
1167    #[test]
1168    fn test_invalid_inputs() {
1169        let cloud = PointCloud::from_points(vec![Point3f::new(0.0, 0.0, 0.0)]);
1170        assert!(validate_ransac_inputs(&cloud.points, 0.1, 10).is_err());
1171        assert!(validate_ransac_inputs(&planar_cloud().points, -0.1, 10).is_err());
1172        assert!(validate_ransac_inputs(&planar_cloud().points, 0.1, 0).is_err());
1173        assert!(validate_ransac_config(
1174            &planar_cloud().points,
1175            GpuPlaneSegmentationConfig {
1176                min_inliers: 0,
1177                ..GpuPlaneSegmentationConfig::default()
1178            }
1179        )
1180        .is_err());
1181
1182        let config = GpuEuclideanClusterConfig::new(-1.0, 1, 10);
1183        assert!(validate_cluster_inputs(&clustered_cloud().points, &config).is_err());
1184    }
1185}