1use std::time::Duration;
4
5use http_body_util::BodyExt;
6use serde::de::DeserializeOwned;
7
8use super::client::StreamingBody;
9use crate::error::{Error, Result};
10
11pub struct TestSseEvent {
13 event: Option<String>,
14 data: String,
15 id: Option<String>,
16}
17
18impl TestSseEvent {
19 pub fn event(&self) -> Option<&str> {
21 self.event.as_deref()
22 }
23
24 pub fn data(&self) -> &str {
26 &self.data
27 }
28
29 pub fn id(&self) -> Option<&str> {
31 self.id.as_deref()
32 }
33
34 pub fn json<T: DeserializeOwned>(&self) -> Result<T> {
36 serde_json::from_str(&self.data)
37 .map_err(|error| Error::internal(format!("event data is not valid JSON: {error}")))
38 }
39}
40
41pub struct TestSseStream {
47 body: StreamingBody,
48 buffer: String,
49 done: bool,
50}
51
52impl TestSseStream {
53 pub(crate) fn new(body: StreamingBody) -> Self {
54 Self {
55 body,
56 buffer: String::new(),
57 done: false,
58 }
59 }
60
61 pub async fn next_event(&mut self) -> Result<Option<TestSseEvent>> {
63 loop {
64 if let Some(block) = self.take_block() {
66 if let Some(event) = parse_event(&block) {
67 return Ok(Some(event));
68 }
69 continue;
70 }
71 if self.done {
72 return Ok(None);
73 }
74 match self.body.frame().await {
76 Some(Ok(frame)) => {
77 if let Ok(data) = frame.into_data() {
78 self.buffer.push_str(&String::from_utf8_lossy(&data));
79 }
80 }
81 Some(Err(error)) => {
82 return Err(Error::internal(format!("event stream error: {error}")));
83 }
84 None => {
85 self.done = true;
86 if !self.buffer.trim().is_empty() {
89 let block = std::mem::take(&mut self.buffer);
90 if let Some(event) = parse_event(&block) {
91 return Ok(Some(event));
92 }
93 }
94 return Ok(None);
95 }
96 }
97 }
98 }
99
100 pub async fn next_event_timeout(&mut self, timeout: Duration) -> Result<Option<TestSseEvent>> {
103 tokio::time::timeout(timeout, self.next_event())
104 .await
105 .map_err(|_| {
106 Error::internal("timed out waiting for an event").with_code("SSE_TIMEOUT")
107 })?
108 }
109
110 fn take_block(&mut self) -> Option<String> {
113 let index = self.buffer.find("\n\n")?;
114 let block: String = self.buffer.drain(..index + 2).collect();
115 Some(block)
116 }
117}
118
119fn parse_event(block: &str) -> Option<TestSseEvent> {
122 let mut event = None;
123 let mut id = None;
124 let mut data_lines: Vec<&str> = Vec::new();
125 let mut has_field = false;
126
127 for line in block.lines() {
128 if line.is_empty() || line.starts_with(':') {
129 continue; }
131 let (field, value) = match line.split_once(':') {
132 Some((field, value)) => (field, value.strip_prefix(' ').unwrap_or(value)),
133 None => (line, ""),
134 };
135 match field {
136 "event" => {
137 event = Some(value.to_owned());
138 has_field = true;
139 }
140 "id" => {
141 id = Some(value.to_owned());
142 has_field = true;
143 }
144 "data" => {
145 data_lines.push(value);
146 has_field = true;
147 }
148 "retry" => has_field = true,
149 _ => {}
150 }
151 }
152
153 if !has_field {
154 return None;
155 }
156 Some(TestSseEvent {
157 event,
158 data: data_lines.join("\n"),
159 id,
160 })
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::body::BoxError;
167 use bytes::Bytes;
168 use futures_util::stream;
169 use http_body::Frame;
170 use http_body_util::StreamBody;
171 use serde::Deserialize;
172
173 #[derive(Debug, Deserialize, PartialEq)]
174 struct Payload {
175 value: i64,
176 }
177
178 fn body_from_chunks(chunks: Vec<std::result::Result<Frame<Bytes>, BoxError>>) -> StreamingBody {
179 Box::pin(StreamBody::new(stream::iter(chunks)))
180 }
181
182 #[test]
183 fn parse_event_collects_name_id_and_multiline_data() {
184 let event = parse_event("event: tick\nid: 7\ndata: first\ndata: second\n\n").unwrap();
185
186 assert_eq!(event.event(), Some("tick"));
187 assert_eq!(event.id(), Some("7"));
188 assert_eq!(event.data(), "first\nsecond");
189 }
190
191 #[test]
192 fn parse_event_skips_heartbeat_only_blocks() {
193 assert!(parse_event(": keep-alive\n\n").is_none());
194 }
195
196 #[test]
197 fn event_json_reports_invalid_payload() {
198 let event = parse_event("data: not-json\n\n").unwrap();
199
200 let error = event.json::<Payload>().unwrap_err();
201 assert!(error.message().starts_with("event data is not valid JSON:"));
202 }
203
204 #[tokio::test]
205 async fn next_event_parses_trailing_block_at_end_of_stream() {
206 let body = body_from_chunks(vec![Ok(Frame::data(Bytes::from_static(
207 b"event: tick\ndata: {\"value\":1}",
208 )))]);
209 let mut stream = TestSseStream::new(body);
210
211 let event = stream.next_event().await.unwrap().unwrap();
212 assert_eq!(event.event(), Some("tick"));
213 assert_eq!(event.json::<Payload>().unwrap(), Payload { value: 1 });
214 assert!(stream.next_event().await.unwrap().is_none());
215 }
216
217 #[tokio::test]
218 async fn next_event_reports_stream_errors() {
219 let error: BoxError = Box::new(std::io::Error::other("boom"));
220 let body = body_from_chunks(vec![Err(error)]);
221 let mut stream = TestSseStream::new(body);
222
223 let error = match stream.next_event().await {
224 Ok(_) => panic!("expected stream error"),
225 Err(error) => error,
226 };
227 assert!(error.message().contains("event stream error: boom"));
228 }
229
230 #[tokio::test]
231 async fn next_event_timeout_reports_deadline() {
232 let body: StreamingBody = Box::pin(StreamBody::new(stream::pending::<
233 std::result::Result<Frame<Bytes>, BoxError>,
234 >()));
235 let mut stream = TestSseStream::new(body);
236
237 let error = match stream.next_event_timeout(Duration::from_millis(5)).await {
238 Ok(_) => panic!("expected timeout"),
239 Err(error) => error,
240 };
241 assert_eq!(error.code(), "SSE_TIMEOUT");
242 }
243}