1use std::collections::HashMap;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5
6use crate::credentials::TempCredentials;
7use crate::error::{AvError, Result};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11#[serde(tag = "method", rename_all = "lowercase")]
12pub enum VaultAuth {
13 Token { token: Option<String> },
15 Approle {
17 role_id: String,
18 secret_id: Option<String>,
19 },
20 Kubernetes {
22 role: String,
23 jwt_path: Option<String>,
25 },
26}
27
28impl Default for VaultAuth {
29 fn default() -> Self {
30 VaultAuth::Token { token: None }
31 }
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize, Default)]
36pub struct VaultConfig {
37 pub address: Option<String>,
40 #[serde(default)]
42 pub auth: VaultAuth,
43 pub mount: Option<String>,
45 pub role: Option<String>,
47 pub namespace: Option<String>,
49 #[serde(default)]
51 pub tls_skip_verify: bool,
52}
53
54impl VaultConfig {
55 pub fn resolve_address(&self) -> Result<String> {
57 self.address
58 .clone()
59 .or_else(|| std::env::var("VAULT_ADDR").ok())
60 .ok_or_else(|| {
61 AvError::InvalidPolicy(
62 "Vault address not set. Set vault.address in config or VAULT_ADDR env var"
63 .to_string(),
64 )
65 })
66 }
67
68 pub fn mount_path(&self) -> &str {
70 self.mount.as_deref().unwrap_or("aws")
71 }
72}
73
74#[derive(Debug, Deserialize)]
76struct VaultAuthResponse {
77 auth: Option<VaultAuthData>,
78}
79
80#[derive(Debug, Deserialize)]
81struct VaultAuthData {
82 client_token: String,
83}
84
85#[derive(Debug, Deserialize)]
87struct VaultSecretResponse {
88 data: Option<VaultAwsCredentials>,
89 lease_duration: Option<u64>,
90}
91
92#[derive(Debug, Deserialize)]
93struct VaultAwsCredentials {
94 access_key: String,
95 secret_key: String,
96 security_token: Option<String>,
97}
98
99pub struct VaultIssuer {
101 address: String,
102 token: String,
103 mount: String,
104 namespace: Option<String>,
105}
106
107impl VaultIssuer {
108 pub async fn new(config: &VaultConfig) -> Result<Self> {
110 let address = config.resolve_address()?;
111 let mount = config.mount_path().to_string();
112 let namespace = config.namespace.clone();
113
114 let token = match &config.auth {
115 VaultAuth::Token { token } => {
116 token
117 .clone()
118 .or_else(|| std::env::var("VAULT_TOKEN").ok())
119 .ok_or_else(|| {
120 AvError::InvalidPolicy(
121 "Vault token not set. Set vault.auth.token in config or VAULT_TOKEN env var".to_string(),
122 )
123 })?
124 }
125 VaultAuth::Approle { role_id, secret_id } => {
126 Self::auth_approle(&address, namespace.as_deref(), role_id, secret_id.as_deref())
127 .await?
128 }
129 VaultAuth::Kubernetes { role, jwt_path } => {
130 let default_path = "/var/run/secrets/kubernetes.io/serviceaccount/token";
131 let path = jwt_path.as_deref().unwrap_or(default_path);
132 let jwt = std::fs::read_to_string(path).map_err(|e| {
133 AvError::InvalidPolicy(format!(
134 "Failed to read Kubernetes service account token from {}: {}",
135 path, e
136 ))
137 })?;
138 Self::auth_kubernetes(&address, namespace.as_deref(), role, &jwt).await?
139 }
140 };
141
142 tracing::info!(address = %address, mount = %mount, "Connected to Vault");
143
144 Ok(Self {
145 address,
146 token,
147 mount,
148 namespace,
149 })
150 }
151
152 async fn auth_approle(
154 address: &str,
155 namespace: Option<&str>,
156 role_id: &str,
157 secret_id: Option<&str>,
158 ) -> Result<String> {
159 let url = format!("{}/v1/auth/approle/login", address);
160 let mut body = HashMap::new();
161 body.insert("role_id", role_id);
162 if let Some(sid) = secret_id {
163 body.insert("secret_id", sid);
164 }
165
166 let resp = vault_post(&url, namespace, None, &body).await?;
167 let auth_resp: VaultAuthResponse = serde_json::from_str(&resp)
168 .map_err(|e| AvError::Sts(format!("Failed to parse Vault auth response: {}", e)))?;
169
170 auth_resp
171 .auth
172 .map(|a| a.client_token)
173 .ok_or_else(|| AvError::Sts("Vault AppRole auth returned no token".to_string()))
174 }
175
176 async fn auth_kubernetes(
178 address: &str,
179 namespace: Option<&str>,
180 role: &str,
181 jwt: &str,
182 ) -> Result<String> {
183 let url = format!("{}/v1/auth/kubernetes/login", address);
184 let mut body = HashMap::new();
185 body.insert("role", role);
186 body.insert("jwt", jwt);
187
188 let resp = vault_post(&url, namespace, None, &body).await?;
189 let auth_resp: VaultAuthResponse = serde_json::from_str(&resp)
190 .map_err(|e| AvError::Sts(format!("Failed to parse Vault auth response: {}", e)))?;
191
192 auth_resp
193 .auth
194 .map(|a| a.client_token)
195 .ok_or_else(|| AvError::Sts("Vault Kubernetes auth returned no token".to_string()))
196 }
197
198 pub async fn issue(
203 &self,
204 vault_role: &str,
205 ttl: Duration,
206 ) -> Result<TempCredentials> {
207 let url = format!(
208 "{}/v1/{}/creds/{}",
209 self.address, self.mount, vault_role
210 );
211
212 let ttl_str = format!("{}s", ttl.as_secs());
213 let mut body = HashMap::new();
214 body.insert("ttl", ttl_str.as_str());
215
216 tracing::info!(
217 vault_role = %vault_role,
218 ttl = %ttl_str,
219 "Requesting credentials from Vault AWS secrets engine"
220 );
221
222 let resp =
223 vault_post(&url, self.namespace.as_deref(), Some(&self.token), &body).await?;
224 let secret: VaultSecretResponse = serde_json::from_str(&resp)
225 .map_err(|e| AvError::Sts(format!("Failed to parse Vault credential response: {}", e)))?;
226
227 let creds = secret.data.ok_or_else(|| {
228 AvError::Sts("Vault returned no credential data".to_string())
229 })?;
230
231 let lease_secs = secret.lease_duration.unwrap_or(ttl.as_secs());
232 let expires_at = chrono::Utc::now() + chrono::Duration::seconds(lease_secs as i64);
233
234 Ok(TempCredentials {
235 access_key_id: creds.access_key,
236 secret_access_key: creds.secret_key,
237 session_token: creds.security_token.unwrap_or_default(),
238 expires_at,
239 })
240 }
241
242 pub async fn read_sts_creds(
245 &self,
246 vault_role: &str,
247 ttl: Duration,
248 ) -> Result<TempCredentials> {
249 let url = format!(
250 "{}/v1/{}/sts/{}",
251 self.address, self.mount, vault_role
252 );
253
254 let ttl_str = format!("{}s", ttl.as_secs());
255 let mut body = HashMap::new();
256 body.insert("ttl", ttl_str.as_str());
257
258 tracing::info!(
259 vault_role = %vault_role,
260 ttl = %ttl_str,
261 "Requesting STS credentials from Vault"
262 );
263
264 let resp =
265 vault_post(&url, self.namespace.as_deref(), Some(&self.token), &body).await?;
266 let secret: VaultSecretResponse = serde_json::from_str(&resp)
267 .map_err(|e| AvError::Sts(format!("Failed to parse Vault STS response: {}", e)))?;
268
269 let creds = secret.data.ok_or_else(|| {
270 AvError::Sts("Vault returned no STS credential data".to_string())
271 })?;
272
273 let lease_secs = secret.lease_duration.unwrap_or(ttl.as_secs());
274 let expires_at = chrono::Utc::now() + chrono::Duration::seconds(lease_secs as i64);
275
276 Ok(TempCredentials {
277 access_key_id: creds.access_key,
278 secret_access_key: creds.secret_key,
279 session_token: creds.security_token.unwrap_or_default(),
280 expires_at,
281 })
282 }
283
284 pub async fn health_check(&self) -> Result<bool> {
286 let url = format!("{}/v1/sys/health", self.address);
287 match vault_get(&url, self.namespace.as_deref(), Some(&self.token)).await {
288 Ok(_) => Ok(true),
289 Err(_) => Ok(false),
290 }
291 }
292}
293
294async fn vault_post(
296 url: &str,
297 namespace: Option<&str>,
298 token: Option<&str>,
299 body: &HashMap<&str, &str>,
300) -> Result<String> {
301 let body_json = serde_json::to_string(body)
302 .map_err(|e| AvError::Sts(format!("Failed to serialize Vault request: {}", e)))?;
303
304 let mut headers = vec![("Content-Type", "application/json")];
305 let ns_header;
306 if let Some(ns) = namespace {
307 ns_header = ns.to_string();
308 headers.push(("X-Vault-Namespace", &ns_header));
309 }
310 let token_header;
311 if let Some(t) = token {
312 token_header = t.to_string();
313 headers.push(("X-Vault-Token", &token_header));
314 }
315
316 http_request("POST", url, &headers, Some(&body_json)).await
317}
318
319async fn vault_get(
321 url: &str,
322 namespace: Option<&str>,
323 token: Option<&str>,
324) -> Result<String> {
325 let mut headers: Vec<(&str, &str)> = vec![];
326 let ns_header;
327 if let Some(ns) = namespace {
328 ns_header = ns.to_string();
329 headers.push(("X-Vault-Namespace", &ns_header));
330 }
331 let token_header;
332 if let Some(t) = token {
333 token_header = t.to_string();
334 headers.push(("X-Vault-Token", &token_header));
335 }
336
337 http_request("GET", url, &headers, None).await
338}
339
340async fn http_request(
343 method: &str,
344 url: &str,
345 headers: &[(&str, &str)],
346 body: Option<&str>,
347) -> Result<String> {
348 use std::io::{Read, Write};
349 use std::net::TcpStream;
350
351 let parsed = parse_url(url)?;
352
353 let stream = TcpStream::connect(format!("{}:{}", parsed.host, parsed.port))
354 .map_err(|e| AvError::Sts(format!("Failed to connect to Vault at {}: {}", url, e)))?;
355 stream
356 .set_read_timeout(Some(Duration::from_secs(30)))
357 .ok();
358
359 let mut request = format!(
360 "{} {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n",
361 method, parsed.path, parsed.host
362 );
363 for (key, value) in headers {
364 request.push_str(&format!("{}: {}\r\n", key, value));
365 }
366 if let Some(b) = body {
367 request.push_str(&format!("Content-Length: {}\r\n", b.len()));
368 }
369 request.push_str("\r\n");
370 if let Some(b) = body {
371 request.push_str(b);
372 }
373
374 let mut stream = stream;
377 stream
378 .write_all(request.as_bytes())
379 .map_err(|e| AvError::Sts(format!("Failed to send request to Vault: {}", e)))?;
380
381 let mut response = String::new();
382 stream
383 .read_to_string(&mut response)
384 .map_err(|e| AvError::Sts(format!("Failed to read Vault response: {}", e)))?;
385
386 if let Some(idx) = response.find("\r\n\r\n") {
388 let status_line = response.lines().next().unwrap_or("");
389 let status_code: u16 = status_line
390 .split_whitespace()
391 .nth(1)
392 .and_then(|s| s.parse().ok())
393 .unwrap_or(0);
394
395 if status_code >= 400 {
396 return Err(AvError::Sts(format!(
397 "Vault returned HTTP {}: {}",
398 status_code,
399 &response[idx + 4..]
400 )));
401 }
402
403 Ok(response[idx + 4..].to_string())
404 } else {
405 Err(AvError::Sts("Invalid HTTP response from Vault".to_string()))
406 }
407}
408
409struct ParsedUrl {
410 host: String,
411 port: u16,
412 path: String,
413}
414
415fn parse_url(url: &str) -> Result<ParsedUrl> {
416 let without_scheme = if let Some(rest) = url.strip_prefix("https://") {
417 rest
418 } else if let Some(rest) = url.strip_prefix("http://") {
419 rest
420 } else {
421 return Err(AvError::InvalidPolicy(format!("Invalid Vault URL: {}", url)));
422 };
423
424 let default_port: u16 = 8200;
425
426 let (host_port, path) = match without_scheme.find('/') {
427 Some(idx) => (&without_scheme[..idx], &without_scheme[idx..]),
428 None => (without_scheme, "/"),
429 };
430
431 let (host, port) = match host_port.rfind(':') {
432 Some(idx) => {
433 let port_str = &host_port[idx + 1..];
434 let port = port_str
435 .parse::<u16>()
436 .unwrap_or(default_port);
437 (host_port[..idx].to_string(), port)
438 }
439 None => (host_port.to_string(), default_port),
440 };
441
442 Ok(ParsedUrl {
443 host,
444 port,
445 path: path.to_string(),
446 })
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452
453 #[test]
454 fn test_parse_url_with_port() {
455 let parsed = parse_url("http://vault.example.com:8200/v1/aws/creds/my-role").unwrap();
456 assert_eq!(parsed.host, "vault.example.com");
457 assert_eq!(parsed.port, 8200);
458 assert_eq!(parsed.path, "/v1/aws/creds/my-role");
459 }
460
461 #[test]
462 fn test_parse_url_without_port() {
463 let parsed = parse_url("https://vault.example.com/v1/sys/health").unwrap();
464 assert_eq!(parsed.host, "vault.example.com");
465 assert_eq!(parsed.port, 8200);
466 assert_eq!(parsed.path, "/v1/sys/health");
467 }
468
469 #[test]
470 fn test_parse_url_localhost() {
471 let parsed = parse_url("http://127.0.0.1:8200/v1/auth/approle/login").unwrap();
472 assert_eq!(parsed.host, "127.0.0.1");
473 assert_eq!(parsed.port, 8200);
474 }
475
476 #[test]
477 fn test_parse_url_invalid() {
478 assert!(parse_url("ftp://vault.example.com").is_err());
479 }
480
481 #[test]
482 fn test_vault_config_defaults() {
483 let config = VaultConfig::default();
484 assert_eq!(config.mount_path(), "aws");
485 assert!(!config.tls_skip_verify);
486 assert!(config.address.is_none());
487 }
488
489 #[test]
490 fn test_vault_config_resolve_address_env() {
491 let config = VaultConfig {
492 address: Some("http://localhost:8200".to_string()),
493 ..Default::default()
494 };
495 assert_eq!(
496 config.resolve_address().unwrap(),
497 "http://localhost:8200"
498 );
499 }
500
501 #[test]
502 fn test_vault_config_custom_mount() {
503 let config = VaultConfig {
504 mount: Some("aws-prod".to_string()),
505 ..Default::default()
506 };
507 assert_eq!(config.mount_path(), "aws-prod");
508 }
509
510 #[test]
511 fn test_vault_config_deserialize() {
512 let toml_str = r#"
513address = "https://vault.internal:8200"
514mount = "aws-prod"
515role = "audex-agent"
516namespace = "engineering"
517
518[auth]
519method = "approle"
520role_id = "abc-123"
521secret_id = "def-456"
522"#;
523 let config: VaultConfig = toml::from_str(toml_str).unwrap();
524 assert_eq!(config.address.as_deref(), Some("https://vault.internal:8200"));
525 assert_eq!(config.mount_path(), "aws-prod");
526 assert_eq!(config.role.as_deref(), Some("audex-agent"));
527 assert_eq!(config.namespace.as_deref(), Some("engineering"));
528 match config.auth {
529 VaultAuth::Approle { role_id, secret_id } => {
530 assert_eq!(role_id, "abc-123");
531 assert_eq!(secret_id.unwrap(), "def-456");
532 }
533 _ => panic!("Expected AppRole auth"),
534 }
535 }
536
537 #[test]
538 fn test_vault_config_kubernetes_auth() {
539 let toml_str = r#"
540address = "http://vault:8200"
541
542[auth]
543method = "kubernetes"
544role = "audex"
545jwt_path = "/var/run/secrets/token"
546"#;
547 let config: VaultConfig = toml::from_str(toml_str).unwrap();
548 match config.auth {
549 VaultAuth::Kubernetes { role, jwt_path } => {
550 assert_eq!(role, "audex");
551 assert_eq!(jwt_path.unwrap(), "/var/run/secrets/token");
552 }
553 _ => panic!("Expected Kubernetes auth"),
554 }
555 }
556}