Skip to main content

vyre_foundation/ir_inner/model/program/
meta.rs

1use std::hash::{Hash, Hasher as _};
2use std::sync::atomic::Ordering;
3use std::sync::Arc;
4
5use rustc_hash::FxHasher;
6use vyre_spec::bin_op::OpIntensity;
7
8use crate::ir::{Expr, Node};
9use crate::ir_inner::model::expr::Ident;
10use crate::ir_inner::model::types::BufferAccess;
11use crate::transform::visit::{walk_nodes_and_exprs, ExprVisitor, NodeVisitor};
12
13use super::Program;
14
15fn mix_wire_fallback_hashable<T: Hash>(hasher: &mut blake3::Hasher, value: &T) {
16    let mut state = FxHasher::default();
17    value.hash(&mut state);
18    hasher.update(&state.finish().to_le_bytes());
19}
20
21/// Bounded IR structure digest for wire-hash fallback (never formats full IR via `Debug`).
22struct FallbackWireHasher<'a>(&'a mut blake3::Hasher);
23
24impl NodeVisitor for FallbackWireHasher<'_> {
25    fn visit_node(&mut self, node: &Node) {
26        let h = &mut *self.0;
27        match node {
28            Node::Let { name, .. } => {
29                h.update(b"n:Let\0");
30                h.update(name.as_bytes());
31            }
32            Node::Assign { name, .. } => {
33                h.update(b"n:Assign\0");
34                h.update(name.as_bytes());
35            }
36            Node::Store { buffer, .. } => {
37                h.update(b"n:Store\0");
38                h.update(buffer.as_bytes());
39            }
40            Node::If { .. } => {
41                h.update(b"n:If\0");
42            }
43            Node::Loop { var, .. } => {
44                h.update(b"n:Loop\0");
45                h.update(var.as_bytes());
46            }
47            Node::IndirectDispatch {
48                count_buffer,
49                count_offset,
50            } => {
51                h.update(b"n:IndirectDispatch\0");
52                h.update(count_buffer.as_bytes());
53                h.update(&count_offset.to_le_bytes());
54            }
55            Node::AsyncLoad {
56                source,
57                destination,
58                tag,
59                ..
60            } => {
61                h.update(b"n:AsyncLoad\0");
62                h.update(source.as_bytes());
63                h.update(destination.as_bytes());
64                h.update(tag.as_bytes());
65            }
66            Node::AsyncStore {
67                source,
68                destination,
69                tag,
70                ..
71            } => {
72                h.update(b"n:AsyncStore\0");
73                h.update(source.as_bytes());
74                h.update(destination.as_bytes());
75                h.update(tag.as_bytes());
76            }
77            Node::AsyncWait { tag } => {
78                h.update(b"n:AsyncWait\0");
79                h.update(tag.as_bytes());
80            }
81            Node::Trap { tag, .. } => {
82                h.update(b"n:Trap\0");
83                h.update(tag.as_bytes());
84            }
85            Node::Resume { tag } => {
86                h.update(b"n:Resume\0");
87                h.update(tag.as_bytes());
88            }
89            Node::AllReduce { buffer, op, group } => {
90                h.update(b"n:AllReduce\0");
91                h.update(buffer.as_bytes());
92                h.update(&op.builtin_wire_tag().to_le_bytes());
93                h.update(&group.as_u32().to_le_bytes());
94            }
95            Node::AllGather {
96                input,
97                output,
98                group,
99            } => {
100                h.update(b"n:AllGather\0");
101                h.update(input.as_bytes());
102                h.update(output.as_bytes());
103                h.update(&group.as_u32().to_le_bytes());
104            }
105            Node::ReduceScatter {
106                input,
107                output,
108                op,
109                group,
110            } => {
111                h.update(b"n:ReduceScatter\0");
112                h.update(input.as_bytes());
113                h.update(output.as_bytes());
114                h.update(&op.builtin_wire_tag().to_le_bytes());
115                h.update(&group.as_u32().to_le_bytes());
116            }
117            Node::Broadcast {
118                buffer,
119                root,
120                group,
121            } => {
122                h.update(b"n:Broadcast\0");
123                h.update(buffer.as_bytes());
124                h.update(&root.to_le_bytes());
125                h.update(&group.as_u32().to_le_bytes());
126            }
127            Node::Return => {
128                h.update(b"n:Return\0");
129            }
130            Node::Barrier { ordering } => {
131                h.update(b"n:Barrier\0");
132                mix_wire_fallback_hashable(h, ordering);
133            }
134            Node::Block(_) => {
135                h.update(b"n:Block\0");
136            }
137            Node::Region {
138                generator,
139                source_region,
140                ..
141            } => {
142                h.update(b"n:Region\0");
143                h.update(generator.as_bytes());
144                if let Some(source_gen) = source_region {
145                    h.update(source_gen.name.as_bytes());
146                }
147            }
148            Node::Opaque(ext) => {
149                h.update(b"n:Opaque\0");
150                h.update(ext.extension_kind().as_bytes());
151            }
152        }
153    }
154}
155
156impl ExprVisitor for FallbackWireHasher<'_> {
157    fn visit_expr(&mut self, expr: &Expr) {
158        let h = &mut *self.0;
159        match expr {
160            Expr::LitU32(v) => {
161                h.update(b"e:LitU32\0");
162                h.update(&v.to_le_bytes());
163            }
164            Expr::LitI32(v) => {
165                h.update(b"e:LitI32\0");
166                h.update(&v.to_le_bytes());
167            }
168            Expr::LitF32(v) => {
169                h.update(b"e:LitF32\0");
170                h.update(&v.to_le_bytes());
171            }
172            Expr::LitBool(v) => {
173                h.update(b"e:LitBool\0");
174                h.update(&[u8::from(*v)]);
175            }
176            Expr::Var(name) => {
177                h.update(b"e:Var\0");
178                h.update(name.as_bytes());
179            }
180            Expr::Load { buffer, .. } => {
181                h.update(b"e:Load\0");
182                h.update(buffer.as_bytes());
183            }
184            Expr::BufLen { buffer } => {
185                h.update(b"e:BufLen\0");
186                h.update(buffer.as_bytes());
187            }
188            Expr::InvocationId { axis } => {
189                h.update(b"e:InvocationId\0");
190                h.update(&[*axis]);
191            }
192            Expr::WorkgroupId { axis } => {
193                h.update(b"e:WorkgroupId\0");
194                h.update(&[*axis]);
195            }
196            Expr::LocalId { axis } => {
197                h.update(b"e:LocalId\0");
198                h.update(&[*axis]);
199            }
200            Expr::BinOp { op, .. } => {
201                h.update(b"e:BinOp\0");
202                mix_wire_fallback_hashable(h, op);
203            }
204            Expr::UnOp { op, .. } => {
205                h.update(b"e:UnOp\0");
206                mix_wire_fallback_hashable(h, op);
207            }
208            Expr::Call { op_id, .. } => {
209                h.update(b"e:Call\0");
210                h.update(op_id.as_bytes());
211            }
212            Expr::Select { .. } => {
213                h.update(b"e:Select\0");
214            }
215            Expr::Cast { target, .. } => {
216                h.update(b"e:Cast\0");
217                mix_wire_fallback_hashable(h, target);
218            }
219            Expr::Fma { .. } => {
220                h.update(b"e:Fma\0");
221            }
222            Expr::Atomic {
223                op,
224                buffer,
225                ordering,
226                ..
227            } => {
228                h.update(b"e:Atomic\0");
229                mix_wire_fallback_hashable(h, op);
230                h.update(buffer.as_bytes());
231                mix_wire_fallback_hashable(h, ordering);
232            }
233            Expr::SubgroupBallot { .. } => {
234                h.update(b"e:SubgroupBallot\0");
235            }
236            Expr::SubgroupShuffle { .. } => {
237                h.update(b"e:SubgroupShuffle\0");
238            }
239            Expr::SubgroupAdd { .. } => {
240                h.update(b"e:SubgroupAdd\0");
241            }
242            Expr::SubgroupLocalId => {
243                h.update(b"e:SubgroupLocalId\0");
244            }
245            Expr::SubgroupSize => {
246                h.update(b"e:SubgroupSize\0");
247            }
248            Expr::Opaque(ext) => {
249                h.update(b"e:Opaque\0");
250                h.update(ext.extension_kind().as_bytes());
251            }
252        }
253    }
254}
255
256impl Program {
257    /// Re-apply the same top-level `Node::Region` contract as
258    /// [`Program::wrapped`].
259    ///
260    /// The [`region_inline_engine`](crate::optimizer::passes::cleanup::region_inline_engine)
261    /// pass flattens small Category-A regions so CSE/DCE can see a single
262    /// function-shaped body, which can leave a statement-shaped entry list. The
263    /// standard optimizer run ends with this helper so the program remains in
264    /// a runnable, validator/reference-interpreter–compatible form while
265    /// still benefiting from the inline pass.
266    #[must_use]
267    pub fn reconcile_runnable_top_level(self) -> Self {
268        if self.is_top_level_region_wrapped() {
269            return self;
270        }
271        // Move the entry Vec out via map_entry's Arc-aware path; one
272        // Program rebuild instead of two scaffold rebuilds.
273        self.map_entry(Self::wrap_entry)
274    }
275
276    /// Look up a buffer declaration by name.
277    #[must_use]
278    #[inline]
279    pub fn buffer(&self, name: &str) -> Option<&super::BufferDecl> {
280        self.buffer_index
281            .get(name)
282            .and_then(|&index| self.buffers.get(index))
283    }
284
285    /// Declared buffers.
286    #[must_use]
287    #[inline]
288    pub fn buffers(&self) -> &[super::BufferDecl] {
289        self.buffers.as_ref()
290    }
291
292    /// Access the buffer declaration Arc directly for identity checks.
293    #[must_use]
294    #[inline]
295    #[cfg(test)]
296    pub(crate) fn buffers_arc(&self) -> &Arc<[super::BufferDecl]> {
297        &self.buffers
298    }
299
300    /// Compare two programs by observable IR structure.
301    ///
302    /// This walk intentionally ignores buffer declaration order and never
303    /// consults arena-local allocation identity. Two programs are structurally
304    /// equal when they declare the same buffers, workgroup size, optional entry
305    /// op id, and entry body semantics.
306    #[must_use]
307    #[inline]
308    pub fn structural_eq(&self, other: &Self) -> bool {
309        // Identity short-circuit: Program::clone shares all the
310        // inner Arcs, so comparing a cloned program against its
311        // source (the common optimizer-pipeline pattern) is pure
312        // refcount comparison.
313        if std::ptr::eq(self, other)
314            || (Arc::ptr_eq(&self.buffers, &other.buffers)
315                && Arc::ptr_eq(&self.entry, &other.entry)
316                && self.entry_op_id == other.entry_op_id
317                && self.non_composable_with_self == other.non_composable_with_self
318                && self.workgroup_size == other.workgroup_size)
319        {
320            return true;
321        }
322        self.entry_op_id == other.entry_op_id
323            && self.non_composable_with_self == other.non_composable_with_self
324            && buffers_equal_ignoring_declaration_order(&self.buffers, &other.buffers)
325            && self.workgroup_size == other.workgroup_size
326            && self.entry == other.entry
327    }
328
329    /// Workgroup dimensions.
330    #[must_use]
331    #[inline]
332    pub fn workgroup_size(&self) -> [u32; 3] {
333        self.workgroup_size
334    }
335
336    /// Substrate-neutral alias for [`workgroup_size`](Self::workgroup_size).
337    ///
338    /// Naming: "parallel region" avoids picking a single target substrate's
339    /// word for one dispatch invocation grouping.
340    #[must_use]
341    #[inline]
342    pub fn parallel_region_size(&self) -> [u32; 3] {
343        self.workgroup_size
344    }
345
346    /// Return true when this program must not be fused with another copy
347    /// of itself in the same megakernel.
348    #[must_use]
349    #[inline]
350    pub fn is_non_composable_with_self(&self) -> bool {
351        self.non_composable_with_self
352    }
353
354    /// Mark this program as non-composable with itself.
355    #[must_use]
356    #[inline]
357    pub fn with_non_composable_with_self(mut self, flag: bool) -> Self {
358        self.non_composable_with_self = flag;
359        self.invalidate_caches();
360        self
361    }
362
363    /// Set the workgroup dimensions in place. Used by harnesses that
364    /// need to clone-and-rewrite a program's workgroup size for fallback
365    /// dispatch  -  the alternative was to reconstruct the entire Program,
366    /// which is unnecessarily expensive when only one field changes.
367    #[inline]
368    pub fn set_workgroup_size(&mut self, workgroup_size: [u32; 3]) {
369        self.workgroup_size = workgroup_size;
370        self.invalidate_caches();
371    }
372
373    /// Substrate-neutral alias for [`set_workgroup_size`](Self::set_workgroup_size).
374    #[inline]
375    pub fn set_parallel_region_size(&mut self, parallel_region_size: [u32; 3]) {
376        self.workgroup_size = parallel_region_size;
377        self.invalidate_caches();
378    }
379
380    /// Entry-point nodes.
381    #[must_use]
382    #[inline]
383    pub fn entry(&self) -> &[Node] {
384        self.entry.as_ref().as_slice()
385    }
386
387    /// Shared entry-point body Arc for identity checks.
388    #[must_use]
389    #[inline]
390    pub fn entry_arc(&self) -> &Arc<Vec<Node>> {
391        &self.entry
392    }
393
394    /// Return true when this Program is the canonical no-op shape produced by
395    /// [`Program::empty`]: no buffers and a single empty root Region.
396    #[must_use]
397    #[inline]
398    pub fn is_explicit_noop(&self) -> bool {
399        self.buffers().is_empty()
400            && matches!(self.entry(), [Node::Region { body, .. }] if body.is_empty())
401    }
402
403    /// Return true when the program satisfies the top-level region-chain
404    /// invariant: at least one top-level node, and every top-level node is a
405    /// `Node::Region`.
406    #[must_use]
407    #[inline]
408    pub fn is_top_level_region_wrapped(&self) -> bool {
409        !self.entry.is_empty()
410            && self
411                .entry()
412                .iter()
413                .all(|node| matches!(node, Node::Region { .. }))
414    }
415
416    /// Actionable error text describing why the top-level region invariant
417    /// failed, or `None` when the entry is valid.
418    #[must_use]
419    pub fn top_level_region_violation(&self) -> Option<String> {
420        if self.entry().is_empty() {
421            return Some(
422                "program entry has no top-level Region. Fix: construct runnable programs with Program::wrapped(...) or wrap the body in Node::Region before validation, interpretation, or dispatch."
423                    .to_string(),
424            );
425        }
426
427        self.entry()
428            .iter()
429            .enumerate()
430            .find(|(_, node)| !matches!(node, Node::Region { .. }))
431            .map(|(index, node)| {
432                format!(
433                    "program entry node {index} is `{}` instead of `Node::Region`. Fix: construct runnable programs with Program::wrapped(...) or wrap the top-level body in Node::Region; raw Program::new is reserved for wire decode and negative tests.",
434                    Self::top_level_node_name(node)
435                )
436            })
437    }
438
439    /// Mutable entry-point nodes for transformation passes.
440    #[must_use]
441    #[inline]
442    pub fn entry_mut(&mut self) -> &mut Vec<Node> {
443        self.invalidate_caches();
444        Arc::make_mut(&mut self.entry)
445    }
446
447    /// Stable BLAKE3 fingerprint of the canonical wire-format bytes.
448    #[must_use]
449    #[inline]
450    pub fn fingerprint(&self) -> [u8; 32] {
451
452        *self.fingerprint.get_or_init(|| {
453            let hash = self.compute_wire_hash();
454            let _ = self.hash.set(hash);
455            *hash.as_bytes()
456        })
457    }
458
459    /// VSA-style hypervector fingerprint of the canonical wire-format
460    /// bytes. Each `u32` lane is one segment of the program's blake3
461    /// hash; together they form an 8-lane hypervector suitable for
462    /// approximate similarity search via hamming distance.
463    ///
464    /// Use as the canonical cache key for approximate-match caches
465    /// (e.g. validation cache, AOT artifact dedup); use
466    /// [`Self::fingerprint`] for exact-match lookups.
467    ///
468    /// Wires the substrate's #29 hypervector primitive into Program
469    /// itself  -  every Program now carries its own VSA fingerprint
470    /// without callers having to reach into the substrate explicitly.
471    #[must_use]
472    pub fn vsa_fingerprint(&self) -> Vec<u32> {
473        self.fingerprint()
474            .chunks_exact(core::mem::size_of::<u32>())
475            .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
476            .collect()
477    }
478
479    /// Indices of read-write buffers in `buffers()` order.
480    #[must_use]
481    #[inline]
482    pub fn output_buffer_indices(&self) -> &[u32] {
483        self.output_buffer_index
484            .get_or_init(|| {
485                Arc::new(
486                    self.buffers()
487                        .iter()
488                        .enumerate()
489                        .filter_map(|(index, buffer)| {
490                            (buffer.access() == BufferAccess::ReadWrite)
491                                .then(|| u32::try_from(index).ok())
492                                .flatten()
493                        })
494                        .collect(),
495                )
496            })
497            .as_slice()
498    }
499
500    /// True when the entry walk discovers any indirect dispatch node.
501    #[must_use]
502    #[inline]
503    pub fn has_indirect_dispatch(&self) -> bool {
504        *self.has_indirect_dispatch.get_or_init(|| {
505            // Fast-path: ProgramStats records every node kind seen during
506            // its single-pass walk. If the IndirectDispatch bit is unset,
507            // the tree definitely contains no IndirectDispatch nodes and
508            // the explicit traversal below would redundantly visit every
509            // node only to return false. Reading the bit is O(1).
510            if !self
511                .stats()
512                .has_any_node_kind(super::stats::NODE_KIND_INDIRECT_DISPATCH)
513            {
514                return false;
515            }
516            let mut stack: smallvec::SmallVec<[&Node; 32]> = self.entry().iter().rev().collect();
517            while let Some(node) = stack.pop() {
518                match node {
519                    Node::IndirectDispatch { .. } => return true,
520                    Node::If {
521                        then, otherwise, ..
522                    } => {
523                        stack.extend(otherwise.iter().rev());
524                        stack.extend(then.iter().rev());
525                    }
526                    Node::Loop { body, .. } | Node::Block(body) => {
527                        stack.extend(body.iter().rev());
528                    }
529                    Node::Region { body, .. } => {
530                        stack.extend(body.iter().rev());
531                    }
532                    Node::Let { .. }
533                    | Node::Assign { .. }
534                    | Node::Store { .. }
535                    | Node::AllReduce { .. }
536                    | Node::AllGather { .. }
537                    | Node::ReduceScatter { .. }
538                    | Node::Broadcast { .. }
539                    | Node::Return
540                    | Node::Barrier { .. }
541                    | Node::AsyncLoad { .. }
542                    | Node::AsyncStore { .. }
543                    | Node::AsyncWait { .. }
544                    | Node::Trap { .. }
545                    | Node::Resume { .. }
546                    | Node::Opaque(_) => {}
547                }
548            }
549            false
550        })
551    }
552
553    /// Check whether a named buffer exists.
554    #[must_use]
555    #[inline]
556    pub fn has_buffer(&self, name: &str) -> bool {
557        self.buffer_index.contains_key(name)
558    }
559
560    /// Number of declared buffers.
561    #[must_use]
562    #[inline]
563    pub fn buffer_count(&self) -> usize {
564        self.buffers.len()
565    }
566
567    #[inline]
568    pub(super) fn build_buffer_index(
569        buffers: &[super::BufferDecl],
570    ) -> rustc_hash::FxHashMap<Arc<str>, usize> {
571        let mut index = rustc_hash::FxHashMap::default();
572        index.reserve(buffers.len());
573        for (buffer_index, buffer) in buffers.iter().enumerate() {
574            index
575                .entry(Arc::clone(&buffer.name))
576                .or_insert(buffer_index);
577        }
578        index
579    }
580
581    /// Mark this program as successfully validated structurally.
582    #[inline]
583    pub fn mark_structurally_validated(&self) {
584        self.structural_validated.store(true, Ordering::Release);
585    }
586
587    /// Return true once structural validation has succeeded for this program shape.
588    #[must_use]
589    #[inline]
590    pub fn is_structurally_validated(&self) -> bool {
591        self.structural_validated.load(Ordering::Acquire)
592    }
593
594    /// Mark this program as successfully validated for a specific backend.
595    #[inline]
596    pub fn mark_validated_on(&self, backend_id: &str) {
597        self.validation_set
598            .get_or_init(|| Arc::new(dashmap::DashSet::new()))
599            .insert(Arc::from(self.validation_cache_key(backend_id)));
600    }
601
602    /// Return true if this program has been validated for the given backend.
603    #[must_use]
604    #[inline]
605    pub fn is_validated_on(&self, backend_id: &str) -> bool {
606        self.validation_set
607            .get()
608            .is_some_and(|set| set.contains(self.validation_cache_key(backend_id).as_str()))
609    }
610
611    /// Deprecated: use `is_structurally_validated` or `is_validated_on`.
612    #[deprecated(note = "use is_structurally_validated or is_validated_on")]
613    #[must_use]
614    #[inline]
615    pub fn is_validated(&self) -> bool {
616        self.is_structurally_validated()
617    }
618
619    /// Deprecated: use `mark_structurally_validated` or `mark_validated_on`.
620    #[deprecated(note = "use mark_structurally_validated or mark_validated_on")]
621    #[inline]
622    pub fn mark_validated(&self) {
623        self.mark_structurally_validated();
624    }
625
626    /// Validate the program and cache the successful result on the program.
627    ///
628    /// # Errors
629    ///
630    /// Returns [`crate::Error::WireFormatValidation`] with every validation
631    /// message joined when the structural validator rejects the program.
632    pub fn validate(&self) -> crate::error::Result<()> {
633        if self.is_structurally_validated() {
634            return Ok(());
635        }
636        let errors = crate::validate::validate(self);
637        if errors.is_empty() {
638            self.mark_structurally_validated();
639            return Ok(());
640        }
641        let mut message = String::new();
642        for (index, error) in errors.into_iter().enumerate() {
643            if index > 0 {
644                message.push_str("; ");
645            }
646            message.push_str(error.message());
647        }
648        Err(crate::error::Error::WireFormatValidation { message })
649    }
650
651    #[inline]
652    /// Estimate the peak VRAM byte size of this Program.
653    ///
654    /// Innovation I.11: Static VRAM Pressure Analysis.
655    /// Returns the total bytes required by all storage and uniform buffers
656    /// declared in the Program. Optimizer passes use this to automatically
657    /// partition workloads if they would exceed a backend-specific safety
658    /// margin.
659    #[must_use]
660    pub fn estimate_peak_vram_bytes(&self) -> u64 {
661        self.buffers
662            .iter()
663            .map(|buffer| {
664                let Some(element_size) = buffer.element.size_bytes() else {
665                    return u64::MAX;
666                };
667                u64::from(buffer.count)
668                    .saturating_mul(u64::try_from(element_size).unwrap_or(u64::MAX))
669            })
670            .fold(0u64, u64::saturating_add)
671    }
672
673    /// Return the peak computational intensity found in any instruction.
674    #[must_use]
675    pub fn peak_intensity(&self) -> OpIntensity {
676        let mut peak = OpIntensity::Free;
677        for node in self.entry() {
678            peak = peak.max(Self::node_intensity(node));
679        }
680        peak
681    }
682
683    fn node_intensity(node: &crate::ir::Node) -> OpIntensity {
684        use crate::ir::Node;
685        match node {
686            Node::Let { value, .. } | Node::Assign { value, .. } => Self::expr_intensity(value),
687            Node::Store { index, value, .. } => {
688                Self::expr_intensity(index).max(Self::expr_intensity(value))
689            }
690            Node::If {
691                cond,
692                then,
693                otherwise,
694            } => {
695                let mut p = Self::expr_intensity(cond);
696                for n in then {
697                    p = p.max(Self::node_intensity(n));
698                }
699                for n in otherwise {
700                    p = p.max(Self::node_intensity(n));
701                }
702                p
703            }
704            Node::Loop { from, to, body, .. } => {
705                let mut p = Self::expr_intensity(from).max(Self::expr_intensity(to));
706                for n in body {
707                    p = p.max(Self::node_intensity(n));
708                }
709                p
710            }
711            Node::Block(nodes) => {
712                let mut p = OpIntensity::Free;
713                for n in nodes {
714                    p = p.max(Self::node_intensity(n));
715                }
716                p
717            }
718            Node::Region { body, .. } => {
719                let mut p = OpIntensity::Free;
720                for n in body.iter() {
721                    p = p.max(Self::node_intensity(n));
722                }
723                p
724            }
725            _ => OpIntensity::Free,
726        }
727    }
728
729    fn expr_intensity(expr: &crate::ir::Expr) -> OpIntensity {
730        use crate::ir::Expr;
731        match expr {
732            Expr::BinOp { op, left, right } => op
733                .intensity()
734                .max(Self::expr_intensity(left))
735                .max(Self::expr_intensity(right)),
736            Expr::UnOp { operand, .. } => Self::expr_intensity(operand),
737            Expr::Load { index, .. } => Self::expr_intensity(index),
738            Expr::Select {
739                cond,
740                true_val,
741                false_val,
742            } => Self::expr_intensity(cond)
743                .max(Self::expr_intensity(true_val))
744                .max(Self::expr_intensity(false_val)),
745            Expr::Cast { value, .. } => Self::expr_intensity(value),
746            Expr::Fma { a, b, c } => Self::expr_intensity(a)
747                .max(Self::expr_intensity(b))
748                .max(Self::expr_intensity(c)),
749            Expr::Atomic {
750                index,
751                value,
752                expected,
753                ..
754            } => {
755                let mut p = Self::expr_intensity(index).max(Self::expr_intensity(value));
756                if let Some(e) = expected {
757                    p = p.max(Self::expr_intensity(e));
758                }
759                p.max(OpIntensity::Heavy)
760            }
761            Expr::SubgroupBallot { cond } => Self::expr_intensity(cond).max(OpIntensity::Heavy),
762            Expr::SubgroupShuffle { value, lane } => Self::expr_intensity(value)
763                .max(Self::expr_intensity(lane))
764                .max(OpIntensity::Heavy),
765            Expr::SubgroupAdd { value } => Self::expr_intensity(value).max(OpIntensity::Heavy),
766            _ => OpIntensity::Free,
767        }
768    }
769
770    fn compute_wire_hash(&self) -> blake3::Hash {
771        match self.canonical_wire_hash() {
772            Ok(hash) => hash,
773            Err(error) => {
774                let structural = self.structural_fingerprint_fallback();
775                let err_msg = error.to_string();
776                let mut fallback = Vec::with_capacity(96 + err_msg.len() + structural.len());
777                fallback.extend_from_slice(b"VYRE-PROGRAM-CANONICAL-WIRE-HASH-ERROR\0");
778                fallback.extend_from_slice(err_msg.as_bytes());
779                fallback.push(0);
780                fallback.extend_from_slice(structural.as_bytes());
781                blake3::hash(&fallback)
782            }
783        }
784    }
785
786    fn structural_fingerprint_fallback(&self) -> String {
787        let mut hasher = blake3::Hasher::new();
788        hasher.update(b"VYRE-WIRE-FALLBACK-V4\0");
789        if let Some(id) = self.entry_op_id.as_deref() {
790            hasher.update(id.as_bytes());
791        }
792        hasher.update(b"\0");
793        for axis in &self.workgroup_size {
794            hasher.update(&axis.to_le_bytes());
795        }
796        hasher.update(&[u8::from(self.non_composable_with_self)]);
797        let mut keys: Vec<Vec<u8>> = self
798            .buffers()
799            .iter()
800            .map(buffer_decl_canonical_key)
801            .collect();
802        keys.sort_unstable();
803        for key in keys {
804            hasher.update(&key);
805        }
806        let mut visitor = FallbackWireHasher(&mut hasher);
807        walk_nodes_and_exprs(self, &mut visitor);
808        hasher.finalize().to_hex().to_string()
809    }
810
811    fn validation_cache_key(&self, backend_id: &str) -> String {
812        const HEX: &[u8; 16] = b"0123456789abcdef";
813        let fingerprint = self.fingerprint();
814        let mut key = String::with_capacity(backend_id.len() + 1 + 64);
815        key.push_str(backend_id);
816        key.push(':');
817        for &byte in &fingerprint {
818            key.push(HEX[(byte >> 4) as usize] as char);
819            key.push(HEX[(byte & 0x0f) as usize] as char);
820        }
821        key
822    }
823
824    #[inline]
825    pub(super) fn invalidate_caches(&mut self) {
826        self.structural_validated.store(false, Ordering::Release);
827        if let Some(set) = self.validation_set.get() {
828            set.clear();
829        }
830        let _ = self.hash.take();
831        let _ = self.fingerprint.take();
832        drop(self.output_buffer_index.take());
833        let _ = self.has_indirect_dispatch.take();
834        drop(self.stats.take());
835    }
836
837    #[inline]
838    pub(super) fn wrap_entry(entry: Vec<Node>) -> Vec<Node> {
839        if !Self::entry_needs_root_region(&entry) {
840            return entry;
841        }
842        vec![Node::Region {
843            generator: Ident::from(Self::ROOT_REGION_GENERATOR),
844            source_region: None,
845            body: Arc::new(entry),
846        }]
847    }
848
849    #[inline]
850    fn entry_needs_root_region(entry: &[Node]) -> bool {
851        entry.is_empty()
852            || entry
853                .iter()
854                .any(|node| !matches!(node, Node::Region { .. }))
855    }
856
857    #[inline]
858    fn top_level_node_name(node: &Node) -> &'static str {
859        match node {
860            Node::Let { .. } => "Let",
861            Node::Assign { .. } => "Assign",
862            Node::Store { .. } => "Store",
863            Node::If { .. } => "If",
864            Node::Loop { .. } => "Loop",
865            Node::Return => "Return",
866            Node::Block(_) => "Block",
867            Node::Barrier { .. } => "Barrier",
868            Node::Region { .. } => "Region",
869            Node::IndirectDispatch { .. } => "IndirectDispatch",
870            Node::AsyncLoad { .. } => "AsyncLoad",
871            Node::AsyncStore { .. } => "AsyncStore",
872            Node::AsyncWait { .. } => "AsyncWait",
873            Node::Trap { .. } => "Trap",
874            Node::Resume { .. } => "Resume",
875            Node::AllReduce { .. } => "AllReduce",
876            Node::AllGather { .. } => "AllGather",
877            Node::ReduceScatter { .. } => "ReduceScatter",
878            Node::Broadcast { .. } => "Broadcast",
879            Node::Opaque(_) => "Opaque",
880        }
881    }
882}
883
884pub(crate) fn buffers_equal_ignoring_declaration_order(
885    left: &[super::BufferDecl],
886    right: &[super::BufferDecl],
887) -> bool {
888    if left.len() != right.len() {
889        return false;
890    }
891
892    // VYRE_IR_HOTSPOTS HIGH (meta.rs:360-379): previous impl allocated
893    // two Vec<Vec<u8>> then sorted on every equality call. Fast-path:
894    // if the slices compare equal in-place (declaration orders match)
895    // we skip the key-materialization entirely. This catches every
896    // Program::clone(prog) == prog and every `Arc::clone`-equivalent
897    // comparison, which dominate the call distribution.
898    if left == right {
899        return true;
900    }
901
902
903    let mut left_keys = Vec::with_capacity(left.len());
904    left_keys.extend(left.iter().map(buffer_decl_canonical_key));
905    let mut right_keys = Vec::with_capacity(right.len());
906    right_keys.extend(right.iter().map(buffer_decl_canonical_key));
907    left_keys.sort_unstable();
908    right_keys.sort_unstable();
909    left_keys == right_keys
910}
911
912pub(super) fn buffer_decl_canonical_key(buffer: &super::BufferDecl) -> Vec<u8> {
913    use crate::serial::wire::framing::{put_len_u32, put_u32, put_u8};
914    use crate::serial::wire::tags::put_data_type;
915
916    let mut key = Vec::with_capacity(96);
917    if let Err(error) = put_len_u32(&mut key, buffer.name.len(), "buffer name length") {
918        key.extend_from_slice(b"\0name-length-error\0");
919        key.extend_from_slice(error.as_bytes());
920    }
921    key.extend_from_slice(buffer.name.as_bytes());
922    put_u32(&mut key, buffer.binding);
923    match crate::serial::wire::tags::access_tag::access_tag(&buffer.access) {
924        Ok(tag) => put_u8(&mut key, tag),
925        Err(error) => {
926            put_u8(&mut key, u8::MAX);
927            key.extend_from_slice(error.as_bytes());
928        }
929    }
930    put_u8(
931        &mut key,
932        match buffer.kind {
933            super::MemoryKind::Global => 0,
934            super::MemoryKind::Shared => 1,
935            super::MemoryKind::Uniform => 2,
936            super::MemoryKind::Local => 3,
937            super::MemoryKind::Readonly => 4,
938            super::MemoryKind::Persistent => 5,
939            super::MemoryKind::Push => 6,
940        },
941    );
942    if let Err(error) = put_data_type(&mut key, &buffer.element) {
943        key.extend_from_slice(b"\0dtype-error\0");
944        key.extend_from_slice(error.as_bytes());
945    }
946    put_u32(&mut key, buffer.count);
947    put_u8(&mut key, u8::from(buffer.is_output));
948    put_u8(&mut key, u8::from(buffer.pipeline_live_out));
949    match &buffer.output_byte_range {
950        Some(range) => {
951            put_u8(&mut key, 1);
952            match u32::try_from(range.start) {
953                Ok(start) => put_u32(&mut key, start),
954                Err(error) => {
955                    put_u32(&mut key, u32::MAX);
956                    key.extend_from_slice(error.to_string().as_bytes());
957                }
958            }
959            match u32::try_from(range.end) {
960                Ok(end) => put_u32(&mut key, end),
961                Err(error) => {
962                    put_u32(&mut key, u32::MAX);
963                    key.extend_from_slice(error.to_string().as_bytes());
964                }
965            }
966        }
967        None => put_u8(&mut key, 0),
968    }
969    match buffer.hints.coalesce_axis {
970        Some(axis) => {
971            put_u8(&mut key, 1);
972            put_u8(&mut key, axis);
973        }
974        None => put_u8(&mut key, 0),
975    }
976    put_u32(&mut key, buffer.hints.preferred_alignment);
977    put_u8(
978        &mut key,
979        match buffer.hints.cache_locality {
980            super::CacheLocality::Streaming => 0,
981            super::CacheLocality::Temporal => 1,
982            super::CacheLocality::Random => 2,
983        },
984    );
985    put_u8(&mut key, u8::from(buffer.bytes_extraction));
986    key
987}
988