1use std::{collections::HashMap, path::PathBuf, str::FromStr};
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
6pub enum ShadingLanguage {
7 Wgsl,
8 Hlsl,
9 Glsl,
10}
11
12#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
13pub enum ShaderStage {
14 Vertex,
15 Fragment, Compute,
17 TesselationControl, TesselationEvaluation, Mesh,
20 Task, Geometry,
22 RayGeneration,
23 ClosestHit,
24 AnyHit,
25 Callable,
26 Miss,
27 Intersect,
28}
29
30impl ShaderStage {
31 pub fn from_file_name(file_name: &String) -> Option<ShaderStage> {
32 let paths = HashMap::from([
34 ("vert", ShaderStage::Vertex),
35 ("frag", ShaderStage::Fragment),
36 ("comp", ShaderStage::Compute),
37 ("task", ShaderStage::Task),
38 ("mesh", ShaderStage::Mesh),
39 ("tesc", ShaderStage::TesselationControl),
40 ("tese", ShaderStage::TesselationEvaluation),
41 ("geom", ShaderStage::Geometry),
42 ("rgen", ShaderStage::RayGeneration),
43 ("rchit", ShaderStage::ClosestHit),
44 ("rahit", ShaderStage::AnyHit),
45 ("rcall", ShaderStage::Callable),
46 ("rmiss", ShaderStage::Miss),
47 ("rint", ShaderStage::Intersect),
48 ]);
49 let extension_list = file_name.rsplit(".");
50 for extension in extension_list {
51 if let Some(stage) = paths.get(extension) {
52 return Some(stage.clone());
53 } else {
54 continue;
55 }
56 }
57 None
59 }
60}
61
62impl ToString for ShaderStage {
63 fn to_string(&self) -> String {
64 match self {
65 ShaderStage::Vertex => "vertex".to_string(),
66 ShaderStage::Fragment => "fragment".to_string(),
67 ShaderStage::Compute => "compute".to_string(),
68 ShaderStage::TesselationControl => "tesselationcontrol".to_string(),
69 ShaderStage::TesselationEvaluation => "tesselationevaluation".to_string(),
70 ShaderStage::Mesh => "mesh".to_string(),
71 ShaderStage::Task => "task".to_string(),
72 ShaderStage::Geometry => "geometry".to_string(),
73 ShaderStage::RayGeneration => "raygeneration".to_string(),
74 ShaderStage::ClosestHit => "closesthit".to_string(),
75 ShaderStage::AnyHit => "anyhit".to_string(),
76 ShaderStage::Callable => "callable".to_string(),
77 ShaderStage::Miss => "miss".to_string(),
78 ShaderStage::Intersect => "intersect".to_string(),
79 }
80 }
81}
82
83impl FromStr for ShadingLanguage {
84 type Err = ();
85
86 fn from_str(input: &str) -> Result<ShadingLanguage, Self::Err> {
87 match input {
88 "wgsl" => Ok(ShadingLanguage::Wgsl),
89 "hlsl" => Ok(ShadingLanguage::Hlsl),
90 "glsl" => Ok(ShadingLanguage::Glsl),
91 _ => Err(()),
92 }
93 }
94}
95impl ToString for ShadingLanguage {
96 fn to_string(&self) -> String {
97 String::from(match &self {
98 ShadingLanguage::Wgsl => "wgsl",
99 ShadingLanguage::Hlsl => "hlsl",
100 ShadingLanguage::Glsl => "glsl",
101 })
102 }
103}
104
105pub trait ShadingLanguageTag {
106 fn get_language() -> ShadingLanguage;
107}
108pub struct HlslShadingLanguageTag {}
109impl ShadingLanguageTag for HlslShadingLanguageTag {
110 fn get_language() -> ShadingLanguage {
111 ShadingLanguage::Hlsl
112 }
113}
114pub struct GlslShadingLanguageTag {}
115impl ShadingLanguageTag for GlslShadingLanguageTag {
116 fn get_language() -> ShadingLanguage {
117 ShadingLanguage::Glsl
118 }
119}
120pub struct WgslShadingLanguageTag {}
121impl ShadingLanguageTag for WgslShadingLanguageTag {
122 fn get_language() -> ShadingLanguage {
123 ShadingLanguage::Wgsl
124 }
125}
126
127#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
129pub enum HlslShaderModel {
130 ShaderModel1,
131 ShaderModel1_1,
132 ShaderModel1_2,
133 ShaderModel1_3,
134 ShaderModel1_4,
135 ShaderModel2,
136 ShaderModel3,
137 ShaderModel4,
138 ShaderModel4_1,
139 ShaderModel5,
140 ShaderModel5_1,
141 ShaderModel6,
142 ShaderModel6_1,
143 ShaderModel6_2,
144 ShaderModel6_3,
145 ShaderModel6_4,
146 ShaderModel6_5,
147 ShaderModel6_6,
148 ShaderModel6_7,
149 #[default]
150 ShaderModel6_8,
151}
152
153impl HlslShaderModel {
154 pub fn earliest() -> HlslShaderModel {
155 HlslShaderModel::ShaderModel1
156 }
157 pub fn latest() -> HlslShaderModel {
158 HlslShaderModel::ShaderModel6_8
159 }
160}
161
162#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
163pub enum HlslVersion {
164 V2016,
165 V2017,
166 V2018,
167 #[default]
168 V2021,
169}
170
171#[derive(Default, Debug, Clone)]
172pub struct HlslCompilationParams {
173 pub shader_model: HlslShaderModel,
174 pub version: HlslVersion,
175 pub enable16bit_types: bool,
176 pub spirv: bool,
177}
178
179#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
180pub enum GlslTargetClient {
181 Vulkan1_0,
182 Vulkan1_1,
183 Vulkan1_2,
184 #[default]
185 Vulkan1_3,
186 OpenGL450,
187}
188
189impl GlslTargetClient {
190 pub fn is_opengl(&self) -> bool {
191 match *self {
192 GlslTargetClient::OpenGL450 => true,
193 _ => false,
194 }
195 }
196}
197
198#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
199pub enum GlslSpirvVersion {
200 SPIRV1_0,
201 SPIRV1_1,
202 SPIRV1_2,
203 SPIRV1_3,
204 SPIRV1_4,
205 SPIRV1_5,
206 #[default]
207 SPIRV1_6,
208}
209#[derive(Default, Debug, Clone)]
210pub struct GlslCompilationParams {
211 pub client: GlslTargetClient,
212 pub spirv: GlslSpirvVersion,
213}
214
215#[derive(Default, Debug, Clone)]
216pub struct WgslCompilationParams {}
217
218#[derive(Default, Debug, Clone)]
219pub struct ShaderContextParams {
220 pub defines: HashMap<String, String>,
221 pub includes: Vec<String>,
222 pub path_remapping: HashMap<PathBuf, PathBuf>,
223}
224
225#[derive(Default, Debug, Clone)]
226pub struct ShaderCompilationParams {
227 pub entry_point: Option<String>,
228 pub shader_stage: Option<ShaderStage>,
229 pub hlsl: HlslCompilationParams,
230 pub glsl: GlslCompilationParams,
231 pub wgsl: WgslCompilationParams,
232}
233
234#[derive(Default, Debug, Clone)]
235pub struct ShaderParams {
236 pub context: ShaderContextParams,
237 pub compilation: ShaderCompilationParams,
238}