1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
use std::task::{Context, Poll};

use crate::ServiceBound;
use async_trait::async_trait;
use futures_util::future::BoxFuture;
use tonic::body::BoxBody;
use tonic::codegen::http::Request;
use tonic::codegen::Service;
use tonic::server::NamedService;
use tonic::Status;
use tower::Layer;

/// The `RequestInterceptor` trait is designed to enable the interception and processing of
/// incoming requests within your service pipeline. This trait is particularly useful for
/// performing operations such as authentication, enriching requests with additional metadata,
/// or rejecting  requests based on certain criteria before they reach the service logic.

/// If your requirements extend beyond request interception, and you need to interact with both the
/// request and response or to perform actions after the service call has been made, you should
/// consider implementing `Middleware`.
///
/// See [examples on GitHub](https://github.com/teimuraz/tonic-middleware/tree/main/example)

#[async_trait]
pub trait RequestInterceptor {
    /// Intercepts an incoming request, allowing for inspection, modification, or early rejection
    /// with a `Status` error.
    ///
    /// # Parameters
    ///
    /// * `req`: The incoming `Request` to be intercepted.
    ///
    /// # Returns
    ///
    /// Returns either the potentially modified request for further processing, or a `Status`
    /// error to halt processing with a specific error response.
    async fn intercept(&self, req: Request<BoxBody>) -> Result<Request<BoxBody>, Status>;
}

/// `InterceptorFor` wraps a service with a `RequestInterceptor`, enabling request-level
/// interception before
/// the request reaches the service logic.
/// # Type Parameters
///
/// * `S`: The service being wrapped.
/// * `I`: The `RequestInterceptor` that will preprocess the requests.
#[derive(Clone)]
pub struct InterceptorFor<S, I>
where
    I: RequestInterceptor,
{
    pub inner: S,
    pub interceptor: I,
}

impl<S, I> InterceptorFor<S, I>
where
    I: RequestInterceptor,
{
    /// Creates a new `InterceptorFor` with the provided service and interceptor.
    ///
    /// # Parameters
    ///
    /// * `inner`: The service being wrapped.
    /// * `interceptor`: The interceptor that will preprocess the requests.
    pub fn new(inner: S, interceptor: I) -> Self {
        InterceptorFor { inner, interceptor }
    }
}

impl<S, I> Service<Request<BoxBody>> for InterceptorFor<S, I>
where
    S: ServiceBound,
    S::Future: Send,
    I: RequestInterceptor + Send + Clone + 'static + Sync,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
        let interceptor = self.interceptor.clone();
        let mut inner = self.inner.clone();
        Box::pin(async move {
            match interceptor.intercept(req).await {
                Ok(req) => inner.call(req).await,
                Err(status) => {
                    let response = status.into_http();
                    Ok(response)
                }
            }
        })
    }
}

impl<S, I> NamedService for InterceptorFor<S, I>
where
    S: NamedService,
    I: RequestInterceptor,
{
    const NAME: &'static str = S::NAME;
}

/// `RequestInterceptorLayer` provides a way to wrap services with a specific interceptor using the tower `Layer` trait
///
/// # Type Parameters
///
/// * `I`: The `RequestInterceptor` implementation.
#[derive(Clone)]
pub struct RequestInterceptorLayer<I> {
    interceptor: I,
}

impl<I> RequestInterceptorLayer<I> {
    /// Creates a new `RequestInterceptorLayer` with the given interceptor.
    ///
    /// # Parameters
    ///
    /// * `interceptor`: The interceptor to apply to services.
    pub fn new(interceptor: I) -> Self {
        RequestInterceptorLayer { interceptor }
    }
}

impl<S, I> Layer<S> for RequestInterceptorLayer<I>
where
    I: RequestInterceptor + Clone,
{
    type Service = InterceptorFor<S, I>;

    fn layer(&self, inner: S) -> Self::Service {
        InterceptorFor::new(inner, self.interceptor.clone())
    }
}