1#![allow(clippy::needless_range_loop)]
2
3use crate::{
4 builder::SP1RecursionAirBuilder, runtime::ExecutionRecord, ExpReverseBitsEvent,
5 ExpReverseBitsInstr, Instruction,
6};
7use core::borrow::Borrow;
8use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
9use p3_baby_bear::BabyBear;
10use p3_field::{AbstractField, PrimeField32};
11use p3_matrix::{dense::RowMajorMatrix, Matrix};
12use sp1_core_machine::utils::pad_rows_fixed;
13use sp1_derive::AlignedBorrow;
14use sp1_stark::air::{BaseAirBuilder, ExtensionAirBuilder, MachineAir, SP1AirBuilder};
15use std::borrow::BorrowMut;
16use tracing::instrument;
17
18use super::mem::{MemoryAccessCols, MemoryAccessColsChips};
19
20pub const NUM_EXP_REVERSE_BITS_LEN_COLS: usize = core::mem::size_of::<ExpReverseBitsLenCols<u8>>();
21pub const NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS: usize =
22 core::mem::size_of::<ExpReverseBitsLenPreprocessedCols<u8>>();
23
24#[derive(Clone, Debug, Copy, Default)]
25pub struct ExpReverseBitsLenChip<const DEGREE: usize>;
26
27#[derive(AlignedBorrow, Clone, Copy, Debug)]
28#[repr(C)]
29pub struct ExpReverseBitsLenPreprocessedCols<T: Copy> {
30 pub x_mem: MemoryAccessColsChips<T>,
31 pub exponent_mem: MemoryAccessColsChips<T>,
32 pub result_mem: MemoryAccessColsChips<T>,
33 pub iteration_num: T,
34 pub is_first: T,
35 pub is_last: T,
36 pub is_real: T,
37}
38
39#[derive(AlignedBorrow, Debug, Clone, Copy)]
40#[repr(C)]
41pub struct ExpReverseBitsLenCols<T: Copy> {
42 pub x: T,
44
45 pub current_bit: T,
47
48 pub prev_accum_squared: T,
50
51 pub prev_accum_squared_times_multiplier: T,
53
54 pub accum: T,
56
57 pub accum_squared: T,
59
60 pub multiplier: T,
62}
63
64impl<F, const DEGREE: usize> BaseAir<F> for ExpReverseBitsLenChip<DEGREE> {
65 fn width(&self) -> usize {
66 NUM_EXP_REVERSE_BITS_LEN_COLS
67 }
68}
69
70impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for ExpReverseBitsLenChip<DEGREE> {
71 type Record = ExecutionRecord<F>;
72
73 type Program = crate::RecursionProgram<F>;
74
75 fn name(&self) -> String {
76 "ExpReverseBitsLen".to_string()
77 }
78
79 fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
80 }
82
83 fn preprocessed_width(&self) -> usize {
84 NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS
85 }
86
87 fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
88 assert!(
89 std::any::TypeId::of::<F>() == std::any::TypeId::of::<BabyBear>(),
90 "generate_preprocessed_trace only supports BabyBear field"
91 );
92
93 let mut rows: Vec<[BabyBear; NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS]> = Vec::new();
94 program
95 .inner
96 .iter()
97 .filter_map(|instruction| match instruction {
98 Instruction::ExpReverseBitsLen(x) => Some(unsafe {
99 std::mem::transmute::<&ExpReverseBitsInstr<F>, &ExpReverseBitsInstr<BabyBear>>(
100 x,
101 )
102 }),
103 _ => None,
104 })
105 .for_each(|instruction: &ExpReverseBitsInstr<BabyBear>| {
106 let ExpReverseBitsInstr { addrs, mult } = instruction;
107 let mut row_add = vec![
108 [BabyBear::zero();
109 NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS];
110 addrs.exp.len()
111 ];
112 row_add.iter_mut().enumerate().for_each(|(i, row)| {
113 let row: &mut ExpReverseBitsLenPreprocessedCols<BabyBear> =
114 row.as_mut_slice().borrow_mut();
115 row.iteration_num = BabyBear::from_canonical_u32(i as u32);
116 row.is_first = BabyBear::from_bool(i == 0);
117 row.is_last = BabyBear::from_bool(i == addrs.exp.len() - 1);
118 row.is_real = BabyBear::one();
119 row.x_mem =
120 MemoryAccessCols { addr: addrs.base, mult: -BabyBear::from_bool(i == 0) };
121 row.exponent_mem =
122 MemoryAccessCols { addr: addrs.exp[i], mult: BabyBear::neg_one() };
123 row.result_mem = MemoryAccessCols {
124 addr: addrs.result,
125 mult: *mult * BabyBear::from_bool(i == addrs.exp.len() - 1),
126 };
127 });
128 rows.extend(row_add);
129 });
130
131 pad_rows_fixed(
133 &mut rows,
134 || [BabyBear::zero(); NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS],
135 program.fixed_log2_rows(self),
136 );
137
138 let trace = RowMajorMatrix::new(
139 unsafe {
140 std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
141 rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
142 )
143 },
144 NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS,
145 );
146 Some(trace)
147 }
148
149 #[instrument(name = "generate exp reverse bits len trace", level = "debug", skip_all, fields(rows = input.exp_reverse_bits_len_events.len()))]
150 fn generate_trace(
151 &self,
152 input: &ExecutionRecord<F>,
153 _: &mut ExecutionRecord<F>,
154 ) -> RowMajorMatrix<F> {
155 assert!(
156 std::any::TypeId::of::<F>() == std::any::TypeId::of::<BabyBear>(),
157 "generate_trace only supports BabyBear field"
158 );
159
160 let events = unsafe {
161 std::mem::transmute::<&Vec<ExpReverseBitsEvent<F>>, &Vec<ExpReverseBitsEvent<BabyBear>>>(
162 &input.exp_reverse_bits_len_events,
163 )
164 };
165 let mut overall_rows = Vec::new();
166
167 events.iter().for_each(|event| {
168 let mut rows =
169 vec![vec![BabyBear::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS]; event.exp.len()];
170 let mut accum = BabyBear::one();
171
172 rows.iter_mut().enumerate().for_each(|(i, row)| {
173 let cols: &mut ExpReverseBitsLenCols<BabyBear> = row.as_mut_slice().borrow_mut();
174 unsafe {
175 crate::sys::exp_reverse_bits_event_to_row_babybear(&event.into(), i, cols);
176 }
177
178 let prev_accum = accum;
179 accum = prev_accum * prev_accum * cols.multiplier;
180
181 cols.accum = accum;
182 cols.accum_squared = accum * accum;
183 cols.prev_accum_squared = prev_accum * prev_accum;
184 cols.prev_accum_squared_times_multiplier =
185 cols.prev_accum_squared * cols.multiplier;
186 });
187 overall_rows.extend(rows);
188 });
189
190 pad_rows_fixed(
192 &mut overall_rows,
193 || [BabyBear::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS].to_vec(),
194 input.fixed_log2_rows(self),
195 );
196
197 let trace = RowMajorMatrix::new(
199 unsafe {
200 std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
201 overall_rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
202 )
203 },
204 NUM_EXP_REVERSE_BITS_LEN_COLS,
205 );
206
207 #[cfg(debug_assertions)]
208 eprintln!(
209 "exp reverse bits len trace dims is width: {:?}, height: {:?}",
210 trace.width(),
211 trace.height()
212 );
213
214 trace
215 }
216
217 fn included(&self, _record: &Self::Record) -> bool {
218 true
219 }
220}
221
222impl<const DEGREE: usize> ExpReverseBitsLenChip<DEGREE> {
223 pub fn eval_exp_reverse_bits_len<
224 AB: BaseAirBuilder + ExtensionAirBuilder + SP1RecursionAirBuilder + SP1AirBuilder,
225 >(
226 &self,
227 builder: &mut AB,
228 local: &ExpReverseBitsLenCols<AB::Var>,
229 local_prepr: &ExpReverseBitsLenPreprocessedCols<AB::Var>,
230 next: &ExpReverseBitsLenCols<AB::Var>,
231 next_prepr: &ExpReverseBitsLenPreprocessedCols<AB::Var>,
232 ) {
233 if DEGREE > 3 {
235 let lhs = (0..DEGREE).map(|_| local_prepr.is_real.into()).product::<AB::Expr>();
236 let rhs = (0..DEGREE).map(|_| local_prepr.is_real.into()).product::<AB::Expr>();
237 builder.assert_eq(lhs, rhs);
238 }
239
240 builder.send_single(local_prepr.x_mem.addr, local.x, local_prepr.x_mem.mult);
243
244 builder
246 .when_transition()
247 .when(next_prepr.is_real)
248 .when_not(local_prepr.is_last)
249 .assert_eq(local.x, next.x);
250
251 builder.send_single(
253 local_prepr.exponent_mem.addr,
254 local.current_bit,
255 local_prepr.exponent_mem.mult,
256 );
257
258 builder.when(local_prepr.is_first).assert_eq(local.accum, local.multiplier);
260
261 builder
263 .when(local_prepr.is_real)
264 .when(local.current_bit)
265 .assert_eq(local.multiplier, local.x);
266 builder
267 .when(local_prepr.is_real)
268 .when_not(local.current_bit)
269 .assert_eq(local.multiplier, AB::Expr::one());
270
271 builder.when(local_prepr.is_real).assert_eq(
274 local.prev_accum_squared_times_multiplier,
275 local.prev_accum_squared * local.multiplier,
276 );
277
278 builder
279 .when(local_prepr.is_real)
280 .when_not(local_prepr.is_first)
281 .assert_eq(local.accum, local.prev_accum_squared_times_multiplier);
282
283 builder.when(local_prepr.is_real).assert_eq(local.accum_squared, local.accum * local.accum);
285
286 builder
287 .when_transition()
288 .when(next_prepr.is_real)
289 .when_not(local_prepr.is_last)
290 .assert_eq(next.prev_accum_squared, local.accum_squared);
291
292 builder.send_single(local_prepr.result_mem.addr, local.accum, local_prepr.result_mem.mult);
294 }
295
296 pub const fn do_exp_bit_memory_access<T: Copy>(
297 local: &ExpReverseBitsLenPreprocessedCols<T>,
298 ) -> T {
299 local.is_real
300 }
301}
302
303impl<AB, const DEGREE: usize> Air<AB> for ExpReverseBitsLenChip<DEGREE>
304where
305 AB: SP1RecursionAirBuilder + PairBuilder,
306{
307 fn eval(&self, builder: &mut AB) {
308 let main = builder.main();
309 let (local, next) = (main.row_slice(0), main.row_slice(1));
310 let local: &ExpReverseBitsLenCols<AB::Var> = (*local).borrow();
311 let next: &ExpReverseBitsLenCols<AB::Var> = (*next).borrow();
312 let prep = builder.preprocessed();
313 let (prep_local, prep_next) = (prep.row_slice(0), prep.row_slice(1));
314 let prep_local: &ExpReverseBitsLenPreprocessedCols<_> = (*prep_local).borrow();
315 let prep_next: &ExpReverseBitsLenPreprocessedCols<_> = (*prep_next).borrow();
316 self.eval_exp_reverse_bits_len::<AB>(builder, local, prep_local, next, prep_next);
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 #![allow(clippy::print_stdout)]
323
324 use crate::{
325 chips::{exp_reverse_bits::ExpReverseBitsLenChip, test_fixtures},
326 linear_program,
327 machine::tests::test_recursion_linear_program,
328 runtime::{instruction as instr, ExecutionRecord},
329 stark::BabyBearPoseidon2Outer,
330 Address, ExpReverseBitsEvent, ExpReverseBitsIo, Instruction, MemAccessKind,
331 RecursionProgram,
332 };
333 use itertools::Itertools;
334 use p3_baby_bear::BabyBear;
335 use p3_field::{AbstractField, PrimeField32};
336 use p3_matrix::dense::RowMajorMatrix;
337 use p3_util::reverse_bits_len;
338 use rand::{rngs::StdRng, Rng, SeedableRng};
339 use sp1_core_machine::utils::setup_logger;
340 use sp1_stark::{air::MachineAir, StarkGenericConfig};
341 use std::iter::once;
342
343 use super::*;
344
345 const DEGREE: usize = 3;
346
347 #[test]
348 fn prove_babybear_circuit_erbl() {
349 setup_logger();
350 type SC = BabyBearPoseidon2Outer;
351 type F = <SC as StarkGenericConfig>::Val;
352
353 let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
354 let mut random_felt = move || -> F { F::from_canonical_u32(rng.gen_range(0..1 << 16)) };
355 let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
356 let mut random_bit = move || rng.gen_range(0..2);
357 let mut addr = 0;
358
359 let instructions = (1..15)
360 .flat_map(|i| {
361 let base = random_felt();
362 let exponent_bits = vec![random_bit(); i];
363 let exponent = F::from_canonical_u32(
364 exponent_bits.iter().enumerate().fold(0, |acc, (i, x)| acc + x * (1 << i)),
365 );
366 let result =
367 base.exp_u64(reverse_bits_len(exponent.as_canonical_u32() as usize, i) as u64);
368
369 let alloc_size = i + 2;
370 let exp_a = (0..i).map(|x| x + addr + 1).collect::<Vec<_>>();
371 let exp_a_clone = exp_a.clone();
372 let x_a = addr;
373 let result_a = addr + alloc_size - 1;
374 addr += alloc_size;
375 let exp_bit_instructions = (0..i).map(move |j| {
376 instr::mem_single(
377 MemAccessKind::Write,
378 1,
379 exp_a_clone[j] as u32,
380 F::from_canonical_u32(exponent_bits[j]),
381 )
382 });
383 once(instr::mem_single(MemAccessKind::Write, 1, x_a as u32, base))
384 .chain(exp_bit_instructions)
385 .chain(once(instr::exp_reverse_bits_len(
386 1,
387 F::from_canonical_u32(x_a as u32),
388 exp_a
389 .into_iter()
390 .map(|bit| F::from_canonical_u32(bit as u32))
391 .collect_vec(),
392 F::from_canonical_u32(result_a as u32),
393 )))
394 .chain(once(instr::mem_single(MemAccessKind::Read, 1, result_a as u32, result)))
395 })
396 .collect::<Vec<Instruction<F>>>();
397
398 test_recursion_linear_program(instructions);
399 }
400
401 #[test]
402 fn generate_trace() {
403 type F = BabyBear;
404
405 let shard = ExecutionRecord {
406 exp_reverse_bits_len_events: vec![ExpReverseBitsEvent {
407 base: F::two(),
408 exp: vec![F::zero(), F::one(), F::one()],
409 result: F::two().exp_u64(0b110),
410 }],
411 ..Default::default()
412 };
413 let chip = ExpReverseBitsLenChip::<3>;
414 let trace: RowMajorMatrix<F> = chip.generate_trace(&shard, &mut ExecutionRecord::default());
415 println!("{:?}", trace.values)
416 }
417
418 #[test]
419 fn generate_erbl_preprocessed_trace() {
420 type F = BabyBear;
421
422 let program = linear_program(vec![
423 instr::mem(MemAccessKind::Write, 2, 0, 0),
424 instr::mem(MemAccessKind::Write, 2, 1, 0),
425 Instruction::ExpReverseBitsLen(ExpReverseBitsInstr {
426 addrs: ExpReverseBitsIo {
427 base: Address(F::zero()),
428 exp: vec![Address(F::one()), Address(F::zero()), Address(F::one())],
429 result: Address(F::from_canonical_u32(4)),
430 },
431 mult: F::one(),
432 }),
433 instr::mem(MemAccessKind::Read, 1, 4, 0),
434 ])
435 .unwrap();
436
437 let chip = ExpReverseBitsLenChip::<3>;
438 let trace = chip.generate_preprocessed_trace(&program).unwrap();
439 println!("{:?}", trace.values);
440 }
441
442 fn generate_trace_reference<const DEGREE: usize>(
443 input: &ExecutionRecord<BabyBear>,
444 _: &mut ExecutionRecord<BabyBear>,
445 ) -> RowMajorMatrix<BabyBear> {
446 type F = BabyBear;
447
448 let mut overall_rows = Vec::new();
449 input.exp_reverse_bits_len_events.iter().for_each(|event| {
450 let mut rows = vec![vec![F::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS]; event.exp.len()];
451
452 let mut accum = F::one();
453
454 rows.iter_mut().enumerate().for_each(|(i, row)| {
455 let cols: &mut ExpReverseBitsLenCols<F> = row.as_mut_slice().borrow_mut();
456
457 let prev_accum = accum;
458 accum = prev_accum *
459 prev_accum *
460 if event.exp[i] == F::one() { event.base } else { F::one() };
461
462 cols.x = event.base;
463 cols.current_bit = event.exp[i];
464 cols.accum = accum;
465 cols.accum_squared = accum * accum;
466 cols.prev_accum_squared = prev_accum * prev_accum;
467 cols.multiplier = if event.exp[i] == F::one() { event.base } else { F::one() };
468 cols.prev_accum_squared_times_multiplier =
469 cols.prev_accum_squared * cols.multiplier;
470 if i == event.exp.len() {
471 assert_eq!(event.result, accum);
472 }
473 });
474
475 overall_rows.extend(rows);
476 });
477
478 pad_rows_fixed(
479 &mut overall_rows,
480 || [F::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS].to_vec(),
481 input.fixed_log2_rows(&ExpReverseBitsLenChip::<DEGREE>),
482 );
483
484 RowMajorMatrix::new(
485 overall_rows.into_iter().flatten().collect(),
486 NUM_EXP_REVERSE_BITS_LEN_COLS,
487 )
488 }
489
490 #[test]
491 fn test_generate_trace() {
492 let shard = test_fixtures::shard();
493 let mut execution_record = test_fixtures::default_execution_record();
494 let trace = ExpReverseBitsLenChip::<DEGREE>.generate_trace(&shard, &mut execution_record);
495 assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
496
497 assert_eq!(trace, generate_trace_reference::<DEGREE>(&shard, &mut execution_record));
498 }
499
500 fn generate_preprocessed_trace_reference(
501 program: &RecursionProgram<BabyBear>,
502 ) -> RowMajorMatrix<BabyBear> {
503 type F = BabyBear;
504
505 let mut rows: Vec<[F; NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS]> = Vec::new();
506 program
507 .inner
508 .iter()
509 .filter_map(|instruction| match instruction {
510 Instruction::ExpReverseBitsLen(x) => Some(x),
511 _ => None,
512 })
513 .for_each(|instruction| {
514 let ExpReverseBitsInstr { addrs, mult } = instruction;
515 let mut row_add =
516 vec![[F::zero(); NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS]; addrs.exp.len()];
517 row_add.iter_mut().enumerate().for_each(|(i, row)| {
518 let row: &mut ExpReverseBitsLenPreprocessedCols<F> =
519 row.as_mut_slice().borrow_mut();
520 row.iteration_num = F::from_canonical_u32(i as u32);
521 row.is_first = F::from_bool(i == 0);
522 row.is_last = F::from_bool(i == addrs.exp.len() - 1);
523 row.is_real = F::one();
524 row.x_mem = MemoryAccessCols { addr: addrs.base, mult: -F::from_bool(i == 0) };
525 row.exponent_mem = MemoryAccessCols { addr: addrs.exp[i], mult: F::neg_one() };
526 row.result_mem = MemoryAccessCols {
527 addr: addrs.result,
528 mult: *mult * F::from_bool(i == addrs.exp.len() - 1),
529 };
530 });
531 rows.extend(row_add);
532 });
533
534 pad_rows_fixed(
535 &mut rows,
536 || [F::zero(); NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS],
537 program.fixed_log2_rows(&ExpReverseBitsLenChip::<3>),
538 );
539
540 RowMajorMatrix::new(
541 rows.into_iter().flatten().collect(),
542 NUM_EXP_REVERSE_BITS_LEN_PREPROCESSED_COLS,
543 )
544 }
545
546 #[test]
547 #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
548 fn generate_preprocessed_trace() {
549 let program = test_fixtures::program();
550 let trace = ExpReverseBitsLenChip::<DEGREE>.generate_preprocessed_trace(&program).unwrap();
551 assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
552
553 assert_eq!(trace, generate_preprocessed_trace_reference(&program));
554 }
555}