sp1_recursion_core/chips/
exp_reverse_bits.rs

1#![allow(clippy::needless_range_loop)]
2
3use crate::{
4    builder::SP1RecursionAirBuilder, runtime::ExecutionRecord, ExpReverseBitsEvent,
5    ExpReverseBitsInstr, Instruction,
6};
7use core::borrow::Borrow;
8use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
9use p3_baby_bear::BabyBear;
10use p3_field::{AbstractField, PrimeField32};
11use p3_matrix::{dense::RowMajorMatrix, Matrix};
12use sp1_core_machine::utils::pad_rows_fixed;
13use sp1_derive::AlignedBorrow;
14use sp1_stark::air::{BaseAirBuilder, ExtensionAirBuilder, MachineAir, SP1AirBuilder};
15use std::borrow::BorrowMut;
16use tracing::instrument;
17
18use super::mem::{MemoryAccessCols, MemoryAccessColsChips};
19
20pub const NUM_EXP_REVERSE_BITS_LEN_COLS: usize = core::mem::size_of::<ExpReverseBitsLenCols<u8>>();
21pub const NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS: usize =
22    core::mem::size_of::<ExpReverseBitsLenPreprocessedCols<u8>>();
23
24#[derive(Clone, Debug, Copy, Default)]
25pub struct ExpReverseBitsLenChip<const DEGREE: usize>;
26
27#[derive(AlignedBorrow, Clone, Copy, Debug)]
28#[repr(C)]
29pub struct ExpReverseBitsLenPreprocessedCols<T: Copy> {
30    pub x_mem: MemoryAccessColsChips<T>,
31    pub exponent_mem: MemoryAccessColsChips<T>,
32    pub result_mem: MemoryAccessColsChips<T>,
33    pub iteration_num: T,
34    pub is_first: T,
35    pub is_last: T,
36    pub is_real: T,
37}
38
39#[derive(AlignedBorrow, Debug, Clone, Copy)]
40#[repr(C)]
41pub struct ExpReverseBitsLenCols<T: Copy> {
42    /// The base of the exponentiation.
43    pub x: T,
44
45    /// The current bit of the exponent. This is read from memory.
46    pub current_bit: T,
47
48    /// The previous accumulator squared.
49    pub prev_accum_squared: T,
50
51    /// Is set to the value local.prev_accum_squared * local.multiplier.
52    pub prev_accum_squared_times_multiplier: T,
53
54    /// The accumulator of the current iteration.
55    pub accum: T,
56
57    /// The accumulator squared.
58    pub accum_squared: T,
59
60    /// A column which equals x if `current_bit` is on, and 1 otherwise.
61    pub multiplier: T,
62}
63
64impl<F, const DEGREE: usize> BaseAir<F> for ExpReverseBitsLenChip<DEGREE> {
65    fn width(&self) -> usize {
66        NUM_EXP_REVERSE_BITS_LEN_COLS
67    }
68}
69
70impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for ExpReverseBitsLenChip<DEGREE> {
71    type Record = ExecutionRecord<F>;
72
73    type Program = crate::RecursionProgram<F>;
74
75    fn name(&self) -> String {
76        "ExpReverseBitsLen".to_string()
77    }
78
79    fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
80        // This is a no-op.
81    }
82
83    fn preprocessed_width(&self) -> usize {
84        NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS
85    }
86
87    fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
88        assert!(
89            std::any::TypeId::of::<F>() == std::any::TypeId::of::<BabyBear>(),
90            "generate_preprocessed_trace only supports BabyBear field"
91        );
92
93        let mut rows: Vec<[BabyBear; NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS]> = Vec::new();
94        program
95            .inner
96            .iter()
97            .filter_map(|instruction| match instruction {
98                Instruction::ExpReverseBitsLen(x) => Some(unsafe {
99                    std::mem::transmute::<&ExpReverseBitsInstr<F>, &ExpReverseBitsInstr<BabyBear>>(
100                        x,
101                    )
102                }),
103                _ => None,
104            })
105            .for_each(|instruction: &ExpReverseBitsInstr<BabyBear>| {
106                let ExpReverseBitsInstr { addrs, mult } = instruction;
107                let mut row_add = vec![
108                    [BabyBear::zero();
109                        NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS];
110                    addrs.exp.len()
111                ];
112                row_add.iter_mut().enumerate().for_each(|(i, row)| {
113                    let row: &mut ExpReverseBitsLenPreprocessedCols<BabyBear> =
114                        row.as_mut_slice().borrow_mut();
115                    row.iteration_num = BabyBear::from_canonical_u32(i as u32);
116                    row.is_first = BabyBear::from_bool(i == 0);
117                    row.is_last = BabyBear::from_bool(i == addrs.exp.len() - 1);
118                    row.is_real = BabyBear::one();
119                    row.x_mem =
120                        MemoryAccessCols { addr: addrs.base, mult: -BabyBear::from_bool(i == 0) };
121                    row.exponent_mem =
122                        MemoryAccessCols { addr: addrs.exp[i], mult: BabyBear::neg_one() };
123                    row.result_mem = MemoryAccessCols {
124                        addr: addrs.result,
125                        mult: *mult * BabyBear::from_bool(i == addrs.exp.len() - 1),
126                    };
127                });
128                rows.extend(row_add);
129            });
130
131        // Pad the trace to a power of two.
132        pad_rows_fixed(
133            &mut rows,
134            || [BabyBear::zero(); NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS],
135            program.fixed_log2_rows(self),
136        );
137
138        let trace = RowMajorMatrix::new(
139            unsafe {
140                std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
141                    rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
142                )
143            },
144            NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS,
145        );
146        Some(trace)
147    }
148
149    #[instrument(name = "generate exp reverse bits len trace", level = "debug", skip_all, fields(rows = input.exp_reverse_bits_len_events.len()))]
150    fn generate_trace(
151        &self,
152        input: &ExecutionRecord<F>,
153        _: &mut ExecutionRecord<F>,
154    ) -> RowMajorMatrix<F> {
155        assert!(
156            std::any::TypeId::of::<F>() == std::any::TypeId::of::<BabyBear>(),
157            "generate_trace only supports BabyBear field"
158        );
159
160        let events = unsafe {
161            std::mem::transmute::<&Vec<ExpReverseBitsEvent<F>>, &Vec<ExpReverseBitsEvent<BabyBear>>>(
162                &input.exp_reverse_bits_len_events,
163            )
164        };
165        let mut overall_rows = Vec::new();
166
167        events.iter().for_each(|event| {
168            let mut rows =
169                vec![vec![BabyBear::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS]; event.exp.len()];
170            let mut accum = BabyBear::one();
171
172            rows.iter_mut().enumerate().for_each(|(i, row)| {
173                let cols: &mut ExpReverseBitsLenCols<BabyBear> = row.as_mut_slice().borrow_mut();
174                unsafe {
175                    crate::sys::exp_reverse_bits_event_to_row_babybear(&event.into(), i, cols);
176                }
177
178                let prev_accum = accum;
179                accum = prev_accum * prev_accum * cols.multiplier;
180
181                cols.accum = accum;
182                cols.accum_squared = accum * accum;
183                cols.prev_accum_squared = prev_accum * prev_accum;
184                cols.prev_accum_squared_times_multiplier =
185                    cols.prev_accum_squared * cols.multiplier;
186            });
187            overall_rows.extend(rows);
188        });
189
190        // Pad the trace to a power of two.
191        pad_rows_fixed(
192            &mut overall_rows,
193            || [BabyBear::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS].to_vec(),
194            input.fixed_log2_rows(self),
195        );
196
197        // Convert the trace to a row major matrix.
198        let trace = RowMajorMatrix::new(
199            unsafe {
200                std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
201                    overall_rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
202                )
203            },
204            NUM_EXP_REVERSE_BITS_LEN_COLS,
205        );
206
207        #[cfg(debug_assertions)]
208        eprintln!(
209            "exp reverse bits len trace dims is width: {:?}, height: {:?}",
210            trace.width(),
211            trace.height()
212        );
213
214        trace
215    }
216
217    fn included(&self, _record: &Self::Record) -> bool {
218        true
219    }
220}
221
222impl<const DEGREE: usize> ExpReverseBitsLenChip<DEGREE> {
223    pub fn eval_exp_reverse_bits_len<
224        AB: BaseAirBuilder + ExtensionAirBuilder + SP1RecursionAirBuilder + SP1AirBuilder,
225    >(
226        &self,
227        builder: &mut AB,
228        local: &ExpReverseBitsLenCols<AB::Var>,
229        local_prepr: &ExpReverseBitsLenPreprocessedCols<AB::Var>,
230        next: &ExpReverseBitsLenCols<AB::Var>,
231        next_prepr: &ExpReverseBitsLenPreprocessedCols<AB::Var>,
232    ) {
233        // Dummy constraints to normalize to DEGREE when DEGREE > 3.
234        if DEGREE > 3 {
235            let lhs = (0..DEGREE).map(|_| local_prepr.is_real.into()).product::<AB::Expr>();
236            let rhs = (0..DEGREE).map(|_| local_prepr.is_real.into()).product::<AB::Expr>();
237            builder.assert_eq(lhs, rhs);
238        }
239
240        // Constrain mem read for x.  The read mult is one for only the first row, and zero for all
241        // others.
242        builder.send_single(local_prepr.x_mem.addr, local.x, local_prepr.x_mem.mult);
243
244        // Ensure that the value at the x memory access is unchanged when not `is_last`.
245        builder
246            .when_transition()
247            .when(next_prepr.is_real)
248            .when_not(local_prepr.is_last)
249            .assert_eq(local.x, next.x);
250
251        // Constrain mem read for exponent's bits.  The read mult is one for all real rows.
252        builder.send_single(
253            local_prepr.exponent_mem.addr,
254            local.current_bit,
255            local_prepr.exponent_mem.mult,
256        );
257
258        // The accumulator needs to start with the multiplier for every `is_first` row.
259        builder.when(local_prepr.is_first).assert_eq(local.accum, local.multiplier);
260
261        // `multiplier` is x if the current bit is 1, and 1 if the current bit is 0.
262        builder
263            .when(local_prepr.is_real)
264            .when(local.current_bit)
265            .assert_eq(local.multiplier, local.x);
266        builder
267            .when(local_prepr.is_real)
268            .when_not(local.current_bit)
269            .assert_eq(local.multiplier, AB::Expr::one());
270
271        // To get `next.accum`, we multiply `local.prev_accum_squared` by `local.multiplier` when
272        // not `is_last`.
273        builder.when(local_prepr.is_real).assert_eq(
274            local.prev_accum_squared_times_multiplier,
275            local.prev_accum_squared * local.multiplier,
276        );
277
278        builder
279            .when(local_prepr.is_real)
280            .when_not(local_prepr.is_first)
281            .assert_eq(local.accum, local.prev_accum_squared_times_multiplier);
282
283        // Constrain the accum_squared column.
284        builder.when(local_prepr.is_real).assert_eq(local.accum_squared, local.accum * local.accum);
285
286        builder
287            .when_transition()
288            .when(next_prepr.is_real)
289            .when_not(local_prepr.is_last)
290            .assert_eq(next.prev_accum_squared, local.accum_squared);
291
292        // Constrain mem write for the result.
293        builder.send_single(local_prepr.result_mem.addr, local.accum, local_prepr.result_mem.mult);
294    }
295
296    pub const fn do_exp_bit_memory_access<T: Copy>(
297        local: &ExpReverseBitsLenPreprocessedCols<T>,
298    ) -> T {
299        local.is_real
300    }
301}
302
303impl<AB, const DEGREE: usize> Air<AB> for ExpReverseBitsLenChip<DEGREE>
304where
305    AB: SP1RecursionAirBuilder + PairBuilder,
306{
307    fn eval(&self, builder: &mut AB) {
308        let main = builder.main();
309        let (local, next) = (main.row_slice(0), main.row_slice(1));
310        let local: &ExpReverseBitsLenCols<AB::Var> = (*local).borrow();
311        let next: &ExpReverseBitsLenCols<AB::Var> = (*next).borrow();
312        let prep = builder.preprocessed();
313        let (prep_local, prep_next) = (prep.row_slice(0), prep.row_slice(1));
314        let prep_local: &ExpReverseBitsLenPreprocessedCols<_> = (*prep_local).borrow();
315        let prep_next: &ExpReverseBitsLenPreprocessedCols<_> = (*prep_next).borrow();
316        self.eval_exp_reverse_bits_len::<AB>(builder, local, prep_local, next, prep_next);
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    #![allow(clippy::print_stdout)]
323
324    use crate::{
325        chips::{exp_reverse_bits::ExpReverseBitsLenChip, test_fixtures},
326        linear_program,
327        machine::tests::test_recursion_linear_program,
328        runtime::{instruction as instr, ExecutionRecord},
329        stark::BabyBearPoseidon2Outer,
330        Address, ExpReverseBitsEvent, ExpReverseBitsIo, Instruction, MemAccessKind,
331        RecursionProgram,
332    };
333    use itertools::Itertools;
334    use p3_baby_bear::BabyBear;
335    use p3_field::{AbstractField, PrimeField32};
336    use p3_matrix::dense::RowMajorMatrix;
337    use p3_util::reverse_bits_len;
338    use rand::{rngs::StdRng, Rng, SeedableRng};
339    use sp1_core_machine::utils::setup_logger;
340    use sp1_stark::{air::MachineAir, StarkGenericConfig};
341    use std::iter::once;
342
343    use super::*;
344
345    const DEGREE: usize = 3;
346
347    #[test]
348    fn prove_babybear_circuit_erbl() {
349        setup_logger();
350        type SC = BabyBearPoseidon2Outer;
351        type F = <SC as StarkGenericConfig>::Val;
352
353        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
354        let mut random_felt = move || -> F { F::from_canonical_u32(rng.gen_range(0..1 << 16)) };
355        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
356        let mut random_bit = move || rng.gen_range(0..2);
357        let mut addr = 0;
358
359        let instructions = (1..15)
360            .flat_map(|i| {
361                let base = random_felt();
362                let exponent_bits = vec![random_bit(); i];
363                let exponent = F::from_canonical_u32(
364                    exponent_bits.iter().enumerate().fold(0, |acc, (i, x)| acc + x * (1 << i)),
365                );
366                let result =
367                    base.exp_u64(reverse_bits_len(exponent.as_canonical_u32() as usize, i) as u64);
368
369                let alloc_size = i + 2;
370                let exp_a = (0..i).map(|x| x + addr + 1).collect::<Vec<_>>();
371                let exp_a_clone = exp_a.clone();
372                let x_a = addr;
373                let result_a = addr + alloc_size - 1;
374                addr += alloc_size;
375                let exp_bit_instructions = (0..i).map(move |j| {
376                    instr::mem_single(
377                        MemAccessKind::Write,
378                        1,
379                        exp_a_clone[j] as u32,
380                        F::from_canonical_u32(exponent_bits[j]),
381                    )
382                });
383                once(instr::mem_single(MemAccessKind::Write, 1, x_a as u32, base))
384                    .chain(exp_bit_instructions)
385                    .chain(once(instr::exp_reverse_bits_len(
386                        1,
387                        F::from_canonical_u32(x_a as u32),
388                        exp_a
389                            .into_iter()
390                            .map(|bit| F::from_canonical_u32(bit as u32))
391                            .collect_vec(),
392                        F::from_canonical_u32(result_a as u32),
393                    )))
394                    .chain(once(instr::mem_single(MemAccessKind::Read, 1, result_a as u32, result)))
395            })
396            .collect::<Vec<Instruction<F>>>();
397
398        test_recursion_linear_program(instructions);
399    }
400
401    #[test]
402    fn generate_trace() {
403        type F = BabyBear;
404
405        let shard = ExecutionRecord {
406            exp_reverse_bits_len_events: vec![ExpReverseBitsEvent {
407                base: F::two(),
408                exp: vec![F::zero(), F::one(), F::one()],
409                result: F::two().exp_u64(0b110),
410            }],
411            ..Default::default()
412        };
413        let chip = ExpReverseBitsLenChip::<3>;
414        let trace: RowMajorMatrix<F> = chip.generate_trace(&shard, &mut ExecutionRecord::default());
415        println!("{:?}", trace.values)
416    }
417
418    #[test]
419    fn generate_erbl_preprocessed_trace() {
420        type F = BabyBear;
421
422        let program = linear_program(vec![
423            instr::mem(MemAccessKind::Write, 2, 0, 0),
424            instr::mem(MemAccessKind::Write, 2, 1, 0),
425            Instruction::ExpReverseBitsLen(ExpReverseBitsInstr {
426                addrs: ExpReverseBitsIo {
427                    base: Address(F::zero()),
428                    exp: vec![Address(F::one()), Address(F::zero()), Address(F::one())],
429                    result: Address(F::from_canonical_u32(4)),
430                },
431                mult: F::one(),
432            }),
433            instr::mem(MemAccessKind::Read, 1, 4, 0),
434        ])
435        .unwrap();
436
437        let chip = ExpReverseBitsLenChip::<3>;
438        let trace = chip.generate_preprocessed_trace(&program).unwrap();
439        println!("{:?}", trace.values);
440    }
441
442    fn generate_trace_reference<const DEGREE: usize>(
443        input: &ExecutionRecord<BabyBear>,
444        _: &mut ExecutionRecord<BabyBear>,
445    ) -> RowMajorMatrix<BabyBear> {
446        type F = BabyBear;
447
448        let mut overall_rows = Vec::new();
449        input.exp_reverse_bits_len_events.iter().for_each(|event| {
450            let mut rows = vec![vec![F::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS]; event.exp.len()];
451
452            let mut accum = F::one();
453
454            rows.iter_mut().enumerate().for_each(|(i, row)| {
455                let cols: &mut ExpReverseBitsLenCols<F> = row.as_mut_slice().borrow_mut();
456
457                let prev_accum = accum;
458                accum = prev_accum *
459                    prev_accum *
460                    if event.exp[i] == F::one() { event.base } else { F::one() };
461
462                cols.x = event.base;
463                cols.current_bit = event.exp[i];
464                cols.accum = accum;
465                cols.accum_squared = accum * accum;
466                cols.prev_accum_squared = prev_accum * prev_accum;
467                cols.multiplier = if event.exp[i] == F::one() { event.base } else { F::one() };
468                cols.prev_accum_squared_times_multiplier =
469                    cols.prev_accum_squared * cols.multiplier;
470                if i == event.exp.len() {
471                    assert_eq!(event.result, accum);
472                }
473            });
474
475            overall_rows.extend(rows);
476        });
477
478        pad_rows_fixed(
479            &mut overall_rows,
480            || [F::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS].to_vec(),
481            input.fixed_log2_rows(&ExpReverseBitsLenChip::<DEGREE>),
482        );
483
484        RowMajorMatrix::new(
485            overall_rows.into_iter().flatten().collect(),
486            NUM_EXP_REVERSE_BITS_LEN_COLS,
487        )
488    }
489
490    #[test]
491    fn test_generate_trace() {
492        let shard = test_fixtures::shard();
493        let mut execution_record = test_fixtures::default_execution_record();
494        let trace = ExpReverseBitsLenChip::<DEGREE>.generate_trace(&shard, &mut execution_record);
495        assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
496
497        assert_eq!(trace, generate_trace_reference::<DEGREE>(&shard, &mut execution_record));
498    }
499
500    fn generate_preprocessed_trace_reference(
501        program: &RecursionProgram<BabyBear>,
502    ) -> RowMajorMatrix<BabyBear> {
503        type F = BabyBear;
504
505        let mut rows: Vec<[F; NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS]> = Vec::new();
506        program
507            .inner
508            .iter()
509            .filter_map(|instruction| match instruction {
510                Instruction::ExpReverseBitsLen(x) => Some(x),
511                _ => None,
512            })
513            .for_each(|instruction| {
514                let ExpReverseBitsInstr { addrs, mult } = instruction;
515                let mut row_add =
516                    vec![[F::zero(); NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS]; addrs.exp.len()];
517                row_add.iter_mut().enumerate().for_each(|(i, row)| {
518                    let row: &mut ExpReverseBitsLenPreprocessedCols<F> =
519                        row.as_mut_slice().borrow_mut();
520                    row.iteration_num = F::from_canonical_u32(i as u32);
521                    row.is_first = F::from_bool(i == 0);
522                    row.is_last = F::from_bool(i == addrs.exp.len() - 1);
523                    row.is_real = F::one();
524                    row.x_mem = MemoryAccessCols { addr: addrs.base, mult: -F::from_bool(i == 0) };
525                    row.exponent_mem = MemoryAccessCols { addr: addrs.exp[i], mult: F::neg_one() };
526                    row.result_mem = MemoryAccessCols {
527                        addr: addrs.result,
528                        mult: *mult * F::from_bool(i == addrs.exp.len() - 1),
529                    };
530                });
531                rows.extend(row_add);
532            });
533
534        pad_rows_fixed(
535            &mut rows,
536            || [F::zero(); NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS],
537            program.fixed_log2_rows(&ExpReverseBitsLenChip::<3>),
538        );
539
540        RowMajorMatrix::new(
541            rows.into_iter().flatten().collect(),
542            NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS,
543        )
544    }
545
546    #[test]
547    #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
548    fn generate_preprocessed_trace() {
549        let program = test_fixtures::program();
550        let trace = ExpReverseBitsLenChip::<DEGREE>.generate_preprocessed_trace(&program).unwrap();
551        assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
552
553        assert_eq!(trace, generate_preprocessed_trace_reference(&program));
554    }
555}