Skip to main content

tryaudex_core/
forward.rs

1use serde::{Deserialize, Serialize};
2
3use crate::audit::AuditEntry;
4use crate::error::{AvError, Result};
5
6/// Audit forwarding configuration in `[audit]` config section.
7#[derive(Debug, Clone, Serialize, Deserialize, Default)]
8pub struct ForwardConfig {
9    /// Forwarding destinations. Each entry is a URL:
10    /// - `s3://bucket-name/prefix/` — S3 with daily partitioned keys
11    /// - `cloudwatch://log-group-name` — CloudWatch Logs
12    /// - `https://...` — Generic webhook (POST JSON array)
13    #[serde(default)]
14    pub destinations: Vec<String>,
15    /// Batch size before flushing (default: 10)
16    pub batch_size: Option<usize>,
17    /// Whether to forward synchronously on each event (default: false, batched)
18    pub sync: Option<bool>,
19}
20
21/// In-memory buffer for batched forwarding.
22pub struct ForwardBuffer {
23    config: ForwardConfig,
24    buffer: Vec<AuditEntry>,
25}
26
27impl ForwardBuffer {
28    pub fn new(config: ForwardConfig) -> Self {
29        Self {
30            config,
31            buffer: Vec::new(),
32        }
33    }
34
35    /// Add an entry to the buffer. Flushes if batch size is reached.
36    /// Returns Ok(true) if a flush happened.
37    pub async fn push(&mut self, entry: AuditEntry) -> Result<bool> {
38        if self.config.destinations.is_empty() {
39            return Ok(false);
40        }
41
42        let sync = self.config.sync.unwrap_or(false);
43        if sync {
44            forward_entries(&self.config.destinations, &[entry]).await?;
45            return Ok(true);
46        }
47
48        self.buffer.push(entry);
49        let batch_size = self.config.batch_size.unwrap_or(10);
50        if self.buffer.len() >= batch_size {
51            self.flush().await?;
52            return Ok(true);
53        }
54        Ok(false)
55    }
56
57    /// Flush all buffered entries to all destinations.
58    pub async fn flush(&mut self) -> Result<()> {
59        if self.buffer.is_empty() || self.config.destinations.is_empty() {
60            return Ok(());
61        }
62        let entries = std::mem::take(&mut self.buffer);
63        forward_entries(&self.config.destinations, &entries).await
64    }
65}
66
67/// Forward entries to all configured destinations.
68async fn forward_entries(destinations: &[String], entries: &[AuditEntry]) -> Result<()> {
69    let mut errors = Vec::new();
70
71    for dest in destinations {
72        let result = if dest.starts_with("s3://") {
73            forward_to_s3(dest, entries).await
74        } else if dest.starts_with("cloudwatch://") {
75            forward_to_cloudwatch(dest, entries).await
76        } else if dest.starts_with("https://") || dest.starts_with("http://") {
77            forward_to_webhook(dest, entries).await
78        } else {
79            Err(AvError::InvalidPolicy(format!(
80                "Unknown audit forwarding destination: {}. Expected s3://, cloudwatch://, or https://",
81                dest
82            )))
83        };
84
85        if let Err(e) = result {
86            errors.push(format!("{}: {}", dest, e));
87        }
88    }
89
90    if errors.is_empty() {
91        Ok(())
92    } else {
93        Err(AvError::InvalidPolicy(format!(
94            "Audit forwarding errors:\n{}",
95            errors.join("\n")
96        )))
97    }
98}
99
100/// Forward audit entries to S3 as a JSONL file with daily partitioned keys.
101/// Format: s3://bucket/prefix/YYYY/MM/DD/audex-{timestamp}.jsonl
102async fn forward_to_s3(dest: &str, entries: &[AuditEntry]) -> Result<()> {
103    let path = dest.strip_prefix("s3://").unwrap();
104    let (bucket, prefix) = path.split_once('/').unwrap_or((path, ""));
105
106    let now = chrono::Utc::now();
107    let key = format!(
108        "{}{}/audex-{}.jsonl",
109        if prefix.is_empty() { "" } else { prefix },
110        now.format("%Y/%m/%d"),
111        now.format("%Y%m%dT%H%M%SZ"),
112    );
113
114    let body = entries
115        .iter()
116        .filter_map(|e| serde_json::to_string(e).ok())
117        .collect::<Vec<_>>()
118        .join("\n");
119
120    // Use AWS SDK PutObject via presigned URL or direct SDK
121    // For now, use the AWS CLI approach via the aws_sdk_s3 isn't a dep,
122    // so we use reqwest with SigV4-style — but simplest is just shelling out.
123    // Instead, we use the S3 PutObject REST API with credentials from env.
124    let region = std::env::var("AWS_REGION")
125        .or_else(|_| std::env::var("AWS_DEFAULT_REGION"))
126        .unwrap_or_else(|_| "us-east-1".to_string());
127
128    let endpoint = format!("https://{}.s3.{}.amazonaws.com/{}", bucket, region, key);
129
130    let client = reqwest::Client::new();
131    let resp = client
132        .put(&endpoint)
133        .header("Content-Type", "application/x-ndjson")
134        .body(body)
135        .send()
136        .await
137        .map_err(|e| AvError::InvalidPolicy(format!("S3 upload failed: {}", e)))?;
138
139    if !resp.status().is_success() {
140        let status = resp.status();
141        let body = resp.text().await.unwrap_or_default();
142        return Err(AvError::InvalidPolicy(format!(
143            "S3 upload returned {}: {}",
144            status,
145            &body[..body.len().min(200)]
146        )));
147    }
148
149    Ok(())
150}
151
152/// Forward audit entries to CloudWatch Logs.
153/// Format: cloudwatch://log-group-name
154async fn forward_to_cloudwatch(dest: &str, entries: &[AuditEntry]) -> Result<()> {
155    let log_group = dest.strip_prefix("cloudwatch://").unwrap();
156    let log_stream = format!("audex/{}", chrono::Utc::now().format("%Y/%m/%d"));
157
158    // CloudWatch Logs PutLogEvents via REST API
159    let region = std::env::var("AWS_REGION")
160        .or_else(|_| std::env::var("AWS_DEFAULT_REGION"))
161        .unwrap_or_else(|_| "us-east-1".to_string());
162
163    let events: Vec<serde_json::Value> = entries
164        .iter()
165        .map(|e| {
166            serde_json::json!({
167                "timestamp": e.timestamp.timestamp_millis(),
168                "message": serde_json::to_string(e).unwrap_or_default()
169            })
170        })
171        .collect();
172
173    let payload = serde_json::json!({
174        "logGroupName": log_group,
175        "logStreamName": log_stream,
176        "logEvents": events
177    });
178
179    let endpoint = format!("https://logs.{}.amazonaws.com", region);
180
181    let client = reqwest::Client::new();
182    let resp = client
183        .post(&endpoint)
184        .header("Content-Type", "application/x-amz-json-1.1")
185        .header("X-Amz-Target", "Logs_20140328.PutLogEvents")
186        .json(&payload)
187        .send()
188        .await
189        .map_err(|e| AvError::InvalidPolicy(format!("CloudWatch Logs forward failed: {}", e)))?;
190
191    if !resp.status().is_success() {
192        let status = resp.status();
193        let body = resp.text().await.unwrap_or_default();
194        return Err(AvError::InvalidPolicy(format!(
195            "CloudWatch Logs returned {}: {}",
196            status,
197            &body[..body.len().min(200)]
198        )));
199    }
200
201    Ok(())
202}
203
204/// Forward audit entries to a generic webhook endpoint (SIEM, Splunk, etc).
205/// POSTs a JSON array of audit entries.
206async fn forward_to_webhook(url: &str, entries: &[AuditEntry]) -> Result<()> {
207    let client = reqwest::Client::new();
208    let resp = client
209        .post(url)
210        .header("Content-Type", "application/json")
211        .header("User-Agent", "audex-audit-forwarder/0.1")
212        .json(&entries)
213        .send()
214        .await
215        .map_err(|e| AvError::InvalidPolicy(format!("Webhook forward failed: {}", e)))?;
216
217    if !resp.status().is_success() {
218        let status = resp.status();
219        let body = resp.text().await.unwrap_or_default();
220        return Err(AvError::InvalidPolicy(format!(
221            "Webhook returned {}: {}",
222            status,
223            &body[..body.len().min(200)]
224        )));
225    }
226
227    Ok(())
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use crate::audit::{AuditEntry, AuditEvent};
234    use chrono::Utc;
235
236    fn sample_entry() -> AuditEntry {
237        AuditEntry {
238            timestamp: Utc::now(),
239            session_id: "test-session-123".to_string(),
240            provider: "aws".to_string(),
241            event: AuditEvent::SessionCreated {
242                role_arn: "arn:aws:iam::123456789012:role/TestRole".to_string(),
243                ttl_seconds: 900,
244                budget: None,
245                allowed_actions: vec!["s3:GetObject".to_string()],
246                command: vec!["aws".to_string(), "s3".to_string(), "ls".to_string()],
247                agent_id: None,
248            },
249        }
250    }
251
252    #[test]
253    fn test_forward_config_default() {
254        let config = ForwardConfig::default();
255        assert!(config.destinations.is_empty());
256        assert!(config.batch_size.is_none());
257        assert!(config.sync.is_none());
258    }
259
260    #[test]
261    fn test_forward_config_deserialize() {
262        let toml_str = r#"
263destinations = ["s3://my-audit-bucket/audex/", "https://siem.example.com/ingest"]
264batch_size = 5
265sync = false
266"#;
267        let config: ForwardConfig = toml::from_str(toml_str).unwrap();
268        assert_eq!(config.destinations.len(), 2);
269        assert_eq!(config.batch_size, Some(5));
270        assert_eq!(config.sync, Some(false));
271    }
272
273    #[tokio::test]
274    async fn test_buffer_no_destinations() {
275        let config = ForwardConfig::default();
276        let mut buffer = ForwardBuffer::new(config);
277        let flushed = buffer.push(sample_entry()).await.unwrap();
278        assert!(!flushed);
279        assert!(buffer.buffer.is_empty());
280    }
281
282    #[tokio::test]
283    async fn test_buffer_batching() {
284        let config = ForwardConfig {
285            // Use a webhook that won't actually be called since we test buffer logic
286            destinations: vec![],
287            batch_size: Some(3),
288            sync: Some(false),
289        };
290        let mut buffer = ForwardBuffer::new(config);
291
292        // With no destinations, push returns false
293        buffer.push(sample_entry()).await.unwrap();
294        buffer.push(sample_entry()).await.unwrap();
295        assert_eq!(buffer.buffer.len(), 0); // no destinations = no buffering
296    }
297
298    #[test]
299    fn test_entry_serialization() {
300        let entry = sample_entry();
301        let json = serde_json::to_string(&entry).unwrap();
302        assert!(json.contains("test-session-123"));
303        assert!(json.contains("session_created"));
304        // Roundtrip
305        let parsed: AuditEntry = serde_json::from_str(&json).unwrap();
306        assert_eq!(parsed.session_id, "test-session-123");
307    }
308}