1use crate::{
2 DatabaseConnection, DatabaseConnectionType, DbBackend, ExecResult, MockDatabase, QueryResult,
3 Statement, Transaction, debug_print, error::*,
4};
5use futures_util::Stream;
6use std::{
7 fmt::Debug,
8 pin::Pin,
9 sync::{
10 Arc, Mutex,
11 atomic::{AtomicUsize, Ordering},
12 },
13};
14use tracing::instrument;
15
16#[cfg(not(feature = "sync"))]
17type PinBoxStream = Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>>;
18#[cfg(feature = "sync")]
19type PinBoxStream = Box<dyn Iterator<Item = Result<QueryResult, DbErr>>>;
20
21#[derive(Debug)]
23pub struct MockDatabaseConnector;
24
25#[derive(Debug)]
27pub struct MockDatabaseConnection {
28 execute_counter: AtomicUsize,
29 query_counter: AtomicUsize,
30 mocker: Mutex<Box<dyn MockDatabaseTrait>>,
31}
32
33pub trait MockDatabaseTrait: Send + Debug {
35 fn execute(&mut self, counter: usize, stmt: Statement) -> Result<ExecResult, DbErr>;
37
38 fn query(&mut self, counter: usize, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;
40
41 fn begin(&mut self);
43
44 fn commit(&mut self);
46
47 fn rollback(&mut self);
49
50 fn drain_transaction_log(&mut self) -> Vec<Transaction>;
52
53 fn get_database_backend(&self) -> DbBackend;
55
56 fn ping(&self) -> Result<(), DbErr>;
58}
59
60impl MockDatabaseConnector {
61 #[allow(unused_variables)]
63 pub fn accepts(string: &str) -> bool {
64 #[cfg(feature = "sqlx-mysql")]
65 if DbBackend::MySql.is_prefix_of(string) {
66 return true;
67 }
68 #[cfg(feature = "sqlx-postgres")]
69 if DbBackend::Postgres.is_prefix_of(string) {
70 return true;
71 }
72 #[cfg(feature = "sqlx-sqlite")]
73 if DbBackend::Sqlite.is_prefix_of(string) {
74 return true;
75 }
76 #[cfg(feature = "rusqlite")]
77 if DbBackend::Sqlite.is_prefix_of(string) {
78 return true;
79 }
80 false
81 }
82
83 #[allow(unused_variables)]
85 #[instrument(level = "trace")]
86 pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
87 macro_rules! connect_mock_db {
88 ( $syntax: expr ) => {
89 Ok(DatabaseConnectionType::MockDatabaseConnection(Arc::new(
90 MockDatabaseConnection::new(MockDatabase::new($syntax)),
91 ))
92 .into())
93 };
94 }
95
96 #[cfg(feature = "sqlx-mysql")]
97 if crate::SqlxMySqlConnector::accepts(string) {
98 return connect_mock_db!(DbBackend::MySql);
99 }
100 #[cfg(feature = "sqlx-postgres")]
101 if crate::SqlxPostgresConnector::accepts(string) {
102 return connect_mock_db!(DbBackend::Postgres);
103 }
104 #[cfg(feature = "sqlx-sqlite")]
105 if crate::SqlxSqliteConnector::accepts(string) {
106 return connect_mock_db!(DbBackend::Sqlite);
107 }
108 #[cfg(feature = "rusqlite")]
109 if crate::driver::rusqlite::RusqliteConnector::accepts(string) {
110 return connect_mock_db!(DbBackend::Sqlite);
111 }
112 connect_mock_db!(DbBackend::Postgres)
113 }
114}
115
116impl MockDatabaseConnection {
117 pub fn new<M>(m: M) -> Self
119 where
120 M: MockDatabaseTrait,
121 M: 'static,
122 {
123 Self {
124 execute_counter: AtomicUsize::new(0),
125 query_counter: AtomicUsize::new(0),
126 mocker: Mutex::new(Box::new(m)),
127 }
128 }
129
130 pub(crate) fn get_mocker_mutex(&self) -> &Mutex<Box<dyn MockDatabaseTrait>> {
131 &self.mocker
132 }
133
134 pub fn get_database_backend(&self) -> DbBackend {
140 self.mocker
141 .lock()
142 .expect("Fail to acquire mocker")
143 .get_database_backend()
144 }
145
146 #[instrument(level = "trace", skip(statement))]
148 pub fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> {
149 debug_print!("{}", statement);
150 let counter = self.execute_counter.fetch_add(1, Ordering::SeqCst);
151 self.mocker
152 .lock()
153 .map_err(exec_err)?
154 .execute(counter, statement)
155 }
156
157 #[instrument(level = "trace", skip(statement))]
159 pub fn query_one(&self, statement: Statement) -> Result<Option<QueryResult>, DbErr> {
160 debug_print!("{}", statement);
161 let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
162 let result = self
163 .mocker
164 .lock()
165 .map_err(query_err)?
166 .query(counter, statement)?;
167 Ok(result.into_iter().next())
168 }
169
170 #[instrument(level = "trace", skip(statement))]
172 pub fn query_all(&self, statement: Statement) -> Result<Vec<QueryResult>, DbErr> {
173 debug_print!("{}", statement);
174 let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
175 self.mocker
176 .lock()
177 .map_err(query_err)?
178 .query(counter, statement)
179 }
180
181 #[instrument(level = "trace", skip(statement))]
183 pub fn fetch(&self, statement: &Statement) -> PinBoxStream {
184 #[cfg(not(feature = "sync"))]
185 {
186 match self.query_all(statement.clone()) {
187 Ok(v) => Box::pin(futures_util::stream::iter(v.into_iter().map(Ok))),
188 Err(e) => Box::pin(futures_util::stream::iter(Some(Err(e)).into_iter())),
189 }
190 }
191 #[cfg(feature = "sync")]
192 {
193 match self.query_all(statement.clone()) {
194 Ok(v) => Box::new(v.into_iter().map(Ok)),
195 Err(e) => Box::new(Some(Err(e)).into_iter()),
196 }
197 }
198 }
199
200 #[instrument(level = "trace")]
206 pub fn begin(&self) {
207 self.mocker
208 .lock()
209 .expect("Failed to acquire mocker")
210 .begin()
211 }
212
213 #[instrument(level = "trace")]
219 pub fn commit(&self) {
220 self.mocker
221 .lock()
222 .expect("Failed to acquire mocker")
223 .commit()
224 }
225
226 #[instrument(level = "trace")]
232 pub fn rollback(&self) {
233 self.mocker
234 .lock()
235 .expect("Failed to acquire mocker")
236 .rollback()
237 }
238
239 pub fn ping(&self) -> Result<(), DbErr> {
241 self.mocker.lock().map_err(query_err)?.ping()
242 }
243}
244
245#[cfg(feature = "stream")]
246impl
247 From<(
248 Arc<crate::MockDatabaseConnection>,
249 Statement,
250 Option<crate::metric::Callback>,
251 )> for crate::QueryStream
252{
253 fn from(
254 (conn, stmt, metric_callback): (
255 Arc<crate::MockDatabaseConnection>,
256 Statement,
257 Option<crate::metric::Callback>,
258 ),
259 ) -> Self {
260 crate::QueryStream::build(stmt, crate::InnerConnection::Mock(conn), metric_callback)
261 }
262}
263
264impl crate::DatabaseTransaction {
265 pub(crate) async fn new_mock(
266 inner: Arc<crate::MockDatabaseConnection>,
267 metric_callback: Option<crate::metric::Callback>,
268 ) -> Result<crate::DatabaseTransaction, DbErr> {
269 use futures_util::lock::Mutex;
270 let backend = inner.get_database_backend();
271 Self::begin(
272 Arc::new(Mutex::new(crate::InnerConnection::Mock(inner))),
273 backend,
274 metric_callback,
275 true,
276 None,
277 None,
278 None,
279 )
280 .await
281 }
282}