Skip to main content

shield_credentials/
username_password.rs

1use std::{pin::Pin, sync::Arc};
2
3use async_trait::async_trait;
4use serde::Deserialize;
5use shield::{Form, Input, InputType, InputTypePassword, InputTypeText, ShieldError, User};
6
7use crate::Credentials;
8
9#[derive(Debug, Deserialize)]
10pub struct UsernamePasswordData {
11    pub username: String,
12    pub password: String,
13}
14
15type SignInFn<U> = dyn Fn(UsernamePasswordData) -> Pin<Box<dyn Future<Output = Result<U, ShieldError>> + Send + Sync>>
16    + Send
17    + Sync;
18
19pub struct UsernamePasswordCredentials<U: User> {
20    sign_in_fn: Arc<SignInFn<U>>,
21}
22
23impl<U: User> UsernamePasswordCredentials<U> {
24    pub fn new(
25        sign_in_fn: impl Fn(
26            UsernamePasswordData,
27        )
28            -> Pin<Box<dyn Future<Output = Result<U, ShieldError>> + Send + Sync>>
29        + Send
30        + Sync
31        + 'static,
32    ) -> Self {
33        Self {
34            sign_in_fn: Arc::new(sign_in_fn),
35        }
36    }
37}
38
39#[async_trait]
40impl<U: User> Credentials<U, UsernamePasswordData> for UsernamePasswordCredentials<U> {
41    fn form(&self) -> Form {
42        Form {
43            inputs: vec![
44                Input {
45                    name: "username".to_owned(),
46                    label: Some("Username".to_owned()),
47                    r#type: InputType::Text(InputTypeText {
48                        autocomplete: Some("username".to_owned()),
49                        placeholder: Some("Username".to_owned()),
50                        required: Some(true),
51                        ..Default::default()
52                    }),
53                    value: None,
54                    addon_start: None,
55                    addon_end: None,
56                },
57                Input {
58                    name: "password".to_owned(),
59                    label: Some("Password".to_owned()),
60                    r#type: InputType::Password(InputTypePassword {
61                        autocomplete: Some("current-password".to_owned()),
62                        placeholder: Some("Password".to_owned()),
63                        required: Some(true),
64                        ..Default::default()
65                    }),
66                    value: None,
67                    addon_start: None,
68                    addon_end: None,
69                },
70            ],
71        }
72    }
73
74    async fn sign_in(&self, data: UsernamePasswordData) -> Result<U, ShieldError> {
75        (self.sign_in_fn)(data).await
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use async_trait::async_trait;
82    use serde::{Deserialize, Serialize};
83    use shield::{EmailAddress, ShieldError, StorageError, User};
84
85    use crate::Credentials;
86
87    use super::{UsernamePasswordCredentials, UsernamePasswordData};
88
89    #[derive(Clone, Debug, Deserialize, Serialize)]
90    pub struct TestUser {
91        id: String,
92        name: Option<String>,
93    }
94
95    #[async_trait]
96    impl User for TestUser {
97        fn id(&self) -> String {
98            self.id.clone()
99        }
100
101        fn name(&self) -> Option<String> {
102            self.name.clone()
103        }
104
105        async fn email_addresses(&self) -> Result<Vec<EmailAddress>, StorageError> {
106            Ok(vec![])
107        }
108
109        fn additional(&self) -> Option<impl Serialize> {
110            None::<()>
111        }
112    }
113
114    #[tokio::test]
115    async fn username_password_credentials() -> Result<(), ShieldError> {
116        let credentials = UsernamePasswordCredentials::new(|data: UsernamePasswordData| {
117            Box::pin(async move {
118                if data.username == "test" && data.password == "test" {
119                    Ok(TestUser {
120                        id: "1".to_owned(),
121                        name: Some("Test".to_owned()),
122                    })
123                } else {
124                    Err(ShieldError::Validation(
125                        "Incorrect username and password combination.".to_owned(),
126                    ))
127                }
128            })
129        });
130
131        assert!(
132            credentials
133                .sign_in(UsernamePasswordData {
134                    username: "test".to_owned(),
135                    password: "incorrect".to_owned(),
136                })
137                .await
138                .is_err_and(|err| err
139                    .to_string()
140                    .contains("Incorrect username and password combination."))
141        );
142
143        let user = credentials
144            .sign_in(UsernamePasswordData {
145                username: "test".to_owned(),
146                password: "test".to_owned(),
147            })
148            .await?;
149
150        assert_eq!(user.name, Some("Test".to_owned()));
151
152        Ok(())
153    }
154}