1use crate::{
2 DatabaseConnection, DatabaseConnectionType, DbBackend, ExecResult, MockDatabase, QueryResult,
3 Statement, Transaction, debug_print, error::*,
4};
5use futures_util::Stream;
6use std::{
7 fmt::Debug,
8 pin::Pin,
9 sync::{
10 Arc, Mutex,
11 atomic::{AtomicUsize, Ordering},
12 },
13};
14use tracing::instrument;
15
16#[derive(Debug)]
18pub struct MockDatabaseConnector;
19
20#[derive(Debug)]
22pub struct MockDatabaseConnection {
23 execute_counter: AtomicUsize,
24 query_counter: AtomicUsize,
25 mocker: Mutex<Box<dyn MockDatabaseTrait>>,
26}
27
28pub trait MockDatabaseTrait: Send + Debug {
30 fn execute(&mut self, counter: usize, stmt: Statement) -> Result<ExecResult, DbErr>;
32
33 fn query(&mut self, counter: usize, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;
35
36 fn begin(&mut self);
38
39 fn commit(&mut self);
41
42 fn rollback(&mut self);
44
45 fn drain_transaction_log(&mut self) -> Vec<Transaction>;
47
48 fn get_database_backend(&self) -> DbBackend;
50
51 fn ping(&self) -> Result<(), DbErr>;
53}
54
55impl MockDatabaseConnector {
56 #[allow(unused_variables)]
58 pub fn accepts(string: &str) -> bool {
59 #[cfg(feature = "sqlx-mysql")]
60 if DbBackend::MySql.is_prefix_of(string) {
61 return true;
62 }
63 #[cfg(feature = "sqlx-postgres")]
64 if DbBackend::Postgres.is_prefix_of(string) {
65 return true;
66 }
67 #[cfg(feature = "sqlx-sqlite")]
68 if DbBackend::Sqlite.is_prefix_of(string) {
69 return true;
70 }
71 false
72 }
73
74 #[allow(unused_variables)]
76 #[instrument(level = "trace")]
77 pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
78 macro_rules! connect_mock_db {
79 ( $syntax: expr ) => {
80 Ok(DatabaseConnectionType::MockDatabaseConnection(Arc::new(
81 MockDatabaseConnection::new(MockDatabase::new($syntax)),
82 ))
83 .into())
84 };
85 }
86
87 #[cfg(feature = "sqlx-mysql")]
88 if crate::SqlxMySqlConnector::accepts(string) {
89 return connect_mock_db!(DbBackend::MySql);
90 }
91 #[cfg(feature = "sqlx-postgres")]
92 if crate::SqlxPostgresConnector::accepts(string) {
93 return connect_mock_db!(DbBackend::Postgres);
94 }
95 #[cfg(feature = "sqlx-sqlite")]
96 if crate::SqlxSqliteConnector::accepts(string) {
97 return connect_mock_db!(DbBackend::Sqlite);
98 }
99 connect_mock_db!(DbBackend::Postgres)
100 }
101}
102
103impl MockDatabaseConnection {
104 pub fn new<M>(m: M) -> Self
106 where
107 M: MockDatabaseTrait,
108 M: 'static,
109 {
110 Self {
111 execute_counter: AtomicUsize::new(0),
112 query_counter: AtomicUsize::new(0),
113 mocker: Mutex::new(Box::new(m)),
114 }
115 }
116
117 pub(crate) fn get_mocker_mutex(&self) -> &Mutex<Box<dyn MockDatabaseTrait>> {
118 &self.mocker
119 }
120
121 pub fn get_database_backend(&self) -> DbBackend {
127 self.mocker
128 .lock()
129 .expect("Fail to acquire mocker")
130 .get_database_backend()
131 }
132
133 #[instrument(level = "trace")]
135 pub fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> {
136 debug_print!("{}", statement);
137 let counter = self.execute_counter.fetch_add(1, Ordering::SeqCst);
138 self.mocker
139 .lock()
140 .map_err(exec_err)?
141 .execute(counter, statement)
142 }
143
144 #[instrument(level = "trace")]
146 pub fn query_one(&self, statement: Statement) -> Result<Option<QueryResult>, DbErr> {
147 debug_print!("{}", statement);
148 let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
149 let result = self
150 .mocker
151 .lock()
152 .map_err(query_err)?
153 .query(counter, statement)?;
154 Ok(result.into_iter().next())
155 }
156
157 #[instrument(level = "trace")]
159 pub fn query_all(&self, statement: Statement) -> Result<Vec<QueryResult>, DbErr> {
160 debug_print!("{}", statement);
161 let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
162 self.mocker
163 .lock()
164 .map_err(query_err)?
165 .query(counter, statement)
166 }
167
168 #[instrument(level = "trace")]
170 pub fn fetch(
171 &self,
172 statement: &Statement,
173 ) -> Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>> {
174 match self.query_all(statement.clone()) {
175 Ok(v) => Box::pin(futures_util::stream::iter(v.into_iter().map(Ok))),
176 Err(e) => Box::pin(futures_util::stream::iter(Some(Err(e)).into_iter())),
177 }
178 }
179
180 #[instrument(level = "trace")]
186 pub fn begin(&self) {
187 self.mocker
188 .lock()
189 .expect("Failed to acquire mocker")
190 .begin()
191 }
192
193 #[instrument(level = "trace")]
199 pub fn commit(&self) {
200 self.mocker
201 .lock()
202 .expect("Failed to acquire mocker")
203 .commit()
204 }
205
206 #[instrument(level = "trace")]
212 pub fn rollback(&self) {
213 self.mocker
214 .lock()
215 .expect("Failed to acquire mocker")
216 .rollback()
217 }
218
219 pub fn ping(&self) -> Result<(), DbErr> {
221 self.mocker.lock().map_err(query_err)?.ping()
222 }
223}
224
225impl
226 From<(
227 Arc<crate::MockDatabaseConnection>,
228 Statement,
229 Option<crate::metric::Callback>,
230 )> for crate::QueryStream
231{
232 fn from(
233 (conn, stmt, metric_callback): (
234 Arc<crate::MockDatabaseConnection>,
235 Statement,
236 Option<crate::metric::Callback>,
237 ),
238 ) -> Self {
239 crate::QueryStream::build(stmt, crate::InnerConnection::Mock(conn), metric_callback)
240 }
241}
242
243impl crate::DatabaseTransaction {
244 pub(crate) async fn new_mock(
245 inner: Arc<crate::MockDatabaseConnection>,
246 metric_callback: Option<crate::metric::Callback>,
247 ) -> Result<crate::DatabaseTransaction, DbErr> {
248 use futures_util::lock::Mutex;
249 let backend = inner.get_database_backend();
250 Self::begin(
251 Arc::new(Mutex::new(crate::InnerConnection::Mock(inner))),
252 backend,
253 metric_callback,
254 None,
255 None,
256 )
257 .await
258 }
259}