torn_api_macros/
lib.rs

1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4
5#[proc_macro_derive(ApiCategory, attributes(api))]
6pub fn derive_api_category(input: TokenStream) -> TokenStream {
7    let ast = syn::parse(input).unwrap();
8
9    impl_api_category(&ast)
10}
11
12#[derive(Debug)]
13enum ApiField {
14    Property(syn::Ident),
15    Flattened,
16}
17
18#[derive(Debug)]
19struct ApiAttribute {
20    field: ApiField,
21    name: syn::Ident,
22    raw_value: String,
23    variant: syn::Ident,
24    type_name: proc_macro2::TokenStream,
25    with: Option<syn::Ident>,
26}
27
28fn impl_api_category(ast: &syn::DeriveInput) -> TokenStream {
29    let name = &ast.ident;
30
31    let enum_ = match &ast.data {
32        syn::Data::Enum(data) => data,
33        _ => panic!("ApiCategory can only be derived for enums"),
34    };
35
36    let mut category: Option<String> = None;
37    for attr in &ast.attrs {
38        if attr.path().is_ident("api") {
39            attr.parse_nested_meta(|meta| {
40                if meta.path.is_ident("category") {
41                    let c: syn::LitStr = meta.value()?.parse()?;
42                    category = Some(c.value());
43                    Ok(())
44                } else {
45                    Err(meta.error("unknown attribute"))
46                }
47            })
48            .unwrap();
49        }
50    }
51
52    let category = category.expect("`category`");
53
54    let fields: Vec<_> = enum_
55        .variants
56        .iter()
57        .filter_map(|variant| {
58            let mut r#type: Option<String> = None;
59            let mut field: Option<ApiField> = None;
60            let mut with: Option<proc_macro2::Ident> = None;
61            for attr in &variant.attrs {
62                if attr.path().is_ident("api") {
63                    attr.parse_nested_meta(|meta| {
64                        if meta.path.is_ident("type") {
65                            let t: syn::LitStr = meta.value()?.parse()?;
66                            r#type = Some(t.value());
67                            Ok(())
68                        } else if meta.path.is_ident("with") {
69                            let w: syn::LitStr = meta.value()?.parse()?;
70                            with = Some(quote::format_ident!("{}", w.value()));
71                            Ok(())
72                        } else if meta.path.is_ident("field") {
73                            let f: syn::LitStr = meta.value()?.parse()?;
74                            field = Some(ApiField::Property(quote::format_ident!("{}", f.value())));
75                            Ok(())
76                        } else if meta.path.is_ident("flatten") {
77                            field = Some(ApiField::Flattened);
78                            Ok(())
79                        } else {
80                            Err(meta.error("unsupported attribute"))
81                        }
82                    })
83                    .unwrap();
84                    let name = format_ident!("{}", variant.ident.to_string().to_case(Case::Snake));
85                    let raw_value = variant.ident.to_string().to_lowercase();
86                    return Some(ApiAttribute {
87                        field: field.expect("field or flatten attribute must be specified"),
88                        raw_value,
89                        variant: variant.ident.clone(),
90                        type_name: r#type.expect("type must be specified").parse().unwrap(),
91                        name,
92                        with,
93                    });
94                }
95            }
96            None
97        })
98        .collect();
99
100    let accessors = fields.iter().map(
101        |ApiAttribute {
102             field,
103             name,
104             type_name,
105             with,
106             ..
107         }| match (field, with) {
108            (ApiField::Property(prop), None) => {
109                let prop_str = prop.to_string();
110                quote! {
111                    pub fn #name(&self) -> serde_json::Result<#type_name> {
112                        self.0.decode_field(#prop_str)
113                    }
114                }
115            }
116            (ApiField::Property(prop), Some(f)) => {
117                let prop_str = prop.to_string();
118                quote! {
119                    pub fn #name(&self) -> serde_json::Result<#type_name> {
120                        self.0.decode_field_with(#prop_str, #f)
121                    }
122                }
123            }
124            (ApiField::Flattened, None) => quote! {
125                pub fn #name(&self) -> serde_json::Result<#type_name> {
126                    self.0.decode()
127                }
128            },
129            (ApiField::Flattened, Some(_)) => todo!(),
130        },
131    );
132
133    let raw_values = fields.iter().map(
134        |ApiAttribute {
135             variant, raw_value, ..
136         }| {
137            quote! {
138                #name::#variant => #raw_value
139            }
140        },
141    );
142
143    let gen = quote! {
144        pub struct Response(pub crate::ApiResponse);
145
146        impl Response {
147            #(#accessors)*
148        }
149
150        impl From<crate::ApiResponse> for Response {
151            fn from(value: crate::ApiResponse) -> Self {
152                Self(value)
153            }
154        }
155
156        impl crate::ApiSelectionResponse for Response {
157            fn into_inner(self) -> crate::ApiResponse {
158                self.0
159            }
160        }
161
162        impl crate::ApiSelection for #name {
163            type Response = Response;
164
165            fn raw_value(self) -> &'static str {
166                match self {
167                    #(#raw_values,)*
168                }
169            }
170
171            fn category() -> &'static str {
172                #category
173            }
174        }
175    };
176
177    gen.into()
178}
179
180#[proc_macro_derive(IntoOwned, attributes(into_owned))]
181pub fn derive_into_owned(input: TokenStream) -> TokenStream {
182    let ast = syn::parse(input).unwrap();
183
184    impl_into_owned(&ast)
185}
186
187fn to_static_lt(ty: &mut syn::Type) -> bool {
188    let mut res = false;
189    match ty {
190        syn::Type::Path(path) => {
191            if let Some(syn::PathArguments::AngleBracketed(ab)) = path
192                .path
193                .segments
194                .last_mut()
195                .map(|s| &mut s.arguments)
196                .as_mut()
197            {
198                for mut arg in &mut ab.args {
199                    match &mut arg {
200                        syn::GenericArgument::Type(ty) => {
201                            if to_static_lt(ty) {
202                                res = true;
203                            }
204                        }
205                        syn::GenericArgument::Lifetime(lt) => {
206                            lt.ident = syn::Ident::new("static", proc_macro2::Span::call_site());
207                            res = true;
208                        }
209                        _ => (),
210                    }
211                }
212            }
213        }
214        syn::Type::Reference(r) => {
215            if let Some(lt) = r.lifetime.as_mut() {
216                lt.ident = syn::Ident::new("static", proc_macro2::Span::call_site());
217                res = true;
218            }
219            to_static_lt(&mut r.elem);
220        }
221        _ => (),
222    };
223    res
224}
225
226fn impl_into_owned(ast: &syn::DeriveInput) -> TokenStream {
227    let name = &ast.ident;
228    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
229
230    let mut identity = false;
231    for attr in &ast.attrs {
232        if attr.path().is_ident("into_owned") {
233            attr.parse_nested_meta(|meta| {
234                if meta.path.is_ident("identity") {
235                    identity = true;
236                    Ok(())
237                } else {
238                    Err(meta.error("unknown attribute"))
239                }
240            })
241            .unwrap();
242        }
243    }
244
245    if identity {
246        return quote! {
247            impl #impl_generics crate::into_owned::IntoOwned for #name #ty_generics #where_clause {
248                type Owned = Self;
249                fn into_owned(self) -> Self::Owned {
250                    self
251                }
252            }
253        }
254        .into();
255    }
256
257    let syn::Data::Struct(r#struct) = &ast.data else {
258        panic!("Only stucts are supported");
259    };
260
261    let syn::Fields::Named(named_fields) = &r#struct.fields else {
262        panic!("Only named fields are supported");
263    };
264
265    let vis = &ast.vis;
266
267    for attr in &ast.attrs {
268        if attr.path().is_ident("identity") {
269            //
270        }
271    }
272
273    let mut owned_fields = Vec::with_capacity(named_fields.named.len());
274    let mut fields = Vec::with_capacity(named_fields.named.len());
275
276    for field in &named_fields.named {
277        let field_name = &field.ident.as_ref().unwrap();
278        let mut ty = field.ty.clone();
279        let vis = &field.vis;
280
281        if to_static_lt(&mut ty) {
282            owned_fields
283                .push(quote! { #vis #field_name: <#ty as crate::into_owned::IntoOwned>::Owned });
284            fields.push(
285                quote! { #field_name: crate::into_owned::IntoOwned::into_owned(self.#field_name) },
286            );
287        } else {
288            owned_fields.push(quote! { #vis #field_name: #ty });
289            fields.push(quote! { #field_name: self.#field_name });
290        };
291    }
292
293    let owned_name = syn::Ident::new(
294        &format!("{}Owned", ast.ident),
295        proc_macro2::Span::call_site(),
296    );
297
298    let gen = quote! {
299        #[derive(Debug, Clone)]
300        #vis struct #owned_name {
301            #(#owned_fields,)*
302        }
303        impl #impl_generics crate::into_owned::IntoOwned for #name #ty_generics #where_clause {
304            type Owned = #owned_name;
305            fn into_owned(self) -> Self::Owned {
306                #owned_name {
307                    #(#fields,)*
308                }
309            }
310        }
311    };
312
313    gen.into()
314}