1use crate::mock::{Mock, MockCall};
6use bytes::Bytes;
7use http::{HeaderMap, Method, StatusCode};
8use parking_lot::RwLock;
9use std::{collections::HashMap, sync::Arc};
10use wae_types::{WaeError, WaeErrorKind, WaeResult as TestingResult};
11
12#[derive(Debug, Clone)]
14pub struct ServiceRequest {
15 pub method: Method,
17 pub path: String,
19 pub headers: HeaderMap,
21 pub body: Option<Bytes>,
23 pub timestamp: std::time::Instant,
25}
26
27#[derive(Debug, Clone)]
29pub struct ServiceResponse {
30 pub status: StatusCode,
32 pub headers: HeaderMap,
34 pub body: Bytes,
36}
37
38impl ServiceResponse {
39 pub fn ok(body: impl Into<Bytes>) -> Self {
41 Self { status: StatusCode::OK, headers: HeaderMap::new(), body: body.into() }
42 }
43
44 pub fn error(status: StatusCode, body: impl Into<Bytes>) -> Self {
46 Self { status, headers: HeaderMap::new(), body: body.into() }
47 }
48
49 pub fn header<K, V>(mut self, key: K, value: V) -> Self
51 where
52 K: TryInto<http::HeaderName>,
53 V: TryInto<http::HeaderValue>,
54 {
55 if let (Ok(key), Ok(value)) = (key.try_into(), value.try_into()) {
56 self.headers.append(key, value);
57 }
58 self
59 }
60}
61
62#[derive(Clone)]
64pub enum ServiceMatchRule {
65 Path(String),
67 PathAndMethod(String, Method),
69 Custom(Arc<dyn Fn(&ServiceRequest) -> bool + Send + Sync>),
71}
72
73impl std::fmt::Debug for ServiceMatchRule {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 match self {
76 ServiceMatchRule::Path(path) => f.debug_tuple("Path").field(path).finish(),
77 ServiceMatchRule::PathAndMethod(path, method) => f.debug_tuple("PathAndMethod").field(path).field(method).finish(),
78 ServiceMatchRule::Custom(_) => f.debug_tuple("Custom").field(&"<function>").finish(),
79 }
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct ServiceResponseConfig {
86 pub match_rule: ServiceMatchRule,
88 pub response: ServiceResponse,
90}
91
92#[derive(Debug, Default)]
94pub struct ServiceExpectation {
95 pub expected_requests: HashMap<String, usize>,
97 pub description: Option<String>,
99}
100
101impl ServiceExpectation {
102 pub fn new() -> Self {
104 Self::default()
105 }
106
107 pub fn expect_request(mut self, path: impl Into<String>, count: usize) -> Self {
109 self.expected_requests.insert(path.into(), count);
110 self
111 }
112
113 pub fn description(mut self, desc: impl Into<String>) -> Self {
115 self.description = Some(desc.into());
116 self
117 }
118}
119
120pub struct MockExternalServiceBuilder {
122 responses: Vec<ServiceResponseConfig>,
123 expectation: ServiceExpectation,
124 requests: Arc<RwLock<Vec<ServiceRequest>>>,
125}
126
127impl MockExternalServiceBuilder {
128 pub fn new() -> Self {
130 Self { responses: Vec::new(), expectation: ServiceExpectation::default(), requests: Arc::new(RwLock::new(Vec::new())) }
131 }
132
133 pub fn respond_to_path(mut self, path: impl Into<String>, response: ServiceResponse) -> Self {
135 self.responses.push(ServiceResponseConfig { match_rule: ServiceMatchRule::Path(path.into()), response });
136 self
137 }
138
139 pub fn respond_to_path_and_method(mut self, path: impl Into<String>, method: Method, response: ServiceResponse) -> Self {
141 self.responses
142 .push(ServiceResponseConfig { match_rule: ServiceMatchRule::PathAndMethod(path.into(), method), response });
143 self
144 }
145
146 pub fn respond_with<F>(mut self, matcher: F, response: ServiceResponse) -> Self
148 where
149 F: Fn(&ServiceRequest) -> bool + Send + Sync + 'static,
150 {
151 self.responses.push(ServiceResponseConfig { match_rule: ServiceMatchRule::Custom(Arc::new(matcher)), response });
152 self
153 }
154
155 pub fn expect(mut self, expectation: ServiceExpectation) -> Self {
157 self.expectation = expectation;
158 self
159 }
160
161 pub fn build(self) -> MockExternalService {
163 MockExternalService { responses: self.responses, expectation: self.expectation, requests: self.requests }
164 }
165}
166
167impl Default for MockExternalServiceBuilder {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173pub struct MockExternalService {
177 responses: Vec<ServiceResponseConfig>,
178 expectation: ServiceExpectation,
179 requests: Arc<RwLock<Vec<ServiceRequest>>>,
180}
181
182impl MockExternalService {
183 pub fn handle_request(
185 &self,
186 method: Method,
187 path: String,
188 headers: HeaderMap,
189 body: Option<Bytes>,
190 ) -> TestingResult<ServiceResponse> {
191 let request = ServiceRequest {
192 method: method.clone(),
193 path: path.clone(),
194 headers: headers.clone(),
195 body: body.clone(),
196 timestamp: std::time::Instant::now(),
197 };
198
199 {
200 let mut requests = self.requests.write();
201 requests.push(request.clone());
202 }
203
204 for config in &self.responses {
205 if Self::matches(&request, &config.match_rule) {
206 return Ok(config.response.clone());
207 }
208 }
209
210 Err(WaeError::new(WaeErrorKind::MockError { reason: format!("No mock response configured for {} {}", method, path) }))
211 }
212
213 pub async fn handle_request_async(
215 &self,
216 method: Method,
217 path: String,
218 headers: HeaderMap,
219 body: Option<Bytes>,
220 ) -> TestingResult<ServiceResponse> {
221 self.handle_request(method, path, headers, body)
222 }
223
224 fn matches(request: &ServiceRequest, rule: &ServiceMatchRule) -> bool {
225 match rule {
226 ServiceMatchRule::Path(path) => request.path == *path,
227 ServiceMatchRule::PathAndMethod(path, method) => request.path == *path && request.method == *method,
228 ServiceMatchRule::Custom(matcher) => matcher(request),
229 }
230 }
231
232 pub fn requests(&self) -> Vec<ServiceRequest> {
234 self.requests.read().clone()
235 }
236
237 pub fn request_count(&self) -> usize {
239 self.requests.read().len()
240 }
241
242 pub fn request_count_by_path(&self, path: &str) -> usize {
244 self.requests.read().iter().filter(|r| r.path == path).count()
245 }
246}
247
248impl Mock for MockExternalService {
249 fn calls(&self) -> Vec<MockCall> {
250 self.requests
251 .read()
252 .iter()
253 .map(|r| MockCall { args: vec![r.method.to_string(), r.path.clone()], timestamp: r.timestamp })
254 .collect()
255 }
256
257 fn call_count(&self) -> usize {
258 self.request_count()
259 }
260
261 fn verify(&self) -> TestingResult<()> {
262 for (path, expected) in &self.expectation.expected_requests {
263 let actual = self.request_count_by_path(path);
264 if actual != *expected {
265 return Err(WaeError::new(WaeErrorKind::AssertionFailed {
266 message: format!("Expected {} requests for path '{}', but got {}", expected, path, actual),
267 }));
268 }
269 }
270
271 Ok(())
272 }
273
274 fn reset(&self) {
275 let mut requests = self.requests.write();
276 requests.clear();
277 }
278}