tower_oauth2_resource_server/
layer.rs

1use futures_util::{future::BoxFuture, Future};
2use http::{Request, Response};
3use pin_project::pin_project;
4use serde::de::DeserializeOwned;
5
6use std::{
7    pin::Pin,
8    task::{ready, Context, Poll},
9};
10use tower::{Layer, Service};
11
12use crate::{error::AuthError, server::OAuth2ResourceServer};
13
14trait Authorize<B> {
15    type Future: Future<Output = Result<Request<B>, AuthError>>;
16
17    fn authorize(&mut self, request: Request<B>) -> Self::Future;
18}
19
20impl<S, ReqBody, Claims> Authorize<ReqBody> for OAuth2ResourceServerService<S, Claims>
21where
22    Claims: DeserializeOwned + Clone + Send + Sync + 'static,
23    ReqBody: Send + 'static,
24{
25    type Future = BoxFuture<'static, Result<Request<ReqBody>, AuthError>>;
26
27    fn authorize(&mut self, request: Request<ReqBody>) -> Self::Future {
28        let auth = self.auth_manager.clone();
29        Box::pin(async move { auth.authorize_request(request).await })
30    }
31}
32
33#[derive(Clone, Debug)]
34pub struct OAuth2ResourceServerLayer<Claims>
35where
36    Claims: DeserializeOwned,
37{
38    auth_manager: OAuth2ResourceServer<Claims>,
39}
40
41impl<S, Claims> Layer<S> for OAuth2ResourceServerLayer<Claims>
42where
43    Claims: Clone + DeserializeOwned + Send + 'static,
44{
45    type Service = OAuth2ResourceServerService<S, Claims>;
46
47    fn layer(&self, inner: S) -> Self::Service {
48        OAuth2ResourceServerService::new(inner, self.auth_manager.clone())
49    }
50}
51
52impl<Claims> OAuth2ResourceServerLayer<Claims>
53where
54    Claims: DeserializeOwned,
55{
56    pub(crate) fn new(auth_manager: OAuth2ResourceServer<Claims>) -> Self
57    where
58        Claims: DeserializeOwned,
59    {
60        OAuth2ResourceServerLayer { auth_manager }
61    }
62}
63
64#[derive(Clone, Debug)]
65pub struct OAuth2ResourceServerService<S, Claims>
66where
67    Claims: Clone + DeserializeOwned + Send + 'static,
68{
69    inner: S,
70    auth_manager: OAuth2ResourceServer<Claims>,
71}
72
73impl<S, Claims> OAuth2ResourceServerService<S, Claims>
74where
75    Claims: Clone + DeserializeOwned + Send + 'static,
76{
77    fn new(inner: S, auth_manager: OAuth2ResourceServer<Claims>) -> Self {
78        Self {
79            inner,
80            auth_manager,
81        }
82    }
83}
84
85impl<S, ReqBody, ResBody, Claims> Service<Request<ReqBody>>
86    for OAuth2ResourceServerService<S, Claims>
87where
88    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
89    ResBody: Default,
90    Claims: Clone + DeserializeOwned + Send + Sync + 'static,
91    ReqBody: Send + 'static,
92{
93    type Response = S::Response;
94    type Error = S::Error;
95    type Future = ResponseFuture<S, ReqBody, Claims>;
96
97    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
98        self.inner.poll_ready(cx)
99    }
100
101    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
102        let inner = self.inner.clone();
103        let authorize = self.authorize(request);
104
105        ResponseFuture {
106            state: State::Authorize { authorize },
107            service: inner,
108        }
109    }
110}
111
112#[pin_project]
113pub struct ResponseFuture<S, ReqBody, Claims>
114where
115    S: Service<Request<ReqBody>>,
116    ReqBody: Send + 'static,
117    Claims: Clone + DeserializeOwned + Send + Sync + 'static,
118{
119    #[pin]
120    state: State<<OAuth2ResourceServerService<S, Claims> as Authorize<ReqBody>>::Future, S::Future>,
121    service: S,
122}
123
124#[pin_project(project = StateProj)]
125enum State<A, SFut> {
126    Authorize {
127        #[pin]
128        authorize: A,
129    },
130    Authorized {
131        #[pin]
132        fut: SFut,
133    },
134}
135
136impl<S, ReqBody, B, Claims> Future for ResponseFuture<S, ReqBody, Claims>
137where
138    S: Service<Request<ReqBody>, Response = Response<B>>,
139    B: Default,
140    ReqBody: Send + 'static,
141    Claims: Clone + DeserializeOwned + Send + Sync + 'static,
142{
143    type Output = Result<Response<B>, S::Error>;
144
145    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
146        let mut this = self.project();
147
148        loop {
149            match this.state.as_mut().project() {
150                StateProj::Authorize { authorize } => {
151                    let auth = ready!(authorize.poll(cx));
152                    match auth {
153                        Ok(req) => {
154                            let fut = this.service.call(req);
155                            this.state.set(State::Authorized { fut })
156                        }
157                        Err(res) => {
158                            let response = Response::<B>::from(res);
159                            return Poll::Ready(Ok(response));
160                        }
161                    };
162                }
163                StateProj::Authorized { fut } => {
164                    return fut.poll(cx);
165                }
166            }
167        }
168    }
169}