sea_orm/driver/
mock.rs

1use crate::{
2    DatabaseConnection, DatabaseConnectionType, DbBackend, ExecResult, MockDatabase, QueryResult,
3    Statement, Transaction, debug_print, error::*,
4};
5use futures_util::Stream;
6use std::{
7    fmt::Debug,
8    pin::Pin,
9    sync::{
10        Arc, Mutex,
11        atomic::{AtomicUsize, Ordering},
12    },
13};
14use tracing::instrument;
15
16/// Defines a database driver for the [MockDatabase]
17#[derive(Debug)]
18pub struct MockDatabaseConnector;
19
20/// Defines a connection for the [MockDatabase]
21#[derive(Debug)]
22pub struct MockDatabaseConnection {
23    execute_counter: AtomicUsize,
24    query_counter: AtomicUsize,
25    mocker: Mutex<Box<dyn MockDatabaseTrait>>,
26}
27
28/// A Trait for any type wanting to perform operations on the [MockDatabase]
29pub trait MockDatabaseTrait: Send + Debug {
30    /// Execute a statement in the [MockDatabase]
31    fn execute(&mut self, counter: usize, stmt: Statement) -> Result<ExecResult, DbErr>;
32
33    /// Execute a SQL query in the [MockDatabase]
34    fn query(&mut self, counter: usize, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;
35
36    /// Create a transaction that can be committed atomically
37    fn begin(&mut self);
38
39    /// Commit a successful transaction atomically into the [MockDatabase]
40    fn commit(&mut self);
41
42    /// Roll back a transaction since errors were encountered
43    fn rollback(&mut self);
44
45    /// Get all logs from a [MockDatabase] and return a [Transaction]
46    fn drain_transaction_log(&mut self) -> Vec<Transaction>;
47
48    /// Get the backend being used in the [MockDatabase]
49    fn get_database_backend(&self) -> DbBackend;
50
51    /// Ping the [MockDatabase]
52    fn ping(&self) -> Result<(), DbErr>;
53}
54
55impl MockDatabaseConnector {
56    /// Check if the database URI given and the [DatabaseBackend](crate::DatabaseBackend) selected are the same
57    #[allow(unused_variables)]
58    pub fn accepts(string: &str) -> bool {
59        #[cfg(feature = "sqlx-mysql")]
60        if DbBackend::MySql.is_prefix_of(string) {
61            return true;
62        }
63        #[cfg(feature = "sqlx-postgres")]
64        if DbBackend::Postgres.is_prefix_of(string) {
65            return true;
66        }
67        #[cfg(feature = "sqlx-sqlite")]
68        if DbBackend::Sqlite.is_prefix_of(string) {
69            return true;
70        }
71        false
72    }
73
74    /// Connect to the [MockDatabase]
75    #[allow(unused_variables)]
76    #[instrument(level = "trace")]
77    pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
78        macro_rules! connect_mock_db {
79            ( $syntax: expr ) => {
80                Ok(DatabaseConnectionType::MockDatabaseConnection(Arc::new(
81                    MockDatabaseConnection::new(MockDatabase::new($syntax)),
82                ))
83                .into())
84            };
85        }
86
87        #[cfg(feature = "sqlx-mysql")]
88        if crate::SqlxMySqlConnector::accepts(string) {
89            return connect_mock_db!(DbBackend::MySql);
90        }
91        #[cfg(feature = "sqlx-postgres")]
92        if crate::SqlxPostgresConnector::accepts(string) {
93            return connect_mock_db!(DbBackend::Postgres);
94        }
95        #[cfg(feature = "sqlx-sqlite")]
96        if crate::SqlxSqliteConnector::accepts(string) {
97            return connect_mock_db!(DbBackend::Sqlite);
98        }
99        connect_mock_db!(DbBackend::Postgres)
100    }
101}
102
103impl MockDatabaseConnection {
104    /// Create a connection to the [MockDatabase]
105    pub fn new<M>(m: M) -> Self
106    where
107        M: MockDatabaseTrait,
108        M: 'static,
109    {
110        Self {
111            execute_counter: AtomicUsize::new(0),
112            query_counter: AtomicUsize::new(0),
113            mocker: Mutex::new(Box::new(m)),
114        }
115    }
116
117    pub(crate) fn get_mocker_mutex(&self) -> &Mutex<Box<dyn MockDatabaseTrait>> {
118        &self.mocker
119    }
120
121    /// Get the [DatabaseBackend](crate::DatabaseBackend) being used by the [MockDatabase]
122    ///
123    /// # Panics
124    ///
125    /// Will panic if the lock cannot be acquired.
126    pub fn get_database_backend(&self) -> DbBackend {
127        self.mocker
128            .lock()
129            .expect("Fail to acquire mocker")
130            .get_database_backend()
131    }
132
133    /// Execute the SQL statement in the [MockDatabase]
134    #[instrument(level = "trace")]
135    pub fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> {
136        debug_print!("{}", statement);
137        let counter = self.execute_counter.fetch_add(1, Ordering::SeqCst);
138        self.mocker
139            .lock()
140            .map_err(exec_err)?
141            .execute(counter, statement)
142    }
143
144    /// Return one [QueryResult] if the query was successful
145    #[instrument(level = "trace")]
146    pub fn query_one(&self, statement: Statement) -> Result<Option<QueryResult>, DbErr> {
147        debug_print!("{}", statement);
148        let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
149        let result = self
150            .mocker
151            .lock()
152            .map_err(query_err)?
153            .query(counter, statement)?;
154        Ok(result.into_iter().next())
155    }
156
157    /// Return all [QueryResult]s if the query was successful
158    #[instrument(level = "trace")]
159    pub fn query_all(&self, statement: Statement) -> Result<Vec<QueryResult>, DbErr> {
160        debug_print!("{}", statement);
161        let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
162        self.mocker
163            .lock()
164            .map_err(query_err)?
165            .query(counter, statement)
166    }
167
168    /// Return [QueryResult]s  from a multi-query operation
169    #[instrument(level = "trace")]
170    pub fn fetch(
171        &self,
172        statement: &Statement,
173    ) -> Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>> {
174        match self.query_all(statement.clone()) {
175            Ok(v) => Box::pin(futures_util::stream::iter(v.into_iter().map(Ok))),
176            Err(e) => Box::pin(futures_util::stream::iter(Some(Err(e)).into_iter())),
177        }
178    }
179
180    /// Create a statement block  of SQL statements that execute together.
181    ///
182    /// # Panics
183    ///
184    /// Will panic if the lock cannot be acquired.
185    #[instrument(level = "trace")]
186    pub fn begin(&self) {
187        self.mocker
188            .lock()
189            .expect("Failed to acquire mocker")
190            .begin()
191    }
192
193    /// Commit a transaction atomically to the database
194    ///
195    /// # Panics
196    ///
197    /// Will panic if the lock cannot be acquired.
198    #[instrument(level = "trace")]
199    pub fn commit(&self) {
200        self.mocker
201            .lock()
202            .expect("Failed to acquire mocker")
203            .commit()
204    }
205
206    /// Roll back a faulty transaction
207    ///
208    /// # Panics
209    ///
210    /// Will panic if the lock cannot be acquired.
211    #[instrument(level = "trace")]
212    pub fn rollback(&self) {
213        self.mocker
214            .lock()
215            .expect("Failed to acquire mocker")
216            .rollback()
217    }
218
219    /// Checks if a connection to the database is still valid.
220    pub fn ping(&self) -> Result<(), DbErr> {
221        self.mocker.lock().map_err(query_err)?.ping()
222    }
223}
224
225impl
226    From<(
227        Arc<crate::MockDatabaseConnection>,
228        Statement,
229        Option<crate::metric::Callback>,
230    )> for crate::QueryStream
231{
232    fn from(
233        (conn, stmt, metric_callback): (
234            Arc<crate::MockDatabaseConnection>,
235            Statement,
236            Option<crate::metric::Callback>,
237        ),
238    ) -> Self {
239        crate::QueryStream::build(stmt, crate::InnerConnection::Mock(conn), metric_callback)
240    }
241}
242
243impl crate::DatabaseTransaction {
244    pub(crate) async fn new_mock(
245        inner: Arc<crate::MockDatabaseConnection>,
246        metric_callback: Option<crate::metric::Callback>,
247    ) -> Result<crate::DatabaseTransaction, DbErr> {
248        use futures_util::lock::Mutex;
249        let backend = inner.get_database_backend();
250        Self::begin(
251            Arc::new(Mutex::new(crate::InnerConnection::Mock(inner))),
252            backend,
253            metric_callback,
254            None,
255            None,
256        )
257        .await
258    }
259}