1use crate::{
4 Error,
5 inst::{Instance, LiteralInstance},
6 syntax::{AccessMode, AddressSpace, Enumerant, SampledType, TexelFormat},
7 ty::{TextureType, Ty, Type},
8};
9
10#[derive(Clone, Debug, PartialEq)]
12pub enum TpltParam {
13 Type(Type),
14 Instance(Instance),
15 Enumerant(Enumerant),
16}
17
18type E = Error;
19
20pub struct ArrayTemplate {
25 n: Option<usize>,
26 ty: Type,
27}
28
29impl ArrayTemplate {
30 pub fn new(ty: Type, n: Option<usize>) -> Self {
31 Self { n, ty }
32 }
33 pub fn parse(tplt: &[TpltParam]) -> Result<ArrayTemplate, E> {
34 let (ty, n) = match tplt {
35 [TpltParam::Type(ty)] => Ok((ty.clone(), None)),
36 [TpltParam::Type(ty), TpltParam::Instance(n)] => Ok((ty.clone(), Some(n.clone()))),
37 _ => Err(E::TemplateArgs("array")),
38 }?;
39 if let Some(n) = n {
40 let n = match n {
41 Instance::Literal(LiteralInstance::AbstractInt(n)) => (n > 0).then_some(n as usize),
42 Instance::Literal(LiteralInstance::I32(n)) => (n > 0).then_some(n as usize),
43 Instance::Literal(LiteralInstance::U32(n)) => (n > 0).then_some(n as usize),
44 #[cfg(feature = "naga-ext")]
45 Instance::Literal(LiteralInstance::I64(n)) => (n > 0).then_some(n as usize),
46 #[cfg(feature = "naga-ext")]
47 Instance::Literal(LiteralInstance::U64(n)) => (n > 0).then_some(n as usize),
48 _ => None,
49 }
50 .ok_or(E::Builtin(
51 "the array element count must evaluate to a `u32` or a `i32` greater than `0`",
52 ))?;
53 Ok(ArrayTemplate { n: Some(n), ty })
54 } else {
55 Ok(ArrayTemplate { n: None, ty })
56 }
57 }
58 pub fn ty(&self) -> Type {
59 Type::Array(Box::new(self.ty.clone()), self.n)
60 }
61 pub fn inner_ty(&self) -> Type {
62 self.ty.clone()
63 }
64 pub fn n(&self) -> Option<usize> {
65 self.n
66 }
67}
68
69#[cfg(feature = "naga-ext")]
70pub struct BindingArrayTemplate {
71 n: Option<usize>,
72 ty: Type,
73}
74
75#[cfg(feature = "naga-ext")]
76impl BindingArrayTemplate {
77 pub fn parse(tplt: &[TpltParam]) -> Result<BindingArrayTemplate, E> {
78 let (ty, n) = match tplt {
79 [TpltParam::Type(ty)] => Ok((ty.clone(), None)),
80 [TpltParam::Type(ty), TpltParam::Instance(n)] => Ok((ty.clone(), Some(n.clone()))),
81 _ => Err(E::TemplateArgs("binding_array")),
82 }?;
83 if let Some(n) = n {
84 let n = match n {
85 Instance::Literal(LiteralInstance::AbstractInt(n)) => (n > 0).then_some(n as usize),
86 Instance::Literal(LiteralInstance::I32(n)) => (n > 0).then_some(n as usize),
87 Instance::Literal(LiteralInstance::U32(n)) => (n > 0).then_some(n as usize),
88 Instance::Literal(LiteralInstance::I64(n)) => (n > 0).then_some(n as usize),
89 Instance::Literal(LiteralInstance::U64(n)) => (n > 0).then_some(n as usize),
90 _ => None,
91 }
92 .ok_or(E::Builtin(
93 "the binding_array element count must evaluate to a `u32` or a `i32` greater than `0`",
94 ))?;
95 Ok(BindingArrayTemplate { n: Some(n), ty })
96 } else {
97 Ok(BindingArrayTemplate { n: None, ty })
98 }
99 }
100 pub fn ty(&self) -> Type {
101 Type::BindingArray(Box::new(self.ty.clone()), self.n)
102 }
103 pub fn inner_ty(&self) -> Type {
104 self.ty.clone()
105 }
106 pub fn n(&self) -> Option<usize> {
107 self.n
108 }
109}
110
111pub struct VecTemplate {
112 ty: Type,
113}
114
115impl VecTemplate {
116 pub fn parse(tplt: &[TpltParam]) -> Result<VecTemplate, E> {
117 let ty = match tplt {
118 [TpltParam::Type(ty)] => Ok(ty.clone()),
119 _ => Err(E::TemplateArgs("vector")),
120 }?;
121 if ty.is_scalar() && ty.is_concrete() {
122 Ok(VecTemplate { ty })
123 } else {
124 Err(Error::Builtin("vector template type must be a scalar"))
125 }
126 }
127 pub fn ty(&self, n: u8) -> Type {
128 Type::Vec(n, self.ty.clone().into())
129 }
130 pub fn inner_ty(&self) -> &Type {
131 &self.ty
132 }
133}
134
135pub struct MatTemplate {
136 ty: Type,
137}
138
139impl MatTemplate {
140 pub fn parse(tplt: &[TpltParam]) -> Result<MatTemplate, E> {
141 let ty = match tplt {
142 [TpltParam::Type(ty)] => Ok(ty.clone()),
143 _ => Err(E::TemplateArgs("matrix")),
144 }?;
145 if ty.is_float() {
146 Ok(MatTemplate { ty })
147 } else {
148 Err(Error::Builtin("matrix template type must be f32 or f16"))
149 }
150 }
151 pub fn ty(&self, c: u8, r: u8) -> Type {
152 Type::Mat(c, r, self.ty.clone().into())
153 }
154
155 pub fn inner_ty(&self) -> &Type {
156 &self.ty
157 }
158}
159
160pub struct PtrTemplate {
161 pub space: AddressSpace,
162 pub ty: Type,
163 pub access: AccessMode,
164}
165
166impl PtrTemplate {
167 pub fn parse(tplt: &[TpltParam]) -> Result<PtrTemplate, E> {
168 let mut it = tplt.iter();
169 match (
170 it.next().cloned(),
171 it.next().cloned(),
172 it.next().cloned(),
173 it.next(),
174 ) {
175 (
176 Some(TpltParam::Enumerant(Enumerant::AddressSpace(space))),
177 Some(TpltParam::Type(ty)),
178 access,
179 None,
180 ) => {
181 if !ty.is_storable() {
182 return Err(Error::Builtin("pointer type must be storable"));
183 }
184 let access = match access {
185 Some(TpltParam::Enumerant(Enumerant::AccessMode(access))) => Some(access),
186 _ => None,
187 };
188 let access = match (space, access) {
191 (AddressSpace::Function, Some(access))
192 | (AddressSpace::Private, Some(access))
193 | (AddressSpace::Workgroup, Some(access))
194 | (AddressSpace::Storage, Some(access)) => access,
195 (AddressSpace::Function, None)
196 | (AddressSpace::Private, None)
197 | (AddressSpace::Workgroup, None) => AccessMode::ReadWrite,
198 (AddressSpace::Uniform, Some(AccessMode::Read) | None) => AccessMode::Read,
199 (AddressSpace::Uniform, _) => {
200 return Err(Error::Builtin(
201 "pointer in uniform address space must have a `read` access mode",
202 ));
203 }
204 (AddressSpace::Storage, None) => AccessMode::Read,
205 (AddressSpace::Handle, _) => {
206 unreachable!("handle address space cannot be spelled")
207 }
208 #[cfg(feature = "naga-ext")]
209 (AddressSpace::PushConstant, _) => {
210 todo!("push_constant")
211 }
212 };
213 Ok(PtrTemplate { space, ty, access })
214 }
215 _ => Err(E::TemplateArgs("pointer")),
216 }
217 }
218
219 pub fn ty(&self) -> Type {
220 Type::Ptr(self.space, self.ty.clone().into(), self.access)
221 }
222}
223
224pub struct AtomicTemplate {
225 pub ty: Type,
226}
227
228impl AtomicTemplate {
229 pub fn parse(tplt: &[TpltParam]) -> Result<AtomicTemplate, E> {
230 let ty = match tplt {
231 [TpltParam::Type(ty)] => Ok(ty.clone()),
232 _ => Err(E::TemplateArgs("atomic")),
233 }?;
234 #[cfg(feature = "naga-ext")]
235 if ty.is_f32() || ty.is_i64() || ty.is_u64() {
236 return Ok(AtomicTemplate { ty });
237 }
238 if ty.is_i32() || ty.is_u32() {
239 Ok(AtomicTemplate { ty })
240 } else {
241 Err(Error::Builtin("atomic template type must be an integer"))
242 }
243 }
244 pub fn ty(&self) -> Type {
245 Type::Atomic(self.ty.clone().into())
246 }
247 pub fn inner_ty(&self) -> Type {
248 self.ty.clone()
249 }
250}
251
252pub struct TextureTemplate {
253 ty: TextureType,
254}
255
256impl TextureTemplate {
257 pub fn parse(name: &str, tplt: &[TpltParam]) -> Result<TextureTemplate, E> {
258 let ty = match name {
259 "texture_1d" => TextureType::Sampled1D(Self::sampled_type(tplt)?),
260 "texture_2d" => TextureType::Sampled2D(Self::sampled_type(tplt)?),
261 "texture_2d_array" => TextureType::Sampled2DArray(Self::sampled_type(tplt)?),
262 "texture_3d" => TextureType::Sampled3D(Self::sampled_type(tplt)?),
263 "texture_cube" => TextureType::SampledCube(Self::sampled_type(tplt)?),
264 "texture_cube_array" => TextureType::SampledCubeArray(Self::sampled_type(tplt)?),
265 "texture_multisampled_2d" => TextureType::Multisampled2D(Self::sampled_type(tplt)?),
266 "texture_storage_1d" => {
267 let (tex, acc) = Self::texel_access(tplt)?;
268 TextureType::Storage1D(tex, acc)
269 }
270 "texture_storage_2d" => {
271 let (tex, acc) = Self::texel_access(tplt)?;
272 TextureType::Storage2D(tex, acc)
273 }
274 "texture_storage_2d_array" => {
275 let (tex, acc) = Self::texel_access(tplt)?;
276 TextureType::Storage2DArray(tex, acc)
277 }
278 "texture_storage_3d" => {
279 let (tex, acc) = Self::texel_access(tplt)?;
280 TextureType::Storage3D(tex, acc)
281 }
282 #[cfg(feature = "naga-ext")]
283 "texture_1d_array" => TextureType::Sampled1DArray(Self::sampled_type(tplt)?),
284 #[cfg(feature = "naga-ext")]
285 "texture_storage_1d_array" => {
286 let (tex, acc) = Self::texel_access(tplt)?;
287 TextureType::Storage1DArray(tex, acc)
288 }
289 #[cfg(feature = "naga-ext")]
290 "texture_multisampled_2d_array" => {
291 TextureType::Multisampled2DArray(Self::sampled_type(tplt)?)
292 }
293 _ => return Err(E::Builtin("not a templated texture type")),
294 };
295 Ok(Self { ty })
296 }
297 fn sampled_type(tplt: &[TpltParam]) -> Result<SampledType, E> {
298 match tplt {
299 [TpltParam::Type(ty)] => ty.try_into(),
300 [_] => Err(Error::Builtin(
301 "texture sampled type must be `i32`, `u32` or `f32`",
302 )),
303 _ => Err(Error::Builtin(
304 "sampled texture types take a single template parameter",
305 )),
306 }
307 }
308 fn texel_access(tplt: &[TpltParam]) -> Result<(TexelFormat, AccessMode), E> {
309 match tplt {
310 [
311 TpltParam::Enumerant(Enumerant::TexelFormat(texel)),
312 TpltParam::Enumerant(Enumerant::AccessMode(access)),
313 ] => Ok((*texel, *access)),
314 _ => Err(Error::Builtin(
315 "storage texture types take two template parameters",
316 )),
317 }
318 }
319 pub fn ty(&self) -> TextureType {
320 self.ty.clone()
321 }
322}
323
324pub struct BitcastTemplate {
325 ty: Type,
326}
327
328impl BitcastTemplate {
329 pub fn parse(tplt: &[TpltParam]) -> Result<BitcastTemplate, E> {
330 let ty = match tplt {
331 [TpltParam::Type(ty)] => Ok(ty.clone()),
332 _ => Err(E::TemplateArgs("bitcast")),
333 }?;
334 if ty.is_numeric() || ty.is_vec() && ty.inner_ty().is_numeric() {
335 Ok(BitcastTemplate { ty })
336 } else {
337 Err(Error::Builtin(
338 "bitcast template type must be a numeric scalar or numeric vector",
339 ))
340 }
341 }
342 pub fn ty(&self) -> &Type {
343 &self.ty
344 }
345 pub fn inner_ty(&self) -> Type {
346 self.ty.inner_ty()
347 }
348}