1#![doc = include_str!("../README.md")]
2
3use std::borrow::Cow;
4
5use bon::bon;
6use oauth2::{
7 CsrfToken, EndpointNotSet, EndpointSet, HttpClientError, RequestTokenError, TokenResponse,
8 basic::{BasicClient, BasicErrorResponse},
9};
10
11pub mod common;
12mod provider;
13pub mod types;
14
15pub use provider::SimpleOAuthProvider;
16use subtle::ConstantTimeEq;
17
18use crate::types::{AuthorizeUrl, OAuthCredentials, StandardTokenResponse, UserInfo};
19
20#[derive(Debug, thiserror::Error)]
21pub enum SimpleOAuthError {
22 #[error(transparent)]
23 Request(#[from] reqwest::Error),
24 #[error("invalid url: {0}")]
25 ParseUrl(#[from] oauth2::url::ParseError),
26 #[error("returned state did not match initial state")]
27 StateMismatch,
28 #[error("token exchange error: {0}")]
29 TokenExchange(#[from] RequestTokenError<HttpClientError<reqwest::Error>, BasicErrorResponse>),
30 #[error("deserialization error: {0}")]
31 Deserialization(#[from] serde_json::Error),
32}
33
34#[derive(Debug, Clone)]
35pub struct SimpleOAuthClient<P> {
36 http_client: reqwest::Client,
37 oauth_http_client: oauth2_reqwest::ReqwestClient,
38 oauth_client:
39 BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>,
40 provider: P,
41}
42
43#[bon]
44impl<P> SimpleOAuthClient<P>
45where
46 P: SimpleOAuthProvider,
47{
48 #[builder(on(String, into))]
49 pub fn new(
50 provider: P,
51 credentials: OAuthCredentials,
52 redirect_url: String,
53 http_client: Option<&reqwest::Client>,
54 ) -> Result<Self, SimpleOAuthError> {
55 let http_client = if let Some(client) = http_client {
56 client.to_owned()
57 } else {
58 reqwest::Client::builder()
59 .redirect(reqwest::redirect::Policy::none())
60 .build()?
61 };
62 let oauth_client = BasicClient::new(oauth2::ClientId::new(credentials.client_id))
63 .set_client_secret(oauth2::ClientSecret::new(credentials.client_secret))
64 .set_redirect_uri(oauth2::RedirectUrl::new(redirect_url)?)
65 .set_auth_uri(oauth2::AuthUrl::new(provider.authorize_url().into())?)
66 .set_token_uri(oauth2::TokenUrl::new(provider.token_url().into())?);
67
68 Ok(Self {
69 oauth_http_client: oauth2_reqwest::ReqwestClient::from(http_client.clone()),
70 http_client,
71 oauth_client,
72 provider,
73 })
74 }
75
76 #[builder(on(String, into), finish_fn(name = "build"))]
85 pub fn authorize_url(
86 &self,
87 redirect_url: Option<String>,
88 scopes: Option<&[&str]>,
89 ) -> Result<AuthorizeUrl, SimpleOAuthError> {
90 let (pkce_challenge, pkce_verifier) = oauth2::PkceCodeChallenge::new_random_sha256();
91 let mut auth_request = self
92 .oauth_client
93 .authorize_url(CsrfToken::new_random)
94 .set_pkce_challenge(pkce_challenge)
95 .add_scopes(
96 scopes
97 .unwrap_or(self.provider.default_scopes())
98 .iter()
99 .map(|s| oauth2::Scope::new((*s).to_owned())),
100 );
101 if let Some(redirect_url) = redirect_url {
102 auth_request =
103 auth_request.set_redirect_uri(Cow::Owned(oauth2::RedirectUrl::new(redirect_url)?));
104 }
105 let (url, state) = auth_request.url();
106
107 Ok(AuthorizeUrl {
108 url,
109 state: state.into_secret(),
110 pkce_verifier: pkce_verifier.into_secret(),
111 })
112 }
113
114 #[builder(on(String, into), finish_fn(name = "build"))]
120 pub async fn exchange_code(
121 &self,
122 code: String,
123 state: &str,
124 initial_state: &str,
125 pkce_verifier: String,
126 redirect_url: Option<String>,
127 ) -> Result<StandardTokenResponse, SimpleOAuthError> {
128 if state.as_bytes().ct_ne(initial_state.as_bytes()).into() {
129 return Err(SimpleOAuthError::StateMismatch);
130 }
131 let mut token_request = self
132 .oauth_client
133 .exchange_code(oauth2::AuthorizationCode::new(code))
134 .set_pkce_verifier(oauth2::PkceCodeVerifier::new(pkce_verifier));
135 if let Some(redirect_url) = redirect_url {
136 token_request =
137 token_request.set_redirect_uri(Cow::Owned(oauth2::RedirectUrl::new(redirect_url)?));
138 }
139 let token = token_request.request_async(&self.oauth_http_client).await?;
140
141 Ok(StandardTokenResponse {
142 access_token: token.access_token().secret().to_owned(),
143 refresh_token: token.refresh_token().map(|s| s.secret().to_owned()),
144 expires_in: token.expires_in(),
145 })
146 }
147
148 #[builder(on(String, into), finish_fn(name = "build"))]
150 pub async fn exchange_refresh_token(
151 &self,
152 refresh_token: String,
153 ) -> Result<StandardTokenResponse, SimpleOAuthError> {
154 let token = self
155 .oauth_client
156 .exchange_refresh_token(&oauth2::RefreshToken::new(refresh_token))
157 .request_async(&self.oauth_http_client)
158 .await?;
159
160 Ok(StandardTokenResponse {
161 access_token: token.access_token().secret().to_owned(),
162 refresh_token: token.refresh_token().map(|s| s.secret().to_owned()),
163 expires_in: token.expires_in(),
164 })
165 }
166
167 pub async fn get_user_info(&self, access_token: &str) -> Result<UserInfo, SimpleOAuthError> {
170 let mut user_info_request = self
171 .http_client
172 .get(self.provider.user_info_url())
173 .bearer_auth(access_token);
174 for (name, val) in self.provider.additional_headers() {
175 user_info_request = user_info_request.header(name, val);
176 }
177
178 let user_info_val = user_info_request
179 .send()
180 .await?
181 .error_for_status()?
182 .json()
183 .await?;
184 let user_info = self.provider.extract_user_info(user_info_val)?;
185
186 Ok(user_info)
187 }
188}