threecrate_gpu/
tsdf.rs

1use crate::device::GpuContext;
2use threecrate_core::{PointCloud, ColoredPoint3f, Error, Result};
3use nalgebra::{Matrix4, Point3};
4use bytemuck::{Pod, Zeroable};
5use wgpu::util::DeviceExt;
6
7/// TSDF voxel data for GPU processing
8#[repr(C)]
9#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)]
10#[repr(align(16))]  // Ensure 16-byte alignment for GPU
11pub struct TsdfVoxel {
12    pub tsdf_value: f32,
13    pub weight: f32,
14    pub color_r: u32,
15    pub color_g: u32,
16    pub color_b: u32,
17    pub _padding1: u32,
18    pub _padding2: u32,
19    pub _padding3: u32,
20}
21
22/// TSDF volume parameters
23#[derive(Debug, Clone)]
24pub struct TsdfVolume {
25    pub voxel_size: f32,
26    pub truncation_distance: f32,
27    pub resolution: [u32; 3], // [width, height, depth]
28    pub origin: Point3<f32>,
29}
30
31/// Represents a TSDF volume stored on the GPU.
32pub struct TsdfVolumeGpu {
33    pub volume: TsdfVolume,
34    pub voxel_buffer: wgpu::Buffer,
35}
36
37/// Camera intrinsic parameters
38#[repr(C)]
39#[derive(Copy, Clone, Pod, Zeroable)]
40#[repr(align(16))]  // Ensure 16-byte alignment for GPU
41pub struct CameraIntrinsics {
42    pub fx: f32,
43    pub fy: f32,
44    pub cx: f32,
45    pub cy: f32,
46    pub width: u32,
47    pub height: u32,
48    pub depth_scale: f32,
49    pub _padding: f32,
50}
51
52/// TSDF integration parameters
53#[repr(C)]
54#[derive(Copy, Clone, Pod, Zeroable)]
55#[repr(align(16))]
56pub struct TsdfParams {
57    pub voxel_size: f32,
58    pub truncation_distance: f32,
59    pub max_weight: f32,
60    pub iso_value: f32,
61    pub resolution: [u32; 3],
62    pub _padding2: u32,
63    pub origin: [f32; 3],
64    pub _padding3: f32,
65}
66
67#[repr(C)]
68#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)]
69#[repr(align(16))]  // Ensure 16-byte alignment for GPU
70pub struct GpuPoint3f {
71    pub x: f32,
72    pub y: f32,
73    pub z: f32,
74    pub r: u32,
75    pub g: u32,
76    pub b: u32,
77    pub _padding1: u32,
78    pub _padding2: u32,
79}
80
81impl GpuContext {
82    /// Integrate depth image into TSDF volume
83    pub async fn tsdf_integrate(
84        &self,
85        volume: &mut TsdfVolume,
86        depth_image: &[f32],
87        color_image: Option<&[u8]>, // RGB color data
88        camera_pose: &Matrix4<f32>,
89        intrinsics: &CameraIntrinsics,
90    ) -> Result<Vec<TsdfVoxel>> {
91        let total_voxels = (volume.resolution[0] * volume.resolution[1] * volume.resolution[2]) as usize;
92        
93        // Create buffers
94        let depth_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
95            label: Some("TSDF Depth Buffer"),
96            contents: bytemuck::cast_slice(depth_image),
97            usage: wgpu::BufferUsages::STORAGE,
98        });
99
100        let color_buffer = if let Some(color_data) = color_image {
101            // Convert RGB u8 data to packed u32 RGB values
102            let mut packed_colors = Vec::with_capacity(color_data.len() / 3);
103            for chunk in color_data.chunks_exact(3) {
104                let r = chunk[0] as u32;
105                let g = chunk[1] as u32;
106                let b = chunk[2] as u32;
107                let packed = (r << 16) | (g << 8) | b;
108                packed_colors.push(packed);
109            }
110            
111            self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
112                label: Some("TSDF Color Buffer"),
113                contents: bytemuck::cast_slice(&packed_colors),
114                usage: wgpu::BufferUsages::STORAGE,
115            })
116        } else {
117            // Create empty buffer if no color data
118            self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
119                label: Some("TSDF Empty Color Buffer"),
120                contents: bytemuck::cast_slice(&[0u32; 4]), // Small dummy buffer
121                usage: wgpu::BufferUsages::STORAGE,
122            })
123        };
124
125        // Initialize TSDF volume if needed
126        let initial_voxels = vec![TsdfVoxel {
127            tsdf_value: 1.0,
128            weight: 0.0,
129            color_r: 0,
130            color_g: 0,
131            color_b: 0,
132            _padding1: 0,
133            _padding2: 0,
134            _padding3: 0,
135        }; total_voxels];
136
137        let voxel_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
138            label: Some("TSDF Voxel Buffer"),
139            contents: bytemuck::cast_slice(&initial_voxels),
140            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
141        });
142
143        // Convert camera transform to world-to-camera matrix (inverse of camera pose)
144        let world_to_camera = camera_pose.try_inverse()
145            .ok_or_else(|| Error::Gpu("Failed to invert camera pose matrix".into()))?;
146        
147        let mut camera_transform = [[0.0f32; 4]; 4];
148        for i in 0..4 {
149            for j in 0..4 {
150                camera_transform[i][j] = world_to_camera[(i, j)];
151            }
152        }
153
154        let transform_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
155            label: Some("TSDF Transform Buffer"),
156            contents: bytemuck::cast_slice(&[camera_transform]),
157            usage: wgpu::BufferUsages::UNIFORM,
158        });
159
160        let intrinsics_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
161            label: Some("TSDF Intrinsics Buffer"),
162            contents: bytemuck::bytes_of(intrinsics),
163            usage: wgpu::BufferUsages::UNIFORM,
164        });
165
166        let params = TsdfParams {
167            voxel_size: volume.voxel_size,
168            truncation_distance: volume.truncation_distance,
169            max_weight: 100.0,
170            iso_value: 0.0,
171            resolution: volume.resolution,
172            _padding2: 0,
173            origin: [volume.origin.x, volume.origin.y, volume.origin.z],
174            _padding3: 0.0,
175        };
176
177        let params_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
178            label: Some("TSDF Params Buffer"),
179            contents: bytemuck::bytes_of(&params),
180            usage: wgpu::BufferUsages::UNIFORM,
181        });
182
183        // Create compute pipeline
184        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
185            label: Some("TSDF Integration Shader"),
186            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/tsdf_integration.wgsl").into()),
187        });
188
189        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
190            label: Some("TSDF Integration Pipeline"),
191            layout: None,
192            module: &shader,
193            entry_point: "main",
194            compilation_options: wgpu::PipelineCompilationOptions::default(),
195        });
196
197        // Create bind group
198        let bind_group_entries = vec![
199            wgpu::BindGroupEntry {
200                binding: 0,
201                resource: voxel_buffer.as_entire_binding(),
202            },
203            wgpu::BindGroupEntry {
204                binding: 1,
205                resource: depth_buffer.as_entire_binding(),
206            },
207            wgpu::BindGroupEntry {
208                binding: 2,
209                resource: transform_buffer.as_entire_binding(),
210            },
211            wgpu::BindGroupEntry {
212                binding: 3,
213                resource: intrinsics_buffer.as_entire_binding(),
214            },
215            wgpu::BindGroupEntry {
216                binding: 4,
217                resource: params_buffer.as_entire_binding(),
218            },
219            wgpu::BindGroupEntry {
220                binding: 5,
221                resource: color_buffer.as_entire_binding(),
222            },
223        ];
224
225        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
226            label: Some("TSDF Integration Bind Group"),
227            layout: &pipeline.get_bind_group_layout(0),
228            entries: &bind_group_entries,
229        });
230
231        // Dispatch compute shader
232        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
233            label: Some("TSDF Integration Encoder"),
234        });
235
236        {
237            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
238                label: Some("TSDF Integration Pass"),
239                timestamp_writes: None,
240            });
241
242            compute_pass.set_pipeline(&pipeline);
243            compute_pass.set_bind_group(0, &bind_group, &[]);
244            
245            // Dispatch with 4x4x4 workgroups
246            let workgroup_size = 4;
247            let dispatch_x = (volume.resolution[0] + workgroup_size - 1) / workgroup_size;
248            let dispatch_y = (volume.resolution[1] + workgroup_size - 1) / workgroup_size;
249            let dispatch_z = (volume.resolution[2] + workgroup_size - 1) / workgroup_size;
250            
251            compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, dispatch_z);
252        }
253
254        // Read back results
255        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
256            label: Some("TSDF Staging Buffer"),
257            size: (total_voxels * std::mem::size_of::<TsdfVoxel>()) as u64,
258            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
259            mapped_at_creation: false,
260        });
261
262        encoder.copy_buffer_to_buffer(
263            &voxel_buffer,
264            0,
265            &staging_buffer,
266            0,
267            staging_buffer.size(),
268        );
269
270        self.queue.submit(std::iter::once(encoder.finish()));
271
272        let buffer_slice = staging_buffer.slice(..);
273        let (sender, receiver) = flume::unbounded();
274        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
275            sender.send(result).unwrap();
276        });
277
278        self.device.poll(wgpu::Maintain::Wait);
279        receiver.recv_async().await.map_err(|_| Error::Gpu("Failed to receive mapping result".into()))?
280            .map_err(|e| Error::Gpu(format!("Buffer mapping failed: {:?}", e)))?;
281
282        let data = buffer_slice.get_mapped_range();
283        let result: Vec<TsdfVoxel> = bytemuck::cast_slice(&data).to_vec();
284        
285        drop(data);
286        staging_buffer.unmap();
287
288        Ok(result)
289    }
290
291    /// Extract point cloud from TSDF volume using marching cubes
292    pub async fn tsdf_extract_surface(
293        &self,
294        volume: &TsdfVolume,
295        voxels: &[TsdfVoxel],
296        iso_value: f32,
297    ) -> Result<PointCloud<ColoredPoint3f>> {
298        let total_voxels = (volume.resolution[0] * volume.resolution[1] * volume.resolution[2]) as usize;
299        let max_points = std::cmp::min(total_voxels, 1_000_000); // Limit to reasonable size
300        
301        // Create voxel buffer
302        let voxel_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
303            label: Some("TSDF Voxel Buffer"),
304            contents: bytemuck::cast_slice(voxels),
305            usage: wgpu::BufferUsages::STORAGE,
306        });
307
308        // Create output buffers
309        let points_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
310            label: Some("Surface Points Buffer"),
311            size: (max_points * std::mem::size_of::<GpuPoint3f>()) as u64,
312            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
313            mapped_at_creation: false,
314        });
315
316        let point_count_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
317            label: Some("Point Count Buffer"),
318            contents: bytemuck::bytes_of(&0u32),
319            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
320        });
321
322        let params = TsdfParams {
323            voxel_size: volume.voxel_size,
324            truncation_distance: volume.truncation_distance,
325            max_weight: 100.0,
326            iso_value,
327            resolution: volume.resolution,
328            _padding2: 0,
329            origin: [volume.origin.x, volume.origin.y, volume.origin.z],
330            _padding3: 0.0,
331        };
332
333        let params_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
334            label: Some("Surface Extraction Params Buffer"),
335            contents: bytemuck::bytes_of(&params),
336            usage: wgpu::BufferUsages::UNIFORM,
337        });
338
339        // Create compute pipeline
340        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
341            label: Some("Surface Extraction Shader"),
342            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/surface_extraction.wgsl").into()),
343        });
344
345        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
346            label: Some("Surface Extraction Pipeline"),
347            layout: None,
348            module: &shader,
349            entry_point: "main",
350            compilation_options: wgpu::PipelineCompilationOptions::default(),
351        });
352
353        // Create bind group
354        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
355            label: Some("Surface Extraction Bind Group"),
356            layout: &pipeline.get_bind_group_layout(0),
357            entries: &[
358                wgpu::BindGroupEntry {
359                    binding: 0,
360                    resource: voxel_buffer.as_entire_binding(),
361                },
362                wgpu::BindGroupEntry {
363                    binding: 1,
364                    resource: points_buffer.as_entire_binding(),
365                },
366                wgpu::BindGroupEntry {
367                    binding: 2,
368                    resource: params_buffer.as_entire_binding(),
369                },
370                wgpu::BindGroupEntry {
371                    binding: 3,
372                    resource: point_count_buffer.as_entire_binding(),
373                },
374            ],
375        });
376
377        // Create staging buffer for reading back results
378        let point_count_staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
379            label: Some("Point Count Staging Buffer"),
380            size: std::mem::size_of::<u32>() as u64,
381            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
382            mapped_at_creation: false,
383        });
384
385        // Dispatch compute shader
386        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
387            label: Some("Surface Extraction Encoder"),
388        });
389
390        {
391            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
392                label: Some("Surface Extraction Pass"),
393                timestamp_writes: None,
394            });
395
396            compute_pass.set_pipeline(&pipeline);
397            compute_pass.set_bind_group(0, &bind_group, &[]);
398            compute_pass.dispatch_workgroups(
399                (volume.resolution[0] + 3) / 4,
400                (volume.resolution[1] + 3) / 4,
401                (volume.resolution[2] + 3) / 4,
402            );
403        }
404
405        // Copy point count to staging buffer
406        encoder.copy_buffer_to_buffer(
407            &point_count_buffer,
408            0,
409            &point_count_staging_buffer,
410            0,
411            std::mem::size_of::<u32>() as u64,
412        );
413
414        self.queue.submit(Some(encoder.finish()));
415
416        // Read point count
417        let point_count_slice = point_count_staging_buffer.slice(..);
418        let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
419        point_count_slice.map_async(wgpu::MapMode::Read, move |result| {
420            tx.send(result).unwrap();
421        });
422        self.device.poll(wgpu::Maintain::Wait);
423        rx.receive().await.unwrap()?;
424
425        let mapped_range = point_count_slice.get_mapped_range();
426        let point_count = bytemuck::cast_slice::<u8, u32>(mapped_range.as_ref())[0] as usize;
427        drop(mapped_range);
428        point_count_staging_buffer.unmap();
429
430        if point_count == 0 {
431            return Ok(PointCloud { points: Vec::new() });
432        }
433
434        // Create staging buffer for points
435        let points_staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
436            label: Some("Points Staging Buffer"),
437            size: (point_count * std::mem::size_of::<GpuPoint3f>()) as u64,
438            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
439            mapped_at_creation: false,
440        });
441
442        // Copy points to staging buffer
443        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
444            label: Some("Points Copy Encoder"),
445        });
446
447        encoder.copy_buffer_to_buffer(
448            &points_buffer,
449            0,
450            &points_staging_buffer,
451            0,
452            (point_count * std::mem::size_of::<GpuPoint3f>()) as u64,
453        );
454
455        self.queue.submit(Some(encoder.finish()));
456
457        // Read points
458        let points_slice = points_staging_buffer.slice(..);
459        let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
460        points_slice.map_async(wgpu::MapMode::Read, move |result| {
461            tx.send(result).unwrap();
462        });
463        self.device.poll(wgpu::Maintain::Wait);
464        rx.receive().await.unwrap()?;
465
466        let mapped_range = points_slice.get_mapped_range();
467        let gpu_points = bytemuck::cast_slice::<u8, GpuPoint3f>(mapped_range.as_ref());
468        let mut points = Vec::with_capacity(point_count);
469
470        for gpu_point in gpu_points.iter().take(point_count) {
471            points.push(ColoredPoint3f {
472                position: Point3::new(gpu_point.x, gpu_point.y, gpu_point.z),
473                color: [gpu_point.r as u8, gpu_point.g as u8, gpu_point.b as u8],
474            });
475        }
476
477        drop(mapped_range);  // Explicitly drop the mapped range before unmapping
478        points_staging_buffer.unmap();
479
480        Ok(PointCloud { points })
481    }
482}
483
484impl TsdfVolumeGpu {
485    /// Creates a new TSDF volume on the GPU.
486    pub fn new(gpu: &GpuContext, volume_params: TsdfVolume) -> Self {
487        let total_voxels = (volume_params.resolution[0] * volume_params.resolution[1] * volume_params.resolution[2]) as usize;
488        
489        // Initialize voxels with default values
490        let initial_voxels = vec![TsdfVoxel {
491            tsdf_value: 1.0,
492            weight: 0.0,
493            color_r: 0,
494            color_g: 0,
495            color_b: 0,
496            _padding1: 0,
497            _padding2: 0,
498            _padding3: 0,
499        }; total_voxels];
500
501        let voxel_buffer = gpu.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
502            label: Some("TSDF Voxel Buffer"),
503            contents: bytemuck::cast_slice(&initial_voxels),
504            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST,
505        });
506
507        Self {
508            volume: volume_params,
509            voxel_buffer,
510        }
511    }
512
513    /// Integrates a depth image into the TSDF volume.
514    pub async fn integrate(
515        &self,
516        gpu: &GpuContext,
517        depth_image: &[f32],
518        color_image: Option<&[u8]>, // RGB color data
519        camera_pose: &Matrix4<f32>,
520        intrinsics: &CameraIntrinsics,
521    ) -> Result<()> {
522        // Create buffers for depth, color, transform, and parameters
523        let depth_buffer = gpu.create_buffer_init("TSDF Depth Buffer", depth_image, wgpu::BufferUsages::STORAGE);
524
525        let color_buffer = if let Some(data) = color_image {
526            gpu.create_buffer_init("TSDF Color Buffer", data, wgpu::BufferUsages::STORAGE)
527        } else {
528            // Create a dummy buffer if no color image is provided
529            gpu.create_buffer_init("TSDF Dummy Color Buffer", &[0u32; 4], wgpu::BufferUsages::STORAGE)
530        };
531
532        // Convert camera transform to world-to-camera matrix (inverse of camera pose)
533        let world_to_camera = camera_pose.try_inverse()
534            .ok_or_else(|| Error::Gpu("Failed to invert camera pose matrix".into()))?;
535        
536        let mut camera_transform = [[0.0f32; 4]; 4];
537        for i in 0..4 {
538            for j in 0..4 {
539                camera_transform[i][j] = world_to_camera[(i, j)];
540            }
541        }
542
543        let transform_buffer = gpu.create_buffer_init(
544            "TSDF Transform Buffer",
545            &[camera_transform],
546            wgpu::BufferUsages::UNIFORM,
547        );
548
549        let intrinsics_buffer = gpu.create_buffer_init(
550            "TSDF Intrinsics Buffer",
551            &[*intrinsics],
552            wgpu::BufferUsages::UNIFORM,
553        );
554
555        let params = TsdfParams {
556            voxel_size: self.volume.voxel_size,
557            truncation_distance: self.volume.truncation_distance,
558            max_weight: 100.0,
559            iso_value: 0.0,
560            resolution: self.volume.resolution,
561            _padding2: 0,
562            origin: [self.volume.origin.x, self.volume.origin.y, self.volume.origin.z],
563            _padding3: 0.0,
564        };
565        let params_buffer = gpu.create_buffer_init("TSDF Params Buffer", &[params], wgpu::BufferUsages::UNIFORM);
566
567        // Create compute pipeline
568        let shader = gpu.create_shader_module("TSDF Integration Shader", include_str!("shaders/tsdf_integration.wgsl"));
569        let pipeline = gpu.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
570            label: Some("TSDF Integration Pipeline"),
571            layout: None,
572            module: &shader,
573            entry_point: "main",
574            compilation_options: wgpu::PipelineCompilationOptions::default(),
575        });
576
577        // Create bind group
578        let bind_group_entries = vec![
579            wgpu::BindGroupEntry {
580                binding: 0,
581                resource: self.voxel_buffer.as_entire_binding(),
582            },
583            wgpu::BindGroupEntry {
584                binding: 1,
585                resource: depth_buffer.as_entire_binding(),
586            },
587            wgpu::BindGroupEntry {
588                binding: 2,
589                resource: transform_buffer.as_entire_binding(),
590            },
591            wgpu::BindGroupEntry {
592                binding: 3,
593                resource: intrinsics_buffer.as_entire_binding(),
594            },
595            wgpu::BindGroupEntry {
596                binding: 4,
597                resource: params_buffer.as_entire_binding(),
598            },
599            wgpu::BindGroupEntry {
600                binding: 5,
601                resource: color_buffer.as_entire_binding(),
602            },
603        ];
604
605        let bind_group = gpu.create_bind_group("TSDF Integration Bind Group", &pipeline.get_bind_group_layout(0), &bind_group_entries);
606
607        // Dispatch compute shader
608        let mut encoder = gpu.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
609            label: Some("TSDF Integration Encoder"),
610        });
611
612        {
613            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
614                label: Some("TSDF Integration Pass"),
615                timestamp_writes: None,
616            });
617
618            compute_pass.set_pipeline(&pipeline);
619            compute_pass.set_bind_group(0, &bind_group, &[]);
620            
621            // Dispatch with 4x4x4 workgroups
622            let workgroup_size = 4;
623            let dispatch_x = (self.volume.resolution[0] + workgroup_size - 1) / workgroup_size;
624            let dispatch_y = (self.volume.resolution[1] + workgroup_size - 1) / workgroup_size;
625            let dispatch_z = (self.volume.resolution[2] + workgroup_size - 1) / workgroup_size;
626            
627            println!("Dispatching compute shader with {} x {} x {} workgroups", dispatch_x, dispatch_y, dispatch_z);
628            compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, dispatch_z);
629        }
630
631        gpu.queue.submit(std::iter::once(encoder.finish()));
632        Ok(())
633    }
634
635    /// Downloads the TSDF voxel data from the GPU.
636    pub async fn download_voxels(&self, gpu: &GpuContext) -> Result<Vec<TsdfVoxel>> {
637        let total_voxels = (self.volume.resolution[0] * self.volume.resolution[1] * self.volume.resolution[2]) as usize;
638        let buffer_size = (total_voxels * std::mem::size_of::<TsdfVoxel>()) as u64;
639
640        let staging_buffer = gpu.device.create_buffer(&wgpu::BufferDescriptor {
641            label: Some("TSDF Staging Buffer"),
642            size: buffer_size,
643            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
644            mapped_at_creation: false,
645        });
646
647        let mut encoder = gpu.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
648            label: Some("TSDF Download Encoder"),
649        });
650
651        encoder.copy_buffer_to_buffer(
652            &self.voxel_buffer,
653            0,
654            &staging_buffer,
655            0,
656            buffer_size,
657        );
658
659        gpu.queue.submit(std::iter::once(encoder.finish()));
660
661        let buffer_slice = staging_buffer.slice(..);
662        let (sender, receiver) = flume::unbounded();
663        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
664            sender.send(result).unwrap();
665        });
666
667        gpu.device.poll(wgpu::Maintain::Wait);
668        receiver.recv_async().await.map_err(|_| Error::Gpu("Failed to receive mapping result".into()))??;
669
670        let data = buffer_slice.get_mapped_range();
671        let result: Vec<TsdfVoxel> = bytemuck::cast_slice(&data).to_vec();
672        
673        drop(data);
674        staging_buffer.unmap();
675
676        Ok(result)
677    }
678
679    /// Extract point cloud from TSDF volume using marching cubes
680    pub async fn extract_surface(&self, gpu: &GpuContext, iso_value: f32) -> Result<PointCloud<ColoredPoint3f>> {
681        let voxels = self.download_voxels(gpu).await?;
682        gpu.tsdf_extract_surface(&self.volume, &voxels, iso_value).await
683    }
684}
685
686/// Create a new TSDF volume with specified parameters
687pub fn create_tsdf_volume(
688    voxel_size: f32,
689    truncation_distance: f32,
690    resolution: [u32; 3],
691    origin: Point3<f32>,
692) -> TsdfVolume {
693    TsdfVolume {
694        voxel_size,
695        truncation_distance,
696        resolution,
697        origin,
698    }
699}
700
701/// GPU-accelerated TSDF integration from depth image
702pub async fn gpu_tsdf_integrate(
703    gpu_context: &GpuContext,
704    volume: &mut TsdfVolume,
705    depth_image: &[f32],
706    color_image: Option<&[u8]>,
707    camera_pose: &Matrix4<f32>,
708    intrinsics: &CameraIntrinsics,
709) -> Result<Vec<TsdfVoxel>> {
710    gpu_context.tsdf_integrate(volume, depth_image, color_image, camera_pose, intrinsics).await
711}
712
713/// GPU-accelerated surface extraction from TSDF volume
714pub async fn gpu_tsdf_extract_surface(
715    gpu_context: &GpuContext,
716    volume: &TsdfVolume,
717    voxels: &[TsdfVoxel],
718    iso_value: f32,
719) -> Result<PointCloud<ColoredPoint3f>> {
720    gpu_context.tsdf_extract_surface(volume, voxels, iso_value).await
721}
722
723#[cfg(test)]
724mod tests {
725    use super::*;
726    use crate::device::GpuContext;
727    use nalgebra::{Matrix4, Point3};
728    use approx::assert_relative_eq;
729
730    /// Try to create a GPU context, return None if not available
731    async fn try_create_gpu_context() -> Option<GpuContext> {
732        match GpuContext::new().await {
733            Ok(gpu) => Some(gpu),
734            Err(_) => {
735                println!("⚠️  GPU not available, skipping GPU-dependent test");
736                None
737            }
738        }
739    }
740
741    /// Create simple depth image for basic testing
742    fn create_simple_depth_image(width: u32, height: u32, depth: f32) -> Vec<f32> {
743        vec![depth; (width * height) as usize]
744    }
745
746    /// Create a basic camera setup for testing
747    fn create_test_camera() -> CameraIntrinsics {
748        CameraIntrinsics {
749            fx: 525.0,
750            fy: 525.0,
751            cx: 319.5,
752            cy: 239.5,
753            width: 640,
754            height: 480,
755            depth_scale: 1.0,
756            _padding: 0.0,
757        }
758    }
759
760    /// Create identity camera pose
761    fn create_identity_pose() -> Matrix4<f32> {
762        Matrix4::new(
763            1.0, 0.0, 0.0, 0.0,
764            0.0, 1.0, 0.0, 0.0,
765            0.0, 0.0, 1.0, 0.0,
766            0.0, 0.0, 0.0, 1.0,
767        )
768    }
769
770    #[test]
771    fn test_tsdf_basic_integration() {
772        pollster::block_on(async {
773            let Some(gpu) = try_create_gpu_context().await else {
774                return;
775            };
776
777            // Create a simple TSDF volume
778            let voxel_size = 0.02; // 2cm voxels for faster processing
779            let truncation_distance = 0.1; 
780            let resolution = [32, 32, 32]; // Smaller resolution for speed
781            let origin = Point3::new(-0.32, -0.32, 0.0);
782
783            let volume_params = create_tsdf_volume(
784                voxel_size,
785                truncation_distance,
786                resolution,
787                origin,
788            );
789            let tsdf_volume_gpu = TsdfVolumeGpu::new(&gpu, volume_params);
790
791            // Create simple depth image with constant depth
792            let intrinsics = create_test_camera();
793            let depth_image = create_simple_depth_image(intrinsics.width, intrinsics.height, 0.5);
794            let camera_pose = create_identity_pose();
795
796            // Test integration
797            let result = tsdf_volume_gpu.integrate(&gpu, &depth_image, None, &camera_pose, &intrinsics).await;
798            assert!(result.is_ok(), "TSDF integration should succeed");
799
800            // Test voxel download
801            let voxels = tsdf_volume_gpu.download_voxels(&gpu).await.unwrap();
802            assert_eq!(voxels.len(), (32 * 32 * 32) as usize, "Should have correct number of voxels");
803
804            // Check that some voxels have been updated
805            let updated_voxels = voxels.iter().filter(|v| v.weight > 0.0).count();
806            assert!(updated_voxels > 0, "Some voxels should have been updated");
807
808            println!("✓ Basic integration test passed: {} voxels updated", updated_voxels);
809        });
810    }
811
812    #[test]
813    fn test_tsdf_surface_extraction() {
814        pollster::block_on(async {
815            let Some(gpu) = try_create_gpu_context().await else {
816                return;
817            };
818
819            // Create TSDF volume 
820            let voxel_size = 0.02;
821            let truncation_distance = 0.1;
822            let resolution = [32, 32, 32];
823            let origin = Point3::new(-0.32, -0.32, 0.0);
824
825            let volume_params = create_tsdf_volume(
826                voxel_size,
827                truncation_distance,
828                resolution,
829                origin,
830            );
831            let tsdf_volume_gpu = TsdfVolumeGpu::new(&gpu, volume_params);
832
833            // Integrate a simple depth image
834            let intrinsics = create_test_camera();
835            let depth_image = create_simple_depth_image(intrinsics.width, intrinsics.height, 0.3);
836            let camera_pose = create_identity_pose();
837
838            tsdf_volume_gpu.integrate(&gpu, &depth_image, None, &camera_pose, &intrinsics)
839                .await
840                .unwrap();
841
842            // Extract surface
843            let point_cloud = tsdf_volume_gpu.extract_surface(&gpu, 0.0).await.unwrap();
844            
845            // Should extract some points
846            assert!(!point_cloud.points.is_empty(), "Should extract surface points");
847            
848            // Points should be in reasonable Z range around the depth value
849            let avg_z = point_cloud.points.iter()
850                .map(|p| p.position.z)
851                .sum::<f32>() / point_cloud.points.len() as f32;
852            
853            assert!(avg_z > 0.2 && avg_z < 0.4, "Average Z should be near depth value of 0.3");
854            
855            println!("✓ Surface extraction test passed: {} points extracted, avg Z: {:.3}", 
856                     point_cloud.points.len(), avg_z);
857        });
858    }
859
860    #[test]
861    fn test_tsdf_multiple_integrations() {
862        pollster::block_on(async {
863            let Some(gpu) = try_create_gpu_context().await else {
864                return;
865            };
866
867            // Create TSDF volume
868            let voxel_size = 0.02;
869            let truncation_distance = 0.1;
870            let resolution = [32, 32, 32];
871            let origin = Point3::new(-0.32, -0.32, 0.0);
872
873            let volume_params = create_tsdf_volume(
874                voxel_size,
875                truncation_distance,
876                resolution,
877                origin,
878            );
879            let tsdf_volume_gpu = TsdfVolumeGpu::new(&gpu, volume_params);
880
881            let intrinsics = create_test_camera();
882            let camera_pose = create_identity_pose();
883
884            // Integrate multiple depth images
885            let depths = [0.25, 0.3, 0.35];
886            for &depth in &depths {
887                let depth_image = create_simple_depth_image(intrinsics.width, intrinsics.height, depth);
888                tsdf_volume_gpu.integrate(&gpu, &depth_image, None, &camera_pose, &intrinsics)
889                    .await
890                    .unwrap();
891            }
892
893            // Check voxel weights have increased
894            let voxels = tsdf_volume_gpu.download_voxels(&gpu).await.unwrap();
895            let max_weight = voxels.iter().map(|v| v.weight).fold(0.0, f32::max);
896            assert!(max_weight > 1.0, "Multiple integrations should increase voxel weights");
897
898            // Extract surface
899            let point_cloud = tsdf_volume_gpu.extract_surface(&gpu, 0.0).await.unwrap();
900            assert!(!point_cloud.points.is_empty(), "Should extract surface after multiple integrations");
901
902            println!("✓ Multiple integration test passed: max weight {:.1}, {} points extracted", 
903                     max_weight, point_cloud.points.len());
904        });
905    }
906
907    #[test]
908    fn test_tsdf_coordinate_system() {
909        pollster::block_on(async {
910            let Some(_gpu) = try_create_gpu_context().await else {
911                return;
912            };
913
914            // Test basic coordinate system consistency
915            let voxel_size = 0.02;
916            let resolution = [32, 32, 32];
917            let origin = Point3::new(-0.32, -0.32, 0.0);
918
919            // Check volume bounds
920            let max_coord = Point3::new(
921                origin.x + (resolution[0] as f32) * voxel_size,
922                origin.y + (resolution[1] as f32) * voxel_size,
923                origin.z + (resolution[2] as f32) * voxel_size,
924            );
925
926            assert_relative_eq!(max_coord.x, 0.32, epsilon = 0.01);
927            assert_relative_eq!(max_coord.y, 0.32, epsilon = 0.01);
928            assert_relative_eq!(max_coord.z, 0.64, epsilon = 0.01);
929
930            // Test camera transforms
931            let camera_pose = create_identity_pose();
932            let world_to_camera = camera_pose.try_inverse().unwrap();
933            
934            let test_point = Point3::new(0.1, 0.2, 0.3);
935            let camera_point = world_to_camera.transform_point(&test_point);
936            
937            // For identity transform, should be the same
938            assert_relative_eq!(test_point.x, camera_point.x, epsilon = 0.001);
939            assert_relative_eq!(test_point.y, camera_point.y, epsilon = 0.001);
940            assert_relative_eq!(test_point.z, camera_point.z, epsilon = 0.001);
941
942            println!("✓ Coordinate system test passed");
943        });
944    }
945
946    #[test]
947    fn test_tsdf_color_integration() {
948        pollster::block_on(async {
949            let Some(gpu) = try_create_gpu_context().await else {
950                return;
951            };
952
953            // Create TSDF volume
954            let voxel_size = 0.02;
955            let truncation_distance = 0.1;
956            let resolution = [32, 32, 32];
957            let origin = Point3::new(-0.32, -0.32, 0.0);
958
959            let volume_params = create_tsdf_volume(
960                voxel_size,
961                truncation_distance,
962                resolution,
963                origin,
964            );
965            let tsdf_volume_gpu = TsdfVolumeGpu::new(&gpu, volume_params);
966
967            // Create depth and color images
968            let intrinsics = create_test_camera();
969            let depth_image = create_simple_depth_image(intrinsics.width, intrinsics.height, 0.3);
970            
971            // Simple red color image
972            let pixel_count = (intrinsics.width * intrinsics.height) as usize;
973            let mut color_image = Vec::with_capacity(pixel_count * 3);
974            for _ in 0..pixel_count {
975                color_image.extend_from_slice(&[255u8, 0u8, 0u8]); // RGB: red
976            }
977            
978            let camera_pose = create_identity_pose();
979
980            // Integrate with color
981            tsdf_volume_gpu.integrate(&gpu, &depth_image, Some(&color_image), &camera_pose, &intrinsics)
982                .await
983                .unwrap();
984
985            // Extract surface
986            let point_cloud = tsdf_volume_gpu.extract_surface(&gpu, 0.0).await.unwrap();
987            
988            assert!(!point_cloud.points.is_empty(), "Should extract colored surface points");
989            
990            // Check that some points have red color
991            let red_points = point_cloud.points.iter()
992                .filter(|p| p.color[0] > 200)
993                .count();
994            
995            assert!(red_points > 0, "Some points should have red color");
996
997            println!("✓ Color integration test passed: {} red points", red_points);
998        });
999    }
1000}