1use std::collections::VecDeque;
4use std::marker::PhantomData;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8use bytes::{Buf, Bytes, BytesMut};
9use futures_util::Stream;
10use memchr::memmem::Finder;
11use serde::de::DeserializeOwned;
12
13use crate::error::{Error, Result};
14use rust_genai_types::response::GenerateContentResponse;
15
16#[derive(Debug, Clone)]
18pub struct ServerSentEvent {
19 pub event: Option<String>,
20 pub data: String,
21 pub id: Option<String>,
22}
23
24pub struct SseDecoder {
26 buffer: BytesMut,
27 finder_lf: Finder<'static>,
28 finder_cr: Finder<'static>,
29 finder_crlf: Finder<'static>,
30}
31
32impl SseDecoder {
33 #[must_use]
35 pub fn new() -> Self {
36 Self {
37 buffer: BytesMut::with_capacity(8192),
38 finder_lf: Finder::new(b"\n\n"),
39 finder_cr: Finder::new(b"\r\r"),
40 finder_crlf: Finder::new(b"\r\n\r\n"),
41 }
42 }
43
44 pub fn decode(&mut self, chunk: &[u8]) -> Vec<Result<ServerSentEvent>> {
46 self.buffer.extend_from_slice(chunk);
47 let mut events = Vec::with_capacity(4);
48
49 while let Some((pos, len)) = self.find_delimiter(&self.buffer) {
50 let event_bytes = self.buffer.split_to(pos);
51 self.buffer.advance(len);
52
53 match Self::parse_lines(&event_bytes) {
54 Ok(Some(event)) => events.push(Ok(event)),
55 Ok(None) => {}
56 Err(err) => events.push(Err(err)),
57 }
58 }
59
60 events
61 }
62
63 fn find_delimiter(&self, buf: &[u8]) -> Option<(usize, usize)> {
64 let best = self.finder_crlf.find(buf).map(|pos| (pos, 4));
65 let best = self
66 .finder_lf
67 .find(buf)
68 .map_or(best, |pos| Some(pick_min(best, pos, 2)));
69 self.finder_cr
70 .find(buf)
71 .map_or(best, |pos| Some(pick_min(best, pos, 2)))
72 }
73
74 fn parse_lines(data: &[u8]) -> Result<Option<ServerSentEvent>> {
75 if data.is_empty() {
76 return Ok(None);
77 }
78
79 let text = std::str::from_utf8(data).map_err(|err| Error::Parse {
80 message: err.to_string(),
81 })?;
82
83 let mut event: Option<String> = None;
84 let mut id: Option<String> = None;
85 let mut data_lines: Vec<String> = Vec::with_capacity(4);
86 let mut has_field = false;
87
88 for line in text.split('\n') {
89 let line = line.trim_end_matches('\r');
90 if line.is_empty() {
91 continue;
92 }
93 if line.starts_with(':') {
94 continue;
95 }
96
97 let (field, value) = match line.split_once(':') {
98 Some((field, value)) => (field, value.strip_prefix(' ').unwrap_or(value)),
99 None => (line, ""),
100 };
101
102 match field {
103 "event" => {
104 has_field = true;
105 if !value.is_empty() {
106 event = Some(value.to_string());
107 }
108 }
109 "data" => {
110 has_field = true;
111 data_lines.push(value.to_string());
112 }
113 "id" => {
114 has_field = true;
115 if !value.is_empty() {
116 id = Some(value.to_string());
117 }
118 }
119 _ => {}
120 }
121 }
122
123 if !has_field {
124 return Ok(None);
125 }
126
127 Ok(Some(ServerSentEvent {
128 event,
129 data: data_lines.join("\n"),
130 id,
131 }))
132 }
133}
134
135impl Default for SseDecoder {
136 fn default() -> Self {
137 Self::new()
138 }
139}
140
141const fn pick_min(best: Option<(usize, usize)>, pos: usize, len: usize) -> (usize, usize) {
142 match best {
143 None => (pos, len),
144 Some((best_pos, best_len)) => {
145 if pos < best_pos {
146 (pos, len)
147 } else {
148 (best_pos, best_len)
149 }
150 }
151 }
152}
153
154pub struct SseJsonStream<T> {
156 stream: Pin<Box<dyn Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send>>,
157 decoder: SseDecoder,
158 pending: VecDeque<Result<ServerSentEvent>>,
159 done: bool,
160 _marker: PhantomData<T>,
161}
162
163impl<T> Unpin for SseJsonStream<T> {}
164
165impl<T> SseJsonStream<T> {
166 #[must_use]
168 pub fn new(response: reqwest::Response) -> Self {
169 Self {
170 stream: Box::pin(response.bytes_stream()),
171 decoder: SseDecoder::new(),
172 pending: VecDeque::new(),
173 done: false,
174 _marker: PhantomData,
175 }
176 }
177}
178
179impl<T> Stream for SseJsonStream<T>
180where
181 T: DeserializeOwned,
182{
183 type Item = Result<T>;
184
185 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
186 let this = self.get_mut();
187 loop {
188 if let Some(item) = this.pending.pop_front() {
189 match item {
190 Err(err) => return Poll::Ready(Some(Err(err))),
191 Ok(event) => {
192 if event.data == "[DONE]" {
193 this.done = true;
194 continue;
195 }
196
197 let parsed = serde_json::from_str::<T>(&event.data).map_err(Error::from)?;
198 return Poll::Ready(Some(Ok(parsed)));
199 }
200 }
201 }
202
203 if this.done {
204 return Poll::Ready(None);
205 }
206
207 match this.stream.as_mut().poll_next(cx) {
208 Poll::Pending => return Poll::Pending,
209 Poll::Ready(None) => return Poll::Ready(None),
210 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
211 Poll::Ready(Some(Ok(bytes))) => {
212 let events = this.decoder.decode(&bytes);
213 for event in events {
214 this.pending.push_back(event);
215 }
216 }
217 }
218 }
219 }
220}
221
222pub fn parse_sse_stream(
224 response: reqwest::Response,
225) -> impl Stream<Item = Result<GenerateContentResponse>> {
226 parse_sse_stream_with::<GenerateContentResponse>(response)
227}
228
229#[must_use]
231pub fn parse_sse_stream_with<T>(response: reqwest::Response) -> SseJsonStream<T>
232where
233 T: DeserializeOwned,
234{
235 SseJsonStream::new(response)
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use futures_util::StreamExt;
242 use serde_json::Value;
243 use wiremock::matchers::method;
244 use wiremock::{Mock, MockServer, ResponseTemplate};
245
246 #[test]
247 fn test_sse_decoder_basic() {
248 let mut decoder = SseDecoder::new();
249 let chunk = b"data: {\"text\":\"Hello\"}\n\ndata: {\"text\":\"World\"}\n\n";
250 let events = decoder.decode(chunk);
251 assert_eq!(events.len(), 2);
252 assert_eq!(events[0].as_ref().unwrap().data, r#"{"text":"Hello"}"#);
253 assert_eq!(events[1].as_ref().unwrap().data, r#"{"text":"World"}"#);
254 }
255
256 #[test]
257 fn test_sse_decoder_crlf() {
258 let mut decoder = SseDecoder::new();
259 let chunk = b"data: {\"text\":\"Hello\"}\r\n\r\n";
260 let events = decoder.decode(chunk);
261 assert_eq!(events.len(), 1);
262 assert_eq!(events[0].as_ref().unwrap().data, r#"{"text":"Hello"}"#);
263 }
264
265 #[test]
266 fn test_sse_decoder_default_works() {
267 let mut decoder = SseDecoder::default();
268 let chunk = b"data: {\"text\":\"Hello\"}\n\n";
269 let events = decoder.decode(chunk);
270 assert_eq!(events.len(), 1);
271 }
272
273 #[test]
274 fn test_sse_decoder_line_without_colon_and_empty_lines() {
275 let mut decoder = SseDecoder::new();
276 let chunk = b"data\n\n\n";
277 let events = decoder.decode(chunk);
278 assert_eq!(events.len(), 1);
279 assert_eq!(events[0].as_ref().unwrap().data, "");
280 }
281
282 #[test]
283 fn test_sse_decoder_only_comments_returns_empty() {
284 let mut decoder = SseDecoder::new();
285 let chunk = b":comment\n\n";
286 let events = decoder.decode(chunk);
287 assert!(events.is_empty());
288 }
289
290 #[test]
291 fn test_sse_done_signal() {
292 let mut decoder = SseDecoder::new();
293 let chunk = b"data: [DONE]\n\n";
294 let events = decoder.decode(chunk);
295 assert_eq!(events.len(), 1);
296 assert_eq!(events[0].as_ref().unwrap().data, "[DONE]");
297 }
298
299 #[test]
300 fn test_sse_double_cr() {
301 let mut decoder = SseDecoder::new();
302 let chunk = b"data: {\"text\":\"Hello\"}\r\r";
303 let events = decoder.decode(chunk);
304 assert_eq!(events.len(), 1);
305 assert_eq!(events[0].as_ref().unwrap().data, r#"{"text":"Hello"}"#);
306 }
307
308 #[test]
309 fn test_sse_decoder_event_and_id() {
310 let mut decoder = SseDecoder::new();
311 let chunk = b":comment\nid: 7\nevent: update\ndata: line1\ndata: line2\n\n";
312 let events = decoder.decode(chunk);
313 assert_eq!(events.len(), 1);
314 let event = events[0].as_ref().unwrap();
315 assert_eq!(event.event.as_deref(), Some("update"));
316 assert_eq!(event.id.as_deref(), Some("7"));
317 assert_eq!(event.data, "line1\nline2");
318 }
319
320 #[test]
321 fn test_sse_decoder_invalid_utf8_and_empty() {
322 let mut decoder = SseDecoder::new();
323 let chunk = b"data: \xFF\xFF\n\n";
324 let events = decoder.decode(chunk);
325 assert_eq!(events.len(), 1);
326 assert!(events[0].as_ref().is_err());
327
328 let events = decoder.decode(b"\n\n");
329 assert!(events.is_empty());
330 }
331
332 #[tokio::test]
333 async fn test_sse_json_stream_invalid_utf8() {
334 let server = MockServer::start().await;
335 let body = vec![0xFF, 0xFF, b'\n', b'\n'];
336 Mock::given(method("GET"))
337 .respond_with(
338 ResponseTemplate::new(200)
339 .insert_header("content-type", "text/event-stream")
340 .set_body_bytes(body),
341 )
342 .mount(&server)
343 .await;
344
345 let response = reqwest::Client::new()
346 .get(server.uri())
347 .send()
348 .await
349 .unwrap();
350 let mut stream = parse_sse_stream_with::<Value>(response);
351 let err = stream.next().await.unwrap().unwrap_err();
352 assert!(matches!(err, Error::Parse { .. }));
353 }
354
355 #[test]
356 fn test_pick_min_prefers_smaller_position() {
357 assert_eq!(pick_min(Some((5, 2)), 2, 4), (2, 4));
358 assert_eq!(pick_min(Some((2, 2)), 5, 4), (2, 2));
359 }
360
361 #[tokio::test]
362 async fn test_sse_json_stream_parses_and_done() {
363 let server = MockServer::start().await;
364 let body = "data: {\"value\":1}\n\ndata: [DONE]\n\n";
365 Mock::given(method("GET"))
366 .respond_with(
367 ResponseTemplate::new(200)
368 .insert_header("content-type", "text/event-stream")
369 .set_body_string(body),
370 )
371 .mount(&server)
372 .await;
373
374 let response = reqwest::Client::new()
375 .get(server.uri())
376 .send()
377 .await
378 .unwrap();
379 let mut stream = parse_sse_stream_with::<Value>(response);
380 let first = stream.next().await.unwrap().unwrap();
381 assert_eq!(first["value"], 1);
382 assert!(stream.next().await.is_none());
383 }
384
385 #[tokio::test]
386 async fn test_sse_json_stream_invalid_json() {
387 let server = MockServer::start().await;
388 let body = "data: {bad json}\n\n";
389 Mock::given(method("GET"))
390 .respond_with(
391 ResponseTemplate::new(200)
392 .insert_header("content-type", "text/event-stream")
393 .set_body_string(body),
394 )
395 .mount(&server)
396 .await;
397
398 let response = reqwest::Client::new()
399 .get(server.uri())
400 .send()
401 .await
402 .unwrap();
403 let mut stream = parse_sse_stream_with::<Value>(response);
404 let err = stream.next().await.unwrap().unwrap_err();
405 assert!(matches!(err, Error::Serialization { .. }));
406 }
407}