Skip to main content

trillium_http/
upgrade.rs

1use crate::{
2    Buffer, Conn, Headers, HttpContext, Method, TypeSet, Version, h3::H3Connection,
3    received_body::read_buffered,
4};
5use fieldwork::Fieldwork;
6use futures_lite::{AsyncRead, AsyncWrite};
7use std::{
8    borrow::Cow,
9    fmt::{self, Debug, Formatter},
10    io,
11    net::IpAddr,
12    pin::Pin,
13    str,
14    sync::Arc,
15    task::{self, Poll},
16};
17use trillium_macros::AsyncWrite;
18
19/// This struct represents a http upgrade. It contains all of the data available on a Conn, as well
20/// as owning the underlying transport.
21///
22/// **Important implementation note**: When reading directly from the transport, ensure that you
23/// read from `buffer` first if there are bytes in it. Alternatively, read directly from the
24/// Upgrade, as that [`AsyncRead`] implementation will drain the buffer first before reading from
25/// the transport.
26#[derive(AsyncWrite, Fieldwork)]
27#[fieldwork(get, get_mut, set, with, take, into_field, rename_predicates)]
28pub struct Upgrade<Transport> {
29    /// The http request headers
30    request_headers: Headers,
31
32    /// The request path
33    #[field(get = false)]
34    path: Cow<'static, str>,
35
36    /// The http request method
37    #[field(copy)]
38    method: Method,
39
40    /// Any state that has been accumulated on the Conn before negotiating the upgrade
41    state: TypeSet,
42
43    /// The underlying io (often a `TcpStream` or similar)
44    #[async_write]
45    transport: Transport,
46
47    /// Any bytes that have been read from the underlying transport already.
48    ///
49    /// It is your responsibility to process these bytes before reading directly from the
50    /// transport.
51    #[field(deref = "[u8]", into_field = false, set = false, with = false)]
52    buffer: Buffer,
53
54    /// The [`HttpContext`] shared for this server
55    #[field(deref = false)]
56    context: Arc<HttpContext>,
57
58    /// the ip address of the connection, if available
59    #[field(copy)]
60    peer_ip: Option<IpAddr>,
61
62    /// the :authority http/3 pseudo-header
63    authority: Option<Cow<'static, str>>,
64
65    /// the :scheme http/3 pseudo-header
66    scheme: Option<Cow<'static, str>>,
67
68    /// the HTTP/3 connection associated with this upgrade, if this was an HTTP/3 connection
69    #[field(
70        get(deref = false),
71        get_mut = false,
72        set = false,
73        with = false,
74        into_field = false,
75        take = false
76    )]
77    h3_connection: Option<Arc<H3Connection>>,
78
79    /// the :protocol http/3 pseudo-header
80    protocol: Option<Cow<'static, str>>,
81
82    /// the http version
83    #[field = "http_version"]
84    version: Version,
85
86    /// whether this connection was deemed secure by the handler stack
87    secure: bool,
88}
89
90impl<Transport> Upgrade<Transport> {
91    #[doc(hidden)]
92    pub fn new(
93        request_headers: Headers,
94        path: impl Into<Cow<'static, str>>,
95        method: Method,
96        transport: Transport,
97        buffer: Buffer,
98        version: Version,
99    ) -> Self {
100        Self {
101            request_headers,
102            path: path.into(),
103            method,
104            transport,
105            buffer,
106            state: TypeSet::new(),
107            context: Arc::default(),
108            peer_ip: None,
109            authority: None,
110            scheme: None,
111            h3_connection: None,
112            protocol: None,
113            secure: false,
114            version,
115        }
116    }
117
118    /// Take any buffered bytes
119    pub fn take_buffer(&mut self) -> Vec<u8> {
120        std::mem::take(&mut self.buffer).into()
121    }
122
123    #[doc(hidden)]
124    pub fn buffer_and_transport_mut(&mut self) -> (&mut Buffer, &mut Transport) {
125        (&mut self.buffer, &mut self.transport)
126    }
127
128    /// borrow the shared state [`TypeSet`] for this application
129    pub fn shared_state(&self) -> &TypeSet {
130        self.context.shared_state()
131    }
132
133    /// the http request path up to but excluding any query component
134    pub fn path(&self) -> &str {
135        match self.path.split_once('?') {
136            Some((path, _)) => path,
137            None => &self.path,
138        }
139    }
140
141    /// retrieves the query component of the path
142    pub fn querystring(&self) -> &str {
143        self.path
144            .split_once('?')
145            .map(|(_, query)| query)
146            .unwrap_or_default()
147    }
148
149    /// Modify the transport type of this upgrade.
150    ///
151    /// This is useful for boxing the transport in order to erase the type argument.
152    pub fn map_transport<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(
153        self,
154        f: impl Fn(Transport) -> T,
155    ) -> Upgrade<T> {
156        Upgrade {
157            transport: f(self.transport),
158            path: self.path,
159            method: self.method,
160            state: self.state,
161            buffer: self.buffer,
162            request_headers: self.request_headers,
163            context: self.context,
164            peer_ip: self.peer_ip,
165            authority: self.authority,
166            scheme: self.scheme,
167            h3_connection: self.h3_connection,
168            protocol: self.protocol,
169            version: self.version,
170            secure: self.secure,
171        }
172    }
173}
174
175impl<Transport> Debug for Upgrade<Transport> {
176    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
177        f.debug_struct(&format!("Upgrade<{}>", std::any::type_name::<Transport>()))
178            .field("request_headers", &self.request_headers)
179            .field("path", &self.path)
180            .field("method", &self.method)
181            .field("buffer", &self.buffer)
182            .field("context", &self.context)
183            .field("state", &self.state)
184            .field("transport", &format_args!(".."))
185            .field("peer_ip", &self.peer_ip)
186            .field("authority", &self.authority)
187            .field("scheme", &self.scheme)
188            .field("h3_connection", &self.h3_connection)
189            .field("protocol", &self.protocol)
190            .field("version", &self.version)
191            .field("secure", &self.secure)
192            .finish()
193    }
194}
195
196impl<Transport> From<Conn<Transport>> for Upgrade<Transport> {
197    fn from(conn: Conn<Transport>) -> Self {
198        let Conn {
199            request_headers,
200            path,
201            method,
202            state,
203            transport,
204            buffer,
205            context,
206            peer_ip,
207            authority,
208            scheme,
209            h3_connection,
210            protocol,
211            version,
212            secure,
213            ..
214        } = conn;
215
216        Self {
217            request_headers,
218            path,
219            method,
220            state,
221            transport,
222            buffer,
223            context,
224            peer_ip,
225            authority,
226            scheme,
227            h3_connection,
228            protocol,
229            version,
230            secure,
231        }
232    }
233}
234
235impl<Transport: AsyncRead + Unpin> AsyncRead for Upgrade<Transport> {
236    fn poll_read(
237        mut self: Pin<&mut Self>,
238        cx: &mut task::Context<'_>,
239        buf: &mut [u8],
240    ) -> Poll<io::Result<usize>> {
241        let Self {
242            transport, buffer, ..
243        } = &mut *self;
244        read_buffered(buffer, transport, cx, buf)
245    }
246}