Skip to main content

sea_orm/database/
restricted_connection.rs

1use super::transaction::run_async_transaction_callback;
2use crate::{
3    AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
4    ExecResult, IsolationLevel, QueryResult, Statement, StatementBuilder, TransactionError,
5    TransactionSession, TransactionTrait,
6};
7use crate::{
8    TransactionOptions,
9    rbac::{
10        PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources,
11        RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks,
12        RbacUserRolePermissions, ResourceRequest,
13        entity::{role::RoleId, user::UserId},
14    },
15};
16use std::{
17    future::Future,
18    pin::Pin,
19    sync::{Arc, RwLock},
20};
21use tracing::instrument;
22
23/// Wrapper of [`DatabaseConnection`] that performs authorization on all executed
24/// queries for the current user. Note that raw SQL [`Statement`] is not allowed
25/// currently.
26#[derive(Debug, Clone)]
27#[cfg_attr(docsrs, doc(cfg(feature = "rbac")))]
28pub struct RestrictedConnection {
29    pub(crate) user_id: UserId,
30    pub(crate) conn: DatabaseConnection,
31}
32
33/// Wrapper of [`DatabaseTransaction`] that performs authorization on all executed
34/// queries for the current user. Note that raw SQL [`Statement`] is not allowed
35/// currently.
36#[derive(Debug)]
37pub struct RestrictedTransaction {
38    user_id: UserId,
39    conn: DatabaseTransaction,
40    rbac: RbacEngineMount,
41}
42
43#[derive(Debug, Default, Clone)]
44pub(crate) struct RbacEngineMount {
45    inner: Arc<RwLock<Option<RbacEngine>>>,
46}
47
48#[async_trait::async_trait]
49impl ConnectionTrait for RestrictedConnection {
50    fn get_database_backend(&self) -> DbBackend {
51        self.conn.get_database_backend()
52    }
53
54    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
55        Err(DbErr::RbacError(format!(
56            "Raw query is not supported: {stmt}"
57        )))
58    }
59
60    async fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
61        self.user_can_run(stmt)?;
62        self.conn.execute(stmt).await
63    }
64
65    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
66        Err(DbErr::RbacError(format!(
67            "Raw query is not supported: {sql}"
68        )))
69    }
70
71    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
72        Err(DbErr::RbacError(format!(
73            "Raw query is not supported: {stmt}"
74        )))
75    }
76
77    async fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
78        self.user_can_run(stmt)?;
79        self.conn.query_one(stmt).await
80    }
81
82    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
83        Err(DbErr::RbacError(format!(
84            "Raw query is not supported: {stmt}"
85        )))
86    }
87
88    async fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
89        self.user_can_run(stmt)?;
90        self.conn.query_all(stmt).await
91    }
92}
93
94#[async_trait::async_trait]
95impl ConnectionTrait for RestrictedTransaction {
96    fn get_database_backend(&self) -> DbBackend {
97        self.conn.get_database_backend()
98    }
99
100    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
101        Err(DbErr::RbacError(format!(
102            "Raw query is not supported: {stmt}"
103        )))
104    }
105
106    async fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
107        self.user_can_run(stmt)?;
108        self.conn.execute(stmt).await
109    }
110
111    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
112        Err(DbErr::RbacError(format!(
113            "Raw query is not supported: {sql}"
114        )))
115    }
116
117    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
118        Err(DbErr::RbacError(format!(
119            "Raw query is not supported: {stmt}"
120        )))
121    }
122
123    async fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
124        self.user_can_run(stmt)?;
125        self.conn.query_one(stmt).await
126    }
127
128    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
129        Err(DbErr::RbacError(format!(
130            "Raw query is not supported: {stmt}"
131        )))
132    }
133
134    async fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
135        self.user_can_run(stmt)?;
136        self.conn.query_all(stmt).await
137    }
138}
139
140impl RestrictedConnection {
141    /// Get the [`RbacUserId`] bounded to this connection.
142    pub fn user_id(&self) -> UserId {
143        self.user_id
144    }
145
146    /// Execute the function inside a transaction.
147    /// If the function returns an error, the transaction will be rolled back.
148    /// Otherwise, the transaction will be committed.
149    #[instrument(level = "trace", skip(callback))]
150    pub async fn transaction_async<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
151    where
152        F: for<'c> AsyncFnOnce(&'c RestrictedTransaction) -> Result<T, E> + Send,
153        T: Send,
154        E: std::fmt::Display + std::fmt::Debug + Send,
155    {
156        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
157        run_async_transaction_callback(transaction, callback).await
158    }
159
160    /// Execute the function inside a transaction with isolation level and/or access mode.
161    /// If the function returns an error, the transaction will be rolled back.
162    /// Otherwise, the transaction will be committed.
163    #[instrument(level = "trace", skip(callback))]
164    pub async fn transaction_with_config_async<F, T, E>(
165        &self,
166        callback: F,
167        isolation_level: Option<IsolationLevel>,
168        access_mode: Option<AccessMode>,
169    ) -> Result<T, TransactionError<E>>
170    where
171        F: for<'c> AsyncFnOnce(&'c RestrictedTransaction) -> Result<T, E> + Send,
172        T: Send,
173        E: std::fmt::Display + std::fmt::Debug + Send,
174    {
175        let transaction = self
176            .begin_with_config(isolation_level, access_mode)
177            .await
178            .map_err(TransactionError::Connection)?;
179        run_async_transaction_callback(transaction, callback).await
180    }
181
182    /// Returns `()` if the current user can execute / query the given SQL statement.
183    /// Returns `DbErr::AccessDenied` otherwise.
184    pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
185        self.conn.rbac.user_can_run(self.user_id, stmt)
186    }
187
188    /// Returns true if the current user can perform action on resource
189    pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
190    where
191        P: Into<PermissionRequest>,
192        R: Into<ResourceRequest>,
193    {
194        self.conn.rbac.user_can(self.user_id, permission, resource)
195    }
196
197    /// Get current user's role and associated permissions.
198    /// This includes permissions "inherited" from child roles.
199    pub fn current_user_role_permissions(&self) -> Result<RbacUserRolePermissions, DbErr> {
200        self.conn.rbac.user_role_permissions(self.user_id)
201    }
202
203    /// Get a list of all roles and their ranks.
204    /// Rank is defined as (1 + number of child roles).
205    pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
206        self.conn.rbac.roles_and_ranks()
207    }
208
209    /// Get two lists of all resources and permissions, excluding wildcards.
210    pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
211        self.conn.rbac.resources_and_permissions()
212    }
213
214    /// Get a list of edges walking the role hierarchy tree
215    pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
216        self.conn.rbac.role_hierarchy_edges(role_id)
217    }
218
219    /// Get a list of permissions for the specific role, grouped by resources.
220    /// This does not include permissions of child roles.
221    pub fn role_permissions_by_resources(
222        &self,
223        role_id: RoleId,
224    ) -> Result<RbacPermissionsByResources, DbErr> {
225        self.conn.rbac.role_permissions_by_resources(role_id)
226    }
227}
228
229impl RestrictedTransaction {
230    /// Get the [`RbacUserId`] bounded to this connection.
231    pub fn user_id(&self) -> UserId {
232        self.user_id
233    }
234
235    /// Execute the function inside a transaction.
236    /// If the function returns an error, the transaction will be rolled back.
237    /// Otherwise, the transaction will be committed.
238    #[instrument(level = "trace", skip(callback))]
239    pub async fn transaction_async<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
240    where
241        F: for<'c> AsyncFnOnce(&'c RestrictedTransaction) -> Result<T, E> + Send,
242        T: Send,
243        E: std::fmt::Display + std::fmt::Debug + Send,
244    {
245        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
246        run_async_transaction_callback(transaction, callback).await
247    }
248
249    /// Execute the function inside a transaction with isolation level and/or access mode.
250    /// If the function returns an error, the transaction will be rolled back.
251    /// Otherwise, the transaction will be committed.
252    #[instrument(level = "trace", skip(callback))]
253    pub async fn transaction_with_config_async<F, T, E>(
254        &self,
255        callback: F,
256        isolation_level: Option<IsolationLevel>,
257        access_mode: Option<AccessMode>,
258    ) -> Result<T, TransactionError<E>>
259    where
260        F: for<'c> AsyncFnOnce(&'c RestrictedTransaction) -> Result<T, E> + Send,
261        T: Send,
262        E: std::fmt::Display + std::fmt::Debug + Send,
263    {
264        let transaction = self
265            .begin_with_config(isolation_level, access_mode)
266            .await
267            .map_err(TransactionError::Connection)?;
268        run_async_transaction_callback(transaction, callback).await
269    }
270
271    /// Returns `()` if the current user can execute / query the given SQL statement.
272    /// Returns `DbErr::AccessDenied` otherwise.
273    pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
274        self.rbac.user_can_run(self.user_id, stmt)
275    }
276
277    /// Returns true if the current user can perform action on resource
278    pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
279    where
280        P: Into<PermissionRequest>,
281        R: Into<ResourceRequest>,
282    {
283        self.rbac.user_can(self.user_id, permission, resource)
284    }
285}
286
287#[async_trait::async_trait]
288impl TransactionTrait for RestrictedConnection {
289    type Transaction = RestrictedTransaction;
290
291    #[instrument(level = "trace")]
292    async fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
293        Ok(RestrictedTransaction {
294            user_id: self.user_id,
295            conn: self.conn.begin().await?,
296            rbac: self.conn.rbac.clone(),
297        })
298    }
299
300    #[instrument(level = "trace")]
301    async fn begin_with_config(
302        &self,
303        isolation_level: Option<IsolationLevel>,
304        access_mode: Option<AccessMode>,
305    ) -> Result<RestrictedTransaction, DbErr> {
306        Ok(RestrictedTransaction {
307            user_id: self.user_id,
308            conn: self
309                .conn
310                .begin_with_config(isolation_level, access_mode)
311                .await?,
312            rbac: self.conn.rbac.clone(),
313        })
314    }
315
316    #[instrument(level = "trace")]
317    async fn begin_with_options(
318        &self,
319        options: TransactionOptions,
320    ) -> Result<RestrictedTransaction, DbErr> {
321        Ok(RestrictedTransaction {
322            user_id: self.user_id,
323            conn: self.conn.begin_with_options(options).await?,
324            rbac: self.conn.rbac.clone(),
325        })
326    }
327
328    /// Execute the function inside a transaction.
329    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
330    #[instrument(level = "trace", skip(callback))]
331    async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
332    where
333        F: for<'c> FnOnce(
334                &'c RestrictedTransaction,
335            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
336            + Send,
337        T: Send,
338        E: std::fmt::Display + std::fmt::Debug + Send,
339    {
340        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
341        transaction.run(callback).await
342    }
343
344    /// Execute the function inside a transaction.
345    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
346    #[instrument(level = "trace", skip(callback))]
347    async fn transaction_with_config<F, T, E>(
348        &self,
349        callback: F,
350        isolation_level: Option<IsolationLevel>,
351        access_mode: Option<AccessMode>,
352    ) -> Result<T, TransactionError<E>>
353    where
354        F: for<'c> FnOnce(
355                &'c RestrictedTransaction,
356            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
357            + Send,
358        T: Send,
359        E: std::fmt::Display + std::fmt::Debug + Send,
360    {
361        let transaction = self
362            .begin_with_config(isolation_level, access_mode)
363            .await
364            .map_err(TransactionError::Connection)?;
365        transaction.run(callback).await
366    }
367}
368
369#[async_trait::async_trait]
370impl TransactionTrait for RestrictedTransaction {
371    type Transaction = RestrictedTransaction;
372
373    #[instrument(level = "trace")]
374    async fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
375        Ok(RestrictedTransaction {
376            user_id: self.user_id,
377            conn: self.conn.begin().await?,
378            rbac: self.rbac.clone(),
379        })
380    }
381
382    #[instrument(level = "trace")]
383    async fn begin_with_config(
384        &self,
385        isolation_level: Option<IsolationLevel>,
386        access_mode: Option<AccessMode>,
387    ) -> Result<RestrictedTransaction, DbErr> {
388        Ok(RestrictedTransaction {
389            user_id: self.user_id,
390            conn: self
391                .conn
392                .begin_with_config(isolation_level, access_mode)
393                .await?,
394            rbac: self.rbac.clone(),
395        })
396    }
397
398    #[instrument(level = "trace")]
399    async fn begin_with_options(
400        &self,
401        options: TransactionOptions,
402    ) -> Result<RestrictedTransaction, DbErr> {
403        Ok(RestrictedTransaction {
404            user_id: self.user_id,
405            conn: self.conn.begin_with_options(options).await?,
406            rbac: self.rbac.clone(),
407        })
408    }
409
410    /// Execute the function inside a transaction.
411    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
412    #[instrument(level = "trace", skip(callback))]
413    async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
414    where
415        F: for<'c> FnOnce(
416                &'c RestrictedTransaction,
417            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
418            + Send,
419        T: Send,
420        E: std::fmt::Display + std::fmt::Debug + Send,
421    {
422        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
423        transaction.run(callback).await
424    }
425
426    /// Execute the function inside a transaction.
427    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
428    #[instrument(level = "trace", skip(callback))]
429    async fn transaction_with_config<F, T, E>(
430        &self,
431        callback: F,
432        isolation_level: Option<IsolationLevel>,
433        access_mode: Option<AccessMode>,
434    ) -> Result<T, TransactionError<E>>
435    where
436        F: for<'c> FnOnce(
437                &'c RestrictedTransaction,
438            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
439            + Send,
440        T: Send,
441        E: std::fmt::Display + std::fmt::Debug + Send,
442    {
443        let transaction = self
444            .begin_with_config(isolation_level, access_mode)
445            .await
446            .map_err(TransactionError::Connection)?;
447        transaction.run(callback).await
448    }
449}
450
451#[async_trait::async_trait]
452impl TransactionSession for RestrictedTransaction {
453    async fn commit(self) -> Result<(), DbErr> {
454        self.commit().await
455    }
456
457    async fn rollback(self) -> Result<(), DbErr> {
458        self.rollback().await
459    }
460}
461
462impl RestrictedTransaction {
463    /// Runs a transaction to completion passing through the result.
464    /// Rolling back the transaction on encountering an error.
465    #[instrument(level = "trace", skip(callback))]
466    async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
467    where
468        F: for<'b> FnOnce(
469                &'b RestrictedTransaction,
470            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
471            + Send,
472        T: Send,
473        E: std::fmt::Display + std::fmt::Debug + Send,
474    {
475        let res = callback(&self).await.map_err(TransactionError::Transaction);
476        if res.is_ok() {
477            self.commit().await.map_err(TransactionError::Connection)?;
478        } else {
479            self.rollback()
480                .await
481                .map_err(TransactionError::Connection)?;
482        }
483        res
484    }
485
486    /// Commit a transaction
487    #[instrument(level = "trace")]
488    pub async fn commit(self) -> Result<(), DbErr> {
489        self.conn.commit().await
490    }
491
492    /// Rolls back a transaction explicitly
493    #[instrument(level = "trace")]
494    pub async fn rollback(self) -> Result<(), DbErr> {
495        self.conn.rollback().await
496    }
497}
498
499impl RbacEngineMount {
500    pub fn is_some(&self) -> bool {
501        let engine = self.inner.read().expect("RBAC Engine died");
502        engine.is_some()
503    }
504
505    pub fn replace(&self, engine: RbacEngine) {
506        let mut inner = self.inner.write().expect("RBAC Engine died");
507        *inner = Some(engine);
508    }
509
510    pub fn user_can<P, R>(&self, user_id: UserId, permission: P, resource: R) -> Result<bool, DbErr>
511    where
512        P: Into<PermissionRequest>,
513        R: Into<ResourceRequest>,
514    {
515        let permission = permission.into();
516        let resource = resource.into();
517        // There is nothing we can do if RwLock is poisoned.
518        let holder = self.inner.read().expect("RBAC Engine died");
519        // Constructor of this struct should ensure engine is not None.
520        let engine = holder.as_ref().expect("RBAC Engine not set");
521        engine
522            .user_can(user_id, permission, resource)
523            .map_err(map_err)
524    }
525
526    pub fn user_can_run<S: StatementBuilder>(
527        &self,
528        user_id: UserId,
529        stmt: &S,
530    ) -> Result<(), DbErr> {
531        let audit = match stmt.audit() {
532            Ok(audit) => audit,
533            Err(err) => return Err(DbErr::RbacError(err.to_string())),
534        };
535        // There is nothing we can do if RwLock is poisoned.
536        let holder = self.inner.read().expect("RBAC Engine died");
537        // Constructor of this struct should ensure engine is not None.
538        let engine = holder.as_ref().expect("RBAC Engine not set");
539        for request in audit.requests {
540            let permission = || PermissionRequest {
541                action: request.access_type.as_str().to_owned(),
542            };
543            let resource = || ResourceRequest {
544                schema: request.schema_table.0.as_ref().map(|s| s.1.to_string()),
545                table: request.schema_table.1.to_string(),
546            };
547            if !engine
548                .user_can(user_id, permission(), resource())
549                .map_err(map_err)?
550            {
551                return Err(DbErr::AccessDenied {
552                    permission: permission().action.to_owned(),
553                    resource: resource().to_string(),
554                });
555            }
556        }
557        Ok(())
558    }
559
560    pub fn user_role_permissions(&self, user_id: UserId) -> Result<RbacUserRolePermissions, DbErr> {
561        let holder = self.inner.read().expect("RBAC Engine died");
562        let engine = holder.as_ref().expect("RBAC Engine not set");
563        engine
564            .get_user_role_permissions(user_id)
565            .map_err(|err| DbErr::RbacError(err.to_string()))
566    }
567
568    pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
569        let holder = self.inner.read().expect("RBAC Engine died");
570        let engine = holder.as_ref().expect("RBAC Engine not set");
571        engine
572            .get_roles_and_ranks()
573            .map_err(|err| DbErr::RbacError(err.to_string()))
574    }
575
576    pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
577        let holder = self.inner.read().expect("RBAC Engine died");
578        let engine = holder.as_ref().expect("RBAC Engine not set");
579        Ok(engine.list_resources_and_permissions())
580    }
581
582    pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
583        let holder = self.inner.read().expect("RBAC Engine died");
584        let engine = holder.as_ref().expect("RBAC Engine not set");
585        Ok(engine.list_role_hierarchy_edges(role_id))
586    }
587
588    pub fn role_permissions_by_resources(
589        &self,
590        role_id: RoleId,
591    ) -> Result<RbacPermissionsByResources, DbErr> {
592        let holder = self.inner.read().expect("RBAC Engine died");
593        let engine = holder.as_ref().expect("RBAC Engine not set");
594        engine
595            .list_role_permissions_by_resources(role_id)
596            .map_err(|err| DbErr::RbacError(err.to_string()))
597    }
598}
599
600fn map_err(err: RbacError) -> DbErr {
601    DbErr::RbacError(err.to_string())
602}