use std::{cell::RefCell, collections::HashMap};
use triton_opcodes::instruction::LabelledInstruction;
use triton_vm::BFieldElement;
use crate::snippet::{DataType, Snippet};
pub struct RawCode {
pub function: Vec<LabelledInstruction>,
pub input_types: Vec<DataType>,
pub output_types: Vec<DataType>,
#[allow(clippy::type_complexity)]
rust_shadowing: Option<Box<RefCell<dyn FnMut(&mut Vec<BFieldElement>)>>>,
}
impl RawCode {
pub fn new(
function: Vec<LabelledInstruction>,
input_types: Vec<DataType>,
output_types: Vec<DataType>,
) -> Self {
assert!(
function.len() >= 2,
"Inner function must have at least two lines: a label and a return or recurse"
);
assert!(
matches!(function[0], LabelledInstruction::Label(_)),
"First line of inner function must be label. Got: {}",
function[0]
);
assert!(
matches!(
function.last().unwrap(),
LabelledInstruction::Instruction(
triton_opcodes::instruction::AnInstruction::Return
) | LabelledInstruction::Instruction(
triton_opcodes::instruction::AnInstruction::Recurse
)
),
"Last line of inner function must be either return or recurse. Got: {}",
function.last().unwrap()
);
Self {
function,
input_types,
output_types,
rust_shadowing: None,
}
}
#[allow(clippy::type_complexity)]
pub fn new_with_shadowing(
function: Vec<LabelledInstruction>,
input_types: Vec<DataType>,
output_types: Vec<DataType>,
rust_shadowing: Box<RefCell<dyn FnMut(&mut Vec<BFieldElement>)>>,
) -> Self {
assert!(
function.len() >= 2,
"Inner function must have at least two lines: a label and a return or recurse"
);
assert!(
matches!(function[0], LabelledInstruction::Label(_)),
"First line of inner function must be label. Got: {}",
function[0]
);
assert!(
matches!(
function.last().unwrap(),
LabelledInstruction::Instruction(
triton_opcodes::instruction::AnInstruction::Return
) | LabelledInstruction::Instruction(
triton_opcodes::instruction::AnInstruction::Recurse
)
),
"Last line of inner function must be either return or recurse. Got: {}",
function.last().unwrap()
);
Self {
function,
input_types,
output_types,
rust_shadowing: Some(rust_shadowing),
}
}
}
impl RawCode {
pub fn entrypoint(&self) -> String {
match &self.function[0] {
LabelledInstruction::Instruction(inst) => {
panic!("First line of inner function must be a label. Got: {inst}")
}
LabelledInstruction::Label(label) => label.to_owned(),
}
}
}
pub enum InnerFunction {
RawCode(RawCode),
Snippet(Box<dyn Snippet>),
NoFunctionBody(NoFunctionBody),
}
pub struct NoFunctionBody {
pub label_name: String,
pub input_types: Vec<DataType>,
pub output_types: Vec<DataType>,
}
impl InnerFunction {
pub fn get_input_types(&self) -> Vec<DataType> {
match self {
InnerFunction::RawCode(raw) => raw.input_types.clone(),
InnerFunction::Snippet(f) => f.input_types(),
InnerFunction::NoFunctionBody(f) => f.input_types.clone(),
}
}
pub fn input_list_element_type(&self) -> DataType {
self.get_input_types().last().unwrap().to_owned()
}
pub fn additional_inputs(&self) -> Vec<DataType> {
let mut input_types = self.get_input_types();
input_types.pop().unwrap();
input_types
}
pub fn size_of_additional_inputs(&self) -> usize {
self.additional_inputs().iter().map(|x| x.get_size()).sum()
}
pub fn get_output_types(&self) -> Vec<DataType> {
match self {
InnerFunction::RawCode(rc) => rc.output_types.clone(),
InnerFunction::Snippet(sn) => sn.output_types(),
InnerFunction::NoFunctionBody(lnat) => lnat.output_types.clone(),
}
}
pub fn entrypoint(&self) -> String {
match self {
InnerFunction::RawCode(rc) => rc.entrypoint(),
InnerFunction::Snippet(sn) => sn.entrypoint(),
InnerFunction::NoFunctionBody(sn) => sn.label_name.to_owned(),
}
}
pub fn rust_shadowing(
&self,
std_in: &[BFieldElement],
secret_in: &[BFieldElement],
stack: &mut Vec<BFieldElement>,
memory: &mut HashMap<BFieldElement, BFieldElement>,
) {
match &self {
InnerFunction::RawCode(rc) => {
if let Some(func) = &rc.rust_shadowing {
let mut func = func.borrow_mut();
(*func)(stack)
} else {
panic!("Raw code must have rust shadowing for equivalence testing")
}
}
InnerFunction::Snippet(sn) => {
sn.rust_shadowing(stack, std_in.to_vec(), secret_in.to_vec(), memory)
}
InnerFunction::NoFunctionBody(_lnat) => {
panic!("Cannot rust shadow inner function without function body")
}
};
}
}