1use std::collections::HashMap;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{AvError, Result};
7use crate::server::{CredentialRequest, CredentialResponse};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct BrokerConfig {
12 pub url: String,
14 pub api_key: Option<String>,
16 #[serde(default = "default_timeout")]
18 pub timeout: u64,
19 pub default_provider: Option<String>,
21 pub default_role_arn: Option<String>,
23}
24
25fn default_timeout() -> u64 {
26 30
27}
28
29impl Default for BrokerConfig {
30 fn default() -> Self {
31 Self {
32 url: "http://localhost:8080".to_string(),
33 api_key: None,
34 timeout: default_timeout(),
35 default_provider: None,
36 default_role_arn: None,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct BatchCredentialRequest {
44 pub requests: Vec<CredentialRequest>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct BatchCredentialResponse {
50 pub results: Vec<BatchResultItem>,
51 pub total: usize,
52 pub succeeded: usize,
53 pub failed: usize,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct BatchResultItem {
59 pub index: usize,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 pub credentials: Option<CredentialResponse>,
62 #[serde(skip_serializing_if = "Option::is_none")]
63 pub error: Option<String>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct TokenExchangeRequest {
69 pub token: String,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct BrokerToken {
77 pub token: String,
78 pub expires_at: String,
79 pub request: CredentialRequest,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct RevokeRequest {
86 pub session_id: String,
87 pub reason: Option<String>,
88}
89
90pub struct BrokerClient {
92 config: BrokerConfig,
93}
94
95impl BrokerClient {
96 pub fn new(config: BrokerConfig) -> Self {
97 Self { config }
98 }
99
100 pub fn from_env() -> Result<Self> {
103 let url = std::env::var("AUDEX_BROKER_URL").unwrap_or_else(|_| {
104 "http://localhost:8080".to_string()
105 });
106 let api_key = std::env::var("AUDEX_BROKER_API_KEY").ok();
107 Ok(Self::new(BrokerConfig {
108 url,
109 api_key,
110 ..Default::default()
111 }))
112 }
113
114 pub async fn get_credentials(&self, request: &CredentialRequest) -> Result<CredentialResponse> {
116 let url = format!("{}/v1/credentials", self.config.url);
117 let body = serde_json::to_string(request)
118 .map_err(|e| AvError::Sts(format!("Failed to serialize request: {}", e)))?;
119
120 let response = self.http_post(&url, &body).await?;
121 serde_json::from_str(&response)
122 .map_err(|e| AvError::Sts(format!("Failed to parse credential response: {}", e)))
123 }
124
125 pub async fn get_credentials_simple(
127 &self,
128 allow: &str,
129 ttl: &str,
130 ) -> Result<CredentialResponse> {
131 let request = CredentialRequest {
132 allow: Some(allow.to_string()),
133 profile: None,
134 resource: None,
135 provider: self.config.default_provider.clone().unwrap_or_else(|| "aws".to_string()),
136 ttl: ttl.to_string(),
137 role_arn: self.config.default_role_arn.clone(),
138 command: vec![],
139 agent_id: std::env::var("AUDEX_AGENT_ID").ok(),
140 };
141 self.get_credentials(&request).await
142 }
143
144 pub async fn batch_credentials(
146 &self,
147 requests: Vec<CredentialRequest>,
148 ) -> Result<BatchCredentialResponse> {
149 let url = format!("{}/v1/credentials/batch", self.config.url);
150 let batch = BatchCredentialRequest { requests };
151 let body = serde_json::to_string(&batch)
152 .map_err(|e| AvError::Sts(format!("Failed to serialize batch request: {}", e)))?;
153
154 let response = self.http_post(&url, &body).await?;
155 serde_json::from_str(&response)
156 .map_err(|e| AvError::Sts(format!("Failed to parse batch response: {}", e)))
157 }
158
159 pub async fn create_token(&self, request: &CredentialRequest) -> Result<BrokerToken> {
162 let url = format!("{}/v1/tokens", self.config.url);
163 let body = serde_json::to_string(request)
164 .map_err(|e| AvError::Sts(format!("Failed to serialize token request: {}", e)))?;
165
166 let response = self.http_post(&url, &body).await?;
167 serde_json::from_str(&response)
168 .map_err(|e| AvError::Sts(format!("Failed to parse token response: {}", e)))
169 }
170
171 pub async fn exchange_token(&self, token: &str) -> Result<CredentialResponse> {
173 let url = format!("{}/v1/tokens/exchange", self.config.url);
174 let req = TokenExchangeRequest {
175 token: token.to_string(),
176 };
177 let body = serde_json::to_string(&req)
178 .map_err(|e| AvError::Sts(format!("Failed to serialize exchange request: {}", e)))?;
179
180 let response = self.http_post(&url, &body).await?;
181 serde_json::from_str(&response)
182 .map_err(|e| AvError::Sts(format!("Failed to parse exchange response: {}", e)))
183 }
184
185 pub async fn revoke(&self, session_id: &str, reason: Option<&str>) -> Result<()> {
187 let url = format!("{}/v1/sessions/{}/revoke", self.config.url, session_id);
188 let req = RevokeRequest {
189 session_id: session_id.to_string(),
190 reason: reason.map(|r| r.to_string()),
191 };
192 let body = serde_json::to_string(&req)
193 .map_err(|e| AvError::Sts(format!("Failed to serialize revoke request: {}", e)))?;
194
195 self.http_post(&url, &body).await?;
196 Ok(())
197 }
198
199 pub async fn list_sessions(&self) -> Result<String> {
201 let url = format!("{}/v1/sessions", self.config.url);
202 self.http_get(&url).await
203 }
204
205 pub async fn health(&self) -> Result<String> {
207 let url = format!("{}/v1/health", self.config.url);
208 self.http_get(&url).await
209 }
210
211 async fn http_post(&self, url: &str, body: &str) -> Result<String> {
213 use std::io::{Read, Write};
214 use std::net::TcpStream;
215
216 let parsed = parse_url(url)?;
217 let stream = TcpStream::connect(format!("{}:{}", parsed.host, parsed.port))
218 .map_err(|e| AvError::Sts(format!("Failed to connect to broker at {}: {}", url, e)))?;
219 stream
220 .set_read_timeout(Some(Duration::from_secs(self.config.timeout)))
221 .ok();
222
223 let mut headers = vec![
224 ("Content-Type", "application/json"),
225 ("Connection", "close"),
226 ];
227 let auth_header;
228 if let Some(ref key) = self.config.api_key {
229 auth_header = format!("Bearer {}", key);
230 headers.push(("Authorization", &auth_header));
231 }
232 let agent_header;
233 if let Ok(agent_id) = std::env::var("AUDEX_AGENT_ID") {
234 agent_header = agent_id;
235 headers.push(("X-Audex-Agent-Id", &agent_header));
236 }
237
238 let mut request = format!(
239 "POST {} HTTP/1.1\r\nHost: {}\r\nContent-Length: {}\r\n",
240 parsed.path, parsed.host, body.len()
241 );
242 for (key, value) in &headers {
243 request.push_str(&format!("{}: {}\r\n", key, value));
244 }
245 request.push_str("\r\n");
246 request.push_str(body);
247
248 let mut stream = stream;
249 stream
250 .write_all(request.as_bytes())
251 .map_err(|e| AvError::Sts(format!("Failed to send broker request: {}", e)))?;
252
253 let mut response = String::new();
254 stream
255 .read_to_string(&mut response)
256 .map_err(|e| AvError::Sts(format!("Failed to read broker response: {}", e)))?;
257
258 extract_http_body(&response)
259 }
260
261 async fn http_get(&self, url: &str) -> Result<String> {
263 use std::io::{Read, Write};
264 use std::net::TcpStream;
265
266 let parsed = parse_url(url)?;
267 let stream = TcpStream::connect(format!("{}:{}", parsed.host, parsed.port))
268 .map_err(|e| AvError::Sts(format!("Failed to connect to broker at {}: {}", url, e)))?;
269 stream
270 .set_read_timeout(Some(Duration::from_secs(self.config.timeout)))
271 .ok();
272
273 let mut headers: Vec<(&str, &str)> = vec![("Connection", "close")];
274 let auth_header;
275 if let Some(ref key) = self.config.api_key {
276 auth_header = format!("Bearer {}", key);
277 headers.push(("Authorization", &auth_header));
278 }
279
280 let mut request = format!(
281 "GET {} HTTP/1.1\r\nHost: {}\r\n",
282 parsed.path, parsed.host
283 );
284 for (key, value) in &headers {
285 request.push_str(&format!("{}: {}\r\n", key, value));
286 }
287 request.push_str("\r\n");
288
289 let mut stream = stream;
290 stream
291 .write_all(request.as_bytes())
292 .map_err(|e| AvError::Sts(format!("Failed to send broker request: {}", e)))?;
293
294 let mut response = String::new();
295 stream
296 .read_to_string(&mut response)
297 .map_err(|e| AvError::Sts(format!("Failed to read broker response: {}", e)))?;
298
299 extract_http_body(&response)
300 }
301}
302
303struct ParsedUrl {
304 host: String,
305 port: u16,
306 path: String,
307}
308
309fn parse_url(url: &str) -> Result<ParsedUrl> {
310 let without_scheme = url
311 .strip_prefix("https://")
312 .or_else(|| url.strip_prefix("http://"))
313 .ok_or_else(|| AvError::InvalidPolicy(format!("Invalid broker URL: {}", url)))?;
314
315 let default_port: u16 = if url.starts_with("https://") { 443 } else { 8080 };
316
317 let (host_port, path) = match without_scheme.find('/') {
318 Some(idx) => (&without_scheme[..idx], &without_scheme[idx..]),
319 None => (without_scheme, "/"),
320 };
321
322 let (host, port) = match host_port.rfind(':') {
323 Some(idx) => {
324 let port = host_port[idx + 1..].parse::<u16>().unwrap_or(default_port);
325 (host_port[..idx].to_string(), port)
326 }
327 None => (host_port.to_string(), default_port),
328 };
329
330 Ok(ParsedUrl {
331 host,
332 port,
333 path: path.to_string(),
334 })
335}
336
337fn extract_http_body(response: &str) -> Result<String> {
338 if let Some(idx) = response.find("\r\n\r\n") {
339 let status_line = response.lines().next().unwrap_or("");
340 let status_code: u16 = status_line
341 .split_whitespace()
342 .nth(1)
343 .and_then(|s| s.parse().ok())
344 .unwrap_or(0);
345
346 let body = &response[idx + 4..];
347
348 if status_code >= 400 {
349 return Err(AvError::Sts(format!(
350 "Broker returned HTTP {}: {}",
351 status_code, body
352 )));
353 }
354
355 Ok(body.to_string())
356 } else {
357 Err(AvError::Sts("Invalid HTTP response from broker".to_string()))
358 }
359}
360
361pub fn credentials_to_env_script(resp: &CredentialResponse, shell: &str) -> String {
364 let mut lines = Vec::new();
365 let export = match shell {
366 "fish" => "set -gx",
367 "powershell" | "pwsh" => "$env:",
368 _ => "export",
369 };
370
371 let creds = &resp.credentials;
372 if let Some(ref key) = creds.aws_access_key_id {
373 lines.push(format_env(export, "AWS_ACCESS_KEY_ID", key, shell));
374 }
375 if let Some(ref key) = creds.aws_secret_access_key {
376 lines.push(format_env(export, "AWS_SECRET_ACCESS_KEY", key, shell));
377 }
378 if let Some(ref token) = creds.aws_session_token {
379 lines.push(format_env(export, "AWS_SESSION_TOKEN", token, shell));
380 }
381 if let Some(ref token) = creds.gcp_access_token {
382 lines.push(format_env(export, "CLOUDSDK_AUTH_ACCESS_TOKEN", token, shell));
383 }
384 if let Some(ref token) = creds.azure_token {
385 lines.push(format_env(export, "AZURE_ACCESS_TOKEN", token, shell));
386 }
387
388 lines.join("\n")
389}
390
391fn format_env(export: &str, key: &str, value: &str, shell: &str) -> String {
392 match shell {
393 "powershell" | "pwsh" => format!("{}{}=\"{}\"", export, key, value),
394 "fish" => format!("{} {} \"{}\"", export, key, value),
395 _ => format!("{} {}=\"{}\"", export, key, value),
396 }
397}
398
399pub fn credentials_to_json(resp: &CredentialResponse) -> Result<String> {
401 let mut map = HashMap::new();
402 map.insert("session_id", resp.session_id.as_str());
403 map.insert("provider", resp.provider.as_str());
404 map.insert("expires_at", resp.expires_at.as_str());
405
406 let creds = &resp.credentials;
407 if let Some(ref v) = creds.aws_access_key_id {
408 map.insert("aws_access_key_id", v);
409 }
410 if let Some(ref v) = creds.aws_secret_access_key {
411 map.insert("aws_secret_access_key", v);
412 }
413 if let Some(ref v) = creds.aws_session_token {
414 map.insert("aws_session_token", v);
415 }
416 if let Some(ref v) = creds.gcp_access_token {
417 map.insert("gcp_access_token", v);
418 }
419 if let Some(ref v) = creds.azure_token {
420 map.insert("azure_token", v);
421 }
422
423 serde_json::to_string_pretty(&map)
424 .map_err(|e| AvError::Sts(format!("Failed to serialize credentials: {}", e)))
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430 use crate::server::CredentialEnvVars;
431
432 #[test]
433 fn test_broker_config_default() {
434 let config = BrokerConfig::default();
435 assert_eq!(config.url, "http://localhost:8080");
436 assert_eq!(config.timeout, 30);
437 assert!(config.api_key.is_none());
438 }
439
440 #[test]
441 fn test_broker_config_deserialize() {
442 let toml_str = r#"
443url = "https://audex.internal:8443"
444api_key = "secret-key-123"
445timeout = 60
446default_provider = "gcp"
447default_role_arn = "arn:aws:iam::123:role/MyRole"
448"#;
449 let config: BrokerConfig = toml::from_str(toml_str).unwrap();
450 assert_eq!(config.url, "https://audex.internal:8443");
451 assert_eq!(config.api_key.as_deref(), Some("secret-key-123"));
452 assert_eq!(config.timeout, 60);
453 assert_eq!(config.default_provider.as_deref(), Some("gcp"));
454 }
455
456 #[test]
457 fn test_batch_request_serialization() {
458 let batch = BatchCredentialRequest {
459 requests: vec![
460 CredentialRequest {
461 allow: Some("s3:GetObject".to_string()),
462 profile: None,
463 resource: None,
464 provider: "aws".to_string(),
465 ttl: "15m".to_string(),
466 role_arn: None,
467 command: vec![],
468 agent_id: None,
469 },
470 ],
471 };
472 let json = serde_json::to_string(&batch).unwrap();
473 assert!(json.contains("s3:GetObject"));
474 }
475
476 #[test]
477 fn test_batch_response_deserialization() {
478 let json = r#"{"results":[{"index":0,"credentials":{"session_id":"abc","provider":"aws","expires_at":"2026-01-01T00:00:00Z","ttl_seconds":900,"credentials":{"aws_access_key_id":"AKID"}},"error":null}],"total":1,"succeeded":1,"failed":0}"#;
479 let resp: BatchCredentialResponse = serde_json::from_str(json).unwrap();
480 assert_eq!(resp.total, 1);
481 assert_eq!(resp.succeeded, 1);
482 }
483
484 #[test]
485 fn test_credentials_to_env_script_bash() {
486 let resp = mock_credential_response();
487 let script = credentials_to_env_script(&resp, "bash");
488 assert!(script.contains("export AWS_ACCESS_KEY_ID=\"AKID123\""));
489 assert!(script.contains("export AWS_SECRET_ACCESS_KEY=\"secret\""));
490 assert!(script.contains("export AWS_SESSION_TOKEN=\"token\""));
491 }
492
493 #[test]
494 fn test_credentials_to_env_script_fish() {
495 let resp = mock_credential_response();
496 let script = credentials_to_env_script(&resp, "fish");
497 assert!(script.contains("set -gx AWS_ACCESS_KEY_ID \"AKID123\""));
498 }
499
500 #[test]
501 fn test_credentials_to_env_script_powershell() {
502 let resp = mock_credential_response();
503 let script = credentials_to_env_script(&resp, "powershell");
504 assert!(script.contains("$env:AWS_ACCESS_KEY_ID=\"AKID123\""));
505 }
506
507 #[test]
508 fn test_credentials_to_json() {
509 let resp = mock_credential_response();
510 let json = credentials_to_json(&resp).unwrap();
511 assert!(json.contains("AKID123"));
512 assert!(json.contains("session_id"));
513 }
514
515 #[test]
516 fn test_parse_url_with_port() {
517 let parsed = parse_url("http://localhost:9090/v1/credentials").unwrap();
518 assert_eq!(parsed.host, "localhost");
519 assert_eq!(parsed.port, 9090);
520 assert_eq!(parsed.path, "/v1/credentials");
521 }
522
523 #[test]
524 fn test_parse_url_default_port() {
525 let parsed = parse_url("http://localhost/v1/health").unwrap();
526 assert_eq!(parsed.port, 8080);
527 }
528
529 #[test]
530 fn test_parse_url_https_default_port() {
531 let parsed = parse_url("https://audex.internal/v1/health").unwrap();
532 assert_eq!(parsed.port, 443);
533 }
534
535 #[test]
536 fn test_extract_http_body_success() {
537 let resp = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n{\"ok\":true}";
538 let body = extract_http_body(resp).unwrap();
539 assert_eq!(body, "{\"ok\":true}");
540 }
541
542 #[test]
543 fn test_extract_http_body_error() {
544 let resp = "HTTP/1.1 401 Unauthorized\r\n\r\n{\"error\":\"invalid key\"}";
545 assert!(extract_http_body(resp).is_err());
546 }
547
548 #[test]
549 fn test_broker_token_serialization() {
550 let token = BrokerToken {
551 token: "tok_abc123".to_string(),
552 expires_at: "2026-01-01T00:15:00Z".to_string(),
553 request: CredentialRequest {
554 allow: Some("s3:GetObject".to_string()),
555 profile: None,
556 resource: None,
557 provider: "aws".to_string(),
558 ttl: "15m".to_string(),
559 role_arn: None,
560 command: vec![],
561 agent_id: None,
562 },
563 };
564 let json = serde_json::to_string(&token).unwrap();
565 assert!(json.contains("tok_abc123"));
566 assert!(json.contains("s3:GetObject"));
567 }
568
569 fn mock_credential_response() -> CredentialResponse {
570 CredentialResponse {
571 session_id: "test-session-123".to_string(),
572 provider: "aws".to_string(),
573 expires_at: "2026-01-01T00:15:00Z".to_string(),
574 ttl_seconds: 900,
575 credentials: CredentialEnvVars {
576 aws_access_key_id: Some("AKID123".to_string()),
577 aws_secret_access_key: Some("secret".to_string()),
578 aws_session_token: Some("token".to_string()),
579 gcp_access_token: None,
580 azure_token: None,
581 },
582 }
583 }
584}