shield_credentials/
email_password.rs

1use std::{pin::Pin, sync::Arc};
2
3use async_trait::async_trait;
4use serde::Deserialize;
5use shield::{Form, Input, InputType, InputTypeEmail, InputTypePassword, ShieldError, User};
6
7use crate::Credentials;
8
9#[derive(Debug, Deserialize)]
10pub struct EmailPasswordData {
11    pub email: String,
12    pub password: String,
13}
14
15type SignInFn<U> = dyn Fn(EmailPasswordData) -> Pin<Box<dyn Future<Output = Result<U, ShieldError>> + Send + Sync>>
16    + Send
17    + Sync;
18
19pub struct EmailPasswordCredentials<U: User> {
20    sign_in_fn: Arc<SignInFn<U>>,
21}
22
23impl<U: User> EmailPasswordCredentials<U> {
24    pub fn new(
25        sign_in_fn: impl Fn(
26            EmailPasswordData,
27        )
28            -> Pin<Box<dyn Future<Output = Result<U, ShieldError>> + Send + Sync>>
29        + Send
30        + Sync
31        + 'static,
32    ) -> Self {
33        Self {
34            sign_in_fn: Arc::new(sign_in_fn),
35        }
36    }
37}
38
39#[async_trait]
40impl<U: User> Credentials<U, EmailPasswordData> for EmailPasswordCredentials<U> {
41    fn form(&self) -> Form {
42        Form {
43            inputs: vec![
44                Input {
45                    name: "email".to_owned(),
46                    label: Some("Email address".to_owned()),
47                    r#type: InputType::Email(InputTypeEmail {
48                        autocomplete: Some("email".to_owned()),
49                        placeholder: Some("Email address".to_owned()),
50                        required: Some(true),
51                        ..Default::default()
52                    }),
53                    value: None,
54                },
55                Input {
56                    name: "password".to_owned(),
57                    label: Some("Password".to_owned()),
58                    r#type: InputType::Password(InputTypePassword {
59                        autocomplete: Some("current-password".to_owned()),
60                        placeholder: Some("Password".to_owned()),
61                        required: Some(true),
62                        ..Default::default()
63                    }),
64                    value: None,
65                },
66            ],
67        }
68    }
69
70    async fn sign_in(&self, data: EmailPasswordData) -> Result<U, ShieldError> {
71        (self.sign_in_fn)(data).await
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use async_trait::async_trait;
78    use serde::{Deserialize, Serialize};
79    use shield::{EmailAddress, ShieldError, StorageError, User};
80
81    use crate::Credentials;
82
83    use super::{EmailPasswordCredentials, EmailPasswordData};
84
85    #[derive(Clone, Debug, Deserialize, Serialize)]
86    pub struct TestUser {
87        id: String,
88        name: Option<String>,
89    }
90
91    #[async_trait]
92    impl User for TestUser {
93        fn id(&self) -> String {
94            self.id.clone()
95        }
96
97        fn name(&self) -> Option<String> {
98            self.name.clone()
99        }
100
101        async fn email_addresses(&self) -> Result<Vec<EmailAddress>, StorageError> {
102            Ok(vec![])
103        }
104
105        fn additional(&self) -> Option<impl Serialize> {
106            None::<()>
107        }
108    }
109
110    #[tokio::test]
111    async fn email_password_credentials() -> Result<(), ShieldError> {
112        let credentials = EmailPasswordCredentials::new(|data: EmailPasswordData| {
113            Box::pin(async move {
114                if data.email == "test@example.com" && data.password == "test" {
115                    Ok(TestUser {
116                        id: "1".to_owned(),
117                        name: Some("Test".to_owned()),
118                    })
119                } else {
120                    Err(ShieldError::Validation(
121                        "Incorrect email and password combination.".to_owned(),
122                    ))
123                }
124            })
125        });
126
127        assert!(
128            credentials
129                .sign_in(EmailPasswordData {
130                    email: "test@example.com".to_owned(),
131                    password: "incorrect".to_owned(),
132                })
133                .await
134                .is_err_and(|err| err
135                    .to_string()
136                    .contains("Incorrect email and password combination."))
137        );
138
139        let user = credentials
140            .sign_in(EmailPasswordData {
141                email: "test@example.com".to_owned(),
142                password: "test".to_owned(),
143            })
144            .await?;
145
146        assert_eq!(user.name, Some("Test".to_owned()));
147
148        Ok(())
149    }
150}