1use chips::poseidon2_skinny::WIDTH;
2use core::fmt::Debug;
3use instruction::{
4 FieldEltType, HintAddCurveInstr, HintBitsInstr, HintExt2FeltsInstr, HintInstr, PrintInstr,
5};
6use itertools::Itertools;
7use p3_field::{AbstractExtensionField, AbstractField, Field, PrimeField64, TwoAdicField};
8use sp1_recursion_core::{
9 air::{Block, RecursionPublicValues, RECURSIVE_PROOF_NUM_PV_ELTS},
10 BaseAluInstr, BaseAluOpcode,
11};
12use sp1_stark::septic_curve::SepticCurve;
13use std::{
14 borrow::{Borrow, Cow},
15 collections::HashMap,
16 mem::transmute,
17};
18use vec_map::VecMap;
19
20use sp1_recursion_core::*;
21
22use crate::prelude::*;
23
24#[derive(Debug, Clone, Default)]
26pub struct AsmCompiler<C: Config> {
27 pub next_addr: C::F,
28 pub virtual_to_physical: VecMap<Address<C::F>>,
30 pub consts: HashMap<Imm<C::F, C::EF>, (Address<C::F>, C::F)>,
32 pub addr_to_mult: VecMap<C::F>,
34}
35
36impl<C: Config> AsmCompiler<C>
37where
38 C::F: PrimeField64,
39{
40 pub fn alloc(next_addr: &mut C::F) -> Address<C::F> {
42 let id = Address(*next_addr);
43 *next_addr += C::F::one();
44 if next_addr.is_zero() {
45 panic!("out of address space");
46 }
47 id
48 }
49
50 pub fn read_ghost_vaddr(&mut self, vaddr: usize) -> Address<C::F> {
54 self.read_vaddr_internal(vaddr, false)
55 }
56
57 pub fn read_vaddr(&mut self, vaddr: usize) -> Address<C::F> {
61 self.read_vaddr_internal(vaddr, true)
62 }
63
64 pub fn read_vaddr_internal(&mut self, vaddr: usize, increment_mult: bool) -> Address<C::F> {
65 use vec_map::Entry;
66 match self.virtual_to_physical.entry(vaddr) {
67 Entry::Vacant(_) => panic!("expected entry: virtual_physical[{:?}]", vaddr),
68 Entry::Occupied(entry) => {
69 if increment_mult {
70 match self.addr_to_mult.get_mut(entry.get().as_usize()) {
72 Some(mult) => *mult += C::F::one(),
73 None => panic!("expected entry: virtual_physical[{:?}]", vaddr),
74 }
75 }
76 *entry.into_mut()
77 }
78 }
79 }
80
81 pub fn write_fp(&mut self, vaddr: usize) -> Address<C::F> {
85 use vec_map::Entry;
86 match self.virtual_to_physical.entry(vaddr) {
87 Entry::Vacant(entry) => {
88 let addr = Self::alloc(&mut self.next_addr);
89 if let Some(x) = self.addr_to_mult.insert(addr.as_usize(), C::F::zero()) {
91 panic!("unexpected entry in addr_to_mult: {x:?}");
92 }
93 *entry.insert(addr)
94 }
95 Entry::Occupied(entry) => {
96 panic!("unexpected entry: virtual_to_physical[{:?}] = {:?}", vaddr, entry.get())
97 }
98 }
99 }
100
101 pub fn read_addr(&mut self, addr: Address<C::F>) -> &mut C::F {
105 self.read_addr_internal(addr, true)
106 }
107
108 pub fn read_ghost_addr(&mut self, addr: Address<C::F>) -> &mut C::F {
112 self.read_addr_internal(addr, true)
113 }
114
115 fn read_addr_internal(&mut self, addr: Address<C::F>, increment_mult: bool) -> &mut C::F {
116 use vec_map::Entry;
117 match self.addr_to_mult.entry(addr.as_usize()) {
118 Entry::Vacant(_) => panic!("expected entry: addr_to_mult[{:?}]", addr.as_usize()),
119 Entry::Occupied(entry) => {
120 let mult = entry.into_mut();
122 if increment_mult {
123 *mult += C::F::one();
124 }
125 mult
126 }
127 }
128 }
129
130 pub fn write_addr(&mut self, addr: Address<C::F>) -> &mut C::F {
134 use vec_map::Entry;
135 match self.addr_to_mult.entry(addr.as_usize()) {
136 Entry::Vacant(entry) => entry.insert(C::F::zero()),
137 Entry::Occupied(entry) => {
138 panic!("unexpected entry: addr_to_mult[{:?}] = {:?}", addr.as_usize(), entry.get())
139 }
140 }
141 }
142
143 pub fn read_const(&mut self, imm: Imm<C::F, C::EF>) -> Address<C::F> {
147 self.consts
148 .entry(imm)
149 .and_modify(|(_, x)| *x += C::F::one())
150 .or_insert_with(|| (Self::alloc(&mut self.next_addr), C::F::one()))
151 .0
152 }
153
154 pub fn read_ghost_const(&mut self, imm: Imm<C::F, C::EF>) -> Address<C::F> {
158 self.consts.entry(imm).or_insert_with(|| (Self::alloc(&mut self.next_addr), C::F::zero())).0
159 }
160
161 fn mem_write_const(&mut self, dst: impl Reg<C>, src: Imm<C::F, C::EF>) -> Instruction<C::F> {
162 Instruction::Mem(MemInstr {
163 addrs: MemIo { inner: dst.write(self) },
164 vals: MemIo { inner: src.as_block() },
165 mult: C::F::zero(),
166 kind: MemAccessKind::Write,
167 })
168 }
169
170 fn base_alu(
171 &mut self,
172 opcode: BaseAluOpcode,
173 dst: impl Reg<C>,
174 lhs: impl Reg<C>,
175 rhs: impl Reg<C>,
176 ) -> Instruction<C::F> {
177 Instruction::BaseAlu(BaseAluInstr {
178 opcode,
179 mult: C::F::zero(),
180 addrs: BaseAluIo { out: dst.write(self), in1: lhs.read(self), in2: rhs.read(self) },
181 })
182 }
183
184 fn ext_alu(
185 &mut self,
186 opcode: ExtAluOpcode,
187 dst: impl Reg<C>,
188 lhs: impl Reg<C>,
189 rhs: impl Reg<C>,
190 ) -> Instruction<C::F> {
191 Instruction::ExtAlu(ExtAluInstr {
192 opcode,
193 mult: C::F::zero(),
194 addrs: ExtAluIo { out: dst.write(self), in1: lhs.read(self), in2: rhs.read(self) },
195 })
196 }
197
198 fn base_assert_eq(
199 &mut self,
200 lhs: impl Reg<C>,
201 rhs: impl Reg<C>,
202 mut f: impl FnMut(Instruction<C::F>),
203 ) {
204 use BaseAluOpcode::*;
205 let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr));
206 f(self.base_alu(SubF, diff, lhs, rhs));
207 f(self.base_alu(DivF, out, diff, Imm::F(C::F::zero())));
208 }
209
210 fn base_assert_ne(
211 &mut self,
212 lhs: impl Reg<C>,
213 rhs: impl Reg<C>,
214 mut f: impl FnMut(Instruction<C::F>),
215 ) {
216 use BaseAluOpcode::*;
217 let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr));
218
219 f(self.base_alu(SubF, diff, lhs, rhs));
220 f(self.base_alu(DivF, out, Imm::F(C::F::one()), diff));
221 }
222
223 fn ext_assert_eq(
224 &mut self,
225 lhs: impl Reg<C>,
226 rhs: impl Reg<C>,
227 mut f: impl FnMut(Instruction<C::F>),
228 ) {
229 use ExtAluOpcode::*;
230 let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr));
231
232 f(self.ext_alu(SubE, diff, lhs, rhs));
233 f(self.ext_alu(DivE, out, diff, Imm::EF(C::EF::zero())));
234 }
235
236 fn ext_assert_ne(
237 &mut self,
238 lhs: impl Reg<C>,
239 rhs: impl Reg<C>,
240 mut f: impl FnMut(Instruction<C::F>),
241 ) {
242 use ExtAluOpcode::*;
243 let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr));
244
245 f(self.ext_alu(SubE, diff, lhs, rhs));
246 f(self.ext_alu(DivE, out, Imm::EF(C::EF::one()), diff));
247 }
248
249 #[inline(always)]
250 fn poseidon2_permute(
251 &mut self,
252 dst: [impl Reg<C>; WIDTH],
253 src: [impl Reg<C>; WIDTH],
254 ) -> Instruction<C::F> {
255 Instruction::Poseidon2(Box::new(Poseidon2Instr {
256 addrs: Poseidon2Io {
257 input: src.map(|r| r.read(self)),
258 output: dst.map(|r| r.write(self)),
259 },
260 mults: [C::F::zero(); WIDTH],
261 }))
262 }
263
264 #[inline(always)]
265 fn select(
266 &mut self,
267 bit: impl Reg<C>,
268 dst1: impl Reg<C>,
269 dst2: impl Reg<C>,
270 lhs: impl Reg<C>,
271 rhs: impl Reg<C>,
272 ) -> Instruction<C::F> {
273 Instruction::Select(SelectInstr {
274 addrs: SelectIo {
275 bit: bit.read(self),
276 out1: dst1.write(self),
277 out2: dst2.write(self),
278 in1: lhs.read(self),
279 in2: rhs.read(self),
280 },
281 mult1: C::F::zero(),
282 mult2: C::F::zero(),
283 })
284 }
285
286 fn exp_reverse_bits(
287 &mut self,
288 dst: impl Reg<C>,
289 base: impl Reg<C>,
290 exp: impl IntoIterator<Item = impl Reg<C>>,
291 ) -> Instruction<C::F> {
292 Instruction::ExpReverseBitsLen(ExpReverseBitsInstr {
293 addrs: ExpReverseBitsIo {
294 result: dst.write(self),
295 base: base.read(self),
296 exp: exp.into_iter().map(|r| r.read(self)).collect(),
297 },
298 mult: C::F::zero(),
299 })
300 }
301
302 fn hint_bit_decomposition(
303 &mut self,
304 value: impl Reg<C>,
305 output: impl IntoIterator<Item = impl Reg<C>>,
306 ) -> Instruction<C::F> {
307 Instruction::HintBits(HintBitsInstr {
308 output_addrs_mults: output.into_iter().map(|r| (r.write(self), C::F::zero())).collect(),
309 input_addr: value.read_ghost(self),
310 })
311 }
312
313 fn add_curve(
314 &mut self,
315 output: SepticCurve<Felt<C::F>>,
316 input1: SepticCurve<Felt<C::F>>,
317 input2: SepticCurve<Felt<C::F>>,
318 ) -> Instruction<C::F> {
319 Instruction::HintAddCurve(Box::new(HintAddCurveInstr {
320 output_x_addrs_mults: output
321 .x
322 .0
323 .into_iter()
324 .map(|r| (r.write(self), C::F::zero()))
325 .collect(),
326 output_y_addrs_mults: output
327 .y
328 .0
329 .into_iter()
330 .map(|r| (r.write(self), C::F::zero()))
331 .collect(),
332 input1_x_addrs: input1.x.0.into_iter().map(|value| value.read_ghost(self)).collect(),
333 input1_y_addrs: input1.y.0.into_iter().map(|value| value.read_ghost(self)).collect(),
334 input2_x_addrs: input2.x.0.into_iter().map(|value| value.read_ghost(self)).collect(),
335 input2_y_addrs: input2.y.0.into_iter().map(|value| value.read_ghost(self)).collect(),
336 }))
337 }
338
339 fn fri_fold(
340 &mut self,
341 CircuitV2FriFoldOutput { alpha_pow_output, ro_output }: CircuitV2FriFoldOutput<C>,
342 CircuitV2FriFoldInput {
343 z,
344 alpha,
345 x,
346 mat_opening,
347 ps_at_z,
348 alpha_pow_input,
349 ro_input,
350 }: CircuitV2FriFoldInput<C>,
351 ) -> Instruction<C::F> {
352 Instruction::FriFold(Box::new(FriFoldInstr {
353 alpha_pow_mults: vec![C::F::zero(); alpha_pow_output.len()],
355 ro_mults: vec![C::F::zero(); ro_output.len()],
356
357 base_single_addrs: FriFoldBaseIo { x: x.read(self) },
358 ext_single_addrs: FriFoldExtSingleIo { z: z.read(self), alpha: alpha.read(self) },
359 ext_vec_addrs: FriFoldExtVecIo {
360 mat_opening: mat_opening.into_iter().map(|e| e.read(self)).collect(),
361 ps_at_z: ps_at_z.into_iter().map(|e| e.read(self)).collect(),
362 alpha_pow_input: alpha_pow_input.into_iter().map(|e| e.read(self)).collect(),
363 ro_input: ro_input.into_iter().map(|e| e.read(self)).collect(),
364 alpha_pow_output: alpha_pow_output.into_iter().map(|e| e.write(self)).collect(),
365 ro_output: ro_output.into_iter().map(|e| e.write(self)).collect(),
366 },
367 }))
368 }
369
370 fn batch_fri(
371 &mut self,
372 acc: Ext<C::F, C::EF>,
373 alpha_pows: Vec<Ext<C::F, C::EF>>,
374 p_at_zs: Vec<Ext<C::F, C::EF>>,
375 p_at_xs: Vec<Felt<C::F>>,
376 ) -> Instruction<C::F> {
377 Instruction::BatchFRI(Box::new(BatchFRIInstr {
378 base_vec_addrs: BatchFRIBaseVecIo {
379 p_at_x: p_at_xs.into_iter().map(|e| e.read(self)).collect(),
380 },
381 ext_single_addrs: BatchFRIExtSingleIo { acc: acc.write(self) },
382 ext_vec_addrs: BatchFRIExtVecIo {
383 p_at_z: p_at_zs.into_iter().map(|e| e.read(self)).collect(),
384 alpha_pow: alpha_pows.into_iter().map(|e| e.read(self)).collect(),
385 },
386 acc_mult: C::F::zero(),
387 }))
388 }
389
390 fn commit_public_values(
391 &mut self,
392 public_values: &RecursionPublicValues<Felt<C::F>>,
393 ) -> Instruction<C::F> {
394 public_values.digest.iter().for_each(|x| {
395 let _ = x.read(self);
396 });
397 let pv_addrs =
398 unsafe {
399 transmute::<
400 RecursionPublicValues<Felt<C::F>>,
401 [Felt<C::F>; RECURSIVE_PROOF_NUM_PV_ELTS],
402 >(*public_values)
403 }
404 .map(|pv| pv.read_ghost(self));
405
406 let public_values_a: &RecursionPublicValues<Address<C::F>> = pv_addrs.as_slice().borrow();
407 Instruction::CommitPublicValues(Box::new(CommitPublicValuesInstr {
408 pv_addrs: *public_values_a,
409 }))
410 }
411
412 fn print_f(&mut self, addr: impl Reg<C>) -> Instruction<C::F> {
413 Instruction::Print(PrintInstr {
414 field_elt_type: FieldEltType::Base,
415 addr: addr.read_ghost(self),
416 })
417 }
418
419 fn print_e(&mut self, addr: impl Reg<C>) -> Instruction<C::F> {
420 Instruction::Print(PrintInstr {
421 field_elt_type: FieldEltType::Extension,
422 addr: addr.read_ghost(self),
423 })
424 }
425
426 fn ext2felts(&mut self, felts: [impl Reg<C>; D], ext: impl Reg<C>) -> Instruction<C::F> {
427 Instruction::HintExt2Felts(HintExt2FeltsInstr {
428 output_addrs_mults: felts.map(|r| (r.write(self), C::F::zero())),
429 input_addr: ext.read_ghost(self),
430 })
431 }
432
433 fn hint(&mut self, output: impl Reg<C>, len: usize) -> Instruction<C::F> {
434 let zero = C::F::zero();
435 Instruction::Hint(HintInstr {
436 output_addrs_mults: output
437 .write_many(self, len)
438 .into_iter()
439 .map(|a| (a, zero))
440 .collect(),
441 })
442 }
443}
444
445impl<C> AsmCompiler<C>
446where
447 C: Config<N = <C as Config>::F> + Debug,
448 C::F: PrimeField64 + TwoAdicField,
449{
450 pub fn compile_one(
455 &mut self,
456 ir_instr: DslIr<C>,
457 mut consumer: impl FnMut(Result<Instruction<C::F>, CompileOneErr<C>>),
458 ) {
459 use BaseAluOpcode::*;
461 use ExtAluOpcode::*;
462
463 let mut f = |instr| consumer(Ok(instr));
464 match ir_instr {
465 DslIr::ImmV(dst, src) => f(self.mem_write_const(dst, Imm::F(src))),
466 DslIr::ImmF(dst, src) => f(self.mem_write_const(dst, Imm::F(src))),
467 DslIr::ImmE(dst, src) => f(self.mem_write_const(dst, Imm::EF(src))),
468
469 DslIr::AddV(dst, lhs, rhs) => f(self.base_alu(AddF, dst, lhs, rhs)),
470 DslIr::AddVI(dst, lhs, rhs) => f(self.base_alu(AddF, dst, lhs, Imm::F(rhs))),
471 DslIr::AddF(dst, lhs, rhs) => f(self.base_alu(AddF, dst, lhs, rhs)),
472 DslIr::AddFI(dst, lhs, rhs) => f(self.base_alu(AddF, dst, lhs, Imm::F(rhs))),
473 DslIr::AddE(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, rhs)),
474 DslIr::AddEI(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, Imm::EF(rhs))),
475 DslIr::AddEF(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, rhs)),
476 DslIr::AddEFI(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, Imm::F(rhs))),
477 DslIr::AddEFFI(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, Imm::EF(rhs))),
478
479 DslIr::SubV(dst, lhs, rhs) => f(self.base_alu(SubF, dst, lhs, rhs)),
480 DslIr::SubVI(dst, lhs, rhs) => f(self.base_alu(SubF, dst, lhs, Imm::F(rhs))),
481 DslIr::SubVIN(dst, lhs, rhs) => f(self.base_alu(SubF, dst, Imm::F(lhs), rhs)),
482 DslIr::SubF(dst, lhs, rhs) => f(self.base_alu(SubF, dst, lhs, rhs)),
483 DslIr::SubFI(dst, lhs, rhs) => f(self.base_alu(SubF, dst, lhs, Imm::F(rhs))),
484 DslIr::SubFIN(dst, lhs, rhs) => f(self.base_alu(SubF, dst, Imm::F(lhs), rhs)),
485 DslIr::SubE(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, lhs, rhs)),
486 DslIr::SubEI(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, lhs, Imm::EF(rhs))),
487 DslIr::SubEIN(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, Imm::EF(lhs), rhs)),
488 DslIr::SubEFI(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, lhs, Imm::F(rhs))),
489 DslIr::SubEF(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, lhs, rhs)),
490
491 DslIr::MulV(dst, lhs, rhs) => f(self.base_alu(MulF, dst, lhs, rhs)),
492 DslIr::MulVI(dst, lhs, rhs) => f(self.base_alu(MulF, dst, lhs, Imm::F(rhs))),
493 DslIr::MulF(dst, lhs, rhs) => f(self.base_alu(MulF, dst, lhs, rhs)),
494 DslIr::MulFI(dst, lhs, rhs) => f(self.base_alu(MulF, dst, lhs, Imm::F(rhs))),
495 DslIr::MulE(dst, lhs, rhs) => f(self.ext_alu(MulE, dst, lhs, rhs)),
496 DslIr::MulEI(dst, lhs, rhs) => f(self.ext_alu(MulE, dst, lhs, Imm::EF(rhs))),
497 DslIr::MulEFI(dst, lhs, rhs) => f(self.ext_alu(MulE, dst, lhs, Imm::F(rhs))),
498 DslIr::MulEF(dst, lhs, rhs) => f(self.ext_alu(MulE, dst, lhs, rhs)),
499
500 DslIr::DivF(dst, lhs, rhs) => f(self.base_alu(DivF, dst, lhs, rhs)),
501 DslIr::DivFI(dst, lhs, rhs) => f(self.base_alu(DivF, dst, lhs, Imm::F(rhs))),
502 DslIr::DivFIN(dst, lhs, rhs) => f(self.base_alu(DivF, dst, Imm::F(lhs), rhs)),
503 DslIr::DivE(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, lhs, rhs)),
504 DslIr::DivEI(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, lhs, Imm::EF(rhs))),
505 DslIr::DivEIN(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, Imm::EF(lhs), rhs)),
506 DslIr::DivEFI(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, lhs, Imm::F(rhs))),
507 DslIr::DivEFIN(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, Imm::F(lhs), rhs)),
508 DslIr::DivEF(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, lhs, rhs)),
509
510 DslIr::NegV(dst, src) => f(self.base_alu(SubF, dst, Imm::F(C::F::zero()), src)),
511 DslIr::NegF(dst, src) => f(self.base_alu(SubF, dst, Imm::F(C::F::zero()), src)),
512 DslIr::NegE(dst, src) => f(self.ext_alu(SubE, dst, Imm::EF(C::EF::zero()), src)),
513 DslIr::InvV(dst, src) => f(self.base_alu(DivF, dst, Imm::F(C::F::one()), src)),
514 DslIr::InvF(dst, src) => f(self.base_alu(DivF, dst, Imm::F(C::F::one()), src)),
515 DslIr::InvE(dst, src) => f(self.ext_alu(DivE, dst, Imm::F(C::F::one()), src)),
516
517 DslIr::Select(bit, dst1, dst2, lhs, rhs) => f(self.select(bit, dst1, dst2, lhs, rhs)),
518
519 DslIr::AssertEqV(lhs, rhs) => self.base_assert_eq(lhs, rhs, f),
520 DslIr::AssertEqF(lhs, rhs) => self.base_assert_eq(lhs, rhs, f),
521 DslIr::AssertEqE(lhs, rhs) => self.ext_assert_eq(lhs, rhs, f),
522 DslIr::AssertEqVI(lhs, rhs) => self.base_assert_eq(lhs, Imm::F(rhs), f),
523 DslIr::AssertEqFI(lhs, rhs) => self.base_assert_eq(lhs, Imm::F(rhs), f),
524 DslIr::AssertEqEI(lhs, rhs) => self.ext_assert_eq(lhs, Imm::EF(rhs), f),
525
526 DslIr::AssertNeV(lhs, rhs) => self.base_assert_ne(lhs, rhs, f),
527 DslIr::AssertNeF(lhs, rhs) => self.base_assert_ne(lhs, rhs, f),
528 DslIr::AssertNeE(lhs, rhs) => self.ext_assert_ne(lhs, rhs, f),
529 DslIr::AssertNeVI(lhs, rhs) => self.base_assert_ne(lhs, Imm::F(rhs), f),
530 DslIr::AssertNeFI(lhs, rhs) => self.base_assert_ne(lhs, Imm::F(rhs), f),
531 DslIr::AssertNeEI(lhs, rhs) => self.ext_assert_ne(lhs, Imm::EF(rhs), f),
532
533 DslIr::CircuitV2Poseidon2PermuteBabyBear(data) => {
534 f(self.poseidon2_permute(data.0, data.1))
535 }
536 DslIr::CircuitV2ExpReverseBits(dst, base, exp) => {
537 f(self.exp_reverse_bits(dst, base, exp))
538 }
539 DslIr::CircuitV2HintBitsF(output, value) => {
540 f(self.hint_bit_decomposition(value, output))
541 }
542 DslIr::CircuitV2FriFold(data) => f(self.fri_fold(data.0, data.1)),
543 DslIr::CircuitV2BatchFRI(data) => f(self.batch_fri(data.0, data.1, data.2, data.3)),
544 DslIr::CircuitV2CommitPublicValues(public_values) => {
545 f(self.commit_public_values(&public_values))
546 }
547 DslIr::CircuitV2HintAddCurve(data) => f(self.add_curve(data.0, data.1, data.2)),
548
549 DslIr::Parallel(_) => {
550 unreachable!("parallel case should have been handled by compile_raw_program")
551 }
552
553 DslIr::PrintV(dst) => f(self.print_f(dst)),
554 DslIr::PrintF(dst) => f(self.print_f(dst)),
555 DslIr::PrintE(dst) => f(self.print_e(dst)),
556 #[cfg(feature = "debug")]
557 DslIr::DebugBacktrace(trace) => f(Instruction::DebugBacktrace(trace)),
558 DslIr::CircuitV2HintFelts(output, len) => f(self.hint(output, len)),
559 DslIr::CircuitV2HintExts(output, len) => f(self.hint(output, len)),
560 DslIr::CircuitExt2Felt(felts, ext) => f(self.ext2felts(felts, ext)),
561 DslIr::CycleTrackerV2Enter(name) => {
562 consumer(Err(CompileOneErr::CycleTrackerEnter(name)))
563 }
564 DslIr::CycleTrackerV2Exit => consumer(Err(CompileOneErr::CycleTrackerExit)),
565 DslIr::ReduceE(_) => {}
566 instr => consumer(Err(CompileOneErr::Unsupported(instr))),
567 }
568 }
569
570 fn compile_raw_program(
572 &mut self,
573 block: DslIrBlock<C>,
574 instrs_prefix: Vec<SeqBlock<Instruction<C::F>>>,
575 ) -> RawProgram<Instruction<C::F>> {
576 let mut seq_blocks = instrs_prefix;
579 let mut maybe_bb: Option<BasicBlock<Instruction<C::F>>> = None;
580
581 for op in block.ops {
582 match op {
583 DslIr::Parallel(par_blocks) => {
584 seq_blocks.extend(maybe_bb.take().map(SeqBlock::Basic));
585 seq_blocks.push(SeqBlock::Parallel(
586 par_blocks
587 .into_iter()
588 .map(|b| self.compile_raw_program(b, vec![]))
589 .collect(),
590 ))
591 }
592 op => {
593 let bb = maybe_bb.get_or_insert_with(Default::default);
594 self.compile_one(op, |item| match item {
595 Ok(instr) => bb.instrs.push(instr),
596 Err(
597 CompileOneErr::CycleTrackerEnter(_) | CompileOneErr::CycleTrackerExit,
598 ) => (),
599 Err(CompileOneErr::Unsupported(instr)) => {
600 panic!("unsupported instruction: {instr:?}")
601 }
602 });
603 }
604 }
605 }
606
607 seq_blocks.extend(maybe_bb.map(SeqBlock::Basic));
608
609 RawProgram { seq_blocks }
610 }
611
612 fn backfill_all<'a>(
613 &mut self,
614 instrs: impl Iterator<Item = &'a mut Instruction<<C as Config>::F>>,
615 ) {
616 let mut backfill = |(mult, addr): (&mut C::F, &Address<C::F>)| {
617 *mult = self.addr_to_mult.remove(addr.as_usize()).unwrap()
618 };
619
620 for asm_instr in instrs {
621 match asm_instr {
623 Instruction::BaseAlu(BaseAluInstr {
624 mult,
625 addrs: BaseAluIo { out: ref addr, .. },
626 ..
627 }) => backfill((mult, addr)),
628 Instruction::ExtAlu(ExtAluInstr {
629 mult,
630 addrs: ExtAluIo { out: ref addr, .. },
631 ..
632 }) => backfill((mult, addr)),
633 Instruction::Mem(MemInstr {
634 addrs: MemIo { inner: ref addr },
635 mult,
636 kind: MemAccessKind::Write,
637 ..
638 }) => backfill((mult, addr)),
639 Instruction::Poseidon2(instr) => {
640 let Poseidon2SkinnyInstr {
641 addrs: Poseidon2Io { output: ref addrs, .. },
642 mults,
643 } = instr.as_mut();
644 mults.iter_mut().zip(addrs).for_each(&mut backfill);
645 }
646 Instruction::Select(SelectInstr {
647 addrs: SelectIo { out1: ref addr1, out2: ref addr2, .. },
648 mult1,
649 mult2,
650 }) => {
651 backfill((mult1, addr1));
652 backfill((mult2, addr2));
653 }
654 Instruction::ExpReverseBitsLen(ExpReverseBitsInstr {
655 addrs: ExpReverseBitsIo { result: ref addr, .. },
656 mult,
657 }) => backfill((mult, addr)),
658 Instruction::HintBits(HintBitsInstr { output_addrs_mults, .. }) |
659 Instruction::Hint(HintInstr { output_addrs_mults, .. }) => {
660 output_addrs_mults.iter_mut().for_each(|(addr, mult)| backfill((mult, addr)));
661 }
662 Instruction::FriFold(instr) => {
663 let FriFoldInstr {
664 ext_vec_addrs: FriFoldExtVecIo { ref alpha_pow_output, ref ro_output, .. },
665 alpha_pow_mults,
666 ro_mults,
667 ..
668 } = instr.as_mut();
669 alpha_pow_mults.iter_mut().zip(alpha_pow_output).for_each(&mut backfill);
671 ro_mults.iter_mut().zip(ro_output).for_each(&mut backfill);
672 }
673 Instruction::BatchFRI(instr) => {
674 let BatchFRIInstr {
675 ext_single_addrs: BatchFRIExtSingleIo { ref acc },
676 acc_mult,
677 ..
678 } = instr.as_mut();
679 backfill((acc_mult, acc));
680 }
681 Instruction::HintExt2Felts(HintExt2FeltsInstr { output_addrs_mults, .. }) => {
682 output_addrs_mults.iter_mut().for_each(|(addr, mult)| backfill((mult, addr)));
683 }
684 Instruction::HintAddCurve(instr) => {
685 let HintAddCurveInstr { output_x_addrs_mults, output_y_addrs_mults, .. } =
686 instr.as_mut();
687 output_x_addrs_mults.iter_mut().for_each(|(addr, mult)| backfill((mult, addr)));
688 output_y_addrs_mults.iter_mut().for_each(|(addr, mult)| backfill((mult, addr)));
689 }
690 Instruction::Mem(MemInstr { kind: MemAccessKind::Read, .. }) |
692 Instruction::CommitPublicValues(_) |
693 Instruction::Print(_) => (),
694 #[cfg(feature = "debug")]
695 Instruction::DebugBacktrace(_) => (),
696 }
697 }
698
699 debug_assert!(self.addr_to_mult.is_empty());
700 }
701
702 pub fn compile(&mut self, program: DslIrProgram<C>) -> RecursionProgram<C::F> {
706 unsafe { RecursionProgram::new_unchecked(self.compile_inner(program.into_inner())) }
709 }
710
711 pub fn compile_inner(&mut self, root_block: DslIrBlock<C>) -> RootProgram<C::F> {
715 let mut program = tracing::debug_span!("compile raw program").in_scope(|| {
717 self.compile_raw_program(root_block, vec![SeqBlock::Basic(BasicBlock::default())])
718 });
719 let total_memory = self.addr_to_mult.len() + self.consts.len();
720 tracing::debug_span!("backfill mult").in_scope(|| self.backfill_all(program.iter_mut()));
721
722 tracing::debug_span!("prepend constants").in_scope(|| {
724 let Some(SeqBlock::Basic(BasicBlock { instrs: instrs_consts })) =
725 program.seq_blocks.first_mut()
726 else {
727 unreachable!()
728 };
729 instrs_consts.extend(self.consts.drain().sorted_by_key(|x| x.1 .0 .0).map(
730 |(imm, (addr, mult))| {
731 Instruction::Mem(MemInstr {
732 addrs: MemIo { inner: addr },
733 vals: MemIo { inner: imm.as_block() },
734 mult,
735 kind: MemAccessKind::Write,
736 })
737 },
738 ));
739 });
740
741 RootProgram { inner: program, total_memory, shape: None }
742 }
743}
744
745#[derive(Debug, Clone)]
746pub enum CompileOneErr<C: Config> {
747 Unsupported(DslIr<C>),
748 CycleTrackerEnter(Cow<'static, str>),
749 CycleTrackerExit,
750}
751
752#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
757pub enum Imm<F, EF> {
758 F(F),
760 EF(EF),
762}
763
764impl<F, EF> Imm<F, EF>
765where
766 F: AbstractField + Copy,
767 EF: AbstractExtensionField<F>,
768{
769 pub fn as_block(&self) -> Block<F> {
771 match self {
772 Imm::F(f) => Block::from(*f),
773 Imm::EF(ef) => ef.as_base_slice().into(),
774 }
775 }
776}
777
778trait Reg<C: Config> {
780 fn read(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F>;
782
783 fn read_ghost(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F>;
785
786 fn write(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F>;
788
789 fn write_many(&self, compiler: &mut AsmCompiler<C>, len: usize) -> Vec<Address<C::F>>;
790}
791
792macro_rules! impl_reg_borrowed {
793 ($a:ty) => {
794 impl<C, T> Reg<C> for $a
795 where
796 C: Config,
797 T: Reg<C> + ?Sized,
798 {
799 fn read(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
800 (**self).read(compiler)
801 }
802
803 fn read_ghost(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
804 (**self).read_ghost(compiler)
805 }
806
807 fn write(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
808 (**self).write(compiler)
809 }
810
811 fn write_many(&self, compiler: &mut AsmCompiler<C>, len: usize) -> Vec<Address<C::F>> {
812 (**self).write_many(compiler, len)
813 }
814 }
815 };
816}
817
818impl_reg_borrowed!(&T);
820impl_reg_borrowed!(&mut T);
821impl_reg_borrowed!(Box<T>);
822
823macro_rules! impl_reg_vaddr {
824 ($a:ty) => {
825 impl<C: Config<F: PrimeField64>> Reg<C> for $a {
826 fn read(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
827 compiler.read_vaddr(self.idx as usize)
828 }
829 fn read_ghost(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
830 compiler.read_ghost_vaddr(self.idx as usize)
831 }
832 fn write(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
833 compiler.write_fp(self.idx as usize)
834 }
835
836 fn write_many(&self, compiler: &mut AsmCompiler<C>, len: usize) -> Vec<Address<C::F>> {
837 (0..len).map(|i| compiler.write_fp((self.idx + i as u32) as usize)).collect()
838 }
839 }
840 };
841}
842
843impl_reg_vaddr!(Var<C::F>);
845impl_reg_vaddr!(Felt<C::F>);
846impl_reg_vaddr!(Ext<C::F, C::EF>);
847
848impl<C: Config<F: PrimeField64>> Reg<C> for Imm<C::F, C::EF> {
849 fn read(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
850 compiler.read_const(*self)
851 }
852
853 fn read_ghost(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
854 compiler.read_ghost_const(*self)
855 }
856
857 fn write(&self, _compiler: &mut AsmCompiler<C>) -> Address<C::F> {
858 panic!("cannot write to immediate in register: {self:?}")
859 }
860
861 fn write_many(&self, _compiler: &mut AsmCompiler<C>, _len: usize) -> Vec<Address<C::F>> {
862 panic!("cannot write to immediate in register: {self:?}")
863 }
864}
865
866impl<C: Config<F: PrimeField64>> Reg<C> for Address<C::F> {
867 fn read(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
868 compiler.read_addr(*self);
869 *self
870 }
871
872 fn read_ghost(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
873 compiler.read_ghost_addr(*self);
874 *self
875 }
876
877 fn write(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
878 compiler.write_addr(*self);
879 *self
880 }
881
882 fn write_many(&self, _compiler: &mut AsmCompiler<C>, _len: usize) -> Vec<Address<C::F>> {
883 todo!()
884 }
885}
886
887#[cfg(test)]
888mod tests {
889 #![allow(clippy::print_stdout)]
890
891 use std::{collections::VecDeque, io::BufRead, iter::zip, sync::Arc};
892
893 use p3_baby_bear::DiffusionMatrixBabyBear;
894 use p3_field::{Field, PrimeField32};
895 use p3_symmetric::{CryptographicHasher, Permutation};
896 use rand::{rngs::StdRng, Rng, SeedableRng};
897
898 use sp1_core_machine::utils::{run_test_machine, setup_logger};
899 use sp1_recursion_core::{machine::RecursionAir, Runtime};
900 use sp1_stark::{
901 baby_bear_poseidon2::BabyBearPoseidon2, inner_perm, BabyBearPoseidon2Inner, InnerHash,
902 StarkGenericConfig,
903 };
904
905 use crate::circuit::{AsmBuilder, AsmConfig, CircuitV2Builder};
906
907 use super::*;
908
909 type SC = BabyBearPoseidon2;
910 type F = <SC as StarkGenericConfig>::Val;
911 type EF = <SC as StarkGenericConfig>::Challenge;
912 fn test_block(block: DslIrBlock<AsmConfig<F, EF>>) {
913 test_block_with_runner(block, |program| {
914 let mut runtime = Runtime::<F, EF, DiffusionMatrixBabyBear>::new(
915 program,
916 BabyBearPoseidon2Inner::new().perm,
917 );
918 runtime.run().unwrap();
919 runtime.record
920 });
921 }
922
923 fn test_block_with_runner(
924 block: DslIrBlock<AsmConfig<F, EF>>,
925 run: impl FnOnce(Arc<RecursionProgram<F>>) -> ExecutionRecord<F>,
926 ) {
927 let mut compiler = super::AsmCompiler::<AsmConfig<F, EF>>::default();
928 let program = Arc::new(compiler.compile_inner(block).validate().unwrap());
929 let record = run(program.clone());
930
931 let wide_machine =
933 RecursionAir::<_, 3>::machine_wide_with_all_chips(BabyBearPoseidon2::default());
934 let (pk, vk) = wide_machine.setup(&program);
935 let result = run_test_machine(vec![record.clone()], wide_machine, pk, vk);
936 if let Err(e) = result {
937 panic!("Verification failed: {:?}", e);
938 }
939
940 let skinny_machine = RecursionAir::<_, 9>::machine_skinny_with_all_chips(
942 BabyBearPoseidon2::ultra_compressed(),
943 );
944 let (pk, vk) = skinny_machine.setup(&program);
945 let result = run_test_machine(vec![record.clone()], skinny_machine, pk, vk);
946 if let Err(e) = result {
947 panic!("Verification failed: {:?}", e);
948 }
949 }
950
951 #[test]
952 fn test_poseidon2() {
953 setup_logger();
954
955 let mut builder = AsmBuilder::<F, EF>::default();
956 let mut rng = StdRng::seed_from_u64(0xCAFEDA7E)
957 .sample_iter::<[F; WIDTH], _>(rand::distributions::Standard);
958 for _ in 0..100 {
959 let input_1: [F; WIDTH] = rng.next().unwrap();
960 let output_1 = inner_perm().permute(input_1);
961
962 let input_1_felts = input_1.map(|x| builder.eval(x));
963 let output_1_felts = builder.poseidon2_permute_v2(input_1_felts);
964 let expected: [Felt<_>; WIDTH] = output_1.map(|x| builder.eval(x));
965 for (lhs, rhs) in output_1_felts.into_iter().zip(expected) {
966 builder.assert_felt_eq(lhs, rhs);
967 }
968 }
969
970 test_block(builder.into_root_block());
971 }
972
973 #[test]
974 fn test_poseidon2_hash() {
975 let perm = inner_perm();
976 let hasher = InnerHash::new(perm.clone());
977
978 let input: [F; 26] = [
979 F::from_canonical_u32(0),
980 F::from_canonical_u32(1),
981 F::from_canonical_u32(2),
982 F::from_canonical_u32(2),
983 F::from_canonical_u32(2),
984 F::from_canonical_u32(2),
985 F::from_canonical_u32(2),
986 F::from_canonical_u32(2),
987 F::from_canonical_u32(2),
988 F::from_canonical_u32(2),
989 F::from_canonical_u32(2),
990 F::from_canonical_u32(2),
991 F::from_canonical_u32(2),
992 F::from_canonical_u32(2),
993 F::from_canonical_u32(2),
994 F::from_canonical_u32(3),
995 F::from_canonical_u32(3),
996 F::from_canonical_u32(3),
997 F::from_canonical_u32(3),
998 F::from_canonical_u32(3),
999 F::from_canonical_u32(3),
1000 F::from_canonical_u32(3),
1001 F::from_canonical_u32(3),
1002 F::from_canonical_u32(3),
1003 F::from_canonical_u32(3),
1004 F::from_canonical_u32(3),
1005 ];
1006 let expected = hasher.hash_iter(input);
1007 println!("{:?}", expected);
1008
1009 let mut builder = AsmBuilder::<F, EF>::default();
1010 let input_felts: [Felt<_>; 26] = input.map(|x| builder.eval(x));
1011 let result = builder.poseidon2_hash_v2(&input_felts);
1012
1013 for (actual_f, expected_f) in zip(result, expected) {
1014 builder.assert_felt_eq(actual_f, expected_f);
1015 }
1016 }
1017
1018 #[test]
1019 fn test_exp_reverse_bits() {
1020 setup_logger();
1021
1022 let mut builder = AsmBuilder::<F, EF>::default();
1023 let mut rng =
1024 StdRng::seed_from_u64(0xEC0BEEF).sample_iter::<F, _>(rand::distributions::Standard);
1025 for _ in 0..100 {
1026 let power_f = rng.next().unwrap();
1027 let power = power_f.as_canonical_u32();
1028 let power_bits = (0..NUM_BITS).map(|i| (power >> i) & 1).collect::<Vec<_>>();
1029
1030 let input_felt = builder.eval(power_f);
1031 let power_bits_felt = builder.num2bits_v2_f(input_felt, NUM_BITS);
1032
1033 let base = rng.next().unwrap();
1034 let base_felt = builder.eval(base);
1035 let result_felt = builder.exp_reverse_bits_v2(base_felt, power_bits_felt);
1036
1037 let expected = power_bits
1038 .into_iter()
1039 .rev()
1040 .zip(std::iter::successors(Some(base), |x| Some(x.square())))
1041 .map(|(bit, base_pow)| match bit {
1042 0 => F::one(),
1043 1 => base_pow,
1044 _ => panic!("not a bit: {bit}"),
1045 })
1046 .product::<F>();
1047 let expected_felt: Felt<_> = builder.eval(expected);
1048 builder.assert_felt_eq(result_felt, expected_felt);
1049 }
1050 test_block(builder.into_root_block());
1051 }
1052
1053 #[test]
1054 fn test_fri_fold() {
1055 setup_logger();
1056
1057 let mut builder = AsmBuilder::<F, EF>::default();
1058
1059 let mut rng = StdRng::seed_from_u64(0xFEB29).sample_iter(rand::distributions::Standard);
1060 let mut random_felt = move || -> F { rng.next().unwrap() };
1061 let mut rng =
1062 StdRng::seed_from_u64(0x0451).sample_iter::<[F; 4], _>(rand::distributions::Standard);
1063 let mut random_ext = move || EF::from_base_slice(&rng.next().unwrap());
1064
1065 for i in 2..17 {
1066 let x = random_felt();
1068 let z = random_ext();
1069 let alpha = random_ext();
1070
1071 let alpha_pow_input = (0..i).map(|_| random_ext()).collect::<Vec<_>>();
1072 let ro_input = (0..i).map(|_| random_ext()).collect::<Vec<_>>();
1073
1074 let ps_at_z = (0..i).map(|_| random_ext()).collect::<Vec<_>>();
1075 let mat_opening = (0..i).map(|_| random_ext()).collect::<Vec<_>>();
1076
1077 let alpha_pow_output = (0..i).map(|i| alpha_pow_input[i] * alpha).collect::<Vec<EF>>();
1079 let ro_output = (0..i)
1080 .map(|i| {
1081 ro_input[i] + alpha_pow_input[i] * (-ps_at_z[i] + mat_opening[i]) / (-z + x)
1082 })
1083 .collect::<Vec<EF>>();
1084
1085 let input_vars = CircuitV2FriFoldInput {
1087 z: builder.eval(z.cons()),
1088 alpha: builder.eval(alpha.cons()),
1089 x: builder.eval(x),
1090 mat_opening: mat_opening.iter().map(|e| builder.eval(e.cons())).collect(),
1091 ps_at_z: ps_at_z.iter().map(|e| builder.eval(e.cons())).collect(),
1092 alpha_pow_input: alpha_pow_input.iter().map(|e| builder.eval(e.cons())).collect(),
1093 ro_input: ro_input.iter().map(|e| builder.eval(e.cons())).collect(),
1094 };
1095
1096 let output_vars = builder.fri_fold_v2(input_vars);
1097 for (lhs, rhs) in std::iter::zip(output_vars.alpha_pow_output, alpha_pow_output) {
1098 builder.assert_ext_eq(lhs, rhs.cons());
1099 }
1100 for (lhs, rhs) in std::iter::zip(output_vars.ro_output, ro_output) {
1101 builder.assert_ext_eq(lhs, rhs.cons());
1102 }
1103 }
1104
1105 test_block(builder.into_root_block());
1106 }
1107
1108 #[test]
1109 fn test_hint_bit_decomposition() {
1110 setup_logger();
1111
1112 let mut builder = AsmBuilder::<F, EF>::default();
1113 let mut rng =
1114 StdRng::seed_from_u64(0xC0FFEE7AB1E).sample_iter::<F, _>(rand::distributions::Standard);
1115 for _ in 0..100 {
1116 let input_f = rng.next().unwrap();
1117 let input = input_f.as_canonical_u32();
1118 let output = (0..NUM_BITS).map(|i| (input >> i) & 1).collect::<Vec<_>>();
1119
1120 let input_felt = builder.eval(input_f);
1121 let output_felts = builder.num2bits_v2_f(input_felt, NUM_BITS);
1122 let expected: Vec<Felt<_>> =
1123 output.into_iter().map(|x| builder.eval(F::from_canonical_u32(x))).collect();
1124 for (lhs, rhs) in output_felts.into_iter().zip(expected) {
1125 builder.assert_felt_eq(lhs, rhs);
1126 }
1127 }
1128 test_block(builder.into_root_block());
1129 }
1130
1131 #[test]
1132 fn test_print_and_cycle_tracker() {
1133 const ITERS: usize = 5;
1134
1135 setup_logger();
1136
1137 let mut builder = AsmBuilder::<F, EF>::default();
1138
1139 let input_fs = StdRng::seed_from_u64(0xC0FFEE7AB1E)
1140 .sample_iter::<F, _>(rand::distributions::Standard)
1141 .take(ITERS)
1142 .collect::<Vec<_>>();
1143
1144 let input_efs = StdRng::seed_from_u64(0x7EA7AB1E)
1145 .sample_iter::<[F; 4], _>(rand::distributions::Standard)
1146 .take(ITERS)
1147 .collect::<Vec<_>>();
1148
1149 let mut buf = VecDeque::<u8>::new();
1150
1151 builder.cycle_tracker_v2_enter("printing felts");
1152 for (i, &input_f) in input_fs.iter().enumerate() {
1153 builder.cycle_tracker_v2_enter(format!("printing felt {i}"));
1154 let input_felt = builder.eval(input_f);
1155 builder.print_f(input_felt);
1156 builder.cycle_tracker_v2_exit();
1157 }
1158 builder.cycle_tracker_v2_exit();
1159
1160 builder.cycle_tracker_v2_enter("printing exts");
1161 for (i, input_block) in input_efs.iter().enumerate() {
1162 builder.cycle_tracker_v2_enter(format!("printing ext {i}"));
1163 let input_ext = builder.eval(EF::from_base_slice(input_block).cons());
1164 builder.print_e(input_ext);
1165 builder.cycle_tracker_v2_exit();
1166 }
1167 builder.cycle_tracker_v2_exit();
1168
1169 test_block_with_runner(builder.into_root_block(), |program| {
1170 let mut runtime = Runtime::<F, EF, DiffusionMatrixBabyBear>::new(
1171 program,
1172 BabyBearPoseidon2Inner::new().perm,
1173 );
1174 runtime.debug_stdout = Box::new(&mut buf);
1175 runtime.run().unwrap();
1176 runtime.record
1177 });
1178
1179 let input_str_fs = input_fs.into_iter().map(|elt| format!("{}", elt));
1180 let input_str_efs = input_efs.into_iter().map(|elt| format!("{:?}", elt));
1181 let input_strs = input_str_fs.chain(input_str_efs);
1182
1183 for (input_str, line) in zip(input_strs, buf.lines()) {
1184 let line = line.unwrap();
1185 assert!(line.contains(&input_str));
1186 }
1187 }
1188
1189 #[test]
1190 fn test_ext2felts() {
1191 setup_logger();
1192
1193 let mut builder = AsmBuilder::<F, EF>::default();
1194 let mut rng =
1195 StdRng::seed_from_u64(0x3264).sample_iter::<[F; 4], _>(rand::distributions::Standard);
1196 let mut random_ext = move || EF::from_base_slice(&rng.next().unwrap());
1197 for _ in 0..100 {
1198 let input = random_ext();
1199 let output: &[F] = input.as_base_slice();
1200
1201 let input_ext = builder.eval(input.cons());
1202 let output_felts = builder.ext2felt_v2(input_ext);
1203 let expected: Vec<Felt<_>> = output.iter().map(|&x| builder.eval(x)).collect();
1204 for (lhs, rhs) in output_felts.into_iter().zip(expected) {
1205 builder.assert_felt_eq(lhs, rhs);
1206 }
1207 }
1208 test_block(builder.into_root_block());
1209 }
1210
1211 macro_rules! test_assert_fixture {
1212 ($assert_felt:ident, $assert_ext:ident, $should_offset:literal) => {
1213 {
1214 use std::convert::identity;
1215 let mut builder = AsmBuilder::<F, EF>::default();
1216 test_assert_fixture!(builder, identity, F, Felt<_>, 0xDEADBEEF, $assert_felt, $should_offset);
1217 test_assert_fixture!(builder, EF::cons, EF, Ext<_, _>, 0xABADCAFE, $assert_ext, $should_offset);
1218 test_block(builder.into_root_block());
1219 }
1220 };
1221 ($builder:ident, $wrap:path, $t:ty, $u:ty, $seed:expr, $assert:ident, $should_offset:expr) => {
1222 {
1223 let mut elts = StdRng::seed_from_u64($seed)
1224 .sample_iter::<$t, _>(rand::distributions::Standard);
1225 for _ in 0..100 {
1226 let a = elts.next().unwrap();
1227 let b = elts.next().unwrap();
1228 let c = a + b;
1229 let ar: $u = $builder.eval($wrap(a));
1230 let br: $u = $builder.eval($wrap(b));
1231 let cr: $u = $builder.eval(ar + br);
1232 let cm = if $should_offset {
1233 c + elts.find(|x| !x.is_zero()).unwrap()
1234 } else {
1235 c
1236 };
1237 $builder.$assert(cr, $wrap(cm));
1238 }
1239 }
1240 };
1241 }
1242
1243 #[test]
1244 fn test_assert_eq_noop() {
1245 test_assert_fixture!(assert_felt_eq, assert_ext_eq, false);
1246 }
1247
1248 #[test]
1249 #[should_panic]
1250 fn test_assert_eq_panics() {
1251 test_assert_fixture!(assert_felt_eq, assert_ext_eq, true);
1252 }
1253
1254 #[test]
1255 fn test_assert_ne_noop() {
1256 test_assert_fixture!(assert_felt_ne, assert_ext_ne, true);
1257 }
1258
1259 #[test]
1260 #[should_panic]
1261 fn test_assert_ne_panics() {
1262 test_assert_fixture!(assert_felt_ne, assert_ext_ne, false);
1263 }
1264}