1use half::prelude::*;
15use itertools::Itertools;
16use num_traits::{FromPrimitive, One, ToPrimitive, Zero};
17
18use crate::{
19 CallSignature, Error, ShaderStage,
20 conv::{Convert, convert_all, convert_all_inner_to, convert_all_to, convert_all_ty},
21 inst::{ArrayInstance, Instance, LiteralInstance, MatInstance, StructInstance, VecInstance},
22 tplt::{ArrayTemplate, MatTemplate, TpltParam, VecTemplate},
23 ty::{StructType, Ty, Type},
24};
25
26type E = Error;
27
28pub fn is_ctor(name: &str) -> bool {
33 match name {
34 "array" | "bool" | "i32" | "u32" | "f32" | "f16" | "mat2x2" | "mat2x3" | "mat2x4"
35 | "mat3x2" | "mat3x3" | "mat3x4" | "mat4x2" | "mat4x3" | "mat4x4" | "vec2" | "vec3"
36 | "vec4" => true,
37 #[cfg(feature = "naga-ext")]
38 "i64" | "u64" | "f64" => true,
39 _ => false,
40 }
41}
42
43pub fn array_t(tplt_ty: &Type, tplt_n: usize, args: &[Instance]) -> Result<Instance, E> {
52 let args = args
53 .iter()
54 .map(|a| {
55 a.convert_to(tplt_ty).ok_or_else(|| {
56 E::ParamType(Type::Array(Box::new(tplt_ty.clone()), Some(tplt_n)), a.ty())
57 })
58 })
59 .collect::<Result<Vec<_>, _>>()?;
60
61 if args.len() != tplt_n {
62 return Err(E::ParamCount("array".to_string(), tplt_n, args.len()));
63 }
64
65 Ok(ArrayInstance::new(args, false).into())
66}
67
68pub fn array(args: &[Instance]) -> Result<Instance, E> {
72 let args = convert_all(args).ok_or(E::Builtin("array elements are incompatible"))?;
73
74 if args.is_empty() {
75 return Err(E::Builtin("array constructor expects at least 1 argument"));
76 }
77
78 Ok(ArrayInstance::new(args, false).into())
79}
80
81pub fn bool(a1: &Instance) -> Result<Instance, E> {
85 match a1 {
86 Instance::Literal(l) => {
87 let zero = LiteralInstance::zero_value(&l.ty())?;
88 Ok(LiteralInstance::Bool(*l != zero).into())
89 }
90 _ => Err(E::Builtin("bool constructor expects a scalar argument")),
91 }
92}
93
94pub fn i32(a1: &Instance) -> Result<Instance, E> {
99 match a1 {
100 Instance::Literal(l) => {
101 let val = match l {
102 LiteralInstance::Bool(n) => Some(n.then_some(1).unwrap_or(0)),
103 LiteralInstance::AbstractInt(n) => n.to_i32(), LiteralInstance::AbstractFloat(n) => Some(*n as i32), LiteralInstance::I32(n) => Some(*n), LiteralInstance::U32(n) => Some(*n as i32), LiteralInstance::F32(n) => Some((*n as i32).min(2147483520)), LiteralInstance::F16(n) => Some((f16::to_f32(*n) as i32).min(65504)), #[cfg(feature = "naga-ext")]
110 LiteralInstance::I64(n) => n.to_i32(), #[cfg(feature = "naga-ext")]
112 LiteralInstance::U64(n) => n.to_i32(), #[cfg(feature = "naga-ext")]
114 LiteralInstance::F64(n) => Some(*n as i32), }
116 .ok_or(E::ConvOverflow(*l, Type::I32))?;
117 Ok(LiteralInstance::I32(val).into())
118 }
119 _ => Err(E::Builtin("i32 constructor expects a scalar argument")),
120 }
121}
122
123pub fn u32(a1: &Instance) -> Result<Instance, E> {
127 match a1 {
128 Instance::Literal(l) => {
129 let val = match l {
130 LiteralInstance::Bool(n) => Some(n.then_some(1).unwrap_or(0)),
131 LiteralInstance::AbstractInt(n) => n.to_u32(), LiteralInstance::AbstractFloat(n) => Some(*n as u32), LiteralInstance::I32(n) => Some(*n as u32), LiteralInstance::U32(n) => Some(*n), LiteralInstance::F32(n) => Some((*n as u32).min(4294967040)), LiteralInstance::F16(n) => Some((f16::to_f32(*n) as u32).min(65504)), #[cfg(feature = "naga-ext")]
138 LiteralInstance::I64(n) => n.to_u32(), #[cfg(feature = "naga-ext")]
140 LiteralInstance::U64(n) => n.to_u32(), #[cfg(feature = "naga-ext")]
142 LiteralInstance::F64(n) => Some(*n as u32), }
144 .ok_or(E::ConvOverflow(*l, Type::U32))?;
145 Ok(LiteralInstance::U32(val).into())
146 }
147 _ => Err(E::Builtin("u32 constructor expects a scalar argument")),
148 }
149}
150
151pub fn f32(a1: &Instance, _stage: ShaderStage) -> Result<Instance, E> {
155 match a1 {
156 Instance::Literal(l) => {
157 let val = match l {
158 LiteralInstance::Bool(n) => Some(n.then_some(f32::one()).unwrap_or(f32::zero())),
159 LiteralInstance::AbstractInt(n) => n.to_f32(), LiteralInstance::AbstractFloat(n) => n.to_f32(), LiteralInstance::I32(n) => Some(*n as f32), LiteralInstance::U32(n) => Some(*n as f32), LiteralInstance::F32(n) => Some(*n), LiteralInstance::F16(n) => Some(f16::to_f32(*n)), #[cfg(feature = "naga-ext")]
166 LiteralInstance::I64(n) => n.to_f32(), #[cfg(feature = "naga-ext")]
168 LiteralInstance::U64(n) => n.to_f32(), #[cfg(feature = "naga-ext")]
170 LiteralInstance::F64(n) => n.to_f32(), }
172 .ok_or(E::ConvOverflow(*l, Type::F32))?;
173 Ok(LiteralInstance::F32(val).into())
174 }
175 _ => Err(E::Builtin("f32 constructor expects a scalar argument")),
176 }
177}
178
179pub fn f16(a1: &Instance, stage: ShaderStage) -> Result<Instance, E> {
183 match a1 {
184 Instance::Literal(l) => {
185 let val = match l {
186 LiteralInstance::Bool(n) => Some(n.then_some(f16::one()).unwrap_or(f16::zero())),
187 LiteralInstance::AbstractInt(n) => {
188 if stage == ShaderStage::Const {
190 let range = -65504..=65504;
191 range.contains(n).then_some(f16::from_f32(*n as f32))
192 } else {
193 Some(f16::from_f32(*n as f32))
194 }
195 }
196 LiteralInstance::AbstractFloat(n) => {
197 if stage == ShaderStage::Const {
199 let range = -65504.0..=65504.0;
200 range.contains(n).then_some(f16::from_f32(*n as f32))
201 } else {
202 Some(f16::from_f32(*n as f32))
203 }
204 }
205 LiteralInstance::I32(n) => {
206 if stage == ShaderStage::Const {
208 f16::from_i32(*n)
209 } else {
210 Some(f16::from_f32(*n as f32))
211 }
212 }
213 LiteralInstance::U32(n) => {
214 if stage == ShaderStage::Const {
216 f16::from_u32(*n)
217 } else {
218 Some(f16::from_f32(*n as f32))
219 }
220 }
221 LiteralInstance::F32(n) => {
222 if stage == ShaderStage::Const {
224 let range = -65504.0..=65504.0;
225 range.contains(n).then_some(f16::from_f32(*n))
226 } else {
227 Some(f16::from_f32(*n))
228 }
229 }
230 LiteralInstance::F16(n) => Some(*n), #[cfg(feature = "naga-ext")]
232 LiteralInstance::I64(n) => {
233 if stage == ShaderStage::Const {
235 let range = -65504..=65504;
236 range.contains(n).then_some(f16::from_f32(*n as f32))
237 } else {
238 Some(f16::from_f32(*n as f32))
239 }
240 }
241 #[cfg(feature = "naga-ext")]
242 LiteralInstance::U64(n) => {
243 if stage == ShaderStage::Const {
245 f16::from_u64(*n)
246 } else {
247 Some(f16::from_f32(*n as f32))
248 }
249 }
250 #[cfg(feature = "naga-ext")]
251 LiteralInstance::F64(n) => {
252 if stage == ShaderStage::Const {
254 let range = -65504.0..=65504.0;
255 range.contains(n).then_some(f16::from_f32(*n as f32))
256 } else {
257 Some(f16::from_f32(*n as f32))
258 }
259 }
260 }
261 .ok_or(E::ConvOverflow(*l, Type::F16))?;
262 Ok(LiteralInstance::F16(val).into())
263 }
264 _ => Err(E::Builtin("f16 constructor expects a scalar argument")),
265 }
266}
267
268pub fn i64(_a1: &Instance) -> Result<Instance, E> {
272 Err(E::Todo("i64".to_string()))
273}
274
275pub fn u64(_a1: &Instance) -> Result<Instance, E> {
279 Err(E::Todo("u64".to_string()))
280}
281
282pub fn f64(_a1: &Instance, _stage: ShaderStage) -> Result<Instance, E> {
286 Err(E::Todo("f64".to_string()))
287}
288
289pub fn mat_t(
293 c: usize,
294 r: usize,
295 tplt_ty: &Type,
296 args: &[Instance],
297 stage: ShaderStage,
298) -> Result<Instance, E> {
299 if let [Instance::Mat(m)] = args {
301 if m.c() != c || m.r() != r {
302 return Err(E::Conversion(
303 m.ty(),
304 Type::Mat(c as u8, r as u8, Box::new(tplt_ty.clone())),
305 ));
306 }
307
308 let conv_fn = match tplt_ty {
309 Type::F32 => f32,
310 Type::F16 => f16,
311 _ => return Err(E::Builtin("matrix type must be a f32 or f16")),
312 };
313
314 let comps = m
315 .iter_cols()
316 .map(|v| {
317 v.unwrap_vec_ref()
318 .iter()
319 .map(|n| conv_fn(n, stage))
320 .collect::<Result<Vec<_>, _>>()
321 .map(|s| Instance::Vec(VecInstance::new(s)))
322 })
323 .collect::<Result<Vec<_>, _>>()?;
324
325 Ok(MatInstance::from_cols(comps).into())
326 } else {
327 let ty = args
328 .first()
329 .ok_or(E::Builtin("matrix constructor expects arguments"))?
330 .ty();
331 let ty = ty
332 .convert_inner_to(tplt_ty)
333 .ok_or(E::Conversion(ty.inner_ty(), tplt_ty.clone()))?;
334 let args =
335 convert_all_to(args, &ty).ok_or(E::Builtin("matrix components are incompatible"))?;
336
337 if ty.is_vec() {
339 if args.len() != c {
340 return Err(E::ParamCount(format!("mat{c}x{r}"), c, args.len()));
341 }
342
343 Ok(MatInstance::from_cols(args).into())
344 }
345 else if ty.is_float() {
347 if args.len() != c * r {
348 return Err(E::ParamCount(format!("mat{c}x{r}"), c * r, args.len()));
349 }
350
351 let args = args
352 .chunks(r)
353 .map(|v| Instance::Vec(VecInstance::new(v.to_vec())))
354 .collect_vec();
355
356 Ok(MatInstance::from_cols(args).into())
357 } else {
358 Err(E::Builtin(
359 "matrix constructor expects float or vector of float arguments",
360 ))
361 }
362 }
363}
364
365pub fn mat(c: usize, r: usize, args: &[Instance]) -> Result<Instance, E> {
369 if let [Instance::Mat(m)] = args {
371 if m.c() != c || m.r() != r {
372 let ty2 = Type::Mat(c as u8, r as u8, m.inner_ty().into());
373 return Err(E::Conversion(m.ty(), ty2));
374 }
375 Ok(m.clone().into())
377 } else {
378 let tys = args.iter().map(|a| a.ty()).collect_vec();
379 let ty = convert_all_ty(&tys).ok_or(E::Builtin("matrix components are incompatible"))?;
380 let mut inner_ty = ty.inner_ty();
381
382 if inner_ty.is_abstract_int() {
383 inner_ty = Type::F32;
385 } else if !inner_ty.is_float() {
386 return Err(E::Builtin(
387 "matrix constructor expects float or vector of float arguments",
388 ));
389 }
390
391 let args = convert_all_inner_to(args, &inner_ty)
392 .ok_or(E::Builtin("matrix components are incompatible"))?;
393
394 if ty.is_vec() {
396 if args.len() != c {
397 return Err(E::ParamCount(format!("mat{c}x{r}"), c, args.len()));
398 }
399
400 Ok(MatInstance::from_cols(args).into())
401 }
402 else if ty.is_float() || ty.is_abstract_int() {
404 if args.len() != c * r {
405 return Err(E::ParamCount(format!("mat{c}x{r}"), c * r, args.len()));
406 }
407 let args = args
408 .chunks(r)
409 .map(|v| Instance::Vec(VecInstance::new(v.to_vec())))
410 .collect_vec();
411
412 Ok(MatInstance::from_cols(args).into())
413 } else {
414 Err(E::Builtin(
415 "matrix constructor expects float or vector of float arguments",
416 ))
417 }
418 }
419}
420
421pub fn vec_t(
425 n: usize,
426 tplt_ty: &Type,
427 args: &[Instance],
428 stage: ShaderStage,
429) -> Result<Instance, E> {
430 if let [Instance::Literal(l)] = args {
432 let val = l
433 .convert_to(tplt_ty)
434 .map(Instance::Literal)
435 .ok_or_else(|| E::ParamType(tplt_ty.clone(), l.ty()))?;
436 let comps = (0..n).map(|_| val.clone()).collect_vec();
437 Ok(VecInstance::new(comps).into())
438 }
439 else if let [Instance::Vec(v)] = args {
441 let ty = Type::Vec(n as u8, Box::new(tplt_ty.clone()));
442 if v.n() != n {
443 return Err(E::Conversion(v.ty(), ty));
444 }
445
446 let conv_fn = match ty.inner_ty() {
447 Type::Bool => |n, _| bool(n),
448 Type::I32 => |n, _| i32(n),
449 Type::U32 => |n, _| u32(n),
450 Type::F32 => |n, stage| f32(n, stage),
451 Type::F16 => |n, stage| f16(n, stage),
452 _ => return Err(E::Builtin("vector type must be a scalar")),
453 };
454
455 let comps = v
456 .iter()
457 .map(|n| conv_fn(n, stage))
458 .collect::<Result<Vec<_>, _>>()?;
459
460 Ok(VecInstance::new(comps).into())
461 }
462 else {
464 let args = args
466 .iter()
467 .flat_map(|a| -> Box<dyn Iterator<Item = &Instance>> {
468 match a {
469 Instance::Vec(v) => Box::new(v.iter()),
470 _ => Box::new(std::iter::once(a)),
471 }
472 })
473 .collect_vec();
474 if args.len() != n {
475 return Err(E::ParamCount(format!("vec{n}"), n, args.len()));
476 }
477
478 let comps = args
479 .iter()
480 .map(|a| {
481 a.convert_inner_to(tplt_ty)
482 .ok_or_else(|| E::ParamType(tplt_ty.clone(), a.ty()))
483 })
484 .collect::<Result<Vec<_>, _>>()?;
485
486 Ok(VecInstance::new(comps).into())
487 }
488}
489
490pub fn vec(n: usize, args: &[Instance]) -> Result<Instance, E> {
494 if let [Instance::Literal(l)] = args {
496 let ty = l.ty();
497 if !ty.is_scalar() {
498 return Err(E::Builtin("vec constructor expects scalar arguments"));
499 }
500 let val = Instance::Literal(*l);
501 let comps = (0..n).map(|_| val.clone()).collect_vec();
502 Ok(VecInstance::new(comps).into())
503 }
504 else if let [Instance::Vec(v)] = args {
506 if v.n() != n {
507 let ty = v.ty();
508 let ty2 = Type::Vec(n as u8, ty.inner_ty().into());
509 return Err(E::Conversion(ty, ty2));
510 }
511 Ok(v.clone().into())
513 }
514 else if !args.is_empty() {
516 let args = args
518 .iter()
519 .flat_map(|a| -> Box<dyn Iterator<Item = &Instance>> {
520 match a {
521 Instance::Vec(v) => Box::new(v.iter()),
522 _ => Box::new(std::iter::once(a)),
523 }
524 })
525 .cloned()
526 .collect_vec();
527 if args.len() != n {
528 return Err(E::ParamCount(format!("vec{n}"), n, args.len()));
529 }
530
531 let comps = convert_all(&args).ok_or(E::Builtin("vector components are incompatible"))?;
532
533 if !comps.first().unwrap().ty().is_scalar() {
534 return Err(E::Builtin("vec constructor expects scalar arguments"));
535 }
536 Ok(VecInstance::new(comps).into())
537 }
538 else {
540 VecInstance::zero_value(n as u8, &Type::AbstractInt).map(Into::into)
541 }
542}
543
544pub fn struct_ctor(struct_ty: &StructType, args: &[Instance]) -> Result<StructInstance, E> {
546 if args.is_empty() {
547 return StructInstance::zero_value(struct_ty);
548 }
549
550 if args.len() != struct_ty.members.len() {
551 return Err(E::ParamCount(
552 struct_ty.name.clone(),
553 struct_ty.members.len(),
554 args.len(),
555 ));
556 }
557
558 let members = struct_ty
559 .members
560 .iter()
561 .zip(args)
562 .map(|(m_ty, inst)| {
563 let inst = inst
564 .convert_to(&m_ty.ty)
565 .ok_or_else(|| E::ParamType(m_ty.ty.clone(), inst.ty()))?;
566 Ok(inst)
567 })
568 .collect::<Result<Vec<_>, E>>()?;
569
570 Ok(StructInstance::new(struct_ty.clone(), members))
571}
572
573pub fn typecheck_struct_ctor(struct_ty: &StructType, args: &[Type]) -> Result<(), E> {
577 if args.is_empty() {
578 return Ok(());
580 }
581
582 if args.len() != struct_ty.members.len() {
583 return Err(E::ParamCount(
584 struct_ty.name.clone(),
585 struct_ty.members.len(),
586 args.len(),
587 ));
588 }
589
590 for (m_ty, a_ty) in struct_ty.members.iter().zip(args) {
591 if !a_ty.is_convertible_to(&m_ty.ty) {
592 return Err(E::ParamType(m_ty.ty.clone(), a_ty.ty()));
593 }
594 }
595
596 Ok(())
597}
598
599fn array_ctor_ty_t(tplt_ty: &Type, tplt_n: usize, args: &[Type]) -> Result<Type, E> {
607 if let Some(arg) = args.iter().find(|arg| !arg.is_convertible_to(tplt_ty)) {
608 Err(E::Conversion(arg.clone(), tplt_ty.clone()))
609 } else {
610 Ok(Type::Array(Box::new(tplt_ty.clone()), Some(tplt_n)))
611 }
612}
613
614fn array_ctor_ty(args: &[Type]) -> Result<Type, E> {
618 let ty = convert_all_ty(args).ok_or(E::Builtin("array elements are incompatible"))?;
619 Ok(Type::Array(Box::new(ty.clone()), Some(args.len())))
620}
621
622fn mat_ctor_ty_t(c: u8, r: u8, tplt_ty: &Type, args: &[Type]) -> Result<Type, E> {
626 if let [ty @ Type::Mat(c2, r2, _)] = args {
628 if *c2 != c || *r2 != r {
630 return Err(E::Conversion(
631 ty.clone(),
632 Type::Mat(c, r, Box::new(tplt_ty.clone())),
633 ));
634 }
635 } else {
636 if args.is_empty() {
637 return Err(E::Builtin("matrix constructor expects arguments"));
638 }
639 let ty = convert_all_ty(args).ok_or(E::Builtin("matrix components are incompatible"))?;
640 let ty = ty
641 .convert_inner_to(tplt_ty)
642 .ok_or(E::Conversion(ty.inner_ty(), tplt_ty.clone()))?;
643
644 if ty.is_vec() {
646 if args.len() != c as usize {
647 return Err(E::ParamCount(format!("mat{c}x{r}"), c as usize, args.len()));
648 }
649 }
650 else if ty.is_float() {
652 let n = c as usize * r as usize;
653 if args.len() != n {
654 return Err(E::ParamCount(format!("mat{c}x{r}"), n, args.len()));
655 }
656 } else {
657 return Err(E::Builtin(
658 "matrix constructor expects float or vector of float arguments",
659 ));
660 }
661 }
662
663 Ok(Type::Mat(c, r, Box::new(tplt_ty.clone())))
664}
665
666fn mat_ctor_ty(c: u8, r: u8, args: &[Type]) -> Result<Type, E> {
670 if let [ty @ Type::Mat(c2, r2, ty2)] = args {
672 if *c2 != c || *r2 != r {
674 return Err(E::Conversion(ty.clone(), Type::Mat(c, r, ty2.clone())));
675 }
676 Ok(ty.clone())
677 } else {
678 let ty = convert_all_ty(args).ok_or(E::Builtin("matrix components are incompatible"))?;
679 let mut inner_ty = ty.inner_ty();
680
681 if inner_ty.is_abstract_int() {
682 inner_ty = Type::F32;
684 } else if !inner_ty.is_float() {
685 return Err(E::Builtin(
686 "matrix constructor expects float or vector of float arguments",
687 ));
688 }
689
690 if ty.is_vec() {
692 if args.len() != c as usize {
693 return Err(E::ParamCount(format!("mat{c}x{r}"), c as usize, args.len()));
694 }
695 }
696 else if ty.is_float() || ty.is_abstract_int() {
698 let n = c as usize * r as usize;
699 if args.len() != n {
700 return Err(E::ParamCount(format!("mat{c}x{r}"), n, args.len()));
701 }
702 } else {
703 return Err(E::Builtin(
704 "matrix constructor expects float or vector of float arguments",
705 ));
706 }
707
708 Ok(Type::Mat(c, r, inner_ty.into()))
709 }
710}
711
712fn vec_ctor_ty_t(n: u8, tplt_ty: &Type, args: &[Type]) -> Result<Type, E> {
716 if let [arg] = args {
717 if arg.is_scalar() {
719 if !arg.is_convertible_to(tplt_ty) {
720 return Err(E::Conversion(arg.clone(), tplt_ty.clone()));
721 }
722 }
723 else if arg.is_vec() {
725 } else {
727 return Err(E::Conversion(arg.clone(), tplt_ty.clone()));
728 }
729 }
730 else {
732 let n2 = args
734 .iter()
735 .try_fold(0, |acc, arg| match arg {
736 ty if ty.is_scalar() => ty.is_convertible_to(tplt_ty).then_some(acc + 1),
737 Type::Vec(n, ty) => ty.is_convertible_to(tplt_ty).then_some(acc + n),
738 _ => None,
739 })
740 .ok_or(E::Builtin(
741 "vector constructor expects scalar or vector arguments",
742 ))?;
743 if n2 != n {
744 return Err(E::ParamCount(format!("vec{n}"), n as usize, args.len()));
745 }
746 }
747
748 Ok(Type::Vec(n, Box::new(tplt_ty.clone())))
749}
750
751fn vec_ctor_ty(n: u8, args: &[Type]) -> Result<Type, E> {
755 if let [arg] = args {
756 if arg.is_scalar() {
758 }
759 else if arg.is_vec() {
761 } else {
763 return Err(E::Builtin(
764 "vector constructor expects scalar or vector arguments",
765 ));
766 }
767 Ok(Type::Vec(n, arg.inner_ty().into()))
768 }
769 else if !args.is_empty() {
771 let n2 = args
773 .iter()
774 .try_fold(0, |acc, arg| match arg {
775 ty if ty.is_scalar() => Some(acc + 1),
776 Type::Vec(n, _) => Some(acc + n),
777 _ => None,
778 })
779 .ok_or(E::Builtin(
780 "vector constructor expects scalar or vector arguments",
781 ))?;
782 if n2 != n {
783 return Err(E::ParamCount(format!("vec{n}"), n as usize, args.len()));
784 }
785
786 let tys = args.iter().map(|arg| arg.inner_ty()).collect_vec();
787 let ty = convert_all_ty(&tys).ok_or(E::Builtin("vector components are incompatible"))?;
788
789 Ok(Type::Vec(n, ty.clone().into()))
790 }
791 else {
793 Ok(Type::Vec(n, Type::AbstractInt.into()))
794 }
795}
796
797pub fn type_ctor(name: &str, tplt: Option<&[TpltParam]>, args: &[Type]) -> Result<Type, E> {
805 match (name, tplt, args) {
806 ("array", Some(t), []) => Ok(ArrayTemplate::parse(t)?.ty()),
807 ("array", Some(t), a) => {
808 let tplt = ArrayTemplate::parse(t)?;
809 array_ctor_ty_t(
810 &tplt.inner_ty(),
811 tplt.n().ok_or(E::TemplateArgs("array"))?,
812 a,
813 )
814 }
815 ("array", None, _) => array_ctor_ty(args),
816 ("bool", None, []) => Ok(Type::Bool),
817 ("bool", None, [a]) if a.is_scalar() => Ok(Type::Bool),
818 ("i32", None, []) => Ok(Type::I32),
819 ("i32", None, [a]) if a.is_scalar() => Ok(Type::I32),
820 ("u32", None, []) => Ok(Type::U32),
821 ("u32", None, [a]) if a.is_scalar() => Ok(Type::U32),
822 ("f32", None, []) => Ok(Type::F32),
823 ("f32", None, [a]) if a.is_scalar() => Ok(Type::F32),
824 ("f16", None, []) => Ok(Type::F16),
825 ("f16", None, [a]) if a.is_scalar() => Ok(Type::F16),
826 ("mat2x2", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(2, 2)),
827 ("mat2x2", Some(t), _) => mat_ctor_ty_t(2, 2, MatTemplate::parse(t)?.inner_ty(), args),
828 ("mat2x2", None, _) => mat_ctor_ty(2, 2, args),
829 ("mat2x3", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(2, 3)),
830 ("mat2x3", Some(t), _) => mat_ctor_ty_t(2, 3, MatTemplate::parse(t)?.inner_ty(), args),
831 ("mat2x3", None, _) => mat_ctor_ty(2, 3, args),
832 ("mat2x4", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(2, 4)),
833 ("mat2x4", Some(t), _) => mat_ctor_ty_t(2, 4, MatTemplate::parse(t)?.inner_ty(), args),
834 ("mat2x4", None, _) => mat_ctor_ty(2, 4, args),
835 ("mat3x2", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(3, 2)),
836 ("mat3x2", Some(t), _) => mat_ctor_ty_t(3, 2, MatTemplate::parse(t)?.inner_ty(), args),
837 ("mat3x2", None, _) => mat_ctor_ty(3, 2, args),
838 ("mat3x3", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(3, 3)),
839 ("mat3x3", Some(t), _) => mat_ctor_ty_t(3, 3, MatTemplate::parse(t)?.inner_ty(), args),
840 ("mat3x3", None, _) => mat_ctor_ty(3, 3, args),
841 ("mat3x4", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(3, 4)),
842 ("mat3x4", Some(t), _) => mat_ctor_ty_t(3, 4, MatTemplate::parse(t)?.inner_ty(), args),
843 ("mat3x4", None, _) => mat_ctor_ty(3, 4, args),
844 ("mat4x2", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(4, 2)),
845 ("mat4x2", Some(t), _) => mat_ctor_ty_t(4, 2, MatTemplate::parse(t)?.inner_ty(), args),
846 ("mat4x2", None, _) => mat_ctor_ty(4, 2, args),
847 ("mat4x3", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(4, 3)),
848 ("mat4x3", Some(t), _) => mat_ctor_ty_t(4, 3, MatTemplate::parse(t)?.inner_ty(), args),
849 ("mat4x3", None, _) => mat_ctor_ty(4, 3, args),
850 ("mat4x4", Some(t), []) => Ok(MatTemplate::parse(t)?.ty(4, 4)),
851 ("mat4x4", Some(t), _) => mat_ctor_ty_t(4, 4, MatTemplate::parse(t)?.inner_ty(), args),
852 ("mat4x4", None, _) => mat_ctor_ty(4, 4, args),
853 ("vec2", Some(t), []) => Ok(VecTemplate::parse(t)?.ty(2)),
854 ("vec2", Some(t), _) => vec_ctor_ty_t(2, VecTemplate::parse(t)?.inner_ty(), args),
855 ("vec2", None, _) => vec_ctor_ty(2, args),
856 ("vec3", Some(t), []) => Ok(VecTemplate::parse(t)?.ty(3)),
857 ("vec3", Some(t), _) => vec_ctor_ty_t(3, VecTemplate::parse(t)?.inner_ty(), args),
858 ("vec3", None, _) => vec_ctor_ty(3, args),
859 ("vec4", Some(t), []) => Ok(VecTemplate::parse(t)?.ty(4)),
860 ("vec4", Some(t), _) => vec_ctor_ty_t(4, VecTemplate::parse(t)?.inner_ty(), args),
861 ("vec4", None, _) => vec_ctor_ty(4, args),
862 #[cfg(feature = "naga-ext")]
863 ("i64", None, []) => Ok(Type::I64),
864 #[cfg(feature = "naga-ext")]
865 ("i64", None, [a]) if a.is_scalar() => Ok(Type::I64),
866 #[cfg(feature = "naga-ext")]
867 ("u64", None, []) => Ok(Type::U64),
868 #[cfg(feature = "naga-ext")]
869 ("u64", None, [a]) if a.is_scalar() => Ok(Type::U64),
870 #[cfg(feature = "naga-ext")]
871 ("f64", None, []) => Ok(Type::F64),
872 #[cfg(feature = "naga-ext")]
873 ("f64", None, [a]) if a.is_scalar() => Ok(Type::F64),
874 _ => Err(E::Signature(CallSignature {
875 name: name.to_string(),
876 tplt: tplt.map(|t| t.to_vec()),
877 args: args.to_vec(),
878 })),
879 }
880}
881
882impl Instance {
888 pub fn zero_value(ty: &Type) -> Result<Self, E> {
890 match ty {
891 Type::Bool => Ok(LiteralInstance::Bool(false).into()),
892 Type::AbstractInt => Ok(LiteralInstance::AbstractInt(0).into()),
893 Type::AbstractFloat => Ok(LiteralInstance::AbstractFloat(0.0).into()),
894 Type::I32 => Ok(LiteralInstance::I32(0).into()),
895 Type::U32 => Ok(LiteralInstance::U32(0).into()),
896 Type::F32 => Ok(LiteralInstance::F32(0.0).into()),
897 Type::F16 => Ok(LiteralInstance::F16(f16::zero()).into()),
898 Type::Struct(s) => StructInstance::zero_value(s).map(Into::into),
899 Type::Array(a_ty, Some(n)) => ArrayInstance::zero_value(*n, a_ty).map(Into::into),
900 Type::Array(_, None) => Err(E::NotConstructible(ty.clone())),
901 Type::Vec(n, v_ty) => VecInstance::zero_value(*n, v_ty).map(Into::into),
902 Type::Mat(c, r, m_ty) => MatInstance::zero_value(*c, *r, m_ty).map(Into::into),
903 Type::Atomic(_)
904 | Type::Ptr(_, _, _)
905 | Type::Ref(_, _, _)
906 | Type::Texture(_)
907 | Type::Sampler(_) => Err(E::NotConstructible(ty.clone())),
908 #[cfg(feature = "naga-ext")]
909 Type::I64 => Ok(LiteralInstance::I64(0).into()),
910 #[cfg(feature = "naga-ext")]
911 Type::U64 => Ok(LiteralInstance::U64(0).into()),
912 #[cfg(feature = "naga-ext")]
913 Type::F64 => Ok(LiteralInstance::F64(0.0).into()),
914 #[cfg(feature = "naga-ext")]
915 Type::BindingArray(_, _) => Err(E::NotConstructible(ty.clone())),
916 #[cfg(feature = "naga-ext")]
917 Type::RayQuery(_) => Err(E::NotConstructible(ty.clone())),
918 #[cfg(feature = "naga-ext")]
919 Type::AccelerationStructure(_) => Err(E::NotConstructible(ty.clone())),
920 }
921 }
922}
923
924impl LiteralInstance {
925 pub fn zero_value(ty: &Type) -> Result<Self, E> {
927 match ty {
928 Type::Bool => Ok(LiteralInstance::Bool(false)),
929 Type::AbstractInt => Ok(LiteralInstance::AbstractInt(0)),
930 Type::AbstractFloat => Ok(LiteralInstance::AbstractFloat(0.0)),
931 Type::I32 => Ok(LiteralInstance::I32(0)),
932 Type::U32 => Ok(LiteralInstance::U32(0)),
933 Type::F32 => Ok(LiteralInstance::F32(0.0)),
934 Type::F16 => Ok(LiteralInstance::F16(f16::zero())),
935 #[cfg(feature = "naga-ext")]
936 Type::I64 => Ok(LiteralInstance::I64(0)),
937 #[cfg(feature = "naga-ext")]
938 Type::U64 => Ok(LiteralInstance::U64(0)),
939 #[cfg(feature = "naga-ext")]
940 Type::F64 => Ok(LiteralInstance::F64(0.0)),
941 _ => Err(E::NotScalar(ty.clone())),
942 }
943 }
944}
945
946impl StructInstance {
947 pub fn zero_value(s: &StructType) -> Result<Self, E> {
949 let members = s
950 .members
951 .iter()
952 .map(|mem| {
953 let val = Instance::zero_value(&mem.ty)?;
954 Ok(val)
955 })
956 .collect::<Result<Vec<_>, _>>()?;
957
958 Ok(StructInstance::new(s.clone(), members))
959 }
960}
961
962impl ArrayInstance {
963 pub fn zero_value(n: usize, ty: &Type) -> Result<Self, E> {
965 let zero = Instance::zero_value(ty)?;
966 let comps = (0..n).map(|_| zero.clone()).collect_vec();
967 Ok(ArrayInstance::new(comps, false))
968 }
969}
970
971impl VecInstance {
972 pub fn zero_value(n: u8, ty: &Type) -> Result<Self, E> {
974 let zero = Instance::Literal(LiteralInstance::zero_value(ty)?);
975 let comps = (0..n).map(|_| zero.clone()).collect_vec();
976 Ok(VecInstance::new(comps))
977 }
978}
979
980impl MatInstance {
981 pub fn zero_value(c: u8, r: u8, ty: &Type) -> Result<Self, E> {
983 let zero = Instance::Literal(LiteralInstance::zero_value(ty)?);
984 let zero_col = Instance::Vec(VecInstance::new((0..r).map(|_| zero.clone()).collect_vec()));
985 let comps = (0..c).map(|_| zero_col.clone()).collect_vec();
986 Ok(MatInstance::from_cols(comps))
987 }
988}