trillium_http/
upgrade.rs

1use crate::{Conn, Headers, Method, StateSet, Stopper};
2use futures_lite::{AsyncRead, AsyncWrite};
3use std::{
4    fmt::{self, Debug, Formatter},
5    io,
6    pin::Pin,
7    str,
8    task::{Context, Poll},
9};
10use trillium_macros::AsyncWrite;
11
12/**
13This open (pub fields) struct represents a http upgrade. It contains
14all of the data available on a Conn, as well as owning the underlying
15transport.
16
17Important implementation note: When reading directly from the
18transport, ensure that you read from `buffer` first if there are bytes
19in it. Alternatively, read directly from the Upgrade, as that
20[`AsyncRead`] implementation will drain the buffer first before
21reading from the transport.
22*/
23#[derive(AsyncWrite)]
24pub struct Upgrade<Transport> {
25    /// The http request headers
26    pub request_headers: Headers,
27    /// The request path
28    pub path: String,
29    /// The http request method
30    pub method: Method,
31    /// Any state that has been accumulated on the Conn before negotiating the upgrade
32    pub state: StateSet,
33    /// The underlying io (often a `TcpStream` or similar)
34    #[async_write]
35    pub transport: Transport,
36    /// Any bytes that have been read from the underlying tcpstream
37    /// already. It is your responsibility to process these bytes
38    /// before reading directly from the transport.
39    pub buffer: Option<Vec<u8>>,
40    /// A [`Stopper`] which can and should be used to gracefully shut
41    /// down any long running streams or futures associated with this
42    /// upgrade
43    pub stopper: Stopper,
44}
45
46impl<Transport> Upgrade<Transport> {
47    /// see [`request_headers`]
48    #[deprecated = "directly access the request_headers field"]
49    pub fn headers(&self) -> &Headers {
50        &self.request_headers
51    }
52
53    /// the http request path up to but excluding any query component
54    pub fn path(&self) -> &str {
55        match self.path.split_once('?') {
56            Some((path, _)) => path,
57            None => &self.path,
58        }
59    }
60
61    /// retrieves the query component of the path
62    pub fn querystring(&self) -> &str {
63        self.path
64            .split_once('?')
65            .map(|(_, query)| query)
66            .unwrap_or_default()
67    }
68
69    /// the http method
70    pub fn method(&self) -> &Method {
71        &self.method
72    }
73
74    /// any state that has been accumulated on the Conn before
75    /// negotiating the upgrade.
76    pub fn state(&self) -> &StateSet {
77        &self.state
78    }
79
80    /// Modify the transport type of this upgrade.
81    ///
82    /// This is useful for boxing the transport in order to erase the type argument.
83    pub fn map_transport<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(
84        self,
85        f: impl Fn(Transport) -> T,
86    ) -> Upgrade<T> {
87        Upgrade {
88            transport: f(self.transport),
89            path: self.path,
90            method: self.method,
91            state: self.state,
92            buffer: self.buffer,
93            request_headers: self.request_headers,
94            stopper: self.stopper,
95        }
96    }
97}
98
99impl<Transport> Debug for Upgrade<Transport> {
100    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
101        f.debug_struct(&format!("Upgrade<{}>", std::any::type_name::<Transport>()))
102            .field("request_headers", &self.request_headers)
103            .field("path", &self.path)
104            .field("method", &self.method)
105            .field(
106                "buffer",
107                &self.buffer.as_deref().map(String::from_utf8_lossy),
108            )
109            .field("stopper", &self.stopper)
110            .field("state", &self.state)
111            .field("transport", &"..")
112            .finish()
113    }
114}
115
116impl<Transport> From<Conn<Transport>> for Upgrade<Transport> {
117    fn from(conn: Conn<Transport>) -> Self {
118        let Conn {
119            request_headers,
120            path,
121            method,
122            state,
123            transport,
124            buffer,
125            stopper,
126            ..
127        } = conn;
128
129        Self {
130            request_headers,
131            path,
132            method,
133            state,
134            transport,
135            buffer: if buffer.is_empty() {
136                None
137            } else {
138                Some(buffer.into())
139            },
140            stopper,
141        }
142    }
143}
144
145impl<Transport: AsyncRead + Unpin> AsyncRead for Upgrade<Transport> {
146    fn poll_read(
147        mut self: Pin<&mut Self>,
148        cx: &mut Context<'_>,
149        buf: &mut [u8],
150    ) -> Poll<io::Result<usize>> {
151        match self.buffer.take() {
152            Some(mut buffer) if !buffer.is_empty() => {
153                let len = buffer.len();
154                if len > buf.len() {
155                    log::trace!(
156                        "have {} bytes of pending data but can only use {}",
157                        len,
158                        buf.len()
159                    );
160                    let remaining = buffer.split_off(buf.len());
161                    buf.copy_from_slice(&buffer[..]);
162                    self.buffer = Some(remaining);
163                    Poll::Ready(Ok(buf.len()))
164                } else {
165                    log::trace!("have {} bytes of pending data, using all of it", len);
166                    buf[..len].copy_from_slice(&buffer);
167                    self.buffer = None;
168                    match Pin::new(&mut self.transport).poll_read(cx, &mut buf[len..]) {
169                        Poll::Ready(Ok(e)) => Poll::Ready(Ok(e + len)),
170                        Poll::Pending => Poll::Ready(Ok(len)),
171                        Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
172                    }
173                }
174            }
175
176            _ => Pin::new(&mut self.transport).poll_read(cx, buf),
177        }
178    }
179}