1use half::f16;
9use itertools::Itertools;
10use num_traits::{FromPrimitive, ToPrimitive};
11
12use crate::{
13 Error, Instance,
14 inst::{ArrayInstance, LiteralInstance, MatInstance, StructInstance, VecInstance},
15 syntax::AccessMode,
16 ty::{Ty, Type},
17};
18
19pub trait Convert: Sized + Clone + Ty {
20 fn convert_to(&self, ty: &Type) -> Option<Self>;
26
27 fn convert_inner_to(&self, ty: &Type) -> Option<Self> {
39 self.convert_to(ty)
40 }
41
42 fn concretize(&self) -> Option<Self> {
46 self.convert_to(&self.ty().concretize())
47 }
48}
49
50impl Convert for Type {
51 fn convert_to(&self, ty: &Type) -> Option<Self> {
52 self.is_convertible_to(ty).then_some(ty.clone())
53 }
54 fn convert_inner_to(&self, ty: &Type) -> Option<Self> {
55 match self {
56 Type::Array(inner, n) => inner
57 .convert_to(ty)
58 .map(|inner| Type::Array(inner.into(), *n)),
59 Type::Vec(n, inner) => inner
60 .convert_to(ty)
61 .map(|inner| Type::Vec(*n, inner.into())),
62 Type::Mat(c, r, inner) => inner
63 .convert_to(ty)
64 .map(|inner| Type::Mat(*c, *r, inner.into())),
65 Type::Atomic(_) => (self == ty).then_some(ty.clone()),
66 Type::Ptr(_, _, _) => (self == ty).then_some(ty.clone()),
67 _ => self.convert_to(ty), }
69 }
70 fn concretize(&self) -> Option<Self> {
71 Some(self.concretize())
72 }
73}
74
75impl Type {
76 pub fn is_convertible_to(&self, ty: &Type) -> bool {
77 conversion_rank(self, ty).is_some()
78 }
79 pub fn concretize(&self) -> Self {
80 match self {
81 Self::AbstractInt => Type::I32,
82 Self::AbstractFloat => Type::F32,
83 Self::Array(ty, n) => Type::Array(ty.concretize().into(), *n),
84 Self::Vec(n, ty) => Type::Vec(*n, ty.concretize().into()),
85 Self::Mat(c, r, ty) if ty.is_abstract() => Type::Mat(*c, *r, Type::F32.into()),
86 _ => self.clone(),
87 }
88 }
89
90 pub fn loaded(self) -> Self {
94 if let Type::Ref(_, ty, _) = self {
95 ty.loaded()
96 } else {
97 self
98 }
99 }
100}
101
102impl Instance {
103 pub fn is_convertible_to(&self, ty: &Type) -> bool {
104 self.ty().is_convertible_to(ty)
105 }
106
107 pub fn loaded(mut self) -> Result<Self, Error> {
111 while let Instance::Ref(r) = self {
112 self = r.read()?.to_owned();
113 }
114 Ok(self)
115 }
116}
117
118impl LiteralInstance {
119 pub fn is_infinite(&self) -> bool {
120 match self {
121 LiteralInstance::Bool(_) => false,
122 LiteralInstance::AbstractInt(_) => false,
123 LiteralInstance::AbstractFloat(n) => n.is_infinite(),
124 LiteralInstance::I32(_) => false,
125 LiteralInstance::U32(_) => false,
126 LiteralInstance::F32(n) => n.is_infinite(),
127 LiteralInstance::F16(n) => n.is_infinite(),
128 #[cfg(feature = "naga-ext")]
129 LiteralInstance::I64(_) => false,
130 #[cfg(feature = "naga-ext")]
131 LiteralInstance::U64(_) => false,
132 #[cfg(feature = "naga-ext")]
133 LiteralInstance::F64(n) => n.is_infinite(),
134 }
135 }
136 pub fn is_finite(&self) -> bool {
137 !self.is_infinite()
138 }
139}
140
141impl Convert for LiteralInstance {
142 fn convert_to(&self, ty: &Type) -> Option<Self> {
143 if ty == &self.ty() {
144 return Some(*self);
145 }
146
147 match (self, ty) {
152 (Self::AbstractInt(n), Type::AbstractFloat) => n.to_f64().map(Self::AbstractFloat),
153 (Self::AbstractInt(n), Type::I32) => n.to_i32().map(Self::I32),
154 (Self::AbstractInt(n), Type::U32) => n.to_u32().map(Self::U32),
155 (Self::AbstractInt(n), Type::F32) => n.to_f32().map(Self::F32),
156 (Self::AbstractInt(n), Type::F16) => f16::from_i64(*n).map(Self::F16),
157 (Self::AbstractFloat(n), Type::F32) => n.to_f32().map(Self::F32),
158 (Self::AbstractFloat(n), Type::F16) => Some(Self::F16(f16::from_f64(*n))),
159 _ => None,
160 }
161 .and_then(|n| n.is_finite().then_some(n))
162 }
163}
164
165impl Convert for ArrayInstance {
166 fn convert_to(&self, ty: &Type) -> Option<Self> {
167 if let Type::Array(c_ty, Some(n)) = ty {
168 if *n == self.n() {
169 self.convert_inner_to(c_ty)
170 } else {
171 None
172 }
173 } else if let Type::Array(c_ty, None) = ty {
174 self.convert_inner_to(c_ty)
175 } else {
176 None
177 }
178 }
179 fn convert_inner_to(&self, ty: &Type) -> Option<Self> {
180 let components = self
181 .iter()
182 .map(|c| c.convert_to(ty))
183 .collect::<Option<Vec<_>>>()?;
184 Some(ArrayInstance::new(components, self.runtime_sized))
185 }
186}
187
188impl Convert for VecInstance {
189 fn convert_to(&self, ty: &Type) -> Option<Self> {
190 if let Type::Vec(n, c_ty) = ty {
191 if *n as usize == self.n() {
192 self.convert_inner_to(c_ty)
193 } else {
194 None
195 }
196 } else {
197 None
198 }
199 }
200 fn convert_inner_to(&self, ty: &Type) -> Option<Self> {
201 let components = self
202 .iter()
203 .map(|c| c.convert_to(ty))
204 .collect::<Option<Vec<_>>>()?;
205 Some(VecInstance::new(components))
206 }
207}
208
209impl Convert for MatInstance {
210 fn convert_to(&self, ty: &Type) -> Option<Self> {
211 if let Type::Mat(c, r, c_ty) = ty {
212 if *c as usize == self.c() && *r as usize == self.r() {
213 self.convert_inner_to(c_ty)
214 } else {
215 None
216 }
217 } else {
218 None
219 }
220 }
221 fn convert_inner_to(&self, ty: &Type) -> Option<Self> {
222 let components = self
223 .iter_cols()
224 .map(|c| c.convert_inner_to(ty))
225 .collect::<Option<Vec<_>>>()?;
226 Some(MatInstance::from_cols(components))
227 }
228}
229
230impl Convert for StructInstance {
231 fn convert_to(&self, ty: &Type) -> Option<Self> {
232 if &self.ty() == ty {
233 Some(self.clone())
234 } else if let Type::Struct(s2) = ty {
235 let s1 = &self.ty;
236 if s1.name.starts_with("__") && s2.name.starts_with("__") {
237 if s2.name.ends_with("f32") {
242 let members = self
243 .members
244 .iter()
245 .map(|inst| inst.convert_inner_to(&Type::F32))
246 .collect::<Option<Vec<_>>>()?;
247 Some(StructInstance::new((**s2).clone(), members))
248 } else if s2.name.ends_with("f16") {
249 let members = self
250 .members
251 .iter()
252 .map(|inst| inst.convert_inner_to(&Type::F16))
253 .collect::<Option<Vec<_>>>()?;
254 Some(StructInstance::new((**s2).clone(), members))
255 } else {
256 None
257 }
258 } else {
259 None
260 }
261 } else {
262 None
263 }
264 }
265}
266
267impl Convert for Instance {
268 fn convert_to(&self, ty: &Type) -> Option<Self> {
269 if &self.ty() == ty {
270 return Some(self.clone());
271 }
272 match self {
273 Self::Literal(l) => l.convert_to(ty).map(Self::Literal),
274 Self::Struct(s) => s.convert_to(ty).map(Self::Struct),
275 Self::Array(a) => a.convert_to(ty).map(Self::Array),
276 Self::Vec(v) => v.convert_to(ty).map(Self::Vec),
277 Self::Mat(m) => m.convert_to(ty).map(Self::Mat),
278 Self::Ptr(_) => None,
279 Self::Ref(r) => r.read().ok().and_then(|r| r.convert_to(ty)), Self::Atomic(_) => None,
281 Self::Deferred(_) => None,
282 }
283 }
284
285 fn convert_inner_to(&self, ty: &Type) -> Option<Self> {
286 match self {
287 Self::Literal(l) => l.convert_inner_to(ty).map(Self::Literal),
288 Self::Struct(_) => None,
289 Self::Array(a) => a.convert_inner_to(ty).map(Self::Array),
290 Self::Vec(v) => v.convert_inner_to(ty).map(Self::Vec),
291 Self::Mat(m) => m.convert_inner_to(ty).map(Self::Mat),
292 Self::Ptr(_) => None,
293 Self::Ref(r) => r.read().ok().and_then(|r| r.convert_inner_to(ty)), Self::Atomic(_) => None,
295 Self::Deferred(_) => None,
296 }
297 }
298}
299
300pub fn conversion_rank(ty1: &Type, ty2: &Type) -> Option<u32> {
302 match (ty1, ty2) {
304 (_, _) if ty1 == ty2 => Some(0),
305 (Type::Ref(_, ty1, AccessMode::Read | AccessMode::ReadWrite), ty2) if &**ty1 == ty2 => {
306 Some(0)
307 }
308 (Type::AbstractInt, Type::AbstractFloat) => Some(5),
309 (Type::AbstractInt, Type::I32) => Some(3),
310 (Type::AbstractInt, Type::U32) => Some(4),
311 (Type::AbstractInt, Type::F32) => Some(6),
312 (Type::AbstractInt, Type::F16) => Some(7),
313 (Type::AbstractFloat, Type::F32) => Some(1),
314 (Type::AbstractFloat, Type::F16) => Some(2),
315 (Type::Struct(s1), Type::Struct(s2)) => {
317 if s1.name.starts_with("__") && s1.name.ends_with("abstract") {
318 if s2.name.ends_with("f32") {
319 Some(1)
320 } else if s2.name.ends_with("f16") {
321 Some(2)
322 } else {
323 None
324 }
325 } else {
326 None
327 }
328 }
329 (Type::Array(ty1, n1), Type::Array(ty2, n2)) if n1 == n2 => conversion_rank(ty1, ty2),
330 (Type::Vec(n1, ty1), Type::Vec(n2, ty2)) if n1 == n2 => conversion_rank(ty1, ty2),
331 (Type::Mat(c1, r1, ty1), Type::Mat(c2, r2, ty2)) if c1 == c2 && r1 == r2 => {
332 conversion_rank(ty1, ty2)
333 }
334 _ => None,
335 }
336}
337
338pub fn convert<T: Convert + Ty + Clone>(i1: &T, i2: &T) -> Option<(T, T)> {
341 let (ty1, ty2) = (i1.ty(), i2.ty());
342 let ty = convert_ty(&ty1, &ty2)?;
343 let i1 = i1.convert_to(ty)?;
344 let i2 = i2.convert_to(ty)?;
345 Some((i1, i2))
346}
347
348pub fn convert_inner<T1: Convert + Ty + Clone, T2: Convert + Ty + Clone>(
350 i1: &T1,
351 i2: &T2,
352) -> Option<(T1, T2)> {
353 let (ty1, ty2) = (i1.inner_ty(), i2.inner_ty());
354 let ty = convert_ty(&ty1, &ty2)?;
355 let i1 = i1.convert_inner_to(ty)?;
356 let i2 = i2.convert_inner_to(ty)?;
357 Some((i1, i2))
358}
359
360pub fn convert_all<'a, T: Convert + Ty + Clone + 'a>(insts: &[T]) -> Option<Vec<T>> {
362 let tys = insts.iter().map(|i| i.ty()).collect_vec();
363 let ty = convert_all_ty(&tys)?;
364 convert_all_to(insts, ty)
365}
366
367pub fn convert_all_to<'a, T: Convert + Ty + Clone + 'a>(insts: &[T], ty: &Type) -> Option<Vec<T>> {
369 insts
370 .iter()
371 .map(|inst| inst.convert_to(ty))
372 .collect::<Option<Vec<_>>>()
373}
374
375pub fn convert_all_inner_to<'a, T: Convert + Ty + Clone + 'a>(
377 insts: &[T],
378 ty: &Type,
379) -> Option<Vec<T>> {
380 insts
381 .iter()
382 .map(|inst| inst.convert_inner_to(ty))
383 .collect::<Option<Vec<_>>>()
384}
385
386pub fn convert_ty<'a>(ty1: &'a Type, ty2: &'a Type) -> Option<&'a Type> {
389 conversion_rank(ty1, ty2)
390 .map(|_rank| ty2)
391 .or_else(|| conversion_rank(ty2, ty1).map(|_rank| ty1))
392}
393
394pub fn convert_all_ty<'a>(tys: impl IntoIterator<Item = &'a Type> + 'a) -> Option<&'a Type> {
396 tys.into_iter()
397 .map(Option::Some)
398 .reduce(|ty1, ty2| convert_ty(ty1?, ty2?))
399 .flatten()
400}