1use threecrate_core::{PointCloud, Result, Point3f, NormalPoint3f};
4use crate::GpuContext;
5const NORMALS_SHADER: &str = r#"
8@group(0) @binding(0) var<storage, read> input_points: array<vec4<f32>>;
9@group(0) @binding(1) var<storage, read_write> output_normals: array<vec4<f32>>;
10@group(0) @binding(2) var<storage, read> neighbors: array<array<u32, 64>>;
11@group(0) @binding(3) var<uniform> params: NormalParams;
12
13struct NormalParams {
14 num_points: u32,
15 k_neighbors: u32,
16 consistent_orientation: u32,
17 _pad: u32,
18 viewpoint: vec4<f32>,
19}
20
21fn mat3_mul_vec3(m00: f32, m01: f32, m02: f32,
22 m10: f32, m11: f32, m12: f32,
23 m20: f32, m21: f32, m22: f32,
24 v: vec3<f32>) -> vec3<f32> {
25 return vec3<f32>(
26 m00 * v.x + m01 * v.y + m02 * v.z,
27 m10 * v.x + m11 * v.y + m12 * v.z,
28 m20 * v.x + m21 * v.y + m22 * v.z
29 );
30}
31
32@compute @workgroup_size(64)
33fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
34 let index = global_id.x;
35 if (index >= params.num_points) {
36 return;
37 }
38
39 let center_point = input_points[index].xyz;
40
41 // Compute covariance matrix via neighbor centroid
42 var centroid = vec3<f32>(0.0);
43 var count = 0u;
44 for (var i = 0u; i < params.k_neighbors; i++) {
45 let neighbor_idx = neighbors[index][i];
46 if (neighbor_idx < params.num_points) {
47 centroid += input_points[neighbor_idx].xyz;
48 count++;
49 }
50 }
51
52 if (count < 3u) {
53 output_normals[index] = vec4<f32>(0.0, 0.0, 1.0, 0.0);
54 return;
55 }
56
57 centroid /= f32(count);
58
59 // Covariance components
60 var c00 = 0.0; var c01 = 0.0; var c02 = 0.0;
61 var c11 = 0.0; var c12 = 0.0; var c22 = 0.0;
62 for (var i = 0u; i < params.k_neighbors; i++) {
63 let neighbor_idx = neighbors[index][i];
64 if (neighbor_idx < params.num_points) {
65 let d = input_points[neighbor_idx].xyz - centroid;
66 c00 += d.x * d.x;
67 c01 += d.x * d.y;
68 c02 += d.x * d.z;
69 c11 += d.y * d.y;
70 c12 += d.y * d.z;
71 c22 += d.z * d.z;
72 }
73 }
74 let inv_count = 1.0 / f32(count);
75 c00 *= inv_count; c01 *= inv_count; c02 *= inv_count;
76 c11 *= inv_count; c12 *= inv_count; c22 *= inv_count;
77
78 // Power iteration on shifted matrix D = trace(C) * I - C to get eigenvector of smallest eigenvalue of C
79 let trace_c = c00 + c11 + c22;
80
81 // Initial vector using cross of two neighbor directions for stability
82 let n0 = neighbors[index][0u];
83 let n1 = neighbors[index][1u];
84 var v = vec3<f32>(0.0, 0.0, 1.0);
85 if (n0 < params.num_points && n1 < params.num_points) {
86 let d1 = normalize(input_points[n0].xyz - center_point);
87 let d2 = normalize(input_points[n1].xyz - center_point);
88 let cp = cross(d1, d2);
89 if (length(cp) > 1e-6) {
90 v = normalize(cp);
91 }
92 }
93
94 // Perform fixed number of iterations
95 for (var it = 0u; it < 8u; it++) {
96 // Multiply v by D = trace*I - C
97 let Cv = mat3_mul_vec3(c00, c01, c02, c01, c11, c12, c02, c12, c22, v);
98 let Dv = vec3<f32>(trace_c * v.x, trace_c * v.y, trace_c * v.z) - Cv;
99 let lenDv = length(Dv);
100 if (lenDv > 1e-8) {
101 v = Dv / lenDv;
102 }
103 }
104 var normal = v;
105
106 // Orientation consistency
107 if (params.consistent_orientation == 1u) {
108 let to_view = normalize(params.viewpoint.xyz - center_point);
109 if (dot(normal, to_view) < 0.0) {
110 normal = -normal;
111 }
112 }
113
114 output_normals[index] = vec4<f32>(normal, 0.0);
115}
116"#;
117
118impl GpuContext {
119 pub async fn compute_normals_with_options(
121 &self,
122 points: &[Point3f],
123 k_neighbors: usize,
124 consistent_orientation: bool,
125 viewpoint: Option<[f32; 3]>,
126 ) -> Result<Vec<nalgebra::Vector3<f32>>> {
127 if points.is_empty() {
128 return Ok(Vec::new());
129 }
130
131 let point_data: Vec<[f32; 4]> = points
133 .iter()
134 .map(|p| [p.x, p.y, p.z, 0.0])
135 .collect();
136
137 let input_buffer = self.create_buffer_init(
139 "Input Points",
140 &point_data,
141 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
142 );
143
144 let output_buffer = self.create_buffer(
145 "Output Normals",
146 (point_data.len() * std::mem::size_of::<[f32; 4]>()) as u64,
147 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
148 );
149
150 let k_neighbors = k_neighbors.max(3).min(64);
152 let neighbors = self.compute_neighbors_simple_points3(&points.iter().map(|p| [p.x, p.y, p.z]).collect::<Vec<[f32;3]>>(), k_neighbors);
153 let neighbors_buffer = self.create_buffer_init(
154 "Neighbors",
155 &neighbors,
156 wgpu::BufferUsages::STORAGE,
157 );
158
159 #[repr(C)]
160 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
161 struct NormalParams {
162 num_points: u32,
163 k_neighbors: u32,
164 consistent_orientation: u32,
165 _pad: u32,
166 viewpoint: [f32; 4],
167 }
168
169 let vp = if let Some(vp) = viewpoint {
171 [vp[0], vp[1], vp[2], 0.0]
172 } else {
173 let mut min_x = point_data[0][0];
174 let mut min_y = point_data[0][1];
175 let mut min_z = point_data[0][2];
176 let mut max_x = point_data[0][0];
177 let mut max_y = point_data[0][1];
178 let mut max_z = point_data[0][2];
179 for p in &point_data {
180 min_x = min_x.min(p[0]);
181 min_y = min_y.min(p[1]);
182 min_z = min_z.min(p[2]);
183 max_x = max_x.max(p[0]);
184 max_y = max_y.max(p[1]);
185 max_z = max_z.max(p[2]);
186 }
187 let cx = (min_x + max_x) * 0.5;
188 let cy = (min_y + max_y) * 0.5;
189 let cz = (min_z + max_z) * 0.5;
190 let dx = max_x - min_x;
191 let dy = max_y - min_y;
192 let dz = max_z - min_z;
193 let extent = (dx * dx + dy * dy + dz * dz).sqrt();
194 [cx, cy, cz + extent, 0.0]
195 };
196
197 let params = NormalParams {
198 num_points: points.len() as u32,
199 k_neighbors: k_neighbors as u32,
200 consistent_orientation: if consistent_orientation { 1 } else { 0 },
201 _pad: 0,
202 viewpoint: vp,
203 };
204
205 let params_buffer = self.create_buffer_init(
206 "Params",
207 &[params],
208 wgpu::BufferUsages::UNIFORM,
209 );
210
211 let shader = self.create_shader_module("Normals Compute", NORMALS_SHADER);
213
214 let bind_group_layout = self.create_bind_group_layout(
216 "Normal Computation",
217 &[
218 wgpu::BindGroupLayoutEntry {
219 binding: 0,
220 visibility: wgpu::ShaderStages::COMPUTE,
221 ty: wgpu::BindingType::Buffer {
222 ty: wgpu::BufferBindingType::Storage { read_only: true },
223 has_dynamic_offset: false,
224 min_binding_size: None,
225 },
226 count: None,
227 },
228 wgpu::BindGroupLayoutEntry {
229 binding: 1,
230 visibility: wgpu::ShaderStages::COMPUTE,
231 ty: wgpu::BindingType::Buffer {
232 ty: wgpu::BufferBindingType::Storage { read_only: false },
233 has_dynamic_offset: false,
234 min_binding_size: None,
235 },
236 count: None,
237 },
238 wgpu::BindGroupLayoutEntry {
239 binding: 2,
240 visibility: wgpu::ShaderStages::COMPUTE,
241 ty: wgpu::BindingType::Buffer {
242 ty: wgpu::BufferBindingType::Storage { read_only: true },
243 has_dynamic_offset: false,
244 min_binding_size: None,
245 },
246 count: None,
247 },
248 wgpu::BindGroupLayoutEntry {
249 binding: 3,
250 visibility: wgpu::ShaderStages::COMPUTE,
251 ty: wgpu::BindingType::Buffer {
252 ty: wgpu::BufferBindingType::Uniform,
253 has_dynamic_offset: false,
254 min_binding_size: None,
255 },
256 count: None,
257 },
258 ],
259 );
260
261 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
263 label: Some("Normal Computation Pipeline"),
264 layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
265 label: Some("Normal Pipeline Layout"),
266 bind_group_layouts: &[&bind_group_layout],
267 push_constant_ranges: &[],
268 })),
269 module: &shader,
270 entry_point: "main",
271 compilation_options: wgpu::PipelineCompilationOptions::default(),
272 });
273
274 let bind_group = self.create_bind_group(
276 "Normal Computation",
277 &bind_group_layout,
278 &[
279 wgpu::BindGroupEntry {
280 binding: 0,
281 resource: input_buffer.as_entire_binding(),
282 },
283 wgpu::BindGroupEntry {
284 binding: 1,
285 resource: output_buffer.as_entire_binding(),
286 },
287 wgpu::BindGroupEntry {
288 binding: 2,
289 resource: neighbors_buffer.as_entire_binding(),
290 },
291 wgpu::BindGroupEntry {
292 binding: 3,
293 resource: params_buffer.as_entire_binding(),
294 },
295 ],
296 );
297
298 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
300 label: Some("Normal Computation"),
301 });
302
303 {
304 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
305 label: Some("Normal Computation Pass"),
306 timestamp_writes: None,
307 });
308 compute_pass.set_pipeline(&pipeline);
309 compute_pass.set_bind_group(0, &bind_group, &[]);
310 let workgroup_count = (points.len() + 63) / 64; compute_pass.dispatch_workgroups(workgroup_count as u32, 1, 1);
312 }
313
314 let staging_buffer = self.create_buffer(
316 "Staging Buffer",
317 (point_data.len() * std::mem::size_of::<[f32; 4]>()) as u64,
318 wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
319 );
320
321 encoder.copy_buffer_to_buffer(
322 &output_buffer,
323 0,
324 &staging_buffer,
325 0,
326 (point_data.len() * std::mem::size_of::<[f32; 4]>()) as u64,
327 );
328
329 self.queue.submit(std::iter::once(encoder.finish()));
330
331 let buffer_slice = staging_buffer.slice(..);
333 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
334 buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
335
336 self.device.poll(wgpu::Maintain::wait()).panic_on_timeout();
337
338 if let Some(Ok(())) = receiver.receive().await {
339 let data = buffer_slice.get_mapped_range();
340 let normals4: Vec<[f32; 4]> = bytemuck::cast_slice(&data).to_vec();
341
342 let result = normals4
343 .into_iter()
344 .map(|n| nalgebra::Vector3::new(n[0], n[1], n[2]))
345 .collect();
346
347 drop(data);
348 staging_buffer.unmap();
349
350 Ok(result)
351 } else {
352 Err(threecrate_core::Error::Gpu("Failed to read GPU results".to_string()))
353 }
354 }
355
356 pub async fn compute_normals(&self, points: &[Point3f], k_neighbors: usize) -> Result<Vec<nalgebra::Vector3<f32>>> {
358 self.compute_normals_with_options(points, k_neighbors, true, None).await
359 }
360
361 pub fn compute_neighbors_simple(&self, points: &[[f32; 3]], k: usize) -> Vec<[u32; 64]> {
363 let mut neighbors = vec![[0u32; 64]; points.len()];
364 let k = k.min(64).min(points.len()); for (i, point) in points.iter().enumerate() {
367 let mut distances: Vec<(f32, usize)> = points
368 .iter()
369 .enumerate()
370 .filter(|(j, _)| *j != i)
371 .map(|(j, other)| {
372 let dx = point[0] - other[0];
373 let dy = point[1] - other[1];
374 let dz = point[2] - other[2];
375 (dx * dx + dy * dy + dz * dz, j)
376 })
377 .collect();
378
379 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
380
381 for (idx, &(_, neighbor_idx)) in distances.iter().take(k).enumerate() {
382 neighbors[i][idx] = neighbor_idx as u32;
383 }
384
385 for idx in k..64 {
387 neighbors[i][idx] = if k > 0 { neighbors[i][k - 1] } else { i as u32 };
388 }
389 }
390
391 neighbors
392 }
393
394 pub fn compute_neighbors_simple_points3(&self, points: &[[f32; 3]], k: usize) -> Vec<[u32; 64]> {
396 self.compute_neighbors_simple(points, k)
397 }
398}
399
400pub async fn gpu_estimate_normals(
402 gpu_context: &GpuContext,
403 cloud: &mut PointCloud<Point3f>,
404 k: usize,
405) -> Result<PointCloud<NormalPoint3f>> {
406 let normals = gpu_context.compute_normals(&cloud.points, k).await?;
407
408 let normal_points: Vec<NormalPoint3f> = cloud
409 .points
410 .iter()
411 .zip(normals.iter())
412 .map(|(point, normal)| NormalPoint3f {
413 position: *point,
414 normal: *normal,
415 })
416 .collect();
417
418 Ok(PointCloud::from_points(normal_points))
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use threecrate_core::{Point3f, PointCloud};
425
426 async fn try_create_gpu_context() -> Option<crate::GpuContext> {
428 match crate::GpuContext::new().await {
429 Ok(gpu) => Some(gpu),
430 Err(_) => {
431 println!("⚠️ GPU not available, skipping GPU-dependent test");
432 None
433 }
434 }
435 }
436
437 #[tokio::test]
438 async fn test_gpu_normals_plane() {
439 let Some(gpu) = try_create_gpu_context().await else {
440 return;
441 };
442
443 let mut cloud = PointCloud::new();
444 for i in 0..15 { for j in 0..15 {
446 cloud.push(Point3f::new(i as f32 * 0.1, j as f32 * 0.1, 0.0));
447 }}
448 let result = gpu_estimate_normals(&gpu, &mut cloud, 8).await.unwrap();
449 assert_eq!(result.len(), 225);
450 let mut z_count = 0;
451 for p in result.iter() {
452 if p.normal.z.abs() > 0.8 { z_count += 1; }
453 }
454 let pct = (z_count as f32 / result.len() as f32) * 100.0;
455 assert!(pct > 80.0, "Only {:.1}% normals in Z direction", pct);
456 }
457
458 #[tokio::test]
459 async fn test_gpu_normals_compare_cpu_plane() {
460 use threecrate_algorithms::estimate_normals as cpu_estimate_normals;
461 let Some(gpu) = try_create_gpu_context().await else {
462 return;
463 };
464
465 let mut cloud = PointCloud::new();
466 for i in 0..10 { for j in 0..10 {
467 cloud.push(Point3f::new(i as f32 * 0.1, j as f32 * 0.1, 0.0));
468 }}
469 let gpu_cloud = gpu_estimate_normals(&gpu, &mut cloud.clone(), 8).await.unwrap();
470 let cpu_cloud = cpu_estimate_normals(&cloud, 8).unwrap();
471 let mut agree = 0usize;
473 for (g, c) in gpu_cloud.iter().zip(cpu_cloud.iter()) {
474 let dot = g.normal.dot(&c.normal);
475 if dot.abs() > 0.7 { agree += 1; }
476 }
477 let pct = (agree as f32 / gpu_cloud.len() as f32) * 100.0;
478 assert!(pct > 70.0, "GPU-CPU normals agree only {:.1}%", pct);
479 }
480}