wgsl_types/builtin/
ops.rs

1use std::iter::zip;
2
3use crate::{
4    Error, Instance, ShaderStage,
5    conv::{Convert, convert, convert_inner},
6    inst::{LiteralInstance, MatInstance, MemView, PtrInstance, RefInstance, VecInstance},
7    syntax::{AddressSpace, BinaryOperator, UnaryOperator},
8    ty::{Ty, Type},
9};
10
11use num_traits::{WrappingNeg, WrappingShl};
12
13type E = Error;
14
15pub trait Compwise: Clone + Sized {
16    fn compwise_unary_mut<F>(&mut self, f: F) -> Result<(), E>
17    where
18        F: Fn(&LiteralInstance) -> Result<LiteralInstance, E>;
19    fn compwise_binary_mut<F>(&mut self, rhs: &Self, f: F) -> Result<(), E>
20    where
21        F: Fn(&LiteralInstance, &LiteralInstance) -> Result<LiteralInstance, E>;
22    fn compwise_unary<F>(&self, f: F) -> Result<Self, E>
23    where
24        F: Fn(&LiteralInstance) -> Result<LiteralInstance, E>,
25    {
26        let mut res = self.clone();
27        res.compwise_unary_mut(f)?;
28        Ok(res)
29    }
30
31    fn compwise_binary<F>(&self, rhs: &Self, f: F) -> Result<Self, E>
32    where
33        F: Fn(&LiteralInstance, &LiteralInstance) -> Result<LiteralInstance, E>,
34    {
35        let mut res = self.clone();
36        res.compwise_binary_mut(rhs, f)?;
37        Ok(res)
38    }
39}
40
41impl Compwise for VecInstance {
42    fn compwise_unary_mut<F>(&mut self, f: F) -> Result<(), E>
43    where
44        F: Fn(&LiteralInstance) -> Result<LiteralInstance, E>,
45    {
46        self.iter_mut().try_for_each(|c| {
47            match c {
48                Instance::Literal(c) => *c = f(c)?,
49                _ => unreachable!("vec must contain literal instances"),
50            };
51            Ok(())
52        })
53    }
54
55    fn compwise_binary_mut<F>(&mut self, rhs: &Self, f: F) -> Result<(), E>
56    where
57        F: Fn(&LiteralInstance, &LiteralInstance) -> Result<LiteralInstance, E>,
58    {
59        if self.n() != rhs.n() {
60            return Err(E::CompwiseBinary(self.ty(), rhs.ty()));
61        }
62        zip(self.iter_mut(), rhs.iter()).try_for_each(|(a, b)| {
63            match (a, b) {
64                (Instance::Literal(a), Instance::Literal(b)) => *a = f(a, b)?,
65                _ => unreachable!("vec must contain literal instances"),
66            };
67            Ok(())
68        })
69    }
70}
71
72impl Compwise for MatInstance {
73    fn compwise_unary_mut<F>(&mut self, f: F) -> Result<(), E>
74    where
75        F: Fn(&LiteralInstance) -> Result<LiteralInstance, E>,
76    {
77        self.iter_cols_mut()
78            .flat_map(|col| col.unwrap_vec_mut().iter_mut())
79            .try_for_each(|c| {
80                match c {
81                    Instance::Literal(c) => *c = f(c)?,
82                    _ => unreachable!("mat must contain literal instances"),
83                };
84                Ok(())
85            })
86    }
87
88    fn compwise_binary_mut<F>(&mut self, rhs: &Self, f: F) -> Result<(), E>
89    where
90        F: Fn(&LiteralInstance, &LiteralInstance) -> Result<LiteralInstance, E>,
91    {
92        if self.c() != rhs.c() || self.r() != rhs.r() {
93            return Err(E::CompwiseBinary(self.ty(), rhs.ty()));
94        }
95        zip(
96            self.iter_cols_mut()
97                .flat_map(|col| col.unwrap_vec_mut().iter_mut()),
98            rhs.iter_cols().flat_map(|col| col.unwrap_vec_ref().iter()),
99        )
100        .try_for_each(|(a, b)| {
101            match (a, b) {
102                (Instance::Literal(a), Instance::Literal(b)) => *a = f(a, b)?,
103                _ => unreachable!("mat must contain literal instances"),
104            };
105            Ok(())
106        })
107    }
108}
109
110macro_rules! both {
111    ($enum:ident::$var:ident, $lhs:ident, $rhs:ident) => {
112        ($enum::$var($lhs), $enum::$var($rhs))
113    };
114    ($enum1:ident::$var1:ident, $enum2:ident::$var2:ident, $lhs:ident, $rhs:ident) => {
115        ($enum1::$var1($lhs), $enum2::$var2($rhs)) | ($enum2::$var2($rhs), $enum1::$var1($lhs))
116    };
117}
118
119// -------------------
120// LOGICAL EXPRESSIONS
121// -------------------
122// reference: https://www.w3.org/TR/WGSL/#logical-expr
123
124// logical and/or are part of bitwise and/or.
125// short circuiting and/or is implemented in eval() because it needs context.
126
127impl LiteralInstance {
128    pub fn op_not(&self) -> Result<Self, E> {
129        match self {
130            Self::Bool(b) => Ok(Self::Bool(!b)),
131            _ => Err(E::Unary(UnaryOperator::LogicalNegation, self.ty())),
132        }
133    }
134}
135
136impl VecInstance {
137    pub fn op_not(&self) -> Result<Self, E> {
138        self.compwise_unary(|c| c.op_not())
139    }
140}
141
142impl Instance {
143    pub fn op_not(&self) -> Result<Self, E> {
144        match self {
145            Self::Literal(lit) => lit.op_not().map(Into::into),
146            Self::Vec(v) => v.op_not().map(Into::into),
147            _ => Err(E::Unary(UnaryOperator::LogicalNegation, self.ty())),
148        }
149    }
150}
151
152// ----------------------
153// ARITHMETIC EXPRESSIONS
154// ----------------------
155// reference: https://www.w3.org/TR/WGSL/#arithmetic-expr
156
157impl LiteralInstance {
158    pub fn op_neg(&self) -> Result<Self, E> {
159        match self {
160            Self::AbstractInt(lhs) => Ok(lhs.wrapping_neg().into()),
161            Self::AbstractFloat(lhs) => Ok((-lhs).into()),
162            Self::I32(lhs) => Ok(lhs.wrapping_neg().into()),
163            Self::F32(lhs) => Ok((-lhs).into()),
164            Self::F16(lhs) => Ok((-lhs).into()),
165            #[cfg(feature = "naga-ext")]
166            Self::I64(lhs) => Ok(LiteralInstance::I64(lhs.wrapping_neg())),
167            #[cfg(feature = "naga-ext")]
168            Self::F64(lhs) => Ok(LiteralInstance::F64(-lhs)),
169            _ => Err(E::Unary(UnaryOperator::Negation, self.ty())),
170        }
171    }
172    pub fn op_or(&self, rhs: &Self) -> Result<Self, E> {
173        let err = || E::Binary(BinaryOperator::ShortCircuitOr, self.ty(), rhs.ty());
174        match (self, rhs) {
175            (Self::Bool(b1), Self::Bool(b2)) => Ok(Self::Bool(*b1 || *b2)),
176            _ => Err(err()),
177        }
178    }
179    pub fn op_and(&self, rhs: &Self) -> Result<Self, E> {
180        let err = || E::Binary(BinaryOperator::ShortCircuitAnd, self.ty(), rhs.ty());
181        match (self, rhs) {
182            (Self::Bool(b1), Self::Bool(b2)) => Ok(Self::Bool(*b1 && *b2)),
183            _ => Err(err()),
184        }
185    }
186    pub fn op_add(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
187        let err = || E::Binary(BinaryOperator::Addition, self.ty(), rhs.ty());
188        match convert(self, rhs).ok_or_else(err)? {
189            both!(Self::AbstractInt, lhs, rhs) => {
190                lhs.checked_add(rhs).ok_or(E::AddOverflow).map(Into::into)
191            }
192            both!(Self::AbstractFloat, lhs, rhs) => {
193                let res = lhs + rhs;
194                res.is_finite()
195                    .then_some(res)
196                    .ok_or(E::AddOverflow)
197                    .map(Into::into)
198            }
199            both!(Self::I32, lhs, rhs) => Ok(lhs.wrapping_add(rhs).into()),
200            both!(Self::U32, lhs, rhs) => Ok(lhs.wrapping_add(rhs).into()),
201            both!(Self::F32, lhs, rhs) => {
202                let res = lhs + rhs;
203                if stage == ShaderStage::Const {
204                    res.is_finite()
205                        .then_some(res)
206                        .ok_or(E::AddOverflow)
207                        .map(Into::into)
208                } else {
209                    Ok(res.into())
210                }
211            }
212            both!(Self::F16, lhs, rhs) => {
213                let res = lhs + rhs;
214                if stage == ShaderStage::Const {
215                    res.is_finite()
216                        .then_some(res)
217                        .ok_or(E::AddOverflow)
218                        .map(Into::into)
219                } else {
220                    Ok(res.into())
221                }
222            }
223            #[cfg(feature = "naga-ext")]
224            both!(Self::I64, lhs, rhs) => Ok(Self::I64(lhs.wrapping_add(rhs))),
225            #[cfg(feature = "naga-ext")]
226            both!(Self::U64, lhs, rhs) => Ok(Self::U64(lhs.wrapping_add(rhs))),
227            #[cfg(feature = "naga-ext")]
228            both!(Self::F64, lhs, rhs) => {
229                let res = lhs + rhs;
230                if stage == ShaderStage::Const {
231                    res.is_finite()
232                        .then_some(res)
233                        .ok_or(E::AddOverflow)
234                        .map(LiteralInstance::F64)
235                } else {
236                    Ok(res.into())
237                }
238            }
239            _ => Err(err()),
240        }
241    }
242    pub fn op_sub(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
243        let err = || E::Binary(BinaryOperator::Subtraction, self.ty(), rhs.ty());
244        match convert(self, rhs).ok_or_else(err)? {
245            both!(Self::AbstractInt, lhs, rhs) => {
246                lhs.checked_sub(rhs).ok_or(E::SubOverflow).map(Into::into)
247            }
248            both!(Self::AbstractFloat, lhs, rhs) => {
249                let res = lhs - rhs;
250                res.is_finite()
251                    .then_some(res)
252                    .ok_or(E::SubOverflow)
253                    .map(Into::into)
254            }
255            both!(Self::I32, lhs, rhs) => Ok(lhs.wrapping_sub(rhs).into()),
256            both!(Self::U32, lhs, rhs) => Ok(lhs.wrapping_sub(rhs).into()),
257            both!(Self::F32, lhs, rhs) => {
258                let res = lhs - rhs;
259                if stage == ShaderStage::Const {
260                    res.is_finite()
261                        .then_some(res)
262                        .ok_or(E::SubOverflow)
263                        .map(Into::into)
264                } else {
265                    Ok(res.into())
266                }
267            }
268            both!(Self::F16, lhs, rhs) => {
269                let res = lhs - rhs;
270                if stage == ShaderStage::Const {
271                    res.is_finite()
272                        .then_some(res)
273                        .ok_or(E::SubOverflow)
274                        .map(Into::into)
275                } else {
276                    Ok(res.into())
277                }
278            }
279            #[cfg(feature = "naga-ext")]
280            both!(Self::I64, lhs, rhs) => Ok(Self::I64(lhs.wrapping_sub(rhs))),
281            #[cfg(feature = "naga-ext")]
282            both!(Self::U64, lhs, rhs) => Ok(Self::U64(lhs.wrapping_sub(rhs))),
283            #[cfg(feature = "naga-ext")]
284            both!(Self::F64, lhs, rhs) => {
285                let res = lhs - rhs;
286                if stage == ShaderStage::Const {
287                    res.is_finite()
288                        .then_some(res)
289                        .ok_or(E::SubOverflow)
290                        .map(LiteralInstance::F64)
291                } else {
292                    Ok(res.into())
293                }
294            }
295            _ => Err(err()),
296        }
297    }
298    pub fn op_mul(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
299        let err = || E::Binary(BinaryOperator::Multiplication, self.ty(), rhs.ty());
300        match convert(self, rhs).ok_or_else(err)? {
301            both!(Self::AbstractInt, lhs, rhs) => {
302                lhs.checked_mul(rhs).ok_or(E::MulOverflow).map(Into::into)
303            }
304            both!(Self::AbstractFloat, lhs, rhs) => {
305                let res = lhs * rhs;
306                res.is_finite()
307                    .then_some(res)
308                    .ok_or(E::MulOverflow)
309                    .map(Into::into)
310            }
311            both!(Self::I32, lhs, rhs) => Ok(lhs.wrapping_mul(rhs).into()),
312            both!(Self::U32, lhs, rhs) => Ok(lhs.wrapping_mul(rhs).into()),
313            both!(Self::F32, lhs, rhs) => {
314                let res = lhs * rhs;
315                if stage == ShaderStage::Const {
316                    res.is_finite()
317                        .then_some(res)
318                        .ok_or(E::MulOverflow)
319                        .map(Into::into)
320                } else {
321                    Ok(res.into())
322                }
323            }
324            both!(Self::F16, lhs, rhs) => {
325                let res = lhs * rhs;
326                if stage == ShaderStage::Const {
327                    res.is_finite()
328                        .then_some(res)
329                        .ok_or(E::MulOverflow)
330                        .map(Into::into)
331                } else {
332                    Ok(res.into())
333                }
334            }
335            #[cfg(feature = "naga-ext")]
336            both!(Self::I64, lhs, rhs) => Ok(Self::I64(lhs.wrapping_mul(rhs))),
337            #[cfg(feature = "naga-ext")]
338            both!(Self::U64, lhs, rhs) => Ok(Self::U64(lhs.wrapping_mul(rhs))),
339            #[cfg(feature = "naga-ext")]
340            both!(Self::F64, lhs, rhs) => {
341                let res = lhs * rhs;
342                if stage == ShaderStage::Const {
343                    res.is_finite()
344                        .then_some(res)
345                        .ok_or(E::MulOverflow)
346                        .map(LiteralInstance::F64)
347                } else {
348                    Ok(res.into())
349                }
350            }
351            _ => Err(err()),
352        }
353    }
354    pub fn op_div(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
355        let err = || E::Binary(BinaryOperator::Division, self.ty(), rhs.ty());
356        let res = match convert(self, rhs).ok_or_else(err)? {
357            both!(Self::AbstractInt, lhs, rhs) => {
358                lhs.checked_div(rhs).ok_or(E::DivByZero).map(Into::into)
359            }
360            both!(Self::AbstractFloat, lhs, rhs) => {
361                let res = lhs / rhs;
362                res.is_finite()
363                    .then_some(res)
364                    .ok_or(E::DivByZero)
365                    .map(Into::into)
366            }
367            both!(Self::I32, lhs, rhs) => lhs.checked_div(rhs).ok_or(E::DivByZero).map(Into::into),
368            both!(Self::U32, lhs, rhs) => lhs.checked_div(rhs).ok_or(E::DivByZero).map(Into::into),
369            both!(Self::F32, lhs, rhs) => {
370                let res = lhs / rhs;
371                res.is_finite()
372                    .then_some(res)
373                    .ok_or(E::DivByZero)
374                    .map(Into::into)
375            }
376            both!(Self::F16, lhs, rhs) => {
377                let res = lhs / rhs;
378                res.is_finite()
379                    .then_some(res)
380                    .ok_or(E::DivByZero)
381                    .map(Into::into)
382            }
383            #[cfg(feature = "naga-ext")]
384            both!(Self::I64, lhs, rhs) => Ok(Self::I64(lhs.wrapping_div(rhs))),
385            #[cfg(feature = "naga-ext")]
386            both!(Self::U64, lhs, rhs) => Ok(Self::U64(lhs.wrapping_div(rhs))),
387            #[cfg(feature = "naga-ext")]
388            both!(Self::F64, lhs, rhs) => {
389                let res = lhs / rhs;
390                res.is_finite()
391                    .then_some(res)
392                    .ok_or(E::DivByZero)
393                    .map(Self::F64)
394            }
395            _ => Err(err()),
396        };
397        if stage == ShaderStage::Exec {
398            // runtime expressions return lhs when operation fails
399            Ok(res.unwrap_or(*self))
400        } else {
401            res
402        }
403    }
404    pub fn op_rem(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
405        let err = || E::Binary(BinaryOperator::Remainder, self.ty(), rhs.ty());
406        match convert(self, rhs).ok_or_else(err)? {
407            both!(Self::AbstractInt, lhs, rhs) => {
408                if stage == ShaderStage::Const {
409                    lhs.checked_rem(rhs).ok_or(E::RemZeroDiv).map(Into::into)
410                } else {
411                    Ok(lhs.checked_rem(rhs).unwrap_or(0).into())
412                }
413            }
414            both!(Self::AbstractFloat, lhs, rhs) => {
415                let res = lhs % rhs;
416                res.is_finite()
417                    .then_some(res)
418                    .ok_or(E::RemZeroDiv)
419                    .map(Into::into)
420            }
421            both!(Self::I32, lhs, rhs) => {
422                if stage == ShaderStage::Const {
423                    lhs.checked_rem(rhs).ok_or(E::RemZeroDiv).map(Into::into)
424                } else {
425                    Ok(lhs.checked_rem(rhs).unwrap_or(0).into())
426                }
427            }
428            both!(Self::U32, lhs, rhs) => {
429                if stage == ShaderStage::Const {
430                    lhs.checked_rem(rhs).ok_or(E::RemZeroDiv).map(Into::into)
431                } else {
432                    Ok(lhs.checked_rem(rhs).unwrap_or(0).into())
433                }
434            }
435            both!(Self::F32, lhs, rhs) => {
436                let res = lhs % rhs;
437                if stage == ShaderStage::Const {
438                    res.is_finite()
439                        .then_some(res)
440                        .ok_or(E::RemZeroDiv)
441                        .map(Into::into)
442                } else {
443                    Ok(res.into())
444                }
445            }
446            both!(Self::F16, lhs, rhs) => {
447                let res = lhs % rhs;
448                if stage == ShaderStage::Const {
449                    res.is_finite()
450                        .then_some(res)
451                        .ok_or(E::RemZeroDiv)
452                        .map(Into::into)
453                } else {
454                    Ok(res.into())
455                }
456            }
457            #[cfg(feature = "naga-ext")]
458            both!(Self::I64, lhs, rhs) => {
459                if stage == ShaderStage::Const {
460                    lhs.checked_rem(rhs).ok_or(E::RemZeroDiv).map(Self::I64)
461                } else {
462                    Ok(Self::I64(lhs.checked_rem(rhs).unwrap_or(0)))
463                }
464            }
465            #[cfg(feature = "naga-ext")]
466            both!(Self::U64, lhs, rhs) => {
467                if stage == ShaderStage::Const {
468                    lhs.checked_rem(rhs).ok_or(E::RemZeroDiv).map(Self::U64)
469                } else {
470                    Ok(Self::U64(lhs.checked_rem(rhs).unwrap_or(0)))
471                }
472            }
473            #[cfg(feature = "naga-ext")]
474            both!(Self::F64, lhs, rhs) => {
475                let res = lhs % rhs;
476                if stage == ShaderStage::Const {
477                    res.is_finite()
478                        .then_some(res)
479                        .ok_or(E::RemZeroDiv)
480                        .map(Self::F64)
481                } else {
482                    Ok(Self::F64(res))
483                }
484            }
485            _ => Err(err()),
486        }
487    }
488    pub fn op_add_vec(&self, rhs: &VecInstance, stage: ShaderStage) -> Result<VecInstance, E> {
489        rhs.op_add_sca(self, stage)
490    }
491    pub fn op_sub_vec(&self, rhs: &VecInstance, stage: ShaderStage) -> Result<VecInstance, E> {
492        let (lhs, rhs) = convert_inner(self, rhs)
493            .ok_or_else(|| E::Binary(BinaryOperator::Subtraction, self.ty(), rhs.ty()))?;
494        rhs.compwise_unary(|r| lhs.op_sub(r, stage))
495    }
496    pub fn op_mul_vec(&self, rhs: &VecInstance, stage: ShaderStage) -> Result<VecInstance, E> {
497        rhs.op_mul_sca(self, stage)
498    }
499    pub fn op_div_vec(&self, rhs: &VecInstance, stage: ShaderStage) -> Result<VecInstance, E> {
500        let (lhs, rhs) = convert_inner(self, rhs)
501            .ok_or_else(|| E::Binary(BinaryOperator::Division, self.ty(), rhs.ty()))?;
502        rhs.compwise_unary(|r| lhs.op_div(r, stage))
503    }
504    pub fn op_rem_vec(&self, rhs: &VecInstance, stage: ShaderStage) -> Result<VecInstance, E> {
505        let (lhs, rhs) = convert_inner(self, rhs)
506            .ok_or_else(|| E::Binary(BinaryOperator::Remainder, self.ty(), rhs.ty()))?;
507        rhs.compwise_unary(|r| lhs.op_rem(r, stage))
508    }
509    pub fn op_mul_mat(&self, rhs: &MatInstance, stage: ShaderStage) -> Result<MatInstance, E> {
510        rhs.op_mul_sca(self, stage)
511    }
512}
513
514impl VecInstance {
515    pub fn op_neg(&self) -> Result<Self, E> {
516        self.compwise_unary(|c| c.op_neg())
517    }
518    pub fn op_add(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
519        let (lhs, rhs) = convert(self, rhs)
520            .ok_or_else(|| E::Binary(BinaryOperator::Addition, self.ty(), rhs.ty()))?;
521        lhs.compwise_binary(&rhs, |l, r| l.op_add(r, stage))
522    }
523    pub fn op_sub(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
524        let (lhs, rhs) = convert(self, rhs)
525            .ok_or_else(|| E::Binary(BinaryOperator::Subtraction, self.ty(), rhs.ty()))?;
526        lhs.compwise_binary(&rhs, |l, r| l.op_sub(r, stage))
527    }
528    pub fn op_mul(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
529        let (lhs, rhs) = convert(self, rhs)
530            .ok_or_else(|| E::Binary(BinaryOperator::Multiplication, self.ty(), rhs.ty()))?;
531        lhs.compwise_binary(&rhs, |l, r| l.op_mul(r, stage))
532    }
533    pub fn op_div(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
534        let (lhs, rhs) = convert(self, rhs)
535            .ok_or_else(|| E::Binary(BinaryOperator::Division, self.ty(), rhs.ty()))?;
536        lhs.compwise_binary(&rhs, |l, r| l.op_div(r, stage))
537    }
538    pub fn op_rem(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
539        let (lhs, rhs) = convert(self, rhs)
540            .ok_or_else(|| E::Binary(BinaryOperator::Remainder, self.ty(), rhs.ty()))?;
541        lhs.compwise_binary(&rhs, |l, r| l.op_rem(r, stage))
542    }
543    pub fn op_add_sca(&self, rhs: &LiteralInstance, stage: ShaderStage) -> Result<Self, E> {
544        let (lhs, rhs) = convert_inner(self, rhs)
545            .ok_or_else(|| E::Binary(BinaryOperator::Addition, self.ty(), rhs.ty()))?;
546        lhs.compwise_unary(|l| l.op_add(&rhs, stage))
547    }
548    pub fn op_sub_sca(&self, rhs: &LiteralInstance, stage: ShaderStage) -> Result<Self, E> {
549        let (lhs, rhs) = convert_inner(self, rhs)
550            .ok_or_else(|| E::Binary(BinaryOperator::Subtraction, self.ty(), rhs.ty()))?;
551        lhs.compwise_unary(|l| l.op_sub(&rhs, stage))
552    }
553    pub fn op_mul_sca(&self, rhs: &LiteralInstance, stage: ShaderStage) -> Result<Self, E> {
554        let (lhs, rhs) = convert_inner(self, rhs)
555            .ok_or_else(|| E::Binary(BinaryOperator::Multiplication, self.ty(), rhs.ty()))?;
556        lhs.compwise_unary(|l| l.op_mul(&rhs, stage))
557    }
558    pub fn op_div_sca(&self, rhs: &LiteralInstance, stage: ShaderStage) -> Result<Self, E> {
559        let (lhs, rhs) = convert_inner(self, rhs)
560            .ok_or_else(|| E::Binary(BinaryOperator::Division, self.ty(), rhs.ty()))?;
561        lhs.compwise_unary(|l| l.op_div(&rhs, stage))
562    }
563    pub fn op_rem_sca(&self, rhs: &LiteralInstance, stage: ShaderStage) -> Result<Self, E> {
564        let (lhs, rhs) = convert_inner(self, rhs)
565            .ok_or_else(|| E::Binary(BinaryOperator::Remainder, self.ty(), rhs.ty()))?;
566        lhs.compwise_unary(|l| l.op_rem(&rhs, stage))
567    }
568    pub fn op_mul_mat(&self, rhs: &MatInstance, stage: ShaderStage) -> Result<Self, E> {
569        let (vec, mat) = convert_inner(self, rhs)
570            .ok_or_else(|| E::Binary(BinaryOperator::Multiplication, self.ty(), rhs.ty()))?;
571        let mat = mat.transpose();
572
573        zip(vec.iter(), mat.iter_cols())
574            .map(|(s, v)| v.unwrap_vec_ref().op_mul_sca(s.unwrap_literal_ref(), stage))
575            .reduce(|a, b| a?.op_add(&b?, stage))
576            .unwrap()
577    }
578}
579
580impl MatInstance {
581    pub fn op_add(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
582        let (lhs, rhs) = convert(self, rhs)
583            .ok_or_else(|| E::Binary(BinaryOperator::Addition, self.ty(), rhs.ty()))?;
584        lhs.compwise_binary(&rhs, |l, r| l.op_add(r, stage))
585    }
586
587    pub fn op_sub(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
588        let (lhs, rhs) = convert(self, rhs)
589            .ok_or_else(|| E::Binary(BinaryOperator::Subtraction, self.ty(), rhs.ty()))?;
590        lhs.compwise_binary(&rhs, |l, r| l.op_sub(r, stage))
591    }
592
593    pub fn op_mul_sca(&self, rhs: &LiteralInstance, stage: ShaderStage) -> Result<Self, E> {
594        let (lhs, rhs) = convert_inner(self, rhs)
595            .ok_or_else(|| E::Binary(BinaryOperator::Multiplication, self.ty(), rhs.ty()))?;
596        lhs.compwise_unary(|l| l.op_mul(&rhs, stage))
597    }
598
599    pub fn op_mul_vec(&self, rhs: &VecInstance, stage: ShaderStage) -> Result<VecInstance, E> {
600        let (lhs, rhs) = convert_inner(self, rhs)
601            .ok_or_else(|| E::Binary(BinaryOperator::Multiplication, self.ty(), rhs.ty()))?;
602
603        zip(lhs.iter_cols(), rhs.iter())
604            .map(|(l, r)| l.unwrap_vec_ref().op_mul_sca(r.unwrap_literal_ref(), stage))
605            .reduce(|l, r| l?.op_add(&r?, stage))
606            .unwrap()
607    }
608
609    pub fn op_mul(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
610        let (lhs, rhs) = convert_inner(self, rhs)
611            .ok_or_else(|| E::Binary(BinaryOperator::Multiplication, self.ty(), rhs.ty()))?;
612        let lhs = lhs.transpose();
613
614        Ok(MatInstance::from_cols(
615            rhs.iter_cols()
616                .map(|col| {
617                    Ok(VecInstance::new(
618                        lhs.iter_cols()
619                            .map(|r| {
620                                col.unwrap_vec_ref()
621                                    .dot(r.unwrap_vec_ref(), stage)
622                                    .map(Into::into)
623                            })
624                            .collect::<Result<_, _>>()?,
625                    )
626                    .into())
627                })
628                .collect::<Result<_, _>>()?,
629        ))
630    }
631}
632
633impl Instance {
634    pub fn op_neg(&self) -> Result<Self, E> {
635        match self {
636            Self::Literal(lhs) => lhs.op_neg().map(Into::into),
637            Self::Vec(lhs) => lhs.op_neg().map(Into::into),
638            _ => Err(E::Unary(UnaryOperator::Negation, self.ty())),
639        }
640    }
641    pub fn op_or(&self, rhs: &Self) -> Result<Self, E> {
642        let err = || E::Binary(BinaryOperator::ShortCircuitOr, self.ty(), rhs.ty());
643        match (self, rhs) {
644            both!(Self::Literal, lhs, rhs) => (lhs.op_or(rhs)).map(Into::into),
645            _ => Err(err()),
646        }
647    }
648    pub fn op_and(&self, rhs: &Self) -> Result<Self, E> {
649        let err = || E::Binary(BinaryOperator::ShortCircuitAnd, self.ty(), rhs.ty());
650        match (self, rhs) {
651            both!(Self::Literal, lhs, rhs) => (lhs.op_or(rhs)).map(Into::into),
652            _ => Err(err()),
653        }
654    }
655    pub fn op_add(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
656        match (self, rhs) {
657            both!(Self::Literal, lhs, rhs) => (lhs.op_add(rhs, stage)).map(Into::into),
658            (Self::Vec(lhs), Self::Literal(rhs)) => lhs.op_add_sca(rhs, stage).map(Into::into),
659            (Self::Literal(lhs), Self::Vec(rhs)) => lhs.op_add_vec(rhs, stage).map(Into::into),
660            both!(Self::Vec, lhs, rhs) => lhs.op_add(rhs, stage).map(Into::into),
661            both!(Self::Mat, lhs, rhs) => lhs.op_add(rhs, stage).map(Into::into),
662            _ => Err(E::Binary(BinaryOperator::Addition, self.ty(), rhs.ty())),
663        }
664    }
665    pub fn op_sub(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
666        match (self, rhs) {
667            both!(Self::Literal, lhs, rhs) => lhs.op_sub(rhs, stage).map(Into::into),
668            (Self::Vec(lhs), Self::Literal(rhs)) => lhs.op_sub_sca(rhs, stage).map(Into::into),
669            (Self::Literal(lhs), Self::Vec(rhs)) => lhs.op_sub_vec(rhs, stage).map(Into::into),
670            both!(Self::Vec, lhs, rhs) => lhs.op_sub(rhs, stage).map(Into::into),
671            both!(Self::Mat, lhs, rhs) => lhs.op_sub(rhs, stage).map(Into::into),
672            _ => Err(E::Binary(BinaryOperator::Subtraction, self.ty(), rhs.ty())),
673        }
674    }
675    pub fn op_mul(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
676        match (self, rhs) {
677            both!(Self::Literal, lhs, rhs) => lhs.op_mul(rhs, stage).map(Into::into),
678            (Self::Vec(lhs), Self::Literal(rhs)) => lhs.op_mul_sca(rhs, stage).map(Into::into),
679            (Self::Literal(lhs), Self::Vec(rhs)) => lhs.op_mul_vec(rhs, stage).map(Into::into),
680            both!(Self::Vec, lhs, rhs) => lhs.op_mul(rhs, stage).map(Into::into),
681            (Self::Mat(lhs), Self::Literal(rhs)) => lhs.op_mul_sca(rhs, stage).map(Into::into),
682            (Self::Literal(lhs), Self::Mat(rhs)) => lhs.op_mul_mat(rhs, stage).map(Into::into),
683            (Self::Mat(lhs), Self::Vec(rhs)) => lhs.op_mul_vec(rhs, stage).map(Into::into),
684            (Self::Vec(lhs), Self::Mat(rhs)) => lhs.op_mul_mat(rhs, stage).map(Into::into),
685            both!(Self::Mat, lhs, rhs) => lhs.op_mul(rhs, stage).map(Into::into),
686            _ => Err(E::Binary(
687                BinaryOperator::Multiplication,
688                self.ty(),
689                rhs.ty(),
690            )),
691        }
692    }
693    pub fn op_div(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
694        match (self, rhs) {
695            both!(Self::Literal, lhs, rhs) => lhs.op_div(rhs, stage).map(Into::into),
696            (Self::Literal(s), Self::Vec(v)) => {
697                v.compwise_unary(|k| s.op_div(k, stage)).map(Into::into)
698            }
699            (Self::Vec(v), Self::Literal(s)) => {
700                v.compwise_unary(|k| k.op_div(s, stage)).map(Into::into)
701            }
702            both!(Self::Vec, lhs, rhs) => lhs.op_div(rhs, stage).map(Into::into),
703            _ => Err(E::Binary(BinaryOperator::Division, self.ty(), rhs.ty())),
704        }
705    }
706    pub fn op_rem(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
707        match (self, rhs) {
708            both!(Self::Literal, lhs, rhs) => lhs.op_rem(rhs, stage).map(Into::into),
709            (Self::Literal(s), Self::Vec(v)) => {
710                v.compwise_unary(|k| s.op_rem(k, stage)).map(Into::into)
711            }
712            (Self::Vec(v), Self::Literal(s)) => {
713                v.compwise_unary(|k| k.op_rem(s, stage)).map(Into::into)
714            }
715            both!(Self::Vec, lhs, rhs) => lhs.op_rem(rhs, stage).map(Into::into),
716            _ => Err(E::Binary(BinaryOperator::Remainder, self.ty(), rhs.ty())),
717        }
718    }
719}
720
721// ----------------------
722// COMPARISON EXPRESSIONS
723// ----------------------
724// reference: https://www.w3.org/TR/WGSL/#comparison-expr
725
726impl LiteralInstance {
727    pub fn op_eq(&self, rhs: &Self) -> Result<bool, E> {
728        let err = || E::Binary(BinaryOperator::Equality, self.ty(), rhs.ty());
729        match convert(self, rhs).ok_or_else(err)? {
730            both!(Self::Bool, lhs, rhs) => Ok(lhs == rhs),
731            both!(Self::AbstractInt, lhs, rhs) => Ok(lhs == rhs),
732            both!(Self::AbstractFloat, lhs, rhs) => Ok(lhs == rhs),
733            both!(Self::I32, lhs, rhs) => Ok(lhs == rhs),
734            both!(Self::U32, lhs, rhs) => Ok(lhs == rhs),
735            both!(Self::F32, lhs, rhs) => Ok(lhs == rhs),
736            both!(Self::F16, lhs, rhs) => Ok(lhs == rhs),
737            #[cfg(feature = "naga-ext")]
738            both!(Self::I64, lhs, rhs) => Ok(lhs == rhs),
739            #[cfg(feature = "naga-ext")]
740            both!(Self::U64, lhs, rhs) => Ok(lhs == rhs),
741            #[cfg(feature = "naga-ext")]
742            both!(Self::F64, lhs, rhs) => Ok(lhs == rhs),
743            _ => Err(err()),
744        }
745    }
746    pub fn op_ne(&self, rhs: &Self) -> Result<bool, E> {
747        let err = || E::Binary(BinaryOperator::Inequality, self.ty(), rhs.ty());
748        match convert(self, rhs).ok_or_else(err)? {
749            both!(Self::Bool, lhs, rhs) => Ok(lhs != rhs),
750            both!(Self::AbstractInt, lhs, rhs) => Ok(lhs != rhs),
751            both!(Self::AbstractFloat, lhs, rhs) => Ok(lhs != rhs),
752            both!(Self::I32, lhs, rhs) => Ok(lhs != rhs),
753            both!(Self::U32, lhs, rhs) => Ok(lhs != rhs),
754            both!(Self::F32, lhs, rhs) => Ok(lhs != rhs),
755            both!(Self::F16, lhs, rhs) => Ok(lhs != rhs),
756            #[cfg(feature = "naga-ext")]
757            both!(Self::I64, lhs, rhs) => Ok(lhs != rhs),
758            #[cfg(feature = "naga-ext")]
759            both!(Self::U64, lhs, rhs) => Ok(lhs != rhs),
760            #[cfg(feature = "naga-ext")]
761            both!(Self::F64, lhs, rhs) => Ok(lhs != rhs),
762            _ => Err(err()),
763        }
764    }
765    pub fn op_lt(&self, rhs: &Self) -> Result<bool, E> {
766        let err = || E::Binary(BinaryOperator::LessThan, self.ty(), rhs.ty());
767        match convert(self, rhs).ok_or_else(err)? {
768            both!(Self::Bool, lhs, rhs) => Ok(!lhs & rhs),
769            both!(Self::AbstractInt, lhs, rhs) => Ok(lhs < rhs),
770            both!(Self::AbstractFloat, lhs, rhs) => Ok(lhs < rhs),
771            both!(Self::I32, lhs, rhs) => Ok(lhs < rhs),
772            both!(Self::U32, lhs, rhs) => Ok(lhs < rhs),
773            both!(Self::F32, lhs, rhs) => Ok(lhs < rhs),
774            both!(Self::F16, lhs, rhs) => Ok(lhs < rhs),
775            #[cfg(feature = "naga-ext")]
776            both!(Self::I64, lhs, rhs) => Ok(lhs < rhs),
777            #[cfg(feature = "naga-ext")]
778            both!(Self::U64, lhs, rhs) => Ok(lhs < rhs),
779            #[cfg(feature = "naga-ext")]
780            both!(Self::F64, lhs, rhs) => Ok(lhs < rhs),
781            _ => Err(err()),
782        }
783    }
784    pub fn op_le(&self, rhs: &Self) -> Result<bool, E> {
785        let err = || E::Binary(BinaryOperator::LessThanEqual, self.ty(), rhs.ty());
786        match convert(self, rhs).ok_or_else(err)? {
787            both!(Self::Bool, lhs, rhs) => Ok(lhs <= rhs),
788            both!(Self::AbstractInt, lhs, rhs) => Ok(lhs <= rhs),
789            both!(Self::AbstractFloat, lhs, rhs) => Ok(lhs <= rhs),
790            both!(Self::I32, lhs, rhs) => Ok(lhs <= rhs),
791            both!(Self::U32, lhs, rhs) => Ok(lhs <= rhs),
792            both!(Self::F32, lhs, rhs) => Ok(lhs <= rhs),
793            both!(Self::F16, lhs, rhs) => Ok(lhs <= rhs),
794            #[cfg(feature = "naga-ext")]
795            both!(Self::I64, lhs, rhs) => Ok(lhs <= rhs),
796            #[cfg(feature = "naga-ext")]
797            both!(Self::U64, lhs, rhs) => Ok(lhs <= rhs),
798            #[cfg(feature = "naga-ext")]
799            both!(Self::F64, lhs, rhs) => Ok(lhs <= rhs),
800            _ => Err(err()),
801        }
802    }
803    pub fn op_gt(&self, rhs: &Self) -> Result<bool, E> {
804        let err = || E::Binary(BinaryOperator::GreaterThan, self.ty(), rhs.ty());
805        match convert(self, rhs).ok_or_else(err)? {
806            both!(Self::Bool, lhs, rhs) => Ok(lhs & !rhs),
807            both!(Self::AbstractInt, lhs, rhs) => Ok(lhs > rhs),
808            both!(Self::AbstractFloat, lhs, rhs) => Ok(lhs > rhs),
809            both!(Self::I32, lhs, rhs) => Ok(lhs > rhs),
810            both!(Self::U32, lhs, rhs) => Ok(lhs > rhs),
811            both!(Self::F32, lhs, rhs) => Ok(lhs > rhs),
812            both!(Self::F16, lhs, rhs) => Ok(lhs > rhs),
813            #[cfg(feature = "naga-ext")]
814            both!(Self::I64, lhs, rhs) => Ok(lhs > rhs),
815            #[cfg(feature = "naga-ext")]
816            both!(Self::U64, lhs, rhs) => Ok(lhs > rhs),
817            #[cfg(feature = "naga-ext")]
818            both!(Self::F64, lhs, rhs) => Ok(lhs > rhs),
819            _ => Err(err()),
820        }
821    }
822    pub fn op_ge(&self, rhs: &Self) -> Result<bool, E> {
823        let err = || E::Binary(BinaryOperator::GreaterThanEqual, self.ty(), rhs.ty());
824        match convert(self, rhs).ok_or_else(err)? {
825            both!(Self::Bool, lhs, rhs) => Ok(lhs >= rhs),
826            both!(Self::AbstractInt, lhs, rhs) => Ok(lhs >= rhs),
827            both!(Self::AbstractFloat, lhs, rhs) => Ok(lhs >= rhs),
828            both!(Self::I32, lhs, rhs) => Ok(lhs >= rhs),
829            both!(Self::U32, lhs, rhs) => Ok(lhs >= rhs),
830            both!(Self::F32, lhs, rhs) => Ok(lhs >= rhs),
831            both!(Self::F16, lhs, rhs) => Ok(lhs >= rhs),
832            #[cfg(feature = "naga-ext")]
833            both!(Self::I64, lhs, rhs) => Ok(lhs >= rhs),
834            #[cfg(feature = "naga-ext")]
835            both!(Self::U64, lhs, rhs) => Ok(lhs >= rhs),
836            #[cfg(feature = "naga-ext")]
837            both!(Self::F64, lhs, rhs) => Ok(lhs >= rhs),
838            _ => Err(err()),
839        }
840    }
841}
842
843impl VecInstance {
844    pub fn op_eq(&self, rhs: &Self) -> Result<Self, E> {
845        let (lhs, rhs) = convert(self, rhs)
846            .ok_or_else(|| E::Binary(BinaryOperator::Equality, self.ty(), rhs.ty()))?;
847        lhs.compwise_binary(&rhs, |l, r| l.op_eq(r).map(Into::into))
848    }
849    pub fn op_ne(&self, rhs: &Self) -> Result<Self, E> {
850        let (lhs, rhs) = convert(self, rhs)
851            .ok_or_else(|| E::Binary(BinaryOperator::Inequality, self.ty(), rhs.ty()))?;
852        lhs.compwise_binary(&rhs, |l, r| l.op_ne(r).map(Into::into))
853    }
854    pub fn op_lt(&self, rhs: &Self) -> Result<Self, E> {
855        let (lhs, rhs) = convert(self, rhs)
856            .ok_or_else(|| E::Binary(BinaryOperator::LessThan, self.ty(), rhs.ty()))?;
857        lhs.compwise_binary(&rhs, |l, r| l.op_lt(r).map(Into::into))
858    }
859    pub fn op_le(&self, rhs: &Self) -> Result<Self, E> {
860        let (lhs, rhs) = convert(self, rhs)
861            .ok_or_else(|| E::Binary(BinaryOperator::LessThanEqual, self.ty(), rhs.ty()))?;
862        lhs.compwise_binary(&rhs, |l, r| l.op_le(r).map(Into::into))
863    }
864    pub fn op_gt(&self, rhs: &Self) -> Result<Self, E> {
865        let (lhs, rhs) = convert(self, rhs)
866            .ok_or_else(|| E::Binary(BinaryOperator::GreaterThan, self.ty(), rhs.ty()))?;
867        lhs.compwise_binary(&rhs, |l, r| l.op_gt(r).map(Into::into))
868    }
869    pub fn op_ge(&self, rhs: &Self) -> Result<Self, E> {
870        let (lhs, rhs) = convert(self, rhs)
871            .ok_or_else(|| E::Binary(BinaryOperator::GreaterThanEqual, self.ty(), rhs.ty()))?;
872        lhs.compwise_binary(&rhs, |l, r| l.op_ge(r).map(Into::into))
873    }
874}
875
876impl Instance {
877    pub fn op_eq(&self, rhs: &Self) -> Result<Self, E> {
878        match (self, rhs) {
879            both!(Self::Literal, lhs, rhs) => lhs
880                .op_eq(rhs)
881                .map(|b| Self::Literal(LiteralInstance::Bool(b))),
882            both!(Self::Vec, lhs, rhs) => lhs.op_eq(rhs).map(Into::into),
883            _ => Err(E::Binary(BinaryOperator::Equality, self.ty(), rhs.ty())),
884        }
885    }
886    pub fn op_ne(&self, rhs: &Self) -> Result<Self, E> {
887        match (self, rhs) {
888            both!(Self::Literal, lhs, rhs) => lhs
889                .op_ne(rhs)
890                .map(|b| Self::Literal(LiteralInstance::Bool(b))),
891            both!(Self::Vec, lhs, rhs) => lhs.op_ne(rhs).map(Into::into),
892            _ => Err(E::Binary(BinaryOperator::Inequality, self.ty(), rhs.ty())),
893        }
894    }
895    pub fn op_lt(&self, rhs: &Self) -> Result<Self, E> {
896        match (self, rhs) {
897            both!(Self::Literal, lhs, rhs) => lhs
898                .op_lt(rhs)
899                .map(|b| Self::Literal(LiteralInstance::Bool(b))),
900            both!(Self::Vec, lhs, rhs) => lhs.op_lt(rhs).map(Into::into),
901            _ => Err(E::Binary(BinaryOperator::LessThan, self.ty(), rhs.ty())),
902        }
903    }
904    pub fn op_le(&self, rhs: &Self) -> Result<Self, E> {
905        match (self, rhs) {
906            both!(Self::Literal, lhs, rhs) => lhs
907                .op_le(rhs)
908                .map(|b| Self::Literal(LiteralInstance::Bool(b))),
909            both!(Self::Vec, lhs, rhs) => lhs.op_le(rhs).map(Into::into),
910            _ => Err(E::Binary(
911                BinaryOperator::LessThanEqual,
912                self.ty(),
913                rhs.ty(),
914            )),
915        }
916    }
917    pub fn op_gt(&self, rhs: &Self) -> Result<Self, E> {
918        match (self, rhs) {
919            both!(Self::Literal, lhs, rhs) => lhs
920                .op_gt(rhs)
921                .map(|b| Self::Literal(LiteralInstance::Bool(b))),
922            both!(Self::Vec, lhs, rhs) => lhs.op_gt(rhs).map(Into::into),
923            _ => Err(E::Binary(BinaryOperator::GreaterThan, self.ty(), rhs.ty())),
924        }
925    }
926    pub fn op_ge(&self, rhs: &Self) -> Result<Self, E> {
927        match (self, rhs) {
928            both!(Self::Literal, lhs, rhs) => lhs
929                .op_ge(rhs)
930                .map(|b| Self::Literal(LiteralInstance::Bool(b))),
931            both!(Self::Vec, lhs, rhs) => lhs.op_ge(rhs).map(Into::into),
932            _ => Err(E::Binary(
933                BinaryOperator::GreaterThanEqual,
934                self.ty(),
935                rhs.ty(),
936            )),
937        }
938    }
939}
940
941// ---------------
942// BIT EXPRESSIONS
943// ---------------
944// reference: https://www.w3.org/TR/WGSL/#bit-expr
945
946impl LiteralInstance {
947    pub fn op_bitnot(&self) -> Result<Self, E> {
948        match self {
949            Self::AbstractInt(n) => Ok(Self::AbstractInt(!n)),
950            Self::I32(n) => Ok(Self::I32(!n)),
951            Self::U32(n) => Ok(Self::U32(!n)),
952            #[cfg(feature = "naga-ext")]
953            Self::I64(n) => Ok(Self::I64(!n)),
954            #[cfg(feature = "naga-ext")]
955            Self::U64(n) => Ok(Self::U64(!n)),
956            _ => Err(E::Unary(UnaryOperator::BitwiseComplement, self.ty())),
957        }
958    }
959    /// Note: this is both the "bitwise OR" and "logical OR" operator.
960    pub fn op_bitor(&self, rhs: &Self) -> Result<Self, E> {
961        let err = || E::Binary(BinaryOperator::BitwiseOr, self.ty(), rhs.ty());
962        match convert(self, rhs).ok_or_else(err)? {
963            both!(Self::Bool, rhs, lhs) => Ok(Self::Bool(lhs | rhs)),
964            both!(Self::AbstractInt, rhs, lhs) => Ok(Self::AbstractInt(lhs | rhs)),
965            both!(Self::I32, rhs, lhs) => Ok(Self::I32(lhs | rhs)),
966            both!(Self::U32, rhs, lhs) => Ok(Self::U32(lhs | rhs)),
967            #[cfg(feature = "naga-ext")]
968            both!(Self::I64, rhs, lhs) => Ok(Self::I64(lhs | rhs)),
969            #[cfg(feature = "naga-ext")]
970            both!(Self::U64, rhs, lhs) => Ok(Self::U64(lhs | rhs)),
971            _ => Err(err()),
972        }
973    }
974    /// Note: this is both the "bitwise AND" and "logical AND" operator.
975    pub fn op_bitand(&self, rhs: &Self) -> Result<Self, E> {
976        let err = || E::Binary(BinaryOperator::BitwiseAnd, self.ty(), rhs.ty());
977        match convert(self, rhs).ok_or_else(err)? {
978            both!(Self::Bool, rhs, lhs) => Ok(Self::Bool(lhs & rhs)),
979            both!(Self::AbstractInt, rhs, lhs) => Ok(Self::AbstractInt(lhs & rhs)),
980            both!(Self::I32, rhs, lhs) => Ok(Self::I32(lhs & rhs)),
981            both!(Self::U32, rhs, lhs) => Ok(Self::U32(lhs & rhs)),
982            #[cfg(feature = "naga-ext")]
983            both!(Self::I64, rhs, lhs) => Ok(Self::I64(lhs & rhs)),
984            #[cfg(feature = "naga-ext")]
985            both!(Self::U64, rhs, lhs) => Ok(Self::U64(lhs & rhs)),
986            _ => Err(err()),
987        }
988    }
989    pub fn op_bitxor(&self, rhs: &Self) -> Result<Self, E> {
990        let err = || E::Binary(BinaryOperator::BitwiseXor, self.ty(), rhs.ty());
991        match convert(self, rhs).ok_or_else(err)? {
992            both!(Self::AbstractInt, rhs, lhs) => Ok(Self::AbstractInt(lhs ^ rhs)),
993            both!(Self::I32, rhs, lhs) => Ok(Self::I32(lhs ^ rhs)),
994            both!(Self::U32, rhs, lhs) => Ok(Self::U32(lhs ^ rhs)),
995            #[cfg(feature = "naga-ext")]
996            both!(Self::I64, rhs, lhs) => Ok(Self::I64(lhs ^ rhs)),
997            #[cfg(feature = "naga-ext")]
998            both!(Self::U64, rhs, lhs) => Ok(Self::U64(lhs ^ rhs)),
999            _ => Err(err()),
1000        }
1001    }
1002    pub fn op_shl(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
1003        let err = || E::Binary(BinaryOperator::ShiftLeft, self.ty(), rhs.ty());
1004        let r = rhs.convert_to(&Type::U32).ok_or_else(err)?.unwrap_u32();
1005        let stage = stage == ShaderStage::Const || stage == ShaderStage::Override;
1006
1007        // in const and override expressions, shr operation must not overflow (all discarded bits
1008        // must be 0 in positive expressions and 1 in negative expressions).
1009        // only abstract types can be shifted by more than the bit width of the operand.
1010        match self {
1011            Self::AbstractInt(l) => {
1012                if r == 0 {
1013                    // shift by 0 is no-op
1014                    return Ok(*self);
1015                } else if r > 63 {
1016                    // shifting that much always returns 0
1017                    return Ok(0i64.into());
1018                }
1019                let msb_mask = (!0u64) << (63 - r);
1020                let msb_bits = *l as u64 & msb_mask;
1021                if stage && (*l >= 0 && msb_bits != 0 || *l < 0 && msb_bits != msb_mask) {
1022                    Err(E::ShlOverflow(r, *self))
1023                } else {
1024                    Ok(l.wrapping_shl(r).into())
1025                }
1026            }
1027            Self::I32(l) => {
1028                let r = r % 32; // "the number of bits to shift is the value of e2, modulo the bit width of e1"
1029                if r == 0 {
1030                    // shift by 0 is no-op
1031                    return Ok(*self);
1032                }
1033                let msb_mask = (!0u32) << (31 - r);
1034                let msb_bits = *l as u32 & msb_mask;
1035                if stage && (*l >= 0 && msb_bits != 0 || *l < 0 && msb_bits != msb_mask) {
1036                    Err(E::ShlOverflow(r, *self))
1037                } else if stage {
1038                    Ok(l.checked_shl(r).ok_or(E::ShlOverflow(r, *self))?.into())
1039                } else {
1040                    Ok(l.wrapping_shl(r).into())
1041                }
1042            }
1043            Self::U32(l) => {
1044                let r = r % 32; // "the number of bits to shift is the value of e2, modulo the bit width of e1"
1045                if r == 0 {
1046                    // shift by 0 is no-op
1047                    return Ok(*self);
1048                }
1049                let msb_mask = (!0u32) << (32 - r);
1050                let msb_bits = *l & msb_mask;
1051                if stage && msb_bits != 0 {
1052                    Err(E::ShlOverflow(r, *self))
1053                } else if stage {
1054                    Ok(l.checked_shl(r).ok_or(E::ShlOverflow(r, *self))?.into())
1055                } else {
1056                    Ok(l.wrapping_shl(r).into())
1057                }
1058            }
1059            #[cfg(feature = "naga-ext")]
1060            Self::I64(l) => {
1061                let r = r % 64; // "the number of bits to shift is the value of e2, modulo the bit width of e1"
1062                if r == 0 {
1063                    // shift by 0 is no-op
1064                    return Ok(*self);
1065                }
1066                let msb_mask = (!0u64) << (31 - r);
1067                let msb_bits = *l as u64 & msb_mask;
1068                if stage && (*l >= 0 && msb_bits != 0 || *l < 0 && msb_bits != msb_mask) {
1069                    Err(E::ShlOverflow(r, *self))
1070                } else if stage {
1071                    Ok(Self::I64(l.checked_shl(r).ok_or(E::ShlOverflow(r, *self))?))
1072                } else {
1073                    Ok(Self::I64(l.wrapping_shl(r)))
1074                }
1075            }
1076            #[cfg(feature = "naga-ext")]
1077            Self::U64(l) => {
1078                let r = r % 64; // "the number of bits to shift is the value of e2, modulo the bit width of e1"
1079                if r == 0 {
1080                    // shift by 0 is no-op
1081                    return Ok(*self);
1082                }
1083                let msb_mask = (!0u64) << (64 - r);
1084                let msb_bits = *l & msb_mask;
1085                if stage && msb_bits != 0 {
1086                    Err(E::ShlOverflow(r, *self))
1087                } else if stage {
1088                    Ok(Self::U64(l.checked_shl(r).ok_or(E::ShlOverflow(r, *self))?))
1089                } else {
1090                    Ok(Self::U64(l.wrapping_shl(r)))
1091                }
1092            }
1093            _ => Err(err()),
1094        }
1095    }
1096    pub fn op_shr(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
1097        let err = || E::Binary(BinaryOperator::ShiftRight, self.ty(), rhs.ty());
1098        let r = rhs.convert_to(&Type::U32).ok_or_else(err)?.unwrap_u32();
1099        let stage = stage == ShaderStage::Const || stage == ShaderStage::Override;
1100
1101        // shift by 0 is no-op
1102        if r == 0 {
1103            return Ok(*self);
1104        }
1105
1106        // contrary to shl, it is not an error to overflow (discard non-zero bits). But it is an
1107        // error to shift more than the bit width.
1108        match self {
1109            Self::I32(l) => Ok(if stage {
1110                l.checked_shr(r).ok_or(E::ShrOverflow(r, *self))?.into()
1111            } else {
1112                l.wrapping_shr(r).into()
1113            }),
1114            Self::U32(l) => Ok(if stage {
1115                l.checked_shr(r).ok_or(E::ShrOverflow(r, *self))?.into()
1116            } else {
1117                l.wrapping_shr(r).into()
1118            }),
1119            Self::AbstractInt(l) => {
1120                // we shr twice because x >> 64 is panic(overflow) and wrapping_shr only allow x >> 63.
1121                Ok((l >> 1).wrapping_shr(r - 1).into())
1122            }
1123            #[cfg(feature = "naga-ext")]
1124            Self::I64(l) => Ok(if stage {
1125                Self::I64(l.checked_shr(r).ok_or(E::ShrOverflow(r, *self))?)
1126            } else {
1127                Self::I64(l.wrapping_shr(r))
1128            }),
1129            #[cfg(feature = "naga-ext")]
1130            Self::U64(l) => Ok(if stage {
1131                Self::U64(l.checked_shr(r).ok_or(E::ShrOverflow(r, *self))?)
1132            } else {
1133                Self::U64(l.wrapping_shr(r))
1134            }),
1135            _ => Err(E::Binary(BinaryOperator::ShiftRight, self.ty(), rhs.ty())),
1136        }
1137    }
1138}
1139
1140impl VecInstance {
1141    pub fn op_bitnot(&self) -> Result<Self, E> {
1142        self.compwise_unary(LiteralInstance::op_bitnot)
1143    }
1144    /// Note: this is both the "bitwise OR" and "logical OR" operator.
1145    pub fn op_bitor(&self, rhs: &Self) -> Result<Self, E> {
1146        self.compwise_binary(rhs, |l, r| l.op_bitor(r))
1147    }
1148    /// Note: this is both the "bitwise AND" and "logical AND" operator.
1149    pub fn op_bitand(&self, rhs: &Self) -> Result<Self, E> {
1150        self.compwise_binary(rhs, |l, r| l.op_bitand(r))
1151    }
1152    pub fn op_bitxor(&self, rhs: &Self) -> Result<Self, E> {
1153        self.compwise_binary(rhs, |l, r| l.op_bitxor(r))
1154    }
1155    pub fn op_shl(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
1156        self.compwise_binary(rhs, |l, r| l.op_shl(r, stage))
1157    }
1158    pub fn op_shr(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
1159        self.compwise_binary(rhs, |l, r| l.op_shr(r, stage))
1160    }
1161}
1162
1163impl Instance {
1164    pub fn op_bitnot(&self) -> Result<Self, E> {
1165        match self {
1166            Instance::Literal(l) => l.op_bitnot().map(Into::into),
1167            Instance::Vec(v) => v.op_bitnot().map(Into::into),
1168            _ => Err(E::Unary(UnaryOperator::BitwiseComplement, self.ty())),
1169        }
1170    }
1171    /// Note: this is both the "bitwise OR" and "logical OR" operator.
1172    pub fn op_bitor(&self, rhs: &Self) -> Result<Self, E> {
1173        match (self, rhs) {
1174            both!(Self::Literal, lhs, rhs) => lhs.op_bitor(rhs).map(Into::into),
1175            both!(Self::Vec, lhs, rhs) => lhs.op_bitor(rhs).map(Into::into),
1176            _ => Err(E::Binary(BinaryOperator::BitwiseOr, self.ty(), rhs.ty())),
1177        }
1178    }
1179    /// Note: this is both the "bitwise AND" and "logical AND" operator.
1180    pub fn op_bitand(&self, rhs: &Self) -> Result<Self, E> {
1181        match (self, rhs) {
1182            both!(Self::Literal, lhs, rhs) => lhs.op_bitand(rhs).map(Into::into),
1183            both!(Self::Vec, lhs, rhs) => lhs.op_bitand(rhs).map(Into::into),
1184            _ => Err(E::Binary(BinaryOperator::BitwiseAnd, self.ty(), rhs.ty())),
1185        }
1186    }
1187    pub fn op_bitxor(&self, rhs: &Self) -> Result<Self, E> {
1188        match (self, rhs) {
1189            both!(Self::Literal, lhs, rhs) => lhs.op_bitxor(rhs).map(Into::into),
1190            both!(Self::Vec, lhs, rhs) => lhs.op_bitxor(rhs).map(Into::into),
1191            _ => Err(E::Binary(BinaryOperator::BitwiseXor, self.ty(), rhs.ty())),
1192        }
1193    }
1194    pub fn op_shl(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
1195        match (self, rhs) {
1196            both!(Self::Literal, lhs, rhs) => lhs.op_shl(rhs, stage).map(Into::into),
1197            both!(Self::Vec, lhs, rhs) => lhs.op_shl(rhs, stage).map(Into::into),
1198            _ => Err(E::Binary(BinaryOperator::ShiftLeft, self.ty(), rhs.ty())),
1199        }
1200    }
1201    pub fn op_shr(&self, rhs: &Self, stage: ShaderStage) -> Result<Self, E> {
1202        match (self, rhs) {
1203            both!(Self::Literal, lhs, rhs) => lhs.op_shr(rhs, stage).map(Into::into),
1204            both!(Self::Vec, lhs, rhs) => lhs.op_shr(rhs, stage).map(Into::into),
1205            _ => Err(E::Binary(BinaryOperator::ShiftRight, self.ty(), rhs.ty())),
1206        }
1207    }
1208}
1209
1210// -------------------
1211// POINTER EXPRESSIONS
1212// -------------------
1213// reference: https://www.w3.org/TR/WGSL/#address-of-expr
1214// reference: https://www.w3.org/TR/WGSL/#indirection-expr
1215
1216impl Instance {
1217    pub fn op_ref(&self) -> Result<Instance, E> {
1218        match self {
1219            Instance::Ref(r) => {
1220                if r.space == AddressSpace::Handle {
1221                    // "It is a shader-creation error if AS is the handle address space."
1222                    Err(E::PtrHandle)
1223                } else if r.ptr.borrow().ty().is_vec() && r.view != MemView::Whole {
1224                    // "It is a shader-creation error if r is a reference to a vector component."
1225                    Err(E::PtrVecComp)
1226                } else {
1227                    Ok(PtrInstance::from(r.clone()).into())
1228                }
1229            }
1230            _ => Err(E::Unary(UnaryOperator::AddressOf, self.ty())),
1231        }
1232    }
1233
1234    pub fn op_deref(&self) -> Result<Instance, E> {
1235        match self {
1236            Instance::Ptr(p) => Ok(RefInstance::from(p.clone()).into()),
1237            _ => Err(E::Unary(UnaryOperator::Indirection, self.ty())),
1238        }
1239    }
1240}