Skip to main content

tryaudex_core/
forward.rs

1use serde::{Deserialize, Serialize};
2
3use crate::audit::AuditEntry;
4use crate::error::{AvError, Result};
5
6/// Audit forwarding configuration in `[audit]` config section.
7#[derive(Debug, Clone, Serialize, Deserialize, Default)]
8pub struct ForwardConfig {
9    /// Forwarding destinations. Each entry is a URL:
10    /// - `s3://bucket-name/prefix/` — S3 with daily partitioned keys
11    /// - `cloudwatch://log-group-name` — CloudWatch Logs
12    /// - `https://...` — Generic webhook (POST JSON array)
13    #[serde(default)]
14    pub destinations: Vec<String>,
15    /// Batch size before flushing (default: 10)
16    pub batch_size: Option<usize>,
17    /// Whether to forward synchronously on each event (default: false, batched)
18    pub sync: Option<bool>,
19    /// Bearer token to include in the `Authorization` header for webhook requests.
20    /// If set, all `https://` destinations will receive `Authorization: Bearer <token>`.
21    /// The env var `AUDEX_WEBHOOK_BEARER_TOKEN` takes precedence if set, avoiding
22    /// plaintext secrets in the config file.
23    pub webhook_bearer_token: Option<String>,
24}
25
26impl ForwardConfig {
27    /// Return the effective webhook bearer token, preferring
28    /// `AUDEX_WEBHOOK_BEARER_TOKEN` env var over the config-file value.
29    pub fn resolve_bearer_token(&self) -> Option<String> {
30        std::env::var("AUDEX_WEBHOOK_BEARER_TOKEN")
31            .ok()
32            .or_else(|| self.webhook_bearer_token.clone())
33    }
34}
35
36/// In-memory buffer for batched forwarding.
37pub struct ForwardBuffer {
38    config: ForwardConfig,
39    buffer: Vec<AuditEntry>,
40}
41
42impl ForwardBuffer {
43    pub fn new(config: ForwardConfig) -> Self {
44        Self {
45            config,
46            buffer: Vec::new(),
47        }
48    }
49
50    /// Add an entry to the buffer. Flushes if batch size is reached.
51    /// Returns Ok(true) if a flush happened.
52    pub async fn push(&mut self, entry: AuditEntry) -> Result<bool> {
53        if self.config.destinations.is_empty() {
54            return Ok(false);
55        }
56
57        let sync = self.config.sync.unwrap_or(false);
58        if sync {
59            // R6-H14: resolve env override (`AUDEX_WEBHOOK_BEARER_TOKEN`)
60            // instead of always using the config-file value.
61            let bearer = self.config.resolve_bearer_token();
62            forward_entries(&self.config.destinations, bearer.as_deref(), &[entry]).await?;
63            return Ok(true);
64        }
65
66        self.buffer.push(entry);
67        let batch_size = self.config.batch_size.unwrap_or(10);
68        if self.buffer.len() >= batch_size {
69            self.flush().await?;
70            return Ok(true);
71        }
72        Ok(false)
73    }
74
75    /// Flush all buffered entries to all destinations.
76    pub async fn flush(&mut self) -> Result<()> {
77        if self.buffer.is_empty() || self.config.destinations.is_empty() {
78            return Ok(());
79        }
80        let entries = std::mem::take(&mut self.buffer);
81        // R6-H14: resolve env override (`AUDEX_WEBHOOK_BEARER_TOKEN`)
82        // instead of always using the config-file value.
83        let bearer = self.config.resolve_bearer_token();
84        forward_entries(&self.config.destinations, bearer.as_deref(), &entries).await
85    }
86}
87
88/// Forward entries to all configured destinations.
89async fn forward_entries(
90    destinations: &[String],
91    webhook_bearer_token: Option<&str>,
92    entries: &[AuditEntry],
93) -> Result<()> {
94    let mut errors = Vec::new();
95
96    for dest in destinations {
97        let result = if dest.starts_with("s3://") {
98            forward_to_s3(dest, entries).await
99        } else if dest.starts_with("cloudwatch://") {
100            forward_to_cloudwatch(dest, entries).await
101        } else if dest.starts_with("https://") {
102            forward_to_webhook(dest, entries, webhook_bearer_token).await
103        } else if dest.starts_with("http://") {
104            Err(AvError::InvalidPolicy(format!(
105                "Refusing to forward audit data over plaintext HTTP: {}. Use https:// instead.",
106                dest
107            )))
108        } else {
109            Err(AvError::InvalidPolicy(format!(
110                "Unknown audit forwarding destination: {}. Expected s3://, cloudwatch://, or https://",
111                dest
112            )))
113        };
114
115        if let Err(e) = result {
116            errors.push(format!("{}: {}", dest, e));
117        }
118    }
119
120    if errors.is_empty() {
121        Ok(())
122    } else {
123        Err(AvError::InvalidPolicy(format!(
124            "Audit forwarding errors:\n{}",
125            errors.join("\n")
126        )))
127    }
128}
129
130/// Forward audit entries to S3 as a JSONL file with daily partitioned keys.
131/// Format: s3://bucket/prefix/YYYY/MM/DD/audex-{timestamp}.jsonl
132///
133/// Uses the AWS CLI for proper SigV4-signed requests, inheriting ambient
134/// credentials from the environment (same approach as cleanup.rs).
135async fn forward_to_s3(dest: &str, entries: &[AuditEntry]) -> Result<()> {
136    let path = dest.strip_prefix("s3://").unwrap();
137    let (bucket, prefix) = path.split_once('/').unwrap_or((path, ""));
138
139    let now = chrono::Utc::now();
140    let key = format!(
141        "{}{}/audex-{}.jsonl",
142        if prefix.is_empty() { "" } else { prefix },
143        now.format("%Y/%m/%d"),
144        now.format("%Y%m%dT%H%M%SZ"),
145    );
146
147    let body = entries
148        .iter()
149        .filter_map(|e| serde_json::to_string(e).ok())
150        .collect::<Vec<_>>()
151        .join("\n");
152
153    let s3_uri = format!("s3://{}/{}", bucket, key);
154
155    // Write body to a temp file for `aws s3 cp` to read.
156    // Use UUID to avoid predictable names (symlink attack risk).
157    let tmp = std::env::temp_dir().join(format!("audex-fwd-{}.jsonl", uuid::Uuid::new_v4()));
158    std::fs::write(&tmp, &body)?;
159
160    let tmp_path = tmp.clone();
161    let output = tokio::task::spawn_blocking(move || {
162        std::process::Command::new("aws")
163            .args([
164                "s3",
165                "cp",
166                tmp_path.to_str().unwrap_or("-"),
167                &s3_uri,
168                "--content-type",
169                "application/x-ndjson",
170            ])
171            .output()
172    })
173    .await
174    .map_err(|e| AvError::InvalidPolicy(format!("S3 upload task panicked: {}", e)))?
175    .map_err(|e| AvError::InvalidPolicy(format!("Failed to run aws s3 cp: {}", e)))?;
176
177    // Clean up temp file regardless of outcome.
178    let _ = std::fs::remove_file(&tmp);
179
180    if !output.status.success() {
181        let stderr = String::from_utf8_lossy(&output.stderr);
182        return Err(AvError::InvalidPolicy(format!(
183            "S3 upload failed: {}",
184            &stderr[..stderr.len().min(200)]
185        )));
186    }
187
188    Ok(())
189}
190
191/// Forward audit entries to CloudWatch Logs.
192/// Format: cloudwatch://log-group-name
193///
194/// Uses the AWS CLI for proper SigV4-signed requests.
195async fn forward_to_cloudwatch(dest: &str, entries: &[AuditEntry]) -> Result<()> {
196    let log_group = dest.strip_prefix("cloudwatch://").unwrap();
197    let log_stream = format!("audex/{}", chrono::Utc::now().format("%Y/%m/%d"));
198
199    // Ensure the log stream exists (create is idempotent).
200    let lg = log_group.to_string();
201    let ls = log_stream.clone();
202    let _ = tokio::task::spawn_blocking(move || {
203        std::process::Command::new("aws")
204            .args([
205                "logs",
206                "create-log-stream",
207                "--log-group-name",
208                &lg,
209                "--log-stream-name",
210                &ls,
211            ])
212            .output()
213    })
214    .await;
215
216    let events: Vec<serde_json::Value> = entries
217        .iter()
218        .map(|e| {
219            serde_json::json!({
220                "timestamp": e.timestamp.timestamp_millis(),
221                "message": serde_json::to_string(e).unwrap_or_default()
222            })
223        })
224        .collect();
225
226    let events_json = serde_json::to_string(&events)
227        .map_err(|e| AvError::InvalidPolicy(format!("Failed to serialize events: {}", e)))?;
228
229    let lg2 = log_group.to_string();
230    let ls2 = log_stream;
231    let output = tokio::task::spawn_blocking(move || {
232        std::process::Command::new("aws")
233            .args([
234                "logs",
235                "put-log-events",
236                "--log-group-name",
237                &lg2,
238                "--log-stream-name",
239                &ls2,
240                "--log-events",
241                &events_json,
242            ])
243            .output()
244    })
245    .await
246    .map_err(|e| AvError::InvalidPolicy(format!("CloudWatch task panicked: {}", e)))?
247    .map_err(|e| AvError::InvalidPolicy(format!("Failed to run aws logs: {}", e)))?;
248
249    if !output.status.success() {
250        let stderr = String::from_utf8_lossy(&output.stderr);
251        return Err(AvError::InvalidPolicy(format!(
252            "CloudWatch Logs forward failed: {}",
253            &stderr[..stderr.len().min(200)]
254        )));
255    }
256
257    Ok(())
258}
259
260/// Forward audit entries to a generic webhook endpoint (SIEM, Splunk, etc).
261/// POSTs a JSON array of audit entries. If `bearer_token` is provided it is
262/// sent as `Authorization: Bearer <token>`.
263async fn forward_to_webhook(
264    url: &str,
265    entries: &[AuditEntry],
266    bearer_token: Option<&str>,
267) -> Result<()> {
268    let client = reqwest::Client::new();
269    let mut req = client
270        .post(url)
271        .header("Content-Type", "application/json")
272        .header("User-Agent", "audex-audit-forwarder/0.1")
273        .json(&entries);
274    if let Some(token) = bearer_token {
275        req = req.bearer_auth(token);
276    }
277    let resp = req
278        .send()
279        .await
280        .map_err(|e| AvError::InvalidPolicy(format!("Webhook forward failed: {}", e)))?;
281
282    if !resp.status().is_success() {
283        let status = resp.status();
284        let body = resp.text().await.unwrap_or_default();
285        return Err(AvError::InvalidPolicy(format!(
286            "Webhook returned {}: {}",
287            status,
288            &body[..body.len().min(200)]
289        )));
290    }
291
292    Ok(())
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use crate::audit::{AuditEntry, AuditEvent};
299    use chrono::Utc;
300
301    fn sample_entry() -> AuditEntry {
302        AuditEntry {
303            timestamp: Utc::now(),
304            session_id: "test-session-123".to_string(),
305            provider: "aws".to_string(),
306            event: AuditEvent::SessionCreated {
307                role_arn: "arn:aws:iam::123456789012:role/TestRole".to_string(),
308                ttl_seconds: 900,
309                budget: None,
310                allowed_actions: vec!["s3:GetObject".to_string()],
311                command: vec!["aws".to_string(), "s3".to_string(), "ls".to_string()],
312                agent_id: None,
313            },
314        }
315    }
316
317    #[test]
318    fn test_forward_config_default() {
319        let config = ForwardConfig::default();
320        assert!(config.destinations.is_empty());
321        assert!(config.batch_size.is_none());
322        assert!(config.sync.is_none());
323    }
324
325    #[test]
326    fn test_forward_config_deserialize() {
327        let toml_str = r#"
328destinations = ["s3://my-audit-bucket/audex/", "https://siem.example.com/ingest"]
329batch_size = 5
330sync = false
331"#;
332        let config: ForwardConfig = toml::from_str(toml_str).unwrap();
333        assert_eq!(config.destinations.len(), 2);
334        assert_eq!(config.batch_size, Some(5));
335        assert_eq!(config.sync, Some(false));
336    }
337
338    #[tokio::test]
339    async fn test_buffer_no_destinations() {
340        let config = ForwardConfig::default();
341        let mut buffer = ForwardBuffer::new(config);
342        let flushed = buffer.push(sample_entry()).await.unwrap();
343        assert!(!flushed);
344        assert!(buffer.buffer.is_empty());
345    }
346
347    #[tokio::test]
348    async fn test_buffer_batching() {
349        let config = ForwardConfig {
350            // Use a webhook that won't actually be called since we test buffer logic
351            destinations: vec![],
352            batch_size: Some(3),
353            sync: Some(false),
354            ..Default::default()
355        };
356        let mut buffer = ForwardBuffer::new(config);
357
358        // With no destinations, push returns false
359        buffer.push(sample_entry()).await.unwrap();
360        buffer.push(sample_entry()).await.unwrap();
361        assert_eq!(buffer.buffer.len(), 0); // no destinations = no buffering
362    }
363
364    #[test]
365    fn test_entry_serialization() {
366        let entry = sample_entry();
367        let json = serde_json::to_string(&entry).unwrap();
368        assert!(json.contains("test-session-123"));
369        assert!(json.contains("session_created"));
370        // Roundtrip
371        let parsed: AuditEntry = serde_json::from_str(&json).unwrap();
372        assert_eq!(parsed.session_id, "test-session-123");
373    }
374}