Skip to main content

tork_core/testing/
sse.rs

1//! A reader for Server-Sent Events responses.
2
3use std::time::Duration;
4
5use http_body_util::BodyExt;
6use serde::de::DeserializeOwned;
7
8use super::client::StreamingBody;
9use crate::error::{Error, Result};
10
11/// A single event parsed from a Server-Sent Events stream.
12pub struct TestSseEvent {
13    event: Option<String>,
14    data: String,
15    id: Option<String>,
16}
17
18impl TestSseEvent {
19    /// The event name (the `event:` field), if any.
20    pub fn event(&self) -> Option<&str> {
21        self.event.as_deref()
22    }
23
24    /// The event payload (the joined `data:` lines).
25    pub fn data(&self) -> &str {
26        &self.data
27    }
28
29    /// The event id (the `id:` field), if any.
30    pub fn id(&self) -> Option<&str> {
31        self.id.as_deref()
32    }
33
34    /// Parses the event data as JSON.
35    pub fn json<T: DeserializeOwned>(&self) -> Result<T> {
36        serde_json::from_str(&self.data)
37            .map_err(|error| Error::internal(format!("event data is not valid JSON: {error}")))
38    }
39}
40
41/// A reader over a `text/event-stream` response body.
42///
43/// Returned by [`sse`](super::TestRequestBuilder::sse). Pull events with
44/// [`next_event`](TestSseStream::next_event). Heartbeat comments are skipped, so
45/// every returned event carries a name, data, or id.
46pub struct TestSseStream {
47    body: StreamingBody,
48    buffer: String,
49    done: bool,
50}
51
52impl TestSseStream {
53    pub(crate) fn new(body: StreamingBody) -> Self {
54        Self {
55            body,
56            buffer: String::new(),
57            done: false,
58        }
59    }
60
61    /// Returns the next event, or `None` once the stream ends.
62    pub async fn next_event(&mut self) -> Result<Option<TestSseEvent>> {
63        loop {
64            // Emit any complete, non-heartbeat event already in the buffer.
65            if let Some(block) = self.take_block() {
66                if let Some(event) = parse_event(&block) {
67                    return Ok(Some(event));
68                }
69                continue;
70            }
71            if self.done {
72                return Ok(None);
73            }
74            // Otherwise read the next body frame into the buffer.
75            match self.body.frame().await {
76                Some(Ok(frame)) => {
77                    if let Ok(data) = frame.into_data() {
78                        self.buffer.push_str(&String::from_utf8_lossy(&data));
79                    }
80                }
81                Some(Err(error)) => {
82                    return Err(Error::internal(format!("event stream error: {error}")));
83                }
84                None => {
85                    self.done = true;
86                    // A trailing block without its blank-line terminator is still
87                    // worth parsing once the stream ends.
88                    if !self.buffer.trim().is_empty() {
89                        let block = std::mem::take(&mut self.buffer);
90                        if let Some(event) = parse_event(&block) {
91                            return Ok(Some(event));
92                        }
93                    }
94                    return Ok(None);
95                }
96            }
97        }
98    }
99
100    /// Like [`next_event`](TestSseStream::next_event) but fails if no event
101    /// arrives within `timeout`.
102    pub async fn next_event_timeout(&mut self, timeout: Duration) -> Result<Option<TestSseEvent>> {
103        tokio::time::timeout(timeout, self.next_event())
104            .await
105            .map_err(|_| {
106                Error::internal("timed out waiting for an event").with_code("SSE_TIMEOUT")
107            })?
108    }
109
110    /// Removes and returns the next complete event block (terminated by a blank
111    /// line) from the buffer, if one is present.
112    fn take_block(&mut self) -> Option<String> {
113        let index = self.buffer.find("\n\n")?;
114        let block: String = self.buffer.drain(..index + 2).collect();
115        Some(block)
116    }
117}
118
119/// Parses one event block into an event, or `None` for a heartbeat/comment-only
120/// block (no name, data, or id).
121fn parse_event(block: &str) -> Option<TestSseEvent> {
122    let mut event = None;
123    let mut id = None;
124    let mut data_lines: Vec<&str> = Vec::new();
125    let mut has_field = false;
126
127    for line in block.lines() {
128        if line.is_empty() || line.starts_with(':') {
129            continue; // blank line or comment (heartbeat)
130        }
131        let (field, value) = match line.split_once(':') {
132            Some((field, value)) => (field, value.strip_prefix(' ').unwrap_or(value)),
133            None => (line, ""),
134        };
135        match field {
136            "event" => {
137                event = Some(value.to_owned());
138                has_field = true;
139            }
140            "id" => {
141                id = Some(value.to_owned());
142                has_field = true;
143            }
144            "data" => {
145                data_lines.push(value);
146                has_field = true;
147            }
148            "retry" => has_field = true,
149            _ => {}
150        }
151    }
152
153    if !has_field {
154        return None;
155    }
156    Some(TestSseEvent {
157        event,
158        data: data_lines.join("\n"),
159        id,
160    })
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::body::BoxError;
167    use bytes::Bytes;
168    use futures_util::stream;
169    use http_body::Frame;
170    use http_body_util::StreamBody;
171    use serde::Deserialize;
172
173    #[derive(Debug, Deserialize, PartialEq)]
174    struct Payload {
175        value: i64,
176    }
177
178    fn body_from_chunks(chunks: Vec<std::result::Result<Frame<Bytes>, BoxError>>) -> StreamingBody {
179        Box::pin(StreamBody::new(stream::iter(chunks)))
180    }
181
182    #[test]
183    fn parse_event_collects_name_id_and_multiline_data() {
184        let event = parse_event("event: tick\nid: 7\ndata: first\ndata: second\n\n").unwrap();
185
186        assert_eq!(event.event(), Some("tick"));
187        assert_eq!(event.id(), Some("7"));
188        assert_eq!(event.data(), "first\nsecond");
189    }
190
191    #[test]
192    fn parse_event_skips_heartbeat_only_blocks() {
193        assert!(parse_event(": keep-alive\n\n").is_none());
194    }
195
196    #[test]
197    fn event_json_reports_invalid_payload() {
198        let event = parse_event("data: not-json\n\n").unwrap();
199
200        let error = event.json::<Payload>().unwrap_err();
201        assert!(error.message().starts_with("event data is not valid JSON:"));
202    }
203
204    #[tokio::test]
205    async fn next_event_parses_trailing_block_at_end_of_stream() {
206        let body = body_from_chunks(vec![Ok(Frame::data(Bytes::from_static(
207            b"event: tick\ndata: {\"value\":1}",
208        )))]);
209        let mut stream = TestSseStream::new(body);
210
211        let event = stream.next_event().await.unwrap().unwrap();
212        assert_eq!(event.event(), Some("tick"));
213        assert_eq!(event.json::<Payload>().unwrap(), Payload { value: 1 });
214        assert!(stream.next_event().await.unwrap().is_none());
215    }
216
217    #[tokio::test]
218    async fn next_event_reports_stream_errors() {
219        let error: BoxError = Box::new(std::io::Error::other("boom"));
220        let body = body_from_chunks(vec![Err(error)]);
221        let mut stream = TestSseStream::new(body);
222
223        let error = match stream.next_event().await {
224            Ok(_) => panic!("expected stream error"),
225            Err(error) => error,
226        };
227        assert!(error.message().contains("event stream error: boom"));
228    }
229
230    #[tokio::test]
231    async fn next_event_timeout_reports_deadline() {
232        let body: StreamingBody = Box::pin(StreamBody::new(stream::pending::<
233            std::result::Result<Frame<Bytes>, BoxError>,
234        >()));
235        let mut stream = TestSseStream::new(body);
236
237        let error = match stream.next_event_timeout(Duration::from_millis(5)).await {
238            Ok(_) => panic!("expected timeout"),
239            Err(error) => error,
240        };
241        assert_eq!(error.code(), "SSE_TIMEOUT");
242    }
243}