1use super::MemoryChipType;
2use crate::{
3 operations::{AssertLtColsBits, BabyBearBitDecomposition, IsZeroOperation},
4 utils::next_power_of_two,
5};
6use core::{
7 borrow::{Borrow, BorrowMut},
8 mem::size_of,
9};
10
11use p3_air::{Air, AirBuilder, BaseAir};
12use p3_field::{AbstractField, PrimeField32};
13use p3_matrix::{dense::RowMajorMatrix, Matrix};
14use p3_maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
15use sp1_core_executor::{
16 events::{GlobalInteractionEvent, MemoryInitializeFinalizeEvent},
17 ExecutionRecord, Program,
18};
19use sp1_derive::AlignedBorrow;
20use sp1_stark::{
21 air::{
22 AirInteraction, BaseAirBuilder, InteractionScope, MachineAir, PublicValues, SP1AirBuilder,
23 SP1_PROOF_NUM_PV_ELTS,
24 },
25 InteractionKind, Word,
26};
27use std::array;
28
29pub struct MemoryGlobalChip {
31 pub kind: MemoryChipType,
32}
33
34impl MemoryGlobalChip {
35 pub const fn new(kind: MemoryChipType) -> Self {
37 Self { kind }
38 }
39}
40
41impl<F> BaseAir<F> for MemoryGlobalChip {
42 fn width(&self) -> usize {
43 NUM_MEMORY_INIT_COLS
44 }
45}
46
47impl<F: PrimeField32> MachineAir<F> for MemoryGlobalChip {
48 type Record = ExecutionRecord;
49
50 type Program = Program;
51
52 fn name(&self) -> String {
53 match self.kind {
54 MemoryChipType::Initialize => "MemoryGlobalInit".to_string(),
55 MemoryChipType::Finalize => "MemoryGlobalFinalize".to_string(),
56 }
57 }
58
59 fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) {
60 let mut memory_events = match self.kind {
61 MemoryChipType::Initialize => input.global_memory_initialize_events.clone(),
62 MemoryChipType::Finalize => input.global_memory_finalize_events.clone(),
63 };
64
65 let is_receive = match self.kind {
66 MemoryChipType::Initialize => false,
67 MemoryChipType::Finalize => true,
68 };
69
70 memory_events.sort_by_key(|event| event.addr);
71
72 let events = memory_events.into_iter().map(|event| {
73 let interaction_shard = if is_receive { event.shard } else { 0 };
74 let interaction_clk = if is_receive { event.timestamp } else { 0 };
75 GlobalInteractionEvent {
76 message: [
77 interaction_shard,
78 interaction_clk,
79 event.addr,
80 (event.value & 255) as u32,
81 ((event.value >> 8) & 255) as u32,
82 ((event.value >> 16) & 255) as u32,
83 ((event.value >> 24) & 255) as u32,
84 ],
85 is_receive,
86 kind: InteractionKind::Memory as u8,
87 }
88 });
89 output.global_interaction_events.extend(events);
90 }
91
92 fn num_rows(&self, input: &Self::Record) -> Option<usize> {
93 let events = match self.kind {
94 MemoryChipType::Initialize => &input.global_memory_initialize_events,
95 MemoryChipType::Finalize => &input.global_memory_finalize_events,
96 };
97 let nb_rows = events.len();
98 let size_log2 = input.fixed_log2_rows::<F, Self>(self);
99 let padded_nb_rows = next_power_of_two(nb_rows, size_log2);
100 Some(padded_nb_rows)
101 }
102
103 fn generate_trace(
104 &self,
105 input: &ExecutionRecord,
106 _output: &mut ExecutionRecord,
107 ) -> RowMajorMatrix<F> {
108 let mut memory_events = match self.kind {
109 MemoryChipType::Initialize => input.global_memory_initialize_events.clone(),
110 MemoryChipType::Finalize => input.global_memory_finalize_events.clone(),
111 };
112
113 let previous_addr_bits = match self.kind {
114 MemoryChipType::Initialize => input.public_values.previous_init_addr_bits,
115 MemoryChipType::Finalize => input.public_values.previous_finalize_addr_bits,
116 };
117
118 memory_events.sort_by_key(|event| event.addr);
119 let mut rows: Vec<[F; NUM_MEMORY_INIT_COLS]> = memory_events
120 .par_iter()
121 .map(|event| {
122 let MemoryInitializeFinalizeEvent { addr, value, shard, timestamp } =
123 event.to_owned();
124
125 let mut row = [F::zero(); NUM_MEMORY_INIT_COLS];
126 let cols: &mut MemoryInitCols<F> = row.as_mut_slice().borrow_mut();
127 cols.addr = F::from_canonical_u32(addr);
128 cols.addr_bits.populate(addr);
129 cols.shard = F::from_canonical_u32(shard);
130 cols.timestamp = F::from_canonical_u32(timestamp);
131 cols.value = array::from_fn(|i| F::from_canonical_u32((value >> i) & 1));
132 cols.is_real = F::one();
133
134 row
135 })
136 .collect::<Vec<_>>();
137
138 for i in 0..memory_events.len() {
139 let addr = memory_events[i].addr;
140 let cols: &mut MemoryInitCols<F> = rows[i].as_mut_slice().borrow_mut();
141 if i == 0 {
142 let prev_addr = previous_addr_bits
143 .iter()
144 .enumerate()
145 .map(|(j, bit)| bit * (1 << j))
146 .sum::<u32>();
147 cols.is_prev_addr_zero.populate(prev_addr);
148 cols.is_first_comp = F::from_bool(prev_addr != 0);
149 if prev_addr != 0 {
150 debug_assert!(prev_addr < addr, "prev_addr {prev_addr} < addr {addr}");
151 let addr_bits: [_; 32] = array::from_fn(|i| (addr >> i) & 1);
152 cols.lt_cols.populate(&previous_addr_bits, &addr_bits);
153 }
154 }
155 if i != 0 {
156 cols.is_next_comp = F::one();
157 let previous_addr = memory_events[i - 1].addr;
158 assert_ne!(previous_addr, addr);
159
160 let addr_bits: [_; 32] = array::from_fn(|i| (addr >> i) & 1);
161 let prev_addr_bits: [_; 32] = array::from_fn(|i| (previous_addr >> i) & 1);
162 cols.lt_cols.populate(&prev_addr_bits, &addr_bits);
163 }
164
165 if i == memory_events.len() - 1 {
166 cols.is_last_addr = F::one();
167 }
168 }
169
170 rows.resize(
172 <MemoryGlobalChip as MachineAir<F>>::num_rows(self, input).unwrap(),
173 [F::zero(); NUM_MEMORY_INIT_COLS],
174 );
175
176 RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_MEMORY_INIT_COLS)
177 }
178
179 fn included(&self, shard: &Self::Record) -> bool {
180 if let Some(shape) = shard.shape.as_ref() {
181 shape.included::<F, _>(self)
182 } else {
183 match self.kind {
184 MemoryChipType::Initialize => !shard.global_memory_initialize_events.is_empty(),
185 MemoryChipType::Finalize => !shard.global_memory_finalize_events.is_empty(),
186 }
187 }
188 }
189
190 fn commit_scope(&self) -> InteractionScope {
191 InteractionScope::Local
192 }
193}
194
195#[derive(AlignedBorrow, Clone, Copy)]
196#[repr(C)]
197pub struct MemoryInitCols<T: Copy> {
198 pub shard: T,
200
201 pub timestamp: T,
203
204 pub addr: T,
206
207 pub lt_cols: AssertLtColsBits<T, 32>,
209
210 pub addr_bits: BabyBearBitDecomposition<T>,
212
213 pub value: [T; 32],
215
216 pub is_real: T,
218
219 pub is_next_comp: T,
221
222 pub is_prev_addr_zero: IsZeroOperation<T>,
224
225 pub is_first_comp: T,
227
228 pub is_last_addr: T,
230}
231
232pub(crate) const NUM_MEMORY_INIT_COLS: usize = size_of::<MemoryInitCols<u8>>();
233
234impl<AB> Air<AB> for MemoryGlobalChip
235where
236 AB: SP1AirBuilder,
237{
238 fn eval(&self, builder: &mut AB) {
239 let main = builder.main();
240 let local = main.row_slice(0);
241 let local: &MemoryInitCols<AB::Var> = (*local).borrow();
242 let next = main.row_slice(1);
243 let next: &MemoryInitCols<AB::Var> = (*next).borrow();
244
245 builder.assert_bool(local.is_real);
247 for i in 0..32 {
248 builder.assert_bool(local.value[i]);
249 }
250
251 let mut byte1 = AB::Expr::zero();
252 let mut byte2 = AB::Expr::zero();
253 let mut byte3 = AB::Expr::zero();
254 let mut byte4 = AB::Expr::zero();
255 for i in 0..8 {
256 byte1 = byte1.clone() + local.value[i].into() * AB::F::from_canonical_u8(1 << i);
257 byte2 = byte2.clone() + local.value[i + 8].into() * AB::F::from_canonical_u8(1 << i);
258 byte3 = byte3.clone() + local.value[i + 16].into() * AB::F::from_canonical_u8(1 << i);
259 byte4 = byte4.clone() + local.value[i + 24].into() * AB::F::from_canonical_u8(1 << i);
260 }
261 let value = [byte1, byte2, byte3, byte4];
262
263 if self.kind == MemoryChipType::Initialize {
264 builder.send(
266 AirInteraction::new(
267 vec![
268 AB::Expr::zero(),
269 AB::Expr::zero(),
270 local.addr.into(),
271 value[0].clone(),
272 value[1].clone(),
273 value[2].clone(),
274 value[3].clone(),
275 AB::Expr::one(),
276 AB::Expr::zero(),
277 AB::Expr::from_canonical_u8(InteractionKind::Memory as u8),
278 ],
279 local.is_real.into(),
280 InteractionKind::Global,
281 ),
282 InteractionScope::Local,
283 );
284 } else {
285 builder.send(
287 AirInteraction::new(
288 vec![
289 local.shard.into(),
290 local.timestamp.into(),
291 local.addr.into(),
292 value[0].clone(),
293 value[1].clone(),
294 value[2].clone(),
295 value[3].clone(),
296 AB::Expr::zero(),
297 AB::Expr::one(),
298 AB::Expr::from_canonical_u8(InteractionKind::Memory as u8),
299 ],
300 local.is_real.into(),
301 InteractionKind::Global,
302 ),
303 InteractionScope::Local,
304 );
305 }
306
307 BabyBearBitDecomposition::<AB::F>::range_check(
309 builder,
310 local.addr,
311 local.addr_bits,
312 local.is_real.into(),
313 );
314
315 builder.when_transition().assert_eq(next.is_next_comp, next.is_real);
329 next.lt_cols.eval(builder, &local.addr_bits.bits, &next.addr_bits.bits, next.is_next_comp);
330
331 builder.when_transition().when_not(local.is_real).assert_zero(next.is_real);
333
334 let local_addr_bits = local.addr_bits.bits;
343
344 let public_values_array: [AB::Expr; SP1_PROOF_NUM_PV_ELTS] =
345 array::from_fn(|i| builder.public_values()[i].into());
346 let public_values: &PublicValues<Word<AB::Expr>, AB::Expr> =
347 public_values_array.as_slice().borrow();
348
349 let prev_addr_bits = match self.kind {
350 MemoryChipType::Initialize => &public_values.previous_init_addr_bits,
351 MemoryChipType::Finalize => &public_values.previous_finalize_addr_bits,
352 };
353
354 let prev_addr = prev_addr_bits
358 .iter()
359 .enumerate()
360 .map(|(i, bit)| bit.clone() * AB::F::from_wrapped_u32(1 << i))
361 .sum::<AB::Expr>();
362
363 let is_first_row = builder.is_first_row();
365 IsZeroOperation::<AB::F>::eval(builder, prev_addr, local.is_prev_addr_zero, is_first_row);
366
367 builder.assert_bool(local.is_first_comp);
369 builder
370 .when_first_row()
371 .assert_eq(local.is_first_comp, AB::Expr::one() - local.is_prev_addr_zero.result);
372
373 builder.when_first_row().assert_one(local.is_real);
375
376 local.lt_cols.eval(builder, prev_addr_bits, &local_addr_bits, local.is_first_comp);
378
379 builder.when_first_row().when(local.is_prev_addr_zero.result).assert_zero(local.addr);
384 builder.when_first_row().when(local.is_prev_addr_zero.result).assert_one(next.is_real);
385 builder.when_first_row().when(local.is_prev_addr_zero.result).assert_one(next.is_next_comp);
388
389 for i in 0..32 {
400 builder.when_first_row().when_not(local.is_first_comp).assert_zero(local.value[i]);
401 }
402
403 let last_addr_bits = match self.kind {
406 MemoryChipType::Initialize => &public_values.last_init_addr_bits,
407 MemoryChipType::Finalize => &public_values.last_finalize_addr_bits,
408 };
409 builder
415 .when_transition()
416 .assert_eq(local.is_last_addr, local.is_real * (AB::Expr::one() - next.is_real));
417
418 for (local_bit, pub_bit) in local.addr_bits.bits.iter().zip(last_addr_bits.iter()) {
420 builder.when_last_row().when(local.is_real).assert_eq(*local_bit, pub_bit.clone());
421 builder
422 .when_transition()
423 .when(local.is_last_addr)
424 .assert_eq(*local_bit, pub_bit.clone());
425 }
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 #![allow(clippy::print_stdout)]
432
433 use super::*;
434 use crate::{
435 programs::tests::*, riscv::RiscvAir,
436 syscall::precompiles::sha256::extend_tests::sha_extend_program, utils::setup_logger,
437 };
438 use p3_baby_bear::BabyBear;
439 use sp1_core_executor::Executor;
440 use sp1_stark::{
441 baby_bear_poseidon2::BabyBearPoseidon2, debug_interactions_with_all_chips, InteractionKind,
442 SP1CoreOpts, StarkMachine,
443 };
444
445 #[test]
446 fn test_memory_generate_trace() {
447 let program = simple_program();
448 let mut runtime = Executor::new(program, SP1CoreOpts::default());
449 runtime.run().unwrap();
450 let shard = runtime.record.clone();
451
452 let chip: MemoryGlobalChip = MemoryGlobalChip::new(MemoryChipType::Initialize);
453
454 let trace: RowMajorMatrix<BabyBear> =
455 chip.generate_trace(&shard, &mut ExecutionRecord::default());
456 println!("{:?}", trace.values);
457
458 let chip: MemoryGlobalChip = MemoryGlobalChip::new(MemoryChipType::Finalize);
459 let trace: RowMajorMatrix<BabyBear> =
460 chip.generate_trace(&shard, &mut ExecutionRecord::default());
461 println!("{:?}", trace.values);
462
463 for mem_event in shard.global_memory_finalize_events {
464 println!("{mem_event:?}");
465 }
466 }
467
468 #[test]
469 fn test_memory_lookup_interactions() {
470 setup_logger();
471 let program = sha_extend_program();
472 let program_clone = program.clone();
473 let mut runtime = Executor::new(program, SP1CoreOpts::default());
474 runtime.run().unwrap();
475 let machine: StarkMachine<BabyBearPoseidon2, RiscvAir<BabyBear>> =
476 RiscvAir::machine(BabyBearPoseidon2::new());
477 let (pkey, _) = machine.setup(&program_clone);
478 let opts = SP1CoreOpts::default();
479 machine.generate_dependencies(
480 &mut runtime.records.clone().into_iter().map(|r| *r).collect::<Vec<_>>(),
481 &opts,
482 None,
483 );
484
485 let shards = runtime.records;
486 for shard in shards.clone() {
487 debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
488 &machine,
489 &pkey,
490 &[*shard],
491 vec![InteractionKind::Memory],
492 InteractionScope::Local,
493 );
494 }
495 debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
496 &machine,
497 &pkey,
498 &shards.into_iter().map(|r| *r).collect::<Vec<_>>(),
499 vec![InteractionKind::Memory],
500 InteractionScope::Global,
501 );
502 }
503
504 #[test]
505 fn test_byte_lookup_interactions() {
506 setup_logger();
507 let program = sha_extend_program();
508 let program_clone = program.clone();
509 let mut runtime = Executor::new(program, SP1CoreOpts::default());
510 runtime.run().unwrap();
511 let machine = RiscvAir::machine(BabyBearPoseidon2::new());
512 let (pkey, _) = machine.setup(&program_clone);
513 let opts = SP1CoreOpts::default();
514 machine.generate_dependencies(
515 &mut runtime.records.clone().into_iter().map(|r| *r).collect::<Vec<_>>(),
516 &opts,
517 None,
518 );
519
520 let shards = runtime.records;
521 debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
522 &machine,
523 &pkey,
524 &shards.into_iter().map(|r| *r).collect::<Vec<_>>(),
525 vec![InteractionKind::Byte],
526 InteractionScope::Global,
527 );
528 }
529}