Skip to main content

sse_stream/
stream.rs

1use std::{
2    collections::VecDeque,
3    num::ParseIntError,
4    str::Utf8Error,
5    task::{ready, Context, Poll},
6};
7
8use crate::Sse;
9use bytes::Buf;
10use futures_util::{stream::MapOk, Stream, TryStreamExt};
11use http_body::{Body, Frame};
12use http_body_util::{BodyDataStream, StreamBody};
13
14#[derive(Debug)]
15enum BomHeaderState {
16    NotFoundYet,
17    Parsing,
18    Consumed,
19}
20
21const BOM_HEADER: &[u8] = b"\xEF\xBB\xBF";
22
23pin_project_lite::pin_project! {
24    pub struct SseStream<B: Body> {
25        #[pin]
26        body: BodyDataStream<B>,
27        parsed: VecDeque<Sse>,
28        current: Option<Sse>,
29        unfinished_line: Vec<u8>,
30        mark_last_chunk_ending_with_cr: bool,
31        bom_header_state: BomHeaderState,
32    }
33}
34
35pub type ByteStreamBody<S, D> = StreamBody<MapOk<S, fn(D) -> Frame<D>>>;
36impl<E, S, D> SseStream<ByteStreamBody<S, D>>
37where
38    S: Stream<Item = Result<D, E>>,
39    E: std::error::Error,
40    D: Buf,
41    StreamBody<ByteStreamBody<S, D>>: Body,
42{
43    /// Create a new [`SseStream`] from a stream of [`Bytes`](bytes::Bytes).
44    ///
45    /// This is useful when you interact with clients don't provide response body directly list reqwest.
46    pub fn from_byte_stream(stream: S) -> Self {
47        let stream = stream.map_ok(http_body::Frame::data as fn(D) -> Frame<D>);
48        let body = StreamBody::new(stream);
49        Self {
50            body: BodyDataStream::new(body),
51            parsed: VecDeque::new(),
52            current: None,
53            unfinished_line: Vec::new(),
54            mark_last_chunk_ending_with_cr: false,
55            bom_header_state: BomHeaderState::NotFoundYet,
56        }
57    }
58}
59
60impl<B: Body> SseStream<B> {
61    /// Create a new [`SseStream`] from a [`Body`].
62    pub fn new(body: B) -> Self {
63        Self {
64            body: BodyDataStream::new(body),
65            parsed: VecDeque::new(),
66            current: None,
67            unfinished_line: Vec::new(),
68            mark_last_chunk_ending_with_cr: false,
69            bom_header_state: BomHeaderState::NotFoundYet,
70        }
71    }
72}
73
74pub enum Error {
75    Body(Box<dyn std::error::Error + Send + Sync>),
76    InvalidLine,
77    DuplicatedEventLine,
78    DuplicatedIdLine,
79    DuplicatedRetry,
80    Utf8Parse(Utf8Error),
81    IntParse(ParseIntError),
82}
83
84impl std::fmt::Display for Error {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        match self {
87            Error::Body(e) => write!(f, "body error: {}", e),
88            Error::InvalidLine => write!(f, "invalid line"),
89            Error::DuplicatedEventLine => write!(f, "duplicated event line"),
90            Error::DuplicatedIdLine => write!(f, "duplicated id line"),
91            Error::DuplicatedRetry => write!(f, "duplicated retry line"),
92            Error::Utf8Parse(e) => write!(f, "utf8 parse error: {}", e),
93            Error::IntParse(e) => write!(f, "int parse error: {}", e),
94        }
95    }
96}
97
98impl std::fmt::Debug for Error {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        match self {
101            Error::Body(e) => write!(f, "Body({:?})", e),
102            Error::InvalidLine => write!(f, "InvalidLine"),
103            Error::DuplicatedEventLine => write!(f, "DuplicatedEventLine"),
104            Error::DuplicatedIdLine => write!(f, "DuplicatedIdLine"),
105            Error::DuplicatedRetry => write!(f, "DuplicatedRetry"),
106            Error::Utf8Parse(e) => write!(f, "Utf8Parse({:?})", e),
107            Error::IntParse(e) => write!(f, "IntParse({:?})", e),
108        }
109    }
110}
111
112impl std::error::Error for Error {
113    fn description(&self) -> &str {
114        match self {
115            Error::Body(_) => "body error",
116            Error::InvalidLine => "invalid line",
117            Error::DuplicatedEventLine => "duplicated event line",
118            Error::DuplicatedIdLine => "duplicated id line",
119            Error::DuplicatedRetry => "duplicated retry line",
120            Error::Utf8Parse(_) => "utf8 parse error",
121            Error::IntParse(_) => "int parse error",
122        }
123    }
124
125    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
126        match self {
127            Error::Body(e) => Some(e.as_ref()),
128            Error::Utf8Parse(e) => Some(e),
129            Error::IntParse(e) => Some(e),
130            _ => None,
131        }
132    }
133}
134
135impl<B: Body> Stream for SseStream<B>
136where
137    B::Error: std::error::Error + Send + Sync + 'static,
138{
139    type Item = Result<Sse, Error>;
140
141    fn poll_next(
142        mut self: std::pin::Pin<&mut Self>,
143        cx: &mut Context<'_>,
144    ) -> Poll<Option<Self::Item>> {
145        let this = self.as_mut().project();
146        if let Some(sse) = this.parsed.pop_front() {
147            return Poll::Ready(Some(Ok(sse)));
148        }
149        let next_data = ready!(this.body.poll_next(cx));
150        match next_data {
151            Some(Ok(mut data)) => {
152                loop {
153                    let mut bytes = data.chunk();
154                    let chunk_size = bytes.len();
155
156                    if *this.mark_last_chunk_ending_with_cr {
157                        if !bytes.is_empty() && bytes[0] == b'\n' {
158                            bytes = &bytes[1..];
159                        }
160                        *this.mark_last_chunk_ending_with_cr = false;
161                    }
162
163                    if bytes.is_empty() {
164                        return self.poll_next(cx);
165                    }
166                    if let BomHeaderState::NotFoundYet = this.bom_header_state {
167                        if bytes[0] == BOM_HEADER[0] {
168                            *this.bom_header_state = BomHeaderState::Parsing;
169                        }
170                    }
171                    // handling situation when the last line is end with `'\r'`. The next chunk may start with `'\n'`, but we should treat them as one line.
172                    if bytes.last().is_some_and(|b| *b == b'\r') {
173                        *this.mark_last_chunk_ending_with_cr = true;
174                    }
175                    let mut lines = bytes.chunk_by(|line_end, line_start| {
176                        !(
177                            // for line ending with `\n`, it can be either `\n` or `\r\n`
178                            *line_end == b'\n' ||
179                            // for line ending with `\r`
180                            (*line_end == b'\r' && *line_start != b'\n')
181                        )
182                    });
183                    let first_line = lines.next().expect("frame is empty");
184
185                    let mut new_unfinished_line = Vec::new();
186                    let mut first_line = if !this.unfinished_line.is_empty() {
187                        this.unfinished_line.extend(first_line);
188                        std::mem::swap(&mut new_unfinished_line, this.unfinished_line);
189                        new_unfinished_line.as_ref()
190                    } else {
191                        first_line
192                    };
193
194                    if let BomHeaderState::Parsing = this.bom_header_state {
195                        if first_line.len() > BOM_HEADER.len() {
196                            if let Some(stripped) = first_line.strip_prefix(BOM_HEADER) {
197                                first_line = stripped
198                            }
199                            // we only check the BOM header only ONCE in the whole stream, it happens instantly when we receive the first line with enough length.
200                            *this.bom_header_state = BomHeaderState::Consumed;
201                        } else {
202                            this.unfinished_line.extend(first_line);
203                            return self.poll_next(cx);
204                        }
205                    }
206
207                    let mut lines = std::iter::once(first_line).chain(lines);
208                    *this.unfinished_line = loop {
209                        let Some(line) = lines.next() else {
210                            break Vec::new();
211                        };
212                        let line = if line.ends_with(b"\r\n") {
213                            &line[..line.len() - 2]
214                        } else if line.ends_with(b"\n") || line.ends_with(b"\r") {
215                            &line[..line.len() - 1]
216                        } else {
217                            break line.to_vec();
218                        };
219
220                        if line.is_empty() {
221                            if let Some(sse) = this.current.take() {
222                                this.parsed.push_back(sse);
223                            }
224                            continue;
225                        }
226                        // find comma
227                        let Some(comma_index) = line.iter().position(|b| *b == b':') else {
228                            #[cfg(feature = "tracing")]
229                            tracing::warn!(?line, "invalid line, missing `:`");
230                            return Poll::Ready(Some(Err(Error::InvalidLine)));
231                        };
232                        let field_name = &line[..comma_index];
233                        let field_value = if line.len() > comma_index + 1 {
234                            let field_value = &line[comma_index + 1..];
235                            if field_value.starts_with(b" ") {
236                                &field_value[1..]
237                            } else {
238                                field_value
239                            }
240                        } else {
241                            b""
242                        };
243                        match field_name {
244                            b"data" => {
245                                let data_line =
246                                    std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
247                                // merge data lines
248                                if let Some(Sse { data, .. }) = this.current.as_mut() {
249                                    if data.is_none() {
250                                        data.replace(data_line.to_owned());
251                                    } else {
252                                        let data = data.as_mut().unwrap();
253                                        data.push('\n');
254                                        data.push_str(data_line);
255                                    }
256                                } else {
257                                    this.current.replace(Sse {
258                                        event: None,
259                                        data: Some(data_line.to_owned()),
260                                        id: None,
261                                        retry: None,
262                                    });
263                                }
264                            }
265                            b"event" => {
266                                let event_value =
267                                    std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
268                                if let Some(Sse { event, .. }) = this.current.as_mut() {
269                                    if event.is_some() {
270                                        return Poll::Ready(Some(Err(Error::DuplicatedEventLine)));
271                                    } else {
272                                        event.replace(event_value.to_owned());
273                                    }
274                                } else {
275                                    this.current.replace(Sse {
276                                        event: Some(event_value.to_owned()),
277                                        ..Default::default()
278                                    });
279                                }
280                            }
281                            b"id" => {
282                                // Per spec: if the id field value contains U+0000 NULL,
283                                // the entire field MUST be ignored.
284                                if field_value.contains(&0u8) {
285                                    #[cfg(feature = "tracing")]
286                                    tracing::warn!(
287                                        ?line,
288                                        "id field contains NULL byte, ignoring per spec"
289                                    );
290                                    continue;
291                                }
292                                let id_value =
293                                    std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
294                                if let Some(Sse { id, .. }) = this.current.as_mut() {
295                                    if id.is_some() {
296                                        return Poll::Ready(Some(Err(Error::DuplicatedIdLine)));
297                                    } else {
298                                        id.replace(id_value.to_owned());
299                                    }
300                                } else {
301                                    this.current.replace(Sse {
302                                        id: Some(id_value.to_owned()),
303                                        ..Default::default()
304                                    });
305                                }
306                            }
307                            b"retry" => {
308                                let retry_value = std::str::from_utf8(field_value)
309                                    .map_err(Error::Utf8Parse)?
310                                    .trim_ascii();
311                                let retry_value =
312                                    retry_value.parse::<u64>().map_err(Error::IntParse)?;
313                                if let Some(Sse { retry, .. }) = this.current.as_mut() {
314                                    if retry.is_some() {
315                                        return Poll::Ready(Some(Err(Error::DuplicatedRetry)));
316                                    } else {
317                                        retry.replace(retry_value);
318                                    }
319                                } else {
320                                    this.current.replace(Sse {
321                                        retry: Some(retry_value),
322                                        ..Default::default()
323                                    });
324                                }
325                            }
326                            b"" => {
327                                #[cfg(feature = "tracing")]
328                                if tracing::enabled!(tracing::Level::DEBUG) {
329                                    // a comment
330                                    let comment = std::str::from_utf8(field_value)
331                                        .map_err(Error::Utf8Parse)?;
332                                    tracing::debug!(?comment, "sse comment line");
333                                }
334                            }
335                            _line => {
336                                #[cfg(feature = "tracing")]
337                                if tracing::enabled!(tracing::Level::WARN) {
338                                    tracing::warn!(line = ?_line, "invalid line: unknown field");
339                                }
340                                return Poll::Ready(Some(Err(Error::InvalidLine)));
341                            }
342                        }
343                    };
344                    data.advance(chunk_size);
345                    if !data.has_remaining() {
346                        break;
347                    }
348                }
349                self.poll_next(cx)
350            }
351            Some(Err(e)) => Poll::Ready(Some(Err(Error::Body(Box::new(e))))),
352            None => {
353                // When data stream terminated without empty line, we should discard last incomplate message.
354                Poll::Ready(None)
355            }
356        }
357    }
358}