symjit/lib.rs
1#![allow(uncommon_codepoints)]
2
3//! Symjit (<https://github.com/siravan/symjit>) is a lightweight just-in-time (JIT)
4//! optimizer compiler for mathematical expressions written in Rust. It was originally
5//! designed to compile SymPy (Python’s symbolic algebra package) expressions
6//! into machine code and to serve as a bridge between SymPy and numerical routines
7//! provided by NumPy and SciPy libraries.
8//!
9//! Symjit crate is the core compiler coupled to a Rust interface to expose the
10//! JIT functionality to the Rust ecosystem and allow Rust applications to
11//! generate code dynamically. Considering its origin, symjit is geared toward
12//! compiling mathematical expressions instead of being a general-purpose JIT
13//! compiler. Therefore, the only supported types for variables are `f64`,
14//! (SIMD f64x4 and f64x2), and implicitly, `bool` and `i32`.
15//!
16//! Symjit emits AMD64 (x86-64), ARM64 (aarch64), and 64-bit RISC-V (riscv64) machine
17//! codes on Linux, Windows, and macOS platforms. SIMD is supported on x86-64
18//! and ARM64.
19//!
20//! In Rust, there are two ways to contruct expressions to pass to Symjit: using
21//! Symbolica or using Symjit standalone expression builder.
22//!
23//! # Symbolica
24//!
25//! Symbolica (<https://symbolica.io/>) is a fast Rust-based Computer Algebra System.
26//! As of version 1.5 of Symbolica, Symjit is an optional backend for Symbolica. Therefore,
27//! the previous interface through `Compiler` object is considered obsolete and should not
28//! be used for new projects.
29//!
30//! # Standalone Expression Builder
31//!
32//! A second way to use Symjit is by using its standalone expression builder. Compared to
33//! Symbolica, the expression builder is limited but is useful in situations that the goal
34//! is to compile an expression without extensive symbolic manipulations.
35//!
36//! The workflow to create, compile, and run expressions is:
37//!
38//! 1. Create terminals (variables and constants) and compose expressions using `Expr` methods:
39//! * Constructors: `var`, `from`, `unary`, `binary`, ...
40//! * Standard algebraic operations: `add`, `mul`, ...
41//! * Standard operators `+`, `-`, `*`, `/`, `%`, `&`, `|`, `^`, `!`.
42//! * Unary functions such as `sin`, `exp`, and other standard mathematical functions.
43//! * Binary functions such as `pow`, `min`, ...
44//! * IfElse operation `ifelse(cond, true_val, false_val)`.
45//! * Heaviside function: `heaviside(x)`, which returns 1 if `x >= 0`; otherwise 0.
46//! * Comparison methods `eq`, `ne`, `lt`, `le`, `gt`, and `ge`.
47//! * Looping constructs `sum` and `prod`.
48//! 2. Create a new `Compiler` object (say, `comp`) using one of its constructors.
49//! 3. Define user-defined functions by calling `comp.def_unary` and `comp.def_binary`
50//! (optional).
51//! 4. Compile by calling `comp.compile` or `comp.compile_params`. The result is of
52//! type `Application` (say, `app`).
53//! 5. Execute the compiled code using one of the `app`'s `call` functions:
54//! * `call(&[f64])`: scalar call.
55//! * `call_params(&[f64], &[f64])`: scalar call with parameters.
56//! * `call_simd(&[__m256d])`: simd call.
57//! * `call_simd_params(&[__m256d], &[f64])`: simd call with parameters.
58//! 6. Optionally, generate a standalone fast function to execute.
59//!
60//! Note that you can use the helper functions `var(&str) -> Expr`, `int(i32) -> Expr`,
61//! `double(f64) -> Expr`, and `boolean(bool) -> f64` to reduce clutter.
62//!
63//! # Examples
64//!
65//! ```rust
66//! use anyhow::Result;
67//! use symjit::{Compiler, Expr};
68//!
69//! pub fn test_scalar() -> Result<()> {
70//! let x = Expr::var("x");
71//! let y = Expr::var("y");
72//! let u = &x + &y;
73//! let v = &x * &y;
74//!
75//! let mut comp = Compiler::new();
76//! let mut app = comp.compile(&[x, y], &[u, v])?;
77//! let res = app.call(&[3.0, 5.0]);
78//! println!("{:?}", &res); // prints [8.0, 15.0]
79//!
80//! Ok(())
81//! }
82//! ```
83//!
84//! `test_scalar` is similar to the following basic example in Python/SymPy:
85//!
86//! ```python
87//! from symjit import compile_func
88//! from sympy import symbols
89//!
90//! x, y = symbols('x y')
91//! f = compile_func([x, y], [x+y, x*y])
92//! print(f(3.0, 5.0)) # prints [8.0, 15.0]
93//! ```
94//!
95//! A more elaborate example, showcasing having a parameter, changing the
96//! optimization level, and using SIMD:
97//!
98//! ```rust
99//! use anyhow::Result;
100//! use symjit::{var, Compiler, Expr};
101//!
102//! pub fn test_simd() -> Result<()> {
103//! use std::arch::x86_64::_mm256_loadu_pd;
104//!
105//! let x = var("x"); // note var instead of Expr::var
106//! let p = var("p"); // the parameter
107//!
108//! let u = &x.square() * &p; // x^2 * p
109//! let mut comp = Compiler::new();
110//! comp.opt_level(2); // optional (opt_level 0 to 2; default 1)
111//! let mut app = comp.compile_params(&[x], &[u], &[p])?;
112//!
113//! let a = &[1.0, 2.0, 3.0, 4.0];
114//! let a = unsafe { _mm256_loadu_pd(a.as_ptr()) };
115//! let res = app.call_simd_params(&[a], &[5.0])?;
116//! println!("{:?}", &res); // prints [__m256d(5.0, 20.0, 45.0, 80.0)]
117//! Ok(())
118//! }
119//! ```
120//!
121//! # Conditional Expression and Loops
122//!
123//! Many mathematical formulas need conditional expressions (`ifelse`) and loops.
124//! Following SymPy, Symjit uses reduction loops such as `sum` and `prod`. The following
125//! example returns the exponential functions:
126//!
127//! ```rust
128//! use symjit::{int, var, Compiler};
129//!
130//! fn test_exp() -> Result<()> {
131//! let x = var("x");
132//! let i = var("i"); // loop variable
133//! let j = var("j"); // loop variable
134//!
135//! // u = x^j / factorial(j) for j in j in 0..=50
136//! let u = x
137//! .pow(&j)
138//! .div(&i.prod(&i, &int(1), &j))
139//! .sum(&j, &int(0), &int(50));
140//!
141//! let mut app = Compiler::new().compile(&[x], &[u])?;
142//! println!("{:?}", app(&[2.0])[0]); // returns exp(2.0) = 7.38905...
143//! Ok(())
144//! }
145//! ```
146//!
147//! An example showing how to calculate pi using the Leibniz formula:
148//!
149//! ```rust
150//! use symjit::{int, var, Compiler};
151//!
152//! fn test_pi() -> Result<()> {
153//! let n = var("n");
154//! let i = var("i"); // loop variable
155//! let j = var("j"); // loop variable
156//!
157//! // numer = if j % 2 == 0 { 4 } else { -4 }
158//! let numer = j.rem(&int(2)).eq(&int(0)).ifelse(&int(4), &int(-4));
159//! // denom = j * 2 + 1
160//! let denom = j.mul(&int(2)).add(&int(1));
161//! // v = numer / denom for j in 0..=n
162//! let v = (&numer / &denom).sum(&j, &int(0), &int(&n));
163//!
164//! let mut app = Compiler::new().compile(&[x], &[v])?;
165//! println!("{:?}", app(&[100000000])[0]); // returns pi
166//! Ok(())
167//! }
168//! ```
169//!
170//! Note that here we are using explicit functions (`add`, `mul`, ...) instead of
171//! the overloaded operators for clarity.
172//!
173//! # Fast Functions
174//!
175//! `Application`'s call functions need to copy the input slice into the function
176//! memory area and then copy the output to a `Vec`. This process is acceptable
177//! for large and complex functions but incurs a penalty for small ones.
178//! Therefore, for a certain subset of applications, Symjit can compile to a
179//! *fast function* and return a function pointer. Examples:
180//!
181//! ```rust
182//! use anyhow::Result;
183//! use symjit::{int, var, Compiler, FastFunc};
184//!
185//! fn test_fast() -> Result<()> {
186//! let x = var("x");
187//! let y = var("y");
188//! let z = var("z");
189//! let u = &x * &(&y - &z).pow(&int(2)); // x * (y - z)^2
190//!
191//! let mut comp = Compiler::new();
192//! let mut app = comp.compile(&[x, y, z], &[u])?;
193//! let f = app.fast_func()?;
194//!
195//! if let FastFunc::F3(f, _) = f {
196//! // f is of type extern "C" fn(f64, f64, f64) -> f64
197//! let res = f(3.0, 5.0, 9.0);
198//! println!("fast\t{:?}", &res);
199//! }
200//!
201//! Ok(())
202//! }
203//! ```
204//!
205//! The conditions for a fast function are:
206//!
207//! * A fast function can have 1 to 8 arguments.
208//! * No SIMD and no parameters.
209//! * It returns only a single value.
210//!
211//! If these conditions are met, you can generate a fast function by calling
212//! `app.fast_func()`, which returns a `Result<FastFunc>`. `FastFunc` is an
213//! enum with eight variants `F1`, `F2`, ..., `F8`, corresponding to functions
214//! with 1 to 8 arguments.
215//!
216//! # User-Defined Functions
217//!
218//! Symjit functions can call into user-defined Rust functions. Currently,
219//! only the following function signatures are accepted:
220//!
221//! ```rust
222//! pub type UnaryFunc = extern "C" fn(f64) -> f64;
223//! pub type BinaryFunc = extern "C" fn(f64, f64) -> f64;
224//! ```
225//!
226//! For example:
227//!
228//! ```rust
229//! extern "C" fn f(x: f64) -> f64 {
230//! x.exp()
231//! }
232//!
233//! extern "C" fn g(x: f64, y: f64) -> f64 {
234//! x.ln() * y
235//! }
236//!
237//! fn test_external() -> Result<()> {
238//! let x = Expr::var("x");
239//! let u = Expr::unary("f_", &x);
240//! let v = &x * &Expr::binary("g_", &u, &x);
241//!
242//! // v(x) = x * (ln(exp(x)) * x) = x ^ 3
243//!
244//! let mut comp = Compiler::new();
245//! comp.def_unary("f_", f);
246//! comp.def_binary("g_", g);
247//! let mut app = comp.compile(&[x], &[v])?;
248//! println!("{:?}", app.call(&[5.0])[0]);
249//!
250//! Ok(())
251//! }
252//! ```
253//!
254//! # Dynamic Expressions
255//!
256//! All the examples up to this point use static expressions. Of course, it
257//! would have been easier just to use Rust expressions for these examples!
258//! The main utility of Symjit for Rust is for dynamic code generation. Here,
259//! we provide a simple example to calculate pi using Viete's formula
260//! (<https://en.wikipedia.org/wiki/Vi%C3%A8te%27s_formula>):
261//!
262//! ```rust
263//! fn test_pi_viete(silent: bool) -> Result<()> {
264//! let x = var("x");
265//! let mut u = int(1);
266//!
267//! for i in 0..50 {
268//! let mut t = x.clone();
269//!
270//! for _ in 0..i {
271//! t = &x + &(&x * &t.sqrt());
272//! }
273//!
274//! u = &u * &t.sqrt();
275//! }
276//!
277//! // u has 1275 = 50 * 51 / 2 sqrt operations
278//! let mut app = Compiler::new().compile(&[x], &[&int(2) / &u])?;
279//! println!("pi = \t{:?}", app.call(&[0.5])[0]);
280//! Ok(())
281//! }
282//! ```
283//!
284//! # C-Interface
285//!
286//! In addition to `Compiler`, this crate provides a C-style interface
287//! used by the Python (<https://github.com/siravan/symjit>) and Julia
288//! (<https://github.com/siravan/Symjit.jl>) packages. This interface
289//! is composed of crate functions like `compile`, `execute`, and
290//! `ptr_states`,..., and is not needed by the Rust interface but can be
291//! used to link symjit to other programming languages.
292//!
293
294use std::collections::HashSet;
295use std::ffi::{c_char, CStr, CString};
296use std::fmt::Debug;
297use std::str::FromStr;
298
299mod allocator;
300mod amd;
301mod applet;
302mod arm;
303mod assembler;
304mod block;
305mod builder;
306mod code;
307mod compactor;
308pub mod compiler;
309mod complexify;
310mod composer;
311mod config;
312mod defuns;
313pub mod expr;
314mod generator;
315pub mod instruction;
316mod machine;
317mod matrix;
318mod memory;
319mod mir;
320mod model;
321mod node;
322mod parser;
323mod runnable;
324mod serializer;
325mod statement;
326mod symbol;
327mod types;
328mod utils;
329
330#[allow(non_upper_case_globals)]
331mod riscv64;
332
333use matrix::Matrix;
334use model::{CellModel, Program};
335
336pub use applet::Applet;
337pub use compiler::{Compiler, FastFunc, Translator};
338pub use composer::Composer;
339pub use config::Config;
340pub use defuns::Defuns;
341pub use expr::{double, int, var, Expr};
342pub use instruction::{BuiltinSymbol, Instruction, Slot, SymbolicaModel};
343pub use num_complex::{Complex, ComplexFloat};
344pub use runnable::{Application, CompilerType};
345pub use serializer::MirWriter;
346pub use types::{ElemType, Element};
347pub use utils::{Compiled, Storage};
348
349#[derive(Debug, Clone, Copy)]
350pub enum CompilerStatus {
351 Ok,
352 Incomplete,
353 InvalidUtf8,
354 ParseError,
355 InvalidCompiler,
356 CompilationError,
357}
358
359pub struct CompilerResult {
360 app: Option<Application>,
361 status: CompilerStatus,
362 msg: CString,
363}
364
365fn error_message<E: Debug>(msg: &str, err: E) -> CString {
366 let s = format!("{:?}: {:?}", msg, err);
367 CString::from_str(&s).unwrap()
368}
369
370/// Compiles a model.
371///
372/// * `model` is a json string encoding the model.
373/// * `ty` is the requested arch (amd, arm, native, or bytecode).
374/// * `opt`: compilation options.
375/// * `df`: user-defined functions.
376///
377/// # Safety
378/// * both model and ty are pointers to null-terminated strings.
379/// * The output is a raw pointer to a CompilerResults.
380///
381#[no_mangle]
382pub unsafe extern "C" fn compile(
383 model: *const c_char,
384 ty: *const c_char,
385 opt: u32,
386 df: *const Defuns,
387) -> *const CompilerResult {
388 let mut res = CompilerResult {
389 app: None,
390 status: CompilerStatus::Incomplete,
391 msg: CString::from_str("Success").unwrap(),
392 };
393
394 let model = unsafe {
395 match CStr::from_ptr(model).to_str() {
396 Ok(model) => model,
397 Err(msg) => {
398 res.status = CompilerStatus::InvalidUtf8;
399 res.msg = error_message("Invalid encoding", msg);
400 return Box::into_raw(Box::new(res)) as *const _;
401 }
402 }
403 };
404
405 let ty = unsafe {
406 match CStr::from_ptr(ty).to_str() {
407 Ok(ty) => ty,
408 Err(msg) => {
409 res.status = CompilerStatus::InvalidUtf8;
410 res.msg = error_message("Invalid compiler type", msg);
411 return Box::into_raw(Box::new(res)) as *const _;
412 }
413 }
414 };
415
416 let ml = match CellModel::load(model) {
417 Ok(ml) => ml,
418 Err(msg) => {
419 res.status = CompilerStatus::ParseError;
420 res.msg = error_message("Cannot parse JSON", msg);
421 return Box::into_raw(Box::new(res)) as *const _;
422 }
423 };
424
425 if let Ok(mut config) = Config::from_name(ty, opt) {
426 let df: Defuns = unsafe {
427 if df.is_null() {
428 Defuns::new()
429 } else {
430 (&*df).clone()
431 }
432 };
433
434 config.set_defuns(df);
435
436 let prog = match Program::new(&ml, config) {
437 Ok(prog) => prog,
438 Err(msg) => {
439 res.status = CompilerStatus::CompilationError;
440 res.msg = error_message("Compilation error (prog)", msg);
441 return Box::into_raw(Box::new(res)) as *const _;
442 }
443 };
444
445 let app = Application::new(prog, HashSet::new());
446
447 match app {
448 Ok(app) => {
449 res.status = CompilerStatus::Ok;
450 res.app = Some(app);
451 }
452 Err(msg) => {
453 res.status = CompilerStatus::CompilationError;
454 res.msg = error_message("Compilation error (app)", &msg);
455 }
456 }
457 } else {
458 res.status = CompilerStatus::InvalidCompiler;
459 res.msg = error_message("Config error", opt);
460 }
461
462 Box::into_raw(Box::new(res)) as *const _
463}
464
465/// Translates a Symbolica model.
466///
467/// * `json` is a json string encoding the output of `export_instructions`.
468/// * `ty` is the requested arch (amd, arm, native, or bytecode).
469/// * `opt`: compilation options.
470/// * `df`: user-defined functions (currently ignored).
471///
472/// # Safety
473/// * both model and ty are pointers to null-terminated strings.
474/// * The output is a raw pointer to a CompilerResults.
475///
476#[no_mangle]
477pub unsafe extern "C" fn translate(
478 json: *const c_char,
479 ty: *const c_char,
480 opt: u32,
481 df: *mut Defuns,
482 num_params: usize,
483) -> *const CompilerResult {
484 let mut res = CompilerResult {
485 app: None,
486 status: CompilerStatus::Incomplete,
487 msg: CString::from_str("Success").unwrap(),
488 };
489
490 let json = unsafe {
491 match CStr::from_ptr(json).to_str() {
492 Ok(json) => json,
493 Err(msg) => {
494 res.status = CompilerStatus::InvalidUtf8;
495 res.msg = error_message("Invalid encoding", msg);
496 return Box::into_raw(Box::new(res)) as *const _;
497 }
498 }
499 };
500
501 let ty = unsafe {
502 match CStr::from_ptr(ty).to_str() {
503 Ok(ty) => ty,
504 Err(msg) => {
505 res.status = CompilerStatus::InvalidUtf8;
506 res.msg = error_message("Invalid compiler type", msg);
507 return Box::into_raw(Box::new(res)) as *const _;
508 }
509 }
510 };
511
512 if let Ok(mut config) = Config::from_name(ty, opt) {
513 let df: Defuns = unsafe {
514 if df.is_null() {
515 Defuns::new()
516 } else {
517 (&*df).clone()
518 }
519 };
520
521 config.set_defuns(df);
522 let mut comp = Compiler::with_config(config);
523 let app = comp.translate(json.to_string(), num_params);
524
525 match app {
526 Ok(app) => {
527 res.app = Some(app);
528 res.status = CompilerStatus::Ok;
529 }
530 Err(msg) => {
531 res.status = CompilerStatus::InvalidCompiler;
532 res.msg = error_message("Compilation error", msg);
533 }
534 }
535 } else {
536 res.status = CompilerStatus::InvalidCompiler;
537 res.msg = error_message("Config error", opt);
538 }
539
540 Box::into_raw(Box::new(res)) as *const _
541}
542
543/// Checks the status of a `CompilerResult`.
544///
545/// Returns a null-terminated string representing the status message.
546///
547/// # Safety
548/// it is the responsibility of the calling function to ensure
549/// that q points to a valid CompilerResult.
550///
551#[no_mangle]
552pub unsafe extern "C" fn check_status(q: *const CompilerResult) -> *const c_char {
553 let q: &CompilerResult = unsafe { &*q };
554 q.msg.as_ptr() as *const _
555}
556
557/// Checks the status of a `CompilerResult`.
558///
559/// Returns a null-terminated string representing the status message.
560///
561/// # Safety
562/// it is the responsibility of the calling function to ensure
563/// that q points to a valid CompilerResult.
564///
565#[no_mangle]
566pub unsafe extern "C" fn save(q: *const CompilerResult, file: *const c_char) -> bool {
567 let q: &CompilerResult = unsafe { &*q };
568 let file = unsafe {
569 match CStr::from_ptr(file).to_str() {
570 Ok(file) => file,
571 Err(_) => return false,
572 }
573 };
574
575 if let Some(app) = &q.app {
576 if let Ok(mut fs) = std::fs::File::create(file) {
577 app.save(&mut fs).is_ok()
578 } else {
579 false
580 }
581 } else {
582 false
583 }
584}
585
586/// Checks the status of a `CompilerResult`.
587///
588/// Returns a null-terminated string representing the status message.
589///
590/// # Safety
591/// it is the responsibility of the calling function to ensure
592/// that q points to a valid CompilerResult.
593///
594#[no_mangle]
595pub unsafe extern "C" fn load(file: *const c_char, df: *mut Defuns) -> *const CompilerResult {
596 let mut res = CompilerResult {
597 app: None,
598 status: CompilerStatus::Incomplete,
599 msg: CString::from_str("Success").unwrap(),
600 };
601
602 let df: Defuns = unsafe {
603 if df.is_null() {
604 Defuns::new()
605 } else {
606 (&*df).clone()
607 }
608 };
609
610 let file = unsafe {
611 match CStr::from_ptr(file).to_str() {
612 Ok(file) => file,
613 Err(_) => return Box::into_raw(Box::new(res)) as *const _,
614 }
615 };
616
617 let fs = std::fs::File::open(file);
618
619 match fs {
620 Ok(mut fs) => match Application::load(&mut fs, &Config::from_defuns(df).unwrap()) {
621 Ok(app) => {
622 res.app = Some(app);
623 res.status = CompilerStatus::Ok;
624 }
625 Err(err) => {
626 res.status = CompilerStatus::ParseError;
627 res.msg = error_message("File parse error", &err);
628 }
629 },
630 Err(err) => {
631 res.msg = error_message("File I/O error", &err);
632 }
633 }
634
635 Box::into_raw(Box::new(res)) as *const _
636}
637
638/// Checks the status of a `CompilerResult`.
639///
640/// Returns a null-terminated string representing the status message.
641///
642/// # Safety
643/// it is the responsibility of the calling function to ensure
644/// that q points to a valid CompilerResult.
645///
646#[no_mangle]
647pub unsafe extern "C" fn get_config(q: *const CompilerResult) -> usize {
648 let q: &CompilerResult = unsafe { &*q };
649
650 match &q.app {
651 Some(app) => {
652 let config = app.prog.config();
653
654 let ty: usize = match config.ty {
655 CompilerType::Native => 0,
656 CompilerType::Amd => 1,
657 CompilerType::AmdAVX => 2,
658 CompilerType::AmdSSE => 3,
659 CompilerType::Arm => 4,
660 CompilerType::RiscV => 5,
661 CompilerType::ByteCode => 6,
662 CompilerType::Debug => 7,
663 };
664
665 (config.opt as usize) | (ty << 32)
666 }
667 None => 0,
668 }
669}
670
671/// Returns the number of state variables.
672///
673/// # Safety
674/// it is the responsibility of the calling function to ensure
675/// that q points to a valid CompilerResult.
676///
677#[no_mangle]
678pub unsafe extern "C" fn count_states(q: *const CompilerResult) -> usize {
679 let q: &CompilerResult = unsafe { &*q };
680 if let Some(app) = &q.app {
681 app.count_states
682 } else {
683 0
684 }
685}
686
687/// Returns the number of parameters.
688///
689/// # Safety
690/// it is the responsibility of the calling function to ensure
691/// that q points to a valid CompilerResult.
692///
693#[no_mangle]
694pub unsafe extern "C" fn count_params(q: *const CompilerResult) -> usize {
695 let q: &CompilerResult = unsafe { &*q };
696 if let Some(app) = &q.app {
697 app.count_params
698 } else {
699 0
700 }
701}
702
703/// Returns the number of observables (output).
704///
705/// # Safety
706/// it is the responsibility of the calling function to ensure
707/// that q points to a valid CompilerResult.
708///
709#[no_mangle]
710pub unsafe extern "C" fn count_obs(q: *const CompilerResult) -> usize {
711 let q: &CompilerResult = unsafe { &*q };
712 if let Some(app) = &q.app {
713 app.count_obs
714 } else {
715 0
716 }
717}
718
719/// Returns the number of differential equations.
720///
721/// Generally, it should be the same as the number of states.
722///
723/// # Safety
724/// it is the responsibility of the calling function to ensure
725/// that q points to a valid CompilerResult.
726///
727#[no_mangle]
728pub unsafe extern "C" fn count_diffs(q: *const CompilerResult) -> usize {
729 let q: &CompilerResult = unsafe { &*q };
730 if let Some(app) = &q.app {
731 app.count_diffs
732 } else {
733 0
734 }
735}
736
737/// Deprecated. Previously used for interfacing to DifferentialEquation.jl. It is
738/// replaced with <https://github.com/siravan/SymJit.jl>.
739///
740/// # Safety
741///
742/// Deprecated. No effects.
743#[no_mangle]
744pub unsafe extern "C" fn run(
745 _q: *mut CompilerResult,
746 _du: *mut f64,
747 _u: *const f64,
748 _ns: usize,
749 _p: *const f64,
750 _np: usize,
751 _t: f64,
752) -> bool {
753 // let q: &mut CompilerResult = unsafe { &mut *q };
754
755 // if let Some(app) = &mut q.app {
756 // if app.count_states != ns || app.count_params != np {
757 // return false;
758 // }
759
760 // let du: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(du, ns) };
761 // let u: &[f64] = unsafe { std::slice::from_raw_parts(u, ns) };
762 // let p: &[f64] = unsafe { std::slice::from_raw_parts(p, np) };
763 // app.call(du, u, p, t);
764 // true
765 // } else {
766 // false
767 // }
768 false
769}
770
771/// Executes the compiled function.
772///
773/// The calling routine should fill the states and parameters before
774/// calling `execute`. The result populates obs or diffs (as defined in
775/// model passed to `compile`).
776///
777/// # Safety
778/// it is the responsibility of the calling function to ensure
779/// that q points to a valid CompilerResult.
780///
781#[no_mangle]
782pub unsafe extern "C" fn execute(q: *mut CompilerResult) -> bool {
783 let q: &mut CompilerResult = unsafe { &mut *q };
784
785 if let Some(app) = &mut q.app {
786 app.exec();
787 true
788 } else {
789 false
790 }
791}
792
793/// Executes the compiled function `n` times (vectorized).
794///
795/// The calling function provides `buf`, which is a k x n matrix of doubles,
796/// where k is equal to the `maximum(count_states, count_obs)`. The calling
797/// funciton fills the first `count_states` rows of buf. The result is returned
798/// in the first count_obs rows of buf.
799///
800/// # Safety
801/// it is the responsibility of the calling function to ensure
802/// that q points to a valid CompilerResult.
803///
804/// In addition, buf should points to a valid matrix of correct size.
805///
806#[no_mangle]
807pub unsafe extern "C" fn execute_vectorized(
808 q: *mut CompilerResult,
809 buf: *mut f64,
810 n: usize,
811) -> bool {
812 let q: &mut CompilerResult = unsafe { &mut *q };
813
814 if let Some(app) = &mut q.app {
815 let h = usize::max(app.count_states, app.count_obs);
816 let p: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(buf, h * n) };
817 let mut states = Matrix::from_buf(p, h, n);
818 let p: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(buf, h * n) };
819 let mut obs = Matrix::from_buf(p, h, n);
820 app.exec_vectorized(&mut states, &mut obs);
821 true
822 } else {
823 false
824 }
825}
826
827/// Evaluates the compiled function. This is for Symbolica compatibility.
828///
829/// # Safety
830/// it is the responsibility of the calling function to ensure
831/// that q points to a valid CompilerResult.
832///
833#[no_mangle]
834pub unsafe extern "C" fn evaluate(
835 q: *mut CompilerResult,
836 args: *const f64,
837 nargs: usize,
838 outs: *mut f64,
839 nouts: usize,
840) -> bool {
841 let q: &mut CompilerResult = unsafe { &mut *q };
842
843 if let Some(app) = &mut q.app {
844 let args: &[f64] = unsafe { std::slice::from_raw_parts(args, nargs) };
845 let outs: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(outs, nouts) };
846 app.evaluate(args, outs);
847 true
848 } else {
849 false
850 }
851}
852
853/// Evaluates the compiled function. This is for Symbolica compatibility.
854///
855/// # Safety
856/// it is the responsibility of the calling function to ensure
857/// that q points to a valid CompilerResult.
858///
859#[no_mangle]
860pub unsafe extern "C" fn evaluate_matrix(
861 q: *mut CompilerResult,
862 args: *const f64,
863 nargs: usize,
864 outs: *mut f64,
865 nouts: usize,
866) -> bool {
867 let q: &mut CompilerResult = unsafe { &mut *q };
868
869 if let Some(app) = &mut q.app {
870 if app.count_params == 0 {
871 return false;
872 }
873
874 let args: &[f64] = unsafe { std::slice::from_raw_parts(args, nargs) };
875 let outs: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(outs, nouts) };
876 let n = nargs / app.count_params;
877 app.evaluate_matrix(args, outs, n);
878 true
879 } else {
880 false
881 }
882}
883
884/// Returns a pointer to the state variables (`count_states` doubles).
885///
886/// The function calling `execute` should write the state variables in this area.
887///
888/// # Safety
889/// it is the responsibility of the calling function to ensure
890/// that q points to a valid CompilerResult.
891///
892#[no_mangle]
893pub unsafe extern "C" fn ptr_states(q: *mut CompilerResult) -> *mut f64 {
894 let q: &mut CompilerResult = unsafe { &mut *q };
895 if let Some(app) = &mut q.app {
896 if let Some(f) = &mut app.compiled {
897 &mut f.mem_mut()[app.first_state] as *mut f64
898 } else {
899 &mut app.bytecode.mem_mut()[app.first_state] as *mut f64
900 }
901 } else {
902 std::ptr::null_mut()
903 }
904}
905
906/// Returns a pointer to the parameters (`count_params` doubles).
907///
908/// The function calling `execute` should write the parameters in this area.
909///
910/// # Safety
911/// it is the responsibility of the calling function to ensure
912/// that q points to a valid CompilerResult.
913///
914#[no_mangle]
915pub unsafe extern "C" fn ptr_params(q: *mut CompilerResult) -> *mut f64 {
916 let q: &mut CompilerResult = unsafe { &mut *q };
917 if let Some(app) = &mut q.app {
918 //&mut app.compiled.mem_mut()[app.first_param] as *mut f64
919 &mut app.params[app.first_param] as *mut f64
920 } else {
921 std::ptr::null_mut()
922 }
923}
924
925/// Returns a pointer to the observables (`count_obs` doubles).
926///
927/// The function calling `execute` reads the observables from this area.
928///
929/// # Safety
930/// it is the responsibility of the calling function to ensure
931/// that q points to a valid CompilerResult.
932///
933#[no_mangle]
934pub unsafe extern "C" fn ptr_obs(q: *mut CompilerResult) -> *const f64 {
935 let q: &CompilerResult = unsafe { &*q };
936 if let Some(app) = &q.app {
937 if let Some(f) = &app.compiled {
938 &f.mem()[app.first_obs] as *const f64
939 } else {
940 &app.bytecode.mem()[app.first_obs] as *const f64
941 }
942 } else {
943 std::ptr::null()
944 }
945}
946
947/// Returns a pointer to the differentials (`count_diffs` doubles).
948///
949/// The function calling `execute` reads the differentials from this area.
950///
951/// Note: whether the output is returned as observables or differentials is
952/// defined in the model.
953///
954/// # Safety
955/// it is the responsibility of the calling function to ensure
956/// that q points to a valid CompilerResult.
957///
958#[no_mangle]
959pub unsafe extern "C" fn ptr_diffs(q: *mut CompilerResult) -> *const f64 {
960 let q: &CompilerResult = unsafe { &*q };
961 if let Some(app) = &q.app {
962 if let Some(f) = &app.compiled {
963 &f.mem()[app.first_diff] as *const f64
964 } else {
965 &app.bytecode.mem()[app.first_diff] as *const f64
966 }
967 } else {
968 std::ptr::null()
969 }
970}
971
972/// Dumps the compiled binary code to a file (`name`).
973///
974/// This function is useful for debugging but is not necessary for
975/// normal operations.
976///
977/// # Safety
978/// it is the responsibility of the calling function to ensure
979/// that q points to a valid CompilerResult.
980///
981#[no_mangle]
982pub unsafe extern "C" fn dump(
983 q: *mut CompilerResult,
984 name: *const c_char,
985 what: *const c_char,
986) -> bool {
987 let q: &mut CompilerResult = unsafe { &mut *q };
988 if let Some(app) = &mut q.app {
989 let name = unsafe { CStr::from_ptr(name).to_str().unwrap() };
990 let what = unsafe { CStr::from_ptr(what).to_str().unwrap() };
991 app.dump(name, what)
992 } else {
993 false
994 }
995}
996
997/// Deallocates the CompilerResult pointed by `q`.
998///
999/// # Safety
1000/// it is the responsibility of the calling function to ensure
1001/// that q points to a valid CompilerResult and that after
1002/// calling this function, q is invalid and should not
1003/// be used anymore.
1004///
1005#[no_mangle]
1006pub unsafe extern "C" fn finalize(q: *mut CompilerResult) {
1007 if !q.is_null() {
1008 let _ = unsafe { Box::from_raw(q) };
1009 }
1010}
1011
1012/// Returns a null-terminated string representing the version.
1013///
1014/// Used for debugging.
1015///
1016/// # Safety
1017/// the return value is a null-terminated string that should not
1018/// be freed.
1019///
1020#[no_mangle]
1021pub unsafe extern "C" fn info() -> *const c_char {
1022 // let msg = c"symjit 1.3.3";
1023 let msg = CString::new(env!("CARGO_PKG_VERSION")).unwrap();
1024 msg.into_raw() as *const _
1025}
1026
1027/// Returns a pointer to the fast function if one can be compiled.
1028///
1029/// # Safety
1030/// 1. If the model cannot be compiled to a fast function, NULL is returned.
1031/// 2. A fast function code memory is leaked and is not deallocated.
1032///
1033#[no_mangle]
1034pub unsafe extern "C" fn fast_func(q: *mut CompilerResult) -> *const usize {
1035 let q: &mut CompilerResult = unsafe { &mut *q };
1036 if let Some(app) = &mut q.app {
1037 match app.get_fast() {
1038 Some(f) => f as *const usize,
1039 None => std::ptr::null(),
1040 }
1041 } else {
1042 std::ptr::null()
1043 }
1044}
1045
1046/// Interface for Sympy's LowLevelCallable.
1047///
1048/// # Safety
1049/// 1. If the model cannot be compiled to a fast function, NULL is returned.
1050/// 2. The resulting function lives as long as q does and should not be stored
1051/// separately.
1052///
1053#[no_mangle]
1054pub unsafe extern "C" fn callable_quad(n: usize, xx: *const f64, q: *mut CompilerResult) -> f64 {
1055 let q: &mut CompilerResult = unsafe { &mut *q };
1056 let xx: &[f64] = unsafe { std::slice::from_raw_parts(xx, n) };
1057
1058 if let Some(app) = &mut q.app {
1059 app.exec_callable(xx)
1060 } else {
1061 f64::NAN
1062 }
1063}
1064
1065/// Interface for Sympy's LowLevelCallable.
1066///
1067/// # Safety
1068/// 1. If the model cannot be compiled to a fast function, NULL is returned.
1069/// 2. The resulting function lives as long as q does and should not be stored
1070/// separately.
1071///
1072#[no_mangle]
1073pub unsafe extern "C" fn callable_quad_fast(n: usize, xx: *const f64, f: *const usize) -> f64 {
1074 let xx: &[f64] = unsafe { std::slice::from_raw_parts(xx, n) };
1075
1076 match n {
1077 0 => {
1078 let f: fn() -> f64 = unsafe { std::mem::transmute(f) };
1079 f()
1080 }
1081 1 => {
1082 let f: fn(f64) -> f64 = unsafe { std::mem::transmute(f) };
1083 f(xx[0])
1084 }
1085 2 => {
1086 let f: fn(f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1087 f(xx[0], xx[1])
1088 }
1089 3 => {
1090 let f: fn(f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1091 f(xx[0], xx[1], xx[2])
1092 }
1093 4 => {
1094 let f: fn(f64, f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1095 f(xx[0], xx[1], xx[2], xx[3])
1096 }
1097 5 => {
1098 let f: fn(f64, f64, f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1099 f(xx[0], xx[1], xx[2], xx[3], xx[4])
1100 }
1101 6 => {
1102 let f: fn(f64, f64, f64, f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1103 f(xx[0], xx[1], xx[2], xx[3], xx[4], xx[5])
1104 }
1105 7 => {
1106 let f: fn(f64, f64, f64, f64, f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1107 f(xx[0], xx[1], xx[2], xx[3], xx[4], xx[5], xx[6])
1108 }
1109 _ => {
1110 panic!("too many parameters for a fast func");
1111 }
1112 }
1113}
1114
1115/// Interface for Sympy's LowLevelCallable (image filtering).
1116///
1117/// # Safety
1118/// 1. If the model cannot be compiled to a fast function, NULL is returned.
1119/// 2. The resulting function lives as long as q does and should not be stored
1120/// separately.
1121///
1122#[no_mangle]
1123pub unsafe extern "C" fn callable_filter(
1124 buffer: *const f64,
1125 filter_size: usize,
1126 return_value: *mut f64,
1127 q: *mut CompilerResult,
1128) -> i64 {
1129 let q: &mut CompilerResult = unsafe { &mut *q };
1130 let xx: &[f64] = unsafe { std::slice::from_raw_parts(buffer, filter_size) };
1131
1132 if let Some(app) = &mut q.app {
1133 let p: &mut f64 = unsafe { &mut *return_value };
1134 *p = app.exec_callable(xx);
1135 1
1136 } else {
1137 0
1138 }
1139}
1140
1141/************************************************/
1142
1143/// Creates an empty Matrix (a 2d array).
1144///
1145/// # Safety
1146/// It returns a pointer to the allocated Matrix, which needs to be
1147/// deallocated eventually.
1148///
1149#[no_mangle]
1150pub unsafe extern "C" fn create_matrix<'a>() -> *const Matrix<'a> {
1151 let mat = Matrix::new();
1152 Box::into_raw(Box::new(mat)) as *const Matrix
1153}
1154
1155/// Finalizes (deallocates) the Matrix.
1156///
1157/// # Safety
1158/// 1, mat should point to a valid Matrix object created by create_matrix.
1159/// 2. After finalize_matrix is called, mat is invalid.
1160///
1161#[no_mangle]
1162pub unsafe extern "C" fn finalize_matrix(mat: *mut Matrix) {
1163 if !mat.is_null() {
1164 let _ = unsafe { Box::from_raw(mat) };
1165 }
1166}
1167
1168/// Adds a row to the Matrix.
1169///
1170/// # Safety
1171/// 1, mat should point to a valid Matrix object created by create_matrix.
1172/// 2. v should point to a valid array of doubles of length at least n.
1173/// 3. v should remains valid for the lifespan of mat.
1174///
1175#[no_mangle]
1176pub unsafe extern "C" fn add_row(mat: *mut Matrix, v: *mut f64, n: usize) {
1177 let mat: &mut Matrix = unsafe { &mut *mat };
1178 mat.add_row(v, n);
1179}
1180
1181/// Executes (runs) the matrix model encoded by `q`.
1182///
1183/// # Safety
1184/// 1, q should point to a valid CompilerResult object.
1185/// 2. states should point to a valid Matrix of at least count_states rows.
1186/// 3. obs should point to a valid Matrix of at least count_obs rows.
1187///
1188#[no_mangle]
1189pub unsafe extern "C" fn execute_matrix(
1190 q: *mut CompilerResult,
1191 states: *mut Matrix,
1192 obs: *mut Matrix,
1193) -> bool {
1194 let q: &mut CompilerResult = unsafe { &mut *q };
1195 let states: &mut Matrix = unsafe { &mut *states };
1196 let obs: &mut Matrix = unsafe { &mut *obs };
1197
1198 if let Some(app) = &mut q.app {
1199 app.exec_vectorized(states, obs);
1200 true
1201 } else {
1202 false
1203 }
1204}
1205
1206/************************************************/
1207
1208/// Creates an empty `Defun` (a list of user-defined functions).
1209///
1210/// `Defuns` are used to pass user-defined functions (either Python
1211/// functions or symjit-compiled functions).
1212///
1213/// # Safety
1214/// It returns a pointer to the allocated Defun, which needs to be
1215/// deallocated eventually.
1216///
1217#[no_mangle]
1218pub unsafe extern "C" fn create_defuns() -> *const Defuns {
1219 let df = Defuns::new();
1220 Box::into_raw(Box::new(df)) as *const Defuns
1221}
1222
1223/// Finalizes (deallocates) a `Defun`.
1224///
1225/// # Safety
1226/// 1, df should point to a valid Defun object created by create_defuns.
1227/// 2. After finalize_defun is called, df is invalid.
1228///
1229#[no_mangle]
1230pub unsafe extern "C" fn finalize_defuns(_df: *mut Defuns) {
1231 // if !df.is_null() {
1232 // let _ = unsafe { Box::from_raw(df) };
1233 // }
1234}
1235
1236/// Adds a new function to a `Defun`.
1237///
1238/// # Safety
1239/// 1, df should point to a valid Defun object created by create_defun.
1240/// 2. name should be a valid utf8 string.
1241/// 3. p should point to a valid C-styple function pointer that accepts
1242/// num_args double arguments.
1243///
1244#[no_mangle]
1245pub unsafe extern "C" fn add_func(
1246 df: *mut Defuns,
1247 name: *const c_char,
1248 p: *const usize,
1249 num_args: usize,
1250) {
1251 let df: &mut Defuns = unsafe { &mut *df };
1252 let name = unsafe { CStr::from_ptr(name).to_str().unwrap() };
1253 df.add_func(name, p, num_args);
1254}