Skip to main content

wae_request/
lib.rs

1//! WAE Request - 基于 hyper 的 HTTP 客户端
2//!
3//! 基于 hyper + hyper-tls 实现的 HTTP 客户端
4
5#![warn(missing_docs)]
6
7use bytes::Bytes;
8use http_body_util::{BodyExt, Full};
9use hyper::{Request, Uri};
10use hyper_util::{client::legacy::Client, rt::TokioExecutor};
11use serde::{Serialize, de::DeserializeOwned};
12use std::{collections::HashMap, sync::Arc, time::Duration};
13use tokio::time::timeout;
14use tracing::{debug, error, info};
15use url::Url;
16use wae_types::{WaeError, WaeErrorKind, WaeResult};
17
18/// HTTP 客户端配置
19#[derive(Debug, Clone)]
20pub struct HttpClientConfig {
21    /// 请求超时时间
22    pub timeout: Duration,
23    /// 连接超时时间
24    pub connect_timeout: Duration,
25    /// 用户代理
26    pub user_agent: String,
27    /// 最大重试次数
28    pub max_retries: u32,
29    /// 重试间隔
30    pub retry_delay: Duration,
31    /// 默认请求头
32    pub default_headers: HashMap<String, String>,
33}
34
35impl Default for HttpClientConfig {
36    fn default() -> Self {
37        Self {
38            timeout: Duration::from_secs(30),
39            connect_timeout: Duration::from_secs(10),
40            user_agent: "wae-request/0.1.0".to_string(),
41            max_retries: 3,
42            retry_delay: Duration::from_millis(1000),
43            default_headers: HashMap::new(),
44        }
45    }
46}
47
48/// HTTP 响应
49#[derive(Debug)]
50pub struct HttpResponse {
51    /// HTTP 版本
52    pub version: String,
53    /// 状态码
54    pub status: u16,
55    /// 状态文本
56    pub status_text: String,
57    /// 响应头
58    pub headers: HashMap<String, String>,
59    /// 响应体
60    pub body: Vec<u8>,
61}
62
63impl HttpResponse {
64    /// 解析 JSON 响应体
65    pub fn json<T: DeserializeOwned>(&self) -> WaeResult<T> {
66        serde_json::from_slice(&self.body).map_err(|e| WaeError::parse_error("JSON", e.to_string()))
67    }
68
69    /// 获取文本响应体
70    pub fn text(&self) -> WaeResult<String> {
71        String::from_utf8(self.body.clone()).map_err(|e| WaeError::parse_error("UTF-8", e.to_string()))
72    }
73
74    /// 检查是否成功
75    pub fn is_success(&self) -> bool {
76        self.status >= 200 && self.status < 300
77    }
78}
79
80/// hyper 客户端 (全局共享)
81static HYPER_CLIENT: std::sync::OnceLock<
82    Arc<Client<hyper_tls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>, Full<Bytes>>>,
83> = std::sync::OnceLock::new();
84
85fn get_hyper_client() -> Arc<Client<hyper_tls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>, Full<Bytes>>>
86{
87    HYPER_CLIENT
88        .get_or_init(|| {
89            let mut http = hyper_util::client::legacy::connect::HttpConnector::new();
90            http.enforce_http(false);
91
92            let https = hyper_tls::HttpsConnector::new_with_connector(http);
93
94            let client = Client::builder(TokioExecutor::new()).build(https);
95
96            Arc::new(client)
97        })
98        .clone()
99}
100
101/// HTTP 客户端
102#[derive(Debug, Clone)]
103pub struct HttpClient {
104    config: HttpClientConfig,
105}
106
107impl Default for HttpClient {
108    fn default() -> Self {
109        Self::new(HttpClientConfig::default())
110    }
111}
112
113impl HttpClient {
114    /// 创建新的 HTTP 客户端
115    pub fn new(config: HttpClientConfig) -> Self {
116        Self { config }
117    }
118
119    /// 使用默认配置创建
120    pub fn with_defaults() -> Self {
121        Self::default()
122    }
123
124    /// 获取配置
125    pub fn config(&self) -> &HttpClientConfig {
126        &self.config
127    }
128
129    /// 发送 GET 请求
130    pub async fn get(&self, url: &str) -> WaeResult<HttpResponse> {
131        self.request("GET", url, None, None).await
132    }
133
134    /// 发送带请求头的 GET 请求
135    pub async fn get_with_headers(&self, url: &str, headers: HashMap<String, String>) -> WaeResult<HttpResponse> {
136        self.request("GET", url, None, Some(headers)).await
137    }
138
139    /// 发送 POST JSON 请求
140    pub async fn post_json<T: Serialize>(&self, url: &str, body: &T) -> WaeResult<HttpResponse> {
141        let json_body = serde_json::to_vec(body).map_err(|_e| WaeError::serialization_failed("JSON"))?;
142
143        let mut headers = HashMap::new();
144        headers.insert("Content-Type".to_string(), "application/json".to_string());
145
146        self.request("POST", url, Some(json_body), Some(headers)).await
147    }
148
149    /// 发送带请求头的 POST 请求
150    pub async fn post_with_headers(
151        &self,
152        url: &str,
153        body: Vec<u8>,
154        headers: HashMap<String, String>,
155    ) -> WaeResult<HttpResponse> {
156        self.request("POST", url, Some(body), Some(headers)).await
157    }
158
159    /// 发送请求 (带重试)
160    pub async fn request(
161        &self,
162        method: &str,
163        url: &str,
164        body: Option<Vec<u8>>,
165        headers: Option<HashMap<String, String>>,
166    ) -> WaeResult<HttpResponse> {
167        let mut last_error = None;
168
169        for attempt in 0..=self.config.max_retries {
170            if attempt > 0 {
171                let delay = self.config.retry_delay * attempt;
172                debug!("Retry attempt {} after {:?}", attempt, delay);
173                tokio::time::sleep(delay).await;
174            }
175
176            match self.request_once(method, url, body.clone(), headers.clone()).await {
177                Ok(response) => {
178                    if response.is_success() {
179                        info!("Request succeeded on attempt {}", attempt);
180                        return Ok(response);
181                    }
182
183                    if Self::is_retryable_status(response.status) && attempt < self.config.max_retries {
184                        last_error = Some(WaeError::new(WaeErrorKind::RequestError {
185                            url: url.to_string(),
186                            reason: format!("HTTP {}: {}", response.status, String::from_utf8_lossy(&response.body)),
187                        }));
188                        continue;
189                    }
190
191                    return Err(WaeError::new(WaeErrorKind::RequestError {
192                        url: url.to_string(),
193                        reason: format!("HTTP {}: {}", response.status, String::from_utf8_lossy(&response.body)),
194                    }));
195                }
196                Err(e) => {
197                    error!("Request error on attempt {}: {}", attempt, e);
198                    if Self::is_retryable_error(&e) && attempt < self.config.max_retries {
199                        last_error = Some(e);
200                        continue;
201                    }
202                    return Err(e);
203                }
204            }
205        }
206
207        Err(last_error.unwrap_or_else(|| WaeError::operation_timeout("request", self.config.timeout.as_millis() as u64)))
208    }
209
210    /// 发送单次请求
211    async fn request_once(
212        &self,
213        method: &str,
214        url_str: &str,
215        body: Option<Vec<u8>>,
216        extra_headers: Option<HashMap<String, String>>,
217    ) -> WaeResult<HttpResponse> {
218        let _url = Url::parse(url_str).map_err(|e| WaeError::invalid_params("url", e.to_string()))?;
219        let uri = url_str.parse::<Uri>().map_err(|e| WaeError::invalid_params("url", e.to_string()))?;
220
221        let client = get_hyper_client();
222
223        let mut builder = Request::builder().method(method).uri(uri).header("User-Agent", &self.config.user_agent);
224
225        for (key, value) in &self.config.default_headers {
226            builder = builder.header(key, value);
227        }
228
229        if let Some(headers) = extra_headers {
230            for (key, value) in headers {
231                builder = builder.header(key, value);
232            }
233        }
234
235        let request = match body {
236            Some(b) => {
237                let len = b.len();
238                builder.header("Content-Length", len).body(Full::new(Bytes::from(b))).map_err(|e| {
239                    WaeError::new(WaeErrorKind::RequestError { url: url_str.to_string(), reason: e.to_string() })
240                })?
241            }
242            None => builder
243                .body(Full::new(Bytes::new()))
244                .map_err(|e| WaeError::new(WaeErrorKind::RequestError { url: url_str.to_string(), reason: e.to_string() }))?,
245        };
246
247        let response = timeout(self.config.timeout, client.request(request))
248            .await
249            .map_err(|_| WaeError::operation_timeout("request", self.config.timeout.as_millis() as u64))?
250            .map_err(|_e| WaeError::new(WaeErrorKind::ConnectionFailed { target: url_str.to_string() }))?;
251
252        let status = response.status().as_u16();
253        let status_text = response.status().canonical_reason().unwrap_or("").to_string();
254
255        let version = match response.version() {
256            hyper::Version::HTTP_09 => "HTTP/0.9".to_string(),
257            hyper::Version::HTTP_10 => "HTTP/1.0".to_string(),
258            hyper::Version::HTTP_11 => "HTTP/1.1".to_string(),
259            hyper::Version::HTTP_2 => "HTTP/2".to_string(),
260            hyper::Version::HTTP_3 => "HTTP/3".to_string(),
261            _ => "HTTP/Unknown".to_string(),
262        };
263
264        let mut headers = HashMap::new();
265        for (key, value) in response.headers() {
266            if let Ok(value_str) = value.to_str() {
267                headers.insert(key.as_str().to_string(), value_str.to_string());
268            }
269        }
270
271        let body_bytes = response
272            .into_body()
273            .collect()
274            .await
275            .map_err(|e| {
276                WaeError::new(WaeErrorKind::ProtocolError {
277                    protocol: "HTTP".to_string(),
278                    reason: format!("Read body failed: {}", e),
279                })
280            })?
281            .to_bytes();
282
283        Ok(HttpResponse { version, status, status_text, headers, body: body_bytes.to_vec() })
284    }
285
286    /// 判断状态码是否可重试
287    fn is_retryable_status(status: u16) -> bool {
288        matches!(status, 408 | 429 | 500 | 502 | 503 | 504)
289    }
290
291    /// 判断错误是否可重试
292    fn is_retryable_error(error: &WaeError) -> bool {
293        matches!(
294            error.kind.as_ref(),
295            WaeErrorKind::OperationTimeout { .. }
296                | WaeErrorKind::ConnectionFailed { .. }
297                | WaeErrorKind::DnsResolutionFailed { .. }
298        )
299    }
300}
301
302/// 请求构建器
303pub struct RequestBuilder {
304    client: HttpClient,
305    method: String,
306    url: String,
307    headers: HashMap<String, String>,
308    body: Option<Vec<u8>>,
309}
310
311impl RequestBuilder {
312    /// 创建新的请求构建器
313    pub fn new(client: HttpClient, method: &str, url: &str) -> Self {
314        Self { client, method: method.to_string(), url: url.to_string(), headers: HashMap::new(), body: None }
315    }
316
317    /// 添加请求头
318    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
319        self.headers.insert(key.into(), value.into());
320        self
321    }
322
323    /// 添加 Bearer Token
324    pub fn bearer_auth(self, token: impl Into<String>) -> Self {
325        self.header("Authorization", format!("Bearer {}", token.into()))
326    }
327
328    /// 设置 JSON 请求体
329    pub fn json<T: Serialize>(mut self, body: &T) -> Self {
330        if let Ok(json) = serde_json::to_vec(body) {
331            self.headers.insert("Content-Type".into(), "application/json".into());
332            self.body = Some(json);
333        }
334        self
335    }
336
337    /// 设置原始请求体
338    pub fn body(mut self, body: Vec<u8>) -> Self {
339        self.body = Some(body);
340        self
341    }
342
343    /// 发送请求
344    pub async fn send(self) -> WaeResult<HttpResponse> {
345        self.client.request(&self.method, &self.url, self.body, Some(self.headers)).await
346    }
347
348    /// 发送请求并解析 JSON
349    pub async fn send_json<T: DeserializeOwned>(self) -> WaeResult<T> {
350        let response = self.send().await?;
351        response.json()
352    }
353}
354
355/// 便捷函数:创建 GET 请求
356pub fn get(url: &str) -> RequestBuilder {
357    RequestBuilder::new(HttpClient::default(), "GET", url)
358}
359
360/// 便捷函数:创建 POST 请求
361pub fn post(url: &str) -> RequestBuilder {
362    RequestBuilder::new(HttpClient::default(), "POST", url)
363}
364
365/// 兼容旧 API 的 RequestClient
366#[derive(Debug, Clone)]
367pub struct RequestClient {
368    client: HttpClient,
369    config: RequestConfig,
370}
371
372/// 兼容旧 API 的配置
373#[derive(Debug, Clone)]
374pub struct RequestConfig {
375    /// 请求超时时间(秒)
376    pub timeout_secs: u64,
377    /// 连接超时时间(秒)
378    pub connect_timeout_secs: u64,
379    /// 最大重试次数
380    pub max_retries: u32,
381    /// 重试间隔(毫秒)
382    pub retry_delay_ms: u64,
383    /// 用户代理
384    pub user_agent: String,
385}
386
387impl Default for RequestConfig {
388    fn default() -> Self {
389        Self {
390            timeout_secs: 30,
391            connect_timeout_secs: 10,
392            max_retries: 3,
393            retry_delay_ms: 1000,
394            user_agent: "wae-request/0.1.0".to_string(),
395        }
396    }
397}
398
399impl RequestClient {
400    /// 创建新的 HTTP 客户端
401    pub fn new(config: RequestConfig) -> WaeResult<Self> {
402        let http_config = HttpClientConfig {
403            timeout: Duration::from_secs(config.timeout_secs),
404            connect_timeout: Duration::from_secs(config.connect_timeout_secs),
405            user_agent: config.user_agent.clone(),
406            max_retries: config.max_retries,
407            retry_delay: Duration::from_millis(config.retry_delay_ms),
408            default_headers: HashMap::new(),
409        };
410        Ok(Self { client: HttpClient::new(http_config), config })
411    }
412
413    /// 使用默认配置创建
414    pub fn with_defaults() -> WaeResult<Self> {
415        Self::new(RequestConfig::default())
416    }
417
418    /// 获取配置
419    pub fn config(&self) -> &RequestConfig {
420        &self.config
421    }
422
423    /// 发送 GET 请求并解析 JSON
424    pub async fn get<T: DeserializeOwned>(&self, url: &str) -> WaeResult<T> {
425        let response = self.client.get(url).await?;
426        response.json()
427    }
428
429    /// 发送 GET 请求返回原始响应
430    pub async fn get_raw(&self, url: &str) -> WaeResult<HttpResponse> {
431        self.client.get(url).await
432    }
433
434    /// 发送 POST JSON 请求并解析响应
435    pub async fn post<T: DeserializeOwned, B: Serialize>(&self, url: &str, body: &B) -> WaeResult<T> {
436        let response = self.client.post_json(url, body).await?;
437        response.json()
438    }
439
440    /// 发送 POST JSON 请求返回原始响应
441    pub async fn post_raw<B: Serialize>(&self, url: &str, body: &B) -> WaeResult<HttpResponse> {
442        self.client.post_json(url, body).await
443    }
444
445    /// 创建请求构建器
446    pub fn builder(&self) -> RequestBuilder {
447        RequestBuilder::new(self.client.clone(), "GET", "")
448    }
449}