1use http::Request;
2use http_body_util::{BodyExt, Full};
3use hyper::body::Bytes;
4use k8s_openapi::api::core::v1::Pod;
5use kube::Api;
6use secrecy::Secret;
7use tracing::*;
8
9use crate::{init_request, raft_join_request, BytesBody, HttpRequest, PodApi, VAULT_PORT};
10
11#[derive(Debug, serde::Serialize)]
12pub struct InitRequest {
13 pub secret_shares: u8,
14 pub secret_threshold: u8,
15 pub stored_shares: u8,
16 pub pgp_keys: serde_json::Value,
17 pub recovery_shares: u8,
18 pub recovery_threshold: u8,
19 pub recovery_pgp_keys: serde_json::Value,
20 pub root_token_pgp_key: String,
21}
22
23impl Default for InitRequest {
24 fn default() -> Self {
25 Self {
26 secret_shares: 3,
27 secret_threshold: 2,
28 stored_shares: 0,
29 pgp_keys: serde_json::Value::Null,
30 recovery_shares: 0,
31 recovery_threshold: 0,
32 recovery_pgp_keys: serde_json::Value::Null,
33 root_token_pgp_key: "".to_string(),
34 }
35 }
36}
37
38#[derive(Clone, Debug, serde::Deserialize)]
39pub struct InitResult {
40 pub keys: Vec<Secret<String>>,
41 pub keys_base64: Vec<Secret<String>>,
42 pub root_token: Secret<String>,
43}
44
45#[async_trait::async_trait]
47pub trait Init {
48 async fn init(&mut self, req: InitRequest) -> anyhow::Result<InitResult>;
50}
51
52#[async_trait::async_trait]
53impl<T> Init for T
54where
55 T: HttpRequest<BytesBody> + Send + Sync + 'static,
56{
57 async fn init(&mut self, req: InitRequest) -> anyhow::Result<InitResult> {
58 let body = serde_json::ser::to_string(&req)?;
59
60 let http_req = init_request(Full::new(Bytes::from(body.to_string())).boxed())?;
61
62 let (parts, body) = self.send_request(http_req).await?.into_parts();
63
64 let body = String::from_utf8(body.to_vec())?;
65
66 if parts.status != hyper::StatusCode::OK {
67 return Err(anyhow::anyhow!("initializing: {}", body));
68 }
69
70 let response: InitResult = serde_json::from_str(&body)?;
71
72 Ok(response)
73 }
74}
75
76#[async_trait::async_trait]
78pub trait RaftJoin {
79 async fn raft_join(&mut self, join_to: &str) -> anyhow::Result<()>;
81}
82
83#[async_trait::async_trait]
84impl<T> RaftJoin for T
85where
86 T: HttpRequest<BytesBody> + Send + Sync + 'static,
87{
88 async fn raft_join(&mut self, join_to: &str) -> anyhow::Result<()> {
89 let body = serde_json::json!({
90 "leader_api_addr": join_to,
91 });
92
93 let http_req = raft_join_request(Full::new(Bytes::from(body.to_string())).boxed())?;
94
95 let (parts, body) = self.send_request(http_req).await?.into_parts();
96
97 let body = String::from_utf8(body.to_vec())?;
98
99 if parts.status != hyper::StatusCode::OK {
100 return Err(anyhow::anyhow!("raft-joining: {}", body));
101 }
102
103 Ok(())
104 }
105}
106
107#[tracing::instrument(skip_all)]
108pub async fn init(domain: String, api: &Api<Pod>, pod_name: &str) -> anyhow::Result<InitResult> {
109 let pod = api.get(pod_name).await?;
110
111 info!("initializing: {}", pod_name);
112
113 let pods = PodApi::new(api.clone(), true, domain);
114 let mut pf = pods
115 .http(
116 pod.metadata
117 .name
118 .clone()
119 .ok_or(anyhow::anyhow!("pod does not have a name"))?
120 .as_str(),
121 VAULT_PORT,
122 )
123 .await?;
124 pf.ready().await?;
125
126 let body = serde_json::json!({
127 "secret_shares": 3,
128 "secret_threshold": 2,
129 "stored_shares": 0,
130 "pgp_keys": serde_json::Value::Null,
131 "recovery_shares": 0,
132 "recovery_threshold": 0,
133 "recovery_pgp_keys": serde_json::Value::Null,
134 "root_token_pgp_key": "",
135 });
136
137 let http_req = Request::builder()
138 .uri("/v1/sys/init")
139 .header("Host", "127.0.0.1")
140 .header("X-Vault-Request", "true")
141 .method(hyper::Method::PUT)
142 .body(Full::new(Bytes::from(body.to_string())).boxed())?;
143
144 let (parts, body) = pf.send_request(http_req).await?.into_parts();
145
146 let body = String::from_utf8(body.to_vec())?;
147
148 if parts.status != hyper::StatusCode::OK {
149 return Err(anyhow::anyhow!("{}", body));
150 }
151
152 let response: InitResult = serde_json::from_str(&body)?;
153
154 Ok(response)
155}
156
157#[tracing::instrument(skip_all)]
158pub async fn raft_join(
159 domain: String,
160 api: &Api<Pod>,
161 pod_name: &str,
162 join_to: &str,
163) -> anyhow::Result<()> {
164 let pod = api.get(pod_name).await?;
165
166 info!(
167 "raft joining: {} to {}",
168 pod.metadata
169 .name
170 .clone()
171 .ok_or(anyhow::anyhow!("pod does not have a name"))?,
172 join_to,
173 );
174
175 let pods = PodApi::new(api.clone(), true, domain);
176 let mut pf = pods
177 .http(
178 pod.metadata
179 .name
180 .clone()
181 .ok_or(anyhow::anyhow!("pod does not have a name"))?
182 .as_str(),
183 VAULT_PORT,
184 )
185 .await?;
186 pf.ready().await?;
187
188 let body = serde_json::json!({
189 "leader_api_addr": join_to,
190 });
191
192 let http_req = Request::builder()
193 .uri("/v1/sys/storage/raft/join")
194 .header("Host", "127.0.0.1")
195 .header("X-Vault-Request", "true")
196 .method(hyper::Method::POST)
197 .body(Full::new(Bytes::from(body.to_string())).boxed())?;
198
199 let (parts, body) = pf.send_request(http_req).await?.into_parts();
200
201 let body = String::from_utf8(body.to_vec())?;
202
203 if parts.status != hyper::StatusCode::OK {
204 return Err(anyhow::anyhow!("{}", body));
205 }
206
207 Ok(())
208}
209
210#[cfg(test)]
211mod tests {
212 use http::{Method, StatusCode};
213 use wiremock::{
214 matchers::{body_json, header, method, path},
215 Mock, MockServer, ResponseTemplate,
216 };
217
218 use crate::{
219 HttpForwarderService, {Init, InitRequest, RaftJoin},
220 };
221
222 #[tokio::test]
223 async fn init_calls_api() {
224 let mock_server = MockServer::start().await;
225
226 Mock::given(method(Method::PUT))
227 .and(path("/v1/sys/init"))
228 .and(header("X-Vault-Request", "true"))
229 .respond_with(
230 ResponseTemplate::new(StatusCode::OK).set_body_json(serde_json::json!({
231 "keys": vec!["abc"],
232 "keys_base64": vec!["YWJj"],
233 "root_token": "def",
234 })),
235 )
236 .expect(1)
237 .mount(&mock_server)
238 .await;
239
240 let mut client = HttpForwarderService::http(
241 tokio::net::TcpStream::connect(mock_server.uri().strip_prefix("http://").unwrap())
242 .await
243 .unwrap(),
244 )
245 .await
246 .unwrap();
247
248 let outcome = client.init(InitRequest::default()).await;
249
250 assert!(outcome.is_ok());
251 }
252
253 #[tokio::test]
254 async fn raft_join_calls_api() {
255 let mock_server = MockServer::start().await;
256
257 Mock::given(method(Method::POST))
258 .and(path("/v1/sys/storage/raft/join"))
259 .and(header("X-Vault-Request", "true"))
260 .and(body_json(serde_json::json!({
261 "leader_api_addr": "other-instance",
262 })))
263 .respond_with(ResponseTemplate::new(StatusCode::OK))
264 .expect(1)
265 .mount(&mock_server)
266 .await;
267
268 let mut client = HttpForwarderService::http(
269 tokio::net::TcpStream::connect(mock_server.uri().strip_prefix("http://").unwrap())
270 .await
271 .unwrap(),
272 )
273 .await
274 .unwrap();
275
276 let outcome = client.raft_join("other-instance").await;
277
278 assert!(outcome.is_ok());
279 }
280}