Skip to main content

wreq_proto/
upgrade.rs

1//! HTTP Upgrades
2//!
3//! This module deals with managing [HTTP Upgrades][mdn] in crate::core:. Since
4//! several concepts in HTTP allow for first talking HTTP, and then converting
5//! to a different protocol, this module conflates them into a single API.
6//! Those include:
7//!
8//! - HTTP/1.1 Upgrades
9//! - HTTP `CONNECT`
10//!
11//! You are responsible for any other pre-requisites to establish an upgrade,
12//! such as sending the appropriate headers, methods, and status codes. You can
13//! then use [`on`][] to grab a `Future` which will resolve to the upgraded
14//! connection object, or an error if the upgrade fails.
15//!
16//! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism
17//!
18//! Sending an HTTP upgrade from the client involves setting
19//! either the appropriate method, if wanting to `CONNECT`, or headers such as
20//! `Upgrade` and `Connection`, on the `http::Request`. Once receiving the
21//! `http::Response` back, you must check for the specific information that the
22//! upgrade is agreed upon by the server (such as a `101` status code), and then
23//! get the `Future` from the `Response`.
24
25use std::{
26    error::Error as StdError,
27    fmt,
28    future::Future,
29    io,
30    pin::Pin,
31    sync::{Arc, Mutex},
32    task::{Context, Poll},
33};
34
35use bytes::Bytes;
36use tokio::{
37    io::{AsyncRead, AsyncWrite, ReadBuf},
38    sync::oneshot,
39};
40
41use self::rewind::Rewind;
42use super::{Error, Result};
43
44/// An upgraded HTTP connection.
45///
46/// This type holds a trait object internally of the original IO that
47/// was used to speak HTTP before the upgrade. It can be used directly
48/// as a [`AsyncRead`] or [`AsyncWrite`] for convenience.
49///
50/// Alternatively, if the exact type is known, this can be deconstructed
51/// into its parts.
52pub struct Upgraded {
53    io: Rewind<Box<dyn Io + Send>>,
54}
55
56/// A future for a possible HTTP upgrade.
57///
58/// If no upgrade was available, or it doesn't succeed, yields an `Error`.
59#[derive(Clone)]
60pub struct OnUpgrade {
61    rx: Option<Arc<Mutex<oneshot::Receiver<Result<Upgraded>>>>>,
62}
63
64/// Gets a pending HTTP upgrade from this message.
65///
66/// This can be called on the following types:
67///
68/// - `http::Request<B>`
69/// - `http::Response<B>`
70/// - `&mut http::Request<B>`
71/// - `&mut http::Response<B>`
72#[inline]
73pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
74    msg.on_upgrade()
75}
76
77pub(crate) struct Pending {
78    tx: oneshot::Sender<Result<Upgraded>>,
79}
80
81pub(crate) fn pending() -> (Pending, OnUpgrade) {
82    let (tx, rx) = oneshot::channel();
83    (
84        Pending { tx },
85        OnUpgrade {
86            rx: Some(Arc::new(Mutex::new(rx))),
87        },
88    )
89}
90
91// ===== impl Upgraded =====
92
93impl Upgraded {
94    #[inline]
95    pub(crate) fn new<T>(io: T, read_buf: Bytes) -> Self
96    where
97        T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
98    {
99        Upgraded {
100            io: Rewind::new_buffered(Box::new(io), read_buf),
101        }
102    }
103}
104
105impl AsyncRead for Upgraded {
106    #[inline]
107    fn poll_read(
108        mut self: Pin<&mut Self>,
109        cx: &mut Context<'_>,
110        buf: &mut ReadBuf<'_>,
111    ) -> Poll<io::Result<()>> {
112        Pin::new(&mut self.io).poll_read(cx, buf)
113    }
114}
115
116impl AsyncWrite for Upgraded {
117    #[inline]
118    fn poll_write(
119        mut self: Pin<&mut Self>,
120        cx: &mut Context<'_>,
121        buf: &[u8],
122    ) -> Poll<io::Result<usize>> {
123        Pin::new(&mut self.io).poll_write(cx, buf)
124    }
125
126    #[inline]
127    fn poll_write_vectored(
128        mut self: Pin<&mut Self>,
129        cx: &mut Context<'_>,
130        bufs: &[io::IoSlice<'_>],
131    ) -> Poll<io::Result<usize>> {
132        Pin::new(&mut self.io).poll_write_vectored(cx, bufs)
133    }
134
135    #[inline]
136    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
137        Pin::new(&mut self.io).poll_flush(cx)
138    }
139
140    #[inline]
141    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
142        Pin::new(&mut self.io).poll_shutdown(cx)
143    }
144
145    #[inline]
146    fn is_write_vectored(&self) -> bool {
147        self.io.is_write_vectored()
148    }
149}
150
151impl fmt::Debug for Upgraded {
152    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
153        f.debug_struct("Upgraded").finish()
154    }
155}
156
157// ===== impl OnUpgrade =====
158
159impl OnUpgrade {
160    #[inline]
161    pub(super) fn none() -> Self {
162        OnUpgrade { rx: None }
163    }
164
165    #[inline]
166    pub(super) fn is_none(&self) -> bool {
167        self.rx.is_none()
168    }
169}
170
171impl Future for OnUpgrade {
172    type Output = Result<Upgraded, Error>;
173
174    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
175        match self.rx {
176            Some(ref rx) => Pin::new(&mut *rx.lock().unwrap())
177                .poll(cx)
178                .map(|res| match res {
179                    Ok(Ok(upgraded)) => Ok(upgraded),
180                    Ok(Err(err)) => Err(err),
181                    Err(_oneshot_canceled) => Err(Error::new_canceled().with(UpgradeExpected)),
182                }),
183            None => Poll::Ready(Err(Error::new_user_no_upgrade())),
184        }
185    }
186}
187
188// ===== impl Pending =====
189
190impl Pending {
191    #[inline]
192    pub(super) fn fulfill(self, upgraded: Upgraded) {
193        trace!("pending upgrade fulfill");
194        let _ = self.tx.send(Ok(upgraded));
195    }
196
197    /// Don't fulfill the pending Upgrade, but instead signal that
198    /// upgrades are handled manually.
199    #[inline]
200    pub(super) fn manual(self) {
201        trace!("pending upgrade handled manually");
202        let _ = self.tx.send(Err(Error::new_user_manual_upgrade()));
203    }
204}
205
206// ===== impl UpgradeExpected =====
207
208/// Error cause returned when an upgrade was expected but canceled
209/// for whatever reason.
210///
211/// This likely means the actual `Conn` future wasn't polled and upgraded.
212#[derive(Debug)]
213struct UpgradeExpected;
214
215impl fmt::Display for UpgradeExpected {
216    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217        f.write_str("upgrade expected but not completed")
218    }
219}
220
221impl StdError for UpgradeExpected {}
222
223// ===== impl Io =====
224
225trait Io: AsyncRead + AsyncWrite + Unpin + 'static {}
226
227impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for T {}
228
229mod sealed {
230    use super::OnUpgrade;
231
232    pub trait CanUpgrade {
233        fn on_upgrade(self) -> OnUpgrade;
234    }
235
236    impl<B> CanUpgrade for http::Request<B> {
237        fn on_upgrade(mut self) -> OnUpgrade {
238            self.extensions_mut()
239                .remove::<OnUpgrade>()
240                .unwrap_or_else(OnUpgrade::none)
241        }
242    }
243
244    impl<B> CanUpgrade for &'_ mut http::Request<B> {
245        fn on_upgrade(self) -> OnUpgrade {
246            self.extensions_mut()
247                .remove::<OnUpgrade>()
248                .unwrap_or_else(OnUpgrade::none)
249        }
250    }
251
252    impl<B> CanUpgrade for http::Response<B> {
253        fn on_upgrade(mut self) -> OnUpgrade {
254            self.extensions_mut()
255                .remove::<OnUpgrade>()
256                .unwrap_or_else(OnUpgrade::none)
257        }
258    }
259
260    impl<B> CanUpgrade for &'_ mut http::Response<B> {
261        fn on_upgrade(self) -> OnUpgrade {
262            self.extensions_mut()
263                .remove::<OnUpgrade>()
264                .unwrap_or_else(OnUpgrade::none)
265        }
266    }
267}
268
269mod rewind {
270    use std::{
271        cmp, io,
272        pin::Pin,
273        task::{Context, Poll},
274    };
275
276    use bytes::{Buf, Bytes};
277    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
278
279    /// Combine a buffer with an IO, rewinding reads to use the buffer.
280    #[derive(Debug)]
281    pub(crate) struct Rewind<T> {
282        pre: Option<Bytes>,
283        inner: T,
284    }
285
286    impl<T> Rewind<T> {
287        #[inline]
288        pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
289            Rewind {
290                pre: Some(buf),
291                inner: io,
292            }
293        }
294
295        #[cfg(test)]
296        pub(crate) fn rewind(&mut self, bs: Bytes) {
297            debug_assert!(self.pre.is_none());
298            self.pre = Some(bs);
299        }
300    }
301
302    impl<T> AsyncRead for Rewind<T>
303    where
304        T: AsyncRead + Unpin,
305    {
306        fn poll_read(
307            mut self: Pin<&mut Self>,
308            cx: &mut Context<'_>,
309            buf: &mut ReadBuf<'_>,
310        ) -> Poll<io::Result<()>> {
311            if let Some(mut prefix) = self.pre.take() {
312                // If there are no remaining bytes, let the bytes get dropped.
313                if !prefix.is_empty() {
314                    let copy_len = cmp::min(prefix.len(), buf.remaining());
315                    // TODO: There should be a way to do following two lines cleaner...
316                    buf.put_slice(&prefix[..copy_len]);
317                    prefix.advance(copy_len);
318                    // Put back what's left
319                    if !prefix.is_empty() {
320                        self.pre = Some(prefix);
321                    }
322
323                    return Poll::Ready(Ok(()));
324                }
325            }
326            Pin::new(&mut self.inner).poll_read(cx, buf)
327        }
328    }
329
330    impl<T> AsyncWrite for Rewind<T>
331    where
332        T: AsyncWrite + Unpin,
333    {
334        #[inline]
335        fn poll_write(
336            mut self: Pin<&mut Self>,
337            cx: &mut Context<'_>,
338            buf: &[u8],
339        ) -> Poll<io::Result<usize>> {
340            Pin::new(&mut self.inner).poll_write(cx, buf)
341        }
342
343        #[inline]
344        fn poll_write_vectored(
345            mut self: Pin<&mut Self>,
346            cx: &mut Context<'_>,
347            bufs: &[io::IoSlice<'_>],
348        ) -> Poll<io::Result<usize>> {
349            Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
350        }
351
352        #[inline]
353        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
354            Pin::new(&mut self.inner).poll_flush(cx)
355        }
356
357        #[inline]
358        fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
359            Pin::new(&mut self.inner).poll_shutdown(cx)
360        }
361
362        #[inline]
363        fn is_write_vectored(&self) -> bool {
364            self.inner.is_write_vectored()
365        }
366    }
367
368    #[cfg(test)]
369    mod tests {
370        use bytes::Bytes;
371        use tokio::io::AsyncReadExt;
372
373        use super::Rewind;
374
375        #[tokio::test]
376        async fn partial_rewind() {
377            let underlying = [104, 101, 108, 108, 111];
378
379            let mock = tokio_test::io::Builder::new().read(&underlying).build();
380
381            let mut stream = Rewind::new_buffered(mock, Bytes::new());
382
383            // Read off some bytes, ensure we filled o1
384            let mut buf = [0; 2];
385            stream.read_exact(&mut buf).await.expect("read1");
386
387            // Rewind the stream so that it is as if we never read in the first place.
388            stream.rewind(Bytes::copy_from_slice(&buf[..]));
389
390            let mut buf = [0; 5];
391            stream.read_exact(&mut buf).await.expect("read1");
392
393            // At this point we should have read everything that was in the MockStream
394            assert_eq!(&buf, &underlying);
395        }
396
397        #[tokio::test]
398        async fn full_rewind() {
399            let underlying = [104, 101, 108, 108, 111];
400
401            let mock = tokio_test::io::Builder::new().read(&underlying).build();
402
403            let mut stream = Rewind::new_buffered(mock, Bytes::new());
404
405            let mut buf = [0; 5];
406            stream.read_exact(&mut buf).await.expect("read1");
407
408            // Rewind the stream so that it is as if we never read in the first place.
409            stream.rewind(Bytes::copy_from_slice(&buf[..]));
410
411            let mut buf = [0; 5];
412            stream.read_exact(&mut buf).await.expect("read1");
413
414            assert_eq!(&buf, &underlying);
415        }
416    }
417}