1use anyhow::{anyhow, Result};
2use std::collections::HashSet;
3use std::io::{Read, Write};
4
5use crate::amd::{AmdComplexGenerator, AmdSSEGenerator, AmdScalarGenerator, AmdVectorGenerator};
6use crate::applet::Applet;
7use crate::arm::{ArmComplexGenerator, ArmGenerator, ArmSimdGenerator};
8use crate::complexify::Complexifier;
9use crate::config::Config;
10use crate::generator::Generator;
11use crate::machine::MachineCode;
12use crate::matrix::{combine_matrixes, Matrix};
13use crate::mir::{CompiledMir, Mir};
14use crate::model::Program;
15use crate::riscv64::RiscV;
16use crate::symbol::Loc;
17use crate::utils::*;
18
19use rayon::prelude::*;
20
21#[derive(Debug, PartialEq, Copy, Clone)]
22pub enum CompilerType {
23 ByteCode,
25 Native,
27 Amd,
29 AmdAVX,
31 AmdSSE,
33 Arm,
35 RiscV,
37 Debug,
40}
41
42#[repr(C)] pub struct Application {
44 pub compiled: Option<MachineCode<f64>>,
48 pub compiled_simd: Option<MachineCode<f64>>,
49 pub use_simd: bool,
50 pub use_threads: bool,
51 pub count_states: usize,
52 pub count_params: usize,
53 pub count_obs: usize,
54 pub count_diffs: usize,
55 pub config: Config,
56 pub prog: Program,
58 pub compiled_fast: Option<MachineCode<f64>>,
59 pub bytecode: CompiledMir,
60 pub params: Vec<f64>,
61 pub can_fast: bool,
62 pub first_state: usize,
63 pub first_param: usize,
64 pub first_obs: usize,
65 pub first_diff: usize,
66 pub reals: HashSet<Loc>,
67 pub original: Option<Mir>,
68}
69
70impl Application {
71 pub fn new(mut prog: Program, reals: HashSet<Loc>) -> Result<Application> {
72 if !reals.is_empty() {
77 prog.builder.config.set_fast_complex(false);
78 }
79
80 let mut mir = Mir::new(prog.config().clone());
81 prog.builder.compile_mir(&mut mir)?;
82 prog.builder.optimize_mir(&mut mir)?;
83 Self::with_mir(prog, reals, mir)
84 }
85
86 pub fn with_mir(mut prog: Program, reals: HashSet<Loc>, mut mir: Mir) -> Result<Application> {
87 let first_state = 0;
88 let first_param = 0;
89 let first_obs = first_state + prog.count_states;
90 let first_diff = first_obs + prog.count_obs;
91
92 let count_states = prog.count_states;
93 let count_params = prog.count_params;
94 let count_obs = prog.count_obs;
95 let count_diffs = prog.count_diffs;
96
97 let params = vec![0.0; count_params + 1];
98
99 let config = prog.config().clone();
100 let mut original: Option<Mir> = None;
101 let compiled: Option<MachineCode<f64>>;
102
103 if config.is_complex() {
104 original = Some(mir.clone());
105 let complexified = Complexifier::new(&reals, config.clone()).complexify(&mir)?;
106
107 if config.fast_complex() {
108 compiled = Self::compile_ty(&config, &mir, &mut prog)?;
113 } else {
114 compiled = Self::compile_ty(&config, &complexified, &mut prog)?;
115 }
116
117 mir = complexified;
118 } else {
119 compiled = Self::compile_ty(&config, &mir, &mut prog)?;
120 }
121
122 let use_simd = config.use_simd() && prog.count_loops == 0;
123 let use_threads = config.use_threads();
124
125 let can_fast = config.may_fast()
126 && count_states <= 8
127 && count_params == 0
128 && count_obs == 1
129 && count_diffs == 0;
130
131 let bytecode = Self::compile_bytecode(mir, &mut prog)?;
133
134 Ok(Application {
135 prog,
136 compiled,
137 compiled_simd: None,
138 compiled_fast: None,
139 bytecode,
140 params,
141 use_simd,
142 use_threads,
143 can_fast,
144 first_state,
145 first_param,
146 first_obs,
147 first_diff,
148 count_states,
149 count_params,
150 count_obs,
151 count_diffs,
152 config,
153 reals,
154 original,
155 })
156 }
157
158 fn compile_ty(
159 config: &Config,
160 mir: &Mir,
161 prog: &mut Program,
162 ) -> Result<Option<MachineCode<f64>>> {
163 let compiled = match config.compiler_type() {
164 CompilerType::AmdAVX => Some(Self::compile_avx(mir, prog)?),
165 CompilerType::AmdSSE => Some(Self::compile_sse(mir, prog)?),
166 CompilerType::Arm => Some(Self::compile_arm(mir, prog)?),
167 CompilerType::RiscV => Some(Self::compile_riscv(mir, prog)?),
168 CompilerType::ByteCode => None,
169 CompilerType::Debug => {
170 println!("`ty = debug` is deprecated");
171 None
172 }
173 _ => return Err(anyhow!("unrecognized `ty`")),
174 };
175
176 Ok(compiled)
177 }
178
179 pub fn seal(self) -> Result<Applet> {
180 Applet::new(self)
181 }
182
183 pub fn as_applet(&self) -> &Applet {
184 unsafe { std::mem::transmute(self) }
185 }
186
187 fn compile<G: Generator>(
190 mir: &Mir,
191 prog: &mut Program,
192 mut generator: G,
193 size: usize,
194 arch: &str,
195 lanes: usize,
196 ) -> Result<MachineCode<f64>> {
197 let mem: Vec<f64> = vec![0.0; size];
198 prog.builder.compile_from_mir(
199 mir,
200 &mut generator,
201 prog.count_states,
202 prog.count_obs,
203 prog.count_params,
204 )?;
205
206 Ok(MachineCode::new(
207 arch,
208 generator.bytes(),
209 mem,
210 false,
211 lanes,
212 prog.config().huge(),
213 ))
214 }
215
216 fn compile_fast<G: Generator>(
217 mir: &Mir,
218 prog: &mut Program,
219 mut generator: G,
220 idx_ret: u32,
221 arch: &str,
222 ) -> Result<MachineCode<f64>> {
223 let mem: Vec<f64> = Vec::new();
224 prog.builder.compile_fast_from_mir(
225 mir,
226 &mut generator,
227 prog.count_states,
228 prog.count_obs,
229 idx_ret as i32,
230 )?;
231
232 Ok(MachineCode::new(
233 arch,
234 generator.bytes(),
235 mem,
236 true,
237 1,
238 prog.config().huge(),
239 ))
240 }
241
242 fn compile_bytecode(mir: Mir, prog: &mut Program) -> Result<CompiledMir> {
243 let mem: Vec<f64> = vec![0.0; prog.mem_size()];
244 let stack: Vec<f64> = vec![0.0; prog.builder.stack_size()];
245
246 Ok(CompiledMir::new(mir, mem, stack))
247 }
248
249 fn compile_sse(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
250 Self::compile::<AmdSSEGenerator>(
251 mir,
252 prog,
253 AmdSSEGenerator::new(prog.config().clone()),
254 prog.mem_size(),
255 "x86_64",
256 1,
257 )
258 }
259
260 fn compile_avx(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
261 if prog.config().is_complex() && prog.config().fast_complex() {
262 Self::compile::<AmdComplexGenerator>(
263 mir,
264 prog,
265 AmdComplexGenerator::new(prog.config().clone()),
266 prog.mem_size(),
267 "x86_64",
268 1,
269 )
270 } else {
271 Self::compile::<AmdScalarGenerator>(
272 mir,
273 prog,
274 AmdScalarGenerator::new(prog.config().clone()),
275 prog.mem_size(),
276 "x86_64",
277 1,
278 )
279 }
280 }
281
282 fn compile_avx_simd(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
283 Self::compile::<AmdVectorGenerator>(
284 mir,
285 prog,
286 AmdVectorGenerator::new(prog.config().clone()),
287 prog.mem_size() * 4,
288 "x86_64",
289 4,
290 )
291 }
292
293 fn compile_arm(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
294 if prog.config().is_complex() && prog.config().fast_complex() {
295 Self::compile::<ArmComplexGenerator>(
296 mir,
297 prog,
298 ArmComplexGenerator::new(prog.config().clone()),
299 prog.mem_size(),
300 "aarch64",
301 1,
302 )
303 } else {
304 Self::compile::<ArmGenerator>(
305 mir,
306 prog,
307 ArmGenerator::new(prog.config().clone()),
308 prog.mem_size(),
309 "aarch64",
310 1,
311 )
312 }
313 }
314
315 fn compile_arm_simd(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
316 Self::compile::<ArmSimdGenerator>(
317 mir,
318 prog,
319 ArmSimdGenerator::new(prog.config().clone()),
320 prog.mem_size() * 2,
321 "aarch64",
322 2,
323 )
324 }
325
326 fn compile_riscv(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
327 Self::compile::<RiscV>(
328 mir,
329 prog,
330 RiscV::new(prog.config().clone()),
331 prog.mem_size(),
332 "riscv64",
333 1,
334 )
335 }
336
337 fn compile_amd_fast(mir: &Mir, prog: &mut Program, idx_ret: u32) -> Result<MachineCode<f64>> {
338 if prog.config().has_avx() {
339 Self::compile_fast(
340 mir,
341 prog,
342 AmdScalarGenerator::new(prog.config().clone()),
343 idx_ret,
344 "x86_64",
345 )
346 } else {
347 Self::compile_fast(
348 mir,
349 prog,
350 AmdSSEGenerator::new(prog.config().clone()),
351 idx_ret,
352 "x86_64",
353 )
354 }
355 }
356
357 fn compile_arm_fast(mir: &Mir, prog: &mut Program, idx_ret: u32) -> Result<MachineCode<f64>> {
358 Self::compile_fast(
359 mir,
360 prog,
361 ArmGenerator::new(prog.config().clone()),
362 idx_ret,
363 "aarch64",
364 )
365 }
366
367 fn compile_riscv_fast(mir: &Mir, prog: &mut Program, idx_ret: u32) -> Result<MachineCode<f64>> {
368 Self::compile_fast(
369 mir,
370 prog,
371 RiscV::new(prog.config().clone()),
372 idx_ret,
373 "riscv64",
374 )
375 }
376
377 #[inline]
380 pub fn exec(&mut self) {
381 if let Some(compiled) = &mut self.compiled {
382 compiled.exec(&self.params[..])
383 } else {
384 self.bytecode.exec(&self.params[..]);
385 }
386 }
387
388 pub fn exec_callable(&mut self, xx: &[f64]) -> f64 {
389 if let Some(compiled) = &mut self.compiled {
390 let mem = compiled.mem_mut();
391 mem[self.first_state..self.first_state + self.count_states].copy_from_slice(xx);
392 compiled.exec(&self.params[..]);
393 compiled.mem()[self.first_obs]
394 } else {
395 let mem = self.bytecode.mem_mut();
396 mem[self.first_state..self.first_state + self.count_states].copy_from_slice(xx);
397 self.bytecode.exec(&self.params[..]);
398 self.bytecode.mem()[self.first_obs]
399 }
400 }
401
402 pub fn prepare_simd(&mut self) {
403 if self.compiled_simd.is_none() && self.use_simd {
405 if self.config.has_avx() {
406 self.compiled_simd =
407 Self::compile_avx_simd(&self.bytecode.mir, &mut self.prog).ok();
408 } else if self.config.is_arm64() {
409 self.compiled_simd =
410 Self::compile_arm_simd(&self.bytecode.mir, &mut self.prog).ok();
411 }
412 };
413 }
414
415 fn prepare_fast(&mut self) {
416 if self.compiled_simd.is_none() && self.can_fast {
418 if self.config.is_amd64() {
419 self.compiled_fast = Self::compile_amd_fast(
420 &self.bytecode.mir,
421 &mut self.prog,
422 self.first_obs as u32,
423 )
424 .ok();
425 } else if self.config.is_arm64() {
426 self.compiled_fast = Self::compile_arm_fast(
427 &self.bytecode.mir,
428 &mut self.prog,
429 self.first_obs as u32,
430 )
431 .ok();
432 } else if self.config.is_riscv64() {
433 self.compiled_fast = Self::compile_riscv_fast(
434 &self.bytecode.mir,
435 &mut self.prog,
436 self.first_obs as u32,
437 )
438 .ok();
439 }
440 };
441 }
442
443 pub fn get_fast(&mut self) -> Option<CompiledFunc<f64>> {
444 self.prepare_fast();
445 self.compiled_fast.as_ref().map(|c| c.func())
446 }
447
448 pub fn exec_vectorized(&mut self, states: &mut Matrix, obs: &mut Matrix) {
449 if let Some(compiled) = &self.compiled {
450 if !compiled.support_indirect() {
451 self.exec_vectorized_simple(states, obs);
452 return;
453 }
454
455 self.prepare_simd();
456
457 if let Some(simd) = &self.compiled_simd {
458 self.exec_vectorized_simd(states, obs, self.use_threads, simd.count_lanes());
459 } else {
460 self.exec_vectorized_scalar(states, obs, self.use_threads);
461 }
462 }
463 }
464
465 pub fn exec_vectorized_simple(&mut self, states: &Matrix, obs: &mut Matrix) {
466 assert!(states.ncols == obs.ncols);
467 let n = states.ncols;
468 let params = &self.params[..];
469
470 if let Some(compiled) = &mut self.compiled {
471 for t in 0..n {
472 {
473 let mem = compiled.mem_mut();
474 for i in 0..self.count_states {
475 mem[self.first_state + i] = states.get(i, t);
476 }
477 }
478
479 compiled.exec(params);
480
481 {
482 let mem = compiled.mem_mut();
483 for i in 0..self.count_obs {
484 obs.set(i, t, mem[self.first_obs + i]);
485 }
486 }
487 }
488 } else {
489 for t in 0..n {
490 {
491 let mem = self.bytecode.mem_mut();
492 for i in 0..self.count_states {
493 mem[self.first_state + i] = states.get(i, t);
494 }
495 }
496
497 self.bytecode.exec(params);
498
499 {
500 let mem = self.bytecode.mem_mut();
501 for i in 0..self.count_obs {
502 obs.set(i, t, mem[self.first_obs + i]);
503 }
504 }
505 }
506 }
507 }
508
509 fn exec_single(t: usize, v: &Matrix, params: &[f64], f: CompiledFunc<f64>) {
510 let p = v.p.as_ptr();
511 f(std::ptr::null(), p, t, params.as_ptr());
512 }
513
514 pub fn exec_vectorized_scalar(&mut self, states: &mut Matrix, obs: &mut Matrix, threads: bool) {
515 if let Some(compiled) = &mut self.compiled {
516 assert!(states.ncols == obs.ncols);
517 let n = states.ncols;
518 let f = compiled.func();
519 let params = &self.params[..];
520 let v = combine_matrixes(states, obs);
521
522 if threads {
523 (0..n)
524 .into_par_iter()
525 .for_each(|t| Self::exec_single(t, &v, params, f));
526 } else {
527 (0..n)
528 .for_each(|t| Self::exec_single(t, &v, params, f));
530 }
531 }
532 }
533
534 pub fn exec_vectorized_simd(
535 &mut self,
536 states: &mut Matrix,
537 obs: &mut Matrix,
538 threads: bool,
539 l: usize,
540 ) {
541 if let Some(compiled) = &mut self.compiled {
542 assert!(states.ncols == obs.ncols);
543 let n = states.ncols;
544 let params = &self.params[..];
545 let n0 = l * (n / l);
546 let v = combine_matrixes(states, obs);
547
548 if let Some(g) = &mut self.compiled_simd {
549 let f = g.func();
550 if threads {
551 (0..n / l)
552 .into_par_iter()
553 .for_each(|t| Self::exec_single(t, &v, params, f));
554 } else {
555 (0..n / l).for_each(|t| Self::exec_single(t, &v, params, f));
556 }
557 }
558
559 let f = compiled.func();
560
561 if threads {
562 (n0..n)
563 .into_par_iter()
564 .for_each(|t| Self::exec_single(t, &v, params, f));
565 } else {
566 (n0..n).for_each(|t| Self::exec_single(t, &v, params, f));
567 }
568 }
569 }
570
571 pub fn dump(&mut self, name: &str, what: &str) -> bool {
572 match what {
573 "scalar" => {
574 if let Some(f) = &self.compiled {
575 f.dump(name);
576 true
577 } else {
578 false
579 }
580 }
581 "simd" => {
582 self.prepare_simd();
583
584 if let Some(f) = &self.compiled_simd {
585 f.dump(name);
586 true
587 } else {
588 false
589 }
590 }
591 "fast" => {
592 self.prepare_fast();
593
594 if let Some(f) = &self.compiled_fast {
595 f.dump(name);
596 true
597 } else {
598 false
599 }
600 }
601 "bytecode" => {
602 self.bytecode.dump(name);
603 true
604 }
605 "stats" => {
606 let size = if let Some(f) = &self.compiled {
607 f.as_machine().unwrap().size
608 } else {
609 0
610 };
611 self.bytecode.mir.print_stats(name, size);
612 true
613 }
614 _ => false,
615 }
616 }
617
618 pub fn dumps(&self) -> Vec<u8> {
619 if let Some(f) = &self.compiled {
620 f.dumps()
621 } else {
622 Vec::new()
623 }
624 }
625
626 const MAGIC: usize = 0x40568795410d08e9;
629}
630
631fn save_reals(stream: &mut impl Write, reals: &HashSet<Loc>) -> Result<()> {
632 let num_elems = reals.len();
633 stream.write_all(&num_elems.to_le_bytes())?;
634
635 for r in reals.iter() {
636 let b = match r {
637 Loc::Mem(idx) => 0x100000000 | (*idx as usize),
638 Loc::Stack(idx) => 0x200000000 | (*idx as usize),
639 Loc::Param(idx) => 0x300000000 | (*idx as usize),
640 };
641 stream.write_all(&b.to_le_bytes())?;
642 }
643
644 Ok(())
645}
646
647fn load_reals(stream: &mut impl Read) -> Result<HashSet<Loc>> {
648 let mut bytes: [u8; 8] = [0; 8];
649
650 stream.read_exact(&mut bytes)?;
651 let num_elems = usize::from_le_bytes(bytes);
652
653 let mut reals: HashSet<Loc> = HashSet::new();
654
655 for _ in 0..num_elems {
656 stream.read_exact(&mut bytes)?;
657 let b = usize::from_le_bytes(bytes);
658
659 let r = match b >> 32 {
660 1 => Loc::Mem((b & 0xffffffff) as u32),
661 2 => Loc::Stack((b & 0xffffffff) as u32),
662 3 => Loc::Param((b & 0xffffffff) as u32),
663 _ => return Err(anyhow!("invalid loc")),
664 };
665 reals.insert(r);
666 }
667
668 Ok(reals)
669}
670
671impl Storage for Application {
672 fn save(&self, stream: &mut impl Write) -> Result<()> {
673 stream.write_all(&Self::MAGIC.to_le_bytes())?;
674
675 let version: usize = 3;
676 stream.write_all(&version.to_le_bytes())?;
677
678 self.prog.save(stream)?;
679
680 let mut mask: usize = 0;
681
682 if self.compiled.is_some() && self.compiled.as_ref().unwrap().as_machine().is_some() {
683 mask |= 1;
684 };
685
686 if self.compiled_fast.is_some()
687 && self.compiled_fast.as_ref().unwrap().as_machine().is_some()
688 {
689 mask |= 2;
690 }
691
692 if self.compiled_simd.is_some()
693 && self.compiled_simd.as_ref().unwrap().as_machine().is_some()
694 {
695 mask |= 4;
696 }
697
698 stream.write_all(&mask.to_le_bytes())?;
699
700 match &self.original {
701 Some(mir) => mir.save(stream)?,
702 None => self.bytecode.mir.save(stream)?,
703 }
704
705 save_reals(stream, &self.reals)?;
706
707 Ok(())
708 }
709
710 fn load(stream: &mut impl Read, config: &Config) -> Result<Self> {
711 let mut bytes: [u8; 8] = [0; 8];
712
713 stream.read_exact(&mut bytes)?;
714
715 if usize::from_le_bytes(bytes) != Self::MAGIC {
716 return Err(anyhow!("invalid magic number (Application)"));
717 }
718
719 stream.read_exact(&mut bytes)?;
720
721 if usize::from_le_bytes(bytes) != 3 {
722 return Err(anyhow!("invalid sjb version"));
723 }
724
725 let prog = Program::load(stream, config)?;
726
727 stream.read_exact(&mut bytes)?;
728 let mask = usize::from_le_bytes(bytes);
729
730 let mir = Mir::load(stream, prog.config())?;
731
732 let reals = load_reals(stream)?;
733
734 let mut app = Application::with_mir(prog, reals, mir)?;
735
736 if mask & 2 != 0 {
737 app.prepare_fast();
738 }
739
740 if mask & 4 != 0 {
741 app.prepare_simd();
742 }
743
744 Ok(app)
745 }
746}