1use crate::errors::AppError;
8use crate::graph::{
9 bfs_with_predecessors, traverse_from_memories_with_hops_capped, PredecessorMap,
10};
11use crate::output;
12use crate::paths::AppPaths;
13use crate::storage::connection::open_ro;
14use crate::storage::fusion::{rrf_fuse, rrf_max_possible};
15use crate::storage::{entities, memories};
16
17use serde::Serialize;
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20use tokio::sync::Semaphore;
21use tokio::task::JoinSet;
22
23#[derive(clap::Args)]
25#[command(
26 about = "Deep parallel multi-hop GraphRAG research via query decomposition",
27 after_long_help = "EXAMPLES:\n \
28 # Basic deep research\n \
29 sqlite-graphrag deep-research \"auth architecture decisions\"\n\n \
30 # With custom parameters\n \
31 sqlite-graphrag deep-research \"auth\" --k 20 --max-hops 3 --max-sub-queries 7\n\n \
32 # Include full memory bodies in output\n \
33 sqlite-graphrag deep-research \"auth\" --with-bodies\n\n \
34 # Tune RRF and graph scoring\n \
35 sqlite-graphrag deep-research \"auth and deployment\" --rrf-k 60 --graph-decay 0.7"
36)]
37pub struct DeepResearchArgs {
38 #[arg(value_name = "QUERY", help = "Research query to decompose and search")]
40 pub query: String,
41 #[arg(
43 long,
44 short,
45 default_value_t = 20,
46 help = "Results per sub-query (Recall@20 captures 95%+ relevant hits)"
47 )]
48 pub k: usize,
49 #[arg(
51 long,
52 default_value_t = 7,
53 help = "Maximum sub-queries (covers complex multi-hop queries)"
54 )]
55 pub max_sub_queries: usize,
56 #[arg(
58 long,
59 default_value_t = 3,
60 help = "Multi-hop graph traversal depth (sweet spot: 2-3 hops)"
61 )]
62 pub max_hops: usize,
63 #[arg(
65 long,
66 default_value_t = 0.3,
67 help = "Minimum edge weight for graph traversal"
68 )]
69 pub min_weight: f64,
70 #[arg(long, help = "Maximum concurrent sub-queries (default: min(cpus, 8))")]
72 pub max_concurrency: Option<usize>,
73 #[arg(long, default_value_t = 30, help = "Timeout per sub-query in seconds")]
75 pub timeout: u64,
76 #[arg(
78 long,
79 default_value_t = false,
80 help = "Include full memory bodies in results"
81 )]
82 pub with_bodies: bool,
83 #[arg(
85 long,
86 default_value_t = 50,
87 help = "Maximum results after deduplication"
88 )]
89 pub max_results: usize,
90 #[arg(
92 long,
93 default_value_t = 60.0,
94 help = "RRF k parameter (higher = less weight on top ranks)"
95 )]
96 pub rrf_k: f64,
97 #[arg(
99 long,
100 default_value_t = 0.7,
101 help = "Graph score decay factor per hop (0.0-1.0)"
102 )]
103 pub graph_decay: f64,
104 #[arg(
106 long,
107 default_value_t = 0.05,
108 help = "Minimum score threshold for graph-expanded results"
109 )]
110 pub graph_min_score: f64,
111 #[arg(
113 long,
114 help = "Limit neighbours per entity per hop for graph traversal (default: unlimited)"
115 )]
116 pub max_neighbors_per_hop: Option<usize>,
117 #[arg(
119 long,
120 help = "Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global)"
121 )]
122 pub namespace: Option<String>,
123 #[arg(long, default_value = "none", value_parser = ["none"], hide = true)]
125 pub mode: String,
126 #[arg(
128 long,
129 value_name = "USD",
130 help = "Max LLM cost in USD (effective with --mode claude-code/codex)"
131 )]
132 pub max_cost_usd: Option<f64>,
133 #[arg(long, hide = true)]
135 pub json: bool,
136 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
138 pub db: Option<String>,
139 #[command(flatten)]
140 pub daemon: crate::cli::DaemonOpts,
141}
142
143#[derive(Serialize)]
144struct SubQuery {
145 id: usize,
146 text: String,
147 source: &'static str,
148}
149
150#[derive(Serialize)]
151struct DeepResult {
152 name: String,
153 score: f64,
154 source: String,
155 sub_query_ids: Vec<usize>,
156 snippet: String,
157 #[serde(skip_serializing_if = "Option::is_none")]
158 body: Option<String>,
159 hop_distance: Option<usize>,
160}
161
162#[derive(Serialize, Clone)]
164struct EvidenceNode {
165 entity: String,
166 #[serde(skip_serializing_if = "Option::is_none")]
167 relation: Option<String>,
168 #[serde(skip_serializing_if = "Option::is_none")]
169 weight: Option<f64>,
170}
171
172#[derive(Serialize)]
181struct EvidenceChain {
182 from: String,
183 to: String,
184 path: Vec<EvidenceNode>,
185 total_weight: f64,
186 depth: usize,
187 sub_query_ids: Vec<usize>,
188}
189
190#[derive(Serialize)]
191struct ResearchStats {
192 sub_queries_total: usize,
193 sub_queries_completed: usize,
194 sub_queries_failed: usize,
195 sub_queries_timed_out: usize,
196 unique_memories_found: usize,
197 evidence_chains_found: usize,
198 elapsed_ms: u64,
199}
200
201#[derive(Serialize)]
202struct GraphContextEntity {
203 name: String,
204 entity_type: String,
205 degree: u32,
206}
207
208#[derive(Serialize)]
209struct GraphContextRel {
210 from: String,
211 to: String,
212 relation: String,
213 weight: f64,
214}
215
216#[derive(Serialize)]
217struct GraphContext {
218 entities: Vec<GraphContextEntity>,
219 relationships: Vec<GraphContextRel>,
220}
221
222#[derive(Serialize)]
223struct DeepResearchResponse {
224 query: String,
225 sub_queries: Vec<SubQuery>,
226 results: Vec<DeepResult>,
227 evidence_chains: Vec<EvidenceChain>,
228 #[serde(skip_serializing_if = "Option::is_none")]
229 graph_context: Option<GraphContext>,
230 stats: ResearchStats,
231}
232
233type MergedHit = (f64, String, String, String, Option<usize>, Vec<usize>);
235
236struct SubQueryResult {
238 sub_query_id: usize,
239 hits: Vec<(i64, f64, String, String, String, Option<usize>)>,
241 chains: Vec<EvidenceChain>,
243}
244
245pub fn run(args: DeepResearchArgs) -> Result<(), AppError> {
247 let rt = tokio::runtime::Builder::new_multi_thread()
248 .worker_threads(2)
249 .enable_all()
250 .build()
251 .map_err(|e| AppError::Internal(anyhow::anyhow!("failed to build tokio runtime: {e}")))?;
252 rt.block_on(run_async(args))
253}
254
255async fn run_async(args: DeepResearchArgs) -> Result<(), AppError> {
257 let start = std::time::Instant::now();
258
259 if args.query.trim().is_empty() {
260 return Err(AppError::Validation(crate::i18n::validation::empty_query()));
261 }
262
263 if args.max_cost_usd.is_some() && args.mode == "none" {
264 tracing::warn!("--max-cost-usd has no effect without --mode claude-code/codex");
265 }
266
267 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
268 let paths = AppPaths::resolve(args.db.as_deref())?;
269 crate::storage::connection::ensure_db_ready(&paths)?;
270
271 let sub_query_texts = decompose_query(&args.query, args.max_sub_queries);
273 let sub_queries: Vec<SubQuery> = sub_query_texts
274 .iter()
275 .enumerate()
276 .map(|(i, text)| SubQuery {
277 id: i,
278 text: text.clone(),
279 source: if sub_query_texts.len() == 1 {
280 "original"
281 } else {
282 "decomposed"
283 },
284 })
285 .collect();
286
287 output::emit_progress_i18n(
291 "Computing per-sub-query embeddings...",
292 "Calculando embeddings por sub-consulta...",
293 );
294 let mut sub_embeddings: Vec<Arc<Vec<f32>>> = Vec::with_capacity(sub_query_texts.len());
295 for sq_text in &sub_query_texts {
296 let emb = crate::daemon::embed_query_or_local(
297 &paths.models,
298 sq_text,
299 args.daemon.autostart_daemon,
300 )?;
301 sub_embeddings.push(Arc::new(emb));
302 }
303
304 let cpu_count = std::thread::available_parallelism()
306 .map(|n| n.get())
307 .unwrap_or(4);
308 let permits = args
309 .max_concurrency
310 .unwrap_or_else(|| cpu_count.min(8))
311 .min(sub_queries.len())
312 .max(1);
313 let semaphore = Arc::new(Semaphore::new(permits));
314 let timeout_dur = std::time::Duration::from_secs(args.timeout);
315
316 let mut join_set: JoinSet<Result<SubQueryResult, (usize, String)>> = JoinSet::new();
317
318 for (idx, sq_text) in sub_query_texts.iter().enumerate() {
319 let sem = Arc::clone(&semaphore);
320 let emb = Arc::clone(&sub_embeddings[idx]);
322 let ns = namespace.clone();
323 let db_path = paths.db.clone();
324 let query_text = sq_text.clone();
325 let k = args.k;
326 let max_hops = args.max_hops;
327 let min_weight = args.min_weight;
328 let rrf_k = args.rrf_k;
329 let graph_decay = args.graph_decay;
330 let graph_min_score = args.graph_min_score;
331 let max_neighbors_per_hop = args.max_neighbors_per_hop;
332
333 join_set.spawn(async move {
334 let _permit = sem
335 .acquire_owned()
336 .await
337 .map_err(|e| (idx, format!("semaphore closed: {e}")))?;
338
339 let result = tokio::time::timeout(timeout_dur, async move {
341 execute_sub_query(
342 idx,
343 &query_text,
344 emb.as_slice(),
345 &ns,
346 &db_path,
347 k,
348 max_hops,
349 min_weight,
350 rrf_k,
351 graph_decay,
352 graph_min_score,
353 max_neighbors_per_hop,
354 )
355 })
356 .await;
357
358 match result {
359 Ok(inner) => inner.map_err(|e| (idx, e)),
360 Err(_) => Err((idx, "timeout".to_string())),
361 }
362 });
363 }
364
365 let mut sub_query_results: Vec<SubQueryResult> = Vec::with_capacity(sub_queries.len());
367 let mut failed_count = 0usize;
368 let mut timed_out_count = 0usize;
369
370 while let Some(join_result) = join_set.join_next().await {
371 match join_result {
372 Ok(Ok(sqr)) => sub_query_results.push(sqr),
373 Ok(Err((_idx, reason))) => {
374 if reason == "timeout" {
375 timed_out_count += 1;
376 } else {
377 failed_count += 1;
378 }
379 tracing::warn!(sub_query_id = _idx, reason = %reason, "sub-query failed");
380 }
381 Err(join_err) => {
382 failed_count += 1;
383 if join_err.is_panic() {
384 tracing::error!("sub-query task panicked: {join_err}");
385 } else {
386 tracing::warn!("sub-query task cancelled: {join_err}");
387 }
388 }
389 }
390 }
391
392 let mut merged: HashMap<i64, MergedHit> = HashMap::new();
395
396 for sqr in &sub_query_results {
397 for (mem_id, score, source, snippet, body, hop) in &sqr.hits {
398 let entry = merged.entry(*mem_id).or_insert_with(|| {
399 (
400 *score,
401 source.clone(),
402 snippet.clone(),
403 body.clone(),
404 *hop,
405 Vec::new(),
406 )
407 });
408 if *score > entry.0 {
410 entry.0 = *score;
411 entry.1 = source.clone();
412 entry.2 = snippet.clone();
413 entry.3 = body.clone();
414 entry.4 = *hop;
415 }
416 if !entry.5.contains(&sqr.sub_query_id) {
417 entry.5.push(sqr.sub_query_id);
418 }
419 }
420 }
421
422 let conn = open_ro(&paths.db)?;
424 let mut results: Vec<DeepResult> = Vec::with_capacity(merged.len().min(args.max_results));
425
426 let mut ranked: Vec<(i64, MergedHit)> = merged.into_iter().collect();
428 ranked.sort_by(|a, b| {
429 b.1 .0
430 .partial_cmp(&a.1 .0)
431 .unwrap_or(std::cmp::Ordering::Equal)
432 });
433 ranked.truncate(args.max_results);
434
435 for (mem_id, (score, source, snippet, body, hop, sq_ids)) in ranked {
436 let name = match memories::read_full(&conn, mem_id)? {
437 Some(row) => row.name,
438 None => continue,
439 };
440 results.push(DeepResult {
441 name,
442 score,
443 source,
444 sub_query_ids: sq_ids,
445 snippet,
446 body: if args.with_bodies { Some(body) } else { None },
447 hop_distance: hop,
448 });
449 }
450
451 let completed_count = sub_query_results.len();
455 let mut evidence_chains: Vec<EvidenceChain> = Vec::new();
456 let mut seen_chain_keys: HashSet<String> = HashSet::new();
457
458 for sqr in sub_query_results {
459 for chain in sqr.chains {
460 let key = format!("{}->{}", chain.from, chain.to);
462 if seen_chain_keys.insert(key) {
463 evidence_chains.push(chain);
464 }
465 }
466 }
467
468 evidence_chains.retain(|c| c.depth >= 2);
470 evidence_chains.sort_by(|a, b| {
471 b.total_weight
472 .partial_cmp(&a.total_weight)
473 .unwrap_or(std::cmp::Ordering::Equal)
474 });
475
476 let unique_memories = results.len();
477 let evidence_count = evidence_chains.len();
478
479 let graph_context = if !results.is_empty() {
481 let result_names: Vec<&str> = results.iter().map(|r| r.name.as_str()).collect();
482 let mut ctx_entities: Vec<GraphContextEntity> = Vec::new();
483 let mut ctx_rels: Vec<GraphContextRel> = Vec::new();
484 let mut seen_entity_ids: HashSet<i64> = HashSet::new();
485
486 for name in &result_names {
487 if let Ok(Some(eid)) = entities::find_entity_id(&conn, &namespace, name) {
488 if seen_entity_ids.insert(eid) {
489 let etype: String = conn
490 .query_row(
491 "SELECT COALESCE(type,'concept') FROM entities WHERE id = ?1",
492 rusqlite::params![eid],
493 |r| r.get(0),
494 )
495 .unwrap_or_else(|_| "concept".to_string());
496 let degree: u32 = conn
497 .query_row(
498 "SELECT COUNT(*) FROM relationships WHERE source_id = ?1 OR target_id = ?1",
499 rusqlite::params![eid],
500 |r| r.get(0),
501 )
502 .unwrap_or(0);
503 ctx_entities.push(GraphContextEntity {
504 name: name.to_string(),
505 entity_type: etype,
506 degree,
507 });
508 }
509 }
510 }
511
512 let entity_ids: Vec<i64> = seen_entity_ids.iter().copied().collect();
513 if entity_ids.len() >= 2 {
514 let placeholders: String = entity_ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
515 let sql = format!(
516 "SELECT s.name, t.name, r.relation, r.weight \
517 FROM relationships r \
518 JOIN entities s ON s.id = r.source_id \
519 JOIN entities t ON t.id = r.target_id \
520 WHERE r.source_id IN ({placeholders}) AND r.target_id IN ({placeholders}) \
521 LIMIT 50"
522 );
523 if let Ok(mut stmt) = conn.prepare(&sql) {
524 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
525 for id in &entity_ids {
526 params.push(Box::new(*id));
527 }
528 for id in &entity_ids {
529 params.push(Box::new(*id));
530 }
531 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
532 params.iter().map(|p| p.as_ref()).collect();
533 if let Ok(rows) = stmt.query_map(param_refs.as_slice(), |r| {
534 Ok((
535 r.get::<_, String>(0)?,
536 r.get::<_, String>(1)?,
537 r.get::<_, String>(2)?,
538 r.get::<_, f64>(3)?,
539 ))
540 }) {
541 for row in rows.flatten() {
542 ctx_rels.push(GraphContextRel {
543 from: row.0,
544 to: row.1,
545 relation: row.2,
546 weight: row.3,
547 });
548 }
549 }
550 }
551 }
552
553 if ctx_entities.is_empty() {
554 None
555 } else {
556 Some(GraphContext {
557 entities: ctx_entities,
558 relationships: ctx_rels,
559 })
560 }
561 } else {
562 None
563 };
564
565 tracing::debug!(
566 total_results = results.len(),
567 total_chains = evidence_chains.len(),
568 "assembly complete"
569 );
570
571 output::emit_json(&DeepResearchResponse {
573 query: args.query,
574 sub_queries,
575 results,
576 evidence_chains,
577 graph_context,
578 stats: ResearchStats {
579 sub_queries_total: sub_query_texts.len(),
580 sub_queries_completed: completed_count,
581 sub_queries_failed: failed_count,
582 sub_queries_timed_out: timed_out_count,
583 unique_memories_found: unique_memories,
584 evidence_chains_found: evidence_count,
585 elapsed_ms: start.elapsed().as_millis() as u64,
586 },
587 })?;
588
589 Ok(())
590}
591
592fn decompose_query(query: &str, max: usize) -> Vec<String> {
595 if query.is_empty() {
596 return vec![query.to_string()];
597 }
598
599 let mut parts: Vec<String> = Vec::new();
600
601 let relational = [
603 " that caused ",
604 " depending on ",
605 " related to ",
606 " connected to ",
607 " linked to ",
608 " caused by ",
609 " followed by ",
610 ];
611 let mut text = query.to_string();
612 let mut did_relational_split = false;
613 for phrase in &relational {
614 if text.to_lowercase().contains(phrase) {
615 let lower = text.to_lowercase();
616 if let Some(pos) = lower.find(phrase) {
617 let left = text[..pos].trim().to_string();
618 let right = text[pos + phrase.len()..].trim().to_string();
619 if !left.is_empty() {
620 parts.push(left);
621 }
622 if !right.is_empty() {
623 text = right;
624 }
625 did_relational_split = true;
626 }
627 }
628 }
629 if did_relational_split && !text.is_empty() {
630 parts.push(text.clone());
631 }
632
633 if parts.is_empty() {
635 let semi_parts: Vec<&str> = query.split(';').collect();
637 if semi_parts.len() > 1 {
638 for p in &semi_parts {
639 let trimmed = p.trim();
640 if !trimmed.is_empty() {
641 parts.push(trimmed.to_string());
642 }
643 }
644 } else {
645 let normalized = query
648 .replace(" and ", ", ")
649 .replace(" AND ", ", ")
650 .replace(" e ", ", ")
651 .replace(" E ", ", ");
652 let comma_parts: Vec<&str> = normalized.split(',').collect();
653 if comma_parts.len() > 1 {
654 for p in &comma_parts {
655 let trimmed = p.trim();
656 if !trimmed.is_empty() {
657 parts.push(trimmed.to_string());
658 }
659 }
660 }
661 }
662 }
663
664 if parts.is_empty() {
666 let words: Vec<&str> = query.split_whitespace().filter(|w| w.len() > 2).collect();
667 if words.len() >= 3 {
668 parts.push(query.to_string());
669 parts.push(format!("{} {}", words[0], words[1]));
670 parts.push(format!(
671 "{} {}",
672 words[words.len() - 2],
673 words[words.len() - 1]
674 ));
675 }
676 }
677
678 if parts.is_empty() {
679 return vec![query.to_string()];
680 }
681
682 parts.truncate(max);
684 parts
685}
686
687fn reconstruct_path(
691 target_id: i64,
692 seed_entity_ids: &HashSet<i64>,
693 predecessor: &PredecessorMap,
694 entity_names: &HashMap<i64, String>,
695) -> Option<(Vec<EvidenceNode>, f64)> {
696 let mut path_ids: Vec<(i64, Option<String>, Option<f64>)> = Vec::new();
697 let mut total_weight = 1.0_f64;
698 let mut current = target_id;
699
700 loop {
701 if seed_entity_ids.contains(¤t) {
702 break;
703 }
704 let (parent, relation, weight) = predecessor.get(¤t)?;
705 total_weight *= weight;
706 path_ids.push((current, Some(relation.clone()), Some(*weight)));
707 current = *parent;
708 }
709 path_ids.push((current, None, None));
711
712 path_ids.reverse();
714
715 let nodes: Vec<EvidenceNode> = path_ids
716 .into_iter()
717 .map(|(id, relation, weight)| EvidenceNode {
718 entity: entity_names
719 .get(&id)
720 .cloned()
721 .unwrap_or_else(|| format!("entity-{id}")),
722 relation,
723 weight,
724 })
725 .collect();
726
727 Some((nodes, total_weight))
728}
729
730#[allow(clippy::too_many_arguments)]
740fn execute_sub_query(
741 sub_query_id: usize,
742 query_text: &str,
743 embedding: &[f32],
744 namespace: &str,
745 db_path: &std::path::Path,
746 k: usize,
747 max_hops: usize,
748 min_weight: f64,
749 rrf_k: f64,
750 graph_decay: f64,
751 graph_min_score: f64,
752 max_neighbors_per_hop: Option<usize>,
753) -> Result<SubQueryResult, String> {
754 let conn = open_ro(db_path).map_err(|e| format!("failed to open db: {e}"))?;
755
756 let mut hits: Vec<(i64, f64, String, String, String, Option<usize>)> =
757 Vec::with_capacity(k * 2);
758 let mut seen_ids: HashSet<i64> = HashSet::new();
759
760 let knn_results = memories::knn_search(&conn, embedding, &[namespace.to_string()], None, k)
764 .map_err(|e| format!("knn_search failed: {e}"))?;
765 let knn_ids: Vec<i64> = knn_results.iter().map(|(id, _)| *id).collect();
766 tracing::debug!(sub_query_id, knn_count = knn_ids.len(), "KNN complete");
767
768 let knn_distance_map: HashMap<i64, f64> = knn_results
770 .iter()
771 .map(|(id, dist)| (*id, *dist as f64))
772 .collect();
773
774 let fts_results = match memories::fts_search(&conn, query_text, namespace, None, k) {
776 Ok(rows) => rows,
777 Err(e) => {
778 tracing::warn!(
779 sub_query_id,
780 "FTS5 search failed, continuing with KNN only: {e}"
781 );
782 vec![]
783 }
784 };
785 let fts_ids: Vec<i64> = fts_results.iter().map(|r| r.id).collect();
786 tracing::debug!(sub_query_id, fts_count = fts_ids.len(), "FTS complete");
787
788 let rrf_scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], rrf_k);
790 let max_possible = rrf_max_possible(&[1.0, 1.0], rrf_k);
791
792 let mut fused: Vec<(i64, f64)> = rrf_scores.into_iter().collect();
794 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
795 fused.truncate(k * 2);
796 tracing::debug!(
797 sub_query_id,
798 fused_count = fused.len(),
799 "RRF fusion complete"
800 );
801
802 if fused.is_empty() && !knn_ids.is_empty() {
803 tracing::warn!(sub_query_id, knn_count = knn_ids.len(), fts_count = fts_ids.len(),
804 "RRF fusion returned 0 results despite KNN/FTS hits; consider lowering --graph-min-score");
805 }
806
807 for (memory_id, combined_score) in &fused {
808 if seen_ids.insert(*memory_id) {
809 let normalized = if max_possible > 0.0 {
810 combined_score / max_possible
811 } else {
812 0.0
813 };
814 let score = normalized.clamp(0.0, 1.0);
815 let in_knn = knn_distance_map.contains_key(memory_id);
816 let in_fts = fts_ids.contains(memory_id);
817 let source = match (in_knn, in_fts) {
818 (true, true) => "hybrid",
819 (true, false) => "knn",
820 (false, true) => "fts",
821 (false, false) => "graph",
822 };
823 if let Ok(Some(row)) = memories::read_full(&conn, *memory_id) {
824 let snippet: String = row.body.chars().take(300).collect();
825 hits.push((
826 *memory_id,
827 score,
828 source.to_string(),
829 snippet,
830 row.body,
831 None,
832 ));
833 }
834 }
835 }
836
837 let memory_ids: Vec<i64> = hits.iter().map(|(id, ..)| *id).collect();
840 let mut chains: Vec<EvidenceChain> = Vec::new();
841
842 if !memory_ids.is_empty() && max_hops > 0 {
843 let entity_knn = entities::knn_search(&conn, embedding, namespace, 5).unwrap_or_default();
845 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
846
847 let top_seed_count = 5.min(memory_ids.len());
850 let top_memory_ids = &memory_ids[..top_seed_count];
851 let mut seed_entity_ids: Vec<i64> = entity_ids.clone();
852 for &mem_id in top_memory_ids {
853 let mut stmt = conn
854 .prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")
855 .map_err(|e| format!("prepare failed: {e}"))?;
856 let ids: Vec<i64> = stmt
857 .query_map(rusqlite::params![mem_id], |r| r.get(0))
858 .map_err(|e| format!("query failed: {e}"))?
859 .filter_map(|r| r.ok())
860 .collect();
861 seed_entity_ids.extend(ids);
862 }
863 seed_entity_ids.sort_unstable();
864 seed_entity_ids.dedup();
865 tracing::debug!(
866 sub_query_id,
867 seed_count = seed_entity_ids.len(),
868 "seed entities collected"
869 );
870
871 let all_seed_ids: Vec<i64> = memory_ids
872 .iter()
873 .chain(entity_ids.iter())
874 .copied()
875 .collect();
876
877 if let Ok(graph_results) = traverse_from_memories_with_hops_capped(
879 &conn,
880 &all_seed_ids,
881 namespace,
882 min_weight,
883 max_hops as u32,
884 max_neighbors_per_hop,
885 ) {
886 let seed_score_map: HashMap<i64, f64> = fused
888 .iter()
889 .map(|(id, s)| {
890 let normalized = if max_possible > 0.0 {
891 s / max_possible
892 } else {
893 0.0
894 };
895 (*id, normalized.clamp(0.0, 1.0))
896 })
897 .collect();
898
899 for (graph_mem_id, hop) in graph_results {
900 if seen_ids.insert(graph_mem_id) {
901 let avg_seed_score: f64 = if seed_score_map.is_empty() {
906 0.5
907 } else {
908 let sum: f64 = seed_score_map.values().sum();
909 sum / seed_score_map.len() as f64
910 };
911 let graph_score =
912 (avg_seed_score * graph_decay.powi(hop as i32)).clamp(0.0, 1.0);
913
914 if graph_score < graph_min_score {
915 continue;
916 }
917
918 if let Ok(Some(row)) = memories::read_full(&conn, graph_mem_id) {
919 let snippet: String = row.body.chars().take(300).collect();
920 hits.push((
921 graph_mem_id,
922 graph_score,
923 "graph".to_string(),
924 snippet,
925 row.body,
926 Some(hop as usize),
927 ));
928 }
929 }
930 }
931 }
932
933 if !seed_entity_ids.is_empty() {
936 let (entity_depth, predecessor) = bfs_with_predecessors(
937 &conn,
938 &seed_entity_ids,
939 namespace,
940 min_weight,
941 max_hops as u32,
942 max_neighbors_per_hop,
943 )
944 .unwrap_or_default();
945
946 tracing::debug!(
947 sub_query_id,
948 bfs_nodes = entity_depth.len(),
949 predecessors = predecessor.len(),
950 "BFS complete"
951 );
952
953 let seed_entity_set: HashSet<i64> = seed_entity_ids.iter().copied().collect();
954
955 let all_entity_ids: Vec<i64> = entity_depth.keys().copied().collect();
957 let mut entity_names: HashMap<i64, String> = HashMap::new();
958 for &eid in &all_entity_ids {
959 let name_res: rusqlite::Result<String> = conn.query_row(
960 "SELECT name FROM entities WHERE id = ?1",
961 rusqlite::params![eid],
962 |r| r.get(0),
963 );
964 if let Ok(name) = name_res {
965 entity_names.insert(eid, name);
966 }
967 }
968
969 for (&target_id, &_hop) in &entity_depth {
971 if seed_entity_set.contains(&target_id) {
972 continue;
973 }
974 if !predecessor.contains_key(&target_id) {
975 continue;
976 }
977 if let Some((path_nodes, total_weight)) =
978 reconstruct_path(target_id, &seed_entity_set, &predecessor, &entity_names)
979 {
980 if path_nodes.len() < 2 {
981 continue;
982 }
983 let from = path_nodes
984 .first()
985 .map(|n| n.entity.clone())
986 .unwrap_or_default();
987 let to = path_nodes
988 .last()
989 .map(|n| n.entity.clone())
990 .unwrap_or_default();
991 let depth = path_nodes.len();
992 chains.push(EvidenceChain {
993 from,
994 to,
995 path: path_nodes,
996 total_weight,
997 depth,
998 sub_query_ids: vec![sub_query_id],
999 });
1000 }
1001 }
1002
1003 chains.sort_by(|a, b| {
1005 b.total_weight
1006 .partial_cmp(&a.total_weight)
1007 .unwrap_or(std::cmp::Ordering::Equal)
1008 });
1009 chains.truncate(20);
1010 tracing::debug!(
1011 sub_query_id,
1012 chains_count = chains.len(),
1013 "evidence chains built"
1014 );
1015 }
1016 }
1017
1018 Ok(SubQueryResult {
1019 sub_query_id,
1020 hits,
1021 chains,
1022 })
1023}
1024
1025#[cfg(test)]
1031mod tests {
1032 use super::*;
1033
1034 #[test]
1035 fn test_decompose_and_conjunction() {
1036 let result = decompose_query("A and B", 7);
1037 assert_eq!(result, vec!["A", "B"]);
1038 }
1039
1040 #[test]
1041 fn test_decompose_no_split() {
1042 let result = decompose_query("simple query", 7);
1043 assert_eq!(result, vec!["simple query"]);
1044 }
1045
1046 #[test]
1047 fn test_decompose_three_parts() {
1048 let result = decompose_query("A, B and C", 7);
1049 assert_eq!(result, vec!["A", "B", "C"]);
1050 }
1051
1052 #[test]
1053 fn test_decompose_portuguese_conjunctions() {
1054 let result = decompose_query("A e B", 7);
1055 assert_eq!(result, vec!["A", "B"]);
1056 }
1057
1058 #[test]
1059 fn test_decompose_max_cap() {
1060 let parts: Vec<String> = (0..10).map(|i| format!("part{i}")).collect();
1061 let query = parts.join(", ");
1062 let result = decompose_query(&query, 7);
1063 assert!(
1064 result.len() <= 7,
1065 "expected at most 7 sub-queries, got {}",
1066 result.len()
1067 );
1068 }
1069
1070 #[test]
1071 fn test_decompose_empty_preserves_original() {
1072 let result = decompose_query("", 7);
1073 assert_eq!(result, vec![""]);
1074 }
1075
1076 #[test]
1077 fn test_decompose_semicolons() {
1078 let result = decompose_query("auth design; deployment config; logging", 7);
1079 assert_eq!(result, vec!["auth design", "deployment config", "logging"]);
1080 }
1081
1082 #[test]
1083 fn test_decompose_relational_phrase() {
1084 let result = decompose_query("auth that caused deployment failure", 7);
1085 assert_eq!(result, vec!["auth", "deployment failure"]);
1086 }
1087
1088 #[test]
1089 fn test_sub_query_serialization() {
1090 let sq = SubQuery {
1091 id: 0,
1092 text: "test query".to_string(),
1093 source: "original",
1094 };
1095 let json = serde_json::to_value(&sq).expect("serialization failed");
1096 assert_eq!(json["id"], 0);
1097 assert_eq!(json["text"], "test query");
1098 assert_eq!(json["source"], "original");
1099 }
1100
1101 #[test]
1102 fn test_deep_result_omits_body_when_none() {
1103 let result = DeepResult {
1104 name: "test".to_string(),
1105 score: 0.9,
1106 source: "knn".to_string(),
1107 sub_query_ids: vec![0],
1108 snippet: "snippet".to_string(),
1109 body: None,
1110 hop_distance: None,
1111 };
1112 let json = serde_json::to_string(&result).expect("serialization failed");
1113 assert!(!json.contains("\"body\""), "body must be omitted when None");
1114 }
1115
1116 #[test]
1117 fn test_deep_result_includes_body_when_some() {
1118 let result = DeepResult {
1119 name: "test".to_string(),
1120 score: 0.9,
1121 source: "knn".to_string(),
1122 sub_query_ids: vec![0, 1],
1123 snippet: "snippet".to_string(),
1124 body: Some("full body content".to_string()),
1125 hop_distance: Some(2),
1126 };
1127 let json = serde_json::to_string(&result).expect("serialization failed");
1128 assert!(json.contains("\"body\""), "body must be present when Some");
1129 assert!(json.contains("full body content"));
1130 }
1131
1132 #[test]
1133 fn test_evidence_node_omits_none_fields() {
1134 let node = EvidenceNode {
1135 entity: "auth-module".to_string(),
1136 relation: None,
1137 weight: None,
1138 };
1139 let json = serde_json::to_string(&node).expect("serialization failed");
1140 assert!(
1141 !json.contains("\"relation\""),
1142 "relation must be omitted when None"
1143 );
1144 assert!(
1145 !json.contains("\"weight\""),
1146 "weight must be omitted when None"
1147 );
1148 }
1149
1150 #[test]
1151 fn test_research_stats_serialization() {
1152 let stats = ResearchStats {
1153 sub_queries_total: 3,
1154 sub_queries_completed: 2,
1155 sub_queries_failed: 1,
1156 sub_queries_timed_out: 0,
1157 unique_memories_found: 10,
1158 evidence_chains_found: 2,
1159 elapsed_ms: 1234,
1160 };
1161 let json = serde_json::to_value(&stats).expect("serialization failed");
1162 assert_eq!(json["sub_queries_total"], 3);
1163 assert_eq!(json["sub_queries_completed"], 2);
1164 assert_eq!(json["sub_queries_failed"], 1);
1165 assert_eq!(json["elapsed_ms"], 1234);
1166 }
1167
1168 #[test]
1169 fn test_deep_research_response_serialization() {
1170 let resp = DeepResearchResponse {
1171 query: "test query".to_string(),
1172 sub_queries: vec![SubQuery {
1173 id: 0,
1174 text: "test query".to_string(),
1175 source: "original",
1176 }],
1177 results: vec![],
1178 evidence_chains: vec![],
1179 graph_context: None,
1180 stats: ResearchStats {
1181 sub_queries_total: 1,
1182 sub_queries_completed: 1,
1183 sub_queries_failed: 0,
1184 sub_queries_timed_out: 0,
1185 unique_memories_found: 0,
1186 evidence_chains_found: 0,
1187 elapsed_ms: 42,
1188 },
1189 };
1190 let json = serde_json::to_value(&resp).expect("serialization failed");
1191 assert_eq!(json["query"], "test query");
1192 assert!(json["sub_queries"].is_array());
1193 assert!(json["results"].is_array());
1194 assert!(json["evidence_chains"].is_array());
1195 assert_eq!(json["stats"]["elapsed_ms"], 42);
1196 }
1197
1198 #[test]
1202 fn test_distinct_sub_queries_produce_distinct_texts() {
1203 let queries = [
1204 "authentication design decisions",
1205 "deployment configuration and infrastructure",
1206 ];
1207 assert_ne!(queries[0], queries[1]);
1209
1210 let decomposed = decompose_query(
1212 "authentication design decisions; deployment configuration and infrastructure",
1213 7,
1214 );
1215 assert_eq!(decomposed.len(), 2);
1216 assert_ne!(decomposed[0], decomposed[1]);
1217 }
1218
1219 #[test]
1221 fn test_rrf_fuse_via_fusion_module() {
1222 use crate::storage::fusion::rrf_fuse;
1223
1224 let knn_ids: Vec<i64> = vec![1, 2, 3];
1225 let fts_ids: Vec<i64> = vec![2, 1, 4];
1226 let scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], 60.0);
1227
1228 let score_1 = scores[&1];
1230 let score_2 = scores[&2];
1231 let score_3 = scores[&3]; let score_4 = scores[&4]; assert!(
1235 score_1 > score_3,
1236 "id 1 (both lists) must beat id 3 (knn-only rank 3)"
1237 );
1238 assert!(
1239 score_2 > score_4,
1240 "id 2 (both lists) must beat id 4 (fts-only rank 3)"
1241 );
1242 }
1243
1244 #[test]
1246 fn test_evidence_chain_has_from_to_and_path() {
1247 let chain = EvidenceChain {
1248 from: "auth-module".to_string(),
1249 to: "jwt-service".to_string(),
1250 path: vec![
1251 EvidenceNode {
1252 entity: "auth-module".to_string(),
1253 relation: None,
1254 weight: None,
1255 },
1256 EvidenceNode {
1257 entity: "token-validator".to_string(),
1258 relation: Some("depends-on".to_string()),
1259 weight: Some(0.9),
1260 },
1261 EvidenceNode {
1262 entity: "jwt-service".to_string(),
1263 relation: Some("uses".to_string()),
1264 weight: Some(0.8),
1265 },
1266 ],
1267 total_weight: 0.72,
1268 depth: 3,
1269 sub_query_ids: vec![0],
1270 };
1271
1272 let json = serde_json::to_value(&chain).expect("serialization failed");
1273 assert!(
1274 json["from"].is_string(),
1275 "evidence chain must have 'from' field"
1276 );
1277 assert!(
1278 json["to"].is_string(),
1279 "evidence chain must have 'to' field"
1280 );
1281 assert!(
1282 json["path"].is_array(),
1283 "evidence chain must have 'path' array"
1284 );
1285 assert_eq!(json["path"].as_array().unwrap().len(), 3);
1286 assert!(json["total_weight"].is_number(), "must have total_weight");
1287 assert_eq!(json["depth"], 3);
1288 }
1289
1290 #[test]
1292 fn test_reconstruct_path_root_to_target_order() {
1293 let seed_set: HashSet<i64> = [10i64].into_iter().collect();
1295 let mut predecessor: HashMap<i64, (i64, String, f64)> = HashMap::new();
1296 predecessor.insert(20, (10, "depends-on".to_string(), 0.9));
1297 predecessor.insert(30, (20, "uses".to_string(), 0.8));
1298 let mut entity_names: HashMap<i64, String> = HashMap::new();
1299 entity_names.insert(10, "seed-entity".to_string());
1300 entity_names.insert(20, "middle-entity".to_string());
1301 entity_names.insert(30, "target-entity".to_string());
1302
1303 let result = reconstruct_path(30, &seed_set, &predecessor, &entity_names);
1304 assert!(result.is_some(), "path must be reconstructed");
1305 let (nodes, weight) = result.unwrap();
1306 assert_eq!(nodes.len(), 3);
1308 assert_eq!(nodes[0].entity, "seed-entity");
1309 assert_eq!(nodes[1].entity, "middle-entity");
1310 assert_eq!(nodes[2].entity, "target-entity");
1311 assert!((weight - 0.72).abs() < 1e-6);
1313 }
1314
1315 #[test]
1317 fn test_evidence_chains_single_hop_filtered_out() {
1318 let chain = EvidenceChain {
1320 from: "a".to_string(),
1321 to: "a".to_string(),
1322 path: vec![EvidenceNode {
1323 entity: "a".to_string(),
1324 relation: None,
1325 weight: None,
1326 }],
1327 total_weight: 1.0,
1328 depth: 1,
1329 sub_query_ids: vec![0],
1330 };
1331 let chains = vec![chain];
1333 let retained: Vec<_> = chains.into_iter().filter(|c| c.depth >= 2).collect();
1334 assert!(retained.is_empty(), "depth-1 chains must be filtered out");
1335 }
1336
1337 #[test]
1339 fn test_bfs_with_predecessors_respects_neighbor_cap() {
1340 use crate::graph::bfs_with_predecessors;
1341 use rusqlite::Connection;
1342
1343 let conn = Connection::open_in_memory().unwrap();
1344 conn.execute_batch(
1345 "CREATE TABLE relationships (
1346 source_id INTEGER NOT NULL,
1347 target_id INTEGER NOT NULL,
1348 weight REAL NOT NULL,
1349 namespace TEXT NOT NULL,
1350 relation TEXT NOT NULL DEFAULT 'related'
1351 );",
1352 )
1353 .unwrap();
1354
1355 for target in 2i64..=6 {
1357 conn.execute(
1358 "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, 'ns')",
1359 rusqlite::params![1i64, target, 1.0f64],
1360 )
1361 .unwrap();
1362 }
1363
1364 let (depth_uncapped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, None).unwrap();
1366 assert_eq!(
1367 depth_uncapped.len() - 1,
1368 5,
1369 "uncapped must discover all 5 neighbours (plus seed)"
1370 );
1371
1372 let (depth_capped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, Some(2)).unwrap();
1374 assert_eq!(
1376 depth_capped.len(),
1377 3,
1378 "capped to 2 must yield seed + 2 neighbours"
1379 );
1380 }
1381}