1use bytes::Bytes;
49use futures_util::Stream;
50use http::{header, StatusCode};
51use http_body_util::Full;
52use pin_project_lite::pin_project;
53use std::fmt::Write;
54use std::pin::Pin;
55use std::task::{Context, Poll};
56use std::time::Duration;
57
58use crate::response::{IntoResponse, Response};
59
60#[derive(Debug, Clone, Default)]
70pub struct SseEvent {
71 pub data: String,
73 pub event: Option<String>,
75 pub id: Option<String>,
77 pub retry: Option<u64>,
79 comment: Option<String>,
81}
82
83impl SseEvent {
84 pub fn new(data: impl Into<String>) -> Self {
86 Self {
87 data: data.into(),
88 event: None,
89 id: None,
90 retry: None,
91 comment: None,
92 }
93 }
94
95 pub fn comment(text: impl Into<String>) -> Self {
99 Self {
100 data: String::new(),
101 event: None,
102 id: None,
103 retry: None,
104 comment: Some(text.into()),
105 }
106 }
107
108 pub fn event(mut self, event: impl Into<String>) -> Self {
110 self.event = Some(event.into());
111 self
112 }
113
114 pub fn id(mut self, id: impl Into<String>) -> Self {
116 self.id = Some(id.into());
117 self
118 }
119
120 pub fn retry(mut self, retry: u64) -> Self {
122 self.retry = Some(retry);
123 self
124 }
125
126 pub fn json_data<T: serde::Serialize>(data: &T) -> Result<Self, serde_json::Error> {
128 Ok(Self::new(serde_json::to_string(data)?))
129 }
130
131 pub fn to_sse_string(&self) -> String {
141 let mut output = String::new();
142
143 if let Some(ref comment) = self.comment {
145 writeln!(output, ": {}", comment).unwrap();
146 output.push('\n');
147 return output;
148 }
149
150 if let Some(ref event) = self.event {
152 writeln!(output, "event: {}", event).unwrap();
153 }
154
155 if let Some(ref id) = self.id {
157 writeln!(output, "id: {}", id).unwrap();
158 }
159
160 if let Some(retry) = self.retry {
162 writeln!(output, "retry: {}", retry).unwrap();
163 }
164
165 for line in self.data.lines() {
167 writeln!(output, "data: {}", line).unwrap();
168 }
169
170 if self.data.is_empty() && self.comment.is_none() {
172 writeln!(output, "data:").unwrap();
173 }
174
175 output.push('\n');
177
178 output
179 }
180
181 pub fn to_bytes(&self) -> Bytes {
183 Bytes::from(self.to_sse_string())
184 }
185}
186
187#[derive(Debug, Clone)]
202pub struct KeepAlive {
203 interval: Duration,
205 text: String,
207}
208
209impl Default for KeepAlive {
210 fn default() -> Self {
211 Self {
212 interval: Duration::from_secs(15),
213 text: "keep-alive".to_string(),
214 }
215 }
216}
217
218impl KeepAlive {
219 pub fn new() -> Self {
221 Self::default()
222 }
223
224 pub fn interval(mut self, interval: Duration) -> Self {
226 self.interval = interval;
227 self
228 }
229
230 pub fn text(mut self, text: impl Into<String>) -> Self {
232 self.text = text.into();
233 self
234 }
235
236 pub fn get_interval(&self) -> Duration {
238 self.interval
239 }
240
241 pub fn event(&self) -> SseEvent {
243 SseEvent::comment(&self.text)
244 }
245}
246
247pub struct Sse<S> {
268 stream: S,
269 keep_alive: Option<KeepAlive>,
270}
271
272impl<S> Sse<S> {
273 pub fn new(stream: S) -> Self {
275 Self {
276 stream,
277 keep_alive: None,
278 }
279 }
280
281 pub fn keep_alive(mut self, config: KeepAlive) -> Self {
295 self.keep_alive = Some(config);
296 self
297 }
298
299 pub fn get_keep_alive(&self) -> Option<&KeepAlive> {
301 self.keep_alive.as_ref()
302 }
303}
304
305pin_project! {
307 pub struct SseStream<S> {
309 #[pin]
310 inner: S,
311 keep_alive: Option<KeepAlive>,
312 #[pin]
313 keep_alive_timer: Option<tokio::time::Interval>,
314 }
315}
316
317impl<S, E> Stream for SseStream<S>
318where
319 S: Stream<Item = Result<SseEvent, E>>,
320{
321 type Item = Result<Bytes, E>;
322
323 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
324 let this = self.project();
325
326 match this.inner.poll_next(cx) {
328 Poll::Ready(Some(Ok(event))) => {
329 return Poll::Ready(Some(Ok(event.to_bytes())));
330 }
331 Poll::Ready(Some(Err(e))) => {
332 return Poll::Ready(Some(Err(e)));
333 }
334 Poll::Ready(None) => {
335 return Poll::Ready(None);
336 }
337 Poll::Pending => {}
338 }
339
340 if let Some(mut timer) = this.keep_alive_timer.as_pin_mut() {
342 if timer.poll_tick(cx).is_ready() {
343 if let Some(keep_alive) = this.keep_alive {
344 let event = keep_alive.event();
345 return Poll::Ready(Some(Ok(event.to_bytes())));
346 }
347 }
348 }
349
350 Poll::Pending
351 }
352}
353
354impl<S, E> IntoResponse for Sse<S>
358where
359 S: Stream<Item = Result<SseEvent, E>> + Send + 'static,
360 E: std::error::Error + Send + Sync + 'static,
361{
362 fn into_response(self) -> Response {
363 let _ = self.stream; let _ = self.keep_alive; http::Response::builder()
375 .status(StatusCode::OK)
376 .header(header::CONTENT_TYPE, "text/event-stream")
377 .header(header::CACHE_CONTROL, "no-cache")
378 .header(header::CONNECTION, "keep-alive")
379 .header("X-Accel-Buffering", "no") .body(Full::new(Bytes::new()))
381 .unwrap()
382 }
383}
384
385pub async fn collect_sse_events<S, E>(stream: S) -> Result<Bytes, E>
389where
390 S: Stream<Item = Result<SseEvent, E>> + Send,
391{
392 use futures_util::StreamExt;
393
394 let mut buffer = Vec::new();
395 futures_util::pin_mut!(stream);
396
397 while let Some(result) = stream.next().await {
398 let event = result?;
399 buffer.extend_from_slice(&event.to_bytes());
400 }
401
402 Ok(Bytes::from(buffer))
403}
404
405pub fn sse_response<I>(events: I) -> Response
422where
423 I: IntoIterator<Item = SseEvent>,
424{
425 let mut buffer = String::new();
426 for event in events {
427 buffer.push_str(&event.to_sse_string());
428 }
429
430 http::Response::builder()
431 .status(StatusCode::OK)
432 .header(header::CONTENT_TYPE, "text/event-stream")
433 .header(header::CACHE_CONTROL, "no-cache")
434 .header(header::CONNECTION, "keep-alive")
435 .header("X-Accel-Buffering", "no")
436 .body(Full::new(Bytes::from(buffer)))
437 .unwrap()
438}
439
440pub fn sse_from_iter<I, E>(
444 events: I,
445) -> Sse<futures_util::stream::Iter<std::vec::IntoIter<Result<SseEvent, E>>>>
446where
447 I: IntoIterator<Item = Result<SseEvent, E>>,
448{
449 use futures_util::stream;
450 let vec: Vec<_> = events.into_iter().collect();
451 Sse::new(stream::iter(vec))
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use proptest::prelude::*;
458
459 #[test]
460 fn test_sse_event_basic() {
461 let event = SseEvent::new("Hello, World!");
462 let output = event.to_sse_string();
463 assert_eq!(output, "data: Hello, World!\n\n");
464 }
465
466 #[test]
467 fn test_sse_event_with_event_type() {
468 let event = SseEvent::new("Hello").event("greeting");
469 let output = event.to_sse_string();
470 assert!(output.contains("event: greeting\n"));
471 assert!(output.contains("data: Hello\n"));
472 }
473
474 #[test]
475 fn test_sse_event_with_id() {
476 let event = SseEvent::new("Hello").id("123");
477 let output = event.to_sse_string();
478 assert!(output.contains("id: 123\n"));
479 assert!(output.contains("data: Hello\n"));
480 }
481
482 #[test]
483 fn test_sse_event_with_retry() {
484 let event = SseEvent::new("Hello").retry(5000);
485 let output = event.to_sse_string();
486 assert!(output.contains("retry: 5000\n"));
487 assert!(output.contains("data: Hello\n"));
488 }
489
490 #[test]
491 fn test_sse_event_multiline_data() {
492 let event = SseEvent::new("Line 1\nLine 2\nLine 3");
493 let output = event.to_sse_string();
494 assert!(output.contains("data: Line 1\n"));
495 assert!(output.contains("data: Line 2\n"));
496 assert!(output.contains("data: Line 3\n"));
497 }
498
499 #[test]
500 fn test_sse_event_full() {
501 let event = SseEvent::new("Hello").event("message").id("1").retry(3000);
502 let output = event.to_sse_string();
503
504 assert!(output.contains("event: message\n"));
506 assert!(output.contains("id: 1\n"));
507 assert!(output.contains("retry: 3000\n"));
508 assert!(output.contains("data: Hello\n"));
509
510 assert!(output.ends_with("\n\n"));
512 }
513
514 #[test]
515 fn test_sse_response_headers() {
516 use futures_util::stream;
517
518 let events: Vec<Result<SseEvent, std::convert::Infallible>> =
519 vec![Ok(SseEvent::new("test"))];
520 let sse = Sse::new(stream::iter(events));
521 let response = sse.into_response();
522
523 assert_eq!(response.status(), StatusCode::OK);
524 assert_eq!(
525 response.headers().get(header::CONTENT_TYPE).unwrap(),
526 "text/event-stream"
527 );
528 assert_eq!(
529 response.headers().get(header::CACHE_CONTROL).unwrap(),
530 "no-cache"
531 );
532 assert_eq!(
533 response.headers().get(header::CONNECTION).unwrap(),
534 "keep-alive"
535 );
536 }
537
538 proptest! {
545 #![proptest_config(ProptestConfig::with_cases(100))]
546
547 #[test]
548 fn prop_sse_response_format(
549 data in "[a-zA-Z0-9 ]{1,50}",
551 event_type in proptest::option::of("[a-zA-Z][a-zA-Z0-9_]{0,20}"),
553 event_id in proptest::option::of("[a-zA-Z0-9]{1,10}"),
555 retry_time in proptest::option::of(1000u64..60000u64),
557 ) {
558 use futures_util::stream;
559
560 let mut event = SseEvent::new(data.clone());
562 if let Some(ref et) = event_type {
563 event = event.event(et.clone());
564 }
565 if let Some(ref id) = event_id {
566 event = event.id(id.clone());
567 }
568 if let Some(retry) = retry_time {
569 event = event.retry(retry);
570 }
571
572 let sse_string = event.to_sse_string();
574
575 prop_assert!(
577 sse_string.ends_with("\n\n"),
578 "SSE event must end with double newline, got: {:?}",
579 sse_string
580 );
581
582 prop_assert!(
584 sse_string.contains(&format!("data: {}", data)),
585 "SSE event must contain data field with 'data: ' prefix"
586 );
587
588 if let Some(ref et) = event_type {
590 prop_assert!(
591 sse_string.contains(&format!("event: {}", et)),
592 "SSE event must contain event type with 'event: ' prefix"
593 );
594 }
595
596 if let Some(ref id) = event_id {
598 prop_assert!(
599 sse_string.contains(&format!("id: {}", id)),
600 "SSE event must contain ID with 'id: ' prefix"
601 );
602 }
603
604 if let Some(retry) = retry_time {
606 prop_assert!(
607 sse_string.contains(&format!("retry: {}", retry)),
608 "SSE event must contain retry with 'retry: ' prefix"
609 );
610 }
611
612 let events: Vec<Result<SseEvent, std::convert::Infallible>> = vec![Ok(event)];
614 let sse = Sse::new(stream::iter(events));
615 let response = sse.into_response();
616
617 prop_assert_eq!(
618 response.headers().get(header::CONTENT_TYPE).map(|v| v.to_str().unwrap()),
619 Some("text/event-stream"),
620 "SSE response must have Content-Type: text/event-stream"
621 );
622
623 prop_assert_eq!(
624 response.headers().get(header::CACHE_CONTROL).map(|v| v.to_str().unwrap()),
625 Some("no-cache"),
626 "SSE response must have Cache-Control: no-cache"
627 );
628
629 prop_assert_eq!(
630 response.headers().get(header::CONNECTION).map(|v| v.to_str().unwrap()),
631 Some("keep-alive"),
632 "SSE response must have Connection: keep-alive"
633 );
634 }
635
636 #[test]
637 fn prop_sse_multiline_data_format(
638 lines in proptest::collection::vec("[a-zA-Z0-9 ]{1,30}", 1..5),
640 ) {
641 let data = lines.join("\n");
642 let event = SseEvent::new(data.clone());
643 let sse_string = event.to_sse_string();
644
645 for line in lines.iter() {
647 prop_assert!(
648 sse_string.contains(&format!("data: {}", line)),
649 "Each line of multiline data must be prefixed with 'data: '"
650 );
651 }
652
653 prop_assert!(
655 sse_string.ends_with("\n\n"),
656 "SSE event must end with double newline"
657 );
658 }
659 }
660}