sea_orm/database/
restricted_connection.rs

1use crate::rbac::{
2    PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources,
3    RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks, RbacUserRolePermissions,
4    ResourceRequest,
5    entity::{role::RoleId, user::UserId},
6};
7use crate::{
8    AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
9    ExecResult, IsolationLevel, QueryResult, Statement, StatementBuilder, TransactionError,
10    TransactionSession, TransactionTrait,
11};
12use std::{
13    pin::Pin,
14    sync::{Arc, RwLock},
15};
16use tracing::instrument;
17
18/// Wrapper of [`DatabaseConnection`] that performs authorization on all executed
19/// queries for the current user. Note that raw SQL [`Statement`] is not allowed
20/// currently.
21#[derive(Debug, Clone)]
22#[cfg_attr(docsrs, doc(cfg(feature = "rbac")))]
23pub struct RestrictedConnection {
24    pub(crate) user_id: UserId,
25    pub(crate) conn: DatabaseConnection,
26}
27
28/// Wrapper of [`DatabaseTransaction`] that performs authorization on all executed
29/// queries for the current user. Note that raw SQL [`Statement`] is not allowed
30/// currently.
31#[derive(Debug)]
32pub struct RestrictedTransaction {
33    user_id: UserId,
34    conn: DatabaseTransaction,
35    rbac: RbacEngineMount,
36}
37
38#[derive(Debug, Default, Clone)]
39pub(crate) struct RbacEngineMount {
40    inner: Arc<RwLock<Option<RbacEngine>>>,
41}
42
43#[async_trait::async_trait]
44impl ConnectionTrait for RestrictedConnection {
45    fn get_database_backend(&self) -> DbBackend {
46        self.conn.get_database_backend()
47    }
48
49    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
50        Err(DbErr::RbacError(format!(
51            "Raw query is not supported: {stmt}"
52        )))
53    }
54
55    async fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
56        self.user_can_run(stmt)?;
57        self.conn.execute(stmt).await
58    }
59
60    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
61        Err(DbErr::RbacError(format!(
62            "Raw query is not supported: {sql}"
63        )))
64    }
65
66    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
67        Err(DbErr::RbacError(format!(
68            "Raw query is not supported: {stmt}"
69        )))
70    }
71
72    async fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
73        self.user_can_run(stmt)?;
74        self.conn.query_one(stmt).await
75    }
76
77    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
78        Err(DbErr::RbacError(format!(
79            "Raw query is not supported: {stmt}"
80        )))
81    }
82
83    async fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
84        self.user_can_run(stmt)?;
85        self.conn.query_all(stmt).await
86    }
87}
88
89#[async_trait::async_trait]
90impl ConnectionTrait for RestrictedTransaction {
91    fn get_database_backend(&self) -> DbBackend {
92        self.conn.get_database_backend()
93    }
94
95    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
96        Err(DbErr::RbacError(format!(
97            "Raw query is not supported: {stmt}"
98        )))
99    }
100
101    async fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
102        self.user_can_run(stmt)?;
103        self.conn.execute(stmt).await
104    }
105
106    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
107        Err(DbErr::RbacError(format!(
108            "Raw query is not supported: {sql}"
109        )))
110    }
111
112    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
113        Err(DbErr::RbacError(format!(
114            "Raw query is not supported: {stmt}"
115        )))
116    }
117
118    async fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
119        self.user_can_run(stmt)?;
120        self.conn.query_one(stmt).await
121    }
122
123    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
124        Err(DbErr::RbacError(format!(
125            "Raw query is not supported: {stmt}"
126        )))
127    }
128
129    async fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
130        self.user_can_run(stmt)?;
131        self.conn.query_all(stmt).await
132    }
133}
134
135impl RestrictedConnection {
136    /// Get the [`RbacUserId`] bounded to this connection.
137    pub fn user_id(&self) -> UserId {
138        self.user_id
139    }
140
141    /// Returns `()` if the current user can execute / query the given SQL statement.
142    /// Returns `DbErr::AccessDenied` otherwise.
143    pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
144        self.conn.rbac.user_can_run(self.user_id, stmt)
145    }
146
147    /// Returns true if the current user can perform action on resource
148    pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
149    where
150        P: Into<PermissionRequest>,
151        R: Into<ResourceRequest>,
152    {
153        self.conn.rbac.user_can(self.user_id, permission, resource)
154    }
155
156    /// Get current user's role and associated permissions.
157    /// This includes permissions "inherited" from child roles.
158    pub fn current_user_role_permissions(&self) -> Result<RbacUserRolePermissions, DbErr> {
159        self.conn.rbac.user_role_permissions(self.user_id)
160    }
161
162    /// Get a list of all roles and their ranks.
163    /// Rank is defined as (1 + number of child roles).
164    pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
165        self.conn.rbac.roles_and_ranks()
166    }
167
168    /// Get two lists of all resources and permissions, excluding wildcards.
169    pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
170        self.conn.rbac.resources_and_permissions()
171    }
172
173    /// Get a list of edges walking the role hierarchy tree
174    pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
175        self.conn.rbac.role_hierarchy_edges(role_id)
176    }
177
178    /// Get a list of permissions for the specific role, grouped by resources.
179    /// This does not include permissions of child roles.
180    pub fn role_permissions_by_resources(
181        &self,
182        role_id: RoleId,
183    ) -> Result<RbacPermissionsByResources, DbErr> {
184        self.conn.rbac.role_permissions_by_resources(role_id)
185    }
186}
187
188impl RestrictedTransaction {
189    /// Get the [`RbacUserId`] bounded to this connection.
190    pub fn user_id(&self) -> UserId {
191        self.user_id
192    }
193
194    /// Returns `()` if the current user can execute / query the given SQL statement.
195    /// Returns `DbErr::AccessDenied` otherwise.
196    pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
197        self.rbac.user_can_run(self.user_id, stmt)
198    }
199
200    /// Returns true if the current user can perform action on resource
201    pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
202    where
203        P: Into<PermissionRequest>,
204        R: Into<ResourceRequest>,
205    {
206        self.rbac.user_can(self.user_id, permission, resource)
207    }
208}
209
210#[async_trait::async_trait]
211impl TransactionTrait for RestrictedConnection {
212    type Transaction = RestrictedTransaction;
213
214    #[instrument(level = "trace")]
215    async fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
216        Ok(RestrictedTransaction {
217            user_id: self.user_id,
218            conn: self.conn.begin().await?,
219            rbac: self.conn.rbac.clone(),
220        })
221    }
222
223    #[instrument(level = "trace")]
224    async fn begin_with_config(
225        &self,
226        isolation_level: Option<IsolationLevel>,
227        access_mode: Option<AccessMode>,
228    ) -> Result<RestrictedTransaction, DbErr> {
229        Ok(RestrictedTransaction {
230            user_id: self.user_id,
231            conn: self
232                .conn
233                .begin_with_config(isolation_level, access_mode)
234                .await?,
235            rbac: self.conn.rbac.clone(),
236        })
237    }
238
239    /// Execute the function inside a transaction.
240    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
241    #[instrument(level = "trace", skip(callback))]
242    async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
243    where
244        F: for<'c> FnOnce(
245                &'c RestrictedTransaction,
246            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
247            + Send,
248        T: Send,
249        E: std::fmt::Display + std::fmt::Debug + Send,
250    {
251        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
252        transaction.run(callback).await
253    }
254
255    /// Execute the function inside a transaction.
256    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
257    #[instrument(level = "trace", skip(callback))]
258    async fn transaction_with_config<F, T, E>(
259        &self,
260        callback: F,
261        isolation_level: Option<IsolationLevel>,
262        access_mode: Option<AccessMode>,
263    ) -> Result<T, TransactionError<E>>
264    where
265        F: for<'c> FnOnce(
266                &'c RestrictedTransaction,
267            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
268            + Send,
269        T: Send,
270        E: std::fmt::Display + std::fmt::Debug + Send,
271    {
272        let transaction = self
273            .begin_with_config(isolation_level, access_mode)
274            .await
275            .map_err(TransactionError::Connection)?;
276        transaction.run(callback).await
277    }
278}
279
280#[async_trait::async_trait]
281impl TransactionTrait for RestrictedTransaction {
282    type Transaction = RestrictedTransaction;
283
284    #[instrument(level = "trace")]
285    async fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
286        Ok(RestrictedTransaction {
287            user_id: self.user_id,
288            conn: self.conn.begin().await?,
289            rbac: self.rbac.clone(),
290        })
291    }
292
293    #[instrument(level = "trace")]
294    async fn begin_with_config(
295        &self,
296        isolation_level: Option<IsolationLevel>,
297        access_mode: Option<AccessMode>,
298    ) -> Result<RestrictedTransaction, DbErr> {
299        Ok(RestrictedTransaction {
300            user_id: self.user_id,
301            conn: self
302                .conn
303                .begin_with_config(isolation_level, access_mode)
304                .await?,
305            rbac: self.rbac.clone(),
306        })
307    }
308
309    /// Execute the function inside a transaction.
310    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
311    #[instrument(level = "trace", skip(callback))]
312    async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
313    where
314        F: for<'c> FnOnce(
315                &'c RestrictedTransaction,
316            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
317            + Send,
318        T: Send,
319        E: std::fmt::Display + std::fmt::Debug + Send,
320    {
321        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
322        transaction.run(callback).await
323    }
324
325    /// Execute the function inside a transaction.
326    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
327    #[instrument(level = "trace", skip(callback))]
328    async fn transaction_with_config<F, T, E>(
329        &self,
330        callback: F,
331        isolation_level: Option<IsolationLevel>,
332        access_mode: Option<AccessMode>,
333    ) -> Result<T, TransactionError<E>>
334    where
335        F: for<'c> FnOnce(
336                &'c RestrictedTransaction,
337            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
338            + Send,
339        T: Send,
340        E: std::fmt::Display + std::fmt::Debug + Send,
341    {
342        let transaction = self
343            .begin_with_config(isolation_level, access_mode)
344            .await
345            .map_err(TransactionError::Connection)?;
346        transaction.run(callback).await
347    }
348}
349
350#[async_trait::async_trait]
351impl TransactionSession for RestrictedTransaction {
352    async fn commit(self) -> Result<(), DbErr> {
353        self.commit().await
354    }
355
356    async fn rollback(self) -> Result<(), DbErr> {
357        self.rollback().await
358    }
359}
360
361impl RestrictedTransaction {
362    /// Runs a transaction to completion passing through the result.
363    /// Rolling back the transaction on encountering an error.
364    #[instrument(level = "trace", skip(callback))]
365    async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
366    where
367        F: for<'b> FnOnce(
368                &'b RestrictedTransaction,
369            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
370            + Send,
371        T: Send,
372        E: std::fmt::Display + std::fmt::Debug + Send,
373    {
374        let res = callback(&self).await.map_err(TransactionError::Transaction);
375        if res.is_ok() {
376            self.commit().await.map_err(TransactionError::Connection)?;
377        } else {
378            self.rollback()
379                .await
380                .map_err(TransactionError::Connection)?;
381        }
382        res
383    }
384
385    /// Commit a transaction
386    #[instrument(level = "trace")]
387    pub async fn commit(self) -> Result<(), DbErr> {
388        self.conn.commit().await
389    }
390
391    /// Rolls back a transaction explicitly
392    #[instrument(level = "trace")]
393    pub async fn rollback(self) -> Result<(), DbErr> {
394        self.conn.rollback().await
395    }
396}
397
398impl RbacEngineMount {
399    pub fn is_some(&self) -> bool {
400        let engine = self.inner.read().expect("RBAC Engine died");
401        engine.is_some()
402    }
403
404    pub fn replace(&self, engine: RbacEngine) {
405        let mut inner = self.inner.write().expect("RBAC Engine died");
406        *inner = Some(engine);
407    }
408
409    pub fn user_can<P, R>(&self, user_id: UserId, permission: P, resource: R) -> Result<bool, DbErr>
410    where
411        P: Into<PermissionRequest>,
412        R: Into<ResourceRequest>,
413    {
414        let permission = permission.into();
415        let resource = resource.into();
416        // There is nothing we can do if RwLock is poisoned.
417        let holder = self.inner.read().expect("RBAC Engine died");
418        // Constructor of this struct should ensure engine is not None.
419        let engine = holder.as_ref().expect("RBAC Engine not set");
420        engine
421            .user_can(user_id, permission, resource)
422            .map_err(map_err)
423    }
424
425    pub fn user_can_run<S: StatementBuilder>(
426        &self,
427        user_id: UserId,
428        stmt: &S,
429    ) -> Result<(), DbErr> {
430        let audit = match stmt.audit() {
431            Ok(audit) => audit,
432            Err(err) => return Err(DbErr::RbacError(err.to_string())),
433        };
434        // There is nothing we can do if RwLock is poisoned.
435        let holder = self.inner.read().expect("RBAC Engine died");
436        // Constructor of this struct should ensure engine is not None.
437        let engine = holder.as_ref().expect("RBAC Engine not set");
438        for request in audit.requests {
439            let permission = || PermissionRequest {
440                action: request.access_type.as_str().to_owned(),
441            };
442            let resource = || ResourceRequest {
443                schema: request.schema_table.0.as_ref().map(|s| s.1.to_string()),
444                table: request.schema_table.1.to_string(),
445            };
446            if !engine
447                .user_can(user_id, permission(), resource())
448                .map_err(map_err)?
449            {
450                return Err(DbErr::AccessDenied {
451                    permission: permission().action.to_owned(),
452                    resource: resource().to_string(),
453                });
454            }
455        }
456        Ok(())
457    }
458
459    pub fn user_role_permissions(&self, user_id: UserId) -> Result<RbacUserRolePermissions, DbErr> {
460        let holder = self.inner.read().expect("RBAC Engine died");
461        let engine = holder.as_ref().expect("RBAC Engine not set");
462        engine
463            .get_user_role_permissions(user_id)
464            .map_err(|err| DbErr::RbacError(err.to_string()))
465    }
466
467    pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
468        let holder = self.inner.read().expect("RBAC Engine died");
469        let engine = holder.as_ref().expect("RBAC Engine not set");
470        engine
471            .get_roles_and_ranks()
472            .map_err(|err| DbErr::RbacError(err.to_string()))
473    }
474
475    pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
476        let holder = self.inner.read().expect("RBAC Engine died");
477        let engine = holder.as_ref().expect("RBAC Engine not set");
478        Ok(engine.list_resources_and_permissions())
479    }
480
481    pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
482        let holder = self.inner.read().expect("RBAC Engine died");
483        let engine = holder.as_ref().expect("RBAC Engine not set");
484        Ok(engine.list_role_hierarchy_edges(role_id))
485    }
486
487    pub fn role_permissions_by_resources(
488        &self,
489        role_id: RoleId,
490    ) -> Result<RbacPermissionsByResources, DbErr> {
491        let holder = self.inner.read().expect("RBAC Engine died");
492        let engine = holder.as_ref().expect("RBAC Engine not set");
493        engine
494            .list_role_permissions_by_resources(role_id)
495            .map_err(|err| DbErr::RbacError(err.to_string()))
496    }
497}
498
499fn map_err(err: RbacError) -> DbErr {
500    DbErr::RbacError(err.to_string())
501}