Skip to main content

tryaudex_core/
vault.rs

1use std::collections::HashMap;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5
6use crate::credentials::TempCredentials;
7use crate::error::{AvError, Result};
8
9/// Vault authentication method.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11#[serde(tag = "method", rename_all = "lowercase")]
12pub enum VaultAuth {
13    /// Use a static Vault token (or VAULT_TOKEN env var).
14    Token { token: Option<String> },
15    /// AppRole authentication.
16    Approle {
17        role_id: String,
18        secret_id: Option<String>,
19    },
20    /// Kubernetes service account auth.
21    Kubernetes {
22        role: String,
23        /// Path to the service account token (default: /var/run/secrets/kubernetes.io/serviceaccount/token)
24        jwt_path: Option<String>,
25    },
26}
27
28impl Default for VaultAuth {
29    fn default() -> Self {
30        VaultAuth::Token { token: None }
31    }
32}
33
34/// Configuration for the Vault credential backend.
35#[derive(Debug, Clone, Serialize, Deserialize, Default)]
36pub struct VaultConfig {
37    /// Vault server address (e.g. "https://vault.example.com:8200").
38    /// Falls back to VAULT_ADDR env var.
39    pub address: Option<String>,
40    /// Authentication method (default: token).
41    #[serde(default)]
42    pub auth: VaultAuth,
43    /// AWS secrets engine mount path (default: "aws").
44    pub mount: Option<String>,
45    /// Vault role name for generating credentials.
46    pub role: Option<String>,
47    /// Vault namespace (for Vault Enterprise).
48    pub namespace: Option<String>,
49    /// Skip TLS verification (NOT recommended for production).
50    #[serde(default)]
51    pub tls_skip_verify: bool,
52}
53
54impl VaultConfig {
55    /// Resolve the Vault address from config or VAULT_ADDR env var.
56    pub fn resolve_address(&self) -> Result<String> {
57        self.address
58            .clone()
59            .or_else(|| std::env::var("VAULT_ADDR").ok())
60            .ok_or_else(|| {
61                AvError::InvalidPolicy(
62                    "Vault address not set. Set vault.address in config or VAULT_ADDR env var"
63                        .to_string(),
64                )
65            })
66    }
67
68    /// Resolve the secrets engine mount path (default: "aws").
69    pub fn mount_path(&self) -> &str {
70        self.mount.as_deref().unwrap_or("aws")
71    }
72}
73
74/// Response from Vault's auth endpoints.
75#[derive(Debug, Deserialize)]
76struct VaultAuthResponse {
77    auth: Option<VaultAuthData>,
78}
79
80#[derive(Debug, Deserialize)]
81struct VaultAuthData {
82    client_token: String,
83}
84
85/// Response from Vault's AWS secrets engine.
86#[derive(Debug, Deserialize)]
87struct VaultSecretResponse {
88    data: Option<VaultAwsCredentials>,
89    lease_duration: Option<u64>,
90}
91
92#[derive(Debug, Deserialize)]
93struct VaultAwsCredentials {
94    access_key: String,
95    secret_key: String,
96    security_token: Option<String>,
97}
98
99/// Client for issuing credentials via HashiCorp Vault's AWS secrets engine.
100pub struct VaultIssuer {
101    address: String,
102    token: String,
103    mount: String,
104    namespace: Option<String>,
105}
106
107impl VaultIssuer {
108    /// Create a new VaultIssuer by authenticating with the configured method.
109    pub async fn new(config: &VaultConfig) -> Result<Self> {
110        let address = config.resolve_address()?;
111        let mount = config.mount_path().to_string();
112        let namespace = config.namespace.clone();
113
114        let token = match &config.auth {
115            VaultAuth::Token { token } => {
116                token
117                    .clone()
118                    .or_else(|| std::env::var("VAULT_TOKEN").ok())
119                    .ok_or_else(|| {
120                        AvError::InvalidPolicy(
121                            "Vault token not set. Set vault.auth.token in config or VAULT_TOKEN env var".to_string(),
122                        )
123                    })?
124            }
125            VaultAuth::Approle { role_id, secret_id } => {
126                Self::auth_approle(&address, namespace.as_deref(), role_id, secret_id.as_deref())
127                    .await?
128            }
129            VaultAuth::Kubernetes { role, jwt_path } => {
130                let default_path = "/var/run/secrets/kubernetes.io/serviceaccount/token";
131                let path = jwt_path.as_deref().unwrap_or(default_path);
132                let jwt = std::fs::read_to_string(path).map_err(|e| {
133                    AvError::InvalidPolicy(format!(
134                        "Failed to read Kubernetes service account token from {}: {}",
135                        path, e
136                    ))
137                })?;
138                Self::auth_kubernetes(&address, namespace.as_deref(), role, &jwt).await?
139            }
140        };
141
142        tracing::info!(address = %address, mount = %mount, "Connected to Vault");
143
144        Ok(Self {
145            address,
146            token,
147            mount,
148            namespace,
149        })
150    }
151
152    /// Authenticate via AppRole and return a client token.
153    async fn auth_approle(
154        address: &str,
155        namespace: Option<&str>,
156        role_id: &str,
157        secret_id: Option<&str>,
158    ) -> Result<String> {
159        let url = format!("{}/v1/auth/approle/login", address);
160        let mut body = HashMap::new();
161        body.insert("role_id", role_id);
162        if let Some(sid) = secret_id {
163            body.insert("secret_id", sid);
164        }
165
166        let resp = vault_post(&url, namespace, None, &body).await?;
167        let auth_resp: VaultAuthResponse = serde_json::from_str(&resp)
168            .map_err(|e| AvError::Sts(format!("Failed to parse Vault auth response: {}", e)))?;
169
170        auth_resp
171            .auth
172            .map(|a| a.client_token)
173            .ok_or_else(|| AvError::Sts("Vault AppRole auth returned no token".to_string()))
174    }
175
176    /// Authenticate via Kubernetes service account and return a client token.
177    async fn auth_kubernetes(
178        address: &str,
179        namespace: Option<&str>,
180        role: &str,
181        jwt: &str,
182    ) -> Result<String> {
183        let url = format!("{}/v1/auth/kubernetes/login", address);
184        let mut body = HashMap::new();
185        body.insert("role", role);
186        body.insert("jwt", jwt);
187
188        let resp = vault_post(&url, namespace, None, &body).await?;
189        let auth_resp: VaultAuthResponse = serde_json::from_str(&resp)
190            .map_err(|e| AvError::Sts(format!("Failed to parse Vault auth response: {}", e)))?;
191
192        auth_resp
193            .auth
194            .map(|a| a.client_token)
195            .ok_or_else(|| AvError::Sts("Vault Kubernetes auth returned no token".to_string()))
196    }
197
198    /// Issue temporary AWS credentials via Vault's AWS secrets engine.
199    ///
200    /// Uses the `/v1/{mount}/creds/{role}` endpoint to generate dynamic credentials.
201    /// The `ttl` parameter is passed as a request parameter to control credential lifetime.
202    pub async fn issue(&self, vault_role: &str, ttl: Duration) -> Result<TempCredentials> {
203        let url = format!("{}/v1/{}/creds/{}", self.address, self.mount, vault_role);
204
205        let ttl_str = format!("{}s", ttl.as_secs());
206        let mut body = HashMap::new();
207        body.insert("ttl", ttl_str.as_str());
208
209        tracing::info!(
210            vault_role = %vault_role,
211            ttl = %ttl_str,
212            "Requesting credentials from Vault AWS secrets engine"
213        );
214
215        let resp = vault_post(&url, self.namespace.as_deref(), Some(&self.token), &body).await?;
216        let secret: VaultSecretResponse = serde_json::from_str(&resp).map_err(|e| {
217            AvError::Sts(format!("Failed to parse Vault credential response: {}", e))
218        })?;
219
220        let creds = secret
221            .data
222            .ok_or_else(|| AvError::Sts("Vault returned no credential data".to_string()))?;
223
224        let lease_secs = secret.lease_duration.unwrap_or(ttl.as_secs());
225        let expires_at = chrono::Utc::now() + chrono::Duration::seconds(lease_secs as i64);
226
227        Ok(TempCredentials {
228            access_key_id: creds.access_key,
229            secret_access_key: creds.secret_key,
230            session_token: creds.security_token.unwrap_or_default(),
231            expires_at,
232        })
233    }
234
235    /// Read a Vault secret from an arbitrary path (for STS credential generation
236    /// where the Vault role uses assumed_role or federation_token type).
237    pub async fn read_sts_creds(&self, vault_role: &str, ttl: Duration) -> Result<TempCredentials> {
238        let url = format!("{}/v1/{}/sts/{}", self.address, self.mount, vault_role);
239
240        let ttl_str = format!("{}s", ttl.as_secs());
241        let mut body = HashMap::new();
242        body.insert("ttl", ttl_str.as_str());
243
244        tracing::info!(
245            vault_role = %vault_role,
246            ttl = %ttl_str,
247            "Requesting STS credentials from Vault"
248        );
249
250        let resp = vault_post(&url, self.namespace.as_deref(), Some(&self.token), &body).await?;
251        let secret: VaultSecretResponse = serde_json::from_str(&resp)
252            .map_err(|e| AvError::Sts(format!("Failed to parse Vault STS response: {}", e)))?;
253
254        let creds = secret
255            .data
256            .ok_or_else(|| AvError::Sts("Vault returned no STS credential data".to_string()))?;
257
258        let lease_secs = secret.lease_duration.unwrap_or(ttl.as_secs());
259        let expires_at = chrono::Utc::now() + chrono::Duration::seconds(lease_secs as i64);
260
261        Ok(TempCredentials {
262            access_key_id: creds.access_key,
263            secret_access_key: creds.secret_key,
264            session_token: creds.security_token.unwrap_or_default(),
265            expires_at,
266        })
267    }
268
269    /// Check if Vault is healthy and the secrets engine is accessible.
270    pub async fn health_check(&self) -> Result<bool> {
271        let url = format!("{}/v1/sys/health", self.address);
272        match vault_get(&url, self.namespace.as_deref(), Some(&self.token)).await {
273            Ok(_) => Ok(true),
274            Err(_) => Ok(false),
275        }
276    }
277}
278
279/// Make a POST request to Vault.
280async fn vault_post(
281    url: &str,
282    namespace: Option<&str>,
283    token: Option<&str>,
284    body: &HashMap<&str, &str>,
285) -> Result<String> {
286    let body_json = serde_json::to_string(body)
287        .map_err(|e| AvError::Sts(format!("Failed to serialize Vault request: {}", e)))?;
288
289    let mut headers = vec![("Content-Type", "application/json")];
290    let ns_header;
291    if let Some(ns) = namespace {
292        ns_header = ns.to_string();
293        headers.push(("X-Vault-Namespace", &ns_header));
294    }
295    let token_header;
296    if let Some(t) = token {
297        token_header = t.to_string();
298        headers.push(("X-Vault-Token", &token_header));
299    }
300
301    http_request("POST", url, &headers, Some(&body_json)).await
302}
303
304/// Make a GET request to Vault.
305async fn vault_get(url: &str, namespace: Option<&str>, token: Option<&str>) -> Result<String> {
306    let mut headers: Vec<(&str, &str)> = vec![];
307    let ns_header;
308    if let Some(ns) = namespace {
309        ns_header = ns.to_string();
310        headers.push(("X-Vault-Namespace", &ns_header));
311    }
312    let token_header;
313    if let Some(t) = token {
314        token_header = t.to_string();
315        headers.push(("X-Vault-Token", &token_header));
316    }
317
318    http_request("GET", url, &headers, None).await
319}
320
321/// Minimal HTTP client using std::net::TcpStream (no external HTTP crate dependency).
322/// Supports http:// and https:// via rustls if available.
323async fn http_request(
324    method: &str,
325    url: &str,
326    headers: &[(&str, &str)],
327    body: Option<&str>,
328) -> Result<String> {
329    use std::io::{Read, Write};
330    use std::net::TcpStream;
331
332    let parsed = parse_url(url)?;
333
334    let stream = TcpStream::connect(format!("{}:{}", parsed.host, parsed.port))
335        .map_err(|e| AvError::Sts(format!("Failed to connect to Vault at {}: {}", url, e)))?;
336    stream.set_read_timeout(Some(Duration::from_secs(30))).ok();
337
338    let mut request = format!(
339        "{} {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n",
340        method, parsed.path, parsed.host
341    );
342    for (key, value) in headers {
343        request.push_str(&format!("{}: {}\r\n", key, value));
344    }
345    if let Some(b) = body {
346        request.push_str(&format!("Content-Length: {}\r\n", b.len()));
347    }
348    request.push_str("\r\n");
349    if let Some(b) = body {
350        request.push_str(b);
351    }
352
353    // For simplicity, use plaintext HTTP. In production, Vault should be
354    // accessed via a local agent or TLS-terminated proxy.
355    let mut stream = stream;
356    stream
357        .write_all(request.as_bytes())
358        .map_err(|e| AvError::Sts(format!("Failed to send request to Vault: {}", e)))?;
359
360    let mut response = String::new();
361    stream
362        .read_to_string(&mut response)
363        .map_err(|e| AvError::Sts(format!("Failed to read Vault response: {}", e)))?;
364
365    // Extract body from HTTP response
366    if let Some(idx) = response.find("\r\n\r\n") {
367        let status_line = response.lines().next().unwrap_or("");
368        let status_code: u16 = status_line
369            .split_whitespace()
370            .nth(1)
371            .and_then(|s| s.parse().ok())
372            .unwrap_or(0);
373
374        if status_code >= 400 {
375            return Err(AvError::Sts(format!(
376                "Vault returned HTTP {}: {}",
377                status_code,
378                &response[idx + 4..]
379            )));
380        }
381
382        Ok(response[idx + 4..].to_string())
383    } else {
384        Err(AvError::Sts("Invalid HTTP response from Vault".to_string()))
385    }
386}
387
388struct ParsedUrl {
389    host: String,
390    port: u16,
391    path: String,
392}
393
394fn parse_url(url: &str) -> Result<ParsedUrl> {
395    let without_scheme = if let Some(rest) = url.strip_prefix("https://") {
396        rest
397    } else if let Some(rest) = url.strip_prefix("http://") {
398        rest
399    } else {
400        return Err(AvError::InvalidPolicy(format!(
401            "Invalid Vault URL: {}",
402            url
403        )));
404    };
405
406    let default_port: u16 = 8200;
407
408    let (host_port, path) = match without_scheme.find('/') {
409        Some(idx) => (&without_scheme[..idx], &without_scheme[idx..]),
410        None => (without_scheme, "/"),
411    };
412
413    let (host, port) = match host_port.rfind(':') {
414        Some(idx) => {
415            let port_str = &host_port[idx + 1..];
416            let port = port_str.parse::<u16>().unwrap_or(default_port);
417            (host_port[..idx].to_string(), port)
418        }
419        None => (host_port.to_string(), default_port),
420    };
421
422    Ok(ParsedUrl {
423        host,
424        port,
425        path: path.to_string(),
426    })
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn test_parse_url_with_port() {
435        let parsed = parse_url("http://vault.example.com:8200/v1/aws/creds/my-role").unwrap();
436        assert_eq!(parsed.host, "vault.example.com");
437        assert_eq!(parsed.port, 8200);
438        assert_eq!(parsed.path, "/v1/aws/creds/my-role");
439    }
440
441    #[test]
442    fn test_parse_url_without_port() {
443        let parsed = parse_url("https://vault.example.com/v1/sys/health").unwrap();
444        assert_eq!(parsed.host, "vault.example.com");
445        assert_eq!(parsed.port, 8200);
446        assert_eq!(parsed.path, "/v1/sys/health");
447    }
448
449    #[test]
450    fn test_parse_url_localhost() {
451        let parsed = parse_url("http://127.0.0.1:8200/v1/auth/approle/login").unwrap();
452        assert_eq!(parsed.host, "127.0.0.1");
453        assert_eq!(parsed.port, 8200);
454    }
455
456    #[test]
457    fn test_parse_url_invalid() {
458        assert!(parse_url("ftp://vault.example.com").is_err());
459    }
460
461    #[test]
462    fn test_vault_config_defaults() {
463        let config = VaultConfig::default();
464        assert_eq!(config.mount_path(), "aws");
465        assert!(!config.tls_skip_verify);
466        assert!(config.address.is_none());
467    }
468
469    #[test]
470    fn test_vault_config_resolve_address_env() {
471        let config = VaultConfig {
472            address: Some("http://localhost:8200".to_string()),
473            ..Default::default()
474        };
475        assert_eq!(config.resolve_address().unwrap(), "http://localhost:8200");
476    }
477
478    #[test]
479    fn test_vault_config_custom_mount() {
480        let config = VaultConfig {
481            mount: Some("aws-prod".to_string()),
482            ..Default::default()
483        };
484        assert_eq!(config.mount_path(), "aws-prod");
485    }
486
487    #[test]
488    fn test_vault_config_deserialize() {
489        let toml_str = r#"
490address = "https://vault.internal:8200"
491mount = "aws-prod"
492role = "audex-agent"
493namespace = "engineering"
494
495[auth]
496method = "approle"
497role_id = "abc-123"
498secret_id = "def-456"
499"#;
500        let config: VaultConfig = toml::from_str(toml_str).unwrap();
501        assert_eq!(
502            config.address.as_deref(),
503            Some("https://vault.internal:8200")
504        );
505        assert_eq!(config.mount_path(), "aws-prod");
506        assert_eq!(config.role.as_deref(), Some("audex-agent"));
507        assert_eq!(config.namespace.as_deref(), Some("engineering"));
508        match config.auth {
509            VaultAuth::Approle { role_id, secret_id } => {
510                assert_eq!(role_id, "abc-123");
511                assert_eq!(secret_id.unwrap(), "def-456");
512            }
513            _ => panic!("Expected AppRole auth"),
514        }
515    }
516
517    #[test]
518    fn test_vault_config_kubernetes_auth() {
519        let toml_str = r#"
520address = "http://vault:8200"
521
522[auth]
523method = "kubernetes"
524role = "audex"
525jwt_path = "/var/run/secrets/token"
526"#;
527        let config: VaultConfig = toml::from_str(toml_str).unwrap();
528        match config.auth {
529            VaultAuth::Kubernetes { role, jwt_path } => {
530                assert_eq!(role, "audex");
531                assert_eq!(jwt_path.unwrap(), "/var/run/secrets/token");
532            }
533            _ => panic!("Expected Kubernetes auth"),
534        }
535    }
536}