1use std::path::Path;
5
6use derive_more::Constructor;
7use enumflags2::BitFlags;
8use proc_macro2::TokenStream;
9use quote::{format_ident, quote, TokenStreamExt};
10use syn::{Ident, Index};
11
12use crate::generate::quote_naga_capabilities;
13use crate::naga_util::module_to_source;
14use crate::quote_gen::create_shader_raw_string_literal;
15use crate::{
16 sanitize_and_pascal_case, WgslBindgenOption, WgslEntryResult, WgslShaderSourceType,
17};
18
19impl<'a> WgslEntryResult<'a> {
20 fn get_label(&self) -> TokenStream {
21 let get_label = || {
22 self
23 .source_including_deps
24 .source_file
25 .file_path
26 .file_name()?
27 .to_str()
28 };
29
30 match get_label() {
31 Some(label) => quote!(Some(#label)),
32 None => quote!(None),
33 }
34 }
35}
36
37impl WgslShaderSourceType {
38 pub(crate) fn create_shader_module_fn_name(&self) -> &'static str {
39 use WgslShaderSourceType::*;
40 match self {
41 EmbedSource => "create_shader_module_embed_source",
42 EmbedWithNagaOilComposer => "create_shader_module_embedded",
43 ComposerWithRelativePath => "create_shader_module_relative_path",
44 }
45 }
46
47 pub(crate) fn load_shader_module_fn_name(&self) -> Ident {
48 use WgslShaderSourceType::*;
49 match self {
50 ComposerWithRelativePath => format_ident!("load_naga_module_from_path"),
51 _ => format_ident!("load_shader_module_embedded"),
52 }
53 }
54
55 pub(crate) fn create_compute_pipeline_fn_name(&self, name: &str) -> Ident {
56 use WgslShaderSourceType::*;
57 match self {
58 EmbedSource => format_ident!("create_{}_pipeline_embed_source", name),
59 EmbedWithNagaOilComposer => {
60 format_ident!("create_{}_pipeline_embedded", name)
61 }
62 ComposerWithRelativePath => {
63 format_ident!("create_{}_pipeline_relative_path", name)
64 }
65 }
66 }
67
68 pub(crate) fn get_return_type(&self, type_to_return: TokenStream) -> TokenStream {
69 use WgslShaderSourceType::*;
70 match self {
71 EmbedSource => type_to_return,
72 EmbedWithNagaOilComposer | ComposerWithRelativePath => {
73 quote!(Result<#type_to_return, naga_oil::compose::ComposerError>)
74 }
75 }
76 }
77
78 pub(crate) fn wrap_return_stmt(&self, stm: TokenStream) -> TokenStream {
79 use WgslShaderSourceType::*;
80 match self {
81 EmbedWithNagaOilComposer | ComposerWithRelativePath => quote!(Ok(#stm)),
82 _ => stm,
83 }
84 }
85
86 pub(crate) fn get_propagate_operator(&self) -> TokenStream {
87 use WgslShaderSourceType::*;
88 match self {
89 EmbedWithNagaOilComposer | ComposerWithRelativePath => quote!(?),
90 _ => quote!(),
91 }
92 }
93
94 pub(crate) fn add_composable_naga_module_stmt(
95 &self,
96 source: TokenStream,
97 relative_file_path: String,
98 as_name_assignment: TokenStream,
99 ) -> TokenStream {
100 use WgslShaderSourceType::*;
101
102 match self {
103 EmbedWithNagaOilComposer | ComposerWithRelativePath => quote! {
104 composer.add_composable_module(
105 naga_oil::compose::ComposableModuleDescriptor {
106 source: #source,
107 file_path: #relative_file_path,
108 language: naga_oil::compose::ShaderLanguage::Wgsl,
109 shader_defs: shader_defs.clone(),
110 #as_name_assignment,
111 ..Default::default()
112 }
113 )?;
114 },
115 _ => panic!("Not supported"),
116 }
117 }
118
119 pub(crate) fn generate_make_naga_module_statement(
120 &self,
121 source: TokenStream,
122 relative_file_path: String,
123 ) -> TokenStream {
124 use WgslShaderSourceType::*;
125 match self {
126 EmbedWithNagaOilComposer | ComposerWithRelativePath => quote! {
127 composer.make_naga_module(naga_oil::compose::NagaModuleDescriptor {
128 source: #source,
129 file_path: #relative_file_path,
130 shader_defs,
131 ..Default::default()
132 })
133 },
134 _ => panic!("Not supported"),
135 }
136 }
137
138 pub(crate) fn shader_module_params_defs_and_params(
139 &self,
140 ) -> (TokenStream, TokenStream) {
141 use WgslShaderSourceType::*;
142 match self {
143 EmbedSource => {
144 let param_defs = quote!(device: &wgpu::Device);
145 let params = quote!(device);
146 (param_defs, params)
147 }
148 EmbedWithNagaOilComposer => {
149 let param_defs = quote! {
150 device: &wgpu::Device,
151 shader_defs: std::collections::HashMap<String, naga_oil::compose::ShaderDefValue>
152 };
153 let params = quote!(device, shader_defs);
154 (param_defs, params)
155 }
156 ComposerWithRelativePath => {
157 let param_defs = quote! {
158 device: &wgpu::Device,
159 base_dir: &str,
160 shader_defs: std::collections::HashMap<String, naga_oil::compose::ShaderDefValue>,
161 load_file: impl Fn(&str) -> Result<String, std::io::Error>
162 };
163 let params = quote!(device, base_dir, shader_defs, load_file);
164 (param_defs, params)
165 }
166 }
167 }
168}
169
170#[derive(Constructor)]
171struct ComputeModuleBuilder<'a> {
172 module: &'a naga::Module,
173 source_type_flags: BitFlags<WgslShaderSourceType>,
174}
175
176impl<'a> ComputeModuleBuilder<'a> {
177 fn build_compute_pipeline_fn(
178 e: &naga::EntryPoint,
179 source_type: WgslShaderSourceType,
180 ) -> TokenStream {
181 let pipeline_name = source_type.create_compute_pipeline_fn_name(&e.name);
184
185 let entry_point = &e.name;
186 let label = format!("Compute Pipeline {}", e.name);
188
189 let create_shader_module_fn_name =
190 format_ident!("{}", source_type.create_shader_module_fn_name());
191
192 let (param_defs, params) = source_type.shader_module_params_defs_and_params();
193
194 let return_type = source_type.get_return_type(quote!(wgpu::ComputePipeline));
195 let propagate_operator = source_type.get_propagate_operator();
196
197 let module_creation = quote! {
198 let module = super::#create_shader_module_fn_name(#params)#propagate_operator;
199 };
200
201 let return_value = source_type.wrap_return_stmt(quote! {
202 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
203 label: Some(#label),
204 layout: Some(&layout),
205 module: &module,
206 entry_point: Some(#entry_point),
207 compilation_options: Default::default(),
208 cache: None,
209 })
210 });
211
212 quote! {
213 pub fn #pipeline_name(#param_defs) -> #return_type {
214 #module_creation
215 let layout = super::create_pipeline_layout(device);
216 #return_value
217 }
218 }
219 }
220
221 fn workgroup_size(e: &naga::EntryPoint) -> TokenStream {
222 let name = format_ident!("{}_WORKGROUP_SIZE", e.name.to_uppercase());
224 let [x, y, z] = e.workgroup_size.map(|s| Index::from(s as usize));
225 quote!(pub const #name: [u32; 3] = [#x, #y, #z];)
226 }
227
228 pub(crate) fn entry_points_iter(&self) -> impl Iterator<Item = &naga::EntryPoint> {
229 self
230 .module
231 .entry_points
232 .iter()
233 .filter(|e| e.stage == naga::ShaderStage::Compute)
234 }
235
236 fn build(&self) -> TokenStream {
237 let entry_points: Vec<_> = self
238 .entry_points_iter()
239 .map(|e| {
240 let workgroup_size_constant = Self::workgroup_size(e);
241
242 let create_pipeline_fns = self
243 .source_type_flags
244 .iter()
245 .map(|source_type| Self::build_compute_pipeline_fn(e, source_type))
246 .collect::<Vec<_>>();
247
248 quote! {
249 #workgroup_size_constant
250 #(#create_pipeline_fns)*
251 }
252 })
253 .collect();
254
255 if entry_points.is_empty() {
256 quote!()
258 } else {
259 quote! {
260 pub mod compute {
261 use super::{_root, _root::*};
262 #(#entry_points)*
263 }
264 }
265 }
266 }
267}
268pub(crate) fn compute_module(
269 module: &naga::Module,
270 source_type_flags: BitFlags<WgslShaderSourceType>,
271) -> TokenStream {
272 ComputeModuleBuilder::new(module, source_type_flags).build()
273}
274
275fn generate_shader_module_embedded(entry: &WgslEntryResult) -> TokenStream {
276 let shader_content = module_to_source(&entry.naga_module).unwrap();
277 let create_shader_module_fn =
278 format_ident!("{}", WgslShaderSourceType::EmbedSource.create_shader_module_fn_name());
279 let shader_literal = create_shader_raw_string_literal(&shader_content);
280 let shader_label = entry.get_label();
281 let create_shader_module = quote! {
282 pub fn #create_shader_module_fn(device: &wgpu::Device) -> wgpu::ShaderModule {
283 let source = std::borrow::Cow::Borrowed(SHADER_STRING);
284 device.create_shader_module(wgpu::ShaderModuleDescriptor {
285 label: #shader_label,
286 source: wgpu::ShaderSource::Wgsl(source)
287 })
288 }
289 };
290 let shader_str_def = quote!(pub const SHADER_STRING: &str = #shader_literal;);
291
292 quote! {
293 #create_shader_module
294 #shader_str_def
295 }
296}
297
298struct ComposeShaderModuleBuilder<'a, 'b> {
299 entry: &'a WgslEntryResult<'b>,
300 capabilities: Option<naga::valid::Capabilities>,
301 entry_source_path: &'a Path,
302 output_dir: &'a Path,
303 workspace_root: &'a Path,
304 source_type: WgslShaderSourceType,
305 shader_defs: crate::FastIndexMap<String, naga_oil::compose::ShaderDefValue>,
306}
307
308impl<'a, 'b> ComposeShaderModuleBuilder<'a, 'b> {
309 fn new(
310 entry: &'a WgslEntryResult<'b>,
311 capabilities: Option<naga::valid::Capabilities>,
312 output_dir: &'a Path,
313 workspace_root: &'a Path,
314 source_type: WgslShaderSourceType,
315 shader_defs: &[(String, naga_oil::compose::ShaderDefValue)],
316 ) -> Self {
317 let entry_source_path = entry.source_including_deps.source_file.file_path.as_path();
318
319 let shader_defs_map: crate::FastIndexMap<String, naga_oil::compose::ShaderDefValue> =
321 shader_defs.iter().cloned().collect();
322
323 Self {
324 entry,
325 capabilities,
326 output_dir,
327 workspace_root,
328 source_type,
329 entry_source_path,
330 shader_defs: shader_defs_map,
331 }
332 }
333
334 fn generate_constants_for_paths(&self) -> TokenStream {
335 use WgslShaderSourceType::*;
336
337 match self.source_type {
338 ComposerWithRelativePath => {
339 let shader_entry_path =
340 get_path_relative_to(self.workspace_root, self.entry_source_path);
341 quote! {
342 pub const SHADER_ENTRY_PATH: &str = #shader_entry_path;
343 }
344 }
345 _ => quote!(),
346 }
347 }
348
349 fn create_shader_module_fn_name(&self) -> Ident {
350 let name = self.source_type.create_shader_module_fn_name();
351 format_ident!("{}", name)
352 }
353
354 fn build_shader_dependency_modules_statements(&self) -> Vec<TokenStream> {
355 let dependency_modules = self
356 .entry
357 .source_including_deps
358 .full_dependencies
359 .iter()
360 .map(|dep| {
361 let as_name = dep
362 .module_name
363 .as_ref()
364 .map(|name| name.to_string())
365 .unwrap();
366 let as_name_assignment = quote! { as_name: Some(#as_name.into()) };
367
368 let relative_file_path = get_path_relative_to(self.output_dir, &dep.file_path);
369 let source = quote!(include_str!(#relative_file_path));
370
371 self.source_type.add_composable_naga_module_stmt(
372 source,
373 relative_file_path,
374 as_name_assignment,
375 )
376 })
377 .collect::<Vec<_>>();
378
379 dependency_modules
380 }
381
382 fn build_load_shader_module_fn(&self) -> TokenStream {
383 use WgslShaderSourceType::*;
384
385 let load_shader_module_fn_name = self.source_type.load_shader_module_fn_name();
386 let return_type = self.source_type.get_return_type(quote!(wgpu::naga::Module));
387
388 match self.source_type {
389 ComposerWithRelativePath => {
390 quote!()
392 }
393 _ => {
394 let dependency_modules = self.build_shader_dependency_modules_statements();
396 let relative_file_path =
397 get_path_relative_to(self.output_dir, self.entry_source_path);
398
399 let source = quote!(include_str!(#relative_file_path));
400
401 let make_naga_module_stmt = self
402 .source_type
403 .generate_make_naga_module_statement(source, relative_file_path);
404
405 quote! {
406 pub fn #load_shader_module_fn_name(
407 composer: &mut naga_oil::compose::Composer,
408 shader_defs: std::collections::HashMap<String, naga_oil::compose::ShaderDefValue>
409 ) -> #return_type {
410 #(#dependency_modules)*
411 #make_naga_module_stmt
412 }
413 }
414 }
415 }
416 }
417
418 fn create_shader_module_fn(&self) -> TokenStream {
419 use WgslShaderSourceType::*;
420
421 let create_shader_module_fn = self.create_shader_module_fn_name();
422 let load_shader_module_fn_name = self.source_type.load_shader_module_fn_name();
423 let shader_label = self.entry.get_label();
424
425 let shader_enum_variant = self.entry.get_shader_variant();
426 let return_type = self.source_type.get_return_type(quote!(wgpu::ShaderModule));
427 let propagate_operator = self.source_type.get_propagate_operator();
428 let return_stmt = self.source_type.wrap_return_stmt(quote! { shader_module });
429
430 let composer = quote!(naga_oil::compose::Composer::default());
431
432 let composer_with_capabilities = match self.capabilities {
433 Some(capabilities) => {
434 let capabilities_expr = quote_naga_capabilities(capabilities);
435 quote! {
436 #composer.with_capabilities(#capabilities_expr)
437 }
438 }
439 None => quote! {
440 #composer
441 },
442 };
443
444 match self.source_type {
445 ComposerWithRelativePath => {
446 quote! {
447 pub fn #create_shader_module_fn(
448 device: &wgpu::Device,
449 base_dir: &str,
450 shader_defs: std::collections::HashMap<String, naga_oil::compose::ShaderDefValue>,
451 load_file: impl Fn(&str) -> Result<String, std::io::Error>,
452 ) -> #return_type
453 {
454 let mut composer = #composer_with_capabilities;
455 let module = ShaderEntry::#shader_enum_variant.load_naga_module_from_path(base_dir, &mut composer, shader_defs, load_file).map_err(|e| {
456 naga_oil::compose::ComposerError {
457 inner: naga_oil::compose::ComposerErrorInner::ImportNotFound(e, 0),
458 source: naga_oil::compose::ErrSource::Constructing {
459 path: "load_naga_module_from_path".to_string(),
460 source: "Generated code".to_string(),
461 offset: 0,
462 },
463 }
464 })?;
465
466 let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
468 label: #shader_label,
469 source: wgpu::ShaderSource::Naga(std::borrow::Cow::Owned(module))
470 });
471
472 #return_stmt
473 }
474 }
475 }
476 _ => {
477 quote! {
478 pub fn #create_shader_module_fn(
479 device: &wgpu::Device,
480 shader_defs: std::collections::HashMap<String, naga_oil::compose::ShaderDefValue>
481 ) -> #return_type {
482
483 let mut composer = #composer_with_capabilities;
484 let module = #load_shader_module_fn_name (&mut composer, shader_defs) #propagate_operator;
485
486 let info = wgpu::naga::valid::Validator::new(
488 wgpu::naga::valid::ValidationFlags::empty(),
489 wgpu::naga::valid::Capabilities::all(),
490 )
491 .validate(&module)
492 .unwrap();
493
494 let shader_string = wgpu::naga::back::wgsl::write_string(
496 &module,
497 &info,
498 wgpu::naga::back::wgsl::WriterFlags::empty(),
499 ).expect("failed to convert naga module to source");
500
501 let source = std::borrow::Cow::Owned(shader_string);
502 let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
503 label: #shader_label,
504 source: wgpu::ShaderSource::Wgsl(source)
505 });
506
507 #return_stmt
508 }
509 }
510 }
511 }
512 }
513
514 fn build(&self) -> TokenStream {
515 use WgslShaderSourceType::*;
516
517 let constants = self.generate_constants_for_paths();
518 let load_shader_module_fn = self.build_load_shader_module_fn();
519 let create_shader_module_fn = self.create_shader_module_fn();
520
521 quote! {
522 #constants
523 #load_shader_module_fn
524 #create_shader_module_fn
525 }
526 }
527}
528
529pub(crate) fn generate_global_load_naga_module_from_path() -> TokenStream {
530 quote! {
531 pub fn visit_shader_files(
547 &self,
548 base_dir: &str,
549 load_file: impl Fn(&str) -> Result<String, std::io::Error>,
550 mut visitor: impl FnMut(&str, &str),
551 ) -> Result<(), String> {
552 fn visit_dependencies_recursive(
553 base_dir: &str,
554 source: &str,
555 current_path: &str,
556 load_file: &impl Fn(&str) -> Result<String, std::io::Error>,
557 visitor: &mut impl FnMut(&str, &str),
558 visited: &mut std::collections::HashSet<String>,
559 ) -> Result<(), String> {
560 let (_, imports, _) = naga_oil::compose::get_preprocessor_data(source);
562
563 for import in imports {
564 let import_path = if import.import.starts_with('\"') {
565 import.import
567 .chars()
568 .skip(1)
569 .take_while(|c| *c != '\"')
570 .collect::<String>()
571 } else {
572 let module_path = import.import.split("::").collect::<Vec<_>>().join(std::path::MAIN_SEPARATOR_STR);
574 format!("{module_path}.wgsl")
575 };
576
577 let full_import_path = if import_path.starts_with('/') || import_path.starts_with('\\') {
580 format!("{base_dir}{import_path}")
581 } else {
582 std::path::Path::new(base_dir).join(import_path).display().to_string()
584 };
585
586 if visited.contains(&full_import_path) {
588 continue;
589 }
590
591 visited.insert(full_import_path.clone());
592
593 let import_source = match load_file(&full_import_path) {
595 Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
596 continue;
597 }
598 Err(err) => {
599 return Err(format!("Failed to load import file {full_import_path}: {err}"));
600 }
601 Ok(content) => content,
602 };
603
604 visit_dependencies_recursive(
606 base_dir,
607 &import_source,
608 full_import_path.trim_start_matches(&format!("{base_dir}/")),
609 load_file,
610 visitor,
611 visited,
612 )?;
613
614 visitor(&full_import_path, &import_source);
616 }
617
618 Ok(())
619 }
620
621 let entry_path = format!("{}/{}", base_dir, self.relative_path());
623 let entry_source = load_file(&entry_path)
624 .map_err(|e| format!("Failed to load entry point {entry_path}: {e}"))?;
625
626 visitor(&entry_path, &entry_source);
628
629 let mut visited = std::collections::HashSet::new();
631 visit_dependencies_recursive(
632 base_dir,
633 &entry_source,
634 self.relative_path(),
635 &load_file,
636 &mut visitor,
637 &mut visited,
638 )?;
639
640 Ok(())
641 }
642
643 pub fn load_naga_module_from_path_contents(
644 &self,
645 base_dir: &str,
646 composer: &mut naga_oil::compose::Composer,
647 shader_defs: std::collections::HashMap<String, naga_oil::compose::ShaderDefValue>,
648 files: Vec<(String, String)>,
649 ) -> Result<wgpu::naga::Module, naga_oil::compose::ComposerError>
650 {
651 let entry_path = format!("{}/{}", base_dir, self.relative_path());
653
654 for (file_path, file_content) in &files {
655 if *file_path == entry_path {
656 continue; }
658
659 let relative_path = file_path.trim_start_matches(&format!("{base_dir}/"));
661 let as_name = std::path::Path::new(relative_path)
662 .with_extension("")
663 .with_extension("")
664 .iter()
665 .flat_map(|s| s.to_str())
666 .collect::<Vec<_>>()
667 .join("::")
668 .to_string();
669
670 composer.add_composable_module(naga_oil::compose::ComposableModuleDescriptor {
671 source: file_content,
672 file_path: relative_path,
673 language: naga_oil::compose::ShaderLanguage::Wgsl,
674 shader_defs: shader_defs.clone(),
675 as_name: Some(as_name),
676 ..Default::default()
677 })?;
678 }
679
680 let (_, entry_source) = &files[0];
682
683 composer.make_naga_module(naga_oil::compose::NagaModuleDescriptor {
685 source: entry_source,
686 file_path: self.relative_path(),
687 shader_defs,
688 ..Default::default()
689 })
690 }
691
692 pub fn load_naga_module_from_path(
693 &self,
694 base_dir: &str,
695 composer: &mut naga_oil::compose::Composer,
696 shader_defs: std::collections::HashMap<String, naga_oil::compose::ShaderDefValue>,
697 load_file: impl Fn(&str) -> Result<String, std::io::Error>,
698 ) -> Result<wgpu::naga::Module, String>
699 {
700 let mut files = Vec::<(String, String)>::new();
701 self.visit_shader_files(base_dir, &load_file, |file_path, file_content| {
702 files.push((file_path.to_string(), file_content.to_string()));
703 })?;
704 self.load_naga_module_from_path_contents(base_dir, composer, shader_defs, files)
705 .map_err(|e| format!("{e}"))
706 }
707 }
708}
709
710pub(crate) fn shader_module(
711 entry: &WgslEntryResult,
712 options: &WgslBindgenOption,
713) -> TokenStream {
714 use WgslShaderSourceType::*;
715 let source_type = options.shader_source_type;
716 let output_dir = options
717 .output
718 .as_ref()
719 .and_then(|output_file| output_file.parent().map(|p| p.to_path_buf()))
720 .unwrap_or_else(|| {
721 std::env::var("CARGO_MANIFEST_DIR")
722 .unwrap_or_else(|_| ".".into())
723 .into()
724 });
725
726 let mut token_stream = TokenStream::new();
727
728 if source_type.contains(EmbedSource) {
729 token_stream.append_all(generate_shader_module_embedded(entry));
730 }
731
732 let capabilities = options.ir_capabilities;
733
734 if source_type.contains(EmbedWithNagaOilComposer) {
735 let builder = ComposeShaderModuleBuilder::new(
736 entry,
737 capabilities,
738 &output_dir,
739 &output_dir,
740 EmbedWithNagaOilComposer,
741 &options.shader_defs,
742 );
743 token_stream.append_all(builder.build());
744 }
745
746 if source_type.contains(ComposerWithRelativePath) {
747 let builder = ComposeShaderModuleBuilder::new(
748 entry,
749 capabilities,
750 &output_dir,
751 &options.workspace_root,
752 ComposerWithRelativePath,
753 &options.shader_defs,
754 );
755 token_stream.append_all(builder.build());
756 }
757
758 token_stream
759}
760
761fn get_path_relative_to(relative_to: &std::path::Path, file: &std::path::Path) -> String {
762 pathdiff::diff_paths(file, relative_to)
763 .expect("failed to get relative path")
764 .to_str()
765 .unwrap()
766 .to_string()
767}
768
769fn create_canonical_variable_name(name: &str, is_const: bool) -> String {
770 let canonical_name = name
771 .replace("::", "_")
772 .replace(" ", "_")
773 .chars()
774 .filter(|c| c.is_alphanumeric() || *c == '_')
775 .collect::<String>();
776
777 if is_const {
778 canonical_name.to_uppercase()
779 } else {
780 canonical_name.to_lowercase()
781 }
782}
783
784#[cfg(test)]
785mod tests {
786 use indoc::indoc;
787
788 use super::*;
789 use crate::assert_tokens_snapshot;
790
791 #[test]
792 fn test_create_canonical_variable_name() {
793 assert_eq!("foo", create_canonical_variable_name("Foo", false));
794 assert_eq!("FOO", create_canonical_variable_name("Foo", true));
795 assert_eq!("foo_bar", create_canonical_variable_name("Foo::Bar", false));
796 assert_eq!("FOO_BAR", create_canonical_variable_name("Foo::Bar", true));
797 assert_eq!("foo_bar", create_canonical_variable_name("Foo Bar", false));
798 assert_eq!("FOO_BAR", create_canonical_variable_name("Foo Bar", true));
799 }
800
801 #[test]
802 fn write_compute_module_empty() {
803 let source = indoc! {r#"
804 @vertex
805 fn main() {}
806 "#};
807
808 let module = naga::front::wgsl::parse_str(source).unwrap();
809 let actual = compute_module(&module, WgslShaderSourceType::EmbedSource.into());
810
811 assert_tokens_snapshot!(actual);
812 }
813
814 #[test]
815 fn write_compute_module_multiple_entries() {
816 let source = indoc! {r#"
817 @compute
818 @workgroup_size(1,2,3)
819 fn main1() {}
820
821 @compute
822 @workgroup_size(256)
823 fn main2() {}
824 "#
825 };
826
827 let module = naga::front::wgsl::parse_str(source).unwrap();
828 let actual = compute_module(&module, WgslShaderSourceType::EmbedSource.into());
829
830 assert_tokens_snapshot!(actual);
831 }
832}