turul_mcp_aws_lambda/
streaming.rs1use bytes::Bytes;
8use futures::{Stream, StreamExt};
9use tracing::debug;
10
11use crate::error::{LambdaError, Result};
12
13pub fn format_sse_event(data: &str, event_type: Option<&str>, event_id: Option<&str>) -> String {
18 let mut event = String::new();
19
20 if let Some(id) = event_id {
21 event.push_str(&format!("id: {}\n", id));
22 }
23
24 if let Some(event_type) = event_type {
25 event.push_str(&format!("event: {}\n", event_type));
26 }
27
28 for line in data.lines() {
30 event.push_str(&format!("data: {}\n", line));
31 }
32
33 event.push('\n'); event
35}
36
37pub fn create_sse_stream<T>(
42 events: Vec<T>,
43 formatter: impl Fn(&T) -> String + Send + 'static,
44) -> impl Stream<Item = Result<Bytes>> + Send + 'static
45where
46 T: Send + 'static,
47{
48 async_stream::stream! {
49 for event in events {
50 let sse_data = formatter(&event);
51 let bytes = Bytes::from(sse_data);
52 yield Ok(bytes);
53
54 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
56 }
57 }
58}
59
60pub fn create_heartbeat_stream(
65 interval_secs: u64,
66) -> impl Stream<Item = Result<Bytes>> + Send + 'static {
67 async_stream::stream! {
68 let mut interval = tokio::time::interval(
69 tokio::time::Duration::from_secs(interval_secs)
70 );
71
72 loop {
73 interval.tick().await;
74
75 let heartbeat = format_sse_event(
76 "heartbeat",
77 Some("heartbeat"),
78 Some(&chrono::Utc::now().timestamp().to_string())
79 );
80
81 yield Ok(Bytes::from(heartbeat));
82 }
83 }
84}
85
86pub fn merge_sse_streams<S1, S2>(
91 stream1: S1,
92 stream2: S2,
93) -> impl Stream<Item = Result<Bytes>> + Send + 'static
94where
95 S1: Stream<Item = Result<Bytes>> + Send + 'static,
96 S2: Stream<Item = Result<Bytes>> + Send + 'static,
97{
98 use futures::stream::select;
99
100 select(stream1.map(|item| (1, item)), stream2.map(|item| (2, item))).map(|(_, result)| result)
101}
102
103pub fn validate_sse_event(event: &str) -> Result<()> {
108 if event.contains('\0') {
110 return Err(LambdaError::Sse(
111 "SSE events cannot contain null bytes".to_string(),
112 ));
113 }
114
115 if event.len() > 1_048_576 {
117 debug!("Warning: SSE event is very large ({} bytes)", event.len());
119 }
120
121 if event.contains('\r') && !event.contains("\r\n") {
123 return Err(LambdaError::Sse(
124 "SSE events should use LF or CRLF line endings, not standalone CR".to_string(),
125 ));
126 }
127
128 Ok(())
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use futures::stream;
135
136 #[test]
137 fn test_format_sse_event() {
138 let event = format_sse_event("Hello, World!", Some("message"), Some("123"));
139
140 assert!(event.contains("id: 123\n"));
141 assert!(event.contains("event: message\n"));
142 assert!(event.contains("data: Hello, World!\n"));
143 assert!(event.ends_with("\n\n"));
144 }
145
146 #[test]
147 fn test_format_multiline_event() {
148 let data = "Line 1\nLine 2\nLine 3";
149 let event = format_sse_event(data, None, None);
150
151 assert!(event.contains("data: Line 1\n"));
152 assert!(event.contains("data: Line 2\n"));
153 assert!(event.contains("data: Line 3\n"));
154 }
155
156 #[tokio::test]
157 async fn test_create_sse_stream() {
158 use futures::StreamExt;
159 use futures::pin_mut;
160
161 let events = vec!["event1", "event2", "event3"];
162 let stream = create_sse_stream(events, |s| format_sse_event(s, Some("test"), None));
163 pin_mut!(stream);
164
165 let first_event = stream.next().await.unwrap().unwrap();
166 let event_str = String::from_utf8(first_event.to_vec()).unwrap();
167
168 assert!(event_str.contains("event: test\n"));
169 assert!(event_str.contains("data: event1\n"));
170 }
171
172 #[test]
173 fn test_validate_sse_event() {
174 assert!(validate_sse_event("Normal event").is_ok());
175 assert!(validate_sse_event("Event\nwith\nnewlines").is_ok());
176 assert!(validate_sse_event("Event with\0null byte").is_err());
177 assert!(validate_sse_event("Event with\rstandalone CR").is_err());
178 assert!(validate_sse_event("Event with\r\nCRLF").is_ok());
179 }
180
181 #[tokio::test]
182 async fn test_merge_streams() {
183 let stream1 = stream::iter(vec![
184 Ok(Bytes::from("stream1-1")),
185 Ok(Bytes::from("stream1-2")),
186 ]);
187
188 let stream2 = stream::iter(vec![
189 Ok(Bytes::from("stream2-1")),
190 Ok(Bytes::from("stream2-2")),
191 ]);
192
193 let merged = merge_sse_streams(stream1, stream2);
194 let results: Vec<_> = merged.collect().await;
195
196 assert_eq!(results.len(), 4);
197 }
199}