Skip to main content

vibeio_http/h2/
mod.rs

1// TODO: add support for extended CONNECT
2
3mod date;
4mod options;
5mod send;
6mod upgrade;
7
8pub use options::*;
9use pin_project_lite::pin_project;
10use tokio_util::sync::CancellationToken;
11
12use std::{
13    future::Future,
14    pin::Pin,
15    task::{Context, Poll},
16};
17
18use bytes::Bytes;
19use http::{Request, Response};
20use http_body::{Body, Frame};
21
22use crate::{
23    early_hints::EarlyHintsReceiver,
24    h2::{
25        date::DateCache,
26        send::{PipeToSendStream, SendBuf},
27    },
28    EarlyHints, HttpProtocol, Incoming, Upgrade, Upgraded,
29};
30
31static HTTP2_INVALID_HEADERS: [http::header::HeaderName; 5] = [
32    http::header::HeaderName::from_static("keep-alive"),
33    http::header::HeaderName::from_static("proxy-connection"),
34    http::header::CONNECTION,
35    http::header::TRANSFER_ENCODING,
36    http::header::UPGRADE,
37];
38
39pub(crate) struct H2Body {
40    recv: h2::RecvStream,
41    data_done: bool,
42}
43
44impl H2Body {
45    #[inline]
46    fn new(recv: h2::RecvStream) -> Self {
47        Self {
48            recv,
49            data_done: false,
50        }
51    }
52}
53
54impl Body for H2Body {
55    type Data = Bytes;
56    type Error = std::io::Error;
57
58    #[inline]
59    fn poll_frame(
60        mut self: Pin<&mut Self>,
61        cx: &mut Context<'_>,
62    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
63        if !self.data_done {
64            match self.recv.poll_data(cx) {
65                Poll::Ready(Some(Ok(data))) => {
66                    let _ = self.recv.flow_control().release_capacity(data.len());
67                    return Poll::Ready(Some(Ok(Frame::data(data))));
68                }
69                Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(h2_error_to_io(err)))),
70                Poll::Ready(None) => self.data_done = true,
71                Poll::Pending => return Poll::Pending,
72            }
73        }
74
75        match self.recv.poll_trailers(cx) {
76            Poll::Ready(Ok(Some(trailers))) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))),
77            Poll::Ready(Ok(None)) => Poll::Ready(None),
78            Poll::Ready(Err(err)) => Poll::Ready(Some(Err(h2_error_to_io(err)))),
79            Poll::Pending => Poll::Pending,
80        }
81    }
82}
83
84#[inline]
85pub(super) fn h2_error_to_io(error: h2::Error) -> std::io::Error {
86    if error.is_io() {
87        error.into_io().unwrap_or(std::io::Error::other("io error"))
88    } else {
89        std::io::Error::other(error)
90    }
91}
92
93#[inline]
94pub(super) fn h2_reason_to_io(reason: h2::Reason) -> std::io::Error {
95    std::io::Error::other(h2::Error::from(reason))
96}
97
98#[inline]
99fn sanitize_response<ResB>(
100    response: &mut Response<ResB>,
101    send_date_header: bool,
102    date_cache: &DateCache,
103) where
104    ResB: Body<Data = bytes::Bytes>,
105{
106    let response_headers = response.headers_mut();
107    if send_date_header {
108        if let Some(http_date) = date_cache.get_date_header_value() {
109            response_headers
110                .entry(http::header::DATE)
111                .or_insert(http_date);
112        }
113    }
114    for header in &HTTP2_INVALID_HEADERS {
115        if let http::header::Entry::Occupied(entry) = response_headers.entry(header) {
116            entry.remove();
117        }
118    }
119    if response_headers
120        .get(http::header::TE)
121        .is_some_and(|v| v != "trailers")
122    {
123        response_headers.remove(http::header::TE);
124    }
125}
126
127struct PendingUpgrade {
128    tx: oneshot::Sender<Upgraded>,
129    upgraded: std::sync::Arc<std::sync::atomic::AtomicBool>,
130    recv_stream: h2::RecvStream,
131}
132
133pin_project! {
134    struct H2Stream<Fut, ResB>
135    where
136        Fut: Future,
137        ResB: Body<Data = bytes::Bytes>,
138    {
139        stream: h2::server::SendResponse<SendBuf<ResB::Data>>,
140        #[pin]
141        state: H2StreamState<Fut, ResB>,
142    }
143}
144
145pin_project! {
146    #[project = H2StreamStateProj]
147    enum H2StreamState<Fut, ResB>
148    where
149        Fut: Future,
150        ResB: Body<Data = bytes::Bytes>,
151    {
152        Service {
153            #[pin]
154            response_fut: Fut,
155            early_hints_rx: EarlyHintsReceiver,
156            date_cache: DateCache,
157            send_date_header: bool,
158            upgrade: Option<PendingUpgrade>,
159            send_continue: bool,
160            early_hints_open: bool,
161        },
162        Body {
163            #[pin]
164            pipe: PipeToSendStream<ResB>,
165        },
166    }
167}
168
169impl<Fut, ResB> H2Stream<Fut, ResB>
170where
171    Fut: Future,
172    ResB: Body<Data = bytes::Bytes>,
173{
174    #[inline]
175    const fn new(
176        stream: h2::server::SendResponse<SendBuf<ResB::Data>>,
177        response_fut: Fut,
178        early_hints_rx: EarlyHintsReceiver,
179        date_cache: DateCache,
180        send_date_header: bool,
181        upgrade: Option<PendingUpgrade>,
182        send_continue: bool,
183    ) -> Self {
184        Self {
185            stream,
186            state: H2StreamState::Service {
187                response_fut,
188                early_hints_rx,
189                date_cache,
190                send_date_header,
191                upgrade,
192                send_continue,
193                early_hints_open: true,
194            },
195        }
196    }
197}
198
199impl<Fut, ResB, ResBE, ResE> Future for H2Stream<Fut, ResB>
200where
201    Fut: Future<Output = Result<Response<ResB>, ResE>>,
202    ResB: Body<Data = bytes::Bytes, Error = ResBE>,
203    ResE: std::error::Error,
204    ResBE: std::error::Error,
205{
206    type Output = ();
207
208    #[inline]
209    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
210        let mut this = self.project();
211
212        loop {
213            match this.state.as_mut().project() {
214                H2StreamStateProj::Service {
215                    response_fut,
216                    early_hints_rx,
217                    date_cache,
218                    send_date_header,
219                    upgrade,
220                    send_continue,
221                    early_hints_open,
222                } => {
223                    if *send_continue {
224                        let mut response = Response::new(());
225                        *response.status_mut() = http::StatusCode::CONTINUE;
226                        let _ = this
227                            .stream
228                            .send_informational(response)
229                            .map_err(h2_error_to_io);
230                        *send_continue = false;
231                    }
232
233                    if let Poll::Ready(response_result) = response_fut.poll(cx) {
234                        let Ok(mut response) = response_result else {
235                            return Poll::Ready(());
236                        };
237
238                        sanitize_response(&mut response, *send_date_header, date_cache);
239
240                        let response_is_end_stream = response.body().is_end_stream();
241                        if !response_is_end_stream {
242                            if let Some(content_length) = response.body().size_hint().exact() {
243                                if !response
244                                    .headers()
245                                    .contains_key(http::header::CONTENT_LENGTH)
246                                {
247                                    response.headers_mut().insert(
248                                        http::header::CONTENT_LENGTH,
249                                        content_length.into(),
250                                    );
251                                }
252                            }
253                        }
254
255                        let (response_parts, response_body) = response.into_parts();
256                        let Ok(send) = this.stream.send_response(
257                            Response::from_parts(response_parts, ()),
258                            response_is_end_stream && upgrade.is_none(),
259                        ) else {
260                            return Poll::Ready(());
261                        };
262
263                        if let Some(PendingUpgrade {
264                            tx,
265                            upgraded,
266                            recv_stream,
267                        }) = upgrade.take()
268                        {
269                            if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
270                                let (upgraded, task) = self::upgrade::pair(send, recv_stream);
271                                let _ = tx.send(Upgraded::new(upgraded, None));
272                                vibeio::spawn(task);
273                                return Poll::Ready(());
274                            }
275                        }
276
277                        if response_is_end_stream {
278                            return Poll::Ready(());
279                        }
280
281                        this.state.set(H2StreamState::Body {
282                            pipe: PipeToSendStream::new(send, response_body),
283                        });
284                        continue;
285                    }
286
287                    match this.stream.poll_reset(cx) {
288                        Poll::Ready(Ok(_)) | Poll::Ready(Err(_)) => return Poll::Ready(()),
289                        Poll::Pending => {}
290                    }
291
292                    if *early_hints_open {
293                        match early_hints_rx.poll_recv(cx) {
294                            Poll::Ready(Some((headers, sender))) => {
295                                let mut response = Response::new(());
296                                *response.status_mut() = http::StatusCode::EARLY_HINTS;
297                                *response.headers_mut() = headers;
298                                sender
299                                    .into_inner()
300                                    .send(
301                                        this.stream
302                                            .send_informational(response)
303                                            .map_err(h2_error_to_io),
304                                    )
305                                    .ok();
306                                continue;
307                            }
308                            Poll::Ready(None) => {
309                                *early_hints_open = false;
310                                continue;
311                            }
312                            Poll::Pending => {}
313                        }
314                    }
315
316                    return Poll::Pending;
317                }
318                H2StreamStateProj::Body { pipe } => {
319                    return pipe.poll(cx).map(|_| ());
320                }
321            }
322        }
323    }
324}
325
326/// An HTTP/2 connection handler.
327///
328/// `Http2` wraps an async I/O stream (`Io`) and drives the HTTP/2 server
329/// connection using the [`h2`] crate. It supports:
330///
331/// - Concurrent request stream handling
332/// - Streaming request/response bodies and trailers
333/// - Automatic `100 Continue` and `103 Early Hints` interim responses
334/// - Per-connection `Date` header caching
335/// - Graceful shutdown via a [`CancellationToken`]
336///
337/// # Construction
338///
339/// ```rust,ignore
340/// let http2 = Http2::new(tcp_stream, Http2Options::default());
341/// ```
342///
343/// # Serving requests
344///
345/// Use the [`HttpProtocol`] trait methods ([`handle`](HttpProtocol::handle) /
346/// [`handle_with_error_fn`](HttpProtocol::handle_with_error_fn)) to drive the
347/// connection to completion.
348pub struct Http2<Io> {
349    io_to_handshake: Option<Io>,
350    date_header_value_cached: DateCache,
351    options: Http2Options,
352    cancel_token: Option<CancellationToken>,
353}
354
355impl<Io> Http2<Io>
356where
357    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
358{
359    /// Creates a new `Http2` connection handler wrapping the given I/O stream.
360    ///
361    /// The `options` value controls HTTP/2 protocol configuration, handshake
362    /// and accept timeouts, and optional behaviour such as automatic
363    /// `100 Continue` responses; see [`Http2Options`] for details.
364    ///
365    /// # Example
366    ///
367    /// ```rust,ignore
368    /// let http2 = Http2::new(tcp_stream, Http2Options::default());
369    /// ```
370    #[inline]
371    pub fn new(io: Io, options: Http2Options) -> Self {
372        Self {
373            io_to_handshake: Some(io),
374            date_header_value_cached: DateCache::default(),
375            options,
376            cancel_token: None,
377        }
378    }
379
380    /// Attaches a [`CancellationToken`] for graceful shutdown.
381    ///
382    /// When the token is cancelled, the handler sends HTTP/2 graceful shutdown
383    /// signals (GOAWAY), stops accepting new streams, and exits cleanly.
384    #[inline]
385    pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
386        self.cancel_token = Some(token);
387        self
388    }
389}
390
391impl<Io> HttpProtocol for Http2<Io>
392where
393    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
394{
395    #[allow(clippy::manual_async_fn)]
396    #[inline]
397    fn handle<F, Fut, ResB, ResBE, ResE>(
398        mut self,
399        request_fn: F,
400    ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
401    where
402        F: Fn(Request<super::Incoming>) -> Fut + 'static,
403        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
404        ResB: http_body::Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
405        ResE: std::error::Error,
406        ResBE: std::error::Error,
407    {
408        async move {
409            let handshake_fut = self.options.h2.handshake(
410                self.io_to_handshake
411                    .take()
412                    .ok_or_else(|| std::io::Error::other("no io to handshake"))?,
413            );
414            let mut h2 = (if let Some(timeout) = self.options.handshake_timeout {
415                vibeio::time::timeout(timeout, handshake_fut).await
416            } else {
417                Ok(handshake_fut.await)
418            })
419            .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "handshake timeout"))?
420            .map_err(|e| {
421                if e.is_io() {
422                    e.into_io().unwrap_or(std::io::Error::other("io error"))
423                } else {
424                    std::io::Error::other(e)
425                }
426            })?;
427
428            while let Some(request) = {
429                let res = {
430                    let accept_fut_orig = h2.accept();
431                    let accept_fut_orig_pin = std::pin::pin!(accept_fut_orig);
432                    let cancel_token = self.cancel_token.clone();
433                    let cancel_fut = async move {
434                        if let Some(token) = cancel_token {
435                            token.cancelled().await
436                        } else {
437                            futures_util::future::pending().await
438                        }
439                    };
440                    let cancel_fut_pin = std::pin::pin!(cancel_fut);
441                    let accept_fut =
442                        futures_util::future::select(cancel_fut_pin, accept_fut_orig_pin);
443
444                    match if let Some(timeout) = self.options.accept_timeout {
445                        vibeio::time::timeout(timeout, accept_fut).await
446                    } else {
447                        Ok(accept_fut.await)
448                    } {
449                        Ok(futures_util::future::Either::Right((request, _))) => {
450                            (Some(request), false)
451                        }
452                        Ok(futures_util::future::Either::Left((_, _))) => {
453                            // Canceled
454                            (None, true)
455                        }
456                        Err(_) => {
457                            // Timeout
458                            (None, false)
459                        }
460                    }
461                };
462                match res {
463                    (Some(request), _) => request,
464                    (None, graceful) => {
465                        h2.graceful_shutdown();
466                        let _ = h2.accept().await;
467                        if graceful {
468                            return Ok(());
469                        }
470                        return Err(std::io::Error::new(
471                            std::io::ErrorKind::TimedOut,
472                            "accept timeout",
473                        ));
474                    }
475                }
476            } {
477                let (request, stream) = match request {
478                    Ok(d) => d,
479                    Err(e) if e.is_go_away() => {
480                        continue;
481                    }
482                    Err(e) if e.is_io() => {
483                        return Err(e.into_io().unwrap_or(std::io::Error::other("io error")));
484                    }
485                    Err(e) => {
486                        return Err(std::io::Error::other(e));
487                    }
488                };
489
490                let date_cache = self.date_header_value_cached.clone();
491                let (request_parts, recv_stream) = request.into_parts();
492                let (request_body, upgrade) = if request_parts.method == http::Method::CONNECT {
493                    (Incoming::Empty, Some(recv_stream))
494                } else {
495                    (Incoming::H2(H2Body::new(recv_stream)), None)
496                };
497                let mut request = Request::from_parts(request_parts, request_body);
498
499                // 100 Continue
500                let is_100_continue = self.options.send_continue_response
501                    && request
502                        .headers()
503                        .get(http::header::EXPECT)
504                        .and_then(|v| v.to_str().ok())
505                        .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
506
507                // Install early hints
508                let (early_hints, early_hints_rx) = EarlyHints::new_lazy();
509                request.extensions_mut().insert(early_hints);
510
511                // Install HTTP upgrade
512                let upgrade = if let Some(recv_stream) = upgrade {
513                    let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
514                    let upgrade = Upgrade::new(upgrade_rx);
515                    let upgraded = upgrade.upgraded.clone();
516                    request.extensions_mut().insert(upgrade);
517                    Some(PendingUpgrade {
518                        tx: upgrade_tx,
519                        upgraded,
520                        recv_stream,
521                    })
522                } else {
523                    None
524                };
525
526                vibeio::spawn(H2Stream::new(
527                    stream,
528                    request_fn(request),
529                    early_hints_rx,
530                    date_cache,
531                    self.options.send_date_header,
532                    upgrade,
533                    is_100_continue,
534                ));
535            }
536
537            Ok(())
538        }
539    }
540}