Skip to main content

sp1_jit/
lib.rs

1#![cfg_attr(not(sp1_native_executor_available), allow(unused))]
2
3#[cfg(not(target_endian = "little"))]
4compile_error!("This crate is only supported on little endian targets.");
5
6pub mod backends;
7pub mod context;
8pub mod debug;
9pub mod instructions;
10mod macros;
11pub mod memory;
12pub mod risc;
13pub mod shm;
14
15use dynasmrt::ExecutableBuffer;
16use hashbrown::HashMap;
17use std::{
18    collections::VecDeque,
19    io,
20    ops::{Deref, DerefMut},
21    os::fd::AsRawFd,
22    ptr::NonNull,
23    sync::{mpsc, Arc},
24};
25
26pub use backends::*;
27pub use context::*;
28pub use instructions::*;
29pub use risc::*;
30
31/// A function that accepts the memory pointer.
32pub type ExternFn = extern "C" fn(*mut JitContext);
33
34pub type EcallHandler = extern "C" fn(*mut JitContext) -> u64;
35
36/// A debugging utility to inspect registers
37pub type DebugFn = extern "C" fn(u64);
38
39/// A transpiler for risc32 instructions.
40///
41/// This trait is implemented for each target architecture supported by the JIT transpiler.
42///
43/// The transpiler is responsible for translating the risc32 instructions into the target
44/// architecture's instruction set.
45///
46/// This transpiler should generate an entrypoint of the form: [`fn(*mut JitContext)`]
47///
48/// For each instruction, you will typically want to call [`SP1RiscvTranspiler::start_instr`]
49/// before transpiling the instruction. This maps a "riscv instruction index" to some physical
50/// native address, as there are multiple native instructions per riscv instruction.
51///
52/// You will also likely want to call [`SP1RiscvTranspiler::bump_clk`] to increment the clock
53/// counter, and [`SP1RiscvTranspiler::set_pc`] to set the PC.
54///
55/// # Note
56/// Some instructions will directly modify the PC, such as [`SP1RiscvTranspiler::jal`] and
57/// [`SP1RiscvTranspiler::jalr`], and all the branch instructions, for these instructions, you would
58/// not want to call [`SP1RiscvTranspiler::set_pc`] as it will be called for you.
59///
60///
61/// ```rust,no_run,ignore
62/// pub fn add_program() {
63///     let mut transpiler = SP1RiscvTranspiler::new(program_size, memory_size, trace_buf_size, 100, 100).unwrap();
64///      
65///     // Transpile the first instruction.
66///     transpiler.start_instr();
67///     transpiler.add(RiscOperand::Reg(RiscRegister::A), RiscOperand::Reg(RiscRegister::B), RiscRegister::C);
68///     transpiler.end_instr();
69///     
70///     // Transpile the second instruction.
71///     transpiler.start_instr();
72///
73///     transpiler.add(RiscOperand::Reg(RiscRegister::A), RiscOperand::Reg(RiscRegister::B), RiscRegister::C);
74///     transpiler.end_instr();
75///     
76///     let mut func = transpiler.finalize();
77///
78///     // Call the function.
79///     let traces = func.call();
80///
81///     // do stuff with the traces.
82/// }
83/// ```
84pub trait RiscvTranspiler:
85    TraceCollector
86    + ComputeInstructions
87    + ControlFlowInstructions
88    + MemoryInstructions
89    + SystemInstructions
90    + Sized
91{
92    /// Create a new transpiler.
93    ///
94    /// The program is used for the jump-table and is not a hard limit on the size of the program.
95    /// The memory size is the exact amount that will be allocated for the program.
96    fn new(
97        program_size: usize,
98        memory_size: usize,
99        max_trace_size: u64,
100        pc_start: u64,
101        pc_base: u64,
102        clk_bump: u64,
103    ) -> Result<Self, std::io::Error>;
104
105    /// Register a rust function of the form [`EcallHandler`] that will be used as the ECALL.
106    fn register_ecall_handler(&mut self, handler: EcallHandler);
107
108    /// Populates a jump table entry for the current instruction being transpiled.
109    ///
110    /// Effectively should create a mapping from RISCV PC -> absolute address of the instruction.
111    ///
112    /// This method should be called for "each pc" in the program.
113    fn start_instr(&mut self);
114
115    /// This method should be called for "each pc" in the program.
116    /// Handle logics when finishing execution of an instruction such as bumping clk and jump to
117    /// branch destination.
118    fn end_instr(&mut self);
119
120    /// Inspcet a [RiscRegister] using a function pointer.
121    ///
122    /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
123    fn inspect_register(&mut self, reg: RiscRegister, handler: DebugFn);
124
125    /// Print an immediate value.
126    ///
127    /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
128    fn inspect_immediate(&mut self, imm: u64, handler: DebugFn);
129
130    /// Call an [ExternFn] from the outputted assembly.
131    ///
132    /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
133    fn call_extern_fn(&mut self, handler: ExternFn);
134
135    /// Returns the function pointer to the generated code.
136    ///
137    /// This function is expected to be of the form: `fn(*mut JitContext)`.
138    fn finalize<M: JitMemory>(self) -> io::Result<JitFunction<M>>;
139}
140
141/// A trait the collects traces, in the form [TraceChunk].
142///
143/// This type is expected to follow the conventions as described in the [TraceChunk] documentation.
144pub trait TraceCollector {
145    /// Write the current state of the registers into the trace buf.
146    ///
147    /// For SP1 this is only called once in the beginning of a "chunk".
148    fn trace_registers(&mut self);
149
150    /// Write the value located at rs1 + imm into the trace buf.
151    fn trace_mem_value(&mut self, rs1: RiscRegister, imm: u64);
152
153    /// Write the start pc of the trace chunk.
154    fn trace_pc_start(&mut self);
155
156    /// Write the start clk of the trace chunk.
157    fn trace_clk_start(&mut self);
158
159    /// Write the end clk of the trace chunk.
160    fn trace_clk_end(&mut self);
161}
162
163pub trait Debuggable {
164    fn print_ctx(&mut self);
165}
166
167/// A trait representing JIT memory.
168pub trait JitMemory: Sized + Deref<Target = [u8]> + DerefMut + AsRawFd {
169    fn new(memory_size: usize) -> Self;
170}
171
172/// A JIT memory that is also resetable
173pub trait JitResetableMemory: JitMemory {
174    fn reset(&mut self);
175}
176
177impl<T: RiscvTranspiler> Debuggable for T {
178    // Useful only for debugging.
179    fn print_ctx(&mut self) {
180        extern "C" fn print_ctx(ctx: *mut JitContext) {
181            let ctx = unsafe { &mut *ctx };
182            eprintln!("pc: {:x}", ctx.pc);
183            eprintln!("clk: {}", ctx.clk);
184            eprintln!("{:?}", *ctx.registers());
185        }
186
187        self.call_extern_fn(print_ctx);
188    }
189}
190
191#[cfg(not(sp1_native_executor_available))]
192/// Stub implementation for non-linux targets to compile.
193pub struct JitFunction<M> {
194    _marker: std::marker::PhantomData<M>,
195}
196
197/// A type representing a JIT compiled function.
198///
199/// The underlying function should be of the form [`fn(*mut JitContext)`].
200#[cfg(sp1_native_executor_available)]
201pub struct JitFunction<M> {
202    jump_table: Vec<*const u8>,
203    code: ExecutableBuffer,
204
205    /// The initial memory image.
206    initial_memory_image: Arc<HashMap<u64, u64>>,
207    pc_start: u64,
208    input_buffer: VecDeque<Vec<u8>>,
209
210    /// A stream of public values from the program (global to entire program).
211    pub public_values_stream: Vec<u8>,
212
213    /// Memory structure,
214    pub memory: M,
215
216    /// During execution, the hints are read by the program, and we store them here.
217    /// This is effectively a mapping from start address to the value of the hint.
218    pub hints: Vec<(u64, Vec<u8>)>,
219
220    pub pc: u64,
221    pub registers: [u64; 32],
222    pub clk: u64,
223    pub global_clk: u64,
224    pub exit_code: u32,
225
226    /// The public value digest words emitted by `COMMIT` syscalls.
227    pub public_value_digest: [u32; context::PUBLIC_VALUE_DIGEST_WORDS],
228
229    pub debug_sender: Option<mpsc::SyncSender<Option<debug::State>>>,
230}
231
232unsafe impl<M: Send> Send for JitFunction<M> {}
233
234#[cfg(sp1_native_executor_available)]
235impl<M: JitMemory> JitFunction<M> {
236    pub(crate) fn new(
237        code: ExecutableBuffer,
238        jump_table: Vec<usize>,
239        memory_size: usize,
240        pc_start: u64,
241    ) -> std::io::Result<Self> {
242        // Adjust the jump table to be absolute addresses.
243        let buf_ptr = code.as_ptr();
244        let jump_table =
245            jump_table.into_iter().map(|offset| unsafe { buf_ptr.add(offset) }).collect();
246
247        let memory = M::new(memory_size);
248
249        Ok(Self {
250            jump_table,
251            code,
252            memory,
253            pc: pc_start,
254            clk: 1,
255            global_clk: 0,
256            registers: [0; 32],
257            initial_memory_image: Arc::new(HashMap::new()),
258            pc_start,
259            input_buffer: VecDeque::new(),
260            hints: Vec::new(),
261            public_values_stream: Vec::new(),
262            debug_sender: None,
263            exit_code: 0,
264            public_value_digest: [0; context::PUBLIC_VALUE_DIGEST_WORDS],
265        })
266    }
267
268    /// Write the initial memory image to the JIT memory.
269    ///
270    /// # Panics
271    ///
272    /// Panics if the PC is not the starting PC.
273    pub fn with_initial_memory_image(&mut self, memory: Arc<HashMap<u64, u64>>) {
274        assert!(
275            self.pc == self.pc_start,
276            "The initial memory should only be supplied before using the JIT function."
277        );
278
279        self.initial_memory_image = memory;
280        self.insert_memory_image();
281    }
282
283    /// Push an input to the input buffer.
284    ///
285    /// # Panics
286    ///
287    /// Panics if the PC is not the starting PC.
288    pub fn push_input(&mut self, input: Vec<u8>) {
289        assert!(
290            self.pc == self.pc_start,
291            "The input buffer should only be supplied before using the JIT function."
292        );
293
294        self.input_buffer.push_back(input);
295
296        self.hints.reserve(1);
297    }
298
299    /// Set the entire input buffer.
300    ///
301    /// # Panics
302    ///
303    /// Panics if the PC is not the starting PC.
304    pub fn set_input_buffer(&mut self, input: VecDeque<Vec<u8>>) {
305        assert!(
306            self.pc == self.pc_start,
307            "The input buffer should only be supplied before using the JIT function."
308        );
309
310        // Reserve the space for the hints.
311        self.hints.reserve(input.len());
312        self.input_buffer = input;
313    }
314
315    /// Call the function, returning the trace buffer, starting at the starting PC of the program.
316    ///
317    /// If the PC is 0, then the program has completed and we return None.
318    ///
319    /// # SAFETY
320    /// Relies on the builder to emit valid assembly
321    /// and that the pointer is valid for the duration of the function call.
322    pub unsafe fn call(&mut self, trace_buf_ptr: *mut u8) {
323        if self.pc == 1 {
324            return;
325        }
326
327        let as_fn = std::mem::transmute::<*const u8, fn(*mut JitContext)>(self.code.as_ptr());
328
329        // Ensure the memory pointer is aligned to the alignment of the u64.
330        let align_offset = self.memory.as_ptr().align_offset(std::mem::align_of::<u64>());
331        let mem_ptr = self.memory.as_mut_ptr().add(align_offset);
332        let tracing = !trace_buf_ptr.is_null();
333
334        // SAFETY:
335        // - The jump table is valid for the duration of the function call, its owned by self.
336        // - The memory is valid for the duration of the function call, its owned by self.
337        // - The trace buf is valid for the duration of the function call, we just allocated it
338        // - The input buffer is valid for the duration of the function call, its owned by self.
339        let mut ctx = JitContext {
340            jump_table: NonNull::new_unchecked(self.jump_table.as_mut_ptr()),
341            memory: NonNull::new_unchecked(mem_ptr),
342            trace_buf: trace_buf_ptr,
343            input_buffer: NonNull::new_unchecked(&mut self.input_buffer),
344            hints: NonNull::new_unchecked(&mut self.hints),
345            maybe_unconstrained: None,
346            public_values_stream: NonNull::new_unchecked(&mut self.public_values_stream),
347            memory_fd: self.memory.as_raw_fd(),
348            registers: self.registers,
349            pc: self.pc,
350            clk: self.clk,
351            global_clk: self.global_clk,
352            is_unconstrained: 0,
353            tracing,
354            debug_sender: self.debug_sender.clone(),
355            exit_code: self.exit_code,
356            public_value_digest: self.public_value_digest,
357        };
358
359        tracing::debug_span!("JIT function", pc = ctx.pc, clk = ctx.clk).in_scope(|| {
360            as_fn(&mut ctx);
361        });
362
363        // Update the values we want to preserve.
364        self.pc = ctx.pc;
365        self.registers = ctx.registers;
366        self.clk = ctx.clk;
367        self.global_clk = ctx.global_clk;
368        self.exit_code = ctx.exit_code;
369        self.public_value_digest = ctx.public_value_digest;
370    }
371
372    fn insert_memory_image(&mut self) {
373        for (addr, val) in self.initial_memory_image.iter() {
374            // Technically, this crate is probably only used on little endian targets, but just to
375            // sure.
376            let bytes = val.to_le_bytes();
377
378            #[cfg(debug_assertions)]
379            if addr % 8 > 0 {
380                panic!("Address {addr} is not aligned to 8");
381            }
382
383            let actual_addr = 2 * addr + 8;
384            unsafe {
385                std::ptr::copy_nonoverlapping(
386                    bytes.as_ptr(),
387                    self.memory.as_mut_ptr().add(actual_addr as usize),
388                    bytes.len(),
389                )
390            };
391        }
392    }
393}
394
395#[cfg(sp1_native_executor_available)]
396impl<M: JitResetableMemory> JitFunction<M> {
397    /// Reset the JIT function to the initial state.
398    ///
399    /// This will clear the registers, the program counter, the clock, and the memory, restoring the
400    /// initial memory image.
401    pub fn reset(&mut self) {
402        self.pc = self.pc_start;
403        self.registers = [0; 32];
404        self.clk = 1;
405        self.global_clk = 0;
406        self.input_buffer = VecDeque::new();
407        self.hints = Vec::new();
408        self.public_values_stream = Vec::new();
409        self.public_value_digest = [0; context::PUBLIC_VALUE_DIGEST_WORDS];
410        self.memory.reset();
411
412        self.insert_memory_image();
413    }
414}