Skip to main content

sh_layer1/streaming/
http.rs

1//! HTTP 流式适配器
2//!
3//! 提供基于 reqwest 的 HTTP 流式响应处理。
4//!
5//! ## 特性
6//! - 流式请求/响应
7//! - SSE (Server-Sent Events) 支持
8//! - 请求超时和重试
9//! - 请求/响应拦截器
10
11use anyhow::{anyhow, Result};
12use async_trait::async_trait;
13use futures::Stream;
14use reqwest::{Client, Response, StatusCode};
15use std::collections::VecDeque;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::Arc;
18use std::time::Duration;
19
20use crate::streaming::providers::StreamEvent;
21use crate::streaming::sse::SseEvent;
22use crate::streaming::sse::SseParser;
23
24/// HTTP 配置
25#[derive(Debug, Clone)]
26pub struct HttpConfig {
27    /// 请求超时(毫秒)
28    pub timeout_ms: u64,
29    /// 连接超时(毫秒)
30    pub connect_timeout_ms: u64,
31    /// 最大重试次数
32    pub max_retries: u32,
33    /// 重试间隔(毫秒)
34    pub retry_interval_ms: u64,
35    /// 流式读取超时(毫秒)
36    pub stream_timeout_ms: u64,
37}
38
39impl Default for HttpConfig {
40    fn default() -> Self {
41        Self {
42            timeout_ms: 30000,
43            connect_timeout_ms: 10000,
44            max_retries: 3,
45            retry_interval_ms: 1000,
46            stream_timeout_ms: 60000,
47        }
48    }
49}
50
51/// HTTP 请求方法
52#[derive(Debug, Clone)]
53pub enum HttpMethod {
54    Get,
55    Post,
56    Put,
57    Delete,
58    Patch,
59}
60
61/// HTTP 请求配置
62#[derive(Debug, Clone)]
63pub struct HttpRequest {
64    /// URL
65    pub url: String,
66    /// 方法
67    pub method: HttpMethod,
68    /// 头部
69    pub headers: Vec<(String, String)>,
70    /// 请求体
71    pub body: Option<serde_json::Value>,
72}
73
74impl HttpRequest {
75    /// 创建 GET 请求
76    pub fn get(url: impl Into<String>) -> Self {
77        Self {
78            url: url.into(),
79            method: HttpMethod::Get,
80            headers: Vec::new(),
81            body: None,
82        }
83    }
84
85    /// 创建 POST 请求
86    pub fn post(url: impl Into<String>, body: serde_json::Value) -> Self {
87        Self {
88            url: url.into(),
89            method: HttpMethod::Post,
90            headers: Vec::new(),
91            body: Some(body),
92        }
93    }
94
95    /// 添加头部
96    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
97        self.headers.push((key.into(), value.into()));
98        self
99    }
100
101    /// 添加 Authorization 头部
102    pub fn bearer_auth(mut self, token: impl Into<String>) -> Self {
103        self.headers.push((
104            "Authorization".to_string(),
105            format!("Bearer {}", token.into()),
106        ));
107        self
108    }
109
110    /// 添加 API Key 头部
111    pub fn api_key(mut self, key: impl Into<String>) -> Self {
112        self.headers.push(("x-api-key".to_string(), key.into()));
113        self
114    }
115}
116
117/// HTTP 响应流
118///
119/// 包装 reqwest Response,提供流式数据读取
120pub struct HttpResponseStream {
121    response: Response,
122    parser: SseParser,
123    pending: VecDeque<StreamEvent>,
124    done: bool,
125    abort_flag: Arc<AtomicBool>,
126}
127
128impl HttpResponseStream {
129    /// 创建新的响应流
130    pub fn new(response: Response, abort_flag: Arc<AtomicBool>) -> Self {
131        Self {
132            response,
133            parser: SseParser::new(),
134            pending: VecDeque::new(),
135            done: false,
136            abort_flag,
137        }
138    }
139
140    /// 获取响应状态码
141    pub fn status(&self) -> StatusCode {
142        self.response.status()
143    }
144
145    /// 获取响应头部
146    pub fn headers(&self) -> &reqwest::header::HeaderMap {
147        self.response.headers()
148    }
149
150    /// 读取下一个事件
151    pub async fn next_event(&mut self) -> Result<Option<StreamEvent>> {
152        // 使用简化的 SSE 事件返回
153        loop {
154            if self.abort_flag.load(Ordering::Relaxed) {
155                return Ok(None);
156            }
157
158            if let Some(event) = self.pending.pop_front() {
159                return Ok(Some(event));
160            }
161
162            if self.done {
163                let _remaining = self.parser.finish()?;
164                if let Some(event) = self.pending.pop_front() {
165                    return Ok(Some(event));
166                }
167                return Ok(None);
168            }
169
170            match self.response.chunk().await? {
171                Some(chunk) => {
172                    let sse_events = self.parser.push(&chunk)?;
173                    for _sse_event in sse_events {
174                        // 将 SSE 事件转换为流事件(简化版)
175                        self.pending.push_back(StreamEvent::MessageStart {
176                            id: String::new(),
177                            model: String::new(),
178                        });
179                    }
180                }
181                None => {
182                    self.done = true;
183                }
184            }
185        }
186    }
187
188    /// 收集所有文本内容
189    pub async fn collect_text(&mut self) -> Result<String> {
190        let mut result = String::new();
191        while let Some(event) = self.next_event().await? {
192            if let StreamEvent::ContentBlockDelta {
193                delta: crate::streaming::providers::ContentDelta::Text(t),
194                ..
195            } = event
196            {
197                result.push_str(&t);
198            }
199        }
200        Ok(result)
201    }
202
203    /// 创建 SSE 事件流
204    pub fn into_sse_stream(mut self) -> impl Stream<Item = Result<SseEvent>> {
205        async_stream::stream! {
206            loop {
207                if self.abort_flag.load(Ordering::Relaxed) {
208                    break;
209                }
210
211                match self.response.chunk().await {
212                    Ok(Some(chunk)) => {
213                        let events = self.parser.push(&chunk)?;
214                        for event in events {
215                            yield Ok(event);
216                        }
217                    }
218                    Ok(None) => {
219                        let remaining = self.parser.finish()?;
220                        for event in remaining {
221                            yield Ok(event);
222                        }
223                        break;
224                    }
225                    Err(e) => {
226                        yield Err(anyhow!("Stream error: {}", e));
227                        break;
228                    }
229                }
230            }
231        }
232    }
233}
234
235/// HTTP 适配器
236///
237/// 提供高级 HTTP 请求功能,包括流式响应处理
238pub struct HttpAdapter {
239    /// HTTP 客户端
240    client: Client,
241    /// 配置
242    config: HttpConfig,
243    /// 中断标志
244    abort_flag: Arc<AtomicBool>,
245}
246
247impl HttpAdapter {
248    /// 创建新的 HTTP 适配器
249    pub fn new() -> Self {
250        Self::with_config(HttpConfig::default())
251    }
252
253    /// 创建带配置的 HTTP 适配器
254    pub fn with_config(config: HttpConfig) -> Self {
255        let client = Client::builder()
256            .timeout(Duration::from_millis(config.timeout_ms))
257            .connect_timeout(Duration::from_millis(config.connect_timeout_ms))
258            .build()
259            .expect("Failed to create HTTP client");
260
261        Self {
262            client,
263            config,
264            abort_flag: Arc::new(AtomicBool::new(false)),
265        }
266    }
267
268    /// 获取中断标志
269    pub fn abort_flag(&self) -> Arc<AtomicBool> {
270        Arc::clone(&self.abort_flag)
271    }
272
273    /// 请求中断
274    pub fn abort(&self) {
275        self.abort_flag.store(true, Ordering::Relaxed);
276    }
277
278    /// 重置中断标志
279    pub fn reset(&self) {
280        self.abort_flag.store(false, Ordering::Relaxed);
281    }
282
283    /// 检查是否已中断
284    pub fn is_aborted(&self) -> bool {
285        self.abort_flag.load(Ordering::Relaxed)
286    }
287
288    /// 执行 HTTP 请求
289    pub async fn request(&self, request: HttpRequest) -> Result<Response> {
290        self.request_with_retry(request, self.config.max_retries)
291            .await
292    }
293
294    /// 执行带重试的 HTTP 请求
295    async fn request_with_retry(&self, request: HttpRequest, max_retries: u32) -> Result<Response> {
296        let mut attempts = 0;
297
298        loop {
299            if self.is_aborted() {
300                return Err(anyhow!("Request aborted"));
301            }
302
303            attempts += 1;
304
305            let result = self.execute_request(&request).await;
306
307            match result {
308                Ok(response) => {
309                    let status = response.status();
310                    if status.is_success() {
311                        return Ok(response);
312                    }
313
314                    // 如果是可重试的错误状态码,继续重试
315                    if Self::is_retryable_status(status) && attempts <= max_retries {
316                        tracing::warn!(
317                            "HTTP request failed with status {}, attempt {}/{}",
318                            status,
319                            attempts,
320                            max_retries
321                        );
322                        let delay = Duration::from_millis(
323                            self.config.retry_interval_ms * (1 << (attempts - 1)),
324                        );
325                        tokio::time::sleep(delay).await;
326                        continue;
327                    }
328
329                    // 其他错误,返回响应体内容
330                    let body = response.text().await.unwrap_or_default();
331                    return Err(anyhow!("HTTP {}: {}", status, body));
332                }
333                Err(e) => {
334                    // 网络错误等可重试
335                    if Self::is_retryable_error(&e) && attempts <= max_retries {
336                        tracing::warn!(
337                            "HTTP request error: {}, attempt {}/{}",
338                            e,
339                            attempts,
340                            max_retries
341                        );
342                        let delay = Duration::from_millis(
343                            self.config.retry_interval_ms * (1 << (attempts - 1)),
344                        );
345                        tokio::time::sleep(delay).await;
346                        continue;
347                    }
348                    return Err(e);
349                }
350            }
351        }
352    }
353
354    /// 执行单次 HTTP 请求
355    async fn execute_request(&self, request: &HttpRequest) -> Result<Response> {
356        let builder = match request.method {
357            HttpMethod::Get => self.client.get(&request.url),
358            HttpMethod::Post => self.client.post(&request.url),
359            HttpMethod::Put => self.client.put(&request.url),
360            HttpMethod::Delete => self.client.delete(&request.url),
361            HttpMethod::Patch => self.client.patch(&request.url),
362        };
363
364        // 添加头部
365        let builder = request
366            .headers
367            .iter()
368            .fold(builder, |b, (k, v)| b.header(k, v));
369
370        // 添加请求体
371        let builder = if let Some(body) = &request.body {
372            builder.json(body)
373        } else {
374            builder
375        };
376
377        let response = builder.send().await?;
378        Ok(response)
379    }
380
381    /// 执行流式 HTTP 请求
382    pub async fn request_stream(&self, request: HttpRequest) -> Result<HttpResponseStream> {
383        let response = self.request(request).await?;
384        Ok(HttpResponseStream::new(response, self.abort_flag.clone()))
385    }
386
387    /// 执行 SSE 流式请求
388    pub async fn request_sse(&self, request: HttpRequest) -> Result<SseStream> {
389        let builder = self.client.post(&request.url);
390
391        let builder = request
392            .headers
393            .iter()
394            .fold(builder, |b, (k, v)| b.header(k, v));
395
396        let builder = if let Some(body) = &request.body {
397            builder.json(body)
398        } else {
399            builder
400        };
401
402        let builder = builder.header("Accept", "text/event-stream");
403
404        let response = builder.send().await?;
405
406        let status = response.status();
407        if !status.is_success() {
408            let body = response.text().await.unwrap_or_default();
409            return Err(anyhow!(
410                "SSE request failed with status {}: {}",
411                status,
412                body
413            ));
414        }
415
416        Ok(SseStream::new(response, self.abort_flag.clone()))
417    }
418
419    /// 检查状态码是否可重试
420    fn is_retryable_status(status: StatusCode) -> bool {
421        matches!(status.as_u16(), 429 | 500 | 502 | 503 | 504)
422    }
423
424    /// 检查错误是否可重试
425    fn is_retryable_error(error: &anyhow::Error) -> bool {
426        let msg = error.to_string().to_lowercase();
427        msg.contains("timeout")
428            || msg.contains("connection")
429            || msg.contains("network")
430            || msg.contains("429")
431            || msg.contains("overloaded")
432    }
433}
434
435impl Default for HttpAdapter {
436    fn default() -> Self {
437        Self::new()
438    }
439}
440
441/// SSE 流
442///
443/// 专门处理 Server-Sent Events 的流式响应
444pub struct SseStream {
445    response: Response,
446    parser: SseParser,
447    abort_flag: Arc<AtomicBool>,
448    done: bool,
449}
450
451impl SseStream {
452    /// 创建新的 SSE 流
453    pub fn new(response: Response, abort_flag: Arc<AtomicBool>) -> Self {
454        Self {
455            response,
456            parser: SseParser::new(),
457            abort_flag,
458            done: false,
459        }
460    }
461
462    /// 获取下一个 SSE 事件
463    pub async fn next_event(&mut self) -> Result<Option<SseEvent>> {
464        loop {
465            if self.abort_flag.load(Ordering::Relaxed) {
466                return Ok(None);
467            }
468
469            if self.done {
470                let remaining = self.parser.finish()?;
471                if remaining.is_empty() {
472                    return Ok(None);
473                }
474                // 返回第一个剩余事件
475                return Ok(remaining.into_iter().next());
476            }
477
478            match self.response.chunk().await? {
479                Some(chunk) => {
480                    let events = self.parser.push(&chunk)?;
481                    if !events.is_empty() {
482                        return Ok(Some(events.into_iter().next().unwrap()));
483                    }
484                }
485                None => {
486                    self.done = true;
487                }
488            }
489        }
490    }
491
492    /// 收集所有事件
493    pub async fn collect_events(&mut self) -> Result<Vec<SseEvent>> {
494        let mut events = Vec::new();
495        while let Some(event) = self.next_event().await? {
496            events.push(event);
497        }
498        Ok(events)
499    }
500}
501
502/// HTTP 适配器 trait
503///
504/// 定义 HTTP 请求的标准接口
505#[async_trait]
506pub trait HttpAdapterTrait: Send + Sync {
507    /// 执行 HTTP GET 请求
508    async fn get(&self, url: &str) -> Result<String>;
509
510    /// 执行 HTTP POST 请求
511    async fn post(&self, url: &str, body: serde_json::Value) -> Result<String>;
512
513    /// 执行流式 HTTP POST 请求
514    async fn post_stream(&self, url: &str, body: serde_json::Value) -> Result<HttpResponseStream>;
515
516    /// 执行 SSE 流式请求
517    async fn post_sse(&self, url: &str, body: serde_json::Value) -> Result<SseStream>;
518}
519
520#[async_trait]
521impl HttpAdapterTrait for HttpAdapter {
522    async fn get(&self, url: &str) -> Result<String> {
523        let request = HttpRequest::get(url);
524        let response = self.request(request).await?;
525        let text = response.text().await?;
526        Ok(text)
527    }
528
529    async fn post(&self, url: &str, body: serde_json::Value) -> Result<String> {
530        let request = HttpRequest::post(url, body);
531        let response = self.request(request).await?;
532        let text = response.text().await?;
533        Ok(text)
534    }
535
536    async fn post_stream(&self, url: &str, body: serde_json::Value) -> Result<HttpResponseStream> {
537        let request = HttpRequest::post(url, body);
538        self.request_stream(request).await
539    }
540
541    async fn post_sse(&self, url: &str, body: serde_json::Value) -> Result<SseStream> {
542        let request = HttpRequest::post(url, body).header("Accept", "text/event-stream");
543        self.request_sse(request).await
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550
551    #[test]
552    fn test_http_config_default() {
553        let config = HttpConfig::default();
554        assert_eq!(config.timeout_ms, 30000);
555        assert_eq!(config.connect_timeout_ms, 10000);
556        assert_eq!(config.max_retries, 3);
557    }
558
559    #[test]
560    fn test_http_request_builder() {
561        let request = HttpRequest::get("https://api.example.com")
562            .bearer_auth("token123")
563            .header("X-Custom", "value");
564
565        assert_eq!(request.url, "https://api.example.com");
566        assert_eq!(request.headers.len(), 2);
567    }
568
569    #[test]
570    fn test_http_request_post() {
571        let body = serde_json::json!({"key": "value"});
572        let request = HttpRequest::post("https://api.example.com", body.clone());
573
574        assert_eq!(request.url, "https://api.example.com");
575        assert!(matches!(request.method, HttpMethod::Post));
576        assert_eq!(request.body, Some(body));
577    }
578
579    #[test]
580    fn test_is_retryable_status() {
581        assert!(HttpAdapter::is_retryable_status(
582            StatusCode::TOO_MANY_REQUESTS
583        ));
584        assert!(HttpAdapter::is_retryable_status(
585            StatusCode::INTERNAL_SERVER_ERROR
586        ));
587        assert!(HttpAdapter::is_retryable_status(StatusCode::BAD_GATEWAY));
588        assert!(HttpAdapter::is_retryable_status(
589            StatusCode::SERVICE_UNAVAILABLE
590        ));
591        assert!(HttpAdapter::is_retryable_status(
592            StatusCode::GATEWAY_TIMEOUT
593        ));
594
595        assert!(!HttpAdapter::is_retryable_status(StatusCode::BAD_REQUEST));
596        assert!(!HttpAdapter::is_retryable_status(StatusCode::UNAUTHORIZED));
597        assert!(!HttpAdapter::is_retryable_status(StatusCode::NOT_FOUND));
598    }
599
600    #[test]
601    fn test_http_adapter_creation() {
602        let adapter = HttpAdapter::new();
603        assert!(!adapter.is_aborted());
604    }
605
606    #[test]
607    fn test_http_adapter_abort() {
608        let adapter = HttpAdapter::new();
609        adapter.abort();
610        assert!(adapter.is_aborted());
611        adapter.reset();
612        assert!(!adapter.is_aborted());
613    }
614}