Skip to main content

trillium_http/
conn.rs

1use crate::{
2    Body, Buffer, Headers, HttpContext, KnownHeaderName,
3    KnownHeaderName::Host,
4    Method, ProtocolSession, ReceivedBody, Status, Swansong, TypeSet, Version,
5    after_send::{AfterSend, SendStatus},
6    h2::H2Connection,
7    h3::H3Connection,
8    headers::hpack::FieldSection,
9    liveness::{CancelOnDisconnect, LivenessFut},
10    received_body::ReceivedBodyState,
11    util::encoding,
12};
13
14/// Header names whose semantics only apply at the HTTP/1 layer.
15///
16/// HTTP/2 (RFC 9113) and HTTP/3 (RFC 9114) call these "connection-specific"
17/// and forbid them in requests and responses.
18pub(super) const H1_ONLY_HEADERS: [KnownHeaderName; 5] = [
19    KnownHeaderName::Connection,
20    KnownHeaderName::KeepAlive,
21    KnownHeaderName::ProxyConnection,
22    KnownHeaderName::TransferEncoding,
23    KnownHeaderName::Upgrade,
24];
25
26/// Validated request pseudo-headers + headers, the common output of
27/// [`validate_h2h3_request`].
28pub(super) struct ValidatedRequest {
29    pub method: Method,
30    pub path: Cow<'static, str>,
31    pub authority: Option<Cow<'static, str>>,
32    pub scheme: Option<Cow<'static, str>>,
33    pub protocol: Option<Cow<'static, str>>,
34    pub request_headers: Headers,
35}
36
37/// Shared HTTP/2 + HTTP/3 request-validation per RFC 9113 and RFC 9114.
38///
39/// Both protocols apply the same malformed-message rules to incoming requests:
40/// no `:status` pseudo, required `:method`, non-empty `:path` (or CONNECT default),
41/// `:scheme` required for non-CONNECT, `:authority` required for CONNECT, `:authority`
42/// or `Host` required when `:scheme` is `http`/`https`, no `Host`/`:authority`
43/// mismatch, no [`H1_ONLY_HEADERS`], and `TE` restricted to `trailers`. Returns `None`
44/// on any violation; the caller maps to its protocol-specific error code.
45pub(super) fn validate_h2h3_request(
46    mut field_section: FieldSection<'static>,
47) -> Option<ValidatedRequest> {
48    let pseudo_headers = field_section.pseudo_headers_mut();
49
50    // `:status` is response-only; reject it on requests.
51    if pseudo_headers.status().is_some() {
52        return None;
53    }
54
55    let method = pseudo_headers.take_method();
56    let path = pseudo_headers.take_path();
57    let authority = pseudo_headers.take_authority();
58    let scheme = pseudo_headers.take_scheme();
59    let protocol = pseudo_headers.take_protocol();
60    let request_headers = field_section.into_headers().into_owned();
61
62    if let Some(host) = request_headers.get_str(Host)
63        && let Some(authority) = &authority
64        && host != authority.as_ref()
65    {
66        return None;
67    }
68
69    if H1_ONLY_HEADERS
70        .into_iter()
71        .any(|name| request_headers.has_header(name))
72    {
73        return None;
74    }
75
76    let method = method?;
77
78    if method != Method::Connect && scheme.is_none() {
79        return None;
80    }
81
82    let path = match (method, path) {
83        (_, Some(path)) if !path.is_empty() => path,
84        (Method::Connect, _) => Cow::Borrowed("/"),
85        _ => return None,
86    };
87
88    if method == Method::Connect && authority.is_none() {
89        return None;
90    }
91
92    // When :scheme names a scheme with a mandatory authority component, the request
93    // MUST carry either :authority or a Host header. The spec gives "http" and "https"
94    // as the canonical examples; we also include "ws" and "wss" (same
95    // hierarchical-with-mandatory-authority shape) so the rule applies consistently if
96    // a non-standard sender uses those. Exotic schemes without mandatory authority
97    // (file, data, mailto, urn) are exempt; CONNECT is handled above.
98    if method != Method::Connect
99        && matches!(scheme.as_deref(), Some("http" | "https" | "ws" | "wss"))
100        && authority.is_none()
101        && request_headers.get_str(Host).is_none()
102    {
103        return None;
104    }
105
106    match request_headers.get_str(KnownHeaderName::Te) {
107        None | Some("trailers") => {}
108        _ => return None,
109    }
110
111    Some(ValidatedRequest {
112        method,
113        path,
114        authority,
115        scheme,
116        protocol,
117        request_headers,
118    })
119}
120use encoding_rs::Encoding;
121use futures_lite::{
122    future,
123    io::{AsyncRead, AsyncWrite},
124};
125use std::{
126    borrow::Cow,
127    fmt::{self, Debug, Formatter},
128    future::Future,
129    net::IpAddr,
130    pin::pin,
131    str,
132    sync::Arc,
133    time::Instant,
134};
135mod h1;
136mod h2;
137mod h3;
138pub(crate) use h1::write_headers_or_trailers;
139pub(crate) use h3::{H3FirstFrame, encode_field_section_h3};
140
141/// An HTTP connection.
142///
143/// This struct represents both the request and the response, and holds the
144/// transport over which the response will be sent.
145#[derive(fieldwork::Fieldwork)]
146pub struct Conn<Transport> {
147    #[field(get)]
148    /// the shared [`HttpContext`]
149    pub(crate) context: Arc<HttpContext>,
150
151    /// request [headers](Headers)
152    #[field(get, get_mut)]
153    pub(crate) request_headers: Headers,
154
155    /// response [headers](Headers)
156    #[field(get, get_mut)]
157    pub(crate) response_headers: Headers,
158
159    pub(crate) path: Cow<'static, str>,
160
161    /// the http method for this conn's request
162    ///
163    /// ```
164    /// # use trillium_http::{Conn, Method};
165    /// let mut conn = Conn::new_synthetic(Method::Get, "/some/path?and&a=query", ());
166    /// assert_eq!(conn.method(), Method::Get);
167    /// ```
168    #[field(get, set, copy)]
169    pub(crate) method: Method,
170
171    /// the http status for this conn, if set
172    #[field(get, copy)]
173    pub(crate) status: Option<Status>,
174
175    /// The HTTP protocol version in use on this connection.
176    ///
177    /// ```
178    /// # use trillium_http::{Conn, Method, Version};
179    /// let conn = Conn::new_synthetic(Method::Get, "/", ());
180    /// assert_eq!(conn.http_version(), Version::Http1_1);
181    /// ```
182    #[field(get = http_version, copy)]
183    pub(crate) version: Version,
184
185    /// the [state typemap](TypeSet) for this conn
186    #[field(get, get_mut)]
187    pub(crate) state: TypeSet,
188
189    /// the response [body](Body)
190    ///
191    /// ```
192    /// # use trillium_testing::HttpTest;
193    /// HttpTest::new(|conn| async move { conn.with_response_body("hello") })
194    ///     .get("/")
195    ///     .block()
196    ///     .assert_body("hello");
197    ///
198    /// HttpTest::new(|conn| async move { conn.with_response_body(String::from("world")) })
199    ///     .get("/")
200    ///     .block()
201    ///     .assert_body("world");
202    ///
203    /// HttpTest::new(|conn| async move { conn.with_response_body(vec![99, 97, 116]) })
204    ///     .get("/")
205    ///     .block()
206    ///     .assert_body("cat");
207    /// ```
208    #[field(get, set, into, option_set_some, take, with)]
209    pub(crate) response_body: Option<Body>,
210
211    /// the transport
212    ///
213    /// This should only be used to call your own custom methods on the transport that do not read
214    /// or write any data. Calling any method that reads from or writes to the transport will
215    /// disrupt the HTTP protocol. If you're looking to transition from HTTP to another protocol,
216    /// use an HTTP upgrade.
217    #[field(get, get_mut)]
218    pub(crate) transport: Transport,
219
220    pub(crate) buffer: Buffer,
221
222    pub(crate) request_body_state: ReceivedBodyState,
223
224    pub(crate) after_send: AfterSend,
225
226    /// whether the connection is secure
227    ///
228    /// note that this does not necessarily indicate that the transport itself is secure, as it may
229    /// indicate that `trillium_http` is behind a trusted reverse proxy that has terminated tls and
230    /// provided appropriate headers to indicate this.
231    #[field(get, set, rename_predicates)]
232    pub(crate) secure: bool,
233
234    /// The [`Instant`] that the first header bytes for this conn were
235    /// received, before any processing or parsing has been performed.
236    #[field(get, copy)]
237    pub(crate) start_time: Instant,
238
239    /// The IP Address for the connection, if available
240    #[field(set, get, copy, into)]
241    pub(crate) peer_ip: Option<IpAddr>,
242
243    /// the `:authority` pseudo-header
244    #[field(set, get, into)]
245    pub(crate) authority: Option<Cow<'static, str>>,
246
247    /// the `:scheme` pseudo-header
248    #[field(set, get, into)]
249    pub(crate) scheme: Option<Cow<'static, str>>,
250
251    /// the [`ProtocolSession`] for this conn — the per-protocol session state
252    /// (h2/h3 connection driver and stream id) bundled into a single enum so the
253    /// "set together" invariant is enforced at the type level. `Http1` for
254    /// h1 / synthetic conns.
255    pub(crate) protocol_session: ProtocolSession,
256
257    /// the `:protocol` pseudo-header (extended CONNECT)
258    #[field(set, get, into)]
259    pub(crate) protocol: Option<Cow<'static, str>>,
260
261    /// request trailers, populated after the request body has been fully read
262    #[field(get, get_mut)]
263    pub(crate) request_trailers: Option<Headers>,
264
265    /// Marker set via [`Conn::upgrade`].
266    pub(crate) upgrade: bool,
267}
268
269impl<Transport> Debug for Conn<Transport> {
270    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
271        f.debug_struct("Conn")
272            .field("context", &self.context)
273            .field("request_headers", &self.request_headers)
274            .field("response_headers", &self.response_headers)
275            .field("path", &self.path)
276            .field("method", &self.method)
277            .field("status", &self.status)
278            .field("version", &self.version)
279            .field("state", &self.state)
280            .field("response_body", &self.response_body)
281            .field("transport", &format_args!(".."))
282            .field("buffer", &format_args!(".."))
283            .field("request_body_state", &self.request_body_state)
284            .field("secure", &self.secure)
285            .field("after_send", &format_args!(".."))
286            .field("start_time", &self.start_time)
287            .field("peer_ip", &self.peer_ip)
288            .field("authority", &self.authority)
289            .field("scheme", &self.scheme)
290            .field("protocol", &self.protocol)
291            .field("protocol_session", &self.protocol_session)
292            .field("request_trailers", &self.request_trailers)
293            .field("upgrade", &self.upgrade)
294            .finish()
295    }
296}
297
298impl<Transport> Conn<Transport>
299where
300    Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
301{
302    /// Returns the shared state typemap for this conn.
303    pub fn shared_state(&self) -> &TypeSet {
304        &self.context.shared_state
305    }
306
307    /// sets the http status code from any `TryInto<Status>`.
308    ///
309    /// ```
310    /// # use trillium_http::Status;
311    /// # trillium_testing::HttpTest::new(|mut conn| async move {
312    /// assert!(conn.status().is_none());
313    ///
314    /// conn.set_status(200); // a status can be set as a u16
315    /// assert_eq!(conn.status().unwrap(), Status::Ok);
316    ///
317    /// conn.set_status(Status::ImATeapot); // or as a Status
318    /// assert_eq!(conn.status().unwrap(), Status::ImATeapot);
319    /// conn
320    /// # }).get("/").block().assert_status(Status::ImATeapot);
321    /// ```
322    pub fn set_status(&mut self, status: impl TryInto<Status>) -> &mut Self {
323        self.status = Some(status.try_into().unwrap_or_else(|_| {
324            log::error!("attempted to set an invalid status code");
325            Status::InternalServerError
326        }));
327        self
328    }
329
330    /// sets the http status code from any `TryInto<Status>`, returning Conn
331    #[must_use]
332    pub fn with_status(mut self, status: impl TryInto<Status>) -> Self {
333        self.set_status(status);
334        self
335    }
336
337    /// retrieves the path part of the request url, up to and excluding any query component
338    /// ```
339    /// # use trillium_testing::HttpTest;
340    /// HttpTest::new(|mut conn| async move {
341    ///     assert_eq!(conn.path(), "/some/path");
342    ///     conn.with_status(200)
343    /// })
344    /// .get("/some/path?and&a=query")
345    /// .block()
346    /// .assert_ok();
347    /// ```
348    pub fn path(&self) -> &str {
349        match self.path.split_once('?') {
350            Some((path, _)) => path,
351            None => &self.path,
352        }
353    }
354
355    /// retrieves the combined path and any query
356    pub fn path_and_query(&self) -> &str {
357        &self.path
358    }
359
360    /// retrieves the query component of the path, or an empty &str
361    ///
362    /// ```
363    /// # use trillium_testing::HttpTest;
364    /// let server = HttpTest::new(|conn| async move {
365    ///     let querystring = conn.querystring().to_string();
366    ///     conn.with_response_body(querystring).with_status(200)
367    /// });
368    ///
369    /// server
370    ///     .get("/some/path?and&a=query")
371    ///     .block()
372    ///     .assert_body("and&a=query");
373    ///
374    /// server.get("/some/path").block().assert_body("");
375    /// ```
376    pub fn querystring(&self) -> &str {
377        self.path
378            .split_once('?')
379            .map(|(_, query)| query)
380            .unwrap_or_default()
381    }
382
383    /// get the host for this conn, if it exists
384    pub fn host(&self) -> Option<&str> {
385        self.request_headers.get_str(Host)
386    }
387
388    /// set the host for this conn
389    pub fn set_host(&mut self, host: String) -> &mut Self {
390        self.request_headers.insert(Host, host);
391        self
392    }
393
394    /// Cancels and drops the future if reading from the transport results in an error or empty read
395    ///
396    /// The use of this method is not advised if your connected http client employs pipelining
397    /// (rarely seen in the wild), as it will buffer an unbounded number of requests one byte at a
398    /// time
399    ///
400    /// If the client disconnects from the conn's transport, this function will return None. If the
401    /// future completes without disconnection, this future will return Some containing the output
402    /// of the future.
403    ///
404    /// Note that the inner future cannot borrow conn, so you will need to clone or take any
405    /// information needed to execute the future prior to executing this method.
406    ///
407    /// # Example
408    ///
409    /// ```rust
410    /// # use futures_lite::{AsyncRead, AsyncWrite};
411    /// # use trillium_http::{Conn, Method};
412    /// async fn something_slow_and_cancel_safe() -> String {
413    ///     String::from("this was not actually slow")
414    /// }
415    /// async fn handler<T>(mut conn: Conn<T>) -> Conn<T>
416    /// where
417    ///     T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
418    /// {
419    ///     let Some(returned_body) = conn
420    ///         .cancel_on_disconnect(async { something_slow_and_cancel_safe().await })
421    ///         .await
422    ///     else {
423    ///         return conn;
424    ///     };
425    ///     conn.with_response_body(returned_body).with_status(200)
426    /// }
427    /// ```
428    pub async fn cancel_on_disconnect<'a, Fut>(&'a mut self, fut: Fut) -> Option<Fut::Output>
429    where
430        Fut: Future + Send + 'a,
431    {
432        CancelOnDisconnect(self, pin!(fut)).await
433    }
434
435    /// Check if the transport is connected by attempting to read from the transport
436    ///
437    /// # Example
438    ///
439    /// This is best to use at appropriate points in a long-running handler, like:
440    ///
441    /// ```rust
442    /// # use futures_lite::{AsyncRead, AsyncWrite};
443    /// # use trillium_http::{Conn, Method};
444    /// # async fn something_slow_but_not_cancel_safe() {}
445    /// async fn handler<T>(mut conn: Conn<T>) -> Conn<T>
446    /// where
447    ///     T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
448    /// {
449    ///     for _ in 0..100 {
450    ///         if conn.is_disconnected().await {
451    ///             return conn;
452    ///         }
453    ///         something_slow_but_not_cancel_safe().await;
454    ///     }
455    ///     conn.with_status(200)
456    /// }
457    /// ```
458    pub async fn is_disconnected(&mut self) -> bool {
459        future::poll_once(LivenessFut::new(self)).await.is_some()
460    }
461
462    /// returns the [`encoding_rs::Encoding`] for this request, as determined from the mime-type
463    /// charset, if available
464    ///
465    /// ```
466    /// # use trillium_testing::HttpTest;
467    /// HttpTest::new(|mut conn| async move {
468    ///     assert_eq!(conn.request_encoding(), encoding_rs::WINDOWS_1252); // the default
469    ///
470    ///     conn.request_headers_mut()
471    ///         .insert("content-type", "text/plain;charset=utf-16");
472    ///     assert_eq!(conn.request_encoding(), encoding_rs::UTF_16LE);
473    ///
474    ///     conn.with_status(200)
475    /// })
476    /// .get("/")
477    /// .block()
478    /// .assert_ok();
479    /// ```
480    pub fn request_encoding(&self) -> &'static Encoding {
481        encoding(&self.request_headers)
482    }
483
484    /// returns the [`encoding_rs::Encoding`] for this response, as
485    /// determined from the mime-type charset, if available
486    ///
487    /// ```
488    /// # use trillium_testing::HttpTest;
489    /// HttpTest::new(|mut conn| async move {
490    ///     assert_eq!(conn.response_encoding(), encoding_rs::WINDOWS_1252); // the default
491    ///     conn.response_headers_mut()
492    ///         .insert("content-type", "text/plain;charset=utf-16");
493    ///
494    ///     assert_eq!(conn.response_encoding(), encoding_rs::UTF_16LE);
495    ///
496    ///     conn.with_status(200)
497    /// })
498    /// .get("/")
499    /// .block()
500    /// .assert_ok();
501    /// ```
502    pub fn response_encoding(&self) -> &'static Encoding {
503        encoding(&self.response_headers)
504    }
505
506    /// returns a [`ReceivedBody`] that references this conn. the conn
507    /// retains all data and holds the singular transport, but the
508    /// `ReceivedBody` provides an interface to read body content.
509    ///
510    /// If the request included an `Expect: 100-continue` header, the 100 Continue response is sent
511    /// lazily on the first read from the returned [`ReceivedBody`].
512    /// ```
513    /// # use trillium_testing::HttpTest;
514    /// let server = HttpTest::new(|mut conn| async move {
515    ///     let request_body = conn.request_body();
516    ///     assert_eq!(request_body.content_length(), Some(5));
517    ///     assert_eq!(request_body.read_string().await.unwrap(), "hello");
518    ///     conn.with_status(200)
519    /// });
520    ///
521    /// server.post("/").with_body("hello").block().assert_ok();
522    /// ```
523    pub fn request_body(&mut self) -> ReceivedBody<'_, Transport> {
524        let needs_100_continue = self.needs_100_continue();
525        let body = self.build_request_body();
526        if needs_100_continue {
527            body.with_send_100_continue()
528        } else {
529            body
530        }
531    }
532
533    /// returns a clone of the [`swansong::Swansong`] for this Conn. use
534    /// this to gracefully stop long-running futures and streams
535    /// inside of handler functions
536    pub fn swansong(&self) -> Swansong {
537        self.protocol_session
538            .h3_connection()
539            .map_or_else(|| self.context.swansong.clone(), |h| h.swansong().clone())
540    }
541
542    /// Registers a function to call after the http response has been
543    /// completely transferred.
544    ///
545    /// The callback is guaranteed to fire **exactly once** before the conn is
546    /// dropped. Either the codec's send path invokes it with the real outcome,
547    /// or — if the conn is dropped before send completes (handler panic,
548    /// transport error, mid-write disconnect) — the drop fallback invokes it
549    /// with a `SendStatus` whose `is_success()` returns false. Multiple
550    /// registrations on the same conn chain in registration order.
551    ///
552    /// Because firing is ordered by send-completion rather than handler return,
553    /// this is the right hook for instrumentation that wants to report what the
554    /// peer actually observed.
555    ///
556    /// This is a sync function and should be computationally lightweight. If
557    /// your _application_ needs additional async processing, use your runtime's
558    /// task spawn within this hook. If your _library_ needs additional async
559    /// processing in an `after_send` hook, please open an issue.
560    pub fn after_send<F>(&mut self, after_send: F)
561    where
562        F: FnOnce(SendStatus) + Send + Sync + 'static,
563    {
564        self.after_send.append(after_send);
565    }
566
567    /// applies a mapping function from one transport to another. This
568    /// is particularly useful for boxing the transport. unless you're
569    /// sure this is what you're looking for, you probably don't want
570    /// to be using this
571    pub fn map_transport<NewTransport>(
572        self,
573        f: impl Fn(Transport) -> NewTransport,
574    ) -> Conn<NewTransport>
575    where
576        NewTransport: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
577    {
578        // Manual respread: rustc treats `Conn<Transport>` and `Conn<NewTransport>` as
579        // disjoint types and rejects `..self` without the unstable
580        // `type_changing_struct_update` feature. If a new field is added to `Conn`,
581        // update this respread, `Upgrade::map_transport`, and `From<Conn> for Upgrade`
582        // (`upgrade.rs`) — they share this drift hazard.
583        Conn {
584            context: self.context,
585            request_headers: self.request_headers,
586            response_headers: self.response_headers,
587            method: self.method,
588            response_body: self.response_body,
589            path: self.path,
590            status: self.status,
591            version: self.version,
592            state: self.state,
593            transport: f(self.transport),
594            buffer: self.buffer,
595            request_body_state: self.request_body_state,
596            secure: self.secure,
597            after_send: self.after_send,
598            start_time: self.start_time,
599            peer_ip: self.peer_ip,
600            authority: self.authority,
601            scheme: self.scheme,
602            protocol: self.protocol,
603            protocol_session: self.protocol_session,
604            request_trailers: self.request_trailers,
605            upgrade: self.upgrade,
606        }
607    }
608
609    /// whether this conn is suitable for an http upgrade to another protocol
610    pub fn should_upgrade(&self) -> bool {
611        self.upgrade
612            || (self.method() == Method::Connect && self.status == Some(Status::Ok))
613            || self.status == Some(Status::SwitchingProtocols)
614    }
615
616    /// Mark this conn to be handed off as an upgrade once the response headers are sent.
617    /// Set the response status (typically `200`) and any headers describing the upgraded
618    /// byte stream before calling; the handler's `upgrade` method receives an [`Upgrade`]
619    /// with per-protocol framing applied on its `AsyncRead`/`AsyncWrite`.
620    #[doc(hidden)]
621    #[must_use]
622    pub fn upgrade(mut self) -> Self {
623        self.upgrade = true;
624        self
625    }
626
627    #[doc(hidden)]
628    pub fn finalize_headers(&mut self) {
629        if self.version == Version::Http3 {
630            self.finalize_response_headers_h3();
631        } else {
632            self.finalize_response_headers_1x();
633        }
634    }
635
636    /// the [`H2Connection`] driver for this conn, if this is an HTTP/2 request
637    pub fn h2_connection(&self) -> Option<&Arc<H2Connection>> {
638        self.protocol_session.h2_connection()
639    }
640
641    /// the h2 stream id for this conn, if this is an HTTP/2 request
642    pub fn h2_stream_id(&self) -> Option<u32> {
643        self.protocol_session.h2_stream_id()
644    }
645
646    /// the [`H3Connection`] driver for this conn, if this is an HTTP/3 request
647    pub fn h3_connection(&self) -> Option<&Arc<H3Connection>> {
648        self.protocol_session.h3_connection()
649    }
650
651    /// the h3 stream id for this conn, if this is an HTTP/3 request
652    pub fn h3_stream_id(&self) -> Option<u64> {
653        self.protocol_session.h3_stream_id()
654    }
655}