tower_oauth2_resource_server/
tenant.rs

1use std::{collections::HashSet, time::Duration};
2
3use jsonwebtoken::{Algorithm, jwk::JwkSet};
4use mockall_double::double;
5use url::Url;
6
7use crate::{error::StartupError, oidc::OidcConfig, validation::ClaimsValidationSpec};
8
9#[double]
10use crate::oidc::OidcDiscovery;
11
12/// Returns the default set of allowed algorithms for JWT validation.
13///
14/// This includes all standard asymmetric algorithms that are considered secure.
15/// HMAC algorithms (HS256, HS384, HS512) are excluded by default as they are
16/// symmetric and typically not appropriate for OAuth2/OIDC flows where the
17/// authorization server and resource server are separate entities.
18pub fn default_allowed_algorithms() -> HashSet<Algorithm> {
19    HashSet::from([
20        Algorithm::RS256,
21        Algorithm::RS384,
22        Algorithm::RS512,
23        Algorithm::ES256,
24        Algorithm::ES384,
25        Algorithm::PS256,
26        Algorithm::PS384,
27        Algorithm::PS512,
28        Algorithm::EdDSA,
29    ])
30}
31
32#[derive(Debug, Clone)]
33pub(crate) enum TenantKind {
34    JwksUrl {
35        jwks_url: Url,
36        jwks_refresh_interval: Duration,
37    },
38    Static {
39        jwks: JwkSet,
40    },
41}
42
43#[derive(Debug, Clone)]
44pub struct TenantConfiguration {
45    pub(crate) identifier: String,
46    pub(crate) claims_validation_spec: ClaimsValidationSpec,
47    pub(crate) allowed_algorithms: HashSet<Algorithm>,
48    pub(crate) kind: TenantKind,
49}
50
51impl TenantConfiguration {
52    /// Build a tenant configuration for issuer_url (what authorization server to use).
53    ///
54    /// On startup, the OIDC Provider Configuration endpoint of the
55    /// authorization server will be queried in order to
56    /// self-configure the middleware.
57    ///
58    /// If `issuer_url` is set to `https://authorization-server.com/issuer`,
59    /// at least one of the following endpoints need to available.
60    ///
61    /// - `https://authorization-server.com/issuer/.well-known/openid-configuration`
62    /// - `https://authorization-server.com/.well-known/openid-configuration/issuer`
63    /// - `https://authorization-server.com/.well-known/oauth-authorization-server/issuer`
64    ///
65    /// A consequence of the self-configuration is that the authorization server
66    /// must be available when the middleware is started.
67    /// In cases where the middleware must be able to start independently from
68    /// the authorization server, the `jwks_url` property can be set.
69    /// This will prevent the self-configuration on start up.
70    pub fn builder(issuer_url: impl Into<String>) -> TenantConfigurationBuilder {
71        TenantConfigurationBuilder::new(issuer_url)
72    }
73
74    /// Build a tenant configuration for a static JWK Set
75    ///
76    /// Format of `jwks` must follow the "JWK Set Format" as defined in  RFC 7517
77    pub fn static_builder(jwks: impl Into<String>) -> TenantStaticConfigurationBuilder {
78        TenantStaticConfigurationBuilder::new(jwks)
79    }
80}
81
82pub struct TenantConfigurationBuilder {
83    issuer_url: String,
84    identifier: Option<String>,
85    jwks_url: Option<String>,
86    audiences: Vec<String>,
87    jwk_set_refresh_interval: Option<Duration>,
88    claims_validation_spec: Option<ClaimsValidationSpec>,
89    allowed_algorithms: Option<HashSet<Algorithm>>,
90}
91
92impl TenantConfigurationBuilder {
93    fn new(issuer_url: impl Into<String>) -> Self {
94        Self {
95            issuer_url: issuer_url.into(),
96            identifier: None,
97            jwks_url: None,
98            audiences: Vec::new(),
99            jwk_set_refresh_interval: None,
100            claims_validation_spec: None,
101            allowed_algorithms: None,
102        }
103    }
104
105    /// Set an identifier for the tenant.
106    ///
107    /// Can be accessed on a [Authorizer](crate::authorizer::token_authorizer::Authorizer) in
108    /// order to identify what authorization server the authorizer is configured for.
109    ///
110    /// Defaults to `issuer_url`.
111    pub fn identifier(mut self, identifier: &str) -> Self {
112        self.identifier = Some(identifier.to_string());
113        self
114    }
115
116    /// Set the jwks_url (what url to query valid public keys from).
117    ///
118    /// This url is normally fetched by calling the OIDC Provider Configuration endpoint
119    /// of the authorization server.
120    /// Only provide this property if the middleware must be able to start
121    /// independently from the authorization server.
122    pub fn jwks_url(mut self, jwks_url: impl Into<String>) -> Self {
123        self.jwks_url = Some(jwks_url.into());
124        self
125    }
126
127    /// Set the expected audiences.
128    ///
129    /// Used to validate `aud` claim of JWTs.
130    pub fn audiences(mut self, audiences: &[impl ToString]) -> Self {
131        self.audiences = audiences.iter().map(|aud| aud.to_string()).collect();
132        self
133    }
134
135    /// Set the interval for rotating jwks.
136    ///
137    /// The `jwks_url` is periodically queried in order to update
138    /// public keys that JWT signatures will be validated against.
139    ///
140    /// Default value is `Duration::from_secs(60)`.
141    pub fn jwks_refresh_interval(mut self, jwk_set_refresh_interval: Duration) -> Self {
142        self.jwk_set_refresh_interval = Some(jwk_set_refresh_interval);
143        self
144    }
145
146    /// Set what claims of JWTs to validate.
147    ///
148    /// By default, `iss`, `exp`, `aud` and possibly `nbf` will be validated.
149    pub fn claims_validation(mut self, claims_validation: ClaimsValidationSpec) -> Self {
150        self.claims_validation_spec = Some(claims_validation);
151        self
152    }
153
154    /// Set the allowed algorithms for JWT validation.
155    ///
156    /// By default, all standard asymmetric algorithms are allowed (RS256, RS384, RS512,
157    /// ES256, ES384, PS256, PS384, PS512, EdDSA). HMAC algorithms are excluded by default.
158    ///
159    /// Use this method to restrict the allowed algorithms if your authorization server
160    /// only uses specific algorithms.
161    pub fn allowed_algorithms(mut self, algorithms: &[Algorithm]) -> Self {
162        self.allowed_algorithms = Some(algorithms.iter().copied().collect());
163        self
164    }
165
166    /// Construct a TenantConfiguration.
167    pub async fn build(self) -> Result<TenantConfiguration, StartupError> {
168        let identifier = match self.identifier {
169            Some(id) => id,
170            None => self.issuer_url.clone(),
171        };
172
173        let issuer_url = Url::parse(&self.issuer_url)
174            .map_err(|_| StartupError::InvalidParameter("Invalid issuer_url format".to_string()))?;
175
176        let jwks_url = self
177            .jwks_url
178            .as_deref()
179            .map(|jwks_url| {
180                Url::parse(jwks_url).map_err(|_| {
181                    StartupError::InvalidParameter("Invalid jwks_url format".to_string())
182                })
183            })
184            .transpose()?;
185
186        let oidc_config = if jwks_url.is_some() {
187            None
188        } else {
189            Some(
190                OidcDiscovery::discover(&issuer_url)
191                    .await
192                    .map_err(|e| StartupError::OidcDiscoveryFailed(e.to_string()))?,
193            )
194        };
195
196        let claims_validation_spec = self
197            .claims_validation_spec
198            .unwrap_or(recommended_claims_spec(&self.audiences, &oidc_config));
199
200        let allowed_algorithms = self
201            .allowed_algorithms
202            .unwrap_or_else(default_allowed_algorithms);
203
204        let jwks_url = match jwks_url {
205            Some(jwks_url) => jwks_url,
206            None => match oidc_config {
207                Some(oidc_config) => oidc_config.jwks_uri,
208                None => {
209                    return Err(StartupError::InvalidParameter(
210                        "Failed to resolve JWKS URL".to_string(),
211                    ));
212                }
213            },
214        };
215
216        let kind = TenantKind::JwksUrl {
217            jwks_url,
218            jwks_refresh_interval: self
219                .jwk_set_refresh_interval
220                .unwrap_or(Duration::from_secs(60)),
221        };
222
223        Ok(TenantConfiguration {
224            identifier,
225            claims_validation_spec,
226            allowed_algorithms,
227            kind,
228        })
229    }
230}
231
232pub struct TenantStaticConfigurationBuilder {
233    identifier: Option<String>,
234    audiences: Vec<String>,
235    claims_validation_spec: Option<ClaimsValidationSpec>,
236    allowed_algorithms: Option<HashSet<Algorithm>>,
237    jwks: String,
238}
239
240impl TenantStaticConfigurationBuilder {
241    fn new(jwks: impl Into<String>) -> Self {
242        Self {
243            jwks: jwks.into(),
244            identifier: None,
245            audiences: Vec::new(),
246            claims_validation_spec: None,
247            allowed_algorithms: None,
248        }
249    }
250
251    /// Set an identifier for the tenant.
252    ///
253    /// Can be accessed on a [Authorizer](crate::authorizer::token_authorizer::Authorizer) in
254    /// order to identify what authorization server the authorizer is configured for.
255    ///
256    /// Used to validate the the iss
257    ///
258    /// Defaults to `static`.
259    pub fn identifier(mut self, identifier: &str) -> Self {
260        self.identifier = Some(identifier.to_string());
261        self
262    }
263
264    /// Set the expected audiences.
265    ///
266    /// Used to validate `aud` claim of JWTs.
267    pub fn audiences(mut self, audiences: &[impl ToString]) -> Self {
268        self.audiences = audiences.iter().map(|aud| aud.to_string()).collect();
269        self
270    }
271
272    /// Set what claims of JWTs to validate.
273    ///
274    /// By default, `iss`, `exp`, `aud` and possibly `nbf` will be validated.
275    pub fn claims_validation(mut self, claims_validation: ClaimsValidationSpec) -> Self {
276        self.claims_validation_spec = Some(claims_validation);
277        self
278    }
279
280    /// Set the allowed algorithms for JWT validation.
281    ///
282    /// By default, all standard asymmetric algorithms are allowed (RS256, RS384, RS512,
283    /// ES256, ES384, PS256, PS384, PS512, EdDSA). HMAC algorithms are excluded by default.
284    ///
285    /// Use this method to restrict the allowed algorithms if your authorization server
286    /// only uses specific algorithms.
287    pub fn allowed_algorithms(mut self, algorithms: &[Algorithm]) -> Self {
288        self.allowed_algorithms = Some(algorithms.iter().copied().collect());
289        self
290    }
291
292    /// Construct a TenantConfiguration.
293    pub fn build(self) -> Result<TenantConfiguration, StartupError> {
294        let identifier = self.identifier.unwrap_or_else(|| String::from("static"));
295
296        let claims_validation_spec = self
297            .claims_validation_spec
298            .unwrap_or(recommended_claims_spec(&self.audiences, &None));
299
300        let allowed_algorithms = self
301            .allowed_algorithms
302            .unwrap_or_else(default_allowed_algorithms);
303
304        let jwks = serde_json::from_str(&self.jwks)
305            .map_err(|e| StartupError::InvalidParameter(format!("Failed to parse JWKS: {e}")))?;
306
307        let kind = TenantKind::Static { jwks };
308
309        Ok(TenantConfiguration {
310            identifier,
311            claims_validation_spec,
312            allowed_algorithms,
313            kind,
314        })
315    }
316}
317
318fn recommended_claims_spec(
319    audiences: &Vec<String>,
320    oidc_config: &Option<OidcConfig>,
321) -> ClaimsValidationSpec {
322    let mut claims_spec = ClaimsValidationSpec::new().exp(true);
323    if !audiences.is_empty() {
324        claims_spec = claims_spec.aud(audiences);
325    }
326
327    if let Some(config) = &oidc_config {
328        if let Some(claims_supported) = &config.claims_supported {
329            if claims_supported.contains(&"nbf".to_owned()) {
330                claims_spec = claims_spec.nbf(true);
331            }
332        }
333        claims_spec = claims_spec.iss(config.issuer.as_str());
334    }
335    claims_spec
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use crate::oidc::{MockOidcDiscovery, OidcConfig};
342    use std::sync::Mutex;
343
344    static MTX: Mutex<()> = Mutex::new(());
345
346    #[tokio::test]
347    async fn test_should_perform_oidc_discovery() {
348        let _m = MTX.lock();
349        let ctx = MockOidcDiscovery::discover_context();
350        ctx.expect().returning(|_| Ok(default_oidc_config())).once();
351
352        let result = TenantConfigurationBuilder::new("http://some-issuer.com")
353            .build()
354            .await;
355
356        assert!(result.is_ok());
357    }
358
359    #[tokio::test]
360    async fn test_should_skip_oidc_discovery_if_jwks_url_set() {
361        let _m = MTX.lock();
362        let ctx = MockOidcDiscovery::discover_context();
363        ctx.expect().never();
364
365        let result = TenantConfigurationBuilder::new("http://some-issuer.com")
366            .jwks_url("https://some-issuer.com/jwks")
367            .build()
368            .await;
369        assert!(result.is_ok());
370    }
371
372    #[tokio::test]
373    async fn test_should_use_issuer_as_identifier() {
374        let _m = MTX.lock();
375        let ctx = MockOidcDiscovery::discover_context();
376        ctx.expect().returning(|_| Ok(default_oidc_config())).once();
377
378        let result = TenantConfigurationBuilder::new("http://some-issuer.com")
379            .build()
380            .await;
381
382        assert!(result.is_ok());
383        let tenant = result.unwrap();
384        assert_eq!(tenant.identifier, "http://some-issuer.com");
385    }
386
387    #[tokio::test]
388    async fn test_custom_identifier_overrides_issuer() {
389        let _m = MTX.lock();
390        let ctx = MockOidcDiscovery::discover_context();
391        ctx.expect().returning(|_| Ok(default_oidc_config())).once();
392
393        let result = TenantConfigurationBuilder::new("http://some-issuer.com")
394            .identifier("custom-identifier")
395            .build()
396            .await;
397
398        assert!(result.is_ok());
399        let tenant = result.unwrap();
400        assert_eq!(tenant.identifier, "custom-identifier");
401    }
402
403    #[tokio::test]
404    async fn test_valid_issuer_url_required() {
405        let _m = MTX.lock();
406        let ctx = MockOidcDiscovery::discover_context();
407        ctx.expect().never();
408
409        let result = TenantConfigurationBuilder::new("not-a-url").build().await;
410
411        assert!(result.is_err());
412        assert_eq!(
413            result.unwrap_err(),
414            StartupError::InvalidParameter("Invalid issuer_url format".to_owned())
415        )
416    }
417
418    #[tokio::test]
419    async fn test_valid_jwks_url_required() {
420        let _m = MTX.lock();
421        let ctx = MockOidcDiscovery::discover_context();
422        ctx.expect().never();
423
424        let result = TenantConfigurationBuilder::new("https://some-issuer.com")
425            .jwks_url("not-a-url")
426            .build()
427            .await;
428
429        assert!(result.is_err());
430        assert_eq!(
431            result.unwrap_err(),
432            StartupError::InvalidParameter("Invalid jwks_url format".to_owned())
433        )
434    }
435
436    #[tokio::test]
437    async fn test_provides_recommended_claims_validation_spec() {
438        let _m = MTX.lock();
439        let ctx = MockOidcDiscovery::discover_context();
440        ctx.expect().returning(|_| Ok(default_oidc_config())).once();
441
442        let result = TenantConfigurationBuilder::new("https://some-issuer.com")
443            .audiences(&["https://some-resource-server.com"])
444            .build()
445            .await;
446
447        assert!(result.is_ok());
448        assert_eq!(
449            result.unwrap().claims_validation_spec,
450            ClaimsValidationSpec::new()
451                .exp(true)
452                .iss("http://some-issuer.com")
453                .aud(&vec!["https://some-resource-server.com".to_owned()])
454        );
455    }
456
457    #[tokio::test]
458    async fn test_custom_claims_validation_spec_overrides_recommended() {
459        let _m = MTX.lock();
460        let ctx = MockOidcDiscovery::discover_context();
461        ctx.expect().returning(|_| Ok(default_oidc_config())).once();
462
463        let claims_validation = ClaimsValidationSpec::new().exp(false);
464        let result = TenantConfigurationBuilder::new("https://some-issuer.com")
465            .audiences(&["https://some-resource-server.com"])
466            .claims_validation(claims_validation.clone())
467            .build()
468            .await;
469
470        assert!(result.is_ok());
471        assert_eq!(result.unwrap().claims_validation_spec, claims_validation);
472    }
473
474    #[test]
475    fn test_static_build() {
476        let jwks = mock_jwks();
477        let t = TenantStaticConfigurationBuilder::new(jwks).build().unwrap();
478
479        assert_eq!(t.identifier, "static");
480        assert!(matches!(t.kind, TenantKind::Static { .. }));
481    }
482
483    #[test]
484    fn test_static_build_invalid_jwks() {
485        let jwks = " {}";
486        let e = TenantStaticConfigurationBuilder::new(jwks)
487            .build()
488            .unwrap_err();
489        assert!(matches!(e, StartupError::InvalidParameter { .. }))
490    }
491
492    #[test]
493    fn test_static_build_custom_identifier() {
494        let jwks = mock_jwks();
495        let t = TenantStaticConfigurationBuilder::new(jwks)
496            .identifier("custom")
497            .build()
498            .unwrap();
499
500        assert_eq!(t.identifier, "custom");
501        assert!(matches!(t.kind, TenantKind::Static { .. }));
502    }
503
504    #[test]
505    fn test_static_provides_recommended_claims_validation_spec() {
506        let jwks = mock_jwks();
507        let t = TenantStaticConfigurationBuilder::new(jwks)
508            .audiences(&["https://some-resource-server.com"])
509            .build()
510            .unwrap();
511
512        assert_eq!(
513            t.claims_validation_spec,
514            ClaimsValidationSpec::new()
515                .exp(true)
516                .aud(&vec!["https://some-resource-server.com".to_owned()])
517        );
518    }
519
520    fn mock_jwks() -> String {
521        let modulus = "oEz_RrupHP9d9XiFbXLoJMwG-75Z18t4ziBy2PHTZHxkHOep7aFeNj-13NmIcL4ooj-2nxrLhWbgA2iBaWr95wKkf5peTsc-5Q6-B2uCcn9xPSQK08Y_jNVhtly3mAOdsT4Y9mQIO_oqaqEyzutypZBEu-18NkbGVwkNhG9sxvUjFXHvMoJs5iwILaDA2FhuEioIDzOy-ZjD8p928ye2v8CdPWl1xPxoBXd2KIe3RkocRDxLeeBg3wH8a9tQ5Z7fOmiXiAI8_lN57zYf078yazvLUlKzCo1pQoR25MU51d7zgI_I7H2Fb5PZGcCmfvN1Up41OfEQyMLL6JYyoP23XQ";
522        let exponent = "AQAB";
523        serde_json::json!({
524            "keys": [{
525                "kty": "RSA",
526                "kid": "test-kid",
527                "n": modulus,
528                "e": exponent
529            }]
530        })
531        .to_string()
532    }
533
534    fn default_oidc_config() -> OidcConfig {
535        OidcConfig {
536            jwks_uri: "http://some-issuer.com/jwks".parse::<Url>().unwrap(),
537            issuer: "http://some-issuer.com".to_owned(),
538            claims_supported: None,
539        }
540    }
541}