1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#![feature(iter_intersperse)]
#![feature(adt_const_params)]
use std::collections::HashMap;
mod derivatives;
pub use derivatives::{cumulative_derivative_wrt_rt, Type};
mod dict;
pub use dict::*;
pub mod traits;
use traits::*;
mod utils;
pub const DERIVATIVE_PREFIX: &'static str = "__der_";
pub const FORWARD_MODE_PREFIX: &'static str = "__for_";
pub const REVERSE_MODE_PREFIX: &'static str = "__rev_";
pub const FUNCTION_PREFFIX: &'static str = "f";
pub const RECEIVER_PREFIX: &'static str = "r";
pub const RETURN_SUFFIX: &'static str = "rtn";
pub fn expr_type(expr: &syn::Expr, type_map: &HashMap<String, String>) -> Result<String, String> {
match expr {
syn::Expr::Path(path_expr) => {
let var = path_expr.path.segments[0].ident.to_string();
match type_map.get(&var) {
Some(ident) => Ok(ident.clone()),
None => Err(format!(
"expr_type: `{}` not found in type map `{:?}`",
var, type_map
)),
}
}
syn::Expr::Lit(lit_expr) => literal_type(lit_expr),
_ => panic!("expr_type: unsupported type"),
}
}
pub fn literal_type(expr_lit: &syn::ExprLit) -> Result<String, String> {
match &expr_lit.lit {
syn::Lit::Float(float_lit) => {
let float_str = float_lit.to_string();
let n = float_str.len();
if !(n > 3) {
return Err(
"All literals need a type suffix e.g. `10.2f32` -- Bad float literal (len)"
.into(),
);
}
let float_type_str = &float_str[n - 3..n];
if !(float_type_str == "f32" || float_type_str == "f64") {
return Err(
"All literals need a type suffix e.g. `10.2f32` -- Bad float literal (type)"
.into(),
);
}
Ok(String::from(float_type_str))
}
syn::Lit::Int(int_lit) => {
let int_str = int_lit.to_string();
let n = int_str.len();
let large_type = if n > 4 {
let large_int_str = &int_str[n - 4..n];
match large_int_str {
"i128" | "u128" => Some(String::from(large_int_str)),
_ => None,
}
} else {
None
};
let standard_type = if n > 3 {
let standard_int_str = &int_str[n - 3..n];
match standard_int_str {
"u16" | "u32" | "u64" | "i16" | "i32" | "i64" | "f32" | "f64" => {
Some(String::from(standard_int_str))
}
_ => None,
}
} else {
None
};
let short_type = if n > 2 {
let short_int_str = &int_str[n - 2..n];
match short_int_str {
"i8" | "u8" => Some(String::from(short_int_str)),
_ => None,
}
} else {
None
};
match large_type.or(standard_type).or(short_type) {
Some(int_lit_some) => Ok(int_lit_some),
None => Err(
"All literals need a type suffix e.g. `10.2f32` -- Bad integer literal".into(),
),
}
}
_ => Err("Unsupported literal (only integer and float literals are supported)".into()),
}
}
#[macro_export]
macro_rules! der {
($a:expr) => {{
format!("{}{}", crate::DERIVATIVE_PREFIX, $a)
}};
}
#[macro_export]
macro_rules! wrt {
($a:expr,$b:expr) => {{
format!("{}_wrt_{}", $a, $b)
}};
}
pub fn method_signature(
method_expr: &syn::ExprMethodCall,
type_map: &HashMap<String, String>,
) -> MethodSignature {
let method_str = method_expr.method.to_string();
let receiver_type_str =
expr_type(&*method_expr.receiver, type_map).expect("method_signature: bad expr");
let arg_types = method_expr
.args
.iter()
.map(|p| expr_type(p, type_map).expect("method_signature: bad arg type"))
.collect::<Vec<_>>();
MethodSignature::new(method_str, receiver_type_str, arg_types)
}
pub fn function_signature(
function_expr: &syn::ExprCall,
type_map: &HashMap<String, String>,
) -> FunctionSignature {
let arg_types = function_expr
.args
.iter()
.map(|arg| expr_type(arg, type_map).expect("function_signature: bad arg type"))
.collect::<Vec<_>>();
let func_ident_str = function_expr
.func
.path()
.expect("propagate_types: func not path")
.path
.segments[0]
.ident
.to_string();
FunctionSignature::new(func_ident_str, arg_types)
}
pub fn operation_signature(
operation_expr: &syn::ExprBinary,
type_map: &HashMap<String, String>,
) -> OperationSignature {
let left_type =
expr_type(&*operation_expr.left, type_map).expect("operation_signature: bad left");
let right_type =
expr_type(&*operation_expr.right, type_map).expect("operation_signature: bad right");
OperationSignature::from((left_type, operation_expr.op, right_type))
}