Skip to main content

zeph_tools/
audit.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::path::Path;
5
6use crate::config::AuditConfig;
7
8#[derive(Debug)]
9pub struct AuditLogger {
10    destination: AuditDestination,
11}
12
13#[derive(Debug)]
14enum AuditDestination {
15    Stdout,
16    File(tokio::sync::Mutex<tokio::fs::File>),
17}
18
19#[derive(serde::Serialize)]
20#[allow(clippy::struct_excessive_bools)]
21pub struct AuditEntry {
22    pub timestamp: String,
23    pub tool: String,
24    pub command: String,
25    pub result: AuditResult,
26    pub duration_ms: u64,
27    /// Fine-grained error category label from the taxonomy. `None` for successful executions.
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub error_category: Option<String>,
30    /// High-level error domain for recovery dispatch. `None` for successful executions.
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub error_domain: Option<String>,
33    /// Invocation phase in which the error occurred per arXiv:2601.16280 taxonomy.
34    /// `None` for successful executions.
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub error_phase: Option<String>,
37    /// Provenance of the tool result. `None` for non-executor audit entries (e.g. policy checks).
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub claim_source: Option<crate::executor::ClaimSource>,
40    /// MCP server ID for tool calls routed through `McpToolExecutor`. `None` for native tools.
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub mcp_server_id: Option<String>,
43    /// Tool output was flagged by regex injection detection.
44    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
45    pub injection_flagged: bool,
46    /// Tool output was flagged as anomalous by the embedding guard.
47    /// Raw cosine distance is NOT stored (prevents threshold reverse-engineering).
48    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
49    pub embedding_anomalous: bool,
50    /// Tool result crossed the MCP-to-ACP trust boundary (MCP tool result served to an ACP client).
51    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
52    pub cross_boundary_mcp_to_acp: bool,
53    /// Decision recorded by the adversarial policy agent before execution.
54    ///
55    /// Values: `"allow"`, `"deny:<reason>"`, `"error:<message>"`.
56    /// `None` when adversarial policy is disabled or not applicable.
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub adversarial_policy_decision: Option<String>,
59    /// Process exit code for shell tool executions. `None` for non-shell tools.
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub exit_code: Option<i32>,
62    /// Whether tool output was truncated before storage. Default false.
63    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
64    pub truncated: bool,
65}
66
67#[derive(serde::Serialize)]
68#[serde(tag = "type")]
69pub enum AuditResult {
70    #[serde(rename = "success")]
71    Success,
72    #[serde(rename = "blocked")]
73    Blocked { reason: String },
74    #[serde(rename = "error")]
75    Error { message: String },
76    #[serde(rename = "timeout")]
77    Timeout,
78    #[serde(rename = "rollback")]
79    Rollback { restored: usize, deleted: usize },
80}
81
82impl AuditLogger {
83    /// Create a new `AuditLogger` from config.
84    ///
85    /// # Errors
86    ///
87    /// Returns an error if a file destination cannot be opened.
88    pub async fn from_config(config: &AuditConfig) -> Result<Self, std::io::Error> {
89        let destination = if config.destination == "stdout" {
90            AuditDestination::Stdout
91        } else {
92            let file = tokio::fs::OpenOptions::new()
93                .create(true)
94                .append(true)
95                .open(Path::new(&config.destination))
96                .await?;
97            AuditDestination::File(tokio::sync::Mutex::new(file))
98        };
99
100        Ok(Self { destination })
101    }
102
103    pub async fn log(&self, entry: &AuditEntry) {
104        let json = match serde_json::to_string(entry) {
105            Ok(j) => j,
106            Err(err) => {
107                tracing::error!("audit entry serialization failed: {err}");
108                return;
109            }
110        };
111
112        match &self.destination {
113            AuditDestination::Stdout => {
114                tracing::info!(target: "audit", "{json}");
115            }
116            AuditDestination::File(file) => {
117                use tokio::io::AsyncWriteExt;
118                let mut f = file.lock().await;
119                let line = format!("{json}\n");
120                if let Err(e) = f.write_all(line.as_bytes()).await {
121                    tracing::error!("failed to write audit log: {e}");
122                } else if let Err(e) = f.flush().await {
123                    tracing::error!("failed to flush audit log: {e}");
124                }
125            }
126        }
127    }
128}
129
130/// Log a per-tool risk summary at startup when `audit.tool_risk_summary = true`.
131///
132/// Each entry records tool name, privilege level (static mapping by tool id), and the
133/// expected input sanitization method. This is a design-time inventory label —
134/// NOT a runtime guarantee that sanitization is functioning correctly.
135pub fn log_tool_risk_summary(tool_ids: &[&str]) {
136    // Static privilege mapping: tool id prefix → (privilege level, expected sanitization).
137    // "high" = can execute arbitrary OS commands; "medium" = network/filesystem access;
138    // "low" = schema-validated parameters only.
139    fn classify(id: &str) -> (&'static str, &'static str) {
140        if id.starts_with("shell") || id == "bash" || id == "exec" {
141            ("high", "env_blocklist + command_blocklist")
142        } else if id.starts_with("web_scrape") || id == "fetch" || id.starts_with("scrape") {
143            ("medium", "validate_url + SSRF + domain_policy")
144        } else if id.starts_with("file_write")
145            || id.starts_with("file_read")
146            || id.starts_with("file")
147        {
148            ("medium", "path_sandbox")
149        } else {
150            ("low", "schema_only")
151        }
152    }
153
154    for &id in tool_ids {
155        let (privilege, sanitization) = classify(id);
156        tracing::info!(
157            tool = id,
158            privilege_level = privilege,
159            expected_sanitization = sanitization,
160            "tool risk summary"
161        );
162    }
163}
164
165#[must_use]
166pub fn chrono_now() -> String {
167    use std::time::{SystemTime, UNIX_EPOCH};
168    let secs = SystemTime::now()
169        .duration_since(UNIX_EPOCH)
170        .unwrap_or_default()
171        .as_secs();
172    format!("{secs}")
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn audit_entry_serialization() {
181        let entry = AuditEntry {
182            timestamp: "1234567890".into(),
183            tool: "shell".into(),
184            command: "echo hello".into(),
185            result: AuditResult::Success,
186            duration_ms: 42,
187            error_category: None,
188            error_domain: None,
189            error_phase: None,
190            claim_source: None,
191            mcp_server_id: None,
192            injection_flagged: false,
193            embedding_anomalous: false,
194            cross_boundary_mcp_to_acp: false,
195            adversarial_policy_decision: None,
196            exit_code: None,
197            truncated: false,
198        };
199        let json = serde_json::to_string(&entry).unwrap();
200        assert!(json.contains("\"type\":\"success\""));
201        assert!(json.contains("\"tool\":\"shell\""));
202        assert!(json.contains("\"duration_ms\":42"));
203    }
204
205    #[test]
206    fn audit_result_blocked_serialization() {
207        let entry = AuditEntry {
208            timestamp: "0".into(),
209            tool: "shell".into(),
210            command: "sudo rm".into(),
211            result: AuditResult::Blocked {
212                reason: "blocked command: sudo".into(),
213            },
214            duration_ms: 0,
215            error_category: Some("policy_blocked".to_owned()),
216            error_domain: Some("action".to_owned()),
217            error_phase: None,
218            claim_source: None,
219            mcp_server_id: None,
220            injection_flagged: false,
221            embedding_anomalous: false,
222            cross_boundary_mcp_to_acp: false,
223            adversarial_policy_decision: None,
224            exit_code: None,
225            truncated: false,
226        };
227        let json = serde_json::to_string(&entry).unwrap();
228        assert!(json.contains("\"type\":\"blocked\""));
229        assert!(json.contains("\"reason\""));
230    }
231
232    #[test]
233    fn audit_result_error_serialization() {
234        let entry = AuditEntry {
235            timestamp: "0".into(),
236            tool: "shell".into(),
237            command: "bad".into(),
238            result: AuditResult::Error {
239                message: "exec failed".into(),
240            },
241            duration_ms: 0,
242            error_category: None,
243            error_domain: None,
244            error_phase: None,
245            claim_source: None,
246            mcp_server_id: None,
247            injection_flagged: false,
248            embedding_anomalous: false,
249            cross_boundary_mcp_to_acp: false,
250            adversarial_policy_decision: None,
251            exit_code: None,
252            truncated: false,
253        };
254        let json = serde_json::to_string(&entry).unwrap();
255        assert!(json.contains("\"type\":\"error\""));
256    }
257
258    #[test]
259    fn audit_result_timeout_serialization() {
260        let entry = AuditEntry {
261            timestamp: "0".into(),
262            tool: "shell".into(),
263            command: "sleep 999".into(),
264            result: AuditResult::Timeout,
265            duration_ms: 30000,
266            error_category: Some("timeout".to_owned()),
267            error_domain: Some("system".to_owned()),
268            error_phase: None,
269            claim_source: None,
270            mcp_server_id: None,
271            injection_flagged: false,
272            embedding_anomalous: false,
273            cross_boundary_mcp_to_acp: false,
274            adversarial_policy_decision: None,
275            exit_code: None,
276            truncated: false,
277        };
278        let json = serde_json::to_string(&entry).unwrap();
279        assert!(json.contains("\"type\":\"timeout\""));
280    }
281
282    #[tokio::test]
283    async fn audit_logger_stdout() {
284        let config = AuditConfig {
285            enabled: true,
286            destination: "stdout".into(),
287            ..Default::default()
288        };
289        let logger = AuditLogger::from_config(&config).await.unwrap();
290        let entry = AuditEntry {
291            timestamp: "0".into(),
292            tool: "shell".into(),
293            command: "echo test".into(),
294            result: AuditResult::Success,
295            duration_ms: 1,
296            error_category: None,
297            error_domain: None,
298            error_phase: None,
299            claim_source: None,
300            mcp_server_id: None,
301            injection_flagged: false,
302            embedding_anomalous: false,
303            cross_boundary_mcp_to_acp: false,
304            adversarial_policy_decision: None,
305            exit_code: None,
306            truncated: false,
307        };
308        logger.log(&entry).await;
309    }
310
311    #[tokio::test]
312    async fn audit_logger_file() {
313        let dir = tempfile::tempdir().unwrap();
314        let path = dir.path().join("audit.log");
315        let config = AuditConfig {
316            enabled: true,
317            destination: path.display().to_string(),
318            ..Default::default()
319        };
320        let logger = AuditLogger::from_config(&config).await.unwrap();
321        let entry = AuditEntry {
322            timestamp: "0".into(),
323            tool: "shell".into(),
324            command: "echo test".into(),
325            result: AuditResult::Success,
326            duration_ms: 1,
327            error_category: None,
328            error_domain: None,
329            error_phase: None,
330            claim_source: None,
331            mcp_server_id: None,
332            injection_flagged: false,
333            embedding_anomalous: false,
334            cross_boundary_mcp_to_acp: false,
335            adversarial_policy_decision: None,
336            exit_code: None,
337            truncated: false,
338        };
339        logger.log(&entry).await;
340
341        let content = tokio::fs::read_to_string(&path).await.unwrap();
342        assert!(content.contains("\"tool\":\"shell\""));
343    }
344
345    #[tokio::test]
346    async fn audit_logger_file_write_error_logged() {
347        let config = AuditConfig {
348            enabled: true,
349            destination: "/nonexistent/dir/audit.log".into(),
350            ..Default::default()
351        };
352        let result = AuditLogger::from_config(&config).await;
353        assert!(result.is_err());
354    }
355
356    #[test]
357    fn claim_source_serde_roundtrip() {
358        use crate::executor::ClaimSource;
359        let cases = [
360            (ClaimSource::Shell, "\"shell\""),
361            (ClaimSource::FileSystem, "\"file_system\""),
362            (ClaimSource::WebScrape, "\"web_scrape\""),
363            (ClaimSource::Mcp, "\"mcp\""),
364            (ClaimSource::A2a, "\"a2a\""),
365            (ClaimSource::CodeSearch, "\"code_search\""),
366            (ClaimSource::Diagnostics, "\"diagnostics\""),
367            (ClaimSource::Memory, "\"memory\""),
368        ];
369        for (variant, expected_json) in cases {
370            let serialized = serde_json::to_string(&variant).unwrap();
371            assert_eq!(serialized, expected_json, "serialize {variant:?}");
372            let deserialized: ClaimSource = serde_json::from_str(&serialized).unwrap();
373            assert_eq!(deserialized, variant, "deserialize {variant:?}");
374        }
375    }
376
377    #[test]
378    fn audit_entry_claim_source_none_omitted() {
379        let entry = AuditEntry {
380            timestamp: "0".into(),
381            tool: "shell".into(),
382            command: "echo".into(),
383            result: AuditResult::Success,
384            duration_ms: 1,
385            error_category: None,
386            error_domain: None,
387            error_phase: None,
388            claim_source: None,
389            mcp_server_id: None,
390            injection_flagged: false,
391            embedding_anomalous: false,
392            cross_boundary_mcp_to_acp: false,
393            adversarial_policy_decision: None,
394            exit_code: None,
395            truncated: false,
396        };
397        let json = serde_json::to_string(&entry).unwrap();
398        assert!(
399            !json.contains("claim_source"),
400            "claim_source must be omitted when None: {json}"
401        );
402    }
403
404    #[test]
405    fn audit_entry_claim_source_some_present() {
406        use crate::executor::ClaimSource;
407        let entry = AuditEntry {
408            timestamp: "0".into(),
409            tool: "shell".into(),
410            command: "echo".into(),
411            result: AuditResult::Success,
412            duration_ms: 1,
413            error_category: None,
414            error_domain: None,
415            error_phase: None,
416            claim_source: Some(ClaimSource::Shell),
417            mcp_server_id: None,
418            injection_flagged: false,
419            embedding_anomalous: false,
420            cross_boundary_mcp_to_acp: false,
421            adversarial_policy_decision: None,
422            exit_code: None,
423            truncated: false,
424        };
425        let json = serde_json::to_string(&entry).unwrap();
426        assert!(
427            json.contains("\"claim_source\":\"shell\""),
428            "expected claim_source=shell in JSON: {json}"
429        );
430    }
431
432    #[tokio::test]
433    async fn audit_logger_multiple_entries() {
434        let dir = tempfile::tempdir().unwrap();
435        let path = dir.path().join("audit.log");
436        let config = AuditConfig {
437            enabled: true,
438            destination: path.display().to_string(),
439            ..Default::default()
440        };
441        let logger = AuditLogger::from_config(&config).await.unwrap();
442
443        for i in 0..5 {
444            let entry = AuditEntry {
445                timestamp: i.to_string(),
446                tool: "shell".into(),
447                command: format!("cmd{i}"),
448                result: AuditResult::Success,
449                duration_ms: i,
450                error_category: None,
451                error_domain: None,
452                error_phase: None,
453                claim_source: None,
454                mcp_server_id: None,
455                injection_flagged: false,
456                embedding_anomalous: false,
457                cross_boundary_mcp_to_acp: false,
458                adversarial_policy_decision: None,
459                exit_code: None,
460                truncated: false,
461            };
462            logger.log(&entry).await;
463        }
464
465        let content = tokio::fs::read_to_string(&path).await.unwrap();
466        assert_eq!(content.lines().count(), 5);
467    }
468
469    #[test]
470    fn audit_entry_exit_code_serialized() {
471        let entry = AuditEntry {
472            timestamp: "0".into(),
473            tool: "shell".into(),
474            command: "echo hi".into(),
475            result: AuditResult::Success,
476            duration_ms: 5,
477            error_category: None,
478            error_domain: None,
479            error_phase: None,
480            claim_source: None,
481            mcp_server_id: None,
482            injection_flagged: false,
483            embedding_anomalous: false,
484            cross_boundary_mcp_to_acp: false,
485            adversarial_policy_decision: None,
486            exit_code: Some(0),
487            truncated: false,
488        };
489        let json = serde_json::to_string(&entry).unwrap();
490        assert!(
491            json.contains("\"exit_code\":0"),
492            "exit_code must be serialized: {json}"
493        );
494    }
495
496    #[test]
497    fn audit_entry_exit_code_none_omitted() {
498        let entry = AuditEntry {
499            timestamp: "0".into(),
500            tool: "file".into(),
501            command: "read /tmp/x".into(),
502            result: AuditResult::Success,
503            duration_ms: 1,
504            error_category: None,
505            error_domain: None,
506            error_phase: None,
507            claim_source: None,
508            mcp_server_id: None,
509            injection_flagged: false,
510            embedding_anomalous: false,
511            cross_boundary_mcp_to_acp: false,
512            adversarial_policy_decision: None,
513            exit_code: None,
514            truncated: false,
515        };
516        let json = serde_json::to_string(&entry).unwrap();
517        assert!(
518            !json.contains("exit_code"),
519            "exit_code None must be omitted: {json}"
520        );
521    }
522
523    #[test]
524    fn log_tool_risk_summary_does_not_panic() {
525        log_tool_risk_summary(&[
526            "shell",
527            "bash",
528            "exec",
529            "web_scrape",
530            "fetch",
531            "scrape_page",
532            "file_write",
533            "file_read",
534            "file_delete",
535            "memory_search",
536            "unknown_tool",
537        ]);
538    }
539
540    #[test]
541    fn log_tool_risk_summary_empty_input_does_not_panic() {
542        log_tool_risk_summary(&[]);
543    }
544}