tonic_middleware/middleware.rs
1use std::task::{Context, Poll};
2
3use crate::ServiceBound;
4use async_trait::async_trait;
5use futures_util::future::BoxFuture;
6use tonic::body::Body;
7use tonic::codegen::http::Request;
8use tonic::codegen::http::Response;
9use tonic::codegen::Service;
10use tonic::server::NamedService;
11use tower::Layer;
12
13/// The `Middleware` trait defines a generic interface for middleware components
14/// in a grpc service chain.
15/// Implementors of this trait can modify, observe, or otherwise interact with requests and
16/// responses in the service pipeline
17///
18/// If you need just intercept requests, pls can [RequestInterceptor]
19///
20/// # Type Parameters
21///
22/// * `S`: A service bound that defines the requirements for the service being wrapped by
23/// the middleware.
24///
25/// See [examples on GitHub](https://github.com/teimuraz/tonic-middleware/tree/main/example)
26#[async_trait]
27pub trait Middleware<S>
28where
29 S: ServiceBound,
30{
31 /// Processes an incoming request and forwards it to the given service.
32 ///
33 /// Implementations may perform operations before or after forwarding the request,
34 /// such as logging, metrics collection, or request modification.
35 ///
36 /// # Parameters
37 ///
38 /// * `req`: The incoming request to process.
39 /// * `service`: The service to forward the processed request to.
40 ///
41 /// # Returns
42 ///
43 /// A `Result` containing the response from the service or an error if one occurred
44 /// during processing.
45 async fn call(&self, req: Request<Body>, service: S) -> Result<Response<Body>, S::Error>;
46}
47
48/// `MiddlewareFor` is a service wrapper that pairs a middleware with its target service.
49///
50/// # Type Parameters
51///
52/// * `S`: The service that this middleware is wrapping.
53/// * `M`: The middleware that is being applied to the service.
54#[derive(Clone)]
55pub struct MiddlewareFor<S, M>
56where
57 S: ServiceBound,
58 M: Middleware<S>,
59{
60 pub inner: S,
61 pub middleware: M,
62}
63
64impl<S, M> MiddlewareFor<S, M>
65where
66 S: ServiceBound,
67 M: Middleware<S>,
68{
69 /// Constructs a new `MiddlewareFor` with the given service and middleware.
70 ///
71 /// # Parameters
72 ///
73 /// * `inner`: The service that this middleware is wrapping.
74 /// * `middleware`: The middleware that is being applied to the service.
75 pub fn new(inner: S, middleware: M) -> Self {
76 MiddlewareFor { inner, middleware }
77 }
78}
79
80impl<S, M> Service<Request<Body>> for MiddlewareFor<S, M>
81where
82 S: ServiceBound,
83 S::Future: Send,
84 M: Middleware<S> + Send + Clone + 'static + Sync,
85{
86 type Response = S::Response;
87 type Error = S::Error;
88 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
89
90 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
91 self.inner.poll_ready(cx)
92 }
93
94 fn call(&mut self, req: Request<Body>) -> Self::Future {
95 let middleware = self.middleware.clone();
96 let inner = self.inner.clone();
97 Box::pin(async move { middleware.call(req, inner).await })
98 }
99}
100
101impl<S, M> NamedService for MiddlewareFor<S, M>
102where
103 S: NamedService + ServiceBound,
104 M: Middleware<S>,
105{
106 const NAME: &'static str = S::NAME;
107}
108
109/// `MiddlewareLayer` provides a way to wrap services with a specific middleware using
110/// the tower `Layer` trait
111#[derive(Clone)]
112pub struct MiddlewareLayer<M> {
113 middleware: M,
114}
115
116impl<M> MiddlewareLayer<M> {
117 /// Creates a new `MiddlewareLayer` with the given middleware.
118 ///
119 /// # Parameters
120 ///
121 /// * `middleware`: The middleware to apply to services.
122 pub fn new(middleware: M) -> Self {
123 MiddlewareLayer { middleware }
124 }
125}
126
127impl<S, M> Layer<S> for MiddlewareLayer<M>
128where
129 S: ServiceBound,
130 M: Middleware<S> + Clone,
131{
132 type Service = MiddlewareFor<S, M>;
133
134 fn layer(&self, inner: S) -> Self::Service {
135 MiddlewareFor::new(inner, self.middleware.clone())
136 }
137}