1use crate::{
2 AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
3 ExecResult, IsolationLevel, QueryResult, Statement, StatementBuilder, TransactionError,
4 TransactionSession, TransactionTrait,
5};
6use crate::{
7 TransactionOptions,
8 rbac::{
9 PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources,
10 RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks,
11 RbacUserRolePermissions, ResourceRequest,
12 entity::{role::RoleId, user::UserId},
13 },
14};
15use std::{
16 pin::Pin,
17 sync::{Arc, RwLock},
18};
19use tracing::instrument;
20
21#[derive(Debug, Clone)]
25#[cfg_attr(docsrs, doc(cfg(feature = "rbac")))]
26pub struct RestrictedConnection {
27 pub(crate) user_id: UserId,
28 pub(crate) conn: DatabaseConnection,
29}
30
31#[derive(Debug)]
35pub struct RestrictedTransaction {
36 user_id: UserId,
37 conn: DatabaseTransaction,
38 rbac: RbacEngineMount,
39}
40
41#[derive(Debug, Default, Clone)]
42pub(crate) struct RbacEngineMount {
43 inner: Arc<RwLock<Option<RbacEngine>>>,
44}
45
46impl ConnectionTrait for RestrictedConnection {
47 fn get_database_backend(&self) -> DbBackend {
48 self.conn.get_database_backend()
49 }
50
51 fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
52 Err(DbErr::RbacError(format!(
53 "Raw query is not supported: {stmt}"
54 )))
55 }
56
57 fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
58 self.user_can_run(stmt)?;
59 self.conn.execute(stmt)
60 }
61
62 fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
63 Err(DbErr::RbacError(format!(
64 "Raw query is not supported: {sql}"
65 )))
66 }
67
68 fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
69 Err(DbErr::RbacError(format!(
70 "Raw query is not supported: {stmt}"
71 )))
72 }
73
74 fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
75 self.user_can_run(stmt)?;
76 self.conn.query_one(stmt)
77 }
78
79 fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
80 Err(DbErr::RbacError(format!(
81 "Raw query is not supported: {stmt}"
82 )))
83 }
84
85 fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
86 self.user_can_run(stmt)?;
87 self.conn.query_all(stmt)
88 }
89}
90
91impl ConnectionTrait for RestrictedTransaction {
92 fn get_database_backend(&self) -> DbBackend {
93 self.conn.get_database_backend()
94 }
95
96 fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
97 Err(DbErr::RbacError(format!(
98 "Raw query is not supported: {stmt}"
99 )))
100 }
101
102 fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
103 self.user_can_run(stmt)?;
104 self.conn.execute(stmt)
105 }
106
107 fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
108 Err(DbErr::RbacError(format!(
109 "Raw query is not supported: {sql}"
110 )))
111 }
112
113 fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
114 Err(DbErr::RbacError(format!(
115 "Raw query is not supported: {stmt}"
116 )))
117 }
118
119 fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
120 self.user_can_run(stmt)?;
121 self.conn.query_one(stmt)
122 }
123
124 fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
125 Err(DbErr::RbacError(format!(
126 "Raw query is not supported: {stmt}"
127 )))
128 }
129
130 fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
131 self.user_can_run(stmt)?;
132 self.conn.query_all(stmt)
133 }
134}
135
136impl RestrictedConnection {
137 pub fn user_id(&self) -> UserId {
139 self.user_id
140 }
141
142 pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
145 self.conn.rbac.user_can_run(self.user_id, stmt)
146 }
147
148 pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
150 where
151 P: Into<PermissionRequest>,
152 R: Into<ResourceRequest>,
153 {
154 self.conn.rbac.user_can(self.user_id, permission, resource)
155 }
156
157 pub fn current_user_role_permissions(&self) -> Result<RbacUserRolePermissions, DbErr> {
160 self.conn.rbac.user_role_permissions(self.user_id)
161 }
162
163 pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
166 self.conn.rbac.roles_and_ranks()
167 }
168
169 pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
171 self.conn.rbac.resources_and_permissions()
172 }
173
174 pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
176 self.conn.rbac.role_hierarchy_edges(role_id)
177 }
178
179 pub fn role_permissions_by_resources(
182 &self,
183 role_id: RoleId,
184 ) -> Result<RbacPermissionsByResources, DbErr> {
185 self.conn.rbac.role_permissions_by_resources(role_id)
186 }
187}
188
189impl RestrictedTransaction {
190 pub fn user_id(&self) -> UserId {
192 self.user_id
193 }
194
195 pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
198 self.rbac.user_can_run(self.user_id, stmt)
199 }
200
201 pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
203 where
204 P: Into<PermissionRequest>,
205 R: Into<ResourceRequest>,
206 {
207 self.rbac.user_can(self.user_id, permission, resource)
208 }
209}
210
211impl TransactionTrait for RestrictedConnection {
212 type Transaction = RestrictedTransaction;
213
214 #[instrument(level = "trace")]
215 fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
216 Ok(RestrictedTransaction {
217 user_id: self.user_id,
218 conn: self.conn.begin()?,
219 rbac: self.conn.rbac.clone(),
220 })
221 }
222
223 #[instrument(level = "trace")]
224 fn begin_with_config(
225 &self,
226 isolation_level: Option<IsolationLevel>,
227 access_mode: Option<AccessMode>,
228 ) -> Result<RestrictedTransaction, DbErr> {
229 Ok(RestrictedTransaction {
230 user_id: self.user_id,
231 conn: self.conn.begin_with_config(isolation_level, access_mode)?,
232 rbac: self.conn.rbac.clone(),
233 })
234 }
235
236 #[instrument(level = "trace")]
237 fn begin_with_options(
238 &self,
239 options: TransactionOptions,
240 ) -> Result<RestrictedTransaction, DbErr> {
241 Ok(RestrictedTransaction {
242 user_id: self.user_id,
243 conn: self.conn.begin_with_options(options)?,
244 rbac: self.conn.rbac.clone(),
245 })
246 }
247
248 #[instrument(level = "trace", skip(callback))]
251 fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
252 where
253 F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
254 E: std::fmt::Display + std::fmt::Debug,
255 {
256 let transaction = self.begin().map_err(TransactionError::Connection)?;
257 transaction.run(callback)
258 }
259
260 #[instrument(level = "trace", skip(callback))]
263 fn transaction_with_config<F, T, E>(
264 &self,
265 callback: F,
266 isolation_level: Option<IsolationLevel>,
267 access_mode: Option<AccessMode>,
268 ) -> Result<T, TransactionError<E>>
269 where
270 F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
271 E: std::fmt::Display + std::fmt::Debug,
272 {
273 let transaction = self
274 .begin_with_config(isolation_level, access_mode)
275 .map_err(TransactionError::Connection)?;
276 transaction.run(callback)
277 }
278}
279
280impl TransactionTrait for RestrictedTransaction {
281 type Transaction = RestrictedTransaction;
282
283 #[instrument(level = "trace")]
284 fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
285 Ok(RestrictedTransaction {
286 user_id: self.user_id,
287 conn: self.conn.begin()?,
288 rbac: self.rbac.clone(),
289 })
290 }
291
292 #[instrument(level = "trace")]
293 fn begin_with_config(
294 &self,
295 isolation_level: Option<IsolationLevel>,
296 access_mode: Option<AccessMode>,
297 ) -> Result<RestrictedTransaction, DbErr> {
298 Ok(RestrictedTransaction {
299 user_id: self.user_id,
300 conn: self.conn.begin_with_config(isolation_level, access_mode)?,
301 rbac: self.rbac.clone(),
302 })
303 }
304
305 #[instrument(level = "trace")]
306 fn begin_with_options(
307 &self,
308 options: TransactionOptions,
309 ) -> Result<RestrictedTransaction, DbErr> {
310 Ok(RestrictedTransaction {
311 user_id: self.user_id,
312 conn: self.conn.begin_with_options(options)?,
313 rbac: self.rbac.clone(),
314 })
315 }
316
317 #[instrument(level = "trace", skip(callback))]
320 fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
321 where
322 F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
323 E: std::fmt::Display + std::fmt::Debug,
324 {
325 let transaction = self.begin().map_err(TransactionError::Connection)?;
326 transaction.run(callback)
327 }
328
329 #[instrument(level = "trace", skip(callback))]
332 fn transaction_with_config<F, T, E>(
333 &self,
334 callback: F,
335 isolation_level: Option<IsolationLevel>,
336 access_mode: Option<AccessMode>,
337 ) -> Result<T, TransactionError<E>>
338 where
339 F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
340 E: std::fmt::Display + std::fmt::Debug,
341 {
342 let transaction = self
343 .begin_with_config(isolation_level, access_mode)
344 .map_err(TransactionError::Connection)?;
345 transaction.run(callback)
346 }
347}
348
349impl TransactionSession for RestrictedTransaction {
350 fn commit(self) -> Result<(), DbErr> {
351 self.commit()
352 }
353
354 fn rollback(self) -> Result<(), DbErr> {
355 self.rollback()
356 }
357}
358
359impl RestrictedTransaction {
360 #[instrument(level = "trace", skip(callback))]
363 fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
364 where
365 F: for<'b> FnOnce(&'b RestrictedTransaction) -> Result<T, E>,
366 E: std::fmt::Display + std::fmt::Debug,
367 {
368 let res = callback(&self).map_err(TransactionError::Transaction);
369 if res.is_ok() {
370 self.commit().map_err(TransactionError::Connection)?;
371 } else {
372 self.rollback().map_err(TransactionError::Connection)?;
373 }
374 res
375 }
376
377 #[instrument(level = "trace")]
379 pub fn commit(self) -> Result<(), DbErr> {
380 self.conn.commit()
381 }
382
383 #[instrument(level = "trace")]
385 pub fn rollback(self) -> Result<(), DbErr> {
386 self.conn.rollback()
387 }
388}
389
390impl RbacEngineMount {
391 pub fn is_some(&self) -> bool {
392 let engine = self.inner.read().expect("RBAC Engine died");
393 engine.is_some()
394 }
395
396 pub fn replace(&self, engine: RbacEngine) {
397 let mut inner = self.inner.write().expect("RBAC Engine died");
398 *inner = Some(engine);
399 }
400
401 pub fn user_can<P, R>(&self, user_id: UserId, permission: P, resource: R) -> Result<bool, DbErr>
402 where
403 P: Into<PermissionRequest>,
404 R: Into<ResourceRequest>,
405 {
406 let permission = permission.into();
407 let resource = resource.into();
408 let holder = self.inner.read().expect("RBAC Engine died");
410 let engine = holder.as_ref().expect("RBAC Engine not set");
412 engine
413 .user_can(user_id, permission, resource)
414 .map_err(map_err)
415 }
416
417 pub fn user_can_run<S: StatementBuilder>(
418 &self,
419 user_id: UserId,
420 stmt: &S,
421 ) -> Result<(), DbErr> {
422 let audit = match stmt.audit() {
423 Ok(audit) => audit,
424 Err(err) => return Err(DbErr::RbacError(err.to_string())),
425 };
426 let holder = self.inner.read().expect("RBAC Engine died");
428 let engine = holder.as_ref().expect("RBAC Engine not set");
430 for request in audit.requests {
431 let permission = || PermissionRequest {
432 action: request.access_type.as_str().to_owned(),
433 };
434 let resource = || ResourceRequest {
435 schema: request.schema_table.0.as_ref().map(|s| s.1.to_string()),
436 table: request.schema_table.1.to_string(),
437 };
438 if !engine
439 .user_can(user_id, permission(), resource())
440 .map_err(map_err)?
441 {
442 return Err(DbErr::AccessDenied {
443 permission: permission().action.to_owned(),
444 resource: resource().to_string(),
445 });
446 }
447 }
448 Ok(())
449 }
450
451 pub fn user_role_permissions(&self, user_id: UserId) -> Result<RbacUserRolePermissions, DbErr> {
452 let holder = self.inner.read().expect("RBAC Engine died");
453 let engine = holder.as_ref().expect("RBAC Engine not set");
454 engine
455 .get_user_role_permissions(user_id)
456 .map_err(|err| DbErr::RbacError(err.to_string()))
457 }
458
459 pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
460 let holder = self.inner.read().expect("RBAC Engine died");
461 let engine = holder.as_ref().expect("RBAC Engine not set");
462 engine
463 .get_roles_and_ranks()
464 .map_err(|err| DbErr::RbacError(err.to_string()))
465 }
466
467 pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
468 let holder = self.inner.read().expect("RBAC Engine died");
469 let engine = holder.as_ref().expect("RBAC Engine not set");
470 Ok(engine.list_resources_and_permissions())
471 }
472
473 pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
474 let holder = self.inner.read().expect("RBAC Engine died");
475 let engine = holder.as_ref().expect("RBAC Engine not set");
476 Ok(engine.list_role_hierarchy_edges(role_id))
477 }
478
479 pub fn role_permissions_by_resources(
480 &self,
481 role_id: RoleId,
482 ) -> Result<RbacPermissionsByResources, DbErr> {
483 let holder = self.inner.read().expect("RBAC Engine died");
484 let engine = holder.as_ref().expect("RBAC Engine not set");
485 engine
486 .list_role_permissions_by_resources(role_id)
487 .map_err(|err| DbErr::RbacError(err.to_string()))
488 }
489}
490
491fn map_err(err: RbacError) -> DbErr {
492 DbErr::RbacError(err.to_string())
493}