Skip to main content

trident/ir/tir/
encode.rs

1//! TIR block encoding for neural optimizer input.
2//!
3//! Encodes TIR basic blocks as fixed-size tensors. Each node = 4 u64 words.
4//! Max 32 nodes per block. Plus 16-element stack context vector.
5//! Total: 144 u64 values per block.
6
7use super::TIROp;
8
9/// Maximum nodes per encoded block.
10pub const MAX_NODES: usize = 32;
11/// Words per node encoding.
12pub const WORDS_PER_NODE: usize = 4;
13/// Stack context elements.
14pub const CONTEXT_SIZE: usize = 16;
15/// Total input dimension: MAX_NODES * WORDS_PER_NODE + CONTEXT_SIZE.
16pub const INPUT_DIM: usize = MAX_NODES * WORDS_PER_NODE + CONTEXT_SIZE;
17
18/// Encoded TIR basic block for neural optimizer input.
19#[derive(Clone, Debug)]
20pub struct TIRBlock {
21    /// 32 nodes * 4 words = 128 u64 values, zero-padded.
22    pub nodes: [u64; MAX_NODES * WORDS_PER_NODE],
23    /// Stack state context at block entry (16 elements).
24    pub context: [u64; CONTEXT_SIZE],
25    /// Number of actual nodes (before padding).
26    pub node_count: usize,
27    /// Source location: function name.
28    pub fn_name: String,
29    /// Source location: start index in the original TIR op sequence.
30    pub start_idx: usize,
31    /// Source location: end index (exclusive).
32    pub end_idx: usize,
33}
34
35impl TIRBlock {
36    /// Flattened input tensor (144 elements) for the neural model.
37    pub fn as_input(&self) -> Vec<u64> {
38        let mut v = Vec::with_capacity(INPUT_DIM);
39        v.extend_from_slice(&self.nodes);
40        v.extend_from_slice(&self.context);
41        v
42    }
43
44    /// Block identifier for display (e.g., "main:0..14").
45    pub fn block_id(&self) -> String {
46        format!("{}:{}..{}", self.fn_name, self.start_idx, self.end_idx)
47    }
48}
49
50/// Opcode mapping: TIROp variant -> 0..53 (6 bits).
51fn opcode(op: &TIROp) -> u8 {
52    match op {
53        // Tier 0 — Structure (0..10)
54        TIROp::Call(_) => 0,
55        TIROp::Return => 1,
56        TIROp::Halt => 2,
57        TIROp::IfElse { .. } => 3,
58        TIROp::IfOnly { .. } => 4,
59        TIROp::Loop { .. } => 5,
60        TIROp::FnStart(_) => 6,
61        TIROp::FnEnd => 7,
62        TIROp::Entry(_) => 8,
63        TIROp::Comment(_) => 9,
64        TIROp::Asm { .. } => 10,
65        // Tier 1 — Universal (11..41)
66        TIROp::Push(_) => 11,
67        TIROp::Pop(_) => 12,
68        TIROp::Dup(_) => 13,
69        TIROp::Swap(_) => 14,
70        TIROp::Add => 15,
71        TIROp::Sub => 16,
72        TIROp::Mul => 17,
73        TIROp::Neg => 18,
74        TIROp::Invert => 19,
75        TIROp::Eq => 20,
76        TIROp::Lt => 21,
77        TIROp::And => 22,
78        TIROp::Or => 23,
79        TIROp::Xor => 24,
80        TIROp::PopCount => 25,
81        TIROp::Split => 26,
82        TIROp::DivMod => 27,
83        TIROp::Shl => 28,
84        TIROp::Shr => 29,
85        TIROp::Log2 => 30,
86        TIROp::Pow => 31,
87        TIROp::ReadIo(_) => 32,
88        TIROp::WriteIo(_) => 33,
89        TIROp::ReadMem(_) => 34,
90        TIROp::WriteMem(_) => 35,
91        TIROp::Assert(_) => 36,
92        TIROp::Hash { .. } => 37,
93        TIROp::Reveal { .. } => 38,
94        TIROp::Seal { .. } => 39,
95        TIROp::RamRead { .. } => 40,
96        TIROp::RamWrite { .. } => 41,
97        // Tier 2 — Provable (42..48)
98        TIROp::Hint(_) => 42,
99        TIROp::SpongeInit => 43,
100        TIROp::SpongeAbsorb => 44,
101        TIROp::SpongeSqueeze => 45,
102        TIROp::SpongeLoad => 46,
103        TIROp::MerkleStep => 47,
104        TIROp::MerkleLoad => 48,
105        // Tier 3 — Recursion (49..53)
106        TIROp::ExtMul => 49,
107        TIROp::ExtInvert => 50,
108        TIROp::FoldExt => 51,
109        TIROp::FoldBase => 52,
110        TIROp::ProofBlock { .. } => 53,
111    }
112}
113
114/// Extract the immediate argument from a TIROp (if any).
115fn immediate(op: &TIROp) -> u64 {
116    match op {
117        TIROp::Push(v) => *v,
118        TIROp::Pop(n) | TIROp::Dup(n) | TIROp::Swap(n) => *n as u64,
119        TIROp::ReadIo(n) | TIROp::WriteIo(n) => *n as u64,
120        TIROp::ReadMem(n) | TIROp::WriteMem(n) => *n as u64,
121        TIROp::Assert(n) => *n as u64,
122        TIROp::Hint(n) => *n as u64,
123        TIROp::Hash { width } => *width as u64,
124        TIROp::RamRead { width } | TIROp::RamWrite { width } => *width as u64,
125        TIROp::Asm { effect, .. } => *effect as u64,
126        _ => 0,
127    }
128}
129
130/// Whether a TIROp is a control flow boundary (block terminator).
131fn is_block_boundary(op: &TIROp) -> bool {
132    matches!(
133        op,
134        TIROp::Call(_)
135            | TIROp::Return
136            | TIROp::Halt
137            | TIROp::IfElse { .. }
138            | TIROp::IfOnly { .. }
139            | TIROp::Loop { .. }
140            | TIROp::FnStart(_)
141            | TIROp::FnEnd
142            | TIROp::Entry(_)
143    )
144}
145
146/// Encode a single node as 4 u64 words.
147///
148/// Word 0: opcode (6 bits) | immediate (58 bits packed)
149/// Word 1: node index (position in block)
150/// Word 2: immediate value (full 64 bits for Push)
151/// Word 3: reserved (0)
152fn encode_node(op: &TIROp, index: usize) -> [u64; WORDS_PER_NODE] {
153    let opc = opcode(op) as u64;
154    let imm = immediate(op);
155    [
156        opc,          // word 0: opcode
157        index as u64, // word 1: position
158        imm,          // word 2: immediate
159        0,            // word 3: reserved
160    ]
161}
162
163/// Split a TIR op sequence into basic blocks at control flow boundaries.
164///
165/// Each block is a maximal straight-line segment of <= MAX_NODES ops.
166/// Structural ops (FnStart, FnEnd, Entry) start new blocks but are
167/// not included in the block content.
168pub fn encode_blocks(ops: &[TIROp]) -> Vec<TIRBlock> {
169    let mut blocks = Vec::new();
170    let mut current_fn = String::new();
171    let mut block_ops: Vec<(usize, &TIROp)> = Vec::new();
172    let mut block_start = 0;
173
174    for (i, op) in ops.iter().enumerate() {
175        // Track current function name
176        if let TIROp::FnStart(name) = op {
177            // Flush pending block
178            if !block_ops.is_empty() {
179                blocks.push(build_block(&block_ops, &current_fn, block_start));
180                block_ops.clear();
181            }
182            current_fn = name.clone();
183            block_start = i + 1;
184            continue;
185        }
186
187        // Skip structural markers
188        if matches!(op, TIROp::FnEnd | TIROp::Entry(_) | TIROp::Comment(_)) {
189            continue;
190        }
191
192        // Control flow boundaries flush the current block
193        if is_block_boundary(op) {
194            if !block_ops.is_empty() {
195                blocks.push(build_block(&block_ops, &current_fn, block_start));
196                block_ops.clear();
197            }
198            block_start = i + 1;
199            continue;
200        }
201
202        block_ops.push((i, op));
203
204        // Split at MAX_NODES
205        if block_ops.len() >= MAX_NODES {
206            blocks.push(build_block(&block_ops, &current_fn, block_start));
207            block_start = i + 1;
208            block_ops.clear();
209        }
210    }
211
212    // Flush remaining
213    if !block_ops.is_empty() {
214        blocks.push(build_block(&block_ops, &current_fn, block_start));
215    }
216
217    blocks
218}
219
220fn build_block(ops: &[(usize, &TIROp)], fn_name: &str, start_idx: usize) -> TIRBlock {
221    let mut nodes = [0u64; MAX_NODES * WORDS_PER_NODE];
222    let node_count = ops.len().min(MAX_NODES);
223    let end_idx = ops.last().map(|(i, _)| i + 1).unwrap_or(start_idx);
224
225    for (local_idx, (_global_idx, op)) in ops.iter().enumerate().take(MAX_NODES) {
226        let encoded = encode_node(op, local_idx);
227        let base = local_idx * WORDS_PER_NODE;
228        nodes[base..base + WORDS_PER_NODE].copy_from_slice(&encoded);
229    }
230
231    TIRBlock {
232        nodes,
233        context: [0; CONTEXT_SIZE],
234        node_count,
235        fn_name: fn_name.to_string(),
236        start_idx,
237        end_idx,
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn opcode_coverage() {
247        // All 54 variants should map to distinct opcodes 0..53
248        let ops = vec![
249            TIROp::Call("f".into()),
250            TIROp::Return,
251            TIROp::Halt,
252            TIROp::IfElse {
253                then_body: vec![],
254                else_body: vec![],
255            },
256            TIROp::IfOnly { then_body: vec![] },
257            TIROp::Loop {
258                label: "l".into(),
259                body: vec![],
260            },
261            TIROp::FnStart("f".into()),
262            TIROp::FnEnd,
263            TIROp::Entry("m".into()),
264            TIROp::Comment("c".into()),
265            TIROp::Asm {
266                lines: vec![],
267                effect: 0,
268            },
269            TIROp::Push(0),
270            TIROp::Pop(1),
271            TIROp::Dup(0),
272            TIROp::Swap(1),
273            TIROp::Add,
274            TIROp::Sub,
275            TIROp::Mul,
276            TIROp::Neg,
277            TIROp::Invert,
278            TIROp::Eq,
279            TIROp::Lt,
280            TIROp::And,
281            TIROp::Or,
282            TIROp::Xor,
283            TIROp::PopCount,
284            TIROp::Split,
285            TIROp::DivMod,
286            TIROp::Shl,
287            TIROp::Shr,
288            TIROp::Log2,
289            TIROp::Pow,
290            TIROp::ReadIo(1),
291            TIROp::WriteIo(1),
292            TIROp::ReadMem(1),
293            TIROp::WriteMem(1),
294            TIROp::Assert(1),
295            TIROp::Hash { width: 0 },
296            TIROp::Reveal {
297                name: "e".into(),
298                tag: 0,
299                field_count: 1,
300            },
301            TIROp::Seal {
302                name: "e".into(),
303                tag: 0,
304                field_count: 1,
305            },
306            TIROp::RamRead { width: 1 },
307            TIROp::RamWrite { width: 1 },
308            TIROp::Hint(1),
309            TIROp::SpongeInit,
310            TIROp::SpongeAbsorb,
311            TIROp::SpongeSqueeze,
312            TIROp::SpongeLoad,
313            TIROp::MerkleStep,
314            TIROp::MerkleLoad,
315            TIROp::ExtMul,
316            TIROp::ExtInvert,
317            TIROp::FoldExt,
318            TIROp::FoldBase,
319            TIROp::ProofBlock {
320                program_hash: "h".into(),
321                body: vec![],
322            },
323        ];
324        let mut seen = std::collections::HashSet::new();
325        for op in &ops {
326            let code = opcode(op);
327            assert!(code <= 53, "opcode {} out of range for {:?}", code, op);
328            seen.insert(code);
329        }
330        assert_eq!(
331            seen.len(),
332            54,
333            "expected 54 distinct opcodes, got {}",
334            seen.len()
335        );
336    }
337
338    #[test]
339    fn encode_simple_block() {
340        let ops = vec![
341            TIROp::FnStart("main".into()),
342            TIROp::Push(42),
343            TIROp::Push(10),
344            TIROp::Add,
345            TIROp::WriteIo(1),
346            TIROp::Return,
347        ];
348        let blocks = encode_blocks(&ops);
349        assert_eq!(blocks.len(), 1);
350        assert_eq!(blocks[0].node_count, 4); // Push, Push, Add, WriteIo
351        assert_eq!(blocks[0].fn_name, "main");
352        // First node is Push(42)
353        assert_eq!(blocks[0].nodes[0], 11); // opcode for Push
354        assert_eq!(blocks[0].nodes[2], 42); // immediate
355    }
356
357    #[test]
358    fn block_split_at_control_flow() {
359        let ops = vec![
360            TIROp::FnStart("main".into()),
361            TIROp::Push(1),
362            TIROp::Push(2),
363            TIROp::Call("helper".into()), // boundary
364            TIROp::Push(3),
365            TIROp::Add,
366            TIROp::Return, // boundary
367        ];
368        let blocks = encode_blocks(&ops);
369        assert_eq!(blocks.len(), 2);
370        assert_eq!(blocks[0].node_count, 2); // Push(1), Push(2)
371        assert_eq!(blocks[1].node_count, 2); // Push(3), Add
372    }
373
374    #[test]
375    fn block_split_at_max_nodes() {
376        let mut ops = vec![TIROp::FnStart("big".into())];
377        for i in 0..40 {
378            ops.push(TIROp::Push(i));
379        }
380        let blocks = encode_blocks(&ops);
381        assert_eq!(blocks.len(), 2);
382        assert_eq!(blocks[0].node_count, 32);
383        assert_eq!(blocks[1].node_count, 8);
384    }
385
386    #[test]
387    fn empty_ops() {
388        let blocks = encode_blocks(&[]);
389        assert!(blocks.is_empty());
390    }
391
392    #[test]
393    fn input_dimension() {
394        let ops = vec![
395            TIROp::FnStart("f".into()),
396            TIROp::Push(1),
397            TIROp::Push(2),
398            TIROp::Add,
399        ];
400        let blocks = encode_blocks(&ops);
401        let input = blocks[0].as_input();
402        assert_eq!(input.len(), INPUT_DIM);
403    }
404
405    #[test]
406    fn block_id_format() {
407        let ops = vec![TIROp::FnStart("main".into()), TIROp::Push(1), TIROp::Add];
408        let blocks = encode_blocks(&ops);
409        assert!(blocks[0].block_id().starts_with("main:"));
410    }
411}