threecrate_gpu/
normals.rs

1//! GPU-accelerated normal estimation
2
3use threecrate_core::{PointCloud, Result, Point3f, NormalPoint3f};
4use crate::GpuContext;
5// use wgpu::util::DeviceExt; // Used in device.rs
6
7const NORMALS_SHADER: &str = r#"
8@group(0) @binding(0) var<storage, read> input_points: array<vec4<f32>>;
9@group(0) @binding(1) var<storage, read_write> output_normals: array<vec4<f32>>;
10@group(0) @binding(2) var<storage, read> neighbors: array<array<u32, 64>>;
11@group(0) @binding(3) var<uniform> params: NormalParams;
12
13struct NormalParams {
14    num_points: u32,
15    k_neighbors: u32,
16    consistent_orientation: u32,
17    _pad: u32,
18    viewpoint: vec4<f32>,
19}
20
21fn mat3_mul_vec3(m00: f32, m01: f32, m02: f32,
22                 m10: f32, m11: f32, m12: f32,
23                 m20: f32, m21: f32, m22: f32,
24                 v: vec3<f32>) -> vec3<f32> {
25    return vec3<f32>(
26        m00 * v.x + m01 * v.y + m02 * v.z,
27        m10 * v.x + m11 * v.y + m12 * v.z,
28        m20 * v.x + m21 * v.y + m22 * v.z
29    );
30}
31
32@compute @workgroup_size(64)
33fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
34    let index = global_id.x;
35    if (index >= params.num_points) {
36        return;
37    }
38    
39    let center_point = input_points[index].xyz;
40    
41    // Compute covariance matrix via neighbor centroid
42    var centroid = vec3<f32>(0.0);
43    var count = 0u;
44    for (var i = 0u; i < params.k_neighbors; i++) {
45        let neighbor_idx = neighbors[index][i];
46        if (neighbor_idx < params.num_points) {
47            centroid += input_points[neighbor_idx].xyz;
48            count++;
49        }
50    }
51    
52    if (count < 3u) {
53        output_normals[index] = vec4<f32>(0.0, 0.0, 1.0, 0.0);
54        return;
55    }
56    
57    centroid /= f32(count);
58    
59    // Covariance components
60    var c00 = 0.0; var c01 = 0.0; var c02 = 0.0;
61    var c11 = 0.0; var c12 = 0.0; var c22 = 0.0;
62    for (var i = 0u; i < params.k_neighbors; i++) {
63        let neighbor_idx = neighbors[index][i];
64        if (neighbor_idx < params.num_points) {
65            let d = input_points[neighbor_idx].xyz - centroid;
66            c00 += d.x * d.x;
67            c01 += d.x * d.y;
68            c02 += d.x * d.z;
69            c11 += d.y * d.y;
70            c12 += d.y * d.z;
71            c22 += d.z * d.z;
72        }
73    }
74    let inv_count = 1.0 / f32(count);
75    c00 *= inv_count; c01 *= inv_count; c02 *= inv_count;
76    c11 *= inv_count; c12 *= inv_count; c22 *= inv_count;
77    
78    // Power iteration on shifted matrix D = trace(C) * I - C to get eigenvector of smallest eigenvalue of C
79    let trace_c = c00 + c11 + c22;
80    
81    // Initial vector using cross of two neighbor directions for stability
82    let n0 = neighbors[index][0u];
83    let n1 = neighbors[index][1u];
84    var v = vec3<f32>(0.0, 0.0, 1.0);
85    if (n0 < params.num_points && n1 < params.num_points) {
86        let d1 = normalize(input_points[n0].xyz - center_point);
87        let d2 = normalize(input_points[n1].xyz - center_point);
88        let cp = cross(d1, d2);
89        if (length(cp) > 1e-6) {
90            v = normalize(cp);
91        }
92    }
93    
94    // Perform fixed number of iterations
95    for (var it = 0u; it < 8u; it++) {
96        // Multiply v by D = trace*I - C
97        let Cv = mat3_mul_vec3(c00, c01, c02, c01, c11, c12, c02, c12, c22, v);
98        let Dv = vec3<f32>(trace_c * v.x, trace_c * v.y, trace_c * v.z) - Cv;
99        let lenDv = length(Dv);
100        if (lenDv > 1e-8) {
101            v = Dv / lenDv;
102        }
103    }
104    var normal = v;
105    
106    // Orientation consistency
107    if (params.consistent_orientation == 1u) {
108        let to_view = normalize(params.viewpoint.xyz - center_point);
109        if (dot(normal, to_view) < 0.0) {
110            normal = -normal;
111        }
112    }
113    
114    output_normals[index] = vec4<f32>(normal, 0.0);
115}
116"#;
117
118impl GpuContext {
119    /// Compute normals for a point cloud using GPU acceleration with options
120    pub async fn compute_normals_with_options(
121        &self,
122        points: &[Point3f],
123        k_neighbors: usize,
124        consistent_orientation: bool,
125        viewpoint: Option<[f32; 3]>,
126    ) -> Result<Vec<nalgebra::Vector3<f32>>> {
127        if points.is_empty() {
128            return Ok(Vec::new());
129        }
130
131        // Convert points to GPU format (std430 alignment prefers vec4)
132        let point_data: Vec<[f32; 4]> = points
133            .iter()
134            .map(|p| [p.x, p.y, p.z, 0.0])
135            .collect();
136
137        // Create buffers
138        let input_buffer = self.create_buffer_init(
139            "Input Points",
140            &point_data,
141            wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
142        );
143
144        let output_buffer = self.create_buffer(
145            "Output Normals",
146            (point_data.len() * std::mem::size_of::<[f32; 4]>()) as u64,
147            wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
148        );
149
150        // For now, use a simple neighbor computation (could be replaced with KD-tree)
151        let k_neighbors = k_neighbors.max(3).min(64);
152        let neighbors = self.compute_neighbors_simple_points3(&points.iter().map(|p| [p.x, p.y, p.z]).collect::<Vec<[f32;3]>>(), k_neighbors);
153        let neighbors_buffer = self.create_buffer_init(
154            "Neighbors",
155            &neighbors,
156            wgpu::BufferUsages::STORAGE,
157        );
158
159        #[repr(C)]
160        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
161        struct NormalParams {
162            num_points: u32,
163            k_neighbors: u32,
164            consistent_orientation: u32,
165            _pad: u32,
166            viewpoint: [f32; 4],
167        }
168
169        // Determine default viewpoint (mimic CPU implementation): above bbox center along +Z by extent
170        let vp = if let Some(vp) = viewpoint {
171            [vp[0], vp[1], vp[2], 0.0]
172        } else {
173            let mut min_x = point_data[0][0];
174            let mut min_y = point_data[0][1];
175            let mut min_z = point_data[0][2];
176            let mut max_x = point_data[0][0];
177            let mut max_y = point_data[0][1];
178            let mut max_z = point_data[0][2];
179            for p in &point_data {
180                min_x = min_x.min(p[0]);
181                min_y = min_y.min(p[1]);
182                min_z = min_z.min(p[2]);
183                max_x = max_x.max(p[0]);
184                max_y = max_y.max(p[1]);
185                max_z = max_z.max(p[2]);
186            }
187            let cx = (min_x + max_x) * 0.5;
188            let cy = (min_y + max_y) * 0.5;
189            let cz = (min_z + max_z) * 0.5;
190            let dx = max_x - min_x;
191            let dy = max_y - min_y;
192            let dz = max_z - min_z;
193            let extent = (dx * dx + dy * dy + dz * dz).sqrt();
194            [cx, cy, cz + extent, 0.0]
195        };
196
197        let params = NormalParams {
198            num_points: points.len() as u32,
199            k_neighbors: k_neighbors as u32,
200            consistent_orientation: if consistent_orientation { 1 } else { 0 },
201            _pad: 0,
202            viewpoint: vp,
203        };
204
205        let params_buffer = self.create_buffer_init(
206            "Params",
207            &[params],
208            wgpu::BufferUsages::UNIFORM,
209        );
210
211        // Create shader
212        let shader = self.create_shader_module("Normals Compute", NORMALS_SHADER);
213
214        // Create bind group layout
215        let bind_group_layout = self.create_bind_group_layout(
216            "Normal Computation",
217            &[
218                wgpu::BindGroupLayoutEntry {
219                    binding: 0,
220                    visibility: wgpu::ShaderStages::COMPUTE,
221                    ty: wgpu::BindingType::Buffer {
222                        ty: wgpu::BufferBindingType::Storage { read_only: true },
223                        has_dynamic_offset: false,
224                        min_binding_size: None,
225                    },
226                    count: None,
227                },
228                wgpu::BindGroupLayoutEntry {
229                    binding: 1,
230                    visibility: wgpu::ShaderStages::COMPUTE,
231                    ty: wgpu::BindingType::Buffer {
232                        ty: wgpu::BufferBindingType::Storage { read_only: false },
233                        has_dynamic_offset: false,
234                        min_binding_size: None,
235                    },
236                    count: None,
237                },
238                wgpu::BindGroupLayoutEntry {
239                    binding: 2,
240                    visibility: wgpu::ShaderStages::COMPUTE,
241                    ty: wgpu::BindingType::Buffer {
242                        ty: wgpu::BufferBindingType::Storage { read_only: true },
243                        has_dynamic_offset: false,
244                        min_binding_size: None,
245                    },
246                    count: None,
247                },
248                wgpu::BindGroupLayoutEntry {
249                    binding: 3,
250                    visibility: wgpu::ShaderStages::COMPUTE,
251                    ty: wgpu::BindingType::Buffer {
252                        ty: wgpu::BufferBindingType::Uniform,
253                        has_dynamic_offset: false,
254                        min_binding_size: None,
255                    },
256                    count: None,
257                },
258            ],
259        );
260
261        // Create compute pipeline
262        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
263            label: Some("Normal Computation Pipeline"),
264            layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
265                label: Some("Normal Pipeline Layout"),
266                bind_group_layouts: &[&bind_group_layout],
267                push_constant_ranges: &[],
268            })),
269            module: &shader,
270            entry_point: "main",
271            compilation_options: wgpu::PipelineCompilationOptions::default(),
272        });
273
274        // Create bind group
275        let bind_group = self.create_bind_group(
276            "Normal Computation",
277            &bind_group_layout,
278            &[
279                wgpu::BindGroupEntry {
280                    binding: 0,
281                    resource: input_buffer.as_entire_binding(),
282                },
283                wgpu::BindGroupEntry {
284                    binding: 1,
285                    resource: output_buffer.as_entire_binding(),
286                },
287                wgpu::BindGroupEntry {
288                    binding: 2,
289                    resource: neighbors_buffer.as_entire_binding(),
290                },
291                wgpu::BindGroupEntry {
292                    binding: 3,
293                    resource: params_buffer.as_entire_binding(),
294                },
295            ],
296        );
297
298        // Execute compute shader
299        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
300            label: Some("Normal Computation"),
301        });
302
303        {
304            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
305                label: Some("Normal Computation Pass"),
306                timestamp_writes: None,
307            });
308            compute_pass.set_pipeline(&pipeline);
309            compute_pass.set_bind_group(0, &bind_group, &[]);
310            let workgroup_count = (points.len() + 63) / 64; // 64 is workgroup size
311            compute_pass.dispatch_workgroups(workgroup_count as u32, 1, 1);
312        }
313
314        // Read back results
315        let staging_buffer = self.create_buffer(
316            "Staging Buffer",
317            (point_data.len() * std::mem::size_of::<[f32; 4]>()) as u64,
318            wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
319        );
320
321        encoder.copy_buffer_to_buffer(
322            &output_buffer,
323            0,
324            &staging_buffer,
325            0,
326            (point_data.len() * std::mem::size_of::<[f32; 4]>()) as u64,
327        );
328
329        self.queue.submit(std::iter::once(encoder.finish()));
330
331        // Map and read results
332        let buffer_slice = staging_buffer.slice(..);
333        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
334        buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
335
336        self.device.poll(wgpu::Maintain::wait()).panic_on_timeout();
337
338        if let Some(Ok(())) = receiver.receive().await {
339            let data = buffer_slice.get_mapped_range();
340            let normals4: Vec<[f32; 4]> = bytemuck::cast_slice(&data).to_vec();
341            
342            let result = normals4
343                .into_iter()
344                .map(|n| nalgebra::Vector3::new(n[0], n[1], n[2]))
345                .collect();
346            
347            drop(data);
348            staging_buffer.unmap();
349            
350            Ok(result)
351        } else {
352            Err(threecrate_core::Error::Gpu("Failed to read GPU results".to_string()))
353        }
354    }
355
356    /// Compute normals for a point cloud using GPU acceleration with default options
357    pub async fn compute_normals(&self, points: &[Point3f], k_neighbors: usize) -> Result<Vec<nalgebra::Vector3<f32>>> {
358        self.compute_normals_with_options(points, k_neighbors, true, None).await
359    }
360
361    /// Simple neighbor computation (brute force - could be replaced with KD-tree)
362    pub fn compute_neighbors_simple(&self, points: &[[f32; 3]], k: usize) -> Vec<[u32; 64]> {
363        let mut neighbors = vec![[0u32; 64]; points.len()];
364        let k = k.min(64).min(points.len()); // Limit to 64 neighbors and available points
365        
366        for (i, point) in points.iter().enumerate() {
367            let mut distances: Vec<(f32, usize)> = points
368                .iter()
369                .enumerate()
370                .filter(|(j, _)| *j != i)
371                .map(|(j, other)| {
372                    let dx = point[0] - other[0];
373                    let dy = point[1] - other[1];
374                    let dz = point[2] - other[2];
375                    (dx * dx + dy * dy + dz * dz, j)
376                })
377                .collect();
378            
379            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
380            
381            for (idx, &(_, neighbor_idx)) in distances.iter().take(k).enumerate() {
382                neighbors[i][idx] = neighbor_idx as u32;
383            }
384            
385            // Fill remaining slots with the same neighbor to avoid issues
386            for idx in k..64 {
387                neighbors[i][idx] = if k > 0 { neighbors[i][k - 1] } else { i as u32 };
388            }
389        }
390        
391        neighbors
392    }
393
394    /// Helper to compute neighbors from Vec<[f32;3]> built from owned data
395    pub fn compute_neighbors_simple_points3(&self, points: &[[f32; 3]], k: usize) -> Vec<[u32; 64]> {
396        self.compute_neighbors_simple(points, k)
397    }
398}
399
400/// GPU-accelerated normal estimation for point clouds
401pub async fn gpu_estimate_normals(
402    gpu_context: &GpuContext,
403    cloud: &mut PointCloud<Point3f>,
404    k: usize,
405) -> Result<PointCloud<NormalPoint3f>> {
406    let normals = gpu_context.compute_normals(&cloud.points, k).await?;
407    
408    let normal_points: Vec<NormalPoint3f> = cloud
409        .points
410        .iter()
411        .zip(normals.iter())
412        .map(|(point, normal)| NormalPoint3f {
413            position: *point,
414            normal: *normal,
415        })
416        .collect();
417    
418    Ok(PointCloud::from_points(normal_points))
419} 
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use threecrate_core::{Point3f, PointCloud};
425
426    /// Try to create a GPU context, return None if not available
427    async fn try_create_gpu_context() -> Option<crate::GpuContext> {
428        match crate::GpuContext::new().await {
429            Ok(gpu) => Some(gpu),
430            Err(_) => {
431                println!("⚠️  GPU not available, skipping GPU-dependent test");
432                None
433            }
434        }
435    }
436
437    #[tokio::test]
438    async fn test_gpu_normals_plane() {
439        let Some(gpu) = try_create_gpu_context().await else {
440            return;
441        };
442        
443        let mut cloud = PointCloud::new();
444        // Create XY plane grid
445        for i in 0..15 { for j in 0..15 {
446            cloud.push(Point3f::new(i as f32 * 0.1, j as f32 * 0.1, 0.0));
447        }}
448        let result = gpu_estimate_normals(&gpu, &mut cloud, 8).await.unwrap();
449        assert_eq!(result.len(), 225);
450        let mut z_count = 0;
451        for p in result.iter() {
452            if p.normal.z.abs() > 0.8 { z_count += 1; }
453        }
454        let pct = (z_count as f32 / result.len() as f32) * 100.0;
455        assert!(pct > 80.0, "Only {:.1}% normals in Z direction", pct);
456    }
457
458    #[tokio::test]
459    async fn test_gpu_normals_compare_cpu_plane() {
460        use threecrate_algorithms::estimate_normals as cpu_estimate_normals;
461        let Some(gpu) = try_create_gpu_context().await else {
462            return;
463        };
464        
465        let mut cloud = PointCloud::new();
466        for i in 0..10 { for j in 0..10 {
467            cloud.push(Point3f::new(i as f32 * 0.1, j as f32 * 0.1, 0.0));
468        }}
469        let gpu_cloud = gpu_estimate_normals(&gpu, &mut cloud.clone(), 8).await.unwrap();
470        let cpu_cloud = cpu_estimate_normals(&cloud, 8).unwrap();
471        // Compare orientation alignment percentage
472        let mut agree = 0usize;
473        for (g, c) in gpu_cloud.iter().zip(cpu_cloud.iter()) {
474            let dot = g.normal.dot(&c.normal);
475            if dot.abs() > 0.7 { agree += 1; }
476        }
477        let pct = (agree as f32 / gpu_cloud.len() as f32) * 100.0;
478        assert!(pct > 70.0, "GPU-CPU normals agree only {:.1}%", pct);
479    }
480}