Skip to main content

symjit/
runnable.rs

1use anyhow::{anyhow, Result};
2use std::collections::HashSet;
3use std::io::{Read, Write};
4
5use crate::amd::{AmdFamily, AmdGenerator};
6use crate::applet::Applet;
7use crate::arm::{ArmGenerator, ArmSimdGenerator};
8use crate::complexify::Complexifier;
9use crate::config::Config;
10use crate::generator::Generator;
11use crate::machine::MachineCode;
12use crate::matrix::{combine_matrixes, Matrix};
13use crate::mir::{CompiledMir, Mir};
14use crate::model::Program;
15use crate::riscv64::RiscV;
16use crate::symbol::Loc;
17use crate::utils::*;
18
19use rayon::prelude::*;
20
21#[derive(Debug, PartialEq, Copy, Clone)]
22pub enum CompilerType {
23    /// generates bytecode (interpreter).
24    ByteCode,
25    /// generates code for the detected CPU (default)
26    Native,
27    /// generates x86-64 (AMD64) code.
28    Amd,
29    /// generates AVX code for x86-64 architecture.
30    AmdAVX,
31    /// generates SSE2 code for x86-64 architecture.
32    AmdSSE,
33    /// generates aarch64 (ARM64) code.
34    Arm,
35    /// generates riscv64 (RISC V) code.
36    RiscV,
37    /// debug mode, generates both bytecode and native codes
38    /// and compares the outputs.
39    Debug,
40}
41
42#[repr(C)] // to ensure binary compatibility with Applet
43pub struct Application {
44    // Applet compatibility
45    // Important! The order of these fields is critical and should be
46    // the same as the order of Applet fields.
47    pub compiled: Option<MachineCode<f64>>,
48    pub compiled_simd: Option<MachineCode<f64>>,
49    pub use_simd: bool,
50    pub use_threads: bool,
51    pub count_states: usize,
52    pub count_params: usize,
53    pub count_obs: usize,
54    pub count_diffs: usize,
55    pub config: Config,
56    // Non-Applet fields
57    pub prog: Program,
58    pub compiled_fast: Option<MachineCode<f64>>,
59    pub bytecode: CompiledMir,
60    pub params: Vec<f64>,
61    pub can_fast: bool,
62    pub first_state: usize,
63    pub first_param: usize,
64    pub first_obs: usize,
65    pub first_diff: usize,
66}
67
68impl Application {
69    pub fn new(mut prog: Program, reals: HashSet<Loc>) -> Result<Application> {
70        let first_state = 0;
71        let first_param = 0;
72        let first_obs = first_state + prog.count_states;
73        let first_diff = first_obs + prog.count_obs;
74
75        let count_states = prog.count_states;
76        let count_params = prog.count_params;
77        let count_obs = prog.count_obs;
78        let count_diffs = prog.count_diffs;
79
80        let params = vec![0.0; count_params + 1];
81
82        let config = prog.config().clone();
83
84        let mut mir = prog.builder.compile_mir()?;
85
86        if config.is_complex() {
87            mir = Complexifier::new(&reals, config.clone()).complexify(&mir)?;
88        }
89
90        // let compiled = Self::compile_ty(prog.config().compiler_type(), &mir, &mut prog)?;
91        let compiled = match config.compiler_type() {
92            CompilerType::AmdAVX => Some(Self::compile_avx(&mir, &mut prog)?),
93            CompilerType::AmdSSE => Some(Self::compile_sse(&mir, &mut prog)?),
94            CompilerType::Arm => Some(Self::compile_arm(&mir, &mut prog)?),
95            CompilerType::RiscV => Some(Self::compile_riscv(&mir, &mut prog)?),
96            CompilerType::ByteCode => None,
97            CompilerType::Debug => {
98                println!("`ty = debug` is deprecated");
99                None
100            }
101            _ => return Err(anyhow!("unrecognized `ty`")),
102        };
103
104        let use_simd = config.use_simd() && prog.count_loops == 0;
105        let use_threads = config.use_threads() && prog.mem_size() < 128;
106
107        let can_fast = config.may_fast()
108            && count_states <= 8
109            && count_params == 0
110            && count_obs == 1
111            && count_diffs == 0;
112
113        // bytecode takes the ownership of mir
114        let bytecode = Self::compile_bytecode(mir, &mut prog)?;
115
116        Ok(Application {
117            prog,
118            compiled,
119            compiled_simd: None,
120            compiled_fast: None,
121            bytecode,
122            params,
123            use_simd,
124            use_threads,
125            can_fast,
126            first_state,
127            first_param,
128            first_obs,
129            first_diff,
130            count_states,
131            count_params,
132            count_obs,
133            count_diffs,
134            config,
135        })
136    }
137
138    pub fn seal(self) -> Result<Applet> {
139        Applet::new(self)
140    }
141
142    pub fn as_applet(&self) -> &Applet {
143        unsafe { std::mem::transmute(self) }
144    }
145
146    /********************* compile_* functions *************************/
147
148    fn compile<G: Generator>(
149        mir: &Mir,
150        prog: &mut Program,
151        mut generator: G,
152        size: usize,
153        arch: &str,
154        lanes: usize,
155    ) -> Result<MachineCode<f64>> {
156        let mem: Vec<f64> = vec![0.0; size];
157        prog.builder.compile_from_mir(
158            mir,
159            &mut generator,
160            prog.count_states,
161            prog.count_obs,
162            prog.count_params,
163        )?;
164
165        Ok(MachineCode::new(arch, generator.bytes(), mem, false, lanes))
166    }
167
168    fn compile_fast<G: Generator>(
169        mir: &Mir,
170        prog: &mut Program,
171        mut generator: G,
172        idx_ret: u32,
173        arch: &str,
174    ) -> Result<MachineCode<f64>> {
175        let mem: Vec<f64> = Vec::new();
176        prog.builder.compile_fast_from_mir(
177            mir,
178            &mut generator,
179            prog.count_states,
180            prog.count_obs,
181            idx_ret as i32,
182        )?;
183
184        Ok(MachineCode::new(arch, generator.bytes(), mem, true, 1))
185    }
186
187    fn compile_bytecode(mir: Mir, prog: &mut Program) -> Result<CompiledMir> {
188        let mem: Vec<f64> = vec![0.0; prog.mem_size()];
189        let stack: Vec<f64> = vec![0.0; prog.builder.block().sym_table.num_stack];
190
191        Ok(CompiledMir::new(mir, mem, stack))
192    }
193
194    fn compile_sse(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
195        Self::compile::<AmdGenerator>(
196            mir,
197            prog,
198            AmdGenerator::new(AmdFamily::SSEScalar, prog.config().clone()),
199            prog.mem_size(),
200            "x86_64",
201            1,
202        )
203    }
204
205    fn compile_avx(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
206        Self::compile::<AmdGenerator>(
207            mir,
208            prog,
209            AmdGenerator::new(AmdFamily::AvxScalar, prog.config().clone()),
210            prog.mem_size(),
211            "x86_64",
212            1,
213        )
214    }
215
216    fn compile_avx_simd(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
217        Self::compile::<AmdGenerator>(
218            mir,
219            prog,
220            AmdGenerator::new(AmdFamily::AvxVector, prog.config().clone()),
221            prog.mem_size() * 4,
222            "x86_64",
223            4,
224        )
225    }
226
227    fn compile_arm(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
228        Self::compile::<ArmGenerator>(
229            mir,
230            prog,
231            ArmGenerator::new(prog.config().clone()),
232            prog.mem_size(),
233            "aarch64",
234            1,
235        )
236    }
237
238    fn compile_arm_simd(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
239        Self::compile::<ArmSimdGenerator>(
240            mir,
241            prog,
242            ArmSimdGenerator::new(prog.config().clone()),
243            prog.mem_size() * 2,
244            "aarch64",
245            2,
246        )
247    }
248
249    fn compile_riscv(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
250        Self::compile::<RiscV>(
251            mir,
252            prog,
253            RiscV::new(prog.config().clone()),
254            prog.mem_size(),
255            "riscv64",
256            1,
257        )
258    }
259
260    fn compile_amd_fast(mir: &Mir, prog: &mut Program, idx_ret: u32) -> Result<MachineCode<f64>> {
261        if prog.config().has_avx() {
262            Self::compile_fast(
263                mir,
264                prog,
265                AmdGenerator::new(AmdFamily::AvxScalar, prog.config().clone()),
266                idx_ret,
267                "x86_64",
268            )
269        } else {
270            Self::compile_fast(
271                mir,
272                prog,
273                AmdGenerator::new(AmdFamily::SSEScalar, prog.config().clone()),
274                idx_ret,
275                "x86_64",
276            )
277        }
278    }
279
280    fn compile_arm_fast(mir: &Mir, prog: &mut Program, idx_ret: u32) -> Result<MachineCode<f64>> {
281        Self::compile_fast(
282            mir,
283            prog,
284            ArmGenerator::new(prog.config().clone()),
285            idx_ret,
286            "aarch64",
287        )
288    }
289
290    fn compile_riscv_fast(mir: &Mir, prog: &mut Program, idx_ret: u32) -> Result<MachineCode<f64>> {
291        Self::compile_fast(
292            mir,
293            prog,
294            RiscV::new(prog.config().clone()),
295            idx_ret,
296            "riscv64",
297        )
298    }
299
300    /**********************************************************/
301
302    #[inline]
303    pub fn exec(&mut self) {
304        if let Some(compiled) = &mut self.compiled {
305            compiled.exec(&self.params[..])
306        } else {
307            self.bytecode.exec(&self.params[..]);
308        }
309    }
310
311    pub fn exec_callable(&mut self, xx: &[f64]) -> f64 {
312        if let Some(compiled) = &mut self.compiled {
313            let mem = compiled.mem_mut();
314            mem[self.first_state..self.first_state + self.count_states].copy_from_slice(xx);
315            compiled.exec(&self.params[..]);
316            compiled.mem()[self.first_obs]
317        } else {
318            let mem = self.bytecode.mem_mut();
319            mem[self.first_state..self.first_state + self.count_states].copy_from_slice(xx);
320            self.bytecode.exec(&self.params[..]);
321            self.bytecode.mem()[self.first_obs]
322        }
323    }
324
325    pub fn prepare_simd(&mut self) {
326        // SIMD compilation is lazy!
327        if self.compiled_simd.is_none() && self.use_simd {
328            if self.config.has_avx() {
329                self.compiled_simd =
330                    Self::compile_avx_simd(&self.bytecode.mir, &mut self.prog).ok();
331            } else if self.config.is_arm64() {
332                self.compiled_simd =
333                    Self::compile_arm_simd(&self.bytecode.mir, &mut self.prog).ok();
334            }
335        };
336    }
337
338    fn prepare_fast(&mut self) {
339        // fast func compilation is lazy!
340        if self.compiled_simd.is_none() && self.can_fast {
341            if self.config.is_amd64() {
342                self.compiled_fast = Self::compile_amd_fast(
343                    &self.bytecode.mir,
344                    &mut self.prog,
345                    self.first_obs as u32,
346                )
347                .ok();
348            } else if self.config.is_arm64() {
349                self.compiled_fast = Self::compile_arm_fast(
350                    &self.bytecode.mir,
351                    &mut self.prog,
352                    self.first_obs as u32,
353                )
354                .ok();
355            } else if self.config.is_riscv64() {
356                self.compiled_fast = Self::compile_riscv_fast(
357                    &self.bytecode.mir,
358                    &mut self.prog,
359                    self.first_obs as u32,
360                )
361                .ok();
362            }
363        };
364    }
365
366    pub fn get_fast(&mut self) -> Option<CompiledFunc<f64>> {
367        self.prepare_fast();
368        self.compiled_fast.as_ref().map(|c| c.func())
369    }
370
371    pub fn exec_vectorized(&mut self, states: &mut Matrix, obs: &mut Matrix) {
372        if let Some(compiled) = &self.compiled {
373            if !compiled.support_indirect() {
374                self.exec_vectorized_simple(states, obs);
375                return;
376            }
377
378            self.prepare_simd();
379
380            if let Some(simd) = &self.compiled_simd {
381                self.exec_vectorized_simd(states, obs, self.use_threads, simd.count_lanes());
382            } else {
383                self.exec_vectorized_scalar(states, obs, self.use_threads);
384            }
385        }
386    }
387
388    pub fn exec_vectorized_simple(&mut self, states: &Matrix, obs: &mut Matrix) {
389        assert!(states.ncols == obs.ncols);
390        let n = states.ncols;
391        let params = &self.params[..];
392
393        if let Some(compiled) = &mut self.compiled {
394            for t in 0..n {
395                {
396                    let mem = compiled.mem_mut();
397                    for i in 0..self.count_states {
398                        mem[self.first_state + i] = states.get(i, t);
399                    }
400                }
401
402                compiled.exec(params);
403
404                {
405                    let mem = compiled.mem_mut();
406                    for i in 0..self.count_obs {
407                        obs.set(i, t, mem[self.first_obs + i]);
408                    }
409                }
410            }
411        } else {
412            for t in 0..n {
413                {
414                    let mem = self.bytecode.mem_mut();
415                    for i in 0..self.count_states {
416                        mem[self.first_state + i] = states.get(i, t);
417                    }
418                }
419
420                self.bytecode.exec(params);
421
422                {
423                    let mem = self.bytecode.mem_mut();
424                    for i in 0..self.count_obs {
425                        obs.set(i, t, mem[self.first_obs + i]);
426                    }
427                }
428            }
429        }
430    }
431
432    fn exec_single(t: usize, v: &Matrix, params: &[f64], f: CompiledFunc<f64>) {
433        let p = v.p.as_ptr();
434        f(std::ptr::null(), p, t, params.as_ptr());
435    }
436
437    pub fn exec_vectorized_scalar(&mut self, states: &mut Matrix, obs: &mut Matrix, threads: bool) {
438        if let Some(compiled) = &mut self.compiled {
439            assert!(states.ncols == obs.ncols);
440            let n = states.ncols;
441            let f = compiled.func();
442            let params = &self.params[..];
443            let v = combine_matrixes(states, obs);
444
445            if threads {
446                (0..n)
447                    .into_par_iter()
448                    .for_each(|t| Self::exec_single(t, &v, params, f));
449            } else {
450                (0..n)
451                    //.into_iter()
452                    .for_each(|t| Self::exec_single(t, &v, params, f));
453            }
454        }
455    }
456
457    pub fn exec_vectorized_simd(
458        &mut self,
459        states: &mut Matrix,
460        obs: &mut Matrix,
461        threads: bool,
462        l: usize,
463    ) {
464        if let Some(compiled) = &mut self.compiled {
465            assert!(states.ncols == obs.ncols);
466            let n = states.ncols;
467            let params = &self.params[..];
468            let n0 = l * (n / l);
469            let v = combine_matrixes(states, obs);
470
471            if let Some(g) = &mut self.compiled_simd {
472                let f = g.func();
473                if threads {
474                    (0..n / l)
475                        .into_par_iter()
476                        .for_each(|t| Self::exec_single(t, &v, params, f));
477                } else {
478                    (0..n / l).for_each(|t| Self::exec_single(t, &v, params, f));
479                }
480            }
481
482            let f = compiled.func();
483
484            if threads {
485                (n0..n)
486                    .into_par_iter()
487                    .for_each(|t| Self::exec_single(t, &v, params, f));
488            } else {
489                (n0..n).for_each(|t| Self::exec_single(t, &v, params, f));
490            }
491        }
492    }
493
494    pub fn dump(&mut self, name: &str, what: &str) -> bool {
495        match what {
496            "scalar" => {
497                if let Some(f) = &self.compiled {
498                    f.dump(name);
499                    true
500                } else {
501                    false
502                }
503            }
504            "simd" => {
505                self.prepare_simd();
506
507                if let Some(f) = &self.compiled_simd {
508                    f.dump(name);
509                    true
510                } else {
511                    false
512                }
513            }
514            "fast" => {
515                self.prepare_fast();
516
517                if let Some(f) = &self.compiled_fast {
518                    f.dump(name);
519                    true
520                } else {
521                    false
522                }
523            }
524            _ => false,
525        }
526    }
527
528    pub fn dumps(&self) -> Vec<u8> {
529        if let Some(f) = &self.compiled {
530            f.dumps()
531        } else {
532            Vec::new()
533        }
534    }
535
536    /************************** save/load ******************************/
537
538    const MAGIC: usize = 0x40568795410d08e9;
539}
540
541impl Storage for Application {
542    fn save(&self, stream: &mut impl Write) -> Result<()> {
543        stream.write_all(&Self::MAGIC.to_le_bytes())?;
544
545        let version: usize = 1;
546        stream.write_all(&version.to_le_bytes())?;
547
548        self.prog.save(stream)?;
549
550        let mut mask: usize = 0;
551
552        if self.compiled.is_some() && self.compiled.as_ref().unwrap().as_machine().is_some() {
553            mask |= 1;
554        };
555
556        if self.compiled_fast.is_some()
557            && self.compiled_fast.as_ref().unwrap().as_machine().is_some()
558        {
559            mask |= 2;
560        }
561
562        if self.compiled_simd.is_some()
563            && self.compiled_simd.as_ref().unwrap().as_machine().is_some()
564        {
565            mask |= 4;
566        }
567
568        stream.write_all(&mask.to_le_bytes())?;
569
570        if let Some(compiled) = &self.compiled {
571            compiled.as_machine().unwrap().save(stream)?;
572        }
573
574        if let Some(compiled) = &self.compiled_fast {
575            compiled.as_machine().unwrap().save(stream)?;
576        }
577
578        if let Some(compiled) = &self.compiled_simd {
579            compiled.as_machine().unwrap().save(stream)?;
580        }
581
582        Ok(())
583    }
584
585    fn load(stream: &mut impl Read) -> Result<Self> {
586        let mut bytes: [u8; 8] = [0; 8];
587
588        stream.read_exact(&mut bytes)?;
589
590        if usize::from_le_bytes(bytes) != Self::MAGIC {
591            return Err(anyhow!("invalid magic number"));
592        }
593
594        stream.read_exact(&mut bytes)?;
595
596        if usize::from_le_bytes(bytes) != 1 {
597            return Err(anyhow!("invalid sjb version"));
598        }
599
600        let mut prog = Program::load(stream)?;
601
602        stream.read_exact(&mut bytes)?;
603        let mask = usize::from_le_bytes(bytes);
604
605        let compiled: Option<MachineCode<f64>> = if mask & 1 != 0 {
606            Some(MachineCode::load(stream)?)
607        } else {
608            None
609        };
610
611        let compiled_fast: Option<MachineCode<f64>> = if mask & 2 != 0 {
612            Some(MachineCode::load(stream)?)
613        } else {
614            None
615        };
616
617        let compiled_simd: Option<MachineCode<f64>> = if mask & 4 != 0 {
618            Some(MachineCode::load(stream)?)
619        } else {
620            None
621        };
622
623        let first_state = 0;
624        let first_param = 0;
625        let first_obs = first_state + prog.count_states;
626        let first_diff = first_obs + prog.count_obs;
627
628        let count_states = prog.count_states;
629        let count_params = prog.count_params;
630        let count_obs = prog.count_obs;
631        let count_diffs = prog.count_diffs;
632
633        let params = vec![0.0; count_params + 1];
634
635        let config = prog.config().clone();
636        let mir = Mir::new(config.clone());
637
638        let use_simd = config.use_simd() && prog.count_loops == 0;
639        let use_threads = config.use_threads() && prog.mem_size() < 128;
640
641        let can_fast = config.may_fast()
642            && count_states <= 8
643            && count_params == 0
644            && count_obs == 1
645            && count_diffs == 0;
646
647        let bytecode = Self::compile_bytecode(mir, &mut prog)?;
648
649        Ok(Application {
650            prog,
651            compiled,
652            compiled_simd,
653            compiled_fast,
654            bytecode,
655            params,
656            use_simd,
657            use_threads,
658            can_fast,
659            first_state,
660            first_param,
661            first_obs,
662            first_diff,
663            count_states,
664            count_params,
665            count_obs,
666            count_diffs,
667            config,
668        })
669    }
670}