Skip to main content

vyre_foundation/ir_inner/model/program/
stats.rs

1use super::Program;
2use crate::ir::{DataType, Expr, Node};
3
4const CAP_SUBGROUP_OPS: u32 = 1 << 0;
5const CAP_F16: u32 = 1 << 1;
6const CAP_BF16: u32 = 1 << 2;
7const CAP_F64: u32 = 1 << 3;
8const CAP_ASYNC_DISPATCH: u32 = 1 << 4;
9const CAP_INDIRECT_DISPATCH: u32 = 1 << 5;
10const CAP_TENSOR_OPS: u32 = 1 << 6;
11const CAP_TRAP: u32 = 1 << 7;
12
13/// Aggregated statistics computed from a single walk of a [`Program`].
14///
15/// This struct is cached inside [`Program`] via a [`std::sync::OnceLock`]
16/// so that planning passes (execution plan, capability scan, provenance,
17/// fusion) can read constant-time summaries instead of re-walking the IR.
18#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
19pub struct ProgramStats {
20    /// Total statement-node count (includes nested children).
21    pub node_count: usize,
22    /// Number of `Node::Region` nodes in the full tree.
23    pub region_count: u32,
24    /// Number of `Expr::Call` expressions.
25    pub call_count: u32,
26    /// Number of `Node::Opaque` nodes and `Expr::Opaque` expressions.
27    pub opaque_count: u32,
28    /// Number of top-level `Node::Region` wrappers in `program.entry()`.
29    pub top_level_regions: u32,
30    /// Sum of statically-known buffer byte sizes.
31    pub static_storage_bytes: u64,
32    /// Estimated scalar/vector IR instruction count.
33    pub instruction_count: u64,
34    /// Number of explicit memory operations (loads, stores, async copies).
35    pub memory_op_count: u64,
36    /// Number of atomic read-modify-write expressions.
37    pub atomic_op_count: u64,
38    /// Number of control-flow operations.
39    pub control_flow_count: u64,
40    /// Coarse register pressure estimate from simultaneously named SSA-ish values.
41    pub register_pressure_estimate: u32,
42    /// Bitmask of capability requirements (see `CAP_*` constants).
43    pub capability_bits: u32,
44}
45
46impl ProgramStats {
47    /// True when the program uses subgroup operations.
48    #[inline]
49    #[must_use]
50    pub fn subgroup_ops(&self) -> bool {
51        self.capability_bits & CAP_SUBGROUP_OPS != 0
52    }
53
54    /// True when the program uses IEEE-754 binary16 values.
55    #[inline]
56    #[must_use]
57    pub fn f16(&self) -> bool {
58        self.capability_bits & CAP_F16 != 0
59    }
60
61    /// True when the program uses bfloat16 values.
62    #[inline]
63    #[must_use]
64    pub fn bf16(&self) -> bool {
65        self.capability_bits & CAP_BF16 != 0
66    }
67
68    /// True when the program uses IEEE-754 binary64 values.
69    #[inline]
70    #[must_use]
71    pub fn f64(&self) -> bool {
72        self.capability_bits & CAP_F64 != 0
73    }
74
75    /// True when the program requires async dispatch semantics.
76    #[inline]
77    #[must_use]
78    pub fn async_dispatch(&self) -> bool {
79        self.capability_bits & CAP_ASYNC_DISPATCH != 0
80    }
81
82    /// True when the program requires indirect dispatch support.
83    #[inline]
84    #[must_use]
85    pub fn indirect_dispatch(&self) -> bool {
86        self.capability_bits & CAP_INDIRECT_DISPATCH != 0
87    }
88
89    /// True when the program uses tensor / tensor-core operand types.
90    #[inline]
91    #[must_use]
92    pub fn tensor_ops(&self) -> bool {
93        self.capability_bits & CAP_TENSOR_OPS != 0
94    }
95
96    /// True when the program uses `Node::Trap`.
97    #[inline]
98    #[must_use]
99    pub fn trap(&self) -> bool {
100        self.capability_bits & CAP_TRAP != 0
101    }
102}
103
104impl Program {
105    /// Return cached statistics for this program, computing them on first call.
106    #[must_use]
107    #[inline]
108    pub fn stats(&self) -> &ProgramStats {
109        self.stats
110            .get_or_init(|| std::sync::Arc::new(compute_stats(self)))
111            .as_ref()
112    }
113}
114
115/// Single-pass preorder walk that accumulates every field of [`ProgramStats`].
116pub(crate) fn compute_stats(program: &Program) -> ProgramStats {
117    let mut node_count = 0usize;
118    let mut region_count = 0u32;
119    let mut call_count = 0u32;
120    let mut opaque_count = 0u32;
121    let mut capability_bits = 0u32;
122    let mut static_storage_bytes = 0u64;
123    let mut ir = IrCounters::default();
124
125    for decl in program.buffers.iter() {
126        let count = decl.count();
127        if count != 0 {
128            if let Some(elem) = decl.element().size_bytes() {
129                static_storage_bytes =
130                    static_storage_bytes.saturating_add(u64::from(count) * elem as u64);
131            }
132        }
133        mark_datatype_bits(&decl.element(), &mut capability_bits);
134    }
135
136    for node in program.entry.iter() {
137        walk_node(
138            node,
139            &mut node_count,
140            &mut region_count,
141            &mut call_count,
142            &mut opaque_count,
143            &mut capability_bits,
144            &mut ir,
145        );
146    }
147
148    let top_level_regions = program
149        .entry()
150        .iter()
151        .filter(|n| matches!(n, Node::Region { .. }))
152        .count() as u32;
153
154    ProgramStats {
155        node_count,
156        region_count,
157        call_count,
158        opaque_count,
159        top_level_regions,
160        static_storage_bytes,
161        instruction_count: ir.instruction_count,
162        memory_op_count: ir.memory_op_count,
163        atomic_op_count: ir.atomic_op_count,
164        control_flow_count: ir.control_flow_count,
165        register_pressure_estimate: ir.register_pressure_estimate(),
166        capability_bits,
167    }
168}
169
170#[derive(Default)]
171struct IrCounters {
172    instruction_count: u64,
173    memory_op_count: u64,
174    atomic_op_count: u64,
175    control_flow_count: u64,
176    live_names: u32,
177    max_live_names: u32,
178}
179
180impl IrCounters {
181    fn instruction(&mut self) {
182        self.instruction_count = self.instruction_count.saturating_add(1);
183    }
184
185    fn memory(&mut self) {
186        self.memory_op_count = self.memory_op_count.saturating_add(1);
187        self.instruction();
188    }
189
190    fn atomic(&mut self) {
191        self.atomic_op_count = self.atomic_op_count.saturating_add(1);
192        self.memory();
193    }
194
195    fn control_flow(&mut self) {
196        self.control_flow_count = self.control_flow_count.saturating_add(1);
197        self.instruction();
198    }
199
200    fn bind_name(&mut self) {
201        self.live_names = self.live_names.saturating_add(1);
202        self.max_live_names = self.max_live_names.max(self.live_names);
203    }
204
205    fn enter_scope(&mut self) -> u32 {
206        self.live_names
207    }
208
209    fn leave_scope(&mut self, saved: u32) {
210        self.live_names = saved;
211    }
212
213    fn register_pressure_estimate(&self) -> u32 {
214        self.max_live_names
215    }
216}
217
218#[inline]
219fn mark_datatype_bits(ty: &DataType, bits: &mut u32) {
220    match ty {
221        DataType::F16 => *bits |= CAP_F16,
222        DataType::BF16 => *bits |= CAP_BF16,
223        DataType::F64 => *bits |= CAP_F64,
224        DataType::Tensor | DataType::TensorShaped { .. } => *bits |= CAP_TENSOR_OPS,
225        _ => {}
226    }
227}
228
229fn walk_node(
230    node: &Node,
231    nodes: &mut usize,
232    regions: &mut u32,
233    calls: &mut u32,
234    opaque: &mut u32,
235    bits: &mut u32,
236    ir: &mut IrCounters,
237) {
238    *nodes = nodes.saturating_add(1);
239    match node {
240        Node::Let { value, .. } | Node::Assign { value, .. } => {
241            ir.instruction();
242            if matches!(node, Node::Let { .. }) {
243                ir.bind_name();
244            }
245            walk_expr(value, nodes, regions, calls, opaque, bits, ir);
246        }
247        Node::Store { index, value, .. } => {
248            ir.memory();
249            walk_expr(index, nodes, regions, calls, opaque, bits, ir);
250            walk_expr(value, nodes, regions, calls, opaque, bits, ir);
251        }
252        Node::If {
253            cond,
254            then,
255            otherwise,
256        } => {
257            ir.control_flow();
258            walk_expr(cond, nodes, regions, calls, opaque, bits, ir);
259            let saved = ir.enter_scope();
260            for child in then.iter().chain(otherwise.iter()) {
261                walk_node(child, nodes, regions, calls, opaque, bits, ir);
262            }
263            ir.leave_scope(saved);
264        }
265        Node::Loop { from, to, body, .. } => {
266            ir.control_flow();
267            walk_expr(from, nodes, regions, calls, opaque, bits, ir);
268            walk_expr(to, nodes, regions, calls, opaque, bits, ir);
269            let saved = ir.enter_scope();
270            for child in body.iter() {
271                walk_node(child, nodes, regions, calls, opaque, bits, ir);
272            }
273            ir.leave_scope(saved);
274        }
275        Node::Block(children) => {
276            let saved = ir.enter_scope();
277            for child in children.iter() {
278                walk_node(child, nodes, regions, calls, opaque, bits, ir);
279            }
280            ir.leave_scope(saved);
281        }
282        Node::Region { body, .. } => {
283            *regions = regions.saturating_add(1);
284            let saved = ir.enter_scope();
285            for child in body.iter() {
286                walk_node(child, nodes, regions, calls, opaque, bits, ir);
287            }
288            ir.leave_scope(saved);
289        }
290        Node::AsyncLoad { offset, size, .. } | Node::AsyncStore { offset, size, .. } => {
291            *bits |= CAP_ASYNC_DISPATCH;
292            ir.memory();
293            walk_expr(offset, nodes, regions, calls, opaque, bits, ir);
294            walk_expr(size, nodes, regions, calls, opaque, bits, ir);
295        }
296        Node::AsyncWait { .. } => {
297            *bits |= CAP_ASYNC_DISPATCH;
298            ir.control_flow();
299        }
300        Node::IndirectDispatch { .. } => {
301            *bits |= CAP_INDIRECT_DISPATCH;
302            ir.control_flow();
303        }
304        Node::Trap { address, .. } => {
305            *bits |= CAP_TRAP;
306            ir.control_flow();
307            walk_expr(address, nodes, regions, calls, opaque, bits, ir);
308        }
309        Node::Opaque(_) => {
310            *opaque = opaque.saturating_add(1);
311            ir.instruction();
312        }
313        Node::Return | Node::Barrier { .. } | Node::Resume { .. } => {
314            ir.control_flow();
315        }
316    }
317}
318
319#[allow(clippy::only_used_in_recursion)]
320fn walk_expr(
321    expr: &Expr,
322    nodes: &mut usize,
323    regions: &mut u32,
324    calls: &mut u32,
325    opaque: &mut u32,
326    bits: &mut u32,
327    ir: &mut IrCounters,
328) {
329    match expr {
330        Expr::SubgroupAdd { value } => {
331            *bits |= CAP_SUBGROUP_OPS;
332            ir.instruction();
333            walk_expr(value, nodes, regions, calls, opaque, bits, ir);
334        }
335        Expr::SubgroupBallot { cond } => {
336            *bits |= CAP_SUBGROUP_OPS;
337            ir.instruction();
338            walk_expr(cond, nodes, regions, calls, opaque, bits, ir);
339        }
340        Expr::SubgroupShuffle { value, lane } => {
341            *bits |= CAP_SUBGROUP_OPS;
342            ir.instruction();
343            walk_expr(value, nodes, regions, calls, opaque, bits, ir);
344            walk_expr(lane, nodes, regions, calls, opaque, bits, ir);
345        }
346        Expr::BinOp { left, right, .. } => {
347            ir.instruction();
348            walk_expr(left, nodes, regions, calls, opaque, bits, ir);
349            walk_expr(right, nodes, regions, calls, opaque, bits, ir);
350        }
351        Expr::UnOp { operand, .. } => {
352            ir.instruction();
353            walk_expr(operand, nodes, regions, calls, opaque, bits, ir);
354        }
355        Expr::Fma { a, b, c } => {
356            ir.instruction();
357            walk_expr(a, nodes, regions, calls, opaque, bits, ir);
358            walk_expr(b, nodes, regions, calls, opaque, bits, ir);
359            walk_expr(c, nodes, regions, calls, opaque, bits, ir);
360        }
361        Expr::Select {
362            cond,
363            true_val,
364            false_val,
365        } => {
366            ir.instruction();
367            walk_expr(cond, nodes, regions, calls, opaque, bits, ir);
368            walk_expr(true_val, nodes, regions, calls, opaque, bits, ir);
369            walk_expr(false_val, nodes, regions, calls, opaque, bits, ir);
370        }
371        Expr::Cast { target, value } => {
372            mark_datatype_bits(target, bits);
373            ir.instruction();
374            walk_expr(value, nodes, regions, calls, opaque, bits, ir);
375        }
376        Expr::Load { index, .. } => {
377            ir.memory();
378            walk_expr(index, nodes, regions, calls, opaque, bits, ir);
379        }
380        Expr::Call { op_id, args } => {
381            if is_subgroup_intrinsic_id(op_id) {
382                *bits |= CAP_SUBGROUP_OPS;
383            }
384            *calls = calls.saturating_add(1);
385            ir.instruction();
386            for arg in args.iter() {
387                walk_expr(arg, nodes, regions, calls, opaque, bits, ir);
388            }
389        }
390        Expr::Atomic {
391            index,
392            expected,
393            value,
394            ..
395        } => {
396            ir.atomic();
397            walk_expr(index, nodes, regions, calls, opaque, bits, ir);
398            if let Some(expected) = expected.as_deref() {
399                walk_expr(expected, nodes, regions, calls, opaque, bits, ir);
400            }
401            walk_expr(value, nodes, regions, calls, opaque, bits, ir);
402        }
403        Expr::Opaque(_) => {
404            *opaque = opaque.saturating_add(1);
405            ir.instruction();
406        }
407        Expr::SubgroupLocalId | Expr::SubgroupSize => {
408            *bits |= CAP_SUBGROUP_OPS;
409            ir.instruction();
410        }
411        Expr::LitU32(_)
412        | Expr::LitI32(_)
413        | Expr::LitF32(_)
414        | Expr::LitBool(_)
415        | Expr::Var(_)
416        | Expr::BufLen { .. }
417        | Expr::InvocationId { .. }
418        | Expr::WorkgroupId { .. }
419        | Expr::LocalId { .. } => {}
420    }
421}
422
423fn is_subgroup_intrinsic_id(op_id: &str) -> bool {
424    const MARKERS: &[&str] = &[
425        "subgroup_",
426        "::subgroup::",
427        "::subgroup",
428        "wave_",
429        "::wave::",
430        "warp_",
431        "::warp::",
432    ];
433    MARKERS.iter().any(|marker| op_id.contains(marker))
434}