Skip to main content

vyre_foundation/ir_inner/model/program/
stats.rs

1use super::Program;
2use crate::ir::{DataType, Expr, Node};
3
4const CAP_SUBGROUP_OPS: u32 = 1 << 0;
5const CAP_F16: u32 = 1 << 1;
6const CAP_BF16: u32 = 1 << 2;
7const CAP_F64: u32 = 1 << 3;
8const CAP_ASYNC_DISPATCH: u32 = 1 << 4;
9const CAP_INDIRECT_DISPATCH: u32 = 1 << 5;
10const CAP_TENSOR_OPS: u32 = 1 << 6;
11const CAP_TRAP: u32 = 1 << 7;
12const CAP_DISTRIBUTED_COLLECTIVES: u32 = 1 << 8;
13
14// Bit positions for `ProgramStats::node_kinds_present`. Mirrors the
15// variant declaration order in `ir_inner::model::generated::Node` and
16// matches `optimizer::program_soa::NodeKind` so the optimizer can use
17// either source of truth interchangeably. The
18// `node_kinds_present_matches_program_soa_node_kind` test enforces
19// the alignment.
20/// `Node::Let`.
21pub const NODE_KIND_LET: u32 = 1 << 0;
22/// `Node::Assign`.
23pub const NODE_KIND_ASSIGN: u32 = 1 << 1;
24/// `Node::Store`.
25pub const NODE_KIND_STORE: u32 = 1 << 2;
26/// `Node::If`.
27pub const NODE_KIND_IF: u32 = 1 << 3;
28/// `Node::Loop`.
29pub const NODE_KIND_LOOP: u32 = 1 << 4;
30/// `Node::IndirectDispatch`.
31pub const NODE_KIND_INDIRECT_DISPATCH: u32 = 1 << 5;
32/// `Node::AsyncLoad`.
33pub const NODE_KIND_ASYNC_LOAD: u32 = 1 << 6;
34/// `Node::AsyncStore`.
35pub const NODE_KIND_ASYNC_STORE: u32 = 1 << 7;
36/// `Node::AsyncWait`.
37pub const NODE_KIND_ASYNC_WAIT: u32 = 1 << 8;
38/// `Node::Trap`.
39pub const NODE_KIND_TRAP: u32 = 1 << 9;
40/// `Node::Resume`.
41pub const NODE_KIND_RESUME: u32 = 1 << 10;
42/// `Node::Return`.
43pub const NODE_KIND_RETURN: u32 = 1 << 11;
44/// `Node::Barrier`.
45pub const NODE_KIND_BARRIER: u32 = 1 << 12;
46/// `Node::Block`.
47pub const NODE_KIND_BLOCK: u32 = 1 << 13;
48/// `Node::Region`.
49pub const NODE_KIND_REGION: u32 = 1 << 14;
50/// `Node::Opaque`.
51pub const NODE_KIND_ALL_REDUCE: u32 = 1 << 15;
52/// `Node::AllGather`.
53pub const NODE_KIND_ALL_GATHER: u32 = 1 << 16;
54/// `Node::ReduceScatter`.
55pub const NODE_KIND_REDUCE_SCATTER: u32 = 1 << 17;
56/// `Node::Broadcast`.
57pub const NODE_KIND_BROADCAST: u32 = 1 << 18;
58/// `Node::Opaque`.
59pub const NODE_KIND_OPAQUE: u32 = 1 << 19;
60
61/// Mask covering every node kind that owns an `Expr` tree, i.e. every
62/// kind a generic expression-rewriting pass (`canonicalize`, `const_fold`,
63/// `strength_reduce`, ...) could possibly affect. A program whose
64/// `node_kinds_present` and this mask AND to zero is structurally
65/// expression-free and any such pass can SKIP without walking.
66pub const NODE_KIND_EXPRESSION_BEARING_MASK: u32 = NODE_KIND_LET
67    | NODE_KIND_ASSIGN
68    | NODE_KIND_STORE
69    | NODE_KIND_IF
70    | NODE_KIND_LOOP
71    | NODE_KIND_ASYNC_LOAD
72    | NODE_KIND_ASYNC_STORE
73    | NODE_KIND_TRAP;
74
75/// Aggregated statistics computed from a single walk of a [`Program`].
76///
77/// This struct is cached inside [`Program`] via a [`std::sync::OnceLock`]
78/// so that planning passes (execution plan, capability scan, provenance,
79/// fusion) can read constant-time summaries instead of re-walking the IR.
80#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
81pub struct ProgramStats {
82    /// Total statement-node count (includes nested children).
83    pub node_count: usize,
84    /// Number of `Node::Region` nodes in the full tree.
85    pub region_count: u32,
86    /// Number of `Expr::Call` expressions.
87    pub call_count: u32,
88    /// Number of `Node::Opaque` nodes and `Expr::Opaque` expressions.
89    pub opaque_count: u32,
90    /// Number of top-level `Node::Region` wrappers in `program.entry()`.
91    pub top_level_regions: u32,
92    /// Sum of statically-known buffer byte sizes.
93    pub static_storage_bytes: u64,
94    /// Estimated scalar/vector IR instruction count.
95    pub instruction_count: u64,
96    /// Number of explicit memory operations (loads, stores, async copies).
97    pub memory_op_count: u64,
98    /// Number of atomic read-modify-write expressions.
99    pub atomic_op_count: u64,
100    /// Number of control-flow operations.
101    pub control_flow_count: u64,
102    /// Coarse register pressure estimate from simultaneously named SSA-ish values.
103    pub register_pressure_estimate: u32,
104    /// Bitmask of capability requirements (see `CAP_*` constants).
105    pub capability_bits: u32,
106    /// Bitset of every `Node` variant observed during the stats walk
107    /// (see the `NODE_KIND_*` constants in this module). Lets pass
108    /// `analyze_impl` predicates do an O(1) bit test against the
109    /// shared, OnceLock-cached `ProgramStats` instead of recursing
110    /// the entry tree just to check 'does this program contain at
111    /// least one Loop / If / Atomic / etc.'.
112    pub node_kinds_present: u32,
113}
114
115mod methods;
116impl Program {
117    /// Return cached statistics for this program, computing them on first call.
118    #[must_use]
119    #[inline]
120    pub fn stats(&self) -> &ProgramStats {
121        self.stats
122            .get_or_init(|| std::sync::Arc::new(compute_stats(self)))
123            .as_ref()
124    }
125}
126
127/// Single-pass preorder walk that accumulates every field of [`ProgramStats`].
128pub(crate) fn compute_stats(program: &Program) -> ProgramStats {
129    let mut node_count = 0usize;
130    let mut region_count = 0u32;
131    let mut call_count = 0u32;
132    let mut opaque_count = 0u32;
133    let mut capability_bits = 0u32;
134    let mut node_kinds_present = 0u32;
135    let mut static_storage_bytes = 0u64;
136    let mut ir = IrCounters::default();
137
138    for decl in program.buffers.iter() {
139        let count = decl.count();
140        if count != 0 {
141            if let Some(elem) = decl.element().size_bytes() {
142                static_storage_bytes =
143                    static_storage_bytes.saturating_add(u64::from(count) * elem as u64);
144            }
145        }
146        mark_datatype_bits(&decl.element(), &mut capability_bits);
147    }
148
149    for node in program.entry.iter() {
150        walk_node(
151            node,
152            &mut node_count,
153            &mut region_count,
154            &mut call_count,
155            &mut opaque_count,
156            &mut capability_bits,
157            &mut node_kinds_present,
158            &mut ir,
159        );
160    }
161
162    let top_level_regions = program
163        .entry()
164        .iter()
165        .filter(|n| matches!(n, Node::Region { .. }))
166        .count()
167        .try_into()
168        .unwrap_or(u32::MAX);
169
170    ProgramStats {
171        node_count,
172        region_count,
173        call_count,
174        opaque_count,
175        top_level_regions,
176        static_storage_bytes,
177        instruction_count: ir.instruction_count,
178        memory_op_count: ir.memory_op_count,
179        atomic_op_count: ir.atomic_op_count,
180        control_flow_count: ir.control_flow_count,
181        register_pressure_estimate: ir.register_pressure_estimate(),
182        capability_bits,
183        node_kinds_present,
184    }
185}
186
187#[derive(Default)]
188struct IrCounters {
189    instruction_count: u64,
190    memory_op_count: u64,
191    atomic_op_count: u64,
192    control_flow_count: u64,
193    live_names: u32,
194    max_live_names: u32,
195}
196
197impl IrCounters {
198    fn instruction(&mut self) {
199        self.instruction_count = self.instruction_count.saturating_add(1);
200    }
201
202    fn memory(&mut self) {
203        self.memory_op_count = self.memory_op_count.saturating_add(1);
204        self.instruction();
205    }
206
207    fn atomic(&mut self) {
208        self.atomic_op_count = self.atomic_op_count.saturating_add(1);
209        self.memory();
210    }
211
212    fn control_flow(&mut self) {
213        self.control_flow_count = self.control_flow_count.saturating_add(1);
214        self.instruction();
215    }
216
217    fn bind_name(&mut self) {
218        self.live_names = self.live_names.saturating_add(1);
219        self.max_live_names = self.max_live_names.max(self.live_names);
220    }
221
222    fn enter_scope(&mut self) -> u32 {
223        self.live_names
224    }
225
226    fn leave_scope(&mut self, saved: u32) {
227        self.live_names = saved;
228    }
229
230    fn register_pressure_estimate(&self) -> u32 {
231        self.max_live_names
232    }
233}
234
235#[inline]
236fn mark_datatype_bits(ty: &DataType, bits: &mut u32) {
237    match ty {
238        DataType::F16 => *bits |= CAP_F16,
239        DataType::BF16 => *bits |= CAP_BF16,
240        DataType::F64 => *bits |= CAP_F64,
241        DataType::Tensor | DataType::TensorShaped { .. } => *bits |= CAP_TENSOR_OPS,
242        _ => {}
243    }
244}
245
246#[allow(clippy::too_many_arguments)]
247#[expect(
248    clippy::too_many_lines,
249    reason = "single-pass ProgramStats walker keeps all counters hot and avoids repeated IR traversals"
250)]
251fn walk_node(
252    node: &Node,
253    nodes: &mut usize,
254    regions: &mut u32,
255    calls: &mut u32,
256    opaque: &mut u32,
257    bits: &mut u32,
258    kinds: &mut u32,
259    ir: &mut IrCounters,
260) {
261    *nodes = nodes.saturating_add(1);
262    match node {
263        Node::Let { value, .. } => {
264            *kinds |= NODE_KIND_LET;
265            ir.instruction();
266            ir.bind_name();
267            walk_expr(value, nodes, regions, calls, opaque, bits, kinds, ir);
268        }
269        Node::Assign { value, .. } => {
270            *kinds |= NODE_KIND_ASSIGN;
271            ir.instruction();
272            walk_expr(value, nodes, regions, calls, opaque, bits, kinds, ir);
273        }
274        Node::Store { index, value, .. } => {
275            *kinds |= NODE_KIND_STORE;
276            ir.memory();
277            walk_expr(index, nodes, regions, calls, opaque, bits, kinds, ir);
278            walk_expr(value, nodes, regions, calls, opaque, bits, kinds, ir);
279        }
280        Node::If {
281            cond,
282            then,
283            otherwise,
284        } => {
285            *kinds |= NODE_KIND_IF;
286            ir.control_flow();
287            walk_expr(cond, nodes, regions, calls, opaque, bits, kinds, ir);
288            let saved = ir.enter_scope();
289            for child in then.iter().chain(otherwise.iter()) {
290                walk_node(child, nodes, regions, calls, opaque, bits, kinds, ir);
291            }
292            ir.leave_scope(saved);
293        }
294        Node::Loop { from, to, body, .. } => {
295            *kinds |= NODE_KIND_LOOP;
296            ir.control_flow();
297            walk_expr(from, nodes, regions, calls, opaque, bits, kinds, ir);
298            walk_expr(to, nodes, regions, calls, opaque, bits, kinds, ir);
299            let saved = ir.enter_scope();
300            for child in body {
301                walk_node(child, nodes, regions, calls, opaque, bits, kinds, ir);
302            }
303            ir.leave_scope(saved);
304        }
305        Node::Block(children) => {
306            *kinds |= NODE_KIND_BLOCK;
307            let saved = ir.enter_scope();
308            for child in children {
309                walk_node(child, nodes, regions, calls, opaque, bits, kinds, ir);
310            }
311            ir.leave_scope(saved);
312        }
313        Node::Region { body, .. } => {
314            *kinds |= NODE_KIND_REGION;
315            *regions = regions.saturating_add(1);
316            let saved = ir.enter_scope();
317            for child in body.iter() {
318                walk_node(child, nodes, regions, calls, opaque, bits, kinds, ir);
319            }
320            ir.leave_scope(saved);
321        }
322        Node::AsyncLoad { offset, size, .. } => {
323            *kinds |= NODE_KIND_ASYNC_LOAD;
324            *bits |= CAP_ASYNC_DISPATCH;
325            ir.memory();
326            walk_expr(offset, nodes, regions, calls, opaque, bits, kinds, ir);
327            walk_expr(size, nodes, regions, calls, opaque, bits, kinds, ir);
328        }
329        Node::AsyncStore { offset, size, .. } => {
330            *kinds |= NODE_KIND_ASYNC_STORE;
331            *bits |= CAP_ASYNC_DISPATCH;
332            ir.memory();
333            walk_expr(offset, nodes, regions, calls, opaque, bits, kinds, ir);
334            walk_expr(size, nodes, regions, calls, opaque, bits, kinds, ir);
335        }
336        Node::AsyncWait { .. } => {
337            *kinds |= NODE_KIND_ASYNC_WAIT;
338            *bits |= CAP_ASYNC_DISPATCH;
339            ir.control_flow();
340        }
341        Node::IndirectDispatch { .. } => {
342            *kinds |= NODE_KIND_INDIRECT_DISPATCH;
343            *bits |= CAP_INDIRECT_DISPATCH;
344            ir.control_flow();
345        }
346        Node::Trap { address, .. } => {
347            *kinds |= NODE_KIND_TRAP;
348            *bits |= CAP_TRAP;
349            ir.control_flow();
350            walk_expr(address, nodes, regions, calls, opaque, bits, kinds, ir);
351        }
352        Node::AllReduce { .. } => {
353            *kinds |= NODE_KIND_ALL_REDUCE;
354            *bits |= CAP_DISTRIBUTED_COLLECTIVES;
355            ir.memory();
356        }
357        Node::AllGather { .. } => {
358            *kinds |= NODE_KIND_ALL_GATHER;
359            *bits |= CAP_DISTRIBUTED_COLLECTIVES;
360            ir.memory();
361        }
362        Node::ReduceScatter { .. } => {
363            *kinds |= NODE_KIND_REDUCE_SCATTER;
364            *bits |= CAP_DISTRIBUTED_COLLECTIVES;
365            ir.memory();
366        }
367        Node::Broadcast { .. } => {
368            *kinds |= NODE_KIND_BROADCAST;
369            *bits |= CAP_DISTRIBUTED_COLLECTIVES;
370            ir.memory();
371        }
372        Node::Opaque(_) => {
373            *kinds |= NODE_KIND_OPAQUE;
374            *opaque = opaque.saturating_add(1);
375            ir.instruction();
376        }
377        Node::Return => {
378            *kinds |= NODE_KIND_RETURN;
379            ir.control_flow();
380        }
381        Node::Barrier { .. } => {
382            *kinds |= NODE_KIND_BARRIER;
383            ir.control_flow();
384        }
385        Node::Resume { .. } => {
386            *kinds |= NODE_KIND_RESUME;
387            ir.control_flow();
388        }
389    }
390}
391
392#[allow(clippy::only_used_in_recursion, clippy::too_many_arguments)]
393fn walk_expr(
394    expr: &Expr,
395    nodes: &mut usize,
396    regions: &mut u32,
397    calls: &mut u32,
398    opaque: &mut u32,
399    bits: &mut u32,
400    kinds: &mut u32,
401    ir: &mut IrCounters,
402) {
403    match expr {
404        Expr::SubgroupAdd { value } => {
405            *bits |= CAP_SUBGROUP_OPS;
406            ir.instruction();
407            walk_expr(value, nodes, regions, calls, opaque, bits, kinds, ir);
408        }
409        Expr::SubgroupBallot { cond } => {
410            *bits |= CAP_SUBGROUP_OPS;
411            ir.instruction();
412            walk_expr(cond, nodes, regions, calls, opaque, bits, kinds, ir);
413        }
414        Expr::SubgroupShuffle { value, lane } => {
415            *bits |= CAP_SUBGROUP_OPS;
416            ir.instruction();
417            walk_expr(value, nodes, regions, calls, opaque, bits, kinds, ir);
418            walk_expr(lane, nodes, regions, calls, opaque, bits, kinds, ir);
419        }
420        Expr::BinOp { left, right, .. } => {
421            ir.instruction();
422            walk_expr(left, nodes, regions, calls, opaque, bits, kinds, ir);
423            walk_expr(right, nodes, regions, calls, opaque, bits, kinds, ir);
424        }
425        Expr::UnOp { operand, .. } => {
426            ir.instruction();
427            walk_expr(operand, nodes, regions, calls, opaque, bits, kinds, ir);
428        }
429        Expr::Fma { a, b, c } => {
430            ir.instruction();
431            walk_expr(a, nodes, regions, calls, opaque, bits, kinds, ir);
432            walk_expr(b, nodes, regions, calls, opaque, bits, kinds, ir);
433            walk_expr(c, nodes, regions, calls, opaque, bits, kinds, ir);
434        }
435        Expr::Select {
436            cond,
437            true_val,
438            false_val,
439        } => {
440            ir.instruction();
441            walk_expr(cond, nodes, regions, calls, opaque, bits, kinds, ir);
442            walk_expr(true_val, nodes, regions, calls, opaque, bits, kinds, ir);
443            walk_expr(false_val, nodes, regions, calls, opaque, bits, kinds, ir);
444        }
445        Expr::Cast { target, value } => {
446            mark_datatype_bits(target, bits);
447            ir.instruction();
448            walk_expr(value, nodes, regions, calls, opaque, bits, kinds, ir);
449        }
450        Expr::Load { index, .. } => {
451            ir.memory();
452            walk_expr(index, nodes, regions, calls, opaque, bits, kinds, ir);
453        }
454        Expr::Call { op_id, args } => {
455            if is_subgroup_intrinsic_id(op_id) {
456                *bits |= CAP_SUBGROUP_OPS;
457            }
458            *calls = calls.saturating_add(1);
459            ir.instruction();
460            for arg in args {
461                walk_expr(arg, nodes, regions, calls, opaque, bits, kinds, ir);
462            }
463        }
464        Expr::Atomic {
465            index,
466            expected,
467            value,
468            ..
469        } => {
470            ir.atomic();
471            walk_expr(index, nodes, regions, calls, opaque, bits, kinds, ir);
472            if let Some(expected) = expected.as_deref() {
473                walk_expr(expected, nodes, regions, calls, opaque, bits, kinds, ir);
474            }
475            walk_expr(value, nodes, regions, calls, opaque, bits, kinds, ir);
476        }
477        Expr::Opaque(_) => {
478            *opaque = opaque.saturating_add(1);
479            ir.instruction();
480        }
481        Expr::SubgroupLocalId | Expr::SubgroupSize => {
482            *bits |= CAP_SUBGROUP_OPS;
483            ir.instruction();
484        }
485        Expr::LitU32(_)
486        | Expr::LitI32(_)
487        | Expr::LitF32(_)
488        | Expr::LitBool(_)
489        | Expr::Var(_)
490        | Expr::BufLen { .. }
491        | Expr::InvocationId { .. }
492        | Expr::WorkgroupId { .. }
493        | Expr::LocalId { .. } => {}
494    }
495}
496
497fn is_subgroup_intrinsic_id(op_id: &str) -> bool {
498    const MARKERS: &[&str] = &[
499        "subgroup_",
500        "::subgroup::",
501        "::subgroup",
502        "wave_",
503        "::wave::",
504        "warp_",
505        "::warp::",
506    ];
507    MARKERS.iter().any(|marker| op_id.contains(marker))
508}