Skip to main content

sp1_jit/
lib.rs

1#![cfg_attr(not(target_os = "linux"), allow(unused))]
2
3#[cfg(not(target_endian = "little"))]
4compile_error!("This crate is only supported on little endian targets.");
5
6pub mod backends;
7pub mod context;
8pub mod debug;
9pub mod instructions;
10mod macros;
11pub mod risc;
12
13use dynasmrt::ExecutableBuffer;
14use hashbrown::HashMap;
15use memmap2::{MmapMut, MmapOptions};
16use std::{
17    collections::VecDeque,
18    io,
19    os::fd::AsRawFd,
20    ptr::NonNull,
21    sync::{mpsc, Arc},
22};
23
24pub use backends::*;
25pub use context::*;
26pub use instructions::*;
27pub use risc::*;
28
29/// A function that accepts the memory pointer.
30pub type ExternFn = extern "C" fn(*mut JitContext);
31
32pub type EcallHandler = extern "C" fn(*mut JitContext) -> u64;
33
34/// A debugging utility to inspect registers
35pub type DebugFn = extern "C" fn(u64);
36
37/// A transpiler for risc32 instructions.
38///
39/// This trait is implemented for each target architecture supported by the JIT transpiler.
40///
41/// The transpiler is responsible for translating the risc32 instructions into the target
42/// architecture's instruction set.
43///
44/// This transpiler should generate an entrypoint of the form: [`fn(*mut JitContext)`]
45///
46/// For each instruction, you will typically want to call [`SP1RiscvTranspiler::start_instr`]
47/// before transpiling the instruction. This maps a "riscv instruction index" to some physical
48/// native address, as there are multiple native instructions per riscv instruction.
49///
50/// You will also likely want to call [`SP1RiscvTranspiler::bump_clk`] to increment the clock
51/// counter, and [`SP1RiscvTranspiler::set_pc`] to set the PC.
52///
53/// # Note
54/// Some instructions will directly modify the PC, such as [`SP1RiscvTranspiler::jal`] and
55/// [`SP1RiscvTranspiler::jalr`], and all the branch instructions, for these instructions, you would
56/// not want to call [`SP1RiscvTranspiler::set_pc`] as it will be called for you.
57///
58///
59/// ```rust,no_run,ignore
60/// pub fn add_program() {
61///     let mut transpiler = SP1RiscvTranspiler::new(program_size, memory_size, trace_buf_size, 100, 100).unwrap();
62///      
63///     // Transpile the first instruction.
64///     transpiler.start_instr();
65///     transpiler.add(RiscOperand::Reg(RiscRegister::A), RiscOperand::Reg(RiscRegister::B), RiscRegister::C);
66///     transpiler.end_instr();
67///     
68///     // Transpile the second instruction.
69///     transpiler.start_instr();
70///
71///     transpiler.add(RiscOperand::Reg(RiscRegister::A), RiscOperand::Reg(RiscRegister::B), RiscRegister::C);
72///     transpiler.end_instr();
73///     
74///     let mut func = transpiler.finalize();
75///
76///     // Call the function.
77///     let traces = func.call();
78///
79///     // do stuff with the traces.
80/// }
81/// ```
82pub trait RiscvTranspiler:
83    TraceCollector
84    + ComputeInstructions
85    + ControlFlowInstructions
86    + MemoryInstructions
87    + SystemInstructions
88    + Sized
89{
90    /// Create a new transpiler.
91    ///
92    /// The program is used for the jump-table and is not a hard limit on the size of the program.
93    /// The memory size is the exact amount that will be allocated for the program.
94    fn new(
95        program_size: usize,
96        memory_size: usize,
97        max_trace_size: u64,
98        pc_start: u64,
99        pc_base: u64,
100        clk_bump: u64,
101    ) -> Result<Self, std::io::Error>;
102
103    /// Register a rust function of the form [`EcallHandler`] that will be used as the ECALL.
104    fn register_ecall_handler(&mut self, handler: EcallHandler);
105
106    /// Populates a jump table entry for the current instruction being transpiled.
107    ///
108    /// Effectively should create a mapping from RISCV PC -> absolute address of the instruction.
109    ///
110    /// This method should be called for "each pc" in the program.
111    fn start_instr(&mut self);
112
113    /// This method should be called for "each pc" in the program.
114    /// Handle logics when finishing execution of an instruction such as bumping clk and jump to
115    /// branch destination.
116    fn end_instr(&mut self);
117
118    /// Inspcet a [RiscRegister] using a function pointer.
119    ///
120    /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
121    fn inspect_register(&mut self, reg: RiscRegister, handler: DebugFn);
122
123    /// Print an immediate value.
124    ///
125    /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
126    fn inspect_immediate(&mut self, imm: u64, handler: DebugFn);
127
128    /// Call an [ExternFn] from the outputted assembly.
129    ///
130    /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
131    fn call_extern_fn(&mut self, handler: ExternFn);
132
133    /// Returns the function pointer to the generated code.
134    ///
135    /// This function is expected to be of the form: `fn(*mut JitContext)`.
136    fn finalize(self) -> io::Result<JitFunction>;
137}
138
139/// A trait the collects traces, in the form [TraceChunk].
140///
141/// This type is expected to follow the conventions as described in the [TraceChunk] documentation.
142pub trait TraceCollector {
143    /// Write the current state of the registers into the trace buf.
144    ///
145    /// For SP1 this is only called once in the beginning of a "chunk".
146    fn trace_registers(&mut self);
147
148    /// Write the value located at rs1 + imm into the trace buf.
149    fn trace_mem_value(&mut self, rs1: RiscRegister, imm: u64);
150
151    /// Write the start pc of the trace chunk.
152    fn trace_pc_start(&mut self);
153
154    /// Write the start clk of the trace chunk.
155    fn trace_clk_start(&mut self);
156
157    /// Write the end clk of the trace chunk.
158    fn trace_clk_end(&mut self);
159}
160
161pub trait Debuggable {
162    fn print_ctx(&mut self);
163}
164
165impl<T: RiscvTranspiler> Debuggable for T {
166    // Useful only for debugging.
167    fn print_ctx(&mut self) {
168        extern "C" fn print_ctx(ctx: *mut JitContext) {
169            let ctx = unsafe { &mut *ctx };
170            eprintln!("pc: {:x}", ctx.pc);
171            eprintln!("clk: {}", ctx.clk);
172            eprintln!("{:?}", *ctx.registers());
173        }
174
175        self.call_extern_fn(print_ctx);
176    }
177}
178
179#[cfg(not(target_os = "linux"))]
180/// Stub implementation for non-linux targets to compile.
181pub struct JitFunction {}
182
183/// A type representing a JIT compiled function.
184///
185/// The underlying function should be of the form [`fn(*mut JitContext)`].
186#[cfg(target_os = "linux")]
187pub struct JitFunction {
188    jump_table: Vec<*const u8>,
189    trace_buf_size: usize,
190    code: ExecutableBuffer,
191
192    /// The initial memory image.
193    initial_memory_image: Arc<HashMap<u64, u64>>,
194    pc_start: u64,
195    input_buffer: VecDeque<Vec<u8>>,
196
197    /// A stream of public values from the program (global to entire program).
198    pub public_values_stream: Vec<u8>,
199
200    /// Keep around the memfd, and pass it to the JIT context,
201    /// we can use this to create the COW memory at runtime.
202    mem_fd: memfd::Memfd,
203
204    /// During execution, the hints are read by the program, and we store them here.
205    /// This is effectively a mapping from start address to the value of the hint.
206    pub hints: Vec<(u64, Vec<u8>)>,
207
208    /// The JIT function may stop "in the middle" of an program,
209    /// we want to be able to resume it, so this is the information needed to do so.
210    pub memory: MmapMut,
211    pub pc: u64,
212    pub registers: [u64; 32],
213    pub clk: u64,
214    pub global_clk: u64,
215    pub exit_code: u32,
216
217    pub debug_sender: Option<mpsc::SyncSender<Option<debug::State>>>,
218}
219
220unsafe impl Send for JitFunction {}
221
222#[cfg(target_os = "linux")]
223impl JitFunction {
224    pub(crate) fn new(
225        code: ExecutableBuffer,
226        jump_table: Vec<usize>,
227        memory_size: usize,
228        trace_buf_size: usize,
229        pc_start: u64,
230    ) -> std::io::Result<Self> {
231        // Adjust the jump table to be absolute addresses.
232        let buf_ptr = code.as_ptr();
233        let jump_table =
234            jump_table.into_iter().map(|offset| unsafe { buf_ptr.add(offset) }).collect();
235
236        let fd = memfd::MemfdOptions::default()
237            .create(uuid::Uuid::new_v4().to_string())
238            .expect("Failed to create jit memory");
239
240        fd.as_file().set_len((memory_size + std::mem::align_of::<u64>()) as u64)?;
241
242        Ok(Self {
243            jump_table,
244            code,
245            memory: unsafe { MmapOptions::new().no_reserve_swap().map_mut(fd.as_file())? },
246            mem_fd: fd,
247            trace_buf_size,
248            pc: pc_start,
249            clk: 1,
250            global_clk: 0,
251            registers: [0; 32],
252            initial_memory_image: Arc::new(HashMap::new()),
253            pc_start,
254            input_buffer: VecDeque::new(),
255            hints: Vec::new(),
256            public_values_stream: Vec::new(),
257            debug_sender: None,
258            exit_code: 0,
259        })
260    }
261
262    /// Write the initial memory image to the JIT memory.
263    ///
264    /// # Panics
265    ///
266    /// Panics if the PC is not the starting PC.
267    pub fn with_initial_memory_image(&mut self, memory: Arc<HashMap<u64, u64>>) {
268        assert!(
269            self.pc == self.pc_start,
270            "The initial memory should only be supplied before using the JIT function."
271        );
272
273        self.initial_memory_image = memory;
274        self.insert_memory_image();
275    }
276
277    /// Push an input to the input buffer.
278    ///
279    /// # Panics
280    ///
281    /// Panics if the PC is not the starting PC.
282    pub fn push_input(&mut self, input: Vec<u8>) {
283        assert!(
284            self.pc == self.pc_start,
285            "The input buffer should only be supplied before using the JIT function."
286        );
287
288        self.input_buffer.push_back(input);
289
290        self.hints.reserve(1);
291    }
292
293    /// Set the entire input buffer.
294    ///
295    /// # Panics
296    ///
297    /// Panics if the PC is not the starting PC.
298    pub fn set_input_buffer(&mut self, input: VecDeque<Vec<u8>>) {
299        assert!(
300            self.pc == self.pc_start,
301            "The input buffer should only be supplied before using the JIT function."
302        );
303
304        // Reserve the space for the hints.
305        self.hints.reserve(input.len());
306        self.input_buffer = input;
307    }
308
309    /// Call the function, returning the trace buffer, starting at the starting PC of the program.
310    ///
311    /// If the PC is 0, then the program has completed and we return None.
312    ///
313    /// # SAFETY
314    /// Relies on the builder to emit valid assembly
315    /// and that the pointer is valid for the duration of the function call.
316    pub unsafe fn call(&mut self) -> Option<TraceChunkRaw> {
317        if self.pc == 1 {
318            return None;
319        }
320
321        let as_fn = std::mem::transmute::<*const u8, fn(*mut JitContext)>(self.code.as_ptr());
322
323        // Ensure the pointer is aligned to the alignment of the MemValue.
324        let mut trace_buf =
325            MmapMut::map_anon(self.trace_buf_size + std::mem::align_of::<MemValue>())
326                .expect("Failed to create trace buf mmap");
327        let trace_buf_offset = trace_buf.as_ptr().align_offset(std::mem::align_of::<MemValue>());
328        let trace_buf_ptr = trace_buf.as_mut_ptr().add(trace_buf_offset);
329
330        // Ensure the memory pointer is aligned to the alignment of the u64.
331        let align_offset = self.memory.as_ptr().align_offset(std::mem::align_of::<u64>());
332        let mem_ptr = self.memory.as_mut_ptr().add(align_offset);
333        let tracing = self.trace_buf_size > 0;
334
335        // SAFETY:
336        // - The jump table is valid for the duration of the function call, its owned by self.
337        // - The memory is valid for the duration of the function call, its owned by self.
338        // - The trace buf is valid for the duration of the function call, we just allocated it
339        // - The input buffer is valid for the duration of the function call, its owned by self.
340        let mut ctx = JitContext {
341            jump_table: NonNull::new_unchecked(self.jump_table.as_mut_ptr()),
342            memory: NonNull::new_unchecked(mem_ptr),
343            trace_buf: NonNull::new_unchecked(trace_buf_ptr),
344            input_buffer: NonNull::new_unchecked(&mut self.input_buffer),
345            hints: NonNull::new_unchecked(&mut self.hints),
346            maybe_unconstrained: None,
347            public_values_stream: NonNull::new_unchecked(&mut self.public_values_stream),
348            memory_fd: self.mem_fd.as_raw_fd(),
349            registers: self.registers,
350            pc: self.pc,
351            clk: self.clk,
352            global_clk: self.global_clk,
353            is_unconstrained: 0,
354            tracing,
355            debug_sender: self.debug_sender.clone(),
356            exit_code: self.exit_code,
357        };
358
359        tracing::debug_span!("JIT function", pc = ctx.pc, clk = ctx.clk).in_scope(|| {
360            as_fn(&mut ctx);
361        });
362
363        // Update the values we want to preserve.
364        self.pc = ctx.pc;
365        self.registers = ctx.registers;
366        self.clk = ctx.clk;
367        self.global_clk = ctx.global_clk;
368        self.exit_code = ctx.exit_code;
369
370        tracing.then_some(TraceChunkRaw::new(
371            trace_buf.make_read_only().expect("Failed to make trace buf read only"),
372        ))
373    }
374
375    /// Reset the JIT function to the initial state.
376    ///
377    /// This will clear the registers, the program counter, the clock, and the memory, restoring the
378    /// initial memory image.
379    pub fn reset(&mut self) {
380        self.pc = self.pc_start;
381        self.registers = [0; 32];
382        self.clk = 1;
383        self.global_clk = 0;
384        self.input_buffer = VecDeque::new();
385        self.hints = Vec::new();
386        self.public_values_stream = Vec::new();
387
388        // Store the original size of the memory.
389        let memory_size = self.memory.len();
390
391        // Create a new memfd for the backing memory.
392        self.mem_fd = memfd::MemfdOptions::default()
393            .create(uuid::Uuid::new_v4().to_string())
394            .expect("Failed to create jit memory");
395
396        self.mem_fd
397            .as_file()
398            .set_len(memory_size as u64)
399            .expect("Failed to set memfd size for backing memory.");
400
401        self.memory = unsafe {
402            MmapOptions::new()
403                .no_reserve_swap()
404                .map_mut(self.mem_fd.as_file())
405                .expect("Failed to map memory")
406        };
407
408        self.insert_memory_image();
409    }
410
411    fn insert_memory_image(&mut self) {
412        for (addr, val) in self.initial_memory_image.iter() {
413            // Technically, this crate is probably only used on little endian targets, but just to
414            // sure.
415            let bytes = val.to_le_bytes();
416
417            #[cfg(debug_assertions)]
418            if addr % 8 > 0 {
419                panic!("Address {addr} is not aligned to 8");
420            }
421
422            let actual_addr = 2 * addr + 8;
423            unsafe {
424                std::ptr::copy_nonoverlapping(
425                    bytes.as_ptr(),
426                    self.memory.as_mut_ptr().add(actual_addr as usize),
427                    bytes.len(),
428                )
429            };
430        }
431    }
432}
433
434pub struct MemoryView<'a> {
435    pub memory: &'a MmapMut,
436}
437
438impl<'a> MemoryView<'a> {
439    pub const fn new(memory: &'a MmapMut) -> Self {
440        Self { memory }
441    }
442
443    /// Read a word from the memory at the address.
444    ///
445    /// # Panics
446    ///
447    /// Panics if the address is not aligned to 8 bytes.
448    pub fn get(&self, addr: u64) -> MemValue {
449        assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
450
451        let word_address = addr / 8;
452        let entry_ptr = self.memory.as_ptr() as *mut MemValue;
453
454        unsafe { std::ptr::read(entry_ptr.add(word_address as usize)) }
455    }
456}