1use eventsource_stream::Eventsource;
2use futures_core::Stream;
3use futures_util::StreamExt;
4use std::pin::Pin;
5use tokio_util::bytes::Bytes;
6
7use crate::api::error::SseParseError;
8
9#[derive(Debug, Clone)]
10pub struct SseEvent {
11 pub event_type: Option<String>,
12 pub data: String,
13 pub id: Option<String>,
14}
15
16pub type SseStream = Pin<Box<dyn Stream<Item = Result<SseEvent, SseParseError>> + Send>>;
17
18pub fn parse_sse_stream<S, E>(byte_stream: S) -> SseStream
19where
20 S: Stream<Item = Result<Bytes, E>> + Send + 'static,
21 E: std::error::Error + Send + 'static,
22{
23 let event_stream = byte_stream.eventsource().map(|result| {
24 result
25 .map(|event| SseEvent {
26 event_type: if event.event.is_empty() {
27 None
28 } else {
29 Some(event.event)
30 },
31 data: event.data,
32 id: if event.id.is_empty() {
33 None
34 } else {
35 Some(event.id)
36 },
37 })
38 .map_err(SseParseError::from)
39 });
40
41 Box::pin(event_stream)
42}
43
44#[cfg(test)]
45mod tests {
46 use super::*;
47 use futures_util::stream;
48
49 #[tokio::test]
50 async fn test_parse_simple_sse_event() {
51 let sse_data = "event: message\ndata: {\"text\": \"hello\"}\n\n";
52 let byte_stream =
53 stream::once(async move { Ok::<_, std::io::Error>(Bytes::from(sse_data)) });
54
55 let mut sse_stream = parse_sse_stream(byte_stream);
56
57 let event = sse_stream.next().await.unwrap().unwrap();
58 assert_eq!(event.event_type, Some("message".to_string()));
59 assert_eq!(event.data, "{\"text\": \"hello\"}");
60 }
61
62 #[tokio::test]
63 async fn test_parse_multiple_sse_events() {
64 let sse_data = "event: start\ndata: first\n\nevent: delta\ndata: second\n\n";
65 let byte_stream =
66 stream::once(async move { Ok::<_, std::io::Error>(Bytes::from(sse_data)) });
67
68 let mut sse_stream = parse_sse_stream(byte_stream);
69
70 let event1 = sse_stream.next().await.unwrap().unwrap();
71 assert_eq!(event1.event_type, Some("start".to_string()));
72 assert_eq!(event1.data, "first");
73
74 let event2 = sse_stream.next().await.unwrap().unwrap();
75 assert_eq!(event2.event_type, Some("delta".to_string()));
76 assert_eq!(event2.data, "second");
77 }
78}