1use crate::mock::{Mock, MockCall};
6use parking_lot::RwLock;
7use std::{collections::HashMap, sync::Arc};
8use wae_types::{WaeError, WaeErrorKind, WaeResult as TestingResult};
9
10#[derive(Debug, Clone)]
12pub struct DatabaseQuery {
13 pub query_name: String,
15 pub params: Vec<String>,
17 pub timestamp: std::time::Instant,
19}
20
21#[derive(Debug, Clone)]
23pub enum DatabaseResult<T> {
24 Return(T),
26 Error(String),
28 NamedReturns(HashMap<String, T>),
30}
31
32impl<T: Clone> DatabaseResult<T> {
33 pub fn return_value(value: T) -> Self {
35 DatabaseResult::Return(value)
36 }
37
38 pub fn error(msg: impl Into<String>) -> Self {
40 DatabaseResult::Error(msg.into())
41 }
42
43 pub fn named_returns(map: HashMap<String, T>) -> Self {
45 DatabaseResult::NamedReturns(map)
46 }
47}
48
49#[derive(Debug, Default)]
51pub struct DatabaseExpectation {
52 pub expected_queries: HashMap<String, usize>,
54 pub description: Option<String>,
56}
57
58impl DatabaseExpectation {
59 pub fn new() -> Self {
61 Self::default()
62 }
63
64 pub fn expect_query(mut self, query_name: impl Into<String>, count: usize) -> Self {
66 self.expected_queries.insert(query_name.into(), count);
67 self
68 }
69
70 pub fn description(mut self, desc: impl Into<String>) -> Self {
72 self.description = Some(desc.into());
73 self
74 }
75}
76
77pub struct MockDatabaseBuilder<T> {
79 result: Option<DatabaseResult<T>>,
80 expectation: DatabaseExpectation,
81 queries: Arc<RwLock<Vec<DatabaseQuery>>>,
82}
83
84impl<T: Clone + Send + Sync + 'static> MockDatabaseBuilder<T> {
85 pub fn new() -> Self {
87 Self { result: None, expectation: DatabaseExpectation::default(), queries: Arc::new(RwLock::new(Vec::new())) }
88 }
89
90 pub fn return_value(mut self, value: T) -> Self {
92 self.result = Some(DatabaseResult::return_value(value));
93 self
94 }
95
96 pub fn error(mut self, msg: impl Into<String>) -> Self {
98 self.result = Some(DatabaseResult::error(msg));
99 self
100 }
101
102 pub fn named_returns(mut self, map: HashMap<String, T>) -> Self {
104 self.result = Some(DatabaseResult::named_returns(map));
105 self
106 }
107
108 pub fn expect(mut self, expectation: DatabaseExpectation) -> Self {
110 self.expectation = expectation;
111 self
112 }
113
114 pub fn build(self) -> MockDatabase<T> {
116 MockDatabase { result: self.result, expectation: self.expectation, queries: self.queries }
117 }
118}
119
120impl<T: Clone + Send + Sync + 'static> Default for MockDatabaseBuilder<T> {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126pub struct MockDatabase<T> {
130 result: Option<DatabaseResult<T>>,
131 expectation: DatabaseExpectation,
132 queries: Arc<RwLock<Vec<DatabaseQuery>>>,
133}
134
135impl<T: Clone> MockDatabase<T> {
136 pub fn query(&self, query_name: impl Into<String>, params: Vec<String>) -> TestingResult<T> {
138 let query_name = query_name.into();
139 {
140 let mut queries = self.queries.write();
141 queries.push(DatabaseQuery { query_name: query_name.clone(), params, timestamp: std::time::Instant::now() });
142 }
143
144 match &self.result {
145 Some(DatabaseResult::Return(v)) => Ok(v.clone()),
146 Some(DatabaseResult::Error(e)) => Err(WaeError::new(WaeErrorKind::MockError { reason: e.clone() })),
147 Some(DatabaseResult::NamedReturns(map)) => {
148 if let Some(value) = map.get(&query_name) {
149 Ok(value.clone())
150 }
151 else {
152 Err(WaeError::new(WaeErrorKind::MockError { reason: format!("No mock result for query: {}", query_name) }))
153 }
154 }
155 None => Err(WaeError::new(WaeErrorKind::MockError { reason: "No mock result configured".to_string() })),
156 }
157 }
158
159 pub async fn query_async(&self, query_name: impl Into<String>, params: Vec<String>) -> TestingResult<T> {
161 self.query(query_name, params)
162 }
163
164 pub fn queries(&self) -> Vec<DatabaseQuery> {
166 self.queries.read().clone()
167 }
168
169 pub fn query_count(&self) -> usize {
171 self.queries.read().len()
172 }
173
174 pub fn query_count_by_name(&self, query_name: &str) -> usize {
176 self.queries.read().iter().filter(|q| q.query_name == query_name).count()
177 }
178}
179
180impl<T: Clone + Send + Sync + 'static> Mock for MockDatabase<T> {
181 fn calls(&self) -> Vec<MockCall> {
182 self.queries
183 .read()
184 .iter()
185 .map(|q| MockCall {
186 args: vec![q.query_name.clone()].into_iter().chain(q.params.clone()).collect(),
187 timestamp: q.timestamp,
188 })
189 .collect()
190 }
191
192 fn call_count(&self) -> usize {
193 self.query_count()
194 }
195
196 fn verify(&self) -> TestingResult<()> {
197 for (query_name, expected) in &self.expectation.expected_queries {
198 let actual = self.query_count_by_name(query_name);
199 if actual != *expected {
200 return Err(WaeError::new(WaeErrorKind::AssertionFailed {
201 message: format!("Expected {} calls for query '{}', but got {}", expected, query_name, actual),
202 }));
203 }
204 }
205
206 Ok(())
207 }
208
209 fn reset(&self) {
210 let mut queries = self.queries.write();
211 queries.clear();
212 }
213}