sp1_recursion_compiler/circuit/
compiler.rs

1use chips::poseidon2_skinny::WIDTH;
2use core::fmt::Debug;
3use instruction::{
4    FieldEltType, HintAddCurveInstr, HintBitsInstr, HintExt2FeltsInstr, HintInstr, PrintInstr,
5};
6use itertools::Itertools;
7use p3_field::{AbstractExtensionField, AbstractField, Field, PrimeField64, TwoAdicField};
8use sp1_recursion_core::{
9    air::{Block, RecursionPublicValues, RECURSIVE_PROOF_NUM_PV_ELTS},
10    BaseAluInstr, BaseAluOpcode,
11};
12use sp1_stark::septic_curve::SepticCurve;
13use std::{
14    borrow::{Borrow, Cow},
15    collections::HashMap,
16    mem::transmute,
17};
18use vec_map::VecMap;
19
20use sp1_recursion_core::*;
21
22use crate::prelude::*;
23
24/// The backend for the circuit compiler.
25#[derive(Debug, Clone, Default)]
26pub struct AsmCompiler<C: Config> {
27    pub next_addr: C::F,
28    /// Map the frame pointers of the variables to the "physical" addresses.
29    pub virtual_to_physical: VecMap<Address<C::F>>,
30    /// Map base or extension field constants to "physical" addresses and mults.
31    pub consts: HashMap<Imm<C::F, C::EF>, (Address<C::F>, C::F)>,
32    /// Map each "physical" address to its read count.
33    pub addr_to_mult: VecMap<C::F>,
34}
35
36impl<C: Config> AsmCompiler<C>
37where
38    C::F: PrimeField64,
39{
40    /// Allocate a fresh address. Checks that the address space is not full.
41    pub fn alloc(next_addr: &mut C::F) -> Address<C::F> {
42        let id = Address(*next_addr);
43        *next_addr += C::F::one();
44        if next_addr.is_zero() {
45            panic!("out of address space");
46        }
47        id
48    }
49
50    /// Map `fp` to its existing address without changing its mult.
51    ///
52    /// Ensures that `fp` has already been assigned an address.
53    pub fn read_ghost_vaddr(&mut self, vaddr: usize) -> Address<C::F> {
54        self.read_vaddr_internal(vaddr, false)
55    }
56
57    /// Map `fp` to its existing address and increment its mult.
58    ///
59    /// Ensures that `fp` has already been assigned an address.
60    pub fn read_vaddr(&mut self, vaddr: usize) -> Address<C::F> {
61        self.read_vaddr_internal(vaddr, true)
62    }
63
64    pub fn read_vaddr_internal(&mut self, vaddr: usize, increment_mult: bool) -> Address<C::F> {
65        use vec_map::Entry;
66        match self.virtual_to_physical.entry(vaddr) {
67            Entry::Vacant(_) => panic!("expected entry: virtual_physical[{:?}]", vaddr),
68            Entry::Occupied(entry) => {
69                if increment_mult {
70                    // This is a read, so we increment the mult.
71                    match self.addr_to_mult.get_mut(entry.get().as_usize()) {
72                        Some(mult) => *mult += C::F::one(),
73                        None => panic!("expected entry: virtual_physical[{:?}]", vaddr),
74                    }
75                }
76                *entry.into_mut()
77            }
78        }
79    }
80
81    /// Map `fp` to a fresh address and initialize the mult to 0.
82    ///
83    /// Ensures that `fp` has not already been written to.
84    pub fn write_fp(&mut self, vaddr: usize) -> Address<C::F> {
85        use vec_map::Entry;
86        match self.virtual_to_physical.entry(vaddr) {
87            Entry::Vacant(entry) => {
88                let addr = Self::alloc(&mut self.next_addr);
89                // This is a write, so we set the mult to zero.
90                if let Some(x) = self.addr_to_mult.insert(addr.as_usize(), C::F::zero()) {
91                    panic!("unexpected entry in addr_to_mult: {x:?}");
92                }
93                *entry.insert(addr)
94            }
95            Entry::Occupied(entry) => {
96                panic!("unexpected entry: virtual_to_physical[{:?}] = {:?}", vaddr, entry.get())
97            }
98        }
99    }
100
101    /// Increment the existing `mult` associated with `addr`.
102    ///
103    /// Ensures that `addr` has already been assigned a `mult`.
104    pub fn read_addr(&mut self, addr: Address<C::F>) -> &mut C::F {
105        self.read_addr_internal(addr, true)
106    }
107
108    /// Retrieves `mult` associated with `addr`.
109    ///
110    /// Ensures that `addr` has already been assigned a `mult`.
111    pub fn read_ghost_addr(&mut self, addr: Address<C::F>) -> &mut C::F {
112        self.read_addr_internal(addr, true)
113    }
114
115    fn read_addr_internal(&mut self, addr: Address<C::F>, increment_mult: bool) -> &mut C::F {
116        use vec_map::Entry;
117        match self.addr_to_mult.entry(addr.as_usize()) {
118            Entry::Vacant(_) => panic!("expected entry: addr_to_mult[{:?}]", addr.as_usize()),
119            Entry::Occupied(entry) => {
120                // This is a read, so we increment the mult.
121                let mult = entry.into_mut();
122                if increment_mult {
123                    *mult += C::F::one();
124                }
125                mult
126            }
127        }
128    }
129
130    /// Associate a `mult` of zero with `addr`.
131    ///
132    /// Ensures that `addr` has not already been written to.
133    pub fn write_addr(&mut self, addr: Address<C::F>) -> &mut C::F {
134        use vec_map::Entry;
135        match self.addr_to_mult.entry(addr.as_usize()) {
136            Entry::Vacant(entry) => entry.insert(C::F::zero()),
137            Entry::Occupied(entry) => {
138                panic!("unexpected entry: addr_to_mult[{:?}] = {:?}", addr.as_usize(), entry.get())
139            }
140        }
141    }
142
143    /// Read a constant (a.k.a. immediate).
144    ///
145    /// Increments the mult, first creating an entry if it does not yet exist.
146    pub fn read_const(&mut self, imm: Imm<C::F, C::EF>) -> Address<C::F> {
147        self.consts
148            .entry(imm)
149            .and_modify(|(_, x)| *x += C::F::one())
150            .or_insert_with(|| (Self::alloc(&mut self.next_addr), C::F::one()))
151            .0
152    }
153
154    /// Read a constant (a.k.a. immediate).
155    ///    
156    /// Does not increment the mult. Creates an entry if it does not yet exist.
157    pub fn read_ghost_const(&mut self, imm: Imm<C::F, C::EF>) -> Address<C::F> {
158        self.consts.entry(imm).or_insert_with(|| (Self::alloc(&mut self.next_addr), C::F::zero())).0
159    }
160
161    fn mem_write_const(&mut self, dst: impl Reg<C>, src: Imm<C::F, C::EF>) -> Instruction<C::F> {
162        Instruction::Mem(MemInstr {
163            addrs: MemIo { inner: dst.write(self) },
164            vals: MemIo { inner: src.as_block() },
165            mult: C::F::zero(),
166            kind: MemAccessKind::Write,
167        })
168    }
169
170    fn base_alu(
171        &mut self,
172        opcode: BaseAluOpcode,
173        dst: impl Reg<C>,
174        lhs: impl Reg<C>,
175        rhs: impl Reg<C>,
176    ) -> Instruction<C::F> {
177        Instruction::BaseAlu(BaseAluInstr {
178            opcode,
179            mult: C::F::zero(),
180            addrs: BaseAluIo { out: dst.write(self), in1: lhs.read(self), in2: rhs.read(self) },
181        })
182    }
183
184    fn ext_alu(
185        &mut self,
186        opcode: ExtAluOpcode,
187        dst: impl Reg<C>,
188        lhs: impl Reg<C>,
189        rhs: impl Reg<C>,
190    ) -> Instruction<C::F> {
191        Instruction::ExtAlu(ExtAluInstr {
192            opcode,
193            mult: C::F::zero(),
194            addrs: ExtAluIo { out: dst.write(self), in1: lhs.read(self), in2: rhs.read(self) },
195        })
196    }
197
198    fn base_assert_eq(
199        &mut self,
200        lhs: impl Reg<C>,
201        rhs: impl Reg<C>,
202        mut f: impl FnMut(Instruction<C::F>),
203    ) {
204        use BaseAluOpcode::*;
205        let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr));
206        f(self.base_alu(SubF, diff, lhs, rhs));
207        f(self.base_alu(DivF, out, diff, Imm::F(C::F::zero())));
208    }
209
210    fn base_assert_ne(
211        &mut self,
212        lhs: impl Reg<C>,
213        rhs: impl Reg<C>,
214        mut f: impl FnMut(Instruction<C::F>),
215    ) {
216        use BaseAluOpcode::*;
217        let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr));
218
219        f(self.base_alu(SubF, diff, lhs, rhs));
220        f(self.base_alu(DivF, out, Imm::F(C::F::one()), diff));
221    }
222
223    fn ext_assert_eq(
224        &mut self,
225        lhs: impl Reg<C>,
226        rhs: impl Reg<C>,
227        mut f: impl FnMut(Instruction<C::F>),
228    ) {
229        use ExtAluOpcode::*;
230        let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr));
231
232        f(self.ext_alu(SubE, diff, lhs, rhs));
233        f(self.ext_alu(DivE, out, diff, Imm::EF(C::EF::zero())));
234    }
235
236    fn ext_assert_ne(
237        &mut self,
238        lhs: impl Reg<C>,
239        rhs: impl Reg<C>,
240        mut f: impl FnMut(Instruction<C::F>),
241    ) {
242        use ExtAluOpcode::*;
243        let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr));
244
245        f(self.ext_alu(SubE, diff, lhs, rhs));
246        f(self.ext_alu(DivE, out, Imm::EF(C::EF::one()), diff));
247    }
248
249    #[inline(always)]
250    fn poseidon2_permute(
251        &mut self,
252        dst: [impl Reg<C>; WIDTH],
253        src: [impl Reg<C>; WIDTH],
254    ) -> Instruction<C::F> {
255        Instruction::Poseidon2(Box::new(Poseidon2Instr {
256            addrs: Poseidon2Io {
257                input: src.map(|r| r.read(self)),
258                output: dst.map(|r| r.write(self)),
259            },
260            mults: [C::F::zero(); WIDTH],
261        }))
262    }
263
264    #[inline(always)]
265    fn select(
266        &mut self,
267        bit: impl Reg<C>,
268        dst1: impl Reg<C>,
269        dst2: impl Reg<C>,
270        lhs: impl Reg<C>,
271        rhs: impl Reg<C>,
272    ) -> Instruction<C::F> {
273        Instruction::Select(SelectInstr {
274            addrs: SelectIo {
275                bit: bit.read(self),
276                out1: dst1.write(self),
277                out2: dst2.write(self),
278                in1: lhs.read(self),
279                in2: rhs.read(self),
280            },
281            mult1: C::F::zero(),
282            mult2: C::F::zero(),
283        })
284    }
285
286    fn exp_reverse_bits(
287        &mut self,
288        dst: impl Reg<C>,
289        base: impl Reg<C>,
290        exp: impl IntoIterator<Item = impl Reg<C>>,
291    ) -> Instruction<C::F> {
292        Instruction::ExpReverseBitsLen(ExpReverseBitsInstr {
293            addrs: ExpReverseBitsIo {
294                result: dst.write(self),
295                base: base.read(self),
296                exp: exp.into_iter().map(|r| r.read(self)).collect(),
297            },
298            mult: C::F::zero(),
299        })
300    }
301
302    fn hint_bit_decomposition(
303        &mut self,
304        value: impl Reg<C>,
305        output: impl IntoIterator<Item = impl Reg<C>>,
306    ) -> Instruction<C::F> {
307        Instruction::HintBits(HintBitsInstr {
308            output_addrs_mults: output.into_iter().map(|r| (r.write(self), C::F::zero())).collect(),
309            input_addr: value.read_ghost(self),
310        })
311    }
312
313    fn add_curve(
314        &mut self,
315        output: SepticCurve<Felt<C::F>>,
316        input1: SepticCurve<Felt<C::F>>,
317        input2: SepticCurve<Felt<C::F>>,
318    ) -> Instruction<C::F> {
319        Instruction::HintAddCurve(Box::new(HintAddCurveInstr {
320            output_x_addrs_mults: output
321                .x
322                .0
323                .into_iter()
324                .map(|r| (r.write(self), C::F::zero()))
325                .collect(),
326            output_y_addrs_mults: output
327                .y
328                .0
329                .into_iter()
330                .map(|r| (r.write(self), C::F::zero()))
331                .collect(),
332            input1_x_addrs: input1.x.0.into_iter().map(|value| value.read_ghost(self)).collect(),
333            input1_y_addrs: input1.y.0.into_iter().map(|value| value.read_ghost(self)).collect(),
334            input2_x_addrs: input2.x.0.into_iter().map(|value| value.read_ghost(self)).collect(),
335            input2_y_addrs: input2.y.0.into_iter().map(|value| value.read_ghost(self)).collect(),
336        }))
337    }
338
339    fn fri_fold(
340        &mut self,
341        CircuitV2FriFoldOutput { alpha_pow_output, ro_output }: CircuitV2FriFoldOutput<C>,
342        CircuitV2FriFoldInput {
343            z,
344            alpha,
345            x,
346            mat_opening,
347            ps_at_z,
348            alpha_pow_input,
349            ro_input,
350        }: CircuitV2FriFoldInput<C>,
351    ) -> Instruction<C::F> {
352        Instruction::FriFold(Box::new(FriFoldInstr {
353            // Calculate before moving the vecs.
354            alpha_pow_mults: vec![C::F::zero(); alpha_pow_output.len()],
355            ro_mults: vec![C::F::zero(); ro_output.len()],
356
357            base_single_addrs: FriFoldBaseIo { x: x.read(self) },
358            ext_single_addrs: FriFoldExtSingleIo { z: z.read(self), alpha: alpha.read(self) },
359            ext_vec_addrs: FriFoldExtVecIo {
360                mat_opening: mat_opening.into_iter().map(|e| e.read(self)).collect(),
361                ps_at_z: ps_at_z.into_iter().map(|e| e.read(self)).collect(),
362                alpha_pow_input: alpha_pow_input.into_iter().map(|e| e.read(self)).collect(),
363                ro_input: ro_input.into_iter().map(|e| e.read(self)).collect(),
364                alpha_pow_output: alpha_pow_output.into_iter().map(|e| e.write(self)).collect(),
365                ro_output: ro_output.into_iter().map(|e| e.write(self)).collect(),
366            },
367        }))
368    }
369
370    fn batch_fri(
371        &mut self,
372        acc: Ext<C::F, C::EF>,
373        alpha_pows: Vec<Ext<C::F, C::EF>>,
374        p_at_zs: Vec<Ext<C::F, C::EF>>,
375        p_at_xs: Vec<Felt<C::F>>,
376    ) -> Instruction<C::F> {
377        Instruction::BatchFRI(Box::new(BatchFRIInstr {
378            base_vec_addrs: BatchFRIBaseVecIo {
379                p_at_x: p_at_xs.into_iter().map(|e| e.read(self)).collect(),
380            },
381            ext_single_addrs: BatchFRIExtSingleIo { acc: acc.write(self) },
382            ext_vec_addrs: BatchFRIExtVecIo {
383                p_at_z: p_at_zs.into_iter().map(|e| e.read(self)).collect(),
384                alpha_pow: alpha_pows.into_iter().map(|e| e.read(self)).collect(),
385            },
386            acc_mult: C::F::zero(),
387        }))
388    }
389
390    fn commit_public_values(
391        &mut self,
392        public_values: &RecursionPublicValues<Felt<C::F>>,
393    ) -> Instruction<C::F> {
394        public_values.digest.iter().for_each(|x| {
395            let _ = x.read(self);
396        });
397        let pv_addrs =
398            unsafe {
399                transmute::<
400                    RecursionPublicValues<Felt<C::F>>,
401                    [Felt<C::F>; RECURSIVE_PROOF_NUM_PV_ELTS],
402                >(*public_values)
403            }
404            .map(|pv| pv.read_ghost(self));
405
406        let public_values_a: &RecursionPublicValues<Address<C::F>> = pv_addrs.as_slice().borrow();
407        Instruction::CommitPublicValues(Box::new(CommitPublicValuesInstr {
408            pv_addrs: *public_values_a,
409        }))
410    }
411
412    fn print_f(&mut self, addr: impl Reg<C>) -> Instruction<C::F> {
413        Instruction::Print(PrintInstr {
414            field_elt_type: FieldEltType::Base,
415            addr: addr.read_ghost(self),
416        })
417    }
418
419    fn print_e(&mut self, addr: impl Reg<C>) -> Instruction<C::F> {
420        Instruction::Print(PrintInstr {
421            field_elt_type: FieldEltType::Extension,
422            addr: addr.read_ghost(self),
423        })
424    }
425
426    fn ext2felts(&mut self, felts: [impl Reg<C>; D], ext: impl Reg<C>) -> Instruction<C::F> {
427        Instruction::HintExt2Felts(HintExt2FeltsInstr {
428            output_addrs_mults: felts.map(|r| (r.write(self), C::F::zero())),
429            input_addr: ext.read_ghost(self),
430        })
431    }
432
433    fn hint(&mut self, output: impl Reg<C>, len: usize) -> Instruction<C::F> {
434        let zero = C::F::zero();
435        Instruction::Hint(HintInstr {
436            output_addrs_mults: output
437                .write_many(self, len)
438                .into_iter()
439                .map(|a| (a, zero))
440                .collect(),
441        })
442    }
443}
444
445impl<C> AsmCompiler<C>
446where
447    C: Config<N = <C as Config>::F> + Debug,
448    C::F: PrimeField64 + TwoAdicField,
449{
450    /// Compiles one instruction, passing one or more instructions to `consumer`.
451    ///
452    /// We do not simply return a `Vec` for performance reasons --- results would be immediately fed
453    /// to `flat_map`, so we employ fusion/deforestation to eliminate intermediate data structures.
454    pub fn compile_one(
455        &mut self,
456        ir_instr: DslIr<C>,
457        mut consumer: impl FnMut(Result<Instruction<C::F>, CompileOneErr<C>>),
458    ) {
459        // For readability. Avoids polluting outer scope.
460        use BaseAluOpcode::*;
461        use ExtAluOpcode::*;
462
463        let mut f = |instr| consumer(Ok(instr));
464        match ir_instr {
465            DslIr::ImmV(dst, src) => f(self.mem_write_const(dst, Imm::F(src))),
466            DslIr::ImmF(dst, src) => f(self.mem_write_const(dst, Imm::F(src))),
467            DslIr::ImmE(dst, src) => f(self.mem_write_const(dst, Imm::EF(src))),
468
469            DslIr::AddV(dst, lhs, rhs) => f(self.base_alu(AddF, dst, lhs, rhs)),
470            DslIr::AddVI(dst, lhs, rhs) => f(self.base_alu(AddF, dst, lhs, Imm::F(rhs))),
471            DslIr::AddF(dst, lhs, rhs) => f(self.base_alu(AddF, dst, lhs, rhs)),
472            DslIr::AddFI(dst, lhs, rhs) => f(self.base_alu(AddF, dst, lhs, Imm::F(rhs))),
473            DslIr::AddE(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, rhs)),
474            DslIr::AddEI(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, Imm::EF(rhs))),
475            DslIr::AddEF(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, rhs)),
476            DslIr::AddEFI(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, Imm::F(rhs))),
477            DslIr::AddEFFI(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, Imm::EF(rhs))),
478
479            DslIr::SubV(dst, lhs, rhs) => f(self.base_alu(SubF, dst, lhs, rhs)),
480            DslIr::SubVI(dst, lhs, rhs) => f(self.base_alu(SubF, dst, lhs, Imm::F(rhs))),
481            DslIr::SubVIN(dst, lhs, rhs) => f(self.base_alu(SubF, dst, Imm::F(lhs), rhs)),
482            DslIr::SubF(dst, lhs, rhs) => f(self.base_alu(SubF, dst, lhs, rhs)),
483            DslIr::SubFI(dst, lhs, rhs) => f(self.base_alu(SubF, dst, lhs, Imm::F(rhs))),
484            DslIr::SubFIN(dst, lhs, rhs) => f(self.base_alu(SubF, dst, Imm::F(lhs), rhs)),
485            DslIr::SubE(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, lhs, rhs)),
486            DslIr::SubEI(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, lhs, Imm::EF(rhs))),
487            DslIr::SubEIN(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, Imm::EF(lhs), rhs)),
488            DslIr::SubEFI(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, lhs, Imm::F(rhs))),
489            DslIr::SubEF(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, lhs, rhs)),
490
491            DslIr::MulV(dst, lhs, rhs) => f(self.base_alu(MulF, dst, lhs, rhs)),
492            DslIr::MulVI(dst, lhs, rhs) => f(self.base_alu(MulF, dst, lhs, Imm::F(rhs))),
493            DslIr::MulF(dst, lhs, rhs) => f(self.base_alu(MulF, dst, lhs, rhs)),
494            DslIr::MulFI(dst, lhs, rhs) => f(self.base_alu(MulF, dst, lhs, Imm::F(rhs))),
495            DslIr::MulE(dst, lhs, rhs) => f(self.ext_alu(MulE, dst, lhs, rhs)),
496            DslIr::MulEI(dst, lhs, rhs) => f(self.ext_alu(MulE, dst, lhs, Imm::EF(rhs))),
497            DslIr::MulEFI(dst, lhs, rhs) => f(self.ext_alu(MulE, dst, lhs, Imm::F(rhs))),
498            DslIr::MulEF(dst, lhs, rhs) => f(self.ext_alu(MulE, dst, lhs, rhs)),
499
500            DslIr::DivF(dst, lhs, rhs) => f(self.base_alu(DivF, dst, lhs, rhs)),
501            DslIr::DivFI(dst, lhs, rhs) => f(self.base_alu(DivF, dst, lhs, Imm::F(rhs))),
502            DslIr::DivFIN(dst, lhs, rhs) => f(self.base_alu(DivF, dst, Imm::F(lhs), rhs)),
503            DslIr::DivE(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, lhs, rhs)),
504            DslIr::DivEI(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, lhs, Imm::EF(rhs))),
505            DslIr::DivEIN(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, Imm::EF(lhs), rhs)),
506            DslIr::DivEFI(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, lhs, Imm::F(rhs))),
507            DslIr::DivEFIN(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, Imm::F(lhs), rhs)),
508            DslIr::DivEF(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, lhs, rhs)),
509
510            DslIr::NegV(dst, src) => f(self.base_alu(SubF, dst, Imm::F(C::F::zero()), src)),
511            DslIr::NegF(dst, src) => f(self.base_alu(SubF, dst, Imm::F(C::F::zero()), src)),
512            DslIr::NegE(dst, src) => f(self.ext_alu(SubE, dst, Imm::EF(C::EF::zero()), src)),
513            DslIr::InvV(dst, src) => f(self.base_alu(DivF, dst, Imm::F(C::F::one()), src)),
514            DslIr::InvF(dst, src) => f(self.base_alu(DivF, dst, Imm::F(C::F::one()), src)),
515            DslIr::InvE(dst, src) => f(self.ext_alu(DivE, dst, Imm::F(C::F::one()), src)),
516
517            DslIr::Select(bit, dst1, dst2, lhs, rhs) => f(self.select(bit, dst1, dst2, lhs, rhs)),
518
519            DslIr::AssertEqV(lhs, rhs) => self.base_assert_eq(lhs, rhs, f),
520            DslIr::AssertEqF(lhs, rhs) => self.base_assert_eq(lhs, rhs, f),
521            DslIr::AssertEqE(lhs, rhs) => self.ext_assert_eq(lhs, rhs, f),
522            DslIr::AssertEqVI(lhs, rhs) => self.base_assert_eq(lhs, Imm::F(rhs), f),
523            DslIr::AssertEqFI(lhs, rhs) => self.base_assert_eq(lhs, Imm::F(rhs), f),
524            DslIr::AssertEqEI(lhs, rhs) => self.ext_assert_eq(lhs, Imm::EF(rhs), f),
525
526            DslIr::AssertNeV(lhs, rhs) => self.base_assert_ne(lhs, rhs, f),
527            DslIr::AssertNeF(lhs, rhs) => self.base_assert_ne(lhs, rhs, f),
528            DslIr::AssertNeE(lhs, rhs) => self.ext_assert_ne(lhs, rhs, f),
529            DslIr::AssertNeVI(lhs, rhs) => self.base_assert_ne(lhs, Imm::F(rhs), f),
530            DslIr::AssertNeFI(lhs, rhs) => self.base_assert_ne(lhs, Imm::F(rhs), f),
531            DslIr::AssertNeEI(lhs, rhs) => self.ext_assert_ne(lhs, Imm::EF(rhs), f),
532
533            DslIr::CircuitV2Poseidon2PermuteBabyBear(data) => {
534                f(self.poseidon2_permute(data.0, data.1))
535            }
536            DslIr::CircuitV2ExpReverseBits(dst, base, exp) => {
537                f(self.exp_reverse_bits(dst, base, exp))
538            }
539            DslIr::CircuitV2HintBitsF(output, value) => {
540                f(self.hint_bit_decomposition(value, output))
541            }
542            DslIr::CircuitV2FriFold(data) => f(self.fri_fold(data.0, data.1)),
543            DslIr::CircuitV2BatchFRI(data) => f(self.batch_fri(data.0, data.1, data.2, data.3)),
544            DslIr::CircuitV2CommitPublicValues(public_values) => {
545                f(self.commit_public_values(&public_values))
546            }
547            DslIr::CircuitV2HintAddCurve(data) => f(self.add_curve(data.0, data.1, data.2)),
548
549            DslIr::Parallel(_) => {
550                unreachable!("parallel case should have been handled by compile_raw_program")
551            }
552
553            DslIr::PrintV(dst) => f(self.print_f(dst)),
554            DslIr::PrintF(dst) => f(self.print_f(dst)),
555            DslIr::PrintE(dst) => f(self.print_e(dst)),
556            #[cfg(feature = "debug")]
557            DslIr::DebugBacktrace(trace) => f(Instruction::DebugBacktrace(trace)),
558            DslIr::CircuitV2HintFelts(output, len) => f(self.hint(output, len)),
559            DslIr::CircuitV2HintExts(output, len) => f(self.hint(output, len)),
560            DslIr::CircuitExt2Felt(felts, ext) => f(self.ext2felts(felts, ext)),
561            DslIr::CycleTrackerV2Enter(name) => {
562                consumer(Err(CompileOneErr::CycleTrackerEnter(name)))
563            }
564            DslIr::CycleTrackerV2Exit => consumer(Err(CompileOneErr::CycleTrackerExit)),
565            DslIr::ReduceE(_) => {}
566            instr => consumer(Err(CompileOneErr::Unsupported(instr))),
567        }
568    }
569
570    /// A raw program (algebraic data type of instructions), not yet backfilled.
571    fn compile_raw_program(
572        &mut self,
573        block: DslIrBlock<C>,
574        instrs_prefix: Vec<SeqBlock<Instruction<C::F>>>,
575    ) -> RawProgram<Instruction<C::F>> {
576        // Consider refactoring the builder to use an AST instead of a list of operations.
577        // Possible to remove address translation at this step.
578        let mut seq_blocks = instrs_prefix;
579        let mut maybe_bb: Option<BasicBlock<Instruction<C::F>>> = None;
580
581        for op in block.ops {
582            match op {
583                DslIr::Parallel(par_blocks) => {
584                    seq_blocks.extend(maybe_bb.take().map(SeqBlock::Basic));
585                    seq_blocks.push(SeqBlock::Parallel(
586                        par_blocks
587                            .into_iter()
588                            .map(|b| self.compile_raw_program(b, vec![]))
589                            .collect(),
590                    ))
591                }
592                op => {
593                    let bb = maybe_bb.get_or_insert_with(Default::default);
594                    self.compile_one(op, |item| match item {
595                        Ok(instr) => bb.instrs.push(instr),
596                        Err(
597                            CompileOneErr::CycleTrackerEnter(_) | CompileOneErr::CycleTrackerExit,
598                        ) => (),
599                        Err(CompileOneErr::Unsupported(instr)) => {
600                            panic!("unsupported instruction: {instr:?}")
601                        }
602                    });
603                }
604            }
605        }
606
607        seq_blocks.extend(maybe_bb.map(SeqBlock::Basic));
608
609        RawProgram { seq_blocks }
610    }
611
612    fn backfill_all<'a>(
613        &mut self,
614        instrs: impl Iterator<Item = &'a mut Instruction<<C as Config>::F>>,
615    ) {
616        let mut backfill = |(mult, addr): (&mut C::F, &Address<C::F>)| {
617            *mult = self.addr_to_mult.remove(addr.as_usize()).unwrap()
618        };
619
620        for asm_instr in instrs {
621            // Exhaustive match for refactoring purposes.
622            match asm_instr {
623                Instruction::BaseAlu(BaseAluInstr {
624                    mult,
625                    addrs: BaseAluIo { out: ref addr, .. },
626                    ..
627                }) => backfill((mult, addr)),
628                Instruction::ExtAlu(ExtAluInstr {
629                    mult,
630                    addrs: ExtAluIo { out: ref addr, .. },
631                    ..
632                }) => backfill((mult, addr)),
633                Instruction::Mem(MemInstr {
634                    addrs: MemIo { inner: ref addr },
635                    mult,
636                    kind: MemAccessKind::Write,
637                    ..
638                }) => backfill((mult, addr)),
639                Instruction::Poseidon2(instr) => {
640                    let Poseidon2SkinnyInstr {
641                        addrs: Poseidon2Io { output: ref addrs, .. },
642                        mults,
643                    } = instr.as_mut();
644                    mults.iter_mut().zip(addrs).for_each(&mut backfill);
645                }
646                Instruction::Select(SelectInstr {
647                    addrs: SelectIo { out1: ref addr1, out2: ref addr2, .. },
648                    mult1,
649                    mult2,
650                }) => {
651                    backfill((mult1, addr1));
652                    backfill((mult2, addr2));
653                }
654                Instruction::ExpReverseBitsLen(ExpReverseBitsInstr {
655                    addrs: ExpReverseBitsIo { result: ref addr, .. },
656                    mult,
657                }) => backfill((mult, addr)),
658                Instruction::HintBits(HintBitsInstr { output_addrs_mults, .. }) |
659                Instruction::Hint(HintInstr { output_addrs_mults, .. }) => {
660                    output_addrs_mults.iter_mut().for_each(|(addr, mult)| backfill((mult, addr)));
661                }
662                Instruction::FriFold(instr) => {
663                    let FriFoldInstr {
664                        ext_vec_addrs: FriFoldExtVecIo { ref alpha_pow_output, ref ro_output, .. },
665                        alpha_pow_mults,
666                        ro_mults,
667                        ..
668                    } = instr.as_mut();
669                    // Using `.chain` seems to be less performant.
670                    alpha_pow_mults.iter_mut().zip(alpha_pow_output).for_each(&mut backfill);
671                    ro_mults.iter_mut().zip(ro_output).for_each(&mut backfill);
672                }
673                Instruction::BatchFRI(instr) => {
674                    let BatchFRIInstr {
675                        ext_single_addrs: BatchFRIExtSingleIo { ref acc },
676                        acc_mult,
677                        ..
678                    } = instr.as_mut();
679                    backfill((acc_mult, acc));
680                }
681                Instruction::HintExt2Felts(HintExt2FeltsInstr { output_addrs_mults, .. }) => {
682                    output_addrs_mults.iter_mut().for_each(|(addr, mult)| backfill((mult, addr)));
683                }
684                Instruction::HintAddCurve(instr) => {
685                    let HintAddCurveInstr { output_x_addrs_mults, output_y_addrs_mults, .. } =
686                        instr.as_mut();
687                    output_x_addrs_mults.iter_mut().for_each(|(addr, mult)| backfill((mult, addr)));
688                    output_y_addrs_mults.iter_mut().for_each(|(addr, mult)| backfill((mult, addr)));
689                }
690                // Instructions that do not write to memory.
691                Instruction::Mem(MemInstr { kind: MemAccessKind::Read, .. }) |
692                Instruction::CommitPublicValues(_) |
693                Instruction::Print(_) => (),
694                #[cfg(feature = "debug")]
695                Instruction::DebugBacktrace(_) => (),
696            }
697        }
698
699        debug_assert!(self.addr_to_mult.is_empty());
700    }
701
702    /// Compile a `DslIrProgram` that is definitionally assumed to be well-formed.
703    ///
704    /// Returns a well-formed program.
705    pub fn compile(&mut self, program: DslIrProgram<C>) -> RecursionProgram<C::F> {
706        // SAFETY: The compiler produces well-formed programs given a well-formed DSL input.
707        // This is also a cryptographic requirement.
708        unsafe { RecursionProgram::new_unchecked(self.compile_inner(program.into_inner())) }
709    }
710
711    /// Compile a root `DslIrBlock` that has not necessarily been validated.
712    ///
713    /// Returns a program that may be ill-formed.
714    pub fn compile_inner(&mut self, root_block: DslIrBlock<C>) -> RootProgram<C::F> {
715        // Prefix an empty basic block to be later filled in by constants.
716        let mut program = tracing::debug_span!("compile raw program").in_scope(|| {
717            self.compile_raw_program(root_block, vec![SeqBlock::Basic(BasicBlock::default())])
718        });
719        let total_memory = self.addr_to_mult.len() + self.consts.len();
720        tracing::debug_span!("backfill mult").in_scope(|| self.backfill_all(program.iter_mut()));
721
722        // Put in the constants.
723        tracing::debug_span!("prepend constants").in_scope(|| {
724            let Some(SeqBlock::Basic(BasicBlock { instrs: instrs_consts })) =
725                program.seq_blocks.first_mut()
726            else {
727                unreachable!()
728            };
729            instrs_consts.extend(self.consts.drain().sorted_by_key(|x| x.1 .0 .0).map(
730                |(imm, (addr, mult))| {
731                    Instruction::Mem(MemInstr {
732                        addrs: MemIo { inner: addr },
733                        vals: MemIo { inner: imm.as_block() },
734                        mult,
735                        kind: MemAccessKind::Write,
736                    })
737                },
738            ));
739        });
740
741        RootProgram { inner: program, total_memory, shape: None }
742    }
743}
744
745#[derive(Debug, Clone)]
746pub enum CompileOneErr<C: Config> {
747    Unsupported(DslIr<C>),
748    CycleTrackerEnter(Cow<'static, str>),
749    CycleTrackerExit,
750}
751
752/// Immediate (i.e. constant) field element.
753///
754/// Required to distinguish a base and extension field element at the type level,
755/// since the IR's instructions do not provide this information.
756#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
757pub enum Imm<F, EF> {
758    /// Element of the base field `F`.
759    F(F),
760    /// Element of the extension field `EF`.
761    EF(EF),
762}
763
764impl<F, EF> Imm<F, EF>
765where
766    F: AbstractField + Copy,
767    EF: AbstractExtensionField<F>,
768{
769    // Get a `Block` of memory representing this immediate.
770    pub fn as_block(&self) -> Block<F> {
771        match self {
772            Imm::F(f) => Block::from(*f),
773            Imm::EF(ef) => ef.as_base_slice().into(),
774        }
775    }
776}
777
778/// Utility functions for various register types.
779trait Reg<C: Config> {
780    /// Mark the register as to be read from, returning the "physical" address.
781    fn read(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F>;
782
783    /// Get the "physical" address of the register, assigning a new address if necessary.
784    fn read_ghost(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F>;
785
786    /// Mark the register as to be written to, returning the "physical" address.
787    fn write(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F>;
788
789    fn write_many(&self, compiler: &mut AsmCompiler<C>, len: usize) -> Vec<Address<C::F>>;
790}
791
792macro_rules! impl_reg_borrowed {
793    ($a:ty) => {
794        impl<C, T> Reg<C> for $a
795        where
796            C: Config,
797            T: Reg<C> + ?Sized,
798        {
799            fn read(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
800                (**self).read(compiler)
801            }
802
803            fn read_ghost(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
804                (**self).read_ghost(compiler)
805            }
806
807            fn write(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
808                (**self).write(compiler)
809            }
810
811            fn write_many(&self, compiler: &mut AsmCompiler<C>, len: usize) -> Vec<Address<C::F>> {
812                (**self).write_many(compiler, len)
813            }
814        }
815    };
816}
817
818// Allow for more flexibility in arguments.
819impl_reg_borrowed!(&T);
820impl_reg_borrowed!(&mut T);
821impl_reg_borrowed!(Box<T>);
822
823macro_rules! impl_reg_vaddr {
824    ($a:ty) => {
825        impl<C: Config<F: PrimeField64>> Reg<C> for $a {
826            fn read(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
827                compiler.read_vaddr(self.idx as usize)
828            }
829            fn read_ghost(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
830                compiler.read_ghost_vaddr(self.idx as usize)
831            }
832            fn write(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
833                compiler.write_fp(self.idx as usize)
834            }
835
836            fn write_many(&self, compiler: &mut AsmCompiler<C>, len: usize) -> Vec<Address<C::F>> {
837                (0..len).map(|i| compiler.write_fp((self.idx + i as u32) as usize)).collect()
838            }
839        }
840    };
841}
842
843// These three types wrap a `u32` but they don't share a trait.
844impl_reg_vaddr!(Var<C::F>);
845impl_reg_vaddr!(Felt<C::F>);
846impl_reg_vaddr!(Ext<C::F, C::EF>);
847
848impl<C: Config<F: PrimeField64>> Reg<C> for Imm<C::F, C::EF> {
849    fn read(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
850        compiler.read_const(*self)
851    }
852
853    fn read_ghost(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
854        compiler.read_ghost_const(*self)
855    }
856
857    fn write(&self, _compiler: &mut AsmCompiler<C>) -> Address<C::F> {
858        panic!("cannot write to immediate in register: {self:?}")
859    }
860
861    fn write_many(&self, _compiler: &mut AsmCompiler<C>, _len: usize) -> Vec<Address<C::F>> {
862        panic!("cannot write to immediate in register: {self:?}")
863    }
864}
865
866impl<C: Config<F: PrimeField64>> Reg<C> for Address<C::F> {
867    fn read(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
868        compiler.read_addr(*self);
869        *self
870    }
871
872    fn read_ghost(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
873        compiler.read_ghost_addr(*self);
874        *self
875    }
876
877    fn write(&self, compiler: &mut AsmCompiler<C>) -> Address<C::F> {
878        compiler.write_addr(*self);
879        *self
880    }
881
882    fn write_many(&self, _compiler: &mut AsmCompiler<C>, _len: usize) -> Vec<Address<C::F>> {
883        todo!()
884    }
885}
886
887#[cfg(test)]
888mod tests {
889    #![allow(clippy::print_stdout)]
890
891    use std::{collections::VecDeque, io::BufRead, iter::zip, sync::Arc};
892
893    use p3_baby_bear::DiffusionMatrixBabyBear;
894    use p3_field::{Field, PrimeField32};
895    use p3_symmetric::{CryptographicHasher, Permutation};
896    use rand::{rngs::StdRng, Rng, SeedableRng};
897
898    use sp1_core_machine::utils::{run_test_machine, setup_logger};
899    use sp1_recursion_core::{machine::RecursionAir, Runtime};
900    use sp1_stark::{
901        baby_bear_poseidon2::BabyBearPoseidon2, inner_perm, BabyBearPoseidon2Inner, InnerHash,
902        StarkGenericConfig,
903    };
904
905    use crate::circuit::{AsmBuilder, AsmConfig, CircuitV2Builder};
906
907    use super::*;
908
909    type SC = BabyBearPoseidon2;
910    type F = <SC as StarkGenericConfig>::Val;
911    type EF = <SC as StarkGenericConfig>::Challenge;
912    fn test_block(block: DslIrBlock<AsmConfig<F, EF>>) {
913        test_block_with_runner(block, |program| {
914            let mut runtime = Runtime::<F, EF, DiffusionMatrixBabyBear>::new(
915                program,
916                BabyBearPoseidon2Inner::new().perm,
917            );
918            runtime.run().unwrap();
919            runtime.record
920        });
921    }
922
923    fn test_block_with_runner(
924        block: DslIrBlock<AsmConfig<F, EF>>,
925        run: impl FnOnce(Arc<RecursionProgram<F>>) -> ExecutionRecord<F>,
926    ) {
927        let mut compiler = super::AsmCompiler::<AsmConfig<F, EF>>::default();
928        let program = Arc::new(compiler.compile_inner(block).validate().unwrap());
929        let record = run(program.clone());
930
931        // Run with the poseidon2 wide chip.
932        let wide_machine =
933            RecursionAir::<_, 3>::machine_wide_with_all_chips(BabyBearPoseidon2::default());
934        let (pk, vk) = wide_machine.setup(&program);
935        let result = run_test_machine(vec![record.clone()], wide_machine, pk, vk);
936        if let Err(e) = result {
937            panic!("Verification failed: {:?}", e);
938        }
939
940        // Run with the poseidon2 skinny chip.
941        let skinny_machine = RecursionAir::<_, 9>::machine_skinny_with_all_chips(
942            BabyBearPoseidon2::ultra_compressed(),
943        );
944        let (pk, vk) = skinny_machine.setup(&program);
945        let result = run_test_machine(vec![record.clone()], skinny_machine, pk, vk);
946        if let Err(e) = result {
947            panic!("Verification failed: {:?}", e);
948        }
949    }
950
951    #[test]
952    fn test_poseidon2() {
953        setup_logger();
954
955        let mut builder = AsmBuilder::<F, EF>::default();
956        let mut rng = StdRng::seed_from_u64(0xCAFEDA7E)
957            .sample_iter::<[F; WIDTH], _>(rand::distributions::Standard);
958        for _ in 0..100 {
959            let input_1: [F; WIDTH] = rng.next().unwrap();
960            let output_1 = inner_perm().permute(input_1);
961
962            let input_1_felts = input_1.map(|x| builder.eval(x));
963            let output_1_felts = builder.poseidon2_permute_v2(input_1_felts);
964            let expected: [Felt<_>; WIDTH] = output_1.map(|x| builder.eval(x));
965            for (lhs, rhs) in output_1_felts.into_iter().zip(expected) {
966                builder.assert_felt_eq(lhs, rhs);
967            }
968        }
969
970        test_block(builder.into_root_block());
971    }
972
973    #[test]
974    fn test_poseidon2_hash() {
975        let perm = inner_perm();
976        let hasher = InnerHash::new(perm.clone());
977
978        let input: [F; 26] = [
979            F::from_canonical_u32(0),
980            F::from_canonical_u32(1),
981            F::from_canonical_u32(2),
982            F::from_canonical_u32(2),
983            F::from_canonical_u32(2),
984            F::from_canonical_u32(2),
985            F::from_canonical_u32(2),
986            F::from_canonical_u32(2),
987            F::from_canonical_u32(2),
988            F::from_canonical_u32(2),
989            F::from_canonical_u32(2),
990            F::from_canonical_u32(2),
991            F::from_canonical_u32(2),
992            F::from_canonical_u32(2),
993            F::from_canonical_u32(2),
994            F::from_canonical_u32(3),
995            F::from_canonical_u32(3),
996            F::from_canonical_u32(3),
997            F::from_canonical_u32(3),
998            F::from_canonical_u32(3),
999            F::from_canonical_u32(3),
1000            F::from_canonical_u32(3),
1001            F::from_canonical_u32(3),
1002            F::from_canonical_u32(3),
1003            F::from_canonical_u32(3),
1004            F::from_canonical_u32(3),
1005        ];
1006        let expected = hasher.hash_iter(input);
1007        println!("{:?}", expected);
1008
1009        let mut builder = AsmBuilder::<F, EF>::default();
1010        let input_felts: [Felt<_>; 26] = input.map(|x| builder.eval(x));
1011        let result = builder.poseidon2_hash_v2(&input_felts);
1012
1013        for (actual_f, expected_f) in zip(result, expected) {
1014            builder.assert_felt_eq(actual_f, expected_f);
1015        }
1016    }
1017
1018    #[test]
1019    fn test_exp_reverse_bits() {
1020        setup_logger();
1021
1022        let mut builder = AsmBuilder::<F, EF>::default();
1023        let mut rng =
1024            StdRng::seed_from_u64(0xEC0BEEF).sample_iter::<F, _>(rand::distributions::Standard);
1025        for _ in 0..100 {
1026            let power_f = rng.next().unwrap();
1027            let power = power_f.as_canonical_u32();
1028            let power_bits = (0..NUM_BITS).map(|i| (power >> i) & 1).collect::<Vec<_>>();
1029
1030            let input_felt = builder.eval(power_f);
1031            let power_bits_felt = builder.num2bits_v2_f(input_felt, NUM_BITS);
1032
1033            let base = rng.next().unwrap();
1034            let base_felt = builder.eval(base);
1035            let result_felt = builder.exp_reverse_bits_v2(base_felt, power_bits_felt);
1036
1037            let expected = power_bits
1038                .into_iter()
1039                .rev()
1040                .zip(std::iter::successors(Some(base), |x| Some(x.square())))
1041                .map(|(bit, base_pow)| match bit {
1042                    0 => F::one(),
1043                    1 => base_pow,
1044                    _ => panic!("not a bit: {bit}"),
1045                })
1046                .product::<F>();
1047            let expected_felt: Felt<_> = builder.eval(expected);
1048            builder.assert_felt_eq(result_felt, expected_felt);
1049        }
1050        test_block(builder.into_root_block());
1051    }
1052
1053    #[test]
1054    fn test_fri_fold() {
1055        setup_logger();
1056
1057        let mut builder = AsmBuilder::<F, EF>::default();
1058
1059        let mut rng = StdRng::seed_from_u64(0xFEB29).sample_iter(rand::distributions::Standard);
1060        let mut random_felt = move || -> F { rng.next().unwrap() };
1061        let mut rng =
1062            StdRng::seed_from_u64(0x0451).sample_iter::<[F; 4], _>(rand::distributions::Standard);
1063        let mut random_ext = move || EF::from_base_slice(&rng.next().unwrap());
1064
1065        for i in 2..17 {
1066            // Generate random values for the inputs.
1067            let x = random_felt();
1068            let z = random_ext();
1069            let alpha = random_ext();
1070
1071            let alpha_pow_input = (0..i).map(|_| random_ext()).collect::<Vec<_>>();
1072            let ro_input = (0..i).map(|_| random_ext()).collect::<Vec<_>>();
1073
1074            let ps_at_z = (0..i).map(|_| random_ext()).collect::<Vec<_>>();
1075            let mat_opening = (0..i).map(|_| random_ext()).collect::<Vec<_>>();
1076
1077            // Compute the outputs from the inputs.
1078            let alpha_pow_output = (0..i).map(|i| alpha_pow_input[i] * alpha).collect::<Vec<EF>>();
1079            let ro_output = (0..i)
1080                .map(|i| {
1081                    ro_input[i] + alpha_pow_input[i] * (-ps_at_z[i] + mat_opening[i]) / (-z + x)
1082                })
1083                .collect::<Vec<EF>>();
1084
1085            // Compute inputs and outputs through the builder.
1086            let input_vars = CircuitV2FriFoldInput {
1087                z: builder.eval(z.cons()),
1088                alpha: builder.eval(alpha.cons()),
1089                x: builder.eval(x),
1090                mat_opening: mat_opening.iter().map(|e| builder.eval(e.cons())).collect(),
1091                ps_at_z: ps_at_z.iter().map(|e| builder.eval(e.cons())).collect(),
1092                alpha_pow_input: alpha_pow_input.iter().map(|e| builder.eval(e.cons())).collect(),
1093                ro_input: ro_input.iter().map(|e| builder.eval(e.cons())).collect(),
1094            };
1095
1096            let output_vars = builder.fri_fold_v2(input_vars);
1097            for (lhs, rhs) in std::iter::zip(output_vars.alpha_pow_output, alpha_pow_output) {
1098                builder.assert_ext_eq(lhs, rhs.cons());
1099            }
1100            for (lhs, rhs) in std::iter::zip(output_vars.ro_output, ro_output) {
1101                builder.assert_ext_eq(lhs, rhs.cons());
1102            }
1103        }
1104
1105        test_block(builder.into_root_block());
1106    }
1107
1108    #[test]
1109    fn test_hint_bit_decomposition() {
1110        setup_logger();
1111
1112        let mut builder = AsmBuilder::<F, EF>::default();
1113        let mut rng =
1114            StdRng::seed_from_u64(0xC0FFEE7AB1E).sample_iter::<F, _>(rand::distributions::Standard);
1115        for _ in 0..100 {
1116            let input_f = rng.next().unwrap();
1117            let input = input_f.as_canonical_u32();
1118            let output = (0..NUM_BITS).map(|i| (input >> i) & 1).collect::<Vec<_>>();
1119
1120            let input_felt = builder.eval(input_f);
1121            let output_felts = builder.num2bits_v2_f(input_felt, NUM_BITS);
1122            let expected: Vec<Felt<_>> =
1123                output.into_iter().map(|x| builder.eval(F::from_canonical_u32(x))).collect();
1124            for (lhs, rhs) in output_felts.into_iter().zip(expected) {
1125                builder.assert_felt_eq(lhs, rhs);
1126            }
1127        }
1128        test_block(builder.into_root_block());
1129    }
1130
1131    #[test]
1132    fn test_print_and_cycle_tracker() {
1133        const ITERS: usize = 5;
1134
1135        setup_logger();
1136
1137        let mut builder = AsmBuilder::<F, EF>::default();
1138
1139        let input_fs = StdRng::seed_from_u64(0xC0FFEE7AB1E)
1140            .sample_iter::<F, _>(rand::distributions::Standard)
1141            .take(ITERS)
1142            .collect::<Vec<_>>();
1143
1144        let input_efs = StdRng::seed_from_u64(0x7EA7AB1E)
1145            .sample_iter::<[F; 4], _>(rand::distributions::Standard)
1146            .take(ITERS)
1147            .collect::<Vec<_>>();
1148
1149        let mut buf = VecDeque::<u8>::new();
1150
1151        builder.cycle_tracker_v2_enter("printing felts");
1152        for (i, &input_f) in input_fs.iter().enumerate() {
1153            builder.cycle_tracker_v2_enter(format!("printing felt {i}"));
1154            let input_felt = builder.eval(input_f);
1155            builder.print_f(input_felt);
1156            builder.cycle_tracker_v2_exit();
1157        }
1158        builder.cycle_tracker_v2_exit();
1159
1160        builder.cycle_tracker_v2_enter("printing exts");
1161        for (i, input_block) in input_efs.iter().enumerate() {
1162            builder.cycle_tracker_v2_enter(format!("printing ext {i}"));
1163            let input_ext = builder.eval(EF::from_base_slice(input_block).cons());
1164            builder.print_e(input_ext);
1165            builder.cycle_tracker_v2_exit();
1166        }
1167        builder.cycle_tracker_v2_exit();
1168
1169        test_block_with_runner(builder.into_root_block(), |program| {
1170            let mut runtime = Runtime::<F, EF, DiffusionMatrixBabyBear>::new(
1171                program,
1172                BabyBearPoseidon2Inner::new().perm,
1173            );
1174            runtime.debug_stdout = Box::new(&mut buf);
1175            runtime.run().unwrap();
1176            runtime.record
1177        });
1178
1179        let input_str_fs = input_fs.into_iter().map(|elt| format!("{}", elt));
1180        let input_str_efs = input_efs.into_iter().map(|elt| format!("{:?}", elt));
1181        let input_strs = input_str_fs.chain(input_str_efs);
1182
1183        for (input_str, line) in zip(input_strs, buf.lines()) {
1184            let line = line.unwrap();
1185            assert!(line.contains(&input_str));
1186        }
1187    }
1188
1189    #[test]
1190    fn test_ext2felts() {
1191        setup_logger();
1192
1193        let mut builder = AsmBuilder::<F, EF>::default();
1194        let mut rng =
1195            StdRng::seed_from_u64(0x3264).sample_iter::<[F; 4], _>(rand::distributions::Standard);
1196        let mut random_ext = move || EF::from_base_slice(&rng.next().unwrap());
1197        for _ in 0..100 {
1198            let input = random_ext();
1199            let output: &[F] = input.as_base_slice();
1200
1201            let input_ext = builder.eval(input.cons());
1202            let output_felts = builder.ext2felt_v2(input_ext);
1203            let expected: Vec<Felt<_>> = output.iter().map(|&x| builder.eval(x)).collect();
1204            for (lhs, rhs) in output_felts.into_iter().zip(expected) {
1205                builder.assert_felt_eq(lhs, rhs);
1206            }
1207        }
1208        test_block(builder.into_root_block());
1209    }
1210
1211    macro_rules! test_assert_fixture {
1212        ($assert_felt:ident, $assert_ext:ident, $should_offset:literal) => {
1213            {
1214                use std::convert::identity;
1215                let mut builder = AsmBuilder::<F, EF>::default();
1216                test_assert_fixture!(builder, identity, F, Felt<_>, 0xDEADBEEF, $assert_felt, $should_offset);
1217                test_assert_fixture!(builder, EF::cons, EF, Ext<_, _>, 0xABADCAFE, $assert_ext, $should_offset);
1218                test_block(builder.into_root_block());
1219            }
1220        };
1221        ($builder:ident, $wrap:path, $t:ty, $u:ty, $seed:expr, $assert:ident, $should_offset:expr) => {
1222            {
1223                let mut elts = StdRng::seed_from_u64($seed)
1224                    .sample_iter::<$t, _>(rand::distributions::Standard);
1225                for _ in 0..100 {
1226                    let a = elts.next().unwrap();
1227                    let b = elts.next().unwrap();
1228                    let c = a + b;
1229                    let ar: $u = $builder.eval($wrap(a));
1230                    let br: $u = $builder.eval($wrap(b));
1231                    let cr: $u = $builder.eval(ar + br);
1232                    let cm = if $should_offset {
1233                        c + elts.find(|x| !x.is_zero()).unwrap()
1234                    } else {
1235                        c
1236                    };
1237                    $builder.$assert(cr, $wrap(cm));
1238                }
1239            }
1240        };
1241    }
1242
1243    #[test]
1244    fn test_assert_eq_noop() {
1245        test_assert_fixture!(assert_felt_eq, assert_ext_eq, false);
1246    }
1247
1248    #[test]
1249    #[should_panic]
1250    fn test_assert_eq_panics() {
1251        test_assert_fixture!(assert_felt_eq, assert_ext_eq, true);
1252    }
1253
1254    #[test]
1255    fn test_assert_ne_noop() {
1256        test_assert_fixture!(assert_felt_ne, assert_ext_ne, true);
1257    }
1258
1259    #[test]
1260    #[should_panic]
1261    fn test_assert_ne_panics() {
1262        test_assert_fixture!(assert_felt_ne, assert_ext_ne, false);
1263    }
1264}