tensorflow_internal_macros/
lib.rs

1#![recursion_limit = "128"]
2//! The package provides macros for internal usage in TensorFlow. No backwards
3//! compatibility guarantees are made.
4
5extern crate proc_macro;
6
7use proc_macro::TokenStream;
8use proc_macro2::Literal;
9use proc_macro2::Span;
10use quote::quote;
11use quote::ToTokens;
12use syn::braced;
13use syn::parse::Parse;
14use syn::parse::ParseStream;
15use syn::parse_macro_input;
16use syn::punctuated::Punctuated;
17use syn::Error;
18use syn::Ident;
19use syn::LitStr;
20use syn::Result;
21use syn::Token;
22use syn::Type;
23
24#[derive(Clone)]
25struct Arg {
26    name: Ident,
27}
28
29impl Parse for Arg {
30    fn parse(input: ParseStream) -> Result<Self> {
31        let name = input.parse()?;
32        Ok(Arg { name })
33    }
34}
35
36struct Args {
37    args: Punctuated<Arg, Token![,]>,
38}
39
40impl Parse for Args {
41    fn parse(input: ParseStream) -> Result<Self> {
42        let list;
43        braced!(list in input);
44        Ok(Args {
45            args: list.parse_terminated(Arg::parse)?,
46        })
47    }
48}
49
50#[derive(Clone)]
51struct Attr {
52    optional: bool,
53    rust_name: Ident,
54    attr_type: Type,
55    c_name: LitStr,
56}
57
58impl Parse for Attr {
59    fn parse(input: ParseStream) -> Result<Self> {
60        let rust_name = input.parse()?;
61        let mut optional = false;
62        let lookahead = input.lookahead1();
63        if lookahead.peek(Token![?]) {
64            input.parse::<Token![?]>()?;
65            optional = true;
66        }
67        input.parse::<Token![:]>()?;
68        let attr_type = input.parse()?;
69        input.parse::<Token![=>]>()?;
70        let c_name = input.parse()?;
71        Ok(Attr {
72            optional,
73            rust_name,
74            attr_type,
75            c_name,
76        })
77    }
78}
79
80struct Attrs {
81    attrs: Punctuated<Attr, Token![,]>,
82}
83
84impl Parse for Attrs {
85    fn parse(input: ParseStream) -> Result<Self> {
86        let list;
87        braced!(list in input);
88        Ok(Attrs {
89            attrs: list.parse_terminated(Attr::parse)?,
90        })
91    }
92}
93
94struct DefineOpInput {
95    fn_name: Ident,
96    name: Ident,
97    op_name: LitStr,
98    deprecation_message: LitStr,
99    args: Vec<Arg>,
100    attrs: Vec<Attr>,
101}
102
103impl Parse for DefineOpInput {
104    fn parse(input: ParseStream) -> Result<Self> {
105        let fn_name = input.parse()?;
106        input.parse::<Token![,]>()?;
107        let name = input.parse()?;
108        input.parse::<Token![,]>()?;
109        let op_name = input.parse()?;
110        input.parse::<Token![,]>()?;
111        let deprecation_message = input.parse()?;
112        let mut args = Vec::new();
113        let mut attrs = Vec::new();
114        loop {
115            let lookahead = input.lookahead1();
116            if !lookahead.peek(Token![,]) {
117                break;
118            }
119            input.parse::<Token![,]>()?;
120            let ident: Ident = input.parse()?;
121            if ident == "args" {
122                let new_args: Args = input.parse()?;
123                args.extend(new_args.args);
124            } else if ident == "attrs" {
125                let new_attrs: Attrs = input.parse()?;
126                attrs.extend(new_attrs.attrs);
127            } else {
128                return Err(Error::new(Span::call_site(), "expected `attrs` or `args`"));
129            }
130        }
131        Ok(DefineOpInput {
132            fn_name,
133            name,
134            op_name,
135            deprecation_message,
136            args,
137            attrs,
138        })
139    }
140}
141
142struct AttrDefs<'a>(&'a [Attr]);
143
144impl<'a> ToTokens for AttrDefs<'a> {
145    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
146        for attr in self.0 {
147            let rust_name = &attr.rust_name;
148            let attr_type = &attr.attr_type;
149            if attr.optional {
150                tokens.extend(quote! { #rust_name: ::std::option::Option<#attr_type>, });
151            } else {
152                tokens.extend(quote! { #rust_name: #attr_type, });
153            }
154        }
155    }
156}
157
158struct AttrSetters<'a>(&'a [Attr]);
159
160impl<'a> ToTokens for AttrSetters<'a> {
161    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
162        for attr in self.0 {
163            let comment =
164                Literal::string(&format!("Sets the `{}` attribute.", attr.c_name.value()));
165            let rust_name = &attr.rust_name;
166            let attr_type = &attr.attr_type;
167            let mut needs_into = false;
168            let mut arg_type = attr_type.clone();
169            if attr_type == &syn::parse_str::<Type>("String").unwrap() {
170                needs_into = true;
171                // TODO: don't use parse
172                arg_type = syn::parse_str::<Type>("&str").unwrap()
173            };
174            let mut value = quote! { value };
175            if needs_into {
176                value = quote! { <#arg_type as ::std::convert::Into<#attr_type>>::into(#value) };
177            }
178            if attr.optional {
179                value = quote! { ::std::option::Option::Some(#value) };
180            }
181            tokens.extend(quote! {
182                #[doc = #comment]
183                pub fn #rust_name(mut self, value: #arg_type) -> Self {
184                    self.#rust_name = #value;
185                    self
186                }
187            });
188        }
189    }
190}
191
192struct BuildFnGenerics {
193    arg_count: usize,
194}
195
196impl ToTokens for BuildFnGenerics {
197    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
198        if self.arg_count == 0 {
199            return;
200        }
201        tokens.extend(quote! {<});
202        for i in 0..self.arg_count {
203            if i > 0 {
204                tokens.extend(quote! {,});
205            }
206            let arg = Ident::new(&format!("O{}", i + 1), Span::call_site());
207            tokens.extend(quote! {#arg: ::std::convert::Into<crate::Output>});
208        }
209        tokens.extend(quote! {>});
210    }
211}
212
213struct BuildFnArgs<'a> {
214    args: &'a [Arg],
215}
216
217impl<'a> ToTokens for BuildFnArgs<'a> {
218    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
219        for (i, arg) in self.args.iter().enumerate() {
220            let arg_name = &arg.name;
221            let arg_type = Ident::new(&format!("O{}", i + 1), Span::call_site());
222            tokens.extend(quote! {#arg_name: #arg_type, });
223        }
224    }
225}
226
227struct SetAttr<'a> {
228    attr: &'a Attr,
229}
230
231impl<'a> ToTokens for SetAttr<'a> {
232    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
233        let c_name = &self.attr.c_name;
234        let rust_name = &self.attr.rust_name;
235        let setter = |value| match self
236            .attr
237            .attr_type
238            .clone()
239            .into_token_stream()
240            .to_string()
241            .as_str()
242        {
243            "String" => quote! { nd.set_attr_string(#c_name, &#value)?; },
244            "DataType" => quote! { nd.set_attr_type(#c_name, #value)?; },
245            "bool" => quote! { nd.set_attr_bool(#c_name, #value)?; },
246            "i64" => quote! { nd.set_attr_int(#c_name, #value)?; },
247            "Shape" => quote! { nd.set_attr_shape(#c_name, &#value)?; },
248            ty => panic!(
249                "Unrecognized attribute type for {}: {}",
250                self.attr.rust_name, ty
251            ),
252        };
253        tokens.extend(if self.attr.optional {
254            let set = setter(quote! { *value });
255            quote! {
256                if let Some(value) = &self.#rust_name {
257                    #set
258                }
259            }
260        } else {
261            setter(quote! { self.#rust_name })
262        });
263    }
264}
265
266struct BuildFn<'a> {
267    op_name: &'a LitStr,
268    args: &'a [Arg],
269    attrs: &'a [Attr],
270}
271
272impl<'a> ToTokens for BuildFn<'a> {
273    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
274        let op_name = &self.op_name;
275        let build_fn_generics = BuildFnGenerics {
276            arg_count: self.args.len(),
277        };
278        let build_fn_args = BuildFnArgs { args: self.args };
279        let arg_names = self.args.iter().map(|arg| &arg.name);
280        let set_attrs = self.attrs.iter().map(|attr| SetAttr { attr });
281        tokens.extend(quote! {
282            #[doc = "Builds the `"]
283            #[doc = #op_name]
284            #[doc = "` operation."]
285            pub fn build#build_fn_generics(&self, #build_fn_args scope: &mut crate::Scope) -> crate::Result<crate::Operation> {
286                let name = scope.get_unique_name_for_op(#op_name);
287                let mut graph = scope.graph_mut();
288                let mut nd = graph.new_operation(#op_name, &name)?;
289                #(
290                    nd.add_input(#arg_names);
291                )*
292                for op in &self.control_inputs {
293                    nd.add_control_input(op);
294                }
295                #(#set_attrs)*
296                nd.finish()
297            }
298        });
299    }
300}
301
302struct ShortFn<'a> {
303    name: &'a Ident,
304    fn_name: &'a Ident,
305    deprecation_message: &'a LitStr,
306    args: &'a [Arg],
307}
308
309impl<'a> ToTokens for ShortFn<'a> {
310    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
311        let name = &self.name;
312        let fn_name = &self.fn_name;
313        let build_fn_generics = BuildFnGenerics {
314            arg_count: self.args.len(),
315        };
316        let build_fn_args = BuildFnArgs { args: self.args };
317        let arg_names = self.args.iter().map(|arg| &arg.name);
318        let mut docs = format!("Shorthand for `{}::new().build(scope)", name);
319        for arg in self.args {
320            docs.push_str(", ");
321            docs.push_str(&arg.name.to_string());
322        }
323        docs.push_str(")`.");
324        let deprecation_message = &self.deprecation_message;
325        tokens.extend(quote! {
326            #[doc = #docs]
327            #[allow(deprecated)]
328            #[deprecated(note = #deprecation_message, since = "0.15.0")]
329            pub fn #fn_name#build_fn_generics(#build_fn_args scope: &mut crate::Scope) -> crate::Result<crate::Operation> {
330                #name::new().build(#(#arg_names, )* scope)
331            }
332        });
333    }
334}
335
336#[proc_macro]
337pub fn define_op(input: TokenStream) -> TokenStream {
338    let input = parse_macro_input!(input as DefineOpInput);
339    let fn_name = input.fn_name;
340    let name = input.name;
341    let op_name = input.op_name;
342    let name_str = name.to_string();
343    let name_str_plus_period = name_str + ".";
344    let deprecation_message = input.deprecation_message;
345    let attr_defs = AttrDefs(&input.attrs);
346    let attr_setters = AttrSetters(&input.attrs);
347    let build_fn = BuildFn {
348        op_name: &op_name,
349        args: &input.args,
350        attrs: &input.attrs,
351    };
352    let short_fn = ShortFn {
353        name: &name,
354        fn_name: &fn_name,
355        deprecation_message: &deprecation_message,
356        args: &input.args,
357    };
358    let stream = quote! {
359        #[doc = "Builder for the `"]
360        #[doc = #op_name]
361        #[doc = "` operation."]
362        #[derive(Debug,Default)]
363        #[deprecated(note = #deprecation_message, since = "0.15.0")]
364        #[allow(deprecated)]
365        pub struct #name {
366            #attr_defs
367            control_inputs: Vec<crate::Operation>,
368        }
369
370        #[allow(deprecated)]
371        impl #name {
372            #[doc = "Creates a new"]
373            #[doc = #name_str_plus_period]
374            pub fn new() -> Self {
375                Self::default()
376            }
377
378            #attr_setters
379
380            /// Adds a control input.
381            pub fn add_control_input(mut self, op: crate::Operation) -> Self {
382                self.control_inputs.push(op);
383                self
384            }
385
386            #build_fn
387        }
388
389        #short_fn
390    };
391    stream.into()
392}