Skip to main content

vibeio_http/h1/
mod.rs

1mod options;
2mod tests;
3mod upgrade;
4mod writebuf;
5mod zerocopy;
6
7pub use options::*;
8pub use upgrade::*;
9pub use zerocopy::*;
10
11#[cfg(unix)]
12pub(crate) type RawHandle = std::os::fd::RawFd;
13#[cfg(windows)]
14pub(crate) type RawHandle = std::os::windows::io::RawHandle;
15
16use std::{
17    io::IoSlice,
18    mem::MaybeUninit,
19    pin::Pin,
20    str::FromStr,
21    task::{Context, Poll},
22    time::UNIX_EPOCH,
23};
24
25use async_channel::Receiver;
26use bytes::{Buf, Bytes, BytesMut};
27use futures_util::stream::Stream;
28use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, Uri, Version};
29use http_body::Body;
30use http_body_util::{BodyExt, Empty};
31use memchr::{memchr3_iter, memmem};
32use tokio::io::{AsyncReadExt, AsyncWriteExt};
33use tokio_util::sync::CancellationToken;
34
35use crate::{h1::writebuf::WriteBuf, EarlyHints, HttpProtocol, Incoming};
36
37const HEX_DIGITS: &[u8; 16] = b"0123456789ABCDEF";
38
39/// An HTTP/1.x connection handler.
40///
41/// `Http1` wraps an async I/O stream (`Io`) and provides a complete
42/// HTTP/1.0 and HTTP/1.1 server implementation, including:
43///
44/// - Request head parsing (via [`httparse`])
45/// - Streaming request bodies (content-length and chunked transfer-encoding)
46/// - Chunked response encoding and trailer support
47/// - `100 Continue` and `103 Early Hints` interim responses
48/// - HTTP connection upgrades (e.g. WebSocket)
49/// - Optional zero-copy response sending on Linux (see `Http1::zerocopy`)
50/// - Keep-alive connection reuse
51/// - Graceful shutdown via a [`CancellationToken`]
52///
53/// # Construction
54///
55/// ```rust,ignore
56/// let http1 = Http1::new(tcp_stream, Http1Options::default());
57/// ```
58///
59/// # Serving requests
60///
61/// Use the [`HttpProtocol`] trait methods ([`handle`](HttpProtocol::handle) /
62/// [`handle_with_error_fn`](HttpProtocol::handle_with_error_fn)) to drive the
63/// connection to completion:
64///
65/// ```rust,ignore
66/// http1.handle(|req| async move {
67///     Ok::<_, Infallible>(Response::new(Full::new(Bytes::from("Hello!"))))
68/// }).await?;
69/// ```
70pub struct Http1<Io> {
71    io: Io,
72    options: options::Http1Options,
73    cancel_token: Option<CancellationToken>,
74    parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]>,
75    date_header_value_cached: Option<(String, std::time::SystemTime)>,
76    cached_headers: Option<HeaderMap>,
77    read_buf: BytesMut,
78    write_buf: WriteBuf,
79}
80
81#[cfg(all(target_os = "linux", feature = "h1-zerocopy"))]
82impl<Io> Http1<Io>
83where
84    for<'a> Io: tokio::io::AsyncRead
85        + tokio::io::AsyncWrite
86        + vibeio::io::AsInnerRawHandle<'a>
87        + Unpin
88        + 'static,
89{
90    /// Converts this `Http1` into an [`Http1Zerocopy`] that uses emulated
91    /// sendfile (Linux only) to send response bodies without copying data
92    /// through user space.
93    ///
94    /// The response body must have a `ZerocopyResponse` extension installed
95    /// (via [`install_zerocopy`]) containing the file descriptor to send from.
96    /// Responses without that extension are sent normally.
97    ///
98    /// Only available on Linux (`target_os = "linux"`), and only when `Io`
99    /// implements [`vibeio::io::AsInnerRawHandle`].
100    #[inline]
101    pub fn zerocopy(self) -> Http1Zerocopy<Io> {
102        Http1Zerocopy { inner: self }
103    }
104}
105
106impl<Io> Http1<Io>
107where
108    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
109{
110    /// Creates a new `Http1` connection handler wrapping the given I/O stream.
111    ///
112    /// The `options` value controls limits, timeouts, and optional features;
113    /// see [`Http1Options`] for details.
114    ///
115    /// # Example
116    ///
117    /// ```rust,ignore
118    /// let http1 = Http1::new(tcp_stream, Http1Options::default());
119    /// ```
120    #[inline]
121    pub fn new(io: Io, options: options::Http1Options) -> Self {
122        // Safety: u8 is a primitive type, so we can safely assume initialization
123        let read_buf = BytesMut::with_capacity(options.max_header_size);
124        let parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]> =
125            Box::new_uninit_slice(options.max_header_count);
126        Self {
127            io,
128            options,
129            cancel_token: None,
130            parsed_headers,
131            date_header_value_cached: None,
132            cached_headers: None,
133            read_buf,
134            write_buf: WriteBuf::new(),
135        }
136    }
137
138    #[inline]
139    fn get_date_header_value(&mut self) -> &str {
140        let now = std::time::SystemTime::now();
141        if self.date_header_value_cached.as_ref().is_none_or(|v| {
142            v.1.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
143                != now.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
144        }) {
145            let value = httpdate::fmt_http_date(now).to_string();
146            self.date_header_value_cached = Some((value, now));
147        }
148        self.date_header_value_cached
149            .as_ref()
150            .map(|v| v.0.as_str())
151            .unwrap_or("")
152    }
153
154    /// Attaches a [`CancellationToken`] for graceful shutdown.
155    ///
156    /// After the current in-flight request has been fully handled and its
157    /// response written, the connection loop checks whether the token has been
158    /// cancelled. If it has, the loop exits cleanly instead of waiting for the
159    /// next request.
160    ///
161    /// This allows the server to drain active connections without abruptly
162    /// closing them mid-response.
163    #[inline]
164    pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
165        self.cancel_token = Some(token);
166        self
167    }
168
169    #[inline]
170    async fn fill_buf(&mut self) -> Result<usize, std::io::Error> {
171        if self.read_buf.remaining() < 1024 {
172            self.read_buf.reserve(1024);
173        }
174        let spare_capacity = self.read_buf.spare_capacity_mut();
175        // Safety: The buffer is are read only after the request head has been parsed
176        let n = self
177            .io
178            .read(unsafe {
179                &mut *std::ptr::slice_from_raw_parts_mut(
180                    spare_capacity.as_mut_ptr() as *mut u8,
181                    spare_capacity.len(),
182                )
183            })
184            .await?;
185        if n == 0 {
186            return Ok(0);
187        }
188        unsafe { self.read_buf.set_len(self.read_buf.len() + n) };
189        Ok(n)
190    }
191
192    #[inline]
193    async fn read_body_fn(
194        &mut self,
195        body_tx: &async_channel::Sender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
196        content_length: u64,
197    ) -> Result<(), std::io::Error> {
198        let mut remaining = content_length;
199        let mut just_started = true;
200        while remaining > 0 {
201            let have_to_read_buf = !just_started || self.read_buf.is_empty();
202            just_started = false;
203            if have_to_read_buf {
204                let n = self.fill_buf().await?;
205                if n == 0 {
206                    break;
207                }
208            }
209            let chunk = self
210                .read_buf
211                .split_to(
212                    self.read_buf
213                        .len()
214                        .min(remaining.min(usize::MAX as u64) as usize),
215                )
216                .freeze();
217            remaining -= chunk.len() as u64;
218
219            let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
220        }
221        body_tx.close(); // Close the body_tx channel to signal EOF
222        Ok(())
223    }
224
225    #[inline]
226    async fn read_body_chunk(
227        &mut self,
228        would_have_trailers: bool,
229    ) -> Result<bytes::Bytes, std::io::Error> {
230        let len = {
231            // Safety: u8 is a primitive type, so we can safely assume initialization
232            let mut len_buf_pos: usize = 0;
233            let mut just_started = true;
234            loop {
235                if len_buf_pos >= 48 {
236                    return Err(std::io::Error::new(
237                        std::io::ErrorKind::InvalidData,
238                        "chunk length buffer overflow",
239                    ));
240                }
241
242                let begin_search = len_buf_pos.saturating_sub(1);
243
244                let have_to_read_buf = !just_started || self.read_buf.is_empty();
245                just_started = false;
246                if have_to_read_buf {
247                    let n = self.fill_buf().await?;
248                    if n == 0 {
249                        return Err(std::io::Error::new(
250                            std::io::ErrorKind::UnexpectedEof,
251                            "unexpected EOF",
252                        ));
253                    }
254                    len_buf_pos += n;
255                } else {
256                    len_buf_pos += self.read_buf.len();
257                }
258
259                if let Some(pos) =
260                    memmem::find(&self.read_buf[begin_search..len_buf_pos.min(48)], b"\r\n")
261                {
262                    let numbers =
263                        std::str::from_utf8(&self.read_buf[begin_search..begin_search + pos])
264                            .map_err(|_| {
265                                std::io::Error::new(
266                                    std::io::ErrorKind::InvalidData,
267                                    "invalid chunk length",
268                                )
269                            })?;
270                    let len = usize::from_str_radix(numbers, 16).map_err(|_| {
271                        std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid chunk length")
272                    })?;
273                    // Ignore the trailing CRLF
274                    self.read_buf.advance(begin_search + pos + 2);
275                    break len;
276                }
277            }
278        };
279        // Safety: u8 is a primitive type, so we can safely assume initialization
280        let mut read = 0;
281        if len == 0 && would_have_trailers {
282            return Ok(bytes::Bytes::new()); // Empty terminating chunk
283        }
284        let mut just_started = true;
285        // + 2, because we need to read the trailing CRLF
286        while read < len + 2 {
287            let have_to_read_buf = !just_started || self.read_buf.is_empty();
288            just_started = false;
289            if have_to_read_buf {
290                let n = self.fill_buf().await?;
291                if n == 0 {
292                    return Err(std::io::Error::new(
293                        std::io::ErrorKind::UnexpectedEof,
294                        "unexpected EOF",
295                    ));
296                }
297                read += n;
298            } else {
299                read += self.read_buf.len();
300            }
301        }
302        let chunk = self.read_buf.split_to(len).freeze();
303        self.read_buf.advance(2); // Ignore the trailing CRLF
304        Ok(chunk)
305    }
306
307    #[inline]
308    async fn read_trailers(&mut self) -> Result<Option<HeaderMap>, std::io::Error> {
309        // Safety: u8 is a primitive type, so we can safely assume initialization
310        let mut bytes_read: usize = 0;
311        let mut just_started = true;
312        while bytes_read < self.options.max_header_size {
313            let old_bytes_read = bytes_read;
314            let begin_search = old_bytes_read.saturating_sub(3);
315
316            let have_to_read_buf = !just_started || self.read_buf.is_empty();
317            just_started = false;
318            if have_to_read_buf {
319                let n = self.fill_buf().await?;
320                if n == 0 {
321                    return Err(std::io::Error::new(
322                        std::io::ErrorKind::UnexpectedEof,
323                        "unexpected EOF",
324                    ));
325                }
326                bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
327            } else {
328                bytes_read =
329                    (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
330            }
331
332            if bytes_read > 2 && self.read_buf[0] == b'\r' && self.read_buf[1] == b'\n' {
333                // No trailers, return None
334                return Ok(None);
335            }
336
337            if let Some(separator_index) =
338                memmem::find(&self.read_buf[begin_search..bytes_read], b"\r\n\r\n")
339            {
340                let to_parse_length = begin_search + separator_index + 4;
341                let buf_ro = self.read_buf.split_to(to_parse_length).freeze();
342
343                // Parse trailers using `httparse` crate's header parsing
344                let mut httparse_trailers =
345                    vec![httparse::EMPTY_HEADER; self.options.max_header_count].into_boxed_slice();
346                let status = httparse::parse_headers(&buf_ro, &mut httparse_trailers)
347                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
348                if let httparse::Status::Complete((_, trailers)) = status {
349                    let mut trailers_constructed = HeaderMap::new();
350                    for header in trailers {
351                        if header == &httparse::EMPTY_HEADER {
352                            // No more headers...
353                            break;
354                        }
355                        let name = HeaderName::from_bytes(header.name.as_bytes())
356                            .map_err(|e| std::io::Error::other(e.to_string()))?;
357                        let value_start = header.value.as_ptr() as usize - buf_ro.as_ptr() as usize;
358                        let value_len = header.value.len();
359                        // Safety: the header value is already validated by httparse
360                        let value = unsafe {
361                            HeaderValue::from_maybe_shared_unchecked(
362                                buf_ro.slice(value_start..(value_start + value_len)),
363                            )
364                        };
365                        trailers_constructed.append(name, value);
366                    }
367
368                    return Ok(Some(trailers_constructed));
369                } else {
370                    return Err(std::io::Error::new(
371                        std::io::ErrorKind::InvalidInput,
372                        "trailer headers incomplete",
373                    ));
374                }
375            }
376        }
377        Err(std::io::Error::new(
378            std::io::ErrorKind::InvalidData,
379            "request too large",
380        ))
381    }
382
383    #[inline]
384    async fn read_chunked_body_fn(
385        &mut self,
386        body_tx: &async_channel::Sender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
387        would_have_trailers: bool,
388    ) -> Result<(), std::io::Error> {
389        loop {
390            let chunk = self.read_body_chunk(would_have_trailers).await?;
391            if chunk.is_empty() {
392                break;
393            }
394
395            let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
396        }
397        if would_have_trailers {
398            // Trailers
399            let trailers = self.read_trailers().await?;
400            if let Some(trailers) = trailers {
401                let _ = body_tx.send(Ok(http_body::Frame::trailers(trailers))).await;
402            }
403        }
404        body_tx.close(); // Close the body_tx channel to signal EOF
405        Ok(())
406    }
407
408    #[inline]
409    async fn read_request(
410        &mut self,
411    ) -> Result<
412        Option<(
413            Request<Incoming>,
414            async_channel::Sender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
415        )>,
416        std::io::Error,
417    > {
418        // Parse HTTP request using httparse
419        let (request, body_tx) = {
420            let Some((head, headers)) = self.get_head().await? else {
421                return Ok(None);
422            };
423            // Safety: The headers are read only after the request head has been parsed
424            let headers = unsafe {
425                std::mem::transmute::<
426                    &mut [MaybeUninit<httparse::Header<'static>>],
427                    &mut [MaybeUninit<httparse::Header<'_>>],
428                >(headers)
429            };
430            let mut req = httparse::Request::new(&mut []);
431            let status = req
432                .parse_with_uninit_headers(&head, headers)
433                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
434            if status.is_partial() {
435                return Err(std::io::Error::new(
436                    std::io::ErrorKind::InvalidData,
437                    "partial request head",
438                ));
439            }
440
441            // Convert httparse HTTP request to `http` one
442            let (body_tx, body_rx) = async_channel::bounded(2);
443            let request_body = Http1Body {
444                inner: Box::pin(body_rx),
445            };
446            let mut request = Request::new(Incoming::new(request_body));
447            match req.version {
448                Some(0) => *request.version_mut() = http::Version::HTTP_10,
449                Some(1) => *request.version_mut() = http::Version::HTTP_11,
450                _ => *request.version_mut() = http::Version::HTTP_11,
451            };
452            if let Some(method) = req.method {
453                *request.method_mut() = Method::from_bytes(method.as_bytes())
454                    .map_err(|e| std::io::Error::other(e.to_string()))?;
455            }
456            if let Some(path) = req.path {
457                *request.uri_mut() =
458                    Uri::from_str(path).map_err(|e| std::io::Error::other(e.to_string()))?;
459            }
460            let mut header_map = self.cached_headers.take().unwrap_or_default();
461            header_map.clear();
462            let additional_capacity = req.headers.len().saturating_sub(header_map.capacity());
463            if additional_capacity > 0 {
464                header_map.reserve(additional_capacity);
465            }
466            for header in req.headers {
467                if header == &httparse::EMPTY_HEADER {
468                    // No more headers...
469                    break;
470                }
471                let name = HeaderName::from_bytes(header.name.as_bytes())
472                    .map_err(|e| std::io::Error::other(e.to_string()))?;
473                let value_start = header.value.as_ptr() as usize - head.as_ptr() as usize;
474                let value_len = header.value.len();
475                // Safety: the header value is already validated by httparse
476                let value = unsafe {
477                    HeaderValue::from_maybe_shared_unchecked(
478                        head.slice(value_start..(value_start + value_len)),
479                    )
480                };
481                header_map.append(name, value);
482            }
483            *request.headers_mut() = header_map;
484
485            (request, body_tx)
486        };
487        Ok(Some((request, body_tx)))
488    }
489
490    #[inline]
491    async fn get_head(
492        &mut self,
493    ) -> Result<Option<(Bytes, &mut [MaybeUninit<httparse::Header<'static>>])>, std::io::Error>
494    {
495        let mut request_line_read = false;
496        let mut bytes_read: usize = 0;
497        let mut whitespace_trimmed = None;
498        let mut just_started = true;
499        while bytes_read < self.options.max_header_size {
500            let old_bytes_read = bytes_read;
501            let begin_search = old_bytes_read.saturating_sub(3);
502
503            let have_to_read_buf = !just_started || self.read_buf.is_empty();
504            just_started = false;
505            if have_to_read_buf {
506                let n = self.fill_buf().await?;
507                if n == 0 {
508                    if whitespace_trimmed.is_none() {
509                        return Ok(None);
510                    }
511                    return Err(std::io::Error::new(
512                        std::io::ErrorKind::UnexpectedEof,
513                        "unexpected EOF",
514                    ));
515                }
516                bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
517            } else {
518                bytes_read =
519                    (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
520            }
521
522            if whitespace_trimmed.is_none() {
523                whitespace_trimmed = self.read_buf[old_bytes_read..bytes_read]
524                    .iter()
525                    .position(|b| !b.is_ascii_whitespace());
526            }
527
528            if let Some(whitespace_trimmed) = whitespace_trimmed {
529                // Validate first line (request line) before checking for header/body separator
530                if !request_line_read {
531                    let memchr = memchr3_iter(
532                        b' ',
533                        b'\r',
534                        b'\n',
535                        &self.read_buf[whitespace_trimmed..bytes_read],
536                    );
537                    let mut spaces = 0;
538                    for separator_index in memchr {
539                        if self.read_buf[whitespace_trimmed + separator_index] == b' ' {
540                            if spaces >= 2 {
541                                return Err(std::io::Error::new(
542                                    std::io::ErrorKind::InvalidInput,
543                                    "bad request first line",
544                                ));
545                            }
546                            spaces += 1;
547                        } else if spaces == 2 {
548                            request_line_read = true;
549                            break;
550                        } else {
551                            return Err(std::io::Error::new(
552                                std::io::ErrorKind::InvalidInput,
553                                "bad request first line",
554                            ));
555                        }
556                    }
557                }
558
559                if request_line_read {
560                    let begin_search = begin_search.max(whitespace_trimmed);
561                    if let Some((separator_index, separator_len)) =
562                        search_header_body_separator(&self.read_buf[begin_search..bytes_read])
563                    {
564                        let to_parse_length =
565                            begin_search + separator_index + separator_len - whitespace_trimmed;
566                        self.read_buf.advance(whitespace_trimmed);
567                        let head = self.read_buf.split_to(to_parse_length);
568                        return Ok(Some((head.freeze(), &mut self.parsed_headers)));
569                    }
570                }
571            }
572        }
573        Err(std::io::Error::new(
574            std::io::ErrorKind::InvalidData,
575            "request too large",
576        ))
577    }
578
579    #[inline]
580    async fn write_response<Z, ZFut>(
581        &mut self,
582        mut response: Response<
583            impl Body<Data = bytes::Bytes, Error = impl std::error::Error> + Unpin,
584        >,
585        version: Version,
586        write_trailers: bool,
587        zerocopy_fn: Option<Z>,
588    ) -> Result<(), std::io::Error>
589    where
590        Z: FnMut(RawHandle, &'static Io, u64) -> ZFut,
591        ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
592    {
593        // Date header
594        if self.options.send_date_header {
595            response.headers_mut().insert(
596                header::DATE,
597                HeaderValue::from_str(self.get_date_header_value())
598                    .map_err(|e| std::io::Error::other(e.to_string()))?,
599            );
600        }
601
602        // If the body has a size hint, set the Content-Length header if it's not already set
603        if let Some(suggested_content_length) = response.body().size_hint().exact() {
604            let headers = response.headers_mut();
605            if !headers.contains_key(header::CONTENT_LENGTH) {
606                headers.insert(header::CONTENT_LENGTH, suggested_content_length.into());
607            }
608        }
609
610        let chunked = response
611            .headers()
612            .get(header::TRANSFER_ENCODING)
613            .map(|v| {
614                v.to_str().ok().is_some_and(|s| {
615                    s.split(',')
616                        .any(|s| s.trim().eq_ignore_ascii_case("chunked"))
617                })
618            })
619            .unwrap_or_else(|| {
620                response
621                    .headers()
622                    .get(header::CONTENT_LENGTH)
623                    .and_then(|v| v.to_str().ok())
624                    .is_none_or(|s| s.parse::<u64>().is_err())
625            });
626
627        if chunked {
628            response.headers_mut().insert(
629                header::TRANSFER_ENCODING,
630                HeaderValue::from_static("chunked"),
631            );
632            while response
633                .headers_mut()
634                .remove(header::CONTENT_LENGTH)
635                .is_some()
636            {}
637        }
638
639        let (parts, mut body) = response.into_parts();
640
641        let mut head = Vec::with_capacity(30 + parts.headers.len() * 30); // Similar to Hyper's heuristic
642        if version == Version::HTTP_10 {
643            head.extend_from_slice(b"HTTP/1.0 ");
644        } else {
645            head.extend_from_slice(b"HTTP/1.1 ");
646        }
647        let status = parts.status;
648        head.extend_from_slice(status.as_str().as_bytes());
649        if let Some(canonical_reason) = status.canonical_reason() {
650            head.extend_from_slice(b" ");
651            head.extend_from_slice(canonical_reason.as_bytes());
652        }
653        head.extend_from_slice(b"\r\n");
654        for (name, value) in &parts.headers {
655            head.extend_from_slice(name.as_str().as_bytes());
656            head.extend_from_slice(b": ");
657            head.extend_from_slice(value.as_bytes());
658            head.extend_from_slice(b"\r\n");
659        }
660        head.extend_from_slice(b"\r\n");
661        unsafe {
662            self.write_buf.push(IoSlice::new(&head));
663        }
664
665        if !chunked {
666            if let Some(content_length) = parts
667                .headers
668                .get(header::CONTENT_LENGTH)
669                .and_then(|v| v.to_str().ok())
670                .and_then(|s| s.parse::<u64>().ok())
671            {
672                if let Some(zero_copy) = parts.extensions.get::<ZerocopyResponse>() {
673                    if let Some(mut zerocopy_fn) = zerocopy_fn {
674                        // Zerocopy
675                        unsafe {
676                            self.write_buf
677                                .flush(&mut self.io, self.options.enable_vectored_write)
678                                .await?
679                        };
680                        zerocopy_fn(
681                            zero_copy.handle,
682                            // Safety: the lifetime of the static reference is bound by the lifetime of the Io struct
683                            unsafe { std::mem::transmute::<&Io, &'static Io>(&self.io) },
684                            content_length,
685                        )
686                        .await?;
687                        self.io.flush().await?;
688                        let reclaimed_headers = parts.headers;
689                        self.cached_headers = Some(reclaimed_headers);
690                        return Ok(());
691                    }
692                }
693            }
694        }
695
696        let mut trailers_written = false;
697        while let Some(chunk) = body.frame().await {
698            let chunk = chunk.map_err(|e| std::io::Error::other(e.to_string()))?;
699            match chunk.into_data() {
700                Ok(data) => {
701                    if chunked {
702                        let mut data_len_buf = Vec::with_capacity(16);
703                        write_chunk_size(&mut data_len_buf, data.len());
704                        unsafe {
705                            self.write_buf.push(IoSlice::new(&data_len_buf));
706                            self.write_buf.push(IoSlice::new(&data));
707                            self.write_buf.push(IoSlice::new(b"\r\n"));
708                            self.write_buf
709                                .write(&mut self.io, self.options.enable_vectored_write)
710                                .await?;
711                        };
712                    } else {
713                        unsafe {
714                            self.write_buf.push(IoSlice::new(&data));
715                            self.write_buf
716                                .write(&mut self.io, self.options.enable_vectored_write)
717                                .await?;
718                        }
719                    }
720                }
721                Err(chunk) => {
722                    if let Ok(trailers) = chunk.into_trailers() {
723                        if write_trailers {
724                            unsafe {
725                                self.write_buf.push(IoSlice::new(b"0\r\n"));
726                                for (name, value) in &trailers {
727                                    self.write_buf.push(IoSlice::new(name.as_str().as_bytes()));
728                                    self.write_buf.push(IoSlice::new(b": "));
729                                    self.write_buf.push(IoSlice::new(value.as_bytes()));
730                                    self.write_buf.push(IoSlice::new(b"\r\n"));
731                                }
732                                self.write_buf.push(IoSlice::new(b"\r\n"));
733                                self.write_buf
734                                    .write(&mut self.io, self.options.enable_vectored_write)
735                                    .await?;
736                            }
737                            trailers_written = true;
738                        }
739                        break;
740                    }
741                }
742            };
743        }
744        if chunked && !trailers_written {
745            // Terminating chunk
746            unsafe {
747                self.write_buf.push(IoSlice::new(b"0\r\n\r\n"));
748            }
749        }
750        unsafe {
751            self.write_buf
752                .flush(&mut self.io, self.options.enable_vectored_write)
753                .await?;
754        }
755        self.io.flush().await?;
756        let reclaimed_headers = parts.headers;
757        self.cached_headers = Some(reclaimed_headers);
758
759        Ok(())
760    }
761
762    #[inline]
763    async fn write_100_continue(&mut self, version: Version) -> Result<(), std::io::Error> {
764        if version == Version::HTTP_10 {
765            self.io.write_all(b"HTTP/1.0 100 Continue\r\n\r\n").await?;
766        } else {
767            self.io.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").await?;
768        }
769        self.io.flush().await?;
770
771        Ok(())
772    }
773
774    #[inline]
775    async fn write_early_hints(
776        &mut self,
777        version: Version,
778        headers: http::HeaderMap,
779    ) -> Result<(), std::io::Error> {
780        let mut head = Vec::new();
781        if version == Version::HTTP_10 {
782            head.extend_from_slice(b"HTTP/1.0 103 Early Hints\r\n");
783        } else {
784            head.extend_from_slice(b"HTTP/1.1 103 Early Hints\r\n");
785        }
786        let mut current_header_name = None;
787        for (name, value) in headers {
788            if let Some(name) = name {
789                current_header_name = Some(name);
790            };
791            if let Some(current_header_name) = &current_header_name {
792                head.extend_from_slice(current_header_name.as_str().as_bytes());
793                if value.is_empty() {
794                    head.extend_from_slice(b":\r\n");
795                    continue;
796                }
797                head.extend_from_slice(b": ");
798                head.extend_from_slice(value.as_bytes());
799                head.extend_from_slice(b"\r\n");
800            }
801        }
802        head.extend_from_slice(b"\r\n");
803
804        self.io.write_all(&head).await?;
805
806        Ok(())
807    }
808
809    #[inline]
810    pub(crate) async fn handle_with_error_fn_and_zerocopy<
811        F,
812        Fut,
813        ResB,
814        ResBE,
815        ResE,
816        EF,
817        EFut,
818        EResB,
819        EResBE,
820        EResE,
821        ZF,
822        ZFut,
823    >(
824        mut self,
825        request_fn: F,
826        error_fn: EF,
827        mut zerocopy_fn: Option<ZF>,
828    ) -> Result<(), std::io::Error>
829    where
830        F: Fn(Request<Incoming>) -> Fut + 'static,
831        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
832        ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
833        ResE: std::error::Error,
834        ResBE: std::error::Error,
835        EF: FnOnce(bool) -> EFut,
836        EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
837        EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin,
838        EResE: std::error::Error,
839        EResBE: std::error::Error,
840        ZF: FnMut(RawHandle, &'static Io, u64) -> ZFut,
841        ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
842    {
843        let mut keep_alive = true;
844
845        while keep_alive {
846            let (mut request, body_tx) = match if let Some(timeout) =
847                self.options.header_read_timeout
848            {
849                vibeio::time::timeout(timeout, self.read_request()).await
850            } else {
851                Ok(self.read_request().await)
852            } {
853                Ok(Ok(Some(d))) => d,
854                Ok(Ok(None)) => {
855                    return Ok(());
856                }
857                Ok(Err(e)) => {
858                    // Parse error
859                    if let Ok(mut response) = error_fn(false).await {
860                        response
861                            .headers_mut()
862                            .insert(header::CONNECTION, HeaderValue::from_static("close"));
863
864                        let _ = self
865                            .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
866                            .await;
867                    }
868                    return Err(e);
869                }
870                Err(_) => {
871                    // Timeout error
872                    if let Ok(mut response) = error_fn(true).await {
873                        response
874                            .headers_mut()
875                            .insert(header::CONNECTION, HeaderValue::from_static("close"));
876
877                        let _ = self
878                            .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
879                            .await;
880                    }
881                    return Err(std::io::Error::new(
882                        std::io::ErrorKind::TimedOut,
883                        "header read timeout",
884                    ));
885                }
886            };
887
888            // Connection header detection
889            let connection_header_split = request
890                .headers()
891                .get(header::CONNECTION)
892                .and_then(|v| v.to_str().ok())
893                .map(|v| v.split(",").map(|v| v.trim()));
894            let is_connection_close = connection_header_split
895                .clone()
896                .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("close")));
897            let is_connection_keep_alive = connection_header_split
898                .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("keep-alive")));
899            keep_alive = !is_connection_close
900                && (is_connection_keep_alive || request.version() == http::Version::HTTP_11);
901
902            let version = request.version();
903
904            // 100 Continue
905            if self.options.send_continue_response {
906                let is_100_continue = request
907                    .headers()
908                    .get(header::EXPECT)
909                    .and_then(|v| v.to_str().ok())
910                    .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
911                if is_100_continue {
912                    self.write_100_continue(version).await?;
913                }
914            }
915
916            // 103 Early Hints
917            let early_hints_fut = if self.options.enable_early_hints {
918                let (early_hints_tx, early_hints_rx) = async_channel::unbounded();
919                let early_hints = EarlyHints::new(early_hints_tx);
920                request.extensions_mut().insert(early_hints);
921                // Safety: the function below is used only in futures_util::future::select
922                // Also, another function that would borrow self would read data,
923                // while this function would write data
924                let mut_self = unsafe { std::mem::transmute::<&mut Self, &mut Self>(&mut self) };
925                Some(async {
926                    let early_hints_rx = early_hints_rx;
927                    while let Ok((headers, sender)) = early_hints_rx.recv().await {
928                        sender
929                            .into_inner()
930                            .send(mut_self.write_early_hints(version, headers).await)
931                            .ok();
932                    }
933                    futures_util::future::pending::<Result<(), std::io::Error>>().await
934                })
935            } else {
936                None
937            };
938
939            // Content-Length header
940            let content_length = request
941                .headers()
942                .get(header::CONTENT_LENGTH)
943                .and_then(|v| v.to_str().ok())
944                .and_then(|v| v.parse::<u64>().ok())
945                .unwrap_or(0);
946            let chunked = request
947                .headers()
948                .get(header::TRANSFER_ENCODING)
949                .and_then(|v| v.to_str().ok())
950                .is_some_and(|v| {
951                    v.split(',')
952                        .any(|v| v.trim().eq_ignore_ascii_case("chunked"))
953                });
954            let has_trailers = request
955                .headers()
956                .get(header::TRAILER)
957                .map(|v| v.to_str().ok().is_some_and(|s| !s.is_empty()))
958                .unwrap_or(false);
959            let write_trailers = request
960                .headers()
961                .get(header::TE)
962                .and_then(|v| v.to_str().ok())
963                .map(|v| {
964                    v.split(',')
965                        .any(|v| v.trim().eq_ignore_ascii_case("trailers"))
966                })
967                .unwrap_or(false);
968
969            // Install HTTP upgrade
970            let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
971            let upgrade = Upgrade::new(upgrade_rx);
972            let upgraded = upgrade.upgraded.clone();
973            request.extensions_mut().insert(upgrade);
974
975            // Get HTTP response
976            let mut response = {
977                let read_body_fut = async {
978                    if chunked {
979                        self.read_chunked_body_fn(&body_tx, has_trailers).await
980                    } else {
981                        self.read_body_fn(&body_tx, content_length).await
982                    }
983                };
984                let read_body_fut_pin = std::pin::pin!(read_body_fut);
985                let request_fut = request_fn(request);
986                let request_fut_pin = std::pin::pin!(request_fut);
987                let early_hints_fut: Pin<
988                    Box<dyn std::future::Future<Output = Result<(), std::io::Error>>>,
989                > = if let Some(early_hints) = early_hints_fut {
990                    Box::pin(early_hints)
991                } else {
992                    Box::pin(futures_util::future::pending::<Result<(), std::io::Error>>())
993                };
994
995                let select_read_body_either =
996                    futures_util::future::select(request_fut_pin, early_hints_fut);
997                let select_either =
998                    futures_util::future::select(read_body_fut_pin, select_read_body_either).await;
999
1000                let (response, body_fut) = match select_either {
1001                    futures_util::future::Either::Left((result, request_fut)) => {
1002                        result?;
1003                        (
1004                            match request_fut.await {
1005                                futures_util::future::Either::Left((response, _)) => response,
1006                                futures_util::future::Either::Right((_, _)) => unreachable!(),
1007                            },
1008                            None,
1009                        )
1010                    }
1011                    futures_util::future::Either::Right((response, read_body_fut)) => (
1012                        match response {
1013                            futures_util::future::Either::Left((response, _)) => response,
1014                            futures_util::future::Either::Right((_, _)) => unreachable!(),
1015                        },
1016                        Some(read_body_fut),
1017                    ),
1018                };
1019
1020                // Drain away remaining body
1021                if let Some(body_fut) = body_fut {
1022                    body_fut.await?;
1023                }
1024
1025                response.map_err(|e| std::io::Error::other(e.to_string()))?
1026            };
1027
1028            let mut was_upgraded = false;
1029            if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
1030                was_upgraded = true;
1031                response
1032                    .headers_mut()
1033                    .insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
1034            } else if keep_alive {
1035                if version == Version::HTTP_10
1036                    || response.headers().contains_key(header::CONNECTION)
1037                {
1038                    response
1039                        .headers_mut()
1040                        .insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
1041                }
1042            } else if version == Version::HTTP_11
1043                || response.headers().contains_key(header::CONNECTION)
1044            {
1045                response
1046                    .headers_mut()
1047                    .insert(header::CONNECTION, HeaderValue::from_static("close"));
1048            }
1049
1050            // Write response to IO
1051            self.write_response(response, version, write_trailers, zerocopy_fn.as_mut())
1052                .await?;
1053
1054            if was_upgraded {
1055                // HTTP upgrade
1056                let frozen_buf = self.read_buf.freeze();
1057                let _ = upgrade_tx.send(Upgraded::new(
1058                    self.io,
1059                    if frozen_buf.is_empty() {
1060                        None
1061                    } else {
1062                        Some(frozen_buf)
1063                    },
1064                ));
1065                return Ok(());
1066            }
1067
1068            if self.cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
1069                // Graceful shutdown requested, break out of loop
1070                break;
1071            }
1072        }
1073        Ok(())
1074    }
1075}
1076
1077impl<Io> HttpProtocol for Http1<Io>
1078where
1079    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
1080{
1081    #[inline]
1082    fn handle_with_error_fn<F, Fut, ResB, ResBE, ResE, EF, EFut, EResB, EResBE, EResE>(
1083        self,
1084        request_fn: F,
1085        error_fn: EF,
1086    ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1087    where
1088        F: Fn(Request<Incoming>) -> Fut + 'static,
1089        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
1090        ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
1091        ResE: std::error::Error,
1092        ResBE: std::error::Error,
1093        EF: FnOnce(bool) -> EFut,
1094        EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
1095        EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin,
1096        EResE: std::error::Error,
1097        EResBE: std::error::Error,
1098    {
1099        #[allow(clippy::type_complexity)]
1100        let no_zerocopy: Option<
1101            Box<
1102                dyn FnMut(
1103                    RawHandle,
1104                    &Io,
1105                    u64,
1106                ) -> Box<
1107                    dyn std::future::Future<Output = Result<(), std::io::Error>>
1108                        + Unpin
1109                        + Send
1110                        + Sync,
1111                >,
1112            >,
1113        > = None;
1114        self.handle_with_error_fn_and_zerocopy(request_fn, error_fn, no_zerocopy)
1115    }
1116
1117    #[inline]
1118    fn handle<F, Fut, ResB, ResBE, ResE>(
1119        self,
1120        request_fn: F,
1121    ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1122    where
1123        F: Fn(Request<Incoming>) -> Fut + 'static,
1124        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
1125        ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
1126        ResE: std::error::Error,
1127        ResBE: std::error::Error,
1128    {
1129        self.handle_with_error_fn(request_fn, |is_timeout| async move {
1130            let mut response = Response::builder();
1131            if is_timeout {
1132                response = response.status(http::StatusCode::REQUEST_TIMEOUT);
1133            } else {
1134                response = response.status(http::StatusCode::BAD_REQUEST);
1135            }
1136            response.body(Empty::new())
1137        })
1138    }
1139}
1140
1141struct Http1Body {
1142    #[allow(clippy::type_complexity)]
1143    inner: Pin<Box<Receiver<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>>>,
1144}
1145
1146impl Body for Http1Body {
1147    type Data = bytes::Bytes;
1148    type Error = std::io::Error;
1149
1150    #[inline]
1151    fn poll_frame(
1152        mut self: Pin<&mut Self>,
1153        cx: &mut Context<'_>,
1154    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
1155        match self.inner.as_mut().poll_next(cx) {
1156            Poll::Ready(Some(Ok(frame))) => Poll::Ready(Some(Ok(frame))),
1157            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
1158            Poll::Ready(None) => Poll::Ready(None),
1159            Poll::Pending => Poll::Pending,
1160        }
1161    }
1162}
1163
1164/// Searches for the header/body separator in a given slice.
1165/// Returns the index of the separator and the length of the separator.
1166#[inline]
1167fn search_header_body_separator(slice: &[u8]) -> Option<(usize, usize)> {
1168    if slice.len() < 2 {
1169        // Slice too short
1170        return None;
1171    }
1172    for (i, b) in slice.iter().copied().enumerate() {
1173        if b == b'\r' {
1174            if slice[i + 1..].chunks(3).next() == Some(&b"\n\r\n"[..]) {
1175                return Some((i, 4));
1176            }
1177        } else if b == b'\n' && slice.get(i + 1) == Some(&b'\n') {
1178            return Some((i, 2));
1179        }
1180    }
1181    None
1182}
1183
1184/// Writes the chunk size to the given buffer in hexadecimal format, followed by `\r\n`.
1185#[inline]
1186fn write_chunk_size(dst: &mut Vec<u8>, len: usize) {
1187    let mut buf = [0u8; 18];
1188    let mut n = len;
1189    let mut pos = buf.len();
1190    loop {
1191        pos -= 1;
1192        buf[pos] = HEX_DIGITS[n & 0xF];
1193        n >>= 4;
1194        if n == 0 {
1195            break;
1196        }
1197    }
1198    dst.extend_from_slice(&buf[pos..]);
1199    dst.extend_from_slice(b"\r\n");
1200}