tower_oauth2_resource_server/
server.rs1use core::fmt;
2use std::sync::Arc;
3
4use futures_util::future::join_all;
5use http::Request;
6use log::debug;
7use serde::de::DeserializeOwned;
8
9use crate::{
10 auth_resolver::AuthorizerResolver,
11 authorizer::token_authorizer::Authorizer,
12 claims::DefaultClaims,
13 error::{AuthError, StartupError},
14 jwt_extract::{BearerTokenJwtExtractor, JwtExtractor},
15 layer::OAuth2ResourceServerLayer,
16 tenant::TenantConfiguration,
17};
18
19#[derive(Clone)]
24pub struct OAuth2ResourceServer<Claims = DefaultClaims> {
25 authorizers: Vec<Authorizer<Claims>>,
26 jwt_extractor: Arc<dyn JwtExtractor + Send + Sync>,
27 auth_resolver: Arc<dyn AuthorizerResolver<Claims>>,
28}
29
30impl<Claims> OAuth2ResourceServer<Claims>
31where
32 Claims: Clone + DeserializeOwned + Send + Sync + 'static,
33{
34 pub(crate) async fn new(
35 tenant_configurations: Vec<TenantConfiguration>,
36 auth_resolver: Arc<dyn AuthorizerResolver<Claims>>,
37 ) -> Result<OAuth2ResourceServer<Claims>, StartupError> {
38 let authorizers = join_all(
39 tenant_configurations
40 .into_iter()
41 .map(Authorizer::<Claims>::new)
42 .collect::<Vec<_>>(),
43 )
44 .await
45 .into_iter()
46 .collect::<Result<Vec<_>, StartupError>>()?;
47
48 Ok(OAuth2ResourceServer {
49 jwt_extractor: Arc::new(BearerTokenJwtExtractor {}),
50 authorizers,
51 auth_resolver,
52 })
53 }
54
55 pub(crate) async fn authorize_request<Body>(
56 &self,
57 mut request: Request<Body>,
58 ) -> Result<Request<Body>, AuthError> {
59 let token = match self.jwt_extractor.extract_jwt(request.headers()) {
60 Ok(token) => token,
61 Err(e) => {
62 debug!("JWT extraction failed: {}", e);
63 return Err(e);
64 }
65 };
66 let authorizer = self
67 .auth_resolver
68 .as_ref()
69 .select_authorizer(&self.authorizers, request.headers(), &token)
70 .ok_or(AuthError::AuthorizerNotFound)?;
71 match authorizer.validate(&token) {
72 Ok(res) => {
73 debug!("JWT validation successful ({})", authorizer.identifier());
74 request.extensions_mut().insert(res);
75 Ok(request)
76 }
77 Err(e) => {
78 debug!(
79 "JWT validation failed ({}) : {}",
80 authorizer.identifier(),
81 e
82 );
83 Err(e)
84 }
85 }
86 }
87}
88
89impl<Claims> fmt::Debug for OAuth2ResourceServer<Claims>
90where
91 Claims: DeserializeOwned,
92{
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 f.debug_struct("OAuth2AuthenticationManager").finish()
95 }
96}
97
98impl<Claims> OAuth2ResourceServer<Claims>
99where
100 Claims: Clone + DeserializeOwned,
101{
102 pub fn into_layer(&self) -> OAuth2ResourceServerLayer<Claims> {
104 OAuth2ResourceServerLayer::new(self.clone())
105 }
106}