wgsl_types/
inst.rs

1//! [`Instance`]s of WGSL [`Type`]s.
2
3use std::{
4    cell::{Ref, RefCell, RefMut},
5    ops::Index,
6    rc::Rc,
7};
8
9use half::f16;
10use itertools::Itertools;
11
12use crate::{
13    Error,
14    syntax::{AccessMode, AddressSpace},
15    ty::{StructType, Ty, Type},
16};
17
18type E = Error;
19
20/// Path to a memory view of an instance.
21///
22/// See [Instance::view]
23#[derive(Clone, Debug, PartialEq, Eq)]
24pub enum MemView {
25    /// View the whole instance.
26    Whole,
27    /// Access a `struct` member or `vec` component.
28    Member(String, Box<MemView>),
29    /// Access an `array`, `vec` or `mat` component.
30    Index(usize, Box<MemView>),
31}
32
33impl MemView {
34    pub fn append_member(&mut self, comp: String) {
35        match self {
36            MemView::Whole => *self = MemView::Member(comp, Box::new(MemView::Whole)),
37            MemView::Member(_, v) | MemView::Index(_, v) => v.append_member(comp),
38        }
39    }
40    pub fn append_index(&mut self, index: usize) {
41        match self {
42            MemView::Whole => *self = MemView::Index(index, Box::new(MemView::Whole)),
43            MemView::Member(_, v) | MemView::Index(_, v) => v.append_index(index),
44        }
45    }
46}
47
48/// Instance of a plain type.
49///
50/// Reference: <https://www.w3.org/TR/WGSL/#plain-types-section>
51#[derive(Clone, Debug, PartialEq)]
52pub enum Instance {
53    Literal(LiteralInstance),
54    Struct(StructInstance),
55    Array(ArrayInstance),
56    Vec(VecInstance),
57    Mat(MatInstance),
58    Ptr(PtrInstance),
59    Ref(RefInstance),
60    Atomic(AtomicInstance),
61    /// For instances that cannot be computed currently, we store the type.
62    /// TODO: remove this
63    Deferred(Type),
64}
65
66impl Instance {
67    pub fn unwrap_literal(self) -> LiteralInstance {
68        match self {
69            Instance::Literal(field_0) => field_0,
70            val => panic!("called `Instance::unwrap_literal()` on a `{val}` value"),
71        }
72    }
73    pub fn unwrap_literal_ref(&self) -> &LiteralInstance {
74        match self {
75            Instance::Literal(field_0) => field_0,
76            val => panic!("called `Instance::unwrap_literal_ref()` on a `{val}` value"),
77        }
78    }
79    pub fn unwrap_vec(self) -> VecInstance {
80        match self {
81            Instance::Vec(field_0) => field_0,
82            val => panic!("called `Instance::unwrap_vec()` on a `{val}` value"),
83        }
84    }
85    pub fn unwrap_vec_ref(&self) -> &VecInstance {
86        match self {
87            Instance::Vec(field_0) => field_0,
88            val => panic!("called `Instance::unwrap_vec_ref()` on a `{val}` value"),
89        }
90    }
91    pub fn unwrap_vec_mut(&mut self) -> &mut VecInstance {
92        match self {
93            Instance::Vec(field_0) => field_0,
94            val => panic!("called `Instance::unwrap_vec_mut()` on a `{val}` value"),
95        }
96    }
97}
98
99macro_rules! from_enum {
100    ($target_enum:ident :: $field:ident ( $from:ident )) => {
101        impl From<$from> for $target_enum {
102            fn from(value: $from) -> Self {
103                $target_enum::$field(value)
104            }
105        }
106    };
107}
108
109from_enum!(Instance::Literal(LiteralInstance));
110from_enum!(Instance::Struct(StructInstance));
111from_enum!(Instance::Array(ArrayInstance));
112from_enum!(Instance::Vec(VecInstance));
113from_enum!(Instance::Mat(MatInstance));
114from_enum!(Instance::Ptr(PtrInstance));
115from_enum!(Instance::Ref(RefInstance));
116from_enum!(Instance::Atomic(AtomicInstance));
117from_enum!(Instance::Deferred(Type));
118
119// Transitive `From` implementations.
120// They have to be implemented manually unfortunately.
121
122macro_rules! impl_transitive_from {
123    ($from:ident => $middle:ident => $into:ident) => {
124        impl From<$from> for $into {
125            fn from(value: $from) -> Self {
126                $into::from($middle::from(value))
127            }
128        }
129    };
130}
131
132impl_transitive_from!(bool => LiteralInstance => Instance);
133impl_transitive_from!(i64 => LiteralInstance => Instance);
134impl_transitive_from!(f64 => LiteralInstance => Instance);
135impl_transitive_from!(i32 => LiteralInstance => Instance);
136impl_transitive_from!(u32 => LiteralInstance => Instance);
137impl_transitive_from!(f32 => LiteralInstance => Instance);
138
139impl Instance {
140    /// Get an instance representing a memory view.
141    ///
142    /// There are two ways to create a memory view:
143    /// * Accessing a `struct` component (`struct.member`)
144    /// * Indexing an `array`, `vec`, or `mat` (`arr[n]`)
145    ///
146    /// Reference: <https://www.w3.org/TR/WGSL/#memory-views>
147    pub fn view(&self, view: &MemView) -> Result<&Instance, E> {
148        match view {
149            MemView::Whole => Ok(self),
150            MemView::Member(m, v) => match self {
151                Instance::Struct(s) => {
152                    let inst = s.member(m).ok_or_else(|| E::Component(s.ty(), m.clone()))?;
153                    inst.view(v)
154                }
155                _ => Err(E::Component(self.ty(), m.clone())),
156            },
157            MemView::Index(i, view) => match self {
158                Instance::Array(a) => {
159                    let inst = a
160                        .components
161                        .get(*i)
162                        .ok_or(E::OutOfBounds(*i, a.ty(), a.n()))?;
163                    inst.view(view)
164                }
165                Instance::Vec(v) => {
166                    let inst = v
167                        .components
168                        .get(*i)
169                        .ok_or(E::OutOfBounds(*i, v.ty(), v.n()))?;
170                    inst.view(view)
171                }
172                Instance::Mat(m) => {
173                    let inst = m
174                        .components
175                        .get(*i)
176                        .ok_or(E::OutOfBounds(*i, m.ty(), m.c()))?;
177                    inst.view(view)
178                }
179                _ => Err(E::NotIndexable(self.ty())),
180            },
181        }
182    }
183
184    /// Get an instance representing a memory view.
185    ///
186    /// See [Self::view]
187    pub fn view_mut(&mut self, view: &MemView) -> Result<&mut Instance, E> {
188        let ty = self.ty();
189        match view {
190            MemView::Whole => Ok(self),
191            MemView::Member(m, v) => match self {
192                Instance::Struct(s) => {
193                    let inst = s.member_mut(m).ok_or_else(|| E::Component(ty, m.clone()))?;
194                    inst.view_mut(v)
195                }
196                _ => Err(E::Component(ty, m.clone())),
197            },
198            MemView::Index(i, view) => match self {
199                Instance::Array(a) => {
200                    let n = a.n();
201                    let inst = a.components.get_mut(*i).ok_or(E::OutOfBounds(*i, ty, n))?;
202                    inst.view_mut(view)
203                }
204                Instance::Vec(v) => {
205                    let n = v.n();
206                    let inst = v.components.get_mut(*i).ok_or(E::OutOfBounds(*i, ty, n))?;
207                    inst.view_mut(view)
208                }
209                Instance::Mat(m) => {
210                    let c = m.c();
211                    let inst = m.components.get_mut(*i).ok_or(E::OutOfBounds(*i, ty, c))?;
212                    inst.view_mut(view)
213                }
214                _ => Err(E::NotIndexable(ty)),
215            },
216        }
217    }
218
219    /// Mutate the instance.
220    ///
221    /// This is the operation performed by the assignment operator.
222    pub fn write(&mut self, value: Instance) -> Result<Instance, E> {
223        if value.ty() != self.ty() {
224            return Err(E::WriteRefType(value.ty(), self.ty()));
225        }
226        let old = std::mem::replace(self, value);
227        Ok(old)
228    }
229}
230
231/// Instance of a numeric literal type.
232#[derive(Clone, Copy, Debug, PartialEq)]
233pub enum LiteralInstance {
234    Bool(bool),
235    AbstractInt(i64),
236    AbstractFloat(f64),
237    I32(i32),
238    U32(u32),
239    F32(f32),
240    F16(f16),
241    #[cfg(feature = "naga-ext")]
242    I64(i64), // identity if representable
243    #[cfg(feature = "naga-ext")]
244    U64(u64), // reinterpretation of bits
245    #[cfg(feature = "naga-ext")]
246    F64(f64),
247}
248
249from_enum!(LiteralInstance::Bool(bool));
250from_enum!(LiteralInstance::AbstractInt(i64));
251from_enum!(LiteralInstance::AbstractFloat(f64));
252from_enum!(LiteralInstance::I32(i32));
253from_enum!(LiteralInstance::U32(u32));
254from_enum!(LiteralInstance::F32(f32));
255from_enum!(LiteralInstance::F16(f16));
256
257impl LiteralInstance {
258    pub fn unwrap_bool(self) -> bool {
259        match self {
260            LiteralInstance::Bool(field_0) => field_0,
261            val => panic!("called `LiteralInstance::unwrap_bool()` on a `{val}` value"),
262        }
263    }
264
265    pub fn unwrap_abstract_int(self) -> i64 {
266        match self {
267            LiteralInstance::AbstractInt(field_0) => field_0,
268            val => panic!("called `LiteralInstance::unwrap_abstract_int()` on a `{val}` value"),
269        }
270    }
271    pub fn unwrap_abstract_float(self) -> f64 {
272        match self {
273            LiteralInstance::AbstractFloat(field_0) => field_0,
274            val => panic!("called `LiteralInstance::unwrap_abstract_float()` on a `{val}` value"),
275        }
276    }
277    pub fn unwrap_i32(self) -> i32 {
278        match self {
279            LiteralInstance::I32(field_0) => field_0,
280            val => panic!("called `LiteralInstance::unwrap_i32()` on a `{val}` value"),
281        }
282    }
283    pub fn unwrap_u32(self) -> u32 {
284        match self {
285            LiteralInstance::U32(field_0) => field_0,
286            val => panic!("called `LiteralInstance::unwrap_u32()` on a `{val}` value"),
287        }
288    }
289    pub fn unwrap_f32(self) -> f32 {
290        match self {
291            LiteralInstance::F32(field_0) => field_0,
292            val => panic!("called `LiteralInstance::unwrap_f32()` on a `{val}` value"),
293        }
294    }
295    pub fn unwrap_f16(self) -> f16 {
296        match self {
297            LiteralInstance::F16(field_0) => field_0,
298            val => panic!("called `LiteralInstance::unwrap_f16()` on a `{val}` value"),
299        }
300    }
301    #[cfg(feature = "naga-ext")]
302    pub fn unwrap_i64(self) -> i64 {
303        match self {
304            LiteralInstance::I64(field_0) => field_0,
305            val => panic!("called `LiteralInstance::unwrap_i64()` on a `{val}` value"),
306        }
307    }
308    #[cfg(feature = "naga-ext")]
309    pub fn unwrap_u64(self) -> u64 {
310        match self {
311            LiteralInstance::U64(field_0) => field_0,
312            val => panic!("called `LiteralInstance::unwrap_u64()` on a `{val}` value"),
313        }
314    }
315    #[cfg(feature = "naga-ext")]
316    pub fn unwrap_f64(self) -> f64 {
317        match self {
318            LiteralInstance::F64(field_0) => field_0,
319            val => panic!("called `LiteralInstance::unwrap_f64()` on a `{val}` value"),
320        }
321    }
322}
323
324/// Instance of a `struct` type.
325///
326/// Reference: <https://www.w3.org/TR/WGSL/#struct-types>
327#[derive(Clone, Debug, PartialEq)]
328pub struct StructInstance {
329    pub ty: StructType,
330    pub members: Vec<Instance>,
331}
332
333impl StructInstance {
334    /// Create a `struct` instance.
335    ///
336    /// # Panics
337    /// * if there is not the right number of members
338    /// * if the members are not of the right type
339    pub fn new(ty: StructType, members: Vec<Instance>) -> Self {
340        assert_eq!(ty.members.len(), members.len());
341        for (m, m_ty) in members.iter().zip(&ty.members) {
342            assert_eq!(m_ty.ty, m.ty());
343        }
344
345        Self { ty, members }
346    }
347    /// Get a `struct` member value by name.
348    pub fn member(&self, name: &str) -> Option<&Instance> {
349        self.members
350            .iter()
351            .zip(&self.ty.members)
352            .find_map(|(inst, m_ty)| (m_ty.name == name).then_some(inst))
353    }
354    /// Get a `struct` member value by name.
355    pub fn member_mut(&mut self, name: &str) -> Option<&mut Instance> {
356        self.members
357            .iter_mut()
358            .zip(&self.ty.members)
359            .find_map(|(inst, m_ty)| (m_ty.name == name).then_some(inst))
360    }
361    // pub fn iter_members(&self) -> impl Iterator<Item = &(String, Instance)> {
362    //     self.members.iter()
363    // }
364}
365
366/// Instance of an `array<T, N>` type.
367///
368/// Reference: <https://www.w3.org/TR/WGSL/#array-types>
369#[derive(Clone, Debug, PartialEq, Default)]
370pub struct ArrayInstance {
371    components: Vec<Instance>,
372    pub runtime_sized: bool,
373}
374
375impl ArrayInstance {
376    /// Construct an `array`.
377    ///
378    /// # Panics
379    /// * if the components is empty
380    /// * if the components are not all the same type
381    pub fn new(components: Vec<Instance>, runtime_sized: bool) -> Self {
382        assert!(!components.is_empty());
383        assert!(components.iter().map(|c| c.ty()).all_equal());
384        Self {
385            components,
386            runtime_sized,
387        }
388    }
389    /// The element count.
390    pub fn n(&self) -> usize {
391        self.components.len()
392    }
393    /// Get an element by index.
394    pub fn get(&self, i: usize) -> Option<&Instance> {
395        self.components.get(i)
396    }
397    /// Get an element by index.
398    pub fn get_mut(&mut self, i: usize) -> Option<&mut Instance> {
399        self.components.get_mut(i)
400    }
401    pub fn iter(&self) -> impl Iterator<Item = &Instance> {
402        self.components.iter()
403    }
404    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Instance> {
405        self.components.iter_mut()
406    }
407    pub fn as_slice(&self) -> &[Instance] {
408        self.components.as_slice()
409    }
410}
411
412impl IntoIterator for ArrayInstance {
413    type Item = Instance;
414    type IntoIter = <Vec<Instance> as IntoIterator>::IntoIter;
415
416    fn into_iter(self) -> Self::IntoIter {
417        self.components.into_iter()
418    }
419}
420
421/// Instance of a `vecN<T>` type.
422///
423/// Reference: <https://www.w3.org/TR/WGSL/#vector-types>
424#[derive(Clone, Debug, PartialEq)]
425pub struct VecInstance {
426    components: ArrayInstance,
427}
428
429impl VecInstance {
430    /// Construct a `vec`.
431    ///
432    /// # Panics
433    /// * if the components length is not [2, 3, 4]
434    /// * if the components are not all the same type
435    /// * if the type is not a scalar
436    pub fn new(components: Vec<Instance>) -> Self {
437        assert!((2..=4).contains(&components.len()));
438        let components = ArrayInstance::new(components, false);
439        assert!(components.inner_ty().is_scalar());
440        Self { components }
441    }
442    /// The component count.
443    pub fn n(&self) -> usize {
444        self.components.n()
445    }
446    /// Get a component by index.
447    pub fn get(&self, i: usize) -> Option<&Instance> {
448        self.components.get(i)
449    }
450    /// Get a component by index.
451    pub fn get_mut(&mut self, i: usize) -> Option<&mut Instance> {
452        self.components.get_mut(i)
453    }
454    pub fn iter(&self) -> impl Iterator<Item = &Instance> {
455        self.components.iter()
456    }
457    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Instance> {
458        self.components.iter_mut()
459    }
460    pub fn as_slice(&self) -> &[Instance] {
461        self.components.as_slice()
462    }
463}
464
465impl IntoIterator for VecInstance {
466    type Item = Instance;
467    type IntoIter = <ArrayInstance as IntoIterator>::IntoIter;
468
469    fn into_iter(self) -> Self::IntoIter {
470        self.components.into_iter()
471    }
472}
473
474impl Index<usize> for VecInstance {
475    type Output = Instance;
476
477    fn index(&self, index: usize) -> &Self::Output {
478        self.get(index).unwrap()
479    }
480}
481
482impl<T: Into<Instance>> From<[T; 2]> for VecInstance {
483    fn from(components: [T; 2]) -> Self {
484        Self::new(components.map(Into::into).to_vec())
485    }
486}
487impl<T: Into<Instance>> From<[T; 3]> for VecInstance {
488    fn from(components: [T; 3]) -> Self {
489        Self::new(components.map(Into::into).to_vec())
490    }
491}
492impl<T: Into<Instance>> From<[T; 4]> for VecInstance {
493    fn from(components: [T; 4]) -> Self {
494        Self::new(components.map(Into::into).to_vec())
495    }
496}
497
498/// Instance of a `matCxR<T>` type.
499///
500/// Reference: <https://www.w3.org/TR/WGSL/#matrix-types>
501#[derive(Clone, Debug, PartialEq)]
502pub struct MatInstance {
503    /// Column vectors of the matrix
504    components: Vec<Instance>,
505}
506
507impl MatInstance {
508    /// Construct a `mat` from column vectors.
509    ///
510    /// # Panics
511    /// * if the number of columns is not [2, 3, 4]
512    /// * if the columns don't have the same number of rows
513    /// * if the number of rows is not [2, 3, 4]
514    /// * if the elements don't have the same type
515    /// * if the type is not a scalar
516    pub fn from_cols(components: Vec<Instance>) -> Self {
517        assert!((2..=4).contains(&components.len()));
518        assert!(
519            components
520                .iter()
521                .map(|c| c.unwrap_vec_ref().n())
522                .all_equal(),
523            "MatInstance columns must have the same number for rows"
524        );
525        assert!(
526            components.iter().map(|c| c.ty()).all_equal(),
527            "MatInstance columns must have the same type"
528        );
529        Self { components }
530    }
531
532    /// The row count.
533    pub fn r(&self) -> usize {
534        self.components.first().unwrap().unwrap_vec_ref().n()
535    }
536    /// The column count.
537    pub fn c(&self) -> usize {
538        self.components.len()
539    }
540    /// Get a column vector.
541    pub fn col(&self, i: usize) -> Option<&Instance> {
542        self.components.get(i)
543    }
544    /// Get a column vector.
545    pub fn col_mut(&mut self, i: usize) -> Option<&mut Instance> {
546        self.components.get_mut(i)
547    }
548    /// Get a component.
549    pub fn get(&self, col: usize, row: usize) -> Option<&Instance> {
550        self.col(col).and_then(|v| v.unwrap_vec_ref().get(row))
551    }
552    /// Get a component.
553    pub fn get_mut(&mut self, i: usize, j: usize) -> Option<&mut Instance> {
554        self.col_mut(i).and_then(|v| v.unwrap_vec_mut().get_mut(j))
555    }
556    pub fn iter_cols(&self) -> impl Iterator<Item = &Instance> {
557        self.components.iter()
558    }
559    pub fn iter_cols_mut(&mut self) -> impl Iterator<Item = &mut Instance> {
560        self.components.iter_mut()
561    }
562    pub fn iter(&self) -> impl Iterator<Item = &Instance> {
563        self.components
564            .iter()
565            .flat_map(|v| v.unwrap_vec_ref().iter())
566    }
567    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Instance> {
568        self.components
569            .iter_mut()
570            .flat_map(|v| v.unwrap_vec_mut().iter_mut())
571    }
572}
573impl IntoIterator for MatInstance {
574    type Item = Instance;
575    type IntoIter = <Vec<Instance> as IntoIterator>::IntoIter;
576
577    fn into_iter(self) -> Self::IntoIter {
578        self.components.into_iter()
579    }
580}
581
582/// Instance of a `ptr<AS,T,AM>` type.
583///
584/// Reference: <https://www.w3.org/TR/WGSL/#ref-ptr-types>
585#[derive(Clone, Debug, PartialEq)]
586pub struct PtrInstance {
587    pub ptr: RefInstance,
588}
589
590impl From<RefInstance> for PtrInstance {
591    fn from(r: RefInstance) -> Self {
592        Self { ptr: r }
593    }
594}
595
596/// Instance of a `ref<AS,T,AM>` type.
597///
598/// Reference: <https://www.w3.org/TR/WGSL/#ref-ptr-types>
599#[derive(Clone, Debug, PartialEq)]
600pub struct RefInstance {
601    /// Inner type
602    pub ty: Type,
603    pub space: AddressSpace,
604    pub access: AccessMode,
605    pub view: MemView,
606    pub ptr: Rc<RefCell<Instance>>,
607}
608
609impl RefInstance {
610    pub fn new(inst: Instance, space: AddressSpace, access: AccessMode) -> Self {
611        let ty = inst.ty();
612        Self {
613            ty,
614            space,
615            access,
616            view: MemView::Whole,
617            ptr: Rc::new(RefCell::new(inst)),
618        }
619    }
620}
621
622impl From<PtrInstance> for RefInstance {
623    fn from(p: PtrInstance) -> Self {
624        p.ptr
625    }
626}
627
628impl RefInstance {
629    /// Get a reference to a `struct` or `vec` member.
630    pub fn view_member(&self, comp: String) -> Result<Self, E> {
631        if !self.access.is_read() {
632            return Err(E::NotRead);
633        }
634        let mut view = self.view.clone();
635        view.append_member(comp);
636        let ty = self.ptr.borrow().view(&view)?.ty();
637        Ok(Self {
638            ty,
639            space: self.space,
640            access: self.access,
641            view,
642            ptr: self.ptr.clone(),
643        })
644    }
645    /// Get a reference to an `array`, `vec` or `mat` component.
646    pub fn view_index(&self, index: usize) -> Result<Self, E> {
647        if !self.access.is_read() {
648            return Err(E::NotRead);
649        }
650        let mut view = self.view.clone();
651        view.append_index(index);
652        let ty = self.ptr.borrow().view(&view)?.ty();
653        Ok(Self {
654            ty,
655            space: self.space,
656            access: self.access,
657            view,
658            ptr: self.ptr.clone(),
659        })
660    }
661
662    pub fn read<'a>(&'a self) -> Result<Ref<'a, Instance>, E> {
663        if !self.access.is_read() {
664            return Err(E::NotRead);
665        }
666        Ok(Ref::<'a, Instance>::map(self.ptr.borrow(), |r| {
667            r.view(&self.view).expect("invalid reference")
668        }))
669    }
670
671    pub fn write(&self, value: Instance) -> Result<(), E> {
672        if !self.access.is_write() {
673            return Err(E::NotWrite);
674        }
675        if value.ty() != self.ty {
676            return Err(E::WriteRefType(value.ty(), self.ty.clone()));
677        }
678        let mut r = self.ptr.borrow_mut();
679        let view = r.view_mut(&self.view).expect("invalid reference");
680        assert!(view.ty() == value.ty());
681        let _ = std::mem::replace(view, value);
682        Ok(())
683    }
684
685    pub fn read_write<'a>(&'a self) -> Result<RefMut<'a, Instance>, E> {
686        if !self.access.is_write() {
687            return Err(E::NotReadWrite);
688        }
689        Ok(RefMut::<'a, Instance>::map(self.ptr.borrow_mut(), |r| {
690            r.view_mut(&self.view).expect("invalid reference")
691        }))
692    }
693}
694
695/// `atomic<T>` Instance.
696///
697/// Reference: <https://www.w3.org/TR/WGSL/#atomic-types>
698#[derive(Clone, Debug, PartialEq)]
699pub struct AtomicInstance {
700    content: Box<Instance>,
701}
702
703impl AtomicInstance {
704    /// # Panics
705    /// * if the instance is not an i32 or u32
706    pub fn new(inst: Instance) -> Self {
707        assert!(matches!(inst.ty(), Type::I32 | Type::U32));
708        Self {
709            content: inst.into(),
710        }
711    }
712
713    pub fn inner(&self) -> &Instance {
714        &self.content
715    }
716
717    pub fn inner_mut(&mut self) -> &mut Instance {
718        &mut self.content
719    }
720}