1use alloc::{sync::Arc, vec::Vec};
2use core::{
3 cmp::max,
4 num::NonZeroU64,
5 ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::{
11 command::CommandBufferMutable,
12 device::queue::TempResource,
13 global::Global,
14 id::CommandEncoderId,
15 init_tracker::MemoryInitKind,
16 ray_tracing::{
17 BlasBuildEntry, BlasGeometries, BuildAccelerationStructureError, TlasPackage,
18 TraceBlasBuildEntry, TraceBlasGeometries, TraceBlasTriangleGeometry, TraceTlasInstance,
19 TraceTlasPackage,
20 },
21 resource::{Blas, BlasCompactState, Buffer, Labeled, StagingBuffer, Tlas},
22 scratch::ScratchBuffer,
23 snatch::SnatchGuard,
24 track::PendingTransition,
25};
26use crate::{command::EncoderStateError, device::resource::CommandIndices};
27use crate::{
28 command::{encoder::EncodingState, ArcCommand},
29 ray_tracing::{
30 ArcBlasBuildEntry, ArcBlasGeometries, ArcBlasTriangleGeometry, ArcTlasInstance,
31 ArcTlasPackage, AsAction, AsBuild, BlasTriangleGeometryInfo, TlasBuild,
32 ValidateAsActionsError,
33 },
34 resource::InvalidResourceError,
35 track::Tracker,
36};
37use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
38
39use crate::id::{BlasId, TlasId};
40
41struct TriangleBufferStore {
42 vertex_buffer: Arc<Buffer>,
43 vertex_transition: Option<PendingTransition<BufferUses>>,
44 index_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
45 transform_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
46 geometry: BlasTriangleGeometryInfo,
47 ending_blas: Option<Arc<Blas>>,
48}
49
50struct BlasStore<'a> {
51 blas: Arc<Blas>,
52 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
53 scratch_buffer_offset: u64,
54}
55
56struct UnsafeTlasStore<'a> {
57 tlas: Arc<Tlas>,
58 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
59 scratch_buffer_offset: u64,
60}
61
62struct TlasStore<'a> {
63 internal: UnsafeTlasStore<'a>,
64 range: Range<usize>,
65}
66
67impl Global {
68 fn resolve_blas_id(&self, blas_id: BlasId) -> Result<Arc<Blas>, InvalidResourceError> {
69 self.hub.blas_s.get(blas_id).get()
70 }
71
72 fn resolve_tlas_id(&self, tlas_id: TlasId) -> Result<Arc<Tlas>, InvalidResourceError> {
73 self.hub.tlas_s.get(tlas_id).get()
74 }
75
76 pub fn command_encoder_mark_acceleration_structures_built(
77 &self,
78 command_encoder_id: CommandEncoderId,
79 blas_ids: &[BlasId],
80 tlas_ids: &[TlasId],
81 ) -> Result<(), EncoderStateError> {
82 profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
83
84 let hub = &self.hub;
85
86 let cmd_enc = hub.command_encoders.get(command_encoder_id);
87
88 let mut cmd_buf_data = cmd_enc.data.lock();
89 cmd_buf_data.with_buffer(
90 |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
91 let device = &cmd_enc.device;
92 device.check_is_valid()?;
93 device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
94
95 let mut build_command = AsBuild::default();
96
97 for blas in blas_ids {
98 let blas = hub.blas_s.get(*blas).get()?;
99 build_command.blas_s_built.push(blas);
100 }
101
102 for tlas in tlas_ids {
103 let tlas = hub.tlas_s.get(*tlas).get()?;
104 build_command.tlas_s_built.push(TlasBuild {
105 tlas,
106 dependencies: Vec::new(),
107 });
108 }
109
110 cmd_buf_data.as_actions.push(AsAction::Build(build_command));
111 Ok(())
112 },
113 )
114 }
115
116 pub fn command_encoder_build_acceleration_structures<'a>(
117 &self,
118 command_encoder_id: CommandEncoderId,
119 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
120 tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
121 ) -> Result<(), EncoderStateError> {
122 profiling::scope!("CommandEncoder::build_acceleration_structures");
123
124 let hub = &self.hub;
125
126 let cmd_enc = hub.command_encoders.get(command_encoder_id);
127
128 let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
129 .map(|blas_entry| {
130 let geometries = match blas_entry.geometries {
131 BlasGeometries::TriangleGeometries(triangle_geometries) => {
132 TraceBlasGeometries::TriangleGeometries(
133 triangle_geometries
134 .map(|tg| TraceBlasTriangleGeometry {
135 size: tg.size.clone(),
136 vertex_buffer: tg.vertex_buffer,
137 index_buffer: tg.index_buffer,
138 transform_buffer: tg.transform_buffer,
139 first_vertex: tg.first_vertex,
140 vertex_stride: tg.vertex_stride,
141 first_index: tg.first_index,
142 transform_buffer_offset: tg.transform_buffer_offset,
143 })
144 .collect(),
145 )
146 }
147 };
148 TraceBlasBuildEntry {
149 blas_id: blas_entry.blas_id,
150 geometries,
151 }
152 })
153 .collect();
154
155 let trace_tlas: Vec<TraceTlasPackage> = tlas_iter
156 .map(|package: TlasPackage| {
157 let instances = package
158 .instances
159 .map(|instance| {
160 instance.map(|instance| TraceTlasInstance {
161 blas_id: instance.blas_id,
162 transform: *instance.transform,
163 custom_data: instance.custom_data,
164 mask: instance.mask,
165 })
166 })
167 .collect();
168 TraceTlasPackage {
169 tlas_id: package.tlas_id,
170 instances,
171 lowest_unmodified: package.lowest_unmodified,
172 }
173 })
174 .collect();
175
176 let mut cmd_buf_data = cmd_enc.data.lock();
177
178 #[cfg(feature = "trace")]
179 if let Some(ref mut list) = cmd_buf_data.trace() {
180 list.push(crate::command::Command::BuildAccelerationStructures {
181 blas: trace_blas.clone(),
182 tlas: trace_tlas.clone(),
183 });
184 }
185
186 cmd_buf_data.push_with(|| -> Result<_, BuildAccelerationStructureError> {
187 let blas = trace_blas
188 .iter()
189 .map(|blas_entry| {
190 let geometries = match &blas_entry.geometries {
191 TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
192 let tri_geo = triangle_geometries
193 .iter()
194 .map(|tg| {
195 Ok(ArcBlasTriangleGeometry {
196 size: tg.size.clone(),
197 vertex_buffer: self.resolve_buffer_id(tg.vertex_buffer)?,
198 index_buffer: tg
199 .index_buffer
200 .map(|id| self.resolve_buffer_id(id))
201 .transpose()?,
202 transform_buffer: tg
203 .transform_buffer
204 .map(|id| self.resolve_buffer_id(id))
205 .transpose()?,
206 first_vertex: tg.first_vertex,
207 vertex_stride: tg.vertex_stride,
208 first_index: tg.first_index,
209 transform_buffer_offset: tg.transform_buffer_offset,
210 })
211 })
212 .collect::<Result<_, BuildAccelerationStructureError>>()?;
213 ArcBlasGeometries::TriangleGeometries(tri_geo)
214 }
215 };
216 Ok(ArcBlasBuildEntry {
217 blas: self.resolve_blas_id(blas_entry.blas_id)?,
218 geometries,
219 })
220 })
221 .collect::<Result<_, BuildAccelerationStructureError>>()?;
222
223 let tlas = trace_tlas
224 .iter()
225 .map(|tlas_package| {
226 let instances = tlas_package
227 .instances
228 .iter()
229 .map(|instance| {
230 instance
231 .as_ref()
232 .map(|instance| {
233 Ok(ArcTlasInstance {
234 blas: self.resolve_blas_id(instance.blas_id)?,
235 transform: instance.transform,
236 custom_data: instance.custom_data,
237 mask: instance.mask,
238 })
239 })
240 .transpose()
241 })
242 .collect::<Result<_, BuildAccelerationStructureError>>()?;
243 Ok(ArcTlasPackage {
244 tlas: self.resolve_tlas_id(tlas_package.tlas_id)?,
245 instances,
246 lowest_unmodified: tlas_package.lowest_unmodified,
247 })
248 })
249 .collect::<Result<_, BuildAccelerationStructureError>>()?;
250
251 Ok(ArcCommand::BuildAccelerationStructures { blas, tlas })
252 })
253 }
254}
255
256pub(crate) fn build_acceleration_structures(
257 state: &mut EncodingState,
258 blas: Vec<ArcBlasBuildEntry>,
259 tlas: Vec<ArcTlasPackage>,
260) -> Result<(), BuildAccelerationStructureError> {
261 state
262 .device
263 .require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
264
265 let mut build_command = AsBuild::default();
266 let mut buf_storage = Vec::new();
267 iter_blas(
268 blas.into_iter(),
269 state.tracker,
270 &mut build_command,
271 &mut buf_storage,
272 )?;
273
274 let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
275 let mut scratch_buffer_blas_size = 0;
276 let mut blas_storage = Vec::new();
277 iter_buffers(
278 state,
279 &mut buf_storage,
280 &mut input_barriers,
281 &mut scratch_buffer_blas_size,
282 &mut blas_storage,
283 )?;
284 let mut tlas_lock_store = Vec::<(Option<ArcTlasPackage>, Arc<Tlas>)>::new();
285
286 for package in tlas.into_iter() {
287 let tlas = package.tlas.clone();
288 state.tracker.tlas_s.insert_single(tlas.clone());
289 tlas_lock_store.push((Some(package), tlas))
290 }
291
292 let mut scratch_buffer_tlas_size = 0;
293 let mut tlas_storage = Vec::<TlasStore>::new();
294 let mut instance_buffer_staging_source = Vec::<u8>::new();
295
296 for (package, tlas) in &mut tlas_lock_store {
297 let package = package.take().unwrap();
298
299 let scratch_buffer_offset = scratch_buffer_tlas_size;
300 scratch_buffer_tlas_size += align_to(
301 tlas.size_info.build_scratch_size as u32,
302 state.device.alignments.ray_tracing_scratch_buffer_alignment,
303 ) as u64;
304
305 let first_byte_index = instance_buffer_staging_source.len();
306
307 let mut dependencies = Vec::new();
308
309 let mut instance_count = 0;
310 for instance in package.instances.into_iter().flatten() {
311 if instance.custom_data >= (1u32 << 24u32) {
312 return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
313 tlas.error_ident(),
314 ));
315 }
316 let blas = &instance.blas;
317 state.tracker.blas_s.insert_single(blas.clone());
318
319 instance_buffer_staging_source.extend(state.device.raw().tlas_instance_to_bytes(
320 hal::TlasInstance {
321 transform: instance.transform,
322 custom_data: instance.custom_data,
323 mask: instance.mask,
324 blas_address: blas.handle,
325 },
326 ));
327
328 if tlas
329 .flags
330 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
331 && !blas
332 .flags
333 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
334 {
335 return Err(
336 BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
337 tlas.error_ident(),
338 blas.error_ident(),
339 ),
340 );
341 }
342
343 instance_count += 1;
344
345 dependencies.push(blas.clone());
346 }
347
348 build_command.tlas_s_built.push(TlasBuild {
349 tlas: tlas.clone(),
350 dependencies,
351 });
352
353 if instance_count > tlas.max_instance_count {
354 return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
355 tlas.error_ident(),
356 instance_count,
357 tlas.max_instance_count,
358 ));
359 }
360
361 tlas_storage.push(TlasStore {
362 internal: UnsafeTlasStore {
363 tlas: tlas.clone(),
364 entries: hal::AccelerationStructureEntries::Instances(
365 hal::AccelerationStructureInstances {
366 buffer: Some(tlas.instance_buffer.as_ref()),
367 offset: 0,
368 count: instance_count,
369 },
370 ),
371 scratch_buffer_offset,
372 },
373 range: first_byte_index..instance_buffer_staging_source.len(),
374 });
375 }
376
377 let Some(scratch_size) =
378 wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
379 else {
380 return Ok(());
382 };
383
384 let scratch_buffer = ScratchBuffer::new(state.device, scratch_size)?;
385
386 let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
387 buffer: scratch_buffer.raw(),
388 usage: hal::StateTransition {
389 from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
390 to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
391 },
392 };
393
394 let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
395
396 for &TlasStore {
397 internal:
398 UnsafeTlasStore {
399 ref tlas,
400 ref entries,
401 ref scratch_buffer_offset,
402 },
403 ..
404 } in &tlas_storage
405 {
406 if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
407 log::info!("only rebuild implemented")
408 }
409 tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
410 entries,
411 mode: hal::AccelerationStructureBuildMode::Build,
412 flags: tlas.flags,
413 source_acceleration_structure: None,
414 destination_acceleration_structure: tlas.try_raw(state.snatch_guard)?,
415 scratch_buffer: scratch_buffer.raw(),
416 scratch_buffer_offset: *scratch_buffer_offset,
417 })
418 }
419
420 let blas_present = !blas_storage.is_empty();
421 let tlas_present = !tlas_storage.is_empty();
422
423 let raw_encoder = &mut state.raw_encoder;
424
425 let mut blas_s_compactable = Vec::new();
426 let mut descriptors = Vec::new();
427
428 for storage in &blas_storage {
429 descriptors.push(map_blas(
430 storage,
431 scratch_buffer.raw(),
432 state.snatch_guard,
433 &mut blas_s_compactable,
434 )?);
435 }
436
437 build_blas(
438 *raw_encoder,
439 blas_present,
440 tlas_present,
441 input_barriers,
442 &descriptors,
443 scratch_buffer_barrier,
444 blas_s_compactable,
445 );
446
447 if tlas_present {
448 let staging_buffer = if !instance_buffer_staging_source.is_empty() {
449 let mut staging_buffer = StagingBuffer::new(
450 state.device,
451 wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
452 )?;
453 staging_buffer.write(&instance_buffer_staging_source);
454 let flushed = staging_buffer.flush();
455 Some(flushed)
456 } else {
457 None
458 };
459
460 unsafe {
461 if let Some(ref staging_buffer) = staging_buffer {
462 raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
463 buffer: staging_buffer.raw(),
464 usage: hal::StateTransition {
465 from: BufferUses::MAP_WRITE,
466 to: BufferUses::COPY_SRC,
467 },
468 }]);
469 }
470 }
471
472 let mut instance_buffer_barriers = Vec::new();
473 for &TlasStore {
474 internal: UnsafeTlasStore { ref tlas, .. },
475 ref range,
476 } in &tlas_storage
477 {
478 let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
479 None => continue,
480 Some(size) => size,
481 };
482 instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
483 buffer: tlas.instance_buffer.as_ref(),
484 usage: hal::StateTransition {
485 from: BufferUses::COPY_DST,
486 to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
487 },
488 });
489 unsafe {
490 raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
491 buffer: tlas.instance_buffer.as_ref(),
492 usage: hal::StateTransition {
493 from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
494 to: BufferUses::COPY_DST,
495 },
496 }]);
497 let temp = hal::BufferCopy {
498 src_offset: range.start as u64,
499 dst_offset: 0,
500 size,
501 };
502 raw_encoder.copy_buffer_to_buffer(
503 staging_buffer.as_ref().unwrap().raw(),
504 tlas.instance_buffer.as_ref(),
505 &[temp],
506 );
507 }
508 }
509
510 unsafe {
511 raw_encoder.transition_buffers(&instance_buffer_barriers);
512
513 raw_encoder.build_acceleration_structures(&tlas_descriptors);
514
515 raw_encoder.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
516 usage: hal::StateTransition {
517 from: hal::AccelerationStructureUses::BUILD_OUTPUT,
518 to: hal::AccelerationStructureUses::SHADER_INPUT,
519 },
520 });
521 }
522
523 if let Some(staging_buffer) = staging_buffer {
524 state
525 .temp_resources
526 .push(TempResource::StagingBuffer(staging_buffer));
527 }
528 }
529
530 state
531 .temp_resources
532 .push(TempResource::ScratchBuffer(scratch_buffer));
533
534 state.as_actions.push(AsAction::Build(build_command));
535
536 Ok(())
537}
538
539impl CommandBufferMutable {
540 pub(crate) fn validate_acceleration_structure_actions(
541 &self,
542 snatch_guard: &SnatchGuard,
543 command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
544 ) -> Result<(), ValidateAsActionsError> {
545 profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
546 for action in &self.as_actions {
547 match action {
548 AsAction::Build(build) => {
549 let build_command_index = NonZeroU64::new(
550 command_index_guard.next_acceleration_structure_build_command_index,
551 )
552 .unwrap();
553
554 command_index_guard.next_acceleration_structure_build_command_index += 1;
555 for blas in build.blas_s_built.iter() {
556 let mut state_lock = blas.compacted_state.lock();
557 *state_lock = match *state_lock {
558 BlasCompactState::Compacted => {
559 unreachable!("Should be validated out in build.")
560 }
561 _ => BlasCompactState::Idle,
564 };
565 *blas.built_index.write() = Some(build_command_index);
566 }
567
568 for tlas_build in build.tlas_s_built.iter() {
569 for blas in &tlas_build.dependencies {
570 if blas.built_index.read().is_none() {
571 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
572 blas.error_ident(),
573 tlas_build.tlas.error_ident(),
574 ));
575 }
576 }
577 *tlas_build.tlas.built_index.write() = Some(build_command_index);
578 tlas_build
579 .tlas
580 .dependencies
581 .write()
582 .clone_from(&tlas_build.dependencies)
583 }
584 }
585 AsAction::UseTlas(tlas) => {
586 let tlas_build_index = tlas.built_index.read();
587 let dependencies = tlas.dependencies.read();
588
589 if (*tlas_build_index).is_none() {
590 return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
591 }
592 for blas in dependencies.deref() {
593 let blas_build_index = *blas.built_index.read();
594 if blas_build_index.is_none() {
595 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
596 tlas.error_ident(),
597 blas.error_ident(),
598 ));
599 }
600 if blas_build_index.unwrap() > tlas_build_index.unwrap() {
601 return Err(ValidateAsActionsError::BlasNewerThenTlas(
602 blas.error_ident(),
603 tlas.error_ident(),
604 ));
605 }
606 blas.try_raw(snatch_guard)?;
607 }
608 }
609 }
610 }
611 Ok(())
612 }
613}
614
615fn iter_blas(
617 blas_iter: impl Iterator<Item = ArcBlasBuildEntry>,
618 tracker: &mut Tracker,
619 build_command: &mut AsBuild,
620 buf_storage: &mut Vec<TriangleBufferStore>,
621) -> Result<(), BuildAccelerationStructureError> {
622 let mut temp_buffer = Vec::new();
623 for entry in blas_iter {
624 let blas = &entry.blas;
625 tracker.blas_s.insert_single(blas.clone());
626
627 build_command.blas_s_built.push(blas.clone());
628
629 match entry.geometries {
630 ArcBlasGeometries::TriangleGeometries(triangle_geometries) => {
631 for (i, mesh) in triangle_geometries.into_iter().enumerate() {
632 let size_desc = match &blas.sizes {
633 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
634 };
635 if i >= size_desc.len() {
636 return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
637 blas.error_ident(),
638 ));
639 }
640 let size_desc = &size_desc[i];
641
642 if size_desc.flags != mesh.size.flags {
643 return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
644 blas.error_ident(),
645 size_desc.flags,
646 mesh.size.flags,
647 ));
648 }
649
650 if size_desc.vertex_count < mesh.size.vertex_count {
651 return Err(
652 BuildAccelerationStructureError::IncompatibleBlasVertexCount(
653 blas.error_ident(),
654 size_desc.vertex_count,
655 mesh.size.vertex_count,
656 ),
657 );
658 }
659
660 if size_desc.vertex_format != mesh.size.vertex_format {
661 return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
662 blas.error_ident(),
663 size_desc.vertex_format,
664 mesh.size.vertex_format,
665 ));
666 }
667
668 if size_desc
669 .vertex_format
670 .min_acceleration_structure_vertex_stride()
671 > mesh.vertex_stride
672 {
673 return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
674 blas.error_ident(),
675 size_desc
676 .vertex_format
677 .min_acceleration_structure_vertex_stride(),
678 mesh.vertex_stride,
679 ));
680 }
681
682 if mesh.vertex_stride
683 % size_desc
684 .vertex_format
685 .acceleration_structure_stride_alignment()
686 != 0
687 {
688 return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
689 blas.error_ident(),
690 size_desc
691 .vertex_format
692 .acceleration_structure_stride_alignment(),
693 mesh.vertex_stride,
694 ));
695 }
696
697 match (size_desc.index_count, mesh.size.index_count) {
698 (Some(_), None) | (None, Some(_)) => {
699 return Err(
700 BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
701 blas.error_ident(),
702 ),
703 )
704 }
705 (Some(create), Some(build)) if create < build => {
706 return Err(
707 BuildAccelerationStructureError::IncompatibleBlasIndexCount(
708 blas.error_ident(),
709 create,
710 build,
711 ),
712 )
713 }
714 _ => {}
715 }
716
717 if size_desc.index_format != mesh.size.index_format {
718 return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
719 blas.error_ident(),
720 size_desc.index_format,
721 mesh.size.index_format,
722 ));
723 }
724
725 if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
726 return Err(BuildAccelerationStructureError::MissingIndexBuffer(
727 blas.error_ident(),
728 ));
729 }
730 let vertex_buffer = mesh.vertex_buffer.clone();
731 let vertex_pending = tracker.buffers.set_single(
732 &vertex_buffer,
733 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
734 );
735 let index_data = if let Some(index_buffer) = mesh.index_buffer {
736 if mesh.first_index.is_none()
737 || mesh.size.index_count.is_none()
738 || mesh.size.index_count.is_none()
739 {
740 return Err(BuildAccelerationStructureError::MissingAssociatedData(
741 index_buffer.error_ident(),
742 ));
743 }
744 let data = tracker.buffers.set_single(
745 &index_buffer,
746 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
747 );
748 Some((index_buffer, data))
749 } else {
750 None
751 };
752 let transform_data = if let Some(transform_buffer) = mesh.transform_buffer {
753 if !blas
754 .flags
755 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
756 {
757 return Err(BuildAccelerationStructureError::UseTransformMissing(
758 blas.error_ident(),
759 ));
760 }
761 if mesh.transform_buffer_offset.is_none() {
762 return Err(BuildAccelerationStructureError::MissingAssociatedData(
763 transform_buffer.error_ident(),
764 ));
765 }
766 let data = tracker.buffers.set_single(
767 &transform_buffer,
768 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
769 );
770 Some((transform_buffer, data))
771 } else {
772 if blas
773 .flags
774 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
775 {
776 return Err(BuildAccelerationStructureError::TransformMissing(
777 blas.error_ident(),
778 ));
779 }
780 None
781 };
782 temp_buffer.push(TriangleBufferStore {
783 vertex_buffer,
784 vertex_transition: vertex_pending,
785 index_buffer_transition: index_data,
786 transform_buffer_transition: transform_data,
787 geometry: BlasTriangleGeometryInfo {
788 size: mesh.size,
789 first_vertex: mesh.first_vertex,
790 vertex_stride: mesh.vertex_stride,
791 first_index: mesh.first_index,
792 transform_buffer_offset: mesh.transform_buffer_offset,
793 },
794 ending_blas: None,
795 });
796 }
797
798 if let Some(last) = temp_buffer.last_mut() {
799 last.ending_blas = Some(blas.clone());
800 buf_storage.append(&mut temp_buffer);
801 }
802 }
803 }
804 }
805 Ok(())
806}
807
808fn iter_buffers<'snatch_guard: 'buffers, 'buffers>(
814 state: &mut EncodingState<'snatch_guard, '_>,
815 buf_storage: &'buffers mut Vec<TriangleBufferStore>,
816 input_barriers: &mut Vec<hal::BufferBarrier<'buffers, dyn hal::DynBuffer>>,
817 scratch_buffer_blas_size: &mut u64,
818 blas_storage: &mut Vec<BlasStore<'buffers>>,
819) -> Result<(), BuildAccelerationStructureError> {
820 let mut triangle_entries =
821 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
822 for buf in buf_storage {
823 let mesh = &buf.geometry;
824 let vertex_buffer = {
825 let vertex_raw = buf.vertex_buffer.as_ref().try_raw(state.snatch_guard)?;
826 let vertex_buffer = &buf.vertex_buffer;
827 vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
828
829 if let Some(barrier) = buf
830 .vertex_transition
831 .take()
832 .map(|pending| pending.into_hal(buf.vertex_buffer.as_ref(), state.snatch_guard))
833 {
834 input_barriers.push(barrier);
835 }
836 if vertex_buffer.size
837 < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride
838 {
839 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
840 vertex_buffer.error_ident(),
841 vertex_buffer.size,
842 (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride,
843 ));
844 }
845 let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
846 state.buffer_memory_init_actions.extend(
847 vertex_buffer.initialization_status.read().create_action(
848 vertex_buffer,
849 vertex_buffer_offset
850 ..(vertex_buffer_offset
851 + mesh.size.vertex_count as u64 * mesh.vertex_stride),
852 MemoryInitKind::NeedsInitializedMemory,
853 ),
854 );
855 vertex_raw
856 };
857 let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
858 buf.index_buffer_transition
859 {
860 let index_raw = index_buffer.try_raw(state.snatch_guard)?;
861 index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
862
863 if let Some(barrier) = index_pending
864 .take()
865 .map(|pending| pending.into_hal(index_buffer, state.snatch_guard))
866 {
867 input_barriers.push(barrier);
868 }
869 let index_stride = mesh.size.index_format.unwrap().byte_size() as u64;
870 let offset = mesh.first_index.unwrap() as u64 * index_stride;
871 let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride;
872
873 if mesh.size.index_count.unwrap() % 3 != 0 {
874 return Err(BuildAccelerationStructureError::InvalidIndexCount(
875 index_buffer.error_ident(),
876 mesh.size.index_count.unwrap(),
877 ));
878 }
879 if index_buffer.size < mesh.size.index_count.unwrap() as u64 * index_stride + offset {
880 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
881 index_buffer.error_ident(),
882 index_buffer.size,
883 mesh.size.index_count.unwrap() as u64 * index_stride + offset,
884 ));
885 }
886
887 state.buffer_memory_init_actions.extend(
888 index_buffer.initialization_status.read().create_action(
889 index_buffer,
890 offset..(offset + index_buffer_size),
891 MemoryInitKind::NeedsInitializedMemory,
892 ),
893 );
894 Some(index_raw)
895 } else {
896 None
897 };
898 let transform_buffer = if let Some((ref mut transform_buffer, ref mut transform_pending)) =
899 buf.transform_buffer_transition
900 {
901 if mesh.transform_buffer_offset.is_none() {
902 return Err(BuildAccelerationStructureError::MissingAssociatedData(
903 transform_buffer.error_ident(),
904 ));
905 }
906 let transform_raw = transform_buffer.try_raw(state.snatch_guard)?;
907 transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
908
909 if let Some(barrier) = transform_pending
910 .take()
911 .map(|pending| pending.into_hal(transform_buffer, state.snatch_guard))
912 {
913 input_barriers.push(barrier);
914 }
915
916 let offset = mesh.transform_buffer_offset.unwrap();
917
918 if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 {
919 return Err(
920 BuildAccelerationStructureError::UnalignedTransformBufferOffset(
921 transform_buffer.error_ident(),
922 ),
923 );
924 }
925 if transform_buffer.size < 48 + offset {
926 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
927 transform_buffer.error_ident(),
928 transform_buffer.size,
929 48 + offset,
930 ));
931 }
932 state.buffer_memory_init_actions.extend(
933 transform_buffer.initialization_status.read().create_action(
934 transform_buffer,
935 offset..(offset + 48),
936 MemoryInitKind::NeedsInitializedMemory,
937 ),
938 );
939 Some(transform_raw)
940 } else {
941 None
942 };
943
944 let triangles = hal::AccelerationStructureTriangles {
945 vertex_buffer: Some(vertex_buffer),
946 vertex_format: mesh.size.vertex_format,
947 first_vertex: mesh.first_vertex,
948 vertex_count: mesh.size.vertex_count,
949 vertex_stride: mesh.vertex_stride,
950 indices: index_buffer.map(|index_buffer| {
951 let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
952 hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
953 format: mesh.size.index_format.unwrap(),
954 buffer: Some(index_buffer),
955 offset: mesh.first_index.unwrap() * index_stride,
956 count: mesh.size.index_count.unwrap(),
957 }
958 }),
959 transform: transform_buffer.map(|transform_buffer| {
960 hal::AccelerationStructureTriangleTransform {
961 buffer: transform_buffer,
962 offset: mesh.transform_buffer_offset.unwrap() as u32,
963 }
964 }),
965 flags: mesh.size.flags,
966 };
967 triangle_entries.push(triangles);
968 if let Some(blas) = buf.ending_blas.take() {
969 let scratch_buffer_offset = *scratch_buffer_blas_size;
970 *scratch_buffer_blas_size += align_to(
971 blas.size_info.build_scratch_size as u32,
972 state.device.alignments.ray_tracing_scratch_buffer_alignment,
973 ) as u64;
974
975 blas_storage.push(BlasStore {
976 blas,
977 entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
978 scratch_buffer_offset,
979 });
980 triangle_entries = Vec::new();
981 }
982 }
983 Ok(())
984}
985
986fn map_blas<'a>(
987 storage: &'a BlasStore<'_>,
988 scratch_buffer: &'a dyn hal::DynBuffer,
989 snatch_guard: &'a SnatchGuard,
990 blases_compactable: &mut Vec<(
991 &'a dyn hal::DynBuffer,
992 &'a dyn hal::DynAccelerationStructure,
993 )>,
994) -> Result<
995 hal::BuildAccelerationStructureDescriptor<
996 'a,
997 dyn hal::DynBuffer,
998 dyn hal::DynAccelerationStructure,
999 >,
1000 BuildAccelerationStructureError,
1001> {
1002 let BlasStore {
1003 blas,
1004 entries,
1005 scratch_buffer_offset,
1006 } = storage;
1007 if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
1008 log::info!("only rebuild implemented")
1009 }
1010 let raw = blas.try_raw(snatch_guard)?;
1011
1012 let state_lock = blas.compacted_state.lock();
1013 if let BlasCompactState::Compacted = *state_lock {
1014 return Err(BuildAccelerationStructureError::CompactedBlas(
1015 blas.error_ident(),
1016 ));
1017 }
1018
1019 if blas
1020 .flags
1021 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
1022 {
1023 blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
1024 }
1025 Ok(hal::BuildAccelerationStructureDescriptor {
1026 entries,
1027 mode: hal::AccelerationStructureBuildMode::Build,
1028 flags: blas.flags,
1029 source_acceleration_structure: None,
1030 destination_acceleration_structure: raw,
1031 scratch_buffer,
1032 scratch_buffer_offset: *scratch_buffer_offset,
1033 })
1034}
1035
1036fn build_blas<'a>(
1037 cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
1038 blas_present: bool,
1039 tlas_present: bool,
1040 input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1041 blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1042 'a,
1043 dyn hal::DynBuffer,
1044 dyn hal::DynAccelerationStructure,
1045 >],
1046 scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1047 blas_s_for_compaction: Vec<(
1048 &'a dyn hal::DynBuffer,
1049 &'a dyn hal::DynAccelerationStructure,
1050 )>,
1051) {
1052 unsafe {
1053 cmd_buf_raw.transition_buffers(&input_barriers);
1054 }
1055
1056 if blas_present {
1057 unsafe {
1058 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1059 usage: hal::StateTransition {
1060 from: hal::AccelerationStructureUses::BUILD_INPUT,
1061 to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1062 },
1063 });
1064
1065 cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1066 }
1067 }
1068
1069 if blas_present && tlas_present {
1070 unsafe {
1071 cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1072 }
1073 }
1074
1075 let mut source_usage = hal::AccelerationStructureUses::empty();
1076 let mut destination_usage = hal::AccelerationStructureUses::empty();
1077 for &(buf, blas) in blas_s_for_compaction.iter() {
1078 unsafe {
1079 cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
1080 buffer: buf,
1081 usage: hal::StateTransition {
1082 from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1083 to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1084 },
1085 }])
1086 }
1087 unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
1088 destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
1089 }
1090
1091 if blas_present {
1092 source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1093 destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1094 }
1095 if tlas_present {
1096 source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1097 destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1098 }
1099 unsafe {
1100 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1101 usage: hal::StateTransition {
1102 from: source_usage,
1103 to: destination_usage,
1104 },
1105 });
1106 }
1107}