Skip to main content

rust_genai/
sse.rs

1//! SSE (Server-Sent Events) stream decoding utilities.
2
3use std::collections::VecDeque;
4use std::marker::PhantomData;
5use std::pin::Pin;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::Arc;
8use std::task::{Context, Poll};
9
10use bytes::{Buf, Bytes, BytesMut};
11use futures_util::Stream;
12use memchr::memmem::Finder;
13use serde::de::DeserializeOwned;
14
15use crate::error::{Error, Result};
16use rust_genai_types::response::GenerateContentResponse;
17
18/// SSE 事件。
19#[derive(Debug, Clone)]
20pub struct ServerSentEvent {
21    pub event: Option<String>,
22    pub data: String,
23    pub id: Option<String>,
24}
25
26/// SSE 解码器。
27pub struct SseDecoder {
28    buffer: BytesMut,
29    finder_lf: Finder<'static>,
30    finder_cr: Finder<'static>,
31    finder_crlf: Finder<'static>,
32}
33
34impl SseDecoder {
35    /// 创建新的 SSE 解码器。
36    #[must_use]
37    pub fn new() -> Self {
38        Self {
39            buffer: BytesMut::with_capacity(8192),
40            finder_lf: Finder::new(b"\n\n"),
41            finder_cr: Finder::new(b"\r\r"),
42            finder_crlf: Finder::new(b"\r\n\r\n"),
43        }
44    }
45
46    /// 解码一个 chunk,返回完整的 SSE 事件。
47    pub fn decode(&mut self, chunk: &[u8]) -> Vec<Result<ServerSentEvent>> {
48        self.buffer.extend_from_slice(chunk);
49        let mut events = Vec::with_capacity(4);
50
51        while let Some((pos, len)) = self.find_delimiter(&self.buffer) {
52            let event_bytes = self.buffer.split_to(pos);
53            self.buffer.advance(len);
54
55            match Self::parse_lines(&event_bytes) {
56                Ok(Some(event)) => events.push(Ok(event)),
57                Ok(None) => {}
58                Err(err) => events.push(Err(err)),
59            }
60        }
61
62        events
63    }
64
65    fn find_delimiter(&self, buf: &[u8]) -> Option<(usize, usize)> {
66        let best = self.finder_crlf.find(buf).map(|pos| (pos, 4));
67        let best = self
68            .finder_lf
69            .find(buf)
70            .map_or(best, |pos| Some(pick_min(best, pos, 2)));
71        self.finder_cr
72            .find(buf)
73            .map_or(best, |pos| Some(pick_min(best, pos, 2)))
74    }
75
76    fn parse_lines(data: &[u8]) -> Result<Option<ServerSentEvent>> {
77        if data.is_empty() {
78            return Ok(None);
79        }
80
81        let text = std::str::from_utf8(data).map_err(|err| Error::Parse {
82            message: err.to_string(),
83        })?;
84
85        let mut event: Option<String> = None;
86        let mut id: Option<String> = None;
87        let mut data_lines: Vec<String> = Vec::with_capacity(4);
88        let mut has_field = false;
89
90        for line in text.split('\n') {
91            let line = line.trim_end_matches('\r');
92            if line.is_empty() {
93                continue;
94            }
95            if line.starts_with(':') {
96                continue;
97            }
98
99            let (field, value) = match line.split_once(':') {
100                Some((field, value)) => (field, value.strip_prefix(' ').unwrap_or(value)),
101                None => (line, ""),
102            };
103
104            match field {
105                "event" => {
106                    has_field = true;
107                    if !value.is_empty() {
108                        event = Some(value.to_string());
109                    }
110                }
111                "data" => {
112                    has_field = true;
113                    data_lines.push(value.to_string());
114                }
115                "id" => {
116                    has_field = true;
117                    if !value.is_empty() {
118                        id = Some(value.to_string());
119                    }
120                }
121                _ => {}
122            }
123        }
124
125        if !has_field {
126            return Ok(None);
127        }
128
129        Ok(Some(ServerSentEvent {
130            event,
131            data: data_lines.join("\n"),
132            id,
133        }))
134    }
135}
136
137impl Default for SseDecoder {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143const fn pick_min(best: Option<(usize, usize)>, pos: usize, len: usize) -> (usize, usize) {
144    match best {
145        None => (pos, len),
146        Some((best_pos, best_len)) => {
147            if pos < best_pos {
148                (pos, len)
149            } else {
150                (best_pos, best_len)
151            }
152        }
153    }
154}
155
156/// SSE JSON Stream 包装器(泛型)。
157pub struct SseJsonStream<T> {
158    stream: Pin<Box<dyn Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send>>,
159    decoder: SseDecoder,
160    pending: VecDeque<Result<ServerSentEvent>>,
161    done: bool,
162    done_signal: Option<Arc<AtomicBool>>,
163    _marker: PhantomData<T>,
164}
165
166impl<T> Unpin for SseJsonStream<T> {}
167
168impl<T> SseJsonStream<T> {
169    /// 从 HTTP 响应创建 SSE 流。
170    #[must_use]
171    pub fn new(response: reqwest::Response) -> Self {
172        Self::with_done_signal(response, None)
173    }
174
175    /// 从 HTTP 响应创建带显式完成标记的 SSE 流。
176    #[must_use]
177    pub fn with_done_signal(
178        response: reqwest::Response,
179        done_signal: Option<Arc<AtomicBool>>,
180    ) -> Self {
181        Self {
182            stream: Box::pin(response.bytes_stream()),
183            decoder: SseDecoder::new(),
184            pending: VecDeque::new(),
185            done: false,
186            done_signal,
187            _marker: PhantomData,
188        }
189    }
190}
191
192impl<T> Stream for SseJsonStream<T>
193where
194    T: DeserializeOwned,
195{
196    type Item = Result<T>;
197
198    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
199        let this = self.get_mut();
200        loop {
201            if let Some(item) = this.pending.pop_front() {
202                match item {
203                    Err(err) => return Poll::Ready(Some(Err(err))),
204                    Ok(event) => {
205                        if event.data == "[DONE]" {
206                            if let Some(done_signal) = &this.done_signal {
207                                done_signal.store(true, Ordering::Relaxed);
208                            }
209                            this.done = true;
210                            continue;
211                        }
212
213                        let parsed = serde_json::from_str::<T>(&event.data).map_err(Error::from)?;
214                        return Poll::Ready(Some(Ok(parsed)));
215                    }
216                }
217            }
218
219            if this.done {
220                return Poll::Ready(None);
221            }
222
223            match this.stream.as_mut().poll_next(cx) {
224                Poll::Pending => return Poll::Pending,
225                Poll::Ready(None) => return Poll::Ready(None),
226                Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
227                Poll::Ready(Some(Ok(bytes))) => {
228                    let events = this.decoder.decode(&bytes);
229                    for event in events {
230                        this.pending.push_back(event);
231                    }
232                }
233            }
234        }
235    }
236}
237
238/// 便捷函数:从 reqwest Response 创建 SSE 流。
239pub fn parse_sse_stream(
240    response: reqwest::Response,
241) -> impl Stream<Item = Result<GenerateContentResponse>> {
242    parse_sse_stream_with::<GenerateContentResponse>(response)
243}
244
245/// 泛型 SSE JSON 流解析器。
246#[must_use]
247pub fn parse_sse_stream_with<T>(response: reqwest::Response) -> SseJsonStream<T>
248where
249    T: DeserializeOwned,
250{
251    SseJsonStream::new(response)
252}
253
254/// 泛型 SSE JSON 流解析器,支持追踪显式 `[DONE]` 标记。
255#[must_use]
256pub fn parse_sse_stream_with_done_signal<T>(
257    response: reqwest::Response,
258    done_signal: Arc<AtomicBool>,
259) -> SseJsonStream<T>
260where
261    T: DeserializeOwned,
262{
263    SseJsonStream::with_done_signal(response, Some(done_signal))
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use futures_util::StreamExt;
270    use serde_json::Value;
271    use wiremock::matchers::method;
272    use wiremock::{Mock, MockServer, ResponseTemplate};
273
274    #[test]
275    fn test_sse_decoder_basic() {
276        let mut decoder = SseDecoder::new();
277        let chunk = b"data: {\"text\":\"Hello\"}\n\ndata: {\"text\":\"World\"}\n\n";
278        let events = decoder.decode(chunk);
279        assert_eq!(events.len(), 2);
280        assert_eq!(events[0].as_ref().unwrap().data, r#"{"text":"Hello"}"#);
281        assert_eq!(events[1].as_ref().unwrap().data, r#"{"text":"World"}"#);
282    }
283
284    #[test]
285    fn test_sse_decoder_crlf() {
286        let mut decoder = SseDecoder::new();
287        let chunk = b"data: {\"text\":\"Hello\"}\r\n\r\n";
288        let events = decoder.decode(chunk);
289        assert_eq!(events.len(), 1);
290        assert_eq!(events[0].as_ref().unwrap().data, r#"{"text":"Hello"}"#);
291    }
292
293    #[test]
294    fn test_sse_decoder_default_works() {
295        let mut decoder = SseDecoder::default();
296        let chunk = b"data: {\"text\":\"Hello\"}\n\n";
297        let events = decoder.decode(chunk);
298        assert_eq!(events.len(), 1);
299    }
300
301    #[test]
302    fn test_sse_decoder_line_without_colon_and_empty_lines() {
303        let mut decoder = SseDecoder::new();
304        let chunk = b"data\n\n\n";
305        let events = decoder.decode(chunk);
306        assert_eq!(events.len(), 1);
307        assert_eq!(events[0].as_ref().unwrap().data, "");
308    }
309
310    #[test]
311    fn test_sse_decoder_only_comments_returns_empty() {
312        let mut decoder = SseDecoder::new();
313        let chunk = b":comment\n\n";
314        let events = decoder.decode(chunk);
315        assert!(events.is_empty());
316    }
317
318    #[test]
319    fn test_sse_done_signal() {
320        let mut decoder = SseDecoder::new();
321        let chunk = b"data: [DONE]\n\n";
322        let events = decoder.decode(chunk);
323        assert_eq!(events.len(), 1);
324        assert_eq!(events[0].as_ref().unwrap().data, "[DONE]");
325    }
326
327    #[test]
328    fn test_sse_double_cr() {
329        let mut decoder = SseDecoder::new();
330        let chunk = b"data: {\"text\":\"Hello\"}\r\r";
331        let events = decoder.decode(chunk);
332        assert_eq!(events.len(), 1);
333        assert_eq!(events[0].as_ref().unwrap().data, r#"{"text":"Hello"}"#);
334    }
335
336    #[test]
337    fn test_sse_decoder_event_and_id() {
338        let mut decoder = SseDecoder::new();
339        let chunk = b":comment\nid: 7\nevent: update\ndata: line1\ndata: line2\n\n";
340        let events = decoder.decode(chunk);
341        assert_eq!(events.len(), 1);
342        let event = events[0].as_ref().unwrap();
343        assert_eq!(event.event.as_deref(), Some("update"));
344        assert_eq!(event.id.as_deref(), Some("7"));
345        assert_eq!(event.data, "line1\nline2");
346    }
347
348    #[test]
349    fn test_sse_decoder_invalid_utf8_and_empty() {
350        let mut decoder = SseDecoder::new();
351        let chunk = b"data: \xFF\xFF\n\n";
352        let events = decoder.decode(chunk);
353        assert_eq!(events.len(), 1);
354        assert!(events[0].as_ref().is_err());
355
356        let events = decoder.decode(b"\n\n");
357        assert!(events.is_empty());
358    }
359
360    #[tokio::test]
361    async fn test_sse_json_stream_invalid_utf8() {
362        let server = MockServer::start().await;
363        let body = vec![0xFF, 0xFF, b'\n', b'\n'];
364        Mock::given(method("GET"))
365            .respond_with(
366                ResponseTemplate::new(200)
367                    .insert_header("content-type", "text/event-stream")
368                    .set_body_bytes(body),
369            )
370            .mount(&server)
371            .await;
372
373        let response = reqwest::Client::new()
374            .get(server.uri())
375            .send()
376            .await
377            .unwrap();
378        let mut stream = parse_sse_stream_with::<Value>(response);
379        let err = stream.next().await.unwrap().unwrap_err();
380        assert!(matches!(err, Error::Parse { .. }));
381    }
382
383    #[test]
384    fn test_pick_min_prefers_smaller_position() {
385        assert_eq!(pick_min(Some((5, 2)), 2, 4), (2, 4));
386        assert_eq!(pick_min(Some((2, 2)), 5, 4), (2, 2));
387    }
388
389    #[tokio::test]
390    async fn test_sse_json_stream_parses_and_done() {
391        let server = MockServer::start().await;
392        let body = "data: {\"value\":1}\n\ndata: [DONE]\n\n";
393        Mock::given(method("GET"))
394            .respond_with(
395                ResponseTemplate::new(200)
396                    .insert_header("content-type", "text/event-stream")
397                    .set_body_string(body),
398            )
399            .mount(&server)
400            .await;
401
402        let response = reqwest::Client::new()
403            .get(server.uri())
404            .send()
405            .await
406            .unwrap();
407        let mut stream = parse_sse_stream_with::<Value>(response);
408        let first = stream.next().await.unwrap().unwrap();
409        assert_eq!(first["value"], 1);
410        assert!(stream.next().await.is_none());
411    }
412
413    #[tokio::test]
414    async fn test_sse_json_stream_invalid_json() {
415        let server = MockServer::start().await;
416        let body = "data: {bad json}\n\n";
417        Mock::given(method("GET"))
418            .respond_with(
419                ResponseTemplate::new(200)
420                    .insert_header("content-type", "text/event-stream")
421                    .set_body_string(body),
422            )
423            .mount(&server)
424            .await;
425
426        let response = reqwest::Client::new()
427            .get(server.uri())
428            .send()
429            .await
430            .unwrap();
431        let mut stream = parse_sse_stream_with::<Value>(response);
432        let err = stream.next().await.unwrap().unwrap_err();
433        assert!(matches!(err, Error::Serialization { .. }));
434    }
435}