1use crate::{
2 AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
3 ExecResult, IsolationLevel, QueryResult, Statement, StatementBuilder, TransactionError,
4 TransactionSession, TransactionTrait,
5};
6use crate::{
7 TransactionOptions,
8 rbac::{
9 PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources,
10 RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks,
11 RbacUserRolePermissions, ResourceRequest,
12 entity::{role::RoleId, user::UserId},
13 },
14};
15use std::{
16 pin::Pin,
17 sync::{Arc, RwLock},
18};
19use tracing::instrument;
20
21#[derive(Debug, Clone)]
25#[cfg_attr(docsrs, doc(cfg(feature = "rbac")))]
26pub struct RestrictedConnection {
27 pub(crate) user_id: UserId,
28 pub(crate) conn: DatabaseConnection,
29}
30
31#[derive(Debug)]
35pub struct RestrictedTransaction {
36 user_id: UserId,
37 conn: DatabaseTransaction,
38 rbac: RbacEngineMount,
39}
40
41#[derive(Debug, Default, Clone)]
42pub(crate) struct RbacEngineMount {
43 inner: Arc<RwLock<Option<RbacEngine>>>,
44}
45
46#[async_trait::async_trait]
47impl ConnectionTrait for RestrictedConnection {
48 fn get_database_backend(&self) -> DbBackend {
49 self.conn.get_database_backend()
50 }
51
52 async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
53 Err(DbErr::RbacError(format!(
54 "Raw query is not supported: {stmt}"
55 )))
56 }
57
58 async fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
59 self.user_can_run(stmt)?;
60 self.conn.execute(stmt).await
61 }
62
63 async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
64 Err(DbErr::RbacError(format!(
65 "Raw query is not supported: {sql}"
66 )))
67 }
68
69 async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
70 Err(DbErr::RbacError(format!(
71 "Raw query is not supported: {stmt}"
72 )))
73 }
74
75 async fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
76 self.user_can_run(stmt)?;
77 self.conn.query_one(stmt).await
78 }
79
80 async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
81 Err(DbErr::RbacError(format!(
82 "Raw query is not supported: {stmt}"
83 )))
84 }
85
86 async fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
87 self.user_can_run(stmt)?;
88 self.conn.query_all(stmt).await
89 }
90}
91
92#[async_trait::async_trait]
93impl ConnectionTrait for RestrictedTransaction {
94 fn get_database_backend(&self) -> DbBackend {
95 self.conn.get_database_backend()
96 }
97
98 async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
99 Err(DbErr::RbacError(format!(
100 "Raw query is not supported: {stmt}"
101 )))
102 }
103
104 async fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
105 self.user_can_run(stmt)?;
106 self.conn.execute(stmt).await
107 }
108
109 async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
110 Err(DbErr::RbacError(format!(
111 "Raw query is not supported: {sql}"
112 )))
113 }
114
115 async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
116 Err(DbErr::RbacError(format!(
117 "Raw query is not supported: {stmt}"
118 )))
119 }
120
121 async fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
122 self.user_can_run(stmt)?;
123 self.conn.query_one(stmt).await
124 }
125
126 async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
127 Err(DbErr::RbacError(format!(
128 "Raw query is not supported: {stmt}"
129 )))
130 }
131
132 async fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
133 self.user_can_run(stmt)?;
134 self.conn.query_all(stmt).await
135 }
136}
137
138impl RestrictedConnection {
139 pub fn user_id(&self) -> UserId {
141 self.user_id
142 }
143
144 pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
147 self.conn.rbac.user_can_run(self.user_id, stmt)
148 }
149
150 pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
152 where
153 P: Into<PermissionRequest>,
154 R: Into<ResourceRequest>,
155 {
156 self.conn.rbac.user_can(self.user_id, permission, resource)
157 }
158
159 pub fn current_user_role_permissions(&self) -> Result<RbacUserRolePermissions, DbErr> {
162 self.conn.rbac.user_role_permissions(self.user_id)
163 }
164
165 pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
168 self.conn.rbac.roles_and_ranks()
169 }
170
171 pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
173 self.conn.rbac.resources_and_permissions()
174 }
175
176 pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
178 self.conn.rbac.role_hierarchy_edges(role_id)
179 }
180
181 pub fn role_permissions_by_resources(
184 &self,
185 role_id: RoleId,
186 ) -> Result<RbacPermissionsByResources, DbErr> {
187 self.conn.rbac.role_permissions_by_resources(role_id)
188 }
189}
190
191impl RestrictedTransaction {
192 pub fn user_id(&self) -> UserId {
194 self.user_id
195 }
196
197 pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
200 self.rbac.user_can_run(self.user_id, stmt)
201 }
202
203 pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
205 where
206 P: Into<PermissionRequest>,
207 R: Into<ResourceRequest>,
208 {
209 self.rbac.user_can(self.user_id, permission, resource)
210 }
211}
212
213#[async_trait::async_trait]
214impl TransactionTrait for RestrictedConnection {
215 type Transaction = RestrictedTransaction;
216
217 #[instrument(level = "trace")]
218 async fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
219 Ok(RestrictedTransaction {
220 user_id: self.user_id,
221 conn: self.conn.begin().await?,
222 rbac: self.conn.rbac.clone(),
223 })
224 }
225
226 #[instrument(level = "trace")]
227 async fn begin_with_config(
228 &self,
229 isolation_level: Option<IsolationLevel>,
230 access_mode: Option<AccessMode>,
231 ) -> Result<RestrictedTransaction, DbErr> {
232 Ok(RestrictedTransaction {
233 user_id: self.user_id,
234 conn: self
235 .conn
236 .begin_with_config(isolation_level, access_mode)
237 .await?,
238 rbac: self.conn.rbac.clone(),
239 })
240 }
241
242 #[instrument(level = "trace")]
243 async fn begin_with_options(
244 &self,
245 options: TransactionOptions,
246 ) -> Result<RestrictedTransaction, DbErr> {
247 Ok(RestrictedTransaction {
248 user_id: self.user_id,
249 conn: self.conn.begin_with_options(options).await?,
250 rbac: self.conn.rbac.clone(),
251 })
252 }
253
254 #[instrument(level = "trace", skip(callback))]
257 async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
258 where
259 F: for<'c> FnOnce(
260 &'c RestrictedTransaction,
261 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
262 + Send,
263 T: Send,
264 E: std::fmt::Display + std::fmt::Debug + Send,
265 {
266 let transaction = self.begin().await.map_err(TransactionError::Connection)?;
267 transaction.run(callback).await
268 }
269
270 #[instrument(level = "trace", skip(callback))]
273 async fn transaction_with_config<F, T, E>(
274 &self,
275 callback: F,
276 isolation_level: Option<IsolationLevel>,
277 access_mode: Option<AccessMode>,
278 ) -> Result<T, TransactionError<E>>
279 where
280 F: for<'c> FnOnce(
281 &'c RestrictedTransaction,
282 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
283 + Send,
284 T: Send,
285 E: std::fmt::Display + std::fmt::Debug + Send,
286 {
287 let transaction = self
288 .begin_with_config(isolation_level, access_mode)
289 .await
290 .map_err(TransactionError::Connection)?;
291 transaction.run(callback).await
292 }
293}
294
295#[async_trait::async_trait]
296impl TransactionTrait for RestrictedTransaction {
297 type Transaction = RestrictedTransaction;
298
299 #[instrument(level = "trace")]
300 async fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
301 Ok(RestrictedTransaction {
302 user_id: self.user_id,
303 conn: self.conn.begin().await?,
304 rbac: self.rbac.clone(),
305 })
306 }
307
308 #[instrument(level = "trace")]
309 async fn begin_with_config(
310 &self,
311 isolation_level: Option<IsolationLevel>,
312 access_mode: Option<AccessMode>,
313 ) -> Result<RestrictedTransaction, DbErr> {
314 Ok(RestrictedTransaction {
315 user_id: self.user_id,
316 conn: self
317 .conn
318 .begin_with_config(isolation_level, access_mode)
319 .await?,
320 rbac: self.rbac.clone(),
321 })
322 }
323
324 #[instrument(level = "trace")]
325 async fn begin_with_options(
326 &self,
327 options: TransactionOptions,
328 ) -> Result<RestrictedTransaction, DbErr> {
329 Ok(RestrictedTransaction {
330 user_id: self.user_id,
331 conn: self.conn.begin_with_options(options).await?,
332 rbac: self.rbac.clone(),
333 })
334 }
335
336 #[instrument(level = "trace", skip(callback))]
339 async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
340 where
341 F: for<'c> FnOnce(
342 &'c RestrictedTransaction,
343 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
344 + Send,
345 T: Send,
346 E: std::fmt::Display + std::fmt::Debug + Send,
347 {
348 let transaction = self.begin().await.map_err(TransactionError::Connection)?;
349 transaction.run(callback).await
350 }
351
352 #[instrument(level = "trace", skip(callback))]
355 async fn transaction_with_config<F, T, E>(
356 &self,
357 callback: F,
358 isolation_level: Option<IsolationLevel>,
359 access_mode: Option<AccessMode>,
360 ) -> Result<T, TransactionError<E>>
361 where
362 F: for<'c> FnOnce(
363 &'c RestrictedTransaction,
364 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
365 + Send,
366 T: Send,
367 E: std::fmt::Display + std::fmt::Debug + Send,
368 {
369 let transaction = self
370 .begin_with_config(isolation_level, access_mode)
371 .await
372 .map_err(TransactionError::Connection)?;
373 transaction.run(callback).await
374 }
375}
376
377#[async_trait::async_trait]
378impl TransactionSession for RestrictedTransaction {
379 async fn commit(self) -> Result<(), DbErr> {
380 self.commit().await
381 }
382
383 async fn rollback(self) -> Result<(), DbErr> {
384 self.rollback().await
385 }
386}
387
388impl RestrictedTransaction {
389 #[instrument(level = "trace", skip(callback))]
392 async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
393 where
394 F: for<'b> FnOnce(
395 &'b RestrictedTransaction,
396 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
397 + Send,
398 T: Send,
399 E: std::fmt::Display + std::fmt::Debug + Send,
400 {
401 let res = callback(&self).await.map_err(TransactionError::Transaction);
402 if res.is_ok() {
403 self.commit().await.map_err(TransactionError::Connection)?;
404 } else {
405 self.rollback()
406 .await
407 .map_err(TransactionError::Connection)?;
408 }
409 res
410 }
411
412 #[instrument(level = "trace")]
414 pub async fn commit(self) -> Result<(), DbErr> {
415 self.conn.commit().await
416 }
417
418 #[instrument(level = "trace")]
420 pub async fn rollback(self) -> Result<(), DbErr> {
421 self.conn.rollback().await
422 }
423}
424
425impl RbacEngineMount {
426 pub fn is_some(&self) -> bool {
427 let engine = self.inner.read().expect("RBAC Engine died");
428 engine.is_some()
429 }
430
431 pub fn replace(&self, engine: RbacEngine) {
432 let mut inner = self.inner.write().expect("RBAC Engine died");
433 *inner = Some(engine);
434 }
435
436 pub fn user_can<P, R>(&self, user_id: UserId, permission: P, resource: R) -> Result<bool, DbErr>
437 where
438 P: Into<PermissionRequest>,
439 R: Into<ResourceRequest>,
440 {
441 let permission = permission.into();
442 let resource = resource.into();
443 let holder = self.inner.read().expect("RBAC Engine died");
445 let engine = holder.as_ref().expect("RBAC Engine not set");
447 engine
448 .user_can(user_id, permission, resource)
449 .map_err(map_err)
450 }
451
452 pub fn user_can_run<S: StatementBuilder>(
453 &self,
454 user_id: UserId,
455 stmt: &S,
456 ) -> Result<(), DbErr> {
457 let audit = match stmt.audit() {
458 Ok(audit) => audit,
459 Err(err) => return Err(DbErr::RbacError(err.to_string())),
460 };
461 let holder = self.inner.read().expect("RBAC Engine died");
463 let engine = holder.as_ref().expect("RBAC Engine not set");
465 for request in audit.requests {
466 let permission = || PermissionRequest {
467 action: request.access_type.as_str().to_owned(),
468 };
469 let resource = || ResourceRequest {
470 schema: request.schema_table.0.as_ref().map(|s| s.1.to_string()),
471 table: request.schema_table.1.to_string(),
472 };
473 if !engine
474 .user_can(user_id, permission(), resource())
475 .map_err(map_err)?
476 {
477 return Err(DbErr::AccessDenied {
478 permission: permission().action.to_owned(),
479 resource: resource().to_string(),
480 });
481 }
482 }
483 Ok(())
484 }
485
486 pub fn user_role_permissions(&self, user_id: UserId) -> Result<RbacUserRolePermissions, DbErr> {
487 let holder = self.inner.read().expect("RBAC Engine died");
488 let engine = holder.as_ref().expect("RBAC Engine not set");
489 engine
490 .get_user_role_permissions(user_id)
491 .map_err(|err| DbErr::RbacError(err.to_string()))
492 }
493
494 pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
495 let holder = self.inner.read().expect("RBAC Engine died");
496 let engine = holder.as_ref().expect("RBAC Engine not set");
497 engine
498 .get_roles_and_ranks()
499 .map_err(|err| DbErr::RbacError(err.to_string()))
500 }
501
502 pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
503 let holder = self.inner.read().expect("RBAC Engine died");
504 let engine = holder.as_ref().expect("RBAC Engine not set");
505 Ok(engine.list_resources_and_permissions())
506 }
507
508 pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
509 let holder = self.inner.read().expect("RBAC Engine died");
510 let engine = holder.as_ref().expect("RBAC Engine not set");
511 Ok(engine.list_role_hierarchy_edges(role_id))
512 }
513
514 pub fn role_permissions_by_resources(
515 &self,
516 role_id: RoleId,
517 ) -> Result<RbacPermissionsByResources, DbErr> {
518 let holder = self.inner.read().expect("RBAC Engine died");
519 let engine = holder.as_ref().expect("RBAC Engine not set");
520 engine
521 .list_role_permissions_by_resources(role_id)
522 .map_err(|err| DbErr::RbacError(err.to_string()))
523 }
524}
525
526fn map_err(err: RbacError) -> DbErr {
527 DbErr::RbacError(err.to_string())
528}