Skip to main content

rust_memex/
mcp_protocol.rs

1use crate::{
2    auth::{AuthDenial, AuthManager, Scope},
3    embeddings::EmbeddingClient,
4    query::{QueryRouter, SearchModeRecommendation},
5    rag::{RAGPipeline, SearchOptions, SliceLayer},
6    search::{HybridSearcher, SearchMode},
7};
8use anyhow::{Result, anyhow};
9use serde_json::{Value, json};
10use std::path::Path;
11use std::sync::Arc;
12use tokio::sync::Mutex;
13
14pub const PROTOCOL_VERSION: &str = "2024-11-05";
15pub const SERVER_NAME: &str = "rust-memex";
16
17/// Build a JSON-RPC 2.0 error response.
18/// Per JSON-RPC 2.0 spec, omits `id` field when it's null or absent.
19pub fn jsonrpc_error(id: Option<&Value>, code: i32, message: impl Into<String>) -> Value {
20    let message = message.into();
21
22    match id {
23        Some(id) if !id.is_null() => json!({
24            "jsonrpc": "2.0",
25            "error": {"code": code, "message": message},
26            "id": id
27        }),
28        _ => json!({
29            "jsonrpc": "2.0",
30            "error": {"code": code, "message": message}
31        }),
32    }
33}
34
35/// Build a JSON-RPC 2.0 success response.
36/// Per JSON-RPC 2.0 spec, omits `id` field when it's null.
37pub fn jsonrpc_success(id: &Value, result: Value) -> Value {
38    if id.is_null() {
39        json!({
40            "jsonrpc": "2.0",
41            "result": result
42        })
43    } else {
44        json!({
45            "jsonrpc": "2.0",
46            "id": id,
47            "result": result
48        })
49    }
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum McpTransport {
54    Stdio,
55    HttpSse,
56}
57
58impl McpTransport {
59    fn health_transport(self) -> Option<&'static str> {
60        match self {
61            Self::Stdio => None,
62            Self::HttpSse => Some("mcp-over-sse"),
63        }
64    }
65}
66
67pub enum McpDispatch {
68    Notification,
69    Response(Value),
70}
71
72impl McpDispatch {
73    pub fn into_option(self) -> Option<Value> {
74        match self {
75            Self::Notification => None,
76            Self::Response(response) => Some(response),
77        }
78    }
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82enum McpMethod {
83    Initialize,
84    ToolsList,
85    ToolsCall,
86}
87
88impl McpMethod {
89    fn from_name(name: &str) -> Option<Self> {
90        match name {
91            "initialize" => Some(Self::Initialize),
92            "tools/list" => Some(Self::ToolsList),
93            "tools/call" => Some(Self::ToolsCall),
94            _ => None,
95        }
96    }
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100enum McpTool {
101    Health,
102    RagIndex,
103    MemoryUpsert,
104    MemoryGet,
105    MemorySearch,
106    MemoryDelete,
107    MemoryPurgeNamespace,
108    NamespaceCreateToken,
109    NamespaceRevokeToken,
110    NamespaceListProtected,
111    NamespaceSecurityStatus,
112    Dive,
113}
114
115impl McpTool {
116    const ALL: [Self; 12] = [
117        Self::Health,
118        Self::RagIndex,
119        Self::MemoryUpsert,
120        Self::MemoryGet,
121        Self::MemorySearch,
122        Self::MemoryDelete,
123        Self::MemoryPurgeNamespace,
124        Self::NamespaceCreateToken,
125        Self::NamespaceRevokeToken,
126        Self::NamespaceListProtected,
127        Self::NamespaceSecurityStatus,
128        Self::Dive,
129    ];
130
131    fn from_name(name: &str) -> Option<Self> {
132        match name {
133            "health" => Some(Self::Health),
134            "rag_index" => Some(Self::RagIndex),
135            "memory_upsert" => Some(Self::MemoryUpsert),
136            "memory_get" => Some(Self::MemoryGet),
137            "memory_search" => Some(Self::MemorySearch),
138            "memory_delete" => Some(Self::MemoryDelete),
139            "memory_purge_namespace" => Some(Self::MemoryPurgeNamespace),
140            "namespace_create_token" => Some(Self::NamespaceCreateToken),
141            "namespace_revoke_token" => Some(Self::NamespaceRevokeToken),
142            "namespace_list_protected" => Some(Self::NamespaceListProtected),
143            "namespace_security_status" => Some(Self::NamespaceSecurityStatus),
144            "dive" => Some(Self::Dive),
145            _ => None,
146        }
147    }
148
149    fn name(self) -> &'static str {
150        match self {
151            Self::Health => "health",
152            Self::RagIndex => "rag_index",
153            Self::MemoryUpsert => "memory_upsert",
154            Self::MemoryGet => "memory_get",
155            Self::MemorySearch => "memory_search",
156            Self::MemoryDelete => "memory_delete",
157            Self::MemoryPurgeNamespace => "memory_purge_namespace",
158            Self::NamespaceCreateToken => "namespace_create_token",
159            Self::NamespaceRevokeToken => "namespace_revoke_token",
160            Self::NamespaceListProtected => "namespace_list_protected",
161            Self::NamespaceSecurityStatus => "namespace_security_status",
162            Self::Dive => "dive",
163        }
164    }
165
166    fn definition(self) -> Value {
167        match self {
168            Self::Health => json!({
169                "name": self.name(),
170                "description": "Health/status of rust-memex server",
171                "inputSchema": {
172                    "type": "object",
173                    "properties": {},
174                    "required": []
175                }
176            }),
177            Self::RagIndex => json!({
178                "name": self.name(),
179                "description": "Index a document for RAG",
180                "inputSchema": {
181                    "type": "object",
182                    "properties": {
183                        "path": {"type": "string"},
184                        "namespace": {"type": "string"}
185                    },
186                    "required": ["path"]
187                }
188            }),
189            Self::MemoryUpsert => json!({
190                "name": self.name(),
191                "description": "Upsert a text chunk into vector memory. If the namespace is protected, provide the access token.",
192                "inputSchema": {
193                    "type": "object",
194                    "properties": {
195                        "namespace": {"type": "string"},
196                        "id": {"type": "string"},
197                        "text": {"type": "string"},
198                        "metadata": {"type": "object"},
199                        "token": {"type": "string", "description": "Access token for protected namespaces"}
200                    },
201                    "required": ["namespace", "id", "text"]
202                }
203            }),
204            Self::MemoryGet => json!({
205                "name": self.name(),
206                "description": "Get a stored chunk by namespace + id. If the namespace is protected, provide the access token.",
207                "inputSchema": {
208                    "type": "object",
209                    "properties": {
210                        "namespace": {"type": "string"},
211                        "id": {"type": "string"},
212                        "token": {"type": "string", "description": "Access token for protected namespaces"}
213                    },
214                    "required": ["namespace", "id"]
215                }
216            }),
217            Self::MemorySearch => json!({
218                "name": self.name(),
219                "description": "Semantic search within a namespace. If the namespace is protected, provide the access token.",
220                "inputSchema": {
221                    "type": "object",
222                    "properties": {
223                        "namespace": {"type": "string"},
224                        "query": {"type": "string"},
225                        "k": {"type": "integer", "default": 5},
226                        "project": {"type": "string", "description": "Filter to documents whose metadata project/project_id matches this value"},
227                        "deep": {"type": "boolean", "default": false, "description": "Include all onion layers instead of only outer summaries"},
228                        "mode": {"type": "string", "enum": ["vector", "bm25", "hybrid"], "default": "hybrid", "description": "Search mode: vector (semantic), bm25 (keyword), hybrid (both)"},
229                        "auto_route": {"type": "boolean", "default": false, "description": "Auto-detect query intent and select optimal search mode. Overrides mode when true."},
230                        "token": {"type": "string", "description": "Access token for protected namespaces"}
231                    },
232                    "required": ["namespace", "query"]
233                }
234            }),
235            Self::MemoryDelete => json!({
236                "name": self.name(),
237                "description": "Delete a chunk by namespace + id. If the namespace is protected, provide the access token.",
238                "inputSchema": {
239                    "type": "object",
240                    "properties": {
241                        "namespace": {"type": "string"},
242                        "id": {"type": "string"},
243                        "token": {"type": "string", "description": "Access token for protected namespaces"}
244                    },
245                    "required": ["namespace", "id"]
246                }
247            }),
248            Self::MemoryPurgeNamespace => json!({
249                "name": self.name(),
250                "description": "Delete all chunks in a namespace. If the namespace is protected, provide the access token.",
251                "inputSchema": {
252                    "type": "object",
253                    "properties": {
254                        "namespace": {"type": "string"},
255                        "token": {"type": "string", "description": "Access token for protected namespaces"}
256                    },
257                    "required": ["namespace"]
258                }
259            }),
260            Self::NamespaceCreateToken => json!({
261                "name": self.name(),
262                "description": "Create an access token for a namespace. Once created, the namespace will require this token for access.",
263                "inputSchema": {
264                    "type": "object",
265                    "properties": {
266                        "namespace": {"type": "string", "description": "The namespace to protect with a token"},
267                        "description": {"type": "string", "description": "Optional description for the token"}
268                    },
269                    "required": ["namespace"]
270                }
271            }),
272            Self::NamespaceRevokeToken => json!({
273                "name": self.name(),
274                "description": "Revoke the access token for a namespace, making it publicly accessible again.",
275                "inputSchema": {
276                    "type": "object",
277                    "properties": {
278                        "namespace": {"type": "string", "description": "The namespace to remove token protection from"}
279                    },
280                    "required": ["namespace"]
281                }
282            }),
283            Self::NamespaceListProtected => json!({
284                "name": self.name(),
285                "description": "List all namespaces that have token protection enabled.",
286                "inputSchema": {
287                    "type": "object",
288                    "properties": {},
289                    "required": []
290                }
291            }),
292            Self::NamespaceSecurityStatus => json!({
293                "name": self.name(),
294                "description": "Check if namespace security (token-based access control) is enabled.",
295                "inputSchema": {
296                    "type": "object",
297                    "properties": {},
298                    "required": []
299                }
300            }),
301            Self::Dive => json!({
302                "name": self.name(),
303                "description": "Deep exploration with all onion layers. Shows ALL layers (outer/middle/inner/core), both BM25 and vector scores, full metadata, and related chunks.",
304                "inputSchema": {
305                    "type": "object",
306                    "properties": {
307                        "namespace": {"type": "string", "description": "Namespace to search in"},
308                        "query": {"type": "string", "description": "Search query text"},
309                        "limit": {"type": "integer", "default": 5, "description": "Maximum results per layer"},
310                        "verbose": {"type": "boolean", "default": false, "description": "Show full text and metadata"}
311                    },
312                    "required": ["namespace", "query"]
313                }
314            }),
315        }
316    }
317}
318
319/// Shared `initialize` result used by every MCP transport.
320///
321/// rust-memex currently exposes a tools-only MCP surface. Do not advertise
322/// `resources` here until `resources/list` and related methods are implemented.
323pub fn shared_initialize_result() -> Value {
324    json!({
325        "protocolVersion": PROTOCOL_VERSION,
326        "serverInfo": {
327            "name": SERVER_NAME,
328            "version": env!("CARGO_PKG_VERSION")
329        },
330        "capabilities": {
331            "tools": {}
332        }
333    })
334}
335
336/// Shared `tools/list` result used by every MCP transport.
337pub fn shared_tools_list_result() -> Value {
338    let tools: Vec<Value> = McpTool::ALL.into_iter().map(McpTool::definition).collect();
339    json!({ "tools": tools })
340}
341
342#[derive(Clone)]
343pub struct McpCore {
344    rag: Arc<RAGPipeline>,
345    hybrid_searcher: Option<Arc<HybridSearcher>>,
346    embedding_client: Arc<Mutex<EmbeddingClient>>,
347    max_request_bytes: usize,
348    allowed_paths: Vec<String>,
349    auth_manager: Arc<AuthManager>,
350}
351
352impl McpCore {
353    pub fn new(
354        rag: Arc<RAGPipeline>,
355        hybrid_searcher: Option<Arc<HybridSearcher>>,
356        embedding_client: Arc<Mutex<EmbeddingClient>>,
357        max_request_bytes: usize,
358        allowed_paths: Vec<String>,
359        auth_manager: Arc<AuthManager>,
360    ) -> Self {
361        Self {
362            rag,
363            hybrid_searcher,
364            embedding_client,
365            max_request_bytes,
366            allowed_paths,
367            auth_manager,
368        }
369    }
370
371    pub fn rag(&self) -> Arc<RAGPipeline> {
372        self.rag.clone()
373    }
374
375    /// Access the unified auth manager (Track C replacement for the legacy
376    /// `NamespaceAccessManager`). Always available — if no tokens are
377    /// configured, every authorize() call for that namespace is permitted.
378    pub fn auth_manager(&self) -> &AuthManager {
379        &self.auth_manager
380    }
381
382    /// MCP-tool per-request namespace access check.
383    ///
384    /// Preserves the legacy semantic of `NamespaceAccessManager::verify_access`:
385    ///   * If no tokens are registered that cover `namespace`, access is allowed
386    ///     (namespace is "open").
387    ///   * If any token covers `namespace`, a matching plaintext token must be
388    ///     supplied via the MCP tool-call `token` argument and it must grant
389    ///     write scope for the namespace.
390    ///
391    /// Returns `Ok(())` on success, `Err(message)` on denial.
392    async fn verify_tool_access(&self, namespace: &str, token: Option<&str>) -> Result<()> {
393        let tokens = self.auth_manager.list_tokens().await;
394        let namespace_has_token = tokens
395            .iter()
396            .any(|entry| entry.has_namespace_access(namespace));
397
398        if !namespace_has_token {
399            // No tokens protect this namespace — legacy "open" behavior.
400            return Ok(());
401        }
402
403        match token {
404            Some(plaintext) => match self
405                .auth_manager
406                .authorize(plaintext, &Scope::Write, Some(namespace))
407                .await
408            {
409                Ok(_) => Ok(()),
410                Err(AuthDenial::InvalidToken) | Err(AuthDenial::MissingToken) => Err(anyhow!(
411                    "Access denied: invalid token for namespace '{}'",
412                    namespace
413                )),
414                Err(denial) => Err(anyhow!("{}", denial)),
415            },
416            None => Err(anyhow!(
417                "Access denied: namespace '{}' requires a token. Use namespace_create_token to generate one.",
418                namespace
419            )),
420        }
421    }
422
423    pub fn hybrid_searcher(&self) -> Option<Arc<HybridSearcher>> {
424        self.hybrid_searcher.clone()
425    }
426
427    pub async fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
428        self.embedding_client.lock().await.embed(query).await
429    }
430
431    pub async fn handle_request(&self, request: Value, transport: McpTransport) -> Option<Value> {
432        self.handle_jsonrpc_request(request, transport)
433            .await
434            .into_option()
435    }
436
437    pub async fn handle_payload(&self, payload: &str, transport: McpTransport) -> Option<Value> {
438        let request = match parse_jsonrpc_payload(payload, self.max_request_bytes) {
439            Ok(request) => request,
440            Err(response) => return Some(response),
441        };
442
443        self.handle_request(request, transport).await
444    }
445
446    pub async fn handle_jsonrpc_request(
447        &self,
448        request: Value,
449        transport: McpTransport,
450    ) -> McpDispatch {
451        let method_name = request["method"].as_str().unwrap_or("");
452
453        if method_name.starts_with("notifications/") {
454            return McpDispatch::Notification;
455        }
456
457        let id = match request.get("id") {
458            Some(value) if value.is_string() || value.is_number() => value.clone(),
459            _ => {
460                return McpDispatch::Response(json!({
461                    "jsonrpc": "2.0",
462                    "id": Value::Null,
463                    "error": {
464                        "code": -32600,
465                        "message": "Invalid Request: missing or invalid 'id' field"
466                    }
467                }));
468            }
469        };
470
471        let method = match McpMethod::from_name(method_name) {
472            Some(method) => method,
473            None => {
474                return McpDispatch::Response(jsonrpc_error(
475                    Some(&id),
476                    -32601,
477                    format!("Unknown method: {}", method_name),
478                ));
479            }
480        };
481
482        let result = match method {
483            McpMethod::Initialize => shared_initialize_result(),
484            McpMethod::ToolsList => shared_tools_list_result(),
485            McpMethod::ToolsCall => match self.handle_tool_call(&request, &id, transport).await {
486                Ok(result) => result,
487                Err(response) => return McpDispatch::Response(response),
488            },
489        };
490
491        McpDispatch::Response(jsonrpc_success(&id, result))
492    }
493
494    async fn handle_tool_call(
495        &self,
496        request: &Value,
497        id: &Value,
498        transport: McpTransport,
499    ) -> std::result::Result<Value, Value> {
500        let tool_name = request["params"]["name"].as_str().unwrap_or("");
501        let tool = McpTool::from_name(tool_name).ok_or_else(|| {
502            jsonrpc_error(Some(id), -32601, format!("Unknown tool: {}", tool_name))
503        })?;
504        let args = &request["params"]["arguments"];
505
506        match tool {
507            McpTool::Health => {
508                let mut status = json!({
509                    "version": env!("CARGO_PKG_VERSION"),
510                    "db_path": self.rag.storage_manager().lance_path(),
511                    "backend": "mlx",
512                    "mlx_server": self.rag.mlx_connected_to(),
513                });
514
515                if let Some(transport_name) = transport.health_transport() {
516                    status["transport"] = json!(transport_name);
517                }
518
519                Ok(text_result_from_json(&status))
520            }
521            McpTool::RagIndex => {
522                let path_str = args["path"].as_str().unwrap_or("");
523                let namespace = args["namespace"].as_str();
524
525                let validated_path = validate_path(path_str, &self.allowed_paths)
526                    .map_err(|e| jsonrpc_error(Some(id), -32602, e.to_string()))?;
527
528                match self.rag.index_document(&validated_path, namespace).await {
529                    Ok(_) => Ok(text_result(format!("Indexed: {}", path_str))),
530                    Err(e) => Ok(tool_error(e)),
531                }
532            }
533            McpTool::MemoryUpsert => {
534                let namespace = args["namespace"].as_str().unwrap_or("default");
535                let token = args["token"].as_str();
536
537                self.verify_tool_access(namespace, token)
538                    .await
539                    .map_err(|e| jsonrpc_error(Some(id), -32603, e.to_string()))?;
540
541                let item_id = args["id"].as_str().unwrap_or("").to_string();
542                let text = args["text"].as_str().unwrap_or("").to_string();
543                let metadata = args.get("metadata").cloned().unwrap_or_else(|| json!({}));
544
545                match self
546                    .rag
547                    .memory_upsert(namespace, item_id.clone(), text, metadata)
548                    .await
549                {
550                    Ok(_) => Ok(text_result(format!("Upserted {}", item_id))),
551                    Err(e) => Ok(tool_error(e)),
552                }
553            }
554            McpTool::MemoryGet => {
555                let namespace = args["namespace"].as_str().unwrap_or("default");
556                let token = args["token"].as_str();
557
558                self.verify_tool_access(namespace, token)
559                    .await
560                    .map_err(|e| jsonrpc_error(Some(id), -32603, e.to_string()))?;
561
562                let item_id = args["id"].as_str().unwrap_or("");
563                match self.rag.lookup_memory(namespace, item_id).await {
564                    Ok(Some(doc)) => Ok(text_result_from_json(&doc)),
565                    Ok(None) => Ok(text_result("Not found")),
566                    Err(e) => Ok(tool_error(e)),
567                }
568            }
569            McpTool::MemorySearch => {
570                let namespace = args["namespace"].as_str().unwrap_or("default");
571                let token = args["token"].as_str();
572
573                self.verify_tool_access(namespace, token)
574                    .await
575                    .map_err(|e| jsonrpc_error(Some(id), -32603, e.to_string()))?;
576
577                let query = args["query"].as_str().unwrap_or("");
578                let limit = requested_limit(args, 5);
579                let mode = requested_search_mode(query, args);
580                let options = requested_search_options(args);
581
582                if let Some(hybrid_result) = self
583                    .try_hybrid_search(query, Some(namespace), limit, (mode, options.clone()), id)
584                    .await?
585                {
586                    return Ok(hybrid_result);
587                }
588
589                match self
590                    .rag
591                    .search_with_options(Some(namespace), query, limit, options)
592                    .await
593                {
594                    Ok(results) => Ok(text_result_from_json(&results)),
595                    Err(e) => Ok(tool_error(e)),
596                }
597            }
598            McpTool::MemoryDelete => {
599                let namespace = args["namespace"].as_str().unwrap_or("default");
600                let token = args["token"].as_str();
601
602                self.verify_tool_access(namespace, token)
603                    .await
604                    .map_err(|e| jsonrpc_error(Some(id), -32603, e.to_string()))?;
605
606                let item_id = args["id"].as_str().unwrap_or("");
607                match self.rag.remove_memory(namespace, item_id).await {
608                    Ok(deleted) => Ok(text_result(format!("Deleted {} rows", deleted))),
609                    Err(e) => Ok(tool_error(e)),
610                }
611            }
612            McpTool::MemoryPurgeNamespace => {
613                let namespace = args["namespace"].as_str().unwrap_or("default");
614                let token = args["token"].as_str();
615
616                self.verify_tool_access(namespace, token)
617                    .await
618                    .map_err(|e| jsonrpc_error(Some(id), -32603, e.to_string()))?;
619
620                match self.rag.clear_namespace(namespace).await {
621                    Ok(deleted) => Ok(text_result(format!(
622                        "Purged namespace '{}', removed {} rows",
623                        namespace, deleted
624                    ))),
625                    Err(e) => Ok(tool_error(e)),
626                }
627            }
628            McpTool::NamespaceCreateToken => {
629                let namespace = args["namespace"].as_str().unwrap_or("");
630                let description = args["description"].as_str().map(ToOwned::to_owned);
631
632                if namespace.is_empty() {
633                    return Ok(tool_error_message("Namespace is required"));
634                }
635
636                // Track C: map legacy per-namespace create to AuthManager.
637                // id == namespace preserves revoke-by-namespace and
638                // list-protected semantics without a separate mapping table.
639                // Scopes grant full access to the namespace. Wildcard
640                // namespaces are not issued via this MCP path.
641                //
642                // Legacy `TokenStore::create_token` overwrote on collision
643                // (HashMap::insert). Preserve that idempotence by revoking
644                // an existing entry before creating — effectively a rotate.
645                let description = description
646                    .unwrap_or_else(|| format!("Auto-created for namespace '{}'", namespace));
647                let _ = self.auth_manager.revoke_token(namespace).await;
648                match self
649                    .auth_manager
650                    .create_token(
651                        namespace.to_string(),
652                        vec![Scope::Read, Scope::Write, Scope::Admin],
653                        vec![namespace.to_string()],
654                        None,
655                        description,
656                    )
657                    .await
658                {
659                    Ok(token) => Ok(text_result(format!(
660                        "Token created for namespace '{}'. Store this token securely - it won't be shown again!\n\nToken: {}",
661                        namespace, token
662                    ))),
663                    Err(e) => Ok(tool_error(e)),
664                }
665            }
666            McpTool::NamespaceRevokeToken => {
667                let namespace = args["namespace"].as_str().unwrap_or("");
668
669                if namespace.is_empty() {
670                    return Ok(tool_error_message("Namespace is required"));
671                }
672
673                match self.auth_manager.revoke_token(namespace).await {
674                    Ok(true) => Ok(text_result(format!(
675                        "Token revoked for namespace '{}'. The namespace is now publicly accessible.",
676                        namespace
677                    ))),
678                    Ok(false) => Ok(text_result(format!(
679                        "No token found for namespace '{}'.",
680                        namespace
681                    ))),
682                    Err(e) => Ok(tool_error(e)),
683                }
684            }
685            McpTool::NamespaceListProtected => {
686                let tokens = self.auth_manager.list_tokens().await;
687                // Legacy shape: one row per protected namespace. Under the v2
688                // auth model a single token can cover many namespaces, so we
689                // flatten: every non-wildcard namespace listed by any token
690                // becomes an entry. Dedup by namespace, keeping the most
691                // recently created description.
692                let mut protected: std::collections::BTreeMap<String, (i64, Option<String>)> =
693                    std::collections::BTreeMap::new();
694                for entry in &tokens {
695                    let created_at = entry.created_at.timestamp();
696                    let desc = Some(entry.description.clone());
697                    for ns in &entry.namespaces {
698                        if ns == "*" {
699                            continue;
700                        }
701                        protected
702                            .entry(ns.clone())
703                            .and_modify(|existing| {
704                                if created_at > existing.0 {
705                                    *existing = (created_at, desc.clone());
706                                }
707                            })
708                            .or_insert_with(|| (created_at, desc.clone()));
709                    }
710                }
711
712                if protected.is_empty() {
713                    Ok(text_result(
714                        "No namespaces are currently protected with tokens.",
715                    ))
716                } else {
717                    let list: Vec<Value> = protected
718                        .into_iter()
719                        .map(|(namespace, (created_at, description))| {
720                            json!({
721                                "namespace": namespace,
722                                "created_at": created_at,
723                                "description": description
724                            })
725                        })
726                        .collect();
727                    Ok(pretty_text_result_from_json(&list))
728                }
729            }
730            McpTool::NamespaceSecurityStatus => {
731                // Track C: "enabled" now means "at least one token is
732                // configured". An empty store is equivalent to the old
733                // `enabled=false` state: every namespace is accessible.
734                let has_any = self.auth_manager.has_any_tokens().await;
735                let tokens = self.auth_manager.list_tokens().await;
736                let protected_namespaces: std::collections::BTreeSet<String> = tokens
737                    .iter()
738                    .flat_map(|entry| entry.namespaces.iter().cloned())
739                    .filter(|ns| ns != "*")
740                    .collect();
741
742                Ok(text_result(format!(
743                    "Namespace security: {}\nProtected namespaces: {}\n\nNote: When security is disabled, all namespaces are accessible without tokens.",
744                    if has_any { "ENABLED" } else { "DISABLED" },
745                    protected_namespaces.len()
746                )))
747            }
748            McpTool::Dive => {
749                let namespace = args["namespace"].as_str().unwrap_or("");
750                let query = args["query"].as_str().unwrap_or("");
751                let limit = args["limit"].as_u64().unwrap_or(5) as usize;
752                let verbose = args["verbose"].as_bool().unwrap_or(false);
753
754                if namespace.is_empty() || query.is_empty() {
755                    return Err(jsonrpc_error(
756                        Some(id),
757                        -32602,
758                        "namespace and query are required",
759                    ));
760                }
761
762                let layers = [
763                    (Some(SliceLayer::Outer), "outer"),
764                    (Some(SliceLayer::Middle), "middle"),
765                    (Some(SliceLayer::Inner), "inner"),
766                    (Some(SliceLayer::Core), "core"),
767                ];
768
769                let mut all_results: Vec<Value> = Vec::new();
770
771                for (layer_filter, layer_name) in &layers {
772                    match self
773                        .rag
774                        .memory_search_with_layer(namespace, query, limit, *layer_filter)
775                        .await
776                    {
777                        Ok(results) => {
778                            let layer_results: Vec<Value> = results
779                                .iter()
780                                .map(|result| {
781                                    let mut object = json!({
782                                        "id": result.id,
783                                        "score": result.score,
784                                        "keywords": result.keywords,
785                                        "layer": result.layer.map(|layer| layer.name()),
786                                        "can_expand": result.can_expand(),
787                                        "parent_id": result.parent_id,
788                                    });
789
790                                    if verbose {
791                                        object["text"] = json!(result.text);
792                                        object["metadata"] = result.metadata.clone();
793                                        object["children_ids"] = json!(result.children_ids);
794                                    } else {
795                                        let preview: String =
796                                            result.text.chars().take(200).collect();
797                                        object["preview"] = json!(preview);
798                                    }
799
800                                    object
801                                })
802                                .collect();
803
804                            all_results.push(json!({
805                                "layer": layer_name,
806                                "count": results.len(),
807                                "results": layer_results
808                            }));
809                        }
810                        Err(e) => {
811                            all_results.push(json!({
812                                "layer": layer_name,
813                                "error": e.to_string()
814                            }));
815                        }
816                    }
817                }
818
819                Ok(pretty_text_result_from_json(&json!({
820                    "query": query,
821                    "namespace": namespace,
822                    "limit_per_layer": limit,
823                    "verbose": verbose,
824                    "layers": all_results
825                })))
826            }
827        }
828    }
829
830    async fn try_hybrid_search(
831        &self,
832        query: &str,
833        namespace: Option<&str>,
834        limit: usize,
835        search: (SearchMode, SearchOptions),
836        id: &Value,
837    ) -> std::result::Result<Option<Value>, Value> {
838        let (mode, options) = search;
839        if mode == SearchMode::Vector {
840            return Ok(None);
841        }
842
843        let Some(hybrid_searcher) = &self.hybrid_searcher else {
844            return Ok(None);
845        };
846
847        let query_embedding = self
848            .embedding_client
849            .lock()
850            .await
851            .embed(query)
852            .await
853            .map_err(|e| jsonrpc_error(Some(id), -32603, format!("Embedding failed: {}", e)))?;
854
855        let results = hybrid_searcher
856            .search(query, query_embedding, namespace, limit, options)
857            .await
858            .map_err(|e| jsonrpc_error(Some(id), -32603, format!("Hybrid search failed: {}", e)))?;
859
860        let payload: Vec<Value> = results
861            .iter()
862            .map(|result| {
863                json!({
864                    "id": result.id,
865                    "namespace": result.namespace,
866                    "text": result.document,
867                    "score": result.combined_score,
868                    "vector_score": result.vector_score,
869                    "bm25_score": result.bm25_score,
870                    "metadata": result.metadata
871                })
872            })
873            .collect();
874
875        Ok(Some(text_result_from_json(&payload)))
876    }
877}
878
879fn requested_search_mode(query: &str, args: &Value) -> SearchMode {
880    if args["auto_route"].as_bool().unwrap_or(false) {
881        let router = QueryRouter::new();
882        let decision = router.route(query);
883        match decision.recommended_mode.mode {
884            SearchModeRecommendation::Vector => SearchMode::Vector,
885            SearchModeRecommendation::Bm25 => SearchMode::Keyword,
886            SearchModeRecommendation::Hybrid => SearchMode::Hybrid,
887        }
888    } else {
889        match args["mode"].as_str() {
890            Some("vector") => SearchMode::Vector,
891            Some("bm25") | Some("keyword") => SearchMode::Keyword,
892            _ => SearchMode::Hybrid,
893        }
894    }
895}
896
897fn requested_layer_filter(args: &Value) -> Option<SliceLayer> {
898    if args["deep"].as_bool().unwrap_or(false) {
899        None
900    } else {
901        Some(SliceLayer::Outer)
902    }
903}
904
905fn requested_search_options(args: &Value) -> SearchOptions {
906    SearchOptions {
907        layer_filter: requested_layer_filter(args),
908        project_filter: args["project"]
909            .as_str()
910            .map(|value| value.trim().to_string())
911            .filter(|value| !value.is_empty()),
912    }
913}
914
915fn requested_limit(args: &Value, default: usize) -> usize {
916    args["k"]
917        .as_u64()
918        .or_else(|| args["limit"].as_u64())
919        .map(|value| value as usize)
920        .unwrap_or(default)
921}
922
923fn parse_jsonrpc_payload(
924    payload: &str,
925    max_request_bytes: usize,
926) -> std::result::Result<Value, Value> {
927    let trimmed = payload.trim();
928
929    if trimmed.len() > max_request_bytes {
930        return Err(jsonrpc_error(
931            None,
932            -32600,
933            format!(
934                "Request too large: {} bytes (max {})",
935                trimmed.len(),
936                max_request_bytes
937            ),
938        ));
939    }
940
941    serde_json::from_str(trimmed)
942        .map_err(|error| jsonrpc_error(None, -32700, format!("Parse error: {}", error)))
943}
944
945fn tool_error(error: impl ToString) -> Value {
946    tool_error_message(error.to_string())
947}
948
949fn tool_error_message(message: impl Into<String>) -> Value {
950    json!({
951        "error": {"message": message.into()}
952    })
953}
954
955fn text_result(text: impl Into<String>) -> Value {
956    json!({
957        "content": [{"type": "text", "text": text.into()}]
958    })
959}
960
961fn text_result_from_json<T: serde::Serialize>(value: &T) -> Value {
962    text_result(serde_json::to_string(value).unwrap_or_default())
963}
964
965fn pretty_text_result_from_json<T: serde::Serialize>(value: &T) -> Value {
966    text_result(serde_json::to_string_pretty(value).unwrap_or_default())
967}
968
969/// Validates a file path to prevent path traversal attacks.
970/// Returns the canonicalized path if valid, or an error if the path is unsafe.
971fn validate_path(path_str: &str, allowed_paths: &[String]) -> Result<std::path::PathBuf> {
972    if path_str.is_empty() {
973        return Err(anyhow!("Path cannot be empty"));
974    }
975
976    if path_str.contains("..") || path_str.contains('\0') || path_str.contains('\n') {
977        return Err(anyhow!(
978            "Path traversal detected: invalid sequences in '{}'",
979            path_str
980        ));
981    }
982
983    let canonical = crate::path_utils::sanitize_existing_path(path_str)?;
984
985    let is_safe = if allowed_paths.is_empty() {
986        let home = std::env::var("HOME")
987            .or_else(|_| std::env::var("USERPROFILE"))
988            .map(std::path::PathBuf::from)
989            .ok();
990        let cwd = std::env::current_dir().ok();
991
992        home.as_ref()
993            .map(|path| canonical.starts_with(path))
994            .unwrap_or(false)
995            || cwd
996                .as_ref()
997                .map(|path| canonical.starts_with(path))
998                .unwrap_or(false)
999    } else {
1000        allowed_paths.iter().any(|allowed| {
1001            let expanded_allowed = shellexpand::tilde(allowed).to_string();
1002            let allowed_path = Path::new(&expanded_allowed);
1003            let canonical_allowed = allowed_path
1004                .canonicalize()
1005                .unwrap_or_else(|_| std::path::PathBuf::from(&expanded_allowed));
1006
1007            canonical.starts_with(&canonical_allowed)
1008        })
1009    };
1010
1011    if !is_safe {
1012        let allowed_info = if allowed_paths.is_empty() {
1013            "$HOME and current working directory".to_string()
1014        } else {
1015            format!("configured paths: {:?}", allowed_paths)
1016        };
1017
1018        return Err(anyhow!(
1019            "Access denied: path '{}' is outside allowed directories ({})",
1020            path_str,
1021            allowed_info
1022        ));
1023    }
1024
1025    Ok(canonical)
1026}
1027
1028#[cfg(test)]
1029mod tests {
1030    use super::{
1031        jsonrpc_error, jsonrpc_success, parse_jsonrpc_payload, requested_layer_filter,
1032        requested_limit, requested_search_options, shared_initialize_result,
1033        shared_tools_list_result,
1034    };
1035    use crate::rag::{SearchOptions, SliceLayer};
1036    use serde_json::{Value, json};
1037
1038    #[test]
1039    fn jsonrpc_error_omits_missing_id() {
1040        let response = jsonrpc_error(None, -32600, "boom");
1041        assert_eq!(response["jsonrpc"], "2.0");
1042        assert_eq!(response["error"]["code"], -32600);
1043        assert_eq!(response.get("id"), None);
1044    }
1045
1046    #[test]
1047    fn jsonrpc_success_omits_null_id() {
1048        let response = jsonrpc_success(&Value::Null, json!({"ok": true}));
1049        assert_eq!(response["jsonrpc"], "2.0");
1050        assert!(response["result"]["ok"].as_bool().unwrap());
1051        assert_eq!(response.get("id"), None);
1052    }
1053
1054    #[test]
1055    fn initialize_advertises_only_tools_capability() {
1056        let response = shared_initialize_result();
1057        assert_eq!(response["protocolVersion"], "2024-11-05");
1058        assert_eq!(response["capabilities"], json!({ "tools": {} }));
1059    }
1060
1061    #[test]
1062    fn tool_list_contains_extended_stdio_and_http_surface() {
1063        let result = shared_tools_list_result();
1064        let tools = result["tools"]
1065            .as_array()
1066            .expect("tools list should be an array");
1067        let names: Vec<&str> = tools
1068            .iter()
1069            .filter_map(|tool| tool["name"].as_str())
1070            .collect();
1071
1072        assert!(names.contains(&"rag_index"));
1073        assert!(names.contains(&"memory_purge_namespace"));
1074        assert!(names.contains(&"namespace_create_token"));
1075        assert!(names.contains(&"dive"));
1076    }
1077
1078    #[test]
1079    fn parse_jsonrpc_payload_rejects_oversized_requests() {
1080        let response = parse_jsonrpc_payload("123456", 5).expect_err("payload should be rejected");
1081        assert_eq!(response["error"]["code"], -32600);
1082        assert!(
1083            response["error"]["message"]
1084                .as_str()
1085                .unwrap_or("")
1086                .contains("Request too large")
1087        );
1088    }
1089
1090    #[test]
1091    fn parse_jsonrpc_payload_returns_jsonrpc_parse_error() {
1092        let response = parse_jsonrpc_payload("{", 1024).expect_err("payload should not parse");
1093        assert_eq!(response["error"]["code"], -32700);
1094        assert!(
1095            response["error"]["message"]
1096                .as_str()
1097                .unwrap_or("")
1098                .contains("Parse error")
1099        );
1100    }
1101
1102    #[test]
1103    fn parse_jsonrpc_payload_accepts_valid_json_with_whitespace() {
1104        let request = parse_jsonrpc_payload(
1105            "  {\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{}}  ",
1106            1024,
1107        )
1108        .expect("payload should parse");
1109
1110        assert_eq!(request["method"], "initialize");
1111        assert_eq!(request["id"], 1);
1112    }
1113
1114    #[test]
1115    fn requested_limit_prefers_request_k_over_default() {
1116        assert_eq!(requested_limit(&json!({"k": 17}), 5), 17);
1117        assert_eq!(requested_limit(&json!({}), 5), 5);
1118    }
1119
1120    #[test]
1121    fn requested_limit_accepts_limit_alias() {
1122        assert_eq!(requested_limit(&json!({"limit": 11}), 5), 11);
1123    }
1124
1125    #[test]
1126    fn requested_layer_filter_defaults_to_outer_only() {
1127        assert_eq!(requested_layer_filter(&json!({})), Some(SliceLayer::Outer));
1128    }
1129
1130    #[test]
1131    fn requested_layer_filter_allows_deep_search() {
1132        assert_eq!(requested_layer_filter(&json!({"deep": true})), None);
1133    }
1134
1135    #[test]
1136    fn requested_search_options_captures_project_filter() {
1137        assert_eq!(
1138            requested_search_options(&json!({"project": "Vista"})),
1139            SearchOptions {
1140                layer_filter: Some(SliceLayer::Outer),
1141                project_filter: Some("Vista".to_string()),
1142            }
1143        );
1144    }
1145}