1#![warn(missing_docs)]
3#![cfg_attr(feature = "cargo-clippy", allow(clippy::style))]
4
5use core::task;
6use core::pin::Pin;
7use core::future::Future;
8
9const GRPC_STATUS_HEADER_CODE: &str = "grpc-status";
10
11pub trait Interceptor {
13 fn on_request(&self, headers: &mut tonic::metadata::MetadataMap, extensions: &mut http::Extensions) -> Option<tonic::Status>;
19
20 fn on_response(&self, status: tonic::Code, _headers: &mut http::HeaderMap, _extensions: &http::Extensions);
22}
23
24impl<I: Interceptor> Interceptor for std::sync::Arc<I> {
25 #[inline(always)]
26 fn on_request(&self, headers: &mut tonic::metadata::MetadataMap, extensions: &mut http::Extensions) -> Option<tonic::Status> {
27 Interceptor::on_request(self.as_ref(), headers, extensions)
28 }
29
30 #[inline(always)]
31 fn on_response(&self, status: tonic::Code, headers: &mut http::HeaderMap, extensions: &http::Extensions) {
32 Interceptor::on_response(self.as_ref(), status, headers, extensions)
33 }
34}
35
36#[derive(Clone)]
38#[repr(transparent)]
39pub struct InterceptorLayer<I>(I);
40
41impl<S, I: Interceptor + Clone> tower_layer::Layer<S> for InterceptorLayer<I> {
42 type Service = InterceptorService<I, S>;
43
44 #[inline(always)]
45 fn layer(&self, inner: S) -> Self::Service {
46 InterceptorService::new(self.0.clone(), inner)
47 }
48}
49
50pub struct InterceptorService<I, S> {
52 interceptor: I,
53 inner: S
54}
55
56impl<I, S> InterceptorService<I, S> {
57 #[inline(always)]
58 pub fn new(interceptor: I, inner: S) -> Self {
60 Self {
61 interceptor,
62 inner
63 }
64 }
65}
66
67impl<ReqBody, ResBody: Default, S: tower_service::Service<http::Request<ReqBody>, Response = http::Response<ResBody>>, I: Interceptor + Clone> tower_service::Service<http::Request<ReqBody>> for InterceptorService<I, S> {
68 type Response = S::Response;
69 type Error = S::Error;
70 type Future = InterceptorFut<I, S::Future>;
71
72 #[inline(always)]
73 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
74 self.inner.poll_ready(cx)
75 }
76
77 #[inline(always)]
78 fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
79 let (mut parts, body) = req.into_parts();
80
81 let mut headers = tonic::metadata::MetadataMap::from_headers(parts.headers);
82 match self.interceptor.on_request(&mut headers, &mut parts.extensions) {
83 None => {
84 parts.headers = headers.into_headers();
85 req = http::Request::from_parts(parts, body);
86 InterceptorFut::fut(self.interceptor.clone(), self.inner.call(req))
87 }
88 Some(status) => InterceptorFut::status(self.interceptor.clone(), status),
89 }
90 }
91}
92
93pub struct InterceptorFut<I, F> {
95 interceptor: I,
96 inner: Result<F, tonic::Status>,
97}
98
99impl<I, F> InterceptorFut<I, F> {
100 #[inline(always)]
101 fn status(interceptor: I, status: tonic::Status) -> Self {
102 Self {
103 interceptor,
104 inner: Err(status),
105 }
106 }
107
108 #[inline(always)]
109 fn fut(interceptor: I, fut: F) -> Self {
110 Self {
111 interceptor,
112 inner: Ok(fut),
113 }
114 }
115}
116
117
118impl<ResBody: Default, E, I: Interceptor, F: Future<Output = Result<http::Response<ResBody>, E>>> Future for InterceptorFut<I, F> {
119 type Output = F::Output;
120
121 fn poll(self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
122 let (intercepter, fut) = unsafe {
123 let this = self.get_unchecked_mut();
124 let fut = match this.inner.as_mut() {
125 Ok(fut) => Pin::new_unchecked(fut),
126 Err(status) => {
127 let mut resp = http::Response::new(Default::default());
128 resp.headers_mut().insert(http::header::CONTENT_TYPE, http::header::HeaderValue::from_static("application/grpc"));
129 let _ = status.add_header(resp.headers_mut());
130 return task::Poll::Ready(Ok(resp));
131 }
132 };
133 (&this.interceptor, fut)
134 };
135 match Future::poll(fut, ctx) {
136 task::Poll::Ready(Result::Ok(resp)) => {
137 let (mut parts, body) = resp.into_parts();
138
139 let status = parts.headers.get(GRPC_STATUS_HEADER_CODE).map(|header| tonic::Code::from_bytes(header.as_bytes())).unwrap_or(tonic::Code::Unknown);
140
141 intercepter.on_response(status, &mut parts.headers, &parts.extensions);
142 task::Poll::Ready(Ok(http::Response::from_parts(parts, body)))
143 },
144 task::Poll::Ready(Result::Err(error)) => task::Poll::Ready(Err(error)),
145 task::Poll::Pending => task::Poll::Pending,
146 }
147 }
148}
149
150#[derive(Clone)]
151pub struct InterceptorFn<OnReq, OnResp> {
153 pub on_request: OnReq,
155 pub on_response: OnResp,
157}
158
159impl<OnReq: Fn(&mut tonic::metadata::MetadataMap, &mut http::Extensions) -> Option<tonic::Status>, OnResp: Fn(tonic::Code, &mut http::HeaderMap, &http::Extensions)> Interceptor for InterceptorFn<OnReq, OnResp> {
160
161 #[inline(always)]
162 fn on_request(&self, headers: &mut tonic::metadata::MetadataMap, extensions: &mut http::Extensions) -> Option<tonic::Status> {
163 (self.on_request)(headers, extensions)
164 }
165
166 #[inline(always)]
167 fn on_response(&self, status: tonic::Code, headers: &mut http::HeaderMap, extensions: &http::Extensions) {
168 (self.on_response)(status, headers, extensions)
169 }
170}
171
172#[inline(always)]
173pub fn interceptor<I: Interceptor>(interceptor: I) -> InterceptorLayer<I> {
175 InterceptorLayer(interceptor)
176}