Skip to main content

rustauth_core/user/
mod.rs

1//! Database-backed user and credential account helpers.
2
3mod input;
4mod record;
5
6use std::sync::{Arc, LazyLock, Mutex};
7
8use time::OffsetDateTime;
9
10use crate::context::AuthContext;
11use crate::crypto::random::generate_random_string;
12use crate::db::{
13    auth_schema, Account, AuthSchemaOptions, Count, Create, DbAdapter, DbRecord, DbSchema, DbValue,
14    Delete, DeleteMany, FindMany, FindOne, JoinOption, SchemaTable, Sort, SortDirection, Update,
15    User, Where,
16};
17use crate::error::RustAuthError;
18pub use input::{
19    CreateCredentialAccountInput, CreateOAuthAccountInput, CreateUserInput, UpdateAccountInput,
20    UpdateUserInput,
21};
22use record::{
23    account_from_record, user_from_record, ACCOUNT_FIELDS, USER_FIELDS, USER_FIELDS_WITH_USERNAME,
24};
25
26pub(super) const USER_MODEL: &str = "user";
27pub(super) const ACCOUNT_MODEL: &str = "account";
28const CREDENTIAL_PROVIDER_ID: &str = "credential";
29const DEFAULT_ID_LENGTH: usize = 32;
30
31fn default_auth_schema() -> &'static DbSchema {
32    static SCHEMA: LazyLock<DbSchema> = LazyLock::new(|| auth_schema(AuthSchemaOptions::default()));
33    &SCHEMA
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct UserWithAccounts {
38    pub user: User,
39    pub accounts: Vec<Account>,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub struct OAuthUserLookup {
44    pub user: User,
45    pub accounts: Vec<Account>,
46    pub linked_account: Option<Account>,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct CreateOAuthUserResult {
51    pub user: User,
52    pub account: Account,
53}
54
55#[derive(Clone)]
56pub struct DbUserStore<'a> {
57    adapter: &'a dyn DbAdapter,
58    schema: DbSchema,
59}
60
61impl<'a> DbUserStore<'a> {
62    pub fn new(adapter: &'a dyn DbAdapter) -> Self {
63        Self::with_schema(adapter, default_auth_schema().clone())
64    }
65
66    pub fn with_schema(adapter: &'a dyn DbAdapter, schema: DbSchema) -> Self {
67        Self { adapter, schema }
68    }
69
70    pub fn from_context(context: &'a AuthContext) -> Result<Self, RustAuthError> {
71        Ok(Self::with_schema(
72            context.adapter_ref()?,
73            context.db_schema.clone(),
74        ))
75    }
76
77    fn users(&self) -> Result<SchemaTable<'_>, RustAuthError> {
78        SchemaTable::new(&self.schema, USER_MODEL)
79    }
80
81    fn accounts(&self) -> Result<SchemaTable<'_>, RustAuthError> {
82        SchemaTable::new(&self.schema, ACCOUNT_MODEL)
83    }
84
85    fn parse_user(&self, record: DbRecord) -> Result<User, RustAuthError> {
86        user_from_record(self.users()?.map_record(record)?)
87    }
88
89    fn parse_account(&self, record: DbRecord) -> Result<Account, RustAuthError> {
90        account_from_record(self.accounts()?.map_record(record)?)
91    }
92
93    pub async fn create_user(&self, input: CreateUserInput) -> Result<User, RustAuthError> {
94        let now = OffsetDateTime::now_utc();
95        let id = input
96            .id
97            .unwrap_or_else(|| generate_random_string(DEFAULT_ID_LENGTH));
98
99        let include_username_fields = input.username.is_some() || input.display_username.is_some();
100        let mut query = Create::new(USER_MODEL)
101            .data("id", DbValue::String(id))
102            .data("name", DbValue::String(input.name))
103            .data("email", DbValue::String(normalize_email(&input.email)))
104            .data("email_verified", DbValue::Boolean(input.email_verified))
105            .data("image", optional_string(input.image))
106            .data("created_at", DbValue::Timestamp(now))
107            .data("updated_at", DbValue::Timestamp(now))
108            .force_allow_id();
109        if include_username_fields {
110            query = query
111                .data("username", optional_string(input.username))
112                .data("display_username", optional_string(input.display_username))
113                .select(USER_FIELDS_WITH_USERNAME);
114        } else {
115            query = query.select(USER_FIELDS);
116        }
117
118        for (field, value) in input.additional_fields {
119            query = query.data(field, value);
120        }
121
122        let record = self.adapter.create(query).await?;
123
124        self.parse_user(record)
125    }
126
127    pub async fn create_credential_account(
128        &self,
129        input: CreateCredentialAccountInput,
130    ) -> Result<Account, RustAuthError> {
131        let now = OffsetDateTime::now_utc();
132        let id = input
133            .id
134            .unwrap_or_else(|| generate_random_string(DEFAULT_ID_LENGTH));
135        let account_id = input.user_id.clone();
136
137        let record = self
138            .adapter
139            .create(
140                Create::new(ACCOUNT_MODEL)
141                    .data("id", DbValue::String(id))
142                    .data(
143                        "provider_id",
144                        DbValue::String(CREDENTIAL_PROVIDER_ID.to_owned()),
145                    )
146                    .data("account_id", DbValue::String(account_id))
147                    .data("user_id", DbValue::String(input.user_id))
148                    .data("access_token", DbValue::Null)
149                    .data("refresh_token", DbValue::Null)
150                    .data("id_token", DbValue::Null)
151                    .data("access_token_expires_at", DbValue::Null)
152                    .data("refresh_token_expires_at", DbValue::Null)
153                    .data("scope", DbValue::Null)
154                    .data("password", DbValue::String(input.password_hash))
155                    .data("created_at", DbValue::Timestamp(now))
156                    .data("updated_at", DbValue::Timestamp(now))
157                    .select(ACCOUNT_FIELDS)
158                    .force_allow_id(),
159            )
160            .await?;
161
162        self.parse_account(record)
163    }
164
165    pub async fn link_account(
166        &self,
167        input: CreateOAuthAccountInput,
168    ) -> Result<Account, RustAuthError> {
169        let now = OffsetDateTime::now_utc();
170        let id = input
171            .id
172            .unwrap_or_else(|| generate_random_string(DEFAULT_ID_LENGTH));
173
174        let record = self
175            .adapter
176            .create(
177                Create::new(ACCOUNT_MODEL)
178                    .data("id", DbValue::String(id))
179                    .data("provider_id", DbValue::String(input.provider_id))
180                    .data("account_id", DbValue::String(input.account_id))
181                    .data("user_id", DbValue::String(input.user_id))
182                    .data("access_token", optional_string(input.access_token))
183                    .data("refresh_token", optional_string(input.refresh_token))
184                    .data("id_token", optional_string(input.id_token))
185                    .data(
186                        "access_token_expires_at",
187                        optional_timestamp(input.access_token_expires_at),
188                    )
189                    .data(
190                        "refresh_token_expires_at",
191                        optional_timestamp(input.refresh_token_expires_at),
192                    )
193                    .data("scope", optional_string(input.scope))
194                    .data("password", DbValue::Null)
195                    .data("created_at", DbValue::Timestamp(now))
196                    .data("updated_at", DbValue::Timestamp(now))
197                    .select(ACCOUNT_FIELDS)
198                    .force_allow_id(),
199            )
200            .await?;
201
202        self.parse_account(record)
203    }
204
205    pub async fn create_oauth_user(
206        &self,
207        user: CreateUserInput,
208        mut account: CreateOAuthAccountInput,
209    ) -> Result<CreateOAuthUserResult, RustAuthError> {
210        let result = Arc::new(Mutex::new(None));
211        let result_for_transaction = Arc::clone(&result);
212        let schema = self.schema.clone();
213        let transaction_status = self
214            .adapter
215            .transaction(Box::new(move |transaction| {
216                let schema = schema.clone();
217                Box::pin(async move {
218                    let users = DbUserStore::with_schema(transaction.as_ref(), schema);
219                    let user = users.create_user(user).await?;
220                    account.user_id = user.id.clone();
221                    let account = users.link_account(account).await?;
222                    store_create_oauth_user_result(
223                        &result_for_transaction,
224                        CreateOAuthUserResult { user, account },
225                    )?;
226                    Ok(())
227                })
228            }))
229            .await;
230
231        match transaction_status {
232            Ok(()) => take_create_oauth_user_result(&result)?.ok_or_else(|| {
233                RustAuthError::Adapter(
234                    "create OAuth user transaction completed without a result".to_owned(),
235                )
236            }),
237            Err(error) => Err(error),
238        }
239    }
240
241    pub async fn find_user_by_email(&self, email: &str) -> Result<Option<User>, RustAuthError> {
242        let record = self
243            .adapter
244            .find_one(
245                FindOne::new(USER_MODEL)
246                    .where_clause(Where::new("email", DbValue::String(normalize_email(email))))
247                    .select(USER_FIELDS),
248            )
249            .await?;
250
251        record.map(|record| self.parse_user(record)).transpose()
252    }
253
254    pub async fn find_user_by_id(&self, user_id: &str) -> Result<Option<User>, RustAuthError> {
255        let record = self
256            .adapter
257            .find_one(
258                FindOne::new(USER_MODEL)
259                    .where_clause(Where::new("id", DbValue::String(user_id.to_owned())))
260                    .select(USER_FIELDS),
261            )
262            .await?;
263
264        record.map(|record| self.parse_user(record)).transpose()
265    }
266
267    pub async fn find_user_by_username(
268        &self,
269        username: &str,
270    ) -> Result<Option<User>, RustAuthError> {
271        let record = self
272            .adapter
273            .find_one(
274                FindOne::new(USER_MODEL)
275                    .where_clause(Where::new("username", DbValue::String(username.to_owned())))
276                    .select(USER_FIELDS_WITH_USERNAME),
277            )
278            .await?;
279
280        record.map(|record| self.parse_user(record)).transpose()
281    }
282
283    pub async fn list_users(
284        &self,
285        limit: Option<usize>,
286        offset: Option<usize>,
287        sort_field: Option<&str>,
288        sort_direction: SortDirection,
289    ) -> Result<Vec<User>, RustAuthError> {
290        let mut query = FindMany::new(USER_MODEL).select(USER_FIELDS);
291        if let Some(limit) = limit {
292            query = query.limit(limit);
293        }
294        if let Some(offset) = offset {
295            query = query.offset(offset);
296        }
297        if let Some(field) = sort_field {
298            query = query.sort_by(Sort::new(field, sort_direction));
299        }
300        self.adapter
301            .find_many(query)
302            .await?
303            .into_iter()
304            .map(|record| self.parse_user(record))
305            .collect()
306    }
307
308    pub async fn count_total_users(&self) -> Result<u64, RustAuthError> {
309        self.adapter.count(Count::new(USER_MODEL)).await
310    }
311
312    pub async fn find_user_by_username_with_accounts(
313        &self,
314        username: &str,
315    ) -> Result<Option<UserWithAccounts>, RustAuthError> {
316        let Some(user) = self.find_user_by_username(username).await? else {
317            return Ok(None);
318        };
319        let accounts = self.list_accounts_for_user(&user.id).await?;
320        Ok(Some(UserWithAccounts { user, accounts }))
321    }
322
323    pub async fn find_user_by_email_with_accounts(
324        &self,
325        email: &str,
326    ) -> Result<Option<UserWithAccounts>, RustAuthError> {
327        let Some(mut record) = self
328            .adapter
329            .find_one(
330                FindOne::new(USER_MODEL)
331                    .where_clause(Where::new("email", DbValue::String(normalize_email(email))))
332                    .select(USER_FIELDS)
333                    .join(ACCOUNT_MODEL, JoinOption::enabled()),
334            )
335            .await?
336        else {
337            return Ok(None);
338        };
339
340        let joined_accounts = record.shift_remove(ACCOUNT_MODEL);
341        let user = self.parse_user(record)?;
342        let accounts = match joined_accounts {
343            Some(DbValue::RecordArray(accounts)) => accounts
344                .into_iter()
345                .map(|record| self.parse_account(record))
346                .collect::<Result<Vec<_>, _>>()?,
347            Some(DbValue::Null) => Vec::new(),
348            None => self.list_accounts_for_user(&user.id).await?,
349            Some(_) => {
350                return Err(RustAuthError::Adapter(
351                    "joined account result must be an array".to_owned(),
352                ));
353            }
354        };
355        Ok(Some(UserWithAccounts { user, accounts }))
356    }
357
358    pub async fn find_oauth_user(
359        &self,
360        email: &str,
361        account_id: &str,
362        provider_id: &str,
363    ) -> Result<Option<OAuthUserLookup>, RustAuthError> {
364        let linked_account = self
365            .find_account_by_provider_account(account_id, provider_id)
366            .await?;
367        let user = if let Some(account) = &linked_account {
368            self.find_user_by_id(&account.user_id).await?
369        } else {
370            self.find_user_by_email(email).await?
371        };
372        let Some(user) = user else {
373            return Ok(None);
374        };
375        let accounts = self.list_accounts_for_user(&user.id).await?;
376        Ok(Some(OAuthUserLookup {
377            user,
378            accounts,
379            linked_account,
380        }))
381    }
382
383    pub async fn list_accounts_for_user(
384        &self,
385        user_id: &str,
386    ) -> Result<Vec<Account>, RustAuthError> {
387        self.adapter
388            .find_many(
389                FindMany::new(ACCOUNT_MODEL)
390                    .where_clause(Where::new("user_id", DbValue::String(user_id.to_owned())))
391                    .select(ACCOUNT_FIELDS),
392            )
393            .await?
394            .into_iter()
395            .map(|record| self.parse_account(record))
396            .collect()
397    }
398
399    pub async fn find_credential_account(
400        &self,
401        user_id: &str,
402    ) -> Result<Option<Account>, RustAuthError> {
403        let record = self
404            .adapter
405            .find_one(
406                FindOne::new(ACCOUNT_MODEL)
407                    .where_clause(Where::new("user_id", DbValue::String(user_id.to_owned())))
408                    .where_clause(Where::new(
409                        "provider_id",
410                        DbValue::String(CREDENTIAL_PROVIDER_ID.to_owned()),
411                    ))
412                    .select(ACCOUNT_FIELDS),
413            )
414            .await?;
415
416        record.map(|record| self.parse_account(record)).transpose()
417    }
418
419    pub async fn find_account_by_provider_account(
420        &self,
421        account_id: &str,
422        provider_id: &str,
423    ) -> Result<Option<Account>, RustAuthError> {
424        let record = self
425            .adapter
426            .find_one(
427                FindOne::new(ACCOUNT_MODEL)
428                    .where_clause(Where::new(
429                        "account_id",
430                        DbValue::String(account_id.to_owned()),
431                    ))
432                    .where_clause(Where::new(
433                        "provider_id",
434                        DbValue::String(provider_id.to_owned()),
435                    ))
436                    .select(ACCOUNT_FIELDS),
437            )
438            .await?;
439
440        record.map(|record| self.parse_account(record)).transpose()
441    }
442
443    pub async fn update_account(
444        &self,
445        account_id: &str,
446        input: UpdateAccountInput,
447    ) -> Result<Option<Account>, RustAuthError> {
448        let mut query = Update::new(ACCOUNT_MODEL)
449            .where_clause(Where::new("id", DbValue::String(account_id.to_owned())))
450            .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc()));
451        if let Some(value) = input.access_token {
452            query = query.data("access_token", optional_string(value));
453        }
454        if let Some(value) = input.refresh_token {
455            query = query.data("refresh_token", optional_string(value));
456        }
457        if let Some(value) = input.id_token {
458            query = query.data("id_token", optional_string(value));
459        }
460        if let Some(value) = input.access_token_expires_at {
461            query = query.data("access_token_expires_at", optional_timestamp(value));
462        }
463        if let Some(value) = input.refresh_token_expires_at {
464            query = query.data("refresh_token_expires_at", optional_timestamp(value));
465        }
466        if let Some(value) = input.scope {
467            query = query.data("scope", optional_string(value));
468        }
469
470        self.adapter
471            .update(query)
472            .await?
473            .map(|record| self.parse_account(record))
474            .transpose()
475    }
476
477    pub async fn update_user(
478        &self,
479        user_id: &str,
480        input: UpdateUserInput,
481    ) -> Result<Option<User>, RustAuthError> {
482        if input.is_empty() {
483            return self.find_user_by_id(user_id).await;
484        }
485        let mut query = Update::new(USER_MODEL)
486            .where_clause(Where::new("id", DbValue::String(user_id.to_owned())))
487            .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc()));
488        if let Some(name) = input.name {
489            query = query.data("name", DbValue::String(name));
490        }
491        if let Some(image) = input.image {
492            query = query.data("image", optional_string(image));
493        }
494        if let Some(username) = input.username {
495            query = query.data("username", optional_string(username));
496        }
497        if let Some(display_username) = input.display_username {
498            query = query.data("display_username", optional_string(display_username));
499        }
500        for (field, value) in input.fields {
501            query = query.data(field, value);
502        }
503        for (field, value) in input.additional_fields {
504            query = query.data(field, value);
505        }
506
507        self.adapter
508            .update(query)
509            .await?
510            .map(|record| self.parse_user(record))
511            .transpose()
512    }
513
514    pub async fn update_credential_password(
515        &self,
516        user_id: &str,
517        password_hash: &str,
518    ) -> Result<Option<Account>, RustAuthError> {
519        self.adapter
520            .update(
521                Update::new(ACCOUNT_MODEL)
522                    .where_clause(Where::new("user_id", DbValue::String(user_id.to_owned())))
523                    .where_clause(Where::new(
524                        "provider_id",
525                        DbValue::String(CREDENTIAL_PROVIDER_ID.to_owned()),
526                    ))
527                    .data("password", DbValue::String(password_hash.to_owned()))
528                    .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc())),
529            )
530            .await?
531            .map(|record| self.parse_account(record))
532            .transpose()
533    }
534
535    pub async fn update_user_email_verified(
536        &self,
537        user_id: &str,
538        email_verified: bool,
539    ) -> Result<Option<User>, RustAuthError> {
540        self.adapter
541            .update(
542                Update::new(USER_MODEL)
543                    .where_clause(Where::new("id", DbValue::String(user_id.to_owned())))
544                    .data("email_verified", DbValue::Boolean(email_verified))
545                    .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc())),
546            )
547            .await?
548            .map(|record| self.parse_user(record))
549            .transpose()
550    }
551
552    pub async fn update_user_email(
553        &self,
554        user_id: &str,
555        email: &str,
556        email_verified: bool,
557    ) -> Result<Option<User>, RustAuthError> {
558        self.adapter
559            .update(
560                Update::new(USER_MODEL)
561                    .where_clause(Where::new("id", DbValue::String(user_id.to_owned())))
562                    .data("email", DbValue::String(normalize_email(email)))
563                    .data("email_verified", DbValue::Boolean(email_verified))
564                    .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc())),
565            )
566            .await?
567            .map(|record| self.parse_user(record))
568            .transpose()
569    }
570
571    pub async fn delete_account(&self, account_id: &str) -> Result<(), RustAuthError> {
572        self.adapter
573            .delete(
574                Delete::new(ACCOUNT_MODEL)
575                    .where_clause(Where::new("id", DbValue::String(account_id.to_owned()))),
576            )
577            .await
578    }
579
580    pub async fn delete_user_accounts(&self, user_id: &str) -> Result<u64, RustAuthError> {
581        self.adapter
582            .delete_many(
583                DeleteMany::new(ACCOUNT_MODEL)
584                    .where_clause(Where::new("user_id", DbValue::String(user_id.to_owned()))),
585            )
586            .await
587    }
588
589    pub async fn delete_user(&self, user_id: &str) -> Result<(), RustAuthError> {
590        self.adapter
591            .delete(
592                Delete::new(USER_MODEL)
593                    .where_clause(Where::new("id", DbValue::String(user_id.to_owned()))),
594            )
595            .await
596    }
597}
598
599fn normalize_email(email: &str) -> String {
600    email.to_lowercase()
601}
602
603fn optional_string(value: Option<String>) -> DbValue {
604    value.map(DbValue::String).unwrap_or(DbValue::Null)
605}
606
607fn optional_timestamp(value: Option<OffsetDateTime>) -> DbValue {
608    value.map(DbValue::Timestamp).unwrap_or(DbValue::Null)
609}
610
611fn store_create_oauth_user_result(
612    result: &Mutex<Option<CreateOAuthUserResult>>,
613    value: CreateOAuthUserResult,
614) -> Result<(), RustAuthError> {
615    let mut guard = result.lock().map_err(|_| RustAuthError::LockPoisoned {
616        context: "create OAuth user result",
617    })?;
618    *guard = Some(value);
619    Ok(())
620}
621
622fn take_create_oauth_user_result(
623    result: &Mutex<Option<CreateOAuthUserResult>>,
624) -> Result<Option<CreateOAuthUserResult>, RustAuthError> {
625    result
626        .lock()
627        .map_err(|_| RustAuthError::LockPoisoned {
628            context: "create OAuth user result",
629        })
630        .map(|mut guard| guard.take())
631}