1use crate::{
2 DatabaseConnection, DatabaseConnectionType, DbBackend, ExecResult, MockDatabase, QueryResult,
3 Statement, Transaction, debug_print, error::*,
4};
5use std::{
6 fmt::Debug,
7 pin::Pin,
8 sync::{
9 Arc, Mutex,
10 atomic::{AtomicUsize, Ordering},
11 },
12};
13use tracing::instrument;
14
15#[cfg(not(feature = "sync"))]
16type PinBoxStream = Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>>;
17#[cfg(feature = "sync")]
18type PinBoxStream = Box<dyn Iterator<Item = Result<QueryResult, DbErr>>>;
19
20#[derive(Debug)]
22pub struct MockDatabaseConnector;
23
24#[derive(Debug)]
26pub struct MockDatabaseConnection {
27 execute_counter: AtomicUsize,
28 query_counter: AtomicUsize,
29 mocker: Mutex<Box<dyn MockDatabaseTrait>>,
30}
31
32pub trait MockDatabaseTrait: Debug {
34 fn execute(&mut self, counter: usize, stmt: Statement) -> Result<ExecResult, DbErr>;
36
37 fn query(&mut self, counter: usize, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;
39
40 fn begin(&mut self);
42
43 fn commit(&mut self);
45
46 fn rollback(&mut self);
48
49 fn drain_transaction_log(&mut self) -> Vec<Transaction>;
51
52 fn get_database_backend(&self) -> DbBackend;
54
55 fn ping(&self) -> Result<(), DbErr>;
57}
58
59impl MockDatabaseConnector {
60 #[allow(unused_variables)]
62 pub fn accepts(string: &str) -> bool {
63 #[cfg(feature = "sqlx-mysql")]
64 if DbBackend::MySql.is_prefix_of(string) {
65 return true;
66 }
67 #[cfg(feature = "sqlx-postgres")]
68 if DbBackend::Postgres.is_prefix_of(string) {
69 return true;
70 }
71 #[cfg(feature = "sqlx-sqlite")]
72 if DbBackend::Sqlite.is_prefix_of(string) {
73 return true;
74 }
75 #[cfg(feature = "rusqlite")]
76 if DbBackend::Sqlite.is_prefix_of(string) {
77 return true;
78 }
79 false
80 }
81
82 #[allow(unused_variables)]
84 #[instrument(level = "trace")]
85 pub fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
86 macro_rules! connect_mock_db {
87 ( $syntax: expr ) => {
88 Ok(DatabaseConnectionType::MockDatabaseConnection(Arc::new(
89 MockDatabaseConnection::new(MockDatabase::new($syntax)),
90 ))
91 .into())
92 };
93 }
94
95 #[cfg(feature = "sqlx-mysql")]
96 if crate::SqlxMySqlConnector::accepts(string) {
97 return connect_mock_db!(DbBackend::MySql);
98 }
99 #[cfg(feature = "sqlx-postgres")]
100 if crate::SqlxPostgresConnector::accepts(string) {
101 return connect_mock_db!(DbBackend::Postgres);
102 }
103 #[cfg(feature = "sqlx-sqlite")]
104 if crate::SqlxSqliteConnector::accepts(string) {
105 return connect_mock_db!(DbBackend::Sqlite);
106 }
107 #[cfg(feature = "rusqlite")]
108 if crate::driver::rusqlite::RusqliteConnector::accepts(string) {
109 return connect_mock_db!(DbBackend::Sqlite);
110 }
111 connect_mock_db!(DbBackend::Postgres)
112 }
113}
114
115impl MockDatabaseConnection {
116 pub fn new<M>(m: M) -> Self
118 where
119 M: MockDatabaseTrait,
120 M: 'static,
121 {
122 Self {
123 execute_counter: AtomicUsize::new(0),
124 query_counter: AtomicUsize::new(0),
125 mocker: Mutex::new(Box::new(m)),
126 }
127 }
128
129 pub(crate) fn get_mocker_mutex(&self) -> &Mutex<Box<dyn MockDatabaseTrait>> {
130 &self.mocker
131 }
132
133 pub fn get_database_backend(&self) -> DbBackend {
139 self.mocker
140 .lock()
141 .expect("Fail to acquire mocker")
142 .get_database_backend()
143 }
144
145 #[instrument(level = "trace")]
147 pub fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> {
148 debug_print!("{}", statement);
149 let counter = self.execute_counter.fetch_add(1, Ordering::SeqCst);
150 self.mocker
151 .lock()
152 .map_err(exec_err)?
153 .execute(counter, statement)
154 }
155
156 #[instrument(level = "trace")]
158 pub fn query_one(&self, statement: Statement) -> Result<Option<QueryResult>, DbErr> {
159 debug_print!("{}", statement);
160 let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
161 let result = self
162 .mocker
163 .lock()
164 .map_err(query_err)?
165 .query(counter, statement)?;
166 Ok(result.into_iter().next())
167 }
168
169 #[instrument(level = "trace")]
171 pub fn query_all(&self, statement: Statement) -> Result<Vec<QueryResult>, DbErr> {
172 debug_print!("{}", statement);
173 let counter = self.query_counter.fetch_add(1, Ordering::SeqCst);
174 self.mocker
175 .lock()
176 .map_err(query_err)?
177 .query(counter, statement)
178 }
179
180 #[instrument(level = "trace")]
182 pub fn fetch(&self, statement: &Statement) -> PinBoxStream {
183 #[cfg(not(feature = "sync"))]
184 {
185 match self.query_all(statement.clone()) {
186 Ok(v) => Box::new(futures_util::stream::iter(v.into_iter().map(Ok))),
187 Err(e) => Box::new(futures_util::stream::iter(Some(Err(e)).into_iter())),
188 }
189 }
190 #[cfg(feature = "sync")]
191 {
192 match self.query_all(statement.clone()) {
193 Ok(v) => Box::new(v.into_iter().map(Ok)),
194 Err(e) => Box::new(Some(Err(e)).into_iter()),
195 }
196 }
197 }
198
199 #[instrument(level = "trace")]
205 pub fn begin(&self) {
206 self.mocker
207 .lock()
208 .expect("Failed to acquire mocker")
209 .begin()
210 }
211
212 #[instrument(level = "trace")]
218 pub fn commit(&self) {
219 self.mocker
220 .lock()
221 .expect("Failed to acquire mocker")
222 .commit()
223 }
224
225 #[instrument(level = "trace")]
231 pub fn rollback(&self) {
232 self.mocker
233 .lock()
234 .expect("Failed to acquire mocker")
235 .rollback()
236 }
237
238 pub fn ping(&self) -> Result<(), DbErr> {
240 self.mocker.lock().map_err(query_err)?.ping()
241 }
242}
243
244impl
245 From<(
246 Arc<crate::MockDatabaseConnection>,
247 Statement,
248 Option<crate::metric::Callback>,
249 )> for crate::QueryStream
250{
251 fn from(
252 (conn, stmt, metric_callback): (
253 Arc<crate::MockDatabaseConnection>,
254 Statement,
255 Option<crate::metric::Callback>,
256 ),
257 ) -> Self {
258 crate::QueryStream::build(stmt, crate::InnerConnection::Mock(conn), metric_callback)
259 }
260}
261
262impl crate::DatabaseTransaction {
263 pub(crate) fn new_mock(
264 inner: Arc<crate::MockDatabaseConnection>,
265 metric_callback: Option<crate::metric::Callback>,
266 ) -> Result<crate::DatabaseTransaction, DbErr> {
267 use std::sync::Mutex;
268 let backend = inner.get_database_backend();
269 Self::begin(
270 Arc::new(Mutex::new(crate::InnerConnection::Mock(inner))),
271 backend,
272 metric_callback,
273 None,
274 None,
275 )
276 }
277}