1use serde::{Deserialize, Serialize};
2use url::Url;
3
4use crate::options::{OidcConfig, TokenEndpointAuthentication};
5
6pub const REQUIRED_DISCOVERY_FIELDS: &[&str] = &[
10 "issuer",
11 "authorization_endpoint",
12 "token_endpoint",
13 "jwks_uri",
14];
15
16#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
17pub struct OidcDiscoveryDocument {
18 #[serde(default)]
19 pub issuer: String,
20 #[serde(default)]
21 pub authorization_endpoint: String,
22 #[serde(default)]
23 pub token_endpoint: String,
24 #[serde(default)]
25 pub jwks_uri: String,
26 pub userinfo_endpoint: Option<String>,
27 pub revocation_endpoint: Option<String>,
28 pub end_session_endpoint: Option<String>,
29 pub introspection_endpoint: Option<String>,
30 pub token_endpoint_auth_methods_supported: Option<Vec<String>>,
31 pub scopes_supported: Option<Vec<String>>,
32 pub response_types_supported: Option<Vec<String>>,
33 pub subject_types_supported: Option<Vec<String>>,
34 pub id_token_signing_alg_values_supported: Option<Vec<String>>,
35 pub claims_supported: Option<Vec<String>>,
36 pub code_challenge_methods_supported: Option<Vec<String>>,
37}
38
39pub fn is_configured_oidc_endpoint(endpoint: Option<&str>) -> bool {
44 endpoint.is_some_and(|value| !value.is_empty())
45}
46
47fn merge_required_endpoint(existing: Option<&str>, discovered: String) -> String {
48 existing
49 .filter(|value| !value.is_empty())
50 .map(str::to_owned)
51 .unwrap_or(discovered)
52}
53
54fn merge_optional_endpoint(existing: Option<&str>, discovered: Option<String>) -> Option<String> {
55 if let Some(value) = existing.filter(|value| !value.is_empty()) {
56 return Some(value.to_owned());
57 }
58 discovered
59}
60
61fn non_empty_endpoint(endpoint: Option<&str>) -> Option<&str> {
62 endpoint.filter(|value| !value.is_empty())
63}
64
65pub fn compute_discovery_url(issuer: &str) -> String {
66 format!(
67 "{}/.well-known/openid-configuration",
68 issuer.trim_end_matches('/')
69 )
70}
71
72pub fn normalize_url(value: &str) -> Result<String, url::ParseError> {
73 Url::parse(value).map(|url| url.to_string())
74}
75
76pub fn normalize_absolute_http_url(
81 field: &'static str,
82 value: &str,
83) -> Result<String, OidcDiscoveryError> {
84 validate_trusted_url(field, value, &|_| true)?;
85 Url::parse(value)
86 .map(|url| url.to_string())
87 .map_err(|source| OidcDiscoveryError::InvalidUrl {
88 field,
89 reason: source.to_string(),
90 })
91}
92
93pub fn normalize_endpoint_url(
96 field: &'static str,
97 endpoint: &str,
98 issuer: &str,
99) -> Result<String, OidcDiscoveryError> {
100 normalize_endpoint(field, endpoint, issuer)
101}
102
103pub fn validate_issuer_url(value: &str) -> Result<String, openidconnect::url::ParseError> {
104 openidconnect::IssuerUrl::new(value.to_owned()).map(|issuer| issuer.to_string())
105}
106
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub struct HydratedOidcDiscovery {
109 pub issuer: String,
110 pub discovery_endpoint: String,
111 pub authorization_endpoint: String,
112 pub token_endpoint: String,
113 pub jwks_endpoint: String,
114 pub user_info_endpoint: Option<String>,
115 pub revocation_endpoint: Option<String>,
116 pub end_session_endpoint: Option<String>,
117 pub introspection_endpoint: Option<String>,
118 pub token_endpoint_authentication: TokenEndpointAuthentication,
119 pub scopes_supported: Option<Vec<String>>,
120}
121
122pub async fn discover_oidc_config(
123 issuer: &str,
124 discovery_endpoint: Option<&str>,
125 existing: PartialOidcDiscoveryConfig<'_>,
126 client: &reqwest::Client,
127) -> Result<HydratedOidcDiscovery, OidcDiscoveryError> {
128 discover_oidc_config_with_origin_validator(
129 issuer,
130 discovery_endpoint,
131 existing,
132 |_| true,
133 client,
134 )
135 .await
136}
137
138pub async fn discover_oidc_config_with_origin_validator<F>(
139 issuer: &str,
140 discovery_endpoint: Option<&str>,
141 existing: PartialOidcDiscoveryConfig<'_>,
142 is_trusted_origin: F,
143 client: &reqwest::Client,
144) -> Result<HydratedOidcDiscovery, OidcDiscoveryError>
145where
146 F: Fn(&str) -> bool,
147{
148 let discovery_endpoint = discovery_endpoint
149 .map(str::to_owned)
150 .or_else(|| existing.discovery_endpoint.map(str::to_owned))
151 .unwrap_or_else(|| compute_discovery_url(issuer));
152 validate_trusted_url(
153 "discovery_endpoint",
154 &discovery_endpoint,
155 &is_trusted_origin,
156 )?;
157 let document = fetch_discovery_document(&discovery_endpoint, client).await?;
158 validate_discovery_document(&document, issuer)?;
159 let normalized = normalize_discovery_document(document, issuer)?;
160 let token_endpoint_authentication =
161 select_token_endpoint_authentication(&normalized, existing.token_endpoint_authentication);
162
163 let hydrated = HydratedOidcDiscovery {
164 issuer: existing
165 .issuer
166 .map(str::to_owned)
167 .unwrap_or(normalized.issuer),
168 discovery_endpoint,
169 authorization_endpoint: merge_required_endpoint(
170 existing.authorization_endpoint,
171 normalized.authorization_endpoint,
172 ),
173 token_endpoint: merge_required_endpoint(existing.token_endpoint, normalized.token_endpoint),
174 jwks_endpoint: merge_required_endpoint(existing.jwks_endpoint, normalized.jwks_uri),
175 user_info_endpoint: merge_optional_endpoint(
176 existing.user_info_endpoint,
177 normalized.userinfo_endpoint,
178 ),
179 revocation_endpoint: merge_optional_endpoint(
180 existing.revocation_endpoint,
181 normalized.revocation_endpoint,
182 ),
183 end_session_endpoint: merge_optional_endpoint(
184 existing.end_session_endpoint,
185 normalized.end_session_endpoint,
186 ),
187 introspection_endpoint: merge_optional_endpoint(
188 existing.introspection_endpoint,
189 normalized.introspection_endpoint,
190 ),
191 token_endpoint_authentication,
192 scopes_supported: normalized.scopes_supported,
193 };
194 validate_trusted_url(
195 "authorization_endpoint",
196 &hydrated.authorization_endpoint,
197 &is_trusted_origin,
198 )?;
199 validate_trusted_url(
200 "token_endpoint",
201 &hydrated.token_endpoint,
202 &is_trusted_origin,
203 )?;
204 validate_trusted_url("jwks_uri", &hydrated.jwks_endpoint, &is_trusted_origin)?;
205 if let Some(user_info_endpoint) = &hydrated.user_info_endpoint {
206 validate_trusted_url("userinfo_endpoint", user_info_endpoint, &is_trusted_origin)?;
207 }
208 if let Some(revocation_endpoint) = &hydrated.revocation_endpoint {
209 validate_trusted_url(
210 "revocation_endpoint",
211 revocation_endpoint,
212 &is_trusted_origin,
213 )?;
214 }
215 if let Some(end_session_endpoint) = &hydrated.end_session_endpoint {
216 validate_trusted_url(
217 "end_session_endpoint",
218 end_session_endpoint,
219 &is_trusted_origin,
220 )?;
221 }
222 if let Some(introspection_endpoint) = &hydrated.introspection_endpoint {
223 validate_trusted_url(
224 "introspection_endpoint",
225 introspection_endpoint,
226 &is_trusted_origin,
227 )?;
228 }
229 Ok(hydrated)
230}
231
232pub trait OidcEndpointConfig {
233 fn discovery_endpoint(&self) -> &str;
234 fn authorization_endpoint(&self) -> Option<&str>;
235 fn token_endpoint(&self) -> Option<&str>;
236 fn user_info_endpoint(&self) -> Option<&str>;
237 fn jwks_endpoint(&self) -> Option<&str>;
238 fn revocation_endpoint(&self) -> Option<&str>;
239 fn end_session_endpoint(&self) -> Option<&str>;
240 fn introspection_endpoint(&self) -> Option<&str>;
241}
242
243impl OidcEndpointConfig for OidcConfig {
244 fn discovery_endpoint(&self) -> &str {
245 &self.discovery_endpoint
246 }
247
248 fn authorization_endpoint(&self) -> Option<&str> {
249 self.authorization_endpoint.as_deref()
250 }
251
252 fn token_endpoint(&self) -> Option<&str> {
253 self.token_endpoint.as_deref()
254 }
255
256 fn user_info_endpoint(&self) -> Option<&str> {
257 self.user_info_endpoint.as_deref()
258 }
259
260 fn jwks_endpoint(&self) -> Option<&str> {
261 self.jwks_endpoint.as_deref()
262 }
263
264 fn revocation_endpoint(&self) -> Option<&str> {
265 self.revocation_endpoint.as_deref()
266 }
267
268 fn end_session_endpoint(&self) -> Option<&str> {
269 self.end_session_endpoint.as_deref()
270 }
271
272 fn introspection_endpoint(&self) -> Option<&str> {
273 self.introspection_endpoint.as_deref()
274 }
275}
276
277pub fn validate_configured_oidc_endpoint_origins<C, F>(
278 config: &C,
279 is_trusted_origin: F,
280) -> Result<(), OidcDiscoveryError>
281where
282 C: OidcEndpointConfig + ?Sized,
283 F: Fn(&str) -> bool,
284{
285 validate_trusted_url(
286 "discovery_endpoint",
287 config.discovery_endpoint(),
288 &is_trusted_origin,
289 )?;
290 if let Some(endpoint) = config.authorization_endpoint() {
291 validate_trusted_url("authorization_endpoint", endpoint, &is_trusted_origin)?;
292 }
293 if let Some(endpoint) = config.token_endpoint() {
294 validate_trusted_url("token_endpoint", endpoint, &is_trusted_origin)?;
295 }
296 if let Some(endpoint) = config.user_info_endpoint() {
297 validate_trusted_url("userinfo_endpoint", endpoint, &is_trusted_origin)?;
298 }
299 if let Some(endpoint) = config.jwks_endpoint() {
300 validate_trusted_url("jwks_uri", endpoint, &is_trusted_origin)?;
301 }
302 if let Some(endpoint) = config.revocation_endpoint() {
303 validate_trusted_url("revocation_endpoint", endpoint, &is_trusted_origin)?;
304 }
305 if let Some(endpoint) = config.end_session_endpoint() {
306 validate_trusted_url("end_session_endpoint", endpoint, &is_trusted_origin)?;
307 }
308 if let Some(endpoint) = config.introspection_endpoint() {
309 validate_trusted_url("introspection_endpoint", endpoint, &is_trusted_origin)?;
310 }
311 Ok(())
312}
313
314#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
315pub struct PartialOidcDiscoveryConfig<'a> {
316 pub issuer: Option<&'a str>,
317 pub discovery_endpoint: Option<&'a str>,
318 pub authorization_endpoint: Option<&'a str>,
319 pub token_endpoint: Option<&'a str>,
320 pub user_info_endpoint: Option<&'a str>,
321 pub jwks_endpoint: Option<&'a str>,
322 pub revocation_endpoint: Option<&'a str>,
323 pub end_session_endpoint: Option<&'a str>,
324 pub introspection_endpoint: Option<&'a str>,
325 pub token_endpoint_authentication: Option<TokenEndpointAuthentication>,
326}
327
328#[derive(Debug, Clone, Copy, PartialEq, Eq)]
329pub enum OidcRuntimeRequirement {
330 SignIn,
331 Callback,
332}
333
334impl OidcRuntimeRequirement {
335 pub fn is_satisfied(self, config: &OidcConfig) -> bool {
336 let _ = self;
341 is_configured_oidc_endpoint(config.authorization_endpoint.as_deref())
342 && is_configured_oidc_endpoint(config.token_endpoint.as_deref())
343 && is_configured_oidc_endpoint(config.jwks_endpoint.as_deref())
344 }
345}
346
347pub fn needs_runtime_discovery(config: &OidcConfig, requirement: OidcRuntimeRequirement) -> bool {
348 !requirement.is_satisfied(config)
349}
350
351pub async fn ensure_runtime_oidc_config_with_origin_validator<F>(
352 issuer: &str,
353 config: OidcConfig,
354 requirement: OidcRuntimeRequirement,
355 is_trusted_origin: F,
356 validate_configured_origins: bool,
357 client: &reqwest::Client,
358) -> Result<OidcConfig, OidcDiscoveryError>
359where
360 F: Fn(&str) -> bool,
361{
362 if !needs_runtime_discovery(&config, requirement) {
363 if validate_configured_origins {
364 validate_configured_oidc_endpoint_origins(&config, &is_trusted_origin)?;
365 }
366 return Ok(config);
367 }
368
369 let hydrated = discover_oidc_config_with_origin_validator(
370 issuer,
371 (!config.discovery_endpoint.is_empty()).then_some(config.discovery_endpoint.as_str()),
372 PartialOidcDiscoveryConfig {
373 issuer: Some(config.issuer.as_str()),
374 discovery_endpoint: (!config.discovery_endpoint.is_empty())
375 .then_some(config.discovery_endpoint.as_str()),
376 authorization_endpoint: non_empty_endpoint(config.authorization_endpoint.as_deref()),
377 token_endpoint: non_empty_endpoint(config.token_endpoint.as_deref()),
378 user_info_endpoint: non_empty_endpoint(config.user_info_endpoint.as_deref()),
379 jwks_endpoint: non_empty_endpoint(config.jwks_endpoint.as_deref()),
380 revocation_endpoint: non_empty_endpoint(config.revocation_endpoint.as_deref()),
381 end_session_endpoint: non_empty_endpoint(config.end_session_endpoint.as_deref()),
382 introspection_endpoint: non_empty_endpoint(config.introspection_endpoint.as_deref()),
383 token_endpoint_authentication: config.token_endpoint_authentication,
384 },
385 &is_trusted_origin,
386 client,
387 )
388 .await?;
389
390 let hydrated_config = OidcConfig {
391 issuer: hydrated.issuer,
392 pkce: config.pkce,
393 client_id: config.client_id,
394 client_secret: config.client_secret,
395 discovery_endpoint: hydrated.discovery_endpoint,
396 authorization_endpoint: Some(hydrated.authorization_endpoint),
397 token_endpoint: Some(hydrated.token_endpoint),
398 user_info_endpoint: hydrated.user_info_endpoint,
399 jwks_endpoint: Some(hydrated.jwks_endpoint),
400 revocation_endpoint: hydrated.revocation_endpoint,
401 end_session_endpoint: hydrated.end_session_endpoint,
402 introspection_endpoint: hydrated.introspection_endpoint,
403 token_endpoint_authentication: Some(hydrated.token_endpoint_authentication),
404 scopes: config.scopes,
405 mapping: config.mapping,
406 override_user_info: config.override_user_info,
407 };
408
409 if validate_configured_origins {
410 validate_configured_oidc_endpoint_origins(&hydrated_config, &is_trusted_origin)?;
411 }
412 Ok(hydrated_config)
413}
414
415#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
416pub enum OidcDiscoveryError {
417 #[error("OIDC discovery request failed: {0}")]
418 Request(String),
419 #[error("OIDC discovery endpoint not found")]
420 NotFound,
421 #[error("OIDC discovery request timed out")]
422 Timeout,
423 #[error("OIDC discovery endpoint returned invalid JSON: {0}")]
424 InvalidJson(String),
425 #[error("OIDC discovery document contains untrusted URL for `{field}`: {url}")]
426 UntrustedOrigin { field: &'static str, url: String },
427 #[error("OIDC discovery document is missing required field `{0}`")]
428 MissingField(&'static str),
429 #[error("OIDC discovery document is missing required fields: {0:?}")]
430 MissingFields(Vec<&'static str>),
431 #[error("OIDC discovery issuer mismatch")]
432 IssuerMismatch,
433 #[error("OIDC discovery document contains invalid URL for `{field}`: {reason}")]
434 InvalidUrl { field: &'static str, reason: String },
435}
436
437impl OidcDiscoveryError {
438 pub fn code(&self) -> &'static str {
439 match self {
440 Self::Timeout => "discovery_timeout",
441 Self::NotFound => "discovery_not_found",
442 Self::InvalidJson(_) => "discovery_invalid_json",
443 Self::InvalidUrl { .. } => "discovery_invalid_url",
444 Self::UntrustedOrigin { .. } => "discovery_untrusted_origin",
445 Self::IssuerMismatch => "issuer_mismatch",
446 Self::MissingField(_) | Self::MissingFields(_) => "discovery_incomplete",
447 Self::Request(_) => "discovery_unexpected_error",
448 }
449 }
450
451 pub fn status(&self) -> http::StatusCode {
452 match self {
453 Self::Timeout | Self::Request(_) => http::StatusCode::BAD_GATEWAY,
454 Self::NotFound
455 | Self::InvalidJson(_)
456 | Self::InvalidUrl { .. }
457 | Self::UntrustedOrigin { .. }
458 | Self::IssuerMismatch
459 | Self::MissingField(_)
460 | Self::MissingFields(_) => http::StatusCode::BAD_REQUEST,
461 }
462 }
463}
464
465pub fn validate_discovery_url<F>(url: &str, is_trusted_origin: F) -> Result<(), OidcDiscoveryError>
467where
468 F: Fn(&str) -> bool,
469{
470 validate_trusted_url("discovery_endpoint", url, &is_trusted_origin)
471}
472
473pub async fn fetch_discovery_document(
475 discovery_endpoint: &str,
476 client: &reqwest::Client,
477) -> Result<OidcDiscoveryDocument, OidcDiscoveryError> {
478 let response = client
479 .get(discovery_endpoint)
480 .header("accept", "application/json")
481 .timeout(std::time::Duration::from_secs(10))
482 .send()
483 .await
484 .map_err(classify_reqwest_error)?;
485 let status = response.status();
486 if status == http::StatusCode::NOT_FOUND {
487 return Err(OidcDiscoveryError::NotFound);
488 }
489 if status == http::StatusCode::REQUEST_TIMEOUT {
490 return Err(OidcDiscoveryError::Timeout);
491 }
492 let response = response
493 .error_for_status()
494 .map_err(classify_reqwest_error)?;
495 response
496 .json::<OidcDiscoveryDocument>()
497 .await
498 .map_err(|error| OidcDiscoveryError::InvalidJson(error.to_string()))
499}
500
501fn classify_reqwest_error(error: reqwest::Error) -> OidcDiscoveryError {
502 if error.is_timeout() {
503 return OidcDiscoveryError::Timeout;
504 }
505 if error.status() == Some(http::StatusCode::NOT_FOUND) {
506 return OidcDiscoveryError::NotFound;
507 }
508 OidcDiscoveryError::Request(error.to_string())
509}
510
511pub fn validate_discovery_document(
513 document: &OidcDiscoveryDocument,
514 issuer: &str,
515) -> Result<(), OidcDiscoveryError> {
516 let mut missing = Vec::new();
517 for field in REQUIRED_DISCOVERY_FIELDS {
518 let is_empty = match *field {
519 "issuer" => document.issuer.is_empty(),
520 "authorization_endpoint" => document.authorization_endpoint.is_empty(),
521 "token_endpoint" => document.token_endpoint.is_empty(),
522 "jwks_uri" => document.jwks_uri.is_empty(),
523 _ => false,
524 };
525 if is_empty {
526 missing.push(*field);
527 }
528 }
529 if !missing.is_empty() {
530 return Err(if missing.len() == 1 {
531 OidcDiscoveryError::MissingField(missing[0])
532 } else {
533 OidcDiscoveryError::MissingFields(missing)
534 });
535 }
536 if trim_trailing_slash(&document.issuer) != trim_trailing_slash(issuer) {
537 return Err(OidcDiscoveryError::IssuerMismatch);
538 }
539 Ok(())
540}
541
542pub fn normalize_discovery_urls<F>(
544 document: OidcDiscoveryDocument,
545 issuer: &str,
546 is_trusted_origin: F,
547) -> Result<OidcDiscoveryDocument, OidcDiscoveryError>
548where
549 F: Fn(&str) -> bool,
550{
551 let normalized = normalize_discovery_document(document, issuer)?;
552 validate_trusted_url(
553 "authorization_endpoint",
554 &normalized.authorization_endpoint,
555 &is_trusted_origin,
556 )?;
557 validate_trusted_url(
558 "token_endpoint",
559 &normalized.token_endpoint,
560 &is_trusted_origin,
561 )?;
562 validate_trusted_url("jwks_uri", &normalized.jwks_uri, &is_trusted_origin)?;
563 if let Some(userinfo_endpoint) = &normalized.userinfo_endpoint {
564 validate_trusted_url("userinfo_endpoint", userinfo_endpoint, &is_trusted_origin)?;
565 }
566 if let Some(revocation_endpoint) = &normalized.revocation_endpoint {
567 validate_trusted_url(
568 "revocation_endpoint",
569 revocation_endpoint,
570 &is_trusted_origin,
571 )?;
572 }
573 if let Some(end_session_endpoint) = &normalized.end_session_endpoint {
574 validate_trusted_url(
575 "end_session_endpoint",
576 end_session_endpoint,
577 &is_trusted_origin,
578 )?;
579 }
580 if let Some(introspection_endpoint) = &normalized.introspection_endpoint {
581 validate_trusted_url(
582 "introspection_endpoint",
583 introspection_endpoint,
584 &is_trusted_origin,
585 )?;
586 }
587 Ok(normalized)
588}
589
590fn normalize_discovery_document(
591 mut document: OidcDiscoveryDocument,
592 issuer: &str,
593) -> Result<OidcDiscoveryDocument, OidcDiscoveryError> {
594 document.authorization_endpoint = normalize_endpoint(
595 "authorization_endpoint",
596 &document.authorization_endpoint,
597 issuer,
598 )?;
599 document.token_endpoint =
600 normalize_endpoint("token_endpoint", &document.token_endpoint, issuer)?;
601 document.jwks_uri = normalize_endpoint("jwks_uri", &document.jwks_uri, issuer)?;
602 document.userinfo_endpoint = document
603 .userinfo_endpoint
604 .as_deref()
605 .map(|endpoint| normalize_endpoint("userinfo_endpoint", endpoint, issuer))
606 .transpose()?;
607 document.revocation_endpoint = document
608 .revocation_endpoint
609 .as_deref()
610 .map(|endpoint| normalize_endpoint("revocation_endpoint", endpoint, issuer))
611 .transpose()?;
612 document.end_session_endpoint = document
613 .end_session_endpoint
614 .as_deref()
615 .map(|endpoint| normalize_endpoint("end_session_endpoint", endpoint, issuer))
616 .transpose()?;
617 document.introspection_endpoint = document
618 .introspection_endpoint
619 .as_deref()
620 .map(|endpoint| normalize_endpoint("introspection_endpoint", endpoint, issuer))
621 .transpose()?;
622 Ok(document)
623}
624
625fn normalize_endpoint(
626 field: &'static str,
627 endpoint: &str,
628 issuer: &str,
629) -> Result<String, OidcDiscoveryError> {
630 if let Ok(url) = Url::parse(endpoint) {
631 ensure_supported_url_scheme(field, &url)?;
632 return Ok(url.to_string());
633 }
634
635 let issuer_url = Url::parse(issuer).map_err(|source| OidcDiscoveryError::InvalidUrl {
636 field,
637 reason: source.to_string(),
638 })?;
639 let origin = issuer_url.origin().ascii_serialization();
640 let base_path = issuer_url.path().trim_end_matches('/');
641 let endpoint_path = endpoint.trim_start_matches('/');
642 let url = Url::parse(&format!("{origin}{base_path}/{endpoint_path}")).map_err(|source| {
643 OidcDiscoveryError::InvalidUrl {
644 field,
645 reason: source.to_string(),
646 }
647 })?;
648 ensure_supported_url_scheme(field, &url)?;
649 Ok(url.to_string())
650}
651
652fn validate_trusted_url<F>(
653 field: &'static str,
654 value: &str,
655 is_trusted_origin: &F,
656) -> Result<(), OidcDiscoveryError>
657where
658 F: Fn(&str) -> bool,
659{
660 let url = Url::parse(value).map_err(|source| OidcDiscoveryError::InvalidUrl {
661 field,
662 reason: source.to_string(),
663 })?;
664 ensure_supported_url_scheme(field, &url)?;
665 if !is_trusted_origin(value) {
666 return Err(OidcDiscoveryError::UntrustedOrigin {
667 field,
668 url: value.to_owned(),
669 });
670 }
671 Ok(())
672}
673
674fn ensure_supported_url_scheme(field: &'static str, url: &Url) -> Result<(), OidcDiscoveryError> {
675 if matches!(url.scheme(), "http" | "https") {
676 return Ok(());
677 }
678 Err(OidcDiscoveryError::InvalidUrl {
679 field,
680 reason: format!("unsupported URL scheme `{}`", url.scheme()),
681 })
682}
683
684pub fn select_token_endpoint_authentication(
686 document: &OidcDiscoveryDocument,
687 existing: Option<TokenEndpointAuthentication>,
688) -> TokenEndpointAuthentication {
689 if let Some(existing) = existing {
690 return existing;
691 }
692 let Some(supported) = &document.token_endpoint_auth_methods_supported else {
693 return TokenEndpointAuthentication::ClientSecretBasic;
694 };
695 if supported
696 .iter()
697 .any(|method| method == "client_secret_basic")
698 {
699 return TokenEndpointAuthentication::ClientSecretBasic;
700 }
701 if supported
702 .iter()
703 .any(|method| method == "client_secret_post")
704 {
705 return TokenEndpointAuthentication::ClientSecretPost;
706 }
707 TokenEndpointAuthentication::ClientSecretBasic
708}
709
710fn trim_trailing_slash(value: &str) -> &str {
711 value.strip_suffix('/').unwrap_or(value)
712}
713
714#[cfg(test)]
715mod tests {
716 use super::*;
717
718 fn discovery_document(issuer: &str) -> OidcDiscoveryDocument {
719 OidcDiscoveryDocument {
720 issuer: issuer.to_owned(),
721 authorization_endpoint: format!("{issuer}/authorize"),
722 token_endpoint: format!("{issuer}/token"),
723 jwks_uri: format!("{issuer}/keys"),
724 userinfo_endpoint: Some(format!("{issuer}/userinfo")),
725 revocation_endpoint: None,
726 end_session_endpoint: None,
727 introspection_endpoint: None,
728 token_endpoint_auth_methods_supported: None,
729 scopes_supported: None,
730 response_types_supported: None,
731 subject_types_supported: None,
732 id_token_signing_alg_values_supported: None,
733 claims_supported: None,
734 code_challenge_methods_supported: None,
735 }
736 }
737
738 #[test]
739 fn normalizes_relative_discovery_endpoints_against_issuer_path(
740 ) -> Result<(), OidcDiscoveryError> {
741 assert_eq!(
742 normalize_endpoint(
743 "token_endpoint",
744 "oauth/token",
745 "https://idp.example.com/tenant"
746 )?,
747 "https://idp.example.com/tenant/oauth/token"
748 );
749 assert_eq!(
750 normalize_endpoint("jwks_uri", "/keys", "https://idp.example.com/tenant")?,
751 "https://idp.example.com/tenant/keys"
752 );
753 let document = normalize_discovery_document(
754 OidcDiscoveryDocument {
755 issuer: "https://idp.example.com/tenant".to_owned(),
756 authorization_endpoint: "authorize".to_owned(),
757 token_endpoint: "token".to_owned(),
758 jwks_uri: "keys".to_owned(),
759 userinfo_endpoint: Some("userinfo".to_owned()),
760 revocation_endpoint: Some("revoke".to_owned()),
761 end_session_endpoint: Some("endsession".to_owned()),
762 introspection_endpoint: Some("introspect".to_owned()),
763 token_endpoint_auth_methods_supported: None,
764 scopes_supported: None,
765 response_types_supported: None,
766 subject_types_supported: None,
767 id_token_signing_alg_values_supported: None,
768 claims_supported: None,
769 code_challenge_methods_supported: None,
770 },
771 "https://idp.example.com/tenant",
772 )?;
773 assert_eq!(
774 document.revocation_endpoint.as_deref(),
775 Some("https://idp.example.com/tenant/revoke")
776 );
777 assert_eq!(
778 document.end_session_endpoint.as_deref(),
779 Some("https://idp.example.com/tenant/endsession")
780 );
781 assert_eq!(
782 document.introspection_endpoint.as_deref(),
783 Some("https://idp.example.com/tenant/introspect")
784 );
785 Ok(())
786 }
787
788 #[test]
789 fn discovery_url_preserves_issuer_path() {
790 assert_eq!(
791 compute_discovery_url("https://idp.example.com/tenant/v1/"),
792 "https://idp.example.com/tenant/v1/.well-known/openid-configuration"
793 );
794 }
795
796 #[test]
797 fn absolute_http_url_api_rejects_relative_and_non_http_values() -> Result<(), OidcDiscoveryError>
798 {
799 assert!(normalize_absolute_http_url("discovery_endpoint", "/relative").is_err());
800 assert!(
801 normalize_absolute_http_url("discovery_endpoint", "ftp://idp.example.com").is_err()
802 );
803 assert_eq!(
804 normalize_absolute_http_url("discovery_endpoint", "https://idp.example.com")?,
805 "https://idp.example.com/"
806 );
807 Ok::<(), OidcDiscoveryError>(())
808 }
809
810 #[test]
811 fn normalize_endpoint_resolves_relative_urls_with_duplicate_slashes(
812 ) -> Result<(), OidcDiscoveryError> {
813 assert_eq!(
814 normalize_endpoint(
815 "token_endpoint",
816 "//oauth2/token",
817 "https://idp.example.com/base//",
818 )?,
819 "https://idp.example.com/base/oauth2/token"
820 );
821 assert_eq!(
822 normalize_endpoint(
823 "token_endpoint",
824 "oauth2/token",
825 "https://idp.example.com/base/"
826 )?,
827 "https://idp.example.com/base/oauth2/token"
828 );
829 Ok(())
830 }
831
832 #[test]
833 fn endpoint_url_api_resolves_relative_values_against_issuer_path(
834 ) -> Result<(), OidcDiscoveryError> {
835 assert_eq!(
836 normalize_endpoint_url(
837 "authorization_endpoint",
838 "/oauth2/authorize",
839 "https://idp.example.com/tenant/",
840 )?,
841 "https://idp.example.com/tenant/oauth2/authorize"
842 );
843 assert!(normalize_endpoint_url(
844 "authorization_endpoint",
845 "ftp://idp.example.com/authorize",
846 "https://idp.example.com/tenant/",
847 )
848 .is_err());
849 Ok::<(), OidcDiscoveryError>(())
850 }
851
852 #[test]
853 fn is_configured_oidc_endpoint_treats_empty_string_as_missing() {
854 assert!(!is_configured_oidc_endpoint(None));
855 assert!(!is_configured_oidc_endpoint(Some("")));
856 assert!(is_configured_oidc_endpoint(Some(
857 "https://idp.example.com/oauth2/v1/authorize"
858 )));
859 }
860
861 #[test]
862 fn runtime_discovery_treats_empty_string_endpoints_as_missing() {
863 let config = OidcConfig {
864 issuer: "https://idp.example.com".to_owned(),
865 pkce: true,
866 client_id: "client".to_owned(),
867 client_secret: "secret".into(),
868 discovery_endpoint: compute_discovery_url("https://idp.example.com"),
869 authorization_endpoint: Some(String::new()),
870 token_endpoint: Some("https://idp.example.com/token".to_owned()),
871 user_info_endpoint: None,
872 jwks_endpoint: Some("https://idp.example.com/keys".to_owned()),
873 revocation_endpoint: None,
874 end_session_endpoint: None,
875 introspection_endpoint: None,
876 token_endpoint_authentication: None,
877 scopes: None,
878 mapping: None,
879 override_user_info: false,
880 };
881
882 assert!(needs_runtime_discovery(
883 &config,
884 OidcRuntimeRequirement::SignIn
885 ));
886 assert!(needs_runtime_discovery(
887 &config,
888 OidcRuntimeRequirement::Callback
889 ));
890 assert!(!is_configured_oidc_endpoint(
891 config.authorization_endpoint.as_deref()
892 ));
893 }
894
895 #[test]
896 fn runtime_discovery_requirements_match_sign_in_and_callback_needs() {
897 let mut config = OidcConfig {
898 issuer: "https://idp.example.com".to_owned(),
899 pkce: true,
900 client_id: "client".to_owned(),
901 client_secret: "secret".into(),
902 discovery_endpoint: compute_discovery_url("https://idp.example.com"),
903 authorization_endpoint: None,
904 token_endpoint: Some("https://idp.example.com/token".to_owned()),
905 user_info_endpoint: Some("https://idp.example.com/userinfo".to_owned()),
906 jwks_endpoint: None,
907 revocation_endpoint: None,
908 end_session_endpoint: None,
909 introspection_endpoint: None,
910 token_endpoint_authentication: None,
911 scopes: None,
912 mapping: None,
913 override_user_info: false,
914 };
915
916 assert!(needs_runtime_discovery(
917 &config,
918 OidcRuntimeRequirement::SignIn
919 ));
920 assert!(needs_runtime_discovery(
921 &config,
922 OidcRuntimeRequirement::Callback
923 ));
924
925 config.authorization_endpoint = Some("https://idp.example.com/authorize".to_owned());
926 assert!(needs_runtime_discovery(
927 &config,
928 OidcRuntimeRequirement::SignIn
929 ));
930 assert!(needs_runtime_discovery(
931 &config,
932 OidcRuntimeRequirement::Callback
933 ));
934
935 config.user_info_endpoint = None;
936 config.jwks_endpoint = Some("https://idp.example.com/keys".to_owned());
937 assert!(!needs_runtime_discovery(
938 &config,
939 OidcRuntimeRequirement::SignIn
940 ));
941 assert!(!needs_runtime_discovery(
942 &config,
943 OidcRuntimeRequirement::Callback
944 ));
945 }
946
947 #[test]
948 fn discovery_errors_expose_stable_codes_and_statuses() {
949 assert_eq!(
950 OidcDiscoveryError::MissingField("issuer").code(),
951 "discovery_incomplete"
952 );
953 assert_eq!(
954 OidcDiscoveryError::MissingFields(vec!["issuer", "jwks_uri"]).code(),
955 "discovery_incomplete"
956 );
957 assert_eq!(OidcDiscoveryError::IssuerMismatch.code(), "issuer_mismatch");
958 assert_eq!(
959 OidcDiscoveryError::InvalidUrl {
960 field: "authorization_endpoint",
961 reason: "bad URL".to_owned(),
962 }
963 .code(),
964 "discovery_invalid_url"
965 );
966 assert_eq!(
967 OidcDiscoveryError::Timeout.status(),
968 http::StatusCode::BAD_GATEWAY
969 );
970 assert_eq!(
971 OidcDiscoveryError::InvalidJson("bad".to_owned()).status(),
972 http::StatusCode::BAD_REQUEST
973 );
974 }
975
976 #[test]
977 fn discovery_validation_reports_all_missing_required_fields(
978 ) -> Result<(), Box<dyn std::error::Error>> {
979 let document: OidcDiscoveryDocument = serde_json::from_str(
980 r#"{
981 "issuer":"https://idp.example.com"
982 }"#,
983 )?;
984
985 let error = match validate_discovery_document(&document, "https://idp.example.com") {
986 Ok(()) => return Err("expected incomplete discovery document".into()),
987 Err(error) => error,
988 };
989
990 assert_eq!(error.code(), "discovery_incomplete");
991 assert!(matches!(
992 error,
993 OidcDiscoveryError::MissingFields(fields)
994 if fields == vec!["authorization_endpoint", "token_endpoint", "jwks_uri"]
995 ));
996 Ok(())
997 }
998
999 #[test]
1000 fn discovery_validation_reports_each_missing_required_field() {
1001 for (field, document) in [
1002 (
1003 "issuer",
1004 OidcDiscoveryDocument {
1005 issuer: String::new(),
1006 ..discovery_document("https://idp.example.com")
1007 },
1008 ),
1009 (
1010 "authorization_endpoint",
1011 OidcDiscoveryDocument {
1012 authorization_endpoint: String::new(),
1013 ..discovery_document("https://idp.example.com")
1014 },
1015 ),
1016 (
1017 "token_endpoint",
1018 OidcDiscoveryDocument {
1019 token_endpoint: String::new(),
1020 ..discovery_document("https://idp.example.com")
1021 },
1022 ),
1023 (
1024 "jwks_uri",
1025 OidcDiscoveryDocument {
1026 jwks_uri: String::new(),
1027 ..discovery_document("https://idp.example.com")
1028 },
1029 ),
1030 ] {
1031 assert!(matches!(
1032 validate_discovery_document(&document, "https://idp.example.com"),
1033 Err(OidcDiscoveryError::MissingField(missing)) if missing == field
1034 ));
1035 }
1036 }
1037
1038 #[test]
1039 fn discovery_validation_normalizes_issuer_trailing_slash() {
1040 let document = discovery_document("https://idp.example.com/");
1041 assert!(validate_discovery_document(&document, "https://idp.example.com").is_ok());
1042 let document = discovery_document("https://idp.example.com");
1043 assert!(validate_discovery_document(&document, "https://idp.example.com/").is_ok());
1044 }
1045
1046 #[test]
1047 fn discovery_validation_rejects_issuer_mismatch() {
1048 let document = discovery_document("https://evil.example.com");
1049 assert!(matches!(
1050 validate_discovery_document(&document, "https://idp.example.com"),
1051 Err(OidcDiscoveryError::IssuerMismatch)
1052 ));
1053 }
1054
1055 #[test]
1056 fn required_discovery_fields_match_upstream_contract() {
1057 assert_eq!(
1058 REQUIRED_DISCOVERY_FIELDS,
1059 &[
1060 "issuer",
1061 "authorization_endpoint",
1062 "token_endpoint",
1063 "jwks_uri",
1064 ]
1065 );
1066 }
1067
1068 #[test]
1069 fn validate_discovery_url_rejects_invalid_and_untrusted_urls() {
1070 assert!(matches!(
1071 validate_discovery_url("not-a-url", |_| true),
1072 Err(OidcDiscoveryError::InvalidUrl { .. })
1073 ));
1074 assert!(matches!(
1075 validate_discovery_url("ftp://idp.example.com/config", |_| true),
1076 Err(OidcDiscoveryError::InvalidUrl { .. })
1077 ));
1078 assert!(matches!(
1079 validate_discovery_url(
1080 "https://untrusted.example.com/.well-known/openid-configuration",
1081 |_| false
1082 ),
1083 Err(OidcDiscoveryError::UntrustedOrigin { .. })
1084 ));
1085 assert!(validate_discovery_url(
1086 "https://idp.example.com/.well-known/openid-configuration",
1087 |_| true
1088 )
1089 .is_ok());
1090 }
1091
1092 #[test]
1093 fn normalize_discovery_urls_rejects_untrusted_required_endpoints(
1094 ) -> Result<(), Box<dyn std::error::Error>> {
1095 let document = OidcDiscoveryDocument {
1096 issuer: "https://idp.example.com".to_owned(),
1097 authorization_endpoint: "/oauth2/authorize".to_owned(),
1098 token_endpoint: "/oauth2/token".to_owned(),
1099 jwks_uri: "/.well-known/jwks.json".to_owned(),
1100 userinfo_endpoint: Some("/userinfo".to_owned()),
1101 revocation_endpoint: Some("/revoke".to_owned()),
1102 end_session_endpoint: Some("/endsession".to_owned()),
1103 introspection_endpoint: Some("/introspection".to_owned()),
1104 token_endpoint_auth_methods_supported: None,
1105 scopes_supported: None,
1106 response_types_supported: None,
1107 subject_types_supported: None,
1108 id_token_signing_alg_values_supported: None,
1109 claims_supported: None,
1110 code_challenge_methods_supported: None,
1111 };
1112
1113 for (suffix, field_hint) in [
1114 ("/oauth2/token", "token_endpoint"),
1115 ("/oauth2/authorize", "authorization_endpoint"),
1116 ("/.well-known/jwks.json", "jwks_uri"),
1117 ("/userinfo", "userinfo_endpoint"),
1118 ("/revoke", "revocation_endpoint"),
1119 ("/endsession", "end_session_endpoint"),
1120 ("/introspection", "introspection_endpoint"),
1121 ] {
1122 let error =
1123 match normalize_discovery_urls(document.clone(), "https://idp.example.com", |url| {
1124 !url.ends_with(suffix)
1125 }) {
1126 Ok(_) => return Err(format!("expected untrusted {field_hint}").into()),
1127 Err(error) => error,
1128 };
1129 assert_eq!(error.code(), "discovery_untrusted_origin");
1130 assert!(error.to_string().contains(field_hint));
1131 }
1132 Ok(())
1133 }
1134
1135 #[test]
1136 fn token_endpoint_authentication_prefers_existing_config_value() {
1137 let document = discovery_document("https://idp.example.com");
1138 assert_eq!(
1139 select_token_endpoint_authentication(
1140 &document,
1141 Some(TokenEndpointAuthentication::ClientSecretPost)
1142 ),
1143 TokenEndpointAuthentication::ClientSecretPost
1144 );
1145 }
1146
1147 #[test]
1148 fn token_endpoint_authentication_prefers_client_secret_basic_when_both_supported() {
1149 let mut document = discovery_document("https://idp.example.com");
1150 document.token_endpoint_auth_methods_supported = Some(vec![
1151 "client_secret_post".to_owned(),
1152 "client_secret_basic".to_owned(),
1153 ]);
1154 assert_eq!(
1155 select_token_endpoint_authentication(&document, None),
1156 TokenEndpointAuthentication::ClientSecretBasic
1157 );
1158 }
1159
1160 #[test]
1161 fn token_endpoint_authentication_selects_client_secret_post_when_only_supported() {
1162 let mut document = discovery_document("https://idp.example.com");
1163 document.token_endpoint_auth_methods_supported =
1164 Some(vec!["client_secret_post".to_owned()]);
1165 assert_eq!(
1166 select_token_endpoint_authentication(&document, None),
1167 TokenEndpointAuthentication::ClientSecretPost
1168 );
1169 }
1170
1171 #[test]
1172 fn normalize_absolute_http_url_accepts_http_and_https() -> Result<(), OidcDiscoveryError> {
1173 assert_eq!(
1174 normalize_absolute_http_url("discovery_endpoint", "http://idp.example.com/path")?,
1175 "http://idp.example.com/path"
1176 );
1177 assert_eq!(
1178 normalize_absolute_http_url("discovery_endpoint", "https://idp.example.com/path")?,
1179 "https://idp.example.com/path"
1180 );
1181 Ok(())
1182 }
1183
1184 #[test]
1185 fn token_endpoint_authentication_defaults_for_empty_or_unsupported_methods() {
1186 let mut document = discovery_document("https://idp.example.com");
1187 document.token_endpoint_auth_methods_supported = Some(Vec::new());
1188 assert_eq!(
1189 select_token_endpoint_authentication(&document, None),
1190 TokenEndpointAuthentication::ClientSecretBasic
1191 );
1192
1193 document.token_endpoint_auth_methods_supported = Some(vec![
1194 "private_key_jwt".to_owned(),
1195 "tls_client_auth".to_owned(),
1196 ]);
1197 assert_eq!(
1198 select_token_endpoint_authentication(&document, None),
1199 TokenEndpointAuthentication::ClientSecretBasic
1200 );
1201 }
1202
1203 #[test]
1204 fn discovery_validation_accepts_document_without_optional_metadata(
1205 ) -> Result<(), Box<dyn std::error::Error>> {
1206 let document: OidcDiscoveryDocument = serde_json::from_str(
1207 r#"{
1208 "issuer":"https://idp.example.com",
1209 "authorization_endpoint":"https://idp.example.com/authorize",
1210 "token_endpoint":"https://idp.example.com/token",
1211 "jwks_uri":"https://idp.example.com/keys"
1212 }"#,
1213 )?;
1214
1215 validate_discovery_document(&document, "https://idp.example.com")?;
1216 assert_eq!(document.userinfo_endpoint, None);
1217 assert_eq!(document.response_types_supported, None);
1218 assert_eq!(document.subject_types_supported, None);
1219 assert_eq!(document.id_token_signing_alg_values_supported, None);
1220 assert_eq!(document.claims_supported, None);
1221 assert_eq!(document.code_challenge_methods_supported, None);
1222 Ok(())
1223 }
1224
1225 #[tokio::test]
1226 async fn fetch_discovery_document_classifies_http_and_json_errors(
1227 ) -> Result<(), Box<dyn std::error::Error>> {
1228 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1229 let address = listener.local_addr()?;
1230 tokio::spawn(async move {
1231 while let Ok((mut stream, _)) = listener.accept().await {
1232 tokio::spawn(async move {
1233 let mut buffer = [0_u8; 1024];
1234 let Ok(read) = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await
1235 else {
1236 return;
1237 };
1238 let request = String::from_utf8_lossy(&buffer[..read]);
1239 let (status, body) = if request.starts_with("GET /missing ") {
1240 ("404 Not Found", "not found")
1241 } else if request.starts_with("GET /server-error ") {
1242 ("500 Internal Server Error", "server error")
1243 } else if request.starts_with("GET /timeout-status ") {
1244 ("408 Request Timeout", "timeout")
1245 } else if request.starts_with("GET /empty ") {
1246 ("200 OK", "")
1247 } else {
1248 ("200 OK", "not-json")
1249 };
1250 let response = format!(
1251 "HTTP/1.1 {status}\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
1252 body.len()
1253 );
1254 let _ =
1255 tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1256 });
1257 }
1258 });
1259
1260 let client = reqwest::Client::new();
1261 let missing_error =
1262 match fetch_discovery_document(&format!("http://{address}/missing"), &client).await {
1263 Ok(_) => return Err("expected missing discovery document to fail".into()),
1264 Err(error) => error,
1265 };
1266 assert_eq!(missing_error.code(), "discovery_not_found");
1267
1268 let server_error = match fetch_discovery_document(
1269 &format!("http://{address}/server-error"),
1270 &client,
1271 )
1272 .await
1273 {
1274 Ok(_) => return Err("expected server error discovery document to fail".into()),
1275 Err(error) => error,
1276 };
1277 assert_eq!(server_error.code(), "discovery_unexpected_error");
1278
1279 let timeout_error =
1280 match fetch_discovery_document(&format!("http://{address}/timeout-status"), &client)
1281 .await
1282 {
1283 Ok(_) => return Err("expected timeout discovery document to fail".into()),
1284 Err(error) => error,
1285 };
1286 assert_eq!(timeout_error.code(), "discovery_timeout");
1287
1288 let empty_response_error =
1289 match fetch_discovery_document(&format!("http://{address}/empty"), &client).await {
1290 Ok(_) => return Err("expected empty discovery document to fail".into()),
1291 Err(error) => error,
1292 };
1293 assert_eq!(empty_response_error.code(), "discovery_invalid_json");
1294
1295 let invalid_json_error = match fetch_discovery_document(
1296 &format!("http://{address}/invalid-json"),
1297 &client,
1298 )
1299 .await
1300 {
1301 Ok(_) => return Err("expected invalid JSON discovery document to fail".into()),
1302 Err(error) => error,
1303 };
1304 assert_eq!(invalid_json_error.code(), "discovery_invalid_json");
1305 Ok(())
1306 }
1307
1308 #[tokio::test]
1309 async fn discovery_rejects_untrusted_discovered_endpoint_origins(
1310 ) -> Result<(), Box<dyn std::error::Error>> {
1311 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1312 let address = listener.local_addr()?;
1313 let base_url = format!("http://{address}");
1314 let server_base_url = base_url.clone();
1315 tokio::spawn(async move {
1316 while let Ok((mut stream, _)) = listener.accept().await {
1317 let server_base_url = server_base_url.clone();
1318 tokio::spawn(async move {
1319 let mut buffer = [0_u8; 1024];
1320 let Ok(read) = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await
1321 else {
1322 return;
1323 };
1324 let request = String::from_utf8_lossy(&buffer[..read]);
1325 let body = if request.starts_with("GET /.well-known/openid-configuration ") {
1326 format!(
1327 r#"{{
1328 "issuer":"{server_base_url}",
1329 "authorization_endpoint":"{server_base_url}/authorize",
1330 "token_endpoint":"https://untrusted.example.com/token",
1331 "jwks_uri":"{server_base_url}/keys",
1332 "userinfo_endpoint":"{server_base_url}/userinfo"
1333 }}"#
1334 )
1335 } else {
1336 r#"{"error":"not_found"}"#.to_owned()
1337 };
1338 let response = format!(
1339 "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
1340 body.len()
1341 );
1342 let _ =
1343 tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1344 });
1345 }
1346 });
1347
1348 let error = match discover_oidc_config_with_origin_validator(
1349 &base_url,
1350 None,
1351 PartialOidcDiscoveryConfig::default(),
1352 |url| url.starts_with(&base_url),
1353 &reqwest::Client::new(),
1354 )
1355 .await
1356 {
1357 Ok(_) => return Err("expected untrusted discovered endpoint to fail".into()),
1358 Err(error) => error,
1359 };
1360 assert_eq!(error.code(), "discovery_untrusted_origin");
1361 Ok(())
1362 }
1363
1364 #[tokio::test]
1365 async fn discovery_rejects_untrusted_optional_endpoint_origins(
1366 ) -> Result<(), Box<dyn std::error::Error>> {
1367 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1368 let address = listener.local_addr()?;
1369 let base_url = format!("http://{address}");
1370 let server_base_url = base_url.clone();
1371 tokio::spawn(async move {
1372 while let Ok((mut stream, _)) = listener.accept().await {
1373 let server_base_url = server_base_url.clone();
1374 tokio::spawn(async move {
1375 let mut buffer = [0_u8; 1024];
1376 let Ok(read) = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await
1377 else {
1378 return;
1379 };
1380 let request = String::from_utf8_lossy(&buffer[..read]);
1381 let body = if request.starts_with("GET /.well-known/openid-configuration ") {
1382 format!(
1383 r#"{{
1384 "issuer":"{server_base_url}",
1385 "authorization_endpoint":"{server_base_url}/authorize",
1386 "token_endpoint":"{server_base_url}/token",
1387 "jwks_uri":"{server_base_url}/keys",
1388 "revocation_endpoint":"https://untrusted.example.com/revoke"
1389 }}"#
1390 )
1391 } else {
1392 r#"{"error":"not_found"}"#.to_owned()
1393 };
1394 let response = format!(
1395 "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
1396 body.len()
1397 );
1398 let _ =
1399 tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1400 });
1401 }
1402 });
1403
1404 let error = match discover_oidc_config_with_origin_validator(
1405 &base_url,
1406 None,
1407 PartialOidcDiscoveryConfig::default(),
1408 |url| url.starts_with(&base_url),
1409 &reqwest::Client::new(),
1410 )
1411 .await
1412 {
1413 Ok(_) => return Err("expected untrusted optional endpoint to fail".into()),
1414 Err(error) => error,
1415 };
1416 assert_eq!(error.code(), "discovery_untrusted_origin");
1417 assert!(error.to_string().contains("revocation_endpoint"));
1418 Ok(())
1419 }
1420
1421 #[tokio::test]
1422 async fn discover_ignores_empty_existing_endpoint_overrides(
1423 ) -> Result<(), Box<dyn std::error::Error>> {
1424 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1425 let address = listener.local_addr()?;
1426 let base_url = format!("http://{address}");
1427 let server_base_url = base_url.clone();
1428 tokio::spawn(async move {
1429 while let Ok((mut stream, _)) = listener.accept().await {
1430 let server_base_url = server_base_url.clone();
1431 tokio::spawn(async move {
1432 let mut buffer = [0_u8; 1024];
1433 let Ok(read) = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await
1434 else {
1435 return;
1436 };
1437 let request = String::from_utf8_lossy(&buffer[..read]);
1438 let body = if request.starts_with("GET /.well-known/openid-configuration ") {
1439 format!(
1440 r#"{{
1441 "issuer":"{server_base_url}",
1442 "authorization_endpoint":"{server_base_url}/authorize",
1443 "token_endpoint":"{server_base_url}/token",
1444 "jwks_uri":"{server_base_url}/keys"
1445 }}"#
1446 )
1447 } else {
1448 r#"{"error":"not_found"}"#.to_owned()
1449 };
1450 let response = format!(
1451 "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
1452 body.len()
1453 );
1454 let _ =
1455 tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1456 });
1457 }
1458 });
1459
1460 let hydrated = discover_oidc_config_with_origin_validator(
1461 &base_url,
1462 None,
1463 PartialOidcDiscoveryConfig {
1464 authorization_endpoint: Some(""),
1465 ..PartialOidcDiscoveryConfig::default()
1466 },
1467 |url| url.starts_with(&base_url),
1468 &reqwest::Client::new(),
1469 )
1470 .await?;
1471
1472 assert_eq!(
1473 hydrated.authorization_endpoint,
1474 format!("{base_url}/authorize")
1475 );
1476 Ok(())
1477 }
1478
1479 #[tokio::test]
1480 async fn discovery_preserves_user_supplied_endpoints_over_discovered_values(
1481 ) -> Result<(), Box<dyn std::error::Error>> {
1482 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1483 let address = listener.local_addr()?;
1484 let base_url = format!("http://{address}");
1485 let server_base_url = base_url.clone();
1486 tokio::spawn(async move {
1487 while let Ok((mut stream, _)) = listener.accept().await {
1488 let server_base_url = server_base_url.clone();
1489 tokio::spawn(async move {
1490 let mut buffer = [0_u8; 1024];
1491 let Ok(read) = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await
1492 else {
1493 return;
1494 };
1495 let request = String::from_utf8_lossy(&buffer[..read]);
1496 let body = if request.starts_with("GET /.well-known/openid-configuration ") {
1497 format!(
1498 r#"{{
1499 "issuer":"{server_base_url}",
1500 "authorization_endpoint":"{server_base_url}/discovered/authorize",
1501 "token_endpoint":"{server_base_url}/discovered/token",
1502 "jwks_uri":"{server_base_url}/discovered/keys",
1503 "userinfo_endpoint":"{server_base_url}/discovered/userinfo",
1504 "token_endpoint_auth_methods_supported":["client_secret_post"]
1505 }}"#
1506 )
1507 } else {
1508 r#"{"error":"not_found"}"#.to_owned()
1509 };
1510 let response = format!(
1511 "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
1512 body.len()
1513 );
1514 let _ =
1515 tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1516 });
1517 }
1518 });
1519
1520 let custom_authorization_endpoint = format!("{base_url}/custom/authorize");
1521 let custom_token_endpoint = format!("{base_url}/custom/token");
1522 let custom_user_info_endpoint = format!("{base_url}/custom/userinfo");
1523 let custom_jwks_endpoint = format!("{base_url}/custom/keys");
1524 let existing = PartialOidcDiscoveryConfig {
1525 authorization_endpoint: Some(&custom_authorization_endpoint),
1526 token_endpoint: Some(&custom_token_endpoint),
1527 user_info_endpoint: Some(&custom_user_info_endpoint),
1528 jwks_endpoint: Some(&custom_jwks_endpoint),
1529 token_endpoint_authentication: Some(TokenEndpointAuthentication::ClientSecretBasic),
1530 ..PartialOidcDiscoveryConfig::default()
1531 };
1532
1533 let hydrated = discover_oidc_config_with_origin_validator(
1534 &base_url,
1535 None,
1536 existing,
1537 |url| url.starts_with(&base_url),
1538 &reqwest::Client::new(),
1539 )
1540 .await?;
1541
1542 assert_eq!(
1543 hydrated.authorization_endpoint,
1544 custom_authorization_endpoint
1545 );
1546 assert_eq!(hydrated.token_endpoint, custom_token_endpoint);
1547 assert_eq!(hydrated.jwks_endpoint, custom_jwks_endpoint);
1548 assert_eq!(
1549 hydrated.user_info_endpoint.as_deref(),
1550 Some(custom_user_info_endpoint.as_str())
1551 );
1552 assert_eq!(
1553 hydrated.token_endpoint_authentication,
1554 TokenEndpointAuthentication::ClientSecretBasic
1555 );
1556 Ok(())
1557 }
1558
1559 #[tokio::test]
1560 async fn discover_uses_custom_and_existing_discovery_endpoints(
1561 ) -> Result<(), Box<dyn std::error::Error>> {
1562 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1563 let address = listener.local_addr()?;
1564 let base_url = format!("http://{address}");
1565 let server_base_url = base_url.clone();
1566 tokio::spawn(async move {
1567 while let Ok((mut stream, _)) = listener.accept().await {
1568 let server_base_url = server_base_url.clone();
1569 tokio::spawn(async move {
1570 let mut buffer = [0_u8; 4096];
1571 let Ok(read) = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await
1572 else {
1573 return;
1574 };
1575 let request = String::from_utf8_lossy(&buffer[..read]);
1576 let body = if request.contains("GET /custom/.well-known/openid-configuration ")
1577 {
1578 format!(
1579 r#"{{
1580 "issuer":"{server_base_url}",
1581 "authorization_endpoint":"{server_base_url}/authorize",
1582 "token_endpoint":"{server_base_url}/token",
1583 "jwks_uri":"{server_base_url}/keys"
1584 }}"#
1585 )
1586 } else if request.contains("GET /tenant/.well-known/openid-configuration ") {
1587 format!(
1588 r#"{{
1589 "issuer":"{server_base_url}",
1590 "authorization_endpoint":"{server_base_url}/tenant/authorize",
1591 "token_endpoint":"{server_base_url}/tenant/token",
1592 "jwks_uri":"{server_base_url}/tenant/keys"
1593 }}"#
1594 )
1595 } else {
1596 r#"{"error":"not_found"}"#.to_owned()
1597 };
1598 let response = format!(
1599 "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
1600 body.len()
1601 );
1602 let _ =
1603 tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1604 });
1605 }
1606 });
1607
1608 let custom_endpoint = format!("{base_url}/custom/.well-known/openid-configuration");
1609 let custom = discover_oidc_config_with_origin_validator(
1610 &base_url,
1611 Some(&custom_endpoint),
1612 PartialOidcDiscoveryConfig::default(),
1613 |url| url.starts_with(&base_url),
1614 &reqwest::Client::new(),
1615 )
1616 .await?;
1617 assert_eq!(custom.discovery_endpoint, custom_endpoint);
1618
1619 let existing_endpoint = format!("{base_url}/tenant/.well-known/openid-configuration");
1620 let existing = discover_oidc_config_with_origin_validator(
1621 &base_url,
1622 None,
1623 PartialOidcDiscoveryConfig {
1624 discovery_endpoint: Some(&existing_endpoint),
1625 ..PartialOidcDiscoveryConfig::default()
1626 },
1627 |url| url.starts_with(&base_url),
1628 &reqwest::Client::new(),
1629 )
1630 .await?;
1631 assert_eq!(existing.discovery_endpoint, existing_endpoint);
1632 assert_eq!(
1633 existing.authorization_endpoint,
1634 format!("{base_url}/tenant/authorize")
1635 );
1636 Ok(())
1637 }
1638
1639 #[tokio::test]
1640 async fn discover_includes_scopes_supported_and_ignores_unknown_fields(
1641 ) -> Result<(), Box<dyn std::error::Error>> {
1642 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1643 let address = listener.local_addr()?;
1644 let base_url = format!("http://{address}");
1645 let server_base_url = base_url.clone();
1646 tokio::spawn(async move {
1647 while let Ok((mut stream, _)) = listener.accept().await {
1648 let server_base_url = server_base_url.clone();
1649 tokio::spawn(async move {
1650 let mut buffer = [0_u8; 1024];
1651 let Ok(read) = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await
1652 else {
1653 return;
1654 };
1655 let request = String::from_utf8_lossy(&buffer[..read]);
1656 let body = if request.starts_with("GET /.well-known/openid-configuration ") {
1657 format!(
1658 r#"{{
1659 "issuer":"{server_base_url}",
1660 "authorization_endpoint":"{server_base_url}/authorize",
1661 "token_endpoint":"{server_base_url}/token",
1662 "jwks_uri":"{server_base_url}/keys",
1663 "scopes_supported":["openid","profile","email","custom"],
1664 "x-vendor-feature":true,
1665 "custom_logout_endpoint":"{server_base_url}/logout"
1666 }}"#
1667 )
1668 } else {
1669 r#"{"error":"not_found"}"#.to_owned()
1670 };
1671 let response = format!(
1672 "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
1673 body.len()
1674 );
1675 let _ =
1676 tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1677 });
1678 }
1679 });
1680
1681 let hydrated = discover_oidc_config_with_origin_validator(
1682 &base_url,
1683 None,
1684 PartialOidcDiscoveryConfig::default(),
1685 |url| url.starts_with(&base_url),
1686 &reqwest::Client::new(),
1687 )
1688 .await?;
1689
1690 assert_eq!(
1691 hydrated.scopes_supported,
1692 Some(vec![
1693 "openid".to_owned(),
1694 "profile".to_owned(),
1695 "email".to_owned(),
1696 "custom".to_owned()
1697 ])
1698 );
1699 assert_eq!(hydrated.user_info_endpoint, None);
1700 Ok(())
1701 }
1702
1703 #[tokio::test]
1704 async fn discover_rejects_untrusted_main_discovery_url(
1705 ) -> Result<(), Box<dyn std::error::Error>> {
1706 let error = match discover_oidc_config_with_origin_validator(
1707 "https://idp.example.com",
1708 None,
1709 PartialOidcDiscoveryConfig::default(),
1710 |_| false,
1711 &reqwest::Client::new(),
1712 )
1713 .await
1714 {
1715 Ok(_) => return Err("expected untrusted discovery URL to fail".into()),
1716 Err(error) => error,
1717 };
1718 assert_eq!(error.code(), "discovery_untrusted_origin");
1719 assert!(error.to_string().contains("discovery_endpoint"));
1720 Ok(())
1721 }
1722
1723 #[tokio::test]
1724 async fn ensure_runtime_returns_unchanged_config_when_discovery_not_needed(
1725 ) -> Result<(), Box<dyn std::error::Error>> {
1726 let config = OidcConfig {
1727 issuer: "https://idp.example.com".to_owned(),
1728 pkce: true,
1729 client_id: "client-id".to_owned(),
1730 client_secret: "client-secret".into(),
1731 discovery_endpoint: compute_discovery_url("https://idp.example.com"),
1732 authorization_endpoint: Some("https://idp.example.com/authorize".to_owned()),
1733 token_endpoint: Some("https://idp.example.com/token".to_owned()),
1734 user_info_endpoint: Some("https://idp.example.com/userinfo".to_owned()),
1735 jwks_endpoint: Some("https://idp.example.com/keys".to_owned()),
1736 revocation_endpoint: None,
1737 end_session_endpoint: None,
1738 introspection_endpoint: None,
1739 token_endpoint_authentication: None,
1740 scopes: Some(vec!["openid".to_owned()]),
1741 mapping: None,
1742 override_user_info: false,
1743 };
1744
1745 let unchanged = ensure_runtime_oidc_config_with_origin_validator(
1746 "https://idp.example.com",
1747 config.clone(),
1748 OidcRuntimeRequirement::Callback,
1749 |_| true,
1750 false,
1751 &reqwest::Client::new(),
1752 )
1753 .await?;
1754
1755 assert_eq!(unchanged.client_id, config.client_id);
1756 assert_eq!(
1757 unchanged.client_secret.expose_secret(),
1758 config.client_secret.expose_secret()
1759 );
1760 assert_eq!(unchanged.pkce, config.pkce);
1761 assert_eq!(unchanged.scopes, config.scopes);
1762 assert_eq!(
1763 unchanged.authorization_endpoint,
1764 config.authorization_endpoint
1765 );
1766 Ok(())
1767 }
1768
1769 #[tokio::test]
1770 async fn ensure_runtime_throws_when_discovery_fails() -> Result<(), Box<dyn std::error::Error>>
1771 {
1772 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1773 let address = listener.local_addr()?;
1774 let base_url = format!("http://{address}");
1775 tokio::spawn(async move {
1776 while let Ok((mut stream, _)) = listener.accept().await {
1777 tokio::spawn(async move {
1778 let mut buffer = [0_u8; 1024];
1779 let _ = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await;
1780 let response =
1781 "HTTP/1.1 404 Not Found\r\ncontent-type: application/json\r\ncontent-length: 2\r\nconnection: close\r\n\r\n{}";
1782 let _ =
1783 tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1784 });
1785 }
1786 });
1787
1788 let config = OidcConfig {
1789 issuer: base_url.clone(),
1790 pkce: true,
1791 client_id: "client-id".to_owned(),
1792 client_secret: "client-secret".into(),
1793 discovery_endpoint: compute_discovery_url(&base_url),
1794 authorization_endpoint: None,
1795 token_endpoint: None,
1796 user_info_endpoint: None,
1797 jwks_endpoint: None,
1798 revocation_endpoint: None,
1799 end_session_endpoint: None,
1800 introspection_endpoint: None,
1801 token_endpoint_authentication: None,
1802 scopes: None,
1803 mapping: None,
1804 override_user_info: false,
1805 };
1806
1807 let error = match ensure_runtime_oidc_config_with_origin_validator(
1808 &base_url,
1809 config,
1810 OidcRuntimeRequirement::SignIn,
1811 |_| true,
1812 false,
1813 &reqwest::Client::new(),
1814 )
1815 .await
1816 {
1817 Ok(_) => return Err("expected runtime discovery failure".into()),
1818 Err(error) => error,
1819 };
1820 assert_eq!(error.code(), "discovery_not_found");
1821 Ok(())
1822 }
1823
1824 #[tokio::test]
1825 async fn runtime_discovery_preserves_only_explicit_request_scopes(
1826 ) -> Result<(), Box<dyn std::error::Error>> {
1827 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1828 let address = listener.local_addr()?;
1829 let base_url = format!("http://{address}");
1830 let server_base_url = base_url.clone();
1831 tokio::spawn(async move {
1832 while let Ok((mut stream, _)) = listener.accept().await {
1833 let server_base_url = server_base_url.clone();
1834 tokio::spawn(async move {
1835 let mut buffer = [0_u8; 1024];
1836 let Ok(read) = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await
1837 else {
1838 return;
1839 };
1840 let request = String::from_utf8_lossy(&buffer[..read]);
1841 let body = if request.starts_with("GET /.well-known/openid-configuration ") {
1842 format!(
1843 r#"{{
1844 "issuer":"{server_base_url}",
1845 "authorization_endpoint":"{server_base_url}/authorize",
1846 "token_endpoint":"{server_base_url}/token",
1847 "jwks_uri":"{server_base_url}/keys",
1848 "scopes_supported":["openid","profile"]
1849 }}"#
1850 )
1851 } else {
1852 r#"{"error":"not_found"}"#.to_owned()
1853 };
1854 let response = format!(
1855 "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
1856 body.len()
1857 );
1858 let _ =
1859 tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1860 });
1861 }
1862 });
1863
1864 let config = OidcConfig {
1865 issuer: base_url.clone(),
1866 pkce: true,
1867 client_id: "client".to_owned(),
1868 client_secret: "secret".into(),
1869 discovery_endpoint: compute_discovery_url(&base_url),
1870 authorization_endpoint: None,
1871 token_endpoint: None,
1872 user_info_endpoint: None,
1873 jwks_endpoint: None,
1874 revocation_endpoint: None,
1875 end_session_endpoint: None,
1876 introspection_endpoint: None,
1877 token_endpoint_authentication: None,
1878 scopes: None,
1879 mapping: None,
1880 override_user_info: false,
1881 };
1882
1883 let hydrated = ensure_runtime_oidc_config_with_origin_validator(
1884 &base_url,
1885 config,
1886 OidcRuntimeRequirement::SignIn,
1887 |url| url.starts_with(&base_url),
1888 false,
1889 &reqwest::Client::new(),
1890 )
1891 .await?;
1892
1893 assert_eq!(hydrated.scopes, None);
1894
1895 let explicit_config = OidcConfig {
1896 scopes: Some(vec!["openid".to_owned(), "email".to_owned()]),
1897 authorization_endpoint: None,
1898 token_endpoint: None,
1899 jwks_endpoint: None,
1900 ..hydrated
1901 };
1902 let explicit_hydrated = ensure_runtime_oidc_config_with_origin_validator(
1903 &base_url,
1904 explicit_config,
1905 OidcRuntimeRequirement::SignIn,
1906 |url| url.starts_with(&base_url),
1907 false,
1908 &reqwest::Client::new(),
1909 )
1910 .await?;
1911
1912 assert_eq!(
1913 explicit_hydrated.scopes,
1914 Some(vec!["openid".to_owned(), "email".to_owned()])
1915 );
1916 Ok(())
1917 }
1918}