Skip to main content

trident/neural/data/
tir_graph.rs

1//! TirGraph — graph representation of TIR for GNN encoding.
2//!
3//! Converts a flat `Vec<TIROp>` into a graph with typed edges:
4//! - DataDep: producer→consumer via abstract stack simulation
5//! - ControlFlow: sequential and branch edges
6//! - MemOrder: conservative ordering between memory operations
7
8use crate::ir::tir::TIROp;
9
10// ─── Types ────────────────────────────────────────────────────────
11
12/// Edge types in the TIR graph.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum EdgeKind {
15    /// Data dependency: op A produces a value consumed by op B.
16    DataDep,
17    /// Control flow: sequential or branch edge.
18    ControlFlow,
19    /// Memory ordering: conservative ordering between memory ops.
20    MemOrder,
21}
22
23/// Field type annotation for a TIR node.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum FieldType {
26    /// Base field element (Goldilocks).
27    BFE,
28    /// Extension field element (cubic extension).
29    XFE,
30    /// Unknown or not applicable.
31    Unknown,
32}
33
34/// Opcode kind — mirrors TIROp variants without payloads.
35/// Used for one-hot encoding in the GNN feature vector.
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
37#[repr(u8)]
38pub enum OpKind {
39    // Tier 0 — Structure (11)
40    Call = 0,
41    Return = 1,
42    Halt = 2,
43    IfElse = 3,
44    IfOnly = 4,
45    Loop = 5,
46    FnStart = 6,
47    FnEnd = 7,
48    Entry = 8,
49    Comment = 9,
50    Asm = 10,
51    // Tier 1 — Universal (31)
52    Push = 11,
53    Pop = 12,
54    Dup = 13,
55    Swap = 14,
56    Add = 15,
57    Sub = 16,
58    Mul = 17,
59    Neg = 18,
60    Invert = 19,
61    Eq = 20,
62    Lt = 21,
63    And = 22,
64    Or = 23,
65    Xor = 24,
66    PopCount = 25,
67    Split = 26,
68    DivMod = 27,
69    Shl = 28,
70    Shr = 29,
71    Log2 = 30,
72    Pow = 31,
73    ReadIo = 32,
74    WriteIo = 33,
75    ReadMem = 34,
76    WriteMem = 35,
77    Assert = 36,
78    Hash = 37,
79    Reveal = 38,
80    Seal = 39,
81    RamRead = 40,
82    RamWrite = 41,
83    // Tier 2 — Provable (7)
84    Hint = 42,
85    SpongeInit = 43,
86    SpongeAbsorb = 44,
87    SpongeSqueeze = 45,
88    SpongeLoad = 46,
89    MerkleStep = 47,
90    MerkleLoad = 48,
91    // Tier 3 — Recursion (5)
92    ExtMul = 49,
93    ExtInvert = 50,
94    FoldExt = 51,
95    FoldBase = 52,
96    ProofBlock = 53,
97}
98
99pub const NUM_OP_KINDS: usize = 54;
100
101impl OpKind {
102    pub fn from_tir_op(op: &TIROp) -> Self {
103        match op {
104            TIROp::Call(_) => OpKind::Call,
105            TIROp::Return => OpKind::Return,
106            TIROp::Halt => OpKind::Halt,
107            TIROp::IfElse { .. } => OpKind::IfElse,
108            TIROp::IfOnly { .. } => OpKind::IfOnly,
109            TIROp::Loop { .. } => OpKind::Loop,
110            TIROp::FnStart(_) => OpKind::FnStart,
111            TIROp::FnEnd => OpKind::FnEnd,
112            TIROp::Entry(_) => OpKind::Entry,
113            TIROp::Comment(_) => OpKind::Comment,
114            TIROp::Asm { .. } => OpKind::Asm,
115            TIROp::Push(_) => OpKind::Push,
116            TIROp::Pop(_) => OpKind::Pop,
117            TIROp::Dup(_) => OpKind::Dup,
118            TIROp::Swap(_) => OpKind::Swap,
119            TIROp::Add => OpKind::Add,
120            TIROp::Sub => OpKind::Sub,
121            TIROp::Mul => OpKind::Mul,
122            TIROp::Neg => OpKind::Neg,
123            TIROp::Invert => OpKind::Invert,
124            TIROp::Eq => OpKind::Eq,
125            TIROp::Lt => OpKind::Lt,
126            TIROp::And => OpKind::And,
127            TIROp::Or => OpKind::Or,
128            TIROp::Xor => OpKind::Xor,
129            TIROp::PopCount => OpKind::PopCount,
130            TIROp::Split => OpKind::Split,
131            TIROp::DivMod => OpKind::DivMod,
132            TIROp::Shl => OpKind::Shl,
133            TIROp::Shr => OpKind::Shr,
134            TIROp::Log2 => OpKind::Log2,
135            TIROp::Pow => OpKind::Pow,
136            TIROp::ReadIo(_) => OpKind::ReadIo,
137            TIROp::WriteIo(_) => OpKind::WriteIo,
138            TIROp::ReadMem(_) => OpKind::ReadMem,
139            TIROp::WriteMem(_) => OpKind::WriteMem,
140            TIROp::Assert(_) => OpKind::Assert,
141            TIROp::Hash { .. } => OpKind::Hash,
142            TIROp::Reveal { .. } => OpKind::Reveal,
143            TIROp::Seal { .. } => OpKind::Seal,
144            TIROp::RamRead { .. } => OpKind::RamRead,
145            TIROp::RamWrite { .. } => OpKind::RamWrite,
146            TIROp::Hint(_) => OpKind::Hint,
147            TIROp::SpongeInit => OpKind::SpongeInit,
148            TIROp::SpongeAbsorb => OpKind::SpongeAbsorb,
149            TIROp::SpongeSqueeze => OpKind::SpongeSqueeze,
150            TIROp::SpongeLoad => OpKind::SpongeLoad,
151            TIROp::MerkleStep => OpKind::MerkleStep,
152            TIROp::MerkleLoad => OpKind::MerkleLoad,
153            TIROp::ExtMul => OpKind::ExtMul,
154            TIROp::ExtInvert => OpKind::ExtInvert,
155            TIROp::FoldExt => OpKind::FoldExt,
156            TIROp::FoldBase => OpKind::FoldBase,
157            TIROp::ProofBlock { .. } => OpKind::ProofBlock,
158        }
159    }
160}
161
162/// A node in the TIR graph.
163#[derive(Debug, Clone)]
164pub struct TirNode {
165    pub op: OpKind,
166    pub field_type: FieldType,
167    pub immediate: Option<u64>,
168}
169
170/// Graph representation of TIR operations.
171#[derive(Debug, Clone)]
172pub struct TirGraph {
173    pub nodes: Vec<TirNode>,
174    pub edges: Vec<(usize, usize, EdgeKind)>,
175}
176
177// ─── Feature Vector ───────────────────────────────────────────────
178
179/// Node feature vector dimensions:
180/// - op_onehot: 54 (NUM_OP_KINDS)
181/// - field_type_onehot: 3 (BFE, XFE, Unknown)
182/// - has_immediate: 1
183/// - immediate_normalized: 1
184/// Total: 59
185pub const NODE_FEATURE_DIM: usize = NUM_OP_KINDS + 3 + 1 + 1;
186
187impl TirNode {
188    /// Encode this node as a 59-dimensional feature vector.
189    pub fn feature_vector(&self) -> [f32; NODE_FEATURE_DIM] {
190        let mut v = [0.0f32; NODE_FEATURE_DIM];
191
192        // One-hot op kind (54 dims)
193        v[self.op as usize] = 1.0;
194
195        // One-hot field type (3 dims, offset 54)
196        let ft_offset = NUM_OP_KINDS;
197        match self.field_type {
198            FieldType::BFE => v[ft_offset] = 1.0,
199            FieldType::XFE => v[ft_offset + 1] = 1.0,
200            FieldType::Unknown => v[ft_offset + 2] = 1.0,
201        }
202
203        // Has immediate (1 dim, offset 57)
204        if self.immediate.is_some() {
205            v[ft_offset + 3] = 1.0;
206        }
207
208        // Normalized immediate (1 dim, offset 58)
209        // Normalize to [0, 1] using log1p for large values
210        if let Some(imm) = self.immediate {
211            v[ft_offset + 4] = (imm as f64 + 1.0).ln() as f32 / 44.4; // ln(2^64) ≈ 44.4
212        }
213
214        v
215    }
216}
217
218// ─── Stack Effects ────────────────────────────────────────────────
219
220/// Determine the field type of a TIROp's output.
221fn output_field_type(op: &TIROp) -> FieldType {
222    match op {
223        TIROp::ExtMul | TIROp::ExtInvert => FieldType::XFE,
224        TIROp::FoldExt => FieldType::XFE,
225        TIROp::SpongeSqueeze => FieldType::BFE,
226        TIROp::Hash { .. } => FieldType::BFE,
227        TIROp::Add
228        | TIROp::Sub
229        | TIROp::Mul
230        | TIROp::Neg
231        | TIROp::Invert
232        | TIROp::Eq
233        | TIROp::Lt
234        | TIROp::And
235        | TIROp::Or
236        | TIROp::Xor
237        | TIROp::DivMod
238        | TIROp::Split
239        | TIROp::Shl
240        | TIROp::Shr
241        | TIROp::Log2
242        | TIROp::Pow
243        | TIROp::PopCount
244        | TIROp::Push(_) => FieldType::BFE,
245        _ => FieldType::Unknown,
246    }
247}
248
249// ─── Graph Construction ───────────────────────────────────────────
250
251impl TirGraph {
252    /// Build a TirGraph from a flat sequence of TIR operations.
253    ///
254    /// Flattens structural ops (IfElse bodies, Loop bodies) into
255    /// a single node list, adding appropriate control flow edges.
256    pub fn from_tir_ops(ops: &[TIROp]) -> Self {
257        let mut nodes = Vec::new();
258        let mut edges = Vec::new();
259
260        // Flatten ops into nodes, recursing into structural bodies
261        flatten_ops(ops, &mut nodes, &mut edges);
262
263        // Extract DataDep edges via abstract stack simulation
264        extract_data_deps(&nodes, &mut edges);
265
266        // Extract MemOrder edges (conservative pairwise ordering)
267        extract_mem_order(&nodes, &mut edges);
268
269        TirGraph { nodes, edges }
270    }
271
272    /// Number of nodes.
273    pub fn num_nodes(&self) -> usize {
274        self.nodes.len()
275    }
276
277    /// Number of edges.
278    pub fn num_edges(&self) -> usize {
279        self.edges.len()
280    }
281
282    /// Count edges of a specific kind.
283    pub fn count_edges(&self, kind: EdgeKind) -> usize {
284        self.edges.iter().filter(|(_, _, k)| *k == kind).count()
285    }
286}
287
288/// Flatten TIR ops into graph nodes, handling structural ops recursively.
289/// Adds ControlFlow edges between sequential ops and into/out-of bodies.
290fn flatten_ops(ops: &[TIROp], nodes: &mut Vec<TirNode>, edges: &mut Vec<(usize, usize, EdgeKind)>) {
291    let mut prev_idx: Option<usize> = None;
292
293    for op in ops {
294        let idx = nodes.len();
295
296        // Determine immediate value (Q5: BFE only, XFE ops get None)
297        let immediate = match op {
298            TIROp::Push(v) => Some(*v),
299            TIROp::Pop(n) | TIROp::Dup(n) | TIROp::Swap(n) => Some(*n as u64),
300            TIROp::ReadIo(n)
301            | TIROp::WriteIo(n)
302            | TIROp::ReadMem(n)
303            | TIROp::WriteMem(n)
304            | TIROp::Assert(n)
305            | TIROp::Hint(n) => Some(*n as u64),
306            TIROp::Hash { width } | TIROp::RamRead { width } | TIROp::RamWrite { width } => {
307                Some(*width as u64)
308            }
309            TIROp::Reveal { field_count, .. } | TIROp::Seal { field_count, .. } => {
310                Some(*field_count as u64)
311            }
312            TIROp::Asm { effect, .. } => Some(*effect as u64),
313            // XFE ops: has_immediate=0 per Q5 resolution
314            TIROp::ExtMul | TIROp::ExtInvert => None,
315            _ => None,
316        };
317
318        let node = TirNode {
319            op: OpKind::from_tir_op(op),
320            field_type: output_field_type(op),
321            immediate,
322        };
323        nodes.push(node);
324
325        // Sequential ControlFlow edge
326        if let Some(p) = prev_idx {
327            edges.push((p, idx, EdgeKind::ControlFlow));
328        }
329
330        // Recurse into structural bodies
331        match op {
332            TIROp::IfElse {
333                then_body,
334                else_body,
335            } => {
336                if !then_body.is_empty() {
337                    let then_start = nodes.len();
338                    flatten_ops(then_body, nodes, edges);
339                    edges.push((idx, then_start, EdgeKind::ControlFlow));
340                }
341                if !else_body.is_empty() {
342                    let else_start = nodes.len();
343                    flatten_ops(else_body, nodes, edges);
344                    edges.push((idx, else_start, EdgeKind::ControlFlow));
345                }
346            }
347            TIROp::IfOnly { then_body } => {
348                if !then_body.is_empty() {
349                    let then_start = nodes.len();
350                    flatten_ops(then_body, nodes, edges);
351                    edges.push((idx, then_start, EdgeKind::ControlFlow));
352                }
353            }
354            TIROp::Loop { body, .. } => {
355                if !body.is_empty() {
356                    let body_start = nodes.len();
357                    flatten_ops(body, nodes, edges);
358                    let body_end = nodes.len() - 1;
359                    edges.push((idx, body_start, EdgeKind::ControlFlow));
360                    // Back edge: loop body end → loop header
361                    edges.push((body_end, idx, EdgeKind::ControlFlow));
362                }
363            }
364            TIROp::ProofBlock { body, .. } => {
365                if !body.is_empty() {
366                    let body_start = nodes.len();
367                    flatten_ops(body, nodes, edges);
368                    edges.push((idx, body_start, EdgeKind::ControlFlow));
369                }
370            }
371            _ => {}
372        }
373
374        prev_idx = Some(idx);
375    }
376}
377
378/// Abstract stack entry: tracks which node produced this value.
379#[derive(Clone, Copy)]
380struct StackEntry {
381    producer: usize,
382}
383
384/// Extract DataDep edges by simulating an abstract stack.
385/// When op B pops a value produced by op A → edge (A→B, DataDep).
386fn extract_data_deps(nodes: &[TirNode], edges: &mut Vec<(usize, usize, EdgeKind)>) {
387    let mut stack: Vec<StackEntry> = Vec::new();
388
389    for (idx, node) in nodes.iter().enumerate() {
390        let (pops, pushes) = stack_effect_from_kind(node);
391
392        // Pop: create DataDep edges from producers to this consumer
393        let actual_pops = pops.min(stack.len());
394        for _ in 0..actual_pops {
395            if let Some(entry) = stack.pop() {
396                edges.push((entry.producer, idx, EdgeKind::DataDep));
397            }
398        }
399
400        // Handle Dup specially: reads from depth without consuming
401        if node.op == OpKind::Dup {
402            let depth = node.immediate.unwrap_or(0) as usize;
403            if depth < stack.len() {
404                let producer = stack[stack.len() - 1 - depth].producer;
405                edges.push((producer, idx, EdgeKind::DataDep));
406            }
407        }
408
409        // Handle Swap: creates read-dependencies on both swapped positions
410        if node.op == OpKind::Swap {
411            let depth = node.immediate.unwrap_or(1) as usize;
412            if depth < stack.len() && !stack.is_empty() {
413                let top = stack.len() - 1;
414                let other = stack.len() - 1 - depth;
415                stack.swap(top, other);
416            }
417        }
418
419        // Push: record this node as producer
420        for _ in 0..pushes {
421            stack.push(StackEntry { producer: idx });
422        }
423    }
424}
425
426/// Get stack effect from a TirNode (using OpKind + immediate).
427fn stack_effect_from_kind(node: &TirNode) -> (usize, usize) {
428    let n = node.immediate.unwrap_or(0) as usize;
429    match node.op {
430        OpKind::Push => (0, 1),
431        OpKind::Pop => (n, 0),
432        OpKind::Dup => (0, 1),
433        OpKind::Swap => (0, 0),
434        OpKind::Add | OpKind::Sub | OpKind::Mul => (2, 1),
435        OpKind::Neg | OpKind::Invert => (1, 1),
436        OpKind::Eq | OpKind::Lt => (2, 1),
437        OpKind::And | OpKind::Or | OpKind::Xor => (2, 1),
438        OpKind::PopCount | OpKind::Log2 => (1, 1),
439        OpKind::Split => (1, 2),
440        OpKind::DivMod => (2, 2),
441        OpKind::Shl | OpKind::Shr | OpKind::Pow => (2, 1),
442        OpKind::ReadIo => (0, n),
443        OpKind::WriteIo => (n, 0),
444        OpKind::ReadMem => (1, n + 1),
445        OpKind::WriteMem => (n + 1, 1),
446        OpKind::Assert => (n, 0),
447        OpKind::Hash => (10, 5),
448        OpKind::Reveal | OpKind::Seal => (n, 0),
449        OpKind::RamRead => (1, n),
450        OpKind::RamWrite => (n + 1, 0),
451        OpKind::Hint => (0, n),
452        OpKind::SpongeInit => (0, 0),
453        OpKind::SpongeAbsorb => (10, 0),
454        OpKind::SpongeSqueeze => (0, 10),
455        OpKind::SpongeLoad => (1, 1),
456        OpKind::MerkleStep | OpKind::MerkleLoad => (0, 0),
457        OpKind::ExtMul => (6, 3),
458        OpKind::ExtInvert => (3, 3),
459        OpKind::FoldExt | OpKind::FoldBase => (0, 0),
460        OpKind::IfElse | OpKind::IfOnly | OpKind::Loop => (1, 0),
461        _ => (0, 0),
462    }
463}
464
465/// Extract MemOrder edges: pairwise between all memory operations.
466/// Conservative — preserves all possible orderings.
467fn extract_mem_order(nodes: &[TirNode], edges: &mut Vec<(usize, usize, EdgeKind)>) {
468    let mem_indices: Vec<usize> = nodes
469        .iter()
470        .enumerate()
471        .filter(|(_, n)| {
472            matches!(
473                n.op,
474                OpKind::ReadMem
475                    | OpKind::WriteMem
476                    | OpKind::RamRead
477                    | OpKind::RamWrite
478                    | OpKind::SpongeLoad
479                    | OpKind::MerkleLoad
480            )
481        })
482        .map(|(i, _)| i)
483        .collect();
484
485    // Pairwise edges between consecutive memory ops (not O(n²) — sequential ordering)
486    for window in mem_indices.windows(2) {
487        edges.push((window[0], window[1], EdgeKind::MemOrder));
488    }
489}
490
491// ─── Tests ────────────────────────────────────────────────────────
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496
497    #[test]
498    fn simple_arithmetic_graph() {
499        // push 3; push 4; add → 3 nodes
500        let ops = vec![TIROp::Push(3), TIROp::Push(4), TIROp::Add];
501        let graph = TirGraph::from_tir_ops(&ops);
502
503        assert_eq!(graph.num_nodes(), 3);
504        // ControlFlow: push3→push4, push4→add = 2
505        assert_eq!(graph.count_edges(EdgeKind::ControlFlow), 2);
506        // DataDep: push3→add, push4→add = 2
507        assert_eq!(graph.count_edges(EdgeKind::DataDep), 2);
508        assert_eq!(graph.count_edges(EdgeKind::MemOrder), 0);
509    }
510
511    #[test]
512    fn memory_ops_get_mem_order_edges() {
513        let ops = vec![
514            TIROp::Push(100),
515            TIROp::ReadMem(1),
516            TIROp::Push(200),
517            TIROp::WriteMem(1),
518        ];
519        let graph = TirGraph::from_tir_ops(&ops);
520
521        assert_eq!(graph.num_nodes(), 4);
522        assert!(graph.count_edges(EdgeKind::MemOrder) >= 1);
523    }
524
525    #[test]
526    fn if_else_creates_branch_edges() {
527        let ops = vec![
528            TIROp::Push(1), // condition
529            TIROp::IfElse {
530                then_body: vec![TIROp::Push(10)],
531                else_body: vec![TIROp::Push(20)],
532            },
533        ];
534        let graph = TirGraph::from_tir_ops(&ops);
535
536        // 4 nodes: Push(1), IfElse, Push(10), Push(20)
537        assert_eq!(graph.num_nodes(), 4);
538        // ControlFlow edges include branch edges to both bodies
539        let cf = graph.count_edges(EdgeKind::ControlFlow);
540        assert!(cf >= 3, "expected ≥3 CF edges, got {}", cf);
541    }
542
543    #[test]
544    fn loop_creates_back_edge() {
545        let ops = vec![
546            TIROp::Push(5),
547            TIROp::Loop {
548                label: "l".into(),
549                body: vec![TIROp::Push(1), TIROp::Sub],
550            },
551        ];
552        let graph = TirGraph::from_tir_ops(&ops);
553
554        // 4 nodes: Push(5), Loop, Push(1), Sub
555        assert_eq!(graph.num_nodes(), 4);
556        // Should have a back edge from Sub → Loop
557        let has_back_edge = graph
558            .edges
559            .iter()
560            .any(|(from, to, kind)| *kind == EdgeKind::ControlFlow && *from == 3 && *to == 1);
561        assert!(has_back_edge, "missing loop back edge");
562    }
563
564    #[test]
565    fn feature_vector_dimensions() {
566        let node = TirNode {
567            op: OpKind::Add,
568            field_type: FieldType::BFE,
569            immediate: None,
570        };
571        let fv = node.feature_vector();
572        assert_eq!(fv.len(), NODE_FEATURE_DIM);
573        assert_eq!(fv.len(), 59);
574        // Add is index 15
575        assert_eq!(fv[15], 1.0);
576        // BFE is index 54
577        assert_eq!(fv[54], 1.0);
578        // No immediate
579        assert_eq!(fv[57], 0.0);
580    }
581
582    #[test]
583    fn feature_vector_with_immediate() {
584        let node = TirNode {
585            op: OpKind::Push,
586            field_type: FieldType::BFE,
587            immediate: Some(42),
588        };
589        let fv = node.feature_vector();
590        assert_eq!(fv[11], 1.0); // Push is index 11
591        assert_eq!(fv[57], 1.0); // has_immediate = 1
592        assert!(fv[58] > 0.0); // normalized immediate > 0
593    }
594
595    #[test]
596    fn empty_ops_produces_empty_graph() {
597        let graph = TirGraph::from_tir_ops(&[]);
598        assert_eq!(graph.num_nodes(), 0);
599        assert_eq!(graph.num_edges(), 0);
600    }
601
602    #[test]
603    fn all_54_op_kinds_are_numbered() {
604        assert_eq!(OpKind::Call as u8, 0);
605        assert_eq!(OpKind::ProofBlock as u8, 53);
606        assert_eq!(NUM_OP_KINDS, 54);
607    }
608
609    #[test]
610    fn dup_creates_data_dep_without_consuming() {
611        // push 7; dup 0 → dup reads from push without consuming
612        let ops = vec![TIROp::Push(7), TIROp::Dup(0)];
613        let graph = TirGraph::from_tir_ops(&ops);
614
615        assert_eq!(graph.num_nodes(), 2);
616        // DataDep: push→dup (read dependency)
617        assert_eq!(graph.count_edges(EdgeKind::DataDep), 1);
618    }
619}