1use std::collections::HashSet;
2
3use super::{Attributes, Method, Service};
4use crate::{
5 format_method_name, format_method_path, format_service_name, generate_doc_comment,
6 generate_doc_comments, naive_snake_case,
7};
8use proc_macro2::{Span, TokenStream};
9use quote::quote;
10use syn::{Ident, Lit, LitStr};
11
12#[allow(clippy::too_many_arguments)]
13pub(crate) fn generate_internal<T: Service>(
14 service: &T,
15 emit_package: bool,
16 proto_path: &str,
17 compile_well_known_types: bool,
18 attributes: &Attributes,
19 disable_comments: &HashSet<String>,
20 use_arc_self: bool,
21 generate_default_stubs: bool,
22) -> TokenStream {
23 let methods = generate_methods(
24 service,
25 emit_package,
26 proto_path,
27 compile_well_known_types,
28 use_arc_self,
29 generate_default_stubs,
30 );
31
32 let server_service = quote::format_ident!("{}Server", service.name());
33 let server_trait = quote::format_ident!("{}", service.name());
34 let server_mod = quote::format_ident!("{}_server", naive_snake_case(service.name()));
35 let trait_attributes = attributes.for_trait(service.name());
36 let generated_trait = generate_trait(
37 service,
38 emit_package,
39 proto_path,
40 compile_well_known_types,
41 server_trait.clone(),
42 disable_comments,
43 use_arc_self,
44 generate_default_stubs,
45 trait_attributes,
46 );
47 let package = if emit_package { service.package() } else { "" };
48 let service_name = format_service_name(service, emit_package);
50
51 let service_doc = if disable_comments.contains(&service_name) {
52 TokenStream::new()
53 } else {
54 generate_doc_comments(service.comment())
55 };
56
57 let named = generate_named(&server_service, &service_name);
58 let mod_attributes = attributes.for_mod(package);
59 let struct_attributes = attributes.for_struct(&service_name);
60
61 let configure_compression_methods = quote! {
62 #[must_use]
64 pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
65 self.accept_compression_encodings.enable(encoding);
66 self
67 }
68
69 #[must_use]
71 pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
72 self.send_compression_encodings.enable(encoding);
73 self
74 }
75 };
76
77 let configure_max_message_size_methods = quote! {
78 #[must_use]
82 pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
83 self.max_decoding_message_size = Some(limit);
84 self
85 }
86
87 #[must_use]
91 pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
92 self.max_encoding_message_size = Some(limit);
93 self
94 }
95 };
96
97 quote! {
98 #(#mod_attributes)*
100 pub mod #server_mod {
101 #![allow(
102 unused_variables,
103 dead_code,
104 missing_docs,
105 clippy::wildcard_imports,
106 clippy::let_unit_value,
108 )]
109 use tonic::codegen::*;
110
111 #generated_trait
112
113 #service_doc
114 #(#struct_attributes)*
115 #[derive(Debug)]
116 pub struct #server_service<T> {
117 inner: Arc<T>,
118 accept_compression_encodings: EnabledCompressionEncodings,
119 send_compression_encodings: EnabledCompressionEncodings,
120 max_decoding_message_size: Option<usize>,
121 max_encoding_message_size: Option<usize>,
122 }
123
124 impl<T> #server_service<T> {
125 pub fn new(inner: T) -> Self {
126 Self::from_arc(Arc::new(inner))
127 }
128
129 pub fn from_arc(inner: Arc<T>) -> Self {
130 Self {
131 inner,
132 accept_compression_encodings: Default::default(),
133 send_compression_encodings: Default::default(),
134 max_decoding_message_size: None,
135 max_encoding_message_size: None,
136 }
137 }
138
139 pub fn with_interceptor<F>(inner: T, interceptor: F) -> InterceptedService<Self, F>
140 where
141 F: tonic::service::Interceptor,
142 {
143 InterceptedService::new(Self::new(inner), interceptor)
144 }
145
146 #configure_compression_methods
147
148 #configure_max_message_size_methods
149 }
150
151 impl<T, B> tonic::codegen::Service<http::Request<B>> for #server_service<T>
152 where
153 T: #server_trait,
154 B: Body + std::marker::Send + 'static,
155 B::Error: Into<StdError> + std::marker::Send + 'static,
156 {
157 type Response = http::Response<tonic::body::Body>;
158 type Error = std::convert::Infallible;
159 type Future = BoxFuture<Self::Response, Self::Error>;
160
161 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
162 Poll::Ready(Ok(()))
163 }
164
165 fn call(&mut self, req: http::Request<B>) -> Self::Future {
166 match req.uri().path() {
167 #methods
168
169 _ => Box::pin(async move {
170 let mut response = http::Response::new(tonic::body::Body::default());
171 let headers = response.headers_mut();
172 headers.insert(tonic::Status::GRPC_STATUS, (tonic::Code::Unimplemented as i32).into());
173 headers.insert(http::header::CONTENT_TYPE, tonic::metadata::GRPC_CONTENT_TYPE);
174 Ok(response)
175 }),
176 }
177 }
178 }
179
180 impl<T> Clone for #server_service<T> {
181 fn clone(&self) -> Self {
182 let inner = self.inner.clone();
183 Self {
184 inner,
185 accept_compression_encodings: self.accept_compression_encodings,
186 send_compression_encodings: self.send_compression_encodings,
187 max_decoding_message_size: self.max_decoding_message_size,
188 max_encoding_message_size: self.max_encoding_message_size,
189 }
190 }
191 }
192
193 #named
194 }
195 }
196}
197
198#[allow(clippy::too_many_arguments)]
199fn generate_trait<T: Service>(
200 service: &T,
201 emit_package: bool,
202 proto_path: &str,
203 compile_well_known_types: bool,
204 server_trait: Ident,
205 disable_comments: &HashSet<String>,
206 use_arc_self: bool,
207 generate_default_stubs: bool,
208 trait_attributes: Vec<syn::Attribute>,
209) -> TokenStream {
210 let methods = generate_trait_methods(
211 service,
212 emit_package,
213 proto_path,
214 compile_well_known_types,
215 disable_comments,
216 use_arc_self,
217 generate_default_stubs,
218 );
219 let trait_doc = generate_doc_comment(format!(
220 " Generated trait containing gRPC methods that should be implemented for use with {}Server.",
221 service.name()
222 ));
223
224 quote! {
225 #trait_doc
226 #(#trait_attributes)*
227 #[async_trait]
228 pub trait #server_trait : std::marker::Send + std::marker::Sync + 'static {
229 #methods
230 }
231 }
232}
233
234fn generate_trait_methods<T: Service>(
235 service: &T,
236 emit_package: bool,
237 proto_path: &str,
238 compile_well_known_types: bool,
239 disable_comments: &HashSet<String>,
240 use_arc_self: bool,
241 generate_default_stubs: bool,
242) -> TokenStream {
243 let mut stream = TokenStream::new();
244
245 for method in service.methods() {
246 let name = quote::format_ident!("{}", method.name());
247
248 let (req_message, res_message) =
249 method.request_response_name(proto_path, compile_well_known_types);
250
251 let method_doc =
252 if disable_comments.contains(&format_method_name(service, method, emit_package)) {
253 TokenStream::new()
254 } else {
255 generate_doc_comments(method.comment())
256 };
257
258 let self_param = if use_arc_self {
259 quote!(self: std::sync::Arc<Self>)
260 } else {
261 quote!(&self)
262 };
263
264 let method = match (
265 method.client_streaming(),
266 method.server_streaming(),
267 generate_default_stubs,
268 ) {
269 (false, false, true) => {
270 quote! {
271 #method_doc
272 async fn #name(#self_param, request: tonic::Request<#req_message>)
273 -> std::result::Result<tonic::Response<#res_message>, tonic::Status> {
274 Err(tonic::Status::unimplemented("Not yet implemented"))
275 }
276 }
277 }
278 (false, false, false) => {
279 quote! {
280 #method_doc
281 async fn #name(#self_param, request: tonic::Request<#req_message>)
282 -> std::result::Result<tonic::Response<#res_message>, tonic::Status>;
283 }
284 }
285 (true, false, true) => {
286 quote! {
287 #method_doc
288 async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
289 -> std::result::Result<tonic::Response<#res_message>, tonic::Status> {
290 Err(tonic::Status::unimplemented("Not yet implemented"))
291 }
292 }
293 }
294 (true, false, false) => {
295 quote! {
296 #method_doc
297 async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
298 -> std::result::Result<tonic::Response<#res_message>, tonic::Status>;
299 }
300 }
301 (false, true, true) => {
302 quote! {
303 #method_doc
304 async fn #name(#self_param, request: tonic::Request<#req_message>)
305 -> std::result::Result<tonic::Response<BoxStream<#res_message>>, tonic::Status> {
306 Err(tonic::Status::unimplemented("Not yet implemented"))
307 }
308 }
309 }
310 (false, true, false) => {
311 let stream = quote::format_ident!("{}Stream", method.identifier());
312 let stream_doc = generate_doc_comment(format!(
313 " Server streaming response type for the {} method.",
314 method.identifier()
315 ));
316
317 quote! {
318 #stream_doc
319 type #stream: tonic::codegen::tokio_stream::Stream<Item = std::result::Result<#res_message, tonic::Status>> + std::marker::Send + 'static;
320
321 #method_doc
322 async fn #name(#self_param, request: tonic::Request<#req_message>)
323 -> std::result::Result<tonic::Response<Self::#stream>, tonic::Status>;
324 }
325 }
326 (true, true, true) => {
327 quote! {
328 #method_doc
329 async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
330 -> std::result::Result<tonic::Response<BoxStream<#res_message>>, tonic::Status> {
331 Err(tonic::Status::unimplemented("Not yet implemented"))
332 }
333 }
334 }
335 (true, true, false) => {
336 let stream = quote::format_ident!("{}Stream", method.identifier());
337 let stream_doc = generate_doc_comment(format!(
338 " Server streaming response type for the {} method.",
339 method.identifier()
340 ));
341
342 quote! {
343 #stream_doc
344 type #stream: tonic::codegen::tokio_stream::Stream<Item = std::result::Result<#res_message, tonic::Status>> + std::marker::Send + 'static;
345
346 #method_doc
347 async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
348 -> std::result::Result<tonic::Response<Self::#stream>, tonic::Status>;
349 }
350 }
351 };
352
353 stream.extend(method);
354 }
355
356 stream
357}
358
359fn generate_named(server_service: &syn::Ident, service_name: &str) -> TokenStream {
360 let service_name = syn::LitStr::new(service_name, proc_macro2::Span::call_site());
361 let name_doc = generate_doc_comment(" Generated gRPC service name");
362
363 quote! {
364 #name_doc
365 pub const SERVICE_NAME: &str = #service_name;
366
367 impl<T> tonic::server::NamedService for #server_service<T> {
368 const NAME: &'static str = SERVICE_NAME;
369 }
370 }
371}
372
373fn generate_methods<T: Service>(
374 service: &T,
375 emit_package: bool,
376 proto_path: &str,
377 compile_well_known_types: bool,
378 use_arc_self: bool,
379 generate_default_stubs: bool,
380) -> TokenStream {
381 let mut stream = TokenStream::new();
382
383 for method in service.methods() {
384 let path = format_method_path(service, method, emit_package);
385 let method_path = Lit::Str(LitStr::new(&path, Span::call_site()));
386 let ident = quote::format_ident!("{}", method.name());
387 let server_trait = quote::format_ident!("{}", service.name());
388
389 let method_stream = match (method.client_streaming(), method.server_streaming()) {
390 (false, false) => generate_unary(
391 method,
392 proto_path,
393 compile_well_known_types,
394 ident,
395 server_trait,
396 use_arc_self,
397 ),
398
399 (false, true) => generate_server_streaming(
400 method,
401 proto_path,
402 compile_well_known_types,
403 ident.clone(),
404 server_trait,
405 use_arc_self,
406 generate_default_stubs,
407 ),
408 (true, false) => generate_client_streaming(
409 method,
410 proto_path,
411 compile_well_known_types,
412 ident.clone(),
413 server_trait,
414 use_arc_self,
415 ),
416
417 (true, true) => generate_streaming(
418 method,
419 proto_path,
420 compile_well_known_types,
421 ident.clone(),
422 server_trait,
423 use_arc_self,
424 generate_default_stubs,
425 ),
426 };
427
428 let method = quote! {
429 #method_path => {
430 #method_stream
431 }
432 };
433 stream.extend(method);
434 }
435
436 stream
437}
438
439fn generate_unary<T: Method>(
440 method: &T,
441 proto_path: &str,
442 compile_well_known_types: bool,
443 method_ident: Ident,
444 server_trait: Ident,
445 use_arc_self: bool,
446) -> TokenStream {
447 let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
448
449 let service_ident = quote::format_ident!("{}Svc", method.identifier());
450
451 let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
452
453 let inner_arg = if use_arc_self {
454 quote!(inner)
455 } else {
456 quote!(&inner)
457 };
458
459 quote! {
460 #[allow(non_camel_case_types)]
461 struct #service_ident<T: #server_trait >(pub Arc<T>);
462
463 impl<T: #server_trait> tonic::server::UnaryService<#request> for #service_ident<T> {
464 type Response = #response;
465 type Future = BoxFuture<tonic::Response<Self::Response>, tonic::Status>;
466
467 fn call(&mut self, request: tonic::Request<#request>) -> Self::Future {
468 let inner = Arc::clone(&self.0);
469 let fut = async move {
470 <T as #server_trait>::#method_ident(#inner_arg, request).await
471 };
472 Box::pin(fut)
473 }
474 }
475
476 let accept_compression_encodings = self.accept_compression_encodings;
477 let send_compression_encodings = self.send_compression_encodings;
478 let max_decoding_message_size = self.max_decoding_message_size;
479 let max_encoding_message_size = self.max_encoding_message_size;
480 let inner = self.inner.clone();
481 let fut = async move {
482 let method = #service_ident(inner);
483 let codec = #codec_name::default();
484
485 let mut grpc = tonic::server::Grpc::new(codec)
486 .apply_compression_config(accept_compression_encodings, send_compression_encodings)
487 .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);
488
489 let res = grpc.unary(method, req).await;
490 Ok(res)
491 };
492
493 Box::pin(fut)
494 }
495}
496
497fn generate_server_streaming<T: Method>(
498 method: &T,
499 proto_path: &str,
500 compile_well_known_types: bool,
501 method_ident: Ident,
502 server_trait: Ident,
503 use_arc_self: bool,
504 generate_default_stubs: bool,
505) -> TokenStream {
506 let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
507
508 let service_ident = quote::format_ident!("{}Svc", method.identifier());
509
510 let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
511
512 let response_stream = if !generate_default_stubs {
513 let stream = quote::format_ident!("{}Stream", method.identifier());
514 quote!(type ResponseStream = T::#stream)
515 } else {
516 quote!(type ResponseStream = BoxStream<#response>)
517 };
518
519 let inner_arg = if use_arc_self {
520 quote!(inner)
521 } else {
522 quote!(&inner)
523 };
524
525 quote! {
526 #[allow(non_camel_case_types)]
527 struct #service_ident<T: #server_trait >(pub Arc<T>);
528
529 impl<T: #server_trait> tonic::server::ServerStreamingService<#request> for #service_ident<T> {
530 type Response = #response;
531 #response_stream;
532 type Future = BoxFuture<tonic::Response<Self::ResponseStream>, tonic::Status>;
533
534 fn call(&mut self, request: tonic::Request<#request>) -> Self::Future {
535 let inner = Arc::clone(&self.0);
536 let fut = async move {
537 <T as #server_trait>::#method_ident(#inner_arg, request).await
538 };
539 Box::pin(fut)
540 }
541 }
542
543 let accept_compression_encodings = self.accept_compression_encodings;
544 let send_compression_encodings = self.send_compression_encodings;
545 let max_decoding_message_size = self.max_decoding_message_size;
546 let max_encoding_message_size = self.max_encoding_message_size;
547 let inner = self.inner.clone();
548 let fut = async move {
549 let method = #service_ident(inner);
550 let codec = #codec_name::default();
551
552 let mut grpc = tonic::server::Grpc::new(codec)
553 .apply_compression_config(accept_compression_encodings, send_compression_encodings)
554 .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);
555
556 let res = grpc.server_streaming(method, req).await;
557 Ok(res)
558 };
559
560 Box::pin(fut)
561 }
562}
563
564fn generate_client_streaming<T: Method>(
565 method: &T,
566 proto_path: &str,
567 compile_well_known_types: bool,
568 method_ident: Ident,
569 server_trait: Ident,
570 use_arc_self: bool,
571) -> TokenStream {
572 let service_ident = quote::format_ident!("{}Svc", method.identifier());
573
574 let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
575 let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
576
577 let inner_arg = if use_arc_self {
578 quote!(inner)
579 } else {
580 quote!(&inner)
581 };
582
583 quote! {
584 #[allow(non_camel_case_types)]
585 struct #service_ident<T: #server_trait >(pub Arc<T>);
586
587 impl<T: #server_trait> tonic::server::ClientStreamingService<#request> for #service_ident<T>
588 {
589 type Response = #response;
590 type Future = BoxFuture<tonic::Response<Self::Response>, tonic::Status>;
591
592 fn call(&mut self, request: tonic::Request<tonic::Streaming<#request>>) -> Self::Future {
593 let inner = Arc::clone(&self.0);
594 let fut = async move {
595 <T as #server_trait>::#method_ident(#inner_arg, request).await
596 };
597 Box::pin(fut)
598 }
599 }
600
601 let accept_compression_encodings = self.accept_compression_encodings;
602 let send_compression_encodings = self.send_compression_encodings;
603 let max_decoding_message_size = self.max_decoding_message_size;
604 let max_encoding_message_size = self.max_encoding_message_size;
605 let inner = self.inner.clone();
606 let fut = async move {
607 let method = #service_ident(inner);
608 let codec = #codec_name::default();
609
610 let mut grpc = tonic::server::Grpc::new(codec)
611 .apply_compression_config(accept_compression_encodings, send_compression_encodings)
612 .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);
613
614 let res = grpc.client_streaming(method, req).await;
615 Ok(res)
616 };
617
618 Box::pin(fut)
619 }
620}
621
622fn generate_streaming<T: Method>(
623 method: &T,
624 proto_path: &str,
625 compile_well_known_types: bool,
626 method_ident: Ident,
627 server_trait: Ident,
628 use_arc_self: bool,
629 generate_default_stubs: bool,
630) -> TokenStream {
631 let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
632
633 let service_ident = quote::format_ident!("{}Svc", method.identifier());
634
635 let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
636
637 let response_stream = if !generate_default_stubs {
638 let stream = quote::format_ident!("{}Stream", method.identifier());
639 quote!(type ResponseStream = T::#stream)
640 } else {
641 quote!(type ResponseStream = BoxStream<#response>)
642 };
643
644 let inner_arg = if use_arc_self {
645 quote!(inner)
646 } else {
647 quote!(&inner)
648 };
649
650 quote! {
651 #[allow(non_camel_case_types)]
652 struct #service_ident<T: #server_trait>(pub Arc<T>);
653
654 impl<T: #server_trait> tonic::server::StreamingService<#request> for #service_ident<T>
655 {
656 type Response = #response;
657 #response_stream;
658 type Future = BoxFuture<tonic::Response<Self::ResponseStream>, tonic::Status>;
659
660 fn call(&mut self, request: tonic::Request<tonic::Streaming<#request>>) -> Self::Future {
661 let inner = Arc::clone(&self.0);
662 let fut = async move {
663 <T as #server_trait>::#method_ident(#inner_arg, request).await
664 };
665 Box::pin(fut)
666 }
667 }
668
669 let accept_compression_encodings = self.accept_compression_encodings;
670 let send_compression_encodings = self.send_compression_encodings;
671 let max_decoding_message_size = self.max_decoding_message_size;
672 let max_encoding_message_size = self.max_encoding_message_size;
673 let inner = self.inner.clone();
674 let fut = async move {
675 let method = #service_ident(inner);
676 let codec = #codec_name::default();
677
678 let mut grpc = tonic::server::Grpc::new(codec)
679 .apply_compression_config(accept_compression_encodings, send_compression_encodings)
680 .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);
681
682 let res = grpc.streaming(method, req).await;
683 Ok(res)
684 };
685
686 Box::pin(fut)
687 }
688}