sp1_recursion_core/chips/
exp_reverse_bits.rs

1#![allow(clippy::needless_range_loop)]
2
3use crate::{builder::SP1RecursionAirBuilder, runtime::ExecutionRecord};
4use core::borrow::Borrow;
5use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
6use p3_field::{AbstractField, PrimeField32};
7use p3_matrix::{dense::RowMajorMatrix, Matrix};
8use sp1_derive::AlignedBorrow;
9use sp1_stark::air::{BaseAirBuilder, ExtensionAirBuilder, MachineAir, SP1AirBuilder};
10
11use super::mem::MemoryAccessColsChips;
12
13#[cfg(feature = "sys")]
14use {
15    super::mem::MemoryAccessCols,
16    crate::{ExpReverseBitsEvent, ExpReverseBitsInstr, Instruction},
17    p3_baby_bear::BabyBear,
18    sp1_core_machine::utils::pad_rows_fixed,
19    std::borrow::BorrowMut,
20    tracing::instrument,
21};
22
23pub const NUM_EXP_REVERSE_BITS_LEN_COLS: usize = core::mem::size_of::<ExpReverseBitsLenCols<u8>>();
24pub const NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS: usize =
25    core::mem::size_of::<ExpReverseBitsLenPreprocessedCols<u8>>();
26
27#[derive(Clone, Debug, Copy, Default)]
28pub struct ExpReverseBitsLenChip<const DEGREE: usize>;
29
30#[derive(AlignedBorrow, Clone, Copy, Debug)]
31#[repr(C)]
32pub struct ExpReverseBitsLenPreprocessedCols<T: Copy> {
33    pub x_mem: MemoryAccessColsChips<T>,
34    pub exponent_mem: MemoryAccessColsChips<T>,
35    pub result_mem: MemoryAccessColsChips<T>,
36    pub iteration_num: T,
37    pub is_first: T,
38    pub is_last: T,
39    pub is_real: T,
40}
41
42#[derive(AlignedBorrow, Debug, Clone, Copy)]
43#[repr(C)]
44pub struct ExpReverseBitsLenCols<T: Copy> {
45    /// The base of the exponentiation.
46    pub x: T,
47
48    /// The current bit of the exponent. This is read from memory.
49    pub current_bit: T,
50
51    /// The previous accumulator squared.
52    pub prev_accum_squared: T,
53
54    /// Is set to the value local.prev_accum_squared * local.multiplier.
55    pub prev_accum_squared_times_multiplier: T,
56
57    /// The accumulator of the current iteration.
58    pub accum: T,
59
60    /// The accumulator squared.
61    pub accum_squared: T,
62
63    /// A column which equals x if `current_bit` is on, and 1 otherwise.
64    pub multiplier: T,
65}
66
67impl<F, const DEGREE: usize> BaseAir<F> for ExpReverseBitsLenChip<DEGREE> {
68    fn width(&self) -> usize {
69        NUM_EXP_REVERSE_BITS_LEN_COLS
70    }
71}
72
73impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for ExpReverseBitsLenChip<DEGREE> {
74    type Record = ExecutionRecord<F>;
75
76    type Program = crate::RecursionProgram<F>;
77
78    fn name(&self) -> String {
79        "ExpReverseBitsLen".to_string()
80    }
81
82    fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
83        // This is a no-op.
84    }
85
86    fn preprocessed_width(&self) -> usize {
87        NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS
88    }
89
90    #[cfg(not(feature = "sys"))]
91    fn generate_preprocessed_trace(&self, _program: &Self::Program) -> Option<RowMajorMatrix<F>> {
92        unimplemented!("To generate traces, enable feature `sp1-recursion-core/sys`");
93    }
94
95    #[cfg(feature = "sys")]
96    fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
97        assert!(
98            std::any::TypeId::of::<F>() == std::any::TypeId::of::<BabyBear>(),
99            "generate_preprocessed_trace only supports BabyBear field"
100        );
101
102        let mut rows: Vec<[BabyBear; NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS]> = Vec::new();
103        program
104            .inner
105            .iter()
106            .filter_map(|instruction| match instruction {
107                Instruction::ExpReverseBitsLen(x) => Some(unsafe {
108                    std::mem::transmute::<&ExpReverseBitsInstr<F>, &ExpReverseBitsInstr<BabyBear>>(
109                        x,
110                    )
111                }),
112                _ => None,
113            })
114            .for_each(|instruction: &ExpReverseBitsInstr<BabyBear>| {
115                let ExpReverseBitsInstr { addrs, mult } = instruction;
116                let mut row_add = vec![
117                    [BabyBear::zero();
118                        NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS];
119                    addrs.exp.len()
120                ];
121                row_add.iter_mut().enumerate().for_each(|(i, row)| {
122                    let row: &mut ExpReverseBitsLenPreprocessedCols<BabyBear> =
123                        row.as_mut_slice().borrow_mut();
124                    row.iteration_num = BabyBear::from_canonical_u32(i as u32);
125                    row.is_first = BabyBear::from_bool(i == 0);
126                    row.is_last = BabyBear::from_bool(i == addrs.exp.len() - 1);
127                    row.is_real = BabyBear::one();
128                    row.x_mem =
129                        MemoryAccessCols { addr: addrs.base, mult: -BabyBear::from_bool(i == 0) };
130                    row.exponent_mem =
131                        MemoryAccessCols { addr: addrs.exp[i], mult: BabyBear::neg_one() };
132                    row.result_mem = MemoryAccessCols {
133                        addr: addrs.result,
134                        mult: *mult * BabyBear::from_bool(i == addrs.exp.len() - 1),
135                    };
136                });
137                rows.extend(row_add);
138            });
139
140        // Pad the trace to a power of two.
141        pad_rows_fixed(
142            &mut rows,
143            || [BabyBear::zero(); NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS],
144            program.fixed_log2_rows(self),
145        );
146
147        let trace = RowMajorMatrix::new(
148            unsafe {
149                std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
150                    rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
151                )
152            },
153            NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS,
154        );
155        Some(trace)
156    }
157
158    #[cfg(not(feature = "sys"))]
159    fn generate_trace(&self, _input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
160        unimplemented!("To generate traces, enable feature `sp1-recursion-core/sys`");
161    }
162
163    #[cfg(feature = "sys")]
164    #[instrument(name = "generate exp reverse bits len trace", level = "debug", skip_all, fields(rows = input.exp_reverse_bits_len_events.len()))]
165    fn generate_trace(
166        &self,
167        input: &ExecutionRecord<F>,
168        _: &mut ExecutionRecord<F>,
169    ) -> RowMajorMatrix<F> {
170        assert!(
171            std::any::TypeId::of::<F>() == std::any::TypeId::of::<BabyBear>(),
172            "generate_trace only supports BabyBear field"
173        );
174
175        let events = unsafe {
176            std::mem::transmute::<&Vec<ExpReverseBitsEvent<F>>, &Vec<ExpReverseBitsEvent<BabyBear>>>(
177                &input.exp_reverse_bits_len_events,
178            )
179        };
180        let mut overall_rows = Vec::new();
181
182        events.iter().for_each(|event| {
183            let mut rows =
184                vec![vec![BabyBear::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS]; event.exp.len()];
185            let mut accum = BabyBear::one();
186
187            rows.iter_mut().enumerate().for_each(|(i, row)| {
188                let cols: &mut ExpReverseBitsLenCols<BabyBear> = row.as_mut_slice().borrow_mut();
189                unsafe {
190                    crate::sys::exp_reverse_bits_event_to_row_babybear(&event.into(), i, cols);
191                }
192
193                let prev_accum = accum;
194                accum = prev_accum * prev_accum * cols.multiplier;
195
196                cols.accum = accum;
197                cols.accum_squared = accum * accum;
198                cols.prev_accum_squared = prev_accum * prev_accum;
199                cols.prev_accum_squared_times_multiplier =
200                    cols.prev_accum_squared * cols.multiplier;
201            });
202            overall_rows.extend(rows);
203        });
204
205        // Pad the trace to a power of two.
206        pad_rows_fixed(
207            &mut overall_rows,
208            || [BabyBear::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS].to_vec(),
209            input.fixed_log2_rows(self),
210        );
211
212        // Convert the trace to a row major matrix.
213        let trace = RowMajorMatrix::new(
214            unsafe {
215                std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
216                    overall_rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
217                )
218            },
219            NUM_EXP_REVERSE_BITS_LEN_COLS,
220        );
221
222        #[cfg(debug_assertions)]
223        eprintln!(
224            "exp reverse bits len trace dims is width: {:?}, height: {:?}",
225            trace.width(),
226            trace.height()
227        );
228
229        trace
230    }
231
232    fn included(&self, _record: &Self::Record) -> bool {
233        true
234    }
235}
236
237impl<const DEGREE: usize> ExpReverseBitsLenChip<DEGREE> {
238    pub fn eval_exp_reverse_bits_len<
239        AB: BaseAirBuilder + ExtensionAirBuilder + SP1RecursionAirBuilder + SP1AirBuilder,
240    >(
241        &self,
242        builder: &mut AB,
243        local: &ExpReverseBitsLenCols<AB::Var>,
244        local_prepr: &ExpReverseBitsLenPreprocessedCols<AB::Var>,
245        next: &ExpReverseBitsLenCols<AB::Var>,
246        next_prepr: &ExpReverseBitsLenPreprocessedCols<AB::Var>,
247    ) {
248        // Dummy constraints to normalize to DEGREE when DEGREE > 3.
249        if DEGREE > 3 {
250            let lhs = (0..DEGREE).map(|_| local_prepr.is_real.into()).product::<AB::Expr>();
251            let rhs = (0..DEGREE).map(|_| local_prepr.is_real.into()).product::<AB::Expr>();
252            builder.assert_eq(lhs, rhs);
253        }
254
255        // Constrain mem read for x.  The read mult is one for only the first row, and zero for all
256        // others.
257        builder.send_single(local_prepr.x_mem.addr, local.x, local_prepr.x_mem.mult);
258
259        // Ensure that the value at the x memory access is unchanged when not `is_last`.
260        builder
261            .when_transition()
262            .when(next_prepr.is_real)
263            .when_not(local_prepr.is_last)
264            .assert_eq(local.x, next.x);
265
266        // Constrain mem read for exponent's bits.  The read mult is one for all real rows.
267        builder.send_single(
268            local_prepr.exponent_mem.addr,
269            local.current_bit,
270            local_prepr.exponent_mem.mult,
271        );
272
273        // The accumulator needs to start with the multiplier for every `is_first` row.
274        builder.when(local_prepr.is_first).assert_eq(local.accum, local.multiplier);
275
276        // `multiplier` is x if the current bit is 1, and 1 if the current bit is 0.
277        builder
278            .when(local_prepr.is_real)
279            .when(local.current_bit)
280            .assert_eq(local.multiplier, local.x);
281        builder
282            .when(local_prepr.is_real)
283            .when_not(local.current_bit)
284            .assert_eq(local.multiplier, AB::Expr::one());
285
286        // To get `next.accum`, we multiply `local.prev_accum_squared` by `local.multiplier` when
287        // not `is_last`.
288        builder.when(local_prepr.is_real).assert_eq(
289            local.prev_accum_squared_times_multiplier,
290            local.prev_accum_squared * local.multiplier,
291        );
292
293        builder
294            .when(local_prepr.is_real)
295            .when_not(local_prepr.is_first)
296            .assert_eq(local.accum, local.prev_accum_squared_times_multiplier);
297
298        // Constrain the accum_squared column.
299        builder.when(local_prepr.is_real).assert_eq(local.accum_squared, local.accum * local.accum);
300
301        builder
302            .when_transition()
303            .when(next_prepr.is_real)
304            .when_not(local_prepr.is_last)
305            .assert_eq(next.prev_accum_squared, local.accum_squared);
306
307        // Constrain mem write for the result.
308        builder.send_single(local_prepr.result_mem.addr, local.accum, local_prepr.result_mem.mult);
309    }
310
311    pub const fn do_exp_bit_memory_access<T: Copy>(
312        local: &ExpReverseBitsLenPreprocessedCols<T>,
313    ) -> T {
314        local.is_real
315    }
316}
317
318impl<AB, const DEGREE: usize> Air<AB> for ExpReverseBitsLenChip<DEGREE>
319where
320    AB: SP1RecursionAirBuilder + PairBuilder,
321{
322    fn eval(&self, builder: &mut AB) {
323        let main = builder.main();
324        let (local, next) = (main.row_slice(0), main.row_slice(1));
325        let local: &ExpReverseBitsLenCols<AB::Var> = (*local).borrow();
326        let next: &ExpReverseBitsLenCols<AB::Var> = (*next).borrow();
327        let prep = builder.preprocessed();
328        let (prep_local, prep_next) = (prep.row_slice(0), prep.row_slice(1));
329        let prep_local: &ExpReverseBitsLenPreprocessedCols<_> = (*prep_local).borrow();
330        let prep_next: &ExpReverseBitsLenPreprocessedCols<_> = (*prep_next).borrow();
331        self.eval_exp_reverse_bits_len::<AB>(builder, local, prep_local, next, prep_next);
332    }
333}
334
335#[cfg(all(test, feature = "sys"))]
336mod tests {
337    #![allow(clippy::print_stdout)]
338
339    use crate::{
340        chips::{exp_reverse_bits::ExpReverseBitsLenChip, test_fixtures},
341        linear_program,
342        machine::tests::test_recursion_linear_program,
343        runtime::{instruction as instr, ExecutionRecord},
344        stark::BabyBearPoseidon2Outer,
345        Address, ExpReverseBitsEvent, ExpReverseBitsIo, Instruction, MemAccessKind,
346        RecursionProgram,
347    };
348    use itertools::Itertools;
349    use p3_baby_bear::BabyBear;
350    use p3_field::{AbstractField, PrimeField32};
351    use p3_matrix::dense::RowMajorMatrix;
352    use p3_util::reverse_bits_len;
353    use rand::{rngs::StdRng, Rng, SeedableRng};
354    use sp1_core_machine::utils::setup_logger;
355    use sp1_stark::{air::MachineAir, StarkGenericConfig};
356    use std::iter::once;
357
358    use super::*;
359
360    const DEGREE: usize = 3;
361
362    #[test]
363    fn prove_babybear_circuit_erbl() {
364        setup_logger();
365        type SC = BabyBearPoseidon2Outer;
366        type F = <SC as StarkGenericConfig>::Val;
367
368        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
369        let mut random_felt = move || -> F { F::from_canonical_u32(rng.gen_range(0..1 << 16)) };
370        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
371        let mut random_bit = move || rng.gen_range(0..2);
372        let mut addr = 0;
373
374        let instructions = (1..15)
375            .flat_map(|i| {
376                let base = random_felt();
377                let exponent_bits = vec![random_bit(); i];
378                let exponent = F::from_canonical_u32(
379                    exponent_bits.iter().enumerate().fold(0, |acc, (i, x)| acc + x * (1 << i)),
380                );
381                let result =
382                    base.exp_u64(reverse_bits_len(exponent.as_canonical_u32() as usize, i) as u64);
383
384                let alloc_size = i + 2;
385                let exp_a = (0..i).map(|x| x + addr + 1).collect::<Vec<_>>();
386                let exp_a_clone = exp_a.clone();
387                let x_a = addr;
388                let result_a = addr + alloc_size - 1;
389                addr += alloc_size;
390                let exp_bit_instructions = (0..i).map(move |j| {
391                    instr::mem_single(
392                        MemAccessKind::Write,
393                        1,
394                        exp_a_clone[j] as u32,
395                        F::from_canonical_u32(exponent_bits[j]),
396                    )
397                });
398                once(instr::mem_single(MemAccessKind::Write, 1, x_a as u32, base))
399                    .chain(exp_bit_instructions)
400                    .chain(once(instr::exp_reverse_bits_len(
401                        1,
402                        F::from_canonical_u32(x_a as u32),
403                        exp_a
404                            .into_iter()
405                            .map(|bit| F::from_canonical_u32(bit as u32))
406                            .collect_vec(),
407                        F::from_canonical_u32(result_a as u32),
408                    )))
409                    .chain(once(instr::mem_single(MemAccessKind::Read, 1, result_a as u32, result)))
410            })
411            .collect::<Vec<Instruction<F>>>();
412
413        test_recursion_linear_program(instructions);
414    }
415
416    #[test]
417    fn generate_trace() {
418        type F = BabyBear;
419
420        let shard = ExecutionRecord {
421            exp_reverse_bits_len_events: vec![ExpReverseBitsEvent {
422                base: F::two(),
423                exp: vec![F::zero(), F::one(), F::one()],
424                result: F::two().exp_u64(0b110),
425            }],
426            ..Default::default()
427        };
428        let chip = ExpReverseBitsLenChip::<3>;
429        let trace: RowMajorMatrix<F> = chip.generate_trace(&shard, &mut ExecutionRecord::default());
430        println!("{:?}", trace.values)
431    }
432
433    #[test]
434    fn generate_erbl_preprocessed_trace() {
435        type F = BabyBear;
436
437        let program = linear_program(vec![
438            instr::mem(MemAccessKind::Write, 2, 0, 0),
439            instr::mem(MemAccessKind::Write, 2, 1, 0),
440            Instruction::ExpReverseBitsLen(ExpReverseBitsInstr {
441                addrs: ExpReverseBitsIo {
442                    base: Address(F::zero()),
443                    exp: vec![Address(F::one()), Address(F::zero()), Address(F::one())],
444                    result: Address(F::from_canonical_u32(4)),
445                },
446                mult: F::one(),
447            }),
448            instr::mem(MemAccessKind::Read, 1, 4, 0),
449        ])
450        .unwrap();
451
452        let chip = ExpReverseBitsLenChip::<3>;
453        let trace = chip.generate_preprocessed_trace(&program).unwrap();
454        println!("{:?}", trace.values);
455    }
456
457    fn generate_trace_reference<const DEGREE: usize>(
458        input: &ExecutionRecord<BabyBear>,
459        _: &mut ExecutionRecord<BabyBear>,
460    ) -> RowMajorMatrix<BabyBear> {
461        type F = BabyBear;
462
463        let mut overall_rows = Vec::new();
464        input.exp_reverse_bits_len_events.iter().for_each(|event| {
465            let mut rows = vec![vec![F::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS]; event.exp.len()];
466
467            let mut accum = F::one();
468
469            rows.iter_mut().enumerate().for_each(|(i, row)| {
470                let cols: &mut ExpReverseBitsLenCols<F> = row.as_mut_slice().borrow_mut();
471
472                let prev_accum = accum;
473                accum = prev_accum *
474                    prev_accum *
475                    if event.exp[i] == F::one() { event.base } else { F::one() };
476
477                cols.x = event.base;
478                cols.current_bit = event.exp[i];
479                cols.accum = accum;
480                cols.accum_squared = accum * accum;
481                cols.prev_accum_squared = prev_accum * prev_accum;
482                cols.multiplier = if event.exp[i] == F::one() { event.base } else { F::one() };
483                cols.prev_accum_squared_times_multiplier =
484                    cols.prev_accum_squared * cols.multiplier;
485                if i == event.exp.len() {
486                    assert_eq!(event.result, accum);
487                }
488            });
489
490            overall_rows.extend(rows);
491        });
492
493        pad_rows_fixed(
494            &mut overall_rows,
495            || [F::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS].to_vec(),
496            input.fixed_log2_rows(&ExpReverseBitsLenChip::<DEGREE>),
497        );
498
499        RowMajorMatrix::new(
500            overall_rows.into_iter().flatten().collect(),
501            NUM_EXP_REVERSE_BITS_LEN_COLS,
502        )
503    }
504
505    #[test]
506    fn test_generate_trace() {
507        let shard = test_fixtures::shard();
508        let mut execution_record = test_fixtures::default_execution_record();
509        let trace = ExpReverseBitsLenChip::<DEGREE>.generate_trace(&shard, &mut execution_record);
510        assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
511
512        assert_eq!(trace, generate_trace_reference::<DEGREE>(&shard, &mut execution_record));
513    }
514
515    fn generate_preprocessed_trace_reference(
516        program: &RecursionProgram<BabyBear>,
517    ) -> RowMajorMatrix<BabyBear> {
518        type F = BabyBear;
519
520        let mut rows: Vec<[F; NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS]> = Vec::new();
521        program
522            .inner
523            .iter()
524            .filter_map(|instruction| match instruction {
525                Instruction::ExpReverseBitsLen(x) => Some(x),
526                _ => None,
527            })
528            .for_each(|instruction| {
529                let ExpReverseBitsInstr { addrs, mult } = instruction;
530                let mut row_add =
531                    vec![[F::zero(); NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS]; addrs.exp.len()];
532                row_add.iter_mut().enumerate().for_each(|(i, row)| {
533                    let row: &mut ExpReverseBitsLenPreprocessedCols<F> =
534                        row.as_mut_slice().borrow_mut();
535                    row.iteration_num = F::from_canonical_u32(i as u32);
536                    row.is_first = F::from_bool(i == 0);
537                    row.is_last = F::from_bool(i == addrs.exp.len() - 1);
538                    row.is_real = F::one();
539                    row.x_mem = MemoryAccessCols { addr: addrs.base, mult: -F::from_bool(i == 0) };
540                    row.exponent_mem = MemoryAccessCols { addr: addrs.exp[i], mult: F::neg_one() };
541                    row.result_mem = MemoryAccessCols {
542                        addr: addrs.result,
543                        mult: *mult * F::from_bool(i == addrs.exp.len() - 1),
544                    };
545                });
546                rows.extend(row_add);
547            });
548
549        pad_rows_fixed(
550            &mut rows,
551            || [F::zero(); NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS],
552            program.fixed_log2_rows(&ExpReverseBitsLenChip::<3>),
553        );
554
555        RowMajorMatrix::new(
556            rows.into_iter().flatten().collect(),
557            NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS,
558        )
559    }
560
561    #[test]
562    #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
563    fn generate_preprocessed_trace() {
564        let program = test_fixtures::program();
565        let trace = ExpReverseBitsLenChip::<DEGREE>.generate_preprocessed_trace(&program).unwrap();
566        assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
567
568        assert_eq!(trace, generate_preprocessed_trace_reference(&program));
569    }
570}