synwire_agent/sandbox/
threshold_gate.rs1use std::collections::HashSet;
4
5use tokio::sync::Mutex;
6
7use synwire_core::BoxFuture;
8use synwire_core::sandbox::approval::{
9 ApprovalCallback, ApprovalDecision, ApprovalRequest, RiskLevel,
10};
11
12pub struct ThresholdGate {
15 threshold: RiskLevel,
16 inner: Box<dyn ApprovalCallback>,
17 always_allowed: Mutex<HashSet<String>>,
19}
20
21impl std::fmt::Debug for ThresholdGate {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("ThresholdGate")
24 .field("threshold", &self.threshold)
25 .finish_non_exhaustive()
26 }
27}
28
29impl ThresholdGate {
30 pub fn new(threshold: RiskLevel, inner: impl ApprovalCallback + 'static) -> Self {
32 Self {
33 threshold,
34 inner: Box::new(inner),
35 always_allowed: Mutex::new(HashSet::new()),
36 }
37 }
38
39 async fn is_always_allowed(&self, operation: &str) -> bool {
40 self.always_allowed.lock().await.contains(operation)
41 }
42
43 async fn record_always_allowed(&self, operation: &str) {
44 let _ = self
45 .always_allowed
46 .lock()
47 .await
48 .insert(operation.to_string());
49 }
50}
51
52impl ApprovalCallback for ThresholdGate {
53 fn request(&self, req: ApprovalRequest) -> BoxFuture<'_, ApprovalDecision> {
54 Box::pin(async move {
55 if self.is_always_allowed(&req.operation).await {
57 return ApprovalDecision::Allow;
58 }
59
60 if req.risk <= self.threshold {
62 return ApprovalDecision::Allow;
63 }
64
65 let operation = req.operation.clone();
67 let decision = self.inner.request(req).await;
68
69 if matches!(decision, ApprovalDecision::AllowAlways) {
70 self.record_always_allowed(&operation).await;
71 return ApprovalDecision::AllowAlways;
72 }
73
74 decision
75 })
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use synwire_core::sandbox::approval::AutoDenyCallback;
83
84 fn req(risk: RiskLevel) -> ApprovalRequest {
85 ApprovalRequest {
86 operation: "test_op".to_string(),
87 description: "test".to_string(),
88 risk,
89 timeout_secs: None,
90 context: serde_json::json!({}),
91 }
92 }
93
94 #[tokio::test]
95 async fn test_auto_approve_below_threshold() {
96 let gate = ThresholdGate::new(RiskLevel::Medium, AutoDenyCallback);
97 let decision = gate.request(req(RiskLevel::Low)).await;
98 assert!(matches!(decision, ApprovalDecision::Allow));
99 }
100
101 #[tokio::test]
102 async fn test_delegate_above_threshold() {
103 let gate = ThresholdGate::new(RiskLevel::Low, AutoDenyCallback);
104 let decision = gate.request(req(RiskLevel::High)).await;
105 assert!(matches!(decision, ApprovalDecision::Deny));
106 }
107
108 #[tokio::test]
109 async fn test_allow_always_caching() {
110 struct AllowAlwaysCallback;
111 impl ApprovalCallback for AllowAlwaysCallback {
112 fn request(&self, _req: ApprovalRequest) -> BoxFuture<'_, ApprovalDecision> {
113 Box::pin(async { ApprovalDecision::AllowAlways })
114 }
115 }
116
117 let gate = ThresholdGate::new(RiskLevel::None, AllowAlwaysCallback);
118
119 let d1 = gate.request(req(RiskLevel::High)).await;
121 assert!(matches!(d1, ApprovalDecision::AllowAlways));
122
123 let d2 = gate.request(req(RiskLevel::Critical)).await;
125 assert!(matches!(d2, ApprovalDecision::Allow));
126 }
127}