1use std::collections::HashSet;
7
8use crate::error::StorageError;
9use crate::storage::provenance::ProvenanceRef;
10use crate::storage::watchtower::{self, ChunkWithNodeContext};
11use crate::storage::DbPool;
12
13pub const MAX_FRAGMENT_CHARS: usize = 2500;
15
16pub const MAX_FRAGMENTS: u32 = 5;
18
19const CITATION_SNIPPET_LEN: usize = 120;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
29#[serde(rename_all = "snake_case")]
30pub enum MatchReason {
31 Semantic,
33 Keyword,
35 Graph,
37 Hybrid,
39}
40
41#[derive(Debug, Clone, serde::Serialize)]
43pub struct VaultCitation {
44 pub chunk_id: i64,
46 pub node_id: i64,
48 pub heading_path: String,
50 pub source_path: String,
52 pub source_title: Option<String>,
54 pub snippet: String,
56 pub retrieval_boost: f64,
58 #[serde(skip_serializing_if = "Option::is_none")]
60 pub edge_type: Option<String>,
61 #[serde(skip_serializing_if = "Option::is_none")]
63 pub edge_label: Option<String>,
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub match_reason: Option<MatchReason>,
67 #[serde(skip_serializing_if = "Option::is_none")]
69 pub score: Option<f64>,
70}
71
72#[derive(Debug, Clone)]
74pub struct FragmentContext {
75 pub chunk_text: String,
77 pub citation: VaultCitation,
79}
80
81pub async fn retrieve_vault_fragments(
90 pool: &DbPool,
91 account_id: &str,
92 keywords: &[String],
93 selected_node_ids: Option<&[i64]>,
94 max_results: u32,
95) -> Result<Vec<FragmentContext>, StorageError> {
96 let mut results: Vec<FragmentContext> = Vec::new();
97 let mut seen_ids: HashSet<i64> = HashSet::new();
98
99 if let Some(node_ids) = selected_node_ids {
101 if !node_ids.is_empty() {
102 let biased = watchtower::get_chunks_for_nodes_with_context(
103 pool,
104 account_id,
105 node_ids,
106 max_results,
107 )
108 .await?;
109
110 for cwc in biased {
111 if seen_ids.insert(cwc.chunk.id) {
112 results.push(fragment_from_chunk_with_context(cwc));
113 }
114 if results.len() >= max_results as usize {
115 break;
116 }
117 }
118 }
119 }
120
121 if results.len() < max_results as usize && !keywords.is_empty() {
123 let remaining = max_results - results.len() as u32;
124 let kw_refs: Vec<&str> = keywords.iter().map(|s| s.as_str()).collect();
125 let keyword_results =
126 watchtower::search_chunks_with_context(pool, account_id, &kw_refs, remaining + 5)
127 .await?;
128
129 for cwc in keyword_results {
130 if seen_ids.insert(cwc.chunk.id) {
131 results.push(fragment_from_chunk_with_context(cwc));
132 }
133 if results.len() >= max_results as usize {
134 break;
135 }
136 }
137 }
138
139 Ok(results)
140}
141
142pub fn format_fragments_prompt(fragments: &[FragmentContext]) -> String {
150 if fragments.is_empty() {
151 return String::new();
152 }
153
154 let mut block = String::from("\nRelevant knowledge from your notes:\n");
155
156 for (i, f) in fragments.iter().enumerate() {
157 let title = f
158 .citation
159 .source_title
160 .as_deref()
161 .unwrap_or(&f.citation.source_path);
162 let heading = if f.citation.heading_path.is_empty() {
163 String::new()
164 } else {
165 format!("[{}] ", f.citation.heading_path)
166 };
167 let preview = truncate_text(&f.chunk_text, 500);
168 let entry = format!("{}. {}(from: {}): \"{}\"\n", i + 1, heading, title, preview);
169
170 if block.len() + entry.len() > MAX_FRAGMENT_CHARS {
171 break;
172 }
173 block.push_str(&entry);
174 }
175
176 block.push_str("Reference these insights to ground your response in your own expertise.\n");
177
178 if block.len() > MAX_FRAGMENT_CHARS {
179 block.truncate(MAX_FRAGMENT_CHARS);
180 }
181 block
182}
183
184pub fn build_citations(fragments: &[FragmentContext]) -> Vec<VaultCitation> {
190 fragments.iter().map(|f| f.citation.clone()).collect()
191}
192
193pub fn citations_to_provenance_refs(citations: &[VaultCitation]) -> Vec<ProvenanceRef> {
199 citations
200 .iter()
201 .map(|c| ProvenanceRef {
202 node_id: Some(c.node_id),
203 chunk_id: Some(c.chunk_id),
204 seed_id: None,
205 source_path: Some(c.source_path.clone()),
206 heading_path: Some(c.heading_path.clone()),
207 snippet: Some(c.snippet.clone()),
208 edge_type: c.edge_type.clone(),
209 edge_label: c.edge_label.clone(),
210 angle_kind: None,
211 signal_kind: None,
212 signal_text: None,
213 source_role: None,
214 })
215 .collect()
216}
217
218pub fn citations_to_chunks_json(citations: &[VaultCitation]) -> String {
220 let entries: Vec<serde_json::Value> = citations
221 .iter()
222 .map(|c| {
223 serde_json::json!({
224 "chunk_id": c.chunk_id,
225 "node_id": c.node_id,
226 "source_path": c.source_path,
227 "heading_path": c.heading_path,
228 })
229 })
230 .collect();
231 serde_json::to_string(&entries).unwrap_or_else(|_| "[]".to_string())
232}
233
234pub async fn resolve_selection_identity(
244 pool: &DbPool,
245 account_id: &str,
246 file_path: &str,
247 heading_context: Option<&str>,
248) -> Result<(Option<i64>, Option<i64>), StorageError> {
249 let node = watchtower::find_node_by_path_for(pool, account_id, file_path).await?;
250
251 let node = match node {
252 Some(n) => n,
253 None => return Ok((None, None)),
254 };
255
256 let chunk =
257 watchtower::find_best_chunk_by_heading_for(pool, account_id, node.id, heading_context)
258 .await?;
259
260 Ok((Some(node.id), chunk.map(|c| c.id)))
261}
262
263fn fragment_from_chunk_with_context(cwc: ChunkWithNodeContext) -> FragmentContext {
268 let snippet = truncate_text(&cwc.chunk.chunk_text, CITATION_SNIPPET_LEN);
269 FragmentContext {
270 chunk_text: cwc.chunk.chunk_text.clone(),
271 citation: VaultCitation {
272 chunk_id: cwc.chunk.id,
273 node_id: cwc.chunk.node_id,
274 heading_path: cwc.chunk.heading_path.clone(),
275 source_path: cwc.relative_path,
276 source_title: cwc.source_title,
277 snippet,
278 retrieval_boost: cwc.chunk.retrieval_boost,
279 edge_type: None,
280 edge_label: None,
281 match_reason: None,
282 score: None,
283 },
284 }
285}
286
287fn truncate_text(text: &str, max_len: usize) -> String {
288 if text.len() <= max_len {
289 text.to_string()
290 } else {
291 let mut end = max_len.saturating_sub(3);
292 while end > 0 && !text.is_char_boundary(end) {
293 end -= 1;
294 }
295 format!("{}...", &text[..end])
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 fn make_fragment(chunk_id: i64, text: &str, path: &str) -> FragmentContext {
304 FragmentContext {
305 chunk_text: text.to_string(),
306 citation: VaultCitation {
307 chunk_id,
308 node_id: chunk_id * 10,
309 heading_path: String::new(),
310 source_path: path.to_string(),
311 source_title: None,
312 snippet: text.chars().take(50).collect(),
313 retrieval_boost: 1.0,
314 edge_type: None,
315 edge_label: None,
316 match_reason: None,
317 score: None,
318 },
319 }
320 }
321
322 fn sample_citation() -> VaultCitation {
323 VaultCitation {
324 chunk_id: 1,
325 node_id: 10,
326 heading_path: "# Guide > ## Setup".to_string(),
327 source_path: "notes/guide.md".to_string(),
328 source_title: Some("Installation Guide".to_string()),
329 snippet: "Install with cargo install".to_string(),
330 retrieval_boost: 1.0,
331 edge_type: None,
332 edge_label: None,
333 match_reason: None,
334 score: None,
335 }
336 }
337
338 fn sample_fragment() -> FragmentContext {
339 FragmentContext {
340 chunk_text: "Install the CLI with cargo install tuitbot".to_string(),
341 citation: sample_citation(),
342 }
343 }
344
345 #[test]
346 fn format_fragments_prompt_empty() {
347 let result = format_fragments_prompt(&[]);
348 assert!(result.is_empty());
349 }
350
351 #[test]
352 fn format_fragments_prompt_single() {
353 let f = make_fragment(1, "Some interesting insight about Rust", "notes/rust.md");
354 let result = format_fragments_prompt(&[f]);
355 assert!(result.contains("Relevant knowledge from your notes:"));
356 assert!(result.contains("(from: notes/rust.md)"));
357 assert!(result.contains("Some interesting insight about Rust"));
358 assert!(result.contains("Reference these insights"));
359 }
360
361 #[test]
362 fn format_fragments_single_with_heading() {
363 let frags = vec![sample_fragment()];
364 let result = format_fragments_prompt(&frags);
365 assert!(result.contains("Relevant knowledge"));
366 assert!(result.contains("Installation Guide"));
367 assert!(result.contains("# Guide > ## Setup"));
368 assert!(result.contains("Reference these insights"));
369 }
370
371 #[test]
372 fn format_fragments_prompt_truncates_at_limit() {
373 let big_text = "A".repeat(300);
374 let fragments: Vec<FragmentContext> = (0..20)
375 .map(|i| make_fragment(i, &big_text, &format!("notes/{i}.md")))
376 .collect();
377 let result = format_fragments_prompt(&fragments);
378 assert!(result.len() <= MAX_FRAGMENT_CHARS);
379 }
380
381 #[test]
382 fn format_fragments_multiple_items_numbered() {
383 let mut f1 = sample_fragment();
384 f1.citation.source_title = Some("First".to_string());
385 let mut f2 = sample_fragment();
386 f2.citation.source_title = Some("Second".to_string());
387 let result = format_fragments_prompt(&[f1, f2]);
388 assert!(result.contains("1."));
389 assert!(result.contains("2."));
390 }
391
392 #[test]
393 fn build_citations_empty() {
394 let result = build_citations(&[]);
395 assert!(result.is_empty());
396 }
397
398 #[test]
399 fn build_citations_preserves_fields() {
400 let f = make_fragment(42, "chunk text here", "vault/note.md");
401 let citations = build_citations(&[f]);
402 assert_eq!(citations.len(), 1);
403 assert_eq!(citations[0].chunk_id, 42);
404 assert_eq!(citations[0].node_id, 420);
405 assert_eq!(citations[0].source_path, "vault/note.md");
406 assert_eq!(citations[0].retrieval_boost, 1.0);
407 }
408
409 #[test]
410 fn build_citations_returns_all() {
411 let frags = vec![sample_fragment(), sample_fragment()];
412 let citations = build_citations(&frags);
413 assert_eq!(citations.len(), 2);
414 }
415
416 #[test]
417 fn citations_to_provenance_refs_maps_fields() {
418 let citation = VaultCitation {
419 chunk_id: 5,
420 node_id: 50,
421 heading_path: "# Title > ## Section".to_string(),
422 source_path: "docs/guide.md".to_string(),
423 source_title: Some("Guide".to_string()),
424 snippet: "snippet text".to_string(),
425 retrieval_boost: 1.5,
426 edge_type: None,
427 edge_label: None,
428 match_reason: None,
429 score: None,
430 };
431 let refs = citations_to_provenance_refs(&[citation]);
432 assert_eq!(refs.len(), 1);
433 assert_eq!(refs[0].node_id, Some(50));
434 assert_eq!(refs[0].chunk_id, Some(5));
435 assert_eq!(refs[0].source_path.as_deref(), Some("docs/guide.md"));
436 assert_eq!(
437 refs[0].heading_path.as_deref(),
438 Some("# Title > ## Section")
439 );
440 assert_eq!(refs[0].snippet.as_deref(), Some("snippet text"));
441 assert!(refs[0].seed_id.is_none());
442 }
443
444 #[test]
445 fn citations_to_chunks_json_empty() {
446 let result = citations_to_chunks_json(&[]);
447 assert_eq!(result, "[]");
448 }
449
450 #[test]
451 fn citations_to_chunks_json_valid() {
452 let citation = VaultCitation {
453 chunk_id: 7,
454 node_id: 70,
455 heading_path: "# Intro".to_string(),
456 source_path: "notes/intro.md".to_string(),
457 source_title: None,
458 snippet: "intro text".to_string(),
459 retrieval_boost: 1.0,
460 edge_type: None,
461 edge_label: None,
462 match_reason: None,
463 score: None,
464 };
465 let result = citations_to_chunks_json(&[citation]);
466 let parsed: Vec<serde_json::Value> = serde_json::from_str(&result).unwrap();
467 assert_eq!(parsed.len(), 1);
468 assert_eq!(parsed[0]["chunk_id"], 7);
469 assert_eq!(parsed[0]["node_id"], 70);
470 assert_eq!(parsed[0]["source_path"], "notes/intro.md");
471 assert_eq!(parsed[0]["heading_path"], "# Intro");
472 }
473
474 #[test]
475 fn format_fragments_heading_path_empty() {
476 let f = make_fragment(1, "some text", "path.md");
477 let result = format_fragments_prompt(&[f]);
478 assert!(!result.contains("[] "));
479 }
480
481 #[test]
482 fn format_fragments_source_title_fallback() {
483 let f = make_fragment(1, "content here", "vault/fallback.md");
484 let result = format_fragments_prompt(&[f]);
485 assert!(result.contains("vault/fallback.md"));
486 }
487
488 #[test]
489 fn truncate_text_short_unchanged() {
490 assert_eq!(truncate_text("hello", 10), "hello");
491 }
492
493 #[test]
494 fn truncate_text_long_gets_ellipsis() {
495 let result = truncate_text("hello world this is long", 10);
496 assert!(result.ends_with("..."));
497 assert!(result.len() <= 13);
498 }
499
500 #[test]
501 fn truncate_text_exact_boundary() {
502 let result = truncate_text("hello", 5);
503 assert_eq!(result, "hello");
504 }
505
506 #[test]
507 fn truncate_text_empty_string() {
508 assert_eq!(truncate_text("", 10), "");
509 }
510
511 #[test]
512 fn truncate_text_zero_max() {
513 let result = truncate_text("hello", 0);
514 assert_eq!(result, "...");
516 }
517
518 #[test]
519 fn citations_to_provenance_refs_empty() {
520 let refs = citations_to_provenance_refs(&[]);
521 assert!(refs.is_empty());
522 }
523
524 #[test]
525 fn citations_to_chunks_json_multiple() {
526 let citations = vec![
527 VaultCitation {
528 chunk_id: 1,
529 node_id: 10,
530 heading_path: "# A".to_string(),
531 source_path: "a.md".to_string(),
532 source_title: None,
533 snippet: "".to_string(),
534 retrieval_boost: 1.0,
535 edge_type: None,
536 edge_label: None,
537 match_reason: None,
538 score: None,
539 },
540 VaultCitation {
541 chunk_id: 2,
542 node_id: 20,
543 heading_path: "# B".to_string(),
544 source_path: "b.md".to_string(),
545 source_title: Some("B".to_string()),
546 snippet: "".to_string(),
547 retrieval_boost: 2.0,
548 edge_type: None,
549 edge_label: None,
550 match_reason: None,
551 score: None,
552 },
553 ];
554 let json_str = citations_to_chunks_json(&citations);
555 let parsed: Vec<serde_json::Value> = serde_json::from_str(&json_str).unwrap();
556 assert_eq!(parsed.len(), 2);
557 assert_eq!(parsed[0]["chunk_id"], 1);
558 assert_eq!(parsed[1]["chunk_id"], 2);
559 }
560
561 #[test]
562 fn format_fragments_with_source_title() {
563 let f = FragmentContext {
564 chunk_text: "CLI tool for managing bots".to_string(),
565 citation: VaultCitation {
566 chunk_id: 1,
567 node_id: 10,
568 heading_path: "".to_string(),
569 source_path: "vault/cli.md".to_string(),
570 source_title: Some("CLI Guide".to_string()),
571 snippet: "CLI tool...".to_string(),
572 retrieval_boost: 1.0,
573 edge_type: None,
574 edge_label: None,
575 match_reason: None,
576 score: None,
577 },
578 };
579 let result = format_fragments_prompt(&[f]);
580 assert!(result.contains("CLI Guide"));
581 assert!(!result.contains("vault/cli.md")); }
583
584 #[test]
585 fn fragment_from_chunk_with_context_builds_correctly() {
586 use crate::storage::watchtower::{ChunkWithNodeContext, ContentChunk};
587
588 let cwc = ChunkWithNodeContext {
589 chunk: ContentChunk {
590 id: 42,
591 account_id: "acct".to_string(),
592 node_id: 100,
593 heading_path: "# Title".to_string(),
594 chunk_text: "Some chunk text for testing purposes".to_string(),
595 chunk_hash: "hash".to_string(),
596 chunk_index: 0,
597 retrieval_boost: 1.5,
598 status: "active".to_string(),
599 created_at: "2026-01-01".to_string(),
600 updated_at: "2026-01-01".to_string(),
601 },
602 relative_path: "notes/test.md".to_string(),
603 source_title: Some("Test Note".to_string()),
604 };
605
606 let frag = fragment_from_chunk_with_context(cwc);
607 assert_eq!(frag.citation.chunk_id, 42);
608 assert_eq!(frag.citation.node_id, 100);
609 assert_eq!(frag.citation.source_path, "notes/test.md");
610 assert_eq!(frag.citation.source_title, Some("Test Note".to_string()));
611 assert_eq!(frag.citation.heading_path, "# Title");
612 assert!((frag.citation.retrieval_boost - 1.5).abs() < 0.001);
613 assert_eq!(frag.chunk_text, "Some chunk text for testing purposes");
614 }
615
616 #[test]
617 fn vault_citation_clone() {
618 let c = sample_citation();
619 let c2 = c.clone();
620 assert_eq!(c.chunk_id, c2.chunk_id);
621 assert_eq!(c.heading_path, c2.heading_path);
622 }
623
624 #[test]
625 fn fragment_context_clone() {
626 let f = sample_fragment();
627 let f2 = f.clone();
628 assert_eq!(f.chunk_text, f2.chunk_text);
629 assert_eq!(f.citation.chunk_id, f2.citation.chunk_id);
630 }
631
632 #[test]
633 fn constants_have_expected_values() {
634 assert_eq!(MAX_FRAGMENT_CHARS, 2500);
635 assert_eq!(MAX_FRAGMENTS, 5);
636 }
637}