1mod date;
2mod options;
3mod send;
4
5pub use options::*;
6use tokio_util::sync::CancellationToken;
7
8use std::{
9 future::Future,
10 pin::Pin,
11 rc::Rc,
12 task::{Context, Poll},
13};
14
15use bytes::Bytes;
16use http::{Request, Response};
17use http_body::{Body, Frame};
18
19use crate::{
20 h2::{date::DateCache, send::PipeToSendStream},
21 EarlyHints, HttpProtocol, Incoming,
22};
23
24static HTTP2_INVALID_HEADERS: [http::header::HeaderName; 5] = [
25 http::header::HeaderName::from_static("keep-alive"),
26 http::header::HeaderName::from_static("proxy-connection"),
27 http::header::CONNECTION,
28 http::header::TRANSFER_ENCODING,
29 http::header::UPGRADE,
30];
31
32pub(crate) struct H2Body {
33 recv: h2::RecvStream,
34 data_done: bool,
35}
36
37impl H2Body {
38 #[inline]
39 fn new(recv: h2::RecvStream) -> Self {
40 Self {
41 recv,
42 data_done: false,
43 }
44 }
45}
46
47impl Body for H2Body {
48 type Data = Bytes;
49 type Error = std::io::Error;
50
51 #[inline]
52 fn poll_frame(
53 mut self: Pin<&mut Self>,
54 cx: &mut Context<'_>,
55 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
56 if !self.data_done {
57 match self.recv.poll_data(cx) {
58 Poll::Ready(Some(Ok(data))) => {
59 let _ = self.recv.flow_control().release_capacity(data.len());
60 return Poll::Ready(Some(Ok(Frame::data(data))));
61 }
62 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(h2_error_to_io(err)))),
63 Poll::Ready(None) => self.data_done = true,
64 Poll::Pending => return Poll::Pending,
65 }
66 }
67
68 match self.recv.poll_trailers(cx) {
69 Poll::Ready(Ok(Some(trailers))) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))),
70 Poll::Ready(Ok(None)) => Poll::Ready(None),
71 Poll::Ready(Err(err)) => Poll::Ready(Some(Err(h2_error_to_io(err)))),
72 Poll::Pending => Poll::Pending,
73 }
74 }
75}
76
77#[inline]
78pub(super) fn h2_error_to_io(error: h2::Error) -> std::io::Error {
79 if error.is_io() {
80 error.into_io().unwrap_or(std::io::Error::other("io error"))
81 } else {
82 std::io::Error::other(error)
83 }
84}
85
86#[inline]
87pub(super) fn h2_reason_to_io(reason: h2::Reason) -> std::io::Error {
88 std::io::Error::other(h2::Error::from(reason))
89}
90
91pub struct Http2<Io> {
114 io_to_handshake: Option<Io>,
115 date_header_value_cached: DateCache,
116 options: Http2Options,
117 cancel_token: Option<CancellationToken>,
118}
119
120impl<Io> Http2<Io>
121where
122 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
123{
124 #[inline]
136 pub fn new(io: Io, options: Http2Options) -> Self {
137 Self {
138 io_to_handshake: Some(io),
139 date_header_value_cached: DateCache::default(),
140 options,
141 cancel_token: None,
142 }
143 }
144
145 #[inline]
150 pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
151 self.cancel_token = Some(token);
152 self
153 }
154}
155
156impl<Io> HttpProtocol for Http2<Io>
157where
158 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
159{
160 #[allow(clippy::manual_async_fn)]
161 #[inline]
162 fn handle<F, Fut, ResB, ResBE, ResE>(
163 mut self,
164 request_fn: F,
165 ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
166 where
167 F: Fn(Request<super::Incoming>) -> Fut + 'static,
168 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
169 ResB: http_body::Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
170 ResE: std::error::Error,
171 ResBE: std::error::Error,
172 {
173 async move {
174 let request_fn = Rc::new(request_fn);
175 let handshake_fut = self.options.h2.handshake(
176 self.io_to_handshake
177 .take()
178 .ok_or_else(|| std::io::Error::other("no io to handshake"))?,
179 );
180 let mut h2 = (if let Some(timeout) = self.options.handshake_timeout {
181 vibeio::time::timeout(timeout, handshake_fut).await
182 } else {
183 Ok(handshake_fut.await)
184 })
185 .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "handshake timeout"))?
186 .map_err(|e| {
187 if e.is_io() {
188 e.into_io().unwrap_or(std::io::Error::other("io error"))
189 } else {
190 std::io::Error::other(e)
191 }
192 })?;
193
194 while let Some(request) = {
195 let res = {
196 let accept_fut_orig = h2.accept();
197 let accept_fut_orig_pin = std::pin::pin!(accept_fut_orig);
198 let cancel_token = self.cancel_token.clone();
199 let cancel_fut = async move {
200 if let Some(token) = cancel_token {
201 token.cancelled().await
202 } else {
203 futures_util::future::pending().await
204 }
205 };
206 let cancel_fut_pin = std::pin::pin!(cancel_fut);
207 let accept_fut =
208 futures_util::future::select(cancel_fut_pin, accept_fut_orig_pin);
209
210 match if let Some(timeout) = self.options.accept_timeout {
211 vibeio::time::timeout(timeout, accept_fut).await
212 } else {
213 Ok(accept_fut.await)
214 } {
215 Ok(futures_util::future::Either::Right((request, _))) => {
216 (Some(request), false)
217 }
218 Ok(futures_util::future::Either::Left((_, _))) => {
219 (None, true)
221 }
222 Err(_) => {
223 (None, false)
225 }
226 }
227 };
228 match res {
229 (Some(request), _) => request,
230 (None, graceful) => {
231 h2.graceful_shutdown();
232 let _ = h2.accept().await;
233 if graceful {
234 return Ok(());
235 }
236 return Err(std::io::Error::new(
237 std::io::ErrorKind::TimedOut,
238 "accept timeout",
239 ));
240 }
241 }
242 } {
243 let (request, mut stream) = match request {
244 Ok(d) => d,
245 Err(e) if e.is_go_away() => {
246 continue;
247 }
248 Err(e) if e.is_io() => {
249 return Err(e.into_io().unwrap_or(std::io::Error::other("io error")));
250 }
251 Err(e) => {
252 return Err(std::io::Error::other(e));
253 }
254 };
255
256 let date_cache = self.date_header_value_cached.clone();
257 let request_fn = request_fn.clone();
258 vibeio::spawn(async move {
259 let (request_parts, recv_stream) = request.into_parts();
260 let request_body = Incoming::H2(H2Body::new(recv_stream));
261 let mut request = Request::from_parts(request_parts, request_body);
262
263 if self.options.send_continue_response {
265 let is_100_continue = request
266 .headers()
267 .get(http::header::EXPECT)
268 .and_then(|v| v.to_str().ok())
269 .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
270 if is_100_continue {
271 let mut response = Response::new(());
272 *response.status_mut() = http::StatusCode::CONTINUE;
273 let _ = stream.send_informational(response).map_err(h2_error_to_io);
274 }
275 }
276
277 let (early_hints_tx, early_hints_rx) = async_channel::unbounded();
278 let early_hints = EarlyHints::new(early_hints_tx);
279 request.extensions_mut().insert(early_hints);
280
281 let mut response_fut = std::pin::pin!(request_fn(request));
282 let early_hints_rx = early_hints_rx;
283 let response_result = loop {
284 let early_hints_recv_fut = early_hints_rx.recv();
285 let mut early_hints_recv_fut = std::pin::pin!(early_hints_recv_fut);
286 let next = std::future::poll_fn(|cx| {
287 match stream.poll_reset(cx) {
288 Poll::Ready(Ok(reason)) => {
289 return Poll::Ready(Err(h2_reason_to_io(reason)));
290 }
291 Poll::Ready(Err(err)) => {
292 return Poll::Ready(Err(h2_error_to_io(err)));
293 }
294 Poll::Pending => {}
295 }
296
297 if let Poll::Ready(res) = response_fut.as_mut().poll(cx) {
298 return Poll::Ready(Ok(futures_util::future::Either::Left(res)));
299 }
300
301 match early_hints_recv_fut.as_mut().poll(cx) {
302 Poll::Ready(Ok(msg)) => {
303 Poll::Ready(Ok(futures_util::future::Either::Right(msg)))
304 }
305 Poll::Ready(Err(_)) => Poll::Pending,
306 Poll::Pending => Poll::Pending,
307 }
308 })
309 .await;
310
311 match next {
312 Ok(futures_util::future::Either::Left(response_result)) => {
313 break response_result;
314 }
315 Ok(futures_util::future::Either::Right((headers, sender))) => {
316 let mut response = Response::new(());
317 *response.status_mut() = http::StatusCode::EARLY_HINTS;
318 *response.headers_mut() = headers;
319 sender
320 .into_inner()
321 .send(
322 stream.send_informational(response).map_err(h2_error_to_io),
323 )
324 .ok();
325 }
326 Err(_) => {
327 return;
328 }
329 }
330 };
331 let Ok(mut response) = response_result else {
332 return;
334 };
335
336 {
337 let response_headers = response.headers_mut();
338 if self.options.send_date_header {
339 if let Some(http_date) = date_cache.get_date_header_value() {
340 response_headers
341 .entry(http::header::DATE)
342 .or_insert(http_date);
343 }
344 }
345 for header in &HTTP2_INVALID_HEADERS {
346 if let http::header::Entry::Occupied(entry) =
347 response_headers.entry(header)
348 {
349 entry.remove();
350 }
351 }
352 if response_headers
353 .get(http::header::TE)
354 .is_some_and(|v| v != "trailers")
355 {
356 response_headers.remove(http::header::TE);
357 }
358 }
359
360 let response_is_end_stream = response.body().is_end_stream();
361 if !response_is_end_stream {
362 if let Some(content_length) = response.body().size_hint().exact() {
363 if !response
364 .headers()
365 .contains_key(http::header::CONTENT_LENGTH)
366 {
367 response
368 .headers_mut()
369 .insert(http::header::CONTENT_LENGTH, content_length.into());
370 }
371 }
372 }
373
374 let (response_parts, mut response_body) = response.into_parts();
375 let Ok(send) = stream.send_response(
376 Response::from_parts(response_parts, ()),
377 response_is_end_stream,
378 ) else {
379 return;
380 };
381
382 if response_is_end_stream {
383 return;
384 }
385
386 let _ = PipeToSendStream::new(send, &mut response_body).await;
387 });
388 }
389
390 Ok(())
391 }
392 }
393}