wgsl_types/
conv.rs

1//! Type conversion algorithms.
2//!
3//! Implementation of the [conversion_rank] algorithm and utilities to convert
4//! [`Type`]s and [`Instance`]s using [automatic conversions].
5//!
6//! [automatic conversions]: https://www.w3.org/TR/WGSL/#feasible-automatic-conversion
7
8use half::f16;
9use itertools::Itertools;
10use num_traits::{FromPrimitive, ToPrimitive};
11
12use crate::{
13    Error, Instance,
14    inst::{ArrayInstance, LiteralInstance, MatInstance, StructInstance, VecInstance},
15    syntax::AccessMode,
16    ty::{Ty, Type},
17};
18
19pub trait Convert: Sized + Clone + Ty {
20    /// Convert an instance to another type, if a feasible conversion exists.
21    ///
22    /// E.g. `array<u32>.convert_inner_to(array<f32>)` becomes `array<f32>`.
23    ///
24    /// Reference: <https://www.w3.org/TR/WGSL/#conversion-rank>
25    fn convert_to(&self, ty: &Type) -> Option<Self>;
26
27    /// Convert an instance by changing its inner type to another.
28    ///
29    /// E.g. `array<u32>.convert_inner_to(f32)` becomes `array<f32>`.
30    ///
31    /// Identical to [`Convert::convert_to`] if the type has no inner type.
32    /// Does not check that the resulting type exists, e.g. `i32` is not a valid inner type for
33    /// matrix types.
34    ///
35    ///
36    /// See [`Ty::inner_ty`]
37    /// See [`Convert::convert_to`]
38    fn convert_inner_to(&self, ty: &Type) -> Option<Self> {
39        self.convert_to(ty)
40    }
41
42    /// Convert an abstract instance to a concrete type.
43    ///
44    /// E.g. `array<vec<AbstractInt>>` becomes `array<vec<i32>>`.
45    fn concretize(&self) -> Option<Self> {
46        self.convert_to(&self.ty().concretize())
47    }
48}
49
50impl Convert for Type {
51    fn convert_to(&self, ty: &Type) -> Option<Self> {
52        self.is_convertible_to(ty).then_some(ty.clone())
53    }
54    fn convert_inner_to(&self, ty: &Type) -> Option<Self> {
55        match self {
56            Type::Array(inner, n) => inner
57                .convert_to(ty)
58                .map(|inner| Type::Array(inner.into(), *n)),
59            Type::Vec(n, inner) => inner
60                .convert_to(ty)
61                .map(|inner| Type::Vec(*n, inner.into())),
62            Type::Mat(c, r, inner) => inner
63                .convert_to(ty)
64                .map(|inner| Type::Mat(*c, *r, inner.into())),
65            Type::Atomic(_) => (self == ty).then_some(ty.clone()),
66            Type::Ptr(_, _, _) => (self == ty).then_some(ty.clone()),
67            _ => self.convert_to(ty), // for types that don't have an inner ty
68        }
69    }
70    fn concretize(&self) -> Option<Self> {
71        Some(self.concretize())
72    }
73}
74
75impl Type {
76    pub fn is_convertible_to(&self, ty: &Type) -> bool {
77        conversion_rank(self, ty).is_some()
78    }
79    pub fn concretize(&self) -> Self {
80        match self {
81            Self::AbstractInt => Type::I32,
82            Self::AbstractFloat => Type::F32,
83            Self::Array(ty, n) => Type::Array(ty.concretize().into(), *n),
84            Self::Vec(n, ty) => Type::Vec(*n, ty.concretize().into()),
85            Self::Mat(c, r, ty) if ty.is_abstract() => Type::Mat(*c, *r, Type::F32.into()),
86            _ => self.clone(),
87        }
88    }
89
90    /// Apply the load rule.
91    ///
92    /// Reference: <https://www.w3.org/TR/WGSL/#load-rule>
93    pub fn loaded(self) -> Self {
94        if let Type::Ref(_, ty, _) = self {
95            ty.loaded()
96        } else {
97            self
98        }
99    }
100}
101
102impl Instance {
103    pub fn is_convertible_to(&self, ty: &Type) -> bool {
104        self.ty().is_convertible_to(ty)
105    }
106
107    /// Apply the load rule.
108    ///
109    /// Reference: <https://www.w3.org/TR/WGSL/#load-rule>
110    pub fn loaded(mut self) -> Result<Self, Error> {
111        while let Instance::Ref(r) = self {
112            self = r.read()?.to_owned();
113        }
114        Ok(self)
115    }
116}
117
118impl LiteralInstance {
119    pub fn is_infinite(&self) -> bool {
120        match self {
121            LiteralInstance::Bool(_) => false,
122            LiteralInstance::AbstractInt(_) => false,
123            LiteralInstance::AbstractFloat(n) => n.is_infinite(),
124            LiteralInstance::I32(_) => false,
125            LiteralInstance::U32(_) => false,
126            LiteralInstance::F32(n) => n.is_infinite(),
127            LiteralInstance::F16(n) => n.is_infinite(),
128            #[cfg(feature = "naga-ext")]
129            LiteralInstance::I64(_) => false,
130            #[cfg(feature = "naga-ext")]
131            LiteralInstance::U64(_) => false,
132            #[cfg(feature = "naga-ext")]
133            LiteralInstance::F64(n) => n.is_infinite(),
134        }
135    }
136    pub fn is_finite(&self) -> bool {
137        !self.is_infinite()
138    }
139}
140
141impl Convert for LiteralInstance {
142    fn convert_to(&self, ty: &Type) -> Option<Self> {
143        if ty == &self.ty() {
144            return Some(*self);
145        }
146
147        // TODO: check that these conversions are correctly implemented.
148        // I think they are incorrect. the to_xyz() functions do not perform rounding.
149        // reference: <https://www.w3.org/TR/WGSL/#floating-point-conversion>
150        // ... except that hex literals must be *exactly* representable in the target type.
151        match (self, ty) {
152            (Self::AbstractInt(n), Type::AbstractFloat) => n.to_f64().map(Self::AbstractFloat),
153            (Self::AbstractInt(n), Type::I32) => n.to_i32().map(Self::I32),
154            (Self::AbstractInt(n), Type::U32) => n.to_u32().map(Self::U32),
155            (Self::AbstractInt(n), Type::F32) => n.to_f32().map(Self::F32),
156            (Self::AbstractInt(n), Type::F16) => f16::from_i64(*n).map(Self::F16),
157            (Self::AbstractFloat(n), Type::F32) => n.to_f32().map(Self::F32),
158            (Self::AbstractFloat(n), Type::F16) => Some(Self::F16(f16::from_f64(*n))),
159            _ => None,
160        }
161        .and_then(|n| n.is_finite().then_some(n))
162    }
163}
164
165impl Convert for ArrayInstance {
166    fn convert_to(&self, ty: &Type) -> Option<Self> {
167        if let Type::Array(c_ty, Some(n)) = ty {
168            if *n == self.n() {
169                self.convert_inner_to(c_ty)
170            } else {
171                None
172            }
173        } else if let Type::Array(c_ty, None) = ty {
174            self.convert_inner_to(c_ty)
175        } else {
176            None
177        }
178    }
179    fn convert_inner_to(&self, ty: &Type) -> Option<Self> {
180        let components = self
181            .iter()
182            .map(|c| c.convert_to(ty))
183            .collect::<Option<Vec<_>>>()?;
184        Some(ArrayInstance::new(components, self.runtime_sized))
185    }
186}
187
188impl Convert for VecInstance {
189    fn convert_to(&self, ty: &Type) -> Option<Self> {
190        if let Type::Vec(n, c_ty) = ty {
191            if *n as usize == self.n() {
192                self.convert_inner_to(c_ty)
193            } else {
194                None
195            }
196        } else {
197            None
198        }
199    }
200    fn convert_inner_to(&self, ty: &Type) -> Option<Self> {
201        let components = self
202            .iter()
203            .map(|c| c.convert_to(ty))
204            .collect::<Option<Vec<_>>>()?;
205        Some(VecInstance::new(components))
206    }
207}
208
209impl Convert for MatInstance {
210    fn convert_to(&self, ty: &Type) -> Option<Self> {
211        if let Type::Mat(c, r, c_ty) = ty {
212            if *c as usize == self.c() && *r as usize == self.r() {
213                self.convert_inner_to(c_ty)
214            } else {
215                None
216            }
217        } else {
218            None
219        }
220    }
221    fn convert_inner_to(&self, ty: &Type) -> Option<Self> {
222        let components = self
223            .iter_cols()
224            .map(|c| c.convert_inner_to(ty))
225            .collect::<Option<Vec<_>>>()?;
226        Some(MatInstance::from_cols(components))
227    }
228}
229
230impl Convert for StructInstance {
231    fn convert_to(&self, ty: &Type) -> Option<Self> {
232        if &self.ty() == ty {
233            Some(self.clone())
234        } else if let Type::Struct(s2) = ty {
235            let s1 = &self.ty;
236            if s1.name.starts_with("__") && s2.name.starts_with("__") {
237                // this is a struct type conversion of built-in types.
238                // __frexp_result_* or __modf_result_*
239                // TODO: here we just assume that s2 is a variant of s1. We should
240                // check.
241                if s2.name.ends_with("f32") {
242                    let members = self
243                        .members
244                        .iter()
245                        .map(|inst| inst.convert_inner_to(&Type::F32))
246                        .collect::<Option<Vec<_>>>()?;
247                    Some(StructInstance::new((**s2).clone(), members))
248                } else if s2.name.ends_with("f16") {
249                    let members = self
250                        .members
251                        .iter()
252                        .map(|inst| inst.convert_inner_to(&Type::F16))
253                        .collect::<Option<Vec<_>>>()?;
254                    Some(StructInstance::new((**s2).clone(), members))
255                } else {
256                    None
257                }
258            } else {
259                None
260            }
261        } else {
262            None
263        }
264    }
265}
266
267impl Convert for Instance {
268    fn convert_to(&self, ty: &Type) -> Option<Self> {
269        if &self.ty() == ty {
270            return Some(self.clone());
271        }
272        match self {
273            Self::Literal(l) => l.convert_to(ty).map(Self::Literal),
274            Self::Struct(s) => s.convert_to(ty).map(Self::Struct),
275            Self::Array(a) => a.convert_to(ty).map(Self::Array),
276            Self::Vec(v) => v.convert_to(ty).map(Self::Vec),
277            Self::Mat(m) => m.convert_to(ty).map(Self::Mat),
278            Self::Ptr(_) => None,
279            Self::Ref(r) => r.read().ok().and_then(|r| r.convert_to(ty)), // this is the "load rule". Also performed by `eval_value`.
280            Self::Atomic(_) => None,
281            Self::Deferred(_) => None,
282        }
283    }
284
285    fn convert_inner_to(&self, ty: &Type) -> Option<Self> {
286        match self {
287            Self::Literal(l) => l.convert_inner_to(ty).map(Self::Literal),
288            Self::Struct(_) => None,
289            Self::Array(a) => a.convert_inner_to(ty).map(Self::Array),
290            Self::Vec(v) => v.convert_inner_to(ty).map(Self::Vec),
291            Self::Mat(m) => m.convert_inner_to(ty).map(Self::Mat),
292            Self::Ptr(_) => None,
293            Self::Ref(r) => r.read().ok().and_then(|r| r.convert_inner_to(ty)), // this is the "load rule". Also performed by `eval_value`.
294            Self::Atomic(_) => None,
295            Self::Deferred(_) => None,
296        }
297    }
298}
299
300/// Implements the [conversion rank algorithm](https://www.w3.org/TR/WGSL/#conversion-rank)
301pub fn conversion_rank(ty1: &Type, ty2: &Type) -> Option<u32> {
302    // reference: <https://www.w3.org/TR/WGSL/#conversion-rank>
303    match (ty1, ty2) {
304        (_, _) if ty1 == ty2 => Some(0),
305        (Type::Ref(_, ty1, AccessMode::Read | AccessMode::ReadWrite), ty2) if &**ty1 == ty2 => {
306            Some(0)
307        }
308        (Type::AbstractInt, Type::AbstractFloat) => Some(5),
309        (Type::AbstractInt, Type::I32) => Some(3),
310        (Type::AbstractInt, Type::U32) => Some(4),
311        (Type::AbstractInt, Type::F32) => Some(6),
312        (Type::AbstractInt, Type::F16) => Some(7),
313        (Type::AbstractFloat, Type::F32) => Some(1),
314        (Type::AbstractFloat, Type::F16) => Some(2),
315        // frexp and modf
316        (Type::Struct(s1), Type::Struct(s2)) => {
317            if s1.name.starts_with("__") && s1.name.ends_with("abstract") {
318                if s2.name.ends_with("f32") {
319                    Some(1)
320                } else if s2.name.ends_with("f16") {
321                    Some(2)
322                } else {
323                    None
324                }
325            } else {
326                None
327            }
328        }
329        (Type::Array(ty1, n1), Type::Array(ty2, n2)) if n1 == n2 => conversion_rank(ty1, ty2),
330        (Type::Vec(n1, ty1), Type::Vec(n2, ty2)) if n1 == n2 => conversion_rank(ty1, ty2),
331        (Type::Mat(c1, r1, ty1), Type::Mat(c2, r2, ty2)) if c1 == c2 && r1 == r2 => {
332            conversion_rank(ty1, ty2)
333        }
334        _ => None,
335    }
336}
337
338/// Performs overload resolution when two instances of T are involved (which is the most common).
339/// it just makes sure that the two instance types are the same. This is sufficient in most cases.
340pub fn convert<T: Convert + Ty + Clone>(i1: &T, i2: &T) -> Option<(T, T)> {
341    let (ty1, ty2) = (i1.ty(), i2.ty());
342    let ty = convert_ty(&ty1, &ty2)?;
343    let i1 = i1.convert_to(ty)?;
344    let i2 = i2.convert_to(ty)?;
345    Some((i1, i2))
346}
347
348/// See [`convert`]
349pub fn convert_inner<T1: Convert + Ty + Clone, T2: Convert + Ty + Clone>(
350    i1: &T1,
351    i2: &T2,
352) -> Option<(T1, T2)> {
353    let (ty1, ty2) = (i1.inner_ty(), i2.inner_ty());
354    let ty = convert_ty(&ty1, &ty2)?;
355    let i1 = i1.convert_inner_to(ty)?;
356    let i2 = i2.convert_inner_to(ty)?;
357    Some((i1, i2))
358}
359
360/// See [`convert`]
361pub fn convert_all<'a, T: Convert + Ty + Clone + 'a>(insts: &[T]) -> Option<Vec<T>> {
362    let tys = insts.iter().map(|i| i.ty()).collect_vec();
363    let ty = convert_all_ty(&tys)?;
364    convert_all_to(insts, ty)
365}
366
367/// See [`convert`]
368pub fn convert_all_to<'a, T: Convert + Ty + Clone + 'a>(insts: &[T], ty: &Type) -> Option<Vec<T>> {
369    insts
370        .iter()
371        .map(|inst| inst.convert_to(ty))
372        .collect::<Option<Vec<_>>>()
373}
374
375/// See [`convert`]
376pub fn convert_all_inner_to<'a, T: Convert + Ty + Clone + 'a>(
377    insts: &[T],
378    ty: &Type,
379) -> Option<Vec<T>> {
380    insts
381        .iter()
382        .map(|inst| inst.convert_inner_to(ty))
383        .collect::<Option<Vec<_>>>()
384}
385
386/// Performs overload resolution when two instances of T are involved (which is the most common).
387/// it just makes sure that the two types are the same. This is sufficient in most cases.
388pub fn convert_ty<'a>(ty1: &'a Type, ty2: &'a Type) -> Option<&'a Type> {
389    conversion_rank(ty1, ty2)
390        .map(|_rank| ty2)
391        .or_else(|| conversion_rank(ty2, ty1).map(|_rank| ty1))
392}
393
394/// Performs overload resolution (find the type that all others can be automatically converted to)
395pub fn convert_all_ty<'a>(tys: impl IntoIterator<Item = &'a Type> + 'a) -> Option<&'a Type> {
396    tys.into_iter()
397        .map(Option::Some)
398        .reduce(|ty1, ty2| convert_ty(ty1?, ty2?))
399        .flatten()
400}