1use parking_lot::RwLock;
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::Duration;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct HttpResponse {
14 pub status: u16,
16 pub headers: HashMap<String, String>,
18 pub body: Vec<u8>,
20}
21
22impl HttpResponse {
23 pub fn new(status: u16, body: impl Into<Vec<u8>>) -> Self {
25 Self {
26 status,
27 headers: HashMap::new(),
28 body: body.into(),
29 }
30 }
31
32 pub fn json(status: u16, value: &serde_json::Value) -> Self {
34 let body = serde_json::to_vec(value).expect("Failed to serialize JSON");
35 let mut headers = HashMap::new();
36 headers.insert("content-type".to_string(), "application/json".to_string());
37 Self {
38 status,
39 headers,
40 body,
41 }
42 }
43
44 pub fn body_string(&self) -> String {
46 String::from_utf8_lossy(&self.body).into_owned()
47 }
48
49 pub fn body_json(&self) -> Result<serde_json::Value, serde_json::Error> {
51 serde_json::from_slice(&self.body)
52 }
53
54 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
56 self.headers.insert(key.into(), value.into());
57 self
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct HttpRequest {
64 pub method: String,
66 pub url: String,
68 pub headers: HashMap<String, String>,
70 pub body: Vec<u8>,
72}
73
74#[derive(Debug, Clone, thiserror::Error)]
76pub enum HttpError {
77 #[error("Connection failed: {0}")]
79 ConnectionFailed(String),
80 #[error("Request timed out")]
82 Timeout,
83 #[error("No mock rule matched for {method} {url}")]
85 NoMockMatch { method: String, url: String },
86 #[error("{0}")]
88 Other(String),
89}
90
91pub trait HttpProvider: Send + Sync {
93 fn request(
95 &self,
96 method: &str,
97 url: &str,
98 headers: HashMap<String, String>,
99 body: Option<Vec<u8>>,
100 ) -> std::pin::Pin<
101 Box<dyn std::future::Future<Output = Result<HttpResponse, HttpError>> + Send + '_>,
102 >;
103
104 fn is_mock(&self) -> bool;
106}
107
108#[derive(Debug, Clone)]
110pub struct RealHttp {
111 timeout: Duration,
113}
114
115impl RealHttp {
116 pub fn new() -> Self {
118 Self {
119 timeout: Duration::from_secs(30),
120 }
121 }
122
123 pub fn with_timeout(timeout: Duration) -> Self {
125 Self { timeout }
126 }
127
128 pub fn timeout(&self) -> Duration {
130 self.timeout
131 }
132}
133
134impl Default for RealHttp {
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140impl HttpProvider for RealHttp {
141 fn request(
142 &self,
143 _method: &str,
144 _url: &str,
145 _headers: HashMap<String, String>,
146 _body: Option<Vec<u8>>,
147 ) -> std::pin::Pin<
148 Box<dyn std::future::Future<Output = Result<HttpResponse, HttpError>> + Send + '_>,
149 > {
150 let timeout = self.timeout;
151 Box::pin(async move {
154 let _ = timeout; Err(HttpError::Other(
158 "Real HTTP not implemented - use MockHttp for testing".to_string(),
159 ))
160 })
161 }
162
163 fn is_mock(&self) -> bool {
164 false
165 }
166}
167
168#[derive(Clone)]
170pub struct MockHttpRule {
171 pub method: Option<String>,
173 pub url_pattern: Regex,
175 pub response: HttpResponse,
177 pub latency: Option<Duration>,
179 pub times: Option<usize>,
181 matched_count: usize,
183}
184
185impl MockHttpRule {
186 pub fn new(url_pattern: &str, response: HttpResponse) -> Self {
188 Self {
189 method: None,
190 url_pattern: Regex::new(url_pattern).expect("Invalid URL regex pattern"),
191 response,
192 latency: None,
193 times: None,
194 matched_count: 0,
195 }
196 }
197
198 pub fn with_method(mut self, method: &str) -> Self {
200 self.method = Some(method.to_uppercase());
201 self
202 }
203
204 pub fn with_latency(mut self, latency: Duration) -> Self {
206 self.latency = Some(latency);
207 self
208 }
209
210 pub fn times(mut self, n: usize) -> Self {
212 self.times = Some(n);
213 self
214 }
215
216 fn matches(&self, method: &str, url: &str) -> bool {
218 if let Some(ref expected_method) = self.method {
220 if expected_method != method.to_uppercase().as_str() {
221 return false;
222 }
223 }
224
225 if let Some(limit) = self.times {
227 if self.matched_count >= limit {
228 return false;
229 }
230 }
231
232 self.url_pattern.is_match(url)
234 }
235}
236
237impl std::fmt::Debug for MockHttpRule {
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 f.debug_struct("MockHttpRule")
240 .field("method", &self.method)
241 .field("url_pattern", &self.url_pattern.as_str())
242 .field("response_status", &self.response.status)
243 .field("latency", &self.latency)
244 .field("times", &self.times)
245 .finish()
246 }
247}
248
249pub struct MockHttp {
270 rules: RwLock<Vec<MockHttpRule>>,
271 requests: RwLock<Vec<HttpRequest>>,
272 fail_on_unmatched: bool,
273}
274
275impl MockHttp {
276 pub fn new() -> Self {
278 Self {
279 rules: RwLock::new(Vec::new()),
280 requests: RwLock::new(Vec::new()),
281 fail_on_unmatched: true,
282 }
283 }
284
285 pub fn rule(self, rule: MockHttpRule) -> Self {
287 self.rules.write().push(rule);
288 self
289 }
290
291 pub fn fail_on_unmatched(mut self, fail: bool) -> Self {
293 self.fail_on_unmatched = fail;
294 self
295 }
296
297 pub fn on_get(self, url_pattern: &str) -> MockHttpBuilder {
299 MockHttpBuilder {
300 mock: self,
301 method: Some("GET".to_string()),
302 url_pattern: url_pattern.to_string(),
303 latency: None,
304 times: None,
305 }
306 }
307
308 pub fn on_post(self, url_pattern: &str) -> MockHttpBuilder {
310 MockHttpBuilder {
311 mock: self,
312 method: Some("POST".to_string()),
313 url_pattern: url_pattern.to_string(),
314 latency: None,
315 times: None,
316 }
317 }
318
319 pub fn on_put(self, url_pattern: &str) -> MockHttpBuilder {
321 MockHttpBuilder {
322 mock: self,
323 method: Some("PUT".to_string()),
324 url_pattern: url_pattern.to_string(),
325 latency: None,
326 times: None,
327 }
328 }
329
330 pub fn on_delete(self, url_pattern: &str) -> MockHttpBuilder {
332 MockHttpBuilder {
333 mock: self,
334 method: Some("DELETE".to_string()),
335 url_pattern: url_pattern.to_string(),
336 latency: None,
337 times: None,
338 }
339 }
340
341 pub fn on_any(self, url_pattern: &str) -> MockHttpBuilder {
343 MockHttpBuilder {
344 mock: self,
345 method: None,
346 url_pattern: url_pattern.to_string(),
347 latency: None,
348 times: None,
349 }
350 }
351
352 pub fn requests(&self) -> Vec<HttpRequest> {
354 self.requests.read().clone()
355 }
356
357 pub fn clear_requests(&self) {
359 self.requests.write().clear();
360 }
361
362 pub fn assert_request_made(&self, method: &str, url_pattern: &str) -> bool {
364 let re = Regex::new(url_pattern).expect("Invalid URL pattern");
365 let requests = self.requests.read();
366 requests
367 .iter()
368 .any(|r| r.method.eq_ignore_ascii_case(method) && re.is_match(&r.url))
369 }
370
371 pub fn request_count(&self) -> usize {
373 self.requests.read().len()
374 }
375}
376
377impl Default for MockHttp {
378 fn default() -> Self {
379 Self::new()
380 }
381}
382
383impl HttpProvider for MockHttp {
384 fn request(
385 &self,
386 method: &str,
387 url: &str,
388 headers: HashMap<String, String>,
389 body: Option<Vec<u8>>,
390 ) -> std::pin::Pin<
391 Box<dyn std::future::Future<Output = Result<HttpResponse, HttpError>> + Send + '_>,
392 > {
393 self.requests.write().push(HttpRequest {
395 method: method.to_string(),
396 url: url.to_string(),
397 headers: headers.clone(),
398 body: body.clone().unwrap_or_default(),
399 });
400
401 let mut rules = self.rules.write();
403 let matched = rules.iter_mut().find(|rule| rule.matches(method, url));
404
405 match matched {
406 Some(rule) => {
407 rule.matched_count += 1;
408 let response = rule.response.clone();
409 let latency = rule.latency;
410
411 Box::pin(async move {
412 if let Some(delay) = latency {
413 tokio::time::sleep(delay).await;
414 }
415 Ok(response)
416 })
417 }
418 None => {
419 if self.fail_on_unmatched {
420 let method = method.to_string();
421 let url = url.to_string();
422 Box::pin(async move { Err(HttpError::NoMockMatch { method, url }) })
423 } else {
424 Box::pin(async move { Ok(HttpResponse::new(404, b"Not Found".to_vec())) })
426 }
427 }
428 }
429 }
430
431 fn is_mock(&self) -> bool {
432 true
433 }
434}
435
436pub struct MockHttpBuilder {
438 mock: MockHttp,
439 method: Option<String>,
440 url_pattern: String,
441 latency: Option<Duration>,
442 times: Option<usize>,
443}
444
445impl MockHttpBuilder {
446 pub fn with_latency(mut self, latency: Duration) -> Self {
448 self.latency = latency.into();
449 self
450 }
451
452 pub fn times(mut self, n: usize) -> Self {
454 self.times = Some(n);
455 self
456 }
457
458 pub fn respond(self, response: HttpResponse) -> MockHttp {
460 let mut rule = MockHttpRule::new(&self.url_pattern, response);
461 rule.method = self.method;
462 rule.latency = self.latency;
463 rule.times = self.times;
464 self.mock.rule(rule)
465 }
466
467 pub fn respond_json(self, status: u16, value: serde_json::Value) -> MockHttp {
469 self.respond(HttpResponse::json(status, &value))
470 }
471
472 pub fn respond_text(self, status: u16, text: &str) -> MockHttp {
474 let mut response = HttpResponse::new(status, text.as_bytes().to_vec());
475 response
476 .headers
477 .insert("content-type".to_string(), "text/plain".to_string());
478 self.respond(response)
479 }
480
481 pub fn respond_error(self, status: u16, message: &str) -> MockHttp {
483 self.respond_json(status, serde_json::json!({"error": message}))
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490 use serde_json::json;
491
492 #[tokio::test]
493 async fn mock_http_matches_get() {
494 let mock = MockHttp::new()
495 .on_get(r"^https://api\.example\.com/users/\d+$")
496 .respond_json(200, json!({"name": "Alice"}));
497
498 let response = mock
499 .request(
500 "GET",
501 "https://api.example.com/users/123",
502 HashMap::new(),
503 None,
504 )
505 .await
506 .unwrap();
507
508 assert_eq!(response.status, 200);
509 let body: serde_json::Value = response.body_json().unwrap();
510 assert_eq!(body["name"], "Alice");
511 }
512
513 #[tokio::test]
514 async fn mock_http_matches_post() {
515 let mock = MockHttp::new()
516 .on_post(r"^https://api\.example\.com/users$")
517 .respond_json(201, json!({"id": 42}));
518
519 let response = mock
520 .request(
521 "POST",
522 "https://api.example.com/users",
523 HashMap::new(),
524 Some(b"{}".to_vec()),
525 )
526 .await
527 .unwrap();
528
529 assert_eq!(response.status, 201);
530 }
531
532 #[tokio::test]
533 async fn mock_http_fails_on_unmatched() {
534 let mock = MockHttp::new()
535 .on_get(r"^https://api\.example\.com/users$")
536 .respond_json(200, json!([]));
537
538 let result = mock
539 .request("GET", "https://api.example.com/other", HashMap::new(), None)
540 .await;
541
542 assert!(matches!(result, Err(HttpError::NoMockMatch { .. })));
543 }
544
545 #[tokio::test]
546 async fn mock_http_records_requests() {
547 let mock = MockHttp::new().on_get(r".*").respond_json(200, json!({}));
548
549 mock.request("GET", "https://example.com/a", HashMap::new(), None)
550 .await
551 .unwrap();
552 mock.request("GET", "https://example.com/b", HashMap::new(), None)
553 .await
554 .unwrap();
555
556 assert_eq!(mock.request_count(), 2);
557 assert!(mock.assert_request_made("GET", r"example\.com/a"));
558 assert!(mock.assert_request_made("GET", r"example\.com/b"));
559 }
560
561 #[tokio::test]
562 async fn mock_http_times_limit() {
563 let mock = MockHttp::new()
564 .on_get(r"^https://api\.example\.com/users$")
565 .times(2)
566 .respond_json(200, json!([]));
567
568 mock.request("GET", "https://api.example.com/users", HashMap::new(), None)
570 .await
571 .unwrap();
572 mock.request("GET", "https://api.example.com/users", HashMap::new(), None)
573 .await
574 .unwrap();
575
576 let result = mock
578 .request("GET", "https://api.example.com/users", HashMap::new(), None)
579 .await;
580 assert!(matches!(result, Err(HttpError::NoMockMatch { .. })));
581 }
582}