Skip to main content

tryaudex_core/
server.rs

1use serde::{Deserialize, Serialize};
2
3/// Request to issue scoped credentials via the server API.
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct CredentialRequest {
6    /// IAM actions to allow (e.g. "s3:GetObject,s3:ListBucket")
7    pub allow: Option<String>,
8    /// Named policy profile (e.g. "s3-readonly")
9    pub profile: Option<String>,
10    /// Resource ARN restriction
11    pub resource: Option<String>,
12    /// Cloud provider (aws, gcp, azure)
13    #[serde(default = "default_provider")]
14    pub provider: String,
15    /// TTL for credentials (e.g. "15m", "1h")
16    #[serde(default = "default_ttl")]
17    pub ttl: String,
18    /// IAM role ARN to assume
19    pub role_arn: Option<String>,
20    /// Command that will use these credentials (for audit)
21    #[serde(default)]
22    pub command: Vec<String>,
23    /// Agent identity
24    pub agent_id: Option<String>,
25}
26
27fn default_provider() -> String {
28    "aws".to_string()
29}
30
31fn default_ttl() -> String {
32    "15m".to_string()
33}
34
35/// Response from the credential issuance endpoint.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct CredentialResponse {
38    pub session_id: String,
39    pub provider: String,
40    pub expires_at: String,
41    pub ttl_seconds: u64,
42    pub credentials: CredentialEnvVars,
43}
44
45/// Credential environment variables returned by the server.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct CredentialEnvVars {
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub aws_access_key_id: Option<String>,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub aws_secret_access_key: Option<String>,
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub aws_session_token: Option<String>,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub gcp_access_token: Option<String>,
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub azure_token: Option<String>,
58}
59
60/// Server configuration.
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct ServerConfig {
63    /// Port to listen on
64    #[serde(default = "default_port")]
65    pub port: u16,
66    /// Bind address
67    #[serde(default = "default_bind")]
68    pub bind: String,
69    /// Optional API key for authentication
70    pub api_key: Option<String>,
71}
72
73fn default_port() -> u16 {
74    8080
75}
76
77fn default_bind() -> String {
78    "127.0.0.1".to_string()
79}
80
81impl Default for ServerConfig {
82    fn default() -> Self {
83        Self {
84            port: default_port(),
85            bind: default_bind(),
86            api_key: None,
87        }
88    }
89}
90
91/// API error response.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ApiError {
94    pub error: String,
95    pub code: u16,
96}
97
98/// Simple HTTP response builder.
99///
100/// Sanitizes `status_text` and `content_type` to prevent CRLF header injection.
101pub fn http_response(status: u16, status_text: &str, content_type: &str, body: &str) -> String {
102    // Strip CR/LF/NUL to prevent header injection via attacker-controlled values.
103    let safe_status = status_text.replace(['\r', '\n', '\0'], "");
104    let safe_ct = content_type.replace(['\r', '\n', '\0'], "");
105    format!(
106        "HTTP/1.1 {} {}\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n{}",
107        status,
108        safe_status,
109        safe_ct,
110        body.len(),
111        body
112    )
113}
114
115pub fn json_ok(body: &str) -> String {
116    http_response(200, "OK", "application/json", body)
117}
118
119/// Return a 500 error with a generic message. Logs the actual error
120/// server-side to avoid leaking internal details (account IDs, ARNs)
121/// in HTTP responses.
122pub fn internal_error(context: &str, err: &dyn std::fmt::Display) -> String {
123    tracing::error!("{}: {}", context, err);
124    json_error(500, &format!("{}: internal error", context))
125}
126
127pub fn json_error(code: u16, message: &str) -> String {
128    let err = ApiError {
129        error: message.to_string(),
130        code,
131    };
132    let body = serde_json::to_string(&err).unwrap_or_else(|_| {
133        // Escape quotes and backslashes to prevent JSON injection in the fallback.
134        let escaped = message.replace('\\', "\\\\").replace('"', "\\\"");
135        format!("{{\"error\":\"{}\",\"code\":{}}}", escaped, code)
136    });
137    let status_text = match code {
138        400 => "Bad Request",
139        401 => "Unauthorized",
140        403 => "Forbidden",
141        404 => "Not Found",
142        405 => "Method Not Allowed",
143        413 => "Payload Too Large",
144        429 => "Too Many Requests",
145        500 => "Internal Server Error",
146        503 => "Service Unavailable",
147        _ => "Error",
148    };
149    http_response(code, status_text, "application/json", &body)
150}
151
152/// Parse a minimal HTTP request: returns (method, path, body).
153pub fn parse_http_request(raw: &str) -> Option<(&str, &str, &str)> {
154    let mut lines = raw.split("\r\n");
155    let first_line = lines.next()?;
156    let mut parts = first_line.split_whitespace();
157    let method = parts.next()?;
158    let path = parts.next()?;
159
160    // Find body after double CRLF
161    let body = raw.split("\r\n\r\n").nth(1).unwrap_or("");
162
163    Some((method, path, body))
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn test_parse_http_request_get() {
172        let raw = "GET /v1/sessions HTTP/1.1\r\nHost: localhost\r\n\r\n";
173        let (method, path, body) = parse_http_request(raw).unwrap();
174        assert_eq!(method, "GET");
175        assert_eq!(path, "/v1/sessions");
176        assert_eq!(body, "");
177    }
178
179    #[test]
180    fn test_parse_http_request_post() {
181        let raw = "POST /v1/credentials HTTP/1.1\r\nContent-Type: application/json\r\n\r\n{\"allow\":\"s3:GetObject\"}";
182        let (method, path, body) = parse_http_request(raw).unwrap();
183        assert_eq!(method, "POST");
184        assert_eq!(path, "/v1/credentials");
185        assert_eq!(body, "{\"allow\":\"s3:GetObject\"}");
186    }
187
188    #[test]
189    fn test_json_ok() {
190        let resp = json_ok("{\"status\":\"ok\"}");
191        assert!(resp.contains("200 OK"));
192        assert!(resp.contains("application/json"));
193    }
194
195    #[test]
196    fn test_json_error() {
197        let resp = json_error(400, "bad request");
198        assert!(resp.contains("400 Bad Request"));
199        assert!(resp.contains("bad request"));
200    }
201
202    #[test]
203    fn test_credential_request_defaults() {
204        let json = r#"{"allow":"s3:GetObject"}"#;
205        let req: CredentialRequest = serde_json::from_str(json).unwrap();
206        assert_eq!(req.provider, "aws");
207        assert_eq!(req.ttl, "15m");
208        assert!(req.role_arn.is_none());
209    }
210
211    #[test]
212    fn test_server_config_default() {
213        let config = ServerConfig::default();
214        assert_eq!(config.port, 8080);
215        assert_eq!(config.bind, "127.0.0.1");
216        assert!(config.api_key.is_none());
217    }
218}