userp/
oauth.rs

1pub mod link;
2pub mod login;
3pub mod provider;
4pub mod refresh;
5pub mod signup;
6
7use crate::config::Allow;
8use crate::core::CoreUserp;
9use crate::traits::UserpCookies;
10use crate::traits::UserpStore;
11
12use self::link::OAuthLinkCallbackError;
13use self::login::OAuthLoginCallbackError;
14use self::provider::OAuthProvider;
15use self::refresh::OAuthRefreshCallbackError;
16use self::signup::OAuthSignupCallbackError;
17
18use chrono::{DateTime, Utc};
19use oauth2::{basic::BasicTokenType, EmptyExtraTokenFields, StandardTokenResponse};
20use oauth2::{AuthorizationCode, CsrfToken, RedirectUrl, TokenResponse};
21use serde::{Deserialize, Serialize};
22use serde_json::json;
23use std::{fmt::Display, sync::Arc};
24use thiserror::Error;
25use url::Url;
26use uuid::Uuid;
27
28const OAUTH_DATA_KEY: &str = "userp-oauth-state";
29
30#[derive(Debug, Clone)]
31pub struct OAuthProviderUser {
32    pub id: String,
33    pub email: Option<String>,
34    pub name: Option<String>,
35    pub email_verified: bool,
36}
37
38pub enum RefreshInitResult {
39    Redirect(Url),
40    Ok,
41}
42
43#[derive(Debug, Serialize, Deserialize, Clone)]
44pub enum OAuthFlow {
45    LogIn {
46        next: Option<String>,
47    },
48    SignUp {
49        next: Option<String>,
50    },
51    Link {
52        user_id: Uuid,
53        next: Option<String>,
54    },
55    Refresh {
56        token_id: Uuid,
57        next: Option<String>,
58    },
59}
60
61impl Display for OAuthFlow {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        f.write_str(match self {
64            OAuthFlow::LogIn { .. } => "LogIn",
65            OAuthFlow::SignUp { .. } => "SignUp",
66            OAuthFlow::Link { .. } => "Link",
67            OAuthFlow::Refresh { .. } => "Refresh",
68        })
69    }
70}
71
72#[derive(Debug, Clone)]
73pub struct OAuthConfig {
74    pub allow_login: Option<Allow>,
75    pub allow_signup: Option<Allow>,
76    pub allow_linking: bool,
77    pub base_url: Url,
78    pub providers: OAuthProviders,
79}
80
81impl OAuthConfig {
82    pub fn new(base_url: Url) -> Self {
83        Self {
84            base_url,
85            allow_login: None,
86            allow_signup: None,
87            allow_linking: true,
88            providers: Default::default(),
89        }
90    }
91
92    pub fn with_client(mut self, client: impl OAuthProvider + 'static) -> Self {
93        self.providers.0.push(Arc::new(client));
94        self
95    }
96
97    pub fn with_allow_signup(mut self, allow_signup: Allow) -> Self {
98        self.allow_signup = Some(allow_signup);
99        self
100    }
101
102    pub fn with_allow_login(mut self, allow_login: Allow) -> Self {
103        self.allow_login = Some(allow_login);
104        self
105    }
106
107    pub fn with_allow_linking(mut self, allow_linking: bool) -> Self {
108        self.allow_linking = allow_linking;
109        self
110    }
111}
112
113#[derive(Debug, Clone, Default)]
114pub struct OAuthProviders(pub(super) Vec<Arc<dyn OAuthProvider>>);
115
116impl OAuthProviders {
117    pub fn get(&self, name: &str) -> Option<&Arc<dyn OAuthProvider>> {
118        self.0.iter().find(|c| c.name() == name)
119    }
120}
121
122#[derive(Clone)]
123pub struct UnmatchedOAuthToken {
124    pub access_token: String,
125    pub refresh_token: Option<String>,
126    pub expires: Option<DateTime<Utc>>,
127    pub scopes: Vec<String>,
128    pub provider_name: String,
129    pub provider_user: OAuthProviderUser,
130}
131
132impl UnmatchedOAuthToken {
133    pub fn from_standard_token_response(
134        token_response: &StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>,
135        provider_name: &str,
136        provider_user: OAuthProviderUser,
137    ) -> Self {
138        Self {
139            access_token: token_response.access_token().secret().into(),
140            refresh_token: token_response.refresh_token().map(|rt| rt.secret().into()),
141            expires: token_response.expires_in().map(|d| Utc::now() + d),
142            scopes: token_response
143                .scopes()
144                .map(|scopes| scopes.iter().map(|s| s.to_string()).collect())
145                .unwrap_or_default(),
146            provider_name: provider_name.into(),
147            provider_user,
148        }
149    }
150}
151
152pub trait OAuthToken: Send + Sync {
153    fn get_id(&self) -> Uuid;
154    fn get_user_id(&self) -> Uuid;
155    fn get_provider_name(&self) -> &str;
156    fn get_refresh_token(&self) -> &Option<String>;
157}
158
159#[derive(Error, Debug)]
160pub enum OAuthCallbackError {
161    #[error("No provider found with name: '{0}'")]
162    NoProvider(String),
163    #[error("No oauth flow & state data cookie found")]
164    NoOAuthDataCookie,
165    #[error("Misformed OAuthData: {0}")]
166    MisformedOAuthData(#[from] serde_json::Error),
167    #[error("CSRF tokens didn't match")]
168    CsrfMismatch,
169    #[error(transparent)]
170    ExchangeAuthorizationCodeError(#[from] anyhow::Error),
171}
172
173#[derive(Error, Debug)]
174pub enum OAuthGenericCallbackError<StoreError: std::error::Error> {
175    #[error(transparent)]
176    Callback(#[from] OAuthCallbackError),
177    #[error(transparent)]
178    Signup(#[from] OAuthSignupCallbackError<StoreError>),
179    #[error(transparent)]
180    Login(#[from] OAuthLoginCallbackError<StoreError>),
181    #[error(transparent)]
182    Link(#[from] OAuthLinkCallbackError<StoreError>),
183    #[error(transparent)]
184    Refresh(#[from] OAuthRefreshCallbackError<StoreError>),
185}
186
187impl<S: UserpStore, C: UserpCookies> CoreUserp<S, C> {
188    fn redirect_uri(&self, path: String, provider_name: &str) -> RedirectUrl {
189        let path = if path.ends_with('/') {
190            path
191        } else {
192            format!("{path}/")
193        };
194
195        let path = path.replace(":provider", provider_name);
196
197        RedirectUrl::from_url(self.oauth.base_url.join(path.as_str()).unwrap())
198    }
199
200    async fn oauth_init(
201        mut self,
202        path: String,
203        provider: Arc<dyn OAuthProvider>,
204        oauth_flow: OAuthFlow,
205    ) -> (Self, Url) {
206        let (auth_url, csrf_state) = provider.get_authorization_url_and_state(
207            &self.redirect_uri(path, provider.name()),
208            provider.scopes(),
209        );
210
211        self.cookies
212            .add(OAUTH_DATA_KEY, &json!((csrf_state, oauth_flow)).to_string());
213
214        (self, auth_url)
215    }
216
217    async fn oauth_callback_inner(
218        &self,
219        provider_name: String,
220        code: AuthorizationCode,
221        csrf_token: CsrfToken,
222        path: String,
223    ) -> Result<(UnmatchedOAuthToken, OAuthFlow, Arc<dyn OAuthProvider>), OAuthCallbackError> {
224        let provider = self
225            .oauth
226            .providers
227            .get(&provider_name)
228            .ok_or(OAuthCallbackError::NoProvider(provider_name.clone()))?;
229
230        let oauth_data = self
231            .cookies
232            .get(OAUTH_DATA_KEY)
233            .ok_or(OAuthCallbackError::NoOAuthDataCookie)?;
234
235        let (prev_csrf_token, oauth_flow) =
236            serde_json::from_str::<(CsrfToken, OAuthFlow)>(&oauth_data)?;
237
238        if csrf_token.secret() != prev_csrf_token.secret() {
239            return Err(OAuthCallbackError::CsrfMismatch);
240        }
241
242        let unmatched_token = provider
243            .exchange_authorization_code(
244                provider.name(),
245                &self.redirect_uri(path, &provider_name),
246                &code,
247            )
248            .await?;
249
250        Ok((unmatched_token, oauth_flow, provider.clone()))
251    }
252
253    #[must_use = "Don't forget to return the auth session as part of the response!"]
254    pub async fn oauth_generic_callback(
255        self,
256        provider_name: String,
257        code: AuthorizationCode,
258        state: CsrfToken,
259    ) -> Result<(Self, Option<String>), OAuthGenericCallbackError<S::Error>> {
260        let (unmatched_token, flow, provider) = self
261            .oauth_callback_inner(
262                provider_name.clone(),
263                code,
264                state,
265                self.routes.actions.signup_oauth_provider.clone(),
266            )
267            .await?;
268
269        Ok(match &flow {
270            OAuthFlow::LogIn { .. } => {
271                self.oauth_login_callback_inner(provider, unmatched_token, flow)
272                    .await?
273            }
274            OAuthFlow::SignUp { .. } => {
275                self.oauth_signup_callback_inner(provider, unmatched_token, flow)
276                    .await?
277            }
278            OAuthFlow::Link { .. } => {
279                let next = self
280                    .oauth_link_callback_inner(provider, unmatched_token, flow)
281                    .await?;
282
283                (self, next)
284            }
285            OAuthFlow::Refresh { .. } => {
286                let next = self
287                    .oauth_refresh_callback_inner(unmatched_token, flow)
288                    .await?;
289
290                (self, next)
291            }
292        })
293    }
294}