tower_oauth2_resource_server/
server.rs

1use core::fmt;
2use std::sync::Arc;
3
4use futures_util::future::join_all;
5use http::Request;
6use log::debug;
7use serde::de::DeserializeOwned;
8
9use crate::{
10    auth_resolver::AuthorizerResolver,
11    authorizer::token_authorizer::Authorizer,
12    claims::DefaultClaims,
13    error::{AuthError, StartupError},
14    error_handler::{DefaultErrorHandler, ErrorHandler},
15    jwt_extract::{BearerTokenJwtExtractor, JwtExtractor},
16    layer::OAuth2ResourceServerLayer,
17    tenant::TenantConfiguration,
18};
19
20/// OAuth2ResourceServer
21///
22/// This is the actual middleware.
23/// May be turned into a tower layer by calling [into_layer](OAuth2ResourceServer::into_layer).
24#[derive(Clone)]
25pub struct OAuth2ResourceServer<Claims = DefaultClaims> {
26    authorizers: Vec<Authorizer<Claims>>,
27    jwt_extractor: Arc<dyn JwtExtractor + Send + Sync>,
28    auth_resolver: Arc<dyn AuthorizerResolver<Claims>>,
29}
30
31impl<Claims> OAuth2ResourceServer<Claims>
32where
33    Claims: Clone + DeserializeOwned + Send + Sync + 'static,
34{
35    pub(crate) async fn new(
36        tenant_configurations: Vec<TenantConfiguration>,
37        auth_resolver: Arc<dyn AuthorizerResolver<Claims>>,
38    ) -> Result<OAuth2ResourceServer<Claims>, StartupError> {
39        let authorizers = join_all(
40            tenant_configurations
41                .into_iter()
42                .map(Authorizer::<Claims>::new)
43                .collect::<Vec<_>>(),
44        )
45        .await
46        .into_iter()
47        .collect::<Result<Vec<_>, StartupError>>()?;
48
49        Ok(OAuth2ResourceServer {
50            jwt_extractor: Arc::new(BearerTokenJwtExtractor {}),
51            authorizers,
52            auth_resolver,
53        })
54    }
55
56    pub(crate) async fn authorize_request<Body>(
57        &self,
58        mut request: Request<Body>,
59    ) -> Result<Request<Body>, AuthError> {
60        let token = match self.jwt_extractor.extract_jwt(request.headers()) {
61            Ok(token) => token,
62            Err(e) => {
63                debug!("JWT extraction failed: {}", e);
64                return Err(e);
65            }
66        };
67        let authorizer = self
68            .auth_resolver
69            .as_ref()
70            .select_authorizer(&self.authorizers, request.headers(), &token)
71            .ok_or(AuthError::AuthorizerNotFound)?;
72        match authorizer.validate(&token) {
73            Ok(res) => {
74                debug!("JWT validation successful ({})", authorizer.identifier());
75                request.extensions_mut().insert(res);
76                Ok(request)
77            }
78            Err(e) => {
79                debug!(
80                    "JWT validation failed ({}) : {}",
81                    authorizer.identifier(),
82                    e
83                );
84                Err(e)
85            }
86        }
87    }
88}
89
90impl<Claims> fmt::Debug for OAuth2ResourceServer<Claims> {
91    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92        f.debug_struct("OAuth2AuthenticationManager").finish()
93    }
94}
95
96impl<Claims> OAuth2ResourceServer<Claims>
97where
98    Claims: Clone,
99{
100    /// Returns a [tower layer](https://docs.rs/tower/latest/tower/trait.Layer.html).
101    pub fn into_layer<ResBody>(&self) -> OAuth2ResourceServerLayer<ResBody, Claims>
102    where
103        ResBody: Default,
104    {
105        OAuth2ResourceServerLayer::new(self.clone(), Arc::new(DefaultErrorHandler))
106    }
107
108    /// Returns a [tower layer](https://docs.rs/tower/latest/tower/trait.Layer.html) that uses a custom [ErrorHandler] implementation.
109    pub fn into_layer_with_error_handler<ResBody>(
110        &self,
111        error_handler: Arc<dyn ErrorHandler<ResBody>>,
112    ) -> OAuth2ResourceServerLayer<ResBody, Claims> {
113        OAuth2ResourceServerLayer::new(self.clone(), error_handler)
114    }
115}