1use serde::{Deserialize, Serialize};
2
3use crate::audit::AuditEntry;
4use crate::error::{AvError, Result};
5
6#[derive(Debug, Clone, Serialize, Deserialize, Default)]
8pub struct ForwardConfig {
9 #[serde(default)]
14 pub destinations: Vec<String>,
15 pub batch_size: Option<usize>,
17 pub sync: Option<bool>,
19 pub webhook_bearer_token: Option<String>,
24}
25
26impl ForwardConfig {
27 pub fn resolve_bearer_token(&self) -> Option<String> {
30 std::env::var("AUDEX_WEBHOOK_BEARER_TOKEN")
31 .ok()
32 .or_else(|| self.webhook_bearer_token.clone())
33 }
34}
35
36pub struct ForwardBuffer {
38 config: ForwardConfig,
39 buffer: Vec<AuditEntry>,
40}
41
42impl ForwardBuffer {
43 pub fn new(config: ForwardConfig) -> Self {
44 Self {
45 config,
46 buffer: Vec::new(),
47 }
48 }
49
50 pub async fn push(&mut self, entry: AuditEntry) -> Result<bool> {
53 if self.config.destinations.is_empty() {
54 return Ok(false);
55 }
56
57 let sync = self.config.sync.unwrap_or(false);
58 if sync {
59 let bearer = self.config.resolve_bearer_token();
62 forward_entries(&self.config.destinations, bearer.as_deref(), &[entry]).await?;
63 return Ok(true);
64 }
65
66 self.buffer.push(entry);
67 let batch_size = self.config.batch_size.unwrap_or(10);
68 if self.buffer.len() >= batch_size {
69 self.flush().await?;
70 return Ok(true);
71 }
72 Ok(false)
73 }
74
75 pub async fn flush(&mut self) -> Result<()> {
77 if self.buffer.is_empty() || self.config.destinations.is_empty() {
78 return Ok(());
79 }
80 let entries = std::mem::take(&mut self.buffer);
81 let bearer = self.config.resolve_bearer_token();
84 forward_entries(&self.config.destinations, bearer.as_deref(), &entries).await
85 }
86}
87
88async fn forward_entries(
90 destinations: &[String],
91 webhook_bearer_token: Option<&str>,
92 entries: &[AuditEntry],
93) -> Result<()> {
94 let mut errors = Vec::new();
95
96 for dest in destinations {
97 let result = if dest.starts_with("s3://") {
98 forward_to_s3(dest, entries).await
99 } else if dest.starts_with("cloudwatch://") {
100 forward_to_cloudwatch(dest, entries).await
101 } else if dest.starts_with("https://") {
102 forward_to_webhook(dest, entries, webhook_bearer_token).await
103 } else if dest.starts_with("http://") {
104 Err(AvError::InvalidPolicy(format!(
105 "Refusing to forward audit data over plaintext HTTP: {}. Use https:// instead.",
106 dest
107 )))
108 } else {
109 Err(AvError::InvalidPolicy(format!(
110 "Unknown audit forwarding destination: {}. Expected s3://, cloudwatch://, or https://",
111 dest
112 )))
113 };
114
115 if let Err(e) = result {
116 errors.push(format!("{}: {}", dest, e));
117 }
118 }
119
120 if errors.is_empty() {
121 Ok(())
122 } else {
123 Err(AvError::InvalidPolicy(format!(
124 "Audit forwarding errors:\n{}",
125 errors.join("\n")
126 )))
127 }
128}
129
130async fn forward_to_s3(dest: &str, entries: &[AuditEntry]) -> Result<()> {
136 let path = dest.strip_prefix("s3://").unwrap();
137 let (bucket, prefix) = path.split_once('/').unwrap_or((path, ""));
138
139 let now = chrono::Utc::now();
140 let key = format!(
141 "{}{}/audex-{}.jsonl",
142 if prefix.is_empty() { "" } else { prefix },
143 now.format("%Y/%m/%d"),
144 now.format("%Y%m%dT%H%M%SZ"),
145 );
146
147 let body = entries
148 .iter()
149 .filter_map(|e| serde_json::to_string(e).ok())
150 .collect::<Vec<_>>()
151 .join("\n");
152
153 let s3_uri = format!("s3://{}/{}", bucket, key);
154
155 let tmp = std::env::temp_dir().join(format!("audex-fwd-{}.jsonl", uuid::Uuid::new_v4()));
158 std::fs::write(&tmp, &body)?;
159
160 let tmp_path = tmp.clone();
161 let output = tokio::task::spawn_blocking(move || {
162 std::process::Command::new("aws")
163 .args([
164 "s3",
165 "cp",
166 tmp_path.to_str().unwrap_or("-"),
167 &s3_uri,
168 "--content-type",
169 "application/x-ndjson",
170 ])
171 .output()
172 })
173 .await
174 .map_err(|e| AvError::InvalidPolicy(format!("S3 upload task panicked: {}", e)))?
175 .map_err(|e| AvError::InvalidPolicy(format!("Failed to run aws s3 cp: {}", e)))?;
176
177 let _ = std::fs::remove_file(&tmp);
179
180 if !output.status.success() {
181 let stderr = String::from_utf8_lossy(&output.stderr);
182 return Err(AvError::InvalidPolicy(format!(
183 "S3 upload failed: {}",
184 &stderr[..stderr.len().min(200)]
185 )));
186 }
187
188 Ok(())
189}
190
191async fn forward_to_cloudwatch(dest: &str, entries: &[AuditEntry]) -> Result<()> {
196 let log_group = dest.strip_prefix("cloudwatch://").unwrap();
197 let log_stream = format!("audex/{}", chrono::Utc::now().format("%Y/%m/%d"));
198
199 let lg = log_group.to_string();
201 let ls = log_stream.clone();
202 let _ = tokio::task::spawn_blocking(move || {
203 std::process::Command::new("aws")
204 .args([
205 "logs",
206 "create-log-stream",
207 "--log-group-name",
208 &lg,
209 "--log-stream-name",
210 &ls,
211 ])
212 .output()
213 })
214 .await;
215
216 let events: Vec<serde_json::Value> = entries
217 .iter()
218 .map(|e| {
219 serde_json::json!({
220 "timestamp": e.timestamp.timestamp_millis(),
221 "message": serde_json::to_string(e).unwrap_or_default()
222 })
223 })
224 .collect();
225
226 let events_json = serde_json::to_string(&events)
227 .map_err(|e| AvError::InvalidPolicy(format!("Failed to serialize events: {}", e)))?;
228
229 let lg2 = log_group.to_string();
230 let ls2 = log_stream;
231 let output = tokio::task::spawn_blocking(move || {
232 std::process::Command::new("aws")
233 .args([
234 "logs",
235 "put-log-events",
236 "--log-group-name",
237 &lg2,
238 "--log-stream-name",
239 &ls2,
240 "--log-events",
241 &events_json,
242 ])
243 .output()
244 })
245 .await
246 .map_err(|e| AvError::InvalidPolicy(format!("CloudWatch task panicked: {}", e)))?
247 .map_err(|e| AvError::InvalidPolicy(format!("Failed to run aws logs: {}", e)))?;
248
249 if !output.status.success() {
250 let stderr = String::from_utf8_lossy(&output.stderr);
251 return Err(AvError::InvalidPolicy(format!(
252 "CloudWatch Logs forward failed: {}",
253 &stderr[..stderr.len().min(200)]
254 )));
255 }
256
257 Ok(())
258}
259
260async fn forward_to_webhook(
264 url: &str,
265 entries: &[AuditEntry],
266 bearer_token: Option<&str>,
267) -> Result<()> {
268 let client = reqwest::Client::new();
269 let mut req = client
270 .post(url)
271 .header("Content-Type", "application/json")
272 .header("User-Agent", "audex-audit-forwarder/0.1")
273 .json(&entries);
274 if let Some(token) = bearer_token {
275 req = req.bearer_auth(token);
276 }
277 let resp = req
278 .send()
279 .await
280 .map_err(|e| AvError::InvalidPolicy(format!("Webhook forward failed: {}", e)))?;
281
282 if !resp.status().is_success() {
283 let status = resp.status();
284 let body = resp.text().await.unwrap_or_default();
285 return Err(AvError::InvalidPolicy(format!(
286 "Webhook returned {}: {}",
287 status,
288 &body[..body.len().min(200)]
289 )));
290 }
291
292 Ok(())
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use crate::audit::{AuditEntry, AuditEvent};
299 use chrono::Utc;
300
301 fn sample_entry() -> AuditEntry {
302 AuditEntry {
303 timestamp: Utc::now(),
304 session_id: "test-session-123".to_string(),
305 provider: "aws".to_string(),
306 event: AuditEvent::SessionCreated {
307 role_arn: "arn:aws:iam::123456789012:role/TestRole".to_string(),
308 ttl_seconds: 900,
309 budget: None,
310 allowed_actions: vec!["s3:GetObject".to_string()],
311 command: vec!["aws".to_string(), "s3".to_string(), "ls".to_string()],
312 agent_id: None,
313 },
314 }
315 }
316
317 #[test]
318 fn test_forward_config_default() {
319 let config = ForwardConfig::default();
320 assert!(config.destinations.is_empty());
321 assert!(config.batch_size.is_none());
322 assert!(config.sync.is_none());
323 }
324
325 #[test]
326 fn test_forward_config_deserialize() {
327 let toml_str = r#"
328destinations = ["s3://my-audit-bucket/audex/", "https://siem.example.com/ingest"]
329batch_size = 5
330sync = false
331"#;
332 let config: ForwardConfig = toml::from_str(toml_str).unwrap();
333 assert_eq!(config.destinations.len(), 2);
334 assert_eq!(config.batch_size, Some(5));
335 assert_eq!(config.sync, Some(false));
336 }
337
338 #[tokio::test]
339 async fn test_buffer_no_destinations() {
340 let config = ForwardConfig::default();
341 let mut buffer = ForwardBuffer::new(config);
342 let flushed = buffer.push(sample_entry()).await.unwrap();
343 assert!(!flushed);
344 assert!(buffer.buffer.is_empty());
345 }
346
347 #[tokio::test]
348 async fn test_buffer_batching() {
349 let config = ForwardConfig {
350 destinations: vec![],
352 batch_size: Some(3),
353 sync: Some(false),
354 ..Default::default()
355 };
356 let mut buffer = ForwardBuffer::new(config);
357
358 buffer.push(sample_entry()).await.unwrap();
360 buffer.push(sample_entry()).await.unwrap();
361 assert_eq!(buffer.buffer.len(), 0); }
363
364 #[test]
365 fn test_entry_serialization() {
366 let entry = sample_entry();
367 let json = serde_json::to_string(&entry).unwrap();
368 assert!(json.contains("test-session-123"));
369 assert!(json.contains("session_created"));
370 let parsed: AuditEntry = serde_json::from_str(&json).unwrap();
372 assert_eq!(parsed.session_id, "test-session-123");
373 }
374}