tower_oauth2_resource_server/
tenant.rs1use std::{collections::HashSet, time::Duration};
2
3use jsonwebtoken::{Algorithm, jwk::JwkSet};
4use mockall_double::double;
5use url::Url;
6
7use crate::{error::StartupError, oidc::OidcConfig, validation::ClaimsValidationSpec};
8
9#[double]
10use crate::oidc::OidcDiscovery;
11
12pub fn default_allowed_algorithms() -> HashSet<Algorithm> {
19 HashSet::from([
20 Algorithm::RS256,
21 Algorithm::RS384,
22 Algorithm::RS512,
23 Algorithm::ES256,
24 Algorithm::ES384,
25 Algorithm::PS256,
26 Algorithm::PS384,
27 Algorithm::PS512,
28 Algorithm::EdDSA,
29 ])
30}
31
32#[derive(Debug, Clone)]
33pub(crate) enum TenantKind {
34 JwksUrl {
35 jwks_url: Url,
36 jwks_refresh_interval: Duration,
37 },
38 Static {
39 jwks: JwkSet,
40 },
41}
42
43#[derive(Debug, Clone)]
44pub struct TenantConfiguration {
45 pub(crate) identifier: String,
46 pub(crate) claims_validation_spec: ClaimsValidationSpec,
47 pub(crate) allowed_algorithms: HashSet<Algorithm>,
48 pub(crate) kind: TenantKind,
49}
50
51impl TenantConfiguration {
52 pub fn builder(issuer_url: impl Into<String>) -> TenantConfigurationBuilder {
71 TenantConfigurationBuilder::new(issuer_url)
72 }
73
74 pub fn static_builder(jwks: impl Into<String>) -> TenantStaticConfigurationBuilder {
78 TenantStaticConfigurationBuilder::new(jwks)
79 }
80}
81
82pub struct TenantConfigurationBuilder {
83 issuer_url: String,
84 identifier: Option<String>,
85 jwks_url: Option<String>,
86 audiences: Vec<String>,
87 jwk_set_refresh_interval: Option<Duration>,
88 claims_validation_spec: Option<ClaimsValidationSpec>,
89 allowed_algorithms: Option<HashSet<Algorithm>>,
90}
91
92impl TenantConfigurationBuilder {
93 fn new(issuer_url: impl Into<String>) -> Self {
94 Self {
95 issuer_url: issuer_url.into(),
96 identifier: None,
97 jwks_url: None,
98 audiences: Vec::new(),
99 jwk_set_refresh_interval: None,
100 claims_validation_spec: None,
101 allowed_algorithms: None,
102 }
103 }
104
105 pub fn identifier(mut self, identifier: &str) -> Self {
112 self.identifier = Some(identifier.to_string());
113 self
114 }
115
116 pub fn jwks_url(mut self, jwks_url: impl Into<String>) -> Self {
123 self.jwks_url = Some(jwks_url.into());
124 self
125 }
126
127 pub fn audiences(mut self, audiences: &[impl ToString]) -> Self {
131 self.audiences = audiences.iter().map(|aud| aud.to_string()).collect();
132 self
133 }
134
135 pub fn jwks_refresh_interval(mut self, jwk_set_refresh_interval: Duration) -> Self {
142 self.jwk_set_refresh_interval = Some(jwk_set_refresh_interval);
143 self
144 }
145
146 pub fn claims_validation(mut self, claims_validation: ClaimsValidationSpec) -> Self {
150 self.claims_validation_spec = Some(claims_validation);
151 self
152 }
153
154 pub fn allowed_algorithms(mut self, algorithms: &[Algorithm]) -> Self {
162 self.allowed_algorithms = Some(algorithms.iter().copied().collect());
163 self
164 }
165
166 pub async fn build(self) -> Result<TenantConfiguration, StartupError> {
168 let identifier = match self.identifier {
169 Some(id) => id,
170 None => self.issuer_url.clone(),
171 };
172
173 let issuer_url = Url::parse(&self.issuer_url)
174 .map_err(|_| StartupError::InvalidParameter("Invalid issuer_url format".to_string()))?;
175
176 let jwks_url = self
177 .jwks_url
178 .as_deref()
179 .map(|jwks_url| {
180 Url::parse(jwks_url).map_err(|_| {
181 StartupError::InvalidParameter("Invalid jwks_url format".to_string())
182 })
183 })
184 .transpose()?;
185
186 let oidc_config = if jwks_url.is_some() {
187 None
188 } else {
189 Some(
190 OidcDiscovery::discover(&issuer_url)
191 .await
192 .map_err(|e| StartupError::OidcDiscoveryFailed(e.to_string()))?,
193 )
194 };
195
196 let claims_validation_spec = self
197 .claims_validation_spec
198 .unwrap_or(recommended_claims_spec(&self.audiences, &oidc_config));
199
200 let allowed_algorithms = self
201 .allowed_algorithms
202 .unwrap_or_else(default_allowed_algorithms);
203
204 let jwks_url = match jwks_url {
205 Some(jwks_url) => jwks_url,
206 None => match oidc_config {
207 Some(oidc_config) => oidc_config.jwks_uri,
208 None => {
209 return Err(StartupError::InvalidParameter(
210 "Failed to resolve JWKS URL".to_string(),
211 ));
212 }
213 },
214 };
215
216 let kind = TenantKind::JwksUrl {
217 jwks_url,
218 jwks_refresh_interval: self
219 .jwk_set_refresh_interval
220 .unwrap_or(Duration::from_secs(60)),
221 };
222
223 Ok(TenantConfiguration {
224 identifier,
225 claims_validation_spec,
226 allowed_algorithms,
227 kind,
228 })
229 }
230}
231
232pub struct TenantStaticConfigurationBuilder {
233 identifier: Option<String>,
234 audiences: Vec<String>,
235 claims_validation_spec: Option<ClaimsValidationSpec>,
236 allowed_algorithms: Option<HashSet<Algorithm>>,
237 jwks: String,
238}
239
240impl TenantStaticConfigurationBuilder {
241 fn new(jwks: impl Into<String>) -> Self {
242 Self {
243 jwks: jwks.into(),
244 identifier: None,
245 audiences: Vec::new(),
246 claims_validation_spec: None,
247 allowed_algorithms: None,
248 }
249 }
250
251 pub fn identifier(mut self, identifier: &str) -> Self {
260 self.identifier = Some(identifier.to_string());
261 self
262 }
263
264 pub fn audiences(mut self, audiences: &[impl ToString]) -> Self {
268 self.audiences = audiences.iter().map(|aud| aud.to_string()).collect();
269 self
270 }
271
272 pub fn claims_validation(mut self, claims_validation: ClaimsValidationSpec) -> Self {
276 self.claims_validation_spec = Some(claims_validation);
277 self
278 }
279
280 pub fn allowed_algorithms(mut self, algorithms: &[Algorithm]) -> Self {
288 self.allowed_algorithms = Some(algorithms.iter().copied().collect());
289 self
290 }
291
292 pub fn build(self) -> Result<TenantConfiguration, StartupError> {
294 let identifier = self.identifier.unwrap_or_else(|| String::from("static"));
295
296 let claims_validation_spec = self
297 .claims_validation_spec
298 .unwrap_or(recommended_claims_spec(&self.audiences, &None));
299
300 let allowed_algorithms = self
301 .allowed_algorithms
302 .unwrap_or_else(default_allowed_algorithms);
303
304 let jwks = serde_json::from_str(&self.jwks)
305 .map_err(|e| StartupError::InvalidParameter(format!("Failed to parse JWKS: {e}")))?;
306
307 let kind = TenantKind::Static { jwks };
308
309 Ok(TenantConfiguration {
310 identifier,
311 claims_validation_spec,
312 allowed_algorithms,
313 kind,
314 })
315 }
316}
317
318fn recommended_claims_spec(
319 audiences: &Vec<String>,
320 oidc_config: &Option<OidcConfig>,
321) -> ClaimsValidationSpec {
322 let mut claims_spec = ClaimsValidationSpec::new().exp(true);
323 if !audiences.is_empty() {
324 claims_spec = claims_spec.aud(audiences);
325 }
326
327 if let Some(config) = &oidc_config {
328 if let Some(claims_supported) = &config.claims_supported {
329 if claims_supported.contains(&"nbf".to_owned()) {
330 claims_spec = claims_spec.nbf(true);
331 }
332 }
333 claims_spec = claims_spec.iss(config.issuer.as_str());
334 }
335 claims_spec
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use crate::oidc::{MockOidcDiscovery, OidcConfig};
342 use std::sync::Mutex;
343
344 static MTX: Mutex<()> = Mutex::new(());
345
346 #[tokio::test]
347 async fn test_should_perform_oidc_discovery() {
348 let _m = MTX.lock();
349 let ctx = MockOidcDiscovery::discover_context();
350 ctx.expect().returning(|_| Ok(default_oidc_config())).once();
351
352 let result = TenantConfigurationBuilder::new("http://some-issuer.com")
353 .build()
354 .await;
355
356 assert!(result.is_ok());
357 }
358
359 #[tokio::test]
360 async fn test_should_skip_oidc_discovery_if_jwks_url_set() {
361 let _m = MTX.lock();
362 let ctx = MockOidcDiscovery::discover_context();
363 ctx.expect().never();
364
365 let result = TenantConfigurationBuilder::new("http://some-issuer.com")
366 .jwks_url("https://some-issuer.com/jwks")
367 .build()
368 .await;
369 assert!(result.is_ok());
370 }
371
372 #[tokio::test]
373 async fn test_should_use_issuer_as_identifier() {
374 let _m = MTX.lock();
375 let ctx = MockOidcDiscovery::discover_context();
376 ctx.expect().returning(|_| Ok(default_oidc_config())).once();
377
378 let result = TenantConfigurationBuilder::new("http://some-issuer.com")
379 .build()
380 .await;
381
382 assert!(result.is_ok());
383 let tenant = result.unwrap();
384 assert_eq!(tenant.identifier, "http://some-issuer.com");
385 }
386
387 #[tokio::test]
388 async fn test_custom_identifier_overrides_issuer() {
389 let _m = MTX.lock();
390 let ctx = MockOidcDiscovery::discover_context();
391 ctx.expect().returning(|_| Ok(default_oidc_config())).once();
392
393 let result = TenantConfigurationBuilder::new("http://some-issuer.com")
394 .identifier("custom-identifier")
395 .build()
396 .await;
397
398 assert!(result.is_ok());
399 let tenant = result.unwrap();
400 assert_eq!(tenant.identifier, "custom-identifier");
401 }
402
403 #[tokio::test]
404 async fn test_valid_issuer_url_required() {
405 let _m = MTX.lock();
406 let ctx = MockOidcDiscovery::discover_context();
407 ctx.expect().never();
408
409 let result = TenantConfigurationBuilder::new("not-a-url").build().await;
410
411 assert!(result.is_err());
412 assert_eq!(
413 result.unwrap_err(),
414 StartupError::InvalidParameter("Invalid issuer_url format".to_owned())
415 )
416 }
417
418 #[tokio::test]
419 async fn test_valid_jwks_url_required() {
420 let _m = MTX.lock();
421 let ctx = MockOidcDiscovery::discover_context();
422 ctx.expect().never();
423
424 let result = TenantConfigurationBuilder::new("https://some-issuer.com")
425 .jwks_url("not-a-url")
426 .build()
427 .await;
428
429 assert!(result.is_err());
430 assert_eq!(
431 result.unwrap_err(),
432 StartupError::InvalidParameter("Invalid jwks_url format".to_owned())
433 )
434 }
435
436 #[tokio::test]
437 async fn test_provides_recommended_claims_validation_spec() {
438 let _m = MTX.lock();
439 let ctx = MockOidcDiscovery::discover_context();
440 ctx.expect().returning(|_| Ok(default_oidc_config())).once();
441
442 let result = TenantConfigurationBuilder::new("https://some-issuer.com")
443 .audiences(&["https://some-resource-server.com"])
444 .build()
445 .await;
446
447 assert!(result.is_ok());
448 assert_eq!(
449 result.unwrap().claims_validation_spec,
450 ClaimsValidationSpec::new()
451 .exp(true)
452 .iss("http://some-issuer.com")
453 .aud(&vec!["https://some-resource-server.com".to_owned()])
454 );
455 }
456
457 #[tokio::test]
458 async fn test_custom_claims_validation_spec_overrides_recommended() {
459 let _m = MTX.lock();
460 let ctx = MockOidcDiscovery::discover_context();
461 ctx.expect().returning(|_| Ok(default_oidc_config())).once();
462
463 let claims_validation = ClaimsValidationSpec::new().exp(false);
464 let result = TenantConfigurationBuilder::new("https://some-issuer.com")
465 .audiences(&["https://some-resource-server.com"])
466 .claims_validation(claims_validation.clone())
467 .build()
468 .await;
469
470 assert!(result.is_ok());
471 assert_eq!(result.unwrap().claims_validation_spec, claims_validation);
472 }
473
474 #[test]
475 fn test_static_build() {
476 let jwks = mock_jwks();
477 let t = TenantStaticConfigurationBuilder::new(jwks).build().unwrap();
478
479 assert_eq!(t.identifier, "static");
480 assert!(matches!(t.kind, TenantKind::Static { .. }));
481 }
482
483 #[test]
484 fn test_static_build_invalid_jwks() {
485 let jwks = " {}";
486 let e = TenantStaticConfigurationBuilder::new(jwks)
487 .build()
488 .unwrap_err();
489 assert!(matches!(e, StartupError::InvalidParameter { .. }))
490 }
491
492 #[test]
493 fn test_static_build_custom_identifier() {
494 let jwks = mock_jwks();
495 let t = TenantStaticConfigurationBuilder::new(jwks)
496 .identifier("custom")
497 .build()
498 .unwrap();
499
500 assert_eq!(t.identifier, "custom");
501 assert!(matches!(t.kind, TenantKind::Static { .. }));
502 }
503
504 #[test]
505 fn test_static_provides_recommended_claims_validation_spec() {
506 let jwks = mock_jwks();
507 let t = TenantStaticConfigurationBuilder::new(jwks)
508 .audiences(&["https://some-resource-server.com"])
509 .build()
510 .unwrap();
511
512 assert_eq!(
513 t.claims_validation_spec,
514 ClaimsValidationSpec::new()
515 .exp(true)
516 .aud(&vec!["https://some-resource-server.com".to_owned()])
517 );
518 }
519
520 fn mock_jwks() -> String {
521 let modulus = "oEz_RrupHP9d9XiFbXLoJMwG-75Z18t4ziBy2PHTZHxkHOep7aFeNj-13NmIcL4ooj-2nxrLhWbgA2iBaWr95wKkf5peTsc-5Q6-B2uCcn9xPSQK08Y_jNVhtly3mAOdsT4Y9mQIO_oqaqEyzutypZBEu-18NkbGVwkNhG9sxvUjFXHvMoJs5iwILaDA2FhuEioIDzOy-ZjD8p928ye2v8CdPWl1xPxoBXd2KIe3RkocRDxLeeBg3wH8a9tQ5Z7fOmiXiAI8_lN57zYf078yazvLUlKzCo1pQoR25MU51d7zgI_I7H2Fb5PZGcCmfvN1Up41OfEQyMLL6JYyoP23XQ";
522 let exponent = "AQAB";
523 serde_json::json!({
524 "keys": [{
525 "kty": "RSA",
526 "kid": "test-kid",
527 "n": modulus,
528 "e": exponent
529 }]
530 })
531 .to_string()
532 }
533
534 fn default_oidc_config() -> OidcConfig {
535 OidcConfig {
536 jwks_uri: "http://some-issuer.com/jwks".parse::<Url>().unwrap(),
537 issuer: "http://some-issuer.com".to_owned(),
538 claims_supported: None,
539 }
540 }
541}