Skip to main content

wae_request/
lib.rs

1//! WAE Request - 纯 Tokio HTTP 客户端
2//!
3//! 基于 tokio + tokio-rustls 实现的 HTTP/1.1 客户端
4//! 不依赖 hyper 或 reqwest
5
6#![warn(missing_docs)]
7
8use serde::{Serialize, de::DeserializeOwned};
9use std::{collections::HashMap, fmt, sync::Arc, time::Duration};
10use tokio::{
11    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
12    net::TcpStream,
13    time::timeout,
14};
15use tokio_rustls::{TlsConnector, rustls::pki_types::ServerName};
16use tracing::{debug, error, info};
17use url::Url;
18use wae_types::{NetworkErrorKind, WaeError, WaeResult};
19
20/// HTTP 客户端错误
21#[derive(Debug)]
22pub enum HttpError {
23    /// URL 解析错误
24    InvalidUrl(String),
25
26    /// DNS 解析错误
27    DnsFailed(String),
28
29    /// 连接错误
30    ConnectionFailed(String),
31
32    /// TLS 错误
33    TlsError(String),
34
35    /// 请求超时
36    Timeout,
37
38    /// HTTP 协议错误
39    ProtocolError(String),
40
41    /// 响应解析错误
42    ParseError(String),
43
44    /// 序列化错误
45    SerializationError(String),
46
47    /// 状态码错误
48    StatusError {
49        /// HTTP 状态码
50        status: u16,
51        /// 响应体
52        body: String,
53    },
54}
55
56impl fmt::Display for HttpError {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        match self {
59            HttpError::InvalidUrl(msg) => write!(f, "Invalid URL: {}", msg),
60            HttpError::DnsFailed(msg) => write!(f, "DNS resolution failed: {}", msg),
61            HttpError::ConnectionFailed(msg) => write!(f, "Connection failed: {}", msg),
62            HttpError::TlsError(msg) => write!(f, "TLS error: {}", msg),
63            HttpError::Timeout => write!(f, "Request timeout"),
64            HttpError::ProtocolError(msg) => write!(f, "HTTP protocol error: {}", msg),
65            HttpError::ParseError(msg) => write!(f, "Response parse error: {}", msg),
66            HttpError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
67            HttpError::StatusError { status, body } => write!(f, "HTTP {}: {}", status, body),
68        }
69    }
70}
71
72impl std::error::Error for HttpError {}
73
74/// HTTP 客户端配置
75#[derive(Debug, Clone)]
76pub struct HttpClientConfig {
77    /// 请求超时时间
78    pub timeout: Duration,
79    /// 连接超时时间
80    pub connect_timeout: Duration,
81    /// 用户代理
82    pub user_agent: String,
83    /// 最大重试次数
84    pub max_retries: u32,
85    /// 重试间隔
86    pub retry_delay: Duration,
87    /// 默认请求头
88    pub default_headers: HashMap<String, String>,
89}
90
91impl Default for HttpClientConfig {
92    fn default() -> Self {
93        Self {
94            timeout: Duration::from_secs(30),
95            connect_timeout: Duration::from_secs(10),
96            user_agent: "wae-request/0.1.0".to_string(),
97            max_retries: 3,
98            retry_delay: Duration::from_millis(1000),
99            default_headers: HashMap::new(),
100        }
101    }
102}
103
104/// HTTP 响应
105#[derive(Debug)]
106pub struct HttpResponse {
107    /// HTTP 版本
108    pub version: String,
109    /// 状态码
110    pub status: u16,
111    /// 状态文本
112    pub status_text: String,
113    /// 响应头
114    pub headers: HashMap<String, String>,
115    /// 响应体
116    pub body: Vec<u8>,
117}
118
119impl HttpResponse {
120    /// 解析 JSON 响应体
121    pub fn json<T: DeserializeOwned>(&self) -> Result<T, HttpError> {
122        serde_json::from_slice(&self.body).map_err(|e| HttpError::ParseError(format!("JSON parse error: {}", e)))
123    }
124
125    /// 获取文本响应体
126    pub fn text(&self) -> Result<String, HttpError> {
127        String::from_utf8(self.body.clone()).map_err(|e| HttpError::ParseError(format!("UTF-8 decode error: {}", e)))
128    }
129
130    /// 检查是否成功
131    pub fn is_success(&self) -> bool {
132        self.status >= 200 && self.status < 300
133    }
134}
135
136/// TLS 连接器 (全局共享)
137static TLS_CONNECTOR: std::sync::OnceLock<Arc<TlsConnector>> = std::sync::OnceLock::new();
138
139fn get_tls_connector() -> Arc<TlsConnector> {
140    TLS_CONNECTOR
141        .get_or_init(|| {
142            let mut roots = tokio_rustls::rustls::RootCertStore::empty();
143            roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
144            let config = tokio_rustls::rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth();
145            Arc::new(TlsConnector::from(Arc::new(config)))
146        })
147        .clone()
148}
149
150/// HTTP 客户端
151#[derive(Debug, Clone)]
152pub struct HttpClient {
153    config: HttpClientConfig,
154}
155
156impl Default for HttpClient {
157    fn default() -> Self {
158        Self::new(HttpClientConfig::default())
159    }
160}
161
162impl HttpClient {
163    /// 创建新的 HTTP 客户端
164    pub fn new(config: HttpClientConfig) -> Self {
165        Self { config }
166    }
167
168    /// 使用默认配置创建
169    pub fn with_defaults() -> Self {
170        Self::default()
171    }
172
173    /// 获取配置
174    pub fn config(&self) -> &HttpClientConfig {
175        &self.config
176    }
177
178    /// 发送 GET 请求
179    pub async fn get(&self, url: &str) -> Result<HttpResponse, HttpError> {
180        self.request("GET", url, None, None).await
181    }
182
183    /// 发送带请求头的 GET 请求
184    pub async fn get_with_headers(&self, url: &str, headers: HashMap<String, String>) -> Result<HttpResponse, HttpError> {
185        self.request("GET", url, None, Some(headers)).await
186    }
187
188    /// 发送 POST JSON 请求
189    pub async fn post_json<T: Serialize>(&self, url: &str, body: &T) -> Result<HttpResponse, HttpError> {
190        let json_body = serde_json::to_vec(body).map_err(|e| HttpError::SerializationError(e.to_string()))?;
191
192        let mut headers = HashMap::new();
193        headers.insert("Content-Type".to_string(), "application/json".to_string());
194
195        self.request("POST", url, Some(json_body), Some(headers)).await
196    }
197
198    /// 发送带请求头的 POST 请求
199    pub async fn post_with_headers(
200        &self,
201        url: &str,
202        body: Vec<u8>,
203        headers: HashMap<String, String>,
204    ) -> Result<HttpResponse, HttpError> {
205        self.request("POST", url, Some(body), Some(headers)).await
206    }
207
208    /// 发送请求 (带重试)
209    pub async fn request(
210        &self,
211        method: &str,
212        url: &str,
213        body: Option<Vec<u8>>,
214        headers: Option<HashMap<String, String>>,
215    ) -> Result<HttpResponse, HttpError> {
216        let mut last_error = None;
217
218        for attempt in 0..=self.config.max_retries {
219            if attempt > 0 {
220                let delay = self.config.retry_delay * attempt;
221                debug!("Retry attempt {} after {:?}", attempt, delay);
222                tokio::time::sleep(delay).await;
223            }
224
225            match self.request_once(method, url, body.clone(), headers.clone()).await {
226                Ok(response) => {
227                    if response.is_success() {
228                        info!("Request succeeded on attempt {}", attempt);
229                        return Ok(response);
230                    }
231
232                    if Self::is_retryable_status(response.status) && attempt < self.config.max_retries {
233                        last_error = Some(HttpError::StatusError {
234                            status: response.status,
235                            body: String::from_utf8_lossy(&response.body).to_string(),
236                        });
237                        continue;
238                    }
239
240                    return Err(HttpError::StatusError {
241                        status: response.status,
242                        body: String::from_utf8_lossy(&response.body).to_string(),
243                    });
244                }
245                Err(e) => {
246                    error!("Request error on attempt {}: {}", attempt, e);
247                    if Self::is_retryable_error(&e) && attempt < self.config.max_retries {
248                        last_error = Some(e);
249                        continue;
250                    }
251                    return Err(e);
252                }
253            }
254        }
255
256        Err(last_error.unwrap_or(HttpError::Timeout))
257    }
258
259    /// 发送单次请求
260    async fn request_once(
261        &self,
262        method: &str,
263        url_str: &str,
264        body: Option<Vec<u8>>,
265        extra_headers: Option<HashMap<String, String>>,
266    ) -> Result<HttpResponse, HttpError> {
267        let url = Url::parse(url_str).map_err(|e| HttpError::InvalidUrl(e.to_string()))?;
268
269        let host = url.host_str().ok_or_else(|| HttpError::InvalidUrl("Missing host".into()))?;
270        let port = url.port().unwrap_or(if url.scheme() == "https" { 443 } else { 80 });
271        let path = url.path();
272        let query = url.query().map(|q| format!("?{}", q)).unwrap_or_default();
273        let uri = format!("{}{}", path, query);
274
275        let is_https = url.scheme() == "https";
276
277        let connect_result =
278            timeout(self.config.connect_timeout, TcpStream::connect((host, port))).await.map_err(|_| HttpError::Timeout)?;
279
280        let tcp_stream = connect_result.map_err(|e| HttpError::ConnectionFailed(format!("TCP connect failed: {}", e)))?;
281
282        tcp_stream.set_nodelay(true).ok();
283
284        let response = if is_https {
285            let connector = get_tls_connector();
286            let server_name = ServerName::try_from(host.to_string())
287                .map_err(|e| HttpError::TlsError(format!("Invalid server name: {}", e)))?;
288
289            let tls_stream = connector
290                .connect(server_name, tcp_stream)
291                .await
292                .map_err(|e| HttpError::TlsError(format!("TLS handshake failed: {}", e)))?;
293
294            let (reader, writer) = tokio::io::split(tls_stream);
295            self.send_http_request(reader, writer, method, host, &uri, body, extra_headers).await?
296        }
297        else {
298            let (reader, writer) = tcp_stream.into_split();
299            self.send_http_request(reader, writer, method, host, &uri, body, extra_headers).await?
300        };
301
302        Ok(response)
303    }
304
305    /// 发送 HTTP 请求并读取响应
306    #[allow(clippy::too_many_arguments)]
307    async fn send_http_request<R, W>(
308        &self,
309        reader: R,
310        mut writer: W,
311        method: &str,
312        host: &str,
313        uri: &str,
314        body: Option<Vec<u8>>,
315        extra_headers: Option<HashMap<String, String>>,
316    ) -> Result<HttpResponse, HttpError>
317    where
318        R: AsyncReadExt + Unpin,
319        W: AsyncWriteExt + Unpin,
320    {
321        let body_len = body.as_ref().map(|b| b.len()).unwrap_or(0);
322
323        let mut request =
324            format!("{} {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: {}\r\n", method, uri, host, self.config.user_agent);
325
326        if body_len > 0 {
327            request.push_str(&format!("Content-Length: {}\r\n", body_len));
328        }
329
330        for (key, value) in &self.config.default_headers {
331            request.push_str(&format!("{}: {}\r\n", key, value));
332        }
333
334        if let Some(headers) = extra_headers {
335            for (key, value) in headers {
336                request.push_str(&format!("{}: {}\r\n", key, value));
337            }
338        }
339
340        request.push_str("Connection: close\r\n\r\n");
341
342        let mut request_bytes = request.into_bytes();
343        if let Some(b) = body {
344            request_bytes.extend(b);
345        }
346
347        timeout(self.config.timeout, async {
348            writer
349                .write_all(&request_bytes)
350                .await
351                .map_err(|e| HttpError::ConnectionFailed(format!("Write request failed: {}", e)))?;
352            writer.flush().await.map_err(|e| HttpError::ConnectionFailed(format!("Flush failed: {}", e)))?;
353            Ok::<_, HttpError>(())
354        })
355        .await
356        .map_err(|_| HttpError::Timeout)??;
357
358        let response = timeout(self.config.timeout, self.read_response(reader)).await.map_err(|_| HttpError::Timeout)??;
359
360        Ok(response)
361    }
362
363    /// 读取 HTTP 响应
364    async fn read_response<R: AsyncReadExt + Unpin>(&self, reader: R) -> Result<HttpResponse, HttpError> {
365        let mut buf_reader = BufReader::new(reader);
366        let mut status_line = String::new();
367
368        buf_reader
369            .read_line(&mut status_line)
370            .await
371            .map_err(|e| HttpError::ProtocolError(format!("Read status line failed: {}", e)))?;
372
373        let status_parts: Vec<&str> = status_line.trim().splitn(3, ' ').collect();
374        if status_parts.len() < 2 {
375            return Err(HttpError::ProtocolError("Invalid status line".into()));
376        }
377
378        let version = status_parts[0].to_string();
379        let status: u16 = status_parts[1].parse().map_err(|_| HttpError::ProtocolError("Invalid status code".into()))?;
380        let status_text = status_parts.get(2).unwrap_or(&"").to_string();
381
382        let mut headers = HashMap::new();
383        loop {
384            let mut line = String::new();
385            buf_reader
386                .read_line(&mut line)
387                .await
388                .map_err(|e| HttpError::ProtocolError(format!("Read header failed: {}", e)))?;
389
390            if line == "\r\n" || line.is_empty() {
391                break;
392            }
393
394            if let Some((key, value)) = line.split_once(':') {
395                headers.insert(key.trim().to_string(), value.trim().to_string());
396            }
397        }
398
399        let content_length: Option<usize> = headers.get("content-length").and_then(|v| v.parse().ok());
400
401        let mut body = Vec::new();
402
403        if let Some(len) = content_length {
404            body.resize(len, 0);
405            buf_reader.read_exact(&mut body).await.map_err(|e| HttpError::ProtocolError(format!("Read body failed: {}", e)))?;
406        }
407        else {
408            buf_reader
409                .read_to_end(&mut body)
410                .await
411                .map_err(|e| HttpError::ProtocolError(format!("Read body failed: {}", e)))?;
412        }
413
414        Ok(HttpResponse { version, status, status_text, headers, body })
415    }
416
417    /// 判断状态码是否可重试
418    fn is_retryable_status(status: u16) -> bool {
419        matches!(status, 408 | 429 | 500 | 502 | 503 | 504)
420    }
421
422    /// 判断错误是否可重试
423    fn is_retryable_error(error: &HttpError) -> bool {
424        matches!(error, HttpError::Timeout | HttpError::ConnectionFailed(_) | HttpError::DnsFailed(_))
425    }
426}
427
428/// 请求构建器
429pub struct RequestBuilder {
430    client: HttpClient,
431    method: String,
432    url: String,
433    headers: HashMap<String, String>,
434    body: Option<Vec<u8>>,
435}
436
437impl RequestBuilder {
438    /// 创建新的请求构建器
439    pub fn new(client: HttpClient, method: &str, url: &str) -> Self {
440        Self { client, method: method.to_string(), url: url.to_string(), headers: HashMap::new(), body: None }
441    }
442
443    /// 添加请求头
444    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
445        self.headers.insert(key.into(), value.into());
446        self
447    }
448
449    /// 添加 Bearer Token
450    pub fn bearer_auth(self, token: impl Into<String>) -> Self {
451        self.header("Authorization", format!("Bearer {}", token.into()))
452    }
453
454    /// 设置 JSON 请求体
455    pub fn json<T: Serialize>(mut self, body: &T) -> Self {
456        if let Ok(json) = serde_json::to_vec(body) {
457            self.headers.insert("Content-Type".into(), "application/json".into());
458            self.body = Some(json);
459        }
460        self
461    }
462
463    /// 设置原始请求体
464    pub fn body(mut self, body: Vec<u8>) -> Self {
465        self.body = Some(body);
466        self
467    }
468
469    /// 发送请求
470    pub async fn send(self) -> Result<HttpResponse, HttpError> {
471        self.client.request(&self.method, &self.url, self.body, Some(self.headers)).await
472    }
473
474    /// 发送请求并解析 JSON
475    pub async fn send_json<T: DeserializeOwned>(self) -> Result<T, HttpError> {
476        let response = self.send().await?;
477        response.json()
478    }
479}
480
481/// 便捷函数:创建 GET 请求
482pub fn get(url: &str) -> RequestBuilder {
483    RequestBuilder::new(HttpClient::default(), "GET", url)
484}
485
486/// 便捷函数:创建 POST 请求
487pub fn post(url: &str) -> RequestBuilder {
488    RequestBuilder::new(HttpClient::default(), "POST", url)
489}
490
491/// 将 HttpError 转换为 WaeError
492impl From<HttpError> for WaeError {
493    fn from(error: HttpError) -> Self {
494        match error {
495            HttpError::InvalidUrl(msg) => WaeError::network(NetworkErrorKind::ConnectionFailed).with_param("message", msg),
496            HttpError::Timeout => WaeError::network(NetworkErrorKind::Timeout),
497            HttpError::ConnectionFailed(msg) => {
498                WaeError::network(NetworkErrorKind::ConnectionFailed).with_param("message", msg)
499            }
500            HttpError::DnsFailed(msg) => WaeError::network(NetworkErrorKind::DnsFailed).with_param("detail", msg),
501            HttpError::TlsError(msg) => WaeError::network(NetworkErrorKind::TlsError).with_param("message", msg),
502            HttpError::StatusError { status, body } => WaeError::network(NetworkErrorKind::ProtocolError)
503                .with_param("status", status.to_string())
504                .with_param("body", body),
505            HttpError::ProtocolError(msg) => WaeError::network(NetworkErrorKind::ProtocolError).with_param("message", msg),
506            HttpError::ParseError(msg) => WaeError::internal(format!("Response parse error: {}", msg)),
507            HttpError::SerializationError(msg) => WaeError::internal(format!("Serialization error: {}", msg)),
508        }
509    }
510}
511
512/// 兼容旧 API 的 RequestClient
513#[derive(Debug, Clone)]
514pub struct RequestClient {
515    client: HttpClient,
516    config: RequestConfig,
517}
518
519/// 兼容旧 API 的配置
520#[derive(Debug, Clone)]
521pub struct RequestConfig {
522    /// 请求超时时间(秒)
523    pub timeout_secs: u64,
524    /// 连接超时时间(秒)
525    pub connect_timeout_secs: u64,
526    /// 最大重试次数
527    pub max_retries: u32,
528    /// 重试间隔(毫秒)
529    pub retry_delay_ms: u64,
530    /// 用户代理
531    pub user_agent: String,
532}
533
534impl Default for RequestConfig {
535    fn default() -> Self {
536        Self {
537            timeout_secs: 30,
538            connect_timeout_secs: 10,
539            max_retries: 3,
540            retry_delay_ms: 1000,
541            user_agent: "wae-request/0.1.0".to_string(),
542        }
543    }
544}
545
546impl RequestClient {
547    /// 创建新的 HTTP 客户端
548    pub fn new(config: RequestConfig) -> WaeResult<Self> {
549        let http_config = HttpClientConfig {
550            timeout: Duration::from_secs(config.timeout_secs),
551            connect_timeout: Duration::from_secs(config.connect_timeout_secs),
552            user_agent: config.user_agent.clone(),
553            max_retries: config.max_retries,
554            retry_delay: Duration::from_millis(config.retry_delay_ms),
555            default_headers: HashMap::new(),
556        };
557        Ok(Self { client: HttpClient::new(http_config), config })
558    }
559
560    /// 使用默认配置创建
561    pub fn with_defaults() -> WaeResult<Self> {
562        Self::new(RequestConfig::default())
563    }
564
565    /// 获取配置
566    pub fn config(&self) -> &RequestConfig {
567        &self.config
568    }
569
570    /// 发送 GET 请求并解析 JSON
571    pub async fn get<T: DeserializeOwned>(&self, url: &str) -> WaeResult<T> {
572        let response = self.client.get(url).await.map_err(WaeError::from)?;
573        response.json().map_err(WaeError::from)
574    }
575
576    /// 发送 GET 请求返回原始响应
577    pub async fn get_raw(&self, url: &str) -> WaeResult<HttpResponse> {
578        self.client.get(url).await.map_err(WaeError::from)
579    }
580
581    /// 发送 POST JSON 请求并解析响应
582    pub async fn post<T: DeserializeOwned, B: Serialize>(&self, url: &str, body: &B) -> WaeResult<T> {
583        let response = self.client.post_json(url, body).await.map_err(WaeError::from)?;
584        response.json().map_err(WaeError::from)
585    }
586
587    /// 发送 POST JSON 请求返回原始响应
588    pub async fn post_raw<B: Serialize>(&self, url: &str, body: &B) -> WaeResult<HttpResponse> {
589        self.client.post_json(url, body).await.map_err(WaeError::from)
590    }
591
592    /// 创建请求构建器
593    pub fn builder(&self) -> RequestBuilder {
594        RequestBuilder::new(self.client.clone(), "GET", "")
595    }
596}