Skip to main content

sediment/mcp/
tools.rs

1//! MCP Tool definitions for Sediment
2//!
3//! 4 tools: store, recall, list, forget
4
5use std::sync::Arc;
6
7use serde::Deserialize;
8use serde_json::{Value, json};
9
10use crate::access::AccessTracker;
11use crate::consolidation::{ConsolidationQueue, spawn_consolidation};
12use crate::db::score_with_decay;
13use crate::graph::GraphStore;
14use crate::item::{Item, ItemFilters};
15use crate::retry::{RetryConfig, with_retry};
16use crate::{Database, ListScope, StoreScope};
17
18use super::protocol::{CallToolResult, Tool};
19use super::server::ServerContext;
20
21/// Spawn a background task with panic logging. If the task panics, the panic
22/// is caught and logged as an error instead of silently disappearing.
23fn spawn_logged(name: &'static str, fut: impl std::future::Future<Output = ()> + Send + 'static) {
24    tokio::spawn(async move {
25        let result = tokio::task::spawn(fut).await;
26        if let Err(e) = result {
27            tracing::error!("Background task '{}' panicked: {:?}", name, e);
28        }
29    });
30}
31
32/// Get all available tools (4 total)
33pub fn get_tools() -> Vec<Tool> {
34    let store_schema = {
35        #[allow(unused_mut)]
36        let mut props = json!({
37            "content": {
38                "type": "string",
39                "description": "The content to store"
40            },
41            "scope": {
42                "type": "string",
43                "enum": ["project", "global"],
44                "default": "project",
45                "description": "Where to store: 'project' (current project) or 'global' (all projects)"
46            }
47        });
48
49        #[cfg(feature = "bench")]
50        {
51            props.as_object_mut().unwrap().insert(
52                "created_at".to_string(),
53                json!({
54                    "type": "number",
55                    "description": "Override creation timestamp (Unix seconds). Benchmark builds only."
56                }),
57            );
58        }
59
60        json!({
61            "type": "object",
62            "properties": props,
63            "required": ["content"]
64        })
65    };
66
67    vec![
68        Tool {
69            name: "store".to_string(),
70            description: "Store content for later retrieval. Use for preferences, facts, reference material, docs, or any information worth remembering. Long content is automatically chunked for better search.".to_string(),
71            input_schema: store_schema,
72        },
73        Tool {
74            name: "recall".to_string(),
75            description: "Search stored content by semantic similarity. Returns matching items with relevant excerpts for chunked content.".to_string(),
76            input_schema: json!({
77                "type": "object",
78                "properties": {
79                    "query": {
80                        "type": "string",
81                        "description": "What to search for (semantic search)"
82                    },
83                    "limit": {
84                        "type": "number",
85                        "default": 5,
86                        "description": "Maximum number of results"
87                    }
88                },
89                "required": ["query"]
90            }),
91        },
92        Tool {
93            name: "list".to_string(),
94            description: "List stored items.".to_string(),
95            input_schema: json!({
96                "type": "object",
97                "properties": {
98                    "limit": {
99                        "type": "number",
100                        "default": 10,
101                        "description": "Maximum number of results"
102                    },
103                    "scope": {
104                        "type": "string",
105                        "enum": ["project", "global", "all"],
106                        "default": "project",
107                        "description": "Which items to list: 'project', 'global', or 'all'"
108                    }
109                }
110            }),
111        },
112        Tool {
113            name: "forget".to_string(),
114            description: "Delete a stored item by its ID.".to_string(),
115            input_schema: json!({
116                "type": "object",
117                "properties": {
118                    "id": {
119                        "type": "string",
120                        "description": "The item ID to delete"
121                    }
122                },
123                "required": ["id"]
124            }),
125        },
126    ]
127}
128
129// ========== Parameter Structs ==========
130
131#[derive(Debug, Deserialize)]
132pub struct StoreParams {
133    pub content: String,
134    #[serde(default)]
135    pub scope: Option<String>,
136    /// Override creation timestamp (Unix seconds). Benchmark builds only.
137    #[cfg(feature = "bench")]
138    #[serde(default)]
139    pub created_at: Option<i64>,
140}
141
142#[derive(Debug, Deserialize)]
143pub struct RecallParams {
144    pub query: String,
145    #[serde(default)]
146    pub limit: Option<usize>,
147}
148
149#[derive(Debug, Deserialize)]
150pub struct ListParams {
151    #[serde(default)]
152    pub limit: Option<usize>,
153    #[serde(default)]
154    pub scope: Option<String>,
155}
156
157#[derive(Debug, Deserialize)]
158pub struct ForgetParams {
159    pub id: String,
160}
161
162// ========== Recall Configuration ==========
163
164/// Controls which graph and scoring features are enabled during recall.
165/// Used by benchmarks to measure the impact of individual features.
166pub struct RecallConfig {
167    pub enable_graph_backfill: bool,
168    pub enable_graph_expansion: bool,
169    pub enable_co_access: bool,
170    pub enable_decay_scoring: bool,
171    pub enable_background_tasks: bool,
172}
173
174impl Default for RecallConfig {
175    fn default() -> Self {
176        Self {
177            enable_graph_backfill: true,
178            enable_graph_expansion: true,
179            enable_co_access: true,
180            enable_decay_scoring: true,
181            enable_background_tasks: true,
182        }
183    }
184}
185
186/// Result of a recall pipeline execution (for benchmark consumption).
187pub struct RecallResult {
188    pub results: Vec<crate::item::SearchResult>,
189    pub graph_expanded: Vec<Value>,
190    pub suggested: Vec<Value>,
191    /// Raw (pre-decay/trust) similarity scores, keyed by item ID
192    pub raw_similarities: std::collections::HashMap<String, f32>,
193}
194
195// ========== Tool Execution ==========
196
197pub async fn execute_tool(ctx: &ServerContext, name: &str, args: Option<Value>) -> CallToolResult {
198    let config = RetryConfig::default();
199    let args_for_retry = args.clone();
200
201    let result = with_retry(&config, || {
202        let ctx_ref = ctx;
203        let name_ref = name;
204        let args_clone = args_for_retry.clone();
205
206        async move {
207            // Open fresh connection with shared embedder
208            let mut db = Database::open_with_embedder(
209                &ctx_ref.db_path,
210                ctx_ref.project_id.clone(),
211                ctx_ref.embedder.clone(),
212            )
213            .await
214            .map_err(|e| sanitize_err("Failed to open database", e))?;
215
216            // Open graph store (shares access.db) — needed by store, recall, forget
217            let graph = GraphStore::open(&ctx_ref.access_db_path)
218                .map_err(|e| sanitize_err("Failed to open graph store", e))?;
219
220            let result = match name_ref {
221                "store" => execute_store(&mut db, &graph, ctx_ref, args_clone).await,
222                "recall" => {
223                    // Access tracker only needed for recall (decay scoring)
224                    let tracker = AccessTracker::open(&ctx_ref.access_db_path)
225                        .map_err(|e| sanitize_err("Failed to open access tracker", e))?;
226                    execute_recall(&mut db, &tracker, &graph, ctx_ref, args_clone).await
227                }
228                "list" => execute_list(&mut db, args_clone).await,
229                "forget" => execute_forget(&mut db, &graph, ctx_ref, args_clone).await,
230                _ => return Ok(CallToolResult::error(format!("Unknown tool: {}", name_ref))),
231            };
232
233            if result.is_error.unwrap_or(false)
234                && let Some(content) = result.content.first()
235                && is_retryable_error(&content.text)
236            {
237                return Err(content.text.clone());
238            }
239
240            Ok(result)
241        }
242    })
243    .await;
244
245    match result {
246        Ok(call_result) => call_result,
247        Err(e) => {
248            tracing::error!("Operation failed after retries: {}", e);
249            CallToolResult::error("Operation failed after retries")
250        }
251    }
252}
253
254fn is_retryable_error(error_msg: &str) -> bool {
255    let retryable_patterns = [
256        "connection",
257        "timeout",
258        "temporarily unavailable",
259        "resource busy",
260        "lock",
261        "I/O error",
262        "Failed to open",
263        "Failed to connect",
264    ];
265
266    let lower = error_msg.to_lowercase();
267    retryable_patterns
268        .iter()
269        .any(|p| lower.contains(&p.to_lowercase()))
270}
271
272// ========== Tool Implementations ==========
273
274async fn execute_store(
275    db: &mut Database,
276    graph: &GraphStore,
277    ctx: &ServerContext,
278    args: Option<Value>,
279) -> CallToolResult {
280    let params: StoreParams = match args {
281        Some(v) => match serde_json::from_value(v) {
282            Ok(p) => p,
283            Err(e) => {
284                tracing::debug!("Parameter validation failed: {}", e);
285                return CallToolResult::error("Invalid parameters");
286            }
287        },
288        None => return CallToolResult::error("Missing parameters"),
289    };
290
291    if params.content.trim().is_empty() {
292        return CallToolResult::error("Content must not be empty");
293    }
294
295    // Reject oversized content to prevent OOM during embedding/chunking.
296    // Intentionally byte-based (not char-based): memory allocation is proportional
297    // to byte length, so this is the correct metric for OOM prevention.
298    const MAX_CONTENT_BYTES: usize = 1_000_000;
299    if params.content.len() > MAX_CONTENT_BYTES {
300        return CallToolResult::error(format!(
301            "Content too large: {} bytes (max {} bytes)",
302            params.content.len(),
303            MAX_CONTENT_BYTES
304        ));
305    }
306
307    // Parse scope
308    let scope = params
309        .scope
310        .as_deref()
311        .map(|s| s.parse::<StoreScope>())
312        .transpose();
313
314    let scope = match scope {
315        Ok(s) => s.unwrap_or(StoreScope::Project),
316        Err(e) => return CallToolResult::error(e),
317    };
318
319    // Build item
320    let mut item = Item::new(&params.content);
321
322    // Override created_at if provided (benchmark builds only)
323    #[cfg(feature = "bench")]
324    if let Some(ts) = params.created_at {
325        if let Some(dt) = chrono::DateTime::from_timestamp(ts, 0) {
326            item = item.with_created_at(dt);
327        }
328    }
329
330    // Set project_id based on scope
331    if scope == StoreScope::Project
332        && let Some(project_id) = db.project_id()
333    {
334        item = item.with_project_id(project_id);
335    }
336
337    match db.store_item(item).await {
338        Ok(store_result) => {
339            let new_id = store_result.id.clone();
340
341            // Create graph node
342            let now = chrono::Utc::now().timestamp();
343            let project_id = db.project_id().map(|s| s.to_string());
344            if let Err(e) = graph.add_node(&new_id, project_id.as_deref(), now) {
345                tracing::warn!("graph add_node failed: {}", e);
346            }
347
348            // Enqueue consolidation candidates from conflicts
349            if !store_result.potential_conflicts.is_empty()
350                && let Ok(queue) = ConsolidationQueue::open(&ctx.access_db_path)
351            {
352                for conflict in &store_result.potential_conflicts {
353                    if let Err(e) = queue.enqueue(&new_id, &conflict.id, conflict.similarity as f64)
354                    {
355                        tracing::warn!("enqueue consolidation failed: {}", e);
356                    }
357                }
358            }
359
360            let mut result = json!({
361                "success": true,
362                "id": new_id,
363                "message": format!("Stored in {} scope", scope)
364            });
365
366            if !store_result.potential_conflicts.is_empty() {
367                let conflicts: Vec<Value> = store_result
368                    .potential_conflicts
369                    .iter()
370                    .map(|c| {
371                        json!({
372                            "id": c.id,
373                            "content": c.content,
374                            "similarity": format!("{:.2}", c.similarity)
375                        })
376                    })
377                    .collect();
378                result["potential_conflicts"] = json!(conflicts);
379            }
380
381            CallToolResult::success(
382                serde_json::to_string_pretty(&result)
383                    .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
384            )
385        }
386        Err(e) => sanitized_error("Failed to store item", e),
387    }
388}
389
390/// Core recall pipeline, extracted for benchmarking.
391///
392/// Performs: vector search, optional decay scoring, optional graph backfill,
393/// optional 1-hop graph expansion, and optional co-access suggestions.
394pub async fn recall_pipeline(
395    db: &mut Database,
396    tracker: &AccessTracker,
397    graph: &GraphStore,
398    query: &str,
399    limit: usize,
400    filters: ItemFilters,
401    config: &RecallConfig,
402) -> std::result::Result<RecallResult, String> {
403    let mut results = db
404        .search_items(query, limit, filters)
405        .await
406        .map_err(|e| format!("Search failed: {}", e))?;
407
408    if results.is_empty() {
409        return Ok(RecallResult {
410            results: Vec::new(),
411            graph_expanded: Vec::new(),
412            suggested: Vec::new(),
413            raw_similarities: std::collections::HashMap::new(),
414        });
415    }
416
417    // Lazy graph backfill (uses project_id from SearchResult, no extra queries)
418    if config.enable_graph_backfill {
419        for result in &results {
420            if let Err(e) = graph.add_node(
421                &result.id,
422                result.project_id.as_deref(),
423                result.created_at.timestamp(),
424            ) {
425                tracing::warn!("graph backfill failed: {}", e);
426            }
427        }
428    }
429
430    // Decay scoring — preserve raw similarity for transparency
431    let mut raw_similarities: std::collections::HashMap<String, f32> =
432        std::collections::HashMap::new();
433    if config.enable_decay_scoring {
434        let item_ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
435        let decay_data = tracker.get_decay_data(&item_ids).unwrap_or_default();
436        let edge_counts = graph.get_edge_counts(&item_ids).unwrap_or_default();
437        let now = chrono::Utc::now().timestamp();
438
439        for result in &mut results {
440            raw_similarities.insert(result.id.clone(), result.similarity);
441
442            let created_at = result.created_at.timestamp();
443            let (access_count, last_accessed, validation_count) = match decay_data.get(&result.id) {
444                Some(data) => (
445                    data.access_count,
446                    Some(data.last_accessed_at),
447                    data.validation_count,
448                ),
449                None => (0, None, 0),
450            };
451
452            let base_score = score_with_decay(
453                result.similarity,
454                now,
455                created_at,
456                access_count,
457                last_accessed,
458            );
459
460            let edge_count = edge_counts.get(&result.id).copied().unwrap_or(0);
461            let trust_bonus = 1.0
462                + 0.05 * (1.0 + validation_count as f64).ln() as f32
463                + 0.005 * edge_count as f32;
464
465            result.similarity = (base_score * trust_bonus).min(1.0);
466        }
467
468        results.sort_by(|a, b| {
469            b.similarity
470                .partial_cmp(&a.similarity)
471                .unwrap_or(std::cmp::Ordering::Equal)
472        });
473    }
474
475    // Record access
476    for result in &results {
477        let created_at = result.created_at.timestamp();
478        if let Err(e) = tracker.record_access(&result.id, created_at) {
479            tracing::warn!("record_access failed: {}", e);
480        }
481    }
482
483    // Graph expansion
484    let existing_ids: std::collections::HashSet<String> =
485        results.iter().map(|r| r.id.clone()).collect();
486
487    let mut graph_expanded = Vec::new();
488    if config.enable_graph_expansion {
489        let top_ids: Vec<&str> = results.iter().take(5).map(|r| r.id.as_str()).collect();
490        if let Ok(neighbors) = graph.get_neighbors(&top_ids, 0.5) {
491            // Collect neighbor IDs not already in results, then batch fetch
492            let neighbor_info: Vec<(String, String)> = neighbors
493                .into_iter()
494                .filter(|(id, _, _)| !existing_ids.contains(id))
495                .map(|(id, rel_type, _)| (id, rel_type))
496                .collect();
497
498            let neighbor_ids: Vec<&str> = neighbor_info.iter().map(|(id, _)| id.as_str()).collect();
499            if let Ok(items) = db.get_items_batch(&neighbor_ids).await {
500                let item_map: std::collections::HashMap<&str, &Item> =
501                    items.iter().map(|item| (item.id.as_str(), item)).collect();
502
503                for (neighbor_id, rel_type) in &neighbor_info {
504                    if let Some(item) = item_map.get(neighbor_id.as_str()) {
505                        let sr = crate::item::SearchResult::from_item(item, 0.05);
506                        let mut entry = json!({
507                            "id": sr.id,
508                            "similarity": "graph",
509                            "created": sr.created_at.to_rfc3339(),
510                            "graph_expanded": true,
511                            "rel_type": rel_type,
512                        });
513                        // Only include content for same-project or global items
514                        let same_project = match (db.project_id(), item.project_id.as_deref()) {
515                            (Some(current), Some(item_pid)) => current == item_pid,
516                            (_, None) => true,
517                            _ => false,
518                        };
519                        if same_project {
520                            entry["content"] = json!(sr.content);
521                        } else {
522                            entry["cross_project"] = json!(true);
523                        }
524                        graph_expanded.push(entry);
525                    }
526                }
527            }
528        }
529    }
530
531    // Co-access suggestions (batch fetch)
532    let mut suggested = Vec::new();
533    if config.enable_co_access {
534        let top3_ids: Vec<&str> = results.iter().take(3).map(|r| r.id.as_str()).collect();
535        if let Ok(co_accessed) = graph.get_co_accessed(&top3_ids, 3) {
536            let co_info: Vec<(String, i64)> = co_accessed
537                .into_iter()
538                .filter(|(id, _)| !existing_ids.contains(id))
539                .collect();
540
541            let co_ids: Vec<&str> = co_info.iter().map(|(id, _)| id.as_str()).collect();
542            if let Ok(items) = db.get_items_batch(&co_ids).await {
543                let item_map: std::collections::HashMap<&str, &Item> =
544                    items.iter().map(|item| (item.id.as_str(), item)).collect();
545
546                for (co_id, co_count) in &co_info {
547                    if let Some(item) = item_map.get(co_id.as_str()) {
548                        let same_project = match (db.project_id(), item.project_id.as_deref()) {
549                            (Some(current), Some(item_pid)) => current == item_pid,
550                            (_, None) => true,
551                            _ => false,
552                        };
553                        let mut entry = json!({
554                            "id": item.id,
555                            "reason": format!("frequently recalled with result (co-accessed {} times)", co_count),
556                        });
557                        if same_project {
558                            entry["content"] = json!(truncate(&item.content, 100));
559                        } else {
560                            entry["cross_project"] = json!(true);
561                        }
562                        suggested.push(entry);
563                    }
564                }
565            }
566        }
567    }
568
569    Ok(RecallResult {
570        results,
571        graph_expanded,
572        suggested,
573        raw_similarities,
574    })
575}
576
577async fn execute_recall(
578    db: &mut Database,
579    tracker: &AccessTracker,
580    graph: &GraphStore,
581    ctx: &ServerContext,
582    args: Option<Value>,
583) -> CallToolResult {
584    let params: RecallParams = match args {
585        Some(v) => match serde_json::from_value(v) {
586            Ok(p) => p,
587            Err(e) => {
588                tracing::debug!("Parameter validation failed: {}", e);
589                return CallToolResult::error("Invalid parameters");
590            }
591        },
592        None => return CallToolResult::error("Missing parameters"),
593    };
594
595    if params.query.trim().is_empty() {
596        return CallToolResult::error("Query must not be empty");
597    }
598
599    // Reject oversized queries to prevent OOM during tokenization.
600    // The model truncates to 512 tokens (~2KB of English text), so anything
601    // beyond 10KB is wasted processing.
602    const MAX_QUERY_BYTES: usize = 10_000;
603    if params.query.len() > MAX_QUERY_BYTES {
604        return CallToolResult::error(format!(
605            "Query too large: {} bytes (max {} bytes)",
606            params.query.len(),
607            MAX_QUERY_BYTES
608        ));
609    }
610
611    let limit = params.limit.unwrap_or(5).min(100);
612
613    let filters = ItemFilters::new();
614
615    let config = RecallConfig::default();
616
617    let recall_result =
618        match recall_pipeline(db, tracker, graph, &params.query, limit, filters, &config).await {
619            Ok(r) => r,
620            Err(e) => {
621                tracing::error!("Recall failed: {}", e);
622                return CallToolResult::error("Search failed");
623            }
624        };
625
626    if recall_result.results.is_empty() {
627        return CallToolResult::success("No items found matching your query.");
628    }
629
630    let results = &recall_result.results;
631
632    // Batch-fetch neighbors for all result IDs
633    let all_result_ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
634    let neighbors_map = graph
635        .get_neighbors_mapped(&all_result_ids, 0.5)
636        .unwrap_or_default();
637
638    let formatted: Vec<Value> = results
639        .iter()
640        .map(|r| {
641            let mut obj = json!({
642                "id": r.id,
643                "content": r.content,
644                "similarity": format!("{:.2}", r.similarity),
645                "created": r.created_at.to_rfc3339(),
646            });
647
648            // Include raw (pre-decay) similarity when decay scoring was applied
649            if let Some(&raw_sim) = recall_result.raw_similarities.get(&r.id)
650                && (raw_sim - r.similarity).abs() > 0.001
651            {
652                obj["raw_similarity"] = json!(format!("{:.2}", raw_sim));
653            }
654
655            if let Some(ref excerpt) = r.relevant_excerpt {
656                obj["relevant_excerpt"] = json!(excerpt);
657            }
658
659            // Cross-project flag
660            if let Some(ref current_pid) = ctx.project_id
661                && let Some(ref item_pid) = r.project_id
662                && item_pid != current_pid
663            {
664                obj["cross_project"] = json!(true);
665            }
666
667            // Related IDs from graph (batch lookup)
668            if let Some(related) = neighbors_map.get(&r.id)
669                && !related.is_empty()
670            {
671                obj["related_ids"] = json!(related);
672            }
673
674            obj
675        })
676        .collect();
677
678    let mut result_json = json!({
679        "count": results.len(),
680        "results": formatted
681    });
682
683    if !recall_result.graph_expanded.is_empty() {
684        result_json["graph_expanded"] = json!(recall_result.graph_expanded);
685    }
686
687    if !recall_result.suggested.is_empty() {
688        result_json["suggested"] = json!(recall_result.suggested);
689    }
690
691    // Fire-and-forget: background consolidation (Phase 2b)
692    spawn_consolidation(
693        Arc::new(ctx.db_path.clone()),
694        Arc::new(ctx.access_db_path.clone()),
695        ctx.project_id.clone(),
696        ctx.embedder.clone(),
697        ctx.consolidation_semaphore.clone(),
698    );
699
700    // Fire-and-forget: co-access recording (Phase 3a)
701    let result_ids: Vec<String> = results.iter().map(|r| r.id.clone()).collect();
702    let access_db_path = ctx.access_db_path.clone();
703    spawn_logged("co_access", async move {
704        if let Ok(g) = GraphStore::open(&access_db_path) {
705            if let Err(e) = g.record_co_access(&result_ids) {
706                tracing::warn!("record_co_access failed: {}", e);
707            }
708        } else {
709            tracing::warn!("co_access: failed to open graph store");
710        }
711    });
712
713    // Periodic maintenance: every 10th recall
714    let run_count = ctx
715        .recall_count
716        .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
717    if run_count % 10 == 9 {
718        // Clustering
719        let access_db_path = ctx.access_db_path.clone();
720        spawn_logged("clustering", async move {
721            if let Ok(g) = GraphStore::open(&access_db_path)
722                && let Ok(clusters) = g.detect_clusters()
723            {
724                for (a, b, c) in &clusters {
725                    let label = format!("cluster-{}", &a[..8.min(a.len())]);
726                    if let Err(e) = g.add_related_edge(a, b, 0.8, &label) {
727                        tracing::warn!("cluster add_related_edge failed: {}", e);
728                    }
729                    if let Err(e) = g.add_related_edge(b, c, 0.8, &label) {
730                        tracing::warn!("cluster add_related_edge failed: {}", e);
731                    }
732                    if let Err(e) = g.add_related_edge(a, c, 0.8, &label) {
733                        tracing::warn!("cluster add_related_edge failed: {}", e);
734                    }
735                }
736                if !clusters.is_empty() {
737                    tracing::info!("Detected {} clusters", clusters.len());
738                }
739            }
740        });
741
742        // Consolidation queue cleanup: purge old processed entries
743        let access_db_path2 = ctx.access_db_path.clone();
744        spawn_logged("consolidation_cleanup", async move {
745            if let Ok(q) = crate::consolidation::ConsolidationQueue::open(&access_db_path2)
746                && let Err(e) = q.cleanup_processed()
747            {
748                tracing::warn!("consolidation queue cleanup failed: {}", e);
749            }
750        });
751    }
752
753    CallToolResult::success(
754        serde_json::to_string_pretty(&result_json)
755            .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
756    )
757}
758
759async fn execute_list(db: &mut Database, args: Option<Value>) -> CallToolResult {
760    let params: ListParams =
761        args.and_then(|v| serde_json::from_value(v).ok())
762            .unwrap_or(ListParams {
763                limit: None,
764                scope: None,
765            });
766
767    let limit = params.limit.unwrap_or(10).min(100);
768
769    let scope = params
770        .scope
771        .as_deref()
772        .map(|s| s.parse::<ListScope>())
773        .transpose();
774
775    let scope = match scope {
776        Ok(s) => s.unwrap_or(ListScope::Project),
777        Err(e) => return CallToolResult::error(e),
778    };
779
780    match db.list_items(Some(limit), scope).await {
781        Ok(items) => {
782            if items.is_empty() {
783                CallToolResult::success("No items stored yet.")
784            } else {
785                let formatted: Vec<Value> = items
786                    .iter()
787                    .map(|item| {
788                        let content_preview = truncate(&item.content, 100);
789                        let mut obj = json!({
790                            "id": item.id,
791                            "content": content_preview,
792                            "created": item.created_at.to_rfc3339(),
793                        });
794
795                        if item.is_chunked {
796                            obj["chunked"] = json!(true);
797                        }
798
799                        obj
800                    })
801                    .collect();
802
803                let result = json!({
804                    "count": items.len(),
805                    "items": formatted
806                });
807
808                CallToolResult::success(
809                    serde_json::to_string_pretty(&result).unwrap_or_else(|e| {
810                        format!("{{\"error\": \"serialization failed: {}\"}}", e)
811                    }),
812                )
813            }
814        }
815        Err(e) => sanitized_error("Failed to list items", e),
816    }
817}
818
819async fn execute_forget(
820    db: &mut Database,
821    graph: &GraphStore,
822    ctx: &ServerContext,
823    args: Option<Value>,
824) -> CallToolResult {
825    let params: ForgetParams = match args {
826        Some(v) => match serde_json::from_value(v) {
827            Ok(p) => p,
828            Err(e) => {
829                tracing::debug!("Parameter validation failed: {}", e);
830                return CallToolResult::error("Invalid parameters");
831            }
832        },
833        None => return CallToolResult::error("Missing parameters"),
834    };
835
836    // Access control: verify the item belongs to the current project (or is global)
837    if let Some(ref current_pid) = ctx.project_id {
838        match db.get_item(&params.id).await {
839            Ok(Some(item)) => {
840                if let Some(ref item_pid) = item.project_id
841                    && item_pid != current_pid
842                {
843                    return CallToolResult::error(format!(
844                        "Cannot delete item {} from a different project",
845                        params.id
846                    ));
847                }
848            }
849            Ok(None) => return CallToolResult::error(format!("Item not found: {}", params.id)),
850            Err(e) => {
851                return sanitized_error("Failed to look up item", e);
852            }
853        }
854    }
855
856    match db.delete_item(&params.id).await {
857        Ok(true) => {
858            // Remove from graph
859            if let Err(e) = graph.remove_node(&params.id) {
860                tracing::warn!("remove_node failed: {}", e);
861            }
862
863            let result = json!({
864                "success": true,
865                "message": format!("Deleted item: {}", params.id)
866            });
867            CallToolResult::success(
868                serde_json::to_string_pretty(&result)
869                    .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
870            )
871        }
872        Ok(false) => CallToolResult::error(format!("Item not found: {}", params.id)),
873        Err(e) => sanitized_error("Failed to delete item", e),
874    }
875}
876
877// ========== Utilities ==========
878
879/// Log a detailed internal error and return a sanitized message to the MCP client.
880/// This prevents leaking file paths, database internals, or OS details.
881fn sanitized_error(context: &str, err: impl std::fmt::Display) -> CallToolResult {
882    tracing::error!("{}: {}", context, err);
883    CallToolResult::error(context.to_string())
884}
885
886/// Like `sanitized_error` but returns a String for use inside `map_err` chains.
887fn sanitize_err(context: &str, err: impl std::fmt::Display) -> String {
888    tracing::error!("{}: {}", context, err);
889    context.to_string()
890}
891
892fn truncate(s: &str, max_len: usize) -> String {
893    if s.chars().count() <= max_len {
894        s.to_string()
895    } else if max_len <= 3 {
896        // Not enough room for "..." suffix; just take max_len chars
897        s.chars().take(max_len).collect()
898    } else {
899        let cut = s
900            .char_indices()
901            .nth(max_len - 3)
902            .map(|(i, _)| i)
903            .unwrap_or(s.len());
904        format!("{}...", &s[..cut])
905    }
906}
907
908#[cfg(test)]
909mod tests {
910    use super::*;
911
912    #[test]
913    fn test_truncate_small_max_len() {
914        // Bug #25: truncate should not panic when max_len < 3
915        assert_eq!(truncate("hello", 0), "");
916        assert_eq!(truncate("hello", 1), "h");
917        assert_eq!(truncate("hello", 2), "he");
918        assert_eq!(truncate("hello", 3), "hel");
919        assert_eq!(truncate("hi", 3), "hi"); // shorter than max, no truncation
920        assert_eq!(truncate("hello", 5), "hello");
921        assert_eq!(truncate("hello!", 5), "he...");
922    }
923
924    #[test]
925    fn test_truncate_unicode() {
926        assert_eq!(truncate("héllo wörld", 5), "hé...");
927        assert_eq!(truncate("日本語テスト", 4), "日...");
928    }
929
930    // ========== Integration Tests ==========
931
932    use std::path::PathBuf;
933    use std::sync::Mutex;
934    use tokio::sync::Semaphore;
935
936    /// Create a ServerContext with temp dirs for integration testing.
937    async fn setup_test_context() -> (ServerContext, tempfile::TempDir) {
938        let tmp = tempfile::TempDir::new().unwrap();
939        let db_path = tmp.path().join("data");
940        let access_db_path = tmp.path().join("access.db");
941
942        let embedder = Arc::new(crate::Embedder::new().unwrap());
943        let project_id = Some("test-project-00000001".to_string());
944
945        let ctx = ServerContext {
946            db_path,
947            access_db_path,
948            project_id,
949            embedder,
950            cwd: PathBuf::from("."),
951            consolidation_semaphore: Arc::new(Semaphore::new(1)),
952            recall_count: std::sync::atomic::AtomicU64::new(0),
953            rate_limit: Mutex::new(super::super::server::RateLimitState {
954                window_start_ms: 0,
955                count: 0,
956            }),
957        };
958
959        (ctx, tmp)
960    }
961
962    #[tokio::test]
963    #[ignore] // requires model download
964    async fn test_store_and_recall_roundtrip() {
965        let (ctx, _tmp) = setup_test_context().await;
966
967        // Store an item
968        let store_result = execute_tool(
969            &ctx,
970            "store",
971            Some(json!({ "content": "Rust is a systems programming language" })),
972        )
973        .await;
974        assert!(
975            store_result.is_error.is_none(),
976            "Store should succeed: {:?}",
977            store_result.content
978        );
979
980        // Recall by query
981        let recall_result = execute_tool(
982            &ctx,
983            "recall",
984            Some(json!({ "query": "systems programming language" })),
985        )
986        .await;
987        assert!(recall_result.is_error.is_none(), "Recall should succeed");
988
989        let text = &recall_result.content[0].text;
990        assert!(
991            text.contains("Rust is a systems programming language"),
992            "Recall should return stored content, got: {}",
993            text
994        );
995    }
996
997    #[tokio::test]
998    #[ignore] // requires model download
999    async fn test_store_and_list() {
1000        let (ctx, _tmp) = setup_test_context().await;
1001
1002        // Store 2 items
1003        execute_tool(
1004            &ctx,
1005            "store",
1006            Some(json!({ "content": "First item for listing" })),
1007        )
1008        .await;
1009        execute_tool(
1010            &ctx,
1011            "store",
1012            Some(json!({ "content": "Second item for listing" })),
1013        )
1014        .await;
1015
1016        // List items
1017        let list_result = execute_tool(&ctx, "list", Some(json!({ "scope": "project" }))).await;
1018        assert!(list_result.is_error.is_none(), "List should succeed");
1019
1020        let text = &list_result.content[0].text;
1021        let parsed: Value = serde_json::from_str(text).unwrap();
1022        assert_eq!(parsed["count"], 2, "Should list 2 items");
1023    }
1024
1025    #[tokio::test]
1026    #[ignore] // requires model download
1027    async fn test_store_conflict_detection() {
1028        let (ctx, _tmp) = setup_test_context().await;
1029
1030        // Store first item
1031        execute_tool(
1032            &ctx,
1033            "store",
1034            Some(json!({ "content": "The quick brown fox jumps over the lazy dog" })),
1035        )
1036        .await;
1037
1038        // Store nearly identical item
1039        let result = execute_tool(
1040            &ctx,
1041            "store",
1042            Some(json!({ "content": "The quick brown fox jumps over the lazy dog" })),
1043        )
1044        .await;
1045        assert!(result.is_error.is_none(), "Store should succeed");
1046
1047        let text = &result.content[0].text;
1048        let parsed: Value = serde_json::from_str(text).unwrap();
1049        assert!(
1050            parsed.get("potential_conflicts").is_some(),
1051            "Should detect conflict for near-duplicate content, got: {}",
1052            text
1053        );
1054    }
1055
1056    #[tokio::test]
1057    #[ignore] // requires model download
1058    async fn test_forget_removes_item() {
1059        let (ctx, _tmp) = setup_test_context().await;
1060
1061        // Store an item
1062        let store_result = execute_tool(
1063            &ctx,
1064            "store",
1065            Some(json!({ "content": "Item to be forgotten" })),
1066        )
1067        .await;
1068        let text = &store_result.content[0].text;
1069        let parsed: Value = serde_json::from_str(text).unwrap();
1070        let item_id = parsed["id"].as_str().unwrap().to_string();
1071
1072        // Forget it
1073        let forget_result = execute_tool(&ctx, "forget", Some(json!({ "id": item_id }))).await;
1074        assert!(forget_result.is_error.is_none(), "Forget should succeed");
1075
1076        // List should be empty
1077        let list_result = execute_tool(&ctx, "list", Some(json!({ "scope": "project" }))).await;
1078        let text = &list_result.content[0].text;
1079        assert!(
1080            text.contains("No items stored yet"),
1081            "Should have no items after forget, got: {}",
1082            text
1083        );
1084    }
1085
1086    #[tokio::test]
1087    #[ignore] // requires model download
1088    async fn test_recall_empty_db() {
1089        let (ctx, _tmp) = setup_test_context().await;
1090
1091        let result = execute_tool(&ctx, "recall", Some(json!({ "query": "anything" }))).await;
1092        assert!(
1093            result.is_error.is_none(),
1094            "Recall on empty DB should not error"
1095        );
1096
1097        let text = &result.content[0].text;
1098        assert!(
1099            text.contains("No items found"),
1100            "Should indicate no items found, got: {}",
1101            text
1102        );
1103    }
1104
1105    #[tokio::test]
1106    #[ignore] // requires model download
1107    async fn test_store_rejects_oversized_content() {
1108        let (ctx, _tmp) = setup_test_context().await;
1109
1110        let large_content = "x".repeat(1_100_000); // >1MB
1111        let result = execute_tool(&ctx, "store", Some(json!({ "content": large_content }))).await;
1112        assert!(
1113            result.is_error == Some(true),
1114            "Should reject oversized content"
1115        );
1116
1117        let text = &result.content[0].text;
1118        assert!(
1119            text.contains("too large"),
1120            "Error should mention size, got: {}",
1121            text
1122        );
1123    }
1124
1125    #[tokio::test]
1126    #[ignore] // requires model download
1127    async fn test_recall_rejects_oversized_query() {
1128        let (ctx, _tmp) = setup_test_context().await;
1129
1130        let large_query = "x".repeat(11_000); // >10KB
1131        let result = execute_tool(&ctx, "recall", Some(json!({ "query": large_query }))).await;
1132        assert!(
1133            result.is_error == Some(true),
1134            "Should reject oversized query"
1135        );
1136
1137        let text = &result.content[0].text;
1138        assert!(
1139            text.contains("too large"),
1140            "Error should mention size, got: {}",
1141            text
1142        );
1143    }
1144
1145    #[tokio::test]
1146    #[ignore] // requires model download
1147    async fn test_store_missing_params() {
1148        let (ctx, _tmp) = setup_test_context().await;
1149
1150        // No params at all
1151        let result = execute_tool(&ctx, "store", None).await;
1152        assert!(result.is_error == Some(true), "Should error with no params");
1153
1154        // Empty object (missing required 'content')
1155        let result = execute_tool(&ctx, "store", Some(json!({}))).await;
1156        assert!(
1157            result.is_error == Some(true),
1158            "Should error with missing content"
1159        );
1160    }
1161}