sp1_core_executor/
utils.rs

1use std::{hash::Hash, str::FromStr};
2
3use hashbrown::HashMap;
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5
6use crate::{Opcode, RiscvAirId};
7
8/// Serialize a `HashMap<u32, V>` as a `Vec<(u32, V)>`.
9pub fn serialize_hashmap_as_vec<K: Eq + Hash + Serialize, V: Serialize, S: Serializer>(
10    map: &HashMap<K, V>,
11    serializer: S,
12) -> Result<S::Ok, S::Error> {
13    Serialize::serialize(&map.iter().collect::<Vec<_>>(), serializer)
14}
15
16/// Deserialize a `Vec<(u32, V)>` as a `HashMap<u32, V>`.
17pub fn deserialize_hashmap_as_vec<
18    'de,
19    K: Eq + Hash + Deserialize<'de>,
20    V: Deserialize<'de>,
21    D: Deserializer<'de>,
22>(
23    deserializer: D,
24) -> Result<HashMap<K, V>, D::Error> {
25    let seq: Vec<(K, V)> = Deserialize::deserialize(deserializer)?;
26    Ok(seq.into_iter().collect())
27}
28
29/// Returns `true` if the given `opcode` is a signed operation.
30#[must_use]
31pub fn is_signed_operation(opcode: Opcode) -> bool {
32    opcode == Opcode::DIV || opcode == Opcode::REM
33}
34
35/// Calculate the correct `quotient` and `remainder` for the given `b` and `c` per RISC-V spec.
36#[must_use]
37pub fn get_quotient_and_remainder(b: u32, c: u32, opcode: Opcode) -> (u32, u32) {
38    if c == 0 {
39        // When c is 0, the quotient is 2^32 - 1 and the remainder is b regardless of whether we
40        // perform signed or unsigned division.
41        (u32::MAX, b)
42    } else if is_signed_operation(opcode) {
43        ((b as i32).wrapping_div(c as i32) as u32, (b as i32).wrapping_rem(c as i32) as u32)
44    } else {
45        (b.wrapping_div(c), b.wrapping_rem(c))
46    }
47}
48
49/// Calculate the most significant bit of the given 32-bit integer `a`, and returns it as a u8.
50#[must_use]
51pub const fn get_msb(a: u32) -> u8 {
52    ((a >> 31) & 1) as u8
53}
54
55/// Load the cost of each air from the predefined JSON.
56#[must_use]
57pub fn rv32im_costs() -> HashMap<RiscvAirId, usize> {
58    let costs: HashMap<String, usize> =
59        serde_json::from_str(include_str!("./artifacts/rv32im_costs.json")).unwrap();
60    costs.into_iter().map(|(k, v)| (RiscvAirId::from_str(&k).unwrap(), v)).collect()
61}