Skip to main content

symjit/
compiler.rs

1use std::collections::{HashMap, HashSet};
2
3use anyhow::{anyhow, Result};
4use num_complex::Complex;
5use rayon::prelude::*;
6
7use crate::code::VirtualTable;
8use crate::config::{Config, SLICE_CAP};
9use crate::defuns::Defuns;
10use crate::expr::Expr;
11pub use crate::instruction::{BuiltinSymbol, Instruction, Slot, SymbolicaModel};
12use crate::model::{CellModel, Equation, Program, Variable};
13use crate::parser::Parser;
14use crate::symbol::Loc;
15use crate::types::{ElemType, Element};
16use crate::utils::{Compiled, CompiledFunc};
17use crate::Application;
18
19// #[derive(Debug)]
20pub struct Compiler {
21    config: Config,
22    df: Defuns,
23}
24
25#[cfg(not(target_arch = "x86_64"))]
26#[allow(non_camel_case_types)]
27type __m256d = [f64; 4];
28
29/// The central hub of the Rust interface. It compiles a list of
30/// variables and expressions into a callable object (of type `Application`).
31///
32/// # Workflow
33///
34/// 1. Create terminals (variables and constants) and compose expressions using `Expr` methods:
35///    * Constructors: `var`, `from`, `unary`, `binary`, ...
36///    * Standard algebraic operations: `add`, `mul`, ...
37///    * Standard operators `+`, `-`, `*`, `/`, `%`, `&`, `|`, `^`, `!`.
38///    * Unary functions such as `sin`, `exp`, and other standard mathematical functions.
39///    * Binary functions such as `pow`, `min`, ...
40///    * IfElse operation `ifelse(cond, true_val, false_val)`.
41///    * Heavide function: `heaviside(x)`, which returns 1 if `x >= 0`; otherwise 0.
42///    * Comparison methods `eq`, `ne`, `lt`, `le`, `gt`, and `ge`.
43///    * Looping constructs `sum` and `prod`.
44/// 2. Create a new `Compiler` object (say, `comp`) using one of its constructors: `new()`
45///    or `with_compile_type(ty: CompilerType)`.
46/// 3. Fine-tune the optimization passes using `opt_level`, `simd`, `fastmath`,
47///    and `cse` methods (optional).
48/// 4. Define user-defined functions by called `comp.def_unary` and `comp.def_binary`
49///    (optional).
50/// 5. Compile by calling `comp.compile` or `comp.compile_params`. The result is of
51///    type `Application` (say, `app`).
52/// 6. Execute the compiled code using one of the `app`'s `call` functions:
53///    * `call(&[f64])`: scalar call.
54///    * `call_params(&[f64], &[f64])`: scalar call with parameters.
55///    * `call_simd(&[__m256d])`: simd call.
56///    * `call_simd_params(&[__m256d], &[f64])`: simd call with parameters.
57/// 7. Optionally, generate a standalone fast function to execute.
58///
59///
60/// # Examples
61///
62/// ```rust
63/// use anyhow::Result;
64/// use symjit::{Compiler, Expr};
65///
66/// pub fn main() -> Result<()> {
67///     let x = Expr::var("x");
68///     let y = Expr::var("y");
69///     let u = &x + &y;
70///     let v = &x * &y;
71///
72///     let mut config = Config::default();
73///     config.set_opt_level(2);
74///     let mut comp = Compiler::with_config(config);
75///     let mut app = comp.compile(&[x, y], &[u, v])?;
76///     let res = app.call(&[3.0, 5.0]);
77///     println!("{:?}", &res);
78///
79///     Ok(())
80/// }
81/// ```
82impl Compiler {
83    /// Creates a new `Compiler` object with default settings.
84    pub fn new() -> Compiler {
85        Compiler {
86            config: Config::default(),
87            df: Defuns::new(),
88        }
89    }
90
91    pub fn with_config(config: Config) -> Compiler {
92        Compiler {
93            config,
94            df: Defuns::new(),
95        }
96    }
97
98    /// Compiles a model.
99    ///
100    /// `states` is a list of variables, created by `Expr::var`.
101    /// `obs` is a list of expressions.
102    pub fn compile(&mut self, states: &[Expr], obs: &[Expr]) -> Result<Application> {
103        self.compile_params(states, obs, &[])
104    }
105
106    /// Compiles a model with parameters.
107    ///
108    /// `states` is a list of variables, created by `Expr::var`.
109    /// `obs` is a list of expressions.
110    /// `params` is a list of parameters, created by `Expr::var`.
111    ///
112    /// Note: for scalar functions, the difference between states and params
113    ///     is mostly by convenion. However, they are different in SIMD cases,
114    ///     as params are always f64.
115    pub fn compile_params(
116        &mut self,
117        states: &[Expr],
118        obs: &[Expr],
119        params: &[Expr],
120    ) -> Result<Application> {
121        let mut vars: Vec<Variable> = Vec::new();
122
123        for state in states.iter() {
124            let v = state.to_variable()?;
125            vars.push(v);
126        }
127
128        let mut ps: Vec<Variable> = Vec::new();
129
130        for p in params.iter() {
131            let v = p.to_variable()?;
132            ps.push(v);
133        }
134
135        let mut eqs: Vec<Equation> = Vec::new();
136
137        for (i, expr) in obs.iter().enumerate() {
138            let name = format!("${}", i);
139            let lhs = Expr::var(&name);
140            eqs.push(Expr::equation(&lhs, expr));
141        }
142
143        let ml = CellModel {
144            iv: Expr::var("$_").to_variable()?,
145            params: ps,
146            states: vars,
147            algs: Vec::new(),
148            odes: Vec::new(),
149            obs: eqs,
150        };
151
152        let prog = Program::new(&ml, self.config)?;
153        // let df = Defuns::new();
154        let mut app = Application::new(prog, HashSet::new(), std::mem::take(&mut self.df))?;
155        // app.prepare_simd();
156
157        // #[cfg(target_arch = "aarch64")]
158        // if let Ok(app) = &app {
159        //     // this is a hack to give enough delay to prevent a bus error
160        //     app.dump("dump.bin", "scalar");
161        //     std::fs::remove_file("dump.bin")?;
162        // };
163
164        Ok(app)
165    }
166
167    /// Registers a user-defined unary function.
168    pub fn def_unary(&mut self, op: &str, f: extern "C" fn(f64) -> f64) {
169        self.df.add_unary(op, f)
170    }
171
172    /// Registers a user-defined binary function.
173    pub fn def_binary(&mut self, op: &str, f: extern "C" fn(f64, f64) -> f64) {
174        self.df.add_binary(op, f)
175    }
176}
177
178pub enum FastFunc<'a> {
179    F1(extern "C" fn(f64) -> f64, &'a Application),
180    F2(extern "C" fn(f64, f64) -> f64, &'a Application),
181    F3(extern "C" fn(f64, f64, f64) -> f64, &'a Application),
182    F4(extern "C" fn(f64, f64, f64, f64) -> f64, &'a Application),
183    F5(
184        extern "C" fn(f64, f64, f64, f64, f64) -> f64,
185        &'a Application,
186    ),
187    F6(
188        extern "C" fn(f64, f64, f64, f64, f64, f64) -> f64,
189        &'a Application,
190    ),
191    F7(
192        extern "C" fn(f64, f64, f64, f64, f64, f64, f64) -> f64,
193        &'a Application,
194    ),
195    F8(
196        extern "C" fn(f64, f64, f64, f64, f64, f64, f64, f64) -> f64,
197        &'a Application,
198    ),
199}
200
201impl Application {
202    /// Calls the compiled function.
203    ///
204    /// `args` is a slice of f64 values, corresponding to the states.
205    ///
206    /// The output is a `Vec<f64>`, corresponding to the observables (the expressions passed
207    /// to `compile`).
208    pub fn call(&mut self, args: &[f64]) -> Vec<f64> {
209        if let Some(f) = &mut self.compiled {
210            {
211                let mem = f.mem_mut();
212                let states = &mut mem[self.first_state..self.first_state + self.count_states];
213                states.copy_from_slice(args);
214            }
215
216            f.exec(&self.params[..]);
217
218            let obs = {
219                let mem = f.mem();
220                &mem[self.first_obs..self.first_obs + self.count_obs]
221            };
222
223            obs.to_vec()
224        } else {
225            Vec::new()
226        }
227    }
228
229    /// Sets the params and calls the compiled function.
230    ///
231    /// `args` is a slice of f64 values, corresponding to the states.
232    /// `params` is a slice of f64 values, corresponding to the params.
233    ///
234    /// The output is a `Vec<f64>`, corresponding to the observables (the expressions passed
235    /// to `compile`).
236    pub fn call_params(&mut self, args: &[f64], params: &[f64]) -> Vec<f64> {
237        if let Some(f) = &mut self.compiled {
238            {
239                let mem = f.mem_mut();
240                let states = &mut mem[self.first_state..self.first_state + self.count_states];
241                states.copy_from_slice(args);
242            }
243
244            f.exec(params);
245
246            let obs = {
247                let mem = f.mem();
248                &mem[self.first_obs..self.first_obs + self.count_obs]
249            };
250
251            obs.to_vec()
252        } else {
253            Vec::new()
254        }
255    }
256
257    pub fn interpret<T>(&mut self, args: &[T], outs: &mut [T])
258    where
259        T: Element,
260    {
261        let args = recast_as_f64(args);
262        let outs = recast_as_f64_mut(outs);
263
264        let mut regs = [0.0; 32];
265        self.bytecode
266            .mir
267            .exec_instruction(outs, &mut self.bytecode.stack, &mut regs, args);
268    }
269
270    pub fn interpret_matrix(&mut self, args: &[f64], outs: &mut [f64], n: usize) {
271        let count_params = self.count_params;
272        let count_obs = self.count_obs;
273
274        for i in 0..n {
275            self.interpret(
276                &args[i * count_params..(i + 1) * count_params],
277                &mut outs[i * count_obs..(i + 1) * count_obs],
278            );
279        }
280    }
281
282    /// Generic evaluate function for compiled Symbolica expressions
283    pub fn evaluate<T>(&self, args: &[T], outs: &mut [T])
284    where
285        T: Element,
286    {
287        let args = recast_as_f64(args);
288        let outs = recast_as_f64_mut(outs);
289
290        let simd = matches!(
291            T::get_type(T::default()),
292            ElemType::RealF64x2(_)
293                | ElemType::RealF64x4(_)
294                | ElemType::ComplexF64x2(_)
295                | ElemType::ComplexF64x4(_)
296        );
297
298        if let Some(f) = &self.compiled {
299            if !simd {
300                f.func()(outs.as_mut_ptr(), std::ptr::null(), 0, args.as_ptr());
301            } else if let Some(g) = &self.compiled_simd {
302                g.func()(outs.as_mut_ptr(), std::ptr::null(), 0, args.as_ptr());
303            }
304        }
305    }
306
307    /// Generic evaluate_single function for compiled Symbolica expressions
308    #[inline(always)]
309    pub fn evaluate_single<T>(&self, args: &[T]) -> T
310    where
311        T: Element + Copy,
312    {
313        let mut outs = [T::default(); 1];
314        self.evaluate(args, &mut outs);
315        outs[0]
316    }
317
318    /// Evaluates a single logical row. It could be a combinatino of multiple
319    /// physical rows because of implicit SIMD.
320    fn evaluate_row(
321        args: &[f64],
322        args_idx: usize,
323        outs: &[f64],
324        outs_idx: usize,
325        f: CompiledFunc<f64>,
326        transpose: bool,
327    ) -> i32 {
328        unsafe {
329            f(
330                outs.as_ptr().add(outs_idx),
331                std::ptr::null(),
332                if transpose { 1 } else { 0 },
333                args.as_ptr().add(args_idx),
334            )
335        }
336    }
337
338    fn evaluate_matrix_with_threads(&self, args: &[f64], outs: &mut [f64], n: usize) {
339        if let Some(f) = &self.compiled {
340            let count_params = self.count_params;
341            let count_obs = self.count_obs;
342            let f_scalar = f.func();
343
344            (0..n).into_par_iter().for_each(|t| {
345                Self::evaluate_row(args, t * count_params, outs, t * count_obs, f_scalar, false);
346            });
347        }
348    }
349
350    fn evaluate_matrix_without_threads(&self, args: &[f64], outs: &mut [f64], n: usize) {
351        if let Some(f) = &self.compiled {
352            let count_params = self.count_params;
353            let count_obs = self.count_obs;
354            let f_scalar = f.func();
355
356            for t in 0..n {
357                Self::evaluate_row(args, t * count_params, outs, t * count_obs, f_scalar, false);
358            }
359        }
360    }
361
362    fn evaluate_matrix_with_threads_simd(
363        &self,
364        args: &[f64],
365        outs: &mut [f64],
366        n: usize,
367        transpose: bool,
368    ) {
369        if let Some(f) = &self.compiled {
370            let count_params = self.count_params;
371            let count_obs = self.count_obs;
372
373            if let Some(compiled) = &self.compiled_simd {
374                let f_simd = compiled.func();
375                let f_scalar = f.func();
376                let lanes = compiled.count_lanes();
377                let step = if transpose { lanes } else { 1 };
378
379                (0..n / step).into_par_iter().for_each(|k| {
380                    let top = k * lanes;
381                    if Self::evaluate_row(
382                        args,
383                        top * count_params,
384                        outs,
385                        top * count_obs,
386                        f_simd,
387                        transpose,
388                    ) != 0
389                    {
390                        for i in 0..lanes {
391                            Self::evaluate_row(
392                                args,
393                                (top + i) * count_params,
394                                outs,
395                                (top + i) * count_obs,
396                                f_scalar,
397                                false,
398                            );
399                        }
400                    }
401                });
402
403                for t in step * (n / step)..n {
404                    Self::evaluate_row(
405                        args,
406                        t * count_params,
407                        outs,
408                        t * count_obs,
409                        f_scalar,
410                        false,
411                    );
412                }
413            }
414        }
415    }
416
417    fn evaluate_matrix_without_threads_simd(
418        &self,
419        args: &[f64],
420        outs: &mut [f64],
421        n: usize,
422        transpose: bool,
423    ) {
424        if let Some(f) = &self.compiled {
425            let count_params = self.count_params;
426            let count_obs = self.count_obs;
427
428            if let Some(compiled) = &self.compiled_simd {
429                let f_simd = compiled.func();
430                let f_scalar = f.func();
431                let lanes = compiled.count_lanes();
432                let step = if transpose { lanes } else { 1 };
433
434                for k in 0..n / step {
435                    let top = k * lanes;
436                    if Self::evaluate_row(
437                        args,
438                        top * count_params,
439                        outs,
440                        top * count_obs,
441                        f_simd,
442                        transpose,
443                    ) != 0
444                    {
445                        for i in 0..lanes {
446                            Self::evaluate_row(
447                                args,
448                                (top + i) * count_params,
449                                outs,
450                                (top + i) * count_obs,
451                                f_scalar,
452                                false,
453                            );
454                        }
455                    }
456                }
457
458                for t in step * (n / step)..n {
459                    Self::evaluate_row(
460                        args,
461                        t * count_params,
462                        outs,
463                        t * count_obs,
464                        f_scalar,
465                        false,
466                    );
467                }
468            }
469        }
470    }
471
472    /// Generic evaluate function for compiled Symbolica expressions
473    /// The main entry point to compute matrices.
474    /// The actual dispatched method depends on the configuration and the
475    /// type of the arguments.
476    pub fn evaluate_matrix<T>(&self, args: &[T], outs: &mut [T], n: usize)
477    where
478        T: Element,
479    {
480        let args = recast_as_f64(args);
481        let outs = recast_as_f64_mut(outs);
482
483        let transpose = !matches!(
484            T::get_type(T::default()),
485            ElemType::RealF64x2(_)
486                | ElemType::RealF64x4(_)
487                | ElemType::ComplexF64x2(_)
488                | ElemType::ComplexF64x4(_)
489        );
490
491        if self.use_threads && n > 1 {
492            if self.compiled_simd.is_some() {
493                self.evaluate_matrix_with_threads_simd(args, outs, n, transpose);
494            } else {
495                self.evaluate_matrix_with_threads(args, outs, n);
496            }
497        } else {
498            if self.compiled_simd.is_some() {
499                self.evaluate_matrix_without_threads_simd(args, outs, n, transpose);
500            } else {
501                self.evaluate_matrix_without_threads(args, outs, n);
502            }
503        }
504    }
505
506    /// Returns a fast function.
507    ///
508    /// `Application` call functions need to copy the input argument slice into
509    /// the function memory area and then copy the output to a `Vec`. This process
510    /// is acceptable for large and complex functions but incurs a penalty for
511    /// small functions. Therefore, for a certain subset of applications, Symjit
512    /// can compile a fast funcction and return a function pointer. Examples:
513    ///
514    /// ```rust
515    /// fn test_fast() -> Result<()> {
516    ///     let x = Expr::var("x");
517    ///     let y = Expr::var("y");
518    ///     let z = Expr::var("z");
519    ///     let u = &x * &(&y - &z).pow(&Expr::from(2));
520    ///
521    ///     let mut comp = Compiler::new();
522    ///     let mut app = comp.compile(&[x, y, z], &[u])?;
523    ///     let f = app.fast_func()?;
524    ///
525    ///     if let FastFunc::F3(f, _) = f {
526    ///         let res = f(3.0, 5.0, 9.0);
527    ///         println!("fast\t{:?}", &res);
528    ///     }
529    ///
530    ///     Ok(())
531    /// }
532    /// ```
533    ///
534    /// The conditions for a fast function are:
535    ///
536    /// * A fast function can have 1 to 8 arguments.
537    /// * No SIMD and no parameters.
538    /// * It returns only a single value.
539    ///
540    /// If these conditions are met, you can generate a fast functin by calling
541    /// `app.fast_func()`, with a return type of `Result<FastFunc>`. `FastFunc` is an
542    /// enum with eight variants `F1, `F2`, ..., `F8`, corresponding to
543    /// functions with 1 to 8 arguments.
544    ///
545    pub fn fast_func(&mut self) -> Result<FastFunc<'_>> {
546        let f = self.get_fast();
547
548        if let Some(f) = f {
549            match self.count_states {
550                1 => {
551                    let g: extern "C" fn(f64) -> f64 = unsafe { std::mem::transmute(f) };
552                    Ok(FastFunc::F1(g, self))
553                }
554                2 => {
555                    let g: extern "C" fn(f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
556                    Ok(FastFunc::F2(g, self))
557                }
558                3 => {
559                    let g: extern "C" fn(f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
560                    Ok(FastFunc::F3(g, self))
561                }
562                4 => {
563                    let g: extern "C" fn(f64, f64, f64, f64) -> f64 =
564                        unsafe { std::mem::transmute(f) };
565                    Ok(FastFunc::F4(g, self))
566                }
567                5 => {
568                    let g: extern "C" fn(f64, f64, f64, f64, f64) -> f64 =
569                        unsafe { std::mem::transmute(f) };
570                    Ok(FastFunc::F5(g, self))
571                }
572                6 => {
573                    let g: extern "C" fn(f64, f64, f64, f64, f64, f64) -> f64 =
574                        unsafe { std::mem::transmute(f) };
575                    Ok(FastFunc::F6(g, self))
576                }
577                7 => {
578                    let g: extern "C" fn(f64, f64, f64, f64, f64, f64, f64) -> f64 =
579                        unsafe { std::mem::transmute(f) };
580                    Ok(FastFunc::F7(g, self))
581                }
582                8 => {
583                    let g: extern "C" fn(f64, f64, f64, f64, f64, f64, f64, f64) -> f64 =
584                        unsafe { std::mem::transmute(f) };
585                    Ok(FastFunc::F8(g, self))
586                }
587                _ => Err(anyhow!("not a fast function")),
588            }
589        } else {
590            Err(anyhow!("not a fast function"))
591        }
592    }
593}
594
595pub fn recast_complex_vec(v: &[Complex<f64>]) -> &[f64] {
596    let n = v.len();
597    let p: *const f64 = unsafe { std::mem::transmute(v.as_ptr()) };
598    let q: &[f64] = unsafe { std::slice::from_raw_parts(p, 2 * n) };
599    q
600}
601
602pub fn recast_complex_vec_mut(v: &mut [Complex<f64>]) -> &mut [f64] {
603    let n = v.len();
604    let p: *mut f64 = unsafe { std::mem::transmute(v.as_mut_ptr()) };
605    let q: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(p, 2 * n) };
606    q
607}
608
609pub fn recast_as_f64<T>(v: &[T]) -> &[f64]
610where
611    T: Sized,
612{
613    let s = std::mem::size_of::<T>() / std::mem::size_of::<f64>();
614    let p: *const f64 = v.as_ptr() as _;
615    let q: &[f64] = unsafe { std::slice::from_raw_parts(p, s * v.len()) };
616    q
617}
618
619pub fn recast_as_f64_mut<T>(v: &mut [T]) -> &mut [f64]
620where
621    T: Sized,
622{
623    let s = std::mem::size_of::<T>() / std::mem::size_of::<f64>();
624    let p: *mut f64 = v.as_ptr() as _;
625    let q: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(p, s * v.len()) };
626    q
627}
628
629/************************* Symbolica *****************************/
630
631/// Translates Symbolica IR (generated by export_instructions) into a Symjit Model
632#[derive(Debug, Clone)]
633pub struct Translator {
634    config: Config,
635    df: Defuns,
636    ssa: Vec<Instruction>,
637    consts: Vec<Complex<f64>>, // constants
638    count_params: usize,
639    count_statics: usize,
640    eqs: Vec<Equation>,            // Symjit Equations (output)
641    temps: HashMap<usize, Slot>,   // Temp idx => Static idx
642    counts: HashMap<usize, usize>, // Static idx => number of usage on the RHS
643    cache: HashMap<usize, Expr>,   // cache of Static variables (Static idx => Expr)
644    outs: HashMap<usize, Expr>,    // cache of Outs (Out idx => Expr)
645    reals: HashSet<Loc>,
646    num_params: usize,
647    has_jump: bool,
648    last_label: usize,
649    depth: usize,
650    conds: Vec<Slot>,
651}
652
653impl Translator {
654    pub fn new(config: Config, df: Defuns) -> Translator {
655        Translator {
656            config,
657            df,
658            ssa: Vec::new(),
659            consts: Vec::new(),
660            count_params: 0,
661            count_statics: 0,
662            eqs: Vec::new(),
663            temps: HashMap::new(),
664            counts: HashMap::new(),
665            cache: HashMap::new(),
666            outs: HashMap::new(),
667            reals: HashSet::new(),
668            num_params: 0,
669            has_jump: false,
670            last_label: 0,
671            depth: 0,
672            conds: Vec::new(),
673        }
674    }
675
676    pub fn parse_model(&mut self, model: &SymbolicaModel) -> Result<()> {
677        for c in model.2.iter() {
678            let val = Complex::new(c.value().re, c.value().im);
679            self.consts.push(val);
680        }
681
682        self.convert(model)?;
683        Ok(())
684    }
685
686    /// The first pass by converting Symbolica IR into
687    /// Static-Single-Assingment (SSA) Form
688    fn convert(&mut self, model: &SymbolicaModel) -> Result<()> {
689        for line in model.0.iter() {
690            match line {
691                Instruction::Add(lhs, args, num_reals) => self.append_add(lhs, args, *num_reals)?,
692                Instruction::Mul(lhs, args, num_reals) => self.append_mul(lhs, args, *num_reals)?,
693                Instruction::Pow(lhs, arg, p, is_real) => {
694                    self.append_pow(lhs, arg, *p, *is_real)?
695                }
696                Instruction::Powf(lhs, arg, p, is_real) => {
697                    self.append_powf(lhs, arg, p, *is_real)?
698                }
699                Instruction::Assign(lhs, rhs) => self.append_assign(lhs, rhs)?,
700                Instruction::Fun(lhs, fun, arg, is_real) => {
701                    self.append_fun(lhs, fun, arg, *is_real)?
702                }
703                Instruction::Join(lhs, cond, true_val, false_val) => {
704                    self.depth -= 1;
705                    self.append_join(lhs, cond, true_val, false_val)?
706                }
707                Instruction::Label(id) => self.append_label(*id)?,
708                Instruction::IfElse(cond, id) => {
709                    self.append_if_else(cond, *id)?;
710                    self.depth += 1;
711                }
712                Instruction::Goto(id) => self.append_goto(*id)?,
713                Instruction::ExternalFun(lhs, op, args) => {
714                    self.append_external_fun(lhs, op, args)?
715                }
716            }
717        }
718
719        Ok(())
720    }
721
722    pub fn append_constant(&mut self, z: Complex<f64>) -> Result<usize> {
723        self.consts.push(z);
724        Ok(self.consts.len() - 1)
725    }
726
727    pub fn append_add(&mut self, lhs: &Slot, args: &[Slot], num_reals: usize) -> Result<()> {
728        let args = self.consume_list(args)?;
729        let lhs = self.produce(lhs)?;
730        self.ssa.push(Instruction::Add(lhs, args, num_reals));
731        Ok(())
732    }
733
734    pub fn append_mul(&mut self, lhs: &Slot, args: &[Slot], num_reals: usize) -> Result<()> {
735        let args = self.consume_list(args)?;
736        let lhs = self.produce(lhs)?;
737        self.ssa.push(Instruction::Mul(lhs, args, num_reals));
738        Ok(())
739    }
740
741    pub fn append_pow(&mut self, lhs: &Slot, arg: &Slot, p: i64, is_real: bool) -> Result<()> {
742        let arg = self.consume(arg)?;
743        let lhs = self.produce(lhs)?;
744        self.ssa.push(Instruction::Pow(lhs, arg, p, is_real));
745        Ok(())
746    }
747
748    pub fn append_powf(&mut self, lhs: &Slot, arg: &Slot, p: &Slot, is_real: bool) -> Result<()> {
749        let arg = self.consume(arg)?;
750        let p = self.consume(p)?;
751        let lhs = self.produce(lhs)?;
752        self.ssa.push(Instruction::Powf(lhs, arg, p, is_real));
753        Ok(())
754    }
755
756    pub fn append_assign(&mut self, lhs: &Slot, rhs: &Slot) -> Result<()> {
757        let rhs = self.consume(rhs)?;
758        let lhs = self.produce(lhs)?;
759        self.ssa.push(Instruction::Assign(lhs, rhs));
760        Ok(())
761    }
762
763    pub fn append_label(&mut self, id: usize) -> Result<()> {
764        self.ssa.push(Instruction::Label(id));
765        Ok(())
766    }
767
768    pub fn append_if_else(&mut self, cond: &Slot, id: usize) -> Result<()> {
769        self.has_jump = true;
770        let cond = self.consume(cond)?;
771        self.ssa.push(Instruction::IfElse(cond, id));
772        Ok(())
773    }
774
775    pub fn append_goto(&mut self, id: usize) -> Result<()> {
776        self.last_label = self.last_label.max(id);
777        self.ssa.push(Instruction::Goto(id));
778        Ok(())
779    }
780
781    pub fn append_external_fun(&mut self, lhs: &Slot, op: &str, args: &[Slot]) -> Result<()> {
782        let args = self.consume_list(args)?;
783        let lhs = self.produce(lhs)?;
784        self.ssa
785            .push(Instruction::ExternalFun(lhs, op.to_string(), args));
786        Ok(())
787    }
788
789    pub fn append_fun(
790        &mut self,
791        lhs: &Slot,
792        fun: &BuiltinSymbol,
793        arg: &Slot,
794        is_real: bool,
795    ) -> Result<()> {
796        let arg = self.consume(arg)?;
797        let lhs = self.produce(lhs)?;
798        self.ssa.push(Instruction::Fun(lhs, *fun, arg, is_real));
799        Ok(())
800    }
801
802    pub fn append_join(
803        &mut self,
804        lhs: &Slot,
805        cond: &Slot,
806        true_val: &Slot,
807        false_val: &Slot,
808    ) -> Result<()> {
809        let cond = self.consume(cond)?;
810        let true_val = self.consume(true_val)?;
811        let false_val = self.consume(false_val)?;
812        let lhs = self.produce(lhs)?;
813        self.ssa
814            .push(Instruction::Join(lhs, cond, true_val, false_val));
815        Ok(())
816    }
817
818    fn create_static(&mut self) -> Result<Slot> {
819        let s = Slot::Static(self.count_statics);
820        self.counts.insert(self.count_statics, 0);
821        self.count_statics += 1;
822        Ok(s)
823    }
824
825    /// Produces a new Static variable if needed.
826    /// slot should be an LHS.
827    fn produce(&mut self, slot: &Slot) -> Result<Slot> {
828        match slot {
829            Slot::Temp(idx) => {
830                if self.depth > 0 {
831                    if let Some(Slot::Static(s)) = self.temps.get(idx) {
832                        *self.counts.get_mut(s).unwrap() += 1;
833                        return Ok(Slot::Static(*s));
834                    }
835                }
836
837                let s = self.create_static()?;
838                self.temps.insert(*idx, s);
839                Ok(s)
840            }
841            Slot::Out(idx) => Ok(Slot::Out(*idx)),
842            _ => Err(anyhow!("unacceptable lhs.")),
843        }
844    }
845
846    /// Consumes a slot.
847    /// slot should be an RHS.
848    fn consume(&mut self, slot: &Slot) -> Result<Slot> {
849        match slot {
850            Slot::Temp(idx) => {
851                if let Some(Slot::Static(s)) = self.temps.get(idx) {
852                    *self.counts.get_mut(s).unwrap() += 1;
853                    Ok(Slot::Static(*s))
854                } else {
855                    Err(anyhow!("Not a static reg."))
856                }
857            }
858            Slot::Out(idx) => Ok(Slot::Out(*idx)),
859            Slot::Param(idx) => Ok(Slot::Param(*idx)),
860            Slot::Const(idx) => Ok(Slot::Const(*idx)),
861            Slot::Static(_) | Slot::Arg(_) => Err(anyhow!("Undefined Static/Arg.")),
862        }
863    }
864
865    fn consume_list(&mut self, slots: &[Slot]) -> Result<Vec<Slot>> {
866        slots.iter().map(|s| self.consume(s)).collect()
867    }
868
869    /// The second pass. It translates the SSA-form into a Symjit model.
870    pub fn translate(&mut self) -> Result<(CellModel, HashSet<Loc>)> {
871        let ssa = std::mem::take(&mut self.ssa);
872
873        for line in ssa.iter() {
874            match line {
875                Instruction::Add(lhs, args, n) => self.translate_nary("plus", lhs, args, *n)?,
876                Instruction::Mul(lhs, args, n) => self.translate_nary("times", lhs, args, *n)?,
877                Instruction::Pow(lhs, arg, p, is_real) => {
878                    let p = Expr::from(*p as f64);
879                    self.translate_pow(lhs, arg, &p, *is_real)?
880                }
881                Instruction::Powf(lhs, arg, p, is_real) => {
882                    let p = self.expr(p, false);
883                    self.translate_pow(lhs, arg, &p, *is_real)?
884                }
885                Instruction::Assign(lhs, rhs) => self.translate_assign(lhs, rhs)?,
886                Instruction::Fun(lhs, fun, arg, is_real) => {
887                    self.translate_fun(lhs, fun, arg, *is_real)?
888                }
889                Instruction::Join(lhs, cond, true_val, false_val) => {
890                    self.translate_join(lhs, cond, true_val, false_val)?
891                }
892                Instruction::Label(id) => self.translate_label(*id)?,
893                Instruction::IfElse(cond, id) => self.translate_ifelse(cond, *id)?,
894                Instruction::Goto(id) => self.translate_goto(*id)?,
895                Instruction::ExternalFun(lhs, op, args) => {
896                    self.translate_external_fun(lhs, op, args)?
897                }
898            }
899        }
900
901        // Important! Outs are cached and should be written to final outputs.
902        for k in 0..self.outs.len() {
903            let out = Expr::var(&format!("Out{}", k));
904
905            if let Some(eq) = self.outs.get(&k) {
906                self.eqs.push(Expr::equation(&out, eq));
907            }
908        }
909
910        let mut params: Vec<Variable> = (0..=self.count_params.max(self.num_params.max(1) - 1))
911            .map(|idx| self.expr(&Slot::Param(idx), false).to_variable().unwrap())
912            .collect();
913
914        let mut states: Vec<Variable> = Vec::new();
915
916        if !self.config.symbolica() {
917            (params, states) = (states, params)
918        }
919
920        Ok((
921            CellModel {
922                iv: Expr::var("$_").to_variable().unwrap(),
923                params,
924                states,
925                algs: Vec::new(),
926                odes: Vec::new(),
927                obs: self.eqs.clone(),
928            },
929            self.reals.clone(),
930        ))
931    }
932
933    // The counterpart of consume for the second-pass
934    fn expr(&mut self, slot: &Slot, is_real: bool) -> Expr {
935        match slot {
936            Slot::Param(idx) => {
937                if is_real {
938                    self.reals.insert(Loc::Param(*idx as u32));
939                }
940                self.count_params = self.count_params.max(*idx);
941                Expr::var(&format!("Param{}", idx))
942            }
943            Slot::Out(idx) => {
944                if let Some(e) = self.outs.get(idx) {
945                    e.clone()
946                } else {
947                    Expr::var(&format!("Out{}", idx))
948                }
949            }
950            Slot::Temp(idx) => Expr::var(&format!("__Temp{}", idx)),
951            Slot::Const(idx) => {
952                let val = self.consts[*idx];
953                if val.im != 0.0 {
954                    Expr::binary("complex", &Expr::from(val.re), &Expr::from(val.im))
955                } else {
956                    Expr::from(self.consts[*idx].re)
957                }
958            }
959            Slot::Static(idx) => self
960                .cache
961                .remove(idx)
962                .unwrap_or(Expr::var(&format!("__Static{}", idx))),
963            Slot::Arg(idx) => Expr::var(&format!("__Arg{}", idx)),
964        }
965    }
966
967    // The counterpart of produce for the second-pass
968    fn assign(&mut self, lhs: &Slot, rhs: Expr) -> Result<()> {
969        if !self.has_jump {
970            if let Slot::Static(idx) = lhs {
971                // Important! If a static variable is used only once, it
972                // is pushed into the cache to be incorporated into the
973                // destination expression tree.
974                if self.counts.get(idx).is_some_and(|c| *c == 1) {
975                    self.cache.insert(*idx, rhs);
976                    return Ok(());
977                }
978            }
979
980            if let Slot::Out(idx) = lhs {
981                self.outs.insert(*idx, rhs.clone());
982                return Ok(());
983            }
984        }
985
986        let lhs = self.expr(lhs, false);
987        self.eqs.push(Expr::equation(&lhs, &rhs));
988        Ok(())
989    }
990
991    fn translate_nary(&mut self, op: &str, lhs: &Slot, args: &[Slot], n: usize) -> Result<()> {
992        let args: Vec<Expr> = args
993            .iter()
994            .enumerate()
995            .map(|(i, x)| self.expr(x, i < n))
996            .collect();
997        let p: Vec<&Expr> = args.iter().collect();
998
999        if n == 0 || n >= p.len() {
1000            self.assign(lhs, Expr::nary(op, &p))
1001        } else {
1002            let l = Expr::nary(op, &p[..n]);
1003            let r = Expr::nary(op, &p[n..]);
1004            self.assign(lhs, Expr::nary(op, &[&l, &r]))
1005        }
1006    }
1007
1008    fn translate_pow(&mut self, lhs: &Slot, arg: &Slot, power: &Expr, is_real: bool) -> Result<()> {
1009        let arg = self.expr(arg, is_real);
1010        self.assign(lhs, Expr::binary("power", &arg, power))
1011    }
1012
1013    fn translate_assign(&mut self, lhs: &Slot, rhs: &Slot) -> Result<()> {
1014        let rhs = self.expr(rhs, false);
1015        self.assign(lhs, rhs)
1016    }
1017
1018    fn translate_fun(
1019        &mut self,
1020        lhs: &Slot,
1021        fun: &BuiltinSymbol,
1022        arg: &Slot,
1023        is_real: bool,
1024    ) -> Result<()> {
1025        let arg = self.expr(arg, is_real);
1026
1027        let op = match fun.0 {
1028            2 => "exp",
1029            3 => "ln",
1030            4 => "sin",
1031            5 => "cos",
1032            6 => {
1033                if is_real {
1034                    "real_root"
1035                } else {
1036                    "root"
1037                }
1038            }
1039            7 => "conjugate",
1040            _ => return Err(anyhow!("function is not defined.")),
1041        };
1042
1043        self.assign(lhs, Expr::unary(op, &arg))
1044    }
1045
1046    fn translate_external_fun(&mut self, lhs: &Slot, op: &str, args: &[Slot]) -> Result<()> {
1047        let n = args.len();
1048        assert!(n <= SLICE_CAP);
1049        let args: Vec<Expr> = args.iter().map(|a| self.expr(a, false)).collect();
1050
1051        if VirtualTable::from_str(op).is_ok() {
1052            if n == 1 {
1053                self.assign(lhs, Expr::unary(op, &args[0]))?;
1054            } else if n == 2 {
1055                self.assign(lhs, Expr::binary(op, &args[0], &args[1]))?;
1056            } else {
1057                return Err(anyhow!("wrong number of arguments to {:?}", op));
1058            }
1059        } else if self.config.is_intrinsic_unary(op) && n == 1 {
1060            self.assign(lhs, Expr::unary(op, &args[0]))?;
1061        } else if self.config.is_intrinsic_binary(op) && n == 2 {
1062            self.assign(lhs, Expr::binary(op, &args[0], &args[1]))?;
1063        } else {
1064            let temps: Vec<Slot> = (0..n).map(|_| self.create_static().unwrap()).collect();
1065            let slice: Vec<Slot> = (0..n).map(Slot::Arg).collect();
1066
1067            for i in 0..n {
1068                self.assign(&temps[i], args[i].clone())?;
1069            }
1070
1071            for i in 0..n {
1072                if let Slot::Static(idx) = temps[i] {
1073                    self.assign(&slice[i], Expr::var(&format!("__Static{}", idx)))?;
1074                }
1075            }
1076
1077            let op = format!("${}", op);
1078            self.assign(
1079                lhs,
1080                Expr::binary(&op, &Expr::from(0), &Expr::from(n as i32)),
1081            )?;
1082        }
1083
1084        Ok(())
1085    }
1086
1087    fn translate_label(&mut self, id: usize) -> Result<()> {
1088        self.eqs.push(Expr::special(&Expr::Label { id }));
1089        Ok(())
1090    }
1091
1092    fn translate_ifelse(&mut self, cond: &Slot, id: usize) -> Result<()> {
1093        // let cond = Expr::binary(
1094        //     "lt",
1095        //     &Expr::unary("abs", &self.expr(cond, false)),
1096        //     &Expr::from(f64::EPSILON),
1097        // );
1098
1099        self.conds.push(*cond);
1100        let if_clause = Expr::binary("eq", &self.expr(cond, false), &Expr::from(0.0));
1101        self.eqs.push(Expr::special(&Expr::BranchIf {
1102            cond: Box::new(if_clause),
1103            id,
1104            is_else: false,
1105        }));
1106        Ok(())
1107    }
1108
1109    fn translate_goto(&mut self, id: usize) -> Result<()> {
1110        if self.config.simd_branch() {
1111            // TODO: the commented out area should be uncommented, except it causes
1112            // a bug in some programs. The effects are not local and likely related
1113            // to instruction movements.
1114
1115            /*
1116            let cond = self.conds.pop().unwrap();
1117            self.conds.push(cond);
1118            let if_clause = Expr::binary("eq", &self.expr(&cond, false), &Expr::from(0.0));
1119            self.eqs.push(Expr::special(&Expr::BranchIf {
1120                cond: Box::new(if_clause),
1121                id,
1122                is_else: true,
1123            }));
1124            */
1125        } else {
1126            self.eqs.push(Expr::special(&Expr::Branch { id }));
1127        }
1128
1129        Ok(())
1130    }
1131
1132    fn translate_join(
1133        &mut self,
1134        lhs: &Slot,
1135        _cond: &Slot,
1136        true_val: &Slot,
1137        false_val: &Slot,
1138    ) -> Result<()> {
1139        // Join is essentially a Φ-function.
1140        let t = self.expr(true_val, false);
1141        let f = self.expr(false_val, false);
1142        let cond = self.conds.pop().unwrap();
1143        let mask = Expr::binary("eq", &self.expr(&cond, false), &Expr::from(0.0));
1144        self.assign(lhs, mask.ifelse(&f, &t))?;
1145        Ok(())
1146    }
1147
1148    pub fn set_num_params(&mut self, num_params: usize) {
1149        self.num_params = num_params
1150    }
1151
1152    pub fn compile(&mut self) -> Result<Application> {
1153        let (ml, reals) = self.translate()?;
1154        let prog = Program::new(&ml, self.config)?;
1155        let mut app = Application::new(prog, reals, std::mem::take(&mut self.df))?;
1156        app.prepare_simd();
1157        Ok(app)
1158    }
1159}
1160
1161impl Compiler {
1162    /// Compiles a Symbolica model.
1163    ///
1164    /// `json` is the JSON-encoded output of Symbolica `export_instructions`.
1165    ///
1166    /// Example:
1167    ///
1168    /// ```rust
1169    /// let params = vec![parse!("x"), parse!("y")];
1170    /// let eval = parse!("x + y^2")
1171    ///     .evaluator(&FunctionMap::new(), &params, OptimizationSettings::default())?
1172    ///
1173    /// let json = serde_json::to_string(&eval.export_instructions())?;
1174    /// let mut comp = Compiler::new();
1175    /// let mut app = comp.translate(&json)?;
1176    /// assert!(app.evaluate_single(&[2.0, 3.0]) == 11.0);
1177    /// ```
1178    pub fn translate(
1179        &mut self,
1180        json: String,
1181        df: Defuns,
1182        num_params: usize,
1183    ) -> Result<Application> {
1184        let mut translator = Translator::new(self.config, df);
1185
1186        let model: SymbolicaModel = if json.starts_with("[[{") {
1187            serde_json::from_str(json.as_str())?
1188        } else {
1189            Parser::new(json).parse()?
1190        };
1191
1192        translator.parse_model(&model)?;
1193        translator.set_num_params(num_params);
1194        let (ml, reals) = translator.translate()?;
1195
1196        let prog = Program::new(&ml, self.config)?;
1197        let df = Defuns::new();
1198        let mut app = Application::new(prog, reals, df)?;
1199
1200        app.prepare_simd();
1201
1202        // #[cfg(target_arch = "aarch64")]
1203        // if let Ok(app) = &mut app {
1204        //     // this is a hack to give enough delay to prevent a bus error
1205        //     app.dump("dump.bin", "scalar");
1206        //     std::fs::remove_file("dump.bin")?;
1207        // };
1208
1209        Ok(app)
1210    }
1211}