wgsl_types/builtin/
ops_ty.rs

1//! Type-checking of operators.
2
3use crate::{
4    Error,
5    conv::{Convert, convert_ty},
6    syntax::{AddressSpace, BinaryOperator, UnaryOperator},
7    ty::{Ty, Type},
8};
9
10type E = Error;
11
12/// Compute the return type of a unary operator expression.
13pub fn type_unary_op(operator: UnaryOperator, operand: &Type) -> Result<Type, E> {
14    match operator {
15        UnaryOperator::LogicalNegation => operand.op_not(),
16        UnaryOperator::Negation => operand.op_neg(),
17        UnaryOperator::BitwiseComplement => operand.op_bitnot(),
18        UnaryOperator::AddressOf => operand.op_ref(),
19        UnaryOperator::Indirection => operand.op_deref(),
20    }
21}
22
23/// Compute the return type of a binary operator expression.
24pub fn type_binary_op(op: BinaryOperator, lhs: &Type, rhs: &Type) -> Result<Type, E> {
25    match op {
26        BinaryOperator::ShortCircuitOr => lhs.op_or(rhs),
27        BinaryOperator::ShortCircuitAnd => lhs.op_and(rhs),
28        BinaryOperator::Addition => lhs.op_add(rhs),
29        BinaryOperator::Subtraction => lhs.op_sub(rhs),
30        BinaryOperator::Multiplication => lhs.op_mul(rhs),
31        BinaryOperator::Division => lhs.op_div(rhs),
32        BinaryOperator::Remainder => lhs.op_rem(rhs),
33        BinaryOperator::Equality => lhs.op_eq(rhs),
34        BinaryOperator::Inequality => lhs.op_ne(rhs),
35        BinaryOperator::LessThan => lhs.op_lt(rhs),
36        BinaryOperator::LessThanEqual => lhs.op_le(rhs),
37        BinaryOperator::GreaterThan => lhs.op_gt(rhs),
38        BinaryOperator::GreaterThanEqual => lhs.op_ge(rhs),
39        BinaryOperator::BitwiseOr => lhs.op_bitor(rhs),
40        BinaryOperator::BitwiseAnd => lhs.op_bitand(rhs),
41        BinaryOperator::BitwiseXor => lhs.op_bitxor(rhs),
42        BinaryOperator::ShiftLeft => lhs.op_shl(rhs),
43        BinaryOperator::ShiftRight => lhs.op_shr(rhs),
44    }
45}
46
47// -------------------
48// LOGICAL EXPRESSIONS
49// -------------------
50// reference: https://www.w3.org/TR/WGSL/#logical-expr
51
52impl Type {
53    pub fn op_not(&self) -> Result<Self, E> {
54        match self {
55            Self::Bool | Self::Vec(_, _) => Ok(self.clone()),
56            _ => Err(E::Unary(UnaryOperator::LogicalNegation, self.clone())),
57        }
58    }
59    pub fn op_or(&self, rhs: &Type) -> Result<Self, E> {
60        match (self, rhs) {
61            (Type::Bool, Type::Bool) => Ok(Type::Bool),
62            _ => Err(E::Binary(
63                BinaryOperator::ShortCircuitOr,
64                self.ty(),
65                rhs.ty(),
66            )),
67        }
68    }
69    pub fn op_and(&self, rhs: &Type) -> Result<Self, E> {
70        match (self, rhs) {
71            (Type::Bool, Type::Bool) => Ok(Type::Bool),
72            _ => Err(E::Binary(
73                BinaryOperator::ShortCircuitAnd,
74                self.ty(),
75                rhs.ty(),
76            )),
77        }
78    }
79}
80
81// ----------------------
82// ARITHMETIC EXPRESSIONS
83// ----------------------
84// reference: https://www.w3.org/TR/WGSL/#arithmetic-expr
85
86impl Type {
87    /// Valid operands:
88    /// * `-S`, S: scalar
89    pub fn op_neg(&self) -> Result<Self, E> {
90        if self.is_scalar() {
91            Ok(self.clone())
92        } else {
93            Err(E::Unary(UnaryOperator::Negation, self.ty()))
94        }
95    }
96
97    /// Valid operands:
98    /// * `T + T`, T: scalar or vec
99    /// * `S + V` or `V + S`, S: scalar, V: `vec<S>`
100    /// * `M + M`, M: mat
101    pub fn op_add(&self, rhs: &Type) -> Result<Self, E> {
102        let err = || E::Binary(BinaryOperator::Addition, self.ty(), rhs.ty());
103        match (self, rhs) {
104            (lhs, rhs) if lhs.is_scalar() && rhs.is_scalar() || lhs.is_vec() && rhs.is_vec() => {
105                let ty = convert_ty(self, rhs).ok_or_else(err)?;
106                Ok(ty.clone())
107            }
108            (scalar_ty, Type::Vec(n, vec_ty)) | (Type::Vec(n, vec_ty), scalar_ty)
109                if scalar_ty.is_scalar() =>
110            {
111                let inner_ty = convert_ty(scalar_ty, vec_ty).ok_or_else(err)?;
112                Ok(Type::Vec(*n, Box::new(inner_ty.clone())))
113            }
114            (Type::Mat(c1, r1, lhs), Type::Mat(c2, r2, rhs)) if c1 == c2 && r1 == r2 => {
115                let inner_ty = convert_ty(lhs, rhs).ok_or_else(err)?;
116                Ok(Type::Mat(*c1, *c2, Box::new(inner_ty.clone())))
117            }
118            _ => Err(err()),
119        }
120    }
121
122    /// Valid operands:
123    /// * `T - T`, T: scalar or vec
124    /// * `S - V` or `V - S`, S: scalar, V: `vec<S>`
125    /// * `M - M`, M: mat
126    pub fn op_sub(&self, rhs: &Type) -> Result<Self, E> {
127        let err = || E::Binary(BinaryOperator::Subtraction, self.ty(), rhs.ty());
128        match (self, rhs) {
129            (lhs, rhs) if lhs.is_scalar() && rhs.is_scalar() || lhs.is_vec() && rhs.is_vec() => {
130                let ty = convert_ty(self, rhs).ok_or_else(err)?;
131                Ok(ty.clone())
132            }
133            (scalar_ty, Type::Vec(n, vec_ty)) | (Type::Vec(n, vec_ty), scalar_ty)
134                if scalar_ty.is_scalar() =>
135            {
136                let inner_ty = convert_ty(scalar_ty, vec_ty).ok_or_else(err)?;
137                Ok(Type::Vec(*n, Box::new(inner_ty.clone())))
138            }
139            (Type::Mat(c1, r1, lhs), Type::Mat(c2, r2, rhs)) if c1 == c2 && r1 == r2 => {
140                let inner_ty = convert_ty(lhs, rhs).ok_or_else(err)?;
141                Ok(Type::Mat(*c1, *c2, Box::new(inner_ty.clone())))
142            }
143            _ => Err(err()),
144        }
145    }
146
147    /// Valid operands:
148    /// * `T * T`, T: scalar or vec
149    /// * `S * V` or `V * S`, S: scalar, V: `vec<S>`
150    /// * `S * M` or `M * S`, S: float, M: `mat<S>`
151    /// * `V * M` or `M * V`, S: float, V: `vec<S>`, M: `mat<S>`
152    /// * `M1 * M1`, M1: `matKxR`, M2: `matCxK`
153    pub fn op_mul(&self, rhs: &Type) -> Result<Self, E> {
154        let err = || E::Binary(BinaryOperator::Multiplication, self.ty(), rhs.ty());
155        match (self, rhs) {
156            (lhs, rhs) if lhs.is_scalar() && rhs.is_scalar() || lhs.is_vec() && rhs.is_vec() => {
157                let ty = convert_ty(self, rhs).ok_or_else(err)?;
158                Ok(ty.clone())
159            }
160            (scalar_ty, Type::Vec(n, vec_ty)) | (Type::Vec(n, vec_ty), scalar_ty)
161                if scalar_ty.is_scalar() =>
162            {
163                let inner_ty = convert_ty(scalar_ty, vec_ty).ok_or_else(err)?;
164                Ok(Type::Vec(*n, Box::new(inner_ty.clone())))
165            }
166            (scalar_ty, Type::Mat(c, r, mat_ty)) | (Type::Mat(c, r, mat_ty), scalar_ty)
167                if scalar_ty.is_scalar() =>
168            {
169                let inner_ty = convert_ty(scalar_ty, mat_ty).ok_or_else(err)?;
170                Ok(Type::Mat(*c, *r, Box::new(inner_ty.clone())))
171            }
172            (Type::Vec(n1, vec_ty), Type::Mat(n2, n, mat_ty))
173            | (Type::Mat(n, n1, mat_ty), Type::Vec(n2, vec_ty))
174                if n1 == n2 =>
175            {
176                let inner_ty = convert_ty(vec_ty, mat_ty).ok_or_else(err)?;
177                Ok(Type::Vec(*n, Box::new(inner_ty.clone())))
178            }
179            (Type::Mat(k1, r, lhs), Type::Mat(c, k2, rhs)) if k1 == k2 => {
180                let inner_ty = convert_ty(lhs, rhs).ok_or_else(err)?;
181                Ok(Type::Mat(*c, *r, Box::new(inner_ty.clone())))
182            }
183            _ => Err(err()),
184        }
185    }
186
187    /// Valid operands:
188    /// * `T / T`, T: scalar or vec
189    /// * `S / V` or `V / S`, S: scalar, V: `vec<S>`
190    pub fn op_div(&self, rhs: &Type) -> Result<Self, E> {
191        let err = || E::Binary(BinaryOperator::Division, self.ty(), rhs.ty());
192        match (self, rhs) {
193            (lhs, rhs) if lhs.is_scalar() && rhs.is_scalar() || lhs.is_vec() && rhs.is_vec() => {
194                let ty = convert_ty(self, rhs).ok_or_else(err)?;
195                Ok(ty.clone())
196            }
197            (scalar_ty, Type::Vec(n, vec_ty)) | (Type::Vec(n, vec_ty), scalar_ty)
198                if scalar_ty.is_scalar() =>
199            {
200                let inner_ty = convert_ty(scalar_ty, vec_ty).ok_or_else(err)?;
201                Ok(Type::Vec(*n, Box::new(inner_ty.clone())))
202            }
203            _ => Err(err()),
204        }
205    }
206
207    /// Valid operands:
208    /// * `T % T`, T: scalar or vec
209    /// * `S % V` or `V % S`, S: scalar, V: `vec<S>`
210    pub fn op_rem(&self, rhs: &Type) -> Result<Self, E> {
211        let err = || E::Binary(BinaryOperator::Remainder, self.ty(), rhs.ty());
212        match (self, rhs) {
213            (lhs, rhs) if lhs.is_scalar() && rhs.is_scalar() || lhs.is_vec() && rhs.is_vec() => {
214                let ty = convert_ty(self, rhs).ok_or_else(err)?;
215                Ok(ty.clone())
216            }
217            (scalar_ty, Type::Vec(n, vec_ty)) | (Type::Vec(n, vec_ty), scalar_ty)
218                if scalar_ty.is_scalar() =>
219            {
220                let inner_ty = convert_ty(scalar_ty, vec_ty).ok_or_else(err)?;
221                Ok(Type::Vec(*n, Box::new(inner_ty.clone())))
222            }
223            _ => Err(err()),
224        }
225    }
226}
227
228// ----------------------
229// COMPARISON EXPRESSIONS
230// ----------------------
231// reference: https://www.w3.org/TR/WGSL/#comparison-expr
232
233impl Type {
234    /// Valid operands:
235    /// * `T == T`, T: scalar or vec
236    pub fn op_eq(&self, rhs: &Type) -> Result<Type, E> {
237        let err = || E::Binary(BinaryOperator::Equality, self.ty(), rhs.ty());
238        match convert_ty(self, rhs).ok_or_else(err)? {
239            ty if ty.is_scalar() => Ok(Type::Bool),
240            Type::Vec(n, _) => Ok(Type::Vec(*n, Box::new(Type::Bool))),
241            _ => Err(err()),
242        }
243    }
244    /// Valid operands:
245    /// * `T != T`, T: scalar or vec
246    pub fn op_ne(&self, rhs: &Type) -> Result<Type, E> {
247        let err = || E::Binary(BinaryOperator::Inequality, self.ty(), rhs.ty());
248        match convert_ty(self, rhs).ok_or_else(err)? {
249            ty if ty.is_scalar() => Ok(Type::Bool),
250            Type::Vec(n, _) => Ok(Type::Vec(*n, Box::new(Type::Bool))),
251            _ => Err(err()),
252        }
253    }
254    /// Valid operands:
255    /// * `T < T`, T: scalar or vec
256    pub fn op_lt(&self, rhs: &Type) -> Result<Type, E> {
257        let err = || E::Binary(BinaryOperator::LessThan, self.ty(), rhs.ty());
258        match convert_ty(self, rhs).ok_or_else(err)? {
259            ty if ty.is_scalar() => Ok(Type::Bool),
260            Type::Vec(n, _) => Ok(Type::Vec(*n, Box::new(Type::Bool))),
261            _ => Err(err()),
262        }
263    }
264    /// Valid operands:
265    /// * `T <= T`, T: scalar or vec
266    pub fn op_le(&self, rhs: &Type) -> Result<Type, E> {
267        let err = || E::Binary(BinaryOperator::LessThanEqual, self.ty(), rhs.ty());
268        match convert_ty(self, rhs).ok_or_else(err)? {
269            ty if ty.is_scalar() => Ok(Type::Bool),
270            Type::Vec(n, _) => Ok(Type::Vec(*n, Box::new(Type::Bool))),
271            _ => Err(err()),
272        }
273    }
274    /// Valid operands:
275    /// * `T > T`, T: scalar or vec
276    pub fn op_gt(&self, rhs: &Type) -> Result<Type, E> {
277        let err = || E::Binary(BinaryOperator::GreaterThan, self.ty(), rhs.ty());
278        match convert_ty(self, rhs).ok_or_else(err)? {
279            ty if ty.is_scalar() => Ok(Type::Bool),
280            Type::Vec(n, _) => Ok(Type::Vec(*n, Box::new(Type::Bool))),
281            _ => Err(err()),
282        }
283    }
284    /// Valid operands:
285    /// * `T >= T`, T: scalar or vec
286    pub fn op_ge(&self, rhs: &Type) -> Result<Type, E> {
287        let err = || E::Binary(BinaryOperator::GreaterThanEqual, self.ty(), rhs.ty());
288        match convert_ty(self, rhs).ok_or_else(err)? {
289            ty if ty.is_scalar() => Ok(Type::Bool),
290            Type::Vec(n, _) => Ok(Type::Vec(*n, Box::new(Type::Bool))),
291            _ => Err(err()),
292        }
293    }
294}
295
296// ---------------
297// BIT EXPRESSIONS
298// ---------------
299// reference: https://www.w3.org/TR/WGSL/#bit-expr
300
301impl Type {
302    /// Valid operands:
303    /// * `~T`, I: integer, T: I or `vec<I>`
304    pub fn op_bitnot(&self) -> Result<Type, E> {
305        match self {
306            ty if ty.is_integer() => Ok(self.clone()),
307            Type::Vec(_, ty) if ty.is_integer() => Ok(self.clone()),
308            _ => Err(E::Unary(UnaryOperator::BitwiseComplement, self.ty())),
309        }
310    }
311
312    /// Note: this is both the "bitwise OR" and "logical OR" operator.
313    ///
314    /// Valid operands:
315    /// * `T | T`, T: integer, or `vec<integer>` (bitwise OR)
316    /// * `V | V`, V: `vec<bool>` (logical OR)
317    pub fn op_bitor(&self, rhs: &Type) -> Result<Type, E> {
318        let err = || E::Binary(BinaryOperator::BitwiseOr, self.ty(), rhs.ty());
319        match convert_ty(self, rhs).ok_or_else(err)? {
320            ty if ty.is_integer() => Ok(self.clone()),
321            Type::Vec(_, ty) if ty.is_integer() => Ok(self.clone()),
322            Type::Vec(_, ty) if ty.is_bool() => Ok(self.clone()),
323            _ => Err(err()),
324        }
325    }
326
327    /// Note: this is both the "bitwise AND" and "logical AND" operator.
328    ///
329    /// Valid operands:
330    /// * `T & T`, T: integer or `vec<integer>` (bitwise AND)
331    /// * `V & V`, V: `vec<bool>` (logical AND)
332    pub fn op_bitand(&self, rhs: &Type) -> Result<Type, E> {
333        let err = || E::Binary(BinaryOperator::BitwiseAnd, self.ty(), rhs.ty());
334        match convert_ty(self, rhs).ok_or_else(err)? {
335            ty if ty.is_integer() => Ok(self.clone()),
336            Type::Vec(_, ty) if ty.is_integer() => Ok(self.clone()),
337            Type::Vec(_, ty) if ty.is_bool() => Ok(self.clone()),
338            _ => Err(err()),
339        }
340    }
341
342    /// Valid operands:
343    /// * `T ^ T`, T: integer or `vec<integer>`
344    pub fn op_bitxor(&self, rhs: &Type) -> Result<Type, E> {
345        let err = || E::Binary(BinaryOperator::BitwiseXor, self.ty(), rhs.ty());
346        match convert_ty(self, rhs).ok_or_else(err)? {
347            ty if ty.is_integer() => Ok(self.clone()),
348            Type::Vec(_, ty) if ty.is_integer() => Ok(self.clone()),
349            _ => Err(err()),
350        }
351    }
352
353    /// Valid operands:
354    /// * `integer << u32`
355    /// * `vec<integer> << vec<u32>`
356    pub fn op_shl(&self, rhs: &Type) -> Result<Type, E> {
357        let err = || E::Binary(BinaryOperator::ShiftLeft, self.ty(), rhs.ty());
358        let rhs = rhs.convert_inner_to(&Type::U32).ok_or_else(err)?;
359        match (self, rhs) {
360            (lhs, Type::U32) if lhs.is_integer() => Ok(lhs.clone()),
361            (lhs, Type::Vec(_, _)) => Ok(lhs.clone()),
362            _ => Err(err()),
363        }
364    }
365
366    /// Valid operands:
367    /// * `integer >> u32`
368    /// * `vec<integer> >> vec<u32>`
369    pub fn op_shr(&self, rhs: &Type) -> Result<Type, E> {
370        let err = || E::Binary(BinaryOperator::ShiftRight, self.ty(), rhs.ty());
371        let rhs = rhs.convert_inner_to(&Type::U32).ok_or_else(err)?;
372        match (self, rhs) {
373            (lhs, Type::U32) if lhs.is_integer() => Ok(lhs.clone()),
374            (lhs, Type::Vec(_, _)) => Ok(lhs.clone()),
375            _ => Err(err()),
376        }
377    }
378}
379
380// -------------------
381// POINTER EXPRESSIONS
382// -------------------
383// reference: https://www.w3.org/TR/WGSL/#address-of-expr
384// reference: https://www.w3.org/TR/WGSL/#indirection-expr
385
386impl Type {
387    pub fn op_ref(&self) -> Result<Type, E> {
388        match self {
389            Type::Ref(a_s, ty, a_m) => {
390                if *a_s == AddressSpace::Handle {
391                    // "It is a shader-creation error if AS is the handle address space."
392                    Err(E::PtrHandle)
393                } else if false {
394                    // TODO: We do not yet have enough information to check this:
395                    // "It is a shader-creation error if r is a reference to a vector component."
396                    Err(E::PtrVecComp)
397                } else {
398                    Ok(Type::Ptr(*a_s, ty.clone(), *a_m))
399                }
400            }
401            _ => Err(E::Unary(UnaryOperator::AddressOf, self.ty())),
402        }
403    }
404
405    pub fn op_deref(&self) -> Result<Type, E> {
406        match self {
407            Type::Ptr(a_s, ty, a_m) => Ok(Type::Ref(*a_s, ty.clone(), *a_m)),
408            _ => Err(E::Unary(UnaryOperator::Indirection, self.ty())),
409        }
410    }
411}