Skip to main content

vtcode_core/safety/
hitl.rs

1use hashbrown::HashSet;
2use std::time::SystemTime;
3
4use anyhow::Result;
5use serde::{Deserialize, Serialize};
6
7use crate::tools::registry::RiskLevel;
8
9/// Decision rendered by the human-in-the-loop gate.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
11pub enum OversightDecision {
12    Allow,
13    Deny,
14    RequireApproval,
15}
16
17/// Formal HITL policy configuration with explicit whitelist/blacklist
18#[derive(Debug, Clone, Default)]
19pub struct HitlPolicy {
20    /// Tools that are automatically approved without user confirmation
21    pub auto_approve_tools: HashSet<String>,
22    /// Tools that are always denied regardless of other settings
23    pub always_deny_tools: HashSet<String>,
24    /// Whether to require approval for tools not in either list
25    pub default_require_approval: bool,
26}
27
28impl HitlPolicy {
29    pub fn new() -> Self {
30        Self {
31            auto_approve_tools: HashSet::new(),
32            always_deny_tools: HashSet::new(),
33            default_require_approval: true,
34        }
35    }
36
37    /// Create a permissive policy (auto-approve common safe tools)
38    pub fn permissive() -> Self {
39        let mut policy = Self::new();
40        policy.default_require_approval = false;
41        // Common safe read-only tools
42        for tool in &[
43            crate::config::constants::tools::LIST_FILES,
44            crate::config::constants::tools::READ_FILE,
45            "grep_search",
46            "view_file",
47        ] {
48            policy.auto_approve_tools.insert(tool.to_string());
49        }
50        policy
51    }
52
53    /// Create a strict policy (require approval for everything)
54    pub fn strict() -> Self {
55        Self {
56            auto_approve_tools: HashSet::new(),
57            always_deny_tools: HashSet::new(),
58            default_require_approval: true,
59        }
60    }
61
62    /// Check policy for a specific tool
63    pub fn check_tool(&self, tool_name: &str) -> OversightDecision {
64        if self.always_deny_tools.contains(tool_name) {
65            OversightDecision::Deny
66        } else if self.auto_approve_tools.contains(tool_name) {
67            OversightDecision::Allow
68        } else if self.default_require_approval {
69            OversightDecision::RequireApproval
70        } else {
71            OversightDecision::Allow
72        }
73    }
74
75    /// Add a tool to the auto-approve whitelist
76    pub fn whitelist_tool(&mut self, tool_name: impl Into<String>) {
77        let name = tool_name.into();
78        self.always_deny_tools.remove(&name);
79        self.auto_approve_tools.insert(name);
80    }
81
82    /// Add a tool to the always-deny blacklist
83    pub fn blacklist_tool(&mut self, tool_name: impl Into<String>) {
84        let name = tool_name.into();
85        self.auto_approve_tools.remove(&name);
86        self.always_deny_tools.insert(name);
87    }
88}
89
90/// Configurable gate that maps risk into oversight requirements.
91#[derive(Debug, Clone)]
92pub struct HitlGate {
93    pub require_explainability: bool,
94    pub emergency_override_enabled: bool,
95    pub policy: HitlPolicy,
96}
97
98impl HitlGate {
99    pub fn new(require_explainability: bool, emergency_override_enabled: bool) -> Self {
100        Self {
101            require_explainability,
102            emergency_override_enabled,
103            policy: HitlPolicy::new(),
104        }
105    }
106
107    /// Create gate with a specific policy
108    pub fn with_policy(policy: HitlPolicy) -> Self {
109        Self {
110            require_explainability: true,
111            emergency_override_enabled: false,
112            policy,
113        }
114    }
115
116    /// Decide based on risk level (original behavior)
117    pub fn decide(&self, risk: RiskLevel) -> OversightDecision {
118        match risk {
119            RiskLevel::High => OversightDecision::RequireApproval,
120            RiskLevel::Medium => {
121                if self.require_explainability {
122                    OversightDecision::RequireApproval
123                } else {
124                    OversightDecision::Allow
125                }
126            }
127            RiskLevel::Low => OversightDecision::Allow,
128            RiskLevel::Critical => OversightDecision::RequireApproval,
129        }
130    }
131
132    /// Decide for a specific tool, considering both risk and policy
133    pub fn decide_for_tool(&self, tool_name: &str, risk: RiskLevel) -> OversightDecision {
134        // Policy blacklist takes absolute precedence
135        let policy_decision = self.policy.check_tool(tool_name);
136        if policy_decision == OversightDecision::Deny {
137            return OversightDecision::Deny;
138        }
139
140        // Policy whitelist overrides risk-based decision
141        if policy_decision == OversightDecision::Allow {
142            return OversightDecision::Allow;
143        }
144
145        // Fall back to risk-based decision
146        self.decide(risk)
147    }
148
149    pub fn override_decision(
150        &self,
151        decision: OversightDecision,
152        reason: impl Into<String>,
153        trail: &mut HitlAuditTrail,
154    ) -> Result<()> {
155        if !self.emergency_override_enabled {
156            anyhow::bail!("emergency override disabled");
157        }
158
159        trail.record(decision, reason);
160        Ok(())
161    }
162}
163
164/// Audit record for HITL decisions.
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct HitlEvent {
167    pub decision: OversightDecision,
168    pub reason: String,
169    pub tool_name: Option<String>,
170    pub at: SystemTime,
171}
172
173#[derive(Debug, Default, Clone)]
174pub struct HitlAuditTrail {
175    events: Vec<HitlEvent>,
176}
177
178impl HitlAuditTrail {
179    pub fn record(&mut self, decision: OversightDecision, reason: impl Into<String>) {
180        self.events.push(HitlEvent {
181            decision,
182            reason: reason.into(),
183            tool_name: None,
184            at: SystemTime::now(),
185        });
186    }
187
188    /// Record a decision for a specific tool
189    pub fn record_tool_decision(
190        &mut self,
191        tool_name: impl Into<String>,
192        decision: OversightDecision,
193        reason: impl Into<String>,
194    ) {
195        self.events.push(HitlEvent {
196            decision,
197            reason: reason.into(),
198            tool_name: Some(tool_name.into()),
199            at: SystemTime::now(),
200        });
201    }
202
203    pub fn events(&self) -> &[HitlEvent] {
204        &self.events
205    }
206
207    /// Get events for a specific tool
208    pub fn events_for_tool(&self, tool_name: &str) -> Vec<&HitlEvent> {
209        self.events
210            .iter()
211            .filter(|e| e.tool_name.as_deref() == Some(tool_name))
212            .collect()
213    }
214
215    /// Export audit trail as JSON for security logging
216    pub fn to_json(&self) -> Result<String> {
217        serde_json::to_string_pretty(&self.events).map_err(Into::into)
218    }
219
220    /// Get statistics about decisions
221    pub fn statistics(&self) -> HitlStatistics {
222        let mut stats = HitlStatistics::default();
223        for event in &self.events {
224            match event.decision {
225                OversightDecision::Allow => stats.allowed += 1,
226                OversightDecision::Deny => stats.denied += 1,
227                OversightDecision::RequireApproval => stats.required_approval += 1,
228            }
229        }
230        stats.total = self.events.len();
231        stats
232    }
233
234    /// Clear old events (for memory management)
235    pub fn prune_old_events(&mut self, max_count: usize) -> usize {
236        if self.events.len() <= max_count {
237            return 0;
238        }
239        let excess = self.events.len() - max_count;
240        self.events.drain(0..excess);
241        excess
242    }
243}
244
245/// Statistics about HITL decisions
246#[derive(Debug, Default, Clone)]
247pub struct HitlStatistics {
248    pub total: usize,
249    pub allowed: usize,
250    pub denied: usize,
251    pub required_approval: usize,
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn approval_required_for_high_risk() {
260        let gate = HitlGate::new(true, true);
261        assert_eq!(
262            gate.decide(RiskLevel::High),
263            OversightDecision::RequireApproval
264        );
265    }
266
267    #[test]
268    fn policy_whitelist_auto_approves() {
269        let mut policy = HitlPolicy::new();
270        policy.whitelist_tool("read_file");
271
272        assert_eq!(policy.check_tool("read_file"), OversightDecision::Allow);
273        assert_eq!(
274            policy.check_tool("write_file"),
275            OversightDecision::RequireApproval
276        );
277    }
278
279    #[test]
280    fn policy_blacklist_denies() {
281        let mut policy = HitlPolicy::new();
282        policy.blacklist_tool("dangerous_tool");
283
284        assert_eq!(policy.check_tool("dangerous_tool"), OversightDecision::Deny);
285    }
286
287    #[test]
288    fn blacklist_overrides_whitelist() {
289        let mut policy = HitlPolicy::new();
290        policy.whitelist_tool("tool");
291        policy.blacklist_tool("tool");
292
293        // Blacklist should win
294        assert_eq!(policy.check_tool("tool"), OversightDecision::Deny);
295    }
296
297    #[test]
298    fn decide_for_tool_respects_policy() {
299        let policy = HitlPolicy::permissive();
300        let gate = HitlGate::with_policy(policy);
301
302        // Whitelisted tool is allowed even with high risk
303        assert_eq!(
304            gate.decide_for_tool("read_file", RiskLevel::High),
305            OversightDecision::Allow
306        );
307    }
308
309    #[test]
310    fn audit_trail_statistics() {
311        let mut trail = HitlAuditTrail::default();
312        trail.record(OversightDecision::Allow, "allowed 1");
313        trail.record(OversightDecision::Allow, "allowed 2");
314        trail.record(OversightDecision::Deny, "denied 1");
315
316        let stats = trail.statistics();
317        assert_eq!(stats.total, 3);
318        assert_eq!(stats.allowed, 2);
319        assert_eq!(stats.denied, 1);
320    }
321
322    #[test]
323    fn audit_trail_json_export() {
324        let mut trail = HitlAuditTrail::default();
325        trail.record_tool_decision("test_tool", OversightDecision::Allow, "test reason");
326
327        let json = trail.to_json().unwrap();
328        assert!(json.contains("test_tool"));
329        assert!(json.contains("Allow"));
330    }
331}