1use std::hash::{Hash, Hasher as _};
2use std::sync::atomic::Ordering;
3use std::sync::Arc;
4
5use rustc_hash::FxHasher;
6use vyre_spec::bin_op::OpIntensity;
7
8use crate::ir::{Expr, Node};
9use crate::ir_inner::model::expr::Ident;
10use crate::ir_inner::model::types::BufferAccess;
11use crate::transform::visit::{walk_nodes_and_exprs, ExprVisitor, NodeVisitor};
12
13use super::Program;
14
15#[derive(Clone, Copy, Debug, Eq, PartialEq)]
17#[repr(u8)]
18pub enum ProgramMutationProvenance {
19 Clean = 0,
21 NonComposableFlag = 1,
23 WorkgroupSize = 2,
25 ParallelRegionSize = 3,
27 EntryMutation = 4,
29 InternalShapeMutation = 5,
31 Unknown = 255,
33}
34
35impl ProgramMutationProvenance {
36 #[inline]
37 const fn from_code(code: u8) -> Self {
38 match code {
39 0 => Self::Clean,
40 1 => Self::NonComposableFlag,
41 2 => Self::WorkgroupSize,
42 3 => Self::ParallelRegionSize,
43 4 => Self::EntryMutation,
44 5 => Self::InternalShapeMutation,
45 _ => Self::Unknown,
46 }
47 }
48}
49
50fn mix_wire_fallback_hashable<T: Hash>(hasher: &mut blake3::Hasher, value: &T) {
51 let mut state = FxHasher::default();
52 value.hash(&mut state);
53 hasher.update(&state.finish().to_le_bytes());
54}
55
56struct FallbackWireHasher<'a>(&'a mut blake3::Hasher);
58
59impl NodeVisitor for FallbackWireHasher<'_> {
60 fn visit_node(&mut self, node: &Node) {
61 let h = &mut *self.0;
62 match node {
63 Node::Let { name, .. } => {
64 h.update(b"n:Let\0");
65 h.update(name.as_bytes());
66 }
67 Node::Assign { name, .. } => {
68 h.update(b"n:Assign\0");
69 h.update(name.as_bytes());
70 }
71 Node::Store { buffer, .. } => {
72 h.update(b"n:Store\0");
73 h.update(buffer.as_bytes());
74 }
75 Node::If { .. } => {
76 h.update(b"n:If\0");
77 }
78 Node::Loop { var, .. } => {
79 h.update(b"n:Loop\0");
80 h.update(var.as_bytes());
81 }
82 Node::IndirectDispatch {
83 count_buffer,
84 count_offset,
85 } => {
86 h.update(b"n:IndirectDispatch\0");
87 h.update(count_buffer.as_bytes());
88 h.update(&count_offset.to_le_bytes());
89 }
90 Node::AsyncLoad {
91 source,
92 destination,
93 tag,
94 ..
95 } => {
96 h.update(b"n:AsyncLoad\0");
97 h.update(source.as_bytes());
98 h.update(destination.as_bytes());
99 h.update(tag.as_bytes());
100 }
101 Node::AsyncStore {
102 source,
103 destination,
104 tag,
105 ..
106 } => {
107 h.update(b"n:AsyncStore\0");
108 h.update(source.as_bytes());
109 h.update(destination.as_bytes());
110 h.update(tag.as_bytes());
111 }
112 Node::AsyncWait { tag } => {
113 h.update(b"n:AsyncWait\0");
114 h.update(tag.as_bytes());
115 }
116 Node::Trap { tag, .. } => {
117 h.update(b"n:Trap\0");
118 h.update(tag.as_bytes());
119 }
120 Node::Resume { tag } => {
121 h.update(b"n:Resume\0");
122 h.update(tag.as_bytes());
123 }
124 Node::AllReduce { buffer, op, group } => {
125 h.update(b"n:AllReduce\0");
126 h.update(buffer.as_bytes());
127 h.update(&op.builtin_wire_tag().to_le_bytes());
128 h.update(&group.as_u32().to_le_bytes());
129 }
130 Node::AllGather {
131 input,
132 output,
133 group,
134 } => {
135 h.update(b"n:AllGather\0");
136 h.update(input.as_bytes());
137 h.update(output.as_bytes());
138 h.update(&group.as_u32().to_le_bytes());
139 }
140 Node::ReduceScatter {
141 input,
142 output,
143 op,
144 group,
145 } => {
146 h.update(b"n:ReduceScatter\0");
147 h.update(input.as_bytes());
148 h.update(output.as_bytes());
149 h.update(&op.builtin_wire_tag().to_le_bytes());
150 h.update(&group.as_u32().to_le_bytes());
151 }
152 Node::Broadcast {
153 buffer,
154 root,
155 group,
156 } => {
157 h.update(b"n:Broadcast\0");
158 h.update(buffer.as_bytes());
159 h.update(&root.to_le_bytes());
160 h.update(&group.as_u32().to_le_bytes());
161 }
162 Node::Return => {
163 h.update(b"n:Return\0");
164 }
165 Node::Barrier { ordering } => {
166 h.update(b"n:Barrier\0");
167 mix_wire_fallback_hashable(h, ordering);
168 }
169 Node::Block(_) => {
170 h.update(b"n:Block\0");
171 }
172 Node::Region {
173 generator,
174 source_region,
175 ..
176 } => {
177 h.update(b"n:Region\0");
178 h.update(generator.as_bytes());
179 if let Some(source_gen) = source_region {
180 h.update(source_gen.name.as_bytes());
181 }
182 }
183 Node::Opaque(ext) => {
184 h.update(b"n:Opaque\0");
185 h.update(ext.extension_kind().as_bytes());
186 }
187 }
188 }
189}
190
191impl ExprVisitor for FallbackWireHasher<'_> {
192 fn visit_expr(&mut self, expr: &Expr) {
193 let h = &mut *self.0;
194 match expr {
195 Expr::LitU32(v) => {
196 h.update(b"e:LitU32\0");
197 h.update(&v.to_le_bytes());
198 }
199 Expr::LitI32(v) => {
200 h.update(b"e:LitI32\0");
201 h.update(&v.to_le_bytes());
202 }
203 Expr::LitF32(v) => {
204 h.update(b"e:LitF32\0");
205 h.update(&v.to_le_bytes());
206 }
207 Expr::LitBool(v) => {
208 h.update(b"e:LitBool\0");
209 h.update(&[u8::from(*v)]);
210 }
211 Expr::Var(name) => {
212 h.update(b"e:Var\0");
213 h.update(name.as_bytes());
214 }
215 Expr::Load { buffer, .. } => {
216 h.update(b"e:Load\0");
217 h.update(buffer.as_bytes());
218 }
219 Expr::BufLen { buffer } => {
220 h.update(b"e:BufLen\0");
221 h.update(buffer.as_bytes());
222 }
223 Expr::InvocationId { axis } => {
224 h.update(b"e:InvocationId\0");
225 h.update(&[*axis]);
226 }
227 Expr::WorkgroupId { axis } => {
228 h.update(b"e:WorkgroupId\0");
229 h.update(&[*axis]);
230 }
231 Expr::LocalId { axis } => {
232 h.update(b"e:LocalId\0");
233 h.update(&[*axis]);
234 }
235 Expr::BinOp { op, .. } => {
236 h.update(b"e:BinOp\0");
237 mix_wire_fallback_hashable(h, op);
238 }
239 Expr::UnOp { op, .. } => {
240 h.update(b"e:UnOp\0");
241 mix_wire_fallback_hashable(h, op);
242 }
243 Expr::Call { op_id, .. } => {
244 h.update(b"e:Call\0");
245 h.update(op_id.as_bytes());
246 }
247 Expr::Select { .. } => {
248 h.update(b"e:Select\0");
249 }
250 Expr::Cast { target, .. } => {
251 h.update(b"e:Cast\0");
252 mix_wire_fallback_hashable(h, target);
253 }
254 Expr::Fma { .. } => {
255 h.update(b"e:Fma\0");
256 }
257 Expr::Atomic {
258 op,
259 buffer,
260 ordering,
261 ..
262 } => {
263 h.update(b"e:Atomic\0");
264 mix_wire_fallback_hashable(h, op);
265 h.update(buffer.as_bytes());
266 mix_wire_fallback_hashable(h, ordering);
267 }
268 Expr::SubgroupBallot { .. } => {
269 h.update(b"e:SubgroupBallot\0");
270 }
271 Expr::SubgroupShuffle { .. } => {
272 h.update(b"e:SubgroupShuffle\0");
273 }
274 Expr::SubgroupAdd { .. } => {
275 h.update(b"e:SubgroupAdd\0");
276 }
277 Expr::SubgroupLocalId => {
278 h.update(b"e:SubgroupLocalId\0");
279 }
280 Expr::SubgroupSize => {
281 h.update(b"e:SubgroupSize\0");
282 }
283 Expr::Opaque(ext) => {
284 h.update(b"e:Opaque\0");
285 h.update(ext.extension_kind().as_bytes());
286 }
287 }
288 }
289}
290
291impl Program {
292 #[must_use]
302 pub fn reconcile_runnable_top_level(self) -> Self {
303 if self.is_top_level_region_wrapped() {
304 return self;
305 }
306 self.map_entry(Self::wrap_entry)
309 }
310
311 #[must_use]
313 #[inline]
314 pub fn buffer(&self, name: &str) -> Option<&super::BufferDecl> {
315 self.buffer_index
316 .get(name)
317 .and_then(|&index| self.buffers.get(index))
318 }
319
320 #[must_use]
322 #[inline]
323 pub fn buffers(&self) -> &[super::BufferDecl] {
324 self.buffers.as_ref()
325 }
326
327 #[must_use]
329 #[inline]
330 #[cfg(test)]
331 pub(crate) fn buffers_arc(&self) -> &Arc<[super::BufferDecl]> {
332 &self.buffers
333 }
334
335 #[must_use]
342 #[inline]
343 pub fn structural_eq(&self, other: &Self) -> bool {
344 if std::ptr::eq(self, other)
349 || (Arc::ptr_eq(&self.buffers, &other.buffers)
350 && Arc::ptr_eq(&self.entry, &other.entry)
351 && self.entry_op_id == other.entry_op_id
352 && self.non_composable_with_self == other.non_composable_with_self
353 && self.workgroup_size == other.workgroup_size)
354 {
355 return true;
356 }
357 self.entry_op_id == other.entry_op_id
358 && self.non_composable_with_self == other.non_composable_with_self
359 && buffers_equal_ignoring_declaration_order(&self.buffers, &other.buffers)
360 && self.workgroup_size == other.workgroup_size
361 && self.entry == other.entry
362 }
363
364 #[must_use]
366 #[inline]
367 pub fn workgroup_size(&self) -> [u32; 3] {
368 self.workgroup_size
369 }
370
371 #[must_use]
376 #[inline]
377 pub fn parallel_region_size(&self) -> [u32; 3] {
378 self.workgroup_size
379 }
380
381 #[must_use]
384 #[inline]
385 pub fn is_non_composable_with_self(&self) -> bool {
386 self.non_composable_with_self
387 }
388
389 #[must_use]
391 #[inline]
392 pub fn with_non_composable_with_self(mut self, flag: bool) -> Self {
393 self.non_composable_with_self = flag;
394 self.invalidate_caches_for(ProgramMutationProvenance::NonComposableFlag);
395 self
396 }
397
398 #[inline]
403 pub fn set_workgroup_size(&mut self, workgroup_size: [u32; 3]) {
404 self.workgroup_size = workgroup_size;
405 self.invalidate_caches_for(ProgramMutationProvenance::WorkgroupSize);
406 }
407
408 #[inline]
410 pub fn set_parallel_region_size(&mut self, parallel_region_size: [u32; 3]) {
411 self.workgroup_size = parallel_region_size;
412 self.invalidate_caches_for(ProgramMutationProvenance::ParallelRegionSize);
413 }
414
415 #[must_use]
417 #[inline]
418 pub fn entry(&self) -> &[Node] {
419 self.entry.as_ref().as_slice()
420 }
421
422 #[must_use]
424 #[inline]
425 pub fn entry_arc(&self) -> &Arc<Vec<Node>> {
426 &self.entry
427 }
428
429 #[must_use]
432 #[inline]
433 pub fn is_explicit_noop(&self) -> bool {
434 self.buffers().is_empty()
435 && matches!(self.entry(), [Node::Region { body, .. }] if body.is_empty())
436 }
437
438 #[must_use]
442 #[inline]
443 pub fn is_top_level_region_wrapped(&self) -> bool {
444 !self.entry.is_empty()
445 && self
446 .entry()
447 .iter()
448 .all(|node| matches!(node, Node::Region { .. }))
449 }
450
451 #[must_use]
454 pub fn top_level_region_violation(&self) -> Option<String> {
455 if self.entry().is_empty() {
456 return Some(
457 "program entry has no top-level Region. Fix: construct runnable programs with Program::wrapped(...) or wrap the body in Node::Region before validation, interpretation, or dispatch."
458 .to_string(),
459 );
460 }
461
462 self.entry()
463 .iter()
464 .enumerate()
465 .find(|(_, node)| !matches!(node, Node::Region { .. }))
466 .map(|(index, node)| {
467 format!(
468 "program entry node {index} is `{}` instead of `Node::Region`. Fix: construct runnable programs with Program::wrapped(...) or wrap the top-level body in Node::Region; raw Program::new is reserved for wire decode and negative tests.",
469 Self::top_level_node_name(node)
470 )
471 })
472 }
473
474 #[must_use]
476 #[inline]
477 pub fn entry_mut(&mut self) -> &mut Vec<Node> {
478 self.invalidate_caches_for(ProgramMutationProvenance::EntryMutation);
479 Arc::make_mut(&mut self.entry)
480 }
481
482 #[must_use]
484 #[inline]
485 pub fn fingerprint(&self) -> [u8; 32] {
486 *self.fingerprint.get_or_init(|| {
487 let hash = self.compute_wire_hash();
488 let _ = self.hash.set(hash);
489 *hash.as_bytes()
490 })
491 }
492
493 #[must_use]
506 pub fn vsa_fingerprint(&self) -> Vec<u32> {
507 self.fingerprint()
508 .chunks_exact(core::mem::size_of::<u32>())
509 .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
510 .collect()
511 }
512
513 #[must_use]
515 #[inline]
516 pub fn output_buffer_indices(&self) -> &[u32] {
517 self.output_buffer_index
518 .get_or_init(|| {
519 Arc::new(
520 self.buffers()
521 .iter()
522 .enumerate()
523 .filter_map(|(index, buffer)| {
524 matches!(
525 buffer.access(),
526 BufferAccess::ReadWrite | BufferAccess::WriteOnly
527 )
528 .then(|| u32::try_from(index).ok())
529 .flatten()
530 })
531 .collect(),
532 )
533 })
534 .as_slice()
535 }
536
537 #[must_use]
539 #[inline]
540 pub fn has_indirect_dispatch(&self) -> bool {
541 *self.has_indirect_dispatch.get_or_init(|| {
542 if !self
548 .stats()
549 .has_any_node_kind(super::stats::NODE_KIND_INDIRECT_DISPATCH)
550 {
551 return false;
552 }
553 let mut stack: smallvec::SmallVec<[&Node; 32]> = self.entry().iter().rev().collect();
554 while let Some(node) = stack.pop() {
555 match node {
556 Node::IndirectDispatch { .. } => return true,
557 Node::If {
558 then, otherwise, ..
559 } => {
560 stack.extend(otherwise.iter().rev());
561 stack.extend(then.iter().rev());
562 }
563 Node::Loop { body, .. } | Node::Block(body) => {
564 stack.extend(body.iter().rev());
565 }
566 Node::Region { body, .. } => {
567 stack.extend(body.iter().rev());
568 }
569 Node::Let { .. }
570 | Node::Assign { .. }
571 | Node::Store { .. }
572 | Node::AllReduce { .. }
573 | Node::AllGather { .. }
574 | Node::ReduceScatter { .. }
575 | Node::Broadcast { .. }
576 | Node::Return
577 | Node::Barrier { .. }
578 | Node::AsyncLoad { .. }
579 | Node::AsyncStore { .. }
580 | Node::AsyncWait { .. }
581 | Node::Trap { .. }
582 | Node::Resume { .. }
583 | Node::Opaque(_) => {}
584 }
585 }
586 false
587 })
588 }
589
590 #[must_use]
592 #[inline]
593 pub fn has_buffer(&self, name: &str) -> bool {
594 self.buffer_index.contains_key(name)
595 }
596
597 #[must_use]
599 #[inline]
600 pub fn buffer_count(&self) -> usize {
601 self.buffers.len()
602 }
603
604 #[inline]
605 pub(super) fn build_buffer_index(
606 buffers: &[super::BufferDecl],
607 ) -> rustc_hash::FxHashMap<Arc<str>, usize> {
608 let mut index = rustc_hash::FxHashMap::default();
609 index.reserve(buffers.len());
610 for (buffer_index, buffer) in buffers.iter().enumerate() {
611 index
612 .entry(Arc::clone(&buffer.name))
613 .or_insert(buffer_index);
614 }
615 index
616 }
617
618 #[inline]
620 pub fn mark_structurally_validated(&self) {
621 self.structural_validation_fingerprint.store(
622 self.current_validation_fingerprint_token(),
623 Ordering::Release,
624 );
625 self.mutation_provenance
626 .store(ProgramMutationProvenance::Clean as u8, Ordering::Release);
627 self.structural_validated.store(true, Ordering::Release);
628 }
629
630 #[must_use]
632 #[inline]
633 pub fn is_structurally_validated(&self) -> bool {
634 if !self.structural_validated.load(Ordering::Acquire) {
635 return false;
636 }
637 if self.validation_mutation_provenance() == ProgramMutationProvenance::Unknown {
638 self.structural_validated.store(false, Ordering::Release);
639 return false;
640 }
641 let recorded = self
642 .structural_validation_fingerprint
643 .load(Ordering::Acquire);
644 if recorded == 0 || recorded != self.current_validation_fingerprint_token() {
645 self.structural_validated.store(false, Ordering::Release);
646 return false;
647 }
648 true
649 }
650
651 #[must_use]
653 #[inline]
654 pub fn validation_mutation_provenance(&self) -> ProgramMutationProvenance {
655 ProgramMutationProvenance::from_code(self.mutation_provenance.load(Ordering::Acquire))
656 }
657
658 #[inline]
662 pub fn mark_unknown_mutation_provenance(&mut self) {
663 self.invalidate_caches_for(ProgramMutationProvenance::Unknown);
664 }
665
666 #[inline]
668 pub fn mark_validated_on(&self, backend_id: &str) {
669 if self.validation_mutation_provenance() == ProgramMutationProvenance::Unknown {
670 return;
671 }
672 self.validation_set
673 .get_or_init(|| Arc::new(dashmap::DashSet::new()))
674 .insert(Arc::from(self.validation_cache_key(backend_id)));
675 }
676
677 #[must_use]
679 #[inline]
680 pub fn is_validated_on(&self, backend_id: &str) -> bool {
681 self.validation_set
682 .get()
683 .is_some_and(|set| set.contains(self.validation_cache_key(backend_id).as_str()))
684 }
685
686 #[deprecated(note = "use is_structurally_validated or is_validated_on")]
688 #[must_use]
689 #[inline]
690 pub fn is_validated(&self) -> bool {
691 self.is_structurally_validated()
692 }
693
694 #[deprecated(note = "use mark_structurally_validated or mark_validated_on")]
696 #[inline]
697 pub fn mark_validated(&self) {
698 self.mark_structurally_validated();
699 }
700
701 pub fn validate(&self) -> crate::error::Result<()> {
708 if self.validation_mutation_provenance() == ProgramMutationProvenance::Unknown {
709 return Err(crate::error::Error::WireFormatValidation {
710 message: "program validation cache was invalidated by unknown mutation provenance. Fix: rebuild the Program through Program::wrapped/from_wire or use a named Program mutation API before validating.".into(),
711 });
712 }
713 if self.is_structurally_validated() {
714 return Ok(());
715 }
716 let errors = crate::validate::validate(self);
717 if errors.is_empty() {
718 self.mark_structurally_validated();
719 return Ok(());
720 }
721 let mut message = String::new();
722 for (index, error) in errors.into_iter().enumerate() {
723 if index > 0 {
724 message.push_str("; ");
725 }
726 message.push_str(error.message());
727 }
728 Err(crate::error::Error::WireFormatValidation { message })
729 }
730
731 #[inline]
732 #[must_use]
740 pub fn estimate_peak_vram_bytes(&self) -> u64 {
741 self.buffers
742 .iter()
743 .map(|buffer| {
744 let Some(element_size) = buffer.element.size_bytes() else {
745 return u64::MAX;
746 };
747 u64::from(buffer.count)
748 .saturating_mul(u64::try_from(element_size).unwrap_or(u64::MAX))
749 })
750 .fold(0u64, u64::saturating_add)
751 }
752
753 #[must_use]
755 pub fn peak_intensity(&self) -> OpIntensity {
756 let mut peak = OpIntensity::Free;
757 for node in self.entry() {
758 peak = peak.max(Self::node_intensity(node));
759 }
760 peak
761 }
762
763 fn node_intensity(node: &crate::ir::Node) -> OpIntensity {
764 use crate::ir::Node;
765 match node {
766 Node::Let { value, .. } | Node::Assign { value, .. } => Self::expr_intensity(value),
767 Node::Store { index, value, .. } => {
768 Self::expr_intensity(index).max(Self::expr_intensity(value))
769 }
770 Node::If {
771 cond,
772 then,
773 otherwise,
774 } => {
775 let mut p = Self::expr_intensity(cond);
776 for n in then {
777 p = p.max(Self::node_intensity(n));
778 }
779 for n in otherwise {
780 p = p.max(Self::node_intensity(n));
781 }
782 p
783 }
784 Node::Loop { from, to, body, .. } => {
785 let mut p = Self::expr_intensity(from).max(Self::expr_intensity(to));
786 for n in body {
787 p = p.max(Self::node_intensity(n));
788 }
789 p
790 }
791 Node::Block(nodes) => {
792 let mut p = OpIntensity::Free;
793 for n in nodes {
794 p = p.max(Self::node_intensity(n));
795 }
796 p
797 }
798 Node::Region { body, .. } => {
799 let mut p = OpIntensity::Free;
800 for n in body.iter() {
801 p = p.max(Self::node_intensity(n));
802 }
803 p
804 }
805 _ => OpIntensity::Free,
806 }
807 }
808
809 fn expr_intensity(expr: &crate::ir::Expr) -> OpIntensity {
810 use crate::ir::Expr;
811 match expr {
812 Expr::BinOp { op, left, right } => op
813 .intensity()
814 .max(Self::expr_intensity(left))
815 .max(Self::expr_intensity(right)),
816 Expr::UnOp { operand, .. } => Self::expr_intensity(operand),
817 Expr::Load { index, .. } => Self::expr_intensity(index),
818 Expr::Select {
819 cond,
820 true_val,
821 false_val,
822 } => Self::expr_intensity(cond)
823 .max(Self::expr_intensity(true_val))
824 .max(Self::expr_intensity(false_val)),
825 Expr::Cast { value, .. } => Self::expr_intensity(value),
826 Expr::Fma { a, b, c } => Self::expr_intensity(a)
827 .max(Self::expr_intensity(b))
828 .max(Self::expr_intensity(c)),
829 Expr::Atomic {
830 index,
831 value,
832 expected,
833 ..
834 } => {
835 let mut p = Self::expr_intensity(index).max(Self::expr_intensity(value));
836 if let Some(e) = expected {
837 p = p.max(Self::expr_intensity(e));
838 }
839 p.max(OpIntensity::Heavy)
840 }
841 Expr::SubgroupBallot { cond } => Self::expr_intensity(cond).max(OpIntensity::Heavy),
842 Expr::SubgroupShuffle { value, lane } => Self::expr_intensity(value)
843 .max(Self::expr_intensity(lane))
844 .max(OpIntensity::Heavy),
845 Expr::SubgroupAdd { value } => Self::expr_intensity(value).max(OpIntensity::Heavy),
846 _ => OpIntensity::Free,
847 }
848 }
849
850 fn compute_wire_hash(&self) -> blake3::Hash {
851 match self.canonical_wire_hash() {
852 Ok(hash) => hash,
853 Err(error) => {
854 let structural = self.structural_fingerprint_fallback();
855 let err_msg = error.to_string();
856 let mut fallback = Vec::with_capacity(96 + err_msg.len() + structural.len());
857 fallback.extend_from_slice(b"VYRE-PROGRAM-CANONICAL-WIRE-HASH-ERROR\0");
858 fallback.extend_from_slice(err_msg.as_bytes());
859 fallback.push(0);
860 fallback.extend_from_slice(structural.as_bytes());
861 blake3::hash(&fallback)
862 }
863 }
864 }
865
866 fn structural_fingerprint_fallback(&self) -> String {
867 let mut hasher = blake3::Hasher::new();
868 hasher.update(b"VYRE-WIRE-FALLBACK-V4\0");
869 if let Some(id) = self.entry_op_id.as_deref() {
870 hasher.update(id.as_bytes());
871 }
872 hasher.update(b"\0");
873 for axis in &self.workgroup_size {
874 hasher.update(&axis.to_le_bytes());
875 }
876 hasher.update(&[u8::from(self.non_composable_with_self)]);
877 let mut keys: Vec<Vec<u8>> = self
878 .buffers()
879 .iter()
880 .map(buffer_decl_canonical_key)
881 .collect();
882 keys.sort_unstable();
883 for key in keys {
884 hasher.update(&key);
885 }
886 let mut visitor = FallbackWireHasher(&mut hasher);
887 walk_nodes_and_exprs(self, &mut visitor);
888 hasher.finalize().to_hex().to_string()
889 }
890
891 fn validation_cache_key(&self, backend_id: &str) -> String {
892 const HEX: &[u8; 16] = b"0123456789abcdef";
893 let fingerprint = self.current_validation_fingerprint();
894 let mut key = String::with_capacity(backend_id.len() + 1 + 64);
895 key.push_str(backend_id);
896 key.push(':');
897 for &byte in &fingerprint {
898 key.push(HEX[(byte >> 4) as usize] as char);
899 key.push(HEX[(byte & 0x0f) as usize] as char);
900 }
901 key
902 }
903
904 #[inline]
905 pub(super) fn invalidate_caches(&mut self) {
906 self.invalidate_caches_for(ProgramMutationProvenance::InternalShapeMutation);
907 }
908
909 #[inline]
910 pub(super) fn invalidate_caches_for(&mut self, provenance: ProgramMutationProvenance) {
911 self.structural_validated.store(false, Ordering::Release);
912 self.structural_validation_fingerprint
913 .store(0, Ordering::Release);
914 self.mutation_provenance
915 .store(provenance as u8, Ordering::Release);
916 if let Some(set) = self.validation_set.get() {
917 set.clear();
918 }
919 let _ = self.hash.take();
920 let _ = self.fingerprint.take();
921 drop(self.output_buffer_index.take());
922 let _ = self.has_indirect_dispatch.take();
923 drop(self.stats.take());
924 }
925
926 fn current_validation_fingerprint(&self) -> [u8; 32] {
927 *self.compute_wire_hash().as_bytes()
928 }
929
930 fn current_validation_fingerprint_token(&self) -> u64 {
931 let bytes = self.current_validation_fingerprint();
932 let token = u64::from_le_bytes([
933 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
934 ]);
935 token.max(1)
936 }
937
938 #[inline]
939 pub(super) fn wrap_entry(entry: Vec<Node>) -> Vec<Node> {
940 if !Self::entry_needs_root_region(&entry) {
941 return entry;
942 }
943 vec![Node::Region {
944 generator: Ident::from(Self::ROOT_REGION_GENERATOR),
945 source_region: None,
946 body: Arc::new(entry),
947 }]
948 }
949
950 #[inline]
951 fn entry_needs_root_region(entry: &[Node]) -> bool {
952 entry.is_empty()
953 || entry
954 .iter()
955 .any(|node| !matches!(node, Node::Region { .. }))
956 }
957
958 #[inline]
959 fn top_level_node_name(node: &Node) -> &'static str {
960 match node {
961 Node::Let { .. } => "Let",
962 Node::Assign { .. } => "Assign",
963 Node::Store { .. } => "Store",
964 Node::If { .. } => "If",
965 Node::Loop { .. } => "Loop",
966 Node::Return => "Return",
967 Node::Block(_) => "Block",
968 Node::Barrier { .. } => "Barrier",
969 Node::Region { .. } => "Region",
970 Node::IndirectDispatch { .. } => "IndirectDispatch",
971 Node::AsyncLoad { .. } => "AsyncLoad",
972 Node::AsyncStore { .. } => "AsyncStore",
973 Node::AsyncWait { .. } => "AsyncWait",
974 Node::Trap { .. } => "Trap",
975 Node::Resume { .. } => "Resume",
976 Node::AllReduce { .. } => "AllReduce",
977 Node::AllGather { .. } => "AllGather",
978 Node::ReduceScatter { .. } => "ReduceScatter",
979 Node::Broadcast { .. } => "Broadcast",
980 Node::Opaque(_) => "Opaque",
981 }
982 }
983}
984
985pub(crate) fn buffers_equal_ignoring_declaration_order(
986 left: &[super::BufferDecl],
987 right: &[super::BufferDecl],
988) -> bool {
989 if left.len() != right.len() {
990 return false;
991 }
992
993 if left == right {
1000 return true;
1001 }
1002
1003 let mut left_keys = Vec::with_capacity(left.len());
1004 left_keys.extend(left.iter().map(buffer_decl_canonical_key));
1005 let mut right_keys = Vec::with_capacity(right.len());
1006 right_keys.extend(right.iter().map(buffer_decl_canonical_key));
1007 left_keys.sort_unstable();
1008 right_keys.sort_unstable();
1009 left_keys == right_keys
1010}
1011
1012pub(super) fn buffer_decl_canonical_key(buffer: &super::BufferDecl) -> Vec<u8> {
1013 use crate::serial::wire::framing::{put_len_u32, put_u32, put_u8};
1014 use crate::serial::wire::tags::put_data_type;
1015
1016 let mut key = Vec::with_capacity(96);
1017 if let Err(error) = put_len_u32(&mut key, buffer.name.len(), "buffer name length") {
1018 key.extend_from_slice(b"\0name-length-error\0");
1019 key.extend_from_slice(error.as_bytes());
1020 }
1021 key.extend_from_slice(buffer.name.as_bytes());
1022 put_u32(&mut key, buffer.binding);
1023 match crate::serial::wire::tags::access_tag::access_tag(&buffer.access) {
1024 Ok(tag) => put_u8(&mut key, tag),
1025 Err(error) => {
1026 put_u8(&mut key, u8::MAX);
1027 key.extend_from_slice(error.as_bytes());
1028 }
1029 }
1030 put_u8(
1031 &mut key,
1032 match buffer.kind {
1033 super::MemoryKind::Global => 0,
1034 super::MemoryKind::Shared => 1,
1035 super::MemoryKind::Uniform => 2,
1036 super::MemoryKind::Local => 3,
1037 super::MemoryKind::Readonly => 4,
1038 super::MemoryKind::Persistent => 5,
1039 super::MemoryKind::Push => 6,
1040 },
1041 );
1042 if let Err(error) = put_data_type(&mut key, &buffer.element) {
1043 key.extend_from_slice(b"\0dtype-error\0");
1044 key.extend_from_slice(error.as_bytes());
1045 }
1046 put_u32(&mut key, buffer.count);
1047 put_u8(&mut key, u8::from(buffer.is_output));
1048 put_u8(&mut key, u8::from(buffer.pipeline_live_out));
1049 match &buffer.output_byte_range {
1050 Some(range) => {
1051 put_u8(&mut key, 1);
1052 match u32::try_from(range.start) {
1053 Ok(start) => put_u32(&mut key, start),
1054 Err(error) => {
1055 put_u32(&mut key, u32::MAX);
1056 key.extend_from_slice(error.to_string().as_bytes());
1057 }
1058 }
1059 match u32::try_from(range.end) {
1060 Ok(end) => put_u32(&mut key, end),
1061 Err(error) => {
1062 put_u32(&mut key, u32::MAX);
1063 key.extend_from_slice(error.to_string().as_bytes());
1064 }
1065 }
1066 }
1067 None => put_u8(&mut key, 0),
1068 }
1069 match buffer.hints.coalesce_axis {
1070 Some(axis) => {
1071 put_u8(&mut key, 1);
1072 put_u8(&mut key, axis);
1073 }
1074 None => put_u8(&mut key, 0),
1075 }
1076 put_u32(&mut key, buffer.hints.preferred_alignment);
1077 put_u8(
1078 &mut key,
1079 match buffer.hints.cache_locality {
1080 super::CacheLocality::Streaming => 0,
1081 super::CacheLocality::Temporal => 1,
1082 super::CacheLocality::Random => 2,
1083 },
1084 );
1085 put_u8(&mut key, u8::from(buffer.bytes_extraction));
1086 key
1087}