userman_auth/
lib.rs

1use futures::StreamExt;
2use futures::TryStreamExt;
3use haikunator::Haikunator;
4use log::error;
5use mongodb::bson::doc;
6use mongodb::bson::oid::ObjectId;
7use mongodb::options::ClientOptions;
8use mongodb::{Client, Database};
9use role::RoleItems;
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14pub mod app;
15mod error;
16pub mod role;
17
18use app::App;
19use role::Role;
20
21pub use error::AuthError;
22pub type Result<T> = std::result::Result<T, AuthError>;
23
24const APPS: &str = "apps";
25const ROLES: &str = "roles";
26
27fn serialize_oid_as_string<S>(oid: &ObjectId, serializer: S) -> std::result::Result<S::Ok, S::Error>
28where
29    S: serde::Serializer,
30{
31    serializer.serialize_str(oid.to_string().as_str())
32}
33
34fn serialize_option_oid_as_string<S>(
35    oid: &Option<ObjectId>,
36    serializer: S,
37) -> std::result::Result<S::Ok, S::Error>
38where
39    S: serde::Serializer,
40{
41    match oid {
42        Some(ref t) => serializer.serialize_some(t.to_string().as_str()),
43        None => serializer.serialize_none(),
44    }
45}
46
47#[derive(Clone, Debug)]
48pub struct Roles(Arc<RwLock<HashMap<String, Role>>>);
49
50impl Default for Roles {
51    fn default() -> Self {
52        Self(Arc::new(RwLock::new(HashMap::new())))
53    }
54}
55
56impl Roles {
57    async fn set(&self, src: HashMap<String, Role>) {
58        let mut lock = self.0.write().await;
59        *lock = src;
60    }
61
62    async fn get<'r, T: Into<&'r str>>(&self, name: T) -> Option<Role> {
63        let lock = self.0.read().await;
64        lock.get(name.into()).cloned()
65    }
66}
67
68#[derive(Clone, Debug)]
69pub struct Auth {
70    roles: Roles,
71    database: Database,
72    app_name: String,
73}
74
75impl Auth {
76    pub async fn add_role_items(&self, role_names: Vec<String>) -> RoleItems {
77        let mut parent = RoleItems::default();
78
79        for name in role_names {
80            if let Some(role) = self.roles.get(name.as_str()).await {
81                role.items.add(&mut parent);
82            }
83        }
84
85        parent
86    }
87}
88
89#[derive(Debug)]
90pub struct MongoDB {
91    uri: String,
92    db_name: String,
93    client_name: String,
94}
95
96impl Default for MongoDB {
97    fn default() -> Self {
98        Self {
99            uri: String::from("mongodb://localhost:27017"),
100            db_name: String::from("umt"),
101            client_name: Haikunator::default().haikunate(),
102        }
103    }
104}
105
106#[derive(Debug)]
107pub struct AuthBuilder {
108    mongodb: MongoDB,
109    app_name: String,
110}
111
112impl AuthBuilder {
113    pub fn mongodb_uri<T: Into<String>>(&mut self, src: T) -> &mut Self {
114        self.mongodb.uri = src.into();
115        self
116    }
117
118    pub fn mongodb_db_name<T: Into<String>>(&mut self, src: T) -> &mut Self {
119        self.mongodb.db_name = src.into();
120        self
121    }
122
123    pub fn mongodb_app_name<T: Into<String>>(&mut self, src: T) -> &mut Self {
124        self.mongodb.client_name = src.into();
125        self
126    }
127
128    pub async fn build(self) -> Result<Auth> {
129        let mut client_options = ClientOptions::parse(&self.mongodb.uri)
130            .await
131            .map_err(AuthError::MongoParseUri)?;
132
133        client_options.app_name = Some(self.mongodb.client_name.to_owned());
134
135        let client = Client::with_options(client_options).map_err(AuthError::MongoCreateClient)?;
136
137        let database = client.database(&self.mongodb.db_name);
138
139        Ok(Auth {
140            roles: Roles::default(),
141            database,
142            app_name: self.app_name,
143        })
144    }
145}
146
147impl Auth {
148    pub fn builder<T: Into<String>>(app_name: T) -> AuthBuilder {
149        AuthBuilder {
150            mongodb: MongoDB::default(),
151            app_name: app_name.into(),
152        }
153    }
154
155    async fn update_roles(&self) -> Result<()> {
156        // get app id
157        let app = self
158            .database
159            .collection::<App>(APPS)
160            .find_one(doc! { "name": &self.app_name }, None)
161            .await
162            .map_err(AuthError::MongoFindOne)?;
163
164        match app {
165            Some(t) => {
166                let mut cursor = self
167                    .database
168                    .collection::<Role>(ROLES)
169                    .find(
170                        doc! {
171                            "app": t.id()
172                        },
173                        None,
174                    )
175                    .await
176                    .map_err(AuthError::MongoFind)?;
177
178                let mut roles = HashMap::new();
179
180                while let Some(role) = cursor
181                    .try_next()
182                    .await
183                    .map_err(AuthError::MongoReadCursor)?
184                {
185                    roles.insert(role.name.clone(), role);
186                }
187
188                self.roles.set(roles).await;
189
190                Ok(())
191            }
192            None => Err(AuthError::MissingAppInDatabase),
193        }
194    }
195
196    pub async fn init(&self) -> Result<()> {
197        self.update_roles().await?;
198
199        let ref_self = self.clone();
200
201        tokio::spawn(async move {
202            let mut change_stream = match ref_self
203                .database
204                .collection::<Role>(ROLES)
205                .watch(vec![], None)
206                .await
207                .map_err(AuthError::MongoWatchChangeStream)
208            {
209                Ok(t) => t,
210                Err(err) => {
211                    return error!("{}", err);
212                }
213            };
214
215            while let Some(Ok(_)) = change_stream.next().await {
216                if let Err(err) = ref_self.update_roles().await {
217                    error!("{}", err);
218                }
219            }
220        });
221
222        Ok(())
223    }
224}