1use bytes::Bytes;
49use futures_util::Stream;
50use http::{header, StatusCode};
51
52use pin_project_lite::pin_project;
53use rustapi_openapi::{MediaType, Operation, ResponseModifier, ResponseSpec, SchemaRef};
54use std::collections::BTreeMap;
55use std::fmt::Write;
56use std::pin::Pin;
57use std::task::{Context, Poll};
58use std::time::Duration;
59
60use crate::response::{IntoResponse, Response};
61
62#[derive(Debug, Clone, Default)]
72pub struct SseEvent {
73 pub data: String,
75 pub event: Option<String>,
77 pub id: Option<String>,
79 pub retry: Option<u64>,
81 comment: Option<String>,
83}
84
85impl SseEvent {
86 pub fn new(data: impl Into<String>) -> Self {
88 Self {
89 data: data.into(),
90 event: None,
91 id: None,
92 retry: None,
93 comment: None,
94 }
95 }
96
97 pub fn comment(text: impl Into<String>) -> Self {
101 Self {
102 data: String::new(),
103 event: None,
104 id: None,
105 retry: None,
106 comment: Some(text.into()),
107 }
108 }
109
110 pub fn event(mut self, event: impl Into<String>) -> Self {
112 self.event = Some(event.into());
113 self
114 }
115
116 pub fn id(mut self, id: impl Into<String>) -> Self {
118 self.id = Some(id.into());
119 self
120 }
121
122 pub fn retry(mut self, retry: u64) -> Self {
124 self.retry = Some(retry);
125 self
126 }
127
128 pub fn json_data<T: serde::Serialize>(data: &T) -> Result<Self, serde_json::Error> {
130 Ok(Self::new(serde_json::to_string(data)?))
131 }
132
133 pub fn to_sse_string(&self) -> String {
143 let mut output = String::new();
144
145 if let Some(ref comment) = self.comment {
147 writeln!(output, ": {}", comment).unwrap();
148 output.push('\n');
149 return output;
150 }
151
152 if let Some(ref event) = self.event {
154 writeln!(output, "event: {}", event).unwrap();
155 }
156
157 if let Some(ref id) = self.id {
159 writeln!(output, "id: {}", id).unwrap();
160 }
161
162 if let Some(retry) = self.retry {
164 writeln!(output, "retry: {}", retry).unwrap();
165 }
166
167 for line in self.data.lines() {
169 writeln!(output, "data: {}", line).unwrap();
170 }
171
172 if self.data.is_empty() && self.comment.is_none() {
174 writeln!(output, "data:").unwrap();
175 }
176
177 output.push('\n');
179
180 output
181 }
182
183 pub fn to_bytes(&self) -> Bytes {
185 Bytes::from(self.to_sse_string())
186 }
187}
188
189#[derive(Debug, Clone)]
204pub struct KeepAlive {
205 interval: Duration,
207 text: String,
209}
210
211impl Default for KeepAlive {
212 fn default() -> Self {
213 Self {
214 interval: Duration::from_secs(15),
215 text: "keep-alive".to_string(),
216 }
217 }
218}
219
220impl KeepAlive {
221 pub fn new() -> Self {
223 Self::default()
224 }
225
226 pub fn interval(mut self, interval: Duration) -> Self {
228 self.interval = interval;
229 self
230 }
231
232 pub fn text(mut self, text: impl Into<String>) -> Self {
234 self.text = text.into();
235 self
236 }
237
238 pub fn get_interval(&self) -> Duration {
240 self.interval
241 }
242
243 pub fn event(&self) -> SseEvent {
245 SseEvent::comment(&self.text)
246 }
247}
248
249pub struct Sse<S> {
270 stream: S,
271 keep_alive: Option<KeepAlive>,
272}
273
274impl<S> Sse<S> {
275 pub fn new(stream: S) -> Self {
277 Self {
278 stream,
279 keep_alive: None,
280 }
281 }
282
283 pub fn keep_alive(mut self, config: KeepAlive) -> Self {
297 self.keep_alive = Some(config);
298 self
299 }
300
301 pub fn get_keep_alive(&self) -> Option<&KeepAlive> {
303 self.keep_alive.as_ref()
304 }
305}
306
307pin_project! {
309 pub struct SseStream<S> {
311 #[pin]
312 inner: S,
313 keep_alive: Option<KeepAlive>,
314 #[pin]
315 keep_alive_timer: Option<tokio::time::Interval>,
316 }
317}
318
319impl<S, E> Stream for SseStream<S>
320where
321 S: Stream<Item = Result<SseEvent, E>>,
322{
323 type Item = Result<Bytes, E>;
324
325 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
326 let this = self.project();
327
328 match this.inner.poll_next(cx) {
330 Poll::Ready(Some(Ok(event))) => {
331 return Poll::Ready(Some(Ok(event.to_bytes())));
332 }
333 Poll::Ready(Some(Err(e))) => {
334 return Poll::Ready(Some(Err(e)));
335 }
336 Poll::Ready(None) => {
337 return Poll::Ready(None);
338 }
339 Poll::Pending => {}
340 }
341
342 if let Some(mut timer) = this.keep_alive_timer.as_pin_mut() {
344 if timer.poll_tick(cx).is_ready() {
345 if let Some(keep_alive) = this.keep_alive {
346 let event = keep_alive.event();
347 return Poll::Ready(Some(Ok(event.to_bytes())));
348 }
349 }
350 }
351
352 Poll::Pending
353 }
354}
355
356impl<S, E> IntoResponse for Sse<S>
360where
361 S: Stream<Item = Result<SseEvent, E>> + Send + 'static,
362 E: std::error::Error + Send + Sync + 'static,
363{
364 fn into_response(self) -> Response {
365 let timer = self.keep_alive.as_ref().map(|k| {
366 let mut interval = tokio::time::interval(k.interval);
367 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
368 interval
369 });
370
371 let stream = SseStream {
372 inner: self.stream,
373 keep_alive: self.keep_alive,
374 keep_alive_timer: timer,
375 };
376
377 use futures_util::StreamExt;
378 let stream =
379 stream.map(|res| res.map_err(|e| crate::error::ApiError::internal(e.to_string())));
380 let body = crate::response::Body::from_stream(stream);
381
382 http::Response::builder()
383 .status(StatusCode::OK)
384 .header(header::CONTENT_TYPE, "text/event-stream")
385 .header(header::CACHE_CONTROL, "no-cache")
386 .header(header::CONNECTION, "keep-alive")
387 .header("X-Accel-Buffering", "no") .body(body)
389 .unwrap()
390 }
391}
392
393impl<S> ResponseModifier for Sse<S> {
395 fn update_response(op: &mut Operation) {
396 let mut content = BTreeMap::new();
397 content.insert(
398 "text/event-stream".to_string(),
399 MediaType {
400 schema: Some(SchemaRef::Inline(serde_json::json!({
401 "type": "string",
402 "description": "Server-Sent Events stream. Events follow the SSE format: 'event: <type>\\ndata: <json>\\n\\n'",
403 }))),
404 example: Some(serde_json::json!("event: message\ndata: {\"id\": 1, \"text\": \"Hello\"}\n\n")),
405 },
406 );
407
408 let response = ResponseSpec {
409 description: "Server-Sent Events stream for real-time updates".to_string(),
410 content,
411 headers: BTreeMap::new(),
412 };
413 op.responses.insert("200".to_string(), response);
414 }
415}
416
417pub async fn collect_sse_events<S, E>(stream: S) -> Result<Bytes, E>
421where
422 S: Stream<Item = Result<SseEvent, E>> + Send,
423{
424 use futures_util::StreamExt;
425
426 let mut buffer = Vec::new();
427 futures_util::pin_mut!(stream);
428
429 while let Some(result) = stream.next().await {
430 let event = result?;
431 buffer.extend_from_slice(&event.to_bytes());
432 }
433
434 Ok(Bytes::from(buffer))
435}
436
437pub fn sse_response<I>(events: I) -> Response
454where
455 I: IntoIterator<Item = SseEvent>,
456{
457 let mut buffer = String::new();
458 for event in events {
459 buffer.push_str(&event.to_sse_string());
460 }
461
462 http::Response::builder()
463 .status(StatusCode::OK)
464 .header(header::CONTENT_TYPE, "text/event-stream")
465 .header(header::CACHE_CONTROL, "no-cache")
466 .header(header::CONNECTION, "keep-alive")
467 .header("X-Accel-Buffering", "no")
468 .body(crate::response::Body::from(buffer))
469 .unwrap()
470}
471
472pub fn sse_from_iter<I, E>(
476 events: I,
477) -> Sse<futures_util::stream::Iter<std::vec::IntoIter<Result<SseEvent, E>>>>
478where
479 I: IntoIterator<Item = Result<SseEvent, E>>,
480{
481 use futures_util::stream;
482 let vec: Vec<_> = events.into_iter().collect();
483 Sse::new(stream::iter(vec))
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use proptest::prelude::*;
490
491 #[test]
492 fn test_sse_event_basic() {
493 let event = SseEvent::new("Hello, World!");
494 let output = event.to_sse_string();
495 assert_eq!(output, "data: Hello, World!\n\n");
496 }
497
498 #[test]
499 fn test_sse_event_with_event_type() {
500 let event = SseEvent::new("Hello").event("greeting");
501 let output = event.to_sse_string();
502 assert!(output.contains("event: greeting\n"));
503 assert!(output.contains("data: Hello\n"));
504 }
505
506 #[test]
507 fn test_sse_event_with_id() {
508 let event = SseEvent::new("Hello").id("123");
509 let output = event.to_sse_string();
510 assert!(output.contains("id: 123\n"));
511 assert!(output.contains("data: Hello\n"));
512 }
513
514 #[test]
515 fn test_sse_event_with_retry() {
516 let event = SseEvent::new("Hello").retry(5000);
517 let output = event.to_sse_string();
518 assert!(output.contains("retry: 5000\n"));
519 assert!(output.contains("data: Hello\n"));
520 }
521
522 #[test]
523 fn test_sse_event_multiline_data() {
524 let event = SseEvent::new("Line 1\nLine 2\nLine 3");
525 let output = event.to_sse_string();
526 assert!(output.contains("data: Line 1\n"));
527 assert!(output.contains("data: Line 2\n"));
528 assert!(output.contains("data: Line 3\n"));
529 }
530
531 #[test]
532 fn test_sse_event_full() {
533 let event = SseEvent::new("Hello").event("message").id("1").retry(3000);
534 let output = event.to_sse_string();
535
536 assert!(output.contains("event: message\n"));
538 assert!(output.contains("id: 1\n"));
539 assert!(output.contains("retry: 3000\n"));
540 assert!(output.contains("data: Hello\n"));
541
542 assert!(output.ends_with("\n\n"));
544 }
545
546 #[test]
547 fn test_sse_response_headers() {
548 use futures_util::stream;
549
550 let events: Vec<Result<SseEvent, std::convert::Infallible>> =
551 vec![Ok(SseEvent::new("test"))];
552 let sse = Sse::new(stream::iter(events));
553 let response = sse.into_response();
554
555 assert_eq!(response.status(), StatusCode::OK);
556 assert_eq!(
557 response.headers().get(header::CONTENT_TYPE).unwrap(),
558 "text/event-stream"
559 );
560 assert_eq!(
561 response.headers().get(header::CACHE_CONTROL).unwrap(),
562 "no-cache"
563 );
564 assert_eq!(
565 response.headers().get(header::CONNECTION).unwrap(),
566 "keep-alive"
567 );
568 }
569
570 proptest! {
577 #![proptest_config(ProptestConfig::with_cases(100))]
578
579 #[test]
580 fn prop_sse_response_format(
581 data in "[a-zA-Z0-9 ]{1,50}",
583 event_type in proptest::option::of("[a-zA-Z][a-zA-Z0-9_]{0,20}"),
585 event_id in proptest::option::of("[a-zA-Z0-9]{1,10}"),
587 retry_time in proptest::option::of(1000u64..60000u64),
589 ) {
590 use futures_util::stream;
591
592 let mut event = SseEvent::new(data.clone());
594 if let Some(ref et) = event_type {
595 event = event.event(et.clone());
596 }
597 if let Some(ref id) = event_id {
598 event = event.id(id.clone());
599 }
600 if let Some(retry) = retry_time {
601 event = event.retry(retry);
602 }
603
604 let sse_string = event.to_sse_string();
606
607 prop_assert!(
609 sse_string.ends_with("\n\n"),
610 "SSE event must end with double newline, got: {:?}",
611 sse_string
612 );
613
614 prop_assert!(
616 sse_string.contains(&format!("data: {}", data)),
617 "SSE event must contain data field with 'data: ' prefix"
618 );
619
620 if let Some(ref et) = event_type {
622 prop_assert!(
623 sse_string.contains(&format!("event: {}", et)),
624 "SSE event must contain event type with 'event: ' prefix"
625 );
626 }
627
628 if let Some(ref id) = event_id {
630 prop_assert!(
631 sse_string.contains(&format!("id: {}", id)),
632 "SSE event must contain ID with 'id: ' prefix"
633 );
634 }
635
636 if let Some(retry) = retry_time {
638 prop_assert!(
639 sse_string.contains(&format!("retry: {}", retry)),
640 "SSE event must contain retry with 'retry: ' prefix"
641 );
642 }
643
644 let events: Vec<Result<SseEvent, std::convert::Infallible>> = vec![Ok(event)];
646 let sse = Sse::new(stream::iter(events));
647 let response = sse.into_response();
648
649 prop_assert_eq!(
650 response.headers().get(header::CONTENT_TYPE).map(|v| v.to_str().unwrap()),
651 Some("text/event-stream"),
652 "SSE response must have Content-Type: text/event-stream"
653 );
654
655 prop_assert_eq!(
656 response.headers().get(header::CACHE_CONTROL).map(|v| v.to_str().unwrap()),
657 Some("no-cache"),
658 "SSE response must have Cache-Control: no-cache"
659 );
660
661 prop_assert_eq!(
662 response.headers().get(header::CONNECTION).map(|v| v.to_str().unwrap()),
663 Some("keep-alive"),
664 "SSE response must have Connection: keep-alive"
665 );
666 }
667
668 #[test]
669 fn prop_sse_multiline_data_format(
670 lines in proptest::collection::vec("[a-zA-Z0-9 ]{1,30}", 1..5),
672 ) {
673 let data = lines.join("\n");
674 let event = SseEvent::new(data.clone());
675 let sse_string = event.to_sse_string();
676
677 for line in lines.iter() {
679 prop_assert!(
680 sse_string.contains(&format!("data: {}", line)),
681 "Each line of multiline data must be prefixed with 'data: '"
682 );
683 }
684
685 prop_assert!(
687 sse_string.ends_with("\n\n"),
688 "SSE event must end with double newline"
689 );
690 }
691 }
692}