1use crate::{AuditEvent, AuditEventType, AuditLogger, SensitivityScanner, TrustLevel};
34use serde::{Deserialize, Serialize};
35use std::collections::HashMap;
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39
40#[derive(Debug, Deserialize)]
43pub struct JsonRpcRequest {
44 pub jsonrpc: String,
45 pub id: Option<serde_json::Value>,
46 pub method: String,
47 #[serde(default)]
48 pub params: serde_json::Value,
49}
50
51#[derive(Debug, Serialize)]
52pub struct JsonRpcResponse {
53 pub jsonrpc: String,
54 pub id: serde_json::Value,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub result: Option<serde_json::Value>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub error: Option<JsonRpcError>,
59}
60
61#[derive(Debug, Serialize)]
62pub struct JsonRpcError {
63 pub code: i32,
64 pub message: String,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub data: Option<serde_json::Value>,
67}
68
69pub type ToolHandler = Box<
73 dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = anyhow::Result<serde_json::Value>> + Send>>
74 + Send
75 + Sync,
76>;
77
78pub struct ToolDef {
80 pub name: String,
81 pub description: String,
82 pub parameters_schema: serde_json::Value,
83 pub handler: ToolHandler,
84}
85
86pub struct SigilMcpServer<S: SensitivityScanner, A: AuditLogger> {
96 name: String,
97 version: String,
98 tools: HashMap<String, ToolEntry>,
99 scanner: Arc<S>,
100 audit: Arc<A>,
101 required_trust: TrustLevel,
103}
104
105struct ToolEntry {
106 description: String,
107 schema: serde_json::Value,
108 handler: ToolHandler,
109 required_trust: Option<TrustLevel>,
111}
112
113impl<S: SensitivityScanner, A: AuditLogger> SigilMcpServer<S, A> {
114 pub fn new(name: &str, version: &str, scanner: Arc<S>, audit: Arc<A>) -> Self {
116 Self {
117 name: name.to_string(),
118 version: version.to_string(),
119 tools: HashMap::new(),
120 scanner,
121 audit,
122 required_trust: TrustLevel::Low,
123 }
124 }
125
126 pub fn set_required_trust(&mut self, level: TrustLevel) {
128 self.required_trust = level;
129 }
130
131 pub fn register_tool(&mut self, tool: ToolDef) {
133 self.tools.insert(
134 tool.name.clone(),
135 ToolEntry {
136 description: tool.description,
137 schema: tool.parameters_schema,
138 handler: tool.handler,
139 required_trust: None,
140 },
141 );
142 }
143
144 pub fn register_tool_with_trust(&mut self, tool: ToolDef, trust: TrustLevel) {
146 self.tools.insert(
147 tool.name.clone(),
148 ToolEntry {
149 description: tool.description,
150 schema: tool.parameters_schema,
151 handler: tool.handler,
152 required_trust: Some(trust),
153 },
154 );
155 }
156
157 pub async fn handle_request(
163 &self,
164 request: &str,
165 caller_trust: TrustLevel,
166 ) -> String {
167 let req: JsonRpcRequest = match serde_json::from_str(request) {
168 Ok(r) => r,
169 Err(e) => {
170 return serde_json::to_string(&JsonRpcResponse {
171 jsonrpc: "2.0".into(),
172 id: serde_json::Value::Null,
173 result: None,
174 error: Some(JsonRpcError {
175 code: -32700,
176 message: format!("Parse error: {e}"),
177 data: None,
178 }),
179 })
180 .unwrap_or_default();
181 }
182 };
183
184 let id = req.id.clone().unwrap_or(serde_json::Value::Null);
185
186 let response = match req.method.as_str() {
187 "initialize" => self.handle_initialize(&id),
188 "tools/list" => self.handle_tools_list(&id),
189 "tools/call" => self.handle_tools_call(&id, req.params, caller_trust).await,
190 _ => JsonRpcResponse {
191 jsonrpc: "2.0".into(),
192 id,
193 result: None,
194 error: Some(JsonRpcError {
195 code: -32601,
196 message: format!("Method not found: {}", req.method),
197 data: None,
198 }),
199 },
200 };
201
202 serde_json::to_string(&response).unwrap_or_default()
203 }
204
205 fn handle_initialize(&self, id: &serde_json::Value) -> JsonRpcResponse {
206 JsonRpcResponse {
207 jsonrpc: "2.0".into(),
208 id: id.clone(),
209 result: Some(serde_json::json!({
210 "protocolVersion": "2024-11-05",
211 "serverInfo": {
212 "name": self.name,
213 "version": self.version,
214 },
215 "capabilities": {
216 "tools": { "listChanged": false },
217 },
218 "sigil": {
219 "version": "0.1.0",
220 "requiredTrust": format!("{:?}", self.required_trust),
221 }
222 })),
223 error: None,
224 }
225 }
226
227 fn handle_tools_list(&self, id: &serde_json::Value) -> JsonRpcResponse {
228 let tools: Vec<serde_json::Value> = self
229 .tools
230 .iter()
231 .map(|(name, entry)| {
232 serde_json::json!({
233 "name": name,
234 "description": entry.description,
235 "inputSchema": entry.schema,
236 })
237 })
238 .collect();
239
240 JsonRpcResponse {
241 jsonrpc: "2.0".into(),
242 id: id.clone(),
243 result: Some(serde_json::json!({ "tools": tools })),
244 error: None,
245 }
246 }
247
248 async fn handle_tools_call(
249 &self,
250 id: &serde_json::Value,
251 params: serde_json::Value,
252 caller_trust: TrustLevel,
253 ) -> JsonRpcResponse {
254 let tool_name = params
255 .get("name")
256 .and_then(|v| v.as_str())
257 .unwrap_or("")
258 .to_string();
259
260 let arguments = params
261 .get("arguments")
262 .cloned()
263 .unwrap_or(serde_json::json!({}));
264
265 let entry = match self.tools.get(&tool_name) {
267 Some(e) => e,
268 None => {
269 return JsonRpcResponse {
270 jsonrpc: "2.0".into(),
271 id: id.clone(),
272 result: None,
273 error: Some(JsonRpcError {
274 code: -32602,
275 message: format!("Unknown tool: {tool_name}"),
276 data: None,
277 }),
278 };
279 }
280 };
281
282 let required = entry.required_trust.unwrap_or(self.required_trust);
284 if (caller_trust as u8) < (required as u8) {
285 let _ = self.audit.log(&AuditEvent::new(AuditEventType::PolicyViolation).with_action(
286 format!("Trust gate: {tool_name} requires {required:?}, caller has {caller_trust:?}"),
287 "high".into(),
288 false,
289 false,
290 ));
291 return JsonRpcResponse {
292 jsonrpc: "2.0".into(),
293 id: id.clone(),
294 result: None,
295 error: Some(JsonRpcError {
296 code: -32001,
297 message: format!(
298 "SIGIL trust gate: tool '{tool_name}' requires {required:?} trust"
299 ),
300 data: None,
301 }),
302 };
303 }
304
305 let args_str = serde_json::to_string(&arguments).unwrap_or_default();
307 let input_scan = self.scanner.scan(&args_str);
308 if input_scan.is_some() {
309 let _ = self.audit.log(&AuditEvent::new(AuditEventType::SigilInterception).with_action(
310 format!("Input scan: secrets detected in {tool_name} arguments"),
311 "high".into(),
312 true,
313 false,
314 ));
315 }
316
317 let result = (entry.handler)(arguments).await;
319
320 match result {
321 Ok(output) => {
322 let output_str = serde_json::to_string(&output).unwrap_or_default();
324 let output_scan = self.scanner.scan(&output_str);
325
326 let _ = self.audit.log(&AuditEvent::new(AuditEventType::McpToolGated).with_action(
327 format!(
328 "MCP tool {tool_name}: input_secrets={}, output_secrets={}",
329 input_scan.is_some(),
330 output_scan.is_some()
331 ),
332 "low".into(),
333 true,
334 true,
335 ));
336
337 JsonRpcResponse {
338 jsonrpc: "2.0".into(),
339 id: id.clone(),
340 result: Some(serde_json::json!({
341 "content": [{
342 "type": "text",
343 "text": output_str,
344 }],
345 "isError": false,
346 "sigil": {
347 "inputSecrets": input_scan.is_some(),
348 "outputSecrets": output_scan.is_some(),
349 }
350 })),
351 error: None,
352 }
353 }
354 Err(e) => JsonRpcResponse {
355 jsonrpc: "2.0".into(),
356 id: id.clone(),
357 result: Some(serde_json::json!({
358 "content": [{
359 "type": "text",
360 "text": format!("Error: {e}"),
361 }],
362 "isError": true,
363 })),
364 error: None,
365 },
366 }
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 struct TestScanner;
376 impl SensitivityScanner for TestScanner {
377 fn scan(&self, text: &str) -> Option<String> {
378 if text.contains("sk-") {
379 Some("OpenAI Key".into())
380 } else {
381 None
382 }
383 }
384 }
385
386 struct TestAudit {
388 log_count: std::sync::atomic::AtomicU32,
389 }
390 impl TestAudit {
391 fn new() -> Self {
392 Self {
393 log_count: std::sync::atomic::AtomicU32::new(0),
394 }
395 }
396 fn count(&self) -> u32 {
397 self.log_count.load(std::sync::atomic::Ordering::SeqCst)
398 }
399 }
400 impl AuditLogger for TestAudit {
401 fn log(&self, _event: &AuditEvent) -> anyhow::Result<()> {
402 self.log_count
403 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
404 Ok(())
405 }
406 }
407
408 fn make_server() -> SigilMcpServer<TestScanner, TestAudit> {
409 let scanner = Arc::new(TestScanner);
410 let audit = Arc::new(TestAudit::new());
411 let mut server = SigilMcpServer::new("test-server", "0.1.0", scanner, audit);
412
413 server.register_tool(ToolDef {
414 name: "echo".into(),
415 description: "Echo input back".into(),
416 parameters_schema: serde_json::json!({"type": "object"}),
417 handler: Box::new(|args| {
418 Box::pin(async move { Ok(args) })
419 }),
420 });
421
422 server.register_tool_with_trust(
423 ToolDef {
424 name: "admin_reset".into(),
425 description: "Dangerous admin operation".into(),
426 parameters_schema: serde_json::json!({"type": "object"}),
427 handler: Box::new(|_| {
428 Box::pin(async move { Ok(serde_json::json!({"status": "reset"})) })
429 }),
430 },
431 TrustLevel::High,
432 );
433
434 server
435 }
436
437 #[tokio::test]
438 async fn initialize_returns_server_info() {
439 let server = make_server();
440 let req = r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}"#;
441 let resp = server.handle_request(req, TrustLevel::Low).await;
442 let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
443 assert_eq!(parsed["result"]["serverInfo"]["name"], "test-server");
444 assert!(parsed["result"]["sigil"].is_object());
445 }
446
447 #[tokio::test]
448 async fn tools_list_returns_registered_tools() {
449 let server = make_server();
450 let req = r#"{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}"#;
451 let resp = server.handle_request(req, TrustLevel::Low).await;
452 let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
453 let tools = parsed["result"]["tools"].as_array().unwrap();
454 assert_eq!(tools.len(), 2);
455 let names: Vec<&str> = tools.iter().map(|t| t["name"].as_str().unwrap()).collect();
456 assert!(names.contains(&"echo"));
457 assert!(names.contains(&"admin_reset"));
458 }
459
460 #[tokio::test]
461 async fn tools_call_echo_succeeds() {
462 let server = make_server();
463 let req = r#"{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"echo","arguments":{"message":"hello"}}}"#;
464 let resp = server.handle_request(req, TrustLevel::Low).await;
465 let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
466 assert!(parsed["result"]["content"][0]["text"]
467 .as_str()
468 .unwrap()
469 .contains("hello"));
470 assert_eq!(parsed["result"]["isError"], false);
471 }
472
473 #[tokio::test]
474 async fn tools_call_unknown_tool_returns_error() {
475 let server = make_server();
476 let req = r#"{"jsonrpc":"2.0","id":4,"method":"tools/call","params":{"name":"nonexistent","arguments":{}}}"#;
477 let resp = server.handle_request(req, TrustLevel::Low).await;
478 let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
479 assert!(parsed["error"]["message"]
480 .as_str()
481 .unwrap()
482 .contains("Unknown tool"));
483 }
484
485 #[tokio::test]
486 async fn trust_gate_blocks_low_trust_from_high_trust_tool() {
487 let server = make_server();
488 let req = r#"{"jsonrpc":"2.0","id":5,"method":"tools/call","params":{"name":"admin_reset","arguments":{}}}"#;
489 let resp = server.handle_request(req, TrustLevel::Low).await;
490 let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
491 assert!(parsed["error"]["message"]
492 .as_str()
493 .unwrap()
494 .contains("trust gate"));
495 }
496
497 #[tokio::test]
498 async fn trust_gate_allows_high_trust_for_high_trust_tool() {
499 let server = make_server();
500 let req = r#"{"jsonrpc":"2.0","id":6,"method":"tools/call","params":{"name":"admin_reset","arguments":{}}}"#;
501 let resp = server.handle_request(req, TrustLevel::High).await;
502 let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
503 assert!(parsed["error"].is_null());
504 assert!(parsed["result"]["content"][0]["text"]
505 .as_str()
506 .unwrap()
507 .contains("reset"));
508 }
509
510 #[tokio::test]
511 async fn sigil_scan_detects_secrets_in_arguments() {
512 let server = make_server();
513 let req = r#"{"jsonrpc":"2.0","id":7,"method":"tools/call","params":{"name":"echo","arguments":{"key":"sk-abc123def456"}}}"#;
514 let resp = server.handle_request(req, TrustLevel::Low).await;
515 let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
516 assert_eq!(parsed["result"]["sigil"]["inputSecrets"], true);
518 assert!(server.audit.count() >= 2);
520 }
521
522 #[tokio::test]
523 async fn sigil_scan_no_secrets_in_clean_input() {
524 let server = make_server();
525 let req = r#"{"jsonrpc":"2.0","id":8,"method":"tools/call","params":{"name":"echo","arguments":{"message":"safe text"}}}"#;
526 let resp = server.handle_request(req, TrustLevel::Low).await;
527 let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
528 assert_eq!(parsed["result"]["sigil"]["inputSecrets"], false);
529 assert_eq!(parsed["result"]["sigil"]["outputSecrets"], false);
530 }
531
532 #[tokio::test]
533 async fn invalid_json_returns_parse_error() {
534 let server = make_server();
535 let resp = server.handle_request("not json", TrustLevel::Low).await;
536 let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
537 assert_eq!(parsed["error"]["code"], -32700);
538 }
539
540 #[tokio::test]
541 async fn unknown_method_returns_method_not_found() {
542 let server = make_server();
543 let req = r#"{"jsonrpc":"2.0","id":10,"method":"resources/list","params":{}}"#;
544 let resp = server.handle_request(req, TrustLevel::Low).await;
545 let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
546 assert_eq!(parsed["error"]["code"], -32601);
547 }
548
549 #[tokio::test]
550 async fn audit_logged_for_every_tool_call() {
551 let server = make_server();
552 let req = r#"{"jsonrpc":"2.0","id":11,"method":"tools/call","params":{"name":"echo","arguments":{"msg":"hi"}}}"#;
553 let before = server.audit.count();
554 server.handle_request(req, TrustLevel::Low).await;
555 let after = server.audit.count();
556 assert!(after > before, "Audit log should record tool invocation");
557 }
558}