Skip to main content

vibeio_http/h2/
mod.rs

1mod date;
2mod options;
3
4pub use options::*;
5use tokio_util::sync::CancellationToken;
6
7use std::{
8    future::Future,
9    pin::Pin,
10    rc::Rc,
11    task::{Context, Poll},
12};
13
14use bytes::Bytes;
15use http::{Request, Response};
16use http_body::{Body, Frame};
17use http_body_util::BodyExt;
18
19use crate::{h2::date::DateCache, EarlyHints, HttpProtocol, Incoming};
20
21static HTTP2_INVALID_HEADERS: [http::header::HeaderName; 5] = [
22    http::header::HeaderName::from_static("keep-alive"),
23    http::header::HeaderName::from_static("proxy-connection"),
24    http::header::TRANSFER_ENCODING,
25    http::header::TE,
26    http::header::UPGRADE,
27];
28
29struct H2Body {
30    recv: h2::RecvStream,
31    data_done: bool,
32}
33
34impl H2Body {
35    #[inline]
36    fn new(recv: h2::RecvStream) -> Self {
37        Self {
38            recv,
39            data_done: false,
40        }
41    }
42}
43
44impl Body for H2Body {
45    type Data = Bytes;
46    type Error = std::io::Error;
47
48    #[inline]
49    fn poll_frame(
50        mut self: Pin<&mut Self>,
51        cx: &mut Context<'_>,
52    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
53        if !self.data_done {
54            match self.recv.poll_data(cx) {
55                Poll::Ready(Some(Ok(data))) => {
56                    let _ = self.recv.flow_control().release_capacity(data.len());
57                    return Poll::Ready(Some(Ok(Frame::data(data))));
58                }
59                Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(h2_error_to_io(err)))),
60                Poll::Ready(None) => self.data_done = true,
61                Poll::Pending => return Poll::Pending,
62            }
63        }
64
65        match self.recv.poll_trailers(cx) {
66            Poll::Ready(Ok(Some(trailers))) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))),
67            Poll::Ready(Ok(None)) => Poll::Ready(None),
68            Poll::Ready(Err(err)) => Poll::Ready(Some(Err(h2_error_to_io(err)))),
69            Poll::Pending => Poll::Pending,
70        }
71    }
72}
73
74#[inline]
75fn h2_error_to_io(error: h2::Error) -> std::io::Error {
76    if error.is_io() {
77        error.into_io().unwrap_or(std::io::Error::other("io error"))
78    } else {
79        std::io::Error::other(error)
80    }
81}
82
83#[inline]
84fn h2_reason_to_io(reason: h2::Reason) -> std::io::Error {
85    std::io::Error::other(h2::Error::from(reason))
86}
87
88#[inline]
89async fn wait_for_send_capacity(
90    send: &mut h2::SendStream<Bytes>,
91    desired_capacity: usize,
92) -> Result<usize, std::io::Error> {
93    if desired_capacity == 0 {
94        return Ok(0);
95    }
96
97    send.reserve_capacity(desired_capacity);
98
99    if send.capacity() > 0 {
100        return Ok(send.capacity().min(desired_capacity));
101    }
102
103    std::future::poll_fn(|cx| loop {
104        match send.poll_reset(cx) {
105            Poll::Ready(Ok(reason)) => return Poll::Ready(Err(h2_reason_to_io(reason))),
106            Poll::Ready(Err(err)) => return Poll::Ready(Err(h2_error_to_io(err))),
107            Poll::Pending => {}
108        }
109
110        match send.poll_capacity(cx) {
111            Poll::Ready(Some(Ok(0))) => {}
112            Poll::Ready(Some(Ok(capacity))) => {
113                return Poll::Ready(Ok(capacity.min(desired_capacity)));
114            }
115            Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(h2_error_to_io(err))),
116            Poll::Ready(None) => {
117                return Poll::Ready(Err(std::io::Error::other(
118                    "send stream capacity unexpectedly closed",
119                )))
120            }
121            Poll::Pending => return Poll::Pending,
122        }
123    })
124    .await
125}
126
127/// An HTTP/2 connection handler.
128///
129/// `Http2` wraps an async I/O stream (`Io`) and drives the HTTP/2 server
130/// connection using the [`h2`] crate. It supports:
131///
132/// - Concurrent request stream handling
133/// - Streaming request/response bodies and trailers
134/// - Automatic `100 Continue` and `103 Early Hints` interim responses
135/// - Per-connection `Date` header caching
136/// - Graceful shutdown via a [`CancellationToken`]
137///
138/// # Construction
139///
140/// ```rust,ignore
141/// let http2 = Http2::new(tcp_stream, Http2Options::default());
142/// ```
143///
144/// # Serving requests
145///
146/// Use the [`HttpProtocol`] trait methods ([`handle`](HttpProtocol::handle) /
147/// [`handle_with_error_fn`](HttpProtocol::handle_with_error_fn)) to drive the
148/// connection to completion.
149pub struct Http2<Io> {
150    io_to_handshake: Option<Io>,
151    date_header_value_cached: DateCache,
152    options: Http2Options,
153    cancel_token: Option<CancellationToken>,
154}
155
156impl<Io> Http2<Io>
157where
158    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
159{
160    /// Creates a new `Http2` connection handler wrapping the given I/O stream.
161    ///
162    /// The `options` value controls HTTP/2 protocol configuration, handshake
163    /// and accept timeouts, and optional behaviour such as automatic
164    /// `100 Continue` responses; see [`Http2Options`] for details.
165    ///
166    /// # Example
167    ///
168    /// ```rust,ignore
169    /// let http2 = Http2::new(tcp_stream, Http2Options::default());
170    /// ```
171    #[inline]
172    pub fn new(io: Io, options: Http2Options) -> Self {
173        Self {
174            io_to_handshake: Some(io),
175            date_header_value_cached: DateCache::default(),
176            options,
177            cancel_token: None,
178        }
179    }
180
181    /// Attaches a [`CancellationToken`] for graceful shutdown.
182    ///
183    /// When the token is cancelled, the handler sends HTTP/2 graceful shutdown
184    /// signals (GOAWAY), stops accepting new streams, and exits cleanly.
185    #[inline]
186    pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
187        self.cancel_token = Some(token);
188        self
189    }
190}
191
192impl<Io> HttpProtocol for Http2<Io>
193where
194    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
195{
196    #[allow(clippy::manual_async_fn)]
197    #[inline]
198    fn handle<F, Fut, ResB, ResBE, ResE>(
199        mut self,
200        request_fn: F,
201    ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
202    where
203        F: Fn(Request<super::Incoming>) -> Fut + 'static,
204        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
205        ResB: http_body::Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
206        ResE: std::error::Error,
207        ResBE: std::error::Error,
208    {
209        async move {
210            let request_fn = Rc::new(request_fn);
211            let handshake_fut = self.options.h2.handshake(
212                self.io_to_handshake
213                    .take()
214                    .ok_or_else(|| std::io::Error::other("no io to handshake"))?,
215            );
216            let mut h2 = (if let Some(timeout) = self.options.handshake_timeout {
217                vibeio::time::timeout(timeout, handshake_fut).await
218            } else {
219                Ok(handshake_fut.await)
220            })
221            .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "handshake timeout"))?
222            .map_err(|e| {
223                if e.is_io() {
224                    e.into_io().unwrap_or(std::io::Error::other("io error"))
225                } else {
226                    std::io::Error::other(e)
227                }
228            })?;
229
230            while let Some(request) = {
231                let res = {
232                    let accept_fut_orig = h2.accept();
233                    let accept_fut_orig_pin = std::pin::pin!(accept_fut_orig);
234                    let cancel_token = self.cancel_token.clone();
235                    let cancel_fut = async move {
236                        if let Some(token) = cancel_token {
237                            token.cancelled().await
238                        } else {
239                            futures_util::future::pending().await
240                        }
241                    };
242                    let cancel_fut_pin = std::pin::pin!(cancel_fut);
243                    let accept_fut =
244                        futures_util::future::select(cancel_fut_pin, accept_fut_orig_pin);
245
246                    match if let Some(timeout) = self.options.accept_timeout {
247                        vibeio::time::timeout(timeout, accept_fut).await
248                    } else {
249                        Ok(accept_fut.await)
250                    } {
251                        Ok(futures_util::future::Either::Right((request, _))) => {
252                            (Some(request), false)
253                        }
254                        Ok(futures_util::future::Either::Left((_, _))) => {
255                            // Canceled
256                            (None, true)
257                        }
258                        Err(_) => {
259                            // Timeout
260                            (None, false)
261                        }
262                    }
263                };
264                match res {
265                    (Some(request), _) => request,
266                    (None, graceful) => {
267                        h2.graceful_shutdown();
268                        let _ = h2.accept().await;
269                        if graceful {
270                            return Ok(());
271                        }
272                        return Err(std::io::Error::new(
273                            std::io::ErrorKind::TimedOut,
274                            "accept timeout",
275                        ));
276                    }
277                }
278            } {
279                let (request, mut stream) = match request {
280                    Ok(d) => d,
281                    Err(e) if e.is_go_away() => {
282                        continue;
283                    }
284                    Err(e) if e.is_io() => {
285                        return Err(e.into_io().unwrap_or(std::io::Error::other("io error")));
286                    }
287                    Err(e) => {
288                        return Err(std::io::Error::other(e));
289                    }
290                };
291
292                let date_cache = self.date_header_value_cached.clone();
293                let request_fn = request_fn.clone();
294                let send_continue_response = self.options.send_continue_response;
295                vibeio::spawn(async move {
296                    let (request_parts, recv_stream) = request.into_parts();
297                    let request_body = Incoming::new(H2Body::new(recv_stream));
298                    let mut request = Request::from_parts(request_parts, request_body);
299
300                    // 100 Continue
301                    if send_continue_response {
302                        let is_100_continue = request
303                            .headers()
304                            .get(http::header::EXPECT)
305                            .and_then(|v| v.to_str().ok())
306                            .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
307                        if is_100_continue {
308                            let mut response = Response::new(());
309                            *response.status_mut() = http::StatusCode::CONTINUE;
310                            let _ = stream.send_informational(response).map_err(h2_error_to_io);
311                        }
312                    }
313
314                    let (early_hints_tx, early_hints_rx) = async_channel::unbounded();
315                    let early_hints = EarlyHints::new(early_hints_tx);
316                    request.extensions_mut().insert(early_hints);
317
318                    let mut response_fut = std::pin::pin!(request_fn(request));
319                    let early_hints_rx = early_hints_rx;
320                    let response_result = loop {
321                        let early_hints_recv_fut = early_hints_rx.recv();
322                        let mut early_hints_recv_fut = std::pin::pin!(early_hints_recv_fut);
323                        let next = std::future::poll_fn(|cx| {
324                            match stream.poll_reset(cx) {
325                                Poll::Ready(Ok(reason)) => {
326                                    return Poll::Ready(Err(h2_reason_to_io(reason)));
327                                }
328                                Poll::Ready(Err(err)) => {
329                                    return Poll::Ready(Err(h2_error_to_io(err)));
330                                }
331                                Poll::Pending => {}
332                            }
333
334                            if let Poll::Ready(res) = response_fut.as_mut().poll(cx) {
335                                return Poll::Ready(Ok(futures_util::future::Either::Left(res)));
336                            }
337
338                            match early_hints_recv_fut.as_mut().poll(cx) {
339                                Poll::Ready(Ok(msg)) => {
340                                    Poll::Ready(Ok(futures_util::future::Either::Right(msg)))
341                                }
342                                Poll::Ready(Err(_)) => Poll::Pending,
343                                Poll::Pending => Poll::Pending,
344                            }
345                        })
346                        .await;
347
348                        match next {
349                            Ok(futures_util::future::Either::Left(response_result)) => {
350                                break response_result;
351                            }
352                            Ok(futures_util::future::Either::Right((headers, sender))) => {
353                                let mut response = Response::new(());
354                                *response.status_mut() = http::StatusCode::EARLY_HINTS;
355                                *response.headers_mut() = headers;
356                                sender
357                                    .into_inner()
358                                    .send(
359                                        stream.send_informational(response).map_err(h2_error_to_io),
360                                    )
361                                    .ok();
362                            }
363                            Err(_) => {
364                                return;
365                            }
366                        }
367                    };
368                    let Ok(mut response) = response_result else {
369                        // Return early if the request handler returns an error
370                        return;
371                    };
372
373                    {
374                        let response_headers = response.headers_mut();
375                        if let Some(http_date) = date_cache.get_date_header_value() {
376                            response_headers
377                                .entry(http::header::DATE)
378                                .or_insert(http_date);
379                        }
380                        if let Some(connection_header) = response_headers
381                            .remove(http::header::CONNECTION)
382                            .as_ref()
383                            .and_then(|v| v.to_str().ok())
384                        {
385                            for name in connection_header.split(',') {
386                                response_headers.remove(name.trim());
387                            }
388                        }
389                        while response_headers.remove(http::header::CONNECTION).is_some() {}
390                        for header in &HTTP2_INVALID_HEADERS {
391                            while response_headers.remove(header).is_some() {}
392                        }
393                    }
394
395                    let response_is_end_stream = response.body().is_end_stream();
396                    if !response_is_end_stream {
397                        if let Some(content_length) = response.body().size_hint().exact() {
398                            if !response
399                                .headers()
400                                .contains_key(http::header::CONTENT_LENGTH)
401                            {
402                                response
403                                    .headers_mut()
404                                    .insert(http::header::CONTENT_LENGTH, content_length.into());
405                            }
406                        }
407                    }
408
409                    let (response_parts, mut response_body) = response.into_parts();
410                    let mut send = match stream.send_response(
411                        Response::from_parts(response_parts, ()),
412                        response_is_end_stream,
413                    ) {
414                        Ok(send) => send,
415                        Err(_) => {
416                            return;
417                        }
418                    };
419
420                    if response_is_end_stream {
421                        return;
422                    }
423
424                    while let Some(chunk) = {
425                        let frame_fut = response_body.frame();
426                        let mut frame_fut = std::pin::pin!(frame_fut);
427                        match std::future::poll_fn(|cx| {
428                            match send.poll_reset(cx) {
429                                Poll::Ready(Ok(reason)) => {
430                                    return Poll::Ready(Err(h2_reason_to_io(reason)));
431                                }
432                                Poll::Ready(Err(err)) => {
433                                    return Poll::Ready(Err(h2_error_to_io(err)));
434                                }
435                                Poll::Pending => {}
436                            }
437
438                            match frame_fut.as_mut().poll(cx) {
439                                Poll::Ready(frame) => Poll::Ready(Ok(frame)),
440                                Poll::Pending => Poll::Pending,
441                            }
442                        })
443                        .await
444                        {
445                            Ok(frame) => frame,
446                            Err(_) => {
447                                return;
448                            }
449                        }
450                    } {
451                        match chunk {
452                            Ok(frame) => {
453                                if frame.is_data() {
454                                    match frame.into_data() {
455                                        Ok(mut data) => {
456                                            let response_is_end_stream =
457                                                response_body.is_end_stream();
458                                            if data.is_empty() {
459                                                if send
460                                                    .send_data(data, response_is_end_stream)
461                                                    .is_err()
462                                                {
463                                                    return;
464                                                }
465                                                if response_is_end_stream {
466                                                    return;
467                                                }
468                                                continue;
469                                            }
470
471                                            while !data.is_empty() {
472                                                let capacity = match wait_for_send_capacity(
473                                                    &mut send,
474                                                    data.len(),
475                                                )
476                                                .await
477                                                {
478                                                    Ok(capacity) => capacity,
479                                                    Err(_) => return,
480                                                };
481                                                let chunk = data.split_to(capacity.min(data.len()));
482                                                let is_end_stream =
483                                                    response_is_end_stream && data.is_empty();
484                                                if send.send_data(chunk, is_end_stream).is_err() {
485                                                    return;
486                                                }
487                                                if is_end_stream {
488                                                    return;
489                                                }
490                                            }
491                                        }
492                                        Err(_) => {
493                                            return;
494                                        }
495                                    }
496                                } else if frame.is_trailers() {
497                                    match frame.into_trailers() {
498                                        Ok(trailers) => {
499                                            if send.send_trailers(trailers).is_err() {
500                                                return;
501                                            }
502                                            return;
503                                        }
504                                        Err(_) => {
505                                            return;
506                                        }
507                                    }
508                                }
509                            }
510                            Err(_) => {
511                                return;
512                            }
513                        }
514                    }
515                    let _ = send.send_data(Bytes::new(), true);
516                });
517            }
518
519            Ok(())
520        }
521    }
522}