1use std::{
2 collections::VecDeque,
3 num::ParseIntError,
4 str::Utf8Error,
5 task::{Context, Poll, ready},
6};
7
8use crate::Sse;
9use bytes::Buf;
10use futures_util::{Stream, TryStreamExt, stream::MapOk};
11use http_body::{Body, Frame};
12use http_body_util::{BodyDataStream, StreamBody};
13
14pin_project_lite::pin_project! {
15 pub struct SseStream<B: Body> {
16 #[pin]
17 body: BodyDataStream<B>,
18 parsed: VecDeque<Sse>,
19 current: Option<Sse>,
20 unfinished_line: Vec<u8>,
21 }
22}
23
24pub type ByteStreamBody<S, D> = StreamBody<MapOk<S, fn(D) -> Frame<D>>>;
25impl<E, S, D> SseStream<ByteStreamBody<S, D>>
26where
27 S: Stream<Item = Result<D, E>>,
28 E: std::error::Error,
29 D: Buf,
30 StreamBody<ByteStreamBody<S, D>>: Body,
31{
32 pub fn from_byte_stream(stream: S) -> Self {
36 let stream = stream.map_ok(http_body::Frame::data as fn(D) -> Frame<D>);
37 let body = StreamBody::new(stream);
38 Self {
39 body: BodyDataStream::new(body),
40 parsed: VecDeque::new(),
41 current: None,
42 unfinished_line: Vec::new(),
43 }
44 }
45}
46
47impl<B: Body> SseStream<B> {
48 pub fn new(body: B) -> Self {
50 Self {
51 body: BodyDataStream::new(body),
52 parsed: VecDeque::new(),
53 current: None,
54 unfinished_line: Vec::new(),
55 }
56 }
57}
58
59pub enum Error {
60 Body(Box<dyn std::error::Error + Send + Sync>),
61 InvalidLine,
62 DuplicatedEventLine,
63 DuplicatedIdLine,
64 DuplicatedRetry,
65 Utf8Parse(Utf8Error),
66 IntParse(ParseIntError),
67}
68
69impl std::fmt::Display for Error {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 match self {
72 Error::Body(e) => write!(f, "body error: {}", e),
73 Error::InvalidLine => write!(f, "invalid line"),
74 Error::DuplicatedEventLine => write!(f, "duplicated event line"),
75 Error::DuplicatedIdLine => write!(f, "duplicated id line"),
76 Error::DuplicatedRetry => write!(f, "duplicated retry line"),
77 Error::Utf8Parse(e) => write!(f, "utf8 parse error: {}", e),
78 Error::IntParse(e) => write!(f, "int parse error: {}", e),
79 }
80 }
81}
82
83impl std::fmt::Debug for Error {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 Error::Body(e) => write!(f, "Body({:?})", e),
87 Error::InvalidLine => write!(f, "InvalidLine"),
88 Error::DuplicatedEventLine => write!(f, "DuplicatedEventLine"),
89 Error::DuplicatedIdLine => write!(f, "DuplicatedIdLine"),
90 Error::DuplicatedRetry => write!(f, "DuplicatedRetry"),
91 Error::Utf8Parse(e) => write!(f, "Utf8Parse({:?})", e),
92 Error::IntParse(e) => write!(f, "IntParse({:?})", e),
93 }
94 }
95}
96
97impl std::error::Error for Error {
98 fn description(&self) -> &str {
99 match self {
100 Error::Body(_) => "body error",
101 Error::InvalidLine => "invalid line",
102 Error::DuplicatedEventLine => "duplicated event line",
103 Error::DuplicatedIdLine => "duplicated id line",
104 Error::DuplicatedRetry => "duplicated retry line",
105 Error::Utf8Parse(_) => "utf8 parse error",
106 Error::IntParse(_) => "int parse error",
107 }
108 }
109
110 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
111 match self {
112 Error::Body(e) => Some(e.as_ref()),
113 Error::Utf8Parse(e) => Some(e),
114 Error::IntParse(e) => Some(e),
115 _ => None,
116 }
117 }
118}
119
120impl<B: Body> Stream for SseStream<B>
121where
122 B::Error: std::error::Error + Send + Sync + 'static,
123{
124 type Item = Result<Sse, Error>;
125
126 fn poll_next(
127 mut self: std::pin::Pin<&mut Self>,
128 cx: &mut Context<'_>,
129 ) -> Poll<Option<Self::Item>> {
130 let this = self.as_mut().project();
131 if let Some(sse) = this.parsed.pop_front() {
132 return Poll::Ready(Some(Ok(sse)));
133 }
134 let next_data = ready!(this.body.poll_next(cx));
135 match next_data {
136 Some(Ok(data)) => {
137 let chunk = data.chunk();
138
139 if chunk.is_empty() {
140 return self.poll_next(cx);
141 }
142 let mut lines = chunk.chunk_by(|maybe_nl, _| *maybe_nl != b'\n');
143 let first_line = lines.next().expect("frame is empty");
144 let mut new_unfinished_line = Vec::new();
145 let first_line = if !this.unfinished_line.is_empty() {
146 this.unfinished_line.extend(first_line);
147 std::mem::swap(&mut new_unfinished_line, this.unfinished_line);
148 new_unfinished_line.as_ref()
149 } else {
150 first_line
151 };
152 let mut lines = std::iter::once(first_line).chain(lines);
153 *this.unfinished_line = loop {
154 let Some(line) = lines.next() else {
155 break Vec::new();
156 };
157 let line = if line.ends_with(b"\r\n") {
158 &line[..line.len() - 2]
159 } else if line.ends_with(b"\n") || line.ends_with(b"\r") {
160 &line[..line.len() - 1]
161 } else {
162 break line.to_vec();
163 };
164
165 if line.is_empty() {
166 if let Some(sse) = this.current.take() {
167 this.parsed.push_back(sse);
168 }
169 continue;
170 }
171 let Some(comma_index) = line.iter().position(|b| *b == b':') else {
173 #[cfg(feature = "tracing")]
174 tracing::warn!(?line, "invalid line, missing `:`");
175 return Poll::Ready(Some(Err(Error::InvalidLine)));
176 };
177 let field_name = &line[..comma_index];
178 let field_value = if line.len() > comma_index + 1 {
179 let field_value = &line[comma_index + 1..];
180 if field_value.starts_with(b" ") {
181 &field_value[1..]
182 } else {
183 field_value
184 }
185 } else {
186 b""
187 };
188 match field_name {
189 b"data" => {
190 let data_line =
191 std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
192 if let Some(Sse { data, .. }) = this.current.as_mut() {
194 if data.is_none() {
195 data.replace(data_line.to_owned());
196 } else {
197 let data = data.as_mut().unwrap();
198 data.push('\n');
199 data.push_str(data_line);
200 }
201 } else {
202 this.current.replace(Sse {
203 event: None,
204 data: Some(data_line.to_owned()),
205 id: None,
206 retry: None,
207 });
208 }
209 }
210 b"event" => {
211 let event_value =
212 std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
213 if let Some(Sse { event, .. }) = this.current.as_mut() {
214 if event.is_some() {
215 return Poll::Ready(Some(Err(Error::DuplicatedEventLine)));
216 } else {
217 event.replace(event_value.to_owned());
218 }
219 } else {
220 this.current.replace(Sse {
221 event: Some(event_value.to_owned()),
222 ..Default::default()
223 });
224 }
225 }
226 b"id" => {
227 let id_value =
228 std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
229 if let Some(Sse { id, .. }) = this.current.as_mut() {
230 if id.is_some() {
231 return Poll::Ready(Some(Err(Error::DuplicatedIdLine)));
232 } else {
233 id.replace(id_value.to_owned());
234 }
235 } else {
236 this.current.replace(Sse {
237 id: Some(id_value.to_owned()),
238 ..Default::default()
239 });
240 }
241 }
242 b"retry" => {
243 let retry_value = std::str::from_utf8(field_value)
244 .map_err(Error::Utf8Parse)?
245 .trim_ascii();
246 let retry_value =
247 retry_value.parse::<u64>().map_err(Error::IntParse)?;
248 if let Some(Sse { retry, .. }) = this.current.as_mut() {
249 if retry.is_some() {
250 return Poll::Ready(Some(Err(Error::DuplicatedRetry)));
251 } else {
252 retry.replace(retry_value);
253 }
254 } else {
255 this.current.replace(Sse {
256 retry: Some(retry_value),
257 ..Default::default()
258 });
259 }
260 }
261 b"" => {
262 #[cfg(feature = "tracing")]
263 if tracing::enabled!(tracing::Level::DEBUG) {
264 let comment =
266 std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
267 tracing::debug!(?comment, "sse comment line");
268 }
269 }
270 _line => {
271 #[cfg(feature = "tracing")]
272 if tracing::enabled!(tracing::Level::WARN) {
273 tracing::warn!(line = ?_line, "invalid line: unknown field");
274 }
275 return Poll::Ready(Some(Err(Error::InvalidLine)));
276 }
277 }
278 };
279 self.poll_next(cx)
280 }
281 Some(Err(e)) => Poll::Ready(Some(Err(Error::Body(Box::new(e))))),
282 None => {
283 if let Some(sse) = this.current.take() {
284 Poll::Ready(Some(Ok(sse)))
285 } else {
286 Poll::Ready(None)
287 }
288 }
289 }
290 }
291}