Skip to main content

symjit/
applet.rs

1use anyhow::{anyhow, Result};
2use rayon::prelude::*;
3
4use crate::machine::MachineCode;
5use crate::runnable::Application;
6use crate::types::{ElemType, Element};
7use crate::utils::*;
8
9#[derive(Clone)]
10pub struct Applet {
11    pub compiled: Option<MachineCode<f64>>,
12    pub compiled_simd: Option<MachineCode<f64>>,
13    pub use_simd: bool,
14    pub use_threads: bool,
15    pub count_states: usize,
16    pub count_params: usize,
17    pub count_obs: usize,
18    pub count_diffs: usize,
19}
20
21impl Applet {
22    pub fn new(app: Application) -> Result<Applet> {
23        if app.prog.config().is_bytecode() {
24            return Err(anyhow!("Bytecode Application cannot be sealed."));
25        }
26
27        Ok(Applet {
28            compiled: app.compiled,
29            compiled_simd: app.compiled_simd,
30            use_simd: app.use_simd,
31            use_threads: app.use_threads,
32            count_states: app.count_states,
33            count_params: app.count_params,
34            count_obs: app.count_obs,
35            count_diffs: app.count_diffs,
36        })
37    }
38
39    /// Generic evaluate function for compiled Symbolica expressions
40    pub fn evaluate<T>(&mut self, args: &[T], outs: &mut [T])
41    where
42        T: Element,
43    {
44        let args = recast_as_f64(args);
45        let outs = recast_as_f64_mut(outs);
46
47        let simd = matches!(
48            T::get_type(T::default()),
49            ElemType::RealF64x2(_)
50                | ElemType::RealF64x4(_)
51                | ElemType::ComplexF64x2(_)
52                | ElemType::ComplexF64x4(_)
53        );
54
55        if let Some(f) = &self.compiled {
56            if !simd {
57                f.func()(outs.as_mut_ptr(), std::ptr::null(), 0, args.as_ptr());
58            } else if let Some(g) = &self.compiled_simd {
59                g.func()(outs.as_mut_ptr(), std::ptr::null(), 0, args.as_ptr());
60            }
61        }
62    }
63
64    /// Generic evaluate_single function for compiled Symbolica expressions
65    #[inline(always)]
66    pub fn evaluate_single<T>(&mut self, args: &[T]) -> T
67    where
68        T: Element + Copy,
69    {
70        let mut outs = [T::default(); 1];
71        self.evaluate(args, &mut outs);
72        outs[0]
73    }
74
75    /// Evaluates a single logical row. It could be a combinatino of multiple
76    /// physical rows because of implicit SIMD.
77    fn evaluate_row(
78        args: &[f64],
79        args_idx: usize,
80        outs: &[f64],
81        outs_idx: usize,
82        f: CompiledFunc<f64>,
83        transpose: bool,
84    ) -> i32 {
85        unsafe {
86            f(
87                outs.as_ptr().add(outs_idx),
88                std::ptr::null(),
89                if transpose { 1 } else { 0 },
90                args.as_ptr().add(args_idx),
91            )
92        }
93    }
94
95    fn evaluate_matrix_with_threads(&self, args: &[f64], outs: &mut [f64], n: usize) {
96        if let Some(f) = &self.compiled {
97            let count_params = self.count_params;
98            let count_obs = self.count_obs;
99            let f_scalar = f.func();
100
101            (0..n).into_par_iter().for_each(|t| {
102                Self::evaluate_row(args, t * count_params, outs, t * count_obs, f_scalar, false);
103            });
104        }
105    }
106
107    fn evaluate_matrix_without_threads(&self, args: &[f64], outs: &mut [f64], n: usize) {
108        if let Some(f) = &self.compiled {
109            let count_params = self.count_params;
110            let count_obs = self.count_obs;
111            let f_scalar = f.func();
112
113            for t in 0..n {
114                Self::evaluate_row(args, t * count_params, outs, t * count_obs, f_scalar, false);
115            }
116        }
117    }
118
119    fn evaluate_matrix_with_threads_simd(
120        &self,
121        args: &[f64],
122        outs: &mut [f64],
123        n: usize,
124        transpose: bool,
125    ) {
126        if let Some(f) = &self.compiled {
127            let count_params = self.count_params;
128            let count_obs = self.count_obs;
129
130            if let Some(compiled) = &self.compiled_simd {
131                let f_simd = compiled.func();
132                let f_scalar = f.func();
133                let lanes = compiled.count_lanes();
134                let step = if transpose { lanes } else { 1 };
135
136                (0..n / step).into_par_iter().for_each(|k| {
137                    let top = k * lanes;
138                    if Self::evaluate_row(
139                        args,
140                        top * count_params,
141                        outs,
142                        top * count_obs,
143                        f_simd,
144                        transpose,
145                    ) != 0
146                    {
147                        for i in 0..lanes {
148                            Self::evaluate_row(
149                                args,
150                                (top + i) * count_params,
151                                outs,
152                                (top + i) * count_obs,
153                                f_scalar,
154                                false,
155                            );
156                        }
157                    }
158                });
159
160                for t in step * (n / step)..n {
161                    Self::evaluate_row(
162                        args,
163                        t * count_params,
164                        outs,
165                        t * count_obs,
166                        f_scalar,
167                        false,
168                    );
169                }
170            }
171        }
172    }
173
174    fn evaluate_matrix_without_threads_simd(
175        &self,
176        args: &[f64],
177        outs: &mut [f64],
178        n: usize,
179        transpose: bool,
180    ) {
181        if let Some(f) = &self.compiled {
182            let count_params = self.count_params;
183            let count_obs = self.count_obs;
184
185            if let Some(compiled) = &self.compiled_simd {
186                let f_simd = compiled.func();
187                let f_scalar = f.func();
188                let lanes = compiled.count_lanes();
189                let step = if transpose { lanes } else { 1 };
190
191                for k in 0..n / step {
192                    let top = k * lanes;
193                    if Self::evaluate_row(
194                        args,
195                        top * count_params,
196                        outs,
197                        top * count_obs,
198                        f_simd,
199                        transpose,
200                    ) != 0
201                    {
202                        for i in 0..lanes {
203                            Self::evaluate_row(
204                                args,
205                                (top + i) * count_params,
206                                outs,
207                                (top + i) * count_obs,
208                                f_scalar,
209                                false,
210                            );
211                        }
212                    }
213                }
214
215                for t in step * (n / step)..n {
216                    Self::evaluate_row(
217                        args,
218                        t * count_params,
219                        outs,
220                        t * count_obs,
221                        f_scalar,
222                        false,
223                    );
224                }
225            }
226        }
227    }
228
229    fn evaluate_matrix_bytecode(&mut self, args: &[f64], outs: &mut [f64], n: usize) {
230        let count_params = self.count_params;
231        let count_obs = self.count_obs;
232
233        for i in 0..n {
234            self.evaluate(
235                &args[i * count_params..(i + 1) * count_params],
236                &mut outs[i * count_obs..(i + 1) * count_obs],
237            );
238        }
239    }
240
241    /// Generic evaluate function for compiled Symbolica expressions
242    /// The main entry point to compute matrices.
243    /// The actual dispatched method depends on the configuration and the
244    /// type of the arguments.
245    pub fn evaluate_matrix<T>(&mut self, args: &[T], outs: &mut [T], n: usize)
246    where
247        T: Element,
248    {
249        let args = recast_as_f64(args);
250        let outs = recast_as_f64_mut(outs);
251
252        let transpose = !matches!(
253            T::get_type(T::default()),
254            ElemType::RealF64x2(_)
255                | ElemType::RealF64x4(_)
256                | ElemType::ComplexF64x2(_)
257                | ElemType::ComplexF64x4(_)
258        );
259
260        if self.use_threads && n > 1 {
261            if self.compiled_simd.is_some() {
262                self.evaluate_matrix_with_threads_simd(args, outs, n, transpose);
263            } else {
264                self.evaluate_matrix_with_threads(args, outs, n);
265            }
266        } else {
267            if self.compiled_simd.is_some() {
268                self.evaluate_matrix_without_threads_simd(args, outs, n, transpose);
269            } else {
270                self.evaluate_matrix_without_threads(args, outs, n);
271            }
272        }
273    }
274}
275
276fn recast_as_f64<T>(v: &[T]) -> &[f64]
277where
278    T: Sized,
279{
280    let s = std::mem::size_of::<T>() / std::mem::size_of::<f64>();
281    let p: *const f64 = v.as_ptr() as _;
282    let q: &[f64] = unsafe { std::slice::from_raw_parts(p, s * v.len()) };
283    q
284}
285
286fn recast_as_f64_mut<T>(v: &mut [T]) -> &mut [f64]
287where
288    T: Sized,
289{
290    let s = std::mem::size_of::<T>() / std::mem::size_of::<f64>();
291    let p: *mut f64 = v.as_ptr() as _;
292    let q: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(p, s * v.len()) };
293    q
294}