1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2use core::{
3 cmp::max,
4 num::NonZeroU64,
5 ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError};
11use crate::{
12 command::CommandBufferMutable,
13 device::queue::TempResource,
14 global::Global,
15 hub::Hub,
16 id::CommandEncoderId,
17 init_tracker::MemoryInitKind,
18 ray_tracing::{
19 BlasBuildEntry, BlasGeometries, BlasTriangleGeometry, BuildAccelerationStructureError,
20 TlasInstance, TlasPackage, TraceBlasBuildEntry, TraceBlasGeometries,
21 TraceBlasTriangleGeometry, TraceTlasInstance, TraceTlasPackage,
22 },
23 resource::{Blas, BlasCompactState, Buffer, Labeled, StagingBuffer, Tlas},
24 scratch::ScratchBuffer,
25 snatch::SnatchGuard,
26 track::PendingTransition,
27};
28use crate::{command::EncoderStateError, device::resource::CommandIndices};
29use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
30
31use crate::id::{BlasId, TlasId};
32
33struct TriangleBufferStore<'a> {
34 vertex_buffer: Arc<Buffer>,
35 vertex_transition: Option<PendingTransition<BufferUses>>,
36 index_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
37 transform_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
38 geometry: BlasTriangleGeometry<'a>,
39 ending_blas: Option<Arc<Blas>>,
40}
41
42struct BlasStore<'a> {
43 blas: Arc<Blas>,
44 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
45 scratch_buffer_offset: u64,
46}
47
48struct UnsafeTlasStore<'a> {
49 tlas: Arc<Tlas>,
50 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
51 scratch_buffer_offset: u64,
52}
53
54struct TlasStore<'a> {
55 internal: UnsafeTlasStore<'a>,
56 range: Range<usize>,
57}
58
59impl Global {
60 pub fn command_encoder_mark_acceleration_structures_built(
61 &self,
62 command_encoder_id: CommandEncoderId,
63 blas_ids: &[BlasId],
64 tlas_ids: &[TlasId],
65 ) -> Result<(), EncoderStateError> {
66 profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
67
68 let hub = &self.hub;
69
70 let cmd_buf = hub
71 .command_buffers
72 .get(command_encoder_id.into_command_buffer_id());
73
74 let mut cmd_buf_data = cmd_buf.data.lock();
75 cmd_buf_data.record_with(
76 |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
77 let device = &cmd_buf.device;
78 device.check_is_valid()?;
79 device
80 .require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
81
82 let mut build_command = AsBuild::default();
83
84 for blas in blas_ids {
85 let blas = hub.blas_s.get(*blas).get()?;
86 build_command.blas_s_built.push(blas);
87 }
88
89 for tlas in tlas_ids {
90 let tlas = hub.tlas_s.get(*tlas).get()?;
91 build_command.tlas_s_built.push(TlasBuild {
92 tlas,
93 dependencies: Vec::new(),
94 });
95 }
96
97 cmd_buf_data.as_actions.push(AsAction::Build(build_command));
98 Ok(())
99 },
100 )
101 }
102
103 pub fn command_encoder_build_acceleration_structures<'a>(
104 &self,
105 command_encoder_id: CommandEncoderId,
106 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
107 tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
108 ) -> Result<(), EncoderStateError> {
109 profiling::scope!("CommandEncoder::build_acceleration_structures");
110
111 let hub = &self.hub;
112
113 let cmd_buf = hub
114 .command_buffers
115 .get(command_encoder_id.into_command_buffer_id());
116
117 let mut build_command = AsBuild::default();
118
119 let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
120 .map(|blas_entry| {
121 let geometries = match blas_entry.geometries {
122 BlasGeometries::TriangleGeometries(triangle_geometries) => {
123 TraceBlasGeometries::TriangleGeometries(
124 triangle_geometries
125 .map(|tg| TraceBlasTriangleGeometry {
126 size: tg.size.clone(),
127 vertex_buffer: tg.vertex_buffer,
128 index_buffer: tg.index_buffer,
129 transform_buffer: tg.transform_buffer,
130 first_vertex: tg.first_vertex,
131 vertex_stride: tg.vertex_stride,
132 first_index: tg.first_index,
133 transform_buffer_offset: tg.transform_buffer_offset,
134 })
135 .collect(),
136 )
137 }
138 };
139 TraceBlasBuildEntry {
140 blas_id: blas_entry.blas_id,
141 geometries,
142 }
143 })
144 .collect();
145
146 let trace_tlas: Vec<TraceTlasPackage> = tlas_iter
147 .map(|package: TlasPackage| {
148 let instances = package
149 .instances
150 .map(|instance| {
151 instance.map(|instance| TraceTlasInstance {
152 blas_id: instance.blas_id,
153 transform: *instance.transform,
154 custom_data: instance.custom_data,
155 mask: instance.mask,
156 })
157 })
158 .collect();
159 TraceTlasPackage {
160 tlas_id: package.tlas_id,
161 instances,
162 lowest_unmodified: package.lowest_unmodified,
163 }
164 })
165 .collect();
166
167 let blas_iter = trace_blas.iter().map(|blas_entry| {
168 let geometries = match &blas_entry.geometries {
169 TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
170 let iter = triangle_geometries.iter().map(|tg| BlasTriangleGeometry {
171 size: &tg.size,
172 vertex_buffer: tg.vertex_buffer,
173 index_buffer: tg.index_buffer,
174 transform_buffer: tg.transform_buffer,
175 first_vertex: tg.first_vertex,
176 vertex_stride: tg.vertex_stride,
177 first_index: tg.first_index,
178 transform_buffer_offset: tg.transform_buffer_offset,
179 });
180 BlasGeometries::TriangleGeometries(Box::new(iter))
181 }
182 };
183 BlasBuildEntry {
184 blas_id: blas_entry.blas_id,
185 geometries,
186 }
187 });
188
189 let tlas_iter = trace_tlas.iter().map(|tlas_package| {
190 let instances = tlas_package.instances.iter().map(|instance| {
191 instance.as_ref().map(|instance| TlasInstance {
192 blas_id: instance.blas_id,
193 transform: &instance.transform,
194 custom_data: instance.custom_data,
195 mask: instance.mask,
196 })
197 });
198 TlasPackage {
199 tlas_id: tlas_package.tlas_id,
200 instances: Box::new(instances),
201 lowest_unmodified: tlas_package.lowest_unmodified,
202 }
203 });
204
205 let mut cmd_buf_data = cmd_buf.data.lock();
206 cmd_buf_data.record_with(|cmd_buf_data| {
207 #[cfg(feature = "trace")]
208 if let Some(ref mut list) = cmd_buf_data.commands {
209 list.push(crate::device::trace::Command::BuildAccelerationStructures {
210 blas: trace_blas.clone(),
211 tlas: trace_tlas.clone(),
212 });
213 }
214
215 let device = &cmd_buf.device;
216 device.check_is_valid()?;
217 device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
218
219 let mut buf_storage = Vec::new();
220 iter_blas(
221 blas_iter,
222 cmd_buf_data,
223 &mut build_command,
224 &mut buf_storage,
225 hub,
226 )?;
227
228 let snatch_guard = device.snatchable_lock.read();
229 let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
230 let mut scratch_buffer_blas_size = 0;
231 let mut blas_storage = Vec::new();
232 iter_buffers(
233 &mut buf_storage,
234 &snatch_guard,
235 &mut input_barriers,
236 cmd_buf_data,
237 &mut scratch_buffer_blas_size,
238 &mut blas_storage,
239 hub,
240 device.alignments.ray_tracing_scratch_buffer_alignment,
241 )?;
242 let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();
243
244 for package in tlas_iter {
245 let tlas = hub.tlas_s.get(package.tlas_id).get()?;
246
247 cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
248
249 tlas_lock_store.push((Some(package), tlas))
250 }
251
252 let mut scratch_buffer_tlas_size = 0;
253 let mut tlas_storage = Vec::<TlasStore>::new();
254 let mut instance_buffer_staging_source = Vec::<u8>::new();
255
256 for (package, tlas) in &mut tlas_lock_store {
257 let package = package.take().unwrap();
258
259 let scratch_buffer_offset = scratch_buffer_tlas_size;
260 scratch_buffer_tlas_size += align_to(
261 tlas.size_info.build_scratch_size as u32,
262 device.alignments.ray_tracing_scratch_buffer_alignment,
263 ) as u64;
264
265 let first_byte_index = instance_buffer_staging_source.len();
266
267 let mut dependencies = Vec::new();
268
269 let mut instance_count = 0;
270 for instance in package.instances.flatten() {
271 if instance.custom_data >= (1u32 << 24u32) {
272 return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
273 tlas.error_ident(),
274 ));
275 }
276 let blas = hub.blas_s.get(instance.blas_id).get()?;
277
278 cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
279
280 instance_buffer_staging_source.extend(device.raw().tlas_instance_to_bytes(
281 hal::TlasInstance {
282 transform: *instance.transform,
283 custom_data: instance.custom_data,
284 mask: instance.mask,
285 blas_address: blas.handle,
286 },
287 ));
288
289 if tlas.flags.contains(
290 wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN,
291 ) && !blas.flags.contains(
292 wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN,
293 ) {
294 return Err(
295 BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
296 tlas.error_ident(),
297 blas.error_ident(),
298 ),
299 );
300 }
301
302 instance_count += 1;
303
304 dependencies.push(blas.clone());
305 }
306
307 build_command.tlas_s_built.push(TlasBuild {
308 tlas: tlas.clone(),
309 dependencies,
310 });
311
312 if instance_count > tlas.max_instance_count {
313 return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
314 tlas.error_ident(),
315 instance_count,
316 tlas.max_instance_count,
317 ));
318 }
319
320 tlas_storage.push(TlasStore {
321 internal: UnsafeTlasStore {
322 tlas: tlas.clone(),
323 entries: hal::AccelerationStructureEntries::Instances(
324 hal::AccelerationStructureInstances {
325 buffer: Some(tlas.instance_buffer.as_ref()),
326 offset: 0,
327 count: instance_count,
328 },
329 ),
330 scratch_buffer_offset,
331 },
332 range: first_byte_index..instance_buffer_staging_source.len(),
333 });
334 }
335
336 let Some(scratch_size) =
337 wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
338 else {
339 return Ok(());
341 };
342
343 let scratch_buffer = ScratchBuffer::new(device, scratch_size)?;
344
345 let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
346 buffer: scratch_buffer.raw(),
347 usage: hal::StateTransition {
348 from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
349 to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
350 },
351 };
352
353 let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
354
355 for &TlasStore {
356 internal:
357 UnsafeTlasStore {
358 ref tlas,
359 ref entries,
360 ref scratch_buffer_offset,
361 },
362 ..
363 } in &tlas_storage
364 {
365 if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
366 log::info!("only rebuild implemented")
367 }
368 tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
369 entries,
370 mode: hal::AccelerationStructureBuildMode::Build,
371 flags: tlas.flags,
372 source_acceleration_structure: None,
373 destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
374 scratch_buffer: scratch_buffer.raw(),
375 scratch_buffer_offset: *scratch_buffer_offset,
376 })
377 }
378
379 let blas_present = !blas_storage.is_empty();
380 let tlas_present = !tlas_storage.is_empty();
381
382 let cmd_buf_raw = cmd_buf_data.encoder.open()?;
383
384 let mut blas_s_compactable = Vec::new();
385 let mut descriptors = Vec::new();
386
387 for storage in &blas_storage {
388 descriptors.push(map_blas(
389 storage,
390 scratch_buffer.raw(),
391 &snatch_guard,
392 &mut blas_s_compactable,
393 )?);
394 }
395
396 build_blas(
397 cmd_buf_raw,
398 blas_present,
399 tlas_present,
400 input_barriers,
401 &descriptors,
402 scratch_buffer_barrier,
403 blas_s_compactable,
404 );
405
406 if tlas_present {
407 let staging_buffer = if !instance_buffer_staging_source.is_empty() {
408 let mut staging_buffer = StagingBuffer::new(
409 device,
410 wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
411 )?;
412 staging_buffer.write(&instance_buffer_staging_source);
413 let flushed = staging_buffer.flush();
414 Some(flushed)
415 } else {
416 None
417 };
418
419 unsafe {
420 if let Some(ref staging_buffer) = staging_buffer {
421 cmd_buf_raw.transition_buffers(&[
422 hal::BufferBarrier::<dyn hal::DynBuffer> {
423 buffer: staging_buffer.raw(),
424 usage: hal::StateTransition {
425 from: BufferUses::MAP_WRITE,
426 to: BufferUses::COPY_SRC,
427 },
428 },
429 ]);
430 }
431 }
432
433 let mut instance_buffer_barriers = Vec::new();
434 for &TlasStore {
435 internal: UnsafeTlasStore { ref tlas, .. },
436 ref range,
437 } in &tlas_storage
438 {
439 let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
440 None => continue,
441 Some(size) => size,
442 };
443 instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
444 buffer: tlas.instance_buffer.as_ref(),
445 usage: hal::StateTransition {
446 from: BufferUses::COPY_DST,
447 to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
448 },
449 });
450 unsafe {
451 cmd_buf_raw.transition_buffers(&[
452 hal::BufferBarrier::<dyn hal::DynBuffer> {
453 buffer: tlas.instance_buffer.as_ref(),
454 usage: hal::StateTransition {
455 from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
456 to: BufferUses::COPY_DST,
457 },
458 },
459 ]);
460 let temp = hal::BufferCopy {
461 src_offset: range.start as u64,
462 dst_offset: 0,
463 size,
464 };
465 cmd_buf_raw.copy_buffer_to_buffer(
466 staging_buffer.as_ref().unwrap().raw(),
469 tlas.instance_buffer.as_ref(),
470 &[temp],
471 );
472 }
473 }
474
475 unsafe {
476 cmd_buf_raw.transition_buffers(&instance_buffer_barriers);
477
478 cmd_buf_raw.build_acceleration_structures(&tlas_descriptors);
479
480 cmd_buf_raw.place_acceleration_structure_barrier(
481 hal::AccelerationStructureBarrier {
482 usage: hal::StateTransition {
483 from: hal::AccelerationStructureUses::BUILD_OUTPUT,
484 to: hal::AccelerationStructureUses::SHADER_INPUT,
485 },
486 },
487 );
488 }
489
490 if let Some(staging_buffer) = staging_buffer {
491 cmd_buf_data
492 .temp_resources
493 .push(TempResource::StagingBuffer(staging_buffer));
494 }
495 }
496
497 cmd_buf_data
498 .temp_resources
499 .push(TempResource::ScratchBuffer(scratch_buffer));
500
501 cmd_buf_data.as_actions.push(AsAction::Build(build_command));
502
503 Ok(())
504 })
505 }
506}
507
508impl CommandBufferMutable {
509 pub(crate) fn validate_acceleration_structure_actions(
510 &self,
511 snatch_guard: &SnatchGuard,
512 command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
513 ) -> Result<(), ValidateAsActionsError> {
514 profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
515 for action in &self.as_actions {
516 match action {
517 AsAction::Build(build) => {
518 let build_command_index = NonZeroU64::new(
519 command_index_guard.next_acceleration_structure_build_command_index,
520 )
521 .unwrap();
522
523 command_index_guard.next_acceleration_structure_build_command_index += 1;
524 for blas in build.blas_s_built.iter() {
525 let mut state_lock = blas.compacted_state.lock();
526 *state_lock = match *state_lock {
527 BlasCompactState::Compacted => {
528 unreachable!("Should be validated out in build.")
529 }
530 _ => BlasCompactState::Idle,
533 };
534 *blas.built_index.write() = Some(build_command_index);
535 }
536
537 for tlas_build in build.tlas_s_built.iter() {
538 for blas in &tlas_build.dependencies {
539 if blas.built_index.read().is_none() {
540 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
541 blas.error_ident(),
542 tlas_build.tlas.error_ident(),
543 ));
544 }
545 }
546 *tlas_build.tlas.built_index.write() = Some(build_command_index);
547 tlas_build
548 .tlas
549 .dependencies
550 .write()
551 .clone_from(&tlas_build.dependencies)
552 }
553 }
554 AsAction::UseTlas(tlas) => {
555 let tlas_build_index = tlas.built_index.read();
556 let dependencies = tlas.dependencies.read();
557
558 if (*tlas_build_index).is_none() {
559 return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
560 }
561 for blas in dependencies.deref() {
562 let blas_build_index = *blas.built_index.read();
563 if blas_build_index.is_none() {
564 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
565 tlas.error_ident(),
566 blas.error_ident(),
567 ));
568 }
569 if blas_build_index.unwrap() > tlas_build_index.unwrap() {
570 return Err(ValidateAsActionsError::BlasNewerThenTlas(
571 blas.error_ident(),
572 tlas.error_ident(),
573 ));
574 }
575 blas.try_raw(snatch_guard)?;
576 }
577 }
578 }
579 }
580 Ok(())
581 }
582}
583
584fn iter_blas<'a>(
586 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
587 cmd_buf_data: &mut CommandBufferMutable,
588 build_command: &mut AsBuild,
589 buf_storage: &mut Vec<TriangleBufferStore<'a>>,
590 hub: &Hub,
591) -> Result<(), BuildAccelerationStructureError> {
592 let mut temp_buffer = Vec::new();
593 for entry in blas_iter {
594 let blas = hub.blas_s.get(entry.blas_id).get()?;
595 cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
596
597 build_command.blas_s_built.push(blas.clone());
598
599 match entry.geometries {
600 BlasGeometries::TriangleGeometries(triangle_geometries) => {
601 for (i, mesh) in triangle_geometries.enumerate() {
602 let size_desc = match &blas.sizes {
603 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
604 };
605 if i >= size_desc.len() {
606 return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
607 blas.error_ident(),
608 ));
609 }
610 let size_desc = &size_desc[i];
611
612 if size_desc.flags != mesh.size.flags {
613 return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
614 blas.error_ident(),
615 size_desc.flags,
616 mesh.size.flags,
617 ));
618 }
619
620 if size_desc.vertex_count < mesh.size.vertex_count {
621 return Err(
622 BuildAccelerationStructureError::IncompatibleBlasVertexCount(
623 blas.error_ident(),
624 size_desc.vertex_count,
625 mesh.size.vertex_count,
626 ),
627 );
628 }
629
630 if size_desc.vertex_format != mesh.size.vertex_format {
631 return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
632 blas.error_ident(),
633 size_desc.vertex_format,
634 mesh.size.vertex_format,
635 ));
636 }
637
638 if size_desc
639 .vertex_format
640 .min_acceleration_structure_vertex_stride()
641 > mesh.vertex_stride
642 {
643 return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
644 blas.error_ident(),
645 size_desc
646 .vertex_format
647 .min_acceleration_structure_vertex_stride(),
648 mesh.vertex_stride,
649 ));
650 }
651
652 if mesh.vertex_stride
653 % size_desc
654 .vertex_format
655 .acceleration_structure_stride_alignment()
656 != 0
657 {
658 return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
659 blas.error_ident(),
660 size_desc
661 .vertex_format
662 .acceleration_structure_stride_alignment(),
663 mesh.vertex_stride,
664 ));
665 }
666
667 match (size_desc.index_count, mesh.size.index_count) {
668 (Some(_), None) | (None, Some(_)) => {
669 return Err(
670 BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
671 blas.error_ident(),
672 ),
673 )
674 }
675 (Some(create), Some(build)) if create < build => {
676 return Err(
677 BuildAccelerationStructureError::IncompatibleBlasIndexCount(
678 blas.error_ident(),
679 create,
680 build,
681 ),
682 )
683 }
684 _ => {}
685 }
686
687 if size_desc.index_format != mesh.size.index_format {
688 return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
689 blas.error_ident(),
690 size_desc.index_format,
691 mesh.size.index_format,
692 ));
693 }
694
695 if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
696 return Err(BuildAccelerationStructureError::MissingIndexBuffer(
697 blas.error_ident(),
698 ));
699 }
700 let vertex_buffer = hub.buffers.get(mesh.vertex_buffer).get()?;
701 let vertex_pending = cmd_buf_data.trackers.buffers.set_single(
702 &vertex_buffer,
703 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
704 );
705 let index_data = if let Some(index_id) = mesh.index_buffer {
706 let index_buffer = hub.buffers.get(index_id).get()?;
707 if mesh.first_index.is_none()
708 || mesh.size.index_count.is_none()
709 || mesh.size.index_count.is_none()
710 {
711 return Err(BuildAccelerationStructureError::MissingAssociatedData(
712 index_buffer.error_ident(),
713 ));
714 }
715 let data = cmd_buf_data.trackers.buffers.set_single(
716 &index_buffer,
717 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
718 );
719 Some((index_buffer, data))
720 } else {
721 None
722 };
723 let transform_data = if let Some(transform_id) = mesh.transform_buffer {
724 if !blas
725 .flags
726 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
727 {
728 return Err(BuildAccelerationStructureError::UseTransformMissing(
729 blas.error_ident(),
730 ));
731 }
732 let transform_buffer = hub.buffers.get(transform_id).get()?;
733 if mesh.transform_buffer_offset.is_none() {
734 return Err(BuildAccelerationStructureError::MissingAssociatedData(
735 transform_buffer.error_ident(),
736 ));
737 }
738 let data = cmd_buf_data.trackers.buffers.set_single(
739 &transform_buffer,
740 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
741 );
742 Some((transform_buffer, data))
743 } else {
744 if blas
745 .flags
746 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
747 {
748 return Err(BuildAccelerationStructureError::TransformMissing(
749 blas.error_ident(),
750 ));
751 }
752 None
753 };
754 temp_buffer.push(TriangleBufferStore {
755 vertex_buffer,
756 vertex_transition: vertex_pending,
757 index_buffer_transition: index_data,
758 transform_buffer_transition: transform_data,
759 geometry: mesh,
760 ending_blas: None,
761 });
762 }
763
764 if let Some(last) = temp_buffer.last_mut() {
765 last.ending_blas = Some(blas);
766 buf_storage.append(&mut temp_buffer);
767 }
768 }
769 }
770 }
771 Ok(())
772}
773
774fn iter_buffers<'a, 'b>(
776 buf_storage: &'a mut Vec<TriangleBufferStore<'b>>,
777 snatch_guard: &'a SnatchGuard,
778 input_barriers: &mut Vec<hal::BufferBarrier<'a, dyn hal::DynBuffer>>,
779 cmd_buf_data: &mut CommandBufferMutable,
780 scratch_buffer_blas_size: &mut u64,
781 blas_storage: &mut Vec<BlasStore<'a>>,
782 hub: &Hub,
783 ray_tracing_scratch_buffer_alignment: u32,
784) -> Result<(), BuildAccelerationStructureError> {
785 let mut triangle_entries =
786 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
787 for buf in buf_storage {
788 let mesh = &buf.geometry;
789 let vertex_buffer = {
790 let vertex_buffer = buf.vertex_buffer.as_ref();
791 let vertex_raw = vertex_buffer.try_raw(snatch_guard)?;
792 vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
793
794 if let Some(barrier) = buf
795 .vertex_transition
796 .take()
797 .map(|pending| pending.into_hal(vertex_buffer, snatch_guard))
798 {
799 input_barriers.push(barrier);
800 }
801 if vertex_buffer.size
802 < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride
803 {
804 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
805 vertex_buffer.error_ident(),
806 vertex_buffer.size,
807 (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride,
808 ));
809 }
810 let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
811 cmd_buf_data.buffer_memory_init_actions.extend(
812 vertex_buffer.initialization_status.read().create_action(
813 &hub.buffers.get(mesh.vertex_buffer).get()?,
814 vertex_buffer_offset
815 ..(vertex_buffer_offset
816 + mesh.size.vertex_count as u64 * mesh.vertex_stride),
817 MemoryInitKind::NeedsInitializedMemory,
818 ),
819 );
820 vertex_raw
821 };
822 let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
823 buf.index_buffer_transition
824 {
825 let index_raw = index_buffer.try_raw(snatch_guard)?;
826 index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
827
828 if let Some(barrier) = index_pending
829 .take()
830 .map(|pending| pending.into_hal(index_buffer, snatch_guard))
831 {
832 input_barriers.push(barrier);
833 }
834 let index_stride = mesh.size.index_format.unwrap().byte_size() as u64;
835 let offset = mesh.first_index.unwrap() as u64 * index_stride;
836 let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride;
837
838 if mesh.size.index_count.unwrap() % 3 != 0 {
839 return Err(BuildAccelerationStructureError::InvalidIndexCount(
840 index_buffer.error_ident(),
841 mesh.size.index_count.unwrap(),
842 ));
843 }
844 if index_buffer.size < mesh.size.index_count.unwrap() as u64 * index_stride + offset {
845 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
846 index_buffer.error_ident(),
847 index_buffer.size,
848 mesh.size.index_count.unwrap() as u64 * index_stride + offset,
849 ));
850 }
851
852 cmd_buf_data.buffer_memory_init_actions.extend(
853 index_buffer.initialization_status.read().create_action(
854 index_buffer,
855 offset..(offset + index_buffer_size),
856 MemoryInitKind::NeedsInitializedMemory,
857 ),
858 );
859 Some(index_raw)
860 } else {
861 None
862 };
863 let transform_buffer = if let Some((ref mut transform_buffer, ref mut transform_pending)) =
864 buf.transform_buffer_transition
865 {
866 if mesh.transform_buffer_offset.is_none() {
867 return Err(BuildAccelerationStructureError::MissingAssociatedData(
868 transform_buffer.error_ident(),
869 ));
870 }
871 let transform_raw = transform_buffer.try_raw(snatch_guard)?;
872 transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
873
874 if let Some(barrier) = transform_pending
875 .take()
876 .map(|pending| pending.into_hal(transform_buffer, snatch_guard))
877 {
878 input_barriers.push(barrier);
879 }
880
881 let offset = mesh.transform_buffer_offset.unwrap();
882
883 if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 {
884 return Err(
885 BuildAccelerationStructureError::UnalignedTransformBufferOffset(
886 transform_buffer.error_ident(),
887 ),
888 );
889 }
890 if transform_buffer.size < 48 + offset {
891 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
892 transform_buffer.error_ident(),
893 transform_buffer.size,
894 48 + offset,
895 ));
896 }
897 cmd_buf_data.buffer_memory_init_actions.extend(
898 transform_buffer.initialization_status.read().create_action(
899 transform_buffer,
900 offset..(offset + 48),
901 MemoryInitKind::NeedsInitializedMemory,
902 ),
903 );
904 Some(transform_raw)
905 } else {
906 None
907 };
908
909 let triangles = hal::AccelerationStructureTriangles {
910 vertex_buffer: Some(vertex_buffer),
911 vertex_format: mesh.size.vertex_format,
912 first_vertex: mesh.first_vertex,
913 vertex_count: mesh.size.vertex_count,
914 vertex_stride: mesh.vertex_stride,
915 indices: index_buffer.map(|index_buffer| {
916 let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
917 hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
918 format: mesh.size.index_format.unwrap(),
919 buffer: Some(index_buffer),
920 offset: mesh.first_index.unwrap() * index_stride,
921 count: mesh.size.index_count.unwrap(),
922 }
923 }),
924 transform: transform_buffer.map(|transform_buffer| {
925 hal::AccelerationStructureTriangleTransform {
926 buffer: transform_buffer,
927 offset: mesh.transform_buffer_offset.unwrap() as u32,
928 }
929 }),
930 flags: mesh.size.flags,
931 };
932 triangle_entries.push(triangles);
933 if let Some(blas) = buf.ending_blas.take() {
934 let scratch_buffer_offset = *scratch_buffer_blas_size;
935 *scratch_buffer_blas_size += align_to(
936 blas.size_info.build_scratch_size as u32,
937 ray_tracing_scratch_buffer_alignment,
938 ) as u64;
939
940 blas_storage.push(BlasStore {
941 blas,
942 entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
943 scratch_buffer_offset,
944 });
945 triangle_entries = Vec::new();
946 }
947 }
948 Ok(())
949}
950
951fn map_blas<'a>(
952 storage: &'a BlasStore<'_>,
953 scratch_buffer: &'a dyn hal::DynBuffer,
954 snatch_guard: &'a SnatchGuard,
955 blases_compactable: &mut Vec<(
956 &'a dyn hal::DynBuffer,
957 &'a dyn hal::DynAccelerationStructure,
958 )>,
959) -> Result<
960 hal::BuildAccelerationStructureDescriptor<
961 'a,
962 dyn hal::DynBuffer,
963 dyn hal::DynAccelerationStructure,
964 >,
965 BuildAccelerationStructureError,
966> {
967 let BlasStore {
968 blas,
969 entries,
970 scratch_buffer_offset,
971 } = storage;
972 if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
973 log::info!("only rebuild implemented")
974 }
975 let raw = blas.try_raw(snatch_guard)?;
976
977 let state_lock = blas.compacted_state.lock();
978 if let BlasCompactState::Compacted = *state_lock {
979 return Err(BuildAccelerationStructureError::CompactedBlas(
980 blas.error_ident(),
981 ));
982 }
983
984 if blas
985 .flags
986 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
987 {
988 blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
989 }
990 Ok(hal::BuildAccelerationStructureDescriptor {
991 entries,
992 mode: hal::AccelerationStructureBuildMode::Build,
993 flags: blas.flags,
994 source_acceleration_structure: None,
995 destination_acceleration_structure: raw,
996 scratch_buffer,
997 scratch_buffer_offset: *scratch_buffer_offset,
998 })
999}
1000
1001fn build_blas<'a>(
1002 cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
1003 blas_present: bool,
1004 tlas_present: bool,
1005 input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1006 blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1007 'a,
1008 dyn hal::DynBuffer,
1009 dyn hal::DynAccelerationStructure,
1010 >],
1011 scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1012 blas_s_for_compaction: Vec<(
1013 &'a dyn hal::DynBuffer,
1014 &'a dyn hal::DynAccelerationStructure,
1015 )>,
1016) {
1017 unsafe {
1018 cmd_buf_raw.transition_buffers(&input_barriers);
1019 }
1020
1021 if blas_present {
1022 unsafe {
1023 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1024 usage: hal::StateTransition {
1025 from: hal::AccelerationStructureUses::BUILD_INPUT,
1026 to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1027 },
1028 });
1029
1030 cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1031 }
1032 }
1033
1034 if blas_present && tlas_present {
1035 unsafe {
1036 cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1037 }
1038 }
1039
1040 let mut source_usage = hal::AccelerationStructureUses::empty();
1041 let mut destination_usage = hal::AccelerationStructureUses::empty();
1042 for &(buf, blas) in blas_s_for_compaction.iter() {
1043 unsafe {
1044 cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
1045 buffer: buf,
1046 usage: hal::StateTransition {
1047 from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1048 to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1049 },
1050 }])
1051 }
1052 unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
1053 destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
1054 }
1055
1056 if blas_present {
1057 source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1058 destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1059 }
1060 if tlas_present {
1061 source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1062 destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1063 }
1064 unsafe {
1065 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1066 usage: hal::StateTransition {
1067 from: source_usage,
1068 to: destination_usage,
1069 },
1070 });
1071 }
1072}