wiremock_grpc_macros/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{
5 braced,
6 parse::{Parse, ParseStream},
7 punctuated::Punctuated,
8 Ident, Result, Token,
9};
10
11#[proc_macro]
94pub fn generate_svc(input: TokenStream) -> TokenStream {
95 let service_def = syn::parse_macro_input!(input as ServiceDefinition);
96 service_def.generate().into()
97}
98
99struct ServiceDefinition {
100 package: String,
101 service_name: Ident,
102 server_name: Ident,
103 methods: Punctuated<Ident, Token![,]>,
104}
105
106impl Parse for ServiceDefinition {
107 fn parse(input: ParseStream) -> Result<Self> {
108 let _package_kw: Ident = input.parse()?;
110 if _package_kw != "package" {
111 return Err(syn::Error::new(
112 _package_kw.span(),
113 "expected `package` keyword",
114 ));
115 }
116
117 let first: Ident = input.parse()?;
119 let mut package = first.to_string();
120 while input.peek(Token![.]) {
121 let _dot: Token![.] = input.parse()?;
122 let next: Ident = input.parse()?;
123 package.push('.');
124 package.push_str(&next.to_string());
125 }
126
127 let _semi: Token![;] = input.parse()?;
128
129 let _service_kw: Ident = input.parse()?;
131 if _service_kw != "service" {
132 return Err(syn::Error::new(
133 _service_kw.span(),
134 "expected `service` keyword",
135 ));
136 }
137 let service_name: Ident = input.parse()?;
138
139 let server_name = if input.peek(Token![as]) {
140 let _as: Token![as] = input.parse()?;
141 input.parse()?
142 } else {
143 format_ident!("{}MockServer", service_name)
144 };
145
146 let content;
147 braced!(content in input);
148
149 let methods = content.parse_terminated(Ident::parse, Token![,])?;
150
151 Ok(ServiceDefinition {
152 package,
153 service_name,
154 server_name,
155 methods,
156 })
157 }
158}
159
160impl ServiceDefinition {
161 fn generate(&self) -> TokenStream2 {
162 let ext_trait = self.generate_ext_trait();
163 let mock_server = self.generate_mock_server();
164
165 quote! {
166 #ext_trait
167 #mock_server
168 }
169 }
170
171 fn generate_ext_trait(&self) -> TokenStream2 {
172 let trait_name = format_ident!("{}TypeSafeExt", self.service_name);
173 let package = &self.package;
174 let service_name = &self.service_name;
175
176 let method_signatures: Vec<_> = self
177 .methods
178 .iter()
179 .map(|method| {
180 let fn_name = format_ident!("path_{}", to_snake_case(&method.to_string()));
181 quote! {
182 fn #fn_name(&self) -> Self;
183 }
184 })
185 .collect();
186
187 let method_impls: Vec<_> = self
188 .methods
189 .iter()
190 .map(|method| {
191 let fn_name = format_ident!("path_{}", to_snake_case(&method.to_string()));
192 let path = format!("/{}.{}/{}", package, service_name, method);
193 quote! {
194 fn #fn_name(&self) -> Self {
195 #[expect(deprecated)]
196 self.path(#path)
197 }
198 }
199 })
200 .collect();
201
202 quote! {
203 pub trait #trait_name {
204 #(#method_signatures)*
205 }
206
207 impl #trait_name for wiremock_grpc::WhenBuilder {
208 #(#method_impls)*
209 }
210 }
211 }
212
213 fn generate_mock_server(&self) -> TokenStream2 {
214 let server_name = &self.server_name;
215 let package = &self.package;
216 let service_name = &self.service_name;
217 let prefix = format!("{}.{}", package, service_name);
218
219 quote! {
220 #[derive(Clone)]
221 pub struct #server_name(wiremock_grpc::GrpcServer);
222
223 impl ::std::ops::Deref for #server_name {
224 type Target = wiremock_grpc::GrpcServer;
225
226 fn deref(&self) -> &Self::Target {
227 &self.0
228 }
229 }
230
231 impl ::std::ops::DerefMut for #server_name {
232 fn deref_mut(&mut self) -> &mut Self::Target {
233 &mut self.0
234 }
235 }
236
237 impl<B> wiremock_grpc::tonic::codegen::Service<wiremock_grpc::tonic::codegen::http::Request<B>> for #server_name
238 where
239 B: wiremock_grpc::http_body::Body + Send + 'static,
240 B::Error: Into<wiremock_grpc::tonic::codegen::StdError> + Send + 'static,
241 {
242 type Response = wiremock_grpc::tonic::codegen::http::Response<wiremock_grpc::tonic::body::Body>;
243 type Error = ::std::convert::Infallible;
244 type Future = wiremock_grpc::tonic::codegen::BoxFuture<Self::Response, Self::Error>;
245
246 fn poll_ready(
247 &mut self,
248 _cx: &mut ::std::task::Context<'_>,
249 ) -> ::std::task::Poll<Result<(), Self::Error>> {
250 ::std::task::Poll::Ready(Ok(()))
251 }
252
253 fn call(&mut self, req: wiremock_grpc::tonic::codegen::http::Request<B>) -> Self::Future {
254 self.0.handle_request(req)
255 }
256 }
257
258 impl wiremock_grpc::tonic::server::NamedService for #server_name {
259 const NAME: &'static str = #prefix;
260 }
261
262 impl #server_name {
263 pub async fn start_default() -> Self {
264 let port = wiremock_grpc::GrpcServer::find_unused_port()
265 .await
266 .expect("Unable to find an open port");
267
268 Self(wiremock_grpc::GrpcServer::new(port)).start_internal().await
269 }
270
271 pub async fn start(port: u16) -> Self {
272 Self(wiremock_grpc::GrpcServer::new(port)).start_internal().await
273 }
274
275 pub async fn start_with_addr(addr: ::std::net::SocketAddr) -> Self {
276 Self(wiremock_grpc::GrpcServer::with_addr(addr)).start_internal().await
277 }
278
279 async fn start_internal(&mut self) -> Self {
280 let address = self.address().clone();
281 let thread = ::tokio::spawn(
282 wiremock_grpc::tonic::transport::Server::builder()
283 .add_service(self.clone())
284 .serve(address),
285 );
286 self._start(thread).await;
287 self.to_owned()
288 }
289 }
290 }
291 }
292}
293
294fn to_snake_case(s: &str) -> String {
295 let mut result = String::new();
296 let chars: Vec<char> = s.chars().collect();
297
298 for (i, &ch) in chars.iter().enumerate() {
299 if ch.is_uppercase() {
300 if i > 0 {
301 let prev = chars[i - 1];
302 let next = chars.get(i + 1).copied();
303
304 if !prev.is_uppercase() || next.map(|n| n.is_lowercase()).unwrap_or(false) {
308 result.push('_');
309 }
310 }
311
312 for lower in ch.to_lowercase() {
313 result.push(lower);
314 }
315 } else {
316 result.push(ch);
317 }
318 }
319 result
320}
321
322#[test]
323fn test_to_snake_case() {
324 assert_eq!(to_snake_case("HTTPServer"), "http_server");
325 assert_eq!(to_snake_case("GetWeather"), "get_weather");
326 assert_eq!(to_snake_case("nothing"), "nothing");
327 assert_eq!(to_snake_case(""), "");
328}