1use crate::{
2 Buffer, Conn, Headers, HttpContext, KnownHeaderName, Method, ProtocolSession, ReceivedBody,
3 Status, TypeSet, Version,
4 h2::H2Connection,
5 h3::{Frame, H3Connection, H3Settings},
6 headers::qpack::{FieldSection, PseudoHeaders},
7 received_body::{H3TrailerFuture, ReceivedBodyState, write_chunk},
8 util::encoding,
9};
10use encoding_rs::Encoding;
11use fieldwork::Fieldwork;
12use futures_lite::{
13 AsyncWriteExt,
14 io::{AsyncRead, AsyncWrite},
15};
16use std::{
17 borrow::Cow,
18 fmt::{self, Debug, Formatter},
19 io::{self, IoSlice, Write},
20 net::IpAddr,
21 pin::Pin,
22 str,
23 sync::Arc,
24 task::{Context, Poll, ready},
25 time::Instant,
26};
27
28#[derive(Debug)]
31pub(crate) enum WriteState {
32 Raw,
35 H1Chunked(H1ChunkedState),
37 H3Framed(H3FramedState),
39}
40
41#[derive(Debug, Default)]
42pub(crate) struct H1ChunkedState {
43 pub(crate) pending: Vec<u8>,
44 pub(crate) terminator_written: bool,
45}
46
47#[derive(Debug, Default)]
48pub(crate) struct H3FramedState {
49 pub(crate) pending: Vec<u8>,
50 pub(crate) terminator_written: bool,
51}
52
53fn compute_write_state(version: Version, outbound_headers: &Headers) -> WriteState {
57 match version {
58 Version::Http1_0 | Version::Http1_1 if has_chunked_encoding(outbound_headers) => {
59 WriteState::H1Chunked(H1ChunkedState::default())
60 }
61 Version::Http3 => WriteState::H3Framed(H3FramedState::default()),
62 _ => WriteState::Raw,
63 }
64}
65
66fn has_chunked_encoding(headers: &Headers) -> bool {
69 headers
70 .get_str(KnownHeaderName::TransferEncoding)
71 .is_some_and(|v| {
72 v.split(',')
73 .any(|coding| coding.trim().eq_ignore_ascii_case("chunked"))
74 })
75}
76
77fn parse_content_length(inbound_headers: &Headers) -> Option<u64> {
80 if inbound_headers.has_header(KnownHeaderName::TransferEncoding) {
81 return None;
82 }
83 let raw = inbound_headers.get_str(KnownHeaderName::ContentLength)?;
84 match raw.parse() {
85 Ok(n) => Some(n),
86 Err(e) => {
87 log::warn!(
88 "Upgrade: ignoring unparseable Content-Length {raw:?}: {e}; inbound length \
89 validation disabled for this upgrade"
90 );
91 None
92 }
93 }
94}
95
96fn poll_drain_pending<T: AsyncWrite + Unpin>(
98 pending: &mut Vec<u8>,
99 cx: &mut Context<'_>,
100 transport: &mut T,
101) -> Poll<io::Result<()>> {
102 while !pending.is_empty() {
103 match Pin::new(&mut *transport).poll_write(cx, pending) {
104 Poll::Ready(Ok(0)) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
105 Poll::Ready(Ok(n)) => {
106 pending.drain(..n);
107 }
108 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
109 Poll::Pending => return Poll::Pending,
110 }
111 }
112 Poll::Ready(Ok(()))
113}
114
115fn best_effort_drain<T: AsyncWrite + Unpin>(
118 pending: &mut Vec<u8>,
119 cx: &mut Context<'_>,
120 transport: &mut T,
121) -> io::Result<()> {
122 while !pending.is_empty() {
123 match Pin::new(&mut *transport).poll_write(cx, pending) {
124 Poll::Ready(Ok(0)) => return Err(io::ErrorKind::WriteZero.into()),
125 Poll::Ready(Ok(n)) => {
126 pending.drain(..n);
127 }
128 Poll::Ready(Err(e)) => return Err(e),
129 Poll::Pending => break,
130 }
131 }
132 Ok(())
133}
134
135fn encode_h3_data_header(out: &mut Vec<u8>, payload_len: u64) {
138 let frame = Frame::Data(payload_len);
139 let header_len = frame.encoded_len();
140 let start = out.len();
141 out.resize(start + header_len, 0);
142 frame.encode(&mut out[start..]);
143}
144
145#[derive(Fieldwork)]
151#[fieldwork(get, get_mut, set, with, take, into_field, rename_predicates)]
152pub struct Upgrade<Transport> {
153 #[field(deprecate(was = "request_headers", since = "1.3.0"))]
155 pub(crate) received_headers: Headers,
156
157 #[field(deprecate(was = "response_headers", since = "1.3.0"))]
160 pub(crate) sent_headers: Headers,
161
162 #[field(get = false)]
164 pub(crate) path: Cow<'static, str>,
165
166 #[field(copy)]
168 pub(crate) method: Method,
169
170 pub(crate) state: TypeSet,
172
173 pub(crate) transport: Transport,
175
176 #[field(deref = "[u8]", into_field = false, set = false, with = false)]
181 pub(crate) buffer: Buffer,
182
183 #[field(deref = false)]
185 pub(crate) context: Arc<HttpContext>,
186
187 #[field(copy)]
189 pub(crate) peer_ip: Option<IpAddr>,
190
191 #[field(copy)]
193 pub(crate) start_time: Instant,
194
195 pub(crate) authority: Option<Cow<'static, str>>,
197
198 pub(crate) scheme: Option<Cow<'static, str>>,
200
201 #[field = false]
204 pub(crate) protocol_session: ProtocolSession,
205
206 pub(crate) protocol: Option<Cow<'static, str>>,
208
209 #[field = "http_version"]
211 pub(crate) version: Version,
212
213 #[field(copy)]
216 pub(crate) status: Option<Status>,
217
218 pub(crate) secure: bool,
220
221 #[field = false]
225 pub(crate) received_body_state: ReceivedBodyState,
226
227 #[field(get, get_mut, take, set = false, with = false, into_field = false)]
230 pub(crate) received_trailers: Option<Headers>,
231
232 #[field = false]
234 pub(crate) content_length_in: Option<u64>,
235
236 #[field = false]
238 pub(crate) write_state: WriteState,
239
240 #[field = false]
243 pub(crate) inbound_encoding: &'static Encoding,
244
245 #[field = false]
249 pub(crate) h3_trailer_decode_in: Option<H3TrailerFuture>,
250
251 #[field = false]
255 pub(crate) h3_trailer_payload_in: Vec<u8>,
256}
257
258impl<Transport> Upgrade<Transport> {
259 #[doc(hidden)]
260 pub fn new(
261 received_headers: Headers,
262 path: impl Into<Cow<'static, str>>,
263 method: Method,
264 transport: Transport,
265 buffer: Buffer,
266 version: Version,
267 ) -> Self {
268 Self {
269 received_headers,
270 sent_headers: Headers::new(),
271 path: path.into(),
272 method,
273 transport,
274 buffer,
275 state: TypeSet::new(),
276 context: Arc::default(),
277 peer_ip: None,
278 start_time: Instant::now(),
279 authority: None,
280 scheme: None,
281 protocol_session: ProtocolSession::Http1,
282 protocol: None,
283 secure: false,
284 version,
285 status: None,
286 received_body_state: ReceivedBodyState::Raw { total: 0 },
287 received_trailers: None,
288 content_length_in: None,
289 write_state: WriteState::Raw,
290 inbound_encoding: encoding_rs::WINDOWS_1252,
291 h3_trailer_decode_in: None,
292 h3_trailer_payload_in: Vec::new(),
293 }
294 }
295
296 #[cfg(feature = "unstable")]
297 #[doc(hidden)]
298 #[allow(clippy::too_many_arguments)]
299 pub fn from_parts(
300 received_headers: Headers,
301 sent_headers: Headers,
302 path: Cow<'static, str>,
303 method: Method,
304 transport: Transport,
305 buffer: Buffer,
306 state: TypeSet,
307 context: Arc<HttpContext>,
308 peer_ip: Option<IpAddr>,
309 authority: Option<Cow<'static, str>>,
310 scheme: Option<Cow<'static, str>>,
311 protocol_session: ProtocolSession,
312 protocol: Option<Cow<'static, str>>,
313 version: Version,
314 status: Option<Status>,
315 secure: bool,
316 received_body_state: ReceivedBodyState,
317 received_trailers: Option<Headers>,
318 ) -> Self {
319 let write_state = compute_write_state(version, &sent_headers);
320 let content_length_in = parse_content_length(&received_headers);
321 let inbound_encoding = encoding(&received_headers);
322
323 Self {
324 received_headers,
325 sent_headers,
326 path,
327 method,
328 state,
329 transport,
330 buffer,
331 context,
332 peer_ip,
333 start_time: Instant::now(),
334 authority,
335 scheme,
336 protocol_session,
337 protocol,
338 version,
339 status,
340 secure,
341 received_body_state,
342 received_trailers,
343 content_length_in,
344 write_state,
345 inbound_encoding,
346 h3_trailer_decode_in: None,
347 h3_trailer_payload_in: Vec::new(),
348 }
349 }
350
351 pub fn h2_connection(&self) -> Option<&Arc<H2Connection>> {
353 self.protocol_session.h2_connection()
354 }
355
356 pub fn h2_stream_id(&self) -> Option<u32> {
358 self.protocol_session.h2_stream_id()
359 }
360
361 pub fn h3_connection(&self) -> Option<&Arc<H3Connection>> {
363 self.protocol_session.h3_connection()
364 }
365
366 pub fn h3_stream_id(&self) -> Option<u64> {
368 self.protocol_session.h3_stream_id()
369 }
370
371 pub fn take_buffer(&mut self) -> Vec<u8> {
373 std::mem::take(&mut self.buffer).into()
374 }
375
376 #[doc(hidden)]
377 pub fn buffer_and_transport_mut(&mut self) -> (&mut Buffer, &mut Transport) {
378 (&mut self.buffer, &mut self.transport)
379 }
380
381 pub fn shared_state(&self) -> &TypeSet {
383 self.context.shared_state()
384 }
385
386 pub fn path(&self) -> &str {
388 match self.path.split_once('?') {
389 Some((path, _)) => path,
390 None => &self.path,
391 }
392 }
393
394 pub fn querystring(&self) -> &str {
396 self.path
397 .split_once('?')
398 .map(|(_, query)| query)
399 .unwrap_or_default()
400 }
401
402 pub fn map_transport<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(
406 self,
407 f: impl Fn(Transport) -> T,
408 ) -> Upgrade<T> {
409 Upgrade {
413 transport: f(self.transport),
414 path: self.path,
415 method: self.method,
416 state: self.state,
417 buffer: self.buffer,
418 received_headers: self.received_headers,
419 sent_headers: self.sent_headers,
420 context: self.context,
421 peer_ip: self.peer_ip,
422 start_time: self.start_time,
423 authority: self.authority,
424 scheme: self.scheme,
425 protocol_session: self.protocol_session,
426 protocol: self.protocol,
427 version: self.version,
428 status: self.status,
429 secure: self.secure,
430 received_body_state: self.received_body_state,
431 received_trailers: self.received_trailers,
432 content_length_in: self.content_length_in,
433 write_state: self.write_state,
434 inbound_encoding: self.inbound_encoding,
435 h3_trailer_decode_in: self.h3_trailer_decode_in,
436 h3_trailer_payload_in: self.h3_trailer_payload_in,
437 }
438 }
439}
440
441impl<Transport: AsyncWrite + Unpin> Upgrade<Transport> {
442 pub async fn send_trailers(self, trailers: Headers) -> io::Result<()> {
462 let Self {
463 mut transport,
464 mut write_state,
465 context,
466 protocol_session,
467 ..
468 } = self;
469
470 match &mut write_state {
471 WriteState::H1Chunked(state) => {
472 if state.terminator_written {
473 return Err(io::ErrorKind::BrokenPipe.into());
474 }
475 state.pending.extend_from_slice(b"0\r\n");
476 crate::conn::write_headers_or_trailers(&mut state.pending, &trailers, &context)
477 .map_err(io::Error::other)?;
478 state.pending.extend_from_slice(b"\r\n");
479 state.terminator_written = true;
480
481 transport.write_all(&state.pending).await?;
482 state.pending.clear();
483 transport.close().await
484 }
485 WriteState::H3Framed(state) => {
486 if state.terminator_written {
487 return Err(io::ErrorKind::BrokenPipe.into());
488 }
489 let Some((h3, stream_id)) = protocol_session.as_h3() else {
490 return Err(io::ErrorKind::NotConnected.into());
491 };
492 let max_field_section = h3
493 .peer_settings()
494 .and_then(H3Settings::max_field_section_size);
495 let field_section = FieldSection::new(PseudoHeaders::default(), &trailers);
496 crate::conn::encode_field_section_h3(
497 &h3,
498 &field_section,
499 max_field_section,
500 &mut state.pending,
501 stream_id,
502 )?;
503 state.terminator_written = true;
504
505 transport.write_all(&state.pending).await?;
506 state.pending.clear();
507 transport.close().await
508 }
509 WriteState::Raw => {
510 if let Some((h2, stream_id)) = protocol_session.as_h2() {
511 h2.submit_trailers(stream_id, trailers)
512 } else {
513 log::warn!(
514 "Upgrade::send_trailers called on a raw upgrade with no per-stream \
515 framing; trailers dropped. Set `Transfer-Encoding: chunked` on the \
516 outbound headers if you intend to emit trailers over HTTP/1.1."
517 );
518 Ok(())
519 }
520 }
521 }
522 }
523}
524
525impl<Transport> Debug for Upgrade<Transport> {
526 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
527 f.debug_struct(&format!("Upgrade<{}>", std::any::type_name::<Transport>()))
528 .field("received_headers", &self.received_headers)
529 .field("sent_headers", &self.sent_headers)
530 .field("path", &self.path)
531 .field("method", &self.method)
532 .field("buffer", &self.buffer)
533 .field("context", &self.context)
534 .field("state", &self.state)
535 .field("transport", &format_args!(".."))
536 .field("peer_ip", &self.peer_ip)
537 .field("start_time", &self.start_time)
538 .field("authority", &self.authority)
539 .field("scheme", &self.scheme)
540 .field("protocol_session", &self.protocol_session)
541 .field("protocol", &self.protocol)
542 .field("version", &self.version)
543 .field("status", &self.status)
544 .field("secure", &self.secure)
545 .field("received_body_state", &self.received_body_state)
546 .field("received_trailers", &self.received_trailers)
547 .field("content_length_in", &self.content_length_in)
548 .field("write_state", &self.write_state)
549 .field("inbound_encoding", &self.inbound_encoding.name())
550 .field(
551 "h3_trailer_decode_in",
552 &self
553 .h3_trailer_decode_in
554 .as_ref()
555 .map(|_| format_args!("..")),
556 )
557 .field(
558 "h3_trailer_payload_in_len",
559 &self.h3_trailer_payload_in.len(),
560 )
561 .finish()
562 }
563}
564
565impl<Transport> From<Conn<Transport>> for Upgrade<Transport> {
566 fn from(conn: Conn<Transport>) -> Self {
567 let Conn {
570 request_headers,
571 response_headers,
572 path,
573 method,
574 state,
575 transport,
576 buffer,
577 context,
578 peer_ip,
579 start_time,
580 authority,
581 scheme,
582 protocol_session,
583 protocol,
584 version,
585 status,
586 secure,
587 request_body_state,
588 request_trailers,
589 response_body,
590 after_send: _,
592 upgrade: _,
593 } = conn;
594
595 if let Some(body) = &response_body
596 && !body.is_empty()
597 {
598 log::warn!(
599 "Conn::upgrade() and a non-empty response body are both set; body is being \
600 discarded. The upgrade path is mutually exclusive with serving a response body."
601 );
602 }
603
604 let write_state = compute_write_state(version, &response_headers);
606 let content_length_in = parse_content_length(&request_headers);
607 let inbound_encoding = encoding(&request_headers);
608 let received_body_state = request_body_state;
609 let received_trailers = request_trailers.filter(|t| !t.is_empty());
610
611 Self {
612 received_headers: request_headers,
613 sent_headers: response_headers,
614 path,
615 method,
616 state,
617 transport,
618 buffer,
619 context,
620 peer_ip,
621 start_time,
622 authority,
623 scheme,
624 protocol_session,
625 protocol,
626 version,
627 status,
628 secure,
629 received_body_state,
630 received_trailers,
631 content_length_in,
632 write_state,
633 inbound_encoding,
634 h3_trailer_decode_in: None,
635 h3_trailer_payload_in: Vec::new(),
636 }
637 }
638}
639
640#[cfg(test)]
641mod tests;
642
643impl<Transport> AsyncRead for Upgrade<Transport>
644where
645 Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
646{
647 fn poll_read(
648 mut self: Pin<&mut Self>,
649 cx: &mut Context<'_>,
650 buf: &mut [u8],
651 ) -> Poll<io::Result<usize>> {
652 let Self {
653 transport,
654 buffer,
655 received_body_state,
656 content_length_in,
657 context,
658 protocol_session,
659 received_trailers,
660 h3_trailer_decode_in,
661 h3_trailer_payload_in,
662 inbound_encoding,
663 ..
664 } = &mut *self;
665
666 let protocol_session = protocol_session.clone();
667 let mut body: ReceivedBody<'_, Transport> = ReceivedBody::new_with_config(
668 *content_length_in,
669 buffer,
670 transport,
671 received_body_state,
672 None,
673 inbound_encoding,
674 &context.config,
675 )
676 .with_trailers(received_trailers)
677 .with_protocol_session(protocol_session)
678 .with_h3_trailer_future(h3_trailer_decode_in)
679 .with_h3_trailer_payload_buffer(h3_trailer_payload_in);
680
681 Pin::new(&mut body).poll_read(cx, buf)
682 }
683}
684
685impl<Transport: AsyncWrite + Unpin> AsyncWrite for Upgrade<Transport> {
686 fn poll_write(
687 mut self: Pin<&mut Self>,
688 cx: &mut Context<'_>,
689 buf: &[u8],
690 ) -> Poll<io::Result<usize>> {
691 let Self {
692 transport,
693 write_state,
694 ..
695 } = &mut *self;
696 match write_state {
697 WriteState::Raw => Pin::new(transport).poll_write(cx, buf),
698 WriteState::H1Chunked(state) => {
699 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
700
701 if buf.is_empty() {
703 return Poll::Ready(Ok(0));
704 }
705
706 if state.terminator_written {
707 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
708 }
709
710 write_chunk(&mut state.pending, buf);
711 best_effort_drain(&mut state.pending, cx, transport)?;
712 Poll::Ready(Ok(buf.len()))
713 }
714 WriteState::H3Framed(state) => {
715 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
716
717 if buf.is_empty() {
718 return Poll::Ready(Ok(0));
719 }
720
721 if state.terminator_written {
722 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
723 }
724
725 encode_h3_data_header(&mut state.pending, buf.len() as u64);
726 state.pending.extend_from_slice(buf);
727 best_effort_drain(&mut state.pending, cx, transport)?;
728 Poll::Ready(Ok(buf.len()))
729 }
730 }
731 }
732
733 fn poll_write_vectored(
734 mut self: Pin<&mut Self>,
735 cx: &mut Context<'_>,
736 bufs: &[IoSlice<'_>],
737 ) -> Poll<io::Result<usize>> {
738 let Self {
739 transport,
740 write_state,
741 ..
742 } = &mut *self;
743 match write_state {
744 WriteState::Raw => Pin::new(transport).poll_write_vectored(cx, bufs),
745 WriteState::H1Chunked(state) => {
746 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
747 let total: usize = bufs.iter().map(|b| b.len()).sum();
748 if total == 0 {
749 return Poll::Ready(Ok(0));
750 }
751 if state.terminator_written {
752 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
753 }
754 let _ = write!(state.pending, "{total:X}\r\n");
757 for b in bufs {
758 state.pending.extend_from_slice(b);
759 }
760 state.pending.extend_from_slice(b"\r\n");
761 best_effort_drain(&mut state.pending, cx, transport)?;
762 Poll::Ready(Ok(total))
763 }
764 WriteState::H3Framed(state) => {
765 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
766 let total: usize = bufs.iter().map(|b| b.len()).sum();
767 if total == 0 {
768 return Poll::Ready(Ok(0));
769 }
770 if state.terminator_written {
771 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
772 }
773 encode_h3_data_header(&mut state.pending, total as u64);
776 for b in bufs {
777 state.pending.extend_from_slice(b);
778 }
779 best_effort_drain(&mut state.pending, cx, transport)?;
780 Poll::Ready(Ok(total))
781 }
782 }
783 }
784
785 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
786 let Self {
787 transport,
788 write_state,
789 ..
790 } = &mut *self;
791 match write_state {
792 WriteState::Raw => Pin::new(transport).poll_flush(cx),
793 WriteState::H1Chunked(state) => {
794 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
795 Pin::new(transport).poll_flush(cx)
796 }
797 WriteState::H3Framed(state) => {
798 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
799 Pin::new(transport).poll_flush(cx)
800 }
801 }
802 }
803
804 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
805 let Self {
806 transport,
807 write_state,
808 ..
809 } = &mut *self;
810 match write_state {
811 WriteState::Raw => Pin::new(transport).poll_close(cx),
812 WriteState::H1Chunked(state) => {
813 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
814 if !state.terminator_written {
815 state.pending.extend_from_slice(b"0\r\n\r\n");
816 state.terminator_written = true;
818 }
819 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
820 Pin::new(transport).poll_close(cx)
821 }
822 WriteState::H3Framed(state) => {
823 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
825 state.terminator_written = true;
826 Pin::new(transport).poll_close(cx)
827 }
828 }
829 }
830}