Skip to main content

vyre_foundation/ir_inner/model/program/
meta.rs

1use std::hash::{Hash, Hasher as _};
2use std::sync::atomic::Ordering;
3use std::sync::Arc;
4
5use rustc_hash::FxHasher;
6use vyre_spec::bin_op::OpIntensity;
7
8use crate::ir::{Expr, Node};
9use crate::ir_inner::model::expr::Ident;
10use crate::ir_inner::model::types::BufferAccess;
11use crate::transform::visit::{walk_nodes_and_exprs, ExprVisitor, NodeVisitor};
12
13use super::Program;
14
15/// Provenance for mutations that invalidate Program validation/cache state.
16#[derive(Clone, Copy, Debug, Eq, PartialEq)]
17#[repr(u8)]
18pub enum ProgramMutationProvenance {
19    /// Program has not been mutated since construction or successful validation.
20    Clean = 0,
21    /// The non-composable dispatch flag changed.
22    NonComposableFlag = 1,
23    /// The target workgroup dimensions changed.
24    WorkgroupSize = 2,
25    /// The substrate-neutral parallel-region dimensions changed.
26    ParallelRegionSize = 3,
27    /// A caller borrowed the mutable entry vector.
28    EntryMutation = 4,
29    /// Internal builder or decode path rewrote Program shape.
30    InternalShapeMutation = 5,
31    /// Mutation provenance is unknown, so validation must fail closed.
32    Unknown = 255,
33}
34
35impl ProgramMutationProvenance {
36    #[inline]
37    const fn from_code(code: u8) -> Self {
38        match code {
39            0 => Self::Clean,
40            1 => Self::NonComposableFlag,
41            2 => Self::WorkgroupSize,
42            3 => Self::ParallelRegionSize,
43            4 => Self::EntryMutation,
44            5 => Self::InternalShapeMutation,
45            _ => Self::Unknown,
46        }
47    }
48}
49
50fn mix_wire_fallback_hashable<T: Hash>(hasher: &mut blake3::Hasher, value: &T) {
51    let mut state = FxHasher::default();
52    value.hash(&mut state);
53    hasher.update(&state.finish().to_le_bytes());
54}
55
56/// Bounded IR structure digest for wire-hash fallback (never formats full IR via `Debug`).
57struct FallbackWireHasher<'a>(&'a mut blake3::Hasher);
58
59impl NodeVisitor for FallbackWireHasher<'_> {
60    fn visit_node(&mut self, node: &Node) {
61        let h = &mut *self.0;
62        match node {
63            Node::Let { name, .. } => {
64                h.update(b"n:Let\0");
65                h.update(name.as_bytes());
66            }
67            Node::Assign { name, .. } => {
68                h.update(b"n:Assign\0");
69                h.update(name.as_bytes());
70            }
71            Node::Store { buffer, .. } => {
72                h.update(b"n:Store\0");
73                h.update(buffer.as_bytes());
74            }
75            Node::If { .. } => {
76                h.update(b"n:If\0");
77            }
78            Node::Loop { var, .. } => {
79                h.update(b"n:Loop\0");
80                h.update(var.as_bytes());
81            }
82            Node::IndirectDispatch {
83                count_buffer,
84                count_offset,
85            } => {
86                h.update(b"n:IndirectDispatch\0");
87                h.update(count_buffer.as_bytes());
88                h.update(&count_offset.to_le_bytes());
89            }
90            Node::AsyncLoad {
91                source,
92                destination,
93                tag,
94                ..
95            } => {
96                h.update(b"n:AsyncLoad\0");
97                h.update(source.as_bytes());
98                h.update(destination.as_bytes());
99                h.update(tag.as_bytes());
100            }
101            Node::AsyncStore {
102                source,
103                destination,
104                tag,
105                ..
106            } => {
107                h.update(b"n:AsyncStore\0");
108                h.update(source.as_bytes());
109                h.update(destination.as_bytes());
110                h.update(tag.as_bytes());
111            }
112            Node::AsyncWait { tag } => {
113                h.update(b"n:AsyncWait\0");
114                h.update(tag.as_bytes());
115            }
116            Node::Trap { tag, .. } => {
117                h.update(b"n:Trap\0");
118                h.update(tag.as_bytes());
119            }
120            Node::Resume { tag } => {
121                h.update(b"n:Resume\0");
122                h.update(tag.as_bytes());
123            }
124            Node::AllReduce { buffer, op, group } => {
125                h.update(b"n:AllReduce\0");
126                h.update(buffer.as_bytes());
127                h.update(&op.builtin_wire_tag().to_le_bytes());
128                h.update(&group.as_u32().to_le_bytes());
129            }
130            Node::AllGather {
131                input,
132                output,
133                group,
134            } => {
135                h.update(b"n:AllGather\0");
136                h.update(input.as_bytes());
137                h.update(output.as_bytes());
138                h.update(&group.as_u32().to_le_bytes());
139            }
140            Node::ReduceScatter {
141                input,
142                output,
143                op,
144                group,
145            } => {
146                h.update(b"n:ReduceScatter\0");
147                h.update(input.as_bytes());
148                h.update(output.as_bytes());
149                h.update(&op.builtin_wire_tag().to_le_bytes());
150                h.update(&group.as_u32().to_le_bytes());
151            }
152            Node::Broadcast {
153                buffer,
154                root,
155                group,
156            } => {
157                h.update(b"n:Broadcast\0");
158                h.update(buffer.as_bytes());
159                h.update(&root.to_le_bytes());
160                h.update(&group.as_u32().to_le_bytes());
161            }
162            Node::Return => {
163                h.update(b"n:Return\0");
164            }
165            Node::Barrier { ordering } => {
166                h.update(b"n:Barrier\0");
167                mix_wire_fallback_hashable(h, ordering);
168            }
169            Node::Block(_) => {
170                h.update(b"n:Block\0");
171            }
172            Node::Region {
173                generator,
174                source_region,
175                ..
176            } => {
177                h.update(b"n:Region\0");
178                h.update(generator.as_bytes());
179                if let Some(source_gen) = source_region {
180                    h.update(source_gen.name.as_bytes());
181                }
182            }
183            Node::Opaque(ext) => {
184                h.update(b"n:Opaque\0");
185                h.update(ext.extension_kind().as_bytes());
186            }
187        }
188    }
189}
190
191impl ExprVisitor for FallbackWireHasher<'_> {
192    fn visit_expr(&mut self, expr: &Expr) {
193        let h = &mut *self.0;
194        match expr {
195            Expr::LitU32(v) => {
196                h.update(b"e:LitU32\0");
197                h.update(&v.to_le_bytes());
198            }
199            Expr::LitI32(v) => {
200                h.update(b"e:LitI32\0");
201                h.update(&v.to_le_bytes());
202            }
203            Expr::LitF32(v) => {
204                h.update(b"e:LitF32\0");
205                h.update(&v.to_le_bytes());
206            }
207            Expr::LitBool(v) => {
208                h.update(b"e:LitBool\0");
209                h.update(&[u8::from(*v)]);
210            }
211            Expr::Var(name) => {
212                h.update(b"e:Var\0");
213                h.update(name.as_bytes());
214            }
215            Expr::Load { buffer, .. } => {
216                h.update(b"e:Load\0");
217                h.update(buffer.as_bytes());
218            }
219            Expr::BufLen { buffer } => {
220                h.update(b"e:BufLen\0");
221                h.update(buffer.as_bytes());
222            }
223            Expr::InvocationId { axis } => {
224                h.update(b"e:InvocationId\0");
225                h.update(&[*axis]);
226            }
227            Expr::WorkgroupId { axis } => {
228                h.update(b"e:WorkgroupId\0");
229                h.update(&[*axis]);
230            }
231            Expr::LocalId { axis } => {
232                h.update(b"e:LocalId\0");
233                h.update(&[*axis]);
234            }
235            Expr::BinOp { op, .. } => {
236                h.update(b"e:BinOp\0");
237                mix_wire_fallback_hashable(h, op);
238            }
239            Expr::UnOp { op, .. } => {
240                h.update(b"e:UnOp\0");
241                mix_wire_fallback_hashable(h, op);
242            }
243            Expr::Call { op_id, .. } => {
244                h.update(b"e:Call\0");
245                h.update(op_id.as_bytes());
246            }
247            Expr::Select { .. } => {
248                h.update(b"e:Select\0");
249            }
250            Expr::Cast { target, .. } => {
251                h.update(b"e:Cast\0");
252                mix_wire_fallback_hashable(h, target);
253            }
254            Expr::Fma { .. } => {
255                h.update(b"e:Fma\0");
256            }
257            Expr::Atomic {
258                op,
259                buffer,
260                ordering,
261                ..
262            } => {
263                h.update(b"e:Atomic\0");
264                mix_wire_fallback_hashable(h, op);
265                h.update(buffer.as_bytes());
266                mix_wire_fallback_hashable(h, ordering);
267            }
268            Expr::SubgroupBallot { .. } => {
269                h.update(b"e:SubgroupBallot\0");
270            }
271            Expr::SubgroupShuffle { .. } => {
272                h.update(b"e:SubgroupShuffle\0");
273            }
274            Expr::SubgroupAdd { .. } => {
275                h.update(b"e:SubgroupAdd\0");
276            }
277            Expr::SubgroupLocalId => {
278                h.update(b"e:SubgroupLocalId\0");
279            }
280            Expr::SubgroupSize => {
281                h.update(b"e:SubgroupSize\0");
282            }
283            Expr::Opaque(ext) => {
284                h.update(b"e:Opaque\0");
285                h.update(ext.extension_kind().as_bytes());
286            }
287        }
288    }
289}
290
291impl Program {
292    /// Re-apply the same top-level `Node::Region` contract as
293    /// [`Program::wrapped`].
294    ///
295    /// The [`region_inline_engine`](crate::optimizer::passes::cleanup::region_inline_engine)
296    /// pass flattens small Category-A regions so CSE/DCE can see a single
297    /// function-shaped body, which can leave a statement-shaped entry list. The
298    /// standard optimizer run ends with this helper so the program remains in
299    /// a runnable, validator/reference-interpreter–compatible form while
300    /// still benefiting from the inline pass.
301    #[must_use]
302    pub fn reconcile_runnable_top_level(self) -> Self {
303        if self.is_top_level_region_wrapped() {
304            return self;
305        }
306        // Move the entry Vec out via map_entry's Arc-aware path; one
307        // Program rebuild instead of two scaffold rebuilds.
308        self.map_entry(Self::wrap_entry)
309    }
310
311    /// Look up a buffer declaration by name.
312    #[must_use]
313    #[inline]
314    pub fn buffer(&self, name: &str) -> Option<&super::BufferDecl> {
315        self.buffer_index
316            .get(name)
317            .and_then(|&index| self.buffers.get(index))
318    }
319
320    /// Declared buffers.
321    #[must_use]
322    #[inline]
323    pub fn buffers(&self) -> &[super::BufferDecl] {
324        self.buffers.as_ref()
325    }
326
327    /// Access the buffer declaration Arc directly for identity checks.
328    #[must_use]
329    #[inline]
330    #[cfg(test)]
331    pub(crate) fn buffers_arc(&self) -> &Arc<[super::BufferDecl]> {
332        &self.buffers
333    }
334
335    /// Compare two programs by observable IR structure.
336    ///
337    /// This walk intentionally ignores buffer declaration order and never
338    /// consults arena-local allocation identity. Two programs are structurally
339    /// equal when they declare the same buffers, workgroup size, optional entry
340    /// op id, and entry body semantics.
341    #[must_use]
342    #[inline]
343    pub fn structural_eq(&self, other: &Self) -> bool {
344        // Identity short-circuit: Program::clone shares all the
345        // inner Arcs, so comparing a cloned program against its
346        // source (the common optimizer-pipeline pattern) is pure
347        // refcount comparison.
348        if std::ptr::eq(self, other)
349            || (Arc::ptr_eq(&self.buffers, &other.buffers)
350                && Arc::ptr_eq(&self.entry, &other.entry)
351                && self.entry_op_id == other.entry_op_id
352                && self.non_composable_with_self == other.non_composable_with_self
353                && self.workgroup_size == other.workgroup_size)
354        {
355            return true;
356        }
357        self.entry_op_id == other.entry_op_id
358            && self.non_composable_with_self == other.non_composable_with_self
359            && buffers_equal_ignoring_declaration_order(&self.buffers, &other.buffers)
360            && self.workgroup_size == other.workgroup_size
361            && self.entry == other.entry
362    }
363
364    /// Workgroup dimensions.
365    #[must_use]
366    #[inline]
367    pub fn workgroup_size(&self) -> [u32; 3] {
368        self.workgroup_size
369    }
370
371    /// Substrate-neutral alias for [`workgroup_size`](Self::workgroup_size).
372    ///
373    /// Naming: "parallel region" avoids picking a single target substrate's
374    /// word for one dispatch invocation grouping.
375    #[must_use]
376    #[inline]
377    pub fn parallel_region_size(&self) -> [u32; 3] {
378        self.workgroup_size
379    }
380
381    /// Return true when this program must not be fused with another copy
382    /// of itself in the same megakernel.
383    #[must_use]
384    #[inline]
385    pub fn is_non_composable_with_self(&self) -> bool {
386        self.non_composable_with_self
387    }
388
389    /// Mark this program as non-composable with itself.
390    #[must_use]
391    #[inline]
392    pub fn with_non_composable_with_self(mut self, flag: bool) -> Self {
393        self.non_composable_with_self = flag;
394        self.invalidate_caches_for(ProgramMutationProvenance::NonComposableFlag);
395        self
396    }
397
398    /// Set the workgroup dimensions in place. Used by harnesses that
399    /// need to clone-and-rewrite a program's workgroup size for fallback
400    /// dispatch  -  the alternative was to reconstruct the entire Program,
401    /// which is unnecessarily expensive when only one field changes.
402    #[inline]
403    pub fn set_workgroup_size(&mut self, workgroup_size: [u32; 3]) {
404        self.workgroup_size = workgroup_size;
405        self.invalidate_caches_for(ProgramMutationProvenance::WorkgroupSize);
406    }
407
408    /// Substrate-neutral alias for [`set_workgroup_size`](Self::set_workgroup_size).
409    #[inline]
410    pub fn set_parallel_region_size(&mut self, parallel_region_size: [u32; 3]) {
411        self.workgroup_size = parallel_region_size;
412        self.invalidate_caches_for(ProgramMutationProvenance::ParallelRegionSize);
413    }
414
415    /// Entry-point nodes.
416    #[must_use]
417    #[inline]
418    pub fn entry(&self) -> &[Node] {
419        self.entry.as_ref().as_slice()
420    }
421
422    /// Shared entry-point body Arc for identity checks.
423    #[must_use]
424    #[inline]
425    pub fn entry_arc(&self) -> &Arc<Vec<Node>> {
426        &self.entry
427    }
428
429    /// Return true when this Program is the canonical no-op shape produced by
430    /// [`Program::empty`]: no buffers and a single empty root Region.
431    #[must_use]
432    #[inline]
433    pub fn is_explicit_noop(&self) -> bool {
434        self.buffers().is_empty()
435            && matches!(self.entry(), [Node::Region { body, .. }] if body.is_empty())
436    }
437
438    /// Return true when the program satisfies the top-level region-chain
439    /// invariant: at least one top-level node, and every top-level node is a
440    /// `Node::Region`.
441    #[must_use]
442    #[inline]
443    pub fn is_top_level_region_wrapped(&self) -> bool {
444        !self.entry.is_empty()
445            && self
446                .entry()
447                .iter()
448                .all(|node| matches!(node, Node::Region { .. }))
449    }
450
451    /// Actionable error text describing why the top-level region invariant
452    /// failed, or `None` when the entry is valid.
453    #[must_use]
454    pub fn top_level_region_violation(&self) -> Option<String> {
455        if self.entry().is_empty() {
456            return Some(
457                "program entry has no top-level Region. Fix: construct runnable programs with Program::wrapped(...) or wrap the body in Node::Region before validation, interpretation, or dispatch."
458                    .to_string(),
459            );
460        }
461
462        self.entry()
463            .iter()
464            .enumerate()
465            .find(|(_, node)| !matches!(node, Node::Region { .. }))
466            .map(|(index, node)| {
467                format!(
468                    "program entry node {index} is `{}` instead of `Node::Region`. Fix: construct runnable programs with Program::wrapped(...) or wrap the top-level body in Node::Region; raw Program::new is reserved for wire decode and negative tests.",
469                    Self::top_level_node_name(node)
470                )
471            })
472    }
473
474    /// Mutable entry-point nodes for transformation passes.
475    #[must_use]
476    #[inline]
477    pub fn entry_mut(&mut self) -> &mut Vec<Node> {
478        self.invalidate_caches_for(ProgramMutationProvenance::EntryMutation);
479        Arc::make_mut(&mut self.entry)
480    }
481
482    /// Stable BLAKE3 fingerprint of the canonical wire-format bytes.
483    #[must_use]
484    #[inline]
485    pub fn fingerprint(&self) -> [u8; 32] {
486        *self.fingerprint.get_or_init(|| {
487            let hash = self.compute_wire_hash();
488            let _ = self.hash.set(hash);
489            *hash.as_bytes()
490        })
491    }
492
493    /// VSA-style hypervector fingerprint of the canonical wire-format
494    /// bytes. Each `u32` lane is one segment of the program's blake3
495    /// hash; together they form an 8-lane hypervector suitable for
496    /// approximate similarity search via hamming distance.
497    ///
498    /// Use as the canonical cache key for approximate-match caches
499    /// (e.g. validation cache, AOT artifact dedup); use
500    /// [`Self::fingerprint`] for exact-match lookups.
501    ///
502    /// Wires the substrate's #29 hypervector primitive into Program
503    /// itself  -  every Program now carries its own VSA fingerprint
504    /// without callers having to reach into the substrate explicitly.
505    #[must_use]
506    pub fn vsa_fingerprint(&self) -> Vec<u32> {
507        self.fingerprint()
508            .chunks_exact(core::mem::size_of::<u32>())
509            .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
510            .collect()
511    }
512
513    /// Indices of writable storage outputs in `buffers()` order.
514    #[must_use]
515    #[inline]
516    pub fn output_buffer_indices(&self) -> &[u32] {
517        self.output_buffer_index
518            .get_or_init(|| {
519                Arc::new(
520                    self.buffers()
521                        .iter()
522                        .enumerate()
523                        .filter_map(|(index, buffer)| {
524                            matches!(
525                                buffer.access(),
526                                BufferAccess::ReadWrite | BufferAccess::WriteOnly
527                            )
528                            .then(|| u32::try_from(index).ok())
529                            .flatten()
530                        })
531                        .collect(),
532                )
533            })
534            .as_slice()
535    }
536
537    /// True when the entry walk discovers any indirect dispatch node.
538    #[must_use]
539    #[inline]
540    pub fn has_indirect_dispatch(&self) -> bool {
541        *self.has_indirect_dispatch.get_or_init(|| {
542            // Fast-path: ProgramStats records every node kind seen during
543            // its single-pass walk. If the IndirectDispatch bit is unset,
544            // the tree definitely contains no IndirectDispatch nodes and
545            // the explicit traversal below would redundantly visit every
546            // node only to return false. Reading the bit is O(1).
547            if !self
548                .stats()
549                .has_any_node_kind(super::stats::NODE_KIND_INDIRECT_DISPATCH)
550            {
551                return false;
552            }
553            let mut stack: smallvec::SmallVec<[&Node; 32]> = self.entry().iter().rev().collect();
554            while let Some(node) = stack.pop() {
555                match node {
556                    Node::IndirectDispatch { .. } => return true,
557                    Node::If {
558                        then, otherwise, ..
559                    } => {
560                        stack.extend(otherwise.iter().rev());
561                        stack.extend(then.iter().rev());
562                    }
563                    Node::Loop { body, .. } | Node::Block(body) => {
564                        stack.extend(body.iter().rev());
565                    }
566                    Node::Region { body, .. } => {
567                        stack.extend(body.iter().rev());
568                    }
569                    Node::Let { .. }
570                    | Node::Assign { .. }
571                    | Node::Store { .. }
572                    | Node::AllReduce { .. }
573                    | Node::AllGather { .. }
574                    | Node::ReduceScatter { .. }
575                    | Node::Broadcast { .. }
576                    | Node::Return
577                    | Node::Barrier { .. }
578                    | Node::AsyncLoad { .. }
579                    | Node::AsyncStore { .. }
580                    | Node::AsyncWait { .. }
581                    | Node::Trap { .. }
582                    | Node::Resume { .. }
583                    | Node::Opaque(_) => {}
584                }
585            }
586            false
587        })
588    }
589
590    /// Check whether a named buffer exists.
591    #[must_use]
592    #[inline]
593    pub fn has_buffer(&self, name: &str) -> bool {
594        self.buffer_index.contains_key(name)
595    }
596
597    /// Number of declared buffers.
598    #[must_use]
599    #[inline]
600    pub fn buffer_count(&self) -> usize {
601        self.buffers.len()
602    }
603
604    #[inline]
605    pub(super) fn build_buffer_index(
606        buffers: &[super::BufferDecl],
607    ) -> rustc_hash::FxHashMap<Arc<str>, usize> {
608        let mut index = rustc_hash::FxHashMap::default();
609        index.reserve(buffers.len());
610        for (buffer_index, buffer) in buffers.iter().enumerate() {
611            index
612                .entry(Arc::clone(&buffer.name))
613                .or_insert(buffer_index);
614        }
615        index
616    }
617
618    /// Mark this program as successfully validated structurally.
619    #[inline]
620    pub fn mark_structurally_validated(&self) {
621        self.structural_validation_fingerprint.store(
622            self.current_validation_fingerprint_token(),
623            Ordering::Release,
624        );
625        self.mutation_provenance
626            .store(ProgramMutationProvenance::Clean as u8, Ordering::Release);
627        self.structural_validated.store(true, Ordering::Release);
628    }
629
630    /// Return true once structural validation has succeeded for this program shape.
631    #[must_use]
632    #[inline]
633    pub fn is_structurally_validated(&self) -> bool {
634        if !self.structural_validated.load(Ordering::Acquire) {
635            return false;
636        }
637        if self.validation_mutation_provenance() == ProgramMutationProvenance::Unknown {
638            self.structural_validated.store(false, Ordering::Release);
639            return false;
640        }
641        let recorded = self
642            .structural_validation_fingerprint
643            .load(Ordering::Acquire);
644        if recorded == 0 || recorded != self.current_validation_fingerprint_token() {
645            self.structural_validated.store(false, Ordering::Release);
646            return false;
647        }
648        true
649    }
650
651    /// Last mutation provenance recorded for validation/cache invalidation.
652    #[must_use]
653    #[inline]
654    pub fn validation_mutation_provenance(&self) -> ProgramMutationProvenance {
655        ProgramMutationProvenance::from_code(self.mutation_provenance.load(Ordering::Acquire))
656    }
657
658    /// Mark the Program as having been mutated by a boundary that cannot name a
659    /// concrete provenance. Validation fails closed until the Program is rebuilt
660    /// through a known constructor or known mutation API.
661    #[inline]
662    pub fn mark_unknown_mutation_provenance(&mut self) {
663        self.invalidate_caches_for(ProgramMutationProvenance::Unknown);
664    }
665
666    /// Mark this program as successfully validated for a specific backend.
667    #[inline]
668    pub fn mark_validated_on(&self, backend_id: &str) {
669        if self.validation_mutation_provenance() == ProgramMutationProvenance::Unknown {
670            return;
671        }
672        self.validation_set
673            .get_or_init(|| Arc::new(dashmap::DashSet::new()))
674            .insert(Arc::from(self.validation_cache_key(backend_id)));
675    }
676
677    /// Return true if this program has been validated for the given backend.
678    #[must_use]
679    #[inline]
680    pub fn is_validated_on(&self, backend_id: &str) -> bool {
681        self.validation_set
682            .get()
683            .is_some_and(|set| set.contains(self.validation_cache_key(backend_id).as_str()))
684    }
685
686    /// Deprecated: use `is_structurally_validated` or `is_validated_on`.
687    #[deprecated(note = "use is_structurally_validated or is_validated_on")]
688    #[must_use]
689    #[inline]
690    pub fn is_validated(&self) -> bool {
691        self.is_structurally_validated()
692    }
693
694    /// Deprecated: use `mark_structurally_validated` or `mark_validated_on`.
695    #[deprecated(note = "use mark_structurally_validated or mark_validated_on")]
696    #[inline]
697    pub fn mark_validated(&self) {
698        self.mark_structurally_validated();
699    }
700
701    /// Validate the program and cache the successful result on the program.
702    ///
703    /// # Errors
704    ///
705    /// Returns [`crate::Error::WireFormatValidation`] with every validation
706    /// message joined when the structural validator rejects the program.
707    pub fn validate(&self) -> crate::error::Result<()> {
708        if self.validation_mutation_provenance() == ProgramMutationProvenance::Unknown {
709            return Err(crate::error::Error::WireFormatValidation {
710                message: "program validation cache was invalidated by unknown mutation provenance. Fix: rebuild the Program through Program::wrapped/from_wire or use a named Program mutation API before validating.".into(),
711            });
712        }
713        if self.is_structurally_validated() {
714            return Ok(());
715        }
716        let errors = crate::validate::validate(self);
717        if errors.is_empty() {
718            self.mark_structurally_validated();
719            return Ok(());
720        }
721        let mut message = String::new();
722        for (index, error) in errors.into_iter().enumerate() {
723            if index > 0 {
724                message.push_str("; ");
725            }
726            message.push_str(error.message());
727        }
728        Err(crate::error::Error::WireFormatValidation { message })
729    }
730
731    #[inline]
732    /// Estimate the peak VRAM byte size of this Program.
733    ///
734    /// Innovation I.11: Static VRAM Pressure Analysis.
735    /// Returns the total bytes required by all storage and uniform buffers
736    /// declared in the Program. Optimizer passes use this to automatically
737    /// partition workloads if they would exceed a backend-specific safety
738    /// margin.
739    #[must_use]
740    pub fn estimate_peak_vram_bytes(&self) -> u64 {
741        self.buffers
742            .iter()
743            .map(|buffer| {
744                let Some(element_size) = buffer.element.size_bytes() else {
745                    return u64::MAX;
746                };
747                u64::from(buffer.count)
748                    .saturating_mul(u64::try_from(element_size).unwrap_or(u64::MAX))
749            })
750            .fold(0u64, u64::saturating_add)
751    }
752
753    /// Return the peak computational intensity found in any instruction.
754    #[must_use]
755    pub fn peak_intensity(&self) -> OpIntensity {
756        let mut peak = OpIntensity::Free;
757        for node in self.entry() {
758            peak = peak.max(Self::node_intensity(node));
759        }
760        peak
761    }
762
763    fn node_intensity(node: &crate::ir::Node) -> OpIntensity {
764        use crate::ir::Node;
765        match node {
766            Node::Let { value, .. } | Node::Assign { value, .. } => Self::expr_intensity(value),
767            Node::Store { index, value, .. } => {
768                Self::expr_intensity(index).max(Self::expr_intensity(value))
769            }
770            Node::If {
771                cond,
772                then,
773                otherwise,
774            } => {
775                let mut p = Self::expr_intensity(cond);
776                for n in then {
777                    p = p.max(Self::node_intensity(n));
778                }
779                for n in otherwise {
780                    p = p.max(Self::node_intensity(n));
781                }
782                p
783            }
784            Node::Loop { from, to, body, .. } => {
785                let mut p = Self::expr_intensity(from).max(Self::expr_intensity(to));
786                for n in body {
787                    p = p.max(Self::node_intensity(n));
788                }
789                p
790            }
791            Node::Block(nodes) => {
792                let mut p = OpIntensity::Free;
793                for n in nodes {
794                    p = p.max(Self::node_intensity(n));
795                }
796                p
797            }
798            Node::Region { body, .. } => {
799                let mut p = OpIntensity::Free;
800                for n in body.iter() {
801                    p = p.max(Self::node_intensity(n));
802                }
803                p
804            }
805            _ => OpIntensity::Free,
806        }
807    }
808
809    fn expr_intensity(expr: &crate::ir::Expr) -> OpIntensity {
810        use crate::ir::Expr;
811        match expr {
812            Expr::BinOp { op, left, right } => op
813                .intensity()
814                .max(Self::expr_intensity(left))
815                .max(Self::expr_intensity(right)),
816            Expr::UnOp { operand, .. } => Self::expr_intensity(operand),
817            Expr::Load { index, .. } => Self::expr_intensity(index),
818            Expr::Select {
819                cond,
820                true_val,
821                false_val,
822            } => Self::expr_intensity(cond)
823                .max(Self::expr_intensity(true_val))
824                .max(Self::expr_intensity(false_val)),
825            Expr::Cast { value, .. } => Self::expr_intensity(value),
826            Expr::Fma { a, b, c } => Self::expr_intensity(a)
827                .max(Self::expr_intensity(b))
828                .max(Self::expr_intensity(c)),
829            Expr::Atomic {
830                index,
831                value,
832                expected,
833                ..
834            } => {
835                let mut p = Self::expr_intensity(index).max(Self::expr_intensity(value));
836                if let Some(e) = expected {
837                    p = p.max(Self::expr_intensity(e));
838                }
839                p.max(OpIntensity::Heavy)
840            }
841            Expr::SubgroupBallot { cond } => Self::expr_intensity(cond).max(OpIntensity::Heavy),
842            Expr::SubgroupShuffle { value, lane } => Self::expr_intensity(value)
843                .max(Self::expr_intensity(lane))
844                .max(OpIntensity::Heavy),
845            Expr::SubgroupAdd { value } => Self::expr_intensity(value).max(OpIntensity::Heavy),
846            _ => OpIntensity::Free,
847        }
848    }
849
850    fn compute_wire_hash(&self) -> blake3::Hash {
851        match self.canonical_wire_hash() {
852            Ok(hash) => hash,
853            Err(error) => {
854                let structural = self.structural_fingerprint_fallback();
855                let err_msg = error.to_string();
856                let mut fallback = Vec::with_capacity(96 + err_msg.len() + structural.len());
857                fallback.extend_from_slice(b"VYRE-PROGRAM-CANONICAL-WIRE-HASH-ERROR\0");
858                fallback.extend_from_slice(err_msg.as_bytes());
859                fallback.push(0);
860                fallback.extend_from_slice(structural.as_bytes());
861                blake3::hash(&fallback)
862            }
863        }
864    }
865
866    fn structural_fingerprint_fallback(&self) -> String {
867        let mut hasher = blake3::Hasher::new();
868        hasher.update(b"VYRE-WIRE-FALLBACK-V4\0");
869        if let Some(id) = self.entry_op_id.as_deref() {
870            hasher.update(id.as_bytes());
871        }
872        hasher.update(b"\0");
873        for axis in &self.workgroup_size {
874            hasher.update(&axis.to_le_bytes());
875        }
876        hasher.update(&[u8::from(self.non_composable_with_self)]);
877        let mut keys: Vec<Vec<u8>> = self
878            .buffers()
879            .iter()
880            .map(buffer_decl_canonical_key)
881            .collect();
882        keys.sort_unstable();
883        for key in keys {
884            hasher.update(&key);
885        }
886        let mut visitor = FallbackWireHasher(&mut hasher);
887        walk_nodes_and_exprs(self, &mut visitor);
888        hasher.finalize().to_hex().to_string()
889    }
890
891    fn validation_cache_key(&self, backend_id: &str) -> String {
892        const HEX: &[u8; 16] = b"0123456789abcdef";
893        let fingerprint = self.current_validation_fingerprint();
894        let mut key = String::with_capacity(backend_id.len() + 1 + 64);
895        key.push_str(backend_id);
896        key.push(':');
897        for &byte in &fingerprint {
898            key.push(HEX[(byte >> 4) as usize] as char);
899            key.push(HEX[(byte & 0x0f) as usize] as char);
900        }
901        key
902    }
903
904    #[inline]
905    pub(super) fn invalidate_caches(&mut self) {
906        self.invalidate_caches_for(ProgramMutationProvenance::InternalShapeMutation);
907    }
908
909    #[inline]
910    pub(super) fn invalidate_caches_for(&mut self, provenance: ProgramMutationProvenance) {
911        self.structural_validated.store(false, Ordering::Release);
912        self.structural_validation_fingerprint
913            .store(0, Ordering::Release);
914        self.mutation_provenance
915            .store(provenance as u8, Ordering::Release);
916        if let Some(set) = self.validation_set.get() {
917            set.clear();
918        }
919        let _ = self.hash.take();
920        let _ = self.fingerprint.take();
921        drop(self.output_buffer_index.take());
922        let _ = self.has_indirect_dispatch.take();
923        drop(self.stats.take());
924    }
925
926    fn current_validation_fingerprint(&self) -> [u8; 32] {
927        *self.compute_wire_hash().as_bytes()
928    }
929
930    fn current_validation_fingerprint_token(&self) -> u64 {
931        let bytes = self.current_validation_fingerprint();
932        let token = u64::from_le_bytes([
933            bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
934        ]);
935        token.max(1)
936    }
937
938    #[inline]
939    pub(super) fn wrap_entry(entry: Vec<Node>) -> Vec<Node> {
940        if !Self::entry_needs_root_region(&entry) {
941            return entry;
942        }
943        vec![Node::Region {
944            generator: Ident::from(Self::ROOT_REGION_GENERATOR),
945            source_region: None,
946            body: Arc::new(entry),
947        }]
948    }
949
950    #[inline]
951    fn entry_needs_root_region(entry: &[Node]) -> bool {
952        entry.is_empty()
953            || entry
954                .iter()
955                .any(|node| !matches!(node, Node::Region { .. }))
956    }
957
958    #[inline]
959    fn top_level_node_name(node: &Node) -> &'static str {
960        match node {
961            Node::Let { .. } => "Let",
962            Node::Assign { .. } => "Assign",
963            Node::Store { .. } => "Store",
964            Node::If { .. } => "If",
965            Node::Loop { .. } => "Loop",
966            Node::Return => "Return",
967            Node::Block(_) => "Block",
968            Node::Barrier { .. } => "Barrier",
969            Node::Region { .. } => "Region",
970            Node::IndirectDispatch { .. } => "IndirectDispatch",
971            Node::AsyncLoad { .. } => "AsyncLoad",
972            Node::AsyncStore { .. } => "AsyncStore",
973            Node::AsyncWait { .. } => "AsyncWait",
974            Node::Trap { .. } => "Trap",
975            Node::Resume { .. } => "Resume",
976            Node::AllReduce { .. } => "AllReduce",
977            Node::AllGather { .. } => "AllGather",
978            Node::ReduceScatter { .. } => "ReduceScatter",
979            Node::Broadcast { .. } => "Broadcast",
980            Node::Opaque(_) => "Opaque",
981        }
982    }
983}
984
985pub(crate) fn buffers_equal_ignoring_declaration_order(
986    left: &[super::BufferDecl],
987    right: &[super::BufferDecl],
988) -> bool {
989    if left.len() != right.len() {
990        return false;
991    }
992
993    // VYRE_IR_HOTSPOTS HIGH (meta.rs:360-379): previous impl allocated
994    // two Vec<Vec<u8>> then sorted on every equality call. Fast-path:
995    // if the slices compare equal in-place (declaration orders match)
996    // we skip the key-materialization entirely. This catches every
997    // Program::clone(prog) == prog and every `Arc::clone`-equivalent
998    // comparison, which dominate the call distribution.
999    if left == right {
1000        return true;
1001    }
1002
1003    let mut left_keys = Vec::with_capacity(left.len());
1004    left_keys.extend(left.iter().map(buffer_decl_canonical_key));
1005    let mut right_keys = Vec::with_capacity(right.len());
1006    right_keys.extend(right.iter().map(buffer_decl_canonical_key));
1007    left_keys.sort_unstable();
1008    right_keys.sort_unstable();
1009    left_keys == right_keys
1010}
1011
1012pub(super) fn buffer_decl_canonical_key(buffer: &super::BufferDecl) -> Vec<u8> {
1013    use crate::serial::wire::framing::{put_len_u32, put_u32, put_u8};
1014    use crate::serial::wire::tags::put_data_type;
1015
1016    let mut key = Vec::with_capacity(96);
1017    if let Err(error) = put_len_u32(&mut key, buffer.name.len(), "buffer name length") {
1018        key.extend_from_slice(b"\0name-length-error\0");
1019        key.extend_from_slice(error.as_bytes());
1020    }
1021    key.extend_from_slice(buffer.name.as_bytes());
1022    put_u32(&mut key, buffer.binding);
1023    match crate::serial::wire::tags::access_tag::access_tag(&buffer.access) {
1024        Ok(tag) => put_u8(&mut key, tag),
1025        Err(error) => {
1026            put_u8(&mut key, u8::MAX);
1027            key.extend_from_slice(error.as_bytes());
1028        }
1029    }
1030    put_u8(
1031        &mut key,
1032        match buffer.kind {
1033            super::MemoryKind::Global => 0,
1034            super::MemoryKind::Shared => 1,
1035            super::MemoryKind::Uniform => 2,
1036            super::MemoryKind::Local => 3,
1037            super::MemoryKind::Readonly => 4,
1038            super::MemoryKind::Persistent => 5,
1039            super::MemoryKind::Push => 6,
1040        },
1041    );
1042    if let Err(error) = put_data_type(&mut key, &buffer.element) {
1043        key.extend_from_slice(b"\0dtype-error\0");
1044        key.extend_from_slice(error.as_bytes());
1045    }
1046    put_u32(&mut key, buffer.count);
1047    put_u8(&mut key, u8::from(buffer.is_output));
1048    put_u8(&mut key, u8::from(buffer.pipeline_live_out));
1049    match &buffer.output_byte_range {
1050        Some(range) => {
1051            put_u8(&mut key, 1);
1052            match u32::try_from(range.start) {
1053                Ok(start) => put_u32(&mut key, start),
1054                Err(error) => {
1055                    put_u32(&mut key, u32::MAX);
1056                    key.extend_from_slice(error.to_string().as_bytes());
1057                }
1058            }
1059            match u32::try_from(range.end) {
1060                Ok(end) => put_u32(&mut key, end),
1061                Err(error) => {
1062                    put_u32(&mut key, u32::MAX);
1063                    key.extend_from_slice(error.to_string().as_bytes());
1064                }
1065            }
1066        }
1067        None => put_u8(&mut key, 0),
1068    }
1069    match buffer.hints.coalesce_axis {
1070        Some(axis) => {
1071            put_u8(&mut key, 1);
1072            put_u8(&mut key, axis);
1073        }
1074        None => put_u8(&mut key, 0),
1075    }
1076    put_u32(&mut key, buffer.hints.preferred_alignment);
1077    put_u8(
1078        &mut key,
1079        match buffer.hints.cache_locality {
1080            super::CacheLocality::Streaming => 0,
1081            super::CacheLocality::Temporal => 1,
1082            super::CacheLocality::Random => 2,
1083        },
1084    );
1085    put_u8(&mut key, u8::from(buffer.bytes_extraction));
1086    key
1087}