1use crate::errors::AppError;
8use crate::graph::{
9 bfs_with_predecessors, traverse_from_memories_with_hops_capped, PredecessorMap,
10};
11use crate::output;
12use crate::paths::AppPaths;
13use crate::storage::connection::open_ro;
14use crate::storage::fusion::{rrf_fuse, rrf_max_possible};
15use crate::storage::{entities, memories};
16
17use serde::Serialize;
18use std::collections::HashSet;
19use std::sync::Arc;
20use tokio::sync::Semaphore;
21use tokio::task::JoinSet;
22
23#[derive(clap::Args)]
25#[command(
26 about = "Deep parallel multi-hop GraphRAG research via query decomposition",
27 after_long_help = "EXAMPLES:\n \
28 # Basic deep research\n \
29 sqlite-graphrag deep-research \"auth architecture decisions\"\n\n \
30 # With custom parameters\n \
31 sqlite-graphrag deep-research \"auth\" --k 20 --max-hops 3 --max-sub-queries 7\n\n \
32 # Include full memory bodies in output\n \
33 sqlite-graphrag deep-research \"auth\" --with-bodies\n\n \
34 # Tune RRF and graph scoring\n \
35 sqlite-graphrag deep-research \"auth and deployment\" --rrf-k 60 --graph-decay 0.7"
36)]
37pub struct DeepResearchArgs {
38 #[arg(
40 value_name = "QUERY",
41 allow_hyphen_values = true,
42 help = "Research query to decompose and search"
43 )]
44 pub query: String,
45 #[arg(
47 long,
48 short,
49 aliases = ["limit", "top-k"],
50 default_value_t = 20,
51 help = "Results per sub-query (Recall@20 captures 95%+ relevant hits)"
52 )]
53 pub k: usize,
54 #[arg(
56 long,
57 default_value_t = 7,
58 help = "Maximum sub-queries (covers complex multi-hop queries)"
59 )]
60 pub max_sub_queries: usize,
61 #[arg(
63 long,
64 default_value_t = 3,
65 help = "Multi-hop graph traversal depth (sweet spot: 2-3 hops)"
66 )]
67 pub max_hops: usize,
68 #[arg(
70 long,
71 default_value_t = 0.3,
72 help = "Minimum edge weight for graph traversal"
73 )]
74 pub min_weight: f64,
75 #[arg(long, help = "Maximum concurrent sub-queries (default: min(cpus, 8))")]
77 pub max_concurrency: Option<usize>,
78 #[arg(long, default_value_t = 30, help = "Timeout per sub-query in seconds")]
80 pub timeout: u64,
81 #[arg(
83 long,
84 default_value_t = false,
85 help = "Include full memory bodies in results"
86 )]
87 pub with_bodies: bool,
88 #[arg(
90 long,
91 default_value_t = 50,
92 help = "Maximum results after deduplication"
93 )]
94 pub max_results: usize,
95 #[arg(
97 long,
98 default_value_t = 60.0,
99 help = "RRF k parameter (higher = less weight on top ranks)"
100 )]
101 pub rrf_k: f64,
102 #[arg(
104 long,
105 default_value_t = 0.7,
106 help = "Graph score decay factor per hop (0.0-1.0)"
107 )]
108 pub graph_decay: f64,
109 #[arg(
111 long,
112 default_value_t = 0.05,
113 help = "Minimum score threshold for graph-expanded results"
114 )]
115 pub graph_min_score: f64,
116 #[arg(
118 long,
119 help = "Limit neighbours per entity per hop for graph traversal (default: unlimited)"
120 )]
121 pub max_neighbors_per_hop: Option<usize>,
122 #[arg(
124 long,
125 help = "Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global)"
126 )]
127 pub namespace: Option<String>,
128 #[arg(long, default_value = "none", value_parser = ["none"], hide = true)]
130 pub mode: String,
131 #[arg(
133 long,
134 value_name = "USD",
135 help = "Max LLM cost in USD (effective with --mode claude-code/codex)"
136 )]
137 pub max_cost_usd: Option<f64>,
138 #[arg(long, hide = true)]
140 pub json: bool,
141 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
143 pub db: Option<String>,
144 #[command(flatten)]
145 pub daemon: crate::cli::DaemonOpts,
146}
147
148#[derive(Serialize)]
149struct SubQuery {
150 id: usize,
151 text: String,
152 source: &'static str,
153}
154
155#[derive(Serialize)]
156struct DeepResult {
157 name: String,
158 score: f64,
159 source: String,
160 sub_query_ids: Vec<usize>,
161 snippet: String,
162 #[serde(skip_serializing_if = "Option::is_none")]
163 body: Option<String>,
164 hop_distance: Option<usize>,
165}
166
167#[derive(Serialize, Clone)]
169struct EvidenceNode {
170 entity: String,
171 #[serde(skip_serializing_if = "Option::is_none")]
172 relation: Option<String>,
173 #[serde(skip_serializing_if = "Option::is_none")]
174 weight: Option<f64>,
175}
176
177#[derive(Serialize)]
186struct EvidenceChain {
187 from: String,
188 to: String,
189 path: Vec<EvidenceNode>,
190 total_weight: f64,
191 depth: usize,
192 sub_query_ids: Vec<usize>,
193}
194
195#[derive(Serialize)]
196struct ResearchStats {
197 sub_queries_total: usize,
198 sub_queries_completed: usize,
199 sub_queries_failed: usize,
200 sub_queries_timed_out: usize,
201 unique_memories_found: usize,
202 evidence_chains_found: usize,
203 elapsed_ms: u64,
204}
205
206#[derive(Serialize)]
207struct GraphContextEntity {
208 name: String,
209 entity_type: String,
210 degree: u32,
211}
212
213#[derive(Serialize)]
214struct GraphContextRel {
215 from: String,
216 to: String,
217 relation: String,
218 weight: f64,
219}
220
221#[derive(Serialize)]
222struct GraphContext {
223 entities: Vec<GraphContextEntity>,
224 relationships: Vec<GraphContextRel>,
225}
226
227#[derive(Serialize)]
228struct DeepResearchResponse {
229 query: String,
230 sub_queries: Vec<SubQuery>,
231 results: Vec<DeepResult>,
232 evidence_chains: Vec<EvidenceChain>,
233 #[serde(skip_serializing_if = "Option::is_none")]
234 graph_context: Option<GraphContext>,
235 stats: ResearchStats,
236}
237
238type MergedHit = (f64, String, String, String, Option<usize>, Vec<usize>);
240
241struct SubQueryResult {
243 sub_query_id: usize,
244 hits: Vec<(i64, f64, String, String, String, Option<usize>)>,
246 chains: Vec<EvidenceChain>,
248}
249
250#[tracing::instrument(skip_all, level = "debug", name = "deep_research")]
252pub fn run(args: DeepResearchArgs) -> Result<(), AppError> {
253 tracing::debug!(target: "deep_research", query = %args.query, k = args.k, "starting deep research");
254 let rt = tokio::runtime::Builder::new_multi_thread()
255 .worker_threads(2)
256 .enable_all()
257 .build()
258 .map_err(|e| AppError::Internal(anyhow::anyhow!("failed to build tokio runtime: {e}")))?;
259 rt.block_on(run_async(args))
260}
261
262async fn run_async(args: DeepResearchArgs) -> Result<(), AppError> {
264 let start = std::time::Instant::now();
265
266 if args.query.trim().is_empty() {
267 return Err(AppError::Validation(crate::i18n::validation::empty_query()));
268 }
269
270 if args.max_cost_usd.is_some() && args.mode == "none" {
271 tracing::warn!(target: "deep_research", "--max-cost-usd has no effect without --mode claude-code/codex");
272 }
273
274 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
275 let paths = AppPaths::resolve(args.db.as_deref())?;
276 crate::storage::connection::ensure_db_ready(&paths)?;
277
278 let sub_query_texts = decompose_query(&args.query, args.max_sub_queries);
280 let sub_queries: Vec<SubQuery> = sub_query_texts
281 .iter()
282 .enumerate()
283 .map(|(i, text)| SubQuery {
284 id: i,
285 text: text.clone(),
286 source: if sub_query_texts.len() == 1 {
287 "original"
288 } else {
289 "decomposed"
290 },
291 })
292 .collect();
293
294 output::emit_progress_i18n(
298 "Computing per-sub-query embeddings...",
299 "Calculando embeddings por sub-consulta...",
300 );
301 let mut sub_embeddings: Vec<Arc<Vec<f32>>> = Vec::with_capacity(sub_query_texts.len());
302 for sq_text in &sub_query_texts {
303 let emb = crate::daemon::embed_query_or_local(
304 &paths.models,
305 sq_text,
306 args.daemon.autostart_daemon,
307 )?;
308 sub_embeddings.push(Arc::new(emb));
309 }
310
311 let cpu_count = std::thread::available_parallelism()
313 .map(|n| n.get())
314 .unwrap_or(4);
315 let permits = args
316 .max_concurrency
317 .unwrap_or_else(|| cpu_count.min(8))
318 .min(sub_queries.len())
319 .max(1);
320 let semaphore = Arc::new(Semaphore::new(permits));
321 let timeout_dur = std::time::Duration::from_secs(args.timeout);
322
323 let mut join_set: JoinSet<Result<SubQueryResult, (usize, String)>> = JoinSet::new();
324
325 for (idx, sq_text) in sub_query_texts.iter().enumerate() {
326 let sem = Arc::clone(&semaphore);
327 let emb = Arc::clone(&sub_embeddings[idx]);
329 let ns = namespace.clone();
330 let db_path = paths.db.clone();
331 let query_text = sq_text.clone();
332 let k = args.k;
333 let max_hops = args.max_hops;
334 let min_weight = args.min_weight;
335 let rrf_k = args.rrf_k;
336 let graph_decay = args.graph_decay;
337 let graph_min_score = args.graph_min_score;
338 let max_neighbors_per_hop = args.max_neighbors_per_hop;
339
340 join_set.spawn(async move {
341 let _permit = sem
342 .acquire_owned()
343 .await
344 .map_err(|e| (idx, format!("semaphore closed: {e}")))?;
345
346 let result = tokio::time::timeout(timeout_dur, async move {
348 execute_sub_query(
349 idx,
350 &query_text,
351 emb.as_slice(),
352 &ns,
353 &db_path,
354 k,
355 max_hops,
356 min_weight,
357 rrf_k,
358 graph_decay,
359 graph_min_score,
360 max_neighbors_per_hop,
361 )
362 })
363 .await;
364
365 match result {
366 Ok(inner) => inner.map_err(|e| (idx, e)),
367 Err(_) => Err((idx, "timeout".to_string())),
368 }
369 });
370 }
371
372 let mut sub_query_results: Vec<SubQueryResult> = Vec::with_capacity(sub_queries.len());
374 let mut failed_count = 0usize;
375 let mut timed_out_count = 0usize;
376
377 while let Some(join_result) = join_set.join_next().await {
378 match join_result {
379 Ok(Ok(sqr)) => sub_query_results.push(sqr),
380 Ok(Err((_idx, reason))) => {
381 if reason == "timeout" {
382 timed_out_count += 1;
383 } else {
384 failed_count += 1;
385 }
386 tracing::warn!(target: "deep_research", sub_query_id = _idx, reason = %reason, "sub-query failed");
387 }
388 Err(join_err) => {
389 failed_count += 1;
390 if join_err.is_panic() {
391 tracing::error!(target: "deep_research", error = %join_err, "sub-query task panicked");
392 } else {
393 tracing::warn!(target: "deep_research", error = %join_err, "sub-query task cancelled");
394 }
395 }
396 }
397 }
398
399 let mut merged: crate::hash::AHashMap<i64, MergedHit> =
402 crate::hash::AHashMap::with_capacity_and_hasher(
403 sub_query_results.len() * args.k,
404 Default::default(),
405 );
406
407 for sqr in &sub_query_results {
408 for (mem_id, score, source, snippet, body, hop) in &sqr.hits {
409 let entry = merged.entry(*mem_id).or_insert_with(|| {
410 (
411 *score,
412 source.clone(),
413 snippet.clone(),
414 body.clone(),
415 *hop,
416 Vec::new(),
417 )
418 });
419 if *score > entry.0 {
421 entry.0 = *score;
422 entry.1 = source.clone();
423 entry.2 = snippet.clone();
424 entry.3 = body.clone();
425 entry.4 = *hop;
426 }
427 if !entry.5.contains(&sqr.sub_query_id) {
428 entry.5.push(sqr.sub_query_id);
429 }
430 }
431 }
432
433 let conn = open_ro(&paths.db)?;
435 let mut results: Vec<DeepResult> = Vec::with_capacity(merged.len().min(args.max_results));
436
437 let mut ranked: Vec<(i64, MergedHit)> = merged.into_iter().collect();
439 ranked.sort_by(|a, b| {
440 b.1 .0
441 .partial_cmp(&a.1 .0)
442 .unwrap_or(std::cmp::Ordering::Equal)
443 });
444 ranked.truncate(args.max_results);
445
446 for (mem_id, (score, source, snippet, body, hop, sq_ids)) in ranked {
447 let name = match memories::read_full(&conn, mem_id)? {
448 Some(row) => row.name,
449 None => continue,
450 };
451 results.push(DeepResult {
452 name,
453 score,
454 source,
455 sub_query_ids: sq_ids,
456 snippet,
457 body: if args.with_bodies { Some(body) } else { None },
458 hop_distance: hop,
459 });
460 }
461
462 let completed_count = sub_query_results.len();
466 let mut evidence_chains: Vec<EvidenceChain> = Vec::with_capacity(completed_count * 2);
467 let mut seen_chain_keys: HashSet<String> = HashSet::with_capacity(completed_count * 2);
468
469 for sqr in sub_query_results {
470 for chain in sqr.chains {
471 let key = format!("{}->{}", chain.from, chain.to);
473 if seen_chain_keys.insert(key) {
474 evidence_chains.push(chain);
475 }
476 }
477 }
478
479 evidence_chains.retain(|c| c.depth >= 2);
481 evidence_chains.sort_by(|a, b| {
482 b.total_weight
483 .partial_cmp(&a.total_weight)
484 .unwrap_or(std::cmp::Ordering::Equal)
485 });
486
487 let unique_memories = results.len();
488 let evidence_count = evidence_chains.len();
489
490 let graph_context = if !results.is_empty() {
492 let result_names: Vec<&str> = results.iter().map(|r| r.name.as_str()).collect();
493 let mut ctx_entities: Vec<GraphContextEntity> = Vec::with_capacity(results.len());
494 let mut ctx_rels: Vec<GraphContextRel> = Vec::with_capacity(results.len() * 2);
495 let mut seen_entity_ids: crate::hash::AHashSet<i64> =
496 crate::hash::AHashSet::with_capacity_and_hasher(results.len(), Default::default());
497
498 for name in &result_names {
499 if let Ok(Some(eid)) = entities::find_entity_id(&conn, &namespace, name) {
500 if seen_entity_ids.insert(eid) {
501 let etype: String = conn
502 .query_row(
503 "SELECT COALESCE(type,'concept') FROM entities WHERE id = ?1",
504 rusqlite::params![eid],
505 |r| r.get(0),
506 )
507 .unwrap_or_else(|_| "concept".to_string());
508 let degree: u32 = conn
509 .query_row(
510 "SELECT COUNT(*) FROM relationships WHERE source_id = ?1 OR target_id = ?1",
511 rusqlite::params![eid],
512 |r| r.get(0),
513 )
514 .unwrap_or(0);
515 ctx_entities.push(GraphContextEntity {
516 name: name.to_string(),
517 entity_type: etype,
518 degree,
519 });
520 }
521 }
522 }
523
524 let entity_ids: Vec<i64> = seen_entity_ids.iter().copied().collect();
525 if entity_ids.len() >= 2 {
526 let placeholders: String = entity_ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
527 let sql = format!(
528 "SELECT s.name, t.name, r.relation, r.weight \
529 FROM relationships r \
530 JOIN entities s ON s.id = r.source_id \
531 JOIN entities t ON t.id = r.target_id \
532 WHERE r.source_id IN ({placeholders}) AND r.target_id IN ({placeholders}) \
533 LIMIT 50"
534 );
535 if let Ok(mut stmt) = conn.prepare(&sql) {
536 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> =
537 Vec::with_capacity(entity_ids.len() * 2);
538 for id in &entity_ids {
539 params.push(Box::new(*id));
540 }
541 for id in &entity_ids {
542 params.push(Box::new(*id));
543 }
544 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
545 params.iter().map(|p| p.as_ref()).collect();
546 if let Ok(rows) = stmt.query_map(param_refs.as_slice(), |r| {
547 Ok((
548 r.get::<_, String>(0)?,
549 r.get::<_, String>(1)?,
550 r.get::<_, String>(2)?,
551 r.get::<_, f64>(3)?,
552 ))
553 }) {
554 for row in rows.flatten() {
555 ctx_rels.push(GraphContextRel {
556 from: row.0,
557 to: row.1,
558 relation: row.2,
559 weight: row.3,
560 });
561 }
562 }
563 }
564 }
565
566 if ctx_entities.is_empty() {
567 None
568 } else {
569 Some(GraphContext {
570 entities: ctx_entities,
571 relationships: ctx_rels,
572 })
573 }
574 } else {
575 None
576 };
577
578 tracing::debug!(target: "deep_research",
579 total_results = results.len(),
580 total_chains = evidence_chains.len(),
581 "assembly complete"
582 );
583
584 output::emit_json(&DeepResearchResponse {
586 query: args.query,
587 sub_queries,
588 results,
589 evidence_chains,
590 graph_context,
591 stats: ResearchStats {
592 sub_queries_total: sub_query_texts.len(),
593 sub_queries_completed: completed_count,
594 sub_queries_failed: failed_count,
595 sub_queries_timed_out: timed_out_count,
596 unique_memories_found: unique_memories,
597 evidence_chains_found: evidence_count,
598 elapsed_ms: start.elapsed().as_millis() as u64,
599 },
600 })?;
601
602 Ok(())
603}
604
605fn decompose_query(query: &str, max: usize) -> Vec<String> {
608 if query.is_empty() {
609 return vec![query.to_string()];
610 }
611
612 let mut parts: Vec<String> = Vec::with_capacity(max);
613
614 let relational = [
616 " that caused ",
617 " depending on ",
618 " related to ",
619 " connected to ",
620 " linked to ",
621 " caused by ",
622 " followed by ",
623 ];
624 let mut text = query.to_string();
625 let mut did_relational_split = false;
626 for phrase in &relational {
627 if text.to_lowercase().contains(phrase) {
628 let lower = text.to_lowercase();
629 if let Some(pos) = lower.find(phrase) {
630 let left = text[..pos].trim().to_string();
631 let right = text[pos + phrase.len()..].trim().to_string();
632 if !left.is_empty() {
633 parts.push(left);
634 }
635 if !right.is_empty() {
636 text = right;
637 }
638 did_relational_split = true;
639 }
640 }
641 }
642 if did_relational_split && !text.is_empty() {
643 parts.push(text.clone());
644 }
645
646 if parts.is_empty() {
648 let semi_parts: Vec<&str> = query.split(';').collect();
650 if semi_parts.len() > 1 {
651 for p in &semi_parts {
652 let trimmed = p.trim();
653 if !trimmed.is_empty() {
654 parts.push(trimmed.to_string());
655 }
656 }
657 } else {
658 let normalized = query
661 .replace(" and ", ", ")
662 .replace(" AND ", ", ")
663 .replace(" e ", ", ")
664 .replace(" E ", ", ");
665 let comma_parts: Vec<&str> = normalized.split(',').collect();
666 if comma_parts.len() > 1 {
667 for p in &comma_parts {
668 let trimmed = p.trim();
669 if !trimmed.is_empty() {
670 parts.push(trimmed.to_string());
671 }
672 }
673 }
674 }
675 }
676
677 if parts.is_empty() {
679 let words: Vec<&str> = query.split_whitespace().filter(|w| w.len() > 2).collect();
680 if words.len() >= 3 {
681 parts.push(query.to_string());
682 parts.push(format!("{} {}", words[0], words[1]));
683 parts.push(format!(
684 "{} {}",
685 words[words.len() - 2],
686 words[words.len() - 1]
687 ));
688 }
689 }
690
691 if parts.is_empty() {
692 return vec![query.to_string()];
693 }
694
695 parts.truncate(max);
697 parts
698}
699
700fn reconstruct_path(
704 target_id: i64,
705 seed_entity_ids: &HashSet<i64>,
706 predecessor: &PredecessorMap,
707 entity_names: &crate::hash::AHashMap<i64, String>,
708) -> Option<(Vec<EvidenceNode>, f64)> {
709 let mut path_ids: Vec<(i64, Option<String>, Option<f64>)> = Vec::with_capacity(8);
710 let mut total_weight = 1.0_f64;
711 let mut current = target_id;
712
713 loop {
714 if seed_entity_ids.contains(¤t) {
715 break;
716 }
717 let (parent, relation, weight) = predecessor.get(¤t)?;
718 total_weight *= weight;
719 path_ids.push((current, Some(relation.clone()), Some(*weight)));
720 current = *parent;
721 }
722 path_ids.push((current, None, None));
724
725 path_ids.reverse();
727
728 let nodes: Vec<EvidenceNode> = path_ids
729 .into_iter()
730 .map(|(id, relation, weight)| EvidenceNode {
731 entity: entity_names
732 .get(&id)
733 .cloned()
734 .unwrap_or_else(|| format!("entity-{id}")),
735 relation,
736 weight,
737 })
738 .collect();
739
740 Some((nodes, total_weight))
741}
742
743#[allow(clippy::too_many_arguments)]
753fn execute_sub_query(
754 sub_query_id: usize,
755 query_text: &str,
756 embedding: &[f32],
757 namespace: &str,
758 db_path: &std::path::Path,
759 k: usize,
760 max_hops: usize,
761 min_weight: f64,
762 rrf_k: f64,
763 graph_decay: f64,
764 graph_min_score: f64,
765 max_neighbors_per_hop: Option<usize>,
766) -> Result<SubQueryResult, String> {
767 let conn = open_ro(db_path).map_err(|e| format!("failed to open db: {e}"))?;
768
769 let mut hits: Vec<(i64, f64, String, String, String, Option<usize>)> =
770 Vec::with_capacity(k * 2);
771 let mut seen_ids: crate::hash::AHashSet<i64> =
772 crate::hash::AHashSet::with_capacity_and_hasher(k * 2, Default::default());
773
774 let knn_results = memories::knn_search(&conn, embedding, &[namespace.to_string()], None, k)
778 .map_err(|e| format!("knn_search failed: {e}"))?;
779 let knn_ids: Vec<i64> = knn_results.iter().map(|(id, _)| *id).collect();
780 tracing::debug!(target: "deep_research", sub_query_id, knn_count = knn_ids.len(), "KNN complete");
781
782 let knn_distance_map: crate::hash::AHashMap<i64, f64> = knn_results
784 .iter()
785 .map(|(id, dist)| (*id, *dist as f64))
786 .collect();
787
788 let fts_results = match memories::fts_search(&conn, query_text, namespace, None, k) {
790 Ok(rows) => rows,
791 Err(e) => {
792 tracing::warn!(target: "deep_research",
793 sub_query_id,
794 "FTS5 search failed, continuing with KNN only: {e}"
795 );
796 vec![]
797 }
798 };
799 let fts_ids: Vec<i64> = fts_results.iter().map(|r| r.id).collect();
800 tracing::debug!(target: "deep_research", sub_query_id, fts_count = fts_ids.len(), "FTS complete");
801
802 let rrf_scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], rrf_k);
804 let max_possible = rrf_max_possible(&[1.0, 1.0], rrf_k);
805
806 let mut fused: Vec<(i64, f64)> = rrf_scores.into_iter().collect();
808 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
809 fused.truncate(k * 2);
810 tracing::debug!(target: "deep_research",
811 sub_query_id,
812 fused_count = fused.len(),
813 "RRF fusion complete"
814 );
815
816 if fused.is_empty() && !knn_ids.is_empty() {
817 tracing::warn!(target: "deep_research", sub_query_id, knn_count = knn_ids.len(), fts_count = fts_ids.len(),
818 "RRF fusion returned 0 results despite KNN/FTS hits; consider lowering --graph-min-score");
819 }
820
821 for (memory_id, combined_score) in &fused {
822 if seen_ids.insert(*memory_id) {
823 let normalized = if max_possible > 0.0 {
824 combined_score / max_possible
825 } else {
826 0.0
827 };
828 let score = normalized.clamp(0.0, 1.0);
829 let in_knn = knn_distance_map.contains_key(memory_id);
830 let in_fts = fts_ids.contains(memory_id);
831 let source = match (in_knn, in_fts) {
832 (true, true) => "hybrid",
833 (true, false) => "knn",
834 (false, true) => "fts",
835 (false, false) => "graph",
836 };
837 if let Ok(Some(row)) = memories::read_full(&conn, *memory_id) {
838 let snippet: String = row.body.chars().take(300).collect();
839 hits.push((
840 *memory_id,
841 score,
842 source.to_string(),
843 snippet,
844 row.body,
845 None,
846 ));
847 }
848 }
849 }
850
851 let memory_ids: Vec<i64> = hits.iter().map(|(id, ..)| *id).collect();
854 let mut chains: Vec<EvidenceChain> = Vec::with_capacity(memory_ids.len());
855
856 if !memory_ids.is_empty() && max_hops > 0 {
857 let entity_knn = entities::knn_search(&conn, embedding, namespace, 5)
859 .inspect_err(|e| tracing::warn!(target: "deep_research", error = %e, "entity KNN search failed, skipping graph seed"))
860 .unwrap_or_default();
861 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
862
863 let top_seed_count = 5.min(memory_ids.len());
866 let top_memory_ids = &memory_ids[..top_seed_count];
867 let mut seed_entity_ids: Vec<i64> = entity_ids.clone();
868 for &mem_id in top_memory_ids {
869 let mut stmt = conn
870 .prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")
871 .map_err(|e| format!("prepare failed: {e}"))?;
872 let ids: Vec<i64> = stmt
873 .query_map(rusqlite::params![mem_id], |r| r.get(0))
874 .map_err(|e| format!("query failed: {e}"))?
875 .filter_map(|r| r.ok())
876 .collect();
877 seed_entity_ids.extend(ids);
878 }
879 seed_entity_ids.sort_unstable();
880 seed_entity_ids.dedup();
881 tracing::debug!(target: "deep_research",
882 sub_query_id,
883 seed_count = seed_entity_ids.len(),
884 "seed entities collected"
885 );
886
887 let all_seed_ids: Vec<i64> = memory_ids
888 .iter()
889 .chain(entity_ids.iter())
890 .copied()
891 .collect();
892
893 if let Ok(graph_results) = traverse_from_memories_with_hops_capped(
895 &conn,
896 &all_seed_ids,
897 namespace,
898 min_weight,
899 max_hops as u32,
900 max_neighbors_per_hop,
901 ) {
902 let seed_score_map: crate::hash::AHashMap<i64, f64> = fused
904 .iter()
905 .map(|(id, s)| {
906 let normalized = if max_possible > 0.0 {
907 s / max_possible
908 } else {
909 0.0
910 };
911 (*id, normalized.clamp(0.0, 1.0))
912 })
913 .collect();
914
915 for (graph_mem_id, hop) in graph_results {
916 if seen_ids.insert(graph_mem_id) {
917 let avg_seed_score: f64 = if seed_score_map.is_empty() {
922 0.5
923 } else {
924 let sum: f64 = seed_score_map.values().sum();
925 sum / seed_score_map.len() as f64
926 };
927 let graph_score =
928 (avg_seed_score * graph_decay.powi(hop as i32)).clamp(0.0, 1.0);
929
930 if graph_score < graph_min_score {
931 continue;
932 }
933
934 if let Ok(Some(row)) = memories::read_full(&conn, graph_mem_id) {
935 let snippet: String = row.body.chars().take(300).collect();
936 hits.push((
937 graph_mem_id,
938 graph_score,
939 "graph".to_string(),
940 snippet,
941 row.body,
942 Some(hop as usize),
943 ));
944 }
945 }
946 }
947 }
948
949 if !seed_entity_ids.is_empty() {
952 let (entity_depth, predecessor) = bfs_with_predecessors(
953 &conn,
954 &seed_entity_ids,
955 namespace,
956 min_weight,
957 max_hops as u32,
958 max_neighbors_per_hop,
959 )
960 .unwrap_or_default();
961
962 tracing::debug!(target: "deep_research",
963 sub_query_id,
964 bfs_nodes = entity_depth.len(),
965 predecessors = predecessor.len(),
966 "BFS complete"
967 );
968
969 let seed_entity_set: HashSet<i64> = seed_entity_ids.iter().copied().collect();
970
971 let all_entity_ids: Vec<i64> = entity_depth.keys().copied().collect();
973 let mut entity_names: crate::hash::AHashMap<i64, String> =
974 crate::hash::AHashMap::with_capacity_and_hasher(
975 all_entity_ids.len(),
976 ahash::RandomState::default(),
977 );
978 for &eid in &all_entity_ids {
979 let name_res: rusqlite::Result<String> = conn.query_row(
980 "SELECT name FROM entities WHERE id = ?1",
981 rusqlite::params![eid],
982 |r| r.get(0),
983 );
984 if let Ok(name) = name_res {
985 entity_names.insert(eid, name);
986 }
987 }
988
989 for (&target_id, &_hop) in &entity_depth {
991 if seed_entity_set.contains(&target_id) {
992 continue;
993 }
994 if !predecessor.contains_key(&target_id) {
995 continue;
996 }
997 if let Some((path_nodes, total_weight)) =
998 reconstruct_path(target_id, &seed_entity_set, &predecessor, &entity_names)
999 {
1000 if path_nodes.len() < 2 {
1001 continue;
1002 }
1003 let from = path_nodes
1004 .first()
1005 .map(|n| n.entity.clone())
1006 .unwrap_or_default();
1007 let to = path_nodes
1008 .last()
1009 .map(|n| n.entity.clone())
1010 .unwrap_or_default();
1011 let depth = path_nodes.len();
1012 chains.push(EvidenceChain {
1013 from,
1014 to,
1015 path: path_nodes,
1016 total_weight,
1017 depth,
1018 sub_query_ids: vec![sub_query_id],
1019 });
1020 }
1021 }
1022
1023 chains.sort_by(|a, b| {
1025 b.total_weight
1026 .partial_cmp(&a.total_weight)
1027 .unwrap_or(std::cmp::Ordering::Equal)
1028 });
1029 chains.truncate(20);
1030 tracing::debug!(target: "deep_research",
1031 sub_query_id,
1032 chains_count = chains.len(),
1033 "evidence chains built"
1034 );
1035 }
1036 }
1037
1038 Ok(SubQueryResult {
1039 sub_query_id,
1040 hits,
1041 chains,
1042 })
1043}
1044
1045#[cfg(test)]
1051mod tests {
1052 use super::*;
1053
1054 #[test]
1055 fn test_decompose_and_conjunction() {
1056 let result = decompose_query("A and B", 7);
1057 assert_eq!(result, vec!["A", "B"]);
1058 }
1059
1060 #[test]
1061 fn test_decompose_no_split() {
1062 let result = decompose_query("simple query", 7);
1063 assert_eq!(result, vec!["simple query"]);
1064 }
1065
1066 #[test]
1067 fn test_decompose_three_parts() {
1068 let result = decompose_query("A, B and C", 7);
1069 assert_eq!(result, vec!["A", "B", "C"]);
1070 }
1071
1072 #[test]
1073 fn test_decompose_portuguese_conjunctions() {
1074 let result = decompose_query("A e B", 7);
1075 assert_eq!(result, vec!["A", "B"]);
1076 }
1077
1078 #[test]
1079 fn test_decompose_max_cap() {
1080 let parts: Vec<String> = (0..10).map(|i| format!("part{i}")).collect();
1081 let query = parts.join(", ");
1082 let result = decompose_query(&query, 7);
1083 assert!(
1084 result.len() <= 7,
1085 "expected at most 7 sub-queries, got {}",
1086 result.len()
1087 );
1088 }
1089
1090 #[test]
1091 fn test_decompose_empty_preserves_original() {
1092 let result = decompose_query("", 7);
1093 assert_eq!(result, vec![""]);
1094 }
1095
1096 #[test]
1097 fn test_decompose_semicolons() {
1098 let result = decompose_query("auth design; deployment config; logging", 7);
1099 assert_eq!(result, vec!["auth design", "deployment config", "logging"]);
1100 }
1101
1102 #[test]
1103 fn test_decompose_relational_phrase() {
1104 let result = decompose_query("auth that caused deployment failure", 7);
1105 assert_eq!(result, vec!["auth", "deployment failure"]);
1106 }
1107
1108 #[test]
1109 fn test_sub_query_serialization() {
1110 let sq = SubQuery {
1111 id: 0,
1112 text: "test query".to_string(),
1113 source: "original",
1114 };
1115 let json = serde_json::to_value(&sq).expect("serialization failed");
1116 assert_eq!(json["id"], 0);
1117 assert_eq!(json["text"], "test query");
1118 assert_eq!(json["source"], "original");
1119 }
1120
1121 #[test]
1122 fn test_deep_result_omits_body_when_none() {
1123 let result = DeepResult {
1124 name: "test".to_string(),
1125 score: 0.9,
1126 source: "knn".to_string(),
1127 sub_query_ids: vec![0],
1128 snippet: "snippet".to_string(),
1129 body: None,
1130 hop_distance: None,
1131 };
1132 let json = serde_json::to_string(&result).expect("serialization failed");
1133 assert!(!json.contains("\"body\""), "body must be omitted when None");
1134 }
1135
1136 #[test]
1137 fn test_deep_result_includes_body_when_some() {
1138 let result = DeepResult {
1139 name: "test".to_string(),
1140 score: 0.9,
1141 source: "knn".to_string(),
1142 sub_query_ids: vec![0, 1],
1143 snippet: "snippet".to_string(),
1144 body: Some("full body content".to_string()),
1145 hop_distance: Some(2),
1146 };
1147 let json = serde_json::to_string(&result).expect("serialization failed");
1148 assert!(json.contains("\"body\""), "body must be present when Some");
1149 assert!(json.contains("full body content"));
1150 }
1151
1152 #[test]
1153 fn test_evidence_node_omits_none_fields() {
1154 let node = EvidenceNode {
1155 entity: "auth-module".to_string(),
1156 relation: None,
1157 weight: None,
1158 };
1159 let json = serde_json::to_string(&node).expect("serialization failed");
1160 assert!(
1161 !json.contains("\"relation\""),
1162 "relation must be omitted when None"
1163 );
1164 assert!(
1165 !json.contains("\"weight\""),
1166 "weight must be omitted when None"
1167 );
1168 }
1169
1170 #[test]
1171 fn test_research_stats_serialization() {
1172 let stats = ResearchStats {
1173 sub_queries_total: 3,
1174 sub_queries_completed: 2,
1175 sub_queries_failed: 1,
1176 sub_queries_timed_out: 0,
1177 unique_memories_found: 10,
1178 evidence_chains_found: 2,
1179 elapsed_ms: 1234,
1180 };
1181 let json = serde_json::to_value(&stats).expect("serialization failed");
1182 assert_eq!(json["sub_queries_total"], 3);
1183 assert_eq!(json["sub_queries_completed"], 2);
1184 assert_eq!(json["sub_queries_failed"], 1);
1185 assert_eq!(json["elapsed_ms"], 1234);
1186 }
1187
1188 #[test]
1189 fn test_deep_research_response_serialization() {
1190 let resp = DeepResearchResponse {
1191 query: "test query".to_string(),
1192 sub_queries: vec![SubQuery {
1193 id: 0,
1194 text: "test query".to_string(),
1195 source: "original",
1196 }],
1197 results: vec![],
1198 evidence_chains: vec![],
1199 graph_context: None,
1200 stats: ResearchStats {
1201 sub_queries_total: 1,
1202 sub_queries_completed: 1,
1203 sub_queries_failed: 0,
1204 sub_queries_timed_out: 0,
1205 unique_memories_found: 0,
1206 evidence_chains_found: 0,
1207 elapsed_ms: 42,
1208 },
1209 };
1210 let json = serde_json::to_value(&resp).expect("serialization failed");
1211 assert_eq!(json["query"], "test query");
1212 assert!(json["sub_queries"].is_array());
1213 assert!(json["results"].is_array());
1214 assert!(json["evidence_chains"].is_array());
1215 assert_eq!(json["stats"]["elapsed_ms"], 42);
1216 }
1217
1218 #[test]
1222 fn test_distinct_sub_queries_produce_distinct_texts() {
1223 let queries = [
1224 "authentication design decisions",
1225 "deployment configuration and infrastructure",
1226 ];
1227 assert_ne!(queries[0], queries[1]);
1229
1230 let decomposed = decompose_query(
1232 "authentication design decisions; deployment configuration and infrastructure",
1233 7,
1234 );
1235 assert_eq!(decomposed.len(), 2);
1236 assert_ne!(decomposed[0], decomposed[1]);
1237 }
1238
1239 #[test]
1241 fn test_rrf_fuse_via_fusion_module() {
1242 use crate::storage::fusion::rrf_fuse;
1243
1244 let knn_ids: Vec<i64> = vec![1, 2, 3];
1245 let fts_ids: Vec<i64> = vec![2, 1, 4];
1246 let scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], 60.0);
1247
1248 let score_1 = scores[&1];
1250 let score_2 = scores[&2];
1251 let score_3 = scores[&3]; let score_4 = scores[&4]; assert!(
1255 score_1 > score_3,
1256 "id 1 (both lists) must beat id 3 (knn-only rank 3)"
1257 );
1258 assert!(
1259 score_2 > score_4,
1260 "id 2 (both lists) must beat id 4 (fts-only rank 3)"
1261 );
1262 }
1263
1264 #[test]
1266 fn test_evidence_chain_has_from_to_and_path() {
1267 let chain = EvidenceChain {
1268 from: "auth-module".to_string(),
1269 to: "jwt-service".to_string(),
1270 path: vec![
1271 EvidenceNode {
1272 entity: "auth-module".to_string(),
1273 relation: None,
1274 weight: None,
1275 },
1276 EvidenceNode {
1277 entity: "token-validator".to_string(),
1278 relation: Some("depends-on".to_string()),
1279 weight: Some(0.9),
1280 },
1281 EvidenceNode {
1282 entity: "jwt-service".to_string(),
1283 relation: Some("uses".to_string()),
1284 weight: Some(0.8),
1285 },
1286 ],
1287 total_weight: 0.72,
1288 depth: 3,
1289 sub_query_ids: vec![0],
1290 };
1291
1292 let json = serde_json::to_value(&chain).expect("serialization failed");
1293 assert!(
1294 json["from"].is_string(),
1295 "evidence chain must have 'from' field"
1296 );
1297 assert!(
1298 json["to"].is_string(),
1299 "evidence chain must have 'to' field"
1300 );
1301 assert!(
1302 json["path"].is_array(),
1303 "evidence chain must have 'path' array"
1304 );
1305 assert_eq!(json["path"].as_array().unwrap().len(), 3);
1306 assert!(json["total_weight"].is_number(), "must have total_weight");
1307 assert_eq!(json["depth"], 3);
1308 }
1309
1310 #[test]
1312 fn test_reconstruct_path_root_to_target_order() {
1313 let seed_set: HashSet<i64> = [10i64].into_iter().collect();
1315 let mut predecessor: PredecessorMap = std::collections::HashMap::new();
1316 predecessor.insert(20, (10, "depends-on".to_string(), 0.9));
1317 predecessor.insert(30, (20, "uses".to_string(), 0.8));
1318 let mut entity_names: crate::hash::AHashMap<i64, String> = crate::hash::AHashMap::default();
1319 entity_names.insert(10, "seed-entity".to_string());
1320 entity_names.insert(20, "middle-entity".to_string());
1321 entity_names.insert(30, "target-entity".to_string());
1322
1323 let result = reconstruct_path(30, &seed_set, &predecessor, &entity_names);
1324 assert!(result.is_some(), "path must be reconstructed");
1325 let (nodes, weight) = result.unwrap();
1326 assert_eq!(nodes.len(), 3);
1328 assert_eq!(nodes[0].entity, "seed-entity");
1329 assert_eq!(nodes[1].entity, "middle-entity");
1330 assert_eq!(nodes[2].entity, "target-entity");
1331 assert!((weight - 0.72).abs() < 1e-6);
1333 }
1334
1335 #[test]
1337 fn test_evidence_chains_single_hop_filtered_out() {
1338 let chain = EvidenceChain {
1340 from: "a".to_string(),
1341 to: "a".to_string(),
1342 path: vec![EvidenceNode {
1343 entity: "a".to_string(),
1344 relation: None,
1345 weight: None,
1346 }],
1347 total_weight: 1.0,
1348 depth: 1,
1349 sub_query_ids: vec![0],
1350 };
1351 let chains = vec![chain];
1353 let retained: Vec<_> = chains.into_iter().filter(|c| c.depth >= 2).collect();
1354 assert!(retained.is_empty(), "depth-1 chains must be filtered out");
1355 }
1356
1357 #[test]
1359 fn test_bfs_with_predecessors_respects_neighbor_cap() {
1360 use crate::graph::bfs_with_predecessors;
1361 use rusqlite::Connection;
1362
1363 let conn = Connection::open_in_memory().unwrap();
1364 conn.execute_batch(
1365 "CREATE TABLE relationships (
1366 source_id INTEGER NOT NULL,
1367 target_id INTEGER NOT NULL,
1368 weight REAL NOT NULL,
1369 namespace TEXT NOT NULL,
1370 relation TEXT NOT NULL DEFAULT 'related'
1371 );",
1372 )
1373 .unwrap();
1374
1375 for target in 2i64..=6 {
1377 conn.execute(
1378 "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, 'ns')",
1379 rusqlite::params![1i64, target, 1.0f64],
1380 )
1381 .unwrap();
1382 }
1383
1384 let (depth_uncapped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, None).unwrap();
1386 assert_eq!(
1387 depth_uncapped.len() - 1,
1388 5,
1389 "uncapped must discover all 5 neighbours (plus seed)"
1390 );
1391
1392 let (depth_capped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, Some(2)).unwrap();
1394 assert_eq!(
1396 depth_capped.len(),
1397 3,
1398 "capped to 2 must yield seed + 2 neighbours"
1399 );
1400 }
1401}