wgpu_core/command/
ray_tracing.rs

1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2use core::{
3    cmp::max,
4    num::NonZeroU64,
5    ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError};
11use crate::{
12    command::CommandBufferMutable,
13    device::queue::TempResource,
14    global::Global,
15    hub::Hub,
16    id::CommandEncoderId,
17    init_tracker::MemoryInitKind,
18    ray_tracing::{
19        BlasBuildEntry, BlasGeometries, BlasTriangleGeometry, BuildAccelerationStructureError,
20        TlasInstance, TlasPackage, TraceBlasBuildEntry, TraceBlasGeometries,
21        TraceBlasTriangleGeometry, TraceTlasInstance, TraceTlasPackage,
22    },
23    resource::{Blas, BlasCompactState, Buffer, Labeled, StagingBuffer, Tlas},
24    scratch::ScratchBuffer,
25    snatch::SnatchGuard,
26    track::PendingTransition,
27};
28use crate::{command::EncoderStateError, device::resource::CommandIndices};
29use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
30
31use crate::id::{BlasId, TlasId};
32
33struct TriangleBufferStore<'a> {
34    vertex_buffer: Arc<Buffer>,
35    vertex_transition: Option<PendingTransition<BufferUses>>,
36    index_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
37    transform_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
38    geometry: BlasTriangleGeometry<'a>,
39    ending_blas: Option<Arc<Blas>>,
40}
41
42struct BlasStore<'a> {
43    blas: Arc<Blas>,
44    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
45    scratch_buffer_offset: u64,
46}
47
48struct UnsafeTlasStore<'a> {
49    tlas: Arc<Tlas>,
50    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
51    scratch_buffer_offset: u64,
52}
53
54struct TlasStore<'a> {
55    internal: UnsafeTlasStore<'a>,
56    range: Range<usize>,
57}
58
59impl Global {
60    pub fn command_encoder_mark_acceleration_structures_built(
61        &self,
62        command_encoder_id: CommandEncoderId,
63        blas_ids: &[BlasId],
64        tlas_ids: &[TlasId],
65    ) -> Result<(), EncoderStateError> {
66        profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
67
68        let hub = &self.hub;
69
70        let cmd_buf = hub
71            .command_buffers
72            .get(command_encoder_id.into_command_buffer_id());
73
74        let mut cmd_buf_data = cmd_buf.data.lock();
75        cmd_buf_data.record_with(
76            |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
77                let device = &cmd_buf.device;
78                device.check_is_valid()?;
79                device
80                    .require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
81
82                let mut build_command = AsBuild::default();
83
84                for blas in blas_ids {
85                    let blas = hub.blas_s.get(*blas).get()?;
86                    build_command.blas_s_built.push(blas);
87                }
88
89                for tlas in tlas_ids {
90                    let tlas = hub.tlas_s.get(*tlas).get()?;
91                    build_command.tlas_s_built.push(TlasBuild {
92                        tlas,
93                        dependencies: Vec::new(),
94                    });
95                }
96
97                cmd_buf_data.as_actions.push(AsAction::Build(build_command));
98                Ok(())
99            },
100        )
101    }
102
103    pub fn command_encoder_build_acceleration_structures<'a>(
104        &self,
105        command_encoder_id: CommandEncoderId,
106        blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
107        tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
108    ) -> Result<(), EncoderStateError> {
109        profiling::scope!("CommandEncoder::build_acceleration_structures");
110
111        let hub = &self.hub;
112
113        let cmd_buf = hub
114            .command_buffers
115            .get(command_encoder_id.into_command_buffer_id());
116
117        let mut build_command = AsBuild::default();
118
119        let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
120            .map(|blas_entry| {
121                let geometries = match blas_entry.geometries {
122                    BlasGeometries::TriangleGeometries(triangle_geometries) => {
123                        TraceBlasGeometries::TriangleGeometries(
124                            triangle_geometries
125                                .map(|tg| TraceBlasTriangleGeometry {
126                                    size: tg.size.clone(),
127                                    vertex_buffer: tg.vertex_buffer,
128                                    index_buffer: tg.index_buffer,
129                                    transform_buffer: tg.transform_buffer,
130                                    first_vertex: tg.first_vertex,
131                                    vertex_stride: tg.vertex_stride,
132                                    first_index: tg.first_index,
133                                    transform_buffer_offset: tg.transform_buffer_offset,
134                                })
135                                .collect(),
136                        )
137                    }
138                };
139                TraceBlasBuildEntry {
140                    blas_id: blas_entry.blas_id,
141                    geometries,
142                }
143            })
144            .collect();
145
146        let trace_tlas: Vec<TraceTlasPackage> = tlas_iter
147            .map(|package: TlasPackage| {
148                let instances = package
149                    .instances
150                    .map(|instance| {
151                        instance.map(|instance| TraceTlasInstance {
152                            blas_id: instance.blas_id,
153                            transform: *instance.transform,
154                            custom_data: instance.custom_data,
155                            mask: instance.mask,
156                        })
157                    })
158                    .collect();
159                TraceTlasPackage {
160                    tlas_id: package.tlas_id,
161                    instances,
162                    lowest_unmodified: package.lowest_unmodified,
163                }
164            })
165            .collect();
166
167        let blas_iter = trace_blas.iter().map(|blas_entry| {
168            let geometries = match &blas_entry.geometries {
169                TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
170                    let iter = triangle_geometries.iter().map(|tg| BlasTriangleGeometry {
171                        size: &tg.size,
172                        vertex_buffer: tg.vertex_buffer,
173                        index_buffer: tg.index_buffer,
174                        transform_buffer: tg.transform_buffer,
175                        first_vertex: tg.first_vertex,
176                        vertex_stride: tg.vertex_stride,
177                        first_index: tg.first_index,
178                        transform_buffer_offset: tg.transform_buffer_offset,
179                    });
180                    BlasGeometries::TriangleGeometries(Box::new(iter))
181                }
182            };
183            BlasBuildEntry {
184                blas_id: blas_entry.blas_id,
185                geometries,
186            }
187        });
188
189        let tlas_iter = trace_tlas.iter().map(|tlas_package| {
190            let instances = tlas_package.instances.iter().map(|instance| {
191                instance.as_ref().map(|instance| TlasInstance {
192                    blas_id: instance.blas_id,
193                    transform: &instance.transform,
194                    custom_data: instance.custom_data,
195                    mask: instance.mask,
196                })
197            });
198            TlasPackage {
199                tlas_id: tlas_package.tlas_id,
200                instances: Box::new(instances),
201                lowest_unmodified: tlas_package.lowest_unmodified,
202            }
203        });
204
205        let mut cmd_buf_data = cmd_buf.data.lock();
206        cmd_buf_data.record_with(|cmd_buf_data| {
207            #[cfg(feature = "trace")]
208            if let Some(ref mut list) = cmd_buf_data.commands {
209                list.push(crate::device::trace::Command::BuildAccelerationStructures {
210                    blas: trace_blas.clone(),
211                    tlas: trace_tlas.clone(),
212                });
213            }
214
215            let device = &cmd_buf.device;
216            device.check_is_valid()?;
217            device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
218
219            let mut buf_storage = Vec::new();
220            iter_blas(
221                blas_iter,
222                cmd_buf_data,
223                &mut build_command,
224                &mut buf_storage,
225                hub,
226            )?;
227
228            let snatch_guard = device.snatchable_lock.read();
229            let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
230            let mut scratch_buffer_blas_size = 0;
231            let mut blas_storage = Vec::new();
232            iter_buffers(
233                &mut buf_storage,
234                &snatch_guard,
235                &mut input_barriers,
236                cmd_buf_data,
237                &mut scratch_buffer_blas_size,
238                &mut blas_storage,
239                hub,
240                device.alignments.ray_tracing_scratch_buffer_alignment,
241            )?;
242            let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();
243
244            for package in tlas_iter {
245                let tlas = hub.tlas_s.get(package.tlas_id).get()?;
246
247                cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
248
249                tlas_lock_store.push((Some(package), tlas))
250            }
251
252            let mut scratch_buffer_tlas_size = 0;
253            let mut tlas_storage = Vec::<TlasStore>::new();
254            let mut instance_buffer_staging_source = Vec::<u8>::new();
255
256            for (package, tlas) in &mut tlas_lock_store {
257                let package = package.take().unwrap();
258
259                let scratch_buffer_offset = scratch_buffer_tlas_size;
260                scratch_buffer_tlas_size += align_to(
261                    tlas.size_info.build_scratch_size as u32,
262                    device.alignments.ray_tracing_scratch_buffer_alignment,
263                ) as u64;
264
265                let first_byte_index = instance_buffer_staging_source.len();
266
267                let mut dependencies = Vec::new();
268
269                let mut instance_count = 0;
270                for instance in package.instances.flatten() {
271                    if instance.custom_data >= (1u32 << 24u32) {
272                        return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
273                            tlas.error_ident(),
274                        ));
275                    }
276                    let blas = hub.blas_s.get(instance.blas_id).get()?;
277
278                    cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
279
280                    instance_buffer_staging_source.extend(device.raw().tlas_instance_to_bytes(
281                        hal::TlasInstance {
282                            transform: *instance.transform,
283                            custom_data: instance.custom_data,
284                            mask: instance.mask,
285                            blas_address: blas.handle,
286                        },
287                    ));
288
289                    if tlas.flags.contains(
290                        wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN,
291                    ) && !blas.flags.contains(
292                        wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN,
293                    ) {
294                        return Err(
295                            BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
296                                tlas.error_ident(),
297                                blas.error_ident(),
298                            ),
299                        );
300                    }
301
302                    instance_count += 1;
303
304                    dependencies.push(blas.clone());
305                }
306
307                build_command.tlas_s_built.push(TlasBuild {
308                    tlas: tlas.clone(),
309                    dependencies,
310                });
311
312                if instance_count > tlas.max_instance_count {
313                    return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
314                        tlas.error_ident(),
315                        instance_count,
316                        tlas.max_instance_count,
317                    ));
318                }
319
320                tlas_storage.push(TlasStore {
321                    internal: UnsafeTlasStore {
322                        tlas: tlas.clone(),
323                        entries: hal::AccelerationStructureEntries::Instances(
324                            hal::AccelerationStructureInstances {
325                                buffer: Some(tlas.instance_buffer.as_ref()),
326                                offset: 0,
327                                count: instance_count,
328                            },
329                        ),
330                        scratch_buffer_offset,
331                    },
332                    range: first_byte_index..instance_buffer_staging_source.len(),
333                });
334            }
335
336            let Some(scratch_size) =
337                wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
338            else {
339                // if the size is zero there is nothing to build
340                return Ok(());
341            };
342
343            let scratch_buffer = ScratchBuffer::new(device, scratch_size)?;
344
345            let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
346                buffer: scratch_buffer.raw(),
347                usage: hal::StateTransition {
348                    from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
349                    to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
350                },
351            };
352
353            let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
354
355            for &TlasStore {
356                internal:
357                    UnsafeTlasStore {
358                        ref tlas,
359                        ref entries,
360                        ref scratch_buffer_offset,
361                    },
362                ..
363            } in &tlas_storage
364            {
365                if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
366                    log::info!("only rebuild implemented")
367                }
368                tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
369                    entries,
370                    mode: hal::AccelerationStructureBuildMode::Build,
371                    flags: tlas.flags,
372                    source_acceleration_structure: None,
373                    destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
374                    scratch_buffer: scratch_buffer.raw(),
375                    scratch_buffer_offset: *scratch_buffer_offset,
376                })
377            }
378
379            let blas_present = !blas_storage.is_empty();
380            let tlas_present = !tlas_storage.is_empty();
381
382            let cmd_buf_raw = cmd_buf_data.encoder.open()?;
383
384            let mut blas_s_compactable = Vec::new();
385            let mut descriptors = Vec::new();
386
387            for storage in &blas_storage {
388                descriptors.push(map_blas(
389                    storage,
390                    scratch_buffer.raw(),
391                    &snatch_guard,
392                    &mut blas_s_compactable,
393                )?);
394            }
395
396            build_blas(
397                cmd_buf_raw,
398                blas_present,
399                tlas_present,
400                input_barriers,
401                &descriptors,
402                scratch_buffer_barrier,
403                blas_s_compactable,
404            );
405
406            if tlas_present {
407                let staging_buffer = if !instance_buffer_staging_source.is_empty() {
408                    let mut staging_buffer = StagingBuffer::new(
409                        device,
410                        wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
411                    )?;
412                    staging_buffer.write(&instance_buffer_staging_source);
413                    let flushed = staging_buffer.flush();
414                    Some(flushed)
415                } else {
416                    None
417                };
418
419                unsafe {
420                    if let Some(ref staging_buffer) = staging_buffer {
421                        cmd_buf_raw.transition_buffers(&[
422                            hal::BufferBarrier::<dyn hal::DynBuffer> {
423                                buffer: staging_buffer.raw(),
424                                usage: hal::StateTransition {
425                                    from: BufferUses::MAP_WRITE,
426                                    to: BufferUses::COPY_SRC,
427                                },
428                            },
429                        ]);
430                    }
431                }
432
433                let mut instance_buffer_barriers = Vec::new();
434                for &TlasStore {
435                    internal: UnsafeTlasStore { ref tlas, .. },
436                    ref range,
437                } in &tlas_storage
438                {
439                    let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
440                        None => continue,
441                        Some(size) => size,
442                    };
443                    instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
444                        buffer: tlas.instance_buffer.as_ref(),
445                        usage: hal::StateTransition {
446                            from: BufferUses::COPY_DST,
447                            to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
448                        },
449                    });
450                    unsafe {
451                        cmd_buf_raw.transition_buffers(&[
452                            hal::BufferBarrier::<dyn hal::DynBuffer> {
453                                buffer: tlas.instance_buffer.as_ref(),
454                                usage: hal::StateTransition {
455                                    from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
456                                    to: BufferUses::COPY_DST,
457                                },
458                            },
459                        ]);
460                        let temp = hal::BufferCopy {
461                            src_offset: range.start as u64,
462                            dst_offset: 0,
463                            size,
464                        };
465                        cmd_buf_raw.copy_buffer_to_buffer(
466                            // the range whose size we just checked end is at (at that point in time) instance_buffer_staging_source.len()
467                            // and since instance_buffer_staging_source doesn't shrink we can un wrap this without a panic
468                            staging_buffer.as_ref().unwrap().raw(),
469                            tlas.instance_buffer.as_ref(),
470                            &[temp],
471                        );
472                    }
473                }
474
475                unsafe {
476                    cmd_buf_raw.transition_buffers(&instance_buffer_barriers);
477
478                    cmd_buf_raw.build_acceleration_structures(&tlas_descriptors);
479
480                    cmd_buf_raw.place_acceleration_structure_barrier(
481                        hal::AccelerationStructureBarrier {
482                            usage: hal::StateTransition {
483                                from: hal::AccelerationStructureUses::BUILD_OUTPUT,
484                                to: hal::AccelerationStructureUses::SHADER_INPUT,
485                            },
486                        },
487                    );
488                }
489
490                if let Some(staging_buffer) = staging_buffer {
491                    cmd_buf_data
492                        .temp_resources
493                        .push(TempResource::StagingBuffer(staging_buffer));
494                }
495            }
496
497            cmd_buf_data
498                .temp_resources
499                .push(TempResource::ScratchBuffer(scratch_buffer));
500
501            cmd_buf_data.as_actions.push(AsAction::Build(build_command));
502
503            Ok(())
504        })
505    }
506}
507
508impl CommandBufferMutable {
509    pub(crate) fn validate_acceleration_structure_actions(
510        &self,
511        snatch_guard: &SnatchGuard,
512        command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
513    ) -> Result<(), ValidateAsActionsError> {
514        profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
515        for action in &self.as_actions {
516            match action {
517                AsAction::Build(build) => {
518                    let build_command_index = NonZeroU64::new(
519                        command_index_guard.next_acceleration_structure_build_command_index,
520                    )
521                    .unwrap();
522
523                    command_index_guard.next_acceleration_structure_build_command_index += 1;
524                    for blas in build.blas_s_built.iter() {
525                        let mut state_lock = blas.compacted_state.lock();
526                        *state_lock = match *state_lock {
527                            BlasCompactState::Compacted => {
528                                unreachable!("Should be validated out in build.")
529                            }
530                            // Reset the compacted state to idle. This means any prepares, before mapping their
531                            // internal buffer, will terminate.
532                            _ => BlasCompactState::Idle,
533                        };
534                        *blas.built_index.write() = Some(build_command_index);
535                    }
536
537                    for tlas_build in build.tlas_s_built.iter() {
538                        for blas in &tlas_build.dependencies {
539                            if blas.built_index.read().is_none() {
540                                return Err(ValidateAsActionsError::UsedUnbuiltBlas(
541                                    blas.error_ident(),
542                                    tlas_build.tlas.error_ident(),
543                                ));
544                            }
545                        }
546                        *tlas_build.tlas.built_index.write() = Some(build_command_index);
547                        tlas_build
548                            .tlas
549                            .dependencies
550                            .write()
551                            .clone_from(&tlas_build.dependencies)
552                    }
553                }
554                AsAction::UseTlas(tlas) => {
555                    let tlas_build_index = tlas.built_index.read();
556                    let dependencies = tlas.dependencies.read();
557
558                    if (*tlas_build_index).is_none() {
559                        return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
560                    }
561                    for blas in dependencies.deref() {
562                        let blas_build_index = *blas.built_index.read();
563                        if blas_build_index.is_none() {
564                            return Err(ValidateAsActionsError::UsedUnbuiltBlas(
565                                tlas.error_ident(),
566                                blas.error_ident(),
567                            ));
568                        }
569                        if blas_build_index.unwrap() > tlas_build_index.unwrap() {
570                            return Err(ValidateAsActionsError::BlasNewerThenTlas(
571                                blas.error_ident(),
572                                tlas.error_ident(),
573                            ));
574                        }
575                        blas.try_raw(snatch_guard)?;
576                    }
577                }
578            }
579        }
580        Ok(())
581    }
582}
583
584///iterates over the blas iterator, and it's geometry, pushing the buffers into a storage vector (and also some validation).
585fn iter_blas<'a>(
586    blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
587    cmd_buf_data: &mut CommandBufferMutable,
588    build_command: &mut AsBuild,
589    buf_storage: &mut Vec<TriangleBufferStore<'a>>,
590    hub: &Hub,
591) -> Result<(), BuildAccelerationStructureError> {
592    let mut temp_buffer = Vec::new();
593    for entry in blas_iter {
594        let blas = hub.blas_s.get(entry.blas_id).get()?;
595        cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
596
597        build_command.blas_s_built.push(blas.clone());
598
599        match entry.geometries {
600            BlasGeometries::TriangleGeometries(triangle_geometries) => {
601                for (i, mesh) in triangle_geometries.enumerate() {
602                    let size_desc = match &blas.sizes {
603                        wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
604                    };
605                    if i >= size_desc.len() {
606                        return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
607                            blas.error_ident(),
608                        ));
609                    }
610                    let size_desc = &size_desc[i];
611
612                    if size_desc.flags != mesh.size.flags {
613                        return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
614                            blas.error_ident(),
615                            size_desc.flags,
616                            mesh.size.flags,
617                        ));
618                    }
619
620                    if size_desc.vertex_count < mesh.size.vertex_count {
621                        return Err(
622                            BuildAccelerationStructureError::IncompatibleBlasVertexCount(
623                                blas.error_ident(),
624                                size_desc.vertex_count,
625                                mesh.size.vertex_count,
626                            ),
627                        );
628                    }
629
630                    if size_desc.vertex_format != mesh.size.vertex_format {
631                        return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
632                            blas.error_ident(),
633                            size_desc.vertex_format,
634                            mesh.size.vertex_format,
635                        ));
636                    }
637
638                    if size_desc
639                        .vertex_format
640                        .min_acceleration_structure_vertex_stride()
641                        > mesh.vertex_stride
642                    {
643                        return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
644                            blas.error_ident(),
645                            size_desc
646                                .vertex_format
647                                .min_acceleration_structure_vertex_stride(),
648                            mesh.vertex_stride,
649                        ));
650                    }
651
652                    if mesh.vertex_stride
653                        % size_desc
654                            .vertex_format
655                            .acceleration_structure_stride_alignment()
656                        != 0
657                    {
658                        return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
659                            blas.error_ident(),
660                            size_desc
661                                .vertex_format
662                                .acceleration_structure_stride_alignment(),
663                            mesh.vertex_stride,
664                        ));
665                    }
666
667                    match (size_desc.index_count, mesh.size.index_count) {
668                        (Some(_), None) | (None, Some(_)) => {
669                            return Err(
670                                BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
671                                    blas.error_ident(),
672                                ),
673                            )
674                        }
675                        (Some(create), Some(build)) if create < build => {
676                            return Err(
677                                BuildAccelerationStructureError::IncompatibleBlasIndexCount(
678                                    blas.error_ident(),
679                                    create,
680                                    build,
681                                ),
682                            )
683                        }
684                        _ => {}
685                    }
686
687                    if size_desc.index_format != mesh.size.index_format {
688                        return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
689                            blas.error_ident(),
690                            size_desc.index_format,
691                            mesh.size.index_format,
692                        ));
693                    }
694
695                    if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
696                        return Err(BuildAccelerationStructureError::MissingIndexBuffer(
697                            blas.error_ident(),
698                        ));
699                    }
700                    let vertex_buffer = hub.buffers.get(mesh.vertex_buffer).get()?;
701                    let vertex_pending = cmd_buf_data.trackers.buffers.set_single(
702                        &vertex_buffer,
703                        BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
704                    );
705                    let index_data = if let Some(index_id) = mesh.index_buffer {
706                        let index_buffer = hub.buffers.get(index_id).get()?;
707                        if mesh.first_index.is_none()
708                            || mesh.size.index_count.is_none()
709                            || mesh.size.index_count.is_none()
710                        {
711                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
712                                index_buffer.error_ident(),
713                            ));
714                        }
715                        let data = cmd_buf_data.trackers.buffers.set_single(
716                            &index_buffer,
717                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
718                        );
719                        Some((index_buffer, data))
720                    } else {
721                        None
722                    };
723                    let transform_data = if let Some(transform_id) = mesh.transform_buffer {
724                        if !blas
725                            .flags
726                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
727                        {
728                            return Err(BuildAccelerationStructureError::UseTransformMissing(
729                                blas.error_ident(),
730                            ));
731                        }
732                        let transform_buffer = hub.buffers.get(transform_id).get()?;
733                        if mesh.transform_buffer_offset.is_none() {
734                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
735                                transform_buffer.error_ident(),
736                            ));
737                        }
738                        let data = cmd_buf_data.trackers.buffers.set_single(
739                            &transform_buffer,
740                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
741                        );
742                        Some((transform_buffer, data))
743                    } else {
744                        if blas
745                            .flags
746                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
747                        {
748                            return Err(BuildAccelerationStructureError::TransformMissing(
749                                blas.error_ident(),
750                            ));
751                        }
752                        None
753                    };
754                    temp_buffer.push(TriangleBufferStore {
755                        vertex_buffer,
756                        vertex_transition: vertex_pending,
757                        index_buffer_transition: index_data,
758                        transform_buffer_transition: transform_data,
759                        geometry: mesh,
760                        ending_blas: None,
761                    });
762                }
763
764                if let Some(last) = temp_buffer.last_mut() {
765                    last.ending_blas = Some(blas);
766                    buf_storage.append(&mut temp_buffer);
767                }
768            }
769        }
770    }
771    Ok(())
772}
773
774/// Iterates over the buffers generated in [iter_blas], convert the barriers into hal barriers, and the triangles into [hal::AccelerationStructureEntries] (and also some validation).
775fn iter_buffers<'a, 'b>(
776    buf_storage: &'a mut Vec<TriangleBufferStore<'b>>,
777    snatch_guard: &'a SnatchGuard,
778    input_barriers: &mut Vec<hal::BufferBarrier<'a, dyn hal::DynBuffer>>,
779    cmd_buf_data: &mut CommandBufferMutable,
780    scratch_buffer_blas_size: &mut u64,
781    blas_storage: &mut Vec<BlasStore<'a>>,
782    hub: &Hub,
783    ray_tracing_scratch_buffer_alignment: u32,
784) -> Result<(), BuildAccelerationStructureError> {
785    let mut triangle_entries =
786        Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
787    for buf in buf_storage {
788        let mesh = &buf.geometry;
789        let vertex_buffer = {
790            let vertex_buffer = buf.vertex_buffer.as_ref();
791            let vertex_raw = vertex_buffer.try_raw(snatch_guard)?;
792            vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
793
794            if let Some(barrier) = buf
795                .vertex_transition
796                .take()
797                .map(|pending| pending.into_hal(vertex_buffer, snatch_guard))
798            {
799                input_barriers.push(barrier);
800            }
801            if vertex_buffer.size
802                < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride
803            {
804                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
805                    vertex_buffer.error_ident(),
806                    vertex_buffer.size,
807                    (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride,
808                ));
809            }
810            let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
811            cmd_buf_data.buffer_memory_init_actions.extend(
812                vertex_buffer.initialization_status.read().create_action(
813                    &hub.buffers.get(mesh.vertex_buffer).get()?,
814                    vertex_buffer_offset
815                        ..(vertex_buffer_offset
816                            + mesh.size.vertex_count as u64 * mesh.vertex_stride),
817                    MemoryInitKind::NeedsInitializedMemory,
818                ),
819            );
820            vertex_raw
821        };
822        let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
823            buf.index_buffer_transition
824        {
825            let index_raw = index_buffer.try_raw(snatch_guard)?;
826            index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
827
828            if let Some(barrier) = index_pending
829                .take()
830                .map(|pending| pending.into_hal(index_buffer, snatch_guard))
831            {
832                input_barriers.push(barrier);
833            }
834            let index_stride = mesh.size.index_format.unwrap().byte_size() as u64;
835            let offset = mesh.first_index.unwrap() as u64 * index_stride;
836            let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride;
837
838            if mesh.size.index_count.unwrap() % 3 != 0 {
839                return Err(BuildAccelerationStructureError::InvalidIndexCount(
840                    index_buffer.error_ident(),
841                    mesh.size.index_count.unwrap(),
842                ));
843            }
844            if index_buffer.size < mesh.size.index_count.unwrap() as u64 * index_stride + offset {
845                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
846                    index_buffer.error_ident(),
847                    index_buffer.size,
848                    mesh.size.index_count.unwrap() as u64 * index_stride + offset,
849                ));
850            }
851
852            cmd_buf_data.buffer_memory_init_actions.extend(
853                index_buffer.initialization_status.read().create_action(
854                    index_buffer,
855                    offset..(offset + index_buffer_size),
856                    MemoryInitKind::NeedsInitializedMemory,
857                ),
858            );
859            Some(index_raw)
860        } else {
861            None
862        };
863        let transform_buffer = if let Some((ref mut transform_buffer, ref mut transform_pending)) =
864            buf.transform_buffer_transition
865        {
866            if mesh.transform_buffer_offset.is_none() {
867                return Err(BuildAccelerationStructureError::MissingAssociatedData(
868                    transform_buffer.error_ident(),
869                ));
870            }
871            let transform_raw = transform_buffer.try_raw(snatch_guard)?;
872            transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
873
874            if let Some(barrier) = transform_pending
875                .take()
876                .map(|pending| pending.into_hal(transform_buffer, snatch_guard))
877            {
878                input_barriers.push(barrier);
879            }
880
881            let offset = mesh.transform_buffer_offset.unwrap();
882
883            if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 {
884                return Err(
885                    BuildAccelerationStructureError::UnalignedTransformBufferOffset(
886                        transform_buffer.error_ident(),
887                    ),
888                );
889            }
890            if transform_buffer.size < 48 + offset {
891                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
892                    transform_buffer.error_ident(),
893                    transform_buffer.size,
894                    48 + offset,
895                ));
896            }
897            cmd_buf_data.buffer_memory_init_actions.extend(
898                transform_buffer.initialization_status.read().create_action(
899                    transform_buffer,
900                    offset..(offset + 48),
901                    MemoryInitKind::NeedsInitializedMemory,
902                ),
903            );
904            Some(transform_raw)
905        } else {
906            None
907        };
908
909        let triangles = hal::AccelerationStructureTriangles {
910            vertex_buffer: Some(vertex_buffer),
911            vertex_format: mesh.size.vertex_format,
912            first_vertex: mesh.first_vertex,
913            vertex_count: mesh.size.vertex_count,
914            vertex_stride: mesh.vertex_stride,
915            indices: index_buffer.map(|index_buffer| {
916                let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
917                hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
918                    format: mesh.size.index_format.unwrap(),
919                    buffer: Some(index_buffer),
920                    offset: mesh.first_index.unwrap() * index_stride,
921                    count: mesh.size.index_count.unwrap(),
922                }
923            }),
924            transform: transform_buffer.map(|transform_buffer| {
925                hal::AccelerationStructureTriangleTransform {
926                    buffer: transform_buffer,
927                    offset: mesh.transform_buffer_offset.unwrap() as u32,
928                }
929            }),
930            flags: mesh.size.flags,
931        };
932        triangle_entries.push(triangles);
933        if let Some(blas) = buf.ending_blas.take() {
934            let scratch_buffer_offset = *scratch_buffer_blas_size;
935            *scratch_buffer_blas_size += align_to(
936                blas.size_info.build_scratch_size as u32,
937                ray_tracing_scratch_buffer_alignment,
938            ) as u64;
939
940            blas_storage.push(BlasStore {
941                blas,
942                entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
943                scratch_buffer_offset,
944            });
945            triangle_entries = Vec::new();
946        }
947    }
948    Ok(())
949}
950
951fn map_blas<'a>(
952    storage: &'a BlasStore<'_>,
953    scratch_buffer: &'a dyn hal::DynBuffer,
954    snatch_guard: &'a SnatchGuard,
955    blases_compactable: &mut Vec<(
956        &'a dyn hal::DynBuffer,
957        &'a dyn hal::DynAccelerationStructure,
958    )>,
959) -> Result<
960    hal::BuildAccelerationStructureDescriptor<
961        'a,
962        dyn hal::DynBuffer,
963        dyn hal::DynAccelerationStructure,
964    >,
965    BuildAccelerationStructureError,
966> {
967    let BlasStore {
968        blas,
969        entries,
970        scratch_buffer_offset,
971    } = storage;
972    if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
973        log::info!("only rebuild implemented")
974    }
975    let raw = blas.try_raw(snatch_guard)?;
976
977    let state_lock = blas.compacted_state.lock();
978    if let BlasCompactState::Compacted = *state_lock {
979        return Err(BuildAccelerationStructureError::CompactedBlas(
980            blas.error_ident(),
981        ));
982    }
983
984    if blas
985        .flags
986        .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
987    {
988        blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
989    }
990    Ok(hal::BuildAccelerationStructureDescriptor {
991        entries,
992        mode: hal::AccelerationStructureBuildMode::Build,
993        flags: blas.flags,
994        source_acceleration_structure: None,
995        destination_acceleration_structure: raw,
996        scratch_buffer,
997        scratch_buffer_offset: *scratch_buffer_offset,
998    })
999}
1000
1001fn build_blas<'a>(
1002    cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
1003    blas_present: bool,
1004    tlas_present: bool,
1005    input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1006    blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1007        'a,
1008        dyn hal::DynBuffer,
1009        dyn hal::DynAccelerationStructure,
1010    >],
1011    scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1012    blas_s_for_compaction: Vec<(
1013        &'a dyn hal::DynBuffer,
1014        &'a dyn hal::DynAccelerationStructure,
1015    )>,
1016) {
1017    unsafe {
1018        cmd_buf_raw.transition_buffers(&input_barriers);
1019    }
1020
1021    if blas_present {
1022        unsafe {
1023            cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1024                usage: hal::StateTransition {
1025                    from: hal::AccelerationStructureUses::BUILD_INPUT,
1026                    to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1027                },
1028            });
1029
1030            cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1031        }
1032    }
1033
1034    if blas_present && tlas_present {
1035        unsafe {
1036            cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1037        }
1038    }
1039
1040    let mut source_usage = hal::AccelerationStructureUses::empty();
1041    let mut destination_usage = hal::AccelerationStructureUses::empty();
1042    for &(buf, blas) in blas_s_for_compaction.iter() {
1043        unsafe {
1044            cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
1045                buffer: buf,
1046                usage: hal::StateTransition {
1047                    from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1048                    to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1049                },
1050            }])
1051        }
1052        unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
1053        destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
1054    }
1055
1056    if blas_present {
1057        source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1058        destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1059    }
1060    if tlas_present {
1061        source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1062        destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1063    }
1064    unsafe {
1065        cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1066            usage: hal::StateTransition {
1067                from: source_usage,
1068                to: destination_usage,
1069            },
1070        });
1071    }
1072}