1use crate::{
2 debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult,
3 Statement, Transaction,
4};
5use futures_util::Stream;
6use std::{
7 fmt::Debug,
8 pin::Pin,
9 sync::{
10 atomic::{AtomicUsize, Ordering},
11 Arc, Mutex,
12 },
13};
14use tracing::instrument;
15
16#[derive(Debug)]
18pub struct MockDatabaseConnector;
19
20#[derive(Debug)]
22pub struct MockDatabaseConnection {
23 execute_counter: AtomicUsize,
24 query_counter: AtomicUsize,
25 mocker: Mutex<Box<dyn MockDatabaseTrait>>,
26}
27
28pub trait MockDatabaseTrait: Send + Debug {
30 fn execute(&mut self, counter: usize, stmt: Statement) -> Result<ExecResult, DbErr>;
32
33 fn query(&mut self, counter: usize, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;
35
36 fn begin(&mut self);
38
39 fn commit(&mut self);
41
42 fn rollback(&mut self);
44
45 fn drain_transaction_log(&mut self) -> Vec<Transaction>;
47
48 fn get_database_backend(&self) -> DbBackend;
50
51 fn ping(&self) -> Result<(), DbErr>;
53}
54
55impl MockDatabaseConnector {
56 #[allow(unused_variables)]
58 pub fn accepts(string: &str) -> bool {
59 #[cfg(feature = "sqlx-mysql")]
60 if DbBackend::MySql.is_prefix_of(string) {
61 return true;
62 }
63 #[cfg(feature = "sqlx-postgres")]
64 if DbBackend::Postgres.is_prefix_of(string) {
65 return true;
66 }
67 #[cfg(feature = "sqlx-sqlite")]
68 if DbBackend::Sqlite.is_prefix_of(string) {
69 return true;
70 }
71 false
72 }
73
74 #[allow(unused_variables)]
76 #[instrument(level = "trace")]
77 pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
78 macro_rules! connect_mock_db {
79 ( $syntax: expr ) => {
80 Ok(DatabaseConnection::MockDatabaseConnection(Arc::new(
81 MockDatabaseConnection::new(MockDatabase::new($syntax)),
82 )))
83 };
84 }
85
86 #[cfg(feature = "sqlx-mysql")]
87 if crate::SqlxMySqlConnector::accepts(string) {
88 return connect_mock_db!(DbBackend::MySql);
89 }
90 #[cfg(feature = "sqlx-postgres")]
91 if crate::SqlxPostgresConnector::accepts(string) {
92 return connect_mock_db!(DbBackend::Postgres);
93 }
94 #[cfg(feature = "sqlx-sqlite")]
95 if crate::SqlxSqliteConnector::accepts(string) {
96 return connect_mock_db!(DbBackend::Sqlite);
97 }
98 connect_mock_db!(DbBackend::Postgres)
99 }
100}
101
102impl MockDatabaseConnection {
103 pub fn new<M: 'static>(m: M) -> Self
105 where
106 M: MockDatabaseTrait,
107 {
108 Self {
109 execute_counter: AtomicUsize::new(0),
110 query_counter: AtomicUsize::new(0),
111 mocker: Mutex::new(Box::new(m)),
112 }
113 }
114
115 pub(crate) fn get_mocker_mutex(&self) -> &Mutex<Box<dyn MockDatabaseTrait>> {
116 &self.mocker
117 }
118
119 pub fn get_database_backend(&self) -> DbBackend {
125 self.mocker
126 .lock()
127 .expect("Fail to acquire mocker")
128 .get_database_backend()
129 }
130
131 #[instrument(level = "trace")]
133 pub fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> {
134 debug_print!("{}", statement);
135 let counter = self.execute_counter.fetch_add(1, Ordering::SeqCst);
136 self.mocker
137 .lock()
138 .map_err(exec_err)?
139 .execute(counter, statement)
140 }
141
142 #[instrument(level = "trace")]
144 pub fn query_one(&self, statement: Statement) -> Result<Option<QueryResult>, DbErr> {
145 debug_print!("{}", statement);
146 let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
147 let result = self
148 .mocker
149 .lock()
150 .map_err(query_err)?
151 .query(counter, statement)?;
152 Ok(result.into_iter().next())
153 }
154
155 #[instrument(level = "trace")]
157 pub fn query_all(&self, statement: Statement) -> Result<Vec<QueryResult>, DbErr> {
158 debug_print!("{}", statement);
159 let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
160 self.mocker
161 .lock()
162 .map_err(query_err)?
163 .query(counter, statement)
164 }
165
166 #[instrument(level = "trace")]
168 pub fn fetch(
169 &self,
170 statement: &Statement,
171 ) -> Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>> {
172 match self.query_all(statement.clone()) {
173 Ok(v) => Box::pin(futures_util::stream::iter(v.into_iter().map(Ok))),
174 Err(e) => Box::pin(futures_util::stream::iter(Some(Err(e)).into_iter())),
175 }
176 }
177
178 #[instrument(level = "trace")]
180 pub fn begin(&self) {
181 self.mocker
182 .lock()
183 .expect("Failed to acquire mocker")
184 .begin()
185 }
186
187 #[instrument(level = "trace")]
189 pub fn commit(&self) {
190 self.mocker
191 .lock()
192 .expect("Failed to acquire mocker")
193 .commit()
194 }
195
196 #[instrument(level = "trace")]
198 pub fn rollback(&self) {
199 self.mocker
200 .lock()
201 .expect("Failed to acquire mocker")
202 .rollback()
203 }
204
205 pub fn ping(&self) -> Result<(), DbErr> {
207 self.mocker.lock().map_err(query_err)?.ping()
208 }
209}
210
211impl
212 From<(
213 Arc<crate::MockDatabaseConnection>,
214 Statement,
215 Option<crate::metric::Callback>,
216 )> for crate::QueryStream
217{
218 fn from(
219 (conn, stmt, metric_callback): (
220 Arc<crate::MockDatabaseConnection>,
221 Statement,
222 Option<crate::metric::Callback>,
223 ),
224 ) -> Self {
225 crate::QueryStream::build(stmt, crate::InnerConnection::Mock(conn), metric_callback)
226 }
227}
228
229impl crate::DatabaseTransaction {
230 pub(crate) async fn new_mock(
231 inner: Arc<crate::MockDatabaseConnection>,
232 metric_callback: Option<crate::metric::Callback>,
233 ) -> Result<crate::DatabaseTransaction, DbErr> {
234 use futures_util::lock::Mutex;
235 let backend = inner.get_database_backend();
236 Self::begin(
237 Arc::new(Mutex::new(crate::InnerConnection::Mock(inner))),
238 backend,
239 metric_callback,
240 None,
241 None,
242 )
243 .await
244 }
245}