1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
mod bindings;
mod types;

use std::path::PathBuf;

pub use bindings::*;
use derive_builder::Builder;
use derive_more::IsVariant;
use enumflags2::{bitflags, BitFlags};
use proc_macro2::TokenStream;
use regex::Regex;
pub use types::*;

use crate::{
  FastIndexMap, WGSLBindgen, WgslBindgenError, WgslType, WgslTypeSerializeStrategy,
};

/// The [wgpu::naga::valid::Capabilities](https://docs.rs/wgpu/latest/wgpu/naga/valid/struct.Capabilities.html) to use for the module.
#[derive(Clone, Copy, Debug)]
pub struct WgslShaderIrCapabilities {
  pub capabilities: naga::valid::Capabilities,
  pub subgroup_stages: naga::valid::ShaderStages,
}

/// An enum representing the source type that will be generated for the output.
#[bitflags(default = UseEmbed)]
#[repr(u8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, IsVariant)]
pub enum WgslShaderSourceType {
  /// Preparse the shader modules and embed the final shader string in the output.
  /// This option skips the naga_oil dependency in the output, and but doesn't allow shader defines.
  UseEmbed = 0b0001,

  /// Use Composer with embedded strings for each shader module,
  /// This option allows shader defines and but doesn't allow hot-reloading.
  UseComposerEmbed = 0b0010,

  /// Use Composer with absolute path to shaders, useful for hot-reloading
  /// This option allows shader defines and is useful for hot-reloading.
  UseComposerWithPath = 0b0100,
}

/// A struct representing a directory to scan for additional source files.
///
/// This struct is used to represent a directory to scan for additional source files
/// when generating Rust bindings for WGSL shaders. The `module_import_root` field
/// is used to specify the root prefix or namespace that should be applied to all
/// shaders given as the entrypoints, and the `directory` field is used to specify
/// the directory to scan for additional source files.
#[derive(Debug, Clone, Default)]
pub struct AdditionalScanDirectory {
  pub module_import_root: Option<String>,
  pub directory: String,
}

impl From<(Option<&str>, &str)> for AdditionalScanDirectory {
  fn from((module_import_root, directory): (Option<&str>, &str)) -> Self {
    Self {
      module_import_root: module_import_root.map(ToString::to_string),
      directory: directory.to_string(),
    }
  }
}

pub type WgslTypeMap = FastIndexMap<WgslType, TokenStream>;

/// A trait for building `WgslType` to `TokenStream` map.
///
/// This map is used to convert built-in WGSL types into their corresponding
/// representations in the generated Rust code. The specific format used for
/// matrix and vector types can vary, and the generated types for the same WGSL
/// type may differ in size or alignment.
///
/// Implementations of this trait provide a `build` function that takes a
/// `WgslTypeSerializeStrategy` and returns an `WgslTypeMap`.
pub trait WgslTypeMapBuild {
  /// Builds the `WgslTypeMap` based on the given serialization strategy.
  fn build(&self, strategy: WgslTypeSerializeStrategy) -> WgslTypeMap;
}

impl WgslTypeMapBuild for WgslTypeMap {
  fn build(&self, _: WgslTypeSerializeStrategy) -> WgslTypeMap {
    self.clone()
  }
}

/// This struct is used to create a custom mapping from the wgsl side to rust side,
/// skipping generation of the struct and using the custom one instead.
/// This also means skipping checks for alignment and size when using bytemuck
/// for the struct.
/// This is useful for core primitive types you would want to model in Rust side
#[derive(Clone, Debug)]
pub struct OverrideStruct {
  /// fully qualified struct name of the struct in wgsl, eg: `lib::fp64::Fp64`
  pub from: String,
  /// fully qualified struct name in your crate, eg: `crate::fp64::Fp64`
  pub to: TokenStream,
}

impl From<(&str, TokenStream)> for OverrideStruct {
  fn from((from, to): (&str, TokenStream)) -> Self {
    OverrideStruct {
      from: from.to_owned(),
      to,
    }
  }
}

/// Struct  for overriding the field type of specific structs.
#[derive(Clone, Debug)]
pub struct OverrideStructFieldType {
  pub struct_regex: Regex,
  pub field_regex: Regex,
  pub override_type: TokenStream,
}
impl From<(Regex, Regex, TokenStream)> for OverrideStructFieldType {
  fn from(
    (struct_regex, field_regex, override_type): (Regex, Regex, TokenStream),
  ) -> Self {
    Self {
      struct_regex,
      field_regex,
      override_type,
    }
  }
}
impl From<(&str, &str, TokenStream)> for OverrideStructFieldType {
  fn from((struct_regex, field_regex, override_type): (&str, &str, TokenStream)) -> Self {
    Self {
      struct_regex: Regex::new(struct_regex).expect("Failed to create struct regex"),
      field_regex: Regex::new(field_regex).expect("Failed to create field regex"),
      override_type,
    }
  }
}

/// Struct for overriding alignment of specific structs.
#[derive(Clone, Debug)]
pub struct OverrideStructAlignment {
  pub struct_regex: Regex,
  pub alignment: u16,
}
impl From<(Regex, u16)> for OverrideStructAlignment {
  fn from((struct_regex, alignment): (Regex, u16)) -> Self {
    Self {
      struct_regex: struct_regex,
      alignment: alignment,
    }
  }
}
impl From<(&str, u16)> for OverrideStructAlignment {
  fn from((struct_regex, alignment): (&str, u16)) -> Self {
    Self {
      struct_regex: Regex::new(struct_regex).expect("Failed to create struct regex"),
      alignment: alignment,
    }
  }
}

/// An enum representing the visibility of the type generated in the output
#[derive(Clone, Copy, PartialEq, Eq, Debug, Default)]
pub enum WgslTypeVisibility {
  /// All exported types set to `pub` visiblity
  #[default]
  Public,

  /// All exported types set to `pub(crate)` visiblity
  RestrictedCrate,

  /// All exported types set to `pub(super)` visiblity
  RestrictedSuper,
}

#[derive(Debug, Default, Builder)]
#[builder(
  setter(into),
  field(private),
  build_fn(private, name = "fallible_build")
)]
pub struct WgslBindgenOption {
  /// A vector of entry points to be added. Each entry point is represented as a `String`.
  #[builder(setter(each(name = "add_entry_point", into)))]
  pub entry_points: Vec<String>,

  /// The root prefix/namespace if any applied to all shaders given as the entrypoints.
  #[builder(default, setter(strip_option, into))]
  pub module_import_root: Option<String>,

  /// The root shader workspace directory where all the imports will tested for resolution.
  #[builder(setter(into))]
  pub workspace_root: PathBuf,

  /// A boolean flag indicating whether to emit a rerun-if-changed directive to Cargo. Defaults to `true`.
  #[builder(default = "true")]
  pub emit_rerun_if_change: bool,

  /// A boolean flag indicating whether to skip header comments. Enabling headers allows to not rerun if contents did not change.
  #[builder(default = "false")]
  pub skip_header_comments: bool,

  /// A boolean flag indicating whether to skip the hash check. This will avoid reruns of bindings generation if
  /// entry shaders including their imports has not changed. Defaults to `false`.
  #[builder(default = "false")]
  pub skip_hash_check: bool,

  /// Derive [encase::ShaderType](https://docs.rs/encase/latest/encase/trait.ShaderType.html#)
  /// for user defined WGSL structs when `WgslTypeSerializeStrategy::Encase`.
  /// else derive bytemuck
  #[builder(default)]
  pub serialization_strategy: WgslTypeSerializeStrategy,

  /// Derive [serde::Serialize](https://docs.rs/serde/1.0.159/serde/trait.Serialize.html)
  /// and [serde::Deserialize](https://docs.rs/serde/1.0.159/serde/trait.Deserialize.html)
  /// for user defined WGSL structs when `true`.
  #[builder(default = "false")]
  pub derive_serde: bool,

  /// The shader source type generated bitflags. Defaults to `WgslShaderSourceType::UseSingleString`.
  #[builder(default)]
  pub shader_source_type: BitFlags<WgslShaderSourceType>,

  /// The output file path for the generated Rust bindings. Defaults to `None`.
  #[builder(default, setter(strip_option, into))]
  pub output: Option<PathBuf>,

  /// The additional set of directories to scan for source files.
  #[builder(default, setter(into, each(name = "additional_scan_dir", into)))]
  pub additional_scan_dirs: Vec<AdditionalScanDirectory>,

  /// The [wgpu::naga::valid::Capabilities](https://docs.rs/wgpu/latest/wgpu/naga/valid/struct.Capabilities.html) to support. Defaults to `None`.
  #[builder(default, setter(strip_option))]
  pub ir_capabilities: Option<WgslShaderIrCapabilities>,

  /// Whether to generate short constructor similar to enums constructors instead of `new`, if number of parameters are below the specified threshold
  /// Defaults to `None`
  #[builder(default, setter(strip_option, into))]
  pub short_constructor: Option<i32>,

  /// Which visiblity to use for the exported types.
  #[builder(default)]
  pub type_visibility: WgslTypeVisibility,

  /// A mapping operation for WGSL built-in types. This is used to map WGSL built-in types to their corresponding representations.
  #[builder(setter(custom))]
  pub type_map: WgslTypeMap,

  /// A vector of custom struct mappings to be added, which will override the struct to be generated.
  /// This is merged with the default struct mappings.
  #[builder(default, setter(each(name = "add_override_struct_mapping", into)))]
  pub override_struct: Vec<OverrideStruct>,

  /// A vector of `OverrideStructFieldType` to override the generated types for struct fields in matching structs.
  #[builder(default, setter(into))]
  pub override_struct_field_type: Vec<OverrideStructFieldType>,

  /// A vector of regular expressions and alignments that override the generated alignment for matching structs.
  /// This can be used in scenarios where a specific minimum alignment is required for a uniform buffer.
  /// Refer to the [WebGPU specs](https://www.w3.org/TR/webgpu/#dom-supported-limits-minuniformbufferoffsetalignment) for more information.
  #[builder(default, setter(into))]
  pub override_struct_alignment: Vec<OverrideStructAlignment>,

  /// The regular expression of the padding fields used in the shader struct types.
  /// These fields will be omitted in the *Init structs generated, and will automatically be assigned the default values.
  #[builder(default, setter(each(name = "add_custom_padding_field_regexp", into)))]
  pub custom_padding_field_regexps: Vec<Regex>,

  /// Whether to always have the init struct generated in the out. This is only applicable when using bytemuck mode.
  #[builder(default = "false")]
  pub always_generate_init_struct: bool,

  /// This field can be used to provide a custom generator for extra bindings that are not covered by the default generator.
  #[builder(default, setter(custom))]
  pub extra_binding_generator: Option<BindingGenerator>,

  /// This field is used to provide the default generator for WGPU bindings. The generator is represented as a `BindingGenerator`.
  #[builder(default, setter(custom))]
  pub wgpu_binding_generator: BindingGenerator,
}

impl WgslBindgenOptionBuilder {
  pub fn build(&mut self) -> Result<WGSLBindgen, WgslBindgenError> {
    self.merge_struct_type_overrides();

    let options = self.fallible_build()?;
    WGSLBindgen::new(options)
  }

  pub fn type_map(&mut self, map_build: impl WgslTypeMapBuild) -> &mut Self {
    let serialization_strategy = self
      .serialization_strategy
      .expect("Serialization strategy must be set before `wgs_type_map`");

    let map = map_build.build(serialization_strategy);

    match self.type_map.as_mut() {
      Some(m) => m.extend(map),
      None => self.type_map = Some(map),
    }

    self
  }

  fn merge_struct_type_overrides(&mut self) {
    let struct_mappings = self
      .override_struct
      .iter()
      .flatten()
      .map(|mapping| {
        let wgsl_type = WgslType::Struct {
          fully_qualified_name: mapping.from.clone(),
        };
        (wgsl_type, mapping.to.clone())
      })
      .collect::<FastIndexMap<_, _>>();

    self.type_map(struct_mappings);
  }

  pub fn extra_binding_generator(
    &mut self,
    config: impl GetBindingsGeneratorConfig,
  ) -> &mut Self {
    let generator = Some(config.get_generator_config());
    self.extra_binding_generator = Some(generator);
    self
  }
}