tower_oauth2_resource_server/
server.rs

1use core::fmt;
2use std::sync::Arc;
3
4use futures_util::future::join_all;
5use http::Request;
6use log::debug;
7use serde::de::DeserializeOwned;
8
9use crate::{
10    auth_resolver::AuthorizerResolver,
11    authorizer::token_authorizer::Authorizer,
12    claims::DefaultClaims,
13    error::{AuthError, StartupError},
14    jwt_extract::{BearerTokenJwtExtractor, JwtExtractor},
15    layer::OAuth2ResourceServerLayer,
16    tenant::TenantConfiguration,
17};
18
19/// OAuth2ResourceServer
20///
21/// This is the actual middleware.
22/// May be turned into a tower layer by calling [into_layer](OAuth2ResourceServer::into_layer).
23#[derive(Clone)]
24pub struct OAuth2ResourceServer<Claims = DefaultClaims> {
25    authorizers: Vec<Authorizer<Claims>>,
26    jwt_extractor: Arc<dyn JwtExtractor + Send + Sync>,
27    auth_resolver: Arc<dyn AuthorizerResolver<Claims>>,
28}
29
30impl<Claims> OAuth2ResourceServer<Claims>
31where
32    Claims: Clone + DeserializeOwned + Send + Sync + 'static,
33{
34    pub(crate) async fn new(
35        tenant_configurations: Vec<TenantConfiguration>,
36        auth_resolver: Arc<dyn AuthorizerResolver<Claims>>,
37    ) -> Result<OAuth2ResourceServer<Claims>, StartupError> {
38        let authorizers = join_all(
39            tenant_configurations
40                .into_iter()
41                .map(Authorizer::<Claims>::new)
42                .collect::<Vec<_>>(),
43        )
44        .await
45        .into_iter()
46        .collect::<Result<Vec<_>, StartupError>>()?;
47
48        Ok(OAuth2ResourceServer {
49            jwt_extractor: Arc::new(BearerTokenJwtExtractor {}),
50            authorizers,
51            auth_resolver,
52        })
53    }
54
55    pub(crate) async fn authorize_request<Body>(
56        &self,
57        mut request: Request<Body>,
58    ) -> Result<Request<Body>, AuthError> {
59        let token = match self.jwt_extractor.extract_jwt(request.headers()) {
60            Ok(token) => token,
61            Err(e) => {
62                debug!("JWT extraction failed: {}", e);
63                return Err(e);
64            }
65        };
66        let authorizer = self
67            .auth_resolver
68            .as_ref()
69            .select_authorizer(&self.authorizers, request.headers(), &token)
70            .ok_or(AuthError::AuthorizerNotFound)?;
71        match authorizer.validate(&token) {
72            Ok(res) => {
73                debug!("JWT validation successful ({})", authorizer.identifier());
74                request.extensions_mut().insert(res);
75                Ok(request)
76            }
77            Err(e) => {
78                debug!(
79                    "JWT validation failed ({}) : {}",
80                    authorizer.identifier(),
81                    e
82                );
83                Err(e)
84            }
85        }
86    }
87}
88
89impl<Claims> fmt::Debug for OAuth2ResourceServer<Claims>
90where
91    Claims: DeserializeOwned,
92{
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        f.debug_struct("OAuth2AuthenticationManager").finish()
95    }
96}
97
98impl<Claims> OAuth2ResourceServer<Claims>
99where
100    Claims: Clone + DeserializeOwned,
101{
102    /// Returns a [tower layer](https://docs.rs/tower/latest/tower/trait.Layer.html).
103    pub fn into_layer(&self) -> OAuth2ResourceServerLayer<Claims> {
104        OAuth2ResourceServerLayer::new(self.clone())
105    }
106}