1use bytes::Bytes;
49use futures_util::Stream;
50use http::{header, StatusCode};
51use http_body_util::Full;
52use pin_project_lite::pin_project;
53use rustapi_openapi::{MediaType, Operation, ResponseModifier, ResponseSpec, SchemaRef};
54use std::fmt::Write;
55use std::pin::Pin;
56use std::task::{Context, Poll};
57use std::time::Duration;
58
59use crate::response::{IntoResponse, Response};
60
61#[derive(Debug, Clone, Default)]
71pub struct SseEvent {
72 pub data: String,
74 pub event: Option<String>,
76 pub id: Option<String>,
78 pub retry: Option<u64>,
80 comment: Option<String>,
82}
83
84impl SseEvent {
85 pub fn new(data: impl Into<String>) -> Self {
87 Self {
88 data: data.into(),
89 event: None,
90 id: None,
91 retry: None,
92 comment: None,
93 }
94 }
95
96 pub fn comment(text: impl Into<String>) -> Self {
100 Self {
101 data: String::new(),
102 event: None,
103 id: None,
104 retry: None,
105 comment: Some(text.into()),
106 }
107 }
108
109 pub fn event(mut self, event: impl Into<String>) -> Self {
111 self.event = Some(event.into());
112 self
113 }
114
115 pub fn id(mut self, id: impl Into<String>) -> Self {
117 self.id = Some(id.into());
118 self
119 }
120
121 pub fn retry(mut self, retry: u64) -> Self {
123 self.retry = Some(retry);
124 self
125 }
126
127 pub fn json_data<T: serde::Serialize>(data: &T) -> Result<Self, serde_json::Error> {
129 Ok(Self::new(serde_json::to_string(data)?))
130 }
131
132 pub fn to_sse_string(&self) -> String {
142 let mut output = String::new();
143
144 if let Some(ref comment) = self.comment {
146 writeln!(output, ": {}", comment).unwrap();
147 output.push('\n');
148 return output;
149 }
150
151 if let Some(ref event) = self.event {
153 writeln!(output, "event: {}", event).unwrap();
154 }
155
156 if let Some(ref id) = self.id {
158 writeln!(output, "id: {}", id).unwrap();
159 }
160
161 if let Some(retry) = self.retry {
163 writeln!(output, "retry: {}", retry).unwrap();
164 }
165
166 for line in self.data.lines() {
168 writeln!(output, "data: {}", line).unwrap();
169 }
170
171 if self.data.is_empty() && self.comment.is_none() {
173 writeln!(output, "data:").unwrap();
174 }
175
176 output.push('\n');
178
179 output
180 }
181
182 pub fn to_bytes(&self) -> Bytes {
184 Bytes::from(self.to_sse_string())
185 }
186}
187
188#[derive(Debug, Clone)]
203pub struct KeepAlive {
204 interval: Duration,
206 text: String,
208}
209
210impl Default for KeepAlive {
211 fn default() -> Self {
212 Self {
213 interval: Duration::from_secs(15),
214 text: "keep-alive".to_string(),
215 }
216 }
217}
218
219impl KeepAlive {
220 pub fn new() -> Self {
222 Self::default()
223 }
224
225 pub fn interval(mut self, interval: Duration) -> Self {
227 self.interval = interval;
228 self
229 }
230
231 pub fn text(mut self, text: impl Into<String>) -> Self {
233 self.text = text.into();
234 self
235 }
236
237 pub fn get_interval(&self) -> Duration {
239 self.interval
240 }
241
242 pub fn event(&self) -> SseEvent {
244 SseEvent::comment(&self.text)
245 }
246}
247
248pub struct Sse<S> {
269 stream: S,
270 keep_alive: Option<KeepAlive>,
271}
272
273impl<S> Sse<S> {
274 pub fn new(stream: S) -> Self {
276 Self {
277 stream,
278 keep_alive: None,
279 }
280 }
281
282 pub fn keep_alive(mut self, config: KeepAlive) -> Self {
296 self.keep_alive = Some(config);
297 self
298 }
299
300 pub fn get_keep_alive(&self) -> Option<&KeepAlive> {
302 self.keep_alive.as_ref()
303 }
304}
305
306pin_project! {
308 pub struct SseStream<S> {
310 #[pin]
311 inner: S,
312 keep_alive: Option<KeepAlive>,
313 #[pin]
314 keep_alive_timer: Option<tokio::time::Interval>,
315 }
316}
317
318impl<S, E> Stream for SseStream<S>
319where
320 S: Stream<Item = Result<SseEvent, E>>,
321{
322 type Item = Result<Bytes, E>;
323
324 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
325 let this = self.project();
326
327 match this.inner.poll_next(cx) {
329 Poll::Ready(Some(Ok(event))) => {
330 return Poll::Ready(Some(Ok(event.to_bytes())));
331 }
332 Poll::Ready(Some(Err(e))) => {
333 return Poll::Ready(Some(Err(e)));
334 }
335 Poll::Ready(None) => {
336 return Poll::Ready(None);
337 }
338 Poll::Pending => {}
339 }
340
341 if let Some(mut timer) = this.keep_alive_timer.as_pin_mut() {
343 if timer.poll_tick(cx).is_ready() {
344 if let Some(keep_alive) = this.keep_alive {
345 let event = keep_alive.event();
346 return Poll::Ready(Some(Ok(event.to_bytes())));
347 }
348 }
349 }
350
351 Poll::Pending
352 }
353}
354
355impl<S, E> IntoResponse for Sse<S>
359where
360 S: Stream<Item = Result<SseEvent, E>> + Send + 'static,
361 E: std::error::Error + Send + Sync + 'static,
362{
363 fn into_response(self) -> Response {
364 let _ = self.stream; let _ = self.keep_alive; http::Response::builder()
376 .status(StatusCode::OK)
377 .header(header::CONTENT_TYPE, "text/event-stream")
378 .header(header::CACHE_CONTROL, "no-cache")
379 .header(header::CONNECTION, "keep-alive")
380 .header("X-Accel-Buffering", "no") .body(Full::new(Bytes::new()))
382 .unwrap()
383 }
384}
385
386impl<S> ResponseModifier for Sse<S> {
388 fn update_response(op: &mut Operation) {
389 let mut content = std::collections::HashMap::new();
390 content.insert(
391 "text/event-stream".to_string(),
392 MediaType {
393 schema: SchemaRef::Inline(serde_json::json!({
394 "type": "string",
395 "description": "Server-Sent Events stream. Events follow the SSE format: 'event: <type>\\ndata: <json>\\n\\n'",
396 "example": "event: message\ndata: {\"id\": 1, \"text\": \"Hello\"}\n\n"
397 })),
398 },
399 );
400
401 let response = ResponseSpec {
402 description: "Server-Sent Events stream for real-time updates".to_string(),
403 content: Some(content),
404 };
405 op.responses.insert("200".to_string(), response);
406 }
407}
408
409pub async fn collect_sse_events<S, E>(stream: S) -> Result<Bytes, E>
413where
414 S: Stream<Item = Result<SseEvent, E>> + Send,
415{
416 use futures_util::StreamExt;
417
418 let mut buffer = Vec::new();
419 futures_util::pin_mut!(stream);
420
421 while let Some(result) = stream.next().await {
422 let event = result?;
423 buffer.extend_from_slice(&event.to_bytes());
424 }
425
426 Ok(Bytes::from(buffer))
427}
428
429pub fn sse_response<I>(events: I) -> Response
446where
447 I: IntoIterator<Item = SseEvent>,
448{
449 let mut buffer = String::new();
450 for event in events {
451 buffer.push_str(&event.to_sse_string());
452 }
453
454 http::Response::builder()
455 .status(StatusCode::OK)
456 .header(header::CONTENT_TYPE, "text/event-stream")
457 .header(header::CACHE_CONTROL, "no-cache")
458 .header(header::CONNECTION, "keep-alive")
459 .header("X-Accel-Buffering", "no")
460 .body(Full::new(Bytes::from(buffer)))
461 .unwrap()
462}
463
464pub fn sse_from_iter<I, E>(
468 events: I,
469) -> Sse<futures_util::stream::Iter<std::vec::IntoIter<Result<SseEvent, E>>>>
470where
471 I: IntoIterator<Item = Result<SseEvent, E>>,
472{
473 use futures_util::stream;
474 let vec: Vec<_> = events.into_iter().collect();
475 Sse::new(stream::iter(vec))
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481 use proptest::prelude::*;
482
483 #[test]
484 fn test_sse_event_basic() {
485 let event = SseEvent::new("Hello, World!");
486 let output = event.to_sse_string();
487 assert_eq!(output, "data: Hello, World!\n\n");
488 }
489
490 #[test]
491 fn test_sse_event_with_event_type() {
492 let event = SseEvent::new("Hello").event("greeting");
493 let output = event.to_sse_string();
494 assert!(output.contains("event: greeting\n"));
495 assert!(output.contains("data: Hello\n"));
496 }
497
498 #[test]
499 fn test_sse_event_with_id() {
500 let event = SseEvent::new("Hello").id("123");
501 let output = event.to_sse_string();
502 assert!(output.contains("id: 123\n"));
503 assert!(output.contains("data: Hello\n"));
504 }
505
506 #[test]
507 fn test_sse_event_with_retry() {
508 let event = SseEvent::new("Hello").retry(5000);
509 let output = event.to_sse_string();
510 assert!(output.contains("retry: 5000\n"));
511 assert!(output.contains("data: Hello\n"));
512 }
513
514 #[test]
515 fn test_sse_event_multiline_data() {
516 let event = SseEvent::new("Line 1\nLine 2\nLine 3");
517 let output = event.to_sse_string();
518 assert!(output.contains("data: Line 1\n"));
519 assert!(output.contains("data: Line 2\n"));
520 assert!(output.contains("data: Line 3\n"));
521 }
522
523 #[test]
524 fn test_sse_event_full() {
525 let event = SseEvent::new("Hello").event("message").id("1").retry(3000);
526 let output = event.to_sse_string();
527
528 assert!(output.contains("event: message\n"));
530 assert!(output.contains("id: 1\n"));
531 assert!(output.contains("retry: 3000\n"));
532 assert!(output.contains("data: Hello\n"));
533
534 assert!(output.ends_with("\n\n"));
536 }
537
538 #[test]
539 fn test_sse_response_headers() {
540 use futures_util::stream;
541
542 let events: Vec<Result<SseEvent, std::convert::Infallible>> =
543 vec![Ok(SseEvent::new("test"))];
544 let sse = Sse::new(stream::iter(events));
545 let response = sse.into_response();
546
547 assert_eq!(response.status(), StatusCode::OK);
548 assert_eq!(
549 response.headers().get(header::CONTENT_TYPE).unwrap(),
550 "text/event-stream"
551 );
552 assert_eq!(
553 response.headers().get(header::CACHE_CONTROL).unwrap(),
554 "no-cache"
555 );
556 assert_eq!(
557 response.headers().get(header::CONNECTION).unwrap(),
558 "keep-alive"
559 );
560 }
561
562 proptest! {
569 #![proptest_config(ProptestConfig::with_cases(100))]
570
571 #[test]
572 fn prop_sse_response_format(
573 data in "[a-zA-Z0-9 ]{1,50}",
575 event_type in proptest::option::of("[a-zA-Z][a-zA-Z0-9_]{0,20}"),
577 event_id in proptest::option::of("[a-zA-Z0-9]{1,10}"),
579 retry_time in proptest::option::of(1000u64..60000u64),
581 ) {
582 use futures_util::stream;
583
584 let mut event = SseEvent::new(data.clone());
586 if let Some(ref et) = event_type {
587 event = event.event(et.clone());
588 }
589 if let Some(ref id) = event_id {
590 event = event.id(id.clone());
591 }
592 if let Some(retry) = retry_time {
593 event = event.retry(retry);
594 }
595
596 let sse_string = event.to_sse_string();
598
599 prop_assert!(
601 sse_string.ends_with("\n\n"),
602 "SSE event must end with double newline, got: {:?}",
603 sse_string
604 );
605
606 prop_assert!(
608 sse_string.contains(&format!("data: {}", data)),
609 "SSE event must contain data field with 'data: ' prefix"
610 );
611
612 if let Some(ref et) = event_type {
614 prop_assert!(
615 sse_string.contains(&format!("event: {}", et)),
616 "SSE event must contain event type with 'event: ' prefix"
617 );
618 }
619
620 if let Some(ref id) = event_id {
622 prop_assert!(
623 sse_string.contains(&format!("id: {}", id)),
624 "SSE event must contain ID with 'id: ' prefix"
625 );
626 }
627
628 if let Some(retry) = retry_time {
630 prop_assert!(
631 sse_string.contains(&format!("retry: {}", retry)),
632 "SSE event must contain retry with 'retry: ' prefix"
633 );
634 }
635
636 let events: Vec<Result<SseEvent, std::convert::Infallible>> = vec![Ok(event)];
638 let sse = Sse::new(stream::iter(events));
639 let response = sse.into_response();
640
641 prop_assert_eq!(
642 response.headers().get(header::CONTENT_TYPE).map(|v| v.to_str().unwrap()),
643 Some("text/event-stream"),
644 "SSE response must have Content-Type: text/event-stream"
645 );
646
647 prop_assert_eq!(
648 response.headers().get(header::CACHE_CONTROL).map(|v| v.to_str().unwrap()),
649 Some("no-cache"),
650 "SSE response must have Cache-Control: no-cache"
651 );
652
653 prop_assert_eq!(
654 response.headers().get(header::CONNECTION).map(|v| v.to_str().unwrap()),
655 Some("keep-alive"),
656 "SSE response must have Connection: keep-alive"
657 );
658 }
659
660 #[test]
661 fn prop_sse_multiline_data_format(
662 lines in proptest::collection::vec("[a-zA-Z0-9 ]{1,30}", 1..5),
664 ) {
665 let data = lines.join("\n");
666 let event = SseEvent::new(data.clone());
667 let sse_string = event.to_sse_string();
668
669 for line in lines.iter() {
671 prop_assert!(
672 sse_string.contains(&format!("data: {}", line)),
673 "Each line of multiline data must be prefixed with 'data: '"
674 );
675 }
676
677 prop_assert!(
679 sse_string.ends_with("\n\n"),
680 "SSE event must end with double newline"
681 );
682 }
683 }
684}