Skip to main content

symjit/
instruction.rs

1use num_complex::Complex;
2use serde::Deserialize;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
5pub struct BuiltinSymbol(pub u32);
6
7impl<'de> serde::Deserialize<'de> for BuiltinSymbol {
8    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
9        let id: u32 = u32::deserialize(deserializer)?;
10        Ok(BuiltinSymbol(id))
11    }
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)]
15pub enum Slot {
16    /// An entry in the list of parameters.
17    Param(usize),
18    /// An entry in the list of constants.
19    Const(usize),
20    /// An entry in the list of temporary storage.
21    Temp(usize),
22    /// An entry in the list of results.
23    Out(usize),
24    /// Static-Single-Assignment Form
25    Static(usize),
26    Arg(usize),
27}
28
29#[derive(Debug, Clone, Deserialize)]
30pub enum Instruction {
31    /// `Add(o, [i0,...,i_n])` means `o = i0 + ... + i_n`.
32    Add(Slot, Vec<Slot>, usize),
33    /// `Mul(o, [i0,...,i_n])` means `o = i0 * ... * i_n`.
34    Mul(Slot, Vec<Slot>, usize),
35    /// `Pow(o, b, e)` means `o = b^e`.
36    Pow(Slot, Slot, i64, bool),
37    /// `Powf(o, b, e)` means `o = b^e`.
38    Powf(Slot, Slot, Slot, bool),
39    /// A function that has a known evaluator or is external, given a symbol name, tags, and arguments.
40    /// `Fun(o, (s, t, a), is_real)` means `o = s(t, a)`.
41    /// The `is_real` flag indicates whether the function is expected to yield a real number.
42    /// Fun(Slot, Box<(Symbol, Vec<String>, Vec<Slot>)>, bool),
43    ///
44    /// Note that Symjit uses the following simplified version of Fun:
45    Fun(Slot, String, Vec<Slot>, bool),
46    /// `ExternalFun(o, s, a,...)` means `o = s(a, ...)`, where `s` is an external function.
47    ExternalFun(Slot, String, Vec<Slot>),
48    /// `Assign(o, v)` means `o = v`.
49    Assign(Slot, Slot),
50    /// `IfElse(cond, label)` means jump to `label` if `cond` is zero.
51    IfElse(Slot, usize),
52    /// Unconditional jump to `label`.
53    Goto(usize),
54    /// A position in the instruction list to jump to.
55    Label(usize),
56    /// `Join(o, cond, t, f)` means `o = cond ? t : f`.
57    Join(Slot, Slot, Slot, Slot),
58}
59
60#[derive(Debug, Clone, Deserialize)]
61pub enum Value {
62    Single(f64),
63}
64
65impl Value {
66    fn value(&self) -> f64 {
67        let Value::Single(x) = self;
68        *x
69    }
70}
71
72#[derive(Debug, Clone, Deserialize)]
73pub struct Rational {
74    pub numerator: Value,
75    pub denominator: Value,
76}
77
78impl Rational {
79    fn value(&self) -> f64 {
80        self.numerator.value() / self.denominator.value()
81    }
82}
83
84#[derive(Debug, Clone, Deserialize)]
85pub struct ComplexRational {
86    pub re: Rational,
87    pub im: Rational,
88}
89
90impl ComplexRational {
91    fn value(&self) -> Complex<f64> {
92        Complex::new(self.re.value(), self.im.value())
93    }
94}
95
96#[derive(Debug, Clone, Deserialize)]
97#[serde(untagged)]
98pub enum ConstType {
99    Complex(ComplexRational),
100    Single(f64),
101}
102
103impl ConstType {
104    pub fn value(&self) -> Complex<f64> {
105        match self {
106            ConstType::Single(x) => Complex::new(*x, 0.0),
107            ConstType::Complex(x) => x.value(),
108        }
109    }
110}
111
112#[derive(Debug, Clone, Deserialize)]
113pub struct SymbolicaModel(pub Vec<Instruction>, pub usize, pub Vec<ConstType>);