1use crate::errors::AppError;
8use crate::graph::{
9 bfs_with_predecessors, traverse_from_memories_with_hops_capped, PredecessorMap,
10};
11use crate::output;
12use crate::paths::AppPaths;
13use crate::storage::connection::open_ro;
14use crate::storage::fusion::{rrf_fuse, rrf_max_possible};
15use crate::storage::{entities, memories};
16
17use serde::Serialize;
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20use tokio::sync::Semaphore;
21use tokio::task::JoinSet;
22
23#[derive(clap::Args)]
25#[command(
26 about = "Deep parallel multi-hop GraphRAG research via query decomposition",
27 after_long_help = "EXAMPLES:\n \
28 # Basic deep research\n \
29 sqlite-graphrag deep-research \"auth architecture decisions\"\n\n \
30 # With custom parameters\n \
31 sqlite-graphrag deep-research \"auth\" --k 20 --max-hops 3 --max-sub-queries 7\n\n \
32 # Include full memory bodies in output\n \
33 sqlite-graphrag deep-research \"auth\" --with-bodies\n\n \
34 # Tune RRF and graph scoring\n \
35 sqlite-graphrag deep-research \"auth and deployment\" --rrf-k 60 --graph-decay 0.7"
36)]
37pub struct DeepResearchArgs {
38 #[arg(value_name = "QUERY", help = "Research query to decompose and search")]
40 pub query: String,
41 #[arg(
43 long,
44 short,
45 default_value_t = 20,
46 help = "Results per sub-query (Recall@20 captures 95%+ relevant hits)"
47 )]
48 pub k: usize,
49 #[arg(
51 long,
52 default_value_t = 7,
53 help = "Maximum sub-queries (covers complex multi-hop queries)"
54 )]
55 pub max_sub_queries: usize,
56 #[arg(
58 long,
59 default_value_t = 3,
60 help = "Multi-hop graph traversal depth (sweet spot: 2-3 hops)"
61 )]
62 pub max_hops: usize,
63 #[arg(
65 long,
66 default_value_t = 0.3,
67 help = "Minimum edge weight for graph traversal"
68 )]
69 pub min_weight: f64,
70 #[arg(long, help = "Maximum concurrent sub-queries (default: min(cpus, 8))")]
72 pub max_concurrency: Option<usize>,
73 #[arg(long, default_value_t = 30, help = "Timeout per sub-query in seconds")]
75 pub timeout: u64,
76 #[arg(
78 long,
79 default_value_t = false,
80 help = "Include full memory bodies in results"
81 )]
82 pub with_bodies: bool,
83 #[arg(
85 long,
86 default_value_t = 50,
87 help = "Maximum results after deduplication"
88 )]
89 pub max_results: usize,
90 #[arg(
92 long,
93 default_value_t = 60.0,
94 help = "RRF k parameter (higher = less weight on top ranks)"
95 )]
96 pub rrf_k: f64,
97 #[arg(
99 long,
100 default_value_t = 0.7,
101 help = "Graph score decay factor per hop (0.0-1.0)"
102 )]
103 pub graph_decay: f64,
104 #[arg(
106 long,
107 default_value_t = 0.2,
108 help = "Minimum score threshold for graph-expanded results"
109 )]
110 pub graph_min_score: f64,
111 #[arg(
113 long,
114 help = "Limit neighbours per entity per hop for graph traversal (default: unlimited)"
115 )]
116 pub max_neighbors_per_hop: Option<usize>,
117 #[arg(
119 long,
120 help = "Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global)"
121 )]
122 pub namespace: Option<String>,
123 #[arg(long, hide = true)]
125 pub json: bool,
126 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
128 pub db: Option<String>,
129 #[command(flatten)]
130 pub daemon: crate::cli::DaemonOpts,
131}
132
133#[derive(Serialize)]
134struct SubQuery {
135 id: usize,
136 text: String,
137 source: &'static str,
138}
139
140#[derive(Serialize)]
141struct DeepResult {
142 name: String,
143 score: f64,
144 source: String,
145 sub_query_ids: Vec<usize>,
146 snippet: String,
147 #[serde(skip_serializing_if = "Option::is_none")]
148 body: Option<String>,
149 hop_distance: Option<usize>,
150}
151
152#[derive(Serialize, Clone)]
154struct EvidenceNode {
155 entity: String,
156 #[serde(skip_serializing_if = "Option::is_none")]
157 relation: Option<String>,
158 #[serde(skip_serializing_if = "Option::is_none")]
159 weight: Option<f64>,
160}
161
162#[derive(Serialize)]
171struct EvidenceChain {
172 from: String,
173 to: String,
174 path: Vec<EvidenceNode>,
175 total_weight: f64,
176 depth: usize,
177 sub_query_ids: Vec<usize>,
178}
179
180#[derive(Serialize)]
181struct ResearchStats {
182 sub_queries_total: usize,
183 sub_queries_completed: usize,
184 sub_queries_failed: usize,
185 sub_queries_timed_out: usize,
186 unique_memories_found: usize,
187 evidence_chains_found: usize,
188 elapsed_ms: u64,
189}
190
191#[derive(Serialize)]
192struct DeepResearchResponse {
193 query: String,
194 sub_queries: Vec<SubQuery>,
195 results: Vec<DeepResult>,
196 evidence_chains: Vec<EvidenceChain>,
197 stats: ResearchStats,
198}
199
200type MergedHit = (f64, String, String, String, Option<usize>, Vec<usize>);
202
203struct SubQueryResult {
205 sub_query_id: usize,
206 hits: Vec<(i64, f64, String, String, String, Option<usize>)>,
208 chains: Vec<EvidenceChain>,
210}
211
212pub fn run(args: DeepResearchArgs) -> Result<(), AppError> {
214 let rt = tokio::runtime::Builder::new_multi_thread()
215 .worker_threads(2)
216 .enable_all()
217 .build()
218 .map_err(|e| AppError::Internal(anyhow::anyhow!("failed to build tokio runtime: {e}")))?;
219 rt.block_on(run_async(args))
220}
221
222async fn run_async(args: DeepResearchArgs) -> Result<(), AppError> {
224 let start = std::time::Instant::now();
225
226 if args.query.trim().is_empty() {
227 return Err(AppError::Validation(crate::i18n::validation::empty_query()));
228 }
229
230 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
231 let paths = AppPaths::resolve(args.db.as_deref())?;
232 crate::storage::connection::ensure_db_ready(&paths)?;
233
234 let sub_query_texts = decompose_query(&args.query, args.max_sub_queries);
236 let sub_queries: Vec<SubQuery> = sub_query_texts
237 .iter()
238 .enumerate()
239 .map(|(i, text)| SubQuery {
240 id: i,
241 text: text.clone(),
242 source: if sub_query_texts.len() == 1 {
243 "original"
244 } else {
245 "decomposed"
246 },
247 })
248 .collect();
249
250 output::emit_progress_i18n(
254 "Computing per-sub-query embeddings...",
255 "Calculando embeddings por sub-consulta...",
256 );
257 let mut sub_embeddings: Vec<Arc<Vec<f32>>> = Vec::with_capacity(sub_query_texts.len());
258 for sq_text in &sub_query_texts {
259 let emb = crate::daemon::embed_query_or_local(
260 &paths.models,
261 sq_text,
262 args.daemon.autostart_daemon,
263 )?;
264 sub_embeddings.push(Arc::new(emb));
265 }
266
267 let cpu_count = std::thread::available_parallelism()
269 .map(|n| n.get())
270 .unwrap_or(4);
271 let permits = args
272 .max_concurrency
273 .unwrap_or_else(|| cpu_count.min(8))
274 .min(sub_queries.len())
275 .max(1);
276 let semaphore = Arc::new(Semaphore::new(permits));
277 let timeout_dur = std::time::Duration::from_secs(args.timeout);
278
279 let mut join_set: JoinSet<Result<SubQueryResult, (usize, String)>> = JoinSet::new();
280
281 for (idx, sq_text) in sub_query_texts.iter().enumerate() {
282 let sem = Arc::clone(&semaphore);
283 let emb = Arc::clone(&sub_embeddings[idx]);
285 let ns = namespace.clone();
286 let db_path = paths.db.clone();
287 let query_text = sq_text.clone();
288 let k = args.k;
289 let max_hops = args.max_hops;
290 let min_weight = args.min_weight;
291 let rrf_k = args.rrf_k;
292 let graph_decay = args.graph_decay;
293 let graph_min_score = args.graph_min_score;
294 let max_neighbors_per_hop = args.max_neighbors_per_hop;
295
296 join_set.spawn(async move {
297 let _permit = sem
298 .acquire_owned()
299 .await
300 .map_err(|e| (idx, format!("semaphore closed: {e}")))?;
301
302 let result = tokio::time::timeout(timeout_dur, async move {
304 execute_sub_query(
305 idx,
306 &query_text,
307 emb.as_slice(),
308 &ns,
309 &db_path,
310 k,
311 max_hops,
312 min_weight,
313 rrf_k,
314 graph_decay,
315 graph_min_score,
316 max_neighbors_per_hop,
317 )
318 })
319 .await;
320
321 match result {
322 Ok(inner) => inner.map_err(|e| (idx, e)),
323 Err(_) => Err((idx, "timeout".to_string())),
324 }
325 });
326 }
327
328 let mut sub_query_results: Vec<SubQueryResult> = Vec::with_capacity(sub_queries.len());
330 let mut failed_count = 0usize;
331 let mut timed_out_count = 0usize;
332
333 while let Some(join_result) = join_set.join_next().await {
334 match join_result {
335 Ok(Ok(sqr)) => sub_query_results.push(sqr),
336 Ok(Err((_idx, reason))) => {
337 if reason == "timeout" {
338 timed_out_count += 1;
339 } else {
340 failed_count += 1;
341 }
342 tracing::warn!(sub_query_id = _idx, reason = %reason, "sub-query failed");
343 }
344 Err(join_err) => {
345 failed_count += 1;
346 if join_err.is_panic() {
347 tracing::error!("sub-query task panicked: {join_err}");
348 } else {
349 tracing::warn!("sub-query task cancelled: {join_err}");
350 }
351 }
352 }
353 }
354
355 let mut merged: HashMap<i64, MergedHit> = HashMap::new();
358
359 for sqr in &sub_query_results {
360 for (mem_id, score, source, snippet, body, hop) in &sqr.hits {
361 let entry = merged.entry(*mem_id).or_insert_with(|| {
362 (
363 *score,
364 source.clone(),
365 snippet.clone(),
366 body.clone(),
367 *hop,
368 Vec::new(),
369 )
370 });
371 if *score > entry.0 {
373 entry.0 = *score;
374 entry.1 = source.clone();
375 entry.2 = snippet.clone();
376 entry.3 = body.clone();
377 entry.4 = *hop;
378 }
379 if !entry.5.contains(&sqr.sub_query_id) {
380 entry.5.push(sqr.sub_query_id);
381 }
382 }
383 }
384
385 let conn = open_ro(&paths.db)?;
387 let mut results: Vec<DeepResult> = Vec::with_capacity(merged.len().min(args.max_results));
388
389 let mut ranked: Vec<(i64, MergedHit)> = merged.into_iter().collect();
391 ranked.sort_by(|a, b| {
392 b.1 .0
393 .partial_cmp(&a.1 .0)
394 .unwrap_or(std::cmp::Ordering::Equal)
395 });
396 ranked.truncate(args.max_results);
397
398 for (mem_id, (score, source, snippet, body, hop, sq_ids)) in ranked {
399 let name = match memories::read_full(&conn, mem_id)? {
400 Some(row) => row.name,
401 None => continue,
402 };
403 results.push(DeepResult {
404 name,
405 score,
406 source,
407 sub_query_ids: sq_ids,
408 snippet,
409 body: if args.with_bodies { Some(body) } else { None },
410 hop_distance: hop,
411 });
412 }
413
414 let completed_count = sub_query_results.len();
418 let mut evidence_chains: Vec<EvidenceChain> = Vec::new();
419 let mut seen_chain_keys: HashSet<String> = HashSet::new();
420
421 for sqr in sub_query_results {
422 for chain in sqr.chains {
423 let key = format!("{}->{}", chain.from, chain.to);
425 if seen_chain_keys.insert(key) {
426 evidence_chains.push(chain);
427 }
428 }
429 }
430
431 evidence_chains.retain(|c| c.depth >= 2);
433 evidence_chains.sort_by(|a, b| {
434 b.total_weight
435 .partial_cmp(&a.total_weight)
436 .unwrap_or(std::cmp::Ordering::Equal)
437 });
438
439 let unique_memories = results.len();
440 let evidence_count = evidence_chains.len();
441
442 output::emit_json(&DeepResearchResponse {
444 query: args.query,
445 sub_queries,
446 results,
447 evidence_chains,
448 stats: ResearchStats {
449 sub_queries_total: sub_query_texts.len(),
450 sub_queries_completed: completed_count,
451 sub_queries_failed: failed_count,
452 sub_queries_timed_out: timed_out_count,
453 unique_memories_found: unique_memories,
454 evidence_chains_found: evidence_count,
455 elapsed_ms: start.elapsed().as_millis() as u64,
456 },
457 })?;
458
459 Ok(())
460}
461
462fn decompose_query(query: &str, max: usize) -> Vec<String> {
465 if query.is_empty() {
466 return vec![query.to_string()];
467 }
468
469 let mut parts: Vec<String> = Vec::new();
470
471 let relational = [
473 " that caused ",
474 " depending on ",
475 " related to ",
476 " connected to ",
477 " linked to ",
478 " caused by ",
479 " followed by ",
480 ];
481 let mut text = query.to_string();
482 let mut did_relational_split = false;
483 for phrase in &relational {
484 if text.to_lowercase().contains(phrase) {
485 let lower = text.to_lowercase();
486 if let Some(pos) = lower.find(phrase) {
487 let left = text[..pos].trim().to_string();
488 let right = text[pos + phrase.len()..].trim().to_string();
489 if !left.is_empty() {
490 parts.push(left);
491 }
492 if !right.is_empty() {
493 text = right;
494 }
495 did_relational_split = true;
496 }
497 }
498 }
499 if did_relational_split && !text.is_empty() {
500 parts.push(text.clone());
501 }
502
503 if parts.is_empty() {
505 let semi_parts: Vec<&str> = query.split(';').collect();
507 if semi_parts.len() > 1 {
508 for p in &semi_parts {
509 let trimmed = p.trim();
510 if !trimmed.is_empty() {
511 parts.push(trimmed.to_string());
512 }
513 }
514 } else {
515 let normalized = query
518 .replace(" and ", ", ")
519 .replace(" AND ", ", ")
520 .replace(" e ", ", ")
521 .replace(" E ", ", ");
522 let comma_parts: Vec<&str> = normalized.split(',').collect();
523 if comma_parts.len() > 1 {
524 for p in &comma_parts {
525 let trimmed = p.trim();
526 if !trimmed.is_empty() {
527 parts.push(trimmed.to_string());
528 }
529 }
530 }
531 }
532 }
533
534 if parts.is_empty() {
536 return vec![query.to_string()];
537 }
538
539 parts.truncate(max);
541 parts
542}
543
544fn reconstruct_path(
548 target_id: i64,
549 seed_entity_ids: &HashSet<i64>,
550 predecessor: &PredecessorMap,
551 entity_names: &HashMap<i64, String>,
552) -> Option<(Vec<EvidenceNode>, f64)> {
553 let mut path_ids: Vec<(i64, Option<String>, Option<f64>)> = Vec::new();
554 let mut total_weight = 1.0_f64;
555 let mut current = target_id;
556
557 loop {
558 if seed_entity_ids.contains(¤t) {
559 break;
560 }
561 let (parent, relation, weight) = predecessor.get(¤t)?;
562 total_weight *= weight;
563 path_ids.push((current, Some(relation.clone()), Some(*weight)));
564 current = *parent;
565 }
566 path_ids.push((current, None, None));
568
569 path_ids.reverse();
571
572 let nodes: Vec<EvidenceNode> = path_ids
573 .into_iter()
574 .map(|(id, relation, weight)| EvidenceNode {
575 entity: entity_names
576 .get(&id)
577 .cloned()
578 .unwrap_or_else(|| format!("entity-{id}")),
579 relation,
580 weight,
581 })
582 .collect();
583
584 Some((nodes, total_weight))
585}
586
587#[allow(clippy::too_many_arguments)]
597fn execute_sub_query(
598 sub_query_id: usize,
599 query_text: &str,
600 embedding: &[f32],
601 namespace: &str,
602 db_path: &std::path::Path,
603 k: usize,
604 max_hops: usize,
605 min_weight: f64,
606 rrf_k: f64,
607 graph_decay: f64,
608 graph_min_score: f64,
609 max_neighbors_per_hop: Option<usize>,
610) -> Result<SubQueryResult, String> {
611 let conn = open_ro(db_path).map_err(|e| format!("failed to open db: {e}"))?;
612
613 let mut hits: Vec<(i64, f64, String, String, String, Option<usize>)> =
614 Vec::with_capacity(k * 2);
615 let mut seen_ids: HashSet<i64> = HashSet::new();
616
617 let knn_results = memories::knn_search(&conn, embedding, &[namespace.to_string()], None, k)
621 .map_err(|e| format!("knn_search failed: {e}"))?;
622 let knn_ids: Vec<i64> = knn_results.iter().map(|(id, _)| *id).collect();
623
624 let knn_distance_map: HashMap<i64, f64> = knn_results
626 .iter()
627 .map(|(id, dist)| (*id, *dist as f64))
628 .collect();
629
630 let fts_results = match memories::fts_search(&conn, query_text, namespace, None, k) {
632 Ok(rows) => rows,
633 Err(e) => {
634 tracing::warn!(
635 sub_query_id,
636 "FTS5 search failed, continuing with KNN only: {e}"
637 );
638 vec![]
639 }
640 };
641 let fts_ids: Vec<i64> = fts_results.iter().map(|r| r.id).collect();
642
643 let rrf_scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], rrf_k);
645 let max_possible = rrf_max_possible(&[1.0, 1.0], rrf_k);
646
647 let mut fused: Vec<(i64, f64)> = rrf_scores.into_iter().collect();
649 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
650 fused.truncate(k * 2);
651
652 for (memory_id, combined_score) in &fused {
653 if seen_ids.insert(*memory_id) {
654 let normalized = if max_possible > 0.0 {
655 combined_score / max_possible
656 } else {
657 0.0
658 };
659 let score = normalized.clamp(0.0, 1.0);
660 let source = if knn_distance_map.contains_key(memory_id) {
661 "knn"
662 } else {
663 "fts"
664 };
665 if let Ok(Some(row)) = memories::read_full(&conn, *memory_id) {
666 let snippet: String = row.body.chars().take(300).collect();
667 hits.push((
668 *memory_id,
669 score,
670 source.to_string(),
671 snippet,
672 row.body,
673 None,
674 ));
675 }
676 }
677 }
678
679 let memory_ids: Vec<i64> = hits.iter().map(|(id, ..)| *id).collect();
682 let mut chains: Vec<EvidenceChain> = Vec::new();
683
684 if !memory_ids.is_empty() && max_hops > 0 {
685 let entity_knn = entities::knn_search(&conn, embedding, namespace, 5).unwrap_or_default();
687 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
688
689 let mut seed_entity_ids: Vec<i64> = entity_ids.clone();
691 for &mem_id in &memory_ids {
692 let mut stmt = conn
693 .prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")
694 .map_err(|e| format!("prepare failed: {e}"))?;
695 let ids: Vec<i64> = stmt
696 .query_map(rusqlite::params![mem_id], |r| r.get(0))
697 .map_err(|e| format!("query failed: {e}"))?
698 .filter_map(|r| r.ok())
699 .collect();
700 seed_entity_ids.extend(ids);
701 }
702 seed_entity_ids.sort_unstable();
703 seed_entity_ids.dedup();
704
705 let all_seed_ids: Vec<i64> = memory_ids
706 .iter()
707 .chain(entity_ids.iter())
708 .copied()
709 .collect();
710
711 if let Ok(graph_results) = traverse_from_memories_with_hops_capped(
713 &conn,
714 &all_seed_ids,
715 namespace,
716 min_weight,
717 max_hops as u32,
718 max_neighbors_per_hop,
719 ) {
720 let seed_score_map: HashMap<i64, f64> = fused
722 .iter()
723 .map(|(id, s)| {
724 let normalized = if max_possible > 0.0 {
725 s / max_possible
726 } else {
727 0.0
728 };
729 (*id, normalized.clamp(0.0, 1.0))
730 })
731 .collect();
732
733 for (graph_mem_id, hop) in graph_results {
734 if seen_ids.insert(graph_mem_id) {
735 let avg_seed_score: f64 = if seed_score_map.is_empty() {
740 0.5
741 } else {
742 let sum: f64 = seed_score_map.values().sum();
743 sum / seed_score_map.len() as f64
744 };
745 let graph_score =
746 (avg_seed_score * graph_decay.powi(hop as i32)).clamp(0.0, 1.0);
747
748 if graph_score < graph_min_score {
749 continue;
750 }
751
752 if let Ok(Some(row)) = memories::read_full(&conn, graph_mem_id) {
753 let snippet: String = row.body.chars().take(300).collect();
754 hits.push((
755 graph_mem_id,
756 graph_score,
757 "graph".to_string(),
758 snippet,
759 row.body,
760 Some(hop as usize),
761 ));
762 }
763 }
764 }
765 }
766
767 if !seed_entity_ids.is_empty() {
770 let (entity_depth, predecessor) = bfs_with_predecessors(
771 &conn,
772 &seed_entity_ids,
773 namespace,
774 min_weight,
775 max_hops as u32,
776 max_neighbors_per_hop,
777 )
778 .unwrap_or_default();
779
780 let seed_entity_set: HashSet<i64> = seed_entity_ids.iter().copied().collect();
781
782 let all_entity_ids: Vec<i64> = entity_depth.keys().copied().collect();
784 let mut entity_names: HashMap<i64, String> = HashMap::new();
785 for &eid in &all_entity_ids {
786 let name_res: rusqlite::Result<String> = conn.query_row(
787 "SELECT name FROM entities WHERE id = ?1",
788 rusqlite::params![eid],
789 |r| r.get(0),
790 );
791 if let Ok(name) = name_res {
792 entity_names.insert(eid, name);
793 }
794 }
795
796 for (&target_id, &_hop) in &entity_depth {
798 if seed_entity_set.contains(&target_id) {
799 continue;
800 }
801 if !predecessor.contains_key(&target_id) {
802 continue;
803 }
804 if let Some((path_nodes, total_weight)) =
805 reconstruct_path(target_id, &seed_entity_set, &predecessor, &entity_names)
806 {
807 if path_nodes.len() < 2 {
808 continue;
809 }
810 let from = path_nodes
811 .first()
812 .map(|n| n.entity.clone())
813 .unwrap_or_default();
814 let to = path_nodes
815 .last()
816 .map(|n| n.entity.clone())
817 .unwrap_or_default();
818 let depth = path_nodes.len();
819 chains.push(EvidenceChain {
820 from,
821 to,
822 path: path_nodes,
823 total_weight,
824 depth,
825 sub_query_ids: vec![sub_query_id],
826 });
827 }
828 }
829
830 chains.sort_by(|a, b| {
832 b.total_weight
833 .partial_cmp(&a.total_weight)
834 .unwrap_or(std::cmp::Ordering::Equal)
835 });
836 chains.truncate(20);
837 }
838 }
839
840 Ok(SubQueryResult {
841 sub_query_id,
842 hits,
843 chains,
844 })
845}
846
847#[cfg(test)]
853mod tests {
854 use super::*;
855
856 #[test]
857 fn test_decompose_and_conjunction() {
858 let result = decompose_query("A and B", 7);
859 assert_eq!(result, vec!["A", "B"]);
860 }
861
862 #[test]
863 fn test_decompose_no_split() {
864 let result = decompose_query("simple query", 7);
865 assert_eq!(result, vec!["simple query"]);
866 }
867
868 #[test]
869 fn test_decompose_three_parts() {
870 let result = decompose_query("A, B and C", 7);
871 assert_eq!(result, vec!["A", "B", "C"]);
872 }
873
874 #[test]
875 fn test_decompose_portuguese_conjunctions() {
876 let result = decompose_query("A e B", 7);
877 assert_eq!(result, vec!["A", "B"]);
878 }
879
880 #[test]
881 fn test_decompose_max_cap() {
882 let parts: Vec<String> = (0..10).map(|i| format!("part{i}")).collect();
883 let query = parts.join(", ");
884 let result = decompose_query(&query, 7);
885 assert!(
886 result.len() <= 7,
887 "expected at most 7 sub-queries, got {}",
888 result.len()
889 );
890 }
891
892 #[test]
893 fn test_decompose_empty_preserves_original() {
894 let result = decompose_query("", 7);
895 assert_eq!(result, vec![""]);
896 }
897
898 #[test]
899 fn test_decompose_semicolons() {
900 let result = decompose_query("auth design; deployment config; logging", 7);
901 assert_eq!(result, vec!["auth design", "deployment config", "logging"]);
902 }
903
904 #[test]
905 fn test_decompose_relational_phrase() {
906 let result = decompose_query("auth that caused deployment failure", 7);
907 assert_eq!(result, vec!["auth", "deployment failure"]);
908 }
909
910 #[test]
911 fn test_sub_query_serialization() {
912 let sq = SubQuery {
913 id: 0,
914 text: "test query".to_string(),
915 source: "original",
916 };
917 let json = serde_json::to_value(&sq).expect("serialization failed");
918 assert_eq!(json["id"], 0);
919 assert_eq!(json["text"], "test query");
920 assert_eq!(json["source"], "original");
921 }
922
923 #[test]
924 fn test_deep_result_omits_body_when_none() {
925 let result = DeepResult {
926 name: "test".to_string(),
927 score: 0.9,
928 source: "knn".to_string(),
929 sub_query_ids: vec![0],
930 snippet: "snippet".to_string(),
931 body: None,
932 hop_distance: None,
933 };
934 let json = serde_json::to_string(&result).expect("serialization failed");
935 assert!(!json.contains("\"body\""), "body must be omitted when None");
936 }
937
938 #[test]
939 fn test_deep_result_includes_body_when_some() {
940 let result = DeepResult {
941 name: "test".to_string(),
942 score: 0.9,
943 source: "knn".to_string(),
944 sub_query_ids: vec![0, 1],
945 snippet: "snippet".to_string(),
946 body: Some("full body content".to_string()),
947 hop_distance: Some(2),
948 };
949 let json = serde_json::to_string(&result).expect("serialization failed");
950 assert!(json.contains("\"body\""), "body must be present when Some");
951 assert!(json.contains("full body content"));
952 }
953
954 #[test]
955 fn test_evidence_node_omits_none_fields() {
956 let node = EvidenceNode {
957 entity: "auth-module".to_string(),
958 relation: None,
959 weight: None,
960 };
961 let json = serde_json::to_string(&node).expect("serialization failed");
962 assert!(
963 !json.contains("\"relation\""),
964 "relation must be omitted when None"
965 );
966 assert!(
967 !json.contains("\"weight\""),
968 "weight must be omitted when None"
969 );
970 }
971
972 #[test]
973 fn test_research_stats_serialization() {
974 let stats = ResearchStats {
975 sub_queries_total: 3,
976 sub_queries_completed: 2,
977 sub_queries_failed: 1,
978 sub_queries_timed_out: 0,
979 unique_memories_found: 10,
980 evidence_chains_found: 2,
981 elapsed_ms: 1234,
982 };
983 let json = serde_json::to_value(&stats).expect("serialization failed");
984 assert_eq!(json["sub_queries_total"], 3);
985 assert_eq!(json["sub_queries_completed"], 2);
986 assert_eq!(json["sub_queries_failed"], 1);
987 assert_eq!(json["elapsed_ms"], 1234);
988 }
989
990 #[test]
991 fn test_deep_research_response_serialization() {
992 let resp = DeepResearchResponse {
993 query: "test query".to_string(),
994 sub_queries: vec![SubQuery {
995 id: 0,
996 text: "test query".to_string(),
997 source: "original",
998 }],
999 results: vec![],
1000 evidence_chains: vec![],
1001 stats: ResearchStats {
1002 sub_queries_total: 1,
1003 sub_queries_completed: 1,
1004 sub_queries_failed: 0,
1005 sub_queries_timed_out: 0,
1006 unique_memories_found: 0,
1007 evidence_chains_found: 0,
1008 elapsed_ms: 42,
1009 },
1010 };
1011 let json = serde_json::to_value(&resp).expect("serialization failed");
1012 assert_eq!(json["query"], "test query");
1013 assert!(json["sub_queries"].is_array());
1014 assert!(json["results"].is_array());
1015 assert!(json["evidence_chains"].is_array());
1016 assert_eq!(json["stats"]["elapsed_ms"], 42);
1017 }
1018
1019 #[test]
1023 fn test_distinct_sub_queries_produce_distinct_texts() {
1024 let queries = [
1025 "authentication design decisions",
1026 "deployment configuration and infrastructure",
1027 ];
1028 assert_ne!(queries[0], queries[1]);
1030
1031 let decomposed = decompose_query(
1033 "authentication design decisions; deployment configuration and infrastructure",
1034 7,
1035 );
1036 assert_eq!(decomposed.len(), 2);
1037 assert_ne!(decomposed[0], decomposed[1]);
1038 }
1039
1040 #[test]
1042 fn test_rrf_fuse_via_fusion_module() {
1043 use crate::storage::fusion::rrf_fuse;
1044
1045 let knn_ids: Vec<i64> = vec![1, 2, 3];
1046 let fts_ids: Vec<i64> = vec![2, 1, 4];
1047 let scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], 60.0);
1048
1049 let score_1 = scores[&1];
1051 let score_2 = scores[&2];
1052 let score_3 = scores[&3]; let score_4 = scores[&4]; assert!(
1056 score_1 > score_3,
1057 "id 1 (both lists) must beat id 3 (knn-only rank 3)"
1058 );
1059 assert!(
1060 score_2 > score_4,
1061 "id 2 (both lists) must beat id 4 (fts-only rank 3)"
1062 );
1063 }
1064
1065 #[test]
1067 fn test_evidence_chain_has_from_to_and_path() {
1068 let chain = EvidenceChain {
1069 from: "auth-module".to_string(),
1070 to: "jwt-service".to_string(),
1071 path: vec![
1072 EvidenceNode {
1073 entity: "auth-module".to_string(),
1074 relation: None,
1075 weight: None,
1076 },
1077 EvidenceNode {
1078 entity: "token-validator".to_string(),
1079 relation: Some("depends-on".to_string()),
1080 weight: Some(0.9),
1081 },
1082 EvidenceNode {
1083 entity: "jwt-service".to_string(),
1084 relation: Some("uses".to_string()),
1085 weight: Some(0.8),
1086 },
1087 ],
1088 total_weight: 0.72,
1089 depth: 3,
1090 sub_query_ids: vec![0],
1091 };
1092
1093 let json = serde_json::to_value(&chain).expect("serialization failed");
1094 assert!(
1095 json["from"].is_string(),
1096 "evidence chain must have 'from' field"
1097 );
1098 assert!(
1099 json["to"].is_string(),
1100 "evidence chain must have 'to' field"
1101 );
1102 assert!(
1103 json["path"].is_array(),
1104 "evidence chain must have 'path' array"
1105 );
1106 assert_eq!(json["path"].as_array().unwrap().len(), 3);
1107 assert!(json["total_weight"].is_number(), "must have total_weight");
1108 assert_eq!(json["depth"], 3);
1109 }
1110
1111 #[test]
1113 fn test_reconstruct_path_root_to_target_order() {
1114 let seed_set: HashSet<i64> = [10i64].into_iter().collect();
1116 let mut predecessor: HashMap<i64, (i64, String, f64)> = HashMap::new();
1117 predecessor.insert(20, (10, "depends-on".to_string(), 0.9));
1118 predecessor.insert(30, (20, "uses".to_string(), 0.8));
1119 let mut entity_names: HashMap<i64, String> = HashMap::new();
1120 entity_names.insert(10, "seed-entity".to_string());
1121 entity_names.insert(20, "middle-entity".to_string());
1122 entity_names.insert(30, "target-entity".to_string());
1123
1124 let result = reconstruct_path(30, &seed_set, &predecessor, &entity_names);
1125 assert!(result.is_some(), "path must be reconstructed");
1126 let (nodes, weight) = result.unwrap();
1127 assert_eq!(nodes.len(), 3);
1129 assert_eq!(nodes[0].entity, "seed-entity");
1130 assert_eq!(nodes[1].entity, "middle-entity");
1131 assert_eq!(nodes[2].entity, "target-entity");
1132 assert!((weight - 0.72).abs() < 1e-6);
1134 }
1135
1136 #[test]
1138 fn test_evidence_chains_single_hop_filtered_out() {
1139 let chain = EvidenceChain {
1141 from: "a".to_string(),
1142 to: "a".to_string(),
1143 path: vec![EvidenceNode {
1144 entity: "a".to_string(),
1145 relation: None,
1146 weight: None,
1147 }],
1148 total_weight: 1.0,
1149 depth: 1,
1150 sub_query_ids: vec![0],
1151 };
1152 let chains = vec![chain];
1154 let retained: Vec<_> = chains.into_iter().filter(|c| c.depth >= 2).collect();
1155 assert!(retained.is_empty(), "depth-1 chains must be filtered out");
1156 }
1157
1158 #[test]
1160 fn test_bfs_with_predecessors_respects_neighbor_cap() {
1161 use crate::graph::bfs_with_predecessors;
1162 use rusqlite::Connection;
1163
1164 let conn = Connection::open_in_memory().unwrap();
1165 conn.execute_batch(
1166 "CREATE TABLE relationships (
1167 source_id INTEGER NOT NULL,
1168 target_id INTEGER NOT NULL,
1169 weight REAL NOT NULL,
1170 namespace TEXT NOT NULL,
1171 relation TEXT NOT NULL DEFAULT 'related'
1172 );",
1173 )
1174 .unwrap();
1175
1176 for target in 2i64..=6 {
1178 conn.execute(
1179 "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, 'ns')",
1180 rusqlite::params![1i64, target, 1.0f64],
1181 )
1182 .unwrap();
1183 }
1184
1185 let (depth_uncapped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, None).unwrap();
1187 assert_eq!(
1188 depth_uncapped.len() - 1,
1189 5,
1190 "uncapped must discover all 5 neighbours (plus seed)"
1191 );
1192
1193 let (depth_capped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, Some(2)).unwrap();
1195 assert_eq!(
1197 depth_capped.len(),
1198 3,
1199 "capped to 2 must yield seed + 2 neighbours"
1200 );
1201 }
1202}