webwire_cli/codegen/
rust.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3
4use crate::schema::{self, TypeRef};
5
6pub fn gen(doc: &schema::Document) -> String {
7    let stream = generate(doc);
8    let code = format!("{}", stream);
9    code
10}
11
12fn optional(stream: TokenStream) -> TokenStream {
13    quote! {
14        Option<#stream>
15    }
16}
17
18pub fn generate(doc: &schema::Document) -> TokenStream {
19    let namespace = gen_namespace(&doc.ns);
20    quote! {
21        #[allow(dead_code)]
22        #namespace
23    }
24}
25
26fn gen_namespace(ns: &schema::Namespace) -> TokenStream {
27    let mut stream = TokenStream::new();
28    for type_ in ns.types.values() {
29        let type_stream = gen_type(type_, &ns.path);
30        stream.extend(type_stream);
31    }
32    for service in ns.services.values() {
33        let service_stream = gen_service(service, &ns.path);
34        stream.extend(service_stream);
35        let provider_stream = gen_provider(service, &ns.path);
36        stream.extend(provider_stream);
37        let consumer_stream = gen_consumer(service, &ns.path);
38        stream.extend(consumer_stream);
39    }
40    for child_ns in ns.namespaces.values() {
41        let child_ns_name = quote::format_ident!("{}", child_ns.name());
42        let child_ns_stream = gen_namespace(child_ns);
43        stream.extend(quote! {
44            pub mod #child_ns_name {
45                #child_ns_stream
46            }
47        });
48    }
49    stream
50}
51
52fn gen_type(type_: &schema::UserDefinedType, ns: &[String]) -> TokenStream {
53    match type_ {
54        schema::UserDefinedType::Enum(enum_) => gen_enum(&enum_.borrow(), ns),
55        schema::UserDefinedType::Struct(struct_) => gen_struct(&struct_.borrow(), ns),
56        schema::UserDefinedType::Fieldset(fieldset) => gen_fieldset(&fieldset.borrow(), ns),
57    }
58}
59
60fn gen_enum(enum_: &schema::Enum, ns: &[String]) -> TokenStream {
61    let name = quote::format_ident!("{}", &enum_.fqtn.name);
62    let variants = gen_enum_variants(enum_, ns);
63    let mut stream = TokenStream::new();
64    stream.extend(quote! {
65        #[derive(Clone, Debug, Eq, PartialEq, ::serde::Serialize, ::serde::Deserialize)]
66        pub enum #name {
67            #variants
68        }
69    });
70    if let Some(extends) = &enum_.extends {
71        let extends_typeref = gen_typeref_ref(&extends, ns);
72        let mut matches = TokenStream::new();
73        let extends_enum = enum_.extends_enum().unwrap();
74        for variant in extends_enum.borrow().all_variants.iter() {
75            let variant_name = quote::format_ident!("{}", variant.name);
76            matches.extend(quote! {
77                #extends_typeref::#variant_name => #name::#variant_name,
78            });
79        }
80        stream.extend(quote! {
81            impl From<#extends_typeref> for #name {
82                fn from(other: #extends_typeref) -> Self {
83                    match other {
84                        #matches
85                    }
86                }
87            }
88        });
89    }
90    stream
91}
92
93fn gen_enum_variants(enum_: &schema::Enum, ns: &[String]) -> TokenStream {
94    let mut stream = TokenStream::new();
95    for variant in enum_.variants.iter() {
96        stream.extend(gen_enum_variant(variant, ns));
97    }
98    if let Some(extends) = enum_.extends_enum() {
99        stream.extend(gen_enum_variants(&extends.borrow(), ns));
100    }
101    stream
102}
103
104fn gen_enum_variant(variant: &schema::EnumVariant, ns: &[String]) -> TokenStream {
105    let name = quote::format_ident!("{}", variant.name);
106    if let Some(value_type) = &variant.value_type {
107        let value_type = gen_typeref(value_type, ns);
108        quote! {
109            #name(#value_type),
110        }
111    } else {
112        quote! {
113            #name,
114        }
115    }
116}
117
118fn gen_struct(struct_: &schema::Struct, ns: &[String]) -> TokenStream {
119    let name = quote::format_ident!("{}", &struct_.fqtn.name);
120    let fields = gen_struct_fields(struct_, ns);
121    quote! {
122        #[derive(Clone, Debug, Eq, PartialEq, ::serde::Serialize, ::serde::Deserialize, ::validator::Validate)]
123        pub struct #name {
124            #fields
125        }
126    }
127}
128
129fn gen_struct_fields(struct_: &schema::Struct, ns: &[String]) -> TokenStream {
130    let mut stream = TokenStream::new();
131    for field in struct_.fields.iter() {
132        stream.extend(gen_struct_field(field, ns))
133    }
134    stream
135}
136
137fn gen_struct_field(field: &schema::Field, ns: &[String]) -> TokenStream {
138    let name = quote::format_ident!("{}", field.name);
139    let mut type_ = gen_typeref(&field.type_, ns);
140    if field.optional {
141        type_ = optional(type_);
142    }
143    let validation_macros = gen_validation_macros(field);
144    quote! {
145        #validation_macros
146        pub #name: #type_,
147    }
148}
149
150fn gen_validation_macros(field: &schema::Field) -> TokenStream {
151    let mut rules = TokenStream::new();
152    match field.length {
153        (Some(min), Some(max)) => rules.extend(quote! { length(min=#min, max=#max) }),
154        (Some(min), None) => rules.extend(quote! { length(min=#min) }),
155        (None, Some(max)) => rules.extend(quote! { length(max=#max) }),
156        (None, None) => {}
157    }
158    if rules.is_empty() {
159        quote! {}
160    } else {
161        quote! {
162            #[validate(#rules)]
163        }
164    }
165}
166
167fn gen_fieldset(fieldset: &schema::Fieldset, ns: &[String]) -> TokenStream {
168    let name = quote::format_ident!("{}", &fieldset.fqtn.name);
169    let fields = gen_fieldset_fields(fieldset, ns);
170    quote! {
171        #[derive(Clone, Debug, Eq, PartialEq, ::serde::Serialize, ::serde::Deserialize, ::validator::Validate)]
172        pub struct #name {
173            #fields
174        }
175    }
176}
177
178fn gen_fieldset_fields(struct_: &schema::Fieldset, ns: &[String]) -> TokenStream {
179    let mut stream = TokenStream::new();
180    for field in struct_.fields.iter() {
181        stream.extend(gen_fieldset_field(field, ns))
182    }
183    stream
184}
185
186fn gen_fieldset_field(field: &schema::FieldsetField, ns: &[String]) -> TokenStream {
187    let name = quote::format_ident!("{}", field.name);
188    let mut type_ = gen_typeref(&field.field.as_ref().unwrap().type_, ns);
189    if field.optional {
190        type_ = optional(type_);
191    }
192    quote! {
193        pub #name: #type_,
194    }
195}
196
197fn gen_service(service: &schema::Service, ns: &[String]) -> TokenStream {
198    let service_name = quote::format_ident!("{}", &service.name);
199    let methods = gen_service_methods(&service, ns);
200    quote! {
201        #[::async_trait::async_trait]
202        pub trait #service_name {
203            type Error: Into<::webwire::ProviderError>;
204            #methods
205        }
206    }
207}
208
209fn gen_service_methods(service: &schema::Service, ns: &[String]) -> TokenStream {
210    let mut stream = TokenStream::new();
211    for method in service.methods.iter() {
212        let signature = gen_service_method_signature(method, ns);
213        stream.extend(quote! {
214            #signature;
215        })
216    }
217    stream
218}
219
220fn gen_service_method_signature(method: &schema::Method, ns: &[String]) -> TokenStream {
221    let name = quote::format_ident!("{}", method.name);
222    let input_arg = match &method.input {
223        Some(type_) => {
224            let input_type = gen_typeref(type_, ns);
225            quote! { input: & #input_type }
226        }
227        None => quote! {},
228    };
229    let output = match &method.output {
230        Some(type_) => gen_typeref(type_, ns),
231        None => quote! { () },
232    };
233    quote! {
234        async fn #name(&self, #input_arg) -> Result<#output, Self::Error>
235    }
236}
237
238fn gen_provider(service: &schema::Service, ns: &[String]) -> TokenStream {
239    let service_name = quote::format_ident!("{}", service.name);
240    let service_name_str = if ns.is_empty() {
241        service.name.to_owned()
242    } else {
243        format!("{}.{}", ns.join("."), &service.name)
244    };
245    let provider_name = quote::format_ident!("{}Provider", service.name);
246    let matches = gen_provider_matches(&service, ns);
247    quote! {
248        pub struct #provider_name<F>(pub F);
249        // NamedProvider impl
250        impl<F: Sync + Send, S: Sync + Send, T: Sync + Send> ::webwire::NamedProvider<S> for #provider_name<F>
251        where
252            F: Fn(::std::sync::Arc<S>) -> T,
253            T: #service_name + 'static,
254        {
255            const NAME: &'static str = #service_name_str;
256        }
257        // Provider impl
258        impl<F: Sync + Send, S: Sync + Send, T: Sync + Send> ::webwire::Provider<S> for #provider_name<F>
259        where
260            F: Fn(::std::sync::Arc<S>) -> T,
261            T: #service_name + 'static,
262        {
263            fn call(
264                &self,
265                session: &::std::sync::Arc<S>,
266                _service: &str,
267                method: &str,
268                input: ::bytes::Bytes,
269            ) -> ::futures::future::BoxFuture<'static, Result<::bytes::Bytes, ::webwire::ProviderError>> {
270                let service = self.0(session.clone());
271                match method {
272                    #matches
273                    _ => Box::pin(::futures::future::ready(Err(::webwire::ProviderError::MethodNotFound))),
274                }
275            }
276        }
277    }
278}
279
280fn gen_provider_matches(service: &schema::Service, ns: &[String]) -> TokenStream {
281    let mut stream = TokenStream::new();
282    for method in service.methods.iter() {
283        let name = quote::format_ident!("{}", method.name);
284        let name_str = &method.name;
285        let input = match &method.input {
286            Some(type_) => gen_typeref(type_, ns),
287            None => quote! { () },
288        };
289        /*
290        let output = match &method.output {
291            Some(type_) => gen_typeref(type_),
292            None => quote! { () },
293        };
294        */
295        let method_call = match &method.input {
296            None => quote! {
297                let output = service.#name().await.map_err(|e| e.into())?;
298            },
299            Some(type_) => {
300                let validation = if type_.is_scalar() {
301                    quote! {}
302                } else {
303                    quote! {
304                        ::validator::Validate::validate(&input).map_err(::webwire::ProviderError::ValidationError)?;
305                    }
306                };
307                quote! {
308                    let input = serde_json::from_slice::<#input>(&input)
309                            .map_err(::webwire::ProviderError::DeserializerError)?;
310                    #validation
311                    let output = service.#name(&input).await.map_err(|e| e.into())?;
312                }
313            }
314        };
315        stream.extend(quote! {
316            #name_str => Box::pin(async move {
317                #method_call
318                let response = serde_json::to_vec(&output)
319                    .map_err(|e| ::webwire::ProviderError::SerializerError(e))
320                    .map(::bytes::Bytes::from)?;
321                Ok(response)
322            }),
323        });
324    }
325    stream
326}
327
328fn gen_consumer(service: &schema::Service, ns: &[String]) -> TokenStream {
329    let consumer_name = quote::format_ident!("{}Consumer", service.name);
330    let consumer_methods = gen_consumer_methods(&service, ns);
331    quote! {
332        pub struct #consumer_name<'a>(pub &'a (dyn ::webwire::Consumer + ::std::marker::Sync + ::std::marker::Send));
333        impl<'a> #consumer_name<'a> {
334            #consumer_methods
335        }
336    }
337}
338
339fn gen_consumer_methods(service: &schema::Service, ns: &[String]) -> TokenStream {
340    let mut stream = TokenStream::new();
341    let service_name_str = if ns.is_empty() {
342        service.name.to_owned()
343    } else {
344        format!("{}.{}", ns.join("."), &service.name)
345    };
346    for method in service.methods.iter() {
347        let signature = gen_consumer_method_signature(method, ns);
348        let method_name_str = &method.name;
349        let serialization = match method.input {
350            Some(_) => quote! {
351                let data: ::bytes::Bytes = serde_json::to_vec(input)
352                    .map_err(|e| ::webwire::ConsumerError::SerializerError(e))?
353                    .into();
354            },
355            None => quote! {
356                let data = ::bytes::Bytes::new();
357            },
358        };
359        stream.extend(quote! {
360            #signature {
361                #serialization
362                let output = self.0.request(#service_name_str, #method_name_str, data).await?;
363                let response = ::serde_json::from_slice(&output)
364                    .map_err(|e| ::webwire::ConsumerError::DeserializerError(e))?;
365                Ok(response)
366            }
367        })
368    }
369    stream
370}
371
372fn gen_consumer_method_signature(method: &schema::Method, ns: &[String]) -> TokenStream {
373    let name = quote::format_ident!("{}", method.name);
374    let input_arg = match &method.input {
375        Some(type_) => {
376            let input_type = gen_typeref(type_, ns);
377            quote! { input: & #input_type }
378        }
379        None => quote! {},
380    };
381    let output = match &method.output {
382        Some(type_) => gen_typeref(type_, ns),
383        None => quote! { () },
384    };
385    quote! {
386        pub async fn #name(&self, #input_arg) -> Result<#output, ::webwire::ConsumerError>
387    }
388}
389
390fn gen_typeref(type_: &schema::Type, ns: &[String]) -> TokenStream {
391    match type_ {
392        schema::Type::None => quote! { () },
393        schema::Type::Boolean => quote! { bool },
394        schema::Type::Integer => quote! { i64 },
395        schema::Type::Float => quote! { f64 },
396        schema::Type::String => quote! { String },
397        schema::Type::UUID => quote! { ::uuid::Uuid },
398        schema::Type::Date => quote! { ::chrono::Date },
399        schema::Type::Time => quote! { ::chrono::Time },
400        schema::Type::DateTime => quote! { ::chrono::DateTime<::chrono::Utc> },
401        schema::Type::Option(some) => {
402            let some_type = gen_typeref(some, ns);
403            quote! { std::option::Option<#some_type> }
404        }
405        schema::Type::Result(ok, err) => {
406            let ok_type = gen_typeref(ok, ns);
407            let err_type = gen_typeref(err, ns);
408            quote! { std::result::Result<#ok_type, #err_type> }
409        }
410        // complex types
411        schema::Type::Array(array) => {
412            let item_type = gen_typeref(&array.item_type, ns);
413            quote! {
414                std::vec::Vec<#item_type>
415            }
416        }
417        schema::Type::Map(map) => {
418            let key_type = gen_typeref(&map.key_type, ns);
419            let value_type = gen_typeref(&map.value_type, ns);
420            quote! {
421                std::collections::HashMap<#key_type, #value_type>
422            }
423        }
424        // named
425        schema::Type::Ref(typeref) => {
426            gen_typeref_ref(typeref, ns)
427        }
428        schema::Type::Builtin(name) => {
429            // FIXME unwrap... igh!
430            let identifier: TokenStream = ::syn::parse_str(name).unwrap();
431            quote! {
432                #identifier
433            }
434        }
435    }
436}
437
438fn gen_typeref_ref(typeref: &TypeRef, ns: &[String]) -> TokenStream {
439    let mut generics_stream = TokenStream::new();
440    if !typeref.generics().is_empty() {
441        for generic in typeref.generics().iter() {
442            let type_ = gen_typeref(generic, ns);
443            generics_stream.extend(quote! {
444                #type_,
445            })
446        }
447        generics_stream = quote! {
448            < #generics_stream >
449        }
450    }
451    let typeref_fqtn = typeref.fqtn();
452    let common_ns = typeref_fqtn
453        .ns
454        .iter()
455        .zip(ns.iter())
456        .take_while(|(a, b)| a == b)
457        .count();
458    let relative_ns = ns[common_ns..]
459        .iter()
460        .map(|_| quote::format_ident!("super"))
461        .chain(
462            typeref_fqtn.ns[common_ns..]
463                .iter()
464                .map(|x| quote::format_ident!("{}", x)),
465        )
466        .fold(TokenStream::new(), |mut stream, name| {
467            let name = quote::format_ident!("{}", name);
468            stream.extend(quote! { #name :: });
469            stream
470        });
471    // FIXME fqtn
472    match &*typeref_fqtn.name {
473        // FIXME `None` should be made into a buitlin type
474        "None" => quote! { () },
475        name => {
476            let name = quote::format_ident!("{}", name);
477            quote! {
478                #relative_ns #name #generics_stream
479            }
480        }
481    }
482}