Skip to main content

steer_core/api/
sse.rs

1use eventsource_stream::Eventsource;
2use futures_core::Stream;
3use futures_util::StreamExt;
4use std::pin::Pin;
5use tokio_util::bytes::Bytes;
6
7use crate::api::error::SseParseError;
8
9#[derive(Debug, Clone)]
10pub struct SseEvent {
11    pub event_type: Option<String>,
12    pub data: String,
13    pub id: Option<String>,
14}
15
16pub type SseStream = Pin<Box<dyn Stream<Item = Result<SseEvent, SseParseError>> + Send>>;
17
18pub fn parse_sse_stream<S, E>(byte_stream: S) -> SseStream
19where
20    S: Stream<Item = Result<Bytes, E>> + Send + 'static,
21    E: std::error::Error + Send + 'static,
22{
23    let event_stream = byte_stream.eventsource().map(|result| {
24        result
25            .map(|event| SseEvent {
26                event_type: if event.event.is_empty() {
27                    None
28                } else {
29                    Some(event.event)
30                },
31                data: event.data,
32                id: if event.id.is_empty() {
33                    None
34                } else {
35                    Some(event.id)
36                },
37            })
38            .map_err(SseParseError::from)
39    });
40
41    Box::pin(event_stream)
42}
43
44#[cfg(test)]
45mod tests {
46    use super::*;
47    use futures_util::stream;
48
49    #[tokio::test]
50    async fn test_parse_simple_sse_event() {
51        let sse_data = "event: message\ndata: {\"text\": \"hello\"}\n\n";
52        let byte_stream =
53            stream::once(async move { Ok::<_, std::io::Error>(Bytes::from(sse_data)) });
54
55        let mut sse_stream = parse_sse_stream(byte_stream);
56
57        let event = sse_stream.next().await.unwrap().unwrap();
58        assert_eq!(event.event_type, Some("message".to_string()));
59        assert_eq!(event.data, "{\"text\": \"hello\"}");
60    }
61
62    #[tokio::test]
63    async fn test_parse_multiple_sse_events() {
64        let sse_data = "event: start\ndata: first\n\nevent: delta\ndata: second\n\n";
65        let byte_stream =
66            stream::once(async move { Ok::<_, std::io::Error>(Bytes::from(sse_data)) });
67
68        let mut sse_stream = parse_sse_stream(byte_stream);
69
70        let event1 = sse_stream.next().await.unwrap().unwrap();
71        assert_eq!(event1.event_type, Some("start".to_string()));
72        assert_eq!(event1.data, "first");
73
74        let event2 = sse_stream.next().await.unwrap().unwrap();
75        assert_eq!(event2.event_type, Some("delta".to_string()));
76        assert_eq!(event2.data, "second");
77    }
78}