1#![warn(
21 missing_docs,
22 missing_debug_implementations,
23 rust_2018_idioms,
24 unreachable_pub
25)]
26#![doc(
27 html_logo_url = "https://raw.githubusercontent.com/tokio-rs/website/master/public/img/icons/tonic.svg"
28)]
29#![doc(html_root_url = "https://docs.rs/tonic-prost-build/0.14.0")]
30#![doc(issue_tracker_base_url = "https://github.com/hyperium/tonic/issues/")]
31
32use proc_macro2::TokenStream;
33use prost_build::{Method, Service};
34use quote::{ToTokens, quote};
35use std::cell::RefCell;
36use std::{
37 collections::HashSet,
38 ffi::OsString,
39 io,
40 path::{Path, PathBuf},
41};
42use tonic_build::{Attributes, CodeGenBuilder};
43
44#[cfg(test)]
45mod tests;
46
47pub use tonic_build::{
49 Attributes as TonicAttributes, Method as TonicMethod, Service as TonicService, manual,
50};
51
52pub use prost_build::Config;
54pub use prost_types::FileDescriptorSet;
55
56pub fn configure() -> Builder {
60 Builder {
61 build_client: true,
62 build_server: true,
63 build_transport: true,
64 file_descriptor_set_path: None,
65 skip_protoc_run: false,
66 out_dir: None,
67 extern_path: Vec::new(),
68 field_attributes: Vec::new(),
69 message_attributes: Vec::new(),
70 enum_attributes: Vec::new(),
71 type_attributes: Vec::new(),
72 boxed: Vec::new(),
73 btree_map: None,
74 bytes: None,
75 server_attributes: Attributes::default(),
76 client_attributes: Attributes::default(),
77 proto_path: "super".to_string(),
78 compile_well_known_types: false,
79 emit_package: true,
80 with_extended_rust_types: false,
81 protoc_args: Vec::new(),
82 include_file: None,
83 emit_rerun_if_changed: std::env::var_os("CARGO").is_some(),
84 disable_comments: HashSet::default(),
85 use_arc_self: false,
86 generate_default_stubs: false,
87 codec_path: "tonic_prost::ProstCodec".to_string(),
88 skip_debug: HashSet::default(),
89 }
90}
91
92pub fn compile_protos(proto: impl AsRef<Path>) -> io::Result<()> {
97 let proto_path: &Path = proto.as_ref();
98
99 let proto_dir = proto_path
101 .parent()
102 .expect("proto file should reside in a directory");
103
104 self::configure().compile_protos(&[proto_path], &[proto_dir])
105}
106
107pub fn compile_fds(fds: prost_types::FileDescriptorSet) -> io::Result<()> {
109 self::configure().compile_fds(fds)
110}
111
112const EXTENDED_NON_PATH_TYPE_ALLOWLIST: &[&str] =
115 &["()", "bool", "i32", "i64", "u32", "u64", "f32", "f64"];
116const DEFAULT_NON_PATH_TYPE_ALLOWLIST: &[&str] = &["()"];
118
119thread_local! {
120 pub static NON_PATH_TYPE_ALLOWLIST: RefCell<&'static [&'static str]> = const {
123 RefCell::new(DEFAULT_NON_PATH_TYPE_ALLOWLIST)
124 };
125}
126
127struct TonicBuildService {
129 prost_service: Service,
130 methods: Vec<TonicBuildMethod>,
131}
132
133impl TonicBuildService {
134 fn new(prost_service: Service, codec_path: String) -> Self {
135 Self {
136 methods: prost_service
138 .methods
139 .iter()
140 .map(|prost_method| TonicBuildMethod {
141 prost_method: prost_method.clone(),
142 codec_path: codec_path.clone(),
143 })
144 .collect(),
145 prost_service,
146 }
147 }
148}
149
150struct TonicBuildMethod {
152 prost_method: Method,
153 codec_path: String,
154}
155
156impl tonic_build::Service for TonicBuildService {
157 type Method = TonicBuildMethod;
158 type Comment = String;
159
160 fn name(&self) -> &str {
161 &self.prost_service.name
162 }
163
164 fn package(&self) -> &str {
165 &self.prost_service.package
166 }
167
168 fn identifier(&self) -> &str {
169 &self.prost_service.proto_name
170 }
171
172 fn methods(&self) -> &[Self::Method] {
173 &self.methods
174 }
175
176 fn comment(&self) -> &[Self::Comment] {
177 &self.prost_service.comments.leading
178 }
179}
180
181impl tonic_build::Method for TonicBuildMethod {
182 type Comment = String;
183
184 fn name(&self) -> &str {
185 &self.prost_method.name
186 }
187
188 fn identifier(&self) -> &str {
189 &self.prost_method.proto_name
190 }
191
192 fn client_streaming(&self) -> bool {
193 self.prost_method.client_streaming
194 }
195
196 fn server_streaming(&self) -> bool {
197 self.prost_method.server_streaming
198 }
199
200 fn comment(&self) -> &[Self::Comment] {
201 &self.prost_method.comments.leading
202 }
203
204 fn request_response_name(
205 &self,
206 proto_path: &str,
207 compile_well_known_types: bool,
208 ) -> (TokenStream, TokenStream) {
209 let request = if is_google_type(&self.prost_method.input_type) && !compile_well_known_types
210 {
211 match self.prost_method.input_type.as_str() {
213 ".google.protobuf.Empty" => quote!(()),
214 ".google.protobuf.Any" => quote!(::prost_types::Any),
215 ".google.protobuf.StringValue" => quote!(::prost::alloc::string::String),
216 _ => {
217 let type_name = self
219 .prost_method
220 .input_type
221 .trim_start_matches(".google.protobuf.")
222 .to_string();
223 syn::parse_str::<syn::Path>(&format!("::prost_types::{type_name}"))
224 .unwrap()
225 .to_token_stream()
226 }
227 }
228 } else if is_non_path_type(&self.prost_method.input_type) {
229 self.prost_method.input_type.parse::<TokenStream>().unwrap()
230 } else {
231 if self.prost_method.input_type.starts_with("::")
233 || self.prost_method.input_type.starts_with("crate::")
234 {
235 self.prost_method.input_type.parse::<TokenStream>().unwrap()
237 } else {
238 let rust_type = self.prost_method.input_type.replace('.', "::");
240 let rust_type = rust_type.trim_start_matches("::");
242 syn::parse_str::<syn::Path>(&format!("{proto_path}::{rust_type}"))
243 .unwrap()
244 .to_token_stream()
245 }
246 };
247
248 let response =
249 if is_google_type(&self.prost_method.output_type) && !compile_well_known_types {
250 match self.prost_method.output_type.as_str() {
252 ".google.protobuf.Empty" => quote!(()),
253 ".google.protobuf.Any" => quote!(::prost_types::Any),
254 ".google.protobuf.StringValue" => quote!(::prost::alloc::string::String),
255 _ => {
256 let type_name = self
258 .prost_method
259 .output_type
260 .trim_start_matches(".google.protobuf.")
261 .to_string();
262 syn::parse_str::<syn::Path>(&format!("::prost_types::{type_name}"))
263 .unwrap()
264 .to_token_stream()
265 }
266 }
267 } else if is_non_path_type(&self.prost_method.output_type) {
268 self.prost_method
269 .output_type
270 .parse::<TokenStream>()
271 .unwrap()
272 } else {
273 if self.prost_method.output_type.starts_with("::")
275 || self.prost_method.output_type.starts_with("crate::")
276 {
277 self.prost_method
279 .output_type
280 .parse::<TokenStream>()
281 .unwrap()
282 } else {
283 let rust_type = self.prost_method.output_type.replace('.', "::");
285 let rust_type = rust_type.trim_start_matches("::");
287 syn::parse_str::<syn::Path>(&format!("{proto_path}::{rust_type}"))
288 .unwrap()
289 .to_token_stream()
290 }
291 };
292
293 (request, response)
294 }
295
296 fn codec_path(&self) -> &str {
297 &self.codec_path
298 }
299
300 fn deprecated(&self) -> bool {
301 self.prost_method.options.deprecated()
302 }
303}
304
305fn is_non_path_type(ty: &str) -> bool {
306 NON_PATH_TYPE_ALLOWLIST.with(|allowlist| {
307 allowlist
308 .borrow()
309 .iter()
310 .any(|allowlist_type| ty.ends_with(allowlist_type))
311 })
312}
313
314fn is_google_type(ty: &str) -> bool {
315 ty.starts_with(".google.protobuf")
316}
317
318#[derive(Debug)]
320struct ServiceGenerator {
321 build_client: bool,
322 build_server: bool,
323 build_transport: bool,
324 client_attributes: Attributes,
325 server_attributes: Attributes,
326 use_arc_self: bool,
327 generate_default_stubs: bool,
328 proto_path: String,
329 compile_well_known_types: bool,
330 codec_path: String,
331 disable_comments: HashSet<String>,
332}
333
334impl ServiceGenerator {
335 #[allow(clippy::too_many_arguments)]
337 fn new(
338 build_client: bool,
339 build_server: bool,
340 build_transport: bool,
341 client_attributes: Attributes,
342 server_attributes: Attributes,
343 use_arc_self: bool,
344 generate_default_stubs: bool,
345 proto_path: String,
346 compile_well_known_types: bool,
347 codec_path: String,
348 disable_comments: HashSet<String>,
349 ) -> Self {
350 ServiceGenerator {
351 build_client,
352 build_server,
353 build_transport,
354 client_attributes,
355 server_attributes,
356 use_arc_self,
357 generate_default_stubs,
358 proto_path,
359 compile_well_known_types,
360 codec_path,
361 disable_comments,
362 }
363 }
364}
365
366impl prost_build::ServiceGenerator for ServiceGenerator {
367 fn generate(&mut self, service: Service, buf: &mut String) {
368 let tonic_service = TonicBuildService::new(service, self.codec_path.clone());
369
370 let mut builder = CodeGenBuilder::new();
371 builder
372 .emit_package(true)
373 .build_transport(self.build_transport)
374 .compile_well_known_types(self.compile_well_known_types)
375 .disable_comments(self.disable_comments.clone())
376 .use_arc_self(self.use_arc_self)
377 .generate_default_stubs(self.generate_default_stubs);
378
379 let mut tokens = TokenStream::new();
380
381 if self.build_client {
382 builder.attributes(self.client_attributes.clone());
383 let client_code = builder.generate_client(&tonic_service, &self.proto_path);
384 tokens.extend(client_code);
385 }
386
387 if self.build_server {
388 builder.attributes(self.server_attributes.clone());
389 let server_code = builder.generate_server(&tonic_service, &self.proto_path);
390 tokens.extend(server_code);
391 }
392
393 let formatted = prettyplease::unparse(&syn::parse2(tokens).unwrap());
394 buf.push_str(&formatted);
395 }
396}
397
398#[derive(Debug, Clone)]
400pub struct Builder {
401 build_client: bool,
402 build_server: bool,
403 build_transport: bool,
404 file_descriptor_set_path: Option<PathBuf>,
405 skip_protoc_run: bool,
406 out_dir: Option<PathBuf>,
407 extern_path: Vec<(String, String)>,
408 field_attributes: Vec<(String, String)>,
409 message_attributes: Vec<(String, String)>,
410 enum_attributes: Vec<(String, String)>,
411 type_attributes: Vec<(String, String)>,
412 boxed: Vec<String>,
413 btree_map: Option<Vec<String>>,
414 bytes: Option<Vec<String>>,
415 server_attributes: Attributes,
416 client_attributes: Attributes,
417 proto_path: String,
418 compile_well_known_types: bool,
419 emit_package: bool,
420 with_extended_rust_types: bool,
421 protoc_args: Vec<OsString>,
422 include_file: Option<PathBuf>,
423 emit_rerun_if_changed: bool,
424 disable_comments: HashSet<String>,
425 use_arc_self: bool,
426 generate_default_stubs: bool,
427 codec_path: String,
428 skip_debug: HashSet<String>,
429}
430
431impl Builder {
432 pub fn build_client(mut self, enable: bool) -> Self {
434 self.build_client = enable;
435 self
436 }
437
438 pub fn build_server(mut self, enable: bool) -> Self {
440 self.build_server = enable;
441 self
442 }
443
444 pub fn build_transport(mut self, enable: bool) -> Self {
446 self.build_transport = enable;
447 self
448 }
449
450 pub fn out_dir(mut self, out_dir: impl AsRef<Path>) -> Self {
455 self.out_dir = Some(out_dir.as_ref().to_path_buf());
456 self
457 }
458
459 pub fn extern_path(mut self, proto_path: impl AsRef<str>, rust_path: impl AsRef<str>) -> Self {
465 self.extern_path.push((
466 proto_path.as_ref().to_string(),
467 rust_path.as_ref().to_string(),
468 ));
469 self
470 }
471
472 pub fn field_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
476 self.field_attributes
477 .push((path.as_ref().to_string(), attribute.as_ref().to_string()));
478 self
479 }
480
481 pub fn message_attribute<P: AsRef<str>, A: AsRef<str>>(
485 mut self,
486 path: P,
487 attribute: A,
488 ) -> Self {
489 self.message_attributes
490 .push((path.as_ref().to_string(), attribute.as_ref().to_string()));
491 self
492 }
493
494 pub fn enum_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
498 self.enum_attributes
499 .push((path.as_ref().to_string(), attribute.as_ref().to_string()));
500 self
501 }
502
503 pub fn type_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
507 self.type_attributes
508 .push((path.as_ref().to_string(), attribute.as_ref().to_string()));
509 self
510 }
511
512 pub fn boxed<P: AsRef<str>>(mut self, path: P) -> Self {
516 self.boxed.push(path.as_ref().to_string());
517 self
518 }
519
520 pub fn btree_map<P: AsRef<str>>(mut self, path: P) -> Self {
524 match &mut self.btree_map {
525 Some(paths) => paths.push(path.as_ref().to_string()),
526 None => self.btree_map = Some(vec![path.as_ref().to_string()]),
527 }
528 self
529 }
530
531 pub fn bytes<P: AsRef<str>>(mut self, path: P) -> Self {
535 match &mut self.bytes {
536 Some(paths) => paths.push(path.as_ref().to_string()),
537 None => self.bytes = Some(vec![path.as_ref().to_string()]),
538 }
539 self
540 }
541
542 pub fn server_mod_attribute<P: AsRef<str>, A: AsRef<str>>(
545 mut self,
546 path: P,
547 attribute: A,
548 ) -> Self {
549 self.server_attributes
550 .push_mod(path.as_ref(), attribute.as_ref());
551 self
552 }
553
554 pub fn server_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
557 self.server_attributes
558 .push_struct(path.as_ref(), attribute.as_ref());
559 self
560 }
561
562 pub fn trait_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
565 self.server_attributes
566 .push_trait(path.as_ref(), attribute.as_ref());
567 self
568 }
569
570 pub fn client_mod_attribute<P: AsRef<str>, A: AsRef<str>>(
573 mut self,
574 path: P,
575 attribute: A,
576 ) -> Self {
577 self.client_attributes
578 .push_mod(path.as_ref(), attribute.as_ref());
579 self
580 }
581
582 pub fn client_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
585 self.client_attributes
586 .push_struct(path.as_ref(), attribute.as_ref());
587 self
588 }
589
590 pub fn proto_path(mut self, proto_path: impl AsRef<str>) -> Self {
603 self.proto_path = proto_path.as_ref().to_string();
604 self
605 }
606
607 pub fn compile_well_known_types(mut self, enable: bool) -> Self {
611 self.compile_well_known_types = enable;
612 self
613 }
614
615 pub fn with_extended_rust_types(mut self, enable: bool) -> Self {
619 self.with_extended_rust_types = enable;
620 self
621 }
622
623 pub fn emit_package(mut self, enable: bool) -> Self {
627 self.emit_package = enable;
628 self
629 }
630
631 pub fn file_descriptor_set_path(mut self, path: impl AsRef<Path>) -> Self {
635 self.file_descriptor_set_path = Some(path.as_ref().to_path_buf());
636 self
637 }
638
639 pub fn skip_protoc_run(mut self) -> Self {
643 self.skip_protoc_run = true;
644 self
645 }
646
647 pub fn protoc_arg<A: AsRef<str>>(mut self, arg: A) -> Self {
651 self.protoc_args.push(arg.as_ref().into());
652 self
653 }
654
655 pub fn include_file(mut self, path: impl AsRef<Path>) -> Self {
659 self.include_file = Some(path.as_ref().to_path_buf());
660 self
661 }
662
663 pub fn emit_rerun_if_changed(mut self, enable: bool) -> Self {
667 self.emit_rerun_if_changed = enable;
668 self
669 }
670
671 pub fn disable_comments<I, S>(mut self, path: I) -> Self
675 where
676 I: IntoIterator<Item = S>,
677 S: AsRef<str>,
678 {
679 self.disable_comments
680 .extend(path.into_iter().map(|s| s.as_ref().to_string()));
681 self
682 }
683
684 pub fn use_arc_self(mut self, enable: bool) -> Self {
686 self.use_arc_self = enable;
687 self
688 }
689
690 pub fn generate_default_stubs(mut self, enable: bool) -> Self {
693 self.generate_default_stubs = enable;
694 self
695 }
696
697 pub fn codec_path(mut self, path: impl AsRef<str>) -> Self {
699 self.codec_path = path.as_ref().to_string();
700 self
701 }
702
703 pub fn skip_debug<I, S>(mut self, paths: I) -> Self
708 where
709 I: IntoIterator<Item = S>,
710 S: AsRef<str>,
711 {
712 self.skip_debug
713 .extend(paths.into_iter().map(|s| s.as_ref().to_string()));
714 self
715 }
716
717 pub fn compile_protos<P>(self, protos: &[P], includes: &[P]) -> io::Result<()>
719 where
720 P: AsRef<Path>,
721 {
722 self.compile_with_config(Config::new(), protos, includes)
723 }
724
725 pub fn compile_with_config<P>(
730 self,
731 mut config: Config,
732 protos: &[P],
733 includes: &[P],
734 ) -> io::Result<()>
735 where
736 P: AsRef<Path>,
737 {
738 struct Defer;
741 impl Drop for Defer {
742 fn drop(&mut self) {
743 NON_PATH_TYPE_ALLOWLIST.set(DEFAULT_NON_PATH_TYPE_ALLOWLIST);
744 }
745 }
746
747 let _defer_guard = Defer;
748
749 let out_dir = if let Some(out_dir) = self.out_dir.as_ref() {
750 out_dir.clone()
751 } else {
752 PathBuf::from(std::env::var("OUT_DIR").unwrap())
753 };
754
755 config.out_dir(&out_dir);
756
757 for (proto_path, rust_path) in &self.extern_path {
758 config.extern_path(proto_path, rust_path);
759 }
760
761 for (prost_path, attr) in &self.field_attributes {
762 config.field_attribute(prost_path, attr);
763 }
764
765 for (prost_path, attr) in &self.message_attributes {
766 config.message_attribute(prost_path, attr);
767 }
768
769 for (prost_path, attr) in &self.enum_attributes {
770 config.enum_attribute(prost_path, attr);
771 }
772
773 for (prost_path, attr) in &self.type_attributes {
774 config.type_attribute(prost_path, attr);
775 }
776
777 for prost_path in &self.boxed {
778 config.boxed(prost_path);
779 }
780
781 if let Some(ref paths) = self.btree_map {
782 config.btree_map(paths);
783 }
784
785 if let Some(ref paths) = self.bytes {
786 config.bytes(paths);
787 }
788
789 if self.compile_well_known_types {
790 config.compile_well_known_types();
791 }
792
793 for arg in &self.protoc_args {
794 config.protoc_arg(arg);
795 }
796
797 if let Some(path) = &self.include_file {
798 config.include_file(path);
799 }
800
801 if self.with_extended_rust_types {
802 NON_PATH_TYPE_ALLOWLIST.set(EXTENDED_NON_PATH_TYPE_ALLOWLIST);
803 }
804
805 if !self.skip_debug.is_empty() {
810 config.skip_debug(self.skip_debug.clone());
811 }
812
813 if let Some(path) = &self.file_descriptor_set_path {
814 config.file_descriptor_set_path(path);
815 }
816
817 if self.skip_protoc_run {
818 config.skip_protoc_run();
819 }
820
821 if self.build_client || self.build_server {
822 let service_generator = ServiceGenerator::new(
823 self.build_client,
824 self.build_server,
825 self.build_transport,
826 self.client_attributes,
827 self.server_attributes,
828 self.use_arc_self,
829 self.generate_default_stubs,
830 self.proto_path,
831 self.compile_well_known_types,
832 self.codec_path.clone(),
833 self.disable_comments,
834 );
835
836 config.service_generator(Box::new(service_generator));
837 };
838
839 config.compile_protos(protos, includes)?;
840
841 Ok(())
842 }
843
844 pub fn compile_fds(self, fds: prost_types::FileDescriptorSet) -> io::Result<()> {
846 self.compile_fds_with_config(fds, Config::new())
847 }
848
849 pub fn compile_fds_with_config(
851 self,
852 fds: prost_types::FileDescriptorSet,
853 mut config: Config,
854 ) -> io::Result<()> {
855 let out_dir = if let Some(out_dir) = self.out_dir.as_ref() {
856 out_dir.clone()
857 } else {
858 PathBuf::from(std::env::var("OUT_DIR").unwrap())
859 };
860
861 config.out_dir(&out_dir);
862
863 for (proto_path, rust_path) in &self.extern_path {
864 config.extern_path(proto_path, rust_path);
865 }
866
867 for (prost_path, attr) in &self.field_attributes {
868 config.field_attribute(prost_path, attr);
869 }
870
871 for (prost_path, attr) in &self.message_attributes {
872 config.message_attribute(prost_path, attr);
873 }
874
875 for (prost_path, attr) in &self.enum_attributes {
876 config.enum_attribute(prost_path, attr);
877 }
878
879 for (prost_path, attr) in &self.type_attributes {
880 config.type_attribute(prost_path, attr);
881 }
882
883 for prost_path in &self.boxed {
884 config.boxed(prost_path);
885 }
886
887 if let Some(ref paths) = self.btree_map {
888 config.btree_map(paths);
889 }
890
891 if let Some(ref paths) = self.bytes {
892 config.bytes(paths);
893 }
894
895 if self.compile_well_known_types {
896 config.compile_well_known_types();
897 }
898
899 for arg in &self.protoc_args {
900 config.protoc_arg(arg);
901 }
902
903 if let Some(path) = &self.include_file {
904 config.include_file(path);
905 }
906
907 if !self.skip_debug.is_empty() {
912 config.skip_debug(self.skip_debug.clone());
913 }
914
915 if let Some(path) = &self.file_descriptor_set_path {
916 config.file_descriptor_set_path(path);
917 }
918
919 if self.skip_protoc_run {
920 config.skip_protoc_run();
921 }
922
923 if self.build_client || self.build_server {
924 let service_generator = ServiceGenerator::new(
925 self.build_client,
926 self.build_server,
927 self.build_transport,
928 self.client_attributes,
929 self.server_attributes,
930 self.use_arc_self,
931 self.generate_default_stubs,
932 self.proto_path,
933 self.compile_well_known_types,
934 self.codec_path.clone(),
935 self.disable_comments,
936 );
937
938 config.service_generator(Box::new(service_generator));
939 };
940
941 config.compile_fds(fds)?;
942
943 Ok(())
944 }
945
946 pub fn service_generator(self) -> Box<dyn prost_build::ServiceGenerator> {
949 Box::new(ServiceGenerator::new(
950 self.build_client,
951 self.build_server,
952 self.build_transport,
953 self.client_attributes,
954 self.server_attributes,
955 self.use_arc_self,
956 self.generate_default_stubs,
957 self.proto_path,
958 self.compile_well_known_types,
959 self.codec_path.clone(),
960 self.disable_comments,
961 ))
962 }
963}