Skip to main content

symjit/
lib.rs

1#![allow(uncommon_codepoints)]
2
3//! Symjit (<https://github.com/siravan/symjit>) is a lightweight just-in-time (JIT)
4//! optimizer compiler for mathematical expressions written in Rust. It was originally
5//! designed to compile SymPy (Python’s symbolic algebra package) expressions
6//! into machine code and to serve as a bridge between SymPy and numerical routines
7//! provided by NumPy and SciPy libraries.
8//!
9//! Symjit crate is the core compiler coupled to a Rust interface to expose the
10//! JIT functionality to the Rust ecosystem and allow Rust applications to
11//! generate code dynamically. Considering its origin, symjit is geared toward
12//! compiling mathematical expressions instead of being a general-purpose JIT
13//! compiler. Therefore, the only supported types for variables are `f64`,
14//! (SIMD f64x4 and f64x2), and implicitly, `bool` and `i32`.
15//!
16//! Symjit emits AMD64 (x86-64), ARM64 (aarch64), and 64-bit RISC-V (riscv64) machine
17//! codes on Linux, Windows, and macOS platforms. SIMD is supported on x86-64
18//! and ARM64.
19//!
20//! In Rust, there are two ways to contruct expressions to pass to Symjit: using
21//! Symbolica or using Symjit standalone expression builder.
22//!
23//! # Symbolica
24//!
25//! Symbolica (<https://symbolica.io/>) is a fast Rust-based Computer Algebra System.
26//! As of version 1.5 of Symbolica, Symjit is an optional backend for Symbolica. Therefore,
27//! the previous interface through `Compiler` object is considered obsolete and should not
28//! be used for new projects.
29//!
30//! # Standalone Expression Builder
31//!
32//! A second way to use Symjit is by using its standalone expression builder. Compared to
33//! Symbolica, the expression builder is limited but is useful in situations that the goal
34//! is to compile an expression without extensive symbolic manipulations.
35//!
36//! The workflow to create, compile, and run expressions is:
37//!
38//! 1. Create terminals (variables and constants) and compose expressions using `Expr` methods:
39//!    * Constructors: `var`, `from`, `unary`, `binary`, ...
40//!    * Standard algebraic operations: `add`, `mul`, ...
41//!    * Standard operators `+`, `-`, `*`, `/`, `%`, `&`, `|`, `^`, `!`.
42//!    * Unary functions such as `sin`, `exp`, and other standard mathematical functions.
43//!    * Binary functions such as `pow`, `min`, ...
44//!    * IfElse operation `ifelse(cond, true_val, false_val)`.
45//!    * Heaviside function: `heaviside(x)`, which returns 1 if `x >= 0`; otherwise 0.
46//!    * Comparison methods `eq`, `ne`, `lt`, `le`, `gt`, and `ge`.
47//!    * Looping constructs `sum` and `prod`.
48//! 2. Create a new `Compiler` object (say, `comp`) using one of its constructors.
49//! 3. Define user-defined functions by calling `comp.def_unary` and `comp.def_binary`
50//!    (optional).
51//! 4. Compile by calling `comp.compile` or `comp.compile_params`. The result is of
52//!    type `Application` (say, `app`).
53//! 5. Execute the compiled code using one of the `app`'s `call` functions:
54//!    * `call(&[f64])`: scalar call.
55//!    * `call_params(&[f64], &[f64])`: scalar call with parameters.
56//!    * `call_simd(&[__m256d])`: simd call.
57//!    * `call_simd_params(&[__m256d], &[f64])`: simd call with parameters.
58//! 6. Optionally, generate a standalone fast function to execute.
59//!
60//! Note that you can use the helper functions `var(&str) -> Expr`, `int(i32) -> Expr`,
61//! `double(f64) -> Expr`, and `boolean(bool) -> f64` to reduce clutter.
62//!
63//! # Examples
64//!
65//! ```rust
66//! use anyhow::Result;
67//! use symjit::{Compiler, Expr};
68//!
69//! pub fn test_scalar() -> Result<()> {
70//!     let x = Expr::var("x");
71//!     let y = Expr::var("y");
72//!     let u = &x + &y;
73//!     let v = &x * &y;
74//!
75//!     let mut comp = Compiler::new();
76//!     let mut app = comp.compile(&[x, y], &[u, v])?;
77//!     let res = app.call(&[3.0, 5.0]);
78//!     println!("{:?}", &res);   // prints [8.0, 15.0]
79//!
80//!     Ok(())
81//! }
82//! ```
83//!
84//! `test_scalar` is similar to the following basic example in Python/SymPy:
85//!
86//! ```python
87//! from symjit import compile_func
88//! from sympy import symbols
89//!
90//! x, y = symbols('x y')
91//! f = compile_func([x, y], [x+y, x*y])
92//! print(f(3.0, 5.0))  # prints [8.0, 15.0]
93//! ```
94//!
95//! A more elaborate example, showcasing having a parameter, changing the
96//! optimization level, and using SIMD:
97//!
98//! ```rust
99//! use anyhow::Result;
100//! use symjit::{var, Compiler, Expr};
101//!
102//! pub fn test_simd() -> Result<()> {
103//!     use std::arch::x86_64::_mm256_loadu_pd;
104//!
105//!     let x = var("x");   // note var instead of Expr::var
106//!     let p = var("p");   // the parameter
107//!
108//!     let u = &x.square() * &p;    // x^2 * p
109//!     let mut comp = Compiler::new();
110//!     comp.opt_level(2);  // optional (opt_level 0 to 2; default 1)
111//!     let mut app = comp.compile_params(&[x], &[u], &[p])?;
112//!
113//!     let a = &[1.0, 2.0, 3.0, 4.0];
114//!     let a = unsafe { _mm256_loadu_pd(a.as_ptr()) };
115//!     let res = app.call_simd_params(&[a], &[5.0])?;
116//!     println!("{:?}", &res);   // prints [__m256d(5.0, 20.0, 45.0, 80.0)]
117//!     Ok(())
118//! }
119//! ```
120//!
121//! # Conditional Expression and Loops
122//!
123//! Many mathematical formulas need conditional expressions (`ifelse`) and loops.
124//! Following SymPy, Symjit uses reduction loops such as `sum` and `prod`. The following
125//! example returns the exponential functions:
126//!
127//! ```rust
128//! use symjit::{int, var, Compiler};
129//!
130//! fn test_exp() -> Result<()> {
131//!     let x = var("x");
132//!     let i = var("i");   // loop variable
133//!     let j = var("j");   // loop variable
134//!
135//!     // u = x^j / factorial(j) for j in j in 0..=50
136//!     let u = x
137//!         .pow(&j)
138//!         .div(&i.prod(&i, &int(1), &j))
139//!         .sum(&j, &int(0), &int(50));
140//!
141//!     let mut app = Compiler::new().compile(&[x], &[u])?;
142//!     println!("{:?}", app(&[2.0])[0]); // returns exp(2.0) = 7.38905...
143//!     Ok(())
144//! }
145//! ```
146//!
147//! An example showing how to calculate pi using the Leibniz formula:
148//!
149//! ```rust
150//! use symjit::{int, var, Compiler};
151//!
152//! fn test_pi() -> Result<()> {
153//!     let n = var("n");
154//!     let i = var("i");   // loop variable
155//!     let j = var("j");   // loop variable
156//!
157//!     // numer = if j % 2 == 0 { 4 } else { -4 }
158//!     let numer = j.rem(&int(2)).eq(&int(0)).ifelse(&int(4), &int(-4));
159//!     // denom = j * 2 + 1
160//!     let denom = j.mul(&int(2)).add(&int(1));
161//!     // v = numer / denom for j in 0..=n
162//!     let v = (&numer / &denom).sum(&j, &int(0), &int(&n));
163//!
164//!     let mut app = Compiler::new().compile(&[x], &[v])?;
165//!     println!("{:?}", app(&[100000000])[0]); // returns pi
166//!     Ok(())
167//! }
168//! ```
169//!
170//! Note that here we are using explicit functions (`add`, `mul`, ...) instead of
171//! the overloaded operators for clarity.
172//!
173//! # Fast Functions
174//!
175//! `Application`'s call functions need to copy the input slice into the function
176//! memory area and then copy the output to a `Vec`. This process is acceptable
177//! for large and complex functions but incurs a penalty for small ones.
178//! Therefore, for a certain subset of applications, Symjit can compile to a
179//! *fast function* and return a function pointer. Examples:
180//!
181//! ```rust
182//! use anyhow::Result;
183//! use symjit::{int, var, Compiler, FastFunc};
184//!
185//! fn test_fast() -> Result<()> {
186//!     let x = var("x");
187//!     let y = var("y");
188//!     let z = var("z");
189//!     let u = &x * &(&y - &z).pow(&int(2));    // x * (y - z)^2
190//!
191//!     let mut comp = Compiler::new();
192//!     let mut app = comp.compile(&[x, y, z], &[u])?;
193//!     let f = app.fast_func()?;
194//!
195//!     if let FastFunc::F3(f, _) = f {
196//!         // f is of type extern "C" fn(f64, f64, f64) -> f64
197//!         let res = f(3.0, 5.0, 9.0);
198//!         println!("fast\t{:?}", &res);
199//!     }
200//!
201//!     Ok(())
202//! }
203//! ```
204//!
205//! The conditions for a fast function are:
206//!
207//! * A fast function can have 1 to 8 arguments.
208//! * No SIMD and no parameters.
209//! * It returns only a single value.
210//!
211//! If these conditions are met, you can generate a fast function by calling
212//! `app.fast_func()`, which returns a `Result<FastFunc>`. `FastFunc` is an
213//! enum with eight variants `F1`, `F2`, ..., `F8`, corresponding to functions
214//! with 1 to 8 arguments.
215//!
216//! # User-Defined Functions
217//!
218//! Symjit functions can call into user-defined Rust functions. Currently,
219//! only the following function signatures are accepted:
220//!
221//! ```rust
222//! pub type UnaryFunc = extern "C" fn(f64) -> f64;
223//! pub type BinaryFunc = extern "C" fn(f64, f64) -> f64;
224//! ```
225//!
226//! For example:
227//!
228//! ```rust
229//! extern "C" fn f(x: f64) -> f64 {
230//!     x.exp()
231//! }
232//!
233//! extern "C" fn g(x: f64, y: f64) -> f64 {
234//!     x.ln() * y
235//! }
236//!
237//! fn test_external() -> Result<()> {
238//!     let x = Expr::var("x");
239//!     let u = Expr::unary("f_", &x);
240//!     let v = &x * &Expr::binary("g_", &u, &x);
241//!
242//!     // v(x) = x * (ln(exp(x)) * x) = x ^ 3
243//!
244//!     let mut comp = Compiler::new();
245//!     comp.def_unary("f_", f);
246//!     comp.def_binary("g_", g);
247//!     let mut app = comp.compile(&[x], &[v])?;
248//!     println!("{:?}", app.call(&[5.0])[0]);
249//!
250//!     Ok(())
251//! }
252//! ```
253//!
254//! # Dynamic Expressions
255//!
256//! All the examples up to this point use static expressions. Of course, it
257//! would have been easier just to use Rust expressions for these examples!
258//! The main utility of Symjit for Rust is for dynamic code generation. Here,
259//! we provide a simple example to calculate pi using Viete's formula
260//! (<https://en.wikipedia.org/wiki/Vi%C3%A8te%27s_formula>):
261//!
262//! ```rust
263//! fn test_pi_viete(silent: bool) -> Result<()> {
264//!     let x = var("x");
265//!     let mut u = int(1);
266//!
267//!     for i in 0..50 {
268//!         let mut t = x.clone();
269//!
270//!         for _ in 0..i {
271//!             t = &x + &(&x * &t.sqrt());
272//!         }
273//!
274//!         u = &u * &t.sqrt();
275//!     }
276//!
277//!     // u has 1275 = 50 * 51 / 2 sqrt operations
278//!     let mut app = Compiler::new().compile(&[x], &[&int(2) / &u])?;
279//!     println!("pi = \t{:?}", app.call(&[0.5])[0]);
280//!     Ok(())
281//! }
282//! ```
283//!
284//! # C-Interface
285//!
286//! In addition to `Compiler`, this crate provides a C-style interface
287//! used by the Python (<https://github.com/siravan/symjit>) and Julia
288//! (<https://github.com/siravan/Symjit.jl>) packages. This interface
289//! is composed of crate functions like `compile`, `execute`, and
290//! `ptr_states`,..., and is not needed by the Rust interface but can be
291//! used to link symjit to other programming languages.
292//!
293
294use std::collections::HashSet;
295use std::ffi::{c_char, CStr, CString};
296use std::fmt::Debug;
297use std::str::FromStr;
298
299mod allocator;
300mod amd;
301mod applet;
302mod arm;
303mod assembler;
304mod block;
305mod builder;
306mod code;
307mod compactor;
308pub mod compiler;
309mod complexify;
310mod composer;
311mod config;
312mod defuns;
313pub mod expr;
314mod generator;
315pub mod instruction;
316mod machine;
317mod matrix;
318mod memory;
319mod mir;
320mod model;
321mod node;
322mod parser;
323mod runnable;
324mod serializer;
325mod statement;
326mod symbol;
327mod types;
328mod utils;
329
330#[allow(non_upper_case_globals)]
331mod riscv64;
332
333use matrix::Matrix;
334use model::{CellModel, Program};
335
336pub use applet::Applet;
337pub use compiler::{Compiler, FastFunc, Translator};
338pub use composer::Composer;
339pub use config::Config;
340pub use defuns::Defuns;
341pub use expr::{double, int, var, Expr};
342pub use instruction::{BuiltinSymbol, Instruction, Slot, SymbolicaModel};
343pub use num_complex::{Complex, ComplexFloat};
344pub use runnable::{Application, CompilerType};
345pub use serializer::MirWriter;
346pub use types::{ElemType, Element};
347pub use utils::{Compiled, Storage};
348
349#[derive(Debug, Clone, Copy)]
350pub enum CompilerStatus {
351    Ok,
352    Incomplete,
353    InvalidUtf8,
354    ParseError,
355    InvalidCompiler,
356    CompilationError,
357}
358
359pub struct CompilerResult {
360    app: Option<Application>,
361    status: CompilerStatus,
362    msg: CString,
363}
364
365fn error_message<E: Debug>(msg: &str, err: E) -> CString {
366    let s = format!("{:?}: {:?}", msg, err);
367    CString::from_str(&s).unwrap()
368}
369
370/// Compiles a model.
371///
372/// * `model` is a json string encoding the model.
373/// * `ty` is the requested arch (amd, arm, native, or bytecode).
374/// * `opt`: compilation options.
375/// * `df`: user-defined functions.
376///
377/// # Safety
378///     * both model and ty are pointers to null-terminated strings.
379///     * The output is a raw pointer to a CompilerResults.
380///
381#[no_mangle]
382pub unsafe extern "C" fn compile(
383    model: *const c_char,
384    ty: *const c_char,
385    opt: u32,
386    df: *const Defuns,
387) -> *const CompilerResult {
388    let mut res = CompilerResult {
389        app: None,
390        status: CompilerStatus::Incomplete,
391        msg: CString::from_str("Success").unwrap(),
392    };
393
394    let model = unsafe {
395        match CStr::from_ptr(model).to_str() {
396            Ok(model) => model,
397            Err(msg) => {
398                res.status = CompilerStatus::InvalidUtf8;
399                res.msg = error_message("Invalid encoding", msg);
400                return Box::into_raw(Box::new(res)) as *const _;
401            }
402        }
403    };
404
405    let ty = unsafe {
406        match CStr::from_ptr(ty).to_str() {
407            Ok(ty) => ty,
408            Err(msg) => {
409                res.status = CompilerStatus::InvalidUtf8;
410                res.msg = error_message("Invalid compiler type", msg);
411                return Box::into_raw(Box::new(res)) as *const _;
412            }
413        }
414    };
415
416    let ml = match CellModel::load(model) {
417        Ok(ml) => ml,
418        Err(msg) => {
419            res.status = CompilerStatus::ParseError;
420            res.msg = error_message("Cannot parse JSON", msg);
421            return Box::into_raw(Box::new(res)) as *const _;
422        }
423    };
424
425    if let Ok(mut config) = Config::from_name(ty, opt) {
426        let df: Defuns = unsafe {
427            if df.is_null() {
428                Defuns::new()
429            } else {
430                (&*df).clone()
431            }
432        };
433
434        config.set_defuns(df);
435
436        let prog = match Program::new(&ml, config) {
437            Ok(prog) => prog,
438            Err(msg) => {
439                res.status = CompilerStatus::CompilationError;
440                res.msg = error_message("Compilation error (prog)", msg);
441                return Box::into_raw(Box::new(res)) as *const _;
442            }
443        };
444
445        let app = Application::new(prog, HashSet::new());
446
447        match app {
448            Ok(app) => {
449                res.status = CompilerStatus::Ok;
450                res.app = Some(app);
451            }
452            Err(msg) => {
453                res.status = CompilerStatus::CompilationError;
454                res.msg = error_message("Compilation error (app)", &msg);
455            }
456        }
457    } else {
458        res.status = CompilerStatus::InvalidCompiler;
459        res.msg = error_message("Config error", opt);
460    }
461
462    Box::into_raw(Box::new(res)) as *const _
463}
464
465/// Translates a Symbolica model.
466///
467/// * `json` is a json string encoding the output of `export_instructions`.
468/// * `ty` is the requested arch (amd, arm, native, or bytecode).
469/// * `opt`: compilation options.
470/// * `df`: user-defined functions (currently ignored).
471///
472/// # Safety
473///     * both model and ty are pointers to null-terminated strings.
474///     * The output is a raw pointer to a CompilerResults.
475///
476#[no_mangle]
477pub unsafe extern "C" fn translate(
478    json: *const c_char,
479    ty: *const c_char,
480    opt: u32,
481    df: *mut Defuns,
482    num_params: usize,
483) -> *const CompilerResult {
484    let mut res = CompilerResult {
485        app: None,
486        status: CompilerStatus::Incomplete,
487        msg: CString::from_str("Success").unwrap(),
488    };
489
490    let json = unsafe {
491        match CStr::from_ptr(json).to_str() {
492            Ok(json) => json,
493            Err(msg) => {
494                res.status = CompilerStatus::InvalidUtf8;
495                res.msg = error_message("Invalid encoding", msg);
496                return Box::into_raw(Box::new(res)) as *const _;
497            }
498        }
499    };
500
501    let ty = unsafe {
502        match CStr::from_ptr(ty).to_str() {
503            Ok(ty) => ty,
504            Err(msg) => {
505                res.status = CompilerStatus::InvalidUtf8;
506                res.msg = error_message("Invalid compiler type", msg);
507                return Box::into_raw(Box::new(res)) as *const _;
508            }
509        }
510    };
511
512    if let Ok(mut config) = Config::from_name(ty, opt) {
513        let df: Defuns = unsafe {
514            if df.is_null() {
515                Defuns::new()
516            } else {
517                (&*df).clone()
518            }
519        };
520
521        config.set_defuns(df);
522        let mut comp = Compiler::with_config(config);
523        let app = comp.translate(json.to_string(), num_params);
524
525        match app {
526            Ok(app) => {
527                res.app = Some(app);
528                res.status = CompilerStatus::Ok;
529            }
530            Err(msg) => {
531                res.status = CompilerStatus::InvalidCompiler;
532                res.msg = error_message("Compilation error", msg);
533            }
534        }
535    } else {
536        res.status = CompilerStatus::InvalidCompiler;
537        res.msg = error_message("Config error", opt);
538    }
539
540    Box::into_raw(Box::new(res)) as *const _
541}
542
543/// Checks the status of a `CompilerResult`.
544///
545/// Returns a null-terminated string representing the status message.
546///
547/// # Safety
548///     it is the responsibility of the calling function to ensure
549///     that q points to a valid CompilerResult.
550///
551#[no_mangle]
552pub unsafe extern "C" fn check_status(q: *const CompilerResult) -> *const c_char {
553    let q: &CompilerResult = unsafe { &*q };
554    q.msg.as_ptr() as *const _
555}
556
557/// Checks the status of a `CompilerResult`.
558///
559/// Returns a null-terminated string representing the status message.
560///
561/// # Safety
562///     it is the responsibility of the calling function to ensure
563///     that q points to a valid CompilerResult.
564///
565#[no_mangle]
566pub unsafe extern "C" fn save(q: *const CompilerResult, file: *const c_char) -> bool {
567    let q: &CompilerResult = unsafe { &*q };
568    let file = unsafe {
569        match CStr::from_ptr(file).to_str() {
570            Ok(file) => file,
571            Err(_) => return false,
572        }
573    };
574
575    if let Some(app) = &q.app {
576        if let Ok(mut fs) = std::fs::File::create(file) {
577            app.save(&mut fs).is_ok()
578        } else {
579            false
580        }
581    } else {
582        false
583    }
584}
585
586/// Checks the status of a `CompilerResult`.
587///
588/// Returns a null-terminated string representing the status message.
589///
590/// # Safety
591///     it is the responsibility of the calling function to ensure
592///     that q points to a valid CompilerResult.
593///
594#[no_mangle]
595pub unsafe extern "C" fn load(file: *const c_char, df: *mut Defuns) -> *const CompilerResult {
596    let mut res = CompilerResult {
597        app: None,
598        status: CompilerStatus::Incomplete,
599        msg: CString::from_str("Success").unwrap(),
600    };
601
602    let df: Defuns = unsafe {
603        if df.is_null() {
604            Defuns::new()
605        } else {
606            (&*df).clone()
607        }
608    };
609
610    let file = unsafe {
611        match CStr::from_ptr(file).to_str() {
612            Ok(file) => file,
613            Err(_) => return Box::into_raw(Box::new(res)) as *const _,
614        }
615    };
616
617    let fs = std::fs::File::open(file);
618
619    match fs {
620        Ok(mut fs) => match Application::load(&mut fs, &Config::from_defuns(df).unwrap()) {
621            Ok(app) => {
622                res.app = Some(app);
623                res.status = CompilerStatus::Ok;
624            }
625            Err(err) => {
626                res.status = CompilerStatus::ParseError;
627                res.msg = error_message("File parse error", &err);
628            }
629        },
630        Err(err) => {
631            res.msg = error_message("File I/O error", &err);
632        }
633    }
634
635    Box::into_raw(Box::new(res)) as *const _
636}
637
638/// Checks the status of a `CompilerResult`.
639///
640/// Returns a null-terminated string representing the status message.
641///
642/// # Safety
643///     it is the responsibility of the calling function to ensure
644///     that q points to a valid CompilerResult.
645///
646#[no_mangle]
647pub unsafe extern "C" fn get_config(q: *const CompilerResult) -> usize {
648    let q: &CompilerResult = unsafe { &*q };
649
650    match &q.app {
651        Some(app) => {
652            let config = app.prog.config();
653
654            let ty: usize = match config.ty {
655                CompilerType::Native => 0,
656                CompilerType::Amd => 1,
657                CompilerType::AmdAVX => 2,
658                CompilerType::AmdSSE => 3,
659                CompilerType::Arm => 4,
660                CompilerType::RiscV => 5,
661                CompilerType::ByteCode => 6,
662                CompilerType::Debug => 7,
663            };
664
665            (config.opt as usize) | (ty << 32)
666        }
667        None => 0,
668    }
669}
670
671/// Returns the number of state variables.
672///
673/// # Safety
674///     it is the responsibility of the calling function to ensure
675///     that q points to a valid CompilerResult.
676///
677#[no_mangle]
678pub unsafe extern "C" fn count_states(q: *const CompilerResult) -> usize {
679    let q: &CompilerResult = unsafe { &*q };
680    if let Some(app) = &q.app {
681        app.count_states
682    } else {
683        0
684    }
685}
686
687/// Returns the number of parameters.
688///
689/// # Safety
690///     it is the responsibility of the calling function to ensure
691///     that q points to a valid CompilerResult.
692///
693#[no_mangle]
694pub unsafe extern "C" fn count_params(q: *const CompilerResult) -> usize {
695    let q: &CompilerResult = unsafe { &*q };
696    if let Some(app) = &q.app {
697        app.count_params
698    } else {
699        0
700    }
701}
702
703/// Returns the number of observables (output).
704///
705/// # Safety
706///     it is the responsibility of the calling function to ensure
707///     that q points to a valid CompilerResult.
708///
709#[no_mangle]
710pub unsafe extern "C" fn count_obs(q: *const CompilerResult) -> usize {
711    let q: &CompilerResult = unsafe { &*q };
712    if let Some(app) = &q.app {
713        app.count_obs
714    } else {
715        0
716    }
717}
718
719/// Returns the number of differential equations.
720///
721/// Generally, it should be the same as the number of states.
722///
723/// # Safety
724///     it is the responsibility of the calling function to ensure
725///     that q points to a valid CompilerResult.
726///
727#[no_mangle]
728pub unsafe extern "C" fn count_diffs(q: *const CompilerResult) -> usize {
729    let q: &CompilerResult = unsafe { &*q };
730    if let Some(app) = &q.app {
731        app.count_diffs
732    } else {
733        0
734    }
735}
736
737/// Deprecated. Previously used for interfacing to DifferentialEquation.jl. It is
738/// replaced with <https://github.com/siravan/SymJit.jl>.
739///
740/// # Safety
741///
742/// Deprecated. No effects.
743#[no_mangle]
744pub unsafe extern "C" fn run(
745    _q: *mut CompilerResult,
746    _du: *mut f64,
747    _u: *const f64,
748    _ns: usize,
749    _p: *const f64,
750    _np: usize,
751    _t: f64,
752) -> bool {
753    // let q: &mut CompilerResult = unsafe { &mut *q };
754
755    // if let Some(app) = &mut q.app {
756    //     if app.count_states != ns || app.count_params != np {
757    //         return false;
758    //     }
759
760    //     let du: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(du, ns) };
761    //     let u: &[f64] = unsafe { std::slice::from_raw_parts(u, ns) };
762    //     let p: &[f64] = unsafe { std::slice::from_raw_parts(p, np) };
763    //     app.call(du, u, p, t);
764    //     true
765    // } else {
766    //     false
767    // }
768    false
769}
770
771/// Executes the compiled function.
772///
773/// The calling routine should fill the states and parameters before
774/// calling `execute`. The result populates obs or diffs (as defined in
775/// model passed to `compile`).
776///
777/// # Safety
778///     it is the responsibility of the calling function to ensure
779///     that q points to a valid CompilerResult.
780///
781#[no_mangle]
782pub unsafe extern "C" fn execute(q: *mut CompilerResult) -> bool {
783    let q: &mut CompilerResult = unsafe { &mut *q };
784
785    if let Some(app) = &mut q.app {
786        app.exec();
787        true
788    } else {
789        false
790    }
791}
792
793/// Executes the compiled function `n` times (vectorized).
794///
795/// The calling function provides `buf`, which is a k x n matrix of doubles,
796/// where k is equal to the `maximum(count_states, count_obs)`. The calling
797/// funciton fills the first `count_states` rows of buf. The result is returned
798/// in the first count_obs rows of buf.
799///
800/// # Safety
801///     it is the responsibility of the calling function to ensure
802///     that q points to a valid CompilerResult.
803///
804///     In addition, buf should points to a valid matrix of correct size.
805///
806#[no_mangle]
807pub unsafe extern "C" fn execute_vectorized(
808    q: *mut CompilerResult,
809    buf: *mut f64,
810    n: usize,
811) -> bool {
812    let q: &mut CompilerResult = unsafe { &mut *q };
813
814    if let Some(app) = &mut q.app {
815        let h = usize::max(app.count_states, app.count_obs);
816        let p: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(buf, h * n) };
817        let mut states = Matrix::from_buf(p, h, n);
818        let p: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(buf, h * n) };
819        let mut obs = Matrix::from_buf(p, h, n);
820        app.exec_vectorized(&mut states, &mut obs);
821        true
822    } else {
823        false
824    }
825}
826
827/// Evaluates the compiled function. This is for Symbolica compatibility.
828///
829/// # Safety
830///     it is the responsibility of the calling function to ensure
831///     that q points to a valid CompilerResult.
832///
833#[no_mangle]
834pub unsafe extern "C" fn evaluate(
835    q: *mut CompilerResult,
836    args: *const f64,
837    nargs: usize,
838    outs: *mut f64,
839    nouts: usize,
840) -> bool {
841    let q: &mut CompilerResult = unsafe { &mut *q };
842
843    if let Some(app) = &mut q.app {
844        let args: &[f64] = unsafe { std::slice::from_raw_parts(args, nargs) };
845        let outs: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(outs, nouts) };
846        app.evaluate(args, outs);
847        true
848    } else {
849        false
850    }
851}
852
853/// Evaluates the compiled function. This is for Symbolica compatibility.
854///
855/// # Safety
856///     it is the responsibility of the calling function to ensure
857///     that q points to a valid CompilerResult.
858///
859#[no_mangle]
860pub unsafe extern "C" fn evaluate_matrix(
861    q: *mut CompilerResult,
862    args: *const f64,
863    nargs: usize,
864    outs: *mut f64,
865    nouts: usize,
866) -> bool {
867    let q: &mut CompilerResult = unsafe { &mut *q };
868
869    if let Some(app) = &mut q.app {
870        if app.count_params == 0 {
871            return false;
872        }
873
874        let args: &[f64] = unsafe { std::slice::from_raw_parts(args, nargs) };
875        let outs: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(outs, nouts) };
876        let n = nargs / app.count_params;
877        app.evaluate_matrix(args, outs, n);
878        true
879    } else {
880        false
881    }
882}
883
884/// Returns a pointer to the state variables (`count_states` doubles).
885///
886/// The function calling `execute` should write the state variables in this area.
887///
888/// # Safety
889///     it is the responsibility of the calling function to ensure
890///     that q points to a valid CompilerResult.
891///
892#[no_mangle]
893pub unsafe extern "C" fn ptr_states(q: *mut CompilerResult) -> *mut f64 {
894    let q: &mut CompilerResult = unsafe { &mut *q };
895    if let Some(app) = &mut q.app {
896        if let Some(f) = &mut app.compiled {
897            &mut f.mem_mut()[app.first_state] as *mut f64
898        } else {
899            &mut app.bytecode.mem_mut()[app.first_state] as *mut f64
900        }
901    } else {
902        std::ptr::null_mut()
903    }
904}
905
906/// Returns a pointer to the parameters (`count_params` doubles).
907///
908/// The function calling `execute` should write the parameters in this area.
909///
910/// # Safety
911///     it is the responsibility of the calling function to ensure
912///     that q points to a valid CompilerResult.
913///
914#[no_mangle]
915pub unsafe extern "C" fn ptr_params(q: *mut CompilerResult) -> *mut f64 {
916    let q: &mut CompilerResult = unsafe { &mut *q };
917    if let Some(app) = &mut q.app {
918        //&mut app.compiled.mem_mut()[app.first_param] as *mut f64
919        &mut app.params[app.first_param] as *mut f64
920    } else {
921        std::ptr::null_mut()
922    }
923}
924
925/// Returns a pointer to the observables (`count_obs` doubles).
926///
927/// The function calling `execute` reads the observables from this area.
928///
929/// # Safety
930///     it is the responsibility of the calling function to ensure
931///     that q points to a valid CompilerResult.
932///
933#[no_mangle]
934pub unsafe extern "C" fn ptr_obs(q: *mut CompilerResult) -> *const f64 {
935    let q: &CompilerResult = unsafe { &*q };
936    if let Some(app) = &q.app {
937        if let Some(f) = &app.compiled {
938            &f.mem()[app.first_obs] as *const f64
939        } else {
940            &app.bytecode.mem()[app.first_obs] as *const f64
941        }
942    } else {
943        std::ptr::null()
944    }
945}
946
947/// Returns a pointer to the differentials (`count_diffs` doubles).
948///
949/// The function calling `execute` reads the differentials from this area.
950///
951/// Note: whether the output is returned as observables or differentials is
952/// defined in the model.
953///
954/// # Safety
955///     it is the responsibility of the calling function to ensure
956///     that q points to a valid CompilerResult.
957///
958#[no_mangle]
959pub unsafe extern "C" fn ptr_diffs(q: *mut CompilerResult) -> *const f64 {
960    let q: &CompilerResult = unsafe { &*q };
961    if let Some(app) = &q.app {
962        if let Some(f) = &app.compiled {
963            &f.mem()[app.first_diff] as *const f64
964        } else {
965            &app.bytecode.mem()[app.first_diff] as *const f64
966        }
967    } else {
968        std::ptr::null()
969    }
970}
971
972/// Dumps the compiled binary code to a file (`name`).
973///
974/// This function is useful for debugging but is not necessary for
975/// normal operations.
976///
977/// # Safety
978///     it is the responsibility of the calling function to ensure
979///     that q points to a valid CompilerResult.
980///
981#[no_mangle]
982pub unsafe extern "C" fn dump(
983    q: *mut CompilerResult,
984    name: *const c_char,
985    what: *const c_char,
986) -> bool {
987    let q: &mut CompilerResult = unsafe { &mut *q };
988    if let Some(app) = &mut q.app {
989        let name = unsafe { CStr::from_ptr(name).to_str().unwrap() };
990        let what = unsafe { CStr::from_ptr(what).to_str().unwrap() };
991        app.dump(name, what)
992    } else {
993        false
994    }
995}
996
997/// Deallocates the CompilerResult pointed by `q`.
998///
999/// # Safety
1000///     it is the responsibility of the calling function to ensure
1001///     that q points to a valid CompilerResult and that after
1002///     calling this function, q is invalid and should not
1003///     be used anymore.
1004///
1005#[no_mangle]
1006pub unsafe extern "C" fn finalize(q: *mut CompilerResult) {
1007    if !q.is_null() {
1008        let _ = unsafe { Box::from_raw(q) };
1009    }
1010}
1011
1012/// Returns a null-terminated string representing the version.
1013///
1014/// Used for debugging.
1015///
1016/// # Safety
1017///     the return value is a null-terminated string that should not
1018///     be freed.
1019///
1020#[no_mangle]
1021pub unsafe extern "C" fn info() -> *const c_char {
1022    // let msg = c"symjit 1.3.3";
1023    let msg = CString::new(env!("CARGO_PKG_VERSION")).unwrap();
1024    msg.into_raw() as *const _
1025}
1026
1027/// Returns a pointer to the fast function if one can be compiled.
1028///
1029/// # Safety
1030///     1. If the model cannot be compiled to a fast function, NULL is returned.
1031///     2. A fast function code memory is leaked and is not deallocated.
1032///
1033#[no_mangle]
1034pub unsafe extern "C" fn fast_func(q: *mut CompilerResult) -> *const usize {
1035    let q: &mut CompilerResult = unsafe { &mut *q };
1036    if let Some(app) = &mut q.app {
1037        match app.get_fast() {
1038            Some(f) => f as *const usize,
1039            None => std::ptr::null(),
1040        }
1041    } else {
1042        std::ptr::null()
1043    }
1044}
1045
1046/// Interface for Sympy's LowLevelCallable.
1047///
1048/// # Safety
1049///     1. If the model cannot be compiled to a fast function, NULL is returned.
1050///     2. The resulting function lives as long as q does and should not be stored
1051///         separately.
1052///
1053#[no_mangle]
1054pub unsafe extern "C" fn callable_quad(n: usize, xx: *const f64, q: *mut CompilerResult) -> f64 {
1055    let q: &mut CompilerResult = unsafe { &mut *q };
1056    let xx: &[f64] = unsafe { std::slice::from_raw_parts(xx, n) };
1057
1058    if let Some(app) = &mut q.app {
1059        app.exec_callable(xx)
1060    } else {
1061        f64::NAN
1062    }
1063}
1064
1065/// Interface for Sympy's LowLevelCallable.
1066///
1067/// # Safety
1068///     1. If the model cannot be compiled to a fast function, NULL is returned.
1069///     2. The resulting function lives as long as q does and should not be stored
1070///         separately.
1071///
1072#[no_mangle]
1073pub unsafe extern "C" fn callable_quad_fast(n: usize, xx: *const f64, f: *const usize) -> f64 {
1074    let xx: &[f64] = unsafe { std::slice::from_raw_parts(xx, n) };
1075
1076    match n {
1077        0 => {
1078            let f: fn() -> f64 = unsafe { std::mem::transmute(f) };
1079            f()
1080        }
1081        1 => {
1082            let f: fn(f64) -> f64 = unsafe { std::mem::transmute(f) };
1083            f(xx[0])
1084        }
1085        2 => {
1086            let f: fn(f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1087            f(xx[0], xx[1])
1088        }
1089        3 => {
1090            let f: fn(f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1091            f(xx[0], xx[1], xx[2])
1092        }
1093        4 => {
1094            let f: fn(f64, f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1095            f(xx[0], xx[1], xx[2], xx[3])
1096        }
1097        5 => {
1098            let f: fn(f64, f64, f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1099            f(xx[0], xx[1], xx[2], xx[3], xx[4])
1100        }
1101        6 => {
1102            let f: fn(f64, f64, f64, f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1103            f(xx[0], xx[1], xx[2], xx[3], xx[4], xx[5])
1104        }
1105        7 => {
1106            let f: fn(f64, f64, f64, f64, f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1107            f(xx[0], xx[1], xx[2], xx[3], xx[4], xx[5], xx[6])
1108        }
1109        _ => {
1110            panic!("too many parameters for a fast func");
1111        }
1112    }
1113}
1114
1115/// Interface for Sympy's LowLevelCallable (image filtering).
1116///
1117/// # Safety
1118///     1. If the model cannot be compiled to a fast function, NULL is returned.
1119///     2. The resulting function lives as long as q does and should not be stored
1120///         separately.
1121///
1122#[no_mangle]
1123pub unsafe extern "C" fn callable_filter(
1124    buffer: *const f64,
1125    filter_size: usize,
1126    return_value: *mut f64,
1127    q: *mut CompilerResult,
1128) -> i64 {
1129    let q: &mut CompilerResult = unsafe { &mut *q };
1130    let xx: &[f64] = unsafe { std::slice::from_raw_parts(buffer, filter_size) };
1131
1132    if let Some(app) = &mut q.app {
1133        let p: &mut f64 = unsafe { &mut *return_value };
1134        *p = app.exec_callable(xx);
1135        1
1136    } else {
1137        0
1138    }
1139}
1140
1141/************************************************/
1142
1143/// Creates an empty Matrix (a 2d array).
1144///
1145/// # Safety
1146///     It returns a pointer to the allocated Matrix, which needs to be
1147///     deallocated eventually.
1148///
1149#[no_mangle]
1150pub unsafe extern "C" fn create_matrix<'a>() -> *const Matrix<'a> {
1151    let mat = Matrix::new();
1152    Box::into_raw(Box::new(mat)) as *const Matrix
1153}
1154
1155/// Finalizes (deallocates) the Matrix.
1156///
1157/// # Safety
1158///     1, mat should point to a valid Matrix object created by create_matrix.
1159///     2. After finalize_matrix is called, mat is invalid.
1160///
1161#[no_mangle]
1162pub unsafe extern "C" fn finalize_matrix(mat: *mut Matrix) {
1163    if !mat.is_null() {
1164        let _ = unsafe { Box::from_raw(mat) };
1165    }
1166}
1167
1168/// Adds a row to the Matrix.
1169///
1170/// # Safety
1171///     1, mat should point to a valid Matrix object created by create_matrix.
1172///     2. v should point to a valid array of doubles of length at least n.
1173///     3. v should remains valid for the lifespan of mat.
1174///
1175#[no_mangle]
1176pub unsafe extern "C" fn add_row(mat: *mut Matrix, v: *mut f64, n: usize) {
1177    let mat: &mut Matrix = unsafe { &mut *mat };
1178    mat.add_row(v, n);
1179}
1180
1181/// Executes (runs) the matrix model encoded by `q`.
1182///
1183/// # Safety
1184///     1, q should point to a valid CompilerResult object.
1185///     2. states should point to a valid Matrix of at least count_states rows.
1186///     3. obs should point to a valid Matrix of at least count_obs rows.
1187///
1188#[no_mangle]
1189pub unsafe extern "C" fn execute_matrix(
1190    q: *mut CompilerResult,
1191    states: *mut Matrix,
1192    obs: *mut Matrix,
1193) -> bool {
1194    let q: &mut CompilerResult = unsafe { &mut *q };
1195    let states: &mut Matrix = unsafe { &mut *states };
1196    let obs: &mut Matrix = unsafe { &mut *obs };
1197
1198    if let Some(app) = &mut q.app {
1199        app.exec_vectorized(states, obs);
1200        true
1201    } else {
1202        false
1203    }
1204}
1205
1206/************************************************/
1207
1208/// Creates an empty `Defun` (a list of user-defined functions).
1209///
1210/// `Defuns` are used to pass user-defined functions (either Python
1211/// functions or symjit-compiled functions).
1212///
1213/// # Safety
1214///     It returns a pointer to the allocated Defun, which needs to be
1215///     deallocated eventually.
1216///
1217#[no_mangle]
1218pub unsafe extern "C" fn create_defuns() -> *const Defuns {
1219    let df = Defuns::new();
1220    Box::into_raw(Box::new(df)) as *const Defuns
1221}
1222
1223/// Finalizes (deallocates) a `Defun`.
1224///
1225/// # Safety
1226///     1, df should point to a valid Defun object created by create_defuns.
1227///     2. After finalize_defun is called, df is invalid.
1228///
1229#[no_mangle]
1230pub unsafe extern "C" fn finalize_defuns(_df: *mut Defuns) {
1231    // if !df.is_null() {
1232    //     let _ = unsafe { Box::from_raw(df) };
1233    // }
1234}
1235
1236/// Adds a new function to a `Defun`.
1237///
1238/// # Safety
1239///     1, df should point to a valid Defun object created by create_defun.
1240///     2. name should be a valid utf8 string.
1241///     3. p should point to a valid C-styple function pointer that accepts
1242///         num_args double arguments.
1243///
1244#[no_mangle]
1245pub unsafe extern "C" fn add_func(
1246    df: *mut Defuns,
1247    name: *const c_char,
1248    p: *const usize,
1249    num_args: usize,
1250) {
1251    let df: &mut Defuns = unsafe { &mut *df };
1252    let name = unsafe { CStr::from_ptr(name).to_str().unwrap() };
1253    df.add_func(name, p, num_args);
1254}