Skip to main content

sea_orm/database/
restricted_connection.rs

1use crate::{
2    AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
3    ExecResult, IsolationLevel, QueryResult, Statement, StatementBuilder, TransactionError,
4    TransactionSession, TransactionTrait,
5};
6use crate::{
7    TransactionOptions,
8    rbac::{
9        PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources,
10        RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks,
11        RbacUserRolePermissions, ResourceRequest,
12        entity::{role::RoleId, user::UserId},
13    },
14};
15use std::{
16    pin::Pin,
17    sync::{Arc, RwLock},
18};
19use tracing::instrument;
20
21/// Wrapper of [`DatabaseConnection`] that performs authorization on all executed
22/// queries for the current user. Note that raw SQL [`Statement`] is not allowed
23/// currently.
24#[derive(Debug, Clone)]
25#[cfg_attr(docsrs, doc(cfg(feature = "rbac")))]
26pub struct RestrictedConnection {
27    pub(crate) user_id: UserId,
28    pub(crate) conn: DatabaseConnection,
29}
30
31/// Wrapper of [`DatabaseTransaction`] that performs authorization on all executed
32/// queries for the current user. Note that raw SQL [`Statement`] is not allowed
33/// currently.
34#[derive(Debug)]
35pub struct RestrictedTransaction {
36    user_id: UserId,
37    conn: DatabaseTransaction,
38    rbac: RbacEngineMount,
39}
40
41#[derive(Debug, Default, Clone)]
42pub(crate) struct RbacEngineMount {
43    inner: Arc<RwLock<Option<RbacEngine>>>,
44}
45
46#[async_trait::async_trait]
47impl ConnectionTrait for RestrictedConnection {
48    fn get_database_backend(&self) -> DbBackend {
49        self.conn.get_database_backend()
50    }
51
52    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
53        Err(DbErr::RbacError(format!(
54            "Raw query is not supported: {stmt}"
55        )))
56    }
57
58    async fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
59        self.user_can_run(stmt)?;
60        self.conn.execute(stmt).await
61    }
62
63    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
64        Err(DbErr::RbacError(format!(
65            "Raw query is not supported: {sql}"
66        )))
67    }
68
69    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
70        Err(DbErr::RbacError(format!(
71            "Raw query is not supported: {stmt}"
72        )))
73    }
74
75    async fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
76        self.user_can_run(stmt)?;
77        self.conn.query_one(stmt).await
78    }
79
80    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
81        Err(DbErr::RbacError(format!(
82            "Raw query is not supported: {stmt}"
83        )))
84    }
85
86    async fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
87        self.user_can_run(stmt)?;
88        self.conn.query_all(stmt).await
89    }
90}
91
92#[async_trait::async_trait]
93impl ConnectionTrait for RestrictedTransaction {
94    fn get_database_backend(&self) -> DbBackend {
95        self.conn.get_database_backend()
96    }
97
98    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
99        Err(DbErr::RbacError(format!(
100            "Raw query is not supported: {stmt}"
101        )))
102    }
103
104    async fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
105        self.user_can_run(stmt)?;
106        self.conn.execute(stmt).await
107    }
108
109    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
110        Err(DbErr::RbacError(format!(
111            "Raw query is not supported: {sql}"
112        )))
113    }
114
115    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
116        Err(DbErr::RbacError(format!(
117            "Raw query is not supported: {stmt}"
118        )))
119    }
120
121    async fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
122        self.user_can_run(stmt)?;
123        self.conn.query_one(stmt).await
124    }
125
126    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
127        Err(DbErr::RbacError(format!(
128            "Raw query is not supported: {stmt}"
129        )))
130    }
131
132    async fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
133        self.user_can_run(stmt)?;
134        self.conn.query_all(stmt).await
135    }
136}
137
138impl RestrictedConnection {
139    /// Get the [`RbacUserId`] bounded to this connection.
140    pub fn user_id(&self) -> UserId {
141        self.user_id
142    }
143
144    /// Returns `()` if the current user can execute / query the given SQL statement.
145    /// Returns `DbErr::AccessDenied` otherwise.
146    pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
147        self.conn.rbac.user_can_run(self.user_id, stmt)
148    }
149
150    /// Returns true if the current user can perform action on resource
151    pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
152    where
153        P: Into<PermissionRequest>,
154        R: Into<ResourceRequest>,
155    {
156        self.conn.rbac.user_can(self.user_id, permission, resource)
157    }
158
159    /// Get current user's role and associated permissions.
160    /// This includes permissions "inherited" from child roles.
161    pub fn current_user_role_permissions(&self) -> Result<RbacUserRolePermissions, DbErr> {
162        self.conn.rbac.user_role_permissions(self.user_id)
163    }
164
165    /// Get a list of all roles and their ranks.
166    /// Rank is defined as (1 + number of child roles).
167    pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
168        self.conn.rbac.roles_and_ranks()
169    }
170
171    /// Get two lists of all resources and permissions, excluding wildcards.
172    pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
173        self.conn.rbac.resources_and_permissions()
174    }
175
176    /// Get a list of edges walking the role hierarchy tree
177    pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
178        self.conn.rbac.role_hierarchy_edges(role_id)
179    }
180
181    /// Get a list of permissions for the specific role, grouped by resources.
182    /// This does not include permissions of child roles.
183    pub fn role_permissions_by_resources(
184        &self,
185        role_id: RoleId,
186    ) -> Result<RbacPermissionsByResources, DbErr> {
187        self.conn.rbac.role_permissions_by_resources(role_id)
188    }
189}
190
191impl RestrictedTransaction {
192    /// Get the [`RbacUserId`] bounded to this connection.
193    pub fn user_id(&self) -> UserId {
194        self.user_id
195    }
196
197    /// Returns `()` if the current user can execute / query the given SQL statement.
198    /// Returns `DbErr::AccessDenied` otherwise.
199    pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
200        self.rbac.user_can_run(self.user_id, stmt)
201    }
202
203    /// Returns true if the current user can perform action on resource
204    pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
205    where
206        P: Into<PermissionRequest>,
207        R: Into<ResourceRequest>,
208    {
209        self.rbac.user_can(self.user_id, permission, resource)
210    }
211}
212
213#[async_trait::async_trait]
214impl TransactionTrait for RestrictedConnection {
215    type Transaction = RestrictedTransaction;
216
217    #[instrument(level = "trace")]
218    async fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
219        Ok(RestrictedTransaction {
220            user_id: self.user_id,
221            conn: self.conn.begin().await?,
222            rbac: self.conn.rbac.clone(),
223        })
224    }
225
226    #[instrument(level = "trace")]
227    async fn begin_with_config(
228        &self,
229        isolation_level: Option<IsolationLevel>,
230        access_mode: Option<AccessMode>,
231    ) -> Result<RestrictedTransaction, DbErr> {
232        Ok(RestrictedTransaction {
233            user_id: self.user_id,
234            conn: self
235                .conn
236                .begin_with_config(isolation_level, access_mode)
237                .await?,
238            rbac: self.conn.rbac.clone(),
239        })
240    }
241
242    #[instrument(level = "trace")]
243    async fn begin_with_options(
244        &self,
245        options: TransactionOptions,
246    ) -> Result<RestrictedTransaction, DbErr> {
247        Ok(RestrictedTransaction {
248            user_id: self.user_id,
249            conn: self.conn.begin_with_options(options).await?,
250            rbac: self.conn.rbac.clone(),
251        })
252    }
253
254    /// Execute the function inside a transaction.
255    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
256    #[instrument(level = "trace", skip(callback))]
257    async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
258    where
259        F: for<'c> FnOnce(
260                &'c RestrictedTransaction,
261            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
262            + Send,
263        T: Send,
264        E: std::fmt::Display + std::fmt::Debug + Send,
265    {
266        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
267        transaction.run(callback).await
268    }
269
270    /// Execute the function inside a transaction.
271    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
272    #[instrument(level = "trace", skip(callback))]
273    async fn transaction_with_config<F, T, E>(
274        &self,
275        callback: F,
276        isolation_level: Option<IsolationLevel>,
277        access_mode: Option<AccessMode>,
278    ) -> Result<T, TransactionError<E>>
279    where
280        F: for<'c> FnOnce(
281                &'c RestrictedTransaction,
282            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
283            + Send,
284        T: Send,
285        E: std::fmt::Display + std::fmt::Debug + Send,
286    {
287        let transaction = self
288            .begin_with_config(isolation_level, access_mode)
289            .await
290            .map_err(TransactionError::Connection)?;
291        transaction.run(callback).await
292    }
293}
294
295#[async_trait::async_trait]
296impl TransactionTrait for RestrictedTransaction {
297    type Transaction = RestrictedTransaction;
298
299    #[instrument(level = "trace")]
300    async fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
301        Ok(RestrictedTransaction {
302            user_id: self.user_id,
303            conn: self.conn.begin().await?,
304            rbac: self.rbac.clone(),
305        })
306    }
307
308    #[instrument(level = "trace")]
309    async fn begin_with_config(
310        &self,
311        isolation_level: Option<IsolationLevel>,
312        access_mode: Option<AccessMode>,
313    ) -> Result<RestrictedTransaction, DbErr> {
314        Ok(RestrictedTransaction {
315            user_id: self.user_id,
316            conn: self
317                .conn
318                .begin_with_config(isolation_level, access_mode)
319                .await?,
320            rbac: self.rbac.clone(),
321        })
322    }
323
324    #[instrument(level = "trace")]
325    async fn begin_with_options(
326        &self,
327        options: TransactionOptions,
328    ) -> Result<RestrictedTransaction, DbErr> {
329        Ok(RestrictedTransaction {
330            user_id: self.user_id,
331            conn: self.conn.begin_with_options(options).await?,
332            rbac: self.rbac.clone(),
333        })
334    }
335
336    /// Execute the function inside a transaction.
337    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
338    #[instrument(level = "trace", skip(callback))]
339    async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
340    where
341        F: for<'c> FnOnce(
342                &'c RestrictedTransaction,
343            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
344            + Send,
345        T: Send,
346        E: std::fmt::Display + std::fmt::Debug + Send,
347    {
348        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
349        transaction.run(callback).await
350    }
351
352    /// Execute the function inside a transaction.
353    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
354    #[instrument(level = "trace", skip(callback))]
355    async fn transaction_with_config<F, T, E>(
356        &self,
357        callback: F,
358        isolation_level: Option<IsolationLevel>,
359        access_mode: Option<AccessMode>,
360    ) -> Result<T, TransactionError<E>>
361    where
362        F: for<'c> FnOnce(
363                &'c RestrictedTransaction,
364            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
365            + Send,
366        T: Send,
367        E: std::fmt::Display + std::fmt::Debug + Send,
368    {
369        let transaction = self
370            .begin_with_config(isolation_level, access_mode)
371            .await
372            .map_err(TransactionError::Connection)?;
373        transaction.run(callback).await
374    }
375}
376
377#[async_trait::async_trait]
378impl TransactionSession for RestrictedTransaction {
379    async fn commit(self) -> Result<(), DbErr> {
380        self.commit().await
381    }
382
383    async fn rollback(self) -> Result<(), DbErr> {
384        self.rollback().await
385    }
386}
387
388impl RestrictedTransaction {
389    /// Runs a transaction to completion passing through the result.
390    /// Rolling back the transaction on encountering an error.
391    #[instrument(level = "trace", skip(callback))]
392    async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
393    where
394        F: for<'b> FnOnce(
395                &'b RestrictedTransaction,
396            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
397            + Send,
398        T: Send,
399        E: std::fmt::Display + std::fmt::Debug + Send,
400    {
401        let res = callback(&self).await.map_err(TransactionError::Transaction);
402        if res.is_ok() {
403            self.commit().await.map_err(TransactionError::Connection)?;
404        } else {
405            self.rollback()
406                .await
407                .map_err(TransactionError::Connection)?;
408        }
409        res
410    }
411
412    /// Commit a transaction
413    #[instrument(level = "trace")]
414    pub async fn commit(self) -> Result<(), DbErr> {
415        self.conn.commit().await
416    }
417
418    /// Rolls back a transaction explicitly
419    #[instrument(level = "trace")]
420    pub async fn rollback(self) -> Result<(), DbErr> {
421        self.conn.rollback().await
422    }
423}
424
425impl RbacEngineMount {
426    pub fn is_some(&self) -> bool {
427        let engine = self.inner.read().expect("RBAC Engine died");
428        engine.is_some()
429    }
430
431    pub fn replace(&self, engine: RbacEngine) {
432        let mut inner = self.inner.write().expect("RBAC Engine died");
433        *inner = Some(engine);
434    }
435
436    pub fn user_can<P, R>(&self, user_id: UserId, permission: P, resource: R) -> Result<bool, DbErr>
437    where
438        P: Into<PermissionRequest>,
439        R: Into<ResourceRequest>,
440    {
441        let permission = permission.into();
442        let resource = resource.into();
443        // There is nothing we can do if RwLock is poisoned.
444        let holder = self.inner.read().expect("RBAC Engine died");
445        // Constructor of this struct should ensure engine is not None.
446        let engine = holder.as_ref().expect("RBAC Engine not set");
447        engine
448            .user_can(user_id, permission, resource)
449            .map_err(map_err)
450    }
451
452    pub fn user_can_run<S: StatementBuilder>(
453        &self,
454        user_id: UserId,
455        stmt: &S,
456    ) -> Result<(), DbErr> {
457        let audit = match stmt.audit() {
458            Ok(audit) => audit,
459            Err(err) => return Err(DbErr::RbacError(err.to_string())),
460        };
461        // There is nothing we can do if RwLock is poisoned.
462        let holder = self.inner.read().expect("RBAC Engine died");
463        // Constructor of this struct should ensure engine is not None.
464        let engine = holder.as_ref().expect("RBAC Engine not set");
465        for request in audit.requests {
466            let permission = || PermissionRequest {
467                action: request.access_type.as_str().to_owned(),
468            };
469            let resource = || ResourceRequest {
470                schema: request.schema_table.0.as_ref().map(|s| s.1.to_string()),
471                table: request.schema_table.1.to_string(),
472            };
473            if !engine
474                .user_can(user_id, permission(), resource())
475                .map_err(map_err)?
476            {
477                return Err(DbErr::AccessDenied {
478                    permission: permission().action.to_owned(),
479                    resource: resource().to_string(),
480                });
481            }
482        }
483        Ok(())
484    }
485
486    pub fn user_role_permissions(&self, user_id: UserId) -> Result<RbacUserRolePermissions, DbErr> {
487        let holder = self.inner.read().expect("RBAC Engine died");
488        let engine = holder.as_ref().expect("RBAC Engine not set");
489        engine
490            .get_user_role_permissions(user_id)
491            .map_err(|err| DbErr::RbacError(err.to_string()))
492    }
493
494    pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
495        let holder = self.inner.read().expect("RBAC Engine died");
496        let engine = holder.as_ref().expect("RBAC Engine not set");
497        engine
498            .get_roles_and_ranks()
499            .map_err(|err| DbErr::RbacError(err.to_string()))
500    }
501
502    pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
503        let holder = self.inner.read().expect("RBAC Engine died");
504        let engine = holder.as_ref().expect("RBAC Engine not set");
505        Ok(engine.list_resources_and_permissions())
506    }
507
508    pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
509        let holder = self.inner.read().expect("RBAC Engine died");
510        let engine = holder.as_ref().expect("RBAC Engine not set");
511        Ok(engine.list_role_hierarchy_edges(role_id))
512    }
513
514    pub fn role_permissions_by_resources(
515        &self,
516        role_id: RoleId,
517    ) -> Result<RbacPermissionsByResources, DbErr> {
518        let holder = self.inner.read().expect("RBAC Engine died");
519        let engine = holder.as_ref().expect("RBAC Engine not set");
520        engine
521            .list_role_permissions_by_resources(role_id)
522            .map_err(|err| DbErr::RbacError(err.to_string()))
523    }
524}
525
526fn map_err(err: RbacError) -> DbErr {
527    DbErr::RbacError(err.to_string())
528}