Skip to main content

rust_genai/
sse.rs

1//! SSE (Server-Sent Events) stream decoding utilities.
2
3use std::collections::VecDeque;
4use std::marker::PhantomData;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8use bytes::{Buf, Bytes, BytesMut};
9use futures_util::Stream;
10use memchr::memmem::Finder;
11use serde::de::DeserializeOwned;
12
13use crate::error::{Error, Result};
14use rust_genai_types::response::GenerateContentResponse;
15
16/// SSE 事件。
17#[derive(Debug, Clone)]
18pub struct ServerSentEvent {
19    pub event: Option<String>,
20    pub data: String,
21    pub id: Option<String>,
22}
23
24/// SSE 解码器。
25pub struct SseDecoder {
26    buffer: BytesMut,
27    finder_lf: Finder<'static>,
28    finder_cr: Finder<'static>,
29    finder_crlf: Finder<'static>,
30}
31
32impl SseDecoder {
33    /// 创建新的 SSE 解码器。
34    #[must_use]
35    pub fn new() -> Self {
36        Self {
37            buffer: BytesMut::with_capacity(8192),
38            finder_lf: Finder::new(b"\n\n"),
39            finder_cr: Finder::new(b"\r\r"),
40            finder_crlf: Finder::new(b"\r\n\r\n"),
41        }
42    }
43
44    /// 解码一个 chunk,返回完整的 SSE 事件。
45    pub fn decode(&mut self, chunk: &[u8]) -> Vec<Result<ServerSentEvent>> {
46        self.buffer.extend_from_slice(chunk);
47        let mut events = Vec::with_capacity(4);
48
49        while let Some((pos, len)) = self.find_delimiter(&self.buffer) {
50            let event_bytes = self.buffer.split_to(pos);
51            self.buffer.advance(len);
52
53            match Self::parse_lines(&event_bytes) {
54                Ok(Some(event)) => events.push(Ok(event)),
55                Ok(None) => {}
56                Err(err) => events.push(Err(err)),
57            }
58        }
59
60        events
61    }
62
63    fn find_delimiter(&self, buf: &[u8]) -> Option<(usize, usize)> {
64        let best = self.finder_crlf.find(buf).map(|pos| (pos, 4));
65        let best = self
66            .finder_lf
67            .find(buf)
68            .map_or(best, |pos| Some(pick_min(best, pos, 2)));
69        self.finder_cr
70            .find(buf)
71            .map_or(best, |pos| Some(pick_min(best, pos, 2)))
72    }
73
74    fn parse_lines(data: &[u8]) -> Result<Option<ServerSentEvent>> {
75        if data.is_empty() {
76            return Ok(None);
77        }
78
79        let text = std::str::from_utf8(data).map_err(|err| Error::Parse {
80            message: err.to_string(),
81        })?;
82
83        let mut event: Option<String> = None;
84        let mut id: Option<String> = None;
85        let mut data_lines: Vec<String> = Vec::with_capacity(4);
86        let mut has_field = false;
87
88        for line in text.split('\n') {
89            let line = line.trim_end_matches('\r');
90            if line.is_empty() {
91                continue;
92            }
93            if line.starts_with(':') {
94                continue;
95            }
96
97            let (field, value) = match line.split_once(':') {
98                Some((field, value)) => (field, value.strip_prefix(' ').unwrap_or(value)),
99                None => (line, ""),
100            };
101
102            match field {
103                "event" => {
104                    has_field = true;
105                    if !value.is_empty() {
106                        event = Some(value.to_string());
107                    }
108                }
109                "data" => {
110                    has_field = true;
111                    data_lines.push(value.to_string());
112                }
113                "id" => {
114                    has_field = true;
115                    if !value.is_empty() {
116                        id = Some(value.to_string());
117                    }
118                }
119                _ => {}
120            }
121        }
122
123        if !has_field {
124            return Ok(None);
125        }
126
127        Ok(Some(ServerSentEvent {
128            event,
129            data: data_lines.join("\n"),
130            id,
131        }))
132    }
133}
134
135impl Default for SseDecoder {
136    fn default() -> Self {
137        Self::new()
138    }
139}
140
141const fn pick_min(best: Option<(usize, usize)>, pos: usize, len: usize) -> (usize, usize) {
142    match best {
143        None => (pos, len),
144        Some((best_pos, best_len)) => {
145            if pos < best_pos {
146                (pos, len)
147            } else {
148                (best_pos, best_len)
149            }
150        }
151    }
152}
153
154/// SSE JSON Stream 包装器(泛型)。
155pub struct SseJsonStream<T> {
156    stream: Pin<Box<dyn Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send>>,
157    decoder: SseDecoder,
158    pending: VecDeque<Result<ServerSentEvent>>,
159    done: bool,
160    _marker: PhantomData<T>,
161}
162
163impl<T> Unpin for SseJsonStream<T> {}
164
165impl<T> SseJsonStream<T> {
166    /// 从 HTTP 响应创建 SSE 流。
167    #[must_use]
168    pub fn new(response: reqwest::Response) -> Self {
169        Self {
170            stream: Box::pin(response.bytes_stream()),
171            decoder: SseDecoder::new(),
172            pending: VecDeque::new(),
173            done: false,
174            _marker: PhantomData,
175        }
176    }
177}
178
179impl<T> Stream for SseJsonStream<T>
180where
181    T: DeserializeOwned,
182{
183    type Item = Result<T>;
184
185    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
186        let this = self.get_mut();
187        loop {
188            if let Some(item) = this.pending.pop_front() {
189                match item {
190                    Err(err) => return Poll::Ready(Some(Err(err))),
191                    Ok(event) => {
192                        if event.data == "[DONE]" {
193                            this.done = true;
194                            continue;
195                        }
196
197                        let parsed = serde_json::from_str::<T>(&event.data).map_err(Error::from)?;
198                        return Poll::Ready(Some(Ok(parsed)));
199                    }
200                }
201            }
202
203            if this.done {
204                return Poll::Ready(None);
205            }
206
207            match this.stream.as_mut().poll_next(cx) {
208                Poll::Pending => return Poll::Pending,
209                Poll::Ready(None) => return Poll::Ready(None),
210                Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
211                Poll::Ready(Some(Ok(bytes))) => {
212                    let events = this.decoder.decode(&bytes);
213                    for event in events {
214                        this.pending.push_back(event);
215                    }
216                }
217            }
218        }
219    }
220}
221
222/// 便捷函数:从 reqwest Response 创建 SSE 流。
223pub fn parse_sse_stream(
224    response: reqwest::Response,
225) -> impl Stream<Item = Result<GenerateContentResponse>> {
226    parse_sse_stream_with::<GenerateContentResponse>(response)
227}
228
229/// 泛型 SSE JSON 流解析器。
230#[must_use]
231pub fn parse_sse_stream_with<T>(response: reqwest::Response) -> SseJsonStream<T>
232where
233    T: DeserializeOwned,
234{
235    SseJsonStream::new(response)
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use futures_util::StreamExt;
242    use serde_json::Value;
243    use wiremock::matchers::method;
244    use wiremock::{Mock, MockServer, ResponseTemplate};
245
246    #[test]
247    fn test_sse_decoder_basic() {
248        let mut decoder = SseDecoder::new();
249        let chunk = b"data: {\"text\":\"Hello\"}\n\ndata: {\"text\":\"World\"}\n\n";
250        let events = decoder.decode(chunk);
251        assert_eq!(events.len(), 2);
252        assert_eq!(events[0].as_ref().unwrap().data, r#"{"text":"Hello"}"#);
253        assert_eq!(events[1].as_ref().unwrap().data, r#"{"text":"World"}"#);
254    }
255
256    #[test]
257    fn test_sse_decoder_crlf() {
258        let mut decoder = SseDecoder::new();
259        let chunk = b"data: {\"text\":\"Hello\"}\r\n\r\n";
260        let events = decoder.decode(chunk);
261        assert_eq!(events.len(), 1);
262        assert_eq!(events[0].as_ref().unwrap().data, r#"{"text":"Hello"}"#);
263    }
264
265    #[test]
266    fn test_sse_decoder_default_works() {
267        let mut decoder = SseDecoder::default();
268        let chunk = b"data: {\"text\":\"Hello\"}\n\n";
269        let events = decoder.decode(chunk);
270        assert_eq!(events.len(), 1);
271    }
272
273    #[test]
274    fn test_sse_decoder_line_without_colon_and_empty_lines() {
275        let mut decoder = SseDecoder::new();
276        let chunk = b"data\n\n\n";
277        let events = decoder.decode(chunk);
278        assert_eq!(events.len(), 1);
279        assert_eq!(events[0].as_ref().unwrap().data, "");
280    }
281
282    #[test]
283    fn test_sse_decoder_only_comments_returns_empty() {
284        let mut decoder = SseDecoder::new();
285        let chunk = b":comment\n\n";
286        let events = decoder.decode(chunk);
287        assert!(events.is_empty());
288    }
289
290    #[test]
291    fn test_sse_done_signal() {
292        let mut decoder = SseDecoder::new();
293        let chunk = b"data: [DONE]\n\n";
294        let events = decoder.decode(chunk);
295        assert_eq!(events.len(), 1);
296        assert_eq!(events[0].as_ref().unwrap().data, "[DONE]");
297    }
298
299    #[test]
300    fn test_sse_double_cr() {
301        let mut decoder = SseDecoder::new();
302        let chunk = b"data: {\"text\":\"Hello\"}\r\r";
303        let events = decoder.decode(chunk);
304        assert_eq!(events.len(), 1);
305        assert_eq!(events[0].as_ref().unwrap().data, r#"{"text":"Hello"}"#);
306    }
307
308    #[test]
309    fn test_sse_decoder_event_and_id() {
310        let mut decoder = SseDecoder::new();
311        let chunk = b":comment\nid: 7\nevent: update\ndata: line1\ndata: line2\n\n";
312        let events = decoder.decode(chunk);
313        assert_eq!(events.len(), 1);
314        let event = events[0].as_ref().unwrap();
315        assert_eq!(event.event.as_deref(), Some("update"));
316        assert_eq!(event.id.as_deref(), Some("7"));
317        assert_eq!(event.data, "line1\nline2");
318    }
319
320    #[test]
321    fn test_sse_decoder_invalid_utf8_and_empty() {
322        let mut decoder = SseDecoder::new();
323        let chunk = b"data: \xFF\xFF\n\n";
324        let events = decoder.decode(chunk);
325        assert_eq!(events.len(), 1);
326        assert!(events[0].as_ref().is_err());
327
328        let events = decoder.decode(b"\n\n");
329        assert!(events.is_empty());
330    }
331
332    #[tokio::test]
333    async fn test_sse_json_stream_invalid_utf8() {
334        let server = MockServer::start().await;
335        let body = vec![0xFF, 0xFF, b'\n', b'\n'];
336        Mock::given(method("GET"))
337            .respond_with(
338                ResponseTemplate::new(200)
339                    .insert_header("content-type", "text/event-stream")
340                    .set_body_bytes(body),
341            )
342            .mount(&server)
343            .await;
344
345        let response = reqwest::Client::new()
346            .get(server.uri())
347            .send()
348            .await
349            .unwrap();
350        let mut stream = parse_sse_stream_with::<Value>(response);
351        let err = stream.next().await.unwrap().unwrap_err();
352        assert!(matches!(err, Error::Parse { .. }));
353    }
354
355    #[test]
356    fn test_pick_min_prefers_smaller_position() {
357        assert_eq!(pick_min(Some((5, 2)), 2, 4), (2, 4));
358        assert_eq!(pick_min(Some((2, 2)), 5, 4), (2, 2));
359    }
360
361    #[tokio::test]
362    async fn test_sse_json_stream_parses_and_done() {
363        let server = MockServer::start().await;
364        let body = "data: {\"value\":1}\n\ndata: [DONE]\n\n";
365        Mock::given(method("GET"))
366            .respond_with(
367                ResponseTemplate::new(200)
368                    .insert_header("content-type", "text/event-stream")
369                    .set_body_string(body),
370            )
371            .mount(&server)
372            .await;
373
374        let response = reqwest::Client::new()
375            .get(server.uri())
376            .send()
377            .await
378            .unwrap();
379        let mut stream = parse_sse_stream_with::<Value>(response);
380        let first = stream.next().await.unwrap().unwrap();
381        assert_eq!(first["value"], 1);
382        assert!(stream.next().await.is_none());
383    }
384
385    #[tokio::test]
386    async fn test_sse_json_stream_invalid_json() {
387        let server = MockServer::start().await;
388        let body = "data: {bad json}\n\n";
389        Mock::given(method("GET"))
390            .respond_with(
391                ResponseTemplate::new(200)
392                    .insert_header("content-type", "text/event-stream")
393                    .set_body_string(body),
394            )
395            .mount(&server)
396            .await;
397
398        let response = reqwest::Client::new()
399            .get(server.uri())
400            .send()
401            .await
402            .unwrap();
403        let mut stream = parse_sse_stream_with::<Value>(response);
404        let err = stream.next().await.unwrap().unwrap_err();
405        assert!(matches!(err, Error::Serialization { .. }));
406    }
407}