1#[cfg(target_arch = "x86_64")]
2use std::arch::x86_64::{__m256d, _mm256_setzero_pd};
3
4use std::collections::{HashMap, HashSet};
5
6use anyhow::{anyhow, Result};
7use num_complex::Complex;
8use rayon::prelude::*;
9
10use crate::config::Config;
11use crate::defuns::Defuns;
12use crate::expr::Expr;
13pub use crate::instruction::{BuiltinSymbol, Instruction, Slot, SymbolicaModel};
14use crate::model::{CellModel, Equation, Program, Variable};
15use crate::parser::Parser;
16use crate::symbol::Loc;
17use crate::utils::CompiledFunc;
18use crate::Application;
19
20pub struct Compiler {
22 config: Config,
23 df: Defuns,
24}
25
26#[cfg(not(target_arch = "x86_64"))]
27#[allow(non_camel_case_types)]
28type __m256d = [f64; 4];
29
30impl Compiler {
84 pub fn new() -> Compiler {
86 Compiler {
87 config: Config::default(),
88 df: Defuns::new(),
89 }
90 }
91
92 pub fn with_config(config: Config) -> Compiler {
93 Compiler {
94 config,
95 df: Defuns::new(),
96 }
97 }
98
99 pub fn compile(&mut self, states: &[Expr], obs: &[Expr]) -> Result<Application> {
104 self.compile_params(states, obs, &[])
105 }
106
107 pub fn compile_params(
117 &mut self,
118 states: &[Expr],
119 obs: &[Expr],
120 params: &[Expr],
121 ) -> Result<Application> {
122 let mut vars: Vec<Variable> = Vec::new();
123
124 for state in states.iter() {
125 let v = state.to_variable()?;
126 vars.push(v);
127 }
128
129 let mut ps: Vec<Variable> = Vec::new();
130
131 for p in params.iter() {
132 let v = p.to_variable()?;
133 ps.push(v);
134 }
135
136 let mut eqs: Vec<Equation> = Vec::new();
137
138 for (i, expr) in obs.iter().enumerate() {
139 let name = format!("${}", i);
140 let lhs = Expr::var(&name);
141 eqs.push(Expr::equation(&lhs, expr));
142 }
143
144 let ml = CellModel {
145 iv: Expr::var("$_").to_variable()?,
146 params: ps,
147 states: vars,
148 algs: Vec::new(),
149 odes: Vec::new(),
150 obs: eqs,
151 };
152
153 let prog = Program::new(&ml, self.config)?;
154 let mut app = Application::new(prog, HashSet::new(), &self.df);
156
157 #[cfg(target_arch = "aarch64")]
158 if let Ok(app) = &mut app {
159 app.dump("dump.bin", "scalar");
161 std::fs::remove_file("dump.bin")?;
162 };
163
164 app
165 }
166
167 pub fn def_unary(&mut self, op: &str, f: extern "C" fn(f64) -> f64) {
169 self.df.add_unary(op, f)
170 }
171
172 pub fn def_binary(&mut self, op: &str, f: extern "C" fn(f64, f64) -> f64) {
174 self.df.add_binary(op, f)
175 }
176}
177
178#[cfg(target_arch = "x86_64")]
179unsafe fn simd_slice(a: &[f64]) -> &[__m256d] {
180 assert!(a.len() & 3 == 0);
181 let p: *const f64 = a.as_ptr();
182 let v = unsafe { std::slice::from_raw_parts(p as *const __m256d, a.len() >> 2) };
183 v
184}
185
186#[cfg(target_arch = "x86_64")]
187unsafe fn simd_slice_mut(a: &mut [f64]) -> &mut [__m256d] {
188 assert!(a.len() & 3 == 0);
189 let p: *mut f64 = a.as_mut_ptr();
190 let v: &mut [__m256d] =
191 unsafe { std::slice::from_raw_parts_mut(p as *mut __m256d, a.len() >> 2) };
192 v
193}
194
195pub enum FastFunc<'a> {
196 F1(extern "C" fn(f64) -> f64, &'a Application),
197 F2(extern "C" fn(f64, f64) -> f64, &'a Application),
198 F3(extern "C" fn(f64, f64, f64) -> f64, &'a Application),
199 F4(extern "C" fn(f64, f64, f64, f64) -> f64, &'a Application),
200 F5(
201 extern "C" fn(f64, f64, f64, f64, f64) -> f64,
202 &'a Application,
203 ),
204 F6(
205 extern "C" fn(f64, f64, f64, f64, f64, f64) -> f64,
206 &'a Application,
207 ),
208 F7(
209 extern "C" fn(f64, f64, f64, f64, f64, f64, f64) -> f64,
210 &'a Application,
211 ),
212 F8(
213 extern "C" fn(f64, f64, f64, f64, f64, f64, f64, f64) -> f64,
214 &'a Application,
215 ),
216}
217
218impl Application {
219 pub fn call(&mut self, args: &[f64]) -> Vec<f64> {
226 {
227 let mem = self.compiled.mem_mut();
228 let states = &mut mem[self.first_state..self.first_state + self.count_states];
229 states.copy_from_slice(args);
230 }
231
232 self.compiled.exec(&self.params[..]);
233
234 let obs = {
235 let mem = self.compiled.mem();
236 &mem[self.first_obs..self.first_obs + self.count_obs]
237 };
238
239 obs.to_vec()
240 }
241
242 pub fn call_params(&mut self, args: &[f64], params: &[f64]) -> Vec<f64> {
250 {
251 let mem = self.compiled.mem_mut();
252 let states = &mut mem[self.first_state..self.first_state + self.count_states];
253 states.copy_from_slice(args);
254 }
255
256 self.compiled.exec(params);
257
258 let obs = {
259 let mem = self.compiled.mem();
260 &mem[self.first_obs..self.first_obs + self.count_obs]
261 };
262
263 obs.to_vec()
264 }
265
266 #[inline(always)]
268 pub fn evaluate<T: Sized + Copy>(&mut self, args: &[T], outs: &mut [T]) {
269 if self.prog.config().is_bytecode() {
270 let mut stack: Vec<f64> = vec![0.0; self.prog.builder.block().sym_table.num_stack];
271 let mut regs = [0.0; 32];
272
273 let outs: &mut [f64] = unsafe { std::mem::transmute(outs) };
274 let args: &[f64] = unsafe { std::mem::transmute(args) };
275
276 self.mir.exec_instruction(outs, &mut stack, &mut regs, args);
277
278 return;
279 }
280
281 let f = self.compiled.func();
282
283 f(
284 outs.as_ptr() as *mut f64,
285 std::ptr::null(),
286 0,
287 args.as_ptr() as *const f64,
288 );
289 }
290
291 #[inline(always)]
293 pub fn evaluate_single<T: Sized + Copy + Default>(&mut self, args: &[T]) -> T {
294 let mut outs = [T::default(); 1];
295 self.evaluate(args, &mut outs);
296 outs[0]
297 }
298
299 #[inline(always)]
301 pub fn evaluate_simd<T: Sized + Copy>(&mut self, args: &[T], outs: &mut [T]) {
302 if let Some(g) = &mut self.compiled_simd {
303 let f = g.func();
304
305 f(
306 outs.as_ptr() as *mut f64,
307 std::ptr::null(),
308 0,
309 args.as_ptr() as *const f64,
310 );
311 }
312 }
313
314 #[inline(always)]
316 pub fn evaluate_simd_single<T: Sized + Copy + Default>(&mut self, args: &[T]) -> T {
317 let mut outs = [T::default(); 1];
318 self.evaluate_simd(args, &mut outs);
319 outs[0]
320 }
321
322 fn evaluate_row(
323 args: &[f64],
324 args_idx: usize,
325 outs: &[f64],
326 outs_idx: usize,
327 f: CompiledFunc<f64>,
328 transpose: bool,
329 ) -> i32 {
330 unsafe {
331 f(
332 outs.as_ptr().add(outs_idx),
333 std::ptr::null(),
334 if transpose { 1 } else { 0 },
335 args.as_ptr().add(args_idx),
336 )
337 }
338 }
339
340 fn evaluate_matrix_with_threads(&mut self, args: &[f64], outs: &mut [f64], n: usize) {
342 let count_params = self.count_params;
343 let count_obs = self.count_obs;
344 let f_scalar = self.compiled.func();
345
346 (0..n).into_par_iter().for_each(|t| {
347 Self::evaluate_row(args, t * count_params, outs, t * count_obs, f_scalar, false);
348 });
349 }
350
351 fn evaluate_matrix_without_threads(&mut self, args: &[f64], outs: &mut [f64], n: usize) {
353 let count_params = self.count_params;
354 let count_obs = self.count_obs;
355 let f_scalar = self.compiled.func();
356
357 for t in 0..n {
358 Self::evaluate_row(args, t * count_params, outs, t * count_obs, f_scalar, false);
359 }
360 }
361
362 fn evaluate_matrix_with_threads_simd(&mut self, args: &[f64], outs: &mut [f64], n: usize) {
363 let count_params = self.count_params;
364 let count_obs = self.count_obs;
365
366 if let Some(compiled) = &self.compiled_simd {
367 let f_simd = compiled.func();
368 let f_scalar = self.compiled.func();
369 let lanes = compiled.count_lanes();
370
371 (0..n / lanes).into_par_iter().for_each(|k| {
372 let top = k * lanes;
373 if Self::evaluate_row(
374 args,
375 top * count_params,
376 outs,
377 top * count_obs,
378 f_simd,
379 true,
380 ) != 0
381 {
382 for i in 0..lanes {
383 Self::evaluate_row(
384 args,
385 (top + i) * count_params,
386 outs,
387 (top + i) * count_obs,
388 f_scalar,
389 false,
390 );
391 }
392 }
393 });
394
395 for t in lanes * (n / lanes)..n {
396 Self::evaluate_row(args, t * count_params, outs, t * count_obs, f_scalar, false);
397 }
398 }
399 }
400
401 fn evaluate_matrix_without_threads_simd(&mut self, args: &[f64], outs: &mut [f64], n: usize) {
402 let count_params = self.count_params;
403 let count_obs = self.count_obs;
404
405 if let Some(compiled) = &self.compiled_simd {
406 let f_simd = compiled.func();
407 let f_scalar = self.compiled.func();
408 let lanes = compiled.count_lanes();
409
410 for k in 0..n / lanes {
411 let top = k * lanes;
412 if Self::evaluate_row(
413 args,
414 top * count_params,
415 outs,
416 top * count_obs,
417 f_simd,
418 true,
419 ) != 0
420 {
421 for i in 0..lanes {
422 Self::evaluate_row(
423 args,
424 (top + i) * count_params,
425 outs,
426 (top + i) * count_obs,
427 f_scalar,
428 false,
429 );
430 }
431 }
432 }
433
434 for t in lanes * (n / lanes)..n {
435 Self::evaluate_row(args, t * count_params, outs, t * count_obs, f_scalar, false);
436 }
437 }
438 }
439
440 pub fn evaluate_matrix_bytecode(&mut self, args: &[f64], outs: &mut [f64], n: usize) {
441 let count_params = self.count_params;
442 let count_obs = self.count_obs;
443
444 for i in 0..n {
445 self.evaluate(
446 &args[i * count_params..(i + 1) * count_params],
447 &mut outs[i * count_obs..(i + 1) * count_obs],
448 );
449 }
450 }
451
452 pub fn evaluate_matrix(&mut self, args: &[f64], outs: &mut [f64], n: usize) {
454 if self.prog.config().is_bytecode() {
455 self.evaluate_matrix_bytecode(args, outs, n);
456 } else if self.use_threads {
457 if self.compiled_simd.is_some() {
458 self.evaluate_matrix_with_threads_simd(args, outs, n);
459 } else {
460 self.evaluate_matrix_with_threads(args, outs, n);
461 }
462 } else {
463 if self.compiled_simd.is_some() {
464 self.evaluate_matrix_without_threads_simd(args, outs, n);
465 } else {
466 self.evaluate_matrix_without_threads(args, outs, n);
467 }
468 }
469 }
470
471 pub fn evaluate_complex_matrix(
472 &mut self,
473 args: &[Complex<f64>],
474 outs: &mut [Complex<f64>],
475 n: usize,
476 ) {
477 let args = recast_complex_vec(args);
478 let outs = recast_complex_vec_mut(outs);
479
480 if self.prog.config().is_bytecode() {
481 self.evaluate_matrix_bytecode(args, outs, n);
482 } else if self.use_threads {
483 if self.compiled_simd.is_some() {
484 self.evaluate_matrix_with_threads_simd(args, outs, n);
485 } else {
486 self.evaluate_matrix_with_threads(args, outs, n);
487 }
488 } else {
489 if self.compiled_simd.is_some() {
490 self.evaluate_matrix_without_threads_simd(args, outs, n);
491 } else {
492 self.evaluate_matrix_without_threads(args, outs, n);
493 }
494 }
495 }
496
497 pub fn evaluate_simd_matrix<T: Sized + Copy>(&mut self, args: &[T], outs: &mut [T], n: usize) {
499 let args_size = args.len() / n;
500 let outs_size = outs.len() / n;
501
502 for (p, q) in args.chunks(args_size).zip(outs.chunks_mut(outs_size)) {
503 self.evaluate_simd(p, q);
504 }
505 }
506
507 #[cfg(target_arch = "x86_64")]
519 pub fn call_simd(&mut self, args: &[__m256d]) -> Result<Vec<__m256d>> {
520 if let Some(f) = &mut self.compiled_simd {
521 {
522 let mem = f.mem_mut();
523 let states = unsafe {
524 simd_slice_mut(
525 &mut mem[self.first_state * 4..(self.first_state + self.count_states) * 4],
526 )
527 };
528 states.copy_from_slice(args);
529 }
530
531 f.exec(&self.params);
532
533 {
534 let mem = f.mem();
535 let obs = unsafe {
536 simd_slice(&mem[self.first_obs * 4..(self.first_obs + self.count_obs) * 4])
537 };
538 let mut res = unsafe { vec![_mm256_setzero_pd(); self.count_obs] };
539 res.copy_from_slice(obs);
540 Ok(res)
541 }
542 } else {
543 self.prepare_simd();
544 if self.compiled_simd.is_some() {
545 self.call_simd(args)
546 } else {
547 Err(anyhow!("cannot compile SIMD"))
548 }
549 }
550 }
551
552 #[cfg(not(target_arch = "x86_64"))]
553 pub unsafe fn call_simd(&mut self, _args: &[__m256d]) -> Result<Vec<__m256d>> {
554 Err(anyhow!("cannot compile SIMD"))
555 }
556
557 #[cfg(target_arch = "x86_64")]
571 pub fn call_simd_params(&mut self, args: &[__m256d], params: &[f64]) -> Result<Vec<__m256d>> {
572 if let Some(f) = &mut self.compiled_simd {
573 {
574 let mem = f.mem_mut();
575 let states = unsafe {
576 simd_slice_mut(
577 &mut mem[self.first_state * 4..(self.first_state + self.count_states) * 4],
578 )
579 };
580 states.copy_from_slice(args);
581 }
582
583 f.exec(params);
584
585 {
586 let mem = f.mem();
587 let obs = unsafe {
588 simd_slice(&mem[self.first_obs * 4..(self.first_obs + self.count_obs) * 4])
589 };
590 let mut res = unsafe { vec![_mm256_setzero_pd(); self.count_obs] };
591 res.copy_from_slice(obs);
592 Ok(res)
593 }
594 } else {
595 self.prepare_simd();
596 if self.compiled_simd.is_some() {
597 self.call_simd_params(args, params)
598 } else {
599 Err(anyhow!("cannot compile SIMD"))
600 }
601 }
602 }
603
604 #[cfg(not(target_arch = "x86_64"))]
605 pub unsafe fn call_simd_params(
606 &mut self,
607 _args: &[__m256d],
608 _params: &[f64],
609 ) -> Result<Vec<__m256d>> {
610 Err(anyhow!("cannot compile SIMD"))
611 }
612
613 pub fn fast_func(&mut self) -> Result<FastFunc<'_>> {
653 let f = self.get_fast();
654
655 if let Some(f) = f {
656 match self.count_states {
657 1 => {
658 let g: extern "C" fn(f64) -> f64 = unsafe { std::mem::transmute(f) };
659 Ok(FastFunc::F1(g, self))
660 }
661 2 => {
662 let g: extern "C" fn(f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
663 Ok(FastFunc::F2(g, self))
664 }
665 3 => {
666 let g: extern "C" fn(f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
667 Ok(FastFunc::F3(g, self))
668 }
669 4 => {
670 let g: extern "C" fn(f64, f64, f64, f64) -> f64 =
671 unsafe { std::mem::transmute(f) };
672 Ok(FastFunc::F4(g, self))
673 }
674 5 => {
675 let g: extern "C" fn(f64, f64, f64, f64, f64) -> f64 =
676 unsafe { std::mem::transmute(f) };
677 Ok(FastFunc::F5(g, self))
678 }
679 6 => {
680 let g: extern "C" fn(f64, f64, f64, f64, f64, f64) -> f64 =
681 unsafe { std::mem::transmute(f) };
682 Ok(FastFunc::F6(g, self))
683 }
684 7 => {
685 let g: extern "C" fn(f64, f64, f64, f64, f64, f64, f64) -> f64 =
686 unsafe { std::mem::transmute(f) };
687 Ok(FastFunc::F7(g, self))
688 }
689 8 => {
690 let g: extern "C" fn(f64, f64, f64, f64, f64, f64, f64, f64) -> f64 =
691 unsafe { std::mem::transmute(f) };
692 Ok(FastFunc::F8(g, self))
693 }
694 _ => Err(anyhow!("not a fast function")),
695 }
696 } else {
697 Err(anyhow!("not a fast function"))
698 }
699 }
700}
701
702pub fn recast_complex_vec(v: &[Complex<f64>]) -> &[f64] {
703 let n = v.len();
704 let p: *const f64 = unsafe { std::mem::transmute(v.as_ptr()) };
705 let q: &[f64] = unsafe { std::slice::from_raw_parts(p, 2 * n) };
706 q
707}
708
709pub fn recast_complex_vec_mut(v: &mut [Complex<f64>]) -> &mut [f64] {
710 let n = v.len();
711 let p: *mut f64 = unsafe { std::mem::transmute(v.as_mut_ptr()) };
712 let q: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(p, 2 * n) };
713 q
714}
715
716#[derive(Debug, Clone)]
720pub struct Translator {
721 config: Config,
722 df: Defuns,
723 ssa: Vec<Instruction>,
724 consts: Vec<Complex<f64>>, count_params: usize,
726 count_statics: usize,
727 eqs: Vec<Equation>, temps: HashMap<usize, Slot>, counts: HashMap<usize, usize>, cache: HashMap<usize, Expr>, outs: HashMap<usize, Expr>, reals: HashSet<Loc>,
733 num_params: usize,
734 has_jump: bool,
735 last_label: usize,
736 depth: usize,
737 conds: Vec<Slot>,
738}
739
740impl Translator {
741 pub fn new(config: Config, df: &Defuns) -> Translator {
742 Translator {
743 config,
744 df: df.clone(),
745 ssa: Vec::new(),
746 consts: Vec::new(),
747 count_params: 0,
748 count_statics: 0,
749 eqs: Vec::new(),
750 temps: HashMap::new(),
751 counts: HashMap::new(),
752 cache: HashMap::new(),
753 outs: HashMap::new(),
754 reals: HashSet::new(),
755 num_params: 0,
756 has_jump: false,
757 last_label: 0,
758 depth: 0,
759 conds: Vec::new(),
760 }
761 }
762
763 pub fn parse_model(&mut self, model: &SymbolicaModel) -> Result<()> {
764 for c in model.2.iter() {
765 let val = Complex::new(c.value().re, c.value().im);
766 self.consts.push(val);
767 }
768
769 self.convert(model)?;
770 Ok(())
771 }
772
773 fn convert(&mut self, model: &SymbolicaModel) -> Result<()> {
776 for line in model.0.iter() {
777 match line {
778 Instruction::Add(lhs, args, num_reals) => self.append_add(lhs, args, *num_reals)?,
779 Instruction::Mul(lhs, args, num_reals) => self.append_mul(lhs, args, *num_reals)?,
780 Instruction::Pow(lhs, arg, p, is_real) => {
781 self.append_pow(lhs, arg, *p, *is_real)?
782 }
783 Instruction::Powf(lhs, arg, p, is_real) => {
784 self.append_powf(lhs, arg, p, *is_real)?
785 }
786 Instruction::Assign(lhs, rhs) => self.append_assign(lhs, rhs)?,
787 Instruction::Fun(lhs, fun, arg, is_real) => {
788 self.append_fun(lhs, fun, arg, *is_real)?
789 }
790 Instruction::Join(lhs, cond, true_val, false_val) => {
791 self.depth -= 1;
792 self.append_join(lhs, cond, true_val, false_val)?
793 }
794 Instruction::Label(id) => self.append_label(*id)?,
795 Instruction::IfElse(cond, id) => {
796 self.append_if_else(cond, *id)?;
797 self.depth += 1;
798 }
799 Instruction::Goto(id) => self.append_goto(*id)?,
800 Instruction::ExternalFun(lhs, op, args) => {
801 self.append_external_fun(lhs, op, args)?
802 }
803 }
804 }
805
806 Ok(())
807 }
808
809 pub fn append_constant(&mut self, z: Complex<f64>) -> Result<usize> {
810 self.consts.push(z);
811 Ok(self.consts.len() - 1)
812 }
813
814 pub fn append_add(&mut self, lhs: &Slot, args: &[Slot], num_reals: usize) -> Result<()> {
815 let args = self.consume_list(args)?;
816 let lhs = self.produce(lhs)?;
817 self.ssa.push(Instruction::Add(lhs, args, num_reals));
818 Ok(())
819 }
820
821 pub fn append_mul(&mut self, lhs: &Slot, args: &[Slot], num_reals: usize) -> Result<()> {
822 let args = self.consume_list(args)?;
823 let lhs = self.produce(lhs)?;
824 self.ssa.push(Instruction::Mul(lhs, args, num_reals));
825 Ok(())
826 }
827
828 pub fn append_pow(&mut self, lhs: &Slot, arg: &Slot, p: i64, is_real: bool) -> Result<()> {
829 let arg = self.consume(arg)?;
830 let lhs = self.produce(lhs)?;
831 self.ssa.push(Instruction::Pow(lhs, arg, p, is_real));
832 Ok(())
833 }
834
835 pub fn append_powf(&mut self, lhs: &Slot, arg: &Slot, p: &Slot, is_real: bool) -> Result<()> {
836 let arg = self.consume(arg)?;
837 let p = self.consume(p)?;
838 let lhs = self.produce(lhs)?;
839 self.ssa.push(Instruction::Powf(lhs, arg, p, is_real));
840 Ok(())
841 }
842
843 pub fn append_assign(&mut self, lhs: &Slot, rhs: &Slot) -> Result<()> {
844 let rhs = self.consume(rhs)?;
845 let lhs = self.produce(lhs)?;
846 self.ssa.push(Instruction::Assign(lhs, rhs));
847 Ok(())
848 }
849
850 pub fn append_label(&mut self, id: usize) -> Result<()> {
851 self.ssa.push(Instruction::Label(id));
852 Ok(())
853 }
854
855 pub fn append_if_else(&mut self, cond: &Slot, id: usize) -> Result<()> {
856 self.has_jump = true;
857 let cond = self.consume(cond)?;
858 self.ssa.push(Instruction::IfElse(cond, id));
859 Ok(())
860 }
861
862 pub fn append_goto(&mut self, id: usize) -> Result<()> {
863 self.last_label = self.last_label.max(id);
864 self.ssa.push(Instruction::Goto(id));
865 Ok(())
866 }
867
868 pub fn append_external_fun(&mut self, lhs: &Slot, op: &str, args: &[Slot]) -> Result<()> {
869 let args = self.consume_list(args)?;
870 let lhs = self.produce(lhs)?;
871 self.ssa
872 .push(Instruction::ExternalFun(lhs, op.to_string(), args));
873 Ok(())
874 }
875
876 pub fn append_fun(
877 &mut self,
878 lhs: &Slot,
879 fun: &BuiltinSymbol,
880 arg: &Slot,
881 is_real: bool,
882 ) -> Result<()> {
883 let arg = self.consume(arg)?;
884 let lhs = self.produce(lhs)?;
885 self.ssa.push(Instruction::Fun(lhs, *fun, arg, is_real));
886 Ok(())
887 }
888
889 pub fn append_join(
890 &mut self,
891 lhs: &Slot,
892 cond: &Slot,
893 true_val: &Slot,
894 false_val: &Slot,
895 ) -> Result<()> {
896 let cond = self.consume(cond)?;
897 let true_val = self.consume(true_val)?;
898 let false_val = self.consume(false_val)?;
899 let lhs = self.produce(lhs)?;
900 self.ssa
901 .push(Instruction::Join(lhs, cond, true_val, false_val));
902 Ok(())
903 }
904
905 fn produce(&mut self, slot: &Slot) -> Result<Slot> {
908 match slot {
909 Slot::Temp(idx) => {
910 if self.depth > 0 {
911 if let Some(Slot::Static(s)) = self.temps.get(idx) {
912 *self.counts.get_mut(s).unwrap() += 1;
913 return Ok(Slot::Static(*s));
914 }
915 }
916
917 let s = Slot::Static(self.count_statics);
918 self.counts.insert(self.count_statics, 0);
919 self.count_statics += 1;
920 self.temps.insert(*idx, s);
921 Ok(s)
922 }
923 Slot::Out(idx) => Ok(Slot::Out(*idx)),
924 _ => Err(anyhow!("unacceptable lhs.")),
925 }
926 }
927
928 fn consume(&mut self, slot: &Slot) -> Result<Slot> {
931 match slot {
932 Slot::Temp(idx) => {
933 if let Some(Slot::Static(s)) = self.temps.get(idx) {
934 *self.counts.get_mut(s).unwrap() += 1;
935 Ok(Slot::Static(*s))
936 } else {
937 Err(anyhow!("Not a static reg."))
938 }
939 }
940 Slot::Out(idx) => Ok(Slot::Out(*idx)),
941 Slot::Param(idx) => Ok(Slot::Param(*idx)),
942 Slot::Const(idx) => Ok(Slot::Const(*idx)),
943 Slot::Static(_) => Err(anyhow!("Undefined Static.")),
944 }
945 }
946
947 fn consume_list(&mut self, slots: &[Slot]) -> Result<Vec<Slot>> {
948 slots.iter().map(|s| self.consume(s)).collect()
949 }
950
951 pub fn translate(&mut self) -> Result<(CellModel, HashSet<Loc>)> {
953 let ssa = std::mem::take(&mut self.ssa);
954
955 for line in ssa.iter() {
956 match line {
957 Instruction::Add(lhs, args, n) => self.translate_nary("plus", lhs, args, *n)?,
958 Instruction::Mul(lhs, args, n) => self.translate_nary("times", lhs, args, *n)?,
959 Instruction::Pow(lhs, arg, p, is_real) => {
960 let p = Expr::from(*p as f64);
961 self.translate_pow(lhs, arg, &p, *is_real)?
962 }
963 Instruction::Powf(lhs, arg, p, is_real) => {
964 let p = self.expr(p, false);
965 self.translate_pow(lhs, arg, &p, *is_real)?
966 }
967 Instruction::Assign(lhs, rhs) => self.translate_assign(lhs, rhs)?,
968 Instruction::Fun(lhs, fun, arg, is_real) => {
969 self.translate_fun(lhs, fun, arg, *is_real)?
970 }
971 Instruction::Join(lhs, cond, true_val, false_val) => {
972 self.translate_join(lhs, cond, true_val, false_val)?
973 }
974 Instruction::Label(id) => self.translate_label(*id)?,
975 Instruction::IfElse(cond, id) => self.translate_ifelse(cond, *id)?,
976 Instruction::Goto(id) => self.translate_goto(*id)?,
977 Instruction::ExternalFun(lhs, op, args) => {
978 self.translate_external_fun(lhs, op, args)?
979 }
980 }
981 }
982
983 for k in 0..self.outs.len() {
985 let out = Expr::var(&format!("Out{}", k));
986
987 if let Some(eq) = self.outs.get(&k) {
988 self.eqs.push(Expr::equation(&out, eq));
989 }
990 }
991
992 let mut params: Vec<Variable> = (0..=self.count_params.max(self.num_params.max(1) - 1))
993 .map(|idx| self.expr(&Slot::Param(idx), false).to_variable().unwrap())
994 .collect();
995
996 let mut states: Vec<Variable> = Vec::new();
997
998 if !self.config.symbolica() {
999 (params, states) = (states, params)
1000 }
1001
1002 Ok((
1003 CellModel {
1004 iv: Expr::var("$_").to_variable().unwrap(),
1005 params,
1006 states,
1007 algs: Vec::new(),
1008 odes: Vec::new(),
1009 obs: self.eqs.clone(),
1010 },
1011 self.reals.clone(),
1012 ))
1013 }
1014
1015 fn expr(&mut self, slot: &Slot, is_real: bool) -> Expr {
1017 match slot {
1018 Slot::Param(idx) => {
1019 if is_real {
1020 self.reals.insert(Loc::Param(*idx as u32));
1021 }
1022 self.count_params = self.count_params.max(*idx);
1023 Expr::var(&format!("Param{}", idx))
1024 }
1025 Slot::Out(idx) => {
1026 if let Some(e) = self.outs.get(idx) {
1027 e.clone()
1028 } else {
1029 Expr::var(&format!("Out{}", idx))
1030 }
1031 }
1032 Slot::Temp(idx) => Expr::var(&format!("__Temp{}", idx)),
1033 Slot::Const(idx) => {
1034 let val = self.consts[*idx];
1035 if val.im != 0.0 {
1036 Expr::binary("complex", &Expr::from(val.re), &Expr::from(val.im))
1037 } else {
1038 Expr::from(self.consts[*idx].re)
1039 }
1040 }
1041 Slot::Static(idx) => self
1042 .cache
1043 .remove(idx)
1044 .unwrap_or(Expr::var(&format!("__Static{}", idx))),
1045 }
1046 }
1047
1048 fn assign(&mut self, lhs: &Slot, rhs: Expr) -> Result<()> {
1050 if !self.has_jump {
1051 if let Slot::Static(idx) = lhs {
1052 if self.counts.get(idx).is_some_and(|c| *c == 1) {
1056 self.cache.insert(*idx, rhs);
1057 return Ok(());
1058 }
1059 }
1060
1061 if let Slot::Out(idx) = lhs {
1062 self.outs.insert(*idx, rhs.clone());
1063 return Ok(());
1064 }
1065 }
1066
1067 let lhs = self.expr(lhs, false);
1068 self.eqs.push(Expr::equation(&lhs, &rhs));
1069 Ok(())
1070 }
1071
1072 fn translate_nary(&mut self, op: &str, lhs: &Slot, args: &[Slot], n: usize) -> Result<()> {
1073 let args: Vec<Expr> = args
1074 .iter()
1075 .enumerate()
1076 .map(|(i, x)| self.expr(x, i < n))
1077 .collect();
1078 let p: Vec<&Expr> = args.iter().collect();
1079
1080 if n == 0 || n >= p.len() {
1081 self.assign(lhs, Expr::nary(op, &p))
1082 } else {
1083 let l = Expr::nary(op, &p[..n]);
1084 let r = Expr::nary(op, &p[n..]);
1085 self.assign(lhs, Expr::nary(op, &[&l, &r]))
1086 }
1087 }
1088
1089 fn translate_pow(&mut self, lhs: &Slot, arg: &Slot, power: &Expr, is_real: bool) -> Result<()> {
1090 let arg = self.expr(arg, is_real);
1091 self.assign(lhs, Expr::binary("power", &arg, power))
1092 }
1093
1094 fn translate_assign(&mut self, lhs: &Slot, rhs: &Slot) -> Result<()> {
1095 let rhs = self.expr(rhs, false);
1096 self.assign(lhs, rhs)
1097 }
1098
1099 fn translate_fun(
1100 &mut self,
1101 lhs: &Slot,
1102 fun: &BuiltinSymbol,
1103 arg: &Slot,
1104 is_real: bool,
1105 ) -> Result<()> {
1106 let arg = self.expr(arg, is_real);
1107
1108 let op = match fun.0 {
1109 2 => "exp",
1110 3 => "ln",
1111 4 => "sin",
1112 5 => "cos",
1113 6 => {
1114 if is_real {
1115 "real_root"
1116 } else {
1117 "root"
1118 }
1119 }
1120 7 => "conjugate",
1121 _ => return Err(anyhow!("function is not defined.")),
1122 };
1123
1124 self.assign(lhs, Expr::unary(op, &arg))
1125 }
1126
1127 fn translate_external_fun(&mut self, lhs: &Slot, op: &str, args: &[Slot]) -> Result<()> {
1128 match args.len() {
1129 1 => {
1130 let arg = self.expr(&args[0], false);
1131 self.assign(lhs, Expr::unary(op, &arg))?;
1132 }
1133 2 => {
1134 let l = self.expr(&args[0], false);
1135 let r = self.expr(&args[1], false);
1136 self.assign(lhs, Expr::binary(op, &l, &r))?;
1137 }
1138 _ => {
1139 return Err(anyhow!(
1140 "only unary and binary external functions are supported"
1141 ))
1142 }
1143 }
1144
1145 Ok(())
1146 }
1147
1148 fn translate_label(&mut self, id: usize) -> Result<()> {
1149 self.eqs.push(Expr::special(&Expr::Label { id }));
1150 Ok(())
1151 }
1152
1153 fn translate_ifelse(&mut self, cond: &Slot, id: usize) -> Result<()> {
1154 self.conds.push(*cond);
1161 let if_clause = Expr::binary("eq", &self.expr(cond, false), &Expr::from(0.0));
1162 self.eqs.push(Expr::special(&Expr::BranchIf {
1163 cond: Box::new(if_clause),
1164 id,
1165 is_else: false,
1166 }));
1167 Ok(())
1168 }
1169
1170 fn translate_goto(&mut self, id: usize) -> Result<()> {
1171 if self.config.simd_branch() {
1172 } else {
1187 self.eqs.push(Expr::special(&Expr::Branch { id }));
1188 }
1189
1190 Ok(())
1191 }
1192
1193 fn translate_join(
1194 &mut self,
1195 lhs: &Slot,
1196 _cond: &Slot,
1197 true_val: &Slot,
1198 false_val: &Slot,
1199 ) -> Result<()> {
1200 let t = self.expr(true_val, false);
1202 let f = self.expr(false_val, false);
1203 let cond = self.conds.pop().unwrap();
1204 let mask = Expr::binary("eq", &self.expr(&cond, false), &Expr::from(0.0));
1205 self.assign(lhs, mask.ifelse(&f, &t))?;
1206 Ok(())
1207 }
1208
1209 pub fn set_num_params(&mut self, num_params: usize) {
1210 self.num_params = num_params
1211 }
1212
1213 pub fn compile(&mut self) -> Result<Application> {
1214 let (ml, reals) = self.translate()?;
1215 let prog = Program::new(&ml, self.config)?;
1216 let mut app = Application::new(prog, reals, &self.df)?;
1217 app.prepare_simd();
1218 Ok(app)
1219 }
1220}
1221
1222impl Compiler {
1223 pub fn translate(
1240 &mut self,
1241 json: String,
1242 df: &Defuns,
1243 num_params: usize,
1244 ) -> Result<Application> {
1245 let mut translator = Translator::new(self.config, df);
1246
1247 let model: SymbolicaModel = if json.starts_with("[[{") {
1248 serde_json::from_str(json.as_str())?
1249 } else {
1250 Parser::new(json).parse()?
1251 };
1252
1253 translator.parse_model(&model)?;
1254 translator.set_num_params(num_params);
1255 let (ml, reals) = translator.translate()?;
1256
1257 let prog = Program::new(&ml, self.config)?;
1258 let df = Defuns::new();
1259 let mut app = Application::new(prog, reals, &df)?;
1260
1261 app.prepare_simd();
1262
1263 Ok(app)
1271 }
1272}