1use crate::errors::AppError;
8use crate::graph::traverse_from_memories_with_hops;
9use crate::output;
10use crate::paths::AppPaths;
11use crate::storage::connection::open_ro;
12use crate::storage::{entities, memories};
13
14use serde::Serialize;
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17use tokio::sync::Semaphore;
18use tokio::task::JoinSet;
19
20#[derive(clap::Args)]
22#[command(
23 about = "Deep parallel multi-hop GraphRAG research via query decomposition",
24 after_long_help = "EXAMPLES:\n \
25 # Basic deep research\n \
26 sqlite-graphrag deep-research \"auth architecture decisions\"\n\n \
27 # With custom parameters\n \
28 sqlite-graphrag deep-research \"auth\" --k 20 --max-hops 3 --max-sub-queries 7\n\n \
29 # Include full memory bodies in output\n \
30 sqlite-graphrag deep-research \"auth\" --with-bodies"
31)]
32pub struct DeepResearchArgs {
33 #[arg(value_name = "QUERY", help = "Research query to decompose and search")]
35 pub query: String,
36 #[arg(
38 long,
39 short,
40 default_value_t = 20,
41 help = "Results per sub-query (Recall@20 captures 95%+ relevant hits)"
42 )]
43 pub k: usize,
44 #[arg(
46 long,
47 default_value_t = 7,
48 help = "Maximum sub-queries (covers complex multi-hop queries)"
49 )]
50 pub max_sub_queries: usize,
51 #[arg(
53 long,
54 default_value_t = 3,
55 help = "Multi-hop graph traversal depth (sweet spot: 2-3 hops)"
56 )]
57 pub max_hops: usize,
58 #[arg(
60 long,
61 default_value_t = 0.3,
62 help = "Minimum edge weight for graph traversal"
63 )]
64 pub min_weight: f64,
65 #[arg(long, help = "Maximum concurrent sub-queries (default: min(cpus, 8))")]
67 pub max_concurrency: Option<usize>,
68 #[arg(long, default_value_t = 30, help = "Timeout per sub-query in seconds")]
70 pub timeout: u64,
71 #[arg(
73 long,
74 default_value_t = false,
75 help = "Include full memory bodies in results"
76 )]
77 pub with_bodies: bool,
78 #[arg(
80 long,
81 default_value_t = 50,
82 help = "Maximum results after deduplication"
83 )]
84 pub max_results: usize,
85 #[arg(
87 long,
88 help = "Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global)"
89 )]
90 pub namespace: Option<String>,
91 #[arg(long, hide = true)]
93 pub json: bool,
94 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
96 pub db: Option<String>,
97 #[command(flatten)]
98 pub daemon: crate::cli::DaemonOpts,
99}
100
101#[derive(Serialize)]
102struct SubQuery {
103 id: usize,
104 text: String,
105 source: &'static str,
106}
107
108#[derive(Serialize)]
109struct DeepResult {
110 name: String,
111 score: f64,
112 source: String,
113 sub_query_ids: Vec<usize>,
114 snippet: String,
115 #[serde(skip_serializing_if = "Option::is_none")]
116 body: Option<String>,
117 hop_distance: Option<usize>,
118}
119
120#[derive(Serialize)]
121struct EvidenceChain {
122 path: Vec<EvidenceNode>,
123 depth: usize,
124 sub_query_ids: Vec<usize>,
125}
126
127#[derive(Serialize)]
128struct EvidenceNode {
129 entity: String,
130 #[serde(skip_serializing_if = "Option::is_none")]
131 relation: Option<String>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 weight: Option<f64>,
134}
135
136#[derive(Serialize)]
137struct ResearchStats {
138 sub_queries_total: usize,
139 sub_queries_completed: usize,
140 sub_queries_failed: usize,
141 sub_queries_timed_out: usize,
142 unique_memories_found: usize,
143 evidence_chains_found: usize,
144 elapsed_ms: u64,
145}
146
147#[derive(Serialize)]
148struct DeepResearchResponse {
149 query: String,
150 sub_queries: Vec<SubQuery>,
151 results: Vec<DeepResult>,
152 evidence_chains: Vec<EvidenceChain>,
153 stats: ResearchStats,
154}
155
156type MergedHit = (f64, String, String, String, Option<usize>, Vec<usize>);
158
159struct SubQueryResult {
161 sub_query_id: usize,
162 hits: Vec<(i64, f64, String, String, String, Option<usize>)>,
164 evidence: Vec<EvidenceNode>,
166}
167
168pub fn run(args: DeepResearchArgs) -> Result<(), AppError> {
170 let rt = tokio::runtime::Builder::new_multi_thread()
171 .worker_threads(2)
172 .enable_all()
173 .build()
174 .map_err(|e| AppError::Internal(anyhow::anyhow!("failed to build tokio runtime: {e}")))?;
175 rt.block_on(run_async(args))
176}
177
178async fn run_async(args: DeepResearchArgs) -> Result<(), AppError> {
180 let start = std::time::Instant::now();
181
182 if args.query.trim().is_empty() {
183 return Err(AppError::Validation(crate::i18n::validation::empty_query()));
184 }
185
186 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
187 let paths = AppPaths::resolve(args.db.as_deref())?;
188 crate::storage::connection::ensure_db_ready(&paths)?;
189
190 let sub_query_texts = decompose_query(&args.query, args.max_sub_queries);
192 let sub_queries: Vec<SubQuery> = sub_query_texts
193 .iter()
194 .enumerate()
195 .map(|(i, text)| SubQuery {
196 id: i,
197 text: text.clone(),
198 source: if sub_query_texts.len() == 1 {
199 "original"
200 } else {
201 "decomposed"
202 },
203 })
204 .collect();
205
206 output::emit_progress_i18n(
208 "Computing query embedding...",
209 "Calculando embedding da consulta...",
210 );
211 let embedding = Arc::new(crate::daemon::embed_query_or_local(
212 &paths.models,
213 &args.query,
214 args.daemon.autostart_daemon,
215 )?);
216
217 let cpu_count = std::thread::available_parallelism()
219 .map(|n| n.get())
220 .unwrap_or(4);
221 let permits = args
222 .max_concurrency
223 .unwrap_or_else(|| cpu_count.min(8))
224 .min(sub_queries.len())
225 .max(1);
226 let semaphore = Arc::new(Semaphore::new(permits));
227 let timeout_dur = std::time::Duration::from_secs(args.timeout);
228
229 let mut join_set: JoinSet<Result<SubQueryResult, (usize, String)>> = JoinSet::new();
230
231 for (idx, sq_text) in sub_query_texts.iter().enumerate() {
232 let sem = Arc::clone(&semaphore);
233 let emb = Arc::clone(&embedding);
234 let ns = namespace.clone();
235 let db_path = paths.db.clone();
236 let query_text = sq_text.clone();
237 let k = args.k;
238 let max_hops = args.max_hops;
239 let min_weight = args.min_weight;
240
241 join_set.spawn(async move {
242 let _permit = sem
243 .acquire_owned()
244 .await
245 .map_err(|e| (idx, format!("semaphore closed: {e}")))?;
246
247 let result = tokio::time::timeout(timeout_dur, async {
248 execute_sub_query(
249 idx,
250 &query_text,
251 &emb,
252 &ns,
253 &db_path,
254 k,
255 max_hops,
256 min_weight,
257 )
258 })
259 .await;
260
261 match result {
262 Ok(inner) => inner.map_err(|e| (idx, e)),
263 Err(_) => Err((idx, "timeout".to_string())),
264 }
265 });
266 }
267
268 let mut sub_query_results: Vec<SubQueryResult> = Vec::with_capacity(sub_queries.len());
270 let mut failed_count = 0usize;
271 let mut timed_out_count = 0usize;
272
273 while let Some(join_result) = join_set.join_next().await {
274 match join_result {
275 Ok(Ok(sqr)) => sub_query_results.push(sqr),
276 Ok(Err((_idx, reason))) => {
277 if reason == "timeout" {
278 timed_out_count += 1;
279 } else {
280 failed_count += 1;
281 }
282 tracing::warn!(sub_query_id = _idx, reason = %reason, "sub-query failed");
283 }
284 Err(join_err) => {
285 failed_count += 1;
286 if join_err.is_panic() {
287 tracing::error!("sub-query task panicked: {join_err}");
288 } else {
289 tracing::warn!("sub-query task cancelled: {join_err}");
290 }
291 }
292 }
293 }
294
295 let mut merged: HashMap<i64, MergedHit> = HashMap::new();
298
299 for sqr in &sub_query_results {
300 for (mem_id, score, source, snippet, body, hop) in &sqr.hits {
301 let entry = merged.entry(*mem_id).or_insert_with(|| {
302 (
303 *score,
304 source.clone(),
305 snippet.clone(),
306 body.clone(),
307 *hop,
308 Vec::new(),
309 )
310 });
311 if *score > entry.0 {
313 entry.0 = *score;
314 entry.1 = source.clone();
315 entry.2 = snippet.clone();
316 entry.3 = body.clone();
317 entry.4 = *hop;
318 }
319 if !entry.5.contains(&sqr.sub_query_id) {
320 entry.5.push(sqr.sub_query_id);
321 }
322 }
323 }
324
325 let conn = open_ro(&paths.db)?;
327 let mut results: Vec<DeepResult> = Vec::with_capacity(merged.len().min(args.max_results));
328
329 let mut ranked: Vec<(i64, MergedHit)> = merged.into_iter().collect();
331 ranked.sort_by(|a, b| {
332 b.1 .0
333 .partial_cmp(&a.1 .0)
334 .unwrap_or(std::cmp::Ordering::Equal)
335 });
336 ranked.truncate(args.max_results);
337
338 for (mem_id, (score, source, snippet, body, hop, sq_ids)) in ranked {
339 let name = match memories::read_full(&conn, mem_id)? {
340 Some(row) => row.name,
341 None => continue,
342 };
343 results.push(DeepResult {
344 name,
345 score,
346 source,
347 sub_query_ids: sq_ids,
348 snippet,
349 body: if args.with_bodies { Some(body) } else { None },
350 hop_distance: hop,
351 });
352 }
353
354 let mut evidence_chains: Vec<EvidenceChain> = Vec::new();
356 let mut seen_chain_keys: HashSet<String> = HashSet::new();
357
358 for sqr in &sub_query_results {
359 if sqr.evidence.is_empty() {
360 continue;
361 }
362 let key: String = sqr
364 .evidence
365 .iter()
366 .map(|n| n.entity.as_str())
367 .collect::<Vec<_>>()
368 .join("->");
369 if seen_chain_keys.insert(key) {
370 evidence_chains.push(EvidenceChain {
371 depth: sqr.evidence.len(),
372 path: sqr
373 .evidence
374 .iter()
375 .map(|n| EvidenceNode {
376 entity: n.entity.clone(),
377 relation: n.relation.clone(),
378 weight: n.weight,
379 })
380 .collect(),
381 sub_query_ids: vec![sqr.sub_query_id],
382 });
383 }
384 }
385
386 let unique_memories = results.len();
387 let evidence_count = evidence_chains.len();
388
389 output::emit_json(&DeepResearchResponse {
391 query: args.query,
392 sub_queries,
393 results,
394 evidence_chains,
395 stats: ResearchStats {
396 sub_queries_total: sub_query_texts.len(),
397 sub_queries_completed: sub_query_results.len(),
398 sub_queries_failed: failed_count,
399 sub_queries_timed_out: timed_out_count,
400 unique_memories_found: unique_memories,
401 evidence_chains_found: evidence_count,
402 elapsed_ms: start.elapsed().as_millis() as u64,
403 },
404 })?;
405
406 Ok(())
407}
408
409fn decompose_query(query: &str, max: usize) -> Vec<String> {
412 if query.is_empty() {
413 return vec![query.to_string()];
414 }
415
416 let mut parts: Vec<String> = Vec::new();
417
418 let relational = [
420 " that caused ",
421 " depending on ",
422 " related to ",
423 " connected to ",
424 " linked to ",
425 " caused by ",
426 " followed by ",
427 ];
428 let mut text = query.to_string();
429 let mut did_relational_split = false;
430 for phrase in &relational {
431 if text.to_lowercase().contains(phrase) {
432 let lower = text.to_lowercase();
433 if let Some(pos) = lower.find(phrase) {
434 let left = text[..pos].trim().to_string();
435 let right = text[pos + phrase.len()..].trim().to_string();
436 if !left.is_empty() {
437 parts.push(left);
438 }
439 if !right.is_empty() {
440 text = right;
441 }
442 did_relational_split = true;
443 }
444 }
445 }
446 if did_relational_split && !text.is_empty() {
447 parts.push(text.clone());
448 }
449
450 if parts.is_empty() {
452 let semi_parts: Vec<&str> = query.split(';').collect();
454 if semi_parts.len() > 1 {
455 for p in &semi_parts {
456 let trimmed = p.trim();
457 if !trimmed.is_empty() {
458 parts.push(trimmed.to_string());
459 }
460 }
461 } else {
462 let normalized = query
465 .replace(" and ", ", ")
466 .replace(" AND ", ", ")
467 .replace(" e ", ", ")
468 .replace(" E ", ", ");
469 let comma_parts: Vec<&str> = normalized.split(',').collect();
470 if comma_parts.len() > 1 {
471 for p in &comma_parts {
472 let trimmed = p.trim();
473 if !trimmed.is_empty() {
474 parts.push(trimmed.to_string());
475 }
476 }
477 }
478 }
479 }
480
481 if parts.is_empty() {
483 return vec![query.to_string()];
484 }
485
486 parts.truncate(max);
488 parts
489}
490
491#[allow(clippy::too_many_arguments)]
496fn execute_sub_query(
497 sub_query_id: usize,
498 query_text: &str,
499 embedding: &[f32],
500 namespace: &str,
501 db_path: &std::path::Path,
502 k: usize,
503 max_hops: usize,
504 min_weight: f64,
505) -> Result<SubQueryResult, String> {
506 let conn = open_ro(db_path).map_err(|e| format!("failed to open db: {e}"))?;
507
508 let mut hits: Vec<(i64, f64, String, String, String, Option<usize>)> =
509 Vec::with_capacity(k * 2);
510 let mut seen_ids: HashSet<i64> = HashSet::new();
511
512 let knn_results = memories::knn_search(&conn, embedding, &[namespace.to_string()], None, k)
514 .map_err(|e| format!("knn_search failed: {e}"))?;
515
516 for (memory_id, distance) in &knn_results {
517 if seen_ids.insert(*memory_id) {
518 let score = 1.0 - (*distance as f64);
519 let score = score.clamp(0.0, 1.0);
520 if let Ok(Some(row)) = memories::read_full(&conn, *memory_id) {
521 let snippet: String = row.body.chars().take(300).collect();
522 hits.push((
523 *memory_id,
524 score,
525 "knn".to_string(),
526 snippet,
527 row.body,
528 None,
529 ));
530 }
531 }
532 }
533
534 match memories::fts_search(&conn, query_text, namespace, None, k) {
536 Ok(fts_rows) => {
537 for row in fts_rows {
538 if seen_ids.insert(row.id) {
539 let snippet: String = row.body.chars().take(300).collect();
541 hits.push((row.id, 0.5, "fts".to_string(), snippet, row.body, None));
542 }
543 }
544 }
545 Err(e) => {
546 tracing::warn!(
547 sub_query_id,
548 "FTS5 search failed, continuing with KNN only: {e}"
549 );
550 }
551 }
552
553 let mut evidence: Vec<EvidenceNode> = Vec::new();
555 let memory_ids: Vec<i64> = hits.iter().map(|(id, ..)| *id).collect();
556
557 if !memory_ids.is_empty() && max_hops > 0 {
558 let entity_knn = entities::knn_search(&conn, embedding, namespace, 5).unwrap_or_default();
560 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
561
562 let all_seed_ids: Vec<i64> = memory_ids
563 .iter()
564 .chain(entity_ids.iter())
565 .copied()
566 .collect();
567
568 if let Ok(graph_results) = traverse_from_memories_with_hops(
569 &conn,
570 &all_seed_ids,
571 namespace,
572 min_weight,
573 max_hops as u32,
574 ) {
575 for (graph_mem_id, hop) in graph_results {
576 if seen_ids.insert(graph_mem_id) {
577 let graph_distance = 1.0 - 1.0 / (hop as f64 + 1.0);
578 let score = 1.0 - graph_distance;
579 if let Ok(Some(row)) = memories::read_full(&conn, graph_mem_id) {
580 let snippet: String = row.body.chars().take(300).collect();
581 hits.push((
582 graph_mem_id,
583 score,
584 "graph".to_string(),
585 snippet,
586 row.body,
587 Some(hop as usize),
588 ));
589 }
590 }
591 }
592 }
593
594 let entity_sql = "\
596 SELECT se.name, te.name, r.relation, r.weight
597 FROM relationships r
598 JOIN entities se ON se.id = r.source_id
599 JOIN entities te ON te.id = r.target_id
600 WHERE r.namespace = ?1 AND r.weight >= ?2
601 ORDER BY r.weight DESC
602 LIMIT 20";
603 if let Ok(mut stmt) = conn.prepare(entity_sql) {
604 if let Ok(rows) = stmt.query_map(rusqlite::params![namespace, min_weight], |r| {
605 Ok((
606 r.get::<_, String>(0)?,
607 r.get::<_, String>(1)?,
608 r.get::<_, String>(2)?,
609 r.get::<_, f64>(3)?,
610 ))
611 }) {
612 for row in rows.flatten() {
613 evidence.push(EvidenceNode {
614 entity: row.0,
615 relation: Some(row.2),
616 weight: Some(row.3),
617 });
618 evidence.push(EvidenceNode {
620 entity: row.1,
621 relation: None,
622 weight: None,
623 });
624 }
625 }
626 }
627 }
628
629 Ok(SubQueryResult {
630 sub_query_id,
631 hits,
632 evidence,
633 })
634}
635
636#[cfg(test)]
637mod tests {
638 use super::*;
639
640 #[test]
641 fn test_decompose_and_conjunction() {
642 let result = decompose_query("A and B", 7);
643 assert_eq!(result, vec!["A", "B"]);
644 }
645
646 #[test]
647 fn test_decompose_no_split() {
648 let result = decompose_query("simple query", 7);
649 assert_eq!(result, vec!["simple query"]);
650 }
651
652 #[test]
653 fn test_decompose_three_parts() {
654 let result = decompose_query("A, B and C", 7);
655 assert_eq!(result, vec!["A", "B", "C"]);
656 }
657
658 #[test]
659 fn test_decompose_portuguese_conjunctions() {
660 let result = decompose_query("A e B", 7);
661 assert_eq!(result, vec!["A", "B"]);
662 }
663
664 #[test]
665 fn test_decompose_max_cap() {
666 let parts: Vec<String> = (0..10).map(|i| format!("part{i}")).collect();
667 let query = parts.join(", ");
668 let result = decompose_query(&query, 7);
669 assert!(
670 result.len() <= 7,
671 "expected at most 7 sub-queries, got {}",
672 result.len()
673 );
674 }
675
676 #[test]
677 fn test_decompose_empty_preserves_original() {
678 let result = decompose_query("", 7);
679 assert_eq!(result, vec![""]);
680 }
681
682 #[test]
683 fn test_decompose_semicolons() {
684 let result = decompose_query("auth design; deployment config; logging", 7);
685 assert_eq!(result, vec!["auth design", "deployment config", "logging"]);
686 }
687
688 #[test]
689 fn test_decompose_relational_phrase() {
690 let result = decompose_query("auth that caused deployment failure", 7);
691 assert_eq!(result, vec!["auth", "deployment failure"]);
692 }
693
694 #[test]
695 fn test_sub_query_serialization() {
696 let sq = SubQuery {
697 id: 0,
698 text: "test query".to_string(),
699 source: "original",
700 };
701 let json = serde_json::to_value(&sq).expect("serialization failed");
702 assert_eq!(json["id"], 0);
703 assert_eq!(json["text"], "test query");
704 assert_eq!(json["source"], "original");
705 }
706
707 #[test]
708 fn test_deep_result_omits_body_when_none() {
709 let result = DeepResult {
710 name: "test".to_string(),
711 score: 0.9,
712 source: "knn".to_string(),
713 sub_query_ids: vec![0],
714 snippet: "snippet".to_string(),
715 body: None,
716 hop_distance: None,
717 };
718 let json = serde_json::to_string(&result).expect("serialization failed");
719 assert!(!json.contains("\"body\""), "body must be omitted when None");
720 }
721
722 #[test]
723 fn test_deep_result_includes_body_when_some() {
724 let result = DeepResult {
725 name: "test".to_string(),
726 score: 0.9,
727 source: "knn".to_string(),
728 sub_query_ids: vec![0, 1],
729 snippet: "snippet".to_string(),
730 body: Some("full body content".to_string()),
731 hop_distance: Some(2),
732 };
733 let json = serde_json::to_string(&result).expect("serialization failed");
734 assert!(json.contains("\"body\""), "body must be present when Some");
735 assert!(json.contains("full body content"));
736 }
737
738 #[test]
739 fn test_evidence_node_omits_none_fields() {
740 let node = EvidenceNode {
741 entity: "auth-module".to_string(),
742 relation: None,
743 weight: None,
744 };
745 let json = serde_json::to_string(&node).expect("serialization failed");
746 assert!(
747 !json.contains("\"relation\""),
748 "relation must be omitted when None"
749 );
750 assert!(
751 !json.contains("\"weight\""),
752 "weight must be omitted when None"
753 );
754 }
755
756 #[test]
757 fn test_research_stats_serialization() {
758 let stats = ResearchStats {
759 sub_queries_total: 3,
760 sub_queries_completed: 2,
761 sub_queries_failed: 1,
762 sub_queries_timed_out: 0,
763 unique_memories_found: 10,
764 evidence_chains_found: 2,
765 elapsed_ms: 1234,
766 };
767 let json = serde_json::to_value(&stats).expect("serialization failed");
768 assert_eq!(json["sub_queries_total"], 3);
769 assert_eq!(json["sub_queries_completed"], 2);
770 assert_eq!(json["sub_queries_failed"], 1);
771 assert_eq!(json["elapsed_ms"], 1234);
772 }
773
774 #[test]
775 fn test_deep_research_response_serialization() {
776 let resp = DeepResearchResponse {
777 query: "test query".to_string(),
778 sub_queries: vec![SubQuery {
779 id: 0,
780 text: "test query".to_string(),
781 source: "original",
782 }],
783 results: vec![],
784 evidence_chains: vec![],
785 stats: ResearchStats {
786 sub_queries_total: 1,
787 sub_queries_completed: 1,
788 sub_queries_failed: 0,
789 sub_queries_timed_out: 0,
790 unique_memories_found: 0,
791 evidence_chains_found: 0,
792 elapsed_ms: 42,
793 },
794 };
795 let json = serde_json::to_value(&resp).expect("serialization failed");
796 assert_eq!(json["query"], "test query");
797 assert!(json["sub_queries"].is_array());
798 assert!(json["results"].is_array());
799 assert!(json["evidence_chains"].is_array());
800 assert_eq!(json["stats"]["elapsed_ms"], 42);
801 }
802}