1use serde::{Deserialize, Serialize};
2
3use crate::audit::AuditEntry;
4use crate::error::{AvError, Result};
5
6#[derive(Debug, Clone, Serialize, Deserialize, Default)]
8pub struct ForwardConfig {
9 #[serde(default)]
14 pub destinations: Vec<String>,
15 pub batch_size: Option<usize>,
17 pub sync: Option<bool>,
19}
20
21pub struct ForwardBuffer {
23 config: ForwardConfig,
24 buffer: Vec<AuditEntry>,
25}
26
27impl ForwardBuffer {
28 pub fn new(config: ForwardConfig) -> Self {
29 Self {
30 config,
31 buffer: Vec::new(),
32 }
33 }
34
35 pub async fn push(&mut self, entry: AuditEntry) -> Result<bool> {
38 if self.config.destinations.is_empty() {
39 return Ok(false);
40 }
41
42 let sync = self.config.sync.unwrap_or(false);
43 if sync {
44 forward_entries(&self.config.destinations, &[entry]).await?;
45 return Ok(true);
46 }
47
48 self.buffer.push(entry);
49 let batch_size = self.config.batch_size.unwrap_or(10);
50 if self.buffer.len() >= batch_size {
51 self.flush().await?;
52 return Ok(true);
53 }
54 Ok(false)
55 }
56
57 pub async fn flush(&mut self) -> Result<()> {
59 if self.buffer.is_empty() || self.config.destinations.is_empty() {
60 return Ok(());
61 }
62 let entries = std::mem::take(&mut self.buffer);
63 forward_entries(&self.config.destinations, &entries).await
64 }
65}
66
67async fn forward_entries(destinations: &[String], entries: &[AuditEntry]) -> Result<()> {
69 let mut errors = Vec::new();
70
71 for dest in destinations {
72 let result = if dest.starts_with("s3://") {
73 forward_to_s3(dest, entries).await
74 } else if dest.starts_with("cloudwatch://") {
75 forward_to_cloudwatch(dest, entries).await
76 } else if dest.starts_with("https://") || dest.starts_with("http://") {
77 forward_to_webhook(dest, entries).await
78 } else {
79 Err(AvError::InvalidPolicy(format!(
80 "Unknown audit forwarding destination: {}. Expected s3://, cloudwatch://, or https://",
81 dest
82 )))
83 };
84
85 if let Err(e) = result {
86 errors.push(format!("{}: {}", dest, e));
87 }
88 }
89
90 if errors.is_empty() {
91 Ok(())
92 } else {
93 Err(AvError::InvalidPolicy(format!(
94 "Audit forwarding errors:\n{}",
95 errors.join("\n")
96 )))
97 }
98}
99
100async fn forward_to_s3(dest: &str, entries: &[AuditEntry]) -> Result<()> {
103 let path = dest.strip_prefix("s3://").unwrap();
104 let (bucket, prefix) = path.split_once('/').unwrap_or((path, ""));
105
106 let now = chrono::Utc::now();
107 let key = format!(
108 "{}{}/audex-{}.jsonl",
109 if prefix.is_empty() { "" } else { prefix },
110 now.format("%Y/%m/%d"),
111 now.format("%Y%m%dT%H%M%SZ"),
112 );
113
114 let body = entries
115 .iter()
116 .filter_map(|e| serde_json::to_string(e).ok())
117 .collect::<Vec<_>>()
118 .join("\n");
119
120 let region = std::env::var("AWS_REGION")
125 .or_else(|_| std::env::var("AWS_DEFAULT_REGION"))
126 .unwrap_or_else(|_| "us-east-1".to_string());
127
128 let endpoint = format!(
129 "https://{}.s3.{}.amazonaws.com/{}",
130 bucket, region, key
131 );
132
133 let client = reqwest::Client::new();
134 let resp = client
135 .put(&endpoint)
136 .header("Content-Type", "application/x-ndjson")
137 .body(body)
138 .send()
139 .await
140 .map_err(|e| AvError::InvalidPolicy(format!("S3 upload failed: {}", e)))?;
141
142 if !resp.status().is_success() {
143 let status = resp.status();
144 let body = resp.text().await.unwrap_or_default();
145 return Err(AvError::InvalidPolicy(format!(
146 "S3 upload returned {}: {}",
147 status,
148 &body[..body.len().min(200)]
149 )));
150 }
151
152 Ok(())
153}
154
155async fn forward_to_cloudwatch(dest: &str, entries: &[AuditEntry]) -> Result<()> {
158 let log_group = dest.strip_prefix("cloudwatch://").unwrap();
159 let log_stream = format!("audex/{}", chrono::Utc::now().format("%Y/%m/%d"));
160
161 let region = std::env::var("AWS_REGION")
163 .or_else(|_| std::env::var("AWS_DEFAULT_REGION"))
164 .unwrap_or_else(|_| "us-east-1".to_string());
165
166 let events: Vec<serde_json::Value> = entries
167 .iter()
168 .map(|e| {
169 serde_json::json!({
170 "timestamp": e.timestamp.timestamp_millis(),
171 "message": serde_json::to_string(e).unwrap_or_default()
172 })
173 })
174 .collect();
175
176 let payload = serde_json::json!({
177 "logGroupName": log_group,
178 "logStreamName": log_stream,
179 "logEvents": events
180 });
181
182 let endpoint = format!("https://logs.{}.amazonaws.com", region);
183
184 let client = reqwest::Client::new();
185 let resp = client
186 .post(&endpoint)
187 .header("Content-Type", "application/x-amz-json-1.1")
188 .header("X-Amz-Target", "Logs_20140328.PutLogEvents")
189 .json(&payload)
190 .send()
191 .await
192 .map_err(|e| AvError::InvalidPolicy(format!("CloudWatch Logs forward failed: {}", e)))?;
193
194 if !resp.status().is_success() {
195 let status = resp.status();
196 let body = resp.text().await.unwrap_or_default();
197 return Err(AvError::InvalidPolicy(format!(
198 "CloudWatch Logs returned {}: {}",
199 status,
200 &body[..body.len().min(200)]
201 )));
202 }
203
204 Ok(())
205}
206
207async fn forward_to_webhook(url: &str, entries: &[AuditEntry]) -> Result<()> {
210 let client = reqwest::Client::new();
211 let resp = client
212 .post(url)
213 .header("Content-Type", "application/json")
214 .header("User-Agent", "audex-audit-forwarder/0.1")
215 .json(&entries)
216 .send()
217 .await
218 .map_err(|e| AvError::InvalidPolicy(format!("Webhook forward failed: {}", e)))?;
219
220 if !resp.status().is_success() {
221 let status = resp.status();
222 let body = resp.text().await.unwrap_or_default();
223 return Err(AvError::InvalidPolicy(format!(
224 "Webhook returned {}: {}",
225 status,
226 &body[..body.len().min(200)]
227 )));
228 }
229
230 Ok(())
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use crate::audit::{AuditEvent, AuditEntry};
237 use chrono::Utc;
238
239 fn sample_entry() -> AuditEntry {
240 AuditEntry {
241 timestamp: Utc::now(),
242 session_id: "test-session-123".to_string(),
243 provider: "aws".to_string(),
244 event: AuditEvent::SessionCreated {
245 role_arn: "arn:aws:iam::123456789012:role/TestRole".to_string(),
246 ttl_seconds: 900,
247 budget: None,
248 allowed_actions: vec!["s3:GetObject".to_string()],
249 command: vec!["aws".to_string(), "s3".to_string(), "ls".to_string()],
250 agent_id: None,
251 },
252 }
253 }
254
255 #[test]
256 fn test_forward_config_default() {
257 let config = ForwardConfig::default();
258 assert!(config.destinations.is_empty());
259 assert!(config.batch_size.is_none());
260 assert!(config.sync.is_none());
261 }
262
263 #[test]
264 fn test_forward_config_deserialize() {
265 let toml_str = r#"
266destinations = ["s3://my-audit-bucket/audex/", "https://siem.example.com/ingest"]
267batch_size = 5
268sync = false
269"#;
270 let config: ForwardConfig = toml::from_str(toml_str).unwrap();
271 assert_eq!(config.destinations.len(), 2);
272 assert_eq!(config.batch_size, Some(5));
273 assert_eq!(config.sync, Some(false));
274 }
275
276 #[tokio::test]
277 async fn test_buffer_no_destinations() {
278 let config = ForwardConfig::default();
279 let mut buffer = ForwardBuffer::new(config);
280 let flushed = buffer.push(sample_entry()).await.unwrap();
281 assert!(!flushed);
282 assert!(buffer.buffer.is_empty());
283 }
284
285 #[tokio::test]
286 async fn test_buffer_batching() {
287 let config = ForwardConfig {
288 destinations: vec![],
290 batch_size: Some(3),
291 sync: Some(false),
292 };
293 let mut buffer = ForwardBuffer::new(config);
294
295 buffer.push(sample_entry()).await.unwrap();
297 buffer.push(sample_entry()).await.unwrap();
298 assert_eq!(buffer.buffer.len(), 0); }
300
301 #[test]
302 fn test_entry_serialization() {
303 let entry = sample_entry();
304 let json = serde_json::to_string(&entry).unwrap();
305 assert!(json.contains("test-session-123"));
306 assert!(json.contains("session_created"));
307 let parsed: AuditEntry = serde_json::from_str(&json).unwrap();
309 assert_eq!(parsed.session_id, "test-session-123");
310 }
311}