1use crate::{debug, ElfInfo, Interrupt, MemValue, PageProtValue, RiscRegister, TraceChunkHeader};
2use memmap2::{MmapMut, MmapOptions};
3use sp1_primitives::consts::{PROT_READ, PROT_WRITE};
4use std::{collections::VecDeque, io, os::fd::RawFd, ptr::NonNull, sync::mpsc};
5
6pub trait SyscallContext {
7 fn rr(&self, reg: RiscRegister) -> u64;
9 fn rw(&mut self, reg: RiscRegister, value: u64);
11 fn set_next_pc(&mut self, pc: u64);
13 fn mr_without_prot(&mut self, addr: u64) -> u64;
15 fn mw_without_prot(&mut self, addr: u64, val: u64);
17 fn mr_slice(
19 &mut self,
20 addr: u64,
21 len: usize,
22 ) -> Result<impl IntoIterator<Item = &u64>, Interrupt> {
23 self.prot_slice_check(addr, len, PROT_READ)?;
24 Ok(self.mr_slice_without_prot(addr, len))
25 }
26 fn mr_slice_without_prot(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64>;
28 fn mr_slice_unsafe(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64>;
31 fn mr_slice_no_trace(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64>;
33 fn mw_slice(&mut self, addr: u64, vals: &[u64]) -> Result<(), Interrupt> {
35 self.prot_slice_check(addr, vals.len(), PROT_WRITE)?;
36 self.mw_slice_without_prot(addr, vals);
37 Ok(())
38 }
39 fn mw_slice_without_prot(&mut self, addr: u64, vals: &[u64]);
41 #[inline]
43 fn read_slice_check(&mut self, addr: u64, len: usize) -> Result<(), Interrupt> {
44 self.prot_slice_check(addr, len, PROT_READ)
45 }
46 #[inline]
48 fn write_slice_check(&mut self, addr: u64, len: usize) -> Result<(), Interrupt> {
49 self.prot_slice_check(addr, len, PROT_WRITE)
50 }
51 #[inline]
53 fn read_write_slice_check(&mut self, addr: u64, len: usize) -> Result<(), Interrupt> {
54 self.prot_slice_check(addr, len, PROT_READ | PROT_WRITE)
55 }
56 fn prot_slice_check(&mut self, addr: u64, len: usize, prot_bitmap: u8)
58 -> Result<(), Interrupt>;
59 fn page_prot_write(&mut self, addr: u64, val: u8);
61 fn page_prot_flush(&mut self) {}
63 fn input_buffer(&mut self) -> &mut VecDeque<Vec<u8>>;
65 fn public_values_stream(&mut self) -> &mut Vec<u8>;
67 fn enter_unconstrained(&mut self) -> io::Result<()>;
69 fn exit_unconstrained(&mut self);
71 fn trace_hint(&mut self, addr: u64, value: Vec<u8>);
73 fn trace_value(&mut self, value: u64);
75 fn mw_hint(&mut self, addr: u64, val: u64);
78 fn bump_memory_clk(&mut self);
82 fn get_current_clk(&self) -> u64;
84 fn set_clk(&mut self, clk: u64);
86 fn set_exit_code(&mut self, exit_code: u32);
88 fn is_unconstrained(&self) -> bool;
90 fn global_clk(&self) -> u64;
92
93 #[cfg(feature = "profiling")]
97 fn cycle_tracker_start(&mut self, name: &str) -> u32;
98
99 #[cfg(feature = "profiling")]
102 fn cycle_tracker_end(&mut self, name: &str) -> Option<(u64, u32)>;
103
104 #[cfg(feature = "profiling")]
108 fn cycle_tracker_report_end(&mut self, name: &str) -> Option<(u64, u32)>;
109
110 fn elf_info(&self) -> ElfInfo;
112 fn init_addr_iter(&self) -> impl IntoIterator<Item = u64>;
114 fn page_prot_iter(&self) -> impl IntoIterator<Item = (&u64, &PageProtValue)>;
116 fn maybe_dump_profiler_data(&self) -> (Vec<(String, u64, u64)>, Vec<u64>);
118 fn maybe_insert_profiler_symbols<I: Iterator<Item = (String, u64, u64)>>(&mut self, iter: I);
120 fn maybe_delete_profiler_symbols<I: Iterator<Item = u64>>(&mut self, iter: I);
122}
123
124impl SyscallContext for JitContext {
125 #[inline]
126 fn bump_memory_clk(&mut self) {
127 self.clk += 1;
128 }
129
130 #[inline]
131 fn get_current_clk(&self) -> u64 {
132 self.clk
133 }
134
135 #[inline]
136 fn set_clk(&mut self, clk: u64) {
137 self.clk = clk;
138 }
139
140 fn rr(&self, reg: RiscRegister) -> u64 {
141 self.registers[reg as usize]
142 }
143
144 fn rw(&mut self, _reg: RiscRegister, _value: u64) {
145 unimplemented!()
146 }
147
148 fn set_next_pc(&mut self, _pc: u64) {
149 unimplemented!()
150 }
151
152 fn mr_without_prot(&mut self, addr: u64) -> u64 {
153 unsafe { ContextMemory::new(self).mr(addr) }
154 }
155
156 fn mw_without_prot(&mut self, addr: u64, val: u64) {
157 unsafe { ContextMemory::new(self).mw(addr, val) };
158 }
159
160 fn mr_slice_without_prot(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> {
161 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
162
163 let word_address = addr / 8;
165
166 let ptr = self.memory.as_ptr() as *mut MemValue;
167 let ptr = unsafe { ptr.add(word_address as usize) };
168
169 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
172
173 if self.tracing() {
174 unsafe {
175 self.trace_mem_access(slice);
176
177 for (i, entry) in slice.iter().enumerate() {
179 let new_entry = MemValue { value: entry.value, clk: self.clk };
180 std::ptr::write(ptr.add(i), new_entry)
181 }
182 }
183 }
184
185 slice.iter().map(|val| &val.value)
186 }
187
188 fn mr_slice_no_trace(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> {
189 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
190
191 let word_address = addr / 8;
193
194 let ptr = self.memory.as_ptr() as *mut MemValue;
195 let ptr = unsafe { ptr.add(word_address as usize) };
196
197 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
200
201 slice.iter().map(|val| &val.value)
202 }
203
204 fn mr_slice_unsafe(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> {
205 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
206
207 let word_address = addr / 8;
209
210 let ptr = self.memory.as_ptr() as *mut MemValue;
211 let ptr = unsafe { ptr.add(word_address as usize) };
212
213 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
216
217 if self.tracing() {
218 unsafe {
219 self.trace_mem_access(slice);
220 }
221 }
222
223 slice.iter().map(|val| &val.value)
224 }
225
226 fn mw_slice_without_prot(&mut self, addr: u64, vals: &[u64]) {
227 unsafe { ContextMemory::new(self).mw_slice(addr, vals) };
228 }
229
230 fn prot_slice_check(
231 &mut self,
232 _addr: u64,
233 _len: usize,
234 _prot_bitmap: u8,
235 ) -> Result<(), Interrupt> {
236 Ok(())
239 }
240
241 fn page_prot_write(&mut self, _addr: u64, _val: u8) {
242 unimplemented!("page_prot_write not implemented for JitContext")
243 }
244
245 fn input_buffer(&mut self) -> &mut VecDeque<Vec<u8>> {
246 unsafe { self.input_buffer() }
247 }
248
249 fn public_values_stream(&mut self) -> &mut Vec<u8> {
250 unsafe { self.public_values_stream() }
251 }
252
253 fn enter_unconstrained(&mut self) -> io::Result<()> {
254 self.enter_unconstrained()
255 }
256
257 fn exit_unconstrained(&mut self) {
258 self.exit_unconstrained()
259 }
260
261 fn trace_hint(&mut self, addr: u64, value: Vec<u8>) {
262 if self.tracing {
263 unsafe { self.trace_hint(addr, value) };
264 }
265 }
266
267 fn trace_value(&mut self, value: u64) {
268 if self.tracing {
269 unsafe {
270 self.trace_mem_access(&[MemValue { clk: u64::MAX, value }]);
273 }
274 }
275 }
276
277 fn mw_hint(&mut self, addr: u64, val: u64) {
278 unsafe { ContextMemory::new(self).mw_hint(addr, val) };
279 }
280
281 fn set_exit_code(&mut self, exit_code: u32) {
282 self.exit_code = exit_code;
283 }
284
285 fn is_unconstrained(&self) -> bool {
286 self.is_unconstrained == 1
287 }
288
289 fn global_clk(&self) -> u64 {
290 self.global_clk
291 }
292
293 #[cfg(feature = "profiling")]
294 fn cycle_tracker_start(&mut self, _name: &str) -> u32 {
295 0
298 }
299
300 #[cfg(feature = "profiling")]
301 fn cycle_tracker_end(&mut self, _name: &str) -> Option<(u64, u32)> {
302 None
305 }
306
307 #[cfg(feature = "profiling")]
308 fn cycle_tracker_report_end(&mut self, _name: &str) -> Option<(u64, u32)> {
309 None
312 }
313
314 fn elf_info(&self) -> ElfInfo {
315 unimplemented!()
316 }
317
318 fn init_addr_iter(&self) -> impl IntoIterator<Item = u64> {
319 Vec::new()
320 }
321
322 fn page_prot_iter(&self) -> impl IntoIterator<Item = (&u64, &PageProtValue)> {
323 Vec::new()
324 }
325
326 fn maybe_dump_profiler_data(&self) -> (Vec<(String, u64, u64)>, Vec<u64>) {
327 unimplemented!()
328 }
329
330 fn maybe_insert_profiler_symbols<I: Iterator<Item = (String, u64, u64)>>(&mut self, _iter: I) {
331 unimplemented!()
332 }
333
334 fn maybe_delete_profiler_symbols<I: Iterator<Item = u64>>(&mut self, _iter: I) {
335 unimplemented!()
336 }
337}
338
339#[repr(C)]
340#[derive(Debug)]
341pub struct JitContext {
342 pub pc: u64,
344 pub clk: u64,
346 pub global_clk: u64,
348 pub is_unconstrained: u64,
351 pub(crate) jump_table: NonNull<*const u8>,
353 pub(crate) memory: NonNull<u8>,
355 pub(crate) trace_buf: *mut u8,
357 pub(crate) registers: [u64; 32],
360 pub(crate) input_buffer: NonNull<VecDeque<Vec<u8>>>,
362 pub(crate) public_values_stream: NonNull<Vec<u8>>,
364 pub(crate) hints: NonNull<Vec<(u64, Vec<u8>)>>,
366 pub(crate) memory_fd: RawFd,
368 pub(crate) maybe_unconstrained: Option<UnconstrainedCtx>,
370 pub(crate) tracing: bool,
372 pub(crate) debug_sender: Option<mpsc::SyncSender<Option<debug::State>>>,
374 pub(crate) exit_code: u32,
376}
377
378impl JitContext {
379 pub unsafe fn trace_mem_access(&self, reads: &[MemValue]) {
382 let raw = self.trace_buf;
387 let num_reads_offset = std::mem::offset_of!(TraceChunkHeader, num_mem_reads);
388 let num_reads_ptr = raw.add(num_reads_offset);
389 let num_reads = std::ptr::read_unaligned(num_reads_ptr as *mut u64);
390
391 let new_num_reads = num_reads + reads.len() as u64;
393 std::ptr::write_unaligned(num_reads_ptr as *mut u64, new_num_reads);
394
395 let reads_start = std::mem::size_of::<TraceChunkHeader>();
397 let tail_ptr = raw.add(reads_start) as *mut MemValue;
398 let tail_ptr = tail_ptr.add(num_reads as usize);
399
400 for (i, read) in reads.iter().enumerate() {
401 std::ptr::write(tail_ptr.add(i), *read);
402 }
403 }
404
405 pub fn enter_unconstrained(&mut self) -> io::Result<()> {
408 let mut cow_memory =
411 unsafe { MmapOptions::new().no_reserve_swap().map_copy(self.memory_fd)? };
412 let cow_memory_ptr = cow_memory.as_mut_ptr();
413
414 let align_offset = cow_memory_ptr.align_offset(std::mem::align_of::<u64>());
417 let cow_memory_ptr = unsafe { cow_memory_ptr.add(align_offset) };
418
419 self.maybe_unconstrained = Some(UnconstrainedCtx {
421 cow_memory,
422 actual_memory_ptr: self.memory,
423 pc: self.pc,
424 clk: self.clk,
425 global_clk: self.global_clk,
426 registers: self.registers,
427 });
428
429 self.pc = self.pc.wrapping_add(4);
431
432 self.memory = unsafe { NonNull::new_unchecked(cow_memory_ptr) };
436
437 self.is_unconstrained = 1;
439
440 Ok(())
441 }
442
443 pub fn exit_unconstrained(&mut self) {
445 let unconstrained = std::mem::take(&mut self.maybe_unconstrained)
446 .expect("Exit unconstrained called but no context is present, this is a bug.");
447
448 self.memory = unconstrained.actual_memory_ptr;
449 self.pc = unconstrained.pc;
450 self.registers = unconstrained.registers;
451 self.clk = unconstrained.clk;
452 self.is_unconstrained = 0;
453 }
454
455 pub unsafe fn trace_hint(&mut self, addr: u64, value: Vec<u8>) {
463 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
464 self.hints.as_mut().push((addr, value));
465 }
466
467 pub const fn memory(&mut self) -> ContextMemory<'_> {
469 unsafe { ContextMemory::new(self) }
470 }
471
472 pub const unsafe fn input_buffer(&mut self) -> &mut VecDeque<Vec<u8>> {
475 self.input_buffer.as_mut()
476 }
477
478 pub const unsafe fn public_values_stream(&mut self) -> &mut Vec<u8> {
481 self.public_values_stream.as_mut()
482 }
483
484 pub const fn registers(&self) -> &[u64; 32] {
486 &self.registers
487 }
488
489 pub const fn rw(&mut self, reg: RiscRegister, val: u64) {
490 self.registers[reg as usize] = val;
491 }
492
493 pub const fn rr(&self, reg: RiscRegister) -> u64 {
494 self.registers[reg as usize]
495 }
496
497 #[inline]
498 pub const fn tracing(&self) -> bool {
499 self.tracing
500 }
501}
502
503#[derive(Debug)]
505pub struct UnconstrainedCtx {
506 pub cow_memory: MmapMut,
508 pub actual_memory_ptr: NonNull<u8>,
510 pub pc: u64,
512 pub clk: u64,
514 pub global_clk: u64,
516 pub registers: [u64; 32],
518}
519
520pub struct ContextMemory<'a> {
524 ctx: &'a mut JitContext,
525}
526
527impl<'a> ContextMemory<'a> {
528 const unsafe fn new(ctx: &'a mut JitContext) -> Self {
539 Self { ctx }
540 }
541
542 #[inline]
543 pub const fn tracing(&self) -> bool {
544 self.ctx.tracing()
545 }
546
547 pub fn mr(&self, addr: u64) -> u64 {
549 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
550 let word_address = addr / 8;
552
553 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
554 let ptr = unsafe { ptr.add(word_address as usize) };
555
556 let entry = unsafe { std::ptr::read(ptr) };
559
560 if self.tracing() {
561 unsafe {
562 self.ctx.trace_mem_access(&[entry]);
563
564 let new_entry = MemValue { value: entry.value, clk: self.ctx.clk };
566 std::ptr::write(ptr, new_entry);
567 }
568 }
569
570 entry.value
571 }
572
573 pub fn mw(&mut self, addr: u64, val: u64) {
575 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
576
577 let word_address = addr / 8;
579
580 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
581 let ptr = unsafe { ptr.add(word_address as usize) };
582
583 let value = MemValue { value: val, clk: self.ctx.clk };
585
586 if self.tracing() {
588 unsafe {
589 let current_entry = std::ptr::read(ptr);
591 self.ctx.trace_mem_access(&[current_entry, value]);
592 }
593 }
594
595 unsafe { std::ptr::write(ptr, value) };
598 }
599
600 pub fn mr_slice(&self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> + Clone {
602 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
603
604 let word_address = addr / 8;
606
607 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
608 let ptr = unsafe { ptr.add(word_address as usize) };
609
610 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
613
614 if self.tracing() {
615 unsafe {
616 self.ctx.trace_mem_access(slice);
617
618 for (i, entry) in slice.iter().enumerate() {
620 let new_entry = MemValue { value: entry.value, clk: self.ctx.clk };
621 std::ptr::write(ptr.add(i), new_entry)
622 }
623 }
624 }
625
626 slice.iter().map(|val| &val.value)
627 }
628
629 pub fn mr_slice_unsafe(&self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> + Clone {
631 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
632
633 let word_address = addr / 8;
635
636 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
637 let ptr = unsafe { ptr.add(word_address as usize) };
638
639 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
642
643 if self.tracing() {
644 unsafe {
645 self.ctx.trace_mem_access(slice);
646 }
647 }
648
649 slice.iter().map(|val| &val.value)
650 }
651
652 pub fn mw_slice(&mut self, addr: u64, vals: &[u64]) {
654 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
655
656 let word_address = addr / 8;
658
659 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
660 let ptr = unsafe { ptr.add(word_address as usize) };
661
662 let values = vals.iter().map(|val| MemValue { value: *val, clk: self.ctx.clk });
664
665 if self.tracing() {
668 unsafe {
669 let current_entries = std::slice::from_raw_parts(ptr, vals.len());
670
671 for (curr, new) in current_entries.iter().zip(values.clone()) {
672 self.ctx.trace_mem_access(&[*curr, new]);
673 }
674 }
675 }
676
677 for (i, val) in values.enumerate() {
678 unsafe { std::ptr::write(ptr.add(i), val) };
679 }
680 }
681
682 pub fn mr_slice_no_trace(
684 &self,
685 addr: u64,
686 len: usize,
687 ) -> impl IntoIterator<Item = &u64> + Clone {
688 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
689
690 let word_address = addr / 8;
692
693 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
694 let ptr = unsafe { ptr.add(word_address as usize) };
695
696 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
699
700 slice.iter().map(|val| &val.value)
701 }
702
703 pub fn mw_hint(&mut self, addr: u64, val: u64) {
705 let words = addr / 8;
706
707 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
708 let ptr = unsafe { ptr.add(words as usize) };
709
710 let new_entry = MemValue { value: val, clk: 0 };
711 unsafe { std::ptr::write(ptr, new_entry) };
712 }
713}