rustauth_plugins/generic_oauth/
provider.rs1use rustauth_oauth::oauth2::{
2 create_authorization_code_request, create_authorization_url,
3 create_refresh_access_token_request, exchange_authorization_code, refresh_access_token_at,
4 AuthorizationCodeRequest, AuthorizationUrlRequest, OAuth2Tokens, OAuth2UserInfo, OAuthError,
5 OAuthFormRequest, OAuthHttpClient, ProviderOptions, RefreshAccessTokenRequest,
6 SocialAuthorizationCodeRequest, SocialAuthorizationUrlRequest, SocialIdTokenRequest,
7 SocialOAuthProvider, SocialProviderFuture,
8};
9use url::Url;
10
11use super::config::{GenericOAuthConfig, GenericOAuthTokenRequest};
12use super::discovery::{resolve_http_client, DiscoveryCache};
13use super::user_info;
14
15#[derive(Debug, Clone)]
22pub struct GenericOAuthProvider {
23 config: GenericOAuthConfig,
24 discovery_cache: Option<DiscoveryCache>,
25 http_client: OAuthHttpClient,
26}
27
28impl GenericOAuthProvider {
29 pub fn new(config: GenericOAuthConfig) -> Self {
30 let http_client = resolve_http_client(&config);
31 Self {
32 config,
33 discovery_cache: None,
34 http_client,
35 }
36 }
37
38 pub(crate) fn with_discovery_cache(
39 config: GenericOAuthConfig,
40 discovery_cache: DiscoveryCache,
41 ) -> Self {
42 let http_client = resolve_http_client(&config);
43 Self {
44 config,
45 discovery_cache: Some(discovery_cache),
46 http_client,
47 }
48 }
49
50 pub fn config(&self) -> &GenericOAuthConfig {
51 &self.config
52 }
53
54 pub fn authorization_code_request(
55 &self,
56 input: SocialAuthorizationCodeRequest,
57 ) -> Result<OAuthFormRequest, OAuthError> {
58 create_authorization_code_request(self.authorization_code_input(input)?)
59 }
60
61 pub fn refresh_access_token_request(
62 &self,
63 refresh_token: impl Into<String>,
64 ) -> Result<OAuthFormRequest, OAuthError> {
65 create_refresh_access_token_request(RefreshAccessTokenRequest {
66 refresh_token: refresh_token.into(),
67 options: self.config.provider_options(),
68 authentication: self.config.authentication,
69 extra_params: self.config.token_url_params.clone(),
70 ..RefreshAccessTokenRequest::default()
71 })
72 }
73
74 fn resolve_code_verifier(
75 &self,
76 code_verifier: Option<String>,
77 ) -> Result<Option<String>, OAuthError> {
78 if !self.config.pkce {
79 return Ok(None);
80 }
81 code_verifier
82 .ok_or(OAuthError::MissingOption("code_verifier"))
83 .map(Some)
84 }
85
86 fn authorization_code_input(
87 &self,
88 input: SocialAuthorizationCodeRequest,
89 ) -> Result<AuthorizationCodeRequest, OAuthError> {
90 Ok(AuthorizationCodeRequest {
91 code: input.code,
92 redirect_uri: input.redirect_uri,
93 options: self.config.provider_options(),
94 code_verifier: self.resolve_code_verifier(input.code_verifier)?,
95 device_id: input.device_id,
96 authentication: self.config.authentication,
97 headers: super::discovery::headers(&self.config.authorization_headers),
98 additional_params: self.config.token_url_params.clone(),
99 ..AuthorizationCodeRequest::default()
100 })
101 }
102
103 async fn token_endpoint(&self) -> Result<String, OAuthError> {
104 if let Some(token_url) = &self.config.token_url {
105 return Ok(token_url.clone());
106 }
107 let Some(discovery_cache) = &self.discovery_cache else {
108 return Err(OAuthError::InvalidResponse(
109 "Invalid OAuth configuration. Token URL not found.".to_owned(),
110 ));
111 };
112 let discovery = discovery_cache
113 .fetch(&self.config, &self.http_client)
114 .await
115 .map_err(|error| OAuthError::InvalidResponse(error.to_string()))?
116 .ok_or_else(|| {
117 OAuthError::InvalidResponse(
118 "Invalid OAuth configuration. Token URL not found.".to_owned(),
119 )
120 })?;
121 discovery.token_endpoint.ok_or_else(|| {
122 OAuthError::InvalidResponse(
123 "Invalid OAuth configuration. Token URL not found.".to_owned(),
124 )
125 })
126 }
127}
128
129impl SocialOAuthProvider for GenericOAuthProvider {
130 fn id(&self) -> &str {
131 &self.config.provider_id
132 }
133
134 fn name(&self) -> &str {
135 &self.config.provider_id
136 }
137
138 fn provider_options(&self) -> ProviderOptions {
139 self.config.provider_options()
140 }
141
142 fn create_authorization_url(
143 &self,
144 input: SocialAuthorizationUrlRequest,
145 ) -> Result<Url, OAuthError> {
146 let Some(authorization_endpoint) = self.config.authorization_url.clone() else {
147 return Err(OAuthError::InvalidResponse(
148 "Invalid OAuth configuration".to_owned(),
149 ));
150 };
151 create_authorization_url(AuthorizationUrlRequest {
152 id: self.config.provider_id.clone(),
153 options: self.config.provider_options(),
154 authorization_endpoint,
155 redirect_uri: input.redirect_uri,
156 state: input.state,
157 code_verifier: self.resolve_code_verifier(input.code_verifier)?,
158 scopes: self.config.scopes(input.scopes),
159 prompt: self.config.prompt.clone(),
160 access_type: self.config.access_type.clone(),
161 response_type: self.config.response_type.clone(),
162 response_mode: self.config.response_mode.clone(),
163 login_hint: input.login_hint,
164 additional_params: self.config.authorization_url_params.clone(),
165 ..AuthorizationUrlRequest::default()
166 })
167 }
168
169 fn validate_authorization_code(
170 &self,
171 input: SocialAuthorizationCodeRequest,
172 ) -> SocialProviderFuture<'_, OAuth2Tokens> {
173 Box::pin(async move {
174 if let Some(get_token) = &self.config.get_token {
175 return get_token(GenericOAuthTokenRequest {
176 code: input.code,
177 redirect_uri: self
178 .config
179 .redirect_uri
180 .clone()
181 .unwrap_or(input.redirect_uri),
182 code_verifier: self.resolve_code_verifier(input.code_verifier)?,
183 device_id: input.device_id,
184 })
185 .await;
186 }
187 let token_endpoint = self.token_endpoint().await?;
188 exchange_authorization_code(
189 &token_endpoint,
190 self.authorization_code_input(input)?,
191 &self.http_client,
192 )
193 .await
194 })
195 }
196
197 fn get_user_info(
198 &self,
199 tokens: OAuth2Tokens,
200 _provider_user: Option<serde_json::Value>,
201 ) -> SocialProviderFuture<'_, Option<OAuth2UserInfo>> {
202 Box::pin(async move {
203 let user = if let Some(get_user_info) = &self.config.get_user_info {
204 get_user_info(tokens).await?
205 } else {
206 user_info::get_user_info(
207 &tokens,
208 self.config.user_info_url.as_deref(),
209 &self.http_client,
210 )
211 .await?
212 };
213 if let Some(map_profile) = &self.config.map_profile_to_user {
214 if let Some(user) = user {
215 return map_profile(user).await.map(Some);
216 }
217 return Ok(None);
218 }
219 Ok(user)
220 })
221 }
222
223 fn verify_id_token(&self, input: SocialIdTokenRequest) -> SocialProviderFuture<'_, bool> {
224 Box::pin(async move {
225 if let Some(verify_id_token) = &self.config.verify_id_token {
226 return verify_id_token(input).await;
227 }
228 Ok(false)
229 })
230 }
231
232 fn refresh_access_token(
233 &self,
234 refresh_token_value: String,
235 ) -> SocialProviderFuture<'_, OAuth2Tokens> {
236 Box::pin(async move {
237 if let Some(refresh_access_token) = &self.config.refresh_access_token {
238 return refresh_access_token(refresh_token_value).await;
239 }
240 let token_endpoint = self.token_endpoint().await?;
241 refresh_access_token_at(
242 &token_endpoint,
243 RefreshAccessTokenRequest {
244 refresh_token: refresh_token_value,
245 options: self.config.provider_options(),
246 authentication: self.config.authentication,
247 extra_params: self.config.token_url_params.clone(),
248 ..RefreshAccessTokenRequest::default()
249 },
250 &self.http_client,
251 )
252 .await
253 })
254 }
255
256 fn revoke_token(&self, token: String) -> SocialProviderFuture<'_, ()> {
257 Box::pin(async move {
258 if let Some(revoke_token) = &self.config.revoke_token {
259 return revoke_token(token).await;
260 }
261 Err(OAuthError::InvalidResponse(format!(
262 "provider does not support token revocation for token `{token}`"
263 )))
264 })
265 }
266}