Skip to main content

sea_orm/database/
restricted_connection.rs

1use super::transaction::run_async_transaction_callback;
2use crate::{
3    AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
4    ExecResult, IsolationLevel, QueryResult, Statement, StatementBuilder, TransactionError,
5    TransactionSession, TransactionTrait,
6};
7use crate::{
8    TransactionOptions,
9    rbac::{
10        PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources,
11        RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks,
12        RbacUserRolePermissions, ResourceRequest,
13        entity::{role::RoleId, user::UserId},
14    },
15};
16use std::{
17    future::Future,
18    pin::Pin,
19    sync::{Arc, RwLock},
20};
21use tracing::instrument;
22
23/// Wrapper of [`DatabaseConnection`] that performs authorization on all executed
24/// queries for the current user. Note that raw SQL [`Statement`] is not allowed
25/// currently.
26#[derive(Debug, Clone)]
27#[cfg_attr(docsrs, doc(cfg(feature = "rbac")))]
28pub struct RestrictedConnection {
29    pub(crate) user_id: UserId,
30    pub(crate) conn: DatabaseConnection,
31}
32
33/// Wrapper of [`DatabaseTransaction`] that performs authorization on all executed
34/// queries for the current user. Note that raw SQL [`Statement`] is not allowed
35/// currently.
36#[derive(Debug)]
37pub struct RestrictedTransaction {
38    user_id: UserId,
39    conn: DatabaseTransaction,
40    rbac: RbacEngineMount,
41}
42
43#[derive(Debug, Default, Clone)]
44pub(crate) struct RbacEngineMount {
45    inner: Arc<RwLock<Option<RbacEngine>>>,
46}
47
48impl ConnectionTrait for RestrictedConnection {
49    fn get_database_backend(&self) -> DbBackend {
50        self.conn.get_database_backend()
51    }
52
53    fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
54        Err(DbErr::RbacError(format!(
55            "Raw query is not supported: {stmt}"
56        )))
57    }
58
59    fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
60        self.user_can_run(stmt)?;
61        self.conn.execute(stmt)
62    }
63
64    fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
65        Err(DbErr::RbacError(format!(
66            "Raw query is not supported: {sql}"
67        )))
68    }
69
70    fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
71        Err(DbErr::RbacError(format!(
72            "Raw query is not supported: {stmt}"
73        )))
74    }
75
76    fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
77        self.user_can_run(stmt)?;
78        self.conn.query_one(stmt)
79    }
80
81    fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
82        Err(DbErr::RbacError(format!(
83            "Raw query is not supported: {stmt}"
84        )))
85    }
86
87    fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
88        self.user_can_run(stmt)?;
89        self.conn.query_all(stmt)
90    }
91}
92
93impl ConnectionTrait for RestrictedTransaction {
94    fn get_database_backend(&self) -> DbBackend {
95        self.conn.get_database_backend()
96    }
97
98    fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
99        Err(DbErr::RbacError(format!(
100            "Raw query is not supported: {stmt}"
101        )))
102    }
103
104    fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
105        self.user_can_run(stmt)?;
106        self.conn.execute(stmt)
107    }
108
109    fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
110        Err(DbErr::RbacError(format!(
111            "Raw query is not supported: {sql}"
112        )))
113    }
114
115    fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
116        Err(DbErr::RbacError(format!(
117            "Raw query is not supported: {stmt}"
118        )))
119    }
120
121    fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
122        self.user_can_run(stmt)?;
123        self.conn.query_one(stmt)
124    }
125
126    fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
127        Err(DbErr::RbacError(format!(
128            "Raw query is not supported: {stmt}"
129        )))
130    }
131
132    fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
133        self.user_can_run(stmt)?;
134        self.conn.query_all(stmt)
135    }
136}
137
138impl RestrictedConnection {
139    /// Get the [`RbacUserId`] bounded to this connection.
140    pub fn user_id(&self) -> UserId {
141        self.user_id
142    }
143
144    /// Execute the function inside a transaction.
145    /// If the function returns an error, the transaction will be rolled back.
146    /// Otherwise, the transaction will be committed.
147    #[instrument(level = "trace", skip(callback))]
148    pub fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
149    where
150        F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
151        E: std::fmt::Display + std::fmt::Debug,
152    {
153        let transaction = self.begin().map_err(TransactionError::Connection)?;
154        run_async_transaction_callback(transaction, callback)
155    }
156
157    /// Execute the function inside a transaction with isolation level and/or access mode.
158    /// If the function returns an error, the transaction will be rolled back.
159    /// Otherwise, the transaction will be committed.
160    #[instrument(level = "trace", skip(callback))]
161    pub fn transaction_with_config<F, T, E>(
162        &self,
163        callback: F,
164        isolation_level: Option<IsolationLevel>,
165        access_mode: Option<AccessMode>,
166    ) -> Result<T, TransactionError<E>>
167    where
168        F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
169        E: std::fmt::Display + std::fmt::Debug,
170    {
171        let transaction = self
172            .begin_with_config(isolation_level, access_mode)
173            .map_err(TransactionError::Connection)?;
174        run_async_transaction_callback(transaction, callback)
175    }
176
177    /// Returns `()` if the current user can execute / query the given SQL statement.
178    /// Returns `DbErr::AccessDenied` otherwise.
179    pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
180        self.conn.rbac.user_can_run(self.user_id, stmt)
181    }
182
183    /// Returns true if the current user can perform action on resource
184    pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
185    where
186        P: Into<PermissionRequest>,
187        R: Into<ResourceRequest>,
188    {
189        self.conn.rbac.user_can(self.user_id, permission, resource)
190    }
191
192    /// Get current user's role and associated permissions.
193    /// This includes permissions "inherited" from child roles.
194    pub fn current_user_role_permissions(&self) -> Result<RbacUserRolePermissions, DbErr> {
195        self.conn.rbac.user_role_permissions(self.user_id)
196    }
197
198    /// Get a list of all roles and their ranks.
199    /// Rank is defined as (1 + number of child roles).
200    pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
201        self.conn.rbac.roles_and_ranks()
202    }
203
204    /// Get two lists of all resources and permissions, excluding wildcards.
205    pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
206        self.conn.rbac.resources_and_permissions()
207    }
208
209    /// Get a list of edges walking the role hierarchy tree
210    pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
211        self.conn.rbac.role_hierarchy_edges(role_id)
212    }
213
214    /// Get a list of permissions for the specific role, grouped by resources.
215    /// This does not include permissions of child roles.
216    pub fn role_permissions_by_resources(
217        &self,
218        role_id: RoleId,
219    ) -> Result<RbacPermissionsByResources, DbErr> {
220        self.conn.rbac.role_permissions_by_resources(role_id)
221    }
222}
223
224impl RestrictedTransaction {
225    /// Get the [`RbacUserId`] bounded to this connection.
226    pub fn user_id(&self) -> UserId {
227        self.user_id
228    }
229
230    /// Execute the function inside a transaction.
231    /// If the function returns an error, the transaction will be rolled back.
232    /// Otherwise, the transaction will be committed.
233    #[instrument(level = "trace", skip(callback))]
234    pub fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
235    where
236        F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
237        E: std::fmt::Display + std::fmt::Debug,
238    {
239        let transaction = self.begin().map_err(TransactionError::Connection)?;
240        run_async_transaction_callback(transaction, callback)
241    }
242
243    /// Execute the function inside a transaction with isolation level and/or access mode.
244    /// If the function returns an error, the transaction will be rolled back.
245    /// Otherwise, the transaction will be committed.
246    #[instrument(level = "trace", skip(callback))]
247    pub fn transaction_with_config<F, T, E>(
248        &self,
249        callback: F,
250        isolation_level: Option<IsolationLevel>,
251        access_mode: Option<AccessMode>,
252    ) -> Result<T, TransactionError<E>>
253    where
254        F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
255        E: std::fmt::Display + std::fmt::Debug,
256    {
257        let transaction = self
258            .begin_with_config(isolation_level, access_mode)
259            .map_err(TransactionError::Connection)?;
260        run_async_transaction_callback(transaction, callback)
261    }
262
263    /// Returns `()` if the current user can execute / query the given SQL statement.
264    /// Returns `DbErr::AccessDenied` otherwise.
265    pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
266        self.rbac.user_can_run(self.user_id, stmt)
267    }
268
269    /// Returns true if the current user can perform action on resource
270    pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
271    where
272        P: Into<PermissionRequest>,
273        R: Into<ResourceRequest>,
274    {
275        self.rbac.user_can(self.user_id, permission, resource)
276    }
277}
278
279impl TransactionTrait for RestrictedConnection {
280    type Transaction = RestrictedTransaction;
281
282    #[instrument(level = "trace")]
283    fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
284        Ok(RestrictedTransaction {
285            user_id: self.user_id,
286            conn: self.conn.begin()?,
287            rbac: self.conn.rbac.clone(),
288        })
289    }
290
291    #[instrument(level = "trace")]
292    fn begin_with_config(
293        &self,
294        isolation_level: Option<IsolationLevel>,
295        access_mode: Option<AccessMode>,
296    ) -> Result<RestrictedTransaction, DbErr> {
297        Ok(RestrictedTransaction {
298            user_id: self.user_id,
299            conn: self.conn.begin_with_config(isolation_level, access_mode)?,
300            rbac: self.conn.rbac.clone(),
301        })
302    }
303
304    #[instrument(level = "trace")]
305    fn begin_with_options(
306        &self,
307        options: TransactionOptions,
308    ) -> Result<RestrictedTransaction, DbErr> {
309        Ok(RestrictedTransaction {
310            user_id: self.user_id,
311            conn: self.conn.begin_with_options(options)?,
312            rbac: self.conn.rbac.clone(),
313        })
314    }
315
316    /// Execute the function inside a transaction.
317    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
318    #[instrument(level = "trace", skip(callback))]
319    fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
320    where
321        F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
322        E: std::fmt::Display + std::fmt::Debug,
323    {
324        let transaction = self.begin().map_err(TransactionError::Connection)?;
325        transaction.run(callback)
326    }
327
328    /// Execute the function inside a transaction.
329    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
330    #[instrument(level = "trace", skip(callback))]
331    fn transaction_with_config<F, T, E>(
332        &self,
333        callback: F,
334        isolation_level: Option<IsolationLevel>,
335        access_mode: Option<AccessMode>,
336    ) -> Result<T, TransactionError<E>>
337    where
338        F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
339        E: std::fmt::Display + std::fmt::Debug,
340    {
341        let transaction = self
342            .begin_with_config(isolation_level, access_mode)
343            .map_err(TransactionError::Connection)?;
344        transaction.run(callback)
345    }
346}
347
348impl TransactionTrait for RestrictedTransaction {
349    type Transaction = RestrictedTransaction;
350
351    #[instrument(level = "trace")]
352    fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
353        Ok(RestrictedTransaction {
354            user_id: self.user_id,
355            conn: self.conn.begin()?,
356            rbac: self.rbac.clone(),
357        })
358    }
359
360    #[instrument(level = "trace")]
361    fn begin_with_config(
362        &self,
363        isolation_level: Option<IsolationLevel>,
364        access_mode: Option<AccessMode>,
365    ) -> Result<RestrictedTransaction, DbErr> {
366        Ok(RestrictedTransaction {
367            user_id: self.user_id,
368            conn: self.conn.begin_with_config(isolation_level, access_mode)?,
369            rbac: self.rbac.clone(),
370        })
371    }
372
373    #[instrument(level = "trace")]
374    fn begin_with_options(
375        &self,
376        options: TransactionOptions,
377    ) -> Result<RestrictedTransaction, DbErr> {
378        Ok(RestrictedTransaction {
379            user_id: self.user_id,
380            conn: self.conn.begin_with_options(options)?,
381            rbac: self.rbac.clone(),
382        })
383    }
384
385    /// Execute the function inside a transaction.
386    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
387    #[instrument(level = "trace", skip(callback))]
388    fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
389    where
390        F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
391        E: std::fmt::Display + std::fmt::Debug,
392    {
393        let transaction = self.begin().map_err(TransactionError::Connection)?;
394        transaction.run(callback)
395    }
396
397    /// Execute the function inside a transaction.
398    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
399    #[instrument(level = "trace", skip(callback))]
400    fn transaction_with_config<F, T, E>(
401        &self,
402        callback: F,
403        isolation_level: Option<IsolationLevel>,
404        access_mode: Option<AccessMode>,
405    ) -> Result<T, TransactionError<E>>
406    where
407        F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
408        E: std::fmt::Display + std::fmt::Debug,
409    {
410        let transaction = self
411            .begin_with_config(isolation_level, access_mode)
412            .map_err(TransactionError::Connection)?;
413        transaction.run(callback)
414    }
415}
416
417impl TransactionSession for RestrictedTransaction {
418    fn commit(self) -> Result<(), DbErr> {
419        self.commit()
420    }
421
422    fn rollback(self) -> Result<(), DbErr> {
423        self.rollback()
424    }
425}
426
427impl RestrictedTransaction {
428    /// Runs a transaction to completion passing through the result.
429    /// Rolling back the transaction on encountering an error.
430    #[instrument(level = "trace", skip(callback))]
431    fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
432    where
433        F: for<'b> FnOnce(&'b RestrictedTransaction) -> Result<T, E>,
434        E: std::fmt::Display + std::fmt::Debug,
435    {
436        let res = callback(&self).map_err(TransactionError::Transaction);
437        if res.is_ok() {
438            self.commit().map_err(TransactionError::Connection)?;
439        } else {
440            self.rollback().map_err(TransactionError::Connection)?;
441        }
442        res
443    }
444
445    /// Commit a transaction
446    #[instrument(level = "trace")]
447    pub fn commit(self) -> Result<(), DbErr> {
448        self.conn.commit()
449    }
450
451    /// Rolls back a transaction explicitly
452    #[instrument(level = "trace")]
453    pub fn rollback(self) -> Result<(), DbErr> {
454        self.conn.rollback()
455    }
456}
457
458impl RbacEngineMount {
459    pub fn is_some(&self) -> bool {
460        let engine = self.inner.read().expect("RBAC Engine died");
461        engine.is_some()
462    }
463
464    pub fn replace(&self, engine: RbacEngine) {
465        let mut inner = self.inner.write().expect("RBAC Engine died");
466        *inner = Some(engine);
467    }
468
469    pub fn user_can<P, R>(&self, user_id: UserId, permission: P, resource: R) -> Result<bool, DbErr>
470    where
471        P: Into<PermissionRequest>,
472        R: Into<ResourceRequest>,
473    {
474        let permission = permission.into();
475        let resource = resource.into();
476        // There is nothing we can do if RwLock is poisoned.
477        let holder = self.inner.read().expect("RBAC Engine died");
478        // Constructor of this struct should ensure engine is not None.
479        let engine = holder.as_ref().expect("RBAC Engine not set");
480        engine
481            .user_can(user_id, permission, resource)
482            .map_err(map_err)
483    }
484
485    pub fn user_can_run<S: StatementBuilder>(
486        &self,
487        user_id: UserId,
488        stmt: &S,
489    ) -> Result<(), DbErr> {
490        let audit = match stmt.audit() {
491            Ok(audit) => audit,
492            Err(err) => return Err(DbErr::RbacError(err.to_string())),
493        };
494        // There is nothing we can do if RwLock is poisoned.
495        let holder = self.inner.read().expect("RBAC Engine died");
496        // Constructor of this struct should ensure engine is not None.
497        let engine = holder.as_ref().expect("RBAC Engine not set");
498        for request in audit.requests {
499            let permission = || PermissionRequest {
500                action: request.access_type.as_str().to_owned(),
501            };
502            let resource = || ResourceRequest {
503                schema: request.schema_table.0.as_ref().map(|s| s.1.to_string()),
504                table: request.schema_table.1.to_string(),
505            };
506            if !engine
507                .user_can(user_id, permission(), resource())
508                .map_err(map_err)?
509            {
510                return Err(DbErr::AccessDenied {
511                    permission: permission().action.to_owned(),
512                    resource: resource().to_string(),
513                });
514            }
515        }
516        Ok(())
517    }
518
519    pub fn user_role_permissions(&self, user_id: UserId) -> Result<RbacUserRolePermissions, DbErr> {
520        let holder = self.inner.read().expect("RBAC Engine died");
521        let engine = holder.as_ref().expect("RBAC Engine not set");
522        engine
523            .get_user_role_permissions(user_id)
524            .map_err(|err| DbErr::RbacError(err.to_string()))
525    }
526
527    pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
528        let holder = self.inner.read().expect("RBAC Engine died");
529        let engine = holder.as_ref().expect("RBAC Engine not set");
530        engine
531            .get_roles_and_ranks()
532            .map_err(|err| DbErr::RbacError(err.to_string()))
533    }
534
535    pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
536        let holder = self.inner.read().expect("RBAC Engine died");
537        let engine = holder.as_ref().expect("RBAC Engine not set");
538        Ok(engine.list_resources_and_permissions())
539    }
540
541    pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
542        let holder = self.inner.read().expect("RBAC Engine died");
543        let engine = holder.as_ref().expect("RBAC Engine not set");
544        Ok(engine.list_role_hierarchy_edges(role_id))
545    }
546
547    pub fn role_permissions_by_resources(
548        &self,
549        role_id: RoleId,
550    ) -> Result<RbacPermissionsByResources, DbErr> {
551        let holder = self.inner.read().expect("RBAC Engine died");
552        let engine = holder.as_ref().expect("RBAC Engine not set");
553        engine
554            .list_role_permissions_by_resources(role_id)
555            .map_err(|err| DbErr::RbacError(err.to_string()))
556    }
557}
558
559fn map_err(err: RbacError) -> DbErr {
560    DbErr::RbacError(err.to_string())
561}