stak_vm/
vm.rs

1#[cfg(feature = "profile")]
2use crate::profiler::Profiler;
3use crate::{
4    Error, Exception, StackSlot,
5    code::{INTEGER_BASE, NUMBER_BASE, SHARE_BASE, TAG_BASE},
6    cons::{Cons, NEVER},
7    instruction::Instruction,
8    memory::Memory,
9    number::Number,
10    primitive_set::PrimitiveSet,
11    r#type::Type,
12    value::{TypedValue, Value},
13};
14#[cfg(feature = "profile")]
15use core::cell::RefCell;
16use core::fmt::{self, Display, Formatter, Write};
17use stak_util::block_on;
18use winter_maybe_async::{maybe_async, maybe_await};
19
20macro_rules! trace {
21    ($prefix:literal, $data:expr) => {
22        #[cfg(feature = "trace_instruction")]
23        std::eprintln!("{}: {}", $prefix, $data);
24    };
25}
26
27macro_rules! trace_memory {
28    ($self:expr) => {
29        #[cfg(feature = "trace_memory")]
30        std::eprintln!("{}", $self);
31    };
32}
33
34macro_rules! profile_event {
35    ($self:expr, $name:literal) => {
36        #[cfg(feature = "profile")]
37        (&$self).profile_event($name)?;
38    };
39}
40
41#[derive(Clone, Copy, Debug, PartialEq, Eq)]
42struct Arity {
43    // A count does not include a variadic argument.
44    count: usize,
45    variadic: bool,
46}
47
48/// A virtual machine.
49pub struct Vm<'a, T: PrimitiveSet> {
50    primitive_set: T,
51    memory: Memory<'a>,
52    #[cfg(feature = "profile")]
53    profiler: Option<RefCell<&'a mut dyn Profiler>>,
54}
55
56// Note that some routines look unnecessarily complicated as we need to mark all
57// volatile variables live across garbage collections.
58impl<'a, T: PrimitiveSet> Vm<'a, T> {
59    /// Creates a virtual machine.
60    pub fn new(heap: &'a mut [Value], primitive_set: T) -> Result<Self, Error> {
61        Ok(Self {
62            primitive_set,
63            memory: Memory::new(heap)?,
64            #[cfg(feature = "profile")]
65            profiler: None,
66        })
67    }
68
69    /// Sets a profiler.
70    #[cfg(feature = "profile")]
71    pub fn with_profiler(self, profiler: &'a mut dyn Profiler) -> Self {
72        Self {
73            profiler: Some(profiler.into()),
74            ..self
75        }
76    }
77
78    /// Returns a reference to a primitive set.
79    pub const fn primitive_set(&self) -> &T {
80        &self.primitive_set
81    }
82
83    /// Returns a mutable reference to a primitive set.
84    pub const fn primitive_set_mut(&mut self) -> &mut T {
85        &mut self.primitive_set
86    }
87
88    /// Runs bytecode on a virtual machine synchronously.
89    ///
90    /// # Panics
91    ///
92    /// Panics if asynchronous operations occur during the run.
93    pub fn run(&mut self) -> Result<(), T::Error> {
94        block_on!(self.run_async())
95    }
96
97    /// Runs bytecode on a virtual machine.
98    #[cfg_attr(not(feature = "async"), doc(hidden))]
99    #[maybe_async]
100    pub fn run_async(&mut self) -> Result<(), T::Error> {
101        while let Err(error) = maybe_await!(self.run_with_continuation()) {
102            if error.is_critical() {
103                return Err(error);
104            }
105
106            let Some(continuation) = self.memory.cdr(self.memory.null()?)?.to_cons() else {
107                return Err(error);
108            };
109
110            if self.memory.cdr(continuation)?.tag() != Type::Procedure as _ {
111                return Err(error);
112            }
113
114            self.memory.set_register(continuation);
115            let string = self.memory.build_string("")?;
116            let symbol = self.memory.allocate(
117                self.memory.register().into(),
118                string.set_tag(Type::Symbol as _).into(),
119            )?;
120            let code = self.memory.allocate(
121                symbol.into(),
122                self.memory
123                    .code()
124                    .set_tag(
125                        Instruction::Call as u16
126                            + Self::build_arity(Arity {
127                                count: 1,
128                                variadic: false,
129                            }) as u16,
130                    )
131                    .into(),
132            )?;
133            self.memory.set_code(code);
134
135            self.memory.set_register(self.memory.null()?);
136            write!(&mut self.memory, "{error}").map_err(Error::from)?;
137            let code = self.memory.allocate(
138                self.memory.register().into(),
139                self.memory
140                    .code()
141                    .set_tag(Instruction::Constant as _)
142                    .into(),
143            )?;
144            self.memory.set_code(code);
145        }
146
147        Ok(())
148    }
149
150    #[maybe_async]
151    fn run_with_continuation(&mut self) -> Result<(), T::Error> {
152        while self.memory.code() != self.memory.null()? {
153            let instruction = self.memory.cdr(self.memory.code())?.assume_cons();
154
155            trace!("instruction", instruction.tag());
156
157            match instruction.tag() {
158                Instruction::CONSTANT => self.constant()?,
159                Instruction::GET => self.get()?,
160                Instruction::SET => self.set()?,
161                Instruction::IF => self.r#if()?,
162                code => maybe_await!(
163                    self.call(instruction, code as usize - Instruction::CALL as usize)
164                )?,
165            }
166
167            self.advance_code()?;
168
169            trace_memory!(self);
170        }
171
172        Ok(())
173    }
174
175    #[inline]
176    fn constant(&mut self) -> Result<(), Error> {
177        let constant = self.operand()?;
178
179        trace!("constant", constant);
180
181        self.memory.push(constant)?;
182
183        Ok(())
184    }
185
186    #[inline]
187    fn get(&mut self) -> Result<(), Error> {
188        let operand = self.operand_cons()?;
189        let value = self.memory.car(operand)?;
190
191        trace!("operand", operand);
192        trace!("value", value);
193
194        self.memory.push(value)?;
195
196        Ok(())
197    }
198
199    #[inline]
200    fn set(&mut self) -> Result<(), Error> {
201        let operand = self.operand_cons()?;
202        let value = self.memory.pop()?;
203
204        trace!("operand", operand);
205        trace!("value", value);
206
207        self.memory.set_car(operand, value)?;
208
209        Ok(())
210    }
211
212    #[inline]
213    fn r#if(&mut self) -> Result<(), Error> {
214        let cons = self.memory.stack();
215
216        if self.memory.pop()? != self.memory.boolean(false)?.into() {
217            self.memory.set_cdr(cons, self.operand()?)?;
218            self.memory.set_code(cons);
219        }
220
221        Ok(())
222    }
223
224    #[inline(always)]
225    #[maybe_async]
226    fn call(&mut self, instruction: Cons, arity: usize) -> Result<(), T::Error> {
227        let procedure = self.procedure()?;
228
229        trace!("procedure", procedure);
230
231        if self.environment(procedure)?.tag() != Type::Procedure as _ {
232            return Err(Error::ProcedureExpected.into());
233        }
234
235        let arguments = Self::parse_arity(arity);
236        let r#return = instruction == self.memory.null()?;
237
238        trace!("return", r#return);
239
240        match self.code(procedure)?.to_typed() {
241            TypedValue::Cons(code) => {
242                #[cfg(feature = "profile")]
243                self.profile_call(self.memory.code(), r#return)?;
244
245                let parameters =
246                    Self::parse_arity(self.memory.car(code)?.assume_number().to_i64() as usize);
247
248                trace!("argument count", arguments.count);
249                trace!("argument variadic", arguments.variadic);
250                trace!("parameter count", parameters.count);
251                trace!("parameter variadic", parameters.variadic);
252
253                self.memory.set_register(procedure);
254
255                let mut list = if arguments.variadic {
256                    self.memory.pop()?.assume_cons()
257                } else {
258                    self.memory.null()?
259                };
260
261                for _ in 0..arguments.count {
262                    let value = self.memory.pop()?;
263                    list = self.memory.cons(value, list)?;
264                }
265
266                // Use a `code` field as an escape cell for a procedure.
267                let code = self.memory.code();
268                self.memory.set_code(self.memory.register());
269                self.memory.set_register(list);
270
271                let continuation = if r#return {
272                    self.continuation()?
273                } else {
274                    self.memory
275                        .allocate(code.into(), self.memory.stack().into())?
276                };
277                let stack = self.memory.allocate(
278                    continuation.into(),
279                    self.environment(self.memory.code())?
280                        .set_tag(StackSlot::Frame as _)
281                        .into(),
282                )?;
283                self.memory.set_stack(stack);
284                self.memory
285                    .set_code(self.code(self.memory.code())?.assume_cons());
286
287                for _ in 0..parameters.count {
288                    if self.memory.register() == self.memory.null()? {
289                        return Err(Error::ArgumentCount.into());
290                    }
291
292                    self.memory.push(self.memory.car(self.memory.register())?)?;
293                    self.memory
294                        .set_register(self.memory.cdr(self.memory.register())?.assume_cons());
295                }
296
297                if parameters.variadic {
298                    self.memory.push(self.memory.register().into())?;
299                } else if self.memory.register() != self.memory.null()? {
300                    return Err(Error::ArgumentCount.into());
301                }
302            }
303            TypedValue::Number(primitive) => {
304                if arguments.variadic {
305                    let list = self.memory.pop()?.assume_cons();
306                    self.memory.set_register(list);
307
308                    while self.memory.register() != self.memory.null()? {
309                        self.memory.push(self.memory.car(self.memory.register())?)?;
310                        self.memory
311                            .set_register(self.memory.cdr(self.memory.register())?.assume_cons());
312                    }
313                }
314
315                maybe_await!(
316                    self.primitive_set
317                        .operate(&mut self.memory, primitive.to_i64() as _)
318                )?;
319            }
320        }
321
322        Ok(())
323    }
324
325    #[inline]
326    const fn parse_arity(info: usize) -> Arity {
327        Arity {
328            count: info / 2,
329            variadic: info % 2 == 1,
330        }
331    }
332
333    #[inline]
334    const fn build_arity(arity: Arity) -> usize {
335        2 * arity.count + arity.variadic as usize
336    }
337
338    #[inline]
339    fn advance_code(&mut self) -> Result<(), Error> {
340        let mut code = self.memory.cdr(self.memory.code())?.assume_cons();
341
342        if code == self.memory.null()? {
343            #[cfg(feature = "profile")]
344            self.profile_return()?;
345
346            let continuation = self.continuation()?;
347            // Keep a value at the top of a stack.
348            self.memory
349                .set_cdr(self.memory.stack(), self.memory.cdr(continuation)?)?;
350
351            code = self
352                .memory
353                .cdr(self.memory.car(continuation)?.assume_cons())?
354                .assume_cons();
355        }
356
357        self.memory.set_code(code);
358
359        Ok(())
360    }
361
362    fn operand(&self) -> Result<Value, Error> {
363        self.memory.car(self.memory.code())
364    }
365
366    fn operand_cons(&self) -> Result<Cons, Error> {
367        Ok(match self.operand()?.to_typed() {
368            TypedValue::Cons(cons) => cons,
369            TypedValue::Number(index) => {
370                self.memory.tail(self.memory.stack(), index.to_i64() as _)?
371            }
372        })
373    }
374
375    // (code . environment)
376    fn procedure(&self) -> Result<Cons, Error> {
377        Ok(self.memory.car(self.operand_cons()?)?.assume_cons())
378    }
379
380    // (parameter-count . instruction-list) | primitive-id
381    fn code(&self, procedure: Cons) -> Result<Value, Error> {
382        self.memory.car(procedure)
383    }
384
385    fn environment(&self, procedure: Cons) -> Result<Cons, Error> {
386        Ok(self.memory.cdr(procedure)?.assume_cons())
387    }
388
389    // (code . stack)
390    fn continuation(&self) -> Result<Cons, Error> {
391        let mut stack = self.memory.stack();
392
393        while self.memory.cdr(stack)?.assume_cons().tag() != StackSlot::Frame as _ {
394            stack = self.memory.cdr(stack)?.assume_cons();
395        }
396
397        Ok(self.memory.car(stack)?.assume_cons())
398    }
399
400    // Profiling
401
402    #[cfg(feature = "profile")]
403    fn profile_call(&self, call_code: Cons, r#return: bool) -> Result<(), Error> {
404        if let Some(profiler) = &self.profiler {
405            profiler
406                .borrow_mut()
407                .profile_call(&self.memory, call_code, r#return)?;
408        }
409
410        Ok(())
411    }
412
413    #[cfg(feature = "profile")]
414    fn profile_return(&self) -> Result<(), Error> {
415        if let Some(profiler) = &self.profiler {
416            profiler.borrow_mut().profile_return(&self.memory)?;
417        }
418
419        Ok(())
420    }
421
422    #[cfg(feature = "profile")]
423    fn profile_event(&self, name: &str) -> Result<(), Error> {
424        if let Some(profiler) = &self.profiler {
425            profiler.borrow_mut().profile_event(name)?;
426        }
427
428        Ok(())
429    }
430
431    /// Initializes a virtual machine with bytecode of a program.
432    pub fn initialize(&mut self, input: impl IntoIterator<Item = u8>) -> Result<(), super::Error> {
433        profile_event!(self, "initialization_start");
434        profile_event!(self, "decode_start");
435
436        let program = self.decode_ribs(&mut input.into_iter())?;
437        self.memory
438            .set_false(self.memory.car(program)?.assume_cons());
439        self.memory
440            .set_code(self.memory.cdr(program)?.assume_cons());
441
442        profile_event!(self, "decode_end");
443
444        // Initialize an implicit top-level frame.
445        let codes = self
446            .memory
447            .cons(Number::default().into(), self.memory.null()?)?
448            .into();
449        let continuation = self.memory.cons(codes, self.memory.null()?)?.into();
450        let stack = self.memory.allocate(
451            continuation,
452            self.memory.null()?.set_tag(StackSlot::Frame as _).into(),
453        )?;
454        self.memory.set_stack(stack);
455        self.memory.set_register(NEVER);
456
457        profile_event!(self, "initialization_end");
458
459        Ok(())
460    }
461
462    fn decode_ribs(&mut self, input: &mut impl Iterator<Item = u8>) -> Result<Cons, Error> {
463        while let Some(head) = input.next() {
464            if head & 1 == 0 {
465                let cdr = self.memory.top()?;
466                let cons = self
467                    .memory
468                    .allocate(Number::from_i64((head >> 1) as _).into(), cdr)?;
469                self.memory.set_top(cons.into())?;
470            } else if head & 0b10 == 0 {
471                let head = head >> 2;
472
473                if head == 0 {
474                    let value = self.memory.top()?;
475                    let cons = self.memory.cons(value, self.memory.code())?;
476                    self.memory.set_code(cons);
477                } else {
478                    let integer = Self::decode_integer_tail(input, head - 1, SHARE_BASE)?;
479                    let index = integer >> 1;
480
481                    if index > 0 {
482                        let cons = self.memory.tail(self.memory.code(), index as usize - 1)?;
483                        let head = self.memory.cdr(cons)?.assume_cons();
484                        let tail = self.memory.cdr(head)?;
485                        self.memory.set_cdr(head, self.memory.code().into())?;
486                        self.memory.set_cdr(cons, tail)?;
487                        self.memory.set_code(head);
488                    }
489
490                    let value = self.memory.car(self.memory.code())?;
491
492                    if integer & 1 == 0 {
493                        self.memory
494                            .set_code(self.memory.cdr(self.memory.code())?.assume_cons());
495                    }
496
497                    self.memory.push(value)?;
498                }
499            } else if head & 0b100 == 0 {
500                let cons = self.memory.stack();
501                let cdr = self.memory.pop()?;
502                let car = self.memory.top()?;
503                let tag = Self::decode_integer_tail(input, head >> 3, TAG_BASE)?;
504                self.memory.set_car(cons, car)?;
505                self.memory.set_raw_cdr(cons, cdr.set_tag(tag as _))?;
506                self.memory.set_top(cons.into())?;
507            } else {
508                self.memory.push(
509                    Self::decode_number(Self::decode_integer_tail(input, head >> 3, NUMBER_BASE)?)
510                        .into(),
511                )?;
512            }
513        }
514
515        self.memory.pop()?.to_cons().ok_or(Error::BytecodeEnd)
516    }
517
518    fn decode_number(integer: u128) -> Number {
519        if integer & 1 == 0 {
520            Number::from_i64((integer >> 1) as _)
521        } else if integer & 0b10 == 0 {
522            Number::from_i64(-((integer >> 2) as i64))
523        } else {
524            let integer = integer >> 2;
525            let mantissa =
526                if integer.is_multiple_of(2) { 1.0 } else { -1.0 } * (integer >> 12) as f64;
527            let exponent = ((integer >> 1) % (1 << 11)) as isize - 1023;
528
529            Number::from_f64(if exponent < 0 {
530                mantissa / (1u64 << exponent.abs()) as f64
531            } else {
532                mantissa * (1u64 << exponent) as f64
533            })
534        }
535    }
536
537    fn decode_integer_tail(
538        input: &mut impl Iterator<Item = u8>,
539        mut x: u8,
540        mut base: u128,
541    ) -> Result<u128, Error> {
542        let mut y = (x >> 1) as u128;
543
544        while x & 1 != 0 {
545            x = input.next().ok_or(Error::BytecodeEnd)?;
546            y += (x as u128 >> 1) * base;
547            base *= INTEGER_BASE;
548        }
549
550        Ok(y)
551    }
552}
553
554impl<T: PrimitiveSet> Display for Vm<'_, T> {
555    fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
556        write!(formatter, "{}", &self.memory)
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563
564    struct FakePrimitiveSet {}
565
566    impl PrimitiveSet for FakePrimitiveSet {
567        type Error = Error;
568
569        #[maybe_async]
570        fn operate(
571            &mut self,
572            _memory: &mut Memory<'_>,
573            _primitive: usize,
574        ) -> Result<(), Self::Error> {
575            Ok(())
576        }
577    }
578
579    type VoidVm = Vm<'static, FakePrimitiveSet>;
580
581    #[test]
582    fn arity() {
583        for arity in [
584            Arity {
585                count: 0,
586                variadic: false,
587            },
588            Arity {
589                count: 1,
590                variadic: false,
591            },
592            Arity {
593                count: 2,
594                variadic: false,
595            },
596            Arity {
597                count: 0,
598                variadic: true,
599            },
600            Arity {
601                count: 1,
602                variadic: true,
603            },
604            Arity {
605                count: 2,
606                variadic: true,
607            },
608        ] {
609            assert_eq!(VoidVm::parse_arity(VoidVm::build_arity(arity)), arity);
610        }
611    }
612}