Skip to main content

vibeio_http/h1/
mod.rs

1mod options;
2mod tests;
3mod writebuf;
4mod zerocopy;
5
6pub use options::*;
7pub use zerocopy::*;
8
9#[cfg(unix)]
10pub(crate) type RawHandle = std::os::fd::RawFd;
11#[cfg(windows)]
12pub(crate) type RawHandle = std::os::windows::io::RawHandle;
13
14use std::{
15    future::Future,
16    io::IoSlice,
17    mem::MaybeUninit,
18    pin::Pin,
19    str::FromStr,
20    task::{Context, Poll},
21    time::UNIX_EPOCH,
22};
23
24use bytes::{Buf, Bytes, BytesMut};
25use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, Uri, Version};
26use http_body::Body;
27use http_body_util::{BodyExt, Empty};
28use kanal::AsyncReceiver;
29use memchr::{memchr3_iter, memmem};
30use tokio::io::{AsyncReadExt, AsyncWriteExt};
31use tokio_util::sync::CancellationToken;
32
33use crate::{h1::writebuf::WriteBuf, EarlyHints, HttpProtocol, Incoming, Upgrade, Upgraded};
34
35const HEX_DIGITS: &[u8; 16] = b"0123456789ABCDEF";
36const WRITE_BUF_BATCH_THRESHOLD: usize = 16384;
37
38/// An HTTP/1.x connection handler.
39///
40/// `Http1` wraps an async I/O stream (`Io`) and provides a complete
41/// HTTP/1.0 and HTTP/1.1 server implementation, including:
42///
43/// - Request head parsing (via [`httparse`])
44/// - Streaming request bodies (content-length and chunked transfer-encoding)
45/// - Chunked response encoding and trailer support
46/// - `100 Continue` and `103 Early Hints` interim responses
47/// - HTTP connection upgrades (e.g. WebSocket)
48/// - Optional zero-copy response sending on Linux (see `Http1::zerocopy`)
49/// - Keep-alive connection reuse
50/// - Graceful shutdown via a [`CancellationToken`]
51///
52/// # Construction
53///
54/// ```rust,ignore
55/// let http1 = Http1::new(tcp_stream, Http1Options::default());
56/// ```
57///
58/// # Serving requests
59///
60/// Use the [`HttpProtocol`] trait methods ([`handle`](HttpProtocol::handle) /
61/// [`handle_with_error_fn`](HttpProtocol::handle_with_error_fn)) to drive the
62/// connection to completion:
63///
64/// ```rust,ignore
65/// http1.handle(|req| async move {
66///     Ok::<_, Infallible>(Response::new(Full::new(Bytes::from("Hello!"))))
67/// }).await?;
68/// ```
69pub struct Http1<Io> {
70    io: Io,
71    options: options::Http1Options,
72    cancel_token: Option<CancellationToken>,
73    parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]>,
74    date_header_value_cached: Option<(String, std::time::SystemTime)>,
75    cached_headers: Option<HeaderMap>,
76    read_buf: BytesMut,
77    response_head_buf: Vec<u8>,
78    write_buf: WriteBuf,
79}
80
81#[cfg(all(target_os = "linux", feature = "h1-zerocopy"))]
82impl<Io> Http1<Io>
83where
84    for<'a> Io: tokio::io::AsyncRead
85        + tokio::io::AsyncWrite
86        + vibeio::io::AsInnerRawHandle<'a>
87        + Unpin
88        + 'static,
89{
90    /// Converts this `Http1` into an [`Http1Zerocopy`] that uses emulated
91    /// sendfile (Linux only) to send response bodies without copying data
92    /// through user space.
93    ///
94    /// The response body must have a `ZerocopyResponse` extension installed
95    /// (via [`install_zerocopy`]) containing the file descriptor to send from.
96    /// Responses without that extension are sent normally.
97    ///
98    /// Only available on Linux (`target_os = "linux"`), and only when `Io`
99    /// implements [`vibeio::io::AsInnerRawHandle`].
100    #[inline]
101    pub fn zerocopy(self) -> Http1Zerocopy<Io> {
102        Http1Zerocopy { inner: self }
103    }
104}
105
106impl<Io> Http1<Io>
107where
108    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
109{
110    /// Creates a new `Http1` connection handler wrapping the given I/O stream.
111    ///
112    /// The `options` value controls limits, timeouts, and optional features;
113    /// see [`Http1Options`] for details.
114    ///
115    /// # Example
116    ///
117    /// ```rust,ignore
118    /// let http1 = Http1::new(tcp_stream, Http1Options::default());
119    /// ```
120    #[inline]
121    pub fn new(io: Io, options: options::Http1Options) -> Self {
122        // Safety: u8 is a primitive type, so we can safely assume initialization
123        let read_buf = BytesMut::with_capacity(options.max_header_size);
124        let parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]> =
125            Box::new_uninit_slice(options.max_header_count);
126        Self {
127            io,
128            options,
129            cancel_token: None,
130            parsed_headers,
131            date_header_value_cached: None,
132            cached_headers: None,
133            read_buf,
134            response_head_buf: Vec::with_capacity(1024),
135            write_buf: WriteBuf::new(),
136        }
137    }
138
139    #[inline]
140    fn get_date_header_value(&mut self) -> &str {
141        let now = std::time::SystemTime::now();
142        if self.date_header_value_cached.as_ref().is_none_or(|v| {
143            v.1.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
144                != now.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
145        }) {
146            let value = httpdate::fmt_http_date(now).to_string();
147            self.date_header_value_cached = Some((value, now));
148        }
149        self.date_header_value_cached
150            .as_ref()
151            .map(|v| v.0.as_str())
152            .unwrap_or("")
153    }
154
155    /// Attaches a [`CancellationToken`] for graceful shutdown.
156    ///
157    /// After the current in-flight request has been fully handled and its
158    /// response written, the connection loop checks whether the token has been
159    /// cancelled. If it has, the loop exits cleanly instead of waiting for the
160    /// next request.
161    ///
162    /// This allows the server to drain active connections without abruptly
163    /// closing them mid-response.
164    #[inline]
165    pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
166        self.cancel_token = Some(token);
167        self
168    }
169
170    #[inline]
171    async fn fill_buf(&mut self) -> Result<usize, std::io::Error> {
172        if self.read_buf.remaining() < 1024 {
173            self.read_buf.reserve(1024);
174        }
175        let spare_capacity = self.read_buf.spare_capacity_mut();
176        // Safety: The buffer is are read only after the request head has been parsed
177        let n = self
178            .io
179            .read(unsafe {
180                &mut *std::ptr::slice_from_raw_parts_mut(
181                    spare_capacity.as_mut_ptr() as *mut u8,
182                    spare_capacity.len(),
183                )
184            })
185            .await?;
186        if n == 0 {
187            return Ok(0);
188        }
189        unsafe { self.read_buf.set_len(self.read_buf.len() + n) };
190        Ok(n)
191    }
192
193    #[inline]
194    async fn read_body_fn(
195        &mut self,
196        body_tx: kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
197        content_length: u64,
198    ) -> Result<(), std::io::Error> {
199        let mut remaining = content_length;
200        let mut just_started = true;
201        while remaining > 0 {
202            let have_to_read_buf = !just_started || self.read_buf.is_empty();
203            just_started = false;
204            if have_to_read_buf {
205                let n = self.fill_buf().await?;
206                if n == 0 {
207                    break;
208                }
209            }
210            let chunk = self
211                .read_buf
212                .split_to(
213                    self.read_buf
214                        .len()
215                        .min(remaining.min(usize::MAX as u64) as usize),
216                )
217                .freeze();
218            remaining -= chunk.len() as u64;
219
220            let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
221        }
222        Ok(())
223    }
224
225    #[inline]
226    async fn read_body_chunk(
227        &mut self,
228        would_have_trailers: bool,
229    ) -> Result<bytes::Bytes, std::io::Error> {
230        let len = {
231            // Safety: u8 is a primitive type, so we can safely assume initialization
232            let mut len_buf_pos: usize = 0;
233            let mut just_started = true;
234            loop {
235                if len_buf_pos >= 48 {
236                    return Err(std::io::Error::new(
237                        std::io::ErrorKind::InvalidData,
238                        "chunk length buffer overflow",
239                    ));
240                }
241
242                let begin_search = len_buf_pos.saturating_sub(1);
243
244                let have_to_read_buf = !just_started || self.read_buf.is_empty();
245                just_started = false;
246                if have_to_read_buf {
247                    let n = self.fill_buf().await?;
248                    if n == 0 {
249                        return Err(std::io::Error::new(
250                            std::io::ErrorKind::UnexpectedEof,
251                            "unexpected EOF",
252                        ));
253                    }
254                    len_buf_pos += n;
255                } else {
256                    len_buf_pos += self.read_buf.len();
257                }
258
259                if let Some(pos) =
260                    memmem::find(&self.read_buf[begin_search..len_buf_pos.min(48)], b"\r\n")
261                {
262                    let numbers =
263                        std::str::from_utf8(&self.read_buf[begin_search..begin_search + pos])
264                            .map_err(|_| {
265                                std::io::Error::new(
266                                    std::io::ErrorKind::InvalidData,
267                                    "invalid chunk length",
268                                )
269                            })?;
270                    let len = usize::from_str_radix(numbers, 16).map_err(|_| {
271                        std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid chunk length")
272                    })?;
273                    // Ignore the trailing CRLF
274                    self.read_buf.advance(begin_search + pos + 2);
275                    break len;
276                }
277            }
278        };
279        // Safety: u8 is a primitive type, so we can safely assume initialization
280        let mut read = 0;
281        if len == 0 && would_have_trailers {
282            return Ok(bytes::Bytes::new()); // Empty terminating chunk
283        }
284        let mut just_started = true;
285        // + 2, because we need to read the trailing CRLF
286        while read < len + 2 {
287            let have_to_read_buf = !just_started || self.read_buf.is_empty();
288            just_started = false;
289            if have_to_read_buf {
290                let n = self.fill_buf().await?;
291                if n == 0 {
292                    return Err(std::io::Error::new(
293                        std::io::ErrorKind::UnexpectedEof,
294                        "unexpected EOF",
295                    ));
296                }
297                read += n;
298            } else {
299                read += self.read_buf.len();
300            }
301        }
302        let chunk = self.read_buf.split_to(len).freeze();
303        self.read_buf.advance(2); // Ignore the trailing CRLF
304        Ok(chunk)
305    }
306
307    #[inline]
308    async fn read_trailers(&mut self) -> Result<Option<HeaderMap>, std::io::Error> {
309        // Safety: u8 is a primitive type, so we can safely assume initialization
310        let mut bytes_read: usize = 0;
311        let mut just_started = true;
312        while bytes_read < self.options.max_header_size {
313            let old_bytes_read = bytes_read;
314            let begin_search = old_bytes_read.saturating_sub(3);
315
316            let have_to_read_buf = !just_started || self.read_buf.is_empty();
317            just_started = false;
318            if have_to_read_buf {
319                let n = self.fill_buf().await?;
320                if n == 0 {
321                    return Err(std::io::Error::new(
322                        std::io::ErrorKind::UnexpectedEof,
323                        "unexpected EOF",
324                    ));
325                }
326                bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
327            } else {
328                bytes_read =
329                    (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
330            }
331
332            if bytes_read > 2 && self.read_buf[0] == b'\r' && self.read_buf[1] == b'\n' {
333                // No trailers, return None
334                return Ok(None);
335            }
336
337            if let Some(separator_index) =
338                memmem::find(&self.read_buf[begin_search..bytes_read], b"\r\n\r\n")
339            {
340                let to_parse_length = begin_search + separator_index + 4;
341                let buf_ro = self.read_buf.split_to(to_parse_length).freeze();
342
343                // Parse trailers using `httparse` crate's header parsing
344                let mut httparse_trailers =
345                    vec![httparse::EMPTY_HEADER; self.options.max_header_count].into_boxed_slice();
346                let status = httparse::parse_headers(&buf_ro, &mut httparse_trailers)
347                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
348                if let httparse::Status::Complete((_, trailers)) = status {
349                    let mut trailers_constructed = HeaderMap::new();
350                    for header in trailers {
351                        if header == &httparse::EMPTY_HEADER {
352                            // No more headers...
353                            break;
354                        }
355                        let name = HeaderName::from_bytes(header.name.as_bytes())
356                            .map_err(|e| std::io::Error::other(e.to_string()))?;
357                        let value_start = header.value.as_ptr() as usize - buf_ro.as_ptr() as usize;
358                        let value_len = header.value.len();
359                        // Safety: the header value is already validated by httparse
360                        let value = unsafe {
361                            HeaderValue::from_maybe_shared_unchecked(
362                                buf_ro.slice(value_start..(value_start + value_len)),
363                            )
364                        };
365                        trailers_constructed.append(name, value);
366                    }
367
368                    return Ok(Some(trailers_constructed));
369                } else {
370                    return Err(std::io::Error::new(
371                        std::io::ErrorKind::InvalidInput,
372                        "trailer headers incomplete",
373                    ));
374                }
375            }
376        }
377        Err(std::io::Error::new(
378            std::io::ErrorKind::InvalidData,
379            "request too large",
380        ))
381    }
382
383    #[inline]
384    async fn read_chunked_body_fn(
385        &mut self,
386        body_tx: kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
387        would_have_trailers: bool,
388    ) -> Result<(), std::io::Error> {
389        loop {
390            let chunk = self.read_body_chunk(would_have_trailers).await?;
391            if chunk.is_empty() {
392                break;
393            }
394
395            let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
396        }
397        if would_have_trailers {
398            // Trailers
399            let trailers = self.read_trailers().await?;
400            if let Some(trailers) = trailers {
401                let _ = body_tx.send(Ok(http_body::Frame::trailers(trailers))).await;
402            }
403        }
404        Ok(())
405    }
406
407    #[inline]
408    async fn read_request(
409        &mut self,
410    ) -> Result<
411        Option<(
412            Request<Incoming>,
413            kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
414        )>,
415        std::io::Error,
416    > {
417        // Parse HTTP request using httparse
418        let (request, body_tx) = {
419            let Some((head, headers)) = self.get_head().await? else {
420                return Ok(None);
421            };
422            // Safety: The headers are read only after the request head has been parsed
423            let headers = unsafe {
424                std::mem::transmute::<
425                    &mut [MaybeUninit<httparse::Header<'static>>],
426                    &mut [MaybeUninit<httparse::Header<'_>>],
427                >(headers)
428            };
429            let mut req = httparse::Request::new(&mut []);
430            let status = req
431                .parse_with_uninit_headers(&head, headers)
432                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
433            if status.is_partial() {
434                return Err(std::io::Error::new(
435                    std::io::ErrorKind::InvalidData,
436                    "partial request head",
437                ));
438            }
439
440            // Convert httparse HTTP request to `http` one
441            let (body_tx, body_rx) = kanal::bounded_async(2);
442            let request_body = Http1Body {
443                inner: Box::pin(body_rx),
444            };
445            let mut request = Request::new(Incoming::H1(request_body));
446            match req.version {
447                Some(0) => *request.version_mut() = http::Version::HTTP_10,
448                Some(1) => *request.version_mut() = http::Version::HTTP_11,
449                _ => *request.version_mut() = http::Version::HTTP_11,
450            };
451            if let Some(method) = req.method {
452                *request.method_mut() = Method::from_bytes(method.as_bytes())
453                    .map_err(|e| std::io::Error::other(e.to_string()))?;
454            }
455            if let Some(path) = req.path {
456                *request.uri_mut() =
457                    Uri::from_str(path).map_err(|e| std::io::Error::other(e.to_string()))?;
458            }
459            let mut header_map = self.cached_headers.take().unwrap_or_default();
460            header_map.clear();
461            let additional_capacity = req.headers.len().saturating_sub(header_map.capacity());
462            if additional_capacity > 0 {
463                header_map.reserve(additional_capacity);
464            }
465            for header in req.headers {
466                if header == &httparse::EMPTY_HEADER {
467                    // No more headers...
468                    break;
469                }
470                let name = HeaderName::from_bytes(header.name.as_bytes())
471                    .map_err(|e| std::io::Error::other(e.to_string()))?;
472                let value_start = header.value.as_ptr() as usize - head.as_ptr() as usize;
473                let value_len = header.value.len();
474                // Safety: the header value is already validated by httparse
475                let value = unsafe {
476                    HeaderValue::from_maybe_shared_unchecked(
477                        head.slice(value_start..(value_start + value_len)),
478                    )
479                };
480                header_map.append(name, value);
481            }
482            *request.headers_mut() = header_map;
483
484            (request, body_tx)
485        };
486        Ok(Some((request, body_tx)))
487    }
488
489    #[inline]
490    async fn get_head(
491        &mut self,
492    ) -> Result<Option<(Bytes, &mut [MaybeUninit<httparse::Header<'static>>])>, std::io::Error>
493    {
494        let mut request_line_read = false;
495        let mut bytes_read: usize = 0;
496        let mut whitespace_trimmed = None;
497        let mut just_started = true;
498        while bytes_read < self.options.max_header_size {
499            let old_bytes_read = bytes_read;
500            let begin_search = old_bytes_read.saturating_sub(3);
501
502            let have_to_read_buf = !just_started || self.read_buf.is_empty();
503            just_started = false;
504            if have_to_read_buf {
505                let n = self.fill_buf().await?;
506                if n == 0 {
507                    if whitespace_trimmed.is_none() {
508                        return Ok(None);
509                    }
510                    return Err(std::io::Error::new(
511                        std::io::ErrorKind::UnexpectedEof,
512                        "unexpected EOF",
513                    ));
514                }
515                bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
516            } else {
517                bytes_read =
518                    (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
519            }
520
521            if whitespace_trimmed.is_none() {
522                whitespace_trimmed = self.read_buf[old_bytes_read..bytes_read]
523                    .iter()
524                    .position(|b| !b.is_ascii_whitespace());
525            }
526
527            if let Some(whitespace_trimmed) = whitespace_trimmed {
528                // Validate first line (request line) before checking for header/body separator
529                if !request_line_read {
530                    let memchr = memchr3_iter(
531                        b' ',
532                        b'\r',
533                        b'\n',
534                        &self.read_buf[whitespace_trimmed..bytes_read],
535                    );
536                    let mut spaces = 0;
537                    for separator_index in memchr {
538                        if self.read_buf[whitespace_trimmed + separator_index] == b' ' {
539                            if spaces >= 2 {
540                                return Err(std::io::Error::new(
541                                    std::io::ErrorKind::InvalidInput,
542                                    "bad request first line",
543                                ));
544                            }
545                            spaces += 1;
546                        } else if spaces == 2 {
547                            request_line_read = true;
548                            break;
549                        } else {
550                            return Err(std::io::Error::new(
551                                std::io::ErrorKind::InvalidInput,
552                                "bad request first line",
553                            ));
554                        }
555                    }
556                }
557
558                if request_line_read {
559                    let begin_search = begin_search.max(whitespace_trimmed);
560                    if let Some((separator_index, separator_len)) =
561                        search_header_body_separator(&self.read_buf[begin_search..bytes_read])
562                    {
563                        let to_parse_length =
564                            begin_search + separator_index + separator_len - whitespace_trimmed;
565                        self.read_buf.advance(whitespace_trimmed);
566                        let head = self.read_buf.split_to(to_parse_length);
567                        return Ok(Some((head.freeze(), &mut self.parsed_headers)));
568                    }
569                }
570            }
571        }
572        Err(std::io::Error::new(
573            std::io::ErrorKind::InvalidData,
574            "request too large",
575        ))
576    }
577
578    #[inline]
579    async fn write_response<Z, ZFut>(
580        &mut self,
581        mut response: Response<
582            impl Body<Data = bytes::Bytes, Error = impl std::error::Error> + Unpin,
583        >,
584        version: Version,
585        write_trailers: bool,
586        zerocopy_fn: Option<Z>,
587    ) -> Result<(), std::io::Error>
588    where
589        Z: FnMut(RawHandle, &'static Io, u64) -> ZFut,
590        ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
591    {
592        // Date header
593        if self.options.send_date_header {
594            response.headers_mut().insert(
595                header::DATE,
596                HeaderValue::from_str(self.get_date_header_value())
597                    .map_err(|e| std::io::Error::other(e.to_string()))?,
598            );
599        }
600
601        // If the body has a size hint, set the Content-Length header if it's not already set
602        if let Some(suggested_content_length) = response.body().size_hint().exact() {
603            let headers = response.headers_mut();
604            if !headers.contains_key(header::CONTENT_LENGTH) {
605                headers.insert(header::CONTENT_LENGTH, suggested_content_length.into());
606            }
607        }
608
609        let chunked = response
610            .headers()
611            .get(header::TRANSFER_ENCODING)
612            .map(|v| {
613                v.to_str().ok().is_some_and(|s| {
614                    s.split(',')
615                        .any(|s| s.trim().eq_ignore_ascii_case("chunked"))
616                })
617            })
618            .unwrap_or_else(|| {
619                response
620                    .headers()
621                    .get(header::CONTENT_LENGTH)
622                    .and_then(|v| v.to_str().ok())
623                    .is_none_or(|s| s.parse::<u64>().is_err())
624            });
625
626        if chunked {
627            response.headers_mut().insert(
628                header::TRANSFER_ENCODING,
629                HeaderValue::from_static("chunked"),
630            );
631            while response
632                .headers_mut()
633                .remove(header::CONTENT_LENGTH)
634                .is_some()
635            {}
636        }
637
638        let (parts, mut body) = response.into_parts();
639
640        self.response_head_buf.clear();
641        let estimated_head_len = 30 + parts.headers.len() * 30; // Similar to Hyper's heuristic
642        if self.response_head_buf.capacity() < estimated_head_len {
643            self.response_head_buf
644                .reserve(estimated_head_len - self.response_head_buf.capacity());
645        }
646        let head = &mut self.response_head_buf;
647        if version == Version::HTTP_10 {
648            head.extend_from_slice(b"HTTP/1.0 ");
649        } else {
650            head.extend_from_slice(b"HTTP/1.1 ");
651        }
652        let status = parts.status;
653        head.extend_from_slice(status.as_str().as_bytes());
654        if let Some(canonical_reason) = status.canonical_reason() {
655            head.extend_from_slice(b" ");
656            head.extend_from_slice(canonical_reason.as_bytes());
657        }
658        head.extend_from_slice(b"\r\n");
659        for (name, value) in &parts.headers {
660            head.extend_from_slice(name.as_str().as_bytes());
661            head.extend_from_slice(b": ");
662            head.extend_from_slice(value.as_bytes());
663            head.extend_from_slice(b"\r\n");
664        }
665        head.extend_from_slice(b"\r\n");
666        unsafe {
667            self.write_buf.push(IoSlice::new(head));
668        }
669
670        if !chunked {
671            if let Some(content_length) = parts
672                .headers
673                .get(header::CONTENT_LENGTH)
674                .and_then(|v| v.to_str().ok())
675                .and_then(|s| s.parse::<u64>().ok())
676            {
677                if let Some(zero_copy) = parts.extensions.get::<ZerocopyResponse>() {
678                    if let Some(mut zerocopy_fn) = zerocopy_fn {
679                        // Zerocopy
680                        unsafe {
681                            self.write_buf
682                                .flush(&mut self.io, self.options.enable_vectored_write)
683                                .await?
684                        };
685                        zerocopy_fn(
686                            zero_copy.handle,
687                            // Safety: the lifetime of the static reference is bound by the lifetime of the Io struct
688                            unsafe { std::mem::transmute::<&Io, &'static Io>(&self.io) },
689                            content_length,
690                        )
691                        .await?;
692                        self.io.flush().await?;
693                        let reclaimed_headers = parts.headers;
694                        self.cached_headers = Some(reclaimed_headers);
695                        return Ok(());
696                    }
697                }
698            }
699        }
700
701        let mut trailers_written = false;
702        while let Some(chunk) = body.frame().await {
703            let chunk = chunk.map_err(|e| std::io::Error::other(e.to_string()))?;
704            match chunk.into_data() {
705                Ok(data) => {
706                    if chunked {
707                        let mut chunk_size_buf = [0u8; 18];
708                        let chunk_size = write_chunk_size(&mut chunk_size_buf, data.len());
709                        self.write_buf.push_copy(chunk_size);
710                        self.write_buf.push_bytes(data);
711                        unsafe {
712                            self.write_buf.push(IoSlice::new(b"\r\n"));
713                        }
714                    } else {
715                        self.write_buf.push_bytes(data);
716                    }
717                    while self.write_buf.len() >= WRITE_BUF_BATCH_THRESHOLD {
718                        unsafe {
719                            self.write_buf
720                                .write(&mut self.io, self.options.enable_vectored_write)
721                                .await?;
722                        }
723                    }
724                }
725                Err(chunk) => {
726                    if let Ok(trailers) = chunk.into_trailers() {
727                        if write_trailers {
728                            unsafe {
729                                self.write_buf.push(IoSlice::new(b"0\r\n"));
730                                for (name, value) in &trailers {
731                                    self.write_buf.push_copy(name.as_str().as_bytes());
732                                    self.write_buf.push(IoSlice::new(b": "));
733                                    self.write_buf.push_copy(value.as_bytes());
734                                    self.write_buf.push(IoSlice::new(b"\r\n"));
735                                }
736                                self.write_buf.push(IoSlice::new(b"\r\n"));
737                            }
738                            trailers_written = true;
739                        }
740                        break;
741                    }
742                }
743            };
744        }
745        if chunked && !trailers_written {
746            // Terminating chunk
747            unsafe {
748                self.write_buf.push(IoSlice::new(b"0\r\n\r\n"));
749            }
750        }
751        unsafe {
752            self.write_buf
753                .flush(&mut self.io, self.options.enable_vectored_write)
754                .await?;
755        }
756        self.io.flush().await?;
757        let reclaimed_headers = parts.headers;
758        self.cached_headers = Some(reclaimed_headers);
759
760        Ok(())
761    }
762
763    #[inline]
764    async fn write_100_continue(&mut self, version: Version) -> Result<(), std::io::Error> {
765        if version == Version::HTTP_10 {
766            self.io.write_all(b"HTTP/1.0 100 Continue\r\n\r\n").await?;
767        } else {
768            self.io.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").await?;
769        }
770        self.io.flush().await?;
771
772        Ok(())
773    }
774
775    #[inline]
776    async fn write_early_hints(
777        &mut self,
778        version: Version,
779        headers: http::HeaderMap,
780    ) -> Result<(), std::io::Error> {
781        let mut head = Vec::new();
782        if version == Version::HTTP_10 {
783            head.extend_from_slice(b"HTTP/1.0 103 Early Hints\r\n");
784        } else {
785            head.extend_from_slice(b"HTTP/1.1 103 Early Hints\r\n");
786        }
787        let mut current_header_name = None;
788        for (name, value) in headers {
789            if let Some(name) = name {
790                current_header_name = Some(name);
791            };
792            if let Some(current_header_name) = &current_header_name {
793                head.extend_from_slice(current_header_name.as_str().as_bytes());
794                if value.is_empty() {
795                    head.extend_from_slice(b":\r\n");
796                    continue;
797                }
798                head.extend_from_slice(b": ");
799                head.extend_from_slice(value.as_bytes());
800                head.extend_from_slice(b"\r\n");
801            }
802        }
803        head.extend_from_slice(b"\r\n");
804
805        self.io.write_all(&head).await?;
806
807        Ok(())
808    }
809
810    #[inline]
811    pub(crate) async fn handle_with_error_fn_and_zerocopy<
812        F,
813        Fut,
814        ResB,
815        ResBE,
816        ResE,
817        EF,
818        EFut,
819        EResB,
820        EResBE,
821        EResE,
822        ZF,
823        ZFut,
824    >(
825        mut self,
826        request_fn: F,
827        error_fn: EF,
828        mut zerocopy_fn: Option<ZF>,
829    ) -> Result<(), std::io::Error>
830    where
831        F: Fn(Request<Incoming>) -> Fut + 'static,
832        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
833        ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
834        ResE: std::error::Error,
835        ResBE: std::error::Error,
836        EF: FnOnce(bool) -> EFut,
837        EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
838        EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin + 'static,
839        EResE: std::error::Error,
840        EResBE: std::error::Error,
841        ZF: FnMut(RawHandle, &'static Io, u64) -> ZFut,
842        ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
843    {
844        let mut keep_alive = true;
845
846        while keep_alive {
847            let (mut request, body_tx) = match if let Some(timeout) =
848                self.options.header_read_timeout
849            {
850                vibeio::time::timeout(timeout, self.read_request()).await
851            } else {
852                Ok(self.read_request().await)
853            } {
854                Ok(Ok(Some(d))) => d,
855                Ok(Ok(None)) => {
856                    return Ok(());
857                }
858                Ok(Err(e)) => {
859                    // Parse error
860                    if let Ok(mut response) = error_fn(false).await {
861                        response
862                            .headers_mut()
863                            .insert(header::CONNECTION, HeaderValue::from_static("close"));
864
865                        let _ = self
866                            .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
867                            .await;
868                    }
869                    return Err(e);
870                }
871                Err(_) => {
872                    // Timeout error
873                    if let Ok(mut response) = error_fn(true).await {
874                        response
875                            .headers_mut()
876                            .insert(header::CONNECTION, HeaderValue::from_static("close"));
877
878                        let _ = self
879                            .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
880                            .await;
881                    }
882                    return Err(std::io::Error::new(
883                        std::io::ErrorKind::TimedOut,
884                        "header read timeout",
885                    ));
886                }
887            };
888
889            // Connection header detection
890            let connection_header_split = request
891                .headers()
892                .get(header::CONNECTION)
893                .and_then(|v| v.to_str().ok())
894                .map(|v| v.split(",").map(|v| v.trim()));
895            let is_connection_close = connection_header_split
896                .clone()
897                .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("close")));
898            let is_connection_keep_alive = connection_header_split
899                .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("keep-alive")));
900            keep_alive = !is_connection_close
901                && (is_connection_keep_alive || request.version() == http::Version::HTTP_11);
902
903            let version = request.version();
904
905            // 100 Continue
906            if self.options.send_continue_response {
907                let is_100_continue = request
908                    .headers()
909                    .get(header::EXPECT)
910                    .and_then(|v| v.to_str().ok())
911                    .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
912                if is_100_continue {
913                    self.write_100_continue(version).await?;
914                }
915            }
916
917            // 103 Early Hints
918            let early_hints_fut = if self.options.enable_early_hints {
919                let (early_hints, mut early_hints_rx) = EarlyHints::new_lazy();
920                request.extensions_mut().insert(early_hints);
921                // Safety: the function below is used only in futures_util::future::select
922                // Also, another function that would borrow self would read data,
923                // while this function would write data
924                let mut_self = unsafe { std::mem::transmute::<&mut Self, &mut Self>(&mut self) };
925                futures_util::future::Either::Left(async move {
926                    while let Some((headers, sender)) =
927                        std::future::poll_fn(|cx| early_hints_rx.poll_recv(cx)).await
928                    {
929                        sender
930                            .into_inner()
931                            .send(mut_self.write_early_hints(version, headers).await)
932                            .ok();
933                    }
934                    futures_util::future::pending::<Result<(), std::io::Error>>().await
935                })
936            } else {
937                futures_util::future::Either::Right(futures_util::future::pending::<
938                    Result<(), std::io::Error>,
939                >())
940            };
941
942            // Content-Length header
943            let content_length = request
944                .headers()
945                .get(header::CONTENT_LENGTH)
946                .and_then(|v| v.to_str().ok())
947                .and_then(|v| v.parse::<u64>().ok())
948                .unwrap_or(0);
949            let chunked = request
950                .headers()
951                .get(header::TRANSFER_ENCODING)
952                .and_then(|v| v.to_str().ok())
953                .is_some_and(|v| {
954                    v.split(',')
955                        .any(|v| v.trim().eq_ignore_ascii_case("chunked"))
956                });
957            let has_trailers = request
958                .headers()
959                .get(header::TRAILER)
960                .map(|v| v.to_str().ok().is_some_and(|s| !s.is_empty()))
961                .unwrap_or(false);
962            let write_trailers = request
963                .headers()
964                .get(header::TE)
965                .and_then(|v| v.to_str().ok())
966                .map(|v| {
967                    v.split(',')
968                        .any(|v| v.trim().eq_ignore_ascii_case("trailers"))
969                })
970                .unwrap_or(false);
971
972            // Install HTTP upgrade
973            let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
974            let upgrade = Upgrade::new(upgrade_rx);
975            let upgraded = upgrade.upgraded.clone();
976            request.extensions_mut().insert(upgrade);
977
978            // Get HTTP response
979            let mut response = {
980                let read_body_fut = async {
981                    if chunked {
982                        self.read_chunked_body_fn(body_tx, has_trailers).await
983                    } else {
984                        self.read_body_fn(body_tx, content_length).await
985                    }
986                };
987                let read_body_fut_pin = std::pin::pin!(read_body_fut);
988                let request_fut = request_fn(request);
989                let request_fut_pin = std::pin::pin!(request_fut);
990                let early_hints_fut_pin = std::pin::pin!(early_hints_fut);
991
992                let select_read_body_either =
993                    futures_util::future::select(request_fut_pin, early_hints_fut_pin);
994                let select_either =
995                    futures_util::future::select(read_body_fut_pin, select_read_body_either).await;
996
997                let (response, body_fut) = match select_either {
998                    futures_util::future::Either::Left((result, request_fut)) => {
999                        result?;
1000                        (
1001                            match request_fut.await {
1002                                futures_util::future::Either::Left((response, _)) => response,
1003                                futures_util::future::Either::Right((_, _)) => unreachable!(),
1004                            },
1005                            None,
1006                        )
1007                    }
1008                    futures_util::future::Either::Right((response, read_body_fut)) => (
1009                        match response {
1010                            futures_util::future::Either::Left((response, _)) => response,
1011                            futures_util::future::Either::Right((_, _)) => unreachable!(),
1012                        },
1013                        Some(read_body_fut),
1014                    ),
1015                };
1016
1017                // Drain away remaining body
1018                if let Some(body_fut) = body_fut {
1019                    body_fut.await?;
1020                }
1021
1022                response.map_err(|e| std::io::Error::other(e.to_string()))?
1023            };
1024
1025            let mut was_upgraded = false;
1026            if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
1027                was_upgraded = true;
1028                response
1029                    .headers_mut()
1030                    .insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
1031            } else if keep_alive {
1032                if version == Version::HTTP_10
1033                    || response.headers().contains_key(header::CONNECTION)
1034                {
1035                    response
1036                        .headers_mut()
1037                        .insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
1038                }
1039            } else if version == Version::HTTP_11
1040                || response.headers().contains_key(header::CONNECTION)
1041            {
1042                response
1043                    .headers_mut()
1044                    .insert(header::CONNECTION, HeaderValue::from_static("close"));
1045            }
1046
1047            // Write response to IO
1048            self.write_response(response, version, write_trailers, zerocopy_fn.as_mut())
1049                .await?;
1050
1051            if was_upgraded {
1052                // HTTP upgrade
1053                let frozen_buf = self.read_buf.freeze();
1054                let _ = upgrade_tx.send(Upgraded::new(
1055                    self.io,
1056                    if frozen_buf.is_empty() {
1057                        None
1058                    } else {
1059                        Some(frozen_buf)
1060                    },
1061                ));
1062                return Ok(());
1063            }
1064
1065            if self.cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
1066                // Graceful shutdown requested, break out of loop
1067                break;
1068            }
1069        }
1070        Ok(())
1071    }
1072}
1073
1074impl<Io> HttpProtocol for Http1<Io>
1075where
1076    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
1077{
1078    #[inline]
1079    fn handle_with_error_fn<F, Fut, ResB, ResBE, ResE, EF, EFut, EResB, EResBE, EResE>(
1080        self,
1081        request_fn: F,
1082        error_fn: EF,
1083    ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1084    where
1085        F: Fn(Request<Incoming>) -> Fut + 'static,
1086        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
1087        ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
1088        ResE: std::error::Error,
1089        ResBE: std::error::Error,
1090        EF: FnOnce(bool) -> EFut,
1091        EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
1092        EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin + 'static,
1093        EResE: std::error::Error,
1094        EResBE: std::error::Error,
1095    {
1096        #[allow(clippy::type_complexity)]
1097        let no_zerocopy: Option<
1098            Box<
1099                dyn FnMut(
1100                    RawHandle,
1101                    &Io,
1102                    u64,
1103                ) -> Box<
1104                    dyn std::future::Future<Output = Result<(), std::io::Error>>
1105                        + Unpin
1106                        + Send
1107                        + Sync,
1108                >,
1109            >,
1110        > = None;
1111        self.handle_with_error_fn_and_zerocopy(request_fn, error_fn, no_zerocopy)
1112    }
1113
1114    #[inline]
1115    fn handle<F, Fut, ResB, ResBE, ResE>(
1116        self,
1117        request_fn: F,
1118    ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1119    where
1120        F: Fn(Request<Incoming>) -> Fut + 'static,
1121        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
1122        ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
1123        ResE: std::error::Error,
1124        ResBE: std::error::Error,
1125    {
1126        self.handle_with_error_fn(request_fn, |is_timeout| async move {
1127            let mut response = Response::builder();
1128            if is_timeout {
1129                response = response.status(http::StatusCode::REQUEST_TIMEOUT);
1130            } else {
1131                response = response.status(http::StatusCode::BAD_REQUEST);
1132            }
1133            response.body(Empty::new())
1134        })
1135    }
1136}
1137
1138pub(crate) struct Http1Body {
1139    #[allow(clippy::type_complexity)]
1140    inner: Pin<Box<AsyncReceiver<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>>>,
1141}
1142
1143impl Body for Http1Body {
1144    type Data = bytes::Bytes;
1145    type Error = std::io::Error;
1146
1147    #[inline]
1148    fn poll_frame(
1149        self: Pin<&mut Self>,
1150        cx: &mut Context<'_>,
1151    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
1152        match std::pin::pin!(self.inner.recv()).poll(cx) {
1153            Poll::Ready(Ok(Ok(frame))) => Poll::Ready(Some(Ok(frame))),
1154            Poll::Ready(Ok(Err(e))) => Poll::Ready(Some(Err(e))),
1155            Poll::Ready(Err(_)) => Poll::Ready(None),
1156            Poll::Pending => Poll::Pending,
1157        }
1158    }
1159}
1160
1161/// Searches for the header/body separator in a given slice.
1162/// Returns the index of the separator and the length of the separator.
1163#[inline]
1164fn search_header_body_separator(slice: &[u8]) -> Option<(usize, usize)> {
1165    if slice.len() < 2 {
1166        // Slice too short
1167        return None;
1168    }
1169    for (i, b) in slice.iter().copied().enumerate() {
1170        if b == b'\r' {
1171            if slice[i + 1..].chunks(3).next() == Some(&b"\n\r\n"[..]) {
1172                return Some((i, 4));
1173            }
1174        } else if b == b'\n' && slice.get(i + 1) == Some(&b'\n') {
1175            return Some((i, 2));
1176        }
1177    }
1178    None
1179}
1180
1181/// Writes the chunk size to the given buffer in hexadecimal format, followed by `\r\n`.
1182#[inline]
1183fn write_chunk_size(dst: &mut [u8; 18], len: usize) -> &[u8] {
1184    let mut n = len;
1185    let mut pos = dst.len() - 2;
1186    loop {
1187        pos -= 1;
1188        dst[pos] = HEX_DIGITS[n & 0xF];
1189        n >>= 4;
1190        if n == 0 {
1191            break;
1192        }
1193    }
1194    dst[dst.len() - 2] = b'\r';
1195    dst[dst.len() - 1] = b'\n';
1196    &dst[pos..]
1197}