1use alloc::{string::ToString as _, sync::Arc, vec::Vec};
2use core::mem::{size_of, ManuallyDrop};
3
4#[cfg(feature = "trace")]
5use crate::device::trace;
6use crate::device::DeviceError;
7use crate::{
8 api_log,
9 device::Device,
10 global::Global,
11 id::{self, BlasId, TlasId},
12 lock::RwLock,
13 lock::{rank, Mutex},
14 ray_tracing::BlasPrepareCompactError,
15 ray_tracing::{CreateBlasError, CreateTlasError},
16 resource,
17 resource::{
18 BlasCompactCallback, BlasCompactState, Fallible, InvalidResourceError, TrackingData,
19 },
20 snatch::Snatchable,
21 LabelHelpers,
22};
23use hal::AccelerationStructureTriangleIndices;
24use wgt::Features;
25
26impl Device {
27 fn create_blas(
28 self: &Arc<Self>,
29 blas_desc: &resource::BlasDescriptor,
30 sizes: wgt::BlasGeometrySizeDescriptors,
31 ) -> Result<Arc<resource::Blas>, CreateBlasError> {
32 self.check_is_valid()?;
33 self.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
34
35 if blas_desc
36 .flags
37 .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
38 {
39 self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
40 }
41
42 let size_info = match &sizes {
43 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => {
44 if descriptors.len() as u32 > self.limits.max_blas_geometry_count {
45 return Err(CreateBlasError::TooManyGeometries(
46 self.limits.max_blas_geometry_count,
47 descriptors.len() as u32,
48 ));
49 }
50
51 let mut entries =
52 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::with_capacity(
53 descriptors.len(),
54 );
55 for desc in descriptors {
56 if desc.index_count.is_some() != desc.index_format.is_some() {
57 return Err(CreateBlasError::MissingIndexData);
58 }
59 let indices =
60 desc.index_count
61 .map(|count| AccelerationStructureTriangleIndices::<
62 dyn hal::DynBuffer,
63 > {
64 format: desc.index_format.unwrap(),
65 buffer: None,
66 offset: 0,
67 count,
68 });
69 if !self
70 .features
71 .allowed_vertex_formats_for_blas()
72 .contains(&desc.vertex_format)
73 {
74 return Err(CreateBlasError::InvalidVertexFormat(
75 desc.vertex_format,
76 self.features.allowed_vertex_formats_for_blas(),
77 ));
78 }
79
80 let mut transform = None;
81
82 if blas_desc
83 .flags
84 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
85 {
86 transform = Some(wgpu_hal::AccelerationStructureTriangleTransform {
87 buffer: self.zero_buffer.as_ref(),
88 offset: 0,
89 })
90 }
91
92 if desc.vertex_count > self.limits.max_blas_primitive_count {
93 return Err(CreateBlasError::TooManyPrimitives(
94 self.limits.max_blas_primitive_count,
95 desc.vertex_count,
96 ));
97 }
98
99 entries.push(hal::AccelerationStructureTriangles::<dyn hal::DynBuffer> {
100 vertex_buffer: None,
101 vertex_format: desc.vertex_format,
102 first_vertex: 0,
103 vertex_count: desc.vertex_count,
104 vertex_stride: 0,
105 indices,
106 transform,
107 flags: desc.flags,
108 });
109 }
110 unsafe {
111 self.raw().get_acceleration_structure_build_sizes(
112 &hal::GetAccelerationStructureBuildSizesDescriptor {
113 entries: &hal::AccelerationStructureEntries::Triangles(entries),
114 flags: blas_desc.flags,
115 },
116 )
117 }
118 }
119 };
120
121 let raw = unsafe {
122 self.raw()
123 .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
124 label: blas_desc.label.as_deref(),
125 size: size_info.acceleration_structure_size,
126 format: hal::AccelerationStructureFormat::BottomLevel,
127 allow_compaction: blas_desc
128 .flags
129 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION),
130 })
131 }
132 .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
133
134 let compaction_buffer = if blas_desc
135 .flags
136 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
137 {
138 Some(ManuallyDrop::new(unsafe {
139 self.raw()
140 .create_buffer(&hal::BufferDescriptor {
141 label: Some("(wgpu internal) compaction read-back buffer"),
142 size: size_of::<wgpu_types::BufferAddress>() as wgpu_types::BufferAddress,
143 usage: wgpu_types::BufferUses::ACCELERATION_STRUCTURE_QUERY
144 | wgpu_types::BufferUses::MAP_READ,
145 memory_flags: hal::MemoryFlags::PREFER_COHERENT,
146 })
147 .map_err(DeviceError::from_hal)?
148 }))
149 } else {
150 None
151 };
152
153 let handle = unsafe {
154 self.raw()
155 .get_acceleration_structure_device_address(raw.as_ref())
156 };
157
158 Ok(Arc::new(resource::Blas {
159 raw: Snatchable::new(raw),
160 device: self.clone(),
161 size_info,
162 sizes,
163 flags: blas_desc.flags,
164 update_mode: blas_desc.update_mode,
165 handle,
166 label: blas_desc.label.to_string(),
167 built_index: RwLock::new(rank::BLAS_BUILT_INDEX, None),
168 tracking_data: TrackingData::new(self.tracker_indices.blas_s.clone()),
169 compaction_buffer,
170 compacted_state: Mutex::new(rank::BLAS_COMPACTION_STATE, BlasCompactState::Idle),
171 }))
172 }
173
174 fn create_tlas(
175 self: &Arc<Self>,
176 desc: &resource::TlasDescriptor,
177 ) -> Result<Arc<resource::Tlas>, CreateTlasError> {
178 self.check_is_valid()?;
179 self.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
180
181 if desc.max_instances > self.limits.max_tlas_instance_count {
182 return Err(CreateTlasError::TooManyInstances(
183 self.limits.max_tlas_instance_count,
184 desc.max_instances,
185 ));
186 }
187
188 if desc
189 .flags
190 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
191 {
192 return Err(CreateTlasError::DisallowedFlag(
193 wgt::AccelerationStructureFlags::USE_TRANSFORM,
194 ));
195 }
196
197 if desc
198 .flags
199 .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
200 {
201 self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
202 }
203
204 let size_info = unsafe {
205 self.raw().get_acceleration_structure_build_sizes(
206 &hal::GetAccelerationStructureBuildSizesDescriptor {
207 entries: &hal::AccelerationStructureEntries::Instances(
208 hal::AccelerationStructureInstances {
209 buffer: None,
210 offset: 0,
211 count: desc.max_instances,
212 },
213 ),
214 flags: desc.flags,
215 },
216 )
217 };
218
219 let raw = unsafe {
220 self.raw()
221 .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
222 label: desc.label.as_deref(),
223 size: size_info.acceleration_structure_size,
224 format: hal::AccelerationStructureFormat::TopLevel,
225 allow_compaction: false,
226 })
227 }
228 .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
229
230 let instance_buffer_size =
231 self.alignments.raw_tlas_instance_size * desc.max_instances.max(1) as usize;
232 let instance_buffer = unsafe {
233 self.raw().create_buffer(&hal::BufferDescriptor {
234 label: Some("(wgpu-core) instances_buffer"),
235 size: instance_buffer_size as u64,
236 usage: wgt::BufferUses::COPY_DST
237 | wgt::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
238 memory_flags: hal::MemoryFlags::PREFER_COHERENT,
239 })
240 }
241 .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
242
243 Ok(Arc::new(resource::Tlas {
244 raw: Snatchable::new(raw),
245 device: self.clone(),
246 size_info,
247 flags: desc.flags,
248 update_mode: desc.update_mode,
249 built_index: RwLock::new(rank::TLAS_BUILT_INDEX, None),
250 dependencies: RwLock::new(rank::TLAS_DEPENDENCIES, Vec::new()),
251 instance_buffer: ManuallyDrop::new(instance_buffer),
252 label: desc.label.to_string(),
253 max_instance_count: desc.max_instances,
254 tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()),
255 }))
256 }
257}
258
259impl Global {
260 pub fn device_create_blas(
261 &self,
262 device_id: id::DeviceId,
263 desc: &resource::BlasDescriptor,
264 sizes: wgt::BlasGeometrySizeDescriptors,
265 id_in: Option<BlasId>,
266 ) -> (BlasId, Option<u64>, Option<CreateBlasError>) {
267 profiling::scope!("Device::create_blas");
268
269 let fid = self.hub.blas_s.prepare(id_in);
270
271 let error = 'error: {
272 let device = self.hub.devices.get(device_id);
273
274 #[cfg(feature = "trace")]
275 if let Some(trace) = device.trace.lock().as_mut() {
276 trace.add(trace::Action::CreateBlas {
277 id: fid.id(),
278 desc: desc.clone(),
279 sizes: sizes.clone(),
280 });
281 }
282
283 let blas = match device.create_blas(desc, sizes) {
284 Ok(blas) => blas,
285 Err(e) => break 'error e,
286 };
287 let handle = blas.handle;
288
289 let id = fid.assign(Fallible::Valid(blas));
290 api_log!("Device::create_blas -> {id:?}");
291
292 return (id, Some(handle), None);
293 };
294
295 let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
296 (id, None, Some(error))
297 }
298
299 pub fn device_create_tlas(
300 &self,
301 device_id: id::DeviceId,
302 desc: &resource::TlasDescriptor,
303 id_in: Option<TlasId>,
304 ) -> (TlasId, Option<CreateTlasError>) {
305 profiling::scope!("Device::create_tlas");
306
307 let fid = self.hub.tlas_s.prepare(id_in);
308
309 let error = 'error: {
310 let device = self.hub.devices.get(device_id);
311
312 #[cfg(feature = "trace")]
313 if let Some(trace) = device.trace.lock().as_mut() {
314 trace.add(trace::Action::CreateTlas {
315 id: fid.id(),
316 desc: desc.clone(),
317 });
318 }
319
320 let tlas = match device.create_tlas(desc) {
321 Ok(tlas) => tlas,
322 Err(e) => break 'error e,
323 };
324
325 let id = fid.assign(Fallible::Valid(tlas));
326 api_log!("Device::create_tlas -> {id:?}");
327
328 return (id, None);
329 };
330
331 let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
332 (id, Some(error))
333 }
334
335 pub fn blas_drop(&self, blas_id: BlasId) {
336 profiling::scope!("Blas::drop");
337 api_log!("Blas::drop {blas_id:?}");
338
339 let _blas = self.hub.blas_s.remove(blas_id);
340
341 #[cfg(feature = "trace")]
342 if let Ok(blas) = _blas.get() {
343 if let Some(t) = blas.device.trace.lock().as_mut() {
344 t.add(trace::Action::DestroyBlas(blas_id));
345 }
346 }
347 }
348
349 pub fn tlas_drop(&self, tlas_id: TlasId) {
350 profiling::scope!("Tlas::drop");
351 api_log!("Tlas::drop {tlas_id:?}");
352
353 let _tlas = self.hub.tlas_s.remove(tlas_id);
354
355 #[cfg(feature = "trace")]
356 if let Ok(tlas) = _tlas.get() {
357 if let Some(t) = tlas.device.trace.lock().as_mut() {
358 t.add(trace::Action::DestroyTlas(tlas_id));
359 }
360 }
361 }
362
363 pub fn blas_prepare_compact_async(
364 &self,
365 blas_id: BlasId,
366 callback: Option<BlasCompactCallback>,
367 ) -> Result<crate::SubmissionIndex, BlasPrepareCompactError> {
368 profiling::scope!("Blas::prepare_compact_async");
369 api_log!("Blas::prepare_compact_async {blas_id:?}");
370
371 let hub = &self.hub;
372
373 let compact_result = match hub.blas_s.get(blas_id).get() {
374 Ok(blas) => blas.prepare_compact_async(callback),
375 Err(e) => Err((callback, e.into())),
376 };
377
378 match compact_result {
379 Ok(submission_index) => Ok(submission_index),
380 Err((mut callback, err)) => {
381 if let Some(callback) = callback.take() {
382 callback(Err(err.clone()));
383 }
384 Err(err)
385 }
386 }
387 }
388
389 pub fn ready_for_compaction(&self, blas_id: BlasId) -> Result<bool, InvalidResourceError> {
390 profiling::scope!("Blas::prepare_compact_async");
391 api_log!("Blas::prepare_compact_async {blas_id:?}");
392
393 let hub = &self.hub;
394
395 let blas = hub.blas_s.get(blas_id).get()?;
396
397 let lock = blas.compacted_state.lock();
398
399 Ok(matches!(*lock, BlasCompactState::Ready { .. }))
400 }
401}