1use crate::{debug, MemValue, RiscRegister, TraceChunkHeader};
2use memmap2::{MmapMut, MmapOptions};
3use std::{collections::VecDeque, io, os::fd::RawFd, ptr::NonNull, sync::mpsc};
4
5pub trait SyscallContext {
6 fn rr(&self, reg: RiscRegister) -> u64;
8 fn mr(&mut self, addr: u64) -> u64;
10 fn mw(&mut self, addr: u64, val: u64);
12 fn mr_slice(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64>;
14 fn mr_slice_unsafe(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64>;
17 fn mr_slice_no_trace(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64>;
19 fn mw_slice(&mut self, addr: u64, vals: &[u64]);
21 fn input_buffer(&mut self) -> &mut VecDeque<Vec<u8>>;
23 fn public_values_stream(&mut self) -> &mut Vec<u8>;
25 fn enter_unconstrained(&mut self) -> io::Result<()>;
27 fn exit_unconstrained(&mut self);
29 fn trace_hint(&mut self, addr: u64, value: Vec<u8>);
31 fn trace_value(&mut self, value: u64);
33 fn mw_hint(&mut self, addr: u64, val: u64);
36 fn bump_memory_clk(&mut self);
40 fn set_exit_code(&mut self, exit_code: u32);
42 fn is_unconstrained(&self) -> bool;
44 fn global_clk(&self) -> u64;
46
47 #[cfg(feature = "profiling")]
51 fn cycle_tracker_start(&mut self, name: &str) -> u32;
52
53 #[cfg(feature = "profiling")]
56 fn cycle_tracker_end(&mut self, name: &str) -> Option<(u64, u32)>;
57
58 #[cfg(feature = "profiling")]
62 fn cycle_tracker_report_end(&mut self, name: &str) -> Option<(u64, u32)>;
63}
64
65impl SyscallContext for JitContext {
66 #[inline]
67 fn bump_memory_clk(&mut self) {
68 self.clk += 1;
69 }
70
71 fn rr(&self, reg: RiscRegister) -> u64 {
72 self.registers[reg as usize]
73 }
74
75 fn mr(&mut self, addr: u64) -> u64 {
76 unsafe { ContextMemory::new(self).mr(addr) }
77 }
78
79 fn mw(&mut self, addr: u64, val: u64) {
80 unsafe { ContextMemory::new(self).mw(addr, val) };
81 }
82
83 fn mr_slice(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> {
84 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
85
86 let word_address = addr / 8;
88
89 let ptr = self.memory.as_ptr() as *mut MemValue;
90 let ptr = unsafe { ptr.add(word_address as usize) };
91
92 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
95
96 if self.tracing() {
97 unsafe {
98 self.trace_mem_access(slice);
99
100 for (i, entry) in slice.iter().enumerate() {
102 let new_entry = MemValue { value: entry.value, clk: self.clk };
103 std::ptr::write(ptr.add(i), new_entry)
104 }
105 }
106 }
107
108 slice.iter().map(|val| &val.value)
109 }
110
111 fn mr_slice_no_trace(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> {
112 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
113
114 let word_address = addr / 8;
116
117 let ptr = self.memory.as_ptr() as *mut MemValue;
118 let ptr = unsafe { ptr.add(word_address as usize) };
119
120 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
123
124 slice.iter().map(|val| &val.value)
125 }
126
127 fn mr_slice_unsafe(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> {
128 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
129
130 let word_address = addr / 8;
132
133 let ptr = self.memory.as_ptr() as *mut MemValue;
134 let ptr = unsafe { ptr.add(word_address as usize) };
135
136 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
139
140 if self.tracing() {
141 unsafe {
142 self.trace_mem_access(slice);
143 }
144 }
145
146 slice.iter().map(|val| &val.value)
147 }
148
149 fn mw_slice(&mut self, addr: u64, vals: &[u64]) {
150 unsafe { ContextMemory::new(self).mw_slice(addr, vals) };
151 }
152
153 fn input_buffer(&mut self) -> &mut VecDeque<Vec<u8>> {
154 unsafe { self.input_buffer() }
155 }
156
157 fn public_values_stream(&mut self) -> &mut Vec<u8> {
158 unsafe { self.public_values_stream() }
159 }
160
161 fn enter_unconstrained(&mut self) -> io::Result<()> {
162 self.enter_unconstrained()
163 }
164
165 fn exit_unconstrained(&mut self) {
166 self.exit_unconstrained()
167 }
168
169 fn trace_hint(&mut self, addr: u64, value: Vec<u8>) {
170 if self.tracing {
171 unsafe { self.trace_hint(addr, value) };
172 }
173 }
174
175 fn trace_value(&mut self, value: u64) {
176 if self.tracing {
177 unsafe {
178 self.trace_mem_access(&[MemValue { clk: u64::MAX, value }]);
181 }
182 }
183 }
184
185 fn mw_hint(&mut self, addr: u64, val: u64) {
186 unsafe { ContextMemory::new(self).mw_hint(addr, val) };
187 }
188
189 fn set_exit_code(&mut self, exit_code: u32) {
190 self.exit_code = exit_code;
191 }
192
193 fn is_unconstrained(&self) -> bool {
194 self.is_unconstrained == 1
195 }
196
197 fn global_clk(&self) -> u64 {
198 self.global_clk
199 }
200
201 #[cfg(feature = "profiling")]
202 fn cycle_tracker_start(&mut self, _name: &str) -> u32 {
203 0
206 }
207
208 #[cfg(feature = "profiling")]
209 fn cycle_tracker_end(&mut self, _name: &str) -> Option<(u64, u32)> {
210 None
213 }
214
215 #[cfg(feature = "profiling")]
216 fn cycle_tracker_report_end(&mut self, _name: &str) -> Option<(u64, u32)> {
217 None
220 }
221}
222
223#[repr(C)]
224#[derive(Debug)]
225pub struct JitContext {
226 pub pc: u64,
228 pub clk: u64,
230 pub global_clk: u64,
232 pub is_unconstrained: u64,
235 pub(crate) jump_table: NonNull<*const u8>,
237 pub(crate) memory: NonNull<u8>,
239 pub(crate) trace_buf: NonNull<u8>,
241 pub(crate) registers: [u64; 32],
244 pub(crate) input_buffer: NonNull<VecDeque<Vec<u8>>>,
246 pub(crate) public_values_stream: NonNull<Vec<u8>>,
248 pub(crate) hints: NonNull<Vec<(u64, Vec<u8>)>>,
250 pub(crate) memory_fd: RawFd,
252 pub(crate) maybe_unconstrained: Option<UnconstrainedCtx>,
254 pub(crate) tracing: bool,
256 pub(crate) debug_sender: Option<mpsc::SyncSender<Option<debug::State>>>,
258 pub(crate) exit_code: u32,
260}
261
262impl JitContext {
263 pub unsafe fn trace_mem_access(&self, reads: &[MemValue]) {
266 let raw = self.trace_buf.as_ptr();
271 let num_reads_offset = std::mem::offset_of!(TraceChunkHeader, num_mem_reads);
272 let num_reads_ptr = raw.add(num_reads_offset);
273 let num_reads = std::ptr::read_unaligned(num_reads_ptr as *mut u64);
274
275 let new_num_reads = num_reads + reads.len() as u64;
277 std::ptr::write_unaligned(num_reads_ptr as *mut u64, new_num_reads);
278
279 let reads_start = std::mem::size_of::<TraceChunkHeader>();
281 let tail_ptr = raw.add(reads_start) as *mut MemValue;
282 let tail_ptr = tail_ptr.add(num_reads as usize);
283
284 for (i, read) in reads.iter().enumerate() {
285 std::ptr::write(tail_ptr.add(i), *read);
286 }
287 }
288
289 pub fn enter_unconstrained(&mut self) -> io::Result<()> {
292 let mut cow_memory =
295 unsafe { MmapOptions::new().no_reserve_swap().map_copy(self.memory_fd)? };
296 let cow_memory_ptr = cow_memory.as_mut_ptr();
297
298 let align_offset = cow_memory_ptr.align_offset(std::mem::align_of::<u64>());
301 let cow_memory_ptr = unsafe { cow_memory_ptr.add(align_offset) };
302
303 self.maybe_unconstrained = Some(UnconstrainedCtx {
305 cow_memory,
306 actual_memory_ptr: self.memory,
307 pc: self.pc,
308 clk: self.clk,
309 global_clk: self.global_clk,
310 registers: self.registers,
311 });
312
313 self.pc = self.pc.wrapping_add(4);
315
316 self.memory = unsafe { NonNull::new_unchecked(cow_memory_ptr) };
320
321 self.is_unconstrained = 1;
323
324 Ok(())
325 }
326
327 pub fn exit_unconstrained(&mut self) {
329 let unconstrained = std::mem::take(&mut self.maybe_unconstrained)
330 .expect("Exit unconstrained called but not context is present, this is a bug.");
331
332 self.memory = unconstrained.actual_memory_ptr;
333 self.pc = unconstrained.pc;
334 self.registers = unconstrained.registers;
335 self.clk = unconstrained.clk;
336 self.is_unconstrained = 0;
337 }
338
339 pub unsafe fn trace_hint(&mut self, addr: u64, value: Vec<u8>) {
347 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
348 self.hints.as_mut().push((addr, value));
349 }
350
351 pub const fn memory(&mut self) -> ContextMemory<'_> {
353 unsafe { ContextMemory::new(self) }
354 }
355
356 pub const unsafe fn input_buffer(&mut self) -> &mut VecDeque<Vec<u8>> {
359 self.input_buffer.as_mut()
360 }
361
362 pub const unsafe fn public_values_stream(&mut self) -> &mut Vec<u8> {
365 self.public_values_stream.as_mut()
366 }
367
368 pub const fn registers(&self) -> &[u64; 32] {
370 &self.registers
371 }
372
373 pub const fn rw(&mut self, reg: RiscRegister, val: u64) {
374 self.registers[reg as usize] = val;
375 }
376
377 pub const fn rr(&self, reg: RiscRegister) -> u64 {
378 self.registers[reg as usize]
379 }
380
381 #[inline]
382 pub const fn tracing(&self) -> bool {
383 self.tracing
384 }
385}
386
387#[derive(Debug)]
389pub struct UnconstrainedCtx {
390 pub cow_memory: MmapMut,
392 pub actual_memory_ptr: NonNull<u8>,
394 pub pc: u64,
396 pub clk: u64,
398 pub global_clk: u64,
400 pub registers: [u64; 32],
402}
403
404pub struct ContextMemory<'a> {
408 ctx: &'a mut JitContext,
409}
410
411impl<'a> ContextMemory<'a> {
412 const unsafe fn new(ctx: &'a mut JitContext) -> Self {
423 Self { ctx }
424 }
425
426 #[inline]
427 pub const fn tracing(&self) -> bool {
428 self.ctx.tracing()
429 }
430
431 pub fn mr(&self, addr: u64) -> u64 {
433 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
434 let word_address = addr / 8;
436
437 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
438 let ptr = unsafe { ptr.add(word_address as usize) };
439
440 let entry = unsafe { std::ptr::read(ptr) };
443
444 if self.tracing() {
445 unsafe {
446 self.ctx.trace_mem_access(&[entry]);
447
448 let new_entry = MemValue { value: entry.value, clk: self.ctx.clk };
450 std::ptr::write(ptr, new_entry);
451 }
452 }
453
454 entry.value
455 }
456
457 pub fn mw(&mut self, addr: u64, val: u64) {
459 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
460
461 let word_address = addr / 8;
463
464 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
465 let ptr = unsafe { ptr.add(word_address as usize) };
466
467 let value = MemValue { value: val, clk: self.ctx.clk };
469
470 if self.tracing() {
472 unsafe {
473 let current_entry = std::ptr::read(ptr);
475 self.ctx.trace_mem_access(&[current_entry, value]);
476 }
477 }
478
479 unsafe { std::ptr::write(ptr, value) };
482 }
483
484 pub fn mr_slice(&self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> + Clone {
486 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
487
488 let word_address = addr / 8;
490
491 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
492 let ptr = unsafe { ptr.add(word_address as usize) };
493
494 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
497
498 if self.tracing() {
499 unsafe {
500 self.ctx.trace_mem_access(slice);
501
502 for (i, entry) in slice.iter().enumerate() {
504 let new_entry = MemValue { value: entry.value, clk: self.ctx.clk };
505 std::ptr::write(ptr.add(i), new_entry)
506 }
507 }
508 }
509
510 slice.iter().map(|val| &val.value)
511 }
512
513 pub fn mr_slice_unsafe(&self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> + Clone {
515 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
516
517 let word_address = addr / 8;
519
520 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
521 let ptr = unsafe { ptr.add(word_address as usize) };
522
523 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
526
527 if self.tracing() {
528 unsafe {
529 self.ctx.trace_mem_access(slice);
530 }
531 }
532
533 slice.iter().map(|val| &val.value)
534 }
535
536 pub fn mw_slice(&mut self, addr: u64, vals: &[u64]) {
538 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
539
540 let word_address = addr / 8;
542
543 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
544 let ptr = unsafe { ptr.add(word_address as usize) };
545
546 let values = vals.iter().map(|val| MemValue { value: *val, clk: self.ctx.clk });
548
549 if self.tracing() {
552 unsafe {
553 let current_entries = std::slice::from_raw_parts(ptr, vals.len());
554
555 for (curr, new) in current_entries.iter().zip(values.clone()) {
556 self.ctx.trace_mem_access(&[*curr, new]);
557 }
558 }
559 }
560
561 for (i, val) in values.enumerate() {
562 unsafe { std::ptr::write(ptr.add(i), val) };
563 }
564 }
565
566 pub fn mr_slice_no_trace(
568 &self,
569 addr: u64,
570 len: usize,
571 ) -> impl IntoIterator<Item = &u64> + Clone {
572 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
573
574 let word_address = addr / 8;
576
577 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
578 let ptr = unsafe { ptr.add(word_address as usize) };
579
580 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
583
584 slice.iter().map(|val| &val.value)
585 }
586
587 pub fn mw_hint(&mut self, addr: u64, val: u64) {
589 let words = addr / 8;
590
591 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
592 let ptr = unsafe { ptr.add(words as usize) };
593
594 let new_entry = MemValue { value: val, clk: 0 };
595 unsafe { std::ptr::write(ptr, new_entry) };
596 }
597}