1mod date;
4mod options;
5mod send;
6mod upgrade;
7
8pub use options::*;
9use pin_project_lite::pin_project;
10use tokio_util::sync::CancellationToken;
11
12use std::{
13 future::Future,
14 pin::Pin,
15 task::{Context, Poll},
16};
17
18use bytes::Bytes;
19use http::{Request, Response};
20use http_body::{Body, Frame};
21
22use crate::{
23 early_hints::EarlyHintsReceiver,
24 h2::{
25 date::DateCache,
26 send::{PipeToSendStream, SendBuf},
27 },
28 EarlyHints, HttpProtocol, Incoming, Upgrade, Upgraded,
29};
30
31static HTTP2_INVALID_HEADERS: [http::header::HeaderName; 5] = [
32 http::header::HeaderName::from_static("keep-alive"),
33 http::header::HeaderName::from_static("proxy-connection"),
34 http::header::CONNECTION,
35 http::header::TRANSFER_ENCODING,
36 http::header::UPGRADE,
37];
38
39pub(crate) struct H2Body {
40 recv: h2::RecvStream,
41 data_done: bool,
42}
43
44impl H2Body {
45 #[inline]
46 fn new(recv: h2::RecvStream) -> Self {
47 Self {
48 recv,
49 data_done: false,
50 }
51 }
52}
53
54impl Body for H2Body {
55 type Data = Bytes;
56 type Error = std::io::Error;
57
58 #[inline]
59 fn poll_frame(
60 mut self: Pin<&mut Self>,
61 cx: &mut Context<'_>,
62 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
63 if !self.data_done {
64 match self.recv.poll_data(cx) {
65 Poll::Ready(Some(Ok(data))) => {
66 let _ = self.recv.flow_control().release_capacity(data.len());
67 return Poll::Ready(Some(Ok(Frame::data(data))));
68 }
69 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(h2_error_to_io(err)))),
70 Poll::Ready(None) => self.data_done = true,
71 Poll::Pending => return Poll::Pending,
72 }
73 }
74
75 match self.recv.poll_trailers(cx) {
76 Poll::Ready(Ok(Some(trailers))) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))),
77 Poll::Ready(Ok(None)) => Poll::Ready(None),
78 Poll::Ready(Err(err)) => Poll::Ready(Some(Err(h2_error_to_io(err)))),
79 Poll::Pending => Poll::Pending,
80 }
81 }
82}
83
84#[inline]
85pub(super) fn h2_error_to_io(error: h2::Error) -> std::io::Error {
86 if error.is_io() {
87 error.into_io().unwrap_or(std::io::Error::other("io error"))
88 } else {
89 std::io::Error::other(error)
90 }
91}
92
93#[inline]
94pub(super) fn h2_reason_to_io(reason: h2::Reason) -> std::io::Error {
95 std::io::Error::other(h2::Error::from(reason))
96}
97
98#[inline]
99fn sanitize_response<ResB>(
100 response: &mut Response<ResB>,
101 send_date_header: bool,
102 date_cache: &DateCache,
103) where
104 ResB: Body<Data = bytes::Bytes>,
105{
106 let response_headers = response.headers_mut();
107 if send_date_header {
108 if let Some(http_date) = date_cache.get_date_header_value() {
109 response_headers
110 .entry(http::header::DATE)
111 .or_insert(http_date);
112 }
113 }
114 for header in &HTTP2_INVALID_HEADERS {
115 if let http::header::Entry::Occupied(entry) = response_headers.entry(header) {
116 entry.remove();
117 }
118 }
119 if response_headers
120 .get(http::header::TE)
121 .is_some_and(|v| v != "trailers")
122 {
123 response_headers.remove(http::header::TE);
124 }
125}
126
127struct PendingUpgrade {
128 tx: oneshot::Sender<Upgraded>,
129 upgraded: std::sync::Arc<std::sync::atomic::AtomicBool>,
130 recv_stream: h2::RecvStream,
131}
132
133pin_project! {
134 struct H2Stream<Fut, ResB>
135 where
136 Fut: Future,
137 ResB: Body<Data = bytes::Bytes>,
138 {
139 stream: h2::server::SendResponse<SendBuf<ResB::Data>>,
140 #[pin]
141 state: H2StreamState<Fut, ResB>,
142 }
143}
144
145pin_project! {
146 #[project = H2StreamStateProj]
147 enum H2StreamState<Fut, ResB>
148 where
149 Fut: Future,
150 ResB: Body<Data = bytes::Bytes>,
151 {
152 Service {
153 #[pin]
154 response_fut: Fut,
155 early_hints_rx: EarlyHintsReceiver,
156 date_cache: DateCache,
157 send_date_header: bool,
158 upgrade: Option<PendingUpgrade>,
159 send_continue: bool,
160 early_hints_open: bool,
161 },
162 Body {
163 #[pin]
164 pipe: PipeToSendStream<ResB>,
165 },
166 }
167}
168
169impl<Fut, ResB> H2Stream<Fut, ResB>
170where
171 Fut: Future,
172 ResB: Body<Data = bytes::Bytes>,
173{
174 #[inline]
175 const fn new(
176 stream: h2::server::SendResponse<SendBuf<ResB::Data>>,
177 response_fut: Fut,
178 early_hints_rx: EarlyHintsReceiver,
179 date_cache: DateCache,
180 send_date_header: bool,
181 upgrade: Option<PendingUpgrade>,
182 send_continue: bool,
183 ) -> Self {
184 Self {
185 stream,
186 state: H2StreamState::Service {
187 response_fut,
188 early_hints_rx,
189 date_cache,
190 send_date_header,
191 upgrade,
192 send_continue,
193 early_hints_open: true,
194 },
195 }
196 }
197}
198
199impl<Fut, ResB, ResBE, ResE> Future for H2Stream<Fut, ResB>
200where
201 Fut: Future<Output = Result<Response<ResB>, ResE>>,
202 ResB: Body<Data = bytes::Bytes, Error = ResBE>,
203 ResE: std::error::Error,
204 ResBE: std::error::Error,
205{
206 type Output = ();
207
208 #[inline]
209 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
210 let mut this = self.project();
211
212 loop {
213 match this.state.as_mut().project() {
214 H2StreamStateProj::Service {
215 response_fut,
216 early_hints_rx,
217 date_cache,
218 send_date_header,
219 upgrade,
220 send_continue,
221 early_hints_open,
222 } => {
223 if *send_continue {
224 let mut response = Response::new(());
225 *response.status_mut() = http::StatusCode::CONTINUE;
226 let _ = this
227 .stream
228 .send_informational(response)
229 .map_err(h2_error_to_io);
230 *send_continue = false;
231 }
232
233 if let Poll::Ready(response_result) = response_fut.poll(cx) {
234 let Ok(mut response) = response_result else {
235 return Poll::Ready(());
236 };
237
238 sanitize_response(&mut response, *send_date_header, date_cache);
239
240 let response_is_end_stream = response.body().is_end_stream();
241 if !response_is_end_stream {
242 if let Some(content_length) = response.body().size_hint().exact() {
243 if !response
244 .headers()
245 .contains_key(http::header::CONTENT_LENGTH)
246 {
247 response.headers_mut().insert(
248 http::header::CONTENT_LENGTH,
249 content_length.into(),
250 );
251 }
252 }
253 }
254
255 let (response_parts, response_body) = response.into_parts();
256 let Ok(send) = this.stream.send_response(
257 Response::from_parts(response_parts, ()),
258 response_is_end_stream && upgrade.is_none(),
259 ) else {
260 return Poll::Ready(());
261 };
262
263 if let Some(PendingUpgrade {
264 tx,
265 upgraded,
266 recv_stream,
267 }) = upgrade.take()
268 {
269 if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
270 let (upgraded, task) = self::upgrade::pair(send, recv_stream);
271 let _ = tx.send(Upgraded::new(upgraded, None));
272 vibeio::spawn(task);
273 return Poll::Ready(());
274 }
275 }
276
277 if response_is_end_stream {
278 return Poll::Ready(());
279 }
280
281 this.state.set(H2StreamState::Body {
282 pipe: PipeToSendStream::new(send, response_body),
283 });
284 continue;
285 }
286
287 match this.stream.poll_reset(cx) {
288 Poll::Ready(Ok(_)) | Poll::Ready(Err(_)) => return Poll::Ready(()),
289 Poll::Pending => {}
290 }
291
292 if *early_hints_open {
293 match early_hints_rx.poll_recv(cx) {
294 Poll::Ready(Some((headers, sender))) => {
295 let mut response = Response::new(());
296 *response.status_mut() = http::StatusCode::EARLY_HINTS;
297 *response.headers_mut() = headers;
298 sender
299 .into_inner()
300 .send(
301 this.stream
302 .send_informational(response)
303 .map_err(h2_error_to_io),
304 )
305 .ok();
306 continue;
307 }
308 Poll::Ready(None) => {
309 *early_hints_open = false;
310 continue;
311 }
312 Poll::Pending => {}
313 }
314 }
315
316 return Poll::Pending;
317 }
318 H2StreamStateProj::Body { pipe } => {
319 return pipe.poll(cx).map(|_| ());
320 }
321 }
322 }
323 }
324}
325
326pub struct Http2<Io> {
349 io_to_handshake: Option<Io>,
350 date_header_value_cached: DateCache,
351 options: Http2Options,
352 cancel_token: Option<CancellationToken>,
353}
354
355impl<Io> Http2<Io>
356where
357 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
358{
359 #[inline]
371 pub fn new(io: Io, options: Http2Options) -> Self {
372 Self {
373 io_to_handshake: Some(io),
374 date_header_value_cached: DateCache::default(),
375 options,
376 cancel_token: None,
377 }
378 }
379
380 #[inline]
385 pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
386 self.cancel_token = Some(token);
387 self
388 }
389}
390
391impl<Io> HttpProtocol for Http2<Io>
392where
393 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
394{
395 #[allow(clippy::manual_async_fn)]
396 #[inline]
397 fn handle<F, Fut, ResB, ResBE, ResE>(
398 mut self,
399 request_fn: F,
400 ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
401 where
402 F: Fn(Request<super::Incoming>) -> Fut + 'static,
403 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
404 ResB: http_body::Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
405 ResE: std::error::Error,
406 ResBE: std::error::Error,
407 {
408 async move {
409 let handshake_fut = self.options.h2.handshake(
410 self.io_to_handshake
411 .take()
412 .ok_or_else(|| std::io::Error::other("no io to handshake"))?,
413 );
414 let mut h2 = (if let Some(timeout) = self.options.handshake_timeout {
415 vibeio::time::timeout(timeout, handshake_fut).await
416 } else {
417 Ok(handshake_fut.await)
418 })
419 .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "handshake timeout"))?
420 .map_err(|e| {
421 if e.is_io() {
422 e.into_io().unwrap_or(std::io::Error::other("io error"))
423 } else {
424 std::io::Error::other(e)
425 }
426 })?;
427
428 while let Some(request) = {
429 let res = {
430 let accept_fut_orig = h2.accept();
431 let accept_fut_orig_pin = std::pin::pin!(accept_fut_orig);
432 let cancel_token = self.cancel_token.clone();
433 let cancel_fut = async move {
434 if let Some(token) = cancel_token {
435 token.cancelled().await
436 } else {
437 futures_util::future::pending().await
438 }
439 };
440 let cancel_fut_pin = std::pin::pin!(cancel_fut);
441 let accept_fut =
442 futures_util::future::select(cancel_fut_pin, accept_fut_orig_pin);
443
444 match if let Some(timeout) = self.options.accept_timeout {
445 vibeio::time::timeout(timeout, accept_fut).await
446 } else {
447 Ok(accept_fut.await)
448 } {
449 Ok(futures_util::future::Either::Right((request, _))) => {
450 (Some(request), false)
451 }
452 Ok(futures_util::future::Either::Left((_, _))) => {
453 (None, true)
455 }
456 Err(_) => {
457 (None, false)
459 }
460 }
461 };
462 match res {
463 (Some(request), _) => request,
464 (None, graceful) => {
465 h2.graceful_shutdown();
466 let _ = h2.accept().await;
467 if graceful {
468 return Ok(());
469 }
470 return Err(std::io::Error::new(
471 std::io::ErrorKind::TimedOut,
472 "accept timeout",
473 ));
474 }
475 }
476 } {
477 let (request, stream) = match request {
478 Ok(d) => d,
479 Err(e) if e.is_go_away() => {
480 continue;
481 }
482 Err(e) if e.is_io() => {
483 return Err(e.into_io().unwrap_or(std::io::Error::other("io error")));
484 }
485 Err(e) => {
486 return Err(std::io::Error::other(e));
487 }
488 };
489
490 let date_cache = self.date_header_value_cached.clone();
491 let (request_parts, recv_stream) = request.into_parts();
492 let (request_body, upgrade) = if request_parts.method == http::Method::CONNECT {
493 (Incoming::Empty, Some(recv_stream))
494 } else {
495 (Incoming::H2(H2Body::new(recv_stream)), None)
496 };
497 let mut request = Request::from_parts(request_parts, request_body);
498
499 let is_100_continue = self.options.send_continue_response
501 && request
502 .headers()
503 .get(http::header::EXPECT)
504 .and_then(|v| v.to_str().ok())
505 .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
506
507 let (early_hints, early_hints_rx) = EarlyHints::new_lazy();
509 request.extensions_mut().insert(early_hints);
510
511 let upgrade = if let Some(recv_stream) = upgrade {
513 let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
514 let upgrade = Upgrade::new(upgrade_rx);
515 let upgraded = upgrade.upgraded.clone();
516 request.extensions_mut().insert(upgrade);
517 Some(PendingUpgrade {
518 tx: upgrade_tx,
519 upgraded,
520 recv_stream,
521 })
522 } else {
523 None
524 };
525
526 vibeio::spawn(H2Stream::new(
527 stream,
528 request_fn(request),
529 early_hints_rx,
530 date_cache,
531 self.options.send_date_header,
532 upgrade,
533 is_100_continue,
534 ));
535 }
536
537 Ok(())
538 }
539 }
540}