tonic_build_codec/
server.rs

1use std::collections::HashMap;
2use super::{Attributes, Method, Service};
3use crate::{generate_doc_comment, generate_doc_comments, naive_snake_case};
4use proc_macro2::{Span, TokenStream, Spacing, Punct, TokenTree};
5use quote::quote;
6use syn::{Ident, Lit, LitStr};
7
8/// Generate service for Server.
9///
10/// This takes some `Service` and will generate a `TokenStream` that contains
11/// a public module containing the server service and handler trait.
12pub fn generate<T: Service>(
13    service: &T,
14    emit_package: bool,
15    proto_path: &str,
16    compile_well_known_types: bool,
17    attributes: &Attributes,
18    codec: &HashMap<String, String>,
19) -> TokenStream {
20    let methods = generate_methods(service, proto_path, compile_well_known_types, codec);
21    let endpoints = generate_endpoints(service, emit_package);
22
23    let server_service = quote::format_ident!("{}Server", service.name());
24    let server_trait = quote::format_ident!("{}", service.name());
25    let server_mod = quote::format_ident!("{}_server", naive_snake_case(&service.name()));
26    let generated_trait = generate_trait(
27        service,
28        proto_path,
29        compile_well_known_types,
30        server_trait.clone(),
31    );
32    let service_doc = generate_doc_comments(service.comment());
33    let package = if emit_package { service.package() } else { "go" };
34    // Transport based implementations
35    let path = format!(
36        "{}{}{}",
37        package,
38        ".",
39        service.identifier()
40    );
41    let transport = generate_transport(&server_service, &server_trait, &path);
42    let mod_attributes = attributes.for_mod(package);
43    let struct_attributes = attributes.for_struct(&path);
44
45    let compression_enabled = cfg!(feature = "compression");
46
47    let compression_config_ty = if compression_enabled {
48        quote! { EnabledCompressionEncodings }
49    } else {
50        quote! { () }
51    };
52
53    let configure_compression_methods = if compression_enabled {
54        quote! {
55            /// Enable decompressing requests with `gzip`.
56            pub fn accept_gzip(mut self) -> Self {
57                self.accept_compression_encodings.enable_gzip();
58                self
59            }
60
61            /// Compress responses with `gzip`, if the client supports it.
62            pub fn send_gzip(mut self) -> Self {
63                self.send_compression_encodings.enable_gzip();
64                self
65            }
66        }
67    } else {
68        quote! {}
69    };
70
71    quote! {
72        /// Generated server implementations.
73        #(#mod_attributes)*
74        pub mod #server_mod {
75            #![allow(
76                unused_variables,
77                dead_code,
78                missing_docs,
79                // will trigger if compression is disabled
80                clippy::let_unit_value,
81            )]
82            use tonic::codegen::*;
83
84            #generated_trait
85
86            #service_doc
87            #(#struct_attributes)*
88            #[derive(Debug)]
89            pub struct #server_service<T: #server_trait> {
90                pub inner: _Inner<T>,
91                accept_compression_encodings: #compression_config_ty,
92                send_compression_encodings: #compression_config_ty,
93            }
94
95            pub struct _Inner<T>(pub Arc<T>);
96
97            impl<T: #server_trait> #server_service<T> {
98                pub fn new(inner: T) -> Self {
99                    let inner = Arc::new(inner);
100                    let inner = _Inner(inner);
101                    Self {
102                        inner,
103                        accept_compression_encodings: Default::default(),
104                        send_compression_encodings: Default::default(),
105                    }
106                }
107
108                pub fn with_interceptor<F>(inner: T, interceptor: F) -> InterceptedService<Self, F>
109                where
110                    F: tonic::service::Interceptor,
111                {
112                    InterceptedService::new(Self::new(inner), interceptor)
113                }
114
115                fn match_uri(path: &str) -> (&str, &str) {
116                    let ss = path.split(".").collect::<Vec<&str>>();
117                    if ss.len() > 1 {
118                        let sx: Vec<&str> = ss.last().unwrap().split("/").collect();
119                        if sx.len() == 2 {
120                            (sx[0], sx[1])
121                        } else {
122                            ("", "")
123                        }
124                    } else {
125                        let sx: Vec<&str> = ss[0].trim_start_matches("/").split("/").collect();
126                        if sx.len() == 2 {
127                            (sx[0], sx[1])
128                        } else {
129                            ("", "")
130                        }
131                    }
132               }
133
134                #configure_compression_methods
135            }
136
137            impl<T, B> tonic::codegen::Service<http::Request<B>> for #server_service<T>
138                where
139                    T: #server_trait,
140                    B: Body + Send + 'static,
141                    B::Error: Into<StdError> + Send + 'static,
142            {
143                type Response = http::Response<tonic::body::BoxBody>;
144                type Error = Never;
145                type Future = BoxFuture<Self::Response, Self::Error>;
146
147                fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
148                    Poll::Ready(Ok(()))
149                }
150
151                fn call(&mut self, req: http::Request<B>) -> Self::Future {
152                    let inner = self.inner.clone();
153
154                    let (_, method) = Self::match_uri(req.uri().path());
155                    match method {
156                        #methods
157
158                        _ => Box::pin(async move {
159                            Ok(http::Response::builder()
160                               .status(200)
161                               .header("grpc-status", "12")
162                               .header("content-type", "application/grpc")
163                               .body(empty_body())
164                               .unwrap())
165                        }),
166                    }
167                }
168            }
169
170            impl<T: #server_trait> Clone for #server_service<T> {
171                fn clone(&self) -> Self {
172                    let inner = self.inner.clone();
173                    Self {
174                        inner,
175                        accept_compression_encodings: self.accept_compression_encodings,
176                        send_compression_encodings: self.send_compression_encodings,
177                    }
178                }
179            }
180
181            impl<T: #server_trait> Clone for _Inner<T> {
182                fn clone(&self) -> Self {
183                    Self(self.0.clone())
184                }
185            }
186
187            impl<T: std::fmt::Debug> std::fmt::Debug for _Inner<T> {
188                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189                   write!(f, "{:?}", self.0)
190                }
191            }
192
193            // export endpoint
194            impl <T :#server_trait> #server_service<T> {
195                pub fn endpoints(&self) -> Vec<String> {
196                    vec!(#endpoints)
197                }
198            }
199
200
201            #transport
202        }
203    }
204}
205
206fn generate_trait<T: Service>(
207    service: &T,
208    proto_path: &str,
209    compile_well_known_types: bool,
210    server_trait: Ident,
211) -> TokenStream {
212    let methods = generate_trait_methods(service, proto_path, compile_well_known_types);
213    let trait_doc = generate_doc_comment(&format!(
214        "Generated trait containing gRPC methods that should be implemented for use with {}Server.",
215        service.name()
216    ));
217
218    quote! {
219        #trait_doc
220        #[async_trait]
221        pub trait #server_trait : Send + Sync + 'static {
222            #methods
223        }
224    }
225}
226
227fn generate_trait_methods<T: Service>(
228    service: &T,
229    proto_path: &str,
230    compile_well_known_types: bool,
231) -> TokenStream {
232    let mut stream = TokenStream::new();
233
234    for method in service.methods() {
235        let name = quote::format_ident!("{}", method.name());
236
237        let (req_message, res_message) =
238            method.request_response_name(proto_path, compile_well_known_types);
239
240        let method_doc = generate_doc_comments(method.comment());
241
242        let method = match (method.client_streaming(), method.server_streaming()) {
243            (false, false) => {
244                quote! {
245                    #method_doc
246                    async fn #name(&self, request: tonic::Request<#req_message>)
247                        -> Result<tonic::Response<#res_message>, tonic::Status>;
248                }
249            }
250            (true, false) => {
251                quote! {
252                    #method_doc
253                    async fn #name(&self, request: tonic::Request<tonic::Streaming<#req_message>>)
254                        -> Result<tonic::Response<#res_message>, tonic::Status>;
255                }
256            }
257            (false, true) => {
258                let stream = quote::format_ident!("{}Stream", method.identifier());
259                let stream_doc = generate_doc_comment(&format!(
260                    "Server streaming response type for the {} method.",
261                    method.identifier()
262                ));
263
264                quote! {
265                    #stream_doc
266                    type #stream: futures_core::Stream<Item = Result<#res_message, tonic::Status>> + Send + 'static;
267
268                    #method_doc
269                    async fn #name(&self, request: tonic::Request<#req_message>)
270                        -> Result<tonic::Response<Self::#stream>, tonic::Status>;
271                }
272            }
273            (true, true) => {
274                let stream = quote::format_ident!("{}Stream", method.identifier());
275                let stream_doc = generate_doc_comment(&format!(
276                    "Server streaming response type for the {} method.",
277                    method.identifier()
278                ));
279
280                quote! {
281                    #stream_doc
282                    type #stream: futures_core::Stream<Item = Result<#res_message, tonic::Status>> + Send + 'static;
283
284                    #method_doc
285                    async fn #name(&self, request: tonic::Request<tonic::Streaming<#req_message>>)
286                        -> Result<tonic::Response<Self::#stream>, tonic::Status>;
287                }
288            }
289        };
290
291        stream.extend(method);
292    }
293
294    stream
295}
296
297#[cfg(feature = "transport")]
298fn generate_transport(
299    server_service: &syn::Ident,
300    server_trait: &syn::Ident,
301    service_name: &str,
302) -> TokenStream {
303    let service_name = syn::LitStr::new(service_name, proc_macro2::Span::call_site());
304
305    quote! {
306        impl<T: #server_trait> tonic::transport::NamedService for #server_service<T> {
307            const NAME: &'static str = #service_name;
308        }
309    }
310}
311
312#[cfg(not(feature = "transport"))]
313fn generate_transport(
314    _server_service: &syn::Ident,
315    _server_trait: &syn::Ident,
316    _service_name: &str,
317) -> TokenStream {
318    TokenStream::new()
319}
320
321fn generate_methods<T: Service>(
322    service: &T,
323    proto_path: &str,
324    compile_well_known_types: bool,
325    codec: &HashMap<String, String>,
326) -> TokenStream {
327    let mut stream = TokenStream::new();
328    for method in service.methods() {
329        let path = format!(
330            "{}",
331            method.identifier()
332        );
333        let method_path = Lit::Str(LitStr::new(&path, Span::call_site()));
334        let ident = quote::format_ident!("{}", method.name());
335        let server_trait = quote::format_ident!("{}", service.name());
336
337        let method_stream = match (method.client_streaming(), method.server_streaming()) {
338            (false, false) => generate_unary(
339                method,
340                proto_path,
341                compile_well_known_types,
342                ident,
343                server_trait,
344                codec,
345            ),
346
347            (false, true) => generate_server_streaming(
348                method,
349                proto_path,
350                compile_well_known_types,
351                ident.clone(),
352                server_trait,
353            ),
354            (true, false) => generate_client_streaming(
355                method,
356                proto_path,
357                compile_well_known_types,
358                ident.clone(),
359                server_trait,
360            ),
361
362            (true, true) => generate_streaming(
363                method,
364                proto_path,
365                compile_well_known_types,
366                ident.clone(),
367                server_trait,
368            ),
369        };
370
371        let method = quote! {
372            #method_path => {
373                #method_stream
374            }
375        };
376        stream.extend(method);
377    }
378
379    stream
380}
381
382fn generate_endpoints<T: Service>(
383    service: &T,
384    emit_package: bool,
385) -> TokenStream {
386    let mut stream = TokenStream::new();
387    let package = if emit_package { service.package() } else { "" };
388
389    for method in service.methods() {
390        let path = format!(
391            "{}{}{}.{}",
392            package,
393            if package.is_empty() {
394                ""
395            } else {
396                "."
397            },
398            service.identifier(),
399            method.identifier()
400        );
401
402        let method = quote! {
403             #path.to_string(),
404        };
405        stream.extend(method);
406    }
407
408
409    stream
410}
411
412fn generate_codec_call(
413    codec: &HashMap<String, String>,
414) -> TokenStream {
415    let mut stream = TokenStream::new();
416    for (k, v) in codec {
417        let codec_name = syn::parse_str::<syn::Path>(v).unwrap();
418        let p = TokenTree::Punct(Punct::new('|', Spacing::Alone).into());
419
420        for (i, s) in k.split("|").collect::<Vec<&str>>().iter().enumerate() {
421            if i != 0 {
422                stream.extend(quote! {#p});
423            }
424            let s = s.trim();
425            stream.extend(quote! {#s});
426        }
427
428        stream.extend(
429            quote! {
430                => {
431                    let codec = #codec_name::default();
432                    let mut grpc = tonic::server::Grpc::new(codec)
433                        .apply_compression_config(accept_compression_encodings, send_compression_encodings);
434
435                    let res = grpc.unary(method, req).await;
436                    Ok(res)
437                }
438            }
439        );
440    }
441    stream
442}
443
444fn generate_unary<T: Method>(
445    method: &T,
446    proto_path: &str,
447    compile_well_known_types: bool,
448    method_ident: Ident,
449    server_trait: Ident,
450    codec: &HashMap<String, String>,
451) -> TokenStream {
452    // let codec_name = syn::parse_str::<syn::Path>(T::CODEC_PATH).unwrap();
453
454    let service_ident = quote::format_ident!("{}Svc", method.identifier());
455
456    let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
457
458    let codec = generate_codec_call(codec);
459
460    quote! {
461        #[allow(non_camel_case_types)]
462        struct #service_ident<T: #server_trait >(pub Arc<T>);
463
464        impl<T: #server_trait> tonic::server::UnaryService<#request> for #service_ident<T> {
465            type Response = #response;
466            type Future = BoxFuture<tonic::Response<Self::Response>, tonic::Status>;
467
468            fn call(&mut self, request: tonic::Request<#request>) -> Self::Future {
469                let inner = self.0.clone();
470                let fut = async move {
471                    (*inner).#method_ident(request).await
472                };
473                Box::pin(fut)
474            }
475        }
476
477        let accept_compression_encodings = self.accept_compression_encodings;
478        let send_compression_encodings = self.send_compression_encodings;
479        let inner = self.inner.clone();
480        let fut = async move {
481            let inner = inner.0;
482            let method = #service_ident(inner);
483
484            let ct = req.headers().get("content-type").unwrap().to_str().unwrap();
485            match ct{
486                #codec
487
488                _ => Ok(http::Response::builder()
489                        .status(200)
490                        .header("grpc-status", "13")
491                        .header("content-type", "application/grpc")
492                        .body(empty_body())
493                        .unwrap())
494            }
495            // if ct == "application/grpc" {
496            //     let codec = tonic::codec::ProstCodec::default();
497            //
498            //     let mut grpc = tonic::server::Grpc::new(codec)
499            //         .apply_compression_config(accept_compression_encodings, send_compression_encodings);
500            //
501            //     let res = grpc.unary(method, req).await;
502            //     Ok(res)
503            // } else {
504            //     let codec = crate::lib::codec::JsonCodec::default();
505            //
506            //     let mut grpc = tonic::server::Grpc::new(codec)
507            //         .apply_compression_config(accept_compression_encodings, send_compression_encodings);
508            //
509            //     let res = grpc.unary(method, req).await;
510            //     Ok(res)
511            // }
512        };
513
514        Box::pin(fut)
515    }
516}
517
518fn generate_server_streaming<T: Method>(
519    method: &T,
520    proto_path: &str,
521    compile_well_known_types: bool,
522    method_ident: Ident,
523    server_trait: Ident,
524) -> TokenStream {
525    let codec_name = syn::parse_str::<syn::Path>(T::CODEC_PATH).unwrap();
526
527    let service_ident = quote::format_ident!("{}Svc", method.identifier());
528
529    let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
530
531    let response_stream = quote::format_ident!("{}Stream", method.identifier());
532
533    quote! {
534        #[allow(non_camel_case_types)]
535        struct #service_ident<T: #server_trait >(pub Arc<T>);
536
537        impl<T: #server_trait> tonic::server::ServerStreamingService<#request> for #service_ident<T> {
538            type Response = #response;
539            type ResponseStream = T::#response_stream;
540            type Future = BoxFuture<tonic::Response<Self::ResponseStream>, tonic::Status>;
541
542            fn call(&mut self, request: tonic::Request<#request>) -> Self::Future {
543                let inner = self.0.clone();
544                let fut = async move {
545                    (*inner).#method_ident(request).await
546                };
547                Box::pin(fut)
548            }
549        }
550
551        let accept_compression_encodings = self.accept_compression_encodings;
552        let send_compression_encodings = self.send_compression_encodings;
553        let inner = self.inner.clone();
554        let fut = async move {
555            let inner = inner.0;
556            let method = #service_ident(inner);
557            let codec = #codec_name::default();
558
559            let mut grpc = tonic::server::Grpc::new(codec)
560                .apply_compression_config(accept_compression_encodings, send_compression_encodings);
561
562            let res = grpc.server_streaming(method, req).await;
563            Ok(res)
564        };
565
566        Box::pin(fut)
567    }
568}
569
570fn generate_client_streaming<T: Method>(
571    method: &T,
572    proto_path: &str,
573    compile_well_known_types: bool,
574    method_ident: Ident,
575    server_trait: Ident,
576) -> TokenStream {
577    let service_ident = quote::format_ident!("{}Svc", method.identifier());
578
579    let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
580    let codec_name = syn::parse_str::<syn::Path>(T::CODEC_PATH).unwrap();
581
582    quote! {
583        #[allow(non_camel_case_types)]
584        struct #service_ident<T: #server_trait >(pub Arc<T>);
585
586        impl<T: #server_trait> tonic::server::ClientStreamingService<#request> for #service_ident<T>
587        {
588            type Response = #response;
589            type Future = BoxFuture<tonic::Response<Self::Response>, tonic::Status>;
590
591            fn call(&mut self, request: tonic::Request<tonic::Streaming<#request>>) -> Self::Future {
592                let inner = self.0.clone();
593                let fut = async move {
594                    (*inner).#method_ident(request).await
595
596                };
597                Box::pin(fut)
598            }
599        }
600
601        let accept_compression_encodings = self.accept_compression_encodings;
602        let send_compression_encodings = self.send_compression_encodings;
603        let inner = self.inner.clone();
604        let fut = async move {
605            let inner = inner.0;
606            let method = #service_ident(inner);
607            let codec = #codec_name::default();
608
609            let mut grpc = tonic::server::Grpc::new(codec)
610                .apply_compression_config(accept_compression_encodings, send_compression_encodings);
611
612            let res = grpc.client_streaming(method, req).await;
613            Ok(res)
614        };
615
616        Box::pin(fut)
617    }
618}
619
620fn generate_streaming<T: Method>(
621    method: &T,
622    proto_path: &str,
623    compile_well_known_types: bool,
624    method_ident: Ident,
625    server_trait: Ident,
626) -> TokenStream {
627    let codec_name = syn::parse_str::<syn::Path>(T::CODEC_PATH).unwrap();
628
629    let service_ident = quote::format_ident!("{}Svc", method.identifier());
630
631    let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
632
633    let response_stream = quote::format_ident!("{}Stream", method.identifier());
634
635    quote! {
636        #[allow(non_camel_case_types)]
637        struct #service_ident<T: #server_trait>(pub Arc<T>);
638
639        impl<T: #server_trait> tonic::server::StreamingService<#request> for #service_ident<T>
640        {
641            type Response = #response;
642            type ResponseStream = T::#response_stream;
643            type Future = BoxFuture<tonic::Response<Self::ResponseStream>, tonic::Status>;
644
645            fn call(&mut self, request: tonic::Request<tonic::Streaming<#request>>) -> Self::Future {
646                let inner = self.0.clone();
647                let fut = async move {
648                    (*inner).#method_ident(request).await
649                };
650                Box::pin(fut)
651            }
652        }
653
654        let accept_compression_encodings = self.accept_compression_encodings;
655        let send_compression_encodings = self.send_compression_encodings;
656        let inner = self.inner.clone();
657        let fut = async move {
658            let inner = inner.0;
659            let method = #service_ident(inner);
660            let codec = #codec_name::default();
661
662            let mut grpc = tonic::server::Grpc::new(codec)
663                .apply_compression_config(accept_compression_encodings, send_compression_encodings);
664
665            let res = grpc.streaming(method, req).await;
666            Ok(res)
667        };
668
669        Box::pin(fut)
670    }
671}