wgsl_bindgen/generate/
shader_module.rs

1//! This file is used for creating direct shader file related functions:
2//! such as `create_shader_module`, `create_compute_module`
3
4use std::path::Path;
5
6use derive_more::Constructor;
7use enumflags2::BitFlags;
8use proc_macro2::TokenStream;
9use quote::{format_ident, quote, TokenStreamExt};
10use syn::{Ident, Index};
11
12use crate::generate::quote_naga_capabilities;
13use crate::naga_util::module_to_source;
14use crate::quote_gen::create_shader_raw_string_literal;
15use crate::{
16  sanitize_and_pascal_case, WgslBindgenOption, WgslEntryResult, WgslShaderSourceType,
17};
18
19impl<'a> WgslEntryResult<'a> {
20  fn get_label(&self) -> TokenStream {
21    let get_label = || {
22      self
23        .source_including_deps
24        .source_file
25        .file_path
26        .file_name()?
27        .to_str()
28    };
29
30    match get_label() {
31      Some(label) => quote!(Some(#label)),
32      None => quote!(None),
33    }
34  }
35}
36
37impl WgslShaderSourceType {
38  pub(crate) fn create_shader_module_fn_name(&self) -> &'static str {
39    use WgslShaderSourceType::*;
40    match self {
41      EmbedSource => "create_shader_module_embed_source",
42      EmbedWithNagaOilComposer => "create_shader_module_embedded",
43      ComposerWithRelativePath => "create_shader_module_relative_path",
44    }
45  }
46
47  pub(crate) fn load_shader_module_fn_name(&self) -> Ident {
48    use WgslShaderSourceType::*;
49    match self {
50      ComposerWithRelativePath => format_ident!("load_naga_module_from_path"),
51      _ => format_ident!("load_shader_module_embedded"),
52    }
53  }
54
55  pub(crate) fn create_compute_pipeline_fn_name(&self, name: &str) -> Ident {
56    use WgslShaderSourceType::*;
57    match self {
58      EmbedSource => format_ident!("create_{}_pipeline_embed_source", name),
59      EmbedWithNagaOilComposer => {
60        format_ident!("create_{}_pipeline_embedded", name)
61      }
62      ComposerWithRelativePath => {
63        format_ident!("create_{}_pipeline_relative_path", name)
64      }
65    }
66  }
67
68  pub(crate) fn get_return_type(&self, type_to_return: TokenStream) -> TokenStream {
69    use WgslShaderSourceType::*;
70    match self {
71      EmbedSource => type_to_return,
72      EmbedWithNagaOilComposer | ComposerWithRelativePath => {
73        quote!(Result<#type_to_return, naga_oil::compose::ComposerError>)
74      }
75    }
76  }
77
78  pub(crate) fn wrap_return_stmt(&self, stm: TokenStream) -> TokenStream {
79    use WgslShaderSourceType::*;
80    match self {
81      EmbedWithNagaOilComposer | ComposerWithRelativePath => quote!(Ok(#stm)),
82      _ => stm,
83    }
84  }
85
86  pub(crate) fn get_propagate_operator(&self) -> TokenStream {
87    use WgslShaderSourceType::*;
88    match self {
89      EmbedWithNagaOilComposer | ComposerWithRelativePath => quote!(?),
90      _ => quote!(),
91    }
92  }
93
94  pub(crate) fn add_composable_naga_module_stmt(
95    &self,
96    source: TokenStream,
97    relative_file_path: String,
98    as_name_assignment: TokenStream,
99  ) -> TokenStream {
100    use WgslShaderSourceType::*;
101
102    match self {
103      EmbedWithNagaOilComposer | ComposerWithRelativePath => quote! {
104        composer.add_composable_module(
105          naga_oil::compose::ComposableModuleDescriptor {
106            source: #source,
107            file_path: #relative_file_path,
108            language: naga_oil::compose::ShaderLanguage::Wgsl,
109            shader_defs: shader_defs.clone(),
110            #as_name_assignment,
111            ..Default::default()
112          }
113        )?;
114      },
115      _ => panic!("Not supported"),
116    }
117  }
118
119  pub(crate) fn generate_make_naga_module_statement(
120    &self,
121    source: TokenStream,
122    relative_file_path: String,
123  ) -> TokenStream {
124    use WgslShaderSourceType::*;
125    match self {
126      EmbedWithNagaOilComposer | ComposerWithRelativePath => quote! {
127        composer.make_naga_module(naga_oil::compose::NagaModuleDescriptor {
128          source: #source,
129          file_path: #relative_file_path,
130          shader_defs,
131          ..Default::default()
132        })
133      },
134      _ => panic!("Not supported"),
135    }
136  }
137
138  pub(crate) fn shader_module_params_defs_and_params(
139    &self,
140  ) -> (TokenStream, TokenStream) {
141    use WgslShaderSourceType::*;
142    match self {
143      EmbedSource => {
144        let param_defs = quote!(device: &wgpu::Device);
145        let params = quote!(device);
146        (param_defs, params)
147      }
148      EmbedWithNagaOilComposer => {
149        let param_defs = quote! {
150          device: &wgpu::Device,
151          shader_defs: std::collections::HashMap<String, naga_oil::compose::ShaderDefValue>
152        };
153        let params = quote!(device, shader_defs);
154        (param_defs, params)
155      }
156      ComposerWithRelativePath => {
157        let param_defs = quote! {
158          device: &wgpu::Device,
159          base_dir: &str,
160          shader_defs: std::collections::HashMap<String, naga_oil::compose::ShaderDefValue>,
161          load_file: impl Fn(&str) -> Result<String, std::io::Error>
162        };
163        let params = quote!(device, base_dir, shader_defs, load_file);
164        (param_defs, params)
165      }
166    }
167  }
168}
169
170#[derive(Constructor)]
171struct ComputeModuleBuilder<'a> {
172  module: &'a naga::Module,
173  source_type_flags: BitFlags<WgslShaderSourceType>,
174}
175
176impl<'a> ComputeModuleBuilder<'a> {
177  fn build_compute_pipeline_fn(
178    e: &naga::EntryPoint,
179    source_type: WgslShaderSourceType,
180  ) -> TokenStream {
181    // Compute pipeline creation has few parameters and can be generated.
182
183    let pipeline_name = source_type.create_compute_pipeline_fn_name(&e.name);
184
185    let entry_point = &e.name;
186    // TODO: Include a user supplied module name in the label?
187    let label = format!("Compute Pipeline {}", e.name);
188
189    let create_shader_module_fn_name =
190      format_ident!("{}", source_type.create_shader_module_fn_name());
191
192    let (param_defs, params) = source_type.shader_module_params_defs_and_params();
193
194    let return_type = source_type.get_return_type(quote!(wgpu::ComputePipeline));
195    let propagate_operator = source_type.get_propagate_operator();
196
197    let module_creation = quote! {
198      let module = super::#create_shader_module_fn_name(#params)#propagate_operator;
199    };
200
201    let return_value = source_type.wrap_return_stmt(quote! {
202      device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
203          label: Some(#label),
204          layout: Some(&layout),
205          module: &module,
206          entry_point: Some(#entry_point),
207          compilation_options: Default::default(),
208          cache: None,
209      })
210    });
211
212    quote! {
213        pub fn #pipeline_name(#param_defs) -> #return_type {
214            #module_creation
215            let layout = super::create_pipeline_layout(device);
216            #return_value
217        }
218    }
219  }
220
221  fn workgroup_size(e: &naga::EntryPoint) -> TokenStream {
222    // Use Index to avoid specifying the type on literals.
223    let name = format_ident!("{}_WORKGROUP_SIZE", e.name.to_uppercase());
224    let [x, y, z] = e.workgroup_size.map(|s| Index::from(s as usize));
225    quote!(pub const #name: [u32; 3] = [#x, #y, #z];)
226  }
227
228  pub(crate) fn entry_points_iter(&self) -> impl Iterator<Item = &naga::EntryPoint> {
229    self
230      .module
231      .entry_points
232      .iter()
233      .filter(|e| e.stage == naga::ShaderStage::Compute)
234  }
235
236  fn build(&self) -> TokenStream {
237    let entry_points: Vec<_> = self
238      .entry_points_iter()
239      .map(|e| {
240        let workgroup_size_constant = Self::workgroup_size(e);
241
242        let create_pipeline_fns = self
243          .source_type_flags
244          .iter()
245          .map(|source_type| Self::build_compute_pipeline_fn(e, source_type))
246          .collect::<Vec<_>>();
247
248        quote! {
249            #workgroup_size_constant
250            #(#create_pipeline_fns)*
251        }
252      })
253      .collect();
254
255    if entry_points.is_empty() {
256      // Don't include empty modules.
257      quote!()
258    } else {
259      quote! {
260          pub mod compute {
261              use super::{_root, _root::*};
262              #(#entry_points)*
263          }
264      }
265    }
266  }
267}
268pub(crate) fn compute_module(
269  module: &naga::Module,
270  source_type_flags: BitFlags<WgslShaderSourceType>,
271) -> TokenStream {
272  ComputeModuleBuilder::new(module, source_type_flags).build()
273}
274
275fn generate_shader_module_embedded(entry: &WgslEntryResult) -> TokenStream {
276  let shader_content = module_to_source(&entry.naga_module).unwrap();
277  let create_shader_module_fn =
278    format_ident!("{}", WgslShaderSourceType::EmbedSource.create_shader_module_fn_name());
279  let shader_literal = create_shader_raw_string_literal(&shader_content);
280  let shader_label = entry.get_label();
281  let create_shader_module = quote! {
282      pub fn #create_shader_module_fn(device: &wgpu::Device) -> wgpu::ShaderModule {
283          let source = std::borrow::Cow::Borrowed(SHADER_STRING);
284          device.create_shader_module(wgpu::ShaderModuleDescriptor {
285              label: #shader_label,
286              source: wgpu::ShaderSource::Wgsl(source)
287          })
288      }
289  };
290  let shader_str_def = quote!(pub const SHADER_STRING: &str = #shader_literal;);
291
292  quote! {
293    #create_shader_module
294    #shader_str_def
295  }
296}
297
298struct ComposeShaderModuleBuilder<'a, 'b> {
299  entry: &'a WgslEntryResult<'b>,
300  capabilities: Option<naga::valid::Capabilities>,
301  entry_source_path: &'a Path,
302  output_dir: &'a Path,
303  workspace_root: &'a Path,
304  source_type: WgslShaderSourceType,
305  shader_defs: crate::FastIndexMap<String, naga_oil::compose::ShaderDefValue>,
306}
307
308impl<'a, 'b> ComposeShaderModuleBuilder<'a, 'b> {
309  fn new(
310    entry: &'a WgslEntryResult<'b>,
311    capabilities: Option<naga::valid::Capabilities>,
312    output_dir: &'a Path,
313    workspace_root: &'a Path,
314    source_type: WgslShaderSourceType,
315    shader_defs: &[(String, naga_oil::compose::ShaderDefValue)],
316  ) -> Self {
317    let entry_source_path = entry.source_including_deps.source_file.file_path.as_path();
318
319    // Convert Vec to FastIndexMap for consistent ordering
320    let shader_defs_map: crate::FastIndexMap<String, naga_oil::compose::ShaderDefValue> =
321      shader_defs.iter().cloned().collect();
322
323    Self {
324      entry,
325      capabilities,
326      output_dir,
327      workspace_root,
328      source_type,
329      entry_source_path,
330      shader_defs: shader_defs_map,
331    }
332  }
333
334  fn generate_constants_for_paths(&self) -> TokenStream {
335    use WgslShaderSourceType::*;
336
337    match self.source_type {
338      ComposerWithRelativePath => {
339        let shader_entry_path =
340          get_path_relative_to(self.workspace_root, self.entry_source_path);
341        quote! {
342          pub const SHADER_ENTRY_PATH: &str = #shader_entry_path;
343        }
344      }
345      _ => quote!(),
346    }
347  }
348
349  fn create_shader_module_fn_name(&self) -> Ident {
350    let name = self.source_type.create_shader_module_fn_name();
351    format_ident!("{}", name)
352  }
353
354  fn build_shader_dependency_modules_statements(&self) -> Vec<TokenStream> {
355    let dependency_modules = self
356      .entry
357      .source_including_deps
358      .full_dependencies
359      .iter()
360      .map(|dep| {
361        let as_name = dep
362          .module_name
363          .as_ref()
364          .map(|name| name.to_string())
365          .unwrap();
366        let as_name_assignment = quote! { as_name: Some(#as_name.into()) };
367
368        let relative_file_path = get_path_relative_to(self.output_dir, &dep.file_path);
369        let source = quote!(include_str!(#relative_file_path));
370
371        self.source_type.add_composable_naga_module_stmt(
372          source,
373          relative_file_path,
374          as_name_assignment,
375        )
376      })
377      .collect::<Vec<_>>();
378
379    dependency_modules
380  }
381
382  fn build_load_shader_module_fn(&self) -> TokenStream {
383    use WgslShaderSourceType::*;
384
385    let load_shader_module_fn_name = self.source_type.load_shader_module_fn_name();
386    let return_type = self.source_type.get_return_type(quote!(wgpu::naga::Module));
387
388    match self.source_type {
389      ComposerWithRelativePath => {
390        // For the new variant, we don't generate anything here - the global function handles it
391        quote!()
392      }
393      _ => {
394        // Keep existing implementation for other variants
395        let dependency_modules = self.build_shader_dependency_modules_statements();
396        let relative_file_path =
397          get_path_relative_to(self.output_dir, self.entry_source_path);
398
399        let source = quote!(include_str!(#relative_file_path));
400
401        let make_naga_module_stmt = self
402          .source_type
403          .generate_make_naga_module_statement(source, relative_file_path);
404
405        quote! {
406          pub fn #load_shader_module_fn_name(
407            composer: &mut naga_oil::compose::Composer,
408            shader_defs: std::collections::HashMap<String, naga_oil::compose::ShaderDefValue>
409          ) -> #return_type {
410            #(#dependency_modules)*
411            #make_naga_module_stmt
412          }
413        }
414      }
415    }
416  }
417
418  fn create_shader_module_fn(&self) -> TokenStream {
419    use WgslShaderSourceType::*;
420
421    let create_shader_module_fn = self.create_shader_module_fn_name();
422    let load_shader_module_fn_name = self.source_type.load_shader_module_fn_name();
423    let shader_label = self.entry.get_label();
424
425    let shader_enum_variant = self.entry.get_shader_variant();
426    let return_type = self.source_type.get_return_type(quote!(wgpu::ShaderModule));
427    let propagate_operator = self.source_type.get_propagate_operator();
428    let return_stmt = self.source_type.wrap_return_stmt(quote! { shader_module });
429
430    let composer = quote!(naga_oil::compose::Composer::default());
431
432    let composer_with_capabilities = match self.capabilities {
433      Some(capabilities) => {
434        let capabilities_expr = quote_naga_capabilities(capabilities);
435        quote! {
436          #composer.with_capabilities(#capabilities_expr)
437        }
438      }
439      None => quote! {
440        #composer
441      },
442    };
443
444    match self.source_type {
445      ComposerWithRelativePath => {
446        quote! {
447          pub fn #create_shader_module_fn(
448            device: &wgpu::Device,
449            base_dir: &str,
450            shader_defs: std::collections::HashMap<String, naga_oil::compose::ShaderDefValue>,
451            load_file: impl Fn(&str) -> Result<String, std::io::Error>,
452          ) -> #return_type
453          {
454            let mut composer = #composer_with_capabilities;
455            let module = ShaderEntry::#shader_enum_variant.load_naga_module_from_path(base_dir, &mut composer, shader_defs, load_file).map_err(|e| {
456              naga_oil::compose::ComposerError {
457                inner: naga_oil::compose::ComposerErrorInner::ImportNotFound(e, 0),
458                source: naga_oil::compose::ErrSource::Constructing {
459                  path: "load_naga_module_from_path".to_string(),
460                  source: "Generated code".to_string(),
461                  offset: 0,
462                },
463              }
464            })?;
465
466            // Use naga-ir feature to create shader module directly from naga module
467            let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
468              label: #shader_label,
469              source: wgpu::ShaderSource::Naga(std::borrow::Cow::Owned(module))
470            });
471
472            #return_stmt
473          }
474        }
475      }
476      _ => {
477        quote! {
478          pub fn #create_shader_module_fn(
479            device: &wgpu::Device,
480            shader_defs: std::collections::HashMap<String, naga_oil::compose::ShaderDefValue>
481          ) -> #return_type {
482
483            let mut composer = #composer_with_capabilities;
484            let module = #load_shader_module_fn_name (&mut composer, shader_defs) #propagate_operator;
485
486            // Mini validation to get module info
487            let info = wgpu::naga::valid::Validator::new(
488              wgpu::naga::valid::ValidationFlags::empty(),
489              wgpu::naga::valid::Capabilities::all(),
490            )
491            .validate(&module)
492            .unwrap();
493
494            // Write to wgsl
495            let shader_string = wgpu::naga::back::wgsl::write_string(
496              &module,
497              &info,
498              wgpu::naga::back::wgsl::WriterFlags::empty(),
499            ).expect("failed to convert naga module to source");
500
501            let source = std::borrow::Cow::Owned(shader_string);
502            let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
503              label: #shader_label,
504              source: wgpu::ShaderSource::Wgsl(source)
505            });
506
507            #return_stmt
508          }
509        }
510      }
511    }
512  }
513
514  fn build(&self) -> TokenStream {
515    use WgslShaderSourceType::*;
516
517    let constants = self.generate_constants_for_paths();
518    let load_shader_module_fn = self.build_load_shader_module_fn();
519    let create_shader_module_fn = self.create_shader_module_fn();
520
521    quote! {
522      #constants
523      #load_shader_module_fn
524      #create_shader_module_fn
525    }
526  }
527}
528
529pub(crate) fn generate_global_load_naga_module_from_path() -> TokenStream {
530  quote! {
531    /// Visits and processes all shader files in a dependency tree.
532    ///
533    /// This function traverses the shader dependency tree and calls the visitor function
534    /// for each file encountered. This allows for custom processing like hot reloading,
535    /// caching, or debugging.
536    ///
537    /// # Arguments
538    ///
539    /// * `base_dir` - The base directory for resolving relative paths
540    /// * `load_file` - Function to load file contents from a path
541    /// * `visitor` - Function called for each file with (file_path, file_content)
542    ///
543    /// # Returns
544    ///
545    /// Returns `Ok(())` if all files were processed successfully, or an error string.
546    pub fn visit_shader_files(
547      &self,
548      base_dir: &str,
549      load_file: impl Fn(&str) -> Result<String, std::io::Error>,
550      mut visitor: impl FnMut(&str, &str),
551    ) -> Result<(), String> {
552        fn visit_dependencies_recursive(
553          base_dir: &str,
554          source: &str,
555          current_path: &str,
556          load_file: &impl Fn(&str) -> Result<String, std::io::Error>,
557          visitor: &mut impl FnMut(&str, &str),
558          visited: &mut std::collections::HashSet<String>,
559        ) -> Result<(), String> {
560          // Use naga_oil's preprocessor to get import information
561          let (_, imports, _) = naga_oil::compose::get_preprocessor_data(source);
562
563          for import in imports {
564            let import_path = if import.import.starts_with('\"') {
565              // Strip quotes from string literals
566              import.import
567                .chars()
568                .skip(1)
569                .take_while(|c| *c != '\"')
570                .collect::<String>()
571            } else {
572              // For module imports like "global_bindings::time", extract just the module name
573              let module_path = import.import.split("::").collect::<Vec<_>>().join(std::path::MAIN_SEPARATOR_STR);
574              format!("{module_path}.wgsl")
575            };
576
577            // Resolve import path - simplified to always resolve from base directory
578            // This works for both module imports (global_bindings::time) and relative imports
579            let full_import_path = if import_path.starts_with('/') || import_path.starts_with('\\') {
580              format!("{base_dir}{import_path}")
581            } else {
582              // Use proper path joining for Windows compatibility
583              std::path::Path::new(base_dir).join(import_path).display().to_string()
584            };
585
586            // Skip if already visited
587            if visited.contains(&full_import_path) {
588              continue;
589            }
590
591            visited.insert(full_import_path.clone());
592
593            // Load the imported file
594            let import_source = match load_file(&full_import_path) {
595              Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
596                continue;
597              }
598              Err(err) => {
599                return Err(format!("Failed to load import file {full_import_path}: {err}"));
600              }
601              Ok(content) => content,
602            };
603
604            // Recursively visit its dependencies
605            visit_dependencies_recursive(
606              base_dir,
607              &import_source,
608              full_import_path.trim_start_matches(&format!("{base_dir}/")),
609              load_file,
610              visitor,
611              visited,
612            )?;
613
614            // Call visitor for the inner most files first
615            visitor(&full_import_path, &import_source);
616          }
617
618          Ok(())
619        }
620
621        // Load entry point source
622        let entry_path = format!("{}/{}", base_dir, self.relative_path());
623        let entry_source = load_file(&entry_path)
624          .map_err(|e| format!("Failed to load entry point {entry_path}: {e}"))?;
625
626        // Call visitor for entry point
627        visitor(&entry_path, &entry_source);
628
629        // Visit all dependencies
630        let mut visited = std::collections::HashSet::new();
631        visit_dependencies_recursive(
632          base_dir,
633          &entry_source,
634          self.relative_path(),
635          &load_file,
636          &mut visitor,
637          &mut visited,
638        )?;
639
640        Ok(())
641      }
642
643      pub fn load_naga_module_from_path_contents(
644        &self,
645        base_dir: &str,
646        composer: &mut naga_oil::compose::Composer,
647        shader_defs: std::collections::HashMap<String, naga_oil::compose::ShaderDefValue>,
648        files: Vec<(String, String)>,
649      ) -> Result<wgpu::naga::Module, naga_oil::compose::ComposerError>
650      {
651        // Process dependency files first (all except entry point)
652        let entry_path = format!("{}/{}", base_dir, self.relative_path());
653
654        for (file_path, file_content) in &files {
655          if *file_path == entry_path {
656            continue; // Skip entry point, process it last
657          }
658
659          // Extract module name from file path (remove .wgsl extension)
660          let relative_path = file_path.trim_start_matches(&format!("{base_dir}/"));
661          let as_name = std::path::Path::new(relative_path)
662            .with_extension("")
663            .with_extension("")
664            .iter()
665            .flat_map(|s| s.to_str())
666            .collect::<Vec<_>>()
667            .join("::")
668            .to_string();
669
670          composer.add_composable_module(naga_oil::compose::ComposableModuleDescriptor {
671            source: file_content,
672            file_path: relative_path,
673            language: naga_oil::compose::ShaderLanguage::Wgsl,
674            shader_defs: shader_defs.clone(),
675            as_name: Some(as_name),
676            ..Default::default()
677          })?;
678        }
679
680        // Get entry point content
681        let (_, entry_source) = &files[0];
682
683        // Create the final module
684        composer.make_naga_module(naga_oil::compose::NagaModuleDescriptor {
685          source: entry_source,
686          file_path: self.relative_path(),
687          shader_defs,
688          ..Default::default()
689        })
690      }
691
692      pub fn load_naga_module_from_path(
693        &self,
694        base_dir: &str,
695        composer: &mut naga_oil::compose::Composer,
696        shader_defs: std::collections::HashMap<String, naga_oil::compose::ShaderDefValue>,
697        load_file: impl Fn(&str) -> Result<String, std::io::Error>,
698      ) -> Result<wgpu::naga::Module, String>
699      {
700        let mut files = Vec::<(String, String)>::new();
701        self.visit_shader_files(base_dir, &load_file, |file_path, file_content| {
702          files.push((file_path.to_string(), file_content.to_string()));
703        })?;
704        self.load_naga_module_from_path_contents(base_dir, composer, shader_defs, files)
705          .map_err(|e| format!("{e}"))
706      }
707  }
708}
709
710pub(crate) fn shader_module(
711  entry: &WgslEntryResult,
712  options: &WgslBindgenOption,
713) -> TokenStream {
714  use WgslShaderSourceType::*;
715  let source_type = options.shader_source_type;
716  let output_dir = options
717    .output
718    .as_ref()
719    .and_then(|output_file| output_file.parent().map(|p| p.to_path_buf()))
720    .unwrap_or_else(|| {
721      std::env::var("CARGO_MANIFEST_DIR")
722        .unwrap_or_else(|_| ".".into())
723        .into()
724    });
725
726  let mut token_stream = TokenStream::new();
727
728  if source_type.contains(EmbedSource) {
729    token_stream.append_all(generate_shader_module_embedded(entry));
730  }
731
732  let capabilities = options.ir_capabilities;
733
734  if source_type.contains(EmbedWithNagaOilComposer) {
735    let builder = ComposeShaderModuleBuilder::new(
736      entry,
737      capabilities,
738      &output_dir,
739      &output_dir,
740      EmbedWithNagaOilComposer,
741      &options.shader_defs,
742    );
743    token_stream.append_all(builder.build());
744  }
745
746  if source_type.contains(ComposerWithRelativePath) {
747    let builder = ComposeShaderModuleBuilder::new(
748      entry,
749      capabilities,
750      &output_dir,
751      &options.workspace_root,
752      ComposerWithRelativePath,
753      &options.shader_defs,
754    );
755    token_stream.append_all(builder.build());
756  }
757
758  token_stream
759}
760
761fn get_path_relative_to(relative_to: &std::path::Path, file: &std::path::Path) -> String {
762  pathdiff::diff_paths(file, relative_to)
763    .expect("failed to get relative path")
764    .to_str()
765    .unwrap()
766    .to_string()
767}
768
769fn create_canonical_variable_name(name: &str, is_const: bool) -> String {
770  let canonical_name = name
771    .replace("::", "_")
772    .replace(" ", "_")
773    .chars()
774    .filter(|c| c.is_alphanumeric() || *c == '_')
775    .collect::<String>();
776
777  if is_const {
778    canonical_name.to_uppercase()
779  } else {
780    canonical_name.to_lowercase()
781  }
782}
783
784#[cfg(test)]
785mod tests {
786  use indoc::indoc;
787
788  use super::*;
789  use crate::assert_tokens_snapshot;
790
791  #[test]
792  fn test_create_canonical_variable_name() {
793    assert_eq!("foo", create_canonical_variable_name("Foo", false));
794    assert_eq!("FOO", create_canonical_variable_name("Foo", true));
795    assert_eq!("foo_bar", create_canonical_variable_name("Foo::Bar", false));
796    assert_eq!("FOO_BAR", create_canonical_variable_name("Foo::Bar", true));
797    assert_eq!("foo_bar", create_canonical_variable_name("Foo Bar", false));
798    assert_eq!("FOO_BAR", create_canonical_variable_name("Foo Bar", true));
799  }
800
801  #[test]
802  fn write_compute_module_empty() {
803    let source = indoc! {r#"
804            @vertex
805            fn main() {}
806        "#};
807
808    let module = naga::front::wgsl::parse_str(source).unwrap();
809    let actual = compute_module(&module, WgslShaderSourceType::EmbedSource.into());
810
811    assert_tokens_snapshot!(actual);
812  }
813
814  #[test]
815  fn write_compute_module_multiple_entries() {
816    let source = indoc! {r#"
817            @compute
818            @workgroup_size(1,2,3)
819            fn main1() {}
820
821            @compute
822            @workgroup_size(256)
823            fn main2() {}
824        "#
825    };
826
827    let module = naga::front::wgsl::parse_str(source).unwrap();
828    let actual = compute_module(&module, WgslShaderSourceType::EmbedSource.into());
829
830    assert_tokens_snapshot!(actual);
831  }
832}