ztk_rust_sdk/common/
http.rs

1//! HTTP 请求封装
2//!
3//! 封装 GET/POST 请求方法,处理 URL 编码和响应解析
4
5use reqwest::Client;
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8use std::collections::HashMap;
9use std::time::Duration;
10
11use crate::error::{ZtkError, ZtkResult};
12
13/// 默认请求超时时间 (秒)
14const DEFAULT_TIMEOUT_SECS: u64 = 30;
15
16/// 默认 API 基础地址
17pub const DEFAULT_BASE_URL: &str = "https://api.zhetaoke.com:10001";
18
19/// 备用 API 基础地址
20pub const BACKUP_BASE_URL: &str = "http://api.zhetaoke.cn:10000";
21
22/// HTTP 客户端封装
23///
24/// 提供统一的 HTTP 请求方法,处理 URL 编码和响应解析
25#[derive(Debug, Clone)]
26pub struct HttpClient {
27    /// reqwest 客户端
28    client: Client,
29    /// API 基础地址
30    base_url: String,
31}
32
33impl HttpClient {
34    /// 创建新的 HTTP 客户端
35    ///
36    /// # Arguments
37    ///
38    /// * `base_url` - API 基础地址
39    /// * `timeout` - 请求超时时间 (可选)
40    ///
41    /// # Returns
42    ///
43    /// 返回 HttpClient 实例或错误
44    pub fn new(base_url: impl Into<String>, timeout: Option<Duration>) -> ZtkResult<Self> {
45        let timeout = timeout.unwrap_or(Duration::from_secs(DEFAULT_TIMEOUT_SECS));
46
47        let client = Client::builder()
48            .timeout(timeout)
49            .build()
50            .map_err(ZtkError::Network)?;
51
52        Ok(Self {
53            client,
54            base_url: base_url.into(),
55        })
56    }
57
58    /// 使用默认配置创建 HTTP 客户端
59    ///
60    /// 使用默认的 API 基础地址和超时时间
61    pub fn with_defaults() -> ZtkResult<Self> {
62        Self::new(DEFAULT_BASE_URL, None)
63    }
64
65    /// 获取基础 URL
66    pub fn base_url(&self) -> &str {
67        &self.base_url
68    }
69
70    /// 发送 GET 请求
71    ///
72    /// # Arguments
73    ///
74    /// * `path` - API 路径
75    /// * `params` - 查询参数
76    ///
77    /// # Returns
78    ///
79    /// 返回反序列化后的响应数据或错误
80    pub async fn get<T, P>(&self, path: &str, params: &P) -> ZtkResult<T>
81    where
82        T: DeserializeOwned,
83        P: Serialize + ?Sized,
84    {
85        self.get_with_base_url(&self.base_url, path, params).await
86    }
87
88    /// 发送 GET 请求 (使用自定义 base_url)
89    ///
90    /// # Arguments
91    ///
92    /// * `base_url` - 自定义的 API 基础地址
93    /// * `path` - API 路径
94    /// * `params` - 查询参数
95    ///
96    /// # Returns
97    ///
98    /// 返回反序列化后的响应数据或错误
99    pub async fn get_with_base_url<T, P>(
100        &self,
101        base_url: &str,
102        path: &str,
103        params: &P,
104    ) -> ZtkResult<T>
105    where
106        T: DeserializeOwned,
107        P: Serialize + ?Sized,
108    {
109        let url = format!("{}{}", base_url, path);
110
111        // 将参数序列化为查询字符串
112        let query_string = self.serialize_params(params)?;
113        let full_url = if query_string.is_empty() {
114            url
115        } else {
116            format!("{}?{}", url, query_string)
117        };
118
119        let response = self
120            .client
121            .get(&full_url)
122            .send()
123            .await
124            .map_err(ZtkError::Network)?;
125
126        self.handle_response(response).await
127    }
128
129    /// 发送 POST 请求 (表单格式)
130    ///
131    /// # Arguments
132    ///
133    /// * `path` - API 路径
134    /// * `params` - 表单参数
135    ///
136    /// # Returns
137    ///
138    /// 返回反序列化后的响应数据或错误
139    pub async fn post_form<T, P>(&self, path: &str, params: &P) -> ZtkResult<T>
140    where
141        T: DeserializeOwned,
142        P: Serialize + ?Sized,
143    {
144        let url = format!("{}{}", self.base_url, path);
145
146        // 将参数序列化为表单数据
147        let form_data = self.serialize_params(params)?;
148
149        let response = self
150            .client
151            .post(&url)
152            .header("Content-Type", "application/x-www-form-urlencoded")
153            .body(form_data)
154            .send()
155            .await
156            .map_err(ZtkError::Network)?;
157
158        self.handle_response(response).await
159    }
160
161    /// 发送 POST 请求 (JSON 格式)
162    ///
163    /// # Arguments
164    ///
165    /// * `path` - API 路径
166    /// * `body` - JSON 请求体
167    ///
168    /// # Returns
169    ///
170    /// 返回反序列化后的响应数据或错误
171    pub async fn post_json<T, B>(&self, path: &str, body: &B) -> ZtkResult<T>
172    where
173        T: DeserializeOwned,
174        B: Serialize + ?Sized,
175    {
176        let url = format!("{}{}", self.base_url, path);
177
178        let response = self
179            .client
180            .post(&url)
181            .json(body)
182            .send()
183            .await
184            .map_err(ZtkError::Network)?;
185
186        self.handle_response(response).await
187    }
188
189    /// 序列化参数为 URL 编码的查询字符串
190    ///
191    /// # Arguments
192    ///
193    /// * `params` - 要序列化的参数
194    ///
195    /// # Returns
196    ///
197    /// 返回 URL 编码的查询字符串或错误
198    fn serialize_params<P: Serialize + ?Sized>(&self, params: &P) -> ZtkResult<String> {
199        // 先序列化为 JSON,然后转换为 HashMap
200        let json_value = serde_json::to_value(params)?;
201
202        let map = match json_value {
203            serde_json::Value::Object(map) => map,
204            _ => return Ok(String::new()),
205        };
206
207        let mut pairs: Vec<String> = Vec::new();
208        for (key, value) in map {
209            let value_str = match value {
210                serde_json::Value::Null => continue, // 跳过 null 值
211                serde_json::Value::String(s) => s,
212                serde_json::Value::Number(n) => n.to_string(),
213                serde_json::Value::Bool(b) => b.to_string(),
214                _ => serde_json::to_string(&value)?,
215            };
216
217            // URL 编码
218            let encoded_value = url_encode(&value_str);
219            pairs.push(format!("{}={}", key, encoded_value));
220        }
221
222        Ok(pairs.join("&"))
223    }
224
225    /// 处理 HTTP 响应
226    ///
227    /// # Arguments
228    ///
229    /// * `response` - HTTP 响应
230    ///
231    /// # Returns
232    ///
233    /// 返回反序列化后的数据或错误
234    async fn handle_response<T: DeserializeOwned>(
235        &self,
236        response: reqwest::Response,
237    ) -> ZtkResult<T> {
238        let status = response.status();
239        let text = response.text().await.map_err(ZtkError::Network)?;
240
241        // 检查 HTTP 状态码
242        if !status.is_success() {
243            return Err(ZtkError::api(
244                status.as_u16() as i32,
245                format!("HTTP 错误: {}", text),
246            ));
247        }
248
249        // 尝试解析为 API 响应
250        // 首先检查是否是 API 错误响应
251        if let Ok(api_error) = serde_json::from_str::<ApiErrorResponse>(&text) {
252            if api_error.status != 200 && api_error.status != 0 {
253                return Err(ZtkError::api(api_error.status, api_error.msg));
254            }
255        }
256
257        // 解析为目标类型
258        serde_json::from_str(&text).map_err(ZtkError::Parse)
259    }
260}
261
262/// API 错误响应结构
263#[derive(Debug, serde::Deserialize)]
264struct ApiErrorResponse {
265    /// 状态码
266    #[serde(default)]
267    status: i32,
268    /// 错误消息
269    #[serde(default)]
270    msg: String,
271}
272
273/// URL 编码函数
274///
275/// 对字符串进行 URL 编码,保留字母数字和部分特殊字符
276///
277/// # Arguments
278///
279/// * `input` - 要编码的字符串
280///
281/// # Returns
282///
283/// 返回 URL 编码后的字符串
284pub fn url_encode(input: &str) -> String {
285    let mut encoded = String::with_capacity(input.len() * 3);
286
287    for byte in input.bytes() {
288        match byte {
289            // 不编码的字符: 字母、数字、-、_、.、~
290            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
291                encoded.push(byte as char);
292            }
293            // 其他字符进行百分号编码
294            _ => {
295                encoded.push('%');
296                encoded.push_str(&format!("{:02X}", byte));
297            }
298        }
299    }
300
301    encoded
302}
303
304/// URL 解码函数
305///
306/// 对 URL 编码的字符串进行解码
307///
308/// # Arguments
309///
310/// * `input` - URL 编码的字符串
311///
312/// # Returns
313///
314/// 返回解码后的字符串或错误
315pub fn url_decode(input: &str) -> ZtkResult<String> {
316    let mut decoded = Vec::with_capacity(input.len());
317    let mut chars = input.chars().peekable();
318
319    while let Some(c) = chars.next() {
320        if c == '%' {
321            // 读取两个十六进制字符
322            let hex1 = chars
323                .next()
324                .ok_or_else(|| ZtkError::url_encode("无效的 URL 编码: 缺少十六进制字符"))?;
325            let hex2 = chars
326                .next()
327                .ok_or_else(|| ZtkError::url_encode("无效的 URL 编码: 缺少十六进制字符"))?;
328
329            let hex_str: String = [hex1, hex2].iter().collect();
330            let byte = u8::from_str_radix(&hex_str, 16)
331                .map_err(|_| ZtkError::url_encode(format!("无效的十六进制字符: {}", hex_str)))?;
332
333            decoded.push(byte);
334        } else if c == '+' {
335            // + 号解码为空格
336            decoded.push(b' ');
337        } else {
338            decoded.push(c as u8);
339        }
340    }
341
342    String::from_utf8(decoded).map_err(|e| ZtkError::url_encode(format!("UTF-8 解码失败: {}", e)))
343}
344
345/// 构建带有 appkey 的请求参数
346///
347/// # Arguments
348///
349/// * `appkey` - 折淘客 AppKey
350/// * `params` - 其他请求参数
351///
352/// # Returns
353///
354/// 返回包含 appkey 的参数 HashMap
355pub fn build_params_with_appkey<P: Serialize>(
356    appkey: &str,
357    params: &P,
358) -> ZtkResult<HashMap<String, String>> {
359    let json_value = serde_json::to_value(params)?;
360
361    let mut map = HashMap::new();
362    map.insert("appkey".to_string(), appkey.to_string());
363
364    if let serde_json::Value::Object(obj) = json_value {
365        for (key, value) in obj {
366            let value_str = match value {
367                serde_json::Value::Null => continue,
368                serde_json::Value::String(s) => s,
369                serde_json::Value::Number(n) => n.to_string(),
370                serde_json::Value::Bool(b) => b.to_string(),
371                _ => serde_json::to_string(&value)?,
372            };
373            map.insert(key, value_str);
374        }
375    }
376
377    Ok(map)
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn test_url_encode_basic() {
386        assert_eq!(url_encode("hello"), "hello");
387        assert_eq!(url_encode("hello world"), "hello%20world");
388        assert_eq!(url_encode("a=b&c=d"), "a%3Db%26c%3Dd");
389    }
390
391    #[test]
392    fn test_url_encode_chinese() {
393        let input = "淘口令";
394        let encoded = url_encode(input);
395        // 中文字符应该被编码
396        assert!(encoded.contains('%'));
397        assert!(!encoded.contains("淘"));
398    }
399
400    #[test]
401    fn test_url_encode_special_chars() {
402        assert_eq!(url_encode("test-value"), "test-value");
403        assert_eq!(url_encode("test_value"), "test_value");
404        assert_eq!(url_encode("test.value"), "test.value");
405        assert_eq!(url_encode("test~value"), "test~value");
406    }
407
408    #[test]
409    fn test_url_decode_basic() {
410        assert_eq!(url_decode("hello").unwrap(), "hello");
411        assert_eq!(url_decode("hello%20world").unwrap(), "hello world");
412        assert_eq!(url_decode("a%3Db%26c%3Dd").unwrap(), "a=b&c=d");
413    }
414
415    #[test]
416    fn test_url_decode_plus_sign() {
417        assert_eq!(url_decode("hello+world").unwrap(), "hello world");
418    }
419
420    #[test]
421    fn test_url_encode_decode_roundtrip() {
422        let test_cases = vec![
423            "hello world",
424            "淘口令测试",
425            "a=b&c=d",
426            "special!@#$%^&*()",
427            "mixed 中文 and English",
428        ];
429
430        for input in test_cases {
431            let encoded = url_encode(input);
432            let decoded = url_decode(&encoded).unwrap();
433            assert_eq!(decoded, input, "Round-trip failed for: {}", input);
434        }
435    }
436
437    #[test]
438    fn test_url_decode_invalid() {
439        // 不完整的编码
440        assert!(url_decode("%").is_err());
441        assert!(url_decode("%2").is_err());
442
443        // 无效的十六进制
444        assert!(url_decode("%GG").is_err());
445    }
446
447    #[test]
448    fn test_default_base_url() {
449        assert_eq!(DEFAULT_BASE_URL, "https://api.zhetaoke.com:10001");
450    }
451
452    #[test]
453    fn test_backup_base_url() {
454        assert_eq!(BACKUP_BASE_URL, "http://api.zhetaoke.cn:10000");
455    }
456
457    #[test]
458    fn test_http_client_creation() {
459        let client = HttpClient::new("https://example.com", None);
460        assert!(client.is_ok());
461
462        let client = client.unwrap();
463        assert_eq!(client.base_url(), "https://example.com");
464    }
465
466    #[test]
467    fn test_http_client_with_defaults() {
468        let client = HttpClient::with_defaults();
469        assert!(client.is_ok());
470
471        let client = client.unwrap();
472        assert_eq!(client.base_url(), DEFAULT_BASE_URL);
473    }
474
475    #[test]
476    fn test_build_params_with_appkey() {
477        #[derive(Serialize)]
478        struct TestParams {
479            name: String,
480            value: i32,
481        }
482
483        let params = TestParams {
484            name: "test".to_string(),
485            value: 123,
486        };
487
488        let result = build_params_with_appkey("my_appkey", &params).unwrap();
489
490        assert_eq!(result.get("appkey"), Some(&"my_appkey".to_string()));
491        assert_eq!(result.get("name"), Some(&"test".to_string()));
492        assert_eq!(result.get("value"), Some(&"123".to_string()));
493    }
494
495    #[test]
496    fn test_build_params_with_optional_fields() {
497        #[derive(Serialize)]
498        struct TestParams {
499            required: String,
500            #[serde(skip_serializing_if = "Option::is_none")]
501            optional: Option<String>,
502        }
503
504        let params = TestParams {
505            required: "value".to_string(),
506            optional: None,
507        };
508
509        let result = build_params_with_appkey("my_appkey", &params).unwrap();
510
511        assert_eq!(result.get("appkey"), Some(&"my_appkey".to_string()));
512        assert_eq!(result.get("required"), Some(&"value".to_string()));
513        assert!(!result.contains_key("optional"));
514    }
515}