Skip to main content

tuitbot_core/context/
retrieval.rs

1//! Vault fragment retrieval and citation engine.
2//!
3//! Retrieves account-scoped content chunks from the vault, builds structured
4//! citation records, and formats fragment text for LLM prompt injection.
5
6use std::collections::HashSet;
7
8use crate::error::StorageError;
9use crate::storage::provenance::ProvenanceRef;
10use crate::storage::watchtower::{self, ChunkWithNodeContext};
11use crate::storage::DbPool;
12
13/// Maximum character budget for the vault fragment prompt section.
14pub const MAX_FRAGMENT_CHARS: usize = 2500;
15
16/// Maximum number of fragments to include in context.
17pub const MAX_FRAGMENTS: u32 = 5;
18
19/// Maximum snippet length in citation records (characters).
20const CITATION_SNIPPET_LEN: usize = 120;
21
22// ============================================================================
23// Structs
24// ============================================================================
25
26/// How a search result was matched: semantic embedding, keyword text search,
27/// graph edge traversal, or a blend of multiple signals.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
29#[serde(rename_all = "snake_case")]
30pub enum MatchReason {
31    /// Matched via embedding cosine similarity.
32    Semantic,
33    /// Matched via keyword / LIKE text search.
34    Keyword,
35    /// Matched via graph edge (wikilink, backlink, shared tag).
36    Graph,
37    /// Matched by two or more signal types.
38    Hybrid,
39}
40
41/// A structured citation linking a prompt fragment back to its vault source.
42#[derive(Debug, Clone, serde::Serialize)]
43pub struct VaultCitation {
44    /// ID of the content chunk.
45    pub chunk_id: i64,
46    /// ID of the parent content node.
47    pub node_id: i64,
48    /// Heading hierarchy path (e.g., "# Title > ## Section").
49    pub heading_path: String,
50    /// Relative file path of the source note.
51    pub source_path: String,
52    /// Title of the source note (if available).
53    pub source_title: Option<String>,
54    /// Short excerpt from the chunk text.
55    pub snippet: String,
56    /// Retrieval boost score.
57    pub retrieval_boost: f64,
58    /// Graph edge type (e.g., "wikilink", "backlink", "shared_tag"). None for non-graph citations.
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub edge_type: Option<String>,
61    /// Graph edge label for provenance tracking. None for non-graph citations.
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub edge_label: Option<String>,
64    /// How this citation was matched (semantic, keyword, graph, or hybrid).
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub match_reason: Option<MatchReason>,
67    /// Retrieval score from the ranking algorithm (RRF or raw similarity).
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub score: Option<f64>,
70}
71
72/// Intermediate result pairing chunk text with citation metadata.
73#[derive(Debug, Clone)]
74pub struct FragmentContext {
75    /// The full chunk text for prompt inclusion.
76    pub chunk_text: String,
77    /// Citation metadata for this fragment.
78    pub citation: VaultCitation,
79}
80
81// ============================================================================
82// Retrieval
83// ============================================================================
84
85/// Retrieve vault fragments matching keywords, with optional selected-note bias.
86///
87/// When `selected_node_ids` is provided, chunks from those notes are retrieved
88/// first, then remaining slots are filled with keyword-matched results (deduplicated).
89pub async fn retrieve_vault_fragments(
90    pool: &DbPool,
91    account_id: &str,
92    keywords: &[String],
93    selected_node_ids: Option<&[i64]>,
94    max_results: u32,
95) -> Result<Vec<FragmentContext>, StorageError> {
96    let mut results: Vec<FragmentContext> = Vec::new();
97    let mut seen_ids: HashSet<i64> = HashSet::new();
98
99    // Step 1: If selected nodes provided, fetch their chunks first.
100    if let Some(node_ids) = selected_node_ids {
101        if !node_ids.is_empty() {
102            let biased = watchtower::get_chunks_for_nodes_with_context(
103                pool,
104                account_id,
105                node_ids,
106                max_results,
107            )
108            .await?;
109
110            for cwc in biased {
111                if seen_ids.insert(cwc.chunk.id) {
112                    results.push(fragment_from_chunk_with_context(cwc));
113                }
114                if results.len() >= max_results as usize {
115                    break;
116                }
117            }
118        }
119    }
120
121    // Step 2: Fill remaining slots with keyword search results.
122    if results.len() < max_results as usize && !keywords.is_empty() {
123        let remaining = max_results - results.len() as u32;
124        let kw_refs: Vec<&str> = keywords.iter().map(|s| s.as_str()).collect();
125        let keyword_results =
126            watchtower::search_chunks_with_context(pool, account_id, &kw_refs, remaining + 5)
127                .await?;
128
129        for cwc in keyword_results {
130            if seen_ids.insert(cwc.chunk.id) {
131                results.push(fragment_from_chunk_with_context(cwc));
132            }
133            if results.len() >= max_results as usize {
134                break;
135            }
136        }
137    }
138
139    Ok(results)
140}
141
142// ============================================================================
143// Formatting
144// ============================================================================
145
146/// Format fragment text as a prompt section with inline citations.
147///
148/// Output is capped at `MAX_FRAGMENT_CHARS`.
149pub fn format_fragments_prompt(fragments: &[FragmentContext]) -> String {
150    if fragments.is_empty() {
151        return String::new();
152    }
153
154    let mut block = String::from("\nRelevant knowledge from your notes:\n");
155
156    for (i, f) in fragments.iter().enumerate() {
157        let title = f
158            .citation
159            .source_title
160            .as_deref()
161            .unwrap_or(&f.citation.source_path);
162        let heading = if f.citation.heading_path.is_empty() {
163            String::new()
164        } else {
165            format!("[{}] ", f.citation.heading_path)
166        };
167        let preview = truncate_text(&f.chunk_text, 500);
168        let entry = format!("{}. {}(from: {}): \"{}\"\n", i + 1, heading, title, preview);
169
170        if block.len() + entry.len() > MAX_FRAGMENT_CHARS {
171            break;
172        }
173        block.push_str(&entry);
174    }
175
176    block.push_str("Reference these insights to ground your response in your own expertise.\n");
177
178    if block.len() > MAX_FRAGMENT_CHARS {
179        block.truncate(MAX_FRAGMENT_CHARS);
180    }
181    block
182}
183
184// ============================================================================
185// Citation builders
186// ============================================================================
187
188/// Extract `VaultCitation` records from fragment contexts.
189pub fn build_citations(fragments: &[FragmentContext]) -> Vec<VaultCitation> {
190    fragments.iter().map(|f| f.citation.clone()).collect()
191}
192
193// ============================================================================
194// Provenance converters
195// ============================================================================
196
197/// Convert `VaultCitation` records to `ProvenanceRef` for persistence.
198pub fn citations_to_provenance_refs(citations: &[VaultCitation]) -> Vec<ProvenanceRef> {
199    citations
200        .iter()
201        .map(|c| ProvenanceRef {
202            node_id: Some(c.node_id),
203            chunk_id: Some(c.chunk_id),
204            seed_id: None,
205            source_path: Some(c.source_path.clone()),
206            heading_path: Some(c.heading_path.clone()),
207            snippet: Some(c.snippet.clone()),
208            edge_type: c.edge_type.clone(),
209            edge_label: c.edge_label.clone(),
210            angle_kind: None,
211            signal_kind: None,
212            signal_text: None,
213            source_role: None,
214        })
215        .collect()
216}
217
218/// Serialize citations as a JSON array for the legacy `source_chunks_json` column.
219pub fn citations_to_chunks_json(citations: &[VaultCitation]) -> String {
220    let entries: Vec<serde_json::Value> = citations
221        .iter()
222        .map(|c| {
223            serde_json::json!({
224                "chunk_id": c.chunk_id,
225                "node_id": c.node_id,
226                "source_path": c.source_path,
227                "heading_path": c.heading_path,
228            })
229        })
230        .collect();
231    serde_json::to_string(&entries).unwrap_or_else(|_| "[]".to_string())
232}
233
234// ============================================================================
235// Selection identity resolution
236// ============================================================================
237
238/// Resolve a Ghostwriter selection payload to the best available indexed block identity.
239///
240/// Returns `(Option<node_id>, Option<chunk_id>)`. Both are `None` if the note
241/// isn't indexed yet. Resolution is best-effort — the `selected_text` is always
242/// the authoritative payload.
243pub async fn resolve_selection_identity(
244    pool: &DbPool,
245    account_id: &str,
246    file_path: &str,
247    heading_context: Option<&str>,
248) -> Result<(Option<i64>, Option<i64>), StorageError> {
249    let node = watchtower::find_node_by_path_for(pool, account_id, file_path).await?;
250
251    let node = match node {
252        Some(n) => n,
253        None => return Ok((None, None)),
254    };
255
256    let chunk =
257        watchtower::find_best_chunk_by_heading_for(pool, account_id, node.id, heading_context)
258            .await?;
259
260    Ok((Some(node.id), chunk.map(|c| c.id)))
261}
262
263// ============================================================================
264// Helpers
265// ============================================================================
266
267fn fragment_from_chunk_with_context(cwc: ChunkWithNodeContext) -> FragmentContext {
268    let snippet = truncate_text(&cwc.chunk.chunk_text, CITATION_SNIPPET_LEN);
269    FragmentContext {
270        chunk_text: cwc.chunk.chunk_text.clone(),
271        citation: VaultCitation {
272            chunk_id: cwc.chunk.id,
273            node_id: cwc.chunk.node_id,
274            heading_path: cwc.chunk.heading_path.clone(),
275            source_path: cwc.relative_path,
276            source_title: cwc.source_title,
277            snippet,
278            retrieval_boost: cwc.chunk.retrieval_boost,
279            edge_type: None,
280            edge_label: None,
281            match_reason: None,
282            score: None,
283        },
284    }
285}
286
287fn truncate_text(text: &str, max_len: usize) -> String {
288    if text.len() <= max_len {
289        text.to_string()
290    } else {
291        let mut end = max_len.saturating_sub(3);
292        while end > 0 && !text.is_char_boundary(end) {
293            end -= 1;
294        }
295        format!("{}...", &text[..end])
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    fn make_fragment(chunk_id: i64, text: &str, path: &str) -> FragmentContext {
304        FragmentContext {
305            chunk_text: text.to_string(),
306            citation: VaultCitation {
307                chunk_id,
308                node_id: chunk_id * 10,
309                heading_path: String::new(),
310                source_path: path.to_string(),
311                source_title: None,
312                snippet: text.chars().take(50).collect(),
313                retrieval_boost: 1.0,
314                edge_type: None,
315                edge_label: None,
316                match_reason: None,
317                score: None,
318            },
319        }
320    }
321
322    fn sample_citation() -> VaultCitation {
323        VaultCitation {
324            chunk_id: 1,
325            node_id: 10,
326            heading_path: "# Guide > ## Setup".to_string(),
327            source_path: "notes/guide.md".to_string(),
328            source_title: Some("Installation Guide".to_string()),
329            snippet: "Install with cargo install".to_string(),
330            retrieval_boost: 1.0,
331            edge_type: None,
332            edge_label: None,
333            match_reason: None,
334            score: None,
335        }
336    }
337
338    fn sample_fragment() -> FragmentContext {
339        FragmentContext {
340            chunk_text: "Install the CLI with cargo install tuitbot".to_string(),
341            citation: sample_citation(),
342        }
343    }
344
345    #[test]
346    fn format_fragments_prompt_empty() {
347        let result = format_fragments_prompt(&[]);
348        assert!(result.is_empty());
349    }
350
351    #[test]
352    fn format_fragments_prompt_single() {
353        let f = make_fragment(1, "Some interesting insight about Rust", "notes/rust.md");
354        let result = format_fragments_prompt(&[f]);
355        assert!(result.contains("Relevant knowledge from your notes:"));
356        assert!(result.contains("(from: notes/rust.md)"));
357        assert!(result.contains("Some interesting insight about Rust"));
358        assert!(result.contains("Reference these insights"));
359    }
360
361    #[test]
362    fn format_fragments_single_with_heading() {
363        let frags = vec![sample_fragment()];
364        let result = format_fragments_prompt(&frags);
365        assert!(result.contains("Relevant knowledge"));
366        assert!(result.contains("Installation Guide"));
367        assert!(result.contains("# Guide > ## Setup"));
368        assert!(result.contains("Reference these insights"));
369    }
370
371    #[test]
372    fn format_fragments_prompt_truncates_at_limit() {
373        let big_text = "A".repeat(300);
374        let fragments: Vec<FragmentContext> = (0..20)
375            .map(|i| make_fragment(i, &big_text, &format!("notes/{i}.md")))
376            .collect();
377        let result = format_fragments_prompt(&fragments);
378        assert!(result.len() <= MAX_FRAGMENT_CHARS);
379    }
380
381    #[test]
382    fn format_fragments_multiple_items_numbered() {
383        let mut f1 = sample_fragment();
384        f1.citation.source_title = Some("First".to_string());
385        let mut f2 = sample_fragment();
386        f2.citation.source_title = Some("Second".to_string());
387        let result = format_fragments_prompt(&[f1, f2]);
388        assert!(result.contains("1."));
389        assert!(result.contains("2."));
390    }
391
392    #[test]
393    fn build_citations_empty() {
394        let result = build_citations(&[]);
395        assert!(result.is_empty());
396    }
397
398    #[test]
399    fn build_citations_preserves_fields() {
400        let f = make_fragment(42, "chunk text here", "vault/note.md");
401        let citations = build_citations(&[f]);
402        assert_eq!(citations.len(), 1);
403        assert_eq!(citations[0].chunk_id, 42);
404        assert_eq!(citations[0].node_id, 420);
405        assert_eq!(citations[0].source_path, "vault/note.md");
406        assert_eq!(citations[0].retrieval_boost, 1.0);
407    }
408
409    #[test]
410    fn build_citations_returns_all() {
411        let frags = vec![sample_fragment(), sample_fragment()];
412        let citations = build_citations(&frags);
413        assert_eq!(citations.len(), 2);
414    }
415
416    #[test]
417    fn citations_to_provenance_refs_maps_fields() {
418        let citation = VaultCitation {
419            chunk_id: 5,
420            node_id: 50,
421            heading_path: "# Title > ## Section".to_string(),
422            source_path: "docs/guide.md".to_string(),
423            source_title: Some("Guide".to_string()),
424            snippet: "snippet text".to_string(),
425            retrieval_boost: 1.5,
426            edge_type: None,
427            edge_label: None,
428            match_reason: None,
429            score: None,
430        };
431        let refs = citations_to_provenance_refs(&[citation]);
432        assert_eq!(refs.len(), 1);
433        assert_eq!(refs[0].node_id, Some(50));
434        assert_eq!(refs[0].chunk_id, Some(5));
435        assert_eq!(refs[0].source_path.as_deref(), Some("docs/guide.md"));
436        assert_eq!(
437            refs[0].heading_path.as_deref(),
438            Some("# Title > ## Section")
439        );
440        assert_eq!(refs[0].snippet.as_deref(), Some("snippet text"));
441        assert!(refs[0].seed_id.is_none());
442    }
443
444    #[test]
445    fn citations_to_chunks_json_empty() {
446        let result = citations_to_chunks_json(&[]);
447        assert_eq!(result, "[]");
448    }
449
450    #[test]
451    fn citations_to_chunks_json_valid() {
452        let citation = VaultCitation {
453            chunk_id: 7,
454            node_id: 70,
455            heading_path: "# Intro".to_string(),
456            source_path: "notes/intro.md".to_string(),
457            source_title: None,
458            snippet: "intro text".to_string(),
459            retrieval_boost: 1.0,
460            edge_type: None,
461            edge_label: None,
462            match_reason: None,
463            score: None,
464        };
465        let result = citations_to_chunks_json(&[citation]);
466        let parsed: Vec<serde_json::Value> = serde_json::from_str(&result).unwrap();
467        assert_eq!(parsed.len(), 1);
468        assert_eq!(parsed[0]["chunk_id"], 7);
469        assert_eq!(parsed[0]["node_id"], 70);
470        assert_eq!(parsed[0]["source_path"], "notes/intro.md");
471        assert_eq!(parsed[0]["heading_path"], "# Intro");
472    }
473
474    #[test]
475    fn format_fragments_heading_path_empty() {
476        let f = make_fragment(1, "some text", "path.md");
477        let result = format_fragments_prompt(&[f]);
478        assert!(!result.contains("[] "));
479    }
480
481    #[test]
482    fn format_fragments_source_title_fallback() {
483        let f = make_fragment(1, "content here", "vault/fallback.md");
484        let result = format_fragments_prompt(&[f]);
485        assert!(result.contains("vault/fallback.md"));
486    }
487
488    #[test]
489    fn truncate_text_short_unchanged() {
490        assert_eq!(truncate_text("hello", 10), "hello");
491    }
492
493    #[test]
494    fn truncate_text_long_gets_ellipsis() {
495        let result = truncate_text("hello world this is long", 10);
496        assert!(result.ends_with("..."));
497        assert!(result.len() <= 13);
498    }
499
500    #[test]
501    fn truncate_text_exact_boundary() {
502        let result = truncate_text("hello", 5);
503        assert_eq!(result, "hello");
504    }
505
506    #[test]
507    fn truncate_text_empty_string() {
508        assert_eq!(truncate_text("", 10), "");
509    }
510
511    #[test]
512    fn truncate_text_zero_max() {
513        let result = truncate_text("hello", 0);
514        // max_len=0, sub(3) saturates to 0, so "..."
515        assert_eq!(result, "...");
516    }
517
518    #[test]
519    fn citations_to_provenance_refs_empty() {
520        let refs = citations_to_provenance_refs(&[]);
521        assert!(refs.is_empty());
522    }
523
524    #[test]
525    fn citations_to_chunks_json_multiple() {
526        let citations = vec![
527            VaultCitation {
528                chunk_id: 1,
529                node_id: 10,
530                heading_path: "# A".to_string(),
531                source_path: "a.md".to_string(),
532                source_title: None,
533                snippet: "".to_string(),
534                retrieval_boost: 1.0,
535                edge_type: None,
536                edge_label: None,
537                match_reason: None,
538                score: None,
539            },
540            VaultCitation {
541                chunk_id: 2,
542                node_id: 20,
543                heading_path: "# B".to_string(),
544                source_path: "b.md".to_string(),
545                source_title: Some("B".to_string()),
546                snippet: "".to_string(),
547                retrieval_boost: 2.0,
548                edge_type: None,
549                edge_label: None,
550                match_reason: None,
551                score: None,
552            },
553        ];
554        let json_str = citations_to_chunks_json(&citations);
555        let parsed: Vec<serde_json::Value> = serde_json::from_str(&json_str).unwrap();
556        assert_eq!(parsed.len(), 2);
557        assert_eq!(parsed[0]["chunk_id"], 1);
558        assert_eq!(parsed[1]["chunk_id"], 2);
559    }
560
561    #[test]
562    fn format_fragments_with_source_title() {
563        let f = FragmentContext {
564            chunk_text: "CLI tool for managing bots".to_string(),
565            citation: VaultCitation {
566                chunk_id: 1,
567                node_id: 10,
568                heading_path: "".to_string(),
569                source_path: "vault/cli.md".to_string(),
570                source_title: Some("CLI Guide".to_string()),
571                snippet: "CLI tool...".to_string(),
572                retrieval_boost: 1.0,
573                edge_type: None,
574                edge_label: None,
575                match_reason: None,
576                score: None,
577            },
578        };
579        let result = format_fragments_prompt(&[f]);
580        assert!(result.contains("CLI Guide"));
581        assert!(!result.contains("vault/cli.md")); // title takes precedence
582    }
583
584    #[test]
585    fn fragment_from_chunk_with_context_builds_correctly() {
586        use crate::storage::watchtower::{ChunkWithNodeContext, ContentChunk};
587
588        let cwc = ChunkWithNodeContext {
589            chunk: ContentChunk {
590                id: 42,
591                account_id: "acct".to_string(),
592                node_id: 100,
593                heading_path: "# Title".to_string(),
594                chunk_text: "Some chunk text for testing purposes".to_string(),
595                chunk_hash: "hash".to_string(),
596                chunk_index: 0,
597                retrieval_boost: 1.5,
598                status: "active".to_string(),
599                created_at: "2026-01-01".to_string(),
600                updated_at: "2026-01-01".to_string(),
601            },
602            relative_path: "notes/test.md".to_string(),
603            source_title: Some("Test Note".to_string()),
604        };
605
606        let frag = fragment_from_chunk_with_context(cwc);
607        assert_eq!(frag.citation.chunk_id, 42);
608        assert_eq!(frag.citation.node_id, 100);
609        assert_eq!(frag.citation.source_path, "notes/test.md");
610        assert_eq!(frag.citation.source_title, Some("Test Note".to_string()));
611        assert_eq!(frag.citation.heading_path, "# Title");
612        assert!((frag.citation.retrieval_boost - 1.5).abs() < 0.001);
613        assert_eq!(frag.chunk_text, "Some chunk text for testing purposes");
614    }
615
616    #[test]
617    fn vault_citation_clone() {
618        let c = sample_citation();
619        let c2 = c.clone();
620        assert_eq!(c.chunk_id, c2.chunk_id);
621        assert_eq!(c.heading_path, c2.heading_path);
622    }
623
624    #[test]
625    fn fragment_context_clone() {
626        let f = sample_fragment();
627        let f2 = f.clone();
628        assert_eq!(f.chunk_text, f2.chunk_text);
629        assert_eq!(f.citation.chunk_id, f2.citation.chunk_id);
630    }
631
632    #[test]
633    fn constants_have_expected_values() {
634        assert_eq!(MAX_FRAGMENT_CHARS, 2500);
635        assert_eq!(MAX_FRAGMENTS, 5);
636    }
637}