tower_oauth2_resource_server/
server.rs

1use core::fmt;
2use std::sync::Arc;
3
4use futures_util::future::join_all;
5use http::Request;
6use log::debug;
7use serde::de::DeserializeOwned;
8
9use crate::{
10    auth_resolver::AuthorizerResolver,
11    authorizer::token_authorizer::Authorizer,
12    claims::DefaultClaims,
13    error::{AuthError, StartupError},
14    error_handler::{DefaultErrorHandler, ErrorHandler},
15    jwt_resolver::{BearerTokenResolver, DefaultBearerTokenResolver, request_ref},
16    layer::OAuth2ResourceServerLayer,
17    tenant::TenantConfiguration,
18};
19
20/// OAuth2ResourceServer
21///
22/// This is the actual middleware.
23/// May be turned into a tower layer by calling [into_layer](OAuth2ResourceServer::into_layer).
24#[derive(Clone)]
25pub struct OAuth2ResourceServer<Claims = DefaultClaims> {
26    authorizers: Vec<Authorizer<Claims>>,
27    bearer_token_resolver: Arc<dyn BearerTokenResolver + Send + Sync>,
28    auth_resolver: Arc<dyn AuthorizerResolver<Claims>>,
29}
30
31impl<Claims> OAuth2ResourceServer<Claims>
32where
33    Claims: Clone + DeserializeOwned + Send + Sync + 'static,
34{
35    pub(crate) async fn new(
36        tenant_configurations: Vec<TenantConfiguration>,
37        auth_resolver: Arc<dyn AuthorizerResolver<Claims>>,
38        bearer_token_resolver: Option<Arc<dyn BearerTokenResolver + Send + Sync>>,
39    ) -> Result<OAuth2ResourceServer<Claims>, StartupError> {
40        let authorizers = join_all(
41            tenant_configurations
42                .into_iter()
43                .map(Authorizer::<Claims>::new)
44                .collect::<Vec<_>>(),
45        )
46        .await
47        .into_iter()
48        .collect::<Result<Vec<_>, StartupError>>()?;
49
50        Ok(OAuth2ResourceServer {
51            bearer_token_resolver: bearer_token_resolver
52                .unwrap_or_else(|| Arc::new(DefaultBearerTokenResolver {})),
53            authorizers,
54            auth_resolver,
55        })
56    }
57
58    pub(crate) async fn authorize_request<Body>(
59        &self,
60        mut request: Request<Body>,
61    ) -> Result<Request<Body>, AuthError> {
62        let req_ref = request_ref(&request);
63        let token = match self.bearer_token_resolver.resolve(&req_ref) {
64            Ok(token) => token,
65            Err(e) => {
66                debug!("JWT extraction failed: {}", e);
67                return Err(e);
68            }
69        };
70        let authorizer = self
71            .auth_resolver
72            .as_ref()
73            .select_authorizer(&self.authorizers, request.headers(), &token)
74            .ok_or(AuthError::AuthorizerNotFound)?;
75        match authorizer.validate(&token) {
76            Ok(res) => {
77                debug!("JWT validation successful ({})", authorizer.identifier());
78                request.extensions_mut().insert(res);
79                Ok(request)
80            }
81            Err(e) => {
82                debug!(
83                    "JWT validation failed ({}) : {}",
84                    authorizer.identifier(),
85                    e
86                );
87                Err(e)
88            }
89        }
90    }
91}
92
93impl<Claims> fmt::Debug for OAuth2ResourceServer<Claims> {
94    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        f.debug_struct("OAuth2AuthenticationManager").finish()
96    }
97}
98
99impl<Claims> OAuth2ResourceServer<Claims>
100where
101    Claims: Clone,
102{
103    /// Returns a [tower layer](https://docs.rs/tower/latest/tower/trait.Layer.html).
104    pub fn into_layer<ResBody>(&self) -> OAuth2ResourceServerLayer<ResBody, Claims>
105    where
106        ResBody: Default,
107    {
108        OAuth2ResourceServerLayer::new(self.clone(), Arc::new(DefaultErrorHandler))
109    }
110
111    /// Returns a [tower layer](https://docs.rs/tower/latest/tower/trait.Layer.html) that uses a custom [ErrorHandler] implementation.
112    pub fn into_layer_with_error_handler<ResBody>(
113        &self,
114        error_handler: Arc<dyn ErrorHandler<ResBody>>,
115    ) -> OAuth2ResourceServerLayer<ResBody, Claims> {
116        OAuth2ResourceServerLayer::new(self.clone(), error_handler)
117    }
118}