Skip to main content

rustauth_oauth/oauth2/
client.rs

1use std::collections::BTreeMap;
2
3use url::Url;
4
5use super::authorization_url::{create_authorization_url, AuthorizationUrlRequest};
6use super::client_credentials_token::{
7    create_client_credentials_token_request, ClientCredentialsTokenRequest,
8};
9use super::error::OAuthError;
10use super::http::{default_http_client, OAuthHttpClient, OAuthHttpClientConfig};
11use super::refresh_access_token::{create_refresh_access_token_request, RefreshAccessTokenRequest};
12use super::request::{post_form_with_client, ClientAuthentication, OAuthFormRequest};
13use super::tokens::{get_oauth2_tokens, get_primary_client_id, OAuth2Tokens, ProviderOptions};
14use super::types::{AuthorizationEndpoint, TokenEndpoint};
15use super::validate_authorization_code::{
16    create_authorization_code_request, AuthorizationCodeRequest,
17};
18
19/// Configured OAuth 2.0 client for a single provider (fixed authorization and token endpoints).
20#[derive(Debug, Clone)]
21pub struct OAuth2Client {
22    id: String,
23    authorization_endpoint: AuthorizationEndpoint,
24    token_endpoint: TokenEndpoint,
25    options: ProviderOptions,
26    default_scopes: Vec<String>,
27    scope_joiner: String,
28    authentication: ClientAuthentication,
29    http: OAuthHttpClient,
30}
31
32/// Builder for [`OAuth2Client`]. Validates endpoints and `client_id` at [`OAuth2ClientBuilder::build`].
33#[must_use = "OAuth2ClientBuilder must be built to produce a client"]
34pub struct OAuth2ClientBuilder {
35    id: String,
36    options: ProviderOptions,
37    authorization_endpoint: Option<AuthorizationEndpoint>,
38    token_endpoint: Option<TokenEndpoint>,
39    default_scopes: Vec<String>,
40    scope_joiner: String,
41    authentication: ClientAuthentication,
42    http: Option<OAuthHttpClient>,
43}
44
45impl OAuth2Client {
46    pub fn builder(
47        provider_id: impl Into<String>,
48        options: ProviderOptions,
49    ) -> OAuth2ClientBuilder {
50        OAuth2ClientBuilder {
51            id: provider_id.into(),
52            options,
53            authorization_endpoint: None,
54            token_endpoint: None,
55            default_scopes: Vec::new(),
56            scope_joiner: " ".to_owned(),
57            authentication: ClientAuthentication::Post,
58            http: None,
59        }
60    }
61
62    pub fn id(&self) -> &str {
63        &self.id
64    }
65
66    pub fn options(&self) -> &ProviderOptions {
67        &self.options
68    }
69
70    pub fn http(&self) -> &OAuthHttpClient {
71        &self.http
72    }
73
74    pub fn authorization_endpoint(&self) -> &AuthorizationEndpoint {
75        &self.authorization_endpoint
76    }
77
78    pub fn token_endpoint(&self) -> &TokenEndpoint {
79        &self.token_endpoint
80    }
81
82    pub fn authorization_url(
83        &self,
84        state: impl Into<String>,
85        redirect_uri: impl Into<String>,
86    ) -> Result<AuthorizationUrlBuilder<'_>, OAuthError> {
87        let state = state.into();
88        if state.is_empty() {
89            return Err(OAuthError::InvalidConfiguration(
90                "authorization state cannot be empty".to_owned(),
91            ));
92        }
93        let redirect_uri = redirect_uri.into();
94        url::Url::parse(
95            self.options
96                .redirect_uri
97                .as_deref()
98                .unwrap_or(&redirect_uri),
99        )?;
100        Ok(AuthorizationUrlBuilder {
101            client: self,
102            state,
103            redirect_uri,
104            code_verifier: None,
105            scopes: Vec::new(),
106            login_hint: None,
107            prompt: None,
108            access_type: None,
109            response_type: None,
110            response_mode: None,
111            display: None,
112            hd: None,
113            duration: None,
114            claims: Vec::new(),
115            additional_params: BTreeMap::new(),
116        })
117    }
118
119    pub fn exchange_code(
120        &self,
121        code: impl Into<String>,
122        redirect_uri: impl Into<String>,
123    ) -> Result<ExchangeCodeBuilder<'_>, OAuthError> {
124        Ok(ExchangeCodeBuilder {
125            client: self,
126            request: AuthorizationCodeRequest::try_new(code, redirect_uri, self.options.clone())?
127                .authentication(self.authentication),
128        })
129    }
130
131    pub fn refresh_token(
132        &self,
133        refresh_token: impl Into<String>,
134    ) -> Result<RefreshTokenBuilder<'_>, OAuthError> {
135        Ok(RefreshTokenBuilder {
136            client: self,
137            request: RefreshAccessTokenRequest::try_new(refresh_token, self.options.clone())?
138                .authentication(self.authentication),
139        })
140    }
141
142    pub fn client_credentials(&self) -> Result<ClientCredentialsBuilder<'_>, OAuthError> {
143        Ok(ClientCredentialsBuilder {
144            client: self,
145            request: ClientCredentialsTokenRequest::try_new(self.options.clone())?
146                .authentication(self.authentication),
147        })
148    }
149}
150
151impl OAuth2ClientBuilder {
152    pub fn authorization_endpoint(mut self, url: impl Into<String>) -> Result<Self, OAuthError> {
153        self.authorization_endpoint = Some(AuthorizationEndpoint::new(url)?);
154        Ok(self)
155    }
156
157    pub fn token_endpoint(mut self, url: impl Into<String>) -> Result<Self, OAuthError> {
158        self.token_endpoint = Some(TokenEndpoint::new(url)?);
159        Ok(self)
160    }
161
162    pub fn default_scope(mut self, scope: impl Into<String>) -> Self {
163        self.default_scopes.push(scope.into());
164        self
165    }
166
167    pub fn default_scopes(mut self, scopes: impl IntoIterator<Item = impl Into<String>>) -> Self {
168        self.default_scopes
169            .extend(scopes.into_iter().map(Into::into));
170        self
171    }
172
173    pub fn scope_joiner(mut self, joiner: impl Into<String>) -> Self {
174        self.scope_joiner = joiner.into();
175        self
176    }
177
178    pub fn authentication(mut self, authentication: ClientAuthentication) -> Self {
179        self.authentication = authentication;
180        self
181    }
182
183    pub fn http_client(mut self, http: OAuthHttpClient) -> Self {
184        self.http = Some(http);
185        self
186    }
187
188    pub fn http_config(mut self, config: OAuthHttpClientConfig) -> Result<Self, OAuthError> {
189        self.http = Some(OAuthHttpClient::from_config(config)?);
190        Ok(self)
191    }
192
193    pub fn build(self) -> Result<OAuth2Client, OAuthError> {
194        let authorization_endpoint = self
195            .authorization_endpoint
196            .ok_or(OAuthError::MissingOption("authorization_endpoint"))?;
197        let token_endpoint = self
198            .token_endpoint
199            .ok_or(OAuthError::MissingOption("token_endpoint"))?;
200        get_primary_client_id(&self.options.client_id)
201            .ok_or(OAuthError::MissingOption("client_id"))?;
202        let http = match self.http {
203            Some(http) => http,
204            None => default_http_client()?,
205        };
206        Ok(OAuth2Client {
207            id: self.id,
208            authorization_endpoint,
209            token_endpoint,
210            options: self.options,
211            default_scopes: self.default_scopes,
212            scope_joiner: self.scope_joiner,
213            authentication: self.authentication,
214            http,
215        })
216    }
217}
218
219/// Authorization URL builder returned by [`OAuth2Client::authorization_url`].
220#[must_use = "AuthorizationUrlBuilder must be built to produce a URL"]
221pub struct AuthorizationUrlBuilder<'a> {
222    client: &'a OAuth2Client,
223    state: String,
224    redirect_uri: String,
225    code_verifier: Option<String>,
226    scopes: Vec<String>,
227    login_hint: Option<String>,
228    prompt: Option<String>,
229    access_type: Option<String>,
230    response_type: Option<String>,
231    response_mode: Option<String>,
232    display: Option<String>,
233    hd: Option<String>,
234    duration: Option<String>,
235    claims: Vec<String>,
236    additional_params: BTreeMap<String, String>,
237}
238
239impl AuthorizationUrlBuilder<'_> {
240    pub fn code_verifier(mut self, code_verifier: impl Into<String>) -> Self {
241        self.code_verifier = Some(code_verifier.into());
242        self
243    }
244
245    pub fn scope(mut self, scope: impl Into<String>) -> Self {
246        self.scopes.push(scope.into());
247        self
248    }
249
250    pub fn scopes(mut self, scopes: impl IntoIterator<Item = impl Into<String>>) -> Self {
251        self.scopes.extend(scopes.into_iter().map(Into::into));
252        self
253    }
254
255    pub fn login_hint(mut self, login_hint: impl Into<String>) -> Self {
256        self.login_hint = Some(login_hint.into());
257        self
258    }
259
260    pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
261        self.prompt = Some(prompt.into());
262        self
263    }
264
265    pub fn access_type(mut self, access_type: impl Into<String>) -> Self {
266        self.access_type = Some(access_type.into());
267        self
268    }
269
270    pub fn response_type(mut self, response_type: impl Into<String>) -> Self {
271        self.response_type = Some(response_type.into());
272        self
273    }
274
275    pub fn response_mode(mut self, response_mode: impl Into<String>) -> Self {
276        self.response_mode = Some(response_mode.into());
277        self
278    }
279
280    pub fn claim(mut self, claim: impl Into<String>) -> Self {
281        self.claims.push(claim.into());
282        self
283    }
284
285    pub fn duration(mut self, duration: impl Into<String>) -> Self {
286        self.duration = Some(duration.into());
287        self
288    }
289
290    pub fn param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
291        self.additional_params.insert(key.into(), value.into());
292        self
293    }
294
295    pub fn build(self) -> Result<Url, OAuthError> {
296        let mut scopes = if !self.client.options.disable_default_scope {
297            self.client.default_scopes.clone()
298        } else {
299            Vec::new()
300        };
301        scopes.extend(self.client.options.scope.iter().cloned());
302        scopes.extend(self.scopes);
303
304        create_authorization_url(AuthorizationUrlRequest {
305            id: self.client.id.clone(),
306            options: self.client.options.clone(),
307            authorization_endpoint: self.client.authorization_endpoint.as_str().to_owned(),
308            redirect_uri: self.redirect_uri,
309            state: self.state,
310            code_verifier: self.code_verifier,
311            scopes,
312            login_hint: self.login_hint,
313            prompt: self.prompt.or_else(|| self.client.options.prompt.clone()),
314            access_type: self.access_type,
315            response_type: self.response_type,
316            response_mode: self
317                .response_mode
318                .or_else(|| self.client.options.response_mode.clone()),
319            display: self.display,
320            hd: self.hd,
321            duration: self.duration,
322            claims: self.claims,
323            additional_params: self.additional_params,
324            scope_joiner: self.client.scope_joiner.clone(),
325        })
326    }
327}
328
329/// Authorization-code exchange builder returned by [`OAuth2Client::exchange_code`].
330#[must_use = "ExchangeCodeBuilder must be sent or converted to a form request"]
331pub struct ExchangeCodeBuilder<'a> {
332    client: &'a OAuth2Client,
333    request: AuthorizationCodeRequest,
334}
335
336impl ExchangeCodeBuilder<'_> {
337    pub fn code_verifier(mut self, code_verifier: impl Into<String>) -> Self {
338        self.request = self.request.code_verifier(code_verifier);
339        self
340    }
341
342    pub fn device_id(mut self, device_id: impl Into<String>) -> Self {
343        self.request.device_id = Some(device_id.into());
344        self
345    }
346
347    pub fn authentication(mut self, authentication: ClientAuthentication) -> Self {
348        self.request = self.request.authentication(authentication);
349        self
350    }
351
352    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
353        self.request = self.request.header(key, value);
354        self
355    }
356
357    pub fn additional_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
358        self.request = self.request.additional_param(key, value);
359        self
360    }
361
362    pub fn override_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
363        self.request = self.request.override_param(key, value);
364        self
365    }
366
367    pub fn resource(mut self, resource: impl Into<String>) -> Self {
368        self.request = self.request.resource(resource);
369        self
370    }
371
372    pub fn into_form_request(self) -> Result<OAuthFormRequest, OAuthError> {
373        create_authorization_code_request(self.request)
374    }
375
376    pub async fn send(self) -> Result<OAuth2Tokens, OAuthError> {
377        exchange_authorization_code(
378            self.client.token_endpoint.as_str(),
379            self.request,
380            &self.client.http,
381        )
382        .await
383    }
384}
385
386/// Refresh-token builder returned by [`OAuth2Client::refresh_token`].
387#[must_use = "RefreshTokenBuilder must be sent or converted to a form request"]
388pub struct RefreshTokenBuilder<'a> {
389    client: &'a OAuth2Client,
390    request: RefreshAccessTokenRequest,
391}
392
393impl RefreshTokenBuilder<'_> {
394    pub fn authentication(mut self, authentication: ClientAuthentication) -> Self {
395        self.request = self.request.authentication(authentication);
396        self
397    }
398
399    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
400        self.request = self.request.header(key, value);
401        self
402    }
403
404    pub fn extra_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
405        self.request = self.request.extra_param(key, value);
406        self
407    }
408
409    pub fn resource(mut self, resource: impl Into<String>) -> Self {
410        self.request = self.request.resource(resource);
411        self
412    }
413
414    pub fn into_form_request(self) -> Result<OAuthFormRequest, OAuthError> {
415        create_refresh_access_token_request(self.request)
416    }
417
418    pub async fn send(self) -> Result<OAuth2Tokens, OAuthError> {
419        refresh_access_token_at(
420            self.client.token_endpoint.as_str(),
421            self.request,
422            &self.client.http,
423        )
424        .await
425    }
426}
427
428/// Client-credentials grant builder returned by [`OAuth2Client::client_credentials`].
429#[must_use = "ClientCredentialsBuilder must be sent or converted to a form request"]
430pub struct ClientCredentialsBuilder<'a> {
431    client: &'a OAuth2Client,
432    request: ClientCredentialsTokenRequest,
433}
434
435impl ClientCredentialsBuilder<'_> {
436    pub fn scope(mut self, scope: impl Into<String>) -> Self {
437        self.request = self.request.scope(scope);
438        self
439    }
440
441    pub fn authentication(mut self, authentication: ClientAuthentication) -> Self {
442        self.request = self.request.authentication(authentication);
443        self
444    }
445
446    pub fn resource(mut self, resource: impl Into<String>) -> Self {
447        self.request = self.request.resource(resource);
448        self
449    }
450
451    pub fn into_form_request(self) -> Result<OAuthFormRequest, OAuthError> {
452        create_client_credentials_token_request(self.request)
453    }
454
455    pub async fn send(self) -> Result<OAuth2Tokens, OAuthError> {
456        let request = create_client_credentials_token_request(self.request)?;
457        let data = post_form_with_client(
458            self.client.token_endpoint.as_str(),
459            request,
460            &self.client.http,
461        )
462        .await?;
463        get_oauth2_tokens(data)
464    }
465}
466
467/// Submits a prepared token form request (advanced / test flows).
468pub async fn submit_token_form(
469    token_endpoint: &str,
470    request: OAuthFormRequest,
471    client: &OAuthHttpClient,
472) -> Result<OAuth2Tokens, OAuthError> {
473    let data = post_form_with_client(token_endpoint, request, client).await?;
474    get_oauth2_tokens(data)
475}
476
477/// Exchanges an authorization code at a token endpoint (advanced / discovery-based flows).
478pub async fn exchange_authorization_code(
479    token_endpoint: &str,
480    request: AuthorizationCodeRequest,
481    client: &OAuthHttpClient,
482) -> Result<OAuth2Tokens, OAuthError> {
483    let form = create_authorization_code_request(request)?;
484    let data = post_form_with_client(token_endpoint, form, client).await?;
485    get_oauth2_tokens(data)
486}
487
488/// Refreshes an access token at a token endpoint (advanced / discovery-based flows).
489pub async fn refresh_access_token_at(
490    token_endpoint: &str,
491    request: RefreshAccessTokenRequest,
492    client: &OAuthHttpClient,
493) -> Result<OAuth2Tokens, OAuthError> {
494    let form = create_refresh_access_token_request(request)?;
495    let data = post_form_with_client(token_endpoint, form, client).await?;
496    get_oauth2_tokens(data)
497}