1use std::collections::HashMap;
2use thiserror::Error;
3
4#[derive(Debug, Error)]
5pub enum VMError {
6 #[error("Stack underflow")]
7 StackUnderflow,
8 #[error("Stack overflow")]
9 StackOverflow,
10 #[error("Invalid opcode: {0}")]
11 InvalidOpcode(u8),
12 #[error("Out of memory at address: {0}")]
13 OutOfMemory(usize),
14 #[error("Division by zero")]
15 DivisionByZero,
16}
17
18#[derive(Debug, Clone, Copy)]
19pub enum Opcode {
20 Push = 0x01,
21 Pop = 0x02,
22 Add = 0x03,
23 Sub = 0x04,
24 Mul = 0x05,
25 Div = 0x06,
26 Load = 0x07,
27 Store = 0x08,
28 Jump = 0x09,
29 JumpIf = 0x0A,
30 Equal = 0x0B,
31 Less = 0x0C,
32 Print = 0x0D,
33 Halt = 0xFF,
34}
35
36impl TryFrom<u8> for Opcode {
37 type Error = VMError;
38
39 fn try_from(value: u8) -> Result<Self, Self::Error> {
40 match value {
41 0x01 => Ok(Opcode::Push),
42 0x02 => Ok(Opcode::Pop),
43 0x03 => Ok(Opcode::Add),
44 0x04 => Ok(Opcode::Sub),
45 0x05 => Ok(Opcode::Mul),
46 0x06 => Ok(Opcode::Div),
47 0x07 => Ok(Opcode::Load),
48 0x08 => Ok(Opcode::Store),
49 0x09 => Ok(Opcode::Jump),
50 0x0A => Ok(Opcode::JumpIf),
51 0x0B => Ok(Opcode::Equal),
52 0x0C => Ok(Opcode::Less),
53 0x0D => Ok(Opcode::Print),
54 0xFF => Ok(Opcode::Halt),
55 _ => Err(VMError::InvalidOpcode(value)),
56 }
57 }
58}
59
60pub struct VM {
61 pc: usize,
63 stack: Vec<i64>,
65 program: Vec<u8>,
67 memory: HashMap<usize, i64>,
69 stack_limit: usize,
71 running: bool,
73}
74
75impl VM {
76 pub fn new(program: Vec<u8>, stack_limit: usize) -> Self {
77 VM {
78 pc: 0,
79 stack: Vec::with_capacity(stack_limit),
80 program,
81 memory: HashMap::new(),
82 stack_limit,
83 running: false,
84 }
85 }
86
87 fn push(&mut self, value: i64) -> Result<(), VMError> {
88 if self.stack.len() >= self.stack_limit {
89 return Err(VMError::StackOverflow);
90 }
91 self.stack.push(value);
92 Ok(())
93 }
94
95 fn pop(&mut self) -> Result<i64, VMError> {
96 self.stack.pop().ok_or(VMError::StackUnderflow)
97 }
98
99 fn fetch(&mut self) -> Option<u8> {
100 if self.pc < self.program.len() {
101 let opcode = self.program[self.pc];
102 self.pc += 1;
103 Some(opcode)
104 } else {
105 None
106 }
107 }
108
109 fn fetch_i64(&mut self) -> Option<i64> {
110 if self.pc + 8 <= self.program.len() {
111 let bytes = &self.program[self.pc..self.pc + 8];
112 self.pc += 8;
113 Some(i64::from_le_bytes(bytes.try_into().unwrap()))
114 } else {
115 None
116 }
117 }
118
119 pub fn execute_next(&mut self) -> Result<bool, VMError> {
120 let opcode = self.fetch().ok_or(VMError::InvalidOpcode(0))?;
121 match Opcode::try_from(opcode)? {
122 Opcode::Push => {
123 let value = self.fetch_i64().ok_or(VMError::InvalidOpcode(opcode))?;
124 self.push(value)?;
125 }
126 Opcode::Pop => {
127 self.pop()?;
128 }
129 Opcode::Add => {
130 let b = self.pop()?;
131 let a = self.pop()?;
132 self.push(a + b)?;
133 }
134 Opcode::Sub => {
135 let b = self.pop()?;
136 let a = self.pop()?;
137 self.push(a - b)?;
138 }
139 Opcode::Mul => {
140 let b = self.pop()?;
141 let a = self.pop()?;
142 self.push(a * b)?;
143 }
144 Opcode::Div => {
145 let b = self.pop()?;
146 let a = self.pop()?;
147 if b == 0 {
148 return Err(VMError::DivisionByZero);
149 }
150 self.push(a / b)?;
151 }
152 Opcode::Load => {
153 let addr = self.pop()? as usize;
154 let value = *self.memory.get(&addr).unwrap_or(&0);
155 self.push(value)?;
156 }
157 Opcode::Store => {
158 let value = self.pop()?;
159 let addr = self.pop()? as usize;
160 self.memory.insert(addr, value);
161 }
162 Opcode::Jump => {
163 let addr = self.pop()? as usize;
164 if addr >= self.program.len() {
165 return Err(VMError::OutOfMemory(addr));
166 }
167 self.pc = addr;
168 }
169 Opcode::JumpIf => {
170 let addr = self.pop()? as usize;
171 let condition = self.pop()?;
172 if condition != 0 {
173 if addr >= self.program.len() {
174 return Err(VMError::OutOfMemory(addr));
175 }
176 self.pc = addr;
177 }
178 }
179 Opcode::Equal => {
180 let b = self.pop()?;
181 let a = self.pop()?;
182 self.push(if a == b { 1 } else { 0 })?;
183 }
184 Opcode::Less => {
185 let b = self.pop()?;
186 let a = self.pop()?;
187 self.push(if a < b { 1 } else { 0 })?;
188 }
189 Opcode::Print => {
190 let value = self.pop()?;
191 println!("Output: {}", value);
192 }
193 Opcode::Halt => {
194 self.running = false;
195 return Ok(false);
196 }
197 }
198 Ok(true)
199 }
200
201 pub fn run(&mut self) -> Result<(), VMError> {
202 self.running = true;
203 while self.running {
204 if !self.execute_next()? {
205 break;
206 }
207 }
208 Ok(())
209 }
210
211 pub fn get_stack(&self) -> &[i64] {
212 &self.stack
213 }
214
215 pub fn get_memory(&self) -> &HashMap<usize, i64> {
216 &self.memory
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_push_pop() {
226 let program = vec![
227 Opcode::Push as u8,
228 42, 0, 0, 0, 0, 0, 0, 0, Opcode::Push as u8,
230 123, 0, 0, 0, 0, 0, 0, 0, Opcode::Pop as u8,
232 Opcode::Halt as u8,
233 ];
234
235 let mut vm = VM::new(program, 100);
236 vm.run().unwrap();
237
238 assert_eq!(vm.get_stack(), &[42]);
239 }
240
241 #[test]
242 fn test_arithmetic() {
243 let program = vec![
244 Opcode::Push as u8,
245 10, 0, 0, 0, 0, 0, 0, 0, Opcode::Push as u8,
247 5, 0, 0, 0, 0, 0, 0, 0, Opcode::Add as u8, Opcode::Push as u8,
250 2, 0, 0, 0, 0, 0, 0, 0, Opcode::Mul as u8, Opcode::Halt as u8,
253 ];
254
255 let mut vm = VM::new(program, 100);
256 vm.run().unwrap();
257
258 assert_eq!(vm.get_stack(), &[30]);
259 }
260}