1use std::{
4 cell::{Ref, RefCell, RefMut},
5 ops::Index,
6 rc::Rc,
7};
8
9use half::f16;
10use itertools::Itertools;
11
12use crate::{
13 Error,
14 syntax::{AccessMode, AddressSpace},
15 ty::{StructType, Ty, Type},
16};
17
18type E = Error;
19
20#[derive(Clone, Debug, PartialEq, Eq)]
24pub enum MemView {
25 Whole,
27 Member(String, Box<MemView>),
29 Index(usize, Box<MemView>),
31}
32
33impl MemView {
34 pub fn append_member(&mut self, comp: String) {
35 match self {
36 MemView::Whole => *self = MemView::Member(comp, Box::new(MemView::Whole)),
37 MemView::Member(_, v) | MemView::Index(_, v) => v.append_member(comp),
38 }
39 }
40 pub fn append_index(&mut self, index: usize) {
41 match self {
42 MemView::Whole => *self = MemView::Index(index, Box::new(MemView::Whole)),
43 MemView::Member(_, v) | MemView::Index(_, v) => v.append_index(index),
44 }
45 }
46}
47
48#[derive(Clone, Debug, PartialEq)]
52pub enum Instance {
53 Literal(LiteralInstance),
54 Struct(StructInstance),
55 Array(ArrayInstance),
56 Vec(VecInstance),
57 Mat(MatInstance),
58 Ptr(PtrInstance),
59 Ref(RefInstance),
60 Atomic(AtomicInstance),
61 Deferred(Type),
64}
65
66impl Instance {
67 pub fn unwrap_literal(self) -> LiteralInstance {
68 match self {
69 Instance::Literal(field_0) => field_0,
70 val => panic!("called `Instance::unwrap_literal()` on a `{val}` value"),
71 }
72 }
73 pub fn unwrap_literal_ref(&self) -> &LiteralInstance {
74 match self {
75 Instance::Literal(field_0) => field_0,
76 val => panic!("called `Instance::unwrap_literal_ref()` on a `{val}` value"),
77 }
78 }
79 pub fn unwrap_vec(self) -> VecInstance {
80 match self {
81 Instance::Vec(field_0) => field_0,
82 val => panic!("called `Instance::unwrap_vec()` on a `{val}` value"),
83 }
84 }
85 pub fn unwrap_vec_ref(&self) -> &VecInstance {
86 match self {
87 Instance::Vec(field_0) => field_0,
88 val => panic!("called `Instance::unwrap_vec_ref()` on a `{val}` value"),
89 }
90 }
91 pub fn unwrap_vec_mut(&mut self) -> &mut VecInstance {
92 match self {
93 Instance::Vec(field_0) => field_0,
94 val => panic!("called `Instance::unwrap_vec_mut()` on a `{val}` value"),
95 }
96 }
97}
98
99macro_rules! from_enum {
100 ($target_enum:ident :: $field:ident ( $from:ident )) => {
101 impl From<$from> for $target_enum {
102 fn from(value: $from) -> Self {
103 $target_enum::$field(value)
104 }
105 }
106 };
107}
108
109from_enum!(Instance::Literal(LiteralInstance));
110from_enum!(Instance::Struct(StructInstance));
111from_enum!(Instance::Array(ArrayInstance));
112from_enum!(Instance::Vec(VecInstance));
113from_enum!(Instance::Mat(MatInstance));
114from_enum!(Instance::Ptr(PtrInstance));
115from_enum!(Instance::Ref(RefInstance));
116from_enum!(Instance::Atomic(AtomicInstance));
117from_enum!(Instance::Deferred(Type));
118
119macro_rules! impl_transitive_from {
123 ($from:ident => $middle:ident => $into:ident) => {
124 impl From<$from> for $into {
125 fn from(value: $from) -> Self {
126 $into::from($middle::from(value))
127 }
128 }
129 };
130}
131
132impl_transitive_from!(bool => LiteralInstance => Instance);
133impl_transitive_from!(i64 => LiteralInstance => Instance);
134impl_transitive_from!(f64 => LiteralInstance => Instance);
135impl_transitive_from!(i32 => LiteralInstance => Instance);
136impl_transitive_from!(u32 => LiteralInstance => Instance);
137impl_transitive_from!(f32 => LiteralInstance => Instance);
138
139impl Instance {
140 pub fn view(&self, view: &MemView) -> Result<&Instance, E> {
148 match view {
149 MemView::Whole => Ok(self),
150 MemView::Member(m, v) => match self {
151 Instance::Struct(s) => {
152 let inst = s.member(m).ok_or_else(|| E::Component(s.ty(), m.clone()))?;
153 inst.view(v)
154 }
155 _ => Err(E::Component(self.ty(), m.clone())),
156 },
157 MemView::Index(i, view) => match self {
158 Instance::Array(a) => {
159 let inst = a
160 .components
161 .get(*i)
162 .ok_or(E::OutOfBounds(*i, a.ty(), a.n()))?;
163 inst.view(view)
164 }
165 Instance::Vec(v) => {
166 let inst = v
167 .components
168 .get(*i)
169 .ok_or(E::OutOfBounds(*i, v.ty(), v.n()))?;
170 inst.view(view)
171 }
172 Instance::Mat(m) => {
173 let inst = m
174 .components
175 .get(*i)
176 .ok_or(E::OutOfBounds(*i, m.ty(), m.c()))?;
177 inst.view(view)
178 }
179 _ => Err(E::NotIndexable(self.ty())),
180 },
181 }
182 }
183
184 pub fn view_mut(&mut self, view: &MemView) -> Result<&mut Instance, E> {
188 let ty = self.ty();
189 match view {
190 MemView::Whole => Ok(self),
191 MemView::Member(m, v) => match self {
192 Instance::Struct(s) => {
193 let inst = s.member_mut(m).ok_or_else(|| E::Component(ty, m.clone()))?;
194 inst.view_mut(v)
195 }
196 _ => Err(E::Component(ty, m.clone())),
197 },
198 MemView::Index(i, view) => match self {
199 Instance::Array(a) => {
200 let n = a.n();
201 let inst = a.components.get_mut(*i).ok_or(E::OutOfBounds(*i, ty, n))?;
202 inst.view_mut(view)
203 }
204 Instance::Vec(v) => {
205 let n = v.n();
206 let inst = v.components.get_mut(*i).ok_or(E::OutOfBounds(*i, ty, n))?;
207 inst.view_mut(view)
208 }
209 Instance::Mat(m) => {
210 let c = m.c();
211 let inst = m.components.get_mut(*i).ok_or(E::OutOfBounds(*i, ty, c))?;
212 inst.view_mut(view)
213 }
214 _ => Err(E::NotIndexable(ty)),
215 },
216 }
217 }
218
219 pub fn write(&mut self, value: Instance) -> Result<Instance, E> {
223 if value.ty() != self.ty() {
224 return Err(E::WriteRefType(value.ty(), self.ty()));
225 }
226 let old = std::mem::replace(self, value);
227 Ok(old)
228 }
229}
230
231#[derive(Clone, Copy, Debug, PartialEq)]
233pub enum LiteralInstance {
234 Bool(bool),
235 AbstractInt(i64),
236 AbstractFloat(f64),
237 I32(i32),
238 U32(u32),
239 F32(f32),
240 F16(f16),
241 #[cfg(feature = "naga-ext")]
242 I64(i64), #[cfg(feature = "naga-ext")]
244 U64(u64), #[cfg(feature = "naga-ext")]
246 F64(f64),
247}
248
249from_enum!(LiteralInstance::Bool(bool));
250from_enum!(LiteralInstance::AbstractInt(i64));
251from_enum!(LiteralInstance::AbstractFloat(f64));
252from_enum!(LiteralInstance::I32(i32));
253from_enum!(LiteralInstance::U32(u32));
254from_enum!(LiteralInstance::F32(f32));
255from_enum!(LiteralInstance::F16(f16));
256
257impl LiteralInstance {
258 pub fn unwrap_bool(self) -> bool {
259 match self {
260 LiteralInstance::Bool(field_0) => field_0,
261 val => panic!("called `LiteralInstance::unwrap_bool()` on a `{val}` value"),
262 }
263 }
264
265 pub fn unwrap_abstract_int(self) -> i64 {
266 match self {
267 LiteralInstance::AbstractInt(field_0) => field_0,
268 val => panic!("called `LiteralInstance::unwrap_abstract_int()` on a `{val}` value"),
269 }
270 }
271 pub fn unwrap_abstract_float(self) -> f64 {
272 match self {
273 LiteralInstance::AbstractFloat(field_0) => field_0,
274 val => panic!("called `LiteralInstance::unwrap_abstract_float()` on a `{val}` value"),
275 }
276 }
277 pub fn unwrap_i32(self) -> i32 {
278 match self {
279 LiteralInstance::I32(field_0) => field_0,
280 val => panic!("called `LiteralInstance::unwrap_i32()` on a `{val}` value"),
281 }
282 }
283 pub fn unwrap_u32(self) -> u32 {
284 match self {
285 LiteralInstance::U32(field_0) => field_0,
286 val => panic!("called `LiteralInstance::unwrap_u32()` on a `{val}` value"),
287 }
288 }
289 pub fn unwrap_f32(self) -> f32 {
290 match self {
291 LiteralInstance::F32(field_0) => field_0,
292 val => panic!("called `LiteralInstance::unwrap_f32()` on a `{val}` value"),
293 }
294 }
295 pub fn unwrap_f16(self) -> f16 {
296 match self {
297 LiteralInstance::F16(field_0) => field_0,
298 val => panic!("called `LiteralInstance::unwrap_f16()` on a `{val}` value"),
299 }
300 }
301 #[cfg(feature = "naga-ext")]
302 pub fn unwrap_i64(self) -> i64 {
303 match self {
304 LiteralInstance::I64(field_0) => field_0,
305 val => panic!("called `LiteralInstance::unwrap_i64()` on a `{val}` value"),
306 }
307 }
308 #[cfg(feature = "naga-ext")]
309 pub fn unwrap_u64(self) -> u64 {
310 match self {
311 LiteralInstance::U64(field_0) => field_0,
312 val => panic!("called `LiteralInstance::unwrap_u64()` on a `{val}` value"),
313 }
314 }
315 #[cfg(feature = "naga-ext")]
316 pub fn unwrap_f64(self) -> f64 {
317 match self {
318 LiteralInstance::F64(field_0) => field_0,
319 val => panic!("called `LiteralInstance::unwrap_f64()` on a `{val}` value"),
320 }
321 }
322}
323
324#[derive(Clone, Debug, PartialEq)]
328pub struct StructInstance {
329 pub ty: StructType,
330 pub members: Vec<Instance>,
331}
332
333impl StructInstance {
334 pub fn new(ty: StructType, members: Vec<Instance>) -> Self {
340 assert_eq!(ty.members.len(), members.len());
341 for (m, m_ty) in members.iter().zip(&ty.members) {
342 assert_eq!(m_ty.ty, m.ty());
343 }
344
345 Self { ty, members }
346 }
347 pub fn member(&self, name: &str) -> Option<&Instance> {
349 self.members
350 .iter()
351 .zip(&self.ty.members)
352 .find_map(|(inst, m_ty)| (m_ty.name == name).then_some(inst))
353 }
354 pub fn member_mut(&mut self, name: &str) -> Option<&mut Instance> {
356 self.members
357 .iter_mut()
358 .zip(&self.ty.members)
359 .find_map(|(inst, m_ty)| (m_ty.name == name).then_some(inst))
360 }
361 }
365
366#[derive(Clone, Debug, PartialEq, Default)]
370pub struct ArrayInstance {
371 components: Vec<Instance>,
372 pub runtime_sized: bool,
373}
374
375impl ArrayInstance {
376 pub fn new(components: Vec<Instance>, runtime_sized: bool) -> Self {
382 assert!(!components.is_empty());
383 assert!(components.iter().map(|c| c.ty()).all_equal());
384 Self {
385 components,
386 runtime_sized,
387 }
388 }
389 pub fn n(&self) -> usize {
391 self.components.len()
392 }
393 pub fn get(&self, i: usize) -> Option<&Instance> {
395 self.components.get(i)
396 }
397 pub fn get_mut(&mut self, i: usize) -> Option<&mut Instance> {
399 self.components.get_mut(i)
400 }
401 pub fn iter(&self) -> impl Iterator<Item = &Instance> {
402 self.components.iter()
403 }
404 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Instance> {
405 self.components.iter_mut()
406 }
407 pub fn as_slice(&self) -> &[Instance] {
408 self.components.as_slice()
409 }
410}
411
412impl IntoIterator for ArrayInstance {
413 type Item = Instance;
414 type IntoIter = <Vec<Instance> as IntoIterator>::IntoIter;
415
416 fn into_iter(self) -> Self::IntoIter {
417 self.components.into_iter()
418 }
419}
420
421#[derive(Clone, Debug, PartialEq)]
425pub struct VecInstance {
426 components: ArrayInstance,
427}
428
429impl VecInstance {
430 pub fn new(components: Vec<Instance>) -> Self {
437 assert!((2..=4).contains(&components.len()));
438 let components = ArrayInstance::new(components, false);
439 assert!(components.inner_ty().is_scalar());
440 Self { components }
441 }
442 pub fn n(&self) -> usize {
444 self.components.n()
445 }
446 pub fn get(&self, i: usize) -> Option<&Instance> {
448 self.components.get(i)
449 }
450 pub fn get_mut(&mut self, i: usize) -> Option<&mut Instance> {
452 self.components.get_mut(i)
453 }
454 pub fn iter(&self) -> impl Iterator<Item = &Instance> {
455 self.components.iter()
456 }
457 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Instance> {
458 self.components.iter_mut()
459 }
460 pub fn as_slice(&self) -> &[Instance] {
461 self.components.as_slice()
462 }
463}
464
465impl IntoIterator for VecInstance {
466 type Item = Instance;
467 type IntoIter = <ArrayInstance as IntoIterator>::IntoIter;
468
469 fn into_iter(self) -> Self::IntoIter {
470 self.components.into_iter()
471 }
472}
473
474impl Index<usize> for VecInstance {
475 type Output = Instance;
476
477 fn index(&self, index: usize) -> &Self::Output {
478 self.get(index).unwrap()
479 }
480}
481
482impl<T: Into<Instance>> From<[T; 2]> for VecInstance {
483 fn from(components: [T; 2]) -> Self {
484 Self::new(components.map(Into::into).to_vec())
485 }
486}
487impl<T: Into<Instance>> From<[T; 3]> for VecInstance {
488 fn from(components: [T; 3]) -> Self {
489 Self::new(components.map(Into::into).to_vec())
490 }
491}
492impl<T: Into<Instance>> From<[T; 4]> for VecInstance {
493 fn from(components: [T; 4]) -> Self {
494 Self::new(components.map(Into::into).to_vec())
495 }
496}
497
498#[derive(Clone, Debug, PartialEq)]
502pub struct MatInstance {
503 components: Vec<Instance>,
505}
506
507impl MatInstance {
508 pub fn from_cols(components: Vec<Instance>) -> Self {
517 assert!((2..=4).contains(&components.len()));
518 assert!(
519 components
520 .iter()
521 .map(|c| c.unwrap_vec_ref().n())
522 .all_equal(),
523 "MatInstance columns must have the same number for rows"
524 );
525 assert!(
526 components.iter().map(|c| c.ty()).all_equal(),
527 "MatInstance columns must have the same type"
528 );
529 Self { components }
530 }
531
532 pub fn r(&self) -> usize {
534 self.components.first().unwrap().unwrap_vec_ref().n()
535 }
536 pub fn c(&self) -> usize {
538 self.components.len()
539 }
540 pub fn col(&self, i: usize) -> Option<&Instance> {
542 self.components.get(i)
543 }
544 pub fn col_mut(&mut self, i: usize) -> Option<&mut Instance> {
546 self.components.get_mut(i)
547 }
548 pub fn get(&self, col: usize, row: usize) -> Option<&Instance> {
550 self.col(col).and_then(|v| v.unwrap_vec_ref().get(row))
551 }
552 pub fn get_mut(&mut self, i: usize, j: usize) -> Option<&mut Instance> {
554 self.col_mut(i).and_then(|v| v.unwrap_vec_mut().get_mut(j))
555 }
556 pub fn iter_cols(&self) -> impl Iterator<Item = &Instance> {
557 self.components.iter()
558 }
559 pub fn iter_cols_mut(&mut self) -> impl Iterator<Item = &mut Instance> {
560 self.components.iter_mut()
561 }
562 pub fn iter(&self) -> impl Iterator<Item = &Instance> {
563 self.components
564 .iter()
565 .flat_map(|v| v.unwrap_vec_ref().iter())
566 }
567 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Instance> {
568 self.components
569 .iter_mut()
570 .flat_map(|v| v.unwrap_vec_mut().iter_mut())
571 }
572}
573impl IntoIterator for MatInstance {
574 type Item = Instance;
575 type IntoIter = <Vec<Instance> as IntoIterator>::IntoIter;
576
577 fn into_iter(self) -> Self::IntoIter {
578 self.components.into_iter()
579 }
580}
581
582#[derive(Clone, Debug, PartialEq)]
586pub struct PtrInstance {
587 pub ptr: RefInstance,
588}
589
590impl From<RefInstance> for PtrInstance {
591 fn from(r: RefInstance) -> Self {
592 Self { ptr: r }
593 }
594}
595
596#[derive(Clone, Debug, PartialEq)]
600pub struct RefInstance {
601 pub ty: Type,
603 pub space: AddressSpace,
604 pub access: AccessMode,
605 pub view: MemView,
606 pub ptr: Rc<RefCell<Instance>>,
607}
608
609impl RefInstance {
610 pub fn new(inst: Instance, space: AddressSpace, access: AccessMode) -> Self {
611 let ty = inst.ty();
612 Self {
613 ty,
614 space,
615 access,
616 view: MemView::Whole,
617 ptr: Rc::new(RefCell::new(inst)),
618 }
619 }
620}
621
622impl From<PtrInstance> for RefInstance {
623 fn from(p: PtrInstance) -> Self {
624 p.ptr
625 }
626}
627
628impl RefInstance {
629 pub fn view_member(&self, comp: String) -> Result<Self, E> {
631 if !self.access.is_read() {
632 return Err(E::NotRead);
633 }
634 let mut view = self.view.clone();
635 view.append_member(comp);
636 let ty = self.ptr.borrow().view(&view)?.ty();
637 Ok(Self {
638 ty,
639 space: self.space,
640 access: self.access,
641 view,
642 ptr: self.ptr.clone(),
643 })
644 }
645 pub fn view_index(&self, index: usize) -> Result<Self, E> {
647 if !self.access.is_read() {
648 return Err(E::NotRead);
649 }
650 let mut view = self.view.clone();
651 view.append_index(index);
652 let ty = self.ptr.borrow().view(&view)?.ty();
653 Ok(Self {
654 ty,
655 space: self.space,
656 access: self.access,
657 view,
658 ptr: self.ptr.clone(),
659 })
660 }
661
662 pub fn read<'a>(&'a self) -> Result<Ref<'a, Instance>, E> {
663 if !self.access.is_read() {
664 return Err(E::NotRead);
665 }
666 Ok(Ref::<'a, Instance>::map(self.ptr.borrow(), |r| {
667 r.view(&self.view).expect("invalid reference")
668 }))
669 }
670
671 pub fn write(&self, value: Instance) -> Result<(), E> {
672 if !self.access.is_write() {
673 return Err(E::NotWrite);
674 }
675 if value.ty() != self.ty {
676 return Err(E::WriteRefType(value.ty(), self.ty.clone()));
677 }
678 let mut r = self.ptr.borrow_mut();
679 let view = r.view_mut(&self.view).expect("invalid reference");
680 assert!(view.ty() == value.ty());
681 let _ = std::mem::replace(view, value);
682 Ok(())
683 }
684
685 pub fn read_write<'a>(&'a self) -> Result<RefMut<'a, Instance>, E> {
686 if !self.access.is_write() {
687 return Err(E::NotReadWrite);
688 }
689 Ok(RefMut::<'a, Instance>::map(self.ptr.borrow_mut(), |r| {
690 r.view_mut(&self.view).expect("invalid reference")
691 }))
692 }
693}
694
695#[derive(Clone, Debug, PartialEq)]
699pub struct AtomicInstance {
700 content: Box<Instance>,
701}
702
703impl AtomicInstance {
704 pub fn new(inst: Instance) -> Self {
707 assert!(matches!(inst.ty(), Type::I32 | Type::U32));
708 Self {
709 content: inst.into(),
710 }
711 }
712
713 pub fn inner(&self) -> &Instance {
714 &self.content
715 }
716
717 pub fn inner_mut(&mut self) -> &mut Instance {
718 &mut self.content
719 }
720}