1use bytes::Bytes;
49use futures_util::Stream;
50use http::{header, StatusCode};
51
52use pin_project_lite::pin_project;
53use rustapi_openapi::{MediaType, Operation, ResponseModifier, ResponseSpec, SchemaRef};
54use std::fmt::Write;
55use std::pin::Pin;
56use std::task::{Context, Poll};
57use std::time::Duration;
58
59use crate::response::{IntoResponse, Response};
60
61#[derive(Debug, Clone, Default)]
71pub struct SseEvent {
72 pub data: String,
74 pub event: Option<String>,
76 pub id: Option<String>,
78 pub retry: Option<u64>,
80 comment: Option<String>,
82}
83
84impl SseEvent {
85 pub fn new(data: impl Into<String>) -> Self {
87 Self {
88 data: data.into(),
89 event: None,
90 id: None,
91 retry: None,
92 comment: None,
93 }
94 }
95
96 pub fn comment(text: impl Into<String>) -> Self {
100 Self {
101 data: String::new(),
102 event: None,
103 id: None,
104 retry: None,
105 comment: Some(text.into()),
106 }
107 }
108
109 pub fn event(mut self, event: impl Into<String>) -> Self {
111 self.event = Some(event.into());
112 self
113 }
114
115 pub fn id(mut self, id: impl Into<String>) -> Self {
117 self.id = Some(id.into());
118 self
119 }
120
121 pub fn retry(mut self, retry: u64) -> Self {
123 self.retry = Some(retry);
124 self
125 }
126
127 pub fn json_data<T: serde::Serialize>(data: &T) -> Result<Self, serde_json::Error> {
129 Ok(Self::new(serde_json::to_string(data)?))
130 }
131
132 pub fn to_sse_string(&self) -> String {
142 let mut output = String::new();
143
144 if let Some(ref comment) = self.comment {
146 writeln!(output, ": {}", comment).unwrap();
147 output.push('\n');
148 return output;
149 }
150
151 if let Some(ref event) = self.event {
153 writeln!(output, "event: {}", event).unwrap();
154 }
155
156 if let Some(ref id) = self.id {
158 writeln!(output, "id: {}", id).unwrap();
159 }
160
161 if let Some(retry) = self.retry {
163 writeln!(output, "retry: {}", retry).unwrap();
164 }
165
166 for line in self.data.lines() {
168 writeln!(output, "data: {}", line).unwrap();
169 }
170
171 if self.data.is_empty() && self.comment.is_none() {
173 writeln!(output, "data:").unwrap();
174 }
175
176 output.push('\n');
178
179 output
180 }
181
182 pub fn to_bytes(&self) -> Bytes {
184 Bytes::from(self.to_sse_string())
185 }
186}
187
188#[derive(Debug, Clone)]
203pub struct KeepAlive {
204 interval: Duration,
206 text: String,
208}
209
210impl Default for KeepAlive {
211 fn default() -> Self {
212 Self {
213 interval: Duration::from_secs(15),
214 text: "keep-alive".to_string(),
215 }
216 }
217}
218
219impl KeepAlive {
220 pub fn new() -> Self {
222 Self::default()
223 }
224
225 pub fn interval(mut self, interval: Duration) -> Self {
227 self.interval = interval;
228 self
229 }
230
231 pub fn text(mut self, text: impl Into<String>) -> Self {
233 self.text = text.into();
234 self
235 }
236
237 pub fn get_interval(&self) -> Duration {
239 self.interval
240 }
241
242 pub fn event(&self) -> SseEvent {
244 SseEvent::comment(&self.text)
245 }
246}
247
248pub struct Sse<S> {
269 stream: S,
270 keep_alive: Option<KeepAlive>,
271}
272
273impl<S> Sse<S> {
274 pub fn new(stream: S) -> Self {
276 Self {
277 stream,
278 keep_alive: None,
279 }
280 }
281
282 pub fn keep_alive(mut self, config: KeepAlive) -> Self {
296 self.keep_alive = Some(config);
297 self
298 }
299
300 pub fn get_keep_alive(&self) -> Option<&KeepAlive> {
302 self.keep_alive.as_ref()
303 }
304}
305
306pin_project! {
308 pub struct SseStream<S> {
310 #[pin]
311 inner: S,
312 keep_alive: Option<KeepAlive>,
313 #[pin]
314 keep_alive_timer: Option<tokio::time::Interval>,
315 }
316}
317
318impl<S, E> Stream for SseStream<S>
319where
320 S: Stream<Item = Result<SseEvent, E>>,
321{
322 type Item = Result<Bytes, E>;
323
324 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
325 let this = self.project();
326
327 match this.inner.poll_next(cx) {
329 Poll::Ready(Some(Ok(event))) => {
330 return Poll::Ready(Some(Ok(event.to_bytes())));
331 }
332 Poll::Ready(Some(Err(e))) => {
333 return Poll::Ready(Some(Err(e)));
334 }
335 Poll::Ready(None) => {
336 return Poll::Ready(None);
337 }
338 Poll::Pending => {}
339 }
340
341 if let Some(mut timer) = this.keep_alive_timer.as_pin_mut() {
343 if timer.poll_tick(cx).is_ready() {
344 if let Some(keep_alive) = this.keep_alive {
345 let event = keep_alive.event();
346 return Poll::Ready(Some(Ok(event.to_bytes())));
347 }
348 }
349 }
350
351 Poll::Pending
352 }
353}
354
355impl<S, E> IntoResponse for Sse<S>
359where
360 S: Stream<Item = Result<SseEvent, E>> + Send + 'static,
361 E: std::error::Error + Send + Sync + 'static,
362{
363 fn into_response(self) -> Response {
364 let timer = self.keep_alive.as_ref().map(|k| {
365 let mut interval = tokio::time::interval(k.interval);
366 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
367 interval
368 });
369
370 let stream = SseStream {
371 inner: self.stream,
372 keep_alive: self.keep_alive,
373 keep_alive_timer: timer,
374 };
375
376 use futures_util::StreamExt;
377 let stream =
378 stream.map(|res| res.map_err(|e| crate::error::ApiError::internal(e.to_string())));
379 let body = crate::response::Body::from_stream(stream);
380
381 http::Response::builder()
382 .status(StatusCode::OK)
383 .header(header::CONTENT_TYPE, "text/event-stream")
384 .header(header::CACHE_CONTROL, "no-cache")
385 .header(header::CONNECTION, "keep-alive")
386 .header("X-Accel-Buffering", "no") .body(body)
388 .unwrap()
389 }
390}
391
392impl<S> ResponseModifier for Sse<S> {
394 fn update_response(op: &mut Operation) {
395 let mut content = std::collections::HashMap::new();
396 content.insert(
397 "text/event-stream".to_string(),
398 MediaType {
399 schema: SchemaRef::Inline(serde_json::json!({
400 "type": "string",
401 "description": "Server-Sent Events stream. Events follow the SSE format: 'event: <type>\\ndata: <json>\\n\\n'",
402 "example": "event: message\ndata: {\"id\": 1, \"text\": \"Hello\"}\n\n"
403 })),
404 },
405 );
406
407 let response = ResponseSpec {
408 description: "Server-Sent Events stream for real-time updates".to_string(),
409 content: Some(content),
410 };
411 op.responses.insert("200".to_string(), response);
412 }
413}
414
415pub async fn collect_sse_events<S, E>(stream: S) -> Result<Bytes, E>
419where
420 S: Stream<Item = Result<SseEvent, E>> + Send,
421{
422 use futures_util::StreamExt;
423
424 let mut buffer = Vec::new();
425 futures_util::pin_mut!(stream);
426
427 while let Some(result) = stream.next().await {
428 let event = result?;
429 buffer.extend_from_slice(&event.to_bytes());
430 }
431
432 Ok(Bytes::from(buffer))
433}
434
435pub fn sse_response<I>(events: I) -> Response
452where
453 I: IntoIterator<Item = SseEvent>,
454{
455 let mut buffer = String::new();
456 for event in events {
457 buffer.push_str(&event.to_sse_string());
458 }
459
460 http::Response::builder()
461 .status(StatusCode::OK)
462 .header(header::CONTENT_TYPE, "text/event-stream")
463 .header(header::CACHE_CONTROL, "no-cache")
464 .header(header::CONNECTION, "keep-alive")
465 .header("X-Accel-Buffering", "no")
466 .body(crate::response::Body::from(buffer))
467 .unwrap()
468}
469
470pub fn sse_from_iter<I, E>(
474 events: I,
475) -> Sse<futures_util::stream::Iter<std::vec::IntoIter<Result<SseEvent, E>>>>
476where
477 I: IntoIterator<Item = Result<SseEvent, E>>,
478{
479 use futures_util::stream;
480 let vec: Vec<_> = events.into_iter().collect();
481 Sse::new(stream::iter(vec))
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use proptest::prelude::*;
488
489 #[test]
490 fn test_sse_event_basic() {
491 let event = SseEvent::new("Hello, World!");
492 let output = event.to_sse_string();
493 assert_eq!(output, "data: Hello, World!\n\n");
494 }
495
496 #[test]
497 fn test_sse_event_with_event_type() {
498 let event = SseEvent::new("Hello").event("greeting");
499 let output = event.to_sse_string();
500 assert!(output.contains("event: greeting\n"));
501 assert!(output.contains("data: Hello\n"));
502 }
503
504 #[test]
505 fn test_sse_event_with_id() {
506 let event = SseEvent::new("Hello").id("123");
507 let output = event.to_sse_string();
508 assert!(output.contains("id: 123\n"));
509 assert!(output.contains("data: Hello\n"));
510 }
511
512 #[test]
513 fn test_sse_event_with_retry() {
514 let event = SseEvent::new("Hello").retry(5000);
515 let output = event.to_sse_string();
516 assert!(output.contains("retry: 5000\n"));
517 assert!(output.contains("data: Hello\n"));
518 }
519
520 #[test]
521 fn test_sse_event_multiline_data() {
522 let event = SseEvent::new("Line 1\nLine 2\nLine 3");
523 let output = event.to_sse_string();
524 assert!(output.contains("data: Line 1\n"));
525 assert!(output.contains("data: Line 2\n"));
526 assert!(output.contains("data: Line 3\n"));
527 }
528
529 #[test]
530 fn test_sse_event_full() {
531 let event = SseEvent::new("Hello").event("message").id("1").retry(3000);
532 let output = event.to_sse_string();
533
534 assert!(output.contains("event: message\n"));
536 assert!(output.contains("id: 1\n"));
537 assert!(output.contains("retry: 3000\n"));
538 assert!(output.contains("data: Hello\n"));
539
540 assert!(output.ends_with("\n\n"));
542 }
543
544 #[test]
545 fn test_sse_response_headers() {
546 use futures_util::stream;
547
548 let events: Vec<Result<SseEvent, std::convert::Infallible>> =
549 vec![Ok(SseEvent::new("test"))];
550 let sse = Sse::new(stream::iter(events));
551 let response = sse.into_response();
552
553 assert_eq!(response.status(), StatusCode::OK);
554 assert_eq!(
555 response.headers().get(header::CONTENT_TYPE).unwrap(),
556 "text/event-stream"
557 );
558 assert_eq!(
559 response.headers().get(header::CACHE_CONTROL).unwrap(),
560 "no-cache"
561 );
562 assert_eq!(
563 response.headers().get(header::CONNECTION).unwrap(),
564 "keep-alive"
565 );
566 }
567
568 proptest! {
575 #![proptest_config(ProptestConfig::with_cases(100))]
576
577 #[test]
578 fn prop_sse_response_format(
579 data in "[a-zA-Z0-9 ]{1,50}",
581 event_type in proptest::option::of("[a-zA-Z][a-zA-Z0-9_]{0,20}"),
583 event_id in proptest::option::of("[a-zA-Z0-9]{1,10}"),
585 retry_time in proptest::option::of(1000u64..60000u64),
587 ) {
588 use futures_util::stream;
589
590 let mut event = SseEvent::new(data.clone());
592 if let Some(ref et) = event_type {
593 event = event.event(et.clone());
594 }
595 if let Some(ref id) = event_id {
596 event = event.id(id.clone());
597 }
598 if let Some(retry) = retry_time {
599 event = event.retry(retry);
600 }
601
602 let sse_string = event.to_sse_string();
604
605 prop_assert!(
607 sse_string.ends_with("\n\n"),
608 "SSE event must end with double newline, got: {:?}",
609 sse_string
610 );
611
612 prop_assert!(
614 sse_string.contains(&format!("data: {}", data)),
615 "SSE event must contain data field with 'data: ' prefix"
616 );
617
618 if let Some(ref et) = event_type {
620 prop_assert!(
621 sse_string.contains(&format!("event: {}", et)),
622 "SSE event must contain event type with 'event: ' prefix"
623 );
624 }
625
626 if let Some(ref id) = event_id {
628 prop_assert!(
629 sse_string.contains(&format!("id: {}", id)),
630 "SSE event must contain ID with 'id: ' prefix"
631 );
632 }
633
634 if let Some(retry) = retry_time {
636 prop_assert!(
637 sse_string.contains(&format!("retry: {}", retry)),
638 "SSE event must contain retry with 'retry: ' prefix"
639 );
640 }
641
642 let events: Vec<Result<SseEvent, std::convert::Infallible>> = vec![Ok(event)];
644 let sse = Sse::new(stream::iter(events));
645 let response = sse.into_response();
646
647 prop_assert_eq!(
648 response.headers().get(header::CONTENT_TYPE).map(|v| v.to_str().unwrap()),
649 Some("text/event-stream"),
650 "SSE response must have Content-Type: text/event-stream"
651 );
652
653 prop_assert_eq!(
654 response.headers().get(header::CACHE_CONTROL).map(|v| v.to_str().unwrap()),
655 Some("no-cache"),
656 "SSE response must have Cache-Control: no-cache"
657 );
658
659 prop_assert_eq!(
660 response.headers().get(header::CONNECTION).map(|v| v.to_str().unwrap()),
661 Some("keep-alive"),
662 "SSE response must have Connection: keep-alive"
663 );
664 }
665
666 #[test]
667 fn prop_sse_multiline_data_format(
668 lines in proptest::collection::vec("[a-zA-Z0-9 ]{1,30}", 1..5),
670 ) {
671 let data = lines.join("\n");
672 let event = SseEvent::new(data.clone());
673 let sse_string = event.to_sse_string();
674
675 for line in lines.iter() {
677 prop_assert!(
678 sse_string.contains(&format!("data: {}", line)),
679 "Each line of multiline data must be prefixed with 'data: '"
680 );
681 }
682
683 prop_assert!(
685 sse_string.ends_with("\n\n"),
686 "SSE event must end with double newline"
687 );
688 }
689 }
690}