1use crate::errors::AppError;
8use crate::graph::{
9 bfs_with_predecessors, traverse_from_memories_with_hops_capped, PredecessorMap,
10};
11use crate::output;
12use crate::paths::AppPaths;
13use crate::storage::connection::open_ro;
14use crate::storage::fusion::{rrf_fuse, rrf_max_possible};
15use crate::storage::{entities, memories};
16
17use serde::Serialize;
18use std::collections::HashSet;
19use std::sync::Arc;
20use tokio::sync::Semaphore;
21use tokio::task::JoinSet;
22
23#[derive(clap::Args)]
25#[command(
26 about = "Deep parallel multi-hop GraphRAG research via query decomposition",
27 after_long_help = "EXAMPLES:\n \
28 # Basic deep research\n \
29 sqlite-graphrag deep-research \"auth architecture decisions\"\n\n \
30 # With custom parameters\n \
31 sqlite-graphrag deep-research \"auth\" --k 20 --max-hops 3 --max-sub-queries 7\n\n \
32 # Include full memory bodies in output\n \
33 sqlite-graphrag deep-research \"auth\" --with-bodies\n\n \
34 # Tune RRF and graph scoring\n \
35 sqlite-graphrag deep-research \"auth and deployment\" --rrf-k 60 --graph-decay 0.7"
36)]
37pub struct DeepResearchArgs {
38 #[arg(
40 value_name = "QUERY",
41 allow_hyphen_values = true,
42 help = "Research query to decompose and search"
43 )]
44 pub query: String,
45 #[arg(
47 long,
48 short,
49 aliases = ["limit", "top-k"],
50 default_value_t = 20,
51 help = "Results per sub-query (Recall@20 captures 95%+ relevant hits)"
52 )]
53 pub k: usize,
54 #[arg(
56 long,
57 default_value_t = 7,
58 help = "Maximum sub-queries (covers complex multi-hop queries)"
59 )]
60 pub max_sub_queries: usize,
61 #[arg(
63 long,
64 default_value_t = 3,
65 help = "Multi-hop graph traversal depth (sweet spot: 2-3 hops)"
66 )]
67 pub max_hops: usize,
68 #[arg(
70 long,
71 default_value_t = 0.3,
72 help = "Minimum edge weight for graph traversal"
73 )]
74 pub min_weight: f64,
75 #[arg(long, help = "Maximum concurrent sub-queries (default: min(cpus, 8))")]
77 pub max_concurrency: Option<usize>,
78 #[arg(long, default_value_t = 30, help = "Timeout per sub-query in seconds")]
80 pub timeout: u64,
81 #[arg(
83 long,
84 default_value_t = false,
85 help = "Include full memory bodies in results"
86 )]
87 pub with_bodies: bool,
88 #[arg(
90 long,
91 default_value_t = 50,
92 help = "Maximum results after deduplication"
93 )]
94 pub max_results: usize,
95 #[arg(
97 long,
98 default_value_t = 60.0,
99 help = "RRF k parameter (higher = less weight on top ranks)"
100 )]
101 pub rrf_k: f64,
102 #[arg(
104 long,
105 default_value_t = 0.7,
106 help = "Graph score decay factor per hop (0.0-1.0)"
107 )]
108 pub graph_decay: f64,
109 #[arg(
111 long,
112 default_value_t = 0.05,
113 help = "Minimum score threshold for graph-expanded results"
114 )]
115 pub graph_min_score: f64,
116 #[arg(
118 long,
119 help = "Limit neighbours per entity per hop for graph traversal (default: unlimited)"
120 )]
121 pub max_neighbors_per_hop: Option<usize>,
122 #[arg(
124 long,
125 help = "Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global)"
126 )]
127 pub namespace: Option<String>,
128 #[arg(long, default_value = "none", value_parser = ["none"], hide = true)]
130 pub mode: String,
131 #[arg(
133 long,
134 value_name = "USD",
135 help = "Max LLM cost in USD (effective with --mode claude-code/codex)"
136 )]
137 pub max_cost_usd: Option<f64>,
138 #[arg(long, hide = true)]
140 pub json: bool,
141 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
143 pub db: Option<String>,
144}
145
146#[derive(Serialize)]
147struct SubQuery {
148 id: usize,
149 text: String,
150 source: &'static str,
151}
152
153#[derive(Serialize)]
154struct DeepResult {
155 name: String,
156 score: f64,
157 source: String,
158 sub_query_ids: Vec<usize>,
159 snippet: String,
160 #[serde(skip_serializing_if = "Option::is_none")]
161 body: Option<String>,
162 hop_distance: Option<usize>,
163}
164
165#[derive(Serialize, Clone)]
167struct EvidenceNode {
168 entity: String,
169 #[serde(skip_serializing_if = "Option::is_none")]
170 relation: Option<String>,
171 #[serde(skip_serializing_if = "Option::is_none")]
172 weight: Option<f64>,
173}
174
175#[derive(Serialize)]
184struct EvidenceChain {
185 from: String,
186 to: String,
187 path: Vec<EvidenceNode>,
188 total_weight: f64,
189 depth: usize,
190 sub_query_ids: Vec<usize>,
191}
192
193#[derive(Serialize)]
194struct ResearchStats {
195 sub_queries_total: usize,
196 sub_queries_completed: usize,
197 sub_queries_failed: usize,
198 sub_queries_timed_out: usize,
199 unique_memories_found: usize,
200 evidence_chains_found: usize,
201 elapsed_ms: u64,
202 vec_degraded: bool,
203}
204
205#[derive(Serialize)]
206struct GraphContextEntity {
207 name: String,
208 entity_type: String,
209 degree: u32,
210}
211
212#[derive(Serialize)]
213struct GraphContextRel {
214 from: String,
215 to: String,
216 relation: String,
217 weight: f64,
218}
219
220#[derive(Serialize)]
221struct GraphContext {
222 entities: Vec<GraphContextEntity>,
223 relationships: Vec<GraphContextRel>,
224}
225
226#[derive(Serialize)]
227struct DeepResearchResponse {
228 query: String,
229 sub_queries: Vec<SubQuery>,
230 results: Vec<DeepResult>,
231 evidence_chains: Vec<EvidenceChain>,
232 #[serde(skip_serializing_if = "Option::is_none")]
233 graph_context: Option<GraphContext>,
234 stats: ResearchStats,
235}
236
237type MergedHit = (f64, String, String, String, Option<usize>, Vec<usize>);
239
240struct SubQueryResult {
242 sub_query_id: usize,
243 hits: Vec<(i64, f64, String, String, String, Option<usize>)>,
245 chains: Vec<EvidenceChain>,
247}
248
249#[tracing::instrument(skip_all, level = "debug", name = "deep_research")]
251pub fn run(
252 args: DeepResearchArgs,
253 llm_backend: crate::cli::LlmBackendChoice,
254 embedding_backend: crate::cli::EmbeddingBackendChoice,
255) -> Result<(), AppError> {
256 tracing::debug!(target: "deep_research", query = %args.query, k = args.k, "starting deep research");
257 let rt = tokio::runtime::Builder::new_multi_thread()
258 .worker_threads(2)
259 .enable_all()
260 .build()
261 .map_err(|e| AppError::Internal(anyhow::anyhow!("failed to build tokio runtime: {e}")))?;
262 rt.block_on(run_async(args, llm_backend, embedding_backend))
263}
264
265async fn run_async(
267 args: DeepResearchArgs,
268 llm_backend: crate::cli::LlmBackendChoice,
269 embedding_backend: crate::cli::EmbeddingBackendChoice,
270) -> Result<(), AppError> {
271 let start = std::time::Instant::now();
272
273 if args.query.trim().is_empty() {
274 return Err(AppError::Validation(crate::i18n::validation::empty_query()));
275 }
276
277 if args.max_cost_usd.is_some() && args.mode == "none" {
278 tracing::warn!(target: "deep_research", "--max-cost-usd has no effect without --mode claude-code/codex");
279 }
280
281 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
282 let paths = AppPaths::resolve(args.db.as_deref())?;
283 crate::storage::connection::ensure_db_ready(&paths)?;
284
285 let sub_query_texts = decompose_query(&args.query, args.max_sub_queries);
287 let sub_queries: Vec<SubQuery> = sub_query_texts
288 .iter()
289 .enumerate()
290 .map(|(i, text)| SubQuery {
291 id: i,
292 text: text.clone(),
293 source: if sub_query_texts.len() == 1 {
294 "original"
295 } else {
296 "decomposed"
297 },
298 })
299 .collect();
300
301 output::emit_progress_i18n(
306 "Computing per-sub-query embeddings...",
307 "Calculando embeddings por sub-consulta...",
308 );
309 let mut sub_embeddings: Vec<Option<Arc<Vec<f32>>>> = Vec::with_capacity(sub_query_texts.len());
310 let mut vec_degraded = false;
311 for sq_text in &sub_query_texts {
312 match crate::embedder::try_embed_query_with_embedding_choice(
313 &paths.models,
314 sq_text,
315 embedding_backend,
316 llm_backend,
317 ) {
318 Ok((v, _backend)) => sub_embeddings.push(Some(Arc::new(v))),
319 Err(reason) => {
320 tracing::warn!(target: "deep_research", fallback_reason = %reason, reason_code = %reason.reason_code(), "embedding failed for sub-query; falling back to FTS5");
321 sub_embeddings.push(None);
322 vec_degraded = true;
323 }
324 }
325 }
326
327 let cpu_count = std::thread::available_parallelism()
329 .map(|n| n.get())
330 .unwrap_or(4);
331 let permits = args
332 .max_concurrency
333 .unwrap_or_else(|| cpu_count.min(8))
334 .min(sub_queries.len())
335 .max(1);
336 let semaphore = Arc::new(Semaphore::new(permits));
337 let timeout_dur = std::time::Duration::from_secs(args.timeout);
338
339 let mut join_set: JoinSet<Result<SubQueryResult, (usize, String)>> = JoinSet::new();
340
341 for (idx, sq_text) in sub_query_texts.iter().enumerate() {
342 let sem = Arc::clone(&semaphore);
343 let emb = sub_embeddings[idx].clone();
345 let ns = namespace.clone();
346 let db_path = paths.db.clone();
347 let query_text = sq_text.clone();
348 let k = args.k;
349 let max_hops = args.max_hops;
350 let min_weight = args.min_weight;
351 let rrf_k = args.rrf_k;
352 let graph_decay = args.graph_decay;
353 let graph_min_score = args.graph_min_score;
354 let max_neighbors_per_hop = args.max_neighbors_per_hop;
355
356 join_set.spawn(async move {
357 let _permit = sem
358 .acquire_owned()
359 .await
360 .map_err(|e| (idx, format!("semaphore closed: {e}")))?;
361
362 let result = tokio::time::timeout(timeout_dur, async move {
364 execute_sub_query(
365 idx,
366 &query_text,
367 emb.as_ref().map(|v| v.as_slice()),
368 &ns,
369 &db_path,
370 k,
371 max_hops,
372 min_weight,
373 rrf_k,
374 graph_decay,
375 graph_min_score,
376 max_neighbors_per_hop,
377 )
378 })
379 .await;
380
381 match result {
382 Ok(inner) => inner.map_err(|e| (idx, e)),
383 Err(_) => Err((idx, "timeout".to_string())),
384 }
385 });
386 }
387
388 let mut sub_query_results: Vec<SubQueryResult> = Vec::with_capacity(sub_queries.len());
390 let mut failed_count = 0usize;
391 let mut timed_out_count = 0usize;
392
393 while let Some(join_result) = join_set.join_next().await {
394 match join_result {
395 Ok(Ok(sqr)) => sub_query_results.push(sqr),
396 Ok(Err((_idx, reason))) => {
397 if reason == "timeout" {
398 timed_out_count += 1;
399 } else {
400 failed_count += 1;
401 }
402 tracing::warn!(target: "deep_research", sub_query_id = _idx, reason = %reason, "sub-query failed");
403 }
404 Err(join_err) => {
405 failed_count += 1;
406 if join_err.is_panic() {
407 tracing::error!(target: "deep_research", error = %join_err, "sub-query task panicked");
408 } else {
409 tracing::warn!(target: "deep_research", error = %join_err, "sub-query task cancelled");
410 }
411 }
412 }
413 }
414
415 let mut merged: crate::hash::AHashMap<i64, MergedHit> =
418 crate::hash::AHashMap::with_capacity_and_hasher(
419 sub_query_results.len() * args.k,
420 Default::default(),
421 );
422
423 for sqr in &sub_query_results {
424 for (mem_id, score, source, snippet, body, hop) in &sqr.hits {
425 let entry = merged.entry(*mem_id).or_insert_with(|| {
426 (
427 *score,
428 source.clone(),
429 snippet.clone(),
430 body.clone(),
431 *hop,
432 Vec::new(),
433 )
434 });
435 if *score > entry.0 {
437 entry.0 = *score;
438 entry.1 = source.clone();
439 entry.2 = snippet.clone();
440 entry.3 = body.clone();
441 entry.4 = *hop;
442 }
443 if !entry.5.contains(&sqr.sub_query_id) {
444 entry.5.push(sqr.sub_query_id);
445 }
446 }
447 }
448
449 let conn = open_ro(&paths.db)?;
451 let mut results: Vec<DeepResult> = Vec::with_capacity(merged.len().min(args.max_results));
452
453 let mut ranked: Vec<(i64, MergedHit)> = merged.into_iter().collect();
455 ranked.sort_by(|a, b| {
456 b.1 .0
457 .partial_cmp(&a.1 .0)
458 .unwrap_or(std::cmp::Ordering::Equal)
459 });
460 ranked.truncate(args.max_results);
461
462 for (mem_id, (score, source, snippet, body, hop, sq_ids)) in ranked {
463 let name = match memories::read_full(&conn, mem_id)? {
464 Some(row) => row.name,
465 None => continue,
466 };
467 results.push(DeepResult {
468 name,
469 score,
470 source,
471 sub_query_ids: sq_ids,
472 snippet,
473 body: if args.with_bodies { Some(body) } else { None },
474 hop_distance: hop,
475 });
476 }
477
478 let completed_count = sub_query_results.len();
482 let mut evidence_chains: Vec<EvidenceChain> = Vec::with_capacity(completed_count * 2);
483 let mut seen_chain_keys: HashSet<String> = HashSet::with_capacity(completed_count * 2);
484
485 for sqr in sub_query_results {
486 for chain in sqr.chains {
487 let key = format!("{}->{}", chain.from, chain.to);
489 if seen_chain_keys.insert(key) {
490 evidence_chains.push(chain);
491 }
492 }
493 }
494
495 evidence_chains.retain(|c| c.depth >= 2);
497 evidence_chains.sort_by(|a, b| {
498 b.total_weight
499 .partial_cmp(&a.total_weight)
500 .unwrap_or(std::cmp::Ordering::Equal)
501 });
502
503 let unique_memories = results.len();
504 let evidence_count = evidence_chains.len();
505
506 let graph_context = if !results.is_empty() {
508 let result_names: Vec<&str> = results.iter().map(|r| r.name.as_str()).collect();
509 let mut ctx_entities: Vec<GraphContextEntity> = Vec::with_capacity(results.len());
510 let mut ctx_rels: Vec<GraphContextRel> = Vec::with_capacity(results.len() * 2);
511 let mut seen_entity_ids: crate::hash::AHashSet<i64> =
512 crate::hash::AHashSet::with_capacity_and_hasher(results.len(), Default::default());
513
514 for name in &result_names {
515 if let Ok(Some(eid)) = entities::find_entity_id(&conn, &namespace, name) {
516 if seen_entity_ids.insert(eid) {
517 let etype: String = conn
518 .query_row(
519 "SELECT COALESCE(type,'concept') FROM entities WHERE id = ?1",
520 rusqlite::params![eid],
521 |r| r.get(0),
522 )
523 .unwrap_or_else(|_| "concept".to_string());
524 let degree: u32 = conn
525 .query_row(
526 "SELECT COUNT(*) FROM relationships WHERE source_id = ?1 OR target_id = ?1",
527 rusqlite::params![eid],
528 |r| r.get(0),
529 )
530 .unwrap_or(0);
531 ctx_entities.push(GraphContextEntity {
532 name: name.to_string(),
533 entity_type: etype,
534 degree,
535 });
536 }
537 }
538 }
539
540 let entity_ids: Vec<i64> = seen_entity_ids.iter().copied().collect();
541 if entity_ids.len() >= 2 {
542 let placeholders: String = entity_ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
543 let sql = format!(
544 "SELECT s.name, t.name, r.relation, r.weight \
545 FROM relationships r \
546 JOIN entities s ON s.id = r.source_id \
547 JOIN entities t ON t.id = r.target_id \
548 WHERE r.source_id IN ({placeholders}) AND r.target_id IN ({placeholders}) \
549 LIMIT 50"
550 );
551 if let Ok(mut stmt) = conn.prepare(&sql) {
552 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> =
553 Vec::with_capacity(entity_ids.len() * 2);
554 for id in &entity_ids {
555 params.push(Box::new(*id));
556 }
557 for id in &entity_ids {
558 params.push(Box::new(*id));
559 }
560 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
561 params.iter().map(|p| p.as_ref()).collect();
562 if let Ok(rows) = stmt.query_map(param_refs.as_slice(), |r| {
563 Ok((
564 r.get::<_, String>(0)?,
565 r.get::<_, String>(1)?,
566 r.get::<_, String>(2)?,
567 r.get::<_, f64>(3)?,
568 ))
569 }) {
570 for row in rows.flatten() {
571 ctx_rels.push(GraphContextRel {
572 from: row.0,
573 to: row.1,
574 relation: row.2,
575 weight: row.3,
576 });
577 }
578 }
579 }
580 }
581
582 if ctx_entities.is_empty() {
583 None
584 } else {
585 Some(GraphContext {
586 entities: ctx_entities,
587 relationships: ctx_rels,
588 })
589 }
590 } else {
591 None
592 };
593
594 tracing::debug!(target: "deep_research",
595 total_results = results.len(),
596 total_chains = evidence_chains.len(),
597 "assembly complete"
598 );
599
600 output::emit_json(&DeepResearchResponse {
602 query: args.query,
603 sub_queries,
604 results,
605 evidence_chains,
606 graph_context,
607 stats: ResearchStats {
608 sub_queries_total: sub_query_texts.len(),
609 sub_queries_completed: completed_count,
610 sub_queries_failed: failed_count,
611 sub_queries_timed_out: timed_out_count,
612 unique_memories_found: unique_memories,
613 evidence_chains_found: evidence_count,
614 elapsed_ms: start.elapsed().as_millis() as u64,
615 vec_degraded,
616 },
617 })?;
618
619 Ok(())
620}
621
622fn decompose_query(query: &str, max: usize) -> Vec<String> {
625 if query.is_empty() {
626 return vec![query.to_string()];
627 }
628
629 let mut parts: Vec<String> = Vec::with_capacity(max);
630
631 let relational = [
633 " that caused ",
634 " depending on ",
635 " related to ",
636 " connected to ",
637 " linked to ",
638 " caused by ",
639 " followed by ",
640 ];
641 let mut text = query.to_string();
642 let mut did_relational_split = false;
643 for phrase in &relational {
644 if text.to_lowercase().contains(phrase) {
645 let lower = text.to_lowercase();
646 if let Some(pos) = lower.find(phrase) {
647 let left = text[..pos].trim().to_string();
648 let right = text[pos + phrase.len()..].trim().to_string();
649 if !left.is_empty() {
650 parts.push(left);
651 }
652 if !right.is_empty() {
653 text = right;
654 }
655 did_relational_split = true;
656 }
657 }
658 }
659 if did_relational_split && !text.is_empty() {
660 parts.push(text.clone());
661 }
662
663 if parts.is_empty() {
665 let semi_parts: Vec<&str> = query.split(';').collect();
667 if semi_parts.len() > 1 {
668 for p in &semi_parts {
669 let trimmed = p.trim();
670 if !trimmed.is_empty() {
671 parts.push(trimmed.to_string());
672 }
673 }
674 } else {
675 let normalized = query
678 .replace(" and ", ", ")
679 .replace(" AND ", ", ")
680 .replace(" e ", ", ")
681 .replace(" E ", ", ");
682 let comma_parts: Vec<&str> = normalized.split(',').collect();
683 if comma_parts.len() > 1 {
684 for p in &comma_parts {
685 let trimmed = p.trim();
686 if !trimmed.is_empty() {
687 parts.push(trimmed.to_string());
688 }
689 }
690 }
691 }
692 }
693
694 if parts.is_empty() {
696 let words: Vec<&str> = query.split_whitespace().filter(|w| w.len() > 2).collect();
697 if words.len() >= 3 {
698 parts.push(query.to_string());
699 parts.push(format!("{} {}", words[0], words[1]));
700 parts.push(format!(
701 "{} {}",
702 words[words.len() - 2],
703 words[words.len() - 1]
704 ));
705 }
706 }
707
708 if parts.is_empty() {
709 return vec![query.to_string()];
710 }
711
712 parts.truncate(max);
714 parts
715}
716
717fn reconstruct_path(
721 target_id: i64,
722 seed_entity_ids: &HashSet<i64>,
723 predecessor: &PredecessorMap,
724 entity_names: &crate::hash::AHashMap<i64, String>,
725) -> Option<(Vec<EvidenceNode>, f64)> {
726 let mut path_ids: Vec<(i64, Option<String>, Option<f64>)> = Vec::with_capacity(8);
727 let mut total_weight = 1.0_f64;
728 let mut current = target_id;
729
730 loop {
731 if seed_entity_ids.contains(¤t) {
732 break;
733 }
734 let (parent, relation, weight) = predecessor.get(¤t)?;
735 total_weight *= weight;
736 path_ids.push((current, Some(relation.clone()), Some(*weight)));
737 current = *parent;
738 }
739 path_ids.push((current, None, None));
741
742 path_ids.reverse();
744
745 let nodes: Vec<EvidenceNode> = path_ids
746 .into_iter()
747 .map(|(id, relation, weight)| EvidenceNode {
748 entity: entity_names
749 .get(&id)
750 .cloned()
751 .unwrap_or_else(|| format!("entity-{id}")),
752 relation,
753 weight,
754 })
755 .collect();
756
757 Some((nodes, total_weight))
758}
759
760#[allow(clippy::too_many_arguments)]
770fn execute_sub_query(
771 sub_query_id: usize,
772 query_text: &str,
773 embedding: Option<&[f32]>,
774 namespace: &str,
775 db_path: &std::path::Path,
776 k: usize,
777 max_hops: usize,
778 min_weight: f64,
779 rrf_k: f64,
780 graph_decay: f64,
781 graph_min_score: f64,
782 max_neighbors_per_hop: Option<usize>,
783) -> Result<SubQueryResult, String> {
784 let conn = open_ro(db_path).map_err(|e| format!("failed to open db: {e}"))?;
785
786 let mut hits: Vec<(i64, f64, String, String, String, Option<usize>)> =
787 Vec::with_capacity(k * 2);
788 let mut seen_ids: crate::hash::AHashSet<i64> =
789 crate::hash::AHashSet::with_capacity_and_hasher(k * 2, Default::default());
790
791 let (knn_ids, knn_distance_map) = if let Some(emb) = embedding {
795 let knn_results = memories::knn_search(&conn, emb, &[namespace.to_string()], None, k)
796 .map_err(|e| format!("knn_search failed: {e}"))?;
797 let ids: Vec<i64> = knn_results.iter().map(|(id, _)| *id).collect();
798 tracing::debug!(target: "deep_research", sub_query_id, knn_count = ids.len(), "KNN complete");
799 let dist_map: crate::hash::AHashMap<i64, f64> = knn_results
800 .iter()
801 .map(|(id, dist)| (*id, *dist as f64))
802 .collect();
803 (ids, dist_map)
804 } else {
805 tracing::debug!(target: "deep_research", sub_query_id, "KNN skipped (no embedding); FTS5-only");
806 (vec![], crate::hash::AHashMap::default())
807 };
808
809 let fts_results = match memories::fts_search(&conn, query_text, namespace, None, k) {
811 Ok(rows) => rows,
812 Err(e) => {
813 tracing::warn!(target: "deep_research",
814 sub_query_id,
815 "FTS5 search failed, continuing with KNN only: {e}"
816 );
817 vec![]
818 }
819 };
820 let fts_ids: Vec<i64> = fts_results.iter().map(|r| r.id).collect();
821 tracing::debug!(target: "deep_research", sub_query_id, fts_count = fts_ids.len(), "FTS complete");
822
823 let rrf_scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], rrf_k);
825 let max_possible = rrf_max_possible(&[1.0, 1.0], rrf_k);
826
827 let mut fused: Vec<(i64, f64)> = rrf_scores.into_iter().collect();
829 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
830 fused.truncate(k * 2);
831 tracing::debug!(target: "deep_research",
832 sub_query_id,
833 fused_count = fused.len(),
834 "RRF fusion complete"
835 );
836
837 if fused.is_empty() && !knn_ids.is_empty() {
838 tracing::warn!(target: "deep_research", sub_query_id, knn_count = knn_ids.len(), fts_count = fts_ids.len(),
839 "RRF fusion returned 0 results despite KNN/FTS hits; consider lowering --graph-min-score");
840 }
841
842 for (memory_id, combined_score) in &fused {
843 if seen_ids.insert(*memory_id) {
844 let normalized = if max_possible > 0.0 {
845 combined_score / max_possible
846 } else {
847 0.0
848 };
849 let score = normalized.clamp(0.0, 1.0);
850 let in_knn = knn_distance_map.contains_key(memory_id);
851 let in_fts = fts_ids.contains(memory_id);
852 let source = match (in_knn, in_fts) {
853 (true, true) => "hybrid",
854 (true, false) => "knn",
855 (false, true) => "fts",
856 (false, false) => "graph",
857 };
858 if let Ok(Some(row)) = memories::read_full(&conn, *memory_id) {
859 let snippet: String = row.body.chars().take(300).collect();
860 hits.push((
861 *memory_id,
862 score,
863 source.to_string(),
864 snippet,
865 row.body,
866 None,
867 ));
868 }
869 }
870 }
871
872 let memory_ids: Vec<i64> = hits.iter().map(|(id, ..)| *id).collect();
875 let mut chains: Vec<EvidenceChain> = Vec::with_capacity(memory_ids.len());
876
877 if !memory_ids.is_empty() && max_hops > 0 {
878 let entity_ids: Vec<i64> = if let Some(emb) = embedding {
880 entities::knn_search(&conn, emb, namespace, 5)
881 .inspect_err(|e| tracing::warn!(target: "deep_research", error = %e, "entity KNN search failed, skipping graph seed"))
882 .unwrap_or_default()
883 .iter()
884 .map(|(id, _)| *id)
885 .collect()
886 } else {
887 vec![]
888 };
889
890 let top_seed_count = 5.min(memory_ids.len());
893 let top_memory_ids = &memory_ids[..top_seed_count];
894 let mut seed_entity_ids: Vec<i64> = entity_ids.clone();
895 for &mem_id in top_memory_ids {
896 let mut stmt = conn
897 .prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")
898 .map_err(|e| format!("prepare failed: {e}"))?;
899 let ids: Vec<i64> = stmt
900 .query_map(rusqlite::params![mem_id], |r| r.get(0))
901 .map_err(|e| format!("query failed: {e}"))?
902 .filter_map(|r| r.ok())
903 .collect();
904 seed_entity_ids.extend(ids);
905 }
906 seed_entity_ids.sort_unstable();
907 seed_entity_ids.dedup();
908 tracing::debug!(target: "deep_research",
909 sub_query_id,
910 seed_count = seed_entity_ids.len(),
911 "seed entities collected"
912 );
913
914 let all_seed_ids: Vec<i64> = memory_ids
915 .iter()
916 .chain(entity_ids.iter())
917 .copied()
918 .collect();
919
920 if let Ok(graph_results) = traverse_from_memories_with_hops_capped(
922 &conn,
923 &all_seed_ids,
924 namespace,
925 min_weight,
926 max_hops as u32,
927 max_neighbors_per_hop,
928 ) {
929 let seed_score_map: crate::hash::AHashMap<i64, f64> = fused
931 .iter()
932 .map(|(id, s)| {
933 let normalized = if max_possible > 0.0 {
934 s / max_possible
935 } else {
936 0.0
937 };
938 (*id, normalized.clamp(0.0, 1.0))
939 })
940 .collect();
941
942 for (graph_mem_id, hop) in graph_results {
943 if seen_ids.insert(graph_mem_id) {
944 let avg_seed_score: f64 = if seed_score_map.is_empty() {
949 0.5
950 } else {
951 let sum: f64 = seed_score_map.values().sum();
952 sum / seed_score_map.len() as f64
953 };
954 let graph_score =
955 (avg_seed_score * graph_decay.powi(hop as i32)).clamp(0.0, 1.0);
956
957 if graph_score < graph_min_score {
958 continue;
959 }
960
961 if let Ok(Some(row)) = memories::read_full(&conn, graph_mem_id) {
962 let snippet: String = row.body.chars().take(300).collect();
963 hits.push((
964 graph_mem_id,
965 graph_score,
966 "graph".to_string(),
967 snippet,
968 row.body,
969 Some(hop as usize),
970 ));
971 }
972 }
973 }
974 }
975
976 if !seed_entity_ids.is_empty() {
979 let (entity_depth, predecessor) = bfs_with_predecessors(
980 &conn,
981 &seed_entity_ids,
982 namespace,
983 min_weight,
984 max_hops as u32,
985 max_neighbors_per_hop,
986 )
987 .unwrap_or_default();
988
989 tracing::debug!(target: "deep_research",
990 sub_query_id,
991 bfs_nodes = entity_depth.len(),
992 predecessors = predecessor.len(),
993 "BFS complete"
994 );
995
996 let seed_entity_set: HashSet<i64> = seed_entity_ids.iter().copied().collect();
997
998 let all_entity_ids: Vec<i64> = entity_depth.keys().copied().collect();
1000 let mut entity_names: crate::hash::AHashMap<i64, String> =
1001 crate::hash::AHashMap::with_capacity_and_hasher(
1002 all_entity_ids.len(),
1003 ahash::RandomState::default(),
1004 );
1005 for &eid in &all_entity_ids {
1006 let name_res: rusqlite::Result<String> = conn.query_row(
1007 "SELECT name FROM entities WHERE id = ?1",
1008 rusqlite::params![eid],
1009 |r| r.get(0),
1010 );
1011 if let Ok(name) = name_res {
1012 entity_names.insert(eid, name);
1013 }
1014 }
1015
1016 for (&target_id, &_hop) in &entity_depth {
1018 if seed_entity_set.contains(&target_id) {
1019 continue;
1020 }
1021 if !predecessor.contains_key(&target_id) {
1022 continue;
1023 }
1024 if let Some((path_nodes, total_weight)) =
1025 reconstruct_path(target_id, &seed_entity_set, &predecessor, &entity_names)
1026 {
1027 if path_nodes.len() < 2 {
1028 continue;
1029 }
1030 let from = path_nodes
1031 .first()
1032 .map(|n| n.entity.clone())
1033 .unwrap_or_default();
1034 let to = path_nodes
1035 .last()
1036 .map(|n| n.entity.clone())
1037 .unwrap_or_default();
1038 let depth = path_nodes.len();
1039 chains.push(EvidenceChain {
1040 from,
1041 to,
1042 path: path_nodes,
1043 total_weight,
1044 depth,
1045 sub_query_ids: vec![sub_query_id],
1046 });
1047 }
1048 }
1049
1050 chains.sort_by(|a, b| {
1052 b.total_weight
1053 .partial_cmp(&a.total_weight)
1054 .unwrap_or(std::cmp::Ordering::Equal)
1055 });
1056 chains.truncate(20);
1057 tracing::debug!(target: "deep_research",
1058 sub_query_id,
1059 chains_count = chains.len(),
1060 "evidence chains built"
1061 );
1062 }
1063 }
1064
1065 Ok(SubQueryResult {
1066 sub_query_id,
1067 hits,
1068 chains,
1069 })
1070}
1071
1072#[cfg(test)]
1078mod tests {
1079 use super::*;
1080
1081 #[test]
1082 fn test_decompose_and_conjunction() {
1083 let result = decompose_query("A and B", 7);
1084 assert_eq!(result, vec!["A", "B"]);
1085 }
1086
1087 #[test]
1088 fn test_decompose_no_split() {
1089 let result = decompose_query("simple query", 7);
1090 assert_eq!(result, vec!["simple query"]);
1091 }
1092
1093 #[test]
1094 fn test_decompose_three_parts() {
1095 let result = decompose_query("A, B and C", 7);
1096 assert_eq!(result, vec!["A", "B", "C"]);
1097 }
1098
1099 #[test]
1100 fn test_decompose_portuguese_conjunctions() {
1101 let result = decompose_query("A e B", 7);
1102 assert_eq!(result, vec!["A", "B"]);
1103 }
1104
1105 #[test]
1106 fn test_decompose_max_cap() {
1107 let parts: Vec<String> = (0..10).map(|i| format!("part{i}")).collect();
1108 let query = parts.join(", ");
1109 let result = decompose_query(&query, 7);
1110 assert!(
1111 result.len() <= 7,
1112 "expected at most 7 sub-queries, got {}",
1113 result.len()
1114 );
1115 }
1116
1117 #[test]
1118 fn test_decompose_empty_preserves_original() {
1119 let result = decompose_query("", 7);
1120 assert_eq!(result, vec![""]);
1121 }
1122
1123 #[test]
1124 fn test_decompose_semicolons() {
1125 let result = decompose_query("auth design; deployment config; logging", 7);
1126 assert_eq!(result, vec!["auth design", "deployment config", "logging"]);
1127 }
1128
1129 #[test]
1130 fn test_decompose_relational_phrase() {
1131 let result = decompose_query("auth that caused deployment failure", 7);
1132 assert_eq!(result, vec!["auth", "deployment failure"]);
1133 }
1134
1135 #[test]
1136 fn test_sub_query_serialization() {
1137 let sq = SubQuery {
1138 id: 0,
1139 text: "test query".to_string(),
1140 source: "original",
1141 };
1142 let json = serde_json::to_value(&sq).expect("serialization failed");
1143 assert_eq!(json["id"], 0);
1144 assert_eq!(json["text"], "test query");
1145 assert_eq!(json["source"], "original");
1146 }
1147
1148 #[test]
1149 fn test_deep_result_omits_body_when_none() {
1150 let result = DeepResult {
1151 name: "test".to_string(),
1152 score: 0.9,
1153 source: "knn".to_string(),
1154 sub_query_ids: vec![0],
1155 snippet: "snippet".to_string(),
1156 body: None,
1157 hop_distance: None,
1158 };
1159 let json = serde_json::to_string(&result).expect("serialization failed");
1160 assert!(!json.contains("\"body\""), "body must be omitted when None");
1161 }
1162
1163 #[test]
1164 fn test_deep_result_includes_body_when_some() {
1165 let result = DeepResult {
1166 name: "test".to_string(),
1167 score: 0.9,
1168 source: "knn".to_string(),
1169 sub_query_ids: vec![0, 1],
1170 snippet: "snippet".to_string(),
1171 body: Some("full body content".to_string()),
1172 hop_distance: Some(2),
1173 };
1174 let json = serde_json::to_string(&result).expect("serialization failed");
1175 assert!(json.contains("\"body\""), "body must be present when Some");
1176 assert!(json.contains("full body content"));
1177 }
1178
1179 #[test]
1180 fn test_evidence_node_omits_none_fields() {
1181 let node = EvidenceNode {
1182 entity: "auth-module".to_string(),
1183 relation: None,
1184 weight: None,
1185 };
1186 let json = serde_json::to_string(&node).expect("serialization failed");
1187 assert!(
1188 !json.contains("\"relation\""),
1189 "relation must be omitted when None"
1190 );
1191 assert!(
1192 !json.contains("\"weight\""),
1193 "weight must be omitted when None"
1194 );
1195 }
1196
1197 #[test]
1198 fn test_research_stats_serialization() {
1199 let stats = ResearchStats {
1200 sub_queries_total: 3,
1201 sub_queries_completed: 2,
1202 sub_queries_failed: 1,
1203 sub_queries_timed_out: 0,
1204 unique_memories_found: 10,
1205 evidence_chains_found: 2,
1206 elapsed_ms: 1234,
1207 vec_degraded: false,
1208 };
1209 let json = serde_json::to_value(&stats).expect("serialization failed");
1210 assert_eq!(json["sub_queries_total"], 3);
1211 assert_eq!(json["sub_queries_completed"], 2);
1212 assert_eq!(json["sub_queries_failed"], 1);
1213 assert_eq!(json["elapsed_ms"], 1234);
1214 }
1215
1216 #[test]
1217 fn test_deep_research_response_serialization() {
1218 let resp = DeepResearchResponse {
1219 query: "test query".to_string(),
1220 sub_queries: vec![SubQuery {
1221 id: 0,
1222 text: "test query".to_string(),
1223 source: "original",
1224 }],
1225 results: vec![],
1226 evidence_chains: vec![],
1227 graph_context: None,
1228 stats: ResearchStats {
1229 sub_queries_total: 1,
1230 sub_queries_completed: 1,
1231 sub_queries_failed: 0,
1232 sub_queries_timed_out: 0,
1233 unique_memories_found: 0,
1234 evidence_chains_found: 0,
1235 elapsed_ms: 42,
1236 vec_degraded: false,
1237 },
1238 };
1239 let json = serde_json::to_value(&resp).expect("serialization failed");
1240 assert_eq!(json["query"], "test query");
1241 assert!(json["sub_queries"].is_array());
1242 assert!(json["results"].is_array());
1243 assert!(json["evidence_chains"].is_array());
1244 assert_eq!(json["stats"]["elapsed_ms"], 42);
1245 }
1246
1247 #[test]
1251 fn test_distinct_sub_queries_produce_distinct_texts() {
1252 let queries = [
1253 "authentication design decisions",
1254 "deployment configuration and infrastructure",
1255 ];
1256 assert_ne!(queries[0], queries[1]);
1258
1259 let decomposed = decompose_query(
1261 "authentication design decisions; deployment configuration and infrastructure",
1262 7,
1263 );
1264 assert_eq!(decomposed.len(), 2);
1265 assert_ne!(decomposed[0], decomposed[1]);
1266 }
1267
1268 #[test]
1270 fn test_rrf_fuse_via_fusion_module() {
1271 use crate::storage::fusion::rrf_fuse;
1272
1273 let knn_ids: Vec<i64> = vec![1, 2, 3];
1274 let fts_ids: Vec<i64> = vec![2, 1, 4];
1275 let scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], 60.0);
1276
1277 let score_1 = scores[&1];
1279 let score_2 = scores[&2];
1280 let score_3 = scores[&3]; let score_4 = scores[&4]; assert!(
1284 score_1 > score_3,
1285 "id 1 (both lists) must beat id 3 (knn-only rank 3)"
1286 );
1287 assert!(
1288 score_2 > score_4,
1289 "id 2 (both lists) must beat id 4 (fts-only rank 3)"
1290 );
1291 }
1292
1293 #[test]
1295 fn test_evidence_chain_has_from_to_and_path() {
1296 let chain = EvidenceChain {
1297 from: "auth-module".to_string(),
1298 to: "jwt-service".to_string(),
1299 path: vec![
1300 EvidenceNode {
1301 entity: "auth-module".to_string(),
1302 relation: None,
1303 weight: None,
1304 },
1305 EvidenceNode {
1306 entity: "token-validator".to_string(),
1307 relation: Some("depends-on".to_string()),
1308 weight: Some(0.9),
1309 },
1310 EvidenceNode {
1311 entity: "jwt-service".to_string(),
1312 relation: Some("uses".to_string()),
1313 weight: Some(0.8),
1314 },
1315 ],
1316 total_weight: 0.72,
1317 depth: 3,
1318 sub_query_ids: vec![0],
1319 };
1320
1321 let json = serde_json::to_value(&chain).expect("serialization failed");
1322 assert!(
1323 json["from"].is_string(),
1324 "evidence chain must have 'from' field"
1325 );
1326 assert!(
1327 json["to"].is_string(),
1328 "evidence chain must have 'to' field"
1329 );
1330 assert!(
1331 json["path"].is_array(),
1332 "evidence chain must have 'path' array"
1333 );
1334 assert_eq!(json["path"].as_array().unwrap().len(), 3);
1335 assert!(json["total_weight"].is_number(), "must have total_weight");
1336 assert_eq!(json["depth"], 3);
1337 }
1338
1339 #[test]
1341 fn test_reconstruct_path_root_to_target_order() {
1342 let seed_set: HashSet<i64> = [10i64].into_iter().collect();
1344 let mut predecessor: PredecessorMap = std::collections::HashMap::new();
1345 predecessor.insert(20, (10, "depends-on".to_string(), 0.9));
1346 predecessor.insert(30, (20, "uses".to_string(), 0.8));
1347 let mut entity_names: crate::hash::AHashMap<i64, String> = crate::hash::AHashMap::default();
1348 entity_names.insert(10, "seed-entity".to_string());
1349 entity_names.insert(20, "middle-entity".to_string());
1350 entity_names.insert(30, "target-entity".to_string());
1351
1352 let result = reconstruct_path(30, &seed_set, &predecessor, &entity_names);
1353 assert!(result.is_some(), "path must be reconstructed");
1354 let (nodes, weight) = result.unwrap();
1355 assert_eq!(nodes.len(), 3);
1357 assert_eq!(nodes[0].entity, "seed-entity");
1358 assert_eq!(nodes[1].entity, "middle-entity");
1359 assert_eq!(nodes[2].entity, "target-entity");
1360 assert!((weight - 0.72).abs() < 1e-6);
1362 }
1363
1364 #[test]
1366 fn test_evidence_chains_single_hop_filtered_out() {
1367 let chain = EvidenceChain {
1369 from: "a".to_string(),
1370 to: "a".to_string(),
1371 path: vec![EvidenceNode {
1372 entity: "a".to_string(),
1373 relation: None,
1374 weight: None,
1375 }],
1376 total_weight: 1.0,
1377 depth: 1,
1378 sub_query_ids: vec![0],
1379 };
1380 let chains = vec![chain];
1382 let retained: Vec<_> = chains.into_iter().filter(|c| c.depth >= 2).collect();
1383 assert!(retained.is_empty(), "depth-1 chains must be filtered out");
1384 }
1385
1386 #[test]
1388 fn test_bfs_with_predecessors_respects_neighbor_cap() {
1389 use crate::graph::bfs_with_predecessors;
1390 use rusqlite::Connection;
1391
1392 let conn = Connection::open_in_memory().unwrap();
1393 conn.execute_batch(
1394 "CREATE TABLE relationships (
1395 source_id INTEGER NOT NULL,
1396 target_id INTEGER NOT NULL,
1397 weight REAL NOT NULL,
1398 namespace TEXT NOT NULL,
1399 relation TEXT NOT NULL DEFAULT 'related'
1400 );",
1401 )
1402 .unwrap();
1403
1404 for target in 2i64..=6 {
1406 conn.execute(
1407 "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, 'ns')",
1408 rusqlite::params![1i64, target, 1.0f64],
1409 )
1410 .unwrap();
1411 }
1412
1413 let (depth_uncapped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, None).unwrap();
1415 assert_eq!(
1416 depth_uncapped.len() - 1,
1417 5,
1418 "uncapped must discover all 5 neighbours (plus seed)"
1419 );
1420
1421 let (depth_capped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, Some(2)).unwrap();
1423 assert_eq!(
1425 depth_capped.len(),
1426 3,
1427 "capped to 2 must yield seed + 2 neighbours"
1428 );
1429 }
1430}