1pub mod link;
2pub mod login;
3pub mod provider;
4pub mod refresh;
5pub mod signup;
6
7use crate::config::Allow;
8use crate::core::CoreUserp;
9use crate::traits::UserpCookies;
10use crate::traits::UserpStore;
11
12use self::link::OAuthLinkCallbackError;
13use self::login::OAuthLoginCallbackError;
14use self::provider::OAuthProvider;
15use self::refresh::OAuthRefreshCallbackError;
16use self::signup::OAuthSignupCallbackError;
17
18use chrono::{DateTime, Utc};
19use oauth2::{basic::BasicTokenType, EmptyExtraTokenFields, StandardTokenResponse};
20use oauth2::{AuthorizationCode, CsrfToken, RedirectUrl, TokenResponse};
21use serde::{Deserialize, Serialize};
22use serde_json::json;
23use std::{fmt::Display, sync::Arc};
24use thiserror::Error;
25use url::Url;
26use uuid::Uuid;
27
28const OAUTH_DATA_KEY: &str = "userp-oauth-state";
29
30#[derive(Debug, Clone)]
31pub struct OAuthProviderUser {
32 pub id: String,
33 pub email: Option<String>,
34 pub name: Option<String>,
35 pub email_verified: bool,
36}
37
38pub enum RefreshInitResult {
39 Redirect(Url),
40 Ok,
41}
42
43#[derive(Debug, Serialize, Deserialize, Clone)]
44pub enum OAuthFlow {
45 LogIn {
46 next: Option<String>,
47 },
48 SignUp {
49 next: Option<String>,
50 },
51 Link {
52 user_id: Uuid,
53 next: Option<String>,
54 },
55 Refresh {
56 token_id: Uuid,
57 next: Option<String>,
58 },
59}
60
61impl Display for OAuthFlow {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 f.write_str(match self {
64 OAuthFlow::LogIn { .. } => "LogIn",
65 OAuthFlow::SignUp { .. } => "SignUp",
66 OAuthFlow::Link { .. } => "Link",
67 OAuthFlow::Refresh { .. } => "Refresh",
68 })
69 }
70}
71
72#[derive(Debug, Clone)]
73pub struct OAuthConfig {
74 pub allow_login: Option<Allow>,
75 pub allow_signup: Option<Allow>,
76 pub allow_linking: bool,
77 pub base_url: Url,
78 pub providers: OAuthProviders,
79}
80
81impl OAuthConfig {
82 pub fn new(base_url: Url) -> Self {
83 Self {
84 base_url,
85 allow_login: None,
86 allow_signup: None,
87 allow_linking: true,
88 providers: Default::default(),
89 }
90 }
91
92 pub fn with_client(mut self, client: impl OAuthProvider + 'static) -> Self {
93 self.providers.0.push(Arc::new(client));
94 self
95 }
96
97 pub fn with_allow_signup(mut self, allow_signup: Allow) -> Self {
98 self.allow_signup = Some(allow_signup);
99 self
100 }
101
102 pub fn with_allow_login(mut self, allow_login: Allow) -> Self {
103 self.allow_login = Some(allow_login);
104 self
105 }
106
107 pub fn with_allow_linking(mut self, allow_linking: bool) -> Self {
108 self.allow_linking = allow_linking;
109 self
110 }
111}
112
113#[derive(Debug, Clone, Default)]
114pub struct OAuthProviders(pub(super) Vec<Arc<dyn OAuthProvider>>);
115
116impl OAuthProviders {
117 pub fn get(&self, name: &str) -> Option<&Arc<dyn OAuthProvider>> {
118 self.0.iter().find(|c| c.name() == name)
119 }
120}
121
122#[derive(Clone)]
123pub struct UnmatchedOAuthToken {
124 pub access_token: String,
125 pub refresh_token: Option<String>,
126 pub expires: Option<DateTime<Utc>>,
127 pub scopes: Vec<String>,
128 pub provider_name: String,
129 pub provider_user: OAuthProviderUser,
130}
131
132impl UnmatchedOAuthToken {
133 pub fn from_standard_token_response(
134 token_response: &StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>,
135 provider_name: &str,
136 provider_user: OAuthProviderUser,
137 ) -> Self {
138 Self {
139 access_token: token_response.access_token().secret().into(),
140 refresh_token: token_response.refresh_token().map(|rt| rt.secret().into()),
141 expires: token_response.expires_in().map(|d| Utc::now() + d),
142 scopes: token_response
143 .scopes()
144 .map(|scopes| scopes.iter().map(|s| s.to_string()).collect())
145 .unwrap_or_default(),
146 provider_name: provider_name.into(),
147 provider_user,
148 }
149 }
150}
151
152pub trait OAuthToken: Send + Sync {
153 fn get_id(&self) -> Uuid;
154 fn get_user_id(&self) -> Uuid;
155 fn get_provider_name(&self) -> &str;
156 fn get_refresh_token(&self) -> &Option<String>;
157}
158
159#[derive(Error, Debug)]
160pub enum OAuthCallbackError {
161 #[error("No provider found with name: '{0}'")]
162 NoProvider(String),
163 #[error("No oauth flow & state data cookie found")]
164 NoOAuthDataCookie,
165 #[error("Misformed OAuthData: {0}")]
166 MisformedOAuthData(#[from] serde_json::Error),
167 #[error("CSRF tokens didn't match")]
168 CsrfMismatch,
169 #[error(transparent)]
170 ExchangeAuthorizationCodeError(#[from] anyhow::Error),
171}
172
173#[derive(Error, Debug)]
174pub enum OAuthGenericCallbackError<StoreError: std::error::Error> {
175 #[error(transparent)]
176 Callback(#[from] OAuthCallbackError),
177 #[error(transparent)]
178 Signup(#[from] OAuthSignupCallbackError<StoreError>),
179 #[error(transparent)]
180 Login(#[from] OAuthLoginCallbackError<StoreError>),
181 #[error(transparent)]
182 Link(#[from] OAuthLinkCallbackError<StoreError>),
183 #[error(transparent)]
184 Refresh(#[from] OAuthRefreshCallbackError<StoreError>),
185}
186
187impl<S: UserpStore, C: UserpCookies> CoreUserp<S, C> {
188 fn redirect_uri(&self, path: String, provider_name: &str) -> RedirectUrl {
189 let path = if path.ends_with('/') {
190 path
191 } else {
192 format!("{path}/")
193 };
194
195 let path = path.replace(":provider", provider_name);
196
197 RedirectUrl::from_url(self.oauth.base_url.join(path.as_str()).unwrap())
198 }
199
200 async fn oauth_init(
201 mut self,
202 path: String,
203 provider: Arc<dyn OAuthProvider>,
204 oauth_flow: OAuthFlow,
205 ) -> (Self, Url) {
206 let (auth_url, csrf_state) = provider.get_authorization_url_and_state(
207 &self.redirect_uri(path, provider.name()),
208 provider.scopes(),
209 );
210
211 self.cookies
212 .add(OAUTH_DATA_KEY, &json!((csrf_state, oauth_flow)).to_string());
213
214 (self, auth_url)
215 }
216
217 async fn oauth_callback_inner(
218 &self,
219 provider_name: String,
220 code: AuthorizationCode,
221 csrf_token: CsrfToken,
222 path: String,
223 ) -> Result<(UnmatchedOAuthToken, OAuthFlow, Arc<dyn OAuthProvider>), OAuthCallbackError> {
224 let provider = self
225 .oauth
226 .providers
227 .get(&provider_name)
228 .ok_or(OAuthCallbackError::NoProvider(provider_name.clone()))?;
229
230 let oauth_data = self
231 .cookies
232 .get(OAUTH_DATA_KEY)
233 .ok_or(OAuthCallbackError::NoOAuthDataCookie)?;
234
235 let (prev_csrf_token, oauth_flow) =
236 serde_json::from_str::<(CsrfToken, OAuthFlow)>(&oauth_data)?;
237
238 if csrf_token.secret() != prev_csrf_token.secret() {
239 return Err(OAuthCallbackError::CsrfMismatch);
240 }
241
242 let unmatched_token = provider
243 .exchange_authorization_code(
244 provider.name(),
245 &self.redirect_uri(path, &provider_name),
246 &code,
247 )
248 .await?;
249
250 Ok((unmatched_token, oauth_flow, provider.clone()))
251 }
252
253 #[must_use = "Don't forget to return the auth session as part of the response!"]
254 pub async fn oauth_generic_callback(
255 self,
256 provider_name: String,
257 code: AuthorizationCode,
258 state: CsrfToken,
259 ) -> Result<(Self, Option<String>), OAuthGenericCallbackError<S::Error>> {
260 let (unmatched_token, flow, provider) = self
261 .oauth_callback_inner(
262 provider_name.clone(),
263 code,
264 state,
265 self.routes.actions.signup_oauth_provider.clone(),
266 )
267 .await?;
268
269 Ok(match &flow {
270 OAuthFlow::LogIn { .. } => {
271 self.oauth_login_callback_inner(provider, unmatched_token, flow)
272 .await?
273 }
274 OAuthFlow::SignUp { .. } => {
275 self.oauth_signup_callback_inner(provider, unmatched_token, flow)
276 .await?
277 }
278 OAuthFlow::Link { .. } => {
279 let next = self
280 .oauth_link_callback_inner(provider, unmatched_token, flow)
281 .await?;
282
283 (self, next)
284 }
285 OAuthFlow::Refresh { .. } => {
286 let next = self
287 .oauth_refresh_callback_inner(unmatched_token, flow)
288 .await?;
289
290 (self, next)
291 }
292 })
293 }
294}