tonic_middleware/
middleware.rs1use std::task::{Context, Poll};
2
3use crate::ServiceBound;
4use async_trait::async_trait;
5use futures_util::future::BoxFuture;
6use tonic::body::BoxBody;
7use tonic::codegen::http::Request;
8use tonic::codegen::http::Response;
9use tonic::codegen::Service;
10use tonic::server::NamedService;
11use tower::Layer;
12
13#[async_trait]
27pub trait Middleware<S>
28where
29 S: ServiceBound,
30{
31 async fn call(&self, req: Request<BoxBody>, service: S) -> Result<Response<BoxBody>, S::Error>;
46}
47
48#[derive(Clone)]
55pub struct MiddlewareFor<S, M>
56where
57 S: ServiceBound,
58 M: Middleware<S>,
59{
60 pub inner: S,
61 pub middleware: M,
62}
63
64impl<S, M> MiddlewareFor<S, M>
65where
66 S: ServiceBound,
67 M: Middleware<S>,
68{
69 pub fn new(inner: S, middleware: M) -> Self {
76 MiddlewareFor { inner, middleware }
77 }
78}
79
80impl<S, M> Service<Request<BoxBody>> for MiddlewareFor<S, M>
81where
82 S: ServiceBound,
83 S::Future: Send,
84 M: Middleware<S> + Send + Clone + 'static + Sync,
85{
86 type Response = S::Response;
87 type Error = S::Error;
88 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
89
90 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
91 self.inner.poll_ready(cx)
92 }
93
94 fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
95 let middleware = self.middleware.clone();
96 let inner = self.inner.clone();
97 Box::pin(async move { middleware.call(req, inner).await })
98 }
99}
100
101impl<S, M> NamedService for MiddlewareFor<S, M>
102where
103 S: NamedService + ServiceBound,
104 M: Middleware<S>,
105{
106 const NAME: &'static str = S::NAME;
107}
108
109#[derive(Clone)]
112pub struct MiddlewareLayer<M> {
113 middleware: M,
114}
115
116impl<M> MiddlewareLayer<M> {
117 pub fn new(middleware: M) -> Self {
123 MiddlewareLayer { middleware }
124 }
125}
126
127impl<S, M> Layer<S> for MiddlewareLayer<M>
128where
129 S: ServiceBound,
130 M: Middleware<S> + Clone,
131{
132 type Service = MiddlewareFor<S, M>;
133
134 fn layer(&self, inner: S) -> Self::Service {
135 MiddlewareFor::new(inner, self.middleware.clone())
136 }
137}