trapeze_codegen/
service_generator.rs1use std::collections::HashMap;
2
3use prost_build::{Comments, Method, Service, ServiceGenerator};
4
5pub struct TtrpcServiceGenerator;
10
11impl ServiceGenerator for TtrpcServiceGenerator {
12 fn generate(&mut self, service: Service, buf: &mut String) {
13 let mut substitutions = service_substitutions(&service);
14
15 let make_client_method = |m| make_client_method(substitutions.clone(), m);
16 let make_trait_method = |m| make_trait_method(substitutions.clone(), m);
17 let make_dispatch_branch = |m| make_dispatch_branch(substitutions.clone(), m);
18
19 let methods = service.methods;
20
21 let client_methods: String = methods.iter().map(make_client_method).collect();
22 let trait_methods: String = methods.iter().map(make_trait_method).collect();
23 let dispatch_branches: String = methods.iter().map(make_dispatch_branch).collect();
24
25 substitutions.insert("client_methods", client_methods);
26 substitutions.insert("trait_methods", trait_methods);
27 substitutions.insert("dispatch_branches", dispatch_branches);
28
29 let service = replace(include_str!("../templates/service.rs"), substitutions);
30
31 buf.push_str(&service);
32 }
33}
34
35fn make_trait_method(mut substitutions: HashMap<&'static str, String>, method: &Method) -> String {
36 substitutions.extend(method_substitutions(method));
37
38 replace(include_str!("../templates/trait_method.rs"), substitutions)
39}
40
41fn make_dispatch_branch(
42 mut substitutions: HashMap<&'static str, String>,
43 method: &Method,
44) -> String {
45 substitutions.extend(method_substitutions(method));
46
47 replace(
48 include_str!("../templates/dispatch_branch.rs"),
49 substitutions,
50 )
51}
52
53fn make_client_method(mut substitutions: HashMap<&'static str, String>, method: &Method) -> String {
54 substitutions.extend(method_substitutions(method));
55
56 replace(include_str!("../templates/client_method.rs"), substitutions)
57}
58
59fn service_substitutions(service: &Service) -> HashMap<&'static str, String> {
60 let mut substitutions = HashMap::default();
61 substitutions.insert("service_comments", format_comments(&service.comments, 0));
62 substitutions.insert("service_name", service.name.clone());
63 substitutions.insert("service_package", service.package.clone());
64 substitutions.insert("service_proto_name", service.proto_name.clone());
65 substitutions.insert("service_module_name", camel2snake(&service.name));
66 substitutions
67}
68
69fn method_substitutions(method: &Method) -> HashMap<&'static str, String> {
70 let mut substitutions = HashMap::default();
71 let Method {
72 name,
73 proto_name,
74 input_type,
75 output_type,
76 client_streaming,
77 server_streaming,
78 comments,
79 ..
80 } = method;
81
82 let input_name = camel2snake(input_type);
83
84 let wrapper = match (*client_streaming, *server_streaming) {
85 (false, false) => "UnaryMethod",
86 (false, true) => "ServerStreamingMethod",
87 (true, false) => "ClientStreamingMethod",
88 (true, true) => "DuplexStreamingMethod",
89 };
90
91 let request_handler = match (*client_streaming, *server_streaming) {
92 (false, false) => "handle_unary_request",
93 (false, true) => "handle_server_streaming_request",
94 (true, false) => "handle_client_streaming_request",
95 (true, true) => "handle_duplex_streaming_request",
96 };
97
98 let input_type = if *client_streaming {
99 stream_for(input_type)
100 } else {
101 input_type.clone()
102 };
103
104 let output_type = if *server_streaming {
105 fallible_stream_for(output_type)
106 } else {
107 fallible_future_for(output_type)
108 };
109
110 let output_handler = if *server_streaming {
111 stream_handler(name)
112 } else {
113 future_handler(name)
114 };
115
116 let not_found = if *server_streaming {
117 not_found_stream()
118 } else {
119 not_found_future()
120 };
121
122 substitutions.insert("method_comments", format_comments(comments, 1));
123 substitutions.insert("method_name", name.clone());
124 substitutions.insert("method_proto_name", proto_name.clone());
125 substitutions.insert("method_input_name", input_name);
126 substitutions.insert("method_input_type", input_type);
127 substitutions.insert("method_output_type", output_type);
128 substitutions.insert("method_wrapper", wrapper.to_string());
129 substitutions.insert("method_request_handler", request_handler.to_string());
130 substitutions.insert("method_output_handler", output_handler);
131 substitutions.insert("method_not_found", not_found);
132 substitutions
133}
134
135fn format_comments(comments: &Comments, indent_level: u8) -> String {
136 let mut formatted = String::new();
137 comments.append_with_indent(indent_level, &mut formatted);
138 formatted
139}
140
141fn future_for(ty: &str) -> String {
142 format!("impl trapeze::prelude::Future<Output = {ty}> + Send")
143}
144
145fn fallible_future_for(ty: &str) -> String {
146 future_for(&format!("trapeze::Result<{ty}>"))
147}
148
149fn stream_for(ty: &str) -> String {
150 format!("impl trapeze::prelude::Stream<Item = {ty}> + Send")
151}
152
153fn fallible_stream_for(ty: &str) -> String {
154 stream_for(&format!("trapeze::Result<{ty}>"))
155}
156
157fn replace(src: impl Into<String>, substitutions: HashMap<&'static str, String>) -> String {
158 let mut src = src.into();
159 for (from, to) in substitutions {
160 src = src.replace(&format!("__{from}__"), &to);
161 }
162 src
163}
164
165fn future_handler(method_name: &str) -> String {
166 format!("async move {{ target.{method_name}(input).await }}")
167}
168
169fn stream_handler(method_name: &str) -> String {
170 format!("trapeze::stream::stream! {{ for await value in target.{method_name}(input) {{ yield value; }} }}")
171}
172
173fn not_found_future() -> String {
174 "async move { Err(not_found) }".into()
175}
176
177fn not_found_stream() -> String {
178 "trapeze::stream::stream! { yield Err(not_found); }".into()
179}
180
181fn camel2snake(name: impl AsRef<str>) -> String {
182 name.as_ref()
183 .split("::")
184 .last()
185 .unwrap()
186 .chars()
187 .enumerate()
188 .flat_map(|(i, c)| {
189 if i > 0 && c.is_uppercase() {
190 vec!['_'].into_iter().chain(c.to_lowercase())
191 } else {
192 vec![].into_iter().chain(c.to_lowercase())
193 }
194 })
195 .collect()
196}