1use super::transaction::run_async_transaction_callback;
2use crate::{
3 AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
4 ExecResult, IsolationLevel, QueryResult, Statement, StatementBuilder, TransactionError,
5 TransactionSession, TransactionTrait,
6};
7use crate::{
8 TransactionOptions,
9 rbac::{
10 PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources,
11 RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks,
12 RbacUserRolePermissions, ResourceRequest,
13 entity::{role::RoleId, user::UserId},
14 },
15};
16use std::{
17 future::Future,
18 pin::Pin,
19 sync::{Arc, RwLock},
20};
21use tracing::instrument;
22
23#[derive(Debug, Clone)]
27#[cfg_attr(docsrs, doc(cfg(feature = "rbac")))]
28pub struct RestrictedConnection {
29 pub(crate) user_id: UserId,
30 pub(crate) conn: DatabaseConnection,
31}
32
33#[derive(Debug)]
37pub struct RestrictedTransaction {
38 user_id: UserId,
39 conn: DatabaseTransaction,
40 rbac: RbacEngineMount,
41}
42
43#[derive(Debug, Default, Clone)]
44pub(crate) struct RbacEngineMount {
45 inner: Arc<RwLock<Option<RbacEngine>>>,
46}
47
48#[async_trait::async_trait]
49impl ConnectionTrait for RestrictedConnection {
50 fn get_database_backend(&self) -> DbBackend {
51 self.conn.get_database_backend()
52 }
53
54 async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
55 Err(DbErr::RbacError(format!(
56 "Raw query is not supported: {stmt}"
57 )))
58 }
59
60 async fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
61 self.user_can_run(stmt)?;
62 self.conn.execute(stmt).await
63 }
64
65 async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
66 Err(DbErr::RbacError(format!(
67 "Raw query is not supported: {sql}"
68 )))
69 }
70
71 async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
72 Err(DbErr::RbacError(format!(
73 "Raw query is not supported: {stmt}"
74 )))
75 }
76
77 async fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
78 self.user_can_run(stmt)?;
79 self.conn.query_one(stmt).await
80 }
81
82 async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
83 Err(DbErr::RbacError(format!(
84 "Raw query is not supported: {stmt}"
85 )))
86 }
87
88 async fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
89 self.user_can_run(stmt)?;
90 self.conn.query_all(stmt).await
91 }
92}
93
94#[async_trait::async_trait]
95impl ConnectionTrait for RestrictedTransaction {
96 fn get_database_backend(&self) -> DbBackend {
97 self.conn.get_database_backend()
98 }
99
100 async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
101 Err(DbErr::RbacError(format!(
102 "Raw query is not supported: {stmt}"
103 )))
104 }
105
106 async fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
107 self.user_can_run(stmt)?;
108 self.conn.execute(stmt).await
109 }
110
111 async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
112 Err(DbErr::RbacError(format!(
113 "Raw query is not supported: {sql}"
114 )))
115 }
116
117 async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
118 Err(DbErr::RbacError(format!(
119 "Raw query is not supported: {stmt}"
120 )))
121 }
122
123 async fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
124 self.user_can_run(stmt)?;
125 self.conn.query_one(stmt).await
126 }
127
128 async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
129 Err(DbErr::RbacError(format!(
130 "Raw query is not supported: {stmt}"
131 )))
132 }
133
134 async fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
135 self.user_can_run(stmt)?;
136 self.conn.query_all(stmt).await
137 }
138}
139
140impl RestrictedConnection {
141 pub fn user_id(&self) -> UserId {
143 self.user_id
144 }
145
146 #[instrument(level = "trace", skip(callback))]
150 pub async fn transaction_async<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
151 where
152 F: for<'c> AsyncFnOnce(&'c RestrictedTransaction) -> Result<T, E> + Send,
153 T: Send,
154 E: std::fmt::Display + std::fmt::Debug + Send,
155 {
156 let transaction = self.begin().await.map_err(TransactionError::Connection)?;
157 run_async_transaction_callback(transaction, callback).await
158 }
159
160 #[instrument(level = "trace", skip(callback))]
164 pub async fn transaction_with_config_async<F, T, E>(
165 &self,
166 callback: F,
167 isolation_level: Option<IsolationLevel>,
168 access_mode: Option<AccessMode>,
169 ) -> Result<T, TransactionError<E>>
170 where
171 F: for<'c> AsyncFnOnce(&'c RestrictedTransaction) -> Result<T, E> + Send,
172 T: Send,
173 E: std::fmt::Display + std::fmt::Debug + Send,
174 {
175 let transaction = self
176 .begin_with_config(isolation_level, access_mode)
177 .await
178 .map_err(TransactionError::Connection)?;
179 run_async_transaction_callback(transaction, callback).await
180 }
181
182 pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
185 self.conn.rbac.user_can_run(self.user_id, stmt)
186 }
187
188 pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
190 where
191 P: Into<PermissionRequest>,
192 R: Into<ResourceRequest>,
193 {
194 self.conn.rbac.user_can(self.user_id, permission, resource)
195 }
196
197 pub fn current_user_role_permissions(&self) -> Result<RbacUserRolePermissions, DbErr> {
200 self.conn.rbac.user_role_permissions(self.user_id)
201 }
202
203 pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
206 self.conn.rbac.roles_and_ranks()
207 }
208
209 pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
211 self.conn.rbac.resources_and_permissions()
212 }
213
214 pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
216 self.conn.rbac.role_hierarchy_edges(role_id)
217 }
218
219 pub fn role_permissions_by_resources(
222 &self,
223 role_id: RoleId,
224 ) -> Result<RbacPermissionsByResources, DbErr> {
225 self.conn.rbac.role_permissions_by_resources(role_id)
226 }
227}
228
229impl RestrictedTransaction {
230 pub fn user_id(&self) -> UserId {
232 self.user_id
233 }
234
235 #[instrument(level = "trace", skip(callback))]
239 pub async fn transaction_async<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
240 where
241 F: for<'c> AsyncFnOnce(&'c RestrictedTransaction) -> Result<T, E> + Send,
242 T: Send,
243 E: std::fmt::Display + std::fmt::Debug + Send,
244 {
245 let transaction = self.begin().await.map_err(TransactionError::Connection)?;
246 run_async_transaction_callback(transaction, callback).await
247 }
248
249 #[instrument(level = "trace", skip(callback))]
253 pub async fn transaction_with_config_async<F, T, E>(
254 &self,
255 callback: F,
256 isolation_level: Option<IsolationLevel>,
257 access_mode: Option<AccessMode>,
258 ) -> Result<T, TransactionError<E>>
259 where
260 F: for<'c> AsyncFnOnce(&'c RestrictedTransaction) -> Result<T, E> + Send,
261 T: Send,
262 E: std::fmt::Display + std::fmt::Debug + Send,
263 {
264 let transaction = self
265 .begin_with_config(isolation_level, access_mode)
266 .await
267 .map_err(TransactionError::Connection)?;
268 run_async_transaction_callback(transaction, callback).await
269 }
270
271 pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
274 self.rbac.user_can_run(self.user_id, stmt)
275 }
276
277 pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
279 where
280 P: Into<PermissionRequest>,
281 R: Into<ResourceRequest>,
282 {
283 self.rbac.user_can(self.user_id, permission, resource)
284 }
285}
286
287#[async_trait::async_trait]
288impl TransactionTrait for RestrictedConnection {
289 type Transaction = RestrictedTransaction;
290
291 #[instrument(level = "trace")]
292 async fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
293 Ok(RestrictedTransaction {
294 user_id: self.user_id,
295 conn: self.conn.begin().await?,
296 rbac: self.conn.rbac.clone(),
297 })
298 }
299
300 #[instrument(level = "trace")]
301 async fn begin_with_config(
302 &self,
303 isolation_level: Option<IsolationLevel>,
304 access_mode: Option<AccessMode>,
305 ) -> Result<RestrictedTransaction, DbErr> {
306 Ok(RestrictedTransaction {
307 user_id: self.user_id,
308 conn: self
309 .conn
310 .begin_with_config(isolation_level, access_mode)
311 .await?,
312 rbac: self.conn.rbac.clone(),
313 })
314 }
315
316 #[instrument(level = "trace")]
317 async fn begin_with_options(
318 &self,
319 options: TransactionOptions,
320 ) -> Result<RestrictedTransaction, DbErr> {
321 Ok(RestrictedTransaction {
322 user_id: self.user_id,
323 conn: self.conn.begin_with_options(options).await?,
324 rbac: self.conn.rbac.clone(),
325 })
326 }
327
328 #[instrument(level = "trace", skip(callback))]
331 async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
332 where
333 F: for<'c> FnOnce(
334 &'c RestrictedTransaction,
335 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
336 + Send,
337 T: Send,
338 E: std::fmt::Display + std::fmt::Debug + Send,
339 {
340 let transaction = self.begin().await.map_err(TransactionError::Connection)?;
341 transaction.run(callback).await
342 }
343
344 #[instrument(level = "trace", skip(callback))]
347 async fn transaction_with_config<F, T, E>(
348 &self,
349 callback: F,
350 isolation_level: Option<IsolationLevel>,
351 access_mode: Option<AccessMode>,
352 ) -> Result<T, TransactionError<E>>
353 where
354 F: for<'c> FnOnce(
355 &'c RestrictedTransaction,
356 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
357 + Send,
358 T: Send,
359 E: std::fmt::Display + std::fmt::Debug + Send,
360 {
361 let transaction = self
362 .begin_with_config(isolation_level, access_mode)
363 .await
364 .map_err(TransactionError::Connection)?;
365 transaction.run(callback).await
366 }
367}
368
369#[async_trait::async_trait]
370impl TransactionTrait for RestrictedTransaction {
371 type Transaction = RestrictedTransaction;
372
373 #[instrument(level = "trace")]
374 async fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
375 Ok(RestrictedTransaction {
376 user_id: self.user_id,
377 conn: self.conn.begin().await?,
378 rbac: self.rbac.clone(),
379 })
380 }
381
382 #[instrument(level = "trace")]
383 async fn begin_with_config(
384 &self,
385 isolation_level: Option<IsolationLevel>,
386 access_mode: Option<AccessMode>,
387 ) -> Result<RestrictedTransaction, DbErr> {
388 Ok(RestrictedTransaction {
389 user_id: self.user_id,
390 conn: self
391 .conn
392 .begin_with_config(isolation_level, access_mode)
393 .await?,
394 rbac: self.rbac.clone(),
395 })
396 }
397
398 #[instrument(level = "trace")]
399 async fn begin_with_options(
400 &self,
401 options: TransactionOptions,
402 ) -> Result<RestrictedTransaction, DbErr> {
403 Ok(RestrictedTransaction {
404 user_id: self.user_id,
405 conn: self.conn.begin_with_options(options).await?,
406 rbac: self.rbac.clone(),
407 })
408 }
409
410 #[instrument(level = "trace", skip(callback))]
413 async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
414 where
415 F: for<'c> FnOnce(
416 &'c RestrictedTransaction,
417 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
418 + Send,
419 T: Send,
420 E: std::fmt::Display + std::fmt::Debug + Send,
421 {
422 let transaction = self.begin().await.map_err(TransactionError::Connection)?;
423 transaction.run(callback).await
424 }
425
426 #[instrument(level = "trace", skip(callback))]
429 async fn transaction_with_config<F, T, E>(
430 &self,
431 callback: F,
432 isolation_level: Option<IsolationLevel>,
433 access_mode: Option<AccessMode>,
434 ) -> Result<T, TransactionError<E>>
435 where
436 F: for<'c> FnOnce(
437 &'c RestrictedTransaction,
438 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
439 + Send,
440 T: Send,
441 E: std::fmt::Display + std::fmt::Debug + Send,
442 {
443 let transaction = self
444 .begin_with_config(isolation_level, access_mode)
445 .await
446 .map_err(TransactionError::Connection)?;
447 transaction.run(callback).await
448 }
449}
450
451#[async_trait::async_trait]
452impl TransactionSession for RestrictedTransaction {
453 async fn commit(self) -> Result<(), DbErr> {
454 self.commit().await
455 }
456
457 async fn rollback(self) -> Result<(), DbErr> {
458 self.rollback().await
459 }
460}
461
462impl RestrictedTransaction {
463 #[instrument(level = "trace", skip(callback))]
466 async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
467 where
468 F: for<'b> FnOnce(
469 &'b RestrictedTransaction,
470 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
471 + Send,
472 T: Send,
473 E: std::fmt::Display + std::fmt::Debug + Send,
474 {
475 let res = callback(&self).await.map_err(TransactionError::Transaction);
476 if res.is_ok() {
477 self.commit().await.map_err(TransactionError::Connection)?;
478 } else {
479 self.rollback()
480 .await
481 .map_err(TransactionError::Connection)?;
482 }
483 res
484 }
485
486 #[instrument(level = "trace")]
488 pub async fn commit(self) -> Result<(), DbErr> {
489 self.conn.commit().await
490 }
491
492 #[instrument(level = "trace")]
494 pub async fn rollback(self) -> Result<(), DbErr> {
495 self.conn.rollback().await
496 }
497}
498
499impl RbacEngineMount {
500 pub fn is_some(&self) -> bool {
501 let engine = self.inner.read().expect("RBAC Engine died");
502 engine.is_some()
503 }
504
505 pub fn replace(&self, engine: RbacEngine) {
506 let mut inner = self.inner.write().expect("RBAC Engine died");
507 *inner = Some(engine);
508 }
509
510 pub fn user_can<P, R>(&self, user_id: UserId, permission: P, resource: R) -> Result<bool, DbErr>
511 where
512 P: Into<PermissionRequest>,
513 R: Into<ResourceRequest>,
514 {
515 let permission = permission.into();
516 let resource = resource.into();
517 let holder = self.inner.read().expect("RBAC Engine died");
519 let engine = holder.as_ref().expect("RBAC Engine not set");
521 engine
522 .user_can(user_id, permission, resource)
523 .map_err(map_err)
524 }
525
526 pub fn user_can_run<S: StatementBuilder>(
527 &self,
528 user_id: UserId,
529 stmt: &S,
530 ) -> Result<(), DbErr> {
531 let audit = match stmt.audit() {
532 Ok(audit) => audit,
533 Err(err) => return Err(DbErr::RbacError(err.to_string())),
534 };
535 let holder = self.inner.read().expect("RBAC Engine died");
537 let engine = holder.as_ref().expect("RBAC Engine not set");
539 for request in audit.requests {
540 let permission = || PermissionRequest {
541 action: request.access_type.as_str().to_owned(),
542 };
543 let resource = || ResourceRequest {
544 schema: request.schema_table.0.as_ref().map(|s| s.1.to_string()),
545 table: request.schema_table.1.to_string(),
546 };
547 if !engine
548 .user_can(user_id, permission(), resource())
549 .map_err(map_err)?
550 {
551 return Err(DbErr::AccessDenied {
552 permission: permission().action.to_owned(),
553 resource: resource().to_string(),
554 });
555 }
556 }
557 Ok(())
558 }
559
560 pub fn user_role_permissions(&self, user_id: UserId) -> Result<RbacUserRolePermissions, DbErr> {
561 let holder = self.inner.read().expect("RBAC Engine died");
562 let engine = holder.as_ref().expect("RBAC Engine not set");
563 engine
564 .get_user_role_permissions(user_id)
565 .map_err(|err| DbErr::RbacError(err.to_string()))
566 }
567
568 pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
569 let holder = self.inner.read().expect("RBAC Engine died");
570 let engine = holder.as_ref().expect("RBAC Engine not set");
571 engine
572 .get_roles_and_ranks()
573 .map_err(|err| DbErr::RbacError(err.to_string()))
574 }
575
576 pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
577 let holder = self.inner.read().expect("RBAC Engine died");
578 let engine = holder.as_ref().expect("RBAC Engine not set");
579 Ok(engine.list_resources_and_permissions())
580 }
581
582 pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
583 let holder = self.inner.read().expect("RBAC Engine died");
584 let engine = holder.as_ref().expect("RBAC Engine not set");
585 Ok(engine.list_role_hierarchy_edges(role_id))
586 }
587
588 pub fn role_permissions_by_resources(
589 &self,
590 role_id: RoleId,
591 ) -> Result<RbacPermissionsByResources, DbErr> {
592 let holder = self.inner.read().expect("RBAC Engine died");
593 let engine = holder.as_ref().expect("RBAC Engine not set");
594 engine
595 .list_role_permissions_by_resources(role_id)
596 .map_err(|err| DbErr::RbacError(err.to_string()))
597 }
598}
599
600fn map_err(err: RbacError) -> DbErr {
601 DbErr::RbacError(err.to_string())
602}