winter_maybe_async/
lib.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, Expr, ImplItem, ItemFn, ItemImpl, ItemTrait, TraitItem, TraitItemFn};
9
10/// Parses a function (regular or trait) and conditionally adds the `async` keyword depending on
11/// the `async` feature flag being enabled.
12///
13/// For example:
14/// ```ignore
15/// trait ExampleTrait {
16///     #[maybe_async]
17///     fn say_hello(&self);
18///
19///     #[maybe_async]
20///     fn get_hello(&self) -> String;
21/// }
22///
23///
24/// #[maybe_async]
25/// fn hello_world() {
26///     // ...
27/// }
28/// ```
29///
30/// When the `async` feature is enabled, will be transformed into:
31/// ```ignore
32/// trait ExampleTrait {
33///     async fn say_hello(&self);
34///
35///     async fn get_hello(&self) -> String;
36/// }
37///
38///
39/// async fn hello_world() {
40///     // ...
41/// }
42/// ```
43#[proc_macro_attribute]
44pub fn maybe_async(_attr: TokenStream, input: TokenStream) -> TokenStream {
45    if let Ok(func) = syn::parse::<ItemFn>(input.clone()) {
46        if cfg!(feature = "async") {
47            let ItemFn { attrs, vis, sig, block } = func;
48            quote! {
49                #(#attrs)* #vis async #sig { #block }
50            }
51            .into()
52        } else {
53            quote!(#func).into()
54        }
55    } else if let Ok(func) = syn::parse::<TraitItemFn>(input.clone()) {
56        if cfg!(feature = "async") {
57            let TraitItemFn { attrs, sig, default, semi_token } = func;
58            quote! {
59                #(#attrs)* async #sig #default #semi_token
60            }
61            .into()
62        } else {
63            quote!(#func).into()
64        }
65    } else {
66        input
67    }
68}
69
70/// Conditionally add `async` keyword to functions.
71///
72/// Parses a trait or an `impl` block and conditionally adds the `async` keyword to methods that
73/// are annotated with `#[maybe_async]`, depending on the `async` feature flag being enabled.
74/// Additionally, if applied to a trait definition or impl block, it will add
75/// `#[async_trait::async_trait(?Send)]` to the it.
76///
77/// For example, given the following trait definition:
78/// ```ignore
79/// #[maybe_async_trait]
80/// trait ExampleTrait {
81///     #[maybe_async]
82///     fn hello_world(&self);
83///
84///     fn get_hello(&self) -> String;
85/// }
86/// ```
87///
88/// And the following implementation:
89/// ```ignore
90/// #[maybe_async_trait]
91/// impl ExampleTrait for MyStruct {
92///     #[maybe_async]
93///     fn hello_world(&self) {
94///         // ...
95///     }
96///
97///     fn get_hello(&self) -> String {
98///         // ...
99///     }
100/// }
101/// ```
102///
103/// When the `async` feature is enabled, this will be transformed into:
104/// ```ignore
105/// #[async_trait::async_trait(?Send)]
106/// trait ExampleTrait {
107///     async fn hello_world(&self);
108///
109///     fn get_hello(&self) -> String;
110/// }
111///
112/// #[async_trait::async_trait(?Send)]
113/// impl ExampleTrait for MyStruct {
114///     async fn hello_world(&self) {
115///         // ...
116///     }
117///
118///     fn get_hello(&self) -> String {
119///         // ...
120///     }
121/// }
122/// ```
123///
124/// When the `async` feature is disabled, the code remains unchanged, and neither the `async`
125/// keyword nor the `#[async_trait::async_trait(?Send)]` attribute is applied.
126#[proc_macro_attribute]
127pub fn maybe_async_trait(_attr: TokenStream, input: TokenStream) -> TokenStream {
128    // Try parsing the input as a trait definition
129    if let Ok(mut trait_item) = syn::parse::<ItemTrait>(input.clone()) {
130        let output = if cfg!(feature = "async") {
131            for item in &mut trait_item.items {
132                if let TraitItem::Fn(method) = item {
133                    // Remove the #[maybe_async] and make method async
134                    method.attrs.retain(|attr| {
135                        if attr.path().is_ident("maybe_async") {
136                            method.sig.asyncness = Some(syn::token::Async::default());
137                            false
138                        } else {
139                            true
140                        }
141                    });
142                }
143            }
144
145            quote! {
146                #[async_trait::async_trait(?Send)]
147                #trait_item
148            }
149        } else {
150            quote! {
151                #trait_item
152            }
153        };
154
155        return output.into();
156    }
157    // Check if it is an Impl block
158    else if let Ok(mut impl_item) = syn::parse::<ItemImpl>(input.clone()) {
159        let output = if cfg!(feature = "async") {
160            for item in &mut impl_item.items {
161                if let ImplItem::Fn(method) = item {
162                    // Remove #[maybe_async] and make method async
163                    method.attrs.retain(|attr| {
164                        if attr.path().is_ident("maybe_async") {
165                            method.sig.asyncness = Some(syn::token::Async::default());
166                            false // Remove the attribute
167                        } else {
168                            true // Keep other attributes
169                        }
170                    });
171                }
172            }
173            quote! {
174                #[async_trait::async_trait(?Send)]
175                #impl_item
176            }
177        } else {
178            quote! {
179                #[cfg(not(feature = "async"))]
180                #impl_item
181            }
182        };
183
184        return output.into();
185    }
186
187    // If input is neither a trait nor an impl block, emit a compile-time error
188    quote! {
189        compile_error!("`maybe_async_trait` can only be applied to trait definitions and trait impl blocks");
190    }.into()
191}
192
193/// Parses an expression and conditionally adds the `.await` keyword at the end of it depending on
194/// the `async` feature flag being enabled.
195///
196/// ```ignore
197/// #[maybe_async]
198/// fn hello_world() {
199///     // Adding `maybe_await` to an expression
200///     let w = maybe_await!(world());
201///
202///     println!("hello {}", w);
203/// }
204///
205/// #[maybe_async]
206/// fn world() -> String {
207///     "world".to_string()
208/// }
209/// ```
210///
211/// When the `async` feature is enabled, will be transformed into:
212/// ```ignore
213/// async fn hello_world() {
214///     let w = world().await;
215///
216///     println!("hello {}", w);
217/// }
218///
219/// async fn world() -> String {
220///     "world".to_string()
221/// }
222/// ```
223#[proc_macro]
224pub fn maybe_await(input: TokenStream) -> TokenStream {
225    let item = parse_macro_input!(input as Expr);
226
227    let quote = if cfg!(feature = "async") {
228        quote!(#item.await)
229    } else {
230        quote!(#item)
231    };
232
233    quote.into()
234}