1use super::Program;
2use crate::ir::{DataType, Expr, Node};
3
4const CAP_SUBGROUP_OPS: u32 = 1 << 0;
5const CAP_F16: u32 = 1 << 1;
6const CAP_BF16: u32 = 1 << 2;
7const CAP_F64: u32 = 1 << 3;
8const CAP_ASYNC_DISPATCH: u32 = 1 << 4;
9const CAP_INDIRECT_DISPATCH: u32 = 1 << 5;
10const CAP_TENSOR_OPS: u32 = 1 << 6;
11const CAP_TRAP: u32 = 1 << 7;
12const CAP_DISTRIBUTED_COLLECTIVES: u32 = 1 << 8;
13
14pub const NODE_KIND_LET: u32 = 1 << 0;
22pub const NODE_KIND_ASSIGN: u32 = 1 << 1;
24pub const NODE_KIND_STORE: u32 = 1 << 2;
26pub const NODE_KIND_IF: u32 = 1 << 3;
28pub const NODE_KIND_LOOP: u32 = 1 << 4;
30pub const NODE_KIND_INDIRECT_DISPATCH: u32 = 1 << 5;
32pub const NODE_KIND_ASYNC_LOAD: u32 = 1 << 6;
34pub const NODE_KIND_ASYNC_STORE: u32 = 1 << 7;
36pub const NODE_KIND_ASYNC_WAIT: u32 = 1 << 8;
38pub const NODE_KIND_TRAP: u32 = 1 << 9;
40pub const NODE_KIND_RESUME: u32 = 1 << 10;
42pub const NODE_KIND_RETURN: u32 = 1 << 11;
44pub const NODE_KIND_BARRIER: u32 = 1 << 12;
46pub const NODE_KIND_BLOCK: u32 = 1 << 13;
48pub const NODE_KIND_REGION: u32 = 1 << 14;
50pub const NODE_KIND_ALL_REDUCE: u32 = 1 << 15;
52pub const NODE_KIND_ALL_GATHER: u32 = 1 << 16;
54pub const NODE_KIND_REDUCE_SCATTER: u32 = 1 << 17;
56pub const NODE_KIND_BROADCAST: u32 = 1 << 18;
58pub const NODE_KIND_OPAQUE: u32 = 1 << 19;
60
61pub const NODE_KIND_EXPRESSION_BEARING_MASK: u32 = NODE_KIND_LET
67 | NODE_KIND_ASSIGN
68 | NODE_KIND_STORE
69 | NODE_KIND_IF
70 | NODE_KIND_LOOP
71 | NODE_KIND_ASYNC_LOAD
72 | NODE_KIND_ASYNC_STORE
73 | NODE_KIND_TRAP;
74
75#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
81pub struct ProgramStats {
82 pub node_count: usize,
84 pub region_count: u32,
86 pub call_count: u32,
88 pub opaque_count: u32,
90 pub top_level_regions: u32,
92 pub static_storage_bytes: u64,
94 pub instruction_count: u64,
96 pub memory_op_count: u64,
98 pub atomic_op_count: u64,
100 pub control_flow_count: u64,
102 pub register_pressure_estimate: u32,
104 pub capability_bits: u32,
106 pub node_kinds_present: u32,
113}
114
115mod methods;
116impl Program {
117 #[must_use]
119 #[inline]
120 pub fn stats(&self) -> &ProgramStats {
121 self.stats
122 .get_or_init(|| std::sync::Arc::new(compute_stats(self)))
123 .as_ref()
124 }
125}
126
127pub(crate) fn compute_stats(program: &Program) -> ProgramStats {
129 let mut node_count = 0usize;
130 let mut region_count = 0u32;
131 let mut call_count = 0u32;
132 let mut opaque_count = 0u32;
133 let mut capability_bits = 0u32;
134 let mut node_kinds_present = 0u32;
135 let mut static_storage_bytes = 0u64;
136 let mut ir = IrCounters::default();
137
138 for decl in program.buffers.iter() {
139 let count = decl.count();
140 if count != 0 {
141 if let Some(elem) = decl.element().size_bytes() {
142 static_storage_bytes =
143 static_storage_bytes.saturating_add(u64::from(count) * elem as u64);
144 }
145 }
146 mark_datatype_bits(&decl.element(), &mut capability_bits);
147 }
148
149 for node in program.entry.iter() {
150 walk_node(
151 node,
152 &mut node_count,
153 &mut region_count,
154 &mut call_count,
155 &mut opaque_count,
156 &mut capability_bits,
157 &mut node_kinds_present,
158 &mut ir,
159 );
160 }
161
162 let top_level_regions = program
163 .entry()
164 .iter()
165 .filter(|n| matches!(n, Node::Region { .. }))
166 .count()
167 .try_into()
168 .unwrap_or(u32::MAX);
169
170 ProgramStats {
171 node_count,
172 region_count,
173 call_count,
174 opaque_count,
175 top_level_regions,
176 static_storage_bytes,
177 instruction_count: ir.instruction_count,
178 memory_op_count: ir.memory_op_count,
179 atomic_op_count: ir.atomic_op_count,
180 control_flow_count: ir.control_flow_count,
181 register_pressure_estimate: ir.register_pressure_estimate(),
182 capability_bits,
183 node_kinds_present,
184 }
185}
186
187#[derive(Default)]
188struct IrCounters {
189 instruction_count: u64,
190 memory_op_count: u64,
191 atomic_op_count: u64,
192 control_flow_count: u64,
193 live_names: u32,
194 max_live_names: u32,
195}
196
197impl IrCounters {
198 fn instruction(&mut self) {
199 self.instruction_count = self.instruction_count.saturating_add(1);
200 }
201
202 fn memory(&mut self) {
203 self.memory_op_count = self.memory_op_count.saturating_add(1);
204 self.instruction();
205 }
206
207 fn atomic(&mut self) {
208 self.atomic_op_count = self.atomic_op_count.saturating_add(1);
209 self.memory();
210 }
211
212 fn control_flow(&mut self) {
213 self.control_flow_count = self.control_flow_count.saturating_add(1);
214 self.instruction();
215 }
216
217 fn bind_name(&mut self) {
218 self.live_names = self.live_names.saturating_add(1);
219 self.max_live_names = self.max_live_names.max(self.live_names);
220 }
221
222 fn enter_scope(&mut self) -> u32 {
223 self.live_names
224 }
225
226 fn leave_scope(&mut self, saved: u32) {
227 self.live_names = saved;
228 }
229
230 fn register_pressure_estimate(&self) -> u32 {
231 self.max_live_names
232 }
233}
234
235#[inline]
236fn mark_datatype_bits(ty: &DataType, bits: &mut u32) {
237 match ty {
238 DataType::F16 => *bits |= CAP_F16,
239 DataType::BF16 => *bits |= CAP_BF16,
240 DataType::F64 => *bits |= CAP_F64,
241 DataType::Tensor | DataType::TensorShaped { .. } => *bits |= CAP_TENSOR_OPS,
242 _ => {}
243 }
244}
245
246#[allow(clippy::too_many_arguments)]
247#[expect(
248 clippy::too_many_lines,
249 reason = "single-pass ProgramStats walker keeps all counters hot and avoids repeated IR traversals"
250)]
251fn walk_node(
252 node: &Node,
253 nodes: &mut usize,
254 regions: &mut u32,
255 calls: &mut u32,
256 opaque: &mut u32,
257 bits: &mut u32,
258 kinds: &mut u32,
259 ir: &mut IrCounters,
260) {
261 *nodes = nodes.saturating_add(1);
262 match node {
263 Node::Let { value, .. } => {
264 *kinds |= NODE_KIND_LET;
265 ir.instruction();
266 ir.bind_name();
267 walk_expr(value, nodes, regions, calls, opaque, bits, kinds, ir);
268 }
269 Node::Assign { value, .. } => {
270 *kinds |= NODE_KIND_ASSIGN;
271 ir.instruction();
272 walk_expr(value, nodes, regions, calls, opaque, bits, kinds, ir);
273 }
274 Node::Store { index, value, .. } => {
275 *kinds |= NODE_KIND_STORE;
276 ir.memory();
277 walk_expr(index, nodes, regions, calls, opaque, bits, kinds, ir);
278 walk_expr(value, nodes, regions, calls, opaque, bits, kinds, ir);
279 }
280 Node::If {
281 cond,
282 then,
283 otherwise,
284 } => {
285 *kinds |= NODE_KIND_IF;
286 ir.control_flow();
287 walk_expr(cond, nodes, regions, calls, opaque, bits, kinds, ir);
288 let saved = ir.enter_scope();
289 for child in then.iter().chain(otherwise.iter()) {
290 walk_node(child, nodes, regions, calls, opaque, bits, kinds, ir);
291 }
292 ir.leave_scope(saved);
293 }
294 Node::Loop { from, to, body, .. } => {
295 *kinds |= NODE_KIND_LOOP;
296 ir.control_flow();
297 walk_expr(from, nodes, regions, calls, opaque, bits, kinds, ir);
298 walk_expr(to, nodes, regions, calls, opaque, bits, kinds, ir);
299 let saved = ir.enter_scope();
300 for child in body {
301 walk_node(child, nodes, regions, calls, opaque, bits, kinds, ir);
302 }
303 ir.leave_scope(saved);
304 }
305 Node::Block(children) => {
306 *kinds |= NODE_KIND_BLOCK;
307 let saved = ir.enter_scope();
308 for child in children {
309 walk_node(child, nodes, regions, calls, opaque, bits, kinds, ir);
310 }
311 ir.leave_scope(saved);
312 }
313 Node::Region { body, .. } => {
314 *kinds |= NODE_KIND_REGION;
315 *regions = regions.saturating_add(1);
316 let saved = ir.enter_scope();
317 for child in body.iter() {
318 walk_node(child, nodes, regions, calls, opaque, bits, kinds, ir);
319 }
320 ir.leave_scope(saved);
321 }
322 Node::AsyncLoad { offset, size, .. } => {
323 *kinds |= NODE_KIND_ASYNC_LOAD;
324 *bits |= CAP_ASYNC_DISPATCH;
325 ir.memory();
326 walk_expr(offset, nodes, regions, calls, opaque, bits, kinds, ir);
327 walk_expr(size, nodes, regions, calls, opaque, bits, kinds, ir);
328 }
329 Node::AsyncStore { offset, size, .. } => {
330 *kinds |= NODE_KIND_ASYNC_STORE;
331 *bits |= CAP_ASYNC_DISPATCH;
332 ir.memory();
333 walk_expr(offset, nodes, regions, calls, opaque, bits, kinds, ir);
334 walk_expr(size, nodes, regions, calls, opaque, bits, kinds, ir);
335 }
336 Node::AsyncWait { .. } => {
337 *kinds |= NODE_KIND_ASYNC_WAIT;
338 *bits |= CAP_ASYNC_DISPATCH;
339 ir.control_flow();
340 }
341 Node::IndirectDispatch { .. } => {
342 *kinds |= NODE_KIND_INDIRECT_DISPATCH;
343 *bits |= CAP_INDIRECT_DISPATCH;
344 ir.control_flow();
345 }
346 Node::Trap { address, .. } => {
347 *kinds |= NODE_KIND_TRAP;
348 *bits |= CAP_TRAP;
349 ir.control_flow();
350 walk_expr(address, nodes, regions, calls, opaque, bits, kinds, ir);
351 }
352 Node::AllReduce { .. } => {
353 *kinds |= NODE_KIND_ALL_REDUCE;
354 *bits |= CAP_DISTRIBUTED_COLLECTIVES;
355 ir.memory();
356 }
357 Node::AllGather { .. } => {
358 *kinds |= NODE_KIND_ALL_GATHER;
359 *bits |= CAP_DISTRIBUTED_COLLECTIVES;
360 ir.memory();
361 }
362 Node::ReduceScatter { .. } => {
363 *kinds |= NODE_KIND_REDUCE_SCATTER;
364 *bits |= CAP_DISTRIBUTED_COLLECTIVES;
365 ir.memory();
366 }
367 Node::Broadcast { .. } => {
368 *kinds |= NODE_KIND_BROADCAST;
369 *bits |= CAP_DISTRIBUTED_COLLECTIVES;
370 ir.memory();
371 }
372 Node::Opaque(_) => {
373 *kinds |= NODE_KIND_OPAQUE;
374 *opaque = opaque.saturating_add(1);
375 ir.instruction();
376 }
377 Node::Return => {
378 *kinds |= NODE_KIND_RETURN;
379 ir.control_flow();
380 }
381 Node::Barrier { .. } => {
382 *kinds |= NODE_KIND_BARRIER;
383 ir.control_flow();
384 }
385 Node::Resume { .. } => {
386 *kinds |= NODE_KIND_RESUME;
387 ir.control_flow();
388 }
389 }
390}
391
392#[allow(clippy::only_used_in_recursion, clippy::too_many_arguments)]
393fn walk_expr(
394 expr: &Expr,
395 nodes: &mut usize,
396 regions: &mut u32,
397 calls: &mut u32,
398 opaque: &mut u32,
399 bits: &mut u32,
400 kinds: &mut u32,
401 ir: &mut IrCounters,
402) {
403 match expr {
404 Expr::SubgroupAdd { value } => {
405 *bits |= CAP_SUBGROUP_OPS;
406 ir.instruction();
407 walk_expr(value, nodes, regions, calls, opaque, bits, kinds, ir);
408 }
409 Expr::SubgroupBallot { cond } => {
410 *bits |= CAP_SUBGROUP_OPS;
411 ir.instruction();
412 walk_expr(cond, nodes, regions, calls, opaque, bits, kinds, ir);
413 }
414 Expr::SubgroupShuffle { value, lane } => {
415 *bits |= CAP_SUBGROUP_OPS;
416 ir.instruction();
417 walk_expr(value, nodes, regions, calls, opaque, bits, kinds, ir);
418 walk_expr(lane, nodes, regions, calls, opaque, bits, kinds, ir);
419 }
420 Expr::BinOp { left, right, .. } => {
421 ir.instruction();
422 walk_expr(left, nodes, regions, calls, opaque, bits, kinds, ir);
423 walk_expr(right, nodes, regions, calls, opaque, bits, kinds, ir);
424 }
425 Expr::UnOp { operand, .. } => {
426 ir.instruction();
427 walk_expr(operand, nodes, regions, calls, opaque, bits, kinds, ir);
428 }
429 Expr::Fma { a, b, c } => {
430 ir.instruction();
431 walk_expr(a, nodes, regions, calls, opaque, bits, kinds, ir);
432 walk_expr(b, nodes, regions, calls, opaque, bits, kinds, ir);
433 walk_expr(c, nodes, regions, calls, opaque, bits, kinds, ir);
434 }
435 Expr::Select {
436 cond,
437 true_val,
438 false_val,
439 } => {
440 ir.instruction();
441 walk_expr(cond, nodes, regions, calls, opaque, bits, kinds, ir);
442 walk_expr(true_val, nodes, regions, calls, opaque, bits, kinds, ir);
443 walk_expr(false_val, nodes, regions, calls, opaque, bits, kinds, ir);
444 }
445 Expr::Cast { target, value } => {
446 mark_datatype_bits(target, bits);
447 ir.instruction();
448 walk_expr(value, nodes, regions, calls, opaque, bits, kinds, ir);
449 }
450 Expr::Load { index, .. } => {
451 ir.memory();
452 walk_expr(index, nodes, regions, calls, opaque, bits, kinds, ir);
453 }
454 Expr::Call { op_id, args } => {
455 if is_subgroup_intrinsic_id(op_id) {
456 *bits |= CAP_SUBGROUP_OPS;
457 }
458 *calls = calls.saturating_add(1);
459 ir.instruction();
460 for arg in args {
461 walk_expr(arg, nodes, regions, calls, opaque, bits, kinds, ir);
462 }
463 }
464 Expr::Atomic {
465 index,
466 expected,
467 value,
468 ..
469 } => {
470 ir.atomic();
471 walk_expr(index, nodes, regions, calls, opaque, bits, kinds, ir);
472 if let Some(expected) = expected.as_deref() {
473 walk_expr(expected, nodes, regions, calls, opaque, bits, kinds, ir);
474 }
475 walk_expr(value, nodes, regions, calls, opaque, bits, kinds, ir);
476 }
477 Expr::Opaque(_) => {
478 *opaque = opaque.saturating_add(1);
479 ir.instruction();
480 }
481 Expr::SubgroupLocalId | Expr::SubgroupSize => {
482 *bits |= CAP_SUBGROUP_OPS;
483 ir.instruction();
484 }
485 Expr::LitU32(_)
486 | Expr::LitI32(_)
487 | Expr::LitF32(_)
488 | Expr::LitBool(_)
489 | Expr::Var(_)
490 | Expr::BufLen { .. }
491 | Expr::InvocationId { .. }
492 | Expr::WorkgroupId { .. }
493 | Expr::LocalId { .. } => {}
494 }
495}
496
497
498fn is_subgroup_intrinsic_id(op_id: &str) -> bool {
499 const MARKERS: &[&str] = &[
500 "subgroup_",
501 "::subgroup::",
502 "::subgroup",
503 "wave_",
504 "::wave::",
505 "warp_",
506 "::warp::",
507 ];
508 MARKERS.iter().any(|marker| op_id.contains(marker))
509}
510