1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::sync::{Arc, Mutex};
4use time::OffsetDateTime;
5
6use crate::context::AuthContext;
7use crate::cookies::{ChunkedCookieStore, Cookie};
8#[cfg(feature = "jose")]
9use crate::crypto::symmetric_encode_jwt_with_salt;
10use crate::db::{Account, DbAdapter, Session, User};
11use crate::error::RustAuthError;
12use crate::session::{CreateSessionInput, SessionStore};
13use crate::user::{
14 CreateOAuthAccountInput, CreateUserInput, DbUserStore, UpdateAccountInput, UpdateUserInput,
15};
16
17use super::errors::OAuthUserInfoError;
18use super::tokens::{encrypt_oauth_tokens_for_storage, set_token_util};
19
20#[cfg(feature = "jose")]
21const ACCOUNT_COOKIE_SALT: &str = "better-auth-account";
22#[cfg(feature = "oauth")]
28pub(crate) const ACCOUNT_ALREADY_LINKED_TO_DIFFERENT_USER: &str =
29 "account_already_linked_to_different_user";
30#[cfg(feature = "oauth")]
31pub(crate) const EMAIL_DOES_NOT_MATCH_LINKED_USER: &str = "email_doesn't_match";
32
33#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
34pub struct OAuthUserInfo {
35 pub id: String,
36 pub name: String,
37 pub email: String,
38 pub image: Option<String>,
39 pub email_verified: bool,
40 pub raw_attributes: Option<Value>,
41}
42
43#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
44pub struct OAuthAccountInput {
45 pub provider_id: String,
46 pub account_id: String,
47 pub access_token: Option<String>,
48 pub refresh_token: Option<String>,
49 pub id_token: Option<String>,
50 pub access_token_expires_at: Option<OffsetDateTime>,
51 pub refresh_token_expires_at: Option<OffsetDateTime>,
52 pub scope: Option<String>,
53}
54
55#[derive(Debug, Clone, Default, PartialEq, Eq)]
56pub struct HandleOAuthUserInfoInput {
57 pub user_info: OAuthUserInfo,
58 pub account: OAuthAccountInput,
59 pub callback_url: Option<String>,
60 pub disable_sign_up: bool,
61 pub override_user_info: bool,
62 pub is_trusted_provider: bool,
63 pub require_trusted_provider_for_implicit_link: bool,
64}
65
66#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct OAuthSessionUser {
68 pub session: Session,
69 pub user: User,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq)]
73struct CreatedOAuthSessionUser {
74 session: Session,
75 user: User,
76 account: Account,
77}
78
79#[derive(Debug, Clone, PartialEq, Eq)]
80pub struct HandleOAuthUserInfoResult {
81 pub data: Option<OAuthSessionUser>,
82 pub error: Option<OAuthUserInfoError>,
83 pub is_register: bool,
84 pub cookies: Vec<Cookie>,
85}
86
87pub async fn handle_oauth_user_info(
88 context: &AuthContext,
89 adapter: &dyn DbAdapter,
90 input: HandleOAuthUserInfoInput,
91) -> Result<HandleOAuthUserInfoResult, RustAuthError> {
92 let users = DbUserStore::with_schema(adapter, context.db_schema.clone());
93 let normalized_email = input.user_info.email.to_lowercase();
94 let db_user = users
95 .find_oauth_user(
96 &normalized_email,
97 &input.account.account_id,
98 &input.account.provider_id,
99 )
100 .await?;
101 let mut user = db_user.as_ref().map(|lookup| lookup.user.clone());
102 let account_cookie;
103 let is_register = user.is_none();
104 let mut created_session = None;
105
106 if let Some(lookup) = db_user {
107 let linked_account = lookup.linked_account.or_else(|| {
108 lookup
109 .accounts
110 .iter()
111 .find(|account| {
112 account.provider_id == input.account.provider_id
113 && account.account_id == input.account.account_id
114 })
115 .cloned()
116 });
117 if let Some(linked_account) = linked_account {
118 account_cookie = Some(
119 update_linked_account(context, &users, &linked_account, &input.account).await?,
120 );
121 if input.override_user_info {
122 user = override_linked_user_info(&users, &lookup.user, &input.user_info).await?;
123 } else if input.user_info.email_verified
124 && !lookup.user.email_verified
125 && same_email(&input.user_info.email, &lookup.user.email)
126 {
127 user = users
128 .update_user_email_verified(&lookup.user.id, true)
129 .await?
130 .or(user);
131 }
132 } else if !can_implicitly_link(context, &input) {
133 return Ok(result_error(OAuthUserInfoError::AccountNotLinked, false));
134 } else {
135 let account = account_input(context, &input.account, &lookup.user.id)?;
136 let linked_account = users
137 .link_account(account)
138 .await
139 .map_err(|_| RustAuthError::Adapter("unable to link OAuth account".to_owned()))?;
140 account_cookie = Some(linked_account);
141 if input.user_info.email_verified
142 && !lookup.user.email_verified
143 && same_email(&input.user_info.email, &lookup.user.email)
144 {
145 user = users
146 .update_user_email_verified(&lookup.user.id, true)
147 .await?
148 .or(Some(lookup.user));
149 } else {
150 user = Some(lookup.user);
151 }
152 }
153 } else {
154 if input.disable_sign_up {
155 return Ok(result_error(OAuthUserInfoError::SignupDisabled, false));
156 }
157 let mut user_input = CreateUserInput::new(input.user_info.name.clone(), normalized_email)
158 .email_verified(input.user_info.email_verified);
159 if let Some(image) = input.user_info.image.clone() {
160 user_input = user_input.image(image);
161 }
162 let created = create_oauth_session_user(
163 context,
164 adapter,
165 user_input,
166 account_input(context, &input.account, "")?,
167 )
168 .await?;
169 account_cookie = Some(created.account);
170 created_session = Some(created.session);
171 user = Some(created.user);
172 }
173
174 let Some(user) = user else {
175 return Ok(result_error(OAuthUserInfoError::UnableToCreateUser, false));
176 };
177 let session = match created_session {
178 Some(session) => session,
179 None => SessionStore::with_storage_and_schema(
180 adapter,
181 context.db_schema.clone(),
182 context.secondary_storage(),
183 context.options.session.store_session_in_database,
184 context.options.session.preserve_session_in_database,
185 )
186 .create_session(CreateSessionInput::new(
187 &user.id,
188 OffsetDateTime::now_utc() + context.session_config.expires_in,
189 ))
190 .await
191 .map_err(|_| RustAuthError::Adapter("unable to create OAuth session".to_owned()))?,
192 };
193 let cookies = if context.options.account.store_account_cookie {
194 account_cookie
195 .as_ref()
196 .map(|account| set_account_cookie(context, account))
197 .transpose()?
198 .unwrap_or_default()
199 } else {
200 Vec::new()
201 };
202 Ok(HandleOAuthUserInfoResult {
203 data: Some(OAuthSessionUser { session, user }),
204 error: None,
205 is_register,
206 cookies,
207 })
208}
209
210async fn create_oauth_session_user(
211 context: &AuthContext,
212 adapter: &dyn DbAdapter,
213 user_input: CreateUserInput,
214 account_input: CreateOAuthAccountInput,
215) -> Result<CreatedOAuthSessionUser, RustAuthError> {
216 let result = Arc::new(Mutex::new(None));
217 let result_for_transaction = Arc::clone(&result);
218 let expires_in = context.session_config.expires_in;
219 let secondary_storage = context.secondary_storage();
220 let store_session_in_database = context.options.session.store_session_in_database;
221 let preserve_session_in_database = context.options.session.preserve_session_in_database;
222 let schema = context.db_schema.clone();
223 let transaction_status = adapter
224 .transaction(Box::new(move |transaction| {
225 let secondary_storage = secondary_storage.clone();
226 let schema = schema.clone();
227 Box::pin(async move {
228 let users = DbUserStore::with_schema(transaction.as_ref(), schema);
229 let created = users
230 .create_oauth_user(user_input, account_input)
231 .await
232 .map_err(|_| {
233 RustAuthError::Adapter("unable to create OAuth user".to_owned())
234 })?;
235 let session = SessionStore::with_storage(
236 transaction.as_ref(),
237 secondary_storage,
238 store_session_in_database,
239 preserve_session_in_database,
240 )
241 .create_session(CreateSessionInput::new(
242 &created.user.id,
243 OffsetDateTime::now_utc() + expires_in,
244 ))
245 .await
246 .map_err(|_| RustAuthError::Adapter("unable to create OAuth session".to_owned()))?;
247 store_created_oauth_session_user(
248 &result_for_transaction,
249 CreatedOAuthSessionUser {
250 session,
251 user: created.user,
252 account: created.account,
253 },
254 )?;
255 Ok(())
256 })
257 }))
258 .await;
259
260 match transaction_status {
261 Ok(()) => take_created_oauth_session_user(&result)?.ok_or_else(|| {
262 RustAuthError::Adapter(
263 "create OAuth session transaction completed without a result".to_owned(),
264 )
265 }),
266 Err(error) => Err(error),
267 }
268}
269
270fn store_created_oauth_session_user(
271 result: &Mutex<Option<CreatedOAuthSessionUser>>,
272 value: CreatedOAuthSessionUser,
273) -> Result<(), RustAuthError> {
274 let mut guard = result.lock().map_err(|_| RustAuthError::LockPoisoned {
275 context: "create OAuth session result",
276 })?;
277 *guard = Some(value);
278 Ok(())
279}
280
281fn take_created_oauth_session_user(
282 result: &Mutex<Option<CreatedOAuthSessionUser>>,
283) -> Result<Option<CreatedOAuthSessionUser>, RustAuthError> {
284 result
285 .lock()
286 .map_err(|_| RustAuthError::LockPoisoned {
287 context: "create OAuth session result",
288 })
289 .map(|mut guard| guard.take())
290}
291
292fn can_implicitly_link(context: &AuthContext, input: &HandleOAuthUserInfoInput) -> bool {
293 let linking = &context.options.account.account_linking;
294 if !linking.enabled || linking.disable_implicit_linking {
295 return false;
296 }
297 let trusted_providers = context
298 .trusted_providers_for_request(None)
299 .unwrap_or_default();
300 let trusted = input.is_trusted_provider
301 || trusted_providers
302 .iter()
303 .any(|provider| provider == &input.account.provider_id);
304 if input.require_trusted_provider_for_implicit_link {
305 return trusted;
306 }
307 trusted || input.user_info.email_verified
308}
309
310async fn update_linked_account(
311 context: &AuthContext,
312 users: &DbUserStore<'_>,
313 linked_account: &Account,
314 account: &OAuthAccountInput,
315) -> Result<Account, RustAuthError> {
316 if !context.options.account.update_account_on_sign_in {
317 return Ok(linked_account.clone());
318 }
319 let updated = users
320 .update_account(&linked_account.id, account_update_input(context, account)?)
321 .await?;
322 Ok(updated.unwrap_or_else(|| linked_account.clone()))
323}
324
325pub(crate) fn set_account_cookie(
326 context: &AuthContext,
327 account: &Account,
328) -> Result<Vec<Cookie>, RustAuthError> {
329 let max_age = context
330 .auth_cookies
331 .account_data
332 .attributes
333 .max_age
334 .unwrap_or(60 * 5);
335 #[cfg(feature = "jose")]
336 let data = symmetric_encode_jwt_with_salt(
337 account,
338 &context.secret_config,
339 ACCOUNT_COOKIE_SALT,
340 max_age,
341 )?;
342 #[cfg(not(feature = "jose"))]
343 let data = encode_account_cookie_data(account)?;
344 let mut attributes = context.auth_cookies.account_data.attributes.clone();
345 attributes.max_age = Some(max_age);
346 Ok(ChunkedCookieStore::new(
347 context.auth_cookies.account_data.name.clone(),
348 attributes,
349 "",
350 )
351 .chunk(&data))
352}
353
354#[cfg(not(feature = "jose"))]
355fn encode_account_cookie_data(_account: &Account) -> Result<String, RustAuthError> {
356 Err(RustAuthError::FeatureDisabled { feature: "jose" })
357}
358
359async fn override_linked_user_info(
360 users: &DbUserStore<'_>,
361 existing: &User,
362 provider: &OAuthUserInfo,
363) -> Result<Option<User>, RustAuthError> {
364 let normalized_email = provider.email.to_lowercase();
365 let email_verified = if normalized_email == existing.email {
366 existing.email_verified || provider.email_verified
367 } else {
368 provider.email_verified
369 };
370 let updated = users
371 .update_user(
372 &existing.id,
373 UpdateUserInput::new()
374 .name(provider.name.clone())
375 .image(provider.image.clone()),
376 )
377 .await?;
378 users
379 .update_user_email(&existing.id, &normalized_email, email_verified)
380 .await
381 .map(|user| user.or(updated))
382}
383
384fn account_input(
385 context: &AuthContext,
386 account: &OAuthAccountInput,
387 user_id: &str,
388) -> Result<CreateOAuthAccountInput, RustAuthError> {
389 let tokens = encrypt_oauth_tokens_for_storage(
390 account.access_token.as_deref(),
391 account.refresh_token.as_deref(),
392 account.id_token.as_deref(),
393 context,
394 )?;
395 Ok(CreateOAuthAccountInput {
396 id: None,
397 provider_id: account.provider_id.clone(),
398 account_id: account.account_id.clone(),
399 user_id: user_id.to_owned(),
400 access_token: tokens.access_token,
401 refresh_token: tokens.refresh_token,
402 id_token: tokens.id_token,
403 access_token_expires_at: account.access_token_expires_at,
404 refresh_token_expires_at: account.refresh_token_expires_at,
405 scope: account.scope.clone(),
406 })
407}
408
409fn account_update_input(
410 context: &AuthContext,
411 account: &OAuthAccountInput,
412) -> Result<UpdateAccountInput, RustAuthError> {
413 let mut input = UpdateAccountInput::default();
414 if account.access_token.is_some() {
415 input.access_token = Some(set_token_util(account.access_token.as_deref(), context)?);
416 }
417 if account.refresh_token.is_some() {
418 input.refresh_token = Some(set_token_util(account.refresh_token.as_deref(), context)?);
419 }
420 if account.id_token.is_some() {
421 input.id_token = Some(set_token_util(account.id_token.as_deref(), context)?);
422 }
423 if account.access_token_expires_at.is_some() {
424 input.access_token_expires_at = Some(account.access_token_expires_at);
425 }
426 if account.refresh_token_expires_at.is_some() {
427 input.refresh_token_expires_at = Some(account.refresh_token_expires_at);
428 }
429 if account.scope.is_some() {
430 input.scope = Some(account.scope.clone());
431 }
432 Ok(input)
433}
434
435fn same_email(provider_email: &str, user_email: &str) -> bool {
436 provider_email.eq_ignore_ascii_case(user_email)
437}
438
439fn result_error(error: OAuthUserInfoError, is_register: bool) -> HandleOAuthUserInfoResult {
440 HandleOAuthUserInfoResult {
441 data: None,
442 error: Some(error),
443 is_register,
444 cookies: Vec::new(),
445 }
446}