1use bytes::Bytes;
6use http::{Extensions, Method, Uri, Version};
7use http_body::Body;
8use pin_project::pin_project;
9use std::{
10 fmt,
11 future::Future,
12 mem,
13 pin::Pin,
14 task::{Context, Poll},
15};
16use tonic::body::Body as TonicBody;
17use tonic::metadata::MetadataMap;
18use tonic::{Request, Status};
19use tower_layer::Layer;
20use tower_service::Service;
21
22pub type Error = Box<dyn std::error::Error + Send + Sync>;
23
24pub trait AsyncInterceptor {
54 type Future: Future<Output = Result<Request<()>, Status>>;
56 fn call(&mut self, request: Request<()>) -> Self::Future;
58}
59
60impl<F, U> AsyncInterceptor for F
61where
62 F: FnMut(Request<()>) -> U,
63 U: Future<Output = Result<Request<()>, Status>>,
64{
65 type Future = U;
66
67 fn call(&mut self, request: Request<()>) -> Self::Future {
68 self(request)
69 }
70}
71
72pub fn async_interceptor<F>(f: F) -> AsyncInterceptorLayer<F>
76where
77 F: AsyncInterceptor,
78{
79 AsyncInterceptorLayer { f }
80}
81
82#[derive(Debug, Clone, Copy)]
87pub struct AsyncInterceptorLayer<F> {
88 f: F,
89}
90
91impl<S, F> Layer<S> for AsyncInterceptorLayer<F>
92where
93 S: Clone,
94 F: AsyncInterceptor + Clone,
95{
96 type Service = AsyncInterceptedService<S, F>;
97
98 fn layer(&self, service: S) -> Self::Service {
99 AsyncInterceptedService::new(service, self.f.clone())
100 }
101}
102
103#[derive(Debug)]
105struct DecomposedRequest<ReqBody> {
106 uri: Uri,
107 method: Method,
108 http_version: Version,
109 msg: ReqBody,
110}
111
112fn request_into_parts<Msg>(mut req: Request<Msg>) -> (MetadataMap, Extensions, Msg) {
114 let metadata = mem::take(req.metadata_mut());
116 let extensions = mem::take(req.extensions_mut());
117 (metadata, extensions, req.into_inner())
118}
119
120fn request_from_parts<Msg>(
122 msg: Msg,
123 metadata: MetadataMap,
124 extensions: Extensions,
125) -> Request<Msg> {
126 let mut req = Request::new(msg);
127 *req.metadata_mut() = metadata;
128 *req.extensions_mut() = extensions;
129 req
130}
131
132fn request_into_http<Msg>(
134 msg: Msg,
135 uri: http::Uri,
136 method: http::Method,
137 version: http::Version,
138 metadata: MetadataMap,
139 extensions: Extensions,
140) -> http::Request<Msg> {
141 let mut request = http::Request::new(msg);
142 *request.version_mut() = version;
143 *request.method_mut() = method;
144 *request.uri_mut() = uri;
145 *request.headers_mut() = metadata.into_headers();
146 *request.extensions_mut() = extensions;
147
148 request
149}
150
151fn decompose<ReqBody>(req: http::Request<ReqBody>) -> (DecomposedRequest<ReqBody>, Request<()>) {
160 let uri = req.uri().clone();
161 let method = req.method().clone();
162 let http_version = req.version();
163 let req = Request::from_http(req);
164 let (metadata, extensions, msg) = request_into_parts(req);
165
166 let dreq = DecomposedRequest {
167 uri,
168 method,
169 http_version,
170 msg,
171 };
172 let req_without_body = request_from_parts((), metadata, extensions);
173
174 (dreq, req_without_body)
175}
176
177fn recompose<ReqBody>(
179 dreq: DecomposedRequest<ReqBody>,
180 modified_req: Request<()>,
181) -> http::Request<ReqBody> {
182 let (metadata, extensions, _) = request_into_parts(modified_req);
183
184 request_into_http(
185 dreq.msg,
186 dreq.uri,
187 dreq.method,
188 dreq.http_version,
189 metadata,
190 extensions,
191 )
192}
193
194#[derive(Clone, Copy)]
198pub struct AsyncInterceptedService<S, F> {
199 inner: S,
200 f: F,
201}
202
203impl<S, F> AsyncInterceptedService<S, F> {
204 pub fn new(service: S, f: F) -> Self {
207 Self { inner: service, f }
208 }
209}
210
211impl<S, F> fmt::Debug for AsyncInterceptedService<S, F>
212where
213 S: fmt::Debug,
214{
215 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216 f.debug_struct("AsyncInterceptedService")
217 .field("inner", &self.inner)
218 .field("f", &format_args!("{}", std::any::type_name::<F>()))
219 .finish()
220 }
221}
222
223impl<S, F, ReqBody, ResBody> Service<http::Request<ReqBody>> for AsyncInterceptedService<S, F>
224where
225 F: AsyncInterceptor + Clone,
226 S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>> + Clone,
227 S::Error: Into<Error>,
228 ReqBody: Default,
229 ResBody: Default + Body<Data = Bytes> + Send + 'static,
230 ResBody::Error: Into<Error>,
231{
232 type Response = http::Response<TonicBody>;
233 type Error = S::Error;
234 type Future = AsyncResponseFuture<S, F::Future, ReqBody>;
235
236 #[inline]
237 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
238 self.inner.poll_ready(cx)
239 }
240
241 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
242 let clone = self.inner.clone();
246 let inner = std::mem::replace(&mut self.inner, clone);
247
248 AsyncResponseFuture::new(req, &mut self.f, inner)
249 }
250}
251
252impl<S, F> tonic::server::NamedService for AsyncInterceptedService<S, F>
254where
255 S: tonic::server::NamedService,
256{
257 const NAME: &'static str = S::NAME;
258}
259
260#[pin_project]
262#[derive(Debug)]
263pub struct ResponseFuture<F> {
264 #[pin]
265 kind: Kind<F>,
266}
267
268impl<F> ResponseFuture<F> {
269 fn future(future: F) -> Self {
270 Self {
271 kind: Kind::Future(future),
272 }
273 }
274
275 fn status(status: Status) -> Self {
276 Self {
277 kind: Kind::Status(Some(status)),
278 }
279 }
280}
281
282#[pin_project(project = KindProj)]
283#[derive(Debug)]
284enum Kind<F> {
285 Future(#[pin] F),
286 Status(Option<Status>),
287}
288
289impl<F, E, B> Future for ResponseFuture<F>
290where
291 F: Future<Output = Result<http::Response<B>, E>>,
292 E: Into<Error>,
293 B: Default + Body<Data = Bytes> + Send + 'static,
294 B::Error: Into<Error>,
295{
296 type Output = Result<http::Response<TonicBody>, E>;
297
298 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
299 match self.project().kind.project() {
300 KindProj::Future(future) => future
301 .poll(cx)
302 .map(|result| result.map(|resp| resp.map(TonicBody::new))),
303 KindProj::Status(status) => {
304 let response = status.take().unwrap().into_http();
305 Poll::Ready(Ok(response))
306 }
307 }
308 }
309}
310
311#[pin_project(project = PinnedOptionProj)]
312#[derive(Debug)]
313enum PinnedOption<F> {
314 Some(#[pin] F),
315 None,
316}
317
318#[pin_project(project = AsyncResponseFutureProj)]
323#[derive(Debug)]
324pub struct AsyncResponseFuture<S, I, ReqBody>
325where
326 S: Service<http::Request<ReqBody>>,
327 S::Error: Into<Error>,
328 I: Future<Output = Result<Request<()>, Status>>,
329{
330 #[pin]
331 interceptor_fut: PinnedOption<I>,
332 #[pin]
333 inner_fut: PinnedOption<ResponseFuture<S::Future>>,
334 inner: S,
335 dreq: DecomposedRequest<ReqBody>,
336}
337
338impl<S, I, ReqBody> AsyncResponseFuture<S, I, ReqBody>
339where
340 S: Service<http::Request<ReqBody>>,
341 S::Error: Into<Error>,
342 I: Future<Output = Result<Request<()>, Status>>,
343 ReqBody: Default,
344{
345 fn new<A: AsyncInterceptor<Future = I>>(
346 req: http::Request<ReqBody>,
347 interceptor: &mut A,
348 inner: S,
349 ) -> Self {
350 let (dreq, req_without_body) = decompose(req);
351 let interceptor_fut = interceptor.call(req_without_body);
352
353 AsyncResponseFuture {
354 interceptor_fut: PinnedOption::Some(interceptor_fut),
355 inner_fut: PinnedOption::None,
356 inner,
357 dreq,
358 }
359 }
360
361 fn create_inner_fut(
364 this: &mut AsyncResponseFutureProj<'_, S, I, ReqBody>,
365 intercepted_req: Result<Request<()>, Status>,
366 ) -> ResponseFuture<S::Future> {
367 match intercepted_req {
368 Ok(req) => {
369 let msg = mem::take(&mut this.dreq.msg);
373 let movable_dreq = DecomposedRequest {
374 uri: this.dreq.uri.clone(),
375 method: this.dreq.method.clone(),
376 http_version: this.dreq.http_version,
377 msg,
378 };
379 let modified_req_with_body = recompose(movable_dreq, req);
380
381 ResponseFuture::future(this.inner.call(modified_req_with_body))
382 }
383 Err(status) => ResponseFuture::status(status),
384 }
385 }
386}
387
388impl<S, I, ReqBody, ResBody> Future for AsyncResponseFuture<S, I, ReqBody>
389where
390 S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
391 I: Future<Output = Result<Request<()>, Status>>,
392 S::Error: Into<Error>,
393 ReqBody: Default,
394 ResBody: Default + Body<Data = Bytes> + Send + 'static,
395 ResBody::Error: Into<Error>,
396{
397 type Output = Result<http::Response<TonicBody>, S::Error>;
398
399 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
400 let mut this = self.project();
401
402 if let PinnedOptionProj::Some(f) = this.interceptor_fut.as_mut().project() {
404 match f.poll(cx) {
405 Poll::Ready(intercepted_req) => {
406 let inner_fut = AsyncResponseFuture::<S, I, ReqBody>::create_inner_fut(
407 &mut this,
408 intercepted_req,
409 );
410 this.inner_fut.set(PinnedOption::Some(inner_fut));
412 this.interceptor_fut.set(PinnedOption::None);
413 }
414 Poll::Pending => return Poll::Pending,
415 }
416 }
417 let inner_fut = match this.inner_fut.project() {
419 PinnedOptionProj::None => panic!(),
420 PinnedOptionProj::Some(f) => f,
421 };
422
423 inner_fut.poll(cx)
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430 use http::StatusCode;
431 use http_body_util::Empty;
432 use std::future;
433 use tower::ServiceExt;
434
435 #[tokio::test]
436 async fn propagates_added_extensions() {
437 #[derive(Clone)]
438 struct TestExtension {
439 data: String,
440 }
441 let test_extension_data = "abc";
442
443 let layer = async_interceptor(|mut req: Request<()>| {
444 req.extensions_mut().insert(TestExtension {
445 data: test_extension_data.to_owned(),
446 });
447
448 future::ready(Ok(req))
449 });
450
451 let svc = layer.layer(tower::service_fn(
452 |http_req: http::Request<Empty<Bytes>>| async {
453 let req = Request::from_http(http_req);
454 let maybe_extension = req.extensions().get::<TestExtension>();
455 assert!(maybe_extension.is_some());
456 assert_eq!(maybe_extension.unwrap().data, test_extension_data);
457
458 Ok::<_, Status>(http::Response::new(Empty::new()))
459 },
460 ));
461
462 let request = http::Request::builder().body(Empty::new()).unwrap();
463 let http_response = svc.oneshot(request).await.unwrap();
464
465 assert_eq!(http_response.status(), StatusCode::OK);
466 }
467
468 #[tokio::test]
469 async fn propagates_added_metadata() {
470 let test_metadata_key = "test_key";
471 let test_metadata_val = "abc";
472
473 let layer = async_interceptor(|mut req: Request<()>| {
474 req.metadata_mut()
475 .insert(test_metadata_key, test_metadata_val.parse().unwrap());
476
477 future::ready(Ok(req))
478 });
479
480 let svc = layer.layer(tower::service_fn(
481 |http_req: http::Request<Empty<Bytes>>| async {
482 let req = Request::from_http(http_req);
483 let maybe_metadata = req.metadata().get(test_metadata_key);
484 assert!(maybe_metadata.is_some());
485 assert_eq!(maybe_metadata.unwrap(), test_metadata_val);
486
487 Ok::<_, Status>(http::Response::new(Empty::new()))
488 },
489 ));
490
491 let request = http::Request::builder().body(Empty::new()).unwrap();
492 let http_response = svc.oneshot(request).await.unwrap();
493
494 assert_eq!(http_response.status(), StatusCode::OK);
495 }
496
497 #[tokio::test]
498 async fn doesnt_remove_headers_from_request() {
499 let layer = async_interceptor(|request: Request<()>| {
500 assert_eq!(
501 request
502 .metadata()
503 .get("user-agent")
504 .expect("missing in interceptor"),
505 "test-tonic"
506 );
507 future::ready(Ok(request))
508 });
509
510 let svc = layer.layer(tower::service_fn(
511 |request: http::Request<Empty<Bytes>>| async move {
512 assert_eq!(
513 request
514 .headers()
515 .get("user-agent")
516 .expect("missing in leaf service"),
517 "test-tonic"
518 );
519
520 Ok::<_, Status>(http::Response::new(Empty::new()))
521 },
522 ));
523
524 let request = http::Request::builder()
525 .header("user-agent", "test-tonic")
526 .body(Empty::new())
527 .unwrap();
528
529 svc.oneshot(request).await.unwrap();
530 }
531
532 #[tokio::test]
533 async fn handles_intercepted_status_as_response() {
534 let message = "Blocked by the interceptor";
535 let expected = Status::permission_denied(message).into_http::<TonicBody>();
536
537 let layer = async_interceptor(|_: Request<()>| {
538 future::ready(Err(Status::permission_denied(message)))
539 });
540
541 let svc = layer.layer(tower::service_fn(|_: http::Request<Empty<Bytes>>| async {
542 Ok::<_, Status>(http::Response::new(Empty::new()))
543 }));
544
545 let request = http::Request::builder().body(Empty::new()).unwrap();
546 let response = svc.oneshot(request).await.unwrap();
547
548 assert_eq!(expected.status(), response.status());
549 assert_eq!(expected.version(), response.version());
550 assert_eq!(expected.headers(), response.headers());
551 }
552
553 #[tokio::test]
554 async fn doesnt_change_http_method() {
555 let layer = async_interceptor(|request: Request<()>| future::ready(Ok(request)));
556
557 let svc = layer.layer(tower::service_fn(
558 |request: http::Request<Empty<Bytes>>| async move {
559 assert_eq!(request.method(), http::Method::OPTIONS);
560
561 Ok::<_, Status>(http::Response::new(Empty::new()))
562 },
563 ));
564
565 let request = http::Request::builder()
566 .method(http::Method::OPTIONS)
567 .body(Empty::new())
568 .unwrap();
569
570 svc.oneshot(request).await.unwrap();
571 }
572}