Skip to main content

vibeio_http/h1/
mod.rs

1mod options;
2mod tests;
3mod upgrade;
4mod writebuf;
5mod zerocopy;
6
7pub use options::*;
8pub use upgrade::*;
9pub use zerocopy::*;
10
11#[cfg(unix)]
12pub(crate) type RawHandle = std::os::fd::RawFd;
13#[cfg(windows)]
14pub(crate) type RawHandle = std::os::windows::io::RawHandle;
15
16use std::{
17    io::IoSlice,
18    mem::MaybeUninit,
19    pin::Pin,
20    str::FromStr,
21    task::{Context, Poll},
22    time::UNIX_EPOCH,
23};
24
25use async_channel::Receiver;
26use bytes::{Buf, Bytes, BytesMut};
27use futures_util::stream::Stream;
28use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, Uri, Version};
29use http_body::Body;
30use http_body_util::{BodyExt, Empty};
31use memchr::{memchr3_iter, memmem};
32use tokio::io::{AsyncReadExt, AsyncWriteExt};
33use tokio_util::sync::CancellationToken;
34
35use crate::{h1::writebuf::WriteBuf, EarlyHints, HttpProtocol, Incoming};
36
37const HEX_DIGITS: &[u8; 16] = b"0123456789ABCDEF";
38const WRITE_BUF_BATCH_THRESHOLD: usize = 16384;
39
40/// An HTTP/1.x connection handler.
41///
42/// `Http1` wraps an async I/O stream (`Io`) and provides a complete
43/// HTTP/1.0 and HTTP/1.1 server implementation, including:
44///
45/// - Request head parsing (via [`httparse`])
46/// - Streaming request bodies (content-length and chunked transfer-encoding)
47/// - Chunked response encoding and trailer support
48/// - `100 Continue` and `103 Early Hints` interim responses
49/// - HTTP connection upgrades (e.g. WebSocket)
50/// - Optional zero-copy response sending on Linux (see `Http1::zerocopy`)
51/// - Keep-alive connection reuse
52/// - Graceful shutdown via a [`CancellationToken`]
53///
54/// # Construction
55///
56/// ```rust,ignore
57/// let http1 = Http1::new(tcp_stream, Http1Options::default());
58/// ```
59///
60/// # Serving requests
61///
62/// Use the [`HttpProtocol`] trait methods ([`handle`](HttpProtocol::handle) /
63/// [`handle_with_error_fn`](HttpProtocol::handle_with_error_fn)) to drive the
64/// connection to completion:
65///
66/// ```rust,ignore
67/// http1.handle(|req| async move {
68///     Ok::<_, Infallible>(Response::new(Full::new(Bytes::from("Hello!"))))
69/// }).await?;
70/// ```
71pub struct Http1<Io> {
72    io: Io,
73    options: options::Http1Options,
74    cancel_token: Option<CancellationToken>,
75    parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]>,
76    date_header_value_cached: Option<(String, std::time::SystemTime)>,
77    cached_headers: Option<HeaderMap>,
78    read_buf: BytesMut,
79    response_head_buf: Vec<u8>,
80    write_buf: WriteBuf,
81}
82
83#[cfg(all(target_os = "linux", feature = "h1-zerocopy"))]
84impl<Io> Http1<Io>
85where
86    for<'a> Io: tokio::io::AsyncRead
87        + tokio::io::AsyncWrite
88        + vibeio::io::AsInnerRawHandle<'a>
89        + Unpin
90        + 'static,
91{
92    /// Converts this `Http1` into an [`Http1Zerocopy`] that uses emulated
93    /// sendfile (Linux only) to send response bodies without copying data
94    /// through user space.
95    ///
96    /// The response body must have a `ZerocopyResponse` extension installed
97    /// (via [`install_zerocopy`]) containing the file descriptor to send from.
98    /// Responses without that extension are sent normally.
99    ///
100    /// Only available on Linux (`target_os = "linux"`), and only when `Io`
101    /// implements [`vibeio::io::AsInnerRawHandle`].
102    #[inline]
103    pub fn zerocopy(self) -> Http1Zerocopy<Io> {
104        Http1Zerocopy { inner: self }
105    }
106}
107
108impl<Io> Http1<Io>
109where
110    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
111{
112    /// Creates a new `Http1` connection handler wrapping the given I/O stream.
113    ///
114    /// The `options` value controls limits, timeouts, and optional features;
115    /// see [`Http1Options`] for details.
116    ///
117    /// # Example
118    ///
119    /// ```rust,ignore
120    /// let http1 = Http1::new(tcp_stream, Http1Options::default());
121    /// ```
122    #[inline]
123    pub fn new(io: Io, options: options::Http1Options) -> Self {
124        // Safety: u8 is a primitive type, so we can safely assume initialization
125        let read_buf = BytesMut::with_capacity(options.max_header_size);
126        let parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]> =
127            Box::new_uninit_slice(options.max_header_count);
128        Self {
129            io,
130            options,
131            cancel_token: None,
132            parsed_headers,
133            date_header_value_cached: None,
134            cached_headers: None,
135            read_buf,
136            response_head_buf: Vec::with_capacity(1024),
137            write_buf: WriteBuf::new(),
138        }
139    }
140
141    #[inline]
142    fn get_date_header_value(&mut self) -> &str {
143        let now = std::time::SystemTime::now();
144        if self.date_header_value_cached.as_ref().is_none_or(|v| {
145            v.1.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
146                != now.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
147        }) {
148            let value = httpdate::fmt_http_date(now).to_string();
149            self.date_header_value_cached = Some((value, now));
150        }
151        self.date_header_value_cached
152            .as_ref()
153            .map(|v| v.0.as_str())
154            .unwrap_or("")
155    }
156
157    /// Attaches a [`CancellationToken`] for graceful shutdown.
158    ///
159    /// After the current in-flight request has been fully handled and its
160    /// response written, the connection loop checks whether the token has been
161    /// cancelled. If it has, the loop exits cleanly instead of waiting for the
162    /// next request.
163    ///
164    /// This allows the server to drain active connections without abruptly
165    /// closing them mid-response.
166    #[inline]
167    pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
168        self.cancel_token = Some(token);
169        self
170    }
171
172    #[inline]
173    async fn fill_buf(&mut self) -> Result<usize, std::io::Error> {
174        if self.read_buf.remaining() < 1024 {
175            self.read_buf.reserve(1024);
176        }
177        let spare_capacity = self.read_buf.spare_capacity_mut();
178        // Safety: The buffer is are read only after the request head has been parsed
179        let n = self
180            .io
181            .read(unsafe {
182                &mut *std::ptr::slice_from_raw_parts_mut(
183                    spare_capacity.as_mut_ptr() as *mut u8,
184                    spare_capacity.len(),
185                )
186            })
187            .await?;
188        if n == 0 {
189            return Ok(0);
190        }
191        unsafe { self.read_buf.set_len(self.read_buf.len() + n) };
192        Ok(n)
193    }
194
195    #[inline]
196    async fn read_body_fn(
197        &mut self,
198        body_tx: &async_channel::Sender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
199        content_length: u64,
200    ) -> Result<(), std::io::Error> {
201        let mut remaining = content_length;
202        let mut just_started = true;
203        while remaining > 0 {
204            let have_to_read_buf = !just_started || self.read_buf.is_empty();
205            just_started = false;
206            if have_to_read_buf {
207                let n = self.fill_buf().await?;
208                if n == 0 {
209                    break;
210                }
211            }
212            let chunk = self
213                .read_buf
214                .split_to(
215                    self.read_buf
216                        .len()
217                        .min(remaining.min(usize::MAX as u64) as usize),
218                )
219                .freeze();
220            remaining -= chunk.len() as u64;
221
222            let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
223        }
224        body_tx.close(); // Close the body_tx channel to signal EOF
225        Ok(())
226    }
227
228    #[inline]
229    async fn read_body_chunk(
230        &mut self,
231        would_have_trailers: bool,
232    ) -> Result<bytes::Bytes, std::io::Error> {
233        let len = {
234            // Safety: u8 is a primitive type, so we can safely assume initialization
235            let mut len_buf_pos: usize = 0;
236            let mut just_started = true;
237            loop {
238                if len_buf_pos >= 48 {
239                    return Err(std::io::Error::new(
240                        std::io::ErrorKind::InvalidData,
241                        "chunk length buffer overflow",
242                    ));
243                }
244
245                let begin_search = len_buf_pos.saturating_sub(1);
246
247                let have_to_read_buf = !just_started || self.read_buf.is_empty();
248                just_started = false;
249                if have_to_read_buf {
250                    let n = self.fill_buf().await?;
251                    if n == 0 {
252                        return Err(std::io::Error::new(
253                            std::io::ErrorKind::UnexpectedEof,
254                            "unexpected EOF",
255                        ));
256                    }
257                    len_buf_pos += n;
258                } else {
259                    len_buf_pos += self.read_buf.len();
260                }
261
262                if let Some(pos) =
263                    memmem::find(&self.read_buf[begin_search..len_buf_pos.min(48)], b"\r\n")
264                {
265                    let numbers =
266                        std::str::from_utf8(&self.read_buf[begin_search..begin_search + pos])
267                            .map_err(|_| {
268                                std::io::Error::new(
269                                    std::io::ErrorKind::InvalidData,
270                                    "invalid chunk length",
271                                )
272                            })?;
273                    let len = usize::from_str_radix(numbers, 16).map_err(|_| {
274                        std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid chunk length")
275                    })?;
276                    // Ignore the trailing CRLF
277                    self.read_buf.advance(begin_search + pos + 2);
278                    break len;
279                }
280            }
281        };
282        // Safety: u8 is a primitive type, so we can safely assume initialization
283        let mut read = 0;
284        if len == 0 && would_have_trailers {
285            return Ok(bytes::Bytes::new()); // Empty terminating chunk
286        }
287        let mut just_started = true;
288        // + 2, because we need to read the trailing CRLF
289        while read < len + 2 {
290            let have_to_read_buf = !just_started || self.read_buf.is_empty();
291            just_started = false;
292            if have_to_read_buf {
293                let n = self.fill_buf().await?;
294                if n == 0 {
295                    return Err(std::io::Error::new(
296                        std::io::ErrorKind::UnexpectedEof,
297                        "unexpected EOF",
298                    ));
299                }
300                read += n;
301            } else {
302                read += self.read_buf.len();
303            }
304        }
305        let chunk = self.read_buf.split_to(len).freeze();
306        self.read_buf.advance(2); // Ignore the trailing CRLF
307        Ok(chunk)
308    }
309
310    #[inline]
311    async fn read_trailers(&mut self) -> Result<Option<HeaderMap>, std::io::Error> {
312        // Safety: u8 is a primitive type, so we can safely assume initialization
313        let mut bytes_read: usize = 0;
314        let mut just_started = true;
315        while bytes_read < self.options.max_header_size {
316            let old_bytes_read = bytes_read;
317            let begin_search = old_bytes_read.saturating_sub(3);
318
319            let have_to_read_buf = !just_started || self.read_buf.is_empty();
320            just_started = false;
321            if have_to_read_buf {
322                let n = self.fill_buf().await?;
323                if n == 0 {
324                    return Err(std::io::Error::new(
325                        std::io::ErrorKind::UnexpectedEof,
326                        "unexpected EOF",
327                    ));
328                }
329                bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
330            } else {
331                bytes_read =
332                    (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
333            }
334
335            if bytes_read > 2 && self.read_buf[0] == b'\r' && self.read_buf[1] == b'\n' {
336                // No trailers, return None
337                return Ok(None);
338            }
339
340            if let Some(separator_index) =
341                memmem::find(&self.read_buf[begin_search..bytes_read], b"\r\n\r\n")
342            {
343                let to_parse_length = begin_search + separator_index + 4;
344                let buf_ro = self.read_buf.split_to(to_parse_length).freeze();
345
346                // Parse trailers using `httparse` crate's header parsing
347                let mut httparse_trailers =
348                    vec![httparse::EMPTY_HEADER; self.options.max_header_count].into_boxed_slice();
349                let status = httparse::parse_headers(&buf_ro, &mut httparse_trailers)
350                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
351                if let httparse::Status::Complete((_, trailers)) = status {
352                    let mut trailers_constructed = HeaderMap::new();
353                    for header in trailers {
354                        if header == &httparse::EMPTY_HEADER {
355                            // No more headers...
356                            break;
357                        }
358                        let name = HeaderName::from_bytes(header.name.as_bytes())
359                            .map_err(|e| std::io::Error::other(e.to_string()))?;
360                        let value_start = header.value.as_ptr() as usize - buf_ro.as_ptr() as usize;
361                        let value_len = header.value.len();
362                        // Safety: the header value is already validated by httparse
363                        let value = unsafe {
364                            HeaderValue::from_maybe_shared_unchecked(
365                                buf_ro.slice(value_start..(value_start + value_len)),
366                            )
367                        };
368                        trailers_constructed.append(name, value);
369                    }
370
371                    return Ok(Some(trailers_constructed));
372                } else {
373                    return Err(std::io::Error::new(
374                        std::io::ErrorKind::InvalidInput,
375                        "trailer headers incomplete",
376                    ));
377                }
378            }
379        }
380        Err(std::io::Error::new(
381            std::io::ErrorKind::InvalidData,
382            "request too large",
383        ))
384    }
385
386    #[inline]
387    async fn read_chunked_body_fn(
388        &mut self,
389        body_tx: &async_channel::Sender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
390        would_have_trailers: bool,
391    ) -> Result<(), std::io::Error> {
392        loop {
393            let chunk = self.read_body_chunk(would_have_trailers).await?;
394            if chunk.is_empty() {
395                break;
396            }
397
398            let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
399        }
400        if would_have_trailers {
401            // Trailers
402            let trailers = self.read_trailers().await?;
403            if let Some(trailers) = trailers {
404                let _ = body_tx.send(Ok(http_body::Frame::trailers(trailers))).await;
405            }
406        }
407        body_tx.close(); // Close the body_tx channel to signal EOF
408        Ok(())
409    }
410
411    #[inline]
412    async fn read_request(
413        &mut self,
414    ) -> Result<
415        Option<(
416            Request<Incoming>,
417            async_channel::Sender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
418        )>,
419        std::io::Error,
420    > {
421        // Parse HTTP request using httparse
422        let (request, body_tx) = {
423            let Some((head, headers)) = self.get_head().await? else {
424                return Ok(None);
425            };
426            // Safety: The headers are read only after the request head has been parsed
427            let headers = unsafe {
428                std::mem::transmute::<
429                    &mut [MaybeUninit<httparse::Header<'static>>],
430                    &mut [MaybeUninit<httparse::Header<'_>>],
431                >(headers)
432            };
433            let mut req = httparse::Request::new(&mut []);
434            let status = req
435                .parse_with_uninit_headers(&head, headers)
436                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
437            if status.is_partial() {
438                return Err(std::io::Error::new(
439                    std::io::ErrorKind::InvalidData,
440                    "partial request head",
441                ));
442            }
443
444            // Convert httparse HTTP request to `http` one
445            let (body_tx, body_rx) = async_channel::bounded(2);
446            let request_body = Http1Body {
447                inner: Box::pin(body_rx),
448            };
449            let mut request = Request::new(Incoming::new(request_body));
450            match req.version {
451                Some(0) => *request.version_mut() = http::Version::HTTP_10,
452                Some(1) => *request.version_mut() = http::Version::HTTP_11,
453                _ => *request.version_mut() = http::Version::HTTP_11,
454            };
455            if let Some(method) = req.method {
456                *request.method_mut() = Method::from_bytes(method.as_bytes())
457                    .map_err(|e| std::io::Error::other(e.to_string()))?;
458            }
459            if let Some(path) = req.path {
460                *request.uri_mut() =
461                    Uri::from_str(path).map_err(|e| std::io::Error::other(e.to_string()))?;
462            }
463            let mut header_map = self.cached_headers.take().unwrap_or_default();
464            header_map.clear();
465            let additional_capacity = req.headers.len().saturating_sub(header_map.capacity());
466            if additional_capacity > 0 {
467                header_map.reserve(additional_capacity);
468            }
469            for header in req.headers {
470                if header == &httparse::EMPTY_HEADER {
471                    // No more headers...
472                    break;
473                }
474                let name = HeaderName::from_bytes(header.name.as_bytes())
475                    .map_err(|e| std::io::Error::other(e.to_string()))?;
476                let value_start = header.value.as_ptr() as usize - head.as_ptr() as usize;
477                let value_len = header.value.len();
478                // Safety: the header value is already validated by httparse
479                let value = unsafe {
480                    HeaderValue::from_maybe_shared_unchecked(
481                        head.slice(value_start..(value_start + value_len)),
482                    )
483                };
484                header_map.append(name, value);
485            }
486            *request.headers_mut() = header_map;
487
488            (request, body_tx)
489        };
490        Ok(Some((request, body_tx)))
491    }
492
493    #[inline]
494    async fn get_head(
495        &mut self,
496    ) -> Result<Option<(Bytes, &mut [MaybeUninit<httparse::Header<'static>>])>, std::io::Error>
497    {
498        let mut request_line_read = false;
499        let mut bytes_read: usize = 0;
500        let mut whitespace_trimmed = None;
501        let mut just_started = true;
502        while bytes_read < self.options.max_header_size {
503            let old_bytes_read = bytes_read;
504            let begin_search = old_bytes_read.saturating_sub(3);
505
506            let have_to_read_buf = !just_started || self.read_buf.is_empty();
507            just_started = false;
508            if have_to_read_buf {
509                let n = self.fill_buf().await?;
510                if n == 0 {
511                    if whitespace_trimmed.is_none() {
512                        return Ok(None);
513                    }
514                    return Err(std::io::Error::new(
515                        std::io::ErrorKind::UnexpectedEof,
516                        "unexpected EOF",
517                    ));
518                }
519                bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
520            } else {
521                bytes_read =
522                    (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
523            }
524
525            if whitespace_trimmed.is_none() {
526                whitespace_trimmed = self.read_buf[old_bytes_read..bytes_read]
527                    .iter()
528                    .position(|b| !b.is_ascii_whitespace());
529            }
530
531            if let Some(whitespace_trimmed) = whitespace_trimmed {
532                // Validate first line (request line) before checking for header/body separator
533                if !request_line_read {
534                    let memchr = memchr3_iter(
535                        b' ',
536                        b'\r',
537                        b'\n',
538                        &self.read_buf[whitespace_trimmed..bytes_read],
539                    );
540                    let mut spaces = 0;
541                    for separator_index in memchr {
542                        if self.read_buf[whitespace_trimmed + separator_index] == b' ' {
543                            if spaces >= 2 {
544                                return Err(std::io::Error::new(
545                                    std::io::ErrorKind::InvalidInput,
546                                    "bad request first line",
547                                ));
548                            }
549                            spaces += 1;
550                        } else if spaces == 2 {
551                            request_line_read = true;
552                            break;
553                        } else {
554                            return Err(std::io::Error::new(
555                                std::io::ErrorKind::InvalidInput,
556                                "bad request first line",
557                            ));
558                        }
559                    }
560                }
561
562                if request_line_read {
563                    let begin_search = begin_search.max(whitespace_trimmed);
564                    if let Some((separator_index, separator_len)) =
565                        search_header_body_separator(&self.read_buf[begin_search..bytes_read])
566                    {
567                        let to_parse_length =
568                            begin_search + separator_index + separator_len - whitespace_trimmed;
569                        self.read_buf.advance(whitespace_trimmed);
570                        let head = self.read_buf.split_to(to_parse_length);
571                        return Ok(Some((head.freeze(), &mut self.parsed_headers)));
572                    }
573                }
574            }
575        }
576        Err(std::io::Error::new(
577            std::io::ErrorKind::InvalidData,
578            "request too large",
579        ))
580    }
581
582    #[inline]
583    async fn write_response<Z, ZFut>(
584        &mut self,
585        mut response: Response<
586            impl Body<Data = bytes::Bytes, Error = impl std::error::Error> + Unpin,
587        >,
588        version: Version,
589        write_trailers: bool,
590        zerocopy_fn: Option<Z>,
591    ) -> Result<(), std::io::Error>
592    where
593        Z: FnMut(RawHandle, &'static Io, u64) -> ZFut,
594        ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
595    {
596        // Date header
597        if self.options.send_date_header {
598            response.headers_mut().insert(
599                header::DATE,
600                HeaderValue::from_str(self.get_date_header_value())
601                    .map_err(|e| std::io::Error::other(e.to_string()))?,
602            );
603        }
604
605        // If the body has a size hint, set the Content-Length header if it's not already set
606        if let Some(suggested_content_length) = response.body().size_hint().exact() {
607            let headers = response.headers_mut();
608            if !headers.contains_key(header::CONTENT_LENGTH) {
609                headers.insert(header::CONTENT_LENGTH, suggested_content_length.into());
610            }
611        }
612
613        let chunked = response
614            .headers()
615            .get(header::TRANSFER_ENCODING)
616            .map(|v| {
617                v.to_str().ok().is_some_and(|s| {
618                    s.split(',')
619                        .any(|s| s.trim().eq_ignore_ascii_case("chunked"))
620                })
621            })
622            .unwrap_or_else(|| {
623                response
624                    .headers()
625                    .get(header::CONTENT_LENGTH)
626                    .and_then(|v| v.to_str().ok())
627                    .is_none_or(|s| s.parse::<u64>().is_err())
628            });
629
630        if chunked {
631            response.headers_mut().insert(
632                header::TRANSFER_ENCODING,
633                HeaderValue::from_static("chunked"),
634            );
635            while response
636                .headers_mut()
637                .remove(header::CONTENT_LENGTH)
638                .is_some()
639            {}
640        }
641
642        let (parts, mut body) = response.into_parts();
643
644        self.response_head_buf.clear();
645        let estimated_head_len = 30 + parts.headers.len() * 30; // Similar to Hyper's heuristic
646        if self.response_head_buf.capacity() < estimated_head_len {
647            self.response_head_buf
648                .reserve(estimated_head_len - self.response_head_buf.capacity());
649        }
650        let head = &mut self.response_head_buf;
651        if version == Version::HTTP_10 {
652            head.extend_from_slice(b"HTTP/1.0 ");
653        } else {
654            head.extend_from_slice(b"HTTP/1.1 ");
655        }
656        let status = parts.status;
657        head.extend_from_slice(status.as_str().as_bytes());
658        if let Some(canonical_reason) = status.canonical_reason() {
659            head.extend_from_slice(b" ");
660            head.extend_from_slice(canonical_reason.as_bytes());
661        }
662        head.extend_from_slice(b"\r\n");
663        for (name, value) in &parts.headers {
664            head.extend_from_slice(name.as_str().as_bytes());
665            head.extend_from_slice(b": ");
666            head.extend_from_slice(value.as_bytes());
667            head.extend_from_slice(b"\r\n");
668        }
669        head.extend_from_slice(b"\r\n");
670        unsafe {
671            self.write_buf.push(IoSlice::new(head));
672        }
673
674        if !chunked {
675            if let Some(content_length) = parts
676                .headers
677                .get(header::CONTENT_LENGTH)
678                .and_then(|v| v.to_str().ok())
679                .and_then(|s| s.parse::<u64>().ok())
680            {
681                if let Some(zero_copy) = parts.extensions.get::<ZerocopyResponse>() {
682                    if let Some(mut zerocopy_fn) = zerocopy_fn {
683                        // Zerocopy
684                        unsafe {
685                            self.write_buf
686                                .flush(&mut self.io, self.options.enable_vectored_write)
687                                .await?
688                        };
689                        zerocopy_fn(
690                            zero_copy.handle,
691                            // Safety: the lifetime of the static reference is bound by the lifetime of the Io struct
692                            unsafe { std::mem::transmute::<&Io, &'static Io>(&self.io) },
693                            content_length,
694                        )
695                        .await?;
696                        self.io.flush().await?;
697                        let reclaimed_headers = parts.headers;
698                        self.cached_headers = Some(reclaimed_headers);
699                        return Ok(());
700                    }
701                }
702            }
703        }
704
705        let mut trailers_written = false;
706        while let Some(chunk) = body.frame().await {
707            let chunk = chunk.map_err(|e| std::io::Error::other(e.to_string()))?;
708            match chunk.into_data() {
709                Ok(data) => {
710                    if chunked {
711                        let mut chunk_size_buf = [0u8; 18];
712                        let chunk_size = write_chunk_size(&mut chunk_size_buf, data.len());
713                        self.write_buf.push_copy(chunk_size);
714                        self.write_buf.push_bytes(data);
715                        unsafe {
716                            self.write_buf.push(IoSlice::new(b"\r\n"));
717                        }
718                    } else {
719                        self.write_buf.push_bytes(data);
720                    }
721                    while self.write_buf.len() >= WRITE_BUF_BATCH_THRESHOLD {
722                        unsafe {
723                            self.write_buf
724                                .write(&mut self.io, self.options.enable_vectored_write)
725                                .await?;
726                        }
727                    }
728                }
729                Err(chunk) => {
730                    if let Ok(trailers) = chunk.into_trailers() {
731                        if write_trailers {
732                            unsafe {
733                                self.write_buf.push(IoSlice::new(b"0\r\n"));
734                                for (name, value) in &trailers {
735                                    self.write_buf.push_copy(name.as_str().as_bytes());
736                                    self.write_buf.push(IoSlice::new(b": "));
737                                    self.write_buf.push_copy(value.as_bytes());
738                                    self.write_buf.push(IoSlice::new(b"\r\n"));
739                                }
740                                self.write_buf.push(IoSlice::new(b"\r\n"));
741                            }
742                            trailers_written = true;
743                        }
744                        break;
745                    }
746                }
747            };
748        }
749        if chunked && !trailers_written {
750            // Terminating chunk
751            unsafe {
752                self.write_buf.push(IoSlice::new(b"0\r\n\r\n"));
753            }
754        }
755        unsafe {
756            self.write_buf
757                .flush(&mut self.io, self.options.enable_vectored_write)
758                .await?;
759        }
760        self.io.flush().await?;
761        let reclaimed_headers = parts.headers;
762        self.cached_headers = Some(reclaimed_headers);
763
764        Ok(())
765    }
766
767    #[inline]
768    async fn write_100_continue(&mut self, version: Version) -> Result<(), std::io::Error> {
769        if version == Version::HTTP_10 {
770            self.io.write_all(b"HTTP/1.0 100 Continue\r\n\r\n").await?;
771        } else {
772            self.io.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").await?;
773        }
774        self.io.flush().await?;
775
776        Ok(())
777    }
778
779    #[inline]
780    async fn write_early_hints(
781        &mut self,
782        version: Version,
783        headers: http::HeaderMap,
784    ) -> Result<(), std::io::Error> {
785        let mut head = Vec::new();
786        if version == Version::HTTP_10 {
787            head.extend_from_slice(b"HTTP/1.0 103 Early Hints\r\n");
788        } else {
789            head.extend_from_slice(b"HTTP/1.1 103 Early Hints\r\n");
790        }
791        let mut current_header_name = None;
792        for (name, value) in headers {
793            if let Some(name) = name {
794                current_header_name = Some(name);
795            };
796            if let Some(current_header_name) = &current_header_name {
797                head.extend_from_slice(current_header_name.as_str().as_bytes());
798                if value.is_empty() {
799                    head.extend_from_slice(b":\r\n");
800                    continue;
801                }
802                head.extend_from_slice(b": ");
803                head.extend_from_slice(value.as_bytes());
804                head.extend_from_slice(b"\r\n");
805            }
806        }
807        head.extend_from_slice(b"\r\n");
808
809        self.io.write_all(&head).await?;
810
811        Ok(())
812    }
813
814    #[inline]
815    pub(crate) async fn handle_with_error_fn_and_zerocopy<
816        F,
817        Fut,
818        ResB,
819        ResBE,
820        ResE,
821        EF,
822        EFut,
823        EResB,
824        EResBE,
825        EResE,
826        ZF,
827        ZFut,
828    >(
829        mut self,
830        request_fn: F,
831        error_fn: EF,
832        mut zerocopy_fn: Option<ZF>,
833    ) -> Result<(), std::io::Error>
834    where
835        F: Fn(Request<Incoming>) -> Fut + 'static,
836        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
837        ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
838        ResE: std::error::Error,
839        ResBE: std::error::Error,
840        EF: FnOnce(bool) -> EFut,
841        EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
842        EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin,
843        EResE: std::error::Error,
844        EResBE: std::error::Error,
845        ZF: FnMut(RawHandle, &'static Io, u64) -> ZFut,
846        ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
847    {
848        let mut keep_alive = true;
849
850        while keep_alive {
851            let (mut request, body_tx) = match if let Some(timeout) =
852                self.options.header_read_timeout
853            {
854                vibeio::time::timeout(timeout, self.read_request()).await
855            } else {
856                Ok(self.read_request().await)
857            } {
858                Ok(Ok(Some(d))) => d,
859                Ok(Ok(None)) => {
860                    return Ok(());
861                }
862                Ok(Err(e)) => {
863                    // Parse error
864                    if let Ok(mut response) = error_fn(false).await {
865                        response
866                            .headers_mut()
867                            .insert(header::CONNECTION, HeaderValue::from_static("close"));
868
869                        let _ = self
870                            .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
871                            .await;
872                    }
873                    return Err(e);
874                }
875                Err(_) => {
876                    // Timeout error
877                    if let Ok(mut response) = error_fn(true).await {
878                        response
879                            .headers_mut()
880                            .insert(header::CONNECTION, HeaderValue::from_static("close"));
881
882                        let _ = self
883                            .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
884                            .await;
885                    }
886                    return Err(std::io::Error::new(
887                        std::io::ErrorKind::TimedOut,
888                        "header read timeout",
889                    ));
890                }
891            };
892
893            // Connection header detection
894            let connection_header_split = request
895                .headers()
896                .get(header::CONNECTION)
897                .and_then(|v| v.to_str().ok())
898                .map(|v| v.split(",").map(|v| v.trim()));
899            let is_connection_close = connection_header_split
900                .clone()
901                .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("close")));
902            let is_connection_keep_alive = connection_header_split
903                .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("keep-alive")));
904            keep_alive = !is_connection_close
905                && (is_connection_keep_alive || request.version() == http::Version::HTTP_11);
906
907            let version = request.version();
908
909            // 100 Continue
910            if self.options.send_continue_response {
911                let is_100_continue = request
912                    .headers()
913                    .get(header::EXPECT)
914                    .and_then(|v| v.to_str().ok())
915                    .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
916                if is_100_continue {
917                    self.write_100_continue(version).await?;
918                }
919            }
920
921            // 103 Early Hints
922            let early_hints_fut = if self.options.enable_early_hints {
923                let (early_hints_tx, early_hints_rx) = async_channel::unbounded();
924                let early_hints = EarlyHints::new(early_hints_tx);
925                request.extensions_mut().insert(early_hints);
926                // Safety: the function below is used only in futures_util::future::select
927                // Also, another function that would borrow self would read data,
928                // while this function would write data
929                let mut_self = unsafe { std::mem::transmute::<&mut Self, &mut Self>(&mut self) };
930                Some(async {
931                    let early_hints_rx = early_hints_rx;
932                    while let Ok((headers, sender)) = early_hints_rx.recv().await {
933                        sender
934                            .into_inner()
935                            .send(mut_self.write_early_hints(version, headers).await)
936                            .ok();
937                    }
938                    futures_util::future::pending::<Result<(), std::io::Error>>().await
939                })
940            } else {
941                None
942            };
943
944            // Content-Length header
945            let content_length = request
946                .headers()
947                .get(header::CONTENT_LENGTH)
948                .and_then(|v| v.to_str().ok())
949                .and_then(|v| v.parse::<u64>().ok())
950                .unwrap_or(0);
951            let chunked = request
952                .headers()
953                .get(header::TRANSFER_ENCODING)
954                .and_then(|v| v.to_str().ok())
955                .is_some_and(|v| {
956                    v.split(',')
957                        .any(|v| v.trim().eq_ignore_ascii_case("chunked"))
958                });
959            let has_trailers = request
960                .headers()
961                .get(header::TRAILER)
962                .map(|v| v.to_str().ok().is_some_and(|s| !s.is_empty()))
963                .unwrap_or(false);
964            let write_trailers = request
965                .headers()
966                .get(header::TE)
967                .and_then(|v| v.to_str().ok())
968                .map(|v| {
969                    v.split(',')
970                        .any(|v| v.trim().eq_ignore_ascii_case("trailers"))
971                })
972                .unwrap_or(false);
973
974            // Install HTTP upgrade
975            let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
976            let upgrade = Upgrade::new(upgrade_rx);
977            let upgraded = upgrade.upgraded.clone();
978            request.extensions_mut().insert(upgrade);
979
980            // Get HTTP response
981            let mut response = {
982                let read_body_fut = async {
983                    if chunked {
984                        self.read_chunked_body_fn(&body_tx, has_trailers).await
985                    } else {
986                        self.read_body_fn(&body_tx, content_length).await
987                    }
988                };
989                let read_body_fut_pin = std::pin::pin!(read_body_fut);
990                let request_fut = request_fn(request);
991                let request_fut_pin = std::pin::pin!(request_fut);
992                let early_hints_fut: Pin<
993                    Box<dyn std::future::Future<Output = Result<(), std::io::Error>>>,
994                > = if let Some(early_hints) = early_hints_fut {
995                    Box::pin(early_hints)
996                } else {
997                    Box::pin(futures_util::future::pending::<Result<(), std::io::Error>>())
998                };
999
1000                let select_read_body_either =
1001                    futures_util::future::select(request_fut_pin, early_hints_fut);
1002                let select_either =
1003                    futures_util::future::select(read_body_fut_pin, select_read_body_either).await;
1004
1005                let (response, body_fut) = match select_either {
1006                    futures_util::future::Either::Left((result, request_fut)) => {
1007                        result?;
1008                        (
1009                            match request_fut.await {
1010                                futures_util::future::Either::Left((response, _)) => response,
1011                                futures_util::future::Either::Right((_, _)) => unreachable!(),
1012                            },
1013                            None,
1014                        )
1015                    }
1016                    futures_util::future::Either::Right((response, read_body_fut)) => (
1017                        match response {
1018                            futures_util::future::Either::Left((response, _)) => response,
1019                            futures_util::future::Either::Right((_, _)) => unreachable!(),
1020                        },
1021                        Some(read_body_fut),
1022                    ),
1023                };
1024
1025                // Drain away remaining body
1026                if let Some(body_fut) = body_fut {
1027                    body_fut.await?;
1028                }
1029
1030                response.map_err(|e| std::io::Error::other(e.to_string()))?
1031            };
1032
1033            let mut was_upgraded = false;
1034            if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
1035                was_upgraded = true;
1036                response
1037                    .headers_mut()
1038                    .insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
1039            } else if keep_alive {
1040                if version == Version::HTTP_10
1041                    || response.headers().contains_key(header::CONNECTION)
1042                {
1043                    response
1044                        .headers_mut()
1045                        .insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
1046                }
1047            } else if version == Version::HTTP_11
1048                || response.headers().contains_key(header::CONNECTION)
1049            {
1050                response
1051                    .headers_mut()
1052                    .insert(header::CONNECTION, HeaderValue::from_static("close"));
1053            }
1054
1055            // Write response to IO
1056            self.write_response(response, version, write_trailers, zerocopy_fn.as_mut())
1057                .await?;
1058
1059            if was_upgraded {
1060                // HTTP upgrade
1061                let frozen_buf = self.read_buf.freeze();
1062                let _ = upgrade_tx.send(Upgraded::new(
1063                    self.io,
1064                    if frozen_buf.is_empty() {
1065                        None
1066                    } else {
1067                        Some(frozen_buf)
1068                    },
1069                ));
1070                return Ok(());
1071            }
1072
1073            if self.cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
1074                // Graceful shutdown requested, break out of loop
1075                break;
1076            }
1077        }
1078        Ok(())
1079    }
1080}
1081
1082impl<Io> HttpProtocol for Http1<Io>
1083where
1084    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
1085{
1086    #[inline]
1087    fn handle_with_error_fn<F, Fut, ResB, ResBE, ResE, EF, EFut, EResB, EResBE, EResE>(
1088        self,
1089        request_fn: F,
1090        error_fn: EF,
1091    ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1092    where
1093        F: Fn(Request<Incoming>) -> Fut + 'static,
1094        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
1095        ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
1096        ResE: std::error::Error,
1097        ResBE: std::error::Error,
1098        EF: FnOnce(bool) -> EFut,
1099        EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
1100        EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin,
1101        EResE: std::error::Error,
1102        EResBE: std::error::Error,
1103    {
1104        #[allow(clippy::type_complexity)]
1105        let no_zerocopy: Option<
1106            Box<
1107                dyn FnMut(
1108                    RawHandle,
1109                    &Io,
1110                    u64,
1111                ) -> Box<
1112                    dyn std::future::Future<Output = Result<(), std::io::Error>>
1113                        + Unpin
1114                        + Send
1115                        + Sync,
1116                >,
1117            >,
1118        > = None;
1119        self.handle_with_error_fn_and_zerocopy(request_fn, error_fn, no_zerocopy)
1120    }
1121
1122    #[inline]
1123    fn handle<F, Fut, ResB, ResBE, ResE>(
1124        self,
1125        request_fn: F,
1126    ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1127    where
1128        F: Fn(Request<Incoming>) -> Fut + 'static,
1129        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
1130        ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
1131        ResE: std::error::Error,
1132        ResBE: std::error::Error,
1133    {
1134        self.handle_with_error_fn(request_fn, |is_timeout| async move {
1135            let mut response = Response::builder();
1136            if is_timeout {
1137                response = response.status(http::StatusCode::REQUEST_TIMEOUT);
1138            } else {
1139                response = response.status(http::StatusCode::BAD_REQUEST);
1140            }
1141            response.body(Empty::new())
1142        })
1143    }
1144}
1145
1146struct Http1Body {
1147    #[allow(clippy::type_complexity)]
1148    inner: Pin<Box<Receiver<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>>>,
1149}
1150
1151impl Body for Http1Body {
1152    type Data = bytes::Bytes;
1153    type Error = std::io::Error;
1154
1155    #[inline]
1156    fn poll_frame(
1157        mut self: Pin<&mut Self>,
1158        cx: &mut Context<'_>,
1159    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
1160        match self.inner.as_mut().poll_next(cx) {
1161            Poll::Ready(Some(Ok(frame))) => Poll::Ready(Some(Ok(frame))),
1162            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
1163            Poll::Ready(None) => Poll::Ready(None),
1164            Poll::Pending => Poll::Pending,
1165        }
1166    }
1167}
1168
1169/// Searches for the header/body separator in a given slice.
1170/// Returns the index of the separator and the length of the separator.
1171#[inline]
1172fn search_header_body_separator(slice: &[u8]) -> Option<(usize, usize)> {
1173    if slice.len() < 2 {
1174        // Slice too short
1175        return None;
1176    }
1177    for (i, b) in slice.iter().copied().enumerate() {
1178        if b == b'\r' {
1179            if slice[i + 1..].chunks(3).next() == Some(&b"\n\r\n"[..]) {
1180                return Some((i, 4));
1181            }
1182        } else if b == b'\n' && slice.get(i + 1) == Some(&b'\n') {
1183            return Some((i, 2));
1184        }
1185    }
1186    None
1187}
1188
1189/// Writes the chunk size to the given buffer in hexadecimal format, followed by `\r\n`.
1190#[inline]
1191fn write_chunk_size(dst: &mut [u8; 18], len: usize) -> &[u8] {
1192    let mut n = len;
1193    let mut pos = dst.len() - 2;
1194    loop {
1195        pos -= 1;
1196        dst[pos] = HEX_DIGITS[n & 0xF];
1197        n >>= 4;
1198        if n == 0 {
1199            break;
1200        }
1201    }
1202    dst[dst.len() - 2] = b'\r';
1203    dst[dst.len() - 1] = b'\n';
1204    &dst[pos..]
1205}