1use std::collections::VecDeque;
4use std::marker::PhantomData;
5use std::pin::Pin;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::Arc;
8use std::task::{Context, Poll};
9
10use bytes::{Buf, Bytes, BytesMut};
11use futures_util::Stream;
12use memchr::memmem::Finder;
13use serde::de::DeserializeOwned;
14
15use crate::error::{Error, Result};
16use rust_genai_types::response::GenerateContentResponse;
17
18#[derive(Debug, Clone)]
20pub struct ServerSentEvent {
21 pub event: Option<String>,
22 pub data: String,
23 pub id: Option<String>,
24}
25
26pub struct SseDecoder {
28 buffer: BytesMut,
29 finder_lf: Finder<'static>,
30 finder_cr: Finder<'static>,
31 finder_crlf: Finder<'static>,
32}
33
34impl SseDecoder {
35 #[must_use]
37 pub fn new() -> Self {
38 Self {
39 buffer: BytesMut::with_capacity(8192),
40 finder_lf: Finder::new(b"\n\n"),
41 finder_cr: Finder::new(b"\r\r"),
42 finder_crlf: Finder::new(b"\r\n\r\n"),
43 }
44 }
45
46 pub fn decode(&mut self, chunk: &[u8]) -> Vec<Result<ServerSentEvent>> {
48 self.buffer.extend_from_slice(chunk);
49 let mut events = Vec::with_capacity(4);
50
51 while let Some((pos, len)) = self.find_delimiter(&self.buffer) {
52 let event_bytes = self.buffer.split_to(pos);
53 self.buffer.advance(len);
54
55 match Self::parse_lines(&event_bytes) {
56 Ok(Some(event)) => events.push(Ok(event)),
57 Ok(None) => {}
58 Err(err) => events.push(Err(err)),
59 }
60 }
61
62 events
63 }
64
65 fn find_delimiter(&self, buf: &[u8]) -> Option<(usize, usize)> {
66 let best = self.finder_crlf.find(buf).map(|pos| (pos, 4));
67 let best = self
68 .finder_lf
69 .find(buf)
70 .map_or(best, |pos| Some(pick_min(best, pos, 2)));
71 self.finder_cr
72 .find(buf)
73 .map_or(best, |pos| Some(pick_min(best, pos, 2)))
74 }
75
76 fn parse_lines(data: &[u8]) -> Result<Option<ServerSentEvent>> {
77 if data.is_empty() {
78 return Ok(None);
79 }
80
81 let text = std::str::from_utf8(data).map_err(|err| Error::Parse {
82 message: err.to_string(),
83 })?;
84
85 let mut event: Option<String> = None;
86 let mut id: Option<String> = None;
87 let mut data_lines: Vec<String> = Vec::with_capacity(4);
88 let mut has_field = false;
89
90 for line in text.split('\n') {
91 let line = line.trim_end_matches('\r');
92 if line.is_empty() {
93 continue;
94 }
95 if line.starts_with(':') {
96 continue;
97 }
98
99 let (field, value) = match line.split_once(':') {
100 Some((field, value)) => (field, value.strip_prefix(' ').unwrap_or(value)),
101 None => (line, ""),
102 };
103
104 match field {
105 "event" => {
106 has_field = true;
107 if !value.is_empty() {
108 event = Some(value.to_string());
109 }
110 }
111 "data" => {
112 has_field = true;
113 data_lines.push(value.to_string());
114 }
115 "id" => {
116 has_field = true;
117 if !value.is_empty() {
118 id = Some(value.to_string());
119 }
120 }
121 _ => {}
122 }
123 }
124
125 if !has_field {
126 return Ok(None);
127 }
128
129 Ok(Some(ServerSentEvent {
130 event,
131 data: data_lines.join("\n"),
132 id,
133 }))
134 }
135}
136
137impl Default for SseDecoder {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143const fn pick_min(best: Option<(usize, usize)>, pos: usize, len: usize) -> (usize, usize) {
144 match best {
145 None => (pos, len),
146 Some((best_pos, best_len)) => {
147 if pos < best_pos {
148 (pos, len)
149 } else {
150 (best_pos, best_len)
151 }
152 }
153 }
154}
155
156pub struct SseJsonStream<T> {
158 stream: Pin<Box<dyn Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send>>,
159 decoder: SseDecoder,
160 pending: VecDeque<Result<ServerSentEvent>>,
161 done: bool,
162 done_signal: Option<Arc<AtomicBool>>,
163 _marker: PhantomData<T>,
164}
165
166impl<T> Unpin for SseJsonStream<T> {}
167
168impl<T> SseJsonStream<T> {
169 #[must_use]
171 pub fn new(response: reqwest::Response) -> Self {
172 Self::with_done_signal(response, None)
173 }
174
175 #[must_use]
177 pub fn with_done_signal(
178 response: reqwest::Response,
179 done_signal: Option<Arc<AtomicBool>>,
180 ) -> Self {
181 Self {
182 stream: Box::pin(response.bytes_stream()),
183 decoder: SseDecoder::new(),
184 pending: VecDeque::new(),
185 done: false,
186 done_signal,
187 _marker: PhantomData,
188 }
189 }
190}
191
192impl<T> Stream for SseJsonStream<T>
193where
194 T: DeserializeOwned,
195{
196 type Item = Result<T>;
197
198 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
199 let this = self.get_mut();
200 loop {
201 if let Some(item) = this.pending.pop_front() {
202 match item {
203 Err(err) => return Poll::Ready(Some(Err(err))),
204 Ok(event) => {
205 if event.data == "[DONE]" {
206 if let Some(done_signal) = &this.done_signal {
207 done_signal.store(true, Ordering::Relaxed);
208 }
209 this.done = true;
210 continue;
211 }
212
213 let parsed = serde_json::from_str::<T>(&event.data).map_err(Error::from)?;
214 return Poll::Ready(Some(Ok(parsed)));
215 }
216 }
217 }
218
219 if this.done {
220 return Poll::Ready(None);
221 }
222
223 match this.stream.as_mut().poll_next(cx) {
224 Poll::Pending => return Poll::Pending,
225 Poll::Ready(None) => return Poll::Ready(None),
226 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
227 Poll::Ready(Some(Ok(bytes))) => {
228 let events = this.decoder.decode(&bytes);
229 for event in events {
230 this.pending.push_back(event);
231 }
232 }
233 }
234 }
235 }
236}
237
238pub fn parse_sse_stream(
240 response: reqwest::Response,
241) -> impl Stream<Item = Result<GenerateContentResponse>> {
242 parse_sse_stream_with::<GenerateContentResponse>(response)
243}
244
245#[must_use]
247pub fn parse_sse_stream_with<T>(response: reqwest::Response) -> SseJsonStream<T>
248where
249 T: DeserializeOwned,
250{
251 SseJsonStream::new(response)
252}
253
254#[must_use]
256pub fn parse_sse_stream_with_done_signal<T>(
257 response: reqwest::Response,
258 done_signal: Arc<AtomicBool>,
259) -> SseJsonStream<T>
260where
261 T: DeserializeOwned,
262{
263 SseJsonStream::with_done_signal(response, Some(done_signal))
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use futures_util::StreamExt;
270 use serde_json::Value;
271 use wiremock::matchers::method;
272 use wiremock::{Mock, MockServer, ResponseTemplate};
273
274 #[test]
275 fn test_sse_decoder_basic() {
276 let mut decoder = SseDecoder::new();
277 let chunk = b"data: {\"text\":\"Hello\"}\n\ndata: {\"text\":\"World\"}\n\n";
278 let events = decoder.decode(chunk);
279 assert_eq!(events.len(), 2);
280 assert_eq!(events[0].as_ref().unwrap().data, r#"{"text":"Hello"}"#);
281 assert_eq!(events[1].as_ref().unwrap().data, r#"{"text":"World"}"#);
282 }
283
284 #[test]
285 fn test_sse_decoder_crlf() {
286 let mut decoder = SseDecoder::new();
287 let chunk = b"data: {\"text\":\"Hello\"}\r\n\r\n";
288 let events = decoder.decode(chunk);
289 assert_eq!(events.len(), 1);
290 assert_eq!(events[0].as_ref().unwrap().data, r#"{"text":"Hello"}"#);
291 }
292
293 #[test]
294 fn test_sse_decoder_default_works() {
295 let mut decoder = SseDecoder::default();
296 let chunk = b"data: {\"text\":\"Hello\"}\n\n";
297 let events = decoder.decode(chunk);
298 assert_eq!(events.len(), 1);
299 }
300
301 #[test]
302 fn test_sse_decoder_line_without_colon_and_empty_lines() {
303 let mut decoder = SseDecoder::new();
304 let chunk = b"data\n\n\n";
305 let events = decoder.decode(chunk);
306 assert_eq!(events.len(), 1);
307 assert_eq!(events[0].as_ref().unwrap().data, "");
308 }
309
310 #[test]
311 fn test_sse_decoder_only_comments_returns_empty() {
312 let mut decoder = SseDecoder::new();
313 let chunk = b":comment\n\n";
314 let events = decoder.decode(chunk);
315 assert!(events.is_empty());
316 }
317
318 #[test]
319 fn test_sse_done_signal() {
320 let mut decoder = SseDecoder::new();
321 let chunk = b"data: [DONE]\n\n";
322 let events = decoder.decode(chunk);
323 assert_eq!(events.len(), 1);
324 assert_eq!(events[0].as_ref().unwrap().data, "[DONE]");
325 }
326
327 #[test]
328 fn test_sse_double_cr() {
329 let mut decoder = SseDecoder::new();
330 let chunk = b"data: {\"text\":\"Hello\"}\r\r";
331 let events = decoder.decode(chunk);
332 assert_eq!(events.len(), 1);
333 assert_eq!(events[0].as_ref().unwrap().data, r#"{"text":"Hello"}"#);
334 }
335
336 #[test]
337 fn test_sse_decoder_event_and_id() {
338 let mut decoder = SseDecoder::new();
339 let chunk = b":comment\nid: 7\nevent: update\ndata: line1\ndata: line2\n\n";
340 let events = decoder.decode(chunk);
341 assert_eq!(events.len(), 1);
342 let event = events[0].as_ref().unwrap();
343 assert_eq!(event.event.as_deref(), Some("update"));
344 assert_eq!(event.id.as_deref(), Some("7"));
345 assert_eq!(event.data, "line1\nline2");
346 }
347
348 #[test]
349 fn test_sse_decoder_invalid_utf8_and_empty() {
350 let mut decoder = SseDecoder::new();
351 let chunk = b"data: \xFF\xFF\n\n";
352 let events = decoder.decode(chunk);
353 assert_eq!(events.len(), 1);
354 assert!(events[0].as_ref().is_err());
355
356 let events = decoder.decode(b"\n\n");
357 assert!(events.is_empty());
358 }
359
360 #[tokio::test]
361 async fn test_sse_json_stream_invalid_utf8() {
362 let server = MockServer::start().await;
363 let body = vec![0xFF, 0xFF, b'\n', b'\n'];
364 Mock::given(method("GET"))
365 .respond_with(
366 ResponseTemplate::new(200)
367 .insert_header("content-type", "text/event-stream")
368 .set_body_bytes(body),
369 )
370 .mount(&server)
371 .await;
372
373 let response = reqwest::Client::new()
374 .get(server.uri())
375 .send()
376 .await
377 .unwrap();
378 let mut stream = parse_sse_stream_with::<Value>(response);
379 let err = stream.next().await.unwrap().unwrap_err();
380 assert!(matches!(err, Error::Parse { .. }));
381 }
382
383 #[test]
384 fn test_pick_min_prefers_smaller_position() {
385 assert_eq!(pick_min(Some((5, 2)), 2, 4), (2, 4));
386 assert_eq!(pick_min(Some((2, 2)), 5, 4), (2, 2));
387 }
388
389 #[tokio::test]
390 async fn test_sse_json_stream_parses_and_done() {
391 let server = MockServer::start().await;
392 let body = "data: {\"value\":1}\n\ndata: [DONE]\n\n";
393 Mock::given(method("GET"))
394 .respond_with(
395 ResponseTemplate::new(200)
396 .insert_header("content-type", "text/event-stream")
397 .set_body_string(body),
398 )
399 .mount(&server)
400 .await;
401
402 let response = reqwest::Client::new()
403 .get(server.uri())
404 .send()
405 .await
406 .unwrap();
407 let mut stream = parse_sse_stream_with::<Value>(response);
408 let first = stream.next().await.unwrap().unwrap();
409 assert_eq!(first["value"], 1);
410 assert!(stream.next().await.is_none());
411 }
412
413 #[tokio::test]
414 async fn test_sse_json_stream_invalid_json() {
415 let server = MockServer::start().await;
416 let body = "data: {bad json}\n\n";
417 Mock::given(method("GET"))
418 .respond_with(
419 ResponseTemplate::new(200)
420 .insert_header("content-type", "text/event-stream")
421 .set_body_string(body),
422 )
423 .mount(&server)
424 .await;
425
426 let response = reqwest::Client::new()
427 .get(server.uri())
428 .send()
429 .await
430 .unwrap();
431 let mut stream = parse_sse_stream_with::<Value>(response);
432 let err = stream.next().await.unwrap().unwrap_err();
433 assert!(matches!(err, Error::Serialization { .. }));
434 }
435}