1use std::sync::Arc;
6
7use serde::Deserialize;
8use serde_json::{Value, json};
9
10use crate::access::AccessTracker;
11use crate::consolidation::{ConsolidationQueue, spawn_consolidation};
12use crate::db::score_with_decay;
13use crate::graph::GraphStore;
14use crate::item::{Item, ItemFilters};
15use crate::retry::{RetryConfig, with_retry};
16use crate::{Database, ListScope, StoreScope};
17
18use super::protocol::{CallToolResult, Tool};
19use super::server::ServerContext;
20
21fn spawn_logged(name: &'static str, fut: impl std::future::Future<Output = ()> + Send + 'static) {
24 tokio::spawn(async move {
25 let result = tokio::task::spawn(fut).await;
26 if let Err(e) = result {
27 tracing::error!("Background task '{}' panicked: {:?}", name, e);
28 }
29 });
30}
31
32pub fn get_tools() -> Vec<Tool> {
34 let store_schema = {
35 #[allow(unused_mut)]
36 let mut props = json!({
37 "content": {
38 "type": "string",
39 "description": "The content to store"
40 },
41 "scope": {
42 "type": "string",
43 "enum": ["project", "global"],
44 "default": "project",
45 "description": "Where to store: 'project' (current project) or 'global' (all projects)"
46 }
47 });
48
49 #[cfg(feature = "bench")]
50 {
51 props.as_object_mut().unwrap().insert(
52 "created_at".to_string(),
53 json!({
54 "type": "number",
55 "description": "Override creation timestamp (Unix seconds). Benchmark builds only."
56 }),
57 );
58 }
59
60 json!({
61 "type": "object",
62 "properties": props,
63 "required": ["content"]
64 })
65 };
66
67 vec![
68 Tool {
69 name: "store".to_string(),
70 description: "Store content for later retrieval. Use for preferences, facts, reference material, docs, or any information worth remembering. Long content is automatically chunked for better search.".to_string(),
71 input_schema: store_schema,
72 },
73 Tool {
74 name: "recall".to_string(),
75 description: "Search stored content by semantic similarity. Returns matching items with relevant excerpts for chunked content.".to_string(),
76 input_schema: json!({
77 "type": "object",
78 "properties": {
79 "query": {
80 "type": "string",
81 "description": "What to search for (semantic search)"
82 },
83 "limit": {
84 "type": "number",
85 "default": 5,
86 "description": "Maximum number of results"
87 }
88 },
89 "required": ["query"]
90 }),
91 },
92 Tool {
93 name: "list".to_string(),
94 description: "List stored items.".to_string(),
95 input_schema: json!({
96 "type": "object",
97 "properties": {
98 "limit": {
99 "type": "number",
100 "default": 10,
101 "description": "Maximum number of results"
102 },
103 "scope": {
104 "type": "string",
105 "enum": ["project", "global", "all"],
106 "default": "project",
107 "description": "Which items to list: 'project', 'global', or 'all'"
108 }
109 }
110 }),
111 },
112 Tool {
113 name: "forget".to_string(),
114 description: "Delete a stored item by its ID.".to_string(),
115 input_schema: json!({
116 "type": "object",
117 "properties": {
118 "id": {
119 "type": "string",
120 "description": "The item ID to delete"
121 }
122 },
123 "required": ["id"]
124 }),
125 },
126 ]
127}
128
129#[derive(Debug, Deserialize)]
132pub struct StoreParams {
133 pub content: String,
134 #[serde(default)]
135 pub scope: Option<String>,
136 #[cfg(feature = "bench")]
138 #[serde(default)]
139 pub created_at: Option<i64>,
140}
141
142#[derive(Debug, Deserialize)]
143pub struct RecallParams {
144 pub query: String,
145 #[serde(default)]
146 pub limit: Option<usize>,
147}
148
149#[derive(Debug, Deserialize)]
150pub struct ListParams {
151 #[serde(default)]
152 pub limit: Option<usize>,
153 #[serde(default)]
154 pub scope: Option<String>,
155}
156
157#[derive(Debug, Deserialize)]
158pub struct ForgetParams {
159 pub id: String,
160}
161
162pub struct RecallConfig {
167 pub enable_graph_backfill: bool,
168 pub enable_graph_expansion: bool,
169 pub enable_co_access: bool,
170 pub enable_decay_scoring: bool,
171 pub enable_background_tasks: bool,
172}
173
174impl Default for RecallConfig {
175 fn default() -> Self {
176 Self {
177 enable_graph_backfill: true,
178 enable_graph_expansion: true,
179 enable_co_access: true,
180 enable_decay_scoring: true,
181 enable_background_tasks: true,
182 }
183 }
184}
185
186pub struct RecallResult {
188 pub results: Vec<crate::item::SearchResult>,
189 pub graph_expanded: Vec<Value>,
190 pub suggested: Vec<Value>,
191 pub raw_similarities: std::collections::HashMap<String, f32>,
193}
194
195pub async fn execute_tool(ctx: &ServerContext, name: &str, args: Option<Value>) -> CallToolResult {
198 let config = RetryConfig::default();
199 let args_for_retry = args.clone();
200
201 let result = with_retry(&config, || {
202 let ctx_ref = ctx;
203 let name_ref = name;
204 let args_clone = args_for_retry.clone();
205
206 async move {
207 let mut db = Database::open_with_embedder(
209 &ctx_ref.db_path,
210 ctx_ref.project_id.clone(),
211 ctx_ref.embedder.clone(),
212 )
213 .await
214 .map_err(|e| sanitize_err("Failed to open database", e))?;
215
216 let tracker = AccessTracker::open(&ctx_ref.access_db_path)
218 .map_err(|e| sanitize_err("Failed to open access tracker", e))?;
219
220 let graph = GraphStore::open(&ctx_ref.access_db_path)
222 .map_err(|e| sanitize_err("Failed to open graph store", e))?;
223
224 let result = match name_ref {
225 "store" => execute_store(&mut db, &tracker, &graph, ctx_ref, args_clone).await,
226 "recall" => execute_recall(&mut db, &tracker, &graph, ctx_ref, args_clone).await,
227 "list" => execute_list(&mut db, args_clone).await,
228 "forget" => execute_forget(&mut db, &graph, ctx_ref, args_clone).await,
229 _ => return Ok(CallToolResult::error(format!("Unknown tool: {}", name_ref))),
230 };
231
232 if result.is_error.unwrap_or(false)
233 && let Some(content) = result.content.first()
234 && is_retryable_error(&content.text)
235 {
236 return Err(content.text.clone());
237 }
238
239 Ok(result)
240 }
241 })
242 .await;
243
244 match result {
245 Ok(call_result) => call_result,
246 Err(e) => {
247 tracing::error!("Operation failed after retries: {}", e);
248 CallToolResult::error("Operation failed after retries")
249 }
250 }
251}
252
253fn is_retryable_error(error_msg: &str) -> bool {
254 let retryable_patterns = [
255 "connection",
256 "timeout",
257 "temporarily unavailable",
258 "resource busy",
259 "lock",
260 "I/O error",
261 "Failed to open",
262 "Failed to connect",
263 ];
264
265 let lower = error_msg.to_lowercase();
266 retryable_patterns
267 .iter()
268 .any(|p| lower.contains(&p.to_lowercase()))
269}
270
271async fn execute_store(
274 db: &mut Database,
275 _tracker: &AccessTracker,
276 graph: &GraphStore,
277 ctx: &ServerContext,
278 args: Option<Value>,
279) -> CallToolResult {
280 let params: StoreParams = match args {
281 Some(v) => match serde_json::from_value(v) {
282 Ok(p) => p,
283 Err(e) => {
284 tracing::debug!("Parameter validation failed: {}", e);
285 return CallToolResult::error("Invalid parameters");
286 }
287 },
288 None => return CallToolResult::error("Missing parameters"),
289 };
290
291 const MAX_CONTENT_BYTES: usize = 1_000_000;
295 if params.content.len() > MAX_CONTENT_BYTES {
296 return CallToolResult::error(format!(
297 "Content too large: {} bytes (max {} bytes)",
298 params.content.len(),
299 MAX_CONTENT_BYTES
300 ));
301 }
302
303 let scope = params
305 .scope
306 .as_deref()
307 .map(|s| s.parse::<StoreScope>())
308 .transpose();
309
310 let scope = match scope {
311 Ok(s) => s.unwrap_or(StoreScope::Project),
312 Err(e) => return CallToolResult::error(e),
313 };
314
315 let mut item = Item::new(¶ms.content);
317
318 #[cfg(feature = "bench")]
320 if let Some(ts) = params.created_at {
321 if let Some(dt) = chrono::DateTime::from_timestamp(ts, 0) {
322 item = item.with_created_at(dt);
323 }
324 }
325
326 if scope == StoreScope::Project
328 && let Some(project_id) = db.project_id()
329 {
330 item = item.with_project_id(project_id);
331 }
332
333 match db.store_item(item).await {
334 Ok(store_result) => {
335 let new_id = store_result.id.clone();
336
337 let now = chrono::Utc::now().timestamp();
339 let project_id = db.project_id().map(|s| s.to_string());
340 if let Err(e) = graph.add_node(&new_id, project_id.as_deref(), now) {
341 tracing::warn!("graph add_node failed: {}", e);
342 }
343
344 if !store_result.potential_conflicts.is_empty()
346 && let Ok(queue) = ConsolidationQueue::open(&ctx.access_db_path)
347 {
348 for conflict in &store_result.potential_conflicts {
349 if let Err(e) = queue.enqueue(&new_id, &conflict.id, conflict.similarity as f64)
350 {
351 tracing::warn!("enqueue consolidation failed: {}", e);
352 }
353 }
354 }
355
356 let mut result = json!({
357 "success": true,
358 "id": new_id,
359 "message": format!("Stored in {} scope", scope)
360 });
361
362 if !store_result.potential_conflicts.is_empty() {
363 let conflicts: Vec<Value> = store_result
364 .potential_conflicts
365 .iter()
366 .map(|c| {
367 json!({
368 "id": c.id,
369 "content": c.content,
370 "similarity": format!("{:.2}", c.similarity)
371 })
372 })
373 .collect();
374 result["potential_conflicts"] = json!(conflicts);
375 }
376
377 CallToolResult::success(
378 serde_json::to_string_pretty(&result)
379 .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
380 )
381 }
382 Err(e) => sanitized_error("Failed to store item", e),
383 }
384}
385
386pub async fn recall_pipeline(
391 db: &mut Database,
392 tracker: &AccessTracker,
393 graph: &GraphStore,
394 query: &str,
395 limit: usize,
396 filters: ItemFilters,
397 config: &RecallConfig,
398) -> std::result::Result<RecallResult, String> {
399 let mut results = db
400 .search_items(query, limit, filters)
401 .await
402 .map_err(|e| format!("Search failed: {}", e))?;
403
404 if results.is_empty() {
405 return Ok(RecallResult {
406 results: Vec::new(),
407 graph_expanded: Vec::new(),
408 suggested: Vec::new(),
409 raw_similarities: std::collections::HashMap::new(),
410 });
411 }
412
413 if config.enable_graph_backfill {
415 for result in &results {
416 if let Err(e) = graph.ensure_node_exists(
417 &result.id,
418 result.project_id.as_deref(),
419 result.created_at.timestamp(),
420 ) {
421 tracing::warn!("ensure_node_exists failed: {}", e);
422 }
423 }
424 }
425
426 let mut raw_similarities: std::collections::HashMap<String, f32> =
428 std::collections::HashMap::new();
429 if config.enable_decay_scoring {
430 let item_ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
431 let access_records = tracker.get_accesses(&item_ids).unwrap_or_default();
432 let validation_counts = tracker.get_validation_counts(&item_ids).unwrap_or_default();
433 let edge_counts = graph.get_edge_counts(&item_ids).unwrap_or_default();
434 let now = chrono::Utc::now().timestamp();
435
436 for result in &mut results {
437 raw_similarities.insert(result.id.clone(), result.similarity);
438
439 let created_at = result.created_at.timestamp();
440 let (access_count, last_accessed) = match access_records.get(&result.id) {
441 Some(rec) => (rec.access_count, Some(rec.last_accessed_at)),
442 None => (0, None),
443 };
444
445 let base_score = score_with_decay(
446 result.similarity,
447 now,
448 created_at,
449 access_count,
450 last_accessed,
451 );
452
453 let validation_count = validation_counts.get(&result.id).copied().unwrap_or(0);
454 let edge_count = edge_counts.get(&result.id).copied().unwrap_or(0);
455 let trust_bonus =
456 1.0 + 0.05 * (1.0 + validation_count as f64).ln() as f32 + 0.02 * edge_count as f32;
457
458 result.similarity = (base_score * trust_bonus).min(1.0);
459 }
460
461 results.sort_by(|a, b| {
462 b.similarity
463 .partial_cmp(&a.similarity)
464 .unwrap_or(std::cmp::Ordering::Equal)
465 });
466 }
467
468 for result in &results {
470 let created_at = result.created_at.timestamp();
471 if let Err(e) = tracker.record_access(&result.id, created_at) {
472 tracing::warn!("record_access failed: {}", e);
473 }
474 }
475
476 let existing_ids: std::collections::HashSet<String> =
478 results.iter().map(|r| r.id.clone()).collect();
479
480 let mut graph_expanded = Vec::new();
481 if config.enable_graph_expansion {
482 let top_ids: Vec<&str> = results.iter().take(5).map(|r| r.id.as_str()).collect();
483 if let Ok(neighbors) = graph.get_neighbors(&top_ids, 0.5) {
484 let neighbor_info: Vec<(String, String)> = neighbors
486 .into_iter()
487 .filter(|(id, _, _)| !existing_ids.contains(id))
488 .map(|(id, rel_type, _)| (id, rel_type))
489 .collect();
490
491 let neighbor_ids: Vec<&str> = neighbor_info.iter().map(|(id, _)| id.as_str()).collect();
492 if let Ok(items) = db.get_items_batch(&neighbor_ids).await {
493 let item_map: std::collections::HashMap<&str, &Item> =
494 items.iter().map(|item| (item.id.as_str(), item)).collect();
495
496 for (neighbor_id, rel_type) in &neighbor_info {
497 if let Some(item) = item_map.get(neighbor_id.as_str()) {
498 let sr = crate::item::SearchResult::from_item(item, 0.05);
499 let mut entry = json!({
500 "id": sr.id,
501 "similarity": "graph",
502 "created": sr.created_at.to_rfc3339(),
503 "graph_expanded": true,
504 "rel_type": rel_type,
505 });
506 let same_project = match (db.project_id(), item.project_id.as_deref()) {
508 (Some(current), Some(item_pid)) => current == item_pid,
509 (_, None) => true,
510 _ => false,
511 };
512 if same_project {
513 entry["content"] = json!(sr.content);
514 } else {
515 entry["cross_project"] = json!(true);
516 }
517 graph_expanded.push(entry);
518 }
519 }
520 }
521 }
522 }
523
524 let mut suggested = Vec::new();
526 if config.enable_co_access {
527 let top3_ids: Vec<&str> = results.iter().take(3).map(|r| r.id.as_str()).collect();
528 if let Ok(co_accessed) = graph.get_co_accessed(&top3_ids, 3) {
529 let co_info: Vec<(String, i64)> = co_accessed
530 .into_iter()
531 .filter(|(id, _)| !existing_ids.contains(id))
532 .collect();
533
534 let co_ids: Vec<&str> = co_info.iter().map(|(id, _)| id.as_str()).collect();
535 if let Ok(items) = db.get_items_batch(&co_ids).await {
536 let item_map: std::collections::HashMap<&str, &Item> =
537 items.iter().map(|item| (item.id.as_str(), item)).collect();
538
539 for (co_id, co_count) in &co_info {
540 if let Some(item) = item_map.get(co_id.as_str()) {
541 let same_project = match (db.project_id(), item.project_id.as_deref()) {
542 (Some(current), Some(item_pid)) => current == item_pid,
543 (_, None) => true,
544 _ => false,
545 };
546 let mut entry = json!({
547 "id": item.id,
548 "reason": format!("frequently recalled with result (co-accessed {} times)", co_count),
549 });
550 if same_project {
551 entry["content"] = json!(truncate(&item.content, 100));
552 } else {
553 entry["cross_project"] = json!(true);
554 }
555 suggested.push(entry);
556 }
557 }
558 }
559 }
560 }
561
562 Ok(RecallResult {
563 results,
564 graph_expanded,
565 suggested,
566 raw_similarities,
567 })
568}
569
570async fn execute_recall(
571 db: &mut Database,
572 tracker: &AccessTracker,
573 graph: &GraphStore,
574 ctx: &ServerContext,
575 args: Option<Value>,
576) -> CallToolResult {
577 let params: RecallParams = match args {
578 Some(v) => match serde_json::from_value(v) {
579 Ok(p) => p,
580 Err(e) => {
581 tracing::debug!("Parameter validation failed: {}", e);
582 return CallToolResult::error("Invalid parameters");
583 }
584 },
585 None => return CallToolResult::error("Missing parameters"),
586 };
587
588 const MAX_QUERY_BYTES: usize = 10_000;
592 if params.query.len() > MAX_QUERY_BYTES {
593 return CallToolResult::error(format!(
594 "Query too large: {} bytes (max {} bytes)",
595 params.query.len(),
596 MAX_QUERY_BYTES
597 ));
598 }
599
600 let limit = params.limit.unwrap_or(5).min(100);
601
602 let filters = ItemFilters::new();
603
604 let config = RecallConfig::default();
605
606 let recall_result =
607 match recall_pipeline(db, tracker, graph, ¶ms.query, limit, filters, &config).await {
608 Ok(r) => r,
609 Err(e) => {
610 tracing::error!("Recall failed: {}", e);
611 return CallToolResult::error("Search failed");
612 }
613 };
614
615 if recall_result.results.is_empty() {
616 return CallToolResult::success("No items found matching your query.");
617 }
618
619 let results = &recall_result.results;
620
621 let all_result_ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
623 let neighbors_map = graph
624 .get_neighbors_mapped(&all_result_ids, 0.5)
625 .unwrap_or_default();
626
627 let formatted: Vec<Value> = results
628 .iter()
629 .map(|r| {
630 let mut obj = json!({
631 "id": r.id,
632 "content": r.content,
633 "similarity": format!("{:.2}", r.similarity),
634 "created": r.created_at.to_rfc3339(),
635 });
636
637 if let Some(&raw_sim) = recall_result.raw_similarities.get(&r.id)
639 && (raw_sim - r.similarity).abs() > 0.001
640 {
641 obj["raw_similarity"] = json!(format!("{:.2}", raw_sim));
642 }
643
644 if let Some(ref excerpt) = r.relevant_excerpt {
645 obj["relevant_excerpt"] = json!(excerpt);
646 }
647
648 if let Some(ref current_pid) = ctx.project_id
650 && let Some(ref item_pid) = r.project_id
651 && item_pid != current_pid
652 {
653 obj["cross_project"] = json!(true);
654 }
655
656 if let Some(related) = neighbors_map.get(&r.id)
658 && !related.is_empty()
659 {
660 obj["related_ids"] = json!(related);
661 }
662
663 obj
664 })
665 .collect();
666
667 let mut result_json = json!({
668 "count": results.len(),
669 "results": formatted
670 });
671
672 if !recall_result.graph_expanded.is_empty() {
673 result_json["graph_expanded"] = json!(recall_result.graph_expanded);
674 }
675
676 if !recall_result.suggested.is_empty() {
677 result_json["suggested"] = json!(recall_result.suggested);
678 }
679
680 spawn_consolidation(
682 Arc::new(ctx.db_path.clone()),
683 Arc::new(ctx.access_db_path.clone()),
684 ctx.project_id.clone(),
685 ctx.embedder.clone(),
686 ctx.consolidation_semaphore.clone(),
687 );
688
689 let result_ids: Vec<String> = results.iter().map(|r| r.id.clone()).collect();
691 let access_db_path = ctx.access_db_path.clone();
692 spawn_logged("co_access", async move {
693 if let Ok(g) = GraphStore::open(&access_db_path) {
694 if let Err(e) = g.record_co_access(&result_ids) {
695 tracing::warn!("record_co_access failed: {}", e);
696 }
697 } else {
698 tracing::warn!("co_access: failed to open graph store");
699 }
700 });
701
702 let run_count = ctx
704 .recall_count
705 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
706 if run_count % 10 == 9 {
707 let access_db_path = ctx.access_db_path.clone();
709 spawn_logged("clustering", async move {
710 if let Ok(g) = GraphStore::open(&access_db_path)
711 && let Ok(clusters) = g.detect_clusters()
712 {
713 for (a, b, c) in &clusters {
714 let label = format!("cluster-{}", &a[..8.min(a.len())]);
715 if let Err(e) = g.add_related_edge(a, b, 0.8, &label) {
716 tracing::warn!("cluster add_related_edge failed: {}", e);
717 }
718 if let Err(e) = g.add_related_edge(b, c, 0.8, &label) {
719 tracing::warn!("cluster add_related_edge failed: {}", e);
720 }
721 if let Err(e) = g.add_related_edge(a, c, 0.8, &label) {
722 tracing::warn!("cluster add_related_edge failed: {}", e);
723 }
724 }
725 if !clusters.is_empty() {
726 tracing::info!("Detected {} clusters", clusters.len());
727 }
728 }
729 });
730
731 let access_db_path2 = ctx.access_db_path.clone();
733 spawn_logged("consolidation_cleanup", async move {
734 if let Ok(q) = crate::consolidation::ConsolidationQueue::open(&access_db_path2)
735 && let Err(e) = q.cleanup_processed()
736 {
737 tracing::warn!("consolidation queue cleanup failed: {}", e);
738 }
739 });
740 }
741
742 CallToolResult::success(
743 serde_json::to_string_pretty(&result_json)
744 .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
745 )
746}
747
748async fn execute_list(db: &mut Database, args: Option<Value>) -> CallToolResult {
749 let params: ListParams =
750 args.and_then(|v| serde_json::from_value(v).ok())
751 .unwrap_or(ListParams {
752 limit: None,
753 scope: None,
754 });
755
756 let limit = params.limit.unwrap_or(10).min(100);
757
758 let filters = ItemFilters::new();
759
760 let scope = params
761 .scope
762 .as_deref()
763 .map(|s| s.parse::<ListScope>())
764 .transpose();
765
766 let scope = match scope {
767 Ok(s) => s.unwrap_or(ListScope::Project),
768 Err(e) => return CallToolResult::error(e),
769 };
770
771 match db.list_items(filters, Some(limit), scope).await {
772 Ok(items) => {
773 if items.is_empty() {
774 CallToolResult::success("No items stored yet.")
775 } else {
776 let formatted: Vec<Value> = items
777 .iter()
778 .map(|item| {
779 let content_preview = truncate(&item.content, 100);
780 let mut obj = json!({
781 "id": item.id,
782 "content": content_preview,
783 "created": item.created_at.to_rfc3339(),
784 });
785
786 if item.is_chunked {
787 obj["chunked"] = json!(true);
788 }
789
790 obj
791 })
792 .collect();
793
794 let result = json!({
795 "count": items.len(),
796 "items": formatted
797 });
798
799 CallToolResult::success(
800 serde_json::to_string_pretty(&result).unwrap_or_else(|e| {
801 format!("{{\"error\": \"serialization failed: {}\"}}", e)
802 }),
803 )
804 }
805 }
806 Err(e) => sanitized_error("Failed to list items", e),
807 }
808}
809
810async fn execute_forget(
811 db: &mut Database,
812 graph: &GraphStore,
813 ctx: &ServerContext,
814 args: Option<Value>,
815) -> CallToolResult {
816 let params: ForgetParams = match args {
817 Some(v) => match serde_json::from_value(v) {
818 Ok(p) => p,
819 Err(e) => {
820 tracing::debug!("Parameter validation failed: {}", e);
821 return CallToolResult::error("Invalid parameters");
822 }
823 },
824 None => return CallToolResult::error("Missing parameters"),
825 };
826
827 if let Some(ref current_pid) = ctx.project_id {
829 match db.get_item(¶ms.id).await {
830 Ok(Some(item)) => {
831 if let Some(ref item_pid) = item.project_id
832 && item_pid != current_pid
833 {
834 return CallToolResult::error(format!(
835 "Cannot delete item {} from a different project",
836 params.id
837 ));
838 }
839 }
840 Ok(None) => return CallToolResult::error(format!("Item not found: {}", params.id)),
841 Err(e) => {
842 return sanitized_error("Failed to look up item", e);
843 }
844 }
845 }
846
847 match db.delete_item(¶ms.id).await {
848 Ok(true) => {
849 if let Err(e) = graph.remove_node(¶ms.id) {
851 tracing::warn!("remove_node failed: {}", e);
852 }
853
854 let result = json!({
855 "success": true,
856 "message": format!("Deleted item: {}", params.id)
857 });
858 CallToolResult::success(
859 serde_json::to_string_pretty(&result)
860 .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
861 )
862 }
863 Ok(false) => CallToolResult::error(format!("Item not found: {}", params.id)),
864 Err(e) => sanitized_error("Failed to delete item", e),
865 }
866}
867
868fn sanitized_error(context: &str, err: impl std::fmt::Display) -> CallToolResult {
873 tracing::error!("{}: {}", context, err);
874 CallToolResult::error(context.to_string())
875}
876
877fn sanitize_err(context: &str, err: impl std::fmt::Display) -> String {
879 tracing::error!("{}: {}", context, err);
880 context.to_string()
881}
882
883fn truncate(s: &str, max_len: usize) -> String {
884 if s.chars().count() <= max_len {
885 s.to_string()
886 } else if max_len <= 3 {
887 s.chars().take(max_len).collect()
889 } else {
890 let cut = s
891 .char_indices()
892 .nth(max_len - 3)
893 .map(|(i, _)| i)
894 .unwrap_or(s.len());
895 format!("{}...", &s[..cut])
896 }
897}
898
899#[cfg(test)]
900mod tests {
901 use super::*;
902
903 #[test]
904 fn test_truncate_small_max_len() {
905 assert_eq!(truncate("hello", 0), "");
907 assert_eq!(truncate("hello", 1), "h");
908 assert_eq!(truncate("hello", 2), "he");
909 assert_eq!(truncate("hello", 3), "hel");
910 assert_eq!(truncate("hi", 3), "hi"); assert_eq!(truncate("hello", 5), "hello");
912 assert_eq!(truncate("hello!", 5), "he...");
913 }
914
915 #[test]
916 fn test_truncate_unicode() {
917 assert_eq!(truncate("héllo wörld", 5), "hé...");
918 assert_eq!(truncate("日本語テスト", 4), "日...");
919 }
920
921 use std::path::PathBuf;
924 use std::sync::Mutex;
925 use tokio::sync::Semaphore;
926
927 async fn setup_test_context() -> (ServerContext, tempfile::TempDir) {
929 let tmp = tempfile::TempDir::new().unwrap();
930 let db_path = tmp.path().join("data");
931 let access_db_path = tmp.path().join("access.db");
932
933 let embedder = Arc::new(crate::Embedder::new().unwrap());
934 let project_id = Some("test-project-00000001".to_string());
935
936 let ctx = ServerContext {
937 db_path,
938 access_db_path,
939 project_id,
940 embedder,
941 cwd: PathBuf::from("."),
942 consolidation_semaphore: Arc::new(Semaphore::new(1)),
943 recall_count: std::sync::atomic::AtomicU64::new(0),
944 rate_limit: Mutex::new(super::super::server::RateLimitState {
945 window_start_ms: 0,
946 count: 0,
947 }),
948 };
949
950 (ctx, tmp)
951 }
952
953 #[tokio::test]
954 #[ignore] async fn test_store_and_recall_roundtrip() {
956 let (ctx, _tmp) = setup_test_context().await;
957
958 let store_result = execute_tool(
960 &ctx,
961 "store",
962 Some(json!({ "content": "Rust is a systems programming language" })),
963 )
964 .await;
965 assert!(
966 store_result.is_error.is_none(),
967 "Store should succeed: {:?}",
968 store_result.content
969 );
970
971 let recall_result = execute_tool(
973 &ctx,
974 "recall",
975 Some(json!({ "query": "systems programming language" })),
976 )
977 .await;
978 assert!(recall_result.is_error.is_none(), "Recall should succeed");
979
980 let text = &recall_result.content[0].text;
981 assert!(
982 text.contains("Rust is a systems programming language"),
983 "Recall should return stored content, got: {}",
984 text
985 );
986 }
987
988 #[tokio::test]
989 #[ignore] async fn test_store_and_list() {
991 let (ctx, _tmp) = setup_test_context().await;
992
993 execute_tool(
995 &ctx,
996 "store",
997 Some(json!({ "content": "First item for listing" })),
998 )
999 .await;
1000 execute_tool(
1001 &ctx,
1002 "store",
1003 Some(json!({ "content": "Second item for listing" })),
1004 )
1005 .await;
1006
1007 let list_result = execute_tool(&ctx, "list", Some(json!({ "scope": "project" }))).await;
1009 assert!(list_result.is_error.is_none(), "List should succeed");
1010
1011 let text = &list_result.content[0].text;
1012 let parsed: Value = serde_json::from_str(text).unwrap();
1013 assert_eq!(parsed["count"], 2, "Should list 2 items");
1014 }
1015
1016 #[tokio::test]
1017 #[ignore] async fn test_store_conflict_detection() {
1019 let (ctx, _tmp) = setup_test_context().await;
1020
1021 execute_tool(
1023 &ctx,
1024 "store",
1025 Some(json!({ "content": "The quick brown fox jumps over the lazy dog" })),
1026 )
1027 .await;
1028
1029 let result = execute_tool(
1031 &ctx,
1032 "store",
1033 Some(json!({ "content": "The quick brown fox jumps over the lazy dog" })),
1034 )
1035 .await;
1036 assert!(result.is_error.is_none(), "Store should succeed");
1037
1038 let text = &result.content[0].text;
1039 let parsed: Value = serde_json::from_str(text).unwrap();
1040 assert!(
1041 parsed.get("potential_conflicts").is_some(),
1042 "Should detect conflict for near-duplicate content, got: {}",
1043 text
1044 );
1045 }
1046
1047 #[tokio::test]
1048 #[ignore] async fn test_forget_removes_item() {
1050 let (ctx, _tmp) = setup_test_context().await;
1051
1052 let store_result = execute_tool(
1054 &ctx,
1055 "store",
1056 Some(json!({ "content": "Item to be forgotten" })),
1057 )
1058 .await;
1059 let text = &store_result.content[0].text;
1060 let parsed: Value = serde_json::from_str(text).unwrap();
1061 let item_id = parsed["id"].as_str().unwrap().to_string();
1062
1063 let forget_result = execute_tool(&ctx, "forget", Some(json!({ "id": item_id }))).await;
1065 assert!(forget_result.is_error.is_none(), "Forget should succeed");
1066
1067 let list_result = execute_tool(&ctx, "list", Some(json!({ "scope": "project" }))).await;
1069 let text = &list_result.content[0].text;
1070 assert!(
1071 text.contains("No items stored yet"),
1072 "Should have no items after forget, got: {}",
1073 text
1074 );
1075 }
1076
1077 #[tokio::test]
1078 #[ignore] async fn test_recall_empty_db() {
1080 let (ctx, _tmp) = setup_test_context().await;
1081
1082 let result = execute_tool(&ctx, "recall", Some(json!({ "query": "anything" }))).await;
1083 assert!(
1084 result.is_error.is_none(),
1085 "Recall on empty DB should not error"
1086 );
1087
1088 let text = &result.content[0].text;
1089 assert!(
1090 text.contains("No items found"),
1091 "Should indicate no items found, got: {}",
1092 text
1093 );
1094 }
1095
1096 #[tokio::test]
1097 #[ignore] async fn test_store_rejects_oversized_content() {
1099 let (ctx, _tmp) = setup_test_context().await;
1100
1101 let large_content = "x".repeat(1_100_000); let result = execute_tool(&ctx, "store", Some(json!({ "content": large_content }))).await;
1103 assert!(
1104 result.is_error == Some(true),
1105 "Should reject oversized content"
1106 );
1107
1108 let text = &result.content[0].text;
1109 assert!(
1110 text.contains("too large"),
1111 "Error should mention size, got: {}",
1112 text
1113 );
1114 }
1115
1116 #[tokio::test]
1117 #[ignore] async fn test_recall_rejects_oversized_query() {
1119 let (ctx, _tmp) = setup_test_context().await;
1120
1121 let large_query = "x".repeat(11_000); let result = execute_tool(&ctx, "recall", Some(json!({ "query": large_query }))).await;
1123 assert!(
1124 result.is_error == Some(true),
1125 "Should reject oversized query"
1126 );
1127
1128 let text = &result.content[0].text;
1129 assert!(
1130 text.contains("too large"),
1131 "Error should mention size, got: {}",
1132 text
1133 );
1134 }
1135
1136 #[tokio::test]
1137 #[ignore] async fn test_store_missing_params() {
1139 let (ctx, _tmp) = setup_test_context().await;
1140
1141 let result = execute_tool(&ctx, "store", None).await;
1143 assert!(result.is_error == Some(true), "Should error with no params");
1144
1145 let result = execute_tool(&ctx, "store", Some(json!({}))).await;
1147 assert!(
1148 result.is_error == Some(true),
1149 "Should error with missing content"
1150 );
1151 }
1152}