sea_orm/driver/
mock.rs

1use crate::{
2    DatabaseConnection, DatabaseConnectionType, DbBackend, ExecResult, MockDatabase, QueryResult,
3    Statement, Transaction, debug_print, error::*,
4};
5use std::{
6    fmt::Debug,
7    pin::Pin,
8    sync::{
9        Arc, Mutex,
10        atomic::{AtomicUsize, Ordering},
11    },
12};
13use tracing::instrument;
14
15#[cfg(not(feature = "sync"))]
16type PinBoxStream = Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>>;
17#[cfg(feature = "sync")]
18type PinBoxStream = Box<dyn Iterator<Item = Result<QueryResult, DbErr>>>;
19
20/// Defines a database driver for the [MockDatabase]
21#[derive(Debug)]
22pub struct MockDatabaseConnector;
23
24/// Defines a connection for the [MockDatabase]
25#[derive(Debug)]
26pub struct MockDatabaseConnection {
27    execute_counter: AtomicUsize,
28    query_counter: AtomicUsize,
29    mocker: Mutex<Box<dyn MockDatabaseTrait>>,
30}
31
32/// A Trait for any type wanting to perform operations on the [MockDatabase]
33pub trait MockDatabaseTrait: Debug {
34    /// Execute a statement in the [MockDatabase]
35    fn execute(&mut self, counter: usize, stmt: Statement) -> Result<ExecResult, DbErr>;
36
37    /// Execute a SQL query in the [MockDatabase]
38    fn query(&mut self, counter: usize, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;
39
40    /// Create a transaction that can be committed atomically
41    fn begin(&mut self);
42
43    /// Commit a successful transaction atomically into the [MockDatabase]
44    fn commit(&mut self);
45
46    /// Roll back a transaction since errors were encountered
47    fn rollback(&mut self);
48
49    /// Get all logs from a [MockDatabase] and return a [Transaction]
50    fn drain_transaction_log(&mut self) -> Vec<Transaction>;
51
52    /// Get the backend being used in the [MockDatabase]
53    fn get_database_backend(&self) -> DbBackend;
54
55    /// Ping the [MockDatabase]
56    fn ping(&self) -> Result<(), DbErr>;
57}
58
59impl MockDatabaseConnector {
60    /// Check if the database URI given and the [DatabaseBackend](crate::DatabaseBackend) selected are the same
61    #[allow(unused_variables)]
62    pub fn accepts(string: &str) -> bool {
63        #[cfg(feature = "sqlx-mysql")]
64        if DbBackend::MySql.is_prefix_of(string) {
65            return true;
66        }
67        #[cfg(feature = "sqlx-postgres")]
68        if DbBackend::Postgres.is_prefix_of(string) {
69            return true;
70        }
71        #[cfg(feature = "sqlx-sqlite")]
72        if DbBackend::Sqlite.is_prefix_of(string) {
73            return true;
74        }
75        #[cfg(feature = "rusqlite")]
76        if DbBackend::Sqlite.is_prefix_of(string) {
77            return true;
78        }
79        false
80    }
81
82    /// Connect to the [MockDatabase]
83    #[allow(unused_variables)]
84    #[instrument(level = "trace")]
85    pub fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
86        macro_rules! connect_mock_db {
87            ( $syntax: expr ) => {
88                Ok(DatabaseConnectionType::MockDatabaseConnection(Arc::new(
89                    MockDatabaseConnection::new(MockDatabase::new($syntax)),
90                ))
91                .into())
92            };
93        }
94
95        #[cfg(feature = "sqlx-mysql")]
96        if crate::SqlxMySqlConnector::accepts(string) {
97            return connect_mock_db!(DbBackend::MySql);
98        }
99        #[cfg(feature = "sqlx-postgres")]
100        if crate::SqlxPostgresConnector::accepts(string) {
101            return connect_mock_db!(DbBackend::Postgres);
102        }
103        #[cfg(feature = "sqlx-sqlite")]
104        if crate::SqlxSqliteConnector::accepts(string) {
105            return connect_mock_db!(DbBackend::Sqlite);
106        }
107        #[cfg(feature = "rusqlite")]
108        if crate::driver::rusqlite::RusqliteConnector::accepts(string) {
109            return connect_mock_db!(DbBackend::Sqlite);
110        }
111        connect_mock_db!(DbBackend::Postgres)
112    }
113}
114
115impl MockDatabaseConnection {
116    /// Create a connection to the [MockDatabase]
117    pub fn new<M>(m: M) -> Self
118    where
119        M: MockDatabaseTrait,
120        M: 'static,
121    {
122        Self {
123            execute_counter: AtomicUsize::new(0),
124            query_counter: AtomicUsize::new(0),
125            mocker: Mutex::new(Box::new(m)),
126        }
127    }
128
129    pub(crate) fn get_mocker_mutex(&self) -> &Mutex<Box<dyn MockDatabaseTrait>> {
130        &self.mocker
131    }
132
133    /// Get the [DatabaseBackend](crate::DatabaseBackend) being used by the [MockDatabase]
134    ///
135    /// # Panics
136    ///
137    /// Will panic if the lock cannot be acquired.
138    pub fn get_database_backend(&self) -> DbBackend {
139        self.mocker
140            .lock()
141            .expect("Fail to acquire mocker")
142            .get_database_backend()
143    }
144
145    /// Execute the SQL statement in the [MockDatabase]
146    #[instrument(level = "trace")]
147    pub fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> {
148        debug_print!("{}", statement);
149        let counter = self.execute_counter.fetch_add(1, Ordering::SeqCst);
150        self.mocker
151            .lock()
152            .map_err(exec_err)?
153            .execute(counter, statement)
154    }
155
156    /// Return one [QueryResult] if the query was successful
157    #[instrument(level = "trace")]
158    pub fn query_one(&self, statement: Statement) -> Result<Option<QueryResult>, DbErr> {
159        debug_print!("{}", statement);
160        let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
161        let result = self
162            .mocker
163            .lock()
164            .map_err(query_err)?
165            .query(counter, statement)?;
166        Ok(result.into_iter().next())
167    }
168
169    /// Return all [QueryResult]s if the query was successful
170    #[instrument(level = "trace")]
171    pub fn query_all(&self, statement: Statement) -> Result<Vec<QueryResult>, DbErr> {
172        debug_print!("{}", statement);
173        let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
174        self.mocker
175            .lock()
176            .map_err(query_err)?
177            .query(counter, statement)
178    }
179
180    /// Return [QueryResult]s  from a multi-query operation
181    #[instrument(level = "trace")]
182    pub fn fetch(&self, statement: &Statement) -> PinBoxStream {
183        #[cfg(not(feature = "sync"))]
184        {
185            match self.query_all(statement.clone()) {
186                Ok(v) => Box::new(futures_util::stream::iter(v.into_iter().map(Ok))),
187                Err(e) => Box::new(futures_util::stream::iter(Some(Err(e)).into_iter())),
188            }
189        }
190        #[cfg(feature = "sync")]
191        {
192            match self.query_all(statement.clone()) {
193                Ok(v) => Box::new(v.into_iter().map(Ok)),
194                Err(e) => Box::new(Some(Err(e)).into_iter()),
195            }
196        }
197    }
198
199    /// Create a statement block  of SQL statements that execute together.
200    ///
201    /// # Panics
202    ///
203    /// Will panic if the lock cannot be acquired.
204    #[instrument(level = "trace")]
205    pub fn begin(&self) {
206        self.mocker
207            .lock()
208            .expect("Failed to acquire mocker")
209            .begin()
210    }
211
212    /// Commit a transaction atomically to the database
213    ///
214    /// # Panics
215    ///
216    /// Will panic if the lock cannot be acquired.
217    #[instrument(level = "trace")]
218    pub fn commit(&self) {
219        self.mocker
220            .lock()
221            .expect("Failed to acquire mocker")
222            .commit()
223    }
224
225    /// Roll back a faulty transaction
226    ///
227    /// # Panics
228    ///
229    /// Will panic if the lock cannot be acquired.
230    #[instrument(level = "trace")]
231    pub fn rollback(&self) {
232        self.mocker
233            .lock()
234            .expect("Failed to acquire mocker")
235            .rollback()
236    }
237
238    /// Checks if a connection to the database is still valid.
239    pub fn ping(&self) -> Result<(), DbErr> {
240        self.mocker.lock().map_err(query_err)?.ping()
241    }
242}
243
244impl
245    From<(
246        Arc<crate::MockDatabaseConnection>,
247        Statement,
248        Option<crate::metric::Callback>,
249    )> for crate::QueryStream
250{
251    fn from(
252        (conn, stmt, metric_callback): (
253            Arc<crate::MockDatabaseConnection>,
254            Statement,
255            Option<crate::metric::Callback>,
256        ),
257    ) -> Self {
258        crate::QueryStream::build(stmt, crate::InnerConnection::Mock(conn), metric_callback)
259    }
260}
261
262impl crate::DatabaseTransaction {
263    pub(crate) fn new_mock(
264        inner: Arc<crate::MockDatabaseConnection>,
265        metric_callback: Option<crate::metric::Callback>,
266    ) -> Result<crate::DatabaseTransaction, DbErr> {
267        use std::sync::Mutex;
268        let backend = inner.get_database_backend();
269        Self::begin(
270            Arc::new(Mutex::new(crate::InnerConnection::Mock(inner))),
271            backend,
272            metric_callback,
273            None,
274            None,
275        )
276    }
277}