vulk_ext/vkx/
shader.rs

1use super::*;
2
3#[derive(Clone, Copy, Debug)]
4pub enum ShaderType {
5    Task,
6    Mesh,
7    Fragment,
8    Compute,
9    Raygen,
10    Miss,
11    ClosestHit,
12}
13
14impl ShaderType {
15    #[must_use]
16    pub fn shader_stage(self) -> vk::ShaderStageFlagBits {
17        match self {
18            ShaderType::Task => vk::ShaderStageFlagBits::TaskEXT,
19            ShaderType::Mesh => vk::ShaderStageFlagBits::MeshEXT,
20            ShaderType::Fragment => vk::ShaderStageFlagBits::Fragment,
21            ShaderType::Compute => vk::ShaderStageFlagBits::Compute,
22            ShaderType::Raygen => vk::ShaderStageFlagBits::RaygenKHR,
23            ShaderType::Miss => vk::ShaderStageFlagBits::MissKHR,
24            ShaderType::ClosestHit => vk::ShaderStageFlagBits::ClosestHitKHR,
25        }
26    }
27
28    #[must_use]
29    pub fn next_shader_stage(self) -> Option<vk::ShaderStageFlagBits> {
30        match self {
31            ShaderType::Task => Some(vk::ShaderStageFlagBits::MeshEXT),
32            ShaderType::Mesh => Some(vk::ShaderStageFlagBits::Fragment),
33            ShaderType::Fragment
34            | ShaderType::Compute
35            | ShaderType::Raygen
36            | ShaderType::Miss
37            | ShaderType::ClosestHit => None,
38        }
39    }
40}
41
42#[derive(Clone, Debug)]
43pub struct ShaderBinary {
44    ty: ShaderType,
45    code: Vec<u8>,
46    entry_point: CString,
47}
48
49impl ShaderBinary {
50    #[must_use]
51    pub fn shader_type(&self) -> ShaderType {
52        self.ty
53    }
54
55    #[must_use]
56    pub fn code_size(&self) -> usize {
57        self.code.len()
58    }
59
60    #[must_use]
61    pub fn p_code(&self) -> *const std::ffi::c_void {
62        self.code.as_ptr().cast()
63    }
64
65    #[must_use]
66    pub fn entry_point(&self) -> *const std::ffi::c_char {
67        self.entry_point.as_ptr()
68    }
69
70    #[must_use]
71    pub fn shader_module_create_info(&self) -> vk::ShaderModuleCreateInfo {
72        vk::ShaderModuleCreateInfo {
73            s_type: vk::StructureType::ShaderModuleCreateInfo,
74            p_next: null(),
75            flags: vk::ShaderModuleCreateFlags::empty(),
76            code_size: self.code_size(),
77            p_code: self.p_code().cast(),
78        }
79    }
80}
81
82pub struct ShaderCompiler {
83    compiler: shaderc::Compiler,
84    includes: HashMap<String, String>,
85}
86
87impl ShaderCompiler {
88    pub fn new() -> Result<Self> {
89        use shaderc::Compiler;
90        let compiler = Compiler::new().context("Creating shader compiler")?;
91        Ok(Self {
92            compiler,
93            includes: HashMap::new(),
94        })
95    }
96
97    pub fn include(&mut self, source_name: impl AsRef<str>, content: impl AsRef<str>) {
98        self.includes.insert(
99            source_name.as_ref().to_string(),
100            content.as_ref().to_string(),
101        );
102    }
103
104    pub fn compile(
105        &self,
106        shader_type: ShaderType,
107        input_file_name: impl AsRef<str>,
108        entry_point_name: impl AsRef<str>,
109        code: impl AsRef<str>,
110    ) -> Result<ShaderBinary> {
111        use shaderc::CompileOptions;
112        use shaderc::OptimizationLevel;
113        use shaderc::ResolvedInclude;
114        use shaderc::SourceLanguage;
115        use shaderc::SpirvVersion;
116        use shaderc::TargetEnv;
117
118        let shader_kind = match shader_type {
119            ShaderType::Task => shaderc::ShaderKind::Task,
120            ShaderType::Mesh => shaderc::ShaderKind::Mesh,
121            ShaderType::Fragment => shaderc::ShaderKind::Fragment,
122            ShaderType::Compute => shaderc::ShaderKind::Compute,
123            ShaderType::Raygen => shaderc::ShaderKind::RayGeneration,
124            ShaderType::Miss => shaderc::ShaderKind::Miss,
125            ShaderType::ClosestHit => shaderc::ShaderKind::ClosestHit,
126        };
127        let mut options = CompileOptions::new().context("Creating shader compiler options")?;
128        options.set_target_env(TargetEnv::Vulkan, vulk::REQUIRED_VULKAN_VERSION);
129        options.set_optimization_level(OptimizationLevel::Performance);
130        options.set_target_spirv(SpirvVersion::V1_6);
131        options.set_source_language(SourceLanguage::GLSL);
132        options.set_warnings_as_errors();
133        options.set_include_callback(|source_name, _, _, _| -> shaderc::IncludeCallbackResult {
134            let content = self.includes.get(source_name);
135            if let Some(content) = content {
136                Ok(ResolvedInclude {
137                    resolved_name: source_name.to_owned(),
138                    content: (*content).clone(),
139                })
140            } else {
141                Err(format!("Unable to include source file {source_name}."))
142            }
143        });
144        let shader = self.compiler.compile_into_spirv(
145            code.as_ref(),
146            shader_kind,
147            input_file_name.as_ref(),
148            entry_point_name.as_ref(),
149            Some(&options),
150        )?;
151        if shader.get_num_warnings() > 0 {
152            error!("{}", shader.get_warning_messages());
153        }
154        Ok(ShaderBinary {
155            ty: shader_type,
156            code: shader.as_binary_u8().to_owned(),
157            entry_point: CString::new(entry_point_name.as_ref())?,
158        })
159    }
160}
161
162#[derive(Debug)]
163pub struct ShaderCreateInfo<'a> {
164    pub shader_binaries: &'a [ShaderBinary],
165    pub set_layouts: &'a [vk::DescriptorSetLayout],
166    pub push_constant_ranges: &'a [vk::PushConstantRange],
167    pub specialization_info: Option<&'a vk::SpecializationInfo>,
168}
169
170#[derive(Debug)]
171pub struct Shader {
172    pub(super) stages: Vec<vk::ShaderStageFlagBits>,
173    pub(super) shaders: Vec<vk::ShaderEXT>,
174}
175
176impl Shader {
177    pub unsafe fn create(device: &Device, create_info: &ShaderCreateInfo<'_>) -> Result<Self> {
178        let create_infos = create_info
179            .shader_binaries
180            .iter()
181            .map(|binary| vk::ShaderCreateInfoEXT {
182                s_type: vk::StructureType::ShaderCreateInfoEXT,
183                p_next: null(),
184                flags: vk::ShaderCreateFlagBitsEXT::LinkStageEXT.into(),
185                stage: binary.ty.shader_stage(),
186                next_stage: if let Some(next_stage) = binary.ty.next_shader_stage() {
187                    next_stage.into()
188                } else {
189                    zeroed()
190                },
191                code_type: vk::ShaderCodeTypeEXT::SpirvEXT,
192                code_size: binary.code_size(),
193                p_code: binary.p_code(),
194                p_name: binary.entry_point(),
195                set_layout_count: create_info.set_layouts.len() as _,
196                p_set_layouts: create_info.set_layouts.as_ptr(),
197                push_constant_range_count: create_info.push_constant_ranges.len() as _,
198                p_push_constant_ranges: create_info.push_constant_ranges.as_ptr(),
199                p_specialization_info: if let Some(specialization_info) =
200                    create_info.specialization_info
201                {
202                    specialization_info as *const _
203                } else {
204                    null()
205                },
206            })
207            .collect::<Vec<_>>();
208        let mut shaders = Vec::with_capacity(create_info.shader_binaries.len());
209        device.create_shaders_ext(
210            create_infos.len() as _,
211            create_infos.as_ptr(),
212            shaders.as_mut_ptr(),
213        )?;
214        shaders.set_len(create_info.shader_binaries.len());
215
216        let stages = create_info
217            .shader_binaries
218            .iter()
219            .map(|spirv| spirv.ty.shader_stage())
220            .collect();
221
222        Ok(Self { stages, shaders })
223    }
224
225    pub unsafe fn destroy(self, device: &Device) {
226        for &shader in &self.shaders {
227            device.destroy_shader_ext(shader);
228        }
229    }
230}
231
232impl Shader {
233    pub unsafe fn bind(&self, device: &Device, cmd: vk::CommandBuffer) {
234        device.cmd_bind_shaders_ext(
235            cmd,
236            self.stages.len() as _,
237            self.stages.as_ptr(),
238            self.shaders.as_ptr(),
239        );
240    }
241}