1use reqwest::Client;
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8use std::collections::HashMap;
9use std::time::Duration;
10
11use crate::error::{ZtkError, ZtkResult};
12
13const DEFAULT_TIMEOUT_SECS: u64 = 30;
15
16pub const DEFAULT_BASE_URL: &str = "https://api.zhetaoke.com:10001";
18
19pub const BACKUP_BASE_URL: &str = "http://api.zhetaoke.cn:10000";
21
22#[derive(Debug, Clone)]
26pub struct HttpClient {
27 client: Client,
29 base_url: String,
31}
32
33impl HttpClient {
34 pub fn new(base_url: impl Into<String>, timeout: Option<Duration>) -> ZtkResult<Self> {
45 let timeout = timeout.unwrap_or(Duration::from_secs(DEFAULT_TIMEOUT_SECS));
46
47 let client = Client::builder()
48 .timeout(timeout)
49 .build()
50 .map_err(ZtkError::Network)?;
51
52 Ok(Self {
53 client,
54 base_url: base_url.into(),
55 })
56 }
57
58 pub fn with_defaults() -> ZtkResult<Self> {
62 Self::new(DEFAULT_BASE_URL, None)
63 }
64
65 pub fn base_url(&self) -> &str {
67 &self.base_url
68 }
69
70 pub async fn get<T, P>(&self, path: &str, params: &P) -> ZtkResult<T>
81 where
82 T: DeserializeOwned,
83 P: Serialize + ?Sized,
84 {
85 self.get_with_base_url(&self.base_url, path, params).await
86 }
87
88 pub async fn get_with_base_url<T, P>(
100 &self,
101 base_url: &str,
102 path: &str,
103 params: &P,
104 ) -> ZtkResult<T>
105 where
106 T: DeserializeOwned,
107 P: Serialize + ?Sized,
108 {
109 let url = format!("{}{}", base_url, path);
110
111 let query_string = self.serialize_params(params)?;
113 let full_url = if query_string.is_empty() {
114 url
115 } else {
116 format!("{}?{}", url, query_string)
117 };
118
119 let response = self
120 .client
121 .get(&full_url)
122 .send()
123 .await
124 .map_err(ZtkError::Network)?;
125
126 self.handle_response(response).await
127 }
128
129 pub async fn post_form<T, P>(&self, path: &str, params: &P) -> ZtkResult<T>
140 where
141 T: DeserializeOwned,
142 P: Serialize + ?Sized,
143 {
144 let url = format!("{}{}", self.base_url, path);
145
146 let form_data = self.serialize_params(params)?;
148
149 let response = self
150 .client
151 .post(&url)
152 .header("Content-Type", "application/x-www-form-urlencoded")
153 .body(form_data)
154 .send()
155 .await
156 .map_err(ZtkError::Network)?;
157
158 self.handle_response(response).await
159 }
160
161 pub async fn post_json<T, B>(&self, path: &str, body: &B) -> ZtkResult<T>
172 where
173 T: DeserializeOwned,
174 B: Serialize + ?Sized,
175 {
176 let url = format!("{}{}", self.base_url, path);
177
178 let response = self
179 .client
180 .post(&url)
181 .json(body)
182 .send()
183 .await
184 .map_err(ZtkError::Network)?;
185
186 self.handle_response(response).await
187 }
188
189 fn serialize_params<P: Serialize + ?Sized>(&self, params: &P) -> ZtkResult<String> {
199 let json_value = serde_json::to_value(params)?;
201
202 let map = match json_value {
203 serde_json::Value::Object(map) => map,
204 _ => return Ok(String::new()),
205 };
206
207 let mut pairs: Vec<String> = Vec::new();
208 for (key, value) in map {
209 let value_str = match value {
210 serde_json::Value::Null => continue, serde_json::Value::String(s) => s,
212 serde_json::Value::Number(n) => n.to_string(),
213 serde_json::Value::Bool(b) => b.to_string(),
214 _ => serde_json::to_string(&value)?,
215 };
216
217 let encoded_value = url_encode(&value_str);
219 pairs.push(format!("{}={}", key, encoded_value));
220 }
221
222 Ok(pairs.join("&"))
223 }
224
225 async fn handle_response<T: DeserializeOwned>(
235 &self,
236 response: reqwest::Response,
237 ) -> ZtkResult<T> {
238 let status = response.status();
239 let text = response.text().await.map_err(ZtkError::Network)?;
240
241 if !status.is_success() {
243 return Err(ZtkError::api(
244 status.as_u16() as i32,
245 format!("HTTP 错误: {}", text),
246 ));
247 }
248
249 if let Ok(api_error) = serde_json::from_str::<ApiErrorResponse>(&text) {
252 if api_error.status != 200 && api_error.status != 0 {
253 return Err(ZtkError::api(api_error.status, api_error.msg));
254 }
255 }
256
257 serde_json::from_str(&text).map_err(ZtkError::Parse)
259 }
260}
261
262#[derive(Debug, serde::Deserialize)]
264struct ApiErrorResponse {
265 #[serde(default)]
267 status: i32,
268 #[serde(default)]
270 msg: String,
271}
272
273pub fn url_encode(input: &str) -> String {
285 let mut encoded = String::with_capacity(input.len() * 3);
286
287 for byte in input.bytes() {
288 match byte {
289 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
291 encoded.push(byte as char);
292 }
293 _ => {
295 encoded.push('%');
296 encoded.push_str(&format!("{:02X}", byte));
297 }
298 }
299 }
300
301 encoded
302}
303
304pub fn url_decode(input: &str) -> ZtkResult<String> {
316 let mut decoded = Vec::with_capacity(input.len());
317 let mut chars = input.chars().peekable();
318
319 while let Some(c) = chars.next() {
320 if c == '%' {
321 let hex1 = chars
323 .next()
324 .ok_or_else(|| ZtkError::url_encode("无效的 URL 编码: 缺少十六进制字符"))?;
325 let hex2 = chars
326 .next()
327 .ok_or_else(|| ZtkError::url_encode("无效的 URL 编码: 缺少十六进制字符"))?;
328
329 let hex_str: String = [hex1, hex2].iter().collect();
330 let byte = u8::from_str_radix(&hex_str, 16)
331 .map_err(|_| ZtkError::url_encode(format!("无效的十六进制字符: {}", hex_str)))?;
332
333 decoded.push(byte);
334 } else if c == '+' {
335 decoded.push(b' ');
337 } else {
338 decoded.push(c as u8);
339 }
340 }
341
342 String::from_utf8(decoded).map_err(|e| ZtkError::url_encode(format!("UTF-8 解码失败: {}", e)))
343}
344
345pub fn build_params_with_appkey<P: Serialize>(
356 appkey: &str,
357 params: &P,
358) -> ZtkResult<HashMap<String, String>> {
359 let json_value = serde_json::to_value(params)?;
360
361 let mut map = HashMap::new();
362 map.insert("appkey".to_string(), appkey.to_string());
363
364 if let serde_json::Value::Object(obj) = json_value {
365 for (key, value) in obj {
366 let value_str = match value {
367 serde_json::Value::Null => continue,
368 serde_json::Value::String(s) => s,
369 serde_json::Value::Number(n) => n.to_string(),
370 serde_json::Value::Bool(b) => b.to_string(),
371 _ => serde_json::to_string(&value)?,
372 };
373 map.insert(key, value_str);
374 }
375 }
376
377 Ok(map)
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[test]
385 fn test_url_encode_basic() {
386 assert_eq!(url_encode("hello"), "hello");
387 assert_eq!(url_encode("hello world"), "hello%20world");
388 assert_eq!(url_encode("a=b&c=d"), "a%3Db%26c%3Dd");
389 }
390
391 #[test]
392 fn test_url_encode_chinese() {
393 let input = "淘口令";
394 let encoded = url_encode(input);
395 assert!(encoded.contains('%'));
397 assert!(!encoded.contains("淘"));
398 }
399
400 #[test]
401 fn test_url_encode_special_chars() {
402 assert_eq!(url_encode("test-value"), "test-value");
403 assert_eq!(url_encode("test_value"), "test_value");
404 assert_eq!(url_encode("test.value"), "test.value");
405 assert_eq!(url_encode("test~value"), "test~value");
406 }
407
408 #[test]
409 fn test_url_decode_basic() {
410 assert_eq!(url_decode("hello").unwrap(), "hello");
411 assert_eq!(url_decode("hello%20world").unwrap(), "hello world");
412 assert_eq!(url_decode("a%3Db%26c%3Dd").unwrap(), "a=b&c=d");
413 }
414
415 #[test]
416 fn test_url_decode_plus_sign() {
417 assert_eq!(url_decode("hello+world").unwrap(), "hello world");
418 }
419
420 #[test]
421 fn test_url_encode_decode_roundtrip() {
422 let test_cases = vec![
423 "hello world",
424 "淘口令测试",
425 "a=b&c=d",
426 "special!@#$%^&*()",
427 "mixed 中文 and English",
428 ];
429
430 for input in test_cases {
431 let encoded = url_encode(input);
432 let decoded = url_decode(&encoded).unwrap();
433 assert_eq!(decoded, input, "Round-trip failed for: {}", input);
434 }
435 }
436
437 #[test]
438 fn test_url_decode_invalid() {
439 assert!(url_decode("%").is_err());
441 assert!(url_decode("%2").is_err());
442
443 assert!(url_decode("%GG").is_err());
445 }
446
447 #[test]
448 fn test_default_base_url() {
449 assert_eq!(DEFAULT_BASE_URL, "https://api.zhetaoke.com:10001");
450 }
451
452 #[test]
453 fn test_backup_base_url() {
454 assert_eq!(BACKUP_BASE_URL, "http://api.zhetaoke.cn:10000");
455 }
456
457 #[test]
458 fn test_http_client_creation() {
459 let client = HttpClient::new("https://example.com", None);
460 assert!(client.is_ok());
461
462 let client = client.unwrap();
463 assert_eq!(client.base_url(), "https://example.com");
464 }
465
466 #[test]
467 fn test_http_client_with_defaults() {
468 let client = HttpClient::with_defaults();
469 assert!(client.is_ok());
470
471 let client = client.unwrap();
472 assert_eq!(client.base_url(), DEFAULT_BASE_URL);
473 }
474
475 #[test]
476 fn test_build_params_with_appkey() {
477 #[derive(Serialize)]
478 struct TestParams {
479 name: String,
480 value: i32,
481 }
482
483 let params = TestParams {
484 name: "test".to_string(),
485 value: 123,
486 };
487
488 let result = build_params_with_appkey("my_appkey", ¶ms).unwrap();
489
490 assert_eq!(result.get("appkey"), Some(&"my_appkey".to_string()));
491 assert_eq!(result.get("name"), Some(&"test".to_string()));
492 assert_eq!(result.get("value"), Some(&"123".to_string()));
493 }
494
495 #[test]
496 fn test_build_params_with_optional_fields() {
497 #[derive(Serialize)]
498 struct TestParams {
499 required: String,
500 #[serde(skip_serializing_if = "Option::is_none")]
501 optional: Option<String>,
502 }
503
504 let params = TestParams {
505 required: "value".to_string(),
506 optional: None,
507 };
508
509 let result = build_params_with_appkey("my_appkey", ¶ms).unwrap();
510
511 assert_eq!(result.get("appkey"), Some(&"my_appkey".to_string()));
512 assert_eq!(result.get("required"), Some(&"value".to_string()));
513 assert!(!result.contains_key("optional"));
514 }
515}