1mod input;
4mod record;
5
6use std::sync::{Arc, LazyLock, Mutex};
7
8use time::OffsetDateTime;
9
10use crate::context::AuthContext;
11use crate::crypto::random::generate_random_string;
12use crate::db::{
13 auth_schema, Account, AuthSchemaOptions, Count, Create, DbAdapter, DbRecord, DbSchema, DbValue,
14 Delete, DeleteMany, FindMany, FindOne, JoinOption, SchemaTable, Sort, SortDirection, Update,
15 User, Where,
16};
17use crate::error::RustAuthError;
18pub use input::{
19 CreateCredentialAccountInput, CreateOAuthAccountInput, CreateUserInput, UpdateAccountInput,
20 UpdateUserInput,
21};
22use record::{
23 account_from_record, user_from_record, ACCOUNT_FIELDS, USER_FIELDS, USER_FIELDS_WITH_USERNAME,
24};
25
26pub(super) const USER_MODEL: &str = "user";
27pub(super) const ACCOUNT_MODEL: &str = "account";
28const CREDENTIAL_PROVIDER_ID: &str = "credential";
29const DEFAULT_ID_LENGTH: usize = 32;
30
31fn default_auth_schema() -> &'static DbSchema {
32 static SCHEMA: LazyLock<DbSchema> = LazyLock::new(|| auth_schema(AuthSchemaOptions::default()));
33 &SCHEMA
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct UserWithAccounts {
38 pub user: User,
39 pub accounts: Vec<Account>,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub struct OAuthUserLookup {
44 pub user: User,
45 pub accounts: Vec<Account>,
46 pub linked_account: Option<Account>,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct CreateOAuthUserResult {
51 pub user: User,
52 pub account: Account,
53}
54
55#[derive(Clone)]
56pub struct DbUserStore<'a> {
57 adapter: &'a dyn DbAdapter,
58 schema: DbSchema,
59}
60
61impl<'a> DbUserStore<'a> {
62 pub fn new(adapter: &'a dyn DbAdapter) -> Self {
63 Self::with_schema(adapter, default_auth_schema().clone())
64 }
65
66 pub fn with_schema(adapter: &'a dyn DbAdapter, schema: DbSchema) -> Self {
67 Self { adapter, schema }
68 }
69
70 pub fn from_context(context: &'a AuthContext) -> Result<Self, RustAuthError> {
71 Ok(Self::with_schema(
72 context.adapter_ref()?,
73 context.db_schema.clone(),
74 ))
75 }
76
77 fn users(&self) -> Result<SchemaTable<'_>, RustAuthError> {
78 SchemaTable::new(&self.schema, USER_MODEL)
79 }
80
81 fn accounts(&self) -> Result<SchemaTable<'_>, RustAuthError> {
82 SchemaTable::new(&self.schema, ACCOUNT_MODEL)
83 }
84
85 fn parse_user(&self, record: DbRecord) -> Result<User, RustAuthError> {
86 user_from_record(self.users()?.map_record(record)?)
87 }
88
89 fn parse_account(&self, record: DbRecord) -> Result<Account, RustAuthError> {
90 account_from_record(self.accounts()?.map_record(record)?)
91 }
92
93 pub async fn create_user(&self, input: CreateUserInput) -> Result<User, RustAuthError> {
94 let now = OffsetDateTime::now_utc();
95 let id = input
96 .id
97 .unwrap_or_else(|| generate_random_string(DEFAULT_ID_LENGTH));
98
99 let include_username_fields = input.username.is_some() || input.display_username.is_some();
100 let mut query = Create::new(USER_MODEL)
101 .data("id", DbValue::String(id))
102 .data("name", DbValue::String(input.name))
103 .data("email", DbValue::String(normalize_email(&input.email)))
104 .data("email_verified", DbValue::Boolean(input.email_verified))
105 .data("image", optional_string(input.image))
106 .data("created_at", DbValue::Timestamp(now))
107 .data("updated_at", DbValue::Timestamp(now))
108 .force_allow_id();
109 if include_username_fields {
110 query = query
111 .data("username", optional_string(input.username))
112 .data("display_username", optional_string(input.display_username))
113 .select(USER_FIELDS_WITH_USERNAME);
114 } else {
115 query = query.select(USER_FIELDS);
116 }
117
118 for (field, value) in input.additional_fields {
119 query = query.data(field, value);
120 }
121
122 let record = self.adapter.create(query).await?;
123
124 self.parse_user(record)
125 }
126
127 pub async fn create_credential_account(
128 &self,
129 input: CreateCredentialAccountInput,
130 ) -> Result<Account, RustAuthError> {
131 let now = OffsetDateTime::now_utc();
132 let id = input
133 .id
134 .unwrap_or_else(|| generate_random_string(DEFAULT_ID_LENGTH));
135 let account_id = input.user_id.clone();
136
137 let record = self
138 .adapter
139 .create(
140 Create::new(ACCOUNT_MODEL)
141 .data("id", DbValue::String(id))
142 .data(
143 "provider_id",
144 DbValue::String(CREDENTIAL_PROVIDER_ID.to_owned()),
145 )
146 .data("account_id", DbValue::String(account_id))
147 .data("user_id", DbValue::String(input.user_id))
148 .data("access_token", DbValue::Null)
149 .data("refresh_token", DbValue::Null)
150 .data("id_token", DbValue::Null)
151 .data("access_token_expires_at", DbValue::Null)
152 .data("refresh_token_expires_at", DbValue::Null)
153 .data("scope", DbValue::Null)
154 .data("password", DbValue::String(input.password_hash))
155 .data("created_at", DbValue::Timestamp(now))
156 .data("updated_at", DbValue::Timestamp(now))
157 .select(ACCOUNT_FIELDS)
158 .force_allow_id(),
159 )
160 .await?;
161
162 self.parse_account(record)
163 }
164
165 pub async fn link_account(
166 &self,
167 input: CreateOAuthAccountInput,
168 ) -> Result<Account, RustAuthError> {
169 let now = OffsetDateTime::now_utc();
170 let id = input
171 .id
172 .unwrap_or_else(|| generate_random_string(DEFAULT_ID_LENGTH));
173
174 let record = self
175 .adapter
176 .create(
177 Create::new(ACCOUNT_MODEL)
178 .data("id", DbValue::String(id))
179 .data("provider_id", DbValue::String(input.provider_id))
180 .data("account_id", DbValue::String(input.account_id))
181 .data("user_id", DbValue::String(input.user_id))
182 .data("access_token", optional_string(input.access_token))
183 .data("refresh_token", optional_string(input.refresh_token))
184 .data("id_token", optional_string(input.id_token))
185 .data(
186 "access_token_expires_at",
187 optional_timestamp(input.access_token_expires_at),
188 )
189 .data(
190 "refresh_token_expires_at",
191 optional_timestamp(input.refresh_token_expires_at),
192 )
193 .data("scope", optional_string(input.scope))
194 .data("password", DbValue::Null)
195 .data("created_at", DbValue::Timestamp(now))
196 .data("updated_at", DbValue::Timestamp(now))
197 .select(ACCOUNT_FIELDS)
198 .force_allow_id(),
199 )
200 .await?;
201
202 self.parse_account(record)
203 }
204
205 pub async fn create_oauth_user(
206 &self,
207 user: CreateUserInput,
208 mut account: CreateOAuthAccountInput,
209 ) -> Result<CreateOAuthUserResult, RustAuthError> {
210 let result = Arc::new(Mutex::new(None));
211 let result_for_transaction = Arc::clone(&result);
212 let schema = self.schema.clone();
213 let transaction_status = self
214 .adapter
215 .transaction(Box::new(move |transaction| {
216 let schema = schema.clone();
217 Box::pin(async move {
218 let users = DbUserStore::with_schema(transaction.as_ref(), schema);
219 let user = users.create_user(user).await?;
220 account.user_id = user.id.clone();
221 let account = users.link_account(account).await?;
222 store_create_oauth_user_result(
223 &result_for_transaction,
224 CreateOAuthUserResult { user, account },
225 )?;
226 Ok(())
227 })
228 }))
229 .await;
230
231 match transaction_status {
232 Ok(()) => take_create_oauth_user_result(&result)?.ok_or_else(|| {
233 RustAuthError::Adapter(
234 "create OAuth user transaction completed without a result".to_owned(),
235 )
236 }),
237 Err(error) => Err(error),
238 }
239 }
240
241 pub async fn find_user_by_email(&self, email: &str) -> Result<Option<User>, RustAuthError> {
242 let record = self
243 .adapter
244 .find_one(
245 FindOne::new(USER_MODEL)
246 .where_clause(Where::new("email", DbValue::String(normalize_email(email))))
247 .select(USER_FIELDS),
248 )
249 .await?;
250
251 record.map(|record| self.parse_user(record)).transpose()
252 }
253
254 pub async fn find_user_by_id(&self, user_id: &str) -> Result<Option<User>, RustAuthError> {
255 let record = self
256 .adapter
257 .find_one(
258 FindOne::new(USER_MODEL)
259 .where_clause(Where::new("id", DbValue::String(user_id.to_owned())))
260 .select(USER_FIELDS),
261 )
262 .await?;
263
264 record.map(|record| self.parse_user(record)).transpose()
265 }
266
267 pub async fn find_user_by_username(
268 &self,
269 username: &str,
270 ) -> Result<Option<User>, RustAuthError> {
271 let record = self
272 .adapter
273 .find_one(
274 FindOne::new(USER_MODEL)
275 .where_clause(Where::new("username", DbValue::String(username.to_owned())))
276 .select(USER_FIELDS_WITH_USERNAME),
277 )
278 .await?;
279
280 record.map(|record| self.parse_user(record)).transpose()
281 }
282
283 pub async fn list_users(
284 &self,
285 limit: Option<usize>,
286 offset: Option<usize>,
287 sort_field: Option<&str>,
288 sort_direction: SortDirection,
289 ) -> Result<Vec<User>, RustAuthError> {
290 let mut query = FindMany::new(USER_MODEL).select(USER_FIELDS);
291 if let Some(limit) = limit {
292 query = query.limit(limit);
293 }
294 if let Some(offset) = offset {
295 query = query.offset(offset);
296 }
297 if let Some(field) = sort_field {
298 query = query.sort_by(Sort::new(field, sort_direction));
299 }
300 self.adapter
301 .find_many(query)
302 .await?
303 .into_iter()
304 .map(|record| self.parse_user(record))
305 .collect()
306 }
307
308 pub async fn count_total_users(&self) -> Result<u64, RustAuthError> {
309 self.adapter.count(Count::new(USER_MODEL)).await
310 }
311
312 pub async fn find_user_by_username_with_accounts(
313 &self,
314 username: &str,
315 ) -> Result<Option<UserWithAccounts>, RustAuthError> {
316 let Some(user) = self.find_user_by_username(username).await? else {
317 return Ok(None);
318 };
319 let accounts = self.list_accounts_for_user(&user.id).await?;
320 Ok(Some(UserWithAccounts { user, accounts }))
321 }
322
323 pub async fn find_user_by_email_with_accounts(
324 &self,
325 email: &str,
326 ) -> Result<Option<UserWithAccounts>, RustAuthError> {
327 let Some(mut record) = self
328 .adapter
329 .find_one(
330 FindOne::new(USER_MODEL)
331 .where_clause(Where::new("email", DbValue::String(normalize_email(email))))
332 .select(USER_FIELDS)
333 .join(ACCOUNT_MODEL, JoinOption::enabled()),
334 )
335 .await?
336 else {
337 return Ok(None);
338 };
339
340 let joined_accounts = record.shift_remove(ACCOUNT_MODEL);
341 let user = self.parse_user(record)?;
342 let accounts = match joined_accounts {
343 Some(DbValue::RecordArray(accounts)) => accounts
344 .into_iter()
345 .map(|record| self.parse_account(record))
346 .collect::<Result<Vec<_>, _>>()?,
347 Some(DbValue::Null) => Vec::new(),
348 None => self.list_accounts_for_user(&user.id).await?,
349 Some(_) => {
350 return Err(RustAuthError::Adapter(
351 "joined account result must be an array".to_owned(),
352 ));
353 }
354 };
355 Ok(Some(UserWithAccounts { user, accounts }))
356 }
357
358 pub async fn find_oauth_user(
359 &self,
360 email: &str,
361 account_id: &str,
362 provider_id: &str,
363 ) -> Result<Option<OAuthUserLookup>, RustAuthError> {
364 let linked_account = self
365 .find_account_by_provider_account(account_id, provider_id)
366 .await?;
367 let user = if let Some(account) = &linked_account {
368 self.find_user_by_id(&account.user_id).await?
369 } else {
370 self.find_user_by_email(email).await?
371 };
372 let Some(user) = user else {
373 return Ok(None);
374 };
375 let accounts = self.list_accounts_for_user(&user.id).await?;
376 Ok(Some(OAuthUserLookup {
377 user,
378 accounts,
379 linked_account,
380 }))
381 }
382
383 pub async fn list_accounts_for_user(
384 &self,
385 user_id: &str,
386 ) -> Result<Vec<Account>, RustAuthError> {
387 self.adapter
388 .find_many(
389 FindMany::new(ACCOUNT_MODEL)
390 .where_clause(Where::new("user_id", DbValue::String(user_id.to_owned())))
391 .select(ACCOUNT_FIELDS),
392 )
393 .await?
394 .into_iter()
395 .map(|record| self.parse_account(record))
396 .collect()
397 }
398
399 pub async fn find_credential_account(
400 &self,
401 user_id: &str,
402 ) -> Result<Option<Account>, RustAuthError> {
403 let record = self
404 .adapter
405 .find_one(
406 FindOne::new(ACCOUNT_MODEL)
407 .where_clause(Where::new("user_id", DbValue::String(user_id.to_owned())))
408 .where_clause(Where::new(
409 "provider_id",
410 DbValue::String(CREDENTIAL_PROVIDER_ID.to_owned()),
411 ))
412 .select(ACCOUNT_FIELDS),
413 )
414 .await?;
415
416 record.map(|record| self.parse_account(record)).transpose()
417 }
418
419 pub async fn find_account_by_provider_account(
420 &self,
421 account_id: &str,
422 provider_id: &str,
423 ) -> Result<Option<Account>, RustAuthError> {
424 let record = self
425 .adapter
426 .find_one(
427 FindOne::new(ACCOUNT_MODEL)
428 .where_clause(Where::new(
429 "account_id",
430 DbValue::String(account_id.to_owned()),
431 ))
432 .where_clause(Where::new(
433 "provider_id",
434 DbValue::String(provider_id.to_owned()),
435 ))
436 .select(ACCOUNT_FIELDS),
437 )
438 .await?;
439
440 record.map(|record| self.parse_account(record)).transpose()
441 }
442
443 pub async fn update_account(
444 &self,
445 account_id: &str,
446 input: UpdateAccountInput,
447 ) -> Result<Option<Account>, RustAuthError> {
448 let mut query = Update::new(ACCOUNT_MODEL)
449 .where_clause(Where::new("id", DbValue::String(account_id.to_owned())))
450 .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc()));
451 if let Some(value) = input.access_token {
452 query = query.data("access_token", optional_string(value));
453 }
454 if let Some(value) = input.refresh_token {
455 query = query.data("refresh_token", optional_string(value));
456 }
457 if let Some(value) = input.id_token {
458 query = query.data("id_token", optional_string(value));
459 }
460 if let Some(value) = input.access_token_expires_at {
461 query = query.data("access_token_expires_at", optional_timestamp(value));
462 }
463 if let Some(value) = input.refresh_token_expires_at {
464 query = query.data("refresh_token_expires_at", optional_timestamp(value));
465 }
466 if let Some(value) = input.scope {
467 query = query.data("scope", optional_string(value));
468 }
469
470 self.adapter
471 .update(query)
472 .await?
473 .map(|record| self.parse_account(record))
474 .transpose()
475 }
476
477 pub async fn update_user(
478 &self,
479 user_id: &str,
480 input: UpdateUserInput,
481 ) -> Result<Option<User>, RustAuthError> {
482 if input.is_empty() {
483 return self.find_user_by_id(user_id).await;
484 }
485 let mut query = Update::new(USER_MODEL)
486 .where_clause(Where::new("id", DbValue::String(user_id.to_owned())))
487 .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc()));
488 if let Some(name) = input.name {
489 query = query.data("name", DbValue::String(name));
490 }
491 if let Some(image) = input.image {
492 query = query.data("image", optional_string(image));
493 }
494 if let Some(username) = input.username {
495 query = query.data("username", optional_string(username));
496 }
497 if let Some(display_username) = input.display_username {
498 query = query.data("display_username", optional_string(display_username));
499 }
500 for (field, value) in input.fields {
501 query = query.data(field, value);
502 }
503 for (field, value) in input.additional_fields {
504 query = query.data(field, value);
505 }
506
507 self.adapter
508 .update(query)
509 .await?
510 .map(|record| self.parse_user(record))
511 .transpose()
512 }
513
514 pub async fn update_credential_password(
515 &self,
516 user_id: &str,
517 password_hash: &str,
518 ) -> Result<Option<Account>, RustAuthError> {
519 self.adapter
520 .update(
521 Update::new(ACCOUNT_MODEL)
522 .where_clause(Where::new("user_id", DbValue::String(user_id.to_owned())))
523 .where_clause(Where::new(
524 "provider_id",
525 DbValue::String(CREDENTIAL_PROVIDER_ID.to_owned()),
526 ))
527 .data("password", DbValue::String(password_hash.to_owned()))
528 .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc())),
529 )
530 .await?
531 .map(|record| self.parse_account(record))
532 .transpose()
533 }
534
535 pub async fn update_user_email_verified(
536 &self,
537 user_id: &str,
538 email_verified: bool,
539 ) -> Result<Option<User>, RustAuthError> {
540 self.adapter
541 .update(
542 Update::new(USER_MODEL)
543 .where_clause(Where::new("id", DbValue::String(user_id.to_owned())))
544 .data("email_verified", DbValue::Boolean(email_verified))
545 .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc())),
546 )
547 .await?
548 .map(|record| self.parse_user(record))
549 .transpose()
550 }
551
552 pub async fn update_user_email(
553 &self,
554 user_id: &str,
555 email: &str,
556 email_verified: bool,
557 ) -> Result<Option<User>, RustAuthError> {
558 self.adapter
559 .update(
560 Update::new(USER_MODEL)
561 .where_clause(Where::new("id", DbValue::String(user_id.to_owned())))
562 .data("email", DbValue::String(normalize_email(email)))
563 .data("email_verified", DbValue::Boolean(email_verified))
564 .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc())),
565 )
566 .await?
567 .map(|record| self.parse_user(record))
568 .transpose()
569 }
570
571 pub async fn delete_account(&self, account_id: &str) -> Result<(), RustAuthError> {
572 self.adapter
573 .delete(
574 Delete::new(ACCOUNT_MODEL)
575 .where_clause(Where::new("id", DbValue::String(account_id.to_owned()))),
576 )
577 .await
578 }
579
580 pub async fn delete_user_accounts(&self, user_id: &str) -> Result<u64, RustAuthError> {
581 self.adapter
582 .delete_many(
583 DeleteMany::new(ACCOUNT_MODEL)
584 .where_clause(Where::new("user_id", DbValue::String(user_id.to_owned()))),
585 )
586 .await
587 }
588
589 pub async fn delete_user(&self, user_id: &str) -> Result<(), RustAuthError> {
590 self.adapter
591 .delete(
592 Delete::new(USER_MODEL)
593 .where_clause(Where::new("id", DbValue::String(user_id.to_owned()))),
594 )
595 .await
596 }
597}
598
599fn normalize_email(email: &str) -> String {
600 email.to_lowercase()
601}
602
603fn optional_string(value: Option<String>) -> DbValue {
604 value.map(DbValue::String).unwrap_or(DbValue::Null)
605}
606
607fn optional_timestamp(value: Option<OffsetDateTime>) -> DbValue {
608 value.map(DbValue::Timestamp).unwrap_or(DbValue::Null)
609}
610
611fn store_create_oauth_user_result(
612 result: &Mutex<Option<CreateOAuthUserResult>>,
613 value: CreateOAuthUserResult,
614) -> Result<(), RustAuthError> {
615 let mut guard = result.lock().map_err(|_| RustAuthError::LockPoisoned {
616 context: "create OAuth user result",
617 })?;
618 *guard = Some(value);
619 Ok(())
620}
621
622fn take_create_oauth_user_result(
623 result: &Mutex<Option<CreateOAuthUserResult>>,
624) -> Result<Option<CreateOAuthUserResult>, RustAuthError> {
625 result
626 .lock()
627 .map_err(|_| RustAuthError::LockPoisoned {
628 context: "create OAuth user result",
629 })
630 .map(|mut guard| guard.take())
631}