Skip to main content

rustauth_core/auth/oauth/
account_linking.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::sync::{Arc, Mutex};
4use time::OffsetDateTime;
5
6use crate::context::AuthContext;
7use crate::cookies::{ChunkedCookieStore, Cookie};
8#[cfg(feature = "jose")]
9use crate::crypto::symmetric_encode_jwt_with_salt;
10use crate::db::{Account, DbAdapter, Session, User};
11use crate::error::RustAuthError;
12use crate::session::{CreateSessionInput, SessionStore};
13use crate::user::{
14    CreateOAuthAccountInput, CreateUserInput, DbUserStore, UpdateAccountInput, UpdateUserInput,
15};
16
17use super::errors::OAuthUserInfoError;
18use super::tokens::{encrypt_oauth_tokens_for_storage, set_token_util};
19
20#[cfg(feature = "jose")]
21const ACCOUNT_COOKIE_SALT: &str = "better-auth-account";
22// These account-linking error codes are only surfaced by the social OAuth
23// routes (`api::routes::social`), which are gated behind the `oauth` feature.
24// Gate the definitions the same way so they are not flagged as dead code when
25// `rustauth-core` is built without the route layer (for example as a
26// dependency of `rustauth-sso`).
27#[cfg(feature = "oauth")]
28pub(crate) const ACCOUNT_ALREADY_LINKED_TO_DIFFERENT_USER: &str =
29    "account_already_linked_to_different_user";
30#[cfg(feature = "oauth")]
31pub(crate) const EMAIL_DOES_NOT_MATCH_LINKED_USER: &str = "email_doesn't_match";
32
33#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
34pub struct OAuthUserInfo {
35    pub id: String,
36    pub name: String,
37    pub email: String,
38    pub image: Option<String>,
39    pub email_verified: bool,
40    pub raw_attributes: Option<Value>,
41}
42
43#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
44pub struct OAuthAccountInput {
45    pub provider_id: String,
46    pub account_id: String,
47    pub access_token: Option<String>,
48    pub refresh_token: Option<String>,
49    pub id_token: Option<String>,
50    pub access_token_expires_at: Option<OffsetDateTime>,
51    pub refresh_token_expires_at: Option<OffsetDateTime>,
52    pub scope: Option<String>,
53}
54
55#[derive(Debug, Clone, Default, PartialEq, Eq)]
56pub struct HandleOAuthUserInfoInput {
57    pub user_info: OAuthUserInfo,
58    pub account: OAuthAccountInput,
59    pub callback_url: Option<String>,
60    pub disable_sign_up: bool,
61    pub override_user_info: bool,
62    pub is_trusted_provider: bool,
63    pub require_trusted_provider_for_implicit_link: bool,
64}
65
66#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct OAuthSessionUser {
68    pub session: Session,
69    pub user: User,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq)]
73struct CreatedOAuthSessionUser {
74    session: Session,
75    user: User,
76    account: Account,
77}
78
79#[derive(Debug, Clone, PartialEq, Eq)]
80pub struct HandleOAuthUserInfoResult {
81    pub data: Option<OAuthSessionUser>,
82    pub error: Option<OAuthUserInfoError>,
83    pub is_register: bool,
84    pub cookies: Vec<Cookie>,
85}
86
87pub async fn handle_oauth_user_info(
88    context: &AuthContext,
89    adapter: &dyn DbAdapter,
90    input: HandleOAuthUserInfoInput,
91) -> Result<HandleOAuthUserInfoResult, RustAuthError> {
92    let users = DbUserStore::with_schema(adapter, context.db_schema.clone());
93    let normalized_email = input.user_info.email.to_lowercase();
94    let db_user = users
95        .find_oauth_user(
96            &normalized_email,
97            &input.account.account_id,
98            &input.account.provider_id,
99        )
100        .await?;
101    let mut user = db_user.as_ref().map(|lookup| lookup.user.clone());
102    let account_cookie;
103    let is_register = user.is_none();
104    let mut created_session = None;
105
106    if let Some(lookup) = db_user {
107        let linked_account = lookup.linked_account.or_else(|| {
108            lookup
109                .accounts
110                .iter()
111                .find(|account| {
112                    account.provider_id == input.account.provider_id
113                        && account.account_id == input.account.account_id
114                })
115                .cloned()
116        });
117        if let Some(linked_account) = linked_account {
118            account_cookie = Some(
119                update_linked_account(context, &users, &linked_account, &input.account).await?,
120            );
121            if input.override_user_info {
122                user = override_linked_user_info(&users, &lookup.user, &input.user_info).await?;
123            } else if input.user_info.email_verified
124                && !lookup.user.email_verified
125                && same_email(&input.user_info.email, &lookup.user.email)
126            {
127                user = users
128                    .update_user_email_verified(&lookup.user.id, true)
129                    .await?
130                    .or(user);
131            }
132        } else if !can_implicitly_link(context, &input) {
133            return Ok(result_error(OAuthUserInfoError::AccountNotLinked, false));
134        } else {
135            let account = account_input(context, &input.account, &lookup.user.id)?;
136            let linked_account = users
137                .link_account(account)
138                .await
139                .map_err(|_| RustAuthError::Adapter("unable to link OAuth account".to_owned()))?;
140            account_cookie = Some(linked_account);
141            if input.user_info.email_verified
142                && !lookup.user.email_verified
143                && same_email(&input.user_info.email, &lookup.user.email)
144            {
145                user = users
146                    .update_user_email_verified(&lookup.user.id, true)
147                    .await?
148                    .or(Some(lookup.user));
149            } else {
150                user = Some(lookup.user);
151            }
152        }
153    } else {
154        if input.disable_sign_up {
155            return Ok(result_error(OAuthUserInfoError::SignupDisabled, false));
156        }
157        let mut user_input = CreateUserInput::new(input.user_info.name.clone(), normalized_email)
158            .email_verified(input.user_info.email_verified);
159        if let Some(image) = input.user_info.image.clone() {
160            user_input = user_input.image(image);
161        }
162        let created = create_oauth_session_user(
163            context,
164            adapter,
165            user_input,
166            account_input(context, &input.account, "")?,
167        )
168        .await?;
169        account_cookie = Some(created.account);
170        created_session = Some(created.session);
171        user = Some(created.user);
172    }
173
174    let Some(user) = user else {
175        return Ok(result_error(OAuthUserInfoError::UnableToCreateUser, false));
176    };
177    let session = match created_session {
178        Some(session) => session,
179        None => SessionStore::with_storage_and_schema(
180            adapter,
181            context.db_schema.clone(),
182            context.secondary_storage(),
183            context.options.session.store_session_in_database,
184            context.options.session.preserve_session_in_database,
185        )
186        .create_session(CreateSessionInput::new(
187            &user.id,
188            OffsetDateTime::now_utc() + context.session_config.expires_in,
189        ))
190        .await
191        .map_err(|_| RustAuthError::Adapter("unable to create OAuth session".to_owned()))?,
192    };
193    let cookies = if context.options.account.store_account_cookie {
194        account_cookie
195            .as_ref()
196            .map(|account| set_account_cookie(context, account))
197            .transpose()?
198            .unwrap_or_default()
199    } else {
200        Vec::new()
201    };
202    Ok(HandleOAuthUserInfoResult {
203        data: Some(OAuthSessionUser { session, user }),
204        error: None,
205        is_register,
206        cookies,
207    })
208}
209
210async fn create_oauth_session_user(
211    context: &AuthContext,
212    adapter: &dyn DbAdapter,
213    user_input: CreateUserInput,
214    account_input: CreateOAuthAccountInput,
215) -> Result<CreatedOAuthSessionUser, RustAuthError> {
216    let result = Arc::new(Mutex::new(None));
217    let result_for_transaction = Arc::clone(&result);
218    let expires_in = context.session_config.expires_in;
219    let secondary_storage = context.secondary_storage();
220    let store_session_in_database = context.options.session.store_session_in_database;
221    let preserve_session_in_database = context.options.session.preserve_session_in_database;
222    let schema = context.db_schema.clone();
223    let transaction_status = adapter
224        .transaction(Box::new(move |transaction| {
225            let secondary_storage = secondary_storage.clone();
226            let schema = schema.clone();
227            Box::pin(async move {
228                let users = DbUserStore::with_schema(transaction.as_ref(), schema);
229                let created = users
230                    .create_oauth_user(user_input, account_input)
231                    .await
232                    .map_err(|_| {
233                        RustAuthError::Adapter("unable to create OAuth user".to_owned())
234                    })?;
235                let session = SessionStore::with_storage(
236                    transaction.as_ref(),
237                    secondary_storage,
238                    store_session_in_database,
239                    preserve_session_in_database,
240                )
241                .create_session(CreateSessionInput::new(
242                    &created.user.id,
243                    OffsetDateTime::now_utc() + expires_in,
244                ))
245                .await
246                .map_err(|_| RustAuthError::Adapter("unable to create OAuth session".to_owned()))?;
247                store_created_oauth_session_user(
248                    &result_for_transaction,
249                    CreatedOAuthSessionUser {
250                        session,
251                        user: created.user,
252                        account: created.account,
253                    },
254                )?;
255                Ok(())
256            })
257        }))
258        .await;
259
260    match transaction_status {
261        Ok(()) => take_created_oauth_session_user(&result)?.ok_or_else(|| {
262            RustAuthError::Adapter(
263                "create OAuth session transaction completed without a result".to_owned(),
264            )
265        }),
266        Err(error) => Err(error),
267    }
268}
269
270fn store_created_oauth_session_user(
271    result: &Mutex<Option<CreatedOAuthSessionUser>>,
272    value: CreatedOAuthSessionUser,
273) -> Result<(), RustAuthError> {
274    let mut guard = result.lock().map_err(|_| RustAuthError::LockPoisoned {
275        context: "create OAuth session result",
276    })?;
277    *guard = Some(value);
278    Ok(())
279}
280
281fn take_created_oauth_session_user(
282    result: &Mutex<Option<CreatedOAuthSessionUser>>,
283) -> Result<Option<CreatedOAuthSessionUser>, RustAuthError> {
284    result
285        .lock()
286        .map_err(|_| RustAuthError::LockPoisoned {
287            context: "create OAuth session result",
288        })
289        .map(|mut guard| guard.take())
290}
291
292fn can_implicitly_link(context: &AuthContext, input: &HandleOAuthUserInfoInput) -> bool {
293    let linking = &context.options.account.account_linking;
294    if !linking.enabled || linking.disable_implicit_linking {
295        return false;
296    }
297    let trusted_providers = context
298        .trusted_providers_for_request(None)
299        .unwrap_or_default();
300    let trusted = input.is_trusted_provider
301        || trusted_providers
302            .iter()
303            .any(|provider| provider == &input.account.provider_id);
304    if input.require_trusted_provider_for_implicit_link {
305        return trusted;
306    }
307    trusted || input.user_info.email_verified
308}
309
310async fn update_linked_account(
311    context: &AuthContext,
312    users: &DbUserStore<'_>,
313    linked_account: &Account,
314    account: &OAuthAccountInput,
315) -> Result<Account, RustAuthError> {
316    if !context.options.account.update_account_on_sign_in {
317        return Ok(linked_account.clone());
318    }
319    let updated = users
320        .update_account(&linked_account.id, account_update_input(context, account)?)
321        .await?;
322    Ok(updated.unwrap_or_else(|| linked_account.clone()))
323}
324
325pub(crate) fn set_account_cookie(
326    context: &AuthContext,
327    account: &Account,
328) -> Result<Vec<Cookie>, RustAuthError> {
329    let max_age = context
330        .auth_cookies
331        .account_data
332        .attributes
333        .max_age
334        .unwrap_or(60 * 5);
335    #[cfg(feature = "jose")]
336    let data = symmetric_encode_jwt_with_salt(
337        account,
338        &context.secret_config,
339        ACCOUNT_COOKIE_SALT,
340        max_age,
341    )?;
342    #[cfg(not(feature = "jose"))]
343    let data = encode_account_cookie_data(account)?;
344    let mut attributes = context.auth_cookies.account_data.attributes.clone();
345    attributes.max_age = Some(max_age);
346    Ok(ChunkedCookieStore::new(
347        context.auth_cookies.account_data.name.clone(),
348        attributes,
349        "",
350    )
351    .chunk(&data))
352}
353
354#[cfg(not(feature = "jose"))]
355fn encode_account_cookie_data(_account: &Account) -> Result<String, RustAuthError> {
356    Err(RustAuthError::FeatureDisabled { feature: "jose" })
357}
358
359async fn override_linked_user_info(
360    users: &DbUserStore<'_>,
361    existing: &User,
362    provider: &OAuthUserInfo,
363) -> Result<Option<User>, RustAuthError> {
364    let normalized_email = provider.email.to_lowercase();
365    let email_verified = if normalized_email == existing.email {
366        existing.email_verified || provider.email_verified
367    } else {
368        provider.email_verified
369    };
370    let updated = users
371        .update_user(
372            &existing.id,
373            UpdateUserInput::new()
374                .name(provider.name.clone())
375                .image(provider.image.clone()),
376        )
377        .await?;
378    users
379        .update_user_email(&existing.id, &normalized_email, email_verified)
380        .await
381        .map(|user| user.or(updated))
382}
383
384fn account_input(
385    context: &AuthContext,
386    account: &OAuthAccountInput,
387    user_id: &str,
388) -> Result<CreateOAuthAccountInput, RustAuthError> {
389    let tokens = encrypt_oauth_tokens_for_storage(
390        account.access_token.as_deref(),
391        account.refresh_token.as_deref(),
392        account.id_token.as_deref(),
393        context,
394    )?;
395    Ok(CreateOAuthAccountInput {
396        id: None,
397        provider_id: account.provider_id.clone(),
398        account_id: account.account_id.clone(),
399        user_id: user_id.to_owned(),
400        access_token: tokens.access_token,
401        refresh_token: tokens.refresh_token,
402        id_token: tokens.id_token,
403        access_token_expires_at: account.access_token_expires_at,
404        refresh_token_expires_at: account.refresh_token_expires_at,
405        scope: account.scope.clone(),
406    })
407}
408
409fn account_update_input(
410    context: &AuthContext,
411    account: &OAuthAccountInput,
412) -> Result<UpdateAccountInput, RustAuthError> {
413    let mut input = UpdateAccountInput::default();
414    if account.access_token.is_some() {
415        input.access_token = Some(set_token_util(account.access_token.as_deref(), context)?);
416    }
417    if account.refresh_token.is_some() {
418        input.refresh_token = Some(set_token_util(account.refresh_token.as_deref(), context)?);
419    }
420    if account.id_token.is_some() {
421        input.id_token = Some(set_token_util(account.id_token.as_deref(), context)?);
422    }
423    if account.access_token_expires_at.is_some() {
424        input.access_token_expires_at = Some(account.access_token_expires_at);
425    }
426    if account.refresh_token_expires_at.is_some() {
427        input.refresh_token_expires_at = Some(account.refresh_token_expires_at);
428    }
429    if account.scope.is_some() {
430        input.scope = Some(account.scope.clone());
431    }
432    Ok(input)
433}
434
435fn same_email(provider_email: &str, user_email: &str) -> bool {
436    provider_email.eq_ignore_ascii_case(user_email)
437}
438
439fn result_error(error: OAuthUserInfoError, is_register: bool) -> HandleOAuthUserInfoResult {
440    HandleOAuthUserInfoResult {
441        data: None,
442        error: Some(error),
443        is_register,
444        cookies: Vec::new(),
445    }
446}