1use crate::TRon;
8use crate::gate::{DenyCode, ToolCall, Verdict};
9use bote::Dispatcher;
10use bote::protocol::{JsonRpcRequest, JsonRpcResponse};
11
12const SECURITY_DENIED: i32 = -32001;
14
15pub struct SecurityGate {
21 tron: TRon,
22 inner: Dispatcher,
23}
24
25impl SecurityGate {
26 #[must_use]
28 pub fn new(tron: TRon, dispatcher: Dispatcher) -> Self {
29 Self {
30 tron,
31 inner: dispatcher,
32 }
33 }
34
35 #[must_use]
37 pub fn dispatcher_mut(&mut self) -> &mut Dispatcher {
38 &mut self.inner
39 }
40
41 #[must_use]
43 pub fn dispatcher(&self) -> &Dispatcher {
44 &self.inner
45 }
46
47 #[must_use]
49 pub fn tron(&self) -> &TRon {
50 &self.tron
51 }
52
53 pub fn register_tool_handlers(&mut self) {
61 use crate::tools;
62 let query = self.tron.query();
63 self.inner
64 .handle("tron_status", tools::status_handler(query.clone()));
65 self.inner
66 .handle("tron_risk", tools::risk_handler(query.clone()));
67 self.inner.handle("tron_audit", tools::audit_handler(query));
68 self.inner
69 .handle("tron_policy", tools::policy_handler(&self.tron));
70 }
71
72 pub async fn dispatch(
78 &self,
79 request: &JsonRpcRequest,
80 agent_id: &str,
81 ) -> Option<JsonRpcResponse> {
82 if request.method == "tools/call"
83 && let Some(denied) = self.check_tool_call(request, agent_id).await
84 {
85 return Some(denied);
86 }
87 self.inner.dispatch(request)
88 }
89
90 pub async fn dispatch_streaming(
92 &self,
93 request: &JsonRpcRequest,
94 agent_id: &str,
95 ) -> bote::DispatchOutcome {
96 if request.method == "tools/call"
97 && let Some(denied) = self.check_tool_call(request, agent_id).await
98 {
99 return bote::DispatchOutcome::Immediate(Some(denied));
100 }
101 self.inner.dispatch_streaming(request)
102 }
103
104 async fn check_tool_call(
107 &self,
108 request: &JsonRpcRequest,
109 agent_id: &str,
110 ) -> Option<JsonRpcResponse> {
111 let id = request.id.clone().unwrap_or(serde_json::Value::Null);
112 let tool_name = match request.params.get("name").and_then(|v| v.as_str()) {
113 Some(name) if !name.is_empty() => name,
114 _ => {
115 return Some(Self::deny_response(
116 id,
117 "missing or empty tool name in tools/call",
118 DenyCode::Unauthorized,
119 ));
120 }
121 };
122 let arguments = request
123 .params
124 .get("arguments")
125 .cloned()
126 .unwrap_or(serde_json::json!({}));
127
128 let call = ToolCall {
129 agent_id: agent_id.to_string(),
130 tool_name: tool_name.to_string(),
131 params: arguments,
132 timestamp: chrono::Utc::now(),
133 };
134
135 let verdict = self.tron.check(&call).await;
136 match verdict {
137 Verdict::Deny { reason, code } => {
138 tracing::warn!(
139 agent = agent_id,
140 tool = tool_name,
141 code = ?code,
142 "security gate denied tool call: {reason}"
143 );
144 Some(Self::deny_response(id, &reason, code))
145 }
146 Verdict::Flag { reason } => {
147 tracing::info!(
148 agent = agent_id,
149 tool = tool_name,
150 "security gate flagged tool call: {reason}"
151 );
152 None
154 }
155 Verdict::Allow => None,
156 }
157 }
158
159 fn deny_response(id: serde_json::Value, reason: &str, code: DenyCode) -> JsonRpcResponse {
161 JsonRpcResponse::error(id, SECURITY_DENIED, format!("security: {reason} [{code}]"))
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::{DefaultAction, TRonConfig};
169 use bote::registry::{ToolDef, ToolRegistry, ToolSchema};
170 use std::collections::HashMap;
171 use std::sync::Arc;
172
173 fn make_gate(config: TRonConfig) -> SecurityGate {
174 let tron = TRon::new(config);
175 let mut reg = ToolRegistry::new();
176 reg.register(ToolDef {
177 name: "echo".into(),
178 description: "Echo input".into(),
179 input_schema: ToolSchema {
180 schema_type: "object".into(),
181 properties: HashMap::new(),
182 required: vec![],
183 },
184 });
185 let mut dispatcher = Dispatcher::new(reg);
186 dispatcher.handle(
187 "echo",
188 Arc::new(|params| {
189 serde_json::json!({"content": [{"type": "text", "text": params.to_string()}]})
190 }),
191 );
192 SecurityGate::new(tron, dispatcher)
193 }
194
195 fn tool_call_request(tool_name: &str, arguments: serde_json::Value) -> JsonRpcRequest {
196 JsonRpcRequest::new(1, "tools/call")
197 .with_params(serde_json::json!({"name": tool_name, "arguments": arguments}))
198 }
199
200 #[tokio::test]
201 async fn deny_unknown_agent() {
202 let gate = make_gate(TRonConfig::default());
203 let req = tool_call_request("echo", serde_json::json!({}));
204 let resp = gate.dispatch(&req, "nobody").await.unwrap();
205 assert!(resp.error.is_some());
206 let err = resp.error.unwrap();
207 assert_eq!(err.code, SECURITY_DENIED);
208 assert!(err.message.contains("unauthorized"));
209 }
210
211 #[tokio::test]
212 async fn allow_known_agent() {
213 let config = TRonConfig {
214 default_unknown_agent: DefaultAction::Allow,
215 default_unknown_tool: DefaultAction::Allow,
216 ..Default::default()
217 };
218 let gate = make_gate(config);
219 let req = tool_call_request("echo", serde_json::json!({"msg": "hello"}));
220 let resp = gate.dispatch(&req, "agent-1").await.unwrap();
221 assert!(resp.error.is_none());
222 assert!(resp.result.is_some());
223 }
224
225 #[tokio::test]
226 async fn allow_with_policy() {
227 let gate = make_gate(TRonConfig::default());
228 gate.tron()
229 .load_policy(
230 r#"
231[agent."web-agent"]
232allow = ["echo"]
233"#,
234 )
235 .unwrap();
236 let req = tool_call_request("echo", serde_json::json!({}));
237 let resp = gate.dispatch(&req, "web-agent").await.unwrap();
238 assert!(resp.error.is_none());
239 }
240
241 #[tokio::test]
242 async fn deny_by_policy() {
243 let gate = make_gate(TRonConfig::default());
244 gate.tron()
245 .load_policy(
246 r#"
247[agent."restricted"]
248allow = ["tarang_*"]
249deny = ["echo"]
250"#,
251 )
252 .unwrap();
253 let req = tool_call_request("echo", serde_json::json!({}));
254 let resp = gate.dispatch(&req, "restricted").await.unwrap();
255 assert!(resp.error.is_some());
256 }
257
258 #[tokio::test]
259 async fn deny_injection() {
260 let config = TRonConfig {
261 default_unknown_agent: DefaultAction::Allow,
262 default_unknown_tool: DefaultAction::Allow,
263 ..Default::default()
264 };
265 let gate = make_gate(config);
266 let req = tool_call_request(
267 "echo",
268 serde_json::json!({"q": "1 UNION SELECT * FROM passwords"}),
269 );
270 let resp = gate.dispatch(&req, "agent").await.unwrap();
271 assert!(resp.error.is_some());
272 let err = resp.error.unwrap();
273 assert!(err.message.contains("injection_detected"));
274 }
275
276 #[tokio::test]
277 async fn non_tool_call_passes_through() {
278 let gate = make_gate(TRonConfig::default());
279 let req = JsonRpcRequest::new(1, "initialize");
281 let resp = gate.dispatch(&req, "unknown-agent").await.unwrap();
282 assert!(resp.result.is_some());
283 }
284
285 #[tokio::test]
286 async fn tools_list_passes_through() {
287 let gate = make_gate(TRonConfig::default());
288 let req = JsonRpcRequest::new(1, "tools/list");
289 let resp = gate.dispatch(&req, "unknown-agent").await.unwrap();
290 let result = resp.result.unwrap();
291 let tools = result["tools"].as_array().unwrap();
292 assert_eq!(tools.len(), 1);
293 }
294
295 #[tokio::test]
296 async fn rate_limit_through_gate() {
297 let config = TRonConfig {
298 default_unknown_agent: DefaultAction::Allow,
299 default_unknown_tool: DefaultAction::Allow,
300 scan_payloads: false,
301 analyze_patterns: false,
302 ..Default::default()
303 };
304 let gate = make_gate(config);
305 let req = tool_call_request("echo", serde_json::json!({}));
306 for _ in 0..60 {
307 let resp = gate.dispatch(&req, "agent").await.unwrap();
308 assert!(resp.error.is_none());
309 }
310 let resp = gate.dispatch(&req, "agent").await.unwrap();
312 assert!(resp.error.is_some());
313 assert!(resp.error.unwrap().message.contains("rate_limited"));
314 }
315
316 #[tokio::test]
317 async fn streaming_dispatch_denied() {
318 let gate = make_gate(TRonConfig::default());
319 let req = tool_call_request("echo", serde_json::json!({}));
320 match gate.dispatch_streaming(&req, "nobody").await {
321 bote::DispatchOutcome::Immediate(Some(resp)) => {
322 assert!(resp.error.is_some());
323 }
324 _ => panic!("expected Immediate(Some) for denied call"),
325 }
326 }
327
328 #[tokio::test]
329 async fn streaming_dispatch_allowed() {
330 let config = TRonConfig {
331 default_unknown_agent: DefaultAction::Allow,
332 default_unknown_tool: DefaultAction::Allow,
333 ..Default::default()
334 };
335 let gate = make_gate(config);
336 let req = tool_call_request("echo", serde_json::json!({}));
337 match gate.dispatch_streaming(&req, "agent").await {
338 bote::DispatchOutcome::Immediate(Some(resp)) => {
339 assert!(resp.error.is_none());
340 }
341 _ => panic!("expected Immediate(Some) for allowed sync tool"),
342 }
343 }
344
345 #[tokio::test]
346 async fn audit_logged_through_gate() {
347 let config = TRonConfig {
348 default_unknown_agent: DefaultAction::Allow,
349 default_unknown_tool: DefaultAction::Allow,
350 scan_payloads: false,
351 analyze_patterns: false,
352 ..Default::default()
353 };
354 let gate = make_gate(config);
355 let req = tool_call_request("echo", serde_json::json!({}));
356 gate.dispatch(&req, "agent-1").await;
357
358 let query = gate.tron().query();
359 assert_eq!(query.total_events().await, 1);
360 }
361
362 #[tokio::test]
363 async fn deny_missing_tool_name() {
364 let config = TRonConfig {
365 default_unknown_agent: DefaultAction::Allow,
366 default_unknown_tool: DefaultAction::Allow,
367 ..Default::default()
368 };
369 let gate = make_gate(config);
370 let req =
372 JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({"arguments": {}}));
373 let resp = gate.dispatch(&req, "agent").await.unwrap();
374 assert!(resp.error.is_some());
375 assert!(resp.error.unwrap().message.contains("missing"));
376 }
377
378 #[tokio::test]
379 async fn deny_empty_tool_name() {
380 let config = TRonConfig {
381 default_unknown_agent: DefaultAction::Allow,
382 default_unknown_tool: DefaultAction::Allow,
383 ..Default::default()
384 };
385 let gate = make_gate(config);
386 let req = tool_call_request("", serde_json::json!({}));
387 let resp = gate.dispatch(&req, "agent").await.unwrap();
388 assert!(resp.error.is_some());
389 assert!(resp.error.unwrap().message.contains("missing"));
390 }
391
392 #[tokio::test]
393 async fn deny_response_format() {
394 let resp = SecurityGate::deny_response(
395 serde_json::json!(42),
396 "rate limit exceeded",
397 DenyCode::RateLimited,
398 );
399 assert_eq!(resp.id, serde_json::json!(42));
400 assert!(resp.error.is_some());
401 let err = resp.error.unwrap();
402 assert_eq!(err.code, SECURITY_DENIED);
403 assert!(err.message.contains("rate_limited"));
404 assert!(err.message.contains("rate limit exceeded"));
405 }
406}