srs2dge_core/shader/
layout.rs

1use crate::{label, prelude::ShaderModule, target::Target};
2use naga::{
3    valid::{Capabilities, ValidationFlags, Validator},
4    ImageClass, ImageDimension, Module, ScalarKind, StorageClass, TypeInner,
5};
6use std::{collections::BTreeMap, num::NonZeroU64};
7use wgpu::{
8    BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType,
9    BufferBindingType, PipelineLayoutDescriptor, SamplerBindingType, ShaderSource, ShaderStages,
10    TextureSampleType, TextureViewDimension,
11};
12
13//
14
15pub struct AutoLayout {
16    group: BindGroupLayout,
17}
18
19pub struct AutoLayoutGetter<'a> {
20    groups: [&'a BindGroupLayout; 1],
21}
22
23//
24
25impl<'a> AutoLayoutGetter<'a> {
26    pub fn get(&self) -> PipelineLayoutDescriptor {
27        PipelineLayoutDescriptor {
28            label: label!(),
29            bind_group_layouts: &self.groups,
30            push_constant_ranges: &[],
31        }
32    }
33}
34
35impl AutoLayout {
36    pub fn new(
37        target: &Target,
38        (vs, vs_main): (&ShaderModule, &str),
39        (fs, fs_main): (&ShaderModule, &str),
40    ) -> Self {
41        let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
42        let vs = Self::module(vs, vs_main, ShaderStages::VERTEX, &mut validator);
43        let fs = Self::module(fs, fs_main, ShaderStages::FRAGMENT, &mut validator);
44
45        let entries: Vec<BindGroupLayoutEntry> = Self::merge(vs, fs)
46            .into_iter()
47            .map(|(_, entry)| entry)
48            .collect();
49
50        let group = target
51            .device
52            .create_bind_group_layout(&BindGroupLayoutDescriptor {
53                label: label!(),
54                entries: &entries,
55            });
56
57        Self { group }
58    }
59
60    pub fn get(&self) -> AutoLayoutGetter {
61        AutoLayoutGetter {
62            groups: [&self.group],
63        }
64    }
65
66    fn module(
67        module: &ShaderModule,
68        entry: &str,
69        visibility: ShaderStages,
70        validator: &mut Validator,
71    ) -> Vec<BindGroupLayoutEntry> {
72        let module = parse(&module.source);
73
74        let i = module
75            .entry_points
76            .iter()
77            .enumerate()
78            .find(|(_, ep)| ep.name == entry)
79            .unwrap()
80            .0;
81
82        let module_info = validator.validate(&module).unwrap();
83
84        let entry_function = module_info.get_entry_point(i);
85
86        let mut layouter = naga::proc::Layouter::default();
87        layouter.update(&module.types, &module.constants).unwrap();
88
89        module
90            .global_variables
91            .iter()
92            .filter(|(handle, _)| !entry_function[*handle].is_empty())
93            .filter_map(|(_, var)| Some((var.class, var.binding.clone()?, var.ty)))
94            .filter_map(|(space, bind, ty)| {
95                let size = layouter[ty];
96                let ty = module.types.get_handle(ty).unwrap();
97
98                match (&ty.inner, space) {
99                    (TypeInner::Sampler { .. }, _) => Some(BindGroupLayoutEntry {
100                        binding: bind.binding,
101                        visibility,
102                        ty: BindingType::Sampler(SamplerBindingType::NonFiltering),
103                        count: None,
104                    }),
105                    (
106                        TypeInner::Image {
107                            dim,
108                            arrayed,
109                            class,
110                        },
111                        _,
112                    ) => Some(BindGroupLayoutEntry {
113                        binding: bind.binding,
114                        visibility,
115                        ty: BindingType::Texture {
116                            sample_type: match class {
117                                ImageClass::Sampled {
118                                    kind: ScalarKind::Float,
119                                    ..
120                                } => TextureSampleType::Float { filterable: false },
121                                ImageClass::Sampled {
122                                    kind: ScalarKind::Sint,
123                                    ..
124                                } => TextureSampleType::Sint,
125                                ImageClass::Sampled {
126                                    kind: ScalarKind::Uint,
127                                    ..
128                                } => TextureSampleType::Uint,
129                                ImageClass::Depth { .. } => TextureSampleType::Depth,
130                                ImageClass::Storage { .. } => todo!(),
131                                _ => todo!(),
132                            },
133                            view_dimension: match (dim, arrayed) {
134                                (ImageDimension::D1, false) => TextureViewDimension::D1,
135                                (ImageDimension::D2, false) => TextureViewDimension::D2,
136                                (ImageDimension::D2, true) => TextureViewDimension::D2Array,
137                                (ImageDimension::D3, false) => TextureViewDimension::D3,
138                                (ImageDimension::Cube, false) => TextureViewDimension::Cube,
139                                (ImageDimension::Cube, true) => TextureViewDimension::CubeArray,
140                                _ => unimplemented!(),
141                            },
142                            multisampled: false,
143                        },
144                        count: None,
145                    }),
146                    (_, StorageClass::Uniform) => Some(BindGroupLayoutEntry {
147                        binding: bind.binding,
148                        visibility,
149                        ty: BindingType::Buffer {
150                            ty: BufferBindingType::Uniform,
151                            has_dynamic_offset: false,
152                            min_binding_size: Some(NonZeroU64::new(size.size as _).unwrap()),
153                        },
154                        count: None,
155                    }),
156                    other => unimplemented!("Unimplemented: {other:?}"),
157                }
158            })
159            .collect()
160    }
161
162    fn merge(
163        vs: Vec<BindGroupLayoutEntry>,
164        fs: Vec<BindGroupLayoutEntry>,
165    ) -> BTreeMap<u32, BindGroupLayoutEntry> {
166        let mut first: BTreeMap<u32, BindGroupLayoutEntry> =
167            vs.into_iter().map(|entry| (entry.binding, entry)).collect();
168
169        for mut entry in fs.into_iter() {
170            if let Some(existing_entry) = first.get(&entry.binding) {
171                entry.visibility |= existing_entry.visibility;
172                first.insert(entry.binding, entry);
173            } else {
174                first.insert(entry.binding, entry);
175            }
176        }
177
178        first
179    }
180}
181
182//
183
184fn parse(source: &ShaderSource) -> Module {
185    match source {
186        #[cfg(feature = "spirv")]
187        ShaderSource::SpirV(spv) => {
188            // source from wgpu repo to keep it somewhat similar:
189            // Parse the given shader code and store its representation.
190            let options = naga::front::spv::Options {
191                adjust_coordinate_space: false, // we require NDC_Y_UP feature
192                strict_capabilities: true,
193                block_ctx_dump_prefix: None,
194            };
195            let parser = naga::front::spv::Parser::new(spv.iter().cloned(), &options);
196            parser.parse().unwrap()
197        }
198
199        #[cfg(feature = "glsl")]
200        ShaderSource::Glsl {
201            shader,
202            stage,
203            defines,
204        } => {
205            // source from wgpu repo to keep it somewhat similar:
206            // Parse the given shader code and store its representation.
207            let options = naga::front::glsl::Options {
208                stage: *stage,
209                defines: defines.clone(),
210            };
211            let mut parser = naga::front::glsl::Parser::default();
212            parser.parse(&options, shader).unwrap()
213        }
214
215        ShaderSource::Wgsl(source) => naga::front::wgsl::parse_str(source).unwrap(),
216    }
217}