shield_sea_orm/
user.rs

1use std::ops::Deref;
2
3use async_trait::async_trait;
4use sea_orm::{DatabaseConnection, ModelTrait, prelude::Uuid};
5use serde::Serialize;
6use shield::{EmailAddress, StorageError};
7
8#[cfg(feature = "entity")]
9use crate::entities::entity;
10use crate::entities::{email_address, user};
11
12#[derive(Clone, Debug)]
13pub struct User {
14    database: DatabaseConnection,
15    user: user::Model,
16    #[cfg(feature = "entity")]
17    entity: entity::Model,
18}
19
20impl User {
21    pub(crate) fn new(
22        database: DatabaseConnection,
23        user: user::Model,
24        #[cfg(feature = "entity")] entity: entity::Model,
25    ) -> Self {
26        Self {
27            database,
28            user,
29            #[cfg(feature = "entity")]
30            entity,
31        }
32    }
33
34    #[cfg(feature = "entity")]
35    pub fn entity(&self) -> &entity::Model {
36        &self.entity
37    }
38}
39
40impl Deref for User {
41    type Target = user::Model;
42
43    fn deref(&self) -> &Self::Target {
44        &self.user
45    }
46}
47
48#[derive(Serialize)]
49#[serde(rename_all = "camelCase")]
50pub struct Additional {
51    #[cfg(feature = "entity")]
52    entity_id: String,
53}
54
55#[async_trait]
56impl shield::User for User {
57    fn id(&self) -> String {
58        self.user.id.to_string()
59    }
60
61    fn name(&self) -> Option<String> {
62        #[cfg(feature = "entity")]
63        {
64            Some(self.entity.name.clone())
65        }
66
67        #[cfg(not(feature = "entity"))]
68        {
69            Some(self.user.name.clone())
70        }
71    }
72
73    async fn email_addresses(&self) -> Result<Vec<EmailAddress>, StorageError> {
74        #[cfg(feature = "entity")]
75        {
76            self.entity
77                .find_related(email_address::Entity)
78                .all(&self.database)
79                .await
80                .map_err(|err| StorageError::Engine(err.to_string()))
81                .map(|email_addresses| {
82                    email_addresses
83                        .into_iter()
84                        .map(|email_address| {
85                            EmailAddress::from(EmailAddressWithUserId(email_address, self.user.id))
86                        })
87                        .collect()
88                })
89        }
90
91        #[cfg(not(feature = "entity"))]
92        {
93            self.user
94                .find_related(email_address::Entity)
95                .all(&self.database)
96                .await
97                .map_err(|err| StorageError::Engine(err.to_string()))
98                .map(|email_addresses| {
99                    email_addresses
100                        .into_iter()
101                        .map(EmailAddress::from)
102                        .collect()
103                })
104        }
105    }
106
107    fn additional(&self) -> Option<impl Serialize> {
108        Some(Additional {
109            #[cfg(feature = "entity")]
110            entity_id: self.user.entity_id.to_string(),
111        })
112    }
113}
114
115#[cfg(not(feature = "entity"))]
116impl From<email_address::Model> for EmailAddress {
117    fn from(value: email_address::Model) -> Self {
118        Self {
119            id: value.id.to_string(),
120            email: value.email,
121            is_primary: value.is_primary,
122            is_verified: value.is_verified,
123            verification_token: value.verification_token,
124            verification_token_expired_at: value.verification_token_expired_at,
125            verified_at: value.verified_at,
126            user_id: value.user_id.to_string(),
127        }
128    }
129}
130
131struct EmailAddressWithUserId(email_address::Model, Uuid);
132
133impl From<EmailAddressWithUserId> for EmailAddress {
134    fn from(EmailAddressWithUserId(value, user_id): EmailAddressWithUserId) -> Self {
135        Self {
136            id: value.id.to_string(),
137            email: value.email,
138            is_primary: value.is_primary,
139            is_verified: value.is_verified,
140            verification_token: value.verification_token,
141            verification_token_expired_at: value.verification_token_expired_at,
142            verified_at: value.verified_at,
143            user_id: user_id.to_string(),
144        }
145    }
146}