1#![warn(missing_docs, missing_debug_implementations)]
2
3use core::future::Future;
6use executor_core::{AnyExecutor, Executor, Task};
7use http_kit::utils::{AsyncRead, AsyncReadExt, AsyncWrite, Stream, StreamExt};
8use hyper::server::conn::{http1::Builder as Http1Builder, http2::Builder as Http2Builder};
9use skyzen_core::{Endpoint, Server};
10use std::pin::Pin;
11use std::ptr;
12use std::sync::Arc;
13use std::task::{Context, Poll};
14use tracing::error;
15
16mod service;
17pub use service::IntoService;
18
19#[derive(Debug, Default, Clone, Copy)]
21pub struct Hyper;
22
23struct ExecutorWrapper<E>(Arc<E>);
24
25impl<E> ExecutorWrapper<E> {
26 const fn new(executor: Arc<E>) -> Self {
27 Self(executor)
28 }
29}
30
31impl<E> Clone for ExecutorWrapper<E> {
32 fn clone(&self) -> Self {
33 Self(self.0.clone())
34 }
35}
36
37impl<Fut, E> hyper::rt::Executor<Fut> for ExecutorWrapper<E>
38where
39 Fut: Future + Send + 'static,
40 Fut::Output: Send + 'static,
41 E: executor_core::Executor + 'static,
42{
43 fn execute(&self, fut: Fut) {
44 self.0.spawn(fut).detach();
45 }
46}
47
48struct ConnectionWrapper<C>(C);
49
50impl<C: Unpin + AsyncRead> hyper::rt::Read for ConnectionWrapper<C> {
51 fn poll_read(
52 self: std::pin::Pin<&mut Self>,
53 cx: &mut std::task::Context<'_>,
54 mut buf: hyper::rt::ReadBufCursor<'_>,
55 ) -> std::task::Poll<Result<(), std::io::Error>> {
56 let inner = &mut self.get_mut().0;
57
58 let buffer = unsafe { &mut *(ptr::from_mut(buf.as_mut()) as *mut [u8]) };
61
62 match Pin::new(inner).poll_read(cx, buffer) {
63 Poll::Ready(Ok(n)) => {
64 unsafe {
66 buf.advance(n);
67 }
68 Poll::Ready(Ok(()))
69 }
70 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
71 Poll::Pending => Poll::Pending,
72 }
73 }
74}
75
76impl<C: AsyncWrite + Unpin> hyper::rt::Write for ConnectionWrapper<C> {
77 fn poll_write(
78 self: Pin<&mut Self>,
79 cx: &mut Context<'_>,
80 buf: &[u8],
81 ) -> Poll<Result<usize, std::io::Error>> {
82 let inner = &mut self.get_mut().0;
83 Pin::new(inner).poll_write(cx, buf)
84 }
85
86 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
87 let inner = &mut self.get_mut().0;
88 Pin::new(inner).poll_flush(cx)
89 }
90
91 fn poll_shutdown(
92 self: Pin<&mut Self>,
93 cx: &mut Context<'_>,
94 ) -> Poll<Result<(), std::io::Error>> {
95 let inner = &mut self.get_mut().0;
96 Pin::new(inner).poll_close(cx)
97 }
98}
99
100#[derive(Debug)]
101struct Prefixed<C> {
102 buffer: Vec<u8>,
103 pos: usize,
104 inner: C,
105}
106
107impl<C> Prefixed<C> {
108 const fn new(inner: C, buffer: Vec<u8>) -> Self {
109 Self {
110 buffer,
111 pos: 0,
112 inner,
113 }
114 }
115}
116
117impl<C: Unpin> Unpin for Prefixed<C> {}
118
119impl<C: AsyncRead + Unpin> AsyncRead for Prefixed<C> {
120 fn poll_read(
121 self: Pin<&mut Self>,
122 cx: &mut Context<'_>,
123 buf: &mut [u8],
124 ) -> Poll<Result<usize, std::io::Error>> {
125 let this = self.get_mut();
126 if this.pos < this.buffer.len() {
127 let available = this.buffer.len() - this.pos;
128 let n = available.min(buf.len());
129 buf[..n].copy_from_slice(&this.buffer[this.pos..this.pos + n]);
130 this.pos += n;
131 if this.pos == this.buffer.len() {
132 this.buffer.clear();
133 this.pos = 0;
134 }
135 return Poll::Ready(Ok(n));
136 }
137
138 Pin::new(&mut this.inner).poll_read(cx, buf)
139 }
140}
141
142impl<C: AsyncWrite + Unpin> AsyncWrite for Prefixed<C> {
143 fn poll_write(
144 self: Pin<&mut Self>,
145 cx: &mut Context<'_>,
146 buf: &[u8],
147 ) -> Poll<Result<usize, std::io::Error>> {
148 Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
149 }
150
151 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
152 Pin::new(&mut self.get_mut().inner).poll_flush(cx)
153 }
154
155 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
156 Pin::new(&mut self.get_mut().inner).poll_close(cx)
157 }
158}
159
160impl Server for Hyper {
161 async fn serve<C, E>(
162 self,
163 executor: impl executor_core::Executor + 'static,
164 error_handler: impl Fn(E) + Send + Sync + 'static,
165 mut connections: impl Stream<Item = Result<C, E>> + Unpin + Send + 'static,
166 endpoint: impl Endpoint + Sync + Clone + 'static,
167 ) where
168 C: Unpin + Send + AsyncRead + AsyncWrite + 'static,
169 E: std::error::Error,
170 {
171 const HTTP2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
172
173 let executor = Arc::new(executor);
174 let hyper_executor = ExecutorWrapper::new(executor.clone());
175 let shared_executor: Arc<AnyExecutor> = Arc::new(AnyExecutor::new(executor.clone()));
176 while let Some(connection) = connections.next().await {
177 match connection {
178 Ok(connection) => {
179 let serve_executor = executor.clone();
180 let endpoint = endpoint.clone();
181 let hyper_executor = hyper_executor.clone();
182 let shared_executor = shared_executor.clone();
183 let serve_future = async move {
184 let (connection, is_h2) =
185 match sniff_protocol(connection, HTTP2_PREFACE).await {
186 Ok(result) => result,
187 Err(error) => {
188 error!("Failed to read connection preface: {error}");
189 return;
190 }
191 };
192
193 if is_h2 {
194 let builder = Http2Builder::new(hyper_executor);
195 let service = IntoService::new(endpoint, shared_executor);
196 if let Err(error) = builder
197 .serve_connection(ConnectionWrapper(connection), service)
198 .await
199 {
200 error!("Failed to serve Hyper h2 connection: {error}");
201 }
202 } else {
203 let builder = Http1Builder::new();
204 let service = IntoService::new(endpoint, shared_executor);
205 if let Err(error) = builder
206 .serve_connection(ConnectionWrapper(connection), service)
207 .with_upgrades()
208 .await
209 {
210 error!("Failed to serve Hyper h1 connection: {error}");
211 }
212 }
213 };
214 serve_executor.spawn(serve_future).detach();
215 }
216 Err(error) => error_handler(error),
217 }
218 }
219 }
220}
221
222async fn sniff_protocol<C>(mut stream: C, preface: &[u8]) -> std::io::Result<(Prefixed<C>, bool)>
223where
224 C: AsyncRead + AsyncWrite + Unpin,
225{
226 let mut buf = vec![0u8; preface.len()];
227 let n = stream.read(&mut buf).await?;
228 buf.truncate(n);
229 let is_h2 = buf.starts_with(preface);
230 Ok((Prefixed::new(stream, buf), is_h2))
231}