skyzen_hyper/
lib.rs

1#![warn(missing_docs, missing_debug_implementations)]
2
3//! The hyper backend of skyzen
4
5use core::future::Future;
6use executor_core::{AnyExecutor, Executor, Task};
7use http_kit::utils::{AsyncRead, AsyncReadExt, AsyncWrite, Stream, StreamExt};
8use hyper::server::conn::{http1::Builder as Http1Builder, http2::Builder as Http2Builder};
9use skyzen_core::{Endpoint, Server};
10use std::pin::Pin;
11use std::ptr;
12use std::sync::Arc;
13use std::task::{Context, Poll};
14use tracing::error;
15
16mod service;
17pub use service::IntoService;
18
19/// Hyper-based [`Server`] implementation.
20#[derive(Debug, Default, Clone, Copy)]
21pub struct Hyper;
22
23struct ExecutorWrapper<E>(Arc<E>);
24
25impl<E> ExecutorWrapper<E> {
26    const fn new(executor: Arc<E>) -> Self {
27        Self(executor)
28    }
29}
30
31impl<E> Clone for ExecutorWrapper<E> {
32    fn clone(&self) -> Self {
33        Self(self.0.clone())
34    }
35}
36
37impl<Fut, E> hyper::rt::Executor<Fut> for ExecutorWrapper<E>
38where
39    Fut: Future + Send + 'static,
40    Fut::Output: Send + 'static,
41    E: executor_core::Executor + 'static,
42{
43    fn execute(&self, fut: Fut) {
44        self.0.spawn(fut).detach();
45    }
46}
47
48struct ConnectionWrapper<C>(C);
49
50impl<C: Unpin + AsyncRead> hyper::rt::Read for ConnectionWrapper<C> {
51    fn poll_read(
52        self: std::pin::Pin<&mut Self>,
53        cx: &mut std::task::Context<'_>,
54        mut buf: hyper::rt::ReadBufCursor<'_>,
55    ) -> std::task::Poll<Result<(), std::io::Error>> {
56        let inner = &mut self.get_mut().0;
57
58        // SAFETY: `buf.as_mut()` gives us a `&mut [MaybeUninit<u8>]`.
59        // We must cast it to `&mut [u8]` and guarantee we will only write `n` bytes and call `advance(n)`
60        let buffer = unsafe { &mut *(ptr::from_mut(buf.as_mut()) as *mut [u8]) };
61
62        match Pin::new(inner).poll_read(cx, buffer) {
63            Poll::Ready(Ok(n)) => {
64                // SAFETY: we just wrote `n` bytes into `buffer`, must now advance `n`
65                unsafe {
66                    buf.advance(n);
67                }
68                Poll::Ready(Ok(()))
69            }
70            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
71            Poll::Pending => Poll::Pending,
72        }
73    }
74}
75
76impl<C: AsyncWrite + Unpin> hyper::rt::Write for ConnectionWrapper<C> {
77    fn poll_write(
78        self: Pin<&mut Self>,
79        cx: &mut Context<'_>,
80        buf: &[u8],
81    ) -> Poll<Result<usize, std::io::Error>> {
82        let inner = &mut self.get_mut().0;
83        Pin::new(inner).poll_write(cx, buf)
84    }
85
86    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
87        let inner = &mut self.get_mut().0;
88        Pin::new(inner).poll_flush(cx)
89    }
90
91    fn poll_shutdown(
92        self: Pin<&mut Self>,
93        cx: &mut Context<'_>,
94    ) -> Poll<Result<(), std::io::Error>> {
95        let inner = &mut self.get_mut().0;
96        Pin::new(inner).poll_close(cx)
97    }
98}
99
100#[derive(Debug)]
101struct Prefixed<C> {
102    buffer: Vec<u8>,
103    pos: usize,
104    inner: C,
105}
106
107impl<C> Prefixed<C> {
108    const fn new(inner: C, buffer: Vec<u8>) -> Self {
109        Self {
110            buffer,
111            pos: 0,
112            inner,
113        }
114    }
115}
116
117impl<C: Unpin> Unpin for Prefixed<C> {}
118
119impl<C: AsyncRead + Unpin> AsyncRead for Prefixed<C> {
120    fn poll_read(
121        self: Pin<&mut Self>,
122        cx: &mut Context<'_>,
123        buf: &mut [u8],
124    ) -> Poll<Result<usize, std::io::Error>> {
125        let this = self.get_mut();
126        if this.pos < this.buffer.len() {
127            let available = this.buffer.len() - this.pos;
128            let n = available.min(buf.len());
129            buf[..n].copy_from_slice(&this.buffer[this.pos..this.pos + n]);
130            this.pos += n;
131            if this.pos == this.buffer.len() {
132                this.buffer.clear();
133                this.pos = 0;
134            }
135            return Poll::Ready(Ok(n));
136        }
137
138        Pin::new(&mut this.inner).poll_read(cx, buf)
139    }
140}
141
142impl<C: AsyncWrite + Unpin> AsyncWrite for Prefixed<C> {
143    fn poll_write(
144        self: Pin<&mut Self>,
145        cx: &mut Context<'_>,
146        buf: &[u8],
147    ) -> Poll<Result<usize, std::io::Error>> {
148        Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
149    }
150
151    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
152        Pin::new(&mut self.get_mut().inner).poll_flush(cx)
153    }
154
155    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
156        Pin::new(&mut self.get_mut().inner).poll_close(cx)
157    }
158}
159
160impl Server for Hyper {
161    async fn serve<C, E>(
162        self,
163        executor: impl executor_core::Executor + 'static,
164        error_handler: impl Fn(E) + Send + Sync + 'static,
165        mut connections: impl Stream<Item = Result<C, E>> + Unpin + Send + 'static,
166        endpoint: impl Endpoint + Sync + Clone + 'static,
167    ) where
168        C: Unpin + Send + AsyncRead + AsyncWrite + 'static,
169        E: std::error::Error,
170    {
171        const HTTP2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
172
173        let executor = Arc::new(executor);
174        let hyper_executor = ExecutorWrapper::new(executor.clone());
175        let shared_executor: Arc<AnyExecutor> = Arc::new(AnyExecutor::new(executor.clone()));
176        while let Some(connection) = connections.next().await {
177            match connection {
178                Ok(connection) => {
179                    let serve_executor = executor.clone();
180                    let endpoint = endpoint.clone();
181                    let hyper_executor = hyper_executor.clone();
182                    let shared_executor = shared_executor.clone();
183                    let serve_future = async move {
184                        let (connection, is_h2) =
185                            match sniff_protocol(connection, HTTP2_PREFACE).await {
186                                Ok(result) => result,
187                                Err(error) => {
188                                    error!("Failed to read connection preface: {error}");
189                                    return;
190                                }
191                            };
192
193                        if is_h2 {
194                            let builder = Http2Builder::new(hyper_executor);
195                            let service = IntoService::new(endpoint, shared_executor);
196                            if let Err(error) = builder
197                                .serve_connection(ConnectionWrapper(connection), service)
198                                .await
199                            {
200                                error!("Failed to serve Hyper h2 connection: {error}");
201                            }
202                        } else {
203                            let builder = Http1Builder::new();
204                            let service = IntoService::new(endpoint, shared_executor);
205                            if let Err(error) = builder
206                                .serve_connection(ConnectionWrapper(connection), service)
207                                .with_upgrades()
208                                .await
209                            {
210                                error!("Failed to serve Hyper h1 connection: {error}");
211                            }
212                        }
213                    };
214                    serve_executor.spawn(serve_future).detach();
215                }
216                Err(error) => error_handler(error),
217            }
218        }
219    }
220}
221
222async fn sniff_protocol<C>(mut stream: C, preface: &[u8]) -> std::io::Result<(Prefixed<C>, bool)>
223where
224    C: AsyncRead + AsyncWrite + Unpin,
225{
226    let mut buf = vec![0u8; preface.len()];
227    let n = stream.read(&mut buf).await?;
228    buf.truncate(n);
229    let is_h2 = buf.starts_with(preface);
230    Ok((Prefixed::new(stream, buf), is_h2))
231}