Skip to main content

securitydept_oidc_client/
client.rs

1use std::{borrow::Cow, cmp::min, sync::Arc, time::Duration};
2
3use base64::Engine;
4use chrono::Utc;
5use openidconnect::{
6    AccessToken, AuthType, AuthenticationFlow, AuthorizationCode, Client, ClientId, ClientSecret,
7    CsrfToken, DeviceAuthorizationUrl, DeviceCodeErrorResponse, DeviceCodeErrorResponseType,
8    EndpointMaybeSet, EndpointNotSet, EndpointSet, IntrospectionUrl, Nonce, OAuth2TokenResponse,
9    PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RevocationUrl, Scope,
10    StandardErrorResponse, StandardTokenResponse, SubjectIdentifier, TokenResponse,
11    core::{
12        CoreAuthDisplay, CoreAuthPrompt, CoreClientAuthMethod, CoreDeviceAuthorizationResponse,
13        CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm,
14        CoreJwsSigningAlgorithm, CoreRevocableToken, CoreRevocationErrorResponse,
15        CoreTokenIntrospectionResponse, CoreTokenType,
16    },
17    reqwest,
18};
19use securitydept_oauth_provider::{OAuthProviderRuntime, ProviderMetadataWithExtra};
20use securitydept_utils::observability::{
21    AuthFlowDiagnosis, AuthFlowDiagnosisField, AuthFlowDiagnosisOutcome, AuthFlowOperation,
22    DiagnosedResult,
23};
24use url::Url;
25
26#[cfg(not(feature = "claims-script"))]
27use crate::claims::DefaultClaimsChecker;
28#[cfg(feature = "claims-script")]
29use crate::claims::ScriptClaimsChecker;
30use crate::{
31    ClaimsCheckResult, ExtraOidcClaims, IdTokenClaimsWithExtra, OidcCodeCallbackSearchParams,
32    OidcCodeExchangeResult, OidcCodeFlowAuthorizationRequest, OidcDeviceAuthorizationResult,
33    OidcDeviceTokenPollResult, OidcDeviceTokenResult, OidcRevocableToken, PendingOauthStore,
34    PendingOauthStoreConfig, UserInfoClaimsWithExtra, UserInfoExchangeResult,
35    claims::ClaimsChecker,
36    config::OidcClientConfig,
37    error::{OidcError, OidcResult},
38    models::{IdTokenFieldsWithExtra, OidcCodeCallbackResult, OidcRefreshTokenResult},
39};
40
41pub type TokenResponseWithExtra = StandardTokenResponse<IdTokenFieldsWithExtra, CoreTokenType>;
42
43pub type ClientWithExtra<
44    HasAuthUrl = EndpointNotSet,
45    HasDeviceAuthUrl = EndpointNotSet,
46    HasIntrospectionUrl = EndpointNotSet,
47    HasRevocationUrl = EndpointNotSet,
48    HasTokenUrl = EndpointNotSet,
49    HasUserInfoUrl = EndpointNotSet,
50> = Client<
51    ExtraOidcClaims,
52    CoreAuthDisplay,
53    CoreGenderClaim,
54    CoreJweContentEncryptionAlgorithm,
55    CoreJsonWebKey,
56    CoreAuthPrompt,
57    StandardErrorResponse<CoreErrorResponseType>,
58    TokenResponseWithExtra,
59    CoreTokenIntrospectionResponse,
60    CoreRevocableToken,
61    CoreRevocationErrorResponse,
62    HasAuthUrl,
63    HasDeviceAuthUrl,
64    HasIntrospectionUrl,
65    HasRevocationUrl,
66    HasTokenUrl,
67    HasUserInfoUrl,
68>;
69
70pub type DiscoveredClientWithExtra = ClientWithExtra<
71    EndpointSet,
72    EndpointNotSet,
73    EndpointNotSet,
74    EndpointNotSet,
75    EndpointMaybeSet,
76    EndpointMaybeSet,
77>;
78
79type DeviceAuthorizationClientWithExtra = ClientWithExtra<
80    EndpointSet,
81    EndpointSet,
82    EndpointNotSet,
83    EndpointNotSet,
84    EndpointMaybeSet,
85    EndpointMaybeSet,
86>;
87
88type RevocationClientWithExtra = ClientWithExtra<
89    EndpointSet,
90    EndpointNotSet,
91    EndpointNotSet,
92    EndpointSet,
93    EndpointMaybeSet,
94    EndpointMaybeSet,
95>;
96
97struct OptionalClientEndpoints {
98    _introspection_endpoint: Option<IntrospectionUrl>,
99    revocation_endpoint: Option<RevocationUrl>,
100    device_authorization_endpoint: Option<DeviceAuthorizationUrl>,
101}
102
103struct BuiltClientWithExtra {
104    client: DiscoveredClientWithExtra,
105    optional_endpoints: OptionalClientEndpoints,
106}
107
108/// Wraps the OIDC discovered client for login/callback flows.
109///
110/// The redirect URI is resolved dynamically per-request so that
111/// `external_base_url = "auto"` can produce the correct absolute callback URL
112/// based on the incoming request headers.
113pub struct OidcClient<PS>
114where
115    PS: PendingOauthStore,
116{
117    config: OidcClientConfig<PS::Config>,
118    provider: Arc<OAuthProviderRuntime>,
119    base_client: DiscoveredClientWithExtra,
120    #[cfg(feature = "claims-script")]
121    claims_checker: ScriptClaimsChecker,
122    #[cfg(not(feature = "claims-script"))]
123    claims_checker: DefaultClaimsChecker,
124    scopes: Vec<String>,
125    pkce_enabled: bool,
126    pending_oauth_store: PS,
127}
128
129impl<PS> OidcClient<PS>
130where
131    PS: PendingOauthStore,
132{
133    pub async fn from_config(config: OidcClientConfig<PS::Config>) -> OidcResult<Self> {
134        config.validate()?;
135        let provider = Arc::new(OAuthProviderRuntime::from_config(config.provider_config()).await?);
136        Self::from_provider(provider, config).await
137    }
138
139    pub async fn from_provider(
140        provider: Arc<OAuthProviderRuntime>,
141        config: OidcClientConfig<PS::Config>,
142    ) -> OidcResult<Self> {
143        config.validate()?;
144
145        let built_client = build_client(&config, provider.oidc_provider_metadata().await?)
146            .map_err(|e| OidcError::Metadata {
147                message: format!("Failed to build OIDC client from provider metadata: {e}"),
148            })?;
149
150        #[cfg(feature = "claims-script")]
151        let claims_checker =
152            ScriptClaimsChecker::from_file(config.claims_check_script.as_deref()).await?;
153        #[cfg(not(feature = "claims-script"))]
154        let claims_checker = DefaultClaimsChecker;
155
156        Ok(Self {
157            pending_oauth_store: PS::from_config_opt(config.pending_store.as_ref()),
158            config,
159            provider,
160            base_client: built_client.client,
161            claims_checker,
162            scopes: vec![],
163            pkce_enabled: false,
164        }
165        .with_runtime_flags())
166    }
167
168    pub fn provider(&self) -> &Arc<OAuthProviderRuntime> {
169        &self.provider
170    }
171
172    pub async fn handle_code_authorize(
173        &self,
174        external_base_url: &Url,
175    ) -> OidcResult<OidcCodeFlowAuthorizationRequest> {
176        self.handle_code_authorize_with_redirect_override(external_base_url, None)
177            .await
178    }
179
180    pub async fn handle_code_authorize_with_redirect_override(
181        &self,
182        external_base_url: &Url,
183        redirect_url_override: Option<&str>,
184    ) -> OidcResult<OidcCodeFlowAuthorizationRequest> {
185        self.handle_code_authorize_with_redirect_override_and_extra_data(
186            external_base_url,
187            redirect_url_override,
188            None,
189        )
190        .await
191    }
192
193    pub async fn handle_code_authorize_with_redirect_override_and_extra_data(
194        &self,
195        external_base_url: &Url,
196        redirect_url_override: Option<&str>,
197        extra_data: Option<serde_json::Value>,
198    ) -> OidcResult<OidcCodeFlowAuthorizationRequest> {
199        let authorization_request =
200            self.authorize_url_with_redirect_override(external_base_url, redirect_url_override)?;
201        self.pending_oauth_store
202            .insert(
203                authorization_request.csrf_token.secret().to_string(),
204                authorization_request.nonce.secret().to_string(),
205                authorization_request.pkce_verifier_secret.clone(),
206                extra_data,
207            )
208            .await?;
209        Ok(authorization_request)
210    }
211
212    pub async fn handle_device_authorize(&self) -> OidcResult<OidcDeviceAuthorizationResult> {
213        let client = self.fresh_device_authorization_client().await?;
214        let mut request = client.exchange_device_code();
215
216        for scope in &self.scopes {
217            request = request.add_scope(Scope::new(scope.clone()));
218        }
219
220        let details: CoreDeviceAuthorizationResponse = request
221            .request_async(self.provider.http_client())
222            .await
223            .map_err(|e| OidcError::DeviceAuthorization {
224                message: format!("Device authorization request failed: {e}"),
225            })?;
226
227        Ok(OidcDeviceAuthorizationResult {
228            device_code: details.device_code().secret().to_string(),
229            user_code: details.user_code().secret().to_string(),
230            verification_uri: details.verification_uri().to_string(),
231            verification_uri_complete: details
232                .verification_uri_complete()
233                .map(|value| value.secret().to_string()),
234            expires_in: details.expires_in(),
235            interval: Some(details.interval()),
236        })
237    }
238
239    pub async fn handle_device_token_poll(
240        &self,
241        device_authorization: &OidcDeviceAuthorizationResult,
242        current_interval: Option<Duration>,
243    ) -> OidcResult<OidcDeviceTokenPollResult> {
244        let current_interval = current_interval.unwrap_or_else(|| {
245            device_authorization.poll_interval(self.config.device_poll_interval)
246        });
247
248        match self.request_device_token_once(device_authorization).await? {
249            DeviceTokenPollResponse::Complete(token_response) => {
250                let token_result = self.build_device_token_result(*token_response).await?;
251                Ok(OidcDeviceTokenPollResult::Complete {
252                    token_result: Box::new(token_result),
253                })
254            }
255            DeviceTokenPollResponse::Pending => Ok(OidcDeviceTokenPollResult::Pending {
256                interval: current_interval,
257            }),
258            DeviceTokenPollResponse::SlowDown => Ok(OidcDeviceTokenPollResult::SlowDown {
259                interval: current_interval.saturating_add(Duration::from_secs(5)),
260            }),
261            DeviceTokenPollResponse::Denied { error_description } => {
262                Ok(OidcDeviceTokenPollResult::Denied { error_description })
263            }
264            DeviceTokenPollResponse::Expired { error_description } => {
265                Ok(OidcDeviceTokenPollResult::Expired { error_description })
266            }
267        }
268    }
269
270    pub async fn handle_device_token_poll_until_complete(
271        &self,
272        device_authorization: &OidcDeviceAuthorizationResult,
273        timeout: Option<Duration>,
274    ) -> OidcResult<OidcDeviceTokenResult> {
275        let started_at = std::time::Instant::now();
276        let mut interval = device_authorization.poll_interval(self.config.device_poll_interval);
277
278        // Enforce a minimum interval of 1 second to prevent busy-polling
279        // when the server returns interval=0.
280        const MIN_POLL_INTERVAL: Duration = Duration::from_secs(1);
281
282        loop {
283            if let Some(timeout) = timeout {
284                let elapsed = started_at.elapsed();
285                if elapsed >= timeout {
286                    return Err(OidcError::DeviceTokenPoll {
287                        message: format!(
288                            "Device token polling timed out after {} seconds",
289                            timeout.as_secs()
290                        ),
291                    });
292                }
293            }
294
295            match self
296                .handle_device_token_poll(device_authorization, Some(interval))
297                .await?
298            {
299                OidcDeviceTokenPollResult::Complete { token_result } => return Ok(*token_result),
300                OidcDeviceTokenPollResult::Pending {
301                    interval: next_interval,
302                }
303                | OidcDeviceTokenPollResult::SlowDown {
304                    interval: next_interval,
305                } => {
306                    interval = next_interval.max(MIN_POLL_INTERVAL);
307                    let sleep_duration = if let Some(timeout) = timeout {
308                        let remaining = timeout.saturating_sub(started_at.elapsed());
309                        min(interval, remaining)
310                    } else {
311                        interval
312                    };
313                    tokio::time::sleep(sleep_duration).await;
314                }
315                OidcDeviceTokenPollResult::Denied { error_description } => {
316                    return Err(OidcError::DeviceTokenPoll {
317                        message: format_device_token_terminal_message(
318                            "access_denied",
319                            error_description.as_deref(),
320                        ),
321                    });
322                }
323                OidcDeviceTokenPollResult::Expired { error_description } => {
324                    return Err(OidcError::DeviceTokenPoll {
325                        message: format_device_token_terminal_message(
326                            "expired_token",
327                            error_description.as_deref(),
328                        ),
329                    });
330                }
331            }
332        }
333    }
334
335    pub async fn handle_code_callback(
336        &self,
337        search_params: OidcCodeCallbackSearchParams,
338        external_base_url: &Url,
339    ) -> OidcResult<OidcCodeCallbackResult> {
340        self.handle_code_callback_with_redirect_override_diagnosed(
341            search_params,
342            external_base_url,
343            None,
344        )
345        .await
346        .into_result()
347    }
348
349    pub async fn handle_code_callback_with_redirect_override(
350        &self,
351        search_params: OidcCodeCallbackSearchParams,
352        external_base_url: &Url,
353        redirect_url_override: Option<&str>,
354    ) -> OidcResult<OidcCodeCallbackResult> {
355        self.handle_code_callback_with_redirect_override_diagnosed(
356            search_params,
357            external_base_url,
358            redirect_url_override,
359        )
360        .await
361        .into_result()
362    }
363
364    pub async fn handle_code_callback_with_redirect_override_diagnosed(
365        &self,
366        search_params: OidcCodeCallbackSearchParams,
367        external_base_url: &Url,
368        redirect_url_override: Option<&str>,
369    ) -> DiagnosedResult<OidcCodeCallbackResult, OidcError> {
370        let diagnosis = AuthFlowDiagnosis::started(AuthFlowOperation::OIDC_CALLBACK)
371            .field("redirect_override", redirect_url_override)
372            .field(
373                AuthFlowDiagnosisField::EXTERNAL_BASE_URL,
374                external_base_url.as_str(),
375            )
376            .field("pkce_enabled", self.pkce_enabled)
377            .field(
378                AuthFlowDiagnosisField::HAS_STATE,
379                search_params.state.is_some(),
380            )
381            .field(
382                AuthFlowDiagnosisField::HAS_CODE,
383                !search_params.code.is_empty(),
384            );
385
386        let code = &search_params.code;
387        let state = search_params
388            .state
389            .as_ref()
390            .ok_or_else(|| OidcError::CSRFValidation {
391                message: "Missing state parameter in callback (required for CSRF validation)"
392                    .to_string(),
393            });
394
395        let state = match state {
396            Ok(state) => state,
397            Err(error) => {
398                return DiagnosedResult::failure(
399                    diagnosis
400                        .with_outcome(AuthFlowDiagnosisOutcome::Rejected)
401                        .field(AuthFlowDiagnosisField::FAILURE_STAGE, "csrf_validation"),
402                    error,
403                );
404            }
405        };
406
407        let pending = match self.pending_oauth_store.take(state).await {
408            Ok(pending) => pending.ok_or_else(|| OidcError::PendingOauth {
409                source: "Invalid or expired state (reuse or unknown); try logging in again"
410                    .to_string()
411                    .into(),
412            }),
413            Err(error) => {
414                return DiagnosedResult::failure(
415                    diagnosis
416                        .with_outcome(AuthFlowDiagnosisOutcome::Failed)
417                        .field(AuthFlowDiagnosisField::FAILURE_STAGE, "pending_oauth_store"),
418                    error,
419                );
420            }
421        };
422
423        let pending = match pending {
424            Ok(pending) => pending,
425            Err(error) => {
426                return DiagnosedResult::failure(
427                    diagnosis
428                        .with_outcome(AuthFlowDiagnosisOutcome::Rejected)
429                        .field(AuthFlowDiagnosisField::FAILURE_STAGE, "pending_oauth_state"),
430                    error,
431                );
432            }
433        };
434
435        let nonce = openidconnect::Nonce::new(pending.nonce.clone());
436        let code_verifier = pending.code_verifier;
437
438        let code_exchange = self
439            .exchange_code_with_redirect_override(
440                external_base_url,
441                code,
442                &nonce,
443                code_verifier.as_deref(),
444                redirect_url_override,
445            )
446            .await;
447
448        let code_exchange = match code_exchange {
449            Ok(code_exchange) => code_exchange,
450            Err(error) => {
451                return DiagnosedResult::failure(
452                    diagnosis
453                        .with_outcome(AuthFlowDiagnosisOutcome::Failed)
454                        .field(AuthFlowDiagnosisField::FAILURE_STAGE, "token_exchange"),
455                    error,
456                );
457            }
458        };
459
460        let claims_check_result = self
461            .check_claims(
462                &code_exchange.id_token_claims,
463                code_exchange.user_info_claims.as_ref(),
464            )
465            .await;
466
467        let claims_check_result = match claims_check_result {
468            Ok(claims_check_result) => claims_check_result,
469            Err(error) => {
470                return DiagnosedResult::failure(
471                    diagnosis
472                        .with_outcome(AuthFlowDiagnosisOutcome::Failed)
473                        .field(AuthFlowDiagnosisField::FAILURE_STAGE, "claims_check"),
474                    error,
475                );
476            }
477        };
478
479        let result = OidcCodeCallbackResult {
480            code: search_params.code,
481            pkce_verifier_secret: code_verifier,
482            state: search_params.state,
483            nonce: pending.nonce,
484            pending_extra_data: pending.extra_data,
485            access_token: code_exchange.access_token,
486            access_token_expiration: code_exchange.access_token_expiration,
487            id_token: code_exchange.id_token,
488            refresh_token: code_exchange.refresh_token,
489            id_token_claims: code_exchange.id_token_claims,
490            user_info_claims: code_exchange.user_info_claims,
491            claims_check_result,
492        };
493
494        DiagnosedResult::success(
495            diagnosis
496                .with_outcome(AuthFlowDiagnosisOutcome::Succeeded)
497                .field(
498                    AuthFlowDiagnosisField::SUBJECT,
499                    result.id_token_claims.subject().to_string(),
500                )
501                .field("has_refresh_token", result.refresh_token.is_some())
502                .field("has_user_info_claims", result.user_info_claims.is_some()),
503            result,
504        )
505    }
506
507    pub async fn handle_token_refresh(
508        &self,
509        refresh_token: String,
510        // optional previous id_token to prevent not return new id_token
511        id_token: Option<String>,
512    ) -> OidcResult<OidcRefreshTokenResult> {
513        self.handle_token_refresh_diagnosed(refresh_token, id_token)
514            .await
515            .into_result()
516    }
517
518    pub async fn handle_token_refresh_diagnosed(
519        &self,
520        refresh_token: String,
521        id_token: Option<String>,
522    ) -> DiagnosedResult<OidcRefreshTokenResult, OidcError> {
523        let diagnosis = AuthFlowDiagnosis::started(AuthFlowOperation::OIDC_TOKEN_REFRESH)
524            .field("has_previous_id_token", id_token.is_some())
525            .field("pkce_enabled", self.pkce_enabled);
526
527        let client = match self.fresh_client().await {
528            Ok(client) => client,
529            Err(error) => {
530                return DiagnosedResult::failure(
531                    diagnosis
532                        .with_outcome(AuthFlowDiagnosisOutcome::Failed)
533                        .field(
534                            AuthFlowDiagnosisField::FAILURE_STAGE,
535                            "client_metadata_refresh",
536                        ),
537                    error,
538                );
539            }
540        };
541        let refresh_token = RefreshToken::new(refresh_token);
542        let now = Utc::now();
543
544        let token_request =
545            client
546                .exchange_refresh_token(&refresh_token)
547                .map_err(|e| OidcError::TokenRefresh {
548                    message: format!("Token endpoint not set or config error: {e}"),
549                });
550
551        let token_request = match token_request {
552            Ok(token_request) => token_request,
553            Err(error) => {
554                return DiagnosedResult::failure(
555                    diagnosis
556                        .with_outcome(AuthFlowDiagnosisOutcome::Failed)
557                        .field(
558                            AuthFlowDiagnosisField::FAILURE_STAGE,
559                            "token_refresh_request_build",
560                        ),
561                    error,
562                );
563            }
564        };
565
566        let token_response = token_request
567            .request_async(self.provider.http_client())
568            .await
569            .map_err(|e| OidcError::TokenRefresh {
570                message: format!("Refresh token request failed: {e}"),
571            });
572
573        let token_response = match token_response {
574            Ok(token_response) => token_response,
575            Err(error) => {
576                return DiagnosedResult::failure(
577                    diagnosis
578                        .with_outcome(AuthFlowDiagnosisOutcome::Failed)
579                        .field(AuthFlowDiagnosisField::FAILURE_STAGE, "token_refresh"),
580                    error,
581                );
582            }
583        };
584
585        let access_token = token_response.access_token().secret().clone();
586        let access_token_expiration = token_response
587            .expires_in()
588            .map(|expires_in| now + expires_in);
589        let refresh_token = token_response
590            .refresh_token()
591            .map(|value| value.secret().clone());
592        let id_token = token_response
593            .id_token()
594            .map(|value| value.to_string())
595            .or(id_token);
596
597        let mut result = OidcRefreshTokenResult {
598            access_token,
599            access_token_expiration,
600            refresh_token,
601            id_token,
602            user_info_claims: None,
603            claims_check_result: None,
604            id_token_claims: None,
605        };
606
607        // Validate required scopes after successful refresh.
608        if let Err(error) = self.check_required_scopes(token_response.scopes()) {
609            return DiagnosedResult::failure(
610                diagnosis
611                    .with_outcome(AuthFlowDiagnosisOutcome::Failed)
612                    .field("failure_stage", "scope_validation"),
613                error,
614            );
615        }
616
617        if let Some(next_id_token) = token_response.extra_fields().id_token() {
618            let id_token_verifier = client.id_token_verifier();
619            let id_token_claims = next_id_token
620                .claims(&id_token_verifier, |_nonce: Option<&Nonce>| Ok(()))
621                .map_err(|e| OidcError::TokenRefresh {
622                    message: format!("Failed to verify refreshed ID token: {e}"),
623                });
624            let id_token_claims = match id_token_claims {
625                Ok(id_token_claims) => id_token_claims,
626                Err(error) => {
627                    return DiagnosedResult::failure(
628                        diagnosis
629                            .with_outcome(AuthFlowDiagnosisOutcome::Failed)
630                            .field("failure_stage", "id_token_verification"),
631                        error,
632                    );
633                }
634            };
635            let user_info_claims = if client.user_info_url().is_some() {
636                match self
637                    .request_userinfo(
638                        &client,
639                        self.provider.http_client(),
640                        token_response.access_token().clone(),
641                        Some(id_token_claims.subject().clone()),
642                    )
643                    .await
644                {
645                    Ok(user_info_claims) => Some(user_info_claims),
646                    Err(error) => {
647                        return DiagnosedResult::failure(
648                            diagnosis
649                                .with_outcome(AuthFlowDiagnosisOutcome::Failed)
650                                .field("failure_stage", "userinfo_exchange"),
651                            error,
652                        );
653                    }
654                }
655            } else {
656                None
657            };
658            let claims_check_result = self
659                .check_claims(id_token_claims, user_info_claims.as_ref())
660                .await;
661            let claims_check_result = match claims_check_result {
662                Ok(claims_check_result) => claims_check_result,
663                Err(error) => {
664                    return DiagnosedResult::failure(
665                        diagnosis
666                            .with_outcome(AuthFlowDiagnosisOutcome::Failed)
667                            .field("failure_stage", "claims_check"),
668                        error,
669                    );
670                }
671            };
672            result.id_token = Some(next_id_token.to_string());
673            result.id_token_claims = Some(id_token_claims.clone());
674            result.user_info_claims = user_info_claims;
675            result.claims_check_result = Some(claims_check_result);
676        }
677
678        DiagnosedResult::success(
679            diagnosis
680                .with_outcome(AuthFlowDiagnosisOutcome::Succeeded)
681                .field("has_refresh_token", result.refresh_token.is_some())
682                .field("has_new_id_token", result.id_token.is_some())
683                .field(
684                    "has_claims_check_result",
685                    result.claims_check_result.is_some(),
686                ),
687            result,
688        )
689    }
690
691    pub async fn handle_token_revoke(&self, token: OidcRevocableToken) -> OidcResult<()> {
692        let client = self.fresh_revocation_client().await?;
693        let token: CoreRevocableToken = match token {
694            OidcRevocableToken::AccessToken(token) => AccessToken::new(token).into(),
695            OidcRevocableToken::RefreshToken(token) => RefreshToken::new(token).into(),
696        };
697
698        client
699            .revoke_token(token)
700            .map_err(|e| OidcError::TokenRevocation {
701                message: format!("Revocation endpoint not set or config error: {e}"),
702            })?
703            .request_async(self.provider.http_client())
704            .await
705            .map_err(|e| OidcError::TokenRevocation {
706                message: format!("Token revocation request failed: {e}"),
707            })
708    }
709
710    /// Shared `user_info` exchange helper for backend modes.
711    ///
712    /// Given a raw `id_token` string and a bearer `access_token`, this method:
713    ///
714    /// 1. Decodes and verifies the ID token (nonce validation is skipped, since
715    ///    this is a server-side post-flow call, not an in-flight callback).
716    /// 2. Optionally calls the provider's userinfo endpoint (if available).
717    /// 3. Runs `check_claims` to produce a `ClaimsCheckResult`.
718    ///
719    /// Backend OIDC presets (pure, mediated, etc.) should call
720    /// this helper rather than reimplementing the user-info protocol stack.
721    pub async fn handle_user_info_exchange(
722        &self,
723        id_token_raw: &str,
724        access_token: &str,
725    ) -> OidcResult<UserInfoExchangeResult> {
726        let client = self.fresh_client().await?;
727        let id_token_verifier = client.id_token_verifier();
728
729        // Parse the raw ID token string into the typed token via serde.
730        let id_token: openidconnect::IdToken<
731            ExtraOidcClaims,
732            CoreGenderClaim,
733            CoreJweContentEncryptionAlgorithm,
734            CoreJwsSigningAlgorithm,
735        > = serde_json::from_value(serde_json::Value::String(id_token_raw.to_string())).map_err(
736            |e| OidcError::Claims {
737                message: format!("Failed to parse ID token string in user_info exchange: {e}"),
738            },
739        )?;
740
741        // Verify and decode — skip nonce for server-side post-flow calls.
742        let id_token_claims = id_token
743            .claims(&id_token_verifier, |_nonce: Option<&Nonce>| Ok(()))
744            .map_err(|e| OidcError::Claims {
745                message: format!("Failed to verify ID token in user_info exchange: {e}"),
746            })?;
747
748        let access_token_obj = AccessToken::new(access_token.to_string());
749
750        let user_info_claims = if client.user_info_url().is_some() {
751            Some(
752                self.request_userinfo(
753                    &client,
754                    self.provider.http_client(),
755                    access_token_obj,
756                    Some(id_token_claims.subject().clone()),
757                )
758                .await?,
759            )
760        } else {
761            None
762        };
763
764        let claims_check_result = self
765            .check_claims(id_token_claims, user_info_claims.as_ref())
766            .await?;
767
768        let issuer = id_token_claims.issuer().url().to_string();
769
770        Ok(UserInfoExchangeResult {
771            subject: id_token_claims.subject().to_string(),
772            display_name: claims_check_result.display_name,
773            picture: claims_check_result.picture,
774            issuer: Some(issuer),
775            claims: Some(claims_check_result.claims),
776        })
777    }
778
779    async fn request_userinfo(
780        &self,
781        client: &DiscoveredClientWithExtra,
782        http_client: &reqwest::Client,
783        access_token: openidconnect::AccessToken,
784        expected_subject: Option<SubjectIdentifier>,
785    ) -> OidcResult<UserInfoClaimsWithExtra> {
786        client
787            .user_info(access_token, expected_subject)
788            .map_err(|e| OidcError::Claims {
789                message: format!("UserInfo request configuration failed: {e}"),
790            })?
791            .request_async(http_client)
792            .await
793            .map_err(|e| OidcError::Claims {
794                message: format!("UserInfo request failed: {e}"),
795            })
796    }
797
798    async fn check_claims(
799        &self,
800        id_token_claims: &IdTokenClaimsWithExtra,
801        user_info_claims: Option<&UserInfoClaimsWithExtra>,
802    ) -> OidcResult<ClaimsCheckResult> {
803        self.claims_checker
804            .check_claims(id_token_claims, user_info_claims)
805            .await
806    }
807
808    fn resolve_redirect_url(
809        &self,
810        external_base_url: &Url,
811        redirect_url_override: Option<&str>,
812    ) -> OidcResult<Url> {
813        external_base_url
814            .join(redirect_url_override.unwrap_or(&self.config.redirect_url))
815            .map_err(|e| OidcError::RedirectUrl { source: e })
816    }
817
818    fn client_with_redirect_override(
819        &self,
820        external_base_url: &Url,
821        redirect_url_override: Option<&str>,
822    ) -> OidcResult<DiscoveredClientWithExtra> {
823        let redirect_url = self.resolve_redirect_url(external_base_url, redirect_url_override)?;
824        Ok(self
825            .base_client
826            .clone()
827            .set_redirect_uri(RedirectUrl::from_url(redirect_url)))
828    }
829
830    async fn fresh_client(&self) -> OidcResult<DiscoveredClientWithExtra> {
831        Ok(self.fresh_client_parts().await?.client)
832    }
833
834    async fn fresh_client_parts(&self) -> OidcResult<BuiltClientWithExtra> {
835        build_client(&self.config, self.provider.oidc_provider_metadata().await?).map_err(|e| {
836            OidcError::Metadata {
837                message: format!("Failed to rebuild OIDC client from provider metadata: {e}"),
838            }
839        })
840    }
841
842    async fn fresh_device_authorization_client(
843        &self,
844    ) -> OidcResult<DeviceAuthorizationClientWithExtra> {
845        let built_client = self.fresh_client_parts().await?;
846        let device_authorization_endpoint = built_client
847            .optional_endpoints
848            .device_authorization_endpoint
849            .ok_or_else(|| OidcError::DeviceAuthorization {
850                message: "Device authorization endpoint not set or config error: device \
851                          authorization endpoint URL is not set"
852                    .to_string(),
853            })?;
854
855        Ok(built_client
856            .client
857            .set_device_authorization_url(device_authorization_endpoint))
858    }
859
860    async fn fresh_revocation_client(&self) -> OidcResult<RevocationClientWithExtra> {
861        let built_client = self.fresh_client_parts().await?;
862        let revocation_endpoint = built_client
863            .optional_endpoints
864            .revocation_endpoint
865            .ok_or_else(|| OidcError::TokenRevocation {
866                message: "Revocation endpoint not set or config error: revocation endpoint URL is \
867                          not set"
868                    .to_string(),
869            })?;
870
871        Ok(built_client.client.set_revocation_url(revocation_endpoint))
872    }
873
874    async fn fresh_client_with_redirect_override(
875        &self,
876        external_base_url: &Url,
877        redirect_url_override: Option<&str>,
878    ) -> OidcResult<DiscoveredClientWithExtra> {
879        let redirect_url = self.resolve_redirect_url(external_base_url, redirect_url_override)?;
880        Ok(self
881            .fresh_client()
882            .await?
883            .set_redirect_uri(RedirectUrl::from_url(redirect_url)))
884    }
885
886    pub fn authorize_url(
887        &self,
888        external_base_url: &Url,
889    ) -> OidcResult<OidcCodeFlowAuthorizationRequest> {
890        self.authorize_url_with_redirect_override(external_base_url, None)
891    }
892
893    pub fn authorize_url_with_redirect_override(
894        &self,
895        external_base_url: &Url,
896        redirect_url_override: Option<&str>,
897    ) -> OidcResult<OidcCodeFlowAuthorizationRequest> {
898        let client =
899            self.client_with_redirect_override(external_base_url, redirect_url_override)?;
900
901        let mut req = client.authorize_url(
902            AuthenticationFlow::<openidconnect::core::CoreResponseType>::AuthorizationCode,
903            CsrfToken::new_random,
904            Nonce::new_random,
905        );
906
907        let pkce_verifier_secret = if self.pkce_enabled {
908            let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
909            req = req.set_pkce_challenge(pkce_challenge);
910            Some(pkce_verifier.into_secret())
911        } else {
912            None
913        };
914
915        for scope in &self.scopes {
916            req = req.add_scope(Scope::new(scope.clone()));
917        }
918
919        let (authorization_url, csrf_token, nonce) = req.url();
920        Ok(OidcCodeFlowAuthorizationRequest {
921            authorization_url,
922            csrf_token,
923            nonce,
924            pkce_verifier_secret,
925        })
926    }
927
928    pub async fn exchange_code(
929        &self,
930        external_base_url: &Url,
931        code: &str,
932        nonce: &Nonce,
933        pkce_verifier_secret: Option<&str>,
934    ) -> OidcResult<OidcCodeExchangeResult> {
935        self.exchange_code_with_redirect_override(
936            external_base_url,
937            code,
938            nonce,
939            pkce_verifier_secret,
940            None,
941        )
942        .await
943    }
944
945    pub async fn exchange_code_with_redirect_override(
946        &self,
947        external_base_url: &Url,
948        code: &str,
949        nonce: &Nonce,
950        pkce_verifier_secret: Option<&str>,
951        redirect_url_override: Option<&str>,
952    ) -> OidcResult<OidcCodeExchangeResult> {
953        let client = self
954            .fresh_client_with_redirect_override(external_base_url, redirect_url_override)
955            .await?;
956
957        let mut token_request = client
958            .exchange_code(AuthorizationCode::new(code.to_string()))
959            .map_err(|e| OidcError::TokenExchange {
960                message: format!("Token endpoint not set or config error: {e}"),
961            })?;
962
963        if let Some(secret) = pkce_verifier_secret {
964            token_request =
965                token_request.set_pkce_verifier(PkceCodeVerifier::new(secret.to_string()));
966        }
967
968        let token_response = token_request
969            .request_async(self.provider.http_client())
970            .await
971            .map_err(|e| OidcError::TokenExchange {
972                message: format!("Token exchange request failed: {e}"),
973            })?;
974
975        let id_token_verifier = client.id_token_verifier();
976        let id_token =
977            token_response
978                .extra_fields()
979                .id_token()
980                .ok_or_else(|| OidcError::TokenExchange {
981                    message: "Missing ID token in token response".to_string(),
982                })?;
983
984        let id_token_claims =
985            id_token
986                .claims(&id_token_verifier, nonce)
987                .map_err(|e| OidcError::TokenExchange {
988                    message: format!("Failed to verify ID token: {e}"),
989                })?;
990
991        let now = Utc::now();
992        let id_token = id_token.to_string();
993        let access_token = token_response.access_token().secret().clone();
994        let access_token_expiration = token_response
995            .expires_in()
996            .map(|expires_in| now + expires_in);
997        let refresh_token = token_response
998            .refresh_token()
999            .map(|value| value.secret().clone());
1000
1001        let user_info_claims = if client.user_info_url().is_some() {
1002            Some(
1003                self.request_userinfo(
1004                    &client,
1005                    self.provider.http_client(),
1006                    token_response.access_token().clone(),
1007                    Some(id_token_claims.subject().clone()),
1008                )
1009                .await?,
1010            )
1011        } else {
1012            None
1013        };
1014
1015        // Validate required scopes after successful exchange.
1016        self.check_required_scopes(token_response.scopes())?;
1017
1018        Ok(OidcCodeExchangeResult {
1019            id_token,
1020            id_token_claims: id_token_claims.to_owned(),
1021            refresh_token,
1022            access_token,
1023            access_token_expiration,
1024            user_info_claims,
1025        })
1026    }
1027
1028    /// Verify that the token response's `scope` field covers all
1029    /// `required_scopes` configured for this client.
1030    ///
1031    /// A `None` scope field in the response is treated as "unknown" and
1032    /// the check is skipped (the provider chose not to echo back the scope).
1033    /// Returns `Err(OidcError::ScopeValidation)` listing the missing scopes
1034    /// when the check fails.
1035    fn check_required_scopes(
1036        &self,
1037        response_scopes: Option<&Vec<openidconnect::Scope>>,
1038    ) -> OidcResult<()> {
1039        if self.config.required_scopes.is_empty() {
1040            return Ok(());
1041        }
1042        let granted = match response_scopes {
1043            Some(scopes) => scopes,
1044            // Provider omitted the scope field — skip check per RFC 6749 §5.1.
1045            None => return Ok(()),
1046        };
1047        let granted_strs: Vec<&str> = granted.iter().map(|s| s.as_str()).collect();
1048        let missing: Vec<String> = self
1049            .config
1050            .required_scopes
1051            .iter()
1052            .filter(|req| !granted_strs.contains(&req.as_str()))
1053            .cloned()
1054            .collect();
1055        if missing.is_empty() {
1056            Ok(())
1057        } else {
1058            Err(OidcError::ScopeValidation { missing })
1059        }
1060    }
1061
1062    fn with_runtime_flags(mut self) -> Self {
1063        self.scopes = self.config.scopes.clone();
1064        self.pkce_enabled = self.config.pkce_enabled;
1065        self
1066    }
1067
1068    async fn request_device_token_once(
1069        &self,
1070        device_authorization: &OidcDeviceAuthorizationResult,
1071    ) -> OidcResult<DeviceTokenPollResponse> {
1072        let client = self.fresh_client().await?;
1073        let token_url = client
1074            .token_uri()
1075            .cloned()
1076            .ok_or_else(|| OidcError::DeviceTokenPoll {
1077                message: "Token endpoint not set for device token polling".to_string(),
1078            })?;
1079
1080        let auth_type = self.resolve_token_endpoint_auth_type().await?;
1081        let mut params = vec![
1082            (
1083                Cow::Borrowed("grant_type"),
1084                Cow::Borrowed("urn:ietf:params:oauth:grant-type:device_code"),
1085            ),
1086            (
1087                Cow::Borrowed("device_code"),
1088                Cow::Owned(device_authorization.device_code.clone()),
1089            ),
1090        ];
1091
1092        if matches!(auth_type, AuthType::RequestBody) {
1093            params.push((
1094                Cow::Borrowed("client_id"),
1095                Cow::Owned(self.config.client_id.clone()),
1096            ));
1097            if let Some(client_secret) = self.config.client_secret.as_ref() {
1098                params.push((
1099                    Cow::Borrowed("client_secret"),
1100                    Cow::Owned(client_secret.clone()),
1101                ));
1102            }
1103        }
1104
1105        let mut request = self
1106            .provider
1107            .http_client()
1108            .post(token_url.url().clone())
1109            .header(reqwest::header::ACCEPT, "application/json")
1110            .form(&params);
1111
1112        if matches!(auth_type, AuthType::BasicAuth) {
1113            let client_secret =
1114                self.config
1115                    .client_secret
1116                    .as_ref()
1117                    .ok_or_else(|| OidcError::DeviceTokenPoll {
1118                        message: "client_secret is required for basic token endpoint auth"
1119                            .to_string(),
1120                    })?;
1121            let credentials = format!(
1122                "{}:{}",
1123                form_urlencode(&self.config.client_id),
1124                form_urlencode(client_secret)
1125            );
1126            let header_value = format!(
1127                "Basic {}",
1128                base64::engine::general_purpose::STANDARD.encode(credentials)
1129            );
1130            request = request.header(reqwest::header::AUTHORIZATION, header_value);
1131        }
1132
1133        let response = request
1134            .send()
1135            .await
1136            .map_err(|e| OidcError::DeviceTokenPoll {
1137                message: format!("Device token poll request failed: {e}"),
1138            })?;
1139        let status = response.status();
1140        let body = response
1141            .bytes()
1142            .await
1143            .map_err(|e| OidcError::DeviceTokenPoll {
1144                message: format!("Failed to read device token poll response: {e}"),
1145            })?;
1146
1147        if status.is_success() {
1148            let token_response =
1149                serde_json::from_slice::<TokenResponseWithExtra>(&body).map_err(|e| {
1150                    OidcError::DeviceTokenPoll {
1151                        message: format!(
1152                            "Failed to parse device token response: {e}; body: {}",
1153                            String::from_utf8_lossy(&body)
1154                        ),
1155                    }
1156                })?;
1157            return Ok(DeviceTokenPollResponse::Complete(Box::new(token_response)));
1158        }
1159
1160        let error_response =
1161            serde_json::from_slice::<DeviceCodeErrorResponse>(&body).map_err(|e| {
1162                OidcError::DeviceTokenPoll {
1163                    message: format!(
1164                        "Device token poll failed with HTTP {} and an unparseable body: {e}; \
1165                         body: {}",
1166                        status,
1167                        String::from_utf8_lossy(&body)
1168                    ),
1169                }
1170            })?;
1171
1172        match error_response.error() {
1173            DeviceCodeErrorResponseType::AuthorizationPending => {
1174                Ok(DeviceTokenPollResponse::Pending)
1175            }
1176            DeviceCodeErrorResponseType::SlowDown => Ok(DeviceTokenPollResponse::SlowDown),
1177            DeviceCodeErrorResponseType::AccessDenied => Ok(DeviceTokenPollResponse::Denied {
1178                error_description: error_response.error_description().cloned(),
1179            }),
1180            DeviceCodeErrorResponseType::ExpiredToken => Ok(DeviceTokenPollResponse::Expired {
1181                error_description: error_response.error_description().cloned(),
1182            }),
1183            other => Err(OidcError::DeviceTokenPoll {
1184                message: format!("Device token poll returned terminal error: {other}"),
1185            }),
1186        }
1187    }
1188
1189    async fn build_device_token_result(
1190        &self,
1191        token_response: TokenResponseWithExtra,
1192    ) -> OidcResult<OidcDeviceTokenResult> {
1193        let client = self.fresh_client().await?;
1194        let id_token_verifier = client.id_token_verifier();
1195        let id_token =
1196            token_response
1197                .extra_fields()
1198                .id_token()
1199                .ok_or_else(|| OidcError::DeviceTokenPoll {
1200                    message: "Missing ID token in device token response".to_string(),
1201                })?;
1202        let id_token_claims = id_token
1203            .claims(&id_token_verifier, |_nonce: Option<&Nonce>| Ok(()))
1204            .map_err(|e| OidcError::DeviceTokenPoll {
1205                message: format!("Failed to verify device-flow ID token: {e}"),
1206            })?;
1207
1208        let now = Utc::now();
1209        let access_token = token_response.access_token().secret().clone();
1210        let access_token_expiration = token_response
1211            .expires_in()
1212            .map(|expires_in| now + expires_in);
1213        let refresh_token = token_response
1214            .refresh_token()
1215            .map(|value| value.secret().clone());
1216
1217        let user_info_claims = if client.user_info_url().is_some() {
1218            Some(
1219                self.request_userinfo(
1220                    &client,
1221                    self.provider.http_client(),
1222                    token_response.access_token().clone(),
1223                    Some(id_token_claims.subject().clone()),
1224                )
1225                .await?,
1226            )
1227        } else {
1228            None
1229        };
1230        let claims_check_result = self
1231            .check_claims(id_token_claims, user_info_claims.as_ref())
1232            .await?;
1233
1234        Ok(OidcDeviceTokenResult {
1235            access_token,
1236            access_token_expiration,
1237            id_token: id_token.to_string(),
1238            refresh_token,
1239            id_token_claims: id_token_claims.to_owned(),
1240            user_info_claims,
1241            claims_check_result,
1242        })
1243    }
1244
1245    async fn resolve_token_endpoint_auth_type(&self) -> OidcResult<AuthType> {
1246        let metadata = self.provider.oidc_provider_metadata().await?;
1247        let supported = metadata.token_endpoint_auth_methods_supported();
1248
1249        if self.config.client_secret.is_none() {
1250            return Ok(AuthType::RequestBody);
1251        }
1252
1253        let supports_basic = supported
1254            .is_none_or(|methods| methods.contains(&CoreClientAuthMethod::ClientSecretBasic));
1255        if supports_basic {
1256            return Ok(AuthType::BasicAuth);
1257        }
1258
1259        let supports_request_body = supported.is_some_and(|methods| {
1260            methods.contains(&CoreClientAuthMethod::ClientSecretPost)
1261                || methods.contains(&CoreClientAuthMethod::None)
1262        });
1263        if supports_request_body {
1264            return Ok(AuthType::RequestBody);
1265        }
1266
1267        Err(OidcError::DeviceTokenPoll {
1268            message: "The provider only advertises unsupported token endpoint auth methods for \
1269                      device polling"
1270                .to_string(),
1271        })
1272    }
1273}
1274
1275enum DeviceTokenPollResponse {
1276    Pending,
1277    SlowDown,
1278    Denied { error_description: Option<String> },
1279    Expired { error_description: Option<String> },
1280    // Box the large variant to keep all arms at pointer size and silence
1281    // clippy::large_enum_variant
1282    Complete(Box<TokenResponseWithExtra>),
1283}
1284
1285fn form_urlencode(value: &str) -> String {
1286    url::form_urlencoded::byte_serialize(value.as_bytes()).collect()
1287}
1288
1289fn format_device_token_terminal_message(
1290    error_code: &str,
1291    error_description: Option<&str>,
1292) -> String {
1293    match error_description {
1294        Some(error_description) => {
1295            format!("Device token polling stopped with {error_code}: {error_description}")
1296        }
1297        None => format!("Device token polling stopped with {error_code}"),
1298    }
1299}
1300
1301fn build_client(
1302    config: &OidcClientConfig<impl PendingOauthStoreConfig>,
1303    metadata: ProviderMetadataWithExtra,
1304) -> Result<BuiltClientWithExtra, String> {
1305    let client_id = ClientId::new(config.client_id.clone());
1306    let client_secret = config
1307        .client_secret
1308        .as_ref()
1309        .map(|value| ClientSecret::new(value.clone()));
1310
1311    let introspection_endpoint = metadata
1312        .additional_metadata()
1313        .introspection_endpoint
1314        .as_ref()
1315        .map(|value| IntrospectionUrl::new(value.clone()))
1316        .transpose()
1317        .map_err(|e| format!("Invalid introspection_endpoint: {e}"))?;
1318    let revocation_endpoint = metadata
1319        .additional_metadata()
1320        .revocation_endpoint
1321        .as_ref()
1322        .map(|value| RevocationUrl::new(value.clone()))
1323        .transpose()
1324        .map_err(|e| format!("Invalid revocation_endpoint: {e}"))?;
1325    let device_authorization_endpoint = metadata
1326        .additional_metadata()
1327        .device_authorization_endpoint
1328        .as_ref()
1329        .map(|value| DeviceAuthorizationUrl::new(value.clone()))
1330        .transpose()
1331        .map_err(|e| format!("Invalid device_authorization_endpoint: {e}"))?;
1332
1333    Ok(BuiltClientWithExtra {
1334        client: ClientWithExtra::from_provider_metadata(metadata, client_id, client_secret),
1335        optional_endpoints: OptionalClientEndpoints {
1336            _introspection_endpoint: introspection_endpoint,
1337            revocation_endpoint,
1338            device_authorization_endpoint,
1339        },
1340    })
1341}