Skip to main content

tryaudex_core/
broker.rs

1use std::collections::HashMap;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{AvError, Result};
7use crate::server::{CredentialRequest, CredentialResponse};
8
9/// Broker client configuration.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct BrokerConfig {
12    /// Audex server URL (e.g. "http://localhost:8080")
13    pub url: String,
14    /// API key for authentication
15    pub api_key: Option<String>,
16    /// Request timeout in seconds (default: 30)
17    #[serde(default = "default_timeout")]
18    pub timeout: u64,
19    /// Default provider for requests
20    pub default_provider: Option<String>,
21    /// Default role ARN for requests
22    pub default_role_arn: Option<String>,
23}
24
25fn default_timeout() -> u64 {
26    30
27}
28
29impl Default for BrokerConfig {
30    fn default() -> Self {
31        Self {
32            url: "http://localhost:8080".to_string(),
33            api_key: None,
34            timeout: default_timeout(),
35            default_provider: None,
36            default_role_arn: None,
37        }
38    }
39}
40
41/// Batch credential request — issue multiple scoped credentials in one call.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct BatchCredentialRequest {
44    pub requests: Vec<CredentialRequest>,
45}
46
47/// Batch credential response.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct BatchCredentialResponse {
50    pub results: Vec<BatchResultItem>,
51    pub total: usize,
52    pub succeeded: usize,
53    pub failed: usize,
54}
55
56/// Individual result in a batch response.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct BatchResultItem {
59    pub index: usize,
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub credentials: Option<CredentialResponse>,
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub error: Option<String>,
64}
65
66/// Token exchange request — exchange a short-lived broker token for credentials.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct TokenExchangeRequest {
69    /// Broker-issued token
70    pub token: String,
71}
72
73/// Broker token — a short-lived, single-use token that can be exchanged for credentials.
74/// Useful for passing to subprocesses without exposing the API key.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct BrokerToken {
77    pub token: String,
78    pub expires_at: String,
79    /// The pre-configured credential request bound to this token.
80    pub request: CredentialRequest,
81}
82
83/// Credential revocation request.
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct RevokeRequest {
86    pub session_id: String,
87    pub reason: Option<String>,
88}
89
90/// Broker API client for programmatic credential requests.
91pub struct BrokerClient {
92    config: BrokerConfig,
93}
94
95impl BrokerClient {
96    pub fn new(config: BrokerConfig) -> Self {
97        Self { config }
98    }
99
100    /// Create from environment variables.
101    /// Reads AUDEX_BROKER_URL and AUDEX_BROKER_API_KEY.
102    pub fn from_env() -> Result<Self> {
103        let url = std::env::var("AUDEX_BROKER_URL").unwrap_or_else(|_| {
104            "http://localhost:8080".to_string()
105        });
106        let api_key = std::env::var("AUDEX_BROKER_API_KEY").ok();
107        Ok(Self::new(BrokerConfig {
108            url,
109            api_key,
110            ..Default::default()
111        }))
112    }
113
114    /// Request scoped credentials from the broker.
115    pub async fn get_credentials(&self, request: &CredentialRequest) -> Result<CredentialResponse> {
116        let url = format!("{}/v1/credentials", self.config.url);
117        let body = serde_json::to_string(request)
118            .map_err(|e| AvError::Sts(format!("Failed to serialize request: {}", e)))?;
119
120        let response = self.http_post(&url, &body).await?;
121        serde_json::from_str(&response)
122            .map_err(|e| AvError::Sts(format!("Failed to parse credential response: {}", e)))
123    }
124
125    /// Request credentials with simple parameters (convenience method).
126    pub async fn get_credentials_simple(
127        &self,
128        allow: &str,
129        ttl: &str,
130    ) -> Result<CredentialResponse> {
131        let request = CredentialRequest {
132            allow: Some(allow.to_string()),
133            profile: None,
134            resource: None,
135            provider: self.config.default_provider.clone().unwrap_or_else(|| "aws".to_string()),
136            ttl: ttl.to_string(),
137            role_arn: self.config.default_role_arn.clone(),
138            command: vec![],
139            agent_id: std::env::var("AUDEX_AGENT_ID").ok(),
140        };
141        self.get_credentials(&request).await
142    }
143
144    /// Batch request — issue multiple credentials in one API call.
145    pub async fn batch_credentials(
146        &self,
147        requests: Vec<CredentialRequest>,
148    ) -> Result<BatchCredentialResponse> {
149        let url = format!("{}/v1/credentials/batch", self.config.url);
150        let batch = BatchCredentialRequest { requests };
151        let body = serde_json::to_string(&batch)
152            .map_err(|e| AvError::Sts(format!("Failed to serialize batch request: {}", e)))?;
153
154        let response = self.http_post(&url, &body).await?;
155        serde_json::from_str(&response)
156            .map_err(|e| AvError::Sts(format!("Failed to parse batch response: {}", e)))
157    }
158
159    /// Create a broker token — a single-use token that can be exchanged for credentials.
160    /// Useful for passing to untrusted subprocesses.
161    pub async fn create_token(&self, request: &CredentialRequest) -> Result<BrokerToken> {
162        let url = format!("{}/v1/tokens", self.config.url);
163        let body = serde_json::to_string(request)
164            .map_err(|e| AvError::Sts(format!("Failed to serialize token request: {}", e)))?;
165
166        let response = self.http_post(&url, &body).await?;
167        serde_json::from_str(&response)
168            .map_err(|e| AvError::Sts(format!("Failed to parse token response: {}", e)))
169    }
170
171    /// Exchange a broker token for credentials.
172    pub async fn exchange_token(&self, token: &str) -> Result<CredentialResponse> {
173        let url = format!("{}/v1/tokens/exchange", self.config.url);
174        let req = TokenExchangeRequest {
175            token: token.to_string(),
176        };
177        let body = serde_json::to_string(&req)
178            .map_err(|e| AvError::Sts(format!("Failed to serialize exchange request: {}", e)))?;
179
180        let response = self.http_post(&url, &body).await?;
181        serde_json::from_str(&response)
182            .map_err(|e| AvError::Sts(format!("Failed to parse exchange response: {}", e)))
183    }
184
185    /// Revoke a session's credentials.
186    pub async fn revoke(&self, session_id: &str, reason: Option<&str>) -> Result<()> {
187        let url = format!("{}/v1/sessions/{}/revoke", self.config.url, session_id);
188        let req = RevokeRequest {
189            session_id: session_id.to_string(),
190            reason: reason.map(|r| r.to_string()),
191        };
192        let body = serde_json::to_string(&req)
193            .map_err(|e| AvError::Sts(format!("Failed to serialize revoke request: {}", e)))?;
194
195        self.http_post(&url, &body).await?;
196        Ok(())
197    }
198
199    /// List active sessions from the broker.
200    pub async fn list_sessions(&self) -> Result<String> {
201        let url = format!("{}/v1/sessions", self.config.url);
202        self.http_get(&url).await
203    }
204
205    /// Get server health status.
206    pub async fn health(&self) -> Result<String> {
207        let url = format!("{}/v1/health", self.config.url);
208        self.http_get(&url).await
209    }
210
211    /// Make an HTTP POST request to the broker.
212    async fn http_post(&self, url: &str, body: &str) -> Result<String> {
213        use std::io::{Read, Write};
214        use std::net::TcpStream;
215
216        let parsed = parse_url(url)?;
217        let stream = TcpStream::connect(format!("{}:{}", parsed.host, parsed.port))
218            .map_err(|e| AvError::Sts(format!("Failed to connect to broker at {}: {}", url, e)))?;
219        stream
220            .set_read_timeout(Some(Duration::from_secs(self.config.timeout)))
221            .ok();
222
223        let mut headers = vec![
224            ("Content-Type", "application/json"),
225            ("Connection", "close"),
226        ];
227        let auth_header;
228        if let Some(ref key) = self.config.api_key {
229            auth_header = format!("Bearer {}", key);
230            headers.push(("Authorization", &auth_header));
231        }
232        let agent_header;
233        if let Ok(agent_id) = std::env::var("AUDEX_AGENT_ID") {
234            agent_header = agent_id;
235            headers.push(("X-Audex-Agent-Id", &agent_header));
236        }
237
238        let mut request = format!(
239            "POST {} HTTP/1.1\r\nHost: {}\r\nContent-Length: {}\r\n",
240            parsed.path, parsed.host, body.len()
241        );
242        for (key, value) in &headers {
243            request.push_str(&format!("{}: {}\r\n", key, value));
244        }
245        request.push_str("\r\n");
246        request.push_str(body);
247
248        let mut stream = stream;
249        stream
250            .write_all(request.as_bytes())
251            .map_err(|e| AvError::Sts(format!("Failed to send broker request: {}", e)))?;
252
253        let mut response = String::new();
254        stream
255            .read_to_string(&mut response)
256            .map_err(|e| AvError::Sts(format!("Failed to read broker response: {}", e)))?;
257
258        extract_http_body(&response)
259    }
260
261    /// Make an HTTP GET request to the broker.
262    async fn http_get(&self, url: &str) -> Result<String> {
263        use std::io::{Read, Write};
264        use std::net::TcpStream;
265
266        let parsed = parse_url(url)?;
267        let stream = TcpStream::connect(format!("{}:{}", parsed.host, parsed.port))
268            .map_err(|e| AvError::Sts(format!("Failed to connect to broker at {}: {}", url, e)))?;
269        stream
270            .set_read_timeout(Some(Duration::from_secs(self.config.timeout)))
271            .ok();
272
273        let mut headers: Vec<(&str, &str)> = vec![("Connection", "close")];
274        let auth_header;
275        if let Some(ref key) = self.config.api_key {
276            auth_header = format!("Bearer {}", key);
277            headers.push(("Authorization", &auth_header));
278        }
279
280        let mut request = format!(
281            "GET {} HTTP/1.1\r\nHost: {}\r\n",
282            parsed.path, parsed.host
283        );
284        for (key, value) in &headers {
285            request.push_str(&format!("{}: {}\r\n", key, value));
286        }
287        request.push_str("\r\n");
288
289        let mut stream = stream;
290        stream
291            .write_all(request.as_bytes())
292            .map_err(|e| AvError::Sts(format!("Failed to send broker request: {}", e)))?;
293
294        let mut response = String::new();
295        stream
296            .read_to_string(&mut response)
297            .map_err(|e| AvError::Sts(format!("Failed to read broker response: {}", e)))?;
298
299        extract_http_body(&response)
300    }
301}
302
303struct ParsedUrl {
304    host: String,
305    port: u16,
306    path: String,
307}
308
309fn parse_url(url: &str) -> Result<ParsedUrl> {
310    let without_scheme = url
311        .strip_prefix("https://")
312        .or_else(|| url.strip_prefix("http://"))
313        .ok_or_else(|| AvError::InvalidPolicy(format!("Invalid broker URL: {}", url)))?;
314
315    let default_port: u16 = if url.starts_with("https://") { 443 } else { 8080 };
316
317    let (host_port, path) = match without_scheme.find('/') {
318        Some(idx) => (&without_scheme[..idx], &without_scheme[idx..]),
319        None => (without_scheme, "/"),
320    };
321
322    let (host, port) = match host_port.rfind(':') {
323        Some(idx) => {
324            let port = host_port[idx + 1..].parse::<u16>().unwrap_or(default_port);
325            (host_port[..idx].to_string(), port)
326        }
327        None => (host_port.to_string(), default_port),
328    };
329
330    Ok(ParsedUrl {
331        host,
332        port,
333        path: path.to_string(),
334    })
335}
336
337fn extract_http_body(response: &str) -> Result<String> {
338    if let Some(idx) = response.find("\r\n\r\n") {
339        let status_line = response.lines().next().unwrap_or("");
340        let status_code: u16 = status_line
341            .split_whitespace()
342            .nth(1)
343            .and_then(|s| s.parse().ok())
344            .unwrap_or(0);
345
346        let body = &response[idx + 4..];
347
348        if status_code >= 400 {
349            return Err(AvError::Sts(format!(
350                "Broker returned HTTP {}: {}",
351                status_code, body
352            )));
353        }
354
355        Ok(body.to_string())
356    } else {
357        Err(AvError::Sts("Invalid HTTP response from broker".to_string()))
358    }
359}
360
361/// Generate environment variable export commands for credentials.
362/// Useful for shell integration and subprocess injection.
363pub fn credentials_to_env_script(resp: &CredentialResponse, shell: &str) -> String {
364    let mut lines = Vec::new();
365    let export = match shell {
366        "fish" => "set -gx",
367        "powershell" | "pwsh" => "$env:",
368        _ => "export",
369    };
370
371    let creds = &resp.credentials;
372    if let Some(ref key) = creds.aws_access_key_id {
373        lines.push(format_env(export, "AWS_ACCESS_KEY_ID", key, shell));
374    }
375    if let Some(ref key) = creds.aws_secret_access_key {
376        lines.push(format_env(export, "AWS_SECRET_ACCESS_KEY", key, shell));
377    }
378    if let Some(ref token) = creds.aws_session_token {
379        lines.push(format_env(export, "AWS_SESSION_TOKEN", token, shell));
380    }
381    if let Some(ref token) = creds.gcp_access_token {
382        lines.push(format_env(export, "CLOUDSDK_AUTH_ACCESS_TOKEN", token, shell));
383    }
384    if let Some(ref token) = creds.azure_token {
385        lines.push(format_env(export, "AZURE_ACCESS_TOKEN", token, shell));
386    }
387
388    lines.join("\n")
389}
390
391fn format_env(export: &str, key: &str, value: &str, shell: &str) -> String {
392    match shell {
393        "powershell" | "pwsh" => format!("{}{}=\"{}\"", export, key, value),
394        "fish" => format!("{} {} \"{}\"", export, key, value),
395        _ => format!("{} {}=\"{}\"", export, key, value),
396    }
397}
398
399/// Generate a JSON credentials document suitable for file-based injection.
400pub fn credentials_to_json(resp: &CredentialResponse) -> Result<String> {
401    let mut map = HashMap::new();
402    map.insert("session_id", resp.session_id.as_str());
403    map.insert("provider", resp.provider.as_str());
404    map.insert("expires_at", resp.expires_at.as_str());
405
406    let creds = &resp.credentials;
407    if let Some(ref v) = creds.aws_access_key_id {
408        map.insert("aws_access_key_id", v);
409    }
410    if let Some(ref v) = creds.aws_secret_access_key {
411        map.insert("aws_secret_access_key", v);
412    }
413    if let Some(ref v) = creds.aws_session_token {
414        map.insert("aws_session_token", v);
415    }
416    if let Some(ref v) = creds.gcp_access_token {
417        map.insert("gcp_access_token", v);
418    }
419    if let Some(ref v) = creds.azure_token {
420        map.insert("azure_token", v);
421    }
422
423    serde_json::to_string_pretty(&map)
424        .map_err(|e| AvError::Sts(format!("Failed to serialize credentials: {}", e)))
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430    use crate::server::CredentialEnvVars;
431
432    #[test]
433    fn test_broker_config_default() {
434        let config = BrokerConfig::default();
435        assert_eq!(config.url, "http://localhost:8080");
436        assert_eq!(config.timeout, 30);
437        assert!(config.api_key.is_none());
438    }
439
440    #[test]
441    fn test_broker_config_deserialize() {
442        let toml_str = r#"
443url = "https://audex.internal:8443"
444api_key = "secret-key-123"
445timeout = 60
446default_provider = "gcp"
447default_role_arn = "arn:aws:iam::123:role/MyRole"
448"#;
449        let config: BrokerConfig = toml::from_str(toml_str).unwrap();
450        assert_eq!(config.url, "https://audex.internal:8443");
451        assert_eq!(config.api_key.as_deref(), Some("secret-key-123"));
452        assert_eq!(config.timeout, 60);
453        assert_eq!(config.default_provider.as_deref(), Some("gcp"));
454    }
455
456    #[test]
457    fn test_batch_request_serialization() {
458        let batch = BatchCredentialRequest {
459            requests: vec![
460                CredentialRequest {
461                    allow: Some("s3:GetObject".to_string()),
462                    profile: None,
463                    resource: None,
464                    provider: "aws".to_string(),
465                    ttl: "15m".to_string(),
466                    role_arn: None,
467                    command: vec![],
468                    agent_id: None,
469                },
470            ],
471        };
472        let json = serde_json::to_string(&batch).unwrap();
473        assert!(json.contains("s3:GetObject"));
474    }
475
476    #[test]
477    fn test_batch_response_deserialization() {
478        let json = r#"{"results":[{"index":0,"credentials":{"session_id":"abc","provider":"aws","expires_at":"2026-01-01T00:00:00Z","ttl_seconds":900,"credentials":{"aws_access_key_id":"AKID"}},"error":null}],"total":1,"succeeded":1,"failed":0}"#;
479        let resp: BatchCredentialResponse = serde_json::from_str(json).unwrap();
480        assert_eq!(resp.total, 1);
481        assert_eq!(resp.succeeded, 1);
482    }
483
484    #[test]
485    fn test_credentials_to_env_script_bash() {
486        let resp = mock_credential_response();
487        let script = credentials_to_env_script(&resp, "bash");
488        assert!(script.contains("export AWS_ACCESS_KEY_ID=\"AKID123\""));
489        assert!(script.contains("export AWS_SECRET_ACCESS_KEY=\"secret\""));
490        assert!(script.contains("export AWS_SESSION_TOKEN=\"token\""));
491    }
492
493    #[test]
494    fn test_credentials_to_env_script_fish() {
495        let resp = mock_credential_response();
496        let script = credentials_to_env_script(&resp, "fish");
497        assert!(script.contains("set -gx AWS_ACCESS_KEY_ID \"AKID123\""));
498    }
499
500    #[test]
501    fn test_credentials_to_env_script_powershell() {
502        let resp = mock_credential_response();
503        let script = credentials_to_env_script(&resp, "powershell");
504        assert!(script.contains("$env:AWS_ACCESS_KEY_ID=\"AKID123\""));
505    }
506
507    #[test]
508    fn test_credentials_to_json() {
509        let resp = mock_credential_response();
510        let json = credentials_to_json(&resp).unwrap();
511        assert!(json.contains("AKID123"));
512        assert!(json.contains("session_id"));
513    }
514
515    #[test]
516    fn test_parse_url_with_port() {
517        let parsed = parse_url("http://localhost:9090/v1/credentials").unwrap();
518        assert_eq!(parsed.host, "localhost");
519        assert_eq!(parsed.port, 9090);
520        assert_eq!(parsed.path, "/v1/credentials");
521    }
522
523    #[test]
524    fn test_parse_url_default_port() {
525        let parsed = parse_url("http://localhost/v1/health").unwrap();
526        assert_eq!(parsed.port, 8080);
527    }
528
529    #[test]
530    fn test_parse_url_https_default_port() {
531        let parsed = parse_url("https://audex.internal/v1/health").unwrap();
532        assert_eq!(parsed.port, 443);
533    }
534
535    #[test]
536    fn test_extract_http_body_success() {
537        let resp = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n{\"ok\":true}";
538        let body = extract_http_body(resp).unwrap();
539        assert_eq!(body, "{\"ok\":true}");
540    }
541
542    #[test]
543    fn test_extract_http_body_error() {
544        let resp = "HTTP/1.1 401 Unauthorized\r\n\r\n{\"error\":\"invalid key\"}";
545        assert!(extract_http_body(resp).is_err());
546    }
547
548    #[test]
549    fn test_broker_token_serialization() {
550        let token = BrokerToken {
551            token: "tok_abc123".to_string(),
552            expires_at: "2026-01-01T00:15:00Z".to_string(),
553            request: CredentialRequest {
554                allow: Some("s3:GetObject".to_string()),
555                profile: None,
556                resource: None,
557                provider: "aws".to_string(),
558                ttl: "15m".to_string(),
559                role_arn: None,
560                command: vec![],
561                agent_id: None,
562            },
563        };
564        let json = serde_json::to_string(&token).unwrap();
565        assert!(json.contains("tok_abc123"));
566        assert!(json.contains("s3:GetObject"));
567    }
568
569    fn mock_credential_response() -> CredentialResponse {
570        CredentialResponse {
571            session_id: "test-session-123".to_string(),
572            provider: "aws".to_string(),
573            expires_at: "2026-01-01T00:15:00Z".to_string(),
574            ttl_seconds: 900,
575            credentials: CredentialEnvVars {
576                aws_access_key_id: Some("AKID123".to_string()),
577                aws_secret_access_key: Some("secret".to_string()),
578                aws_session_token: Some("token".to_string()),
579                gcp_access_token: None,
580                azure_token: None,
581            },
582        }
583    }
584}