1mod options;
2mod tests;
3mod writebuf;
4mod zerocopy;
5
6pub use options::*;
7pub use zerocopy::*;
8
9#[cfg(unix)]
10pub(crate) type RawHandle = std::os::fd::RawFd;
11#[cfg(windows)]
12pub(crate) type RawHandle = std::os::windows::io::RawHandle;
13
14use std::{
15 future::Future,
16 io::IoSlice,
17 mem::MaybeUninit,
18 pin::Pin,
19 str::FromStr,
20 task::{Context, Poll},
21 time::UNIX_EPOCH,
22};
23
24use bytes::{Buf, Bytes, BytesMut};
25use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, Uri, Version};
26use http_body::Body;
27use http_body_util::{BodyExt, Empty};
28use kanal::AsyncReceiver;
29use memchr::{memchr3_iter, memmem};
30use tokio::io::{AsyncReadExt, AsyncWriteExt};
31use tokio_util::sync::CancellationToken;
32
33use crate::{h1::writebuf::WriteBuf, EarlyHints, HttpProtocol, Incoming, Upgrade, Upgraded};
34
35const HEX_DIGITS: &[u8; 16] = b"0123456789ABCDEF";
36const WRITE_BUF_BATCH_THRESHOLD: usize = 16384;
37
38pub struct Http1<Io> {
70 io: Io,
71 options: options::Http1Options,
72 cancel_token: Option<CancellationToken>,
73 parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]>,
74 date_header_value_cached: Option<(String, std::time::SystemTime)>,
75 cached_headers: Option<HeaderMap>,
76 read_buf: BytesMut,
77 response_head_buf: Vec<u8>,
78 write_buf: WriteBuf,
79}
80
81#[cfg(all(target_os = "linux", feature = "h1-zerocopy"))]
82impl<Io> Http1<Io>
83where
84 for<'a> Io: tokio::io::AsyncRead
85 + tokio::io::AsyncWrite
86 + vibeio::io::AsInnerRawHandle<'a>
87 + Unpin
88 + 'static,
89{
90 #[inline]
101 pub fn zerocopy(self) -> Http1Zerocopy<Io> {
102 Http1Zerocopy { inner: self }
103 }
104}
105
106impl<Io> Http1<Io>
107where
108 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
109{
110 #[inline]
121 pub fn new(io: Io, options: options::Http1Options) -> Self {
122 let read_buf = BytesMut::with_capacity(options.max_header_size);
124 let parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]> =
125 Box::new_uninit_slice(options.max_header_count);
126 Self {
127 io,
128 options,
129 cancel_token: None,
130 parsed_headers,
131 date_header_value_cached: None,
132 cached_headers: None,
133 read_buf,
134 response_head_buf: Vec::with_capacity(1024),
135 write_buf: WriteBuf::new(),
136 }
137 }
138
139 #[inline]
140 fn get_date_header_value(&mut self) -> &str {
141 let now = std::time::SystemTime::now();
142 if self.date_header_value_cached.as_ref().is_none_or(|v| {
143 v.1.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
144 != now.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
145 }) {
146 let value = httpdate::fmt_http_date(now).to_string();
147 self.date_header_value_cached = Some((value, now));
148 }
149 self.date_header_value_cached
150 .as_ref()
151 .map(|v| v.0.as_str())
152 .unwrap_or("")
153 }
154
155 #[inline]
165 pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
166 self.cancel_token = Some(token);
167 self
168 }
169
170 #[inline]
171 async fn fill_buf(&mut self) -> Result<usize, std::io::Error> {
172 if self.read_buf.remaining() < 1024 {
173 self.read_buf.reserve(1024);
174 }
175 let spare_capacity = self.read_buf.spare_capacity_mut();
176 let n = self
178 .io
179 .read(unsafe {
180 &mut *std::ptr::slice_from_raw_parts_mut(
181 spare_capacity.as_mut_ptr() as *mut u8,
182 spare_capacity.len(),
183 )
184 })
185 .await?;
186 if n == 0 {
187 return Ok(0);
188 }
189 unsafe { self.read_buf.set_len(self.read_buf.len() + n) };
190 Ok(n)
191 }
192
193 #[inline]
194 async fn read_body_fn(
195 &mut self,
196 body_tx: kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
197 content_length: u64,
198 ) -> Result<(), std::io::Error> {
199 let mut remaining = content_length;
200 let mut just_started = true;
201 while remaining > 0 {
202 let have_to_read_buf = !just_started || self.read_buf.is_empty();
203 just_started = false;
204 if have_to_read_buf {
205 let n = self.fill_buf().await?;
206 if n == 0 {
207 break;
208 }
209 }
210 let chunk = self
211 .read_buf
212 .split_to(
213 self.read_buf
214 .len()
215 .min(remaining.min(usize::MAX as u64) as usize),
216 )
217 .freeze();
218 remaining -= chunk.len() as u64;
219
220 let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
221 }
222 Ok(())
223 }
224
225 #[inline]
226 async fn read_body_chunk(
227 &mut self,
228 would_have_trailers: bool,
229 ) -> Result<bytes::Bytes, std::io::Error> {
230 let len = {
231 let mut len_buf_pos: usize = 0;
233 let mut just_started = true;
234 loop {
235 if len_buf_pos >= 48 {
236 return Err(std::io::Error::new(
237 std::io::ErrorKind::InvalidData,
238 "chunk length buffer overflow",
239 ));
240 }
241
242 let begin_search = len_buf_pos.saturating_sub(1);
243
244 let have_to_read_buf = !just_started || self.read_buf.is_empty();
245 just_started = false;
246 if have_to_read_buf {
247 let n = self.fill_buf().await?;
248 if n == 0 {
249 return Err(std::io::Error::new(
250 std::io::ErrorKind::UnexpectedEof,
251 "unexpected EOF",
252 ));
253 }
254 len_buf_pos += n;
255 } else {
256 len_buf_pos += self.read_buf.len();
257 }
258
259 if let Some(pos) =
260 memmem::find(&self.read_buf[begin_search..len_buf_pos.min(48)], b"\r\n")
261 {
262 let numbers = std::str::from_utf8(&self.read_buf[..begin_search + pos])
263 .map_err(|_| {
264 std::io::Error::new(
265 std::io::ErrorKind::InvalidData,
266 "invalid chunk length",
267 )
268 })?;
269 let len = usize::from_str_radix(numbers, 16).map_err(|_| {
270 std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid chunk length")
271 })?;
272 self.read_buf.advance(begin_search + pos + 2);
274 break len;
275 }
276 }
277 };
278 let mut read = 0;
280 if len == 0 && would_have_trailers {
281 return Ok(bytes::Bytes::new()); }
283 let mut just_started = true;
284 let Some(len_plus_two) = len.checked_add(2) else {
286 return Err(std::io::Error::new(
287 std::io::ErrorKind::InvalidData,
288 "chunk length too large",
289 ));
290 };
291 while read < len_plus_two {
292 let have_to_read_buf = !just_started || self.read_buf.is_empty();
293 just_started = false;
294 if have_to_read_buf {
295 let n = self.fill_buf().await?;
296 if n == 0 {
297 return Err(std::io::Error::new(
298 std::io::ErrorKind::UnexpectedEof,
299 "unexpected EOF",
300 ));
301 }
302 read += n;
303 } else {
304 read += self.read_buf.len();
305 }
306 }
307 let chunk = self.read_buf.split_to(len).freeze();
308 self.read_buf.advance(2); Ok(chunk)
310 }
311
312 #[inline]
313 async fn read_trailers(&mut self) -> Result<Option<HeaderMap>, std::io::Error> {
314 let mut bytes_read: usize = 0;
316 let mut just_started = true;
317 while bytes_read < self.options.max_header_size {
318 let old_bytes_read = bytes_read;
319 let begin_search = old_bytes_read.saturating_sub(3);
320
321 let have_to_read_buf = !just_started || self.read_buf.is_empty();
322 just_started = false;
323 if have_to_read_buf {
324 let n = self.fill_buf().await?;
325 if n == 0 {
326 return Err(std::io::Error::new(
327 std::io::ErrorKind::UnexpectedEof,
328 "unexpected EOF",
329 ));
330 }
331 bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
332 } else {
333 bytes_read =
334 (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
335 }
336
337 if bytes_read >= 2 && self.read_buf[0] == b'\r' && self.read_buf[1] == b'\n' {
338 return Ok(None);
340 }
341
342 if let Some(separator_index) =
343 memmem::find(&self.read_buf[begin_search..bytes_read], b"\r\n\r\n")
344 {
345 let to_parse_length = begin_search + separator_index + 4;
346 let buf_ro = self.read_buf.split_to(to_parse_length).freeze();
347
348 let mut httparse_trailers =
350 vec![httparse::EMPTY_HEADER; self.options.max_header_count].into_boxed_slice();
351 let status = httparse::parse_headers(&buf_ro, &mut httparse_trailers)
352 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
353 if let httparse::Status::Complete((_, trailers)) = status {
354 let mut trailers_constructed = HeaderMap::new();
355 for header in trailers {
356 if header == &httparse::EMPTY_HEADER {
357 break;
359 }
360 let name = HeaderName::from_bytes(header.name.as_bytes())
361 .map_err(|e| std::io::Error::other(e.to_string()))?;
362 let value_start = header.value.as_ptr() as usize - buf_ro.as_ptr() as usize;
363 let value_len = header.value.len();
364 let value = unsafe {
366 HeaderValue::from_maybe_shared_unchecked(
367 buf_ro.slice(value_start..(value_start + value_len)),
368 )
369 };
370 trailers_constructed.append(name, value);
371 }
372
373 return Ok(Some(trailers_constructed));
374 } else {
375 return Err(std::io::Error::new(
376 std::io::ErrorKind::InvalidInput,
377 "trailer headers incomplete",
378 ));
379 }
380 }
381 }
382 Err(std::io::Error::new(
383 std::io::ErrorKind::InvalidData,
384 "request too large",
385 ))
386 }
387
388 #[inline]
389 async fn read_chunked_body_fn(
390 &mut self,
391 body_tx: kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
392 would_have_trailers: bool,
393 ) -> Result<(), std::io::Error> {
394 loop {
395 let chunk = self.read_body_chunk(would_have_trailers).await?;
396 if chunk.is_empty() {
397 break;
398 }
399
400 let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
401 }
402 if would_have_trailers {
403 let trailers = self.read_trailers().await?;
405 if let Some(trailers) = trailers {
406 let _ = body_tx.send(Ok(http_body::Frame::trailers(trailers))).await;
407 }
408 }
409 Ok(())
410 }
411
412 #[inline]
413 async fn read_request(
414 &mut self,
415 ) -> Result<
416 Option<(
417 Request<Incoming>,
418 kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
419 )>,
420 std::io::Error,
421 > {
422 let (request, body_tx) = {
424 let Some((head, headers)) = self.get_head().await? else {
425 return Ok(None);
426 };
427 let headers = unsafe {
429 std::mem::transmute::<
430 &mut [MaybeUninit<httparse::Header<'static>>],
431 &mut [MaybeUninit<httparse::Header<'_>>],
432 >(headers)
433 };
434 let mut req = httparse::Request::new(&mut []);
435 let status = req
436 .parse_with_uninit_headers(&head, headers)
437 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
438 if status.is_partial() {
439 return Err(std::io::Error::new(
440 std::io::ErrorKind::InvalidData,
441 "partial request head",
442 ));
443 }
444
445 let (body_tx, body_rx) = kanal::bounded_async(2);
447 let request_body = Http1Body {
448 inner: Box::pin(body_rx),
449 };
450 let mut request = Request::new(Incoming::H1(request_body));
451 match req.version {
452 Some(0) => *request.version_mut() = http::Version::HTTP_10,
453 Some(1) => *request.version_mut() = http::Version::HTTP_11,
454 _ => *request.version_mut() = http::Version::HTTP_11,
455 };
456 if let Some(method) = req.method {
457 *request.method_mut() = Method::from_bytes(method.as_bytes())
458 .map_err(|e| std::io::Error::other(e.to_string()))?;
459 }
460 if let Some(path) = req.path {
461 *request.uri_mut() =
462 Uri::from_str(path).map_err(|e| std::io::Error::other(e.to_string()))?;
463 }
464 let mut header_map = self.cached_headers.take().unwrap_or_default();
465 header_map.clear();
466 let additional_capacity = req.headers.len().saturating_sub(header_map.capacity());
467 if additional_capacity > 0 {
468 header_map.reserve(additional_capacity);
469 }
470 for header in req.headers {
471 if header == &httparse::EMPTY_HEADER {
472 break;
474 }
475 let name = HeaderName::from_bytes(header.name.as_bytes())
476 .map_err(|e| std::io::Error::other(e.to_string()))?;
477 let value_start = header.value.as_ptr() as usize - head.as_ptr() as usize;
478 let value_len = header.value.len();
479 let value = unsafe {
481 HeaderValue::from_maybe_shared_unchecked(
482 head.slice(value_start..(value_start + value_len)),
483 )
484 };
485 header_map.append(name, value);
486 }
487 *request.headers_mut() = header_map;
488
489 (request, body_tx)
490 };
491 Ok(Some((request, body_tx)))
492 }
493
494 #[inline]
495 async fn get_head(
496 &mut self,
497 ) -> Result<Option<(Bytes, &mut [MaybeUninit<httparse::Header<'static>>])>, std::io::Error>
498 {
499 let mut request_line_read = false;
500 let mut bytes_read: usize = 0;
501 let mut whitespace_trimmed = None;
502 let mut just_started = true;
503 while bytes_read < self.options.max_header_size {
504 let old_bytes_read = bytes_read;
505 let begin_search = old_bytes_read.saturating_sub(3);
506
507 let have_to_read_buf = !just_started || self.read_buf.is_empty();
508 just_started = false;
509 if have_to_read_buf {
510 let n = self.fill_buf().await?;
511 if n == 0 {
512 if whitespace_trimmed.is_none() {
513 return Ok(None);
514 }
515 return Err(std::io::Error::new(
516 std::io::ErrorKind::UnexpectedEof,
517 "unexpected EOF",
518 ));
519 }
520 bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
521 } else {
522 bytes_read =
523 (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
524 }
525
526 if whitespace_trimmed.is_none() {
527 whitespace_trimmed = self.read_buf[old_bytes_read..bytes_read]
528 .iter()
529 .position(|b| !b.is_ascii_whitespace());
530 }
531
532 if let Some(whitespace_trimmed) = whitespace_trimmed {
533 if !request_line_read {
535 let memchr = memchr3_iter(
536 b' ',
537 b'\r',
538 b'\n',
539 &self.read_buf[whitespace_trimmed..bytes_read],
540 );
541 let mut spaces = 0;
542 for separator_index in memchr {
543 if self.read_buf[whitespace_trimmed + separator_index] == b' ' {
544 if spaces >= 2 {
545 return Err(std::io::Error::new(
546 std::io::ErrorKind::InvalidInput,
547 "bad request first line",
548 ));
549 }
550 spaces += 1;
551 } else if spaces == 2 {
552 request_line_read = true;
553 break;
554 } else {
555 return Err(std::io::Error::new(
556 std::io::ErrorKind::InvalidInput,
557 "bad request first line",
558 ));
559 }
560 }
561 }
562
563 if request_line_read {
564 let begin_search = begin_search.max(whitespace_trimmed);
565 if let Some((separator_index, separator_len)) =
566 search_header_body_separator(&self.read_buf[begin_search..bytes_read])
567 {
568 let to_parse_length =
569 begin_search + separator_index + separator_len - whitespace_trimmed;
570 self.read_buf.advance(whitespace_trimmed);
571 let head = self.read_buf.split_to(to_parse_length);
572 return Ok(Some((head.freeze(), &mut self.parsed_headers)));
573 }
574 }
575 }
576 }
577 Err(std::io::Error::new(
578 std::io::ErrorKind::InvalidData,
579 "request too large",
580 ))
581 }
582
583 #[inline]
584 async fn write_response<Z, ZFut>(
585 &mut self,
586 mut response: Response<
587 impl Body<Data = bytes::Bytes, Error = impl std::error::Error> + Unpin,
588 >,
589 version: Version,
590 write_trailers: bool,
591 zerocopy_fn: Option<Z>,
592 ) -> Result<(), std::io::Error>
593 where
594 Z: FnMut(RawHandle, &'static Io, u64) -> ZFut,
595 ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
596 {
597 if self.options.send_date_header {
599 response.headers_mut().insert(
600 header::DATE,
601 HeaderValue::from_str(self.get_date_header_value())
602 .map_err(|e| std::io::Error::other(e.to_string()))?,
603 );
604 }
605
606 if let Some(suggested_content_length) = response.body().size_hint().exact() {
608 let headers = response.headers_mut();
609 if !headers.contains_key(header::CONTENT_LENGTH) {
610 headers.insert(header::CONTENT_LENGTH, suggested_content_length.into());
611 }
612 }
613
614 let chunked = response
615 .headers()
616 .get(header::TRANSFER_ENCODING)
617 .map(|v| {
618 v.to_str().ok().is_some_and(|s| {
619 s.split(',')
620 .any(|s| s.trim().eq_ignore_ascii_case("chunked"))
621 })
622 })
623 .unwrap_or_else(|| {
624 response
625 .headers()
626 .get(header::CONTENT_LENGTH)
627 .and_then(|v| v.to_str().ok())
628 .is_none_or(|s| s.parse::<u64>().is_err())
629 });
630
631 if chunked {
632 response.headers_mut().insert(
633 header::TRANSFER_ENCODING,
634 HeaderValue::from_static("chunked"),
635 );
636 while response
637 .headers_mut()
638 .remove(header::CONTENT_LENGTH)
639 .is_some()
640 {}
641 }
642
643 let (parts, mut body) = response.into_parts();
644
645 self.response_head_buf.clear();
646 let estimated_head_len = 30 + parts.headers.len() * 30; if self.response_head_buf.capacity() < estimated_head_len {
648 self.response_head_buf
649 .reserve(estimated_head_len - self.response_head_buf.capacity());
650 }
651 let head = &mut self.response_head_buf;
652 if version == Version::HTTP_10 {
653 head.extend_from_slice(b"HTTP/1.0 ");
654 } else {
655 head.extend_from_slice(b"HTTP/1.1 ");
656 }
657 let status = parts.status;
658 head.extend_from_slice(status.as_str().as_bytes());
659 if let Some(canonical_reason) = status.canonical_reason() {
660 head.extend_from_slice(b" ");
661 head.extend_from_slice(canonical_reason.as_bytes());
662 }
663 head.extend_from_slice(b"\r\n");
664 for (name, value) in &parts.headers {
665 head.extend_from_slice(name.as_str().as_bytes());
666 head.extend_from_slice(b": ");
667 head.extend_from_slice(value.as_bytes());
668 head.extend_from_slice(b"\r\n");
669 }
670 head.extend_from_slice(b"\r\n");
671 unsafe {
672 self.write_buf.push(IoSlice::new(head));
673 }
674
675 if !chunked {
676 if let Some(content_length) = parts
677 .headers
678 .get(header::CONTENT_LENGTH)
679 .and_then(|v| v.to_str().ok())
680 .and_then(|s| s.parse::<u64>().ok())
681 {
682 if let Some(zero_copy) = parts.extensions.get::<ZerocopyResponse>() {
683 if let Some(mut zerocopy_fn) = zerocopy_fn {
684 unsafe {
686 self.write_buf
687 .flush(&mut self.io, self.options.enable_vectored_write)
688 .await?
689 };
690 zerocopy_fn(
691 zero_copy.handle,
692 unsafe { std::mem::transmute::<&Io, &'static Io>(&self.io) },
694 content_length,
695 )
696 .await?;
697 self.io.flush().await?;
698 let reclaimed_headers = parts.headers;
699 self.cached_headers = Some(reclaimed_headers);
700 return Ok(());
701 }
702 }
703 }
704 }
705
706 let mut trailers_written = false;
707 while let Some(chunk) = body.frame().await {
708 let chunk = chunk.map_err(|e| std::io::Error::other(e.to_string()))?;
709 match chunk.into_data() {
710 Ok(data) => {
711 if chunked {
712 let mut chunk_size_buf = [0u8; 18];
713 let chunk_size = write_chunk_size(&mut chunk_size_buf, data.len());
714 self.write_buf.push_copy(chunk_size);
715 self.write_buf.push_bytes(data);
716 unsafe {
717 self.write_buf.push(IoSlice::new(b"\r\n"));
718 }
719 } else {
720 self.write_buf.push_bytes(data);
721 }
722 while self.write_buf.len() >= WRITE_BUF_BATCH_THRESHOLD {
723 let bytes_written = unsafe {
724 self.write_buf
725 .write(&mut self.io, self.options.enable_vectored_write)
726 .await?
727 };
728 if bytes_written == 0 {
729 return Err(std::io::ErrorKind::WriteZero.into());
730 }
731 }
732 }
733 Err(chunk) => {
734 if let Ok(trailers) = chunk.into_trailers() {
735 if write_trailers {
736 unsafe {
737 self.write_buf.push(IoSlice::new(b"0\r\n"));
738 for (name, value) in &trailers {
739 self.write_buf.push_copy(name.as_str().as_bytes());
740 self.write_buf.push(IoSlice::new(b": "));
741 self.write_buf.push_copy(value.as_bytes());
742 self.write_buf.push(IoSlice::new(b"\r\n"));
743 }
744 self.write_buf.push(IoSlice::new(b"\r\n"));
745 }
746 trailers_written = true;
747 }
748 break;
749 }
750 }
751 };
752 }
753 if chunked && !trailers_written {
754 unsafe {
756 self.write_buf.push(IoSlice::new(b"0\r\n\r\n"));
757 }
758 }
759 unsafe {
760 self.write_buf
761 .flush(&mut self.io, self.options.enable_vectored_write)
762 .await?;
763 }
764 self.io.flush().await?;
765 let reclaimed_headers = parts.headers;
766 self.cached_headers = Some(reclaimed_headers);
767
768 Ok(())
769 }
770
771 #[inline]
772 async fn write_100_continue(&mut self, version: Version) -> Result<(), std::io::Error> {
773 if version == Version::HTTP_10 {
774 self.io.write_all(b"HTTP/1.0 100 Continue\r\n\r\n").await?;
775 } else {
776 self.io.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").await?;
777 }
778 self.io.flush().await?;
779
780 Ok(())
781 }
782
783 #[inline]
784 async fn write_early_hints(
785 &mut self,
786 version: Version,
787 headers: http::HeaderMap,
788 ) -> Result<(), std::io::Error> {
789 let mut head = Vec::new();
790 if version == Version::HTTP_10 {
791 head.extend_from_slice(b"HTTP/1.0 103 Early Hints\r\n");
792 } else {
793 head.extend_from_slice(b"HTTP/1.1 103 Early Hints\r\n");
794 }
795 let mut current_header_name = None;
796 for (name, value) in headers {
797 if let Some(name) = name {
798 current_header_name = Some(name);
799 };
800 if let Some(current_header_name) = ¤t_header_name {
801 head.extend_from_slice(current_header_name.as_str().as_bytes());
802 if value.is_empty() {
803 head.extend_from_slice(b":\r\n");
804 continue;
805 }
806 head.extend_from_slice(b": ");
807 head.extend_from_slice(value.as_bytes());
808 head.extend_from_slice(b"\r\n");
809 }
810 }
811 head.extend_from_slice(b"\r\n");
812
813 self.io.write_all(&head).await?;
814
815 Ok(())
816 }
817
818 #[inline]
819 pub(crate) async fn handle_with_error_fn_and_zerocopy<
820 F,
821 Fut,
822 ResB,
823 ResBE,
824 ResE,
825 EF,
826 EFut,
827 EResB,
828 EResBE,
829 EResE,
830 ZF,
831 ZFut,
832 >(
833 mut self,
834 request_fn: F,
835 error_fn: EF,
836 mut zerocopy_fn: Option<ZF>,
837 ) -> Result<(), std::io::Error>
838 where
839 F: Fn(Request<Incoming>) -> Fut + 'static,
840 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
841 ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
842 ResE: std::error::Error,
843 ResBE: std::error::Error,
844 EF: FnOnce(bool) -> EFut,
845 EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
846 EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin + 'static,
847 EResE: std::error::Error,
848 EResBE: std::error::Error,
849 ZF: FnMut(RawHandle, &'static Io, u64) -> ZFut,
850 ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
851 {
852 let mut keep_alive = true;
853
854 while keep_alive {
855 let (mut request, body_tx) = match if let Some(timeout) =
856 self.options.header_read_timeout
857 {
858 vibeio::time::timeout(timeout, async {
859 if let Some(token) = self.cancel_token.clone() {
860 token.run_until_cancelled(self.read_request()).await
861 } else {
862 Some(self.read_request().await)
863 }
864 })
865 .await
866 } else {
867 Ok(Some(self.read_request().await))
868 } {
869 Ok(Some(Ok(Some(d)))) => d,
870 Ok(Some(Ok(None))) => {
871 return Ok(());
872 }
873 Ok(Some(Err(e))) => {
874 if let Ok(mut response) = error_fn(false).await {
876 response
877 .headers_mut()
878 .insert(header::CONNECTION, HeaderValue::from_static("close"));
879
880 let _ = self
881 .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
882 .await;
883 }
884 return Err(e);
885 }
886 Ok(None) => {
887 return Ok(());
889 }
890 Err(_) => {
891 if let Ok(mut response) = error_fn(true).await {
893 response
894 .headers_mut()
895 .insert(header::CONNECTION, HeaderValue::from_static("close"));
896
897 let _ = self
898 .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
899 .await;
900 }
901 return Err(std::io::Error::new(
902 std::io::ErrorKind::TimedOut,
903 "header read timeout",
904 ));
905 }
906 };
907
908 let connection_header_split = request
910 .headers()
911 .get(header::CONNECTION)
912 .and_then(|v| v.to_str().ok())
913 .map(|v| v.split(",").map(|v| v.trim()));
914 let is_connection_close = connection_header_split
915 .clone()
916 .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("close")));
917 let is_connection_keep_alive = connection_header_split
918 .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("keep-alive")));
919 keep_alive = !is_connection_close
920 && (is_connection_keep_alive || request.version() == http::Version::HTTP_11);
921
922 let version = request.version();
923
924 if self.options.send_continue_response {
926 let is_100_continue = request
927 .headers()
928 .get(header::EXPECT)
929 .and_then(|v| v.to_str().ok())
930 .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
931 if is_100_continue {
932 self.write_100_continue(version).await?;
933 }
934 }
935
936 let early_hints_fut = if self.options.enable_early_hints {
938 let (early_hints, mut early_hints_rx) = EarlyHints::new_lazy();
939 request.extensions_mut().insert(early_hints);
940 let mut_self = unsafe { std::mem::transmute::<&mut Self, &mut Self>(&mut self) };
944 futures_util::future::Either::Left(async move {
945 while let Some((headers, sender)) =
946 std::future::poll_fn(|cx| early_hints_rx.poll_recv(cx)).await
947 {
948 sender
949 .into_inner()
950 .send(mut_self.write_early_hints(version, headers).await)
951 .ok();
952 }
953 futures_util::future::pending::<Result<(), std::io::Error>>().await
954 })
955 } else {
956 futures_util::future::Either::Right(futures_util::future::pending::<
957 Result<(), std::io::Error>,
958 >())
959 };
960
961 let content_length = request
963 .headers()
964 .get(header::CONTENT_LENGTH)
965 .and_then(|v| v.to_str().ok())
966 .and_then(|v| v.parse::<u64>().ok())
967 .unwrap_or(0);
968 let chunked = request
969 .headers()
970 .get(header::TRANSFER_ENCODING)
971 .and_then(|v| v.to_str().ok())
972 .is_some_and(|v| {
973 v.split(',')
974 .any(|v| v.trim().eq_ignore_ascii_case("chunked"))
975 });
976 let has_trailers = request
977 .headers()
978 .get(header::TRAILER)
979 .map(|v| v.to_str().ok().is_some_and(|s| !s.is_empty()))
980 .unwrap_or(false);
981 let write_trailers = request
982 .headers()
983 .get(header::TE)
984 .and_then(|v| v.to_str().ok())
985 .map(|v| {
986 v.split(',')
987 .any(|v| v.trim().eq_ignore_ascii_case("trailers"))
988 })
989 .unwrap_or(false);
990
991 let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
993 let upgrade = Upgrade::new(upgrade_rx);
994 let upgraded = upgrade.upgraded.clone();
995 request.extensions_mut().insert(upgrade);
996
997 let mut response = {
999 let read_body_fut = async {
1000 if chunked {
1001 self.read_chunked_body_fn(body_tx, has_trailers).await
1002 } else {
1003 self.read_body_fn(body_tx, content_length).await
1004 }
1005 };
1006 let read_body_fut_pin = std::pin::pin!(read_body_fut);
1007 let request_fut = request_fn(request);
1008 let request_fut_pin = std::pin::pin!(request_fut);
1009 let early_hints_fut_pin = std::pin::pin!(early_hints_fut);
1010
1011 let select_read_body_either =
1012 futures_util::future::select(request_fut_pin, early_hints_fut_pin);
1013 let select_either =
1014 futures_util::future::select(read_body_fut_pin, select_read_body_either).await;
1015
1016 let (response, body_fut) = match select_either {
1017 futures_util::future::Either::Left((result, request_fut)) => {
1018 result?;
1019 (
1020 match request_fut.await {
1021 futures_util::future::Either::Left((response, _)) => response,
1022 futures_util::future::Either::Right((_, _)) => unreachable!(),
1023 },
1024 None,
1025 )
1026 }
1027 futures_util::future::Either::Right((response, read_body_fut)) => (
1028 match response {
1029 futures_util::future::Either::Left((response, _)) => response,
1030 futures_util::future::Either::Right((_, _)) => unreachable!(),
1031 },
1032 Some(read_body_fut),
1033 ),
1034 };
1035
1036 if let Some(body_fut) = body_fut {
1038 body_fut.await?;
1039 }
1040
1041 response.map_err(|e| std::io::Error::other(e.to_string()))?
1042 };
1043
1044 let mut was_upgraded = false;
1045 if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
1046 was_upgraded = true;
1047 response
1048 .headers_mut()
1049 .insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
1050 } else if keep_alive {
1051 if version == Version::HTTP_10
1052 || response.headers().contains_key(header::CONNECTION)
1053 {
1054 response
1055 .headers_mut()
1056 .insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
1057 }
1058 } else if version == Version::HTTP_11
1059 || response.headers().contains_key(header::CONNECTION)
1060 {
1061 response
1062 .headers_mut()
1063 .insert(header::CONNECTION, HeaderValue::from_static("close"));
1064 }
1065
1066 self.write_response(response, version, write_trailers, zerocopy_fn.as_mut())
1068 .await?;
1069
1070 if was_upgraded {
1071 let frozen_buf = self.read_buf.freeze();
1073 let _ = upgrade_tx.send(Upgraded::new(
1074 self.io,
1075 if frozen_buf.is_empty() {
1076 None
1077 } else {
1078 Some(frozen_buf)
1079 },
1080 ));
1081 return Ok(());
1082 }
1083
1084 if self.cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
1085 break;
1087 }
1088 }
1089 Ok(())
1090 }
1091}
1092
1093impl<Io> HttpProtocol for Http1<Io>
1094where
1095 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
1096{
1097 #[inline]
1098 fn handle_with_error_fn<F, Fut, ResB, ResBE, ResE, EF, EFut, EResB, EResBE, EResE>(
1099 self,
1100 request_fn: F,
1101 error_fn: EF,
1102 ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1103 where
1104 F: Fn(Request<Incoming>) -> Fut + 'static,
1105 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
1106 ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
1107 ResE: std::error::Error,
1108 ResBE: std::error::Error,
1109 EF: FnOnce(bool) -> EFut,
1110 EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
1111 EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin + 'static,
1112 EResE: std::error::Error,
1113 EResBE: std::error::Error,
1114 {
1115 #[allow(clippy::type_complexity)]
1116 let no_zerocopy: Option<
1117 Box<
1118 dyn FnMut(
1119 RawHandle,
1120 &Io,
1121 u64,
1122 ) -> Box<
1123 dyn std::future::Future<Output = Result<(), std::io::Error>>
1124 + Unpin
1125 + Send
1126 + Sync,
1127 >,
1128 >,
1129 > = None;
1130 self.handle_with_error_fn_and_zerocopy(request_fn, error_fn, no_zerocopy)
1131 }
1132
1133 #[inline]
1134 fn handle<F, Fut, ResB, ResBE, ResE>(
1135 self,
1136 request_fn: F,
1137 ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1138 where
1139 F: Fn(Request<Incoming>) -> Fut + 'static,
1140 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
1141 ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
1142 ResE: std::error::Error,
1143 ResBE: std::error::Error,
1144 {
1145 self.handle_with_error_fn(request_fn, |is_timeout| async move {
1146 let mut response = Response::builder();
1147 if is_timeout {
1148 response = response.status(http::StatusCode::REQUEST_TIMEOUT);
1149 } else {
1150 response = response.status(http::StatusCode::BAD_REQUEST);
1151 }
1152 response.body(Empty::new())
1153 })
1154 }
1155}
1156
1157pub(crate) struct Http1Body {
1158 #[allow(clippy::type_complexity)]
1159 inner: Pin<Box<AsyncReceiver<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>>>,
1160}
1161
1162impl Body for Http1Body {
1163 type Data = bytes::Bytes;
1164 type Error = std::io::Error;
1165
1166 #[inline]
1167 fn poll_frame(
1168 self: Pin<&mut Self>,
1169 cx: &mut Context<'_>,
1170 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
1171 match std::pin::pin!(self.inner.recv()).poll(cx) {
1172 Poll::Ready(Ok(Ok(frame))) => Poll::Ready(Some(Ok(frame))),
1173 Poll::Ready(Ok(Err(e))) => Poll::Ready(Some(Err(e))),
1174 Poll::Ready(Err(_)) => Poll::Ready(None),
1175 Poll::Pending => Poll::Pending,
1176 }
1177 }
1178}
1179
1180#[inline]
1183fn search_header_body_separator(slice: &[u8]) -> Option<(usize, usize)> {
1184 if slice.len() < 2 {
1185 return None;
1187 }
1188 for (i, b) in slice.iter().copied().enumerate() {
1189 if b == b'\r' {
1190 if slice[i + 1..].chunks(3).next() == Some(&b"\n\r\n"[..]) {
1191 return Some((i, 4));
1192 }
1193 } else if b == b'\n' && slice.get(i + 1) == Some(&b'\n') {
1194 return Some((i, 2));
1195 }
1196 }
1197 None
1198}
1199
1200#[inline]
1202fn write_chunk_size(dst: &mut [u8; 18], len: usize) -> &[u8] {
1203 let mut n = len;
1204 let mut pos = dst.len() - 2;
1205 loop {
1206 pos -= 1;
1207 dst[pos] = HEX_DIGITS[n & 0xF];
1208 n >>= 4;
1209 if n == 0 {
1210 break;
1211 }
1212 }
1213 dst[dst.len() - 2] = b'\r';
1214 dst[dst.len() - 1] = b'\n';
1215 &dst[pos..]
1216}