1mod date;
2mod options;
3
4pub use options::*;
5use tokio_util::sync::CancellationToken;
6
7use std::{
8 future::Future,
9 pin::Pin,
10 rc::Rc,
11 task::{Context, Poll},
12};
13
14use bytes::Bytes;
15use http::{Request, Response};
16use http_body::{Body, Frame};
17use http_body_util::BodyExt;
18
19use crate::{h2::date::DateCache, EarlyHints, HttpProtocol, Incoming};
20
21static HTTP2_INVALID_HEADERS: [http::header::HeaderName; 5] = [
22 http::header::HeaderName::from_static("keep-alive"),
23 http::header::HeaderName::from_static("proxy-connection"),
24 http::header::TRANSFER_ENCODING,
25 http::header::TE,
26 http::header::UPGRADE,
27];
28
29struct H2Body {
30 recv: h2::RecvStream,
31 data_done: bool,
32}
33
34impl H2Body {
35 #[inline]
36 fn new(recv: h2::RecvStream) -> Self {
37 Self {
38 recv,
39 data_done: false,
40 }
41 }
42}
43
44impl Body for H2Body {
45 type Data = Bytes;
46 type Error = std::io::Error;
47
48 #[inline]
49 fn poll_frame(
50 mut self: Pin<&mut Self>,
51 cx: &mut Context<'_>,
52 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
53 if !self.data_done {
54 match self.recv.poll_data(cx) {
55 Poll::Ready(Some(Ok(data))) => return Poll::Ready(Some(Ok(Frame::data(data)))),
56 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(h2_error_to_io(err)))),
57 Poll::Ready(None) => self.data_done = true,
58 Poll::Pending => return Poll::Pending,
59 }
60 }
61
62 match self.recv.poll_trailers(cx) {
63 Poll::Ready(Ok(Some(trailers))) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))),
64 Poll::Ready(Ok(None)) => Poll::Ready(None),
65 Poll::Ready(Err(err)) => Poll::Ready(Some(Err(h2_error_to_io(err)))),
66 Poll::Pending => Poll::Pending,
67 }
68 }
69}
70
71#[inline]
72fn h2_error_to_io(error: h2::Error) -> std::io::Error {
73 if error.is_io() {
74 error.into_io().unwrap_or(std::io::Error::other("io error"))
75 } else {
76 std::io::Error::other(error)
77 }
78}
79
80#[inline]
81fn h2_reason_to_io(reason: h2::Reason) -> std::io::Error {
82 std::io::Error::other(h2::Error::from(reason))
83}
84
85async fn wait_for_send_capacity(send: &mut h2::SendStream<Bytes>) -> Result<(), std::io::Error> {
86 send.reserve_capacity(1);
87
88 if send.capacity() == 0 {
89 std::future::poll_fn(|cx| loop {
90 match send.poll_capacity(cx) {
91 Poll::Ready(Some(Ok(0))) => {}
92 Poll::Ready(Some(Ok(_))) => return Poll::Ready(Ok(())),
93 Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(h2_error_to_io(err))),
94 Poll::Ready(None) => {
95 return Poll::Ready(Err(std::io::Error::other(
96 "send stream capacity unexpectedly closed",
97 )))
98 }
99 Poll::Pending => return Poll::Pending,
100 }
101 })
102 .await
103 } else {
104 let reset_reason = std::future::poll_fn(|cx| match send.poll_reset(cx) {
105 Poll::Ready(Ok(reason)) => Poll::Ready(Ok(Some(reason))),
106 Poll::Ready(Err(err)) => Poll::Ready(Err(h2_error_to_io(err))),
107 Poll::Pending => Poll::Ready(Ok(None)),
108 })
109 .await?;
110 if let Some(reason) = reset_reason {
111 Err(h2_reason_to_io(reason))
112 } else {
113 Ok(())
114 }
115 }
116}
117
118pub struct Http2<Io> {
141 io_to_handshake: Option<Io>,
142 date_header_value_cached: DateCache,
143 options: Http2Options,
144 cancel_token: Option<CancellationToken>,
145}
146
147impl<Io> Http2<Io>
148where
149 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
150{
151 #[inline]
163 pub fn new(io: Io, options: Http2Options) -> Self {
164 Self {
165 io_to_handshake: Some(io),
166 date_header_value_cached: DateCache::default(),
167 options,
168 cancel_token: None,
169 }
170 }
171
172 #[inline]
177 pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
178 self.cancel_token = Some(token);
179 self
180 }
181}
182
183impl<Io> HttpProtocol for Http2<Io>
184where
185 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
186{
187 #[allow(clippy::manual_async_fn)]
188 #[inline]
189 fn handle<F, Fut, ResB, ResBE, ResE>(
190 mut self,
191 request_fn: F,
192 ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
193 where
194 F: Fn(Request<super::Incoming>) -> Fut + 'static,
195 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
196 ResB: http_body::Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
197 ResE: std::error::Error,
198 ResBE: std::error::Error,
199 {
200 async move {
201 let request_fn = Rc::new(request_fn);
202 let handshake_fut = self.options.h2.handshake(
203 self.io_to_handshake
204 .take()
205 .ok_or_else(|| std::io::Error::other("no io to handshake"))?,
206 );
207 let mut h2 = (if let Some(timeout) = self.options.handshake_timeout {
208 vibeio::time::timeout(timeout, handshake_fut).await
209 } else {
210 Ok(handshake_fut.await)
211 })
212 .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "handshake timeout"))?
213 .map_err(|e| {
214 if e.is_io() {
215 e.into_io().unwrap_or(std::io::Error::other("io error"))
216 } else {
217 std::io::Error::other(e)
218 }
219 })?;
220
221 while let Some(request) = {
222 let res = {
223 let accept_fut_orig = h2.accept();
224 let accept_fut_orig_pin = std::pin::pin!(accept_fut_orig);
225 let cancel_token = self.cancel_token.clone();
226 let cancel_fut = async move {
227 if let Some(token) = cancel_token {
228 token.cancelled().await
229 } else {
230 futures_util::future::pending().await
231 }
232 };
233 let cancel_fut_pin = std::pin::pin!(cancel_fut);
234 let accept_fut =
235 futures_util::future::select(cancel_fut_pin, accept_fut_orig_pin);
236
237 match if let Some(timeout) = self.options.accept_timeout {
238 vibeio::time::timeout(timeout, accept_fut).await
239 } else {
240 Ok(accept_fut.await)
241 } {
242 Ok(futures_util::future::Either::Right((request, _))) => {
243 (Some(request), false)
244 }
245 Ok(futures_util::future::Either::Left((_, _))) => {
246 (None, true)
248 }
249 Err(_) => {
250 (None, false)
252 }
253 }
254 };
255 match res {
256 (Some(request), _) => request,
257 (None, graceful) => {
258 h2.graceful_shutdown();
259 let _ = h2.accept().await;
260 if graceful {
261 return Ok(());
262 }
263 return Err(std::io::Error::new(
264 std::io::ErrorKind::TimedOut,
265 "accept timeout",
266 ));
267 }
268 }
269 } {
270 let (request, mut stream) = match request {
271 Ok(d) => d,
272 Err(e) if e.is_go_away() => {
273 continue;
274 }
275 Err(e) if e.is_io() => {
276 return Err(e.into_io().unwrap_or(std::io::Error::other("io error")));
277 }
278 Err(e) => {
279 return Err(std::io::Error::other(e));
280 }
281 };
282
283 let date_cache = self.date_header_value_cached.clone();
284 let request_fn = request_fn.clone();
285 let send_continue_response = self.options.send_continue_response;
286 vibeio::spawn(async move {
287 let (request_parts, recv_stream) = request.into_parts();
288 let request_body = Incoming::new(H2Body::new(recv_stream));
289 let mut request = Request::from_parts(request_parts, request_body);
290
291 if send_continue_response {
293 let is_100_continue = request
294 .headers()
295 .get(http::header::EXPECT)
296 .and_then(|v| v.to_str().ok())
297 .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
298 if is_100_continue {
299 let mut response = Response::new(());
300 *response.status_mut() = http::StatusCode::CONTINUE;
301 let _ = stream.send_informational(response).map_err(h2_error_to_io);
302 }
303 }
304
305 let (early_hints_tx, early_hints_rx) = async_channel::unbounded();
306 let early_hints = EarlyHints::new(early_hints_tx);
307 request.extensions_mut().insert(early_hints);
308
309 let mut response_fut = std::pin::pin!(request_fn(request));
310 let early_hints_rx = early_hints_rx;
311 let response_result = loop {
312 let early_hints_recv_fut = early_hints_rx.recv();
313 let mut early_hints_recv_fut = std::pin::pin!(early_hints_recv_fut);
314 let next = std::future::poll_fn(|cx| {
315 match stream.poll_reset(cx) {
316 Poll::Ready(Ok(reason)) => {
317 return Poll::Ready(Err(h2_reason_to_io(reason)));
318 }
319 Poll::Ready(Err(err)) => {
320 return Poll::Ready(Err(h2_error_to_io(err)));
321 }
322 Poll::Pending => {}
323 }
324
325 if let Poll::Ready(res) = response_fut.as_mut().poll(cx) {
326 return Poll::Ready(Ok(futures_util::future::Either::Left(res)));
327 }
328
329 match early_hints_recv_fut.as_mut().poll(cx) {
330 Poll::Ready(Ok(msg)) => {
331 Poll::Ready(Ok(futures_util::future::Either::Right(msg)))
332 }
333 Poll::Ready(Err(_)) => Poll::Pending,
334 Poll::Pending => Poll::Pending,
335 }
336 })
337 .await;
338
339 match next {
340 Ok(futures_util::future::Either::Left(response_result)) => {
341 break response_result;
342 }
343 Ok(futures_util::future::Either::Right((headers, sender))) => {
344 let mut response = Response::new(());
345 *response.status_mut() = http::StatusCode::EARLY_HINTS;
346 *response.headers_mut() = headers;
347 sender
348 .into_inner()
349 .send(
350 stream.send_informational(response).map_err(h2_error_to_io),
351 )
352 .ok();
353 }
354 Err(_) => {
355 return;
356 }
357 }
358 };
359 let Ok(mut response) = response_result else {
360 return;
362 };
363
364 {
365 let response_headers = response.headers_mut();
366 if let Some(http_date) = date_cache.get_date_header_value() {
367 response_headers
368 .entry(http::header::DATE)
369 .or_insert(http_date);
370 }
371 if let Some(connection_header) = response_headers
372 .remove(http::header::CONNECTION)
373 .as_ref()
374 .and_then(|v| v.to_str().ok())
375 {
376 for name in connection_header.split(',') {
377 response_headers.remove(name.trim());
378 }
379 }
380 while response_headers.remove(http::header::CONNECTION).is_some() {}
381 for header in &HTTP2_INVALID_HEADERS {
382 while response_headers.remove(header).is_some() {}
383 }
384 }
385
386 let response_is_end_stream = response.body().is_end_stream();
387 if !response_is_end_stream {
388 if let Some(content_length) = response.body().size_hint().exact() {
389 if !response
390 .headers()
391 .contains_key(http::header::CONTENT_LENGTH)
392 {
393 response
394 .headers_mut()
395 .insert(http::header::CONTENT_LENGTH, content_length.into());
396 }
397 }
398 }
399
400 let (response_parts, mut response_body) = response.into_parts();
401 let mut send = match stream.send_response(
402 Response::from_parts(response_parts, ()),
403 response_is_end_stream,
404 ) {
405 Ok(send) => send,
406 Err(_) => {
407 return;
408 }
409 };
410
411 if response_is_end_stream {
412 return;
413 }
414
415 while let Some(chunk) = {
416 let frame_fut = response_body.frame();
417 let mut frame_fut = std::pin::pin!(frame_fut);
418 match std::future::poll_fn(|cx| {
419 match send.poll_reset(cx) {
420 Poll::Ready(Ok(reason)) => {
421 return Poll::Ready(Err(h2_reason_to_io(reason)));
422 }
423 Poll::Ready(Err(err)) => {
424 return Poll::Ready(Err(h2_error_to_io(err)));
425 }
426 Poll::Pending => {}
427 }
428
429 match frame_fut.as_mut().poll(cx) {
430 Poll::Ready(frame) => Poll::Ready(Ok(frame)),
431 Poll::Pending => Poll::Pending,
432 }
433 })
434 .await
435 {
436 Ok(frame) => frame,
437 Err(_) => {
438 return;
439 }
440 }
441 } {
442 match chunk {
443 Ok(frame) => {
444 if frame.is_data() {
445 match frame.into_data() {
446 Ok(data) => {
447 if wait_for_send_capacity(&mut send).await.is_err() {
448 return;
449 }
450 let is_end_stream = response_body.is_end_stream();
451 if send.send_data(data, is_end_stream).is_err() {
452 return;
453 }
454 if is_end_stream {
455 return;
456 }
457 }
458 Err(_) => {
459 return;
460 }
461 }
462 } else if frame.is_trailers() {
463 match frame.into_trailers() {
464 Ok(trailers) => {
465 if send.send_trailers(trailers).is_err() {
466 return;
467 }
468 return;
469 }
470 Err(_) => {
471 return;
472 }
473 }
474 }
475 }
476 Err(_) => {
477 return;
478 }
479 }
480 }
481 let _ = send.send_data(Bytes::new(), true);
482 });
483 }
484
485 Ok(())
486 }
487 }
488}