spa_server_derive/
lib.rs

1mod embed;
2mod utils;
3
4use embed::impl_embed;
5use proc_macro2::{Span, TokenStream};
6use quote::quote;
7use syn::{
8    parse_macro_input, spanned::Spanned, DeriveInput, Error, FnArg, Meta, NestedMeta, Pat, Path,
9    Result,
10};
11use utils::{FromLit, LitWrap};
12
13#[proc_macro_derive(SPAServer, attributes(spa_server))]
14pub fn derive_spa_server(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
15    let input = parse_macro_input!(input as DeriveInput);
16    expand(input)
17        .unwrap_or_else(|e| e.into_compile_error())
18        .into()
19}
20
21fn get_name_value<'a, T, I, S>(metas: I, key: S) -> Option<T>
22where
23    T: FromLit,
24    I: Iterator<Item = &'a Meta>,
25    S: AsRef<str>,
26{
27    for m in metas {
28        if let Meta::NameValue(nv) = m {
29            if let Some(ident) = nv.path.get_ident() {
30                if ident == key.as_ref() {
31                    let lw = LitWrap { inner: &nv.lit };
32                    if let Ok(r) = lw.parse() {
33                        return Some(r);
34                    }
35                }
36            }
37        }
38    }
39
40    None
41}
42
43fn get_path<'a>(metas: impl Iterator<Item = &'a Meta>, key: impl AsRef<str>) -> bool {
44    for m in metas {
45        if let Meta::Path(p) = m {
46            if let Some(ident) = p.get_ident() {
47                if ident == key.as_ref() {
48                    return true;
49                }
50            }
51        }
52    }
53
54    false
55}
56
57fn expand(input: DeriveInput) -> Result<TokenStream> {
58    let name = &input.ident;
59    let attrs = &input.attrs;
60    let mut opt = Options::default();
61    for attr in attrs {
62        if let Meta::List(l) = attr.parse_meta()? {
63            if let Some(id) = l.path.get_ident() {
64                if id != "spa_server" {
65                    return Err(Error::new(l.span(), "only support attribute spa_server"));
66                }
67            }
68
69            let metas = l.nested.iter().filter_map(|x| {
70                if let NestedMeta::Meta(m) = x {
71                    Some(m)
72                } else {
73                    None
74                }
75            });
76
77            opt.static_files = get_name_value(metas.clone(), "static_files").ok_or(Error::new(
78                Span::call_site(),
79                "must set static files path in attribute",
80            ))?;
81
82            opt.cors = get_path(metas.clone(), "cors");
83            if !opt.cors {
84                if let Some(cors) = get_name_value(metas.clone(), "cors") {
85                    opt.cors = cors;
86                }
87            }
88
89            for m in metas {
90                match m {
91                    Meta::List(l) => {
92                        if let Some(id) = l.path.get_ident() {
93                            if id == "apis" {
94                                for api in &l.nested {
95                                    if let NestedMeta::Meta(meta) = api {
96                                        if let Meta::List(pl) = meta {
97                                            if let Some(iid) = pl.path.get_ident() {
98                                                if iid == "api" {
99                                                    let mut api_path = Vec::new();
100                                                    let mut prefix = None;
101                                                    for ppl in &pl.nested {
102                                                        if let NestedMeta::Meta(mm) = ppl {
103                                                            match mm {
104                                                                Meta::Path(p) => {
105                                                                    api_path.push(p.clone())
106                                                                }
107                                                                Meta::NameValue(nv) => {
108                                                                    if let Some(iiid) =
109                                                                        nv.path.get_ident()
110                                                                    {
111                                                                        if iiid == "prefix" {
112                                                                            let lw = LitWrap {
113                                                                                inner: &nv.lit,
114                                                                            };
115                                                                            if let Ok(r) =
116                                                                                lw.parse::<String>()
117                                                                            {
118                                                                                prefix = Some(r);
119                                                                            }
120                                                                        }
121                                                                    }
122                                                                }
123                                                                _ => {}
124                                                            }
125                                                        }
126                                                    }
127
128                                                    opt.apis.push(Api {
129                                                        path: api_path,
130                                                        prefix,
131                                                    });
132                                                }
133                                            }
134                                        }
135                                    }
136                                }
137                            } else if id == "identity" {
138                                let mut identity = Identity::default();
139                                for nm in &l.nested {
140                                    if let NestedMeta::Meta(meta) = nm {
141                                        if let Meta::NameValue(nv) = meta {
142                                            if let Some(iid) = nv.path.get_ident() {
143                                                if iid == "name" {
144                                                    let lit = LitWrap { inner: &nv.lit };
145                                                    if let Ok(name) = lit.parse() {
146                                                        identity.name = name;
147                                                    }
148                                                } else if iid == "age" {
149                                                    let lit = LitWrap { inner: &nv.lit };
150                                                    if let Ok(age) = lit.parse() {
151                                                        identity.age = age;
152                                                    }
153                                                }
154                                            }
155                                        }
156                                    }
157                                }
158
159                                if !identity.name.is_empty() && identity.age != 0 {
160                                    opt.identity = Some(identity);
161                                }
162                            }
163                        }
164                    }
165                    _ => {}
166                }
167            }
168        }
169    }
170
171    let mut services = Vec::new();
172    for api in opt.apis {
173        let api_list = api.path;
174        match api.prefix {
175            Some(p) => {
176                services.push(quote! {
177                    .service(
178                        web::scope(#p)
179                        #(.service(#api_list))*
180                        .app_data(data.clone())
181                    )
182                });
183            }
184            None => {
185                services.push(quote! {
186                    #(.service(#api_list))*
187                    .app_data(data.clone())
188                });
189            }
190        }
191    }
192
193    let cors = if opt.cors {
194        quote! { .wrap(spa_server::re_export::Cors::permissive()) }
195    } else {
196        TokenStream::new()
197    };
198
199    let identity = if let Some(id) = opt.identity {
200        let name = id.name;
201        let age = id.age;
202        quote! {
203            .wrap(spa_server::re_export::IdentityService::new(
204                spa_server::re_export::CookieIdentityPolicy::new(&[0; 32])
205                    .name(#name)
206                    .max_age_time(spa_server::Duration::minutes(#age))
207                    .http_only(true)
208                    .secure(false)
209            ))
210        }
211    } else {
212        TokenStream::new()
213    };
214
215    let embed_tokens = impl_embed(name, &opt.static_files, None);
216
217    Ok(quote! {
218        use spa_server::re_export::{
219            App, HttpServer, rt::System, web, Files
220        };
221        use spa_server::{Embed, Filenames};
222        use std::borrow::Cow;
223
224        impl #name {
225            pub async fn run(self, port: u16) -> Result<(), Box<dyn std::error::Error>> {
226                let root_path = spa_server::release_asset::<#name>()?;
227                let data = web::Data::new(self);
228
229                HttpServer::new(move || {
230                    App::new()
231                        #identity
232                        #cors
233                        #(#services)*
234                        .data(root_path.clone())
235                        .service(spa_server::index)
236                        .service(Files::new("/", root_path.clone()).index_file("index.html"))
237                })
238                .bind(format!("0.0.0.0:{}", port))?
239                .run()
240                .await?;
241
242                Ok(())
243            }
244        }
245
246        #embed_tokens
247    })
248}
249
250#[derive(Default)]
251struct Options {
252    apis: Vec<Api>,
253    static_files: String,
254    cors: bool,
255    identity: Option<Identity>,
256}
257
258#[derive(Default)]
259struct Api {
260    path: Vec<Path>,
261    prefix: Option<String>,
262}
263
264#[allow(dead_code)]
265#[derive(Default)]
266struct Identity {
267    name: String,
268    age: i64,
269}
270
271#[proc_macro_attribute]
272pub fn main(_: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
273    let mut input = syn::parse_macro_input!(item as syn::ItemFn);
274    let attrs = &input.attrs;
275    let vis = &input.vis;
276    let sig = &mut input.sig;
277    let body = &input.block;
278
279    if sig.asyncness.is_none() {
280        return Error::new_spanned(sig.fn_token, "only async fn is supported")
281            .to_compile_error()
282            .into();
283    }
284
285    sig.asyncness = None;
286
287    (quote! {
288        #(#attrs)*
289        #vis #sig {
290            spa_server::re_export::rt::System::new()
291                .block_on(async move { #body })
292        }
293    })
294    .into()
295}
296
297mod route;
298
299macro_rules! method_macro {
300    (
301        $($variant:ident, $method:ident,)+
302    ) => {
303        $(
304            #[proc_macro_attribute]
305            pub fn $method(args: proc_macro::TokenStream, input: proc_macro::TokenStream) -> proc_macro::TokenStream {
306                route::with_method(Some(route::MethodType::$variant), args, input)
307            }
308        )+
309    };
310}
311
312method_macro! {
313    Get,       get,
314    Post,      post,
315    Put,       put,
316    Delete,    delete,
317    Head,      head,
318    Connect,   connect,
319    Options,   options,
320    Trace,     trace,
321    Patch,     patch,
322}
323
324#[proc_macro_attribute]
325pub fn error_to_json(
326    _: proc_macro::TokenStream,
327    item: proc_macro::TokenStream,
328) -> proc_macro::TokenStream {
329    let mut input = syn::parse_macro_input!(item as syn::ItemFn);
330    let attrs = &input.attrs;
331    let vis = &input.vis;
332    let sig = &mut input.sig;
333    let body = &input.block;
334
335    let mut sig_impl = sig.clone();
336    sig_impl.ident = syn::Ident::new(&format!("_{}_impl", sig.ident), sig.span());
337    let sig_impl_ident = &sig_impl.ident;
338
339    let mut args = Vec::new();
340    for i in &sig.inputs {
341        if let FnArg::Typed(p) = i {
342            if let Pat::Ident(id) = &*p.pat {
343                args.push(id.ident.clone());
344            }
345        }
346    }
347
348    (quote! {
349        #[allow(unused_mut)]
350        #(#attrs)*
351        #vis #sig {
352            Ok(match #sig_impl_ident(#(#args),*).await {
353                Ok(a) => a,
354                Err(e) => spa_server::re_export::HttpResponse::Ok().json(spa_server::quick_err(format!("{:?}", e)))
355            })
356        }
357
358        #vis #sig_impl {
359            #body
360        }
361    })
362    .into()
363}