sp1_core_executor/
context.rs1use crate::hook::{hookify, BoxedHook, HookEnv, HookRegistry};
2use core::mem::take;
3use hashbrown::HashMap;
4use sp1_hypercube::air::PROOF_NONCE_NUM_WORDS;
5use std::io::Write;
6
7use sp1_primitives::consts::fd::LOWEST_ALLOWED_FD;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub struct StatusCode(u32);
16
17impl StatusCode {
18 pub const SUCCESS: Self = Self(0);
20 pub const PANIC: Self = Self(1);
22 pub const INVALID_HINT: Self = Self(3);
26 pub const ANY: Self = Self(u32::MAX);
28
29 #[must_use]
38 pub const fn new(code: u32) -> Option<Self> {
39 match code {
40 0 => Some(Self::SUCCESS),
41 1 => Some(Self::PANIC),
42 3 => Some(Self::INVALID_HINT),
43 _ => None,
44 }
45 }
46
47 #[must_use]
49 pub const fn as_u32(&self) -> u32 {
50 self.0
51 }
52
53 #[must_use]
55 pub const fn is_accepted_code(&self, code: u32) -> bool {
56 matches!(code, 0 | 1 | 3) && (self.0 == Self::ANY.0 || self.0 == code)
57 }
58}
59
60#[derive(Clone)]
62pub struct SP1Context<'a> {
63 pub hook_registry: Option<HookRegistry<'a>>,
67
68 pub max_cycles: Option<u64>,
70
71 pub deferred_proof_verification: bool,
73
74 pub expected_exit_code: StatusCode,
76
77 pub calculate_gas: bool,
82
83 pub proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
86
87 pub io_options: IoOptions<'a>,
89}
90
91impl Default for SP1Context<'_> {
92 fn default() -> Self {
93 Self::builder().build()
94 }
95}
96
97pub struct SP1ContextBuilder<'a> {
99 no_default_hooks: bool,
100 hook_registry_entries: Vec<(u32, BoxedHook<'a>)>,
101 max_cycles: Option<u64>,
102 deferred_proof_verification: bool,
103 calculate_gas: bool,
104 expected_exit_code: Option<StatusCode>,
105 proof_nonce: [u32; 4],
106 io_options: IoOptions<'a>,
108}
109
110impl Default for SP1ContextBuilder<'_> {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116impl<'a> SP1Context<'a> {
117 #[must_use]
119 pub fn builder() -> SP1ContextBuilder<'a> {
120 SP1ContextBuilder::new()
121 }
122}
123
124impl<'a> SP1ContextBuilder<'a> {
125 #[must_use]
129 pub const fn new() -> Self {
130 Self {
131 no_default_hooks: false,
132 hook_registry_entries: Vec::new(),
133 max_cycles: None,
134 deferred_proof_verification: true,
136 calculate_gas: true,
137 expected_exit_code: None,
138 proof_nonce: [0, 0, 0, 0], io_options: IoOptions::new(),
140 }
141 }
142
143 pub fn build(&mut self) -> SP1Context<'a> {
147 let hook_registry =
153 (!self.hook_registry_entries.is_empty() || self.no_default_hooks).then(|| {
154 let mut table = if take(&mut self.no_default_hooks) {
155 HashMap::default()
156 } else {
157 HookRegistry::default().table
158 };
159
160 self.hook_registry_entries
161 .iter()
162 .map(|(fd, _)| fd)
163 .filter(|fd| table.contains_key(*fd))
164 .for_each(|fd| {
165 tracing::warn!("Overriding default hook with file descriptor {}", fd);
166 });
167
168 table.extend(take(&mut self.hook_registry_entries));
170 HookRegistry { table }
171 });
172
173 let cycle_limit = take(&mut self.max_cycles);
174 let deferred_proof_verification = take(&mut self.deferred_proof_verification);
175 let calculate_gas = take(&mut self.calculate_gas);
176 let proof_nonce = take(&mut self.proof_nonce);
177 SP1Context {
178 hook_registry,
179 max_cycles: cycle_limit,
180 deferred_proof_verification,
181 calculate_gas,
182 proof_nonce,
183 io_options: take(&mut self.io_options),
184 expected_exit_code: self.expected_exit_code.unwrap_or(StatusCode::SUCCESS),
185 }
186 }
187
188 pub fn hook(
197 &mut self,
198 fd: u32,
199 f: impl FnMut(HookEnv, &[u8]) -> Vec<Vec<u8>> + Send + Sync + 'a,
200 ) -> &mut Self {
201 assert!(fd > LOWEST_ALLOWED_FD, "Hook file descriptors must be greater than 10.");
202
203 self.hook_registry_entries.push((fd, hookify(f)));
204 self
205 }
206
207 pub fn without_default_hooks(&mut self) -> &mut Self {
212 self.no_default_hooks = true;
213 self
214 }
215
216 pub fn calculate_gas(&mut self, value: bool) -> &mut Self {
223 self.calculate_gas = value;
224 self
225 }
226
227 pub fn max_cycles(&mut self, max_cycles: u64) -> &mut Self {
230 self.max_cycles = Some(max_cycles);
231 self
232 }
233
234 pub fn set_deferred_proof_verification(&mut self, value: bool) -> &mut Self {
236 self.deferred_proof_verification = value;
237 self
238 }
239
240 pub fn stdout<W: IoWriter>(&mut self, writer: &'a mut W) -> &mut Self {
242 self.io_options.stdout = Some(writer);
243 self
244 }
245
246 pub fn stderr<W: IoWriter>(&mut self, writer: &'a mut W) -> &mut Self {
248 self.io_options.stderr = Some(writer);
249 self
250 }
251
252 pub fn expected_exit_code(&mut self, code: StatusCode) -> &mut Self {
254 self.expected_exit_code = Some(code);
255 self
256 }
257
258 pub fn proof_nonce(&mut self, nonce: [u32; 4]) -> &mut Self {
261 self.proof_nonce = nonce;
262 self
263 }
264}
265
266#[derive(Default)]
272pub struct IoOptions<'a> {
273 pub stdout: Option<&'a mut dyn IoWriter>,
275 pub stderr: Option<&'a mut dyn IoWriter>,
277}
278
279impl IoOptions<'_> {
280 #[must_use]
282 pub const fn new() -> Self {
283 Self { stdout: None, stderr: None }
284 }
285}
286impl Clone for IoOptions<'_> {
287 fn clone(&self) -> Self {
288 IoOptions { stdout: None, stderr: None }
289 }
290}
291
292pub trait IoWriter: Write + Send + Sync {}
296
297impl<W: Write + Send + Sync> IoWriter for W {}
298
299#[cfg(test)]
300mod tests {
301 use crate::SP1Context;
302
303 #[test]
304 fn defaults() {
305 let SP1Context { hook_registry, max_cycles: cycle_limit, .. } =
306 SP1Context::builder().build();
307 assert!(hook_registry.is_none());
308 assert!(cycle_limit.is_none());
309 }
310
311 #[test]
312 fn without_default_hooks() {
313 let SP1Context { hook_registry, .. } =
314 SP1Context::builder().without_default_hooks().build();
315 assert!(hook_registry.unwrap().table.is_empty());
316 }
317
318 #[test]
319 fn with_custom_hook() {
320 let SP1Context { hook_registry, .. } =
321 SP1Context::builder().hook(30, |_, _| vec![]).build();
322 assert!(hook_registry.unwrap().table.contains_key(&30));
323 }
324
325 #[test]
326 fn without_default_hooks_with_custom_hook() {
327 let SP1Context { hook_registry, .. } =
328 SP1Context::builder().without_default_hooks().hook(30, |_, _| vec![]).build();
329 assert_eq!(&hook_registry.unwrap().table.into_keys().collect::<Vec<_>>(), &[30]);
330 }
331}