1mod options;
2mod tests;
3mod writebuf;
4mod zerocopy;
5
6pub use options::*;
7pub use zerocopy::*;
8
9#[cfg(unix)]
10pub(crate) type RawHandle = std::os::fd::RawFd;
11#[cfg(windows)]
12pub(crate) type RawHandle = std::os::windows::io::RawHandle;
13
14use std::{
15 future::Future,
16 io::IoSlice,
17 mem::MaybeUninit,
18 pin::Pin,
19 str::FromStr,
20 task::{Context, Poll},
21 time::UNIX_EPOCH,
22};
23
24use bytes::{Buf, Bytes, BytesMut};
25use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, Uri, Version};
26use http_body::Body;
27use http_body_util::{BodyExt, Empty};
28use kanal::AsyncReceiver;
29use memchr::{memchr3_iter, memmem};
30use tokio::io::{AsyncReadExt, AsyncWriteExt};
31use tokio_util::sync::CancellationToken;
32
33use crate::{h1::writebuf::WriteBuf, EarlyHints, HttpProtocol, Incoming, Upgrade, Upgraded};
34
35const HEX_DIGITS: &[u8; 16] = b"0123456789ABCDEF";
36const WRITE_BUF_BATCH_THRESHOLD: usize = 16384;
37
38pub struct Http1<Io> {
70 io: Io,
71 options: options::Http1Options,
72 cancel_token: Option<CancellationToken>,
73 parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]>,
74 date_header_value_cached: Option<(String, std::time::SystemTime)>,
75 cached_headers: Option<HeaderMap>,
76 read_buf: BytesMut,
77 response_head_buf: Vec<u8>,
78 write_buf: WriteBuf,
79}
80
81#[cfg(all(target_os = "linux", feature = "h1-zerocopy"))]
82impl<Io> Http1<Io>
83where
84 for<'a> Io: tokio::io::AsyncRead
85 + tokio::io::AsyncWrite
86 + vibeio::io::AsInnerRawHandle<'a>
87 + Unpin
88 + 'static,
89{
90 #[inline]
101 pub fn zerocopy(self) -> Http1Zerocopy<Io> {
102 Http1Zerocopy { inner: self }
103 }
104}
105
106impl<Io> Http1<Io>
107where
108 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
109{
110 #[inline]
121 pub fn new(io: Io, options: options::Http1Options) -> Self {
122 let read_buf = BytesMut::with_capacity(options.max_header_size);
124 let parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]> =
125 Box::new_uninit_slice(options.max_header_count);
126 Self {
127 io,
128 options,
129 cancel_token: None,
130 parsed_headers,
131 date_header_value_cached: None,
132 cached_headers: None,
133 read_buf,
134 response_head_buf: Vec::with_capacity(1024),
135 write_buf: WriteBuf::new(),
136 }
137 }
138
139 #[inline]
140 fn get_date_header_value(&mut self) -> &str {
141 let now = std::time::SystemTime::now();
142 if self.date_header_value_cached.as_ref().is_none_or(|v| {
143 v.1.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
144 != now.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
145 }) {
146 let value = httpdate::fmt_http_date(now).to_string();
147 self.date_header_value_cached = Some((value, now));
148 }
149 self.date_header_value_cached
150 .as_ref()
151 .map(|v| v.0.as_str())
152 .unwrap_or("")
153 }
154
155 #[inline]
165 pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
166 self.cancel_token = Some(token);
167 self
168 }
169
170 #[inline]
171 async fn fill_buf(&mut self) -> Result<usize, std::io::Error> {
172 if self.read_buf.remaining() < 1024 {
173 self.read_buf.reserve(1024);
174 }
175 let spare_capacity = self.read_buf.spare_capacity_mut();
176 let n = self
178 .io
179 .read(unsafe {
180 &mut *std::ptr::slice_from_raw_parts_mut(
181 spare_capacity.as_mut_ptr() as *mut u8,
182 spare_capacity.len(),
183 )
184 })
185 .await?;
186 if n == 0 {
187 return Ok(0);
188 }
189 unsafe { self.read_buf.set_len(self.read_buf.len() + n) };
190 Ok(n)
191 }
192
193 #[inline]
194 async fn read_body_fn(
195 &mut self,
196 body_tx: kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
197 content_length: u64,
198 ) -> Result<(), std::io::Error> {
199 let mut remaining = content_length;
200 let mut just_started = true;
201 while remaining > 0 {
202 let have_to_read_buf = !just_started || self.read_buf.is_empty();
203 just_started = false;
204 if have_to_read_buf {
205 let n = self.fill_buf().await?;
206 if n == 0 {
207 break;
208 }
209 }
210 let chunk = self
211 .read_buf
212 .split_to(
213 self.read_buf
214 .len()
215 .min(remaining.min(usize::MAX as u64) as usize),
216 )
217 .freeze();
218 remaining -= chunk.len() as u64;
219
220 let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
221 }
222 Ok(())
223 }
224
225 #[inline]
226 async fn read_body_chunk(
227 &mut self,
228 would_have_trailers: bool,
229 ) -> Result<bytes::Bytes, std::io::Error> {
230 let len = {
231 let mut len_buf_pos: usize = 0;
233 let mut just_started = true;
234 loop {
235 if len_buf_pos >= 48 {
236 return Err(std::io::Error::new(
237 std::io::ErrorKind::InvalidData,
238 "chunk length buffer overflow",
239 ));
240 }
241
242 let begin_search = len_buf_pos.saturating_sub(1);
243
244 let have_to_read_buf = !just_started || self.read_buf.is_empty();
245 just_started = false;
246 if have_to_read_buf {
247 let n = self.fill_buf().await?;
248 if n == 0 {
249 return Err(std::io::Error::new(
250 std::io::ErrorKind::UnexpectedEof,
251 "unexpected EOF",
252 ));
253 }
254 len_buf_pos += n;
255 } else {
256 len_buf_pos += self.read_buf.len();
257 }
258
259 if let Some(pos) =
260 memmem::find(&self.read_buf[begin_search..len_buf_pos.min(48)], b"\r\n")
261 {
262 let numbers =
263 std::str::from_utf8(&self.read_buf[begin_search..begin_search + pos])
264 .map_err(|_| {
265 std::io::Error::new(
266 std::io::ErrorKind::InvalidData,
267 "invalid chunk length",
268 )
269 })?;
270 let len = usize::from_str_radix(numbers, 16).map_err(|_| {
271 std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid chunk length")
272 })?;
273 self.read_buf.advance(begin_search + pos + 2);
275 break len;
276 }
277 }
278 };
279 let mut read = 0;
281 if len == 0 && would_have_trailers {
282 return Ok(bytes::Bytes::new()); }
284 let mut just_started = true;
285 while read < len + 2 {
287 let have_to_read_buf = !just_started || self.read_buf.is_empty();
288 just_started = false;
289 if have_to_read_buf {
290 let n = self.fill_buf().await?;
291 if n == 0 {
292 return Err(std::io::Error::new(
293 std::io::ErrorKind::UnexpectedEof,
294 "unexpected EOF",
295 ));
296 }
297 read += n;
298 } else {
299 read += self.read_buf.len();
300 }
301 }
302 let chunk = self.read_buf.split_to(len).freeze();
303 self.read_buf.advance(2); Ok(chunk)
305 }
306
307 #[inline]
308 async fn read_trailers(&mut self) -> Result<Option<HeaderMap>, std::io::Error> {
309 let mut bytes_read: usize = 0;
311 let mut just_started = true;
312 while bytes_read < self.options.max_header_size {
313 let old_bytes_read = bytes_read;
314 let begin_search = old_bytes_read.saturating_sub(3);
315
316 let have_to_read_buf = !just_started || self.read_buf.is_empty();
317 just_started = false;
318 if have_to_read_buf {
319 let n = self.fill_buf().await?;
320 if n == 0 {
321 return Err(std::io::Error::new(
322 std::io::ErrorKind::UnexpectedEof,
323 "unexpected EOF",
324 ));
325 }
326 bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
327 } else {
328 bytes_read =
329 (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
330 }
331
332 if bytes_read > 2 && self.read_buf[0] == b'\r' && self.read_buf[1] == b'\n' {
333 return Ok(None);
335 }
336
337 if let Some(separator_index) =
338 memmem::find(&self.read_buf[begin_search..bytes_read], b"\r\n\r\n")
339 {
340 let to_parse_length = begin_search + separator_index + 4;
341 let buf_ro = self.read_buf.split_to(to_parse_length).freeze();
342
343 let mut httparse_trailers =
345 vec![httparse::EMPTY_HEADER; self.options.max_header_count].into_boxed_slice();
346 let status = httparse::parse_headers(&buf_ro, &mut httparse_trailers)
347 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
348 if let httparse::Status::Complete((_, trailers)) = status {
349 let mut trailers_constructed = HeaderMap::new();
350 for header in trailers {
351 if header == &httparse::EMPTY_HEADER {
352 break;
354 }
355 let name = HeaderName::from_bytes(header.name.as_bytes())
356 .map_err(|e| std::io::Error::other(e.to_string()))?;
357 let value_start = header.value.as_ptr() as usize - buf_ro.as_ptr() as usize;
358 let value_len = header.value.len();
359 let value = unsafe {
361 HeaderValue::from_maybe_shared_unchecked(
362 buf_ro.slice(value_start..(value_start + value_len)),
363 )
364 };
365 trailers_constructed.append(name, value);
366 }
367
368 return Ok(Some(trailers_constructed));
369 } else {
370 return Err(std::io::Error::new(
371 std::io::ErrorKind::InvalidInput,
372 "trailer headers incomplete",
373 ));
374 }
375 }
376 }
377 Err(std::io::Error::new(
378 std::io::ErrorKind::InvalidData,
379 "request too large",
380 ))
381 }
382
383 #[inline]
384 async fn read_chunked_body_fn(
385 &mut self,
386 body_tx: kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
387 would_have_trailers: bool,
388 ) -> Result<(), std::io::Error> {
389 loop {
390 let chunk = self.read_body_chunk(would_have_trailers).await?;
391 if chunk.is_empty() {
392 break;
393 }
394
395 let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
396 }
397 if would_have_trailers {
398 let trailers = self.read_trailers().await?;
400 if let Some(trailers) = trailers {
401 let _ = body_tx.send(Ok(http_body::Frame::trailers(trailers))).await;
402 }
403 }
404 Ok(())
405 }
406
407 #[inline]
408 async fn read_request(
409 &mut self,
410 ) -> Result<
411 Option<(
412 Request<Incoming>,
413 kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
414 )>,
415 std::io::Error,
416 > {
417 let (request, body_tx) = {
419 let Some((head, headers)) = self.get_head().await? else {
420 return Ok(None);
421 };
422 let headers = unsafe {
424 std::mem::transmute::<
425 &mut [MaybeUninit<httparse::Header<'static>>],
426 &mut [MaybeUninit<httparse::Header<'_>>],
427 >(headers)
428 };
429 let mut req = httparse::Request::new(&mut []);
430 let status = req
431 .parse_with_uninit_headers(&head, headers)
432 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
433 if status.is_partial() {
434 return Err(std::io::Error::new(
435 std::io::ErrorKind::InvalidData,
436 "partial request head",
437 ));
438 }
439
440 let (body_tx, body_rx) = kanal::bounded_async(2);
442 let request_body = Http1Body {
443 inner: Box::pin(body_rx),
444 };
445 let mut request = Request::new(Incoming::H1(request_body));
446 match req.version {
447 Some(0) => *request.version_mut() = http::Version::HTTP_10,
448 Some(1) => *request.version_mut() = http::Version::HTTP_11,
449 _ => *request.version_mut() = http::Version::HTTP_11,
450 };
451 if let Some(method) = req.method {
452 *request.method_mut() = Method::from_bytes(method.as_bytes())
453 .map_err(|e| std::io::Error::other(e.to_string()))?;
454 }
455 if let Some(path) = req.path {
456 *request.uri_mut() =
457 Uri::from_str(path).map_err(|e| std::io::Error::other(e.to_string()))?;
458 }
459 let mut header_map = self.cached_headers.take().unwrap_or_default();
460 header_map.clear();
461 let additional_capacity = req.headers.len().saturating_sub(header_map.capacity());
462 if additional_capacity > 0 {
463 header_map.reserve(additional_capacity);
464 }
465 for header in req.headers {
466 if header == &httparse::EMPTY_HEADER {
467 break;
469 }
470 let name = HeaderName::from_bytes(header.name.as_bytes())
471 .map_err(|e| std::io::Error::other(e.to_string()))?;
472 let value_start = header.value.as_ptr() as usize - head.as_ptr() as usize;
473 let value_len = header.value.len();
474 let value = unsafe {
476 HeaderValue::from_maybe_shared_unchecked(
477 head.slice(value_start..(value_start + value_len)),
478 )
479 };
480 header_map.append(name, value);
481 }
482 *request.headers_mut() = header_map;
483
484 (request, body_tx)
485 };
486 Ok(Some((request, body_tx)))
487 }
488
489 #[inline]
490 async fn get_head(
491 &mut self,
492 ) -> Result<Option<(Bytes, &mut [MaybeUninit<httparse::Header<'static>>])>, std::io::Error>
493 {
494 let mut request_line_read = false;
495 let mut bytes_read: usize = 0;
496 let mut whitespace_trimmed = None;
497 let mut just_started = true;
498 while bytes_read < self.options.max_header_size {
499 let old_bytes_read = bytes_read;
500 let begin_search = old_bytes_read.saturating_sub(3);
501
502 let have_to_read_buf = !just_started || self.read_buf.is_empty();
503 just_started = false;
504 if have_to_read_buf {
505 let n = self.fill_buf().await?;
506 if n == 0 {
507 if whitespace_trimmed.is_none() {
508 return Ok(None);
509 }
510 return Err(std::io::Error::new(
511 std::io::ErrorKind::UnexpectedEof,
512 "unexpected EOF",
513 ));
514 }
515 bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
516 } else {
517 bytes_read =
518 (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
519 }
520
521 if whitespace_trimmed.is_none() {
522 whitespace_trimmed = self.read_buf[old_bytes_read..bytes_read]
523 .iter()
524 .position(|b| !b.is_ascii_whitespace());
525 }
526
527 if let Some(whitespace_trimmed) = whitespace_trimmed {
528 if !request_line_read {
530 let memchr = memchr3_iter(
531 b' ',
532 b'\r',
533 b'\n',
534 &self.read_buf[whitespace_trimmed..bytes_read],
535 );
536 let mut spaces = 0;
537 for separator_index in memchr {
538 if self.read_buf[whitespace_trimmed + separator_index] == b' ' {
539 if spaces >= 2 {
540 return Err(std::io::Error::new(
541 std::io::ErrorKind::InvalidInput,
542 "bad request first line",
543 ));
544 }
545 spaces += 1;
546 } else if spaces == 2 {
547 request_line_read = true;
548 break;
549 } else {
550 return Err(std::io::Error::new(
551 std::io::ErrorKind::InvalidInput,
552 "bad request first line",
553 ));
554 }
555 }
556 }
557
558 if request_line_read {
559 let begin_search = begin_search.max(whitespace_trimmed);
560 if let Some((separator_index, separator_len)) =
561 search_header_body_separator(&self.read_buf[begin_search..bytes_read])
562 {
563 let to_parse_length =
564 begin_search + separator_index + separator_len - whitespace_trimmed;
565 self.read_buf.advance(whitespace_trimmed);
566 let head = self.read_buf.split_to(to_parse_length);
567 return Ok(Some((head.freeze(), &mut self.parsed_headers)));
568 }
569 }
570 }
571 }
572 Err(std::io::Error::new(
573 std::io::ErrorKind::InvalidData,
574 "request too large",
575 ))
576 }
577
578 #[inline]
579 async fn write_response<Z, ZFut>(
580 &mut self,
581 mut response: Response<
582 impl Body<Data = bytes::Bytes, Error = impl std::error::Error> + Unpin,
583 >,
584 version: Version,
585 write_trailers: bool,
586 zerocopy_fn: Option<Z>,
587 ) -> Result<(), std::io::Error>
588 where
589 Z: FnMut(RawHandle, &'static Io, u64) -> ZFut,
590 ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
591 {
592 if self.options.send_date_header {
594 response.headers_mut().insert(
595 header::DATE,
596 HeaderValue::from_str(self.get_date_header_value())
597 .map_err(|e| std::io::Error::other(e.to_string()))?,
598 );
599 }
600
601 if let Some(suggested_content_length) = response.body().size_hint().exact() {
603 let headers = response.headers_mut();
604 if !headers.contains_key(header::CONTENT_LENGTH) {
605 headers.insert(header::CONTENT_LENGTH, suggested_content_length.into());
606 }
607 }
608
609 let chunked = response
610 .headers()
611 .get(header::TRANSFER_ENCODING)
612 .map(|v| {
613 v.to_str().ok().is_some_and(|s| {
614 s.split(',')
615 .any(|s| s.trim().eq_ignore_ascii_case("chunked"))
616 })
617 })
618 .unwrap_or_else(|| {
619 response
620 .headers()
621 .get(header::CONTENT_LENGTH)
622 .and_then(|v| v.to_str().ok())
623 .is_none_or(|s| s.parse::<u64>().is_err())
624 });
625
626 if chunked {
627 response.headers_mut().insert(
628 header::TRANSFER_ENCODING,
629 HeaderValue::from_static("chunked"),
630 );
631 while response
632 .headers_mut()
633 .remove(header::CONTENT_LENGTH)
634 .is_some()
635 {}
636 }
637
638 let (parts, mut body) = response.into_parts();
639
640 self.response_head_buf.clear();
641 let estimated_head_len = 30 + parts.headers.len() * 30; if self.response_head_buf.capacity() < estimated_head_len {
643 self.response_head_buf
644 .reserve(estimated_head_len - self.response_head_buf.capacity());
645 }
646 let head = &mut self.response_head_buf;
647 if version == Version::HTTP_10 {
648 head.extend_from_slice(b"HTTP/1.0 ");
649 } else {
650 head.extend_from_slice(b"HTTP/1.1 ");
651 }
652 let status = parts.status;
653 head.extend_from_slice(status.as_str().as_bytes());
654 if let Some(canonical_reason) = status.canonical_reason() {
655 head.extend_from_slice(b" ");
656 head.extend_from_slice(canonical_reason.as_bytes());
657 }
658 head.extend_from_slice(b"\r\n");
659 for (name, value) in &parts.headers {
660 head.extend_from_slice(name.as_str().as_bytes());
661 head.extend_from_slice(b": ");
662 head.extend_from_slice(value.as_bytes());
663 head.extend_from_slice(b"\r\n");
664 }
665 head.extend_from_slice(b"\r\n");
666 unsafe {
667 self.write_buf.push(IoSlice::new(head));
668 }
669
670 if !chunked {
671 if let Some(content_length) = parts
672 .headers
673 .get(header::CONTENT_LENGTH)
674 .and_then(|v| v.to_str().ok())
675 .and_then(|s| s.parse::<u64>().ok())
676 {
677 if let Some(zero_copy) = parts.extensions.get::<ZerocopyResponse>() {
678 if let Some(mut zerocopy_fn) = zerocopy_fn {
679 unsafe {
681 self.write_buf
682 .flush(&mut self.io, self.options.enable_vectored_write)
683 .await?
684 };
685 zerocopy_fn(
686 zero_copy.handle,
687 unsafe { std::mem::transmute::<&Io, &'static Io>(&self.io) },
689 content_length,
690 )
691 .await?;
692 self.io.flush().await?;
693 let reclaimed_headers = parts.headers;
694 self.cached_headers = Some(reclaimed_headers);
695 return Ok(());
696 }
697 }
698 }
699 }
700
701 let mut trailers_written = false;
702 while let Some(chunk) = body.frame().await {
703 let chunk = chunk.map_err(|e| std::io::Error::other(e.to_string()))?;
704 match chunk.into_data() {
705 Ok(data) => {
706 if chunked {
707 let mut chunk_size_buf = [0u8; 18];
708 let chunk_size = write_chunk_size(&mut chunk_size_buf, data.len());
709 self.write_buf.push_copy(chunk_size);
710 self.write_buf.push_bytes(data);
711 unsafe {
712 self.write_buf.push(IoSlice::new(b"\r\n"));
713 }
714 } else {
715 self.write_buf.push_bytes(data);
716 }
717 while self.write_buf.len() >= WRITE_BUF_BATCH_THRESHOLD {
718 unsafe {
719 self.write_buf
720 .write(&mut self.io, self.options.enable_vectored_write)
721 .await?;
722 }
723 }
724 }
725 Err(chunk) => {
726 if let Ok(trailers) = chunk.into_trailers() {
727 if write_trailers {
728 unsafe {
729 self.write_buf.push(IoSlice::new(b"0\r\n"));
730 for (name, value) in &trailers {
731 self.write_buf.push_copy(name.as_str().as_bytes());
732 self.write_buf.push(IoSlice::new(b": "));
733 self.write_buf.push_copy(value.as_bytes());
734 self.write_buf.push(IoSlice::new(b"\r\n"));
735 }
736 self.write_buf.push(IoSlice::new(b"\r\n"));
737 }
738 trailers_written = true;
739 }
740 break;
741 }
742 }
743 };
744 }
745 if chunked && !trailers_written {
746 unsafe {
748 self.write_buf.push(IoSlice::new(b"0\r\n\r\n"));
749 }
750 }
751 unsafe {
752 self.write_buf
753 .flush(&mut self.io, self.options.enable_vectored_write)
754 .await?;
755 }
756 self.io.flush().await?;
757 let reclaimed_headers = parts.headers;
758 self.cached_headers = Some(reclaimed_headers);
759
760 Ok(())
761 }
762
763 #[inline]
764 async fn write_100_continue(&mut self, version: Version) -> Result<(), std::io::Error> {
765 if version == Version::HTTP_10 {
766 self.io.write_all(b"HTTP/1.0 100 Continue\r\n\r\n").await?;
767 } else {
768 self.io.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").await?;
769 }
770 self.io.flush().await?;
771
772 Ok(())
773 }
774
775 #[inline]
776 async fn write_early_hints(
777 &mut self,
778 version: Version,
779 headers: http::HeaderMap,
780 ) -> Result<(), std::io::Error> {
781 let mut head = Vec::new();
782 if version == Version::HTTP_10 {
783 head.extend_from_slice(b"HTTP/1.0 103 Early Hints\r\n");
784 } else {
785 head.extend_from_slice(b"HTTP/1.1 103 Early Hints\r\n");
786 }
787 let mut current_header_name = None;
788 for (name, value) in headers {
789 if let Some(name) = name {
790 current_header_name = Some(name);
791 };
792 if let Some(current_header_name) = ¤t_header_name {
793 head.extend_from_slice(current_header_name.as_str().as_bytes());
794 if value.is_empty() {
795 head.extend_from_slice(b":\r\n");
796 continue;
797 }
798 head.extend_from_slice(b": ");
799 head.extend_from_slice(value.as_bytes());
800 head.extend_from_slice(b"\r\n");
801 }
802 }
803 head.extend_from_slice(b"\r\n");
804
805 self.io.write_all(&head).await?;
806
807 Ok(())
808 }
809
810 #[inline]
811 pub(crate) async fn handle_with_error_fn_and_zerocopy<
812 F,
813 Fut,
814 ResB,
815 ResBE,
816 ResE,
817 EF,
818 EFut,
819 EResB,
820 EResBE,
821 EResE,
822 ZF,
823 ZFut,
824 >(
825 mut self,
826 request_fn: F,
827 error_fn: EF,
828 mut zerocopy_fn: Option<ZF>,
829 ) -> Result<(), std::io::Error>
830 where
831 F: Fn(Request<Incoming>) -> Fut + 'static,
832 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
833 ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
834 ResE: std::error::Error,
835 ResBE: std::error::Error,
836 EF: FnOnce(bool) -> EFut,
837 EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
838 EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin + 'static,
839 EResE: std::error::Error,
840 EResBE: std::error::Error,
841 ZF: FnMut(RawHandle, &'static Io, u64) -> ZFut,
842 ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
843 {
844 let mut keep_alive = true;
845
846 while keep_alive {
847 let (mut request, body_tx) = match if let Some(timeout) =
848 self.options.header_read_timeout
849 {
850 vibeio::time::timeout(timeout, self.read_request()).await
851 } else {
852 Ok(self.read_request().await)
853 } {
854 Ok(Ok(Some(d))) => d,
855 Ok(Ok(None)) => {
856 return Ok(());
857 }
858 Ok(Err(e)) => {
859 if let Ok(mut response) = error_fn(false).await {
861 response
862 .headers_mut()
863 .insert(header::CONNECTION, HeaderValue::from_static("close"));
864
865 let _ = self
866 .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
867 .await;
868 }
869 return Err(e);
870 }
871 Err(_) => {
872 if let Ok(mut response) = error_fn(true).await {
874 response
875 .headers_mut()
876 .insert(header::CONNECTION, HeaderValue::from_static("close"));
877
878 let _ = self
879 .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
880 .await;
881 }
882 return Err(std::io::Error::new(
883 std::io::ErrorKind::TimedOut,
884 "header read timeout",
885 ));
886 }
887 };
888
889 let connection_header_split = request
891 .headers()
892 .get(header::CONNECTION)
893 .and_then(|v| v.to_str().ok())
894 .map(|v| v.split(",").map(|v| v.trim()));
895 let is_connection_close = connection_header_split
896 .clone()
897 .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("close")));
898 let is_connection_keep_alive = connection_header_split
899 .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("keep-alive")));
900 keep_alive = !is_connection_close
901 && (is_connection_keep_alive || request.version() == http::Version::HTTP_11);
902
903 let version = request.version();
904
905 if self.options.send_continue_response {
907 let is_100_continue = request
908 .headers()
909 .get(header::EXPECT)
910 .and_then(|v| v.to_str().ok())
911 .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
912 if is_100_continue {
913 self.write_100_continue(version).await?;
914 }
915 }
916
917 let early_hints_fut = if self.options.enable_early_hints {
919 let (early_hints, mut early_hints_rx) = EarlyHints::new_lazy();
920 request.extensions_mut().insert(early_hints);
921 let mut_self = unsafe { std::mem::transmute::<&mut Self, &mut Self>(&mut self) };
925 futures_util::future::Either::Left(async move {
926 while let Some((headers, sender)) =
927 std::future::poll_fn(|cx| early_hints_rx.poll_recv(cx)).await
928 {
929 sender
930 .into_inner()
931 .send(mut_self.write_early_hints(version, headers).await)
932 .ok();
933 }
934 futures_util::future::pending::<Result<(), std::io::Error>>().await
935 })
936 } else {
937 futures_util::future::Either::Right(futures_util::future::pending::<
938 Result<(), std::io::Error>,
939 >())
940 };
941
942 let content_length = request
944 .headers()
945 .get(header::CONTENT_LENGTH)
946 .and_then(|v| v.to_str().ok())
947 .and_then(|v| v.parse::<u64>().ok())
948 .unwrap_or(0);
949 let chunked = request
950 .headers()
951 .get(header::TRANSFER_ENCODING)
952 .and_then(|v| v.to_str().ok())
953 .is_some_and(|v| {
954 v.split(',')
955 .any(|v| v.trim().eq_ignore_ascii_case("chunked"))
956 });
957 let has_trailers = request
958 .headers()
959 .get(header::TRAILER)
960 .map(|v| v.to_str().ok().is_some_and(|s| !s.is_empty()))
961 .unwrap_or(false);
962 let write_trailers = request
963 .headers()
964 .get(header::TE)
965 .and_then(|v| v.to_str().ok())
966 .map(|v| {
967 v.split(',')
968 .any(|v| v.trim().eq_ignore_ascii_case("trailers"))
969 })
970 .unwrap_or(false);
971
972 let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
974 let upgrade = Upgrade::new(upgrade_rx);
975 let upgraded = upgrade.upgraded.clone();
976 request.extensions_mut().insert(upgrade);
977
978 let mut response = {
980 let read_body_fut = async {
981 if chunked {
982 self.read_chunked_body_fn(body_tx, has_trailers).await
983 } else {
984 self.read_body_fn(body_tx, content_length).await
985 }
986 };
987 let read_body_fut_pin = std::pin::pin!(read_body_fut);
988 let request_fut = request_fn(request);
989 let request_fut_pin = std::pin::pin!(request_fut);
990 let early_hints_fut_pin = std::pin::pin!(early_hints_fut);
991
992 let select_read_body_either =
993 futures_util::future::select(request_fut_pin, early_hints_fut_pin);
994 let select_either =
995 futures_util::future::select(read_body_fut_pin, select_read_body_either).await;
996
997 let (response, body_fut) = match select_either {
998 futures_util::future::Either::Left((result, request_fut)) => {
999 result?;
1000 (
1001 match request_fut.await {
1002 futures_util::future::Either::Left((response, _)) => response,
1003 futures_util::future::Either::Right((_, _)) => unreachable!(),
1004 },
1005 None,
1006 )
1007 }
1008 futures_util::future::Either::Right((response, read_body_fut)) => (
1009 match response {
1010 futures_util::future::Either::Left((response, _)) => response,
1011 futures_util::future::Either::Right((_, _)) => unreachable!(),
1012 },
1013 Some(read_body_fut),
1014 ),
1015 };
1016
1017 if let Some(body_fut) = body_fut {
1019 body_fut.await?;
1020 }
1021
1022 response.map_err(|e| std::io::Error::other(e.to_string()))?
1023 };
1024
1025 let mut was_upgraded = false;
1026 if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
1027 was_upgraded = true;
1028 response
1029 .headers_mut()
1030 .insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
1031 } else if keep_alive {
1032 if version == Version::HTTP_10
1033 || response.headers().contains_key(header::CONNECTION)
1034 {
1035 response
1036 .headers_mut()
1037 .insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
1038 }
1039 } else if version == Version::HTTP_11
1040 || response.headers().contains_key(header::CONNECTION)
1041 {
1042 response
1043 .headers_mut()
1044 .insert(header::CONNECTION, HeaderValue::from_static("close"));
1045 }
1046
1047 self.write_response(response, version, write_trailers, zerocopy_fn.as_mut())
1049 .await?;
1050
1051 if was_upgraded {
1052 let frozen_buf = self.read_buf.freeze();
1054 let _ = upgrade_tx.send(Upgraded::new(
1055 self.io,
1056 if frozen_buf.is_empty() {
1057 None
1058 } else {
1059 Some(frozen_buf)
1060 },
1061 ));
1062 return Ok(());
1063 }
1064
1065 if self.cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
1066 break;
1068 }
1069 }
1070 Ok(())
1071 }
1072}
1073
1074impl<Io> HttpProtocol for Http1<Io>
1075where
1076 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
1077{
1078 #[inline]
1079 fn handle_with_error_fn<F, Fut, ResB, ResBE, ResE, EF, EFut, EResB, EResBE, EResE>(
1080 self,
1081 request_fn: F,
1082 error_fn: EF,
1083 ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1084 where
1085 F: Fn(Request<Incoming>) -> Fut + 'static,
1086 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
1087 ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
1088 ResE: std::error::Error,
1089 ResBE: std::error::Error,
1090 EF: FnOnce(bool) -> EFut,
1091 EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
1092 EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin + 'static,
1093 EResE: std::error::Error,
1094 EResBE: std::error::Error,
1095 {
1096 #[allow(clippy::type_complexity)]
1097 let no_zerocopy: Option<
1098 Box<
1099 dyn FnMut(
1100 RawHandle,
1101 &Io,
1102 u64,
1103 ) -> Box<
1104 dyn std::future::Future<Output = Result<(), std::io::Error>>
1105 + Unpin
1106 + Send
1107 + Sync,
1108 >,
1109 >,
1110 > = None;
1111 self.handle_with_error_fn_and_zerocopy(request_fn, error_fn, no_zerocopy)
1112 }
1113
1114 #[inline]
1115 fn handle<F, Fut, ResB, ResBE, ResE>(
1116 self,
1117 request_fn: F,
1118 ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1119 where
1120 F: Fn(Request<Incoming>) -> Fut + 'static,
1121 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
1122 ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
1123 ResE: std::error::Error,
1124 ResBE: std::error::Error,
1125 {
1126 self.handle_with_error_fn(request_fn, |is_timeout| async move {
1127 let mut response = Response::builder();
1128 if is_timeout {
1129 response = response.status(http::StatusCode::REQUEST_TIMEOUT);
1130 } else {
1131 response = response.status(http::StatusCode::BAD_REQUEST);
1132 }
1133 response.body(Empty::new())
1134 })
1135 }
1136}
1137
1138pub(crate) struct Http1Body {
1139 #[allow(clippy::type_complexity)]
1140 inner: Pin<Box<AsyncReceiver<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>>>,
1141}
1142
1143impl Body for Http1Body {
1144 type Data = bytes::Bytes;
1145 type Error = std::io::Error;
1146
1147 #[inline]
1148 fn poll_frame(
1149 self: Pin<&mut Self>,
1150 cx: &mut Context<'_>,
1151 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
1152 match std::pin::pin!(self.inner.recv()).poll(cx) {
1153 Poll::Ready(Ok(Ok(frame))) => Poll::Ready(Some(Ok(frame))),
1154 Poll::Ready(Ok(Err(e))) => Poll::Ready(Some(Err(e))),
1155 Poll::Ready(Err(_)) => Poll::Ready(None),
1156 Poll::Pending => Poll::Pending,
1157 }
1158 }
1159}
1160
1161#[inline]
1164fn search_header_body_separator(slice: &[u8]) -> Option<(usize, usize)> {
1165 if slice.len() < 2 {
1166 return None;
1168 }
1169 for (i, b) in slice.iter().copied().enumerate() {
1170 if b == b'\r' {
1171 if slice[i + 1..].chunks(3).next() == Some(&b"\n\r\n"[..]) {
1172 return Some((i, 4));
1173 }
1174 } else if b == b'\n' && slice.get(i + 1) == Some(&b'\n') {
1175 return Some((i, 2));
1176 }
1177 }
1178 None
1179}
1180
1181#[inline]
1183fn write_chunk_size(dst: &mut [u8; 18], len: usize) -> &[u8] {
1184 let mut n = len;
1185 let mut pos = dst.len() - 2;
1186 loop {
1187 pos -= 1;
1188 dst[pos] = HEX_DIGITS[n & 0xF];
1189 n >>= 4;
1190 if n == 0 {
1191 break;
1192 }
1193 }
1194 dst[dst.len() - 2] = b'\r';
1195 dst[dst.len() - 1] = b'\n';
1196 &dst[pos..]
1197}