Skip to main content

tryaudex_core/
broker.rs

1use std::collections::HashMap;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{AvError, Result};
7use crate::server::{CredentialRequest, CredentialResponse};
8
9/// Broker client configuration.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct BrokerConfig {
12    /// Audex server URL (e.g. "http://localhost:8080")
13    pub url: String,
14    /// API key for authentication
15    pub api_key: Option<String>,
16    /// Request timeout in seconds (default: 30)
17    #[serde(default = "default_timeout")]
18    pub timeout: u64,
19    /// Default provider for requests
20    pub default_provider: Option<String>,
21    /// Default role ARN for requests
22    pub default_role_arn: Option<String>,
23}
24
25fn default_timeout() -> u64 {
26    30
27}
28
29impl Default for BrokerConfig {
30    fn default() -> Self {
31        Self {
32            url: "http://localhost:8080".to_string(),
33            api_key: None,
34            timeout: default_timeout(),
35            default_provider: None,
36            default_role_arn: None,
37        }
38    }
39}
40
41/// Batch credential request — issue multiple scoped credentials in one call.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct BatchCredentialRequest {
44    pub requests: Vec<CredentialRequest>,
45}
46
47/// Batch credential response.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct BatchCredentialResponse {
50    pub results: Vec<BatchResultItem>,
51    pub total: usize,
52    pub succeeded: usize,
53    pub failed: usize,
54}
55
56/// Individual result in a batch response.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct BatchResultItem {
59    pub index: usize,
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub credentials: Option<CredentialResponse>,
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub error: Option<String>,
64}
65
66/// Token exchange request — exchange a short-lived broker token for credentials.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct TokenExchangeRequest {
69    /// Broker-issued token
70    pub token: String,
71}
72
73/// Broker token — a short-lived, single-use token that can be exchanged for credentials.
74/// Useful for passing to subprocesses without exposing the API key.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct BrokerToken {
77    pub token: String,
78    pub expires_at: String,
79    /// The pre-configured credential request bound to this token.
80    pub request: CredentialRequest,
81}
82
83/// Credential revocation request.
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct RevokeRequest {
86    pub session_id: String,
87    pub reason: Option<String>,
88}
89
90/// Broker API client for programmatic credential requests.
91pub struct BrokerClient {
92    config: BrokerConfig,
93}
94
95impl BrokerClient {
96    pub fn new(config: BrokerConfig) -> Self {
97        Self { config }
98    }
99
100    /// Create from environment variables.
101    /// Reads AUDEX_BROKER_URL and AUDEX_BROKER_API_KEY.
102    pub fn from_env() -> Result<Self> {
103        let url = std::env::var("AUDEX_BROKER_URL")
104            .unwrap_or_else(|_| "http://localhost:8080".to_string());
105        let api_key = std::env::var("AUDEX_BROKER_API_KEY").ok();
106        Ok(Self::new(BrokerConfig {
107            url,
108            api_key,
109            ..Default::default()
110        }))
111    }
112
113    /// Request scoped credentials from the broker.
114    pub async fn get_credentials(&self, request: &CredentialRequest) -> Result<CredentialResponse> {
115        let url = format!("{}/v1/credentials", self.config.url);
116        let body = serde_json::to_string(request)
117            .map_err(|e| AvError::Sts(format!("Failed to serialize request: {}", e)))?;
118
119        let response = self.http_post(&url, &body).await?;
120        serde_json::from_str(&response)
121            .map_err(|e| AvError::Sts(format!("Failed to parse credential response: {}", e)))
122    }
123
124    /// Request credentials with simple parameters (convenience method).
125    pub async fn get_credentials_simple(
126        &self,
127        allow: &str,
128        ttl: &str,
129    ) -> Result<CredentialResponse> {
130        let request = CredentialRequest {
131            allow: Some(allow.to_string()),
132            profile: None,
133            resource: None,
134            provider: self
135                .config
136                .default_provider
137                .clone()
138                .unwrap_or_else(|| "aws".to_string()),
139            ttl: ttl.to_string(),
140            role_arn: self.config.default_role_arn.clone(),
141            command: vec![],
142            agent_id: std::env::var("AUDEX_AGENT_ID").ok(),
143        };
144        self.get_credentials(&request).await
145    }
146
147    /// Batch request — issue multiple credentials in one API call.
148    pub async fn batch_credentials(
149        &self,
150        requests: Vec<CredentialRequest>,
151    ) -> Result<BatchCredentialResponse> {
152        let url = format!("{}/v1/credentials/batch", self.config.url);
153        let batch = BatchCredentialRequest { requests };
154        let body = serde_json::to_string(&batch)
155            .map_err(|e| AvError::Sts(format!("Failed to serialize batch request: {}", e)))?;
156
157        let response = self.http_post(&url, &body).await?;
158        serde_json::from_str(&response)
159            .map_err(|e| AvError::Sts(format!("Failed to parse batch response: {}", e)))
160    }
161
162    /// Create a broker token — a single-use token that can be exchanged for credentials.
163    /// Useful for passing to untrusted subprocesses.
164    pub async fn create_token(&self, request: &CredentialRequest) -> Result<BrokerToken> {
165        let url = format!("{}/v1/tokens", self.config.url);
166        let body = serde_json::to_string(request)
167            .map_err(|e| AvError::Sts(format!("Failed to serialize token request: {}", e)))?;
168
169        let response = self.http_post(&url, &body).await?;
170        serde_json::from_str(&response)
171            .map_err(|e| AvError::Sts(format!("Failed to parse token response: {}", e)))
172    }
173
174    /// Exchange a broker token for credentials.
175    pub async fn exchange_token(&self, token: &str) -> Result<CredentialResponse> {
176        let url = format!("{}/v1/tokens/exchange", self.config.url);
177        let req = TokenExchangeRequest {
178            token: token.to_string(),
179        };
180        let body = serde_json::to_string(&req)
181            .map_err(|e| AvError::Sts(format!("Failed to serialize exchange request: {}", e)))?;
182
183        let response = self.http_post(&url, &body).await?;
184        serde_json::from_str(&response)
185            .map_err(|e| AvError::Sts(format!("Failed to parse exchange response: {}", e)))
186    }
187
188    /// Revoke a session's credentials.
189    pub async fn revoke(&self, session_id: &str, reason: Option<&str>) -> Result<()> {
190        let url = format!("{}/v1/sessions/{}/revoke", self.config.url, session_id);
191        let req = RevokeRequest {
192            session_id: session_id.to_string(),
193            reason: reason.map(|r| r.to_string()),
194        };
195        let body = serde_json::to_string(&req)
196            .map_err(|e| AvError::Sts(format!("Failed to serialize revoke request: {}", e)))?;
197
198        self.http_post(&url, &body).await?;
199        Ok(())
200    }
201
202    /// List active sessions from the broker.
203    pub async fn list_sessions(&self) -> Result<String> {
204        let url = format!("{}/v1/sessions", self.config.url);
205        self.http_get(&url).await
206    }
207
208    /// Get server health status.
209    pub async fn health(&self) -> Result<String> {
210        let url = format!("{}/v1/health", self.config.url);
211        self.http_get(&url).await
212    }
213
214    /// Make an HTTP POST request to the broker.
215    async fn http_post(&self, url: &str, body: &str) -> Result<String> {
216        use std::io::{Read, Write};
217        use std::net::TcpStream;
218
219        let parsed = parse_url(url)?;
220        let stream = TcpStream::connect(format!("{}:{}", parsed.host, parsed.port))
221            .map_err(|e| AvError::Sts(format!("Failed to connect to broker at {}: {}", url, e)))?;
222        stream
223            .set_read_timeout(Some(Duration::from_secs(self.config.timeout)))
224            .ok();
225
226        let mut headers = vec![
227            ("Content-Type", "application/json"),
228            ("Connection", "close"),
229        ];
230        let auth_header;
231        if let Some(ref key) = self.config.api_key {
232            auth_header = format!("Bearer {}", key);
233            headers.push(("Authorization", &auth_header));
234        }
235        let agent_header;
236        if let Ok(agent_id) = std::env::var("AUDEX_AGENT_ID") {
237            agent_header = agent_id;
238            headers.push(("X-Audex-Agent-Id", &agent_header));
239        }
240
241        let mut request = format!(
242            "POST {} HTTP/1.1\r\nHost: {}\r\nContent-Length: {}\r\n",
243            parsed.path,
244            parsed.host,
245            body.len()
246        );
247        for (key, value) in &headers {
248            request.push_str(&format!("{}: {}\r\n", key, value));
249        }
250        request.push_str("\r\n");
251        request.push_str(body);
252
253        let mut stream = stream;
254        stream
255            .write_all(request.as_bytes())
256            .map_err(|e| AvError::Sts(format!("Failed to send broker request: {}", e)))?;
257
258        let mut response = String::new();
259        stream
260            .read_to_string(&mut response)
261            .map_err(|e| AvError::Sts(format!("Failed to read broker response: {}", e)))?;
262
263        extract_http_body(&response)
264    }
265
266    /// Make an HTTP GET request to the broker.
267    async fn http_get(&self, url: &str) -> Result<String> {
268        use std::io::{Read, Write};
269        use std::net::TcpStream;
270
271        let parsed = parse_url(url)?;
272        let stream = TcpStream::connect(format!("{}:{}", parsed.host, parsed.port))
273            .map_err(|e| AvError::Sts(format!("Failed to connect to broker at {}: {}", url, e)))?;
274        stream
275            .set_read_timeout(Some(Duration::from_secs(self.config.timeout)))
276            .ok();
277
278        let mut headers: Vec<(&str, &str)> = vec![("Connection", "close")];
279        let auth_header;
280        if let Some(ref key) = self.config.api_key {
281            auth_header = format!("Bearer {}", key);
282            headers.push(("Authorization", &auth_header));
283        }
284
285        let mut request = format!("GET {} HTTP/1.1\r\nHost: {}\r\n", parsed.path, parsed.host);
286        for (key, value) in &headers {
287            request.push_str(&format!("{}: {}\r\n", key, value));
288        }
289        request.push_str("\r\n");
290
291        let mut stream = stream;
292        stream
293            .write_all(request.as_bytes())
294            .map_err(|e| AvError::Sts(format!("Failed to send broker request: {}", e)))?;
295
296        let mut response = String::new();
297        stream
298            .read_to_string(&mut response)
299            .map_err(|e| AvError::Sts(format!("Failed to read broker response: {}", e)))?;
300
301        extract_http_body(&response)
302    }
303}
304
305struct ParsedUrl {
306    host: String,
307    port: u16,
308    path: String,
309}
310
311fn parse_url(url: &str) -> Result<ParsedUrl> {
312    let without_scheme = url
313        .strip_prefix("https://")
314        .or_else(|| url.strip_prefix("http://"))
315        .ok_or_else(|| AvError::InvalidPolicy(format!("Invalid broker URL: {}", url)))?;
316
317    let default_port: u16 = if url.starts_with("https://") {
318        443
319    } else {
320        8080
321    };
322
323    let (host_port, path) = match without_scheme.find('/') {
324        Some(idx) => (&without_scheme[..idx], &without_scheme[idx..]),
325        None => (without_scheme, "/"),
326    };
327
328    let (host, port) = match host_port.rfind(':') {
329        Some(idx) => {
330            let port = host_port[idx + 1..].parse::<u16>().unwrap_or(default_port);
331            (host_port[..idx].to_string(), port)
332        }
333        None => (host_port.to_string(), default_port),
334    };
335
336    Ok(ParsedUrl {
337        host,
338        port,
339        path: path.to_string(),
340    })
341}
342
343fn extract_http_body(response: &str) -> Result<String> {
344    if let Some(idx) = response.find("\r\n\r\n") {
345        let status_line = response.lines().next().unwrap_or("");
346        let status_code: u16 = status_line
347            .split_whitespace()
348            .nth(1)
349            .and_then(|s| s.parse().ok())
350            .unwrap_or(0);
351
352        let body = &response[idx + 4..];
353
354        if status_code >= 400 {
355            return Err(AvError::Sts(format!(
356                "Broker returned HTTP {}: {}",
357                status_code, body
358            )));
359        }
360
361        Ok(body.to_string())
362    } else {
363        Err(AvError::Sts(
364            "Invalid HTTP response from broker".to_string(),
365        ))
366    }
367}
368
369/// Generate environment variable export commands for credentials.
370/// Useful for shell integration and subprocess injection.
371pub fn credentials_to_env_script(resp: &CredentialResponse, shell: &str) -> String {
372    let mut lines = Vec::new();
373    let export = match shell {
374        "fish" => "set -gx",
375        "powershell" | "pwsh" => "$env:",
376        _ => "export",
377    };
378
379    let creds = &resp.credentials;
380    if let Some(ref key) = creds.aws_access_key_id {
381        lines.push(format_env(export, "AWS_ACCESS_KEY_ID", key, shell));
382    }
383    if let Some(ref key) = creds.aws_secret_access_key {
384        lines.push(format_env(export, "AWS_SECRET_ACCESS_KEY", key, shell));
385    }
386    if let Some(ref token) = creds.aws_session_token {
387        lines.push(format_env(export, "AWS_SESSION_TOKEN", token, shell));
388    }
389    if let Some(ref token) = creds.gcp_access_token {
390        lines.push(format_env(
391            export,
392            "CLOUDSDK_AUTH_ACCESS_TOKEN",
393            token,
394            shell,
395        ));
396    }
397    if let Some(ref token) = creds.azure_token {
398        lines.push(format_env(export, "AZURE_ACCESS_TOKEN", token, shell));
399    }
400
401    lines.join("\n")
402}
403
404fn format_env(export: &str, key: &str, value: &str, shell: &str) -> String {
405    match shell {
406        "powershell" | "pwsh" => format!("{}{}=\"{}\"", export, key, value),
407        "fish" => format!("{} {} \"{}\"", export, key, value),
408        _ => format!("{} {}=\"{}\"", export, key, value),
409    }
410}
411
412/// Generate a JSON credentials document suitable for file-based injection.
413pub fn credentials_to_json(resp: &CredentialResponse) -> Result<String> {
414    let mut map = HashMap::new();
415    map.insert("session_id", resp.session_id.as_str());
416    map.insert("provider", resp.provider.as_str());
417    map.insert("expires_at", resp.expires_at.as_str());
418
419    let creds = &resp.credentials;
420    if let Some(ref v) = creds.aws_access_key_id {
421        map.insert("aws_access_key_id", v);
422    }
423    if let Some(ref v) = creds.aws_secret_access_key {
424        map.insert("aws_secret_access_key", v);
425    }
426    if let Some(ref v) = creds.aws_session_token {
427        map.insert("aws_session_token", v);
428    }
429    if let Some(ref v) = creds.gcp_access_token {
430        map.insert("gcp_access_token", v);
431    }
432    if let Some(ref v) = creds.azure_token {
433        map.insert("azure_token", v);
434    }
435
436    serde_json::to_string_pretty(&map)
437        .map_err(|e| AvError::Sts(format!("Failed to serialize credentials: {}", e)))
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443    use crate::server::CredentialEnvVars;
444
445    #[test]
446    fn test_broker_config_default() {
447        let config = BrokerConfig::default();
448        assert_eq!(config.url, "http://localhost:8080");
449        assert_eq!(config.timeout, 30);
450        assert!(config.api_key.is_none());
451    }
452
453    #[test]
454    fn test_broker_config_deserialize() {
455        let toml_str = r#"
456url = "https://audex.internal:8443"
457api_key = "secret-key-123"
458timeout = 60
459default_provider = "gcp"
460default_role_arn = "arn:aws:iam::123:role/MyRole"
461"#;
462        let config: BrokerConfig = toml::from_str(toml_str).unwrap();
463        assert_eq!(config.url, "https://audex.internal:8443");
464        assert_eq!(config.api_key.as_deref(), Some("secret-key-123"));
465        assert_eq!(config.timeout, 60);
466        assert_eq!(config.default_provider.as_deref(), Some("gcp"));
467    }
468
469    #[test]
470    fn test_batch_request_serialization() {
471        let batch = BatchCredentialRequest {
472            requests: vec![CredentialRequest {
473                allow: Some("s3:GetObject".to_string()),
474                profile: None,
475                resource: None,
476                provider: "aws".to_string(),
477                ttl: "15m".to_string(),
478                role_arn: None,
479                command: vec![],
480                agent_id: None,
481            }],
482        };
483        let json = serde_json::to_string(&batch).unwrap();
484        assert!(json.contains("s3:GetObject"));
485    }
486
487    #[test]
488    fn test_batch_response_deserialization() {
489        let json = r#"{"results":[{"index":0,"credentials":{"session_id":"abc","provider":"aws","expires_at":"2026-01-01T00:00:00Z","ttl_seconds":900,"credentials":{"aws_access_key_id":"AKID"}},"error":null}],"total":1,"succeeded":1,"failed":0}"#;
490        let resp: BatchCredentialResponse = serde_json::from_str(json).unwrap();
491        assert_eq!(resp.total, 1);
492        assert_eq!(resp.succeeded, 1);
493    }
494
495    #[test]
496    fn test_credentials_to_env_script_bash() {
497        let resp = mock_credential_response();
498        let script = credentials_to_env_script(&resp, "bash");
499        assert!(script.contains("export AWS_ACCESS_KEY_ID=\"AKID123\""));
500        assert!(script.contains("export AWS_SECRET_ACCESS_KEY=\"secret\""));
501        assert!(script.contains("export AWS_SESSION_TOKEN=\"token\""));
502    }
503
504    #[test]
505    fn test_credentials_to_env_script_fish() {
506        let resp = mock_credential_response();
507        let script = credentials_to_env_script(&resp, "fish");
508        assert!(script.contains("set -gx AWS_ACCESS_KEY_ID \"AKID123\""));
509    }
510
511    #[test]
512    fn test_credentials_to_env_script_powershell() {
513        let resp = mock_credential_response();
514        let script = credentials_to_env_script(&resp, "powershell");
515        assert!(script.contains("$env:AWS_ACCESS_KEY_ID=\"AKID123\""));
516    }
517
518    #[test]
519    fn test_credentials_to_json() {
520        let resp = mock_credential_response();
521        let json = credentials_to_json(&resp).unwrap();
522        assert!(json.contains("AKID123"));
523        assert!(json.contains("session_id"));
524    }
525
526    #[test]
527    fn test_parse_url_with_port() {
528        let parsed = parse_url("http://localhost:9090/v1/credentials").unwrap();
529        assert_eq!(parsed.host, "localhost");
530        assert_eq!(parsed.port, 9090);
531        assert_eq!(parsed.path, "/v1/credentials");
532    }
533
534    #[test]
535    fn test_parse_url_default_port() {
536        let parsed = parse_url("http://localhost/v1/health").unwrap();
537        assert_eq!(parsed.port, 8080);
538    }
539
540    #[test]
541    fn test_parse_url_https_default_port() {
542        let parsed = parse_url("https://audex.internal/v1/health").unwrap();
543        assert_eq!(parsed.port, 443);
544    }
545
546    #[test]
547    fn test_extract_http_body_success() {
548        let resp = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n{\"ok\":true}";
549        let body = extract_http_body(resp).unwrap();
550        assert_eq!(body, "{\"ok\":true}");
551    }
552
553    #[test]
554    fn test_extract_http_body_error() {
555        let resp = "HTTP/1.1 401 Unauthorized\r\n\r\n{\"error\":\"invalid key\"}";
556        assert!(extract_http_body(resp).is_err());
557    }
558
559    #[test]
560    fn test_broker_token_serialization() {
561        let token = BrokerToken {
562            token: "tok_abc123".to_string(),
563            expires_at: "2026-01-01T00:15:00Z".to_string(),
564            request: CredentialRequest {
565                allow: Some("s3:GetObject".to_string()),
566                profile: None,
567                resource: None,
568                provider: "aws".to_string(),
569                ttl: "15m".to_string(),
570                role_arn: None,
571                command: vec![],
572                agent_id: None,
573            },
574        };
575        let json = serde_json::to_string(&token).unwrap();
576        assert!(json.contains("tok_abc123"));
577        assert!(json.contains("s3:GetObject"));
578    }
579
580    fn mock_credential_response() -> CredentialResponse {
581        CredentialResponse {
582            session_id: "test-session-123".to_string(),
583            provider: "aws".to_string(),
584            expires_at: "2026-01-01T00:15:00Z".to_string(),
585            ttl_seconds: 900,
586            credentials: CredentialEnvVars {
587                aws_access_key_id: Some("AKID123".to_string()),
588                aws_secret_access_key: Some("secret".to_string()),
589                aws_session_token: Some("token".to_string()),
590                gcp_access_token: None,
591                azure_token: None,
592            },
593        }
594    }
595}