sp1_jit/lib.rs
1#![cfg_attr(not(target_os = "linux"), allow(unused))]
2
3#[cfg(not(target_endian = "little"))]
4compile_error!("This crate is only supported on little endian targets.");
5
6pub mod backends;
7pub mod context;
8pub mod debug;
9pub mod instructions;
10mod macros;
11pub mod risc;
12
13use dynasmrt::ExecutableBuffer;
14use hashbrown::HashMap;
15use memmap2::{MmapMut, MmapOptions};
16use std::{
17 collections::VecDeque,
18 io,
19 os::fd::AsRawFd,
20 ptr::NonNull,
21 sync::{mpsc, Arc},
22};
23
24pub use backends::*;
25pub use context::*;
26pub use instructions::*;
27pub use risc::*;
28
29/// A function that accepts the memory pointer.
30pub type ExternFn = extern "C" fn(*mut JitContext);
31
32pub type EcallHandler = extern "C" fn(*mut JitContext) -> u64;
33
34/// A debugging utility to inspect registers
35pub type DebugFn = extern "C" fn(u64);
36
37/// A transpiler for risc32 instructions.
38///
39/// This trait is implemented for each target architecture supported by the JIT transpiler.
40///
41/// The transpiler is responsible for translating the risc32 instructions into the target
42/// architecture's instruction set.
43///
44/// This transpiler should generate an entrypoint of the form: [`fn(*mut JitContext)`]
45///
46/// For each instruction, you will typically want to call [`SP1RiscvTranspiler::start_instr`]
47/// before transpiling the instruction. This maps a "riscv instruction index" to some physical
48/// native address, as there are multiple native instructions per riscv instruction.
49///
50/// You will also likely want to call [`SP1RiscvTranspiler::bump_clk`] to increment the clock
51/// counter, and [`SP1RiscvTranspiler::set_pc`] to set the PC.
52///
53/// # Note
54/// Some instructions will directly modify the PC, such as [`SP1RiscvTranspiler::jal`] and
55/// [`SP1RiscvTranspiler::jalr`], and all the branch instructions, for these instructions, you would
56/// not want to call [`SP1RiscvTranspiler::set_pc`] as it will be called for you.
57///
58///
59/// ```rust,no_run,ignore
60/// pub fn add_program() {
61/// let mut transpiler = SP1RiscvTranspiler::new(program_size, memory_size, trace_buf_size, 100, 100).unwrap();
62///
63/// // Transpile the first instruction.
64/// transpiler.start_instr();
65/// transpiler.add(RiscOperand::Reg(RiscRegister::A), RiscOperand::Reg(RiscRegister::B), RiscRegister::C);
66/// transpiler.end_instr();
67///
68/// // Transpile the second instruction.
69/// transpiler.start_instr();
70///
71/// transpiler.add(RiscOperand::Reg(RiscRegister::A), RiscOperand::Reg(RiscRegister::B), RiscRegister::C);
72/// transpiler.end_instr();
73///
74/// let mut func = transpiler.finalize();
75///
76/// // Call the function.
77/// let traces = func.call();
78///
79/// // do stuff with the traces.
80/// }
81/// ```
82pub trait RiscvTranspiler:
83 TraceCollector
84 + ComputeInstructions
85 + ControlFlowInstructions
86 + MemoryInstructions
87 + SystemInstructions
88 + Sized
89{
90 /// Create a new transpiler.
91 ///
92 /// The program is used for the jump-table and is not a hard limit on the size of the program.
93 /// The memory size is the exact amount that will be allocated for the program.
94 fn new(
95 program_size: usize,
96 memory_size: usize,
97 max_trace_size: u64,
98 pc_start: u64,
99 pc_base: u64,
100 clk_bump: u64,
101 ) -> Result<Self, std::io::Error>;
102
103 /// Register a rust function of the form [`EcallHandler`] that will be used as the ECALL.
104 fn register_ecall_handler(&mut self, handler: EcallHandler);
105
106 /// Populates a jump table entry for the current instruction being transpiled.
107 ///
108 /// Effectively should create a mapping from RISCV PC -> absolute address of the instruction.
109 ///
110 /// This method should be called for "each pc" in the program.
111 fn start_instr(&mut self);
112
113 /// This method should be called for "each pc" in the program.
114 /// Handle logics when finishing execution of an instruction such as bumping clk and jump to
115 /// branch destination.
116 fn end_instr(&mut self);
117
118 /// Inspcet a [RiscRegister] using a function pointer.
119 ///
120 /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
121 fn inspect_register(&mut self, reg: RiscRegister, handler: DebugFn);
122
123 /// Print an immediate value.
124 ///
125 /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
126 fn inspect_immediate(&mut self, imm: u64, handler: DebugFn);
127
128 /// Call an [ExternFn] from the outputted assembly.
129 ///
130 /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
131 fn call_extern_fn(&mut self, handler: ExternFn);
132
133 /// Returns the function pointer to the generated code.
134 ///
135 /// This function is expected to be of the form: `fn(*mut JitContext)`.
136 fn finalize(self) -> io::Result<JitFunction>;
137}
138
139/// A trait the collects traces, in the form [TraceChunk].
140///
141/// This type is expected to follow the conventions as described in the [TraceChunk] documentation.
142pub trait TraceCollector {
143 /// Write the current state of the registers into the trace buf.
144 ///
145 /// For SP1 this is only called once in the beginning of a "chunk".
146 fn trace_registers(&mut self);
147
148 /// Write the value located at rs1 + imm into the trace buf.
149 fn trace_mem_value(&mut self, rs1: RiscRegister, imm: u64);
150
151 /// Write the start pc of the trace chunk.
152 fn trace_pc_start(&mut self);
153
154 /// Write the start clk of the trace chunk.
155 fn trace_clk_start(&mut self);
156
157 /// Write the end clk of the trace chunk.
158 fn trace_clk_end(&mut self);
159}
160
161pub trait Debuggable {
162 fn print_ctx(&mut self);
163}
164
165impl<T: RiscvTranspiler> Debuggable for T {
166 // Useful only for debugging.
167 fn print_ctx(&mut self) {
168 extern "C" fn print_ctx(ctx: *mut JitContext) {
169 let ctx = unsafe { &mut *ctx };
170 eprintln!("pc: {:x}", ctx.pc);
171 eprintln!("clk: {}", ctx.clk);
172 eprintln!("{:?}", *ctx.registers());
173 }
174
175 self.call_extern_fn(print_ctx);
176 }
177}
178
179#[cfg(not(target_os = "linux"))]
180/// Stub implementation for non-linux targets to compile.
181pub struct JitFunction {}
182
183/// A type representing a JIT compiled function.
184///
185/// The underlying function should be of the form [`fn(*mut JitContext)`].
186#[cfg(target_os = "linux")]
187pub struct JitFunction {
188 jump_table: Vec<*const u8>,
189 trace_buf_size: usize,
190 code: ExecutableBuffer,
191
192 /// The initial memory image.
193 initial_memory_image: Arc<HashMap<u64, u64>>,
194 pc_start: u64,
195 input_buffer: VecDeque<Vec<u8>>,
196
197 /// A stream of public values from the program (global to entire program).
198 pub public_values_stream: Vec<u8>,
199
200 /// Keep around the memfd, and pass it to the JIT context,
201 /// we can use this to create the COW memory at runtime.
202 mem_fd: memfd::Memfd,
203
204 /// During execution, the hints are read by the program, and we store them here.
205 /// This is effectively a mapping from start address to the value of the hint.
206 pub hints: Vec<(u64, Vec<u8>)>,
207
208 /// The JIT function may stop "in the middle" of an program,
209 /// we want to be able to resume it, so this is the information needed to do so.
210 pub memory: MmapMut,
211 pub pc: u64,
212 pub registers: [u64; 32],
213 pub clk: u64,
214 pub global_clk: u64,
215 pub exit_code: u32,
216
217 pub debug_sender: Option<mpsc::SyncSender<Option<debug::State>>>,
218}
219
220unsafe impl Send for JitFunction {}
221
222#[cfg(target_os = "linux")]
223impl JitFunction {
224 pub(crate) fn new(
225 code: ExecutableBuffer,
226 jump_table: Vec<usize>,
227 memory_size: usize,
228 trace_buf_size: usize,
229 pc_start: u64,
230 ) -> std::io::Result<Self> {
231 // Adjust the jump table to be absolute addresses.
232 let buf_ptr = code.as_ptr();
233 let jump_table =
234 jump_table.into_iter().map(|offset| unsafe { buf_ptr.add(offset) }).collect();
235
236 let fd = memfd::MemfdOptions::default()
237 .create(uuid::Uuid::new_v4().to_string())
238 .expect("Failed to create jit memory");
239
240 fd.as_file().set_len((memory_size + std::mem::align_of::<u64>()) as u64)?;
241
242 Ok(Self {
243 jump_table,
244 code,
245 memory: unsafe { MmapOptions::new().no_reserve_swap().map_mut(fd.as_file())? },
246 mem_fd: fd,
247 trace_buf_size,
248 pc: pc_start,
249 clk: 1,
250 global_clk: 0,
251 registers: [0; 32],
252 initial_memory_image: Arc::new(HashMap::new()),
253 pc_start,
254 input_buffer: VecDeque::new(),
255 hints: Vec::new(),
256 public_values_stream: Vec::new(),
257 debug_sender: None,
258 exit_code: 0,
259 })
260 }
261
262 /// Write the initial memory image to the JIT memory.
263 ///
264 /// # Panics
265 ///
266 /// Panics if the PC is not the starting PC.
267 pub fn with_initial_memory_image(&mut self, memory: Arc<HashMap<u64, u64>>) {
268 assert!(
269 self.pc == self.pc_start,
270 "The initial memory should only be supplied before using the JIT function."
271 );
272
273 self.initial_memory_image = memory;
274 self.insert_memory_image();
275 }
276
277 /// Push an input to the input buffer.
278 ///
279 /// # Panics
280 ///
281 /// Panics if the PC is not the starting PC.
282 pub fn push_input(&mut self, input: Vec<u8>) {
283 assert!(
284 self.pc == self.pc_start,
285 "The input buffer should only be supplied before using the JIT function."
286 );
287
288 self.input_buffer.push_back(input);
289
290 self.hints.reserve(1);
291 }
292
293 /// Set the entire input buffer.
294 ///
295 /// # Panics
296 ///
297 /// Panics if the PC is not the starting PC.
298 pub fn set_input_buffer(&mut self, input: VecDeque<Vec<u8>>) {
299 assert!(
300 self.pc == self.pc_start,
301 "The input buffer should only be supplied before using the JIT function."
302 );
303
304 // Reserve the space for the hints.
305 self.hints.reserve(input.len());
306 self.input_buffer = input;
307 }
308
309 /// Call the function, returning the trace buffer, starting at the starting PC of the program.
310 ///
311 /// If the PC is 0, then the program has completed and we return None.
312 ///
313 /// # SAFETY
314 /// Relies on the builder to emit valid assembly
315 /// and that the pointer is valid for the duration of the function call.
316 pub unsafe fn call(&mut self) -> Option<TraceChunkRaw> {
317 if self.pc == 1 {
318 return None;
319 }
320
321 let as_fn = std::mem::transmute::<*const u8, fn(*mut JitContext)>(self.code.as_ptr());
322
323 // Ensure the pointer is aligned to the alignment of the MemValue.
324 let mut trace_buf =
325 MmapMut::map_anon(self.trace_buf_size + std::mem::align_of::<MemValue>())
326 .expect("Failed to create trace buf mmap");
327 let trace_buf_offset = trace_buf.as_ptr().align_offset(std::mem::align_of::<MemValue>());
328 let trace_buf_ptr = trace_buf.as_mut_ptr().add(trace_buf_offset);
329
330 // Ensure the memory pointer is aligned to the alignment of the u64.
331 let align_offset = self.memory.as_ptr().align_offset(std::mem::align_of::<u64>());
332 let mem_ptr = self.memory.as_mut_ptr().add(align_offset);
333 let tracing = self.trace_buf_size > 0;
334
335 // SAFETY:
336 // - The jump table is valid for the duration of the function call, its owned by self.
337 // - The memory is valid for the duration of the function call, its owned by self.
338 // - The trace buf is valid for the duration of the function call, we just allocated it
339 // - The input buffer is valid for the duration of the function call, its owned by self.
340 let mut ctx = JitContext {
341 jump_table: NonNull::new_unchecked(self.jump_table.as_mut_ptr()),
342 memory: NonNull::new_unchecked(mem_ptr),
343 trace_buf: NonNull::new_unchecked(trace_buf_ptr),
344 input_buffer: NonNull::new_unchecked(&mut self.input_buffer),
345 hints: NonNull::new_unchecked(&mut self.hints),
346 maybe_unconstrained: None,
347 public_values_stream: NonNull::new_unchecked(&mut self.public_values_stream),
348 memory_fd: self.mem_fd.as_raw_fd(),
349 registers: self.registers,
350 pc: self.pc,
351 clk: self.clk,
352 global_clk: self.global_clk,
353 is_unconstrained: 0,
354 tracing,
355 debug_sender: self.debug_sender.clone(),
356 exit_code: self.exit_code,
357 };
358
359 tracing::debug_span!("JIT function", pc = ctx.pc, clk = ctx.clk).in_scope(|| {
360 as_fn(&mut ctx);
361 });
362
363 // Update the values we want to preserve.
364 self.pc = ctx.pc;
365 self.registers = ctx.registers;
366 self.clk = ctx.clk;
367 self.global_clk = ctx.global_clk;
368 self.exit_code = ctx.exit_code;
369
370 tracing.then_some(TraceChunkRaw::new(
371 trace_buf.make_read_only().expect("Failed to make trace buf read only"),
372 ))
373 }
374
375 /// Reset the JIT function to the initial state.
376 ///
377 /// This will clear the registers, the program counter, the clock, and the memory, restoring the
378 /// initial memory image.
379 pub fn reset(&mut self) {
380 self.pc = self.pc_start;
381 self.registers = [0; 32];
382 self.clk = 1;
383 self.global_clk = 0;
384 self.input_buffer = VecDeque::new();
385 self.hints = Vec::new();
386 self.public_values_stream = Vec::new();
387
388 // Store the original size of the memory.
389 let memory_size = self.memory.len();
390
391 // Create a new memfd for the backing memory.
392 self.mem_fd = memfd::MemfdOptions::default()
393 .create(uuid::Uuid::new_v4().to_string())
394 .expect("Failed to create jit memory");
395
396 self.mem_fd
397 .as_file()
398 .set_len(memory_size as u64)
399 .expect("Failed to set memfd size for backing memory.");
400
401 self.memory = unsafe {
402 MmapOptions::new()
403 .no_reserve_swap()
404 .map_mut(self.mem_fd.as_file())
405 .expect("Failed to map memory")
406 };
407
408 self.insert_memory_image();
409 }
410
411 fn insert_memory_image(&mut self) {
412 for (addr, val) in self.initial_memory_image.iter() {
413 // Technically, this crate is probably only used on little endian targets, but just to
414 // sure.
415 let bytes = val.to_le_bytes();
416
417 #[cfg(debug_assertions)]
418 if addr % 8 > 0 {
419 panic!("Address {addr} is not aligned to 8");
420 }
421
422 let actual_addr = 2 * addr + 8;
423 unsafe {
424 std::ptr::copy_nonoverlapping(
425 bytes.as_ptr(),
426 self.memory.as_mut_ptr().add(actual_addr as usize),
427 bytes.len(),
428 )
429 };
430 }
431 }
432}
433
434pub struct MemoryView<'a> {
435 pub memory: &'a MmapMut,
436}
437
438impl<'a> MemoryView<'a> {
439 pub const fn new(memory: &'a MmapMut) -> Self {
440 Self { memory }
441 }
442
443 /// Read a word from the memory at the address.
444 ///
445 /// # Panics
446 ///
447 /// Panics if the address is not aligned to 8 bytes.
448 pub fn get(&self, addr: u64) -> MemValue {
449 assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
450
451 let word_address = addr / 8;
452 let entry_ptr = self.memory.as_ptr() as *mut MemValue;
453
454 unsafe { std::ptr::read(entry_ptr.add(word_address as usize)) }
455 }
456}