sp1_jit/lib.rs
1#![cfg_attr(not(sp1_native_executor_available), allow(unused))]
2
3#[cfg(not(target_endian = "little"))]
4compile_error!("This crate is only supported on little endian targets.");
5
6pub mod backends;
7pub mod context;
8pub mod debug;
9pub mod instructions;
10mod macros;
11pub mod memory;
12pub mod risc;
13pub mod shm;
14
15use dynasmrt::ExecutableBuffer;
16use hashbrown::HashMap;
17use std::{
18 collections::VecDeque,
19 io,
20 ops::{Deref, DerefMut},
21 os::fd::AsRawFd,
22 ptr::NonNull,
23 sync::{mpsc, Arc},
24};
25
26pub use backends::*;
27pub use context::*;
28pub use instructions::*;
29pub use risc::*;
30
31/// A function that accepts the memory pointer.
32pub type ExternFn = extern "C" fn(*mut JitContext);
33
34pub type EcallHandler = extern "C" fn(*mut JitContext) -> u64;
35
36/// A debugging utility to inspect registers
37pub type DebugFn = extern "C" fn(u64);
38
39/// A transpiler for risc32 instructions.
40///
41/// This trait is implemented for each target architecture supported by the JIT transpiler.
42///
43/// The transpiler is responsible for translating the risc32 instructions into the target
44/// architecture's instruction set.
45///
46/// This transpiler should generate an entrypoint of the form: [`fn(*mut JitContext)`]
47///
48/// For each instruction, you will typically want to call [`SP1RiscvTranspiler::start_instr`]
49/// before transpiling the instruction. This maps a "riscv instruction index" to some physical
50/// native address, as there are multiple native instructions per riscv instruction.
51///
52/// You will also likely want to call [`SP1RiscvTranspiler::bump_clk`] to increment the clock
53/// counter, and [`SP1RiscvTranspiler::set_pc`] to set the PC.
54///
55/// # Note
56/// Some instructions will directly modify the PC, such as [`SP1RiscvTranspiler::jal`] and
57/// [`SP1RiscvTranspiler::jalr`], and all the branch instructions, for these instructions, you would
58/// not want to call [`SP1RiscvTranspiler::set_pc`] as it will be called for you.
59///
60///
61/// ```rust,no_run,ignore
62/// pub fn add_program() {
63/// let mut transpiler = SP1RiscvTranspiler::new(program_size, memory_size, trace_buf_size, 100, 100).unwrap();
64///
65/// // Transpile the first instruction.
66/// transpiler.start_instr();
67/// transpiler.add(RiscOperand::Reg(RiscRegister::A), RiscOperand::Reg(RiscRegister::B), RiscRegister::C);
68/// transpiler.end_instr();
69///
70/// // Transpile the second instruction.
71/// transpiler.start_instr();
72///
73/// transpiler.add(RiscOperand::Reg(RiscRegister::A), RiscOperand::Reg(RiscRegister::B), RiscRegister::C);
74/// transpiler.end_instr();
75///
76/// let mut func = transpiler.finalize();
77///
78/// // Call the function.
79/// let traces = func.call();
80///
81/// // do stuff with the traces.
82/// }
83/// ```
84pub trait RiscvTranspiler:
85 TraceCollector
86 + ComputeInstructions
87 + ControlFlowInstructions
88 + MemoryInstructions
89 + SystemInstructions
90 + Sized
91{
92 /// Create a new transpiler.
93 ///
94 /// The program is used for the jump-table and is not a hard limit on the size of the program.
95 /// The memory size is the exact amount that will be allocated for the program.
96 fn new(
97 program_size: usize,
98 memory_size: usize,
99 max_trace_size: u64,
100 pc_start: u64,
101 pc_base: u64,
102 clk_bump: u64,
103 ) -> Result<Self, std::io::Error>;
104
105 /// Register a rust function of the form [`EcallHandler`] that will be used as the ECALL.
106 fn register_ecall_handler(&mut self, handler: EcallHandler);
107
108 /// Populates a jump table entry for the current instruction being transpiled.
109 ///
110 /// Effectively should create a mapping from RISCV PC -> absolute address of the instruction.
111 ///
112 /// This method should be called for "each pc" in the program.
113 fn start_instr(&mut self);
114
115 /// This method should be called for "each pc" in the program.
116 /// Handle logics when finishing execution of an instruction such as bumping clk and jump to
117 /// branch destination.
118 fn end_instr(&mut self);
119
120 /// Inspcet a [RiscRegister] using a function pointer.
121 ///
122 /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
123 fn inspect_register(&mut self, reg: RiscRegister, handler: DebugFn);
124
125 /// Print an immediate value.
126 ///
127 /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
128 fn inspect_immediate(&mut self, imm: u64, handler: DebugFn);
129
130 /// Call an [ExternFn] from the outputted assembly.
131 ///
132 /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
133 fn call_extern_fn(&mut self, handler: ExternFn);
134
135 /// Returns the function pointer to the generated code.
136 ///
137 /// This function is expected to be of the form: `fn(*mut JitContext)`.
138 fn finalize<M: JitMemory>(self) -> io::Result<JitFunction<M>>;
139}
140
141/// A trait the collects traces, in the form [TraceChunk].
142///
143/// This type is expected to follow the conventions as described in the [TraceChunk] documentation.
144pub trait TraceCollector {
145 /// Write the current state of the registers into the trace buf.
146 ///
147 /// For SP1 this is only called once in the beginning of a "chunk".
148 fn trace_registers(&mut self);
149
150 /// Write the value located at rs1 + imm into the trace buf.
151 fn trace_mem_value(&mut self, rs1: RiscRegister, imm: u64);
152
153 /// Write the start pc of the trace chunk.
154 fn trace_pc_start(&mut self);
155
156 /// Write the start clk of the trace chunk.
157 fn trace_clk_start(&mut self);
158
159 /// Write the end clk of the trace chunk.
160 fn trace_clk_end(&mut self);
161}
162
163pub trait Debuggable {
164 fn print_ctx(&mut self);
165}
166
167/// A trait representing JIT memory.
168pub trait JitMemory: Sized + Deref<Target = [u8]> + DerefMut + AsRawFd {
169 fn new(memory_size: usize) -> Self;
170}
171
172/// A JIT memory that is also resetable
173pub trait JitResetableMemory: JitMemory {
174 fn reset(&mut self);
175}
176
177impl<T: RiscvTranspiler> Debuggable for T {
178 // Useful only for debugging.
179 fn print_ctx(&mut self) {
180 extern "C" fn print_ctx(ctx: *mut JitContext) {
181 let ctx = unsafe { &mut *ctx };
182 eprintln!("pc: {:x}", ctx.pc);
183 eprintln!("clk: {}", ctx.clk);
184 eprintln!("{:?}", *ctx.registers());
185 }
186
187 self.call_extern_fn(print_ctx);
188 }
189}
190
191#[cfg(not(sp1_native_executor_available))]
192/// Stub implementation for non-linux targets to compile.
193pub struct JitFunction<M> {
194 _marker: std::marker::PhantomData<M>,
195}
196
197/// A type representing a JIT compiled function.
198///
199/// The underlying function should be of the form [`fn(*mut JitContext)`].
200#[cfg(sp1_native_executor_available)]
201pub struct JitFunction<M> {
202 jump_table: Vec<*const u8>,
203 code: ExecutableBuffer,
204
205 /// The initial memory image.
206 initial_memory_image: Arc<HashMap<u64, u64>>,
207 pc_start: u64,
208 input_buffer: VecDeque<Vec<u8>>,
209
210 /// A stream of public values from the program (global to entire program).
211 pub public_values_stream: Vec<u8>,
212
213 /// Memory structure,
214 pub memory: M,
215
216 /// During execution, the hints are read by the program, and we store them here.
217 /// This is effectively a mapping from start address to the value of the hint.
218 pub hints: Vec<(u64, Vec<u8>)>,
219
220 pub pc: u64,
221 pub registers: [u64; 32],
222 pub clk: u64,
223 pub global_clk: u64,
224 pub exit_code: u32,
225
226 pub debug_sender: Option<mpsc::SyncSender<Option<debug::State>>>,
227}
228
229unsafe impl<M: Send> Send for JitFunction<M> {}
230
231#[cfg(sp1_native_executor_available)]
232impl<M: JitMemory> JitFunction<M> {
233 pub(crate) fn new(
234 code: ExecutableBuffer,
235 jump_table: Vec<usize>,
236 memory_size: usize,
237 pc_start: u64,
238 ) -> std::io::Result<Self> {
239 // Adjust the jump table to be absolute addresses.
240 let buf_ptr = code.as_ptr();
241 let jump_table =
242 jump_table.into_iter().map(|offset| unsafe { buf_ptr.add(offset) }).collect();
243
244 let memory = M::new(memory_size);
245
246 Ok(Self {
247 jump_table,
248 code,
249 memory,
250 pc: pc_start,
251 clk: 1,
252 global_clk: 0,
253 registers: [0; 32],
254 initial_memory_image: Arc::new(HashMap::new()),
255 pc_start,
256 input_buffer: VecDeque::new(),
257 hints: Vec::new(),
258 public_values_stream: Vec::new(),
259 debug_sender: None,
260 exit_code: 0,
261 })
262 }
263
264 /// Write the initial memory image to the JIT memory.
265 ///
266 /// # Panics
267 ///
268 /// Panics if the PC is not the starting PC.
269 pub fn with_initial_memory_image(&mut self, memory: Arc<HashMap<u64, u64>>) {
270 assert!(
271 self.pc == self.pc_start,
272 "The initial memory should only be supplied before using the JIT function."
273 );
274
275 self.initial_memory_image = memory;
276 self.insert_memory_image();
277 }
278
279 /// Push an input to the input buffer.
280 ///
281 /// # Panics
282 ///
283 /// Panics if the PC is not the starting PC.
284 pub fn push_input(&mut self, input: Vec<u8>) {
285 assert!(
286 self.pc == self.pc_start,
287 "The input buffer should only be supplied before using the JIT function."
288 );
289
290 self.input_buffer.push_back(input);
291
292 self.hints.reserve(1);
293 }
294
295 /// Set the entire input buffer.
296 ///
297 /// # Panics
298 ///
299 /// Panics if the PC is not the starting PC.
300 pub fn set_input_buffer(&mut self, input: VecDeque<Vec<u8>>) {
301 assert!(
302 self.pc == self.pc_start,
303 "The input buffer should only be supplied before using the JIT function."
304 );
305
306 // Reserve the space for the hints.
307 self.hints.reserve(input.len());
308 self.input_buffer = input;
309 }
310
311 /// Call the function, returning the trace buffer, starting at the starting PC of the program.
312 ///
313 /// If the PC is 0, then the program has completed and we return None.
314 ///
315 /// # SAFETY
316 /// Relies on the builder to emit valid assembly
317 /// and that the pointer is valid for the duration of the function call.
318 pub unsafe fn call(&mut self, trace_buf_ptr: *mut u8) {
319 if self.pc == 1 {
320 return;
321 }
322
323 let as_fn = std::mem::transmute::<*const u8, fn(*mut JitContext)>(self.code.as_ptr());
324
325 // Ensure the memory pointer is aligned to the alignment of the u64.
326 let align_offset = self.memory.as_ptr().align_offset(std::mem::align_of::<u64>());
327 let mem_ptr = self.memory.as_mut_ptr().add(align_offset);
328 let tracing = !trace_buf_ptr.is_null();
329
330 // SAFETY:
331 // - The jump table is valid for the duration of the function call, its owned by self.
332 // - The memory is valid for the duration of the function call, its owned by self.
333 // - The trace buf is valid for the duration of the function call, we just allocated it
334 // - The input buffer is valid for the duration of the function call, its owned by self.
335 let mut ctx = JitContext {
336 jump_table: NonNull::new_unchecked(self.jump_table.as_mut_ptr()),
337 memory: NonNull::new_unchecked(mem_ptr),
338 trace_buf: trace_buf_ptr,
339 input_buffer: NonNull::new_unchecked(&mut self.input_buffer),
340 hints: NonNull::new_unchecked(&mut self.hints),
341 maybe_unconstrained: None,
342 public_values_stream: NonNull::new_unchecked(&mut self.public_values_stream),
343 memory_fd: self.memory.as_raw_fd(),
344 registers: self.registers,
345 pc: self.pc,
346 clk: self.clk,
347 global_clk: self.global_clk,
348 is_unconstrained: 0,
349 tracing,
350 debug_sender: self.debug_sender.clone(),
351 exit_code: self.exit_code,
352 };
353
354 tracing::debug_span!("JIT function", pc = ctx.pc, clk = ctx.clk).in_scope(|| {
355 as_fn(&mut ctx);
356 });
357
358 // Update the values we want to preserve.
359 self.pc = ctx.pc;
360 self.registers = ctx.registers;
361 self.clk = ctx.clk;
362 self.global_clk = ctx.global_clk;
363 self.exit_code = ctx.exit_code;
364 }
365
366 fn insert_memory_image(&mut self) {
367 for (addr, val) in self.initial_memory_image.iter() {
368 // Technically, this crate is probably only used on little endian targets, but just to
369 // sure.
370 let bytes = val.to_le_bytes();
371
372 #[cfg(debug_assertions)]
373 if addr % 8 > 0 {
374 panic!("Address {addr} is not aligned to 8");
375 }
376
377 let actual_addr = 2 * addr + 8;
378 unsafe {
379 std::ptr::copy_nonoverlapping(
380 bytes.as_ptr(),
381 self.memory.as_mut_ptr().add(actual_addr as usize),
382 bytes.len(),
383 )
384 };
385 }
386 }
387}
388
389#[cfg(sp1_native_executor_available)]
390impl<M: JitResetableMemory> JitFunction<M> {
391 /// Reset the JIT function to the initial state.
392 ///
393 /// This will clear the registers, the program counter, the clock, and the memory, restoring the
394 /// initial memory image.
395 pub fn reset(&mut self) {
396 self.pc = self.pc_start;
397 self.registers = [0; 32];
398 self.clk = 1;
399 self.global_clk = 0;
400 self.input_buffer = VecDeque::new();
401 self.hints = Vec::new();
402 self.public_values_stream = Vec::new();
403 self.memory.reset();
404
405 self.insert_memory_image();
406 }
407}