1mod options;
2mod tests;
3mod upgrade;
4mod writebuf;
5mod zerocopy;
6
7pub use options::*;
8pub use upgrade::*;
9pub use zerocopy::*;
10
11#[cfg(unix)]
12pub(crate) type RawHandle = std::os::fd::RawFd;
13#[cfg(windows)]
14pub(crate) type RawHandle = std::os::windows::io::RawHandle;
15
16use std::{
17 io::IoSlice,
18 mem::MaybeUninit,
19 pin::Pin,
20 str::FromStr,
21 task::{Context, Poll},
22 time::UNIX_EPOCH,
23};
24
25use async_channel::Receiver;
26use bytes::{Buf, Bytes, BytesMut};
27use futures_util::stream::Stream;
28use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, Uri, Version};
29use http_body::Body;
30use http_body_util::{BodyExt, Empty};
31use memchr::{memchr3_iter, memmem};
32use tokio::io::{AsyncReadExt, AsyncWriteExt};
33use tokio_util::sync::CancellationToken;
34
35use crate::{h1::writebuf::WriteBuf, EarlyHints, HttpProtocol, Incoming};
36
37const HEX_DIGITS: &[u8; 16] = b"0123456789ABCDEF";
38
39pub struct Http1<Io> {
71 io: Io,
72 options: options::Http1Options,
73 cancel_token: Option<CancellationToken>,
74 parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]>,
75 date_header_value_cached: Option<(String, std::time::SystemTime)>,
76 cached_headers: Option<HeaderMap>,
77 read_buf: BytesMut,
78 write_buf: WriteBuf,
79}
80
81#[cfg(all(target_os = "linux", feature = "h1-zerocopy"))]
82impl<Io> Http1<Io>
83where
84 for<'a> Io: tokio::io::AsyncRead
85 + tokio::io::AsyncWrite
86 + vibeio::io::AsInnerRawHandle<'a>
87 + Unpin
88 + 'static,
89{
90 #[inline]
101 pub fn zerocopy(self) -> Http1Zerocopy<Io> {
102 Http1Zerocopy { inner: self }
103 }
104}
105
106impl<Io> Http1<Io>
107where
108 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
109{
110 #[inline]
121 pub fn new(io: Io, options: options::Http1Options) -> Self {
122 let read_buf = BytesMut::with_capacity(options.max_header_size);
124 let parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]> =
125 Box::new_uninit_slice(options.max_header_count);
126 Self {
127 io,
128 options,
129 cancel_token: None,
130 parsed_headers,
131 date_header_value_cached: None,
132 cached_headers: None,
133 read_buf,
134 write_buf: WriteBuf::new(),
135 }
136 }
137
138 #[inline]
139 fn get_date_header_value(&mut self) -> &str {
140 let now = std::time::SystemTime::now();
141 if self.date_header_value_cached.as_ref().is_none_or(|v| {
142 v.1.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
143 != now.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
144 }) {
145 let value = httpdate::fmt_http_date(now).to_string();
146 self.date_header_value_cached = Some((value, now));
147 }
148 self.date_header_value_cached
149 .as_ref()
150 .map(|v| v.0.as_str())
151 .unwrap_or("")
152 }
153
154 #[inline]
164 pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
165 self.cancel_token = Some(token);
166 self
167 }
168
169 #[inline]
170 async fn fill_buf(&mut self) -> Result<usize, std::io::Error> {
171 if self.read_buf.remaining() < 1024 {
172 self.read_buf.reserve(1024);
173 }
174 let spare_capacity = self.read_buf.spare_capacity_mut();
175 let n = self
177 .io
178 .read(unsafe {
179 &mut *std::ptr::slice_from_raw_parts_mut(
180 spare_capacity.as_mut_ptr() as *mut u8,
181 spare_capacity.len(),
182 )
183 })
184 .await?;
185 if n == 0 {
186 return Ok(0);
187 }
188 unsafe { self.read_buf.set_len(self.read_buf.len() + n) };
189 Ok(n)
190 }
191
192 #[inline]
193 async fn read_body_fn(
194 &mut self,
195 body_tx: &async_channel::Sender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
196 content_length: u64,
197 ) -> Result<(), std::io::Error> {
198 let mut remaining = content_length;
199 let mut just_started = true;
200 while remaining > 0 {
201 let have_to_read_buf = !just_started || self.read_buf.is_empty();
202 just_started = false;
203 if have_to_read_buf {
204 let n = self.fill_buf().await?;
205 if n == 0 {
206 break;
207 }
208 }
209 let chunk = self
210 .read_buf
211 .split_to(
212 self.read_buf
213 .len()
214 .min(remaining.min(usize::MAX as u64) as usize),
215 )
216 .freeze();
217 remaining -= chunk.len() as u64;
218
219 let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
220 }
221 body_tx.close(); Ok(())
223 }
224
225 #[inline]
226 async fn read_body_chunk(
227 &mut self,
228 would_have_trailers: bool,
229 ) -> Result<bytes::Bytes, std::io::Error> {
230 let len = {
231 let mut len_buf_pos: usize = 0;
233 let mut just_started = true;
234 loop {
235 if len_buf_pos >= 48 {
236 return Err(std::io::Error::new(
237 std::io::ErrorKind::InvalidData,
238 "chunk length buffer overflow",
239 ));
240 }
241
242 let begin_search = len_buf_pos.saturating_sub(1);
243
244 let have_to_read_buf = !just_started || self.read_buf.is_empty();
245 just_started = false;
246 if have_to_read_buf {
247 let n = self.fill_buf().await?;
248 if n == 0 {
249 return Err(std::io::Error::new(
250 std::io::ErrorKind::UnexpectedEof,
251 "unexpected EOF",
252 ));
253 }
254 len_buf_pos += n;
255 } else {
256 len_buf_pos += self.read_buf.len();
257 }
258
259 if let Some(pos) =
260 memmem::find(&self.read_buf[begin_search..len_buf_pos.min(48)], b"\r\n")
261 {
262 let numbers =
263 std::str::from_utf8(&self.read_buf[begin_search..begin_search + pos])
264 .map_err(|_| {
265 std::io::Error::new(
266 std::io::ErrorKind::InvalidData,
267 "invalid chunk length",
268 )
269 })?;
270 let len = usize::from_str_radix(numbers, 16).map_err(|_| {
271 std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid chunk length")
272 })?;
273 self.read_buf.advance(begin_search + pos + 2);
275 break len;
276 }
277 }
278 };
279 let mut read = 0;
281 if len == 0 && would_have_trailers {
282 return Ok(bytes::Bytes::new()); }
284 let mut just_started = true;
285 while read < len + 2 {
287 let have_to_read_buf = !just_started || self.read_buf.is_empty();
288 just_started = false;
289 if have_to_read_buf {
290 let n = self.fill_buf().await?;
291 if n == 0 {
292 return Err(std::io::Error::new(
293 std::io::ErrorKind::UnexpectedEof,
294 "unexpected EOF",
295 ));
296 }
297 read += n;
298 } else {
299 read += self.read_buf.len();
300 }
301 }
302 let chunk = self.read_buf.split_to(len).freeze();
303 self.read_buf.advance(2); Ok(chunk)
305 }
306
307 #[inline]
308 async fn read_trailers(&mut self) -> Result<Option<HeaderMap>, std::io::Error> {
309 let mut bytes_read: usize = 0;
311 let mut just_started = true;
312 while bytes_read < self.options.max_header_size {
313 let old_bytes_read = bytes_read;
314 let begin_search = old_bytes_read.saturating_sub(3);
315
316 let have_to_read_buf = !just_started || self.read_buf.is_empty();
317 just_started = false;
318 if have_to_read_buf {
319 let n = self.fill_buf().await?;
320 if n == 0 {
321 return Err(std::io::Error::new(
322 std::io::ErrorKind::UnexpectedEof,
323 "unexpected EOF",
324 ));
325 }
326 bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
327 } else {
328 bytes_read =
329 (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
330 }
331
332 if bytes_read > 2 && self.read_buf[0] == b'\r' && self.read_buf[1] == b'\n' {
333 return Ok(None);
335 }
336
337 if let Some(separator_index) =
338 memmem::find(&self.read_buf[begin_search..bytes_read], b"\r\n\r\n")
339 {
340 let to_parse_length = begin_search + separator_index + 4;
341 let buf_ro = self.read_buf.split_to(to_parse_length).freeze();
342
343 let mut httparse_trailers =
345 vec![httparse::EMPTY_HEADER; self.options.max_header_count].into_boxed_slice();
346 let status = httparse::parse_headers(&buf_ro, &mut httparse_trailers)
347 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
348 if let httparse::Status::Complete((_, trailers)) = status {
349 let mut trailers_constructed = HeaderMap::new();
350 for header in trailers {
351 if header == &httparse::EMPTY_HEADER {
352 break;
354 }
355 let name = HeaderName::from_bytes(header.name.as_bytes())
356 .map_err(|e| std::io::Error::other(e.to_string()))?;
357 let value_start = header.value.as_ptr() as usize - buf_ro.as_ptr() as usize;
358 let value_len = header.value.len();
359 let value = unsafe {
361 HeaderValue::from_maybe_shared_unchecked(
362 buf_ro.slice(value_start..(value_start + value_len)),
363 )
364 };
365 trailers_constructed.append(name, value);
366 }
367
368 return Ok(Some(trailers_constructed));
369 } else {
370 return Err(std::io::Error::new(
371 std::io::ErrorKind::InvalidInput,
372 "trailer headers incomplete",
373 ));
374 }
375 }
376 }
377 Err(std::io::Error::new(
378 std::io::ErrorKind::InvalidData,
379 "request too large",
380 ))
381 }
382
383 #[inline]
384 async fn read_chunked_body_fn(
385 &mut self,
386 body_tx: &async_channel::Sender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
387 would_have_trailers: bool,
388 ) -> Result<(), std::io::Error> {
389 loop {
390 let chunk = self.read_body_chunk(would_have_trailers).await?;
391 if chunk.is_empty() {
392 break;
393 }
394
395 let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
396 }
397 if would_have_trailers {
398 let trailers = self.read_trailers().await?;
400 if let Some(trailers) = trailers {
401 let _ = body_tx.send(Ok(http_body::Frame::trailers(trailers))).await;
402 }
403 }
404 body_tx.close(); Ok(())
406 }
407
408 #[inline]
409 async fn read_request(
410 &mut self,
411 ) -> Result<
412 Option<(
413 Request<Incoming>,
414 async_channel::Sender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
415 )>,
416 std::io::Error,
417 > {
418 let (request, body_tx) = {
420 let Some((head, headers)) = self.get_head().await? else {
421 return Ok(None);
422 };
423 let headers = unsafe {
425 std::mem::transmute::<
426 &mut [MaybeUninit<httparse::Header<'static>>],
427 &mut [MaybeUninit<httparse::Header<'_>>],
428 >(headers)
429 };
430 let mut req = httparse::Request::new(&mut []);
431 let status = req
432 .parse_with_uninit_headers(&head, headers)
433 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
434 if status.is_partial() {
435 return Err(std::io::Error::new(
436 std::io::ErrorKind::InvalidData,
437 "partial request head",
438 ));
439 }
440
441 let (body_tx, body_rx) = async_channel::bounded(2);
443 let request_body = Http1Body {
444 inner: Box::pin(body_rx),
445 };
446 let mut request = Request::new(Incoming::new(request_body));
447 match req.version {
448 Some(0) => *request.version_mut() = http::Version::HTTP_10,
449 Some(1) => *request.version_mut() = http::Version::HTTP_11,
450 _ => *request.version_mut() = http::Version::HTTP_11,
451 };
452 if let Some(method) = req.method {
453 *request.method_mut() = Method::from_bytes(method.as_bytes())
454 .map_err(|e| std::io::Error::other(e.to_string()))?;
455 }
456 if let Some(path) = req.path {
457 *request.uri_mut() =
458 Uri::from_str(path).map_err(|e| std::io::Error::other(e.to_string()))?;
459 }
460 let mut header_map = self.cached_headers.take().unwrap_or_default();
461 header_map.clear();
462 let additional_capacity = req.headers.len().saturating_sub(header_map.capacity());
463 if additional_capacity > 0 {
464 header_map.reserve(additional_capacity);
465 }
466 for header in req.headers {
467 if header == &httparse::EMPTY_HEADER {
468 break;
470 }
471 let name = HeaderName::from_bytes(header.name.as_bytes())
472 .map_err(|e| std::io::Error::other(e.to_string()))?;
473 let value_start = header.value.as_ptr() as usize - head.as_ptr() as usize;
474 let value_len = header.value.len();
475 let value = unsafe {
477 HeaderValue::from_maybe_shared_unchecked(
478 head.slice(value_start..(value_start + value_len)),
479 )
480 };
481 header_map.append(name, value);
482 }
483 *request.headers_mut() = header_map;
484
485 (request, body_tx)
486 };
487 Ok(Some((request, body_tx)))
488 }
489
490 #[inline]
491 async fn get_head(
492 &mut self,
493 ) -> Result<Option<(Bytes, &mut [MaybeUninit<httparse::Header<'static>>])>, std::io::Error>
494 {
495 let mut request_line_read = false;
496 let mut bytes_read: usize = 0;
497 let mut whitespace_trimmed = None;
498 let mut just_started = true;
499 while bytes_read < self.options.max_header_size {
500 let old_bytes_read = bytes_read;
501 let begin_search = old_bytes_read.saturating_sub(3);
502
503 let have_to_read_buf = !just_started || self.read_buf.is_empty();
504 just_started = false;
505 if have_to_read_buf {
506 let n = self.fill_buf().await?;
507 if n == 0 {
508 if whitespace_trimmed.is_none() {
509 return Ok(None);
510 }
511 return Err(std::io::Error::new(
512 std::io::ErrorKind::UnexpectedEof,
513 "unexpected EOF",
514 ));
515 }
516 bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
517 } else {
518 bytes_read =
519 (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
520 }
521
522 if whitespace_trimmed.is_none() {
523 whitespace_trimmed = self.read_buf[old_bytes_read..bytes_read]
524 .iter()
525 .position(|b| !b.is_ascii_whitespace());
526 }
527
528 if let Some(whitespace_trimmed) = whitespace_trimmed {
529 if !request_line_read {
531 let memchr = memchr3_iter(
532 b' ',
533 b'\r',
534 b'\n',
535 &self.read_buf[whitespace_trimmed..bytes_read],
536 );
537 let mut spaces = 0;
538 for separator_index in memchr {
539 if self.read_buf[whitespace_trimmed + separator_index] == b' ' {
540 if spaces >= 2 {
541 return Err(std::io::Error::new(
542 std::io::ErrorKind::InvalidInput,
543 "bad request first line",
544 ));
545 }
546 spaces += 1;
547 } else if spaces == 2 {
548 request_line_read = true;
549 break;
550 } else {
551 return Err(std::io::Error::new(
552 std::io::ErrorKind::InvalidInput,
553 "bad request first line",
554 ));
555 }
556 }
557 }
558
559 if request_line_read {
560 let begin_search = begin_search.max(whitespace_trimmed);
561 if let Some((separator_index, separator_len)) =
562 search_header_body_separator(&self.read_buf[begin_search..bytes_read])
563 {
564 let to_parse_length =
565 begin_search + separator_index + separator_len - whitespace_trimmed;
566 self.read_buf.advance(whitespace_trimmed);
567 let head = self.read_buf.split_to(to_parse_length);
568 return Ok(Some((head.freeze(), &mut self.parsed_headers)));
569 }
570 }
571 }
572 }
573 Err(std::io::Error::new(
574 std::io::ErrorKind::InvalidData,
575 "request too large",
576 ))
577 }
578
579 #[inline]
580 async fn write_response<Z, ZFut>(
581 &mut self,
582 mut response: Response<
583 impl Body<Data = bytes::Bytes, Error = impl std::error::Error> + Unpin,
584 >,
585 version: Version,
586 write_trailers: bool,
587 zerocopy_fn: Option<Z>,
588 ) -> Result<(), std::io::Error>
589 where
590 Z: FnMut(RawHandle, &'static Io, u64) -> ZFut,
591 ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
592 {
593 if self.options.send_date_header {
595 response.headers_mut().insert(
596 header::DATE,
597 HeaderValue::from_str(self.get_date_header_value())
598 .map_err(|e| std::io::Error::other(e.to_string()))?,
599 );
600 }
601
602 if let Some(suggested_content_length) = response.body().size_hint().exact() {
604 let headers = response.headers_mut();
605 if !headers.contains_key(header::CONTENT_LENGTH) {
606 headers.insert(header::CONTENT_LENGTH, suggested_content_length.into());
607 }
608 }
609
610 let chunked = response
611 .headers()
612 .get(header::TRANSFER_ENCODING)
613 .map(|v| {
614 v.to_str().ok().is_some_and(|s| {
615 s.split(',')
616 .any(|s| s.trim().eq_ignore_ascii_case("chunked"))
617 })
618 })
619 .unwrap_or_else(|| {
620 response
621 .headers()
622 .get(header::CONTENT_LENGTH)
623 .and_then(|v| v.to_str().ok())
624 .is_none_or(|s| s.parse::<u64>().is_err())
625 });
626
627 if chunked {
628 response.headers_mut().insert(
629 header::TRANSFER_ENCODING,
630 HeaderValue::from_static("chunked"),
631 );
632 while response
633 .headers_mut()
634 .remove(header::CONTENT_LENGTH)
635 .is_some()
636 {}
637 }
638
639 let (parts, mut body) = response.into_parts();
640
641 let mut head = Vec::with_capacity(30 + parts.headers.len() * 30); if version == Version::HTTP_10 {
643 head.extend_from_slice(b"HTTP/1.0 ");
644 } else {
645 head.extend_from_slice(b"HTTP/1.1 ");
646 }
647 let status = parts.status;
648 head.extend_from_slice(status.as_str().as_bytes());
649 if let Some(canonical_reason) = status.canonical_reason() {
650 head.extend_from_slice(b" ");
651 head.extend_from_slice(canonical_reason.as_bytes());
652 }
653 head.extend_from_slice(b"\r\n");
654 for (name, value) in &parts.headers {
655 head.extend_from_slice(name.as_str().as_bytes());
656 head.extend_from_slice(b": ");
657 head.extend_from_slice(value.as_bytes());
658 head.extend_from_slice(b"\r\n");
659 }
660 head.extend_from_slice(b"\r\n");
661 unsafe {
662 self.write_buf.push(IoSlice::new(&head));
663 }
664
665 if !chunked {
666 if let Some(content_length) = parts
667 .headers
668 .get(header::CONTENT_LENGTH)
669 .and_then(|v| v.to_str().ok())
670 .and_then(|s| s.parse::<u64>().ok())
671 {
672 if let Some(zero_copy) = parts.extensions.get::<ZerocopyResponse>() {
673 if let Some(mut zerocopy_fn) = zerocopy_fn {
674 unsafe {
676 self.write_buf
677 .flush(&mut self.io, self.options.enable_vectored_write)
678 .await?
679 };
680 zerocopy_fn(
681 zero_copy.handle,
682 unsafe { std::mem::transmute::<&Io, &'static Io>(&self.io) },
684 content_length,
685 )
686 .await?;
687 self.io.flush().await?;
688 let reclaimed_headers = parts.headers;
689 self.cached_headers = Some(reclaimed_headers);
690 return Ok(());
691 }
692 }
693 }
694 }
695
696 let mut trailers_written = false;
697 while let Some(chunk) = body.frame().await {
698 let chunk = chunk.map_err(|e| std::io::Error::other(e.to_string()))?;
699 match chunk.into_data() {
700 Ok(data) => {
701 if chunked {
702 let mut data_len_buf = Vec::with_capacity(16);
703 write_chunk_size(&mut data_len_buf, data.len());
704 unsafe {
705 self.write_buf.push(IoSlice::new(&data_len_buf));
706 self.write_buf.push(IoSlice::new(&data));
707 self.write_buf.push(IoSlice::new(b"\r\n"));
708 self.write_buf
709 .write(&mut self.io, self.options.enable_vectored_write)
710 .await?;
711 };
712 } else {
713 unsafe {
714 self.write_buf.push(IoSlice::new(&data));
715 self.write_buf
716 .write(&mut self.io, self.options.enable_vectored_write)
717 .await?;
718 }
719 }
720 }
721 Err(chunk) => {
722 if let Ok(trailers) = chunk.into_trailers() {
723 if write_trailers {
724 unsafe {
725 self.write_buf.push(IoSlice::new(b"0\r\n"));
726 for (name, value) in &trailers {
727 self.write_buf.push(IoSlice::new(name.as_str().as_bytes()));
728 self.write_buf.push(IoSlice::new(b": "));
729 self.write_buf.push(IoSlice::new(value.as_bytes()));
730 self.write_buf.push(IoSlice::new(b"\r\n"));
731 }
732 self.write_buf.push(IoSlice::new(b"\r\n"));
733 self.write_buf
734 .write(&mut self.io, self.options.enable_vectored_write)
735 .await?;
736 }
737 trailers_written = true;
738 }
739 break;
740 }
741 }
742 };
743 }
744 if chunked && !trailers_written {
745 unsafe {
747 self.write_buf.push(IoSlice::new(b"0\r\n\r\n"));
748 }
749 }
750 unsafe {
751 self.write_buf
752 .flush(&mut self.io, self.options.enable_vectored_write)
753 .await?;
754 }
755 self.io.flush().await?;
756 let reclaimed_headers = parts.headers;
757 self.cached_headers = Some(reclaimed_headers);
758
759 Ok(())
760 }
761
762 #[inline]
763 async fn write_100_continue(&mut self, version: Version) -> Result<(), std::io::Error> {
764 if version == Version::HTTP_10 {
765 self.io.write_all(b"HTTP/1.0 100 Continue\r\n\r\n").await?;
766 } else {
767 self.io.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").await?;
768 }
769 self.io.flush().await?;
770
771 Ok(())
772 }
773
774 #[inline]
775 async fn write_early_hints(
776 &mut self,
777 version: Version,
778 headers: http::HeaderMap,
779 ) -> Result<(), std::io::Error> {
780 let mut head = Vec::new();
781 if version == Version::HTTP_10 {
782 head.extend_from_slice(b"HTTP/1.0 103 Early Hints\r\n");
783 } else {
784 head.extend_from_slice(b"HTTP/1.1 103 Early Hints\r\n");
785 }
786 let mut current_header_name = None;
787 for (name, value) in headers {
788 if let Some(name) = name {
789 current_header_name = Some(name);
790 };
791 if let Some(current_header_name) = ¤t_header_name {
792 head.extend_from_slice(current_header_name.as_str().as_bytes());
793 if value.is_empty() {
794 head.extend_from_slice(b":\r\n");
795 continue;
796 }
797 head.extend_from_slice(b": ");
798 head.extend_from_slice(value.as_bytes());
799 head.extend_from_slice(b"\r\n");
800 }
801 }
802 head.extend_from_slice(b"\r\n");
803
804 self.io.write_all(&head).await?;
805
806 Ok(())
807 }
808
809 #[inline]
810 pub(crate) async fn handle_with_error_fn_and_zerocopy<
811 F,
812 Fut,
813 ResB,
814 ResBE,
815 ResE,
816 EF,
817 EFut,
818 EResB,
819 EResBE,
820 EResE,
821 ZF,
822 ZFut,
823 >(
824 mut self,
825 request_fn: F,
826 error_fn: EF,
827 mut zerocopy_fn: Option<ZF>,
828 ) -> Result<(), std::io::Error>
829 where
830 F: Fn(Request<Incoming>) -> Fut + 'static,
831 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
832 ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
833 ResE: std::error::Error,
834 ResBE: std::error::Error,
835 EF: FnOnce(bool) -> EFut,
836 EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
837 EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin,
838 EResE: std::error::Error,
839 EResBE: std::error::Error,
840 ZF: FnMut(RawHandle, &'static Io, u64) -> ZFut,
841 ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
842 {
843 let mut keep_alive = true;
844
845 while keep_alive {
846 let (mut request, body_tx) = match if let Some(timeout) =
847 self.options.header_read_timeout
848 {
849 vibeio::time::timeout(timeout, self.read_request()).await
850 } else {
851 Ok(self.read_request().await)
852 } {
853 Ok(Ok(Some(d))) => d,
854 Ok(Ok(None)) => {
855 return Ok(());
856 }
857 Ok(Err(e)) => {
858 if let Ok(mut response) = error_fn(false).await {
860 response
861 .headers_mut()
862 .insert(header::CONNECTION, HeaderValue::from_static("close"));
863
864 let _ = self
865 .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
866 .await;
867 }
868 return Err(e);
869 }
870 Err(_) => {
871 if let Ok(mut response) = error_fn(true).await {
873 response
874 .headers_mut()
875 .insert(header::CONNECTION, HeaderValue::from_static("close"));
876
877 let _ = self
878 .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
879 .await;
880 }
881 return Err(std::io::Error::new(
882 std::io::ErrorKind::TimedOut,
883 "header read timeout",
884 ));
885 }
886 };
887
888 let connection_header_split = request
890 .headers()
891 .get(header::CONNECTION)
892 .and_then(|v| v.to_str().ok())
893 .map(|v| v.split(",").map(|v| v.trim()));
894 let is_connection_close = connection_header_split
895 .clone()
896 .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("close")));
897 let is_connection_keep_alive = connection_header_split
898 .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("keep-alive")));
899 keep_alive = !is_connection_close
900 && (is_connection_keep_alive || request.version() == http::Version::HTTP_11);
901
902 let version = request.version();
903
904 if self.options.send_continue_response {
906 let is_100_continue = request
907 .headers()
908 .get(header::EXPECT)
909 .and_then(|v| v.to_str().ok())
910 .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
911 if is_100_continue {
912 self.write_100_continue(version).await?;
913 }
914 }
915
916 let early_hints_fut = if self.options.enable_early_hints {
918 let (early_hints_tx, early_hints_rx) = async_channel::unbounded();
919 let early_hints = EarlyHints::new(early_hints_tx);
920 request.extensions_mut().insert(early_hints);
921 let mut_self = unsafe { std::mem::transmute::<&mut Self, &mut Self>(&mut self) };
925 Some(async {
926 let early_hints_rx = early_hints_rx;
927 while let Ok((headers, sender)) = early_hints_rx.recv().await {
928 sender
929 .into_inner()
930 .send(mut_self.write_early_hints(version, headers).await)
931 .ok();
932 }
933 futures_util::future::pending::<Result<(), std::io::Error>>().await
934 })
935 } else {
936 None
937 };
938
939 let content_length = request
941 .headers()
942 .get(header::CONTENT_LENGTH)
943 .and_then(|v| v.to_str().ok())
944 .and_then(|v| v.parse::<u64>().ok())
945 .unwrap_or(0);
946 let chunked = request
947 .headers()
948 .get(header::TRANSFER_ENCODING)
949 .and_then(|v| v.to_str().ok())
950 .is_some_and(|v| {
951 v.split(',')
952 .any(|v| v.trim().eq_ignore_ascii_case("chunked"))
953 });
954 let has_trailers = request
955 .headers()
956 .get(header::TRAILER)
957 .map(|v| v.to_str().ok().is_some_and(|s| !s.is_empty()))
958 .unwrap_or(false);
959 let write_trailers = request
960 .headers()
961 .get(header::TE)
962 .and_then(|v| v.to_str().ok())
963 .map(|v| {
964 v.split(',')
965 .any(|v| v.trim().eq_ignore_ascii_case("trailers"))
966 })
967 .unwrap_or(false);
968
969 let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
971 let upgrade = Upgrade::new(upgrade_rx);
972 let upgraded = upgrade.upgraded.clone();
973 request.extensions_mut().insert(upgrade);
974
975 let mut response = {
977 let read_body_fut = async {
978 if chunked {
979 self.read_chunked_body_fn(&body_tx, has_trailers).await
980 } else {
981 self.read_body_fn(&body_tx, content_length).await
982 }
983 };
984 let read_body_fut_pin = std::pin::pin!(read_body_fut);
985 let request_fut = request_fn(request);
986 let request_fut_pin = std::pin::pin!(request_fut);
987 let early_hints_fut: Pin<
988 Box<dyn std::future::Future<Output = Result<(), std::io::Error>>>,
989 > = if let Some(early_hints) = early_hints_fut {
990 Box::pin(early_hints)
991 } else {
992 Box::pin(futures_util::future::pending::<Result<(), std::io::Error>>())
993 };
994
995 let select_read_body_either =
996 futures_util::future::select(request_fut_pin, early_hints_fut);
997 let select_either =
998 futures_util::future::select(read_body_fut_pin, select_read_body_either).await;
999
1000 let (response, body_fut) = match select_either {
1001 futures_util::future::Either::Left((result, request_fut)) => {
1002 result?;
1003 (
1004 match request_fut.await {
1005 futures_util::future::Either::Left((response, _)) => response,
1006 futures_util::future::Either::Right((_, _)) => unreachable!(),
1007 },
1008 None,
1009 )
1010 }
1011 futures_util::future::Either::Right((response, read_body_fut)) => (
1012 match response {
1013 futures_util::future::Either::Left((response, _)) => response,
1014 futures_util::future::Either::Right((_, _)) => unreachable!(),
1015 },
1016 Some(read_body_fut),
1017 ),
1018 };
1019
1020 if let Some(body_fut) = body_fut {
1022 body_fut.await?;
1023 }
1024
1025 response.map_err(|e| std::io::Error::other(e.to_string()))?
1026 };
1027
1028 let mut was_upgraded = false;
1029 if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
1030 was_upgraded = true;
1031 response
1032 .headers_mut()
1033 .insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
1034 } else if keep_alive {
1035 if version == Version::HTTP_10
1036 || response.headers().contains_key(header::CONNECTION)
1037 {
1038 response
1039 .headers_mut()
1040 .insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
1041 }
1042 } else if version == Version::HTTP_11
1043 || response.headers().contains_key(header::CONNECTION)
1044 {
1045 response
1046 .headers_mut()
1047 .insert(header::CONNECTION, HeaderValue::from_static("close"));
1048 }
1049
1050 self.write_response(response, version, write_trailers, zerocopy_fn.as_mut())
1052 .await?;
1053
1054 if was_upgraded {
1055 let frozen_buf = self.read_buf.freeze();
1057 let _ = upgrade_tx.send(Upgraded::new(
1058 self.io,
1059 if frozen_buf.is_empty() {
1060 None
1061 } else {
1062 Some(frozen_buf)
1063 },
1064 ));
1065 return Ok(());
1066 }
1067
1068 if self.cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
1069 break;
1071 }
1072 }
1073 Ok(())
1074 }
1075}
1076
1077impl<Io> HttpProtocol for Http1<Io>
1078where
1079 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
1080{
1081 #[inline]
1082 fn handle_with_error_fn<F, Fut, ResB, ResBE, ResE, EF, EFut, EResB, EResBE, EResE>(
1083 self,
1084 request_fn: F,
1085 error_fn: EF,
1086 ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1087 where
1088 F: Fn(Request<Incoming>) -> Fut + 'static,
1089 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
1090 ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
1091 ResE: std::error::Error,
1092 ResBE: std::error::Error,
1093 EF: FnOnce(bool) -> EFut,
1094 EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
1095 EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin,
1096 EResE: std::error::Error,
1097 EResBE: std::error::Error,
1098 {
1099 #[allow(clippy::type_complexity)]
1100 let no_zerocopy: Option<
1101 Box<
1102 dyn FnMut(
1103 RawHandle,
1104 &Io,
1105 u64,
1106 ) -> Box<
1107 dyn std::future::Future<Output = Result<(), std::io::Error>>
1108 + Unpin
1109 + Send
1110 + Sync,
1111 >,
1112 >,
1113 > = None;
1114 self.handle_with_error_fn_and_zerocopy(request_fn, error_fn, no_zerocopy)
1115 }
1116
1117 #[inline]
1118 fn handle<F, Fut, ResB, ResBE, ResE>(
1119 self,
1120 request_fn: F,
1121 ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1122 where
1123 F: Fn(Request<Incoming>) -> Fut + 'static,
1124 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
1125 ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
1126 ResE: std::error::Error,
1127 ResBE: std::error::Error,
1128 {
1129 self.handle_with_error_fn(request_fn, |is_timeout| async move {
1130 let mut response = Response::builder();
1131 if is_timeout {
1132 response = response.status(http::StatusCode::REQUEST_TIMEOUT);
1133 } else {
1134 response = response.status(http::StatusCode::BAD_REQUEST);
1135 }
1136 response.body(Empty::new())
1137 })
1138 }
1139}
1140
1141struct Http1Body {
1142 #[allow(clippy::type_complexity)]
1143 inner: Pin<Box<Receiver<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>>>,
1144}
1145
1146impl Body for Http1Body {
1147 type Data = bytes::Bytes;
1148 type Error = std::io::Error;
1149
1150 #[inline]
1151 fn poll_frame(
1152 mut self: Pin<&mut Self>,
1153 cx: &mut Context<'_>,
1154 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
1155 match self.inner.as_mut().poll_next(cx) {
1156 Poll::Ready(Some(Ok(frame))) => Poll::Ready(Some(Ok(frame))),
1157 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
1158 Poll::Ready(None) => Poll::Ready(None),
1159 Poll::Pending => Poll::Pending,
1160 }
1161 }
1162}
1163
1164#[inline]
1167fn search_header_body_separator(slice: &[u8]) -> Option<(usize, usize)> {
1168 if slice.len() < 2 {
1169 return None;
1171 }
1172 for (i, b) in slice.iter().copied().enumerate() {
1173 if b == b'\r' {
1174 if slice[i + 1..].chunks(3).next() == Some(&b"\n\r\n"[..]) {
1175 return Some((i, 4));
1176 }
1177 } else if b == b'\n' && slice.get(i + 1) == Some(&b'\n') {
1178 return Some((i, 2));
1179 }
1180 }
1181 None
1182}
1183
1184#[inline]
1186fn write_chunk_size(dst: &mut Vec<u8>, len: usize) {
1187 let mut buf = [0u8; 18];
1188 let mut n = len;
1189 let mut pos = buf.len();
1190 loop {
1191 pos -= 1;
1192 buf[pos] = HEX_DIGITS[n & 0xF];
1193 n >>= 4;
1194 if n == 0 {
1195 break;
1196 }
1197 }
1198 dst.extend_from_slice(&buf[pos..]);
1199 dst.extend_from_slice(b"\r\n");
1200}