sse_stream/
body.rs

1use std::{
2    pin::Pin,
3    task::{Context, Poll, ready},
4    time::Duration,
5};
6
7use bytes::Bytes;
8use futures_util::Stream;
9use http_body::{Body, Frame};
10
11use crate::Sse;
12
13pin_project_lite::pin_project! {
14    pub struct SseBody<S, T = NeverTimer> {
15        #[pin]
16        pub event_stream: S,
17        #[pin]
18        pub keep_alive: Option<KeepAliveStream<T>>,
19    }
20}
21
22impl<S, E> SseBody<S, NeverTimer>
23where
24    S: Stream<Item = Result<Sse, E>>,
25{
26    pub fn new(stream: S) -> Self {
27        Self {
28            event_stream: stream,
29            keep_alive: None,
30        }
31    }
32}
33
34impl<S, E, T> SseBody<S, T>
35where
36    S: Stream<Item = Result<Sse, E>>,
37    T: Timer,
38{
39    pub fn new_keep_alive(stream: S, keep_alive: KeepAlive) -> Self {
40        Self {
41            event_stream: stream,
42            keep_alive: Some(KeepAliveStream::new(keep_alive)),
43        }
44    }
45
46    pub fn with_keep_alive<T2: Timer>(self, keep_alive: KeepAlive) -> SseBody<S, T2> {
47        SseBody {
48            event_stream: self.event_stream,
49            keep_alive: Some(KeepAliveStream::new(keep_alive)),
50        }
51    }
52}
53
54impl<S, E, T> Body for SseBody<S, T>
55where
56    S: Stream<Item = Result<Sse, E>>,
57    T: Timer,
58{
59    type Data = Bytes;
60    type Error = E;
61
62    fn poll_frame(
63        self: Pin<&mut Self>,
64        cx: &mut Context<'_>,
65    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
66        let this = self.project();
67
68        match this.event_stream.poll_next(cx) {
69            Poll::Pending => {
70                if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
71                    keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
72                } else {
73                    Poll::Pending
74                }
75            }
76            Poll::Ready(Some(Ok(event))) => {
77                if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
78                    keep_alive.reset();
79                }
80                Poll::Ready(Some(Ok(Frame::data(event.into()))))
81            }
82            Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
83            Poll::Ready(None) => Poll::Ready(None),
84        }
85    }
86}
87
88/// Configure the interval between keep-alive messages, the content
89/// of each message, and the associated stream.
90#[derive(Debug, Clone)]
91#[must_use]
92pub struct KeepAlive {
93    event: Bytes,
94    max_interval: Duration,
95}
96
97impl KeepAlive {
98    /// Create a new `KeepAlive`.
99    pub fn new() -> Self {
100        Self {
101            event: Bytes::from_static(b":\n\n"),
102            max_interval: Duration::from_secs(15),
103        }
104    }
105
106    /// Customize the interval between keep-alive messages.
107    ///
108    /// Default is 15 seconds.
109    pub fn interval(mut self, time: Duration) -> Self {
110        self.max_interval = time;
111        self
112    }
113
114    /// Customize the event of the keep-alive message.
115    ///
116    /// Default is an empty comment.
117    ///
118    /// # Panics
119    ///
120    /// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE
121    /// comments.
122    pub fn event(mut self, event: Sse) -> Self {
123        self.event = event.into();
124        self
125    }
126
127    /// Customize the event of the keep-alive message with a comment
128    pub fn comment(mut self, comment: &str) -> Self {
129        self.event = format!(": {}\n\n", comment).into();
130        self
131    }
132}
133
134impl Default for KeepAlive {
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140pub trait Timer: Future<Output = ()> {
141    fn reset(self: Pin<&mut Self>, instant: std::time::Instant);
142    fn from_duration(duration: Duration) -> Self;
143}
144
145pub struct NeverTimer;
146
147impl Future for NeverTimer {
148    type Output = ();
149
150    fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
151        Poll::Pending
152    }
153}
154
155impl Timer for NeverTimer {
156    fn from_duration(_: Duration) -> Self {
157        Self
158    }
159
160    fn reset(self: Pin<&mut Self>, _: std::time::Instant) {
161        // No-op
162    }
163}
164
165pin_project_lite::pin_project! {
166    #[derive(Debug)]
167    struct KeepAliveStream<S> {
168        keep_alive: KeepAlive,
169        #[pin]
170        alive_timer: S,
171    }
172}
173
174impl<S> KeepAliveStream<S>
175where
176    S: Timer,
177{
178    fn new(keep_alive: KeepAlive) -> Self {
179        Self {
180            alive_timer: S::from_duration(keep_alive.max_interval),
181            keep_alive,
182        }
183    }
184
185    fn reset(self: Pin<&mut Self>) {
186        let this = self.project();
187        this.alive_timer
188            .reset(std::time::Instant::now() + this.keep_alive.max_interval);
189    }
190
191    fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
192        let this = self.as_mut().project();
193
194        ready!(this.alive_timer.poll(cx));
195
196        let event = this.keep_alive.event.clone();
197
198        self.reset();
199
200        Poll::Ready(event)
201    }
202}