tonic_build/
lib.rs

1#![doc = include_str!("../README.md")]
2#![recursion_limit = "256"]
3#![doc(
4    html_logo_url = "https://raw.githubusercontent.com/tokio-rs/website/master/public/img/icons/tonic.svg"
5)]
6#![doc(issue_tracker_base_url = "https://github.com/hyperium/tonic/issues/")]
7#![doc(test(no_crate_inject, attr(deny(rust_2018_idioms))))]
8#![cfg_attr(docsrs, feature(doc_auto_cfg))]
9
10use proc_macro2::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream};
11use quote::TokenStreamExt;
12
13// Prost functionality has been moved to tonic-prost-build
14
15pub mod manual;
16
17/// Service code generation for client
18mod client;
19/// Service code generation for Server
20mod server;
21
22mod code_gen;
23pub use code_gen::CodeGenBuilder;
24
25/// Service generation trait.
26///
27/// This trait can be implemented and consumed
28/// by `client::generate` and `server::generate`
29/// to allow any codegen module to generate service
30/// abstractions.
31pub trait Service {
32    /// Comment type.
33    type Comment: AsRef<str>;
34
35    /// Method type.
36    type Method: Method;
37
38    /// Name of service.
39    fn name(&self) -> &str;
40    /// Package name of service.
41    fn package(&self) -> &str;
42    /// Identifier used to generate type name.
43    fn identifier(&self) -> &str;
44    /// Methods provided by service.
45    fn methods(&self) -> &[Self::Method];
46    /// Get comments about this item.
47    fn comment(&self) -> &[Self::Comment];
48}
49
50/// Method generation trait.
51///
52/// Each service contains a set of generic
53/// `Methods`'s that will be used by codegen
54/// to generate abstraction implementations for
55/// the provided methods.
56pub trait Method {
57    /// Comment type.
58    type Comment: AsRef<str>;
59
60    /// Name of method.
61    fn name(&self) -> &str;
62    /// Identifier used to generate type name.
63    fn identifier(&self) -> &str;
64    /// Path to the codec.
65    fn codec_path(&self) -> &str;
66    /// Method is streamed by client.
67    fn client_streaming(&self) -> bool;
68    /// Method is streamed by server.
69    fn server_streaming(&self) -> bool;
70    /// Get comments about this item.
71    fn comment(&self) -> &[Self::Comment];
72    /// Method is deprecated.
73    fn deprecated(&self) -> bool {
74        false
75    }
76    /// Type name of request and response.
77    fn request_response_name(
78        &self,
79        proto_path: &str,
80        compile_well_known_types: bool,
81    ) -> (TokenStream, TokenStream);
82}
83
84/// Attributes that will be added to `mod` and `struct` items.
85#[derive(Debug, Default, Clone)]
86pub struct Attributes {
87    /// `mod` attributes.
88    module: Vec<(String, String)>,
89    /// `struct` attributes.
90    structure: Vec<(String, String)>,
91    /// `trait` attributes.
92    trait_attributes: Vec<(String, String)>,
93}
94
95impl Attributes {
96    fn for_mod(&self, name: &str) -> Vec<syn::Attribute> {
97        generate_attributes(name, &self.module)
98    }
99
100    fn for_struct(&self, name: &str) -> Vec<syn::Attribute> {
101        generate_attributes(name, &self.structure)
102    }
103
104    fn for_trait(&self, name: &str) -> Vec<syn::Attribute> {
105        generate_attributes(name, &self.trait_attributes)
106    }
107
108    /// Add an attribute that will be added to `mod` items matching the given pattern.
109    ///
110    /// # Examples
111    ///
112    /// ```
113    /// # use tonic_build::*;
114    /// let mut attributes = Attributes::default();
115    /// attributes.push_mod("my.proto.package", r#"#[cfg(feature = "server")]"#);
116    /// ```
117    pub fn push_mod(&mut self, pattern: impl Into<String>, attr: impl Into<String>) {
118        self.module.push((pattern.into(), attr.into()));
119    }
120
121    /// Add an attribute that will be added to `struct` items matching the given pattern.
122    ///
123    /// # Examples
124    ///
125    /// ```
126    /// # use tonic_build::*;
127    /// let mut attributes = Attributes::default();
128    /// attributes.push_struct("EchoService", "#[derive(PartialEq)]");
129    /// ```
130    pub fn push_struct(&mut self, pattern: impl Into<String>, attr: impl Into<String>) {
131        self.structure.push((pattern.into(), attr.into()));
132    }
133
134    /// Add an attribute that will be added to `trait` items matching the given pattern.
135    ///
136    /// # Examples
137    ///
138    /// ```
139    /// # use tonic_build::*;
140    /// let mut attributes = Attributes::default();
141    /// attributes.push_trait("Server", "#[mockall::automock]");
142    /// ```
143    pub fn push_trait(&mut self, pattern: impl Into<String>, attr: impl Into<String>) {
144        self.trait_attributes.push((pattern.into(), attr.into()));
145    }
146}
147
148fn format_service_name<T: Service>(service: &T, emit_package: bool) -> String {
149    let package = if emit_package { service.package() } else { "" };
150    format!(
151        "{}{}{}",
152        package,
153        if package.is_empty() { "" } else { "." },
154        service.identifier(),
155    )
156}
157
158fn format_method_path<T: Service>(service: &T, method: &T::Method, emit_package: bool) -> String {
159    format!(
160        "/{}/{}",
161        format_service_name(service, emit_package),
162        method.identifier()
163    )
164}
165
166fn format_method_name<T: Service>(service: &T, method: &T::Method, emit_package: bool) -> String {
167    format!(
168        "{}.{}",
169        format_service_name(service, emit_package),
170        method.identifier()
171    )
172}
173
174// Generates attributes given a list of (`pattern`, `attribute`) pairs. If `pattern` matches `name`, `attribute` will be included.
175fn generate_attributes<'a>(
176    name: &str,
177    attrs: impl IntoIterator<Item = &'a (String, String)>,
178) -> Vec<syn::Attribute> {
179    attrs
180        .into_iter()
181        .filter(|(matcher, _)| match_name(matcher, name))
182        .flat_map(|(_, attr)| {
183            // attributes cannot be parsed directly, so we pretend they're on a struct
184            syn::parse_str::<syn::DeriveInput>(&format!("{attr}\nstruct fake;"))
185                .unwrap()
186                .attrs
187        })
188        .collect::<Vec<_>>()
189}
190
191fn generate_deprecated() -> TokenStream {
192    let mut deprecated_stream = TokenStream::new();
193    deprecated_stream.append(Ident::new("deprecated", Span::call_site()));
194
195    let group = Group::new(Delimiter::Bracket, deprecated_stream);
196
197    let mut stream = TokenStream::new();
198    stream.append(Punct::new('#', Spacing::Alone));
199    stream.append(group);
200
201    stream
202}
203
204// Generate a singular line of a doc comment
205fn generate_doc_comment<S: AsRef<str>>(comment: S) -> TokenStream {
206    let comment = comment.as_ref();
207
208    let comment = if !comment.starts_with(' ') {
209        format!(" {comment}")
210    } else {
211        comment.to_string()
212    };
213
214    let mut doc_stream = TokenStream::new();
215
216    doc_stream.append(Ident::new("doc", Span::call_site()));
217    doc_stream.append(Punct::new('=', Spacing::Alone));
218    doc_stream.append(Literal::string(comment.as_ref()));
219
220    let group = Group::new(Delimiter::Bracket, doc_stream);
221
222    let mut stream = TokenStream::new();
223    stream.append(Punct::new('#', Spacing::Alone));
224    stream.append(group);
225    stream
226}
227
228// Generate a larger doc comment composed of many lines of doc comments
229fn generate_doc_comments<T: AsRef<str>>(comments: &[T]) -> TokenStream {
230    let mut stream = TokenStream::new();
231
232    for comment in comments {
233        stream.extend(generate_doc_comment(comment));
234    }
235
236    stream
237}
238
239// Checks whether a path pattern matches a given path.
240pub(crate) fn match_name(pattern: &str, path: &str) -> bool {
241    if pattern.is_empty() {
242        false
243    } else if pattern == "." || pattern == path {
244        true
245    } else {
246        let pattern_segments = pattern.split('.').collect::<Vec<_>>();
247        let path_segments = path.split('.').collect::<Vec<_>>();
248
249        if &pattern[..1] == "." {
250            // prefix match
251            if pattern_segments.len() > path_segments.len() {
252                false
253            } else {
254                pattern_segments[..] == path_segments[..pattern_segments.len()]
255            }
256        // suffix match
257        } else if pattern_segments.len() > path_segments.len() {
258            false
259        } else {
260            pattern_segments[..] == path_segments[path_segments.len() - pattern_segments.len()..]
261        }
262    }
263}
264
265fn naive_snake_case(name: &str) -> String {
266    let mut s = String::new();
267    let mut it = name.chars().peekable();
268
269    while let Some(x) = it.next() {
270        s.push(x.to_ascii_lowercase());
271        if let Some(y) = it.peek() {
272            if y.is_uppercase() {
273                s.push('_');
274            }
275        }
276    }
277
278    s
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_match_name() {
287        assert!(match_name(".", ".my.protos"));
288        assert!(match_name(".", ".protos"));
289
290        assert!(match_name(".my", ".my"));
291        assert!(match_name(".my", ".my.protos"));
292        assert!(match_name(".my.protos.Service", ".my.protos.Service"));
293
294        assert!(match_name("Service", ".my.protos.Service"));
295
296        assert!(!match_name(".m", ".my.protos"));
297        assert!(!match_name(".p", ".protos"));
298
299        assert!(!match_name(".my", ".myy"));
300        assert!(!match_name(".protos", ".my.protos"));
301        assert!(!match_name(".Service", ".my.protos.Service"));
302
303        assert!(!match_name("service", ".my.protos.Service"));
304    }
305
306    #[test]
307    fn test_snake_case() {
308        for case in &[
309            ("Service", "service"),
310            ("ThatHasALongName", "that_has_a_long_name"),
311            ("greeter", "greeter"),
312            ("ABCServiceX", "a_b_c_service_x"),
313        ] {
314            assert_eq!(naive_snake_case(case.0), case.1)
315        }
316    }
317}