1use std::time::Duration;
21
22use bytes::Bytes;
23use futures::Stream;
24
25use crate::body::{body_from_stream, BoxBody, BoxBodyError};
26use crate::response::IntoResponse;
27
28#[derive(Debug, Clone, Default)]
38pub struct SseEvent {
39 pub id: Option<String>,
41 pub event: Option<String>,
43 pub data: String,
45 pub retry: Option<u64>,
47 comment: Option<String>,
50}
51
52impl SseEvent {
53 pub fn data(data: impl Into<String>) -> Self {
55 SseEvent {
56 data: data.into(),
57 ..Default::default()
58 }
59 }
60
61 pub fn comment(text: impl Into<String>) -> Self {
66 SseEvent {
67 comment: Some(text.into()),
68 ..Default::default()
69 }
70 }
71
72 pub fn with_id(mut self, id: impl Into<String>) -> Self {
74 self.id = Some(id.into());
75 self
76 }
77
78 pub fn with_event(mut self, event: impl Into<String>) -> Self {
80 self.event = Some(event.into());
81 self
82 }
83
84 pub fn with_retry(mut self, ms: u64) -> Self {
86 self.retry = Some(ms);
87 self
88 }
89
90 pub fn to_wire(&self) -> String {
97 if let Some(ref c) = self.comment {
98 let mut out = String::new();
99 for line in c.split('\n') {
100 out.push(':');
101 out.push_str(&line.replace('\r', ""));
102 out.push('\n');
103 }
104 out.push('\n');
105 return out;
106 }
107
108 let mut out = String::new();
109 if let Some(ref id) = self.id {
110 out.push_str("id: ");
111 out.push_str(&strip_cr_lf(id));
112 out.push('\n');
113 }
114 if let Some(ref ev) = self.event {
115 out.push_str("event: ");
116 out.push_str(&strip_cr_lf(ev));
117 out.push('\n');
118 }
119 if let Some(ms) = self.retry {
120 out.push_str("retry: ");
121 out.push_str(&ms.to_string());
122 out.push('\n');
123 }
124 for line in self.data.split('\n') {
125 out.push_str("data: ");
126 out.push_str(&line.replace('\r', ""));
127 out.push('\n');
128 }
129 out.push('\n');
130 out
131 }
132}
133
134fn strip_cr_lf(s: &str) -> String {
135 s.chars().filter(|&c| c != '\n' && c != '\r').collect()
136}
137
138pub struct SseResponse<S> {
144 stream: S,
145}
146
147impl<S> SseResponse<S>
148where
149 S: Stream<Item = SseEvent> + Send + 'static,
150{
151 pub fn new(stream: S) -> Self {
153 SseResponse { stream }
154 }
155}
156
157impl<S> IntoResponse for SseResponse<S>
158where
159 S: Stream<Item = SseEvent> + Send + 'static,
160{
161 fn into_response(self) -> http::Response<BoxBody> {
162 use futures::StreamExt;
163 let framed = self.stream.map(|ev| {
164 let wire = ev.to_wire();
165 Ok::<_, BoxBodyError>(http_body::Frame::data(Bytes::from(wire)))
166 });
167 let body = body_from_stream(framed);
168 let mut res = http::Response::new(body);
169 let h = res.headers_mut();
170 h.insert(
171 http::header::CONTENT_TYPE,
172 http::HeaderValue::from_static("text/event-stream"),
173 );
174 h.insert(
175 http::header::CACHE_CONTROL,
176 http::HeaderValue::from_static("no-cache"),
177 );
178 h.insert(
179 http::header::CONNECTION,
180 http::HeaderValue::from_static("keep-alive"),
181 );
182 h.insert(
184 http::HeaderName::from_static("x-accel-buffering"),
185 http::HeaderValue::from_static("no"),
186 );
187 res
188 }
189}
190
191pub fn keep_alive<S>(stream: S, interval: Duration) -> impl Stream<Item = SseEvent>
198where
199 S: Stream<Item = SseEvent> + Send + 'static,
200{
201 let pings = futures::stream::unfold((), move |_| async move {
202 tokio::time::sleep(interval).await;
203 Some((SseEvent::comment("keepalive"), ()))
204 });
205 futures::stream::select(stream, pings)
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn data_event_wire_format() {
214 let ev = SseEvent::data("hello");
215 assert_eq!(ev.to_wire(), "data: hello\n\n");
216 }
217
218 #[test]
219 fn multi_line_data_splits_into_multiple_data_lines() {
220 let ev = SseEvent::data("line one\nline two\nline three");
221 assert_eq!(
222 ev.to_wire(),
223 "data: line one\ndata: line two\ndata: line three\n\n"
224 );
225 }
226
227 #[test]
228 fn full_event_wire_format() {
229 let ev = SseEvent::data("payload")
230 .with_id("42")
231 .with_event("update")
232 .with_retry(3000);
233 assert_eq!(
234 ev.to_wire(),
235 "id: 42\nevent: update\nretry: 3000\ndata: payload\n\n"
236 );
237 }
238
239 #[test]
240 fn comment_event_wire_format() {
241 let ev = SseEvent::comment("keepalive");
242 assert_eq!(ev.to_wire(), ":keepalive\n\n");
243 }
244
245 #[test]
246 fn comment_with_newline_is_split() {
247 let ev = SseEvent::comment("line1\nline2");
248 assert_eq!(ev.to_wire(), ":line1\n:line2\n\n");
249 }
250
251 #[test]
252 fn cr_lf_is_stripped_from_scalar_fields() {
253 let ev = SseEvent::data("ok").with_id("1\n2\r3").with_event("a\nb");
254 let wire = ev.to_wire();
255 assert!(wire.contains("id: 123\n"));
256 assert!(wire.contains("event: ab\n"));
257 }
258
259 #[test]
260 fn carriage_return_in_data_is_dropped() {
261 let ev = SseEvent::data("a\rb\nc\rd");
262 assert_eq!(ev.to_wire(), "data: ab\ndata: cd\n\n");
264 }
265
266 #[tokio::test]
267 async fn sse_response_sets_required_headers() {
268 use futures::stream;
269 let s = stream::iter(vec![SseEvent::data("hi")]);
270 let res = SseResponse::new(s).into_response();
271 assert_eq!(res.status(), http::StatusCode::OK);
272 assert_eq!(
273 res.headers().get(http::header::CONTENT_TYPE).unwrap(),
274 "text/event-stream"
275 );
276 assert_eq!(
277 res.headers().get(http::header::CACHE_CONTROL).unwrap(),
278 "no-cache"
279 );
280 assert_eq!(
281 res.headers().get(http::header::CONNECTION).unwrap(),
282 "keep-alive"
283 );
284 assert_eq!(res.headers().get("x-accel-buffering").unwrap(), "no");
285 }
286
287 #[tokio::test]
288 async fn sse_response_streams_event_bytes() {
289 use futures::stream;
290 use http_body_util::BodyExt;
291
292 let s = stream::iter(vec![
293 SseEvent::data("first"),
294 SseEvent::data("second").with_event("update"),
295 ]);
296 let res = SseResponse::new(s).into_response();
297 let collected = res.into_body().collect().await.unwrap().to_bytes();
298 let text = std::str::from_utf8(&collected).unwrap();
299 assert_eq!(text, "data: first\n\nevent: update\ndata: second\n\n");
300 }
301
302 #[tokio::test]
303 async fn keep_alive_interleaves_pings() {
304 use futures::StreamExt;
305 use std::time::Duration;
306
307 let pending = futures::stream::pending::<SseEvent>();
309 let source = futures::stream::iter(vec![SseEvent::data("real")]).chain(pending);
310
311 let mut combined = Box::pin(keep_alive(source, Duration::from_millis(20)));
312
313 let first = tokio::time::timeout(Duration::from_millis(200), combined.next())
314 .await
315 .unwrap()
316 .unwrap();
317 let second = tokio::time::timeout(Duration::from_millis(200), combined.next())
320 .await
321 .unwrap()
322 .unwrap();
323 let saw_ping = first.comment.is_some() || second.comment.is_some();
324 assert!(saw_ping, "expected at least one keep-alive ping");
325 }
326}