use std::array;
use std::iter::once;
use itertools::Itertools;
use p3_air::{AirBuilder, FilteredAirBuilder};
use p3_air::{AirBuilderWithPublicValues, PermutationAirBuilder};
use p3_field::{AbstractField, Field};
use p3_uni_stark::StarkGenericConfig;
use p3_uni_stark::{ProverConstraintFolder, SymbolicAirBuilder, VerifierConstraintFolder};
use super::interaction::AirInteraction;
use super::word::Word;
use super::{BinomialExtension, WORD_SIZE};
use crate::cpu::columns::InstructionCols;
use crate::cpu::columns::OpcodeSelectorCols;
use crate::lookup::InteractionKind;
use crate::memory::MemoryAccessCols;
use crate::{bytes::ByteOpcode, memory::MemoryCols};
pub trait MessageBuilder<M> {
fn send(&mut self, message: M);
fn receive(&mut self, message: M);
}
impl<AB: EmptyMessageBuilder, M> MessageBuilder<M> for AB {
fn send(&mut self, _message: M) {}
fn receive(&mut self, _message: M) {}
}
pub trait EmptyMessageBuilder: AirBuilder {}
pub trait BaseAirBuilder: AirBuilder + MessageBuilder<AirInteraction<Self::Expr>> {
fn when_not<I: Into<Self::Expr>>(&mut self, condition: I) -> FilteredAirBuilder<Self> {
self.when_ne(condition, Self::F::one())
}
fn assert_all_eq<I1: Into<Self::Expr>, I2: Into<Self::Expr>>(
&mut self,
left: impl IntoIterator<Item = I1>,
right: impl IntoIterator<Item = I2>,
) {
for (left, right) in left.into_iter().zip_eq(right) {
self.assert_eq(left, right);
}
}
fn assert_all_zero<I: Into<Self::Expr>>(&mut self, iter: impl IntoIterator<Item = I>) {
iter.into_iter().for_each(|expr| self.assert_zero(expr));
}
#[inline]
fn if_else(
&mut self,
condition: impl Into<Self::Expr> + Clone,
a: impl Into<Self::Expr> + Clone,
b: impl Into<Self::Expr> + Clone,
) -> Self::Expr {
condition.clone().into() * a.into() + (Self::Expr::one() - condition.into()) * b.into()
}
fn index_array(
&mut self,
array: &[impl Into<Self::Expr> + Clone],
index_bitmap: &[impl Into<Self::Expr> + Clone],
) -> Self::Expr {
let mut result = Self::Expr::zero();
for (value, i) in array.iter().zip_eq(index_bitmap) {
result += value.clone().into() * i.clone().into();
}
result
}
}
pub trait ByteAirBuilder: BaseAirBuilder {
#[allow(clippy::too_many_arguments)]
fn send_byte(
&mut self,
opcode: impl Into<Self::Expr>,
a: impl Into<Self::Expr>,
b: impl Into<Self::Expr>,
c: impl Into<Self::Expr>,
shard: impl Into<Self::Expr>,
channel: impl Into<Self::Expr>,
multiplicity: impl Into<Self::Expr>,
) {
self.send_byte_pair(
opcode,
a,
Self::Expr::zero(),
b,
c,
shard,
channel,
multiplicity,
)
}
#[allow(clippy::too_many_arguments)]
fn send_byte_pair(
&mut self,
opcode: impl Into<Self::Expr>,
a1: impl Into<Self::Expr>,
a2: impl Into<Self::Expr>,
b: impl Into<Self::Expr>,
c: impl Into<Self::Expr>,
shard: impl Into<Self::Expr>,
channel: impl Into<Self::Expr>,
multiplicity: impl Into<Self::Expr>,
) {
self.send(AirInteraction::new(
vec![
opcode.into(),
a1.into(),
a2.into(),
b.into(),
c.into(),
shard.into(),
channel.into(),
],
multiplicity.into(),
InteractionKind::Byte,
));
}
#[allow(clippy::too_many_arguments)]
fn receive_byte(
&mut self,
opcode: impl Into<Self::Expr>,
a: impl Into<Self::Expr>,
b: impl Into<Self::Expr>,
c: impl Into<Self::Expr>,
shard: impl Into<Self::Expr>,
channel: impl Into<Self::Expr>,
multiplicity: impl Into<Self::Expr>,
) {
self.receive_byte_pair(
opcode,
a,
Self::Expr::zero(),
b,
c,
shard,
channel,
multiplicity,
)
}
#[allow(clippy::too_many_arguments)]
fn receive_byte_pair(
&mut self,
opcode: impl Into<Self::Expr>,
a1: impl Into<Self::Expr>,
a2: impl Into<Self::Expr>,
b: impl Into<Self::Expr>,
c: impl Into<Self::Expr>,
shard: impl Into<Self::Expr>,
channel: impl Into<Self::Expr>,
multiplicity: impl Into<Self::Expr>,
) {
self.receive(AirInteraction::new(
vec![
opcode.into(),
a1.into(),
a2.into(),
b.into(),
c.into(),
shard.into(),
channel.into(),
],
multiplicity.into(),
InteractionKind::Byte,
));
}
}
pub trait WordAirBuilder: ByteAirBuilder {
fn assert_word_eq(
&mut self,
left: Word<impl Into<Self::Expr>>,
right: Word<impl Into<Self::Expr>>,
) {
for (left, right) in left.0.into_iter().zip(right.0) {
self.assert_eq(left, right);
}
}
fn assert_word_zero(&mut self, word: Word<impl Into<Self::Expr>>) {
for limb in word.0 {
self.assert_zero(limb);
}
}
fn index_word_array(
&mut self,
array: &[Word<impl Into<Self::Expr> + Clone>],
index_bitmap: &[impl Into<Self::Expr> + Clone],
) -> Word<Self::Expr> {
let mut result = Word::default();
for i in 0..WORD_SIZE {
result[i] = self.index_array(
array
.iter()
.map(|word| word[i].clone())
.collect_vec()
.as_slice(),
index_bitmap,
);
}
result
}
fn select_word(
&mut self,
condition: impl Into<Self::Expr> + Clone,
a: Word<impl Into<Self::Expr> + Clone>,
b: Word<impl Into<Self::Expr> + Clone>,
) -> Word<Self::Expr> {
Word(array::from_fn(|i| {
self.if_else(condition.clone(), a[i].clone(), b[i].clone())
}))
}
fn slice_range_check_u8(
&mut self,
input: &[impl Into<Self::Expr> + Clone],
shard: impl Into<Self::Expr> + Clone,
channel: impl Into<Self::Expr> + Clone,
mult: impl Into<Self::Expr> + Clone,
) {
let mut index = 0;
while index + 1 < input.len() {
self.send_byte(
Self::Expr::from_canonical_u8(ByteOpcode::U8Range as u8),
Self::Expr::zero(),
input[index].clone(),
input[index + 1].clone(),
shard.clone(),
channel.clone(),
mult.clone(),
);
index += 2;
}
if index < input.len() {
self.send_byte(
Self::Expr::from_canonical_u8(ByteOpcode::U8Range as u8),
Self::Expr::zero(),
input[index].clone(),
Self::Expr::zero(),
shard.clone(),
channel.clone(),
mult.clone(),
);
}
}
fn slice_range_check_u16(
&mut self,
input: &[impl Into<Self::Expr> + Copy],
shard: impl Into<Self::Expr> + Clone,
channel: impl Into<Self::Expr> + Clone,
mult: impl Into<Self::Expr> + Clone,
) {
input.iter().for_each(|limb| {
self.send_byte(
Self::Expr::from_canonical_u8(ByteOpcode::U16Range as u8),
*limb,
Self::Expr::zero(),
Self::Expr::zero(),
shard.clone(),
channel.clone(),
mult.clone(),
);
});
}
}
pub trait AluAirBuilder: BaseAirBuilder {
#[allow(clippy::too_many_arguments)]
fn send_alu(
&mut self,
opcode: impl Into<Self::Expr>,
a: Word<impl Into<Self::Expr>>,
b: Word<impl Into<Self::Expr>>,
c: Word<impl Into<Self::Expr>>,
shard: impl Into<Self::Expr>,
channel: impl Into<Self::Expr>,
nonce: impl Into<Self::Expr>,
multiplicity: impl Into<Self::Expr>,
) {
let values = once(opcode.into())
.chain(a.0.into_iter().map(Into::into))
.chain(b.0.into_iter().map(Into::into))
.chain(c.0.into_iter().map(Into::into))
.chain(once(shard.into()))
.chain(once(channel.into()))
.chain(once(nonce.into()))
.collect();
self.send(AirInteraction::new(
values,
multiplicity.into(),
InteractionKind::Alu,
));
}
#[allow(clippy::too_many_arguments)]
fn receive_alu(
&mut self,
opcode: impl Into<Self::Expr>,
a: Word<impl Into<Self::Expr>>,
b: Word<impl Into<Self::Expr>>,
c: Word<impl Into<Self::Expr>>,
shard: impl Into<Self::Expr>,
channel: impl Into<Self::Expr>,
nonce: impl Into<Self::Expr>,
multiplicity: impl Into<Self::Expr>,
) {
let values = once(opcode.into())
.chain(a.0.into_iter().map(Into::into))
.chain(b.0.into_iter().map(Into::into))
.chain(c.0.into_iter().map(Into::into))
.chain(once(shard.into()))
.chain(once(channel.into()))
.chain(once(nonce.into()))
.collect();
self.receive(AirInteraction::new(
values,
multiplicity.into(),
InteractionKind::Alu,
));
}
#[allow(clippy::too_many_arguments)]
fn send_syscall(
&mut self,
shard: impl Into<Self::Expr> + Clone,
channel: impl Into<Self::Expr> + Clone,
clk: impl Into<Self::Expr> + Clone,
nonce: impl Into<Self::Expr> + Clone,
syscall_id: impl Into<Self::Expr> + Clone,
arg1: impl Into<Self::Expr> + Clone,
arg2: impl Into<Self::Expr> + Clone,
multiplicity: impl Into<Self::Expr>,
) {
self.send(AirInteraction::new(
vec![
shard.clone().into(),
channel.clone().into(),
clk.clone().into(),
nonce.clone().into(),
syscall_id.clone().into(),
arg1.clone().into(),
arg2.clone().into(),
],
multiplicity.into(),
InteractionKind::Syscall,
));
}
#[allow(clippy::too_many_arguments)]
fn receive_syscall(
&mut self,
shard: impl Into<Self::Expr> + Clone,
channel: impl Into<Self::Expr> + Clone,
clk: impl Into<Self::Expr> + Clone,
nonce: impl Into<Self::Expr> + Clone,
syscall_id: impl Into<Self::Expr> + Clone,
arg1: impl Into<Self::Expr> + Clone,
arg2: impl Into<Self::Expr> + Clone,
multiplicity: impl Into<Self::Expr>,
) {
self.receive(AirInteraction::new(
vec![
shard.clone().into(),
channel.clone().into(),
clk.clone().into(),
nonce.clone().into(),
syscall_id.clone().into(),
arg1.clone().into(),
arg2.clone().into(),
],
multiplicity.into(),
InteractionKind::Syscall,
));
}
}
pub trait MemoryAirBuilder: BaseAirBuilder {
fn eval_memory_access<E: Into<Self::Expr> + Clone>(
&mut self,
shard: impl Into<Self::Expr>,
channel: impl Into<Self::Expr>,
clk: impl Into<Self::Expr>,
addr: impl Into<Self::Expr>,
memory_access: &impl MemoryCols<E>,
do_check: impl Into<Self::Expr>,
) {
let do_check: Self::Expr = do_check.into();
let shard: Self::Expr = shard.into();
let channel: Self::Expr = channel.into();
let clk: Self::Expr = clk.into();
let mem_access = memory_access.access();
self.assert_bool(do_check.clone());
self.eval_memory_access_timestamp(
mem_access,
do_check.clone(),
shard.clone(),
channel,
clk.clone(),
);
let addr = addr.into();
let prev_shard = mem_access.prev_shard.clone().into();
let prev_clk = mem_access.prev_clk.clone().into();
let prev_values = once(prev_shard)
.chain(once(prev_clk))
.chain(once(addr.clone()))
.chain(memory_access.prev_value().clone().map(Into::into))
.collect();
let current_values = once(shard)
.chain(once(clk))
.chain(once(addr.clone()))
.chain(memory_access.value().clone().map(Into::into))
.collect();
self.send(AirInteraction::new(
prev_values,
do_check.clone(),
InteractionKind::Memory,
));
self.receive(AirInteraction::new(
current_values,
do_check.clone(),
InteractionKind::Memory,
));
}
fn eval_memory_access_slice<E: Into<Self::Expr> + Copy>(
&mut self,
shard: impl Into<Self::Expr> + Copy,
channel: impl Into<Self::Expr> + Clone,
clk: impl Into<Self::Expr> + Clone,
initial_addr: impl Into<Self::Expr> + Clone,
memory_access_slice: &[impl MemoryCols<E>],
verify_memory_access: impl Into<Self::Expr> + Copy,
) {
for (i, access_slice) in memory_access_slice.iter().enumerate() {
self.eval_memory_access(
shard,
channel.clone(),
clk.clone(),
initial_addr.clone().into() + Self::Expr::from_canonical_usize(i * 4),
access_slice,
verify_memory_access,
);
}
}
fn eval_memory_access_timestamp(
&mut self,
mem_access: &MemoryAccessCols<impl Into<Self::Expr> + Clone>,
do_check: impl Into<Self::Expr>,
shard: impl Into<Self::Expr> + Clone,
channel: impl Into<Self::Expr> + Clone,
clk: impl Into<Self::Expr>,
) {
let do_check: Self::Expr = do_check.into();
let compare_clk: Self::Expr = mem_access.compare_clk.clone().into();
let shard: Self::Expr = shard.clone().into();
let prev_shard: Self::Expr = mem_access.prev_shard.clone().into();
self.when(do_check.clone()).assert_bool(compare_clk.clone());
self.when(do_check.clone())
.when(compare_clk.clone())
.assert_eq(shard.clone(), prev_shard);
let prev_comp_value = self.if_else(
mem_access.compare_clk.clone(),
mem_access.prev_clk.clone(),
mem_access.prev_shard.clone(),
);
let current_comp_val = self.if_else(compare_clk.clone(), clk.into(), shard.clone());
let diff_minus_one = current_comp_val - prev_comp_value - Self::Expr::one();
self.eval_range_check_24bits(
diff_minus_one,
mem_access.diff_16bit_limb.clone(),
mem_access.diff_8bit_limb.clone(),
shard.clone(),
channel.clone(),
do_check,
);
}
fn eval_range_check_24bits(
&mut self,
value: impl Into<Self::Expr>,
limb_16: impl Into<Self::Expr> + Clone,
limb_8: impl Into<Self::Expr> + Clone,
shard: impl Into<Self::Expr> + Clone,
channel: impl Into<Self::Expr> + Clone,
do_check: impl Into<Self::Expr> + Clone,
) {
self.when(do_check.clone()).assert_eq(
value,
limb_16.clone().into()
+ limb_8.clone().into() * Self::Expr::from_canonical_u32(1 << 16),
);
self.send_byte(
Self::Expr::from_canonical_u8(ByteOpcode::U16Range as u8),
limb_16,
Self::Expr::zero(),
Self::Expr::zero(),
shard.clone(),
channel.clone(),
do_check.clone(),
);
self.send_byte(
Self::Expr::from_canonical_u8(ByteOpcode::U8Range as u8),
Self::Expr::zero(),
Self::Expr::zero(),
limb_8,
shard.clone(),
channel.clone(),
do_check,
)
}
}
pub trait ProgramAirBuilder: BaseAirBuilder {
fn send_program(
&mut self,
pc: impl Into<Self::Expr>,
instruction: InstructionCols<impl Into<Self::Expr> + Copy>,
selectors: OpcodeSelectorCols<impl Into<Self::Expr> + Copy>,
shard: impl Into<Self::Expr> + Copy,
multiplicity: impl Into<Self::Expr>,
) {
let values = once(pc.into())
.chain(once(instruction.opcode.into()))
.chain(instruction.into_iter().map(|x| x.into()))
.chain(selectors.into_iter().map(|x| x.into()))
.chain(once(shard.into()))
.collect();
self.send(AirInteraction::new(
values,
multiplicity.into(),
InteractionKind::Program,
));
}
fn receive_program(
&mut self,
pc: impl Into<Self::Expr>,
instruction: InstructionCols<impl Into<Self::Expr> + Copy>,
selectors: OpcodeSelectorCols<impl Into<Self::Expr> + Copy>,
shard: impl Into<Self::Expr> + Copy,
multiplicity: impl Into<Self::Expr>,
) {
let values: Vec<<Self as AirBuilder>::Expr> = once(pc.into())
.chain(once(instruction.opcode.into()))
.chain(instruction.into_iter().map(|x| x.into()))
.chain(selectors.into_iter().map(|x| x.into()))
.chain(once(shard.into()))
.collect();
self.receive(AirInteraction::new(
values,
multiplicity.into(),
InteractionKind::Program,
));
}
}
pub trait ExtensionAirBuilder: BaseAirBuilder {
fn assert_ext_eq<I: Into<Self::Expr>>(
&mut self,
left: BinomialExtension<I>,
right: BinomialExtension<I>,
) {
for (left, right) in left.0.into_iter().zip(right.0) {
self.assert_eq(left, right);
}
}
fn assert_is_base_element<I: Into<Self::Expr> + Clone>(
&mut self,
element: BinomialExtension<I>,
) {
let base_slice = element.as_base_slice();
let degree = base_slice.len();
base_slice[1..degree].iter().for_each(|coeff| {
self.assert_zero(coeff.clone().into());
});
}
fn if_else_ext(
&mut self,
condition: impl Into<Self::Expr> + Clone,
a: BinomialExtension<impl Into<Self::Expr> + Clone>,
b: BinomialExtension<impl Into<Self::Expr> + Clone>,
) -> BinomialExtension<Self::Expr> {
BinomialExtension(array::from_fn(|i| {
self.if_else(condition.clone(), a.0[i].clone(), b.0[i].clone())
}))
}
}
pub trait MultiTableAirBuilder: PermutationAirBuilder {
type Sum: Into<Self::ExprEF>;
fn cumulative_sum(&self) -> Self::Sum;
}
pub trait MachineAirBuilder:
BaseAirBuilder + ExtensionAirBuilder + AirBuilderWithPublicValues
{
}
pub trait SP1AirBuilder:
MachineAirBuilder
+ ByteAirBuilder
+ WordAirBuilder
+ AluAirBuilder
+ MemoryAirBuilder
+ ProgramAirBuilder
{
}
impl<'a, AB: AirBuilder + MessageBuilder<M>, M> MessageBuilder<M> for FilteredAirBuilder<'a, AB> {
fn send(&mut self, message: M) {
self.inner.send(message);
}
fn receive(&mut self, message: M) {
self.inner.receive(message);
}
}
impl<AB: AirBuilder + MessageBuilder<AirInteraction<AB::Expr>>> BaseAirBuilder for AB {}
impl<AB: BaseAirBuilder> ByteAirBuilder for AB {}
impl<AB: BaseAirBuilder> WordAirBuilder for AB {}
impl<AB: BaseAirBuilder> AluAirBuilder for AB {}
impl<AB: BaseAirBuilder> MemoryAirBuilder for AB {}
impl<AB: BaseAirBuilder> ProgramAirBuilder for AB {}
impl<AB: BaseAirBuilder> ExtensionAirBuilder for AB {}
impl<AB: BaseAirBuilder + AirBuilderWithPublicValues> MachineAirBuilder for AB {}
impl<AB: BaseAirBuilder + AirBuilderWithPublicValues> SP1AirBuilder for AB {}
impl<'a, SC: StarkGenericConfig> EmptyMessageBuilder for ProverConstraintFolder<'a, SC> {}
impl<'a, SC: StarkGenericConfig> EmptyMessageBuilder for VerifierConstraintFolder<'a, SC> {}
impl<F: Field> EmptyMessageBuilder for SymbolicAirBuilder<F> {}
#[cfg(debug_assertions)]
#[cfg(not(doctest))]
impl<'a, F: Field> EmptyMessageBuilder for p3_uni_stark::DebugConstraintBuilder<'a, F> {}