wgpu_core/device/
ray_tracing.rs

1use alloc::{string::ToString as _, sync::Arc, vec::Vec};
2use core::mem::{size_of, ManuallyDrop};
3
4#[cfg(feature = "trace")]
5use crate::device::trace;
6use crate::device::DeviceError;
7use crate::{
8    api_log,
9    device::Device,
10    global::Global,
11    id::{self, BlasId, TlasId},
12    lock::RwLock,
13    lock::{rank, Mutex},
14    ray_tracing::BlasPrepareCompactError,
15    ray_tracing::{CreateBlasError, CreateTlasError},
16    resource,
17    resource::{
18        BlasCompactCallback, BlasCompactState, Fallible, InvalidResourceError, TrackingData,
19    },
20    snatch::Snatchable,
21    LabelHelpers,
22};
23use hal::AccelerationStructureTriangleIndices;
24use wgt::Features;
25
26impl Device {
27    fn create_blas(
28        self: &Arc<Self>,
29        blas_desc: &resource::BlasDescriptor,
30        sizes: wgt::BlasGeometrySizeDescriptors,
31    ) -> Result<Arc<resource::Blas>, CreateBlasError> {
32        self.check_is_valid()?;
33        self.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
34
35        if blas_desc
36            .flags
37            .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
38        {
39            self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
40        }
41
42        let size_info = match &sizes {
43            wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => {
44                if descriptors.len() as u32 > self.limits.max_blas_geometry_count {
45                    return Err(CreateBlasError::TooManyGeometries(
46                        self.limits.max_blas_geometry_count,
47                        descriptors.len() as u32,
48                    ));
49                }
50
51                let mut entries =
52                    Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::with_capacity(
53                        descriptors.len(),
54                    );
55                for desc in descriptors {
56                    if desc.index_count.is_some() != desc.index_format.is_some() {
57                        return Err(CreateBlasError::MissingIndexData);
58                    }
59                    let indices =
60                        desc.index_count
61                            .map(|count| AccelerationStructureTriangleIndices::<
62                                dyn hal::DynBuffer,
63                            > {
64                                format: desc.index_format.unwrap(),
65                                buffer: None,
66                                offset: 0,
67                                count,
68                            });
69                    if !self
70                        .features
71                        .allowed_vertex_formats_for_blas()
72                        .contains(&desc.vertex_format)
73                    {
74                        return Err(CreateBlasError::InvalidVertexFormat(
75                            desc.vertex_format,
76                            self.features.allowed_vertex_formats_for_blas(),
77                        ));
78                    }
79
80                    let mut transform = None;
81
82                    if blas_desc
83                        .flags
84                        .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
85                    {
86                        transform = Some(wgpu_hal::AccelerationStructureTriangleTransform {
87                            buffer: self.zero_buffer.as_ref(),
88                            offset: 0,
89                        })
90                    }
91
92                    if desc.vertex_count > self.limits.max_blas_primitive_count {
93                        return Err(CreateBlasError::TooManyPrimitives(
94                            self.limits.max_blas_primitive_count,
95                            desc.vertex_count,
96                        ));
97                    }
98
99                    entries.push(hal::AccelerationStructureTriangles::<dyn hal::DynBuffer> {
100                        vertex_buffer: None,
101                        vertex_format: desc.vertex_format,
102                        first_vertex: 0,
103                        vertex_count: desc.vertex_count,
104                        vertex_stride: 0,
105                        indices,
106                        transform,
107                        flags: desc.flags,
108                    });
109                }
110                unsafe {
111                    self.raw().get_acceleration_structure_build_sizes(
112                        &hal::GetAccelerationStructureBuildSizesDescriptor {
113                            entries: &hal::AccelerationStructureEntries::Triangles(entries),
114                            flags: blas_desc.flags,
115                        },
116                    )
117                }
118            }
119        };
120
121        let raw = unsafe {
122            self.raw()
123                .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
124                    label: blas_desc.label.as_deref(),
125                    size: size_info.acceleration_structure_size,
126                    format: hal::AccelerationStructureFormat::BottomLevel,
127                    allow_compaction: blas_desc
128                        .flags
129                        .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION),
130                })
131        }
132        .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
133
134        let compaction_buffer = if blas_desc
135            .flags
136            .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
137        {
138            Some(ManuallyDrop::new(unsafe {
139                self.raw()
140                    .create_buffer(&hal::BufferDescriptor {
141                        label: Some("(wgpu internal) compaction read-back buffer"),
142                        size: size_of::<wgpu_types::BufferAddress>() as wgpu_types::BufferAddress,
143                        usage: wgpu_types::BufferUses::ACCELERATION_STRUCTURE_QUERY
144                            | wgpu_types::BufferUses::MAP_READ,
145                        memory_flags: hal::MemoryFlags::PREFER_COHERENT,
146                    })
147                    .map_err(DeviceError::from_hal)?
148            }))
149        } else {
150            None
151        };
152
153        let handle = unsafe {
154            self.raw()
155                .get_acceleration_structure_device_address(raw.as_ref())
156        };
157
158        Ok(Arc::new(resource::Blas {
159            raw: Snatchable::new(raw),
160            device: self.clone(),
161            size_info,
162            sizes,
163            flags: blas_desc.flags,
164            update_mode: blas_desc.update_mode,
165            handle,
166            label: blas_desc.label.to_string(),
167            built_index: RwLock::new(rank::BLAS_BUILT_INDEX, None),
168            tracking_data: TrackingData::new(self.tracker_indices.blas_s.clone()),
169            compaction_buffer,
170            compacted_state: Mutex::new(rank::BLAS_COMPACTION_STATE, BlasCompactState::Idle),
171        }))
172    }
173
174    fn create_tlas(
175        self: &Arc<Self>,
176        desc: &resource::TlasDescriptor,
177    ) -> Result<Arc<resource::Tlas>, CreateTlasError> {
178        self.check_is_valid()?;
179        self.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
180
181        if desc.max_instances > self.limits.max_tlas_instance_count {
182            return Err(CreateTlasError::TooManyInstances(
183                self.limits.max_tlas_instance_count,
184                desc.max_instances,
185            ));
186        }
187
188        if desc
189            .flags
190            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
191        {
192            return Err(CreateTlasError::DisallowedFlag(
193                wgt::AccelerationStructureFlags::USE_TRANSFORM,
194            ));
195        }
196
197        if desc
198            .flags
199            .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
200        {
201            self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
202        }
203
204        let size_info = unsafe {
205            self.raw().get_acceleration_structure_build_sizes(
206                &hal::GetAccelerationStructureBuildSizesDescriptor {
207                    entries: &hal::AccelerationStructureEntries::Instances(
208                        hal::AccelerationStructureInstances {
209                            buffer: None,
210                            offset: 0,
211                            count: desc.max_instances,
212                        },
213                    ),
214                    flags: desc.flags,
215                },
216            )
217        };
218
219        let raw = unsafe {
220            self.raw()
221                .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
222                    label: desc.label.as_deref(),
223                    size: size_info.acceleration_structure_size,
224                    format: hal::AccelerationStructureFormat::TopLevel,
225                    allow_compaction: false,
226                })
227        }
228        .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
229
230        let instance_buffer_size =
231            self.alignments.raw_tlas_instance_size * desc.max_instances.max(1) as usize;
232        let instance_buffer = unsafe {
233            self.raw().create_buffer(&hal::BufferDescriptor {
234                label: Some("(wgpu-core) instances_buffer"),
235                size: instance_buffer_size as u64,
236                usage: wgt::BufferUses::COPY_DST
237                    | wgt::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
238                memory_flags: hal::MemoryFlags::PREFER_COHERENT,
239            })
240        }
241        .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
242
243        Ok(Arc::new(resource::Tlas {
244            raw: Snatchable::new(raw),
245            device: self.clone(),
246            size_info,
247            flags: desc.flags,
248            update_mode: desc.update_mode,
249            built_index: RwLock::new(rank::TLAS_BUILT_INDEX, None),
250            dependencies: RwLock::new(rank::TLAS_DEPENDENCIES, Vec::new()),
251            instance_buffer: ManuallyDrop::new(instance_buffer),
252            label: desc.label.to_string(),
253            max_instance_count: desc.max_instances,
254            tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()),
255        }))
256    }
257}
258
259impl Global {
260    pub fn device_create_blas(
261        &self,
262        device_id: id::DeviceId,
263        desc: &resource::BlasDescriptor,
264        sizes: wgt::BlasGeometrySizeDescriptors,
265        id_in: Option<BlasId>,
266    ) -> (BlasId, Option<u64>, Option<CreateBlasError>) {
267        profiling::scope!("Device::create_blas");
268
269        let fid = self.hub.blas_s.prepare(id_in);
270
271        let error = 'error: {
272            let device = self.hub.devices.get(device_id);
273
274            #[cfg(feature = "trace")]
275            if let Some(trace) = device.trace.lock().as_mut() {
276                trace.add(trace::Action::CreateBlas {
277                    id: fid.id(),
278                    desc: desc.clone(),
279                    sizes: sizes.clone(),
280                });
281            }
282
283            let blas = match device.create_blas(desc, sizes) {
284                Ok(blas) => blas,
285                Err(e) => break 'error e,
286            };
287            let handle = blas.handle;
288
289            let id = fid.assign(Fallible::Valid(blas));
290            api_log!("Device::create_blas -> {id:?}");
291
292            return (id, Some(handle), None);
293        };
294
295        let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
296        (id, None, Some(error))
297    }
298
299    pub fn device_create_tlas(
300        &self,
301        device_id: id::DeviceId,
302        desc: &resource::TlasDescriptor,
303        id_in: Option<TlasId>,
304    ) -> (TlasId, Option<CreateTlasError>) {
305        profiling::scope!("Device::create_tlas");
306
307        let fid = self.hub.tlas_s.prepare(id_in);
308
309        let error = 'error: {
310            let device = self.hub.devices.get(device_id);
311
312            #[cfg(feature = "trace")]
313            if let Some(trace) = device.trace.lock().as_mut() {
314                trace.add(trace::Action::CreateTlas {
315                    id: fid.id(),
316                    desc: desc.clone(),
317                });
318            }
319
320            let tlas = match device.create_tlas(desc) {
321                Ok(tlas) => tlas,
322                Err(e) => break 'error e,
323            };
324
325            let id = fid.assign(Fallible::Valid(tlas));
326            api_log!("Device::create_tlas -> {id:?}");
327
328            return (id, None);
329        };
330
331        let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
332        (id, Some(error))
333    }
334
335    pub fn blas_drop(&self, blas_id: BlasId) {
336        profiling::scope!("Blas::drop");
337        api_log!("Blas::drop {blas_id:?}");
338
339        let _blas = self.hub.blas_s.remove(blas_id);
340
341        #[cfg(feature = "trace")]
342        if let Ok(blas) = _blas.get() {
343            if let Some(t) = blas.device.trace.lock().as_mut() {
344                t.add(trace::Action::DestroyBlas(blas_id));
345            }
346        }
347    }
348
349    pub fn tlas_drop(&self, tlas_id: TlasId) {
350        profiling::scope!("Tlas::drop");
351        api_log!("Tlas::drop {tlas_id:?}");
352
353        let _tlas = self.hub.tlas_s.remove(tlas_id);
354
355        #[cfg(feature = "trace")]
356        if let Ok(tlas) = _tlas.get() {
357            if let Some(t) = tlas.device.trace.lock().as_mut() {
358                t.add(trace::Action::DestroyTlas(tlas_id));
359            }
360        }
361    }
362
363    pub fn blas_prepare_compact_async(
364        &self,
365        blas_id: BlasId,
366        callback: Option<BlasCompactCallback>,
367    ) -> Result<crate::SubmissionIndex, BlasPrepareCompactError> {
368        profiling::scope!("Blas::prepare_compact_async");
369        api_log!("Blas::prepare_compact_async {blas_id:?}");
370
371        let hub = &self.hub;
372
373        let compact_result = match hub.blas_s.get(blas_id).get() {
374            Ok(blas) => blas.prepare_compact_async(callback),
375            Err(e) => Err((callback, e.into())),
376        };
377
378        match compact_result {
379            Ok(submission_index) => Ok(submission_index),
380            Err((mut callback, err)) => {
381                if let Some(callback) = callback.take() {
382                    callback(Err(err.clone()));
383                }
384                Err(err)
385            }
386        }
387    }
388
389    pub fn ready_for_compaction(&self, blas_id: BlasId) -> Result<bool, InvalidResourceError> {
390        profiling::scope!("Blas::prepare_compact_async");
391        api_log!("Blas::prepare_compact_async {blas_id:?}");
392
393        let hub = &self.hub;
394
395        let blas = hub.blas_s.get(blas_id).get()?;
396
397        let lock = blas.compacted_state.lock();
398
399        Ok(matches!(*lock, BlasCompactState::Ready { .. }))
400    }
401}