tensorflow_internal_macros/
lib.rs1#![recursion_limit = "128"]
2extern crate proc_macro;
6
7use proc_macro::TokenStream;
8use proc_macro2::Literal;
9use proc_macro2::Span;
10use quote::quote;
11use quote::ToTokens;
12use syn::braced;
13use syn::parse::Parse;
14use syn::parse::ParseStream;
15use syn::parse_macro_input;
16use syn::punctuated::Punctuated;
17use syn::Error;
18use syn::Ident;
19use syn::LitStr;
20use syn::Result;
21use syn::Token;
22use syn::Type;
23
24#[derive(Clone)]
25struct Arg {
26 name: Ident,
27}
28
29impl Parse for Arg {
30 fn parse(input: ParseStream) -> Result<Self> {
31 let name = input.parse()?;
32 Ok(Arg { name })
33 }
34}
35
36struct Args {
37 args: Punctuated<Arg, Token![,]>,
38}
39
40impl Parse for Args {
41 fn parse(input: ParseStream) -> Result<Self> {
42 let list;
43 braced!(list in input);
44 Ok(Args {
45 args: list.parse_terminated(Arg::parse)?,
46 })
47 }
48}
49
50#[derive(Clone)]
51struct Attr {
52 optional: bool,
53 rust_name: Ident,
54 attr_type: Type,
55 c_name: LitStr,
56}
57
58impl Parse for Attr {
59 fn parse(input: ParseStream) -> Result<Self> {
60 let rust_name = input.parse()?;
61 let mut optional = false;
62 let lookahead = input.lookahead1();
63 if lookahead.peek(Token![?]) {
64 input.parse::<Token![?]>()?;
65 optional = true;
66 }
67 input.parse::<Token![:]>()?;
68 let attr_type = input.parse()?;
69 input.parse::<Token![=>]>()?;
70 let c_name = input.parse()?;
71 Ok(Attr {
72 optional,
73 rust_name,
74 attr_type,
75 c_name,
76 })
77 }
78}
79
80struct Attrs {
81 attrs: Punctuated<Attr, Token![,]>,
82}
83
84impl Parse for Attrs {
85 fn parse(input: ParseStream) -> Result<Self> {
86 let list;
87 braced!(list in input);
88 Ok(Attrs {
89 attrs: list.parse_terminated(Attr::parse)?,
90 })
91 }
92}
93
94struct DefineOpInput {
95 fn_name: Ident,
96 name: Ident,
97 op_name: LitStr,
98 deprecation_message: LitStr,
99 args: Vec<Arg>,
100 attrs: Vec<Attr>,
101}
102
103impl Parse for DefineOpInput {
104 fn parse(input: ParseStream) -> Result<Self> {
105 let fn_name = input.parse()?;
106 input.parse::<Token![,]>()?;
107 let name = input.parse()?;
108 input.parse::<Token![,]>()?;
109 let op_name = input.parse()?;
110 input.parse::<Token![,]>()?;
111 let deprecation_message = input.parse()?;
112 let mut args = Vec::new();
113 let mut attrs = Vec::new();
114 loop {
115 let lookahead = input.lookahead1();
116 if !lookahead.peek(Token![,]) {
117 break;
118 }
119 input.parse::<Token![,]>()?;
120 let ident: Ident = input.parse()?;
121 if ident == "args" {
122 let new_args: Args = input.parse()?;
123 args.extend(new_args.args);
124 } else if ident == "attrs" {
125 let new_attrs: Attrs = input.parse()?;
126 attrs.extend(new_attrs.attrs);
127 } else {
128 return Err(Error::new(Span::call_site(), "expected `attrs` or `args`"));
129 }
130 }
131 Ok(DefineOpInput {
132 fn_name,
133 name,
134 op_name,
135 deprecation_message,
136 args,
137 attrs,
138 })
139 }
140}
141
142struct AttrDefs<'a>(&'a [Attr]);
143
144impl<'a> ToTokens for AttrDefs<'a> {
145 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
146 for attr in self.0 {
147 let rust_name = &attr.rust_name;
148 let attr_type = &attr.attr_type;
149 if attr.optional {
150 tokens.extend(quote! { #rust_name: ::std::option::Option<#attr_type>, });
151 } else {
152 tokens.extend(quote! { #rust_name: #attr_type, });
153 }
154 }
155 }
156}
157
158struct AttrSetters<'a>(&'a [Attr]);
159
160impl<'a> ToTokens for AttrSetters<'a> {
161 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
162 for attr in self.0 {
163 let comment =
164 Literal::string(&format!("Sets the `{}` attribute.", attr.c_name.value()));
165 let rust_name = &attr.rust_name;
166 let attr_type = &attr.attr_type;
167 let mut needs_into = false;
168 let mut arg_type = attr_type.clone();
169 if attr_type == &syn::parse_str::<Type>("String").unwrap() {
170 needs_into = true;
171 arg_type = syn::parse_str::<Type>("&str").unwrap()
173 };
174 let mut value = quote! { value };
175 if needs_into {
176 value = quote! { <#arg_type as ::std::convert::Into<#attr_type>>::into(#value) };
177 }
178 if attr.optional {
179 value = quote! { ::std::option::Option::Some(#value) };
180 }
181 tokens.extend(quote! {
182 #[doc = #comment]
183 pub fn #rust_name(mut self, value: #arg_type) -> Self {
184 self.#rust_name = #value;
185 self
186 }
187 });
188 }
189 }
190}
191
192struct BuildFnGenerics {
193 arg_count: usize,
194}
195
196impl ToTokens for BuildFnGenerics {
197 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
198 if self.arg_count == 0 {
199 return;
200 }
201 tokens.extend(quote! {<});
202 for i in 0..self.arg_count {
203 if i > 0 {
204 tokens.extend(quote! {,});
205 }
206 let arg = Ident::new(&format!("O{}", i + 1), Span::call_site());
207 tokens.extend(quote! {#arg: ::std::convert::Into<crate::Output>});
208 }
209 tokens.extend(quote! {>});
210 }
211}
212
213struct BuildFnArgs<'a> {
214 args: &'a [Arg],
215}
216
217impl<'a> ToTokens for BuildFnArgs<'a> {
218 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
219 for (i, arg) in self.args.iter().enumerate() {
220 let arg_name = &arg.name;
221 let arg_type = Ident::new(&format!("O{}", i + 1), Span::call_site());
222 tokens.extend(quote! {#arg_name: #arg_type, });
223 }
224 }
225}
226
227struct SetAttr<'a> {
228 attr: &'a Attr,
229}
230
231impl<'a> ToTokens for SetAttr<'a> {
232 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
233 let c_name = &self.attr.c_name;
234 let rust_name = &self.attr.rust_name;
235 let setter = |value| match self
236 .attr
237 .attr_type
238 .clone()
239 .into_token_stream()
240 .to_string()
241 .as_str()
242 {
243 "String" => quote! { nd.set_attr_string(#c_name, &#value)?; },
244 "DataType" => quote! { nd.set_attr_type(#c_name, #value)?; },
245 "bool" => quote! { nd.set_attr_bool(#c_name, #value)?; },
246 "i64" => quote! { nd.set_attr_int(#c_name, #value)?; },
247 "Shape" => quote! { nd.set_attr_shape(#c_name, &#value)?; },
248 ty => panic!(
249 "Unrecognized attribute type for {}: {}",
250 self.attr.rust_name, ty
251 ),
252 };
253 tokens.extend(if self.attr.optional {
254 let set = setter(quote! { *value });
255 quote! {
256 if let Some(value) = &self.#rust_name {
257 #set
258 }
259 }
260 } else {
261 setter(quote! { self.#rust_name })
262 });
263 }
264}
265
266struct BuildFn<'a> {
267 op_name: &'a LitStr,
268 args: &'a [Arg],
269 attrs: &'a [Attr],
270}
271
272impl<'a> ToTokens for BuildFn<'a> {
273 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
274 let op_name = &self.op_name;
275 let build_fn_generics = BuildFnGenerics {
276 arg_count: self.args.len(),
277 };
278 let build_fn_args = BuildFnArgs { args: self.args };
279 let arg_names = self.args.iter().map(|arg| &arg.name);
280 let set_attrs = self.attrs.iter().map(|attr| SetAttr { attr });
281 tokens.extend(quote! {
282 #[doc = "Builds the `"]
283 #[doc = #op_name]
284 #[doc = "` operation."]
285 pub fn build#build_fn_generics(&self, #build_fn_args scope: &mut crate::Scope) -> crate::Result<crate::Operation> {
286 let name = scope.get_unique_name_for_op(#op_name);
287 let mut graph = scope.graph_mut();
288 let mut nd = graph.new_operation(#op_name, &name)?;
289 #(
290 nd.add_input(#arg_names);
291 )*
292 for op in &self.control_inputs {
293 nd.add_control_input(op);
294 }
295 #(#set_attrs)*
296 nd.finish()
297 }
298 });
299 }
300}
301
302struct ShortFn<'a> {
303 name: &'a Ident,
304 fn_name: &'a Ident,
305 deprecation_message: &'a LitStr,
306 args: &'a [Arg],
307}
308
309impl<'a> ToTokens for ShortFn<'a> {
310 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
311 let name = &self.name;
312 let fn_name = &self.fn_name;
313 let build_fn_generics = BuildFnGenerics {
314 arg_count: self.args.len(),
315 };
316 let build_fn_args = BuildFnArgs { args: self.args };
317 let arg_names = self.args.iter().map(|arg| &arg.name);
318 let mut docs = format!("Shorthand for `{}::new().build(scope)", name);
319 for arg in self.args {
320 docs.push_str(", ");
321 docs.push_str(&arg.name.to_string());
322 }
323 docs.push_str(")`.");
324 let deprecation_message = &self.deprecation_message;
325 tokens.extend(quote! {
326 #[doc = #docs]
327 #[allow(deprecated)]
328 #[deprecated(note = #deprecation_message, since = "0.15.0")]
329 pub fn #fn_name#build_fn_generics(#build_fn_args scope: &mut crate::Scope) -> crate::Result<crate::Operation> {
330 #name::new().build(#(#arg_names, )* scope)
331 }
332 });
333 }
334}
335
336#[proc_macro]
337pub fn define_op(input: TokenStream) -> TokenStream {
338 let input = parse_macro_input!(input as DefineOpInput);
339 let fn_name = input.fn_name;
340 let name = input.name;
341 let op_name = input.op_name;
342 let name_str = name.to_string();
343 let name_str_plus_period = name_str + ".";
344 let deprecation_message = input.deprecation_message;
345 let attr_defs = AttrDefs(&input.attrs);
346 let attr_setters = AttrSetters(&input.attrs);
347 let build_fn = BuildFn {
348 op_name: &op_name,
349 args: &input.args,
350 attrs: &input.attrs,
351 };
352 let short_fn = ShortFn {
353 name: &name,
354 fn_name: &fn_name,
355 deprecation_message: &deprecation_message,
356 args: &input.args,
357 };
358 let stream = quote! {
359 #[doc = "Builder for the `"]
360 #[doc = #op_name]
361 #[doc = "` operation."]
362 #[derive(Debug,Default)]
363 #[deprecated(note = #deprecation_message, since = "0.15.0")]
364 #[allow(deprecated)]
365 pub struct #name {
366 #attr_defs
367 control_inputs: Vec<crate::Operation>,
368 }
369
370 #[allow(deprecated)]
371 impl #name {
372 #[doc = "Creates a new"]
373 #[doc = #name_str_plus_period]
374 pub fn new() -> Self {
375 Self::default()
376 }
377
378 #attr_setters
379
380 pub fn add_control_input(mut self, op: crate::Operation) -> Self {
382 self.control_inputs.push(op);
383 self
384 }
385
386 #build_fn
387 }
388
389 #short_fn
390 };
391 stream.into()
392}