sea_orm/driver/
mock.rs

1use crate::{
2    debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult,
3    Statement, Transaction,
4};
5use futures_util::Stream;
6use std::{
7    fmt::Debug,
8    pin::Pin,
9    sync::{
10        atomic::{AtomicUsize, Ordering},
11        Arc, Mutex,
12    },
13};
14use tracing::instrument;
15
16/// Defines a database driver for the [MockDatabase]
17#[derive(Debug)]
18pub struct MockDatabaseConnector;
19
20/// Defines a connection for the [MockDatabase]
21#[derive(Debug)]
22pub struct MockDatabaseConnection {
23    execute_counter: AtomicUsize,
24    query_counter: AtomicUsize,
25    mocker: Mutex<Box<dyn MockDatabaseTrait>>,
26}
27
28/// A Trait for any type wanting to perform operations on the [MockDatabase]
29pub trait MockDatabaseTrait: Send + Debug {
30    /// Execute a statement in the [MockDatabase]
31    fn execute(&mut self, counter: usize, stmt: Statement) -> Result<ExecResult, DbErr>;
32
33    /// Execute a SQL query in the [MockDatabase]
34    fn query(&mut self, counter: usize, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;
35
36    /// Create a transaction that can be committed atomically
37    fn begin(&mut self);
38
39    /// Commit a successful transaction atomically into the [MockDatabase]
40    fn commit(&mut self);
41
42    /// Roll back a transaction since errors were encountered
43    fn rollback(&mut self);
44
45    /// Get all logs from a [MockDatabase] and return a [Transaction]
46    fn drain_transaction_log(&mut self) -> Vec<Transaction>;
47
48    /// Get the backend being used in the [MockDatabase]
49    fn get_database_backend(&self) -> DbBackend;
50
51    /// Ping the [MockDatabase]
52    fn ping(&self) -> Result<(), DbErr>;
53}
54
55impl MockDatabaseConnector {
56    /// Check if the database URI given and the [DatabaseBackend](crate::DatabaseBackend) selected are the same
57    #[allow(unused_variables)]
58    pub fn accepts(string: &str) -> bool {
59        #[cfg(feature = "sqlx-mysql")]
60        if DbBackend::MySql.is_prefix_of(string) {
61            return true;
62        }
63        #[cfg(feature = "sqlx-postgres")]
64        if DbBackend::Postgres.is_prefix_of(string) {
65            return true;
66        }
67        #[cfg(feature = "sqlx-sqlite")]
68        if DbBackend::Sqlite.is_prefix_of(string) {
69            return true;
70        }
71        false
72    }
73
74    /// Connect to the [MockDatabase]
75    #[allow(unused_variables)]
76    #[instrument(level = "trace")]
77    pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
78        macro_rules! connect_mock_db {
79            ( $syntax: expr ) => {
80                Ok(DatabaseConnection::MockDatabaseConnection(Arc::new(
81                    MockDatabaseConnection::new(MockDatabase::new($syntax)),
82                )))
83            };
84        }
85
86        #[cfg(feature = "sqlx-mysql")]
87        if crate::SqlxMySqlConnector::accepts(string) {
88            return connect_mock_db!(DbBackend::MySql);
89        }
90        #[cfg(feature = "sqlx-postgres")]
91        if crate::SqlxPostgresConnector::accepts(string) {
92            return connect_mock_db!(DbBackend::Postgres);
93        }
94        #[cfg(feature = "sqlx-sqlite")]
95        if crate::SqlxSqliteConnector::accepts(string) {
96            return connect_mock_db!(DbBackend::Sqlite);
97        }
98        connect_mock_db!(DbBackend::Postgres)
99    }
100}
101
102impl MockDatabaseConnection {
103    /// Create a connection to the [MockDatabase]
104    pub fn new<M>(m: M) -> Self
105    where
106        M: MockDatabaseTrait,
107        M: 'static,
108    {
109        Self {
110            execute_counter: AtomicUsize::new(0),
111            query_counter: AtomicUsize::new(0),
112            mocker: Mutex::new(Box::new(m)),
113        }
114    }
115
116    pub(crate) fn get_mocker_mutex(&self) -> &Mutex<Box<dyn MockDatabaseTrait>> {
117        &self.mocker
118    }
119
120    /// Get the [DatabaseBackend](crate::DatabaseBackend) being used by the [MockDatabase]
121    ///
122    /// # Panics
123    ///
124    /// Will panic if the lock cannot be acquired.
125    pub fn get_database_backend(&self) -> DbBackend {
126        self.mocker
127            .lock()
128            .expect("Fail to acquire mocker")
129            .get_database_backend()
130    }
131
132    /// Execute the SQL statement in the [MockDatabase]
133    #[instrument(level = "trace")]
134    pub fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> {
135        debug_print!("{}", statement);
136        let counter = self.execute_counter.fetch_add(1, Ordering::SeqCst);
137        self.mocker
138            .lock()
139            .map_err(exec_err)?
140            .execute(counter, statement)
141    }
142
143    /// Return one [QueryResult] if the query was successful
144    #[instrument(level = "trace")]
145    pub fn query_one(&self, statement: Statement) -> Result<Option<QueryResult>, DbErr> {
146        debug_print!("{}", statement);
147        let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
148        let result = self
149            .mocker
150            .lock()
151            .map_err(query_err)?
152            .query(counter, statement)?;
153        Ok(result.into_iter().next())
154    }
155
156    /// Return all [QueryResult]s if the query was successful
157    #[instrument(level = "trace")]
158    pub fn query_all(&self, statement: Statement) -> Result<Vec<QueryResult>, DbErr> {
159        debug_print!("{}", statement);
160        let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
161        self.mocker
162            .lock()
163            .map_err(query_err)?
164            .query(counter, statement)
165    }
166
167    /// Return [QueryResult]s  from a multi-query operation
168    #[instrument(level = "trace")]
169    pub fn fetch(
170        &self,
171        statement: &Statement,
172    ) -> Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>> {
173        match self.query_all(statement.clone()) {
174            Ok(v) => Box::pin(futures_util::stream::iter(v.into_iter().map(Ok))),
175            Err(e) => Box::pin(futures_util::stream::iter(Some(Err(e)).into_iter())),
176        }
177    }
178
179    /// Create a statement block  of SQL statements that execute together.
180    ///
181    /// # Panics
182    ///
183    /// Will panic if the lock cannot be acquired.
184    #[instrument(level = "trace")]
185    pub fn begin(&self) {
186        self.mocker
187            .lock()
188            .expect("Failed to acquire mocker")
189            .begin()
190    }
191
192    /// Commit a transaction atomically to the database
193    ///
194    /// # Panics
195    ///
196    /// Will panic if the lock cannot be acquired.
197    #[instrument(level = "trace")]
198    pub fn commit(&self) {
199        self.mocker
200            .lock()
201            .expect("Failed to acquire mocker")
202            .commit()
203    }
204
205    /// Roll back a faulty transaction
206    ///
207    /// # Panics
208    ///
209    /// Will panic if the lock cannot be acquired.
210    #[instrument(level = "trace")]
211    pub fn rollback(&self) {
212        self.mocker
213            .lock()
214            .expect("Failed to acquire mocker")
215            .rollback()
216    }
217
218    /// Checks if a connection to the database is still valid.
219    pub fn ping(&self) -> Result<(), DbErr> {
220        self.mocker.lock().map_err(query_err)?.ping()
221    }
222}
223
224impl
225    From<(
226        Arc<crate::MockDatabaseConnection>,
227        Statement,
228        Option<crate::metric::Callback>,
229    )> for crate::QueryStream
230{
231    fn from(
232        (conn, stmt, metric_callback): (
233            Arc<crate::MockDatabaseConnection>,
234            Statement,
235            Option<crate::metric::Callback>,
236        ),
237    ) -> Self {
238        crate::QueryStream::build(stmt, crate::InnerConnection::Mock(conn), metric_callback)
239    }
240}
241
242impl crate::DatabaseTransaction {
243    pub(crate) async fn new_mock(
244        inner: Arc<crate::MockDatabaseConnection>,
245        metric_callback: Option<crate::metric::Callback>,
246    ) -> Result<crate::DatabaseTransaction, DbErr> {
247        use futures_util::lock::Mutex;
248        let backend = inner.get_database_backend();
249        Self::begin(
250            Arc::new(Mutex::new(crate::InnerConnection::Mock(inner))),
251            backend,
252            metric_callback,
253            None,
254            None,
255        )
256        .await
257    }
258}