1use bytes::Bytes;
21use futures_util::Stream;
22use http::{header, StatusCode};
23use http_body_util::Full;
24use std::fmt::Write;
25
26use crate::response::{IntoResponse, Response};
27
28#[derive(Debug, Clone, Default)]
37pub struct SseEvent {
38 pub data: String,
40 pub event: Option<String>,
42 pub id: Option<String>,
44 pub retry: Option<u64>,
46}
47
48impl SseEvent {
49 pub fn new(data: impl Into<String>) -> Self {
51 Self {
52 data: data.into(),
53 event: None,
54 id: None,
55 retry: None,
56 }
57 }
58
59 pub fn event(mut self, event: impl Into<String>) -> Self {
61 self.event = Some(event.into());
62 self
63 }
64
65 pub fn id(mut self, id: impl Into<String>) -> Self {
67 self.id = Some(id.into());
68 self
69 }
70
71 pub fn retry(mut self, retry: u64) -> Self {
73 self.retry = Some(retry);
74 self
75 }
76
77 pub fn to_sse_string(&self) -> String {
86 let mut output = String::new();
87
88 if let Some(ref event) = self.event {
90 writeln!(output, "event: {}", event).unwrap();
91 }
92
93 if let Some(ref id) = self.id {
95 writeln!(output, "id: {}", id).unwrap();
96 }
97
98 if let Some(retry) = self.retry {
100 writeln!(output, "retry: {}", retry).unwrap();
101 }
102
103 for line in self.data.lines() {
105 writeln!(output, "data: {}", line).unwrap();
106 }
107
108 output.push('\n');
110
111 output
112 }
113}
114
115pub struct Sse<S> {
134 #[allow(dead_code)]
135 stream: S,
136 keep_alive: Option<std::time::Duration>,
137}
138
139impl<S> Sse<S> {
140 pub fn new(stream: S) -> Self {
142 Self {
143 stream,
144 keep_alive: None,
145 }
146 }
147
148 pub fn keep_alive(mut self, interval: std::time::Duration) -> Self {
153 self.keep_alive = Some(interval);
154 self
155 }
156}
157
158impl<S, E> IntoResponse for Sse<S>
162where
163 S: Stream<Item = Result<SseEvent, E>> + Send + 'static,
164 E: std::error::Error + Send + Sync + 'static,
165{
166 fn into_response(self) -> Response {
167 http::Response::builder()
175 .status(StatusCode::OK)
176 .header(header::CONTENT_TYPE, "text/event-stream")
177 .header(header::CACHE_CONTROL, "no-cache")
178 .header(header::CONNECTION, "keep-alive")
179 .body(Full::new(Bytes::new()))
180 .unwrap()
181 }
182}
183
184pub fn sse_from_iter<I, E>(
188 events: I,
189) -> Sse<futures_util::stream::Iter<std::vec::IntoIter<Result<SseEvent, E>>>>
190where
191 I: IntoIterator<Item = Result<SseEvent, E>>,
192{
193 use futures_util::stream;
194 let vec: Vec<_> = events.into_iter().collect();
195 Sse::new(stream::iter(vec))
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use proptest::prelude::*;
202
203 #[test]
204 fn test_sse_event_basic() {
205 let event = SseEvent::new("Hello, World!");
206 let output = event.to_sse_string();
207 assert_eq!(output, "data: Hello, World!\n\n");
208 }
209
210 #[test]
211 fn test_sse_event_with_event_type() {
212 let event = SseEvent::new("Hello").event("greeting");
213 let output = event.to_sse_string();
214 assert!(output.contains("event: greeting\n"));
215 assert!(output.contains("data: Hello\n"));
216 }
217
218 #[test]
219 fn test_sse_event_with_id() {
220 let event = SseEvent::new("Hello").id("123");
221 let output = event.to_sse_string();
222 assert!(output.contains("id: 123\n"));
223 assert!(output.contains("data: Hello\n"));
224 }
225
226 #[test]
227 fn test_sse_event_with_retry() {
228 let event = SseEvent::new("Hello").retry(5000);
229 let output = event.to_sse_string();
230 assert!(output.contains("retry: 5000\n"));
231 assert!(output.contains("data: Hello\n"));
232 }
233
234 #[test]
235 fn test_sse_event_multiline_data() {
236 let event = SseEvent::new("Line 1\nLine 2\nLine 3");
237 let output = event.to_sse_string();
238 assert!(output.contains("data: Line 1\n"));
239 assert!(output.contains("data: Line 2\n"));
240 assert!(output.contains("data: Line 3\n"));
241 }
242
243 #[test]
244 fn test_sse_event_full() {
245 let event = SseEvent::new("Hello").event("message").id("1").retry(3000);
246 let output = event.to_sse_string();
247
248 assert!(output.contains("event: message\n"));
250 assert!(output.contains("id: 1\n"));
251 assert!(output.contains("retry: 3000\n"));
252 assert!(output.contains("data: Hello\n"));
253
254 assert!(output.ends_with("\n\n"));
256 }
257
258 #[test]
259 fn test_sse_response_headers() {
260 use futures_util::stream;
261
262 let events: Vec<Result<SseEvent, std::convert::Infallible>> =
263 vec![Ok(SseEvent::new("test"))];
264 let sse = Sse::new(stream::iter(events));
265 let response = sse.into_response();
266
267 assert_eq!(response.status(), StatusCode::OK);
268 assert_eq!(
269 response.headers().get(header::CONTENT_TYPE).unwrap(),
270 "text/event-stream"
271 );
272 assert_eq!(
273 response.headers().get(header::CACHE_CONTROL).unwrap(),
274 "no-cache"
275 );
276 assert_eq!(
277 response.headers().get(header::CONNECTION).unwrap(),
278 "keep-alive"
279 );
280 }
281
282 proptest! {
289 #![proptest_config(ProptestConfig::with_cases(100))]
290
291 #[test]
292 fn prop_sse_response_format(
293 data in "[a-zA-Z0-9 ]{1,50}",
295 event_type in proptest::option::of("[a-zA-Z][a-zA-Z0-9_]{0,20}"),
297 event_id in proptest::option::of("[a-zA-Z0-9]{1,10}"),
299 retry_time in proptest::option::of(1000u64..60000u64),
301 ) {
302 use futures_util::stream;
303
304 let mut event = SseEvent::new(data.clone());
306 if let Some(ref et) = event_type {
307 event = event.event(et.clone());
308 }
309 if let Some(ref id) = event_id {
310 event = event.id(id.clone());
311 }
312 if let Some(retry) = retry_time {
313 event = event.retry(retry);
314 }
315
316 let sse_string = event.to_sse_string();
318
319 prop_assert!(
321 sse_string.ends_with("\n\n"),
322 "SSE event must end with double newline, got: {:?}",
323 sse_string
324 );
325
326 prop_assert!(
328 sse_string.contains(&format!("data: {}", data)),
329 "SSE event must contain data field with 'data: ' prefix"
330 );
331
332 if let Some(ref et) = event_type {
334 prop_assert!(
335 sse_string.contains(&format!("event: {}", et)),
336 "SSE event must contain event type with 'event: ' prefix"
337 );
338 }
339
340 if let Some(ref id) = event_id {
342 prop_assert!(
343 sse_string.contains(&format!("id: {}", id)),
344 "SSE event must contain ID with 'id: ' prefix"
345 );
346 }
347
348 if let Some(retry) = retry_time {
350 prop_assert!(
351 sse_string.contains(&format!("retry: {}", retry)),
352 "SSE event must contain retry with 'retry: ' prefix"
353 );
354 }
355
356 let events: Vec<Result<SseEvent, std::convert::Infallible>> = vec![Ok(event)];
358 let sse = Sse::new(stream::iter(events));
359 let response = sse.into_response();
360
361 prop_assert_eq!(
362 response.headers().get(header::CONTENT_TYPE).map(|v| v.to_str().unwrap()),
363 Some("text/event-stream"),
364 "SSE response must have Content-Type: text/event-stream"
365 );
366
367 prop_assert_eq!(
368 response.headers().get(header::CACHE_CONTROL).map(|v| v.to_str().unwrap()),
369 Some("no-cache"),
370 "SSE response must have Cache-Control: no-cache"
371 );
372
373 prop_assert_eq!(
374 response.headers().get(header::CONNECTION).map(|v| v.to_str().unwrap()),
375 Some("keep-alive"),
376 "SSE response must have Connection: keep-alive"
377 );
378 }
379
380 #[test]
381 fn prop_sse_multiline_data_format(
382 lines in proptest::collection::vec("[a-zA-Z0-9 ]{1,30}", 1..5),
384 ) {
385 let data = lines.join("\n");
386 let event = SseEvent::new(data.clone());
387 let sse_string = event.to_sse_string();
388
389 for line in lines.iter() {
391 prop_assert!(
392 sse_string.contains(&format!("data: {}", line)),
393 "Each line of multiline data must be prefixed with 'data: '"
394 );
395 }
396
397 prop_assert!(
399 sse_string.ends_with("\n\n"),
400 "SSE event must end with double newline"
401 );
402 }
403 }
404}