wgsl_types/
ty.rs

1//! WGSL [`Type`]s.
2
3use std::str::FromStr;
4
5use crate::{Error, Instance, inst::*, syntax::*};
6
7#[derive(Clone, Debug, PartialEq, Eq)]
8pub struct StructMemberType {
9    pub name: String,
10    pub ty: Type,
11    pub size: Option<u32>,
12    pub align: Option<u32>,
13}
14
15impl StructMemberType {
16    pub fn new(name: String, ty: Type) -> Self {
17        Self {
18            name,
19            ty,
20            size: None,
21            align: None,
22        }
23    }
24}
25
26#[derive(Clone, Debug, PartialEq, Eq)]
27pub struct StructType {
28    pub name: String,
29    pub members: Vec<StructMemberType>,
30}
31
32impl From<StructType> for Type {
33    fn from(value: StructType) -> Self {
34        Self::Struct(Box::new(value))
35    }
36}
37
38#[derive(Clone, Debug, PartialEq, Eq, Hash)]
39pub enum TextureType {
40    // sampled
41    Sampled1D(SampledType),
42    Sampled2D(SampledType),
43    Sampled2DArray(SampledType),
44    Sampled3D(SampledType),
45    SampledCube(SampledType),
46    SampledCubeArray(SampledType),
47    // multisampled
48    Multisampled2D(SampledType),
49    DepthMultisampled2D,
50    // external
51    External,
52    // storage
53    Storage1D(TexelFormat, AccessMode),
54    Storage2D(TexelFormat, AccessMode),
55    Storage2DArray(TexelFormat, AccessMode),
56    Storage3D(TexelFormat, AccessMode),
57    // depth
58    Depth2D,
59    Depth2DArray,
60    DepthCube,
61    DepthCubeArray,
62    #[cfg(feature = "naga-ext")]
63    Sampled1DArray(SampledType),
64    #[cfg(feature = "naga-ext")]
65    Storage1DArray(TexelFormat, AccessMode),
66    #[cfg(feature = "naga-ext")]
67    Multisampled2DArray(SampledType),
68}
69
70#[derive(Clone, Debug, PartialEq, Eq)]
71pub enum TextureDimensions {
72    D1,
73    D2,
74    D3,
75}
76
77impl TextureType {
78    pub fn dimensions(&self) -> TextureDimensions {
79        match self {
80            Self::Sampled1D(_) | Self::Storage1D(_, _) => TextureDimensions::D1,
81            Self::Sampled2D(_)
82            | Self::Sampled2DArray(_)
83            | Self::SampledCube(_)
84            | Self::SampledCubeArray(_)
85            | Self::Multisampled2D(_)
86            | Self::Depth2D
87            | Self::Depth2DArray
88            | Self::DepthCube
89            | Self::DepthCubeArray
90            | Self::DepthMultisampled2D
91            | Self::Storage2D(_, _)
92            | Self::Storage2DArray(_, _)
93            | Self::External => TextureDimensions::D2,
94            Self::Sampled3D(_) | Self::Storage3D(_, _) => TextureDimensions::D3,
95            #[cfg(feature = "naga-ext")]
96            Self::Sampled1DArray(_) | Self::Storage1DArray(_, _) => TextureDimensions::D1,
97            #[cfg(feature = "naga-ext")]
98            Self::Multisampled2DArray(_) => TextureDimensions::D2,
99        }
100    }
101    pub fn sampled_type(&self) -> Option<SampledType> {
102        match self {
103            TextureType::Sampled1D(st) => Some(*st),
104            TextureType::Sampled2D(st) => Some(*st),
105            TextureType::Sampled2DArray(st) => Some(*st),
106            TextureType::Sampled3D(st) => Some(*st),
107            TextureType::SampledCube(st) => Some(*st),
108            TextureType::SampledCubeArray(st) => Some(*st),
109            TextureType::Multisampled2D(_) => None,
110            TextureType::DepthMultisampled2D => None,
111            TextureType::External => None,
112            TextureType::Storage1D(_, _) => None,
113            TextureType::Storage2D(_, _) => None,
114            TextureType::Storage2DArray(_, _) => None,
115            TextureType::Storage3D(_, _) => None,
116            TextureType::Depth2D => None,
117            TextureType::Depth2DArray => None,
118            TextureType::DepthCube => None,
119            TextureType::DepthCubeArray => None,
120            #[cfg(feature = "naga-ext")]
121            TextureType::Sampled1DArray(st) => Some(*st),
122            #[cfg(feature = "naga-ext")]
123            TextureType::Storage1DArray(_, _) => None,
124            #[cfg(feature = "naga-ext")]
125            TextureType::Multisampled2DArray(st) => Some(*st),
126        }
127    }
128    pub fn channel_type(&self) -> SampledType {
129        match self {
130            TextureType::Sampled1D(st) => *st,
131            TextureType::Sampled2D(st) => *st,
132            TextureType::Sampled2DArray(st) => *st,
133            TextureType::Sampled3D(st) => *st,
134            TextureType::SampledCube(st) => *st,
135            TextureType::SampledCubeArray(st) => *st,
136            TextureType::Multisampled2D(st) => *st,
137            TextureType::DepthMultisampled2D => SampledType::F32,
138            TextureType::External => SampledType::F32,
139            TextureType::Storage1D(f, _) => f.channel_type(),
140            TextureType::Storage2D(f, _) => f.channel_type(),
141            TextureType::Storage2DArray(f, _) => f.channel_type(),
142            TextureType::Storage3D(f, _) => f.channel_type(),
143            TextureType::Depth2D => SampledType::F32,
144            TextureType::Depth2DArray => SampledType::F32,
145            TextureType::DepthCube => SampledType::F32,
146            TextureType::DepthCubeArray => SampledType::F32,
147            #[cfg(feature = "naga-ext")]
148            TextureType::Sampled1DArray(st) => *st,
149            #[cfg(feature = "naga-ext")]
150            TextureType::Storage1DArray(f, _) => f.channel_type(),
151            #[cfg(feature = "naga-ext")]
152            TextureType::Multisampled2DArray(st) => *st,
153        }
154    }
155    pub fn is_depth(&self) -> bool {
156        matches!(
157            self,
158            TextureType::Depth2D
159                | TextureType::Depth2DArray
160                | TextureType::DepthCube
161                | TextureType::DepthCubeArray
162        )
163    }
164    pub fn is_storage(&self) -> bool {
165        match self {
166            TextureType::Storage1D(_, _)
167            | TextureType::Storage2D(_, _)
168            | TextureType::Storage2DArray(_, _)
169            | TextureType::Storage3D(_, _) => true,
170            #[cfg(feature = "naga-ext")]
171            TextureType::Storage1DArray(_, _) => true,
172            _ => false,
173        }
174    }
175    pub fn is_sampled(&self) -> bool {
176        match self {
177            TextureType::Sampled1D(_)
178            | TextureType::Sampled2D(_)
179            | TextureType::Sampled2DArray(_)
180            | TextureType::Sampled3D(_)
181            | TextureType::SampledCube(_)
182            | TextureType::SampledCubeArray(_) => true,
183            #[cfg(feature = "naga-ext")]
184            TextureType::Sampled1DArray(_) => true,
185            _ => false,
186        }
187    }
188    pub fn is_arrayed(&self) -> bool {
189        match self {
190            TextureType::Sampled2DArray(_)
191            | TextureType::SampledCubeArray(_)
192            | TextureType::Storage2DArray(_, _)
193            | TextureType::Depth2DArray
194            | TextureType::DepthCubeArray => true,
195            #[cfg(feature = "naga-ext")]
196            TextureType::Sampled1DArray(_)
197            | TextureType::Storage1DArray(_, _)
198            | TextureType::Multisampled2DArray(_) => true,
199            _ => false,
200        }
201    }
202    pub fn is_multisampled(&self) -> bool {
203        match self {
204            TextureType::Multisampled2D(_) | TextureType::DepthMultisampled2D => true,
205            #[cfg(feature = "naga-ext")]
206            TextureType::Multisampled2DArray(_) => true,
207            _ => false,
208        }
209    }
210}
211
212impl TryFrom<&Type> for SampledType {
213    type Error = Error;
214
215    fn try_from(value: &Type) -> Result<Self, Self::Error> {
216        match value {
217            Type::I32 => Ok(SampledType::I32),
218            Type::U32 => Ok(SampledType::U32),
219            Type::F32 => Ok(SampledType::F32),
220            _ => Err(Error::SampledType(value.clone())),
221        }
222    }
223}
224
225impl From<SampledType> for Type {
226    fn from(value: SampledType) -> Self {
227        match value {
228            SampledType::I32 => Type::I32,
229            SampledType::U32 => Type::U32,
230            SampledType::F32 => Type::F32,
231        }
232    }
233}
234
235#[derive(Clone, Debug, PartialEq, Eq, Hash)]
236pub enum SamplerType {
237    Sampler,
238    SamplerComparison,
239}
240
241impl FromStr for SamplerType {
242    type Err = ();
243
244    fn from_str(s: &str) -> Result<Self, Self::Err> {
245        match s {
246            "sampler" => Ok(Self::Sampler),
247            "sampler_comparison" => Ok(Self::SamplerComparison),
248            _ => Err(()),
249        }
250    }
251}
252
253/// WGSL type.
254#[derive(Clone, Debug, PartialEq, Eq)]
255pub enum Type {
256    Bool,
257    AbstractInt,
258    AbstractFloat,
259    I32,
260    U32,
261    F32,
262    F16,
263    Struct(Box<StructType>),
264    Array(Box<Type>, Option<usize>),
265    Vec(u8, Box<Type>),
266    Mat(u8, u8, Box<Type>),
267    Atomic(Box<Type>),
268    Ptr(AddressSpace, Box<Type>, AccessMode),
269    Ref(AddressSpace, Box<Type>, AccessMode),
270    Texture(TextureType),
271    Sampler(SamplerType),
272    #[cfg(feature = "naga-ext")]
273    I64,
274    #[cfg(feature = "naga-ext")]
275    U64,
276    #[cfg(feature = "naga-ext")]
277    F64,
278    #[cfg(feature = "naga-ext")]
279    BindingArray(Box<Type>, Option<usize>),
280    #[cfg(feature = "naga-ext")]
281    RayQuery(Option<AccelerationStructureFlags>),
282    #[cfg(feature = "naga-ext")]
283    AccelerationStructure(Option<AccelerationStructureFlags>),
284}
285
286impl Type {
287    /// Reference: <https://www.w3.org/TR/WGSL/#scalar>
288    pub fn is_scalar(&self) -> bool {
289        match self {
290            Type::Bool
291            | Type::AbstractInt
292            | Type::AbstractFloat
293            | Type::I32
294            | Type::U32
295            | Type::F32
296            | Type::F16 => true,
297            #[cfg(feature = "naga-ext")]
298            Type::I64 | Type::U64 | Type::F64 => true,
299            _ => false,
300        }
301    }
302
303    /// Reference: <https://www.w3.org/TR/WGSL/#numeric-scalar>
304    pub fn is_numeric(&self) -> bool {
305        match self {
306            Type::AbstractInt
307            | Type::AbstractFloat
308            | Type::I32
309            | Type::U32
310            | Type::F32
311            | Type::F16 => true,
312            #[cfg(feature = "naga-ext")]
313            Type::I64 | Type::U64 | Type::F64 => true,
314            _ => false,
315        }
316    }
317
318    /// Reference: <https://www.w3.org/TR/WGSL/#integer-scalar>
319    pub fn is_integer(&self) -> bool {
320        match self {
321            Type::AbstractInt | Type::I32 | Type::U32 => true,
322            #[cfg(feature = "naga-ext")]
323            Type::I64 | Type::U64 => true,
324            _ => false,
325        }
326    }
327
328    /// Reference: <https://www.w3.org/TR/WGSL/#floating-point-types>
329    pub fn is_float(&self) -> bool {
330        match self {
331            Type::AbstractFloat | Type::F32 | Type::F16 => true,
332            #[cfg(feature = "naga-ext")]
333            Type::F64 => true,
334            _ => false,
335        }
336    }
337
338    /// Reference: <https://www.w3.org/TR/WGSL/#abstract-types>
339    pub fn is_abstract(&self) -> bool {
340        match self {
341            Type::AbstractInt => true,
342            Type::AbstractFloat => true,
343            Type::Array(ty, _) | Type::Vec(_, ty) | Type::Mat(_, _, ty) => ty.is_abstract(),
344            _ => false,
345        }
346    }
347
348    pub fn is_concrete(&self) -> bool {
349        !self.is_abstract()
350    }
351
352    /// Reference: <https://www.w3.org/TR/WGSL/#storable-types>
353    pub fn is_storable(&self) -> bool {
354        self.is_concrete()
355            && match self {
356                Type::Bool
357                | Type::I32
358                | Type::U32
359                | Type::F32
360                | Type::F16
361                | Type::Struct(_)
362                | Type::Array(_, _)
363                | Type::Vec(_, _)
364                | Type::Mat(_, _, _)
365                | Type::Atomic(_) => true,
366                #[cfg(feature = "naga-ext")]
367                Type::I64 | Type::U64 | Type::F64 => true,
368                _ => false,
369            }
370    }
371
372    pub fn is_array(&self) -> bool {
373        matches!(self, Type::Array(_, _))
374    }
375    pub fn is_vec(&self) -> bool {
376        matches!(self, Type::Vec(_, _))
377    }
378    pub fn is_i32(&self) -> bool {
379        matches!(self, Type::I32)
380    }
381    pub fn is_u32(&self) -> bool {
382        matches!(self, Type::U32)
383    }
384    pub fn is_f32(&self) -> bool {
385        matches!(self, Type::F32)
386    }
387    #[cfg(feature = "naga-ext")]
388    pub fn is_i64(&self) -> bool {
389        matches!(self, Type::I64)
390    }
391    #[cfg(feature = "naga-ext")]
392    pub fn is_u64(&self) -> bool {
393        matches!(self, Type::U64)
394    }
395    #[cfg(feature = "naga-ext")]
396    pub fn is_f64(&self) -> bool {
397        matches!(self, Type::F64)
398    }
399    pub fn is_bool(&self) -> bool {
400        matches!(self, Type::Bool)
401    }
402    pub fn is_mat(&self) -> bool {
403        matches!(self, Type::Mat(_, _, _))
404    }
405    pub fn is_abstract_int(&self) -> bool {
406        matches!(self, Type::AbstractInt)
407    }
408
409    pub fn unwrap_atomic(self) -> Box<Type> {
410        match self {
411            Type::Atomic(ty) => ty,
412            val => panic!("called `Type::unwrap_atomic()` on a `{val}` value"),
413        }
414    }
415
416    pub fn unwrap_struct(self) -> Box<StructType> {
417        match self {
418            Type::Struct(ty) => ty,
419            val => panic!("called `Type::unwrap_struct()` on a `{val}` value"),
420        }
421    }
422
423    pub fn unwrap_vec(self) -> (u8, Box<Type>) {
424        match self {
425            Type::Vec(size, ty) => (size, ty),
426            val => panic!("called `Type::unwrap_vec()` on a `{val}` value"),
427        }
428    }
429}
430
431pub trait Ty {
432    /// get the type of an instance.
433    fn ty(&self) -> Type;
434
435    /// get the inner type of an instance (not recursive).
436    ///
437    /// e.g. the inner type of `array<vec3<u32>>` is `vec3<u32>`.
438    fn inner_ty(&self) -> Type {
439        self.ty()
440    }
441}
442
443impl Ty for Type {
444    fn ty(&self) -> Type {
445        self.clone()
446    }
447
448    fn inner_ty(&self) -> Type {
449        match self {
450            Type::Bool => self.clone(),
451            Type::AbstractInt => self.clone(),
452            Type::AbstractFloat => self.clone(),
453            Type::I32 => self.clone(),
454            Type::U32 => self.clone(),
455            Type::F32 => self.clone(),
456            Type::F16 => self.clone(),
457            Type::Struct(_) => self.clone(),
458            Type::Array(ty, _) => ty.ty(),
459            Type::Vec(_, ty) => ty.ty(),
460            Type::Mat(_, _, ty) => ty.ty(),
461            Type::Atomic(ty) => ty.ty(),
462            Type::Ptr(_, ty, _) => ty.ty(),
463            Type::Ref(_, ty, _) => ty.ty(),
464            Type::Texture(_) => self.clone(),
465            Type::Sampler(_) => self.clone(),
466            #[cfg(feature = "naga-ext")]
467            Type::I64 => self.clone(),
468            #[cfg(feature = "naga-ext")]
469            Type::U64 => self.clone(),
470            #[cfg(feature = "naga-ext")]
471            Type::F64 => self.clone(),
472            #[cfg(feature = "naga-ext")]
473            Type::BindingArray(ty, _) => ty.ty(),
474            #[cfg(feature = "naga-ext")]
475            Type::RayQuery(_) => self.clone(),
476            #[cfg(feature = "naga-ext")]
477            Type::AccelerationStructure(_) => self.clone(),
478        }
479    }
480}
481
482impl Ty for Instance {
483    fn ty(&self) -> Type {
484        match self {
485            Instance::Literal(l) => l.ty(),
486            Instance::Struct(s) => s.ty(),
487            Instance::Array(a) => a.ty(),
488            Instance::Vec(v) => v.ty(),
489            Instance::Mat(m) => m.ty(),
490            Instance::Ptr(p) => p.ty(),
491            Instance::Ref(r) => r.ty(),
492            Instance::Atomic(a) => a.ty(),
493            Instance::Deferred(t) => t.ty(),
494        }
495    }
496    fn inner_ty(&self) -> Type {
497        match self {
498            Instance::Literal(l) => l.inner_ty(),
499            Instance::Struct(s) => s.inner_ty(),
500            Instance::Array(a) => a.inner_ty(),
501            Instance::Vec(v) => v.inner_ty(),
502            Instance::Mat(m) => m.inner_ty(),
503            Instance::Ptr(p) => p.inner_ty(),
504            Instance::Ref(r) => r.inner_ty(),
505            Instance::Atomic(a) => a.inner_ty(),
506            Instance::Deferred(t) => t.inner_ty(),
507        }
508    }
509}
510
511impl Ty for LiteralInstance {
512    fn ty(&self) -> Type {
513        match self {
514            LiteralInstance::Bool(_) => Type::Bool,
515            LiteralInstance::AbstractInt(_) => Type::AbstractInt,
516            LiteralInstance::AbstractFloat(_) => Type::AbstractFloat,
517            LiteralInstance::I32(_) => Type::I32,
518            LiteralInstance::U32(_) => Type::U32,
519            LiteralInstance::F32(_) => Type::F32,
520            LiteralInstance::F16(_) => Type::F16,
521            #[cfg(feature = "naga-ext")]
522            LiteralInstance::I64(_) => Type::I64,
523            #[cfg(feature = "naga-ext")]
524            LiteralInstance::U64(_) => Type::U64,
525            #[cfg(feature = "naga-ext")]
526            LiteralInstance::F64(_) => Type::F64,
527        }
528    }
529}
530
531impl Ty for StructInstance {
532    fn ty(&self) -> Type {
533        self.ty.clone().into()
534    }
535}
536
537impl Ty for ArrayInstance {
538    fn ty(&self) -> Type {
539        Type::Array(
540            Box::new(self.inner_ty().clone()),
541            (!self.runtime_sized).then_some(self.n()),
542        )
543    }
544    fn inner_ty(&self) -> Type {
545        self.get(0).unwrap().ty()
546    }
547}
548
549impl Ty for VecInstance {
550    fn ty(&self) -> Type {
551        Type::Vec(self.n() as u8, Box::new(self.inner_ty()))
552    }
553    fn inner_ty(&self) -> Type {
554        self.get(0).unwrap().ty()
555    }
556}
557
558impl Ty for MatInstance {
559    fn ty(&self) -> Type {
560        Type::Mat(self.c() as u8, self.r() as u8, Box::new(self.inner_ty()))
561    }
562    fn inner_ty(&self) -> Type {
563        self.get(0, 0).unwrap().ty()
564    }
565}
566
567impl Ty for PtrInstance {
568    fn ty(&self) -> Type {
569        Type::Ptr(
570            self.ptr.space,
571            Box::new(self.ptr.ty.clone()),
572            self.ptr.access,
573        )
574    }
575}
576
577impl Ty for RefInstance {
578    fn ty(&self) -> Type {
579        Type::Ref(self.space, Box::new(self.ty.clone()), self.access)
580    }
581}
582
583impl Ty for AtomicInstance {
584    fn ty(&self) -> Type {
585        Type::Atomic(self.inner_ty().into())
586    }
587    fn inner_ty(&self) -> Type {
588        self.inner().ty()
589    }
590}