1use std::hash::{Hash, Hasher as _};
2use std::sync::atomic::Ordering;
3use std::sync::Arc;
4
5use rustc_hash::FxHasher;
6use vyre_spec::bin_op::OpIntensity;
7
8use crate::ir::{Expr, Node};
9use crate::ir_inner::model::expr::Ident;
10use crate::ir_inner::model::types::BufferAccess;
11use crate::transform::visit::{walk_nodes_and_exprs, ExprVisitor, NodeVisitor};
12
13use super::Program;
14
15fn mix_wire_fallback_hashable<T: Hash>(hasher: &mut blake3::Hasher, value: &T) {
16 let mut state = FxHasher::default();
17 value.hash(&mut state);
18 hasher.update(&state.finish().to_le_bytes());
19}
20
21struct FallbackWireHasher<'a>(&'a mut blake3::Hasher);
23
24impl NodeVisitor for FallbackWireHasher<'_> {
25 fn visit_node(&mut self, node: &Node) {
26 let h = &mut *self.0;
27 match node {
28 Node::Let { name, .. } => {
29 h.update(b"n:Let\0");
30 h.update(name.as_bytes());
31 }
32 Node::Assign { name, .. } => {
33 h.update(b"n:Assign\0");
34 h.update(name.as_bytes());
35 }
36 Node::Store { buffer, .. } => {
37 h.update(b"n:Store\0");
38 h.update(buffer.as_bytes());
39 }
40 Node::If { .. } => {
41 h.update(b"n:If\0");
42 }
43 Node::Loop { var, .. } => {
44 h.update(b"n:Loop\0");
45 h.update(var.as_bytes());
46 }
47 Node::IndirectDispatch {
48 count_buffer,
49 count_offset,
50 } => {
51 h.update(b"n:IndirectDispatch\0");
52 h.update(count_buffer.as_bytes());
53 h.update(&count_offset.to_le_bytes());
54 }
55 Node::AsyncLoad {
56 source,
57 destination,
58 tag,
59 ..
60 } => {
61 h.update(b"n:AsyncLoad\0");
62 h.update(source.as_bytes());
63 h.update(destination.as_bytes());
64 h.update(tag.as_bytes());
65 }
66 Node::AsyncStore {
67 source,
68 destination,
69 tag,
70 ..
71 } => {
72 h.update(b"n:AsyncStore\0");
73 h.update(source.as_bytes());
74 h.update(destination.as_bytes());
75 h.update(tag.as_bytes());
76 }
77 Node::AsyncWait { tag } => {
78 h.update(b"n:AsyncWait\0");
79 h.update(tag.as_bytes());
80 }
81 Node::Trap { tag, .. } => {
82 h.update(b"n:Trap\0");
83 h.update(tag.as_bytes());
84 }
85 Node::Resume { tag } => {
86 h.update(b"n:Resume\0");
87 h.update(tag.as_bytes());
88 }
89 Node::AllReduce { buffer, op, group } => {
90 h.update(b"n:AllReduce\0");
91 h.update(buffer.as_bytes());
92 h.update(&op.builtin_wire_tag().to_le_bytes());
93 h.update(&group.as_u32().to_le_bytes());
94 }
95 Node::AllGather {
96 input,
97 output,
98 group,
99 } => {
100 h.update(b"n:AllGather\0");
101 h.update(input.as_bytes());
102 h.update(output.as_bytes());
103 h.update(&group.as_u32().to_le_bytes());
104 }
105 Node::ReduceScatter {
106 input,
107 output,
108 op,
109 group,
110 } => {
111 h.update(b"n:ReduceScatter\0");
112 h.update(input.as_bytes());
113 h.update(output.as_bytes());
114 h.update(&op.builtin_wire_tag().to_le_bytes());
115 h.update(&group.as_u32().to_le_bytes());
116 }
117 Node::Broadcast {
118 buffer,
119 root,
120 group,
121 } => {
122 h.update(b"n:Broadcast\0");
123 h.update(buffer.as_bytes());
124 h.update(&root.to_le_bytes());
125 h.update(&group.as_u32().to_le_bytes());
126 }
127 Node::Return => {
128 h.update(b"n:Return\0");
129 }
130 Node::Barrier { ordering } => {
131 h.update(b"n:Barrier\0");
132 mix_wire_fallback_hashable(h, ordering);
133 }
134 Node::Block(_) => {
135 h.update(b"n:Block\0");
136 }
137 Node::Region {
138 generator,
139 source_region,
140 ..
141 } => {
142 h.update(b"n:Region\0");
143 h.update(generator.as_bytes());
144 if let Some(source_gen) = source_region {
145 h.update(source_gen.name.as_bytes());
146 }
147 }
148 Node::Opaque(ext) => {
149 h.update(b"n:Opaque\0");
150 h.update(ext.extension_kind().as_bytes());
151 }
152 }
153 }
154}
155
156impl ExprVisitor for FallbackWireHasher<'_> {
157 fn visit_expr(&mut self, expr: &Expr) {
158 let h = &mut *self.0;
159 match expr {
160 Expr::LitU32(v) => {
161 h.update(b"e:LitU32\0");
162 h.update(&v.to_le_bytes());
163 }
164 Expr::LitI32(v) => {
165 h.update(b"e:LitI32\0");
166 h.update(&v.to_le_bytes());
167 }
168 Expr::LitF32(v) => {
169 h.update(b"e:LitF32\0");
170 h.update(&v.to_le_bytes());
171 }
172 Expr::LitBool(v) => {
173 h.update(b"e:LitBool\0");
174 h.update(&[u8::from(*v)]);
175 }
176 Expr::Var(name) => {
177 h.update(b"e:Var\0");
178 h.update(name.as_bytes());
179 }
180 Expr::Load { buffer, .. } => {
181 h.update(b"e:Load\0");
182 h.update(buffer.as_bytes());
183 }
184 Expr::BufLen { buffer } => {
185 h.update(b"e:BufLen\0");
186 h.update(buffer.as_bytes());
187 }
188 Expr::InvocationId { axis } => {
189 h.update(b"e:InvocationId\0");
190 h.update(&[*axis]);
191 }
192 Expr::WorkgroupId { axis } => {
193 h.update(b"e:WorkgroupId\0");
194 h.update(&[*axis]);
195 }
196 Expr::LocalId { axis } => {
197 h.update(b"e:LocalId\0");
198 h.update(&[*axis]);
199 }
200 Expr::BinOp { op, .. } => {
201 h.update(b"e:BinOp\0");
202 mix_wire_fallback_hashable(h, op);
203 }
204 Expr::UnOp { op, .. } => {
205 h.update(b"e:UnOp\0");
206 mix_wire_fallback_hashable(h, op);
207 }
208 Expr::Call { op_id, .. } => {
209 h.update(b"e:Call\0");
210 h.update(op_id.as_bytes());
211 }
212 Expr::Select { .. } => {
213 h.update(b"e:Select\0");
214 }
215 Expr::Cast { target, .. } => {
216 h.update(b"e:Cast\0");
217 mix_wire_fallback_hashable(h, target);
218 }
219 Expr::Fma { .. } => {
220 h.update(b"e:Fma\0");
221 }
222 Expr::Atomic {
223 op,
224 buffer,
225 ordering,
226 ..
227 } => {
228 h.update(b"e:Atomic\0");
229 mix_wire_fallback_hashable(h, op);
230 h.update(buffer.as_bytes());
231 mix_wire_fallback_hashable(h, ordering);
232 }
233 Expr::SubgroupBallot { .. } => {
234 h.update(b"e:SubgroupBallot\0");
235 }
236 Expr::SubgroupShuffle { .. } => {
237 h.update(b"e:SubgroupShuffle\0");
238 }
239 Expr::SubgroupAdd { .. } => {
240 h.update(b"e:SubgroupAdd\0");
241 }
242 Expr::SubgroupLocalId => {
243 h.update(b"e:SubgroupLocalId\0");
244 }
245 Expr::SubgroupSize => {
246 h.update(b"e:SubgroupSize\0");
247 }
248 Expr::Opaque(ext) => {
249 h.update(b"e:Opaque\0");
250 h.update(ext.extension_kind().as_bytes());
251 }
252 }
253 }
254}
255
256impl Program {
257 #[must_use]
267 pub fn reconcile_runnable_top_level(self) -> Self {
268 if self.is_top_level_region_wrapped() {
269 return self;
270 }
271 self.map_entry(Self::wrap_entry)
274 }
275
276 #[must_use]
278 #[inline]
279 pub fn buffer(&self, name: &str) -> Option<&super::BufferDecl> {
280 self.buffer_index
281 .get(name)
282 .and_then(|&index| self.buffers.get(index))
283 }
284
285 #[must_use]
287 #[inline]
288 pub fn buffers(&self) -> &[super::BufferDecl] {
289 self.buffers.as_ref()
290 }
291
292 #[must_use]
294 #[inline]
295 #[cfg(test)]
296 pub(crate) fn buffers_arc(&self) -> &Arc<[super::BufferDecl]> {
297 &self.buffers
298 }
299
300 #[must_use]
307 #[inline]
308 pub fn structural_eq(&self, other: &Self) -> bool {
309 if std::ptr::eq(self, other)
314 || (Arc::ptr_eq(&self.buffers, &other.buffers)
315 && Arc::ptr_eq(&self.entry, &other.entry)
316 && self.entry_op_id == other.entry_op_id
317 && self.non_composable_with_self == other.non_composable_with_self
318 && self.workgroup_size == other.workgroup_size)
319 {
320 return true;
321 }
322 self.entry_op_id == other.entry_op_id
323 && self.non_composable_with_self == other.non_composable_with_self
324 && buffers_equal_ignoring_declaration_order(&self.buffers, &other.buffers)
325 && self.workgroup_size == other.workgroup_size
326 && self.entry == other.entry
327 }
328
329 #[must_use]
331 #[inline]
332 pub fn workgroup_size(&self) -> [u32; 3] {
333 self.workgroup_size
334 }
335
336 #[must_use]
341 #[inline]
342 pub fn parallel_region_size(&self) -> [u32; 3] {
343 self.workgroup_size
344 }
345
346 #[must_use]
349 #[inline]
350 pub fn is_non_composable_with_self(&self) -> bool {
351 self.non_composable_with_self
352 }
353
354 #[must_use]
356 #[inline]
357 pub fn with_non_composable_with_self(mut self, flag: bool) -> Self {
358 self.non_composable_with_self = flag;
359 self.invalidate_caches();
360 self
361 }
362
363 #[inline]
368 pub fn set_workgroup_size(&mut self, workgroup_size: [u32; 3]) {
369 self.workgroup_size = workgroup_size;
370 self.invalidate_caches();
371 }
372
373 #[inline]
375 pub fn set_parallel_region_size(&mut self, parallel_region_size: [u32; 3]) {
376 self.workgroup_size = parallel_region_size;
377 self.invalidate_caches();
378 }
379
380 #[must_use]
382 #[inline]
383 pub fn entry(&self) -> &[Node] {
384 self.entry.as_ref().as_slice()
385 }
386
387 #[must_use]
389 #[inline]
390 pub fn entry_arc(&self) -> &Arc<Vec<Node>> {
391 &self.entry
392 }
393
394 #[must_use]
397 #[inline]
398 pub fn is_explicit_noop(&self) -> bool {
399 self.buffers().is_empty()
400 && matches!(self.entry(), [Node::Region { body, .. }] if body.is_empty())
401 }
402
403 #[must_use]
407 #[inline]
408 pub fn is_top_level_region_wrapped(&self) -> bool {
409 !self.entry.is_empty()
410 && self
411 .entry()
412 .iter()
413 .all(|node| matches!(node, Node::Region { .. }))
414 }
415
416 #[must_use]
419 pub fn top_level_region_violation(&self) -> Option<String> {
420 if self.entry().is_empty() {
421 return Some(
422 "program entry has no top-level Region. Fix: construct runnable programs with Program::wrapped(...) or wrap the body in Node::Region before validation, interpretation, or dispatch."
423 .to_string(),
424 );
425 }
426
427 self.entry()
428 .iter()
429 .enumerate()
430 .find(|(_, node)| !matches!(node, Node::Region { .. }))
431 .map(|(index, node)| {
432 format!(
433 "program entry node {index} is `{}` instead of `Node::Region`. Fix: construct runnable programs with Program::wrapped(...) or wrap the top-level body in Node::Region; raw Program::new is reserved for wire decode and negative tests.",
434 Self::top_level_node_name(node)
435 )
436 })
437 }
438
439 #[must_use]
441 #[inline]
442 pub fn entry_mut(&mut self) -> &mut Vec<Node> {
443 self.invalidate_caches();
444 Arc::make_mut(&mut self.entry)
445 }
446
447 #[must_use]
449 #[inline]
450 pub fn fingerprint(&self) -> [u8; 32] {
451
452 *self.fingerprint.get_or_init(|| {
453 let hash = self.compute_wire_hash();
454 let _ = self.hash.set(hash);
455 *hash.as_bytes()
456 })
457 }
458
459 #[must_use]
472 pub fn vsa_fingerprint(&self) -> Vec<u32> {
473 self.fingerprint()
474 .chunks_exact(core::mem::size_of::<u32>())
475 .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
476 .collect()
477 }
478
479 #[must_use]
481 #[inline]
482 pub fn output_buffer_indices(&self) -> &[u32] {
483 self.output_buffer_index
484 .get_or_init(|| {
485 Arc::new(
486 self.buffers()
487 .iter()
488 .enumerate()
489 .filter_map(|(index, buffer)| {
490 (buffer.access() == BufferAccess::ReadWrite)
491 .then(|| u32::try_from(index).ok())
492 .flatten()
493 })
494 .collect(),
495 )
496 })
497 .as_slice()
498 }
499
500 #[must_use]
502 #[inline]
503 pub fn has_indirect_dispatch(&self) -> bool {
504 *self.has_indirect_dispatch.get_or_init(|| {
505 if !self
511 .stats()
512 .has_any_node_kind(super::stats::NODE_KIND_INDIRECT_DISPATCH)
513 {
514 return false;
515 }
516 let mut stack: smallvec::SmallVec<[&Node; 32]> = self.entry().iter().rev().collect();
517 while let Some(node) = stack.pop() {
518 match node {
519 Node::IndirectDispatch { .. } => return true,
520 Node::If {
521 then, otherwise, ..
522 } => {
523 stack.extend(otherwise.iter().rev());
524 stack.extend(then.iter().rev());
525 }
526 Node::Loop { body, .. } | Node::Block(body) => {
527 stack.extend(body.iter().rev());
528 }
529 Node::Region { body, .. } => {
530 stack.extend(body.iter().rev());
531 }
532 Node::Let { .. }
533 | Node::Assign { .. }
534 | Node::Store { .. }
535 | Node::AllReduce { .. }
536 | Node::AllGather { .. }
537 | Node::ReduceScatter { .. }
538 | Node::Broadcast { .. }
539 | Node::Return
540 | Node::Barrier { .. }
541 | Node::AsyncLoad { .. }
542 | Node::AsyncStore { .. }
543 | Node::AsyncWait { .. }
544 | Node::Trap { .. }
545 | Node::Resume { .. }
546 | Node::Opaque(_) => {}
547 }
548 }
549 false
550 })
551 }
552
553 #[must_use]
555 #[inline]
556 pub fn has_buffer(&self, name: &str) -> bool {
557 self.buffer_index.contains_key(name)
558 }
559
560 #[must_use]
562 #[inline]
563 pub fn buffer_count(&self) -> usize {
564 self.buffers.len()
565 }
566
567 #[inline]
568 pub(super) fn build_buffer_index(
569 buffers: &[super::BufferDecl],
570 ) -> rustc_hash::FxHashMap<Arc<str>, usize> {
571 let mut index = rustc_hash::FxHashMap::default();
572 index.reserve(buffers.len());
573 for (buffer_index, buffer) in buffers.iter().enumerate() {
574 index
575 .entry(Arc::clone(&buffer.name))
576 .or_insert(buffer_index);
577 }
578 index
579 }
580
581 #[inline]
583 pub fn mark_structurally_validated(&self) {
584 self.structural_validated.store(true, Ordering::Release);
585 }
586
587 #[must_use]
589 #[inline]
590 pub fn is_structurally_validated(&self) -> bool {
591 self.structural_validated.load(Ordering::Acquire)
592 }
593
594 #[inline]
596 pub fn mark_validated_on(&self, backend_id: &str) {
597 self.validation_set
598 .get_or_init(|| Arc::new(dashmap::DashSet::new()))
599 .insert(Arc::from(self.validation_cache_key(backend_id)));
600 }
601
602 #[must_use]
604 #[inline]
605 pub fn is_validated_on(&self, backend_id: &str) -> bool {
606 self.validation_set
607 .get()
608 .is_some_and(|set| set.contains(self.validation_cache_key(backend_id).as_str()))
609 }
610
611 #[deprecated(note = "use is_structurally_validated or is_validated_on")]
613 #[must_use]
614 #[inline]
615 pub fn is_validated(&self) -> bool {
616 self.is_structurally_validated()
617 }
618
619 #[deprecated(note = "use mark_structurally_validated or mark_validated_on")]
621 #[inline]
622 pub fn mark_validated(&self) {
623 self.mark_structurally_validated();
624 }
625
626 pub fn validate(&self) -> crate::error::Result<()> {
633 if self.is_structurally_validated() {
634 return Ok(());
635 }
636 let errors = crate::validate::validate(self);
637 if errors.is_empty() {
638 self.mark_structurally_validated();
639 return Ok(());
640 }
641 let mut message = String::new();
642 for (index, error) in errors.into_iter().enumerate() {
643 if index > 0 {
644 message.push_str("; ");
645 }
646 message.push_str(error.message());
647 }
648 Err(crate::error::Error::WireFormatValidation { message })
649 }
650
651 #[inline]
652 #[must_use]
660 pub fn estimate_peak_vram_bytes(&self) -> u64 {
661 self.buffers
662 .iter()
663 .map(|buffer| {
664 let Some(element_size) = buffer.element.size_bytes() else {
665 return u64::MAX;
666 };
667 u64::from(buffer.count)
668 .saturating_mul(u64::try_from(element_size).unwrap_or(u64::MAX))
669 })
670 .fold(0u64, u64::saturating_add)
671 }
672
673 #[must_use]
675 pub fn peak_intensity(&self) -> OpIntensity {
676 let mut peak = OpIntensity::Free;
677 for node in self.entry() {
678 peak = peak.max(Self::node_intensity(node));
679 }
680 peak
681 }
682
683 fn node_intensity(node: &crate::ir::Node) -> OpIntensity {
684 use crate::ir::Node;
685 match node {
686 Node::Let { value, .. } | Node::Assign { value, .. } => Self::expr_intensity(value),
687 Node::Store { index, value, .. } => {
688 Self::expr_intensity(index).max(Self::expr_intensity(value))
689 }
690 Node::If {
691 cond,
692 then,
693 otherwise,
694 } => {
695 let mut p = Self::expr_intensity(cond);
696 for n in then {
697 p = p.max(Self::node_intensity(n));
698 }
699 for n in otherwise {
700 p = p.max(Self::node_intensity(n));
701 }
702 p
703 }
704 Node::Loop { from, to, body, .. } => {
705 let mut p = Self::expr_intensity(from).max(Self::expr_intensity(to));
706 for n in body {
707 p = p.max(Self::node_intensity(n));
708 }
709 p
710 }
711 Node::Block(nodes) => {
712 let mut p = OpIntensity::Free;
713 for n in nodes {
714 p = p.max(Self::node_intensity(n));
715 }
716 p
717 }
718 Node::Region { body, .. } => {
719 let mut p = OpIntensity::Free;
720 for n in body.iter() {
721 p = p.max(Self::node_intensity(n));
722 }
723 p
724 }
725 _ => OpIntensity::Free,
726 }
727 }
728
729 fn expr_intensity(expr: &crate::ir::Expr) -> OpIntensity {
730 use crate::ir::Expr;
731 match expr {
732 Expr::BinOp { op, left, right } => op
733 .intensity()
734 .max(Self::expr_intensity(left))
735 .max(Self::expr_intensity(right)),
736 Expr::UnOp { operand, .. } => Self::expr_intensity(operand),
737 Expr::Load { index, .. } => Self::expr_intensity(index),
738 Expr::Select {
739 cond,
740 true_val,
741 false_val,
742 } => Self::expr_intensity(cond)
743 .max(Self::expr_intensity(true_val))
744 .max(Self::expr_intensity(false_val)),
745 Expr::Cast { value, .. } => Self::expr_intensity(value),
746 Expr::Fma { a, b, c } => Self::expr_intensity(a)
747 .max(Self::expr_intensity(b))
748 .max(Self::expr_intensity(c)),
749 Expr::Atomic {
750 index,
751 value,
752 expected,
753 ..
754 } => {
755 let mut p = Self::expr_intensity(index).max(Self::expr_intensity(value));
756 if let Some(e) = expected {
757 p = p.max(Self::expr_intensity(e));
758 }
759 p.max(OpIntensity::Heavy)
760 }
761 Expr::SubgroupBallot { cond } => Self::expr_intensity(cond).max(OpIntensity::Heavy),
762 Expr::SubgroupShuffle { value, lane } => Self::expr_intensity(value)
763 .max(Self::expr_intensity(lane))
764 .max(OpIntensity::Heavy),
765 Expr::SubgroupAdd { value } => Self::expr_intensity(value).max(OpIntensity::Heavy),
766 _ => OpIntensity::Free,
767 }
768 }
769
770 fn compute_wire_hash(&self) -> blake3::Hash {
771 match self.canonical_wire_hash() {
772 Ok(hash) => hash,
773 Err(error) => {
774 let structural = self.structural_fingerprint_fallback();
775 let err_msg = error.to_string();
776 let mut fallback = Vec::with_capacity(96 + err_msg.len() + structural.len());
777 fallback.extend_from_slice(b"VYRE-PROGRAM-CANONICAL-WIRE-HASH-ERROR\0");
778 fallback.extend_from_slice(err_msg.as_bytes());
779 fallback.push(0);
780 fallback.extend_from_slice(structural.as_bytes());
781 blake3::hash(&fallback)
782 }
783 }
784 }
785
786 fn structural_fingerprint_fallback(&self) -> String {
787 let mut hasher = blake3::Hasher::new();
788 hasher.update(b"VYRE-WIRE-FALLBACK-V4\0");
789 if let Some(id) = self.entry_op_id.as_deref() {
790 hasher.update(id.as_bytes());
791 }
792 hasher.update(b"\0");
793 for axis in &self.workgroup_size {
794 hasher.update(&axis.to_le_bytes());
795 }
796 hasher.update(&[u8::from(self.non_composable_with_self)]);
797 let mut keys: Vec<Vec<u8>> = self
798 .buffers()
799 .iter()
800 .map(buffer_decl_canonical_key)
801 .collect();
802 keys.sort_unstable();
803 for key in keys {
804 hasher.update(&key);
805 }
806 let mut visitor = FallbackWireHasher(&mut hasher);
807 walk_nodes_and_exprs(self, &mut visitor);
808 hasher.finalize().to_hex().to_string()
809 }
810
811 fn validation_cache_key(&self, backend_id: &str) -> String {
812 const HEX: &[u8; 16] = b"0123456789abcdef";
813 let fingerprint = self.fingerprint();
814 let mut key = String::with_capacity(backend_id.len() + 1 + 64);
815 key.push_str(backend_id);
816 key.push(':');
817 for &byte in &fingerprint {
818 key.push(HEX[(byte >> 4) as usize] as char);
819 key.push(HEX[(byte & 0x0f) as usize] as char);
820 }
821 key
822 }
823
824 #[inline]
825 pub(super) fn invalidate_caches(&mut self) {
826 self.structural_validated.store(false, Ordering::Release);
827 if let Some(set) = self.validation_set.get() {
828 set.clear();
829 }
830 let _ = self.hash.take();
831 let _ = self.fingerprint.take();
832 drop(self.output_buffer_index.take());
833 let _ = self.has_indirect_dispatch.take();
834 drop(self.stats.take());
835 }
836
837 #[inline]
838 pub(super) fn wrap_entry(entry: Vec<Node>) -> Vec<Node> {
839 if !Self::entry_needs_root_region(&entry) {
840 return entry;
841 }
842 vec![Node::Region {
843 generator: Ident::from(Self::ROOT_REGION_GENERATOR),
844 source_region: None,
845 body: Arc::new(entry),
846 }]
847 }
848
849 #[inline]
850 fn entry_needs_root_region(entry: &[Node]) -> bool {
851 entry.is_empty()
852 || entry
853 .iter()
854 .any(|node| !matches!(node, Node::Region { .. }))
855 }
856
857 #[inline]
858 fn top_level_node_name(node: &Node) -> &'static str {
859 match node {
860 Node::Let { .. } => "Let",
861 Node::Assign { .. } => "Assign",
862 Node::Store { .. } => "Store",
863 Node::If { .. } => "If",
864 Node::Loop { .. } => "Loop",
865 Node::Return => "Return",
866 Node::Block(_) => "Block",
867 Node::Barrier { .. } => "Barrier",
868 Node::Region { .. } => "Region",
869 Node::IndirectDispatch { .. } => "IndirectDispatch",
870 Node::AsyncLoad { .. } => "AsyncLoad",
871 Node::AsyncStore { .. } => "AsyncStore",
872 Node::AsyncWait { .. } => "AsyncWait",
873 Node::Trap { .. } => "Trap",
874 Node::Resume { .. } => "Resume",
875 Node::AllReduce { .. } => "AllReduce",
876 Node::AllGather { .. } => "AllGather",
877 Node::ReduceScatter { .. } => "ReduceScatter",
878 Node::Broadcast { .. } => "Broadcast",
879 Node::Opaque(_) => "Opaque",
880 }
881 }
882}
883
884pub(crate) fn buffers_equal_ignoring_declaration_order(
885 left: &[super::BufferDecl],
886 right: &[super::BufferDecl],
887) -> bool {
888 if left.len() != right.len() {
889 return false;
890 }
891
892 if left == right {
899 return true;
900 }
901
902
903 let mut left_keys = Vec::with_capacity(left.len());
904 left_keys.extend(left.iter().map(buffer_decl_canonical_key));
905 let mut right_keys = Vec::with_capacity(right.len());
906 right_keys.extend(right.iter().map(buffer_decl_canonical_key));
907 left_keys.sort_unstable();
908 right_keys.sort_unstable();
909 left_keys == right_keys
910}
911
912pub(super) fn buffer_decl_canonical_key(buffer: &super::BufferDecl) -> Vec<u8> {
913 use crate::serial::wire::framing::{put_len_u32, put_u32, put_u8};
914 use crate::serial::wire::tags::put_data_type;
915
916 let mut key = Vec::with_capacity(96);
917 if let Err(error) = put_len_u32(&mut key, buffer.name.len(), "buffer name length") {
918 key.extend_from_slice(b"\0name-length-error\0");
919 key.extend_from_slice(error.as_bytes());
920 }
921 key.extend_from_slice(buffer.name.as_bytes());
922 put_u32(&mut key, buffer.binding);
923 match crate::serial::wire::tags::access_tag::access_tag(&buffer.access) {
924 Ok(tag) => put_u8(&mut key, tag),
925 Err(error) => {
926 put_u8(&mut key, u8::MAX);
927 key.extend_from_slice(error.as_bytes());
928 }
929 }
930 put_u8(
931 &mut key,
932 match buffer.kind {
933 super::MemoryKind::Global => 0,
934 super::MemoryKind::Shared => 1,
935 super::MemoryKind::Uniform => 2,
936 super::MemoryKind::Local => 3,
937 super::MemoryKind::Readonly => 4,
938 super::MemoryKind::Persistent => 5,
939 super::MemoryKind::Push => 6,
940 },
941 );
942 if let Err(error) = put_data_type(&mut key, &buffer.element) {
943 key.extend_from_slice(b"\0dtype-error\0");
944 key.extend_from_slice(error.as_bytes());
945 }
946 put_u32(&mut key, buffer.count);
947 put_u8(&mut key, u8::from(buffer.is_output));
948 put_u8(&mut key, u8::from(buffer.pipeline_live_out));
949 match &buffer.output_byte_range {
950 Some(range) => {
951 put_u8(&mut key, 1);
952 match u32::try_from(range.start) {
953 Ok(start) => put_u32(&mut key, start),
954 Err(error) => {
955 put_u32(&mut key, u32::MAX);
956 key.extend_from_slice(error.to_string().as_bytes());
957 }
958 }
959 match u32::try_from(range.end) {
960 Ok(end) => put_u32(&mut key, end),
961 Err(error) => {
962 put_u32(&mut key, u32::MAX);
963 key.extend_from_slice(error.to_string().as_bytes());
964 }
965 }
966 }
967 None => put_u8(&mut key, 0),
968 }
969 match buffer.hints.coalesce_axis {
970 Some(axis) => {
971 put_u8(&mut key, 1);
972 put_u8(&mut key, axis);
973 }
974 None => put_u8(&mut key, 0),
975 }
976 put_u32(&mut key, buffer.hints.preferred_alignment);
977 put_u8(
978 &mut key,
979 match buffer.hints.cache_locality {
980 super::CacheLocality::Streaming => 0,
981 super::CacheLocality::Temporal => 1,
982 super::CacheLocality::Random => 2,
983 },
984 );
985 put_u8(&mut key, u8::from(buffer.bytes_extraction));
986 key
987}
988