tower_oauth2_resource_server/
layer.rs

1use futures_util::{future::BoxFuture, Future};
2use http::{Request, Response};
3use pin_project::pin_project;
4use serde::de::DeserializeOwned;
5
6use std::{
7    pin::Pin,
8    task::{ready, Context, Poll},
9};
10use tower::{Layer, Service};
11
12use crate::{error::AuthError, server::OAuth2ResourceServer};
13
14trait Authorize<B> {
15    type Future: Future<Output = Result<Request<B>, AuthError>>;
16
17    fn authorize(&mut self, request: Request<B>) -> Self::Future;
18}
19
20impl<S, ReqBody, Claims> Authorize<ReqBody> for OAuth2ResourceServerService<S, Claims>
21where
22    Claims: DeserializeOwned + Clone + Send + Sync + 'static,
23    ReqBody: Send + 'static,
24{
25    type Future = BoxFuture<'static, Result<Request<ReqBody>, AuthError>>;
26
27    fn authorize(&mut self, request: Request<ReqBody>) -> Self::Future {
28        let auth = self.auth_manager.clone();
29        Box::pin(async move { auth.authorize_request(request).await })
30    }
31}
32
33#[derive(Clone, Debug)]
34pub struct OAuth2ResourceServerLayer<Claims> {
35    auth_manager: OAuth2ResourceServer<Claims>,
36}
37
38impl<S, Claims> Layer<S> for OAuth2ResourceServerLayer<Claims>
39where
40    Claims: Clone + DeserializeOwned + Send + 'static,
41{
42    type Service = OAuth2ResourceServerService<S, Claims>;
43
44    fn layer(&self, inner: S) -> Self::Service {
45        OAuth2ResourceServerService::new(inner, self.auth_manager.clone())
46    }
47}
48
49impl<Claims> OAuth2ResourceServerLayer<Claims> {
50    pub(crate) fn new(auth_manager: OAuth2ResourceServer<Claims>) -> Self {
51        OAuth2ResourceServerLayer { auth_manager }
52    }
53}
54
55#[derive(Clone, Debug)]
56pub struct OAuth2ResourceServerService<S, Claims> {
57    inner: S,
58    auth_manager: OAuth2ResourceServer<Claims>,
59}
60
61impl<S, Claims> OAuth2ResourceServerService<S, Claims> {
62    fn new(inner: S, auth_manager: OAuth2ResourceServer<Claims>) -> Self {
63        Self {
64            inner,
65            auth_manager,
66        }
67    }
68}
69
70impl<S, ReqBody, ResBody, Claims> Service<Request<ReqBody>>
71    for OAuth2ResourceServerService<S, Claims>
72where
73    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
74    ResBody: Default,
75    Claims: Clone + DeserializeOwned + Send + Sync + 'static,
76    ReqBody: Send + 'static,
77{
78    type Response = S::Response;
79    type Error = S::Error;
80    type Future = ResponseFuture<S, ReqBody, Claims>;
81
82    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
83        self.inner.poll_ready(cx)
84    }
85
86    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
87        let inner = self.inner.clone();
88        let authorize = self.authorize(request);
89
90        ResponseFuture {
91            state: State::Authorize { authorize },
92            service: inner,
93        }
94    }
95}
96
97#[pin_project]
98pub struct ResponseFuture<S, ReqBody, Claims>
99where
100    S: Service<Request<ReqBody>>,
101    ReqBody: Send + 'static,
102    Claims: Clone + DeserializeOwned + Send + Sync + 'static,
103{
104    #[pin]
105    state: State<<OAuth2ResourceServerService<S, Claims> as Authorize<ReqBody>>::Future, S::Future>,
106    service: S,
107}
108
109#[pin_project(project = StateProj)]
110enum State<A, SFut> {
111    Authorize {
112        #[pin]
113        authorize: A,
114    },
115    Authorized {
116        #[pin]
117        fut: SFut,
118    },
119}
120
121impl<S, ReqBody, B, Claims> Future for ResponseFuture<S, ReqBody, Claims>
122where
123    S: Service<Request<ReqBody>, Response = Response<B>>,
124    B: Default,
125    ReqBody: Send + 'static,
126    Claims: Clone + DeserializeOwned + Send + Sync + 'static,
127{
128    type Output = Result<Response<B>, S::Error>;
129
130    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
131        let mut this = self.project();
132
133        loop {
134            match this.state.as_mut().project() {
135                StateProj::Authorize { authorize } => {
136                    let auth = ready!(authorize.poll(cx));
137                    match auth {
138                        Ok(req) => {
139                            let fut = this.service.call(req);
140                            this.state.set(State::Authorized { fut })
141                        }
142                        Err(res) => {
143                            let response = Response::<B>::from(res);
144                            return Poll::Ready(Ok(response));
145                        }
146                    };
147                }
148                StateProj::Authorized { fut } => {
149                    return fut.poll(cx);
150                }
151            }
152        }
153    }
154}