1use std::collections::BTreeMap;
2
3use url::Url;
4
5use super::authorization_url::{create_authorization_url, AuthorizationUrlRequest};
6use super::client_credentials_token::{
7 create_client_credentials_token_request, ClientCredentialsTokenRequest,
8};
9use super::error::OAuthError;
10use super::http::{default_http_client, OAuthHttpClient, OAuthHttpClientConfig};
11use super::refresh_access_token::{create_refresh_access_token_request, RefreshAccessTokenRequest};
12use super::request::{post_form_with_client, ClientAuthentication, OAuthFormRequest};
13use super::tokens::{get_oauth2_tokens, get_primary_client_id, OAuth2Tokens, ProviderOptions};
14use super::types::{AuthorizationEndpoint, TokenEndpoint};
15use super::validate_authorization_code::{
16 create_authorization_code_request, AuthorizationCodeRequest,
17};
18
19#[derive(Debug, Clone)]
21pub struct OAuth2Client {
22 id: String,
23 authorization_endpoint: AuthorizationEndpoint,
24 token_endpoint: TokenEndpoint,
25 options: ProviderOptions,
26 default_scopes: Vec<String>,
27 scope_joiner: String,
28 authentication: ClientAuthentication,
29 http: OAuthHttpClient,
30}
31
32#[must_use = "OAuth2ClientBuilder must be built to produce a client"]
34pub struct OAuth2ClientBuilder {
35 id: String,
36 options: ProviderOptions,
37 authorization_endpoint: Option<AuthorizationEndpoint>,
38 token_endpoint: Option<TokenEndpoint>,
39 default_scopes: Vec<String>,
40 scope_joiner: String,
41 authentication: ClientAuthentication,
42 http: Option<OAuthHttpClient>,
43}
44
45impl OAuth2Client {
46 pub fn builder(
47 provider_id: impl Into<String>,
48 options: ProviderOptions,
49 ) -> OAuth2ClientBuilder {
50 OAuth2ClientBuilder {
51 id: provider_id.into(),
52 options,
53 authorization_endpoint: None,
54 token_endpoint: None,
55 default_scopes: Vec::new(),
56 scope_joiner: " ".to_owned(),
57 authentication: ClientAuthentication::Post,
58 http: None,
59 }
60 }
61
62 pub fn id(&self) -> &str {
63 &self.id
64 }
65
66 pub fn options(&self) -> &ProviderOptions {
67 &self.options
68 }
69
70 pub fn http(&self) -> &OAuthHttpClient {
71 &self.http
72 }
73
74 pub fn authorization_endpoint(&self) -> &AuthorizationEndpoint {
75 &self.authorization_endpoint
76 }
77
78 pub fn token_endpoint(&self) -> &TokenEndpoint {
79 &self.token_endpoint
80 }
81
82 pub fn authorization_url(
83 &self,
84 state: impl Into<String>,
85 redirect_uri: impl Into<String>,
86 ) -> Result<AuthorizationUrlBuilder<'_>, OAuthError> {
87 let state = state.into();
88 if state.is_empty() {
89 return Err(OAuthError::InvalidConfiguration(
90 "authorization state cannot be empty".to_owned(),
91 ));
92 }
93 let redirect_uri = redirect_uri.into();
94 url::Url::parse(
95 self.options
96 .redirect_uri
97 .as_deref()
98 .unwrap_or(&redirect_uri),
99 )?;
100 Ok(AuthorizationUrlBuilder {
101 client: self,
102 state,
103 redirect_uri,
104 code_verifier: None,
105 scopes: Vec::new(),
106 login_hint: None,
107 prompt: None,
108 access_type: None,
109 response_type: None,
110 response_mode: None,
111 display: None,
112 hd: None,
113 duration: None,
114 claims: Vec::new(),
115 additional_params: BTreeMap::new(),
116 })
117 }
118
119 pub fn exchange_code(
120 &self,
121 code: impl Into<String>,
122 redirect_uri: impl Into<String>,
123 ) -> Result<ExchangeCodeBuilder<'_>, OAuthError> {
124 Ok(ExchangeCodeBuilder {
125 client: self,
126 request: AuthorizationCodeRequest::try_new(code, redirect_uri, self.options.clone())?
127 .authentication(self.authentication),
128 })
129 }
130
131 pub fn refresh_token(
132 &self,
133 refresh_token: impl Into<String>,
134 ) -> Result<RefreshTokenBuilder<'_>, OAuthError> {
135 Ok(RefreshTokenBuilder {
136 client: self,
137 request: RefreshAccessTokenRequest::try_new(refresh_token, self.options.clone())?
138 .authentication(self.authentication),
139 })
140 }
141
142 pub fn client_credentials(&self) -> Result<ClientCredentialsBuilder<'_>, OAuthError> {
143 Ok(ClientCredentialsBuilder {
144 client: self,
145 request: ClientCredentialsTokenRequest::try_new(self.options.clone())?
146 .authentication(self.authentication),
147 })
148 }
149}
150
151impl OAuth2ClientBuilder {
152 pub fn authorization_endpoint(mut self, url: impl Into<String>) -> Result<Self, OAuthError> {
153 self.authorization_endpoint = Some(AuthorizationEndpoint::new(url)?);
154 Ok(self)
155 }
156
157 pub fn token_endpoint(mut self, url: impl Into<String>) -> Result<Self, OAuthError> {
158 self.token_endpoint = Some(TokenEndpoint::new(url)?);
159 Ok(self)
160 }
161
162 pub fn default_scope(mut self, scope: impl Into<String>) -> Self {
163 self.default_scopes.push(scope.into());
164 self
165 }
166
167 pub fn default_scopes(mut self, scopes: impl IntoIterator<Item = impl Into<String>>) -> Self {
168 self.default_scopes
169 .extend(scopes.into_iter().map(Into::into));
170 self
171 }
172
173 pub fn scope_joiner(mut self, joiner: impl Into<String>) -> Self {
174 self.scope_joiner = joiner.into();
175 self
176 }
177
178 pub fn authentication(mut self, authentication: ClientAuthentication) -> Self {
179 self.authentication = authentication;
180 self
181 }
182
183 pub fn http_client(mut self, http: OAuthHttpClient) -> Self {
184 self.http = Some(http);
185 self
186 }
187
188 pub fn http_config(mut self, config: OAuthHttpClientConfig) -> Result<Self, OAuthError> {
189 self.http = Some(OAuthHttpClient::from_config(config)?);
190 Ok(self)
191 }
192
193 pub fn build(self) -> Result<OAuth2Client, OAuthError> {
194 let authorization_endpoint = self
195 .authorization_endpoint
196 .ok_or(OAuthError::MissingOption("authorization_endpoint"))?;
197 let token_endpoint = self
198 .token_endpoint
199 .ok_or(OAuthError::MissingOption("token_endpoint"))?;
200 get_primary_client_id(&self.options.client_id)
201 .ok_or(OAuthError::MissingOption("client_id"))?;
202 let http = match self.http {
203 Some(http) => http,
204 None => default_http_client()?,
205 };
206 Ok(OAuth2Client {
207 id: self.id,
208 authorization_endpoint,
209 token_endpoint,
210 options: self.options,
211 default_scopes: self.default_scopes,
212 scope_joiner: self.scope_joiner,
213 authentication: self.authentication,
214 http,
215 })
216 }
217}
218
219#[must_use = "AuthorizationUrlBuilder must be built to produce a URL"]
221pub struct AuthorizationUrlBuilder<'a> {
222 client: &'a OAuth2Client,
223 state: String,
224 redirect_uri: String,
225 code_verifier: Option<String>,
226 scopes: Vec<String>,
227 login_hint: Option<String>,
228 prompt: Option<String>,
229 access_type: Option<String>,
230 response_type: Option<String>,
231 response_mode: Option<String>,
232 display: Option<String>,
233 hd: Option<String>,
234 duration: Option<String>,
235 claims: Vec<String>,
236 additional_params: BTreeMap<String, String>,
237}
238
239impl AuthorizationUrlBuilder<'_> {
240 pub fn code_verifier(mut self, code_verifier: impl Into<String>) -> Self {
241 self.code_verifier = Some(code_verifier.into());
242 self
243 }
244
245 pub fn scope(mut self, scope: impl Into<String>) -> Self {
246 self.scopes.push(scope.into());
247 self
248 }
249
250 pub fn scopes(mut self, scopes: impl IntoIterator<Item = impl Into<String>>) -> Self {
251 self.scopes.extend(scopes.into_iter().map(Into::into));
252 self
253 }
254
255 pub fn login_hint(mut self, login_hint: impl Into<String>) -> Self {
256 self.login_hint = Some(login_hint.into());
257 self
258 }
259
260 pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
261 self.prompt = Some(prompt.into());
262 self
263 }
264
265 pub fn access_type(mut self, access_type: impl Into<String>) -> Self {
266 self.access_type = Some(access_type.into());
267 self
268 }
269
270 pub fn response_type(mut self, response_type: impl Into<String>) -> Self {
271 self.response_type = Some(response_type.into());
272 self
273 }
274
275 pub fn response_mode(mut self, response_mode: impl Into<String>) -> Self {
276 self.response_mode = Some(response_mode.into());
277 self
278 }
279
280 pub fn claim(mut self, claim: impl Into<String>) -> Self {
281 self.claims.push(claim.into());
282 self
283 }
284
285 pub fn duration(mut self, duration: impl Into<String>) -> Self {
286 self.duration = Some(duration.into());
287 self
288 }
289
290 pub fn param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
291 self.additional_params.insert(key.into(), value.into());
292 self
293 }
294
295 pub fn build(self) -> Result<Url, OAuthError> {
296 let mut scopes = if !self.client.options.disable_default_scope {
297 self.client.default_scopes.clone()
298 } else {
299 Vec::new()
300 };
301 scopes.extend(self.client.options.scope.iter().cloned());
302 scopes.extend(self.scopes);
303
304 create_authorization_url(AuthorizationUrlRequest {
305 id: self.client.id.clone(),
306 options: self.client.options.clone(),
307 authorization_endpoint: self.client.authorization_endpoint.as_str().to_owned(),
308 redirect_uri: self.redirect_uri,
309 state: self.state,
310 code_verifier: self.code_verifier,
311 scopes,
312 login_hint: self.login_hint,
313 prompt: self.prompt.or_else(|| self.client.options.prompt.clone()),
314 access_type: self.access_type,
315 response_type: self.response_type,
316 response_mode: self
317 .response_mode
318 .or_else(|| self.client.options.response_mode.clone()),
319 display: self.display,
320 hd: self.hd,
321 duration: self.duration,
322 claims: self.claims,
323 additional_params: self.additional_params,
324 scope_joiner: self.client.scope_joiner.clone(),
325 })
326 }
327}
328
329#[must_use = "ExchangeCodeBuilder must be sent or converted to a form request"]
331pub struct ExchangeCodeBuilder<'a> {
332 client: &'a OAuth2Client,
333 request: AuthorizationCodeRequest,
334}
335
336impl ExchangeCodeBuilder<'_> {
337 pub fn code_verifier(mut self, code_verifier: impl Into<String>) -> Self {
338 self.request = self.request.code_verifier(code_verifier);
339 self
340 }
341
342 pub fn device_id(mut self, device_id: impl Into<String>) -> Self {
343 self.request.device_id = Some(device_id.into());
344 self
345 }
346
347 pub fn authentication(mut self, authentication: ClientAuthentication) -> Self {
348 self.request = self.request.authentication(authentication);
349 self
350 }
351
352 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
353 self.request = self.request.header(key, value);
354 self
355 }
356
357 pub fn additional_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
358 self.request = self.request.additional_param(key, value);
359 self
360 }
361
362 pub fn override_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
363 self.request = self.request.override_param(key, value);
364 self
365 }
366
367 pub fn resource(mut self, resource: impl Into<String>) -> Self {
368 self.request = self.request.resource(resource);
369 self
370 }
371
372 pub fn into_form_request(self) -> Result<OAuthFormRequest, OAuthError> {
373 create_authorization_code_request(self.request)
374 }
375
376 pub async fn send(self) -> Result<OAuth2Tokens, OAuthError> {
377 exchange_authorization_code(
378 self.client.token_endpoint.as_str(),
379 self.request,
380 &self.client.http,
381 )
382 .await
383 }
384}
385
386#[must_use = "RefreshTokenBuilder must be sent or converted to a form request"]
388pub struct RefreshTokenBuilder<'a> {
389 client: &'a OAuth2Client,
390 request: RefreshAccessTokenRequest,
391}
392
393impl RefreshTokenBuilder<'_> {
394 pub fn authentication(mut self, authentication: ClientAuthentication) -> Self {
395 self.request = self.request.authentication(authentication);
396 self
397 }
398
399 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
400 self.request = self.request.header(key, value);
401 self
402 }
403
404 pub fn extra_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
405 self.request = self.request.extra_param(key, value);
406 self
407 }
408
409 pub fn resource(mut self, resource: impl Into<String>) -> Self {
410 self.request = self.request.resource(resource);
411 self
412 }
413
414 pub fn into_form_request(self) -> Result<OAuthFormRequest, OAuthError> {
415 create_refresh_access_token_request(self.request)
416 }
417
418 pub async fn send(self) -> Result<OAuth2Tokens, OAuthError> {
419 refresh_access_token_at(
420 self.client.token_endpoint.as_str(),
421 self.request,
422 &self.client.http,
423 )
424 .await
425 }
426}
427
428#[must_use = "ClientCredentialsBuilder must be sent or converted to a form request"]
430pub struct ClientCredentialsBuilder<'a> {
431 client: &'a OAuth2Client,
432 request: ClientCredentialsTokenRequest,
433}
434
435impl ClientCredentialsBuilder<'_> {
436 pub fn scope(mut self, scope: impl Into<String>) -> Self {
437 self.request = self.request.scope(scope);
438 self
439 }
440
441 pub fn authentication(mut self, authentication: ClientAuthentication) -> Self {
442 self.request = self.request.authentication(authentication);
443 self
444 }
445
446 pub fn resource(mut self, resource: impl Into<String>) -> Self {
447 self.request = self.request.resource(resource);
448 self
449 }
450
451 pub fn into_form_request(self) -> Result<OAuthFormRequest, OAuthError> {
452 create_client_credentials_token_request(self.request)
453 }
454
455 pub async fn send(self) -> Result<OAuth2Tokens, OAuthError> {
456 let request = create_client_credentials_token_request(self.request)?;
457 let data = post_form_with_client(
458 self.client.token_endpoint.as_str(),
459 request,
460 &self.client.http,
461 )
462 .await?;
463 get_oauth2_tokens(data)
464 }
465}
466
467pub async fn submit_token_form(
469 token_endpoint: &str,
470 request: OAuthFormRequest,
471 client: &OAuthHttpClient,
472) -> Result<OAuth2Tokens, OAuthError> {
473 let data = post_form_with_client(token_endpoint, request, client).await?;
474 get_oauth2_tokens(data)
475}
476
477pub async fn exchange_authorization_code(
479 token_endpoint: &str,
480 request: AuthorizationCodeRequest,
481 client: &OAuthHttpClient,
482) -> Result<OAuth2Tokens, OAuthError> {
483 let form = create_authorization_code_request(request)?;
484 let data = post_form_with_client(token_endpoint, form, client).await?;
485 get_oauth2_tokens(data)
486}
487
488pub async fn refresh_access_token_at(
490 token_endpoint: &str,
491 request: RefreshAccessTokenRequest,
492 client: &OAuthHttpClient,
493) -> Result<OAuth2Tokens, OAuthError> {
494 let form = create_refresh_access_token_request(request)?;
495 let data = post_form_with_client(token_endpoint, form, client).await?;
496 get_oauth2_tokens(data)
497}