1use crate::{
2 debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult,
3 Statement, Transaction,
4};
5use futures_util::Stream;
6use std::{
7 fmt::Debug,
8 pin::Pin,
9 sync::{
10 atomic::{AtomicUsize, Ordering},
11 Arc, Mutex,
12 },
13};
14use tracing::instrument;
15
16#[derive(Debug)]
18pub struct MockDatabaseConnector;
19
20#[derive(Debug)]
22pub struct MockDatabaseConnection {
23 execute_counter: AtomicUsize,
24 query_counter: AtomicUsize,
25 mocker: Mutex<Box<dyn MockDatabaseTrait>>,
26}
27
28pub trait MockDatabaseTrait: Send + Debug {
30 fn execute(&mut self, counter: usize, stmt: Statement) -> Result<ExecResult, DbErr>;
32
33 fn query(&mut self, counter: usize, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;
35
36 fn begin(&mut self);
38
39 fn commit(&mut self);
41
42 fn rollback(&mut self);
44
45 fn drain_transaction_log(&mut self) -> Vec<Transaction>;
47
48 fn get_database_backend(&self) -> DbBackend;
50
51 fn ping(&self) -> Result<(), DbErr>;
53}
54
55impl MockDatabaseConnector {
56 #[allow(unused_variables)]
58 pub fn accepts(string: &str) -> bool {
59 #[cfg(feature = "sqlx-mysql")]
60 if DbBackend::MySql.is_prefix_of(string) {
61 return true;
62 }
63 #[cfg(feature = "sqlx-postgres")]
64 if DbBackend::Postgres.is_prefix_of(string) {
65 return true;
66 }
67 #[cfg(feature = "sqlx-sqlite")]
68 if DbBackend::Sqlite.is_prefix_of(string) {
69 return true;
70 }
71 false
72 }
73
74 #[allow(unused_variables)]
76 #[instrument(level = "trace")]
77 pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
78 macro_rules! connect_mock_db {
79 ( $syntax: expr ) => {
80 Ok(DatabaseConnection::MockDatabaseConnection(Arc::new(
81 MockDatabaseConnection::new(MockDatabase::new($syntax)),
82 )))
83 };
84 }
85
86 #[cfg(feature = "sqlx-mysql")]
87 if crate::SqlxMySqlConnector::accepts(string) {
88 return connect_mock_db!(DbBackend::MySql);
89 }
90 #[cfg(feature = "sqlx-postgres")]
91 if crate::SqlxPostgresConnector::accepts(string) {
92 return connect_mock_db!(DbBackend::Postgres);
93 }
94 #[cfg(feature = "sqlx-sqlite")]
95 if crate::SqlxSqliteConnector::accepts(string) {
96 return connect_mock_db!(DbBackend::Sqlite);
97 }
98 connect_mock_db!(DbBackend::Postgres)
99 }
100}
101
102impl MockDatabaseConnection {
103 pub fn new<M>(m: M) -> Self
105 where
106 M: MockDatabaseTrait,
107 M: 'static,
108 {
109 Self {
110 execute_counter: AtomicUsize::new(0),
111 query_counter: AtomicUsize::new(0),
112 mocker: Mutex::new(Box::new(m)),
113 }
114 }
115
116 pub(crate) fn get_mocker_mutex(&self) -> &Mutex<Box<dyn MockDatabaseTrait>> {
117 &self.mocker
118 }
119
120 pub fn get_database_backend(&self) -> DbBackend {
126 self.mocker
127 .lock()
128 .expect("Fail to acquire mocker")
129 .get_database_backend()
130 }
131
132 #[instrument(level = "trace")]
134 pub fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> {
135 debug_print!("{}", statement);
136 let counter = self.execute_counter.fetch_add(1, Ordering::SeqCst);
137 self.mocker
138 .lock()
139 .map_err(exec_err)?
140 .execute(counter, statement)
141 }
142
143 #[instrument(level = "trace")]
145 pub fn query_one(&self, statement: Statement) -> Result<Option<QueryResult>, DbErr> {
146 debug_print!("{}", statement);
147 let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
148 let result = self
149 .mocker
150 .lock()
151 .map_err(query_err)?
152 .query(counter, statement)?;
153 Ok(result.into_iter().next())
154 }
155
156 #[instrument(level = "trace")]
158 pub fn query_all(&self, statement: Statement) -> Result<Vec<QueryResult>, DbErr> {
159 debug_print!("{}", statement);
160 let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
161 self.mocker
162 .lock()
163 .map_err(query_err)?
164 .query(counter, statement)
165 }
166
167 #[instrument(level = "trace")]
169 pub fn fetch(
170 &self,
171 statement: &Statement,
172 ) -> Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>> {
173 match self.query_all(statement.clone()) {
174 Ok(v) => Box::pin(futures_util::stream::iter(v.into_iter().map(Ok))),
175 Err(e) => Box::pin(futures_util::stream::iter(Some(Err(e)).into_iter())),
176 }
177 }
178
179 #[instrument(level = "trace")]
185 pub fn begin(&self) {
186 self.mocker
187 .lock()
188 .expect("Failed to acquire mocker")
189 .begin()
190 }
191
192 #[instrument(level = "trace")]
198 pub fn commit(&self) {
199 self.mocker
200 .lock()
201 .expect("Failed to acquire mocker")
202 .commit()
203 }
204
205 #[instrument(level = "trace")]
211 pub fn rollback(&self) {
212 self.mocker
213 .lock()
214 .expect("Failed to acquire mocker")
215 .rollback()
216 }
217
218 pub fn ping(&self) -> Result<(), DbErr> {
220 self.mocker.lock().map_err(query_err)?.ping()
221 }
222}
223
224impl
225 From<(
226 Arc<crate::MockDatabaseConnection>,
227 Statement,
228 Option<crate::metric::Callback>,
229 )> for crate::QueryStream
230{
231 fn from(
232 (conn, stmt, metric_callback): (
233 Arc<crate::MockDatabaseConnection>,
234 Statement,
235 Option<crate::metric::Callback>,
236 ),
237 ) -> Self {
238 crate::QueryStream::build(stmt, crate::InnerConnection::Mock(conn), metric_callback)
239 }
240}
241
242impl crate::DatabaseTransaction {
243 pub(crate) async fn new_mock(
244 inner: Arc<crate::MockDatabaseConnection>,
245 metric_callback: Option<crate::metric::Callback>,
246 ) -> Result<crate::DatabaseTransaction, DbErr> {
247 use futures_util::lock::Mutex;
248 let backend = inner.get_database_backend();
249 Self::begin(
250 Arc::new(Mutex::new(crate::InnerConnection::Mock(inner))),
251 backend,
252 metric_callback,
253 None,
254 None,
255 )
256 .await
257 }
258}