1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct CredentialRequest {
6 pub allow: Option<String>,
8 pub profile: Option<String>,
10 pub resource: Option<String>,
12 #[serde(default = "default_provider")]
14 pub provider: String,
15 #[serde(default = "default_ttl")]
17 pub ttl: String,
18 pub role_arn: Option<String>,
20 #[serde(default)]
22 pub command: Vec<String>,
23 pub agent_id: Option<String>,
25}
26
27fn default_provider() -> String {
28 "aws".to_string()
29}
30
31fn default_ttl() -> String {
32 "15m".to_string()
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct CredentialResponse {
38 pub session_id: String,
39 pub provider: String,
40 pub expires_at: String,
41 pub ttl_seconds: u64,
42 pub credentials: CredentialEnvVars,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct CredentialEnvVars {
48 #[serde(skip_serializing_if = "Option::is_none")]
49 pub aws_access_key_id: Option<String>,
50 #[serde(skip_serializing_if = "Option::is_none")]
51 pub aws_secret_access_key: Option<String>,
52 #[serde(skip_serializing_if = "Option::is_none")]
53 pub aws_session_token: Option<String>,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 pub gcp_access_token: Option<String>,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub azure_token: Option<String>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct ServerConfig {
63 #[serde(default = "default_port")]
65 pub port: u16,
66 #[serde(default = "default_bind")]
68 pub bind: String,
69 pub api_key: Option<String>,
71}
72
73fn default_port() -> u16 {
74 8080
75}
76
77fn default_bind() -> String {
78 "127.0.0.1".to_string()
79}
80
81impl Default for ServerConfig {
82 fn default() -> Self {
83 Self {
84 port: default_port(),
85 bind: default_bind(),
86 api_key: None,
87 }
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ApiError {
94 pub error: String,
95 pub code: u16,
96}
97
98pub fn http_response(status: u16, status_text: &str, content_type: &str, body: &str) -> String {
102 let safe_status = status_text.replace(['\r', '\n', '\0'], "");
104 let safe_ct = content_type.replace(['\r', '\n', '\0'], "");
105 format!(
106 "HTTP/1.1 {} {}\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n{}",
107 status,
108 safe_status,
109 safe_ct,
110 body.len(),
111 body
112 )
113}
114
115pub fn json_ok(body: &str) -> String {
116 http_response(200, "OK", "application/json", body)
117}
118
119pub fn internal_error(context: &str, err: &dyn std::fmt::Display) -> String {
123 tracing::error!("{}: {}", context, err);
124 json_error(500, &format!("{}: internal error", context))
125}
126
127pub fn json_error(code: u16, message: &str) -> String {
128 let err = ApiError {
129 error: message.to_string(),
130 code,
131 };
132 let body = serde_json::to_string(&err).unwrap_or_else(|_| {
133 let escaped = message.replace('\\', "\\\\").replace('"', "\\\"");
135 format!("{{\"error\":\"{}\",\"code\":{}}}", escaped, code)
136 });
137 let status_text = match code {
138 400 => "Bad Request",
139 401 => "Unauthorized",
140 403 => "Forbidden",
141 404 => "Not Found",
142 405 => "Method Not Allowed",
143 413 => "Payload Too Large",
144 429 => "Too Many Requests",
145 500 => "Internal Server Error",
146 503 => "Service Unavailable",
147 _ => "Error",
148 };
149 http_response(code, status_text, "application/json", &body)
150}
151
152pub fn parse_http_request(raw: &str) -> Option<(&str, &str, &str)> {
154 let mut lines = raw.split("\r\n");
155 let first_line = lines.next()?;
156 let mut parts = first_line.split_whitespace();
157 let method = parts.next()?;
158 let path = parts.next()?;
159
160 let body = raw.split("\r\n\r\n").nth(1).unwrap_or("");
162
163 Some((method, path, body))
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn test_parse_http_request_get() {
172 let raw = "GET /v1/sessions HTTP/1.1\r\nHost: localhost\r\n\r\n";
173 let (method, path, body) = parse_http_request(raw).unwrap();
174 assert_eq!(method, "GET");
175 assert_eq!(path, "/v1/sessions");
176 assert_eq!(body, "");
177 }
178
179 #[test]
180 fn test_parse_http_request_post() {
181 let raw = "POST /v1/credentials HTTP/1.1\r\nContent-Type: application/json\r\n\r\n{\"allow\":\"s3:GetObject\"}";
182 let (method, path, body) = parse_http_request(raw).unwrap();
183 assert_eq!(method, "POST");
184 assert_eq!(path, "/v1/credentials");
185 assert_eq!(body, "{\"allow\":\"s3:GetObject\"}");
186 }
187
188 #[test]
189 fn test_json_ok() {
190 let resp = json_ok("{\"status\":\"ok\"}");
191 assert!(resp.contains("200 OK"));
192 assert!(resp.contains("application/json"));
193 }
194
195 #[test]
196 fn test_json_error() {
197 let resp = json_error(400, "bad request");
198 assert!(resp.contains("400 Bad Request"));
199 assert!(resp.contains("bad request"));
200 }
201
202 #[test]
203 fn test_credential_request_defaults() {
204 let json = r#"{"allow":"s3:GetObject"}"#;
205 let req: CredentialRequest = serde_json::from_str(json).unwrap();
206 assert_eq!(req.provider, "aws");
207 assert_eq!(req.ttl, "15m");
208 assert!(req.role_arn.is_none());
209 }
210
211 #[test]
212 fn test_server_config_default() {
213 let config = ServerConfig::default();
214 assert_eq!(config.port, 8080);
215 assert_eq!(config.bind, "127.0.0.1");
216 assert!(config.api_key.is_none());
217 }
218}