1use crate::ir::tir::TIROp;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum EdgeKind {
15 DataDep,
17 ControlFlow,
19 MemOrder,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum FieldType {
26 BFE,
28 XFE,
30 Unknown,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
37#[repr(u8)]
38pub enum OpKind {
39 Call = 0,
41 Return = 1,
42 Halt = 2,
43 IfElse = 3,
44 IfOnly = 4,
45 Loop = 5,
46 FnStart = 6,
47 FnEnd = 7,
48 Entry = 8,
49 Comment = 9,
50 Asm = 10,
51 Push = 11,
53 Pop = 12,
54 Dup = 13,
55 Swap = 14,
56 Add = 15,
57 Sub = 16,
58 Mul = 17,
59 Neg = 18,
60 Invert = 19,
61 Eq = 20,
62 Lt = 21,
63 And = 22,
64 Or = 23,
65 Xor = 24,
66 PopCount = 25,
67 Split = 26,
68 DivMod = 27,
69 Shl = 28,
70 Shr = 29,
71 Log2 = 30,
72 Pow = 31,
73 ReadIo = 32,
74 WriteIo = 33,
75 ReadMem = 34,
76 WriteMem = 35,
77 Assert = 36,
78 Hash = 37,
79 Reveal = 38,
80 Seal = 39,
81 RamRead = 40,
82 RamWrite = 41,
83 Hint = 42,
85 SpongeInit = 43,
86 SpongeAbsorb = 44,
87 SpongeSqueeze = 45,
88 SpongeLoad = 46,
89 MerkleStep = 47,
90 MerkleLoad = 48,
91 ExtMul = 49,
93 ExtInvert = 50,
94 FoldExt = 51,
95 FoldBase = 52,
96 ProofBlock = 53,
97}
98
99pub const NUM_OP_KINDS: usize = 54;
100
101impl OpKind {
102 pub fn from_tir_op(op: &TIROp) -> Self {
103 match op {
104 TIROp::Call(_) => OpKind::Call,
105 TIROp::Return => OpKind::Return,
106 TIROp::Halt => OpKind::Halt,
107 TIROp::IfElse { .. } => OpKind::IfElse,
108 TIROp::IfOnly { .. } => OpKind::IfOnly,
109 TIROp::Loop { .. } => OpKind::Loop,
110 TIROp::FnStart(_) => OpKind::FnStart,
111 TIROp::FnEnd => OpKind::FnEnd,
112 TIROp::Entry(_) => OpKind::Entry,
113 TIROp::Comment(_) => OpKind::Comment,
114 TIROp::Asm { .. } => OpKind::Asm,
115 TIROp::Push(_) => OpKind::Push,
116 TIROp::Pop(_) => OpKind::Pop,
117 TIROp::Dup(_) => OpKind::Dup,
118 TIROp::Swap(_) => OpKind::Swap,
119 TIROp::Add => OpKind::Add,
120 TIROp::Sub => OpKind::Sub,
121 TIROp::Mul => OpKind::Mul,
122 TIROp::Neg => OpKind::Neg,
123 TIROp::Invert => OpKind::Invert,
124 TIROp::Eq => OpKind::Eq,
125 TIROp::Lt => OpKind::Lt,
126 TIROp::And => OpKind::And,
127 TIROp::Or => OpKind::Or,
128 TIROp::Xor => OpKind::Xor,
129 TIROp::PopCount => OpKind::PopCount,
130 TIROp::Split => OpKind::Split,
131 TIROp::DivMod => OpKind::DivMod,
132 TIROp::Shl => OpKind::Shl,
133 TIROp::Shr => OpKind::Shr,
134 TIROp::Log2 => OpKind::Log2,
135 TIROp::Pow => OpKind::Pow,
136 TIROp::ReadIo(_) => OpKind::ReadIo,
137 TIROp::WriteIo(_) => OpKind::WriteIo,
138 TIROp::ReadMem(_) => OpKind::ReadMem,
139 TIROp::WriteMem(_) => OpKind::WriteMem,
140 TIROp::Assert(_) => OpKind::Assert,
141 TIROp::Hash { .. } => OpKind::Hash,
142 TIROp::Reveal { .. } => OpKind::Reveal,
143 TIROp::Seal { .. } => OpKind::Seal,
144 TIROp::RamRead { .. } => OpKind::RamRead,
145 TIROp::RamWrite { .. } => OpKind::RamWrite,
146 TIROp::Hint(_) => OpKind::Hint,
147 TIROp::SpongeInit => OpKind::SpongeInit,
148 TIROp::SpongeAbsorb => OpKind::SpongeAbsorb,
149 TIROp::SpongeSqueeze => OpKind::SpongeSqueeze,
150 TIROp::SpongeLoad => OpKind::SpongeLoad,
151 TIROp::MerkleStep => OpKind::MerkleStep,
152 TIROp::MerkleLoad => OpKind::MerkleLoad,
153 TIROp::ExtMul => OpKind::ExtMul,
154 TIROp::ExtInvert => OpKind::ExtInvert,
155 TIROp::FoldExt => OpKind::FoldExt,
156 TIROp::FoldBase => OpKind::FoldBase,
157 TIROp::ProofBlock { .. } => OpKind::ProofBlock,
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct TirNode {
165 pub op: OpKind,
166 pub field_type: FieldType,
167 pub immediate: Option<u64>,
168}
169
170#[derive(Debug, Clone)]
172pub struct TirGraph {
173 pub nodes: Vec<TirNode>,
174 pub edges: Vec<(usize, usize, EdgeKind)>,
175}
176
177pub const NODE_FEATURE_DIM: usize = NUM_OP_KINDS + 3 + 1 + 1;
186
187impl TirNode {
188 pub fn feature_vector(&self) -> [f32; NODE_FEATURE_DIM] {
190 let mut v = [0.0f32; NODE_FEATURE_DIM];
191
192 v[self.op as usize] = 1.0;
194
195 let ft_offset = NUM_OP_KINDS;
197 match self.field_type {
198 FieldType::BFE => v[ft_offset] = 1.0,
199 FieldType::XFE => v[ft_offset + 1] = 1.0,
200 FieldType::Unknown => v[ft_offset + 2] = 1.0,
201 }
202
203 if self.immediate.is_some() {
205 v[ft_offset + 3] = 1.0;
206 }
207
208 if let Some(imm) = self.immediate {
211 v[ft_offset + 4] = (imm as f64 + 1.0).ln() as f32 / 44.4; }
213
214 v
215 }
216}
217
218fn output_field_type(op: &TIROp) -> FieldType {
222 match op {
223 TIROp::ExtMul | TIROp::ExtInvert => FieldType::XFE,
224 TIROp::FoldExt => FieldType::XFE,
225 TIROp::SpongeSqueeze => FieldType::BFE,
226 TIROp::Hash { .. } => FieldType::BFE,
227 TIROp::Add
228 | TIROp::Sub
229 | TIROp::Mul
230 | TIROp::Neg
231 | TIROp::Invert
232 | TIROp::Eq
233 | TIROp::Lt
234 | TIROp::And
235 | TIROp::Or
236 | TIROp::Xor
237 | TIROp::DivMod
238 | TIROp::Split
239 | TIROp::Shl
240 | TIROp::Shr
241 | TIROp::Log2
242 | TIROp::Pow
243 | TIROp::PopCount
244 | TIROp::Push(_) => FieldType::BFE,
245 _ => FieldType::Unknown,
246 }
247}
248
249impl TirGraph {
252 pub fn from_tir_ops(ops: &[TIROp]) -> Self {
257 let mut nodes = Vec::new();
258 let mut edges = Vec::new();
259
260 flatten_ops(ops, &mut nodes, &mut edges);
262
263 extract_data_deps(&nodes, &mut edges);
265
266 extract_mem_order(&nodes, &mut edges);
268
269 TirGraph { nodes, edges }
270 }
271
272 pub fn num_nodes(&self) -> usize {
274 self.nodes.len()
275 }
276
277 pub fn num_edges(&self) -> usize {
279 self.edges.len()
280 }
281
282 pub fn count_edges(&self, kind: EdgeKind) -> usize {
284 self.edges.iter().filter(|(_, _, k)| *k == kind).count()
285 }
286}
287
288fn flatten_ops(ops: &[TIROp], nodes: &mut Vec<TirNode>, edges: &mut Vec<(usize, usize, EdgeKind)>) {
291 let mut prev_idx: Option<usize> = None;
292
293 for op in ops {
294 let idx = nodes.len();
295
296 let immediate = match op {
298 TIROp::Push(v) => Some(*v),
299 TIROp::Pop(n) | TIROp::Dup(n) | TIROp::Swap(n) => Some(*n as u64),
300 TIROp::ReadIo(n)
301 | TIROp::WriteIo(n)
302 | TIROp::ReadMem(n)
303 | TIROp::WriteMem(n)
304 | TIROp::Assert(n)
305 | TIROp::Hint(n) => Some(*n as u64),
306 TIROp::Hash { width } | TIROp::RamRead { width } | TIROp::RamWrite { width } => {
307 Some(*width as u64)
308 }
309 TIROp::Reveal { field_count, .. } | TIROp::Seal { field_count, .. } => {
310 Some(*field_count as u64)
311 }
312 TIROp::Asm { effect, .. } => Some(*effect as u64),
313 TIROp::ExtMul | TIROp::ExtInvert => None,
315 _ => None,
316 };
317
318 let node = TirNode {
319 op: OpKind::from_tir_op(op),
320 field_type: output_field_type(op),
321 immediate,
322 };
323 nodes.push(node);
324
325 if let Some(p) = prev_idx {
327 edges.push((p, idx, EdgeKind::ControlFlow));
328 }
329
330 match op {
332 TIROp::IfElse {
333 then_body,
334 else_body,
335 } => {
336 if !then_body.is_empty() {
337 let then_start = nodes.len();
338 flatten_ops(then_body, nodes, edges);
339 edges.push((idx, then_start, EdgeKind::ControlFlow));
340 }
341 if !else_body.is_empty() {
342 let else_start = nodes.len();
343 flatten_ops(else_body, nodes, edges);
344 edges.push((idx, else_start, EdgeKind::ControlFlow));
345 }
346 }
347 TIROp::IfOnly { then_body } => {
348 if !then_body.is_empty() {
349 let then_start = nodes.len();
350 flatten_ops(then_body, nodes, edges);
351 edges.push((idx, then_start, EdgeKind::ControlFlow));
352 }
353 }
354 TIROp::Loop { body, .. } => {
355 if !body.is_empty() {
356 let body_start = nodes.len();
357 flatten_ops(body, nodes, edges);
358 let body_end = nodes.len() - 1;
359 edges.push((idx, body_start, EdgeKind::ControlFlow));
360 edges.push((body_end, idx, EdgeKind::ControlFlow));
362 }
363 }
364 TIROp::ProofBlock { body, .. } => {
365 if !body.is_empty() {
366 let body_start = nodes.len();
367 flatten_ops(body, nodes, edges);
368 edges.push((idx, body_start, EdgeKind::ControlFlow));
369 }
370 }
371 _ => {}
372 }
373
374 prev_idx = Some(idx);
375 }
376}
377
378#[derive(Clone, Copy)]
380struct StackEntry {
381 producer: usize,
382}
383
384fn extract_data_deps(nodes: &[TirNode], edges: &mut Vec<(usize, usize, EdgeKind)>) {
387 let mut stack: Vec<StackEntry> = Vec::new();
388
389 for (idx, node) in nodes.iter().enumerate() {
390 let (pops, pushes) = stack_effect_from_kind(node);
391
392 let actual_pops = pops.min(stack.len());
394 for _ in 0..actual_pops {
395 if let Some(entry) = stack.pop() {
396 edges.push((entry.producer, idx, EdgeKind::DataDep));
397 }
398 }
399
400 if node.op == OpKind::Dup {
402 let depth = node.immediate.unwrap_or(0) as usize;
403 if depth < stack.len() {
404 let producer = stack[stack.len() - 1 - depth].producer;
405 edges.push((producer, idx, EdgeKind::DataDep));
406 }
407 }
408
409 if node.op == OpKind::Swap {
411 let depth = node.immediate.unwrap_or(1) as usize;
412 if depth < stack.len() && !stack.is_empty() {
413 let top = stack.len() - 1;
414 let other = stack.len() - 1 - depth;
415 stack.swap(top, other);
416 }
417 }
418
419 for _ in 0..pushes {
421 stack.push(StackEntry { producer: idx });
422 }
423 }
424}
425
426fn stack_effect_from_kind(node: &TirNode) -> (usize, usize) {
428 let n = node.immediate.unwrap_or(0) as usize;
429 match node.op {
430 OpKind::Push => (0, 1),
431 OpKind::Pop => (n, 0),
432 OpKind::Dup => (0, 1),
433 OpKind::Swap => (0, 0),
434 OpKind::Add | OpKind::Sub | OpKind::Mul => (2, 1),
435 OpKind::Neg | OpKind::Invert => (1, 1),
436 OpKind::Eq | OpKind::Lt => (2, 1),
437 OpKind::And | OpKind::Or | OpKind::Xor => (2, 1),
438 OpKind::PopCount | OpKind::Log2 => (1, 1),
439 OpKind::Split => (1, 2),
440 OpKind::DivMod => (2, 2),
441 OpKind::Shl | OpKind::Shr | OpKind::Pow => (2, 1),
442 OpKind::ReadIo => (0, n),
443 OpKind::WriteIo => (n, 0),
444 OpKind::ReadMem => (1, n + 1),
445 OpKind::WriteMem => (n + 1, 1),
446 OpKind::Assert => (n, 0),
447 OpKind::Hash => (10, 5),
448 OpKind::Reveal | OpKind::Seal => (n, 0),
449 OpKind::RamRead => (1, n),
450 OpKind::RamWrite => (n + 1, 0),
451 OpKind::Hint => (0, n),
452 OpKind::SpongeInit => (0, 0),
453 OpKind::SpongeAbsorb => (10, 0),
454 OpKind::SpongeSqueeze => (0, 10),
455 OpKind::SpongeLoad => (1, 1),
456 OpKind::MerkleStep | OpKind::MerkleLoad => (0, 0),
457 OpKind::ExtMul => (6, 3),
458 OpKind::ExtInvert => (3, 3),
459 OpKind::FoldExt | OpKind::FoldBase => (0, 0),
460 OpKind::IfElse | OpKind::IfOnly | OpKind::Loop => (1, 0),
461 _ => (0, 0),
462 }
463}
464
465fn extract_mem_order(nodes: &[TirNode], edges: &mut Vec<(usize, usize, EdgeKind)>) {
468 let mem_indices: Vec<usize> = nodes
469 .iter()
470 .enumerate()
471 .filter(|(_, n)| {
472 matches!(
473 n.op,
474 OpKind::ReadMem
475 | OpKind::WriteMem
476 | OpKind::RamRead
477 | OpKind::RamWrite
478 | OpKind::SpongeLoad
479 | OpKind::MerkleLoad
480 )
481 })
482 .map(|(i, _)| i)
483 .collect();
484
485 for window in mem_indices.windows(2) {
487 edges.push((window[0], window[1], EdgeKind::MemOrder));
488 }
489}
490
491#[cfg(test)]
494mod tests {
495 use super::*;
496
497 #[test]
498 fn simple_arithmetic_graph() {
499 let ops = vec![TIROp::Push(3), TIROp::Push(4), TIROp::Add];
501 let graph = TirGraph::from_tir_ops(&ops);
502
503 assert_eq!(graph.num_nodes(), 3);
504 assert_eq!(graph.count_edges(EdgeKind::ControlFlow), 2);
506 assert_eq!(graph.count_edges(EdgeKind::DataDep), 2);
508 assert_eq!(graph.count_edges(EdgeKind::MemOrder), 0);
509 }
510
511 #[test]
512 fn memory_ops_get_mem_order_edges() {
513 let ops = vec![
514 TIROp::Push(100),
515 TIROp::ReadMem(1),
516 TIROp::Push(200),
517 TIROp::WriteMem(1),
518 ];
519 let graph = TirGraph::from_tir_ops(&ops);
520
521 assert_eq!(graph.num_nodes(), 4);
522 assert!(graph.count_edges(EdgeKind::MemOrder) >= 1);
523 }
524
525 #[test]
526 fn if_else_creates_branch_edges() {
527 let ops = vec![
528 TIROp::Push(1), TIROp::IfElse {
530 then_body: vec![TIROp::Push(10)],
531 else_body: vec![TIROp::Push(20)],
532 },
533 ];
534 let graph = TirGraph::from_tir_ops(&ops);
535
536 assert_eq!(graph.num_nodes(), 4);
538 let cf = graph.count_edges(EdgeKind::ControlFlow);
540 assert!(cf >= 3, "expected ≥3 CF edges, got {}", cf);
541 }
542
543 #[test]
544 fn loop_creates_back_edge() {
545 let ops = vec![
546 TIROp::Push(5),
547 TIROp::Loop {
548 label: "l".into(),
549 body: vec![TIROp::Push(1), TIROp::Sub],
550 },
551 ];
552 let graph = TirGraph::from_tir_ops(&ops);
553
554 assert_eq!(graph.num_nodes(), 4);
556 let has_back_edge = graph
558 .edges
559 .iter()
560 .any(|(from, to, kind)| *kind == EdgeKind::ControlFlow && *from == 3 && *to == 1);
561 assert!(has_back_edge, "missing loop back edge");
562 }
563
564 #[test]
565 fn feature_vector_dimensions() {
566 let node = TirNode {
567 op: OpKind::Add,
568 field_type: FieldType::BFE,
569 immediate: None,
570 };
571 let fv = node.feature_vector();
572 assert_eq!(fv.len(), NODE_FEATURE_DIM);
573 assert_eq!(fv.len(), 59);
574 assert_eq!(fv[15], 1.0);
576 assert_eq!(fv[54], 1.0);
578 assert_eq!(fv[57], 0.0);
580 }
581
582 #[test]
583 fn feature_vector_with_immediate() {
584 let node = TirNode {
585 op: OpKind::Push,
586 field_type: FieldType::BFE,
587 immediate: Some(42),
588 };
589 let fv = node.feature_vector();
590 assert_eq!(fv[11], 1.0); assert_eq!(fv[57], 1.0); assert!(fv[58] > 0.0); }
594
595 #[test]
596 fn empty_ops_produces_empty_graph() {
597 let graph = TirGraph::from_tir_ops(&[]);
598 assert_eq!(graph.num_nodes(), 0);
599 assert_eq!(graph.num_edges(), 0);
600 }
601
602 #[test]
603 fn all_54_op_kinds_are_numbered() {
604 assert_eq!(OpKind::Call as u8, 0);
605 assert_eq!(OpKind::ProofBlock as u8, 53);
606 assert_eq!(NUM_OP_KINDS, 54);
607 }
608
609 #[test]
610 fn dup_creates_data_dep_without_consuming() {
611 let ops = vec![TIROp::Push(7), TIROp::Dup(0)];
613 let graph = TirGraph::from_tir_ops(&ops);
614
615 assert_eq!(graph.num_nodes(), 2);
616 assert_eq!(graph.count_edges(EdgeKind::DataDep), 1);
618 }
619}