1use super::transaction::run_async_transaction_callback;
2use crate::{
3 AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
4 ExecResult, IsolationLevel, QueryResult, Statement, StatementBuilder, TransactionError,
5 TransactionSession, TransactionTrait,
6};
7use crate::{
8 TransactionOptions,
9 rbac::{
10 PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources,
11 RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks,
12 RbacUserRolePermissions, ResourceRequest,
13 entity::{role::RoleId, user::UserId},
14 },
15};
16use std::{
17 future::Future,
18 pin::Pin,
19 sync::{Arc, RwLock},
20};
21use tracing::instrument;
22
23#[derive(Debug, Clone)]
27#[cfg_attr(docsrs, doc(cfg(feature = "rbac")))]
28pub struct RestrictedConnection {
29 pub(crate) user_id: UserId,
30 pub(crate) conn: DatabaseConnection,
31}
32
33#[derive(Debug)]
37pub struct RestrictedTransaction {
38 user_id: UserId,
39 conn: DatabaseTransaction,
40 rbac: RbacEngineMount,
41}
42
43#[derive(Debug, Default, Clone)]
44pub(crate) struct RbacEngineMount {
45 inner: Arc<RwLock<Option<RbacEngine>>>,
46}
47
48impl ConnectionTrait for RestrictedConnection {
49 fn get_database_backend(&self) -> DbBackend {
50 self.conn.get_database_backend()
51 }
52
53 fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
54 Err(DbErr::RbacError(format!(
55 "Raw query is not supported: {stmt}"
56 )))
57 }
58
59 fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
60 self.user_can_run(stmt)?;
61 self.conn.execute(stmt)
62 }
63
64 fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
65 Err(DbErr::RbacError(format!(
66 "Raw query is not supported: {sql}"
67 )))
68 }
69
70 fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
71 Err(DbErr::RbacError(format!(
72 "Raw query is not supported: {stmt}"
73 )))
74 }
75
76 fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
77 self.user_can_run(stmt)?;
78 self.conn.query_one(stmt)
79 }
80
81 fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
82 Err(DbErr::RbacError(format!(
83 "Raw query is not supported: {stmt}"
84 )))
85 }
86
87 fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
88 self.user_can_run(stmt)?;
89 self.conn.query_all(stmt)
90 }
91}
92
93impl ConnectionTrait for RestrictedTransaction {
94 fn get_database_backend(&self) -> DbBackend {
95 self.conn.get_database_backend()
96 }
97
98 fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
99 Err(DbErr::RbacError(format!(
100 "Raw query is not supported: {stmt}"
101 )))
102 }
103
104 fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
105 self.user_can_run(stmt)?;
106 self.conn.execute(stmt)
107 }
108
109 fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
110 Err(DbErr::RbacError(format!(
111 "Raw query is not supported: {sql}"
112 )))
113 }
114
115 fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
116 Err(DbErr::RbacError(format!(
117 "Raw query is not supported: {stmt}"
118 )))
119 }
120
121 fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
122 self.user_can_run(stmt)?;
123 self.conn.query_one(stmt)
124 }
125
126 fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
127 Err(DbErr::RbacError(format!(
128 "Raw query is not supported: {stmt}"
129 )))
130 }
131
132 fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
133 self.user_can_run(stmt)?;
134 self.conn.query_all(stmt)
135 }
136}
137
138impl RestrictedConnection {
139 pub fn user_id(&self) -> UserId {
141 self.user_id
142 }
143
144 #[instrument(level = "trace", skip(callback))]
148 pub fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
149 where
150 F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
151 E: std::fmt::Display + std::fmt::Debug,
152 {
153 let transaction = self.begin().map_err(TransactionError::Connection)?;
154 run_async_transaction_callback(transaction, callback)
155 }
156
157 #[instrument(level = "trace", skip(callback))]
161 pub fn transaction_with_config<F, T, E>(
162 &self,
163 callback: F,
164 isolation_level: Option<IsolationLevel>,
165 access_mode: Option<AccessMode>,
166 ) -> Result<T, TransactionError<E>>
167 where
168 F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
169 E: std::fmt::Display + std::fmt::Debug,
170 {
171 let transaction = self
172 .begin_with_config(isolation_level, access_mode)
173 .map_err(TransactionError::Connection)?;
174 run_async_transaction_callback(transaction, callback)
175 }
176
177 pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
180 self.conn.rbac.user_can_run(self.user_id, stmt)
181 }
182
183 pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
185 where
186 P: Into<PermissionRequest>,
187 R: Into<ResourceRequest>,
188 {
189 self.conn.rbac.user_can(self.user_id, permission, resource)
190 }
191
192 pub fn current_user_role_permissions(&self) -> Result<RbacUserRolePermissions, DbErr> {
195 self.conn.rbac.user_role_permissions(self.user_id)
196 }
197
198 pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
201 self.conn.rbac.roles_and_ranks()
202 }
203
204 pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
206 self.conn.rbac.resources_and_permissions()
207 }
208
209 pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
211 self.conn.rbac.role_hierarchy_edges(role_id)
212 }
213
214 pub fn role_permissions_by_resources(
217 &self,
218 role_id: RoleId,
219 ) -> Result<RbacPermissionsByResources, DbErr> {
220 self.conn.rbac.role_permissions_by_resources(role_id)
221 }
222}
223
224impl RestrictedTransaction {
225 pub fn user_id(&self) -> UserId {
227 self.user_id
228 }
229
230 #[instrument(level = "trace", skip(callback))]
234 pub fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
235 where
236 F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
237 E: std::fmt::Display + std::fmt::Debug,
238 {
239 let transaction = self.begin().map_err(TransactionError::Connection)?;
240 run_async_transaction_callback(transaction, callback)
241 }
242
243 #[instrument(level = "trace", skip(callback))]
247 pub fn transaction_with_config<F, T, E>(
248 &self,
249 callback: F,
250 isolation_level: Option<IsolationLevel>,
251 access_mode: Option<AccessMode>,
252 ) -> Result<T, TransactionError<E>>
253 where
254 F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
255 E: std::fmt::Display + std::fmt::Debug,
256 {
257 let transaction = self
258 .begin_with_config(isolation_level, access_mode)
259 .map_err(TransactionError::Connection)?;
260 run_async_transaction_callback(transaction, callback)
261 }
262
263 pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
266 self.rbac.user_can_run(self.user_id, stmt)
267 }
268
269 pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
271 where
272 P: Into<PermissionRequest>,
273 R: Into<ResourceRequest>,
274 {
275 self.rbac.user_can(self.user_id, permission, resource)
276 }
277}
278
279impl TransactionTrait for RestrictedConnection {
280 type Transaction = RestrictedTransaction;
281
282 #[instrument(level = "trace")]
283 fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
284 Ok(RestrictedTransaction {
285 user_id: self.user_id,
286 conn: self.conn.begin()?,
287 rbac: self.conn.rbac.clone(),
288 })
289 }
290
291 #[instrument(level = "trace")]
292 fn begin_with_config(
293 &self,
294 isolation_level: Option<IsolationLevel>,
295 access_mode: Option<AccessMode>,
296 ) -> Result<RestrictedTransaction, DbErr> {
297 Ok(RestrictedTransaction {
298 user_id: self.user_id,
299 conn: self.conn.begin_with_config(isolation_level, access_mode)?,
300 rbac: self.conn.rbac.clone(),
301 })
302 }
303
304 #[instrument(level = "trace")]
305 fn begin_with_options(
306 &self,
307 options: TransactionOptions,
308 ) -> Result<RestrictedTransaction, DbErr> {
309 Ok(RestrictedTransaction {
310 user_id: self.user_id,
311 conn: self.conn.begin_with_options(options)?,
312 rbac: self.conn.rbac.clone(),
313 })
314 }
315
316 #[instrument(level = "trace", skip(callback))]
319 fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
320 where
321 F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
322 E: std::fmt::Display + std::fmt::Debug,
323 {
324 let transaction = self.begin().map_err(TransactionError::Connection)?;
325 transaction.run(callback)
326 }
327
328 #[instrument(level = "trace", skip(callback))]
331 fn transaction_with_config<F, T, E>(
332 &self,
333 callback: F,
334 isolation_level: Option<IsolationLevel>,
335 access_mode: Option<AccessMode>,
336 ) -> Result<T, TransactionError<E>>
337 where
338 F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
339 E: std::fmt::Display + std::fmt::Debug,
340 {
341 let transaction = self
342 .begin_with_config(isolation_level, access_mode)
343 .map_err(TransactionError::Connection)?;
344 transaction.run(callback)
345 }
346}
347
348impl TransactionTrait for RestrictedTransaction {
349 type Transaction = RestrictedTransaction;
350
351 #[instrument(level = "trace")]
352 fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
353 Ok(RestrictedTransaction {
354 user_id: self.user_id,
355 conn: self.conn.begin()?,
356 rbac: self.rbac.clone(),
357 })
358 }
359
360 #[instrument(level = "trace")]
361 fn begin_with_config(
362 &self,
363 isolation_level: Option<IsolationLevel>,
364 access_mode: Option<AccessMode>,
365 ) -> Result<RestrictedTransaction, DbErr> {
366 Ok(RestrictedTransaction {
367 user_id: self.user_id,
368 conn: self.conn.begin_with_config(isolation_level, access_mode)?,
369 rbac: self.rbac.clone(),
370 })
371 }
372
373 #[instrument(level = "trace")]
374 fn begin_with_options(
375 &self,
376 options: TransactionOptions,
377 ) -> Result<RestrictedTransaction, DbErr> {
378 Ok(RestrictedTransaction {
379 user_id: self.user_id,
380 conn: self.conn.begin_with_options(options)?,
381 rbac: self.rbac.clone(),
382 })
383 }
384
385 #[instrument(level = "trace", skip(callback))]
388 fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
389 where
390 F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
391 E: std::fmt::Display + std::fmt::Debug,
392 {
393 let transaction = self.begin().map_err(TransactionError::Connection)?;
394 transaction.run(callback)
395 }
396
397 #[instrument(level = "trace", skip(callback))]
400 fn transaction_with_config<F, T, E>(
401 &self,
402 callback: F,
403 isolation_level: Option<IsolationLevel>,
404 access_mode: Option<AccessMode>,
405 ) -> Result<T, TransactionError<E>>
406 where
407 F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
408 E: std::fmt::Display + std::fmt::Debug,
409 {
410 let transaction = self
411 .begin_with_config(isolation_level, access_mode)
412 .map_err(TransactionError::Connection)?;
413 transaction.run(callback)
414 }
415}
416
417impl TransactionSession for RestrictedTransaction {
418 fn commit(self) -> Result<(), DbErr> {
419 self.commit()
420 }
421
422 fn rollback(self) -> Result<(), DbErr> {
423 self.rollback()
424 }
425}
426
427impl RestrictedTransaction {
428 #[instrument(level = "trace", skip(callback))]
431 fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
432 where
433 F: for<'b> FnOnce(&'b RestrictedTransaction) -> Result<T, E>,
434 E: std::fmt::Display + std::fmt::Debug,
435 {
436 let res = callback(&self).map_err(TransactionError::Transaction);
437 if res.is_ok() {
438 self.commit().map_err(TransactionError::Connection)?;
439 } else {
440 self.rollback().map_err(TransactionError::Connection)?;
441 }
442 res
443 }
444
445 #[instrument(level = "trace")]
447 pub fn commit(self) -> Result<(), DbErr> {
448 self.conn.commit()
449 }
450
451 #[instrument(level = "trace")]
453 pub fn rollback(self) -> Result<(), DbErr> {
454 self.conn.rollback()
455 }
456}
457
458impl RbacEngineMount {
459 pub fn is_some(&self) -> bool {
460 let engine = self.inner.read().expect("RBAC Engine died");
461 engine.is_some()
462 }
463
464 pub fn replace(&self, engine: RbacEngine) {
465 let mut inner = self.inner.write().expect("RBAC Engine died");
466 *inner = Some(engine);
467 }
468
469 pub fn user_can<P, R>(&self, user_id: UserId, permission: P, resource: R) -> Result<bool, DbErr>
470 where
471 P: Into<PermissionRequest>,
472 R: Into<ResourceRequest>,
473 {
474 let permission = permission.into();
475 let resource = resource.into();
476 let holder = self.inner.read().expect("RBAC Engine died");
478 let engine = holder.as_ref().expect("RBAC Engine not set");
480 engine
481 .user_can(user_id, permission, resource)
482 .map_err(map_err)
483 }
484
485 pub fn user_can_run<S: StatementBuilder>(
486 &self,
487 user_id: UserId,
488 stmt: &S,
489 ) -> Result<(), DbErr> {
490 let audit = match stmt.audit() {
491 Ok(audit) => audit,
492 Err(err) => return Err(DbErr::RbacError(err.to_string())),
493 };
494 let holder = self.inner.read().expect("RBAC Engine died");
496 let engine = holder.as_ref().expect("RBAC Engine not set");
498 for request in audit.requests {
499 let permission = || PermissionRequest {
500 action: request.access_type.as_str().to_owned(),
501 };
502 let resource = || ResourceRequest {
503 schema: request.schema_table.0.as_ref().map(|s| s.1.to_string()),
504 table: request.schema_table.1.to_string(),
505 };
506 if !engine
507 .user_can(user_id, permission(), resource())
508 .map_err(map_err)?
509 {
510 return Err(DbErr::AccessDenied {
511 permission: permission().action.to_owned(),
512 resource: resource().to_string(),
513 });
514 }
515 }
516 Ok(())
517 }
518
519 pub fn user_role_permissions(&self, user_id: UserId) -> Result<RbacUserRolePermissions, DbErr> {
520 let holder = self.inner.read().expect("RBAC Engine died");
521 let engine = holder.as_ref().expect("RBAC Engine not set");
522 engine
523 .get_user_role_permissions(user_id)
524 .map_err(|err| DbErr::RbacError(err.to_string()))
525 }
526
527 pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
528 let holder = self.inner.read().expect("RBAC Engine died");
529 let engine = holder.as_ref().expect("RBAC Engine not set");
530 engine
531 .get_roles_and_ranks()
532 .map_err(|err| DbErr::RbacError(err.to_string()))
533 }
534
535 pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
536 let holder = self.inner.read().expect("RBAC Engine died");
537 let engine = holder.as_ref().expect("RBAC Engine not set");
538 Ok(engine.list_resources_and_permissions())
539 }
540
541 pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
542 let holder = self.inner.read().expect("RBAC Engine died");
543 let engine = holder.as_ref().expect("RBAC Engine not set");
544 Ok(engine.list_role_hierarchy_edges(role_id))
545 }
546
547 pub fn role_permissions_by_resources(
548 &self,
549 role_id: RoleId,
550 ) -> Result<RbacPermissionsByResources, DbErr> {
551 let holder = self.inner.read().expect("RBAC Engine died");
552 let engine = holder.as_ref().expect("RBAC Engine not set");
553 engine
554 .list_role_permissions_by_resources(role_id)
555 .map_err(|err| DbErr::RbacError(err.to_string()))
556 }
557}
558
559fn map_err(err: RbacError) -> DbErr {
560 DbErr::RbacError(err.to_string())
561}