1use std::error::Error;
4use std::fmt;
5use std::sync::{Arc, Mutex};
6
7use http::StatusCode;
8use time::{Duration, OffsetDateTime};
9
10use crate::api::ApiErrorResponse;
11use crate::crypto::password::{hash_password, verify_password};
12use crate::db::{DbAdapter, DbRecord, Session, User};
13use crate::error::RustAuthError;
14use crate::error_codes::ErrorCode;
15use crate::options::SecondaryStorage;
16use crate::session::{CreateSessionInput, SessionStore};
17use crate::user::{CreateCredentialAccountInput, CreateUserInput, DbUserStore};
18
19pub type PasswordHashFn = fn(&str) -> Result<String, RustAuthError>;
20pub type PasswordVerifyFn = fn(&str, &str) -> Result<bool, RustAuthError>;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum AuthFlowErrorCode {
24 InvalidEmail,
25 InvalidPasswordLength,
26 InvalidEmailOrPassword,
27 UserAlreadyExists,
28 UserAlreadyExistsUseAnotherEmail,
29 EmailNotVerified,
30 FailedToCreateSession,
31 StorageError,
32}
33
34impl AuthFlowErrorCode {
35 pub fn as_str(self) -> &'static str {
36 match self {
37 Self::InvalidEmail => "INVALID_EMAIL",
38 Self::InvalidPasswordLength => "INVALID_PASSWORD_LENGTH",
39 Self::InvalidEmailOrPassword => "INVALID_EMAIL_OR_PASSWORD",
40 Self::UserAlreadyExists => crate::error_codes::USER_ALREADY_EXISTS,
41 Self::UserAlreadyExistsUseAnotherEmail => {
42 crate::error_codes::USER_ALREADY_EXISTS_USE_ANOTHER_EMAIL
43 }
44 Self::EmailNotVerified => "EMAIL_NOT_VERIFIED",
45 Self::FailedToCreateSession => "FAILED_TO_CREATE_SESSION",
46 Self::StorageError => "STORAGE_ERROR",
47 }
48 }
49
50 pub fn message(self) -> &'static str {
51 match self {
52 Self::InvalidEmail => "Invalid email",
53 Self::InvalidPasswordLength => "Invalid password length",
54 Self::InvalidEmailOrPassword => "Invalid email or password",
55 Self::UserAlreadyExists => "User already exists",
56 Self::UserAlreadyExistsUseAnotherEmail => "User already exists. Use another email.",
57 Self::EmailNotVerified => "Email not verified",
58 Self::FailedToCreateSession => "Failed to create session",
59 Self::StorageError => "Storage error",
60 }
61 }
62}
63
64impl ErrorCode for AuthFlowErrorCode {
65 fn as_str(&self) -> &str {
66 (*self).as_str()
67 }
68
69 fn message(&self) -> &str {
70 (*self).message()
71 }
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
75pub struct AuthFlowError {
76 code: AuthFlowErrorCode,
77 message: String,
78}
79
80impl AuthFlowError {
81 pub fn new(code: AuthFlowErrorCode) -> Self {
82 Self {
83 code,
84 message: code.message().to_owned(),
85 }
86 }
87
88 pub fn storage(error: RustAuthError) -> Self {
89 Self {
90 code: AuthFlowErrorCode::StorageError,
91 message: error.to_string(),
92 }
93 }
94
95 pub fn code(&self) -> AuthFlowErrorCode {
96 self.code
97 }
98
99 pub fn code_str(&self) -> &'static str {
100 self.code.as_str()
101 }
102
103 pub fn message(&self) -> &str {
104 self.message.as_str()
105 }
106
107 pub fn http_status(&self) -> StatusCode {
108 match self.code {
109 AuthFlowErrorCode::InvalidEmailOrPassword => StatusCode::UNAUTHORIZED,
110 AuthFlowErrorCode::EmailNotVerified => StatusCode::FORBIDDEN,
111 AuthFlowErrorCode::StorageError | AuthFlowErrorCode::FailedToCreateSession => {
112 StatusCode::INTERNAL_SERVER_ERROR
113 }
114 AuthFlowErrorCode::InvalidEmail
115 | AuthFlowErrorCode::InvalidPasswordLength
116 | AuthFlowErrorCode::UserAlreadyExists
117 | AuthFlowErrorCode::UserAlreadyExistsUseAnotherEmail => StatusCode::BAD_REQUEST,
118 }
119 }
120
121 pub fn to_api_response(&self) -> ApiErrorResponse {
122 ApiErrorResponse {
123 code: self.code_str().to_owned(),
124 message: self.message().to_owned(),
125 original_message: None,
126 }
127 }
128}
129
130impl fmt::Display for AuthFlowError {
131 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
132 write!(formatter, "{}: {}", self.code.as_str(), self.message)
133 }
134}
135
136impl Error for AuthFlowError {}
137
138impl From<RustAuthError> for AuthFlowError {
139 fn from(error: RustAuthError) -> Self {
140 Self::storage(error)
141 }
142}
143
144#[derive(Clone)]
145pub struct EmailPasswordConfig {
146 pub session_expires_in: u64,
147 pub dont_remember_session_expires_in: u64,
148 pub min_password_length: usize,
149 pub max_password_length: usize,
150 pub require_email_verification: bool,
151 pub secondary_storage: Option<Arc<dyn SecondaryStorage>>,
152 pub store_session_in_database: bool,
153 pub preserve_session_in_database: bool,
154}
155
156impl Default for EmailPasswordConfig {
157 fn default() -> Self {
158 Self {
159 session_expires_in: 60 * 60 * 24 * 7,
160 dont_remember_session_expires_in: 60 * 60 * 24,
161 min_password_length: 8,
162 max_password_length: 128,
163 require_email_verification: false,
164 secondary_storage: None,
165 store_session_in_database: false,
166 preserve_session_in_database: false,
167 }
168 }
169}
170
171impl fmt::Debug for EmailPasswordConfig {
172 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
173 formatter
174 .debug_struct("EmailPasswordConfig")
175 .field("session_expires_in", &self.session_expires_in)
176 .field(
177 "dont_remember_session_expires_in",
178 &self.dont_remember_session_expires_in,
179 )
180 .field("min_password_length", &self.min_password_length)
181 .field("max_password_length", &self.max_password_length)
182 .field(
183 "require_email_verification",
184 &self.require_email_verification,
185 )
186 .field(
187 "secondary_storage",
188 &self
189 .secondary_storage
190 .as_ref()
191 .map(|_| "<secondary-storage>"),
192 )
193 .field("store_session_in_database", &self.store_session_in_database)
194 .field(
195 "preserve_session_in_database",
196 &self.preserve_session_in_database,
197 )
198 .finish()
199 }
200}
201
202#[derive(Debug, Clone, PartialEq)]
203pub struct SignUpInput {
204 pub name: String,
205 pub email: String,
206 pub password: String,
207 pub image: Option<String>,
208 pub username: Option<String>,
209 pub display_username: Option<String>,
210 pub remember_me: bool,
211 pub ip_address: Option<String>,
212 pub user_agent: Option<String>,
213 pub additional_user_fields: DbRecord,
214 pub additional_session_fields: DbRecord,
215}
216
217impl SignUpInput {
218 pub fn new(
219 name: impl Into<String>,
220 email: impl Into<String>,
221 password: impl Into<String>,
222 ) -> Self {
223 Self {
224 name: name.into(),
225 email: email.into(),
226 password: password.into(),
227 image: None,
228 username: None,
229 display_username: None,
230 remember_me: true,
231 ip_address: None,
232 user_agent: None,
233 additional_user_fields: DbRecord::new(),
234 additional_session_fields: DbRecord::new(),
235 }
236 }
237
238 #[must_use]
239 pub fn image(mut self, image: impl Into<String>) -> Self {
240 self.image = Some(image.into());
241 self
242 }
243
244 #[must_use]
245 pub fn username(mut self, username: impl Into<String>) -> Self {
246 self.username = Some(username.into());
247 self
248 }
249
250 #[must_use]
251 pub fn display_username(mut self, display_username: impl Into<String>) -> Self {
252 self.display_username = Some(display_username.into());
253 self
254 }
255
256 #[must_use]
257 pub fn remember_me(mut self, remember_me: bool) -> Self {
258 self.remember_me = remember_me;
259 self
260 }
261
262 #[must_use]
263 pub fn ip_address(mut self, ip_address: impl Into<String>) -> Self {
264 self.ip_address = Some(ip_address.into());
265 self
266 }
267
268 #[must_use]
269 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
270 self.user_agent = Some(user_agent.into());
271 self
272 }
273
274 #[must_use]
275 pub fn additional_user_fields(mut self, fields: DbRecord) -> Self {
276 self.additional_user_fields = fields;
277 self
278 }
279
280 #[must_use]
281 pub fn additional_session_fields(mut self, fields: DbRecord) -> Self {
282 self.additional_session_fields = fields;
283 self
284 }
285}
286
287#[derive(Debug, Clone, PartialEq)]
288pub struct SignInInput {
289 pub email: String,
290 pub password: String,
291 pub remember_me: bool,
292 pub ip_address: Option<String>,
293 pub user_agent: Option<String>,
294 pub additional_session_fields: DbRecord,
295}
296
297impl SignInInput {
298 pub fn new(email: impl Into<String>, password: impl Into<String>) -> Self {
299 Self {
300 email: email.into(),
301 password: password.into(),
302 remember_me: true,
303 ip_address: None,
304 user_agent: None,
305 additional_session_fields: DbRecord::new(),
306 }
307 }
308
309 #[must_use]
310 pub fn remember_me(mut self, remember_me: bool) -> Self {
311 self.remember_me = remember_me;
312 self
313 }
314
315 #[must_use]
316 pub fn ip_address(mut self, ip_address: impl Into<String>) -> Self {
317 self.ip_address = Some(ip_address.into());
318 self
319 }
320
321 #[must_use]
322 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
323 self.user_agent = Some(user_agent.into());
324 self
325 }
326
327 #[must_use]
328 pub fn additional_session_fields(mut self, fields: DbRecord) -> Self {
329 self.additional_session_fields = fields;
330 self
331 }
332}
333
334#[derive(Debug, Clone, PartialEq, Eq)]
335pub struct EmailPasswordAuthResult {
336 pub user: User,
337 pub session: Session,
338}
339
340#[derive(Clone)]
341pub struct EmailPasswordAuth<'a> {
342 adapter: &'a dyn DbAdapter,
343 config: EmailPasswordConfig,
344 hash_password: PasswordHashFn,
345 verify_password: PasswordVerifyFn,
346}
347
348impl<'a> EmailPasswordAuth<'a> {
349 pub fn new(
350 adapter: &'a dyn DbAdapter,
351 config: EmailPasswordConfig,
352 hash_password: PasswordHashFn,
353 verify_password: PasswordVerifyFn,
354 ) -> Self {
355 Self {
356 adapter,
357 config,
358 hash_password,
359 verify_password,
360 }
361 }
362
363 pub fn with_defaults(adapter: &'a dyn DbAdapter, config: EmailPasswordConfig) -> Self {
364 Self::new(adapter, config, hash_password, verify_password)
365 }
366
367 pub async fn sign_up(
368 &self,
369 input: SignUpInput,
370 ) -> Result<EmailPasswordAuthResult, AuthFlowError> {
371 self.validate_email_and_password(&input.email, &input.password)?;
372 let users = DbUserStore::new(self.adapter);
373 if users.find_user_by_email(&input.email).await?.is_some() {
374 return Err(AuthFlowError::new(AuthFlowErrorCode::UserAlreadyExists));
375 }
376
377 let password_hash = (self.hash_password)(&input.password)?;
378 let mut create_user = CreateUserInput::new(input.name, input.email)
379 .additional_fields(input.additional_user_fields);
380 if let Some(image) = input.image {
381 create_user = create_user.image(image);
382 }
383 if let Some(username) = input.username {
384 create_user = create_user.username(username);
385 }
386 if let Some(display_username) = input.display_username {
387 create_user = create_user.display_username(display_username);
388 }
389 let result = Arc::new(Mutex::new(None));
390 let result_for_transaction = Arc::clone(&result);
391 let config = self.config.clone();
392 let transaction_status = self
393 .adapter
394 .transaction(Box::new(move |transaction| {
395 Box::pin(async move {
396 let outcome = create_sign_up_records(SignUpRecordsInput {
397 adapter: transaction.as_ref(),
398 config: &config,
399 create_user,
400 password_hash,
401 remember_me: input.remember_me,
402 ip_address: input.ip_address,
403 user_agent: input.user_agent,
404 additional_session_fields: input.additional_session_fields,
405 })
406 .await;
407 match outcome {
408 Ok(result) => {
409 store_sign_up_result(&result_for_transaction, Ok(result))?;
410 Ok(())
411 }
412 Err(error) => {
413 let transaction_error = RustAuthError::Adapter(error.to_string());
414 store_sign_up_result(&result_for_transaction, Err(error))?;
415 Err(transaction_error)
416 }
417 }
418 })
419 }))
420 .await;
421
422 match transaction_status {
423 Ok(()) => match take_sign_up_result(&result)? {
424 Some(Ok(result)) => Ok(result),
425 Some(Err(error)) => Err(error),
426 None => Err(AuthFlowError::storage(RustAuthError::Adapter(
427 "sign-up transaction completed without a result".to_owned(),
428 ))),
429 },
430 Err(error) => match take_sign_up_result(&result)? {
431 Some(Err(auth_error)) => Err(auth_error),
432 _ => Err(AuthFlowError::storage(error)),
433 },
434 }
435 }
436
437 pub async fn sign_in(
438 &self,
439 input: SignInInput,
440 ) -> Result<EmailPasswordAuthResult, AuthFlowError> {
441 validate_email(&input.email)?;
442 let users = DbUserStore::new(self.adapter);
443 let Some(user_with_accounts) = users.find_user_by_email_with_accounts(&input.email).await?
444 else {
445 let _ = (self.hash_password)(&input.password);
446 return Err(AuthFlowError::new(
447 AuthFlowErrorCode::InvalidEmailOrPassword,
448 ));
449 };
450 let Some(account) = user_with_accounts
451 .accounts
452 .iter()
453 .find(|account| account.provider_id == "credential")
454 else {
455 let _ = (self.hash_password)(&input.password);
456 return Err(AuthFlowError::new(
457 AuthFlowErrorCode::InvalidEmailOrPassword,
458 ));
459 };
460 let Some(password_hash) = account.password.as_deref() else {
461 let _ = (self.hash_password)(&input.password);
462 return Err(AuthFlowError::new(
463 AuthFlowErrorCode::InvalidEmailOrPassword,
464 ));
465 };
466 if !(self.verify_password)(password_hash, &input.password)? {
467 return Err(AuthFlowError::new(
468 AuthFlowErrorCode::InvalidEmailOrPassword,
469 ));
470 }
471 if self.config.require_email_verification && !user_with_accounts.user.email_verified {
472 return Err(AuthFlowError::new(AuthFlowErrorCode::EmailNotVerified));
473 }
474 let session = create_session_record(
475 self.adapter,
476 &self.config,
477 &user_with_accounts.user.id,
478 input.remember_me,
479 input.ip_address,
480 input.user_agent,
481 input.additional_session_fields,
482 )
483 .await?;
484
485 Ok(EmailPasswordAuthResult {
486 user: user_with_accounts.user,
487 session,
488 })
489 }
490
491 fn validate_email_and_password(
492 &self,
493 email: &str,
494 password: &str,
495 ) -> Result<(), AuthFlowError> {
496 validate_email(email)?;
497 if password.len() < self.config.min_password_length
498 || password.len() > self.config.max_password_length
499 {
500 return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidPasswordLength));
501 }
502 Ok(())
503 }
504}
505
506struct SignUpRecordsInput<'a> {
507 adapter: &'a dyn DbAdapter,
508 config: &'a EmailPasswordConfig,
509 create_user: CreateUserInput,
510 password_hash: String,
511 remember_me: bool,
512 ip_address: Option<String>,
513 user_agent: Option<String>,
514 additional_session_fields: DbRecord,
515}
516
517async fn create_sign_up_records(
518 input: SignUpRecordsInput<'_>,
519) -> Result<EmailPasswordAuthResult, AuthFlowError> {
520 let users = DbUserStore::new(input.adapter);
521 let user = users.create_user(input.create_user).await?;
522 users
523 .create_credential_account(CreateCredentialAccountInput::new(
524 user.id.clone(),
525 input.password_hash,
526 ))
527 .await?;
528 let session = create_session_record(
529 input.adapter,
530 input.config,
531 &user.id,
532 input.remember_me,
533 input.ip_address,
534 input.user_agent,
535 input.additional_session_fields,
536 )
537 .await?;
538
539 Ok(EmailPasswordAuthResult { user, session })
540}
541
542async fn create_session_record(
543 adapter: &dyn DbAdapter,
544 config: &EmailPasswordConfig,
545 user_id: &str,
546 remember_me: bool,
547 ip_address: Option<String>,
548 user_agent: Option<String>,
549 additional_fields: DbRecord,
550) -> Result<Session, AuthFlowError> {
551 let expires_in = if remember_me {
552 config.session_expires_in
553 } else {
554 config.dont_remember_session_expires_in
555 };
556 let seconds = i64::try_from(expires_in)
557 .map_err(|_| AuthFlowError::new(AuthFlowErrorCode::FailedToCreateSession))?;
558 let expires_at = OffsetDateTime::now_utc() + Duration::seconds(seconds);
559 let mut input =
560 CreateSessionInput::new(user_id, expires_at).additional_fields(additional_fields);
561 if let Some(ip_address) = ip_address {
562 input = input.ip_address(ip_address);
563 }
564 if let Some(user_agent) = user_agent {
565 input = input.user_agent(user_agent);
566 }
567
568 SessionStore::with_storage(
569 adapter,
570 config.secondary_storage.clone(),
571 config.store_session_in_database,
572 config.preserve_session_in_database,
573 )
574 .create_session(input)
575 .await
576 .map_err(|_| AuthFlowError::new(AuthFlowErrorCode::FailedToCreateSession))
577}
578
579fn store_sign_up_result(
580 result: &Mutex<Option<Result<EmailPasswordAuthResult, AuthFlowError>>>,
581 value: Result<EmailPasswordAuthResult, AuthFlowError>,
582) -> Result<(), RustAuthError> {
583 let mut guard = result.lock().map_err(|_| RustAuthError::LockPoisoned {
584 context: "sign-up result",
585 })?;
586 *guard = Some(value);
587 Ok(())
588}
589
590fn take_sign_up_result(
591 result: &Mutex<Option<Result<EmailPasswordAuthResult, AuthFlowError>>>,
592) -> Result<Option<Result<EmailPasswordAuthResult, AuthFlowError>>, AuthFlowError> {
593 result
594 .lock()
595 .map_err(|_| {
596 AuthFlowError::storage(RustAuthError::LockPoisoned {
597 context: "sign-up result",
598 })
599 })
600 .map(|mut guard| guard.take())
601}
602
603fn validate_email(email: &str) -> Result<(), AuthFlowError> {
604 let email = email.trim();
605 let Some((local, domain)) = email.split_once('@') else {
606 return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidEmail));
607 };
608 if local.is_empty()
609 || domain.is_empty()
610 || domain.starts_with('.')
611 || domain.ends_with('.')
612 || !domain.contains('.')
613 {
614 return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidEmail));
615 }
616 Ok(())
617}
618
619#[cfg(test)]
620mod tests {
621 use super::*;
622 use crate::error_codes::ErrorCode;
623 use http::StatusCode;
624
625 fn assert_error_code(code: impl ErrorCode, expected_code: &str, expected_message: &str) {
626 assert_eq!(code.as_str(), expected_code);
627 assert_eq!(code.message(), expected_message);
628 }
629
630 #[test]
631 fn auth_flow_error_code_implements_error_code_trait() {
632 assert_error_code(
633 AuthFlowErrorCode::InvalidEmail,
634 "INVALID_EMAIL",
635 "Invalid email",
636 );
637 }
638
639 #[test]
640 fn auth_flow_error_to_api_response_preserves_code_and_message() {
641 let error = AuthFlowError::new(AuthFlowErrorCode::InvalidEmailOrPassword);
642 let response = error.to_api_response();
643 assert_eq!(response.code, "INVALID_EMAIL_OR_PASSWORD");
644 assert_eq!(response.message, "Invalid email or password");
645 assert_eq!(response.original_message, None);
646 }
647
648 #[test]
649 fn auth_flow_error_http_status_mapping_matches_route_helpers() {
650 let cases = [
651 (
652 AuthFlowErrorCode::InvalidEmailOrPassword,
653 StatusCode::UNAUTHORIZED,
654 ),
655 (AuthFlowErrorCode::EmailNotVerified, StatusCode::FORBIDDEN),
656 (
657 AuthFlowErrorCode::FailedToCreateSession,
658 StatusCode::INTERNAL_SERVER_ERROR,
659 ),
660 (
661 AuthFlowErrorCode::UserAlreadyExists,
662 StatusCode::BAD_REQUEST,
663 ),
664 ];
665 for (code, expected_status) in cases {
666 let error = AuthFlowError::new(code);
667 assert_eq!(error.http_status(), expected_status);
668 }
669 }
670}