1use crate::rbac::{
2 PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources,
3 RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks, RbacUserRolePermissions,
4 ResourceRequest,
5 entity::{role::RoleId, user::UserId},
6};
7use crate::{
8 AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
9 ExecResult, IsolationLevel, QueryResult, Statement, StatementBuilder, TransactionError,
10 TransactionSession, TransactionTrait,
11};
12use std::{
13 pin::Pin,
14 sync::{Arc, RwLock},
15};
16use tracing::instrument;
17
18#[derive(Debug, Clone)]
22#[cfg_attr(docsrs, doc(cfg(feature = "rbac")))]
23pub struct RestrictedConnection {
24 pub(crate) user_id: UserId,
25 pub(crate) conn: DatabaseConnection,
26}
27
28#[derive(Debug)]
32pub struct RestrictedTransaction {
33 user_id: UserId,
34 conn: DatabaseTransaction,
35 rbac: RbacEngineMount,
36}
37
38#[derive(Debug, Default, Clone)]
39pub(crate) struct RbacEngineMount {
40 inner: Arc<RwLock<Option<RbacEngine>>>,
41}
42
43#[async_trait::async_trait]
44impl ConnectionTrait for RestrictedConnection {
45 fn get_database_backend(&self) -> DbBackend {
46 self.conn.get_database_backend()
47 }
48
49 async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
50 Err(DbErr::RbacError(format!(
51 "Raw query is not supported: {stmt}"
52 )))
53 }
54
55 async fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
56 self.user_can_run(stmt)?;
57 self.conn.execute(stmt).await
58 }
59
60 async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
61 Err(DbErr::RbacError(format!(
62 "Raw query is not supported: {sql}"
63 )))
64 }
65
66 async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
67 Err(DbErr::RbacError(format!(
68 "Raw query is not supported: {stmt}"
69 )))
70 }
71
72 async fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
73 self.user_can_run(stmt)?;
74 self.conn.query_one(stmt).await
75 }
76
77 async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
78 Err(DbErr::RbacError(format!(
79 "Raw query is not supported: {stmt}"
80 )))
81 }
82
83 async fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
84 self.user_can_run(stmt)?;
85 self.conn.query_all(stmt).await
86 }
87}
88
89#[async_trait::async_trait]
90impl ConnectionTrait for RestrictedTransaction {
91 fn get_database_backend(&self) -> DbBackend {
92 self.conn.get_database_backend()
93 }
94
95 async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
96 Err(DbErr::RbacError(format!(
97 "Raw query is not supported: {stmt}"
98 )))
99 }
100
101 async fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
102 self.user_can_run(stmt)?;
103 self.conn.execute(stmt).await
104 }
105
106 async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
107 Err(DbErr::RbacError(format!(
108 "Raw query is not supported: {sql}"
109 )))
110 }
111
112 async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
113 Err(DbErr::RbacError(format!(
114 "Raw query is not supported: {stmt}"
115 )))
116 }
117
118 async fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
119 self.user_can_run(stmt)?;
120 self.conn.query_one(stmt).await
121 }
122
123 async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
124 Err(DbErr::RbacError(format!(
125 "Raw query is not supported: {stmt}"
126 )))
127 }
128
129 async fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
130 self.user_can_run(stmt)?;
131 self.conn.query_all(stmt).await
132 }
133}
134
135impl RestrictedConnection {
136 pub fn user_id(&self) -> UserId {
138 self.user_id
139 }
140
141 pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
144 self.conn.rbac.user_can_run(self.user_id, stmt)
145 }
146
147 pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
149 where
150 P: Into<PermissionRequest>,
151 R: Into<ResourceRequest>,
152 {
153 self.conn.rbac.user_can(self.user_id, permission, resource)
154 }
155
156 pub fn current_user_role_permissions(&self) -> Result<RbacUserRolePermissions, DbErr> {
159 self.conn.rbac.user_role_permissions(self.user_id)
160 }
161
162 pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
165 self.conn.rbac.roles_and_ranks()
166 }
167
168 pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
170 self.conn.rbac.resources_and_permissions()
171 }
172
173 pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
175 self.conn.rbac.role_hierarchy_edges(role_id)
176 }
177
178 pub fn role_permissions_by_resources(
181 &self,
182 role_id: RoleId,
183 ) -> Result<RbacPermissionsByResources, DbErr> {
184 self.conn.rbac.role_permissions_by_resources(role_id)
185 }
186}
187
188impl RestrictedTransaction {
189 pub fn user_id(&self) -> UserId {
191 self.user_id
192 }
193
194 pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
197 self.rbac.user_can_run(self.user_id, stmt)
198 }
199
200 pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
202 where
203 P: Into<PermissionRequest>,
204 R: Into<ResourceRequest>,
205 {
206 self.rbac.user_can(self.user_id, permission, resource)
207 }
208}
209
210#[async_trait::async_trait]
211impl TransactionTrait for RestrictedConnection {
212 type Transaction = RestrictedTransaction;
213
214 #[instrument(level = "trace")]
215 async fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
216 Ok(RestrictedTransaction {
217 user_id: self.user_id,
218 conn: self.conn.begin().await?,
219 rbac: self.conn.rbac.clone(),
220 })
221 }
222
223 #[instrument(level = "trace")]
224 async fn begin_with_config(
225 &self,
226 isolation_level: Option<IsolationLevel>,
227 access_mode: Option<AccessMode>,
228 ) -> Result<RestrictedTransaction, DbErr> {
229 Ok(RestrictedTransaction {
230 user_id: self.user_id,
231 conn: self
232 .conn
233 .begin_with_config(isolation_level, access_mode)
234 .await?,
235 rbac: self.conn.rbac.clone(),
236 })
237 }
238
239 #[instrument(level = "trace", skip(callback))]
242 async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
243 where
244 F: for<'c> FnOnce(
245 &'c RestrictedTransaction,
246 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
247 + Send,
248 T: Send,
249 E: std::fmt::Display + std::fmt::Debug + Send,
250 {
251 let transaction = self.begin().await.map_err(TransactionError::Connection)?;
252 transaction.run(callback).await
253 }
254
255 #[instrument(level = "trace", skip(callback))]
258 async fn transaction_with_config<F, T, E>(
259 &self,
260 callback: F,
261 isolation_level: Option<IsolationLevel>,
262 access_mode: Option<AccessMode>,
263 ) -> Result<T, TransactionError<E>>
264 where
265 F: for<'c> FnOnce(
266 &'c RestrictedTransaction,
267 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
268 + Send,
269 T: Send,
270 E: std::fmt::Display + std::fmt::Debug + Send,
271 {
272 let transaction = self
273 .begin_with_config(isolation_level, access_mode)
274 .await
275 .map_err(TransactionError::Connection)?;
276 transaction.run(callback).await
277 }
278}
279
280#[async_trait::async_trait]
281impl TransactionTrait for RestrictedTransaction {
282 type Transaction = RestrictedTransaction;
283
284 #[instrument(level = "trace")]
285 async fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
286 Ok(RestrictedTransaction {
287 user_id: self.user_id,
288 conn: self.conn.begin().await?,
289 rbac: self.rbac.clone(),
290 })
291 }
292
293 #[instrument(level = "trace")]
294 async fn begin_with_config(
295 &self,
296 isolation_level: Option<IsolationLevel>,
297 access_mode: Option<AccessMode>,
298 ) -> Result<RestrictedTransaction, DbErr> {
299 Ok(RestrictedTransaction {
300 user_id: self.user_id,
301 conn: self
302 .conn
303 .begin_with_config(isolation_level, access_mode)
304 .await?,
305 rbac: self.rbac.clone(),
306 })
307 }
308
309 #[instrument(level = "trace", skip(callback))]
312 async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
313 where
314 F: for<'c> FnOnce(
315 &'c RestrictedTransaction,
316 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
317 + Send,
318 T: Send,
319 E: std::fmt::Display + std::fmt::Debug + Send,
320 {
321 let transaction = self.begin().await.map_err(TransactionError::Connection)?;
322 transaction.run(callback).await
323 }
324
325 #[instrument(level = "trace", skip(callback))]
328 async fn transaction_with_config<F, T, E>(
329 &self,
330 callback: F,
331 isolation_level: Option<IsolationLevel>,
332 access_mode: Option<AccessMode>,
333 ) -> Result<T, TransactionError<E>>
334 where
335 F: for<'c> FnOnce(
336 &'c RestrictedTransaction,
337 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
338 + Send,
339 T: Send,
340 E: std::fmt::Display + std::fmt::Debug + Send,
341 {
342 let transaction = self
343 .begin_with_config(isolation_level, access_mode)
344 .await
345 .map_err(TransactionError::Connection)?;
346 transaction.run(callback).await
347 }
348}
349
350#[async_trait::async_trait]
351impl TransactionSession for RestrictedTransaction {
352 async fn commit(self) -> Result<(), DbErr> {
353 self.commit().await
354 }
355
356 async fn rollback(self) -> Result<(), DbErr> {
357 self.rollback().await
358 }
359}
360
361impl RestrictedTransaction {
362 #[instrument(level = "trace", skip(callback))]
365 async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
366 where
367 F: for<'b> FnOnce(
368 &'b RestrictedTransaction,
369 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
370 + Send,
371 T: Send,
372 E: std::fmt::Display + std::fmt::Debug + Send,
373 {
374 let res = callback(&self).await.map_err(TransactionError::Transaction);
375 if res.is_ok() {
376 self.commit().await.map_err(TransactionError::Connection)?;
377 } else {
378 self.rollback()
379 .await
380 .map_err(TransactionError::Connection)?;
381 }
382 res
383 }
384
385 #[instrument(level = "trace")]
387 pub async fn commit(self) -> Result<(), DbErr> {
388 self.conn.commit().await
389 }
390
391 #[instrument(level = "trace")]
393 pub async fn rollback(self) -> Result<(), DbErr> {
394 self.conn.rollback().await
395 }
396}
397
398impl RbacEngineMount {
399 pub fn is_some(&self) -> bool {
400 let engine = self.inner.read().expect("RBAC Engine died");
401 engine.is_some()
402 }
403
404 pub fn replace(&self, engine: RbacEngine) {
405 let mut inner = self.inner.write().expect("RBAC Engine died");
406 *inner = Some(engine);
407 }
408
409 pub fn user_can<P, R>(&self, user_id: UserId, permission: P, resource: R) -> Result<bool, DbErr>
410 where
411 P: Into<PermissionRequest>,
412 R: Into<ResourceRequest>,
413 {
414 let permission = permission.into();
415 let resource = resource.into();
416 let holder = self.inner.read().expect("RBAC Engine died");
418 let engine = holder.as_ref().expect("RBAC Engine not set");
420 engine
421 .user_can(user_id, permission, resource)
422 .map_err(map_err)
423 }
424
425 pub fn user_can_run<S: StatementBuilder>(
426 &self,
427 user_id: UserId,
428 stmt: &S,
429 ) -> Result<(), DbErr> {
430 let audit = match stmt.audit() {
431 Ok(audit) => audit,
432 Err(err) => return Err(DbErr::RbacError(err.to_string())),
433 };
434 let holder = self.inner.read().expect("RBAC Engine died");
436 let engine = holder.as_ref().expect("RBAC Engine not set");
438 for request in audit.requests {
439 let permission = || PermissionRequest {
440 action: request.access_type.as_str().to_owned(),
441 };
442 let resource = || ResourceRequest {
443 schema: request.schema_table.0.as_ref().map(|s| s.1.to_string()),
444 table: request.schema_table.1.to_string(),
445 };
446 if !engine
447 .user_can(user_id, permission(), resource())
448 .map_err(map_err)?
449 {
450 return Err(DbErr::AccessDenied {
451 permission: permission().action.to_owned(),
452 resource: resource().to_string(),
453 });
454 }
455 }
456 Ok(())
457 }
458
459 pub fn user_role_permissions(&self, user_id: UserId) -> Result<RbacUserRolePermissions, DbErr> {
460 let holder = self.inner.read().expect("RBAC Engine died");
461 let engine = holder.as_ref().expect("RBAC Engine not set");
462 engine
463 .get_user_role_permissions(user_id)
464 .map_err(|err| DbErr::RbacError(err.to_string()))
465 }
466
467 pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
468 let holder = self.inner.read().expect("RBAC Engine died");
469 let engine = holder.as_ref().expect("RBAC Engine not set");
470 engine
471 .get_roles_and_ranks()
472 .map_err(|err| DbErr::RbacError(err.to_string()))
473 }
474
475 pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
476 let holder = self.inner.read().expect("RBAC Engine died");
477 let engine = holder.as_ref().expect("RBAC Engine not set");
478 Ok(engine.list_resources_and_permissions())
479 }
480
481 pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
482 let holder = self.inner.read().expect("RBAC Engine died");
483 let engine = holder.as_ref().expect("RBAC Engine not set");
484 Ok(engine.list_role_hierarchy_edges(role_id))
485 }
486
487 pub fn role_permissions_by_resources(
488 &self,
489 role_id: RoleId,
490 ) -> Result<RbacPermissionsByResources, DbErr> {
491 let holder = self.inner.read().expect("RBAC Engine died");
492 let engine = holder.as_ref().expect("RBAC Engine not set");
493 engine
494 .list_role_permissions_by_resources(role_id)
495 .map_err(|err| DbErr::RbacError(err.to_string()))
496 }
497}
498
499fn map_err(err: RbacError) -> DbErr {
500 DbErr::RbacError(err.to_string())
501}