1use std::{
6 future::Future,
7 pin::Pin,
8 task::{Context, Poll},
9 time::Duration,
10};
11
12use bytes::{BufMut, Bytes, BytesMut};
13use futures::Stream;
14use http::{HeaderValue, header};
15use http_body::Frame;
16use paste::paste;
17use pin_project::pin_project;
18use tokio::time::{Instant, Sleep};
19
20use super::IntoResponse;
21use crate::{body::Body, error::BoxError, response::Response};
22
23pub struct Sse<S> {
27 stream: S,
28 keep_alive: Option<KeepAlive>,
29}
30
31impl<S> Sse<S> {
32 pub fn new(stream: S) -> Self {
34 Self {
35 stream,
36 keep_alive: None,
37 }
38 }
39
40 pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
42 self.keep_alive = Some(keep_alive);
43 self
44 }
45}
46
47impl<S, E> IntoResponse for Sse<S>
48where
49 S: Stream<Item = Result<Event, E>> + Send + Sync + 'static,
50 E: Into<BoxError>,
51{
52 fn into_response(self) -> Response {
53 Response::builder()
54 .header(
55 header::CONTENT_TYPE,
56 HeaderValue::from_str(mime::TEXT_EVENT_STREAM.essence_str()).expect("infallible"),
57 )
58 .header(header::CACHE_CONTROL, HeaderValue::from_static("no-cache"))
59 .body(Body::from_body(SseBody {
60 stream: self.stream,
61 keep_alive: self.keep_alive.map(KeepAliveStream::new),
62 }))
63 .expect("infallible")
64 }
65}
66
67#[pin_project]
68struct SseBody<S> {
69 #[pin]
70 stream: S,
71 #[pin]
72 keep_alive: Option<KeepAliveStream>,
73}
74
75impl<S, E> http_body::Body for SseBody<S>
76where
77 S: Stream<Item = Result<Event, E>>,
78{
79 type Data = Bytes;
80 type Error = E;
81
82 fn poll_frame(
83 self: Pin<&mut Self>,
84 cx: &mut Context<'_>,
85 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
86 let this = self.project();
87 match this.stream.poll_next(cx) {
89 Poll::Pending => {
90 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
92 keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
93 } else {
94 Poll::Pending
95 }
96 }
97 Poll::Ready(Some(Ok(event))) => {
98 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
100 keep_alive.reset();
101 }
102 Poll::Ready(Some(Ok(Frame::data(event.finalize()))))
103 }
104 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
105 Poll::Ready(None) => Poll::Ready(None),
106 }
107 }
108}
109
110#[must_use]
113#[derive(Default)]
114pub struct Event {
115 buffer: BytesMut,
116 flags: EventFlags,
117}
118
119impl Event {
120 const DATA: &'static str = "data";
121 const EVENT: &'static str = "event";
122 const ID: &'static str = "id";
123 const RETRY: &'static str = "retry";
124
125 pub fn new() -> Self {
127 Self::default()
128 }
129
130 pub fn event<T>(mut self, event: T) -> Self
137 where
138 T: AsRef<str>,
139 {
140 assert!(
141 !self.flags.contains_event(),
142 "Each `Event` cannot have more than one event field",
143 );
144 self.flags.set_event();
145
146 self.field(Self::EVENT, event.as_ref());
147
148 self
149 }
150
151 pub fn data<T>(mut self, data: T) -> Self
159 where
160 T: AsRef<str>,
161 {
162 assert!(
163 !self.flags.contains_data(),
164 "Each `Event` cannot have more than one data",
165 );
166 self.flags.set_data();
167
168 for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
169 self.field(Self::DATA, line);
170 }
171
172 self
173 }
174
175 #[cfg(feature = "json")]
183 pub fn json<T>(mut self, data: &T) -> Result<Self, crate::utils::json::Error>
184 where
185 T: serde::Serialize,
186 {
187 assert!(
188 !self.flags.contains_data(),
189 "Each `Event` cannot have more than one data",
190 );
191 self.flags.set_data();
192
193 self.buffer.extend_from_slice(Self::DATA.as_bytes());
194 self.buffer.put_u8(b':');
195 self.buffer.put_u8(b' ');
196
197 let mut writer = self.buffer.writer();
198 crate::utils::json::serialize_to_writer(&mut writer, data)?;
199 self.buffer = writer.into_inner();
200
201 Ok(self)
202 }
203
204 pub fn id<T>(mut self, id: T) -> Self
211 where
212 T: AsRef<str>,
213 {
214 assert!(
215 !self.flags.contains_id(),
216 "Each `Event` cannot have more than one id",
217 );
218 self.flags.set_id();
219
220 self.field(Self::ID, id.as_ref().as_bytes());
221
222 self
223 }
224
225 pub fn retry(mut self, timeout: Duration) -> Self {
231 assert!(
232 !self.flags.contains_retry(),
233 "Each `Event` cannot have more than one retry field",
234 );
235 self.flags.set_retry();
236
237 self.buffer.extend_from_slice(Self::RETRY.as_bytes());
238 self.buffer.put_u8(b':');
239 self.buffer.put_u8(b' ');
240 self.buffer
241 .extend_from_slice(itoa::Buffer::new().format(timeout.as_millis()).as_bytes());
242 self.buffer.put_u8(b'\n');
243
244 self
245 }
246
247 pub fn comment<T>(mut self, comment: T) -> Self
253 where
254 T: AsRef<str>,
255 {
256 self.field("", comment.as_ref().as_bytes());
257 self
258 }
259
260 fn field<V>(&mut self, key: &'static str, val: V)
261 where
262 V: AsRef<[u8]>,
263 {
264 let val = val.as_ref();
265 assert_eq!(
266 memchr::memchr2(b'\r', b'\n', val),
267 None,
268 "Field should not contain `\\r` or `\\n`",
269 );
270
271 self.buffer.extend_from_slice(key.as_bytes());
272 self.buffer.put_u8(b':');
273 self.buffer.put_u8(b' ');
274 self.buffer.extend_from_slice(val);
275 self.buffer.put_u8(b'\n');
276 }
277
278 fn finalize(mut self) -> Bytes {
279 self.buffer.put_u8(b'\n');
280 self.buffer.freeze()
281 }
282}
283
284pub struct KeepAlive {
286 event: Bytes,
287 max_interval: Duration,
288}
289
290impl KeepAlive {
291 pub fn new() -> Self {
293 Self {
294 event: Bytes::from_static(b":\n\n"),
295 max_interval: Duration::from_secs(15),
296 }
297 }
298
299 pub fn interval(mut self, interval: Duration) -> Self {
303 self.max_interval = interval;
304 self
305 }
306
307 pub fn text<T>(mut self, text: T) -> Self
315 where
316 T: AsRef<str>,
317 {
318 self.event = Event::new().comment(text).finalize();
319 self
320 }
321
322 pub fn event(mut self, event: Event) -> Self {
326 self.event = event.finalize();
327 self
328 }
329}
330
331impl Default for KeepAlive {
332 fn default() -> Self {
333 Self::new()
334 }
335}
336
337#[pin_project]
338struct KeepAliveStream {
339 keep_alive: KeepAlive,
340 #[pin]
341 alive_timer: Sleep,
342}
343
344impl KeepAliveStream {
345 fn new(keep_alive: KeepAlive) -> Self {
346 Self {
347 alive_timer: tokio::time::sleep(keep_alive.max_interval),
348 keep_alive,
349 }
350 }
351
352 fn reset(self: Pin<&mut Self>) {
353 let this = self.project();
354 this.alive_timer
355 .reset(Instant::now() + this.keep_alive.max_interval);
356 }
357
358 fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
359 let this = self.as_mut().project();
360
361 if this.alive_timer.poll(cx).is_pending() {
362 return Poll::Pending;
363 }
364
365 let event = self.keep_alive.event.clone();
366 self.reset();
367
368 Poll::Ready(event)
369 }
370}
371
372fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> {
374 MemchrSplit {
375 needle,
376 haystack: Some(haystack),
377 }
378}
379
380struct MemchrSplit<'a> {
381 needle: u8,
382 haystack: Option<&'a [u8]>,
383}
384
385impl<'a> Iterator for MemchrSplit<'a> {
386 type Item = &'a [u8];
387 fn next(&mut self) -> Option<Self::Item> {
388 let haystack = self.haystack?;
389 if let Some(pos) = memchr::memchr(self.needle, haystack) {
390 let (front, back) = haystack.split_at(pos);
391 self.haystack = Some(&back[1..]);
392 Some(front)
393 } else {
394 self.haystack.take()
395 }
396 }
397}
398
399macro_rules! define_bitflag {
400 (struct $name:ident($type:ty) { $( $flag:ident = $val:tt, )+ }) => {
401 #[derive(Default)]
402 struct $name($type);
403
404 impl $name {
405 $(
406 paste! {
407 const [<$flag:upper>]: $type = $val;
408
409 #[inline]
410 fn [<set_ $flag:lower>](&mut self) {
411 self.0 |= Self::[<$flag:upper>];
412 }
413
414 #[inline]
415 fn [<contains_ $flag:lower>](&self) -> bool {
416 self.0 & Self::[<$flag:upper>] == Self::[<$flag:upper>]
417 }
418 }
419 )+
420 }
421 }
422}
423
424define_bitflag! {
425 struct EventFlags(u8) {
426 DATA = 0b0001,
427 EVENT = 0b0010,
428 ID = 0b0100,
429 RETRY = 0b1000,
430 }
431}
432
433#[cfg(test)]
434mod sse_tests {
435 use std::{convert::Infallible, time::Duration};
436
437 use ahash::AHashMap;
438 use async_stream::stream;
439 use faststr::FastStr;
440 use futures::{Stream, StreamExt, stream};
441 use http::{header, method::Method};
442 use http_body_util::BodyExt;
443
444 use super::{Event, KeepAlive, Sse, memchr_split};
445 use crate::{
446 body::Body,
447 server::route::{MethodRouter, any},
448 };
449
450 impl Event {
451 fn into_string(self) -> String {
452 unsafe { String::from_utf8_unchecked(self.finalize().to_vec()) }
453 }
454 }
455
456 #[test]
457 fn event_build() {
458 assert_eq!(Event::new().into_string(), "\n");
460
461 assert_eq!(
463 Event::new().event("sse-event").into_string(),
464 "event: sse-event\n\n",
465 );
466 assert_eq!(
467 Event::new().data("text-data").into_string(),
468 "data: text-data\n\n",
469 );
470 assert_eq!(Event::new().id("seq-001").into_string(), "id: seq-001\n\n");
471 assert_eq!(
472 Event::new().retry(Duration::from_secs(1)).into_string(),
473 "retry: 1000\n\n",
474 );
475 assert_eq!(
476 Event::new().comment("comment").into_string(),
477 ": comment\n\n",
478 );
479
480 assert_eq!(
482 Event::new().data("114\n514\n1919\n810").into_string(),
483 "data: 114\ndata: 514\ndata: 1919\ndata: 810\n\n",
484 );
485
486 assert_eq!(
488 Event::new()
489 .event("ping")
490 .data("hello\nworld")
491 .id("first")
492 .retry(Duration::from_secs(15))
493 .comment("test comment")
494 .into_string(),
495 "event: ping\ndata: hello\ndata: world\nid: first\nretry: 15000\n: test comment\n\n",
496 );
497 }
498
499 #[test]
500 #[should_panic]
501 fn multi_event() {
502 let _ = Event::new().event("ping").event("pong").into_string();
503 }
504
505 #[test]
506 #[should_panic]
507 fn multi_data() {
508 let _ = Event::new().data("data1").data("data2").into_string();
509 }
510
511 #[test]
512 #[should_panic]
513 fn multi_id() {
514 let _ = Event::new().id("ping-1").id("ping-2").into_string();
515 }
516
517 #[test]
518 #[should_panic]
519 fn multi_retry() {
520 let _ = Event::new()
521 .retry(Duration::from_secs(1))
522 .retry(Duration::from_secs(1))
523 .into_string();
524 }
525
526 #[test]
527 fn multi_comment() {
529 assert_eq!(
530 Event::new()
531 .comment("114514")
532 .comment("1919810")
533 .into_string(),
534 ": 114514\n: 1919810\n\n",
535 );
536 }
537
538 #[test]
539 fn memchr_splitting() {
541 assert_eq!(
542 memchr_split(2, &[]).collect::<Vec<_>>(),
543 [&[]] as [&[u8]; 1]
544 );
545 assert_eq!(
546 memchr_split(2, &[2]).collect::<Vec<_>>(),
547 [&[], &[]] as [&[u8]; 2]
548 );
549 assert_eq!(
550 memchr_split(2, &[1]).collect::<Vec<_>>(),
551 [&[1]] as [&[u8]; 1]
552 );
553 assert_eq!(
554 memchr_split(2, &[1, 2]).collect::<Vec<_>>(),
555 [&[1], &[]] as [&[u8]; 2]
556 );
557 assert_eq!(
558 memchr_split(2, &[2, 1]).collect::<Vec<_>>(),
559 [&[], &[1]] as [&[u8]; 2]
560 );
561 assert_eq!(
562 memchr_split(2, &[1, 2, 2, 1]).collect::<Vec<_>>(),
563 [&[1], &[], &[1]] as [&[u8]; 3]
564 );
565 }
566
567 fn parse_event(s: &str) -> AHashMap<String, String> {
568 let mut res: AHashMap<String, String> = AHashMap::new();
569
570 for line in s.split('\n') {
571 if line.is_empty() {
572 continue;
573 }
574 let Some(pos) = line.find(": ") else {
575 continue;
576 };
577 let mut key = line[..pos].to_owned();
585 if key.is_empty() {
586 key.push_str("comment");
587 }
588 let val = &line[pos + 2..];
589 if res.contains_key(&key) {
590 res.get_mut(&key).unwrap().push('\n');
591 } else {
592 res.insert(key.clone(), Default::default());
593 }
594 res.get_mut(&key).unwrap().push_str(val);
595 }
596
597 res
598 }
599
600 async fn poll_event(body: &mut Body) -> AHashMap<String, String> {
601 let data = body
602 .frame()
603 .await
604 .expect("No frame found")
605 .expect("Failed to pull frame")
606 .into_data()
607 .expect("Frame is not data");
608 let s = FastStr::from_bytes(data).expect("Frame data is not a valid string");
609 parse_event(&s)
610 }
611
612 #[tokio::test]
613 async fn simple_event() {
614 async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
615 Sse::new(
616 stream::iter(vec![
617 Event::new().event("ping").data("-"),
618 Event::new().event("ping").data("biu"),
619 Event::new()
620 .event("pong")
621 .id("pong")
622 .retry(Duration::from_secs(1))
623 .comment(""),
624 ])
625 .map(Ok),
626 )
627 }
628 let router: MethodRouter<Option<Body>> = any(sse_handler);
629 let resp = router.call_route(Method::GET, None).await;
630 let (parts, mut body) = resp.into_parts();
631 assert_eq!(
632 parts
633 .headers
634 .get(header::CONTENT_TYPE)
635 .expect("`Content-Type` does not exist")
636 .to_str()
637 .expect("`Content-Type` is not a valid string"),
638 mime::TEXT_EVENT_STREAM.essence_str(),
639 );
640 assert_eq!(
641 parts
642 .headers
643 .get(header::CACHE_CONTROL)
644 .expect("`Cache-Control` does not exist")
645 .to_str()
646 .expect("`Cache-Control` is not a valid string"),
647 "no-cache",
648 );
649
650 let event = poll_event(&mut body).await;
652 assert_eq!(event.len(), 2);
653 assert_eq!(event.get("event").unwrap(), "ping");
654 assert_eq!(event.get("data").unwrap(), "-");
655
656 let event = poll_event(&mut body).await;
658 assert_eq!(event.len(), 2);
659 assert_eq!(event.get("event").unwrap(), "ping");
660 assert_eq!(event.get("data").unwrap(), "biu");
661
662 let event = poll_event(&mut body).await;
664 assert_eq!(event.len(), 4);
665 assert_eq!(event.get("event").unwrap(), "pong");
666 assert_eq!(event.get("id").unwrap(), "pong");
667 assert_eq!(event.get("retry").unwrap(), "1000");
668 assert_eq!(event.get("comment").unwrap(), "");
669 }
670
671 #[tokio::test]
672 async fn keep_alive() {
673 async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
674 let stream = stream! {
675 loop {
676 yield Ok(Event::new().event("ping"));
677 tokio::time::sleep(Duration::from_secs(5)).await;
678 }
679 };
680
681 Sse::new(stream).keep_alive(
682 KeepAlive::new()
683 .interval(Duration::from_secs(1))
684 .text("do not kill me"),
685 )
686 }
687
688 let router: MethodRouter<Option<Body>> = any(sse_handler);
689 let resp = router.call_route(Method::GET, None).await;
690 let (_, mut body) = resp.into_parts();
691
692 let event_fields = poll_event(&mut body).await;
694 assert_eq!(event_fields.get("event").unwrap(), "ping");
695
696 for _ in 0..4 {
698 let event_fields = poll_event(&mut body).await;
699 assert_eq!(event_fields.get("comment").unwrap(), "do not kill me");
700 }
701
702 let event_fields = poll_event(&mut body).await;
704 assert_eq!(event_fields.get("event").unwrap(), "ping");
705 }
706
707 #[tokio::test]
708 async fn keep_alive_ends_when_the_stream_ends() {
709 async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
710 let stream = stream! {
711 tokio::time::sleep(Duration::from_secs(5)).await;
713 yield Ok(Event::new().event("ping"));
714 };
715
716 Sse::new(stream).keep_alive(
717 KeepAlive::new()
718 .interval(Duration::from_secs(1))
719 .text("do not kill me"),
720 )
721 }
722
723 let router: MethodRouter<Option<Body>> = any(sse_handler);
724 let resp = router.call_route(Method::GET, None).await;
725 let (_, mut body) = resp.into_parts();
726
727 for _ in 0..4 {
729 let event_fields = poll_event(&mut body).await;
730 assert_eq!(event_fields.get("comment").unwrap(), "do not kill me");
731 }
732
733 let event_fields = poll_event(&mut body).await;
735 assert_eq!(event_fields.get("event").unwrap(), "ping");
736
737 assert!(body.frame().await.is_none());
739 }
740}