use core::borrow::Borrow;
use p3_air::AirBuilder;
use p3_baby_bear::BabyBear;
use p3_field::AbstractField;
use p3_field::Field;
use p3_field::PrimeField32;
use p3_field::TwoAdicField;
use p3_matrix::Matrix;
use crate::air::BaseAirBuilder;
use crate::air::SP1AirBuilder;
use crate::operations::IsZeroOperation;
use super::ShaExtendChip;
use super::ShaExtendCols;
impl<F: Field> ShaExtendCols<F> {
pub fn populate_flags(&mut self, i: usize) {
let g = F::from_canonical_u32(BabyBear::two_adic_generator(4).as_canonical_u32());
self.cycle_16 = g.exp_u64((i + 1) as u64);
self.cycle_16_start
.populate_from_field_element(self.cycle_16 - g);
self.cycle_16_end
.populate_from_field_element(self.cycle_16 - F::one());
let j = 16 + (i % 48);
self.i = F::from_canonical_usize(j);
self.cycle_48[0] = F::from_bool((16..32).contains(&j));
self.cycle_48[1] = F::from_bool((32..48).contains(&j));
self.cycle_48[2] = F::from_bool((48..64).contains(&j));
self.cycle_48_start = self.cycle_48[0] * self.cycle_16_start.result * self.is_real;
self.cycle_48_end = self.cycle_48[2] * self.cycle_16_end.result * self.is_real;
}
}
impl ShaExtendChip {
pub fn eval_flags<AB: SP1AirBuilder>(&self, builder: &mut AB) {
let main = builder.main();
let (local, next) = (main.row_slice(0), main.row_slice(1));
let local: &ShaExtendCols<AB::Var> = (*local).borrow();
let next: &ShaExtendCols<AB::Var> = (*next).borrow();
let one = AB::Expr::from(AB::F::one());
let g = AB::F::from_canonical_u32(BabyBear::two_adic_generator(4).as_canonical_u32());
builder.when_first_row().assert_eq(local.cycle_16, g);
builder
.when_first_row()
.assert_eq(local.i, AB::F::from_canonical_u32(16));
builder
.when_transition()
.assert_eq(local.cycle_16 * g, next.cycle_16);
IsZeroOperation::<AB::F>::eval(
builder,
local.cycle_16 - AB::Expr::from(g),
local.cycle_16_start,
one.clone(),
);
IsZeroOperation::<AB::F>::eval(
builder,
local.cycle_16 - AB::Expr::one(),
local.cycle_16_end,
one.clone(),
);
builder
.when_first_row()
.assert_eq(local.cycle_48[0], AB::F::one());
builder
.when_first_row()
.assert_eq(local.cycle_48[1], AB::F::zero());
builder
.when_first_row()
.assert_eq(local.cycle_48[2], AB::F::zero());
for i in 0..3 {
builder
.when_transition()
.when(local.cycle_16_end.result)
.assert_eq(local.cycle_48[i], next.cycle_48[(i + 1) % 3]);
builder
.when_transition()
.when(one.clone() - local.cycle_16_end.result)
.assert_eq(local.cycle_48[i], next.cycle_48[i]);
builder.assert_bool(local.cycle_48[i]);
}
builder.assert_eq(
local.cycle_16_start.result * local.cycle_48[0] * local.is_real,
local.cycle_48_start,
);
builder.assert_eq(
local.cycle_16_end.result * local.cycle_48[2] * local.is_real,
local.cycle_48_end,
);
builder
.when_transition()
.when(local.cycle_16_end.result * local.cycle_48[2])
.assert_eq(next.i, AB::F::from_canonical_u32(16));
builder
.when_transition()
.when_not(local.cycle_16_end.result * local.cycle_48[2])
.assert_eq(local.i + one.clone(), next.i);
}
}