1use super::TIROp;
8
9pub const MAX_NODES: usize = 32;
11pub const WORDS_PER_NODE: usize = 4;
13pub const CONTEXT_SIZE: usize = 16;
15pub const INPUT_DIM: usize = MAX_NODES * WORDS_PER_NODE + CONTEXT_SIZE;
17
18#[derive(Clone, Debug)]
20pub struct TIRBlock {
21 pub nodes: [u64; MAX_NODES * WORDS_PER_NODE],
23 pub context: [u64; CONTEXT_SIZE],
25 pub node_count: usize,
27 pub fn_name: String,
29 pub start_idx: usize,
31 pub end_idx: usize,
33}
34
35impl TIRBlock {
36 pub fn as_input(&self) -> Vec<u64> {
38 let mut v = Vec::with_capacity(INPUT_DIM);
39 v.extend_from_slice(&self.nodes);
40 v.extend_from_slice(&self.context);
41 v
42 }
43
44 pub fn block_id(&self) -> String {
46 format!("{}:{}..{}", self.fn_name, self.start_idx, self.end_idx)
47 }
48}
49
50fn opcode(op: &TIROp) -> u8 {
52 match op {
53 TIROp::Call(_) => 0,
55 TIROp::Return => 1,
56 TIROp::Halt => 2,
57 TIROp::IfElse { .. } => 3,
58 TIROp::IfOnly { .. } => 4,
59 TIROp::Loop { .. } => 5,
60 TIROp::FnStart(_) => 6,
61 TIROp::FnEnd => 7,
62 TIROp::Entry(_) => 8,
63 TIROp::Comment(_) => 9,
64 TIROp::Asm { .. } => 10,
65 TIROp::Push(_) => 11,
67 TIROp::Pop(_) => 12,
68 TIROp::Dup(_) => 13,
69 TIROp::Swap(_) => 14,
70 TIROp::Add => 15,
71 TIROp::Sub => 16,
72 TIROp::Mul => 17,
73 TIROp::Neg => 18,
74 TIROp::Invert => 19,
75 TIROp::Eq => 20,
76 TIROp::Lt => 21,
77 TIROp::And => 22,
78 TIROp::Or => 23,
79 TIROp::Xor => 24,
80 TIROp::PopCount => 25,
81 TIROp::Split => 26,
82 TIROp::DivMod => 27,
83 TIROp::Shl => 28,
84 TIROp::Shr => 29,
85 TIROp::Log2 => 30,
86 TIROp::Pow => 31,
87 TIROp::ReadIo(_) => 32,
88 TIROp::WriteIo(_) => 33,
89 TIROp::ReadMem(_) => 34,
90 TIROp::WriteMem(_) => 35,
91 TIROp::Assert(_) => 36,
92 TIROp::Hash { .. } => 37,
93 TIROp::Reveal { .. } => 38,
94 TIROp::Seal { .. } => 39,
95 TIROp::RamRead { .. } => 40,
96 TIROp::RamWrite { .. } => 41,
97 TIROp::Hint(_) => 42,
99 TIROp::SpongeInit => 43,
100 TIROp::SpongeAbsorb => 44,
101 TIROp::SpongeSqueeze => 45,
102 TIROp::SpongeLoad => 46,
103 TIROp::MerkleStep => 47,
104 TIROp::MerkleLoad => 48,
105 TIROp::ExtMul => 49,
107 TIROp::ExtInvert => 50,
108 TIROp::FoldExt => 51,
109 TIROp::FoldBase => 52,
110 TIROp::ProofBlock { .. } => 53,
111 }
112}
113
114fn immediate(op: &TIROp) -> u64 {
116 match op {
117 TIROp::Push(v) => *v,
118 TIROp::Pop(n) | TIROp::Dup(n) | TIROp::Swap(n) => *n as u64,
119 TIROp::ReadIo(n) | TIROp::WriteIo(n) => *n as u64,
120 TIROp::ReadMem(n) | TIROp::WriteMem(n) => *n as u64,
121 TIROp::Assert(n) => *n as u64,
122 TIROp::Hint(n) => *n as u64,
123 TIROp::Hash { width } => *width as u64,
124 TIROp::RamRead { width } | TIROp::RamWrite { width } => *width as u64,
125 TIROp::Asm { effect, .. } => *effect as u64,
126 _ => 0,
127 }
128}
129
130fn is_block_boundary(op: &TIROp) -> bool {
132 matches!(
133 op,
134 TIROp::Call(_)
135 | TIROp::Return
136 | TIROp::Halt
137 | TIROp::IfElse { .. }
138 | TIROp::IfOnly { .. }
139 | TIROp::Loop { .. }
140 | TIROp::FnStart(_)
141 | TIROp::FnEnd
142 | TIROp::Entry(_)
143 )
144}
145
146fn encode_node(op: &TIROp, index: usize) -> [u64; WORDS_PER_NODE] {
153 let opc = opcode(op) as u64;
154 let imm = immediate(op);
155 [
156 opc, index as u64, imm, 0, ]
161}
162
163pub fn encode_blocks(ops: &[TIROp]) -> Vec<TIRBlock> {
169 let mut blocks = Vec::new();
170 let mut current_fn = String::new();
171 let mut block_ops: Vec<(usize, &TIROp)> = Vec::new();
172 let mut block_start = 0;
173
174 for (i, op) in ops.iter().enumerate() {
175 if let TIROp::FnStart(name) = op {
177 if !block_ops.is_empty() {
179 blocks.push(build_block(&block_ops, ¤t_fn, block_start));
180 block_ops.clear();
181 }
182 current_fn = name.clone();
183 block_start = i + 1;
184 continue;
185 }
186
187 if matches!(op, TIROp::FnEnd | TIROp::Entry(_) | TIROp::Comment(_)) {
189 continue;
190 }
191
192 if is_block_boundary(op) {
194 if !block_ops.is_empty() {
195 blocks.push(build_block(&block_ops, ¤t_fn, block_start));
196 block_ops.clear();
197 }
198 block_start = i + 1;
199 continue;
200 }
201
202 block_ops.push((i, op));
203
204 if block_ops.len() >= MAX_NODES {
206 blocks.push(build_block(&block_ops, ¤t_fn, block_start));
207 block_start = i + 1;
208 block_ops.clear();
209 }
210 }
211
212 if !block_ops.is_empty() {
214 blocks.push(build_block(&block_ops, ¤t_fn, block_start));
215 }
216
217 blocks
218}
219
220fn build_block(ops: &[(usize, &TIROp)], fn_name: &str, start_idx: usize) -> TIRBlock {
221 let mut nodes = [0u64; MAX_NODES * WORDS_PER_NODE];
222 let node_count = ops.len().min(MAX_NODES);
223 let end_idx = ops.last().map(|(i, _)| i + 1).unwrap_or(start_idx);
224
225 for (local_idx, (_global_idx, op)) in ops.iter().enumerate().take(MAX_NODES) {
226 let encoded = encode_node(op, local_idx);
227 let base = local_idx * WORDS_PER_NODE;
228 nodes[base..base + WORDS_PER_NODE].copy_from_slice(&encoded);
229 }
230
231 TIRBlock {
232 nodes,
233 context: [0; CONTEXT_SIZE],
234 node_count,
235 fn_name: fn_name.to_string(),
236 start_idx,
237 end_idx,
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn opcode_coverage() {
247 let ops = vec![
249 TIROp::Call("f".into()),
250 TIROp::Return,
251 TIROp::Halt,
252 TIROp::IfElse {
253 then_body: vec![],
254 else_body: vec![],
255 },
256 TIROp::IfOnly { then_body: vec![] },
257 TIROp::Loop {
258 label: "l".into(),
259 body: vec![],
260 },
261 TIROp::FnStart("f".into()),
262 TIROp::FnEnd,
263 TIROp::Entry("m".into()),
264 TIROp::Comment("c".into()),
265 TIROp::Asm {
266 lines: vec![],
267 effect: 0,
268 },
269 TIROp::Push(0),
270 TIROp::Pop(1),
271 TIROp::Dup(0),
272 TIROp::Swap(1),
273 TIROp::Add,
274 TIROp::Sub,
275 TIROp::Mul,
276 TIROp::Neg,
277 TIROp::Invert,
278 TIROp::Eq,
279 TIROp::Lt,
280 TIROp::And,
281 TIROp::Or,
282 TIROp::Xor,
283 TIROp::PopCount,
284 TIROp::Split,
285 TIROp::DivMod,
286 TIROp::Shl,
287 TIROp::Shr,
288 TIROp::Log2,
289 TIROp::Pow,
290 TIROp::ReadIo(1),
291 TIROp::WriteIo(1),
292 TIROp::ReadMem(1),
293 TIROp::WriteMem(1),
294 TIROp::Assert(1),
295 TIROp::Hash { width: 0 },
296 TIROp::Reveal {
297 name: "e".into(),
298 tag: 0,
299 field_count: 1,
300 },
301 TIROp::Seal {
302 name: "e".into(),
303 tag: 0,
304 field_count: 1,
305 },
306 TIROp::RamRead { width: 1 },
307 TIROp::RamWrite { width: 1 },
308 TIROp::Hint(1),
309 TIROp::SpongeInit,
310 TIROp::SpongeAbsorb,
311 TIROp::SpongeSqueeze,
312 TIROp::SpongeLoad,
313 TIROp::MerkleStep,
314 TIROp::MerkleLoad,
315 TIROp::ExtMul,
316 TIROp::ExtInvert,
317 TIROp::FoldExt,
318 TIROp::FoldBase,
319 TIROp::ProofBlock {
320 program_hash: "h".into(),
321 body: vec![],
322 },
323 ];
324 let mut seen = std::collections::HashSet::new();
325 for op in &ops {
326 let code = opcode(op);
327 assert!(code <= 53, "opcode {} out of range for {:?}", code, op);
328 seen.insert(code);
329 }
330 assert_eq!(
331 seen.len(),
332 54,
333 "expected 54 distinct opcodes, got {}",
334 seen.len()
335 );
336 }
337
338 #[test]
339 fn encode_simple_block() {
340 let ops = vec![
341 TIROp::FnStart("main".into()),
342 TIROp::Push(42),
343 TIROp::Push(10),
344 TIROp::Add,
345 TIROp::WriteIo(1),
346 TIROp::Return,
347 ];
348 let blocks = encode_blocks(&ops);
349 assert_eq!(blocks.len(), 1);
350 assert_eq!(blocks[0].node_count, 4); assert_eq!(blocks[0].fn_name, "main");
352 assert_eq!(blocks[0].nodes[0], 11); assert_eq!(blocks[0].nodes[2], 42); }
356
357 #[test]
358 fn block_split_at_control_flow() {
359 let ops = vec![
360 TIROp::FnStart("main".into()),
361 TIROp::Push(1),
362 TIROp::Push(2),
363 TIROp::Call("helper".into()), TIROp::Push(3),
365 TIROp::Add,
366 TIROp::Return, ];
368 let blocks = encode_blocks(&ops);
369 assert_eq!(blocks.len(), 2);
370 assert_eq!(blocks[0].node_count, 2); assert_eq!(blocks[1].node_count, 2); }
373
374 #[test]
375 fn block_split_at_max_nodes() {
376 let mut ops = vec![TIROp::FnStart("big".into())];
377 for i in 0..40 {
378 ops.push(TIROp::Push(i));
379 }
380 let blocks = encode_blocks(&ops);
381 assert_eq!(blocks.len(), 2);
382 assert_eq!(blocks[0].node_count, 32);
383 assert_eq!(blocks[1].node_count, 8);
384 }
385
386 #[test]
387 fn empty_ops() {
388 let blocks = encode_blocks(&[]);
389 assert!(blocks.is_empty());
390 }
391
392 #[test]
393 fn input_dimension() {
394 let ops = vec![
395 TIROp::FnStart("f".into()),
396 TIROp::Push(1),
397 TIROp::Push(2),
398 TIROp::Add,
399 ];
400 let blocks = encode_blocks(&ops);
401 let input = blocks[0].as_input();
402 assert_eq!(input.len(), INPUT_DIM);
403 }
404
405 #[test]
406 fn block_id_format() {
407 let ops = vec![TIROp::FnStart("main".into()), TIROp::Push(1), TIROp::Add];
408 let blocks = encode_blocks(&ops);
409 assert!(blocks[0].block_id().starts_with("main:"));
410 }
411}