tower_oauth2_resource_server/
layer.rs1use futures_util::{future::BoxFuture, Future};
2use http::{Request, Response};
3use pin_project::pin_project;
4use serde::de::DeserializeOwned;
5
6use std::{
7 pin::Pin,
8 task::{ready, Context, Poll},
9};
10use tower::{Layer, Service};
11
12use crate::{error::AuthError, server::OAuth2ResourceServer};
13
14trait Authorize<B> {
15 type Future: Future<Output = Result<Request<B>, AuthError>>;
16
17 fn authorize(&mut self, request: Request<B>) -> Self::Future;
18}
19
20impl<S, ReqBody, Claims> Authorize<ReqBody> for OAuth2ResourceServerService<S, Claims>
21where
22 Claims: DeserializeOwned + Clone + Send + Sync + 'static,
23 ReqBody: Send + 'static,
24{
25 type Future = BoxFuture<'static, Result<Request<ReqBody>, AuthError>>;
26
27 fn authorize(&mut self, request: Request<ReqBody>) -> Self::Future {
28 let auth = self.auth_manager.clone();
29 Box::pin(async move { auth.authorize_request(request).await })
30 }
31}
32
33#[derive(Clone, Debug)]
34pub struct OAuth2ResourceServerLayer<Claims> {
35 auth_manager: OAuth2ResourceServer<Claims>,
36}
37
38impl<S, Claims> Layer<S> for OAuth2ResourceServerLayer<Claims>
39where
40 Claims: Clone + DeserializeOwned + Send + 'static,
41{
42 type Service = OAuth2ResourceServerService<S, Claims>;
43
44 fn layer(&self, inner: S) -> Self::Service {
45 OAuth2ResourceServerService::new(inner, self.auth_manager.clone())
46 }
47}
48
49impl<Claims> OAuth2ResourceServerLayer<Claims> {
50 pub(crate) fn new(auth_manager: OAuth2ResourceServer<Claims>) -> Self {
51 OAuth2ResourceServerLayer { auth_manager }
52 }
53}
54
55#[derive(Clone, Debug)]
56pub struct OAuth2ResourceServerService<S, Claims> {
57 inner: S,
58 auth_manager: OAuth2ResourceServer<Claims>,
59}
60
61impl<S, Claims> OAuth2ResourceServerService<S, Claims> {
62 fn new(inner: S, auth_manager: OAuth2ResourceServer<Claims>) -> Self {
63 Self {
64 inner,
65 auth_manager,
66 }
67 }
68}
69
70impl<S, ReqBody, ResBody, Claims> Service<Request<ReqBody>>
71 for OAuth2ResourceServerService<S, Claims>
72where
73 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
74 ResBody: Default,
75 Claims: Clone + DeserializeOwned + Send + Sync + 'static,
76 ReqBody: Send + 'static,
77{
78 type Response = S::Response;
79 type Error = S::Error;
80 type Future = ResponseFuture<S, ReqBody, Claims>;
81
82 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
83 self.inner.poll_ready(cx)
84 }
85
86 fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
87 let inner = self.inner.clone();
88 let authorize = self.authorize(request);
89
90 ResponseFuture {
91 state: State::Authorize { authorize },
92 service: inner,
93 }
94 }
95}
96
97#[pin_project]
98pub struct ResponseFuture<S, ReqBody, Claims>
99where
100 S: Service<Request<ReqBody>>,
101 ReqBody: Send + 'static,
102 Claims: Clone + DeserializeOwned + Send + Sync + 'static,
103{
104 #[pin]
105 state: State<<OAuth2ResourceServerService<S, Claims> as Authorize<ReqBody>>::Future, S::Future>,
106 service: S,
107}
108
109#[pin_project(project = StateProj)]
110enum State<A, SFut> {
111 Authorize {
112 #[pin]
113 authorize: A,
114 },
115 Authorized {
116 #[pin]
117 fut: SFut,
118 },
119}
120
121impl<S, ReqBody, B, Claims> Future for ResponseFuture<S, ReqBody, Claims>
122where
123 S: Service<Request<ReqBody>, Response = Response<B>>,
124 B: Default,
125 ReqBody: Send + 'static,
126 Claims: Clone + DeserializeOwned + Send + Sync + 'static,
127{
128 type Output = Result<Response<B>, S::Error>;
129
130 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
131 let mut this = self.project();
132
133 loop {
134 match this.state.as_mut().project() {
135 StateProj::Authorize { authorize } => {
136 let auth = ready!(authorize.poll(cx));
137 match auth {
138 Ok(req) => {
139 let fut = this.service.call(req);
140 this.state.set(State::Authorized { fut })
141 }
142 Err(res) => {
143 let response = Response::<B>::from(res);
144 return Poll::Ready(Ok(response));
145 }
146 };
147 }
148 StateProj::Authorized { fut } => {
149 return fut.poll(cx);
150 }
151 }
152 }
153 }
154}