sea_orm/database/
restricted_connection.rs

1use crate::rbac::{
2    PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources,
3    RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks, RbacUserRolePermissions,
4    ResourceRequest,
5    entity::{role::RoleId, user::UserId},
6};
7use crate::{
8    AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
9    ExecResult, IsolationLevel, QueryResult, Statement, StatementBuilder, TransactionError,
10    TransactionSession, TransactionTrait,
11};
12use std::{
13    pin::Pin,
14    sync::{Arc, RwLock},
15};
16use tracing::instrument;
17
18/// Wrapper of [`DatabaseConnection`] that performs authorization on all executed
19/// queries for the current user. Note that raw SQL [`Statement`] is not allowed
20/// currently.
21#[derive(Debug, Clone)]
22#[cfg_attr(docsrs, doc(cfg(feature = "rbac")))]
23pub struct RestrictedConnection {
24    pub(crate) user_id: UserId,
25    pub(crate) conn: DatabaseConnection,
26}
27
28/// Wrapper of [`DatabaseTransaction`] that performs authorization on all executed
29/// queries for the current user. Note that raw SQL [`Statement`] is not allowed
30/// currently.
31#[derive(Debug)]
32pub struct RestrictedTransaction {
33    user_id: UserId,
34    conn: DatabaseTransaction,
35    rbac: RbacEngineMount,
36}
37
38#[derive(Debug, Default, Clone)]
39pub(crate) struct RbacEngineMount {
40    inner: Arc<RwLock<Option<RbacEngine>>>,
41}
42
43impl ConnectionTrait for RestrictedConnection {
44    fn get_database_backend(&self) -> DbBackend {
45        self.conn.get_database_backend()
46    }
47
48    fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
49        Err(DbErr::RbacError(format!(
50            "Raw query is not supported: {stmt}"
51        )))
52    }
53
54    fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
55        self.user_can_run(stmt)?;
56        self.conn.execute(stmt)
57    }
58
59    fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
60        Err(DbErr::RbacError(format!(
61            "Raw query is not supported: {sql}"
62        )))
63    }
64
65    fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
66        Err(DbErr::RbacError(format!(
67            "Raw query is not supported: {stmt}"
68        )))
69    }
70
71    fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
72        self.user_can_run(stmt)?;
73        self.conn.query_one(stmt)
74    }
75
76    fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
77        Err(DbErr::RbacError(format!(
78            "Raw query is not supported: {stmt}"
79        )))
80    }
81
82    fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
83        self.user_can_run(stmt)?;
84        self.conn.query_all(stmt)
85    }
86}
87
88impl ConnectionTrait for RestrictedTransaction {
89    fn get_database_backend(&self) -> DbBackend {
90        self.conn.get_database_backend()
91    }
92
93    fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
94        Err(DbErr::RbacError(format!(
95            "Raw query is not supported: {stmt}"
96        )))
97    }
98
99    fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
100        self.user_can_run(stmt)?;
101        self.conn.execute(stmt)
102    }
103
104    fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
105        Err(DbErr::RbacError(format!(
106            "Raw query is not supported: {sql}"
107        )))
108    }
109
110    fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
111        Err(DbErr::RbacError(format!(
112            "Raw query is not supported: {stmt}"
113        )))
114    }
115
116    fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
117        self.user_can_run(stmt)?;
118        self.conn.query_one(stmt)
119    }
120
121    fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
122        Err(DbErr::RbacError(format!(
123            "Raw query is not supported: {stmt}"
124        )))
125    }
126
127    fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
128        self.user_can_run(stmt)?;
129        self.conn.query_all(stmt)
130    }
131}
132
133impl RestrictedConnection {
134    /// Get the [`RbacUserId`] bounded to this connection.
135    pub fn user_id(&self) -> UserId {
136        self.user_id
137    }
138
139    /// Returns `()` if the current user can execute / query the given SQL statement.
140    /// Returns `DbErr::AccessDenied` otherwise.
141    pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
142        self.conn.rbac.user_can_run(self.user_id, stmt)
143    }
144
145    /// Returns true if the current user can perform action on resource
146    pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
147    where
148        P: Into<PermissionRequest>,
149        R: Into<ResourceRequest>,
150    {
151        self.conn.rbac.user_can(self.user_id, permission, resource)
152    }
153
154    /// Get current user's role and associated permissions.
155    /// This includes permissions "inherited" from child roles.
156    pub fn current_user_role_permissions(&self) -> Result<RbacUserRolePermissions, DbErr> {
157        self.conn.rbac.user_role_permissions(self.user_id)
158    }
159
160    /// Get a list of all roles and their ranks.
161    /// Rank is defined as (1 + number of child roles).
162    pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
163        self.conn.rbac.roles_and_ranks()
164    }
165
166    /// Get two lists of all resources and permissions, excluding wildcards.
167    pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
168        self.conn.rbac.resources_and_permissions()
169    }
170
171    /// Get a list of edges walking the role hierarchy tree
172    pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
173        self.conn.rbac.role_hierarchy_edges(role_id)
174    }
175
176    /// Get a list of permissions for the specific role, grouped by resources.
177    /// This does not include permissions of child roles.
178    pub fn role_permissions_by_resources(
179        &self,
180        role_id: RoleId,
181    ) -> Result<RbacPermissionsByResources, DbErr> {
182        self.conn.rbac.role_permissions_by_resources(role_id)
183    }
184}
185
186impl RestrictedTransaction {
187    /// Get the [`RbacUserId`] bounded to this connection.
188    pub fn user_id(&self) -> UserId {
189        self.user_id
190    }
191
192    /// Returns `()` if the current user can execute / query the given SQL statement.
193    /// Returns `DbErr::AccessDenied` otherwise.
194    pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
195        self.rbac.user_can_run(self.user_id, stmt)
196    }
197
198    /// Returns true if the current user can perform action on resource
199    pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
200    where
201        P: Into<PermissionRequest>,
202        R: Into<ResourceRequest>,
203    {
204        self.rbac.user_can(self.user_id, permission, resource)
205    }
206}
207
208impl TransactionTrait for RestrictedConnection {
209    type Transaction = RestrictedTransaction;
210
211    #[instrument(level = "trace")]
212    fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
213        Ok(RestrictedTransaction {
214            user_id: self.user_id,
215            conn: self.conn.begin()?,
216            rbac: self.conn.rbac.clone(),
217        })
218    }
219
220    #[instrument(level = "trace")]
221    fn begin_with_config(
222        &self,
223        isolation_level: Option<IsolationLevel>,
224        access_mode: Option<AccessMode>,
225    ) -> Result<RestrictedTransaction, DbErr> {
226        Ok(RestrictedTransaction {
227            user_id: self.user_id,
228            conn: self.conn.begin_with_config(isolation_level, access_mode)?,
229            rbac: self.conn.rbac.clone(),
230        })
231    }
232
233    /// Execute the function inside a transaction.
234    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
235    #[instrument(level = "trace", skip(callback))]
236    fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
237    where
238        F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
239        E: std::fmt::Display + std::fmt::Debug,
240    {
241        let transaction = self.begin().map_err(TransactionError::Connection)?;
242        transaction.run(callback)
243    }
244
245    /// Execute the function inside a transaction.
246    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
247    #[instrument(level = "trace", skip(callback))]
248    fn transaction_with_config<F, T, E>(
249        &self,
250        callback: F,
251        isolation_level: Option<IsolationLevel>,
252        access_mode: Option<AccessMode>,
253    ) -> Result<T, TransactionError<E>>
254    where
255        F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
256        E: std::fmt::Display + std::fmt::Debug,
257    {
258        let transaction = self
259            .begin_with_config(isolation_level, access_mode)
260            .map_err(TransactionError::Connection)?;
261        transaction.run(callback)
262    }
263}
264
265impl TransactionTrait for RestrictedTransaction {
266    type Transaction = RestrictedTransaction;
267
268    #[instrument(level = "trace")]
269    fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
270        Ok(RestrictedTransaction {
271            user_id: self.user_id,
272            conn: self.conn.begin()?,
273            rbac: self.rbac.clone(),
274        })
275    }
276
277    #[instrument(level = "trace")]
278    fn begin_with_config(
279        &self,
280        isolation_level: Option<IsolationLevel>,
281        access_mode: Option<AccessMode>,
282    ) -> Result<RestrictedTransaction, DbErr> {
283        Ok(RestrictedTransaction {
284            user_id: self.user_id,
285            conn: self.conn.begin_with_config(isolation_level, access_mode)?,
286            rbac: self.rbac.clone(),
287        })
288    }
289
290    /// Execute the function inside a transaction.
291    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
292    #[instrument(level = "trace", skip(callback))]
293    fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
294    where
295        F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
296        E: std::fmt::Display + std::fmt::Debug,
297    {
298        let transaction = self.begin().map_err(TransactionError::Connection)?;
299        transaction.run(callback)
300    }
301
302    /// Execute the function inside a transaction.
303    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
304    #[instrument(level = "trace", skip(callback))]
305    fn transaction_with_config<F, T, E>(
306        &self,
307        callback: F,
308        isolation_level: Option<IsolationLevel>,
309        access_mode: Option<AccessMode>,
310    ) -> Result<T, TransactionError<E>>
311    where
312        F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
313        E: std::fmt::Display + std::fmt::Debug,
314    {
315        let transaction = self
316            .begin_with_config(isolation_level, access_mode)
317            .map_err(TransactionError::Connection)?;
318        transaction.run(callback)
319    }
320}
321
322impl TransactionSession for RestrictedTransaction {
323    fn commit(self) -> Result<(), DbErr> {
324        self.commit()
325    }
326
327    fn rollback(self) -> Result<(), DbErr> {
328        self.rollback()
329    }
330}
331
332impl RestrictedTransaction {
333    /// Runs a transaction to completion passing through the result.
334    /// Rolling back the transaction on encountering an error.
335    #[instrument(level = "trace", skip(callback))]
336    fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
337    where
338        F: for<'b> FnOnce(&'b RestrictedTransaction) -> Result<T, E>,
339        E: std::fmt::Display + std::fmt::Debug,
340    {
341        let res = callback(&self).map_err(TransactionError::Transaction);
342        if res.is_ok() {
343            self.commit().map_err(TransactionError::Connection)?;
344        } else {
345            self.rollback().map_err(TransactionError::Connection)?;
346        }
347        res
348    }
349
350    /// Commit a transaction
351    #[instrument(level = "trace")]
352    pub fn commit(self) -> Result<(), DbErr> {
353        self.conn.commit()
354    }
355
356    /// Rolls back a transaction explicitly
357    #[instrument(level = "trace")]
358    pub fn rollback(self) -> Result<(), DbErr> {
359        self.conn.rollback()
360    }
361}
362
363impl RbacEngineMount {
364    pub fn is_some(&self) -> bool {
365        let engine = self.inner.read().expect("RBAC Engine died");
366        engine.is_some()
367    }
368
369    pub fn replace(&self, engine: RbacEngine) {
370        let mut inner = self.inner.write().expect("RBAC Engine died");
371        *inner = Some(engine);
372    }
373
374    pub fn user_can<P, R>(&self, user_id: UserId, permission: P, resource: R) -> Result<bool, DbErr>
375    where
376        P: Into<PermissionRequest>,
377        R: Into<ResourceRequest>,
378    {
379        let permission = permission.into();
380        let resource = resource.into();
381        // There is nothing we can do if RwLock is poisoned.
382        let holder = self.inner.read().expect("RBAC Engine died");
383        // Constructor of this struct should ensure engine is not None.
384        let engine = holder.as_ref().expect("RBAC Engine not set");
385        engine
386            .user_can(user_id, permission, resource)
387            .map_err(map_err)
388    }
389
390    pub fn user_can_run<S: StatementBuilder>(
391        &self,
392        user_id: UserId,
393        stmt: &S,
394    ) -> Result<(), DbErr> {
395        let audit = match stmt.audit() {
396            Ok(audit) => audit,
397            Err(err) => return Err(DbErr::RbacError(err.to_string())),
398        };
399        // There is nothing we can do if RwLock is poisoned.
400        let holder = self.inner.read().expect("RBAC Engine died");
401        // Constructor of this struct should ensure engine is not None.
402        let engine = holder.as_ref().expect("RBAC Engine not set");
403        for request in audit.requests {
404            let permission = || PermissionRequest {
405                action: request.access_type.as_str().to_owned(),
406            };
407            let resource = || ResourceRequest {
408                schema: request.schema_table.0.as_ref().map(|s| s.1.to_string()),
409                table: request.schema_table.1.to_string(),
410            };
411            if !engine
412                .user_can(user_id, permission(), resource())
413                .map_err(map_err)?
414            {
415                return Err(DbErr::AccessDenied {
416                    permission: permission().action.to_owned(),
417                    resource: resource().to_string(),
418                });
419            }
420        }
421        Ok(())
422    }
423
424    pub fn user_role_permissions(&self, user_id: UserId) -> Result<RbacUserRolePermissions, DbErr> {
425        let holder = self.inner.read().expect("RBAC Engine died");
426        let engine = holder.as_ref().expect("RBAC Engine not set");
427        engine
428            .get_user_role_permissions(user_id)
429            .map_err(|err| DbErr::RbacError(err.to_string()))
430    }
431
432    pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
433        let holder = self.inner.read().expect("RBAC Engine died");
434        let engine = holder.as_ref().expect("RBAC Engine not set");
435        engine
436            .get_roles_and_ranks()
437            .map_err(|err| DbErr::RbacError(err.to_string()))
438    }
439
440    pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
441        let holder = self.inner.read().expect("RBAC Engine died");
442        let engine = holder.as_ref().expect("RBAC Engine not set");
443        Ok(engine.list_resources_and_permissions())
444    }
445
446    pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
447        let holder = self.inner.read().expect("RBAC Engine died");
448        let engine = holder.as_ref().expect("RBAC Engine not set");
449        Ok(engine.list_role_hierarchy_edges(role_id))
450    }
451
452    pub fn role_permissions_by_resources(
453        &self,
454        role_id: RoleId,
455    ) -> Result<RbacPermissionsByResources, DbErr> {
456        let holder = self.inner.read().expect("RBAC Engine died");
457        let engine = holder.as_ref().expect("RBAC Engine not set");
458        engine
459            .list_role_permissions_by_resources(role_id)
460            .map_err(|err| DbErr::RbacError(err.to_string()))
461    }
462}
463
464fn map_err(err: RbacError) -> DbErr {
465    DbErr::RbacError(err.to_string())
466}