Skip to main content

symjit/
compiler.rs

1#[cfg(target_arch = "x86_64")]
2use std::arch::x86_64::{__m256d, _mm256_setzero_pd};
3
4use std::collections::{HashMap, HashSet};
5
6use anyhow::{anyhow, Result};
7use num_complex::Complex;
8use rayon::prelude::*;
9
10use crate::config::Config;
11use crate::defuns::Defuns;
12use crate::expr::Expr;
13pub use crate::instruction::{BuiltinSymbol, Instruction, Slot, SymbolicaModel};
14use crate::model::{CellModel, Equation, Program, Variable};
15use crate::parser::Parser;
16use crate::symbol::Loc;
17use crate::utils::CompiledFunc;
18use crate::Application;
19
20// #[derive(Debug)]
21pub struct Compiler {
22    config: Config,
23    df: Defuns,
24}
25
26#[cfg(not(target_arch = "x86_64"))]
27#[allow(non_camel_case_types)]
28type __m256d = [f64; 4];
29
30/// The central hub of the Rust interface. It compiles a list of
31/// variables and expressions into a callable object (of type `Application`).
32///
33/// # Workflow
34///
35/// 1. Create terminals (variables and constants) and compose expressions using `Expr` methods:
36///    * Constructors: `var`, `from`, `unary`, `binary`, ...
37///    * Standard algebraic operations: `add`, `mul`, ...
38///    * Standard operators `+`, `-`, `*`, `/`, `%`, `&`, `|`, `^`, `!`.
39///    * Unary functions such as `sin`, `exp`, and other standard mathematical functions.
40///    * Binary functions such as `pow`, `min`, ...
41///    * IfElse operation `ifelse(cond, true_val, false_val)`.
42///    * Heavide function: `heaviside(x)`, which returns 1 if `x >= 0`; otherwise 0.
43///    * Comparison methods `eq`, `ne`, `lt`, `le`, `gt`, and `ge`.
44///    * Looping constructs `sum` and `prod`.
45/// 2. Create a new `Compiler` object (say, `comp`) using one of its constructors: `new()`
46///    or `with_compile_type(ty: CompilerType)`.
47/// 3. Fine-tune the optimization passes using `opt_level`, `simd`, `fastmath`,
48///    and `cse` methods (optional).
49/// 4. Define user-defined functions by called `comp.def_unary` and `comp.def_binary`
50///    (optional).
51/// 5. Compile by calling `comp.compile` or `comp.compile_params`. The result is of
52///    type `Application` (say, `app`).
53/// 6. Execute the compiled code using one of the `app`'s `call` functions:
54///    * `call(&[f64])`: scalar call.
55///    * `call_params(&[f64], &[f64])`: scalar call with parameters.
56///    * `call_simd(&[__m256d])`: simd call.
57///    * `call_simd_params(&[__m256d], &[f64])`: simd call with parameters.
58/// 7. Optionally, generate a standalone fast function to execute.
59///
60///
61/// # Examples
62///
63/// ```rust
64/// use anyhow::Result;
65/// use symjit::{Compiler, Expr};
66///
67/// pub fn main() -> Result<()> {
68///     let x = Expr::var("x");
69///     let y = Expr::var("y");
70///     let u = &x + &y;
71///     let v = &x * &y;
72///
73///     let mut config = Config::default();
74///     config.set_opt_level(2);
75///     let mut comp = Compiler::with_config(config);
76///     let mut app = comp.compile(&[x, y], &[u, v])?;
77///     let res = app.call(&[3.0, 5.0]);
78///     println!("{:?}", &res);
79///
80///     Ok(())
81/// }
82/// ```
83impl Compiler {
84    /// Creates a new `Compiler` object with default settings.
85    pub fn new() -> Compiler {
86        Compiler {
87            config: Config::default(),
88            df: Defuns::new(),
89        }
90    }
91
92    pub fn with_config(config: Config) -> Compiler {
93        Compiler {
94            config,
95            df: Defuns::new(),
96        }
97    }
98
99    /// Compiles a model.
100    ///
101    /// `states` is a list of variables, created by `Expr::var`.
102    /// `obs` is a list of expressions.
103    pub fn compile(&mut self, states: &[Expr], obs: &[Expr]) -> Result<Application> {
104        self.compile_params(states, obs, &[])
105    }
106
107    /// Compiles a model with parameters.
108    ///
109    /// `states` is a list of variables, created by `Expr::var`.
110    /// `obs` is a list of expressions.
111    /// `params` is a list of parameters, created by `Expr::var`.
112    ///
113    /// Note: for scalar functions, the difference between states and params
114    ///     is mostly by convenion. However, they are different in SIMD cases,
115    ///     as params are always f64.
116    pub fn compile_params(
117        &mut self,
118        states: &[Expr],
119        obs: &[Expr],
120        params: &[Expr],
121    ) -> Result<Application> {
122        let mut vars: Vec<Variable> = Vec::new();
123
124        for state in states.iter() {
125            let v = state.to_variable()?;
126            vars.push(v);
127        }
128
129        let mut ps: Vec<Variable> = Vec::new();
130
131        for p in params.iter() {
132            let v = p.to_variable()?;
133            ps.push(v);
134        }
135
136        let mut eqs: Vec<Equation> = Vec::new();
137
138        for (i, expr) in obs.iter().enumerate() {
139            let name = format!("${}", i);
140            let lhs = Expr::var(&name);
141            eqs.push(Expr::equation(&lhs, expr));
142        }
143
144        let ml = CellModel {
145            iv: Expr::var("$_").to_variable()?,
146            params: ps,
147            states: vars,
148            algs: Vec::new(),
149            odes: Vec::new(),
150            obs: eqs,
151        };
152
153        let prog = Program::new(&ml, self.config)?;
154        // let df = Defuns::new();
155        let mut app = Application::new(prog, HashSet::new(), &self.df);
156
157        #[cfg(target_arch = "aarch64")]
158        if let Ok(app) = &mut app {
159            // this is a hack to give enough delay to prevent a bus error
160            app.dump("dump.bin", "scalar");
161            std::fs::remove_file("dump.bin")?;
162        };
163
164        app
165    }
166
167    /// Registers a user-defined unary function.
168    pub fn def_unary(&mut self, op: &str, f: extern "C" fn(f64) -> f64) {
169        self.df.add_unary(op, f)
170    }
171
172    /// Registers a user-defined binary function.
173    pub fn def_binary(&mut self, op: &str, f: extern "C" fn(f64, f64) -> f64) {
174        self.df.add_binary(op, f)
175    }
176}
177
178#[cfg(target_arch = "x86_64")]
179unsafe fn simd_slice(a: &[f64]) -> &[__m256d] {
180    assert!(a.len() & 3 == 0);
181    let p: *const f64 = a.as_ptr();
182    let v = unsafe { std::slice::from_raw_parts(p as *const __m256d, a.len() >> 2) };
183    v
184}
185
186#[cfg(target_arch = "x86_64")]
187unsafe fn simd_slice_mut(a: &mut [f64]) -> &mut [__m256d] {
188    assert!(a.len() & 3 == 0);
189    let p: *mut f64 = a.as_mut_ptr();
190    let v: &mut [__m256d] =
191        unsafe { std::slice::from_raw_parts_mut(p as *mut __m256d, a.len() >> 2) };
192    v
193}
194
195pub enum FastFunc<'a> {
196    F1(extern "C" fn(f64) -> f64, &'a Application),
197    F2(extern "C" fn(f64, f64) -> f64, &'a Application),
198    F3(extern "C" fn(f64, f64, f64) -> f64, &'a Application),
199    F4(extern "C" fn(f64, f64, f64, f64) -> f64, &'a Application),
200    F5(
201        extern "C" fn(f64, f64, f64, f64, f64) -> f64,
202        &'a Application,
203    ),
204    F6(
205        extern "C" fn(f64, f64, f64, f64, f64, f64) -> f64,
206        &'a Application,
207    ),
208    F7(
209        extern "C" fn(f64, f64, f64, f64, f64, f64, f64) -> f64,
210        &'a Application,
211    ),
212    F8(
213        extern "C" fn(f64, f64, f64, f64, f64, f64, f64, f64) -> f64,
214        &'a Application,
215    ),
216}
217
218impl Application {
219    /// Calls the compiled function.
220    ///
221    /// `args` is a slice of f64 values, corresponding to the states.
222    ///
223    /// The output is a `Vec<f64>`, corresponding to the observables (the expressions passed
224    /// to `compile`).
225    pub fn call(&mut self, args: &[f64]) -> Vec<f64> {
226        {
227            let mem = self.compiled.mem_mut();
228            let states = &mut mem[self.first_state..self.first_state + self.count_states];
229            states.copy_from_slice(args);
230        }
231
232        self.compiled.exec(&self.params[..]);
233
234        let obs = {
235            let mem = self.compiled.mem();
236            &mem[self.first_obs..self.first_obs + self.count_obs]
237        };
238
239        obs.to_vec()
240    }
241
242    /// Sets the params and calls the compiled function.
243    ///
244    /// `args` is a slice of f64 values, corresponding to the states.
245    /// `params` is a slice of f64 values, corresponding to the params.
246    ///
247    /// The output is a `Vec<f64>`, corresponding to the observables (the expressions passed
248    /// to `compile`).
249    pub fn call_params(&mut self, args: &[f64], params: &[f64]) -> Vec<f64> {
250        {
251            let mem = self.compiled.mem_mut();
252            let states = &mut mem[self.first_state..self.first_state + self.count_states];
253            states.copy_from_slice(args);
254        }
255
256        self.compiled.exec(params);
257
258        let obs = {
259            let mem = self.compiled.mem();
260            &mem[self.first_obs..self.first_obs + self.count_obs]
261        };
262
263        obs.to_vec()
264    }
265
266    /// Generic evaluate function for compiled Symbolica expressions
267    #[inline(always)]
268    pub fn evaluate<T: Sized + Copy>(&mut self, args: &[T], outs: &mut [T]) {
269        if self.prog.config().is_bytecode() {
270            let mut stack: Vec<f64> = vec![0.0; self.prog.builder.block().sym_table.num_stack];
271            let mut regs = [0.0; 32];
272
273            let outs: &mut [f64] = unsafe { std::mem::transmute(outs) };
274            let args: &[f64] = unsafe { std::mem::transmute(args) };
275
276            self.mir.exec_instruction(outs, &mut stack, &mut regs, args);
277
278            return;
279        }
280
281        let f = self.compiled.func();
282
283        f(
284            outs.as_ptr() as *mut f64,
285            std::ptr::null(),
286            0,
287            args.as_ptr() as *const f64,
288        );
289    }
290
291    /// Generic evaluate_single function for compiled Symbolica expressions
292    #[inline(always)]
293    pub fn evaluate_single<T: Sized + Copy + Default>(&mut self, args: &[T]) -> T {
294        let mut outs = [T::default(); 1];
295        self.evaluate(args, &mut outs);
296        outs[0]
297    }
298
299    /// Generic SIMD evaluate function for compiled Symbolica expressions
300    #[inline(always)]
301    pub fn evaluate_simd<T: Sized + Copy>(&mut self, args: &[T], outs: &mut [T]) {
302        if let Some(g) = &mut self.compiled_simd {
303            let f = g.func();
304
305            f(
306                outs.as_ptr() as *mut f64,
307                std::ptr::null(),
308                0,
309                args.as_ptr() as *const f64,
310            );
311        }
312    }
313
314    /// Generic SIMD evaluate_single function for compiled Symbolica expressions
315    #[inline(always)]
316    pub fn evaluate_simd_single<T: Sized + Copy + Default>(&mut self, args: &[T]) -> T {
317        let mut outs = [T::default(); 1];
318        self.evaluate_simd(args, &mut outs);
319        outs[0]
320    }
321
322    fn evaluate_row(
323        args: &[f64],
324        args_idx: usize,
325        outs: &[f64],
326        outs_idx: usize,
327        f: CompiledFunc<f64>,
328        transpose: bool,
329    ) -> i32 {
330        unsafe {
331            f(
332                outs.as_ptr().add(outs_idx),
333                std::ptr::null(),
334                if transpose { 1 } else { 0 },
335                args.as_ptr().add(args_idx),
336            )
337        }
338    }
339
340    /// Generic evaluate function for compiled Symbolica expressions
341    fn evaluate_matrix_with_threads(&mut self, args: &[f64], outs: &mut [f64], n: usize) {
342        let count_params = self.count_params;
343        let count_obs = self.count_obs;
344        let f_scalar = self.compiled.func();
345
346        (0..n).into_par_iter().for_each(|t| {
347            Self::evaluate_row(args, t * count_params, outs, t * count_obs, f_scalar, false);
348        });
349    }
350
351    /// Generic evaluate function for compiled Symbolica expressions
352    fn evaluate_matrix_without_threads(&mut self, args: &[f64], outs: &mut [f64], n: usize) {
353        let count_params = self.count_params;
354        let count_obs = self.count_obs;
355        let f_scalar = self.compiled.func();
356
357        for t in 0..n {
358            Self::evaluate_row(args, t * count_params, outs, t * count_obs, f_scalar, false);
359        }
360    }
361
362    fn evaluate_matrix_with_threads_simd(&mut self, args: &[f64], outs: &mut [f64], n: usize) {
363        let count_params = self.count_params;
364        let count_obs = self.count_obs;
365
366        if let Some(compiled) = &self.compiled_simd {
367            let f_simd = compiled.func();
368            let f_scalar = self.compiled.func();
369            let lanes = compiled.count_lanes();
370
371            (0..n / lanes).into_par_iter().for_each(|k| {
372                let top = k * lanes;
373                if Self::evaluate_row(
374                    args,
375                    top * count_params,
376                    outs,
377                    top * count_obs,
378                    f_simd,
379                    true,
380                ) != 0
381                {
382                    for i in 0..lanes {
383                        Self::evaluate_row(
384                            args,
385                            (top + i) * count_params,
386                            outs,
387                            (top + i) * count_obs,
388                            f_scalar,
389                            false,
390                        );
391                    }
392                }
393            });
394
395            for t in lanes * (n / lanes)..n {
396                Self::evaluate_row(args, t * count_params, outs, t * count_obs, f_scalar, false);
397            }
398        }
399    }
400
401    fn evaluate_matrix_without_threads_simd(&mut self, args: &[f64], outs: &mut [f64], n: usize) {
402        let count_params = self.count_params;
403        let count_obs = self.count_obs;
404
405        if let Some(compiled) = &self.compiled_simd {
406            let f_simd = compiled.func();
407            let f_scalar = self.compiled.func();
408            let lanes = compiled.count_lanes();
409
410            for k in 0..n / lanes {
411                let top = k * lanes;
412                if Self::evaluate_row(
413                    args,
414                    top * count_params,
415                    outs,
416                    top * count_obs,
417                    f_simd,
418                    true,
419                ) != 0
420                {
421                    for i in 0..lanes {
422                        Self::evaluate_row(
423                            args,
424                            (top + i) * count_params,
425                            outs,
426                            (top + i) * count_obs,
427                            f_scalar,
428                            false,
429                        );
430                    }
431                }
432            }
433
434            for t in lanes * (n / lanes)..n {
435                Self::evaluate_row(args, t * count_params, outs, t * count_obs, f_scalar, false);
436            }
437        }
438    }
439
440    pub fn evaluate_matrix_bytecode(&mut self, args: &[f64], outs: &mut [f64], n: usize) {
441        let count_params = self.count_params;
442        let count_obs = self.count_obs;
443
444        for i in 0..n {
445            self.evaluate(
446                &args[i * count_params..(i + 1) * count_params],
447                &mut outs[i * count_obs..(i + 1) * count_obs],
448            );
449        }
450    }
451
452    /// Generic evaluate function for compiled Symbolica expressions
453    pub fn evaluate_matrix(&mut self, args: &[f64], outs: &mut [f64], n: usize) {
454        if self.prog.config().is_bytecode() {
455            self.evaluate_matrix_bytecode(args, outs, n);
456        } else if self.use_threads {
457            if self.compiled_simd.is_some() {
458                self.evaluate_matrix_with_threads_simd(args, outs, n);
459            } else {
460                self.evaluate_matrix_with_threads(args, outs, n);
461            }
462        } else {
463            if self.compiled_simd.is_some() {
464                self.evaluate_matrix_without_threads_simd(args, outs, n);
465            } else {
466                self.evaluate_matrix_without_threads(args, outs, n);
467            }
468        }
469    }
470
471    pub fn evaluate_complex_matrix(
472        &mut self,
473        args: &[Complex<f64>],
474        outs: &mut [Complex<f64>],
475        n: usize,
476    ) {
477        let args = recast_complex_vec(args);
478        let outs = recast_complex_vec_mut(outs);
479
480        if self.prog.config().is_bytecode() {
481            self.evaluate_matrix_bytecode(args, outs, n);
482        } else if self.use_threads {
483            if self.compiled_simd.is_some() {
484                self.evaluate_matrix_with_threads_simd(args, outs, n);
485            } else {
486                self.evaluate_matrix_with_threads(args, outs, n);
487            }
488        } else {
489            if self.compiled_simd.is_some() {
490                self.evaluate_matrix_without_threads_simd(args, outs, n);
491            } else {
492                self.evaluate_matrix_without_threads(args, outs, n);
493            }
494        }
495    }
496
497    /// Generic evaluate function for compiled Symbolica expressions
498    pub fn evaluate_simd_matrix<T: Sized + Copy>(&mut self, args: &[T], outs: &mut [T], n: usize) {
499        let args_size = args.len() / n;
500        let outs_size = outs.len() / n;
501
502        for (p, q) in args.chunks(args_size).zip(outs.chunks_mut(outs_size)) {
503            self.evaluate_simd(p, q);
504        }
505    }
506
507    /// Calls the compiled SIMD function.
508    ///
509    /// `args` is a slice of __m256d values, corresponding to the states.
510    ///
511    /// The output is an `Result` wrapping `Vec<__m256d>`, corresponding to the observables
512    /// (the expressions passed to `compile`).
513    ///
514    /// Note: currently, this function only works on X86-64 CPUs with the AVX extension. Intel
515    /// introduced the AVX instruction set in 2011; therefore, most intel and AMD processors
516    /// support it. If SIMD is not supported, this function returns `None`.
517    ///
518    #[cfg(target_arch = "x86_64")]
519    pub fn call_simd(&mut self, args: &[__m256d]) -> Result<Vec<__m256d>> {
520        if let Some(f) = &mut self.compiled_simd {
521            {
522                let mem = f.mem_mut();
523                let states = unsafe {
524                    simd_slice_mut(
525                        &mut mem[self.first_state * 4..(self.first_state + self.count_states) * 4],
526                    )
527                };
528                states.copy_from_slice(args);
529            }
530
531            f.exec(&self.params);
532
533            {
534                let mem = f.mem();
535                let obs = unsafe {
536                    simd_slice(&mem[self.first_obs * 4..(self.first_obs + self.count_obs) * 4])
537                };
538                let mut res = unsafe { vec![_mm256_setzero_pd(); self.count_obs] };
539                res.copy_from_slice(obs);
540                Ok(res)
541            }
542        } else {
543            self.prepare_simd();
544            if self.compiled_simd.is_some() {
545                self.call_simd(args)
546            } else {
547                Err(anyhow!("cannot compile SIMD"))
548            }
549        }
550    }
551
552    #[cfg(not(target_arch = "x86_64"))]
553    pub unsafe fn call_simd(&mut self, _args: &[__m256d]) -> Result<Vec<__m256d>> {
554        Err(anyhow!("cannot compile SIMD"))
555    }
556
557    /// Sets the params and calls the compiled SIMD function.
558    ///
559    /// `args` is a slice of __m256d values, corresponding to the states.
560    ///
561    /// `params` is a slice of f64 values.
562    ///
563    /// The output is a `Result` wrapping a `Vec<__m256d>`, corresponding to the observables
564    /// (the expressions passed to `compile`).
565    ///
566    /// Note: currently, this function only works on X86-64 CPUs with the AVX extension. Intel
567    /// introduced the AVX instruction set in 2011; therefore, most intel and AMD processors
568    /// support it. If SIMD is not supported, this function returns `None`.
569    ///
570    #[cfg(target_arch = "x86_64")]
571    pub fn call_simd_params(&mut self, args: &[__m256d], params: &[f64]) -> Result<Vec<__m256d>> {
572        if let Some(f) = &mut self.compiled_simd {
573            {
574                let mem = f.mem_mut();
575                let states = unsafe {
576                    simd_slice_mut(
577                        &mut mem[self.first_state * 4..(self.first_state + self.count_states) * 4],
578                    )
579                };
580                states.copy_from_slice(args);
581            }
582
583            f.exec(params);
584
585            {
586                let mem = f.mem();
587                let obs = unsafe {
588                    simd_slice(&mem[self.first_obs * 4..(self.first_obs + self.count_obs) * 4])
589                };
590                let mut res = unsafe { vec![_mm256_setzero_pd(); self.count_obs] };
591                res.copy_from_slice(obs);
592                Ok(res)
593            }
594        } else {
595            self.prepare_simd();
596            if self.compiled_simd.is_some() {
597                self.call_simd_params(args, params)
598            } else {
599                Err(anyhow!("cannot compile SIMD"))
600            }
601        }
602    }
603
604    #[cfg(not(target_arch = "x86_64"))]
605    pub unsafe fn call_simd_params(
606        &mut self,
607        _args: &[__m256d],
608        _params: &[f64],
609    ) -> Result<Vec<__m256d>> {
610        Err(anyhow!("cannot compile SIMD"))
611    }
612
613    /// Returns a fast function.
614    ///
615    /// `Application` call functions need to copy the input argument slice into
616    /// the function memory area and then copy the output to a `Vec`. This process
617    /// is acceptable for large and complex functions but incurs a penalty for
618    /// small functions. Therefore, for a certain subset of applications, Symjit
619    /// can compile a fast funcction and return a function pointer. Examples:
620    ///
621    /// ```rust
622    /// fn test_fast() -> Result<()> {
623    ///     let x = Expr::var("x");
624    ///     let y = Expr::var("y");
625    ///     let z = Expr::var("z");
626    ///     let u = &x * &(&y - &z).pow(&Expr::from(2));
627    ///
628    ///     let mut comp = Compiler::new();
629    ///     let mut app = comp.compile(&[x, y, z], &[u])?;
630    ///     let f = app.fast_func()?;
631    ///
632    ///     if let FastFunc::F3(f, _) = f {
633    ///         let res = f(3.0, 5.0, 9.0);
634    ///         println!("fast\t{:?}", &res);
635    ///     }
636    ///
637    ///     Ok(())
638    /// }
639    /// ```
640    ///
641    /// The conditions for a fast function are:
642    ///
643    /// * A fast function can have 1 to 8 arguments.
644    /// * No SIMD and no parameters.
645    /// * It returns only a single value.
646    ///
647    /// If these conditions are met, you can generate a fast functin by calling
648    /// `app.fast_func()`, with a return type of `Result<FastFunc>`. `FastFunc` is an
649    /// enum with eight variants `F1, `F2`, ..., `F8`, corresponding to
650    /// functions with 1 to 8 arguments.
651    ///
652    pub fn fast_func(&mut self) -> Result<FastFunc<'_>> {
653        let f = self.get_fast();
654
655        if let Some(f) = f {
656            match self.count_states {
657                1 => {
658                    let g: extern "C" fn(f64) -> f64 = unsafe { std::mem::transmute(f) };
659                    Ok(FastFunc::F1(g, self))
660                }
661                2 => {
662                    let g: extern "C" fn(f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
663                    Ok(FastFunc::F2(g, self))
664                }
665                3 => {
666                    let g: extern "C" fn(f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
667                    Ok(FastFunc::F3(g, self))
668                }
669                4 => {
670                    let g: extern "C" fn(f64, f64, f64, f64) -> f64 =
671                        unsafe { std::mem::transmute(f) };
672                    Ok(FastFunc::F4(g, self))
673                }
674                5 => {
675                    let g: extern "C" fn(f64, f64, f64, f64, f64) -> f64 =
676                        unsafe { std::mem::transmute(f) };
677                    Ok(FastFunc::F5(g, self))
678                }
679                6 => {
680                    let g: extern "C" fn(f64, f64, f64, f64, f64, f64) -> f64 =
681                        unsafe { std::mem::transmute(f) };
682                    Ok(FastFunc::F6(g, self))
683                }
684                7 => {
685                    let g: extern "C" fn(f64, f64, f64, f64, f64, f64, f64) -> f64 =
686                        unsafe { std::mem::transmute(f) };
687                    Ok(FastFunc::F7(g, self))
688                }
689                8 => {
690                    let g: extern "C" fn(f64, f64, f64, f64, f64, f64, f64, f64) -> f64 =
691                        unsafe { std::mem::transmute(f) };
692                    Ok(FastFunc::F8(g, self))
693                }
694                _ => Err(anyhow!("not a fast function")),
695            }
696        } else {
697            Err(anyhow!("not a fast function"))
698        }
699    }
700}
701
702pub fn recast_complex_vec(v: &[Complex<f64>]) -> &[f64] {
703    let n = v.len();
704    let p: *const f64 = unsafe { std::mem::transmute(v.as_ptr()) };
705    let q: &[f64] = unsafe { std::slice::from_raw_parts(p, 2 * n) };
706    q
707}
708
709pub fn recast_complex_vec_mut(v: &mut [Complex<f64>]) -> &mut [f64] {
710    let n = v.len();
711    let p: *mut f64 = unsafe { std::mem::transmute(v.as_mut_ptr()) };
712    let q: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(p, 2 * n) };
713    q
714}
715
716/************************* Symbolica *****************************/
717
718/// Translates Symbolica IR (generated by export_instructions) into a Symjit Model
719#[derive(Debug, Clone)]
720pub struct Translator {
721    config: Config,
722    df: Defuns,
723    ssa: Vec<Instruction>,
724    consts: Vec<Complex<f64>>, // constants
725    count_params: usize,
726    count_statics: usize,
727    eqs: Vec<Equation>,            // Symjit Equations (output)
728    temps: HashMap<usize, Slot>,   // Temp idx => Static idx
729    counts: HashMap<usize, usize>, // Static idx => number of usage on the RHS
730    cache: HashMap<usize, Expr>,   // cache of Static variables (Static idx => Expr)
731    outs: HashMap<usize, Expr>,    // cache of Outs (Out idx => Expr)
732    reals: HashSet<Loc>,
733    num_params: usize,
734    has_jump: bool,
735    last_label: usize,
736    depth: usize,
737    conds: Vec<Slot>,
738}
739
740impl Translator {
741    pub fn new(config: Config, df: &Defuns) -> Translator {
742        Translator {
743            config,
744            df: df.clone(),
745            ssa: Vec::new(),
746            consts: Vec::new(),
747            count_params: 0,
748            count_statics: 0,
749            eqs: Vec::new(),
750            temps: HashMap::new(),
751            counts: HashMap::new(),
752            cache: HashMap::new(),
753            outs: HashMap::new(),
754            reals: HashSet::new(),
755            num_params: 0,
756            has_jump: false,
757            last_label: 0,
758            depth: 0,
759            conds: Vec::new(),
760        }
761    }
762
763    pub fn parse_model(&mut self, model: &SymbolicaModel) -> Result<()> {
764        for c in model.2.iter() {
765            let val = Complex::new(c.value().re, c.value().im);
766            self.consts.push(val);
767        }
768
769        self.convert(model)?;
770        Ok(())
771    }
772
773    /// The first pass by converting Symbolica IR into
774    /// Static-Single-Assingment (SSA) Form
775    fn convert(&mut self, model: &SymbolicaModel) -> Result<()> {
776        for line in model.0.iter() {
777            match line {
778                Instruction::Add(lhs, args, num_reals) => self.append_add(lhs, args, *num_reals)?,
779                Instruction::Mul(lhs, args, num_reals) => self.append_mul(lhs, args, *num_reals)?,
780                Instruction::Pow(lhs, arg, p, is_real) => {
781                    self.append_pow(lhs, arg, *p, *is_real)?
782                }
783                Instruction::Powf(lhs, arg, p, is_real) => {
784                    self.append_powf(lhs, arg, p, *is_real)?
785                }
786                Instruction::Assign(lhs, rhs) => self.append_assign(lhs, rhs)?,
787                Instruction::Fun(lhs, fun, arg, is_real) => {
788                    self.append_fun(lhs, fun, arg, *is_real)?
789                }
790                Instruction::Join(lhs, cond, true_val, false_val) => {
791                    self.depth -= 1;
792                    self.append_join(lhs, cond, true_val, false_val)?
793                }
794                Instruction::Label(id) => self.append_label(*id)?,
795                Instruction::IfElse(cond, id) => {
796                    self.append_if_else(cond, *id)?;
797                    self.depth += 1;
798                }
799                Instruction::Goto(id) => self.append_goto(*id)?,
800                Instruction::ExternalFun(lhs, op, args) => {
801                    self.append_external_fun(lhs, op, args)?
802                }
803            }
804        }
805
806        Ok(())
807    }
808
809    pub fn append_constant(&mut self, z: Complex<f64>) -> Result<usize> {
810        self.consts.push(z);
811        Ok(self.consts.len() - 1)
812    }
813
814    pub fn append_add(&mut self, lhs: &Slot, args: &[Slot], num_reals: usize) -> Result<()> {
815        let args = self.consume_list(args)?;
816        let lhs = self.produce(lhs)?;
817        self.ssa.push(Instruction::Add(lhs, args, num_reals));
818        Ok(())
819    }
820
821    pub fn append_mul(&mut self, lhs: &Slot, args: &[Slot], num_reals: usize) -> Result<()> {
822        let args = self.consume_list(args)?;
823        let lhs = self.produce(lhs)?;
824        self.ssa.push(Instruction::Mul(lhs, args, num_reals));
825        Ok(())
826    }
827
828    pub fn append_pow(&mut self, lhs: &Slot, arg: &Slot, p: i64, is_real: bool) -> Result<()> {
829        let arg = self.consume(arg)?;
830        let lhs = self.produce(lhs)?;
831        self.ssa.push(Instruction::Pow(lhs, arg, p, is_real));
832        Ok(())
833    }
834
835    pub fn append_powf(&mut self, lhs: &Slot, arg: &Slot, p: &Slot, is_real: bool) -> Result<()> {
836        let arg = self.consume(arg)?;
837        let p = self.consume(p)?;
838        let lhs = self.produce(lhs)?;
839        self.ssa.push(Instruction::Powf(lhs, arg, p, is_real));
840        Ok(())
841    }
842
843    pub fn append_assign(&mut self, lhs: &Slot, rhs: &Slot) -> Result<()> {
844        let rhs = self.consume(rhs)?;
845        let lhs = self.produce(lhs)?;
846        self.ssa.push(Instruction::Assign(lhs, rhs));
847        Ok(())
848    }
849
850    pub fn append_label(&mut self, id: usize) -> Result<()> {
851        self.ssa.push(Instruction::Label(id));
852        Ok(())
853    }
854
855    pub fn append_if_else(&mut self, cond: &Slot, id: usize) -> Result<()> {
856        self.has_jump = true;
857        let cond = self.consume(cond)?;
858        self.ssa.push(Instruction::IfElse(cond, id));
859        Ok(())
860    }
861
862    pub fn append_goto(&mut self, id: usize) -> Result<()> {
863        self.last_label = self.last_label.max(id);
864        self.ssa.push(Instruction::Goto(id));
865        Ok(())
866    }
867
868    pub fn append_external_fun(&mut self, lhs: &Slot, op: &str, args: &[Slot]) -> Result<()> {
869        let args = self.consume_list(args)?;
870        let lhs = self.produce(lhs)?;
871        self.ssa
872            .push(Instruction::ExternalFun(lhs, op.to_string(), args));
873        Ok(())
874    }
875
876    pub fn append_fun(
877        &mut self,
878        lhs: &Slot,
879        fun: &BuiltinSymbol,
880        arg: &Slot,
881        is_real: bool,
882    ) -> Result<()> {
883        let arg = self.consume(arg)?;
884        let lhs = self.produce(lhs)?;
885        self.ssa.push(Instruction::Fun(lhs, *fun, arg, is_real));
886        Ok(())
887    }
888
889    pub fn append_join(
890        &mut self,
891        lhs: &Slot,
892        cond: &Slot,
893        true_val: &Slot,
894        false_val: &Slot,
895    ) -> Result<()> {
896        let cond = self.consume(cond)?;
897        let true_val = self.consume(true_val)?;
898        let false_val = self.consume(false_val)?;
899        let lhs = self.produce(lhs)?;
900        self.ssa
901            .push(Instruction::Join(lhs, cond, true_val, false_val));
902        Ok(())
903    }
904
905    /// Produces a new Static variable if needed.
906    /// slot should be an LHS.
907    fn produce(&mut self, slot: &Slot) -> Result<Slot> {
908        match slot {
909            Slot::Temp(idx) => {
910                if self.depth > 0 {
911                    if let Some(Slot::Static(s)) = self.temps.get(idx) {
912                        *self.counts.get_mut(s).unwrap() += 1;
913                        return Ok(Slot::Static(*s));
914                    }
915                }
916
917                let s = Slot::Static(self.count_statics);
918                self.counts.insert(self.count_statics, 0);
919                self.count_statics += 1;
920                self.temps.insert(*idx, s);
921                Ok(s)
922            }
923            Slot::Out(idx) => Ok(Slot::Out(*idx)),
924            _ => Err(anyhow!("unacceptable lhs.")),
925        }
926    }
927
928    /// Consumes a slot.
929    /// slot should be an RHS.
930    fn consume(&mut self, slot: &Slot) -> Result<Slot> {
931        match slot {
932            Slot::Temp(idx) => {
933                if let Some(Slot::Static(s)) = self.temps.get(idx) {
934                    *self.counts.get_mut(s).unwrap() += 1;
935                    Ok(Slot::Static(*s))
936                } else {
937                    Err(anyhow!("Not a static reg."))
938                }
939            }
940            Slot::Out(idx) => Ok(Slot::Out(*idx)),
941            Slot::Param(idx) => Ok(Slot::Param(*idx)),
942            Slot::Const(idx) => Ok(Slot::Const(*idx)),
943            Slot::Static(_) => Err(anyhow!("Undefined Static.")),
944        }
945    }
946
947    fn consume_list(&mut self, slots: &[Slot]) -> Result<Vec<Slot>> {
948        slots.iter().map(|s| self.consume(s)).collect()
949    }
950
951    /// The second pass. It translates the SSA-form into a Symjit model.
952    pub fn translate(&mut self) -> Result<(CellModel, HashSet<Loc>)> {
953        let ssa = std::mem::take(&mut self.ssa);
954
955        for line in ssa.iter() {
956            match line {
957                Instruction::Add(lhs, args, n) => self.translate_nary("plus", lhs, args, *n)?,
958                Instruction::Mul(lhs, args, n) => self.translate_nary("times", lhs, args, *n)?,
959                Instruction::Pow(lhs, arg, p, is_real) => {
960                    let p = Expr::from(*p as f64);
961                    self.translate_pow(lhs, arg, &p, *is_real)?
962                }
963                Instruction::Powf(lhs, arg, p, is_real) => {
964                    let p = self.expr(p, false);
965                    self.translate_pow(lhs, arg, &p, *is_real)?
966                }
967                Instruction::Assign(lhs, rhs) => self.translate_assign(lhs, rhs)?,
968                Instruction::Fun(lhs, fun, arg, is_real) => {
969                    self.translate_fun(lhs, fun, arg, *is_real)?
970                }
971                Instruction::Join(lhs, cond, true_val, false_val) => {
972                    self.translate_join(lhs, cond, true_val, false_val)?
973                }
974                Instruction::Label(id) => self.translate_label(*id)?,
975                Instruction::IfElse(cond, id) => self.translate_ifelse(cond, *id)?,
976                Instruction::Goto(id) => self.translate_goto(*id)?,
977                Instruction::ExternalFun(lhs, op, args) => {
978                    self.translate_external_fun(lhs, op, args)?
979                }
980            }
981        }
982
983        // Important! Outs are cached and should be written to final outputs.
984        for k in 0..self.outs.len() {
985            let out = Expr::var(&format!("Out{}", k));
986
987            if let Some(eq) = self.outs.get(&k) {
988                self.eqs.push(Expr::equation(&out, eq));
989            }
990        }
991
992        let mut params: Vec<Variable> = (0..=self.count_params.max(self.num_params.max(1) - 1))
993            .map(|idx| self.expr(&Slot::Param(idx), false).to_variable().unwrap())
994            .collect();
995
996        let mut states: Vec<Variable> = Vec::new();
997
998        if !self.config.symbolica() {
999            (params, states) = (states, params)
1000        }
1001
1002        Ok((
1003            CellModel {
1004                iv: Expr::var("$_").to_variable().unwrap(),
1005                params,
1006                states,
1007                algs: Vec::new(),
1008                odes: Vec::new(),
1009                obs: self.eqs.clone(),
1010            },
1011            self.reals.clone(),
1012        ))
1013    }
1014
1015    // The counterpart of consume for the second-pass
1016    fn expr(&mut self, slot: &Slot, is_real: bool) -> Expr {
1017        match slot {
1018            Slot::Param(idx) => {
1019                if is_real {
1020                    self.reals.insert(Loc::Param(*idx as u32));
1021                }
1022                self.count_params = self.count_params.max(*idx);
1023                Expr::var(&format!("Param{}", idx))
1024            }
1025            Slot::Out(idx) => {
1026                if let Some(e) = self.outs.get(idx) {
1027                    e.clone()
1028                } else {
1029                    Expr::var(&format!("Out{}", idx))
1030                }
1031            }
1032            Slot::Temp(idx) => Expr::var(&format!("__Temp{}", idx)),
1033            Slot::Const(idx) => {
1034                let val = self.consts[*idx];
1035                if val.im != 0.0 {
1036                    Expr::binary("complex", &Expr::from(val.re), &Expr::from(val.im))
1037                } else {
1038                    Expr::from(self.consts[*idx].re)
1039                }
1040            }
1041            Slot::Static(idx) => self
1042                .cache
1043                .remove(idx)
1044                .unwrap_or(Expr::var(&format!("__Static{}", idx))),
1045        }
1046    }
1047
1048    // The counterpart of produce for the second-pass
1049    fn assign(&mut self, lhs: &Slot, rhs: Expr) -> Result<()> {
1050        if !self.has_jump {
1051            if let Slot::Static(idx) = lhs {
1052                // Important! If a static variable is used only once, it
1053                // is pushed into the cache to be incorporated into the
1054                // destination expression tree.
1055                if self.counts.get(idx).is_some_and(|c| *c == 1) {
1056                    self.cache.insert(*idx, rhs);
1057                    return Ok(());
1058                }
1059            }
1060
1061            if let Slot::Out(idx) = lhs {
1062                self.outs.insert(*idx, rhs.clone());
1063                return Ok(());
1064            }
1065        }
1066
1067        let lhs = self.expr(lhs, false);
1068        self.eqs.push(Expr::equation(&lhs, &rhs));
1069        Ok(())
1070    }
1071
1072    fn translate_nary(&mut self, op: &str, lhs: &Slot, args: &[Slot], n: usize) -> Result<()> {
1073        let args: Vec<Expr> = args
1074            .iter()
1075            .enumerate()
1076            .map(|(i, x)| self.expr(x, i < n))
1077            .collect();
1078        let p: Vec<&Expr> = args.iter().collect();
1079
1080        if n == 0 || n >= p.len() {
1081            self.assign(lhs, Expr::nary(op, &p))
1082        } else {
1083            let l = Expr::nary(op, &p[..n]);
1084            let r = Expr::nary(op, &p[n..]);
1085            self.assign(lhs, Expr::nary(op, &[&l, &r]))
1086        }
1087    }
1088
1089    fn translate_pow(&mut self, lhs: &Slot, arg: &Slot, power: &Expr, is_real: bool) -> Result<()> {
1090        let arg = self.expr(arg, is_real);
1091        self.assign(lhs, Expr::binary("power", &arg, power))
1092    }
1093
1094    fn translate_assign(&mut self, lhs: &Slot, rhs: &Slot) -> Result<()> {
1095        let rhs = self.expr(rhs, false);
1096        self.assign(lhs, rhs)
1097    }
1098
1099    fn translate_fun(
1100        &mut self,
1101        lhs: &Slot,
1102        fun: &BuiltinSymbol,
1103        arg: &Slot,
1104        is_real: bool,
1105    ) -> Result<()> {
1106        let arg = self.expr(arg, is_real);
1107
1108        let op = match fun.0 {
1109            2 => "exp",
1110            3 => "ln",
1111            4 => "sin",
1112            5 => "cos",
1113            6 => {
1114                if is_real {
1115                    "real_root"
1116                } else {
1117                    "root"
1118                }
1119            }
1120            7 => "conjugate",
1121            _ => return Err(anyhow!("function is not defined.")),
1122        };
1123
1124        self.assign(lhs, Expr::unary(op, &arg))
1125    }
1126
1127    fn translate_external_fun(&mut self, lhs: &Slot, op: &str, args: &[Slot]) -> Result<()> {
1128        match args.len() {
1129            1 => {
1130                let arg = self.expr(&args[0], false);
1131                self.assign(lhs, Expr::unary(op, &arg))?;
1132            }
1133            2 => {
1134                let l = self.expr(&args[0], false);
1135                let r = self.expr(&args[1], false);
1136                self.assign(lhs, Expr::binary(op, &l, &r))?;
1137            }
1138            _ => {
1139                return Err(anyhow!(
1140                    "only unary and binary external functions are supported"
1141                ))
1142            }
1143        }
1144
1145        Ok(())
1146    }
1147
1148    fn translate_label(&mut self, id: usize) -> Result<()> {
1149        self.eqs.push(Expr::special(&Expr::Label { id }));
1150        Ok(())
1151    }
1152
1153    fn translate_ifelse(&mut self, cond: &Slot, id: usize) -> Result<()> {
1154        // let cond = Expr::binary(
1155        //     "lt",
1156        //     &Expr::unary("abs", &self.expr(cond, false)),
1157        //     &Expr::from(f64::EPSILON),
1158        // );
1159
1160        self.conds.push(*cond);
1161        let if_clause = Expr::binary("eq", &self.expr(cond, false), &Expr::from(0.0));
1162        self.eqs.push(Expr::special(&Expr::BranchIf {
1163            cond: Box::new(if_clause),
1164            id,
1165            is_else: false,
1166        }));
1167        Ok(())
1168    }
1169
1170    fn translate_goto(&mut self, id: usize) -> Result<()> {
1171        if self.config.simd_branch() {
1172            // TODO: the commented out area should be uncommented, except it causes
1173            // a bug in some programs. The effects are not local and likely related
1174            // to instruction movements.
1175
1176            /*
1177            let cond = self.conds.pop().unwrap();
1178            self.conds.push(cond);
1179            let if_clause = Expr::binary("eq", &self.expr(&cond, false), &Expr::from(0.0));
1180            self.eqs.push(Expr::special(&Expr::BranchIf {
1181                cond: Box::new(if_clause),
1182                id,
1183                is_else: true,
1184            }));
1185            */
1186        } else {
1187            self.eqs.push(Expr::special(&Expr::Branch { id }));
1188        }
1189
1190        Ok(())
1191    }
1192
1193    fn translate_join(
1194        &mut self,
1195        lhs: &Slot,
1196        _cond: &Slot,
1197        true_val: &Slot,
1198        false_val: &Slot,
1199    ) -> Result<()> {
1200        // Join is essentially a Φ-function.
1201        let t = self.expr(true_val, false);
1202        let f = self.expr(false_val, false);
1203        let cond = self.conds.pop().unwrap();
1204        let mask = Expr::binary("eq", &self.expr(&cond, false), &Expr::from(0.0));
1205        self.assign(lhs, mask.ifelse(&f, &t))?;
1206        Ok(())
1207    }
1208
1209    pub fn set_num_params(&mut self, num_params: usize) {
1210        self.num_params = num_params
1211    }
1212
1213    pub fn compile(&mut self) -> Result<Application> {
1214        let (ml, reals) = self.translate()?;
1215        let prog = Program::new(&ml, self.config)?;
1216        let mut app = Application::new(prog, reals, &self.df)?;
1217        app.prepare_simd();
1218        Ok(app)
1219    }
1220}
1221
1222impl Compiler {
1223    /// Compiles a Symbolica model.
1224    ///
1225    /// `json` is the JSON-encoded output of Symbolica `export_instructions`.
1226    ///
1227    /// Example:
1228    ///
1229    /// ```rust
1230    /// let params = vec![parse!("x"), parse!("y")];
1231    /// let eval = parse!("x + y^2")
1232    ///     .evaluator(&FunctionMap::new(), &params, OptimizationSettings::default())?
1233    ///
1234    /// let json = serde_json::to_string(&eval.export_instructions())?;
1235    /// let mut comp = Compiler::new();
1236    /// let mut app = comp.translate(&json)?;
1237    /// assert!(app.evaluate_single(&[2.0, 3.0]) == 11.0);
1238    /// ```
1239    pub fn translate(
1240        &mut self,
1241        json: String,
1242        df: &Defuns,
1243        num_params: usize,
1244    ) -> Result<Application> {
1245        let mut translator = Translator::new(self.config, df);
1246
1247        let model: SymbolicaModel = if json.starts_with("[[{") {
1248            serde_json::from_str(json.as_str())?
1249        } else {
1250            Parser::new(json).parse()?
1251        };
1252
1253        translator.parse_model(&model)?;
1254        translator.set_num_params(num_params);
1255        let (ml, reals) = translator.translate()?;
1256
1257        let prog = Program::new(&ml, self.config)?;
1258        let df = Defuns::new();
1259        let mut app = Application::new(prog, reals, &df)?;
1260
1261        app.prepare_simd();
1262
1263        // #[cfg(target_arch = "aarch64")]
1264        // if let Ok(app) = &mut app {
1265        //     // this is a hack to give enough delay to prevent a bus error
1266        //     app.dump("dump.bin", "scalar");
1267        //     std::fs::remove_file("dump.bin")?;
1268        // };
1269
1270        Ok(app)
1271    }
1272}