1use std::{convert::Infallible, sync::Arc};
2
3use http::Response;
4use http_body_util::{combinators::BoxBody, BodyExt, Empty};
5use hyper::{
6 body::{Body, Bytes},
7 Request,
8};
9use hyper_util::rt::TokioIo;
10use rustls::crypto::ring;
11use secrecy::{ExposeSecret, Secret};
12use tokio::io::{AsyncRead, AsyncWrite};
13use tokio_rustls::rustls::{pki_types, RootCertStore};
14use tracing::*;
15
16pub type BytesBody = BoxBody<Bytes, Infallible>;
17
18#[async_trait::async_trait]
20pub trait HttpRequest<B>
21where
22 B: Body,
23{
24 async fn send_request(&mut self, req: Request<B>) -> hyper::Result<Response<Bytes>>;
26
27 async fn ready(&mut self) -> anyhow::Result<()>;
29}
30
31pub struct HttpForwarderService<B>
33where
34 B: Body,
35{
36 sender: hyper::client::conn::http1::SendRequest<B>,
37}
38
39impl<B> HttpForwarderService<B>
40where
41 B: Body<Data = Bytes, Error = Infallible> + Send + 'static,
42{
43 pub async fn http<T>(stream: T) -> anyhow::Result<HttpForwarderService<B>>
45 where
46 T: AsyncRead + AsyncWrite + Unpin + Sync + Send + 'static,
47 {
48 let io = TokioIo::new(stream);
49
50 let (sender, connection) = hyper::client::conn::http1::handshake(io).await?;
51
52 tokio::spawn(async move {
53 if let Err(e) = connection.await {
54 warn!("Error in connection: {}", e);
55 }
56 });
57
58 Ok(Self { sender })
59 }
60
61 pub async fn https<T>(domain: &str, stream: T) -> anyhow::Result<HttpForwarderService<B>>
67 where
68 T: AsyncRead + AsyncWrite + Unpin + Sync + Send + 'static,
69 {
70 let tls_stream = setup_tls(domain, stream).await?;
71
72 HttpForwarderService::http(tls_stream).await
73 }
74}
75
76#[async_trait::async_trait]
77impl<B> HttpRequest<B> for HttpForwarderService<B>
78where
79 B: Body<Data = Bytes, Error = Infallible> + Send + 'static,
80{
81 async fn send_request(&mut self, req: Request<B>) -> hyper::Result<Response<Bytes>> {
82 let (parts, body) = self.sender.send_request(req).await?.into_parts();
83 let body = body.boxed().collect().await?.to_bytes();
84 Ok(Response::from_parts(parts, body))
85 }
86
87 async fn ready(&mut self) -> anyhow::Result<()> {
88 self.sender
89 .ready()
90 .await
91 .map_err(|e| anyhow::anyhow!("{}", e))
92 }
93}
94
95pub(crate) async fn setup_tls<T>(
96 domain: &str,
97 stream: T,
98) -> anyhow::Result<tokio_rustls::client::TlsStream<T>>
99where
100 T: AsyncRead + AsyncWrite + Unpin + Sync + Send + 'static,
101{
102 let mut root_cert_store = RootCertStore::empty();
103
104 for cert in rustls_native_certs::load_native_certs()
105 .map_err(|e| anyhow::anyhow!("could not load platform certs: {}", e))?
106 {
107 root_cert_store.add(cert).unwrap();
108 }
109
110 let tls = tokio_rustls::rustls::ClientConfig::builder_with_provider(Arc::new(
111 ring::default_provider(),
112 ))
113 .with_safe_default_protocol_versions()?
114 .with_root_certificates(root_cert_store)
115 .with_no_client_auth();
116
117 let tls_stream = tokio_rustls::TlsConnector::from(Arc::new(tls))
118 .connect(pki_types::ServerName::try_from(domain)?.to_owned(), stream)
119 .await?;
120
121 Ok(tls_stream)
122}
123
124pub const VAULT_PORT: u16 = 8200;
125
126pub(crate) fn vault_request() -> http::request::Builder {
127 hyper::Request::builder()
128 .header("Host", "127.0.0.1")
129 .header("x-Vault-Request", "true")
130}
131
132pub(crate) fn vault_request_with_token(token: Secret<String>) -> http::request::Builder {
133 vault_request().header("X-Vault-Token", token.expose_secret())
134}
135
136const SEAL_STATUS_URL: &str = "/v1/sys/seal-status";
137pub(crate) fn seal_status_request(body: BytesBody) -> http::Result<Request<BytesBody>> {
138 vault_request()
139 .uri(SEAL_STATUS_URL)
140 .method(hyper::Method::GET)
141 .body(body)
142}
143
144const UNSEAL_URL: &str = "/v1/sys/unseal";
145pub(crate) fn unseal_request(body: BytesBody) -> http::Result<Request<BytesBody>> {
146 vault_request()
147 .uri(UNSEAL_URL)
148 .method(hyper::Method::PUT)
149 .body(body)
150}
151
152pub(crate) fn get_unseal_keys_request(
153 path: &str,
154 token: Secret<String>,
155) -> http::Result<Request<BytesBody>> {
156 vault_request_with_token(token)
157 .uri(path)
158 .method(hyper::Method::GET)
159 .body(Empty::<Bytes>::new().boxed())
160}
161
162const INIT_URL: &str = "/v1/sys/init";
163pub(crate) fn init_request(body: BytesBody) -> http::Result<Request<BytesBody>> {
164 vault_request()
165 .uri(INIT_URL)
166 .method(hyper::Method::PUT)
167 .body(body)
168}
169
170const RAFT_JOIN_URL: &str = "/v1/sys/storage/raft/join";
171pub(crate) fn raft_join_request(body: BytesBody) -> http::Result<Request<BytesBody>> {
172 vault_request()
173 .uri(RAFT_JOIN_URL)
174 .method(hyper::Method::POST)
175 .body(body)
176}
177
178const RAFT_CONFIGURATION_URL: &str = "/v1/sys/storage/raft/configuration";
179pub(crate) fn raft_configuration_request(
180 token: Secret<String>,
181 body: BytesBody,
182) -> http::Result<Request<BytesBody>> {
183 vault_request_with_token(token)
184 .uri(RAFT_CONFIGURATION_URL)
185 .method(hyper::Method::GET)
186 .body(body)
187}
188
189const STEP_DOWN_URL: &str = "/v1/sys/step-down";
190pub(crate) fn step_down_request(
191 token: Secret<String>,
192 body: BytesBody,
193) -> http::Result<Request<BytesBody>> {
194 vault_request_with_token(token)
195 .uri(STEP_DOWN_URL)
196 .method(hyper::Method::PUT)
197 .body(body)
198}
199
200#[cfg(test)]
201mod tests {
202 use http::StatusCode;
203 use http_body_util::Empty;
204 use hyper::body::Bytes;
205 use wiremock::{matchers::any, Mock, MockServer, ResponseTemplate};
206
207 use crate::http::{HttpForwarderService, HttpRequest};
208
209 #[tokio::test]
210 async fn http_forward_works() {
211 let mock_server = MockServer::start().await;
212
213 Mock::given(any())
214 .respond_with(ResponseTemplate::new(StatusCode::OK))
215 .expect(1)
216 .mount(&mock_server)
217 .await;
218
219 let mut client = HttpForwarderService::http(
220 tokio::net::TcpStream::connect(mock_server.uri().strip_prefix("http://").unwrap())
221 .await
222 .unwrap(),
223 )
224 .await
225 .unwrap();
226
227 let http_req = hyper::Request::builder()
228 .uri("/")
229 .method(hyper::Method::GET)
230 .body(Empty::<Bytes>::new())
231 .unwrap();
232
233 let (parts, _) = client.send_request(http_req).await.unwrap().into_parts();
234
235 assert!(parts.status.is_success());
236 }
237
238 #[ignore = "connecting to google.com"]
240 #[tokio::test]
241 async fn https_forward_works() {
242 const DOMAIN: &str = "google.com";
243
244 let stream = tokio::net::TcpStream::connect(&format!("{}:443", DOMAIN))
245 .await
246 .unwrap();
247
248 let mut pf = HttpForwarderService::https(DOMAIN, stream).await.unwrap();
249
250 let http_req = hyper::Request::builder()
251 .uri("/")
252 .header("Host", DOMAIN)
253 .method(hyper::Method::GET)
254 .body(Empty::<Bytes>::new())
255 .unwrap();
256
257 let (parts, _) = pf.send_request(http_req).await.unwrap().into_parts();
258
259 assert!(parts.status.is_success() || parts.status.is_redirection());
260 }
261}