1use crate::rbac::{
2 PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources,
3 RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks, RbacUserRolePermissions,
4 ResourceRequest,
5 entity::{role::RoleId, user::UserId},
6};
7use crate::{
8 AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
9 ExecResult, IsolationLevel, QueryResult, Statement, StatementBuilder, TransactionError,
10 TransactionSession, TransactionTrait,
11};
12use std::{
13 pin::Pin,
14 sync::{Arc, RwLock},
15};
16use tracing::instrument;
17
18#[derive(Debug, Clone)]
22#[cfg_attr(docsrs, doc(cfg(feature = "rbac")))]
23pub struct RestrictedConnection {
24 pub(crate) user_id: UserId,
25 pub(crate) conn: DatabaseConnection,
26}
27
28#[derive(Debug)]
32pub struct RestrictedTransaction {
33 user_id: UserId,
34 conn: DatabaseTransaction,
35 rbac: RbacEngineMount,
36}
37
38#[derive(Debug, Default, Clone)]
39pub(crate) struct RbacEngineMount {
40 inner: Arc<RwLock<Option<RbacEngine>>>,
41}
42
43impl ConnectionTrait for RestrictedConnection {
44 fn get_database_backend(&self) -> DbBackend {
45 self.conn.get_database_backend()
46 }
47
48 fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
49 Err(DbErr::RbacError(format!(
50 "Raw query is not supported: {stmt}"
51 )))
52 }
53
54 fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
55 self.user_can_run(stmt)?;
56 self.conn.execute(stmt)
57 }
58
59 fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
60 Err(DbErr::RbacError(format!(
61 "Raw query is not supported: {sql}"
62 )))
63 }
64
65 fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
66 Err(DbErr::RbacError(format!(
67 "Raw query is not supported: {stmt}"
68 )))
69 }
70
71 fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
72 self.user_can_run(stmt)?;
73 self.conn.query_one(stmt)
74 }
75
76 fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
77 Err(DbErr::RbacError(format!(
78 "Raw query is not supported: {stmt}"
79 )))
80 }
81
82 fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
83 self.user_can_run(stmt)?;
84 self.conn.query_all(stmt)
85 }
86}
87
88impl ConnectionTrait for RestrictedTransaction {
89 fn get_database_backend(&self) -> DbBackend {
90 self.conn.get_database_backend()
91 }
92
93 fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
94 Err(DbErr::RbacError(format!(
95 "Raw query is not supported: {stmt}"
96 )))
97 }
98
99 fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
100 self.user_can_run(stmt)?;
101 self.conn.execute(stmt)
102 }
103
104 fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
105 Err(DbErr::RbacError(format!(
106 "Raw query is not supported: {sql}"
107 )))
108 }
109
110 fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
111 Err(DbErr::RbacError(format!(
112 "Raw query is not supported: {stmt}"
113 )))
114 }
115
116 fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
117 self.user_can_run(stmt)?;
118 self.conn.query_one(stmt)
119 }
120
121 fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
122 Err(DbErr::RbacError(format!(
123 "Raw query is not supported: {stmt}"
124 )))
125 }
126
127 fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
128 self.user_can_run(stmt)?;
129 self.conn.query_all(stmt)
130 }
131}
132
133impl RestrictedConnection {
134 pub fn user_id(&self) -> UserId {
136 self.user_id
137 }
138
139 pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
142 self.conn.rbac.user_can_run(self.user_id, stmt)
143 }
144
145 pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
147 where
148 P: Into<PermissionRequest>,
149 R: Into<ResourceRequest>,
150 {
151 self.conn.rbac.user_can(self.user_id, permission, resource)
152 }
153
154 pub fn current_user_role_permissions(&self) -> Result<RbacUserRolePermissions, DbErr> {
157 self.conn.rbac.user_role_permissions(self.user_id)
158 }
159
160 pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
163 self.conn.rbac.roles_and_ranks()
164 }
165
166 pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
168 self.conn.rbac.resources_and_permissions()
169 }
170
171 pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
173 self.conn.rbac.role_hierarchy_edges(role_id)
174 }
175
176 pub fn role_permissions_by_resources(
179 &self,
180 role_id: RoleId,
181 ) -> Result<RbacPermissionsByResources, DbErr> {
182 self.conn.rbac.role_permissions_by_resources(role_id)
183 }
184}
185
186impl RestrictedTransaction {
187 pub fn user_id(&self) -> UserId {
189 self.user_id
190 }
191
192 pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
195 self.rbac.user_can_run(self.user_id, stmt)
196 }
197
198 pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
200 where
201 P: Into<PermissionRequest>,
202 R: Into<ResourceRequest>,
203 {
204 self.rbac.user_can(self.user_id, permission, resource)
205 }
206}
207
208impl TransactionTrait for RestrictedConnection {
209 type Transaction = RestrictedTransaction;
210
211 #[instrument(level = "trace")]
212 fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
213 Ok(RestrictedTransaction {
214 user_id: self.user_id,
215 conn: self.conn.begin()?,
216 rbac: self.conn.rbac.clone(),
217 })
218 }
219
220 #[instrument(level = "trace")]
221 fn begin_with_config(
222 &self,
223 isolation_level: Option<IsolationLevel>,
224 access_mode: Option<AccessMode>,
225 ) -> Result<RestrictedTransaction, DbErr> {
226 Ok(RestrictedTransaction {
227 user_id: self.user_id,
228 conn: self.conn.begin_with_config(isolation_level, access_mode)?,
229 rbac: self.conn.rbac.clone(),
230 })
231 }
232
233 #[instrument(level = "trace", skip(callback))]
236 fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
237 where
238 F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
239 E: std::fmt::Display + std::fmt::Debug,
240 {
241 let transaction = self.begin().map_err(TransactionError::Connection)?;
242 transaction.run(callback)
243 }
244
245 #[instrument(level = "trace", skip(callback))]
248 fn transaction_with_config<F, T, E>(
249 &self,
250 callback: F,
251 isolation_level: Option<IsolationLevel>,
252 access_mode: Option<AccessMode>,
253 ) -> Result<T, TransactionError<E>>
254 where
255 F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
256 E: std::fmt::Display + std::fmt::Debug,
257 {
258 let transaction = self
259 .begin_with_config(isolation_level, access_mode)
260 .map_err(TransactionError::Connection)?;
261 transaction.run(callback)
262 }
263}
264
265impl TransactionTrait for RestrictedTransaction {
266 type Transaction = RestrictedTransaction;
267
268 #[instrument(level = "trace")]
269 fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
270 Ok(RestrictedTransaction {
271 user_id: self.user_id,
272 conn: self.conn.begin()?,
273 rbac: self.rbac.clone(),
274 })
275 }
276
277 #[instrument(level = "trace")]
278 fn begin_with_config(
279 &self,
280 isolation_level: Option<IsolationLevel>,
281 access_mode: Option<AccessMode>,
282 ) -> Result<RestrictedTransaction, DbErr> {
283 Ok(RestrictedTransaction {
284 user_id: self.user_id,
285 conn: self.conn.begin_with_config(isolation_level, access_mode)?,
286 rbac: self.rbac.clone(),
287 })
288 }
289
290 #[instrument(level = "trace", skip(callback))]
293 fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
294 where
295 F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
296 E: std::fmt::Display + std::fmt::Debug,
297 {
298 let transaction = self.begin().map_err(TransactionError::Connection)?;
299 transaction.run(callback)
300 }
301
302 #[instrument(level = "trace", skip(callback))]
305 fn transaction_with_config<F, T, E>(
306 &self,
307 callback: F,
308 isolation_level: Option<IsolationLevel>,
309 access_mode: Option<AccessMode>,
310 ) -> Result<T, TransactionError<E>>
311 where
312 F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
313 E: std::fmt::Display + std::fmt::Debug,
314 {
315 let transaction = self
316 .begin_with_config(isolation_level, access_mode)
317 .map_err(TransactionError::Connection)?;
318 transaction.run(callback)
319 }
320}
321
322impl TransactionSession for RestrictedTransaction {
323 fn commit(self) -> Result<(), DbErr> {
324 self.commit()
325 }
326
327 fn rollback(self) -> Result<(), DbErr> {
328 self.rollback()
329 }
330}
331
332impl RestrictedTransaction {
333 #[instrument(level = "trace", skip(callback))]
336 fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
337 where
338 F: for<'b> FnOnce(&'b RestrictedTransaction) -> Result<T, E>,
339 E: std::fmt::Display + std::fmt::Debug,
340 {
341 let res = callback(&self).map_err(TransactionError::Transaction);
342 if res.is_ok() {
343 self.commit().map_err(TransactionError::Connection)?;
344 } else {
345 self.rollback().map_err(TransactionError::Connection)?;
346 }
347 res
348 }
349
350 #[instrument(level = "trace")]
352 pub fn commit(self) -> Result<(), DbErr> {
353 self.conn.commit()
354 }
355
356 #[instrument(level = "trace")]
358 pub fn rollback(self) -> Result<(), DbErr> {
359 self.conn.rollback()
360 }
361}
362
363impl RbacEngineMount {
364 pub fn is_some(&self) -> bool {
365 let engine = self.inner.read().expect("RBAC Engine died");
366 engine.is_some()
367 }
368
369 pub fn replace(&self, engine: RbacEngine) {
370 let mut inner = self.inner.write().expect("RBAC Engine died");
371 *inner = Some(engine);
372 }
373
374 pub fn user_can<P, R>(&self, user_id: UserId, permission: P, resource: R) -> Result<bool, DbErr>
375 where
376 P: Into<PermissionRequest>,
377 R: Into<ResourceRequest>,
378 {
379 let permission = permission.into();
380 let resource = resource.into();
381 let holder = self.inner.read().expect("RBAC Engine died");
383 let engine = holder.as_ref().expect("RBAC Engine not set");
385 engine
386 .user_can(user_id, permission, resource)
387 .map_err(map_err)
388 }
389
390 pub fn user_can_run<S: StatementBuilder>(
391 &self,
392 user_id: UserId,
393 stmt: &S,
394 ) -> Result<(), DbErr> {
395 let audit = match stmt.audit() {
396 Ok(audit) => audit,
397 Err(err) => return Err(DbErr::RbacError(err.to_string())),
398 };
399 let holder = self.inner.read().expect("RBAC Engine died");
401 let engine = holder.as_ref().expect("RBAC Engine not set");
403 for request in audit.requests {
404 let permission = || PermissionRequest {
405 action: request.access_type.as_str().to_owned(),
406 };
407 let resource = || ResourceRequest {
408 schema: request.schema_table.0.as_ref().map(|s| s.1.to_string()),
409 table: request.schema_table.1.to_string(),
410 };
411 if !engine
412 .user_can(user_id, permission(), resource())
413 .map_err(map_err)?
414 {
415 return Err(DbErr::AccessDenied {
416 permission: permission().action.to_owned(),
417 resource: resource().to_string(),
418 });
419 }
420 }
421 Ok(())
422 }
423
424 pub fn user_role_permissions(&self, user_id: UserId) -> Result<RbacUserRolePermissions, DbErr> {
425 let holder = self.inner.read().expect("RBAC Engine died");
426 let engine = holder.as_ref().expect("RBAC Engine not set");
427 engine
428 .get_user_role_permissions(user_id)
429 .map_err(|err| DbErr::RbacError(err.to_string()))
430 }
431
432 pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
433 let holder = self.inner.read().expect("RBAC Engine died");
434 let engine = holder.as_ref().expect("RBAC Engine not set");
435 engine
436 .get_roles_and_ranks()
437 .map_err(|err| DbErr::RbacError(err.to_string()))
438 }
439
440 pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
441 let holder = self.inner.read().expect("RBAC Engine died");
442 let engine = holder.as_ref().expect("RBAC Engine not set");
443 Ok(engine.list_resources_and_permissions())
444 }
445
446 pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
447 let holder = self.inner.read().expect("RBAC Engine died");
448 let engine = holder.as_ref().expect("RBAC Engine not set");
449 Ok(engine.list_role_hierarchy_edges(role_id))
450 }
451
452 pub fn role_permissions_by_resources(
453 &self,
454 role_id: RoleId,
455 ) -> Result<RbacPermissionsByResources, DbErr> {
456 let holder = self.inner.read().expect("RBAC Engine died");
457 let engine = holder.as_ref().expect("RBAC Engine not set");
458 engine
459 .list_role_permissions_by_resources(role_id)
460 .map_err(|err| DbErr::RbacError(err.to_string()))
461 }
462}
463
464fn map_err(err: RbacError) -> DbErr {
465 DbErr::RbacError(err.to_string())
466}