Skip to main content

rustauth_core/auth/
email_password.rs

1//! Email/password auth service built on top of core DB stores.
2
3use std::error::Error;
4use std::fmt;
5use std::sync::{Arc, Mutex};
6
7use http::StatusCode;
8use time::{Duration, OffsetDateTime};
9
10use crate::api::ApiErrorResponse;
11use crate::crypto::password::{hash_password, verify_password};
12use crate::db::{DbAdapter, DbRecord, Session, User};
13use crate::error::RustAuthError;
14use crate::error_codes::ErrorCode;
15use crate::options::SecondaryStorage;
16use crate::session::{CreateSessionInput, SessionStore};
17use crate::user::{CreateCredentialAccountInput, CreateUserInput, DbUserStore};
18
19pub type PasswordHashFn = fn(&str) -> Result<String, RustAuthError>;
20pub type PasswordVerifyFn = fn(&str, &str) -> Result<bool, RustAuthError>;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum AuthFlowErrorCode {
24    InvalidEmail,
25    InvalidPasswordLength,
26    InvalidEmailOrPassword,
27    UserAlreadyExists,
28    UserAlreadyExistsUseAnotherEmail,
29    EmailNotVerified,
30    FailedToCreateSession,
31    StorageError,
32}
33
34impl AuthFlowErrorCode {
35    pub fn as_str(self) -> &'static str {
36        match self {
37            Self::InvalidEmail => "INVALID_EMAIL",
38            Self::InvalidPasswordLength => "INVALID_PASSWORD_LENGTH",
39            Self::InvalidEmailOrPassword => "INVALID_EMAIL_OR_PASSWORD",
40            Self::UserAlreadyExists => crate::error_codes::USER_ALREADY_EXISTS,
41            Self::UserAlreadyExistsUseAnotherEmail => {
42                crate::error_codes::USER_ALREADY_EXISTS_USE_ANOTHER_EMAIL
43            }
44            Self::EmailNotVerified => "EMAIL_NOT_VERIFIED",
45            Self::FailedToCreateSession => "FAILED_TO_CREATE_SESSION",
46            Self::StorageError => "STORAGE_ERROR",
47        }
48    }
49
50    pub fn message(self) -> &'static str {
51        match self {
52            Self::InvalidEmail => "Invalid email",
53            Self::InvalidPasswordLength => "Invalid password length",
54            Self::InvalidEmailOrPassword => "Invalid email or password",
55            Self::UserAlreadyExists => "User already exists",
56            Self::UserAlreadyExistsUseAnotherEmail => "User already exists. Use another email.",
57            Self::EmailNotVerified => "Email not verified",
58            Self::FailedToCreateSession => "Failed to create session",
59            Self::StorageError => "Storage error",
60        }
61    }
62}
63
64impl ErrorCode for AuthFlowErrorCode {
65    fn as_str(&self) -> &str {
66        (*self).as_str()
67    }
68
69    fn message(&self) -> &str {
70        (*self).message()
71    }
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
75pub struct AuthFlowError {
76    code: AuthFlowErrorCode,
77    message: String,
78}
79
80impl AuthFlowError {
81    pub fn new(code: AuthFlowErrorCode) -> Self {
82        Self {
83            code,
84            message: code.message().to_owned(),
85        }
86    }
87
88    pub fn storage(error: RustAuthError) -> Self {
89        Self {
90            code: AuthFlowErrorCode::StorageError,
91            message: error.to_string(),
92        }
93    }
94
95    pub fn code(&self) -> AuthFlowErrorCode {
96        self.code
97    }
98
99    pub fn code_str(&self) -> &'static str {
100        self.code.as_str()
101    }
102
103    pub fn message(&self) -> &str {
104        self.message.as_str()
105    }
106
107    pub fn http_status(&self) -> StatusCode {
108        match self.code {
109            AuthFlowErrorCode::InvalidEmailOrPassword => StatusCode::UNAUTHORIZED,
110            AuthFlowErrorCode::EmailNotVerified => StatusCode::FORBIDDEN,
111            AuthFlowErrorCode::StorageError | AuthFlowErrorCode::FailedToCreateSession => {
112                StatusCode::INTERNAL_SERVER_ERROR
113            }
114            AuthFlowErrorCode::InvalidEmail
115            | AuthFlowErrorCode::InvalidPasswordLength
116            | AuthFlowErrorCode::UserAlreadyExists
117            | AuthFlowErrorCode::UserAlreadyExistsUseAnotherEmail => StatusCode::BAD_REQUEST,
118        }
119    }
120
121    pub fn to_api_response(&self) -> ApiErrorResponse {
122        ApiErrorResponse {
123            code: self.code_str().to_owned(),
124            message: self.message().to_owned(),
125            original_message: None,
126        }
127    }
128}
129
130impl fmt::Display for AuthFlowError {
131    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
132        write!(formatter, "{}: {}", self.code.as_str(), self.message)
133    }
134}
135
136impl Error for AuthFlowError {}
137
138impl From<RustAuthError> for AuthFlowError {
139    fn from(error: RustAuthError) -> Self {
140        Self::storage(error)
141    }
142}
143
144#[derive(Clone)]
145pub struct EmailPasswordConfig {
146    pub session_expires_in: u64,
147    pub dont_remember_session_expires_in: u64,
148    pub min_password_length: usize,
149    pub max_password_length: usize,
150    pub require_email_verification: bool,
151    pub secondary_storage: Option<Arc<dyn SecondaryStorage>>,
152    pub store_session_in_database: bool,
153    pub preserve_session_in_database: bool,
154}
155
156impl Default for EmailPasswordConfig {
157    fn default() -> Self {
158        Self {
159            session_expires_in: 60 * 60 * 24 * 7,
160            dont_remember_session_expires_in: 60 * 60 * 24,
161            min_password_length: 8,
162            max_password_length: 128,
163            require_email_verification: false,
164            secondary_storage: None,
165            store_session_in_database: false,
166            preserve_session_in_database: false,
167        }
168    }
169}
170
171impl fmt::Debug for EmailPasswordConfig {
172    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
173        formatter
174            .debug_struct("EmailPasswordConfig")
175            .field("session_expires_in", &self.session_expires_in)
176            .field(
177                "dont_remember_session_expires_in",
178                &self.dont_remember_session_expires_in,
179            )
180            .field("min_password_length", &self.min_password_length)
181            .field("max_password_length", &self.max_password_length)
182            .field(
183                "require_email_verification",
184                &self.require_email_verification,
185            )
186            .field(
187                "secondary_storage",
188                &self
189                    .secondary_storage
190                    .as_ref()
191                    .map(|_| "<secondary-storage>"),
192            )
193            .field("store_session_in_database", &self.store_session_in_database)
194            .field(
195                "preserve_session_in_database",
196                &self.preserve_session_in_database,
197            )
198            .finish()
199    }
200}
201
202#[derive(Debug, Clone, PartialEq)]
203pub struct SignUpInput {
204    pub name: String,
205    pub email: String,
206    pub password: String,
207    pub image: Option<String>,
208    pub username: Option<String>,
209    pub display_username: Option<String>,
210    pub remember_me: bool,
211    pub ip_address: Option<String>,
212    pub user_agent: Option<String>,
213    pub additional_user_fields: DbRecord,
214    pub additional_session_fields: DbRecord,
215}
216
217impl SignUpInput {
218    pub fn new(
219        name: impl Into<String>,
220        email: impl Into<String>,
221        password: impl Into<String>,
222    ) -> Self {
223        Self {
224            name: name.into(),
225            email: email.into(),
226            password: password.into(),
227            image: None,
228            username: None,
229            display_username: None,
230            remember_me: true,
231            ip_address: None,
232            user_agent: None,
233            additional_user_fields: DbRecord::new(),
234            additional_session_fields: DbRecord::new(),
235        }
236    }
237
238    #[must_use]
239    pub fn image(mut self, image: impl Into<String>) -> Self {
240        self.image = Some(image.into());
241        self
242    }
243
244    #[must_use]
245    pub fn username(mut self, username: impl Into<String>) -> Self {
246        self.username = Some(username.into());
247        self
248    }
249
250    #[must_use]
251    pub fn display_username(mut self, display_username: impl Into<String>) -> Self {
252        self.display_username = Some(display_username.into());
253        self
254    }
255
256    #[must_use]
257    pub fn remember_me(mut self, remember_me: bool) -> Self {
258        self.remember_me = remember_me;
259        self
260    }
261
262    #[must_use]
263    pub fn ip_address(mut self, ip_address: impl Into<String>) -> Self {
264        self.ip_address = Some(ip_address.into());
265        self
266    }
267
268    #[must_use]
269    pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
270        self.user_agent = Some(user_agent.into());
271        self
272    }
273
274    #[must_use]
275    pub fn additional_user_fields(mut self, fields: DbRecord) -> Self {
276        self.additional_user_fields = fields;
277        self
278    }
279
280    #[must_use]
281    pub fn additional_session_fields(mut self, fields: DbRecord) -> Self {
282        self.additional_session_fields = fields;
283        self
284    }
285}
286
287#[derive(Debug, Clone, PartialEq)]
288pub struct SignInInput {
289    pub email: String,
290    pub password: String,
291    pub remember_me: bool,
292    pub ip_address: Option<String>,
293    pub user_agent: Option<String>,
294    pub additional_session_fields: DbRecord,
295}
296
297impl SignInInput {
298    pub fn new(email: impl Into<String>, password: impl Into<String>) -> Self {
299        Self {
300            email: email.into(),
301            password: password.into(),
302            remember_me: true,
303            ip_address: None,
304            user_agent: None,
305            additional_session_fields: DbRecord::new(),
306        }
307    }
308
309    #[must_use]
310    pub fn remember_me(mut self, remember_me: bool) -> Self {
311        self.remember_me = remember_me;
312        self
313    }
314
315    #[must_use]
316    pub fn ip_address(mut self, ip_address: impl Into<String>) -> Self {
317        self.ip_address = Some(ip_address.into());
318        self
319    }
320
321    #[must_use]
322    pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
323        self.user_agent = Some(user_agent.into());
324        self
325    }
326
327    #[must_use]
328    pub fn additional_session_fields(mut self, fields: DbRecord) -> Self {
329        self.additional_session_fields = fields;
330        self
331    }
332}
333
334#[derive(Debug, Clone, PartialEq, Eq)]
335pub struct EmailPasswordAuthResult {
336    pub user: User,
337    pub session: Session,
338}
339
340#[derive(Clone)]
341pub struct EmailPasswordAuth<'a> {
342    adapter: &'a dyn DbAdapter,
343    config: EmailPasswordConfig,
344    hash_password: PasswordHashFn,
345    verify_password: PasswordVerifyFn,
346}
347
348impl<'a> EmailPasswordAuth<'a> {
349    pub fn new(
350        adapter: &'a dyn DbAdapter,
351        config: EmailPasswordConfig,
352        hash_password: PasswordHashFn,
353        verify_password: PasswordVerifyFn,
354    ) -> Self {
355        Self {
356            adapter,
357            config,
358            hash_password,
359            verify_password,
360        }
361    }
362
363    pub fn with_defaults(adapter: &'a dyn DbAdapter, config: EmailPasswordConfig) -> Self {
364        Self::new(adapter, config, hash_password, verify_password)
365    }
366
367    pub async fn sign_up(
368        &self,
369        input: SignUpInput,
370    ) -> Result<EmailPasswordAuthResult, AuthFlowError> {
371        self.validate_email_and_password(&input.email, &input.password)?;
372        let users = DbUserStore::new(self.adapter);
373        if users.find_user_by_email(&input.email).await?.is_some() {
374            return Err(AuthFlowError::new(AuthFlowErrorCode::UserAlreadyExists));
375        }
376
377        let password_hash = (self.hash_password)(&input.password)?;
378        let mut create_user = CreateUserInput::new(input.name, input.email)
379            .additional_fields(input.additional_user_fields);
380        if let Some(image) = input.image {
381            create_user = create_user.image(image);
382        }
383        if let Some(username) = input.username {
384            create_user = create_user.username(username);
385        }
386        if let Some(display_username) = input.display_username {
387            create_user = create_user.display_username(display_username);
388        }
389        let result = Arc::new(Mutex::new(None));
390        let result_for_transaction = Arc::clone(&result);
391        let config = self.config.clone();
392        let transaction_status = self
393            .adapter
394            .transaction(Box::new(move |transaction| {
395                Box::pin(async move {
396                    let outcome = create_sign_up_records(SignUpRecordsInput {
397                        adapter: transaction.as_ref(),
398                        config: &config,
399                        create_user,
400                        password_hash,
401                        remember_me: input.remember_me,
402                        ip_address: input.ip_address,
403                        user_agent: input.user_agent,
404                        additional_session_fields: input.additional_session_fields,
405                    })
406                    .await;
407                    match outcome {
408                        Ok(result) => {
409                            store_sign_up_result(&result_for_transaction, Ok(result))?;
410                            Ok(())
411                        }
412                        Err(error) => {
413                            let transaction_error = RustAuthError::Adapter(error.to_string());
414                            store_sign_up_result(&result_for_transaction, Err(error))?;
415                            Err(transaction_error)
416                        }
417                    }
418                })
419            }))
420            .await;
421
422        match transaction_status {
423            Ok(()) => match take_sign_up_result(&result)? {
424                Some(Ok(result)) => Ok(result),
425                Some(Err(error)) => Err(error),
426                None => Err(AuthFlowError::storage(RustAuthError::Adapter(
427                    "sign-up transaction completed without a result".to_owned(),
428                ))),
429            },
430            Err(error) => match take_sign_up_result(&result)? {
431                Some(Err(auth_error)) => Err(auth_error),
432                _ => Err(AuthFlowError::storage(error)),
433            },
434        }
435    }
436
437    pub async fn sign_in(
438        &self,
439        input: SignInInput,
440    ) -> Result<EmailPasswordAuthResult, AuthFlowError> {
441        validate_email(&input.email)?;
442        let users = DbUserStore::new(self.adapter);
443        let Some(user_with_accounts) = users.find_user_by_email_with_accounts(&input.email).await?
444        else {
445            let _ = (self.hash_password)(&input.password);
446            return Err(AuthFlowError::new(
447                AuthFlowErrorCode::InvalidEmailOrPassword,
448            ));
449        };
450        let Some(account) = user_with_accounts
451            .accounts
452            .iter()
453            .find(|account| account.provider_id == "credential")
454        else {
455            let _ = (self.hash_password)(&input.password);
456            return Err(AuthFlowError::new(
457                AuthFlowErrorCode::InvalidEmailOrPassword,
458            ));
459        };
460        let Some(password_hash) = account.password.as_deref() else {
461            let _ = (self.hash_password)(&input.password);
462            return Err(AuthFlowError::new(
463                AuthFlowErrorCode::InvalidEmailOrPassword,
464            ));
465        };
466        if !(self.verify_password)(password_hash, &input.password)? {
467            return Err(AuthFlowError::new(
468                AuthFlowErrorCode::InvalidEmailOrPassword,
469            ));
470        }
471        if self.config.require_email_verification && !user_with_accounts.user.email_verified {
472            return Err(AuthFlowError::new(AuthFlowErrorCode::EmailNotVerified));
473        }
474        let session = create_session_record(
475            self.adapter,
476            &self.config,
477            &user_with_accounts.user.id,
478            input.remember_me,
479            input.ip_address,
480            input.user_agent,
481            input.additional_session_fields,
482        )
483        .await?;
484
485        Ok(EmailPasswordAuthResult {
486            user: user_with_accounts.user,
487            session,
488        })
489    }
490
491    fn validate_email_and_password(
492        &self,
493        email: &str,
494        password: &str,
495    ) -> Result<(), AuthFlowError> {
496        validate_email(email)?;
497        if password.len() < self.config.min_password_length
498            || password.len() > self.config.max_password_length
499        {
500            return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidPasswordLength));
501        }
502        Ok(())
503    }
504}
505
506struct SignUpRecordsInput<'a> {
507    adapter: &'a dyn DbAdapter,
508    config: &'a EmailPasswordConfig,
509    create_user: CreateUserInput,
510    password_hash: String,
511    remember_me: bool,
512    ip_address: Option<String>,
513    user_agent: Option<String>,
514    additional_session_fields: DbRecord,
515}
516
517async fn create_sign_up_records(
518    input: SignUpRecordsInput<'_>,
519) -> Result<EmailPasswordAuthResult, AuthFlowError> {
520    let users = DbUserStore::new(input.adapter);
521    let user = users.create_user(input.create_user).await?;
522    users
523        .create_credential_account(CreateCredentialAccountInput::new(
524            user.id.clone(),
525            input.password_hash,
526        ))
527        .await?;
528    let session = create_session_record(
529        input.adapter,
530        input.config,
531        &user.id,
532        input.remember_me,
533        input.ip_address,
534        input.user_agent,
535        input.additional_session_fields,
536    )
537    .await?;
538
539    Ok(EmailPasswordAuthResult { user, session })
540}
541
542async fn create_session_record(
543    adapter: &dyn DbAdapter,
544    config: &EmailPasswordConfig,
545    user_id: &str,
546    remember_me: bool,
547    ip_address: Option<String>,
548    user_agent: Option<String>,
549    additional_fields: DbRecord,
550) -> Result<Session, AuthFlowError> {
551    let expires_in = if remember_me {
552        config.session_expires_in
553    } else {
554        config.dont_remember_session_expires_in
555    };
556    let seconds = i64::try_from(expires_in)
557        .map_err(|_| AuthFlowError::new(AuthFlowErrorCode::FailedToCreateSession))?;
558    let expires_at = OffsetDateTime::now_utc() + Duration::seconds(seconds);
559    let mut input =
560        CreateSessionInput::new(user_id, expires_at).additional_fields(additional_fields);
561    if let Some(ip_address) = ip_address {
562        input = input.ip_address(ip_address);
563    }
564    if let Some(user_agent) = user_agent {
565        input = input.user_agent(user_agent);
566    }
567
568    SessionStore::with_storage(
569        adapter,
570        config.secondary_storage.clone(),
571        config.store_session_in_database,
572        config.preserve_session_in_database,
573    )
574    .create_session(input)
575    .await
576    .map_err(|_| AuthFlowError::new(AuthFlowErrorCode::FailedToCreateSession))
577}
578
579fn store_sign_up_result(
580    result: &Mutex<Option<Result<EmailPasswordAuthResult, AuthFlowError>>>,
581    value: Result<EmailPasswordAuthResult, AuthFlowError>,
582) -> Result<(), RustAuthError> {
583    let mut guard = result.lock().map_err(|_| RustAuthError::LockPoisoned {
584        context: "sign-up result",
585    })?;
586    *guard = Some(value);
587    Ok(())
588}
589
590fn take_sign_up_result(
591    result: &Mutex<Option<Result<EmailPasswordAuthResult, AuthFlowError>>>,
592) -> Result<Option<Result<EmailPasswordAuthResult, AuthFlowError>>, AuthFlowError> {
593    result
594        .lock()
595        .map_err(|_| {
596            AuthFlowError::storage(RustAuthError::LockPoisoned {
597                context: "sign-up result",
598            })
599        })
600        .map(|mut guard| guard.take())
601}
602
603fn validate_email(email: &str) -> Result<(), AuthFlowError> {
604    let email = email.trim();
605    let Some((local, domain)) = email.split_once('@') else {
606        return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidEmail));
607    };
608    if local.is_empty()
609        || domain.is_empty()
610        || domain.starts_with('.')
611        || domain.ends_with('.')
612        || !domain.contains('.')
613    {
614        return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidEmail));
615    }
616    Ok(())
617}
618
619#[cfg(test)]
620mod tests {
621    use super::*;
622    use crate::error_codes::ErrorCode;
623    use http::StatusCode;
624
625    fn assert_error_code(code: impl ErrorCode, expected_code: &str, expected_message: &str) {
626        assert_eq!(code.as_str(), expected_code);
627        assert_eq!(code.message(), expected_message);
628    }
629
630    #[test]
631    fn auth_flow_error_code_implements_error_code_trait() {
632        assert_error_code(
633            AuthFlowErrorCode::InvalidEmail,
634            "INVALID_EMAIL",
635            "Invalid email",
636        );
637    }
638
639    #[test]
640    fn auth_flow_error_to_api_response_preserves_code_and_message() {
641        let error = AuthFlowError::new(AuthFlowErrorCode::InvalidEmailOrPassword);
642        let response = error.to_api_response();
643        assert_eq!(response.code, "INVALID_EMAIL_OR_PASSWORD");
644        assert_eq!(response.message, "Invalid email or password");
645        assert_eq!(response.original_message, None);
646    }
647
648    #[test]
649    fn auth_flow_error_http_status_mapping_matches_route_helpers() {
650        let cases = [
651            (
652                AuthFlowErrorCode::InvalidEmailOrPassword,
653                StatusCode::UNAUTHORIZED,
654            ),
655            (AuthFlowErrorCode::EmailNotVerified, StatusCode::FORBIDDEN),
656            (
657                AuthFlowErrorCode::FailedToCreateSession,
658                StatusCode::INTERNAL_SERVER_ERROR,
659            ),
660            (
661                AuthFlowErrorCode::UserAlreadyExists,
662                StatusCode::BAD_REQUEST,
663            ),
664        ];
665        for (code, expected_status) in cases {
666            let error = AuthFlowError::new(code);
667            assert_eq!(error.http_status(), expected_status);
668        }
669    }
670}