sp1_core_executor/
context.rs1use crate::hook::{hookify, BoxedHook, HookEnv, HookRegistry};
2use core::mem::take;
3use hashbrown::HashMap;
4use sp1_hypercube::air::PROOF_NONCE_NUM_WORDS;
5use std::io::Write;
6
7use sp1_primitives::consts::fd::LOWEST_ALLOWED_FD;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct StatusCode(u32);
14
15impl StatusCode {
16 pub const SUCCESS: Self = Self(0);
18 pub const PANIC: Self = Self(1);
20 pub const ANY: Self = Self(u32::MAX);
22
23 #[must_use]
32 pub const fn new(code: u32) -> Option<Self> {
33 match code {
34 0 => Some(Self::SUCCESS),
35 1 => Some(Self::PANIC),
36 _ => None,
37 }
38 }
39
40 #[must_use]
42 pub const fn as_u32(&self) -> u32 {
43 self.0
44 }
45
46 #[must_use]
48 pub const fn is_accepted_code(&self, code: u32) -> bool {
49 (code == 0 || code == 1) && (self.0 == Self::ANY.0 || self.0 == code)
50 }
51}
52
53#[derive(Clone)]
55pub struct SP1Context<'a> {
56 pub hook_registry: Option<HookRegistry<'a>>,
60
61 pub max_cycles: Option<u64>,
63
64 pub deferred_proof_verification: bool,
66
67 pub expected_exit_code: StatusCode,
69
70 pub calculate_gas: bool,
75
76 pub proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
79
80 pub io_options: IoOptions<'a>,
82}
83
84impl Default for SP1Context<'_> {
85 fn default() -> Self {
86 Self::builder().build()
87 }
88}
89
90pub struct SP1ContextBuilder<'a> {
92 no_default_hooks: bool,
93 hook_registry_entries: Vec<(u32, BoxedHook<'a>)>,
94 max_cycles: Option<u64>,
95 deferred_proof_verification: bool,
96 calculate_gas: bool,
97 expected_exit_code: Option<StatusCode>,
98 proof_nonce: [u32; 4],
99 io_options: IoOptions<'a>,
101}
102
103impl Default for SP1ContextBuilder<'_> {
104 fn default() -> Self {
105 Self::new()
106 }
107}
108
109impl<'a> SP1Context<'a> {
110 #[must_use]
112 pub fn builder() -> SP1ContextBuilder<'a> {
113 SP1ContextBuilder::new()
114 }
115}
116
117impl<'a> SP1ContextBuilder<'a> {
118 #[must_use]
122 pub const fn new() -> Self {
123 Self {
124 no_default_hooks: false,
125 hook_registry_entries: Vec::new(),
126 max_cycles: None,
127 deferred_proof_verification: true,
129 calculate_gas: true,
130 expected_exit_code: None,
131 proof_nonce: [0, 0, 0, 0], io_options: IoOptions::new(),
133 }
134 }
135
136 pub fn build(&mut self) -> SP1Context<'a> {
140 let hook_registry =
146 (!self.hook_registry_entries.is_empty() || self.no_default_hooks).then(|| {
147 let mut table = if take(&mut self.no_default_hooks) {
148 HashMap::default()
149 } else {
150 HookRegistry::default().table
151 };
152
153 self.hook_registry_entries
154 .iter()
155 .map(|(fd, _)| fd)
156 .filter(|fd| table.contains_key(*fd))
157 .for_each(|fd| {
158 tracing::warn!("Overriding default hook with file descriptor {}", fd);
159 });
160
161 table.extend(take(&mut self.hook_registry_entries));
163 HookRegistry { table }
164 });
165
166 let cycle_limit = take(&mut self.max_cycles);
167 let deferred_proof_verification = take(&mut self.deferred_proof_verification);
168 let calculate_gas = take(&mut self.calculate_gas);
169 let proof_nonce = take(&mut self.proof_nonce);
170 SP1Context {
171 hook_registry,
172 max_cycles: cycle_limit,
173 deferred_proof_verification,
174 calculate_gas,
175 proof_nonce,
176 io_options: take(&mut self.io_options),
177 expected_exit_code: self.expected_exit_code.unwrap_or(StatusCode::SUCCESS),
178 }
179 }
180
181 pub fn hook(
190 &mut self,
191 fd: u32,
192 f: impl FnMut(HookEnv, &[u8]) -> Vec<Vec<u8>> + Send + Sync + 'a,
193 ) -> &mut Self {
194 assert!(fd > LOWEST_ALLOWED_FD, "Hook file descriptors must be greater than 10.");
195
196 self.hook_registry_entries.push((fd, hookify(f)));
197 self
198 }
199
200 pub fn without_default_hooks(&mut self) -> &mut Self {
205 self.no_default_hooks = true;
206 self
207 }
208
209 pub fn calculate_gas(&mut self, value: bool) -> &mut Self {
216 self.calculate_gas = value;
217 self
218 }
219
220 pub fn max_cycles(&mut self, max_cycles: u64) -> &mut Self {
223 self.max_cycles = Some(max_cycles);
224 self
225 }
226
227 pub fn set_deferred_proof_verification(&mut self, value: bool) -> &mut Self {
229 self.deferred_proof_verification = value;
230 self
231 }
232
233 pub fn stdout<W: IoWriter>(&mut self, writer: &'a mut W) -> &mut Self {
235 self.io_options.stdout = Some(writer);
236 self
237 }
238
239 pub fn stderr<W: IoWriter>(&mut self, writer: &'a mut W) -> &mut Self {
241 self.io_options.stderr = Some(writer);
242 self
243 }
244
245 pub fn expected_exit_code(&mut self, code: StatusCode) -> &mut Self {
247 self.expected_exit_code = Some(code);
248 self
249 }
250
251 pub fn proof_nonce(&mut self, nonce: [u32; 4]) -> &mut Self {
254 self.proof_nonce = nonce;
255 self
256 }
257}
258
259#[derive(Default)]
265pub struct IoOptions<'a> {
266 pub stdout: Option<&'a mut dyn IoWriter>,
268 pub stderr: Option<&'a mut dyn IoWriter>,
270}
271
272impl IoOptions<'_> {
273 #[must_use]
275 pub const fn new() -> Self {
276 Self { stdout: None, stderr: None }
277 }
278}
279impl Clone for IoOptions<'_> {
280 fn clone(&self) -> Self {
281 IoOptions { stdout: None, stderr: None }
282 }
283}
284
285pub trait IoWriter: Write + Send + Sync {}
289
290impl<W: Write + Send + Sync> IoWriter for W {}
291
292#[cfg(test)]
293mod tests {
294 use crate::SP1Context;
295
296 #[test]
297 fn defaults() {
298 let SP1Context { hook_registry, max_cycles: cycle_limit, .. } =
299 SP1Context::builder().build();
300 assert!(hook_registry.is_none());
301 assert!(cycle_limit.is_none());
302 }
303
304 #[test]
305 fn without_default_hooks() {
306 let SP1Context { hook_registry, .. } =
307 SP1Context::builder().without_default_hooks().build();
308 assert!(hook_registry.unwrap().table.is_empty());
309 }
310
311 #[test]
312 fn with_custom_hook() {
313 let SP1Context { hook_registry, .. } =
314 SP1Context::builder().hook(30, |_, _| vec![]).build();
315 assert!(hook_registry.unwrap().table.contains_key(&30));
316 }
317
318 #[test]
319 fn without_default_hooks_with_custom_hook() {
320 let SP1Context { hook_registry, .. } =
321 SP1Context::builder().without_default_hooks().hook(30, |_, _| vec![]).build();
322 assert_eq!(&hook_registry.unwrap().table.into_keys().collect::<Vec<_>>(), &[30]);
323 }
324}