Skip to main content

sh_layer1/streaming/
mod.rs

1//! 流式处理模块
2//!
3//! SSE、WebSocket、HTTP 等流式响应处理。
4//!
5//! ## 模块结构
6//! - `sse`: SSE 解析器和事件处理
7//! - `websocket`: WebSocket 适配器
8//! - `http`: HTTP 流式适配器
9//! - `providers`: LLM 提供商特定的流式格式
10
11pub mod http;
12pub mod providers;
13pub mod sse;
14pub mod websocket;
15
16// Re-export from submodules
17pub use http::{HttpAdapter, HttpConfig, HttpRequest, HttpResponseStream, SseStream};
18pub use providers::{
19    ContentBlockType, ContentDelta, StreamEvent, StreamProvider, StreamState, StreamUsage,
20};
21pub use sse::{SseEvent, SseParser};
22pub use websocket::{WebSocketAdapter, WebSocketConfig, WebSocketMessage, WebSocketMessageStream};
23
24use anyhow::Result;
25use futures::Stream;
26use reqwest::Response;
27use std::collections::VecDeque;
28use std::pin::Pin;
29use std::sync::atomic::{AtomicBool, Ordering};
30use std::sync::Arc;
31use std::task::{Context, Poll};
32
33// Re-export provider types used in parse_sse_event
34pub use providers::{AnthropicStreamEvent, OllamaStreamChunk, OpenAiStreamChunk};
35
36/// 流处理器(兼容旧 API)
37pub struct StreamHandler;
38
39impl StreamHandler {
40    /// 创建 SSE 流
41    pub fn create_sse_stream(
42        source: impl Stream<Item = Result<String>> + Send + 'static,
43    ) -> impl Stream<Item = Result<String>> {
44        use futures::StreamExt;
45
46        source.map(|item| match item {
47            Ok(data) => Ok(format!("data: {}\n\n", data)),
48            Err(e) => Err(e),
49        })
50    }
51}
52
53/// 可中断的流式响应
54pub struct AbortableStream<S> {
55    inner: S,
56    abort_flag: Arc<AtomicBool>,
57}
58
59impl<S> AbortableStream<S> {
60    /// 创建可中断的流
61    pub fn new(inner: S, abort_flag: Arc<AtomicBool>) -> Self {
62        Self { inner, abort_flag }
63    }
64
65    /// 检查是否已中断
66    pub fn is_aborted(&self) -> bool {
67        self.abort_flag.load(Ordering::Relaxed)
68    }
69}
70
71impl<S, T> Stream for AbortableStream<S>
72where
73    S: Stream<Item = Result<T>> + Unpin,
74{
75    type Item = Result<T>;
76
77    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
78        if self.abort_flag.load(Ordering::Relaxed) {
79            return Poll::Ready(None);
80        }
81        Pin::new(&mut self.inner).poll_next(cx)
82    }
83}
84
85/// 流式消息(兼容旧 API)
86pub struct MessageStream {
87    response: Response,
88    parser: SseParser,
89    pending: VecDeque<StreamEvent>,
90    done: bool,
91    state: StreamState,
92    provider: StreamProvider,
93}
94
95impl MessageStream {
96    /// 创建新的流式消息
97    pub fn new(response: Response, provider: StreamProvider, model: String) -> Self {
98        let parser = SseParser::new().with_context(
99            match provider {
100                StreamProvider::Anthropic | StreamProvider::AnthropicCompatible => "Anthropic",
101                StreamProvider::OpenAI | StreamProvider::OpenAICompatible => "OpenAI",
102                StreamProvider::Gemini => "Gemini",
103                StreamProvider::AzureOpenAI => "AzureOpenAI",
104                StreamProvider::Bedrock => "Bedrock",
105                StreamProvider::Ollama => "Ollama",
106            },
107            &model,
108        );
109        Self {
110            response,
111            parser,
112            pending: VecDeque::new(),
113            done: false,
114            state: StreamState::new(model),
115            provider,
116        }
117    }
118
119    /// 获取下一个事件
120    pub async fn next_event(&mut self) -> Result<Option<StreamEvent>> {
121        loop {
122            if let Some(event) = self.pending.pop_front() {
123                return Ok(Some(event));
124            }
125
126            if self.done {
127                let _remaining = self.parser.finish()?;
128                for event in self.state.finish() {
129                    self.pending.push_back(event);
130                }
131                if let Some(event) = self.pending.pop_front() {
132                    return Ok(Some(event));
133                }
134                return Ok(None);
135            }
136
137            match self.response.chunk().await? {
138                Some(chunk) => {
139                    let sse_events = self.parser.push(&chunk)?;
140                    for sse_event in sse_events {
141                        let events = self.parse_sse_event(&sse_event)?;
142                        self.pending.extend(events);
143                    }
144                }
145                None => {
146                    self.done = true;
147                }
148            }
149        }
150    }
151
152    fn parse_sse_event(
153        &mut self,
154        event: &crate::streaming::sse::SseEvent,
155    ) -> Result<Vec<StreamEvent>> {
156        use crate::streaming::providers::*;
157
158        match self.provider {
159            StreamProvider::Anthropic | StreamProvider::AnthropicCompatible => {
160                let anthropic_event: AnthropicStreamEvent = serde_json::from_str(&event.data)?;
161                Ok(self.state.ingest_anthropic(anthropic_event))
162            }
163            StreamProvider::OpenAI | StreamProvider::OpenAICompatible => {
164                let openai_chunk: OpenAiStreamChunk = serde_json::from_str(&event.data)?;
165                Ok(self.state.ingest_openai(openai_chunk))
166            }
167            StreamProvider::Gemini => {
168                let openai_chunk: OpenAiStreamChunk = serde_json::from_str(&event.data)?;
169                Ok(self.state.ingest_openai(openai_chunk))
170            }
171            StreamProvider::AzureOpenAI => {
172                let openai_chunk: OpenAiStreamChunk = serde_json::from_str(&event.data)?;
173                Ok(self.state.ingest_openai(openai_chunk))
174            }
175            StreamProvider::Bedrock => {
176                let anthropic_event: AnthropicStreamEvent = serde_json::from_str(&event.data)?;
177                Ok(self.state.ingest_anthropic(anthropic_event))
178            }
179            StreamProvider::Ollama => {
180                let ollama_chunk: OllamaStreamChunk = serde_json::from_str(&event.data)?;
181                Ok(self.state.ingest_ollama(ollama_chunk))
182            }
183        }
184    }
185
186    /// 收集所有文本内容
187    pub async fn collect_text(&mut self) -> Result<String> {
188        let mut text = String::new();
189        while let Some(event) = self.next_event().await? {
190            if let StreamEvent::ContentBlockDelta {
191                delta: ContentDelta::Text(t),
192                ..
193            } = event
194            {
195                text.push_str(&t);
196            }
197        }
198        Ok(text)
199    }
200}
201
202/// 回调类型
203pub type OnChunkCallback = Box<dyn Fn(&str) + Send + Sync>;
204
205/// 带回调的流式响应
206pub struct CallbackStream {
207    inner: MessageStream,
208    on_chunk: Option<OnChunkCallback>,
209    abort_flag: Arc<AtomicBool>,
210}
211
212impl CallbackStream {
213    /// 创建带回调的流
214    pub fn new(inner: MessageStream, on_chunk: Option<OnChunkCallback>) -> Self {
215        Self {
216            inner,
217            on_chunk,
218            abort_flag: Arc::new(AtomicBool::new(false)),
219        }
220    }
221
222    /// 获取中断标志
223    pub fn abort_flag(&self) -> Arc<AtomicBool> {
224        Arc::clone(&self.abort_flag)
225    }
226
227    /// 请求中断
228    pub fn abort(&self) {
229        self.abort_flag.store(true, Ordering::Relaxed);
230    }
231
232    /// 获取下一个事件
233    pub async fn next_event(&mut self) -> Result<Option<StreamEvent>> {
234        if self.abort_flag.load(Ordering::Relaxed) {
235            return Ok(None);
236        }
237
238        let event = self.inner.next_event().await?;
239
240        // 触发回调
241        if let Some(ref callback) = self.on_chunk {
242            if let Some(StreamEvent::ContentBlockDelta {
243                delta: ContentDelta::Text(t),
244                ..
245            }) = event.as_ref()
246            {
247                callback(t);
248            }
249        }
250
251        Ok(event)
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn abortable_stream_respects_abort_flag() {
261        use futures::stream;
262
263        let abort_flag = Arc::new(AtomicBool::new(true));
264        let inner = stream::iter(vec![Ok("test".to_string())]);
265        let mut stream = AbortableStream::new(inner, abort_flag);
266
267        let result = futures::executor::block_on_stream(&mut stream).next();
268        assert!(
269            result.is_none(),
270            "aborted stream should return None immediately"
271        );
272    }
273}