sse_reqwest_client/lib.rs
1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3#![warn(rustdoc::missing_crate_level_docs)]
4
5use std::{
6 fmt,
7 future::Future,
8 num::NonZeroUsize,
9 pin::Pin,
10 sync::Arc,
11 task::{Context, Poll, ready},
12 time::Duration,
13};
14
15use bytes::Bytes;
16use futures_core::stream::Stream;
17use reqwest::{RequestBuilder, StatusCode, header::HeaderValue};
18use thiserror::Error;
19use tokio::time::{Instant, Sleep, sleep};
20
21pub use sse_core::SseRetryConfig;
22use sse_core::{
23 MessageEvent, PayloadTooLargeError, SseDecoder, SseEvent as SseEventCore, SseStream,
24 SseStreamError,
25};
26
27type ByteStream = Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send + Sync>>;
28type ConnectFuture =
29 Pin<Box<dyn Future<Output = Result<reqwest::Response, reqwest::Error>> + Send + Sync>>;
30
31/// An alias for [`Result`] with the error defaulting to [`Error`](enum@Error).
32pub type Result<T, E = Error> = std::result::Result<T, E>;
33
34/// Errors that can occur during the lifecycle of an [`EventSource`] connection.
35#[derive(Debug, Error)]
36pub enum Error {
37 /// The server responded with a non-200 HTTP status code.
38 #[error("unexpected HTTP status code: {0}")]
39 Status(StatusCode),
40 /// The [`RequestBuilder`] could not be cloned (e.g., it contains a streaming body).
41 #[error("request builder could not be cloned (e.g., non-restartable body stream)")]
42 UncloneableRequest,
43 /// The server's response lacked the `text/event-stream` Content-Type.
44 #[error("invalid response HTTP Content-Type")]
45 InvalidContentType,
46 /// The server's response did not contain a Content-Type header.
47 #[error("response HTTP Content-Type missing")]
48 MissingContentType,
49 /// The client exhausted all retry attempts without successfully reconnecting.
50 #[error("couldn't reconnect to SSE server in {0} attempts: {1}")]
51 Timeout(u32, SseErrorEvent),
52 /// The server sent an event payload that exceeded the configured buffer limit.
53 #[error("server sent an oversized payload exceeding the allotted buffer")]
54 PayloadTooLarge(#[from] PayloadTooLargeError),
55 /// The `Last-Event-ID` provided by the server contains bytes that cannot be
56 /// safely converted into a valid HTTP header.
57 #[error("Last-Event-ID cannot be converted to a valid HTTP header: {0}")]
58 InvalidLastEventId(reqwest::header::InvalidHeaderValue),
59}
60
61/// The connection state mapping to the JavaScript [`EventSource`](https://developer.mozilla.org/en-US/docs/Web/API/EventSource) API [`readyState`](https://developer.mozilla.org/en-US/docs/Web/API/EventSource/readyState).
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
63#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
64pub enum ReadyState {
65 /// The connection has not yet been established, or it was closed and the client is reconnecting.
66 Connecting = 0,
67 /// The connection is open and ready to receive events.
68 Open = 1,
69 /// The connection is permanently closed and will not reconnect.
70 Closed = 2,
71}
72
73enum State {
74 Disconnected,
75 Connecting(ConnectFuture),
76 Open,
77 Sleeping(Pin<Box<Sleep>>),
78 Closed,
79}
80
81impl fmt::Debug for State {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 match self {
84 State::Disconnected => f.write_str("Disconnected"),
85 State::Connecting(_) => f.write_str("Connecting(_)"),
86 State::Open => f.write_str("Open"),
87 State::Sleeping(fut) => f.debug_tuple("Sleeping").field(fut).finish(),
88 State::Closed => f.write_str("Closed"),
89 }
90 }
91}
92
93/// Transient errors that cause the stream to drop and trigger an automatic reconnection.
94#[derive(Debug, Error)]
95pub enum SseErrorEvent {
96 /// The server gracefully closed the TCP connection (EOF) while the stream was active.
97 #[error("server cleanly closed the connection (EOF)")]
98 Eof,
99
100 /// The server responded with an HTTP status code designated as retryable (e.g., 502, 503).
101 #[error("transient HTTP error: {0}")]
102 Http(StatusCode),
103
104 /// A network-level error occurred, such as a dropped socket, DNS failure, or read timeout.
105 #[error("network or transport error: {0}")]
106 Network(#[from] reqwest::Error),
107}
108
109/// High-level events emitted by the [`EventSource`] stream.
110#[derive(Debug)]
111pub enum SseEvent {
112 /// Emitted when the underlying HTTP connection is successfully established.
113 Open,
114 /// A parsed message event from the server.
115 Message(MessageEvent),
116 /// Emitted when the connection drops but the client is actively attempting to reconnect.
117 ///
118 /// This gives the application a chance to log the interruption or update UI state
119 /// while the exponential backoff handles the recovery in the background.
120 Error(SseErrorEvent),
121}
122
123impl SseEvent {
124 /// Consumes the event and returns the underlying [`MessageEvent`] if this is a standard message.
125 ///
126 /// This is particularly useful in stream combinators:
127 /// ```no_run
128 /// # use futures_util::TryStreamExt;
129 /// # use sse_reqwest_client::*;
130 /// # tokio_test::block_on(async {
131 /// # let client = reqwest::Client::new();
132 /// # let mut stream = client.get("https://example.com/events").into_event_source();
133 /// let messages: Vec<_> = stream
134 /// .try_filter_map(async |res| Ok(res.into_message()))
135 /// .try_collect()
136 /// .await?;
137 /// # Result::<()>::Ok(())
138 /// # });
139 /// ```
140 pub fn into_message(self) -> Option<MessageEvent> {
141 match self {
142 Self::Message(msg) => Some(msg),
143 Self::Open | Self::Error(_) => None,
144 }
145 }
146
147 /// Returns a reference to the underlying [`MessageEvent`] if this is a standard message.
148 pub fn as_message(&self) -> Option<&MessageEvent> {
149 match self {
150 Self::Message(msg) => Some(msg),
151 Self::Open | Self::Error(_) => None,
152 }
153 }
154
155 /// Returns a mutable reference to the underlying [`MessageEvent`] if this is a standard message.
156 pub fn as_message_mut(&mut self) -> Option<&mut MessageEvent> {
157 match self {
158 Self::Message(msg) => Some(msg),
159 Self::Open | Self::Error(_) => None,
160 }
161 }
162}
163
164impl From<MessageEvent> for SseEvent {
165 fn from(event: MessageEvent) -> Self {
166 Self::Message(event)
167 }
168}
169
170impl From<SseErrorEvent> for SseEvent {
171 fn from(err: SseErrorEvent) -> Self {
172 Self::Error(err)
173 }
174}
175
176/// Error indicating that an [`SseEvent`] could not be converted into a [`MessageEvent`].
177#[derive(Debug, Error)]
178#[error("couldn't convert Event::{} into a MessageEvent", match .0 {
179 SseEvent::Open => "Open",
180 SseEvent::Message(_) => "Message",
181 SseEvent::Error(_) => "Error"
182})]
183pub struct FromMessageEventError(pub SseEvent);
184
185impl TryFrom<SseEvent> for MessageEvent {
186 type Error = FromMessageEventError;
187
188 fn try_from(ev: SseEvent) -> Result<Self, Self::Error> {
189 match ev {
190 SseEvent::Message(msg) => Ok(msg),
191 ev => Err(FromMessageEventError(ev)),
192 }
193 }
194}
195
196/// A builder for configuring an [`EventSource`] connection.
197///
198/// # Example
199/// ```rust,no_run
200/// use std::{num::NonZeroUsize, time::Duration};
201/// use sse_reqwest_client::{RequestBuilderExt, EventSourceBuilder, SseRetryConfig};
202///
203/// # #[tokio::main]
204/// # async fn main() {
205/// let client = reqwest::Client::new();
206/// let req = client.get("https://api.example.com/stream");
207///
208/// // Create a stream with a strict 1MB payload limit and a custom retry delay
209/// let stream = req.into_event_source_builder()
210/// .retry_config(SseRetryConfig::new())
211/// .initial_reconnection_time(Duration::from_secs(5))
212/// .max_payload_size(NonZeroUsize::new(1024 * 1024).unwrap())
213/// .build();
214/// # }
215/// ```
216#[derive(Debug)]
217pub struct EventSourceBuilder {
218 req: RequestBuilder,
219 retry_config: SseRetryConfig,
220 reconnection_time_ms: u32,
221 max_payload_size: Option<NonZeroUsize>,
222 last_event_id: Option<Arc<str>>,
223 retry_transient_errors: bool,
224 successful_connection_threshold: Duration,
225}
226
227impl EventSourceBuilder {
228 /// Creates a new builder wrapping the given [`reqwest::RequestBuilder`].
229 #[must_use]
230 pub fn new(req: RequestBuilder) -> Self {
231 Self {
232 req,
233 reconnection_time_ms: 3000, // Default per SSE Spec
234 retry_config: SseRetryConfig::new(),
235 max_payload_size: None, // use default
236 last_event_id: None,
237 retry_transient_errors: false,
238 successful_connection_threshold: Duration::from_secs(5),
239 }
240 }
241
242 /// Applies a custom retry configuration for automatic reconnections.
243 #[inline]
244 #[must_use]
245 pub fn retry_config(mut self, retry_config: SseRetryConfig) -> Self {
246 self.retry_config = retry_config;
247 self
248 }
249
250 /// Sets the base delay to wait before attempting to reconnect.
251 ///
252 /// This delay may be overridden by the server using `retry` events.
253 #[inline]
254 #[must_use]
255 pub fn initial_reconnection_time(mut self, reconnection_time: Duration) -> Self {
256 self.reconnection_time_ms = reconnection_time
257 .as_millis()
258 .try_into()
259 .expect("Read duration too long");
260 self
261 }
262
263 /// Configures the maximum allowed byte size for a single event payload.
264 #[inline]
265 #[must_use]
266 pub fn max_payload_size(mut self, max_payload_size: NonZeroUsize) -> Self {
267 self.max_payload_size = Some(max_payload_size);
268 self
269 }
270
271 /// Sets the initial `Last-Event-ID` to send with the first connection request.
272 ///
273 /// This is useful for resuming a dropped stream from a previously saved state.
274 #[inline]
275 #[must_use]
276 pub fn last_event_id(mut self, id: impl Into<Arc<str>>) -> Self {
277 self.last_event_id = Some(id.into());
278 self
279 }
280
281 /// Enables automatic retries for transient HTTP status codes.
282 ///
283 /// By default, the [`EventSource`] strictly follows the WHATWG specification and will
284 /// permanently close the stream on any non-200 HTTP response.
285 ///
286 /// Setting this to `true` allows the client to automatically back off and retry when
287 /// encountering temporary proxy or server issues. The following status codes are
288 /// considered transient:
289 /// * `408 Request Timeout`
290 /// * `429 Too Many Requests`
291 /// * `502 Bad Gateway`
292 /// * `503 Service Unavailable`
293 /// * `504 Gateway Timeout`
294 #[inline]
295 #[must_use]
296 pub fn retry_transient_errors(mut self, retry: bool) -> Self {
297 self.retry_transient_errors = retry;
298 self
299 }
300
301 /// Sets the minimum duration a connection must remain open to be considered "successful"
302 /// and reset the exponential backoff counter. Defaults to 5 seconds.
303 #[inline]
304 #[must_use]
305 pub fn successful_connection_threshold(mut self, threshold: Duration) -> Self {
306 self.successful_connection_threshold = threshold;
307 self
308 }
309
310 /// Consumes the builder and returns the configured [`EventSource`].
311 #[must_use]
312 pub fn build(self) -> EventSource {
313 let mut decoder = match self.max_payload_size {
314 Some(max_payload_size) => SseDecoder::with_limit(max_payload_size),
315 None => SseDecoder::new(),
316 };
317 decoder.reconnect_with_id(self.last_event_id);
318
319 EventSource {
320 req: (self.req)
321 .header(reqwest::header::ACCEPT, "text/event-stream")
322 .header(reqwest::header::CACHE_CONTROL, "no-store"),
323 reconnection_time_ms: self.reconnection_time_ms,
324 connection_attempt: 0,
325 connected_since: None,
326 retry_config: self.retry_config,
327 retry_transient_errors: self.retry_transient_errors,
328 successful_connection_threshold: self.successful_connection_threshold,
329 stream: SseStream::with_decoder(decoder),
330 state: State::Disconnected,
331 }
332 }
333}
334
335/// A reconnecting stream of Server-Sent Events.
336pub struct EventSource {
337 req: RequestBuilder,
338 reconnection_time_ms: u32,
339 connection_attempt: u32,
340 connected_since: Option<Instant>,
341 retry_config: SseRetryConfig,
342 retry_transient_errors: bool,
343 successful_connection_threshold: Duration,
344 stream: SseStream<ByteStream>,
345 state: State,
346}
347
348impl fmt::Debug for EventSource {
349 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350 f.debug_struct("EventSource")
351 .field("req", &self.req)
352 .field("reconnection_time_ms", &self.reconnection_time_ms)
353 .field("connection_attempt", &self.connection_attempt)
354 .field("retry_config", &self.retry_config)
355 .field("retry_transient_errors", &self.retry_transient_errors)
356 .field("state", &self.state)
357 .field(
358 "stream.last_event_id()",
359 &self.stream.last_event_id().map(|id| &**id),
360 )
361 .field("stream.is_closed()", &self.stream.is_closed())
362 .finish_non_exhaustive()
363 }
364}
365
366impl EventSource {
367 /// Creates a new [`EventSource`] from the given request with default configurations.
368 #[must_use]
369 pub fn new(req: RequestBuilder) -> Self {
370 Self::builder(req).build()
371 }
372
373 /// Creates a builder to customize the [`EventSource`] before connecting.
374 #[must_use]
375 pub fn builder(req: RequestBuilder) -> EventSourceBuilder {
376 EventSourceBuilder::new(req)
377 }
378
379 /// Closes the underlying SSE connection.
380 ///
381 /// This method terminates the active HTTP request, effectively dropping the
382 /// inner stream. Calling [`close`](Self::close) is idempotent; if the connection is already
383 /// closed, this does nothing and is perfectly safe to call multiple times.
384 ///
385 /// While this halts all incoming events and stops the automatic reconnection
386 /// loop, it does not consume the [`EventSource`]. You can manually restart the
387 /// stream and initiate a fresh connection at any time by calling
388 /// [`force_reconnect()`](Self::force_reconnect).
389 ///
390 /// # Example
391 /// ```rust,no_run
392 /// # use futures_util::StreamExt;
393 /// # use sse_reqwest_client::{RequestBuilderExt, ReadyState};
394 /// # #[tokio::main]
395 /// # async fn main() {
396 /// let client = reqwest::Client::new();
397 /// let mut stream = client.get("https://api.example.com/stream").into_event_source();
398 ///
399 /// // ... later, to gracefully stop listening to events (e.g., app goes to background):
400 /// stream.close();
401 /// assert_eq!(stream.ready_state(), ReadyState::Closed);
402 ///
403 /// // The stream will now yield None
404 /// assert!(stream.next().await.is_none());
405 /// # }
406 /// ```
407 pub fn close(&mut self) {
408 self.stream.close();
409 self.state = State::Closed;
410 }
411
412 /// Returns the current connection state.
413 #[inline]
414 #[must_use]
415 pub fn ready_state(&self) -> ReadyState {
416 match &self.state {
417 State::Disconnected | State::Connecting(_) | State::Sleeping(_) => {
418 ReadyState::Connecting
419 }
420 State::Open => ReadyState::Open,
421 State::Closed => ReadyState::Closed,
422 }
423 }
424
425 /// Returns the most recently received `Last-Event-ID`, if any.
426 #[inline]
427 #[must_use]
428 pub fn last_event_id(&self) -> Option<&Arc<str>> {
429 self.stream.last_event_id()
430 }
431
432 /// Terminates the current connection and immediately attempts to reconnect.
433 ///
434 /// Because [`EventSource`] automatically handles network drops and reconnections, you typically
435 /// do not need to call this manually. However, it is useful in specific scenarios, such as:
436 ///
437 /// * **Bypassing Backoff:** The server state has heavily desynced, and you want to bypass the
438 /// current exponential backoff timer to reconnect instantly.
439 /// * **Manual Revival:** You previously called [`close()`](Self::close) to pause the stream,
440 /// and now want to resume listening for events.
441 ///
442 /// This method resets the connection attempt counter, meaning the next connection attempt will
443 /// happen immediately without any retry delay. And will continue with the exponential backoff
444 /// reset.
445 ///
446 /// # Example
447 /// ```rust,no_run
448 /// # use sse_reqwest_client::{RequestBuilderExt, ReadyState};
449 /// # #[tokio::main]
450 /// # async fn main() {
451 /// let client = reqwest::Client::new();
452 /// let mut stream = client.get("https://api.example.com/stream").into_event_source();
453 ///
454 /// // If your application detects via the OS that network connectivity was restored,
455 /// // you can manually trigger an immediate reconnect to bypass active backoff delays.
456 /// stream.force_reconnect();
457 /// assert_eq!(stream.ready_state(), ReadyState::Connecting);
458 /// # }
459 /// ```
460 #[inline]
461 pub fn force_reconnect(&mut self) {
462 self.stream.close();
463 self.connection_attempt = 0;
464 self.state = State::Disconnected;
465 }
466
467 /// Terminates the current connection and immediately attempts to reconnect,
468 /// explicitly overriding the `Last-Event-ID` sent to the server.
469 ///
470 /// This is useful if your application state has desynced and you need to
471 /// force the server to rewind or fast-forward to a specific point in the stream.
472 ///
473 /// See [force_reconnect()](Self::force_reconnect) for more info.
474 #[inline]
475 pub fn force_reconnect_with_id(&mut self, id: Option<Arc<str>>) {
476 self.stream.close_with_id(id);
477 self.connection_attempt = 0;
478 self.state = State::Disconnected;
479 }
480
481 fn go_to_sleep(&mut self, cause: SseErrorEvent) -> Result<SseEvent> {
482 if let Some(connected_since) = self.connected_since.take() {
483 if self.successful_connection_threshold <= connected_since.elapsed() {
484 self.connection_attempt = 0;
485 }
486 }
487
488 let wait_dur = (self.retry_config)
489 .calculate_backoff(self.reconnection_time_ms, self.connection_attempt);
490 self.connection_attempt += 1;
491 if let Some(dur) = wait_dur {
492 self.state = State::Sleeping(Box::pin(sleep(dur)));
493 Ok(SseEvent::Error(cause))
494 } else {
495 self.close();
496 Err(Error::Timeout(self.connection_attempt, cause))
497 }
498 }
499}
500
501impl Stream for EventSource {
502 type Item = Result<SseEvent>;
503
504 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
505 let slf = &mut *self;
506
507 loop {
508 match &mut slf.state {
509 State::Disconnected => {
510 let Some(mut req) = slf.req.try_clone() else {
511 slf.close();
512 return Poll::Ready(Some(Err(Error::UncloneableRequest)));
513 };
514
515 // TODO: Maybe we should error if the provided RequestBuilder already had a
516 // Last-Event-ID header.
517 if let Some(last_event_id) = slf.stream.last_event_id() {
518 match HeaderValue::from_str(last_event_id) {
519 Ok(val) => req = req.header("Last-Event-ID", val),
520 Err(err) => {
521 slf.close();
522 return Poll::Ready(Some(Err(Error::InvalidLastEventId(err))));
523 }
524 }
525 }
526
527 let fut = Box::pin(req.send());
528 slf.state = State::Connecting(fut);
529 }
530
531 State::Connecting(fut) => match ready!(fut.as_mut().poll(cx)) {
532 Ok(res) => {
533 let status = res.status();
534
535 if matches!(status, StatusCode::NO_CONTENT) {
536 slf.close();
537 return Poll::Ready(None);
538 }
539
540 let is_transient_error = matches!(
541 status,
542 StatusCode::REQUEST_TIMEOUT
543 | StatusCode::TOO_MANY_REQUESTS
544 | StatusCode::BAD_GATEWAY
545 | StatusCode::SERVICE_UNAVAILABLE
546 | StatusCode::GATEWAY_TIMEOUT
547 );
548
549 if slf.retry_transient_errors && is_transient_error {
550 return Poll::Ready(Some(slf.go_to_sleep(SseErrorEvent::Http(status))));
551 } else if status != StatusCode::OK {
552 slf.close();
553 return Poll::Ready(Some(Err(Error::Status(status))));
554 }
555
556 let Some(content_type) = res
557 .headers()
558 .get(reqwest::header::CONTENT_TYPE)
559 .map(|v| v.as_bytes())
560 else {
561 slf.close();
562 return Poll::Ready(Some(Err(Error::MissingContentType)));
563 };
564
565 const MIME_EVENT_STREAM: &str = "text/event-stream";
566 if !(content_type.starts_with(MIME_EVENT_STREAM.as_bytes())
567 && matches!(
568 content_type.get(MIME_EVENT_STREAM.len()),
569 None | Some(b';' | b' ' | b'\t')
570 ))
571 {
572 slf.close();
573 return Poll::Ready(Some(Err(Error::InvalidContentType)));
574 }
575
576 slf.state = State::Open;
577 slf.connected_since = Some(Instant::now());
578 slf.stream.attach(Box::pin(res.bytes_stream()));
579
580 return Poll::Ready(Some(Ok(SseEvent::Open)));
581 }
582 Err(err) => {
583 slf.close();
584 return Poll::Ready(Some(slf.go_to_sleep(err.into())));
585 }
586 },
587
588 State::Open => match ready!(Pin::new(&mut slf.stream).poll_next(cx)) {
589 Some(Ok(raw_event)) => match raw_event {
590 SseEventCore::Retry(ms) => slf.reconnection_time_ms = ms,
591 SseEventCore::Message(event) => return Poll::Ready(Some(Ok(event.into()))),
592 },
593 Some(Err(SseStreamError::PayloadTooLarge(err))) => {
594 slf.close();
595 return Poll::Ready(Some(Err(Error::PayloadTooLarge(err))));
596 }
597 Some(Err(SseStreamError::Inner(err))) => {
598 return Poll::Ready(Some(slf.go_to_sleep(err.into())));
599 }
600 None => return Poll::Ready(Some(slf.go_to_sleep(SseErrorEvent::Eof))),
601 },
602
603 State::Sleeping(sleep_fut) => {
604 ready!(sleep_fut.as_mut().poll(cx));
605 slf.state = State::Disconnected;
606 }
607
608 State::Closed => return Poll::Ready(None),
609 }
610 }
611 }
612}
613
614mod sealed {
615 pub trait Sealed {}
616}
617
618/// An extension trait for [`reqwest::RequestBuilder`] to ergonomically create SSE streams.
619pub trait RequestBuilderExt: sealed::Sealed {
620 /// Converts this request builder into an active [`EventSource`] with default settings.
621 fn into_event_source(self) -> EventSource;
622 /// Converts this request builder into an [`EventSourceBuilder`] for further configuration.
623 fn into_event_source_builder(self) -> EventSourceBuilder;
624}
625
626impl sealed::Sealed for RequestBuilder {}
627impl RequestBuilderExt for RequestBuilder {
628 fn into_event_source(self) -> EventSource {
629 EventSource::new(self)
630 }
631 fn into_event_source_builder(self) -> EventSourceBuilder {
632 EventSourceBuilder::new(self)
633 }
634}