Skip to main content

sp1_core_executor/
instruction.rs

1//! Instructions for the SP1 zkVM.
2
3use core::fmt::Debug;
4use rrs_lib::instruction_formats::{
5    OPCODE_AUIPC, OPCODE_BRANCH, OPCODE_JAL, OPCODE_JALR, OPCODE_LOAD, OPCODE_LUI, OPCODE_OP,
6    OPCODE_OP_32, OPCODE_OP_IMM, OPCODE_OP_IMM_32, OPCODE_STORE, OPCODE_SYSTEM,
7};
8use serde::{Deserialize, Serialize};
9
10use crate::opcode::Opcode;
11
12/// RV64 instruction types.
13pub enum InstructionType {
14    /// I-type instructions with shamt (32-bit).
15    ITypeShamt32 = 0b0_0000_0001,
16    /// I-type instructions with shamt.
17    ITypeShamt = 0b0_0000_0010,
18    /// I-type instructions.
19    IType = 0b0_0000_0100,
20    /// R-type instructions.
21    RType = 0b0_0000_1000,
22    /// J-type instructions.
23    JType = 0b0_0001_0000,
24    /// B-type instructions.
25    BType = 0b0_0010_0000,
26    /// S-type instructions.
27    SType = 0b0_0100_0000,
28    /// U-type instructions.
29    UType = 0b0_1000_0000,
30    /// ECALL instruction.
31    ECALL = 0b1_0000_0000,
32}
33
34/// Validates that a u64 is properly sign-extended for a given bit width.
35///
36/// This function checks that all bits above the specified bit width are properly sign-extended
37/// (either all 0s for positive values or all 1s for negative values).
38///
39/// Returns true if the value is properly sign-extended, false otherwise.
40#[must_use]
41#[inline]
42pub const fn validate_sign_extension(value: u64, bit_width: u32) -> bool {
43    if bit_width >= 64 {
44        return true; // No sign extension needed
45    }
46
47    let sign_bit_mask = 1u64 << (bit_width - 1);
48    let sign_bit = (value & sign_bit_mask) != 0;
49
50    // Create mask for bits above the immediate width
51    let upper_bits_mask = !((1u64 << bit_width) - 1);
52    let upper_bits = value & upper_bits_mask;
53
54    if sign_bit {
55        // Negative value: upper bits should all be 1s
56        upper_bits == upper_bits_mask
57    } else {
58        // Positive value: upper bits should all be 0s
59        upper_bits == 0
60    }
61}
62
63/// RISC-V 64IM Instruction.
64///
65/// The structure of the instruction differs from the RISC-V ISA. We do not encode the instructions
66/// as 32-bit words, but instead use a custom encoding that is more friendly to decode in the
67/// SP1 zkVM.
68#[derive(Clone, Copy, Serialize, Deserialize, deepsize2::DeepSizeOf)]
69#[repr(C)]
70pub struct Instruction {
71    /// The operation to execute.
72    pub opcode: Opcode,
73    /// The first operand.
74    pub op_a: u8,
75    /// The second operand.
76    pub op_b: u64,
77    /// The third operand.
78    pub op_c: u64,
79    /// Whether the second operand is an immediate value.
80    pub imm_b: bool,
81    /// Whether the third operand is an immediate value.
82    pub imm_c: bool,
83}
84
85impl Instruction {
86    /// Create a new [`RiscvInstruction`].
87    #[must_use]
88    pub const fn new(
89        opcode: Opcode,
90        op_a: u8,
91        op_b: u64,
92        op_c: u64,
93        imm_b: bool,
94        imm_c: bool,
95    ) -> Self {
96        Self { opcode, op_a, op_b, op_c, imm_b, imm_c }
97    }
98
99    /// Returns if the instruction is an ALU instruction.
100    #[must_use]
101    #[inline]
102    pub const fn is_alu_instruction(&self) -> bool {
103        matches!(
104            self.opcode,
105            Opcode::ADD
106                | Opcode::ADDI
107                | Opcode::SUB
108                | Opcode::XOR
109                | Opcode::OR
110                | Opcode::AND
111                | Opcode::SLL
112                | Opcode::SRL
113                | Opcode::SRA
114                | Opcode::SLT
115                | Opcode::SLTU
116                | Opcode::MUL
117                | Opcode::MULH
118                | Opcode::MULHU
119                | Opcode::MULHSU
120                | Opcode::DIV
121                | Opcode::DIVU
122                | Opcode::REM
123                | Opcode::REMU
124                // RISCV-64
125                | Opcode::ADDW
126                | Opcode::SUBW
127                | Opcode::MULW
128                | Opcode::DIVW
129                | Opcode::DIVUW
130                | Opcode::REMW
131                | Opcode::REMUW
132                | Opcode::SLLW
133                | Opcode::SRLW
134                | Opcode::SRAW
135        )
136    }
137
138    /// Returns if the instruction is a ecall instruction.
139    #[must_use]
140    #[inline]
141    pub const fn is_ecall_instruction(&self) -> bool {
142        matches!(self.opcode, Opcode::ECALL)
143    }
144
145    /// Returns if the instruction is a memory load instruction.
146    #[must_use]
147    #[inline]
148    pub const fn is_memory_load_instruction(&self) -> bool {
149        matches!(
150            self.opcode,
151            Opcode::LB
152                | Opcode::LH
153                | Opcode::LW
154                | Opcode::LBU
155                | Opcode::LHU
156                // RISCV-64
157                | Opcode::LWU
158                | Opcode::LD
159        )
160    }
161
162    /// Returns if the instruction is a memory store instruction.
163    #[must_use]
164    #[inline]
165    pub const fn is_memory_store_instruction(&self) -> bool {
166        matches!(self.opcode, Opcode::SB | Opcode::SH | Opcode::SW | /* RISCV-64 */ Opcode::SD)
167    }
168
169    /// Returns if the instruction is a branch instruction.
170    #[must_use]
171    #[inline]
172    pub const fn is_branch_instruction(&self) -> bool {
173        matches!(
174            self.opcode,
175            Opcode::BEQ | Opcode::BNE | Opcode::BLT | Opcode::BGE | Opcode::BLTU | Opcode::BGEU
176        )
177    }
178
179    /// Returns if the instruction is a jump instruction.
180    #[must_use]
181    #[inline]
182    pub const fn is_jump_instruction(&self) -> bool {
183        matches!(self.opcode, Opcode::JAL | Opcode::JALR)
184    }
185
186    /// Returns if the instruction is a jal instruction.
187    #[must_use]
188    #[inline]
189    pub const fn is_jal_instruction(&self) -> bool {
190        matches!(self.opcode, Opcode::JAL)
191    }
192
193    /// Returns if the instruction is a jalr instruction.
194    #[must_use]
195    #[inline]
196    pub const fn is_jalr_instruction(&self) -> bool {
197        matches!(self.opcode, Opcode::JALR)
198    }
199
200    /// Returns if the instruction is a utype instruction.
201    #[must_use]
202    #[inline]
203    pub const fn is_utype_instruction(&self) -> bool {
204        matches!(self.opcode, Opcode::AUIPC | Opcode::LUI)
205    }
206
207    /// Returns if the instruction guarantees that the `next_pc` are with correct limbs.
208    #[must_use]
209    #[inline]
210    pub const fn is_with_correct_next_pc(&self) -> bool {
211        matches!(
212            self.opcode,
213            Opcode::BEQ
214                | Opcode::BNE
215                | Opcode::BLT
216                | Opcode::BGE
217                | Opcode::BLTU
218                | Opcode::BGEU
219                | Opcode::JAL
220                | Opcode::JALR
221        )
222    }
223
224    /// Returns if the instruction is a divrem instruction.
225    #[must_use]
226    #[inline]
227    pub const fn is_divrem_instruction(&self) -> bool {
228        matches!(self.opcode, Opcode::DIV | Opcode::DIVU | Opcode::REM | Opcode::REMU)
229    }
230
231    /// Returns if the instruction is an ebreak instruction.
232    #[must_use]
233    #[inline]
234    pub const fn is_ebreak_instruction(&self) -> bool {
235        matches!(self.opcode, Opcode::EBREAK)
236    }
237
238    /// Returns if the instruction is an unimplemented instruction.
239    #[must_use]
240    #[inline]
241    pub const fn is_unimp_instruction(&self) -> bool {
242        matches!(self.opcode, Opcode::UNIMP)
243    }
244
245    /// Returns the encoded RISC-V instruction.
246    #[must_use]
247    #[inline]
248    #[allow(clippy::too_many_lines)]
249    pub fn encode(&self) -> u32 {
250        if self.opcode == Opcode::ECALL {
251            0x00000073
252        } else {
253            let (mut base_opcode, imm_base_opcode) = self.opcode.base_opcode();
254
255            let is_imm = self.imm_c;
256            if is_imm {
257                base_opcode = imm_base_opcode.expect("Opcode should have imm base opcode");
258            }
259
260            let funct3 = self.opcode.funct3();
261            let funct7 = self.opcode.funct7();
262            let funct12 = self.opcode.funct12();
263
264            match base_opcode {
265                // R-type instructions
266                // Operands represent register indices, which must be 5 bits (0-31)
267                OPCODE_OP | OPCODE_OP_32 => {
268                    assert!(
269                        self.op_a <= 31,
270                        "Register index {} exceeds maximum value 31",
271                        self.op_a
272                    );
273                    assert!(
274                        self.op_b <= 31,
275                        "Register index {} exceeds maximum value 31",
276                        self.op_b
277                    );
278                    assert!(
279                        self.op_c <= 31,
280                        "Register index {} exceeds maximum value 31",
281                        self.op_c
282                    );
283
284                    let (rd, rs1, rs2) = (self.op_a as u32, self.op_b as u32, self.op_c as u32);
285                    let funct3: u32 = funct3.expect("Opcode should have funct3").into();
286                    let funct7: u32 = funct7.expect("Opcode should have funct7").into();
287
288                    (funct7 << 25)
289                        | (rs2 << 20)
290                        | (rs1 << 15)
291                        | (funct3 << 12)
292                        | (rd << 7)
293                        | base_opcode
294                }
295                // I-type instructions
296                // Operands a and b represent register indices, which must be 5 bits (0-31)
297                // Operand c represents an immediate value, which must be 12 bits (signed)
298                OPCODE_OP_IMM | OPCODE_OP_IMM_32 | OPCODE_LOAD | OPCODE_JALR | OPCODE_SYSTEM => {
299                    assert!(
300                        self.op_a <= 31,
301                        "Register index {} exceeds maximum value 31",
302                        self.op_a
303                    );
304                    assert!(
305                        self.op_b <= 31,
306                        "Register index {} exceeds maximum value 31",
307                        self.op_b
308                    );
309                    assert!(
310                        validate_sign_extension(self.op_c, 12),
311                        "Immediate value {} is not properly sign-extended for 12 bits",
312                        self.op_c
313                    );
314
315                    let (rd, rs1, imm) = (
316                        self.op_a as u32,
317                        self.op_b as u32,
318                        // Extract original 12-bit immediate from sign-extended u64
319                        (self.op_c & 0xFFF) as u32,
320                    );
321                    let funct3: u32 = funct3.expect("Opcode should have funct3").into();
322
323                    // Check if it should be a I-type shamt instruction.
324                    if (base_opcode == OPCODE_OP_IMM || base_opcode == OPCODE_OP_IMM_32)
325                        && matches!(funct3, 0b001 | 0b101)
326                    {
327                        let funct7: u32 = funct7.expect("Opcode should have funct7").into();
328                        (funct7 << 25)
329                            | (imm << 20)
330                            | (rs1 << 15)
331                            | (funct3 << 12)
332                            | (rd << 7)
333                            | base_opcode
334                    } else if base_opcode == OPCODE_SYSTEM && funct3 == 0 && rd == 0 && rs1 == 0 {
335                        let funct12: u32 = funct12.expect("Opcode should have funct12");
336                        (funct12 << 20) | (rs1 << 15) | (funct3 << 12) | (rd << 7) | base_opcode
337                    } else {
338                        (imm << 20) | (rs1 << 15) | (funct3 << 12) | (rd << 7) | base_opcode
339                    }
340                }
341
342                // S-type instructions
343                // Operands a and b represent register indices, which must be 5 bits (0-31)
344                // Operand c represents an immediate value, which must be 12 bits (signed) (split
345                // b/w [31:25] + [11:7])
346                OPCODE_STORE => {
347                    assert!(
348                        self.op_a <= 31,
349                        "Register index {} exceeds maximum value 31",
350                        self.op_a
351                    );
352                    assert!(
353                        self.op_b <= 31,
354                        "Register index {} exceeds maximum value 31",
355                        self.op_b
356                    );
357                    assert!(
358                        validate_sign_extension(self.op_c, 12),
359                        "Immediate value {} is not properly sign-extended for 12 bits",
360                        self.op_c
361                    );
362
363                    let funct3: u32 = funct3.expect("Opcode should have funct3").into();
364                    let (rd, rs1, imm) = (
365                        self.op_a as u32,
366                        self.op_b as u32,
367                        // Extract original 12-bit immediate from sign-extended u64
368                        (self.op_c & 0xFFF) as u32,
369                    );
370                    let imm_11_5 = (imm >> 5) & 0b1111111;
371                    let imm_4_0 = imm & 0b11111;
372
373                    (imm_11_5 << 25)
374                        | (rd << 20)
375                        | (rs1 << 15)
376                        | (funct3 << 12)
377                        | (imm_4_0 << 7)
378                        | base_opcode
379                }
380
381                // B-type instructions
382                // Operands a and b represent register indices, which must be 5 bits (0-31)
383                // Signed 13 bits for B-type instructions (bits [31:25] + [11:8] + [7])
384                OPCODE_BRANCH => {
385                    assert!(
386                        self.op_a <= 31,
387                        "Register index {} exceeds maximum value 31",
388                        self.op_a
389                    );
390                    assert!(
391                        self.op_b <= 31,
392                        "Register index {} exceeds maximum value 31",
393                        self.op_b
394                    );
395                    assert!(
396                        validate_sign_extension(self.op_c, 13),
397                        "Immediate value {} is not properly sign-extended for 13 bits",
398                        self.op_c
399                    );
400
401                    let funct3: u32 = funct3.expect("Opcode should have funct3").into();
402                    let (rd, rs1, imm) = (
403                        self.op_a as u32,
404                        self.op_b as u32,
405                        // Extract original 13-bit immediate from sign-extended u64
406                        (self.op_c & 0x1FFF) as u32,
407                    );
408                    assert!(imm & 0b1 == 0, "B-type immediate must be aligned (multiple of 2)");
409
410                    let imm_12 = (imm >> 12) & 0b1;
411                    let imm_10_5 = (imm >> 5) & 0b111111;
412                    let imm_4_1 = (imm >> 1) & 0b1111;
413                    let imm_11 = (imm >> 11) & 0b1;
414
415                    (imm_12 << 31)
416                        | (imm_10_5 << 25)
417                        | (rs1 << 20)
418                        | (rd << 15)
419                        | (funct3 << 12)
420                        | (imm_4_1 << 8)
421                        | (imm_11 << 7)
422                        | (base_opcode & 0b1111111)
423                }
424                // U-type instructions
425                // 20 bits for U-type instructions (bits [31:12])
426                OPCODE_AUIPC | OPCODE_LUI => {
427                    assert!(
428                        self.op_a <= 31,
429                        "Register index {} exceeds maximum value 31",
430                        self.op_a
431                    );
432                    let mut sign_extended_imm = self.op_b >> 12;
433                    if self.op_b >= (1 << 32) {
434                        sign_extended_imm += u64::MAX - (1u64 << 52) + 1;
435                    }
436                    assert!(
437                        validate_sign_extension(sign_extended_imm, 20),
438                        "Immediate value {} is not properly sign-extended for 20 bits",
439                        self.op_b
440                    );
441                    let (rd, imm_upper) = (
442                        self.op_a as u32,
443                        // Extract original 20-bit immediate from sign-extended u64
444                        self.op_b as u32,
445                    );
446                    imm_upper | (rd << 7) | base_opcode
447                }
448                // J-type instructions
449                // 21 bits for J-type instructions (bits [31:12] + [20] + [19:12] + [30:21])
450                OPCODE_JAL => {
451                    assert!(
452                        self.op_a <= 31,
453                        "Register index {} exceeds maximum value 31",
454                        self.op_a
455                    );
456                    assert!(
457                        validate_sign_extension(self.op_b, 21),
458                        "Immediate value {} is not properly sign-extended for 21 bits",
459                        self.op_b
460                    );
461                    assert!(self.op_b & 0b1 == 0, "J-type immediate must be 2-byte aligned");
462                    let (rd, imm) = (
463                        self.op_a as u32,
464                        // Extract original 21-bit immediate from sign-extended u64
465                        (self.op_b & 0x1FFFFF) as u32,
466                    );
467
468                    let imm_20 = (imm >> 20) & 0x1;
469                    let imm_10_1 = (imm >> 1) & 0x3FF;
470                    let imm_11 = (imm >> 11) & 0x1;
471                    let imm_19_12 = (imm >> 12) & 0xFF;
472
473                    (imm_20 << 31)
474                        | (imm_10_1 << 21)
475                        | (imm_11 << 20)
476                        | (imm_19_12 << 12)
477                        | (rd << 7)
478                        | base_opcode
479                }
480
481                _ => unreachable!(),
482            }
483        }
484    }
485}
486
487impl Debug for Instruction {
488    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
489        let mnemonic = self.opcode.mnemonic();
490        let op_a_formatted = format!("%x{}", self.op_a);
491        let op_b_formatted = if self.imm_b || self.opcode == Opcode::AUIPC {
492            format!("{}", self.op_b as i32)
493        } else {
494            format!("%x{}", self.op_b)
495        };
496        let op_c_formatted =
497            if self.imm_c { format!("{}", self.op_c as i32) } else { format!("%x{}", self.op_c) };
498
499        let width = 10;
500        write!(
501            f,
502            "{mnemonic:<width$} {op_a_formatted:<width$} {op_b_formatted:<width$} {op_c_formatted:<width$}"
503        )
504    }
505}