shader_sense/
shader.rs

1use std::{collections::HashMap, path::PathBuf, str::FromStr};
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
6pub enum ShadingLanguage {
7    Wgsl,
8    Hlsl,
9    Glsl,
10}
11
12#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
13pub enum ShaderStage {
14    Vertex,
15    Fragment, // aka pixel shader
16    Compute,
17    TesselationControl,    // aka hull shader
18    TesselationEvaluation, // aka domain shader
19    Mesh,
20    Task, // aka amplification shader
21    Geometry,
22    RayGeneration,
23    ClosestHit,
24    AnyHit,
25    Callable,
26    Miss,
27    Intersect,
28}
29
30impl ShaderStage {
31    pub fn from_file_name(file_name: &String) -> Option<ShaderStage> {
32        // TODO: add control for these
33        let paths = HashMap::from([
34            ("vert", ShaderStage::Vertex),
35            ("frag", ShaderStage::Fragment),
36            ("comp", ShaderStage::Compute),
37            ("task", ShaderStage::Task),
38            ("mesh", ShaderStage::Mesh),
39            ("tesc", ShaderStage::TesselationControl),
40            ("tese", ShaderStage::TesselationEvaluation),
41            ("geom", ShaderStage::Geometry),
42            ("rgen", ShaderStage::RayGeneration),
43            ("rchit", ShaderStage::ClosestHit),
44            ("rahit", ShaderStage::AnyHit),
45            ("rcall", ShaderStage::Callable),
46            ("rmiss", ShaderStage::Miss),
47            ("rint", ShaderStage::Intersect),
48        ]);
49        let extension_list = file_name.rsplit(".");
50        for extension in extension_list {
51            if let Some(stage) = paths.get(extension) {
52                return Some(stage.clone());
53            } else {
54                continue;
55            }
56        }
57        // For header files & undefined, will output issue with missing version...
58        None
59    }
60}
61
62impl ToString for ShaderStage {
63    fn to_string(&self) -> String {
64        match self {
65            ShaderStage::Vertex => "vertex".to_string(),
66            ShaderStage::Fragment => "fragment".to_string(),
67            ShaderStage::Compute => "compute".to_string(),
68            ShaderStage::TesselationControl => "tesselationcontrol".to_string(),
69            ShaderStage::TesselationEvaluation => "tesselationevaluation".to_string(),
70            ShaderStage::Mesh => "mesh".to_string(),
71            ShaderStage::Task => "task".to_string(),
72            ShaderStage::Geometry => "geometry".to_string(),
73            ShaderStage::RayGeneration => "raygeneration".to_string(),
74            ShaderStage::ClosestHit => "closesthit".to_string(),
75            ShaderStage::AnyHit => "anyhit".to_string(),
76            ShaderStage::Callable => "callable".to_string(),
77            ShaderStage::Miss => "miss".to_string(),
78            ShaderStage::Intersect => "intersect".to_string(),
79        }
80    }
81}
82
83impl FromStr for ShadingLanguage {
84    type Err = ();
85
86    fn from_str(input: &str) -> Result<ShadingLanguage, Self::Err> {
87        match input {
88            "wgsl" => Ok(ShadingLanguage::Wgsl),
89            "hlsl" => Ok(ShadingLanguage::Hlsl),
90            "glsl" => Ok(ShadingLanguage::Glsl),
91            _ => Err(()),
92        }
93    }
94}
95impl ToString for ShadingLanguage {
96    fn to_string(&self) -> String {
97        String::from(match &self {
98            ShadingLanguage::Wgsl => "wgsl",
99            ShadingLanguage::Hlsl => "hlsl",
100            ShadingLanguage::Glsl => "glsl",
101        })
102    }
103}
104
105pub trait ShadingLanguageTag {
106    fn get_language() -> ShadingLanguage;
107}
108pub struct HlslShadingLanguageTag {}
109impl ShadingLanguageTag for HlslShadingLanguageTag {
110    fn get_language() -> ShadingLanguage {
111        ShadingLanguage::Hlsl
112    }
113}
114pub struct GlslShadingLanguageTag {}
115impl ShadingLanguageTag for GlslShadingLanguageTag {
116    fn get_language() -> ShadingLanguage {
117        ShadingLanguage::Glsl
118    }
119}
120pub struct WgslShadingLanguageTag {}
121impl ShadingLanguageTag for WgslShadingLanguageTag {
122    fn get_language() -> ShadingLanguage {
123        ShadingLanguage::Wgsl
124    }
125}
126
127// DXC only support shader model up to 6.0
128#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
129pub enum HlslShaderModel {
130    ShaderModel1,
131    ShaderModel1_1,
132    ShaderModel1_2,
133    ShaderModel1_3,
134    ShaderModel1_4,
135    ShaderModel2,
136    ShaderModel3,
137    ShaderModel4,
138    ShaderModel4_1,
139    ShaderModel5,
140    ShaderModel5_1,
141    ShaderModel6,
142    ShaderModel6_1,
143    ShaderModel6_2,
144    ShaderModel6_3,
145    ShaderModel6_4,
146    ShaderModel6_5,
147    ShaderModel6_6,
148    ShaderModel6_7,
149    #[default]
150    ShaderModel6_8,
151}
152
153impl HlslShaderModel {
154    pub fn earliest() -> HlslShaderModel {
155        HlslShaderModel::ShaderModel1
156    }
157    pub fn latest() -> HlslShaderModel {
158        HlslShaderModel::ShaderModel6_8
159    }
160}
161
162#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
163pub enum HlslVersion {
164    V2016,
165    V2017,
166    V2018,
167    #[default]
168    V2021,
169}
170
171#[derive(Default, Debug, Clone)]
172pub struct HlslCompilationParams {
173    pub shader_model: HlslShaderModel,
174    pub version: HlslVersion,
175    pub enable16bit_types: bool,
176    pub spirv: bool,
177}
178
179#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
180pub enum GlslTargetClient {
181    Vulkan1_0,
182    Vulkan1_1,
183    Vulkan1_2,
184    #[default]
185    Vulkan1_3,
186    OpenGL450,
187}
188
189impl GlslTargetClient {
190    pub fn is_opengl(&self) -> bool {
191        match *self {
192            GlslTargetClient::OpenGL450 => true,
193            _ => false,
194        }
195    }
196}
197
198#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
199pub enum GlslSpirvVersion {
200    SPIRV1_0,
201    SPIRV1_1,
202    SPIRV1_2,
203    SPIRV1_3,
204    SPIRV1_4,
205    SPIRV1_5,
206    #[default]
207    SPIRV1_6,
208}
209#[derive(Default, Debug, Clone)]
210pub struct GlslCompilationParams {
211    pub client: GlslTargetClient,
212    pub spirv: GlslSpirvVersion,
213}
214
215#[derive(Default, Debug, Clone)]
216pub struct WgslCompilationParams {}
217
218#[derive(Default, Debug, Clone)]
219pub struct ShaderContextParams {
220    pub defines: HashMap<String, String>,
221    pub includes: Vec<String>,
222    pub path_remapping: HashMap<PathBuf, PathBuf>,
223}
224
225#[derive(Default, Debug, Clone)]
226pub struct ShaderCompilationParams {
227    pub entry_point: Option<String>,
228    pub shader_stage: Option<ShaderStage>,
229    pub hlsl: HlslCompilationParams,
230    pub glsl: GlslCompilationParams,
231    pub wgsl: WgslCompilationParams,
232}
233
234#[derive(Default, Debug, Clone)]
235pub struct ShaderParams {
236    pub context: ShaderContextParams,
237    pub compilation: ShaderCompilationParams,
238}