shield_credentials/
email_password.rs1use std::{pin::Pin, sync::Arc};
2
3use async_trait::async_trait;
4use serde::Deserialize;
5use shield::{Form, Input, InputType, InputTypeEmail, InputTypePassword, ShieldError, User};
6
7use crate::Credentials;
8
9#[derive(Debug, Deserialize)]
10pub struct EmailPasswordData {
11 pub email: String,
12 pub password: String,
13}
14
15type SignInFn<U> = dyn Fn(EmailPasswordData) -> Pin<Box<dyn Future<Output = Result<U, ShieldError>> + Send + Sync>>
16 + Send
17 + Sync;
18
19pub struct EmailPasswordCredentials<U: User> {
20 sign_in_fn: Arc<SignInFn<U>>,
21}
22
23impl<U: User> EmailPasswordCredentials<U> {
24 pub fn new(
25 sign_in_fn: impl Fn(
26 EmailPasswordData,
27 )
28 -> Pin<Box<dyn Future<Output = Result<U, ShieldError>> + Send + Sync>>
29 + Send
30 + Sync
31 + 'static,
32 ) -> Self {
33 Self {
34 sign_in_fn: Arc::new(sign_in_fn),
35 }
36 }
37}
38
39#[async_trait]
40impl<U: User> Credentials<U, EmailPasswordData> for EmailPasswordCredentials<U> {
41 fn form(&self) -> Form {
42 Form {
43 inputs: vec![
44 Input {
45 name: "email".to_owned(),
46 label: Some("Email address".to_owned()),
47 r#type: InputType::Email(InputTypeEmail {
48 autocomplete: Some("email".to_owned()),
49 placeholder: Some("Email address".to_owned()),
50 required: Some(true),
51 ..Default::default()
52 }),
53 value: None,
54 },
55 Input {
56 name: "password".to_owned(),
57 label: Some("Password".to_owned()),
58 r#type: InputType::Password(InputTypePassword {
59 autocomplete: Some("current-password".to_owned()),
60 placeholder: Some("Password".to_owned()),
61 required: Some(true),
62 ..Default::default()
63 }),
64 value: None,
65 },
66 ],
67 }
68 }
69
70 async fn sign_in(&self, data: EmailPasswordData) -> Result<U, ShieldError> {
71 (self.sign_in_fn)(data).await
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use async_trait::async_trait;
78 use serde::{Deserialize, Serialize};
79 use shield::{EmailAddress, ShieldError, StorageError, User};
80
81 use crate::Credentials;
82
83 use super::{EmailPasswordCredentials, EmailPasswordData};
84
85 #[derive(Clone, Debug, Deserialize, Serialize)]
86 pub struct TestUser {
87 id: String,
88 name: Option<String>,
89 }
90
91 #[async_trait]
92 impl User for TestUser {
93 fn id(&self) -> String {
94 self.id.clone()
95 }
96
97 fn name(&self) -> Option<String> {
98 self.name.clone()
99 }
100
101 async fn email_addresses(&self) -> Result<Vec<EmailAddress>, StorageError> {
102 Ok(vec![])
103 }
104
105 fn additional(&self) -> Option<impl Serialize> {
106 None::<()>
107 }
108 }
109
110 #[tokio::test]
111 async fn email_password_credentials() -> Result<(), ShieldError> {
112 let credentials = EmailPasswordCredentials::new(|data: EmailPasswordData| {
113 Box::pin(async move {
114 if data.email == "test@example.com" && data.password == "test" {
115 Ok(TestUser {
116 id: "1".to_owned(),
117 name: Some("Test".to_owned()),
118 })
119 } else {
120 Err(ShieldError::Validation(
121 "Incorrect email and password combination.".to_owned(),
122 ))
123 }
124 })
125 });
126
127 assert!(
128 credentials
129 .sign_in(EmailPasswordData {
130 email: "test@example.com".to_owned(),
131 password: "incorrect".to_owned(),
132 })
133 .await
134 .is_err_and(|err| err
135 .to_string()
136 .contains("Incorrect email and password combination."))
137 );
138
139 let user = credentials
140 .sign_in(EmailPasswordData {
141 email: "test@example.com".to_owned(),
142 password: "test".to_owned(),
143 })
144 .await?;
145
146 assert_eq!(user.name, Some("Test".to_owned()));
147
148 Ok(())
149 }
150}