wgsl_types/
tplt.rs

1//! Built-in type-generator and function templates.
2
3use crate::{
4    Error,
5    inst::{Instance, LiteralInstance},
6    syntax::{AccessMode, AddressSpace, Enumerant, SampledType, TexelFormat},
7    ty::{TextureType, Ty, Type},
8};
9
10/// A single template parameter.
11#[derive(Clone, Debug, PartialEq)]
12pub enum TpltParam {
13    Type(Type),
14    Instance(Instance),
15    Enumerant(Enumerant),
16}
17
18type E = Error;
19
20// ------------------------
21// TYPE-GENERATOR TEMPLATES
22// ------------------------
23
24pub struct ArrayTemplate {
25    n: Option<usize>,
26    ty: Type,
27}
28
29impl ArrayTemplate {
30    pub fn new(ty: Type, n: Option<usize>) -> Self {
31        Self { n, ty }
32    }
33    pub fn parse(tplt: &[TpltParam]) -> Result<ArrayTemplate, E> {
34        let (ty, n) = match tplt {
35            [TpltParam::Type(ty)] => Ok((ty.clone(), None)),
36            [TpltParam::Type(ty), TpltParam::Instance(n)] => Ok((ty.clone(), Some(n.clone()))),
37            _ => Err(E::TemplateArgs("array")),
38        }?;
39        if let Some(n) = n {
40            let n = match n {
41                Instance::Literal(LiteralInstance::AbstractInt(n)) => (n > 0).then_some(n as usize),
42                Instance::Literal(LiteralInstance::I32(n)) => (n > 0).then_some(n as usize),
43                Instance::Literal(LiteralInstance::U32(n)) => (n > 0).then_some(n as usize),
44                #[cfg(feature = "naga-ext")]
45                Instance::Literal(LiteralInstance::I64(n)) => (n > 0).then_some(n as usize),
46                #[cfg(feature = "naga-ext")]
47                Instance::Literal(LiteralInstance::U64(n)) => (n > 0).then_some(n as usize),
48                _ => None,
49            }
50            .ok_or(E::Builtin(
51                "the array element count must evaluate to a `u32` or a `i32` greater than `0`",
52            ))?;
53            Ok(ArrayTemplate { n: Some(n), ty })
54        } else {
55            Ok(ArrayTemplate { n: None, ty })
56        }
57    }
58    pub fn ty(&self) -> Type {
59        Type::Array(Box::new(self.ty.clone()), self.n)
60    }
61    pub fn inner_ty(&self) -> Type {
62        self.ty.clone()
63    }
64    pub fn n(&self) -> Option<usize> {
65        self.n
66    }
67}
68
69#[cfg(feature = "naga-ext")]
70pub struct BindingArrayTemplate {
71    n: Option<usize>,
72    ty: Type,
73}
74
75#[cfg(feature = "naga-ext")]
76impl BindingArrayTemplate {
77    pub fn parse(tplt: &[TpltParam]) -> Result<BindingArrayTemplate, E> {
78        let (ty, n) = match tplt {
79            [TpltParam::Type(ty)] => Ok((ty.clone(), None)),
80            [TpltParam::Type(ty), TpltParam::Instance(n)] => Ok((ty.clone(), Some(n.clone()))),
81            _ => Err(E::TemplateArgs("binding_array")),
82        }?;
83        if let Some(n) = n {
84            let n = match n {
85                Instance::Literal(LiteralInstance::AbstractInt(n)) => (n > 0).then_some(n as usize),
86                Instance::Literal(LiteralInstance::I32(n)) => (n > 0).then_some(n as usize),
87                Instance::Literal(LiteralInstance::U32(n)) => (n > 0).then_some(n as usize),
88                Instance::Literal(LiteralInstance::I64(n)) => (n > 0).then_some(n as usize),
89                Instance::Literal(LiteralInstance::U64(n)) => (n > 0).then_some(n as usize),
90                _ => None,
91            }
92            .ok_or(E::Builtin(
93                "the binding_array element count must evaluate to a `u32` or a `i32` greater than `0`",
94            ))?;
95            Ok(BindingArrayTemplate { n: Some(n), ty })
96        } else {
97            Ok(BindingArrayTemplate { n: None, ty })
98        }
99    }
100    pub fn ty(&self) -> Type {
101        Type::BindingArray(Box::new(self.ty.clone()), self.n)
102    }
103    pub fn inner_ty(&self) -> Type {
104        self.ty.clone()
105    }
106    pub fn n(&self) -> Option<usize> {
107        self.n
108    }
109}
110
111pub struct VecTemplate {
112    ty: Type,
113}
114
115impl VecTemplate {
116    pub fn parse(tplt: &[TpltParam]) -> Result<VecTemplate, E> {
117        let ty = match tplt {
118            [TpltParam::Type(ty)] => Ok(ty.clone()),
119            _ => Err(E::TemplateArgs("vector")),
120        }?;
121        if ty.is_scalar() && ty.is_concrete() {
122            Ok(VecTemplate { ty })
123        } else {
124            Err(Error::Builtin("vector template type must be a scalar"))
125        }
126    }
127    pub fn ty(&self, n: u8) -> Type {
128        Type::Vec(n, self.ty.clone().into())
129    }
130    pub fn inner_ty(&self) -> &Type {
131        &self.ty
132    }
133}
134
135pub struct MatTemplate {
136    ty: Type,
137}
138
139impl MatTemplate {
140    pub fn parse(tplt: &[TpltParam]) -> Result<MatTemplate, E> {
141        let ty = match tplt {
142            [TpltParam::Type(ty)] => Ok(ty.clone()),
143            _ => Err(E::TemplateArgs("matrix")),
144        }?;
145        if ty.is_float() {
146            Ok(MatTemplate { ty })
147        } else {
148            Err(Error::Builtin("matrix template type must be f32 or f16"))
149        }
150    }
151    pub fn ty(&self, c: u8, r: u8) -> Type {
152        Type::Mat(c, r, self.ty.clone().into())
153    }
154
155    pub fn inner_ty(&self) -> &Type {
156        &self.ty
157    }
158}
159
160pub struct PtrTemplate {
161    pub space: AddressSpace,
162    pub ty: Type,
163    pub access: AccessMode,
164}
165
166impl PtrTemplate {
167    pub fn parse(tplt: &[TpltParam]) -> Result<PtrTemplate, E> {
168        let mut it = tplt.iter();
169        match (
170            it.next().cloned(),
171            it.next().cloned(),
172            it.next().cloned(),
173            it.next(),
174        ) {
175            (
176                Some(TpltParam::Enumerant(Enumerant::AddressSpace(space))),
177                Some(TpltParam::Type(ty)),
178                access,
179                None,
180            ) => {
181                if !ty.is_storable() {
182                    return Err(Error::Builtin("pointer type must be storable"));
183                }
184                let access = match access {
185                    Some(TpltParam::Enumerant(Enumerant::AccessMode(access))) => Some(access),
186                    _ => None,
187                };
188                // selecting the default access mode per address space.
189                // reference: <https://www.w3.org/TR/WGSL/#address-space>
190                let access = match (space, access) {
191                    (AddressSpace::Function, Some(access))
192                    | (AddressSpace::Private, Some(access))
193                    | (AddressSpace::Workgroup, Some(access))
194                    | (AddressSpace::Storage, Some(access)) => access,
195                    (AddressSpace::Function, None)
196                    | (AddressSpace::Private, None)
197                    | (AddressSpace::Workgroup, None) => AccessMode::ReadWrite,
198                    (AddressSpace::Uniform, Some(AccessMode::Read) | None) => AccessMode::Read,
199                    (AddressSpace::Uniform, _) => {
200                        return Err(Error::Builtin(
201                            "pointer in uniform address space must have a `read` access mode",
202                        ));
203                    }
204                    (AddressSpace::Storage, None) => AccessMode::Read,
205                    (AddressSpace::Handle, _) => {
206                        unreachable!("handle address space cannot be spelled")
207                    }
208                    #[cfg(feature = "naga-ext")]
209                    (AddressSpace::PushConstant, _) => {
210                        todo!("push_constant")
211                    }
212                };
213                Ok(PtrTemplate { space, ty, access })
214            }
215            _ => Err(E::TemplateArgs("pointer")),
216        }
217    }
218
219    pub fn ty(&self) -> Type {
220        Type::Ptr(self.space, self.ty.clone().into(), self.access)
221    }
222}
223
224pub struct AtomicTemplate {
225    pub ty: Type,
226}
227
228impl AtomicTemplate {
229    pub fn parse(tplt: &[TpltParam]) -> Result<AtomicTemplate, E> {
230        let ty = match tplt {
231            [TpltParam::Type(ty)] => Ok(ty.clone()),
232            _ => Err(E::TemplateArgs("atomic")),
233        }?;
234        #[cfg(feature = "naga-ext")]
235        if ty.is_f32() || ty.is_i64() || ty.is_u64() {
236            return Ok(AtomicTemplate { ty });
237        }
238        if ty.is_i32() || ty.is_u32() {
239            Ok(AtomicTemplate { ty })
240        } else {
241            Err(Error::Builtin("atomic template type must be an integer"))
242        }
243    }
244    pub fn ty(&self) -> Type {
245        Type::Atomic(self.ty.clone().into())
246    }
247    pub fn inner_ty(&self) -> Type {
248        self.ty.clone()
249    }
250}
251
252pub struct TextureTemplate {
253    ty: TextureType,
254}
255
256impl TextureTemplate {
257    pub fn parse(name: &str, tplt: &[TpltParam]) -> Result<TextureTemplate, E> {
258        let ty = match name {
259            "texture_1d" => TextureType::Sampled1D(Self::sampled_type(tplt)?),
260            "texture_2d" => TextureType::Sampled2D(Self::sampled_type(tplt)?),
261            "texture_2d_array" => TextureType::Sampled2DArray(Self::sampled_type(tplt)?),
262            "texture_3d" => TextureType::Sampled3D(Self::sampled_type(tplt)?),
263            "texture_cube" => TextureType::SampledCube(Self::sampled_type(tplt)?),
264            "texture_cube_array" => TextureType::SampledCubeArray(Self::sampled_type(tplt)?),
265            "texture_multisampled_2d" => TextureType::Multisampled2D(Self::sampled_type(tplt)?),
266            "texture_storage_1d" => {
267                let (tex, acc) = Self::texel_access(tplt)?;
268                TextureType::Storage1D(tex, acc)
269            }
270            "texture_storage_2d" => {
271                let (tex, acc) = Self::texel_access(tplt)?;
272                TextureType::Storage2D(tex, acc)
273            }
274            "texture_storage_2d_array" => {
275                let (tex, acc) = Self::texel_access(tplt)?;
276                TextureType::Storage2DArray(tex, acc)
277            }
278            "texture_storage_3d" => {
279                let (tex, acc) = Self::texel_access(tplt)?;
280                TextureType::Storage3D(tex, acc)
281            }
282            #[cfg(feature = "naga-ext")]
283            "texture_1d_array" => TextureType::Sampled1DArray(Self::sampled_type(tplt)?),
284            #[cfg(feature = "naga-ext")]
285            "texture_storage_1d_array" => {
286                let (tex, acc) = Self::texel_access(tplt)?;
287                TextureType::Storage1DArray(tex, acc)
288            }
289            #[cfg(feature = "naga-ext")]
290            "texture_multisampled_2d_array" => {
291                TextureType::Multisampled2DArray(Self::sampled_type(tplt)?)
292            }
293            _ => return Err(E::Builtin("not a templated texture type")),
294        };
295        Ok(Self { ty })
296    }
297    fn sampled_type(tplt: &[TpltParam]) -> Result<SampledType, E> {
298        match tplt {
299            [TpltParam::Type(ty)] => ty.try_into(),
300            [_] => Err(Error::Builtin(
301                "texture sampled type must be `i32`, `u32` or `f32`",
302            )),
303            _ => Err(Error::Builtin(
304                "sampled texture types take a single template parameter",
305            )),
306        }
307    }
308    fn texel_access(tplt: &[TpltParam]) -> Result<(TexelFormat, AccessMode), E> {
309        match tplt {
310            [
311                TpltParam::Enumerant(Enumerant::TexelFormat(texel)),
312                TpltParam::Enumerant(Enumerant::AccessMode(access)),
313            ] => Ok((*texel, *access)),
314            _ => Err(Error::Builtin(
315                "storage texture types take two template parameters",
316            )),
317        }
318    }
319    pub fn ty(&self) -> TextureType {
320        self.ty.clone()
321    }
322}
323
324pub struct BitcastTemplate {
325    ty: Type,
326}
327
328impl BitcastTemplate {
329    pub fn parse(tplt: &[TpltParam]) -> Result<BitcastTemplate, E> {
330        let ty = match tplt {
331            [TpltParam::Type(ty)] => Ok(ty.clone()),
332            _ => Err(E::TemplateArgs("bitcast")),
333        }?;
334        if ty.is_numeric() || ty.is_vec() && ty.inner_ty().is_numeric() {
335            Ok(BitcastTemplate { ty })
336        } else {
337            Err(Error::Builtin(
338                "bitcast template type must be a numeric scalar or numeric vector",
339            ))
340        }
341    }
342    pub fn ty(&self) -> &Type {
343        &self.ty
344    }
345    pub fn inner_ty(&self) -> Type {
346        self.ty.inner_ty()
347    }
348}