Skip to main content

sse_reqwest_client/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3#![warn(rustdoc::missing_crate_level_docs)]
4
5use std::{
6    fmt,
7    future::Future,
8    num::NonZeroUsize,
9    pin::Pin,
10    sync::Arc,
11    task::{Context, Poll, ready},
12    time::Duration,
13};
14
15use bytes::Bytes;
16use futures_core::stream::Stream;
17use reqwest::{RequestBuilder, StatusCode, header::HeaderValue};
18use thiserror::Error;
19use tokio::time::{Instant, Sleep, sleep};
20
21pub use sse_core::SseRetryConfig;
22use sse_core::{
23    MessageEvent, PayloadTooLargeError, SseDecoder, SseEvent as SseEventCore, SseStream,
24    SseStreamError,
25};
26
27type ByteStream = Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send + Sync>>;
28type ConnectFuture =
29    Pin<Box<dyn Future<Output = Result<reqwest::Response, reqwest::Error>> + Send + Sync>>;
30
31/// An alias for [`Result`] with the error defaulting to [`Error`](enum@Error).
32pub type Result<T, E = Error> = std::result::Result<T, E>;
33
34/// Errors that can occur during the lifecycle of an [`EventSource`] connection.
35#[derive(Debug, Error)]
36pub enum Error {
37    /// The server responded with a non-200 HTTP status code.
38    #[error("unexpected HTTP status code: {0}")]
39    Status(StatusCode),
40    /// The [`RequestBuilder`] could not be cloned (e.g., it contains a streaming body).
41    #[error("request builder could not be cloned (e.g., non-restartable body stream)")]
42    UncloneableRequest,
43    /// The server's response lacked the `text/event-stream` Content-Type.
44    #[error("invalid response HTTP Content-Type")]
45    InvalidContentType,
46    /// The server's response did not contain a Content-Type header.
47    #[error("response HTTP Content-Type missing")]
48    MissingContentType,
49    /// The client exhausted all retry attempts without successfully reconnecting.
50    #[error("couldn't reconnect to SSE server in {0} attempts: {1}")]
51    Timeout(u32, SseErrorEvent),
52    /// The server sent an event payload that exceeded the configured buffer limit.
53    #[error("server sent an oversized payload exceeding the allotted buffer")]
54    PayloadTooLarge(#[from] PayloadTooLargeError),
55    /// The `Last-Event-ID` provided by the server contains bytes that cannot be
56    /// safely converted into a valid HTTP header.
57    #[error("Last-Event-ID cannot be converted to a valid HTTP header: {0}")]
58    InvalidLastEventId(reqwest::header::InvalidHeaderValue),
59}
60
61/// The connection state mapping to the JavaScript [`EventSource`](https://developer.mozilla.org/en-US/docs/Web/API/EventSource) API [`readyState`](https://developer.mozilla.org/en-US/docs/Web/API/EventSource/readyState).
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
63#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
64pub enum ReadyState {
65    /// The connection has not yet been established, or it was closed and the client is reconnecting.
66    Connecting = 0,
67    /// The connection is open and ready to receive events.
68    Open = 1,
69    /// The connection is permanently closed and will not reconnect.
70    Closed = 2,
71}
72
73enum State {
74    Disconnected,
75    Connecting(ConnectFuture),
76    Open,
77    Sleeping(Pin<Box<Sleep>>),
78    Closed,
79}
80
81impl fmt::Debug for State {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        match self {
84            State::Disconnected => f.write_str("Disconnected"),
85            State::Connecting(_) => f.write_str("Connecting(_)"),
86            State::Open => f.write_str("Open"),
87            State::Sleeping(fut) => f.debug_tuple("Sleeping").field(fut).finish(),
88            State::Closed => f.write_str("Closed"),
89        }
90    }
91}
92
93/// Transient errors that cause the stream to drop and trigger an automatic reconnection.
94#[derive(Debug, Error)]
95pub enum SseErrorEvent {
96    /// The server gracefully closed the TCP connection (EOF) while the stream was active.
97    #[error("server cleanly closed the connection (EOF)")]
98    Eof,
99
100    /// The server responded with an HTTP status code designated as retryable (e.g., 502, 503).
101    #[error("transient HTTP error: {0}")]
102    Http(StatusCode),
103
104    /// A network-level error occurred, such as a dropped socket, DNS failure, or read timeout.
105    #[error("network or transport error: {0}")]
106    Network(#[from] reqwest::Error),
107}
108
109/// High-level events emitted by the [`EventSource`] stream.
110#[derive(Debug)]
111pub enum SseEvent {
112    /// Emitted when the underlying HTTP connection is successfully established.
113    Open,
114    /// A parsed message event from the server.
115    Message(MessageEvent),
116    /// Emitted when the connection drops but the client is actively attempting to reconnect.
117    ///
118    /// This gives the application a chance to log the interruption or update UI state
119    /// while the exponential backoff handles the recovery in the background.
120    Error(SseErrorEvent),
121}
122
123impl SseEvent {
124    /// Consumes the event and returns the underlying [`MessageEvent`] if this is a standard message.
125    ///
126    /// This is particularly useful in stream combinators:
127    /// ```no_run
128    /// # use futures_util::TryStreamExt;
129    /// # use sse_reqwest_client::*;
130    /// # tokio_test::block_on(async {
131    /// # let client = reqwest::Client::new();
132    /// # let mut stream = client.get("https://example.com/events").into_event_source();
133    /// let messages: Vec<_> = stream
134    ///     .try_filter_map(async |res| Ok(res.into_message()))
135    ///     .try_collect()
136    ///     .await?;
137    /// # Result::<()>::Ok(())
138    /// # });
139    /// ```
140    pub fn into_message(self) -> Option<MessageEvent> {
141        match self {
142            Self::Message(msg) => Some(msg),
143            Self::Open | Self::Error(_) => None,
144        }
145    }
146
147    /// Returns a reference to the underlying [`MessageEvent`] if this is a standard message.
148    pub fn as_message(&self) -> Option<&MessageEvent> {
149        match self {
150            Self::Message(msg) => Some(msg),
151            Self::Open | Self::Error(_) => None,
152        }
153    }
154
155    /// Returns a mutable reference to the underlying [`MessageEvent`] if this is a standard message.
156    pub fn as_message_mut(&mut self) -> Option<&mut MessageEvent> {
157        match self {
158            Self::Message(msg) => Some(msg),
159            Self::Open | Self::Error(_) => None,
160        }
161    }
162}
163
164impl From<MessageEvent> for SseEvent {
165    fn from(event: MessageEvent) -> Self {
166        Self::Message(event)
167    }
168}
169
170impl From<SseErrorEvent> for SseEvent {
171    fn from(err: SseErrorEvent) -> Self {
172        Self::Error(err)
173    }
174}
175
176/// Error indicating that an [`SseEvent`] could not be converted into a [`MessageEvent`].
177#[derive(Debug, Error)]
178#[error("couldn't convert Event::{} into a MessageEvent", match .0 {
179    SseEvent::Open => "Open",
180    SseEvent::Message(_) => "Message",
181    SseEvent::Error(_) => "Error"
182})]
183pub struct FromMessageEventError(pub SseEvent);
184
185impl TryFrom<SseEvent> for MessageEvent {
186    type Error = FromMessageEventError;
187
188    fn try_from(ev: SseEvent) -> Result<Self, Self::Error> {
189        match ev {
190            SseEvent::Message(msg) => Ok(msg),
191            ev => Err(FromMessageEventError(ev)),
192        }
193    }
194}
195
196/// A builder for configuring an [`EventSource`] connection.
197///
198/// # Example
199/// ```rust,no_run
200/// use std::{num::NonZeroUsize, time::Duration};
201/// use sse_reqwest_client::{RequestBuilderExt, EventSourceBuilder, SseRetryConfig};
202///
203/// # #[tokio::main]
204/// # async fn main() {
205/// let client = reqwest::Client::new();
206/// let req = client.get("https://api.example.com/stream");
207///
208/// // Create a stream with a strict 1MB payload limit and a custom retry delay
209/// let stream = req.into_event_source_builder()
210///     .retry_config(SseRetryConfig::new())
211///     .initial_reconnection_time(Duration::from_secs(5))
212///     .max_payload_size(NonZeroUsize::new(1024 * 1024).unwrap())
213///     .build();
214/// # }
215/// ```
216#[derive(Debug)]
217pub struct EventSourceBuilder {
218    req: RequestBuilder,
219    retry_config: SseRetryConfig,
220    reconnection_time_ms: u32,
221    max_payload_size: Option<NonZeroUsize>,
222    last_event_id: Option<Arc<str>>,
223    retry_transient_errors: bool,
224    successful_connection_threshold: Duration,
225}
226
227impl EventSourceBuilder {
228    /// Creates a new builder wrapping the given [`reqwest::RequestBuilder`].
229    #[must_use]
230    pub fn new(req: RequestBuilder) -> Self {
231        Self {
232            req,
233            reconnection_time_ms: 3000, // Default per SSE Spec
234            retry_config: SseRetryConfig::new(),
235            max_payload_size: None, // use default
236            last_event_id: None,
237            retry_transient_errors: false,
238            successful_connection_threshold: Duration::from_secs(5),
239        }
240    }
241
242    /// Applies a custom retry configuration for automatic reconnections.
243    #[inline]
244    #[must_use]
245    pub fn retry_config(mut self, retry_config: SseRetryConfig) -> Self {
246        self.retry_config = retry_config;
247        self
248    }
249
250    /// Sets the base delay to wait before attempting to reconnect.
251    ///
252    /// This delay may be overridden by the server using `retry` events.
253    #[inline]
254    #[must_use]
255    pub fn initial_reconnection_time(mut self, reconnection_time: Duration) -> Self {
256        self.reconnection_time_ms = reconnection_time
257            .as_millis()
258            .try_into()
259            .expect("Read duration too long");
260        self
261    }
262
263    /// Configures the maximum allowed byte size for a single event payload.
264    #[inline]
265    #[must_use]
266    pub fn max_payload_size(mut self, max_payload_size: NonZeroUsize) -> Self {
267        self.max_payload_size = Some(max_payload_size);
268        self
269    }
270
271    /// Sets the initial `Last-Event-ID` to send with the first connection request.
272    ///
273    /// This is useful for resuming a dropped stream from a previously saved state.
274    #[inline]
275    #[must_use]
276    pub fn last_event_id(mut self, id: impl Into<Arc<str>>) -> Self {
277        self.last_event_id = Some(id.into());
278        self
279    }
280
281    /// Enables automatic retries for transient HTTP status codes.
282    ///
283    /// By default, the [`EventSource`] strictly follows the WHATWG specification and will
284    /// permanently close the stream on any non-200 HTTP response.
285    ///
286    /// Setting this to `true` allows the client to automatically back off and retry when
287    /// encountering temporary proxy or server issues. The following status codes are
288    /// considered transient:
289    /// * `408 Request Timeout`
290    /// * `429 Too Many Requests`
291    /// * `502 Bad Gateway`
292    /// * `503 Service Unavailable`
293    /// * `504 Gateway Timeout`
294    #[inline]
295    #[must_use]
296    pub fn retry_transient_errors(mut self, retry: bool) -> Self {
297        self.retry_transient_errors = retry;
298        self
299    }
300
301    /// Sets the minimum duration a connection must remain open to be considered "successful"
302    /// and reset the exponential backoff counter. Defaults to 5 seconds.
303    #[inline]
304    #[must_use]
305    pub fn successful_connection_threshold(mut self, threshold: Duration) -> Self {
306        self.successful_connection_threshold = threshold;
307        self
308    }
309
310    /// Consumes the builder and returns the configured [`EventSource`].
311    #[must_use]
312    pub fn build(self) -> EventSource {
313        let mut decoder = match self.max_payload_size {
314            Some(max_payload_size) => SseDecoder::with_limit(max_payload_size),
315            None => SseDecoder::new(),
316        };
317        decoder.reconnect_with_id(self.last_event_id);
318
319        EventSource {
320            req: (self.req)
321                .header(reqwest::header::ACCEPT, "text/event-stream")
322                .header(reqwest::header::CACHE_CONTROL, "no-store"),
323            reconnection_time_ms: self.reconnection_time_ms,
324            connection_attempt: 0,
325            connected_since: None,
326            retry_config: self.retry_config,
327            retry_transient_errors: self.retry_transient_errors,
328            successful_connection_threshold: self.successful_connection_threshold,
329            stream: SseStream::with_decoder(decoder),
330            state: State::Disconnected,
331        }
332    }
333}
334
335/// A reconnecting stream of Server-Sent Events.
336pub struct EventSource {
337    req: RequestBuilder,
338    reconnection_time_ms: u32,
339    connection_attempt: u32,
340    connected_since: Option<Instant>,
341    retry_config: SseRetryConfig,
342    retry_transient_errors: bool,
343    successful_connection_threshold: Duration,
344    stream: SseStream<ByteStream>,
345    state: State,
346}
347
348impl fmt::Debug for EventSource {
349    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350        f.debug_struct("EventSource")
351            .field("req", &self.req)
352            .field("reconnection_time_ms", &self.reconnection_time_ms)
353            .field("connection_attempt", &self.connection_attempt)
354            .field("retry_config", &self.retry_config)
355            .field("retry_transient_errors", &self.retry_transient_errors)
356            .field("state", &self.state)
357            .field(
358                "stream.last_event_id()",
359                &self.stream.last_event_id().map(|id| &**id),
360            )
361            .field("stream.is_closed()", &self.stream.is_closed())
362            .finish_non_exhaustive()
363    }
364}
365
366impl EventSource {
367    /// Creates a new [`EventSource`] from the given request with default configurations.
368    #[must_use]
369    pub fn new(req: RequestBuilder) -> Self {
370        Self::builder(req).build()
371    }
372
373    /// Creates a builder to customize the [`EventSource`] before connecting.
374    #[must_use]
375    pub fn builder(req: RequestBuilder) -> EventSourceBuilder {
376        EventSourceBuilder::new(req)
377    }
378
379    /// Closes the underlying SSE connection.
380    ///
381    /// This method terminates the active HTTP request, effectively dropping the
382    /// inner stream. Calling [`close`](Self::close) is idempotent; if the connection is already
383    /// closed, this does nothing and is perfectly safe to call multiple times.
384    ///
385    /// While this halts all incoming events and stops the automatic reconnection
386    /// loop, it does not consume the [`EventSource`]. You can manually restart the
387    /// stream and initiate a fresh connection at any time by calling
388    /// [`force_reconnect()`](Self::force_reconnect).
389    ///
390    /// # Example
391    /// ```rust,no_run
392    /// # use futures_util::StreamExt;
393    /// # use sse_reqwest_client::{RequestBuilderExt, ReadyState};
394    /// # #[tokio::main]
395    /// # async fn main() {
396    /// let client = reqwest::Client::new();
397    /// let mut stream = client.get("https://api.example.com/stream").into_event_source();
398    ///
399    /// // ... later, to gracefully stop listening to events (e.g., app goes to background):
400    /// stream.close();
401    /// assert_eq!(stream.ready_state(), ReadyState::Closed);
402    ///
403    /// // The stream will now yield None
404    /// assert!(stream.next().await.is_none());
405    /// # }
406    /// ```
407    pub fn close(&mut self) {
408        self.stream.close();
409        self.state = State::Closed;
410    }
411
412    /// Returns the current connection state.
413    #[inline]
414    #[must_use]
415    pub fn ready_state(&self) -> ReadyState {
416        match &self.state {
417            State::Disconnected | State::Connecting(_) | State::Sleeping(_) => {
418                ReadyState::Connecting
419            }
420            State::Open => ReadyState::Open,
421            State::Closed => ReadyState::Closed,
422        }
423    }
424
425    /// Returns the most recently received `Last-Event-ID`, if any.
426    #[inline]
427    #[must_use]
428    pub fn last_event_id(&self) -> Option<&Arc<str>> {
429        self.stream.last_event_id()
430    }
431
432    /// Terminates the current connection and immediately attempts to reconnect.
433    ///
434    /// Because [`EventSource`] automatically handles network drops and reconnections, you typically
435    /// do not need to call this manually. However, it is useful in specific scenarios, such as:
436    ///
437    /// * **Bypassing Backoff:** The server state has heavily desynced, and you want to bypass the
438    ///   current exponential backoff timer to reconnect instantly.
439    /// * **Manual Revival:** You previously called [`close()`](Self::close) to pause the stream,
440    ///   and now want to resume listening for events.
441    ///
442    /// This method resets the connection attempt counter, meaning the next connection attempt will
443    /// happen immediately without any retry delay. And will continue with the exponential backoff
444    /// reset.
445    ///
446    /// # Example
447    /// ```rust,no_run
448    /// # use sse_reqwest_client::{RequestBuilderExt, ReadyState};
449    /// # #[tokio::main]
450    /// # async fn main() {
451    /// let client = reqwest::Client::new();
452    /// let mut stream = client.get("https://api.example.com/stream").into_event_source();
453    ///
454    /// // If your application detects via the OS that network connectivity was restored,
455    /// // you can manually trigger an immediate reconnect to bypass active backoff delays.
456    /// stream.force_reconnect();
457    /// assert_eq!(stream.ready_state(), ReadyState::Connecting);
458    /// # }
459    /// ```
460    #[inline]
461    pub fn force_reconnect(&mut self) {
462        self.stream.close();
463        self.connection_attempt = 0;
464        self.state = State::Disconnected;
465    }
466
467    /// Terminates the current connection and immediately attempts to reconnect,
468    /// explicitly overriding the `Last-Event-ID` sent to the server.
469    ///
470    /// This is useful if your application state has desynced and you need to
471    /// force the server to rewind or fast-forward to a specific point in the stream.
472    ///
473    /// See [force_reconnect()](Self::force_reconnect) for more info.
474    #[inline]
475    pub fn force_reconnect_with_id(&mut self, id: Option<Arc<str>>) {
476        self.stream.close_with_id(id);
477        self.connection_attempt = 0;
478        self.state = State::Disconnected;
479    }
480
481    fn go_to_sleep(&mut self, cause: SseErrorEvent) -> Result<SseEvent> {
482        if let Some(connected_since) = self.connected_since.take() {
483            if self.successful_connection_threshold <= connected_since.elapsed() {
484                self.connection_attempt = 0;
485            }
486        }
487
488        let wait_dur = (self.retry_config)
489            .calculate_backoff(self.reconnection_time_ms, self.connection_attempt);
490        self.connection_attempt += 1;
491        if let Some(dur) = wait_dur {
492            self.state = State::Sleeping(Box::pin(sleep(dur)));
493            Ok(SseEvent::Error(cause))
494        } else {
495            self.close();
496            Err(Error::Timeout(self.connection_attempt, cause))
497        }
498    }
499}
500
501impl Stream for EventSource {
502    type Item = Result<SseEvent>;
503
504    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
505        let slf = &mut *self;
506
507        loop {
508            match &mut slf.state {
509                State::Disconnected => {
510                    let Some(mut req) = slf.req.try_clone() else {
511                        slf.close();
512                        return Poll::Ready(Some(Err(Error::UncloneableRequest)));
513                    };
514
515                    // TODO: Maybe we should error if the provided RequestBuilder already had a
516                    //       Last-Event-ID header.
517                    if let Some(last_event_id) = slf.stream.last_event_id() {
518                        match HeaderValue::from_str(last_event_id) {
519                            Ok(val) => req = req.header("Last-Event-ID", val),
520                            Err(err) => {
521                                slf.close();
522                                return Poll::Ready(Some(Err(Error::InvalidLastEventId(err))));
523                            }
524                        }
525                    }
526
527                    let fut = Box::pin(req.send());
528                    slf.state = State::Connecting(fut);
529                }
530
531                State::Connecting(fut) => match ready!(fut.as_mut().poll(cx)) {
532                    Ok(res) => {
533                        let status = res.status();
534
535                        if matches!(status, StatusCode::NO_CONTENT) {
536                            slf.close();
537                            return Poll::Ready(None);
538                        }
539
540                        let is_transient_error = matches!(
541                            status,
542                            StatusCode::REQUEST_TIMEOUT
543                                | StatusCode::TOO_MANY_REQUESTS
544                                | StatusCode::BAD_GATEWAY
545                                | StatusCode::SERVICE_UNAVAILABLE
546                                | StatusCode::GATEWAY_TIMEOUT
547                        );
548
549                        if slf.retry_transient_errors && is_transient_error {
550                            return Poll::Ready(Some(slf.go_to_sleep(SseErrorEvent::Http(status))));
551                        } else if status != StatusCode::OK {
552                            slf.close();
553                            return Poll::Ready(Some(Err(Error::Status(status))));
554                        }
555
556                        let Some(content_type) = res
557                            .headers()
558                            .get(reqwest::header::CONTENT_TYPE)
559                            .map(|v| v.as_bytes())
560                        else {
561                            slf.close();
562                            return Poll::Ready(Some(Err(Error::MissingContentType)));
563                        };
564
565                        const MIME_EVENT_STREAM: &str = "text/event-stream";
566                        if !(content_type.starts_with(MIME_EVENT_STREAM.as_bytes())
567                            && matches!(
568                                content_type.get(MIME_EVENT_STREAM.len()),
569                                None | Some(b';' | b' ' | b'\t')
570                            ))
571                        {
572                            slf.close();
573                            return Poll::Ready(Some(Err(Error::InvalidContentType)));
574                        }
575
576                        slf.state = State::Open;
577                        slf.connected_since = Some(Instant::now());
578                        slf.stream.attach(Box::pin(res.bytes_stream()));
579
580                        return Poll::Ready(Some(Ok(SseEvent::Open)));
581                    }
582                    Err(err) => {
583                        slf.close();
584                        return Poll::Ready(Some(slf.go_to_sleep(err.into())));
585                    }
586                },
587
588                State::Open => match ready!(Pin::new(&mut slf.stream).poll_next(cx)) {
589                    Some(Ok(raw_event)) => match raw_event {
590                        SseEventCore::Retry(ms) => slf.reconnection_time_ms = ms,
591                        SseEventCore::Message(event) => return Poll::Ready(Some(Ok(event.into()))),
592                    },
593                    Some(Err(SseStreamError::PayloadTooLarge(err))) => {
594                        slf.close();
595                        return Poll::Ready(Some(Err(Error::PayloadTooLarge(err))));
596                    }
597                    Some(Err(SseStreamError::Inner(err))) => {
598                        return Poll::Ready(Some(slf.go_to_sleep(err.into())));
599                    }
600                    None => return Poll::Ready(Some(slf.go_to_sleep(SseErrorEvent::Eof))),
601                },
602
603                State::Sleeping(sleep_fut) => {
604                    ready!(sleep_fut.as_mut().poll(cx));
605                    slf.state = State::Disconnected;
606                }
607
608                State::Closed => return Poll::Ready(None),
609            }
610        }
611    }
612}
613
614mod sealed {
615    pub trait Sealed {}
616}
617
618/// An extension trait for [`reqwest::RequestBuilder`] to ergonomically create SSE streams.
619pub trait RequestBuilderExt: sealed::Sealed {
620    /// Converts this request builder into an active [`EventSource`] with default settings.
621    fn into_event_source(self) -> EventSource;
622    /// Converts this request builder into an [`EventSourceBuilder`] for further configuration.
623    fn into_event_source_builder(self) -> EventSourceBuilder;
624}
625
626impl sealed::Sealed for RequestBuilder {}
627impl RequestBuilderExt for RequestBuilder {
628    fn into_event_source(self) -> EventSource {
629        EventSource::new(self)
630    }
631    fn into_event_source_builder(self) -> EventSourceBuilder {
632        EventSourceBuilder::new(self)
633    }
634}