vault_mgmt_lib/
http.rs

1use std::{convert::Infallible, sync::Arc};
2
3use http::Response;
4use http_body_util::{combinators::BoxBody, BodyExt, Empty};
5use hyper::{
6    body::{Body, Bytes},
7    Request,
8};
9use hyper_util::rt::TokioIo;
10use rustls::crypto::ring;
11use secrecy::{ExposeSecret, Secret};
12use tokio::io::{AsyncRead, AsyncWrite};
13use tokio_rustls::rustls::{pki_types, RootCertStore};
14use tracing::*;
15
16pub type BytesBody = BoxBody<Bytes, Infallible>;
17
18/// Send HTTP requests
19#[async_trait::async_trait]
20pub trait HttpRequest<B>
21where
22    B: Body,
23{
24    /// Send an HTTP request and return the response
25    async fn send_request(&mut self, req: Request<B>) -> hyper::Result<Response<Bytes>>;
26
27    /// Wait until the connection is ready to send requests
28    async fn ready(&mut self) -> anyhow::Result<()>;
29}
30
31/// Forward HTTP requests over a connection stream
32pub struct HttpForwarderService<B>
33where
34    B: Body,
35{
36    sender: hyper::client::conn::http1::SendRequest<B>,
37}
38
39impl<B> HttpForwarderService<B>
40where
41    B: Body<Data = Bytes, Error = Infallible> + Send + 'static,
42{
43    /// Forward HTTP requests over a connection stream
44    pub async fn http<T>(stream: T) -> anyhow::Result<HttpForwarderService<B>>
45    where
46        T: AsyncRead + AsyncWrite + Unpin + Sync + Send + 'static,
47    {
48        let io = TokioIo::new(stream);
49
50        let (sender, connection) = hyper::client::conn::http1::handshake(io).await?;
51
52        tokio::spawn(async move {
53            if let Err(e) = connection.await {
54                warn!("Error in connection: {}", e);
55            }
56        });
57
58        Ok(Self { sender })
59    }
60
61    /// Wrap the connection stream in TLS and forward HTTP requests over it
62    /// The domain is used to verify the TLS certificate
63    /// The native root certificates are used to verify the TLS certificate
64    ///
65    /// TODO: allow customizing the TLS configuration
66    pub async fn https<T>(domain: &str, stream: T) -> anyhow::Result<HttpForwarderService<B>>
67    where
68        T: AsyncRead + AsyncWrite + Unpin + Sync + Send + 'static,
69    {
70        let tls_stream = setup_tls(domain, stream).await?;
71
72        HttpForwarderService::http(tls_stream).await
73    }
74}
75
76#[async_trait::async_trait]
77impl<B> HttpRequest<B> for HttpForwarderService<B>
78where
79    B: Body<Data = Bytes, Error = Infallible> + Send + 'static,
80{
81    async fn send_request(&mut self, req: Request<B>) -> hyper::Result<Response<Bytes>> {
82        let (parts, body) = self.sender.send_request(req).await?.into_parts();
83        let body = body.boxed().collect().await?.to_bytes();
84        Ok(Response::from_parts(parts, body))
85    }
86
87    async fn ready(&mut self) -> anyhow::Result<()> {
88        self.sender
89            .ready()
90            .await
91            .map_err(|e| anyhow::anyhow!("{}", e))
92    }
93}
94
95pub(crate) async fn setup_tls<T>(
96    domain: &str,
97    stream: T,
98) -> anyhow::Result<tokio_rustls::client::TlsStream<T>>
99where
100    T: AsyncRead + AsyncWrite + Unpin + Sync + Send + 'static,
101{
102    let mut root_cert_store = RootCertStore::empty();
103
104    for cert in rustls_native_certs::load_native_certs()
105        .map_err(|e| anyhow::anyhow!("could not load platform certs: {}", e))?
106    {
107        root_cert_store.add(cert).unwrap();
108    }
109
110    let tls = tokio_rustls::rustls::ClientConfig::builder_with_provider(Arc::new(
111        ring::default_provider(),
112    ))
113    .with_safe_default_protocol_versions()?
114    .with_root_certificates(root_cert_store)
115    .with_no_client_auth();
116
117    let tls_stream = tokio_rustls::TlsConnector::from(Arc::new(tls))
118        .connect(pki_types::ServerName::try_from(domain)?.to_owned(), stream)
119        .await?;
120
121    Ok(tls_stream)
122}
123
124pub const VAULT_PORT: u16 = 8200;
125
126pub(crate) fn vault_request() -> http::request::Builder {
127    hyper::Request::builder()
128        .header("Host", "127.0.0.1")
129        .header("x-Vault-Request", "true")
130}
131
132pub(crate) fn vault_request_with_token(token: Secret<String>) -> http::request::Builder {
133    vault_request().header("X-Vault-Token", token.expose_secret())
134}
135
136const SEAL_STATUS_URL: &str = "/v1/sys/seal-status";
137pub(crate) fn seal_status_request(body: BytesBody) -> http::Result<Request<BytesBody>> {
138    vault_request()
139        .uri(SEAL_STATUS_URL)
140        .method(hyper::Method::GET)
141        .body(body)
142}
143
144const UNSEAL_URL: &str = "/v1/sys/unseal";
145pub(crate) fn unseal_request(body: BytesBody) -> http::Result<Request<BytesBody>> {
146    vault_request()
147        .uri(UNSEAL_URL)
148        .method(hyper::Method::PUT)
149        .body(body)
150}
151
152pub(crate) fn get_unseal_keys_request(
153    path: &str,
154    token: Secret<String>,
155) -> http::Result<Request<BytesBody>> {
156    vault_request_with_token(token)
157        .uri(path)
158        .method(hyper::Method::GET)
159        .body(Empty::<Bytes>::new().boxed())
160}
161
162const INIT_URL: &str = "/v1/sys/init";
163pub(crate) fn init_request(body: BytesBody) -> http::Result<Request<BytesBody>> {
164    vault_request()
165        .uri(INIT_URL)
166        .method(hyper::Method::PUT)
167        .body(body)
168}
169
170const RAFT_JOIN_URL: &str = "/v1/sys/storage/raft/join";
171pub(crate) fn raft_join_request(body: BytesBody) -> http::Result<Request<BytesBody>> {
172    vault_request()
173        .uri(RAFT_JOIN_URL)
174        .method(hyper::Method::POST)
175        .body(body)
176}
177
178const RAFT_CONFIGURATION_URL: &str = "/v1/sys/storage/raft/configuration";
179pub(crate) fn raft_configuration_request(
180    token: Secret<String>,
181    body: BytesBody,
182) -> http::Result<Request<BytesBody>> {
183    vault_request_with_token(token)
184        .uri(RAFT_CONFIGURATION_URL)
185        .method(hyper::Method::GET)
186        .body(body)
187}
188
189const STEP_DOWN_URL: &str = "/v1/sys/step-down";
190pub(crate) fn step_down_request(
191    token: Secret<String>,
192    body: BytesBody,
193) -> http::Result<Request<BytesBody>> {
194    vault_request_with_token(token)
195        .uri(STEP_DOWN_URL)
196        .method(hyper::Method::PUT)
197        .body(body)
198}
199
200#[cfg(test)]
201mod tests {
202    use http::StatusCode;
203    use http_body_util::Empty;
204    use hyper::body::Bytes;
205    use wiremock::{matchers::any, Mock, MockServer, ResponseTemplate};
206
207    use crate::http::{HttpForwarderService, HttpRequest};
208
209    #[tokio::test]
210    async fn http_forward_works() {
211        let mock_server = MockServer::start().await;
212
213        Mock::given(any())
214            .respond_with(ResponseTemplate::new(StatusCode::OK))
215            .expect(1)
216            .mount(&mock_server)
217            .await;
218
219        let mut client = HttpForwarderService::http(
220            tokio::net::TcpStream::connect(mock_server.uri().strip_prefix("http://").unwrap())
221                .await
222                .unwrap(),
223        )
224        .await
225        .unwrap();
226
227        let http_req = hyper::Request::builder()
228            .uri("/")
229            .method(hyper::Method::GET)
230            .body(Empty::<Bytes>::new())
231            .unwrap();
232
233        let (parts, _) = client.send_request(http_req).await.unwrap().into_parts();
234
235        assert!(parts.status.is_success());
236    }
237
238    // TODO: do not use remote host for testing
239    #[ignore = "connecting to google.com"]
240    #[tokio::test]
241    async fn https_forward_works() {
242        const DOMAIN: &str = "google.com";
243
244        let stream = tokio::net::TcpStream::connect(&format!("{}:443", DOMAIN))
245            .await
246            .unwrap();
247
248        let mut pf = HttpForwarderService::https(DOMAIN, stream).await.unwrap();
249
250        let http_req = hyper::Request::builder()
251            .uri("/")
252            .header("Host", DOMAIN)
253            .method(hyper::Method::GET)
254            .body(Empty::<Bytes>::new())
255            .unwrap();
256
257        let (parts, _) = pf.send_request(http_req).await.unwrap().into_parts();
258
259        assert!(parts.status.is_success() || parts.status.is_redirection());
260    }
261}