vault_mgmt_lib/
init.rs

1use http::Request;
2use http_body_util::{BodyExt, Full};
3use hyper::body::Bytes;
4use k8s_openapi::api::core::v1::Pod;
5use kube::Api;
6use secrecy::Secret;
7use tracing::*;
8
9use crate::{init_request, raft_join_request, BytesBody, HttpRequest, PodApi, VAULT_PORT};
10
11#[derive(Debug, serde::Serialize)]
12pub struct InitRequest {
13    pub secret_shares: u8,
14    pub secret_threshold: u8,
15    pub stored_shares: u8,
16    pub pgp_keys: serde_json::Value,
17    pub recovery_shares: u8,
18    pub recovery_threshold: u8,
19    pub recovery_pgp_keys: serde_json::Value,
20    pub root_token_pgp_key: String,
21}
22
23impl Default for InitRequest {
24    fn default() -> Self {
25        Self {
26            secret_shares: 3,
27            secret_threshold: 2,
28            stored_shares: 0,
29            pgp_keys: serde_json::Value::Null,
30            recovery_shares: 0,
31            recovery_threshold: 0,
32            recovery_pgp_keys: serde_json::Value::Null,
33            root_token_pgp_key: "".to_string(),
34        }
35    }
36}
37
38#[derive(Clone, Debug, serde::Deserialize)]
39pub struct InitResult {
40    pub keys: Vec<Secret<String>>,
41    pub keys_base64: Vec<Secret<String>>,
42    pub root_token: Secret<String>,
43}
44
45/// Init a vault process
46#[async_trait::async_trait]
47pub trait Init {
48    /// Init a vault process
49    async fn init(&mut self, req: InitRequest) -> anyhow::Result<InitResult>;
50}
51
52#[async_trait::async_trait]
53impl<T> Init for T
54where
55    T: HttpRequest<BytesBody> + Send + Sync + 'static,
56{
57    async fn init(&mut self, req: InitRequest) -> anyhow::Result<InitResult> {
58        let body = serde_json::ser::to_string(&req)?;
59
60        let http_req = init_request(Full::new(Bytes::from(body.to_string())).boxed())?;
61
62        let (parts, body) = self.send_request(http_req).await?.into_parts();
63
64        let body = String::from_utf8(body.to_vec())?;
65
66        if parts.status != hyper::StatusCode::OK {
67            return Err(anyhow::anyhow!("initializing: {}", body));
68        }
69
70        let response: InitResult = serde_json::from_str(&body)?;
71
72        Ok(response)
73    }
74}
75
76/// Join a vault process to a raft cluster
77#[async_trait::async_trait]
78pub trait RaftJoin {
79    /// Join a vault process to a raft cluster
80    async fn raft_join(&mut self, join_to: &str) -> anyhow::Result<()>;
81}
82
83#[async_trait::async_trait]
84impl<T> RaftJoin for T
85where
86    T: HttpRequest<BytesBody> + Send + Sync + 'static,
87{
88    async fn raft_join(&mut self, join_to: &str) -> anyhow::Result<()> {
89        let body = serde_json::json!({
90            "leader_api_addr": join_to,
91        });
92
93        let http_req = raft_join_request(Full::new(Bytes::from(body.to_string())).boxed())?;
94
95        let (parts, body) = self.send_request(http_req).await?.into_parts();
96
97        let body = String::from_utf8(body.to_vec())?;
98
99        if parts.status != hyper::StatusCode::OK {
100            return Err(anyhow::anyhow!("raft-joining: {}", body));
101        }
102
103        Ok(())
104    }
105}
106
107#[tracing::instrument(skip_all)]
108pub async fn init(domain: String, api: &Api<Pod>, pod_name: &str) -> anyhow::Result<InitResult> {
109    let pod = api.get(pod_name).await?;
110
111    info!("initializing: {}", pod_name);
112
113    let pods = PodApi::new(api.clone(), true, domain);
114    let mut pf = pods
115        .http(
116            pod.metadata
117                .name
118                .clone()
119                .ok_or(anyhow::anyhow!("pod does not have a name"))?
120                .as_str(),
121            VAULT_PORT,
122        )
123        .await?;
124    pf.ready().await?;
125
126    let body = serde_json::json!({
127        "secret_shares": 3,
128        "secret_threshold": 2,
129        "stored_shares": 0,
130        "pgp_keys": serde_json::Value::Null,
131        "recovery_shares": 0,
132        "recovery_threshold": 0,
133        "recovery_pgp_keys": serde_json::Value::Null,
134        "root_token_pgp_key": "",
135    });
136
137    let http_req = Request::builder()
138        .uri("/v1/sys/init")
139        .header("Host", "127.0.0.1")
140        .header("X-Vault-Request", "true")
141        .method(hyper::Method::PUT)
142        .body(Full::new(Bytes::from(body.to_string())).boxed())?;
143
144    let (parts, body) = pf.send_request(http_req).await?.into_parts();
145
146    let body = String::from_utf8(body.to_vec())?;
147
148    if parts.status != hyper::StatusCode::OK {
149        return Err(anyhow::anyhow!("{}", body));
150    }
151
152    let response: InitResult = serde_json::from_str(&body)?;
153
154    Ok(response)
155}
156
157#[tracing::instrument(skip_all)]
158pub async fn raft_join(
159    domain: String,
160    api: &Api<Pod>,
161    pod_name: &str,
162    join_to: &str,
163) -> anyhow::Result<()> {
164    let pod = api.get(pod_name).await?;
165
166    info!(
167        "raft joining: {} to {}",
168        pod.metadata
169            .name
170            .clone()
171            .ok_or(anyhow::anyhow!("pod does not have a name"))?,
172        join_to,
173    );
174
175    let pods = PodApi::new(api.clone(), true, domain);
176    let mut pf = pods
177        .http(
178            pod.metadata
179                .name
180                .clone()
181                .ok_or(anyhow::anyhow!("pod does not have a name"))?
182                .as_str(),
183            VAULT_PORT,
184        )
185        .await?;
186    pf.ready().await?;
187
188    let body = serde_json::json!({
189        "leader_api_addr": join_to,
190    });
191
192    let http_req = Request::builder()
193        .uri("/v1/sys/storage/raft/join")
194        .header("Host", "127.0.0.1")
195        .header("X-Vault-Request", "true")
196        .method(hyper::Method::POST)
197        .body(Full::new(Bytes::from(body.to_string())).boxed())?;
198
199    let (parts, body) = pf.send_request(http_req).await?.into_parts();
200
201    let body = String::from_utf8(body.to_vec())?;
202
203    if parts.status != hyper::StatusCode::OK {
204        return Err(anyhow::anyhow!("{}", body));
205    }
206
207    Ok(())
208}
209
210#[cfg(test)]
211mod tests {
212    use http::{Method, StatusCode};
213    use wiremock::{
214        matchers::{body_json, header, method, path},
215        Mock, MockServer, ResponseTemplate,
216    };
217
218    use crate::{
219        HttpForwarderService, {Init, InitRequest, RaftJoin},
220    };
221
222    #[tokio::test]
223    async fn init_calls_api() {
224        let mock_server = MockServer::start().await;
225
226        Mock::given(method(Method::PUT))
227            .and(path("/v1/sys/init"))
228            .and(header("X-Vault-Request", "true"))
229            .respond_with(
230                ResponseTemplate::new(StatusCode::OK).set_body_json(serde_json::json!({
231                    "keys": vec!["abc"],
232                    "keys_base64": vec!["YWJj"],
233                    "root_token": "def",
234                })),
235            )
236            .expect(1)
237            .mount(&mock_server)
238            .await;
239
240        let mut client = HttpForwarderService::http(
241            tokio::net::TcpStream::connect(mock_server.uri().strip_prefix("http://").unwrap())
242                .await
243                .unwrap(),
244        )
245        .await
246        .unwrap();
247
248        let outcome = client.init(InitRequest::default()).await;
249
250        assert!(outcome.is_ok());
251    }
252
253    #[tokio::test]
254    async fn raft_join_calls_api() {
255        let mock_server = MockServer::start().await;
256
257        Mock::given(method(Method::POST))
258            .and(path("/v1/sys/storage/raft/join"))
259            .and(header("X-Vault-Request", "true"))
260            .and(body_json(serde_json::json!({
261                "leader_api_addr": "other-instance",
262            })))
263            .respond_with(ResponseTemplate::new(StatusCode::OK))
264            .expect(1)
265            .mount(&mock_server)
266            .await;
267
268        let mut client = HttpForwarderService::http(
269            tokio::net::TcpStream::connect(mock_server.uri().strip_prefix("http://").unwrap())
270                .await
271                .unwrap(),
272        )
273        .await
274        .unwrap();
275
276        let outcome = client.raft_join("other-instance").await;
277
278        assert!(outcome.is_ok());
279    }
280}