1use std::collections::HashMap;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{AvError, Result};
7use crate::server::{CredentialRequest, CredentialResponse};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct BrokerConfig {
12 pub url: String,
14 pub api_key: Option<String>,
16 #[serde(default = "default_timeout")]
18 pub timeout: u64,
19 pub default_provider: Option<String>,
21 pub default_role_arn: Option<String>,
23}
24
25fn default_timeout() -> u64 {
26 30
27}
28
29impl Default for BrokerConfig {
30 fn default() -> Self {
31 Self {
32 url: "http://localhost:8080".to_string(),
33 api_key: None,
34 timeout: default_timeout(),
35 default_provider: None,
36 default_role_arn: None,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct BatchCredentialRequest {
44 pub requests: Vec<CredentialRequest>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct BatchCredentialResponse {
50 pub results: Vec<BatchResultItem>,
51 pub total: usize,
52 pub succeeded: usize,
53 pub failed: usize,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct BatchResultItem {
59 pub index: usize,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 pub credentials: Option<CredentialResponse>,
62 #[serde(skip_serializing_if = "Option::is_none")]
63 pub error: Option<String>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct TokenExchangeRequest {
69 pub token: String,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct BrokerToken {
77 pub token: String,
78 pub expires_at: String,
79 pub request: CredentialRequest,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct RevokeRequest {
86 pub session_id: String,
87 pub reason: Option<String>,
88}
89
90pub struct BrokerClient {
92 config: BrokerConfig,
93}
94
95impl BrokerClient {
96 pub fn new(config: BrokerConfig) -> Self {
97 Self { config }
98 }
99
100 pub fn from_env() -> Result<Self> {
103 let url = std::env::var("AUDEX_BROKER_URL")
104 .unwrap_or_else(|_| "http://localhost:8080".to_string());
105 let api_key = std::env::var("AUDEX_BROKER_API_KEY").ok();
106 Ok(Self::new(BrokerConfig {
107 url,
108 api_key,
109 ..Default::default()
110 }))
111 }
112
113 pub async fn get_credentials(&self, request: &CredentialRequest) -> Result<CredentialResponse> {
115 let url = format!("{}/v1/credentials", self.config.url);
116 let body = serde_json::to_string(request)
117 .map_err(|e| AvError::Sts(format!("Failed to serialize request: {}", e)))?;
118
119 let response = self.http_post(&url, &body).await?;
120 serde_json::from_str(&response)
121 .map_err(|e| AvError::Sts(format!("Failed to parse credential response: {}", e)))
122 }
123
124 pub async fn get_credentials_simple(
126 &self,
127 allow: &str,
128 ttl: &str,
129 ) -> Result<CredentialResponse> {
130 let request = CredentialRequest {
131 allow: Some(allow.to_string()),
132 profile: None,
133 resource: None,
134 provider: self
135 .config
136 .default_provider
137 .clone()
138 .unwrap_or_else(|| "aws".to_string()),
139 ttl: ttl.to_string(),
140 role_arn: self.config.default_role_arn.clone(),
141 command: vec![],
142 agent_id: std::env::var("AUDEX_AGENT_ID").ok(),
143 };
144 self.get_credentials(&request).await
145 }
146
147 pub async fn batch_credentials(
149 &self,
150 requests: Vec<CredentialRequest>,
151 ) -> Result<BatchCredentialResponse> {
152 let url = format!("{}/v1/credentials/batch", self.config.url);
153 let batch = BatchCredentialRequest { requests };
154 let body = serde_json::to_string(&batch)
155 .map_err(|e| AvError::Sts(format!("Failed to serialize batch request: {}", e)))?;
156
157 let response = self.http_post(&url, &body).await?;
158 serde_json::from_str(&response)
159 .map_err(|e| AvError::Sts(format!("Failed to parse batch response: {}", e)))
160 }
161
162 pub async fn create_token(&self, request: &CredentialRequest) -> Result<BrokerToken> {
165 let url = format!("{}/v1/tokens", self.config.url);
166 let body = serde_json::to_string(request)
167 .map_err(|e| AvError::Sts(format!("Failed to serialize token request: {}", e)))?;
168
169 let response = self.http_post(&url, &body).await?;
170 serde_json::from_str(&response)
171 .map_err(|e| AvError::Sts(format!("Failed to parse token response: {}", e)))
172 }
173
174 pub async fn exchange_token(&self, token: &str) -> Result<CredentialResponse> {
176 let url = format!("{}/v1/tokens/exchange", self.config.url);
177 let req = TokenExchangeRequest {
178 token: token.to_string(),
179 };
180 let body = serde_json::to_string(&req)
181 .map_err(|e| AvError::Sts(format!("Failed to serialize exchange request: {}", e)))?;
182
183 let response = self.http_post(&url, &body).await?;
184 serde_json::from_str(&response)
185 .map_err(|e| AvError::Sts(format!("Failed to parse exchange response: {}", e)))
186 }
187
188 pub async fn revoke(&self, session_id: &str, reason: Option<&str>) -> Result<()> {
190 let url = format!("{}/v1/sessions/{}/revoke", self.config.url, session_id);
191 let req = RevokeRequest {
192 session_id: session_id.to_string(),
193 reason: reason.map(|r| r.to_string()),
194 };
195 let body = serde_json::to_string(&req)
196 .map_err(|e| AvError::Sts(format!("Failed to serialize revoke request: {}", e)))?;
197
198 self.http_post(&url, &body).await?;
199 Ok(())
200 }
201
202 pub async fn list_sessions(&self) -> Result<String> {
204 let url = format!("{}/v1/sessions", self.config.url);
205 self.http_get(&url).await
206 }
207
208 pub async fn health(&self) -> Result<String> {
210 let url = format!("{}/v1/health", self.config.url);
211 self.http_get(&url).await
212 }
213
214 async fn http_post(&self, url: &str, body: &str) -> Result<String> {
216 use std::io::{Read, Write};
217 use std::net::TcpStream;
218
219 let parsed = parse_url(url)?;
220 let stream = TcpStream::connect(format!("{}:{}", parsed.host, parsed.port))
221 .map_err(|e| AvError::Sts(format!("Failed to connect to broker at {}: {}", url, e)))?;
222 stream
223 .set_read_timeout(Some(Duration::from_secs(self.config.timeout)))
224 .ok();
225
226 let mut headers = vec![
227 ("Content-Type", "application/json"),
228 ("Connection", "close"),
229 ];
230 let auth_header;
231 if let Some(ref key) = self.config.api_key {
232 auth_header = format!("Bearer {}", key);
233 headers.push(("Authorization", &auth_header));
234 }
235 let agent_header;
236 if let Ok(agent_id) = std::env::var("AUDEX_AGENT_ID") {
237 agent_header = agent_id;
238 headers.push(("X-Audex-Agent-Id", &agent_header));
239 }
240
241 let mut request = format!(
242 "POST {} HTTP/1.1\r\nHost: {}\r\nContent-Length: {}\r\n",
243 parsed.path,
244 parsed.host,
245 body.len()
246 );
247 for (key, value) in &headers {
248 request.push_str(&format!("{}: {}\r\n", key, value));
249 }
250 request.push_str("\r\n");
251 request.push_str(body);
252
253 let mut stream = stream;
254 stream
255 .write_all(request.as_bytes())
256 .map_err(|e| AvError::Sts(format!("Failed to send broker request: {}", e)))?;
257
258 let mut response = String::new();
259 stream
260 .read_to_string(&mut response)
261 .map_err(|e| AvError::Sts(format!("Failed to read broker response: {}", e)))?;
262
263 extract_http_body(&response)
264 }
265
266 async fn http_get(&self, url: &str) -> Result<String> {
268 use std::io::{Read, Write};
269 use std::net::TcpStream;
270
271 let parsed = parse_url(url)?;
272 let stream = TcpStream::connect(format!("{}:{}", parsed.host, parsed.port))
273 .map_err(|e| AvError::Sts(format!("Failed to connect to broker at {}: {}", url, e)))?;
274 stream
275 .set_read_timeout(Some(Duration::from_secs(self.config.timeout)))
276 .ok();
277
278 let mut headers: Vec<(&str, &str)> = vec![("Connection", "close")];
279 let auth_header;
280 if let Some(ref key) = self.config.api_key {
281 auth_header = format!("Bearer {}", key);
282 headers.push(("Authorization", &auth_header));
283 }
284
285 let mut request = format!("GET {} HTTP/1.1\r\nHost: {}\r\n", parsed.path, parsed.host);
286 for (key, value) in &headers {
287 request.push_str(&format!("{}: {}\r\n", key, value));
288 }
289 request.push_str("\r\n");
290
291 let mut stream = stream;
292 stream
293 .write_all(request.as_bytes())
294 .map_err(|e| AvError::Sts(format!("Failed to send broker request: {}", e)))?;
295
296 let mut response = String::new();
297 stream
298 .read_to_string(&mut response)
299 .map_err(|e| AvError::Sts(format!("Failed to read broker response: {}", e)))?;
300
301 extract_http_body(&response)
302 }
303}
304
305struct ParsedUrl {
306 host: String,
307 port: u16,
308 path: String,
309}
310
311fn parse_url(url: &str) -> Result<ParsedUrl> {
312 let without_scheme = url
313 .strip_prefix("https://")
314 .or_else(|| url.strip_prefix("http://"))
315 .ok_or_else(|| AvError::InvalidPolicy(format!("Invalid broker URL: {}", url)))?;
316
317 let default_port: u16 = if url.starts_with("https://") {
318 443
319 } else {
320 8080
321 };
322
323 let (host_port, path) = match without_scheme.find('/') {
324 Some(idx) => (&without_scheme[..idx], &without_scheme[idx..]),
325 None => (without_scheme, "/"),
326 };
327
328 let (host, port) = match host_port.rfind(':') {
329 Some(idx) => {
330 let port = host_port[idx + 1..].parse::<u16>().unwrap_or(default_port);
331 (host_port[..idx].to_string(), port)
332 }
333 None => (host_port.to_string(), default_port),
334 };
335
336 Ok(ParsedUrl {
337 host,
338 port,
339 path: path.to_string(),
340 })
341}
342
343fn extract_http_body(response: &str) -> Result<String> {
344 if let Some(idx) = response.find("\r\n\r\n") {
345 let status_line = response.lines().next().unwrap_or("");
346 let status_code: u16 = status_line
347 .split_whitespace()
348 .nth(1)
349 .and_then(|s| s.parse().ok())
350 .unwrap_or(0);
351
352 let body = &response[idx + 4..];
353
354 if status_code >= 400 {
355 return Err(AvError::Sts(format!(
356 "Broker returned HTTP {}: {}",
357 status_code, body
358 )));
359 }
360
361 Ok(body.to_string())
362 } else {
363 Err(AvError::Sts(
364 "Invalid HTTP response from broker".to_string(),
365 ))
366 }
367}
368
369pub fn credentials_to_env_script(resp: &CredentialResponse, shell: &str) -> String {
372 let mut lines = Vec::new();
373 let export = match shell {
374 "fish" => "set -gx",
375 "powershell" | "pwsh" => "$env:",
376 _ => "export",
377 };
378
379 let creds = &resp.credentials;
380 if let Some(ref key) = creds.aws_access_key_id {
381 lines.push(format_env(export, "AWS_ACCESS_KEY_ID", key, shell));
382 }
383 if let Some(ref key) = creds.aws_secret_access_key {
384 lines.push(format_env(export, "AWS_SECRET_ACCESS_KEY", key, shell));
385 }
386 if let Some(ref token) = creds.aws_session_token {
387 lines.push(format_env(export, "AWS_SESSION_TOKEN", token, shell));
388 }
389 if let Some(ref token) = creds.gcp_access_token {
390 lines.push(format_env(
391 export,
392 "CLOUDSDK_AUTH_ACCESS_TOKEN",
393 token,
394 shell,
395 ));
396 }
397 if let Some(ref token) = creds.azure_token {
398 lines.push(format_env(export, "AZURE_ACCESS_TOKEN", token, shell));
399 }
400
401 lines.join("\n")
402}
403
404fn format_env(export: &str, key: &str, value: &str, shell: &str) -> String {
405 match shell {
406 "powershell" | "pwsh" => format!("{}{}=\"{}\"", export, key, value),
407 "fish" => format!("{} {} \"{}\"", export, key, value),
408 _ => format!("{} {}=\"{}\"", export, key, value),
409 }
410}
411
412pub fn credentials_to_json(resp: &CredentialResponse) -> Result<String> {
414 let mut map = HashMap::new();
415 map.insert("session_id", resp.session_id.as_str());
416 map.insert("provider", resp.provider.as_str());
417 map.insert("expires_at", resp.expires_at.as_str());
418
419 let creds = &resp.credentials;
420 if let Some(ref v) = creds.aws_access_key_id {
421 map.insert("aws_access_key_id", v);
422 }
423 if let Some(ref v) = creds.aws_secret_access_key {
424 map.insert("aws_secret_access_key", v);
425 }
426 if let Some(ref v) = creds.aws_session_token {
427 map.insert("aws_session_token", v);
428 }
429 if let Some(ref v) = creds.gcp_access_token {
430 map.insert("gcp_access_token", v);
431 }
432 if let Some(ref v) = creds.azure_token {
433 map.insert("azure_token", v);
434 }
435
436 serde_json::to_string_pretty(&map)
437 .map_err(|e| AvError::Sts(format!("Failed to serialize credentials: {}", e)))
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use crate::server::CredentialEnvVars;
444
445 #[test]
446 fn test_broker_config_default() {
447 let config = BrokerConfig::default();
448 assert_eq!(config.url, "http://localhost:8080");
449 assert_eq!(config.timeout, 30);
450 assert!(config.api_key.is_none());
451 }
452
453 #[test]
454 fn test_broker_config_deserialize() {
455 let toml_str = r#"
456url = "https://audex.internal:8443"
457api_key = "secret-key-123"
458timeout = 60
459default_provider = "gcp"
460default_role_arn = "arn:aws:iam::123:role/MyRole"
461"#;
462 let config: BrokerConfig = toml::from_str(toml_str).unwrap();
463 assert_eq!(config.url, "https://audex.internal:8443");
464 assert_eq!(config.api_key.as_deref(), Some("secret-key-123"));
465 assert_eq!(config.timeout, 60);
466 assert_eq!(config.default_provider.as_deref(), Some("gcp"));
467 }
468
469 #[test]
470 fn test_batch_request_serialization() {
471 let batch = BatchCredentialRequest {
472 requests: vec![CredentialRequest {
473 allow: Some("s3:GetObject".to_string()),
474 profile: None,
475 resource: None,
476 provider: "aws".to_string(),
477 ttl: "15m".to_string(),
478 role_arn: None,
479 command: vec![],
480 agent_id: None,
481 }],
482 };
483 let json = serde_json::to_string(&batch).unwrap();
484 assert!(json.contains("s3:GetObject"));
485 }
486
487 #[test]
488 fn test_batch_response_deserialization() {
489 let json = r#"{"results":[{"index":0,"credentials":{"session_id":"abc","provider":"aws","expires_at":"2026-01-01T00:00:00Z","ttl_seconds":900,"credentials":{"aws_access_key_id":"AKID"}},"error":null}],"total":1,"succeeded":1,"failed":0}"#;
490 let resp: BatchCredentialResponse = serde_json::from_str(json).unwrap();
491 assert_eq!(resp.total, 1);
492 assert_eq!(resp.succeeded, 1);
493 }
494
495 #[test]
496 fn test_credentials_to_env_script_bash() {
497 let resp = mock_credential_response();
498 let script = credentials_to_env_script(&resp, "bash");
499 assert!(script.contains("export AWS_ACCESS_KEY_ID=\"AKID123\""));
500 assert!(script.contains("export AWS_SECRET_ACCESS_KEY=\"secret\""));
501 assert!(script.contains("export AWS_SESSION_TOKEN=\"token\""));
502 }
503
504 #[test]
505 fn test_credentials_to_env_script_fish() {
506 let resp = mock_credential_response();
507 let script = credentials_to_env_script(&resp, "fish");
508 assert!(script.contains("set -gx AWS_ACCESS_KEY_ID \"AKID123\""));
509 }
510
511 #[test]
512 fn test_credentials_to_env_script_powershell() {
513 let resp = mock_credential_response();
514 let script = credentials_to_env_script(&resp, "powershell");
515 assert!(script.contains("$env:AWS_ACCESS_KEY_ID=\"AKID123\""));
516 }
517
518 #[test]
519 fn test_credentials_to_json() {
520 let resp = mock_credential_response();
521 let json = credentials_to_json(&resp).unwrap();
522 assert!(json.contains("AKID123"));
523 assert!(json.contains("session_id"));
524 }
525
526 #[test]
527 fn test_parse_url_with_port() {
528 let parsed = parse_url("http://localhost:9090/v1/credentials").unwrap();
529 assert_eq!(parsed.host, "localhost");
530 assert_eq!(parsed.port, 9090);
531 assert_eq!(parsed.path, "/v1/credentials");
532 }
533
534 #[test]
535 fn test_parse_url_default_port() {
536 let parsed = parse_url("http://localhost/v1/health").unwrap();
537 assert_eq!(parsed.port, 8080);
538 }
539
540 #[test]
541 fn test_parse_url_https_default_port() {
542 let parsed = parse_url("https://audex.internal/v1/health").unwrap();
543 assert_eq!(parsed.port, 443);
544 }
545
546 #[test]
547 fn test_extract_http_body_success() {
548 let resp = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n{\"ok\":true}";
549 let body = extract_http_body(resp).unwrap();
550 assert_eq!(body, "{\"ok\":true}");
551 }
552
553 #[test]
554 fn test_extract_http_body_error() {
555 let resp = "HTTP/1.1 401 Unauthorized\r\n\r\n{\"error\":\"invalid key\"}";
556 assert!(extract_http_body(resp).is_err());
557 }
558
559 #[test]
560 fn test_broker_token_serialization() {
561 let token = BrokerToken {
562 token: "tok_abc123".to_string(),
563 expires_at: "2026-01-01T00:15:00Z".to_string(),
564 request: CredentialRequest {
565 allow: Some("s3:GetObject".to_string()),
566 profile: None,
567 resource: None,
568 provider: "aws".to_string(),
569 ttl: "15m".to_string(),
570 role_arn: None,
571 command: vec![],
572 agent_id: None,
573 },
574 };
575 let json = serde_json::to_string(&token).unwrap();
576 assert!(json.contains("tok_abc123"));
577 assert!(json.contains("s3:GetObject"));
578 }
579
580 fn mock_credential_response() -> CredentialResponse {
581 CredentialResponse {
582 session_id: "test-session-123".to_string(),
583 provider: "aws".to_string(),
584 expires_at: "2026-01-01T00:15:00Z".to_string(),
585 ttl_seconds: 900,
586 credentials: CredentialEnvVars {
587 aws_access_key_id: Some("AKID123".to_string()),
588 aws_secret_access_key: Some("secret".to_string()),
589 aws_session_token: Some("token".to_string()),
590 gcp_access_token: None,
591 azure_token: None,
592 },
593 }
594 }
595}