use p3_air::AirBuilder;
use p3_field::{AbstractField, Field};
use sp1_core::air::{BinomialExtension, ExtensionAirBuilder};
use crate::{
air::{
BinomialExtensionUtils, Block, BlockBuilder, IsExtZeroOperation, SP1RecursionAirBuilder,
},
cpu::{CpuChip, CpuCols},
memory::MemoryCols,
};
impl<F: Field, const L: usize> CpuChip<F, L> {
pub fn eval_branch<AB>(
&self,
builder: &mut AB,
local: &CpuCols<AB::Var>,
next_pc: &mut AB::Expr,
) where
AB: SP1RecursionAirBuilder<F = F>,
{
let branch_cols = local.opcode_specific.branch();
let is_branch_instruction = self.is_branch_instruction::<AB>(local);
let one = AB::Expr::one();
let a_prev_ext: BinomialExtension<AB::Expr> =
BinomialExtensionUtils::from_block(local.a.prev_value().map(|x| x.into()));
let a_ext: BinomialExtension<AB::Expr> =
BinomialExtensionUtils::from_block(local.a.value().map(|x| x.into()));
let b_ext: BinomialExtension<AB::Expr> =
BinomialExtensionUtils::from_block(local.b.value().map(|x| x.into()));
let one_ext: BinomialExtension<AB::Expr> =
BinomialExtensionUtils::from_block(Block::from(one.clone()));
let expected_a_ext = a_prev_ext + one_ext;
builder
.when(local.is_real)
.when(local.selectors.is_bneinc)
.assert_block_eq(a_ext.as_block(), expected_a_ext.as_block());
let comparison_diff = a_ext - b_ext;
builder.when(is_branch_instruction.clone()).assert_ext_eq(
BinomialExtension::from(branch_cols.comparison_diff_val),
comparison_diff,
);
IsExtZeroOperation::<AB::F>::eval(
builder,
BinomialExtension::from(branch_cols.comparison_diff_val),
branch_cols.comparison_diff,
is_branch_instruction.clone(),
);
let mut do_branch = local.selectors.is_beq * branch_cols.comparison_diff.result;
do_branch += local.selectors.is_bne * (one.clone() - branch_cols.comparison_diff.result);
do_branch += local.selectors.is_bneinc * (one.clone() - branch_cols.comparison_diff.result);
builder
.when(is_branch_instruction.clone())
.assert_eq(branch_cols.do_branch, do_branch);
let pc_offset = local.c.value().0[0];
let expected_next_pc =
builder.if_else(branch_cols.do_branch, local.pc + pc_offset, local.pc + one);
builder
.when(is_branch_instruction.clone())
.assert_eq(branch_cols.next_pc, expected_next_pc);
*next_pc = is_branch_instruction * branch_cols.next_pc;
}
}