wgpu_core/command/
ray_tracing.rs

1use alloc::{sync::Arc, vec::Vec};
2use core::{
3    cmp::max,
4    num::NonZeroU64,
5    ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::{
11    command::CommandBufferMutable,
12    device::queue::TempResource,
13    global::Global,
14    id::CommandEncoderId,
15    init_tracker::MemoryInitKind,
16    ray_tracing::{
17        BlasBuildEntry, BlasGeometries, BuildAccelerationStructureError, TlasPackage,
18        TraceBlasBuildEntry, TraceBlasGeometries, TraceBlasTriangleGeometry, TraceTlasInstance,
19        TraceTlasPackage,
20    },
21    resource::{Blas, BlasCompactState, Buffer, Labeled, StagingBuffer, Tlas},
22    scratch::ScratchBuffer,
23    snatch::SnatchGuard,
24    track::PendingTransition,
25};
26use crate::{command::EncoderStateError, device::resource::CommandIndices};
27use crate::{
28    command::{encoder::EncodingState, ArcCommand},
29    ray_tracing::{
30        ArcBlasBuildEntry, ArcBlasGeometries, ArcBlasTriangleGeometry, ArcTlasInstance,
31        ArcTlasPackage, AsAction, AsBuild, BlasTriangleGeometryInfo, TlasBuild,
32        ValidateAsActionsError,
33    },
34    resource::InvalidResourceError,
35    track::Tracker,
36};
37use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
38
39use crate::id::{BlasId, TlasId};
40
41struct TriangleBufferStore {
42    vertex_buffer: Arc<Buffer>,
43    vertex_transition: Option<PendingTransition<BufferUses>>,
44    index_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
45    transform_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
46    geometry: BlasTriangleGeometryInfo,
47    ending_blas: Option<Arc<Blas>>,
48}
49
50struct BlasStore<'a> {
51    blas: Arc<Blas>,
52    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
53    scratch_buffer_offset: u64,
54}
55
56struct UnsafeTlasStore<'a> {
57    tlas: Arc<Tlas>,
58    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
59    scratch_buffer_offset: u64,
60}
61
62struct TlasStore<'a> {
63    internal: UnsafeTlasStore<'a>,
64    range: Range<usize>,
65}
66
67impl Global {
68    fn resolve_blas_id(&self, blas_id: BlasId) -> Result<Arc<Blas>, InvalidResourceError> {
69        self.hub.blas_s.get(blas_id).get()
70    }
71
72    fn resolve_tlas_id(&self, tlas_id: TlasId) -> Result<Arc<Tlas>, InvalidResourceError> {
73        self.hub.tlas_s.get(tlas_id).get()
74    }
75
76    pub fn command_encoder_mark_acceleration_structures_built(
77        &self,
78        command_encoder_id: CommandEncoderId,
79        blas_ids: &[BlasId],
80        tlas_ids: &[TlasId],
81    ) -> Result<(), EncoderStateError> {
82        profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
83
84        let hub = &self.hub;
85
86        let cmd_enc = hub.command_encoders.get(command_encoder_id);
87
88        let mut cmd_buf_data = cmd_enc.data.lock();
89        cmd_buf_data.with_buffer(
90            |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
91                let device = &cmd_enc.device;
92                device.check_is_valid()?;
93                device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
94
95                let mut build_command = AsBuild::default();
96
97                for blas in blas_ids {
98                    let blas = hub.blas_s.get(*blas).get()?;
99                    build_command.blas_s_built.push(blas);
100                }
101
102                for tlas in tlas_ids {
103                    let tlas = hub.tlas_s.get(*tlas).get()?;
104                    build_command.tlas_s_built.push(TlasBuild {
105                        tlas,
106                        dependencies: Vec::new(),
107                    });
108                }
109
110                cmd_buf_data.as_actions.push(AsAction::Build(build_command));
111                Ok(())
112            },
113        )
114    }
115
116    pub fn command_encoder_build_acceleration_structures<'a>(
117        &self,
118        command_encoder_id: CommandEncoderId,
119        blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
120        tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
121    ) -> Result<(), EncoderStateError> {
122        profiling::scope!("CommandEncoder::build_acceleration_structures");
123
124        let hub = &self.hub;
125
126        let cmd_enc = hub.command_encoders.get(command_encoder_id);
127
128        let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
129            .map(|blas_entry| {
130                let geometries = match blas_entry.geometries {
131                    BlasGeometries::TriangleGeometries(triangle_geometries) => {
132                        TraceBlasGeometries::TriangleGeometries(
133                            triangle_geometries
134                                .map(|tg| TraceBlasTriangleGeometry {
135                                    size: tg.size.clone(),
136                                    vertex_buffer: tg.vertex_buffer,
137                                    index_buffer: tg.index_buffer,
138                                    transform_buffer: tg.transform_buffer,
139                                    first_vertex: tg.first_vertex,
140                                    vertex_stride: tg.vertex_stride,
141                                    first_index: tg.first_index,
142                                    transform_buffer_offset: tg.transform_buffer_offset,
143                                })
144                                .collect(),
145                        )
146                    }
147                };
148                TraceBlasBuildEntry {
149                    blas_id: blas_entry.blas_id,
150                    geometries,
151                }
152            })
153            .collect();
154
155        let trace_tlas: Vec<TraceTlasPackage> = tlas_iter
156            .map(|package: TlasPackage| {
157                let instances = package
158                    .instances
159                    .map(|instance| {
160                        instance.map(|instance| TraceTlasInstance {
161                            blas_id: instance.blas_id,
162                            transform: *instance.transform,
163                            custom_data: instance.custom_data,
164                            mask: instance.mask,
165                        })
166                    })
167                    .collect();
168                TraceTlasPackage {
169                    tlas_id: package.tlas_id,
170                    instances,
171                    lowest_unmodified: package.lowest_unmodified,
172                }
173            })
174            .collect();
175
176        let mut cmd_buf_data = cmd_enc.data.lock();
177
178        #[cfg(feature = "trace")]
179        if let Some(ref mut list) = cmd_buf_data.trace() {
180            list.push(crate::command::Command::BuildAccelerationStructures {
181                blas: trace_blas.clone(),
182                tlas: trace_tlas.clone(),
183            });
184        }
185
186        cmd_buf_data.push_with(|| -> Result<_, BuildAccelerationStructureError> {
187            let blas = trace_blas
188                .iter()
189                .map(|blas_entry| {
190                    let geometries = match &blas_entry.geometries {
191                        TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
192                            let tri_geo = triangle_geometries
193                                .iter()
194                                .map(|tg| {
195                                    Ok(ArcBlasTriangleGeometry {
196                                        size: tg.size.clone(),
197                                        vertex_buffer: self.resolve_buffer_id(tg.vertex_buffer)?,
198                                        index_buffer: tg
199                                            .index_buffer
200                                            .map(|id| self.resolve_buffer_id(id))
201                                            .transpose()?,
202                                        transform_buffer: tg
203                                            .transform_buffer
204                                            .map(|id| self.resolve_buffer_id(id))
205                                            .transpose()?,
206                                        first_vertex: tg.first_vertex,
207                                        vertex_stride: tg.vertex_stride,
208                                        first_index: tg.first_index,
209                                        transform_buffer_offset: tg.transform_buffer_offset,
210                                    })
211                                })
212                                .collect::<Result<_, BuildAccelerationStructureError>>()?;
213                            ArcBlasGeometries::TriangleGeometries(tri_geo)
214                        }
215                    };
216                    Ok(ArcBlasBuildEntry {
217                        blas: self.resolve_blas_id(blas_entry.blas_id)?,
218                        geometries,
219                    })
220                })
221                .collect::<Result<_, BuildAccelerationStructureError>>()?;
222
223            let tlas = trace_tlas
224                .iter()
225                .map(|tlas_package| {
226                    let instances = tlas_package
227                        .instances
228                        .iter()
229                        .map(|instance| {
230                            instance
231                                .as_ref()
232                                .map(|instance| {
233                                    Ok(ArcTlasInstance {
234                                        blas: self.resolve_blas_id(instance.blas_id)?,
235                                        transform: instance.transform,
236                                        custom_data: instance.custom_data,
237                                        mask: instance.mask,
238                                    })
239                                })
240                                .transpose()
241                        })
242                        .collect::<Result<_, BuildAccelerationStructureError>>()?;
243                    Ok(ArcTlasPackage {
244                        tlas: self.resolve_tlas_id(tlas_package.tlas_id)?,
245                        instances,
246                        lowest_unmodified: tlas_package.lowest_unmodified,
247                    })
248                })
249                .collect::<Result<_, BuildAccelerationStructureError>>()?;
250
251            Ok(ArcCommand::BuildAccelerationStructures { blas, tlas })
252        })
253    }
254}
255
256pub(crate) fn build_acceleration_structures(
257    state: &mut EncodingState,
258    blas: Vec<ArcBlasBuildEntry>,
259    tlas: Vec<ArcTlasPackage>,
260) -> Result<(), BuildAccelerationStructureError> {
261    state
262        .device
263        .require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
264
265    let mut build_command = AsBuild::default();
266    let mut buf_storage = Vec::new();
267    iter_blas(
268        blas.into_iter(),
269        state.tracker,
270        &mut build_command,
271        &mut buf_storage,
272    )?;
273
274    let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
275    let mut scratch_buffer_blas_size = 0;
276    let mut blas_storage = Vec::new();
277    iter_buffers(
278        state,
279        &mut buf_storage,
280        &mut input_barriers,
281        &mut scratch_buffer_blas_size,
282        &mut blas_storage,
283    )?;
284    let mut tlas_lock_store = Vec::<(Option<ArcTlasPackage>, Arc<Tlas>)>::new();
285
286    for package in tlas.into_iter() {
287        let tlas = package.tlas.clone();
288        state.tracker.tlas_s.insert_single(tlas.clone());
289        tlas_lock_store.push((Some(package), tlas))
290    }
291
292    let mut scratch_buffer_tlas_size = 0;
293    let mut tlas_storage = Vec::<TlasStore>::new();
294    let mut instance_buffer_staging_source = Vec::<u8>::new();
295
296    for (package, tlas) in &mut tlas_lock_store {
297        let package = package.take().unwrap();
298
299        let scratch_buffer_offset = scratch_buffer_tlas_size;
300        scratch_buffer_tlas_size += align_to(
301            tlas.size_info.build_scratch_size as u32,
302            state.device.alignments.ray_tracing_scratch_buffer_alignment,
303        ) as u64;
304
305        let first_byte_index = instance_buffer_staging_source.len();
306
307        let mut dependencies = Vec::new();
308
309        let mut instance_count = 0;
310        for instance in package.instances.into_iter().flatten() {
311            if instance.custom_data >= (1u32 << 24u32) {
312                return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
313                    tlas.error_ident(),
314                ));
315            }
316            let blas = &instance.blas;
317            state.tracker.blas_s.insert_single(blas.clone());
318
319            instance_buffer_staging_source.extend(state.device.raw().tlas_instance_to_bytes(
320                hal::TlasInstance {
321                    transform: instance.transform,
322                    custom_data: instance.custom_data,
323                    mask: instance.mask,
324                    blas_address: blas.handle,
325                },
326            ));
327
328            if tlas
329                .flags
330                .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
331                && !blas
332                    .flags
333                    .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
334            {
335                return Err(
336                    BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
337                        tlas.error_ident(),
338                        blas.error_ident(),
339                    ),
340                );
341            }
342
343            instance_count += 1;
344
345            dependencies.push(blas.clone());
346        }
347
348        build_command.tlas_s_built.push(TlasBuild {
349            tlas: tlas.clone(),
350            dependencies,
351        });
352
353        if instance_count > tlas.max_instance_count {
354            return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
355                tlas.error_ident(),
356                instance_count,
357                tlas.max_instance_count,
358            ));
359        }
360
361        tlas_storage.push(TlasStore {
362            internal: UnsafeTlasStore {
363                tlas: tlas.clone(),
364                entries: hal::AccelerationStructureEntries::Instances(
365                    hal::AccelerationStructureInstances {
366                        buffer: Some(tlas.instance_buffer.as_ref()),
367                        offset: 0,
368                        count: instance_count,
369                    },
370                ),
371                scratch_buffer_offset,
372            },
373            range: first_byte_index..instance_buffer_staging_source.len(),
374        });
375    }
376
377    let Some(scratch_size) =
378        wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
379    else {
380        // if the size is zero there is nothing to build
381        return Ok(());
382    };
383
384    let scratch_buffer = ScratchBuffer::new(state.device, scratch_size)?;
385
386    let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
387        buffer: scratch_buffer.raw(),
388        usage: hal::StateTransition {
389            from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
390            to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
391        },
392    };
393
394    let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
395
396    for &TlasStore {
397        internal:
398            UnsafeTlasStore {
399                ref tlas,
400                ref entries,
401                ref scratch_buffer_offset,
402            },
403        ..
404    } in &tlas_storage
405    {
406        if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
407            log::info!("only rebuild implemented")
408        }
409        tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
410            entries,
411            mode: hal::AccelerationStructureBuildMode::Build,
412            flags: tlas.flags,
413            source_acceleration_structure: None,
414            destination_acceleration_structure: tlas.try_raw(state.snatch_guard)?,
415            scratch_buffer: scratch_buffer.raw(),
416            scratch_buffer_offset: *scratch_buffer_offset,
417        })
418    }
419
420    let blas_present = !blas_storage.is_empty();
421    let tlas_present = !tlas_storage.is_empty();
422
423    let raw_encoder = &mut state.raw_encoder;
424
425    let mut blas_s_compactable = Vec::new();
426    let mut descriptors = Vec::new();
427
428    for storage in &blas_storage {
429        descriptors.push(map_blas(
430            storage,
431            scratch_buffer.raw(),
432            state.snatch_guard,
433            &mut blas_s_compactable,
434        )?);
435    }
436
437    build_blas(
438        *raw_encoder,
439        blas_present,
440        tlas_present,
441        input_barriers,
442        &descriptors,
443        scratch_buffer_barrier,
444        blas_s_compactable,
445    );
446
447    if tlas_present {
448        let staging_buffer = if !instance_buffer_staging_source.is_empty() {
449            let mut staging_buffer = StagingBuffer::new(
450                state.device,
451                wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
452            )?;
453            staging_buffer.write(&instance_buffer_staging_source);
454            let flushed = staging_buffer.flush();
455            Some(flushed)
456        } else {
457            None
458        };
459
460        unsafe {
461            if let Some(ref staging_buffer) = staging_buffer {
462                raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
463                    buffer: staging_buffer.raw(),
464                    usage: hal::StateTransition {
465                        from: BufferUses::MAP_WRITE,
466                        to: BufferUses::COPY_SRC,
467                    },
468                }]);
469            }
470        }
471
472        let mut instance_buffer_barriers = Vec::new();
473        for &TlasStore {
474            internal: UnsafeTlasStore { ref tlas, .. },
475            ref range,
476        } in &tlas_storage
477        {
478            let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
479                None => continue,
480                Some(size) => size,
481            };
482            instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
483                buffer: tlas.instance_buffer.as_ref(),
484                usage: hal::StateTransition {
485                    from: BufferUses::COPY_DST,
486                    to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
487                },
488            });
489            unsafe {
490                raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
491                    buffer: tlas.instance_buffer.as_ref(),
492                    usage: hal::StateTransition {
493                        from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
494                        to: BufferUses::COPY_DST,
495                    },
496                }]);
497                let temp = hal::BufferCopy {
498                    src_offset: range.start as u64,
499                    dst_offset: 0,
500                    size,
501                };
502                raw_encoder.copy_buffer_to_buffer(
503                    staging_buffer.as_ref().unwrap().raw(),
504                    tlas.instance_buffer.as_ref(),
505                    &[temp],
506                );
507            }
508        }
509
510        unsafe {
511            raw_encoder.transition_buffers(&instance_buffer_barriers);
512
513            raw_encoder.build_acceleration_structures(&tlas_descriptors);
514
515            raw_encoder.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
516                usage: hal::StateTransition {
517                    from: hal::AccelerationStructureUses::BUILD_OUTPUT,
518                    to: hal::AccelerationStructureUses::SHADER_INPUT,
519                },
520            });
521        }
522
523        if let Some(staging_buffer) = staging_buffer {
524            state
525                .temp_resources
526                .push(TempResource::StagingBuffer(staging_buffer));
527        }
528    }
529
530    state
531        .temp_resources
532        .push(TempResource::ScratchBuffer(scratch_buffer));
533
534    state.as_actions.push(AsAction::Build(build_command));
535
536    Ok(())
537}
538
539impl CommandBufferMutable {
540    pub(crate) fn validate_acceleration_structure_actions(
541        &self,
542        snatch_guard: &SnatchGuard,
543        command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
544    ) -> Result<(), ValidateAsActionsError> {
545        profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
546        for action in &self.as_actions {
547            match action {
548                AsAction::Build(build) => {
549                    let build_command_index = NonZeroU64::new(
550                        command_index_guard.next_acceleration_structure_build_command_index,
551                    )
552                    .unwrap();
553
554                    command_index_guard.next_acceleration_structure_build_command_index += 1;
555                    for blas in build.blas_s_built.iter() {
556                        let mut state_lock = blas.compacted_state.lock();
557                        *state_lock = match *state_lock {
558                            BlasCompactState::Compacted => {
559                                unreachable!("Should be validated out in build.")
560                            }
561                            // Reset the compacted state to idle. This means any prepares, before mapping their
562                            // internal buffer, will terminate.
563                            _ => BlasCompactState::Idle,
564                        };
565                        *blas.built_index.write() = Some(build_command_index);
566                    }
567
568                    for tlas_build in build.tlas_s_built.iter() {
569                        for blas in &tlas_build.dependencies {
570                            if blas.built_index.read().is_none() {
571                                return Err(ValidateAsActionsError::UsedUnbuiltBlas(
572                                    blas.error_ident(),
573                                    tlas_build.tlas.error_ident(),
574                                ));
575                            }
576                        }
577                        *tlas_build.tlas.built_index.write() = Some(build_command_index);
578                        tlas_build
579                            .tlas
580                            .dependencies
581                            .write()
582                            .clone_from(&tlas_build.dependencies)
583                    }
584                }
585                AsAction::UseTlas(tlas) => {
586                    let tlas_build_index = tlas.built_index.read();
587                    let dependencies = tlas.dependencies.read();
588
589                    if (*tlas_build_index).is_none() {
590                        return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
591                    }
592                    for blas in dependencies.deref() {
593                        let blas_build_index = *blas.built_index.read();
594                        if blas_build_index.is_none() {
595                            return Err(ValidateAsActionsError::UsedUnbuiltBlas(
596                                tlas.error_ident(),
597                                blas.error_ident(),
598                            ));
599                        }
600                        if blas_build_index.unwrap() > tlas_build_index.unwrap() {
601                            return Err(ValidateAsActionsError::BlasNewerThenTlas(
602                                blas.error_ident(),
603                                tlas.error_ident(),
604                            ));
605                        }
606                        blas.try_raw(snatch_guard)?;
607                    }
608                }
609            }
610        }
611        Ok(())
612    }
613}
614
615///iterates over the blas iterator, and it's geometry, pushing the buffers into a storage vector (and also some validation).
616fn iter_blas(
617    blas_iter: impl Iterator<Item = ArcBlasBuildEntry>,
618    tracker: &mut Tracker,
619    build_command: &mut AsBuild,
620    buf_storage: &mut Vec<TriangleBufferStore>,
621) -> Result<(), BuildAccelerationStructureError> {
622    let mut temp_buffer = Vec::new();
623    for entry in blas_iter {
624        let blas = &entry.blas;
625        tracker.blas_s.insert_single(blas.clone());
626
627        build_command.blas_s_built.push(blas.clone());
628
629        match entry.geometries {
630            ArcBlasGeometries::TriangleGeometries(triangle_geometries) => {
631                for (i, mesh) in triangle_geometries.into_iter().enumerate() {
632                    let size_desc = match &blas.sizes {
633                        wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
634                    };
635                    if i >= size_desc.len() {
636                        return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
637                            blas.error_ident(),
638                        ));
639                    }
640                    let size_desc = &size_desc[i];
641
642                    if size_desc.flags != mesh.size.flags {
643                        return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
644                            blas.error_ident(),
645                            size_desc.flags,
646                            mesh.size.flags,
647                        ));
648                    }
649
650                    if size_desc.vertex_count < mesh.size.vertex_count {
651                        return Err(
652                            BuildAccelerationStructureError::IncompatibleBlasVertexCount(
653                                blas.error_ident(),
654                                size_desc.vertex_count,
655                                mesh.size.vertex_count,
656                            ),
657                        );
658                    }
659
660                    if size_desc.vertex_format != mesh.size.vertex_format {
661                        return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
662                            blas.error_ident(),
663                            size_desc.vertex_format,
664                            mesh.size.vertex_format,
665                        ));
666                    }
667
668                    if size_desc
669                        .vertex_format
670                        .min_acceleration_structure_vertex_stride()
671                        > mesh.vertex_stride
672                    {
673                        return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
674                            blas.error_ident(),
675                            size_desc
676                                .vertex_format
677                                .min_acceleration_structure_vertex_stride(),
678                            mesh.vertex_stride,
679                        ));
680                    }
681
682                    if mesh.vertex_stride
683                        % size_desc
684                            .vertex_format
685                            .acceleration_structure_stride_alignment()
686                        != 0
687                    {
688                        return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
689                            blas.error_ident(),
690                            size_desc
691                                .vertex_format
692                                .acceleration_structure_stride_alignment(),
693                            mesh.vertex_stride,
694                        ));
695                    }
696
697                    match (size_desc.index_count, mesh.size.index_count) {
698                        (Some(_), None) | (None, Some(_)) => {
699                            return Err(
700                                BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
701                                    blas.error_ident(),
702                                ),
703                            )
704                        }
705                        (Some(create), Some(build)) if create < build => {
706                            return Err(
707                                BuildAccelerationStructureError::IncompatibleBlasIndexCount(
708                                    blas.error_ident(),
709                                    create,
710                                    build,
711                                ),
712                            )
713                        }
714                        _ => {}
715                    }
716
717                    if size_desc.index_format != mesh.size.index_format {
718                        return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
719                            blas.error_ident(),
720                            size_desc.index_format,
721                            mesh.size.index_format,
722                        ));
723                    }
724
725                    if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
726                        return Err(BuildAccelerationStructureError::MissingIndexBuffer(
727                            blas.error_ident(),
728                        ));
729                    }
730                    let vertex_buffer = mesh.vertex_buffer.clone();
731                    let vertex_pending = tracker.buffers.set_single(
732                        &vertex_buffer,
733                        BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
734                    );
735                    let index_data = if let Some(index_buffer) = mesh.index_buffer {
736                        if mesh.first_index.is_none()
737                            || mesh.size.index_count.is_none()
738                            || mesh.size.index_count.is_none()
739                        {
740                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
741                                index_buffer.error_ident(),
742                            ));
743                        }
744                        let data = tracker.buffers.set_single(
745                            &index_buffer,
746                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
747                        );
748                        Some((index_buffer, data))
749                    } else {
750                        None
751                    };
752                    let transform_data = if let Some(transform_buffer) = mesh.transform_buffer {
753                        if !blas
754                            .flags
755                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
756                        {
757                            return Err(BuildAccelerationStructureError::UseTransformMissing(
758                                blas.error_ident(),
759                            ));
760                        }
761                        if mesh.transform_buffer_offset.is_none() {
762                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
763                                transform_buffer.error_ident(),
764                            ));
765                        }
766                        let data = tracker.buffers.set_single(
767                            &transform_buffer,
768                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
769                        );
770                        Some((transform_buffer, data))
771                    } else {
772                        if blas
773                            .flags
774                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
775                        {
776                            return Err(BuildAccelerationStructureError::TransformMissing(
777                                blas.error_ident(),
778                            ));
779                        }
780                        None
781                    };
782                    temp_buffer.push(TriangleBufferStore {
783                        vertex_buffer,
784                        vertex_transition: vertex_pending,
785                        index_buffer_transition: index_data,
786                        transform_buffer_transition: transform_data,
787                        geometry: BlasTriangleGeometryInfo {
788                            size: mesh.size,
789                            first_vertex: mesh.first_vertex,
790                            vertex_stride: mesh.vertex_stride,
791                            first_index: mesh.first_index,
792                            transform_buffer_offset: mesh.transform_buffer_offset,
793                        },
794                        ending_blas: None,
795                    });
796                }
797
798                if let Some(last) = temp_buffer.last_mut() {
799                    last.ending_blas = Some(blas.clone());
800                    buf_storage.append(&mut temp_buffer);
801                }
802            }
803        }
804    }
805    Ok(())
806}
807
808/// Iterates over the buffers generated in [iter_blas], convert the barriers into hal barriers, and the triangles into [hal::AccelerationStructureEntries] (and also some validation).
809///
810/// `'buffers` is the lifetime of `&dyn hal::DynBuffer` in our working data,
811/// i.e., needs to span until `build_acceleration_structures` finishes encoding.
812/// `'snatch_guard` is the lifetime of the snatch lock acquisition.
813fn iter_buffers<'snatch_guard: 'buffers, 'buffers>(
814    state: &mut EncodingState<'snatch_guard, '_>,
815    buf_storage: &'buffers mut Vec<TriangleBufferStore>,
816    input_barriers: &mut Vec<hal::BufferBarrier<'buffers, dyn hal::DynBuffer>>,
817    scratch_buffer_blas_size: &mut u64,
818    blas_storage: &mut Vec<BlasStore<'buffers>>,
819) -> Result<(), BuildAccelerationStructureError> {
820    let mut triangle_entries =
821        Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
822    for buf in buf_storage {
823        let mesh = &buf.geometry;
824        let vertex_buffer = {
825            let vertex_raw = buf.vertex_buffer.as_ref().try_raw(state.snatch_guard)?;
826            let vertex_buffer = &buf.vertex_buffer;
827            vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
828
829            if let Some(barrier) = buf
830                .vertex_transition
831                .take()
832                .map(|pending| pending.into_hal(buf.vertex_buffer.as_ref(), state.snatch_guard))
833            {
834                input_barriers.push(barrier);
835            }
836            if vertex_buffer.size
837                < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride
838            {
839                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
840                    vertex_buffer.error_ident(),
841                    vertex_buffer.size,
842                    (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride,
843                ));
844            }
845            let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
846            state.buffer_memory_init_actions.extend(
847                vertex_buffer.initialization_status.read().create_action(
848                    vertex_buffer,
849                    vertex_buffer_offset
850                        ..(vertex_buffer_offset
851                            + mesh.size.vertex_count as u64 * mesh.vertex_stride),
852                    MemoryInitKind::NeedsInitializedMemory,
853                ),
854            );
855            vertex_raw
856        };
857        let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
858            buf.index_buffer_transition
859        {
860            let index_raw = index_buffer.try_raw(state.snatch_guard)?;
861            index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
862
863            if let Some(barrier) = index_pending
864                .take()
865                .map(|pending| pending.into_hal(index_buffer, state.snatch_guard))
866            {
867                input_barriers.push(barrier);
868            }
869            let index_stride = mesh.size.index_format.unwrap().byte_size() as u64;
870            let offset = mesh.first_index.unwrap() as u64 * index_stride;
871            let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride;
872
873            if mesh.size.index_count.unwrap() % 3 != 0 {
874                return Err(BuildAccelerationStructureError::InvalidIndexCount(
875                    index_buffer.error_ident(),
876                    mesh.size.index_count.unwrap(),
877                ));
878            }
879            if index_buffer.size < mesh.size.index_count.unwrap() as u64 * index_stride + offset {
880                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
881                    index_buffer.error_ident(),
882                    index_buffer.size,
883                    mesh.size.index_count.unwrap() as u64 * index_stride + offset,
884                ));
885            }
886
887            state.buffer_memory_init_actions.extend(
888                index_buffer.initialization_status.read().create_action(
889                    index_buffer,
890                    offset..(offset + index_buffer_size),
891                    MemoryInitKind::NeedsInitializedMemory,
892                ),
893            );
894            Some(index_raw)
895        } else {
896            None
897        };
898        let transform_buffer = if let Some((ref mut transform_buffer, ref mut transform_pending)) =
899            buf.transform_buffer_transition
900        {
901            if mesh.transform_buffer_offset.is_none() {
902                return Err(BuildAccelerationStructureError::MissingAssociatedData(
903                    transform_buffer.error_ident(),
904                ));
905            }
906            let transform_raw = transform_buffer.try_raw(state.snatch_guard)?;
907            transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
908
909            if let Some(barrier) = transform_pending
910                .take()
911                .map(|pending| pending.into_hal(transform_buffer, state.snatch_guard))
912            {
913                input_barriers.push(barrier);
914            }
915
916            let offset = mesh.transform_buffer_offset.unwrap();
917
918            if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 {
919                return Err(
920                    BuildAccelerationStructureError::UnalignedTransformBufferOffset(
921                        transform_buffer.error_ident(),
922                    ),
923                );
924            }
925            if transform_buffer.size < 48 + offset {
926                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
927                    transform_buffer.error_ident(),
928                    transform_buffer.size,
929                    48 + offset,
930                ));
931            }
932            state.buffer_memory_init_actions.extend(
933                transform_buffer.initialization_status.read().create_action(
934                    transform_buffer,
935                    offset..(offset + 48),
936                    MemoryInitKind::NeedsInitializedMemory,
937                ),
938            );
939            Some(transform_raw)
940        } else {
941            None
942        };
943
944        let triangles = hal::AccelerationStructureTriangles {
945            vertex_buffer: Some(vertex_buffer),
946            vertex_format: mesh.size.vertex_format,
947            first_vertex: mesh.first_vertex,
948            vertex_count: mesh.size.vertex_count,
949            vertex_stride: mesh.vertex_stride,
950            indices: index_buffer.map(|index_buffer| {
951                let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
952                hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
953                    format: mesh.size.index_format.unwrap(),
954                    buffer: Some(index_buffer),
955                    offset: mesh.first_index.unwrap() * index_stride,
956                    count: mesh.size.index_count.unwrap(),
957                }
958            }),
959            transform: transform_buffer.map(|transform_buffer| {
960                hal::AccelerationStructureTriangleTransform {
961                    buffer: transform_buffer,
962                    offset: mesh.transform_buffer_offset.unwrap() as u32,
963                }
964            }),
965            flags: mesh.size.flags,
966        };
967        triangle_entries.push(triangles);
968        if let Some(blas) = buf.ending_blas.take() {
969            let scratch_buffer_offset = *scratch_buffer_blas_size;
970            *scratch_buffer_blas_size += align_to(
971                blas.size_info.build_scratch_size as u32,
972                state.device.alignments.ray_tracing_scratch_buffer_alignment,
973            ) as u64;
974
975            blas_storage.push(BlasStore {
976                blas,
977                entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
978                scratch_buffer_offset,
979            });
980            triangle_entries = Vec::new();
981        }
982    }
983    Ok(())
984}
985
986fn map_blas<'a>(
987    storage: &'a BlasStore<'_>,
988    scratch_buffer: &'a dyn hal::DynBuffer,
989    snatch_guard: &'a SnatchGuard,
990    blases_compactable: &mut Vec<(
991        &'a dyn hal::DynBuffer,
992        &'a dyn hal::DynAccelerationStructure,
993    )>,
994) -> Result<
995    hal::BuildAccelerationStructureDescriptor<
996        'a,
997        dyn hal::DynBuffer,
998        dyn hal::DynAccelerationStructure,
999    >,
1000    BuildAccelerationStructureError,
1001> {
1002    let BlasStore {
1003        blas,
1004        entries,
1005        scratch_buffer_offset,
1006    } = storage;
1007    if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
1008        log::info!("only rebuild implemented")
1009    }
1010    let raw = blas.try_raw(snatch_guard)?;
1011
1012    let state_lock = blas.compacted_state.lock();
1013    if let BlasCompactState::Compacted = *state_lock {
1014        return Err(BuildAccelerationStructureError::CompactedBlas(
1015            blas.error_ident(),
1016        ));
1017    }
1018
1019    if blas
1020        .flags
1021        .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
1022    {
1023        blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
1024    }
1025    Ok(hal::BuildAccelerationStructureDescriptor {
1026        entries,
1027        mode: hal::AccelerationStructureBuildMode::Build,
1028        flags: blas.flags,
1029        source_acceleration_structure: None,
1030        destination_acceleration_structure: raw,
1031        scratch_buffer,
1032        scratch_buffer_offset: *scratch_buffer_offset,
1033    })
1034}
1035
1036fn build_blas<'a>(
1037    cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
1038    blas_present: bool,
1039    tlas_present: bool,
1040    input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1041    blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1042        'a,
1043        dyn hal::DynBuffer,
1044        dyn hal::DynAccelerationStructure,
1045    >],
1046    scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1047    blas_s_for_compaction: Vec<(
1048        &'a dyn hal::DynBuffer,
1049        &'a dyn hal::DynAccelerationStructure,
1050    )>,
1051) {
1052    unsafe {
1053        cmd_buf_raw.transition_buffers(&input_barriers);
1054    }
1055
1056    if blas_present {
1057        unsafe {
1058            cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1059                usage: hal::StateTransition {
1060                    from: hal::AccelerationStructureUses::BUILD_INPUT,
1061                    to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1062                },
1063            });
1064
1065            cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1066        }
1067    }
1068
1069    if blas_present && tlas_present {
1070        unsafe {
1071            cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1072        }
1073    }
1074
1075    let mut source_usage = hal::AccelerationStructureUses::empty();
1076    let mut destination_usage = hal::AccelerationStructureUses::empty();
1077    for &(buf, blas) in blas_s_for_compaction.iter() {
1078        unsafe {
1079            cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
1080                buffer: buf,
1081                usage: hal::StateTransition {
1082                    from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1083                    to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1084                },
1085            }])
1086        }
1087        unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
1088        destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
1089    }
1090
1091    if blas_present {
1092        source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1093        destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1094    }
1095    if tlas_present {
1096        source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1097        destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1098    }
1099    unsafe {
1100        cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1101            usage: hal::StateTransition {
1102                from: source_usage,
1103                to: destination_usage,
1104            },
1105        });
1106    }
1107}