Skip to main content

synwire_agent/sandbox/
threshold_gate.rs

1//! Threshold-based approval gate.
2
3use std::collections::HashSet;
4
5use tokio::sync::Mutex;
6
7use synwire_core::BoxFuture;
8use synwire_core::sandbox::approval::{
9    ApprovalCallback, ApprovalDecision, ApprovalRequest, RiskLevel,
10};
11
12/// Auto-approves operations up to a given risk level; delegates higher-risk
13/// operations to an inner callback.
14pub struct ThresholdGate {
15    threshold: RiskLevel,
16    inner: Box<dyn ApprovalCallback>,
17    /// Operations that have been globally approved via `AllowAlways`.
18    always_allowed: Mutex<HashSet<String>>,
19}
20
21impl std::fmt::Debug for ThresholdGate {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("ThresholdGate")
24            .field("threshold", &self.threshold)
25            .finish_non_exhaustive()
26    }
27}
28
29impl ThresholdGate {
30    /// Create a new threshold gate.
31    pub fn new(threshold: RiskLevel, inner: impl ApprovalCallback + 'static) -> Self {
32        Self {
33            threshold,
34            inner: Box::new(inner),
35            always_allowed: Mutex::new(HashSet::new()),
36        }
37    }
38
39    async fn is_always_allowed(&self, operation: &str) -> bool {
40        self.always_allowed.lock().await.contains(operation)
41    }
42
43    async fn record_always_allowed(&self, operation: &str) {
44        let _ = self
45            .always_allowed
46            .lock()
47            .await
48            .insert(operation.to_string());
49    }
50}
51
52impl ApprovalCallback for ThresholdGate {
53    fn request(&self, req: ApprovalRequest) -> BoxFuture<'_, ApprovalDecision> {
54        Box::pin(async move {
55            // Check AllowAlways cache first.
56            if self.is_always_allowed(&req.operation).await {
57                return ApprovalDecision::Allow;
58            }
59
60            // Auto-approve if within threshold.
61            if req.risk <= self.threshold {
62                return ApprovalDecision::Allow;
63            }
64
65            // Delegate to inner callback for higher-risk operations.
66            let operation = req.operation.clone();
67            let decision = self.inner.request(req).await;
68
69            if matches!(decision, ApprovalDecision::AllowAlways) {
70                self.record_always_allowed(&operation).await;
71                return ApprovalDecision::AllowAlways;
72            }
73
74            decision
75        })
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use synwire_core::sandbox::approval::AutoDenyCallback;
83
84    fn req(risk: RiskLevel) -> ApprovalRequest {
85        ApprovalRequest {
86            operation: "test_op".to_string(),
87            description: "test".to_string(),
88            risk,
89            timeout_secs: None,
90            context: serde_json::json!({}),
91        }
92    }
93
94    #[tokio::test]
95    async fn test_auto_approve_below_threshold() {
96        let gate = ThresholdGate::new(RiskLevel::Medium, AutoDenyCallback);
97        let decision = gate.request(req(RiskLevel::Low)).await;
98        assert!(matches!(decision, ApprovalDecision::Allow));
99    }
100
101    #[tokio::test]
102    async fn test_delegate_above_threshold() {
103        let gate = ThresholdGate::new(RiskLevel::Low, AutoDenyCallback);
104        let decision = gate.request(req(RiskLevel::High)).await;
105        assert!(matches!(decision, ApprovalDecision::Deny));
106    }
107
108    #[tokio::test]
109    async fn test_allow_always_caching() {
110        struct AllowAlwaysCallback;
111        impl ApprovalCallback for AllowAlwaysCallback {
112            fn request(&self, _req: ApprovalRequest) -> BoxFuture<'_, ApprovalDecision> {
113                Box::pin(async { ApprovalDecision::AllowAlways })
114            }
115        }
116
117        let gate = ThresholdGate::new(RiskLevel::None, AllowAlwaysCallback);
118
119        // First request: above threshold, delegates to inner → AllowAlways.
120        let d1 = gate.request(req(RiskLevel::High)).await;
121        assert!(matches!(d1, ApprovalDecision::AllowAlways));
122
123        // Second request: should be served from cache without calling inner.
124        let d2 = gate.request(req(RiskLevel::Critical)).await;
125        assert!(matches!(d2, ApprovalDecision::Allow));
126    }
127}