1use super::Program;
2use crate::ir::{DataType, Expr, Node};
3
4const CAP_SUBGROUP_OPS: u32 = 1 << 0;
5const CAP_F16: u32 = 1 << 1;
6const CAP_BF16: u32 = 1 << 2;
7const CAP_F64: u32 = 1 << 3;
8const CAP_ASYNC_DISPATCH: u32 = 1 << 4;
9const CAP_INDIRECT_DISPATCH: u32 = 1 << 5;
10const CAP_TENSOR_OPS: u32 = 1 << 6;
11const CAP_TRAP: u32 = 1 << 7;
12
13#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
19pub struct ProgramStats {
20 pub node_count: usize,
22 pub region_count: u32,
24 pub call_count: u32,
26 pub opaque_count: u32,
28 pub top_level_regions: u32,
30 pub static_storage_bytes: u64,
32 pub instruction_count: u64,
34 pub memory_op_count: u64,
36 pub atomic_op_count: u64,
38 pub control_flow_count: u64,
40 pub register_pressure_estimate: u32,
42 pub capability_bits: u32,
44}
45
46impl ProgramStats {
47 #[inline]
49 #[must_use]
50 pub fn subgroup_ops(&self) -> bool {
51 self.capability_bits & CAP_SUBGROUP_OPS != 0
52 }
53
54 #[inline]
56 #[must_use]
57 pub fn f16(&self) -> bool {
58 self.capability_bits & CAP_F16 != 0
59 }
60
61 #[inline]
63 #[must_use]
64 pub fn bf16(&self) -> bool {
65 self.capability_bits & CAP_BF16 != 0
66 }
67
68 #[inline]
70 #[must_use]
71 pub fn f64(&self) -> bool {
72 self.capability_bits & CAP_F64 != 0
73 }
74
75 #[inline]
77 #[must_use]
78 pub fn async_dispatch(&self) -> bool {
79 self.capability_bits & CAP_ASYNC_DISPATCH != 0
80 }
81
82 #[inline]
84 #[must_use]
85 pub fn indirect_dispatch(&self) -> bool {
86 self.capability_bits & CAP_INDIRECT_DISPATCH != 0
87 }
88
89 #[inline]
91 #[must_use]
92 pub fn tensor_ops(&self) -> bool {
93 self.capability_bits & CAP_TENSOR_OPS != 0
94 }
95
96 #[inline]
98 #[must_use]
99 pub fn trap(&self) -> bool {
100 self.capability_bits & CAP_TRAP != 0
101 }
102}
103
104impl Program {
105 #[must_use]
107 #[inline]
108 pub fn stats(&self) -> &ProgramStats {
109 self.stats
110 .get_or_init(|| std::sync::Arc::new(compute_stats(self)))
111 .as_ref()
112 }
113}
114
115pub(crate) fn compute_stats(program: &Program) -> ProgramStats {
117 let mut node_count = 0usize;
118 let mut region_count = 0u32;
119 let mut call_count = 0u32;
120 let mut opaque_count = 0u32;
121 let mut capability_bits = 0u32;
122 let mut static_storage_bytes = 0u64;
123 let mut ir = IrCounters::default();
124
125 for decl in program.buffers.iter() {
126 let count = decl.count();
127 if count != 0 {
128 if let Some(elem) = decl.element().size_bytes() {
129 static_storage_bytes =
130 static_storage_bytes.saturating_add(u64::from(count) * elem as u64);
131 }
132 }
133 mark_datatype_bits(&decl.element(), &mut capability_bits);
134 }
135
136 for node in program.entry.iter() {
137 walk_node(
138 node,
139 &mut node_count,
140 &mut region_count,
141 &mut call_count,
142 &mut opaque_count,
143 &mut capability_bits,
144 &mut ir,
145 );
146 }
147
148 let top_level_regions = program
149 .entry()
150 .iter()
151 .filter(|n| matches!(n, Node::Region { .. }))
152 .count() as u32;
153
154 ProgramStats {
155 node_count,
156 region_count,
157 call_count,
158 opaque_count,
159 top_level_regions,
160 static_storage_bytes,
161 instruction_count: ir.instruction_count,
162 memory_op_count: ir.memory_op_count,
163 atomic_op_count: ir.atomic_op_count,
164 control_flow_count: ir.control_flow_count,
165 register_pressure_estimate: ir.register_pressure_estimate(),
166 capability_bits,
167 }
168}
169
170#[derive(Default)]
171struct IrCounters {
172 instruction_count: u64,
173 memory_op_count: u64,
174 atomic_op_count: u64,
175 control_flow_count: u64,
176 live_names: u32,
177 max_live_names: u32,
178}
179
180impl IrCounters {
181 fn instruction(&mut self) {
182 self.instruction_count = self.instruction_count.saturating_add(1);
183 }
184
185 fn memory(&mut self) {
186 self.memory_op_count = self.memory_op_count.saturating_add(1);
187 self.instruction();
188 }
189
190 fn atomic(&mut self) {
191 self.atomic_op_count = self.atomic_op_count.saturating_add(1);
192 self.memory();
193 }
194
195 fn control_flow(&mut self) {
196 self.control_flow_count = self.control_flow_count.saturating_add(1);
197 self.instruction();
198 }
199
200 fn bind_name(&mut self) {
201 self.live_names = self.live_names.saturating_add(1);
202 self.max_live_names = self.max_live_names.max(self.live_names);
203 }
204
205 fn enter_scope(&mut self) -> u32 {
206 self.live_names
207 }
208
209 fn leave_scope(&mut self, saved: u32) {
210 self.live_names = saved;
211 }
212
213 fn register_pressure_estimate(&self) -> u32 {
214 self.max_live_names
215 }
216}
217
218#[inline]
219fn mark_datatype_bits(ty: &DataType, bits: &mut u32) {
220 match ty {
221 DataType::F16 => *bits |= CAP_F16,
222 DataType::BF16 => *bits |= CAP_BF16,
223 DataType::F64 => *bits |= CAP_F64,
224 DataType::Tensor | DataType::TensorShaped { .. } => *bits |= CAP_TENSOR_OPS,
225 _ => {}
226 }
227}
228
229fn walk_node(
230 node: &Node,
231 nodes: &mut usize,
232 regions: &mut u32,
233 calls: &mut u32,
234 opaque: &mut u32,
235 bits: &mut u32,
236 ir: &mut IrCounters,
237) {
238 *nodes = nodes.saturating_add(1);
239 match node {
240 Node::Let { value, .. } | Node::Assign { value, .. } => {
241 ir.instruction();
242 if matches!(node, Node::Let { .. }) {
243 ir.bind_name();
244 }
245 walk_expr(value, nodes, regions, calls, opaque, bits, ir);
246 }
247 Node::Store { index, value, .. } => {
248 ir.memory();
249 walk_expr(index, nodes, regions, calls, opaque, bits, ir);
250 walk_expr(value, nodes, regions, calls, opaque, bits, ir);
251 }
252 Node::If {
253 cond,
254 then,
255 otherwise,
256 } => {
257 ir.control_flow();
258 walk_expr(cond, nodes, regions, calls, opaque, bits, ir);
259 let saved = ir.enter_scope();
260 for child in then.iter().chain(otherwise.iter()) {
261 walk_node(child, nodes, regions, calls, opaque, bits, ir);
262 }
263 ir.leave_scope(saved);
264 }
265 Node::Loop { from, to, body, .. } => {
266 ir.control_flow();
267 walk_expr(from, nodes, regions, calls, opaque, bits, ir);
268 walk_expr(to, nodes, regions, calls, opaque, bits, ir);
269 let saved = ir.enter_scope();
270 for child in body.iter() {
271 walk_node(child, nodes, regions, calls, opaque, bits, ir);
272 }
273 ir.leave_scope(saved);
274 }
275 Node::Block(children) => {
276 let saved = ir.enter_scope();
277 for child in children.iter() {
278 walk_node(child, nodes, regions, calls, opaque, bits, ir);
279 }
280 ir.leave_scope(saved);
281 }
282 Node::Region { body, .. } => {
283 *regions = regions.saturating_add(1);
284 let saved = ir.enter_scope();
285 for child in body.iter() {
286 walk_node(child, nodes, regions, calls, opaque, bits, ir);
287 }
288 ir.leave_scope(saved);
289 }
290 Node::AsyncLoad { offset, size, .. } | Node::AsyncStore { offset, size, .. } => {
291 *bits |= CAP_ASYNC_DISPATCH;
292 ir.memory();
293 walk_expr(offset, nodes, regions, calls, opaque, bits, ir);
294 walk_expr(size, nodes, regions, calls, opaque, bits, ir);
295 }
296 Node::AsyncWait { .. } => {
297 *bits |= CAP_ASYNC_DISPATCH;
298 ir.control_flow();
299 }
300 Node::IndirectDispatch { .. } => {
301 *bits |= CAP_INDIRECT_DISPATCH;
302 ir.control_flow();
303 }
304 Node::Trap { address, .. } => {
305 *bits |= CAP_TRAP;
306 ir.control_flow();
307 walk_expr(address, nodes, regions, calls, opaque, bits, ir);
308 }
309 Node::Opaque(_) => {
310 *opaque = opaque.saturating_add(1);
311 ir.instruction();
312 }
313 Node::Return | Node::Barrier { .. } | Node::Resume { .. } => {
314 ir.control_flow();
315 }
316 }
317}
318
319#[allow(clippy::only_used_in_recursion)]
320fn walk_expr(
321 expr: &Expr,
322 nodes: &mut usize,
323 regions: &mut u32,
324 calls: &mut u32,
325 opaque: &mut u32,
326 bits: &mut u32,
327 ir: &mut IrCounters,
328) {
329 match expr {
330 Expr::SubgroupAdd { value } => {
331 *bits |= CAP_SUBGROUP_OPS;
332 ir.instruction();
333 walk_expr(value, nodes, regions, calls, opaque, bits, ir);
334 }
335 Expr::SubgroupBallot { cond } => {
336 *bits |= CAP_SUBGROUP_OPS;
337 ir.instruction();
338 walk_expr(cond, nodes, regions, calls, opaque, bits, ir);
339 }
340 Expr::SubgroupShuffle { value, lane } => {
341 *bits |= CAP_SUBGROUP_OPS;
342 ir.instruction();
343 walk_expr(value, nodes, regions, calls, opaque, bits, ir);
344 walk_expr(lane, nodes, regions, calls, opaque, bits, ir);
345 }
346 Expr::BinOp { left, right, .. } => {
347 ir.instruction();
348 walk_expr(left, nodes, regions, calls, opaque, bits, ir);
349 walk_expr(right, nodes, regions, calls, opaque, bits, ir);
350 }
351 Expr::UnOp { operand, .. } => {
352 ir.instruction();
353 walk_expr(operand, nodes, regions, calls, opaque, bits, ir);
354 }
355 Expr::Fma { a, b, c } => {
356 ir.instruction();
357 walk_expr(a, nodes, regions, calls, opaque, bits, ir);
358 walk_expr(b, nodes, regions, calls, opaque, bits, ir);
359 walk_expr(c, nodes, regions, calls, opaque, bits, ir);
360 }
361 Expr::Select {
362 cond,
363 true_val,
364 false_val,
365 } => {
366 ir.instruction();
367 walk_expr(cond, nodes, regions, calls, opaque, bits, ir);
368 walk_expr(true_val, nodes, regions, calls, opaque, bits, ir);
369 walk_expr(false_val, nodes, regions, calls, opaque, bits, ir);
370 }
371 Expr::Cast { target, value } => {
372 mark_datatype_bits(target, bits);
373 ir.instruction();
374 walk_expr(value, nodes, regions, calls, opaque, bits, ir);
375 }
376 Expr::Load { index, .. } => {
377 ir.memory();
378 walk_expr(index, nodes, regions, calls, opaque, bits, ir);
379 }
380 Expr::Call { op_id, args } => {
381 if is_subgroup_intrinsic_id(op_id) {
382 *bits |= CAP_SUBGROUP_OPS;
383 }
384 *calls = calls.saturating_add(1);
385 ir.instruction();
386 for arg in args.iter() {
387 walk_expr(arg, nodes, regions, calls, opaque, bits, ir);
388 }
389 }
390 Expr::Atomic {
391 index,
392 expected,
393 value,
394 ..
395 } => {
396 ir.atomic();
397 walk_expr(index, nodes, regions, calls, opaque, bits, ir);
398 if let Some(expected) = expected.as_deref() {
399 walk_expr(expected, nodes, regions, calls, opaque, bits, ir);
400 }
401 walk_expr(value, nodes, regions, calls, opaque, bits, ir);
402 }
403 Expr::Opaque(_) => {
404 *opaque = opaque.saturating_add(1);
405 ir.instruction();
406 }
407 Expr::SubgroupLocalId | Expr::SubgroupSize => {
408 *bits |= CAP_SUBGROUP_OPS;
409 ir.instruction();
410 }
411 Expr::LitU32(_)
412 | Expr::LitI32(_)
413 | Expr::LitF32(_)
414 | Expr::LitBool(_)
415 | Expr::Var(_)
416 | Expr::BufLen { .. }
417 | Expr::InvocationId { .. }
418 | Expr::WorkgroupId { .. }
419 | Expr::LocalId { .. } => {}
420 }
421}
422
423fn is_subgroup_intrinsic_id(op_id: &str) -> bool {
424 const MARKERS: &[&str] = &[
425 "subgroup_",
426 "::subgroup::",
427 "::subgroup",
428 "wave_",
429 "::wave::",
430 "warp_",
431 "::warp::",
432 ];
433 MARKERS.iter().any(|marker| op_id.contains(marker))
434}