1use std::collections::HashMap;
2use super::{Attributes, Method, Service};
3use crate::{generate_doc_comment, generate_doc_comments, naive_snake_case};
4use proc_macro2::{Span, TokenStream, Spacing, Punct, TokenTree};
5use quote::quote;
6use syn::{Ident, Lit, LitStr};
7
8pub fn generate<T: Service>(
13 service: &T,
14 emit_package: bool,
15 proto_path: &str,
16 compile_well_known_types: bool,
17 attributes: &Attributes,
18 codec: &HashMap<String, String>,
19) -> TokenStream {
20 let methods = generate_methods(service, proto_path, compile_well_known_types, codec);
21 let endpoints = generate_endpoints(service, emit_package);
22
23 let server_service = quote::format_ident!("{}Server", service.name());
24 let server_trait = quote::format_ident!("{}", service.name());
25 let server_mod = quote::format_ident!("{}_server", naive_snake_case(&service.name()));
26 let generated_trait = generate_trait(
27 service,
28 proto_path,
29 compile_well_known_types,
30 server_trait.clone(),
31 );
32 let service_doc = generate_doc_comments(service.comment());
33 let package = if emit_package { service.package() } else { "go" };
34 let path = format!(
36 "{}{}{}",
37 package,
38 ".",
39 service.identifier()
40 );
41 let transport = generate_transport(&server_service, &server_trait, &path);
42 let mod_attributes = attributes.for_mod(package);
43 let struct_attributes = attributes.for_struct(&path);
44
45 let compression_enabled = cfg!(feature = "compression");
46
47 let compression_config_ty = if compression_enabled {
48 quote! { EnabledCompressionEncodings }
49 } else {
50 quote! { () }
51 };
52
53 let configure_compression_methods = if compression_enabled {
54 quote! {
55 pub fn accept_gzip(mut self) -> Self {
57 self.accept_compression_encodings.enable_gzip();
58 self
59 }
60
61 pub fn send_gzip(mut self) -> Self {
63 self.send_compression_encodings.enable_gzip();
64 self
65 }
66 }
67 } else {
68 quote! {}
69 };
70
71 quote! {
72 #(#mod_attributes)*
74 pub mod #server_mod {
75 #![allow(
76 unused_variables,
77 dead_code,
78 missing_docs,
79 clippy::let_unit_value,
81 )]
82 use tonic::codegen::*;
83
84 #generated_trait
85
86 #service_doc
87 #(#struct_attributes)*
88 #[derive(Debug)]
89 pub struct #server_service<T: #server_trait> {
90 pub inner: _Inner<T>,
91 accept_compression_encodings: #compression_config_ty,
92 send_compression_encodings: #compression_config_ty,
93 }
94
95 pub struct _Inner<T>(pub Arc<T>);
96
97 impl<T: #server_trait> #server_service<T> {
98 pub fn new(inner: T) -> Self {
99 let inner = Arc::new(inner);
100 let inner = _Inner(inner);
101 Self {
102 inner,
103 accept_compression_encodings: Default::default(),
104 send_compression_encodings: Default::default(),
105 }
106 }
107
108 pub fn with_interceptor<F>(inner: T, interceptor: F) -> InterceptedService<Self, F>
109 where
110 F: tonic::service::Interceptor,
111 {
112 InterceptedService::new(Self::new(inner), interceptor)
113 }
114
115 fn match_uri(path: &str) -> (&str, &str) {
116 let ss = path.split(".").collect::<Vec<&str>>();
117 if ss.len() > 1 {
118 let sx: Vec<&str> = ss.last().unwrap().split("/").collect();
119 if sx.len() == 2 {
120 (sx[0], sx[1])
121 } else {
122 ("", "")
123 }
124 } else {
125 let sx: Vec<&str> = ss[0].trim_start_matches("/").split("/").collect();
126 if sx.len() == 2 {
127 (sx[0], sx[1])
128 } else {
129 ("", "")
130 }
131 }
132 }
133
134 #configure_compression_methods
135 }
136
137 impl<T, B> tonic::codegen::Service<http::Request<B>> for #server_service<T>
138 where
139 T: #server_trait,
140 B: Body + Send + 'static,
141 B::Error: Into<StdError> + Send + 'static,
142 {
143 type Response = http::Response<tonic::body::BoxBody>;
144 type Error = Never;
145 type Future = BoxFuture<Self::Response, Self::Error>;
146
147 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
148 Poll::Ready(Ok(()))
149 }
150
151 fn call(&mut self, req: http::Request<B>) -> Self::Future {
152 let inner = self.inner.clone();
153
154 let (_, method) = Self::match_uri(req.uri().path());
155 match method {
156 #methods
157
158 _ => Box::pin(async move {
159 Ok(http::Response::builder()
160 .status(200)
161 .header("grpc-status", "12")
162 .header("content-type", "application/grpc")
163 .body(empty_body())
164 .unwrap())
165 }),
166 }
167 }
168 }
169
170 impl<T: #server_trait> Clone for #server_service<T> {
171 fn clone(&self) -> Self {
172 let inner = self.inner.clone();
173 Self {
174 inner,
175 accept_compression_encodings: self.accept_compression_encodings,
176 send_compression_encodings: self.send_compression_encodings,
177 }
178 }
179 }
180
181 impl<T: #server_trait> Clone for _Inner<T> {
182 fn clone(&self) -> Self {
183 Self(self.0.clone())
184 }
185 }
186
187 impl<T: std::fmt::Debug> std::fmt::Debug for _Inner<T> {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 write!(f, "{:?}", self.0)
190 }
191 }
192
193 impl <T :#server_trait> #server_service<T> {
195 pub fn endpoints(&self) -> Vec<String> {
196 vec!(#endpoints)
197 }
198 }
199
200
201 #transport
202 }
203 }
204}
205
206fn generate_trait<T: Service>(
207 service: &T,
208 proto_path: &str,
209 compile_well_known_types: bool,
210 server_trait: Ident,
211) -> TokenStream {
212 let methods = generate_trait_methods(service, proto_path, compile_well_known_types);
213 let trait_doc = generate_doc_comment(&format!(
214 "Generated trait containing gRPC methods that should be implemented for use with {}Server.",
215 service.name()
216 ));
217
218 quote! {
219 #trait_doc
220 #[async_trait]
221 pub trait #server_trait : Send + Sync + 'static {
222 #methods
223 }
224 }
225}
226
227fn generate_trait_methods<T: Service>(
228 service: &T,
229 proto_path: &str,
230 compile_well_known_types: bool,
231) -> TokenStream {
232 let mut stream = TokenStream::new();
233
234 for method in service.methods() {
235 let name = quote::format_ident!("{}", method.name());
236
237 let (req_message, res_message) =
238 method.request_response_name(proto_path, compile_well_known_types);
239
240 let method_doc = generate_doc_comments(method.comment());
241
242 let method = match (method.client_streaming(), method.server_streaming()) {
243 (false, false) => {
244 quote! {
245 #method_doc
246 async fn #name(&self, request: tonic::Request<#req_message>)
247 -> Result<tonic::Response<#res_message>, tonic::Status>;
248 }
249 }
250 (true, false) => {
251 quote! {
252 #method_doc
253 async fn #name(&self, request: tonic::Request<tonic::Streaming<#req_message>>)
254 -> Result<tonic::Response<#res_message>, tonic::Status>;
255 }
256 }
257 (false, true) => {
258 let stream = quote::format_ident!("{}Stream", method.identifier());
259 let stream_doc = generate_doc_comment(&format!(
260 "Server streaming response type for the {} method.",
261 method.identifier()
262 ));
263
264 quote! {
265 #stream_doc
266 type #stream: futures_core::Stream<Item = Result<#res_message, tonic::Status>> + Send + 'static;
267
268 #method_doc
269 async fn #name(&self, request: tonic::Request<#req_message>)
270 -> Result<tonic::Response<Self::#stream>, tonic::Status>;
271 }
272 }
273 (true, true) => {
274 let stream = quote::format_ident!("{}Stream", method.identifier());
275 let stream_doc = generate_doc_comment(&format!(
276 "Server streaming response type for the {} method.",
277 method.identifier()
278 ));
279
280 quote! {
281 #stream_doc
282 type #stream: futures_core::Stream<Item = Result<#res_message, tonic::Status>> + Send + 'static;
283
284 #method_doc
285 async fn #name(&self, request: tonic::Request<tonic::Streaming<#req_message>>)
286 -> Result<tonic::Response<Self::#stream>, tonic::Status>;
287 }
288 }
289 };
290
291 stream.extend(method);
292 }
293
294 stream
295}
296
297#[cfg(feature = "transport")]
298fn generate_transport(
299 server_service: &syn::Ident,
300 server_trait: &syn::Ident,
301 service_name: &str,
302) -> TokenStream {
303 let service_name = syn::LitStr::new(service_name, proc_macro2::Span::call_site());
304
305 quote! {
306 impl<T: #server_trait> tonic::transport::NamedService for #server_service<T> {
307 const NAME: &'static str = #service_name;
308 }
309 }
310}
311
312#[cfg(not(feature = "transport"))]
313fn generate_transport(
314 _server_service: &syn::Ident,
315 _server_trait: &syn::Ident,
316 _service_name: &str,
317) -> TokenStream {
318 TokenStream::new()
319}
320
321fn generate_methods<T: Service>(
322 service: &T,
323 proto_path: &str,
324 compile_well_known_types: bool,
325 codec: &HashMap<String, String>,
326) -> TokenStream {
327 let mut stream = TokenStream::new();
328 for method in service.methods() {
329 let path = format!(
330 "{}",
331 method.identifier()
332 );
333 let method_path = Lit::Str(LitStr::new(&path, Span::call_site()));
334 let ident = quote::format_ident!("{}", method.name());
335 let server_trait = quote::format_ident!("{}", service.name());
336
337 let method_stream = match (method.client_streaming(), method.server_streaming()) {
338 (false, false) => generate_unary(
339 method,
340 proto_path,
341 compile_well_known_types,
342 ident,
343 server_trait,
344 codec,
345 ),
346
347 (false, true) => generate_server_streaming(
348 method,
349 proto_path,
350 compile_well_known_types,
351 ident.clone(),
352 server_trait,
353 ),
354 (true, false) => generate_client_streaming(
355 method,
356 proto_path,
357 compile_well_known_types,
358 ident.clone(),
359 server_trait,
360 ),
361
362 (true, true) => generate_streaming(
363 method,
364 proto_path,
365 compile_well_known_types,
366 ident.clone(),
367 server_trait,
368 ),
369 };
370
371 let method = quote! {
372 #method_path => {
373 #method_stream
374 }
375 };
376 stream.extend(method);
377 }
378
379 stream
380}
381
382fn generate_endpoints<T: Service>(
383 service: &T,
384 emit_package: bool,
385) -> TokenStream {
386 let mut stream = TokenStream::new();
387 let package = if emit_package { service.package() } else { "" };
388
389 for method in service.methods() {
390 let path = format!(
391 "{}{}{}.{}",
392 package,
393 if package.is_empty() {
394 ""
395 } else {
396 "."
397 },
398 service.identifier(),
399 method.identifier()
400 );
401
402 let method = quote! {
403 #path.to_string(),
404 };
405 stream.extend(method);
406 }
407
408
409 stream
410}
411
412fn generate_codec_call(
413 codec: &HashMap<String, String>,
414) -> TokenStream {
415 let mut stream = TokenStream::new();
416 for (k, v) in codec {
417 let codec_name = syn::parse_str::<syn::Path>(v).unwrap();
418 let p = TokenTree::Punct(Punct::new('|', Spacing::Alone).into());
419
420 for (i, s) in k.split("|").collect::<Vec<&str>>().iter().enumerate() {
421 if i != 0 {
422 stream.extend(quote! {#p});
423 }
424 let s = s.trim();
425 stream.extend(quote! {#s});
426 }
427
428 stream.extend(
429 quote! {
430 => {
431 let codec = #codec_name::default();
432 let mut grpc = tonic::server::Grpc::new(codec)
433 .apply_compression_config(accept_compression_encodings, send_compression_encodings);
434
435 let res = grpc.unary(method, req).await;
436 Ok(res)
437 }
438 }
439 );
440 }
441 stream
442}
443
444fn generate_unary<T: Method>(
445 method: &T,
446 proto_path: &str,
447 compile_well_known_types: bool,
448 method_ident: Ident,
449 server_trait: Ident,
450 codec: &HashMap<String, String>,
451) -> TokenStream {
452 let service_ident = quote::format_ident!("{}Svc", method.identifier());
455
456 let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
457
458 let codec = generate_codec_call(codec);
459
460 quote! {
461 #[allow(non_camel_case_types)]
462 struct #service_ident<T: #server_trait >(pub Arc<T>);
463
464 impl<T: #server_trait> tonic::server::UnaryService<#request> for #service_ident<T> {
465 type Response = #response;
466 type Future = BoxFuture<tonic::Response<Self::Response>, tonic::Status>;
467
468 fn call(&mut self, request: tonic::Request<#request>) -> Self::Future {
469 let inner = self.0.clone();
470 let fut = async move {
471 (*inner).#method_ident(request).await
472 };
473 Box::pin(fut)
474 }
475 }
476
477 let accept_compression_encodings = self.accept_compression_encodings;
478 let send_compression_encodings = self.send_compression_encodings;
479 let inner = self.inner.clone();
480 let fut = async move {
481 let inner = inner.0;
482 let method = #service_ident(inner);
483
484 let ct = req.headers().get("content-type").unwrap().to_str().unwrap();
485 match ct{
486 #codec
487
488 _ => Ok(http::Response::builder()
489 .status(200)
490 .header("grpc-status", "13")
491 .header("content-type", "application/grpc")
492 .body(empty_body())
493 .unwrap())
494 }
495 };
513
514 Box::pin(fut)
515 }
516}
517
518fn generate_server_streaming<T: Method>(
519 method: &T,
520 proto_path: &str,
521 compile_well_known_types: bool,
522 method_ident: Ident,
523 server_trait: Ident,
524) -> TokenStream {
525 let codec_name = syn::parse_str::<syn::Path>(T::CODEC_PATH).unwrap();
526
527 let service_ident = quote::format_ident!("{}Svc", method.identifier());
528
529 let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
530
531 let response_stream = quote::format_ident!("{}Stream", method.identifier());
532
533 quote! {
534 #[allow(non_camel_case_types)]
535 struct #service_ident<T: #server_trait >(pub Arc<T>);
536
537 impl<T: #server_trait> tonic::server::ServerStreamingService<#request> for #service_ident<T> {
538 type Response = #response;
539 type ResponseStream = T::#response_stream;
540 type Future = BoxFuture<tonic::Response<Self::ResponseStream>, tonic::Status>;
541
542 fn call(&mut self, request: tonic::Request<#request>) -> Self::Future {
543 let inner = self.0.clone();
544 let fut = async move {
545 (*inner).#method_ident(request).await
546 };
547 Box::pin(fut)
548 }
549 }
550
551 let accept_compression_encodings = self.accept_compression_encodings;
552 let send_compression_encodings = self.send_compression_encodings;
553 let inner = self.inner.clone();
554 let fut = async move {
555 let inner = inner.0;
556 let method = #service_ident(inner);
557 let codec = #codec_name::default();
558
559 let mut grpc = tonic::server::Grpc::new(codec)
560 .apply_compression_config(accept_compression_encodings, send_compression_encodings);
561
562 let res = grpc.server_streaming(method, req).await;
563 Ok(res)
564 };
565
566 Box::pin(fut)
567 }
568}
569
570fn generate_client_streaming<T: Method>(
571 method: &T,
572 proto_path: &str,
573 compile_well_known_types: bool,
574 method_ident: Ident,
575 server_trait: Ident,
576) -> TokenStream {
577 let service_ident = quote::format_ident!("{}Svc", method.identifier());
578
579 let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
580 let codec_name = syn::parse_str::<syn::Path>(T::CODEC_PATH).unwrap();
581
582 quote! {
583 #[allow(non_camel_case_types)]
584 struct #service_ident<T: #server_trait >(pub Arc<T>);
585
586 impl<T: #server_trait> tonic::server::ClientStreamingService<#request> for #service_ident<T>
587 {
588 type Response = #response;
589 type Future = BoxFuture<tonic::Response<Self::Response>, tonic::Status>;
590
591 fn call(&mut self, request: tonic::Request<tonic::Streaming<#request>>) -> Self::Future {
592 let inner = self.0.clone();
593 let fut = async move {
594 (*inner).#method_ident(request).await
595
596 };
597 Box::pin(fut)
598 }
599 }
600
601 let accept_compression_encodings = self.accept_compression_encodings;
602 let send_compression_encodings = self.send_compression_encodings;
603 let inner = self.inner.clone();
604 let fut = async move {
605 let inner = inner.0;
606 let method = #service_ident(inner);
607 let codec = #codec_name::default();
608
609 let mut grpc = tonic::server::Grpc::new(codec)
610 .apply_compression_config(accept_compression_encodings, send_compression_encodings);
611
612 let res = grpc.client_streaming(method, req).await;
613 Ok(res)
614 };
615
616 Box::pin(fut)
617 }
618}
619
620fn generate_streaming<T: Method>(
621 method: &T,
622 proto_path: &str,
623 compile_well_known_types: bool,
624 method_ident: Ident,
625 server_trait: Ident,
626) -> TokenStream {
627 let codec_name = syn::parse_str::<syn::Path>(T::CODEC_PATH).unwrap();
628
629 let service_ident = quote::format_ident!("{}Svc", method.identifier());
630
631 let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
632
633 let response_stream = quote::format_ident!("{}Stream", method.identifier());
634
635 quote! {
636 #[allow(non_camel_case_types)]
637 struct #service_ident<T: #server_trait>(pub Arc<T>);
638
639 impl<T: #server_trait> tonic::server::StreamingService<#request> for #service_ident<T>
640 {
641 type Response = #response;
642 type ResponseStream = T::#response_stream;
643 type Future = BoxFuture<tonic::Response<Self::ResponseStream>, tonic::Status>;
644
645 fn call(&mut self, request: tonic::Request<tonic::Streaming<#request>>) -> Self::Future {
646 let inner = self.0.clone();
647 let fut = async move {
648 (*inner).#method_ident(request).await
649 };
650 Box::pin(fut)
651 }
652 }
653
654 let accept_compression_encodings = self.accept_compression_encodings;
655 let send_compression_encodings = self.send_compression_encodings;
656 let inner = self.inner.clone();
657 let fut = async move {
658 let inner = inner.0;
659 let method = #service_ident(inner);
660 let codec = #codec_name::default();
661
662 let mut grpc = tonic::server::Grpc::new(codec)
663 .apply_compression_config(accept_compression_encodings, send_compression_encodings);
664
665 let res = grpc.streaming(method, req).await;
666 Ok(res)
667 };
668
669 Box::pin(fut)
670 }
671}