Skip to main content

vibeio_http/h2/
mod.rs

1mod date;
2mod options;
3
4pub use options::*;
5use tokio_util::sync::CancellationToken;
6
7use std::{
8    future::Future,
9    pin::Pin,
10    rc::Rc,
11    task::{Context, Poll},
12};
13
14use bytes::Bytes;
15use http::{Request, Response};
16use http_body::{Body, Frame};
17use http_body_util::BodyExt;
18
19use crate::{h2::date::DateCache, EarlyHints, HttpProtocol, Incoming};
20
21static HTTP2_INVALID_HEADERS: [http::header::HeaderName; 5] = [
22    http::header::HeaderName::from_static("keep-alive"),
23    http::header::HeaderName::from_static("proxy-connection"),
24    http::header::TRANSFER_ENCODING,
25    http::header::TE,
26    http::header::UPGRADE,
27];
28
29struct H2Body {
30    recv: h2::RecvStream,
31    data_done: bool,
32}
33
34impl H2Body {
35    #[inline]
36    fn new(recv: h2::RecvStream) -> Self {
37        Self {
38            recv,
39            data_done: false,
40        }
41    }
42}
43
44impl Body for H2Body {
45    type Data = Bytes;
46    type Error = std::io::Error;
47
48    #[inline]
49    fn poll_frame(
50        mut self: Pin<&mut Self>,
51        cx: &mut Context<'_>,
52    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
53        if !self.data_done {
54            match self.recv.poll_data(cx) {
55                Poll::Ready(Some(Ok(data))) => return Poll::Ready(Some(Ok(Frame::data(data)))),
56                Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(h2_error_to_io(err)))),
57                Poll::Ready(None) => self.data_done = true,
58                Poll::Pending => return Poll::Pending,
59            }
60        }
61
62        match self.recv.poll_trailers(cx) {
63            Poll::Ready(Ok(Some(trailers))) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))),
64            Poll::Ready(Ok(None)) => Poll::Ready(None),
65            Poll::Ready(Err(err)) => Poll::Ready(Some(Err(h2_error_to_io(err)))),
66            Poll::Pending => Poll::Pending,
67        }
68    }
69}
70
71#[inline]
72fn h2_error_to_io(error: h2::Error) -> std::io::Error {
73    if error.is_io() {
74        error.into_io().unwrap_or(std::io::Error::other("io error"))
75    } else {
76        std::io::Error::other(error)
77    }
78}
79
80#[inline]
81fn h2_reason_to_io(reason: h2::Reason) -> std::io::Error {
82    std::io::Error::other(h2::Error::from(reason))
83}
84
85async fn wait_for_send_capacity(send: &mut h2::SendStream<Bytes>) -> Result<(), std::io::Error> {
86    send.reserve_capacity(1);
87
88    if send.capacity() == 0 {
89        std::future::poll_fn(|cx| loop {
90            match send.poll_capacity(cx) {
91                Poll::Ready(Some(Ok(0))) => {}
92                Poll::Ready(Some(Ok(_))) => return Poll::Ready(Ok(())),
93                Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(h2_error_to_io(err))),
94                Poll::Ready(None) => {
95                    return Poll::Ready(Err(std::io::Error::other(
96                        "send stream capacity unexpectedly closed",
97                    )))
98                }
99                Poll::Pending => return Poll::Pending,
100            }
101        })
102        .await
103    } else {
104        let reset_reason = std::future::poll_fn(|cx| match send.poll_reset(cx) {
105            Poll::Ready(Ok(reason)) => Poll::Ready(Ok(Some(reason))),
106            Poll::Ready(Err(err)) => Poll::Ready(Err(h2_error_to_io(err))),
107            Poll::Pending => Poll::Ready(Ok(None)),
108        })
109        .await?;
110        if let Some(reason) = reset_reason {
111            Err(h2_reason_to_io(reason))
112        } else {
113            Ok(())
114        }
115    }
116}
117
118/// An HTTP/2 connection handler.
119///
120/// `Http2` wraps an async I/O stream (`Io`) and drives the HTTP/2 server
121/// connection using the [`h2`] crate. It supports:
122///
123/// - Concurrent request stream handling
124/// - Streaming request/response bodies and trailers
125/// - Automatic `100 Continue` and `103 Early Hints` interim responses
126/// - Per-connection `Date` header caching
127/// - Graceful shutdown via a [`CancellationToken`]
128///
129/// # Construction
130///
131/// ```rust,ignore
132/// let http2 = Http2::new(tcp_stream, Http2Options::default());
133/// ```
134///
135/// # Serving requests
136///
137/// Use the [`HttpProtocol`] trait methods ([`handle`](HttpProtocol::handle) /
138/// [`handle_with_error_fn`](HttpProtocol::handle_with_error_fn)) to drive the
139/// connection to completion.
140pub struct Http2<Io> {
141    io_to_handshake: Option<Io>,
142    date_header_value_cached: DateCache,
143    options: Http2Options,
144    cancel_token: Option<CancellationToken>,
145}
146
147impl<Io> Http2<Io>
148where
149    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
150{
151    /// Creates a new `Http2` connection handler wrapping the given I/O stream.
152    ///
153    /// The `options` value controls HTTP/2 protocol configuration, handshake
154    /// and accept timeouts, and optional behaviour such as automatic
155    /// `100 Continue` responses; see [`Http2Options`] for details.
156    ///
157    /// # Example
158    ///
159    /// ```rust,ignore
160    /// let http2 = Http2::new(tcp_stream, Http2Options::default());
161    /// ```
162    #[inline]
163    pub fn new(io: Io, options: Http2Options) -> Self {
164        Self {
165            io_to_handshake: Some(io),
166            date_header_value_cached: DateCache::default(),
167            options,
168            cancel_token: None,
169        }
170    }
171
172    /// Attaches a [`CancellationToken`] for graceful shutdown.
173    ///
174    /// When the token is cancelled, the handler sends HTTP/2 graceful shutdown
175    /// signals (GOAWAY), stops accepting new streams, and exits cleanly.
176    #[inline]
177    pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
178        self.cancel_token = Some(token);
179        self
180    }
181}
182
183impl<Io> HttpProtocol for Http2<Io>
184where
185    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
186{
187    #[allow(clippy::manual_async_fn)]
188    #[inline]
189    fn handle<F, Fut, ResB, ResBE, ResE>(
190        mut self,
191        request_fn: F,
192    ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
193    where
194        F: Fn(Request<super::Incoming>) -> Fut + 'static,
195        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
196        ResB: http_body::Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
197        ResE: std::error::Error,
198        ResBE: std::error::Error,
199    {
200        async move {
201            let request_fn = Rc::new(request_fn);
202            let handshake_fut = self.options.h2.handshake(
203                self.io_to_handshake
204                    .take()
205                    .ok_or_else(|| std::io::Error::other("no io to handshake"))?,
206            );
207            let mut h2 = (if let Some(timeout) = self.options.handshake_timeout {
208                vibeio::time::timeout(timeout, handshake_fut).await
209            } else {
210                Ok(handshake_fut.await)
211            })
212            .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "handshake timeout"))?
213            .map_err(|e| {
214                if e.is_io() {
215                    e.into_io().unwrap_or(std::io::Error::other("io error"))
216                } else {
217                    std::io::Error::other(e)
218                }
219            })?;
220
221            while let Some(request) = {
222                let res = {
223                    let accept_fut_orig = h2.accept();
224                    let accept_fut_orig_pin = std::pin::pin!(accept_fut_orig);
225                    let cancel_token = self.cancel_token.clone();
226                    let cancel_fut = async move {
227                        if let Some(token) = cancel_token {
228                            token.cancelled().await
229                        } else {
230                            futures_util::future::pending().await
231                        }
232                    };
233                    let cancel_fut_pin = std::pin::pin!(cancel_fut);
234                    let accept_fut =
235                        futures_util::future::select(cancel_fut_pin, accept_fut_orig_pin);
236
237                    match if let Some(timeout) = self.options.accept_timeout {
238                        vibeio::time::timeout(timeout, accept_fut).await
239                    } else {
240                        Ok(accept_fut.await)
241                    } {
242                        Ok(futures_util::future::Either::Right((request, _))) => {
243                            (Some(request), false)
244                        }
245                        Ok(futures_util::future::Either::Left((_, _))) => {
246                            // Canceled
247                            (None, true)
248                        }
249                        Err(_) => {
250                            // Timeout
251                            (None, false)
252                        }
253                    }
254                };
255                match res {
256                    (Some(request), _) => request,
257                    (None, graceful) => {
258                        h2.graceful_shutdown();
259                        let _ = h2.accept().await;
260                        if graceful {
261                            return Ok(());
262                        }
263                        return Err(std::io::Error::new(
264                            std::io::ErrorKind::TimedOut,
265                            "accept timeout",
266                        ));
267                    }
268                }
269            } {
270                let (request, mut stream) = match request {
271                    Ok(d) => d,
272                    Err(e) if e.is_go_away() => {
273                        continue;
274                    }
275                    Err(e) if e.is_io() => {
276                        return Err(e.into_io().unwrap_or(std::io::Error::other("io error")));
277                    }
278                    Err(e) => {
279                        return Err(std::io::Error::other(e));
280                    }
281                };
282
283                let date_cache = self.date_header_value_cached.clone();
284                let request_fn = request_fn.clone();
285                let send_continue_response = self.options.send_continue_response;
286                vibeio::spawn(async move {
287                    let (request_parts, recv_stream) = request.into_parts();
288                    let request_body = Incoming::new(H2Body::new(recv_stream));
289                    let mut request = Request::from_parts(request_parts, request_body);
290
291                    // 100 Continue
292                    if send_continue_response {
293                        let is_100_continue = request
294                            .headers()
295                            .get(http::header::EXPECT)
296                            .and_then(|v| v.to_str().ok())
297                            .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
298                        if is_100_continue {
299                            let mut response = Response::new(());
300                            *response.status_mut() = http::StatusCode::CONTINUE;
301                            let _ = stream.send_informational(response).map_err(h2_error_to_io);
302                        }
303                    }
304
305                    let (early_hints_tx, early_hints_rx) = async_channel::unbounded();
306                    let early_hints = EarlyHints::new(early_hints_tx);
307                    request.extensions_mut().insert(early_hints);
308
309                    let mut response_fut = std::pin::pin!(request_fn(request));
310                    let early_hints_rx = early_hints_rx;
311                    let response_result = loop {
312                        let early_hints_recv_fut = early_hints_rx.recv();
313                        let mut early_hints_recv_fut = std::pin::pin!(early_hints_recv_fut);
314                        let next = std::future::poll_fn(|cx| {
315                            match stream.poll_reset(cx) {
316                                Poll::Ready(Ok(reason)) => {
317                                    return Poll::Ready(Err(h2_reason_to_io(reason)));
318                                }
319                                Poll::Ready(Err(err)) => {
320                                    return Poll::Ready(Err(h2_error_to_io(err)));
321                                }
322                                Poll::Pending => {}
323                            }
324
325                            if let Poll::Ready(res) = response_fut.as_mut().poll(cx) {
326                                return Poll::Ready(Ok(futures_util::future::Either::Left(res)));
327                            }
328
329                            match early_hints_recv_fut.as_mut().poll(cx) {
330                                Poll::Ready(Ok(msg)) => {
331                                    Poll::Ready(Ok(futures_util::future::Either::Right(msg)))
332                                }
333                                Poll::Ready(Err(_)) => Poll::Pending,
334                                Poll::Pending => Poll::Pending,
335                            }
336                        })
337                        .await;
338
339                        match next {
340                            Ok(futures_util::future::Either::Left(response_result)) => {
341                                break response_result;
342                            }
343                            Ok(futures_util::future::Either::Right((headers, sender))) => {
344                                let mut response = Response::new(());
345                                *response.status_mut() = http::StatusCode::EARLY_HINTS;
346                                *response.headers_mut() = headers;
347                                sender
348                                    .into_inner()
349                                    .send(
350                                        stream.send_informational(response).map_err(h2_error_to_io),
351                                    )
352                                    .ok();
353                            }
354                            Err(_) => {
355                                return;
356                            }
357                        }
358                    };
359                    let Ok(mut response) = response_result else {
360                        // Return early if the request handler returns an error
361                        return;
362                    };
363
364                    {
365                        let response_headers = response.headers_mut();
366                        if let Some(http_date) = date_cache.get_date_header_value() {
367                            response_headers
368                                .entry(http::header::DATE)
369                                .or_insert(http_date);
370                        }
371                        if let Some(connection_header) = response_headers
372                            .remove(http::header::CONNECTION)
373                            .as_ref()
374                            .and_then(|v| v.to_str().ok())
375                        {
376                            for name in connection_header.split(',') {
377                                response_headers.remove(name.trim());
378                            }
379                        }
380                        while response_headers.remove(http::header::CONNECTION).is_some() {}
381                        for header in &HTTP2_INVALID_HEADERS {
382                            while response_headers.remove(header).is_some() {}
383                        }
384                    }
385
386                    let response_is_end_stream = response.body().is_end_stream();
387                    if !response_is_end_stream {
388                        if let Some(content_length) = response.body().size_hint().exact() {
389                            if !response
390                                .headers()
391                                .contains_key(http::header::CONTENT_LENGTH)
392                            {
393                                response
394                                    .headers_mut()
395                                    .insert(http::header::CONTENT_LENGTH, content_length.into());
396                            }
397                        }
398                    }
399
400                    let (response_parts, mut response_body) = response.into_parts();
401                    let mut send = match stream.send_response(
402                        Response::from_parts(response_parts, ()),
403                        response_is_end_stream,
404                    ) {
405                        Ok(send) => send,
406                        Err(_) => {
407                            return;
408                        }
409                    };
410
411                    if response_is_end_stream {
412                        return;
413                    }
414
415                    while let Some(chunk) = {
416                        let frame_fut = response_body.frame();
417                        let mut frame_fut = std::pin::pin!(frame_fut);
418                        match std::future::poll_fn(|cx| {
419                            match send.poll_reset(cx) {
420                                Poll::Ready(Ok(reason)) => {
421                                    return Poll::Ready(Err(h2_reason_to_io(reason)));
422                                }
423                                Poll::Ready(Err(err)) => {
424                                    return Poll::Ready(Err(h2_error_to_io(err)));
425                                }
426                                Poll::Pending => {}
427                            }
428
429                            match frame_fut.as_mut().poll(cx) {
430                                Poll::Ready(frame) => Poll::Ready(Ok(frame)),
431                                Poll::Pending => Poll::Pending,
432                            }
433                        })
434                        .await
435                        {
436                            Ok(frame) => frame,
437                            Err(_) => {
438                                return;
439                            }
440                        }
441                    } {
442                        match chunk {
443                            Ok(frame) => {
444                                if frame.is_data() {
445                                    match frame.into_data() {
446                                        Ok(data) => {
447                                            if wait_for_send_capacity(&mut send).await.is_err() {
448                                                return;
449                                            }
450                                            let is_end_stream = response_body.is_end_stream();
451                                            if send.send_data(data, is_end_stream).is_err() {
452                                                return;
453                                            }
454                                            if is_end_stream {
455                                                return;
456                                            }
457                                        }
458                                        Err(_) => {
459                                            return;
460                                        }
461                                    }
462                                } else if frame.is_trailers() {
463                                    match frame.into_trailers() {
464                                        Ok(trailers) => {
465                                            if send.send_trailers(trailers).is_err() {
466                                                return;
467                                            }
468                                            return;
469                                        }
470                                        Err(_) => {
471                                            return;
472                                        }
473                                    }
474                                }
475                            }
476                            Err(_) => {
477                                return;
478                            }
479                        }
480                    }
481                    let _ = send.send_data(Bytes::new(), true);
482                });
483            }
484
485            Ok(())
486        }
487    }
488}