1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
use ::*;
macro_rules! recursive_apply {
($func:ident ($($var:ident),+), $name:expr, $val:expr) => {
$func ($(Box::new($var.apply1($name, $val))),+)
}
}
macro_rules! simplify_nn {
($($var:ident),+ => $body:expr, else $else_:expr) => {
if $($var.clone().is_num() &&)+ true {
$(let $var = $var.val().unwrap();)+
Num($body)
} else {
$else_
}
}
}
impl Expr {
pub fn apply1<T, S: AsRef<str>>(self, name_raw: S, value: T) -> Expr
where T: Clone + Into<Expr> {
use Expr::*;
let name = name_raw.as_ref();
let partial = match self {
Symbol(s) => if s == name {
value.into()
} else {
Symbol(s.clone())
},
Num(n) => Num(n),
Add(a, b) => recursive_apply!(Add(a, b), name, value.clone()),
Mul(a, b) => recursive_apply!(Mul(a, b), name, value.clone()),
Pow(a, b) => recursive_apply!(Pow(a, b), name, value.clone()),
Log(a, b) => recursive_apply!(Log(a, b), name, value.clone()),
Sin(x) => recursive_apply!(Sin(x), name, value.clone()),
Cos(x) => recursive_apply!(Cos(x), name, value.clone()),
Arcsin(x) => recursive_apply!(Arcsin(x), name, value.clone()),
Arccos(x) => recursive_apply!(Arccos(x), name, value.clone()),
Arctan(x) => recursive_apply!(Arctan(x), name, value.clone()),
};
match partial.clone() {
Add(a, b) => simplify_nn!(a, b => a + b, else partial),
Mul(a, b) => simplify_nn!(a, b => a * b, else partial),
Pow(a, b) => simplify_nn!(a, b => a.powf(b), else partial),
Log(a, b) => simplify_nn!(a, b => a.log(b), else partial),
Sin(x) => simplify_nn!(x => x.sin(), else partial),
Cos(x) => simplify_nn!(x => x.cos(), else partial),
Arcsin(x) => simplify_nn!(x => x.asin(), else partial),
Arccos(x) => simplify_nn!(x => x.acos(), else partial),
Arctan(x) => simplify_nn!(x => x.atan(), else partial),
_ => partial
}
}
}
#[test]
fn one() {
let a = 3f64;
let b = -2f64;
let c = 1f64;
let quad = a*(s!(x)^2.) + b*s!(x) + c;
quad.apply1("x", 3.).val().unwrap();
}
#[test]
fn batch() {
let scalar_field = s!(x) * s!(y) * s!(z);
assert_eq!(apply!(scalar_field, x = 2., y = 3., z = 4.).val().unwrap(), 24.);
}
#[test]
fn expr() {
let field = (s!(x)^2.) + s!(y).sqrt();
let x_t = 2.*s!(t) + 3.;
let y_t = s!(t) - s!(t)^2.;
let field_t = apply!(field, x = x_t, y = y_t);
assert_eq!(format!("{:?}", field_t),
"Add(Pow(Add(Mul(Num(2.0), Symbol(\"t\")), Num(3.0)), Num(2.0)), Pow(Pow(Add(Symbol(\"t\"), Mul(Symbol(\"t\"), Num(-1.0))), Num(2.0)), Num(0.5)))");
assert_eq!(apply!(field_t, t=1.).val().unwrap(), 25.);
}