1use crate::rbac::{
2 PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources,
3 RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks, RbacUserRolePermissions,
4 ResourceRequest,
5 entity::{role::RoleId, user::UserId},
6};
7use crate::{
8 AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
9 ExecResult, IsolationLevel, QueryResult, Statement, StatementBuilder, TransactionError,
10 TransactionSession, TransactionTrait,
11};
12use std::{
13 pin::Pin,
14 sync::{Arc, RwLock},
15};
16use tracing::instrument;
17
18#[derive(Debug, Clone)]
22#[cfg_attr(docsrs, doc(cfg(feature = "rbac")))]
23pub struct RestrictedConnection {
24 pub(crate) user_id: UserId,
25 pub(crate) conn: DatabaseConnection,
26}
27
28#[derive(Debug)]
32pub struct RestrictedTransaction {
33 user_id: UserId,
34 conn: DatabaseTransaction,
35 rbac: RbacEngineMount,
36}
37
38#[derive(Debug, Default, Clone)]
39pub(crate) struct RbacEngineMount {
40 inner: Arc<RwLock<Option<RbacEngine>>>,
41}
42
43#[async_trait::async_trait]
44impl ConnectionTrait for RestrictedConnection {
45 fn get_database_backend(&self) -> DbBackend {
46 self.conn.get_database_backend()
47 }
48
49 async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
50 Err(DbErr::RbacError(format!(
51 "Raw query is not supported: {stmt}"
52 )))
53 }
54
55 async fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
56 self.user_can_run(stmt)?;
57 self.conn.execute(stmt).await
58 }
59
60 async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
61 Err(DbErr::RbacError(format!(
62 "Raw query is not supported: {sql}"
63 )))
64 }
65
66 async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
67 Err(DbErr::RbacError(format!(
68 "Raw query is not supported: {stmt}"
69 )))
70 }
71
72 async fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
73 self.user_can_run(stmt)?;
74 self.conn.query_one(stmt).await
75 }
76
77 async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
78 Err(DbErr::RbacError(format!(
79 "Raw query is not supported: {stmt}"
80 )))
81 }
82
83 async fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
84 self.user_can_run(stmt)?;
85 self.conn.query_all(stmt).await
86 }
87}
88
89#[async_trait::async_trait]
90impl ConnectionTrait for RestrictedTransaction {
91 fn get_database_backend(&self) -> DbBackend {
92 self.conn.get_database_backend()
93 }
94
95 async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
96 Err(DbErr::RbacError(format!(
97 "Raw query is not supported: {stmt}"
98 )))
99 }
100
101 async fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
102 self.user_can_run(stmt)?;
103 self.conn.execute(stmt).await
104 }
105
106 async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
107 Err(DbErr::RbacError(format!(
108 "Raw query is not supported: {sql}"
109 )))
110 }
111
112 async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
113 Err(DbErr::RbacError(format!(
114 "Raw query is not supported: {stmt}"
115 )))
116 }
117
118 async fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
119 self.user_can_run(stmt)?;
120 self.conn.query_one(stmt).await
121 }
122
123 async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
124 Err(DbErr::RbacError(format!(
125 "Raw query is not supported: {stmt}"
126 )))
127 }
128
129 async fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
130 self.user_can_run(stmt)?;
131 self.conn.query_all(stmt).await
132 }
133}
134
135impl RestrictedConnection {
136 pub fn user_id(&self) -> UserId {
138 self.user_id
139 }
140
141 pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
144 self.conn.rbac.user_can_run(self.user_id, stmt)
145 }
146
147 pub fn current_user_role_permissions(&self) -> Result<RbacUserRolePermissions, DbErr> {
150 self.conn.rbac.user_role_permissions(self.user_id)
151 }
152
153 pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
156 self.conn.rbac.roles_and_ranks()
157 }
158
159 pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
161 self.conn.rbac.resources_and_permissions()
162 }
163
164 pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
166 self.conn.rbac.role_hierarchy_edges(role_id)
167 }
168
169 pub fn role_permissions_by_resources(
172 &self,
173 role_id: RoleId,
174 ) -> Result<RbacPermissionsByResources, DbErr> {
175 self.conn.rbac.role_permissions_by_resources(role_id)
176 }
177}
178
179impl RestrictedTransaction {
180 pub fn user_id(&self) -> UserId {
182 self.user_id
183 }
184
185 pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
188 self.rbac.user_can_run(self.user_id, stmt)
189 }
190}
191
192#[async_trait::async_trait]
193impl TransactionTrait for RestrictedConnection {
194 type Transaction = RestrictedTransaction;
195
196 #[instrument(level = "trace")]
197 async fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
198 Ok(RestrictedTransaction {
199 user_id: self.user_id,
200 conn: self.conn.begin().await?,
201 rbac: self.conn.rbac.clone(),
202 })
203 }
204
205 #[instrument(level = "trace")]
206 async fn begin_with_config(
207 &self,
208 isolation_level: Option<IsolationLevel>,
209 access_mode: Option<AccessMode>,
210 ) -> Result<RestrictedTransaction, DbErr> {
211 Ok(RestrictedTransaction {
212 user_id: self.user_id,
213 conn: self
214 .conn
215 .begin_with_config(isolation_level, access_mode)
216 .await?,
217 rbac: self.conn.rbac.clone(),
218 })
219 }
220
221 #[instrument(level = "trace", skip(callback))]
224 async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
225 where
226 F: for<'c> FnOnce(
227 &'c RestrictedTransaction,
228 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
229 + Send,
230 T: Send,
231 E: std::fmt::Display + std::fmt::Debug + Send,
232 {
233 let transaction = self.begin().await.map_err(TransactionError::Connection)?;
234 transaction.run(callback).await
235 }
236
237 #[instrument(level = "trace", skip(callback))]
240 async fn transaction_with_config<F, T, E>(
241 &self,
242 callback: F,
243 isolation_level: Option<IsolationLevel>,
244 access_mode: Option<AccessMode>,
245 ) -> Result<T, TransactionError<E>>
246 where
247 F: for<'c> FnOnce(
248 &'c RestrictedTransaction,
249 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
250 + Send,
251 T: Send,
252 E: std::fmt::Display + std::fmt::Debug + Send,
253 {
254 let transaction = self
255 .begin_with_config(isolation_level, access_mode)
256 .await
257 .map_err(TransactionError::Connection)?;
258 transaction.run(callback).await
259 }
260}
261
262#[async_trait::async_trait]
263impl TransactionTrait for RestrictedTransaction {
264 type Transaction = RestrictedTransaction;
265
266 #[instrument(level = "trace")]
267 async fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
268 Ok(RestrictedTransaction {
269 user_id: self.user_id,
270 conn: self.conn.begin().await?,
271 rbac: self.rbac.clone(),
272 })
273 }
274
275 #[instrument(level = "trace")]
276 async fn begin_with_config(
277 &self,
278 isolation_level: Option<IsolationLevel>,
279 access_mode: Option<AccessMode>,
280 ) -> Result<RestrictedTransaction, DbErr> {
281 Ok(RestrictedTransaction {
282 user_id: self.user_id,
283 conn: self
284 .conn
285 .begin_with_config(isolation_level, access_mode)
286 .await?,
287 rbac: self.rbac.clone(),
288 })
289 }
290
291 #[instrument(level = "trace", skip(callback))]
294 async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
295 where
296 F: for<'c> FnOnce(
297 &'c RestrictedTransaction,
298 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
299 + Send,
300 T: Send,
301 E: std::fmt::Display + std::fmt::Debug + Send,
302 {
303 let transaction = self.begin().await.map_err(TransactionError::Connection)?;
304 transaction.run(callback).await
305 }
306
307 #[instrument(level = "trace", skip(callback))]
310 async fn transaction_with_config<F, T, E>(
311 &self,
312 callback: F,
313 isolation_level: Option<IsolationLevel>,
314 access_mode: Option<AccessMode>,
315 ) -> Result<T, TransactionError<E>>
316 where
317 F: for<'c> FnOnce(
318 &'c RestrictedTransaction,
319 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
320 + Send,
321 T: Send,
322 E: std::fmt::Display + std::fmt::Debug + Send,
323 {
324 let transaction = self
325 .begin_with_config(isolation_level, access_mode)
326 .await
327 .map_err(TransactionError::Connection)?;
328 transaction.run(callback).await
329 }
330}
331
332#[async_trait::async_trait]
333impl TransactionSession for RestrictedTransaction {
334 async fn commit(self) -> Result<(), DbErr> {
335 self.commit().await
336 }
337
338 async fn rollback(self) -> Result<(), DbErr> {
339 self.rollback().await
340 }
341}
342
343impl RestrictedTransaction {
344 #[instrument(level = "trace", skip(callback))]
347 async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
348 where
349 F: for<'b> FnOnce(
350 &'b RestrictedTransaction,
351 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
352 + Send,
353 T: Send,
354 E: std::fmt::Display + std::fmt::Debug + Send,
355 {
356 let res = callback(&self).await.map_err(TransactionError::Transaction);
357 if res.is_ok() {
358 self.commit().await.map_err(TransactionError::Connection)?;
359 } else {
360 self.rollback()
361 .await
362 .map_err(TransactionError::Connection)?;
363 }
364 res
365 }
366
367 #[instrument(level = "trace")]
369 pub async fn commit(self) -> Result<(), DbErr> {
370 self.conn.commit().await
371 }
372
373 #[instrument(level = "trace")]
375 pub async fn rollback(self) -> Result<(), DbErr> {
376 self.conn.rollback().await
377 }
378}
379
380impl RbacEngineMount {
381 pub fn is_some(&self) -> bool {
382 let engine = self.inner.read().expect("RBAC Engine died");
383 engine.is_some()
384 }
385
386 pub fn replace(&self, engine: RbacEngine) {
387 let mut inner = self.inner.write().expect("RBAC Engine died");
388 *inner = Some(engine);
389 }
390
391 pub fn user_can_run<S: StatementBuilder>(
392 &self,
393 user_id: UserId,
394 stmt: &S,
395 ) -> Result<(), DbErr> {
396 let audit = match stmt.audit() {
397 Ok(audit) => audit,
398 Err(err) => return Err(DbErr::RbacError(err.to_string())),
399 };
400 for request in audit.requests {
401 let holder = self.inner.read().expect("RBAC Engine died");
403 let engine = holder.as_ref().expect("RBAC Engine not set");
405 let permission = || PermissionRequest {
406 action: request.access_type.as_str().to_owned(),
407 };
408 let resource = || ResourceRequest {
409 schema: request.schema_table.0.as_ref().map(|s| s.1.to_string()),
410 table: request.schema_table.1.to_string(),
411 };
412 if !engine
413 .user_can(user_id, permission(), resource())
414 .map_err(map_err)?
415 {
416 let r = resource();
417 return Err(DbErr::AccessDenied {
418 permission: permission().action.to_owned(),
419 resource: format!(
420 "{}{}{}",
421 if let Some(schema) = &r.schema {
422 schema
423 } else {
424 ""
425 },
426 if r.schema.is_some() { "." } else { "" },
427 r.table
428 ),
429 });
430 }
431 }
432 Ok(())
433 }
434
435 pub fn user_role_permissions(&self, user_id: UserId) -> Result<RbacUserRolePermissions, DbErr> {
436 let holder = self.inner.read().expect("RBAC Engine died");
437 let engine = holder.as_ref().expect("RBAC Engine not set");
438 engine
439 .get_user_role_permissions(user_id)
440 .map_err(|err| DbErr::RbacError(err.to_string()))
441 }
442
443 pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
444 let holder = self.inner.read().expect("RBAC Engine died");
445 let engine = holder.as_ref().expect("RBAC Engine not set");
446 engine
447 .get_roles_and_ranks()
448 .map_err(|err| DbErr::RbacError(err.to_string()))
449 }
450
451 pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
452 let holder = self.inner.read().expect("RBAC Engine died");
453 let engine = holder.as_ref().expect("RBAC Engine not set");
454 Ok(engine.list_resources_and_permissions())
455 }
456
457 pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
458 let holder = self.inner.read().expect("RBAC Engine died");
459 let engine = holder.as_ref().expect("RBAC Engine not set");
460 Ok(engine.list_role_hierarchy_edges(role_id))
461 }
462
463 pub fn role_permissions_by_resources(
464 &self,
465 role_id: RoleId,
466 ) -> Result<RbacPermissionsByResources, DbErr> {
467 let holder = self.inner.read().expect("RBAC Engine died");
468 let engine = holder.as_ref().expect("RBAC Engine not set");
469 engine
470 .list_role_permissions_by_resources(role_id)
471 .map_err(|err| DbErr::RbacError(err.to_string()))
472 }
473}
474
475fn map_err(err: RbacError) -> DbErr {
476 DbErr::RbacError(err.to_string())
477}