wgpu_hal/vulkan/
command.rs

1use super::conv;
2use arrayvec::ArrayVec;
3use ash::vk;
4use core::{mem, ops::Range};
5use hashbrown::hash_map::Entry;
6
7const ALLOCATION_GRANULARITY: u32 = 16;
8const DST_IMAGE_LAYOUT: vk::ImageLayout = vk::ImageLayout::TRANSFER_DST_OPTIMAL;
9
10impl super::Texture {
11    fn map_buffer_copies<T>(&self, regions: T) -> impl Iterator<Item = vk::BufferImageCopy>
12    where
13        T: Iterator<Item = crate::BufferTextureCopy>,
14    {
15        let (block_width, block_height) = self.format.block_dimensions();
16        let format = self.format;
17        let copy_size = self.copy_size;
18        regions.map(move |r| {
19            let extent = r.texture_base.max_copy_size(&copy_size).min(&r.size);
20            let (image_subresource, image_offset) = conv::map_subresource_layers(&r.texture_base);
21            vk::BufferImageCopy {
22                buffer_offset: r.buffer_layout.offset,
23                buffer_row_length: r.buffer_layout.bytes_per_row.map_or(0, |bpr| {
24                    let block_size = format
25                        .block_copy_size(Some(r.texture_base.aspect.map()))
26                        .unwrap();
27                    block_width * (bpr / block_size)
28                }),
29                buffer_image_height: r
30                    .buffer_layout
31                    .rows_per_image
32                    .map_or(0, |rpi| rpi * block_height),
33                image_subresource,
34                image_offset,
35                image_extent: conv::map_copy_extent(&extent),
36            }
37        })
38    }
39}
40
41impl super::CommandEncoder {
42    fn write_pass_end_timestamp_if_requested(&mut self) {
43        if let Some((query_set, index)) = self.end_of_pass_timer_query.take() {
44            unsafe {
45                self.device.raw.cmd_write_timestamp(
46                    self.active,
47                    vk::PipelineStageFlags::BOTTOM_OF_PIPE,
48                    query_set,
49                    index,
50                );
51            }
52        }
53    }
54
55    fn make_framebuffer(
56        &mut self,
57        key: super::FramebufferKey,
58    ) -> Result<vk::Framebuffer, crate::DeviceError> {
59        Ok(match self.framebuffers.entry(key) {
60            Entry::Occupied(e) => *e.get(),
61            Entry::Vacant(e) => {
62                let super::FramebufferKey {
63                    raw_pass,
64                    ref attachment_views,
65                    attachment_identities: _,
66                    extent,
67                } = *e.key();
68
69                let vk_info = vk::FramebufferCreateInfo::default()
70                    .render_pass(raw_pass)
71                    .width(extent.width)
72                    .height(extent.height)
73                    .layers(extent.depth_or_array_layers)
74                    .attachments(attachment_views);
75
76                let raw = unsafe { self.device.raw.create_framebuffer(&vk_info, None).unwrap() };
77                *e.insert(raw)
78            }
79        })
80    }
81
82    fn make_temp_texture_view(
83        &mut self,
84        key: super::TempTextureViewKey,
85    ) -> Result<super::IdentifiedTextureView, crate::DeviceError> {
86        Ok(match self.temp_texture_views.entry(key) {
87            Entry::Occupied(e) => *e.get(),
88            Entry::Vacant(e) => {
89                let super::TempTextureViewKey {
90                    texture,
91                    texture_identity: _,
92                    format,
93                    mip_level,
94                    depth_slice,
95                } = *e.key();
96
97                let vk_info = vk::ImageViewCreateInfo::default()
98                    .image(texture)
99                    .view_type(vk::ImageViewType::TYPE_2D)
100                    .format(format)
101                    .subresource_range(vk::ImageSubresourceRange {
102                        aspect_mask: vk::ImageAspectFlags::COLOR,
103                        base_mip_level: mip_level,
104                        level_count: 1,
105                        base_array_layer: depth_slice,
106                        layer_count: 1,
107                    });
108                let raw = unsafe { self.device.raw.create_image_view(&vk_info, None) }
109                    .map_err(super::map_host_device_oom_and_ioca_err)?;
110
111                let identity = self.device.texture_view_identity_factory.next();
112
113                *e.insert(super::IdentifiedTextureView { raw, identity })
114            }
115        })
116    }
117}
118
119impl crate::CommandEncoder for super::CommandEncoder {
120    type A = super::Api;
121
122    unsafe fn begin_encoding(&mut self, label: crate::Label) -> Result<(), crate::DeviceError> {
123        if self.free.is_empty() {
124            let vk_info = vk::CommandBufferAllocateInfo::default()
125                .command_pool(self.raw)
126                .command_buffer_count(ALLOCATION_GRANULARITY);
127            let cmd_buf_vec = unsafe {
128                self.device
129                    .raw
130                    .allocate_command_buffers(&vk_info)
131                    .map_err(super::map_host_device_oom_err)?
132            };
133            self.free.extend(cmd_buf_vec);
134        }
135        let raw = self.free.pop().unwrap();
136
137        // Set the name unconditionally, since there might be a
138        // previous name assigned to this.
139        unsafe { self.device.set_object_name(raw, label.unwrap_or_default()) };
140
141        // Reset this in case the last renderpass was never ended.
142        self.rpass_debug_marker_active = false;
143
144        let vk_info = vk::CommandBufferBeginInfo::default()
145            .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT);
146        unsafe { self.device.raw.begin_command_buffer(raw, &vk_info) }
147            .map_err(super::map_host_device_oom_err)?;
148        self.active = raw;
149
150        Ok(())
151    }
152
153    unsafe fn end_encoding(&mut self) -> Result<super::CommandBuffer, crate::DeviceError> {
154        let raw = self.active;
155        self.active = vk::CommandBuffer::null();
156        unsafe { self.device.raw.end_command_buffer(raw) }.map_err(map_err)?;
157        fn map_err(err: vk::Result) -> crate::DeviceError {
158            // We don't use VK_KHR_video_encode_queue
159            // VK_ERROR_INVALID_VIDEO_STD_PARAMETERS_KHR
160            super::map_host_device_oom_err(err)
161        }
162        Ok(super::CommandBuffer { raw })
163    }
164
165    unsafe fn discard_encoding(&mut self) {
166        // Safe use requires this is not called in the "closed" state, so the buffer
167        // shouldn't be null. Assert this to make sure we're not pushing null
168        // buffers to the discard pile.
169        assert_ne!(self.active, vk::CommandBuffer::null());
170
171        self.discarded.push(self.active);
172        self.active = vk::CommandBuffer::null();
173    }
174
175    unsafe fn reset_all<I>(&mut self, cmd_bufs: I)
176    where
177        I: Iterator<Item = super::CommandBuffer>,
178    {
179        self.temp.clear();
180        self.free
181            .extend(cmd_bufs.into_iter().map(|cmd_buf| cmd_buf.raw));
182        self.free.append(&mut self.discarded);
183        let _ = unsafe {
184            self.device
185                .raw
186                .reset_command_pool(self.raw, vk::CommandPoolResetFlags::default())
187        };
188    }
189
190    unsafe fn transition_buffers<'a, T>(&mut self, barriers: T)
191    where
192        T: Iterator<Item = crate::BufferBarrier<'a, super::Buffer>>,
193    {
194        //Note: this is done so that we never end up with empty stage flags
195        let mut src_stages = vk::PipelineStageFlags::TOP_OF_PIPE;
196        let mut dst_stages = vk::PipelineStageFlags::BOTTOM_OF_PIPE;
197        let vk_barriers = &mut self.temp.buffer_barriers;
198        vk_barriers.clear();
199
200        for bar in barriers {
201            let (src_stage, src_access) = conv::map_buffer_usage_to_barrier(bar.usage.from);
202            src_stages |= src_stage;
203            let (dst_stage, dst_access) = conv::map_buffer_usage_to_barrier(bar.usage.to);
204            dst_stages |= dst_stage;
205
206            vk_barriers.push(
207                vk::BufferMemoryBarrier::default()
208                    .buffer(bar.buffer.raw)
209                    .size(vk::WHOLE_SIZE)
210                    .src_access_mask(src_access)
211                    .dst_access_mask(dst_access),
212            )
213        }
214
215        if !vk_barriers.is_empty() {
216            unsafe {
217                self.device.raw.cmd_pipeline_barrier(
218                    self.active,
219                    src_stages,
220                    dst_stages,
221                    vk::DependencyFlags::empty(),
222                    &[],
223                    vk_barriers,
224                    &[],
225                )
226            };
227        }
228    }
229
230    unsafe fn transition_textures<'a, T>(&mut self, barriers: T)
231    where
232        T: Iterator<Item = crate::TextureBarrier<'a, super::Texture>>,
233    {
234        let mut src_stages = vk::PipelineStageFlags::empty();
235        let mut dst_stages = vk::PipelineStageFlags::empty();
236        let vk_barriers = &mut self.temp.image_barriers;
237        vk_barriers.clear();
238
239        for bar in barriers {
240            let range = conv::map_subresource_range_combined_aspect(
241                &bar.range,
242                bar.texture.format,
243                &self.device.private_caps,
244            );
245            let (src_stage, src_access) = conv::map_texture_usage_to_barrier(bar.usage.from);
246            let src_layout = conv::derive_image_layout(bar.usage.from, bar.texture.format);
247            src_stages |= src_stage;
248            let (dst_stage, dst_access) = conv::map_texture_usage_to_barrier(bar.usage.to);
249            let dst_layout = conv::derive_image_layout(bar.usage.to, bar.texture.format);
250            dst_stages |= dst_stage;
251
252            vk_barriers.push(
253                vk::ImageMemoryBarrier::default()
254                    .image(bar.texture.raw)
255                    .subresource_range(range)
256                    .src_access_mask(src_access)
257                    .dst_access_mask(dst_access)
258                    .old_layout(src_layout)
259                    .new_layout(dst_layout),
260            );
261        }
262
263        if !vk_barriers.is_empty() {
264            unsafe {
265                self.device.raw.cmd_pipeline_barrier(
266                    self.active,
267                    src_stages,
268                    dst_stages,
269                    vk::DependencyFlags::empty(),
270                    &[],
271                    &[],
272                    vk_barriers,
273                )
274            };
275        }
276    }
277
278    unsafe fn clear_buffer(&mut self, buffer: &super::Buffer, range: crate::MemoryRange) {
279        let range_size = range.end - range.start;
280        if self.device.workarounds.contains(
281            super::Workarounds::FORCE_FILL_BUFFER_WITH_SIZE_GREATER_4096_ALIGNED_OFFSET_16,
282        ) && range_size >= 4096
283            && range.start % 16 != 0
284        {
285            let rounded_start = wgt::math::align_to(range.start, 16);
286            let prefix_size = rounded_start - range.start;
287
288            unsafe {
289                self.device.raw.cmd_fill_buffer(
290                    self.active,
291                    buffer.raw,
292                    range.start,
293                    prefix_size,
294                    0,
295                )
296            };
297
298            // This will never be zero, as rounding can only add up to 12 bytes, and the total size is 4096.
299            let suffix_size = range.end - rounded_start;
300
301            unsafe {
302                self.device.raw.cmd_fill_buffer(
303                    self.active,
304                    buffer.raw,
305                    rounded_start,
306                    suffix_size,
307                    0,
308                )
309            };
310        } else {
311            unsafe {
312                self.device
313                    .raw
314                    .cmd_fill_buffer(self.active, buffer.raw, range.start, range_size, 0)
315            };
316        }
317    }
318
319    unsafe fn copy_buffer_to_buffer<T>(
320        &mut self,
321        src: &super::Buffer,
322        dst: &super::Buffer,
323        regions: T,
324    ) where
325        T: Iterator<Item = crate::BufferCopy>,
326    {
327        let vk_regions_iter = regions.map(|r| vk::BufferCopy {
328            src_offset: r.src_offset,
329            dst_offset: r.dst_offset,
330            size: r.size.get(),
331        });
332
333        unsafe {
334            self.device.raw.cmd_copy_buffer(
335                self.active,
336                src.raw,
337                dst.raw,
338                &smallvec::SmallVec::<[vk::BufferCopy; 32]>::from_iter(vk_regions_iter),
339            )
340        };
341    }
342
343    unsafe fn copy_texture_to_texture<T>(
344        &mut self,
345        src: &super::Texture,
346        src_usage: wgt::TextureUses,
347        dst: &super::Texture,
348        regions: T,
349    ) where
350        T: Iterator<Item = crate::TextureCopy>,
351    {
352        let src_layout = conv::derive_image_layout(src_usage, src.format);
353
354        let vk_regions_iter = regions.map(|r| {
355            let (src_subresource, src_offset) = conv::map_subresource_layers(&r.src_base);
356            let (dst_subresource, dst_offset) = conv::map_subresource_layers(&r.dst_base);
357            let extent = r
358                .size
359                .min(&r.src_base.max_copy_size(&src.copy_size))
360                .min(&r.dst_base.max_copy_size(&dst.copy_size));
361            vk::ImageCopy {
362                src_subresource,
363                src_offset,
364                dst_subresource,
365                dst_offset,
366                extent: conv::map_copy_extent(&extent),
367            }
368        });
369
370        unsafe {
371            self.device.raw.cmd_copy_image(
372                self.active,
373                src.raw,
374                src_layout,
375                dst.raw,
376                DST_IMAGE_LAYOUT,
377                &smallvec::SmallVec::<[vk::ImageCopy; 32]>::from_iter(vk_regions_iter),
378            )
379        };
380    }
381
382    unsafe fn copy_buffer_to_texture<T>(
383        &mut self,
384        src: &super::Buffer,
385        dst: &super::Texture,
386        regions: T,
387    ) where
388        T: Iterator<Item = crate::BufferTextureCopy>,
389    {
390        let vk_regions_iter = dst.map_buffer_copies(regions);
391
392        unsafe {
393            self.device.raw.cmd_copy_buffer_to_image(
394                self.active,
395                src.raw,
396                dst.raw,
397                DST_IMAGE_LAYOUT,
398                &smallvec::SmallVec::<[vk::BufferImageCopy; 32]>::from_iter(vk_regions_iter),
399            )
400        };
401    }
402
403    unsafe fn copy_texture_to_buffer<T>(
404        &mut self,
405        src: &super::Texture,
406        src_usage: wgt::TextureUses,
407        dst: &super::Buffer,
408        regions: T,
409    ) where
410        T: Iterator<Item = crate::BufferTextureCopy>,
411    {
412        let src_layout = conv::derive_image_layout(src_usage, src.format);
413        let vk_regions_iter = src.map_buffer_copies(regions);
414
415        unsafe {
416            self.device.raw.cmd_copy_image_to_buffer(
417                self.active,
418                src.raw,
419                src_layout,
420                dst.raw,
421                &smallvec::SmallVec::<[vk::BufferImageCopy; 32]>::from_iter(vk_regions_iter),
422            )
423        };
424    }
425
426    unsafe fn begin_query(&mut self, set: &super::QuerySet, index: u32) {
427        unsafe {
428            self.device.raw.cmd_begin_query(
429                self.active,
430                set.raw,
431                index,
432                vk::QueryControlFlags::empty(),
433            )
434        };
435    }
436    unsafe fn end_query(&mut self, set: &super::QuerySet, index: u32) {
437        unsafe { self.device.raw.cmd_end_query(self.active, set.raw, index) };
438    }
439    unsafe fn write_timestamp(&mut self, set: &super::QuerySet, index: u32) {
440        unsafe {
441            self.device.raw.cmd_write_timestamp(
442                self.active,
443                vk::PipelineStageFlags::BOTTOM_OF_PIPE,
444                set.raw,
445                index,
446            )
447        };
448    }
449    unsafe fn read_acceleration_structure_compact_size(
450        &mut self,
451        acceleration_structure: &super::AccelerationStructure,
452        buffer: &super::Buffer,
453    ) {
454        let ray_tracing_functions = self
455            .device
456            .extension_fns
457            .ray_tracing
458            .as_ref()
459            .expect("Feature `RAY_TRACING` not enabled");
460        let query_pool = acceleration_structure
461            .compacted_size_query
462            .as_ref()
463            .unwrap();
464        unsafe {
465            self.device
466                .raw
467                .cmd_reset_query_pool(self.active, *query_pool, 0, 1);
468            ray_tracing_functions
469                .acceleration_structure
470                .cmd_write_acceleration_structures_properties(
471                    self.active,
472                    &[acceleration_structure.raw],
473                    vk::QueryType::ACCELERATION_STRUCTURE_COMPACTED_SIZE_KHR,
474                    *query_pool,
475                    0,
476                );
477            self.device.raw.cmd_copy_query_pool_results(
478                self.active,
479                *query_pool,
480                0,
481                1,
482                buffer.raw,
483                0,
484                wgt::QUERY_SIZE as vk::DeviceSize,
485                vk::QueryResultFlags::TYPE_64 | vk::QueryResultFlags::WAIT,
486            )
487        };
488    }
489    unsafe fn reset_queries(&mut self, set: &super::QuerySet, range: Range<u32>) {
490        unsafe {
491            self.device.raw.cmd_reset_query_pool(
492                self.active,
493                set.raw,
494                range.start,
495                range.end - range.start,
496            )
497        };
498    }
499    unsafe fn copy_query_results(
500        &mut self,
501        set: &super::QuerySet,
502        range: Range<u32>,
503        buffer: &super::Buffer,
504        offset: wgt::BufferAddress,
505        stride: wgt::BufferSize,
506    ) {
507        unsafe {
508            self.device.raw.cmd_copy_query_pool_results(
509                self.active,
510                set.raw,
511                range.start,
512                range.end - range.start,
513                buffer.raw,
514                offset,
515                stride.get(),
516                vk::QueryResultFlags::TYPE_64 | vk::QueryResultFlags::WAIT,
517            )
518        };
519    }
520
521    unsafe fn build_acceleration_structures<'a, T>(&mut self, descriptor_count: u32, descriptors: T)
522    where
523        super::Api: 'a,
524        T: IntoIterator<
525            Item = crate::BuildAccelerationStructureDescriptor<
526                'a,
527                super::Buffer,
528                super::AccelerationStructure,
529            >,
530        >,
531    {
532        const CAPACITY_OUTER: usize = 8;
533        const CAPACITY_INNER: usize = 1;
534        let descriptor_count = descriptor_count as usize;
535
536        let ray_tracing_functions = self
537            .device
538            .extension_fns
539            .ray_tracing
540            .as_ref()
541            .expect("Feature `RAY_TRACING` not enabled");
542
543        let get_device_address = |buffer: Option<&super::Buffer>| unsafe {
544            match buffer {
545                Some(buffer) => ray_tracing_functions
546                    .buffer_device_address
547                    .get_buffer_device_address(
548                        &vk::BufferDeviceAddressInfo::default().buffer(buffer.raw),
549                    ),
550                None => panic!("Buffers are required to build acceleration structures"),
551            }
552        };
553
554        // storage to all the data required for cmd_build_acceleration_structures
555        let mut ranges_storage = smallvec::SmallVec::<
556            [smallvec::SmallVec<[vk::AccelerationStructureBuildRangeInfoKHR; CAPACITY_INNER]>;
557                CAPACITY_OUTER],
558        >::with_capacity(descriptor_count);
559        let mut geometries_storage = smallvec::SmallVec::<
560            [smallvec::SmallVec<[vk::AccelerationStructureGeometryKHR; CAPACITY_INNER]>;
561                CAPACITY_OUTER],
562        >::with_capacity(descriptor_count);
563
564        // pointers to all the data required for cmd_build_acceleration_structures
565        let mut geometry_infos = smallvec::SmallVec::<
566            [vk::AccelerationStructureBuildGeometryInfoKHR; CAPACITY_OUTER],
567        >::with_capacity(descriptor_count);
568        let mut ranges_ptrs = smallvec::SmallVec::<
569            [&[vk::AccelerationStructureBuildRangeInfoKHR]; CAPACITY_OUTER],
570        >::with_capacity(descriptor_count);
571
572        for desc in descriptors {
573            let (geometries, ranges) = match *desc.entries {
574                crate::AccelerationStructureEntries::Instances(ref instances) => {
575                    let instance_data = vk::AccelerationStructureGeometryInstancesDataKHR::default(
576                    // TODO: Code is so large that rustfmt refuses to treat this... :(
577                    )
578                    .data(vk::DeviceOrHostAddressConstKHR {
579                        device_address: get_device_address(instances.buffer),
580                    });
581
582                    let geometry = vk::AccelerationStructureGeometryKHR::default()
583                        .geometry_type(vk::GeometryTypeKHR::INSTANCES)
584                        .geometry(vk::AccelerationStructureGeometryDataKHR {
585                            instances: instance_data,
586                        });
587
588                    let range = vk::AccelerationStructureBuildRangeInfoKHR::default()
589                        .primitive_count(instances.count)
590                        .primitive_offset(instances.offset);
591
592                    (smallvec::smallvec![geometry], smallvec::smallvec![range])
593                }
594                crate::AccelerationStructureEntries::Triangles(ref in_geometries) => {
595                    let mut ranges = smallvec::SmallVec::<
596                        [vk::AccelerationStructureBuildRangeInfoKHR; CAPACITY_INNER],
597                    >::with_capacity(in_geometries.len());
598                    let mut geometries = smallvec::SmallVec::<
599                        [vk::AccelerationStructureGeometryKHR; CAPACITY_INNER],
600                    >::with_capacity(in_geometries.len());
601                    for triangles in in_geometries {
602                        let mut triangle_data =
603                            vk::AccelerationStructureGeometryTrianglesDataKHR::default()
604                                // IndexType::NONE_KHR is not set by default (due to being provided by VK_KHR_acceleration_structure) but unless there is an
605                                // index buffer we need to have IndexType::NONE_KHR as our index type.
606                                .index_type(vk::IndexType::NONE_KHR)
607                                .vertex_data(vk::DeviceOrHostAddressConstKHR {
608                                    device_address: get_device_address(triangles.vertex_buffer),
609                                })
610                                .vertex_format(conv::map_vertex_format(triangles.vertex_format))
611                                .max_vertex(triangles.vertex_count)
612                                .vertex_stride(triangles.vertex_stride);
613
614                        let mut range = vk::AccelerationStructureBuildRangeInfoKHR::default();
615
616                        if let Some(ref indices) = triangles.indices {
617                            triangle_data = triangle_data
618                                .index_data(vk::DeviceOrHostAddressConstKHR {
619                                    device_address: get_device_address(indices.buffer),
620                                })
621                                .index_type(conv::map_index_format(indices.format));
622
623                            range = range
624                                .primitive_count(indices.count / 3)
625                                .primitive_offset(indices.offset)
626                                .first_vertex(triangles.first_vertex);
627                        } else {
628                            range = range
629                                .primitive_count(triangles.vertex_count)
630                                .first_vertex(triangles.first_vertex);
631                        }
632
633                        if let Some(ref transform) = triangles.transform {
634                            let transform_device_address = unsafe {
635                                ray_tracing_functions
636                                    .buffer_device_address
637                                    .get_buffer_device_address(
638                                        &vk::BufferDeviceAddressInfo::default()
639                                            .buffer(transform.buffer.raw),
640                                    )
641                            };
642                            triangle_data =
643                                triangle_data.transform_data(vk::DeviceOrHostAddressConstKHR {
644                                    device_address: transform_device_address,
645                                });
646
647                            range = range.transform_offset(transform.offset);
648                        }
649
650                        let geometry = vk::AccelerationStructureGeometryKHR::default()
651                            .geometry_type(vk::GeometryTypeKHR::TRIANGLES)
652                            .geometry(vk::AccelerationStructureGeometryDataKHR {
653                                triangles: triangle_data,
654                            })
655                            .flags(conv::map_acceleration_structure_geometry_flags(
656                                triangles.flags,
657                            ));
658
659                        geometries.push(geometry);
660                        ranges.push(range);
661                    }
662                    (geometries, ranges)
663                }
664                crate::AccelerationStructureEntries::AABBs(ref in_geometries) => {
665                    let mut ranges = smallvec::SmallVec::<
666                        [vk::AccelerationStructureBuildRangeInfoKHR; CAPACITY_INNER],
667                    >::with_capacity(in_geometries.len());
668                    let mut geometries = smallvec::SmallVec::<
669                        [vk::AccelerationStructureGeometryKHR; CAPACITY_INNER],
670                    >::with_capacity(in_geometries.len());
671                    for aabb in in_geometries {
672                        let aabbs_data = vk::AccelerationStructureGeometryAabbsDataKHR::default()
673                            .data(vk::DeviceOrHostAddressConstKHR {
674                                device_address: get_device_address(aabb.buffer),
675                            })
676                            .stride(aabb.stride);
677
678                        let range = vk::AccelerationStructureBuildRangeInfoKHR::default()
679                            .primitive_count(aabb.count)
680                            .primitive_offset(aabb.offset);
681
682                        let geometry = vk::AccelerationStructureGeometryKHR::default()
683                            .geometry_type(vk::GeometryTypeKHR::AABBS)
684                            .geometry(vk::AccelerationStructureGeometryDataKHR {
685                                aabbs: aabbs_data,
686                            })
687                            .flags(conv::map_acceleration_structure_geometry_flags(aabb.flags));
688
689                        geometries.push(geometry);
690                        ranges.push(range);
691                    }
692                    (geometries, ranges)
693                }
694            };
695
696            ranges_storage.push(ranges);
697            geometries_storage.push(geometries);
698
699            let scratch_device_address = unsafe {
700                ray_tracing_functions
701                    .buffer_device_address
702                    .get_buffer_device_address(
703                        &vk::BufferDeviceAddressInfo::default().buffer(desc.scratch_buffer.raw),
704                    )
705            };
706            let ty = match *desc.entries {
707                crate::AccelerationStructureEntries::Instances(_) => {
708                    vk::AccelerationStructureTypeKHR::TOP_LEVEL
709                }
710                _ => vk::AccelerationStructureTypeKHR::BOTTOM_LEVEL,
711            };
712            let mut geometry_info = vk::AccelerationStructureBuildGeometryInfoKHR::default()
713                .ty(ty)
714                .mode(conv::map_acceleration_structure_build_mode(desc.mode))
715                .flags(conv::map_acceleration_structure_flags(desc.flags))
716                .dst_acceleration_structure(desc.destination_acceleration_structure.raw)
717                .scratch_data(vk::DeviceOrHostAddressKHR {
718                    device_address: scratch_device_address + desc.scratch_buffer_offset,
719                });
720
721            if desc.mode == crate::AccelerationStructureBuildMode::Update {
722                geometry_info.src_acceleration_structure = desc
723                    .source_acceleration_structure
724                    .unwrap_or(desc.destination_acceleration_structure)
725                    .raw;
726            }
727
728            geometry_infos.push(geometry_info);
729        }
730
731        for (i, geometry_info) in geometry_infos.iter_mut().enumerate() {
732            geometry_info.geometry_count = geometries_storage[i].len() as u32;
733            geometry_info.p_geometries = geometries_storage[i].as_ptr();
734            ranges_ptrs.push(&ranges_storage[i]);
735        }
736
737        unsafe {
738            ray_tracing_functions
739                .acceleration_structure
740                .cmd_build_acceleration_structures(self.active, &geometry_infos, &ranges_ptrs);
741        }
742    }
743
744    unsafe fn place_acceleration_structure_barrier(
745        &mut self,
746        barrier: crate::AccelerationStructureBarrier,
747    ) {
748        let (src_stage, src_access) = conv::map_acceleration_structure_usage_to_barrier(
749            barrier.usage.from,
750            self.device.features,
751        );
752        let (dst_stage, dst_access) = conv::map_acceleration_structure_usage_to_barrier(
753            barrier.usage.to,
754            self.device.features,
755        );
756
757        unsafe {
758            self.device.raw.cmd_pipeline_barrier(
759                self.active,
760                src_stage | vk::PipelineStageFlags::TOP_OF_PIPE,
761                dst_stage | vk::PipelineStageFlags::BOTTOM_OF_PIPE,
762                vk::DependencyFlags::empty(),
763                &[vk::MemoryBarrier::default()
764                    .src_access_mask(src_access)
765                    .dst_access_mask(dst_access)],
766                &[],
767                &[],
768            )
769        };
770    }
771    // render
772
773    unsafe fn begin_render_pass(
774        &mut self,
775        desc: &crate::RenderPassDescriptor<super::QuerySet, super::TextureView>,
776    ) -> Result<(), crate::DeviceError> {
777        let mut vk_clear_values =
778            ArrayVec::<vk::ClearValue, { super::MAX_TOTAL_ATTACHMENTS }>::new();
779        let mut rp_key = super::RenderPassKey {
780            colors: ArrayVec::default(),
781            depth_stencil: None,
782            sample_count: desc.sample_count,
783            multiview: desc.multiview,
784        };
785        let mut fb_key = super::FramebufferKey {
786            raw_pass: vk::RenderPass::null(),
787            attachment_views: ArrayVec::default(),
788            attachment_identities: ArrayVec::default(),
789            extent: desc.extent,
790        };
791
792        for cat in desc.color_attachments {
793            if let Some(cat) = cat.as_ref() {
794                let color_view = if cat.target.view.dimension == wgt::TextureViewDimension::D3 {
795                    let key = super::TempTextureViewKey {
796                        texture: cat.target.view.raw_texture,
797                        texture_identity: cat.target.view.texture_identity,
798                        format: cat.target.view.raw_format,
799                        mip_level: cat.target.view.base_mip_level,
800                        depth_slice: cat.depth_slice.unwrap(),
801                    };
802                    self.make_temp_texture_view(key)?
803                } else {
804                    cat.target.view.identified_raw_view()
805                };
806
807                vk_clear_values.push(vk::ClearValue {
808                    color: unsafe { cat.make_vk_clear_color() },
809                });
810                let color = super::ColorAttachmentKey {
811                    base: cat.target.make_attachment_key(cat.ops),
812                    resolve: cat
813                        .resolve_target
814                        .as_ref()
815                        .map(|target| target.make_attachment_key(crate::AttachmentOps::STORE)),
816                };
817
818                rp_key.colors.push(Some(color));
819                fb_key.push_view(color_view);
820                if let Some(ref at) = cat.resolve_target {
821                    vk_clear_values.push(unsafe { mem::zeroed() });
822                    fb_key.push_view(at.view.identified_raw_view());
823                }
824
825                // Assert this attachment is valid for the detected multiview, as a sanity check
826                // The driver crash for this is really bad on AMD, so the check is worth it
827                if let Some(multiview) = desc.multiview {
828                    assert_eq!(cat.target.view.layers, multiview);
829                    if let Some(ref resolve_target) = cat.resolve_target {
830                        assert_eq!(resolve_target.view.layers, multiview);
831                    }
832                }
833            } else {
834                rp_key.colors.push(None);
835            }
836        }
837        if let Some(ref ds) = desc.depth_stencil_attachment {
838            vk_clear_values.push(vk::ClearValue {
839                depth_stencil: vk::ClearDepthStencilValue {
840                    depth: ds.clear_value.0,
841                    stencil: ds.clear_value.1,
842                },
843            });
844            rp_key.depth_stencil = Some(super::DepthStencilAttachmentKey {
845                base: ds.target.make_attachment_key(ds.depth_ops),
846                stencil_ops: ds.stencil_ops,
847            });
848            fb_key.push_view(ds.target.view.identified_raw_view());
849
850            // Assert this attachment is valid for the detected multiview, as a sanity check
851            // The driver crash for this is really bad on AMD, so the check is worth it
852            if let Some(multiview) = desc.multiview {
853                assert_eq!(ds.target.view.layers, multiview);
854            }
855        }
856
857        let render_area = vk::Rect2D {
858            offset: vk::Offset2D { x: 0, y: 0 },
859            extent: vk::Extent2D {
860                width: desc.extent.width,
861                height: desc.extent.height,
862            },
863        };
864        let vk_viewports = [vk::Viewport {
865            x: 0.0,
866            y: desc.extent.height as f32,
867            width: desc.extent.width as f32,
868            height: -(desc.extent.height as f32),
869            min_depth: 0.0,
870            max_depth: 1.0,
871        }];
872
873        let raw_pass = self.device.make_render_pass(rp_key).unwrap();
874        fb_key.raw_pass = raw_pass;
875        let raw_framebuffer = self.make_framebuffer(fb_key).unwrap();
876
877        let vk_info = vk::RenderPassBeginInfo::default()
878            .render_pass(raw_pass)
879            .render_area(render_area)
880            .clear_values(&vk_clear_values)
881            .framebuffer(raw_framebuffer);
882
883        if let Some(label) = desc.label {
884            unsafe { self.begin_debug_marker(label) };
885            self.rpass_debug_marker_active = true;
886        }
887
888        // Start timestamp if any (before all other commands but after debug marker)
889        if let Some(timestamp_writes) = desc.timestamp_writes.as_ref() {
890            if let Some(index) = timestamp_writes.beginning_of_pass_write_index {
891                unsafe {
892                    self.write_timestamp(timestamp_writes.query_set, index);
893                }
894            }
895            self.end_of_pass_timer_query = timestamp_writes
896                .end_of_pass_write_index
897                .map(|index| (timestamp_writes.query_set.raw, index));
898        }
899
900        unsafe {
901            self.device
902                .raw
903                .cmd_set_viewport(self.active, 0, &vk_viewports);
904            self.device
905                .raw
906                .cmd_set_scissor(self.active, 0, &[render_area]);
907            self.device.raw.cmd_begin_render_pass(
908                self.active,
909                &vk_info,
910                vk::SubpassContents::INLINE,
911            );
912        };
913
914        self.bind_point = vk::PipelineBindPoint::GRAPHICS;
915
916        Ok(())
917    }
918    unsafe fn end_render_pass(&mut self) {
919        unsafe {
920            self.device.raw.cmd_end_render_pass(self.active);
921        }
922
923        // After all other commands but before debug marker, so this is still seen as part of this pass.
924        self.write_pass_end_timestamp_if_requested();
925
926        if self.rpass_debug_marker_active {
927            unsafe {
928                self.end_debug_marker();
929            }
930            self.rpass_debug_marker_active = false;
931        }
932    }
933
934    unsafe fn set_bind_group(
935        &mut self,
936        layout: &super::PipelineLayout,
937        index: u32,
938        group: &super::BindGroup,
939        dynamic_offsets: &[wgt::DynamicOffset],
940    ) {
941        let sets = [*group.set.raw()];
942        unsafe {
943            self.device.raw.cmd_bind_descriptor_sets(
944                self.active,
945                self.bind_point,
946                layout.raw,
947                index,
948                &sets,
949                dynamic_offsets,
950            )
951        };
952    }
953    unsafe fn set_push_constants(
954        &mut self,
955        layout: &super::PipelineLayout,
956        stages: wgt::ShaderStages,
957        offset_bytes: u32,
958        data: &[u32],
959    ) {
960        unsafe {
961            self.device.raw.cmd_push_constants(
962                self.active,
963                layout.raw,
964                conv::map_shader_stage(stages),
965                offset_bytes,
966                bytemuck::cast_slice(data),
967            )
968        };
969    }
970
971    unsafe fn insert_debug_marker(&mut self, label: &str) {
972        if let Some(ext) = self.device.extension_fns.debug_utils.as_ref() {
973            let cstr = self.temp.make_c_str(label);
974            let vk_label = vk::DebugUtilsLabelEXT::default().label_name(cstr);
975            unsafe { ext.cmd_insert_debug_utils_label(self.active, &vk_label) };
976        }
977    }
978    unsafe fn begin_debug_marker(&mut self, group_label: &str) {
979        if let Some(ext) = self.device.extension_fns.debug_utils.as_ref() {
980            let cstr = self.temp.make_c_str(group_label);
981            let vk_label = vk::DebugUtilsLabelEXT::default().label_name(cstr);
982            unsafe { ext.cmd_begin_debug_utils_label(self.active, &vk_label) };
983        }
984    }
985    unsafe fn end_debug_marker(&mut self) {
986        if let Some(ext) = self.device.extension_fns.debug_utils.as_ref() {
987            unsafe { ext.cmd_end_debug_utils_label(self.active) };
988        }
989    }
990
991    unsafe fn set_render_pipeline(&mut self, pipeline: &super::RenderPipeline) {
992        unsafe {
993            self.device.raw.cmd_bind_pipeline(
994                self.active,
995                vk::PipelineBindPoint::GRAPHICS,
996                pipeline.raw,
997            )
998        };
999    }
1000
1001    unsafe fn set_index_buffer<'a>(
1002        &mut self,
1003        binding: crate::BufferBinding<'a, super::Buffer>,
1004        format: wgt::IndexFormat,
1005    ) {
1006        unsafe {
1007            self.device.raw.cmd_bind_index_buffer(
1008                self.active,
1009                binding.buffer.raw,
1010                binding.offset,
1011                conv::map_index_format(format),
1012            )
1013        };
1014    }
1015    unsafe fn set_vertex_buffer<'a>(
1016        &mut self,
1017        index: u32,
1018        binding: crate::BufferBinding<'a, super::Buffer>,
1019    ) {
1020        let vk_buffers = [binding.buffer.raw];
1021        let vk_offsets = [binding.offset];
1022        unsafe {
1023            self.device
1024                .raw
1025                .cmd_bind_vertex_buffers(self.active, index, &vk_buffers, &vk_offsets)
1026        };
1027    }
1028    unsafe fn set_viewport(&mut self, rect: &crate::Rect<f32>, depth_range: Range<f32>) {
1029        let vk_viewports = [vk::Viewport {
1030            x: rect.x,
1031            y: rect.y + rect.h,
1032            width: rect.w,
1033            height: -rect.h, // flip Y
1034            min_depth: depth_range.start,
1035            max_depth: depth_range.end,
1036        }];
1037        unsafe {
1038            self.device
1039                .raw
1040                .cmd_set_viewport(self.active, 0, &vk_viewports)
1041        };
1042    }
1043    unsafe fn set_scissor_rect(&mut self, rect: &crate::Rect<u32>) {
1044        let vk_scissors = [vk::Rect2D {
1045            offset: vk::Offset2D {
1046                x: rect.x as i32,
1047                y: rect.y as i32,
1048            },
1049            extent: vk::Extent2D {
1050                width: rect.w,
1051                height: rect.h,
1052            },
1053        }];
1054        unsafe {
1055            self.device
1056                .raw
1057                .cmd_set_scissor(self.active, 0, &vk_scissors)
1058        };
1059    }
1060    unsafe fn set_stencil_reference(&mut self, value: u32) {
1061        unsafe {
1062            self.device.raw.cmd_set_stencil_reference(
1063                self.active,
1064                vk::StencilFaceFlags::FRONT_AND_BACK,
1065                value,
1066            )
1067        };
1068    }
1069    unsafe fn set_blend_constants(&mut self, color: &[f32; 4]) {
1070        unsafe { self.device.raw.cmd_set_blend_constants(self.active, color) };
1071    }
1072
1073    unsafe fn draw(
1074        &mut self,
1075        first_vertex: u32,
1076        vertex_count: u32,
1077        first_instance: u32,
1078        instance_count: u32,
1079    ) {
1080        unsafe {
1081            self.device.raw.cmd_draw(
1082                self.active,
1083                vertex_count,
1084                instance_count,
1085                first_vertex,
1086                first_instance,
1087            )
1088        };
1089    }
1090    unsafe fn draw_indexed(
1091        &mut self,
1092        first_index: u32,
1093        index_count: u32,
1094        base_vertex: i32,
1095        first_instance: u32,
1096        instance_count: u32,
1097    ) {
1098        unsafe {
1099            self.device.raw.cmd_draw_indexed(
1100                self.active,
1101                index_count,
1102                instance_count,
1103                first_index,
1104                base_vertex,
1105                first_instance,
1106            )
1107        };
1108    }
1109    unsafe fn draw_mesh_tasks(
1110        &mut self,
1111        group_count_x: u32,
1112        group_count_y: u32,
1113        group_count_z: u32,
1114    ) {
1115        if let Some(ref t) = self.device.extension_fns.mesh_shading {
1116            unsafe {
1117                t.cmd_draw_mesh_tasks(self.active, group_count_x, group_count_y, group_count_z);
1118            };
1119        } else {
1120            panic!("Feature `MESH_SHADING` not enabled");
1121        }
1122    }
1123    unsafe fn draw_indirect(
1124        &mut self,
1125        buffer: &super::Buffer,
1126        offset: wgt::BufferAddress,
1127        draw_count: u32,
1128    ) {
1129        unsafe {
1130            self.device.raw.cmd_draw_indirect(
1131                self.active,
1132                buffer.raw,
1133                offset,
1134                draw_count,
1135                size_of::<wgt::DrawIndirectArgs>() as u32,
1136            )
1137        };
1138    }
1139    unsafe fn draw_indexed_indirect(
1140        &mut self,
1141        buffer: &super::Buffer,
1142        offset: wgt::BufferAddress,
1143        draw_count: u32,
1144    ) {
1145        unsafe {
1146            self.device.raw.cmd_draw_indexed_indirect(
1147                self.active,
1148                buffer.raw,
1149                offset,
1150                draw_count,
1151                size_of::<wgt::DrawIndexedIndirectArgs>() as u32,
1152            )
1153        };
1154    }
1155    unsafe fn draw_mesh_tasks_indirect(
1156        &mut self,
1157        buffer: &<Self::A as crate::Api>::Buffer,
1158        offset: wgt::BufferAddress,
1159        draw_count: u32,
1160    ) {
1161        if let Some(ref t) = self.device.extension_fns.mesh_shading {
1162            unsafe {
1163                t.cmd_draw_mesh_tasks_indirect(
1164                    self.active,
1165                    buffer.raw,
1166                    offset,
1167                    draw_count,
1168                    size_of::<wgt::DispatchIndirectArgs>() as u32,
1169                );
1170            };
1171        } else {
1172            panic!("Feature `MESH_SHADING` not enabled");
1173        }
1174    }
1175    unsafe fn draw_indirect_count(
1176        &mut self,
1177        buffer: &super::Buffer,
1178        offset: wgt::BufferAddress,
1179        count_buffer: &super::Buffer,
1180        count_offset: wgt::BufferAddress,
1181        max_count: u32,
1182    ) {
1183        let stride = size_of::<wgt::DrawIndirectArgs>() as u32;
1184        match self.device.extension_fns.draw_indirect_count {
1185            Some(ref t) => {
1186                unsafe {
1187                    t.cmd_draw_indirect_count(
1188                        self.active,
1189                        buffer.raw,
1190                        offset,
1191                        count_buffer.raw,
1192                        count_offset,
1193                        max_count,
1194                        stride,
1195                    )
1196                };
1197            }
1198            None => panic!("Feature `DRAW_INDIRECT_COUNT` not enabled"),
1199        }
1200    }
1201    unsafe fn draw_indexed_indirect_count(
1202        &mut self,
1203        buffer: &super::Buffer,
1204        offset: wgt::BufferAddress,
1205        count_buffer: &super::Buffer,
1206        count_offset: wgt::BufferAddress,
1207        max_count: u32,
1208    ) {
1209        let stride = size_of::<wgt::DrawIndexedIndirectArgs>() as u32;
1210        match self.device.extension_fns.draw_indirect_count {
1211            Some(ref t) => {
1212                unsafe {
1213                    t.cmd_draw_indexed_indirect_count(
1214                        self.active,
1215                        buffer.raw,
1216                        offset,
1217                        count_buffer.raw,
1218                        count_offset,
1219                        max_count,
1220                        stride,
1221                    )
1222                };
1223            }
1224            None => panic!("Feature `DRAW_INDIRECT_COUNT` not enabled"),
1225        }
1226    }
1227    unsafe fn draw_mesh_tasks_indirect_count(
1228        &mut self,
1229        buffer: &<Self::A as crate::Api>::Buffer,
1230        offset: wgt::BufferAddress,
1231        count_buffer: &super::Buffer,
1232        count_offset: wgt::BufferAddress,
1233        max_count: u32,
1234    ) {
1235        if self.device.extension_fns.draw_indirect_count.is_none() {
1236            panic!("Feature `DRAW_INDIRECT_COUNT` not enabled");
1237        }
1238        if let Some(ref t) = self.device.extension_fns.mesh_shading {
1239            unsafe {
1240                t.cmd_draw_mesh_tasks_indirect_count(
1241                    self.active,
1242                    buffer.raw,
1243                    offset,
1244                    count_buffer.raw,
1245                    count_offset,
1246                    max_count,
1247                    size_of::<wgt::DispatchIndirectArgs>() as u32,
1248                );
1249            };
1250        } else {
1251            panic!("Feature `MESH_SHADING` not enabled");
1252        }
1253    }
1254
1255    // compute
1256
1257    unsafe fn begin_compute_pass(
1258        &mut self,
1259        desc: &crate::ComputePassDescriptor<'_, super::QuerySet>,
1260    ) {
1261        self.bind_point = vk::PipelineBindPoint::COMPUTE;
1262        if let Some(label) = desc.label {
1263            unsafe { self.begin_debug_marker(label) };
1264            self.rpass_debug_marker_active = true;
1265        }
1266
1267        if let Some(timestamp_writes) = desc.timestamp_writes.as_ref() {
1268            if let Some(index) = timestamp_writes.beginning_of_pass_write_index {
1269                unsafe {
1270                    self.write_timestamp(timestamp_writes.query_set, index);
1271                }
1272            }
1273            self.end_of_pass_timer_query = timestamp_writes
1274                .end_of_pass_write_index
1275                .map(|index| (timestamp_writes.query_set.raw, index));
1276        }
1277    }
1278    unsafe fn end_compute_pass(&mut self) {
1279        self.write_pass_end_timestamp_if_requested();
1280
1281        if self.rpass_debug_marker_active {
1282            unsafe { self.end_debug_marker() };
1283            self.rpass_debug_marker_active = false
1284        }
1285    }
1286
1287    unsafe fn set_compute_pipeline(&mut self, pipeline: &super::ComputePipeline) {
1288        unsafe {
1289            self.device.raw.cmd_bind_pipeline(
1290                self.active,
1291                vk::PipelineBindPoint::COMPUTE,
1292                pipeline.raw,
1293            )
1294        };
1295    }
1296
1297    unsafe fn dispatch(&mut self, count: [u32; 3]) {
1298        unsafe {
1299            self.device
1300                .raw
1301                .cmd_dispatch(self.active, count[0], count[1], count[2])
1302        };
1303    }
1304    unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) {
1305        unsafe {
1306            self.device
1307                .raw
1308                .cmd_dispatch_indirect(self.active, buffer.raw, offset)
1309        }
1310    }
1311
1312    unsafe fn copy_acceleration_structure_to_acceleration_structure(
1313        &mut self,
1314        src: &super::AccelerationStructure,
1315        dst: &super::AccelerationStructure,
1316        copy: wgt::AccelerationStructureCopy,
1317    ) {
1318        let ray_tracing_functions = self
1319            .device
1320            .extension_fns
1321            .ray_tracing
1322            .as_ref()
1323            .expect("Feature `RAY_TRACING` not enabled");
1324
1325        let mode = match copy {
1326            wgt::AccelerationStructureCopy::Clone => vk::CopyAccelerationStructureModeKHR::CLONE,
1327            wgt::AccelerationStructureCopy::Compact => {
1328                vk::CopyAccelerationStructureModeKHR::COMPACT
1329            }
1330        };
1331
1332        unsafe {
1333            ray_tracing_functions
1334                .acceleration_structure
1335                .cmd_copy_acceleration_structure(
1336                    self.active,
1337                    &vk::CopyAccelerationStructureInfoKHR {
1338                        s_type: vk::StructureType::COPY_ACCELERATION_STRUCTURE_INFO_KHR,
1339                        p_next: core::ptr::null(),
1340                        src: src.raw,
1341                        dst: dst.raw,
1342                        mode,
1343                        _marker: Default::default(),
1344                    },
1345                );
1346        }
1347    }
1348}
1349
1350#[test]
1351fn check_dst_image_layout() {
1352    assert_eq!(
1353        conv::derive_image_layout(wgt::TextureUses::COPY_DST, wgt::TextureFormat::Rgba8Unorm),
1354        DST_IMAGE_LAYOUT
1355    );
1356}