shader_sense/
shader.rs

1//! Shader stage and specific helpers
2use std::{
3    collections::HashMap,
4    ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not},
5    path::PathBuf,
6    str::FromStr,
7};
8
9use serde::{Deserialize, Serialize};
10
11/// All shading language supported
12#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
13pub enum ShadingLanguage {
14    Wgsl,
15    Hlsl,
16    Glsl,
17}
18
19#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
20pub struct ShaderStageMask(u32);
21
22impl ShaderStageMask {
23    pub const VERTEX: Self = Self(1 << 0);
24    pub const FRAGMENT: Self = Self(1 << 1);
25    pub const COMPUTE: Self = Self(1 << 2);
26    pub const TESSELATION_CONTROL: Self = Self(1 << 3);
27    pub const TESSELATION_EVALUATION: Self = Self(1 << 4);
28    pub const MESH: Self = Self(1 << 5);
29    pub const TASK: Self = Self(1 << 6);
30    pub const GEOMETRY: Self = Self(1 << 7);
31    pub const RAY_GENERATION: Self = Self(1 << 8);
32    pub const CLOSEST_HIT: Self = Self(1 << 9);
33    pub const ANY_HIT: Self = Self(1 << 10);
34    pub const CALLABLE: Self = Self(1 << 11);
35    pub const MISS: Self = Self(1 << 12);
36    pub const INTERSECT: Self = Self(1 << 13);
37}
38
39impl Default for ShaderStageMask {
40    fn default() -> Self {
41        Self(0)
42    }
43}
44impl ShaderStageMask {
45    pub const fn from_u32(x: u32) -> Self {
46        Self(x)
47    }
48    pub const fn as_u32(self) -> u32 {
49        self.0
50    }
51    pub const fn is_empty(self) -> bool {
52        self.0 == 0
53    }
54    pub const fn contains(self, other: &ShaderStage) -> bool {
55        let mask = other.as_mask();
56        self.0 & mask.0 == mask.0
57    }
58}
59impl BitOr for ShaderStageMask {
60    type Output = Self;
61    #[inline]
62    fn bitor(self, rhs: Self) -> Self {
63        Self(self.0 | rhs.0)
64    }
65}
66impl BitOrAssign for ShaderStageMask {
67    #[inline]
68    fn bitor_assign(&mut self, rhs: Self) {
69        *self = *self | rhs
70    }
71}
72impl BitAnd for ShaderStageMask {
73    type Output = Self;
74    #[inline]
75    fn bitand(self, rhs: Self) -> Self {
76        Self(self.0 & rhs.0)
77    }
78}
79impl BitAndAssign for ShaderStageMask {
80    #[inline]
81    fn bitand_assign(&mut self, rhs: Self) {
82        *self = *self & rhs
83    }
84}
85impl BitXor for ShaderStageMask {
86    type Output = Self;
87    #[inline]
88    fn bitxor(self, rhs: Self) -> Self {
89        Self(self.0 ^ rhs.0)
90    }
91}
92impl BitXorAssign for ShaderStageMask {
93    #[inline]
94    fn bitxor_assign(&mut self, rhs: Self) {
95        *self = *self ^ rhs
96    }
97}
98impl Not for ShaderStageMask {
99    type Output = Self;
100    #[inline]
101    fn not(self) -> Self {
102        Self(!self.0)
103    }
104}
105
106/// All shader stage supported
107#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
108#[serde(rename_all = "camelCase")]
109pub enum ShaderStage {
110    Vertex,
111    Fragment, // aka pixel shader
112    Compute,
113    TesselationControl,    // aka hull shader
114    TesselationEvaluation, // aka domain shader
115    Mesh,
116    Task, // aka amplification shader
117    Geometry,
118    RayGeneration,
119    ClosestHit,
120    AnyHit,
121    Callable,
122    Miss,
123    Intersect,
124}
125
126impl ShaderStage {
127    /// Get a stage from its filename. Mostly follow glslang guideline
128    pub fn from_file_name(file_name: &String) -> Option<ShaderStage> {
129        // TODO: add control for these
130        let paths = HashMap::from([
131            ("vert", ShaderStage::Vertex),
132            ("frag", ShaderStage::Fragment),
133            ("comp", ShaderStage::Compute),
134            ("task", ShaderStage::Task),
135            ("mesh", ShaderStage::Mesh),
136            ("tesc", ShaderStage::TesselationControl),
137            ("tese", ShaderStage::TesselationEvaluation),
138            ("geom", ShaderStage::Geometry),
139            ("rgen", ShaderStage::RayGeneration),
140            ("rchit", ShaderStage::ClosestHit),
141            ("rahit", ShaderStage::AnyHit),
142            ("rcall", ShaderStage::Callable),
143            ("rmiss", ShaderStage::Miss),
144            ("rint", ShaderStage::Intersect),
145        ]);
146        let extension_list = file_name.rsplit(".");
147        for extension in extension_list {
148            if let Some(stage) = paths.get(extension) {
149                return Some(stage.clone());
150            } else {
151                continue;
152            }
153        }
154        // For header files & undefined, will output issue with missing version...
155        None
156    }
157    pub const fn as_mask(&self) -> ShaderStageMask {
158        match self {
159            ShaderStage::Vertex => ShaderStageMask::VERTEX,
160            ShaderStage::Fragment => ShaderStageMask::FRAGMENT,
161            ShaderStage::Compute => ShaderStageMask::COMPUTE,
162            ShaderStage::TesselationControl => ShaderStageMask::TESSELATION_CONTROL,
163            ShaderStage::TesselationEvaluation => ShaderStageMask::TESSELATION_EVALUATION,
164            ShaderStage::Mesh => ShaderStageMask::MESH,
165            ShaderStage::Task => ShaderStageMask::TASK,
166            ShaderStage::Geometry => ShaderStageMask::GEOMETRY,
167            ShaderStage::RayGeneration => ShaderStageMask::RAY_GENERATION,
168            ShaderStage::ClosestHit => ShaderStageMask::CLOSEST_HIT,
169            ShaderStage::AnyHit => ShaderStageMask::ANY_HIT,
170            ShaderStage::Callable => ShaderStageMask::CALLABLE,
171            ShaderStage::Miss => ShaderStageMask::MISS,
172            ShaderStage::Intersect => ShaderStageMask::INTERSECT,
173        }
174    }
175    /// All graphics pipeline stages.
176    pub fn graphics() -> ShaderStageMask {
177        ShaderStageMask::VERTEX
178            | ShaderStageMask::FRAGMENT
179            | ShaderStageMask::GEOMETRY
180            | ShaderStageMask::TESSELATION_CONTROL
181            | ShaderStageMask::TESSELATION_EVALUATION
182            | ShaderStageMask::TASK
183            | ShaderStageMask::MESH
184    }
185    /// All compute pipeline stages.
186    pub fn compute() -> ShaderStageMask {
187        ShaderStageMask::COMPUTE
188    }
189    /// All raytracing pipeline stages.
190    pub fn raytracing() -> ShaderStageMask {
191        ShaderStageMask::RAY_GENERATION
192            | ShaderStageMask::INTERSECT
193            | ShaderStageMask::CLOSEST_HIT
194            | ShaderStageMask::ANY_HIT
195            | ShaderStageMask::MISS
196            | ShaderStageMask::INTERSECT
197    }
198}
199
200impl FromStr for ShaderStage {
201    type Err = ();
202
203    fn from_str(input: &str) -> Result<ShaderStage, Self::Err> {
204        // Be case insensitive for parsing.
205        let lower_input = input.to_lowercase();
206        match lower_input.as_str() {
207            "vertex" => Ok(ShaderStage::Vertex),
208            "fragment" | "pixel" => Ok(ShaderStage::Fragment),
209            "compute" => Ok(ShaderStage::Compute),
210            "tesselationcontrol" | "hull" => Ok(ShaderStage::TesselationControl),
211            "tesselationevaluation" | "domain" => Ok(ShaderStage::TesselationEvaluation),
212            "mesh" => Ok(ShaderStage::Mesh),
213            "task" | "amplification" => Ok(ShaderStage::Task),
214            "geometry" => Ok(ShaderStage::Geometry),
215            "raygeneration" => Ok(ShaderStage::RayGeneration),
216            "closesthit" => Ok(ShaderStage::ClosestHit),
217            "anyhit" => Ok(ShaderStage::AnyHit),
218            "callable" => Ok(ShaderStage::Callable),
219            "miss" => Ok(ShaderStage::Miss),
220            "intersect" => Ok(ShaderStage::Intersect),
221            _ => Err(()),
222        }
223    }
224}
225impl ToString for ShaderStage {
226    fn to_string(&self) -> String {
227        match self {
228            ShaderStage::Vertex => "vertex".to_string(),
229            ShaderStage::Fragment => "fragment".to_string(),
230            ShaderStage::Compute => "compute".to_string(),
231            ShaderStage::TesselationControl => "tesselationcontrol".to_string(),
232            ShaderStage::TesselationEvaluation => "tesselationevaluation".to_string(),
233            ShaderStage::Mesh => "mesh".to_string(),
234            ShaderStage::Task => "task".to_string(),
235            ShaderStage::Geometry => "geometry".to_string(),
236            ShaderStage::RayGeneration => "raygeneration".to_string(),
237            ShaderStage::ClosestHit => "closesthit".to_string(),
238            ShaderStage::AnyHit => "anyhit".to_string(),
239            ShaderStage::Callable => "callable".to_string(),
240            ShaderStage::Miss => "miss".to_string(),
241            ShaderStage::Intersect => "intersect".to_string(),
242        }
243    }
244}
245
246impl FromStr for ShadingLanguage {
247    type Err = ();
248
249    fn from_str(input: &str) -> Result<ShadingLanguage, Self::Err> {
250        match input {
251            "wgsl" => Ok(ShadingLanguage::Wgsl),
252            "hlsl" => Ok(ShadingLanguage::Hlsl),
253            "glsl" => Ok(ShadingLanguage::Glsl),
254            _ => Err(()),
255        }
256    }
257}
258impl ToString for ShadingLanguage {
259    fn to_string(&self) -> String {
260        String::from(match &self {
261            ShadingLanguage::Wgsl => "wgsl",
262            ShadingLanguage::Hlsl => "hlsl",
263            ShadingLanguage::Glsl => "glsl",
264        })
265    }
266}
267
268/// Generic tag to define a language to be used in template situations
269pub trait ShadingLanguageTag {
270    /// Get the language of the tag.
271    fn get_language() -> ShadingLanguage;
272}
273
274/// Hlsl tag
275pub struct HlslShadingLanguageTag {}
276impl ShadingLanguageTag for HlslShadingLanguageTag {
277    fn get_language() -> ShadingLanguage {
278        ShadingLanguage::Hlsl
279    }
280}
281/// Glsl tag
282pub struct GlslShadingLanguageTag {}
283impl ShadingLanguageTag for GlslShadingLanguageTag {
284    fn get_language() -> ShadingLanguage {
285        ShadingLanguage::Glsl
286    }
287}
288/// Wgsl tag
289pub struct WgslShadingLanguageTag {}
290impl ShadingLanguageTag for WgslShadingLanguageTag {
291    fn get_language() -> ShadingLanguage {
292        ShadingLanguage::Wgsl
293    }
294}
295
296/// All HLSL shader model existing.
297///
298/// Note that DXC only support shader model up to 6.0, and FXC is not supported.
299/// So shader model below 6 are only present for documentation purpose.
300#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
301pub enum HlslShaderModel {
302    ShaderModel1,
303    ShaderModel1_1,
304    ShaderModel1_2,
305    ShaderModel1_3,
306    ShaderModel1_4,
307    ShaderModel2,
308    ShaderModel3,
309    ShaderModel4,
310    ShaderModel4_1,
311    ShaderModel5,
312    ShaderModel5_1,
313    ShaderModel6,
314    ShaderModel6_1,
315    ShaderModel6_2,
316    ShaderModel6_3,
317    ShaderModel6_4,
318    ShaderModel6_5,
319    ShaderModel6_6,
320    ShaderModel6_7,
321    #[default]
322    ShaderModel6_8,
323}
324
325impl HlslShaderModel {
326    /// Get first shader model version
327    pub fn earliest() -> HlslShaderModel {
328        HlslShaderModel::ShaderModel1
329    }
330    /// Get last shader model version
331    pub fn latest() -> HlslShaderModel {
332        HlslShaderModel::ShaderModel6_8
333    }
334}
335
336/// All HLSL version supported
337#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
338pub enum HlslVersion {
339    V2016,
340    V2017,
341    V2018,
342    #[default]
343    V2021,
344}
345
346/// Hlsl compilation parameters for DXC.
347#[derive(Default, Debug, Clone, PartialEq, Eq)]
348pub struct HlslCompilationParams {
349    pub shader_model: HlslShaderModel,
350    pub version: HlslVersion,
351    pub enable16bit_types: bool,
352    pub spirv: bool,
353}
354
355/// Glsl target client
356#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
357pub enum GlslTargetClient {
358    Vulkan1_0,
359    Vulkan1_1,
360    Vulkan1_2,
361    #[default]
362    Vulkan1_3,
363    OpenGL450,
364}
365
366impl GlslTargetClient {
367    /// Check if glsl is for OpenGL or Vulkan
368    pub fn is_opengl(&self) -> bool {
369        match *self {
370            GlslTargetClient::OpenGL450 => true,
371            _ => false,
372        }
373    }
374}
375
376/// All SPIRV version supported for glsl
377#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
378pub enum GlslSpirvVersion {
379    SPIRV1_0,
380    SPIRV1_1,
381    SPIRV1_2,
382    SPIRV1_3,
383    SPIRV1_4,
384    SPIRV1_5,
385    #[default]
386    SPIRV1_6,
387}
388/// Glsl compilation parameters for glslang.
389#[derive(Default, Debug, Clone, PartialEq, Eq)]
390pub struct GlslCompilationParams {
391    pub client: GlslTargetClient,
392    pub spirv: GlslSpirvVersion,
393}
394
395/// Wgsl compilation parameters for naga.
396#[derive(Default, Debug, Clone, PartialEq, Eq)]
397pub struct WgslCompilationParams {}
398
399/// Parameters for includes.
400#[derive(Default, Debug, Clone)]
401pub struct ShaderContextParams {
402    pub defines: HashMap<String, String>,
403    pub includes: Vec<PathBuf>,
404    pub path_remapping: HashMap<PathBuf, PathBuf>,
405}
406
407/// Parameters for compilation
408#[derive(Default, Debug, Clone)]
409pub struct ShaderCompilationParams {
410    pub entry_point: Option<String>,
411    pub shader_stage: Option<ShaderStage>,
412    pub hlsl: HlslCompilationParams,
413    pub glsl: GlslCompilationParams,
414    pub wgsl: WgslCompilationParams,
415}
416
417/// Generic parameters passed to validation and inspection.
418#[derive(Default, Debug, Clone)]
419pub struct ShaderParams {
420    pub context: ShaderContextParams,
421    pub compilation: ShaderCompilationParams,
422}