sdf2mesh/
shadertoy.rs

1static API_KEY: &str = "rdnjhn";
2
3use serde::Deserialize;
4
5#[derive(Debug, Deserialize, Default)]
6#[allow(unused)]
7pub struct ShaderInfo {
8    pub id: String,
9    pub date: String,
10    viewed: i32,
11    pub name: String,
12    pub username: String,
13    pub description: String,
14    likes: i32,
15    published: i32,
16    flags: i32,
17
18    #[serde(rename = "usePreview")]
19    use_preview: i32,
20    tags: Vec<String>,
21    hasliked: i32,
22}
23
24#[derive(Debug, Deserialize, Default)]
25#[allow(unused)]
26pub struct Sampler {
27    filter: String,
28    wrap: String,
29    vflip: String,
30    srgb: String,
31    internal: String,
32}
33
34#[derive(Debug, Deserialize, Default)]
35#[allow(unused)]
36pub struct ShaderInput {
37    id: i32,
38    src: String,
39    ctype: String,
40    channel: i32,
41    sampler: Sampler,
42    published: i32,
43}
44
45#[derive(Debug, Deserialize, Default)]
46#[allow(unused)]
47pub struct ShaderOutput {
48    id: i32,
49    channel: i32,
50}
51
52#[derive(Debug, Deserialize, Default)]
53#[allow(unused)]
54pub struct RenderPass {
55    inputs: Vec<ShaderInput>,
56    outputs: Vec<ShaderOutput>,
57    code: String,
58    name: String,
59    description: String,
60    r#type: String,
61}
62#[derive(Debug, Deserialize, Default)]
63#[allow(unused)]
64pub struct Shader {
65    pub ver: String,
66    pub info: ShaderInfo,
67    pub renderpass: Vec<RenderPass>,
68}
69
70#[derive(Debug)]
71pub enum ShaderProcessingError {
72    RequestError(reqwest::Error),
73    ShaderError(String),
74    ParseErrors(naga::front::glsl::ParseErrors),
75    WgslError(naga::back::wgsl::Error),
76    ValidationError(naga::WithSpan<naga::valid::ValidationError>),
77
78    /// Error when the SDF is missing in the shader
79    MissingSdf(String),
80}
81
82impl From<reqwest::Error> for ShaderProcessingError {
83    fn from(error: reqwest::Error) -> Self {
84        ShaderProcessingError::RequestError(error)
85    }
86}
87
88impl From<String> for ShaderProcessingError {
89    fn from(error: String) -> Self {
90        ShaderProcessingError::ShaderError(error)
91    }
92}
93
94impl From<naga::front::glsl::ParseErrors> for ShaderProcessingError {
95    fn from(error: naga::front::glsl::ParseErrors) -> Self {
96        ShaderProcessingError::ParseErrors(error)
97    }
98}
99
100impl From<naga::back::wgsl::Error> for ShaderProcessingError {
101    fn from(error: naga::back::wgsl::Error) -> Self {
102        ShaderProcessingError::WgslError(error)
103    }
104}
105
106impl From<naga::WithSpan<naga::valid::ValidationError>> for ShaderProcessingError {
107    fn from(error: naga::WithSpan<naga::valid::ValidationError>) -> Self {
108        ShaderProcessingError::ValidationError(error)
109    }
110}
111
112#[derive(Debug, Deserialize)]
113pub enum ShaderToyApiResponse {
114    Shader(Shader),
115    Error(String),
116}
117
118impl Shader {
119    pub fn fetch_code_from_last_pass(&self) -> Option<String> {
120        let mut code = String::new();
121        for pass in &self.renderpass {
122            code += &pass.code;
123        }
124        Some(code)
125    }
126
127    pub async fn from_api(shader_id: &str) -> Result<Self, ShaderProcessingError> {
128        let response = reqwest::get(format!(
129            "https://www.shadertoy.com/api/v1/shaders/{shader_id}?key={API_KEY}"
130        ))
131        .await?;
132
133        let shader = response.json::<ShaderToyApiResponse>().await?;
134
135        match shader {
136            ShaderToyApiResponse::Shader(shader) => Ok(shader),
137            ShaderToyApiResponse::Error(error) => Err(error.into()),
138        }
139    }
140
141    pub fn default_uniform_block() -> &'static str {
142        r#"
143        layout(binding=0) uniform vec3      iResolution;           // viewport resolution (in pixels)
144		layout(binding=0) uniform float     iTime;                 // shader playback time (in seconds)
145		layout(binding=0) uniform float     iTimeDelta;            // render time (in seconds)
146		layout(binding=0) uniform int       iFrame;                // shader playback frame
147		layout(binding=0) uniform vec4      iChannelTime;          // channel playback time (in seconds)
148		layout(binding=0) uniform vec4      iMouse;                // mouse pixel coords. xy: current (if MLB down), zw: click
149		layout(binding=0) uniform vec4      iDate;                 // (year, month, day, time in seconds)
150		layout(binding=0) uniform float     iSampleRate;           // sound sample rate (i.e., 44100)        
151        "#
152    }
153
154    pub fn generate_wgsl_shader_code(&self) -> Result<WgslShaderCode, ShaderProcessingError> {
155        let mut glsl = String::from("#version 450 core\n");
156
157        glsl += Shader::default_uniform_block();
158
159        let shader_code = &self.fetch_code_from_last_pass().unwrap();
160        glsl += shader_code;
161
162        // We add an empty main function to the shader so that naga can compile it to valid WGSL
163        glsl += r#" void main() {}"#;
164
165        WgslShaderCode::from_glsl(&glsl)
166    }
167}
168
169pub fn convert_glsl_to_wgsl(glsl: &str) -> Result<String, ShaderProcessingError> {
170    use naga::back::wgsl::WriterFlags;
171    use naga::front::glsl::{Frontend, Options};
172    use naga::ShaderStage;
173
174    // Setup and parse GLSL fragment shader
175    let mut frontend = Frontend::default();
176    let options = Options::from(ShaderStage::Fragment);
177
178    let module = frontend.parse(&options, glsl)?;
179
180    // Write to WGSL
181    let mut wgsl = String::new();
182    let mut wgsl_writer = naga::back::wgsl::Writer::new(&mut wgsl, WriterFlags::empty());
183
184    use naga::valid::Validator;
185    let module_info = Validator::new(
186        naga::valid::ValidationFlags::all(),
187        naga::valid::Capabilities::all(),
188    )
189    .validate(&module)?;
190
191    wgsl_writer.write(&module, &module_info)?;
192
193    Ok(wgsl)
194}
195
196pub struct WgslShaderCode(String);
197
198impl WgslShaderCode {
199    pub fn from_glsl(glsl: &str) -> Result<Self, ShaderProcessingError> {
200        convert_glsl_to_wgsl(glsl).map(Self)
201    }
202
203    pub fn remove_function(&mut self, function_name: &str) -> Result<(), ShaderProcessingError> {
204        self.0 = remove_function_from_wgsl(&self.0, function_name)?;
205        Ok(())
206    }
207
208    pub fn has_function(&self, function_name: &str) -> bool {
209        wgsl_has_function(&self.0, function_name).unwrap_or(false)
210    }
211
212    pub fn rename_function(
213        &mut self,
214        old_function_name: &str,
215        new_function_name: &str,
216    ) -> Result<(), ShaderProcessingError> {
217        self.0 = rename_function_in_wgsl(&self.0, old_function_name, new_function_name)?;
218        Ok(())
219    }
220
221    pub fn remove_line(&mut self, line_to_be_removed: &str) {
222        let mut s = String::new();
223        for line in self.0.lines() {
224            if line.trim() != line_to_be_removed.trim() {
225                s += line;
226                s += "\n";
227            }
228        }
229        self.0 = s;
230    }
231
232    pub fn add_line(&mut self, line: &str) {
233        self.0 += line;
234        self.0 += "\n";
235    }
236
237    pub fn write_to_file(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> {
238        let mut f = std::io::BufWriter::new(std::fs::File::create(path)?);
239        use std::io::Write;
240        f.write_all(self.0.as_bytes())?;
241        Ok(())
242    }
243}
244
245impl std::fmt::Display for WgslShaderCode {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        write!(f, "{}", self.0)
248    }
249}
250
251fn remove_function_from_wgsl(
252    wgsl: &str,
253    function_name: &str,
254) -> Result<String, ShaderProcessingError> {
255    // find function name in wgsl
256    let lines = wgsl.lines();
257    let mut new_wgsl = String::new();
258    let mut in_function = false;
259    let mut function_found = false;
260    let mut curly_braces = 0;
261
262    for line in lines {
263        let line = line.trim();
264        if line.starts_with(function_name) {
265            in_function = true;
266            function_found = true;
267        }
268
269        if in_function {
270            for c in line.chars() {
271                if c == '{' {
272                    curly_braces += 1;
273                } else if c == '}' {
274                    curly_braces -= 1;
275                }
276            }
277        }
278        if curly_braces == 0 {
279            if !in_function {
280                new_wgsl += format!("{}\n", line).as_str();
281            }
282            in_function = false;
283        }
284    }
285
286    if !function_found {
287        return Err(ShaderProcessingError::ShaderError(format!(
288            "Function {} not found in shader",
289            function_name
290        )));
291    }
292
293    Ok(new_wgsl)
294}
295
296fn wgsl_has_function(wgsl: &str, function_name: &str) -> Result<bool, ShaderProcessingError> {
297    let lines = wgsl.lines();
298    let mut function_found = false;
299    for line in lines {
300        let line = line.trim();
301        if line.starts_with(format!("fn {function_name}(").as_str()) {
302            function_found = true;
303            break;
304        }
305    }
306
307    if !function_found {
308        return Err(ShaderProcessingError::ShaderError(format!(
309            "Function {} not found in shader",
310            function_name
311        )));
312    }
313
314    Ok(true)
315}
316
317fn rename_function_in_wgsl(
318    wgsl: &str,
319    old_function_name: &str,
320    new_function_name: &str,
321) -> Result<String, ShaderProcessingError> {
322    // find function name in wgsl
323    let lines = wgsl.lines();
324    let mut new_wgsl = String::new();
325    let mut in_function = false;
326    let mut function_found = false;
327    for line in lines {
328        let line = line.trim();
329        if line.starts_with(format!("fn {old_function_name}(").as_str()) {
330            in_function = true;
331            function_found = true;
332            new_wgsl += line
333                .replacen(old_function_name, new_function_name, 1)
334                .as_str();
335        } else {
336            new_wgsl += format!("{}\n", line).as_str();
337        }
338
339        if in_function && line.starts_with('}') {
340            in_function = false;
341        }
342    }
343
344    if !function_found {
345        return Err(ShaderProcessingError::ShaderError(format!(
346            "Function `{}` not found in shader",
347            old_function_name
348        )));
349    }
350
351    Ok(new_wgsl)
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    #[test]
358    fn remove_function() {
359        let wgsl = r#"        
360fn mainImage(fragColor: ptr<function, vec4<f32>>, fragCoord: vec2<f32>) {
361    var fragCoord_1: vec2<f32>;
362
363    fragCoord_1 = fragCoord;
364    return;
365}
366
367fn main_1() {
368    return;
369}
370
371@fragment
372fn main() {
373    main_1();
374    return;
375}
376"#;
377
378        let new_wgsl = remove_function_from_wgsl(wgsl, "fn main_1()").unwrap();
379
380        assert!(new_wgsl.contains("fn mainImage(fragColor"));
381        assert!(!new_wgsl.contains("fn main_1()"));
382
383        let new_wgsl = remove_function_from_wgsl(&new_wgsl, "fn main(").unwrap();
384        assert!(new_wgsl.contains("fn mainImage(fragColor"));
385        println!("{}", new_wgsl);
386
387        assert!(!new_wgsl.contains("fn main()"));
388
389        let new_wgsl = remove_function_from_wgsl(&new_wgsl, "fn mainImage(").unwrap();
390
391        assert_eq!(new_wgsl.trim(), "@fragment");
392    }
393
394    #[test]
395
396    fn rename_function() {
397        let in_wgsl = r#"fn normal(p_4: vec3<f32>, epsilon: f32) -> vec3<f32>"#;
398        let out_wgsl = rename_function_in_wgsl(in_wgsl, "normal", "sdf3d_normal").unwrap();
399
400        assert!(out_wgsl.contains("fn sdf3d_normal(p_4: vec3<f32>, epsilon: f32) -> vec3<f32>"));
401    }
402
403    #[test]
404    fn test_naga() {
405        let mut glsl = String::from("#version 450 core\n");
406
407        glsl += Shader::default_uniform_block();
408
409        // Our test shader
410        glsl += r#"
411vec3 c = vec3(0.0, 0.0, 0.0);
412const float r = 1.0;
413float distance_from_sphere(vec3 p, vec3 c, float r)
414{
415    return distance(p, c) - r;
416}
417
418float sdf3d(vec3 p)
419{
420    float sphere_0 = distance_from_sphere(p, c, r);
421    
422    // set displacement
423    float displacement = sin(5.0 * p.x) * sin(5.0 * p.y) * sin(5.0 * p.z) * 0.25 * sin(2.f * iTime);
424    
425    return sphere_0 + displacement;
426}
427
428vec3 sdf3d_normal(in vec3 p, in float epsilon)
429{
430    const vec3 small_step = vec3(epsilon, 0.0, 0.0);
431
432    float gradient_x = sdf3d(p + small_step.xyy) - sdf3d(p - small_step.xyy);
433    float gradient_y = sdf3d(p + small_step.yxy) - sdf3d(p - small_step.yxy);
434    float gradient_z = sdf3d(p + small_step.yyx) - sdf3d(p - small_step.yyx);
435
436    vec3 normal = vec3(gradient_x, gradient_y, gradient_z);
437
438    return normalize(normal);
439}
440
441void mainImage( out vec4 fragColor, in vec2 fragCoord ) {}
442
443"#;
444        // We simply add an empty main function to the shader
445        // Because the shader can only be parsed if it has a main function
446        // The actual main function is added later via dualcontour.wgsl shader
447        glsl += r#" void main() {}"#;
448
449        let wgsl = convert_glsl_to_wgsl(&glsl).unwrap();
450
451        println!("{}", wgsl);
452    }
453}