Skip to main content

sp1_core_executor/
utils.rs

1use std::{hash::Hash, str::FromStr};
2
3use hashbrown::HashMap;
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5
6use crate::{Instruction, Opcode, Register, RiscvAirId, SyscallCode};
7
8/// Serialize a `HashMap<u32, V>` as a `Vec<(u32, V)>`.
9pub fn serialize_hashmap_as_vec<K: Eq + Hash + Serialize, V: Serialize, S: Serializer>(
10    map: &HashMap<K, V>,
11    serializer: S,
12) -> Result<S::Ok, S::Error> {
13    Serialize::serialize(&map.iter().collect::<Vec<_>>(), serializer)
14}
15
16/// Deserialize a `Vec<(u32, V)>` as a `HashMap<u32, V>`.
17pub fn deserialize_hashmap_as_vec<
18    'de,
19    K: Eq + Hash + Deserialize<'de>,
20    V: Deserialize<'de>,
21    D: Deserializer<'de>,
22>(
23    deserializer: D,
24) -> Result<HashMap<K, V>, D::Error> {
25    let seq: Vec<(K, V)> = Deserialize::deserialize(deserializer)?;
26    Ok(seq.into_iter().collect())
27}
28
29/// Returns `true` if the given `opcode` is a signed 64bit operation.
30#[must_use]
31pub fn is_signed_64bit_operation(opcode: Opcode) -> bool {
32    opcode == Opcode::DIV || opcode == Opcode::REM
33}
34
35/// Returns `true` if the given `opcode` is a unsigned 64bit operation.
36#[must_use]
37pub fn is_unsigned_64bit_operation(opcode: Opcode) -> bool {
38    opcode == Opcode::DIVU || opcode == Opcode::REMU
39}
40
41/// Returns `true` if the given `opcode` is a 64bit operation.
42#[must_use]
43pub fn is_64bit_operation(opcode: Opcode) -> bool {
44    opcode == Opcode::DIV
45        || opcode == Opcode::DIVU
46        || opcode == Opcode::REM
47        || opcode == Opcode::REMU
48}
49
50/// Returns `true` if the given `opcode` is a word operation.
51#[must_use]
52pub fn is_word_operation(opcode: Opcode) -> bool {
53    opcode == Opcode::DIVW
54        || opcode == Opcode::DIVUW
55        || opcode == Opcode::REMW
56        || opcode == Opcode::REMUW
57}
58
59/// Returns `true` if the given `opcode` is a signed word operation.
60#[must_use]
61pub fn is_signed_word_operation(opcode: Opcode) -> bool {
62    opcode == Opcode::DIVW || opcode == Opcode::REMW
63}
64
65/// Returns `true` if the given `opcode` is a unsigned word operation.
66#[must_use]
67pub fn is_unsigned_word_operation(opcode: Opcode) -> bool {
68    opcode == Opcode::DIVUW || opcode == Opcode::REMUW
69}
70
71/// Calculate the correct `quotient` and `remainder` for the given `b` and `c` per RISC-V spec.
72#[must_use]
73pub fn get_quotient_and_remainder(b: u64, c: u64, opcode: Opcode) -> (u64, u64) {
74    if c == 0 && is_64bit_operation(opcode) {
75        (u64::MAX, b)
76    } else if (c as i32 == 0) && is_word_operation(opcode) {
77        (u64::MAX, (b as i32) as u64)
78    } else if is_signed_64bit_operation(opcode) {
79        ((b as i64).wrapping_div(c as i64) as u64, (b as i64).wrapping_rem(c as i64) as u64)
80    } else if is_signed_word_operation(opcode) {
81        (
82            (b as i32).wrapping_div(c as i32) as i64 as u64,
83            (b as i32).wrapping_rem(c as i32) as i64 as u64,
84        )
85    } else if is_unsigned_word_operation(opcode) {
86        (
87            (b as u32).wrapping_div(c as u32) as i32 as i64 as u64,
88            (b as u32).wrapping_rem(c as u32) as i32 as i64 as u64,
89        )
90    } else {
91        (b.wrapping_div(c), b.wrapping_rem(c))
92    }
93}
94
95/// Calculate the most significant bit of the given 64-bit integer `a`, and returns it as a u8.
96#[must_use]
97pub const fn get_msb(a: u64) -> u8 {
98    ((a >> 63) & 1) as u8
99}
100
101/// Load the cost of each air from the predefined JSON.
102#[must_use]
103pub fn rv64im_costs() -> HashMap<RiscvAirId, usize> {
104    let costs: HashMap<String, usize> =
105        serde_json::from_str(include_str!("./artifacts/rv64im_costs.json")).unwrap();
106    costs.into_iter().map(|(k, v)| (RiscvAirId::from_str(&k).unwrap(), v)).collect()
107}
108
109/// Calculate the largest multiple of 32 less than of equal to a given integer `n`.
110#[must_use]
111pub fn trunc_32(n: usize) -> usize {
112    (n / 32) * 32
113}
114
115/// The maximum trace area and maximum height increment for a single event of a syscall.
116#[must_use]
117pub fn cost_and_height_per_syscall(
118    syscall_code: SyscallCode,
119    costs: &HashMap<RiscvAirId, usize>,
120    page_protect: bool,
121) -> (usize, usize) {
122    assert!(!page_protect, "page protect turned off");
123
124    let air_id = syscall_code.as_air_id().unwrap();
125    let rows_per_event = air_id.rows_per_event();
126
127    let mut cost_per_syscall = 0;
128    let mut max_height_per_syscall = rows_per_event;
129
130    cost_per_syscall += rows_per_event * costs[&air_id];
131    if rows_per_event > 1 {
132        let control_air_id = air_id.control_air_id().unwrap();
133        cost_per_syscall += costs[&control_air_id];
134    }
135
136    let touched_addresses = syscall_code.touched_addresses();
137    cost_per_syscall += touched_addresses * costs[&RiscvAirId::MemoryLocal];
138    cost_per_syscall += 2 * touched_addresses * costs[&RiscvAirId::Global];
139    cost_per_syscall += costs[&RiscvAirId::SyscallPrecompile];
140    cost_per_syscall += costs[&RiscvAirId::Global];
141    max_height_per_syscall = max_height_per_syscall.max(2 * touched_addresses + 1);
142
143    (cost_per_syscall, max_height_per_syscall)
144}
145
146/// Add a halt syscall to the end of the instructions vec.
147pub fn add_halt(instructions: &mut Vec<Instruction>) {
148    instructions.push(Instruction::new(Opcode::ADD, Register::X5 as u8, 0, 0, false, false));
149    instructions.push(Instruction::new(Opcode::ADD, Register::X10 as u8, 0, 0, false, false));
150    instructions.push(Instruction::new(
151        Opcode::ECALL,
152        Register::X5 as u8,
153        Register::X10 as u64,
154        Register::X11 as u64,
155        false,
156        false,
157    ));
158}