1use std::future::Future;
9use std::marker::PhantomData;
10use std::pin::Pin;
11use std::sync::{Arc, LazyLock};
12use std::task::{Context, Poll};
13use std::time::Duration;
14
15use bytes::Bytes;
16use futures_util::stream::{BoxStream, StreamExt};
17use http::header::{HeaderName, HeaderValue, CACHE_CONTROL, CONTENT_TYPE};
18use http_body::{Body, Frame, SizeHint};
19use serde::Serialize;
20use tokio::sync::{OwnedSemaphorePermit, Semaphore};
21use tokio::time::{interval_at, sleep, Instant, Interval, Sleep};
22
23use crate::body::{BoxError, RespBody};
24use crate::constants::TEXT_EVENT_STREAM;
25use crate::error::{Error, Result};
26use crate::extract::RequestContext;
27use crate::response::{IntoResponse, Response};
28
29const DEFAULT_HEARTBEAT: Duration = Duration::from_secs(15);
31const DEFAULT_MAX_EVENT_SIZE: usize = 256 * 1024;
35const HEARTBEAT_FRAME: &[u8] = b": ping\n\n";
37static HEARTBEAT_BYTES: LazyLock<Bytes> = LazyLock::new(|| Bytes::from_static(HEARTBEAT_FRAME));
40const X_ACCEL_BUFFERING: &str = "x-accel-buffering";
42
43pub struct SseEvent<T> {
49 data: Option<T>,
50 raw: Option<String>,
51 event: Option<String>,
52 id: Option<String>,
53 retry: Option<u64>,
54 comment: Option<String>,
55}
56
57impl<T> SseEvent<T> {
58 pub fn new(data: T) -> Self {
60 Self {
61 data: Some(data),
62 raw: None,
63 event: None,
64 id: None,
65 retry: None,
66 comment: None,
67 }
68 }
69
70 pub fn raw(raw: impl Into<String>) -> Self {
72 Self {
73 data: None,
74 raw: Some(raw.into()),
75 event: None,
76 id: None,
77 retry: None,
78 comment: None,
79 }
80 }
81
82 pub fn event(mut self, name: impl Into<String>) -> Self {
84 self.event = Some(name.into());
85 self
86 }
87
88 pub fn id(mut self, id: impl ToString) -> Self {
90 self.id = Some(id.to_string());
91 self
92 }
93
94 pub fn retry_ms(mut self, ms: u64) -> Self {
96 self.retry = Some(ms);
97 self
98 }
99
100 pub fn comment(mut self, text: impl Into<String>) -> Self {
102 self.comment = Some(text.into());
103 self
104 }
105}
106
107impl<T: Serialize> SseEvent<T> {
108 fn into_raw(self) -> Result<RawEvent> {
110 let data = match (self.data, self.raw) {
111 (Some(data), _) => Some(serde_json::to_string(&data).map_err(|error| {
112 Error::internal(format!("failed to serialize SSE data: {error}"))
113 })?),
114 (None, Some(raw)) => Some(raw),
115 (None, None) => None,
116 };
117 Ok(RawEvent {
118 data,
119 event: self.event,
120 id: self.id,
121 retry: self.retry,
122 comment: self.comment,
123 })
124 }
125}
126
127struct RawEvent {
129 data: Option<String>,
130 event: Option<String>,
131 id: Option<String>,
132 retry: Option<u64>,
133 comment: Option<String>,
134}
135
136fn encode_event(event: &RawEvent, default_event: Option<&str>) -> Bytes {
141 let mut out = String::new();
142
143 if let Some(comment) = &event.comment {
144 for line in comment.split('\n') {
145 out.push_str(": ");
146 out.push_str(line);
147 out.push('\n');
148 }
149 }
150 if let Some(name) = event.event.as_deref().or(default_event) {
151 out.push_str("event: ");
152 push_single_line(&mut out, name, false);
155 out.push('\n');
156 }
157 if let Some(id) = &event.id {
158 out.push_str("id: ");
159 push_single_line(&mut out, id, true);
161 out.push('\n');
162 }
163 if let Some(retry) = event.retry {
164 out.push_str("retry: ");
165 out.push_str(&retry.to_string());
166 out.push('\n');
167 }
168 if let Some(data) = &event.data {
169 for line in data.split('\n') {
170 out.push_str("data: ");
171 out.push_str(line);
172 out.push('\n');
173 }
174 }
175 out.push('\n');
176
177 Bytes::from(out)
178}
179
180fn push_single_line(out: &mut String, value: &str, strip_nul: bool) {
184 for ch in value.chars() {
185 if ch == '\r' || ch == '\n' || (strip_nul && ch == '\0') {
186 continue;
187 }
188 out.push(ch);
189 }
190}
191
192struct SseConfig {
194 default_event: Option<String>,
195 heartbeat: Option<Duration>,
196 no_cache: bool,
197 disable_proxy_buffering: bool,
198 max_event_size: Option<usize>,
199 client_timeout: Option<Duration>,
200 done_event: Option<String>,
201}
202
203impl Default for SseConfig {
204 fn default() -> Self {
205 Self {
206 default_event: None,
207 heartbeat: Some(DEFAULT_HEARTBEAT),
208 no_cache: true,
209 disable_proxy_buffering: true,
210 max_event_size: Some(DEFAULT_MAX_EVENT_SIZE),
211 client_timeout: None,
212 done_event: None,
213 }
214 }
215}
216
217pub struct Sse<T> {
219 events: BoxStream<'static, Result<RawEvent>>,
220 config: SseConfig,
221 _marker: PhantomData<fn() -> T>,
222}
223
224impl<T: Serialize + Send + 'static> Sse<T> {
225 pub fn new<S>(stream: S) -> Self
230 where
231 S: futures_core::Stream<Item = Result<T>> + Send + 'static,
232 {
233 let events = stream
234 .map(|item| item.and_then(|value| SseEvent::new(value).into_raw()))
235 .boxed();
236 Self::from_events(events)
237 }
238
239 pub fn events<S>(stream: S) -> Self
241 where
242 S: futures_core::Stream<Item = Result<SseEvent<T>>> + Send + 'static,
243 {
244 let events = stream.map(|item| item.and_then(SseEvent::into_raw)).boxed();
245 Self::from_events(events)
246 }
247
248 fn from_events(events: BoxStream<'static, Result<RawEvent>>) -> Self {
250 Self {
251 events,
252 config: SseConfig::default(),
253 _marker: PhantomData,
254 }
255 }
256
257 pub fn event(mut self, default: impl Into<String>) -> Self {
259 self.config.default_event = Some(default.into());
260 self
261 }
262
263 pub fn heartbeat(mut self, every: Duration) -> Self {
265 self.config.heartbeat = Some(every);
266 self
267 }
268
269 pub fn no_heartbeat(mut self) -> Self {
271 self.config.heartbeat = None;
272 self
273 }
274
275 pub fn no_cache(mut self, on: bool) -> Self {
277 self.config.no_cache = on;
278 self
279 }
280
281 pub fn disable_proxy_buffering(mut self, on: bool) -> Self {
283 self.config.disable_proxy_buffering = on;
284 self
285 }
286
287 pub fn max_event_size(mut self, bytes: usize) -> Self {
289 self.config.max_event_size = Some(bytes);
290 self
291 }
292
293 pub fn client_timeout(mut self, after: Duration) -> Self {
295 self.config.client_timeout = Some(after);
296 self
297 }
298
299 pub fn done_event(mut self, marker: impl Into<String>) -> Self {
301 self.config.done_event = Some(marker.into());
302 self
303 }
304}
305
306#[derive(Clone)]
312pub(crate) struct SseLimiter {
313 semaphore: Arc<Semaphore>,
314}
315
316impl SseLimiter {
317 pub(crate) fn new(limit: usize) -> Self {
318 Self {
319 semaphore: Arc::new(Semaphore::new(limit)),
320 }
321 }
322
323 fn try_acquire(&self) -> Option<OwnedSemaphorePermit> {
325 Arc::clone(&self.semaphore).try_acquire_owned().ok()
326 }
327}
328
329#[doc(hidden)]
336pub fn __sse_into_response<T>(ctx: &RequestContext, sse: Sse<T>) -> Result<Response> {
337 let permit = match ctx.state().get::<SseLimiter>() {
338 Some(limiter) => match limiter.try_acquire() {
339 Some(permit) => Some(permit),
340 None => {
341 return Err(Error::service_unavailable(
342 "the server is at its Server-Sent Events connection limit",
343 ));
344 }
345 },
346 None => None,
347 };
348 Ok(sse.into_response_with_permit(permit))
349}
350
351impl<T> IntoResponse for Sse<T> {
352 fn into_response(self) -> Response {
353 self.into_response_with_permit(None)
354 }
355}
356
357impl<T> Sse<T> {
358 fn into_response_with_permit(self, permit: Option<OwnedSemaphorePermit>) -> Response {
361 let Sse { events, config, .. } = self;
362
363 let heartbeat = config
364 .heartbeat
365 .map(|every| interval_at(Instant::now() + every, every));
366 let timeout = config.client_timeout.map(|after| Box::pin(sleep(after)));
367 let done = config.done_event.map(|marker| {
368 encode_event(
369 &RawEvent {
370 data: Some(marker),
371 event: None,
372 id: None,
373 retry: None,
374 comment: None,
375 },
376 config.default_event.as_deref(),
377 )
378 });
379
380 let body = SseBody {
381 events,
382 default_event: config.default_event,
383 max_event_size: config.max_event_size,
384 heartbeat,
385 timeout,
386 done,
387 finished: false,
388 _permit: permit,
389 };
390
391 let mut response = http::Response::new(RespBody::stream(body));
392 let headers = response.headers_mut();
393 headers.insert(CONTENT_TYPE, HeaderValue::from_static(TEXT_EVENT_STREAM));
394 if config.no_cache {
395 headers.insert(CACHE_CONTROL, HeaderValue::from_static("no-cache"));
396 }
397 if config.disable_proxy_buffering {
398 headers.insert(
399 HeaderName::from_static(X_ACCEL_BUFFERING),
400 HeaderValue::from_static("no"),
401 );
402 }
403 response
404 }
405}
406
407struct SseBody {
413 events: BoxStream<'static, Result<RawEvent>>,
414 default_event: Option<String>,
415 max_event_size: Option<usize>,
416 heartbeat: Option<Interval>,
417 timeout: Option<Pin<Box<Sleep>>>,
418 done: Option<Bytes>,
419 finished: bool,
420 _permit: Option<OwnedSemaphorePermit>,
422}
423
424impl SseBody {
425 fn finish(&mut self) -> Poll<Option<Result<Frame<Bytes>, BoxError>>> {
427 self.finished = true;
428 Poll::Ready(self.done.take().map(|bytes| Ok(Frame::data(bytes))))
429 }
430}
431
432impl Body for SseBody {
433 type Data = Bytes;
434 type Error = BoxError;
435
436 fn poll_frame(
437 self: Pin<&mut Self>,
438 cx: &mut Context<'_>,
439 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
440 let this = self.get_mut();
441 if this.finished {
442 return Poll::Ready(None);
443 }
444
445 if let Some(timeout) = &mut this.timeout {
447 if timeout.as_mut().poll(cx).is_ready() {
448 return this.finish();
449 }
450 }
451
452 loop {
454 match this.events.poll_next_unpin(cx) {
455 Poll::Ready(Some(Ok(event))) => {
456 let bytes = encode_event(&event, this.default_event.as_deref());
457 if let Some(max) = this.max_event_size {
458 if bytes.len() > max {
459 tracing::warn!(
460 target: "tork",
461 event_bytes = bytes.len(),
462 max_event_size = max,
463 "SSE event exceeds max_event_size, skipping"
464 );
465 continue;
466 }
467 }
468 return Poll::Ready(Some(Ok(Frame::data(bytes))));
469 }
470 Poll::Ready(Some(Err(error))) => {
471 tracing::error!(target: "tork", error = %error, "SSE stream error");
473 return this.finish();
474 }
475 Poll::Ready(None) => return this.finish(),
476 Poll::Pending => break,
477 }
478 }
479
480 if let Some(heartbeat) = &mut this.heartbeat {
482 if heartbeat.poll_tick(cx).is_ready() {
483 return Poll::Ready(Some(Ok(Frame::data(HEARTBEAT_BYTES.clone()))));
484 }
485 }
486
487 Poll::Pending
488 }
489
490 fn is_end_stream(&self) -> bool {
491 self.finished
492 }
493
494 fn size_hint(&self) -> SizeHint {
495 SizeHint::default()
496 }
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502 use http::StatusCode;
503 use http_body_util::BodyExt;
504 use serde_json::json;
505 use std::time::Duration;
506
507 fn encode<T: Serialize>(event: SseEvent<T>, default: Option<&str>) -> String {
508 let raw = event.into_raw().expect("serialize");
509 String::from_utf8(encode_event(&raw, default).to_vec()).unwrap()
510 }
511
512 #[derive(Debug)]
513 struct BadSerialize;
514
515 impl Serialize for BadSerialize {
516 fn serialize<S>(&self, _serializer: S) -> std::result::Result<S::Ok, S::Error>
517 where
518 S: serde::Serializer,
519 {
520 Err(serde::ser::Error::custom("nope"))
521 }
522 }
523
524 #[test]
525 fn encodes_event_id_retry_and_data() {
526 let text = encode(
527 SseEvent::new(json!({ "id": 1 }))
528 .event("item")
529 .id(7)
530 .retry_ms(5000),
531 None,
532 );
533 assert_eq!(
534 text,
535 "event: item\nid: 7\nretry: 5000\ndata: {\"id\":1}\n\n"
536 );
537 }
538
539 #[test]
540 fn encodes_raw_data_with_event() {
541 let text = encode(SseEvent::<()>::raw("[DONE]").event("done"), None);
542 assert_eq!(text, "event: done\ndata: [DONE]\n\n");
543 }
544
545 #[test]
546 fn falls_back_to_the_default_event_name() {
547 let text = encode(SseEvent::new(json!(1)), Some("tick"));
548 assert_eq!(text, "event: tick\ndata: 1\n\n");
549 }
550
551 #[test]
552 fn comment_and_multiline_raw_data_split_into_lines() {
553 let text = encode(SseEvent::<()>::raw("a\nb").comment("note"), None);
554 assert_eq!(text, ": note\ndata: a\ndata: b\n\n");
555 }
556
557 #[test]
558 fn event_name_and_id_cannot_inject_extra_fields() {
559 let text = encode(
562 SseEvent::new(json!(1))
563 .event("ping\nevent: admin\ndata: spoofed")
564 .id("9\r\nid: 0\0"),
565 None,
566 );
567 assert_eq!(text, "event: pingevent: admindata: spoofed\nid: 9id: 0\ndata: 1\n\n");
569 assert_eq!(text.matches("\n\n").count(), 1, "exactly one event terminator");
572 assert_eq!(text.lines().filter(|l| l.starts_with("event: ")).count(), 1);
573 assert_eq!(text.lines().filter(|l| l.starts_with("id: ")).count(), 1);
574 }
575
576 #[test]
577 fn serialize_error_is_reported_for_typed_sse_events() {
578 let error = match SseEvent::new(BadSerialize).into_raw() {
579 Ok(_) => panic!("expected serialization to fail"),
580 Err(error) => error,
581 };
582 assert!(error.message().starts_with("failed to serialize SSE data:"));
583 }
584
585 #[tokio::test]
586 async fn builder_flags_toggle_headers_and_timeout_defaults() {
587 let stream = futures_util::stream::pending::<Result<serde_json::Value>>();
588 let response = Sse::new(stream)
589 .event("tick")
590 .no_cache(false)
591 .disable_proxy_buffering(false)
592 .no_heartbeat()
593 .client_timeout(Duration::from_millis(20))
594 .into_response();
595
596 assert_eq!(response.status(), StatusCode::OK);
597 assert!(response.headers().get(CONTENT_TYPE).is_some());
598 assert!(response.headers().get(CACHE_CONTROL).is_none());
599 assert!(response.headers().get(X_ACCEL_BUFFERING).is_none());
600 }
601
602 #[tokio::test]
603 async fn client_timeout_finishes_without_emitting_a_done_event() {
604 let stream = futures_util::stream::pending::<Result<serde_json::Value>>();
605 let response = Sse::new(stream)
606 .client_timeout(Duration::from_millis(20))
607 .into_response();
608 let mut body = response.into_body();
609
610 let frame = tokio::time::timeout(Duration::from_secs(1), body.frame())
611 .await
612 .expect("timeout should trigger");
613 assert!(frame.is_none());
614 }
615
616 #[tokio::test]
617 async fn events_builder_handles_prebuilt_events() {
618 let stream = futures_util::stream::iter(vec![
619 Ok::<_, Error>(SseEvent::new(json!({ "n": 1 })).event("tick")),
620 Ok(SseEvent::raw("[DONE]").comment("final")),
621 ]);
622 let response = Sse::events(stream)
623 .event("default")
624 .done_event("[END]")
625 .into_response();
626
627 let body = body_to_string(response).await;
628 assert!(
629 body.contains("event: tick\ndata: {\"n\":1}\n\n"),
630 "body: {body}"
631 );
632 assert!(body.contains(": final"), "body: {body}");
633 assert!(body.contains("data: [DONE]"), "body: {body}");
634 assert!(
635 body.trim_end().ends_with("data: [END]"),
636 "done last: {body}"
637 );
638 }
639
640 async fn body_to_string(response: Response) -> String {
641 let bytes = response.into_body().collect().await.unwrap().to_bytes();
642 String::from_utf8(bytes.to_vec()).unwrap()
643 }
644
645 #[tokio::test]
646 async fn into_response_sets_headers_and_streams_events() {
647 let stream = futures_util::stream::iter(vec![
648 Ok::<_, Error>(json!({ "n": 1 })),
649 Ok(json!({ "n": 2 })),
650 ]);
651 let response = Sse::new(stream)
652 .event("tick")
653 .done_event("[DONE]")
654 .into_response();
655
656 assert_eq!(response.status(), StatusCode::OK);
657 assert_eq!(
658 response.headers().get(CONTENT_TYPE).unwrap(),
659 TEXT_EVENT_STREAM
660 );
661 assert_eq!(response.headers().get(CACHE_CONTROL).unwrap(), "no-cache");
662 assert_eq!(response.headers().get(X_ACCEL_BUFFERING).unwrap(), "no");
663
664 let body = body_to_string(response).await;
665 assert!(
666 body.contains("event: tick\ndata: {\"n\":1}\n\n"),
667 "body: {body}"
668 );
669 assert!(
670 body.contains("event: tick\ndata: {\"n\":2}\n\n"),
671 "body: {body}"
672 );
673 assert!(
674 body.trim_end().ends_with("data: [DONE]"),
675 "done last: {body}"
676 );
677 }
678
679 #[tokio::test]
680 async fn oversized_events_are_skipped() {
681 let stream = futures_util::stream::iter(vec![
682 Ok::<_, Error>(json!("tiny")),
683 Ok(json!(
684 "a really long value that exceeds the configured maximum size"
685 )),
686 ]);
687 let response = Sse::new(stream).max_event_size(40).into_response();
688 let body = body_to_string(response).await;
689
690 assert!(body.contains("data: \"tiny\""), "small kept: {body}");
691 assert!(!body.contains("really long"), "large skipped: {body}");
692 }
693
694 #[tokio::test]
695 async fn heartbeat_fires_while_the_source_is_idle() {
696 let stream = futures_util::stream::pending::<Result<serde_json::Value>>();
698 let response = Sse::new(stream)
699 .heartbeat(Duration::from_millis(20))
700 .into_response();
701 let mut body = response.into_body();
702
703 let frame = tokio::time::timeout(Duration::from_secs(2), body.frame())
705 .await
706 .expect("a heartbeat should arrive")
707 .unwrap()
708 .unwrap();
709 assert_eq!(
710 frame.into_data().unwrap(),
711 Bytes::from_static(HEARTBEAT_FRAME)
712 );
713 }
714
715 #[test]
716 fn sse_limiter_caps_concurrent_permits_and_frees_them_on_drop() {
717 let limiter = SseLimiter::new(2);
718
719 let first = limiter.try_acquire().expect("first is under the cap");
720 let second = limiter.try_acquire().expect("second reaches the cap");
721 assert!(limiter.try_acquire().is_none(), "third is over the cap");
722
723 drop(first);
725 let third = limiter.try_acquire().expect("a freed slot is reusable");
726
727 drop(second);
728 drop(third);
729 assert!(limiter.try_acquire().is_some());
730 }
731}