sse_stream/
lib.rs

1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
2
3// reference: https://html.spec.whatwg.org/multipage/server-sent-events.html
4use std::{
5    collections::VecDeque,
6    num::ParseIntError,
7    str::Utf8Error,
8    task::{Context, Poll, ready},
9};
10
11use bytes::Buf;
12use futures_util::{Stream, TryStreamExt, stream::MapOk};
13use http_body::{Body, Frame};
14use http_body_util::{BodyDataStream, StreamBody};
15
16pin_project_lite::pin_project! {
17    pub struct SseStream<B: Body> {
18        #[pin]
19        body: BodyDataStream<B>,
20        parsed: VecDeque<Sse>,
21        current: Option<Sse>,
22        unfinished_line: Vec<u8>,
23    }
24}
25
26pub type ByteStreamBody<S, D> = StreamBody<MapOk<S, fn(D) -> Frame<D>>>;
27impl<E, S, D> SseStream<ByteStreamBody<S, D>>
28where
29    S: Stream<Item = Result<D, E>>,
30    E: std::error::Error,
31    D: Buf,
32    StreamBody<ByteStreamBody<S, D>>: Body,
33{
34    /// Create a new [`SseStream`] from a stream of [`Bytes`](bytes::Bytes).
35    ///
36    /// This is useful when you interact with clients don't provide response body directly list reqwest.
37    pub fn from_byte_stream(stream: S) -> Self {
38        let stream = stream.map_ok(http_body::Frame::data as fn(D) -> Frame<D>);
39        let body = StreamBody::new(stream);
40        Self {
41            body: BodyDataStream::new(body),
42            parsed: VecDeque::new(),
43            current: None,
44            unfinished_line: Vec::new(),
45        }
46    }
47}
48
49impl<B: Body> SseStream<B> {
50    /// Create a new [`SseStream`] from a [`Body`].
51    pub fn new(body: B) -> Self {
52        Self {
53            body: BodyDataStream::new(body),
54            parsed: VecDeque::new(),
55            current: None,
56            unfinished_line: Vec::new(),
57        }
58    }
59}
60
61#[derive(Default, Debug)]
62pub struct Sse {
63    pub event: Option<String>,
64    pub data: Option<String>,
65    pub id: Option<String>,
66    pub retry: Option<u64>,
67}
68
69impl Sse {
70    pub fn is_event(&self) -> bool {
71        self.event.is_some()
72    }
73    pub fn is_message(&self) -> bool {
74        self.event.is_none()
75    }
76}
77
78pub enum Error {
79    Body(Box<dyn std::error::Error + Send + Sync>),
80    InvalidLine,
81    DuplicatedEventLine,
82    DuplicatedIdLine,
83    DuplicatedRetry,
84    Utf8Parse(Utf8Error),
85    IntParse(ParseIntError),
86}
87
88impl std::fmt::Display for Error {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        match self {
91            Error::Body(e) => write!(f, "body error: {}", e),
92            Error::InvalidLine => write!(f, "invalid line"),
93            Error::DuplicatedEventLine => write!(f, "duplicated event line"),
94            Error::DuplicatedIdLine => write!(f, "duplicated id line"),
95            Error::DuplicatedRetry => write!(f, "duplicated retry line"),
96            Error::Utf8Parse(e) => write!(f, "utf8 parse error: {}", e),
97            Error::IntParse(e) => write!(f, "int parse error: {}", e),
98        }
99    }
100}
101
102impl std::fmt::Debug for Error {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        match self {
105            Error::Body(e) => write!(f, "Body({:?})", e),
106            Error::InvalidLine => write!(f, "InvalidLine"),
107            Error::DuplicatedEventLine => write!(f, "DuplicatedEventLine"),
108            Error::DuplicatedIdLine => write!(f, "DuplicatedIdLine"),
109            Error::DuplicatedRetry => write!(f, "DuplicatedRetry"),
110            Error::Utf8Parse(e) => write!(f, "Utf8Parse({:?})", e),
111            Error::IntParse(e) => write!(f, "IntParse({:?})", e),
112        }
113    }
114}
115
116impl std::error::Error for Error {
117    fn description(&self) -> &str {
118        match self {
119            Error::Body(_) => "body error",
120            Error::InvalidLine => "invalid line",
121            Error::DuplicatedEventLine => "duplicated event line",
122            Error::DuplicatedIdLine => "duplicated id line",
123            Error::DuplicatedRetry => "duplicated retry line",
124            Error::Utf8Parse(_) => "utf8 parse error",
125            Error::IntParse(_) => "int parse error",
126        }
127    }
128
129    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
130        match self {
131            Error::Body(e) => Some(e.as_ref()),
132            Error::Utf8Parse(e) => Some(e),
133            Error::IntParse(e) => Some(e),
134            _ => None,
135        }
136    }
137}
138
139impl<B: Body> Stream for SseStream<B>
140where
141    B::Error: std::error::Error + Send + Sync + 'static,
142{
143    type Item = Result<Sse, Error>;
144
145    fn poll_next(
146        mut self: std::pin::Pin<&mut Self>,
147        cx: &mut Context<'_>,
148    ) -> Poll<Option<Self::Item>> {
149        let this = self.as_mut().project();
150        if let Some(sse) = this.parsed.pop_front() {
151            return Poll::Ready(Some(Ok(sse)));
152        }
153        let next_data = ready!(this.body.poll_next(cx));
154        match next_data {
155            Some(Ok(data)) => {
156                let chunk = data.chunk();
157
158                if chunk.is_empty() {
159                    return self.poll_next(cx);
160                }
161                let mut lines = chunk.chunk_by(|maybe_nl, _| *maybe_nl != b'\n');
162                let first_line = lines.next().expect("frame is empty");
163                let mut new_unfinished_line = Vec::new();
164                let first_line = if !this.unfinished_line.is_empty() {
165                    this.unfinished_line.extend(first_line);
166                    std::mem::swap(&mut new_unfinished_line, this.unfinished_line);
167                    new_unfinished_line.as_ref()
168                } else {
169                    first_line
170                };
171                let mut lines = std::iter::once(first_line).chain(lines);
172                *this.unfinished_line = loop {
173                    let Some(line) = lines.next() else {
174                        break Vec::new();
175                    };
176                    let line = if line.ends_with(b"\r\n") {
177                        &line[..line.len() - 2]
178                    } else if line.ends_with(b"\n") || line.ends_with(b"\r") {
179                        &line[..line.len() - 1]
180                    } else {
181                        break line.to_vec();
182                    };
183
184                    if line.is_empty() {
185                        if let Some(sse) = this.current.take() {
186                            this.parsed.push_back(sse);
187                        }
188                        continue;
189                    }
190                    // find comma
191                    let Some(comma_index) = line.iter().position(|b| *b == b':') else {
192                        #[cfg(feature = "tracing")]
193                        tracing::warn!(?line, "invalid line, missing `:`");
194                        return Poll::Ready(Some(Err(Error::InvalidLine)));
195                    };
196                    let field_name = &line[..comma_index];
197                    let field_value = if line.len() > comma_index + 1 {
198                        let field_value = &line[comma_index + 1..];
199                        if field_value.starts_with(b" ") {
200                            &field_value[1..]
201                        } else {
202                            field_value
203                        }
204                    } else {
205                        b""
206                    };
207                    match field_name {
208                        b"data" => {
209                            let data_line =
210                                std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
211                            // merge data lines
212                            if let Some(Sse { data, .. }) = this.current.as_mut() {
213                                if data.is_none() {
214                                    data.replace(data_line.to_owned());
215                                } else {
216                                    let data = data.as_mut().unwrap();
217                                    data.push('\n');
218                                    data.push_str(data_line);
219                                }
220                            } else {
221                                this.current.replace(Sse {
222                                    event: None,
223                                    data: Some(data_line.to_owned()),
224                                    id: None,
225                                    retry: None,
226                                });
227                            }
228                        }
229                        b"event" => {
230                            let event_value =
231                                std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
232                            if let Some(Sse { event, .. }) = this.current.as_mut() {
233                                if event.is_some() {
234                                    return Poll::Ready(Some(Err(Error::DuplicatedEventLine)));
235                                } else {
236                                    event.replace(event_value.to_owned());
237                                }
238                            } else {
239                                this.current.replace(Sse {
240                                    event: Some(event_value.to_owned()),
241                                    ..Default::default()
242                                });
243                            }
244                        }
245                        b"id" => {
246                            let id_value =
247                                std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
248                            if let Some(Sse { id, .. }) = this.current.as_mut() {
249                                if id.is_some() {
250                                    return Poll::Ready(Some(Err(Error::DuplicatedIdLine)));
251                                } else {
252                                    id.replace(id_value.to_owned());
253                                }
254                            } else {
255                                this.current.replace(Sse {
256                                    id: Some(id_value.to_owned()),
257                                    ..Default::default()
258                                });
259                            }
260                        }
261                        b"retry" => {
262                            let retry_value = std::str::from_utf8(field_value)
263                                .map_err(Error::Utf8Parse)?
264                                .trim_ascii();
265                            let retry_value =
266                                retry_value.parse::<u64>().map_err(Error::IntParse)?;
267                            if let Some(Sse { retry, .. }) = this.current.as_mut() {
268                                if retry.is_some() {
269                                    return Poll::Ready(Some(Err(Error::DuplicatedRetry)));
270                                } else {
271                                    retry.replace(retry_value);
272                                }
273                            } else {
274                                this.current.replace(Sse {
275                                    retry: Some(retry_value),
276                                    ..Default::default()
277                                });
278                            }
279                        }
280                        b"" => {
281                            #[cfg(feature = "tracing")]
282                            if tracing::enabled!(tracing::Level::DEBUG) {
283                                // a comment
284                                let comment =
285                                    std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
286                                tracing::debug!(?comment, "sse comment line");
287                            }
288                        }
289                        _line => {
290                            #[cfg(feature = "tracing")]
291                            if tracing::enabled!(tracing::Level::WARN) {
292                                tracing::warn!(line = ?_line, "invalid line: unknown field");
293                            }
294                            return Poll::Ready(Some(Err(Error::InvalidLine)));
295                        }
296                    }
297                };
298                self.poll_next(cx)
299            }
300            Some(Err(e)) => Poll::Ready(Some(Err(Error::Body(Box::new(e))))),
301            None => {
302                if let Some(sse) = this.current.take() {
303                    Poll::Ready(Some(Ok(sse)))
304                } else {
305                    Poll::Ready(None)
306                }
307            }
308        }
309    }
310}