1use std::collections::HashMap;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5
6use crate::credentials::TempCredentials;
7use crate::error::{AvError, Result};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11#[serde(tag = "method", rename_all = "lowercase")]
12pub enum VaultAuth {
13 Token { token: Option<String> },
15 Approle {
17 role_id: String,
18 secret_id: Option<String>,
19 },
20 Kubernetes {
22 role: String,
23 jwt_path: Option<String>,
25 },
26}
27
28impl Default for VaultAuth {
29 fn default() -> Self {
30 VaultAuth::Token { token: None }
31 }
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize, Default)]
36pub struct VaultConfig {
37 pub address: Option<String>,
40 #[serde(default)]
42 pub auth: VaultAuth,
43 pub mount: Option<String>,
45 pub role: Option<String>,
47 pub namespace: Option<String>,
49 #[serde(default)]
51 pub tls_skip_verify: bool,
52}
53
54impl VaultConfig {
55 pub fn resolve_address(&self) -> Result<String> {
57 self.address
58 .clone()
59 .or_else(|| std::env::var("VAULT_ADDR").ok())
60 .ok_or_else(|| {
61 AvError::InvalidPolicy(
62 "Vault address not set. Set vault.address in config or VAULT_ADDR env var"
63 .to_string(),
64 )
65 })
66 }
67
68 pub fn mount_path(&self) -> &str {
70 self.mount.as_deref().unwrap_or("aws")
71 }
72}
73
74#[derive(Debug, Deserialize)]
76struct VaultAuthResponse {
77 auth: Option<VaultAuthData>,
78}
79
80#[derive(Debug, Deserialize)]
81struct VaultAuthData {
82 client_token: String,
83}
84
85#[derive(Debug, Deserialize)]
87struct VaultSecretResponse {
88 data: Option<VaultAwsCredentials>,
89 lease_duration: Option<u64>,
90}
91
92#[derive(Debug, Deserialize)]
93struct VaultAwsCredentials {
94 access_key: String,
95 secret_key: String,
96 security_token: Option<String>,
97}
98
99pub struct VaultIssuer {
101 address: String,
102 token: String,
103 mount: String,
104 namespace: Option<String>,
105}
106
107impl VaultIssuer {
108 pub async fn new(config: &VaultConfig) -> Result<Self> {
110 let address = config.resolve_address()?;
111 let mount = config.mount_path().to_string();
112 let namespace = config.namespace.clone();
113
114 let token = match &config.auth {
115 VaultAuth::Token { token } => {
116 token
117 .clone()
118 .or_else(|| std::env::var("VAULT_TOKEN").ok())
119 .ok_or_else(|| {
120 AvError::InvalidPolicy(
121 "Vault token not set. Set vault.auth.token in config or VAULT_TOKEN env var".to_string(),
122 )
123 })?
124 }
125 VaultAuth::Approle { role_id, secret_id } => {
126 Self::auth_approle(&address, namespace.as_deref(), role_id, secret_id.as_deref())
127 .await?
128 }
129 VaultAuth::Kubernetes { role, jwt_path } => {
130 let default_path = "/var/run/secrets/kubernetes.io/serviceaccount/token";
131 let path = jwt_path.as_deref().unwrap_or(default_path);
132 let jwt = std::fs::read_to_string(path).map_err(|e| {
133 AvError::InvalidPolicy(format!(
134 "Failed to read Kubernetes service account token from {}: {}",
135 path, e
136 ))
137 })?;
138 Self::auth_kubernetes(&address, namespace.as_deref(), role, &jwt).await?
139 }
140 };
141
142 tracing::info!(address = %address, mount = %mount, "Connected to Vault");
143
144 Ok(Self {
145 address,
146 token,
147 mount,
148 namespace,
149 })
150 }
151
152 async fn auth_approle(
154 address: &str,
155 namespace: Option<&str>,
156 role_id: &str,
157 secret_id: Option<&str>,
158 ) -> Result<String> {
159 let url = format!("{}/v1/auth/approle/login", address);
160 let mut body = HashMap::new();
161 body.insert("role_id", role_id);
162 if let Some(sid) = secret_id {
163 body.insert("secret_id", sid);
164 }
165
166 let resp = vault_post(&url, namespace, None, &body).await?;
167 let auth_resp: VaultAuthResponse = serde_json::from_str(&resp)
168 .map_err(|e| AvError::Sts(format!("Failed to parse Vault auth response: {}", e)))?;
169
170 auth_resp
171 .auth
172 .map(|a| a.client_token)
173 .ok_or_else(|| AvError::Sts("Vault AppRole auth returned no token".to_string()))
174 }
175
176 async fn auth_kubernetes(
178 address: &str,
179 namespace: Option<&str>,
180 role: &str,
181 jwt: &str,
182 ) -> Result<String> {
183 let url = format!("{}/v1/auth/kubernetes/login", address);
184 let mut body = HashMap::new();
185 body.insert("role", role);
186 body.insert("jwt", jwt);
187
188 let resp = vault_post(&url, namespace, None, &body).await?;
189 let auth_resp: VaultAuthResponse = serde_json::from_str(&resp)
190 .map_err(|e| AvError::Sts(format!("Failed to parse Vault auth response: {}", e)))?;
191
192 auth_resp
193 .auth
194 .map(|a| a.client_token)
195 .ok_or_else(|| AvError::Sts("Vault Kubernetes auth returned no token".to_string()))
196 }
197
198 pub async fn issue(&self, vault_role: &str, ttl: Duration) -> Result<TempCredentials> {
203 let url = format!("{}/v1/{}/creds/{}", self.address, self.mount, vault_role);
204
205 let ttl_str = format!("{}s", ttl.as_secs());
206 let mut body = HashMap::new();
207 body.insert("ttl", ttl_str.as_str());
208
209 tracing::info!(
210 vault_role = %vault_role,
211 ttl = %ttl_str,
212 "Requesting credentials from Vault AWS secrets engine"
213 );
214
215 let resp = vault_post(&url, self.namespace.as_deref(), Some(&self.token), &body).await?;
216 let secret: VaultSecretResponse = serde_json::from_str(&resp).map_err(|e| {
217 AvError::Sts(format!("Failed to parse Vault credential response: {}", e))
218 })?;
219
220 let creds = secret
221 .data
222 .ok_or_else(|| AvError::Sts("Vault returned no credential data".to_string()))?;
223
224 let lease_secs = secret.lease_duration.unwrap_or(ttl.as_secs());
225 let expires_at = chrono::Utc::now() + chrono::Duration::seconds(lease_secs as i64);
226
227 Ok(TempCredentials {
228 access_key_id: creds.access_key,
229 secret_access_key: creds.secret_key,
230 session_token: creds.security_token.unwrap_or_default(),
231 expires_at,
232 })
233 }
234
235 pub async fn read_sts_creds(&self, vault_role: &str, ttl: Duration) -> Result<TempCredentials> {
238 let url = format!("{}/v1/{}/sts/{}", self.address, self.mount, vault_role);
239
240 let ttl_str = format!("{}s", ttl.as_secs());
241 let mut body = HashMap::new();
242 body.insert("ttl", ttl_str.as_str());
243
244 tracing::info!(
245 vault_role = %vault_role,
246 ttl = %ttl_str,
247 "Requesting STS credentials from Vault"
248 );
249
250 let resp = vault_post(&url, self.namespace.as_deref(), Some(&self.token), &body).await?;
251 let secret: VaultSecretResponse = serde_json::from_str(&resp)
252 .map_err(|e| AvError::Sts(format!("Failed to parse Vault STS response: {}", e)))?;
253
254 let creds = secret
255 .data
256 .ok_or_else(|| AvError::Sts("Vault returned no STS credential data".to_string()))?;
257
258 let lease_secs = secret.lease_duration.unwrap_or(ttl.as_secs());
259 let expires_at = chrono::Utc::now() + chrono::Duration::seconds(lease_secs as i64);
260
261 Ok(TempCredentials {
262 access_key_id: creds.access_key,
263 secret_access_key: creds.secret_key,
264 session_token: creds.security_token.unwrap_or_default(),
265 expires_at,
266 })
267 }
268
269 pub async fn health_check(&self) -> Result<bool> {
271 let url = format!("{}/v1/sys/health", self.address);
272 match vault_get(&url, self.namespace.as_deref(), Some(&self.token)).await {
273 Ok(_) => Ok(true),
274 Err(_) => Ok(false),
275 }
276 }
277}
278
279async fn vault_post(
281 url: &str,
282 namespace: Option<&str>,
283 token: Option<&str>,
284 body: &HashMap<&str, &str>,
285) -> Result<String> {
286 let body_json = serde_json::to_string(body)
287 .map_err(|e| AvError::Sts(format!("Failed to serialize Vault request: {}", e)))?;
288
289 let mut headers = vec![("Content-Type", "application/json")];
290 let ns_header;
291 if let Some(ns) = namespace {
292 ns_header = ns.to_string();
293 headers.push(("X-Vault-Namespace", &ns_header));
294 }
295 let token_header;
296 if let Some(t) = token {
297 token_header = t.to_string();
298 headers.push(("X-Vault-Token", &token_header));
299 }
300
301 http_request("POST", url, &headers, Some(&body_json)).await
302}
303
304async fn vault_get(url: &str, namespace: Option<&str>, token: Option<&str>) -> Result<String> {
306 let mut headers: Vec<(&str, &str)> = vec![];
307 let ns_header;
308 if let Some(ns) = namespace {
309 ns_header = ns.to_string();
310 headers.push(("X-Vault-Namespace", &ns_header));
311 }
312 let token_header;
313 if let Some(t) = token {
314 token_header = t.to_string();
315 headers.push(("X-Vault-Token", &token_header));
316 }
317
318 http_request("GET", url, &headers, None).await
319}
320
321async fn http_request(
324 method: &str,
325 url: &str,
326 headers: &[(&str, &str)],
327 body: Option<&str>,
328) -> Result<String> {
329 use std::io::{Read, Write};
330 use std::net::TcpStream;
331
332 let parsed = parse_url(url)?;
333
334 let stream = TcpStream::connect(format!("{}:{}", parsed.host, parsed.port))
335 .map_err(|e| AvError::Sts(format!("Failed to connect to Vault at {}: {}", url, e)))?;
336 stream.set_read_timeout(Some(Duration::from_secs(30))).ok();
337
338 let mut request = format!(
339 "{} {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n",
340 method, parsed.path, parsed.host
341 );
342 for (key, value) in headers {
343 request.push_str(&format!("{}: {}\r\n", key, value));
344 }
345 if let Some(b) = body {
346 request.push_str(&format!("Content-Length: {}\r\n", b.len()));
347 }
348 request.push_str("\r\n");
349 if let Some(b) = body {
350 request.push_str(b);
351 }
352
353 let mut stream = stream;
356 stream
357 .write_all(request.as_bytes())
358 .map_err(|e| AvError::Sts(format!("Failed to send request to Vault: {}", e)))?;
359
360 let mut response = String::new();
361 stream
362 .read_to_string(&mut response)
363 .map_err(|e| AvError::Sts(format!("Failed to read Vault response: {}", e)))?;
364
365 if let Some(idx) = response.find("\r\n\r\n") {
367 let status_line = response.lines().next().unwrap_or("");
368 let status_code: u16 = status_line
369 .split_whitespace()
370 .nth(1)
371 .and_then(|s| s.parse().ok())
372 .unwrap_or(0);
373
374 if status_code >= 400 {
375 return Err(AvError::Sts(format!(
376 "Vault returned HTTP {}: {}",
377 status_code,
378 &response[idx + 4..]
379 )));
380 }
381
382 Ok(response[idx + 4..].to_string())
383 } else {
384 Err(AvError::Sts("Invalid HTTP response from Vault".to_string()))
385 }
386}
387
388struct ParsedUrl {
389 host: String,
390 port: u16,
391 path: String,
392}
393
394fn parse_url(url: &str) -> Result<ParsedUrl> {
395 let without_scheme = if let Some(rest) = url.strip_prefix("https://") {
396 rest
397 } else if let Some(rest) = url.strip_prefix("http://") {
398 rest
399 } else {
400 return Err(AvError::InvalidPolicy(format!(
401 "Invalid Vault URL: {}",
402 url
403 )));
404 };
405
406 let default_port: u16 = 8200;
407
408 let (host_port, path) = match without_scheme.find('/') {
409 Some(idx) => (&without_scheme[..idx], &without_scheme[idx..]),
410 None => (without_scheme, "/"),
411 };
412
413 let (host, port) = match host_port.rfind(':') {
414 Some(idx) => {
415 let port_str = &host_port[idx + 1..];
416 let port = port_str.parse::<u16>().unwrap_or(default_port);
417 (host_port[..idx].to_string(), port)
418 }
419 None => (host_port.to_string(), default_port),
420 };
421
422 Ok(ParsedUrl {
423 host,
424 port,
425 path: path.to_string(),
426 })
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn test_parse_url_with_port() {
435 let parsed = parse_url("http://vault.example.com:8200/v1/aws/creds/my-role").unwrap();
436 assert_eq!(parsed.host, "vault.example.com");
437 assert_eq!(parsed.port, 8200);
438 assert_eq!(parsed.path, "/v1/aws/creds/my-role");
439 }
440
441 #[test]
442 fn test_parse_url_without_port() {
443 let parsed = parse_url("https://vault.example.com/v1/sys/health").unwrap();
444 assert_eq!(parsed.host, "vault.example.com");
445 assert_eq!(parsed.port, 8200);
446 assert_eq!(parsed.path, "/v1/sys/health");
447 }
448
449 #[test]
450 fn test_parse_url_localhost() {
451 let parsed = parse_url("http://127.0.0.1:8200/v1/auth/approle/login").unwrap();
452 assert_eq!(parsed.host, "127.0.0.1");
453 assert_eq!(parsed.port, 8200);
454 }
455
456 #[test]
457 fn test_parse_url_invalid() {
458 assert!(parse_url("ftp://vault.example.com").is_err());
459 }
460
461 #[test]
462 fn test_vault_config_defaults() {
463 let config = VaultConfig::default();
464 assert_eq!(config.mount_path(), "aws");
465 assert!(!config.tls_skip_verify);
466 assert!(config.address.is_none());
467 }
468
469 #[test]
470 fn test_vault_config_resolve_address_env() {
471 let config = VaultConfig {
472 address: Some("http://localhost:8200".to_string()),
473 ..Default::default()
474 };
475 assert_eq!(config.resolve_address().unwrap(), "http://localhost:8200");
476 }
477
478 #[test]
479 fn test_vault_config_custom_mount() {
480 let config = VaultConfig {
481 mount: Some("aws-prod".to_string()),
482 ..Default::default()
483 };
484 assert_eq!(config.mount_path(), "aws-prod");
485 }
486
487 #[test]
488 fn test_vault_config_deserialize() {
489 let toml_str = r#"
490address = "https://vault.internal:8200"
491mount = "aws-prod"
492role = "audex-agent"
493namespace = "engineering"
494
495[auth]
496method = "approle"
497role_id = "abc-123"
498secret_id = "def-456"
499"#;
500 let config: VaultConfig = toml::from_str(toml_str).unwrap();
501 assert_eq!(
502 config.address.as_deref(),
503 Some("https://vault.internal:8200")
504 );
505 assert_eq!(config.mount_path(), "aws-prod");
506 assert_eq!(config.role.as_deref(), Some("audex-agent"));
507 assert_eq!(config.namespace.as_deref(), Some("engineering"));
508 match config.auth {
509 VaultAuth::Approle { role_id, secret_id } => {
510 assert_eq!(role_id, "abc-123");
511 assert_eq!(secret_id.unwrap(), "def-456");
512 }
513 _ => panic!("Expected AppRole auth"),
514 }
515 }
516
517 #[test]
518 fn test_vault_config_kubernetes_auth() {
519 let toml_str = r#"
520address = "http://vault:8200"
521
522[auth]
523method = "kubernetes"
524role = "audex"
525jwt_path = "/var/run/secrets/token"
526"#;
527 let config: VaultConfig = toml::from_str(toml_str).unwrap();
528 match config.auth {
529 VaultAuth::Kubernetes { role, jwt_path } => {
530 assert_eq!(role, "audex");
531 assert_eq!(jwt_path.unwrap(), "/var/run/secrets/token");
532 }
533 _ => panic!("Expected Kubernetes auth"),
534 }
535 }
536}