1use crate::{
2 auth::{AuthDenial, AuthManager, Scope},
3 embeddings::EmbeddingClient,
4 query::{QueryRouter, SearchModeRecommendation},
5 rag::{RAGPipeline, SearchOptions, SliceLayer},
6 search::{HybridSearcher, SearchMode},
7};
8use anyhow::{Result, anyhow};
9use serde_json::{Value, json};
10use std::path::Path;
11use std::sync::Arc;
12use tokio::sync::Mutex;
13
14pub const PROTOCOL_VERSION: &str = "2024-11-05";
15pub const SERVER_NAME: &str = "rust-memex";
16
17pub fn jsonrpc_error(id: Option<&Value>, code: i32, message: impl Into<String>) -> Value {
20 let message = message.into();
21
22 match id {
23 Some(id) if !id.is_null() => json!({
24 "jsonrpc": "2.0",
25 "error": {"code": code, "message": message},
26 "id": id
27 }),
28 _ => json!({
29 "jsonrpc": "2.0",
30 "error": {"code": code, "message": message}
31 }),
32 }
33}
34
35pub fn jsonrpc_success(id: &Value, result: Value) -> Value {
38 if id.is_null() {
39 json!({
40 "jsonrpc": "2.0",
41 "result": result
42 })
43 } else {
44 json!({
45 "jsonrpc": "2.0",
46 "id": id,
47 "result": result
48 })
49 }
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum McpTransport {
54 Stdio,
55 HttpSse,
56}
57
58impl McpTransport {
59 fn health_transport(self) -> Option<&'static str> {
60 match self {
61 Self::Stdio => None,
62 Self::HttpSse => Some("mcp-over-sse"),
63 }
64 }
65}
66
67pub enum McpDispatch {
68 Notification,
69 Response(Value),
70}
71
72impl McpDispatch {
73 pub fn into_option(self) -> Option<Value> {
74 match self {
75 Self::Notification => None,
76 Self::Response(response) => Some(response),
77 }
78 }
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82enum McpMethod {
83 Initialize,
84 ToolsList,
85 ToolsCall,
86}
87
88impl McpMethod {
89 fn from_name(name: &str) -> Option<Self> {
90 match name {
91 "initialize" => Some(Self::Initialize),
92 "tools/list" => Some(Self::ToolsList),
93 "tools/call" => Some(Self::ToolsCall),
94 _ => None,
95 }
96 }
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100enum McpTool {
101 Health,
102 RagIndex,
103 MemoryUpsert,
104 MemoryGet,
105 MemorySearch,
106 MemoryDelete,
107 MemoryPurgeNamespace,
108 NamespaceCreateToken,
109 NamespaceRevokeToken,
110 NamespaceListProtected,
111 NamespaceSecurityStatus,
112 Dive,
113}
114
115impl McpTool {
116 const ALL: [Self; 12] = [
117 Self::Health,
118 Self::RagIndex,
119 Self::MemoryUpsert,
120 Self::MemoryGet,
121 Self::MemorySearch,
122 Self::MemoryDelete,
123 Self::MemoryPurgeNamespace,
124 Self::NamespaceCreateToken,
125 Self::NamespaceRevokeToken,
126 Self::NamespaceListProtected,
127 Self::NamespaceSecurityStatus,
128 Self::Dive,
129 ];
130
131 fn from_name(name: &str) -> Option<Self> {
132 match name {
133 "health" => Some(Self::Health),
134 "rag_index" => Some(Self::RagIndex),
135 "memory_upsert" => Some(Self::MemoryUpsert),
136 "memory_get" => Some(Self::MemoryGet),
137 "memory_search" => Some(Self::MemorySearch),
138 "memory_delete" => Some(Self::MemoryDelete),
139 "memory_purge_namespace" => Some(Self::MemoryPurgeNamespace),
140 "namespace_create_token" => Some(Self::NamespaceCreateToken),
141 "namespace_revoke_token" => Some(Self::NamespaceRevokeToken),
142 "namespace_list_protected" => Some(Self::NamespaceListProtected),
143 "namespace_security_status" => Some(Self::NamespaceSecurityStatus),
144 "dive" => Some(Self::Dive),
145 _ => None,
146 }
147 }
148
149 fn name(self) -> &'static str {
150 match self {
151 Self::Health => "health",
152 Self::RagIndex => "rag_index",
153 Self::MemoryUpsert => "memory_upsert",
154 Self::MemoryGet => "memory_get",
155 Self::MemorySearch => "memory_search",
156 Self::MemoryDelete => "memory_delete",
157 Self::MemoryPurgeNamespace => "memory_purge_namespace",
158 Self::NamespaceCreateToken => "namespace_create_token",
159 Self::NamespaceRevokeToken => "namespace_revoke_token",
160 Self::NamespaceListProtected => "namespace_list_protected",
161 Self::NamespaceSecurityStatus => "namespace_security_status",
162 Self::Dive => "dive",
163 }
164 }
165
166 fn definition(self) -> Value {
167 match self {
168 Self::Health => json!({
169 "name": self.name(),
170 "description": "Health/status of rust-memex server",
171 "inputSchema": {
172 "type": "object",
173 "properties": {},
174 "required": []
175 }
176 }),
177 Self::RagIndex => json!({
178 "name": self.name(),
179 "description": "Index a document for RAG",
180 "inputSchema": {
181 "type": "object",
182 "properties": {
183 "path": {"type": "string"},
184 "namespace": {"type": "string"}
185 },
186 "required": ["path"]
187 }
188 }),
189 Self::MemoryUpsert => json!({
190 "name": self.name(),
191 "description": "Upsert a text chunk into vector memory. If the namespace is protected, provide the access token.",
192 "inputSchema": {
193 "type": "object",
194 "properties": {
195 "namespace": {"type": "string"},
196 "id": {"type": "string"},
197 "text": {"type": "string"},
198 "metadata": {"type": "object"},
199 "token": {"type": "string", "description": "Access token for protected namespaces"}
200 },
201 "required": ["namespace", "id", "text"]
202 }
203 }),
204 Self::MemoryGet => json!({
205 "name": self.name(),
206 "description": "Get a stored chunk by namespace + id. If the namespace is protected, provide the access token.",
207 "inputSchema": {
208 "type": "object",
209 "properties": {
210 "namespace": {"type": "string"},
211 "id": {"type": "string"},
212 "token": {"type": "string", "description": "Access token for protected namespaces"}
213 },
214 "required": ["namespace", "id"]
215 }
216 }),
217 Self::MemorySearch => json!({
218 "name": self.name(),
219 "description": "Semantic search within a namespace. If the namespace is protected, provide the access token.",
220 "inputSchema": {
221 "type": "object",
222 "properties": {
223 "namespace": {"type": "string"},
224 "query": {"type": "string"},
225 "k": {"type": "integer", "default": 5},
226 "project": {"type": "string", "description": "Filter to documents whose metadata project/project_id matches this value"},
227 "deep": {"type": "boolean", "default": false, "description": "Include all onion layers instead of only outer summaries"},
228 "mode": {"type": "string", "enum": ["vector", "bm25", "hybrid"], "default": "hybrid", "description": "Search mode: vector (semantic), bm25 (keyword), hybrid (both)"},
229 "auto_route": {"type": "boolean", "default": false, "description": "Auto-detect query intent and select optimal search mode. Overrides mode when true."},
230 "token": {"type": "string", "description": "Access token for protected namespaces"}
231 },
232 "required": ["namespace", "query"]
233 }
234 }),
235 Self::MemoryDelete => json!({
236 "name": self.name(),
237 "description": "Delete a chunk by namespace + id. If the namespace is protected, provide the access token.",
238 "inputSchema": {
239 "type": "object",
240 "properties": {
241 "namespace": {"type": "string"},
242 "id": {"type": "string"},
243 "token": {"type": "string", "description": "Access token for protected namespaces"}
244 },
245 "required": ["namespace", "id"]
246 }
247 }),
248 Self::MemoryPurgeNamespace => json!({
249 "name": self.name(),
250 "description": "Delete all chunks in a namespace. If the namespace is protected, provide the access token.",
251 "inputSchema": {
252 "type": "object",
253 "properties": {
254 "namespace": {"type": "string"},
255 "token": {"type": "string", "description": "Access token for protected namespaces"}
256 },
257 "required": ["namespace"]
258 }
259 }),
260 Self::NamespaceCreateToken => json!({
261 "name": self.name(),
262 "description": "Create an access token for a namespace. Once created, the namespace will require this token for access.",
263 "inputSchema": {
264 "type": "object",
265 "properties": {
266 "namespace": {"type": "string", "description": "The namespace to protect with a token"},
267 "description": {"type": "string", "description": "Optional description for the token"}
268 },
269 "required": ["namespace"]
270 }
271 }),
272 Self::NamespaceRevokeToken => json!({
273 "name": self.name(),
274 "description": "Revoke the access token for a namespace, making it publicly accessible again.",
275 "inputSchema": {
276 "type": "object",
277 "properties": {
278 "namespace": {"type": "string", "description": "The namespace to remove token protection from"}
279 },
280 "required": ["namespace"]
281 }
282 }),
283 Self::NamespaceListProtected => json!({
284 "name": self.name(),
285 "description": "List all namespaces that have token protection enabled.",
286 "inputSchema": {
287 "type": "object",
288 "properties": {},
289 "required": []
290 }
291 }),
292 Self::NamespaceSecurityStatus => json!({
293 "name": self.name(),
294 "description": "Check if namespace security (token-based access control) is enabled.",
295 "inputSchema": {
296 "type": "object",
297 "properties": {},
298 "required": []
299 }
300 }),
301 Self::Dive => json!({
302 "name": self.name(),
303 "description": "Deep exploration with all onion layers. Shows ALL layers (outer/middle/inner/core), both BM25 and vector scores, full metadata, and related chunks.",
304 "inputSchema": {
305 "type": "object",
306 "properties": {
307 "namespace": {"type": "string", "description": "Namespace to search in"},
308 "query": {"type": "string", "description": "Search query text"},
309 "limit": {"type": "integer", "default": 5, "description": "Maximum results per layer"},
310 "verbose": {"type": "boolean", "default": false, "description": "Show full text and metadata"}
311 },
312 "required": ["namespace", "query"]
313 }
314 }),
315 }
316 }
317}
318
319pub fn shared_initialize_result() -> Value {
324 json!({
325 "protocolVersion": PROTOCOL_VERSION,
326 "serverInfo": {
327 "name": SERVER_NAME,
328 "version": env!("CARGO_PKG_VERSION")
329 },
330 "capabilities": {
331 "tools": {}
332 }
333 })
334}
335
336pub fn shared_tools_list_result() -> Value {
338 let tools: Vec<Value> = McpTool::ALL.into_iter().map(McpTool::definition).collect();
339 json!({ "tools": tools })
340}
341
342#[derive(Clone)]
343pub struct McpCore {
344 rag: Arc<RAGPipeline>,
345 hybrid_searcher: Option<Arc<HybridSearcher>>,
346 embedding_client: Arc<Mutex<EmbeddingClient>>,
347 max_request_bytes: usize,
348 allowed_paths: Vec<String>,
349 auth_manager: Arc<AuthManager>,
350}
351
352impl McpCore {
353 pub fn new(
354 rag: Arc<RAGPipeline>,
355 hybrid_searcher: Option<Arc<HybridSearcher>>,
356 embedding_client: Arc<Mutex<EmbeddingClient>>,
357 max_request_bytes: usize,
358 allowed_paths: Vec<String>,
359 auth_manager: Arc<AuthManager>,
360 ) -> Self {
361 Self {
362 rag,
363 hybrid_searcher,
364 embedding_client,
365 max_request_bytes,
366 allowed_paths,
367 auth_manager,
368 }
369 }
370
371 pub fn rag(&self) -> Arc<RAGPipeline> {
372 self.rag.clone()
373 }
374
375 pub fn auth_manager(&self) -> &AuthManager {
379 &self.auth_manager
380 }
381
382 async fn verify_tool_access(&self, namespace: &str, token: Option<&str>) -> Result<()> {
393 let tokens = self.auth_manager.list_tokens().await;
394 let namespace_has_token = tokens
395 .iter()
396 .any(|entry| entry.has_namespace_access(namespace));
397
398 if !namespace_has_token {
399 return Ok(());
401 }
402
403 match token {
404 Some(plaintext) => match self
405 .auth_manager
406 .authorize(plaintext, &Scope::Write, Some(namespace))
407 .await
408 {
409 Ok(_) => Ok(()),
410 Err(AuthDenial::InvalidToken) | Err(AuthDenial::MissingToken) => Err(anyhow!(
411 "Access denied: invalid token for namespace '{}'",
412 namespace
413 )),
414 Err(denial) => Err(anyhow!("{}", denial)),
415 },
416 None => Err(anyhow!(
417 "Access denied: namespace '{}' requires a token. Use namespace_create_token to generate one.",
418 namespace
419 )),
420 }
421 }
422
423 pub fn hybrid_searcher(&self) -> Option<Arc<HybridSearcher>> {
424 self.hybrid_searcher.clone()
425 }
426
427 pub async fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
428 self.embedding_client.lock().await.embed(query).await
429 }
430
431 pub async fn handle_request(&self, request: Value, transport: McpTransport) -> Option<Value> {
432 self.handle_jsonrpc_request(request, transport)
433 .await
434 .into_option()
435 }
436
437 pub async fn handle_payload(&self, payload: &str, transport: McpTransport) -> Option<Value> {
438 let request = match parse_jsonrpc_payload(payload, self.max_request_bytes) {
439 Ok(request) => request,
440 Err(response) => return Some(response),
441 };
442
443 self.handle_request(request, transport).await
444 }
445
446 pub async fn handle_jsonrpc_request(
447 &self,
448 request: Value,
449 transport: McpTransport,
450 ) -> McpDispatch {
451 let method_name = request["method"].as_str().unwrap_or("");
452
453 if method_name.starts_with("notifications/") {
454 return McpDispatch::Notification;
455 }
456
457 let id = match request.get("id") {
458 Some(value) if value.is_string() || value.is_number() => value.clone(),
459 _ => {
460 return McpDispatch::Response(json!({
461 "jsonrpc": "2.0",
462 "id": Value::Null,
463 "error": {
464 "code": -32600,
465 "message": "Invalid Request: missing or invalid 'id' field"
466 }
467 }));
468 }
469 };
470
471 let method = match McpMethod::from_name(method_name) {
472 Some(method) => method,
473 None => {
474 return McpDispatch::Response(jsonrpc_error(
475 Some(&id),
476 -32601,
477 format!("Unknown method: {}", method_name),
478 ));
479 }
480 };
481
482 let result = match method {
483 McpMethod::Initialize => shared_initialize_result(),
484 McpMethod::ToolsList => shared_tools_list_result(),
485 McpMethod::ToolsCall => match self.handle_tool_call(&request, &id, transport).await {
486 Ok(result) => result,
487 Err(response) => return McpDispatch::Response(response),
488 },
489 };
490
491 McpDispatch::Response(jsonrpc_success(&id, result))
492 }
493
494 async fn handle_tool_call(
495 &self,
496 request: &Value,
497 id: &Value,
498 transport: McpTransport,
499 ) -> std::result::Result<Value, Value> {
500 let tool_name = request["params"]["name"].as_str().unwrap_or("");
501 let tool = McpTool::from_name(tool_name).ok_or_else(|| {
502 jsonrpc_error(Some(id), -32601, format!("Unknown tool: {}", tool_name))
503 })?;
504 let args = &request["params"]["arguments"];
505
506 match tool {
507 McpTool::Health => {
508 let mut status = json!({
509 "version": env!("CARGO_PKG_VERSION"),
510 "db_path": self.rag.storage_manager().lance_path(),
511 "backend": "mlx",
512 "mlx_server": self.rag.mlx_connected_to(),
513 });
514
515 if let Some(transport_name) = transport.health_transport() {
516 status["transport"] = json!(transport_name);
517 }
518
519 Ok(text_result_from_json(&status))
520 }
521 McpTool::RagIndex => {
522 let path_str = args["path"].as_str().unwrap_or("");
523 let namespace = args["namespace"].as_str();
524
525 let validated_path = validate_path(path_str, &self.allowed_paths)
526 .map_err(|e| jsonrpc_error(Some(id), -32602, e.to_string()))?;
527
528 match self.rag.index_document(&validated_path, namespace).await {
529 Ok(_) => Ok(text_result(format!("Indexed: {}", path_str))),
530 Err(e) => Ok(tool_error(e)),
531 }
532 }
533 McpTool::MemoryUpsert => {
534 let namespace = args["namespace"].as_str().unwrap_or("default");
535 let token = args["token"].as_str();
536
537 self.verify_tool_access(namespace, token)
538 .await
539 .map_err(|e| jsonrpc_error(Some(id), -32603, e.to_string()))?;
540
541 let item_id = args["id"].as_str().unwrap_or("").to_string();
542 let text = args["text"].as_str().unwrap_or("").to_string();
543 let metadata = args.get("metadata").cloned().unwrap_or_else(|| json!({}));
544
545 match self
546 .rag
547 .memory_upsert(namespace, item_id.clone(), text, metadata)
548 .await
549 {
550 Ok(_) => Ok(text_result(format!("Upserted {}", item_id))),
551 Err(e) => Ok(tool_error(e)),
552 }
553 }
554 McpTool::MemoryGet => {
555 let namespace = args["namespace"].as_str().unwrap_or("default");
556 let token = args["token"].as_str();
557
558 self.verify_tool_access(namespace, token)
559 .await
560 .map_err(|e| jsonrpc_error(Some(id), -32603, e.to_string()))?;
561
562 let item_id = args["id"].as_str().unwrap_or("");
563 match self.rag.lookup_memory(namespace, item_id).await {
564 Ok(Some(doc)) => Ok(text_result_from_json(&doc)),
565 Ok(None) => Ok(text_result("Not found")),
566 Err(e) => Ok(tool_error(e)),
567 }
568 }
569 McpTool::MemorySearch => {
570 let namespace = args["namespace"].as_str().unwrap_or("default");
571 let token = args["token"].as_str();
572
573 self.verify_tool_access(namespace, token)
574 .await
575 .map_err(|e| jsonrpc_error(Some(id), -32603, e.to_string()))?;
576
577 let query = args["query"].as_str().unwrap_or("");
578 let limit = requested_limit(args, 5);
579 let mode = requested_search_mode(query, args);
580 let options = requested_search_options(args);
581
582 if let Some(hybrid_result) = self
583 .try_hybrid_search(query, Some(namespace), limit, (mode, options.clone()), id)
584 .await?
585 {
586 return Ok(hybrid_result);
587 }
588
589 match self
590 .rag
591 .search_with_options(Some(namespace), query, limit, options)
592 .await
593 {
594 Ok(results) => Ok(text_result_from_json(&results)),
595 Err(e) => Ok(tool_error(e)),
596 }
597 }
598 McpTool::MemoryDelete => {
599 let namespace = args["namespace"].as_str().unwrap_or("default");
600 let token = args["token"].as_str();
601
602 self.verify_tool_access(namespace, token)
603 .await
604 .map_err(|e| jsonrpc_error(Some(id), -32603, e.to_string()))?;
605
606 let item_id = args["id"].as_str().unwrap_or("");
607 match self.rag.remove_memory(namespace, item_id).await {
608 Ok(deleted) => Ok(text_result(format!("Deleted {} rows", deleted))),
609 Err(e) => Ok(tool_error(e)),
610 }
611 }
612 McpTool::MemoryPurgeNamespace => {
613 let namespace = args["namespace"].as_str().unwrap_or("default");
614 let token = args["token"].as_str();
615
616 self.verify_tool_access(namespace, token)
617 .await
618 .map_err(|e| jsonrpc_error(Some(id), -32603, e.to_string()))?;
619
620 match self.rag.clear_namespace(namespace).await {
621 Ok(deleted) => Ok(text_result(format!(
622 "Purged namespace '{}', removed {} rows",
623 namespace, deleted
624 ))),
625 Err(e) => Ok(tool_error(e)),
626 }
627 }
628 McpTool::NamespaceCreateToken => {
629 let namespace = args["namespace"].as_str().unwrap_or("");
630 let description = args["description"].as_str().map(ToOwned::to_owned);
631
632 if namespace.is_empty() {
633 return Ok(tool_error_message("Namespace is required"));
634 }
635
636 let description = description
646 .unwrap_or_else(|| format!("Auto-created for namespace '{}'", namespace));
647 let _ = self.auth_manager.revoke_token(namespace).await;
648 match self
649 .auth_manager
650 .create_token(
651 namespace.to_string(),
652 vec![Scope::Read, Scope::Write, Scope::Admin],
653 vec![namespace.to_string()],
654 None,
655 description,
656 )
657 .await
658 {
659 Ok(token) => Ok(text_result(format!(
660 "Token created for namespace '{}'. Store this token securely - it won't be shown again!\n\nToken: {}",
661 namespace, token
662 ))),
663 Err(e) => Ok(tool_error(e)),
664 }
665 }
666 McpTool::NamespaceRevokeToken => {
667 let namespace = args["namespace"].as_str().unwrap_or("");
668
669 if namespace.is_empty() {
670 return Ok(tool_error_message("Namespace is required"));
671 }
672
673 match self.auth_manager.revoke_token(namespace).await {
674 Ok(true) => Ok(text_result(format!(
675 "Token revoked for namespace '{}'. The namespace is now publicly accessible.",
676 namespace
677 ))),
678 Ok(false) => Ok(text_result(format!(
679 "No token found for namespace '{}'.",
680 namespace
681 ))),
682 Err(e) => Ok(tool_error(e)),
683 }
684 }
685 McpTool::NamespaceListProtected => {
686 let tokens = self.auth_manager.list_tokens().await;
687 let mut protected: std::collections::BTreeMap<String, (i64, Option<String>)> =
693 std::collections::BTreeMap::new();
694 for entry in &tokens {
695 let created_at = entry.created_at.timestamp();
696 let desc = Some(entry.description.clone());
697 for ns in &entry.namespaces {
698 if ns == "*" {
699 continue;
700 }
701 protected
702 .entry(ns.clone())
703 .and_modify(|existing| {
704 if created_at > existing.0 {
705 *existing = (created_at, desc.clone());
706 }
707 })
708 .or_insert_with(|| (created_at, desc.clone()));
709 }
710 }
711
712 if protected.is_empty() {
713 Ok(text_result(
714 "No namespaces are currently protected with tokens.",
715 ))
716 } else {
717 let list: Vec<Value> = protected
718 .into_iter()
719 .map(|(namespace, (created_at, description))| {
720 json!({
721 "namespace": namespace,
722 "created_at": created_at,
723 "description": description
724 })
725 })
726 .collect();
727 Ok(pretty_text_result_from_json(&list))
728 }
729 }
730 McpTool::NamespaceSecurityStatus => {
731 let has_any = self.auth_manager.has_any_tokens().await;
735 let tokens = self.auth_manager.list_tokens().await;
736 let protected_namespaces: std::collections::BTreeSet<String> = tokens
737 .iter()
738 .flat_map(|entry| entry.namespaces.iter().cloned())
739 .filter(|ns| ns != "*")
740 .collect();
741
742 Ok(text_result(format!(
743 "Namespace security: {}\nProtected namespaces: {}\n\nNote: When security is disabled, all namespaces are accessible without tokens.",
744 if has_any { "ENABLED" } else { "DISABLED" },
745 protected_namespaces.len()
746 )))
747 }
748 McpTool::Dive => {
749 let namespace = args["namespace"].as_str().unwrap_or("");
750 let query = args["query"].as_str().unwrap_or("");
751 let limit = args["limit"].as_u64().unwrap_or(5) as usize;
752 let verbose = args["verbose"].as_bool().unwrap_or(false);
753
754 if namespace.is_empty() || query.is_empty() {
755 return Err(jsonrpc_error(
756 Some(id),
757 -32602,
758 "namespace and query are required",
759 ));
760 }
761
762 let layers = [
763 (Some(SliceLayer::Outer), "outer"),
764 (Some(SliceLayer::Middle), "middle"),
765 (Some(SliceLayer::Inner), "inner"),
766 (Some(SliceLayer::Core), "core"),
767 ];
768
769 let mut all_results: Vec<Value> = Vec::new();
770
771 for (layer_filter, layer_name) in &layers {
772 match self
773 .rag
774 .memory_search_with_layer(namespace, query, limit, *layer_filter)
775 .await
776 {
777 Ok(results) => {
778 let layer_results: Vec<Value> = results
779 .iter()
780 .map(|result| {
781 let mut object = json!({
782 "id": result.id,
783 "score": result.score,
784 "keywords": result.keywords,
785 "layer": result.layer.map(|layer| layer.name()),
786 "can_expand": result.can_expand(),
787 "parent_id": result.parent_id,
788 });
789
790 if verbose {
791 object["text"] = json!(result.text);
792 object["metadata"] = result.metadata.clone();
793 object["children_ids"] = json!(result.children_ids);
794 } else {
795 let preview: String =
796 result.text.chars().take(200).collect();
797 object["preview"] = json!(preview);
798 }
799
800 object
801 })
802 .collect();
803
804 all_results.push(json!({
805 "layer": layer_name,
806 "count": results.len(),
807 "results": layer_results
808 }));
809 }
810 Err(e) => {
811 all_results.push(json!({
812 "layer": layer_name,
813 "error": e.to_string()
814 }));
815 }
816 }
817 }
818
819 Ok(pretty_text_result_from_json(&json!({
820 "query": query,
821 "namespace": namespace,
822 "limit_per_layer": limit,
823 "verbose": verbose,
824 "layers": all_results
825 })))
826 }
827 }
828 }
829
830 async fn try_hybrid_search(
831 &self,
832 query: &str,
833 namespace: Option<&str>,
834 limit: usize,
835 search: (SearchMode, SearchOptions),
836 id: &Value,
837 ) -> std::result::Result<Option<Value>, Value> {
838 let (mode, options) = search;
839 if mode == SearchMode::Vector {
840 return Ok(None);
841 }
842
843 let Some(hybrid_searcher) = &self.hybrid_searcher else {
844 return Ok(None);
845 };
846
847 let query_embedding = self
848 .embedding_client
849 .lock()
850 .await
851 .embed(query)
852 .await
853 .map_err(|e| jsonrpc_error(Some(id), -32603, format!("Embedding failed: {}", e)))?;
854
855 let results = hybrid_searcher
856 .search(query, query_embedding, namespace, limit, options)
857 .await
858 .map_err(|e| jsonrpc_error(Some(id), -32603, format!("Hybrid search failed: {}", e)))?;
859
860 let payload: Vec<Value> = results
861 .iter()
862 .map(|result| {
863 json!({
864 "id": result.id,
865 "namespace": result.namespace,
866 "text": result.document,
867 "score": result.combined_score,
868 "vector_score": result.vector_score,
869 "bm25_score": result.bm25_score,
870 "metadata": result.metadata
871 })
872 })
873 .collect();
874
875 Ok(Some(text_result_from_json(&payload)))
876 }
877}
878
879fn requested_search_mode(query: &str, args: &Value) -> SearchMode {
880 if args["auto_route"].as_bool().unwrap_or(false) {
881 let router = QueryRouter::new();
882 let decision = router.route(query);
883 match decision.recommended_mode.mode {
884 SearchModeRecommendation::Vector => SearchMode::Vector,
885 SearchModeRecommendation::Bm25 => SearchMode::Keyword,
886 SearchModeRecommendation::Hybrid => SearchMode::Hybrid,
887 }
888 } else {
889 match args["mode"].as_str() {
890 Some("vector") => SearchMode::Vector,
891 Some("bm25") | Some("keyword") => SearchMode::Keyword,
892 _ => SearchMode::Hybrid,
893 }
894 }
895}
896
897fn requested_layer_filter(args: &Value) -> Option<SliceLayer> {
898 if args["deep"].as_bool().unwrap_or(false) {
899 None
900 } else {
901 Some(SliceLayer::Outer)
902 }
903}
904
905fn requested_search_options(args: &Value) -> SearchOptions {
906 SearchOptions {
907 layer_filter: requested_layer_filter(args),
908 project_filter: args["project"]
909 .as_str()
910 .map(|value| value.trim().to_string())
911 .filter(|value| !value.is_empty()),
912 }
913}
914
915fn requested_limit(args: &Value, default: usize) -> usize {
916 args["k"]
917 .as_u64()
918 .or_else(|| args["limit"].as_u64())
919 .map(|value| value as usize)
920 .unwrap_or(default)
921}
922
923fn parse_jsonrpc_payload(
924 payload: &str,
925 max_request_bytes: usize,
926) -> std::result::Result<Value, Value> {
927 let trimmed = payload.trim();
928
929 if trimmed.len() > max_request_bytes {
930 return Err(jsonrpc_error(
931 None,
932 -32600,
933 format!(
934 "Request too large: {} bytes (max {})",
935 trimmed.len(),
936 max_request_bytes
937 ),
938 ));
939 }
940
941 serde_json::from_str(trimmed)
942 .map_err(|error| jsonrpc_error(None, -32700, format!("Parse error: {}", error)))
943}
944
945fn tool_error(error: impl ToString) -> Value {
946 tool_error_message(error.to_string())
947}
948
949fn tool_error_message(message: impl Into<String>) -> Value {
950 json!({
951 "error": {"message": message.into()}
952 })
953}
954
955fn text_result(text: impl Into<String>) -> Value {
956 json!({
957 "content": [{"type": "text", "text": text.into()}]
958 })
959}
960
961fn text_result_from_json<T: serde::Serialize>(value: &T) -> Value {
962 text_result(serde_json::to_string(value).unwrap_or_default())
963}
964
965fn pretty_text_result_from_json<T: serde::Serialize>(value: &T) -> Value {
966 text_result(serde_json::to_string_pretty(value).unwrap_or_default())
967}
968
969fn validate_path(path_str: &str, allowed_paths: &[String]) -> Result<std::path::PathBuf> {
972 if path_str.is_empty() {
973 return Err(anyhow!("Path cannot be empty"));
974 }
975
976 if path_str.contains("..") || path_str.contains('\0') || path_str.contains('\n') {
977 return Err(anyhow!(
978 "Path traversal detected: invalid sequences in '{}'",
979 path_str
980 ));
981 }
982
983 let canonical = crate::path_utils::sanitize_existing_path(path_str)?;
984
985 let is_safe = if allowed_paths.is_empty() {
986 let home = std::env::var("HOME")
987 .or_else(|_| std::env::var("USERPROFILE"))
988 .map(std::path::PathBuf::from)
989 .ok();
990 let cwd = std::env::current_dir().ok();
991
992 home.as_ref()
993 .map(|path| canonical.starts_with(path))
994 .unwrap_or(false)
995 || cwd
996 .as_ref()
997 .map(|path| canonical.starts_with(path))
998 .unwrap_or(false)
999 } else {
1000 allowed_paths.iter().any(|allowed| {
1001 let expanded_allowed = shellexpand::tilde(allowed).to_string();
1002 let allowed_path = Path::new(&expanded_allowed);
1003 let canonical_allowed = allowed_path
1004 .canonicalize()
1005 .unwrap_or_else(|_| std::path::PathBuf::from(&expanded_allowed));
1006
1007 canonical.starts_with(&canonical_allowed)
1008 })
1009 };
1010
1011 if !is_safe {
1012 let allowed_info = if allowed_paths.is_empty() {
1013 "$HOME and current working directory".to_string()
1014 } else {
1015 format!("configured paths: {:?}", allowed_paths)
1016 };
1017
1018 return Err(anyhow!(
1019 "Access denied: path '{}' is outside allowed directories ({})",
1020 path_str,
1021 allowed_info
1022 ));
1023 }
1024
1025 Ok(canonical)
1026}
1027
1028#[cfg(test)]
1029mod tests {
1030 use super::{
1031 jsonrpc_error, jsonrpc_success, parse_jsonrpc_payload, requested_layer_filter,
1032 requested_limit, requested_search_options, shared_initialize_result,
1033 shared_tools_list_result,
1034 };
1035 use crate::rag::{SearchOptions, SliceLayer};
1036 use serde_json::{Value, json};
1037
1038 #[test]
1039 fn jsonrpc_error_omits_missing_id() {
1040 let response = jsonrpc_error(None, -32600, "boom");
1041 assert_eq!(response["jsonrpc"], "2.0");
1042 assert_eq!(response["error"]["code"], -32600);
1043 assert_eq!(response.get("id"), None);
1044 }
1045
1046 #[test]
1047 fn jsonrpc_success_omits_null_id() {
1048 let response = jsonrpc_success(&Value::Null, json!({"ok": true}));
1049 assert_eq!(response["jsonrpc"], "2.0");
1050 assert!(response["result"]["ok"].as_bool().unwrap());
1051 assert_eq!(response.get("id"), None);
1052 }
1053
1054 #[test]
1055 fn initialize_advertises_only_tools_capability() {
1056 let response = shared_initialize_result();
1057 assert_eq!(response["protocolVersion"], "2024-11-05");
1058 assert_eq!(response["capabilities"], json!({ "tools": {} }));
1059 }
1060
1061 #[test]
1062 fn tool_list_contains_extended_stdio_and_http_surface() {
1063 let result = shared_tools_list_result();
1064 let tools = result["tools"]
1065 .as_array()
1066 .expect("tools list should be an array");
1067 let names: Vec<&str> = tools
1068 .iter()
1069 .filter_map(|tool| tool["name"].as_str())
1070 .collect();
1071
1072 assert!(names.contains(&"rag_index"));
1073 assert!(names.contains(&"memory_purge_namespace"));
1074 assert!(names.contains(&"namespace_create_token"));
1075 assert!(names.contains(&"dive"));
1076 }
1077
1078 #[test]
1079 fn parse_jsonrpc_payload_rejects_oversized_requests() {
1080 let response = parse_jsonrpc_payload("123456", 5).expect_err("payload should be rejected");
1081 assert_eq!(response["error"]["code"], -32600);
1082 assert!(
1083 response["error"]["message"]
1084 .as_str()
1085 .unwrap_or("")
1086 .contains("Request too large")
1087 );
1088 }
1089
1090 #[test]
1091 fn parse_jsonrpc_payload_returns_jsonrpc_parse_error() {
1092 let response = parse_jsonrpc_payload("{", 1024).expect_err("payload should not parse");
1093 assert_eq!(response["error"]["code"], -32700);
1094 assert!(
1095 response["error"]["message"]
1096 .as_str()
1097 .unwrap_or("")
1098 .contains("Parse error")
1099 );
1100 }
1101
1102 #[test]
1103 fn parse_jsonrpc_payload_accepts_valid_json_with_whitespace() {
1104 let request = parse_jsonrpc_payload(
1105 " {\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{}} ",
1106 1024,
1107 )
1108 .expect("payload should parse");
1109
1110 assert_eq!(request["method"], "initialize");
1111 assert_eq!(request["id"], 1);
1112 }
1113
1114 #[test]
1115 fn requested_limit_prefers_request_k_over_default() {
1116 assert_eq!(requested_limit(&json!({"k": 17}), 5), 17);
1117 assert_eq!(requested_limit(&json!({}), 5), 5);
1118 }
1119
1120 #[test]
1121 fn requested_limit_accepts_limit_alias() {
1122 assert_eq!(requested_limit(&json!({"limit": 11}), 5), 11);
1123 }
1124
1125 #[test]
1126 fn requested_layer_filter_defaults_to_outer_only() {
1127 assert_eq!(requested_layer_filter(&json!({})), Some(SliceLayer::Outer));
1128 }
1129
1130 #[test]
1131 fn requested_layer_filter_allows_deep_search() {
1132 assert_eq!(requested_layer_filter(&json!({"deep": true})), None);
1133 }
1134
1135 #[test]
1136 fn requested_search_options_captures_project_filter() {
1137 assert_eq!(
1138 requested_search_options(&json!({"project": "Vista"})),
1139 SearchOptions {
1140 layer_filter: Some(SliceLayer::Outer),
1141 project_filter: Some("Vista".to_string()),
1142 }
1143 );
1144 }
1145}