1use super::*;
2
3#[derive(Clone, Copy, Debug)]
4pub enum ShaderType {
5 Task,
6 Mesh,
7 Fragment,
8 Compute,
9 Raygen,
10 Miss,
11 ClosestHit,
12}
13
14impl ShaderType {
15 #[must_use]
16 pub fn shader_stage(self) -> vk::ShaderStageFlagBits {
17 match self {
18 ShaderType::Task => vk::ShaderStageFlagBits::TaskEXT,
19 ShaderType::Mesh => vk::ShaderStageFlagBits::MeshEXT,
20 ShaderType::Fragment => vk::ShaderStageFlagBits::Fragment,
21 ShaderType::Compute => vk::ShaderStageFlagBits::Compute,
22 ShaderType::Raygen => vk::ShaderStageFlagBits::RaygenKHR,
23 ShaderType::Miss => vk::ShaderStageFlagBits::MissKHR,
24 ShaderType::ClosestHit => vk::ShaderStageFlagBits::ClosestHitKHR,
25 }
26 }
27
28 #[must_use]
29 pub fn next_shader_stage(self) -> Option<vk::ShaderStageFlagBits> {
30 match self {
31 ShaderType::Task => Some(vk::ShaderStageFlagBits::MeshEXT),
32 ShaderType::Mesh => Some(vk::ShaderStageFlagBits::Fragment),
33 ShaderType::Fragment
34 | ShaderType::Compute
35 | ShaderType::Raygen
36 | ShaderType::Miss
37 | ShaderType::ClosestHit => None,
38 }
39 }
40}
41
42#[derive(Clone, Debug)]
43pub struct ShaderBinary {
44 ty: ShaderType,
45 code: Vec<u8>,
46 entry_point: CString,
47}
48
49impl ShaderBinary {
50 #[must_use]
51 pub fn shader_type(&self) -> ShaderType {
52 self.ty
53 }
54
55 #[must_use]
56 pub fn code_size(&self) -> usize {
57 self.code.len()
58 }
59
60 #[must_use]
61 pub fn p_code(&self) -> *const std::ffi::c_void {
62 self.code.as_ptr().cast()
63 }
64
65 #[must_use]
66 pub fn entry_point(&self) -> *const std::ffi::c_char {
67 self.entry_point.as_ptr()
68 }
69
70 #[must_use]
71 pub fn shader_module_create_info(&self) -> vk::ShaderModuleCreateInfo {
72 vk::ShaderModuleCreateInfo {
73 s_type: vk::StructureType::ShaderModuleCreateInfo,
74 p_next: null(),
75 flags: vk::ShaderModuleCreateFlags::empty(),
76 code_size: self.code_size(),
77 p_code: self.p_code().cast(),
78 }
79 }
80}
81
82pub struct ShaderCompiler {
83 compiler: shaderc::Compiler,
84 includes: HashMap<String, String>,
85}
86
87impl ShaderCompiler {
88 pub fn new() -> Result<Self> {
89 use shaderc::Compiler;
90 let compiler = Compiler::new().context("Creating shader compiler")?;
91 Ok(Self {
92 compiler,
93 includes: HashMap::new(),
94 })
95 }
96
97 pub fn include(&mut self, source_name: impl AsRef<str>, content: impl AsRef<str>) {
98 self.includes.insert(
99 source_name.as_ref().to_string(),
100 content.as_ref().to_string(),
101 );
102 }
103
104 pub fn compile(
105 &self,
106 shader_type: ShaderType,
107 input_file_name: impl AsRef<str>,
108 entry_point_name: impl AsRef<str>,
109 code: impl AsRef<str>,
110 ) -> Result<ShaderBinary> {
111 use shaderc::CompileOptions;
112 use shaderc::OptimizationLevel;
113 use shaderc::ResolvedInclude;
114 use shaderc::SourceLanguage;
115 use shaderc::SpirvVersion;
116 use shaderc::TargetEnv;
117
118 let shader_kind = match shader_type {
119 ShaderType::Task => shaderc::ShaderKind::Task,
120 ShaderType::Mesh => shaderc::ShaderKind::Mesh,
121 ShaderType::Fragment => shaderc::ShaderKind::Fragment,
122 ShaderType::Compute => shaderc::ShaderKind::Compute,
123 ShaderType::Raygen => shaderc::ShaderKind::RayGeneration,
124 ShaderType::Miss => shaderc::ShaderKind::Miss,
125 ShaderType::ClosestHit => shaderc::ShaderKind::ClosestHit,
126 };
127 let mut options = CompileOptions::new().context("Creating shader compiler options")?;
128 options.set_target_env(TargetEnv::Vulkan, vulk::REQUIRED_VULKAN_VERSION);
129 options.set_optimization_level(OptimizationLevel::Performance);
130 options.set_target_spirv(SpirvVersion::V1_6);
131 options.set_source_language(SourceLanguage::GLSL);
132 options.set_warnings_as_errors();
133 options.set_include_callback(|source_name, _, _, _| -> shaderc::IncludeCallbackResult {
134 let content = self.includes.get(source_name);
135 if let Some(content) = content {
136 Ok(ResolvedInclude {
137 resolved_name: source_name.to_owned(),
138 content: (*content).clone(),
139 })
140 } else {
141 Err(format!("Unable to include source file {source_name}."))
142 }
143 });
144 let shader = self.compiler.compile_into_spirv(
145 code.as_ref(),
146 shader_kind,
147 input_file_name.as_ref(),
148 entry_point_name.as_ref(),
149 Some(&options),
150 )?;
151 if shader.get_num_warnings() > 0 {
152 error!("{}", shader.get_warning_messages());
153 }
154 Ok(ShaderBinary {
155 ty: shader_type,
156 code: shader.as_binary_u8().to_owned(),
157 entry_point: CString::new(entry_point_name.as_ref())?,
158 })
159 }
160}
161
162#[derive(Debug)]
163pub struct ShaderCreateInfo<'a> {
164 pub shader_binaries: &'a [ShaderBinary],
165 pub set_layouts: &'a [vk::DescriptorSetLayout],
166 pub push_constant_ranges: &'a [vk::PushConstantRange],
167 pub specialization_info: Option<&'a vk::SpecializationInfo>,
168}
169
170#[derive(Debug)]
171pub struct Shader {
172 pub(super) stages: Vec<vk::ShaderStageFlagBits>,
173 pub(super) shaders: Vec<vk::ShaderEXT>,
174}
175
176impl Shader {
177 pub unsafe fn create(device: &Device, create_info: &ShaderCreateInfo<'_>) -> Result<Self> {
178 let create_infos = create_info
179 .shader_binaries
180 .iter()
181 .map(|binary| vk::ShaderCreateInfoEXT {
182 s_type: vk::StructureType::ShaderCreateInfoEXT,
183 p_next: null(),
184 flags: vk::ShaderCreateFlagBitsEXT::LinkStageEXT.into(),
185 stage: binary.ty.shader_stage(),
186 next_stage: if let Some(next_stage) = binary.ty.next_shader_stage() {
187 next_stage.into()
188 } else {
189 zeroed()
190 },
191 code_type: vk::ShaderCodeTypeEXT::SpirvEXT,
192 code_size: binary.code_size(),
193 p_code: binary.p_code(),
194 p_name: binary.entry_point(),
195 set_layout_count: create_info.set_layouts.len() as _,
196 p_set_layouts: create_info.set_layouts.as_ptr(),
197 push_constant_range_count: create_info.push_constant_ranges.len() as _,
198 p_push_constant_ranges: create_info.push_constant_ranges.as_ptr(),
199 p_specialization_info: if let Some(specialization_info) =
200 create_info.specialization_info
201 {
202 specialization_info as *const _
203 } else {
204 null()
205 },
206 })
207 .collect::<Vec<_>>();
208 let mut shaders = Vec::with_capacity(create_info.shader_binaries.len());
209 device.create_shaders_ext(
210 create_infos.len() as _,
211 create_infos.as_ptr(),
212 shaders.as_mut_ptr(),
213 )?;
214 shaders.set_len(create_info.shader_binaries.len());
215
216 let stages = create_info
217 .shader_binaries
218 .iter()
219 .map(|spirv| spirv.ty.shader_stage())
220 .collect();
221
222 Ok(Self { stages, shaders })
223 }
224
225 pub unsafe fn destroy(self, device: &Device) {
226 for &shader in &self.shaders {
227 device.destroy_shader_ext(shader);
228 }
229 }
230}
231
232impl Shader {
233 pub unsafe fn bind(&self, device: &Device, cmd: vk::CommandBuffer) {
234 device.cmd_bind_shaders_ext(
235 cmd,
236 self.stages.len() as _,
237 self.stages.as_ptr(),
238 self.shaders.as_ptr(),
239 );
240 }
241}