sea_orm/driver/
mock.rs

1use crate::{
2    debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult,
3    Statement, Transaction,
4};
5use futures_util::Stream;
6use std::{
7    fmt::Debug,
8    pin::Pin,
9    sync::{
10        atomic::{AtomicUsize, Ordering},
11        Arc, Mutex,
12    },
13};
14use tracing::instrument;
15
16/// Defines a database driver for the [MockDatabase]
17#[derive(Debug)]
18pub struct MockDatabaseConnector;
19
20/// Defines a connection for the [MockDatabase]
21#[derive(Debug)]
22pub struct MockDatabaseConnection {
23    execute_counter: AtomicUsize,
24    query_counter: AtomicUsize,
25    mocker: Mutex<Box<dyn MockDatabaseTrait>>,
26}
27
28/// A Trait for any type wanting to perform operations on the [MockDatabase]
29pub trait MockDatabaseTrait: Send + Debug {
30    /// Execute a statement in the [MockDatabase]
31    fn execute(&mut self, counter: usize, stmt: Statement) -> Result<ExecResult, DbErr>;
32
33    /// Execute a SQL query in the [MockDatabase]
34    fn query(&mut self, counter: usize, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;
35
36    /// Create a transaction that can be committed atomically
37    fn begin(&mut self);
38
39    /// Commit a successful transaction atomically into the [MockDatabase]
40    fn commit(&mut self);
41
42    /// Roll back a transaction since errors were encountered
43    fn rollback(&mut self);
44
45    /// Get all logs from a [MockDatabase] and return a [Transaction]
46    fn drain_transaction_log(&mut self) -> Vec<Transaction>;
47
48    /// Get the backend being used in the [MockDatabase]
49    fn get_database_backend(&self) -> DbBackend;
50
51    /// Ping the [MockDatabase]
52    fn ping(&self) -> Result<(), DbErr>;
53}
54
55impl MockDatabaseConnector {
56    /// Check if the database URI given and the [DatabaseBackend](crate::DatabaseBackend) selected are the same
57    #[allow(unused_variables)]
58    pub fn accepts(string: &str) -> bool {
59        #[cfg(feature = "sqlx-mysql")]
60        if DbBackend::MySql.is_prefix_of(string) {
61            return true;
62        }
63        #[cfg(feature = "sqlx-postgres")]
64        if DbBackend::Postgres.is_prefix_of(string) {
65            return true;
66        }
67        #[cfg(feature = "sqlx-sqlite")]
68        if DbBackend::Sqlite.is_prefix_of(string) {
69            return true;
70        }
71        false
72    }
73
74    /// Connect to the [MockDatabase]
75    #[allow(unused_variables)]
76    #[instrument(level = "trace")]
77    pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
78        macro_rules! connect_mock_db {
79            ( $syntax: expr ) => {
80                Ok(DatabaseConnection::MockDatabaseConnection(Arc::new(
81                    MockDatabaseConnection::new(MockDatabase::new($syntax)),
82                )))
83            };
84        }
85
86        #[cfg(feature = "sqlx-mysql")]
87        if crate::SqlxMySqlConnector::accepts(string) {
88            return connect_mock_db!(DbBackend::MySql);
89        }
90        #[cfg(feature = "sqlx-postgres")]
91        if crate::SqlxPostgresConnector::accepts(string) {
92            return connect_mock_db!(DbBackend::Postgres);
93        }
94        #[cfg(feature = "sqlx-sqlite")]
95        if crate::SqlxSqliteConnector::accepts(string) {
96            return connect_mock_db!(DbBackend::Sqlite);
97        }
98        connect_mock_db!(DbBackend::Postgres)
99    }
100}
101
102impl MockDatabaseConnection {
103    /// Create a connection to the [MockDatabase]
104    pub fn new<M: 'static>(m: M) -> Self
105    where
106        M: MockDatabaseTrait,
107    {
108        Self {
109            execute_counter: AtomicUsize::new(0),
110            query_counter: AtomicUsize::new(0),
111            mocker: Mutex::new(Box::new(m)),
112        }
113    }
114
115    pub(crate) fn get_mocker_mutex(&self) -> &Mutex<Box<dyn MockDatabaseTrait>> {
116        &self.mocker
117    }
118
119    /// Get the [DatabaseBackend](crate::DatabaseBackend) being used by the [MockDatabase]
120    ///
121    /// # Panics
122    ///
123    /// Will panic if the lock cannot be acquired.
124    pub fn get_database_backend(&self) -> DbBackend {
125        self.mocker
126            .lock()
127            .expect("Fail to acquire mocker")
128            .get_database_backend()
129    }
130
131    /// Execute the SQL statement in the [MockDatabase]
132    #[instrument(level = "trace")]
133    pub fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> {
134        debug_print!("{}", statement);
135        let counter = self.execute_counter.fetch_add(1, Ordering::SeqCst);
136        self.mocker
137            .lock()
138            .map_err(exec_err)?
139            .execute(counter, statement)
140    }
141
142    /// Return one [QueryResult] if the query was successful
143    #[instrument(level = "trace")]
144    pub fn query_one(&self, statement: Statement) -> Result<Option<QueryResult>, DbErr> {
145        debug_print!("{}", statement);
146        let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
147        let result = self
148            .mocker
149            .lock()
150            .map_err(query_err)?
151            .query(counter, statement)?;
152        Ok(result.into_iter().next())
153    }
154
155    /// Return all [QueryResult]s if the query was successful
156    #[instrument(level = "trace")]
157    pub fn query_all(&self, statement: Statement) -> Result<Vec<QueryResult>, DbErr> {
158        debug_print!("{}", statement);
159        let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
160        self.mocker
161            .lock()
162            .map_err(query_err)?
163            .query(counter, statement)
164    }
165
166    /// Return [QueryResult]s  from a multi-query operation
167    #[instrument(level = "trace")]
168    pub fn fetch(
169        &self,
170        statement: &Statement,
171    ) -> Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>> {
172        match self.query_all(statement.clone()) {
173            Ok(v) => Box::pin(futures_util::stream::iter(v.into_iter().map(Ok))),
174            Err(e) => Box::pin(futures_util::stream::iter(Some(Err(e)).into_iter())),
175        }
176    }
177
178    /// Create a statement block  of SQL statements that execute together.
179    #[instrument(level = "trace")]
180    pub fn begin(&self) {
181        self.mocker
182            .lock()
183            .expect("Failed to acquire mocker")
184            .begin()
185    }
186
187    /// Commit a transaction atomically to the database
188    #[instrument(level = "trace")]
189    pub fn commit(&self) {
190        self.mocker
191            .lock()
192            .expect("Failed to acquire mocker")
193            .commit()
194    }
195
196    /// Roll back a faulty transaction
197    #[instrument(level = "trace")]
198    pub fn rollback(&self) {
199        self.mocker
200            .lock()
201            .expect("Failed to acquire mocker")
202            .rollback()
203    }
204
205    /// Checks if a connection to the database is still valid.
206    pub fn ping(&self) -> Result<(), DbErr> {
207        self.mocker.lock().map_err(query_err)?.ping()
208    }
209}
210
211impl
212    From<(
213        Arc<crate::MockDatabaseConnection>,
214        Statement,
215        Option<crate::metric::Callback>,
216    )> for crate::QueryStream
217{
218    fn from(
219        (conn, stmt, metric_callback): (
220            Arc<crate::MockDatabaseConnection>,
221            Statement,
222            Option<crate::metric::Callback>,
223        ),
224    ) -> Self {
225        crate::QueryStream::build(stmt, crate::InnerConnection::Mock(conn), metric_callback)
226    }
227}
228
229impl crate::DatabaseTransaction {
230    pub(crate) async fn new_mock(
231        inner: Arc<crate::MockDatabaseConnection>,
232        metric_callback: Option<crate::metric::Callback>,
233    ) -> Result<crate::DatabaseTransaction, DbErr> {
234        use futures_util::lock::Mutex;
235        let backend = inner.get_database_backend();
236        Self::begin(
237            Arc::new(Mutex::new(crate::InnerConnection::Mock(inner))),
238            backend,
239            metric_callback,
240            None,
241            None,
242        )
243        .await
244    }
245}