sp1_core_machine/air/memory.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
use std::iter::once;
use p3_air::AirBuilder;
use p3_field::AbstractField;
use sp1_core_executor::ByteOpcode;
use sp1_stark::{
air::{AirInteraction, BaseAirBuilder, ByteAirBuilder, InteractionScope},
InteractionKind,
};
use crate::memory::{MemoryAccessCols, MemoryCols};
pub trait MemoryAirBuilder: BaseAirBuilder {
/// Constrain a memory read or write.
///
/// This method verifies that a memory access timestamp (shard, clk) is greater than the
/// previous access's timestamp. It will also add to the memory argument.
fn eval_memory_access<E: Into<Self::Expr> + Clone>(
&mut self,
shard: impl Into<Self::Expr>,
clk: impl Into<Self::Expr>,
addr: impl Into<Self::Expr>,
memory_access: &impl MemoryCols<E>,
do_check: impl Into<Self::Expr>,
) {
let do_check: Self::Expr = do_check.into();
let shard: Self::Expr = shard.into();
let clk: Self::Expr = clk.into();
let mem_access = memory_access.access();
self.assert_bool(do_check.clone());
// Verify that the current memory access time is greater than the previous's.
self.eval_memory_access_timestamp(mem_access, do_check.clone(), shard.clone(), clk.clone());
// Add to the memory argument.
let addr = addr.into();
let prev_shard = mem_access.prev_shard.clone().into();
let prev_clk = mem_access.prev_clk.clone().into();
let prev_values = once(prev_shard)
.chain(once(prev_clk))
.chain(once(addr.clone()))
.chain(memory_access.prev_value().clone().map(Into::into))
.collect();
let current_values = once(shard)
.chain(once(clk))
.chain(once(addr.clone()))
.chain(memory_access.value().clone().map(Into::into))
.collect();
// The previous values get sent with multiplicity = 1, for "read".
self.send(
AirInteraction::new(prev_values, do_check.clone(), InteractionKind::Memory),
InteractionScope::Local,
);
// The current values get "received", i.e. multiplicity = -1
self.receive(
AirInteraction::new(current_values, do_check.clone(), InteractionKind::Memory),
InteractionScope::Local,
);
}
/// Constraints a memory read or write to a slice of `MemoryAccessCols`.
fn eval_memory_access_slice<E: Into<Self::Expr> + Copy>(
&mut self,
shard: impl Into<Self::Expr> + Copy,
clk: impl Into<Self::Expr> + Clone,
initial_addr: impl Into<Self::Expr> + Clone,
memory_access_slice: &[impl MemoryCols<E>],
verify_memory_access: impl Into<Self::Expr> + Copy,
) {
for (i, access_slice) in memory_access_slice.iter().enumerate() {
self.eval_memory_access(
shard,
clk.clone(),
initial_addr.clone().into() + Self::Expr::from_canonical_usize(i * 4),
access_slice,
verify_memory_access,
);
}
}
/// Verifies the memory access timestamp.
///
/// This method verifies that the current memory access happened after the previous one's.
/// Specifically it will ensure that if the current and previous access are in the same shard,
/// then the current's clk val is greater than the previous's. If they are not in the same
/// shard, then it will ensure that the current's shard val is greater than the previous's.
fn eval_memory_access_timestamp(
&mut self,
mem_access: &MemoryAccessCols<impl Into<Self::Expr> + Clone>,
do_check: impl Into<Self::Expr>,
shard: impl Into<Self::Expr> + Clone,
clk: impl Into<Self::Expr>,
) {
let do_check: Self::Expr = do_check.into();
let compare_clk: Self::Expr = mem_access.compare_clk.clone().into();
let shard: Self::Expr = shard.clone().into();
let prev_shard: Self::Expr = mem_access.prev_shard.clone().into();
// First verify that compare_clk's value is correct.
self.when(do_check.clone()).assert_bool(compare_clk.clone());
self.when(do_check.clone()).when(compare_clk.clone()).assert_eq(shard.clone(), prev_shard);
// Get the comparison timestamp values for the current and previous memory access.
let prev_comp_value = self.if_else(
mem_access.compare_clk.clone(),
mem_access.prev_clk.clone(),
mem_access.prev_shard.clone(),
);
let current_comp_val = self.if_else(compare_clk.clone(), clk.into(), shard.clone());
// Assert `current_comp_val > prev_comp_val`. We check this by asserting that
// `0 <= current_comp_val-prev_comp_val-1 < 2^24`.
//
// The equivalence of these statements comes from the fact that if
// `current_comp_val <= prev_comp_val`, then `current_comp_val-prev_comp_val-1 < 0` and will
// underflow in the prime field, resulting in a value that is `>= 2^24` as long as both
// `current_comp_val, prev_comp_val` are range-checked to be `<2^24` and as long as we're
// working in a field larger than `2 * 2^24` (which is true of the BabyBear and Mersenne31
// prime).
let diff_minus_one = current_comp_val - prev_comp_value - Self::Expr::one();
// Verify that mem_access.ts_diff = mem_access.ts_diff_16bit_limb
// + mem_access.ts_diff_8bit_limb * 2^16.
self.eval_range_check_24bits(
diff_minus_one,
mem_access.diff_16bit_limb.clone(),
mem_access.diff_8bit_limb.clone(),
do_check,
);
}
/// Verifies the inputted value is within 24 bits.
///
/// This method verifies that the inputted is less than 2^24 by doing a 16 bit and 8 bit range
/// check on it's limbs. It will also verify that the limbs are correct. This method is needed
/// since the memory access timestamp check (see [Self::verify_mem_access_ts]) needs to assume
/// the clk is within 24 bits.
fn eval_range_check_24bits(
&mut self,
value: impl Into<Self::Expr>,
limb_16: impl Into<Self::Expr> + Clone,
limb_8: impl Into<Self::Expr> + Clone,
do_check: impl Into<Self::Expr> + Clone,
) {
// Verify that value = limb_16 + limb_8 * 2^16.
self.when(do_check.clone()).assert_eq(
value,
limb_16.clone().into()
+ limb_8.clone().into() * Self::Expr::from_canonical_u32(1 << 16),
);
// Send the range checks for the limbs.
self.send_byte(
Self::Expr::from_canonical_u8(ByteOpcode::U16Range as u8),
limb_16,
Self::Expr::zero(),
Self::Expr::zero(),
do_check.clone(),
);
self.send_byte(
Self::Expr::from_canonical_u8(ByteOpcode::U8Range as u8),
Self::Expr::zero(),
Self::Expr::zero(),
limb_8,
do_check,
)
}
}