1use std::sync::Arc;
6
7use serde::Deserialize;
8use serde_json::{Value, json};
9
10use crate::access::AccessTracker;
11use crate::consolidation::{ConsolidationQueue, spawn_consolidation};
12use crate::db::score_with_decay;
13use crate::graph::GraphStore;
14use crate::item::{Item, ItemFilters};
15use crate::retry::{RetryConfig, with_retry};
16use crate::{Database, ListScope, StoreScope};
17
18use super::protocol::{CallToolResult, Tool};
19use super::server::ServerContext;
20
21fn spawn_logged(name: &'static str, fut: impl std::future::Future<Output = ()> + Send + 'static) {
24 tokio::spawn(async move {
25 let result = tokio::task::spawn(fut).await;
26 if let Err(e) = result {
27 tracing::error!("Background task '{}' panicked: {:?}", name, e);
28 }
29 });
30}
31
32pub fn get_tools() -> Vec<Tool> {
34 vec![
35 Tool {
36 name: "store".to_string(),
37 description: "Store content for later retrieval. Use for preferences, facts, reference material, docs, or any information worth remembering. Long content is automatically chunked for better search.".to_string(),
38 input_schema: json!({
39 "type": "object",
40 "properties": {
41 "content": {
42 "type": "string",
43 "description": "The content to store"
44 },
45 "scope": {
46 "type": "string",
47 "enum": ["project", "global"],
48 "default": "project",
49 "description": "Where to store: 'project' (current project) or 'global' (all projects)"
50 }
51 },
52 "required": ["content"]
53 }),
54 },
55 Tool {
56 name: "recall".to_string(),
57 description: "Search stored content by semantic similarity. Returns matching items with relevant excerpts for chunked content.".to_string(),
58 input_schema: json!({
59 "type": "object",
60 "properties": {
61 "query": {
62 "type": "string",
63 "description": "What to search for (semantic search)"
64 },
65 "limit": {
66 "type": "number",
67 "default": 5,
68 "description": "Maximum number of results"
69 }
70 },
71 "required": ["query"]
72 }),
73 },
74 Tool {
75 name: "list".to_string(),
76 description: "List stored items.".to_string(),
77 input_schema: json!({
78 "type": "object",
79 "properties": {
80 "limit": {
81 "type": "number",
82 "default": 10,
83 "description": "Maximum number of results"
84 },
85 "scope": {
86 "type": "string",
87 "enum": ["project", "global", "all"],
88 "default": "project",
89 "description": "Which items to list: 'project', 'global', or 'all'"
90 }
91 }
92 }),
93 },
94 Tool {
95 name: "forget".to_string(),
96 description: "Delete a stored item by its ID.".to_string(),
97 input_schema: json!({
98 "type": "object",
99 "properties": {
100 "id": {
101 "type": "string",
102 "description": "The item ID to delete"
103 }
104 },
105 "required": ["id"]
106 }),
107 },
108 ]
109}
110
111#[derive(Debug, Deserialize)]
114pub struct StoreParams {
115 pub content: String,
116 #[serde(default)]
117 pub scope: Option<String>,
118}
119
120#[derive(Debug, Deserialize)]
121pub struct RecallParams {
122 pub query: String,
123 #[serde(default)]
124 pub limit: Option<usize>,
125}
126
127#[derive(Debug, Deserialize)]
128pub struct ListParams {
129 #[serde(default)]
130 pub limit: Option<usize>,
131 #[serde(default)]
132 pub scope: Option<String>,
133}
134
135#[derive(Debug, Deserialize)]
136pub struct ForgetParams {
137 pub id: String,
138}
139
140pub struct RecallConfig {
145 pub enable_graph_backfill: bool,
146 pub enable_graph_expansion: bool,
147 pub enable_co_access: bool,
148 pub enable_decay_scoring: bool,
149 pub enable_background_tasks: bool,
150}
151
152impl Default for RecallConfig {
153 fn default() -> Self {
154 Self {
155 enable_graph_backfill: true,
156 enable_graph_expansion: true,
157 enable_co_access: true,
158 enable_decay_scoring: true,
159 enable_background_tasks: true,
160 }
161 }
162}
163
164pub struct RecallResult {
166 pub results: Vec<crate::item::SearchResult>,
167 pub graph_expanded: Vec<Value>,
168 pub suggested: Vec<Value>,
169 pub raw_similarities: std::collections::HashMap<String, f32>,
171}
172
173pub async fn execute_tool(ctx: &ServerContext, name: &str, args: Option<Value>) -> CallToolResult {
176 let config = RetryConfig::default();
177 let args_for_retry = args.clone();
178
179 let result = with_retry(&config, || {
180 let ctx_ref = ctx;
181 let name_ref = name;
182 let args_clone = args_for_retry.clone();
183
184 async move {
185 let mut db = Database::open_with_embedder(
187 &ctx_ref.db_path,
188 ctx_ref.project_id.clone(),
189 ctx_ref.embedder.clone(),
190 )
191 .await
192 .map_err(|e| sanitize_err("Failed to open database", e))?;
193
194 let tracker = AccessTracker::open(&ctx_ref.access_db_path)
196 .map_err(|e| sanitize_err("Failed to open access tracker", e))?;
197
198 let graph = GraphStore::open(&ctx_ref.access_db_path)
200 .map_err(|e| sanitize_err("Failed to open graph store", e))?;
201
202 let result = match name_ref {
203 "store" => execute_store(&mut db, &tracker, &graph, ctx_ref, args_clone).await,
204 "recall" => execute_recall(&mut db, &tracker, &graph, ctx_ref, args_clone).await,
205 "list" => execute_list(&mut db, args_clone).await,
206 "forget" => execute_forget(&mut db, &graph, ctx_ref, args_clone).await,
207 _ => return Ok(CallToolResult::error(format!("Unknown tool: {}", name_ref))),
208 };
209
210 if result.is_error.unwrap_or(false)
211 && let Some(content) = result.content.first()
212 && is_retryable_error(&content.text)
213 {
214 return Err(content.text.clone());
215 }
216
217 Ok(result)
218 }
219 })
220 .await;
221
222 match result {
223 Ok(call_result) => call_result,
224 Err(e) => {
225 tracing::error!("Operation failed after retries: {}", e);
226 CallToolResult::error("Operation failed after retries")
227 }
228 }
229}
230
231fn is_retryable_error(error_msg: &str) -> bool {
232 let retryable_patterns = [
233 "connection",
234 "timeout",
235 "temporarily unavailable",
236 "resource busy",
237 "lock",
238 "I/O error",
239 "Failed to open",
240 "Failed to connect",
241 ];
242
243 let lower = error_msg.to_lowercase();
244 retryable_patterns
245 .iter()
246 .any(|p| lower.contains(&p.to_lowercase()))
247}
248
249async fn execute_store(
252 db: &mut Database,
253 _tracker: &AccessTracker,
254 graph: &GraphStore,
255 ctx: &ServerContext,
256 args: Option<Value>,
257) -> CallToolResult {
258 let params: StoreParams = match args {
259 Some(v) => match serde_json::from_value(v) {
260 Ok(p) => p,
261 Err(e) => {
262 tracing::debug!("Parameter validation failed: {}", e);
263 return CallToolResult::error("Invalid parameters");
264 }
265 },
266 None => return CallToolResult::error("Missing parameters"),
267 };
268
269 const MAX_CONTENT_BYTES: usize = 1_000_000;
273 if params.content.len() > MAX_CONTENT_BYTES {
274 return CallToolResult::error(format!(
275 "Content too large: {} bytes (max {} bytes)",
276 params.content.len(),
277 MAX_CONTENT_BYTES
278 ));
279 }
280
281 let scope = params
283 .scope
284 .as_deref()
285 .map(|s| s.parse::<StoreScope>())
286 .transpose();
287
288 let scope = match scope {
289 Ok(s) => s.unwrap_or(StoreScope::Project),
290 Err(e) => return CallToolResult::error(e),
291 };
292
293 let mut item = Item::new(¶ms.content);
295
296 if scope == StoreScope::Project
298 && let Some(project_id) = db.project_id()
299 {
300 item = item.with_project_id(project_id);
301 }
302
303 match db.store_item(item).await {
304 Ok(store_result) => {
305 let new_id = store_result.id.clone();
306
307 let now = chrono::Utc::now().timestamp();
309 let project_id = db.project_id().map(|s| s.to_string());
310 if let Err(e) = graph.add_node(&new_id, project_id.as_deref(), now) {
311 tracing::warn!("graph add_node failed: {}", e);
312 }
313
314 if !store_result.potential_conflicts.is_empty()
316 && let Ok(queue) = ConsolidationQueue::open(&ctx.access_db_path)
317 {
318 for conflict in &store_result.potential_conflicts {
319 if let Err(e) = queue.enqueue(&new_id, &conflict.id, conflict.similarity as f64)
320 {
321 tracing::warn!("enqueue consolidation failed: {}", e);
322 }
323 }
324 }
325
326 let mut result = json!({
327 "success": true,
328 "id": new_id,
329 "message": format!("Stored in {} scope", scope)
330 });
331
332 if !store_result.potential_conflicts.is_empty() {
333 let conflicts: Vec<Value> = store_result
334 .potential_conflicts
335 .iter()
336 .map(|c| {
337 json!({
338 "id": c.id,
339 "content": c.content,
340 "similarity": format!("{:.2}", c.similarity)
341 })
342 })
343 .collect();
344 result["potential_conflicts"] = json!(conflicts);
345 }
346
347 CallToolResult::success(
348 serde_json::to_string_pretty(&result)
349 .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
350 )
351 }
352 Err(e) => sanitized_error("Failed to store item", e),
353 }
354}
355
356pub async fn recall_pipeline(
361 db: &mut Database,
362 tracker: &AccessTracker,
363 graph: &GraphStore,
364 query: &str,
365 limit: usize,
366 filters: ItemFilters,
367 config: &RecallConfig,
368) -> std::result::Result<RecallResult, String> {
369 let mut results = db
370 .search_items(query, limit, filters)
371 .await
372 .map_err(|e| format!("Search failed: {}", e))?;
373
374 if results.is_empty() {
375 return Ok(RecallResult {
376 results: Vec::new(),
377 graph_expanded: Vec::new(),
378 suggested: Vec::new(),
379 raw_similarities: std::collections::HashMap::new(),
380 });
381 }
382
383 if config.enable_graph_backfill {
385 for result in &results {
386 if let Err(e) = graph.ensure_node_exists(
387 &result.id,
388 result.project_id.as_deref(),
389 result.created_at.timestamp(),
390 ) {
391 tracing::warn!("ensure_node_exists failed: {}", e);
392 }
393 }
394 }
395
396 let mut raw_similarities: std::collections::HashMap<String, f32> =
398 std::collections::HashMap::new();
399 if config.enable_decay_scoring {
400 let item_ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
401 let access_records = tracker.get_accesses(&item_ids).unwrap_or_default();
402 let validation_counts = tracker.get_validation_counts(&item_ids).unwrap_or_default();
403 let edge_counts = graph.get_edge_counts(&item_ids).unwrap_or_default();
404 let now = chrono::Utc::now().timestamp();
405
406 for result in &mut results {
407 raw_similarities.insert(result.id.clone(), result.similarity);
408
409 let created_at = result.created_at.timestamp();
410 let (access_count, last_accessed) = match access_records.get(&result.id) {
411 Some(rec) => (rec.access_count, Some(rec.last_accessed_at)),
412 None => (0, None),
413 };
414
415 let base_score = score_with_decay(
416 result.similarity,
417 now,
418 created_at,
419 access_count,
420 last_accessed,
421 );
422
423 let validation_count = validation_counts.get(&result.id).copied().unwrap_or(0);
424 let edge_count = edge_counts.get(&result.id).copied().unwrap_or(0);
425 let trust_bonus =
426 1.0 + 0.05 * (1.0 + validation_count as f64).ln() as f32 + 0.02 * edge_count as f32;
427
428 result.similarity = (base_score * trust_bonus).min(1.0);
429 }
430
431 results.sort_by(|a, b| {
432 b.similarity
433 .partial_cmp(&a.similarity)
434 .unwrap_or(std::cmp::Ordering::Equal)
435 });
436 }
437
438 for result in &results {
440 let created_at = result.created_at.timestamp();
441 if let Err(e) = tracker.record_access(&result.id, created_at) {
442 tracing::warn!("record_access failed: {}", e);
443 }
444 }
445
446 let existing_ids: std::collections::HashSet<String> =
448 results.iter().map(|r| r.id.clone()).collect();
449
450 let mut graph_expanded = Vec::new();
451 if config.enable_graph_expansion {
452 let top_ids: Vec<&str> = results.iter().take(5).map(|r| r.id.as_str()).collect();
453 if let Ok(neighbors) = graph.get_neighbors(&top_ids, 0.5) {
454 let neighbor_info: Vec<(String, String)> = neighbors
456 .into_iter()
457 .filter(|(id, _, _)| !existing_ids.contains(id))
458 .map(|(id, rel_type, _)| (id, rel_type))
459 .collect();
460
461 let neighbor_ids: Vec<&str> = neighbor_info.iter().map(|(id, _)| id.as_str()).collect();
462 if let Ok(items) = db.get_items_batch(&neighbor_ids).await {
463 let item_map: std::collections::HashMap<&str, &Item> =
464 items.iter().map(|item| (item.id.as_str(), item)).collect();
465
466 for (neighbor_id, rel_type) in &neighbor_info {
467 if let Some(item) = item_map.get(neighbor_id.as_str()) {
468 let sr = crate::item::SearchResult::from_item(item, 0.05);
469 let mut entry = json!({
470 "id": sr.id,
471 "similarity": "graph",
472 "created": sr.created_at.to_rfc3339(),
473 "graph_expanded": true,
474 "rel_type": rel_type,
475 });
476 let same_project = match (db.project_id(), item.project_id.as_deref()) {
478 (Some(current), Some(item_pid)) => current == item_pid,
479 (_, None) => true,
480 _ => false,
481 };
482 if same_project {
483 entry["content"] = json!(sr.content);
484 } else {
485 entry["cross_project"] = json!(true);
486 }
487 graph_expanded.push(entry);
488 }
489 }
490 }
491 }
492 }
493
494 let mut suggested = Vec::new();
496 if config.enable_co_access {
497 let top3_ids: Vec<&str> = results.iter().take(3).map(|r| r.id.as_str()).collect();
498 if let Ok(co_accessed) = graph.get_co_accessed(&top3_ids, 3) {
499 let co_info: Vec<(String, i64)> = co_accessed
500 .into_iter()
501 .filter(|(id, _)| !existing_ids.contains(id))
502 .collect();
503
504 let co_ids: Vec<&str> = co_info.iter().map(|(id, _)| id.as_str()).collect();
505 if let Ok(items) = db.get_items_batch(&co_ids).await {
506 let item_map: std::collections::HashMap<&str, &Item> =
507 items.iter().map(|item| (item.id.as_str(), item)).collect();
508
509 for (co_id, co_count) in &co_info {
510 if let Some(item) = item_map.get(co_id.as_str()) {
511 let same_project = match (db.project_id(), item.project_id.as_deref()) {
512 (Some(current), Some(item_pid)) => current == item_pid,
513 (_, None) => true,
514 _ => false,
515 };
516 let mut entry = json!({
517 "id": item.id,
518 "reason": format!("frequently recalled with result (co-accessed {} times)", co_count),
519 });
520 if same_project {
521 entry["content"] = json!(truncate(&item.content, 100));
522 } else {
523 entry["cross_project"] = json!(true);
524 }
525 suggested.push(entry);
526 }
527 }
528 }
529 }
530 }
531
532 Ok(RecallResult {
533 results,
534 graph_expanded,
535 suggested,
536 raw_similarities,
537 })
538}
539
540async fn execute_recall(
541 db: &mut Database,
542 tracker: &AccessTracker,
543 graph: &GraphStore,
544 ctx: &ServerContext,
545 args: Option<Value>,
546) -> CallToolResult {
547 let params: RecallParams = match args {
548 Some(v) => match serde_json::from_value(v) {
549 Ok(p) => p,
550 Err(e) => {
551 tracing::debug!("Parameter validation failed: {}", e);
552 return CallToolResult::error("Invalid parameters");
553 }
554 },
555 None => return CallToolResult::error("Missing parameters"),
556 };
557
558 const MAX_QUERY_BYTES: usize = 10_000;
562 if params.query.len() > MAX_QUERY_BYTES {
563 return CallToolResult::error(format!(
564 "Query too large: {} bytes (max {} bytes)",
565 params.query.len(),
566 MAX_QUERY_BYTES
567 ));
568 }
569
570 let limit = params.limit.unwrap_or(5).min(100);
571
572 let filters = ItemFilters::new();
573
574 let config = RecallConfig::default();
575
576 let recall_result =
577 match recall_pipeline(db, tracker, graph, ¶ms.query, limit, filters, &config).await {
578 Ok(r) => r,
579 Err(e) => {
580 tracing::error!("Recall failed: {}", e);
581 return CallToolResult::error("Search failed");
582 }
583 };
584
585 if recall_result.results.is_empty() {
586 return CallToolResult::success("No items found matching your query.");
587 }
588
589 let results = &recall_result.results;
590
591 let all_result_ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
593 let neighbors_map = graph
594 .get_neighbors_mapped(&all_result_ids, 0.5)
595 .unwrap_or_default();
596
597 let formatted: Vec<Value> = results
598 .iter()
599 .map(|r| {
600 let mut obj = json!({
601 "id": r.id,
602 "content": r.content,
603 "similarity": format!("{:.2}", r.similarity),
604 "created": r.created_at.to_rfc3339(),
605 });
606
607 if let Some(&raw_sim) = recall_result.raw_similarities.get(&r.id)
609 && (raw_sim - r.similarity).abs() > 0.001
610 {
611 obj["raw_similarity"] = json!(format!("{:.2}", raw_sim));
612 }
613
614 if let Some(ref excerpt) = r.relevant_excerpt {
615 obj["relevant_excerpt"] = json!(excerpt);
616 }
617
618 if let Some(ref current_pid) = ctx.project_id
620 && let Some(ref item_pid) = r.project_id
621 && item_pid != current_pid
622 {
623 obj["cross_project"] = json!(true);
624 }
625
626 if let Some(related) = neighbors_map.get(&r.id)
628 && !related.is_empty()
629 {
630 obj["related_ids"] = json!(related);
631 }
632
633 obj
634 })
635 .collect();
636
637 let mut result_json = json!({
638 "count": results.len(),
639 "results": formatted
640 });
641
642 if !recall_result.graph_expanded.is_empty() {
643 result_json["graph_expanded"] = json!(recall_result.graph_expanded);
644 }
645
646 if !recall_result.suggested.is_empty() {
647 result_json["suggested"] = json!(recall_result.suggested);
648 }
649
650 spawn_consolidation(
652 Arc::new(ctx.db_path.clone()),
653 Arc::new(ctx.access_db_path.clone()),
654 ctx.project_id.clone(),
655 ctx.embedder.clone(),
656 ctx.consolidation_semaphore.clone(),
657 );
658
659 let result_ids: Vec<String> = results.iter().map(|r| r.id.clone()).collect();
661 let access_db_path = ctx.access_db_path.clone();
662 spawn_logged("co_access", async move {
663 if let Ok(g) = GraphStore::open(&access_db_path) {
664 if let Err(e) = g.record_co_access(&result_ids) {
665 tracing::warn!("record_co_access failed: {}", e);
666 }
667 } else {
668 tracing::warn!("co_access: failed to open graph store");
669 }
670 });
671
672 let run_count = ctx
674 .recall_count
675 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
676 if run_count % 10 == 9 {
677 let access_db_path = ctx.access_db_path.clone();
679 spawn_logged("clustering", async move {
680 if let Ok(g) = GraphStore::open(&access_db_path)
681 && let Ok(clusters) = g.detect_clusters()
682 {
683 for (a, b, c) in &clusters {
684 let label = format!("cluster-{}", &a[..8.min(a.len())]);
685 if let Err(e) = g.add_related_edge(a, b, 0.8, &label) {
686 tracing::warn!("cluster add_related_edge failed: {}", e);
687 }
688 if let Err(e) = g.add_related_edge(b, c, 0.8, &label) {
689 tracing::warn!("cluster add_related_edge failed: {}", e);
690 }
691 if let Err(e) = g.add_related_edge(a, c, 0.8, &label) {
692 tracing::warn!("cluster add_related_edge failed: {}", e);
693 }
694 }
695 if !clusters.is_empty() {
696 tracing::info!("Detected {} clusters", clusters.len());
697 }
698 }
699 });
700
701 let access_db_path2 = ctx.access_db_path.clone();
703 spawn_logged("consolidation_cleanup", async move {
704 if let Ok(q) = crate::consolidation::ConsolidationQueue::open(&access_db_path2)
705 && let Err(e) = q.cleanup_processed()
706 {
707 tracing::warn!("consolidation queue cleanup failed: {}", e);
708 }
709 });
710 }
711
712 CallToolResult::success(
713 serde_json::to_string_pretty(&result_json)
714 .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
715 )
716}
717
718async fn execute_list(db: &mut Database, args: Option<Value>) -> CallToolResult {
719 let params: ListParams =
720 args.and_then(|v| serde_json::from_value(v).ok())
721 .unwrap_or(ListParams {
722 limit: None,
723 scope: None,
724 });
725
726 let limit = params.limit.unwrap_or(10).min(100);
727
728 let filters = ItemFilters::new();
729
730 let scope = params
731 .scope
732 .as_deref()
733 .map(|s| s.parse::<ListScope>())
734 .transpose();
735
736 let scope = match scope {
737 Ok(s) => s.unwrap_or(ListScope::Project),
738 Err(e) => return CallToolResult::error(e),
739 };
740
741 match db.list_items(filters, Some(limit), scope).await {
742 Ok(items) => {
743 if items.is_empty() {
744 CallToolResult::success("No items stored yet.")
745 } else {
746 let formatted: Vec<Value> = items
747 .iter()
748 .map(|item| {
749 let content_preview = truncate(&item.content, 100);
750 let mut obj = json!({
751 "id": item.id,
752 "content": content_preview,
753 "created": item.created_at.to_rfc3339(),
754 });
755
756 if item.is_chunked {
757 obj["chunked"] = json!(true);
758 }
759
760 obj
761 })
762 .collect();
763
764 let result = json!({
765 "count": items.len(),
766 "items": formatted
767 });
768
769 CallToolResult::success(
770 serde_json::to_string_pretty(&result).unwrap_or_else(|e| {
771 format!("{{\"error\": \"serialization failed: {}\"}}", e)
772 }),
773 )
774 }
775 }
776 Err(e) => sanitized_error("Failed to list items", e),
777 }
778}
779
780async fn execute_forget(
781 db: &mut Database,
782 graph: &GraphStore,
783 ctx: &ServerContext,
784 args: Option<Value>,
785) -> CallToolResult {
786 let params: ForgetParams = match args {
787 Some(v) => match serde_json::from_value(v) {
788 Ok(p) => p,
789 Err(e) => {
790 tracing::debug!("Parameter validation failed: {}", e);
791 return CallToolResult::error("Invalid parameters");
792 }
793 },
794 None => return CallToolResult::error("Missing parameters"),
795 };
796
797 if let Some(ref current_pid) = ctx.project_id {
799 match db.get_item(¶ms.id).await {
800 Ok(Some(item)) => {
801 if let Some(ref item_pid) = item.project_id
802 && item_pid != current_pid
803 {
804 return CallToolResult::error(format!(
805 "Cannot delete item {} from a different project",
806 params.id
807 ));
808 }
809 }
810 Ok(None) => return CallToolResult::error(format!("Item not found: {}", params.id)),
811 Err(e) => {
812 return sanitized_error("Failed to look up item", e);
813 }
814 }
815 }
816
817 match db.delete_item(¶ms.id).await {
818 Ok(true) => {
819 if let Err(e) = graph.remove_node(¶ms.id) {
821 tracing::warn!("remove_node failed: {}", e);
822 }
823
824 let result = json!({
825 "success": true,
826 "message": format!("Deleted item: {}", params.id)
827 });
828 CallToolResult::success(
829 serde_json::to_string_pretty(&result)
830 .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
831 )
832 }
833 Ok(false) => CallToolResult::error(format!("Item not found: {}", params.id)),
834 Err(e) => sanitized_error("Failed to delete item", e),
835 }
836}
837
838fn sanitized_error(context: &str, err: impl std::fmt::Display) -> CallToolResult {
843 tracing::error!("{}: {}", context, err);
844 CallToolResult::error(context.to_string())
845}
846
847fn sanitize_err(context: &str, err: impl std::fmt::Display) -> String {
849 tracing::error!("{}: {}", context, err);
850 context.to_string()
851}
852
853fn truncate(s: &str, max_len: usize) -> String {
854 if s.chars().count() <= max_len {
855 s.to_string()
856 } else if max_len <= 3 {
857 s.chars().take(max_len).collect()
859 } else {
860 let cut = s
861 .char_indices()
862 .nth(max_len - 3)
863 .map(|(i, _)| i)
864 .unwrap_or(s.len());
865 format!("{}...", &s[..cut])
866 }
867}
868
869#[cfg(test)]
870mod tests {
871 use super::*;
872
873 #[test]
874 fn test_truncate_small_max_len() {
875 assert_eq!(truncate("hello", 0), "");
877 assert_eq!(truncate("hello", 1), "h");
878 assert_eq!(truncate("hello", 2), "he");
879 assert_eq!(truncate("hello", 3), "hel");
880 assert_eq!(truncate("hi", 3), "hi"); assert_eq!(truncate("hello", 5), "hello");
882 assert_eq!(truncate("hello!", 5), "he...");
883 }
884
885 #[test]
886 fn test_truncate_unicode() {
887 assert_eq!(truncate("héllo wörld", 5), "hé...");
888 assert_eq!(truncate("日本語テスト", 4), "日...");
889 }
890
891 use std::path::PathBuf;
894 use std::sync::Mutex;
895 use tokio::sync::Semaphore;
896
897 async fn setup_test_context() -> (ServerContext, tempfile::TempDir) {
899 let tmp = tempfile::TempDir::new().unwrap();
900 let db_path = tmp.path().join("data");
901 let access_db_path = tmp.path().join("access.db");
902
903 let embedder = Arc::new(crate::Embedder::new().unwrap());
904 let project_id = Some("test-project-00000001".to_string());
905
906 let ctx = ServerContext {
907 db_path,
908 access_db_path,
909 project_id,
910 embedder,
911 cwd: PathBuf::from("."),
912 consolidation_semaphore: Arc::new(Semaphore::new(1)),
913 recall_count: std::sync::atomic::AtomicU64::new(0),
914 rate_limit: Mutex::new(super::super::server::RateLimitState {
915 window_start_ms: 0,
916 count: 0,
917 }),
918 };
919
920 (ctx, tmp)
921 }
922
923 #[tokio::test]
924 #[ignore] async fn test_store_and_recall_roundtrip() {
926 let (ctx, _tmp) = setup_test_context().await;
927
928 let store_result = execute_tool(
930 &ctx,
931 "store",
932 Some(json!({ "content": "Rust is a systems programming language" })),
933 )
934 .await;
935 assert!(
936 store_result.is_error.is_none(),
937 "Store should succeed: {:?}",
938 store_result.content
939 );
940
941 let recall_result = execute_tool(
943 &ctx,
944 "recall",
945 Some(json!({ "query": "systems programming language" })),
946 )
947 .await;
948 assert!(recall_result.is_error.is_none(), "Recall should succeed");
949
950 let text = &recall_result.content[0].text;
951 assert!(
952 text.contains("Rust is a systems programming language"),
953 "Recall should return stored content, got: {}",
954 text
955 );
956 }
957
958 #[tokio::test]
959 #[ignore] async fn test_store_and_list() {
961 let (ctx, _tmp) = setup_test_context().await;
962
963 execute_tool(
965 &ctx,
966 "store",
967 Some(json!({ "content": "First item for listing" })),
968 )
969 .await;
970 execute_tool(
971 &ctx,
972 "store",
973 Some(json!({ "content": "Second item for listing" })),
974 )
975 .await;
976
977 let list_result = execute_tool(&ctx, "list", Some(json!({ "scope": "project" }))).await;
979 assert!(list_result.is_error.is_none(), "List should succeed");
980
981 let text = &list_result.content[0].text;
982 let parsed: Value = serde_json::from_str(text).unwrap();
983 assert_eq!(parsed["count"], 2, "Should list 2 items");
984 }
985
986 #[tokio::test]
987 #[ignore] async fn test_store_conflict_detection() {
989 let (ctx, _tmp) = setup_test_context().await;
990
991 execute_tool(
993 &ctx,
994 "store",
995 Some(json!({ "content": "The quick brown fox jumps over the lazy dog" })),
996 )
997 .await;
998
999 let result = execute_tool(
1001 &ctx,
1002 "store",
1003 Some(json!({ "content": "The quick brown fox jumps over the lazy dog" })),
1004 )
1005 .await;
1006 assert!(result.is_error.is_none(), "Store should succeed");
1007
1008 let text = &result.content[0].text;
1009 let parsed: Value = serde_json::from_str(text).unwrap();
1010 assert!(
1011 parsed.get("potential_conflicts").is_some(),
1012 "Should detect conflict for near-duplicate content, got: {}",
1013 text
1014 );
1015 }
1016
1017 #[tokio::test]
1018 #[ignore] async fn test_forget_removes_item() {
1020 let (ctx, _tmp) = setup_test_context().await;
1021
1022 let store_result = execute_tool(
1024 &ctx,
1025 "store",
1026 Some(json!({ "content": "Item to be forgotten" })),
1027 )
1028 .await;
1029 let text = &store_result.content[0].text;
1030 let parsed: Value = serde_json::from_str(text).unwrap();
1031 let item_id = parsed["id"].as_str().unwrap().to_string();
1032
1033 let forget_result = execute_tool(&ctx, "forget", Some(json!({ "id": item_id }))).await;
1035 assert!(forget_result.is_error.is_none(), "Forget should succeed");
1036
1037 let list_result = execute_tool(&ctx, "list", Some(json!({ "scope": "project" }))).await;
1039 let text = &list_result.content[0].text;
1040 assert!(
1041 text.contains("No items stored yet"),
1042 "Should have no items after forget, got: {}",
1043 text
1044 );
1045 }
1046
1047 #[tokio::test]
1048 #[ignore] async fn test_recall_empty_db() {
1050 let (ctx, _tmp) = setup_test_context().await;
1051
1052 let result = execute_tool(&ctx, "recall", Some(json!({ "query": "anything" }))).await;
1053 assert!(
1054 result.is_error.is_none(),
1055 "Recall on empty DB should not error"
1056 );
1057
1058 let text = &result.content[0].text;
1059 assert!(
1060 text.contains("No items found"),
1061 "Should indicate no items found, got: {}",
1062 text
1063 );
1064 }
1065
1066 #[tokio::test]
1067 #[ignore] async fn test_store_rejects_oversized_content() {
1069 let (ctx, _tmp) = setup_test_context().await;
1070
1071 let large_content = "x".repeat(1_100_000); let result = execute_tool(&ctx, "store", Some(json!({ "content": large_content }))).await;
1073 assert!(
1074 result.is_error == Some(true),
1075 "Should reject oversized content"
1076 );
1077
1078 let text = &result.content[0].text;
1079 assert!(
1080 text.contains("too large"),
1081 "Error should mention size, got: {}",
1082 text
1083 );
1084 }
1085
1086 #[tokio::test]
1087 #[ignore] async fn test_recall_rejects_oversized_query() {
1089 let (ctx, _tmp) = setup_test_context().await;
1090
1091 let large_query = "x".repeat(11_000); let result = execute_tool(&ctx, "recall", Some(json!({ "query": large_query }))).await;
1093 assert!(
1094 result.is_error == Some(true),
1095 "Should reject oversized query"
1096 );
1097
1098 let text = &result.content[0].text;
1099 assert!(
1100 text.contains("too large"),
1101 "Error should mention size, got: {}",
1102 text
1103 );
1104 }
1105
1106 #[tokio::test]
1107 #[ignore] async fn test_store_missing_params() {
1109 let (ctx, _tmp) = setup_test_context().await;
1110
1111 let result = execute_tool(&ctx, "store", None).await;
1113 assert!(result.is_error == Some(true), "Should error with no params");
1114
1115 let result = execute_tool(&ctx, "store", Some(json!({}))).await;
1117 assert!(
1118 result.is_error == Some(true),
1119 "Should error with missing content"
1120 );
1121 }
1122}