Skip to main content

wasm_pvm/pvm/
blob.rs

1use super::Instruction;
2
3pub struct ProgramBlob {
4    instructions: Vec<Instruction>,
5    jump_table: Vec<u32>,
6}
7
8impl ProgramBlob {
9    #[must_use]
10    pub fn new(instructions: Vec<Instruction>) -> Self {
11        Self {
12            instructions,
13            jump_table: Vec::new(),
14        }
15    }
16
17    #[must_use]
18    pub fn with_jump_table(mut self, jump_table: Vec<u32>) -> Self {
19        self.jump_table = jump_table;
20        self
21    }
22
23    #[must_use]
24    pub fn instructions(&self) -> &[Instruction] {
25        &self.instructions
26    }
27
28    #[must_use]
29    pub fn encode(&self) -> Vec<u8> {
30        let (code, mask) = self.encode_code_and_mask();
31        let code_len = code.len();
32
33        let mut blob = Vec::new();
34
35        blob.extend(encode_var_u32(self.jump_table.len() as u32));
36
37        let item_len: u8 = if self.jump_table.is_empty() { 0 } else { 4 };
38        blob.push(item_len);
39
40        blob.extend(encode_var_u32(code_len as u32));
41
42        for &addr in &self.jump_table {
43            blob.extend(addr.to_le_bytes());
44        }
45
46        blob.extend(code);
47        blob.extend(mask);
48
49        blob
50    }
51
52    fn encode_code_and_mask(&self) -> (Vec<u8>, Vec<u8>) {
53        let mut code = Vec::new();
54        let mut mask_bits = Vec::new();
55
56        for instr in &self.instructions {
57            let encoded = instr.encode();
58            let start_offset = code.len();
59
60            code.extend(&encoded);
61
62            for i in 0..encoded.len() {
63                mask_bits.push(i == 0);
64            }
65
66            if instr.is_terminating() && start_offset + encoded.len() < code.len() {}
67        }
68
69        let mask = pack_mask(&mask_bits);
70        (code, mask)
71    }
72}
73
74fn pack_mask(bits: &[bool]) -> Vec<u8> {
75    let mut packed = Vec::new();
76    for chunk in bits.chunks(8) {
77        let mut byte: u8 = 0;
78        for (i, &bit) in chunk.iter().enumerate() {
79            if bit {
80                byte |= 1 << i;
81            }
82        }
83        packed.push(byte);
84    }
85    packed
86}
87
88pub(crate) fn encode_var_u32(value: u32) -> Vec<u8> {
89    if value == 0 {
90        return vec![0];
91    }
92
93    let value = u64::from(value);
94    let max_encoded: u64 = 1 << (7 * 8);
95
96    if value >= max_encoded {
97        let mut dest = vec![0xff];
98        dest.extend(&value.to_le_bytes());
99        return dest;
100    }
101
102    let mut min_encoded = max_encoded >> 7;
103    for l in (0..=7).rev() {
104        if value >= min_encoded {
105            let mut dest = vec![0u8; l + 1];
106            let max_val = 1u64 << (8 * l);
107            let first_byte = (1u64 << 8) - (1u64 << (8 - l)) + value / max_val;
108            dest[0] = first_byte as u8;
109
110            let mut rest = value % max_val;
111            for item in dest.iter_mut().skip(1) {
112                *item = rest as u8;
113                rest >>= 8;
114            }
115            return dest;
116        }
117        min_encoded >>= 7;
118    }
119
120    vec![value as u8]
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn test_encode_var_u32() {
129        assert_eq!(encode_var_u32(0), vec![0]);
130        assert_eq!(encode_var_u32(1), vec![1]);
131        assert_eq!(encode_var_u32(127), vec![127]);
132        assert_eq!(encode_var_u32(128), vec![0x80, 0x80]);
133        assert_eq!(encode_var_u32(145), vec![0x80, 0x91]);
134        assert_eq!(encode_var_u32(300), vec![0x81, 0x2c]);
135        assert_eq!(encode_var_u32(16383), vec![0xbf, 0xff]);
136        assert_eq!(encode_var_u32(16384), vec![0xc0, 0x00, 0x40]);
137    }
138
139    #[test]
140    fn test_pack_mask() {
141        assert_eq!(pack_mask(&[true, false, false]), vec![0b0000_0001]);
142        assert_eq!(
143            pack_mask(&[true, false, false, false, false, false, false, false, true]),
144            vec![0b0000_0001, 0b0000_0001]
145        );
146    }
147
148    #[test]
149    fn test_load_imm64_mask() {
150        // LoadImm64 encodes to 10 bytes, Trap to 1 byte
151        let instructions = vec![
152            Instruction::LoadImm64 {
153                reg: 7,
154                value: 0xFEFD_0000,
155            },
156            Instruction::Trap,
157        ];
158
159        let blob = ProgramBlob::new(instructions);
160        let (code, mask) = blob.encode_code_and_mask();
161
162        // LoadImm64 = 10 bytes, Trap = 1 byte → total 11 bytes
163        assert_eq!(code.len(), 11, "code length should be 11");
164
165        // Check that the opcode is correct
166        assert_eq!(code[0], 20, "first byte should be LoadImm64 opcode (20)");
167        assert_eq!(code[10], 0, "byte 10 should be Trap opcode (0)");
168
169        // Check mask: bit 0 = 1 (LoadImm64 start), bits 1-9 = 0, bit 10 = 1 (Trap start)
170        let mask_bits: Vec<bool> = (0..11)
171            .map(|pc| {
172                let byte_idx = pc / 8;
173                let bit_idx = pc % 8;
174                (mask[byte_idx] >> bit_idx) & 1 == 1
175            })
176            .collect();
177
178        assert!(mask_bits[0], "PC 0 should be instruction start");
179        for (pc, &is_start) in mask_bits.iter().enumerate().take(10).skip(1) {
180            assert!(!is_start, "PC {pc} should NOT be instruction start");
181        }
182        assert!(mask_bits[10], "PC 10 should be instruction start (Trap)");
183    }
184}