threecrate_gpu/
tsdf.rs

1use crate::device::GpuContext;
2use threecrate_core::{PointCloud, ColoredPoint3f, Error, Result};
3use nalgebra::{Matrix4, Point3};
4use bytemuck::{Pod, Zeroable};
5use wgpu::util::DeviceExt;
6
7/// TSDF voxel data for GPU processing
8#[repr(C)]
9#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)]
10#[repr(align(16))]  // Ensure 16-byte alignment for GPU
11pub struct TsdfVoxel {
12    pub tsdf_value: f32,
13    pub weight: f32,
14    pub color_r: u32,
15    pub color_g: u32,
16    pub color_b: u32,
17    pub _padding1: u32,
18    pub _padding2: u32,
19    pub _padding3: u32,
20}
21
22/// TSDF volume parameters
23#[derive(Debug, Clone)]
24pub struct TsdfVolume {
25    pub voxel_size: f32,
26    pub truncation_distance: f32,
27    pub resolution: [u32; 3], // [width, height, depth]
28    pub origin: Point3<f32>,
29}
30
31/// Represents a TSDF volume stored on the GPU.
32pub struct TsdfVolumeGpu {
33    pub volume: TsdfVolume,
34    pub voxel_buffer: wgpu::Buffer,
35}
36
37/// Camera intrinsic parameters
38#[repr(C)]
39#[derive(Copy, Clone, Pod, Zeroable)]
40#[repr(align(16))]  // Ensure 16-byte alignment for GPU
41pub struct CameraIntrinsics {
42    pub fx: f32,
43    pub fy: f32,
44    pub cx: f32,
45    pub cy: f32,
46    pub width: u32,
47    pub height: u32,
48    pub depth_scale: f32,
49    pub _padding: f32,
50}
51
52/// TSDF integration parameters
53#[repr(C)]
54#[derive(Copy, Clone, Pod, Zeroable)]
55#[repr(align(16))]
56pub struct TsdfParams {
57    pub voxel_size: f32,
58    pub truncation_distance: f32,
59    pub max_weight: f32,
60    pub iso_value: f32,
61    pub resolution: [u32; 3],
62    pub _padding2: u32,
63    pub origin: [f32; 3],
64    pub _padding3: f32,
65}
66
67#[repr(C)]
68#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)]
69#[repr(align(16))]  // Ensure 16-byte alignment for GPU
70pub struct GpuPoint3f {
71    pub x: f32,
72    pub y: f32,
73    pub z: f32,
74    pub r: u32,
75    pub g: u32,
76    pub b: u32,
77    pub _padding1: u32,
78    pub _padding2: u32,
79}
80
81impl GpuContext {
82    /// Integrate depth image into TSDF volume
83    pub async fn tsdf_integrate(
84        &self,
85        volume: &mut TsdfVolume,
86        depth_image: &[f32],
87        color_image: Option<&[u8]>, // RGB color data
88        camera_pose: &Matrix4<f32>,
89        intrinsics: &CameraIntrinsics,
90    ) -> Result<Vec<TsdfVoxel>> {
91        let total_voxels = (volume.resolution[0] * volume.resolution[1] * volume.resolution[2]) as usize;
92        
93        // Create buffers
94        let depth_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
95            label: Some("TSDF Depth Buffer"),
96            contents: bytemuck::cast_slice(depth_image),
97            usage: wgpu::BufferUsages::STORAGE,
98        });
99
100        let color_buffer = if let Some(color_data) = color_image {
101            // Convert RGB u8 data to packed u32 RGB values
102            let mut packed_colors = Vec::with_capacity(color_data.len() / 3);
103            for chunk in color_data.chunks_exact(3) {
104                let r = chunk[0] as u32;
105                let g = chunk[1] as u32;
106                let b = chunk[2] as u32;
107                let packed = (r << 16) | (g << 8) | b;
108                packed_colors.push(packed);
109            }
110            
111            self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
112                label: Some("TSDF Color Buffer"),
113                contents: bytemuck::cast_slice(&packed_colors),
114                usage: wgpu::BufferUsages::STORAGE,
115            })
116        } else {
117            // Create empty buffer if no color data
118            self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
119                label: Some("TSDF Empty Color Buffer"),
120                contents: bytemuck::cast_slice(&[0u32; 4]), // Small dummy buffer
121                usage: wgpu::BufferUsages::STORAGE,
122            })
123        };
124
125        // Initialize TSDF volume if needed
126        let initial_voxels = vec![TsdfVoxel {
127            tsdf_value: 1.0,
128            weight: 0.0,
129            color_r: 0,
130            color_g: 0,
131            color_b: 0,
132            _padding1: 0,
133            _padding2: 0,
134            _padding3: 0,
135        }; total_voxels];
136
137        let voxel_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
138            label: Some("TSDF Voxel Buffer"),
139            contents: bytemuck::cast_slice(&initial_voxels),
140            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
141        });
142
143        // Convert camera transform to world-to-camera matrix (inverse of camera pose)
144        let world_to_camera = camera_pose.try_inverse()
145            .ok_or_else(|| Error::Gpu("Failed to invert camera pose matrix".into()))?;
146        
147        let mut camera_transform = [[0.0f32; 4]; 4];
148        for i in 0..4 {
149            for j in 0..4 {
150                camera_transform[i][j] = world_to_camera[(i, j)];
151            }
152        }
153
154        let transform_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
155            label: Some("TSDF Transform Buffer"),
156            contents: bytemuck::cast_slice(&[camera_transform]),
157            usage: wgpu::BufferUsages::UNIFORM,
158        });
159
160        let intrinsics_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
161            label: Some("TSDF Intrinsics Buffer"),
162            contents: bytemuck::bytes_of(intrinsics),
163            usage: wgpu::BufferUsages::UNIFORM,
164        });
165
166        let params = TsdfParams {
167            voxel_size: volume.voxel_size,
168            truncation_distance: volume.truncation_distance,
169            max_weight: 100.0,
170            iso_value: 0.0,
171            resolution: volume.resolution,
172            _padding2: 0,
173            origin: [volume.origin.x, volume.origin.y, volume.origin.z],
174            _padding3: 0.0,
175        };
176
177        let params_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
178            label: Some("TSDF Params Buffer"),
179            contents: bytemuck::bytes_of(&params),
180            usage: wgpu::BufferUsages::UNIFORM,
181        });
182
183        // Create compute pipeline
184        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
185            label: Some("TSDF Integration Shader"),
186            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/tsdf_integration.wgsl").into()),
187        });
188
189        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
190            label: Some("TSDF Integration Pipeline"),
191            layout: None,
192            module: &shader,
193            entry_point: Some("main"),
194            compilation_options: wgpu::PipelineCompilationOptions::default(),
195            cache: None,
196        });
197
198        // Create bind group
199        let bind_group_entries = vec![
200            wgpu::BindGroupEntry {
201                binding: 0,
202                resource: voxel_buffer.as_entire_binding(),
203            },
204            wgpu::BindGroupEntry {
205                binding: 1,
206                resource: depth_buffer.as_entire_binding(),
207            },
208            wgpu::BindGroupEntry {
209                binding: 2,
210                resource: transform_buffer.as_entire_binding(),
211            },
212            wgpu::BindGroupEntry {
213                binding: 3,
214                resource: intrinsics_buffer.as_entire_binding(),
215            },
216            wgpu::BindGroupEntry {
217                binding: 4,
218                resource: params_buffer.as_entire_binding(),
219            },
220            wgpu::BindGroupEntry {
221                binding: 5,
222                resource: color_buffer.as_entire_binding(),
223            },
224        ];
225
226        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
227            label: Some("TSDF Integration Bind Group"),
228            layout: &pipeline.get_bind_group_layout(0),
229            entries: &bind_group_entries,
230        });
231
232        // Dispatch compute shader
233        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
234            label: Some("TSDF Integration Encoder"),
235        });
236
237        {
238            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
239                label: Some("TSDF Integration Pass"),
240                timestamp_writes: None,
241            });
242
243            compute_pass.set_pipeline(&pipeline);
244            compute_pass.set_bind_group(0, &bind_group, &[]);
245            
246            // Dispatch with 4x4x4 workgroups
247            let workgroup_size = 4;
248            let dispatch_x = (volume.resolution[0] + workgroup_size - 1) / workgroup_size;
249            let dispatch_y = (volume.resolution[1] + workgroup_size - 1) / workgroup_size;
250            let dispatch_z = (volume.resolution[2] + workgroup_size - 1) / workgroup_size;
251            
252            compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, dispatch_z);
253        }
254
255        // Read back results
256        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
257            label: Some("TSDF Staging Buffer"),
258            size: (total_voxels * std::mem::size_of::<TsdfVoxel>()) as u64,
259            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
260            mapped_at_creation: false,
261        });
262
263        encoder.copy_buffer_to_buffer(
264            &voxel_buffer,
265            0,
266            &staging_buffer,
267            0,
268            staging_buffer.size(),
269        );
270
271        self.queue.submit(std::iter::once(encoder.finish()));
272
273        let buffer_slice = staging_buffer.slice(..);
274        let (sender, receiver) = flume::unbounded();
275        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
276            sender.send(result).unwrap();
277        });
278
279        let _ = self.device.poll(wgpu::PollType::Wait);
280        receiver.recv_async().await.map_err(|_| Error::Gpu("Failed to receive mapping result".into()))?
281            .map_err(|e| Error::Gpu(format!("Buffer mapping failed: {:?}", e)))?;
282
283        let data = buffer_slice.get_mapped_range();
284        let result: Vec<TsdfVoxel> = bytemuck::cast_slice(&data).to_vec();
285        
286        drop(data);
287        staging_buffer.unmap();
288
289        Ok(result)
290    }
291
292    /// Extract point cloud from TSDF volume using marching cubes
293    pub async fn tsdf_extract_surface(
294        &self,
295        volume: &TsdfVolume,
296        voxels: &[TsdfVoxel],
297        iso_value: f32,
298    ) -> Result<PointCloud<ColoredPoint3f>> {
299        let total_voxels = (volume.resolution[0] * volume.resolution[1] * volume.resolution[2]) as usize;
300        let max_points = std::cmp::min(total_voxels, 1_000_000); // Limit to reasonable size
301        
302        // Create voxel buffer
303        let voxel_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
304            label: Some("TSDF Voxel Buffer"),
305            contents: bytemuck::cast_slice(voxels),
306            usage: wgpu::BufferUsages::STORAGE,
307        });
308
309        // Create output buffers
310        let points_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
311            label: Some("Surface Points Buffer"),
312            size: (max_points * std::mem::size_of::<GpuPoint3f>()) as u64,
313            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
314            mapped_at_creation: false,
315        });
316
317        let point_count_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
318            label: Some("Point Count Buffer"),
319            contents: bytemuck::bytes_of(&0u32),
320            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
321        });
322
323        let params = TsdfParams {
324            voxel_size: volume.voxel_size,
325            truncation_distance: volume.truncation_distance,
326            max_weight: 100.0,
327            iso_value,
328            resolution: volume.resolution,
329            _padding2: 0,
330            origin: [volume.origin.x, volume.origin.y, volume.origin.z],
331            _padding3: 0.0,
332        };
333
334        let params_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
335            label: Some("Surface Extraction Params Buffer"),
336            contents: bytemuck::bytes_of(&params),
337            usage: wgpu::BufferUsages::UNIFORM,
338        });
339
340        // Create compute pipeline
341        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
342            label: Some("Surface Extraction Shader"),
343            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/surface_extraction.wgsl").into()),
344        });
345
346        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
347            label: Some("Surface Extraction Pipeline"),
348            layout: None,
349            module: &shader,
350            entry_point: Some("main"),
351            compilation_options: wgpu::PipelineCompilationOptions::default(),
352            cache: None,
353        });
354
355        // Create bind group
356        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
357            label: Some("Surface Extraction Bind Group"),
358            layout: &pipeline.get_bind_group_layout(0),
359            entries: &[
360                wgpu::BindGroupEntry {
361                    binding: 0,
362                    resource: voxel_buffer.as_entire_binding(),
363                },
364                wgpu::BindGroupEntry {
365                    binding: 1,
366                    resource: points_buffer.as_entire_binding(),
367                },
368                wgpu::BindGroupEntry {
369                    binding: 2,
370                    resource: params_buffer.as_entire_binding(),
371                },
372                wgpu::BindGroupEntry {
373                    binding: 3,
374                    resource: point_count_buffer.as_entire_binding(),
375                },
376            ],
377        });
378
379        // Create staging buffer for reading back results
380        let point_count_staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
381            label: Some("Point Count Staging Buffer"),
382            size: std::mem::size_of::<u32>() as u64,
383            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
384            mapped_at_creation: false,
385        });
386
387        // Dispatch compute shader
388        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
389            label: Some("Surface Extraction Encoder"),
390        });
391
392        {
393            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
394                label: Some("Surface Extraction Pass"),
395                timestamp_writes: None,
396            });
397
398            compute_pass.set_pipeline(&pipeline);
399            compute_pass.set_bind_group(0, &bind_group, &[]);
400            compute_pass.dispatch_workgroups(
401                (volume.resolution[0] + 3) / 4,
402                (volume.resolution[1] + 3) / 4,
403                (volume.resolution[2] + 3) / 4,
404            );
405        }
406
407        // Copy point count to staging buffer
408        encoder.copy_buffer_to_buffer(
409            &point_count_buffer,
410            0,
411            &point_count_staging_buffer,
412            0,
413            std::mem::size_of::<u32>() as u64,
414        );
415
416        self.queue.submit(Some(encoder.finish()));
417
418        // Read point count
419        let point_count_slice = point_count_staging_buffer.slice(..);
420        let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
421        point_count_slice.map_async(wgpu::MapMode::Read, move |result| {
422            tx.send(result).unwrap();
423        });
424        let _ = self.device.poll(wgpu::PollType::Wait);
425        rx.receive().await.unwrap()?;
426
427        let mapped_range = point_count_slice.get_mapped_range();
428        let point_count = bytemuck::cast_slice::<u8, u32>(mapped_range.as_ref())[0] as usize;
429        drop(mapped_range);
430        point_count_staging_buffer.unmap();
431
432        if point_count == 0 {
433            return Ok(PointCloud { points: Vec::new() });
434        }
435
436        // Create staging buffer for points
437        let points_staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
438            label: Some("Points Staging Buffer"),
439            size: (point_count * std::mem::size_of::<GpuPoint3f>()) as u64,
440            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
441            mapped_at_creation: false,
442        });
443
444        // Copy points to staging buffer
445        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
446            label: Some("Points Copy Encoder"),
447        });
448
449        encoder.copy_buffer_to_buffer(
450            &points_buffer,
451            0,
452            &points_staging_buffer,
453            0,
454            (point_count * std::mem::size_of::<GpuPoint3f>()) as u64,
455        );
456
457        self.queue.submit(Some(encoder.finish()));
458
459        // Read points
460        let points_slice = points_staging_buffer.slice(..);
461        let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
462        points_slice.map_async(wgpu::MapMode::Read, move |result| {
463            tx.send(result).unwrap();
464        });
465        let _ = self.device.poll(wgpu::PollType::Wait);
466        rx.receive().await.unwrap()?;
467
468        let mapped_range = points_slice.get_mapped_range();
469        let gpu_points = bytemuck::cast_slice::<u8, GpuPoint3f>(mapped_range.as_ref());
470        let mut points = Vec::with_capacity(point_count);
471
472        for gpu_point in gpu_points.iter().take(point_count) {
473            points.push(ColoredPoint3f {
474                position: Point3::new(gpu_point.x, gpu_point.y, gpu_point.z),
475                color: [gpu_point.r as u8, gpu_point.g as u8, gpu_point.b as u8],
476            });
477        }
478
479        drop(mapped_range);  // Explicitly drop the mapped range before unmapping
480        points_staging_buffer.unmap();
481
482        Ok(PointCloud { points })
483    }
484}
485
486impl TsdfVolumeGpu {
487    /// Creates a new TSDF volume on the GPU.
488    pub fn new(gpu: &GpuContext, volume_params: TsdfVolume) -> Self {
489        let total_voxels = (volume_params.resolution[0] * volume_params.resolution[1] * volume_params.resolution[2]) as usize;
490        
491        // Initialize voxels with default values
492        let initial_voxels = vec![TsdfVoxel {
493            tsdf_value: 1.0,
494            weight: 0.0,
495            color_r: 0,
496            color_g: 0,
497            color_b: 0,
498            _padding1: 0,
499            _padding2: 0,
500            _padding3: 0,
501        }; total_voxels];
502
503        let voxel_buffer = gpu.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
504            label: Some("TSDF Voxel Buffer"),
505            contents: bytemuck::cast_slice(&initial_voxels),
506            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST,
507        });
508
509        Self {
510            volume: volume_params,
511            voxel_buffer,
512        }
513    }
514
515    /// Integrates a depth image into the TSDF volume.
516    pub async fn integrate(
517        &self,
518        gpu: &GpuContext,
519        depth_image: &[f32],
520        color_image: Option<&[u8]>, // RGB color data
521        camera_pose: &Matrix4<f32>,
522        intrinsics: &CameraIntrinsics,
523    ) -> Result<()> {
524        // Create buffers for depth, color, transform, and parameters
525        let depth_buffer = gpu.create_buffer_init("TSDF Depth Buffer", depth_image, wgpu::BufferUsages::STORAGE);
526
527        let color_buffer = if let Some(data) = color_image {
528            gpu.create_buffer_init("TSDF Color Buffer", data, wgpu::BufferUsages::STORAGE)
529        } else {
530            // Create a dummy buffer if no color image is provided
531            gpu.create_buffer_init("TSDF Dummy Color Buffer", &[0u32; 4], wgpu::BufferUsages::STORAGE)
532        };
533
534        // Convert camera transform to world-to-camera matrix (inverse of camera pose)
535        let world_to_camera = camera_pose.try_inverse()
536            .ok_or_else(|| Error::Gpu("Failed to invert camera pose matrix".into()))?;
537        
538        let mut camera_transform = [[0.0f32; 4]; 4];
539        for i in 0..4 {
540            for j in 0..4 {
541                camera_transform[i][j] = world_to_camera[(i, j)];
542            }
543        }
544
545        let transform_buffer = gpu.create_buffer_init(
546            "TSDF Transform Buffer",
547            &[camera_transform],
548            wgpu::BufferUsages::UNIFORM,
549        );
550
551        let intrinsics_buffer = gpu.create_buffer_init(
552            "TSDF Intrinsics Buffer",
553            &[*intrinsics],
554            wgpu::BufferUsages::UNIFORM,
555        );
556
557        let params = TsdfParams {
558            voxel_size: self.volume.voxel_size,
559            truncation_distance: self.volume.truncation_distance,
560            max_weight: 100.0,
561            iso_value: 0.0,
562            resolution: self.volume.resolution,
563            _padding2: 0,
564            origin: [self.volume.origin.x, self.volume.origin.y, self.volume.origin.z],
565            _padding3: 0.0,
566        };
567        let params_buffer = gpu.create_buffer_init("TSDF Params Buffer", &[params], wgpu::BufferUsages::UNIFORM);
568
569        // Create compute pipeline
570        let shader = gpu.create_shader_module("TSDF Integration Shader", include_str!("shaders/tsdf_integration.wgsl"));
571        let pipeline = gpu.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
572            label: Some("TSDF Integration Pipeline"),
573            layout: None,
574            module: &shader,
575            entry_point: Some("main"),
576            compilation_options: wgpu::PipelineCompilationOptions::default(),
577            cache: None,
578        });
579
580        // Create bind group
581        let bind_group_entries = vec![
582            wgpu::BindGroupEntry {
583                binding: 0,
584                resource: self.voxel_buffer.as_entire_binding(),
585            },
586            wgpu::BindGroupEntry {
587                binding: 1,
588                resource: depth_buffer.as_entire_binding(),
589            },
590            wgpu::BindGroupEntry {
591                binding: 2,
592                resource: transform_buffer.as_entire_binding(),
593            },
594            wgpu::BindGroupEntry {
595                binding: 3,
596                resource: intrinsics_buffer.as_entire_binding(),
597            },
598            wgpu::BindGroupEntry {
599                binding: 4,
600                resource: params_buffer.as_entire_binding(),
601            },
602            wgpu::BindGroupEntry {
603                binding: 5,
604                resource: color_buffer.as_entire_binding(),
605            },
606        ];
607
608        let bind_group = gpu.create_bind_group("TSDF Integration Bind Group", &pipeline.get_bind_group_layout(0), &bind_group_entries);
609
610        // Dispatch compute shader
611        let mut encoder = gpu.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
612            label: Some("TSDF Integration Encoder"),
613        });
614
615        {
616            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
617                label: Some("TSDF Integration Pass"),
618                timestamp_writes: None,
619            });
620
621            compute_pass.set_pipeline(&pipeline);
622            compute_pass.set_bind_group(0, &bind_group, &[]);
623            
624            // Dispatch with 4x4x4 workgroups
625            let workgroup_size = 4;
626            let dispatch_x = (self.volume.resolution[0] + workgroup_size - 1) / workgroup_size;
627            let dispatch_y = (self.volume.resolution[1] + workgroup_size - 1) / workgroup_size;
628            let dispatch_z = (self.volume.resolution[2] + workgroup_size - 1) / workgroup_size;
629            
630            println!("Dispatching compute shader with {} x {} x {} workgroups", dispatch_x, dispatch_y, dispatch_z);
631            compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, dispatch_z);
632        }
633
634        gpu.queue.submit(std::iter::once(encoder.finish()));
635        Ok(())
636    }
637
638    /// Downloads the TSDF voxel data from the GPU.
639    pub async fn download_voxels(&self, gpu: &GpuContext) -> Result<Vec<TsdfVoxel>> {
640        let total_voxels = (self.volume.resolution[0] * self.volume.resolution[1] * self.volume.resolution[2]) as usize;
641        let buffer_size = (total_voxels * std::mem::size_of::<TsdfVoxel>()) as u64;
642
643        let staging_buffer = gpu.device.create_buffer(&wgpu::BufferDescriptor {
644            label: Some("TSDF Staging Buffer"),
645            size: buffer_size,
646            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
647            mapped_at_creation: false,
648        });
649
650        let mut encoder = gpu.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
651            label: Some("TSDF Download Encoder"),
652        });
653
654        encoder.copy_buffer_to_buffer(
655            &self.voxel_buffer,
656            0,
657            &staging_buffer,
658            0,
659            buffer_size,
660        );
661
662        gpu.queue.submit(std::iter::once(encoder.finish()));
663
664        let buffer_slice = staging_buffer.slice(..);
665        let (sender, receiver) = flume::unbounded();
666        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
667            sender.send(result).unwrap();
668        });
669
670        let _ = gpu.device.poll(wgpu::PollType::Wait);
671        receiver.recv_async().await.map_err(|_| Error::Gpu("Failed to receive mapping result".into()))??;
672
673        let data = buffer_slice.get_mapped_range();
674        let result: Vec<TsdfVoxel> = bytemuck::cast_slice(&data).to_vec();
675        
676        drop(data);
677        staging_buffer.unmap();
678
679        Ok(result)
680    }
681
682    /// Extract point cloud from TSDF volume using marching cubes
683    pub async fn extract_surface(&self, gpu: &GpuContext, iso_value: f32) -> Result<PointCloud<ColoredPoint3f>> {
684        let voxels = self.download_voxels(gpu).await?;
685        gpu.tsdf_extract_surface(&self.volume, &voxels, iso_value).await
686    }
687}
688
689/// Create a new TSDF volume with specified parameters
690pub fn create_tsdf_volume(
691    voxel_size: f32,
692    truncation_distance: f32,
693    resolution: [u32; 3],
694    origin: Point3<f32>,
695) -> TsdfVolume {
696    TsdfVolume {
697        voxel_size,
698        truncation_distance,
699        resolution,
700        origin,
701    }
702}
703
704/// GPU-accelerated TSDF integration from depth image
705pub async fn gpu_tsdf_integrate(
706    gpu_context: &GpuContext,
707    volume: &mut TsdfVolume,
708    depth_image: &[f32],
709    color_image: Option<&[u8]>,
710    camera_pose: &Matrix4<f32>,
711    intrinsics: &CameraIntrinsics,
712) -> Result<Vec<TsdfVoxel>> {
713    gpu_context.tsdf_integrate(volume, depth_image, color_image, camera_pose, intrinsics).await
714}
715
716/// GPU-accelerated surface extraction from TSDF volume
717pub async fn gpu_tsdf_extract_surface(
718    gpu_context: &GpuContext,
719    volume: &TsdfVolume,
720    voxels: &[TsdfVoxel],
721    iso_value: f32,
722) -> Result<PointCloud<ColoredPoint3f>> {
723    gpu_context.tsdf_extract_surface(volume, voxels, iso_value).await
724}
725
726#[cfg(test)]
727mod tests {
728    use super::*;
729    use crate::device::GpuContext;
730    use nalgebra::{Matrix4, Point3};
731    use approx::assert_relative_eq;
732
733    /// Try to create a GPU context, return None if not available
734    async fn try_create_gpu_context() -> Option<GpuContext> {
735        match GpuContext::new().await {
736            Ok(gpu) => Some(gpu),
737            Err(_) => {
738                println!("⚠️  GPU not available, skipping GPU-dependent test");
739                None
740            }
741        }
742    }
743
744    /// Create simple depth image for basic testing
745    fn create_simple_depth_image(width: u32, height: u32, depth: f32) -> Vec<f32> {
746        vec![depth; (width * height) as usize]
747    }
748
749    /// Create a basic camera setup for testing
750    fn create_test_camera() -> CameraIntrinsics {
751        CameraIntrinsics {
752            fx: 525.0,
753            fy: 525.0,
754            cx: 319.5,
755            cy: 239.5,
756            width: 640,
757            height: 480,
758            depth_scale: 1.0,
759            _padding: 0.0,
760        }
761    }
762
763    /// Create identity camera pose
764    fn create_identity_pose() -> Matrix4<f32> {
765        Matrix4::new(
766            1.0, 0.0, 0.0, 0.0,
767            0.0, 1.0, 0.0, 0.0,
768            0.0, 0.0, 1.0, 0.0,
769            0.0, 0.0, 0.0, 1.0,
770        )
771    }
772
773    #[test]
774    fn test_tsdf_basic_integration() {
775        pollster::block_on(async {
776            let Some(gpu) = try_create_gpu_context().await else {
777                return;
778            };
779
780            // Create a simple TSDF volume
781            let voxel_size = 0.02; // 2cm voxels for faster processing
782            let truncation_distance = 0.1; 
783            let resolution = [32, 32, 32]; // Smaller resolution for speed
784            let origin = Point3::new(-0.32, -0.32, 0.0);
785
786            let volume_params = create_tsdf_volume(
787                voxel_size,
788                truncation_distance,
789                resolution,
790                origin,
791            );
792            let tsdf_volume_gpu = TsdfVolumeGpu::new(&gpu, volume_params);
793
794            // Create simple depth image with constant depth
795            let intrinsics = create_test_camera();
796            let depth_image = create_simple_depth_image(intrinsics.width, intrinsics.height, 0.5);
797            let camera_pose = create_identity_pose();
798
799            // Test integration
800            let result = tsdf_volume_gpu.integrate(&gpu, &depth_image, None, &camera_pose, &intrinsics).await;
801            assert!(result.is_ok(), "TSDF integration should succeed");
802
803            // Test voxel download
804            let voxels = tsdf_volume_gpu.download_voxels(&gpu).await.unwrap();
805            assert_eq!(voxels.len(), (32 * 32 * 32) as usize, "Should have correct number of voxels");
806
807            // Check that some voxels have been updated
808            let updated_voxels = voxels.iter().filter(|v| v.weight > 0.0).count();
809            assert!(updated_voxels > 0, "Some voxels should have been updated");
810
811            println!("✓ Basic integration test passed: {} voxels updated", updated_voxels);
812        });
813    }
814
815    #[test]
816    fn test_tsdf_surface_extraction() {
817        pollster::block_on(async {
818            let Some(gpu) = try_create_gpu_context().await else {
819                return;
820            };
821
822            // Create TSDF volume 
823            let voxel_size = 0.02;
824            let truncation_distance = 0.1;
825            let resolution = [32, 32, 32];
826            let origin = Point3::new(-0.32, -0.32, 0.0);
827
828            let volume_params = create_tsdf_volume(
829                voxel_size,
830                truncation_distance,
831                resolution,
832                origin,
833            );
834            let tsdf_volume_gpu = TsdfVolumeGpu::new(&gpu, volume_params);
835
836            // Integrate a simple depth image
837            let intrinsics = create_test_camera();
838            let depth_image = create_simple_depth_image(intrinsics.width, intrinsics.height, 0.3);
839            let camera_pose = create_identity_pose();
840
841            tsdf_volume_gpu.integrate(&gpu, &depth_image, None, &camera_pose, &intrinsics)
842                .await
843                .unwrap();
844
845            // Extract surface
846            let point_cloud = tsdf_volume_gpu.extract_surface(&gpu, 0.0).await.unwrap();
847            
848            // Should extract some points
849            assert!(!point_cloud.points.is_empty(), "Should extract surface points");
850            
851            // Points should be in reasonable Z range around the depth value
852            let avg_z = point_cloud.points.iter()
853                .map(|p| p.position.z)
854                .sum::<f32>() / point_cloud.points.len() as f32;
855            
856            assert!(avg_z > 0.2 && avg_z < 0.4, "Average Z should be near depth value of 0.3");
857            
858            println!("✓ Surface extraction test passed: {} points extracted, avg Z: {:.3}", 
859                     point_cloud.points.len(), avg_z);
860        });
861    }
862
863    #[test]
864    fn test_tsdf_multiple_integrations() {
865        pollster::block_on(async {
866            let Some(gpu) = try_create_gpu_context().await else {
867                return;
868            };
869
870            // Create TSDF volume
871            let voxel_size = 0.02;
872            let truncation_distance = 0.1;
873            let resolution = [32, 32, 32];
874            let origin = Point3::new(-0.32, -0.32, 0.0);
875
876            let volume_params = create_tsdf_volume(
877                voxel_size,
878                truncation_distance,
879                resolution,
880                origin,
881            );
882            let tsdf_volume_gpu = TsdfVolumeGpu::new(&gpu, volume_params);
883
884            let intrinsics = create_test_camera();
885            let camera_pose = create_identity_pose();
886
887            // Integrate multiple depth images
888            let depths = [0.25, 0.3, 0.35];
889            for &depth in &depths {
890                let depth_image = create_simple_depth_image(intrinsics.width, intrinsics.height, depth);
891                tsdf_volume_gpu.integrate(&gpu, &depth_image, None, &camera_pose, &intrinsics)
892                    .await
893                    .unwrap();
894            }
895
896            // Check voxel weights have increased
897            let voxels = tsdf_volume_gpu.download_voxels(&gpu).await.unwrap();
898            let max_weight = voxels.iter().map(|v| v.weight).fold(0.0, f32::max);
899            assert!(max_weight > 1.0, "Multiple integrations should increase voxel weights");
900
901            // Extract surface
902            let point_cloud = tsdf_volume_gpu.extract_surface(&gpu, 0.0).await.unwrap();
903            assert!(!point_cloud.points.is_empty(), "Should extract surface after multiple integrations");
904
905            println!("✓ Multiple integration test passed: max weight {:.1}, {} points extracted", 
906                     max_weight, point_cloud.points.len());
907        });
908    }
909
910    #[test]
911    fn test_tsdf_coordinate_system() {
912        pollster::block_on(async {
913            let Some(_gpu) = try_create_gpu_context().await else {
914                return;
915            };
916
917            // Test basic coordinate system consistency
918            let voxel_size = 0.02;
919            let resolution = [32, 32, 32];
920            let origin = Point3::new(-0.32, -0.32, 0.0);
921
922            // Check volume bounds
923            let max_coord = Point3::new(
924                origin.x + (resolution[0] as f32) * voxel_size,
925                origin.y + (resolution[1] as f32) * voxel_size,
926                origin.z + (resolution[2] as f32) * voxel_size,
927            );
928
929            assert_relative_eq!(max_coord.x, 0.32, epsilon = 0.01);
930            assert_relative_eq!(max_coord.y, 0.32, epsilon = 0.01);
931            assert_relative_eq!(max_coord.z, 0.64, epsilon = 0.01);
932
933            // Test camera transforms
934            let camera_pose = create_identity_pose();
935            let world_to_camera = camera_pose.try_inverse().unwrap();
936            
937            let test_point = Point3::new(0.1, 0.2, 0.3);
938            let camera_point = world_to_camera.transform_point(&test_point);
939            
940            // For identity transform, should be the same
941            assert_relative_eq!(test_point.x, camera_point.x, epsilon = 0.001);
942            assert_relative_eq!(test_point.y, camera_point.y, epsilon = 0.001);
943            assert_relative_eq!(test_point.z, camera_point.z, epsilon = 0.001);
944
945            println!("✓ Coordinate system test passed");
946        });
947    }
948
949    #[test]
950    fn test_tsdf_color_integration() {
951        pollster::block_on(async {
952            let Some(gpu) = try_create_gpu_context().await else {
953                return;
954            };
955
956            // Create TSDF volume
957            let voxel_size = 0.02;
958            let truncation_distance = 0.1;
959            let resolution = [32, 32, 32];
960            let origin = Point3::new(-0.32, -0.32, 0.0);
961
962            let volume_params = create_tsdf_volume(
963                voxel_size,
964                truncation_distance,
965                resolution,
966                origin,
967            );
968            let tsdf_volume_gpu = TsdfVolumeGpu::new(&gpu, volume_params);
969
970            // Create depth and color images
971            let intrinsics = create_test_camera();
972            let depth_image = create_simple_depth_image(intrinsics.width, intrinsics.height, 0.3);
973            
974            // Simple red color image
975            let pixel_count = (intrinsics.width * intrinsics.height) as usize;
976            let mut color_image = Vec::with_capacity(pixel_count * 3);
977            for _ in 0..pixel_count {
978                color_image.extend_from_slice(&[255u8, 0u8, 0u8]); // RGB: red
979            }
980            
981            let camera_pose = create_identity_pose();
982
983            // Integrate with color
984            tsdf_volume_gpu.integrate(&gpu, &depth_image, Some(&color_image), &camera_pose, &intrinsics)
985                .await
986                .unwrap();
987
988            // Extract surface
989            let point_cloud = tsdf_volume_gpu.extract_surface(&gpu, 0.0).await.unwrap();
990            
991            assert!(!point_cloud.points.is_empty(), "Should extract colored surface points");
992            
993            // Check that some points have red color
994            let red_points = point_cloud.points.iter()
995                .filter(|p| p.color[0] > 200)
996                .count();
997            
998            assert!(red_points > 0, "Some points should have red color");
999
1000            println!("✓ Color integration test passed: {} red points", red_points);
1001        });
1002    }
1003}