trillium_http/conn.rs
1use crate::{
2 Body, Buffer, Headers, HttpContext, KnownHeaderName,
3 KnownHeaderName::Host,
4 Method, ProtocolSession, ReceivedBody, Status, Swansong, TypeSet, Version,
5 after_send::{AfterSend, SendStatus},
6 h2::H2Connection,
7 h3::H3Connection,
8 headers::hpack::FieldSection,
9 liveness::{CancelOnDisconnect, LivenessFut},
10 received_body::ReceivedBodyState,
11 util::encoding,
12};
13
14/// Header names whose semantics only apply at the HTTP/1 layer.
15///
16/// HTTP/2 (RFC 9113) and HTTP/3 (RFC 9114) call these "connection-specific"
17/// and forbid them in requests and responses.
18pub(super) const H1_ONLY_HEADERS: [KnownHeaderName; 5] = [
19 KnownHeaderName::Connection,
20 KnownHeaderName::KeepAlive,
21 KnownHeaderName::ProxyConnection,
22 KnownHeaderName::TransferEncoding,
23 KnownHeaderName::Upgrade,
24];
25
26/// Validated request pseudo-headers + headers, the common output of
27/// [`validate_h2h3_request`].
28pub(super) struct ValidatedRequest {
29 pub method: Method,
30 pub path: Cow<'static, str>,
31 pub authority: Option<Cow<'static, str>>,
32 pub scheme: Option<Cow<'static, str>>,
33 pub protocol: Option<Cow<'static, str>>,
34 pub request_headers: Headers,
35}
36
37/// Shared HTTP/2 + HTTP/3 request-validation per RFC 9113 and RFC 9114.
38///
39/// Both protocols apply the same malformed-message rules to incoming requests:
40/// no `:status` pseudo, required `:method`, non-empty `:path` (or CONNECT default),
41/// `:scheme` required for non-CONNECT, `:authority` required for CONNECT, `:authority`
42/// or `Host` required when `:scheme` is `http`/`https`, no `Host`/`:authority`
43/// mismatch, no [`H1_ONLY_HEADERS`], and `TE` restricted to `trailers`. Returns `None`
44/// on any violation; the caller maps to its protocol-specific error code.
45pub(super) fn validate_h2h3_request(
46 mut field_section: FieldSection<'static>,
47) -> Option<ValidatedRequest> {
48 let pseudo_headers = field_section.pseudo_headers_mut();
49
50 // `:status` is response-only; reject it on requests.
51 if pseudo_headers.status().is_some() {
52 return None;
53 }
54
55 let method = pseudo_headers.take_method();
56 let path = pseudo_headers.take_path();
57 let authority = pseudo_headers.take_authority();
58 let scheme = pseudo_headers.take_scheme();
59 let protocol = pseudo_headers.take_protocol();
60 let request_headers = field_section.into_headers().into_owned();
61
62 if let Some(host) = request_headers.get_str(Host)
63 && let Some(authority) = &authority
64 && host != authority.as_ref()
65 {
66 return None;
67 }
68
69 if H1_ONLY_HEADERS
70 .into_iter()
71 .any(|name| request_headers.has_header(name))
72 {
73 return None;
74 }
75
76 let method = method?;
77
78 if method != Method::Connect && scheme.is_none() {
79 return None;
80 }
81
82 let path = match (method, path) {
83 (_, Some(path)) if !path.is_empty() => path,
84 (Method::Connect, _) => Cow::Borrowed("/"),
85 _ => return None,
86 };
87
88 if method == Method::Connect && authority.is_none() {
89 return None;
90 }
91
92 // When :scheme names a scheme with a mandatory authority component, the request
93 // MUST carry either :authority or a Host header. The spec gives "http" and "https"
94 // as the canonical examples; we also include "ws" and "wss" (same
95 // hierarchical-with-mandatory-authority shape) so the rule applies consistently if
96 // a non-standard sender uses those. Exotic schemes without mandatory authority
97 // (file, data, mailto, urn) are exempt; CONNECT is handled above.
98 if method != Method::Connect
99 && matches!(scheme.as_deref(), Some("http" | "https" | "ws" | "wss"))
100 && authority.is_none()
101 && request_headers.get_str(Host).is_none()
102 {
103 return None;
104 }
105
106 match request_headers.get_str(KnownHeaderName::Te) {
107 None | Some("trailers") => {}
108 _ => return None,
109 }
110
111 Some(ValidatedRequest {
112 method,
113 path,
114 authority,
115 scheme,
116 protocol,
117 request_headers,
118 })
119}
120use encoding_rs::Encoding;
121use futures_lite::{
122 future,
123 io::{AsyncRead, AsyncWrite},
124};
125use std::{
126 borrow::Cow,
127 fmt::{self, Debug, Formatter},
128 future::Future,
129 net::IpAddr,
130 pin::pin,
131 str,
132 sync::Arc,
133 time::Instant,
134};
135mod h1;
136mod h2;
137mod h3;
138pub(crate) use h1::write_headers_or_trailers;
139pub(crate) use h3::{H3FirstFrame, encode_field_section_h3};
140
141/// An HTTP connection.
142///
143/// This struct represents both the request and the response, and holds the
144/// transport over which the response will be sent.
145#[derive(fieldwork::Fieldwork)]
146pub struct Conn<Transport> {
147 #[field(get)]
148 /// the shared [`HttpContext`]
149 pub(crate) context: Arc<HttpContext>,
150
151 /// request [headers](Headers)
152 #[field(get, get_mut)]
153 pub(crate) request_headers: Headers,
154
155 /// response [headers](Headers)
156 #[field(get, get_mut)]
157 pub(crate) response_headers: Headers,
158
159 pub(crate) path: Cow<'static, str>,
160
161 /// the http method for this conn's request
162 ///
163 /// ```
164 /// # use trillium_http::{Conn, Method};
165 /// let mut conn = Conn::new_synthetic(Method::Get, "/some/path?and&a=query", ());
166 /// assert_eq!(conn.method(), Method::Get);
167 /// ```
168 #[field(get, set, copy)]
169 pub(crate) method: Method,
170
171 /// the http status for this conn, if set
172 #[field(get, copy)]
173 pub(crate) status: Option<Status>,
174
175 /// The HTTP protocol version in use on this connection.
176 ///
177 /// ```
178 /// # use trillium_http::{Conn, Method, Version};
179 /// let conn = Conn::new_synthetic(Method::Get, "/", ());
180 /// assert_eq!(conn.http_version(), Version::Http1_1);
181 /// ```
182 #[field(get = http_version, copy)]
183 pub(crate) version: Version,
184
185 /// the [state typemap](TypeSet) for this conn
186 #[field(get, get_mut)]
187 pub(crate) state: TypeSet,
188
189 /// the response [body](Body)
190 ///
191 /// ```
192 /// # use trillium_testing::HttpTest;
193 /// HttpTest::new(|conn| async move { conn.with_response_body("hello") })
194 /// .get("/")
195 /// .block()
196 /// .assert_body("hello");
197 ///
198 /// HttpTest::new(|conn| async move { conn.with_response_body(String::from("world")) })
199 /// .get("/")
200 /// .block()
201 /// .assert_body("world");
202 ///
203 /// HttpTest::new(|conn| async move { conn.with_response_body(vec![99, 97, 116]) })
204 /// .get("/")
205 /// .block()
206 /// .assert_body("cat");
207 /// ```
208 #[field(get, set, into, option_set_some, take, with)]
209 pub(crate) response_body: Option<Body>,
210
211 /// the transport
212 ///
213 /// This should only be used to call your own custom methods on the transport that do not read
214 /// or write any data. Calling any method that reads from or writes to the transport will
215 /// disrupt the HTTP protocol. If you're looking to transition from HTTP to another protocol,
216 /// use an HTTP upgrade.
217 #[field(get, get_mut)]
218 pub(crate) transport: Transport,
219
220 pub(crate) buffer: Buffer,
221
222 pub(crate) request_body_state: ReceivedBodyState,
223
224 pub(crate) after_send: AfterSend,
225
226 /// whether the connection is secure
227 ///
228 /// note that this does not necessarily indicate that the transport itself is secure, as it may
229 /// indicate that `trillium_http` is behind a trusted reverse proxy that has terminated tls and
230 /// provided appropriate headers to indicate this.
231 #[field(get, set, rename_predicates)]
232 pub(crate) secure: bool,
233
234 /// The [`Instant`] that the first header bytes for this conn were
235 /// received, before any processing or parsing has been performed.
236 #[field(get, copy)]
237 pub(crate) start_time: Instant,
238
239 /// The IP Address for the connection, if available
240 #[field(set, get, copy, into)]
241 pub(crate) peer_ip: Option<IpAddr>,
242
243 /// the `:authority` pseudo-header
244 #[field(set, get, into)]
245 pub(crate) authority: Option<Cow<'static, str>>,
246
247 /// the `:scheme` pseudo-header
248 #[field(set, get, into)]
249 pub(crate) scheme: Option<Cow<'static, str>>,
250
251 /// the [`ProtocolSession`] for this conn — the per-protocol session state
252 /// (h2/h3 connection driver and stream id) bundled into a single enum so the
253 /// "set together" invariant is enforced at the type level. `Http1` for
254 /// h1 / synthetic conns.
255 pub(crate) protocol_session: ProtocolSession,
256
257 /// the `:protocol` pseudo-header (extended CONNECT)
258 #[field(set, get, into)]
259 pub(crate) protocol: Option<Cow<'static, str>>,
260
261 /// request trailers, populated after the request body has been fully read
262 #[field(get, get_mut)]
263 pub(crate) request_trailers: Option<Headers>,
264
265 /// Marker set via [`Conn::upgrade`].
266 pub(crate) upgrade: bool,
267}
268
269impl<Transport> Debug for Conn<Transport> {
270 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
271 f.debug_struct("Conn")
272 .field("context", &self.context)
273 .field("request_headers", &self.request_headers)
274 .field("response_headers", &self.response_headers)
275 .field("path", &self.path)
276 .field("method", &self.method)
277 .field("status", &self.status)
278 .field("version", &self.version)
279 .field("state", &self.state)
280 .field("response_body", &self.response_body)
281 .field("transport", &format_args!(".."))
282 .field("buffer", &format_args!(".."))
283 .field("request_body_state", &self.request_body_state)
284 .field("secure", &self.secure)
285 .field("after_send", &format_args!(".."))
286 .field("start_time", &self.start_time)
287 .field("peer_ip", &self.peer_ip)
288 .field("authority", &self.authority)
289 .field("scheme", &self.scheme)
290 .field("protocol", &self.protocol)
291 .field("protocol_session", &self.protocol_session)
292 .field("request_trailers", &self.request_trailers)
293 .field("upgrade", &self.upgrade)
294 .finish()
295 }
296}
297
298impl<Transport> Conn<Transport>
299where
300 Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
301{
302 /// Returns the shared state typemap for this conn.
303 pub fn shared_state(&self) -> &TypeSet {
304 &self.context.shared_state
305 }
306
307 /// sets the http status code from any `TryInto<Status>`.
308 ///
309 /// ```
310 /// # use trillium_http::Status;
311 /// # trillium_testing::HttpTest::new(|mut conn| async move {
312 /// assert!(conn.status().is_none());
313 ///
314 /// conn.set_status(200); // a status can be set as a u16
315 /// assert_eq!(conn.status().unwrap(), Status::Ok);
316 ///
317 /// conn.set_status(Status::ImATeapot); // or as a Status
318 /// assert_eq!(conn.status().unwrap(), Status::ImATeapot);
319 /// conn
320 /// # }).get("/").block().assert_status(Status::ImATeapot);
321 /// ```
322 pub fn set_status(&mut self, status: impl TryInto<Status>) -> &mut Self {
323 self.status = Some(status.try_into().unwrap_or_else(|_| {
324 log::error!("attempted to set an invalid status code");
325 Status::InternalServerError
326 }));
327 self
328 }
329
330 /// sets the http status code from any `TryInto<Status>`, returning Conn
331 #[must_use]
332 pub fn with_status(mut self, status: impl TryInto<Status>) -> Self {
333 self.set_status(status);
334 self
335 }
336
337 /// retrieves the path part of the request url, up to and excluding any query component
338 /// ```
339 /// # use trillium_testing::HttpTest;
340 /// HttpTest::new(|mut conn| async move {
341 /// assert_eq!(conn.path(), "/some/path");
342 /// conn.with_status(200)
343 /// })
344 /// .get("/some/path?and&a=query")
345 /// .block()
346 /// .assert_ok();
347 /// ```
348 pub fn path(&self) -> &str {
349 match self.path.split_once('?') {
350 Some((path, _)) => path,
351 None => &self.path,
352 }
353 }
354
355 /// retrieves the combined path and any query
356 pub fn path_and_query(&self) -> &str {
357 &self.path
358 }
359
360 /// retrieves the query component of the path, or an empty &str
361 ///
362 /// ```
363 /// # use trillium_testing::HttpTest;
364 /// let server = HttpTest::new(|conn| async move {
365 /// let querystring = conn.querystring().to_string();
366 /// conn.with_response_body(querystring).with_status(200)
367 /// });
368 ///
369 /// server
370 /// .get("/some/path?and&a=query")
371 /// .block()
372 /// .assert_body("and&a=query");
373 ///
374 /// server.get("/some/path").block().assert_body("");
375 /// ```
376 pub fn querystring(&self) -> &str {
377 self.path
378 .split_once('?')
379 .map(|(_, query)| query)
380 .unwrap_or_default()
381 }
382
383 /// get the host for this conn, if it exists
384 pub fn host(&self) -> Option<&str> {
385 self.request_headers.get_str(Host)
386 }
387
388 /// set the host for this conn
389 pub fn set_host(&mut self, host: String) -> &mut Self {
390 self.request_headers.insert(Host, host);
391 self
392 }
393
394 /// Cancels and drops the future if reading from the transport results in an error or empty read
395 ///
396 /// The use of this method is not advised if your connected http client employs pipelining
397 /// (rarely seen in the wild), as it will buffer an unbounded number of requests one byte at a
398 /// time
399 ///
400 /// If the client disconnects from the conn's transport, this function will return None. If the
401 /// future completes without disconnection, this future will return Some containing the output
402 /// of the future.
403 ///
404 /// Note that the inner future cannot borrow conn, so you will need to clone or take any
405 /// information needed to execute the future prior to executing this method.
406 ///
407 /// # Example
408 ///
409 /// ```rust
410 /// # use futures_lite::{AsyncRead, AsyncWrite};
411 /// # use trillium_http::{Conn, Method};
412 /// async fn something_slow_and_cancel_safe() -> String {
413 /// String::from("this was not actually slow")
414 /// }
415 /// async fn handler<T>(mut conn: Conn<T>) -> Conn<T>
416 /// where
417 /// T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
418 /// {
419 /// let Some(returned_body) = conn
420 /// .cancel_on_disconnect(async { something_slow_and_cancel_safe().await })
421 /// .await
422 /// else {
423 /// return conn;
424 /// };
425 /// conn.with_response_body(returned_body).with_status(200)
426 /// }
427 /// ```
428 pub async fn cancel_on_disconnect<'a, Fut>(&'a mut self, fut: Fut) -> Option<Fut::Output>
429 where
430 Fut: Future + Send + 'a,
431 {
432 CancelOnDisconnect(self, pin!(fut)).await
433 }
434
435 /// Check if the transport is connected by attempting to read from the transport
436 ///
437 /// # Example
438 ///
439 /// This is best to use at appropriate points in a long-running handler, like:
440 ///
441 /// ```rust
442 /// # use futures_lite::{AsyncRead, AsyncWrite};
443 /// # use trillium_http::{Conn, Method};
444 /// # async fn something_slow_but_not_cancel_safe() {}
445 /// async fn handler<T>(mut conn: Conn<T>) -> Conn<T>
446 /// where
447 /// T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
448 /// {
449 /// for _ in 0..100 {
450 /// if conn.is_disconnected().await {
451 /// return conn;
452 /// }
453 /// something_slow_but_not_cancel_safe().await;
454 /// }
455 /// conn.with_status(200)
456 /// }
457 /// ```
458 pub async fn is_disconnected(&mut self) -> bool {
459 future::poll_once(LivenessFut::new(self)).await.is_some()
460 }
461
462 /// returns the [`encoding_rs::Encoding`] for this request, as determined from the mime-type
463 /// charset, if available
464 ///
465 /// ```
466 /// # use trillium_testing::HttpTest;
467 /// HttpTest::new(|mut conn| async move {
468 /// assert_eq!(conn.request_encoding(), encoding_rs::WINDOWS_1252); // the default
469 ///
470 /// conn.request_headers_mut()
471 /// .insert("content-type", "text/plain;charset=utf-16");
472 /// assert_eq!(conn.request_encoding(), encoding_rs::UTF_16LE);
473 ///
474 /// conn.with_status(200)
475 /// })
476 /// .get("/")
477 /// .block()
478 /// .assert_ok();
479 /// ```
480 pub fn request_encoding(&self) -> &'static Encoding {
481 encoding(&self.request_headers)
482 }
483
484 /// returns the [`encoding_rs::Encoding`] for this response, as
485 /// determined from the mime-type charset, if available
486 ///
487 /// ```
488 /// # use trillium_testing::HttpTest;
489 /// HttpTest::new(|mut conn| async move {
490 /// assert_eq!(conn.response_encoding(), encoding_rs::WINDOWS_1252); // the default
491 /// conn.response_headers_mut()
492 /// .insert("content-type", "text/plain;charset=utf-16");
493 ///
494 /// assert_eq!(conn.response_encoding(), encoding_rs::UTF_16LE);
495 ///
496 /// conn.with_status(200)
497 /// })
498 /// .get("/")
499 /// .block()
500 /// .assert_ok();
501 /// ```
502 pub fn response_encoding(&self) -> &'static Encoding {
503 encoding(&self.response_headers)
504 }
505
506 /// returns a [`ReceivedBody`] that references this conn. the conn
507 /// retains all data and holds the singular transport, but the
508 /// `ReceivedBody` provides an interface to read body content.
509 ///
510 /// If the request included an `Expect: 100-continue` header, the 100 Continue response is sent
511 /// lazily on the first read from the returned [`ReceivedBody`].
512 /// ```
513 /// # use trillium_testing::HttpTest;
514 /// let server = HttpTest::new(|mut conn| async move {
515 /// let request_body = conn.request_body();
516 /// assert_eq!(request_body.content_length(), Some(5));
517 /// assert_eq!(request_body.read_string().await.unwrap(), "hello");
518 /// conn.with_status(200)
519 /// });
520 ///
521 /// server.post("/").with_body("hello").block().assert_ok();
522 /// ```
523 pub fn request_body(&mut self) -> ReceivedBody<'_, Transport> {
524 let needs_100_continue = self.needs_100_continue();
525 let body = self.build_request_body();
526 if needs_100_continue {
527 body.with_send_100_continue()
528 } else {
529 body
530 }
531 }
532
533 /// returns a clone of the [`swansong::Swansong`] for this Conn. use
534 /// this to gracefully stop long-running futures and streams
535 /// inside of handler functions
536 pub fn swansong(&self) -> Swansong {
537 self.protocol_session
538 .h3_connection()
539 .map_or_else(|| self.context.swansong.clone(), |h| h.swansong().clone())
540 }
541
542 /// Registers a function to call after the http response has been
543 /// completely transferred.
544 ///
545 /// The callback is guaranteed to fire **exactly once** before the conn is
546 /// dropped. Either the codec's send path invokes it with the real outcome,
547 /// or — if the conn is dropped before send completes (handler panic,
548 /// transport error, mid-write disconnect) — the drop fallback invokes it
549 /// with a `SendStatus` whose `is_success()` returns false. Multiple
550 /// registrations on the same conn chain in registration order.
551 ///
552 /// Because firing is ordered by send-completion rather than handler return,
553 /// this is the right hook for instrumentation that wants to report what the
554 /// peer actually observed.
555 ///
556 /// This is a sync function and should be computationally lightweight. If
557 /// your _application_ needs additional async processing, use your runtime's
558 /// task spawn within this hook. If your _library_ needs additional async
559 /// processing in an `after_send` hook, please open an issue.
560 pub fn after_send<F>(&mut self, after_send: F)
561 where
562 F: FnOnce(SendStatus) + Send + Sync + 'static,
563 {
564 self.after_send.append(after_send);
565 }
566
567 /// applies a mapping function from one transport to another. This
568 /// is particularly useful for boxing the transport. unless you're
569 /// sure this is what you're looking for, you probably don't want
570 /// to be using this
571 pub fn map_transport<NewTransport>(
572 self,
573 f: impl Fn(Transport) -> NewTransport,
574 ) -> Conn<NewTransport>
575 where
576 NewTransport: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
577 {
578 // Manual respread: rustc treats `Conn<Transport>` and `Conn<NewTransport>` as
579 // disjoint types and rejects `..self` without the unstable
580 // `type_changing_struct_update` feature. If a new field is added to `Conn`,
581 // update this respread, `Upgrade::map_transport`, and `From<Conn> for Upgrade`
582 // (`upgrade.rs`) — they share this drift hazard.
583 Conn {
584 context: self.context,
585 request_headers: self.request_headers,
586 response_headers: self.response_headers,
587 method: self.method,
588 response_body: self.response_body,
589 path: self.path,
590 status: self.status,
591 version: self.version,
592 state: self.state,
593 transport: f(self.transport),
594 buffer: self.buffer,
595 request_body_state: self.request_body_state,
596 secure: self.secure,
597 after_send: self.after_send,
598 start_time: self.start_time,
599 peer_ip: self.peer_ip,
600 authority: self.authority,
601 scheme: self.scheme,
602 protocol: self.protocol,
603 protocol_session: self.protocol_session,
604 request_trailers: self.request_trailers,
605 upgrade: self.upgrade,
606 }
607 }
608
609 /// whether this conn is suitable for an http upgrade to another protocol
610 pub fn should_upgrade(&self) -> bool {
611 self.upgrade
612 || (self.method() == Method::Connect && self.status == Some(Status::Ok))
613 || self.status == Some(Status::SwitchingProtocols)
614 }
615
616 /// Mark this conn to be handed off as an upgrade once the response headers are sent.
617 /// Set the response status (typically `200`) and any headers describing the upgraded
618 /// byte stream before calling; the handler's `upgrade` method receives an [`Upgrade`]
619 /// with per-protocol framing applied on its `AsyncRead`/`AsyncWrite`.
620 #[doc(hidden)]
621 #[must_use]
622 pub fn upgrade(mut self) -> Self {
623 self.upgrade = true;
624 self
625 }
626
627 #[doc(hidden)]
628 pub fn finalize_headers(&mut self) {
629 if self.version == Version::Http3 {
630 self.finalize_response_headers_h3();
631 } else {
632 self.finalize_response_headers_1x();
633 }
634 }
635
636 /// the [`H2Connection`] driver for this conn, if this is an HTTP/2 request
637 pub fn h2_connection(&self) -> Option<&Arc<H2Connection>> {
638 self.protocol_session.h2_connection()
639 }
640
641 /// the h2 stream id for this conn, if this is an HTTP/2 request
642 pub fn h2_stream_id(&self) -> Option<u32> {
643 self.protocol_session.h2_stream_id()
644 }
645
646 /// the [`H3Connection`] driver for this conn, if this is an HTTP/3 request
647 pub fn h3_connection(&self) -> Option<&Arc<H3Connection>> {
648 self.protocol_session.h3_connection()
649 }
650
651 /// the h3 stream id for this conn, if this is an HTTP/3 request
652 pub fn h3_stream_id(&self) -> Option<u64> {
653 self.protocol_session.h3_stream_id()
654 }
655}