vulk_ext/vkx/
descriptor.rs

1use super::*;
2
3const DESCRIPTOR_MAX_SIZE: usize = 16;
4
5pub(crate) fn validate_descriptor_sizes(
6    p: &vk::PhysicalDeviceDescriptorBufferPropertiesEXT,
7) -> Result<()> {
8    ensure!(DESCRIPTOR_MAX_SIZE >= p.buffer_capture_replay_descriptor_data_size);
9    ensure!(DESCRIPTOR_MAX_SIZE >= p.image_capture_replay_descriptor_data_size);
10    ensure!(DESCRIPTOR_MAX_SIZE >= p.image_view_capture_replay_descriptor_data_size);
11    ensure!(DESCRIPTOR_MAX_SIZE >= p.sampler_capture_replay_descriptor_data_size);
12    ensure!(DESCRIPTOR_MAX_SIZE >= p.acceleration_structure_capture_replay_descriptor_data_size);
13    ensure!(DESCRIPTOR_MAX_SIZE >= p.sampler_descriptor_size);
14    ensure!(DESCRIPTOR_MAX_SIZE >= p.combined_image_sampler_descriptor_size);
15    ensure!(DESCRIPTOR_MAX_SIZE >= p.sampled_image_descriptor_size);
16    ensure!(DESCRIPTOR_MAX_SIZE >= p.storage_image_descriptor_size);
17    ensure!(DESCRIPTOR_MAX_SIZE >= p.uniform_texel_buffer_descriptor_size);
18    ensure!(DESCRIPTOR_MAX_SIZE >= p.robust_uniform_texel_buffer_descriptor_size);
19    ensure!(DESCRIPTOR_MAX_SIZE >= p.storage_texel_buffer_descriptor_size);
20    ensure!(DESCRIPTOR_MAX_SIZE >= p.robust_storage_texel_buffer_descriptor_size);
21    ensure!(DESCRIPTOR_MAX_SIZE >= p.uniform_buffer_descriptor_size);
22    ensure!(DESCRIPTOR_MAX_SIZE >= p.robust_uniform_buffer_descriptor_size);
23    ensure!(DESCRIPTOR_MAX_SIZE >= p.storage_buffer_descriptor_size);
24    ensure!(DESCRIPTOR_MAX_SIZE >= p.robust_storage_buffer_descriptor_size);
25    ensure!(DESCRIPTOR_MAX_SIZE >= p.input_attachment_descriptor_size);
26    ensure!(DESCRIPTOR_MAX_SIZE >= p.acceleration_structure_descriptor_size);
27    Ok(())
28}
29
30#[derive(Clone, Copy, Debug)]
31pub enum DescriptorCreateInfo {
32    UniformBuffer {
33        address: vk::DeviceAddress,
34        range: vk::DeviceSize,
35    },
36    StorageBuffer {
37        address: vk::DeviceAddress,
38        range: vk::DeviceSize,
39    },
40    SampledImage {
41        image_view: vk::ImageView,
42        image_layout: vk::ImageLayout,
43    },
44    StorageImage {
45        image_view: vk::ImageView,
46        image_layout: vk::ImageLayout,
47    },
48    InputAttachment {
49        image_view: vk::ImageView,
50        image_layout: vk::ImageLayout,
51    },
52    Sampler(vk::Sampler),
53    AccelerationStructure(vk::DeviceAddress),
54}
55
56impl DescriptorCreateInfo {
57    fn size(&self, props: &vk::PhysicalDeviceDescriptorBufferPropertiesEXT) -> usize {
58        match self {
59            DescriptorCreateInfo::UniformBuffer { .. } => props.uniform_buffer_descriptor_size,
60            DescriptorCreateInfo::StorageBuffer { .. } => props.storage_buffer_descriptor_size,
61            DescriptorCreateInfo::SampledImage { .. } => props.sampled_image_descriptor_size,
62            DescriptorCreateInfo::StorageImage { .. } => props.storage_image_descriptor_size,
63            DescriptorCreateInfo::InputAttachment { .. } => props.input_attachment_descriptor_size,
64            DescriptorCreateInfo::Sampler(_) => props.sampler_descriptor_size,
65            DescriptorCreateInfo::AccelerationStructure(_) => {
66                props.acceleration_structure_descriptor_size
67            }
68        }
69    }
70
71    fn ty(&self) -> vk::DescriptorType {
72        match self {
73            DescriptorCreateInfo::UniformBuffer { .. } => vk::DescriptorType::UniformBuffer,
74            DescriptorCreateInfo::StorageBuffer { .. } => vk::DescriptorType::StorageBuffer,
75            DescriptorCreateInfo::SampledImage { .. } => vk::DescriptorType::SampledImage,
76            DescriptorCreateInfo::StorageImage { .. } => vk::DescriptorType::StorageImage,
77            DescriptorCreateInfo::InputAttachment { .. } => vk::DescriptorType::InputAttachment,
78            DescriptorCreateInfo::Sampler(_) => vk::DescriptorType::Sampler,
79            DescriptorCreateInfo::AccelerationStructure(_) => {
80                vk::DescriptorType::AccelerationStructureKHR
81            }
82        }
83    }
84}
85
86type DescriptorData = [u8; DESCRIPTOR_MAX_SIZE];
87
88#[derive(Clone, Copy)]
89pub struct Descriptor {
90    ty: vk::DescriptorType,
91    size: usize,
92    data: DescriptorData,
93}
94
95impl Descriptor {
96    #[must_use]
97    pub unsafe fn create(
98        physical_device: &PhysicalDevice,
99        device: &Device,
100        create_info: DescriptorCreateInfo,
101    ) -> Self {
102        // Descriptor info.
103        let props = physical_device.descriptor_buffer_properties_ext;
104        let size = create_info.size(&props);
105        let ty = create_info.ty();
106
107        // Get descriptor data.
108        let data = match create_info {
109            DescriptorCreateInfo::UniformBuffer { address, range } => Self::get_descriptor_data(
110                device,
111                ty,
112                size,
113                vk::DescriptorDataEXT {
114                    p_uniform_buffer: &vk::DescriptorAddressInfoEXT {
115                        s_type: vk::StructureType::DescriptorAddressInfoEXT,
116                        p_next: null_mut(),
117                        address,
118                        range,
119                        format: vk::Format::Undefined,
120                    },
121                },
122            ),
123            DescriptorCreateInfo::StorageBuffer { address, range } => Self::get_descriptor_data(
124                device,
125                ty,
126                size,
127                vk::DescriptorDataEXT {
128                    p_storage_buffer: &vk::DescriptorAddressInfoEXT {
129                        s_type: vk::StructureType::DescriptorAddressInfoEXT,
130                        p_next: null_mut(),
131                        address,
132                        range,
133                        format: vk::Format::Undefined,
134                    },
135                },
136            ),
137            DescriptorCreateInfo::SampledImage {
138                image_view,
139                image_layout,
140            } => Self::get_descriptor_data(
141                device,
142                ty,
143                size,
144                vk::DescriptorDataEXT {
145                    p_sampled_image: &vk::DescriptorImageInfo {
146                        sampler: vk::Sampler::null(),
147                        image_view,
148                        image_layout,
149                    },
150                },
151            ),
152            DescriptorCreateInfo::StorageImage {
153                image_view,
154                image_layout,
155            } => Self::get_descriptor_data(
156                device,
157                ty,
158                size,
159                vk::DescriptorDataEXT {
160                    p_storage_image: &vk::DescriptorImageInfo {
161                        sampler: vk::Sampler::null(),
162                        image_view,
163                        image_layout,
164                    },
165                },
166            ),
167            DescriptorCreateInfo::InputAttachment {
168                image_view,
169                image_layout,
170            } => Self::get_descriptor_data(
171                device,
172                ty,
173                size,
174                vk::DescriptorDataEXT {
175                    p_input_attachment_image: &vk::DescriptorImageInfo {
176                        sampler: vk::Sampler::null(),
177                        image_view,
178                        image_layout,
179                    },
180                },
181            ),
182            DescriptorCreateInfo::Sampler(sampler) => Self::get_descriptor_data(
183                device,
184                ty,
185                size,
186                vk::DescriptorDataEXT {
187                    p_sampler: &sampler,
188                },
189            ),
190            DescriptorCreateInfo::AccelerationStructure(acceleration_structure) => {
191                Self::get_descriptor_data(
192                    device,
193                    ty,
194                    size,
195                    vk::DescriptorDataEXT {
196                        acceleration_structure,
197                    },
198                )
199            }
200        };
201
202        Self { ty, size, data }
203    }
204
205    unsafe fn get_descriptor_data(
206        device: &Device,
207        ty: vk::DescriptorType,
208        size: usize,
209        data: vk::DescriptorDataEXT,
210    ) -> DescriptorData {
211        let mut descriptor = MaybeUninit::<DescriptorData>::zeroed();
212        device.get_descriptor_ext(
213            &vk::DescriptorGetInfoEXT {
214                s_type: vk::StructureType::DescriptorGetInfoEXT,
215                p_next: null(),
216                ty,
217                data,
218            },
219            size,
220            descriptor.as_mut_ptr().cast(),
221        );
222        descriptor.assume_init()
223    }
224}
225
226impl std::fmt::Debug for Descriptor {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        let slice = &self.data[0..self.size];
229        f.debug_struct("Descriptor")
230            .field("ty", &self.ty)
231            .field("size", &self.size)
232            .field("data", &slice)
233            .finish()
234    }
235}
236
237#[derive(Debug)]
238pub struct DescriptorBinding<'a> {
239    pub ty: vk::DescriptorType,
240    pub stages: vk::ShaderStageFlags,
241    pub descriptors: &'a [Descriptor],
242}
243
244pub struct DescriptorStorage {
245    buffer: vk::Buffer,
246    allocations: BufferAllocations,
247    pub(super) allocation: BufferAllocation,
248    set_layout: vk::DescriptorSetLayout,
249    pub(super) set_count: u32,
250    pub(super) buffer_indices: Vec<u32>,
251    pub(super) offsets: Vec<vk::DeviceSize>,
252    pub(super) push_constant_range: Option<vk::PushConstantRange>,
253    pub(super) pipeline_layout: vk::PipelineLayout,
254    pub(super) usage: vk::BufferUsageFlags,
255}
256
257impl DescriptorStorage {
258    pub unsafe fn create(
259        physical_device: &PhysicalDevice,
260        device: &Device,
261        bindings: &[DescriptorBinding],
262        push_constant_range: Option<vk::PushConstantRange>,
263    ) -> Result<Self> {
264        // Validation.
265        ensure!(!bindings.is_empty(), "Expected 1 or more bindings");
266        for (binding_index, binding) in bindings.iter().enumerate() {
267            ensure!(
268                !binding.descriptors.is_empty(),
269                "Binding {} expected 1 or more descriptors",
270                binding_index
271            );
272            for descriptor in binding.descriptors {
273                ensure!(
274                    binding.ty == descriptor.ty,
275                    "Binding {} expected descriptor type to be equal to {:?}, got {:?} instead",
276                    binding_index,
277                    binding.ty,
278                    descriptor.ty
279                );
280            }
281        }
282
283        // Descriptor set layout.
284        let set_layout_bindings = bindings
285            .iter()
286            .enumerate()
287            .map(|(binding_index, binding)| vk::DescriptorSetLayoutBinding {
288                binding: binding_index as _,
289                descriptor_type: binding.ty,
290                descriptor_count: binding.descriptors.len() as _,
291                stage_flags: binding.stages,
292                p_immutable_samplers: null(),
293            })
294            .collect::<Vec<_>>();
295        let set_layout =
296            device.create_descriptor_set_layout(&vk::DescriptorSetLayoutCreateInfo {
297                s_type: vk::StructureType::DescriptorSetLayoutCreateInfo,
298                p_next: null(),
299                flags: vk::DescriptorSetLayoutCreateFlagBits::DescriptorBufferEXT.into(),
300                binding_count: set_layout_bindings.len() as _,
301                p_bindings: set_layout_bindings.as_ptr(),
302            })?;
303        let set_count = 1;
304        let buffer_indices = vec![0];
305        let offsets = vec![0];
306        let size = device.get_descriptor_set_layout_size_ext(set_layout);
307
308        // Buffer usage.
309        let usage = vk::BufferUsageFlagBits::ResourceDescriptorBufferEXT
310            | vk::BufferUsageFlagBits::SamplerDescriptorBufferEXT;
311
312        // Buffer.
313        let (buffer, buffer_create_info) = BufferCreator::new(size, usage)
314            .create(device)
315            .context("Creating buffer object")?;
316
317        // Allocate.
318        let allocations = BufferAllocations::allocate(
319            physical_device,
320            device,
321            &[buffer],
322            &[buffer_create_info],
323            vk::MemoryPropertyFlagBits::HostVisible | vk::MemoryPropertyFlagBits::HostCoherent,
324        )?;
325        let allocation = allocations.allocations()[0];
326
327        // Write descriptors.
328        for (binding_index, binding) in bindings.iter().enumerate() {
329            let binding_index = binding_index as u32;
330            let descriptor_offset =
331                device.get_descriptor_set_layout_binding_offset_ext(set_layout, binding_index);
332            let descriptor_offset = descriptor_offset as usize;
333            for (array_index, descriptor) in binding.descriptors.iter().enumerate() {
334                let dst_offset = descriptor_offset + array_index * descriptor.size;
335                std::ptr::copy_nonoverlapping(
336                    descriptor.data.as_ptr(),
337                    allocation.as_mut_ptr::<u8>().add(dst_offset),
338                    descriptor.size,
339                );
340            }
341        }
342
343        // Pipeline layout.
344        let pipeline_layout = {
345            let mut create_info = vk::PipelineLayoutCreateInfo {
346                s_type: vk::StructureType::PipelineLayoutCreateInfo,
347                p_next: null(),
348                flags: vk::PipelineLayoutCreateFlags::empty(),
349                set_layout_count: 1,
350                p_set_layouts: &set_layout,
351                push_constant_range_count: 0,
352                p_push_constant_ranges: null(),
353            };
354            let mut pcr: vk::PushConstantRange = zeroed();
355            if let Some(push_constant_range) = &push_constant_range {
356                pcr.stage_flags = push_constant_range.stage_flags;
357                pcr.size = push_constant_range.size;
358                pcr.offset = push_constant_range.offset;
359                create_info.push_constant_range_count = 1;
360                create_info.p_push_constant_ranges = &pcr;
361            }
362            device.create_pipeline_layout(&create_info)?
363        };
364
365        Ok(Self {
366            buffer,
367            allocations,
368            allocation,
369            set_layout,
370            set_count,
371            buffer_indices,
372            offsets,
373            push_constant_range,
374            pipeline_layout,
375            usage,
376        })
377    }
378
379    pub unsafe fn destroy(self, device: &Device) {
380        device.destroy_pipeline_layout(self.pipeline_layout);
381        device.destroy_descriptor_set_layout(self.set_layout);
382        device.destroy_buffer(self.buffer);
383        self.allocations.free(device);
384    }
385
386    pub unsafe fn bind(&self, device: &Device, cmd: vk::CommandBuffer) {
387        device.cmd_bind_descriptor_buffers_ext(
388            cmd,
389            1,
390            &vk::DescriptorBufferBindingInfoEXT {
391                s_type: vk::StructureType::DescriptorBufferBindingInfoEXT,
392                p_next: null_mut(),
393                address: self.allocation.device_address(),
394                usage: self.usage,
395            },
396        );
397    }
398
399    pub unsafe fn set_offsets(
400        &self,
401        device: &Device,
402        cmd: vk::CommandBuffer,
403        pipeline_bind_point: vk::PipelineBindPoint,
404    ) {
405        device.cmd_set_descriptor_buffer_offsets_ext(
406            cmd,
407            pipeline_bind_point,
408            self.pipeline_layout,
409            0,
410            self.set_count,
411            self.buffer_indices.as_ptr(),
412            self.offsets.as_ptr(),
413        );
414    }
415
416    pub unsafe fn push_constants<T>(
417        &self,
418        device: &Device,
419        cmd: vk::CommandBuffer,
420        data: &T,
421    ) -> Result<()> {
422        let Some(pcr) = self.push_constant_range else {
423            bail!("Missing push constant range");
424        };
425        ensure!(pcr.size as usize == size_of::<T>());
426        device.cmd_push_constants(
427            cmd,
428            self.pipeline_layout,
429            pcr.stage_flags,
430            pcr.offset,
431            pcr.size,
432            (data as *const T).cast(),
433        );
434        Ok(())
435    }
436
437    #[must_use]
438    pub fn pipeline_layout(&self) -> vk::PipelineLayout {
439        self.pipeline_layout
440    }
441
442    #[must_use]
443    pub fn set_layouts(&self) -> &[vk::DescriptorSetLayout] {
444        std::slice::from_ref(&self.set_layout)
445    }
446
447    #[must_use]
448    pub fn push_constant_ranges(&self) -> &[vk::PushConstantRange] {
449        if let Some(pcr) = &self.push_constant_range {
450            std::slice::from_ref(pcr)
451        } else {
452            &[]
453        }
454    }
455}