volo_http/server/response/
sse.rs

1//! SSE (Server-Sent Events) supports
2//!
3//! See [`Sse`] and [`Event`] for more details.
4
5use std::{
6    future::Future,
7    pin::Pin,
8    task::{Context, Poll},
9    time::Duration,
10};
11
12use bytes::{BufMut, Bytes, BytesMut};
13use futures::Stream;
14use http::{HeaderValue, header};
15use http_body::Frame;
16use paste::paste;
17use pin_project::pin_project;
18use tokio::time::{Instant, Sleep};
19
20use super::IntoResponse;
21use crate::{body::Body, error::BoxError, response::Response};
22
23/// Response of [SSE][sse] (Server-Sent Events), inclusing a stream with SSE [`Event`]s.
24///
25/// [sse]: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events
26pub struct Sse<S> {
27    stream: S,
28    keep_alive: Option<KeepAlive>,
29}
30
31impl<S> Sse<S> {
32    /// Create a new SSE response with the given stream of [`Event`]s.
33    pub fn new(stream: S) -> Self {
34        Self {
35            stream,
36            keep_alive: None,
37        }
38    }
39
40    /// Configure a [`KeepAlive`] for sending keep-alive messages.
41    pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
42        self.keep_alive = Some(keep_alive);
43        self
44    }
45}
46
47impl<S, E> IntoResponse for Sse<S>
48where
49    S: Stream<Item = Result<Event, E>> + Send + Sync + 'static,
50    E: Into<BoxError>,
51{
52    fn into_response(self) -> Response {
53        Response::builder()
54            .header(
55                header::CONTENT_TYPE,
56                HeaderValue::from_str(mime::TEXT_EVENT_STREAM.essence_str()).expect("infallible"),
57            )
58            .header(header::CACHE_CONTROL, HeaderValue::from_static("no-cache"))
59            .body(Body::from_body(SseBody {
60                stream: self.stream,
61                keep_alive: self.keep_alive.map(KeepAliveStream::new),
62            }))
63            .expect("infallible")
64    }
65}
66
67#[pin_project]
68struct SseBody<S> {
69    #[pin]
70    stream: S,
71    #[pin]
72    keep_alive: Option<KeepAliveStream>,
73}
74
75impl<S, E> http_body::Body for SseBody<S>
76where
77    S: Stream<Item = Result<Event, E>>,
78{
79    type Data = Bytes;
80    type Error = E;
81
82    fn poll_frame(
83        self: Pin<&mut Self>,
84        cx: &mut Context<'_>,
85    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
86        let this = self.project();
87        // Firstly, we should poll SSE stream
88        match this.stream.poll_next(cx) {
89            Poll::Pending => {
90                // If the SSE stream is unavailable, poll the keep-alive stream
91                if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
92                    keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
93                } else {
94                    Poll::Pending
95                }
96            }
97            Poll::Ready(Some(Ok(event))) => {
98                // The SSE stream is available, reset deadline of keep-alive stream
99                if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
100                    keep_alive.reset();
101                }
102                Poll::Ready(Some(Ok(Frame::data(event.finalize()))))
103            }
104            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
105            Poll::Ready(None) => Poll::Ready(None),
106        }
107    }
108}
109
110/// Message of [`Sse`]. Each event has some lines of specified fields including `event`, `data`,
111/// `id`, `retry` and comment.
112#[must_use]
113#[derive(Default)]
114pub struct Event {
115    buffer: BytesMut,
116    flags: EventFlags,
117}
118
119impl Event {
120    const DATA: &'static str = "data";
121    const EVENT: &'static str = "event";
122    const ID: &'static str = "id";
123    const RETRY: &'static str = "retry";
124
125    /// Create an empty [`Event`]
126    pub fn new() -> Self {
127        Self::default()
128    }
129
130    /// Set the event field (`event: <event-name>`) for the event message.
131    ///
132    /// # Panics
133    ///
134    /// - Panics if the event field has already set.
135    /// - Panics if the event name contains `\r` or `\n`.
136    pub fn event<T>(mut self, event: T) -> Self
137    where
138        T: AsRef<str>,
139    {
140        assert!(
141            !self.flags.contains_event(),
142            "Each `Event` cannot have more than one event field",
143        );
144        self.flags.set_event();
145
146        self.field(Self::EVENT, event.as_ref());
147
148        self
149    }
150
151    /// Set the data field(s) (`data: <content>`) for the event message.
152    ///
153    /// Each line of contents will be added `data: ` prefix when sending.
154    ///
155    /// # Panics
156    ///
157    /// - Panics if the data field has already set through [`Self::data`] or [`Self::json`].
158    pub fn data<T>(mut self, data: T) -> Self
159    where
160        T: AsRef<str>,
161    {
162        assert!(
163            !self.flags.contains_data(),
164            "Each `Event` cannot have more than one data",
165        );
166        self.flags.set_data();
167
168        for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
169            self.field(Self::DATA, line);
170        }
171
172        self
173    }
174
175    /// Set the data field (`data: <content>`) by serialized data for the event message.
176    ///
177    /// Each line of contents will be added `data: ` prefix when sending.
178    ///
179    /// # Panics
180    ///
181    /// - Panics if the data field has already set through [`Self::data`] or [`Self::json`].
182    #[cfg(feature = "json")]
183    pub fn json<T>(mut self, data: &T) -> Result<Self, crate::utils::json::Error>
184    where
185        T: serde::Serialize,
186    {
187        assert!(
188            !self.flags.contains_data(),
189            "Each `Event` cannot have more than one data",
190        );
191        self.flags.set_data();
192
193        self.buffer.extend_from_slice(Self::DATA.as_bytes());
194        self.buffer.put_u8(b':');
195        self.buffer.put_u8(b' ');
196
197        let mut writer = self.buffer.writer();
198        crate::utils::json::serialize_to_writer(&mut writer, data)?;
199        self.buffer = writer.into_inner();
200
201        Ok(self)
202    }
203
204    /// Set the id field (`id: <id>`) for the event message.
205    ///
206    /// # Panics
207    ///
208    /// - Panics if the id field has already set.
209    /// - Panics if the id contains `\r` or `\n`.
210    pub fn id<T>(mut self, id: T) -> Self
211    where
212        T: AsRef<str>,
213    {
214        assert!(
215            !self.flags.contains_id(),
216            "Each `Event` cannot have more than one id",
217        );
218        self.flags.set_id();
219
220        self.field(Self::ID, id.as_ref().as_bytes());
221
222        self
223    }
224
225    /// Set the retry field (`retry: <timeout>`) for the event message.
226    ///
227    /// # Panics
228    ///
229    /// - Panics if the timeout field has already set.
230    pub fn retry(mut self, timeout: Duration) -> Self {
231        assert!(
232            !self.flags.contains_retry(),
233            "Each `Event` cannot have more than one retry field",
234        );
235        self.flags.set_retry();
236
237        self.buffer.extend_from_slice(Self::RETRY.as_bytes());
238        self.buffer.put_u8(b':');
239        self.buffer.put_u8(b' ');
240        self.buffer
241            .extend_from_slice(itoa::Buffer::new().format(timeout.as_millis()).as_bytes());
242        self.buffer.put_u8(b'\n');
243
244        self
245    }
246
247    /// Add a comment field (`: <comment-text>`).
248    ///
249    /// # Panics
250    ///
251    /// - Panics if the comment text contains `\r` or `\n`.
252    pub fn comment<T>(mut self, comment: T) -> Self
253    where
254        T: AsRef<str>,
255    {
256        self.field("", comment.as_ref().as_bytes());
257        self
258    }
259
260    fn field<V>(&mut self, key: &'static str, val: V)
261    where
262        V: AsRef<[u8]>,
263    {
264        let val = val.as_ref();
265        assert_eq!(
266            memchr::memchr2(b'\r', b'\n', val),
267            None,
268            "Field should not contain `\\r` or `\\n`",
269        );
270
271        self.buffer.extend_from_slice(key.as_bytes());
272        self.buffer.put_u8(b':');
273        self.buffer.put_u8(b' ');
274        self.buffer.extend_from_slice(val);
275        self.buffer.put_u8(b'\n');
276    }
277
278    fn finalize(mut self) -> Bytes {
279        self.buffer.put_u8(b'\n');
280        self.buffer.freeze()
281    }
282}
283
284/// Configure a interval to send a message for keeping SSE connection alive.
285pub struct KeepAlive {
286    event: Bytes,
287    max_interval: Duration,
288}
289
290impl KeepAlive {
291    /// Create a new `KeepAlive` with an empty comment as message.
292    pub fn new() -> Self {
293        Self {
294            event: Bytes::from_static(b":\n\n"),
295            max_interval: Duration::from_secs(15),
296        }
297    }
298
299    /// Set the interval between keep-alive messages.
300    ///
301    /// Default is 15 seconds.
302    pub fn interval(mut self, interval: Duration) -> Self {
303        self.max_interval = interval;
304        self
305    }
306
307    /// Set the comment text for the keep-alive message.
308    ///
309    /// Default is an empty comment.
310    ///
311    /// # Panics
312    ///
313    /// - Panics if the comment text contains `\r` or `\n`.
314    pub fn text<T>(mut self, text: T) -> Self
315    where
316        T: AsRef<str>,
317    {
318        self.event = Event::new().comment(text).finalize();
319        self
320    }
321
322    /// Set the event of keep-alive message.
323    ///
324    /// Default is an empty comment.
325    pub fn event(mut self, event: Event) -> Self {
326        self.event = event.finalize();
327        self
328    }
329}
330
331impl Default for KeepAlive {
332    fn default() -> Self {
333        Self::new()
334    }
335}
336
337#[pin_project]
338struct KeepAliveStream {
339    keep_alive: KeepAlive,
340    #[pin]
341    alive_timer: Sleep,
342}
343
344impl KeepAliveStream {
345    fn new(keep_alive: KeepAlive) -> Self {
346        Self {
347            alive_timer: tokio::time::sleep(keep_alive.max_interval),
348            keep_alive,
349        }
350    }
351
352    fn reset(self: Pin<&mut Self>) {
353        let this = self.project();
354        this.alive_timer
355            .reset(Instant::now() + this.keep_alive.max_interval);
356    }
357
358    fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
359        let this = self.as_mut().project();
360
361        if this.alive_timer.poll(cx).is_pending() {
362            return Poll::Pending;
363        }
364
365        let event = self.keep_alive.event.clone();
366        self.reset();
367
368        Poll::Ready(event)
369    }
370}
371
372// Copied from `axum/src/response/sse.rs`
373fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> {
374    MemchrSplit {
375        needle,
376        haystack: Some(haystack),
377    }
378}
379
380struct MemchrSplit<'a> {
381    needle: u8,
382    haystack: Option<&'a [u8]>,
383}
384
385impl<'a> Iterator for MemchrSplit<'a> {
386    type Item = &'a [u8];
387    fn next(&mut self) -> Option<Self::Item> {
388        let haystack = self.haystack?;
389        if let Some(pos) = memchr::memchr(self.needle, haystack) {
390            let (front, back) = haystack.split_at(pos);
391            self.haystack = Some(&back[1..]);
392            Some(front)
393        } else {
394            self.haystack.take()
395        }
396    }
397}
398
399macro_rules! define_bitflag {
400    (struct $name:ident($type:ty) { $( $flag:ident = $val:tt, )+ }) => {
401        #[derive(Default)]
402        struct $name($type);
403
404        impl $name {
405            $(
406                paste! {
407                    const [<$flag:upper>]: $type = $val;
408
409                    #[inline]
410                    fn [<set_ $flag:lower>](&mut self) {
411                        self.0 |= Self::[<$flag:upper>];
412                    }
413
414                    #[inline]
415                    fn [<contains_ $flag:lower>](&self) -> bool {
416                        self.0 & Self::[<$flag:upper>] == Self::[<$flag:upper>]
417                    }
418                }
419            )+
420        }
421    }
422}
423
424define_bitflag! {
425    struct EventFlags(u8) {
426        DATA    = 0b0001,
427        EVENT   = 0b0010,
428        ID      = 0b0100,
429        RETRY   = 0b1000,
430    }
431}
432
433#[cfg(test)]
434mod sse_tests {
435    use std::{convert::Infallible, time::Duration};
436
437    use ahash::AHashMap;
438    use async_stream::stream;
439    use faststr::FastStr;
440    use futures::{Stream, StreamExt, stream};
441    use http::{header, method::Method};
442    use http_body_util::BodyExt;
443
444    use super::{Event, KeepAlive, Sse, memchr_split};
445    use crate::{
446        body::Body,
447        server::route::{MethodRouter, any},
448    };
449
450    impl Event {
451        fn into_string(self) -> String {
452            unsafe { String::from_utf8_unchecked(self.finalize().to_vec()) }
453        }
454    }
455
456    #[test]
457    fn event_build() {
458        // Empty event
459        assert_eq!(Event::new().into_string(), "\n");
460
461        // Single field
462        assert_eq!(
463            Event::new().event("sse-event").into_string(),
464            "event: sse-event\n\n",
465        );
466        assert_eq!(
467            Event::new().data("text-data").into_string(),
468            "data: text-data\n\n",
469        );
470        assert_eq!(Event::new().id("seq-001").into_string(), "id: seq-001\n\n");
471        assert_eq!(
472            Event::new().retry(Duration::from_secs(1)).into_string(),
473            "retry: 1000\n\n",
474        );
475        assert_eq!(
476            Event::new().comment("comment").into_string(),
477            ": comment\n\n",
478        );
479
480        // Multi-line data
481        assert_eq!(
482            Event::new().data("114\n514\n1919\n810").into_string(),
483            "data: 114\ndata: 514\ndata: 1919\ndata: 810\n\n",
484        );
485
486        // Multi-field event
487        assert_eq!(
488            Event::new()
489                .event("ping")
490                .data("hello\nworld")
491                .id("first")
492                .retry(Duration::from_secs(15))
493                .comment("test comment")
494                .into_string(),
495            "event: ping\ndata: hello\ndata: world\nid: first\nretry: 15000\n: test comment\n\n",
496        );
497    }
498
499    #[test]
500    #[should_panic]
501    fn multi_event() {
502        let _ = Event::new().event("ping").event("pong").into_string();
503    }
504
505    #[test]
506    #[should_panic]
507    fn multi_data() {
508        let _ = Event::new().data("data1").data("data2").into_string();
509    }
510
511    #[test]
512    #[should_panic]
513    fn multi_id() {
514        let _ = Event::new().id("ping-1").id("ping-2").into_string();
515    }
516
517    #[test]
518    #[should_panic]
519    fn multi_retry() {
520        let _ = Event::new()
521            .retry(Duration::from_secs(1))
522            .retry(Duration::from_secs(1))
523            .into_string();
524    }
525
526    #[test]
527    // This will not panic
528    fn multi_comment() {
529        assert_eq!(
530            Event::new()
531                .comment("114514")
532                .comment("1919810")
533                .into_string(),
534            ": 114514\n: 1919810\n\n",
535        );
536    }
537
538    #[test]
539    // Copied from `axum/src/response/sse.rs`
540    fn memchr_splitting() {
541        assert_eq!(
542            memchr_split(2, &[]).collect::<Vec<_>>(),
543            [&[]] as [&[u8]; 1]
544        );
545        assert_eq!(
546            memchr_split(2, &[2]).collect::<Vec<_>>(),
547            [&[], &[]] as [&[u8]; 2]
548        );
549        assert_eq!(
550            memchr_split(2, &[1]).collect::<Vec<_>>(),
551            [&[1]] as [&[u8]; 1]
552        );
553        assert_eq!(
554            memchr_split(2, &[1, 2]).collect::<Vec<_>>(),
555            [&[1], &[]] as [&[u8]; 2]
556        );
557        assert_eq!(
558            memchr_split(2, &[2, 1]).collect::<Vec<_>>(),
559            [&[], &[1]] as [&[u8]; 2]
560        );
561        assert_eq!(
562            memchr_split(2, &[1, 2, 2, 1]).collect::<Vec<_>>(),
563            [&[1], &[], &[1]] as [&[u8]; 3]
564        );
565    }
566
567    fn parse_event(s: &str) -> AHashMap<String, String> {
568        let mut res: AHashMap<String, String> = AHashMap::new();
569
570        for line in s.split('\n') {
571            if line.is_empty() {
572                continue;
573            }
574            let Some(pos) = line.find(": ") else {
575                continue;
576            };
577            // key: value
578            // 0123456789
579            //    |
580            //   pos
581            //
582            // key: [..pos)
583            // val: [pos+2..)
584            let mut key = line[..pos].to_owned();
585            if key.is_empty() {
586                key.push_str("comment");
587            }
588            let val = &line[pos + 2..];
589            if res.contains_key(&key) {
590                res.get_mut(&key).unwrap().push('\n');
591            } else {
592                res.insert(key.clone(), Default::default());
593            }
594            res.get_mut(&key).unwrap().push_str(val);
595        }
596
597        res
598    }
599
600    async fn poll_event(body: &mut Body) -> AHashMap<String, String> {
601        let data = body
602            .frame()
603            .await
604            .expect("No frame found")
605            .expect("Failed to pull frame")
606            .into_data()
607            .expect("Frame is not data");
608        let s = FastStr::from_bytes(data).expect("Frame data is not a valid string");
609        parse_event(&s)
610    }
611
612    #[tokio::test]
613    async fn simple_event() {
614        async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
615            Sse::new(
616                stream::iter(vec![
617                    Event::new().event("ping").data("-"),
618                    Event::new().event("ping").data("biu"),
619                    Event::new()
620                        .event("pong")
621                        .id("pong")
622                        .retry(Duration::from_secs(1))
623                        .comment(""),
624                ])
625                .map(Ok),
626            )
627        }
628        let router: MethodRouter<Option<Body>> = any(sse_handler);
629        let resp = router.call_route(Method::GET, None).await;
630        let (parts, mut body) = resp.into_parts();
631        assert_eq!(
632            parts
633                .headers
634                .get(header::CONTENT_TYPE)
635                .expect("`Content-Type` does not exist")
636                .to_str()
637                .expect("`Content-Type` is not a valid string"),
638            mime::TEXT_EVENT_STREAM.essence_str(),
639        );
640        assert_eq!(
641            parts
642                .headers
643                .get(header::CACHE_CONTROL)
644                .expect("`Cache-Control` does not exist")
645                .to_str()
646                .expect("`Cache-Control` is not a valid string"),
647            "no-cache",
648        );
649
650        // Event 1
651        let event = poll_event(&mut body).await;
652        assert_eq!(event.len(), 2);
653        assert_eq!(event.get("event").unwrap(), "ping");
654        assert_eq!(event.get("data").unwrap(), "-");
655
656        // Event 2
657        let event = poll_event(&mut body).await;
658        assert_eq!(event.len(), 2);
659        assert_eq!(event.get("event").unwrap(), "ping");
660        assert_eq!(event.get("data").unwrap(), "biu");
661
662        // Event 3
663        let event = poll_event(&mut body).await;
664        assert_eq!(event.len(), 4);
665        assert_eq!(event.get("event").unwrap(), "pong");
666        assert_eq!(event.get("id").unwrap(), "pong");
667        assert_eq!(event.get("retry").unwrap(), "1000");
668        assert_eq!(event.get("comment").unwrap(), "");
669    }
670
671    #[tokio::test]
672    async fn keep_alive() {
673        async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
674            let stream = stream! {
675                loop {
676                    yield Ok(Event::new().event("ping"));
677                    tokio::time::sleep(Duration::from_secs(5)).await;
678                }
679            };
680
681            Sse::new(stream).keep_alive(
682                KeepAlive::new()
683                    .interval(Duration::from_secs(1))
684                    .text("do not kill me"),
685            )
686        }
687
688        let router: MethodRouter<Option<Body>> = any(sse_handler);
689        let resp = router.call_route(Method::GET, None).await;
690        let (_, mut body) = resp.into_parts();
691
692        // The first message is event
693        let event_fields = poll_event(&mut body).await;
694        assert_eq!(event_fields.get("event").unwrap(), "ping");
695
696        // Then 4 keep-alive messages
697        for _ in 0..4 {
698            let event_fields = poll_event(&mut body).await;
699            assert_eq!(event_fields.get("comment").unwrap(), "do not kill me");
700        }
701
702        // After 5 seconds, event is coming
703        let event_fields = poll_event(&mut body).await;
704        assert_eq!(event_fields.get("event").unwrap(), "ping");
705    }
706
707    #[tokio::test]
708    async fn keep_alive_ends_when_the_stream_ends() {
709        async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
710            let stream = stream! {
711                // Sleep 5 seconds and send only one event
712                tokio::time::sleep(Duration::from_secs(5)).await;
713                yield Ok(Event::new().event("ping"));
714            };
715
716            Sse::new(stream).keep_alive(
717                KeepAlive::new()
718                    .interval(Duration::from_secs(1))
719                    .text("do not kill me"),
720            )
721        }
722
723        let router: MethodRouter<Option<Body>> = any(sse_handler);
724        let resp = router.call_route(Method::GET, None).await;
725        let (_, mut body) = resp.into_parts();
726
727        // 4 comments before event
728        for _ in 0..4 {
729            let event_fields = poll_event(&mut body).await;
730            assert_eq!(event_fields.get("comment").unwrap(), "do not kill me");
731        }
732
733        // Event is coming
734        let event_fields = poll_event(&mut body).await;
735        assert_eq!(event_fields.get("event").unwrap(), "ping");
736
737        // Stream finished
738        assert!(body.frame().await.is_none());
739    }
740}