wgsl_types/builtin/
ctor.rs

1//! Constructor implementations, including zero-value constructors.
2//!
3//! Functions bear the same name as the WGSL counterpart.
4//! Functions that take template parameters are suffixed with `_t` and take `tplt_*` arguments.
5//!
6//! ### Usage quirks
7//!
8//! * The arguments must be [loaded][Type::loaded].
9//! * The struct constructor is a bit special since it is the only user-defined type.
10//!   Use [`struct_ctor`] and [`typecheck_struct_ctor`] for structs.
11//! * User-defined functions can shadow WGSL built-in functions.
12//! * Type aliases must be resolved: WGSL allows calling functions with the name of the alias.
13
14use half::prelude::*;
15use itertools::Itertools;
16use num_traits::{FromPrimitive, One, ToPrimitive, Zero};
17
18use crate::{
19    CallSignature, Error, ShaderStage,
20    conv::{Convert, convert_all, convert_all_inner_to, convert_all_to, convert_all_ty},
21    inst::{ArrayInstance, Instance, LiteralInstance, MatInstance, StructInstance, VecInstance},
22    tplt::{ArrayTemplate, MatTemplate, TpltParam, VecTemplate},
23    ty::{StructType, Ty, Type},
24};
25
26type E = Error;
27
28/// Check if a function name could correspond to a built-in constructor function.
29///
30/// Warning: WGSL allows shadowing built-in functions. Check that a user-defined
31/// function does not shadow the built-in one.
32pub fn is_ctor(name: &str) -> bool {
33    match name {
34        "array" | "bool" | "i32" | "u32" | "f32" | "f16" | "mat2x2" | "mat2x3" | "mat2x4"
35        | "mat3x2" | "mat3x3" | "mat3x4" | "mat4x2" | "mat4x3" | "mat4x4" | "vec2" | "vec3"
36        | "vec4" => true,
37        #[cfg(feature = "naga-ext")]
38        "i64" | "u64" | "f64" => true,
39        _ => false,
40    }
41}
42
43// ------------
44// CONSTRUCTORS
45// ------------
46// reference: <https://www.w3.org/TR/WGSL/#constructor-builtin-function>
47
48/// `array<T,N>()` constructor.
49///
50/// Reference: <https://www.w3.org/TR/WGSL/#array-builtin>
51pub fn array_t(tplt_ty: &Type, tplt_n: usize, args: &[Instance]) -> Result<Instance, E> {
52    let args = args
53        .iter()
54        .map(|a| {
55            a.convert_to(tplt_ty).ok_or_else(|| {
56                E::ParamType(Type::Array(Box::new(tplt_ty.clone()), Some(tplt_n)), a.ty())
57            })
58        })
59        .collect::<Result<Vec<_>, _>>()?;
60
61    if args.len() != tplt_n {
62        return Err(E::ParamCount("array".to_string(), tplt_n, args.len()));
63    }
64
65    Ok(ArrayInstance::new(args, false).into())
66}
67
68/// `array()` constructor.
69///
70/// Reference: <https://www.w3.org/TR/WGSL/#array-builtin>
71pub fn array(args: &[Instance]) -> Result<Instance, E> {
72    let args = convert_all(args).ok_or(E::Builtin("array elements are incompatible"))?;
73
74    if args.is_empty() {
75        return Err(E::Builtin("array constructor expects at least 1 argument"));
76    }
77
78    Ok(ArrayInstance::new(args, false).into())
79}
80
81/// `bool()` constructor.
82///
83/// Reference: <https://www.w3.org/TR/WGSL/#bool-builtin>
84pub fn bool(a1: &Instance) -> Result<Instance, E> {
85    match a1 {
86        Instance::Literal(l) => {
87            let zero = LiteralInstance::zero_value(&l.ty())?;
88            Ok(LiteralInstance::Bool(*l != zero).into())
89        }
90        _ => Err(E::Builtin("bool constructor expects a scalar argument")),
91    }
92}
93
94/// `i32()` constructor.
95///
96/// Reference: <https://www.w3.org/TR/WGSL/#i32-builtin>
97// TODO: check that "If T is a floating point type, e is converted to i32, rounding towards zero."
98pub fn i32(a1: &Instance) -> Result<Instance, E> {
99    match a1 {
100        Instance::Literal(l) => {
101            let val = match l {
102                LiteralInstance::Bool(n) => Some(n.then_some(1).unwrap_or(0)),
103                LiteralInstance::AbstractInt(n) => n.to_i32(), // identity if representable
104                LiteralInstance::AbstractFloat(n) => Some(*n as i32), // rounding towards 0
105                LiteralInstance::I32(n) => Some(*n),           // identity operation
106                LiteralInstance::U32(n) => Some(*n as i32),    // reinterpretation of bits
107                LiteralInstance::F32(n) => Some((*n as i32).min(2147483520)), // rounding towards 0 AND representable in f32
108                LiteralInstance::F16(n) => Some((f16::to_f32(*n) as i32).min(65504)), // rounding towards 0 AND representable in f16
109                #[cfg(feature = "naga-ext")]
110                LiteralInstance::I64(n) => n.to_i32(), // identity if representable
111                #[cfg(feature = "naga-ext")]
112                LiteralInstance::U64(n) => n.to_i32(), // identity if representable
113                #[cfg(feature = "naga-ext")]
114                LiteralInstance::F64(n) => Some(*n as i32), // rounding towards 0
115            }
116            .ok_or(E::ConvOverflow(*l, Type::I32))?;
117            Ok(LiteralInstance::I32(val).into())
118        }
119        _ => Err(E::Builtin("i32 constructor expects a scalar argument")),
120    }
121}
122
123/// `u32()` constructor.
124///
125/// Reference: <https://www.w3.org/TR/WGSL/#u32-builtin>
126pub fn u32(a1: &Instance) -> Result<Instance, E> {
127    match a1 {
128        Instance::Literal(l) => {
129            let val = match l {
130                LiteralInstance::Bool(n) => Some(n.then_some(1).unwrap_or(0)),
131                LiteralInstance::AbstractInt(n) => n.to_u32(), // identity if representable
132                LiteralInstance::AbstractFloat(n) => Some(*n as u32), // rounding towards 0
133                LiteralInstance::I32(n) => Some(*n as u32),    // reinterpretation of bits
134                LiteralInstance::U32(n) => Some(*n),           // identity operation
135                LiteralInstance::F32(n) => Some((*n as u32).min(4294967040)), // rounding towards 0 AND representable in f32
136                LiteralInstance::F16(n) => Some((f16::to_f32(*n) as u32).min(65504)), // rounding towards 0 AND representable in f16
137                #[cfg(feature = "naga-ext")]
138                LiteralInstance::I64(n) => n.to_u32(), // identity if representable
139                #[cfg(feature = "naga-ext")]
140                LiteralInstance::U64(n) => n.to_u32(), // identity if representable
141                #[cfg(feature = "naga-ext")]
142                LiteralInstance::F64(n) => Some(*n as u32), // rounding towards 0
143            }
144            .ok_or(E::ConvOverflow(*l, Type::U32))?;
145            Ok(LiteralInstance::U32(val).into())
146        }
147        _ => Err(E::Builtin("u32 constructor expects a scalar argument")),
148    }
149}
150
151/// `f32()` constructor.
152///
153/// Reference: <https://www.w3.org/TR/WGSL/#f32-builtin>
154pub fn f32(a1: &Instance, _stage: ShaderStage) -> Result<Instance, E> {
155    match a1 {
156        Instance::Literal(l) => {
157            let val = match l {
158                LiteralInstance::Bool(n) => Some(n.then_some(f32::one()).unwrap_or(f32::zero())),
159                LiteralInstance::AbstractInt(n) => n.to_f32(), // implicit conversion
160                LiteralInstance::AbstractFloat(n) => n.to_f32(), // implicit conversion
161                LiteralInstance::I32(n) => Some(*n as f32),    // scalar to float (never overflows)
162                LiteralInstance::U32(n) => Some(*n as f32),    // scalar to float (never overflows)
163                LiteralInstance::F32(n) => Some(*n),           // identity operation
164                LiteralInstance::F16(n) => Some(f16::to_f32(*n)), // exactly representable
165                #[cfg(feature = "naga-ext")]
166                LiteralInstance::I64(n) => n.to_f32(), // implicit conversion
167                #[cfg(feature = "naga-ext")]
168                LiteralInstance::U64(n) => n.to_f32(), // implicit conversion
169                #[cfg(feature = "naga-ext")]
170                LiteralInstance::F64(n) => n.to_f32(), // implicit conversion
171            }
172            .ok_or(E::ConvOverflow(*l, Type::F32))?;
173            Ok(LiteralInstance::F32(val).into())
174        }
175        _ => Err(E::Builtin("f32 constructor expects a scalar argument")),
176    }
177}
178
179/// `f16()` constructor.
180///
181/// Reference: <https://www.w3.org/TR/WGSL/#f16-builtin>
182pub fn f16(a1: &Instance, stage: ShaderStage) -> Result<Instance, E> {
183    match a1 {
184        Instance::Literal(l) => {
185            let val = match l {
186                LiteralInstance::Bool(n) => Some(n.then_some(f16::one()).unwrap_or(f16::zero())),
187                LiteralInstance::AbstractInt(n) => {
188                    // scalar to float (can overflow)
189                    if stage == ShaderStage::Const {
190                        let range = -65504..=65504;
191                        range.contains(n).then_some(f16::from_f32(*n as f32))
192                    } else {
193                        Some(f16::from_f32(*n as f32))
194                    }
195                }
196                LiteralInstance::AbstractFloat(n) => {
197                    // scalar to float (can overflow)
198                    if stage == ShaderStage::Const {
199                        let range = -65504.0..=65504.0;
200                        range.contains(n).then_some(f16::from_f32(*n as f32))
201                    } else {
202                        Some(f16::from_f32(*n as f32))
203                    }
204                }
205                LiteralInstance::I32(n) => {
206                    // scalar to float (can overflow)
207                    if stage == ShaderStage::Const {
208                        f16::from_i32(*n)
209                    } else {
210                        Some(f16::from_f32(*n as f32))
211                    }
212                }
213                LiteralInstance::U32(n) => {
214                    // scalar to float (can overflow)
215                    if stage == ShaderStage::Const {
216                        f16::from_u32(*n)
217                    } else {
218                        Some(f16::from_f32(*n as f32))
219                    }
220                }
221                LiteralInstance::F32(n) => {
222                    // scalar to float (can overflow)
223                    if stage == ShaderStage::Const {
224                        let range = -65504.0..=65504.0;
225                        range.contains(n).then_some(f16::from_f32(*n))
226                    } else {
227                        Some(f16::from_f32(*n))
228                    }
229                }
230                LiteralInstance::F16(n) => Some(*n), // identity operation
231                #[cfg(feature = "naga-ext")]
232                LiteralInstance::I64(n) => {
233                    // scalar to float (can overflow)
234                    if stage == ShaderStage::Const {
235                        let range = -65504..=65504;
236                        range.contains(n).then_some(f16::from_f32(*n as f32))
237                    } else {
238                        Some(f16::from_f32(*n as f32))
239                    }
240                }
241                #[cfg(feature = "naga-ext")]
242                LiteralInstance::U64(n) => {
243                    // scalar to float (can overflow)
244                    if stage == ShaderStage::Const {
245                        f16::from_u64(*n)
246                    } else {
247                        Some(f16::from_f32(*n as f32))
248                    }
249                }
250                #[cfg(feature = "naga-ext")]
251                LiteralInstance::F64(n) => {
252                    // scalar to float (can overflow)
253                    if stage == ShaderStage::Const {
254                        let range = -65504.0..=65504.0;
255                        range.contains(n).then_some(f16::from_f32(*n as f32))
256                    } else {
257                        Some(f16::from_f32(*n as f32))
258                    }
259                }
260            }
261            .ok_or(E::ConvOverflow(*l, Type::F16))?;
262            Ok(LiteralInstance::F16(val).into())
263        }
264        _ => Err(E::Builtin("f16 constructor expects a scalar argument")),
265    }
266}
267
268/// `i64()` constructor (naga extension).
269///
270/// TODO: This built-in is not implemented!
271pub fn i64(_a1: &Instance) -> Result<Instance, E> {
272    Err(E::Todo("i64".to_string()))
273}
274
275/// `u64()` constructor (naga extension).
276///
277/// TODO: This built-in is not implemented!
278pub fn u64(_a1: &Instance) -> Result<Instance, E> {
279    Err(E::Todo("u64".to_string()))
280}
281
282/// `f64()` constructor (naga extension).
283///
284/// TODO: This built-in is not implemented!
285pub fn f64(_a1: &Instance, _stage: ShaderStage) -> Result<Instance, E> {
286    Err(E::Todo("f64".to_string()))
287}
288
289/// `matCxR<T>()` constructor.
290///
291/// Reference: <https://www.w3.org/TR/WGSL/#mat2x2-builtin>
292pub fn mat_t(
293    c: usize,
294    r: usize,
295    tplt_ty: &Type,
296    args: &[Instance],
297    stage: ShaderStage,
298) -> Result<Instance, E> {
299    // overload 1: mat conversion constructor
300    if let [Instance::Mat(m)] = args {
301        if m.c() != c || m.r() != r {
302            return Err(E::Conversion(
303                m.ty(),
304                Type::Mat(c as u8, r as u8, Box::new(tplt_ty.clone())),
305            ));
306        }
307
308        let conv_fn = match tplt_ty {
309            Type::F32 => f32,
310            Type::F16 => f16,
311            _ => return Err(E::Builtin("matrix type must be a f32 or f16")),
312        };
313
314        let comps = m
315            .iter_cols()
316            .map(|v| {
317                v.unwrap_vec_ref()
318                    .iter()
319                    .map(|n| conv_fn(n, stage))
320                    .collect::<Result<Vec<_>, _>>()
321                    .map(|s| Instance::Vec(VecInstance::new(s)))
322            })
323            .collect::<Result<Vec<_>, _>>()?;
324
325        Ok(MatInstance::from_cols(comps).into())
326    } else {
327        let ty = args
328            .first()
329            .ok_or(E::Builtin("matrix constructor expects arguments"))?
330            .ty();
331        let ty = ty
332            .convert_inner_to(tplt_ty)
333            .ok_or(E::Conversion(ty.inner_ty(), tplt_ty.clone()))?;
334        let args =
335            convert_all_to(args, &ty).ok_or(E::Builtin("matrix components are incompatible"))?;
336
337        // overload 2: mat from column vectors
338        if ty.is_vec() {
339            if args.len() != c {
340                return Err(E::ParamCount(format!("mat{c}x{r}"), c, args.len()));
341            }
342
343            Ok(MatInstance::from_cols(args).into())
344        }
345        // overload 3: mat from float values
346        else if ty.is_float() {
347            if args.len() != c * r {
348                return Err(E::ParamCount(format!("mat{c}x{r}"), c * r, args.len()));
349            }
350
351            let args = args
352                .chunks(r)
353                .map(|v| Instance::Vec(VecInstance::new(v.to_vec())))
354                .collect_vec();
355
356            Ok(MatInstance::from_cols(args).into())
357        } else {
358            Err(E::Builtin(
359                "matrix constructor expects float or vector of float arguments",
360            ))
361        }
362    }
363}
364
365/// `matCxR()` constructor.
366///
367/// Reference: <https://www.w3.org/TR/WGSL/#mat2x2-builtin>
368pub fn mat(c: usize, r: usize, args: &[Instance]) -> Result<Instance, E> {
369    // overload 1: mat conversion constructor
370    if let [Instance::Mat(m)] = args {
371        if m.c() != c || m.r() != r {
372            let ty2 = Type::Mat(c as u8, r as u8, m.inner_ty().into());
373            return Err(E::Conversion(m.ty(), ty2));
374        }
375        // note: `matCxR(e: matCxR<S>) -> matCxR<S>` is no-op
376        Ok(m.clone().into())
377    } else {
378        let tys = args.iter().map(|a| a.ty()).collect_vec();
379        let ty = convert_all_ty(&tys).ok_or(E::Builtin("matrix components are incompatible"))?;
380        let mut inner_ty = ty.inner_ty();
381
382        if inner_ty.is_abstract_int() {
383            // force conversion from AbstractInt to a float type
384            inner_ty = Type::F32;
385        } else if !inner_ty.is_float() {
386            return Err(E::Builtin(
387                "matrix constructor expects float or vector of float arguments",
388            ));
389        }
390
391        let args = convert_all_inner_to(args, &inner_ty)
392            .ok_or(E::Builtin("matrix components are incompatible"))?;
393
394        // overload 2: mat from column vectors
395        if ty.is_vec() {
396            if args.len() != c {
397                return Err(E::ParamCount(format!("mat{c}x{r}"), c, args.len()));
398            }
399
400            Ok(MatInstance::from_cols(args).into())
401        }
402        // overload 3: mat from float values
403        else if ty.is_float() || ty.is_abstract_int() {
404            if args.len() != c * r {
405                return Err(E::ParamCount(format!("mat{c}x{r}"), c * r, args.len()));
406            }
407            let args = args
408                .chunks(r)
409                .map(|v| Instance::Vec(VecInstance::new(v.to_vec())))
410                .collect_vec();
411
412            Ok(MatInstance::from_cols(args).into())
413        } else {
414            Err(E::Builtin(
415                "matrix constructor expects float or vector of float arguments",
416            ))
417        }
418    }
419}
420
421/// `vecN<T>()` constructor.
422///
423/// Reference: <https://www.w3.org/TR/WGSL/#vec2-builtin>
424pub fn vec_t(
425    n: usize,
426    tplt_ty: &Type,
427    args: &[Instance],
428    stage: ShaderStage,
429) -> Result<Instance, E> {
430    // overload 1: vec init from single scalar value
431    if let [Instance::Literal(l)] = args {
432        let val = l
433            .convert_to(tplt_ty)
434            .map(Instance::Literal)
435            .ok_or_else(|| E::ParamType(tplt_ty.clone(), l.ty()))?;
436        let comps = (0..n).map(|_| val.clone()).collect_vec();
437        Ok(VecInstance::new(comps).into())
438    }
439    // overload 2: vec conversion constructor
440    else if let [Instance::Vec(v)] = args {
441        let ty = Type::Vec(n as u8, Box::new(tplt_ty.clone()));
442        if v.n() != n {
443            return Err(E::Conversion(v.ty(), ty));
444        }
445
446        let conv_fn = match ty.inner_ty() {
447            Type::Bool => |n, _| bool(n),
448            Type::I32 => |n, _| i32(n),
449            Type::U32 => |n, _| u32(n),
450            Type::F32 => |n, stage| f32(n, stage),
451            Type::F16 => |n, stage| f16(n, stage),
452            _ => return Err(E::Builtin("vector type must be a scalar")),
453        };
454
455        let comps = v
456            .iter()
457            .map(|n| conv_fn(n, stage))
458            .collect::<Result<Vec<_>, _>>()?;
459
460        Ok(VecInstance::new(comps).into())
461    }
462    // overload 3: vec init from component values
463    else {
464        // flatten vecN args
465        let args = args
466            .iter()
467            .flat_map(|a| -> Box<dyn Iterator<Item = &Instance>> {
468                match a {
469                    Instance::Vec(v) => Box::new(v.iter()),
470                    _ => Box::new(std::iter::once(a)),
471                }
472            })
473            .collect_vec();
474        if args.len() != n {
475            return Err(E::ParamCount(format!("vec{n}"), n, args.len()));
476        }
477
478        let comps = args
479            .iter()
480            .map(|a| {
481                a.convert_inner_to(tplt_ty)
482                    .ok_or_else(|| E::ParamType(tplt_ty.clone(), a.ty()))
483            })
484            .collect::<Result<Vec<_>, _>>()?;
485
486        Ok(VecInstance::new(comps).into())
487    }
488}
489
490/// `vecN()` constructor.
491///
492/// Reference: <https://www.w3.org/TR/WGSL/#vec2-builtin>
493pub fn vec(n: usize, args: &[Instance]) -> Result<Instance, E> {
494    // overload 1: vec init from single scalar value
495    if let [Instance::Literal(l)] = args {
496        let ty = l.ty();
497        if !ty.is_scalar() {
498            return Err(E::Builtin("vec constructor expects scalar arguments"));
499        }
500        let val = Instance::Literal(*l);
501        let comps = (0..n).map(|_| val.clone()).collect_vec();
502        Ok(VecInstance::new(comps).into())
503    }
504    // overload 2: vec conversion constructor
505    else if let [Instance::Vec(v)] = args {
506        if v.n() != n {
507            let ty = v.ty();
508            let ty2 = Type::Vec(n as u8, ty.inner_ty().into());
509            return Err(E::Conversion(ty, ty2));
510        }
511        // note: `vecN(e: vecN<S>) -> vecN<S>` is no-op
512        Ok(v.clone().into())
513    }
514    // overload 3: vec init from component values
515    else if !args.is_empty() {
516        // flatten vecN args
517        let args = args
518            .iter()
519            .flat_map(|a| -> Box<dyn Iterator<Item = &Instance>> {
520                match a {
521                    Instance::Vec(v) => Box::new(v.iter()),
522                    _ => Box::new(std::iter::once(a)),
523                }
524            })
525            .cloned()
526            .collect_vec();
527        if args.len() != n {
528            return Err(E::ParamCount(format!("vec{n}"), n, args.len()));
529        }
530
531        let comps = convert_all(&args).ok_or(E::Builtin("vector components are incompatible"))?;
532
533        if !comps.first().unwrap(/* SAFETY: len() checked above */).ty().is_scalar() {
534            return Err(E::Builtin("vec constructor expects scalar arguments"));
535        }
536        Ok(VecInstance::new(comps).into())
537    }
538    // overload 3: zero-vec
539    else {
540        VecInstance::zero_value(n as u8, &Type::AbstractInt).map(Into::into)
541    }
542}
543
544/// User-defined struct constructor.
545pub fn struct_ctor(struct_ty: &StructType, args: &[Instance]) -> Result<StructInstance, E> {
546    if args.is_empty() {
547        return StructInstance::zero_value(struct_ty);
548    }
549
550    if args.len() != struct_ty.members.len() {
551        return Err(E::ParamCount(
552            struct_ty.name.clone(),
553            struct_ty.members.len(),
554            args.len(),
555        ));
556    }
557
558    let members = struct_ty
559        .members
560        .iter()
561        .zip(args)
562        .map(|(m_ty, inst)| {
563            let inst = inst
564                .convert_to(&m_ty.ty)
565                .ok_or_else(|| E::ParamType(m_ty.ty.clone(), inst.ty()))?;
566            Ok(inst)
567        })
568        .collect::<Result<Vec<_>, E>>()?;
569
570    Ok(StructInstance::new(struct_ty.clone(), members))
571}
572
573/// Check a struct constructor call signature.
574///
575/// Validates the type and number of arguments passed.
576pub fn typecheck_struct_ctor(struct_ty: &StructType, args: &[Type]) -> Result<(), E> {
577    if args.is_empty() {
578        // zero-value constructor
579        return Ok(());
580    }
581
582    if args.len() != struct_ty.members.len() {
583        return Err(E::ParamCount(
584            struct_ty.name.clone(),
585            struct_ty.members.len(),
586            args.len(),
587        ));
588    }
589
590    for (m_ty, a_ty) in struct_ty.members.iter().zip(args) {
591        if !a_ty.is_convertible_to(&m_ty.ty) {
592            return Err(E::ParamType(m_ty.ty.clone(), a_ty.ty()));
593        }
594    }
595
596    Ok(())
597}
598
599// -----------------
600// CONSTRUCTOR TYPES
601// -----------------
602
603/// Return type of `array<T,N>()` constructor.
604///
605/// Reference: <https://www.w3.org/TR/WGSL/#array-builtin>
606fn array_ctor_ty_t(tplt_ty: &Type, tplt_n: usize, args: &[Type]) -> Result<Type, E> {
607    if let Some(arg) = args.iter().find(|arg| !arg.is_convertible_to(tplt_ty)) {
608        Err(E::Conversion(arg.clone(), tplt_ty.clone()))
609    } else {
610        Ok(Type::Array(Box::new(tplt_ty.clone()), Some(tplt_n)))
611    }
612}
613
614/// Return type of `array()` constructor.
615///
616/// Reference: <https://www.w3.org/TR/WGSL/#array-builtin>
617fn array_ctor_ty(args: &[Type]) -> Result<Type, E> {
618    let ty = convert_all_ty(args).ok_or(E::Builtin("array elements are incompatible"))?;
619    Ok(Type::Array(Box::new(ty.clone()), Some(args.len())))
620}
621
622/// Return type of `matCxR<T>()` constructor.
623///
624/// Reference: <https://www.w3.org/TR/WGSL/#mat2x2-builtin>
625fn mat_ctor_ty_t(c: u8, r: u8, tplt_ty: &Type, args: &[Type]) -> Result<Type, E> {
626    // overload 1: mat conversion constructor
627    if let [ty @ Type::Mat(c2, r2, _)] = args {
628        // note: this is an explicit conversion, not automatic conversion
629        if *c2 != c || *r2 != r {
630            return Err(E::Conversion(
631                ty.clone(),
632                Type::Mat(c, r, Box::new(tplt_ty.clone())),
633            ));
634        }
635    } else {
636        if args.is_empty() {
637            return Err(E::Builtin("matrix constructor expects arguments"));
638        }
639        let ty = convert_all_ty(args).ok_or(E::Builtin("matrix components are incompatible"))?;
640        let ty = ty
641            .convert_inner_to(tplt_ty)
642            .ok_or(E::Conversion(ty.inner_ty(), tplt_ty.clone()))?;
643
644        // overload 2: mat from column vectors
645        if ty.is_vec() {
646            if args.len() != c as usize {
647                return Err(E::ParamCount(format!("mat{c}x{r}"), c as usize, args.len()));
648            }
649        }
650        // overload 3: mat from float values
651        else if ty.is_float() {
652            let n = c as usize * r as usize;
653            if args.len() != n {
654                return Err(E::ParamCount(format!("mat{c}x{r}"), n, args.len()));
655            }
656        } else {
657            return Err(E::Builtin(
658                "matrix constructor expects float or vector of float arguments",
659            ));
660        }
661    }
662
663    Ok(Type::Mat(c, r, Box::new(tplt_ty.clone())))
664}
665
666/// Return type of `matCxR()` constructor.
667///
668/// Reference: <https://www.w3.org/TR/WGSL/#mat2x2-builtin>
669fn mat_ctor_ty(c: u8, r: u8, args: &[Type]) -> Result<Type, E> {
670    // overload 1: mat conversion constructor
671    if let [ty @ Type::Mat(c2, r2, ty2)] = args {
672        // note: this is an explicit conversion, not automatic conversion
673        if *c2 != c || *r2 != r {
674            return Err(E::Conversion(ty.clone(), Type::Mat(c, r, ty2.clone())));
675        }
676        Ok(ty.clone())
677    } else {
678        let ty = convert_all_ty(args).ok_or(E::Builtin("matrix components are incompatible"))?;
679        let mut inner_ty = ty.inner_ty();
680
681        if inner_ty.is_abstract_int() {
682            // force conversion from AbstractInt to a float type
683            inner_ty = Type::F32;
684        } else if !inner_ty.is_float() {
685            return Err(E::Builtin(
686                "matrix constructor expects float or vector of float arguments",
687            ));
688        }
689
690        // overload 2: mat from column vectors
691        if ty.is_vec() {
692            if args.len() != c as usize {
693                return Err(E::ParamCount(format!("mat{c}x{r}"), c as usize, args.len()));
694            }
695        }
696        // overload 3: mat from float values
697        else if ty.is_float() || ty.is_abstract_int() {
698            let n = c as usize * r as usize;
699            if args.len() != n {
700                return Err(E::ParamCount(format!("mat{c}x{r}"), n, args.len()));
701            }
702        } else {
703            return Err(E::Builtin(
704                "matrix constructor expects float or vector of float arguments",
705            ));
706        }
707
708        Ok(Type::Mat(c, r, inner_ty.into()))
709    }
710}
711
712/// Return type of `vecN<T>()` constructor.
713///
714/// Reference: <https://www.w3.org/TR/WGSL/#vec2-builtin>
715fn vec_ctor_ty_t(n: u8, tplt_ty: &Type, args: &[Type]) -> Result<Type, E> {
716    if let [arg] = args {
717        // overload 1: vec init from single scalar value
718        if arg.is_scalar() {
719            if !arg.is_convertible_to(tplt_ty) {
720                return Err(E::Conversion(arg.clone(), tplt_ty.clone()));
721            }
722        }
723        // overload 2: vec conversion constructor
724        else if arg.is_vec() {
725            // note: this is an explicit conversion, not automatic conversion
726        } else {
727            return Err(E::Conversion(arg.clone(), tplt_ty.clone()));
728        }
729    }
730    // overload 3: vec init from component values
731    else {
732        // flatten vecN args
733        let n2 = args
734            .iter()
735            .try_fold(0, |acc, arg| match arg {
736                ty if ty.is_scalar() => ty.is_convertible_to(tplt_ty).then_some(acc + 1),
737                Type::Vec(n, ty) => ty.is_convertible_to(tplt_ty).then_some(acc + n),
738                _ => None,
739            })
740            .ok_or(E::Builtin(
741                "vector constructor expects scalar or vector arguments",
742            ))?;
743        if n2 != n {
744            return Err(E::ParamCount(format!("vec{n}"), n as usize, args.len()));
745        }
746    }
747
748    Ok(Type::Vec(n, Box::new(tplt_ty.clone())))
749}
750
751/// Return type of `vecN()` constructor.
752///
753/// Reference: <https://www.w3.org/TR/WGSL/#vec2-builtin>
754fn vec_ctor_ty(n: u8, args: &[Type]) -> Result<Type, E> {
755    if let [arg] = args {
756        // overload 1: vec init from single scalar value
757        if arg.is_scalar() {
758        }
759        // overload 2: vec conversion constructor
760        else if arg.is_vec() {
761            // note: `vecN(e: vecN<S>) -> vecN<S>` is no-op
762        } else {
763            return Err(E::Builtin(
764                "vector constructor expects scalar or vector arguments",
765            ));
766        }
767        Ok(Type::Vec(n, arg.inner_ty().into()))
768    }
769    // overload 3: vec init from component values
770    else if !args.is_empty() {
771        // flatten vecN args
772        let n2 = args
773            .iter()
774            .try_fold(0, |acc, arg| match arg {
775                ty if ty.is_scalar() => Some(acc + 1),
776                Type::Vec(n, _) => Some(acc + n),
777                _ => None,
778            })
779            .ok_or(E::Builtin(
780                "vector constructor expects scalar or vector arguments",
781            ))?;
782        if n2 != n {
783            return Err(E::ParamCount(format!("vec{n}"), n as usize, args.len()));
784        }
785
786        let tys = args.iter().map(|arg| arg.inner_ty()).collect_vec();
787        let ty = convert_all_ty(&tys).ok_or(E::Builtin("vector components are incompatible"))?;
788
789        Ok(Type::Vec(n, ty.clone().into()))
790    }
791    // overload 3: zero-vec
792    else {
793        Ok(Type::Vec(n, Type::AbstractInt.into()))
794    }
795}
796
797/// Compute the return type of calling a built-in constructor function.
798///
799/// The arguments must be [loaded][Type::loaded].
800///
801/// Includes built-in constructors and zero-value constructors, *but not* the struct
802/// constructors, since they require knowledge of the struct type.
803/// You can type-check a struct constructor call with [`typecheck_struct_ctor`].
804pub fn type_ctor(name: &str, tplt: Option<&[TpltParam]>, args: &[Type]) -> Result<Type, E> {
805    match (name, tplt, args) {
806        ("array", Some(t), []) => Ok(ArrayTemplate::parse(t)?.ty()),
807        ("array", Some(t), a) => {
808            let tplt = ArrayTemplate::parse(t)?;
809            array_ctor_ty_t(
810                &tplt.inner_ty(),
811                tplt.n().ok_or(E::TemplateArgs("array"))?,
812                a,
813            )
814        }
815        ("array", None, _) => array_ctor_ty(args),
816        ("bool", None, []) => Ok(Type::Bool),
817        ("bool", None, [a]) if a.is_scalar() => Ok(Type::Bool),
818        ("i32", None, []) => Ok(Type::I32),
819        ("i32", None, [a]) if a.is_scalar() => Ok(Type::I32),
820        ("u32", None, []) => Ok(Type::U32),
821        ("u32", None, [a]) if a.is_scalar() => Ok(Type::U32),
822        ("f32", None, []) => Ok(Type::F32),
823        ("f32", None, [a]) if a.is_scalar() => Ok(Type::F32),
824        ("f16", None, []) => Ok(Type::F16),
825        ("f16", None, [a]) if a.is_scalar() => Ok(Type::F16),
826        ("mat2x2", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(2, 2)),
827        ("mat2x2", Some(t), _) => mat_ctor_ty_t(2, 2, MatTemplate::parse(t)?.inner_ty(), args),
828        ("mat2x2", None, _) => mat_ctor_ty(2, 2, args),
829        ("mat2x3", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(2, 3)),
830        ("mat2x3", Some(t), _) => mat_ctor_ty_t(2, 3, MatTemplate::parse(t)?.inner_ty(), args),
831        ("mat2x3", None, _) => mat_ctor_ty(2, 3, args),
832        ("mat2x4", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(2, 4)),
833        ("mat2x4", Some(t), _) => mat_ctor_ty_t(2, 4, MatTemplate::parse(t)?.inner_ty(), args),
834        ("mat2x4", None, _) => mat_ctor_ty(2, 4, args),
835        ("mat3x2", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(3, 2)),
836        ("mat3x2", Some(t), _) => mat_ctor_ty_t(3, 2, MatTemplate::parse(t)?.inner_ty(), args),
837        ("mat3x2", None, _) => mat_ctor_ty(3, 2, args),
838        ("mat3x3", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(3, 3)),
839        ("mat3x3", Some(t), _) => mat_ctor_ty_t(3, 3, MatTemplate::parse(t)?.inner_ty(), args),
840        ("mat3x3", None, _) => mat_ctor_ty(3, 3, args),
841        ("mat3x4", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(3, 4)),
842        ("mat3x4", Some(t), _) => mat_ctor_ty_t(3, 4, MatTemplate::parse(t)?.inner_ty(), args),
843        ("mat3x4", None, _) => mat_ctor_ty(3, 4, args),
844        ("mat4x2", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(4, 2)),
845        ("mat4x2", Some(t), _) => mat_ctor_ty_t(4, 2, MatTemplate::parse(t)?.inner_ty(), args),
846        ("mat4x2", None, _) => mat_ctor_ty(4, 2, args),
847        ("mat4x3", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(4, 3)),
848        ("mat4x3", Some(t), _) => mat_ctor_ty_t(4, 3, MatTemplate::parse(t)?.inner_ty(), args),
849        ("mat4x3", None, _) => mat_ctor_ty(4, 3, args),
850        ("mat4x4", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(4, 4)),
851        ("mat4x4", Some(t), _) => mat_ctor_ty_t(4, 4, MatTemplate::parse(t)?.inner_ty(), args),
852        ("mat4x4", None, _) => mat_ctor_ty(4, 4, args),
853        ("vec2", Some(t), []) => Ok(VecTemplate::parse(t)?.ty(2)),
854        ("vec2", Some(t), _) => vec_ctor_ty_t(2, VecTemplate::parse(t)?.inner_ty(), args),
855        ("vec2", None, _) => vec_ctor_ty(2, args),
856        ("vec3", Some(t), []) => Ok(VecTemplate::parse(t)?.ty(3)),
857        ("vec3", Some(t), _) => vec_ctor_ty_t(3, VecTemplate::parse(t)?.inner_ty(), args),
858        ("vec3", None, _) => vec_ctor_ty(3, args),
859        ("vec4", Some(t), []) => Ok(VecTemplate::parse(t)?.ty(4)),
860        ("vec4", Some(t), _) => vec_ctor_ty_t(4, VecTemplate::parse(t)?.inner_ty(), args),
861        ("vec4", None, _) => vec_ctor_ty(4, args),
862        #[cfg(feature = "naga-ext")]
863        ("i64", None, []) => Ok(Type::I64),
864        #[cfg(feature = "naga-ext")]
865        ("i64", None, [a]) if a.is_scalar() => Ok(Type::I64),
866        #[cfg(feature = "naga-ext")]
867        ("u64", None, []) => Ok(Type::U64),
868        #[cfg(feature = "naga-ext")]
869        ("u64", None, [a]) if a.is_scalar() => Ok(Type::U64),
870        #[cfg(feature = "naga-ext")]
871        ("f64", None, []) => Ok(Type::F64),
872        #[cfg(feature = "naga-ext")]
873        ("f64", None, [a]) if a.is_scalar() => Ok(Type::F64),
874        _ => Err(E::Signature(CallSignature {
875            name: name.to_string(),
876            tplt: tplt.map(|t| t.to_vec()),
877            args: args.to_vec(),
878        })),
879    }
880}
881
882// -----------
883// ZERO VALUES
884// -----------
885// reference: <https://www.w3.org/TR/WGSL/#zero-value>
886
887impl Instance {
888    /// Zero-value initialize an instance of a given type.
889    pub fn zero_value(ty: &Type) -> Result<Self, E> {
890        match ty {
891            Type::Bool => Ok(LiteralInstance::Bool(false).into()),
892            Type::AbstractInt => Ok(LiteralInstance::AbstractInt(0).into()),
893            Type::AbstractFloat => Ok(LiteralInstance::AbstractFloat(0.0).into()),
894            Type::I32 => Ok(LiteralInstance::I32(0).into()),
895            Type::U32 => Ok(LiteralInstance::U32(0).into()),
896            Type::F32 => Ok(LiteralInstance::F32(0.0).into()),
897            Type::F16 => Ok(LiteralInstance::F16(f16::zero()).into()),
898            Type::Struct(s) => StructInstance::zero_value(s).map(Into::into),
899            Type::Array(a_ty, Some(n)) => ArrayInstance::zero_value(*n, a_ty).map(Into::into),
900            Type::Array(_, None) => Err(E::NotConstructible(ty.clone())),
901            Type::Vec(n, v_ty) => VecInstance::zero_value(*n, v_ty).map(Into::into),
902            Type::Mat(c, r, m_ty) => MatInstance::zero_value(*c, *r, m_ty).map(Into::into),
903            Type::Atomic(_)
904            | Type::Ptr(_, _, _)
905            | Type::Ref(_, _, _)
906            | Type::Texture(_)
907            | Type::Sampler(_) => Err(E::NotConstructible(ty.clone())),
908            #[cfg(feature = "naga-ext")]
909            Type::I64 => Ok(LiteralInstance::I64(0).into()),
910            #[cfg(feature = "naga-ext")]
911            Type::U64 => Ok(LiteralInstance::U64(0).into()),
912            #[cfg(feature = "naga-ext")]
913            Type::F64 => Ok(LiteralInstance::F64(0.0).into()),
914            #[cfg(feature = "naga-ext")]
915            Type::BindingArray(_, _) => Err(E::NotConstructible(ty.clone())),
916            #[cfg(feature = "naga-ext")]
917            Type::RayQuery(_) => Err(E::NotConstructible(ty.clone())),
918            #[cfg(feature = "naga-ext")]
919            Type::AccelerationStructure(_) => Err(E::NotConstructible(ty.clone())),
920        }
921    }
922}
923
924impl LiteralInstance {
925    /// The zero-value constructor.
926    pub fn zero_value(ty: &Type) -> Result<Self, E> {
927        match ty {
928            Type::Bool => Ok(LiteralInstance::Bool(false)),
929            Type::AbstractInt => Ok(LiteralInstance::AbstractInt(0)),
930            Type::AbstractFloat => Ok(LiteralInstance::AbstractFloat(0.0)),
931            Type::I32 => Ok(LiteralInstance::I32(0)),
932            Type::U32 => Ok(LiteralInstance::U32(0)),
933            Type::F32 => Ok(LiteralInstance::F32(0.0)),
934            Type::F16 => Ok(LiteralInstance::F16(f16::zero())),
935            #[cfg(feature = "naga-ext")]
936            Type::I64 => Ok(LiteralInstance::I64(0)),
937            #[cfg(feature = "naga-ext")]
938            Type::U64 => Ok(LiteralInstance::U64(0)),
939            #[cfg(feature = "naga-ext")]
940            Type::F64 => Ok(LiteralInstance::F64(0.0)),
941            _ => Err(E::NotScalar(ty.clone())),
942        }
943    }
944}
945
946impl StructInstance {
947    /// Zero-value initialize a `struct` instance.
948    pub fn zero_value(s: &StructType) -> Result<Self, E> {
949        let members = s
950            .members
951            .iter()
952            .map(|mem| {
953                let val = Instance::zero_value(&mem.ty)?;
954                Ok(val)
955            })
956            .collect::<Result<Vec<_>, _>>()?;
957
958        Ok(StructInstance::new(s.clone(), members))
959    }
960}
961
962impl ArrayInstance {
963    /// Zero-value initialize an `array` instance.
964    pub fn zero_value(n: usize, ty: &Type) -> Result<Self, E> {
965        let zero = Instance::zero_value(ty)?;
966        let comps = (0..n).map(|_| zero.clone()).collect_vec();
967        Ok(ArrayInstance::new(comps, false))
968    }
969}
970
971impl VecInstance {
972    /// Zero-value initialize a `vec` instance.
973    pub fn zero_value(n: u8, ty: &Type) -> Result<Self, E> {
974        let zero = Instance::Literal(LiteralInstance::zero_value(ty)?);
975        let comps = (0..n).map(|_| zero.clone()).collect_vec();
976        Ok(VecInstance::new(comps))
977    }
978}
979
980impl MatInstance {
981    /// Zero-value initialize a `mat` instance.
982    pub fn zero_value(c: u8, r: u8, ty: &Type) -> Result<Self, E> {
983        let zero = Instance::Literal(LiteralInstance::zero_value(ty)?);
984        let zero_col = Instance::Vec(VecInstance::new((0..r).map(|_| zero.clone()).collect_vec()));
985        let comps = (0..c).map(|_| zero_col.clone()).collect_vec();
986        Ok(MatInstance::from_cols(comps))
987    }
988}