Skip to main content

sea_orm/database/
restricted_connection.rs

1use crate::{
2    AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
3    ExecResult, IsolationLevel, QueryResult, Statement, StatementBuilder, TransactionError,
4    TransactionSession, TransactionTrait,
5};
6use crate::{
7    TransactionOptions,
8    rbac::{
9        PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources,
10        RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks,
11        RbacUserRolePermissions, ResourceRequest,
12        entity::{role::RoleId, user::UserId},
13    },
14};
15use std::{
16    pin::Pin,
17    sync::{Arc, RwLock},
18};
19use tracing::instrument;
20
21/// Wrapper of [`DatabaseConnection`] that performs authorization on all executed
22/// queries for the current user. Note that raw SQL [`Statement`] is not allowed
23/// currently.
24#[derive(Debug, Clone)]
25#[cfg_attr(docsrs, doc(cfg(feature = "rbac")))]
26pub struct RestrictedConnection {
27    pub(crate) user_id: UserId,
28    pub(crate) conn: DatabaseConnection,
29}
30
31/// Wrapper of [`DatabaseTransaction`] that performs authorization on all executed
32/// queries for the current user. Note that raw SQL [`Statement`] is not allowed
33/// currently.
34#[derive(Debug)]
35pub struct RestrictedTransaction {
36    user_id: UserId,
37    conn: DatabaseTransaction,
38    rbac: RbacEngineMount,
39}
40
41#[derive(Debug, Default, Clone)]
42pub(crate) struct RbacEngineMount {
43    inner: Arc<RwLock<Option<RbacEngine>>>,
44}
45
46impl ConnectionTrait for RestrictedConnection {
47    fn get_database_backend(&self) -> DbBackend {
48        self.conn.get_database_backend()
49    }
50
51    fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
52        Err(DbErr::RbacError(format!(
53            "Raw query is not supported: {stmt}"
54        )))
55    }
56
57    fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
58        self.user_can_run(stmt)?;
59        self.conn.execute(stmt)
60    }
61
62    fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
63        Err(DbErr::RbacError(format!(
64            "Raw query is not supported: {sql}"
65        )))
66    }
67
68    fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
69        Err(DbErr::RbacError(format!(
70            "Raw query is not supported: {stmt}"
71        )))
72    }
73
74    fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
75        self.user_can_run(stmt)?;
76        self.conn.query_one(stmt)
77    }
78
79    fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
80        Err(DbErr::RbacError(format!(
81            "Raw query is not supported: {stmt}"
82        )))
83    }
84
85    fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
86        self.user_can_run(stmt)?;
87        self.conn.query_all(stmt)
88    }
89}
90
91impl ConnectionTrait for RestrictedTransaction {
92    fn get_database_backend(&self) -> DbBackend {
93        self.conn.get_database_backend()
94    }
95
96    fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
97        Err(DbErr::RbacError(format!(
98            "Raw query is not supported: {stmt}"
99        )))
100    }
101
102    fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
103        self.user_can_run(stmt)?;
104        self.conn.execute(stmt)
105    }
106
107    fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
108        Err(DbErr::RbacError(format!(
109            "Raw query is not supported: {sql}"
110        )))
111    }
112
113    fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
114        Err(DbErr::RbacError(format!(
115            "Raw query is not supported: {stmt}"
116        )))
117    }
118
119    fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
120        self.user_can_run(stmt)?;
121        self.conn.query_one(stmt)
122    }
123
124    fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
125        Err(DbErr::RbacError(format!(
126            "Raw query is not supported: {stmt}"
127        )))
128    }
129
130    fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
131        self.user_can_run(stmt)?;
132        self.conn.query_all(stmt)
133    }
134}
135
136impl RestrictedConnection {
137    /// Get the [`RbacUserId`] bounded to this connection.
138    pub fn user_id(&self) -> UserId {
139        self.user_id
140    }
141
142    /// Returns `()` if the current user can execute / query the given SQL statement.
143    /// Returns `DbErr::AccessDenied` otherwise.
144    pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
145        self.conn.rbac.user_can_run(self.user_id, stmt)
146    }
147
148    /// Returns true if the current user can perform action on resource
149    pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
150    where
151        P: Into<PermissionRequest>,
152        R: Into<ResourceRequest>,
153    {
154        self.conn.rbac.user_can(self.user_id, permission, resource)
155    }
156
157    /// Get current user's role and associated permissions.
158    /// This includes permissions "inherited" from child roles.
159    pub fn current_user_role_permissions(&self) -> Result<RbacUserRolePermissions, DbErr> {
160        self.conn.rbac.user_role_permissions(self.user_id)
161    }
162
163    /// Get a list of all roles and their ranks.
164    /// Rank is defined as (1 + number of child roles).
165    pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
166        self.conn.rbac.roles_and_ranks()
167    }
168
169    /// Get two lists of all resources and permissions, excluding wildcards.
170    pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
171        self.conn.rbac.resources_and_permissions()
172    }
173
174    /// Get a list of edges walking the role hierarchy tree
175    pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
176        self.conn.rbac.role_hierarchy_edges(role_id)
177    }
178
179    /// Get a list of permissions for the specific role, grouped by resources.
180    /// This does not include permissions of child roles.
181    pub fn role_permissions_by_resources(
182        &self,
183        role_id: RoleId,
184    ) -> Result<RbacPermissionsByResources, DbErr> {
185        self.conn.rbac.role_permissions_by_resources(role_id)
186    }
187}
188
189impl RestrictedTransaction {
190    /// Get the [`RbacUserId`] bounded to this connection.
191    pub fn user_id(&self) -> UserId {
192        self.user_id
193    }
194
195    /// Returns `()` if the current user can execute / query the given SQL statement.
196    /// Returns `DbErr::AccessDenied` otherwise.
197    pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
198        self.rbac.user_can_run(self.user_id, stmt)
199    }
200
201    /// Returns true if the current user can perform action on resource
202    pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
203    where
204        P: Into<PermissionRequest>,
205        R: Into<ResourceRequest>,
206    {
207        self.rbac.user_can(self.user_id, permission, resource)
208    }
209}
210
211impl TransactionTrait for RestrictedConnection {
212    type Transaction = RestrictedTransaction;
213
214    #[instrument(level = "trace")]
215    fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
216        Ok(RestrictedTransaction {
217            user_id: self.user_id,
218            conn: self.conn.begin()?,
219            rbac: self.conn.rbac.clone(),
220        })
221    }
222
223    #[instrument(level = "trace")]
224    fn begin_with_config(
225        &self,
226        isolation_level: Option<IsolationLevel>,
227        access_mode: Option<AccessMode>,
228    ) -> Result<RestrictedTransaction, DbErr> {
229        Ok(RestrictedTransaction {
230            user_id: self.user_id,
231            conn: self.conn.begin_with_config(isolation_level, access_mode)?,
232            rbac: self.conn.rbac.clone(),
233        })
234    }
235
236    #[instrument(level = "trace")]
237    fn begin_with_options(
238        &self,
239        options: TransactionOptions,
240    ) -> Result<RestrictedTransaction, DbErr> {
241        Ok(RestrictedTransaction {
242            user_id: self.user_id,
243            conn: self.conn.begin_with_options(options)?,
244            rbac: self.conn.rbac.clone(),
245        })
246    }
247
248    /// Execute the function inside a transaction.
249    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
250    #[instrument(level = "trace", skip(callback))]
251    fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
252    where
253        F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
254        E: std::fmt::Display + std::fmt::Debug,
255    {
256        let transaction = self.begin().map_err(TransactionError::Connection)?;
257        transaction.run(callback)
258    }
259
260    /// Execute the function inside a transaction.
261    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
262    #[instrument(level = "trace", skip(callback))]
263    fn transaction_with_config<F, T, E>(
264        &self,
265        callback: F,
266        isolation_level: Option<IsolationLevel>,
267        access_mode: Option<AccessMode>,
268    ) -> Result<T, TransactionError<E>>
269    where
270        F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
271        E: std::fmt::Display + std::fmt::Debug,
272    {
273        let transaction = self
274            .begin_with_config(isolation_level, access_mode)
275            .map_err(TransactionError::Connection)?;
276        transaction.run(callback)
277    }
278}
279
280impl TransactionTrait for RestrictedTransaction {
281    type Transaction = RestrictedTransaction;
282
283    #[instrument(level = "trace")]
284    fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
285        Ok(RestrictedTransaction {
286            user_id: self.user_id,
287            conn: self.conn.begin()?,
288            rbac: self.rbac.clone(),
289        })
290    }
291
292    #[instrument(level = "trace")]
293    fn begin_with_config(
294        &self,
295        isolation_level: Option<IsolationLevel>,
296        access_mode: Option<AccessMode>,
297    ) -> Result<RestrictedTransaction, DbErr> {
298        Ok(RestrictedTransaction {
299            user_id: self.user_id,
300            conn: self.conn.begin_with_config(isolation_level, access_mode)?,
301            rbac: self.rbac.clone(),
302        })
303    }
304
305    #[instrument(level = "trace")]
306    fn begin_with_options(
307        &self,
308        options: TransactionOptions,
309    ) -> Result<RestrictedTransaction, DbErr> {
310        Ok(RestrictedTransaction {
311            user_id: self.user_id,
312            conn: self.conn.begin_with_options(options)?,
313            rbac: self.rbac.clone(),
314        })
315    }
316
317    /// Execute the function inside a transaction.
318    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
319    #[instrument(level = "trace", skip(callback))]
320    fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
321    where
322        F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
323        E: std::fmt::Display + std::fmt::Debug,
324    {
325        let transaction = self.begin().map_err(TransactionError::Connection)?;
326        transaction.run(callback)
327    }
328
329    /// Execute the function inside a transaction.
330    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
331    #[instrument(level = "trace", skip(callback))]
332    fn transaction_with_config<F, T, E>(
333        &self,
334        callback: F,
335        isolation_level: Option<IsolationLevel>,
336        access_mode: Option<AccessMode>,
337    ) -> Result<T, TransactionError<E>>
338    where
339        F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
340        E: std::fmt::Display + std::fmt::Debug,
341    {
342        let transaction = self
343            .begin_with_config(isolation_level, access_mode)
344            .map_err(TransactionError::Connection)?;
345        transaction.run(callback)
346    }
347}
348
349impl TransactionSession for RestrictedTransaction {
350    fn commit(self) -> Result<(), DbErr> {
351        self.commit()
352    }
353
354    fn rollback(self) -> Result<(), DbErr> {
355        self.rollback()
356    }
357}
358
359impl RestrictedTransaction {
360    /// Runs a transaction to completion passing through the result.
361    /// Rolling back the transaction on encountering an error.
362    #[instrument(level = "trace", skip(callback))]
363    fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
364    where
365        F: for<'b> FnOnce(&'b RestrictedTransaction) -> Result<T, E>,
366        E: std::fmt::Display + std::fmt::Debug,
367    {
368        let res = callback(&self).map_err(TransactionError::Transaction);
369        if res.is_ok() {
370            self.commit().map_err(TransactionError::Connection)?;
371        } else {
372            self.rollback().map_err(TransactionError::Connection)?;
373        }
374        res
375    }
376
377    /// Commit a transaction
378    #[instrument(level = "trace")]
379    pub fn commit(self) -> Result<(), DbErr> {
380        self.conn.commit()
381    }
382
383    /// Rolls back a transaction explicitly
384    #[instrument(level = "trace")]
385    pub fn rollback(self) -> Result<(), DbErr> {
386        self.conn.rollback()
387    }
388}
389
390impl RbacEngineMount {
391    pub fn is_some(&self) -> bool {
392        let engine = self.inner.read().expect("RBAC Engine died");
393        engine.is_some()
394    }
395
396    pub fn replace(&self, engine: RbacEngine) {
397        let mut inner = self.inner.write().expect("RBAC Engine died");
398        *inner = Some(engine);
399    }
400
401    pub fn user_can<P, R>(&self, user_id: UserId, permission: P, resource: R) -> Result<bool, DbErr>
402    where
403        P: Into<PermissionRequest>,
404        R: Into<ResourceRequest>,
405    {
406        let permission = permission.into();
407        let resource = resource.into();
408        // There is nothing we can do if RwLock is poisoned.
409        let holder = self.inner.read().expect("RBAC Engine died");
410        // Constructor of this struct should ensure engine is not None.
411        let engine = holder.as_ref().expect("RBAC Engine not set");
412        engine
413            .user_can(user_id, permission, resource)
414            .map_err(map_err)
415    }
416
417    pub fn user_can_run<S: StatementBuilder>(
418        &self,
419        user_id: UserId,
420        stmt: &S,
421    ) -> Result<(), DbErr> {
422        let audit = match stmt.audit() {
423            Ok(audit) => audit,
424            Err(err) => return Err(DbErr::RbacError(err.to_string())),
425        };
426        // There is nothing we can do if RwLock is poisoned.
427        let holder = self.inner.read().expect("RBAC Engine died");
428        // Constructor of this struct should ensure engine is not None.
429        let engine = holder.as_ref().expect("RBAC Engine not set");
430        for request in audit.requests {
431            let permission = || PermissionRequest {
432                action: request.access_type.as_str().to_owned(),
433            };
434            let resource = || ResourceRequest {
435                schema: request.schema_table.0.as_ref().map(|s| s.1.to_string()),
436                table: request.schema_table.1.to_string(),
437            };
438            if !engine
439                .user_can(user_id, permission(), resource())
440                .map_err(map_err)?
441            {
442                return Err(DbErr::AccessDenied {
443                    permission: permission().action.to_owned(),
444                    resource: resource().to_string(),
445                });
446            }
447        }
448        Ok(())
449    }
450
451    pub fn user_role_permissions(&self, user_id: UserId) -> Result<RbacUserRolePermissions, DbErr> {
452        let holder = self.inner.read().expect("RBAC Engine died");
453        let engine = holder.as_ref().expect("RBAC Engine not set");
454        engine
455            .get_user_role_permissions(user_id)
456            .map_err(|err| DbErr::RbacError(err.to_string()))
457    }
458
459    pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
460        let holder = self.inner.read().expect("RBAC Engine died");
461        let engine = holder.as_ref().expect("RBAC Engine not set");
462        engine
463            .get_roles_and_ranks()
464            .map_err(|err| DbErr::RbacError(err.to_string()))
465    }
466
467    pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
468        let holder = self.inner.read().expect("RBAC Engine died");
469        let engine = holder.as_ref().expect("RBAC Engine not set");
470        Ok(engine.list_resources_and_permissions())
471    }
472
473    pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
474        let holder = self.inner.read().expect("RBAC Engine died");
475        let engine = holder.as_ref().expect("RBAC Engine not set");
476        Ok(engine.list_role_hierarchy_edges(role_id))
477    }
478
479    pub fn role_permissions_by_resources(
480        &self,
481        role_id: RoleId,
482    ) -> Result<RbacPermissionsByResources, DbErr> {
483        let holder = self.inner.read().expect("RBAC Engine died");
484        let engine = holder.as_ref().expect("RBAC Engine not set");
485        engine
486            .list_role_permissions_by_resources(role_id)
487            .map_err(|err| DbErr::RbacError(err.to_string()))
488    }
489}
490
491fn map_err(err: RbacError) -> DbErr {
492    DbErr::RbacError(err.to_string())
493}