Skip to main content

sigil_protocol/
mcp_server.rs

1//! Reference SIGIL MCP Server.
2//!
3//! An embeddable MCP (Model Context Protocol) server that any Rust
4//! application can use to expose SIGIL-protected tools over JSON-RPC 2.0.
5//!
6//! ```rust,no_run
7//! use sigil_protocol::mcp_server::{SigilMcpServer, ToolDef};
8//! use sigil_protocol::{SensitivityScanner, AuditLogger, AuditEvent};
9//! use std::sync::Arc;
10//!
11//! struct MyScanner;
12//! impl SensitivityScanner for MyScanner {
13//!     fn scan(&self, _text: &str) -> Option<String> { None }
14//! }
15//! struct MyAudit;
16//! impl AuditLogger for MyAudit {
17//!     fn log(&self, _e: &AuditEvent) -> anyhow::Result<()> { Ok(()) }
18//! }
19//!
20//! let scanner = Arc::new(MyScanner);
21//! let audit = Arc::new(MyAudit);
22//! let mut server = SigilMcpServer::new("my-server", "1.0.0", scanner, audit);
23//! server.register_tool(ToolDef {
24//!     name: "get_weather".into(),
25//!     description: "Get current weather".into(),
26//!     parameters_schema: serde_json::json!({"type": "object"}),
27//!     handler: Box::new(|args| Box::pin(async move {
28//!         Ok(serde_json::json!({"temp": 22}))
29//!     })),
30//! });
31//! ```
32
33use crate::{AuditEvent, AuditEventType, AuditLogger, SensitivityScanner, TrustLevel};
34use serde::{Deserialize, Serialize};
35use std::collections::HashMap;
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39
40// ── JSON-RPC 2.0 types ─────────────────────────────────────────
41
42#[derive(Debug, Deserialize)]
43pub struct JsonRpcRequest {
44    pub jsonrpc: String,
45    pub id: Option<serde_json::Value>,
46    pub method: String,
47    #[serde(default)]
48    pub params: serde_json::Value,
49}
50
51#[derive(Debug, Serialize)]
52pub struct JsonRpcResponse {
53    pub jsonrpc: String,
54    pub id: serde_json::Value,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub result: Option<serde_json::Value>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub error: Option<JsonRpcError>,
59}
60
61#[derive(Debug, Serialize)]
62pub struct JsonRpcError {
63    pub code: i32,
64    pub message: String,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub data: Option<serde_json::Value>,
67}
68
69// ── MCP Tool definition ─────────────────────────────────────────
70
71/// Async handler for a tool call.
72pub type ToolHandler = Box<
73    dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = anyhow::Result<serde_json::Value>> + Send>>
74        + Send
75        + Sync,
76>;
77
78/// A tool definition to register with the SIGIL MCP server.
79pub struct ToolDef {
80    pub name: String,
81    pub description: String,
82    pub parameters_schema: serde_json::Value,
83    pub handler: ToolHandler,
84}
85
86// ── SIGIL MCP Server ────────────────────────────────────────────
87
88/// A reference SIGIL-secured MCP server.
89///
90/// Wraps any set of tools with:
91/// - **Input scanning** — tool arguments are scanned for secrets before execution
92/// - **Output scanning** — tool results are scanned for secrets before returning
93/// - **Audit logging** — every tool invocation is logged
94/// - **Trust gating** — tools can require a minimum trust level
95pub struct SigilMcpServer<S: SensitivityScanner, A: AuditLogger> {
96    name: String,
97    version: String,
98    tools: HashMap<String, ToolEntry>,
99    scanner: Arc<S>,
100    audit: Arc<A>,
101    /// Minimum trust level required for this server (default: Low).
102    required_trust: TrustLevel,
103}
104
105struct ToolEntry {
106    description: String,
107    schema: serde_json::Value,
108    handler: ToolHandler,
109    /// Per-tool trust level override (None = use server default).
110    required_trust: Option<TrustLevel>,
111}
112
113impl<S: SensitivityScanner, A: AuditLogger> SigilMcpServer<S, A> {
114    /// Create a new SIGIL MCP server.
115    pub fn new(name: &str, version: &str, scanner: Arc<S>, audit: Arc<A>) -> Self {
116        Self {
117            name: name.to_string(),
118            version: version.to_string(),
119            tools: HashMap::new(),
120            scanner,
121            audit,
122            required_trust: TrustLevel::Low,
123        }
124    }
125
126    /// Set the minimum trust level for the entire server.
127    pub fn set_required_trust(&mut self, level: TrustLevel) {
128        self.required_trust = level;
129    }
130
131    /// Register a tool.
132    pub fn register_tool(&mut self, tool: ToolDef) {
133        self.tools.insert(
134            tool.name.clone(),
135            ToolEntry {
136                description: tool.description,
137                schema: tool.parameters_schema,
138                handler: tool.handler,
139                required_trust: None,
140            },
141        );
142    }
143
144    /// Register a tool with a specific trust requirement.
145    pub fn register_tool_with_trust(&mut self, tool: ToolDef, trust: TrustLevel) {
146        self.tools.insert(
147            tool.name.clone(),
148            ToolEntry {
149                description: tool.description,
150                schema: tool.parameters_schema,
151                handler: tool.handler,
152                required_trust: Some(trust),
153            },
154        );
155    }
156
157    /// Handle an incoming JSON-RPC 2.0 request string.
158    ///
159    /// Returns the JSON-RPC response string. All tool arguments and results
160    /// are scanned by the SIGIL `SensitivityScanner`, and every invocation
161    /// is logged via the `AuditLogger`.
162    pub async fn handle_request(
163        &self,
164        request: &str,
165        caller_trust: TrustLevel,
166    ) -> String {
167        let req: JsonRpcRequest = match serde_json::from_str(request) {
168            Ok(r) => r,
169            Err(e) => {
170                return serde_json::to_string(&JsonRpcResponse {
171                    jsonrpc: "2.0".into(),
172                    id: serde_json::Value::Null,
173                    result: None,
174                    error: Some(JsonRpcError {
175                        code: -32700,
176                        message: format!("Parse error: {e}"),
177                        data: None,
178                    }),
179                })
180                .unwrap_or_default();
181            }
182        };
183
184        let id = req.id.clone().unwrap_or(serde_json::Value::Null);
185
186        let response = match req.method.as_str() {
187            "initialize" => self.handle_initialize(&id),
188            "tools/list" => self.handle_tools_list(&id),
189            "tools/call" => self.handle_tools_call(&id, req.params, caller_trust).await,
190            _ => JsonRpcResponse {
191                jsonrpc: "2.0".into(),
192                id,
193                result: None,
194                error: Some(JsonRpcError {
195                    code: -32601,
196                    message: format!("Method not found: {}", req.method),
197                    data: None,
198                }),
199            },
200        };
201
202        serde_json::to_string(&response).unwrap_or_default()
203    }
204
205    fn handle_initialize(&self, id: &serde_json::Value) -> JsonRpcResponse {
206        JsonRpcResponse {
207            jsonrpc: "2.0".into(),
208            id: id.clone(),
209            result: Some(serde_json::json!({
210                "protocolVersion": "2024-11-05",
211                "serverInfo": {
212                    "name": self.name,
213                    "version": self.version,
214                },
215                "capabilities": {
216                    "tools": { "listChanged": false },
217                },
218                "sigil": {
219                    "version": "0.1.0",
220                    "requiredTrust": format!("{:?}", self.required_trust),
221                }
222            })),
223            error: None,
224        }
225    }
226
227    fn handle_tools_list(&self, id: &serde_json::Value) -> JsonRpcResponse {
228        let tools: Vec<serde_json::Value> = self
229            .tools
230            .iter()
231            .map(|(name, entry)| {
232                serde_json::json!({
233                    "name": name,
234                    "description": entry.description,
235                    "inputSchema": entry.schema,
236                })
237            })
238            .collect();
239
240        JsonRpcResponse {
241            jsonrpc: "2.0".into(),
242            id: id.clone(),
243            result: Some(serde_json::json!({ "tools": tools })),
244            error: None,
245        }
246    }
247
248    async fn handle_tools_call(
249        &self,
250        id: &serde_json::Value,
251        params: serde_json::Value,
252        caller_trust: TrustLevel,
253    ) -> JsonRpcResponse {
254        let tool_name = params
255            .get("name")
256            .and_then(|v| v.as_str())
257            .unwrap_or("")
258            .to_string();
259
260        let arguments = params
261            .get("arguments")
262            .cloned()
263            .unwrap_or(serde_json::json!({}));
264
265        // ── Lookup tool ──
266        let entry = match self.tools.get(&tool_name) {
267            Some(e) => e,
268            None => {
269                return JsonRpcResponse {
270                    jsonrpc: "2.0".into(),
271                    id: id.clone(),
272                    result: None,
273                    error: Some(JsonRpcError {
274                        code: -32602,
275                        message: format!("Unknown tool: {tool_name}"),
276                        data: None,
277                    }),
278                };
279            }
280        };
281
282        // ── SIGIL: Trust gate ──
283        let required = entry.required_trust.unwrap_or(self.required_trust);
284        if (caller_trust as u8) < (required as u8) {
285            let _ = self.audit.log(&AuditEvent::new(AuditEventType::PolicyViolation).with_action(
286                format!("Trust gate: {tool_name} requires {required:?}, caller has {caller_trust:?}"),
287                "high".into(),
288                false,
289                false,
290            ));
291            return JsonRpcResponse {
292                jsonrpc: "2.0".into(),
293                id: id.clone(),
294                result: None,
295                error: Some(JsonRpcError {
296                    code: -32001,
297                    message: format!(
298                        "SIGIL trust gate: tool '{tool_name}' requires {required:?} trust"
299                    ),
300                    data: None,
301                }),
302            };
303        }
304
305        // ── SIGIL: Scan input arguments ──
306        let args_str = serde_json::to_string(&arguments).unwrap_or_default();
307        let input_scan = self.scanner.scan(&args_str);
308        if input_scan.is_some() {
309            let _ = self.audit.log(&AuditEvent::new(AuditEventType::SigilInterception).with_action(
310                format!("Input scan: secrets detected in {tool_name} arguments"),
311                "high".into(),
312                true,
313                false,
314            ));
315        }
316
317        // ── Execute tool ──
318        let result = (entry.handler)(arguments).await;
319
320        match result {
321            Ok(output) => {
322                // ── SIGIL: Scan output ──
323                let output_str = serde_json::to_string(&output).unwrap_or_default();
324                let output_scan = self.scanner.scan(&output_str);
325
326                let _ = self.audit.log(&AuditEvent::new(AuditEventType::McpToolGated).with_action(
327                    format!(
328                        "MCP tool {tool_name}: input_secrets={}, output_secrets={}",
329                        input_scan.is_some(),
330                        output_scan.is_some()
331                    ),
332                    "low".into(),
333                    true,
334                    true,
335                ));
336
337                JsonRpcResponse {
338                    jsonrpc: "2.0".into(),
339                    id: id.clone(),
340                    result: Some(serde_json::json!({
341                        "content": [{
342                            "type": "text",
343                            "text": output_str,
344                        }],
345                        "isError": false,
346                        "sigil": {
347                            "inputSecrets": input_scan.is_some(),
348                            "outputSecrets": output_scan.is_some(),
349                        }
350                    })),
351                    error: None,
352                }
353            }
354            Err(e) => JsonRpcResponse {
355                jsonrpc: "2.0".into(),
356                id: id.clone(),
357                result: Some(serde_json::json!({
358                    "content": [{
359                        "type": "text",
360                        "text": format!("Error: {e}"),
361                    }],
362                    "isError": true,
363                })),
364                error: None,
365            },
366        }
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    // Minimal test scanner
375    struct TestScanner;
376    impl SensitivityScanner for TestScanner {
377        fn scan(&self, text: &str) -> Option<String> {
378            if text.contains("sk-") {
379                Some("OpenAI Key".into())
380            } else {
381                None
382            }
383        }
384    }
385
386    // Minimal test audit logger
387    struct TestAudit {
388        log_count: std::sync::atomic::AtomicU32,
389    }
390    impl TestAudit {
391        fn new() -> Self {
392            Self {
393                log_count: std::sync::atomic::AtomicU32::new(0),
394            }
395        }
396        fn count(&self) -> u32 {
397            self.log_count.load(std::sync::atomic::Ordering::SeqCst)
398        }
399    }
400    impl AuditLogger for TestAudit {
401        fn log(&self, _event: &AuditEvent) -> anyhow::Result<()> {
402            self.log_count
403                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
404            Ok(())
405        }
406    }
407
408    fn make_server() -> SigilMcpServer<TestScanner, TestAudit> {
409        let scanner = Arc::new(TestScanner);
410        let audit = Arc::new(TestAudit::new());
411        let mut server = SigilMcpServer::new("test-server", "0.1.0", scanner, audit);
412
413        server.register_tool(ToolDef {
414            name: "echo".into(),
415            description: "Echo input back".into(),
416            parameters_schema: serde_json::json!({"type": "object"}),
417            handler: Box::new(|args| {
418                Box::pin(async move { Ok(args) })
419            }),
420        });
421
422        server.register_tool_with_trust(
423            ToolDef {
424                name: "admin_reset".into(),
425                description: "Dangerous admin operation".into(),
426                parameters_schema: serde_json::json!({"type": "object"}),
427                handler: Box::new(|_| {
428                    Box::pin(async move { Ok(serde_json::json!({"status": "reset"})) })
429                }),
430            },
431            TrustLevel::High,
432        );
433
434        server
435    }
436
437    #[tokio::test]
438    async fn initialize_returns_server_info() {
439        let server = make_server();
440        let req = r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}"#;
441        let resp = server.handle_request(req, TrustLevel::Low).await;
442        let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
443        assert_eq!(parsed["result"]["serverInfo"]["name"], "test-server");
444        assert!(parsed["result"]["sigil"].is_object());
445    }
446
447    #[tokio::test]
448    async fn tools_list_returns_registered_tools() {
449        let server = make_server();
450        let req = r#"{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}"#;
451        let resp = server.handle_request(req, TrustLevel::Low).await;
452        let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
453        let tools = parsed["result"]["tools"].as_array().unwrap();
454        assert_eq!(tools.len(), 2);
455        let names: Vec<&str> = tools.iter().map(|t| t["name"].as_str().unwrap()).collect();
456        assert!(names.contains(&"echo"));
457        assert!(names.contains(&"admin_reset"));
458    }
459
460    #[tokio::test]
461    async fn tools_call_echo_succeeds() {
462        let server = make_server();
463        let req = r#"{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"echo","arguments":{"message":"hello"}}}"#;
464        let resp = server.handle_request(req, TrustLevel::Low).await;
465        let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
466        assert!(parsed["result"]["content"][0]["text"]
467            .as_str()
468            .unwrap()
469            .contains("hello"));
470        assert_eq!(parsed["result"]["isError"], false);
471    }
472
473    #[tokio::test]
474    async fn tools_call_unknown_tool_returns_error() {
475        let server = make_server();
476        let req = r#"{"jsonrpc":"2.0","id":4,"method":"tools/call","params":{"name":"nonexistent","arguments":{}}}"#;
477        let resp = server.handle_request(req, TrustLevel::Low).await;
478        let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
479        assert!(parsed["error"]["message"]
480            .as_str()
481            .unwrap()
482            .contains("Unknown tool"));
483    }
484
485    #[tokio::test]
486    async fn trust_gate_blocks_low_trust_from_high_trust_tool() {
487        let server = make_server();
488        let req = r#"{"jsonrpc":"2.0","id":5,"method":"tools/call","params":{"name":"admin_reset","arguments":{}}}"#;
489        let resp = server.handle_request(req, TrustLevel::Low).await;
490        let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
491        assert!(parsed["error"]["message"]
492            .as_str()
493            .unwrap()
494            .contains("trust gate"));
495    }
496
497    #[tokio::test]
498    async fn trust_gate_allows_high_trust_for_high_trust_tool() {
499        let server = make_server();
500        let req = r#"{"jsonrpc":"2.0","id":6,"method":"tools/call","params":{"name":"admin_reset","arguments":{}}}"#;
501        let resp = server.handle_request(req, TrustLevel::High).await;
502        let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
503        assert!(parsed["error"].is_null());
504        assert!(parsed["result"]["content"][0]["text"]
505            .as_str()
506            .unwrap()
507            .contains("reset"));
508    }
509
510    #[tokio::test]
511    async fn sigil_scan_detects_secrets_in_arguments() {
512        let server = make_server();
513        let req = r#"{"jsonrpc":"2.0","id":7,"method":"tools/call","params":{"name":"echo","arguments":{"key":"sk-abc123def456"}}}"#;
514        let resp = server.handle_request(req, TrustLevel::Low).await;
515        let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
516        // Tool still executes, but SIGIL metadata flags secrets
517        assert_eq!(parsed["result"]["sigil"]["inputSecrets"], true);
518        // Audit log was called (at least 2: interception + tool gated)
519        assert!(server.audit.count() >= 2);
520    }
521
522    #[tokio::test]
523    async fn sigil_scan_no_secrets_in_clean_input() {
524        let server = make_server();
525        let req = r#"{"jsonrpc":"2.0","id":8,"method":"tools/call","params":{"name":"echo","arguments":{"message":"safe text"}}}"#;
526        let resp = server.handle_request(req, TrustLevel::Low).await;
527        let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
528        assert_eq!(parsed["result"]["sigil"]["inputSecrets"], false);
529        assert_eq!(parsed["result"]["sigil"]["outputSecrets"], false);
530    }
531
532    #[tokio::test]
533    async fn invalid_json_returns_parse_error() {
534        let server = make_server();
535        let resp = server.handle_request("not json", TrustLevel::Low).await;
536        let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
537        assert_eq!(parsed["error"]["code"], -32700);
538    }
539
540    #[tokio::test]
541    async fn unknown_method_returns_method_not_found() {
542        let server = make_server();
543        let req = r#"{"jsonrpc":"2.0","id":10,"method":"resources/list","params":{}}"#;
544        let resp = server.handle_request(req, TrustLevel::Low).await;
545        let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap();
546        assert_eq!(parsed["error"]["code"], -32601);
547    }
548
549    #[tokio::test]
550    async fn audit_logged_for_every_tool_call() {
551        let server = make_server();
552        let req = r#"{"jsonrpc":"2.0","id":11,"method":"tools/call","params":{"name":"echo","arguments":{"msg":"hi"}}}"#;
553        let before = server.audit.count();
554        server.handle_request(req, TrustLevel::Low).await;
555        let after = server.audit.count();
556        assert!(after > before, "Audit log should record tool invocation");
557    }
558}