sp1_core_executor/
memory.rs

1use serde::{de::DeserializeOwned, Deserialize, Serialize};
2use vec_map::VecMap;
3
4/// A memory.
5///
6/// Consists of registers, as well as a page table for main memory.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8#[serde(bound(serialize = "T: Serialize"))]
9#[serde(bound(deserialize = "T: DeserializeOwned"))]
10pub struct Memory<T: Copy> {
11    /// The registers.
12    pub registers: Registers<T>,
13    /// The page table.
14    pub page_table: PagedMemory<T>,
15}
16
17impl<V: Copy + 'static> IntoIterator for Memory<V> {
18    type Item = (u32, V);
19
20    type IntoIter = Box<dyn Iterator<Item = Self::Item>>;
21
22    fn into_iter(self) -> Self::IntoIter {
23        Box::new(self.registers.into_iter().chain(self.page_table))
24    }
25}
26
27impl<T: Copy + Default> Default for Memory<T> {
28    fn default() -> Self {
29        Self { registers: Registers::default(), page_table: PagedMemory::default() }
30    }
31}
32
33impl<T: Copy> Memory<T> {
34    /// Initialize a new memory with preallocated page table.
35    pub fn new_preallocated() -> Self {
36        Self { registers: Registers::default(), page_table: PagedMemory::new_preallocated() }
37    }
38
39    /// Get an entry for the given address.
40    ///
41    /// When possible, prefer directly accessing the `page_table` or `registers` fields.
42    /// This method often incurs unnecessary branching.
43    #[inline]
44    pub fn entry(&mut self, addr: u32) -> Entry<'_, T> {
45        if addr < 32 {
46            self.registers.entry(addr)
47        } else {
48            self.page_table.entry(addr)
49        }
50    }
51
52    /// Insert a value into the memory.
53    ///
54    /// When possible, prefer directly accessing the `page_table` or `registers` fields.
55    /// This method often incurs unnecessary branching.   
56    #[inline]
57    pub fn insert(&mut self, addr: u32, value: T) -> Option<T> {
58        if addr < 32 {
59            self.registers.insert(addr, value)
60        } else {
61            self.page_table.insert(addr, value)
62        }
63    }
64
65    /// Get a value from the memory.
66    ///
67    /// When possible, prefer directly accessing the `page_table` or `registers` fields.
68    /// This method often incurs unnecessary branching.
69    #[inline]
70    pub fn get(&self, addr: u32) -> Option<&T> {
71        if addr < 32 {
72            self.registers.get(addr)
73        } else {
74            self.page_table.get(addr)
75        }
76    }
77
78    /// Remove a value from the memory.
79    ///
80    /// When possible, prefer directly accessing the `page_table` or `registers` fields.
81    /// This method often incurs unnecessary branching.
82    #[inline]
83    pub fn remove(&mut self, addr: u32) -> Option<T> {
84        if addr < 32 {
85            self.registers.remove(addr)
86        } else {
87            self.page_table.remove(addr)
88        }
89    }
90
91    /// Clear the memory.
92    #[inline]
93    pub fn clear(&mut self) {
94        self.registers.clear();
95        self.page_table.clear();
96    }
97}
98
99impl<V: Copy + Default> FromIterator<(u32, V)> for Memory<V> {
100    fn from_iter<T: IntoIterator<Item = (u32, V)>>(iter: T) -> Self {
101        let mut memory = Self::new_preallocated();
102        for (addr, value) in iter {
103            memory.insert(addr, value);
104        }
105        memory
106    }
107}
108
109/// An array of 32 registers.
110#[derive(Debug, Clone, Serialize, Deserialize)]
111#[serde(bound(serialize = "T: Serialize"))]
112#[serde(bound(deserialize = "T: DeserializeOwned"))]
113pub struct Registers<T: Copy> {
114    pub registers: [Option<T>; 32],
115}
116
117impl<T: Copy> Default for Registers<T> {
118    fn default() -> Self {
119        Self { registers: [None; 32] }
120    }
121}
122
123impl<T: Copy> Registers<T> {
124    /// Get an entry for the given register.
125    #[inline]
126    pub fn entry(&mut self, addr: u32) -> Entry<'_, T> {
127        let entry = &mut self.registers[addr as usize];
128        match entry {
129            Some(v) => Entry::Occupied(OccupiedEntry { entry: v }),
130            None => Entry::Vacant(VacantEntry { entry }),
131        }
132    }
133
134    /// Insert a value into the registers.
135    ///
136    /// Assumes addr < 32.
137    #[inline]
138    pub fn insert(&mut self, addr: u32, value: T) -> Option<T> {
139        self.registers[addr as usize].replace(value)
140    }
141
142    /// Remove a value from the registers, and return it if it exists.
143    ///
144    /// Assumes addr < 32.
145    #[inline]
146    pub fn remove(&mut self, addr: u32) -> Option<T> {
147        self.registers[addr as usize].take()
148    }
149
150    /// Get a reference to the value at the given address, if it exists.
151    ///
152    /// Assumes addr < 32.
153    #[inline]
154    pub fn get(&self, addr: u32) -> Option<&T> {
155        self.registers[addr as usize].as_ref()
156    }
157
158    /// Clear the registers.
159    #[inline]
160    pub fn clear(&mut self) {
161        self.registers.fill(None);
162    }
163}
164
165impl<V: Copy> FromIterator<(u32, V)> for Registers<V> {
166    fn from_iter<T: IntoIterator<Item = (u32, V)>>(iter: T) -> Self {
167        let mut mmu = Self::default();
168        for (k, v) in iter {
169            mmu.insert(k, v);
170        }
171        mmu
172    }
173}
174
175impl<V: Copy + 'static> IntoIterator for Registers<V> {
176    type Item = (u32, V);
177
178    type IntoIter = Box<dyn Iterator<Item = Self::Item>>;
179
180    fn into_iter(self) -> Self::IntoIter {
181        Box::new(
182            self.registers
183                .into_iter()
184                .enumerate()
185                .filter_map(move |(i, v)| v.map(|v| (i as u32, v))),
186        )
187    }
188}
189
190/// A page of memory.
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct Page<V>(VecMap<V>);
193
194impl<V> Default for Page<V> {
195    fn default() -> Self {
196        Self(VecMap::default())
197    }
198}
199
200const LOG_PAGE_LEN: usize = 14;
201const PAGE_LEN: usize = 1 << LOG_PAGE_LEN;
202const MAX_PAGE_COUNT: usize = ((1 << 31) - (1 << 27)) / 4 / PAGE_LEN + 1;
203const NO_PAGE: u16 = u16::MAX;
204const PAGE_MASK: usize = PAGE_LEN - 1;
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
207#[serde(bound(serialize = "V: Serialize"))]
208#[serde(bound(deserialize = "V: DeserializeOwned"))]
209pub struct NewPage<V>(Vec<Option<V>>);
210
211impl<V: Copy> NewPage<V> {
212    pub fn new() -> Self {
213        Self(vec![None; PAGE_LEN])
214    }
215}
216
217impl<V: Copy> Default for NewPage<V> {
218    fn default() -> Self {
219        Self(Vec::new())
220    }
221}
222
223/// Paged memory. Balances both memory locality and total memory usage.
224#[derive(Debug, Clone, Serialize, Deserialize)]
225#[serde(bound(serialize = "V: Serialize"))]
226#[serde(bound(deserialize = "V: DeserializeOwned"))]
227pub struct PagedMemory<V: Copy> {
228    /// The internal page table.
229    pub page_table: Vec<NewPage<V>>,
230    pub index: Vec<u16>,
231}
232
233impl<V: Copy> PagedMemory<V> {
234    /// The number of lower bits to ignore, since addresses (except registers) are a multiple of 4.
235    const NUM_IGNORED_LOWER_BITS: usize = 2;
236
237    /// Create a `PagedMemory` with capacity `MAX_PAGE_COUNT`.
238    pub fn new_preallocated() -> Self {
239        Self { page_table: Vec::new(), index: vec![NO_PAGE; MAX_PAGE_COUNT] }
240    }
241
242    /// Get a reference to the memory value at the given address, if it exists.
243    pub fn get(&self, addr: u32) -> Option<&V> {
244        let (upper, lower) = Self::indices(addr);
245        let index = self.index[upper];
246        if index == NO_PAGE {
247            None
248        } else {
249            self.page_table[index as usize].0[lower].as_ref()
250        }
251    }
252
253    /// Get a mutable reference to the memory value at the given address, if it exists.
254    pub fn get_mut(&mut self, addr: u32) -> Option<&mut V> {
255        let (upper, lower) = Self::indices(addr);
256        let index = self.index[upper];
257        if index == NO_PAGE {
258            None
259        } else {
260            self.page_table[index as usize].0[lower].as_mut()
261        }
262    }
263
264    /// Insert a value at the given address. Returns the previous value, if any.
265    pub fn insert(&mut self, addr: u32, value: V) -> Option<V> {
266        let (upper, lower) = Self::indices(addr);
267        let mut index = self.index[upper];
268        if index == NO_PAGE {
269            index = self.page_table.len() as u16;
270            self.index[upper] = index;
271            self.page_table.push(NewPage::new());
272        }
273        self.page_table[index as usize].0[lower].replace(value)
274    }
275
276    /// Remove the value at the given address if it exists, returning it.
277    pub fn remove(&mut self, addr: u32) -> Option<V> {
278        let (upper, lower) = Self::indices(addr);
279        let index = self.index[upper];
280        if index == NO_PAGE {
281            None
282        } else {
283            self.page_table[index as usize].0[lower].take()
284        }
285    }
286
287    /// Gets the memory entry for the given address.
288    pub fn entry(&mut self, addr: u32) -> Entry<'_, V> {
289        let (upper, lower) = Self::indices(addr);
290        let index = self.index[upper];
291        if index == NO_PAGE {
292            let index = self.page_table.len();
293            self.index[upper] = index as u16;
294            self.page_table.push(NewPage::new());
295            Entry::Vacant(VacantEntry { entry: &mut self.page_table[index].0[lower] })
296        } else {
297            let option = &mut self.page_table[index as usize].0[lower];
298            match option {
299                Some(v) => Entry::Occupied(OccupiedEntry { entry: v }),
300                None => Entry::Vacant(VacantEntry { entry: option }),
301            }
302        }
303    }
304
305    /// Returns an iterator over the occupied addresses.
306    pub fn keys(&self) -> impl Iterator<Item = u32> + '_ {
307        self.index.iter().enumerate().filter(|(_, &i)| i != NO_PAGE).flat_map(|(i, index)| {
308            let upper = i << LOG_PAGE_LEN;
309            self.page_table[*index as usize]
310                .0
311                .iter()
312                .enumerate()
313                .filter_map(move |(lower, v)| v.map(|_| Self::decompress_addr(upper + lower)))
314        })
315    }
316
317    /// Get the exact number of addresses in use. This function iterates through each page
318    /// and is therefore somewhat expensive.
319    pub fn exact_len(&self) -> usize {
320        self.index
321            .iter()
322            .filter(|&&i| i != NO_PAGE)
323            .map(|index| self.page_table[*index as usize].0.iter().filter(|v| v.is_some()).count())
324            .sum()
325    }
326
327    /// Estimate the number of addresses in use.
328    pub fn estimate_len(&self) -> usize {
329        self.index.iter().filter(|&i| *i != NO_PAGE).count() * PAGE_LEN
330    }
331
332    /// Clears the page table. Drops all `Page`s, but retains the memory used by the table itself.
333    pub fn clear(&mut self) {
334        self.page_table.clear();
335        self.index.fill(NO_PAGE);
336    }
337
338    /// Break apart an address into an upper and lower index.
339    #[inline]
340    const fn indices(addr: u32) -> (usize, usize) {
341        let index = Self::compress_addr(addr);
342        (index >> LOG_PAGE_LEN, index & PAGE_MASK)
343    }
344
345    /// Compress an address from the sparse address space to a contiguous space.
346    #[inline]
347    const fn compress_addr(addr: u32) -> usize {
348        addr as usize >> Self::NUM_IGNORED_LOWER_BITS
349    }
350
351    /// Decompress an address from a contiguous space to the sparse address space.
352    #[inline]
353    const fn decompress_addr(addr: usize) -> u32 {
354        (addr << Self::NUM_IGNORED_LOWER_BITS) as u32
355    }
356}
357
358impl<V: Copy> Default for PagedMemory<V> {
359    fn default() -> Self {
360        Self { page_table: Vec::new(), index: vec![NO_PAGE; MAX_PAGE_COUNT] }
361    }
362}
363
364/// An entry of `PagedMemory` or `Registers`, for in-place manipulation.
365pub enum Entry<'a, V: Copy> {
366    Vacant(VacantEntry<'a, V>),
367    Occupied(OccupiedEntry<'a, V>),
368}
369
370impl<'a, V: Copy> Entry<'a, V> {
371    /// Ensures a value is in the entry, inserting the provided value if necessary.
372    /// Returns a mutable reference to the value.
373    pub fn or_insert(self, default: V) -> &'a mut V {
374        match self {
375            Entry::Vacant(entry) => entry.insert(default),
376            Entry::Occupied(entry) => entry.into_mut(),
377        }
378    }
379
380    /// Ensures a value is in the entry, computing a value if necessary.
381    /// Returns a mutable reference to the value.
382    pub fn or_insert_with<F: FnOnce() -> V>(self, default: F) -> &'a mut V {
383        match self {
384            Entry::Vacant(entry) => entry.insert(default()),
385            Entry::Occupied(entry) => entry.into_mut(),
386        }
387    }
388
389    /// Provides in-place mutable access to an occupied entry before any potential inserts into the
390    /// map.
391    pub fn and_modify<F: FnOnce(&mut V)>(mut self, f: F) -> Self {
392        match &mut self {
393            Entry::Vacant(_) => {}
394            Entry::Occupied(entry) => f(entry.get_mut()),
395        }
396        self
397    }
398}
399
400/// A vacant entry, for in-place manipulation.
401pub struct VacantEntry<'a, V: Copy> {
402    entry: &'a mut Option<V>,
403}
404
405impl<'a, V: Copy> VacantEntry<'a, V> {
406    /// Insert a value into the `VacantEntry`, returning a mutable reference to it.
407    pub fn insert(self, value: V) -> &'a mut V {
408        // By construction, the slot in the page is `None`.
409        *self.entry = Some(value);
410        self.entry.as_mut().unwrap()
411    }
412}
413
414/// An occupied entry, for in-place manipulation.
415pub struct OccupiedEntry<'a, V> {
416    entry: &'a mut V,
417}
418
419impl<'a, V: Copy> OccupiedEntry<'a, V> {
420    /// Get a reference to the value in the `OccupiedEntry`.
421    pub fn get(&self) -> &V {
422        self.entry
423    }
424
425    /// Get a mutable reference to the value in the `OccupiedEntry`.
426    pub fn get_mut(&mut self) -> &mut V {
427        self.entry
428    }
429
430    /// Insert a value in the `OccupiedEntry`, returning the previous value.
431    pub fn insert(&mut self, value: V) -> V {
432        std::mem::replace(self.entry, value)
433    }
434
435    /// Converts the `OccupiedEntry` the into a mutable reference to the associated value.
436    pub fn into_mut(self) -> &'a mut V {
437        self.entry
438    }
439
440    /// Removes the value from the `OccupiedEntry` and returns it.
441    pub fn remove(self) -> V {
442        *self.entry
443    }
444}
445
446impl<V: Copy> FromIterator<(u32, V)> for PagedMemory<V> {
447    fn from_iter<T: IntoIterator<Item = (u32, V)>>(iter: T) -> Self {
448        let mut mmu = Self::new_preallocated();
449        for (k, v) in iter {
450            mmu.insert(k, v);
451        }
452        mmu
453    }
454}
455
456impl<V: Copy + 'static> IntoIterator for PagedMemory<V> {
457    type Item = (u32, V);
458
459    type IntoIter = Box<dyn Iterator<Item = Self::Item>>;
460
461    fn into_iter(mut self) -> Self::IntoIter {
462        Box::new(self.index.into_iter().enumerate().filter(|(_, i)| *i != NO_PAGE).flat_map(
463            move |(i, index)| {
464                let upper = i << LOG_PAGE_LEN;
465                std::mem::take(&mut self.page_table[index as usize])
466                    .0
467                    .into_iter()
468                    .enumerate()
469                    .filter_map(move |(lower, v)| {
470                        v.map(|v| (Self::decompress_addr(upper + lower), v))
471                    })
472            },
473        ))
474    }
475}