Skip to main content

synaptic_middleware/
security.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use synaptic_core::SynapticError;
7
8use crate::{AgentMiddleware, ToolCallRequest, ToolCaller};
9
10/// Risk level for a tool call.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
12pub enum RiskLevel {
13    None,
14    Low,
15    Medium,
16    High,
17    Critical,
18}
19
20/// Assesses the risk level of a tool call.
21#[async_trait]
22pub trait SecurityAnalyzer: Send + Sync {
23    async fn assess(&self, tool_name: &str, args: &Value) -> Result<RiskLevel, SynapticError>;
24}
25
26/// Determines whether a tool call at a given risk level requires confirmation.
27#[async_trait]
28pub trait ConfirmationPolicy: Send + Sync {
29    async fn should_confirm(&self, tool_name: &str, risk: RiskLevel)
30        -> Result<bool, SynapticError>;
31}
32
33/// Callback for obtaining user confirmation before executing a risky tool call.
34#[async_trait]
35pub trait SecurityConfirmationCallback: Send + Sync {
36    async fn confirm(
37        &self,
38        tool_name: &str,
39        args: &Value,
40        risk: RiskLevel,
41    ) -> Result<bool, SynapticError>;
42}
43
44/// Rule-based security analyzer that maps tool names and argument patterns to risk levels.
45pub struct RuleBasedAnalyzer {
46    tool_risks: HashMap<String, RiskLevel>,
47    arg_patterns: Vec<ArgPattern>,
48    default_risk: RiskLevel,
49}
50
51/// A pattern that elevates risk when matched in tool arguments.
52struct ArgPattern {
53    key: String,
54    pattern: String,
55    risk: RiskLevel,
56}
57
58impl RuleBasedAnalyzer {
59    pub fn new() -> Self {
60        Self {
61            tool_risks: HashMap::new(),
62            arg_patterns: Vec::new(),
63            default_risk: RiskLevel::Low,
64        }
65    }
66
67    /// Set the default risk level for unknown tools.
68    pub fn with_default_risk(mut self, risk: RiskLevel) -> Self {
69        self.default_risk = risk;
70        self
71    }
72
73    /// Set the risk level for a specific tool.
74    pub fn with_tool_risk(mut self, tool_name: impl Into<String>, risk: RiskLevel) -> Self {
75        self.tool_risks.insert(tool_name.into(), risk);
76        self
77    }
78
79    /// Add an argument pattern that elevates risk.
80    /// If the argument value for `key` contains `pattern`, the risk is elevated to `risk`.
81    pub fn with_arg_pattern(
82        mut self,
83        key: impl Into<String>,
84        pattern: impl Into<String>,
85        risk: RiskLevel,
86    ) -> Self {
87        self.arg_patterns.push(ArgPattern {
88            key: key.into(),
89            pattern: pattern.into(),
90            risk,
91        });
92        self
93    }
94}
95
96impl Default for RuleBasedAnalyzer {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102#[async_trait]
103impl SecurityAnalyzer for RuleBasedAnalyzer {
104    async fn assess(&self, tool_name: &str, args: &Value) -> Result<RiskLevel, SynapticError> {
105        let mut risk = self
106            .tool_risks
107            .get(tool_name)
108            .copied()
109            .unwrap_or(self.default_risk);
110
111        // Check argument patterns - elevate risk if matched
112        for pattern in &self.arg_patterns {
113            if let Some(val) = args.get(&pattern.key) {
114                let val_str = match val {
115                    Value::String(s) => s.clone(),
116                    other => other.to_string(),
117                };
118                if val_str.contains(&pattern.pattern) && pattern.risk > risk {
119                    risk = pattern.risk;
120                }
121            }
122        }
123
124        Ok(risk)
125    }
126}
127
128/// Confirms tool calls that meet or exceed a risk threshold.
129pub struct ThresholdConfirmationPolicy {
130    threshold: RiskLevel,
131}
132
133impl ThresholdConfirmationPolicy {
134    pub fn new(threshold: RiskLevel) -> Self {
135        Self { threshold }
136    }
137}
138
139#[async_trait]
140impl ConfirmationPolicy for ThresholdConfirmationPolicy {
141    async fn should_confirm(
142        &self,
143        _tool_name: &str,
144        risk: RiskLevel,
145    ) -> Result<bool, SynapticError> {
146        Ok(risk >= self.threshold)
147    }
148}
149
150/// Middleware that assesses tool call risk and optionally requires confirmation.
151pub struct SecurityMiddleware {
152    analyzer: Arc<dyn SecurityAnalyzer>,
153    policy: Arc<dyn ConfirmationPolicy>,
154    callback: Arc<dyn SecurityConfirmationCallback>,
155    /// Tools that bypass security checks entirely.
156    bypass: HashSet<String>,
157}
158
159impl SecurityMiddleware {
160    pub fn new(
161        analyzer: Arc<dyn SecurityAnalyzer>,
162        policy: Arc<dyn ConfirmationPolicy>,
163        callback: Arc<dyn SecurityConfirmationCallback>,
164    ) -> Self {
165        Self {
166            analyzer,
167            policy,
168            callback,
169            bypass: HashSet::new(),
170        }
171    }
172
173    /// Add tools that should bypass security checks.
174    pub fn with_bypass(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
175        self.bypass = tools.into_iter().map(|s| s.into()).collect();
176        self
177    }
178}
179
180#[async_trait]
181impl AgentMiddleware for SecurityMiddleware {
182    async fn wrap_tool_call(
183        &self,
184        request: ToolCallRequest,
185        next: &dyn ToolCaller,
186    ) -> Result<Value, SynapticError> {
187        let tool_name = &request.call.name;
188
189        // Check bypass list
190        if self.bypass.contains(tool_name) {
191            return next.call(request).await;
192        }
193
194        // Assess risk
195        let risk = self
196            .analyzer
197            .assess(tool_name, &request.call.arguments)
198            .await?;
199
200        // Check if confirmation is needed
201        let needs_confirm = self.policy.should_confirm(tool_name, risk).await?;
202
203        if needs_confirm {
204            let confirmed = self
205                .callback
206                .confirm(tool_name, &request.call.arguments, risk)
207                .await?;
208            if !confirmed {
209                return Err(SynapticError::Tool(format!(
210                    "tool call '{}' rejected by security policy (risk: {:?})",
211                    tool_name, risk
212                )));
213            }
214        }
215
216        next.call(request).await
217    }
218}