Skip to main content

toro_auth_mongo/
lib.rs

1use async_trait::async_trait;
2use futures::TryStreamExt;
3use mongodb::{Client, Collection, Database, bson::doc};
4use serde::{Deserialize, Serialize};
5use std::{marker::PhantomData, str::FromStr};
6use toro_auth_core::{
7    ObjectId,
8    identity::{IdentityBackend, IdentityError},
9    session::{Session, SessionBackend, SessionError},
10};
11use uuid::Uuid;
12
13#[derive(Debug)]
14pub enum MongoInitError {
15    FailedToConnect,
16}
17
18#[derive(Clone)]
19pub struct MongoBackend<
20    T: ObjectId + Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
21> {
22    _mapper: PhantomData<T>,
23    identity_db: Collection<T>,
24    session_db: Collection<Session<T>>,
25}
26
27impl<T: ObjectId + Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static>
28    MongoBackend<T>
29{
30    pub fn new(db: Database) -> Self {
31        Self {
32            _mapper: PhantomData,
33            identity_db: db.collection("identity"),
34            session_db: db.collection("session"),
35        }
36    }
37
38    pub async fn from_url(url: String, db_name: String) -> Result<Self, MongoInitError> {
39        let client = Client::with_uri_str(url).await.map_err(|e| {
40            eprintln!("{e:#?}");
41            MongoInitError::FailedToConnect
42        })?;
43        let db = client.database(&db_name);
44        Ok(Self::new(db))
45    }
46
47    pub async fn search_identity(&self, username: String) -> Result<Vec<T>, IdentityError> {
48        let mut res = match self
49            .identity_db
50            .find(doc! {
51                "name": {
52                    "$regex": username,
53                    "$options": "i"
54                }
55            })
56            .await
57        {
58            Ok(res) => res,
59            Err(e) => {
60                eprintln!("{e}");
61                return Err(IdentityError::InternalServerError);
62            }
63        };
64
65        let mut identities = Vec::new();
66        while let Some(identity) = res.try_next().await.map_err(|e| {
67            eprintln!("{e:#?}");
68            IdentityError::InternalServerError
69        })? {
70            identities.push(identity);
71        }
72
73        Ok(identities)
74    }
75}
76
77#[async_trait]
78impl<T: ObjectId + Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static>
79    SessionBackend<T> for MongoBackend<T>
80{
81    async fn login(&self, username: String, password: String) -> Result<Session<T>, SessionError> {
82        let res = match self
83            .identity_db
84            .find_one(doc! {
85                "username": username,
86                "password": password
87            })
88            .await
89        {
90            Ok(res) => res,
91            Err(e) => {
92                eprintln!("{e}");
93                return Err(SessionError::InternalServerError);
94            }
95        };
96        let Some(identity) = res else {
97            return Err(SessionError::InvalidLogin);
98        };
99
100        let Some(user_id) = identity.id() else {
101            return Err(SessionError::InternalServerError);
102        };
103
104        let session = Session::new(Uuid::new_v4().into(), user_id.into());
105
106        let _ = match self.session_db.insert_one(session.clone()).await {
107            Ok(res) => res,
108            Err(e) => {
109                eprintln!("{e}");
110                return Err(SessionError::InternalServerError);
111            }
112        };
113
114        Ok(session)
115    }
116
117    async fn validate(&self, session_id: String) -> Result<T, SessionError> {
118        let res = match self
119            .session_db
120            .find_one(doc! {
121                "id": {
122                    "$eq": session_id
123                }
124            })
125            .await
126        {
127            Ok(res) => res,
128            Err(e) => {
129                eprintln!("{e}");
130                return Err(SessionError::InternalServerError);
131            }
132        };
133        let Some(session) = res else {
134            return Err(SessionError::InvalidOrMissingSession);
135        };
136
137        let identity = self.get_by_id(session.user_id).await.map_err(|_| {
138            eprintln!("Couldn't find related user");
139            SessionError::UserNotFound
140        })?;
141
142        Ok(identity)
143    }
144}
145
146#[async_trait]
147impl<T: ObjectId + Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static>
148    IdentityBackend<T> for MongoBackend<T>
149{
150    async fn get_all(&self) -> Result<Vec<T>, IdentityError> {
151        let mut res = match self.identity_db.find(doc! {}).await {
152            Ok(res) => res,
153            Err(e) => {
154                eprintln!("{e}");
155                return Err(IdentityError::InternalServerError);
156            }
157        };
158
159        let mut identities = Vec::new();
160        while let Some(identity) = res.try_next().await.map_err(|e| {
161            eprintln!("{e:#?}");
162            IdentityError::InternalServerError
163        })? {
164            identities.push(identity);
165        }
166
167        Ok(identities)
168    }
169
170    async fn create(&self, mut identity: T) -> Result<(), IdentityError> {
171        identity.set_id(Uuid::new_v4());
172
173        self.identity_db
174            .insert_one(identity.clone())
175            .await
176            .map_err(|e| {
177                eprintln!("{e:#?}");
178                IdentityError::InternalServerError
179            })?;
180
181        Ok(())
182    }
183
184    async fn get_by_username(&self, username: String) -> Result<Option<T>, IdentityError> {
185        self.identity_db
186            .find_one(doc! {
187                "username": {
188                    "$eq": username
189                }
190            })
191            .await
192            .map_err(|e| {
193                eprintln!("{e:#?}");
194                IdentityError::InternalServerError
195            })
196    }
197
198    async fn get_by_id(&self, id: String) -> Result<T, IdentityError> {
199        let res = match self
200            .identity_db
201            .find_one(doc! {
202                "id": {
203                    "$eq": id
204                }
205            })
206            .await
207        {
208            Ok(res) => res,
209            Err(e) => {
210                eprintln!("{e}");
211                return Err(IdentityError::InternalServerError);
212            }
213        };
214        let Some(identity) = res else {
215            return Err(IdentityError::NotFound);
216        };
217        Ok(identity)
218    }
219
220    async fn update_by_id(&self, id: String, identity: T) -> Result<(), IdentityError> {
221        let mut identity = identity;
222        identity.set_id(Uuid::from_str(&id).map_err(|_| IdentityError::InvalidId)?);
223        let res = match self
224            .identity_db
225            .replace_one(
226                doc! {
227                    "id": {
228                        "$eq": id
229                    }
230                },
231                identity,
232            )
233            .await
234        {
235            Ok(res) => res,
236            Err(e) => {
237                eprintln!("{e:#?}");
238                return Err(IdentityError::InternalServerError);
239            }
240        };
241
242        if res.matched_count <= 0 && res.modified_count <= 0 {
243            return Err(IdentityError::NotFound);
244        }
245
246        Ok(())
247    }
248
249    async fn delete_by_id(&self, id: String) -> Result<(), IdentityError> {
250        let res = match self
251            .identity_db
252            .delete_one(doc! {
253                "id": {
254                    "$eq": id
255                }
256            })
257            .await
258        {
259            Ok(res) => res,
260            Err(e) => {
261                eprintln!("{e}");
262                return Err(IdentityError::InternalServerError);
263            }
264        };
265
266        match res.deleted_count {
267            0 => Err(IdentityError::NotFound),
268            _ => Ok(()),
269        }
270    }
271}