1use crate::errors::AppError;
8use crate::graph::{
9 bfs_with_predecessors, traverse_from_memories_with_hops_capped, PredecessorMap,
10};
11use crate::output;
12use crate::paths::AppPaths;
13use crate::storage::connection::open_ro;
14use crate::storage::fusion::{rrf_fuse, rrf_max_possible};
15use crate::storage::{entities, memories};
16
17use serde::Serialize;
18use std::collections::HashSet;
19use std::sync::Arc;
20use tokio::sync::Semaphore;
21use tokio::task::JoinSet;
22
23#[derive(clap::Args)]
25#[command(
26 about = "Deep parallel multi-hop GraphRAG research via query decomposition",
27 after_long_help = "EXAMPLES:\n \
28 # Basic deep research\n \
29 sqlite-graphrag deep-research \"auth architecture decisions\"\n\n \
30 # With custom parameters\n \
31 sqlite-graphrag deep-research \"auth\" --k 20 --max-hops 3 --max-sub-queries 7\n\n \
32 # Include full memory bodies in output\n \
33 sqlite-graphrag deep-research \"auth\" --with-bodies\n\n \
34 # Tune RRF and graph scoring\n \
35 sqlite-graphrag deep-research \"auth and deployment\" --rrf-k 60 --graph-decay 0.7"
36)]
37pub struct DeepResearchArgs {
38 #[arg(
40 value_name = "QUERY",
41 allow_hyphen_values = true,
42 help = "Research query to decompose and search"
43 )]
44 pub query: String,
45 #[arg(
47 long,
48 short,
49 aliases = ["limit", "top-k"],
50 default_value_t = 20,
51 help = "Results per sub-query (Recall@20 captures 95%+ relevant hits)"
52 )]
53 pub k: usize,
54 #[arg(
56 long,
57 default_value_t = 7,
58 help = "Maximum sub-queries (covers complex multi-hop queries)"
59 )]
60 pub max_sub_queries: usize,
61 #[arg(
63 long,
64 default_value_t = 3,
65 help = "Multi-hop graph traversal depth (sweet spot: 2-3 hops)"
66 )]
67 pub max_hops: usize,
68 #[arg(
70 long,
71 default_value_t = 0.3,
72 help = "Minimum edge weight for graph traversal"
73 )]
74 pub min_weight: f64,
75 #[arg(long, help = "Maximum concurrent sub-queries (default: min(cpus, 8))")]
77 pub max_concurrency: Option<usize>,
78 #[arg(long, default_value_t = 30, help = "Timeout per sub-query in seconds")]
80 pub timeout: u64,
81 #[arg(
83 long,
84 default_value_t = false,
85 help = "Include full memory bodies in results"
86 )]
87 pub with_bodies: bool,
88 #[arg(
90 long,
91 default_value_t = 50,
92 help = "Maximum results after deduplication"
93 )]
94 pub max_results: usize,
95 #[arg(
97 long,
98 default_value_t = 60.0,
99 help = "RRF k parameter (higher = less weight on top ranks)"
100 )]
101 pub rrf_k: f64,
102 #[arg(
104 long,
105 default_value_t = 0.7,
106 help = "Graph score decay factor per hop (0.0-1.0)"
107 )]
108 pub graph_decay: f64,
109 #[arg(
111 long,
112 default_value_t = 0.05,
113 help = "Minimum score threshold for graph-expanded results"
114 )]
115 pub graph_min_score: f64,
116 #[arg(
118 long,
119 help = "Limit neighbours per entity per hop for graph traversal (default: unlimited)"
120 )]
121 pub max_neighbors_per_hop: Option<usize>,
122 #[arg(
124 long,
125 help = "Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global)"
126 )]
127 pub namespace: Option<String>,
128 #[arg(long, default_value = "none", value_parser = ["none"], hide = true)]
130 pub mode: String,
131 #[arg(
133 long,
134 value_name = "USD",
135 help = "Max LLM cost in USD (effective with --mode claude-code/codex)"
136 )]
137 pub max_cost_usd: Option<f64>,
138 #[arg(long, hide = true)]
140 pub json: bool,
141 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
143 pub db: Option<String>,
144}
145
146#[derive(Serialize)]
147struct SubQuery {
148 id: usize,
149 text: String,
150 source: &'static str,
151}
152
153#[derive(Serialize)]
154struct DeepResult {
155 name: String,
156 score: f64,
157 source: String,
158 sub_query_ids: Vec<usize>,
159 snippet: String,
160 #[serde(skip_serializing_if = "Option::is_none")]
161 body: Option<String>,
162 hop_distance: Option<usize>,
163}
164
165#[derive(Serialize, Clone)]
167struct EvidenceNode {
168 entity: String,
169 #[serde(skip_serializing_if = "Option::is_none")]
170 relation: Option<String>,
171 #[serde(skip_serializing_if = "Option::is_none")]
172 weight: Option<f64>,
173}
174
175#[derive(Serialize)]
184struct EvidenceChain {
185 from: String,
186 to: String,
187 path: Vec<EvidenceNode>,
188 total_weight: f64,
189 depth: usize,
190 sub_query_ids: Vec<usize>,
191}
192
193#[derive(Serialize)]
194struct ResearchStats {
195 sub_queries_total: usize,
196 sub_queries_completed: usize,
197 sub_queries_failed: usize,
198 sub_queries_timed_out: usize,
199 unique_memories_found: usize,
200 evidence_chains_found: usize,
201 elapsed_ms: u64,
202}
203
204#[derive(Serialize)]
205struct GraphContextEntity {
206 name: String,
207 entity_type: String,
208 degree: u32,
209}
210
211#[derive(Serialize)]
212struct GraphContextRel {
213 from: String,
214 to: String,
215 relation: String,
216 weight: f64,
217}
218
219#[derive(Serialize)]
220struct GraphContext {
221 entities: Vec<GraphContextEntity>,
222 relationships: Vec<GraphContextRel>,
223}
224
225#[derive(Serialize)]
226struct DeepResearchResponse {
227 query: String,
228 sub_queries: Vec<SubQuery>,
229 results: Vec<DeepResult>,
230 evidence_chains: Vec<EvidenceChain>,
231 #[serde(skip_serializing_if = "Option::is_none")]
232 graph_context: Option<GraphContext>,
233 stats: ResearchStats,
234}
235
236type MergedHit = (f64, String, String, String, Option<usize>, Vec<usize>);
238
239struct SubQueryResult {
241 sub_query_id: usize,
242 hits: Vec<(i64, f64, String, String, String, Option<usize>)>,
244 chains: Vec<EvidenceChain>,
246}
247
248#[tracing::instrument(skip_all, level = "debug", name = "deep_research")]
250pub fn run(args: DeepResearchArgs) -> Result<(), AppError> {
251 tracing::debug!(target: "deep_research", query = %args.query, k = args.k, "starting deep research");
252 let rt = tokio::runtime::Builder::new_multi_thread()
253 .worker_threads(2)
254 .enable_all()
255 .build()
256 .map_err(|e| AppError::Internal(anyhow::anyhow!("failed to build tokio runtime: {e}")))?;
257 rt.block_on(run_async(args))
258}
259
260async fn run_async(args: DeepResearchArgs) -> Result<(), AppError> {
262 let start = std::time::Instant::now();
263
264 if args.query.trim().is_empty() {
265 return Err(AppError::Validation(crate::i18n::validation::empty_query()));
266 }
267
268 if args.max_cost_usd.is_some() && args.mode == "none" {
269 tracing::warn!(target: "deep_research", "--max-cost-usd has no effect without --mode claude-code/codex");
270 }
271
272 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
273 let paths = AppPaths::resolve(args.db.as_deref())?;
274 crate::storage::connection::ensure_db_ready(&paths)?;
275
276 let sub_query_texts = decompose_query(&args.query, args.max_sub_queries);
278 let sub_queries: Vec<SubQuery> = sub_query_texts
279 .iter()
280 .enumerate()
281 .map(|(i, text)| SubQuery {
282 id: i,
283 text: text.clone(),
284 source: if sub_query_texts.len() == 1 {
285 "original"
286 } else {
287 "decomposed"
288 },
289 })
290 .collect();
291
292 output::emit_progress_i18n(
296 "Computing per-sub-query embeddings...",
297 "Calculando embeddings por sub-consulta...",
298 );
299 let mut sub_embeddings: Vec<Arc<Vec<f32>>> = Vec::with_capacity(sub_query_texts.len());
300 for sq_text in &sub_query_texts {
301 let emb = crate::embedder::embed_query_local(&paths.models, sq_text)?;
302 sub_embeddings.push(Arc::new(emb));
303 }
304
305 let cpu_count = std::thread::available_parallelism()
307 .map(|n| n.get())
308 .unwrap_or(4);
309 let permits = args
310 .max_concurrency
311 .unwrap_or_else(|| cpu_count.min(8))
312 .min(sub_queries.len())
313 .max(1);
314 let semaphore = Arc::new(Semaphore::new(permits));
315 let timeout_dur = std::time::Duration::from_secs(args.timeout);
316
317 let mut join_set: JoinSet<Result<SubQueryResult, (usize, String)>> = JoinSet::new();
318
319 for (idx, sq_text) in sub_query_texts.iter().enumerate() {
320 let sem = Arc::clone(&semaphore);
321 let emb = Arc::clone(&sub_embeddings[idx]);
323 let ns = namespace.clone();
324 let db_path = paths.db.clone();
325 let query_text = sq_text.clone();
326 let k = args.k;
327 let max_hops = args.max_hops;
328 let min_weight = args.min_weight;
329 let rrf_k = args.rrf_k;
330 let graph_decay = args.graph_decay;
331 let graph_min_score = args.graph_min_score;
332 let max_neighbors_per_hop = args.max_neighbors_per_hop;
333
334 join_set.spawn(async move {
335 let _permit = sem
336 .acquire_owned()
337 .await
338 .map_err(|e| (idx, format!("semaphore closed: {e}")))?;
339
340 let result = tokio::time::timeout(timeout_dur, async move {
342 execute_sub_query(
343 idx,
344 &query_text,
345 emb.as_slice(),
346 &ns,
347 &db_path,
348 k,
349 max_hops,
350 min_weight,
351 rrf_k,
352 graph_decay,
353 graph_min_score,
354 max_neighbors_per_hop,
355 )
356 })
357 .await;
358
359 match result {
360 Ok(inner) => inner.map_err(|e| (idx, e)),
361 Err(_) => Err((idx, "timeout".to_string())),
362 }
363 });
364 }
365
366 let mut sub_query_results: Vec<SubQueryResult> = Vec::with_capacity(sub_queries.len());
368 let mut failed_count = 0usize;
369 let mut timed_out_count = 0usize;
370
371 while let Some(join_result) = join_set.join_next().await {
372 match join_result {
373 Ok(Ok(sqr)) => sub_query_results.push(sqr),
374 Ok(Err((_idx, reason))) => {
375 if reason == "timeout" {
376 timed_out_count += 1;
377 } else {
378 failed_count += 1;
379 }
380 tracing::warn!(target: "deep_research", sub_query_id = _idx, reason = %reason, "sub-query failed");
381 }
382 Err(join_err) => {
383 failed_count += 1;
384 if join_err.is_panic() {
385 tracing::error!(target: "deep_research", error = %join_err, "sub-query task panicked");
386 } else {
387 tracing::warn!(target: "deep_research", error = %join_err, "sub-query task cancelled");
388 }
389 }
390 }
391 }
392
393 let mut merged: crate::hash::AHashMap<i64, MergedHit> =
396 crate::hash::AHashMap::with_capacity_and_hasher(
397 sub_query_results.len() * args.k,
398 Default::default(),
399 );
400
401 for sqr in &sub_query_results {
402 for (mem_id, score, source, snippet, body, hop) in &sqr.hits {
403 let entry = merged.entry(*mem_id).or_insert_with(|| {
404 (
405 *score,
406 source.clone(),
407 snippet.clone(),
408 body.clone(),
409 *hop,
410 Vec::new(),
411 )
412 });
413 if *score > entry.0 {
415 entry.0 = *score;
416 entry.1 = source.clone();
417 entry.2 = snippet.clone();
418 entry.3 = body.clone();
419 entry.4 = *hop;
420 }
421 if !entry.5.contains(&sqr.sub_query_id) {
422 entry.5.push(sqr.sub_query_id);
423 }
424 }
425 }
426
427 let conn = open_ro(&paths.db)?;
429 let mut results: Vec<DeepResult> = Vec::with_capacity(merged.len().min(args.max_results));
430
431 let mut ranked: Vec<(i64, MergedHit)> = merged.into_iter().collect();
433 ranked.sort_by(|a, b| {
434 b.1 .0
435 .partial_cmp(&a.1 .0)
436 .unwrap_or(std::cmp::Ordering::Equal)
437 });
438 ranked.truncate(args.max_results);
439
440 for (mem_id, (score, source, snippet, body, hop, sq_ids)) in ranked {
441 let name = match memories::read_full(&conn, mem_id)? {
442 Some(row) => row.name,
443 None => continue,
444 };
445 results.push(DeepResult {
446 name,
447 score,
448 source,
449 sub_query_ids: sq_ids,
450 snippet,
451 body: if args.with_bodies { Some(body) } else { None },
452 hop_distance: hop,
453 });
454 }
455
456 let completed_count = sub_query_results.len();
460 let mut evidence_chains: Vec<EvidenceChain> = Vec::with_capacity(completed_count * 2);
461 let mut seen_chain_keys: HashSet<String> = HashSet::with_capacity(completed_count * 2);
462
463 for sqr in sub_query_results {
464 for chain in sqr.chains {
465 let key = format!("{}->{}", chain.from, chain.to);
467 if seen_chain_keys.insert(key) {
468 evidence_chains.push(chain);
469 }
470 }
471 }
472
473 evidence_chains.retain(|c| c.depth >= 2);
475 evidence_chains.sort_by(|a, b| {
476 b.total_weight
477 .partial_cmp(&a.total_weight)
478 .unwrap_or(std::cmp::Ordering::Equal)
479 });
480
481 let unique_memories = results.len();
482 let evidence_count = evidence_chains.len();
483
484 let graph_context = if !results.is_empty() {
486 let result_names: Vec<&str> = results.iter().map(|r| r.name.as_str()).collect();
487 let mut ctx_entities: Vec<GraphContextEntity> = Vec::with_capacity(results.len());
488 let mut ctx_rels: Vec<GraphContextRel> = Vec::with_capacity(results.len() * 2);
489 let mut seen_entity_ids: crate::hash::AHashSet<i64> =
490 crate::hash::AHashSet::with_capacity_and_hasher(results.len(), Default::default());
491
492 for name in &result_names {
493 if let Ok(Some(eid)) = entities::find_entity_id(&conn, &namespace, name) {
494 if seen_entity_ids.insert(eid) {
495 let etype: String = conn
496 .query_row(
497 "SELECT COALESCE(type,'concept') FROM entities WHERE id = ?1",
498 rusqlite::params![eid],
499 |r| r.get(0),
500 )
501 .unwrap_or_else(|_| "concept".to_string());
502 let degree: u32 = conn
503 .query_row(
504 "SELECT COUNT(*) FROM relationships WHERE source_id = ?1 OR target_id = ?1",
505 rusqlite::params![eid],
506 |r| r.get(0),
507 )
508 .unwrap_or(0);
509 ctx_entities.push(GraphContextEntity {
510 name: name.to_string(),
511 entity_type: etype,
512 degree,
513 });
514 }
515 }
516 }
517
518 let entity_ids: Vec<i64> = seen_entity_ids.iter().copied().collect();
519 if entity_ids.len() >= 2 {
520 let placeholders: String = entity_ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
521 let sql = format!(
522 "SELECT s.name, t.name, r.relation, r.weight \
523 FROM relationships r \
524 JOIN entities s ON s.id = r.source_id \
525 JOIN entities t ON t.id = r.target_id \
526 WHERE r.source_id IN ({placeholders}) AND r.target_id IN ({placeholders}) \
527 LIMIT 50"
528 );
529 if let Ok(mut stmt) = conn.prepare(&sql) {
530 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> =
531 Vec::with_capacity(entity_ids.len() * 2);
532 for id in &entity_ids {
533 params.push(Box::new(*id));
534 }
535 for id in &entity_ids {
536 params.push(Box::new(*id));
537 }
538 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
539 params.iter().map(|p| p.as_ref()).collect();
540 if let Ok(rows) = stmt.query_map(param_refs.as_slice(), |r| {
541 Ok((
542 r.get::<_, String>(0)?,
543 r.get::<_, String>(1)?,
544 r.get::<_, String>(2)?,
545 r.get::<_, f64>(3)?,
546 ))
547 }) {
548 for row in rows.flatten() {
549 ctx_rels.push(GraphContextRel {
550 from: row.0,
551 to: row.1,
552 relation: row.2,
553 weight: row.3,
554 });
555 }
556 }
557 }
558 }
559
560 if ctx_entities.is_empty() {
561 None
562 } else {
563 Some(GraphContext {
564 entities: ctx_entities,
565 relationships: ctx_rels,
566 })
567 }
568 } else {
569 None
570 };
571
572 tracing::debug!(target: "deep_research",
573 total_results = results.len(),
574 total_chains = evidence_chains.len(),
575 "assembly complete"
576 );
577
578 output::emit_json(&DeepResearchResponse {
580 query: args.query,
581 sub_queries,
582 results,
583 evidence_chains,
584 graph_context,
585 stats: ResearchStats {
586 sub_queries_total: sub_query_texts.len(),
587 sub_queries_completed: completed_count,
588 sub_queries_failed: failed_count,
589 sub_queries_timed_out: timed_out_count,
590 unique_memories_found: unique_memories,
591 evidence_chains_found: evidence_count,
592 elapsed_ms: start.elapsed().as_millis() as u64,
593 },
594 })?;
595
596 Ok(())
597}
598
599fn decompose_query(query: &str, max: usize) -> Vec<String> {
602 if query.is_empty() {
603 return vec![query.to_string()];
604 }
605
606 let mut parts: Vec<String> = Vec::with_capacity(max);
607
608 let relational = [
610 " that caused ",
611 " depending on ",
612 " related to ",
613 " connected to ",
614 " linked to ",
615 " caused by ",
616 " followed by ",
617 ];
618 let mut text = query.to_string();
619 let mut did_relational_split = false;
620 for phrase in &relational {
621 if text.to_lowercase().contains(phrase) {
622 let lower = text.to_lowercase();
623 if let Some(pos) = lower.find(phrase) {
624 let left = text[..pos].trim().to_string();
625 let right = text[pos + phrase.len()..].trim().to_string();
626 if !left.is_empty() {
627 parts.push(left);
628 }
629 if !right.is_empty() {
630 text = right;
631 }
632 did_relational_split = true;
633 }
634 }
635 }
636 if did_relational_split && !text.is_empty() {
637 parts.push(text.clone());
638 }
639
640 if parts.is_empty() {
642 let semi_parts: Vec<&str> = query.split(';').collect();
644 if semi_parts.len() > 1 {
645 for p in &semi_parts {
646 let trimmed = p.trim();
647 if !trimmed.is_empty() {
648 parts.push(trimmed.to_string());
649 }
650 }
651 } else {
652 let normalized = query
655 .replace(" and ", ", ")
656 .replace(" AND ", ", ")
657 .replace(" e ", ", ")
658 .replace(" E ", ", ");
659 let comma_parts: Vec<&str> = normalized.split(',').collect();
660 if comma_parts.len() > 1 {
661 for p in &comma_parts {
662 let trimmed = p.trim();
663 if !trimmed.is_empty() {
664 parts.push(trimmed.to_string());
665 }
666 }
667 }
668 }
669 }
670
671 if parts.is_empty() {
673 let words: Vec<&str> = query.split_whitespace().filter(|w| w.len() > 2).collect();
674 if words.len() >= 3 {
675 parts.push(query.to_string());
676 parts.push(format!("{} {}", words[0], words[1]));
677 parts.push(format!(
678 "{} {}",
679 words[words.len() - 2],
680 words[words.len() - 1]
681 ));
682 }
683 }
684
685 if parts.is_empty() {
686 return vec![query.to_string()];
687 }
688
689 parts.truncate(max);
691 parts
692}
693
694fn reconstruct_path(
698 target_id: i64,
699 seed_entity_ids: &HashSet<i64>,
700 predecessor: &PredecessorMap,
701 entity_names: &crate::hash::AHashMap<i64, String>,
702) -> Option<(Vec<EvidenceNode>, f64)> {
703 let mut path_ids: Vec<(i64, Option<String>, Option<f64>)> = Vec::with_capacity(8);
704 let mut total_weight = 1.0_f64;
705 let mut current = target_id;
706
707 loop {
708 if seed_entity_ids.contains(¤t) {
709 break;
710 }
711 let (parent, relation, weight) = predecessor.get(¤t)?;
712 total_weight *= weight;
713 path_ids.push((current, Some(relation.clone()), Some(*weight)));
714 current = *parent;
715 }
716 path_ids.push((current, None, None));
718
719 path_ids.reverse();
721
722 let nodes: Vec<EvidenceNode> = path_ids
723 .into_iter()
724 .map(|(id, relation, weight)| EvidenceNode {
725 entity: entity_names
726 .get(&id)
727 .cloned()
728 .unwrap_or_else(|| format!("entity-{id}")),
729 relation,
730 weight,
731 })
732 .collect();
733
734 Some((nodes, total_weight))
735}
736
737#[allow(clippy::too_many_arguments)]
747fn execute_sub_query(
748 sub_query_id: usize,
749 query_text: &str,
750 embedding: &[f32],
751 namespace: &str,
752 db_path: &std::path::Path,
753 k: usize,
754 max_hops: usize,
755 min_weight: f64,
756 rrf_k: f64,
757 graph_decay: f64,
758 graph_min_score: f64,
759 max_neighbors_per_hop: Option<usize>,
760) -> Result<SubQueryResult, String> {
761 let conn = open_ro(db_path).map_err(|e| format!("failed to open db: {e}"))?;
762
763 let mut hits: Vec<(i64, f64, String, String, String, Option<usize>)> =
764 Vec::with_capacity(k * 2);
765 let mut seen_ids: crate::hash::AHashSet<i64> =
766 crate::hash::AHashSet::with_capacity_and_hasher(k * 2, Default::default());
767
768 let knn_results = memories::knn_search(&conn, embedding, &[namespace.to_string()], None, k)
772 .map_err(|e| format!("knn_search failed: {e}"))?;
773 let knn_ids: Vec<i64> = knn_results.iter().map(|(id, _)| *id).collect();
774 tracing::debug!(target: "deep_research", sub_query_id, knn_count = knn_ids.len(), "KNN complete");
775
776 let knn_distance_map: crate::hash::AHashMap<i64, f64> = knn_results
778 .iter()
779 .map(|(id, dist)| (*id, *dist as f64))
780 .collect();
781
782 let fts_results = match memories::fts_search(&conn, query_text, namespace, None, k) {
784 Ok(rows) => rows,
785 Err(e) => {
786 tracing::warn!(target: "deep_research",
787 sub_query_id,
788 "FTS5 search failed, continuing with KNN only: {e}"
789 );
790 vec![]
791 }
792 };
793 let fts_ids: Vec<i64> = fts_results.iter().map(|r| r.id).collect();
794 tracing::debug!(target: "deep_research", sub_query_id, fts_count = fts_ids.len(), "FTS complete");
795
796 let rrf_scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], rrf_k);
798 let max_possible = rrf_max_possible(&[1.0, 1.0], rrf_k);
799
800 let mut fused: Vec<(i64, f64)> = rrf_scores.into_iter().collect();
802 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
803 fused.truncate(k * 2);
804 tracing::debug!(target: "deep_research",
805 sub_query_id,
806 fused_count = fused.len(),
807 "RRF fusion complete"
808 );
809
810 if fused.is_empty() && !knn_ids.is_empty() {
811 tracing::warn!(target: "deep_research", sub_query_id, knn_count = knn_ids.len(), fts_count = fts_ids.len(),
812 "RRF fusion returned 0 results despite KNN/FTS hits; consider lowering --graph-min-score");
813 }
814
815 for (memory_id, combined_score) in &fused {
816 if seen_ids.insert(*memory_id) {
817 let normalized = if max_possible > 0.0 {
818 combined_score / max_possible
819 } else {
820 0.0
821 };
822 let score = normalized.clamp(0.0, 1.0);
823 let in_knn = knn_distance_map.contains_key(memory_id);
824 let in_fts = fts_ids.contains(memory_id);
825 let source = match (in_knn, in_fts) {
826 (true, true) => "hybrid",
827 (true, false) => "knn",
828 (false, true) => "fts",
829 (false, false) => "graph",
830 };
831 if let Ok(Some(row)) = memories::read_full(&conn, *memory_id) {
832 let snippet: String = row.body.chars().take(300).collect();
833 hits.push((
834 *memory_id,
835 score,
836 source.to_string(),
837 snippet,
838 row.body,
839 None,
840 ));
841 }
842 }
843 }
844
845 let memory_ids: Vec<i64> = hits.iter().map(|(id, ..)| *id).collect();
848 let mut chains: Vec<EvidenceChain> = Vec::with_capacity(memory_ids.len());
849
850 if !memory_ids.is_empty() && max_hops > 0 {
851 let entity_knn = entities::knn_search(&conn, embedding, namespace, 5)
853 .inspect_err(|e| tracing::warn!(target: "deep_research", error = %e, "entity KNN search failed, skipping graph seed"))
854 .unwrap_or_default();
855 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
856
857 let top_seed_count = 5.min(memory_ids.len());
860 let top_memory_ids = &memory_ids[..top_seed_count];
861 let mut seed_entity_ids: Vec<i64> = entity_ids.clone();
862 for &mem_id in top_memory_ids {
863 let mut stmt = conn
864 .prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")
865 .map_err(|e| format!("prepare failed: {e}"))?;
866 let ids: Vec<i64> = stmt
867 .query_map(rusqlite::params![mem_id], |r| r.get(0))
868 .map_err(|e| format!("query failed: {e}"))?
869 .filter_map(|r| r.ok())
870 .collect();
871 seed_entity_ids.extend(ids);
872 }
873 seed_entity_ids.sort_unstable();
874 seed_entity_ids.dedup();
875 tracing::debug!(target: "deep_research",
876 sub_query_id,
877 seed_count = seed_entity_ids.len(),
878 "seed entities collected"
879 );
880
881 let all_seed_ids: Vec<i64> = memory_ids
882 .iter()
883 .chain(entity_ids.iter())
884 .copied()
885 .collect();
886
887 if let Ok(graph_results) = traverse_from_memories_with_hops_capped(
889 &conn,
890 &all_seed_ids,
891 namespace,
892 min_weight,
893 max_hops as u32,
894 max_neighbors_per_hop,
895 ) {
896 let seed_score_map: crate::hash::AHashMap<i64, f64> = fused
898 .iter()
899 .map(|(id, s)| {
900 let normalized = if max_possible > 0.0 {
901 s / max_possible
902 } else {
903 0.0
904 };
905 (*id, normalized.clamp(0.0, 1.0))
906 })
907 .collect();
908
909 for (graph_mem_id, hop) in graph_results {
910 if seen_ids.insert(graph_mem_id) {
911 let avg_seed_score: f64 = if seed_score_map.is_empty() {
916 0.5
917 } else {
918 let sum: f64 = seed_score_map.values().sum();
919 sum / seed_score_map.len() as f64
920 };
921 let graph_score =
922 (avg_seed_score * graph_decay.powi(hop as i32)).clamp(0.0, 1.0);
923
924 if graph_score < graph_min_score {
925 continue;
926 }
927
928 if let Ok(Some(row)) = memories::read_full(&conn, graph_mem_id) {
929 let snippet: String = row.body.chars().take(300).collect();
930 hits.push((
931 graph_mem_id,
932 graph_score,
933 "graph".to_string(),
934 snippet,
935 row.body,
936 Some(hop as usize),
937 ));
938 }
939 }
940 }
941 }
942
943 if !seed_entity_ids.is_empty() {
946 let (entity_depth, predecessor) = bfs_with_predecessors(
947 &conn,
948 &seed_entity_ids,
949 namespace,
950 min_weight,
951 max_hops as u32,
952 max_neighbors_per_hop,
953 )
954 .unwrap_or_default();
955
956 tracing::debug!(target: "deep_research",
957 sub_query_id,
958 bfs_nodes = entity_depth.len(),
959 predecessors = predecessor.len(),
960 "BFS complete"
961 );
962
963 let seed_entity_set: HashSet<i64> = seed_entity_ids.iter().copied().collect();
964
965 let all_entity_ids: Vec<i64> = entity_depth.keys().copied().collect();
967 let mut entity_names: crate::hash::AHashMap<i64, String> =
968 crate::hash::AHashMap::with_capacity_and_hasher(
969 all_entity_ids.len(),
970 ahash::RandomState::default(),
971 );
972 for &eid in &all_entity_ids {
973 let name_res: rusqlite::Result<String> = conn.query_row(
974 "SELECT name FROM entities WHERE id = ?1",
975 rusqlite::params![eid],
976 |r| r.get(0),
977 );
978 if let Ok(name) = name_res {
979 entity_names.insert(eid, name);
980 }
981 }
982
983 for (&target_id, &_hop) in &entity_depth {
985 if seed_entity_set.contains(&target_id) {
986 continue;
987 }
988 if !predecessor.contains_key(&target_id) {
989 continue;
990 }
991 if let Some((path_nodes, total_weight)) =
992 reconstruct_path(target_id, &seed_entity_set, &predecessor, &entity_names)
993 {
994 if path_nodes.len() < 2 {
995 continue;
996 }
997 let from = path_nodes
998 .first()
999 .map(|n| n.entity.clone())
1000 .unwrap_or_default();
1001 let to = path_nodes
1002 .last()
1003 .map(|n| n.entity.clone())
1004 .unwrap_or_default();
1005 let depth = path_nodes.len();
1006 chains.push(EvidenceChain {
1007 from,
1008 to,
1009 path: path_nodes,
1010 total_weight,
1011 depth,
1012 sub_query_ids: vec![sub_query_id],
1013 });
1014 }
1015 }
1016
1017 chains.sort_by(|a, b| {
1019 b.total_weight
1020 .partial_cmp(&a.total_weight)
1021 .unwrap_or(std::cmp::Ordering::Equal)
1022 });
1023 chains.truncate(20);
1024 tracing::debug!(target: "deep_research",
1025 sub_query_id,
1026 chains_count = chains.len(),
1027 "evidence chains built"
1028 );
1029 }
1030 }
1031
1032 Ok(SubQueryResult {
1033 sub_query_id,
1034 hits,
1035 chains,
1036 })
1037}
1038
1039#[cfg(test)]
1045mod tests {
1046 use super::*;
1047
1048 #[test]
1049 fn test_decompose_and_conjunction() {
1050 let result = decompose_query("A and B", 7);
1051 assert_eq!(result, vec!["A", "B"]);
1052 }
1053
1054 #[test]
1055 fn test_decompose_no_split() {
1056 let result = decompose_query("simple query", 7);
1057 assert_eq!(result, vec!["simple query"]);
1058 }
1059
1060 #[test]
1061 fn test_decompose_three_parts() {
1062 let result = decompose_query("A, B and C", 7);
1063 assert_eq!(result, vec!["A", "B", "C"]);
1064 }
1065
1066 #[test]
1067 fn test_decompose_portuguese_conjunctions() {
1068 let result = decompose_query("A e B", 7);
1069 assert_eq!(result, vec!["A", "B"]);
1070 }
1071
1072 #[test]
1073 fn test_decompose_max_cap() {
1074 let parts: Vec<String> = (0..10).map(|i| format!("part{i}")).collect();
1075 let query = parts.join(", ");
1076 let result = decompose_query(&query, 7);
1077 assert!(
1078 result.len() <= 7,
1079 "expected at most 7 sub-queries, got {}",
1080 result.len()
1081 );
1082 }
1083
1084 #[test]
1085 fn test_decompose_empty_preserves_original() {
1086 let result = decompose_query("", 7);
1087 assert_eq!(result, vec![""]);
1088 }
1089
1090 #[test]
1091 fn test_decompose_semicolons() {
1092 let result = decompose_query("auth design; deployment config; logging", 7);
1093 assert_eq!(result, vec!["auth design", "deployment config", "logging"]);
1094 }
1095
1096 #[test]
1097 fn test_decompose_relational_phrase() {
1098 let result = decompose_query("auth that caused deployment failure", 7);
1099 assert_eq!(result, vec!["auth", "deployment failure"]);
1100 }
1101
1102 #[test]
1103 fn test_sub_query_serialization() {
1104 let sq = SubQuery {
1105 id: 0,
1106 text: "test query".to_string(),
1107 source: "original",
1108 };
1109 let json = serde_json::to_value(&sq).expect("serialization failed");
1110 assert_eq!(json["id"], 0);
1111 assert_eq!(json["text"], "test query");
1112 assert_eq!(json["source"], "original");
1113 }
1114
1115 #[test]
1116 fn test_deep_result_omits_body_when_none() {
1117 let result = DeepResult {
1118 name: "test".to_string(),
1119 score: 0.9,
1120 source: "knn".to_string(),
1121 sub_query_ids: vec![0],
1122 snippet: "snippet".to_string(),
1123 body: None,
1124 hop_distance: None,
1125 };
1126 let json = serde_json::to_string(&result).expect("serialization failed");
1127 assert!(!json.contains("\"body\""), "body must be omitted when None");
1128 }
1129
1130 #[test]
1131 fn test_deep_result_includes_body_when_some() {
1132 let result = DeepResult {
1133 name: "test".to_string(),
1134 score: 0.9,
1135 source: "knn".to_string(),
1136 sub_query_ids: vec![0, 1],
1137 snippet: "snippet".to_string(),
1138 body: Some("full body content".to_string()),
1139 hop_distance: Some(2),
1140 };
1141 let json = serde_json::to_string(&result).expect("serialization failed");
1142 assert!(json.contains("\"body\""), "body must be present when Some");
1143 assert!(json.contains("full body content"));
1144 }
1145
1146 #[test]
1147 fn test_evidence_node_omits_none_fields() {
1148 let node = EvidenceNode {
1149 entity: "auth-module".to_string(),
1150 relation: None,
1151 weight: None,
1152 };
1153 let json = serde_json::to_string(&node).expect("serialization failed");
1154 assert!(
1155 !json.contains("\"relation\""),
1156 "relation must be omitted when None"
1157 );
1158 assert!(
1159 !json.contains("\"weight\""),
1160 "weight must be omitted when None"
1161 );
1162 }
1163
1164 #[test]
1165 fn test_research_stats_serialization() {
1166 let stats = ResearchStats {
1167 sub_queries_total: 3,
1168 sub_queries_completed: 2,
1169 sub_queries_failed: 1,
1170 sub_queries_timed_out: 0,
1171 unique_memories_found: 10,
1172 evidence_chains_found: 2,
1173 elapsed_ms: 1234,
1174 };
1175 let json = serde_json::to_value(&stats).expect("serialization failed");
1176 assert_eq!(json["sub_queries_total"], 3);
1177 assert_eq!(json["sub_queries_completed"], 2);
1178 assert_eq!(json["sub_queries_failed"], 1);
1179 assert_eq!(json["elapsed_ms"], 1234);
1180 }
1181
1182 #[test]
1183 fn test_deep_research_response_serialization() {
1184 let resp = DeepResearchResponse {
1185 query: "test query".to_string(),
1186 sub_queries: vec![SubQuery {
1187 id: 0,
1188 text: "test query".to_string(),
1189 source: "original",
1190 }],
1191 results: vec![],
1192 evidence_chains: vec![],
1193 graph_context: None,
1194 stats: ResearchStats {
1195 sub_queries_total: 1,
1196 sub_queries_completed: 1,
1197 sub_queries_failed: 0,
1198 sub_queries_timed_out: 0,
1199 unique_memories_found: 0,
1200 evidence_chains_found: 0,
1201 elapsed_ms: 42,
1202 },
1203 };
1204 let json = serde_json::to_value(&resp).expect("serialization failed");
1205 assert_eq!(json["query"], "test query");
1206 assert!(json["sub_queries"].is_array());
1207 assert!(json["results"].is_array());
1208 assert!(json["evidence_chains"].is_array());
1209 assert_eq!(json["stats"]["elapsed_ms"], 42);
1210 }
1211
1212 #[test]
1216 fn test_distinct_sub_queries_produce_distinct_texts() {
1217 let queries = [
1218 "authentication design decisions",
1219 "deployment configuration and infrastructure",
1220 ];
1221 assert_ne!(queries[0], queries[1]);
1223
1224 let decomposed = decompose_query(
1226 "authentication design decisions; deployment configuration and infrastructure",
1227 7,
1228 );
1229 assert_eq!(decomposed.len(), 2);
1230 assert_ne!(decomposed[0], decomposed[1]);
1231 }
1232
1233 #[test]
1235 fn test_rrf_fuse_via_fusion_module() {
1236 use crate::storage::fusion::rrf_fuse;
1237
1238 let knn_ids: Vec<i64> = vec![1, 2, 3];
1239 let fts_ids: Vec<i64> = vec![2, 1, 4];
1240 let scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], 60.0);
1241
1242 let score_1 = scores[&1];
1244 let score_2 = scores[&2];
1245 let score_3 = scores[&3]; let score_4 = scores[&4]; assert!(
1249 score_1 > score_3,
1250 "id 1 (both lists) must beat id 3 (knn-only rank 3)"
1251 );
1252 assert!(
1253 score_2 > score_4,
1254 "id 2 (both lists) must beat id 4 (fts-only rank 3)"
1255 );
1256 }
1257
1258 #[test]
1260 fn test_evidence_chain_has_from_to_and_path() {
1261 let chain = EvidenceChain {
1262 from: "auth-module".to_string(),
1263 to: "jwt-service".to_string(),
1264 path: vec![
1265 EvidenceNode {
1266 entity: "auth-module".to_string(),
1267 relation: None,
1268 weight: None,
1269 },
1270 EvidenceNode {
1271 entity: "token-validator".to_string(),
1272 relation: Some("depends-on".to_string()),
1273 weight: Some(0.9),
1274 },
1275 EvidenceNode {
1276 entity: "jwt-service".to_string(),
1277 relation: Some("uses".to_string()),
1278 weight: Some(0.8),
1279 },
1280 ],
1281 total_weight: 0.72,
1282 depth: 3,
1283 sub_query_ids: vec![0],
1284 };
1285
1286 let json = serde_json::to_value(&chain).expect("serialization failed");
1287 assert!(
1288 json["from"].is_string(),
1289 "evidence chain must have 'from' field"
1290 );
1291 assert!(
1292 json["to"].is_string(),
1293 "evidence chain must have 'to' field"
1294 );
1295 assert!(
1296 json["path"].is_array(),
1297 "evidence chain must have 'path' array"
1298 );
1299 assert_eq!(json["path"].as_array().unwrap().len(), 3);
1300 assert!(json["total_weight"].is_number(), "must have total_weight");
1301 assert_eq!(json["depth"], 3);
1302 }
1303
1304 #[test]
1306 fn test_reconstruct_path_root_to_target_order() {
1307 let seed_set: HashSet<i64> = [10i64].into_iter().collect();
1309 let mut predecessor: PredecessorMap = std::collections::HashMap::new();
1310 predecessor.insert(20, (10, "depends-on".to_string(), 0.9));
1311 predecessor.insert(30, (20, "uses".to_string(), 0.8));
1312 let mut entity_names: crate::hash::AHashMap<i64, String> = crate::hash::AHashMap::default();
1313 entity_names.insert(10, "seed-entity".to_string());
1314 entity_names.insert(20, "middle-entity".to_string());
1315 entity_names.insert(30, "target-entity".to_string());
1316
1317 let result = reconstruct_path(30, &seed_set, &predecessor, &entity_names);
1318 assert!(result.is_some(), "path must be reconstructed");
1319 let (nodes, weight) = result.unwrap();
1320 assert_eq!(nodes.len(), 3);
1322 assert_eq!(nodes[0].entity, "seed-entity");
1323 assert_eq!(nodes[1].entity, "middle-entity");
1324 assert_eq!(nodes[2].entity, "target-entity");
1325 assert!((weight - 0.72).abs() < 1e-6);
1327 }
1328
1329 #[test]
1331 fn test_evidence_chains_single_hop_filtered_out() {
1332 let chain = EvidenceChain {
1334 from: "a".to_string(),
1335 to: "a".to_string(),
1336 path: vec![EvidenceNode {
1337 entity: "a".to_string(),
1338 relation: None,
1339 weight: None,
1340 }],
1341 total_weight: 1.0,
1342 depth: 1,
1343 sub_query_ids: vec![0],
1344 };
1345 let chains = vec![chain];
1347 let retained: Vec<_> = chains.into_iter().filter(|c| c.depth >= 2).collect();
1348 assert!(retained.is_empty(), "depth-1 chains must be filtered out");
1349 }
1350
1351 #[test]
1353 fn test_bfs_with_predecessors_respects_neighbor_cap() {
1354 use crate::graph::bfs_with_predecessors;
1355 use rusqlite::Connection;
1356
1357 let conn = Connection::open_in_memory().unwrap();
1358 conn.execute_batch(
1359 "CREATE TABLE relationships (
1360 source_id INTEGER NOT NULL,
1361 target_id INTEGER NOT NULL,
1362 weight REAL NOT NULL,
1363 namespace TEXT NOT NULL,
1364 relation TEXT NOT NULL DEFAULT 'related'
1365 );",
1366 )
1367 .unwrap();
1368
1369 for target in 2i64..=6 {
1371 conn.execute(
1372 "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, 'ns')",
1373 rusqlite::params![1i64, target, 1.0f64],
1374 )
1375 .unwrap();
1376 }
1377
1378 let (depth_uncapped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, None).unwrap();
1380 assert_eq!(
1381 depth_uncapped.len() - 1,
1382 5,
1383 "uncapped must discover all 5 neighbours (plus seed)"
1384 );
1385
1386 let (depth_capped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, Some(2)).unwrap();
1388 assert_eq!(
1390 depth_capped.len(),
1391 3,
1392 "capped to 2 must yield seed + 2 neighbours"
1393 );
1394 }
1395}