sea_orm/driver/
mock.rs

1use crate::{
2    DatabaseConnection, DatabaseConnectionType, DbBackend, ExecResult, MockDatabase, QueryResult,
3    Statement, Transaction, debug_print, error::*,
4};
5use futures_util::Stream;
6use std::{
7    fmt::Debug,
8    pin::Pin,
9    sync::{
10        Arc, Mutex,
11        atomic::{AtomicUsize, Ordering},
12    },
13};
14use tracing::instrument;
15
16#[cfg(not(feature = "sync"))]
17type PinBoxStream = Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>>;
18#[cfg(feature = "sync")]
19type PinBoxStream = Box<dyn Iterator<Item = Result<QueryResult, DbErr>>>;
20
21/// Defines a database driver for the [MockDatabase]
22#[derive(Debug)]
23pub struct MockDatabaseConnector;
24
25/// Defines a connection for the [MockDatabase]
26#[derive(Debug)]
27pub struct MockDatabaseConnection {
28    execute_counter: AtomicUsize,
29    query_counter: AtomicUsize,
30    mocker: Mutex<Box<dyn MockDatabaseTrait>>,
31}
32
33/// A Trait for any type wanting to perform operations on the [MockDatabase]
34pub trait MockDatabaseTrait: Send + Debug {
35    /// Execute a statement in the [MockDatabase]
36    fn execute(&mut self, counter: usize, stmt: Statement) -> Result<ExecResult, DbErr>;
37
38    /// Execute a SQL query in the [MockDatabase]
39    fn query(&mut self, counter: usize, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;
40
41    /// Create a transaction that can be committed atomically
42    fn begin(&mut self);
43
44    /// Commit a successful transaction atomically into the [MockDatabase]
45    fn commit(&mut self);
46
47    /// Roll back a transaction since errors were encountered
48    fn rollback(&mut self);
49
50    /// Get all logs from a [MockDatabase] and return a [Transaction]
51    fn drain_transaction_log(&mut self) -> Vec<Transaction>;
52
53    /// Get the backend being used in the [MockDatabase]
54    fn get_database_backend(&self) -> DbBackend;
55
56    /// Ping the [MockDatabase]
57    fn ping(&self) -> Result<(), DbErr>;
58}
59
60impl MockDatabaseConnector {
61    /// Check if the database URI given and the [DatabaseBackend](crate::DatabaseBackend) selected are the same
62    #[allow(unused_variables)]
63    pub fn accepts(string: &str) -> bool {
64        #[cfg(feature = "sqlx-mysql")]
65        if DbBackend::MySql.is_prefix_of(string) {
66            return true;
67        }
68        #[cfg(feature = "sqlx-postgres")]
69        if DbBackend::Postgres.is_prefix_of(string) {
70            return true;
71        }
72        #[cfg(feature = "sqlx-sqlite")]
73        if DbBackend::Sqlite.is_prefix_of(string) {
74            return true;
75        }
76        #[cfg(feature = "rusqlite")]
77        if DbBackend::Sqlite.is_prefix_of(string) {
78            return true;
79        }
80        false
81    }
82
83    /// Connect to the [MockDatabase]
84    #[allow(unused_variables)]
85    #[instrument(level = "trace")]
86    pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
87        macro_rules! connect_mock_db {
88            ( $syntax: expr ) => {
89                Ok(DatabaseConnectionType::MockDatabaseConnection(Arc::new(
90                    MockDatabaseConnection::new(MockDatabase::new($syntax)),
91                ))
92                .into())
93            };
94        }
95
96        #[cfg(feature = "sqlx-mysql")]
97        if crate::SqlxMySqlConnector::accepts(string) {
98            return connect_mock_db!(DbBackend::MySql);
99        }
100        #[cfg(feature = "sqlx-postgres")]
101        if crate::SqlxPostgresConnector::accepts(string) {
102            return connect_mock_db!(DbBackend::Postgres);
103        }
104        #[cfg(feature = "sqlx-sqlite")]
105        if crate::SqlxSqliteConnector::accepts(string) {
106            return connect_mock_db!(DbBackend::Sqlite);
107        }
108        #[cfg(feature = "rusqlite")]
109        if crate::driver::rusqlite::RusqliteConnector::accepts(string) {
110            return connect_mock_db!(DbBackend::Sqlite);
111        }
112        connect_mock_db!(DbBackend::Postgres)
113    }
114}
115
116impl MockDatabaseConnection {
117    /// Create a connection to the [MockDatabase]
118    pub fn new<M>(m: M) -> Self
119    where
120        M: MockDatabaseTrait,
121        M: 'static,
122    {
123        Self {
124            execute_counter: AtomicUsize::new(0),
125            query_counter: AtomicUsize::new(0),
126            mocker: Mutex::new(Box::new(m)),
127        }
128    }
129
130    pub(crate) fn get_mocker_mutex(&self) -> &Mutex<Box<dyn MockDatabaseTrait>> {
131        &self.mocker
132    }
133
134    /// Get the [DatabaseBackend](crate::DatabaseBackend) being used by the [MockDatabase]
135    ///
136    /// # Panics
137    ///
138    /// Will panic if the lock cannot be acquired.
139    pub fn get_database_backend(&self) -> DbBackend {
140        self.mocker
141            .lock()
142            .expect("Fail to acquire mocker")
143            .get_database_backend()
144    }
145
146    /// Execute the SQL statement in the [MockDatabase]
147    #[instrument(level = "trace")]
148    pub fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> {
149        debug_print!("{}", statement);
150        let counter = self.execute_counter.fetch_add(1, Ordering::SeqCst);
151        self.mocker
152            .lock()
153            .map_err(exec_err)?
154            .execute(counter, statement)
155    }
156
157    /// Return one [QueryResult] if the query was successful
158    #[instrument(level = "trace")]
159    pub fn query_one(&self, statement: Statement) -> Result<Option<QueryResult>, DbErr> {
160        debug_print!("{}", statement);
161        let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
162        let result = self
163            .mocker
164            .lock()
165            .map_err(query_err)?
166            .query(counter, statement)?;
167        Ok(result.into_iter().next())
168    }
169
170    /// Return all [QueryResult]s if the query was successful
171    #[instrument(level = "trace")]
172    pub fn query_all(&self, statement: Statement) -> Result<Vec<QueryResult>, DbErr> {
173        debug_print!("{}", statement);
174        let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
175        self.mocker
176            .lock()
177            .map_err(query_err)?
178            .query(counter, statement)
179    }
180
181    /// Return [QueryResult]s  from a multi-query operation
182    #[instrument(level = "trace")]
183    pub fn fetch(&self, statement: &Statement) -> PinBoxStream {
184        #[cfg(not(feature = "sync"))]
185        {
186            match self.query_all(statement.clone()) {
187                Ok(v) => Box::pin(futures_util::stream::iter(v.into_iter().map(Ok))),
188                Err(e) => Box::pin(futures_util::stream::iter(Some(Err(e)).into_iter())),
189            }
190        }
191        #[cfg(feature = "sync")]
192        {
193            match self.query_all(statement.clone()) {
194                Ok(v) => Box::new(v.into_iter().map(Ok)),
195                Err(e) => Box::new(Some(Err(e)).into_iter()),
196            }
197        }
198    }
199
200    /// Create a statement block  of SQL statements that execute together.
201    ///
202    /// # Panics
203    ///
204    /// Will panic if the lock cannot be acquired.
205    #[instrument(level = "trace")]
206    pub fn begin(&self) {
207        self.mocker
208            .lock()
209            .expect("Failed to acquire mocker")
210            .begin()
211    }
212
213    /// Commit a transaction atomically to the database
214    ///
215    /// # Panics
216    ///
217    /// Will panic if the lock cannot be acquired.
218    #[instrument(level = "trace")]
219    pub fn commit(&self) {
220        self.mocker
221            .lock()
222            .expect("Failed to acquire mocker")
223            .commit()
224    }
225
226    /// Roll back a faulty transaction
227    ///
228    /// # Panics
229    ///
230    /// Will panic if the lock cannot be acquired.
231    #[instrument(level = "trace")]
232    pub fn rollback(&self) {
233        self.mocker
234            .lock()
235            .expect("Failed to acquire mocker")
236            .rollback()
237    }
238
239    /// Checks if a connection to the database is still valid.
240    pub fn ping(&self) -> Result<(), DbErr> {
241        self.mocker.lock().map_err(query_err)?.ping()
242    }
243}
244
245impl
246    From<(
247        Arc<crate::MockDatabaseConnection>,
248        Statement,
249        Option<crate::metric::Callback>,
250    )> for crate::QueryStream
251{
252    fn from(
253        (conn, stmt, metric_callback): (
254            Arc<crate::MockDatabaseConnection>,
255            Statement,
256            Option<crate::metric::Callback>,
257        ),
258    ) -> Self {
259        crate::QueryStream::build(stmt, crate::InnerConnection::Mock(conn), metric_callback)
260    }
261}
262
263impl crate::DatabaseTransaction {
264    pub(crate) async fn new_mock(
265        inner: Arc<crate::MockDatabaseConnection>,
266        metric_callback: Option<crate::metric::Callback>,
267    ) -> Result<crate::DatabaseTransaction, DbErr> {
268        use futures_util::lock::Mutex;
269        let backend = inner.get_database_backend();
270        Self::begin(
271            Arc::new(Mutex::new(crate::InnerConnection::Mock(inner))),
272            backend,
273            metric_callback,
274            None,
275            None,
276        )
277        .await
278    }
279}