1use crate::{
2 Buffer, Conn, Headers, HttpContext, KnownHeaderName, Method, ProtocolSession, ReceivedBody,
3 Status, TypeSet, Version,
4 h2::H2Connection,
5 h3::{Frame, H3Connection, H3Settings},
6 headers::qpack::{FieldSection, PseudoHeaders},
7 received_body::{H3TrailerFuture, ReceivedBodyState, write_chunk},
8 util::encoding,
9};
10use encoding_rs::Encoding;
11use fieldwork::Fieldwork;
12use futures_lite::{
13 AsyncWriteExt,
14 io::{AsyncRead, AsyncWrite},
15};
16use std::{
17 borrow::Cow,
18 fmt::{self, Debug, Formatter},
19 io::{self, IoSlice, Write},
20 net::IpAddr,
21 pin::Pin,
22 str,
23 sync::Arc,
24 task::{Context, Poll, ready},
25 time::Instant,
26};
27
28#[derive(Debug)]
31pub(crate) enum WriteState {
32 Raw,
35 H1Chunked(H1ChunkedState),
37 H3Framed(H3FramedState),
39}
40
41#[derive(Debug, Default)]
42pub(crate) struct H1ChunkedState {
43 pub(crate) pending: Vec<u8>,
44 pub(crate) terminator_written: bool,
45}
46
47#[derive(Debug, Default)]
48pub(crate) struct H3FramedState {
49 pub(crate) pending: Vec<u8>,
50 pub(crate) terminator_written: bool,
51}
52
53fn compute_write_state(version: Version, outbound_headers: &Headers) -> WriteState {
57 match version {
58 Version::Http1_0 | Version::Http1_1 if has_chunked_encoding(outbound_headers) => {
59 WriteState::H1Chunked(H1ChunkedState::default())
60 }
61 Version::Http3 => WriteState::H3Framed(H3FramedState::default()),
62 _ => WriteState::Raw,
63 }
64}
65
66fn has_chunked_encoding(headers: &Headers) -> bool {
69 headers
70 .get_str(KnownHeaderName::TransferEncoding)
71 .is_some_and(|v| {
72 v.split(',')
73 .any(|coding| coding.trim().eq_ignore_ascii_case("chunked"))
74 })
75}
76
77fn parse_content_length(inbound_headers: &Headers) -> Option<u64> {
80 if inbound_headers.has_header(KnownHeaderName::TransferEncoding) {
81 return None;
82 }
83 let raw = inbound_headers.get_str(KnownHeaderName::ContentLength)?;
84 match raw.parse() {
85 Ok(n) => Some(n),
86 Err(e) => {
87 log::warn!(
88 "Upgrade: ignoring unparseable Content-Length {raw:?}: {e}; inbound length \
89 validation disabled for this upgrade"
90 );
91 None
92 }
93 }
94}
95
96fn poll_drain_pending<T: AsyncWrite + Unpin>(
98 pending: &mut Vec<u8>,
99 cx: &mut Context<'_>,
100 transport: &mut T,
101) -> Poll<io::Result<()>> {
102 while !pending.is_empty() {
103 match Pin::new(&mut *transport).poll_write(cx, pending) {
104 Poll::Ready(Ok(0)) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
105 Poll::Ready(Ok(n)) => {
106 pending.drain(..n);
107 }
108 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
109 Poll::Pending => return Poll::Pending,
110 }
111 }
112 Poll::Ready(Ok(()))
113}
114
115fn best_effort_drain<T: AsyncWrite + Unpin>(
118 pending: &mut Vec<u8>,
119 cx: &mut Context<'_>,
120 transport: &mut T,
121) -> io::Result<()> {
122 while !pending.is_empty() {
123 match Pin::new(&mut *transport).poll_write(cx, pending) {
124 Poll::Ready(Ok(0)) => return Err(io::ErrorKind::WriteZero.into()),
125 Poll::Ready(Ok(n)) => {
126 pending.drain(..n);
127 }
128 Poll::Ready(Err(e)) => return Err(e),
129 Poll::Pending => break,
130 }
131 }
132 Ok(())
133}
134
135fn encode_h3_data_header(out: &mut Vec<u8>, payload_len: u64) {
138 let frame = Frame::Data(payload_len);
139 let header_len = frame.encoded_len();
140 let start = out.len();
141 out.resize(start + header_len, 0);
142 frame.encode(&mut out[start..]);
143}
144
145#[derive(Fieldwork)]
151#[fieldwork(get, get_mut, set, with, take, into_field, rename_predicates)]
152pub struct Upgrade<Transport> {
153 #[field(deprecate(was = "request_headers", since = "1.3.0"))]
155 pub(crate) received_headers: Headers,
156
157 #[field(deprecate(was = "response_headers", since = "1.3.0"))]
160 pub(crate) sent_headers: Headers,
161
162 #[field(get = false)]
164 pub(crate) path: Cow<'static, str>,
165
166 #[field(copy)]
168 pub(crate) method: Method,
169
170 pub(crate) state: TypeSet,
172
173 pub(crate) transport: Transport,
175
176 #[field(deref = "[u8]", into_field = false, set = false, with = false)]
181 pub(crate) buffer: Buffer,
182
183 #[field(deref = false)]
185 pub(crate) context: Arc<HttpContext>,
186
187 #[field(copy)]
189 pub(crate) peer_ip: Option<IpAddr>,
190
191 #[field(copy)]
193 pub(crate) start_time: Instant,
194
195 pub(crate) authority: Option<Cow<'static, str>>,
197
198 pub(crate) scheme: Option<Cow<'static, str>>,
200
201 #[field = false]
204 pub(crate) protocol_session: ProtocolSession,
205
206 pub(crate) protocol: Option<Cow<'static, str>>,
208
209 #[field = "http_version"]
211 pub(crate) version: Version,
212
213 #[field(copy)]
216 pub(crate) status: Option<Status>,
217
218 pub(crate) secure: bool,
220
221 #[field = false]
225 pub(crate) received_body_state: ReceivedBodyState,
226
227 #[field(get, get_mut, take, set = false, with = false, into_field = false)]
230 pub(crate) received_trailers: Option<Headers>,
231
232 #[field = false]
234 pub(crate) content_length_in: Option<u64>,
235
236 #[field = false]
238 pub(crate) write_state: WriteState,
239
240 #[field = false]
243 pub(crate) inbound_encoding: &'static Encoding,
244
245 #[field = false]
249 pub(crate) h3_trailer_decode_in: Option<H3TrailerFuture>,
250
251 #[field = false]
255 pub(crate) h3_trailer_payload_in: Vec<u8>,
256}
257
258impl<Transport> Upgrade<Transport> {
259 #[doc(hidden)]
260 pub fn new(
261 received_headers: Headers,
262 path: impl Into<Cow<'static, str>>,
263 method: Method,
264 transport: Transport,
265 buffer: Buffer,
266 version: Version,
267 ) -> Self {
268 Self {
269 received_headers,
270 sent_headers: Headers::new(),
271 path: path.into(),
272 method,
273 transport,
274 buffer,
275 state: TypeSet::new(),
276 context: Arc::default(),
277 peer_ip: None,
278 start_time: Instant::now(),
279 authority: None,
280 scheme: None,
281 protocol_session: ProtocolSession::Http1,
282 protocol: None,
283 secure: false,
284 version,
285 status: None,
286 received_body_state: ReceivedBodyState::Raw { total: 0 },
287 received_trailers: None,
288 content_length_in: None,
289 write_state: WriteState::Raw,
290 inbound_encoding: encoding_rs::WINDOWS_1252,
291 h3_trailer_decode_in: None,
292 h3_trailer_payload_in: Vec::new(),
293 }
294 }
295
296 #[cfg(feature = "unstable")]
297 #[doc(hidden)]
298 #[allow(clippy::too_many_arguments)]
299 pub fn from_parts(
300 received_headers: Headers,
301 sent_headers: Headers,
302 path: Cow<'static, str>,
303 method: Method,
304 transport: Transport,
305 buffer: Buffer,
306 state: TypeSet,
307 context: Arc<HttpContext>,
308 peer_ip: Option<IpAddr>,
309 authority: Option<Cow<'static, str>>,
310 scheme: Option<Cow<'static, str>>,
311 protocol_session: ProtocolSession,
312 protocol: Option<Cow<'static, str>>,
313 version: Version,
314 status: Option<Status>,
315 secure: bool,
316 received_body_state: ReceivedBodyState,
317 received_trailers: Option<Headers>,
318 ) -> Self {
319 let write_state = compute_write_state(version, &sent_headers);
320 let content_length_in = parse_content_length(&received_headers);
321 let inbound_encoding = encoding(&received_headers);
322
323 Self {
324 received_headers,
325 sent_headers,
326 path,
327 method,
328 state,
329 transport,
330 buffer,
331 context,
332 peer_ip,
333 start_time: Instant::now(),
334 authority,
335 scheme,
336 protocol_session,
337 protocol,
338 version,
339 status,
340 secure,
341 received_body_state,
342 received_trailers,
343 content_length_in,
344 write_state,
345 inbound_encoding,
346 h3_trailer_decode_in: None,
347 h3_trailer_payload_in: Vec::new(),
348 }
349 }
350
351 pub fn h2_connection(&self) -> Option<&Arc<H2Connection>> {
353 self.protocol_session.h2_connection()
354 }
355
356 pub fn h2_stream_id(&self) -> Option<u32> {
358 self.protocol_session.h2_stream_id()
359 }
360
361 pub fn h3_connection(&self) -> Option<&Arc<H3Connection>> {
363 self.protocol_session.h3_connection()
364 }
365
366 pub fn h3_stream_id(&self) -> Option<u64> {
368 self.protocol_session.h3_stream_id()
369 }
370
371 pub fn take_buffer(&mut self) -> Vec<u8> {
373 std::mem::take(&mut self.buffer).into()
374 }
375
376 #[doc(hidden)]
377 pub fn buffer_and_transport_mut(&mut self) -> (&mut Buffer, &mut Transport) {
378 (&mut self.buffer, &mut self.transport)
379 }
380
381 pub fn shared_state(&self) -> &TypeSet {
383 self.context.shared_state()
384 }
385
386 pub fn path(&self) -> &str {
388 match self.path.split_once('?') {
389 Some((path, _)) => path,
390 None => &self.path,
391 }
392 }
393
394 pub fn querystring(&self) -> &str {
396 self.path
397 .split_once('?')
398 .map(|(_, query)| query)
399 .unwrap_or_default()
400 }
401
402 pub fn map_transport<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(
406 self,
407 f: impl Fn(Transport) -> T,
408 ) -> Upgrade<T> {
409 Upgrade {
413 transport: f(self.transport),
414 path: self.path,
415 method: self.method,
416 state: self.state,
417 buffer: self.buffer,
418 received_headers: self.received_headers,
419 sent_headers: self.sent_headers,
420 context: self.context,
421 peer_ip: self.peer_ip,
422 start_time: self.start_time,
423 authority: self.authority,
424 scheme: self.scheme,
425 protocol_session: self.protocol_session,
426 protocol: self.protocol,
427 version: self.version,
428 status: self.status,
429 secure: self.secure,
430 received_body_state: self.received_body_state,
431 received_trailers: self.received_trailers,
432 content_length_in: self.content_length_in,
433 write_state: self.write_state,
434 inbound_encoding: self.inbound_encoding,
435 h3_trailer_decode_in: self.h3_trailer_decode_in,
436 h3_trailer_payload_in: self.h3_trailer_payload_in,
437 }
438 }
439}
440
441impl<Transport: AsyncWrite + Unpin> Upgrade<Transport> {
442 pub async fn send_trailers(self, trailers: Headers) -> io::Result<()> {
462 let Self {
463 mut transport,
464 mut write_state,
465 context,
466 protocol_session,
467 ..
468 } = self;
469
470 match &mut write_state {
471 WriteState::H1Chunked(state) => {
472 if state.terminator_written {
473 return Err(io::ErrorKind::BrokenPipe.into());
474 }
475 state.pending.extend_from_slice(b"0\r\n");
476 crate::conn::write_headers_or_trailers(&mut state.pending, &trailers, &context)
477 .map_err(io::Error::other)?;
478 state.pending.extend_from_slice(b"\r\n");
479 state.terminator_written = true;
480
481 transport.write_all(&state.pending).await?;
482 state.pending.clear();
483 transport.close().await
484 }
485 WriteState::H3Framed(state) => {
486 if state.terminator_written {
487 return Err(io::ErrorKind::BrokenPipe.into());
488 }
489 let Some((h3, stream_id)) = protocol_session.as_h3() else {
490 return Err(io::ErrorKind::NotConnected.into());
491 };
492 let max_field_section = h3
493 .peer_settings()
494 .and_then(H3Settings::max_field_section_size);
495 let initial_cap = context.config.request_buffer_initial_len;
496 let field_section = FieldSection::new(PseudoHeaders::default(), &trailers);
497 crate::conn::encode_field_section_h3(
498 &h3,
499 &field_section,
500 max_field_section,
501 initial_cap,
502 &mut state.pending,
503 stream_id,
504 )?;
505 state.terminator_written = true;
506
507 transport.write_all(&state.pending).await?;
508 state.pending.clear();
509 transport.close().await
510 }
511 WriteState::Raw => {
512 if let Some((h2, stream_id)) = protocol_session.as_h2() {
513 h2.submit_trailers(stream_id, trailers)
514 } else {
515 log::warn!(
516 "Upgrade::send_trailers called on a raw upgrade with no per-stream \
517 framing; trailers dropped. Set `Transfer-Encoding: chunked` on the \
518 outbound headers if you intend to emit trailers over HTTP/1.1."
519 );
520 Ok(())
521 }
522 }
523 }
524 }
525}
526
527impl<Transport> Debug for Upgrade<Transport> {
528 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
529 f.debug_struct(&format!("Upgrade<{}>", std::any::type_name::<Transport>()))
530 .field("received_headers", &self.received_headers)
531 .field("sent_headers", &self.sent_headers)
532 .field("path", &self.path)
533 .field("method", &self.method)
534 .field("buffer", &self.buffer)
535 .field("context", &self.context)
536 .field("state", &self.state)
537 .field("transport", &format_args!(".."))
538 .field("peer_ip", &self.peer_ip)
539 .field("start_time", &self.start_time)
540 .field("authority", &self.authority)
541 .field("scheme", &self.scheme)
542 .field("protocol_session", &self.protocol_session)
543 .field("protocol", &self.protocol)
544 .field("version", &self.version)
545 .field("status", &self.status)
546 .field("secure", &self.secure)
547 .field("received_body_state", &self.received_body_state)
548 .field("received_trailers", &self.received_trailers)
549 .field("content_length_in", &self.content_length_in)
550 .field("write_state", &self.write_state)
551 .field("inbound_encoding", &self.inbound_encoding.name())
552 .field(
553 "h3_trailer_decode_in",
554 &self
555 .h3_trailer_decode_in
556 .as_ref()
557 .map(|_| format_args!("..")),
558 )
559 .field(
560 "h3_trailer_payload_in_len",
561 &self.h3_trailer_payload_in.len(),
562 )
563 .finish()
564 }
565}
566
567impl<Transport> From<Conn<Transport>> for Upgrade<Transport> {
568 fn from(conn: Conn<Transport>) -> Self {
569 let Conn {
572 request_headers,
573 response_headers,
574 path,
575 method,
576 state,
577 transport,
578 buffer,
579 context,
580 peer_ip,
581 start_time,
582 authority,
583 scheme,
584 protocol_session,
585 protocol,
586 version,
587 status,
588 secure,
589 request_body_state,
590 request_trailers,
591 response_body,
592 after_send: _,
594 upgrade: _,
595 } = conn;
596
597 if let Some(body) = &response_body
598 && !body.is_empty()
599 {
600 log::warn!(
601 "Conn::upgrade() and a non-empty response body are both set; body is being \
602 discarded. The upgrade path is mutually exclusive with serving a response body."
603 );
604 }
605
606 let write_state = compute_write_state(version, &response_headers);
608 let content_length_in = parse_content_length(&request_headers);
609 let inbound_encoding = encoding(&request_headers);
610 let received_body_state = request_body_state;
611 let received_trailers = request_trailers.filter(|t| !t.is_empty());
612
613 Self {
614 received_headers: request_headers,
615 sent_headers: response_headers,
616 path,
617 method,
618 state,
619 transport,
620 buffer,
621 context,
622 peer_ip,
623 start_time,
624 authority,
625 scheme,
626 protocol_session,
627 protocol,
628 version,
629 status,
630 secure,
631 received_body_state,
632 received_trailers,
633 content_length_in,
634 write_state,
635 inbound_encoding,
636 h3_trailer_decode_in: None,
637 h3_trailer_payload_in: Vec::new(),
638 }
639 }
640}
641
642#[cfg(test)]
643mod tests;
644
645impl<Transport> AsyncRead for Upgrade<Transport>
646where
647 Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
648{
649 fn poll_read(
650 mut self: Pin<&mut Self>,
651 cx: &mut Context<'_>,
652 buf: &mut [u8],
653 ) -> Poll<io::Result<usize>> {
654 let Self {
655 transport,
656 buffer,
657 received_body_state,
658 content_length_in,
659 context,
660 protocol_session,
661 received_trailers,
662 h3_trailer_decode_in,
663 h3_trailer_payload_in,
664 inbound_encoding,
665 ..
666 } = &mut *self;
667
668 let protocol_session = protocol_session.clone();
669 let mut body: ReceivedBody<'_, Transport> = ReceivedBody::new_with_config(
670 *content_length_in,
671 buffer,
672 transport,
673 received_body_state,
674 None,
675 inbound_encoding,
676 &context.config,
677 )
678 .with_trailers(received_trailers)
679 .with_protocol_session(protocol_session)
680 .with_h3_trailer_future(h3_trailer_decode_in)
681 .with_h3_trailer_payload_buffer(h3_trailer_payload_in);
682
683 Pin::new(&mut body).poll_read(cx, buf)
684 }
685}
686
687impl<Transport: AsyncWrite + Unpin> AsyncWrite for Upgrade<Transport> {
688 fn poll_write(
689 mut self: Pin<&mut Self>,
690 cx: &mut Context<'_>,
691 buf: &[u8],
692 ) -> Poll<io::Result<usize>> {
693 let Self {
694 transport,
695 write_state,
696 ..
697 } = &mut *self;
698 match write_state {
699 WriteState::Raw => Pin::new(transport).poll_write(cx, buf),
700 WriteState::H1Chunked(state) => {
701 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
702
703 if buf.is_empty() {
705 return Poll::Ready(Ok(0));
706 }
707
708 if state.terminator_written {
709 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
710 }
711
712 write_chunk(&mut state.pending, buf);
713 best_effort_drain(&mut state.pending, cx, transport)?;
714 Poll::Ready(Ok(buf.len()))
715 }
716 WriteState::H3Framed(state) => {
717 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
718
719 if buf.is_empty() {
720 return Poll::Ready(Ok(0));
721 }
722
723 if state.terminator_written {
724 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
725 }
726
727 encode_h3_data_header(&mut state.pending, buf.len() as u64);
728 state.pending.extend_from_slice(buf);
729 best_effort_drain(&mut state.pending, cx, transport)?;
730 Poll::Ready(Ok(buf.len()))
731 }
732 }
733 }
734
735 fn poll_write_vectored(
736 mut self: Pin<&mut Self>,
737 cx: &mut Context<'_>,
738 bufs: &[IoSlice<'_>],
739 ) -> Poll<io::Result<usize>> {
740 let Self {
741 transport,
742 write_state,
743 ..
744 } = &mut *self;
745 match write_state {
746 WriteState::Raw => Pin::new(transport).poll_write_vectored(cx, bufs),
747 WriteState::H1Chunked(state) => {
748 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
749 let total: usize = bufs.iter().map(|b| b.len()).sum();
750 if total == 0 {
751 return Poll::Ready(Ok(0));
752 }
753 if state.terminator_written {
754 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
755 }
756 let _ = write!(state.pending, "{total:X}\r\n");
759 for b in bufs {
760 state.pending.extend_from_slice(b);
761 }
762 state.pending.extend_from_slice(b"\r\n");
763 best_effort_drain(&mut state.pending, cx, transport)?;
764 Poll::Ready(Ok(total))
765 }
766 WriteState::H3Framed(state) => {
767 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
768 let total: usize = bufs.iter().map(|b| b.len()).sum();
769 if total == 0 {
770 return Poll::Ready(Ok(0));
771 }
772 if state.terminator_written {
773 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
774 }
775 encode_h3_data_header(&mut state.pending, total as u64);
778 for b in bufs {
779 state.pending.extend_from_slice(b);
780 }
781 best_effort_drain(&mut state.pending, cx, transport)?;
782 Poll::Ready(Ok(total))
783 }
784 }
785 }
786
787 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
788 let Self {
789 transport,
790 write_state,
791 ..
792 } = &mut *self;
793 match write_state {
794 WriteState::Raw => Pin::new(transport).poll_flush(cx),
795 WriteState::H1Chunked(state) => {
796 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
797 Pin::new(transport).poll_flush(cx)
798 }
799 WriteState::H3Framed(state) => {
800 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
801 Pin::new(transport).poll_flush(cx)
802 }
803 }
804 }
805
806 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
807 let Self {
808 transport,
809 write_state,
810 ..
811 } = &mut *self;
812 match write_state {
813 WriteState::Raw => Pin::new(transport).poll_close(cx),
814 WriteState::H1Chunked(state) => {
815 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
816 if !state.terminator_written {
817 state.pending.extend_from_slice(b"0\r\n\r\n");
818 state.terminator_written = true;
820 }
821 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
822 Pin::new(transport).poll_close(cx)
823 }
824 WriteState::H3Framed(state) => {
825 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
827 state.terminator_written = true;
828 Pin::new(transport).poll_close(cx)
829 }
830 }
831 }
832}