1use std::collections::{HashMap, HashSet};
2
3use anyhow::{anyhow, Result};
4use num_complex::Complex;
5use rayon::prelude::*;
6
7use crate::code::VirtualTable;
8use crate::config::{Config, SLICE_CAP};
9use crate::defuns::Defuns;
10use crate::expr::Expr;
11pub use crate::instruction::{BuiltinSymbol, Instruction, Slot, SymbolicaModel};
12use crate::model::{CellModel, Equation, Program, Variable};
13use crate::parser::Parser;
14use crate::symbol::Loc;
15use crate::types::{ElemType, Element};
16use crate::utils::{Compiled, CompiledFunc};
17use crate::Application;
18
19pub struct Compiler {
21 config: Config,
22 df: Defuns,
23}
24
25#[cfg(not(target_arch = "x86_64"))]
26#[allow(non_camel_case_types)]
27type __m256d = [f64; 4];
28
29impl Compiler {
83 pub fn new() -> Compiler {
85 Compiler {
86 config: Config::default(),
87 df: Defuns::new(),
88 }
89 }
90
91 pub fn with_config(config: Config) -> Compiler {
92 Compiler {
93 config,
94 df: Defuns::new(),
95 }
96 }
97
98 pub fn compile(&mut self, states: &[Expr], obs: &[Expr]) -> Result<Application> {
103 self.compile_params(states, obs, &[])
104 }
105
106 pub fn compile_params(
116 &mut self,
117 states: &[Expr],
118 obs: &[Expr],
119 params: &[Expr],
120 ) -> Result<Application> {
121 let mut vars: Vec<Variable> = Vec::new();
122
123 for state in states.iter() {
124 let v = state.to_variable()?;
125 vars.push(v);
126 }
127
128 let mut ps: Vec<Variable> = Vec::new();
129
130 for p in params.iter() {
131 let v = p.to_variable()?;
132 ps.push(v);
133 }
134
135 let mut eqs: Vec<Equation> = Vec::new();
136
137 for (i, expr) in obs.iter().enumerate() {
138 let name = format!("${}", i);
139 let lhs = Expr::var(&name);
140 eqs.push(Expr::equation(&lhs, expr));
141 }
142
143 let ml = CellModel {
144 iv: Expr::var("$_").to_variable()?,
145 params: ps,
146 states: vars,
147 algs: Vec::new(),
148 odes: Vec::new(),
149 obs: eqs,
150 };
151
152 let prog = Program::new(&ml, self.config)?;
153 let mut app = Application::new(prog, HashSet::new(), std::mem::take(&mut self.df))?;
155 Ok(app)
165 }
166
167 pub fn def_unary(&mut self, op: &str, f: extern "C" fn(f64) -> f64) {
169 self.df.add_unary(op, f)
170 }
171
172 pub fn def_binary(&mut self, op: &str, f: extern "C" fn(f64, f64) -> f64) {
174 self.df.add_binary(op, f)
175 }
176}
177
178pub enum FastFunc<'a> {
179 F1(extern "C" fn(f64) -> f64, &'a Application),
180 F2(extern "C" fn(f64, f64) -> f64, &'a Application),
181 F3(extern "C" fn(f64, f64, f64) -> f64, &'a Application),
182 F4(extern "C" fn(f64, f64, f64, f64) -> f64, &'a Application),
183 F5(
184 extern "C" fn(f64, f64, f64, f64, f64) -> f64,
185 &'a Application,
186 ),
187 F6(
188 extern "C" fn(f64, f64, f64, f64, f64, f64) -> f64,
189 &'a Application,
190 ),
191 F7(
192 extern "C" fn(f64, f64, f64, f64, f64, f64, f64) -> f64,
193 &'a Application,
194 ),
195 F8(
196 extern "C" fn(f64, f64, f64, f64, f64, f64, f64, f64) -> f64,
197 &'a Application,
198 ),
199}
200
201impl Application {
202 pub fn call(&mut self, args: &[f64]) -> Vec<f64> {
209 if let Some(f) = &mut self.compiled {
210 {
211 let mem = f.mem_mut();
212 let states = &mut mem[self.first_state..self.first_state + self.count_states];
213 states.copy_from_slice(args);
214 }
215
216 f.exec(&self.params[..]);
217
218 let obs = {
219 let mem = f.mem();
220 &mem[self.first_obs..self.first_obs + self.count_obs]
221 };
222
223 obs.to_vec()
224 } else {
225 Vec::new()
226 }
227 }
228
229 pub fn call_params(&mut self, args: &[f64], params: &[f64]) -> Vec<f64> {
237 if let Some(f) = &mut self.compiled {
238 {
239 let mem = f.mem_mut();
240 let states = &mut mem[self.first_state..self.first_state + self.count_states];
241 states.copy_from_slice(args);
242 }
243
244 f.exec(params);
245
246 let obs = {
247 let mem = f.mem();
248 &mem[self.first_obs..self.first_obs + self.count_obs]
249 };
250
251 obs.to_vec()
252 } else {
253 Vec::new()
254 }
255 }
256
257 pub fn interpret<T>(&mut self, args: &[T], outs: &mut [T])
258 where
259 T: Element,
260 {
261 let args = recast_as_f64(args);
262 let outs = recast_as_f64_mut(outs);
263
264 let mut regs = [0.0; 32];
265 self.bytecode
266 .mir
267 .exec_instruction(outs, &mut self.bytecode.stack, &mut regs, args);
268 }
269
270 pub fn interpret_matrix(&mut self, args: &[f64], outs: &mut [f64], n: usize) {
271 let count_params = self.count_params;
272 let count_obs = self.count_obs;
273
274 for i in 0..n {
275 self.interpret(
276 &args[i * count_params..(i + 1) * count_params],
277 &mut outs[i * count_obs..(i + 1) * count_obs],
278 );
279 }
280 }
281
282 pub fn evaluate<T>(&self, args: &[T], outs: &mut [T])
284 where
285 T: Element,
286 {
287 let args = recast_as_f64(args);
288 let outs = recast_as_f64_mut(outs);
289
290 let simd = matches!(
291 T::get_type(T::default()),
292 ElemType::RealF64x2(_)
293 | ElemType::RealF64x4(_)
294 | ElemType::ComplexF64x2(_)
295 | ElemType::ComplexF64x4(_)
296 );
297
298 if let Some(f) = &self.compiled {
299 if !simd {
300 f.func()(outs.as_mut_ptr(), std::ptr::null(), 0, args.as_ptr());
301 } else if let Some(g) = &self.compiled_simd {
302 g.func()(outs.as_mut_ptr(), std::ptr::null(), 0, args.as_ptr());
303 }
304 }
305 }
306
307 #[inline(always)]
309 pub fn evaluate_single<T>(&self, args: &[T]) -> T
310 where
311 T: Element + Copy,
312 {
313 let mut outs = [T::default(); 1];
314 self.evaluate(args, &mut outs);
315 outs[0]
316 }
317
318 fn evaluate_row(
321 args: &[f64],
322 args_idx: usize,
323 outs: &[f64],
324 outs_idx: usize,
325 f: CompiledFunc<f64>,
326 transpose: bool,
327 ) -> i32 {
328 unsafe {
329 f(
330 outs.as_ptr().add(outs_idx),
331 std::ptr::null(),
332 if transpose { 1 } else { 0 },
333 args.as_ptr().add(args_idx),
334 )
335 }
336 }
337
338 fn evaluate_matrix_with_threads(&self, args: &[f64], outs: &mut [f64], n: usize) {
339 if let Some(f) = &self.compiled {
340 let count_params = self.count_params;
341 let count_obs = self.count_obs;
342 let f_scalar = f.func();
343
344 (0..n).into_par_iter().for_each(|t| {
345 Self::evaluate_row(args, t * count_params, outs, t * count_obs, f_scalar, false);
346 });
347 }
348 }
349
350 fn evaluate_matrix_without_threads(&self, args: &[f64], outs: &mut [f64], n: usize) {
351 if let Some(f) = &self.compiled {
352 let count_params = self.count_params;
353 let count_obs = self.count_obs;
354 let f_scalar = f.func();
355
356 for t in 0..n {
357 Self::evaluate_row(args, t * count_params, outs, t * count_obs, f_scalar, false);
358 }
359 }
360 }
361
362 fn evaluate_matrix_with_threads_simd(
363 &self,
364 args: &[f64],
365 outs: &mut [f64],
366 n: usize,
367 transpose: bool,
368 ) {
369 if let Some(f) = &self.compiled {
370 let count_params = self.count_params;
371 let count_obs = self.count_obs;
372
373 if let Some(compiled) = &self.compiled_simd {
374 let f_simd = compiled.func();
375 let f_scalar = f.func();
376 let lanes = compiled.count_lanes();
377 let step = if transpose { lanes } else { 1 };
378
379 (0..n / step).into_par_iter().for_each(|k| {
380 let top = k * lanes;
381 if Self::evaluate_row(
382 args,
383 top * count_params,
384 outs,
385 top * count_obs,
386 f_simd,
387 transpose,
388 ) != 0
389 {
390 for i in 0..lanes {
391 Self::evaluate_row(
392 args,
393 (top + i) * count_params,
394 outs,
395 (top + i) * count_obs,
396 f_scalar,
397 false,
398 );
399 }
400 }
401 });
402
403 for t in step * (n / step)..n {
404 Self::evaluate_row(
405 args,
406 t * count_params,
407 outs,
408 t * count_obs,
409 f_scalar,
410 false,
411 );
412 }
413 }
414 }
415 }
416
417 fn evaluate_matrix_without_threads_simd(
418 &self,
419 args: &[f64],
420 outs: &mut [f64],
421 n: usize,
422 transpose: bool,
423 ) {
424 if let Some(f) = &self.compiled {
425 let count_params = self.count_params;
426 let count_obs = self.count_obs;
427
428 if let Some(compiled) = &self.compiled_simd {
429 let f_simd = compiled.func();
430 let f_scalar = f.func();
431 let lanes = compiled.count_lanes();
432 let step = if transpose { lanes } else { 1 };
433
434 for k in 0..n / step {
435 let top = k * lanes;
436 if Self::evaluate_row(
437 args,
438 top * count_params,
439 outs,
440 top * count_obs,
441 f_simd,
442 transpose,
443 ) != 0
444 {
445 for i in 0..lanes {
446 Self::evaluate_row(
447 args,
448 (top + i) * count_params,
449 outs,
450 (top + i) * count_obs,
451 f_scalar,
452 false,
453 );
454 }
455 }
456 }
457
458 for t in step * (n / step)..n {
459 Self::evaluate_row(
460 args,
461 t * count_params,
462 outs,
463 t * count_obs,
464 f_scalar,
465 false,
466 );
467 }
468 }
469 }
470 }
471
472 pub fn evaluate_matrix<T>(&self, args: &[T], outs: &mut [T], n: usize)
477 where
478 T: Element,
479 {
480 let args = recast_as_f64(args);
481 let outs = recast_as_f64_mut(outs);
482
483 let transpose = !matches!(
484 T::get_type(T::default()),
485 ElemType::RealF64x2(_)
486 | ElemType::RealF64x4(_)
487 | ElemType::ComplexF64x2(_)
488 | ElemType::ComplexF64x4(_)
489 );
490
491 if self.use_threads && n > 1 {
492 if self.compiled_simd.is_some() {
493 self.evaluate_matrix_with_threads_simd(args, outs, n, transpose);
494 } else {
495 self.evaluate_matrix_with_threads(args, outs, n);
496 }
497 } else {
498 if self.compiled_simd.is_some() {
499 self.evaluate_matrix_without_threads_simd(args, outs, n, transpose);
500 } else {
501 self.evaluate_matrix_without_threads(args, outs, n);
502 }
503 }
504 }
505
506 pub fn fast_func(&mut self) -> Result<FastFunc<'_>> {
546 let f = self.get_fast();
547
548 if let Some(f) = f {
549 match self.count_states {
550 1 => {
551 let g: extern "C" fn(f64) -> f64 = unsafe { std::mem::transmute(f) };
552 Ok(FastFunc::F1(g, self))
553 }
554 2 => {
555 let g: extern "C" fn(f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
556 Ok(FastFunc::F2(g, self))
557 }
558 3 => {
559 let g: extern "C" fn(f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
560 Ok(FastFunc::F3(g, self))
561 }
562 4 => {
563 let g: extern "C" fn(f64, f64, f64, f64) -> f64 =
564 unsafe { std::mem::transmute(f) };
565 Ok(FastFunc::F4(g, self))
566 }
567 5 => {
568 let g: extern "C" fn(f64, f64, f64, f64, f64) -> f64 =
569 unsafe { std::mem::transmute(f) };
570 Ok(FastFunc::F5(g, self))
571 }
572 6 => {
573 let g: extern "C" fn(f64, f64, f64, f64, f64, f64) -> f64 =
574 unsafe { std::mem::transmute(f) };
575 Ok(FastFunc::F6(g, self))
576 }
577 7 => {
578 let g: extern "C" fn(f64, f64, f64, f64, f64, f64, f64) -> f64 =
579 unsafe { std::mem::transmute(f) };
580 Ok(FastFunc::F7(g, self))
581 }
582 8 => {
583 let g: extern "C" fn(f64, f64, f64, f64, f64, f64, f64, f64) -> f64 =
584 unsafe { std::mem::transmute(f) };
585 Ok(FastFunc::F8(g, self))
586 }
587 _ => Err(anyhow!("not a fast function")),
588 }
589 } else {
590 Err(anyhow!("not a fast function"))
591 }
592 }
593}
594
595pub fn recast_complex_vec(v: &[Complex<f64>]) -> &[f64] {
596 let n = v.len();
597 let p: *const f64 = unsafe { std::mem::transmute(v.as_ptr()) };
598 let q: &[f64] = unsafe { std::slice::from_raw_parts(p, 2 * n) };
599 q
600}
601
602pub fn recast_complex_vec_mut(v: &mut [Complex<f64>]) -> &mut [f64] {
603 let n = v.len();
604 let p: *mut f64 = unsafe { std::mem::transmute(v.as_mut_ptr()) };
605 let q: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(p, 2 * n) };
606 q
607}
608
609pub fn recast_as_f64<T>(v: &[T]) -> &[f64]
610where
611 T: Sized,
612{
613 let s = std::mem::size_of::<T>() / std::mem::size_of::<f64>();
614 let p: *const f64 = v.as_ptr() as _;
615 let q: &[f64] = unsafe { std::slice::from_raw_parts(p, s * v.len()) };
616 q
617}
618
619pub fn recast_as_f64_mut<T>(v: &mut [T]) -> &mut [f64]
620where
621 T: Sized,
622{
623 let s = std::mem::size_of::<T>() / std::mem::size_of::<f64>();
624 let p: *mut f64 = v.as_ptr() as _;
625 let q: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(p, s * v.len()) };
626 q
627}
628
629#[derive(Debug, Clone)]
633pub struct Translator {
634 config: Config,
635 df: Defuns,
636 ssa: Vec<Instruction>,
637 consts: Vec<Complex<f64>>, count_params: usize,
639 count_statics: usize,
640 eqs: Vec<Equation>, temps: HashMap<usize, Slot>, counts: HashMap<usize, usize>, cache: HashMap<usize, Expr>, outs: HashMap<usize, Expr>, reals: HashSet<Loc>,
646 num_params: usize,
647 has_jump: bool,
648 last_label: usize,
649 depth: usize,
650 conds: Vec<Slot>,
651}
652
653impl Translator {
654 pub fn new(config: Config, df: Defuns) -> Translator {
655 Translator {
656 config,
657 df,
658 ssa: Vec::new(),
659 consts: Vec::new(),
660 count_params: 0,
661 count_statics: 0,
662 eqs: Vec::new(),
663 temps: HashMap::new(),
664 counts: HashMap::new(),
665 cache: HashMap::new(),
666 outs: HashMap::new(),
667 reals: HashSet::new(),
668 num_params: 0,
669 has_jump: false,
670 last_label: 0,
671 depth: 0,
672 conds: Vec::new(),
673 }
674 }
675
676 pub fn parse_model(&mut self, model: &SymbolicaModel) -> Result<()> {
677 for c in model.2.iter() {
678 let val = Complex::new(c.value().re, c.value().im);
679 self.consts.push(val);
680 }
681
682 self.convert(model)?;
683 Ok(())
684 }
685
686 fn convert(&mut self, model: &SymbolicaModel) -> Result<()> {
689 for line in model.0.iter() {
690 match line {
691 Instruction::Add(lhs, args, num_reals) => self.append_add(lhs, args, *num_reals)?,
692 Instruction::Mul(lhs, args, num_reals) => self.append_mul(lhs, args, *num_reals)?,
693 Instruction::Pow(lhs, arg, p, is_real) => {
694 self.append_pow(lhs, arg, *p, *is_real)?
695 }
696 Instruction::Powf(lhs, arg, p, is_real) => {
697 self.append_powf(lhs, arg, p, *is_real)?
698 }
699 Instruction::Assign(lhs, rhs) => self.append_assign(lhs, rhs)?,
700 Instruction::Fun(lhs, fun, arg, is_real) => {
701 self.append_fun(lhs, fun, arg, *is_real)?
702 }
703 Instruction::Join(lhs, cond, true_val, false_val) => {
704 self.depth -= 1;
705 self.append_join(lhs, cond, true_val, false_val)?
706 }
707 Instruction::Label(id) => self.append_label(*id)?,
708 Instruction::IfElse(cond, id) => {
709 self.append_if_else(cond, *id)?;
710 self.depth += 1;
711 }
712 Instruction::Goto(id) => self.append_goto(*id)?,
713 Instruction::ExternalFun(lhs, op, args) => {
714 self.append_external_fun(lhs, op, args)?
715 }
716 }
717 }
718
719 Ok(())
720 }
721
722 pub fn append_constant(&mut self, z: Complex<f64>) -> Result<usize> {
723 self.consts.push(z);
724 Ok(self.consts.len() - 1)
725 }
726
727 pub fn append_add(&mut self, lhs: &Slot, args: &[Slot], num_reals: usize) -> Result<()> {
728 let args = self.consume_list(args)?;
729 let lhs = self.produce(lhs)?;
730 self.ssa.push(Instruction::Add(lhs, args, num_reals));
731 Ok(())
732 }
733
734 pub fn append_mul(&mut self, lhs: &Slot, args: &[Slot], num_reals: usize) -> Result<()> {
735 let args = self.consume_list(args)?;
736 let lhs = self.produce(lhs)?;
737 self.ssa.push(Instruction::Mul(lhs, args, num_reals));
738 Ok(())
739 }
740
741 pub fn append_pow(&mut self, lhs: &Slot, arg: &Slot, p: i64, is_real: bool) -> Result<()> {
742 let arg = self.consume(arg)?;
743 let lhs = self.produce(lhs)?;
744 self.ssa.push(Instruction::Pow(lhs, arg, p, is_real));
745 Ok(())
746 }
747
748 pub fn append_powf(&mut self, lhs: &Slot, arg: &Slot, p: &Slot, is_real: bool) -> Result<()> {
749 let arg = self.consume(arg)?;
750 let p = self.consume(p)?;
751 let lhs = self.produce(lhs)?;
752 self.ssa.push(Instruction::Powf(lhs, arg, p, is_real));
753 Ok(())
754 }
755
756 pub fn append_assign(&mut self, lhs: &Slot, rhs: &Slot) -> Result<()> {
757 let rhs = self.consume(rhs)?;
758 let lhs = self.produce(lhs)?;
759 self.ssa.push(Instruction::Assign(lhs, rhs));
760 Ok(())
761 }
762
763 pub fn append_label(&mut self, id: usize) -> Result<()> {
764 self.ssa.push(Instruction::Label(id));
765 Ok(())
766 }
767
768 pub fn append_if_else(&mut self, cond: &Slot, id: usize) -> Result<()> {
769 self.has_jump = true;
770 let cond = self.consume(cond)?;
771 self.ssa.push(Instruction::IfElse(cond, id));
772 Ok(())
773 }
774
775 pub fn append_goto(&mut self, id: usize) -> Result<()> {
776 self.last_label = self.last_label.max(id);
777 self.ssa.push(Instruction::Goto(id));
778 Ok(())
779 }
780
781 pub fn append_external_fun(&mut self, lhs: &Slot, op: &str, args: &[Slot]) -> Result<()> {
782 let args = self.consume_list(args)?;
783 let lhs = self.produce(lhs)?;
784 self.ssa
785 .push(Instruction::ExternalFun(lhs, op.to_string(), args));
786 Ok(())
787 }
788
789 pub fn append_fun(
790 &mut self,
791 lhs: &Slot,
792 fun: &BuiltinSymbol,
793 arg: &Slot,
794 is_real: bool,
795 ) -> Result<()> {
796 let arg = self.consume(arg)?;
797 let lhs = self.produce(lhs)?;
798 self.ssa.push(Instruction::Fun(lhs, *fun, arg, is_real));
799 Ok(())
800 }
801
802 pub fn append_join(
803 &mut self,
804 lhs: &Slot,
805 cond: &Slot,
806 true_val: &Slot,
807 false_val: &Slot,
808 ) -> Result<()> {
809 let cond = self.consume(cond)?;
810 let true_val = self.consume(true_val)?;
811 let false_val = self.consume(false_val)?;
812 let lhs = self.produce(lhs)?;
813 self.ssa
814 .push(Instruction::Join(lhs, cond, true_val, false_val));
815 Ok(())
816 }
817
818 fn create_static(&mut self) -> Result<Slot> {
819 let s = Slot::Static(self.count_statics);
820 self.counts.insert(self.count_statics, 0);
821 self.count_statics += 1;
822 Ok(s)
823 }
824
825 fn produce(&mut self, slot: &Slot) -> Result<Slot> {
828 match slot {
829 Slot::Temp(idx) => {
830 if self.depth > 0 {
831 if let Some(Slot::Static(s)) = self.temps.get(idx) {
832 *self.counts.get_mut(s).unwrap() += 1;
833 return Ok(Slot::Static(*s));
834 }
835 }
836
837 let s = self.create_static()?;
838 self.temps.insert(*idx, s);
839 Ok(s)
840 }
841 Slot::Out(idx) => Ok(Slot::Out(*idx)),
842 _ => Err(anyhow!("unacceptable lhs.")),
843 }
844 }
845
846 fn consume(&mut self, slot: &Slot) -> Result<Slot> {
849 match slot {
850 Slot::Temp(idx) => {
851 if let Some(Slot::Static(s)) = self.temps.get(idx) {
852 *self.counts.get_mut(s).unwrap() += 1;
853 Ok(Slot::Static(*s))
854 } else {
855 Err(anyhow!("Not a static reg."))
856 }
857 }
858 Slot::Out(idx) => Ok(Slot::Out(*idx)),
859 Slot::Param(idx) => Ok(Slot::Param(*idx)),
860 Slot::Const(idx) => Ok(Slot::Const(*idx)),
861 Slot::Static(_) | Slot::Arg(_) => Err(anyhow!("Undefined Static/Arg.")),
862 }
863 }
864
865 fn consume_list(&mut self, slots: &[Slot]) -> Result<Vec<Slot>> {
866 slots.iter().map(|s| self.consume(s)).collect()
867 }
868
869 pub fn translate(&mut self) -> Result<(CellModel, HashSet<Loc>)> {
871 let ssa = std::mem::take(&mut self.ssa);
872
873 for line in ssa.iter() {
874 match line {
875 Instruction::Add(lhs, args, n) => self.translate_nary("plus", lhs, args, *n)?,
876 Instruction::Mul(lhs, args, n) => self.translate_nary("times", lhs, args, *n)?,
877 Instruction::Pow(lhs, arg, p, is_real) => {
878 let p = Expr::from(*p as f64);
879 self.translate_pow(lhs, arg, &p, *is_real)?
880 }
881 Instruction::Powf(lhs, arg, p, is_real) => {
882 let p = self.expr(p, false);
883 self.translate_pow(lhs, arg, &p, *is_real)?
884 }
885 Instruction::Assign(lhs, rhs) => self.translate_assign(lhs, rhs)?,
886 Instruction::Fun(lhs, fun, arg, is_real) => {
887 self.translate_fun(lhs, fun, arg, *is_real)?
888 }
889 Instruction::Join(lhs, cond, true_val, false_val) => {
890 self.translate_join(lhs, cond, true_val, false_val)?
891 }
892 Instruction::Label(id) => self.translate_label(*id)?,
893 Instruction::IfElse(cond, id) => self.translate_ifelse(cond, *id)?,
894 Instruction::Goto(id) => self.translate_goto(*id)?,
895 Instruction::ExternalFun(lhs, op, args) => {
896 self.translate_external_fun(lhs, op, args)?
897 }
898 }
899 }
900
901 for k in 0..self.outs.len() {
903 let out = Expr::var(&format!("Out{}", k));
904
905 if let Some(eq) = self.outs.get(&k) {
906 self.eqs.push(Expr::equation(&out, eq));
907 }
908 }
909
910 let mut params: Vec<Variable> = (0..=self.count_params.max(self.num_params.max(1) - 1))
911 .map(|idx| self.expr(&Slot::Param(idx), false).to_variable().unwrap())
912 .collect();
913
914 let mut states: Vec<Variable> = Vec::new();
915
916 if !self.config.symbolica() {
917 (params, states) = (states, params)
918 }
919
920 Ok((
921 CellModel {
922 iv: Expr::var("$_").to_variable().unwrap(),
923 params,
924 states,
925 algs: Vec::new(),
926 odes: Vec::new(),
927 obs: self.eqs.clone(),
928 },
929 self.reals.clone(),
930 ))
931 }
932
933 fn expr(&mut self, slot: &Slot, is_real: bool) -> Expr {
935 match slot {
936 Slot::Param(idx) => {
937 if is_real {
938 self.reals.insert(Loc::Param(*idx as u32));
939 }
940 self.count_params = self.count_params.max(*idx);
941 Expr::var(&format!("Param{}", idx))
942 }
943 Slot::Out(idx) => {
944 if let Some(e) = self.outs.get(idx) {
945 e.clone()
946 } else {
947 Expr::var(&format!("Out{}", idx))
948 }
949 }
950 Slot::Temp(idx) => Expr::var(&format!("__Temp{}", idx)),
951 Slot::Const(idx) => {
952 let val = self.consts[*idx];
953 if val.im != 0.0 {
954 Expr::binary("complex", &Expr::from(val.re), &Expr::from(val.im))
955 } else {
956 Expr::from(self.consts[*idx].re)
957 }
958 }
959 Slot::Static(idx) => self
960 .cache
961 .remove(idx)
962 .unwrap_or(Expr::var(&format!("__Static{}", idx))),
963 Slot::Arg(idx) => Expr::var(&format!("__Arg{}", idx)),
964 }
965 }
966
967 fn assign(&mut self, lhs: &Slot, rhs: Expr) -> Result<()> {
969 if !self.has_jump {
970 if let Slot::Static(idx) = lhs {
971 if self.counts.get(idx).is_some_and(|c| *c == 1) {
975 self.cache.insert(*idx, rhs);
976 return Ok(());
977 }
978 }
979
980 if let Slot::Out(idx) = lhs {
981 self.outs.insert(*idx, rhs.clone());
982 return Ok(());
983 }
984 }
985
986 let lhs = self.expr(lhs, false);
987 self.eqs.push(Expr::equation(&lhs, &rhs));
988 Ok(())
989 }
990
991 fn translate_nary(&mut self, op: &str, lhs: &Slot, args: &[Slot], n: usize) -> Result<()> {
992 let args: Vec<Expr> = args
993 .iter()
994 .enumerate()
995 .map(|(i, x)| self.expr(x, i < n))
996 .collect();
997 let p: Vec<&Expr> = args.iter().collect();
998
999 if n == 0 || n >= p.len() {
1000 self.assign(lhs, Expr::nary(op, &p))
1001 } else {
1002 let l = Expr::nary(op, &p[..n]);
1003 let r = Expr::nary(op, &p[n..]);
1004 self.assign(lhs, Expr::nary(op, &[&l, &r]))
1005 }
1006 }
1007
1008 fn translate_pow(&mut self, lhs: &Slot, arg: &Slot, power: &Expr, is_real: bool) -> Result<()> {
1009 let arg = self.expr(arg, is_real);
1010 self.assign(lhs, Expr::binary("power", &arg, power))
1011 }
1012
1013 fn translate_assign(&mut self, lhs: &Slot, rhs: &Slot) -> Result<()> {
1014 let rhs = self.expr(rhs, false);
1015 self.assign(lhs, rhs)
1016 }
1017
1018 fn translate_fun(
1019 &mut self,
1020 lhs: &Slot,
1021 fun: &BuiltinSymbol,
1022 arg: &Slot,
1023 is_real: bool,
1024 ) -> Result<()> {
1025 let arg = self.expr(arg, is_real);
1026
1027 let op = match fun.0 {
1028 2 => "exp",
1029 3 => "ln",
1030 4 => "sin",
1031 5 => "cos",
1032 6 => {
1033 if is_real {
1034 "real_root"
1035 } else {
1036 "root"
1037 }
1038 }
1039 7 => "conjugate",
1040 _ => return Err(anyhow!("function is not defined.")),
1041 };
1042
1043 self.assign(lhs, Expr::unary(op, &arg))
1044 }
1045
1046 fn translate_external_fun(&mut self, lhs: &Slot, op: &str, args: &[Slot]) -> Result<()> {
1047 let n = args.len();
1048 assert!(n <= SLICE_CAP);
1049 let args: Vec<Expr> = args.iter().map(|a| self.expr(a, false)).collect();
1050
1051 if VirtualTable::from_str(op).is_ok() {
1052 if n == 1 {
1053 self.assign(lhs, Expr::unary(op, &args[0]))?;
1054 } else if n == 2 {
1055 self.assign(lhs, Expr::binary(op, &args[0], &args[1]))?;
1056 } else {
1057 return Err(anyhow!("wrong number of arguments to {:?}", op));
1058 }
1059 } else if self.config.is_intrinsic_unary(op) && n == 1 {
1060 self.assign(lhs, Expr::unary(op, &args[0]))?;
1061 } else if self.config.is_intrinsic_binary(op) && n == 2 {
1062 self.assign(lhs, Expr::binary(op, &args[0], &args[1]))?;
1063 } else {
1064 let temps: Vec<Slot> = (0..n).map(|_| self.create_static().unwrap()).collect();
1065 let slice: Vec<Slot> = (0..n).map(Slot::Arg).collect();
1066
1067 for i in 0..n {
1068 self.assign(&temps[i], args[i].clone())?;
1069 }
1070
1071 for i in 0..n {
1072 if let Slot::Static(idx) = temps[i] {
1073 self.assign(&slice[i], Expr::var(&format!("__Static{}", idx)))?;
1074 }
1075 }
1076
1077 let op = format!("${}", op);
1078 self.assign(
1079 lhs,
1080 Expr::binary(&op, &Expr::from(0), &Expr::from(n as i32)),
1081 )?;
1082 }
1083
1084 Ok(())
1085 }
1086
1087 fn translate_label(&mut self, id: usize) -> Result<()> {
1088 self.eqs.push(Expr::special(&Expr::Label { id }));
1089 Ok(())
1090 }
1091
1092 fn translate_ifelse(&mut self, cond: &Slot, id: usize) -> Result<()> {
1093 self.conds.push(*cond);
1100 let if_clause = Expr::binary("eq", &self.expr(cond, false), &Expr::from(0.0));
1101 self.eqs.push(Expr::special(&Expr::BranchIf {
1102 cond: Box::new(if_clause),
1103 id,
1104 is_else: false,
1105 }));
1106 Ok(())
1107 }
1108
1109 fn translate_goto(&mut self, id: usize) -> Result<()> {
1110 if self.config.simd_branch() {
1111 } else {
1126 self.eqs.push(Expr::special(&Expr::Branch { id }));
1127 }
1128
1129 Ok(())
1130 }
1131
1132 fn translate_join(
1133 &mut self,
1134 lhs: &Slot,
1135 _cond: &Slot,
1136 true_val: &Slot,
1137 false_val: &Slot,
1138 ) -> Result<()> {
1139 let t = self.expr(true_val, false);
1141 let f = self.expr(false_val, false);
1142 let cond = self.conds.pop().unwrap();
1143 let mask = Expr::binary("eq", &self.expr(&cond, false), &Expr::from(0.0));
1144 self.assign(lhs, mask.ifelse(&f, &t))?;
1145 Ok(())
1146 }
1147
1148 pub fn set_num_params(&mut self, num_params: usize) {
1149 self.num_params = num_params
1150 }
1151
1152 pub fn compile(&mut self) -> Result<Application> {
1153 let (ml, reals) = self.translate()?;
1154 let prog = Program::new(&ml, self.config)?;
1155 let mut app = Application::new(prog, reals, std::mem::take(&mut self.df))?;
1156 app.prepare_simd();
1157 Ok(app)
1158 }
1159}
1160
1161impl Compiler {
1162 pub fn translate(
1179 &mut self,
1180 json: String,
1181 df: Defuns,
1182 num_params: usize,
1183 ) -> Result<Application> {
1184 let mut translator = Translator::new(self.config, df);
1185
1186 let model: SymbolicaModel = if json.starts_with("[[{") {
1187 serde_json::from_str(json.as_str())?
1188 } else {
1189 Parser::new(json).parse()?
1190 };
1191
1192 translator.parse_model(&model)?;
1193 translator.set_num_params(num_params);
1194 let (ml, reals) = translator.translate()?;
1195
1196 let prog = Program::new(&ml, self.config)?;
1197 let df = Defuns::new();
1198 let mut app = Application::new(prog, reals, df)?;
1199
1200 app.prepare_simd();
1201
1202 Ok(app)
1210 }
1211}