1use std::collections::HashMap;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{AvError, Result};
7use crate::server::{CredentialRequest, CredentialResponse};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct BrokerConfig {
12 pub url: String,
14 pub api_key: Option<String>,
16 #[serde(default = "default_timeout")]
18 pub timeout: u64,
19 pub default_provider: Option<String>,
21 pub default_role_arn: Option<String>,
23}
24
25fn default_timeout() -> u64 {
26 30
27}
28
29impl Default for BrokerConfig {
30 fn default() -> Self {
31 Self {
32 url: "http://localhost:8080".to_string(),
33 api_key: None,
34 timeout: default_timeout(),
35 default_provider: None,
36 default_role_arn: None,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct BatchCredentialRequest {
44 pub requests: Vec<CredentialRequest>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct BatchCredentialResponse {
50 pub results: Vec<BatchResultItem>,
51 pub total: usize,
52 pub succeeded: usize,
53 pub failed: usize,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct BatchResultItem {
59 pub index: usize,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 pub credentials: Option<CredentialResponse>,
62 #[serde(skip_serializing_if = "Option::is_none")]
63 pub error: Option<String>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct TokenExchangeRequest {
69 pub token: String,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct BrokerToken {
77 pub token: String,
78 pub expires_at: String,
79 pub request: CredentialRequest,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct RevokeRequest {
86 pub session_id: String,
87 pub reason: Option<String>,
88}
89
90pub struct BrokerClient {
92 config: BrokerConfig,
93}
94
95impl BrokerClient {
96 pub fn new(config: BrokerConfig) -> Self {
97 Self { config }
98 }
99
100 pub fn from_env() -> Result<Self> {
103 let url = std::env::var("AUDEX_BROKER_URL")
104 .unwrap_or_else(|_| "http://localhost:8080".to_string());
105 let api_key = std::env::var("AUDEX_BROKER_API_KEY").ok();
106 Ok(Self::new(BrokerConfig {
107 url,
108 api_key,
109 ..Default::default()
110 }))
111 }
112
113 pub async fn get_credentials(&self, request: &CredentialRequest) -> Result<CredentialResponse> {
115 let url = format!("{}/v1/credentials", self.config.url);
116 let body = serde_json::to_string(request)
117 .map_err(|e| AvError::Sts(format!("Failed to serialize request: {}", e)))?;
118
119 let response = self.http_post(&url, &body).await?;
120 serde_json::from_str(&response)
121 .map_err(|e| AvError::Sts(format!("Failed to parse credential response: {}", e)))
122 }
123
124 pub async fn get_credentials_simple(
126 &self,
127 allow: &str,
128 ttl: &str,
129 ) -> Result<CredentialResponse> {
130 let request = CredentialRequest {
131 allow: Some(allow.to_string()),
132 profile: None,
133 resource: None,
134 provider: self
135 .config
136 .default_provider
137 .clone()
138 .unwrap_or_else(|| "aws".to_string()),
139 ttl: ttl.to_string(),
140 role_arn: self.config.default_role_arn.clone(),
141 command: vec![],
142 agent_id: std::env::var("AUDEX_AGENT_ID").ok(),
143 };
144 self.get_credentials(&request).await
145 }
146
147 pub async fn batch_credentials(
149 &self,
150 requests: Vec<CredentialRequest>,
151 ) -> Result<BatchCredentialResponse> {
152 let url = format!("{}/v1/credentials/batch", self.config.url);
153 let batch = BatchCredentialRequest { requests };
154 let body = serde_json::to_string(&batch)
155 .map_err(|e| AvError::Sts(format!("Failed to serialize batch request: {}", e)))?;
156
157 let response = self.http_post(&url, &body).await?;
158 serde_json::from_str(&response)
159 .map_err(|e| AvError::Sts(format!("Failed to parse batch response: {}", e)))
160 }
161
162 pub async fn create_token(&self, request: &CredentialRequest) -> Result<BrokerToken> {
165 let url = format!("{}/v1/tokens", self.config.url);
166 let body = serde_json::to_string(request)
167 .map_err(|e| AvError::Sts(format!("Failed to serialize token request: {}", e)))?;
168
169 let response = self.http_post(&url, &body).await?;
170 serde_json::from_str(&response)
171 .map_err(|e| AvError::Sts(format!("Failed to parse token response: {}", e)))
172 }
173
174 pub async fn exchange_token(&self, token: &str) -> Result<CredentialResponse> {
176 let url = format!("{}/v1/tokens/exchange", self.config.url);
177 let req = TokenExchangeRequest {
178 token: token.to_string(),
179 };
180 let body = serde_json::to_string(&req)
181 .map_err(|e| AvError::Sts(format!("Failed to serialize exchange request: {}", e)))?;
182
183 let response = self.http_post(&url, &body).await?;
184 serde_json::from_str(&response)
185 .map_err(|e| AvError::Sts(format!("Failed to parse exchange response: {}", e)))
186 }
187
188 pub async fn revoke(&self, session_id: &str, reason: Option<&str>) -> Result<()> {
190 crate::validate::session_id(session_id)?;
191 let url = format!("{}/v1/sessions/{}/revoke", self.config.url, session_id);
192 let req = RevokeRequest {
193 session_id: session_id.to_string(),
194 reason: reason.map(|r| r.to_string()),
195 };
196 let body = serde_json::to_string(&req)
197 .map_err(|e| AvError::Sts(format!("Failed to serialize revoke request: {}", e)))?;
198
199 self.http_post(&url, &body).await?;
200 Ok(())
201 }
202
203 pub async fn list_sessions(&self) -> Result<String> {
205 let url = format!("{}/v1/sessions", self.config.url);
206 self.http_get(&url).await
207 }
208
209 pub async fn health(&self) -> Result<String> {
211 let url = format!("{}/v1/health", self.config.url);
212 self.http_get(&url).await
213 }
214
215 async fn http_post(&self, url: &str, body: &str) -> Result<String> {
217 self.http_request("POST", url, Some(body)).await
218 }
219
220 async fn http_get(&self, url: &str) -> Result<String> {
222 self.http_request("GET", url, None).await
223 }
224
225 async fn http_request(&self, method: &str, url: &str, body: Option<&str>) -> Result<String> {
232 use tokio::io::{AsyncReadExt, AsyncWriteExt};
233
234 let parsed = parse_url(url)?;
235
236 let api_key = match &self.config.api_key {
240 Some(k) => {
241 validate_header_value("api_key", k)?;
242 Some(k.clone())
243 }
244 None => None,
245 };
246 let agent_id = std::env::var("AUDEX_AGENT_ID").ok();
247 if let Some(ref a) = agent_id {
248 validate_header_value("AUDEX_AGENT_ID", a)?;
249 }
250
251 let mut request = String::new();
252 request.push_str(&format!("{} {} HTTP/1.1\r\n", method, parsed.path));
253 request.push_str(&format!("Host: {}\r\n", parsed.host_header));
255 request.push_str("Connection: close\r\n");
256 if let Some(b) = body {
257 request.push_str("Content-Type: application/json\r\n");
258 request.push_str(&format!("Content-Length: {}\r\n", b.len()));
259 }
260 if let Some(ref key) = api_key {
261 request.push_str(&format!("Authorization: Bearer {}\r\n", key));
262 }
263 if let Some(ref a) = agent_id {
264 request.push_str(&format!("X-Audex-Agent-Id: {}\r\n", a));
265 }
266 request.push_str("\r\n");
267 if let Some(b) = body {
268 request.push_str(b);
269 }
270
271 let deadline = Duration::from_secs(self.config.timeout);
272 let exchange = async {
273 let mut stream =
274 tokio::net::TcpStream::connect(format!("{}:{}", parsed.connect_host, parsed.port))
275 .await
276 .map_err(|e| {
277 AvError::Sts(format!("Failed to connect to broker at {}: {}", url, e))
278 })?;
279
280 stream
281 .write_all(request.as_bytes())
282 .await
283 .map_err(|e| AvError::Sts(format!("Failed to send broker request: {}", e)))?;
284
285 let mut response = Vec::new();
290 let mut limited = stream.take(2 * 1024 * 1024);
291 limited
292 .read_to_end(&mut response)
293 .await
294 .map_err(|e| AvError::Sts(format!("Failed to read broker response: {}", e)))?;
295 let response = String::from_utf8_lossy(&response).into_owned();
296 extract_http_body(&response)
297 };
298
299 tokio::time::timeout(deadline, exchange)
300 .await
301 .map_err(|_| {
302 AvError::Sts(format!(
303 "Broker request timed out after {}s",
304 deadline.as_secs()
305 ))
306 })?
307 }
308}
309
310fn validate_header_value(name: &str, value: &str) -> Result<()> {
313 if value.bytes().any(|b| b == b'\r' || b == b'\n' || b == 0) {
314 return Err(AvError::InvalidPolicy(format!(
315 "Broker header value for `{}` contains forbidden characters (CR/LF/NUL)",
316 name
317 )));
318 }
319 Ok(())
320}
321
322#[derive(Debug)]
323struct ParsedUrl {
324 #[allow(dead_code)] host: String,
327 host_header: String,
329 connect_host: String,
331 port: u16,
332 path: String,
333}
334
335fn parse_url(url: &str) -> Result<ParsedUrl> {
336 use std::net::IpAddr;
337 use std::str::FromStr;
338
339 if url.starts_with("https://") {
342 return Err(AvError::InvalidPolicy(
343 "Broker client does not support HTTPS. Use http:// or place the broker \
344 behind a TLS-terminating reverse proxy."
345 .to_string(),
346 ));
347 }
348
349 let without_scheme = url.strip_prefix("http://").ok_or_else(|| {
350 AvError::InvalidPolicy(format!("Invalid broker URL (expected http://): {}", url))
351 })?;
352
353 let default_port: u16 = 8080;
354
355 let (host_port, path) = match without_scheme.find('/') {
356 Some(idx) => (&without_scheme[..idx], &without_scheme[idx..]),
357 None => (without_scheme, "/"),
358 };
359
360 let (raw_host, port_str): (&str, Option<&str>) = if let Some(rest) = host_port.strip_prefix('[')
365 {
366 match rest.find(']') {
367 Some(end) => {
368 let host = &rest[..end];
369 let after = &rest[end + 1..];
370 let port = if after.is_empty() {
371 None
372 } else if let Some(p) = after.strip_prefix(':') {
373 Some(p)
374 } else {
375 return Err(AvError::InvalidPolicy(format!(
376 "Broker URL has malformed IPv6 host: {}",
377 url
378 )));
379 };
380 (host, port)
381 }
382 None => {
383 return Err(AvError::InvalidPolicy(format!(
384 "Broker URL has unterminated IPv6 bracket: {}",
385 url
386 )));
387 }
388 }
389 } else {
390 match host_port.rfind(':') {
391 Some(idx) => (&host_port[..idx], Some(&host_port[idx + 1..])),
392 None => (host_port, None),
393 }
394 };
395
396 let port = match port_str {
401 Some(s) => s.parse::<u16>().map_err(|_| {
402 AvError::InvalidPolicy(format!("Broker URL has invalid port `{}` in {}", s, url))
403 })?,
404 None => default_port,
405 };
406
407 if raw_host.is_empty() {
408 return Err(AvError::InvalidPolicy(format!(
409 "Broker URL has empty host: {}",
410 url
411 )));
412 }
413
414 let is_loopback = if let Ok(ip) = IpAddr::from_str(raw_host) {
421 ip.is_loopback()
422 } else {
423 raw_host.eq_ignore_ascii_case("localhost")
424 };
425 if !is_loopback {
426 return Err(AvError::InvalidPolicy(format!(
427 "Broker URL '{}' uses plaintext HTTP to a non-loopback host. \
428 This would send credentials in cleartext. Use a loopback address \
429 (localhost/127.0.0.1) or place the broker behind a TLS-terminating \
430 reverse proxy and use the AUDEX_BROKER_URL env var with http://localhost.",
431 url
432 )));
433 }
434
435 let is_ipv6 = raw_host.contains(':');
439 let (host_header, connect_host) = if is_ipv6 {
440 (format!("[{}]", raw_host), format!("[{}]", raw_host))
441 } else {
442 (raw_host.to_string(), raw_host.to_string())
443 };
444
445 Ok(ParsedUrl {
446 host: raw_host.to_string(),
447 host_header,
448 connect_host,
449 port,
450 path: path.to_string(),
451 })
452}
453
454fn extract_http_body(response: &str) -> Result<String> {
455 if let Some(idx) = response.find("\r\n\r\n") {
456 let status_line = response.lines().next().unwrap_or("");
457 let status_code: u16 = status_line
458 .split_whitespace()
459 .nth(1)
460 .and_then(|s| s.parse().ok())
461 .unwrap_or(0);
462
463 let body = &response[idx + 4..];
464
465 if status_code >= 400 {
466 return Err(AvError::Sts(format!(
467 "Broker returned HTTP {}: {}",
468 status_code, body
469 )));
470 }
471
472 Ok(body.to_string())
473 } else {
474 Err(AvError::Sts(
475 "Invalid HTTP response from broker".to_string(),
476 ))
477 }
478}
479
480pub fn credentials_to_env_script(resp: &CredentialResponse, shell: &str) -> String {
483 let mut lines = Vec::new();
484 let export = match shell {
485 "fish" => "set -gx",
486 "powershell" | "pwsh" => "$env:",
487 _ => "export",
488 };
489
490 let creds = &resp.credentials;
491 if let Some(ref key) = creds.aws_access_key_id {
492 lines.push(format_env(export, "AWS_ACCESS_KEY_ID", key, shell));
493 }
494 if let Some(ref key) = creds.aws_secret_access_key {
495 lines.push(format_env(export, "AWS_SECRET_ACCESS_KEY", key, shell));
496 }
497 if let Some(ref token) = creds.aws_session_token {
498 lines.push(format_env(export, "AWS_SESSION_TOKEN", token, shell));
499 }
500 if let Some(ref token) = creds.gcp_access_token {
501 lines.push(format_env(
502 export,
503 "CLOUDSDK_AUTH_ACCESS_TOKEN",
504 token,
505 shell,
506 ));
507 }
508 if let Some(ref token) = creds.azure_token {
509 lines.push(format_env(export, "AZURE_ACCESS_TOKEN", token, shell));
510 }
511
512 lines.join("\n")
513}
514
515fn format_env(export: &str, key: &str, value: &str, shell: &str) -> String {
516 let escaped = value
519 .replace('\\', "\\\\")
520 .replace('"', "\\\"")
521 .replace('`', "\\`")
522 .replace('$', "\\$")
523 .replace('!', "\\!");
524 match shell {
525 "powershell" | "pwsh" => format!("{}{}=\"{}\"", export, key, escaped),
526 "fish" => format!("{} {} \"{}\"", export, key, escaped),
527 _ => format!("{} {}=\"{}\"", export, key, escaped),
528 }
529}
530
531pub fn credentials_to_json(resp: &CredentialResponse) -> Result<String> {
533 let mut map = HashMap::new();
534 map.insert("session_id", resp.session_id.as_str());
535 map.insert("provider", resp.provider.as_str());
536 map.insert("expires_at", resp.expires_at.as_str());
537
538 let creds = &resp.credentials;
539 if let Some(ref v) = creds.aws_access_key_id {
540 map.insert("aws_access_key_id", v);
541 }
542 if let Some(ref v) = creds.aws_secret_access_key {
543 map.insert("aws_secret_access_key", v);
544 }
545 if let Some(ref v) = creds.aws_session_token {
546 map.insert("aws_session_token", v);
547 }
548 if let Some(ref v) = creds.gcp_access_token {
549 map.insert("gcp_access_token", v);
550 }
551 if let Some(ref v) = creds.azure_token {
552 map.insert("azure_token", v);
553 }
554
555 serde_json::to_string_pretty(&map)
556 .map_err(|e| AvError::Sts(format!("Failed to serialize credentials: {}", e)))
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use crate::server::CredentialEnvVars;
563
564 #[test]
565 fn test_broker_config_default() {
566 let config = BrokerConfig::default();
567 assert_eq!(config.url, "http://localhost:8080");
568 assert_eq!(config.timeout, 30);
569 assert!(config.api_key.is_none());
570 }
571
572 #[test]
573 fn test_broker_config_deserialize() {
574 let toml_str = r#"
575url = "https://audex.internal:8443"
576api_key = "secret-key-123"
577timeout = 60
578default_provider = "gcp"
579default_role_arn = "arn:aws:iam::123:role/MyRole"
580"#;
581 let config: BrokerConfig = toml::from_str(toml_str).unwrap();
582 assert_eq!(config.url, "https://audex.internal:8443");
583 assert_eq!(config.api_key.as_deref(), Some("secret-key-123"));
584 assert_eq!(config.timeout, 60);
585 assert_eq!(config.default_provider.as_deref(), Some("gcp"));
586 }
587
588 #[test]
589 fn test_batch_request_serialization() {
590 let batch = BatchCredentialRequest {
591 requests: vec![CredentialRequest {
592 allow: Some("s3:GetObject".to_string()),
593 profile: None,
594 resource: None,
595 provider: "aws".to_string(),
596 ttl: "15m".to_string(),
597 role_arn: None,
598 command: vec![],
599 agent_id: None,
600 }],
601 };
602 let json = serde_json::to_string(&batch).unwrap();
603 assert!(json.contains("s3:GetObject"));
604 }
605
606 #[test]
607 fn test_batch_response_deserialization() {
608 let json = r#"{"results":[{"index":0,"credentials":{"session_id":"abc","provider":"aws","expires_at":"2026-01-01T00:00:00Z","ttl_seconds":900,"credentials":{"aws_access_key_id":"AKID"}},"error":null}],"total":1,"succeeded":1,"failed":0}"#;
609 let resp: BatchCredentialResponse = serde_json::from_str(json).unwrap();
610 assert_eq!(resp.total, 1);
611 assert_eq!(resp.succeeded, 1);
612 }
613
614 #[test]
615 fn test_credentials_to_env_script_bash() {
616 let resp = mock_credential_response();
617 let script = credentials_to_env_script(&resp, "bash");
618 assert!(script.contains("export AWS_ACCESS_KEY_ID=\"AKID123\""));
619 assert!(script.contains("export AWS_SECRET_ACCESS_KEY=\"secret\""));
620 assert!(script.contains("export AWS_SESSION_TOKEN=\"token\""));
621 }
622
623 #[test]
624 fn test_credentials_to_env_script_fish() {
625 let resp = mock_credential_response();
626 let script = credentials_to_env_script(&resp, "fish");
627 assert!(script.contains("set -gx AWS_ACCESS_KEY_ID \"AKID123\""));
628 }
629
630 #[test]
631 fn test_credentials_to_env_script_powershell() {
632 let resp = mock_credential_response();
633 let script = credentials_to_env_script(&resp, "powershell");
634 assert!(script.contains("$env:AWS_ACCESS_KEY_ID=\"AKID123\""));
635 }
636
637 #[test]
638 fn test_credentials_to_json() {
639 let resp = mock_credential_response();
640 let json = credentials_to_json(&resp).unwrap();
641 assert!(json.contains("AKID123"));
642 assert!(json.contains("session_id"));
643 }
644
645 #[test]
646 fn test_parse_url_with_port() {
647 let parsed = parse_url("http://localhost:9090/v1/credentials").unwrap();
648 assert_eq!(parsed.host, "localhost");
649 assert_eq!(parsed.port, 9090);
650 assert_eq!(parsed.path, "/v1/credentials");
651 }
652
653 #[test]
654 fn test_parse_url_default_port() {
655 let parsed = parse_url("http://localhost/v1/health").unwrap();
656 assert_eq!(parsed.port, 8080);
657 }
658
659 #[test]
660 fn test_parse_url_https_rejected() {
661 let result = parse_url("https://audex.internal/v1/health");
662 assert!(result.is_err());
663 let err = result.unwrap_err().to_string();
664 assert!(err.contains("HTTPS"), "Error should mention HTTPS: {}", err);
665 }
666
667 #[test]
668 fn test_parse_url_non_loopback_rejected() {
669 let result = parse_url("http://10.0.0.5:8080/v1/credentials");
670 assert!(result.is_err());
671 let err = result.unwrap_err().to_string();
672 assert!(
673 err.contains("plaintext HTTP"),
674 "Error should mention plaintext: {}",
675 err
676 );
677
678 let result = parse_url("http://audex.internal:8080/v1/credentials");
680 assert!(result.is_err());
681 }
682
683 #[test]
684 fn test_parse_url_loopback_accepted() {
685 assert!(parse_url("http://localhost:8080/v1/creds").is_ok());
686 assert!(parse_url("http://127.0.0.1:8080/v1/creds").is_ok());
687 assert!(parse_url("http://[::1]:8080/v1/creds").is_ok());
688 assert!(parse_url("http://127.0.0.2:8080/v1/creds").is_ok());
691 assert!(parse_url("http://127.1.2.3:8080/v1/creds").is_ok());
692 }
693
694 #[test]
695 fn test_parse_url_invalid_port_rejected() {
696 let err = parse_url("http://localhost:abcd/v1/creds").unwrap_err();
698 assert!(err.to_string().contains("invalid port"), "got: {}", err);
699 }
700
701 #[test]
702 fn test_parse_url_unspecified_address_rejected() {
703 assert!(parse_url("http://0.0.0.0:8080/v1/creds").is_err());
706 }
707
708 #[test]
709 fn test_parse_url_ipv6_bracketed_with_no_port() {
710 let parsed = parse_url("http://[::1]/v1/creds").unwrap();
713 assert_eq!(parsed.host, "::1");
714 assert_eq!(parsed.port, 8080);
715 assert_eq!(parsed.host_header, "[::1]");
716 }
717
718 #[test]
719 fn test_validate_header_value_rejects_crlf() {
720 assert!(validate_header_value("api_key", "abc\r\nX-Evil: 1").is_err());
722 assert!(validate_header_value("api_key", "abc\nevil").is_err());
723 assert!(validate_header_value("api_key", "abc\0nul").is_err());
724 assert!(validate_header_value("api_key", "normal-key").is_ok());
725 }
726
727 #[test]
728 fn test_extract_http_body_success() {
729 let resp = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n{\"ok\":true}";
730 let body = extract_http_body(resp).unwrap();
731 assert_eq!(body, "{\"ok\":true}");
732 }
733
734 #[test]
735 fn test_extract_http_body_error() {
736 let resp = "HTTP/1.1 401 Unauthorized\r\n\r\n{\"error\":\"invalid key\"}";
737 assert!(extract_http_body(resp).is_err());
738 }
739
740 #[test]
741 fn test_broker_token_serialization() {
742 let token = BrokerToken {
743 token: "tok_abc123".to_string(),
744 expires_at: "2026-01-01T00:15:00Z".to_string(),
745 request: CredentialRequest {
746 allow: Some("s3:GetObject".to_string()),
747 profile: None,
748 resource: None,
749 provider: "aws".to_string(),
750 ttl: "15m".to_string(),
751 role_arn: None,
752 command: vec![],
753 agent_id: None,
754 },
755 };
756 let json = serde_json::to_string(&token).unwrap();
757 assert!(json.contains("tok_abc123"));
758 assert!(json.contains("s3:GetObject"));
759 }
760
761 fn mock_credential_response() -> CredentialResponse {
762 CredentialResponse {
763 session_id: "test-session-123".to_string(),
764 provider: "aws".to_string(),
765 expires_at: "2026-01-01T00:15:00Z".to_string(),
766 ttl_seconds: 900,
767 credentials: CredentialEnvVars {
768 aws_access_key_id: Some("AKID123".to_string()),
769 aws_secret_access_key: Some("secret".to_string()),
770 aws_session_token: Some("token".to_string()),
771 gcp_access_token: None,
772 azure_token: None,
773 },
774 }
775 }
776}