use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::AbstractField;
use p3_matrix::Matrix;
use super::{ShaExtendChip, ShaExtendCols, NUM_SHA_EXTEND_COLS};
use crate::air::{BaseAirBuilder, SP1AirBuilder};
use crate::memory::MemoryCols;
use crate::operations::{
Add4Operation, FixedRotateRightOperation, FixedShiftRightOperation, XorOperation,
};
use crate::runtime::SyscallCode;
use core::borrow::Borrow;
impl<F> BaseAir<F> for ShaExtendChip {
fn width(&self) -> usize {
NUM_SHA_EXTEND_COLS
}
}
impl<AB> Air<AB> for ShaExtendChip
where
AB: SP1AirBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let (local, next) = (main.row_slice(0), main.row_slice(1));
let local: &ShaExtendCols<AB::Var> = (*local).borrow();
let next: &ShaExtendCols<AB::Var> = (*next).borrow();
builder.when_first_row().assert_zero(local.nonce);
builder
.when_transition()
.assert_eq(local.nonce + AB::Expr::one(), next.nonce);
let i_start = AB::F::from_canonical_u32(16);
let nb_bytes_in_word = AB::F::from_canonical_u32(4);
self.eval_flags(builder);
builder
.when_transition()
.when_not(local.cycle_16_end.result * local.cycle_48[2])
.assert_eq(local.shard, next.shard);
builder
.when_transition()
.when_not(local.cycle_16_end.result * local.cycle_48[2])
.assert_eq(local.clk, next.clk);
builder
.when_transition()
.when_not(local.cycle_16_end.result * local.cycle_48[2])
.assert_eq(local.channel, next.channel);
builder
.when_transition()
.when_not(local.cycle_16_end.result * local.cycle_48[2])
.assert_eq(local.w_ptr, next.w_ptr);
builder.eval_memory_access(
local.shard,
local.channel,
local.clk + (local.i - i_start),
local.w_ptr + (local.i - AB::F::from_canonical_u32(15)) * nb_bytes_in_word,
&local.w_i_minus_15,
local.is_real,
);
builder.eval_memory_access(
local.shard,
local.channel,
local.clk + (local.i - i_start),
local.w_ptr + (local.i - AB::F::from_canonical_u32(2)) * nb_bytes_in_word,
&local.w_i_minus_2,
local.is_real,
);
builder.eval_memory_access(
local.shard,
local.channel,
local.clk + (local.i - i_start),
local.w_ptr + (local.i - AB::F::from_canonical_u32(16)) * nb_bytes_in_word,
&local.w_i_minus_16,
local.is_real,
);
builder.eval_memory_access(
local.shard,
local.channel,
local.clk + (local.i - i_start),
local.w_ptr + (local.i - AB::F::from_canonical_u32(7)) * nb_bytes_in_word,
&local.w_i_minus_7,
local.is_real,
);
FixedRotateRightOperation::<AB::F>::eval(
builder,
*local.w_i_minus_15.value(),
7,
local.w_i_minus_15_rr_7,
local.shard,
local.channel,
local.is_real,
);
FixedRotateRightOperation::<AB::F>::eval(
builder,
*local.w_i_minus_15.value(),
18,
local.w_i_minus_15_rr_18,
local.shard,
local.channel,
local.is_real,
);
FixedShiftRightOperation::<AB::F>::eval(
builder,
*local.w_i_minus_15.value(),
3,
local.w_i_minus_15_rs_3,
local.shard,
local.channel,
local.is_real,
);
XorOperation::<AB::F>::eval(
builder,
local.w_i_minus_15_rr_7.value,
local.w_i_minus_15_rr_18.value,
local.s0_intermediate,
local.shard,
local.channel,
local.is_real,
);
XorOperation::<AB::F>::eval(
builder,
local.s0_intermediate.value,
local.w_i_minus_15_rs_3.value,
local.s0,
local.shard,
local.channel,
local.is_real,
);
FixedRotateRightOperation::<AB::F>::eval(
builder,
*local.w_i_minus_2.value(),
17,
local.w_i_minus_2_rr_17,
local.shard,
local.channel,
local.is_real,
);
FixedRotateRightOperation::<AB::F>::eval(
builder,
*local.w_i_minus_2.value(),
19,
local.w_i_minus_2_rr_19,
local.shard,
local.channel,
local.is_real,
);
FixedShiftRightOperation::<AB::F>::eval(
builder,
*local.w_i_minus_2.value(),
10,
local.w_i_minus_2_rs_10,
local.shard,
local.channel,
local.is_real,
);
XorOperation::<AB::F>::eval(
builder,
local.w_i_minus_2_rr_17.value,
local.w_i_minus_2_rr_19.value,
local.s1_intermediate,
local.shard,
local.channel,
local.is_real,
);
XorOperation::<AB::F>::eval(
builder,
local.s1_intermediate.value,
local.w_i_minus_2_rs_10.value,
local.s1,
local.shard,
local.channel,
local.is_real,
);
Add4Operation::<AB::F>::eval(
builder,
*local.w_i_minus_16.value(),
local.s0.value,
*local.w_i_minus_7.value(),
local.s1.value,
local.shard,
local.channel,
local.is_real,
local.s2,
);
builder.eval_memory_access(
local.shard,
local.channel,
local.clk + (local.i - i_start),
local.w_ptr + local.i * nb_bytes_in_word,
&local.w_i,
local.is_real,
);
builder.assert_word_eq(*local.w_i.value(), local.s2.value);
builder.receive_syscall(
local.shard,
local.channel,
local.clk,
local.nonce,
AB::F::from_canonical_u32(SyscallCode::SHA_EXTEND.syscall_id()),
local.w_ptr,
AB::Expr::zero(),
local.cycle_48_start,
);
builder.assert_bool(local.is_real);
builder
.when_transition()
.when_not(local.cycle_48_end)
.assert_eq(local.is_real, next.is_real);
builder.when_last_row().assert_zero(local.is_real);
}
}