1static API_KEY: &str = "rdnjhn";
2
3use serde::Deserialize;
4
5#[derive(Debug, Deserialize, Default)]
6#[allow(unused)]
7pub struct ShaderInfo {
8 pub id: String,
9 pub date: String,
10 viewed: i32,
11 pub name: String,
12 pub username: String,
13 pub description: String,
14 likes: i32,
15 published: i32,
16 flags: i32,
17
18 #[serde(rename = "usePreview")]
19 use_preview: i32,
20 tags: Vec<String>,
21 hasliked: i32,
22}
23
24#[derive(Debug, Deserialize, Default)]
25#[allow(unused)]
26pub struct Sampler {
27 filter: String,
28 wrap: String,
29 vflip: String,
30 srgb: String,
31 internal: String,
32}
33
34#[derive(Debug, Deserialize, Default)]
35#[allow(unused)]
36pub struct ShaderInput {
37 id: i32,
38 src: String,
39 ctype: String,
40 channel: i32,
41 sampler: Sampler,
42 published: i32,
43}
44
45#[derive(Debug, Deserialize, Default)]
46#[allow(unused)]
47pub struct ShaderOutput {
48 id: i32,
49 channel: i32,
50}
51
52#[derive(Debug, Deserialize, Default)]
53#[allow(unused)]
54pub struct RenderPass {
55 inputs: Vec<ShaderInput>,
56 outputs: Vec<ShaderOutput>,
57 code: String,
58 name: String,
59 description: String,
60 r#type: String,
61}
62#[derive(Debug, Deserialize, Default)]
63#[allow(unused)]
64pub struct Shader {
65 pub ver: String,
66 pub info: ShaderInfo,
67 pub renderpass: Vec<RenderPass>,
68}
69
70#[derive(Debug)]
71pub enum ShaderProcessingError {
72 RequestError(reqwest::Error),
73 ShaderError(String),
74 ParseErrors(naga::front::glsl::ParseErrors),
75 WgslError(naga::back::wgsl::Error),
76 ValidationError(naga::WithSpan<naga::valid::ValidationError>),
77
78 MissingSdf(String),
80}
81
82impl From<reqwest::Error> for ShaderProcessingError {
83 fn from(error: reqwest::Error) -> Self {
84 ShaderProcessingError::RequestError(error)
85 }
86}
87
88impl From<String> for ShaderProcessingError {
89 fn from(error: String) -> Self {
90 ShaderProcessingError::ShaderError(error)
91 }
92}
93
94impl From<naga::front::glsl::ParseErrors> for ShaderProcessingError {
95 fn from(error: naga::front::glsl::ParseErrors) -> Self {
96 ShaderProcessingError::ParseErrors(error)
97 }
98}
99
100impl From<naga::back::wgsl::Error> for ShaderProcessingError {
101 fn from(error: naga::back::wgsl::Error) -> Self {
102 ShaderProcessingError::WgslError(error)
103 }
104}
105
106impl From<naga::WithSpan<naga::valid::ValidationError>> for ShaderProcessingError {
107 fn from(error: naga::WithSpan<naga::valid::ValidationError>) -> Self {
108 ShaderProcessingError::ValidationError(error)
109 }
110}
111
112#[derive(Debug, Deserialize)]
113pub enum ShaderToyApiResponse {
114 Shader(Shader),
115 Error(String),
116}
117
118impl Shader {
119 pub fn fetch_code_from_last_pass(&self) -> Option<String> {
120 let mut code = String::new();
121 for pass in &self.renderpass {
122 code += &pass.code;
123 }
124 Some(code)
125 }
126
127 pub async fn from_api(shader_id: &str) -> Result<Self, ShaderProcessingError> {
128 let response = reqwest::get(format!(
129 "https://www.shadertoy.com/api/v1/shaders/{shader_id}?key={API_KEY}"
130 ))
131 .await?;
132
133 let shader = response.json::<ShaderToyApiResponse>().await?;
134
135 match shader {
136 ShaderToyApiResponse::Shader(shader) => Ok(shader),
137 ShaderToyApiResponse::Error(error) => Err(error.into()),
138 }
139 }
140
141 pub fn default_uniform_block() -> &'static str {
142 r#"
143 layout(binding=0) uniform vec3 iResolution; // viewport resolution (in pixels)
144 layout(binding=0) uniform float iTime; // shader playback time (in seconds)
145 layout(binding=0) uniform float iTimeDelta; // render time (in seconds)
146 layout(binding=0) uniform int iFrame; // shader playback frame
147 layout(binding=0) uniform vec4 iChannelTime; // channel playback time (in seconds)
148 layout(binding=0) uniform vec4 iMouse; // mouse pixel coords. xy: current (if MLB down), zw: click
149 layout(binding=0) uniform vec4 iDate; // (year, month, day, time in seconds)
150 layout(binding=0) uniform float iSampleRate; // sound sample rate (i.e., 44100)
151 "#
152 }
153
154 pub fn generate_wgsl_shader_code(&self) -> Result<WgslShaderCode, ShaderProcessingError> {
155 let mut glsl = String::from("#version 450 core\n");
156
157 glsl += Shader::default_uniform_block();
158
159 let shader_code = &self.fetch_code_from_last_pass().unwrap();
160 glsl += shader_code;
161
162 glsl += r#" void main() {}"#;
164
165 WgslShaderCode::from_glsl(&glsl)
166 }
167}
168
169pub fn convert_glsl_to_wgsl(glsl: &str) -> Result<String, ShaderProcessingError> {
170 use naga::back::wgsl::WriterFlags;
171 use naga::front::glsl::{Frontend, Options};
172 use naga::ShaderStage;
173
174 let mut frontend = Frontend::default();
176 let options = Options::from(ShaderStage::Fragment);
177
178 let module = frontend.parse(&options, glsl)?;
179
180 let mut wgsl = String::new();
182 let mut wgsl_writer = naga::back::wgsl::Writer::new(&mut wgsl, WriterFlags::empty());
183
184 use naga::valid::Validator;
185 let module_info = Validator::new(
186 naga::valid::ValidationFlags::all(),
187 naga::valid::Capabilities::all(),
188 )
189 .validate(&module)?;
190
191 wgsl_writer.write(&module, &module_info)?;
192
193 Ok(wgsl)
194}
195
196pub struct WgslShaderCode(String);
197
198impl WgslShaderCode {
199 pub fn from_glsl(glsl: &str) -> Result<Self, ShaderProcessingError> {
200 convert_glsl_to_wgsl(glsl).map(Self)
201 }
202
203 pub fn remove_function(&mut self, function_name: &str) -> Result<(), ShaderProcessingError> {
204 self.0 = remove_function_from_wgsl(&self.0, function_name)?;
205 Ok(())
206 }
207
208 pub fn has_function(&self, function_name: &str) -> bool {
209 wgsl_has_function(&self.0, function_name).unwrap_or(false)
210 }
211
212 pub fn rename_function(
213 &mut self,
214 old_function_name: &str,
215 new_function_name: &str,
216 ) -> Result<(), ShaderProcessingError> {
217 self.0 = rename_function_in_wgsl(&self.0, old_function_name, new_function_name)?;
218 Ok(())
219 }
220
221 pub fn remove_line(&mut self, line_to_be_removed: &str) {
222 let mut s = String::new();
223 for line in self.0.lines() {
224 if line.trim() != line_to_be_removed.trim() {
225 s += line;
226 s += "\n";
227 }
228 }
229 self.0 = s;
230 }
231
232 pub fn add_line(&mut self, line: &str) {
233 self.0 += line;
234 self.0 += "\n";
235 }
236
237 pub fn write_to_file(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> {
238 let mut f = std::io::BufWriter::new(std::fs::File::create(path)?);
239 use std::io::Write;
240 f.write_all(self.0.as_bytes())?;
241 Ok(())
242 }
243}
244
245impl std::fmt::Display for WgslShaderCode {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 write!(f, "{}", self.0)
248 }
249}
250
251fn remove_function_from_wgsl(
252 wgsl: &str,
253 function_name: &str,
254) -> Result<String, ShaderProcessingError> {
255 let lines = wgsl.lines();
257 let mut new_wgsl = String::new();
258 let mut in_function = false;
259 let mut function_found = false;
260 let mut curly_braces = 0;
261
262 for line in lines {
263 let line = line.trim();
264 if line.starts_with(function_name) {
265 in_function = true;
266 function_found = true;
267 }
268
269 if in_function {
270 for c in line.chars() {
271 if c == '{' {
272 curly_braces += 1;
273 } else if c == '}' {
274 curly_braces -= 1;
275 }
276 }
277 }
278 if curly_braces == 0 {
279 if !in_function {
280 new_wgsl += format!("{}\n", line).as_str();
281 }
282 in_function = false;
283 }
284 }
285
286 if !function_found {
287 return Err(ShaderProcessingError::ShaderError(format!(
288 "Function {} not found in shader",
289 function_name
290 )));
291 }
292
293 Ok(new_wgsl)
294}
295
296fn wgsl_has_function(wgsl: &str, function_name: &str) -> Result<bool, ShaderProcessingError> {
297 let lines = wgsl.lines();
298 let mut function_found = false;
299 for line in lines {
300 let line = line.trim();
301 if line.starts_with(format!("fn {function_name}(").as_str()) {
302 function_found = true;
303 break;
304 }
305 }
306
307 if !function_found {
308 return Err(ShaderProcessingError::ShaderError(format!(
309 "Function {} not found in shader",
310 function_name
311 )));
312 }
313
314 Ok(true)
315}
316
317fn rename_function_in_wgsl(
318 wgsl: &str,
319 old_function_name: &str,
320 new_function_name: &str,
321) -> Result<String, ShaderProcessingError> {
322 let lines = wgsl.lines();
324 let mut new_wgsl = String::new();
325 let mut in_function = false;
326 let mut function_found = false;
327 for line in lines {
328 let line = line.trim();
329 if line.starts_with(format!("fn {old_function_name}(").as_str()) {
330 in_function = true;
331 function_found = true;
332 new_wgsl += line
333 .replacen(old_function_name, new_function_name, 1)
334 .as_str();
335 } else {
336 new_wgsl += format!("{}\n", line).as_str();
337 }
338
339 if in_function && line.starts_with('}') {
340 in_function = false;
341 }
342 }
343
344 if !function_found {
345 return Err(ShaderProcessingError::ShaderError(format!(
346 "Function `{}` not found in shader",
347 old_function_name
348 )));
349 }
350
351 Ok(new_wgsl)
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 #[test]
358 fn remove_function() {
359 let wgsl = r#"
360fn mainImage(fragColor: ptr<function, vec4<f32>>, fragCoord: vec2<f32>) {
361 var fragCoord_1: vec2<f32>;
362
363 fragCoord_1 = fragCoord;
364 return;
365}
366
367fn main_1() {
368 return;
369}
370
371@fragment
372fn main() {
373 main_1();
374 return;
375}
376"#;
377
378 let new_wgsl = remove_function_from_wgsl(wgsl, "fn main_1()").unwrap();
379
380 assert!(new_wgsl.contains("fn mainImage(fragColor"));
381 assert!(!new_wgsl.contains("fn main_1()"));
382
383 let new_wgsl = remove_function_from_wgsl(&new_wgsl, "fn main(").unwrap();
384 assert!(new_wgsl.contains("fn mainImage(fragColor"));
385 println!("{}", new_wgsl);
386
387 assert!(!new_wgsl.contains("fn main()"));
388
389 let new_wgsl = remove_function_from_wgsl(&new_wgsl, "fn mainImage(").unwrap();
390
391 assert_eq!(new_wgsl.trim(), "@fragment");
392 }
393
394 #[test]
395
396 fn rename_function() {
397 let in_wgsl = r#"fn normal(p_4: vec3<f32>, epsilon: f32) -> vec3<f32>"#;
398 let out_wgsl = rename_function_in_wgsl(in_wgsl, "normal", "sdf3d_normal").unwrap();
399
400 assert!(out_wgsl.contains("fn sdf3d_normal(p_4: vec3<f32>, epsilon: f32) -> vec3<f32>"));
401 }
402
403 #[test]
404 fn test_naga() {
405 let mut glsl = String::from("#version 450 core\n");
406
407 glsl += Shader::default_uniform_block();
408
409 glsl += r#"
411vec3 c = vec3(0.0, 0.0, 0.0);
412const float r = 1.0;
413float distance_from_sphere(vec3 p, vec3 c, float r)
414{
415 return distance(p, c) - r;
416}
417
418float sdf3d(vec3 p)
419{
420 float sphere_0 = distance_from_sphere(p, c, r);
421
422 // set displacement
423 float displacement = sin(5.0 * p.x) * sin(5.0 * p.y) * sin(5.0 * p.z) * 0.25 * sin(2.f * iTime);
424
425 return sphere_0 + displacement;
426}
427
428vec3 sdf3d_normal(in vec3 p, in float epsilon)
429{
430 const vec3 small_step = vec3(epsilon, 0.0, 0.0);
431
432 float gradient_x = sdf3d(p + small_step.xyy) - sdf3d(p - small_step.xyy);
433 float gradient_y = sdf3d(p + small_step.yxy) - sdf3d(p - small_step.yxy);
434 float gradient_z = sdf3d(p + small_step.yyx) - sdf3d(p - small_step.yyx);
435
436 vec3 normal = vec3(gradient_x, gradient_y, gradient_z);
437
438 return normalize(normal);
439}
440
441void mainImage( out vec4 fragColor, in vec2 fragCoord ) {}
442
443"#;
444 glsl += r#" void main() {}"#;
448
449 let wgsl = convert_glsl_to_wgsl(&glsl).unwrap();
450
451 println!("{}", wgsl);
452 }
453}