1use crate::{label, prelude::ShaderModule, target::Target};
2use naga::{
3 valid::{Capabilities, ValidationFlags, Validator},
4 ImageClass, ImageDimension, Module, ScalarKind, StorageClass, TypeInner,
5};
6use std::{collections::BTreeMap, num::NonZeroU64};
7use wgpu::{
8 BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType,
9 BufferBindingType, PipelineLayoutDescriptor, SamplerBindingType, ShaderSource, ShaderStages,
10 TextureSampleType, TextureViewDimension,
11};
12
13pub struct AutoLayout {
16 group: BindGroupLayout,
17}
18
19pub struct AutoLayoutGetter<'a> {
20 groups: [&'a BindGroupLayout; 1],
21}
22
23impl<'a> AutoLayoutGetter<'a> {
26 pub fn get(&self) -> PipelineLayoutDescriptor {
27 PipelineLayoutDescriptor {
28 label: label!(),
29 bind_group_layouts: &self.groups,
30 push_constant_ranges: &[],
31 }
32 }
33}
34
35impl AutoLayout {
36 pub fn new(
37 target: &Target,
38 (vs, vs_main): (&ShaderModule, &str),
39 (fs, fs_main): (&ShaderModule, &str),
40 ) -> Self {
41 let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
42 let vs = Self::module(vs, vs_main, ShaderStages::VERTEX, &mut validator);
43 let fs = Self::module(fs, fs_main, ShaderStages::FRAGMENT, &mut validator);
44
45 let entries: Vec<BindGroupLayoutEntry> = Self::merge(vs, fs)
46 .into_iter()
47 .map(|(_, entry)| entry)
48 .collect();
49
50 let group = target
51 .device
52 .create_bind_group_layout(&BindGroupLayoutDescriptor {
53 label: label!(),
54 entries: &entries,
55 });
56
57 Self { group }
58 }
59
60 pub fn get(&self) -> AutoLayoutGetter {
61 AutoLayoutGetter {
62 groups: [&self.group],
63 }
64 }
65
66 fn module(
67 module: &ShaderModule,
68 entry: &str,
69 visibility: ShaderStages,
70 validator: &mut Validator,
71 ) -> Vec<BindGroupLayoutEntry> {
72 let module = parse(&module.source);
73
74 let i = module
75 .entry_points
76 .iter()
77 .enumerate()
78 .find(|(_, ep)| ep.name == entry)
79 .unwrap()
80 .0;
81
82 let module_info = validator.validate(&module).unwrap();
83
84 let entry_function = module_info.get_entry_point(i);
85
86 let mut layouter = naga::proc::Layouter::default();
87 layouter.update(&module.types, &module.constants).unwrap();
88
89 module
90 .global_variables
91 .iter()
92 .filter(|(handle, _)| !entry_function[*handle].is_empty())
93 .filter_map(|(_, var)| Some((var.class, var.binding.clone()?, var.ty)))
94 .filter_map(|(space, bind, ty)| {
95 let size = layouter[ty];
96 let ty = module.types.get_handle(ty).unwrap();
97
98 match (&ty.inner, space) {
99 (TypeInner::Sampler { .. }, _) => Some(BindGroupLayoutEntry {
100 binding: bind.binding,
101 visibility,
102 ty: BindingType::Sampler(SamplerBindingType::NonFiltering),
103 count: None,
104 }),
105 (
106 TypeInner::Image {
107 dim,
108 arrayed,
109 class,
110 },
111 _,
112 ) => Some(BindGroupLayoutEntry {
113 binding: bind.binding,
114 visibility,
115 ty: BindingType::Texture {
116 sample_type: match class {
117 ImageClass::Sampled {
118 kind: ScalarKind::Float,
119 ..
120 } => TextureSampleType::Float { filterable: false },
121 ImageClass::Sampled {
122 kind: ScalarKind::Sint,
123 ..
124 } => TextureSampleType::Sint,
125 ImageClass::Sampled {
126 kind: ScalarKind::Uint,
127 ..
128 } => TextureSampleType::Uint,
129 ImageClass::Depth { .. } => TextureSampleType::Depth,
130 ImageClass::Storage { .. } => todo!(),
131 _ => todo!(),
132 },
133 view_dimension: match (dim, arrayed) {
134 (ImageDimension::D1, false) => TextureViewDimension::D1,
135 (ImageDimension::D2, false) => TextureViewDimension::D2,
136 (ImageDimension::D2, true) => TextureViewDimension::D2Array,
137 (ImageDimension::D3, false) => TextureViewDimension::D3,
138 (ImageDimension::Cube, false) => TextureViewDimension::Cube,
139 (ImageDimension::Cube, true) => TextureViewDimension::CubeArray,
140 _ => unimplemented!(),
141 },
142 multisampled: false,
143 },
144 count: None,
145 }),
146 (_, StorageClass::Uniform) => Some(BindGroupLayoutEntry {
147 binding: bind.binding,
148 visibility,
149 ty: BindingType::Buffer {
150 ty: BufferBindingType::Uniform,
151 has_dynamic_offset: false,
152 min_binding_size: Some(NonZeroU64::new(size.size as _).unwrap()),
153 },
154 count: None,
155 }),
156 other => unimplemented!("Unimplemented: {other:?}"),
157 }
158 })
159 .collect()
160 }
161
162 fn merge(
163 vs: Vec<BindGroupLayoutEntry>,
164 fs: Vec<BindGroupLayoutEntry>,
165 ) -> BTreeMap<u32, BindGroupLayoutEntry> {
166 let mut first: BTreeMap<u32, BindGroupLayoutEntry> =
167 vs.into_iter().map(|entry| (entry.binding, entry)).collect();
168
169 for mut entry in fs.into_iter() {
170 if let Some(existing_entry) = first.get(&entry.binding) {
171 entry.visibility |= existing_entry.visibility;
172 first.insert(entry.binding, entry);
173 } else {
174 first.insert(entry.binding, entry);
175 }
176 }
177
178 first
179 }
180}
181
182fn parse(source: &ShaderSource) -> Module {
185 match source {
186 #[cfg(feature = "spirv")]
187 ShaderSource::SpirV(spv) => {
188 let options = naga::front::spv::Options {
191 adjust_coordinate_space: false, strict_capabilities: true,
193 block_ctx_dump_prefix: None,
194 };
195 let parser = naga::front::spv::Parser::new(spv.iter().cloned(), &options);
196 parser.parse().unwrap()
197 }
198
199 #[cfg(feature = "glsl")]
200 ShaderSource::Glsl {
201 shader,
202 stage,
203 defines,
204 } => {
205 let options = naga::front::glsl::Options {
208 stage: *stage,
209 defines: defines.clone(),
210 };
211 let mut parser = naga::front::glsl::Parser::default();
212 parser.parse(&options, shader).unwrap()
213 }
214
215 ShaderSource::Wgsl(source) => naga::front::wgsl::parse_str(source).unwrap(),
216 }
217}