1use std::{
2 collections::VecDeque,
3 num::ParseIntError,
4 str::Utf8Error,
5 task::{ready, Context, Poll},
6};
7
8use crate::Sse;
9use bytes::Buf;
10use futures_util::{stream::MapOk, Stream, TryStreamExt};
11use http_body::{Body, Frame};
12use http_body_util::{BodyDataStream, StreamBody};
13
14#[derive(Debug)]
15enum BomHeaderState {
16 Parsing(Vec<u8>),
17 Consumed,
18}
19
20const BOM_HEADER: &[u8] = b"\xEF\xBB\xBF";
21
22fn try_consume_bom_header(buf: &[u8]) -> Option<&[u8]> {
26 if buf.len() < BOM_HEADER.len() {
27 if BOM_HEADER.starts_with(buf) {
28 None
29 } else {
30 Some(buf)
31 }
32 } else if buf.starts_with(BOM_HEADER) {
33 Some(&buf[BOM_HEADER.len()..])
34 } else {
35 Some(buf)
36 }
37}
38
39pin_project_lite::pin_project! {
40 pub struct SseStream<B: Body> {
41 #[pin]
42 body: BodyDataStream<B>,
43 parsed: VecDeque<Sse>,
44 current: Option<Sse>,
45 unfinished_line: Vec<u8>,
46 bom_header_state: BomHeaderState,
47 }
48}
49
50pub type ByteStreamBody<S, D> = StreamBody<MapOk<S, fn(D) -> Frame<D>>>;
51impl<E, S, D> SseStream<ByteStreamBody<S, D>>
52where
53 S: Stream<Item = Result<D, E>>,
54 E: std::error::Error,
55 D: Buf,
56 StreamBody<ByteStreamBody<S, D>>: Body,
57{
58 pub fn from_byte_stream(stream: S) -> Self {
62 let stream = stream.map_ok(http_body::Frame::data as fn(D) -> Frame<D>);
63 let body = StreamBody::new(stream);
64 Self {
65 body: BodyDataStream::new(body),
66 parsed: VecDeque::new(),
67 current: None,
68 unfinished_line: Vec::new(),
69 bom_header_state: BomHeaderState::Parsing(Vec::new()),
70 }
71 }
72}
73
74impl<B: Body> SseStream<B> {
75 pub fn new(body: B) -> Self {
77 Self {
78 body: BodyDataStream::new(body),
79 parsed: VecDeque::new(),
80 current: None,
81 unfinished_line: Vec::new(),
82 bom_header_state: BomHeaderState::Parsing(Vec::new()),
83 }
84 }
85}
86
87pub enum Error {
88 Body(Box<dyn std::error::Error + Send + Sync>),
89 InvalidLine,
90 DuplicatedEventLine,
91 DuplicatedIdLine,
92 DuplicatedRetry,
93 Utf8Parse(Utf8Error),
94 IntParse(ParseIntError),
95}
96
97impl std::fmt::Display for Error {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 match self {
100 Error::Body(e) => write!(f, "body error: {}", e),
101 Error::InvalidLine => write!(f, "invalid line"),
102 Error::DuplicatedEventLine => write!(f, "duplicated event line"),
103 Error::DuplicatedIdLine => write!(f, "duplicated id line"),
104 Error::DuplicatedRetry => write!(f, "duplicated retry line"),
105 Error::Utf8Parse(e) => write!(f, "utf8 parse error: {}", e),
106 Error::IntParse(e) => write!(f, "int parse error: {}", e),
107 }
108 }
109}
110
111impl std::fmt::Debug for Error {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 match self {
114 Error::Body(e) => write!(f, "Body({:?})", e),
115 Error::InvalidLine => write!(f, "InvalidLine"),
116 Error::DuplicatedEventLine => write!(f, "DuplicatedEventLine"),
117 Error::DuplicatedIdLine => write!(f, "DuplicatedIdLine"),
118 Error::DuplicatedRetry => write!(f, "DuplicatedRetry"),
119 Error::Utf8Parse(e) => write!(f, "Utf8Parse({:?})", e),
120 Error::IntParse(e) => write!(f, "IntParse({:?})", e),
121 }
122 }
123}
124
125impl std::error::Error for Error {
126 fn description(&self) -> &str {
127 match self {
128 Error::Body(_) => "body error",
129 Error::InvalidLine => "invalid line",
130 Error::DuplicatedEventLine => "duplicated event line",
131 Error::DuplicatedIdLine => "duplicated id line",
132 Error::DuplicatedRetry => "duplicated retry line",
133 Error::Utf8Parse(_) => "utf8 parse error",
134 Error::IntParse(_) => "int parse error",
135 }
136 }
137
138 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
139 match self {
140 Error::Body(e) => Some(e.as_ref()),
141 Error::Utf8Parse(e) => Some(e),
142 Error::IntParse(e) => Some(e),
143 _ => None,
144 }
145 }
146}
147
148impl<B: Body> Stream for SseStream<B>
149where
150 B::Error: std::error::Error + Send + Sync + 'static,
151{
152 type Item = Result<Sse, Error>;
153
154 fn poll_next(
155 mut self: std::pin::Pin<&mut Self>,
156 cx: &mut Context<'_>,
157 ) -> Poll<Option<Self::Item>> {
158 let this = self.as_mut().project();
159 if let Some(sse) = this.parsed.pop_front() {
160 return Poll::Ready(Some(Ok(sse)));
161 }
162 let next_data = ready!(this.body.poll_next(cx));
163 match next_data {
164 Some(Ok(data)) => {
165 let stripped_vec = if let BomHeaderState::Parsing(buf) = this.bom_header_state {
166 buf.extend_from_slice(data.chunk());
167 if let Some(stripped) = try_consume_bom_header(buf) {
168 let stripped_vec = stripped.to_vec();
169 *this.bom_header_state = BomHeaderState::Consumed;
170 Some(stripped_vec)
171 } else {
172 return self.poll_next(cx);
173 }
174 } else {
175 None
176 };
177
178 let chunk = stripped_vec.as_deref().unwrap_or(data.chunk());
179
180 if chunk.is_empty() {
181 return self.poll_next(cx);
182 }
183 let mut lines = chunk.chunk_by(|maybe_nl, _| *maybe_nl != b'\n');
184 let first_line = lines.next().expect("frame is empty");
185 let mut new_unfinished_line = Vec::new();
186 let first_line = if !this.unfinished_line.is_empty() {
187 this.unfinished_line.extend(first_line);
188 std::mem::swap(&mut new_unfinished_line, this.unfinished_line);
189 new_unfinished_line.as_ref()
190 } else {
191 first_line
192 };
193 let mut lines = std::iter::once(first_line).chain(lines);
194 *this.unfinished_line = loop {
195 let Some(line) = lines.next() else {
196 break Vec::new();
197 };
198 let line = if line.ends_with(b"\r\n") {
199 &line[..line.len() - 2]
200 } else if line.ends_with(b"\n") || line.ends_with(b"\r") {
201 &line[..line.len() - 1]
202 } else {
203 break line.to_vec();
204 };
205
206 if line.is_empty() {
207 if let Some(sse) = this.current.take() {
208 this.parsed.push_back(sse);
209 }
210 continue;
211 }
212 let Some(comma_index) = line.iter().position(|b| *b == b':') else {
214 #[cfg(feature = "tracing")]
215 tracing::warn!(?line, "invalid line, missing `:`");
216 return Poll::Ready(Some(Err(Error::InvalidLine)));
217 };
218 let field_name = &line[..comma_index];
219 let field_value = if line.len() > comma_index + 1 {
220 let field_value = &line[comma_index + 1..];
221 if field_value.starts_with(b" ") {
222 &field_value[1..]
223 } else {
224 field_value
225 }
226 } else {
227 b""
228 };
229 match field_name {
230 b"data" => {
231 let data_line =
232 std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
233 if let Some(Sse { data, .. }) = this.current.as_mut() {
235 if data.is_none() {
236 data.replace(data_line.to_owned());
237 } else {
238 let data = data.as_mut().unwrap();
239 data.push('\n');
240 data.push_str(data_line);
241 }
242 } else {
243 this.current.replace(Sse {
244 event: None,
245 data: Some(data_line.to_owned()),
246 id: None,
247 retry: None,
248 });
249 }
250 }
251 b"event" => {
252 let event_value =
253 std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
254 if let Some(Sse { event, .. }) = this.current.as_mut() {
255 if event.is_some() {
256 return Poll::Ready(Some(Err(Error::DuplicatedEventLine)));
257 } else {
258 event.replace(event_value.to_owned());
259 }
260 } else {
261 this.current.replace(Sse {
262 event: Some(event_value.to_owned()),
263 ..Default::default()
264 });
265 }
266 }
267 b"id" => {
268 let id_value =
269 std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
270 if let Some(Sse { id, .. }) = this.current.as_mut() {
271 if id.is_some() {
272 return Poll::Ready(Some(Err(Error::DuplicatedIdLine)));
273 } else {
274 id.replace(id_value.to_owned());
275 }
276 } else {
277 this.current.replace(Sse {
278 id: Some(id_value.to_owned()),
279 ..Default::default()
280 });
281 }
282 }
283 b"retry" => {
284 let retry_value = std::str::from_utf8(field_value)
285 .map_err(Error::Utf8Parse)?
286 .trim_ascii();
287 let retry_value =
288 retry_value.parse::<u64>().map_err(Error::IntParse)?;
289 if let Some(Sse { retry, .. }) = this.current.as_mut() {
290 if retry.is_some() {
291 return Poll::Ready(Some(Err(Error::DuplicatedRetry)));
292 } else {
293 retry.replace(retry_value);
294 }
295 } else {
296 this.current.replace(Sse {
297 retry: Some(retry_value),
298 ..Default::default()
299 });
300 }
301 }
302 b"" => {
303 #[cfg(feature = "tracing")]
304 if tracing::enabled!(tracing::Level::DEBUG) {
305 let comment =
307 std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
308 tracing::debug!(?comment, "sse comment line");
309 }
310 }
311 _line => {
312 #[cfg(feature = "tracing")]
313 if tracing::enabled!(tracing::Level::WARN) {
314 tracing::warn!(line = ?_line, "invalid line: unknown field");
315 }
316 return Poll::Ready(Some(Err(Error::InvalidLine)));
317 }
318 }
319 };
320 self.poll_next(cx)
321 }
322 Some(Err(e)) => Poll::Ready(Some(Err(Error::Body(Box::new(e))))),
323 None => {
324 if let Some(sse) = this.current.take() {
325 Poll::Ready(Some(Ok(sse)))
326 } else {
327 Poll::Ready(None)
328 }
329 }
330 }
331 }
332}