rust_ad_core_macros/
lib.rs

1//! **I do not recommend using this directly, please sea [rust-ad](https://crates.io/crates/rust-ad).**
2//!
3//! Internal proc-macro functionality.
4use proc_macro::{TokenStream, TokenTree};
5use rust_ad_consts::{INTERNAL_FORWARD_PREFIX, INTERNAL_REVERSE_PREFIX};
6
7/// Same result as performing `forward_derivative_macro` and `reverse_derivative_macro` consecutively.
8///
9/// But this does it a little neater and more efficiently.
10#[proc_macro]
11pub fn combined_derivative_macro(item: TokenStream) -> TokenStream {
12    // eprintln!("\nitem:\n{:?}\n",item);
13    let mut iter = item.into_iter();
14    let name = match iter.next() {
15        Some(TokenTree::Ident(ident)) => ident,
16        _ => panic!("No function ident"),
17    };
18    let vec = iter.collect::<Vec<_>>();
19    assert_eq!(vec.len() % 2, 0, "Bad punctuation");
20    let num = (vec.len() - 1) / 2;
21    let mut iter = vec.chunks_exact(2);
22
23    let default = match iter.next() {
24        Some(item) => {
25            let (punc, lit) = (&item[0], &item[1]);
26            match (punc, lit) {
27                (TokenTree::Punct(_), TokenTree::Literal(default)) => default,
28                _ => panic!("Bad default value"),
29            }
30        }
31        _ => panic!("No default value"),
32    };
33
34    let iter = iter.enumerate();
35    let arg_fmt_str = (0..num)
36        .map(|i| format!("args[{}],", i))
37        .collect::<String>();
38
39    let der_functions = iter
40        .map(|(index, item)| {
41            let (punc, lit) = (&item[0], &item[1]);
42            match (punc, lit) {
43                (TokenTree::Punct(_), TokenTree::Literal(format_str)) => format!(
44                    "\tconst f{}: DFn = |args: &[Arg]| -> String {{ compose!({},{}) }};\n",
45                    index, format_str, arg_fmt_str
46                ),
47                _ => panic!("Bad format strings"),
48            }
49        })
50        .collect::<String>();
51    let fn_fmt_str = (0..num).map(|i| format!("f{},", i)).collect::<String>();
52
53    let for_out_str = format!(
54        "pub static {}{}: FgdType = {{\n{}\n\tfgd::<{{ {} }},{{ &[{}] }}>\n}};",
55        INTERNAL_FORWARD_PREFIX, name, der_functions, default, fn_fmt_str
56    );
57    let rev_out_str = format!(
58        "pub static {}{}: RgdType = {{\n{}\n\trgd::<{{ {} }},{{ &[{}] }}>\n}};",
59        INTERNAL_REVERSE_PREFIX, name, der_functions, default, fn_fmt_str
60    );
61    let out_str = format!("{}\n{}", for_out_str, rev_out_str);
62    // eprintln!("out_str: \n{}\n",out_str);
63    out_str.parse().unwrap()
64}
65
66/// Generates forward derivative functions.
67/// ```ignore
68/// static outer_test: FgdType = {
69///     const base_fn: DFn = |args:&[String]| -> String { format!("{0}-{1}",args[0],args[1]) };
70///     const exponent_fn: DFn = |args:&[String]| -> String { format!("{0}*{1}+{0}",args[0],args[1]) };
71///     fgd::<"0f32",{&[base_fn, exponent_fn]}>
72/// };
73/// ```
74/// Is equivalent to
75/// ```ignore
76/// forward_derivative_macro!(outer_test,"0f32","{0}-{1}","{0}*{1}+{0}");
77/// ```
78#[proc_macro]
79pub fn forward_derivative_macro(item: TokenStream) -> TokenStream {
80    // eprintln!("\nitem:\n{:?}\n",item);
81    let mut iter = item.into_iter();
82    let name = match iter.next() {
83        Some(TokenTree::Ident(ident)) => ident,
84        _ => panic!("No function ident"),
85    };
86    let vec = iter.collect::<Vec<_>>();
87    assert_eq!(vec.len() % 2, 0, "Bad punctuation");
88    let num = (vec.len() - 1) / 2;
89    let mut iter = vec.chunks_exact(2);
90
91    let default = match iter.next() {
92        Some(item) => {
93            let (punc, lit) = (&item[0], &item[1]);
94            match (punc, lit) {
95                (TokenTree::Punct(_), TokenTree::Literal(default)) => default,
96                _ => panic!("Bad default value"),
97            }
98        }
99        _ => panic!("No default value"),
100    };
101
102    let iter = iter.enumerate();
103    let arg_fmt_str = (0..num)
104        .map(|i| format!("args[{}],", i))
105        .collect::<String>();
106
107    let der_functions = iter
108        .map(|(index, item)| {
109            let (punc, lit) = (&item[0], &item[1]);
110            match (punc, lit) {
111                (TokenTree::Punct(_), TokenTree::Literal(format_str)) => format!(
112                    "\tconst f{}: DFn = |args: &[Arg]| -> String {{ compose!({},{}) }};\n",
113                    index, format_str, arg_fmt_str
114                ),
115                _ => panic!("Bad format strings"),
116            }
117        })
118        .collect::<String>();
119    let fn_fmt_str = (0..num).map(|i| format!("f{},", i)).collect::<String>();
120    let out_str = format!(
121        "pub static {}{}: FgdType = {{\n{}\n\tfgd::<{{ {} }},{{ &[{}] }}>\n}};",
122        INTERNAL_FORWARD_PREFIX, name, der_functions, default, fn_fmt_str
123    );
124    // eprintln!("out_str: \n{}\n",out_str);
125    out_str.parse().unwrap()
126}
127/// Generates reverse derivative functions.
128#[proc_macro]
129pub fn reverse_derivative_macro(item: TokenStream) -> TokenStream {
130    // eprintln!("\nitem:\n{:?}\n",item);
131    let mut iter = item.into_iter();
132    let name = match iter.next() {
133        Some(TokenTree::Ident(ident)) => ident,
134        _ => panic!("No function ident"),
135    };
136    let vec = iter.collect::<Vec<_>>();
137    assert_eq!(vec.len() % 2, 0, "Bad punctuation");
138    let num = (vec.len() - 1) / 2;
139    let mut iter = vec.chunks_exact(2);
140
141    let default = match iter.next() {
142        Some(item) => {
143            let (punc, lit) = (&item[0], &item[1]);
144            match (punc, lit) {
145                (TokenTree::Punct(_), TokenTree::Literal(default)) => default,
146                _ => panic!("Bad default value"),
147            }
148        }
149        _ => panic!("No default value"),
150    };
151
152    let iter = iter.enumerate();
153    let arg_fmt_str = (0..num)
154        .map(|i| format!("args[{}],", i))
155        .collect::<String>();
156
157    let der_functions = iter
158        .map(|(index, item)| {
159            let (punc, lit) = (&item[0], &item[1]);
160            match (punc, lit) {
161                (TokenTree::Punct(_), TokenTree::Literal(format_str)) => format!(
162                    "\tconst f{}: DFn = |args: &[Arg]| -> String {{ compose!({},{}) }};\n",
163                    index, format_str, arg_fmt_str
164                ),
165                _ => panic!("Bad format strings"),
166            }
167        })
168        .collect::<String>();
169    let fn_fmt_str = (0..num).map(|i| format!("f{},", i)).collect::<String>();
170    let out_str = format!(
171        "pub static {}{}: RgdType = {{\n{}\n\trgd::<{{ {} }},{{ &[{}] }}>\n}};",
172        INTERNAL_REVERSE_PREFIX, name, der_functions, default, fn_fmt_str
173    );
174    // eprintln!("out_str: \n{}\n",out_str);
175    out_str.parse().unwrap()
176}
177
178/// `format!()` but:
179/// 1. only allows positional arguments e.g. `{0}`, `{1}`, etc.
180/// 2. allows unused arguments.
181#[proc_macro]
182pub fn compose(item: TokenStream) -> TokenStream {
183    // eprintln!("item: {}",item);
184    let mut iter = item.into_iter();
185    let fmt_str = match iter.next() {
186        Some(TokenTree::Literal(l)) => l.to_string(),
187        _ => panic!("No fmt str"),
188    };
189    let vec = iter.skip(1).collect::<Vec<_>>();
190    let component_iter = vec.split(|t| match t {
191        TokenTree::Punct(p) => p.as_char() == ',',
192        _ => false,
193    });
194    let components = component_iter
195        .map(|component_slice| {
196            component_slice
197                .iter()
198                .map(|c| c.to_string())
199                .collect::<String>()
200        })
201        .collect::<Vec<_>>();
202
203    let mut bytes_string = Vec::from(&fmt_str.as_bytes()[1..fmt_str.len() - 1]);
204    let mut i = 0;
205    let mut out_str = String::from("let mut temp = String::new();");
206    while i < bytes_string.len() {
207        if bytes_string[i] == b'}' {
208            // Removes opening '}'
209            let index_str = String::from_utf8(bytes_string.drain(0..i).collect::<Vec<_>>())
210                .expect("compose: utf8");
211            let index: usize = index_str.parse().expect("compose: parse");
212            out_str.push_str(&format!(
213                "\n\ttemp.push_str(&{}.to_string());",
214                components[index]
215            ));
216            // Removes'}'
217            bytes_string.remove(0);
218            i = 0;
219        } else if bytes_string[i] == b'{' {
220            let segment = String::from_utf8(bytes_string.drain(0..i).collect::<Vec<_>>())
221                .expect("compose: utf8");
222            out_str.push_str(&format!("\n\ttemp.push_str(\"{}\");", segment));
223            // Removes '{'
224            bytes_string.remove(0);
225            i = 0;
226        } else {
227            i += 1;
228        }
229    }
230    let segment = String::from_utf8(bytes_string).expect("compose: utf8");
231    out_str.push_str(&format!("\n\ttemp.push_str(\"{}\");", segment));
232
233    let out_str = format!("{{\n\t{}\n\ttemp\n}}", out_str);
234    // eprintln!("out_str: {}",out_str);
235    out_str.parse().unwrap()
236}
237
238// TODO Can we replace these with declarative macros like `der!` (and then just move them into `rust-ad-core`)?
239/// Gets internal forward derivative function identifier
240#[proc_macro]
241pub fn f(item: TokenStream) -> TokenStream {
242    let mut items = item.into_iter();
243    let function_ident = match items.next() {
244        Some(proc_macro::TokenTree::Ident(ident)) => ident,
245        _ => panic!("Requires function identifier"),
246    };
247    let call_str = format!("{}{}", INTERNAL_FORWARD_PREFIX, function_ident);
248    call_str.parse().unwrap()
249}
250/// Gets internal reverse derivative function identifier
251#[proc_macro]
252pub fn r(item: TokenStream) -> TokenStream {
253    let mut items = item.into_iter();
254    let function_ident = match items.next() {
255        Some(proc_macro::TokenTree::Ident(ident)) => ident,
256        _ => panic!("Requires function identifier"),
257    };
258    let call_str = format!("{}{}", INTERNAL_REVERSE_PREFIX, function_ident);
259    call_str.parse().unwrap()
260}