use core::mem::take;
use std::sync::Arc;
use hashbrown::HashMap;
use crate::{
hook::{hookify, BoxedHook, HookEnv, HookRegistry},
subproof::SubproofVerifier,
};
#[derive(Clone, Default)]
pub struct SP1Context<'a> {
pub hook_registry: Option<HookRegistry<'a>>,
pub subproof_verifier: Option<Arc<dyn SubproofVerifier + 'a>>,
pub max_cycles: Option<u64>,
}
#[derive(Clone, Default)]
pub struct SP1ContextBuilder<'a> {
no_default_hooks: bool,
hook_registry_entries: Vec<(u32, BoxedHook<'a>)>,
subproof_verifier: Option<Arc<dyn SubproofVerifier + 'a>>,
max_cycles: Option<u64>,
}
impl<'a> SP1Context<'a> {
#[must_use]
pub fn builder() -> SP1ContextBuilder<'a> {
SP1ContextBuilder::new()
}
}
impl<'a> SP1ContextBuilder<'a> {
#[must_use]
pub fn new() -> Self {
SP1ContextBuilder::default()
}
pub fn build(&mut self) -> SP1Context<'a> {
let hook_registry =
(!self.hook_registry_entries.is_empty() || self.no_default_hooks).then(|| {
let mut table = if take(&mut self.no_default_hooks) {
HashMap::default()
} else {
HookRegistry::default().table
};
table.extend(take(&mut self.hook_registry_entries));
HookRegistry { table }
});
let subproof_verifier = take(&mut self.subproof_verifier);
let cycle_limit = take(&mut self.max_cycles);
SP1Context { hook_registry, subproof_verifier, max_cycles: cycle_limit }
}
pub fn hook(
&mut self,
fd: u32,
f: impl FnMut(HookEnv, &[u8]) -> Vec<Vec<u8>> + Send + Sync + 'a,
) -> &mut Self {
self.hook_registry_entries.push((fd, hookify(f)));
self
}
pub fn without_default_hooks(&mut self) -> &mut Self {
self.no_default_hooks = true;
self
}
pub fn subproof_verifier(
&mut self,
subproof_verifier: Arc<dyn SubproofVerifier + 'a>,
) -> &mut Self {
self.subproof_verifier = Some(subproof_verifier);
self
}
pub fn max_cycles(&mut self, max_cycles: u64) -> &mut Self {
self.max_cycles = Some(max_cycles);
self
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{subproof::DefaultSubproofVerifier, SP1Context};
#[test]
fn defaults() {
let SP1Context { hook_registry, subproof_verifier, max_cycles: cycle_limit } =
SP1Context::builder().build();
assert!(hook_registry.is_none());
assert!(subproof_verifier.is_none());
assert!(cycle_limit.is_none());
}
#[test]
fn without_default_hooks() {
let SP1Context { hook_registry, .. } =
SP1Context::builder().without_default_hooks().build();
assert!(hook_registry.unwrap().table.is_empty());
}
#[test]
fn with_custom_hook() {
let SP1Context { hook_registry, .. } =
SP1Context::builder().hook(30, |_, _| vec![]).build();
assert!(hook_registry.unwrap().table.contains_key(&30));
}
#[test]
fn without_default_hooks_with_custom_hook() {
let SP1Context { hook_registry, .. } =
SP1Context::builder().without_default_hooks().hook(30, |_, _| vec![]).build();
assert_eq!(&hook_registry.unwrap().table.into_keys().collect::<Vec<_>>(), &[30]);
}
#[test]
fn subproof_verifier() {
let SP1Context { subproof_verifier, .. } = SP1Context::builder()
.subproof_verifier(Arc::new(DefaultSubproofVerifier::new()))
.build();
assert!(subproof_verifier.is_some());
}
}