sp1_core_executor/syscalls/
context.rs

1use hashbrown::HashMap;
2
3use crate::{
4    events::{
5        MemoryLocalEvent, MemoryReadRecord, MemoryWriteRecord, PrecompileEvent, SyscallEvent,
6    },
7    record::ExecutionRecord,
8    Executor, ExecutorMode, Register,
9};
10
11use super::SyscallCode;
12
13/// A runtime for syscalls that is protected so that developers cannot arbitrarily modify the
14/// runtime.
15#[allow(dead_code)]
16pub struct SyscallContext<'a, 'b: 'a> {
17    /// The current shard.
18    pub current_shard: u32,
19    /// The clock cycle.
20    pub clk: u32,
21    /// The next program counter.
22    pub next_pc: u32,
23    /// The exit code.
24    pub exit_code: u32,
25    /// The runtime.
26    pub rt: &'a mut Executor<'b>,
27    /// The local memory access events for the syscall.
28    pub local_memory_access: HashMap<u32, MemoryLocalEvent>,
29}
30
31impl<'a, 'b> SyscallContext<'a, 'b> {
32    /// Create a new [`SyscallContext`].
33    pub fn new(runtime: &'a mut Executor<'b>) -> Self {
34        let current_shard = runtime.shard();
35        let clk = runtime.state.clk;
36        Self {
37            current_shard,
38            clk,
39            next_pc: runtime.state.pc.wrapping_add(4),
40            exit_code: 0,
41            rt: runtime,
42            local_memory_access: HashMap::new(),
43        }
44    }
45
46    /// Get a mutable reference to the execution record.
47    pub fn record_mut(&mut self) -> &mut ExecutionRecord {
48        &mut self.rt.record
49    }
50
51    #[inline]
52    /// Add a precompile event to the execution record.
53    pub fn add_precompile_event(
54        &mut self,
55        syscall_code: SyscallCode,
56        syscall_event: SyscallEvent,
57        event: PrecompileEvent,
58    ) {
59        if self.rt.executor_mode == ExecutorMode::Trace {
60            self.record_mut().precompile_events.add_event(syscall_code, syscall_event, event);
61        }
62    }
63
64    /// Get the current shard.
65    #[must_use]
66    pub fn current_shard(&self) -> u32 {
67        self.rt.state.current_shard
68    }
69
70    /// Read a word from memory.
71    ///
72    /// `addr` must be a pointer to main memory, not a register.
73    pub fn mr(&mut self, addr: u32) -> (MemoryReadRecord, u32) {
74        let record =
75            self.rt.mr(addr, self.current_shard, self.clk, Some(&mut self.local_memory_access));
76        (record, record.value)
77    }
78
79    /// Read a slice of words from memory.
80    ///
81    /// `addr` must be a pointer to main memory, not a register.
82    pub fn mr_slice(&mut self, addr: u32, len: usize) -> (Vec<MemoryReadRecord>, Vec<u32>) {
83        let mut records = Vec::with_capacity(len);
84        let mut values = Vec::with_capacity(len);
85        for i in 0..len {
86            let (record, value) = self.mr(addr + i as u32 * 4);
87            records.push(record);
88            values.push(value);
89        }
90        (records, values)
91    }
92
93    /// Write a word to memory.
94    ///
95    /// `addr` must be a pointer to main memory, not a register.
96    pub fn mw(&mut self, addr: u32, value: u32) -> MemoryWriteRecord {
97        self.rt.mw(addr, value, self.current_shard, self.clk, Some(&mut self.local_memory_access))
98    }
99
100    /// Write a slice of words to memory.
101    pub fn mw_slice(&mut self, addr: u32, values: &[u32]) -> Vec<MemoryWriteRecord> {
102        let mut records = Vec::with_capacity(values.len());
103        for i in 0..values.len() {
104            let record = self.mw(addr + i as u32 * 4, values[i]);
105            records.push(record);
106        }
107        records
108    }
109
110    /// Read a register and record the memory access.
111    pub fn rr_traced(&mut self, register: Register) -> (MemoryReadRecord, u32) {
112        let record = self.rt.rr_traced(
113            register,
114            self.current_shard,
115            self.clk,
116            Some(&mut self.local_memory_access),
117        );
118        (record, record.value)
119    }
120
121    /// Write a register and record the memory access.
122    pub fn rw_traced(&mut self, register: Register, value: u32) -> (MemoryWriteRecord, u32) {
123        let record = self.rt.rw_traced(
124            register,
125            value,
126            self.current_shard,
127            self.clk,
128            Some(&mut self.local_memory_access),
129        );
130        (record, record.value)
131    }
132
133    /// Postprocess the syscall.  Specifically will process the syscall's memory local events.
134    pub fn postprocess(&mut self) -> Vec<MemoryLocalEvent> {
135        let mut syscall_local_mem_events = Vec::new();
136
137        if !self.rt.unconstrained {
138            if self.rt.executor_mode == ExecutorMode::Trace {
139                // Will need to transfer the existing memory local events in the executor to it's
140                // record, and return all the syscall memory local events.  This is similar
141                // to what `bump_record` does.
142                for (addr, event) in self.local_memory_access.drain() {
143                    let local_mem_access = self.rt.local_memory_access.remove(&addr);
144
145                    if let Some(local_mem_access) = local_mem_access {
146                        self.rt.record.cpu_local_memory_access.push(local_mem_access);
147                    }
148
149                    syscall_local_mem_events.push(event);
150                }
151            }
152            if let Some(estimator) = &mut self.rt.record_estimator {
153                let original_len = estimator.current_touched_compressed_addresses.len();
154                // Remove addresses from the main set that were touched in the precompile.
155                estimator.current_touched_compressed_addresses =
156                    core::mem::take(&mut estimator.current_touched_compressed_addresses) -
157                        &estimator.current_precompile_touched_compressed_addresses;
158                // Add the number of addresses that were removed from the main set.
159                // Since the type of `RangeSetBlaze::len(...)` depends on the target pointer width,
160                // we need to cast to a `usize` and suppress the clippy lint for when this is a
161                // no-op.
162                #[allow(clippy::unnecessary_cast)]
163                {
164                    estimator.current_local_mem += (original_len -
165                        estimator.current_touched_compressed_addresses.len())
166                        as usize;
167                }
168            }
169        }
170
171        syscall_local_mem_events
172    }
173
174    /// Get the current value of a register, but doesn't use a memory record.
175    /// This is generally unconstrained, so you must be careful using it.
176    #[must_use]
177    pub fn register_unsafe(&mut self, register: Register) -> u32 {
178        self.rt.register(register)
179    }
180
181    /// Get the current value of a byte, but doesn't use a memory record.
182    #[must_use]
183    pub fn byte_unsafe(&mut self, addr: u32) -> u8 {
184        self.rt.byte(addr)
185    }
186
187    /// Get the current value of a word, but doesn't use a memory record.
188    #[must_use]
189    pub fn word_unsafe(&mut self, addr: u32) -> u32 {
190        self.rt.word(addr)
191    }
192
193    /// Get a slice of words, but doesn't use a memory record.
194    #[must_use]
195    pub fn slice_unsafe(&mut self, addr: u32, len: usize) -> Vec<u32> {
196        let mut values = Vec::new();
197        for i in 0..len {
198            values.push(self.rt.word(addr + i as u32 * 4));
199        }
200        values
201    }
202
203    /// Set the next program counter.
204    pub fn set_next_pc(&mut self, next_pc: u32) {
205        self.next_pc = next_pc;
206    }
207
208    /// Set the exit code.
209    pub fn set_exit_code(&mut self, exit_code: u32) {
210        self.exit_code = exit_code;
211    }
212}