1use anyhow::{anyhow, Result};
2use rayon::prelude::*;
3
4use crate::machine::MachineCode;
5use crate::runnable::Application;
6use crate::types::{ElemType, Element};
7use crate::utils::*;
8
9#[derive(Clone)]
10pub struct Applet {
11 pub compiled: Option<MachineCode<f64>>,
12 pub compiled_simd: Option<MachineCode<f64>>,
13 pub use_simd: bool,
14 pub use_threads: bool,
15 pub count_states: usize,
16 pub count_params: usize,
17 pub count_obs: usize,
18 pub count_diffs: usize,
19}
20
21impl Applet {
22 pub fn new(app: Application) -> Result<Applet> {
23 if app.prog.config().is_bytecode() {
24 return Err(anyhow!("Bytecode Application cannot be sealed."));
25 }
26
27 Ok(Applet {
28 compiled: app.compiled,
29 compiled_simd: app.compiled_simd,
30 use_simd: app.use_simd,
31 use_threads: app.use_threads,
32 count_states: app.count_states,
33 count_params: app.count_params,
34 count_obs: app.count_obs,
35 count_diffs: app.count_diffs,
36 })
37 }
38
39 pub fn evaluate<T>(&mut self, args: &[T], outs: &mut [T])
41 where
42 T: Element,
43 {
44 let args = recast_as_f64(args);
45 let outs = recast_as_f64_mut(outs);
46
47 let simd = matches!(
48 T::get_type(T::default()),
49 ElemType::RealF64x2(_)
50 | ElemType::RealF64x4(_)
51 | ElemType::ComplexF64x2(_)
52 | ElemType::ComplexF64x4(_)
53 );
54
55 if let Some(f) = &self.compiled {
56 if !simd {
57 f.func()(outs.as_mut_ptr(), std::ptr::null(), 0, args.as_ptr());
58 } else if let Some(g) = &self.compiled_simd {
59 g.func()(outs.as_mut_ptr(), std::ptr::null(), 0, args.as_ptr());
60 }
61 }
62 }
63
64 #[inline(always)]
66 pub fn evaluate_single<T>(&mut self, args: &[T]) -> T
67 where
68 T: Element + Copy,
69 {
70 let mut outs = [T::default(); 1];
71 self.evaluate(args, &mut outs);
72 outs[0]
73 }
74
75 fn evaluate_row(
78 args: &[f64],
79 args_idx: usize,
80 outs: &[f64],
81 outs_idx: usize,
82 f: CompiledFunc<f64>,
83 transpose: bool,
84 ) -> i32 {
85 unsafe {
86 f(
87 outs.as_ptr().add(outs_idx),
88 std::ptr::null(),
89 if transpose { 1 } else { 0 },
90 args.as_ptr().add(args_idx),
91 )
92 }
93 }
94
95 fn evaluate_matrix_with_threads(&self, args: &[f64], outs: &mut [f64], n: usize) {
96 if let Some(f) = &self.compiled {
97 let count_params = self.count_params;
98 let count_obs = self.count_obs;
99 let f_scalar = f.func();
100
101 (0..n).into_par_iter().for_each(|t| {
102 Self::evaluate_row(args, t * count_params, outs, t * count_obs, f_scalar, false);
103 });
104 }
105 }
106
107 fn evaluate_matrix_without_threads(&self, args: &[f64], outs: &mut [f64], n: usize) {
108 if let Some(f) = &self.compiled {
109 let count_params = self.count_params;
110 let count_obs = self.count_obs;
111 let f_scalar = f.func();
112
113 for t in 0..n {
114 Self::evaluate_row(args, t * count_params, outs, t * count_obs, f_scalar, false);
115 }
116 }
117 }
118
119 fn evaluate_matrix_with_threads_simd(
120 &self,
121 args: &[f64],
122 outs: &mut [f64],
123 n: usize,
124 transpose: bool,
125 ) {
126 if let Some(f) = &self.compiled {
127 let count_params = self.count_params;
128 let count_obs = self.count_obs;
129
130 if let Some(compiled) = &self.compiled_simd {
131 let f_simd = compiled.func();
132 let f_scalar = f.func();
133 let lanes = compiled.count_lanes();
134 let step = if transpose { lanes } else { 1 };
135
136 (0..n / step).into_par_iter().for_each(|k| {
137 let top = k * lanes;
138 if Self::evaluate_row(
139 args,
140 top * count_params,
141 outs,
142 top * count_obs,
143 f_simd,
144 transpose,
145 ) != 0
146 {
147 for i in 0..lanes {
148 Self::evaluate_row(
149 args,
150 (top + i) * count_params,
151 outs,
152 (top + i) * count_obs,
153 f_scalar,
154 false,
155 );
156 }
157 }
158 });
159
160 for t in step * (n / step)..n {
161 Self::evaluate_row(
162 args,
163 t * count_params,
164 outs,
165 t * count_obs,
166 f_scalar,
167 false,
168 );
169 }
170 }
171 }
172 }
173
174 fn evaluate_matrix_without_threads_simd(
175 &self,
176 args: &[f64],
177 outs: &mut [f64],
178 n: usize,
179 transpose: bool,
180 ) {
181 if let Some(f) = &self.compiled {
182 let count_params = self.count_params;
183 let count_obs = self.count_obs;
184
185 if let Some(compiled) = &self.compiled_simd {
186 let f_simd = compiled.func();
187 let f_scalar = f.func();
188 let lanes = compiled.count_lanes();
189 let step = if transpose { lanes } else { 1 };
190
191 for k in 0..n / step {
192 let top = k * lanes;
193 if Self::evaluate_row(
194 args,
195 top * count_params,
196 outs,
197 top * count_obs,
198 f_simd,
199 transpose,
200 ) != 0
201 {
202 for i in 0..lanes {
203 Self::evaluate_row(
204 args,
205 (top + i) * count_params,
206 outs,
207 (top + i) * count_obs,
208 f_scalar,
209 false,
210 );
211 }
212 }
213 }
214
215 for t in step * (n / step)..n {
216 Self::evaluate_row(
217 args,
218 t * count_params,
219 outs,
220 t * count_obs,
221 f_scalar,
222 false,
223 );
224 }
225 }
226 }
227 }
228
229 fn evaluate_matrix_bytecode(&mut self, args: &[f64], outs: &mut [f64], n: usize) {
230 let count_params = self.count_params;
231 let count_obs = self.count_obs;
232
233 for i in 0..n {
234 self.evaluate(
235 &args[i * count_params..(i + 1) * count_params],
236 &mut outs[i * count_obs..(i + 1) * count_obs],
237 );
238 }
239 }
240
241 pub fn evaluate_matrix<T>(&mut self, args: &[T], outs: &mut [T], n: usize)
246 where
247 T: Element,
248 {
249 let args = recast_as_f64(args);
250 let outs = recast_as_f64_mut(outs);
251
252 let transpose = !matches!(
253 T::get_type(T::default()),
254 ElemType::RealF64x2(_)
255 | ElemType::RealF64x4(_)
256 | ElemType::ComplexF64x2(_)
257 | ElemType::ComplexF64x4(_)
258 );
259
260 if self.use_threads && n > 1 {
261 if self.compiled_simd.is_some() {
262 self.evaluate_matrix_with_threads_simd(args, outs, n, transpose);
263 } else {
264 self.evaluate_matrix_with_threads(args, outs, n);
265 }
266 } else {
267 if self.compiled_simd.is_some() {
268 self.evaluate_matrix_without_threads_simd(args, outs, n, transpose);
269 } else {
270 self.evaluate_matrix_without_threads(args, outs, n);
271 }
272 }
273 }
274}
275
276fn recast_as_f64<T>(v: &[T]) -> &[f64]
277where
278 T: Sized,
279{
280 let s = std::mem::size_of::<T>() / std::mem::size_of::<f64>();
281 let p: *const f64 = v.as_ptr() as _;
282 let q: &[f64] = unsafe { std::slice::from_raw_parts(p, s * v.len()) };
283 q
284}
285
286fn recast_as_f64_mut<T>(v: &mut [T]) -> &mut [f64]
287where
288 T: Sized,
289{
290 let s = std::mem::size_of::<T>() / std::mem::size_of::<f64>();
291 let p: *mut f64 = v.as_ptr() as _;
292 let q: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(p, s * v.len()) };
293 q
294}