1use proc_macro::{TokenStream, TokenTree};
5use rust_ad_consts::{INTERNAL_FORWARD_PREFIX, INTERNAL_REVERSE_PREFIX};
6
7#[proc_macro]
11pub fn combined_derivative_macro(item: TokenStream) -> TokenStream {
12 let mut iter = item.into_iter();
14 let name = match iter.next() {
15 Some(TokenTree::Ident(ident)) => ident,
16 _ => panic!("No function ident"),
17 };
18 let vec = iter.collect::<Vec<_>>();
19 assert_eq!(vec.len() % 2, 0, "Bad punctuation");
20 let num = (vec.len() - 1) / 2;
21 let mut iter = vec.chunks_exact(2);
22
23 let default = match iter.next() {
24 Some(item) => {
25 let (punc, lit) = (&item[0], &item[1]);
26 match (punc, lit) {
27 (TokenTree::Punct(_), TokenTree::Literal(default)) => default,
28 _ => panic!("Bad default value"),
29 }
30 }
31 _ => panic!("No default value"),
32 };
33
34 let iter = iter.enumerate();
35 let arg_fmt_str = (0..num)
36 .map(|i| format!("args[{}],", i))
37 .collect::<String>();
38
39 let der_functions = iter
40 .map(|(index, item)| {
41 let (punc, lit) = (&item[0], &item[1]);
42 match (punc, lit) {
43 (TokenTree::Punct(_), TokenTree::Literal(format_str)) => format!(
44 "\tconst f{}: DFn = |args: &[Arg]| -> String {{ compose!({},{}) }};\n",
45 index, format_str, arg_fmt_str
46 ),
47 _ => panic!("Bad format strings"),
48 }
49 })
50 .collect::<String>();
51 let fn_fmt_str = (0..num).map(|i| format!("f{},", i)).collect::<String>();
52
53 let for_out_str = format!(
54 "pub static {}{}: FgdType = {{\n{}\n\tfgd::<{{ {} }},{{ &[{}] }}>\n}};",
55 INTERNAL_FORWARD_PREFIX, name, der_functions, default, fn_fmt_str
56 );
57 let rev_out_str = format!(
58 "pub static {}{}: RgdType = {{\n{}\n\trgd::<{{ {} }},{{ &[{}] }}>\n}};",
59 INTERNAL_REVERSE_PREFIX, name, der_functions, default, fn_fmt_str
60 );
61 let out_str = format!("{}\n{}", for_out_str, rev_out_str);
62 out_str.parse().unwrap()
64}
65
66#[proc_macro]
79pub fn forward_derivative_macro(item: TokenStream) -> TokenStream {
80 let mut iter = item.into_iter();
82 let name = match iter.next() {
83 Some(TokenTree::Ident(ident)) => ident,
84 _ => panic!("No function ident"),
85 };
86 let vec = iter.collect::<Vec<_>>();
87 assert_eq!(vec.len() % 2, 0, "Bad punctuation");
88 let num = (vec.len() - 1) / 2;
89 let mut iter = vec.chunks_exact(2);
90
91 let default = match iter.next() {
92 Some(item) => {
93 let (punc, lit) = (&item[0], &item[1]);
94 match (punc, lit) {
95 (TokenTree::Punct(_), TokenTree::Literal(default)) => default,
96 _ => panic!("Bad default value"),
97 }
98 }
99 _ => panic!("No default value"),
100 };
101
102 let iter = iter.enumerate();
103 let arg_fmt_str = (0..num)
104 .map(|i| format!("args[{}],", i))
105 .collect::<String>();
106
107 let der_functions = iter
108 .map(|(index, item)| {
109 let (punc, lit) = (&item[0], &item[1]);
110 match (punc, lit) {
111 (TokenTree::Punct(_), TokenTree::Literal(format_str)) => format!(
112 "\tconst f{}: DFn = |args: &[Arg]| -> String {{ compose!({},{}) }};\n",
113 index, format_str, arg_fmt_str
114 ),
115 _ => panic!("Bad format strings"),
116 }
117 })
118 .collect::<String>();
119 let fn_fmt_str = (0..num).map(|i| format!("f{},", i)).collect::<String>();
120 let out_str = format!(
121 "pub static {}{}: FgdType = {{\n{}\n\tfgd::<{{ {} }},{{ &[{}] }}>\n}};",
122 INTERNAL_FORWARD_PREFIX, name, der_functions, default, fn_fmt_str
123 );
124 out_str.parse().unwrap()
126}
127#[proc_macro]
129pub fn reverse_derivative_macro(item: TokenStream) -> TokenStream {
130 let mut iter = item.into_iter();
132 let name = match iter.next() {
133 Some(TokenTree::Ident(ident)) => ident,
134 _ => panic!("No function ident"),
135 };
136 let vec = iter.collect::<Vec<_>>();
137 assert_eq!(vec.len() % 2, 0, "Bad punctuation");
138 let num = (vec.len() - 1) / 2;
139 let mut iter = vec.chunks_exact(2);
140
141 let default = match iter.next() {
142 Some(item) => {
143 let (punc, lit) = (&item[0], &item[1]);
144 match (punc, lit) {
145 (TokenTree::Punct(_), TokenTree::Literal(default)) => default,
146 _ => panic!("Bad default value"),
147 }
148 }
149 _ => panic!("No default value"),
150 };
151
152 let iter = iter.enumerate();
153 let arg_fmt_str = (0..num)
154 .map(|i| format!("args[{}],", i))
155 .collect::<String>();
156
157 let der_functions = iter
158 .map(|(index, item)| {
159 let (punc, lit) = (&item[0], &item[1]);
160 match (punc, lit) {
161 (TokenTree::Punct(_), TokenTree::Literal(format_str)) => format!(
162 "\tconst f{}: DFn = |args: &[Arg]| -> String {{ compose!({},{}) }};\n",
163 index, format_str, arg_fmt_str
164 ),
165 _ => panic!("Bad format strings"),
166 }
167 })
168 .collect::<String>();
169 let fn_fmt_str = (0..num).map(|i| format!("f{},", i)).collect::<String>();
170 let out_str = format!(
171 "pub static {}{}: RgdType = {{\n{}\n\trgd::<{{ {} }},{{ &[{}] }}>\n}};",
172 INTERNAL_REVERSE_PREFIX, name, der_functions, default, fn_fmt_str
173 );
174 out_str.parse().unwrap()
176}
177
178#[proc_macro]
182pub fn compose(item: TokenStream) -> TokenStream {
183 let mut iter = item.into_iter();
185 let fmt_str = match iter.next() {
186 Some(TokenTree::Literal(l)) => l.to_string(),
187 _ => panic!("No fmt str"),
188 };
189 let vec = iter.skip(1).collect::<Vec<_>>();
190 let component_iter = vec.split(|t| match t {
191 TokenTree::Punct(p) => p.as_char() == ',',
192 _ => false,
193 });
194 let components = component_iter
195 .map(|component_slice| {
196 component_slice
197 .iter()
198 .map(|c| c.to_string())
199 .collect::<String>()
200 })
201 .collect::<Vec<_>>();
202
203 let mut bytes_string = Vec::from(&fmt_str.as_bytes()[1..fmt_str.len() - 1]);
204 let mut i = 0;
205 let mut out_str = String::from("let mut temp = String::new();");
206 while i < bytes_string.len() {
207 if bytes_string[i] == b'}' {
208 let index_str = String::from_utf8(bytes_string.drain(0..i).collect::<Vec<_>>())
210 .expect("compose: utf8");
211 let index: usize = index_str.parse().expect("compose: parse");
212 out_str.push_str(&format!(
213 "\n\ttemp.push_str(&{}.to_string());",
214 components[index]
215 ));
216 bytes_string.remove(0);
218 i = 0;
219 } else if bytes_string[i] == b'{' {
220 let segment = String::from_utf8(bytes_string.drain(0..i).collect::<Vec<_>>())
221 .expect("compose: utf8");
222 out_str.push_str(&format!("\n\ttemp.push_str(\"{}\");", segment));
223 bytes_string.remove(0);
225 i = 0;
226 } else {
227 i += 1;
228 }
229 }
230 let segment = String::from_utf8(bytes_string).expect("compose: utf8");
231 out_str.push_str(&format!("\n\ttemp.push_str(\"{}\");", segment));
232
233 let out_str = format!("{{\n\t{}\n\ttemp\n}}", out_str);
234 out_str.parse().unwrap()
236}
237
238#[proc_macro]
241pub fn f(item: TokenStream) -> TokenStream {
242 let mut items = item.into_iter();
243 let function_ident = match items.next() {
244 Some(proc_macro::TokenTree::Ident(ident)) => ident,
245 _ => panic!("Requires function identifier"),
246 };
247 let call_str = format!("{}{}", INTERNAL_FORWARD_PREFIX, function_ident);
248 call_str.parse().unwrap()
249}
250#[proc_macro]
252pub fn r(item: TokenStream) -> TokenStream {
253 let mut items = item.into_iter();
254 let function_ident = match items.next() {
255 Some(proc_macro::TokenTree::Ident(ident)) => ident,
256 _ => panic!("Requires function identifier"),
257 };
258 let call_str = format!("{}{}", INTERNAL_REVERSE_PREFIX, function_ident);
259 call_str.parse().unwrap()
260}