Skip to main content

tryaudex_core/
forward.rs

1use serde::{Deserialize, Serialize};
2
3use crate::audit::AuditEntry;
4use crate::error::{AvError, Result};
5
6/// Audit forwarding configuration in `[audit]` config section.
7#[derive(Debug, Clone, Serialize, Deserialize, Default)]
8pub struct ForwardConfig {
9    /// Forwarding destinations. Each entry is a URL:
10    /// - `s3://bucket-name/prefix/` — S3 with daily partitioned keys
11    /// - `cloudwatch://log-group-name` — CloudWatch Logs
12    /// - `https://...` — Generic webhook (POST JSON array)
13    #[serde(default)]
14    pub destinations: Vec<String>,
15    /// Batch size before flushing (default: 10)
16    pub batch_size: Option<usize>,
17    /// Whether to forward synchronously on each event (default: false, batched)
18    pub sync: Option<bool>,
19}
20
21/// In-memory buffer for batched forwarding.
22pub struct ForwardBuffer {
23    config: ForwardConfig,
24    buffer: Vec<AuditEntry>,
25}
26
27impl ForwardBuffer {
28    pub fn new(config: ForwardConfig) -> Self {
29        Self {
30            config,
31            buffer: Vec::new(),
32        }
33    }
34
35    /// Add an entry to the buffer. Flushes if batch size is reached.
36    /// Returns Ok(true) if a flush happened.
37    pub async fn push(&mut self, entry: AuditEntry) -> Result<bool> {
38        if self.config.destinations.is_empty() {
39            return Ok(false);
40        }
41
42        let sync = self.config.sync.unwrap_or(false);
43        if sync {
44            forward_entries(&self.config.destinations, &[entry]).await?;
45            return Ok(true);
46        }
47
48        self.buffer.push(entry);
49        let batch_size = self.config.batch_size.unwrap_or(10);
50        if self.buffer.len() >= batch_size {
51            self.flush().await?;
52            return Ok(true);
53        }
54        Ok(false)
55    }
56
57    /// Flush all buffered entries to all destinations.
58    pub async fn flush(&mut self) -> Result<()> {
59        if self.buffer.is_empty() || self.config.destinations.is_empty() {
60            return Ok(());
61        }
62        let entries = std::mem::take(&mut self.buffer);
63        forward_entries(&self.config.destinations, &entries).await
64    }
65}
66
67/// Forward entries to all configured destinations.
68async fn forward_entries(destinations: &[String], entries: &[AuditEntry]) -> Result<()> {
69    let mut errors = Vec::new();
70
71    for dest in destinations {
72        let result = if dest.starts_with("s3://") {
73            forward_to_s3(dest, entries).await
74        } else if dest.starts_with("cloudwatch://") {
75            forward_to_cloudwatch(dest, entries).await
76        } else if dest.starts_with("https://") || dest.starts_with("http://") {
77            forward_to_webhook(dest, entries).await
78        } else {
79            Err(AvError::InvalidPolicy(format!(
80                "Unknown audit forwarding destination: {}. Expected s3://, cloudwatch://, or https://",
81                dest
82            )))
83        };
84
85        if let Err(e) = result {
86            errors.push(format!("{}: {}", dest, e));
87        }
88    }
89
90    if errors.is_empty() {
91        Ok(())
92    } else {
93        Err(AvError::InvalidPolicy(format!(
94            "Audit forwarding errors:\n{}",
95            errors.join("\n")
96        )))
97    }
98}
99
100/// Forward audit entries to S3 as a JSONL file with daily partitioned keys.
101/// Format: s3://bucket/prefix/YYYY/MM/DD/audex-{timestamp}.jsonl
102async fn forward_to_s3(dest: &str, entries: &[AuditEntry]) -> Result<()> {
103    let path = dest.strip_prefix("s3://").unwrap();
104    let (bucket, prefix) = path.split_once('/').unwrap_or((path, ""));
105
106    let now = chrono::Utc::now();
107    let key = format!(
108        "{}{}/audex-{}.jsonl",
109        if prefix.is_empty() { "" } else { prefix },
110        now.format("%Y/%m/%d"),
111        now.format("%Y%m%dT%H%M%SZ"),
112    );
113
114    let body = entries
115        .iter()
116        .filter_map(|e| serde_json::to_string(e).ok())
117        .collect::<Vec<_>>()
118        .join("\n");
119
120    // Use AWS SDK PutObject via presigned URL or direct SDK
121    // For now, use the AWS CLI approach via the aws_sdk_s3 isn't a dep,
122    // so we use reqwest with SigV4-style — but simplest is just shelling out.
123    // Instead, we use the S3 PutObject REST API with credentials from env.
124    let region = std::env::var("AWS_REGION")
125        .or_else(|_| std::env::var("AWS_DEFAULT_REGION"))
126        .unwrap_or_else(|_| "us-east-1".to_string());
127
128    let endpoint = format!(
129        "https://{}.s3.{}.amazonaws.com/{}",
130        bucket, region, key
131    );
132
133    let client = reqwest::Client::new();
134    let resp = client
135        .put(&endpoint)
136        .header("Content-Type", "application/x-ndjson")
137        .body(body)
138        .send()
139        .await
140        .map_err(|e| AvError::InvalidPolicy(format!("S3 upload failed: {}", e)))?;
141
142    if !resp.status().is_success() {
143        let status = resp.status();
144        let body = resp.text().await.unwrap_or_default();
145        return Err(AvError::InvalidPolicy(format!(
146            "S3 upload returned {}: {}",
147            status,
148            &body[..body.len().min(200)]
149        )));
150    }
151
152    Ok(())
153}
154
155/// Forward audit entries to CloudWatch Logs.
156/// Format: cloudwatch://log-group-name
157async fn forward_to_cloudwatch(dest: &str, entries: &[AuditEntry]) -> Result<()> {
158    let log_group = dest.strip_prefix("cloudwatch://").unwrap();
159    let log_stream = format!("audex/{}", chrono::Utc::now().format("%Y/%m/%d"));
160
161    // CloudWatch Logs PutLogEvents via REST API
162    let region = std::env::var("AWS_REGION")
163        .or_else(|_| std::env::var("AWS_DEFAULT_REGION"))
164        .unwrap_or_else(|_| "us-east-1".to_string());
165
166    let events: Vec<serde_json::Value> = entries
167        .iter()
168        .map(|e| {
169            serde_json::json!({
170                "timestamp": e.timestamp.timestamp_millis(),
171                "message": serde_json::to_string(e).unwrap_or_default()
172            })
173        })
174        .collect();
175
176    let payload = serde_json::json!({
177        "logGroupName": log_group,
178        "logStreamName": log_stream,
179        "logEvents": events
180    });
181
182    let endpoint = format!("https://logs.{}.amazonaws.com", region);
183
184    let client = reqwest::Client::new();
185    let resp = client
186        .post(&endpoint)
187        .header("Content-Type", "application/x-amz-json-1.1")
188        .header("X-Amz-Target", "Logs_20140328.PutLogEvents")
189        .json(&payload)
190        .send()
191        .await
192        .map_err(|e| AvError::InvalidPolicy(format!("CloudWatch Logs forward failed: {}", e)))?;
193
194    if !resp.status().is_success() {
195        let status = resp.status();
196        let body = resp.text().await.unwrap_or_default();
197        return Err(AvError::InvalidPolicy(format!(
198            "CloudWatch Logs returned {}: {}",
199            status,
200            &body[..body.len().min(200)]
201        )));
202    }
203
204    Ok(())
205}
206
207/// Forward audit entries to a generic webhook endpoint (SIEM, Splunk, etc).
208/// POSTs a JSON array of audit entries.
209async fn forward_to_webhook(url: &str, entries: &[AuditEntry]) -> Result<()> {
210    let client = reqwest::Client::new();
211    let resp = client
212        .post(url)
213        .header("Content-Type", "application/json")
214        .header("User-Agent", "audex-audit-forwarder/0.1")
215        .json(&entries)
216        .send()
217        .await
218        .map_err(|e| AvError::InvalidPolicy(format!("Webhook forward failed: {}", e)))?;
219
220    if !resp.status().is_success() {
221        let status = resp.status();
222        let body = resp.text().await.unwrap_or_default();
223        return Err(AvError::InvalidPolicy(format!(
224            "Webhook returned {}: {}",
225            status,
226            &body[..body.len().min(200)]
227        )));
228    }
229
230    Ok(())
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use crate::audit::{AuditEvent, AuditEntry};
237    use chrono::Utc;
238
239    fn sample_entry() -> AuditEntry {
240        AuditEntry {
241            timestamp: Utc::now(),
242            session_id: "test-session-123".to_string(),
243            provider: "aws".to_string(),
244            event: AuditEvent::SessionCreated {
245                role_arn: "arn:aws:iam::123456789012:role/TestRole".to_string(),
246                ttl_seconds: 900,
247                budget: None,
248                allowed_actions: vec!["s3:GetObject".to_string()],
249                command: vec!["aws".to_string(), "s3".to_string(), "ls".to_string()],
250                agent_id: None,
251            },
252        }
253    }
254
255    #[test]
256    fn test_forward_config_default() {
257        let config = ForwardConfig::default();
258        assert!(config.destinations.is_empty());
259        assert!(config.batch_size.is_none());
260        assert!(config.sync.is_none());
261    }
262
263    #[test]
264    fn test_forward_config_deserialize() {
265        let toml_str = r#"
266destinations = ["s3://my-audit-bucket/audex/", "https://siem.example.com/ingest"]
267batch_size = 5
268sync = false
269"#;
270        let config: ForwardConfig = toml::from_str(toml_str).unwrap();
271        assert_eq!(config.destinations.len(), 2);
272        assert_eq!(config.batch_size, Some(5));
273        assert_eq!(config.sync, Some(false));
274    }
275
276    #[tokio::test]
277    async fn test_buffer_no_destinations() {
278        let config = ForwardConfig::default();
279        let mut buffer = ForwardBuffer::new(config);
280        let flushed = buffer.push(sample_entry()).await.unwrap();
281        assert!(!flushed);
282        assert!(buffer.buffer.is_empty());
283    }
284
285    #[tokio::test]
286    async fn test_buffer_batching() {
287        let config = ForwardConfig {
288            // Use a webhook that won't actually be called since we test buffer logic
289            destinations: vec![],
290            batch_size: Some(3),
291            sync: Some(false),
292        };
293        let mut buffer = ForwardBuffer::new(config);
294
295        // With no destinations, push returns false
296        buffer.push(sample_entry()).await.unwrap();
297        buffer.push(sample_entry()).await.unwrap();
298        assert_eq!(buffer.buffer.len(), 0); // no destinations = no buffering
299    }
300
301    #[test]
302    fn test_entry_serialization() {
303        let entry = sample_entry();
304        let json = serde_json::to_string(&entry).unwrap();
305        assert!(json.contains("test-session-123"));
306        assert!(json.contains("session_created"));
307        // Roundtrip
308        let parsed: AuditEntry = serde_json::from_str(&json).unwrap();
309        assert_eq!(parsed.session_id, "test-session-123");
310    }
311}