Skip to main content

rustauth_plugins/generic_oauth/
provider.rs

1use rustauth_oauth::oauth2::{
2    create_authorization_code_request, create_authorization_url,
3    create_refresh_access_token_request, exchange_authorization_code, refresh_access_token_at,
4    AuthorizationCodeRequest, AuthorizationUrlRequest, OAuth2Tokens, OAuth2UserInfo, OAuthError,
5    OAuthFormRequest, OAuthHttpClient, ProviderOptions, RefreshAccessTokenRequest,
6    SocialAuthorizationCodeRequest, SocialAuthorizationUrlRequest, SocialIdTokenRequest,
7    SocialOAuthProvider, SocialProviderFuture,
8};
9use url::Url;
10
11use super::config::{GenericOAuthConfig, GenericOAuthTokenRequest};
12use super::discovery::{resolve_http_client, DiscoveryCache};
13use super::user_info;
14
15/// Social provider implementation used by the generic OAuth plugin.
16///
17/// `SocialOAuthProvider::create_authorization_url` is synchronous, so providers that only
18/// define `discovery_url` cannot resolve their authorization endpoint through this trait method.
19/// Use the plugin routes (`/sign-in/oauth2`, `/oauth2/callback/:providerId`, `/oauth2/link`) as
20/// the canonical flow for discovery-only generic providers.
21#[derive(Debug, Clone)]
22pub struct GenericOAuthProvider {
23    config: GenericOAuthConfig,
24    discovery_cache: Option<DiscoveryCache>,
25    http_client: OAuthHttpClient,
26}
27
28impl GenericOAuthProvider {
29    pub fn new(config: GenericOAuthConfig) -> Self {
30        let http_client = resolve_http_client(&config);
31        Self {
32            config,
33            discovery_cache: None,
34            http_client,
35        }
36    }
37
38    pub(crate) fn with_discovery_cache(
39        config: GenericOAuthConfig,
40        discovery_cache: DiscoveryCache,
41    ) -> Self {
42        let http_client = resolve_http_client(&config);
43        Self {
44            config,
45            discovery_cache: Some(discovery_cache),
46            http_client,
47        }
48    }
49
50    pub fn config(&self) -> &GenericOAuthConfig {
51        &self.config
52    }
53
54    pub fn authorization_code_request(
55        &self,
56        input: SocialAuthorizationCodeRequest,
57    ) -> Result<OAuthFormRequest, OAuthError> {
58        create_authorization_code_request(self.authorization_code_input(input)?)
59    }
60
61    pub fn refresh_access_token_request(
62        &self,
63        refresh_token: impl Into<String>,
64    ) -> Result<OAuthFormRequest, OAuthError> {
65        create_refresh_access_token_request(RefreshAccessTokenRequest {
66            refresh_token: refresh_token.into(),
67            options: self.config.provider_options(),
68            authentication: self.config.authentication,
69            extra_params: self.config.token_url_params.clone(),
70            ..RefreshAccessTokenRequest::default()
71        })
72    }
73
74    fn resolve_code_verifier(
75        &self,
76        code_verifier: Option<String>,
77    ) -> Result<Option<String>, OAuthError> {
78        if !self.config.pkce {
79            return Ok(None);
80        }
81        code_verifier
82            .ok_or(OAuthError::MissingOption("code_verifier"))
83            .map(Some)
84    }
85
86    fn authorization_code_input(
87        &self,
88        input: SocialAuthorizationCodeRequest,
89    ) -> Result<AuthorizationCodeRequest, OAuthError> {
90        Ok(AuthorizationCodeRequest {
91            code: input.code,
92            redirect_uri: input.redirect_uri,
93            options: self.config.provider_options(),
94            code_verifier: self.resolve_code_verifier(input.code_verifier)?,
95            device_id: input.device_id,
96            authentication: self.config.authentication,
97            headers: super::discovery::headers(&self.config.authorization_headers),
98            additional_params: self.config.token_url_params.clone(),
99            ..AuthorizationCodeRequest::default()
100        })
101    }
102
103    async fn token_endpoint(&self) -> Result<String, OAuthError> {
104        if let Some(token_url) = &self.config.token_url {
105            return Ok(token_url.clone());
106        }
107        let Some(discovery_cache) = &self.discovery_cache else {
108            return Err(OAuthError::InvalidResponse(
109                "Invalid OAuth configuration. Token URL not found.".to_owned(),
110            ));
111        };
112        let discovery = discovery_cache
113            .fetch(&self.config, &self.http_client)
114            .await
115            .map_err(|error| OAuthError::InvalidResponse(error.to_string()))?
116            .ok_or_else(|| {
117                OAuthError::InvalidResponse(
118                    "Invalid OAuth configuration. Token URL not found.".to_owned(),
119                )
120            })?;
121        discovery.token_endpoint.ok_or_else(|| {
122            OAuthError::InvalidResponse(
123                "Invalid OAuth configuration. Token URL not found.".to_owned(),
124            )
125        })
126    }
127}
128
129impl SocialOAuthProvider for GenericOAuthProvider {
130    fn id(&self) -> &str {
131        &self.config.provider_id
132    }
133
134    fn name(&self) -> &str {
135        &self.config.provider_id
136    }
137
138    fn provider_options(&self) -> ProviderOptions {
139        self.config.provider_options()
140    }
141
142    fn create_authorization_url(
143        &self,
144        input: SocialAuthorizationUrlRequest,
145    ) -> Result<Url, OAuthError> {
146        let Some(authorization_endpoint) = self.config.authorization_url.clone() else {
147            return Err(OAuthError::InvalidResponse(
148                "Invalid OAuth configuration".to_owned(),
149            ));
150        };
151        create_authorization_url(AuthorizationUrlRequest {
152            id: self.config.provider_id.clone(),
153            options: self.config.provider_options(),
154            authorization_endpoint,
155            redirect_uri: input.redirect_uri,
156            state: input.state,
157            code_verifier: self.resolve_code_verifier(input.code_verifier)?,
158            scopes: self.config.scopes(input.scopes),
159            prompt: self.config.prompt.clone(),
160            access_type: self.config.access_type.clone(),
161            response_type: self.config.response_type.clone(),
162            response_mode: self.config.response_mode.clone(),
163            login_hint: input.login_hint,
164            additional_params: self.config.authorization_url_params.clone(),
165            ..AuthorizationUrlRequest::default()
166        })
167    }
168
169    fn validate_authorization_code(
170        &self,
171        input: SocialAuthorizationCodeRequest,
172    ) -> SocialProviderFuture<'_, OAuth2Tokens> {
173        Box::pin(async move {
174            if let Some(get_token) = &self.config.get_token {
175                return get_token(GenericOAuthTokenRequest {
176                    code: input.code,
177                    redirect_uri: self
178                        .config
179                        .redirect_uri
180                        .clone()
181                        .unwrap_or(input.redirect_uri),
182                    code_verifier: self.resolve_code_verifier(input.code_verifier)?,
183                    device_id: input.device_id,
184                })
185                .await;
186            }
187            let token_endpoint = self.token_endpoint().await?;
188            exchange_authorization_code(
189                &token_endpoint,
190                self.authorization_code_input(input)?,
191                &self.http_client,
192            )
193            .await
194        })
195    }
196
197    fn get_user_info(
198        &self,
199        tokens: OAuth2Tokens,
200        _provider_user: Option<serde_json::Value>,
201    ) -> SocialProviderFuture<'_, Option<OAuth2UserInfo>> {
202        Box::pin(async move {
203            let user = if let Some(get_user_info) = &self.config.get_user_info {
204                get_user_info(tokens).await?
205            } else {
206                user_info::get_user_info(
207                    &tokens,
208                    self.config.user_info_url.as_deref(),
209                    &self.http_client,
210                )
211                .await?
212            };
213            if let Some(map_profile) = &self.config.map_profile_to_user {
214                if let Some(user) = user {
215                    return map_profile(user).await.map(Some);
216                }
217                return Ok(None);
218            }
219            Ok(user)
220        })
221    }
222
223    fn verify_id_token(&self, input: SocialIdTokenRequest) -> SocialProviderFuture<'_, bool> {
224        Box::pin(async move {
225            if let Some(verify_id_token) = &self.config.verify_id_token {
226                return verify_id_token(input).await;
227            }
228            Ok(false)
229        })
230    }
231
232    fn refresh_access_token(
233        &self,
234        refresh_token_value: String,
235    ) -> SocialProviderFuture<'_, OAuth2Tokens> {
236        Box::pin(async move {
237            if let Some(refresh_access_token) = &self.config.refresh_access_token {
238                return refresh_access_token(refresh_token_value).await;
239            }
240            let token_endpoint = self.token_endpoint().await?;
241            refresh_access_token_at(
242                &token_endpoint,
243                RefreshAccessTokenRequest {
244                    refresh_token: refresh_token_value,
245                    options: self.config.provider_options(),
246                    authentication: self.config.authentication,
247                    extra_params: self.config.token_url_params.clone(),
248                    ..RefreshAccessTokenRequest::default()
249                },
250                &self.http_client,
251            )
252            .await
253        })
254    }
255
256    fn revoke_token(&self, token: String) -> SocialProviderFuture<'_, ()> {
257        Box::pin(async move {
258            if let Some(revoke_token) = &self.config.revoke_token {
259                return revoke_token(token).await;
260            }
261            Err(OAuthError::InvalidResponse(format!(
262                "provider does not support token revocation for token `{token}`"
263            )))
264        })
265    }
266}