1#![allow(clippy::needless_range_loop)]
2
3use crate::{builder::SP1RecursionAirBuilder, runtime::ExecutionRecord};
4use core::borrow::Borrow;
5use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
6use p3_field::{AbstractField, PrimeField32};
7use p3_matrix::{dense::RowMajorMatrix, Matrix};
8use sp1_derive::AlignedBorrow;
9use sp1_stark::air::{BaseAirBuilder, ExtensionAirBuilder, MachineAir, SP1AirBuilder};
10
11use super::mem::MemoryAccessColsChips;
12
13#[cfg(feature = "sys")]
14use {
15 super::mem::MemoryAccessCols,
16 crate::{ExpReverseBitsEvent, ExpReverseBitsInstr, Instruction},
17 p3_baby_bear::BabyBear,
18 sp1_core_machine::utils::pad_rows_fixed,
19 std::borrow::BorrowMut,
20 tracing::instrument,
21};
22
23pub const NUM_EXP_REVERSE_BITS_LEN_COLS: usize = core::mem::size_of::<ExpReverseBitsLenCols<u8>>();
24pub const NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS: usize =
25 core::mem::size_of::<ExpReverseBitsLenPreprocessedCols<u8>>();
26
27#[derive(Clone, Debug, Copy, Default)]
28pub struct ExpReverseBitsLenChip<const DEGREE: usize>;
29
30#[derive(AlignedBorrow, Clone, Copy, Debug)]
31#[repr(C)]
32pub struct ExpReverseBitsLenPreprocessedCols<T: Copy> {
33 pub x_mem: MemoryAccessColsChips<T>,
34 pub exponent_mem: MemoryAccessColsChips<T>,
35 pub result_mem: MemoryAccessColsChips<T>,
36 pub iteration_num: T,
37 pub is_first: T,
38 pub is_last: T,
39 pub is_real: T,
40}
41
42#[derive(AlignedBorrow, Debug, Clone, Copy)]
43#[repr(C)]
44pub struct ExpReverseBitsLenCols<T: Copy> {
45 pub x: T,
47
48 pub current_bit: T,
50
51 pub prev_accum_squared: T,
53
54 pub prev_accum_squared_times_multiplier: T,
56
57 pub accum: T,
59
60 pub accum_squared: T,
62
63 pub multiplier: T,
65}
66
67impl<F, const DEGREE: usize> BaseAir<F> for ExpReverseBitsLenChip<DEGREE> {
68 fn width(&self) -> usize {
69 NUM_EXP_REVERSE_BITS_LEN_COLS
70 }
71}
72
73impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for ExpReverseBitsLenChip<DEGREE> {
74 type Record = ExecutionRecord<F>;
75
76 type Program = crate::RecursionProgram<F>;
77
78 fn name(&self) -> String {
79 "ExpReverseBitsLen".to_string()
80 }
81
82 fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
83 }
85
86 fn preprocessed_width(&self) -> usize {
87 NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS
88 }
89
90 #[cfg(not(feature = "sys"))]
91 fn generate_preprocessed_trace(&self, _program: &Self::Program) -> Option<RowMajorMatrix<F>> {
92 unimplemented!("To generate traces, enable feature `sp1-recursion-core/sys`");
93 }
94
95 #[cfg(feature = "sys")]
96 fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
97 assert!(
98 std::any::TypeId::of::<F>() == std::any::TypeId::of::<BabyBear>(),
99 "generate_preprocessed_trace only supports BabyBear field"
100 );
101
102 let mut rows: Vec<[BabyBear; NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS]> = Vec::new();
103 program
104 .inner
105 .iter()
106 .filter_map(|instruction| match instruction {
107 Instruction::ExpReverseBitsLen(x) => Some(unsafe {
108 std::mem::transmute::<&ExpReverseBitsInstr<F>, &ExpReverseBitsInstr<BabyBear>>(
109 x,
110 )
111 }),
112 _ => None,
113 })
114 .for_each(|instruction: &ExpReverseBitsInstr<BabyBear>| {
115 let ExpReverseBitsInstr { addrs, mult } = instruction;
116 let mut row_add = vec![
117 [BabyBear::zero();
118 NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS];
119 addrs.exp.len()
120 ];
121 row_add.iter_mut().enumerate().for_each(|(i, row)| {
122 let row: &mut ExpReverseBitsLenPreprocessedCols<BabyBear> =
123 row.as_mut_slice().borrow_mut();
124 row.iteration_num = BabyBear::from_canonical_u32(i as u32);
125 row.is_first = BabyBear::from_bool(i == 0);
126 row.is_last = BabyBear::from_bool(i == addrs.exp.len() - 1);
127 row.is_real = BabyBear::one();
128 row.x_mem =
129 MemoryAccessCols { addr: addrs.base, mult: -BabyBear::from_bool(i == 0) };
130 row.exponent_mem =
131 MemoryAccessCols { addr: addrs.exp[i], mult: BabyBear::neg_one() };
132 row.result_mem = MemoryAccessCols {
133 addr: addrs.result,
134 mult: *mult * BabyBear::from_bool(i == addrs.exp.len() - 1),
135 };
136 });
137 rows.extend(row_add);
138 });
139
140 pad_rows_fixed(
142 &mut rows,
143 || [BabyBear::zero(); NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS],
144 program.fixed_log2_rows(self),
145 );
146
147 let trace = RowMajorMatrix::new(
148 unsafe {
149 std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
150 rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
151 )
152 },
153 NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS,
154 );
155 Some(trace)
156 }
157
158 #[cfg(not(feature = "sys"))]
159 fn generate_trace(&self, _input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
160 unimplemented!("To generate traces, enable feature `sp1-recursion-core/sys`");
161 }
162
163 #[cfg(feature = "sys")]
164 #[instrument(name = "generate exp reverse bits len trace", level = "debug", skip_all, fields(rows = input.exp_reverse_bits_len_events.len()))]
165 fn generate_trace(
166 &self,
167 input: &ExecutionRecord<F>,
168 _: &mut ExecutionRecord<F>,
169 ) -> RowMajorMatrix<F> {
170 assert!(
171 std::any::TypeId::of::<F>() == std::any::TypeId::of::<BabyBear>(),
172 "generate_trace only supports BabyBear field"
173 );
174
175 let events = unsafe {
176 std::mem::transmute::<&Vec<ExpReverseBitsEvent<F>>, &Vec<ExpReverseBitsEvent<BabyBear>>>(
177 &input.exp_reverse_bits_len_events,
178 )
179 };
180 let mut overall_rows = Vec::new();
181
182 events.iter().for_each(|event| {
183 let mut rows =
184 vec![vec![BabyBear::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS]; event.exp.len()];
185 let mut accum = BabyBear::one();
186
187 rows.iter_mut().enumerate().for_each(|(i, row)| {
188 let cols: &mut ExpReverseBitsLenCols<BabyBear> = row.as_mut_slice().borrow_mut();
189 unsafe {
190 crate::sys::exp_reverse_bits_event_to_row_babybear(&event.into(), i, cols);
191 }
192
193 let prev_accum = accum;
194 accum = prev_accum * prev_accum * cols.multiplier;
195
196 cols.accum = accum;
197 cols.accum_squared = accum * accum;
198 cols.prev_accum_squared = prev_accum * prev_accum;
199 cols.prev_accum_squared_times_multiplier =
200 cols.prev_accum_squared * cols.multiplier;
201 });
202 overall_rows.extend(rows);
203 });
204
205 pad_rows_fixed(
207 &mut overall_rows,
208 || [BabyBear::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS].to_vec(),
209 input.fixed_log2_rows(self),
210 );
211
212 let trace = RowMajorMatrix::new(
214 unsafe {
215 std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
216 overall_rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
217 )
218 },
219 NUM_EXP_REVERSE_BITS_LEN_COLS,
220 );
221
222 #[cfg(debug_assertions)]
223 eprintln!(
224 "exp reverse bits len trace dims is width: {:?}, height: {:?}",
225 trace.width(),
226 trace.height()
227 );
228
229 trace
230 }
231
232 fn included(&self, _record: &Self::Record) -> bool {
233 true
234 }
235}
236
237impl<const DEGREE: usize> ExpReverseBitsLenChip<DEGREE> {
238 pub fn eval_exp_reverse_bits_len<
239 AB: BaseAirBuilder + ExtensionAirBuilder + SP1RecursionAirBuilder + SP1AirBuilder,
240 >(
241 &self,
242 builder: &mut AB,
243 local: &ExpReverseBitsLenCols<AB::Var>,
244 local_prepr: &ExpReverseBitsLenPreprocessedCols<AB::Var>,
245 next: &ExpReverseBitsLenCols<AB::Var>,
246 next_prepr: &ExpReverseBitsLenPreprocessedCols<AB::Var>,
247 ) {
248 if DEGREE > 3 {
250 let lhs = (0..DEGREE).map(|_| local_prepr.is_real.into()).product::<AB::Expr>();
251 let rhs = (0..DEGREE).map(|_| local_prepr.is_real.into()).product::<AB::Expr>();
252 builder.assert_eq(lhs, rhs);
253 }
254
255 builder.send_single(local_prepr.x_mem.addr, local.x, local_prepr.x_mem.mult);
258
259 builder
261 .when_transition()
262 .when(next_prepr.is_real)
263 .when_not(local_prepr.is_last)
264 .assert_eq(local.x, next.x);
265
266 builder.send_single(
268 local_prepr.exponent_mem.addr,
269 local.current_bit,
270 local_prepr.exponent_mem.mult,
271 );
272
273 builder.when(local_prepr.is_first).assert_eq(local.accum, local.multiplier);
275
276 builder
278 .when(local_prepr.is_real)
279 .when(local.current_bit)
280 .assert_eq(local.multiplier, local.x);
281 builder
282 .when(local_prepr.is_real)
283 .when_not(local.current_bit)
284 .assert_eq(local.multiplier, AB::Expr::one());
285
286 builder.when(local_prepr.is_real).assert_eq(
289 local.prev_accum_squared_times_multiplier,
290 local.prev_accum_squared * local.multiplier,
291 );
292
293 builder
294 .when(local_prepr.is_real)
295 .when_not(local_prepr.is_first)
296 .assert_eq(local.accum, local.prev_accum_squared_times_multiplier);
297
298 builder.when(local_prepr.is_real).assert_eq(local.accum_squared, local.accum * local.accum);
300
301 builder
302 .when_transition()
303 .when(next_prepr.is_real)
304 .when_not(local_prepr.is_last)
305 .assert_eq(next.prev_accum_squared, local.accum_squared);
306
307 builder.send_single(local_prepr.result_mem.addr, local.accum, local_prepr.result_mem.mult);
309 }
310
311 pub const fn do_exp_bit_memory_access<T: Copy>(
312 local: &ExpReverseBitsLenPreprocessedCols<T>,
313 ) -> T {
314 local.is_real
315 }
316}
317
318impl<AB, const DEGREE: usize> Air<AB> for ExpReverseBitsLenChip<DEGREE>
319where
320 AB: SP1RecursionAirBuilder + PairBuilder,
321{
322 fn eval(&self, builder: &mut AB) {
323 let main = builder.main();
324 let (local, next) = (main.row_slice(0), main.row_slice(1));
325 let local: &ExpReverseBitsLenCols<AB::Var> = (*local).borrow();
326 let next: &ExpReverseBitsLenCols<AB::Var> = (*next).borrow();
327 let prep = builder.preprocessed();
328 let (prep_local, prep_next) = (prep.row_slice(0), prep.row_slice(1));
329 let prep_local: &ExpReverseBitsLenPreprocessedCols<_> = (*prep_local).borrow();
330 let prep_next: &ExpReverseBitsLenPreprocessedCols<_> = (*prep_next).borrow();
331 self.eval_exp_reverse_bits_len::<AB>(builder, local, prep_local, next, prep_next);
332 }
333}
334
335#[cfg(all(test, feature = "sys"))]
336mod tests {
337 #![allow(clippy::print_stdout)]
338
339 use crate::{
340 chips::{exp_reverse_bits::ExpReverseBitsLenChip, test_fixtures},
341 linear_program,
342 machine::tests::test_recursion_linear_program,
343 runtime::{instruction as instr, ExecutionRecord},
344 stark::BabyBearPoseidon2Outer,
345 Address, ExpReverseBitsEvent, ExpReverseBitsIo, Instruction, MemAccessKind,
346 RecursionProgram,
347 };
348 use itertools::Itertools;
349 use p3_baby_bear::BabyBear;
350 use p3_field::{AbstractField, PrimeField32};
351 use p3_matrix::dense::RowMajorMatrix;
352 use p3_util::reverse_bits_len;
353 use rand::{rngs::StdRng, Rng, SeedableRng};
354 use sp1_core_machine::utils::setup_logger;
355 use sp1_stark::{air::MachineAir, StarkGenericConfig};
356 use std::iter::once;
357
358 use super::*;
359
360 const DEGREE: usize = 3;
361
362 #[test]
363 fn prove_babybear_circuit_erbl() {
364 setup_logger();
365 type SC = BabyBearPoseidon2Outer;
366 type F = <SC as StarkGenericConfig>::Val;
367
368 let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
369 let mut random_felt = move || -> F { F::from_canonical_u32(rng.gen_range(0..1 << 16)) };
370 let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
371 let mut random_bit = move || rng.gen_range(0..2);
372 let mut addr = 0;
373
374 let instructions = (1..15)
375 .flat_map(|i| {
376 let base = random_felt();
377 let exponent_bits = vec![random_bit(); i];
378 let exponent = F::from_canonical_u32(
379 exponent_bits.iter().enumerate().fold(0, |acc, (i, x)| acc + x * (1 << i)),
380 );
381 let result =
382 base.exp_u64(reverse_bits_len(exponent.as_canonical_u32() as usize, i) as u64);
383
384 let alloc_size = i + 2;
385 let exp_a = (0..i).map(|x| x + addr + 1).collect::<Vec<_>>();
386 let exp_a_clone = exp_a.clone();
387 let x_a = addr;
388 let result_a = addr + alloc_size - 1;
389 addr += alloc_size;
390 let exp_bit_instructions = (0..i).map(move |j| {
391 instr::mem_single(
392 MemAccessKind::Write,
393 1,
394 exp_a_clone[j] as u32,
395 F::from_canonical_u32(exponent_bits[j]),
396 )
397 });
398 once(instr::mem_single(MemAccessKind::Write, 1, x_a as u32, base))
399 .chain(exp_bit_instructions)
400 .chain(once(instr::exp_reverse_bits_len(
401 1,
402 F::from_canonical_u32(x_a as u32),
403 exp_a
404 .into_iter()
405 .map(|bit| F::from_canonical_u32(bit as u32))
406 .collect_vec(),
407 F::from_canonical_u32(result_a as u32),
408 )))
409 .chain(once(instr::mem_single(MemAccessKind::Read, 1, result_a as u32, result)))
410 })
411 .collect::<Vec<Instruction<F>>>();
412
413 test_recursion_linear_program(instructions);
414 }
415
416 #[test]
417 fn generate_trace() {
418 type F = BabyBear;
419
420 let shard = ExecutionRecord {
421 exp_reverse_bits_len_events: vec![ExpReverseBitsEvent {
422 base: F::two(),
423 exp: vec![F::zero(), F::one(), F::one()],
424 result: F::two().exp_u64(0b110),
425 }],
426 ..Default::default()
427 };
428 let chip = ExpReverseBitsLenChip::<3>;
429 let trace: RowMajorMatrix<F> = chip.generate_trace(&shard, &mut ExecutionRecord::default());
430 println!("{:?}", trace.values)
431 }
432
433 #[test]
434 fn generate_erbl_preprocessed_trace() {
435 type F = BabyBear;
436
437 let program = linear_program(vec![
438 instr::mem(MemAccessKind::Write, 2, 0, 0),
439 instr::mem(MemAccessKind::Write, 2, 1, 0),
440 Instruction::ExpReverseBitsLen(ExpReverseBitsInstr {
441 addrs: ExpReverseBitsIo {
442 base: Address(F::zero()),
443 exp: vec![Address(F::one()), Address(F::zero()), Address(F::one())],
444 result: Address(F::from_canonical_u32(4)),
445 },
446 mult: F::one(),
447 }),
448 instr::mem(MemAccessKind::Read, 1, 4, 0),
449 ])
450 .unwrap();
451
452 let chip = ExpReverseBitsLenChip::<3>;
453 let trace = chip.generate_preprocessed_trace(&program).unwrap();
454 println!("{:?}", trace.values);
455 }
456
457 fn generate_trace_reference<const DEGREE: usize>(
458 input: &ExecutionRecord<BabyBear>,
459 _: &mut ExecutionRecord<BabyBear>,
460 ) -> RowMajorMatrix<BabyBear> {
461 type F = BabyBear;
462
463 let mut overall_rows = Vec::new();
464 input.exp_reverse_bits_len_events.iter().for_each(|event| {
465 let mut rows = vec![vec![F::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS]; event.exp.len()];
466
467 let mut accum = F::one();
468
469 rows.iter_mut().enumerate().for_each(|(i, row)| {
470 let cols: &mut ExpReverseBitsLenCols<F> = row.as_mut_slice().borrow_mut();
471
472 let prev_accum = accum;
473 accum = prev_accum *
474 prev_accum *
475 if event.exp[i] == F::one() { event.base } else { F::one() };
476
477 cols.x = event.base;
478 cols.current_bit = event.exp[i];
479 cols.accum = accum;
480 cols.accum_squared = accum * accum;
481 cols.prev_accum_squared = prev_accum * prev_accum;
482 cols.multiplier = if event.exp[i] == F::one() { event.base } else { F::one() };
483 cols.prev_accum_squared_times_multiplier =
484 cols.prev_accum_squared * cols.multiplier;
485 if i == event.exp.len() {
486 assert_eq!(event.result, accum);
487 }
488 });
489
490 overall_rows.extend(rows);
491 });
492
493 pad_rows_fixed(
494 &mut overall_rows,
495 || [F::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS].to_vec(),
496 input.fixed_log2_rows(&ExpReverseBitsLenChip::<DEGREE>),
497 );
498
499 RowMajorMatrix::new(
500 overall_rows.into_iter().flatten().collect(),
501 NUM_EXP_REVERSE_BITS_LEN_COLS,
502 )
503 }
504
505 #[test]
506 fn test_generate_trace() {
507 let shard = test_fixtures::shard();
508 let mut execution_record = test_fixtures::default_execution_record();
509 let trace = ExpReverseBitsLenChip::<DEGREE>.generate_trace(&shard, &mut execution_record);
510 assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
511
512 assert_eq!(trace, generate_trace_reference::<DEGREE>(&shard, &mut execution_record));
513 }
514
515 fn generate_preprocessed_trace_reference(
516 program: &RecursionProgram<BabyBear>,
517 ) -> RowMajorMatrix<BabyBear> {
518 type F = BabyBear;
519
520 let mut rows: Vec<[F; NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS]> = Vec::new();
521 program
522 .inner
523 .iter()
524 .filter_map(|instruction| match instruction {
525 Instruction::ExpReverseBitsLen(x) => Some(x),
526 _ => None,
527 })
528 .for_each(|instruction| {
529 let ExpReverseBitsInstr { addrs, mult } = instruction;
530 let mut row_add =
531 vec![[F::zero(); NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS]; addrs.exp.len()];
532 row_add.iter_mut().enumerate().for_each(|(i, row)| {
533 let row: &mut ExpReverseBitsLenPreprocessedCols<F> =
534 row.as_mut_slice().borrow_mut();
535 row.iteration_num = F::from_canonical_u32(i as u32);
536 row.is_first = F::from_bool(i == 0);
537 row.is_last = F::from_bool(i == addrs.exp.len() - 1);
538 row.is_real = F::one();
539 row.x_mem = MemoryAccessCols { addr: addrs.base, mult: -F::from_bool(i == 0) };
540 row.exponent_mem = MemoryAccessCols { addr: addrs.exp[i], mult: F::neg_one() };
541 row.result_mem = MemoryAccessCols {
542 addr: addrs.result,
543 mult: *mult * F::from_bool(i == addrs.exp.len() - 1),
544 };
545 });
546 rows.extend(row_add);
547 });
548
549 pad_rows_fixed(
550 &mut rows,
551 || [F::zero(); NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS],
552 program.fixed_log2_rows(&ExpReverseBitsLenChip::<3>),
553 );
554
555 RowMajorMatrix::new(
556 rows.into_iter().flatten().collect(),
557 NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS,
558 )
559 }
560
561 #[test]
562 #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
563 fn generate_preprocessed_trace() {
564 let program = test_fixtures::program();
565 let trace = ExpReverseBitsLenChip::<DEGREE>.generate_preprocessed_trace(&program).unwrap();
566 assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
567
568 assert_eq!(trace, generate_preprocessed_trace_reference(&program));
569 }
570}