Skip to main content

simple_oauth/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::borrow::Cow;
4
5use bon::bon;
6use oauth2::{
7    CsrfToken, EndpointNotSet, EndpointSet, HttpClientError, RequestTokenError, TokenResponse,
8    basic::{BasicClient, BasicErrorResponse},
9};
10
11pub mod common;
12mod provider;
13pub mod types;
14
15pub use provider::SimpleOAuthProvider;
16use subtle::ConstantTimeEq;
17
18use crate::types::{AuthorizeUrl, OAuthCredentials, StandardTokenResponse, UserInfo};
19
20#[derive(Debug, thiserror::Error)]
21pub enum SimpleOAuthError {
22    #[error(transparent)]
23    Request(#[from] reqwest::Error),
24    #[error("invalid url: {0}")]
25    ParseUrl(#[from] oauth2::url::ParseError),
26    #[error("returned state did not match initial state")]
27    StateMismatch,
28    #[error("token exchange error: {0}")]
29    TokenExchange(#[from] RequestTokenError<HttpClientError<reqwest::Error>, BasicErrorResponse>),
30    #[error("deserialization error: {0}")]
31    Deserialization(#[from] serde_json::Error),
32}
33
34#[derive(Debug, Clone)]
35pub struct SimpleOAuthClient<P> {
36    http_client: reqwest::Client,
37    oauth_http_client: oauth2_reqwest::ReqwestClient,
38    oauth_client:
39        BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>,
40    provider: P,
41}
42
43#[bon]
44impl<P> SimpleOAuthClient<P>
45where
46    P: SimpleOAuthProvider,
47{
48    #[builder(on(String, into))]
49    pub fn new(
50        provider: P,
51        credentials: OAuthCredentials,
52        redirect_url: String,
53        http_client: Option<&reqwest::Client>,
54    ) -> Result<Self, SimpleOAuthError> {
55        let http_client = if let Some(client) = http_client {
56            client.to_owned()
57        } else {
58            reqwest::Client::builder()
59                .redirect(reqwest::redirect::Policy::none())
60                .build()?
61        };
62        let oauth_client = BasicClient::new(oauth2::ClientId::new(credentials.client_id))
63            .set_client_secret(oauth2::ClientSecret::new(credentials.client_secret))
64            .set_redirect_uri(oauth2::RedirectUrl::new(redirect_url)?)
65            .set_auth_uri(oauth2::AuthUrl::new(provider.authorize_url().into())?)
66            .set_token_uri(oauth2::TokenUrl::new(provider.token_url().into())?);
67
68        Ok(Self {
69            oauth_http_client: oauth2_reqwest::ReqwestClient::from(http_client.clone()),
70            http_client,
71            oauth_client,
72            provider,
73        })
74    }
75
76    /// Build the URL to navigate the user to for authorization. **Make sure to save the returned state and
77    /// PKCE verifier in a secure location, typically in a server-side cache or session.**
78    ///
79    /// If scopes are not provided, will use default limited scopes for the provider to access basic user info (user ID and name only).
80    /// If more access is needed (e.g. email), make sure to specify all required scopes.
81    ///
82    /// You can optionally override the redirect URL, but make sure to pass in the exact same URL when calling
83    /// `exchange_code()`.
84    #[builder(on(String, into), finish_fn(name = "build"))]
85    pub fn authorize_url(
86        &self,
87        redirect_url: Option<String>,
88        scopes: Option<&[&str]>,
89    ) -> Result<AuthorizeUrl, SimpleOAuthError> {
90        let (pkce_challenge, pkce_verifier) = oauth2::PkceCodeChallenge::new_random_sha256();
91        let mut auth_request = self
92            .oauth_client
93            .authorize_url(CsrfToken::new_random)
94            .set_pkce_challenge(pkce_challenge)
95            .add_scopes(
96                scopes
97                    .unwrap_or(self.provider.default_scopes())
98                    .iter()
99                    .map(|s| oauth2::Scope::new((*s).to_owned())),
100            );
101        if let Some(redirect_url) = redirect_url {
102            auth_request =
103                auth_request.set_redirect_uri(Cow::Owned(oauth2::RedirectUrl::new(redirect_url)?));
104        }
105        let (url, state) = auth_request.url();
106
107        Ok(AuthorizeUrl {
108            url,
109            state: state.into_secret(),
110            pkce_verifier: pkce_verifier.into_secret(),
111        })
112    }
113
114    /// Exchange the returned code after authorization for an access/refresh token. Along with the
115    /// returned code and state, you will need to specify the saved PKCE verifier and initial state
116    /// (the state will be verified using a timing-resistant algorithm).
117    ///
118    /// If you set the redirect URL when calling `authorize_url()`, you must set the same URL here as well.
119    #[builder(on(String, into), finish_fn(name = "build"))]
120    pub async fn exchange_code(
121        &self,
122        code: String,
123        state: &str,
124        initial_state: &str,
125        pkce_verifier: String,
126        redirect_url: Option<String>,
127    ) -> Result<StandardTokenResponse, SimpleOAuthError> {
128        if state.as_bytes().ct_ne(initial_state.as_bytes()).into() {
129            return Err(SimpleOAuthError::StateMismatch);
130        }
131        let mut token_request = self
132            .oauth_client
133            .exchange_code(oauth2::AuthorizationCode::new(code))
134            .set_pkce_verifier(oauth2::PkceCodeVerifier::new(pkce_verifier));
135        if let Some(redirect_url) = redirect_url {
136            token_request =
137                token_request.set_redirect_uri(Cow::Owned(oauth2::RedirectUrl::new(redirect_url)?));
138        }
139        let token = token_request.request_async(&self.oauth_http_client).await?;
140
141        Ok(StandardTokenResponse {
142            access_token: token.access_token().secret().to_owned(),
143            refresh_token: token.refresh_token().map(|s| s.secret().to_owned()),
144            expires_in: token.expires_in(),
145        })
146    }
147
148    /// Exchange the refresh token for a new access token
149    #[builder(on(String, into), finish_fn(name = "build"))]
150    pub async fn exchange_refresh_token(
151        &self,
152        refresh_token: String,
153    ) -> Result<StandardTokenResponse, SimpleOAuthError> {
154        let token = self
155            .oauth_client
156            .exchange_refresh_token(&oauth2::RefreshToken::new(refresh_token))
157            .request_async(&self.oauth_http_client)
158            .await?;
159
160        Ok(StandardTokenResponse {
161            access_token: token.access_token().secret().to_owned(),
162            refresh_token: token.refresh_token().map(|s| s.secret().to_owned()),
163            expires_in: token.expires_in(),
164        })
165    }
166
167    /// Retrieve user info from the provider using the access token. This is a convenience
168    /// method for when you only need basic info (e.g. id, name, email, avatar).
169    pub async fn get_user_info(&self, access_token: &str) -> Result<UserInfo, SimpleOAuthError> {
170        let mut user_info_request = self
171            .http_client
172            .get(self.provider.user_info_url())
173            .bearer_auth(access_token);
174        for (name, val) in self.provider.additional_headers() {
175            user_info_request = user_info_request.header(name, val);
176        }
177
178        let user_info_val = user_info_request
179            .send()
180            .await?
181            .error_for_status()?
182            .json()
183            .await?;
184        let user_info = self.provider.extract_user_info(user_info_val)?;
185
186        Ok(user_info)
187    }
188}