sea_orm/database/
restricted_connection.rs

1use crate::rbac::{
2    PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources,
3    RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks, RbacUserRolePermissions,
4    ResourceRequest,
5    entity::{role::RoleId, user::UserId},
6};
7use crate::{
8    AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
9    ExecResult, IsolationLevel, QueryResult, Statement, StatementBuilder, TransactionError,
10    TransactionSession, TransactionTrait,
11};
12use std::{
13    pin::Pin,
14    sync::{Arc, RwLock},
15};
16use tracing::instrument;
17
18/// Wrapper of [`DatabaseConnection`] that performs authorization on all executed
19/// queries for the current user. Note that raw SQL [`Statement`] is not allowed
20/// currently.
21#[derive(Debug, Clone)]
22#[cfg_attr(docsrs, doc(cfg(feature = "rbac")))]
23pub struct RestrictedConnection {
24    pub(crate) user_id: UserId,
25    pub(crate) conn: DatabaseConnection,
26}
27
28/// Wrapper of [`DatabaseTransaction`] that performs authorization on all executed
29/// queries for the current user. Note that raw SQL [`Statement`] is not allowed
30/// currently.
31#[derive(Debug)]
32pub struct RestrictedTransaction {
33    user_id: UserId,
34    conn: DatabaseTransaction,
35    rbac: RbacEngineMount,
36}
37
38#[derive(Debug, Default, Clone)]
39pub(crate) struct RbacEngineMount {
40    inner: Arc<RwLock<Option<RbacEngine>>>,
41}
42
43#[async_trait::async_trait]
44impl ConnectionTrait for RestrictedConnection {
45    fn get_database_backend(&self) -> DbBackend {
46        self.conn.get_database_backend()
47    }
48
49    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
50        Err(DbErr::RbacError(format!(
51            "Raw query is not supported: {stmt}"
52        )))
53    }
54
55    async fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
56        self.user_can_run(stmt)?;
57        self.conn.execute(stmt).await
58    }
59
60    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
61        Err(DbErr::RbacError(format!(
62            "Raw query is not supported: {sql}"
63        )))
64    }
65
66    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
67        Err(DbErr::RbacError(format!(
68            "Raw query is not supported: {stmt}"
69        )))
70    }
71
72    async fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
73        self.user_can_run(stmt)?;
74        self.conn.query_one(stmt).await
75    }
76
77    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
78        Err(DbErr::RbacError(format!(
79            "Raw query is not supported: {stmt}"
80        )))
81    }
82
83    async fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
84        self.user_can_run(stmt)?;
85        self.conn.query_all(stmt).await
86    }
87}
88
89#[async_trait::async_trait]
90impl ConnectionTrait for RestrictedTransaction {
91    fn get_database_backend(&self) -> DbBackend {
92        self.conn.get_database_backend()
93    }
94
95    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
96        Err(DbErr::RbacError(format!(
97            "Raw query is not supported: {stmt}"
98        )))
99    }
100
101    async fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
102        self.user_can_run(stmt)?;
103        self.conn.execute(stmt).await
104    }
105
106    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
107        Err(DbErr::RbacError(format!(
108            "Raw query is not supported: {sql}"
109        )))
110    }
111
112    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
113        Err(DbErr::RbacError(format!(
114            "Raw query is not supported: {stmt}"
115        )))
116    }
117
118    async fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
119        self.user_can_run(stmt)?;
120        self.conn.query_one(stmt).await
121    }
122
123    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
124        Err(DbErr::RbacError(format!(
125            "Raw query is not supported: {stmt}"
126        )))
127    }
128
129    async fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
130        self.user_can_run(stmt)?;
131        self.conn.query_all(stmt).await
132    }
133}
134
135impl RestrictedConnection {
136    /// Get the [`RbacUserId`] bounded to this connection.
137    pub fn user_id(&self) -> UserId {
138        self.user_id
139    }
140
141    /// Returns `()` if the current user can execute / query the given SQL statement.
142    /// Returns `DbErr` otherwise.
143    pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
144        self.conn.rbac.user_can_run(self.user_id, stmt)
145    }
146
147    /// Get current user's role and associated permissions.
148    /// This includes permissions "inherited" from child roles.
149    pub fn current_user_role_permissions(&self) -> Result<RbacUserRolePermissions, DbErr> {
150        self.conn.rbac.user_role_permissions(self.user_id)
151    }
152
153    /// Get a list of all roles and their ranks.
154    /// Rank is defined as (1 + number of child roles).
155    pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
156        self.conn.rbac.roles_and_ranks()
157    }
158
159    /// Get two lists of all resources and permissions, excluding wildcards.
160    pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
161        self.conn.rbac.resources_and_permissions()
162    }
163
164    /// Get a list of edges walking the role hierarchy tree
165    pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
166        self.conn.rbac.role_hierarchy_edges(role_id)
167    }
168
169    /// Get a list of permissions for the specific role, grouped by resources.
170    /// This does not include permissions of child roles.
171    pub fn role_permissions_by_resources(
172        &self,
173        role_id: RoleId,
174    ) -> Result<RbacPermissionsByResources, DbErr> {
175        self.conn.rbac.role_permissions_by_resources(role_id)
176    }
177}
178
179impl RestrictedTransaction {
180    /// Get the [`RbacUserId`] bounded to this connection.
181    pub fn user_id(&self) -> UserId {
182        self.user_id
183    }
184
185    /// Returns `()` if the current user can execute / query the given SQL statement.
186    /// Returns `DbErr` otherwise.
187    pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
188        self.rbac.user_can_run(self.user_id, stmt)
189    }
190}
191
192#[async_trait::async_trait]
193impl TransactionTrait for RestrictedConnection {
194    type Transaction = RestrictedTransaction;
195
196    #[instrument(level = "trace")]
197    async fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
198        Ok(RestrictedTransaction {
199            user_id: self.user_id,
200            conn: self.conn.begin().await?,
201            rbac: self.conn.rbac.clone(),
202        })
203    }
204
205    #[instrument(level = "trace")]
206    async fn begin_with_config(
207        &self,
208        isolation_level: Option<IsolationLevel>,
209        access_mode: Option<AccessMode>,
210    ) -> Result<RestrictedTransaction, DbErr> {
211        Ok(RestrictedTransaction {
212            user_id: self.user_id,
213            conn: self
214                .conn
215                .begin_with_config(isolation_level, access_mode)
216                .await?,
217            rbac: self.conn.rbac.clone(),
218        })
219    }
220
221    /// Execute the function inside a transaction.
222    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
223    #[instrument(level = "trace", skip(callback))]
224    async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
225    where
226        F: for<'c> FnOnce(
227                &'c RestrictedTransaction,
228            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
229            + Send,
230        T: Send,
231        E: std::fmt::Display + std::fmt::Debug + Send,
232    {
233        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
234        transaction.run(callback).await
235    }
236
237    /// Execute the function inside a transaction.
238    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
239    #[instrument(level = "trace", skip(callback))]
240    async fn transaction_with_config<F, T, E>(
241        &self,
242        callback: F,
243        isolation_level: Option<IsolationLevel>,
244        access_mode: Option<AccessMode>,
245    ) -> Result<T, TransactionError<E>>
246    where
247        F: for<'c> FnOnce(
248                &'c RestrictedTransaction,
249            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
250            + Send,
251        T: Send,
252        E: std::fmt::Display + std::fmt::Debug + Send,
253    {
254        let transaction = self
255            .begin_with_config(isolation_level, access_mode)
256            .await
257            .map_err(TransactionError::Connection)?;
258        transaction.run(callback).await
259    }
260}
261
262#[async_trait::async_trait]
263impl TransactionTrait for RestrictedTransaction {
264    type Transaction = RestrictedTransaction;
265
266    #[instrument(level = "trace")]
267    async fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
268        Ok(RestrictedTransaction {
269            user_id: self.user_id,
270            conn: self.conn.begin().await?,
271            rbac: self.rbac.clone(),
272        })
273    }
274
275    #[instrument(level = "trace")]
276    async fn begin_with_config(
277        &self,
278        isolation_level: Option<IsolationLevel>,
279        access_mode: Option<AccessMode>,
280    ) -> Result<RestrictedTransaction, DbErr> {
281        Ok(RestrictedTransaction {
282            user_id: self.user_id,
283            conn: self
284                .conn
285                .begin_with_config(isolation_level, access_mode)
286                .await?,
287            rbac: self.rbac.clone(),
288        })
289    }
290
291    /// Execute the function inside a transaction.
292    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
293    #[instrument(level = "trace", skip(callback))]
294    async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
295    where
296        F: for<'c> FnOnce(
297                &'c RestrictedTransaction,
298            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
299            + Send,
300        T: Send,
301        E: std::fmt::Display + std::fmt::Debug + Send,
302    {
303        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
304        transaction.run(callback).await
305    }
306
307    /// Execute the function inside a transaction.
308    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
309    #[instrument(level = "trace", skip(callback))]
310    async fn transaction_with_config<F, T, E>(
311        &self,
312        callback: F,
313        isolation_level: Option<IsolationLevel>,
314        access_mode: Option<AccessMode>,
315    ) -> Result<T, TransactionError<E>>
316    where
317        F: for<'c> FnOnce(
318                &'c RestrictedTransaction,
319            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
320            + Send,
321        T: Send,
322        E: std::fmt::Display + std::fmt::Debug + Send,
323    {
324        let transaction = self
325            .begin_with_config(isolation_level, access_mode)
326            .await
327            .map_err(TransactionError::Connection)?;
328        transaction.run(callback).await
329    }
330}
331
332#[async_trait::async_trait]
333impl TransactionSession for RestrictedTransaction {
334    async fn commit(self) -> Result<(), DbErr> {
335        self.commit().await
336    }
337
338    async fn rollback(self) -> Result<(), DbErr> {
339        self.rollback().await
340    }
341}
342
343impl RestrictedTransaction {
344    /// Runs a transaction to completion passing through the result.
345    /// Rolling back the transaction on encountering an error.
346    #[instrument(level = "trace", skip(callback))]
347    async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
348    where
349        F: for<'b> FnOnce(
350                &'b RestrictedTransaction,
351            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
352            + Send,
353        T: Send,
354        E: std::fmt::Display + std::fmt::Debug + Send,
355    {
356        let res = callback(&self).await.map_err(TransactionError::Transaction);
357        if res.is_ok() {
358            self.commit().await.map_err(TransactionError::Connection)?;
359        } else {
360            self.rollback()
361                .await
362                .map_err(TransactionError::Connection)?;
363        }
364        res
365    }
366
367    /// Commit a transaction
368    #[instrument(level = "trace")]
369    pub async fn commit(self) -> Result<(), DbErr> {
370        self.conn.commit().await
371    }
372
373    /// Rolls back a transaction explicitly
374    #[instrument(level = "trace")]
375    pub async fn rollback(self) -> Result<(), DbErr> {
376        self.conn.rollback().await
377    }
378}
379
380impl RbacEngineMount {
381    pub fn is_some(&self) -> bool {
382        let engine = self.inner.read().expect("RBAC Engine died");
383        engine.is_some()
384    }
385
386    pub fn replace(&self, engine: RbacEngine) {
387        let mut inner = self.inner.write().expect("RBAC Engine died");
388        *inner = Some(engine);
389    }
390
391    pub fn user_can_run<S: StatementBuilder>(
392        &self,
393        user_id: UserId,
394        stmt: &S,
395    ) -> Result<(), DbErr> {
396        let audit = match stmt.audit() {
397            Ok(audit) => audit,
398            Err(err) => return Err(DbErr::RbacError(err.to_string())),
399        };
400        for request in audit.requests {
401            // There is nothing we can do if RwLock is poisoned.
402            let holder = self.inner.read().expect("RBAC Engine died");
403            // Constructor of this struct should ensure engine is not None.
404            let engine = holder.as_ref().expect("RBAC Engine not set");
405            let permission = || PermissionRequest {
406                action: request.access_type.as_str().to_owned(),
407            };
408            let resource = || ResourceRequest {
409                schema: request.schema_table.0.as_ref().map(|s| s.1.to_string()),
410                table: request.schema_table.1.to_string(),
411            };
412            if !engine
413                .user_can(user_id, permission(), resource())
414                .map_err(map_err)?
415            {
416                let r = resource();
417                return Err(DbErr::AccessDenied {
418                    permission: permission().action.to_owned(),
419                    resource: format!(
420                        "{}{}{}",
421                        if let Some(schema) = &r.schema {
422                            schema
423                        } else {
424                            ""
425                        },
426                        if r.schema.is_some() { "." } else { "" },
427                        r.table
428                    ),
429                });
430            }
431        }
432        Ok(())
433    }
434
435    pub fn user_role_permissions(&self, user_id: UserId) -> Result<RbacUserRolePermissions, DbErr> {
436        let holder = self.inner.read().expect("RBAC Engine died");
437        let engine = holder.as_ref().expect("RBAC Engine not set");
438        engine
439            .get_user_role_permissions(user_id)
440            .map_err(|err| DbErr::RbacError(err.to_string()))
441    }
442
443    pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
444        let holder = self.inner.read().expect("RBAC Engine died");
445        let engine = holder.as_ref().expect("RBAC Engine not set");
446        engine
447            .get_roles_and_ranks()
448            .map_err(|err| DbErr::RbacError(err.to_string()))
449    }
450
451    pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
452        let holder = self.inner.read().expect("RBAC Engine died");
453        let engine = holder.as_ref().expect("RBAC Engine not set");
454        Ok(engine.list_resources_and_permissions())
455    }
456
457    pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
458        let holder = self.inner.read().expect("RBAC Engine died");
459        let engine = holder.as_ref().expect("RBAC Engine not set");
460        Ok(engine.list_role_hierarchy_edges(role_id))
461    }
462
463    pub fn role_permissions_by_resources(
464        &self,
465        role_id: RoleId,
466    ) -> Result<RbacPermissionsByResources, DbErr> {
467        let holder = self.inner.read().expect("RBAC Engine died");
468        let engine = holder.as_ref().expect("RBAC Engine not set");
469        engine
470            .list_role_permissions_by_resources(role_id)
471            .map_err(|err| DbErr::RbacError(err.to_string()))
472    }
473}
474
475fn map_err(err: RbacError) -> DbErr {
476    DbErr::RbacError(err.to_string())
477}