trapeze_codegen/
service_generator.rs

1use std::collections::HashMap;
2
3use prost_build::{Comments, Method, Service, ServiceGenerator};
4
5/// A service generator that takes a service descriptor and generates Rust code for a `ttrpc` service.
6///
7/// It generates a trait describing methods of the service and implements the trait for a `trapeze::Client`.
8/// To implement a server, users should implement the trait on their own objects.
9pub struct TtrpcServiceGenerator;
10
11impl ServiceGenerator for TtrpcServiceGenerator {
12    fn generate(&mut self, service: Service, buf: &mut String) {
13        let mut substitutions = service_substitutions(&service);
14
15        let make_client_method = |m| make_client_method(substitutions.clone(), m);
16        let make_trait_method = |m| make_trait_method(substitutions.clone(), m);
17        let make_dispatch_branch = |m| make_dispatch_branch(substitutions.clone(), m);
18
19        let methods = service.methods;
20
21        let client_methods: String = methods.iter().map(make_client_method).collect();
22        let trait_methods: String = methods.iter().map(make_trait_method).collect();
23        let dispatch_branches: String = methods.iter().map(make_dispatch_branch).collect();
24
25        substitutions.insert("client_methods", client_methods);
26        substitutions.insert("trait_methods", trait_methods);
27        substitutions.insert("dispatch_branches", dispatch_branches);
28
29        let service = replace(include_str!("../templates/service.rs"), substitutions);
30
31        buf.push_str(&service);
32    }
33}
34
35fn make_trait_method(mut substitutions: HashMap<&'static str, String>, method: &Method) -> String {
36    substitutions.extend(method_substitutions(method));
37
38    replace(include_str!("../templates/trait_method.rs"), substitutions)
39}
40
41fn make_dispatch_branch(
42    mut substitutions: HashMap<&'static str, String>,
43    method: &Method,
44) -> String {
45    substitutions.extend(method_substitutions(method));
46
47    replace(
48        include_str!("../templates/dispatch_branch.rs"),
49        substitutions,
50    )
51}
52
53fn make_client_method(mut substitutions: HashMap<&'static str, String>, method: &Method) -> String {
54    substitutions.extend(method_substitutions(method));
55
56    replace(include_str!("../templates/client_method.rs"), substitutions)
57}
58
59fn service_substitutions(service: &Service) -> HashMap<&'static str, String> {
60    let mut substitutions = HashMap::default();
61    substitutions.insert("service_comments", format_comments(&service.comments, 0));
62    substitutions.insert("service_name", service.name.clone());
63    substitutions.insert("service_package", service.package.clone());
64    substitutions.insert("service_proto_name", service.proto_name.clone());
65    substitutions.insert("service_module_name", camel2snake(&service.name));
66    substitutions
67}
68
69fn method_substitutions(method: &Method) -> HashMap<&'static str, String> {
70    let mut substitutions = HashMap::default();
71    let Method {
72        name,
73        proto_name,
74        input_type,
75        output_type,
76        client_streaming,
77        server_streaming,
78        comments,
79        ..
80    } = method;
81
82    let input_name = camel2snake(input_type);
83
84    let wrapper = match (*client_streaming, *server_streaming) {
85        (false, false) => "UnaryMethod",
86        (false, true) => "ServerStreamingMethod",
87        (true, false) => "ClientStreamingMethod",
88        (true, true) => "DuplexStreamingMethod",
89    };
90
91    let request_handler = match (*client_streaming, *server_streaming) {
92        (false, false) => "handle_unary_request",
93        (false, true) => "handle_server_streaming_request",
94        (true, false) => "handle_client_streaming_request",
95        (true, true) => "handle_duplex_streaming_request",
96    };
97
98    let input_type = if *client_streaming {
99        stream_for(input_type)
100    } else {
101        input_type.clone()
102    };
103
104    let output_type = if *server_streaming {
105        fallible_stream_for(output_type)
106    } else {
107        fallible_future_for(output_type)
108    };
109
110    let output_handler = if *server_streaming {
111        stream_handler(name)
112    } else {
113        future_handler(name)
114    };
115
116    let not_found = if *server_streaming {
117        not_found_stream()
118    } else {
119        not_found_future()
120    };
121
122    substitutions.insert("method_comments", format_comments(comments, 1));
123    substitutions.insert("method_name", name.clone());
124    substitutions.insert("method_proto_name", proto_name.clone());
125    substitutions.insert("method_input_name", input_name);
126    substitutions.insert("method_input_type", input_type);
127    substitutions.insert("method_output_type", output_type);
128    substitutions.insert("method_wrapper", wrapper.to_string());
129    substitutions.insert("method_request_handler", request_handler.to_string());
130    substitutions.insert("method_output_handler", output_handler);
131    substitutions.insert("method_not_found", not_found);
132    substitutions
133}
134
135fn format_comments(comments: &Comments, indent_level: u8) -> String {
136    let mut formatted = String::new();
137    comments.append_with_indent(indent_level, &mut formatted);
138    formatted
139}
140
141fn future_for(ty: &str) -> String {
142    format!("impl trapeze::prelude::Future<Output = {ty}> + Send")
143}
144
145fn fallible_future_for(ty: &str) -> String {
146    future_for(&format!("trapeze::Result<{ty}>"))
147}
148
149fn stream_for(ty: &str) -> String {
150    format!("impl trapeze::prelude::Stream<Item = {ty}> + Send")
151}
152
153fn fallible_stream_for(ty: &str) -> String {
154    stream_for(&format!("trapeze::Result<{ty}>"))
155}
156
157fn replace(src: impl Into<String>, substitutions: HashMap<&'static str, String>) -> String {
158    let mut src = src.into();
159    for (from, to) in substitutions {
160        src = src.replace(&format!("__{from}__"), &to);
161    }
162    src
163}
164
165fn future_handler(method_name: &str) -> String {
166    format!("async move {{ target.{method_name}(input).await }}")
167}
168
169fn stream_handler(method_name: &str) -> String {
170    format!("trapeze::stream::stream! {{ for await value in target.{method_name}(input) {{ yield value; }} }}")
171}
172
173fn not_found_future() -> String {
174    "async move { Err(not_found) }".into()
175}
176
177fn not_found_stream() -> String {
178    "trapeze::stream::stream! { yield Err(not_found); }".into()
179}
180
181fn camel2snake(name: impl AsRef<str>) -> String {
182    name.as_ref()
183        .split("::")
184        .last()
185        .unwrap()
186        .chars()
187        .enumerate()
188        .flat_map(|(i, c)| {
189            if i > 0 && c.is_uppercase() {
190                vec!['_'].into_iter().chain(c.to_lowercase())
191            } else {
192                vec![].into_iter().chain(c.to_lowercase())
193            }
194        })
195        .collect()
196}