Skip to main content

vibeio_http/h2/
mod.rs

1// TODO: add support for extended CONNECT
2
3mod date;
4mod options;
5mod send;
6mod upgrade;
7
8pub use options::*;
9use pin_project_lite::pin_project;
10use tokio_util::sync::CancellationToken;
11
12use std::{
13    future::Future,
14    pin::Pin,
15    sync::{atomic::AtomicBool, Arc},
16    task::{Context, Poll},
17};
18
19use bytes::Bytes;
20use http::{Request, Response};
21use http_body::{Body, Frame};
22
23use crate::{
24    early_hints::EarlyHintsReceiver,
25    h2::{
26        date::DateCache,
27        send::{PipeToSendStream, SendBuf},
28    },
29    EarlyHints, HttpProtocol, Incoming, Upgrade, Upgraded,
30};
31
32static HTTP2_INVALID_HEADERS: [http::header::HeaderName; 5] = [
33    http::header::HeaderName::from_static("keep-alive"),
34    http::header::HeaderName::from_static("proxy-connection"),
35    http::header::CONNECTION,
36    http::header::TRANSFER_ENCODING,
37    http::header::UPGRADE,
38];
39
40pub(crate) struct H2Body {
41    recv: h2::RecvStream,
42    data_done: bool,
43    send_continue_body: Option<Arc<AtomicBool>>,
44}
45
46impl H2Body {
47    #[inline]
48    fn new(recv: h2::RecvStream, send_continue_body: Option<Arc<AtomicBool>>) -> Self {
49        Self {
50            recv,
51            data_done: false,
52            send_continue_body,
53        }
54    }
55}
56
57impl Body for H2Body {
58    type Data = Bytes;
59    type Error = std::io::Error;
60
61    #[inline]
62    fn poll_frame(
63        mut self: Pin<&mut Self>,
64        cx: &mut Context<'_>,
65    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
66        if !self.data_done {
67            match self.recv.poll_data(cx) {
68                Poll::Ready(Some(Ok(data))) => {
69                    let _ = self.recv.flow_control().release_capacity(data.len());
70                    return Poll::Ready(Some(Ok(Frame::data(data))));
71                }
72                Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(h2_error_to_io(err)))),
73                Poll::Ready(None) => self.data_done = true,
74                Poll::Pending => {
75                    if let Some(scb) = self.send_continue_body.as_ref() {
76                        scb.store(true, std::sync::atomic::Ordering::Relaxed);
77                    }
78                    return Poll::Pending;
79                }
80            }
81        }
82
83        match self.recv.poll_trailers(cx) {
84            Poll::Ready(Ok(Some(trailers))) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))),
85            Poll::Ready(Ok(None)) => Poll::Ready(None),
86            Poll::Ready(Err(err)) => Poll::Ready(Some(Err(h2_error_to_io(err)))),
87            Poll::Pending => {
88                if let Some(scb) = self.send_continue_body.as_ref() {
89                    scb.store(true, std::sync::atomic::Ordering::Relaxed);
90                }
91                Poll::Pending
92            }
93        }
94    }
95}
96
97#[inline]
98pub(super) fn h2_error_to_io(error: h2::Error) -> std::io::Error {
99    if error.is_io() {
100        error.into_io().unwrap_or(std::io::Error::other("io error"))
101    } else {
102        std::io::Error::other(error)
103    }
104}
105
106#[inline]
107pub(super) fn h2_reason_to_io(reason: h2::Reason) -> std::io::Error {
108    std::io::Error::other(h2::Error::from(reason))
109}
110
111#[inline]
112fn sanitize_response<ResB>(
113    response: &mut Response<ResB>,
114    send_date_header: bool,
115    date_cache: &DateCache,
116) where
117    ResB: Body<Data = bytes::Bytes>,
118{
119    let response_headers = response.headers_mut();
120    if send_date_header {
121        if let Some(http_date) = date_cache.get_date_header_value() {
122            response_headers
123                .entry(http::header::DATE)
124                .or_insert(http_date);
125        }
126    }
127    for header in &HTTP2_INVALID_HEADERS {
128        if let http::header::Entry::Occupied(entry) = response_headers.entry(header) {
129            entry.remove();
130        }
131    }
132    if response_headers
133        .get(http::header::TE)
134        .is_some_and(|v| v != "trailers")
135    {
136        response_headers.remove(http::header::TE);
137    }
138}
139
140struct PendingUpgrade {
141    tx: oneshot::Sender<Upgraded>,
142    upgraded: std::sync::Arc<std::sync::atomic::AtomicBool>,
143    recv_stream: h2::RecvStream,
144}
145
146pin_project! {
147    struct H2Stream<Fut, ResB>
148    where
149        Fut: Future,
150        ResB: Body<Data = bytes::Bytes>,
151    {
152        stream: h2::server::SendResponse<SendBuf<ResB::Data>>,
153        #[pin]
154        state: H2StreamState<Fut, ResB>,
155    }
156}
157
158pin_project! {
159    #[project = H2StreamStateProj]
160    enum H2StreamState<Fut, ResB>
161    where
162        Fut: Future,
163        ResB: Body<Data = bytes::Bytes>,
164    {
165        Service {
166            #[pin]
167            response_fut: Fut,
168            early_hints_rx: EarlyHintsReceiver,
169            date_cache: DateCache,
170            send_date_header: bool,
171            upgrade: Option<PendingUpgrade>,
172            send_continue: bool,
173            early_hints_open: bool,
174            send_continue_body: Option<Arc<AtomicBool>>,
175            continue_sent: bool
176        },
177        Body {
178            #[pin]
179            pipe: PipeToSendStream<ResB>,
180        },
181    }
182}
183
184impl<Fut, ResB> H2Stream<Fut, ResB>
185where
186    Fut: Future,
187    ResB: Body<Data = bytes::Bytes>,
188{
189    #[allow(clippy::too_many_arguments)]
190    #[inline]
191    const fn new(
192        stream: h2::server::SendResponse<SendBuf<ResB::Data>>,
193        response_fut: Fut,
194        early_hints_rx: EarlyHintsReceiver,
195        date_cache: DateCache,
196        send_date_header: bool,
197        upgrade: Option<PendingUpgrade>,
198        send_continue: bool,
199        send_continue_body: Option<Arc<AtomicBool>>,
200    ) -> Self {
201        Self {
202            stream,
203            state: H2StreamState::Service {
204                response_fut,
205                early_hints_rx,
206                date_cache,
207                send_date_header,
208                upgrade,
209                send_continue,
210                early_hints_open: true,
211                send_continue_body,
212                continue_sent: false,
213            },
214        }
215    }
216}
217
218impl<Fut, ResB, ResBE, ResE> Future for H2Stream<Fut, ResB>
219where
220    Fut: Future<Output = Result<Response<ResB>, ResE>>,
221    ResB: Body<Data = bytes::Bytes, Error = ResBE>,
222    ResE: std::error::Error,
223    ResBE: std::error::Error,
224{
225    type Output = ();
226
227    #[inline]
228    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
229        let mut this = self.project();
230
231        loop {
232            match this.state.as_mut().project() {
233                H2StreamStateProj::Service {
234                    response_fut,
235                    early_hints_rx,
236                    date_cache,
237                    send_date_header,
238                    upgrade,
239                    send_continue,
240                    early_hints_open,
241                    send_continue_body,
242                    continue_sent,
243                } => {
244                    if let Poll::Ready(response_result) = response_fut.poll(cx) {
245                        let Ok(mut response) = response_result else {
246                            return Poll::Ready(());
247                        };
248
249                        sanitize_response(&mut response, *send_date_header, date_cache);
250
251                        let response_is_end_stream = response.body().is_end_stream();
252                        if !response_is_end_stream {
253                            if let Some(content_length) = response.body().size_hint().exact() {
254                                if !response
255                                    .headers()
256                                    .contains_key(http::header::CONTENT_LENGTH)
257                                {
258                                    response.headers_mut().insert(
259                                        http::header::CONTENT_LENGTH,
260                                        content_length.into(),
261                                    );
262                                }
263                            }
264                        }
265
266                        if *send_continue && !*continue_sent {
267                            if !response.status().is_client_error()
268                                && !response.status().is_server_error()
269                            {
270                                let mut response = Response::new(());
271                                *response.status_mut() = http::StatusCode::CONTINUE;
272                                let _ = this
273                                    .stream
274                                    .send_informational(response)
275                                    .map_err(h2_error_to_io);
276                            }
277                            *continue_sent = true;
278                        }
279
280                        let (response_parts, response_body) = response.into_parts();
281                        let Ok(send) = this.stream.send_response(
282                            Response::from_parts(response_parts, ()),
283                            response_is_end_stream && upgrade.is_none(),
284                        ) else {
285                            return Poll::Ready(());
286                        };
287
288                        if let Some(PendingUpgrade {
289                            tx,
290                            upgraded,
291                            recv_stream,
292                        }) = upgrade.take()
293                        {
294                            if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
295                                let (upgraded, task) = self::upgrade::pair(send, recv_stream);
296                                let _ = tx.send(Upgraded::new(upgraded, None));
297                                vibeio::spawn(task);
298                                return Poll::Ready(());
299                            }
300                        }
301
302                        if response_is_end_stream {
303                            return Poll::Ready(());
304                        }
305
306                        this.state.set(H2StreamState::Body {
307                            pipe: PipeToSendStream::new(send, response_body),
308                        });
309                        continue;
310                    }
311
312                    match this.stream.poll_reset(cx) {
313                        Poll::Ready(Ok(_)) | Poll::Ready(Err(_)) => return Poll::Ready(()),
314                        Poll::Pending => {}
315                    }
316
317                    if *send_continue
318                        && !*continue_sent
319                        && send_continue_body
320                            .as_ref()
321                            .is_some_and(|scb| scb.load(std::sync::atomic::Ordering::Relaxed))
322                    {
323                        let mut response = Response::new(());
324                        *response.status_mut() = http::StatusCode::CONTINUE;
325                        let _ = this
326                            .stream
327                            .send_informational(response)
328                            .map_err(h2_error_to_io);
329                        *continue_sent = true;
330                    }
331
332                    if *early_hints_open {
333                        match early_hints_rx.poll_recv(cx) {
334                            Poll::Ready(Some((headers, sender))) => {
335                                let mut response = Response::new(());
336                                *response.status_mut() = http::StatusCode::EARLY_HINTS;
337                                *response.headers_mut() = headers;
338                                sender
339                                    .into_inner()
340                                    .send(
341                                        this.stream
342                                            .send_informational(response)
343                                            .map_err(h2_error_to_io),
344                                    )
345                                    .ok();
346                                continue;
347                            }
348                            Poll::Ready(None) => {
349                                *early_hints_open = false;
350                                continue;
351                            }
352                            Poll::Pending => {}
353                        }
354                    }
355
356                    return Poll::Pending;
357                }
358                H2StreamStateProj::Body { pipe } => {
359                    return pipe.poll(cx).map(|_| ());
360                }
361            }
362        }
363    }
364}
365
366/// An HTTP/2 connection handler.
367///
368/// `Http2` wraps an async I/O stream (`Io`) and drives the HTTP/2 server
369/// connection using the [`h2`] crate. It supports:
370///
371/// - Concurrent request stream handling
372/// - Streaming request/response bodies and trailers
373/// - Automatic `100 Continue` and `103 Early Hints` interim responses
374/// - Per-connection `Date` header caching
375/// - Graceful shutdown via a [`CancellationToken`]
376///
377/// # Construction
378///
379/// ```rust,ignore
380/// let http2 = Http2::new(tcp_stream, Http2Options::default());
381/// ```
382///
383/// # Serving requests
384///
385/// Use the [`HttpProtocol`] trait methods ([`handle`](HttpProtocol::handle) /
386/// [`handle_with_error_fn`](HttpProtocol::handle_with_error_fn)) to drive the
387/// connection to completion.
388pub struct Http2<Io> {
389    io_to_handshake: Option<Io>,
390    date_header_value_cached: DateCache,
391    options: Http2Options,
392    cancel_token: Option<CancellationToken>,
393}
394
395impl<Io> Http2<Io>
396where
397    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
398{
399    /// Creates a new `Http2` connection handler wrapping the given I/O stream.
400    ///
401    /// The `options` value controls HTTP/2 protocol configuration, handshake
402    /// and accept timeouts, and optional behaviour such as automatic
403    /// `100 Continue` responses; see [`Http2Options`] for details.
404    ///
405    /// # Example
406    ///
407    /// ```rust,ignore
408    /// let http2 = Http2::new(tcp_stream, Http2Options::default());
409    /// ```
410    #[inline]
411    pub fn new(io: Io, options: Http2Options) -> Self {
412        Self {
413            io_to_handshake: Some(io),
414            date_header_value_cached: DateCache::default(),
415            options,
416            cancel_token: None,
417        }
418    }
419
420    /// Attaches a [`CancellationToken`] for graceful shutdown.
421    ///
422    /// When the token is cancelled, the handler sends HTTP/2 graceful shutdown
423    /// signals (GOAWAY), stops accepting new streams, and exits cleanly.
424    #[inline]
425    pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
426        self.cancel_token = Some(token);
427        self
428    }
429}
430
431impl<Io> HttpProtocol for Http2<Io>
432where
433    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
434{
435    #[allow(clippy::manual_async_fn)]
436    #[inline]
437    fn handle<F, Fut, ResB, ResBE, ResE>(
438        mut self,
439        request_fn: F,
440    ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
441    where
442        F: Fn(Request<super::Incoming>) -> Fut + 'static,
443        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
444        ResB: http_body::Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
445        ResE: std::error::Error,
446        ResBE: std::error::Error,
447    {
448        async move {
449            let handshake_fut = self.options.h2.handshake(
450                self.io_to_handshake
451                    .take()
452                    .ok_or_else(|| std::io::Error::other("no io to handshake"))?,
453            );
454            let mut h2 = (if let Some(timeout) = self.options.handshake_timeout {
455                vibeio::time::timeout(timeout, handshake_fut).await
456            } else {
457                Ok(handshake_fut.await)
458            })
459            .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "handshake timeout"))?
460            .map_err(|e| {
461                if e.is_io() {
462                    e.into_io().unwrap_or(std::io::Error::other("io error"))
463                } else {
464                    std::io::Error::other(e)
465                }
466            })?;
467
468            while let Some(request) = {
469                let res = {
470                    let accept_fut_orig = h2.accept();
471                    let accept_fut_orig_pin = std::pin::pin!(accept_fut_orig);
472                    let cancel_token = self.cancel_token.clone();
473                    let cancel_fut = async move {
474                        if let Some(token) = cancel_token {
475                            token.cancelled().await
476                        } else {
477                            futures_util::future::pending().await
478                        }
479                    };
480                    let cancel_fut_pin = std::pin::pin!(cancel_fut);
481                    let accept_fut =
482                        futures_util::future::select(cancel_fut_pin, accept_fut_orig_pin);
483
484                    match if let Some(timeout) = self.options.accept_timeout {
485                        vibeio::time::timeout(timeout, accept_fut).await
486                    } else {
487                        Ok(accept_fut.await)
488                    } {
489                        Ok(futures_util::future::Either::Right((request, _))) => {
490                            (Some(request), false)
491                        }
492                        Ok(futures_util::future::Either::Left((_, _))) => {
493                            // Canceled
494                            (None, true)
495                        }
496                        Err(_) => {
497                            // Timeout
498                            (None, false)
499                        }
500                    }
501                };
502                match res {
503                    (Some(request), _) => request,
504                    (None, graceful) => {
505                        h2.graceful_shutdown();
506                        let _ = h2.accept().await;
507                        if graceful {
508                            return Ok(());
509                        }
510                        return Err(std::io::Error::new(
511                            std::io::ErrorKind::TimedOut,
512                            "accept timeout",
513                        ));
514                    }
515                }
516            } {
517                let (request, stream) = match request {
518                    Ok(d) => d,
519                    Err(e) if e.is_go_away() => {
520                        continue;
521                    }
522                    Err(e) if e.is_io() => {
523                        return Err(e.into_io().unwrap_or(std::io::Error::other("io error")));
524                    }
525                    Err(e) => {
526                        return Err(std::io::Error::other(e));
527                    }
528                };
529
530                // 100 Continue
531                let is_100_continue = self.options.send_continue_response
532                    && request
533                        .headers()
534                        .get(http::header::EXPECT)
535                        .and_then(|v| v.to_str().ok())
536                        .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
537
538                let date_cache = self.date_header_value_cached.clone();
539                let send_continue_body = is_100_continue.then(|| Arc::new(AtomicBool::new(false)));
540                let (request_parts, recv_stream) = request.into_parts();
541                let (request_body, upgrade) = if request_parts.method == http::Method::CONNECT {
542                    (Incoming::Empty, Some(recv_stream))
543                } else {
544                    (
545                        Incoming::H2(H2Body::new(recv_stream, send_continue_body.clone())),
546                        None,
547                    )
548                };
549                let mut request = Request::from_parts(request_parts, request_body);
550
551                // Install early hints
552                let (early_hints, early_hints_rx) = EarlyHints::new_lazy();
553                request.extensions_mut().insert(early_hints);
554
555                // Install HTTP upgrade
556                let upgrade = if let Some(recv_stream) = upgrade {
557                    let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
558                    let upgrade = Upgrade::new(upgrade_rx);
559                    let upgraded = upgrade.upgraded.clone();
560                    request.extensions_mut().insert(upgrade);
561                    Some(PendingUpgrade {
562                        tx: upgrade_tx,
563                        upgraded,
564                        recv_stream,
565                    })
566                } else {
567                    None
568                };
569
570                vibeio::spawn(H2Stream::new(
571                    stream,
572                    request_fn(request),
573                    early_hints_rx,
574                    date_cache,
575                    self.options.send_date_header,
576                    upgrade,
577                    is_100_continue,
578                    send_continue_body,
579                ));
580            }
581
582            Ok(())
583        }
584    }
585}