1use hashbrown::HashSet;
2use std::time::SystemTime;
3
4use anyhow::Result;
5use serde::{Deserialize, Serialize};
6
7use crate::tools::registry::RiskLevel;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
11pub enum OversightDecision {
12 Allow,
13 Deny,
14 RequireApproval,
15}
16
17#[derive(Debug, Clone, Default)]
19pub struct HitlPolicy {
20 pub auto_approve_tools: HashSet<String>,
22 pub always_deny_tools: HashSet<String>,
24 pub default_require_approval: bool,
26}
27
28impl HitlPolicy {
29 pub fn new() -> Self {
30 Self {
31 auto_approve_tools: HashSet::new(),
32 always_deny_tools: HashSet::new(),
33 default_require_approval: true,
34 }
35 }
36
37 pub fn permissive() -> Self {
39 let mut policy = Self::new();
40 policy.default_require_approval = false;
41 for tool in &[
43 crate::config::constants::tools::LIST_FILES,
44 crate::config::constants::tools::READ_FILE,
45 "grep_search",
46 "view_file",
47 ] {
48 policy.auto_approve_tools.insert(tool.to_string());
49 }
50 policy
51 }
52
53 pub fn strict() -> Self {
55 Self {
56 auto_approve_tools: HashSet::new(),
57 always_deny_tools: HashSet::new(),
58 default_require_approval: true,
59 }
60 }
61
62 pub fn check_tool(&self, tool_name: &str) -> OversightDecision {
64 if self.always_deny_tools.contains(tool_name) {
65 OversightDecision::Deny
66 } else if self.auto_approve_tools.contains(tool_name) {
67 OversightDecision::Allow
68 } else if self.default_require_approval {
69 OversightDecision::RequireApproval
70 } else {
71 OversightDecision::Allow
72 }
73 }
74
75 pub fn whitelist_tool(&mut self, tool_name: impl Into<String>) {
77 let name = tool_name.into();
78 self.always_deny_tools.remove(&name);
79 self.auto_approve_tools.insert(name);
80 }
81
82 pub fn blacklist_tool(&mut self, tool_name: impl Into<String>) {
84 let name = tool_name.into();
85 self.auto_approve_tools.remove(&name);
86 self.always_deny_tools.insert(name);
87 }
88}
89
90#[derive(Debug, Clone)]
92pub struct HitlGate {
93 pub require_explainability: bool,
94 pub emergency_override_enabled: bool,
95 pub policy: HitlPolicy,
96}
97
98impl HitlGate {
99 pub fn new(require_explainability: bool, emergency_override_enabled: bool) -> Self {
100 Self {
101 require_explainability,
102 emergency_override_enabled,
103 policy: HitlPolicy::new(),
104 }
105 }
106
107 pub fn with_policy(policy: HitlPolicy) -> Self {
109 Self {
110 require_explainability: true,
111 emergency_override_enabled: false,
112 policy,
113 }
114 }
115
116 pub fn decide(&self, risk: RiskLevel) -> OversightDecision {
118 match risk {
119 RiskLevel::High => OversightDecision::RequireApproval,
120 RiskLevel::Medium => {
121 if self.require_explainability {
122 OversightDecision::RequireApproval
123 } else {
124 OversightDecision::Allow
125 }
126 }
127 RiskLevel::Low => OversightDecision::Allow,
128 RiskLevel::Critical => OversightDecision::RequireApproval,
129 }
130 }
131
132 pub fn decide_for_tool(&self, tool_name: &str, risk: RiskLevel) -> OversightDecision {
134 let policy_decision = self.policy.check_tool(tool_name);
136 if policy_decision == OversightDecision::Deny {
137 return OversightDecision::Deny;
138 }
139
140 if policy_decision == OversightDecision::Allow {
142 return OversightDecision::Allow;
143 }
144
145 self.decide(risk)
147 }
148
149 pub fn override_decision(
150 &self,
151 decision: OversightDecision,
152 reason: impl Into<String>,
153 trail: &mut HitlAuditTrail,
154 ) -> Result<()> {
155 if !self.emergency_override_enabled {
156 anyhow::bail!("emergency override disabled");
157 }
158
159 trail.record(decision, reason);
160 Ok(())
161 }
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct HitlEvent {
167 pub decision: OversightDecision,
168 pub reason: String,
169 pub tool_name: Option<String>,
170 pub at: SystemTime,
171}
172
173#[derive(Debug, Default, Clone)]
174pub struct HitlAuditTrail {
175 events: Vec<HitlEvent>,
176}
177
178impl HitlAuditTrail {
179 pub fn record(&mut self, decision: OversightDecision, reason: impl Into<String>) {
180 self.events.push(HitlEvent {
181 decision,
182 reason: reason.into(),
183 tool_name: None,
184 at: SystemTime::now(),
185 });
186 }
187
188 pub fn record_tool_decision(
190 &mut self,
191 tool_name: impl Into<String>,
192 decision: OversightDecision,
193 reason: impl Into<String>,
194 ) {
195 self.events.push(HitlEvent {
196 decision,
197 reason: reason.into(),
198 tool_name: Some(tool_name.into()),
199 at: SystemTime::now(),
200 });
201 }
202
203 pub fn events(&self) -> &[HitlEvent] {
204 &self.events
205 }
206
207 pub fn events_for_tool(&self, tool_name: &str) -> Vec<&HitlEvent> {
209 self.events
210 .iter()
211 .filter(|e| e.tool_name.as_deref() == Some(tool_name))
212 .collect()
213 }
214
215 pub fn to_json(&self) -> Result<String> {
217 serde_json::to_string_pretty(&self.events).map_err(Into::into)
218 }
219
220 pub fn statistics(&self) -> HitlStatistics {
222 let mut stats = HitlStatistics::default();
223 for event in &self.events {
224 match event.decision {
225 OversightDecision::Allow => stats.allowed += 1,
226 OversightDecision::Deny => stats.denied += 1,
227 OversightDecision::RequireApproval => stats.required_approval += 1,
228 }
229 }
230 stats.total = self.events.len();
231 stats
232 }
233
234 pub fn prune_old_events(&mut self, max_count: usize) -> usize {
236 if self.events.len() <= max_count {
237 return 0;
238 }
239 let excess = self.events.len() - max_count;
240 self.events.drain(0..excess);
241 excess
242 }
243}
244
245#[derive(Debug, Default, Clone)]
247pub struct HitlStatistics {
248 pub total: usize,
249 pub allowed: usize,
250 pub denied: usize,
251 pub required_approval: usize,
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn approval_required_for_high_risk() {
260 let gate = HitlGate::new(true, true);
261 assert_eq!(
262 gate.decide(RiskLevel::High),
263 OversightDecision::RequireApproval
264 );
265 }
266
267 #[test]
268 fn policy_whitelist_auto_approves() {
269 let mut policy = HitlPolicy::new();
270 policy.whitelist_tool("read_file");
271
272 assert_eq!(policy.check_tool("read_file"), OversightDecision::Allow);
273 assert_eq!(
274 policy.check_tool("write_file"),
275 OversightDecision::RequireApproval
276 );
277 }
278
279 #[test]
280 fn policy_blacklist_denies() {
281 let mut policy = HitlPolicy::new();
282 policy.blacklist_tool("dangerous_tool");
283
284 assert_eq!(policy.check_tool("dangerous_tool"), OversightDecision::Deny);
285 }
286
287 #[test]
288 fn blacklist_overrides_whitelist() {
289 let mut policy = HitlPolicy::new();
290 policy.whitelist_tool("tool");
291 policy.blacklist_tool("tool");
292
293 assert_eq!(policy.check_tool("tool"), OversightDecision::Deny);
295 }
296
297 #[test]
298 fn decide_for_tool_respects_policy() {
299 let policy = HitlPolicy::permissive();
300 let gate = HitlGate::with_policy(policy);
301
302 assert_eq!(
304 gate.decide_for_tool("read_file", RiskLevel::High),
305 OversightDecision::Allow
306 );
307 }
308
309 #[test]
310 fn audit_trail_statistics() {
311 let mut trail = HitlAuditTrail::default();
312 trail.record(OversightDecision::Allow, "allowed 1");
313 trail.record(OversightDecision::Allow, "allowed 2");
314 trail.record(OversightDecision::Deny, "denied 1");
315
316 let stats = trail.statistics();
317 assert_eq!(stats.total, 3);
318 assert_eq!(stats.allowed, 2);
319 assert_eq!(stats.denied, 1);
320 }
321
322 #[test]
323 fn audit_trail_json_export() {
324 let mut trail = HitlAuditTrail::default();
325 trail.record_tool_decision("test_tool", OversightDecision::Allow, "test reason");
326
327 let json = trail.to_json().unwrap();
328 assert!(json.contains("test_tool"));
329 assert!(json.contains("Allow"));
330 }
331}