Skip to main content

wae_testing/
http_client.rs

1//! HTTP 测试客户端模块
2//!
3//! 提供用于测试 HTTP 服务的客户端工具,支持链式 API 风格调用。
4
5use bytes::Bytes;
6use http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
7use serde::Serialize;
8use std::collections::HashMap;
9use wae_types::{WaeError, WaeErrorKind, WaeResult as TestingResult};
10
11/// HTTP 测试客户端
12///
13/// 用于构建 HTTP 请求的测试工具。
14#[derive(Debug, Clone)]
15pub struct TestClient {
16    base_url: Option<String>,
17    headers: HeaderMap,
18    query_params: HashMap<String, String>,
19}
20
21impl TestClient {
22    /// 创建新的测试客户端
23    pub fn new() -> Self {
24        Self { base_url: None, headers: HeaderMap::new(), query_params: HashMap::new() }
25    }
26
27    /// 设置基础 URL
28    pub fn base_url(mut self, url: impl Into<String>) -> Self {
29        self.base_url = Some(url.into());
30        self
31    }
32
33    /// 添加请求头
34    pub fn header<K, V>(mut self, key: K, value: V) -> Self
35    where
36        K: TryInto<HeaderName>,
37        V: TryInto<HeaderValue>,
38    {
39        if let Ok(key) = key.try_into()
40            && let Ok(value) = value.try_into()
41        {
42            self.headers.append(key, value);
43        }
44        self
45    }
46
47    /// 设置多个请求头
48    pub fn headers(mut self, headers: HeaderMap) -> Self {
49        self.headers.extend(headers);
50        self
51    }
52
53    /// 添加查询参数
54    pub fn query_param<K, V>(mut self, key: K, value: V) -> Self
55    where
56        K: Into<String>,
57        V: Into<String>,
58    {
59        self.query_params.insert(key.into(), value.into());
60        self
61    }
62
63    /// 设置多个查询参数
64    pub fn query_params(mut self, params: HashMap<String, String>) -> Self {
65        self.query_params.extend(params);
66        self
67    }
68
69    /// 创建 GET 请求
70    pub fn get(self, url: impl Into<String>) -> RequestBuilder {
71        self.request(http::Method::GET, url)
72    }
73
74    /// 创建 POST 请求
75    pub fn post(self, url: impl Into<String>) -> RequestBuilder {
76        self.request(http::Method::POST, url)
77    }
78
79    /// 创建 PUT 请求
80    pub fn put(self, url: impl Into<String>) -> RequestBuilder {
81        self.request(http::Method::PUT, url)
82    }
83
84    /// 创建 DELETE 请求
85    pub fn delete(self, url: impl Into<String>) -> RequestBuilder {
86        self.request(http::Method::DELETE, url)
87    }
88
89    /// 创建 PATCH 请求
90    pub fn patch(self, url: impl Into<String>) -> RequestBuilder {
91        self.request(http::Method::PATCH, url)
92    }
93
94    /// 创建 HEAD 请求
95    pub fn head(self, url: impl Into<String>) -> RequestBuilder {
96        self.request(http::Method::HEAD, url)
97    }
98
99    /// 创建 OPTIONS 请求
100    pub fn options(self, url: impl Into<String>) -> RequestBuilder {
101        self.request(http::Method::OPTIONS, url)
102    }
103
104    /// 创建自定义 HTTP 方法的请求
105    pub fn request<M, U>(self, method: M, url: U) -> RequestBuilder
106    where
107        M: Into<http::Method>,
108        U: Into<String>,
109    {
110        RequestBuilder::new(method, url, self.base_url, self.headers, self.query_params)
111    }
112}
113
114impl Default for TestClient {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120/// 请求构建器
121///
122/// 用于构建 HTTP 请求的链式构建器。
123#[derive(Debug, Clone)]
124pub struct RequestBuilder {
125    method: http::Method,
126    url: String,
127    base_url: Option<String>,
128    headers: HeaderMap,
129    query_params: HashMap<String, String>,
130    body: Option<RequestBody>,
131}
132
133/// 请求体类型
134#[derive(Debug, Clone)]
135enum RequestBody {
136    /// JSON 请求体
137    Json(Bytes),
138    /// 文本请求体
139    Text(String),
140    /// 字节请求体
141    Bytes(Bytes),
142}
143
144impl RequestBuilder {
145    /// 创建新的请求构建器
146    fn new<M, U>(method: M, url: U, base_url: Option<String>, headers: HeaderMap, query_params: HashMap<String, String>) -> Self
147    where
148        M: Into<http::Method>,
149        U: Into<String>,
150    {
151        Self { method: method.into(), url: url.into(), base_url, headers, query_params, body: None }
152    }
153
154    /// 添加请求头
155    pub fn header<K, V>(mut self, key: K, value: V) -> Self
156    where
157        K: TryInto<HeaderName>,
158        V: TryInto<HeaderValue>,
159    {
160        if let Ok(key) = key.try_into()
161            && let Ok(value) = value.try_into()
162        {
163            self.headers.append(key, value);
164        }
165        self
166    }
167
168    /// 设置多个请求头
169    pub fn headers(mut self, headers: HeaderMap) -> Self {
170        self.headers.extend(headers);
171        self
172    }
173
174    /// 添加查询参数
175    pub fn query_param<K, V>(mut self, key: K, value: V) -> Self
176    where
177        K: Into<String>,
178        V: Into<String>,
179    {
180        self.query_params.insert(key.into(), value.into());
181        self
182    }
183
184    /// 设置多个查询参数
185    pub fn query_params(mut self, params: HashMap<String, String>) -> Self {
186        self.query_params.extend(params);
187        self
188    }
189
190    /// 设置 JSON 请求体
191    pub fn json<T: Serialize>(mut self, data: &T) -> TestingResult<Self> {
192        let bytes = serde_json::to_vec(data).map_err(|e| WaeError::new(WaeErrorKind::JsonError { reason: e.to_string() }))?;
193        self.body = Some(RequestBody::Json(Bytes::from(bytes)));
194        Ok(self)
195    }
196
197    /// 设置文本请求体
198    pub fn text(mut self, data: impl Into<String>) -> Self {
199        self.body = Some(RequestBody::Text(data.into()));
200        self
201    }
202
203    /// 设置字节请求体
204    pub fn bytes(mut self, data: impl Into<Bytes>) -> Self {
205        self.body = Some(RequestBody::Bytes(data.into()));
206        self
207    }
208
209    /// 构建完整 URL
210    fn build_url(&self) -> String {
211        let mut url_str = if let Some(ref base_url) = self.base_url {
212            if self.url.starts_with("http://") || self.url.starts_with("https://") {
213                self.url.clone()
214            }
215            else {
216                format!("{}{}", base_url, self.url)
217            }
218        }
219        else {
220            self.url.clone()
221        };
222
223        if !self.query_params.is_empty() {
224            let separator = if url_str.contains('?') { "&" } else { "?" };
225            let params: Vec<String> = self.query_params.iter().map(|(k, v)| format!("{}={}", k, v)).collect();
226            url_str.push_str(separator);
227            url_str.push_str(&params.join("&"));
228        }
229
230        url_str
231    }
232
233    /// 获取 HTTP 方法
234    pub fn method(&self) -> &http::Method {
235        &self.method
236    }
237
238    /// 获取 URL
239    pub fn url(&self) -> String {
240        self.build_url()
241    }
242
243    /// 获取请求头
244    pub fn get_headers(&self) -> &HeaderMap {
245        &self.headers
246    }
247
248    /// 获取请求体
249    pub fn body(&self) -> Option<Bytes> {
250        match &self.body {
251            Some(RequestBody::Json(b)) => Some(b.clone()),
252            Some(RequestBody::Text(t)) => Some(Bytes::from(t.clone())),
253            Some(RequestBody::Bytes(b)) => Some(b.clone()),
254            None => None,
255        }
256    }
257
258    /// 创建测试响应(用于单元测试)
259    pub fn create_response(&self, status: StatusCode, headers: HeaderMap, body: Bytes) -> TestResponse {
260        TestResponse { status, headers, body, request_method: self.method.clone(), request_url: self.build_url() }
261    }
262}
263
264/// HTTP 响应
265///
266/// 包含 HTTP 响应的状态码、响应头和响应体。
267#[derive(Debug, Clone)]
268pub struct TestResponse {
269    /// 响应状态码
270    pub status: StatusCode,
271    /// 响应头
272    pub headers: HeaderMap,
273    /// 响应体
274    pub body: Bytes,
275    /// 请求方法(用于测试追踪)
276    request_method: http::Method,
277    /// 请求 URL(用于测试追踪)
278    request_url: String,
279}
280
281impl TestResponse {
282    /// 创建新的测试响应
283    pub fn new(status: StatusCode, headers: HeaderMap, body: Bytes) -> Self {
284        Self { status, headers, body, request_method: http::Method::GET, request_url: String::new() }
285    }
286
287    /// 获取响应状态码
288    pub fn status(&self) -> StatusCode {
289        self.status
290    }
291
292    /// 检查响应状态码是否成功 (200-299)
293    pub fn is_success(&self) -> bool {
294        self.status.is_success()
295    }
296
297    /// 获取响应头
298    pub fn headers(&self) -> &HeaderMap {
299        &self.headers
300    }
301
302    /// 获取指定响应头
303    pub fn header<K: TryInto<HeaderName>>(&self, key: K) -> Option<&HeaderValue> {
304        if let Ok(key) = key.try_into() { self.headers.get(key) } else { None }
305    }
306
307    /// 获取响应体字节
308    pub fn body(&self) -> &Bytes {
309        &self.body
310    }
311
312    /// 将响应体解析为字符串
313    pub fn text(&self) -> TestingResult<String> {
314        String::from_utf8(self.body.to_vec())
315            .map_err(|e| WaeError::new(WaeErrorKind::ParseError { type_name: "String".to_string(), reason: e.to_string() }))
316    }
317
318    /// 将响应体解析为 JSON
319    pub fn json<T: serde::de::DeserializeOwned>(&self) -> TestingResult<T> {
320        serde_json::from_slice(&self.body).map_err(|e| WaeError::new(WaeErrorKind::JsonError { reason: e.to_string() }))
321    }
322
323    /// 断言响应状态码
324    pub fn assert_status(&self, status: StatusCode) -> TestingResult<&Self> {
325        if self.status != status {
326            return Err(WaeError::new(WaeErrorKind::AssertionFailed {
327                message: format!("Expected status {}, got {}", status, self.status),
328            }));
329        }
330        Ok(self)
331    }
332
333    /// 断言响应是成功状态码 (200-299)
334    pub fn assert_success(&self) -> TestingResult<&Self> {
335        if !self.is_success() {
336            return Err(WaeError::new(WaeErrorKind::AssertionFailed {
337                message: format!("Expected success status, got {}", self.status),
338            }));
339        }
340        Ok(self)
341    }
342
343    /// 断言响应头存在
344    pub fn assert_header<K: TryInto<HeaderName>>(&self, key: K) -> TestingResult<&Self> {
345        if let Ok(key) = key.try_into()
346            && self.headers.contains_key(&key)
347        {
348            return Ok(self);
349        }
350        Err(WaeError::new(WaeErrorKind::AssertionFailed { message: "Expected header not found".to_string() }))
351    }
352
353    /// 断言响应头值
354    pub fn assert_header_eq<K, V>(&self, key: K, value: V) -> TestingResult<&Self>
355    where
356        K: TryInto<HeaderName>,
357        V: AsRef<str>,
358    {
359        if let Ok(key) = key.try_into()
360            && let Some(header_value) = self.headers.get(&key)
361            && let Ok(hv_str) = header_value.to_str()
362            && hv_str == value.as_ref()
363        {
364            return Ok(self);
365        }
366        Err(WaeError::new(WaeErrorKind::AssertionFailed { message: format!("Expected header value {}", value.as_ref()) }))
367    }
368
369    /// 断言响应体包含指定文本
370    pub fn assert_body_contains(&self, text: &str) -> TestingResult<&Self> {
371        let body_str = self.text()?;
372        if !body_str.contains(text) {
373            return Err(WaeError::new(WaeErrorKind::AssertionFailed {
374                message: "Expected text not found in response body".to_string(),
375            }));
376        }
377        Ok(self)
378    }
379
380    /// 获取请求方法(用于测试追踪)
381    pub fn request_method(&self) -> &http::Method {
382        &self.request_method
383    }
384
385    /// 获取请求 URL(用于测试追踪)
386    pub fn request_url(&self) -> &str {
387        &self.request_url
388    }
389}