1use serde::{Deserialize, Serialize};
2
3use crate::audit::AuditEntry;
4use crate::error::{AvError, Result};
5
6#[derive(Debug, Clone, Serialize, Deserialize, Default)]
8pub struct ForwardConfig {
9 #[serde(default)]
14 pub destinations: Vec<String>,
15 pub batch_size: Option<usize>,
17 pub sync: Option<bool>,
19}
20
21pub struct ForwardBuffer {
23 config: ForwardConfig,
24 buffer: Vec<AuditEntry>,
25}
26
27impl ForwardBuffer {
28 pub fn new(config: ForwardConfig) -> Self {
29 Self {
30 config,
31 buffer: Vec::new(),
32 }
33 }
34
35 pub async fn push(&mut self, entry: AuditEntry) -> Result<bool> {
38 if self.config.destinations.is_empty() {
39 return Ok(false);
40 }
41
42 let sync = self.config.sync.unwrap_or(false);
43 if sync {
44 forward_entries(&self.config.destinations, &[entry]).await?;
45 return Ok(true);
46 }
47
48 self.buffer.push(entry);
49 let batch_size = self.config.batch_size.unwrap_or(10);
50 if self.buffer.len() >= batch_size {
51 self.flush().await?;
52 return Ok(true);
53 }
54 Ok(false)
55 }
56
57 pub async fn flush(&mut self) -> Result<()> {
59 if self.buffer.is_empty() || self.config.destinations.is_empty() {
60 return Ok(());
61 }
62 let entries = std::mem::take(&mut self.buffer);
63 forward_entries(&self.config.destinations, &entries).await
64 }
65}
66
67async fn forward_entries(destinations: &[String], entries: &[AuditEntry]) -> Result<()> {
69 let mut errors = Vec::new();
70
71 for dest in destinations {
72 let result = if dest.starts_with("s3://") {
73 forward_to_s3(dest, entries).await
74 } else if dest.starts_with("cloudwatch://") {
75 forward_to_cloudwatch(dest, entries).await
76 } else if dest.starts_with("https://") || dest.starts_with("http://") {
77 forward_to_webhook(dest, entries).await
78 } else {
79 Err(AvError::InvalidPolicy(format!(
80 "Unknown audit forwarding destination: {}. Expected s3://, cloudwatch://, or https://",
81 dest
82 )))
83 };
84
85 if let Err(e) = result {
86 errors.push(format!("{}: {}", dest, e));
87 }
88 }
89
90 if errors.is_empty() {
91 Ok(())
92 } else {
93 Err(AvError::InvalidPolicy(format!(
94 "Audit forwarding errors:\n{}",
95 errors.join("\n")
96 )))
97 }
98}
99
100async fn forward_to_s3(dest: &str, entries: &[AuditEntry]) -> Result<()> {
103 let path = dest.strip_prefix("s3://").unwrap();
104 let (bucket, prefix) = path.split_once('/').unwrap_or((path, ""));
105
106 let now = chrono::Utc::now();
107 let key = format!(
108 "{}{}/audex-{}.jsonl",
109 if prefix.is_empty() { "" } else { prefix },
110 now.format("%Y/%m/%d"),
111 now.format("%Y%m%dT%H%M%SZ"),
112 );
113
114 let body = entries
115 .iter()
116 .filter_map(|e| serde_json::to_string(e).ok())
117 .collect::<Vec<_>>()
118 .join("\n");
119
120 let region = std::env::var("AWS_REGION")
125 .or_else(|_| std::env::var("AWS_DEFAULT_REGION"))
126 .unwrap_or_else(|_| "us-east-1".to_string());
127
128 let endpoint = format!("https://{}.s3.{}.amazonaws.com/{}", bucket, region, key);
129
130 let client = reqwest::Client::new();
131 let resp = client
132 .put(&endpoint)
133 .header("Content-Type", "application/x-ndjson")
134 .body(body)
135 .send()
136 .await
137 .map_err(|e| AvError::InvalidPolicy(format!("S3 upload failed: {}", e)))?;
138
139 if !resp.status().is_success() {
140 let status = resp.status();
141 let body = resp.text().await.unwrap_or_default();
142 return Err(AvError::InvalidPolicy(format!(
143 "S3 upload returned {}: {}",
144 status,
145 &body[..body.len().min(200)]
146 )));
147 }
148
149 Ok(())
150}
151
152async fn forward_to_cloudwatch(dest: &str, entries: &[AuditEntry]) -> Result<()> {
155 let log_group = dest.strip_prefix("cloudwatch://").unwrap();
156 let log_stream = format!("audex/{}", chrono::Utc::now().format("%Y/%m/%d"));
157
158 let region = std::env::var("AWS_REGION")
160 .or_else(|_| std::env::var("AWS_DEFAULT_REGION"))
161 .unwrap_or_else(|_| "us-east-1".to_string());
162
163 let events: Vec<serde_json::Value> = entries
164 .iter()
165 .map(|e| {
166 serde_json::json!({
167 "timestamp": e.timestamp.timestamp_millis(),
168 "message": serde_json::to_string(e).unwrap_or_default()
169 })
170 })
171 .collect();
172
173 let payload = serde_json::json!({
174 "logGroupName": log_group,
175 "logStreamName": log_stream,
176 "logEvents": events
177 });
178
179 let endpoint = format!("https://logs.{}.amazonaws.com", region);
180
181 let client = reqwest::Client::new();
182 let resp = client
183 .post(&endpoint)
184 .header("Content-Type", "application/x-amz-json-1.1")
185 .header("X-Amz-Target", "Logs_20140328.PutLogEvents")
186 .json(&payload)
187 .send()
188 .await
189 .map_err(|e| AvError::InvalidPolicy(format!("CloudWatch Logs forward failed: {}", e)))?;
190
191 if !resp.status().is_success() {
192 let status = resp.status();
193 let body = resp.text().await.unwrap_or_default();
194 return Err(AvError::InvalidPolicy(format!(
195 "CloudWatch Logs returned {}: {}",
196 status,
197 &body[..body.len().min(200)]
198 )));
199 }
200
201 Ok(())
202}
203
204async fn forward_to_webhook(url: &str, entries: &[AuditEntry]) -> Result<()> {
207 let client = reqwest::Client::new();
208 let resp = client
209 .post(url)
210 .header("Content-Type", "application/json")
211 .header("User-Agent", "audex-audit-forwarder/0.1")
212 .json(&entries)
213 .send()
214 .await
215 .map_err(|e| AvError::InvalidPolicy(format!("Webhook forward failed: {}", e)))?;
216
217 if !resp.status().is_success() {
218 let status = resp.status();
219 let body = resp.text().await.unwrap_or_default();
220 return Err(AvError::InvalidPolicy(format!(
221 "Webhook returned {}: {}",
222 status,
223 &body[..body.len().min(200)]
224 )));
225 }
226
227 Ok(())
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::audit::{AuditEntry, AuditEvent};
234 use chrono::Utc;
235
236 fn sample_entry() -> AuditEntry {
237 AuditEntry {
238 timestamp: Utc::now(),
239 session_id: "test-session-123".to_string(),
240 provider: "aws".to_string(),
241 event: AuditEvent::SessionCreated {
242 role_arn: "arn:aws:iam::123456789012:role/TestRole".to_string(),
243 ttl_seconds: 900,
244 budget: None,
245 allowed_actions: vec!["s3:GetObject".to_string()],
246 command: vec!["aws".to_string(), "s3".to_string(), "ls".to_string()],
247 agent_id: None,
248 },
249 }
250 }
251
252 #[test]
253 fn test_forward_config_default() {
254 let config = ForwardConfig::default();
255 assert!(config.destinations.is_empty());
256 assert!(config.batch_size.is_none());
257 assert!(config.sync.is_none());
258 }
259
260 #[test]
261 fn test_forward_config_deserialize() {
262 let toml_str = r#"
263destinations = ["s3://my-audit-bucket/audex/", "https://siem.example.com/ingest"]
264batch_size = 5
265sync = false
266"#;
267 let config: ForwardConfig = toml::from_str(toml_str).unwrap();
268 assert_eq!(config.destinations.len(), 2);
269 assert_eq!(config.batch_size, Some(5));
270 assert_eq!(config.sync, Some(false));
271 }
272
273 #[tokio::test]
274 async fn test_buffer_no_destinations() {
275 let config = ForwardConfig::default();
276 let mut buffer = ForwardBuffer::new(config);
277 let flushed = buffer.push(sample_entry()).await.unwrap();
278 assert!(!flushed);
279 assert!(buffer.buffer.is_empty());
280 }
281
282 #[tokio::test]
283 async fn test_buffer_batching() {
284 let config = ForwardConfig {
285 destinations: vec![],
287 batch_size: Some(3),
288 sync: Some(false),
289 };
290 let mut buffer = ForwardBuffer::new(config);
291
292 buffer.push(sample_entry()).await.unwrap();
294 buffer.push(sample_entry()).await.unwrap();
295 assert_eq!(buffer.buffer.len(), 0); }
297
298 #[test]
299 fn test_entry_serialization() {
300 let entry = sample_entry();
301 let json = serde_json::to_string(&entry).unwrap();
302 assert!(json.contains("test-session-123"));
303 assert!(json.contains("session_created"));
304 let parsed: AuditEntry = serde_json::from_str(&json).unwrap();
306 assert_eq!(parsed.session_id, "test-session-123");
307 }
308}