1use async_trait::async_trait;
2use futures::TryStreamExt;
3use mongodb::{Client, Collection, Database, bson::doc};
4use serde::{Deserialize, Serialize};
5use std::{marker::PhantomData, str::FromStr};
6use toro_auth_core::{
7 ObjectId,
8 identity::{IdentityBackend, IdentityError},
9 session::{Session, SessionBackend, SessionError},
10};
11use uuid::Uuid;
12
13#[derive(Debug)]
14pub enum MongoInitError {
15 FailedToConnect,
16}
17
18#[derive(Clone)]
19pub struct MongoBackend<
20 T: ObjectId + Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
21> {
22 _mapper: PhantomData<T>,
23 identity_db: Collection<T>,
24 session_db: Collection<Session<T>>,
25}
26
27impl<T: ObjectId + Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static>
28 MongoBackend<T>
29{
30 pub fn new(db: Database) -> Self {
31 Self {
32 _mapper: PhantomData,
33 identity_db: db.collection("identity"),
34 session_db: db.collection("session"),
35 }
36 }
37
38 pub async fn from_url(url: String, db_name: String) -> Result<Self, MongoInitError> {
39 let client = Client::with_uri_str(url).await.map_err(|e| {
40 eprintln!("{e:#?}");
41 MongoInitError::FailedToConnect
42 })?;
43 let db = client.database(&db_name);
44 Ok(Self::new(db))
45 }
46
47 pub async fn search_identity(&self, username: String) -> Result<Vec<T>, IdentityError> {
48 let mut res = match self
49 .identity_db
50 .find(doc! {
51 "name": {
52 "$regex": username,
53 "$options": "i"
54 }
55 })
56 .await
57 {
58 Ok(res) => res,
59 Err(e) => {
60 eprintln!("{e}");
61 return Err(IdentityError::InternalServerError);
62 }
63 };
64
65 let mut identities = Vec::new();
66 while let Some(identity) = res.try_next().await.map_err(|e| {
67 eprintln!("{e:#?}");
68 IdentityError::InternalServerError
69 })? {
70 identities.push(identity);
71 }
72
73 Ok(identities)
74 }
75}
76
77#[async_trait]
78impl<T: ObjectId + Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static>
79 SessionBackend<T> for MongoBackend<T>
80{
81 async fn login(&self, username: String, password: String) -> Result<Session<T>, SessionError> {
82 let res = match self
83 .identity_db
84 .find_one(doc! {
85 "username": username,
86 "password": password
87 })
88 .await
89 {
90 Ok(res) => res,
91 Err(e) => {
92 eprintln!("{e}");
93 return Err(SessionError::InternalServerError);
94 }
95 };
96 let Some(identity) = res else {
97 return Err(SessionError::InvalidLogin);
98 };
99
100 let Some(user_id) = identity.id() else {
101 return Err(SessionError::InternalServerError);
102 };
103
104 let session = Session::new(Uuid::new_v4().into(), user_id.into());
105
106 let _ = match self.session_db.insert_one(session.clone()).await {
107 Ok(res) => res,
108 Err(e) => {
109 eprintln!("{e}");
110 return Err(SessionError::InternalServerError);
111 }
112 };
113
114 Ok(session)
115 }
116
117 async fn validate(&self, session_id: String) -> Result<T, SessionError> {
118 let res = match self
119 .session_db
120 .find_one(doc! {
121 "id": {
122 "$eq": session_id
123 }
124 })
125 .await
126 {
127 Ok(res) => res,
128 Err(e) => {
129 eprintln!("{e}");
130 return Err(SessionError::InternalServerError);
131 }
132 };
133 let Some(session) = res else {
134 return Err(SessionError::InvalidOrMissingSession);
135 };
136
137 let identity = self.get_by_id(session.user_id).await.map_err(|_| {
138 eprintln!("Couldn't find related user");
139 SessionError::UserNotFound
140 })?;
141
142 Ok(identity)
143 }
144}
145
146#[async_trait]
147impl<T: ObjectId + Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static>
148 IdentityBackend<T> for MongoBackend<T>
149{
150 async fn get_all(&self) -> Result<Vec<T>, IdentityError> {
151 let mut res = match self.identity_db.find(doc! {}).await {
152 Ok(res) => res,
153 Err(e) => {
154 eprintln!("{e}");
155 return Err(IdentityError::InternalServerError);
156 }
157 };
158
159 let mut identities = Vec::new();
160 while let Some(identity) = res.try_next().await.map_err(|e| {
161 eprintln!("{e:#?}");
162 IdentityError::InternalServerError
163 })? {
164 identities.push(identity);
165 }
166
167 Ok(identities)
168 }
169
170 async fn create(&self, mut identity: T) -> Result<(), IdentityError> {
171 identity.set_id(Uuid::new_v4());
172
173 self.identity_db
174 .insert_one(identity.clone())
175 .await
176 .map_err(|e| {
177 eprintln!("{e:#?}");
178 IdentityError::InternalServerError
179 })?;
180
181 Ok(())
182 }
183
184 async fn get_by_username(&self, username: String) -> Result<Option<T>, IdentityError> {
185 self.identity_db
186 .find_one(doc! {
187 "username": {
188 "$eq": username
189 }
190 })
191 .await
192 .map_err(|e| {
193 eprintln!("{e:#?}");
194 IdentityError::InternalServerError
195 })
196 }
197
198 async fn get_by_id(&self, id: String) -> Result<T, IdentityError> {
199 let res = match self
200 .identity_db
201 .find_one(doc! {
202 "id": {
203 "$eq": id
204 }
205 })
206 .await
207 {
208 Ok(res) => res,
209 Err(e) => {
210 eprintln!("{e}");
211 return Err(IdentityError::InternalServerError);
212 }
213 };
214 let Some(identity) = res else {
215 return Err(IdentityError::NotFound);
216 };
217 Ok(identity)
218 }
219
220 async fn update_by_id(&self, id: String, identity: T) -> Result<(), IdentityError> {
221 let mut identity = identity;
222 identity.set_id(Uuid::from_str(&id).map_err(|_| IdentityError::InvalidId)?);
223 let res = match self
224 .identity_db
225 .replace_one(
226 doc! {
227 "id": {
228 "$eq": id
229 }
230 },
231 identity,
232 )
233 .await
234 {
235 Ok(res) => res,
236 Err(e) => {
237 eprintln!("{e:#?}");
238 return Err(IdentityError::InternalServerError);
239 }
240 };
241
242 if res.matched_count <= 0 && res.modified_count <= 0 {
243 return Err(IdentityError::NotFound);
244 }
245
246 Ok(())
247 }
248
249 async fn delete_by_id(&self, id: String) -> Result<(), IdentityError> {
250 let res = match self
251 .identity_db
252 .delete_one(doc! {
253 "id": {
254 "$eq": id
255 }
256 })
257 .await
258 {
259 Ok(res) => res,
260 Err(e) => {
261 eprintln!("{e}");
262 return Err(IdentityError::InternalServerError);
263 }
264 };
265
266 match res.deleted_count {
267 0 => Err(IdentityError::NotFound),
268 _ => Ok(()),
269 }
270 }
271}