Skip to main content

sqlite_graphrag/commands/
deep_research.rs

1//! Handler for the `deep-research` CLI subcommand.
2//!
3//! Orchestrates parallel multi-hop GraphRAG search via query decomposition.
4//! The workload is I/O-bound (SQLite WAL reads), so tokio is used instead of
5//! rayon. Each sub-query opens its own read-only connection.
6
7use crate::errors::AppError;
8use crate::graph::{
9    bfs_with_predecessors, traverse_from_memories_with_hops_capped, PredecessorMap,
10};
11use crate::output;
12use crate::paths::AppPaths;
13use crate::storage::connection::open_ro;
14use crate::storage::fusion::{rrf_fuse, rrf_max_possible};
15use crate::storage::{entities, memories};
16
17use serde::Serialize;
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20use tokio::sync::Semaphore;
21use tokio::task::JoinSet;
22
23/// Arguments for the `deep-research` subcommand.
24#[derive(clap::Args)]
25#[command(
26    about = "Deep parallel multi-hop GraphRAG research via query decomposition",
27    after_long_help = "EXAMPLES:\n  \
28        # Basic deep research\n  \
29        sqlite-graphrag deep-research \"auth architecture decisions\"\n\n  \
30        # With custom parameters\n  \
31        sqlite-graphrag deep-research \"auth\" --k 20 --max-hops 3 --max-sub-queries 7\n\n  \
32        # Include full memory bodies in output\n  \
33        sqlite-graphrag deep-research \"auth\" --with-bodies\n\n  \
34        # Tune RRF and graph scoring\n  \
35        sqlite-graphrag deep-research \"auth and deployment\" --rrf-k 60 --graph-decay 0.7"
36)]
37pub struct DeepResearchArgs {
38    /// Research query to decompose and search.
39    #[arg(value_name = "QUERY", help = "Research query to decompose and search")]
40    pub query: String,
41    /// Results per sub-query (Recall@20 captures 95%+ relevant hits).
42    #[arg(
43        long,
44        short,
45        default_value_t = 20,
46        help = "Results per sub-query (Recall@20 captures 95%+ relevant hits)"
47    )]
48    pub k: usize,
49    /// Maximum sub-queries from decomposition (covers complex multi-hop queries).
50    #[arg(
51        long,
52        default_value_t = 7,
53        help = "Maximum sub-queries (covers complex multi-hop queries)"
54    )]
55    pub max_sub_queries: usize,
56    /// Multi-hop graph traversal depth (sweet spot: 2-3 hops).
57    #[arg(
58        long,
59        default_value_t = 3,
60        help = "Multi-hop graph traversal depth (sweet spot: 2-3 hops)"
61    )]
62    pub max_hops: usize,
63    /// Minimum edge weight for graph traversal.
64    #[arg(
65        long,
66        default_value_t = 0.3,
67        help = "Minimum edge weight for graph traversal"
68    )]
69    pub min_weight: f64,
70    /// Maximum concurrent sub-queries (default: min(cpus, 8)).
71    #[arg(long, help = "Maximum concurrent sub-queries (default: min(cpus, 8))")]
72    pub max_concurrency: Option<usize>,
73    /// Timeout per sub-query in seconds.
74    #[arg(long, default_value_t = 30, help = "Timeout per sub-query in seconds")]
75    pub timeout: u64,
76    /// Include full memory bodies in results.
77    #[arg(
78        long,
79        default_value_t = false,
80        help = "Include full memory bodies in results"
81    )]
82    pub with_bodies: bool,
83    /// Maximum results after deduplication.
84    #[arg(
85        long,
86        default_value_t = 50,
87        help = "Maximum results after deduplication"
88    )]
89    pub max_results: usize,
90    /// RRF k parameter controlling score smoothing (higher = less weight on top ranks).
91    #[arg(
92        long,
93        default_value_t = 60.0,
94        help = "RRF k parameter (higher = less weight on top ranks)"
95    )]
96    pub rrf_k: f64,
97    /// Decay factor applied to graph scores per hop (score = seed_score * decay^hop).
98    #[arg(
99        long,
100        default_value_t = 0.7,
101        help = "Graph score decay factor per hop (0.0-1.0)"
102    )]
103    pub graph_decay: f64,
104    /// Minimum score threshold for graph-expanded results (filters noise).
105    #[arg(
106        long,
107        default_value_t = 0.2,
108        help = "Minimum score threshold for graph-expanded results"
109    )]
110    pub graph_min_score: f64,
111    /// Limit top-k neighbours followed per entity per hop (None = unlimited).
112    #[arg(
113        long,
114        help = "Limit neighbours per entity per hop for graph traversal (default: unlimited)"
115    )]
116    pub max_neighbors_per_hop: Option<usize>,
117    /// Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global).
118    #[arg(
119        long,
120        help = "Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global)"
121    )]
122    pub namespace: Option<String>,
123    /// JSON output (always on, kept for consistency).
124    #[arg(long, hide = true)]
125    pub json: bool,
126    /// Database path.
127    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
128    pub db: Option<String>,
129    #[command(flatten)]
130    pub daemon: crate::cli::DaemonOpts,
131}
132
133#[derive(Serialize)]
134struct SubQuery {
135    id: usize,
136    text: String,
137    source: &'static str,
138}
139
140#[derive(Serialize)]
141struct DeepResult {
142    name: String,
143    score: f64,
144    source: String,
145    sub_query_ids: Vec<usize>,
146    snippet: String,
147    #[serde(skip_serializing_if = "Option::is_none")]
148    body: Option<String>,
149    hop_distance: Option<usize>,
150}
151
152/// A node in a reconstructed evidence path.
153#[derive(Serialize, Clone)]
154struct EvidenceNode {
155    entity: String,
156    #[serde(skip_serializing_if = "Option::is_none")]
157    relation: Option<String>,
158    #[serde(skip_serializing_if = "Option::is_none")]
159    weight: Option<f64>,
160}
161
162/// A directed evidence chain reconstructed from BFS predecessors.
163///
164/// Fields:
165/// - `from`: name of the seed (source) entity.
166/// - `to`: name of the terminal (target) entity.
167/// - `path`: ordered list of intermediate nodes from `from` to `to`.
168/// - `total_weight`: product of edge weights along the path.
169/// - `sub_query_ids`: which sub-queries produced this chain.
170#[derive(Serialize)]
171struct EvidenceChain {
172    from: String,
173    to: String,
174    path: Vec<EvidenceNode>,
175    total_weight: f64,
176    depth: usize,
177    sub_query_ids: Vec<usize>,
178}
179
180#[derive(Serialize)]
181struct ResearchStats {
182    sub_queries_total: usize,
183    sub_queries_completed: usize,
184    sub_queries_failed: usize,
185    sub_queries_timed_out: usize,
186    unique_memories_found: usize,
187    evidence_chains_found: usize,
188    elapsed_ms: u64,
189}
190
191#[derive(Serialize)]
192struct DeepResearchResponse {
193    query: String,
194    sub_queries: Vec<SubQuery>,
195    results: Vec<DeepResult>,
196    evidence_chains: Vec<EvidenceChain>,
197    stats: ResearchStats,
198}
199
200/// Aggregated hit data: (score, source_label, snippet, body, hop_distance, sub_query_ids).
201type MergedHit = (f64, String, String, String, Option<usize>, Vec<usize>);
202
203/// Intermediate result from a single sub-query execution.
204struct SubQueryResult {
205    sub_query_id: usize,
206    /// (memory_id, score, source_label, snippet, body, hop_distance)
207    hits: Vec<(i64, f64, String, String, String, Option<usize>)>,
208    /// Evidence chains reconstructed from BFS.
209    chains: Vec<EvidenceChain>,
210}
211
212/// Sync entry point — builds a tokio runtime for the async fan-out.
213pub fn run(args: DeepResearchArgs) -> Result<(), AppError> {
214    let rt = tokio::runtime::Builder::new_multi_thread()
215        .worker_threads(2)
216        .enable_all()
217        .build()
218        .map_err(|e| AppError::Internal(anyhow::anyhow!("failed to build tokio runtime: {e}")))?;
219    rt.block_on(run_async(args))
220}
221
222/// Main async logic: decompose, fan-out, assemble, emit JSON.
223async fn run_async(args: DeepResearchArgs) -> Result<(), AppError> {
224    let start = std::time::Instant::now();
225
226    if args.query.trim().is_empty() {
227        return Err(AppError::Validation(crate::i18n::validation::empty_query()));
228    }
229
230    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
231    let paths = AppPaths::resolve(args.db.as_deref())?;
232    crate::storage::connection::ensure_db_ready(&paths)?;
233
234    // Phase 1: Query decomposition (sync, pure logic).
235    let sub_query_texts = decompose_query(&args.query, args.max_sub_queries);
236    let sub_queries: Vec<SubQuery> = sub_query_texts
237        .iter()
238        .enumerate()
239        .map(|(i, text)| SubQuery {
240            id: i,
241            text: text.clone(),
242            source: if sub_query_texts.len() == 1 {
243                "original"
244            } else {
245                "decomposed"
246            },
247        })
248        .collect();
249
250    // GAP-07 FIX: compute ONE embedding PER sub-query text (sequential — daemon serialises).
251    // The previous code used a single embedding for args.query shared across all sub-queries,
252    // making decomposition cosmetic.  We now build a Vec<Arc<Vec<f32>>> indexed by sub-query.
253    output::emit_progress_i18n(
254        "Computing per-sub-query embeddings...",
255        "Calculando embeddings por sub-consulta...",
256    );
257    let mut sub_embeddings: Vec<Arc<Vec<f32>>> = Vec::with_capacity(sub_query_texts.len());
258    for sq_text in &sub_query_texts {
259        let emb = crate::daemon::embed_query_or_local(
260            &paths.models,
261            sq_text,
262            args.daemon.autostart_daemon,
263        )?;
264        sub_embeddings.push(Arc::new(emb));
265    }
266
267    // Phase 2: Fan-out — parallel sub-query execution.
268    let cpu_count = std::thread::available_parallelism()
269        .map(|n| n.get())
270        .unwrap_or(4);
271    let permits = args
272        .max_concurrency
273        .unwrap_or_else(|| cpu_count.min(8))
274        .min(sub_queries.len())
275        .max(1);
276    let semaphore = Arc::new(Semaphore::new(permits));
277    let timeout_dur = std::time::Duration::from_secs(args.timeout);
278
279    let mut join_set: JoinSet<Result<SubQueryResult, (usize, String)>> = JoinSet::new();
280
281    for (idx, sq_text) in sub_query_texts.iter().enumerate() {
282        let sem = Arc::clone(&semaphore);
283        // GAP-07 FIX: pass embedding for THIS specific sub-query.
284        let emb = Arc::clone(&sub_embeddings[idx]);
285        let ns = namespace.clone();
286        let db_path = paths.db.clone();
287        let query_text = sq_text.clone();
288        let k = args.k;
289        let max_hops = args.max_hops;
290        let min_weight = args.min_weight;
291        let rrf_k = args.rrf_k;
292        let graph_decay = args.graph_decay;
293        let graph_min_score = args.graph_min_score;
294        let max_neighbors_per_hop = args.max_neighbors_per_hop;
295
296        join_set.spawn(async move {
297            let _permit = sem
298                .acquire_owned()
299                .await
300                .map_err(|e| (idx, format!("semaphore closed: {e}")))?;
301
302            // Dereference the Arc to obtain a &[f32] slice for the sync function.
303            let result = tokio::time::timeout(timeout_dur, async move {
304                execute_sub_query(
305                    idx,
306                    &query_text,
307                    emb.as_slice(),
308                    &ns,
309                    &db_path,
310                    k,
311                    max_hops,
312                    min_weight,
313                    rrf_k,
314                    graph_decay,
315                    graph_min_score,
316                    max_neighbors_per_hop,
317                )
318            })
319            .await;
320
321            match result {
322                Ok(inner) => inner.map_err(|e| (idx, e)),
323                Err(_) => Err((idx, "timeout".to_string())),
324            }
325        });
326    }
327
328    // Collect results incrementally.
329    let mut sub_query_results: Vec<SubQueryResult> = Vec::with_capacity(sub_queries.len());
330    let mut failed_count = 0usize;
331    let mut timed_out_count = 0usize;
332
333    while let Some(join_result) = join_set.join_next().await {
334        match join_result {
335            Ok(Ok(sqr)) => sub_query_results.push(sqr),
336            Ok(Err((_idx, reason))) => {
337                if reason == "timeout" {
338                    timed_out_count += 1;
339                } else {
340                    failed_count += 1;
341                }
342                tracing::warn!(sub_query_id = _idx, reason = %reason, "sub-query failed");
343            }
344            Err(join_err) => {
345                failed_count += 1;
346                if join_err.is_panic() {
347                    tracing::error!("sub-query task panicked: {join_err}");
348                } else {
349                    tracing::warn!("sub-query task cancelled: {join_err}");
350                }
351            }
352        }
353    }
354
355    // Phase 3: Evidence assembly — merge, dedup, rank.
356    // Aggregate hits: memory_id -> (best_score, source, snippet, body, hop_distance, sub_query_ids)
357    let mut merged: HashMap<i64, MergedHit> = HashMap::new();
358
359    for sqr in &sub_query_results {
360        for (mem_id, score, source, snippet, body, hop) in &sqr.hits {
361            let entry = merged.entry(*mem_id).or_insert_with(|| {
362                (
363                    *score,
364                    source.clone(),
365                    snippet.clone(),
366                    body.clone(),
367                    *hop,
368                    Vec::new(),
369                )
370            });
371            // Keep best score.
372            if *score > entry.0 {
373                entry.0 = *score;
374                entry.1 = source.clone();
375                entry.2 = snippet.clone();
376                entry.3 = body.clone();
377                entry.4 = *hop;
378            }
379            if !entry.5.contains(&sqr.sub_query_id) {
380                entry.5.push(sqr.sub_query_id);
381            }
382        }
383    }
384
385    // Resolve memory names for merged results.
386    let conn = open_ro(&paths.db)?;
387    let mut results: Vec<DeepResult> = Vec::with_capacity(merged.len().min(args.max_results));
388
389    // Sort by score descending.
390    let mut ranked: Vec<(i64, MergedHit)> = merged.into_iter().collect();
391    ranked.sort_by(|a, b| {
392        b.1 .0
393            .partial_cmp(&a.1 .0)
394            .unwrap_or(std::cmp::Ordering::Equal)
395    });
396    ranked.truncate(args.max_results);
397
398    for (mem_id, (score, source, snippet, body, hop, sq_ids)) in ranked {
399        let name = match memories::read_full(&conn, mem_id)? {
400            Some(row) => row.name,
401            None => continue,
402        };
403        results.push(DeepResult {
404            name,
405            score,
406            source,
407            sub_query_ids: sq_ids,
408            snippet,
409            body: if args.with_bodies { Some(body) } else { None },
410            hop_distance: hop,
411        });
412    }
413
414    // GAP-09/10 FIX: Collect evidence chains from reconstructed BFS paths.
415    // The old code appended flat node pairs from a global SELECT; now each
416    // sub-query returns directed EvidenceChain structs (from, to, path).
417    let completed_count = sub_query_results.len();
418    let mut evidence_chains: Vec<EvidenceChain> = Vec::new();
419    let mut seen_chain_keys: HashSet<String> = HashSet::new();
420
421    for sqr in sub_query_results {
422        for chain in sqr.chains {
423            // Deduplicate chains by (from, to) pair.
424            let key = format!("{}->{}", chain.from, chain.to);
425            if seen_chain_keys.insert(key) {
426                evidence_chains.push(chain);
427            }
428        }
429    }
430
431    // Sort evidence chains by total_weight descending, discard single-hop trivial chains.
432    evidence_chains.retain(|c| c.depth >= 2);
433    evidence_chains.sort_by(|a, b| {
434        b.total_weight
435            .partial_cmp(&a.total_weight)
436            .unwrap_or(std::cmp::Ordering::Equal)
437    });
438
439    let unique_memories = results.len();
440    let evidence_count = evidence_chains.len();
441
442    // Phase 4: JSON output.
443    output::emit_json(&DeepResearchResponse {
444        query: args.query,
445        sub_queries,
446        results,
447        evidence_chains,
448        stats: ResearchStats {
449            sub_queries_total: sub_query_texts.len(),
450            sub_queries_completed: completed_count,
451            sub_queries_failed: failed_count,
452            sub_queries_timed_out: timed_out_count,
453            unique_memories_found: unique_memories,
454            evidence_chains_found: evidence_count,
455            elapsed_ms: start.elapsed().as_millis() as u64,
456        },
457    })?;
458
459    Ok(())
460}
461
462/// Heuristic query decomposition: splits by conjunctions, commas, semicolons,
463/// relational phrases, and extracts explicit entities (kebab-case or quoted).
464fn decompose_query(query: &str, max: usize) -> Vec<String> {
465    if query.is_empty() {
466        return vec![query.to_string()];
467    }
468
469    let mut parts: Vec<String> = Vec::new();
470
471    // Split by relational phrases first (most specific).
472    let relational = [
473        " that caused ",
474        " depending on ",
475        " related to ",
476        " connected to ",
477        " linked to ",
478        " caused by ",
479        " followed by ",
480    ];
481    let mut text = query.to_string();
482    let mut did_relational_split = false;
483    for phrase in &relational {
484        if text.to_lowercase().contains(phrase) {
485            let lower = text.to_lowercase();
486            if let Some(pos) = lower.find(phrase) {
487                let left = text[..pos].trim().to_string();
488                let right = text[pos + phrase.len()..].trim().to_string();
489                if !left.is_empty() {
490                    parts.push(left);
491                }
492                if !right.is_empty() {
493                    text = right;
494                }
495                did_relational_split = true;
496            }
497        }
498    }
499    if did_relational_split && !text.is_empty() {
500        parts.push(text.clone());
501    }
502
503    // If no relational split, try conjunctions and delimiters.
504    if parts.is_empty() {
505        // Split by semicolons first.
506        let semi_parts: Vec<&str> = query.split(';').collect();
507        if semi_parts.len() > 1 {
508            for p in &semi_parts {
509                let trimmed = p.trim();
510                if !trimmed.is_empty() {
511                    parts.push(trimmed.to_string());
512                }
513            }
514        } else {
515            // Split by commas and conjunctions.
516            // Replace " and " and " e " (Portuguese) with comma, then split.
517            let normalized = query
518                .replace(" and ", ", ")
519                .replace(" AND ", ", ")
520                .replace(" e ", ", ")
521                .replace(" E ", ", ");
522            let comma_parts: Vec<&str> = normalized.split(',').collect();
523            if comma_parts.len() > 1 {
524                for p in &comma_parts {
525                    let trimmed = p.trim();
526                    if !trimmed.is_empty() {
527                        parts.push(trimmed.to_string());
528                    }
529                }
530            }
531        }
532    }
533
534    // If still no split, return original as single sub-query.
535    if parts.is_empty() {
536        return vec![query.to_string()];
537    }
538
539    // Cap at max.
540    parts.truncate(max);
541    parts
542}
543
544/// Reconstruct a directed path from `target_entity_id` back to a seed using the
545/// predecessor map built by BFS.  Returns the path nodes from root to target
546/// plus the accumulated edge weights.
547fn reconstruct_path(
548    target_id: i64,
549    seed_entity_ids: &HashSet<i64>,
550    predecessor: &PredecessorMap,
551    entity_names: &HashMap<i64, String>,
552) -> Option<(Vec<EvidenceNode>, f64)> {
553    let mut path_ids: Vec<(i64, Option<String>, Option<f64>)> = Vec::new();
554    let mut total_weight = 1.0_f64;
555    let mut current = target_id;
556
557    loop {
558        if seed_entity_ids.contains(&current) {
559            break;
560        }
561        let (parent, relation, weight) = predecessor.get(&current)?;
562        total_weight *= weight;
563        path_ids.push((current, Some(relation.clone()), Some(*weight)));
564        current = *parent;
565    }
566    // Push the seed entity (root).
567    path_ids.push((current, None, None));
568
569    // Reverse so path goes from seed → target.
570    path_ids.reverse();
571
572    let nodes: Vec<EvidenceNode> = path_ids
573        .into_iter()
574        .map(|(id, relation, weight)| EvidenceNode {
575            entity: entity_names
576                .get(&id)
577                .cloned()
578                .unwrap_or_else(|| format!("entity-{id}")),
579            relation,
580            weight,
581        })
582        .collect();
583
584    Some((nodes, total_weight))
585}
586
587/// Execute a single sub-query: hybrid search (KNN + FTS fused via RRF) + graph traversal.
588///
589/// GAP-07 fix: receives the embedding for THIS sub-query (not the shared original).
590/// GAP-08/11 fix: uses rrf_fuse() for proper score fusion instead of hardcoded 0.5.
591/// GAP-09/10 fix: builds directed evidence chains filtered to discovered entities.
592/// GAP-17: respects max_neighbors_per_hop cap in BFS.
593///
594/// Runs synchronously on a blocking thread (called from a tokio spawn context).
595/// Each call opens its own read-only SQLite connection to leverage WAL concurrency.
596#[allow(clippy::too_many_arguments)]
597fn execute_sub_query(
598    sub_query_id: usize,
599    query_text: &str,
600    embedding: &[f32],
601    namespace: &str,
602    db_path: &std::path::Path,
603    k: usize,
604    max_hops: usize,
605    min_weight: f64,
606    rrf_k: f64,
607    graph_decay: f64,
608    graph_min_score: f64,
609    max_neighbors_per_hop: Option<usize>,
610) -> Result<SubQueryResult, String> {
611    let conn = open_ro(db_path).map_err(|e| format!("failed to open db: {e}"))?;
612
613    let mut hits: Vec<(i64, f64, String, String, String, Option<usize>)> =
614        Vec::with_capacity(k * 2);
615    let mut seen_ids: HashSet<i64> = HashSet::new();
616
617    // --- GAP-08/11 FIX: Use RRF fusion for KNN + FTS instead of hardcoded 0.5 ---
618
619    // 1. KNN vector search — collect ranked IDs.
620    let knn_results = memories::knn_search(&conn, embedding, &[namespace.to_string()], None, k)
621        .map_err(|e| format!("knn_search failed: {e}"))?;
622    let knn_ids: Vec<i64> = knn_results.iter().map(|(id, _)| *id).collect();
623
624    // Build distance map for score computation.
625    let knn_distance_map: HashMap<i64, f64> = knn_results
626        .iter()
627        .map(|(id, dist)| (*id, *dist as f64))
628        .collect();
629
630    // 2. FTS5 search — collect ranked IDs.
631    let fts_results = match memories::fts_search(&conn, query_text, namespace, None, k) {
632        Ok(rows) => rows,
633        Err(e) => {
634            tracing::warn!(
635                sub_query_id,
636                "FTS5 search failed, continuing with KNN only: {e}"
637            );
638            vec![]
639        }
640    };
641    let fts_ids: Vec<i64> = fts_results.iter().map(|r| r.id).collect();
642
643    // 3. Fuse via RRF.
644    let rrf_scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], rrf_k);
645    let max_possible = rrf_max_possible(&[1.0, 1.0], rrf_k);
646
647    // 4. Sort fused results and build hits.
648    let mut fused: Vec<(i64, f64)> = rrf_scores.into_iter().collect();
649    fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
650    fused.truncate(k * 2);
651
652    for (memory_id, combined_score) in &fused {
653        if seen_ids.insert(*memory_id) {
654            let normalized = if max_possible > 0.0 {
655                combined_score / max_possible
656            } else {
657                0.0
658            };
659            let score = normalized.clamp(0.0, 1.0);
660            let source = if knn_distance_map.contains_key(memory_id) {
661                "knn"
662            } else {
663                "fts"
664            };
665            if let Ok(Some(row)) = memories::read_full(&conn, *memory_id) {
666                let snippet: String = row.body.chars().take(300).collect();
667                hits.push((
668                    *memory_id,
669                    score,
670                    source.to_string(),
671                    snippet,
672                    row.body,
673                    None,
674                ));
675            }
676        }
677    }
678
679    // 5. Graph traversal from discovered memories.
680    // GAP-09/10 FIX: entity KNN also uses this sub-query's embedding.
681    let memory_ids: Vec<i64> = hits.iter().map(|(id, ..)| *id).collect();
682    let mut chains: Vec<EvidenceChain> = Vec::new();
683
684    if !memory_ids.is_empty() && max_hops > 0 {
685        // Seed entities from KNN on entity vectors using THIS sub-query's embedding.
686        let entity_knn = entities::knn_search(&conn, embedding, namespace, 5).unwrap_or_default();
687        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
688
689        // Gather seed entity IDs from discovered memories too.
690        let mut seed_entity_ids: Vec<i64> = entity_ids.clone();
691        for &mem_id in &memory_ids {
692            let mut stmt = conn
693                .prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")
694                .map_err(|e| format!("prepare failed: {e}"))?;
695            let ids: Vec<i64> = stmt
696                .query_map(rusqlite::params![mem_id], |r| r.get(0))
697                .map_err(|e| format!("query failed: {e}"))?
698                .filter_map(|r| r.ok())
699                .collect();
700            seed_entity_ids.extend(ids);
701        }
702        seed_entity_ids.sort_unstable();
703        seed_entity_ids.dedup();
704
705        let all_seed_ids: Vec<i64> = memory_ids
706            .iter()
707            .chain(entity_ids.iter())
708            .copied()
709            .collect();
710
711        // Graph traversal with hop scores.
712        if let Ok(graph_results) = traverse_from_memories_with_hops_capped(
713            &conn,
714            &all_seed_ids,
715            namespace,
716            min_weight,
717            max_hops as u32,
718            max_neighbors_per_hop,
719        ) {
720            // Build seed score map from RRF-fused scores for graph decay computation.
721            let seed_score_map: HashMap<i64, f64> = fused
722                .iter()
723                .map(|(id, s)| {
724                    let normalized = if max_possible > 0.0 {
725                        s / max_possible
726                    } else {
727                        0.0
728                    };
729                    (*id, normalized.clamp(0.0, 1.0))
730                })
731                .collect();
732
733            for (graph_mem_id, hop) in graph_results {
734                if seen_ids.insert(graph_mem_id) {
735                    // GAP-08/11 FIX: graph score = seed_score * decay^hop * edge_weight.
736                    // For the seed score, use the best score among the seed memories that
737                    // transitively reached this graph memory (approximate with the average
738                    // seed score since we don't track the exact path yet).
739                    let avg_seed_score: f64 = if seed_score_map.is_empty() {
740                        0.5
741                    } else {
742                        let sum: f64 = seed_score_map.values().sum();
743                        sum / seed_score_map.len() as f64
744                    };
745                    let graph_score =
746                        (avg_seed_score * graph_decay.powi(hop as i32)).clamp(0.0, 1.0);
747
748                    if graph_score < graph_min_score {
749                        continue;
750                    }
751
752                    if let Ok(Some(row)) = memories::read_full(&conn, graph_mem_id) {
753                        let snippet: String = row.body.chars().take(300).collect();
754                        hits.push((
755                            graph_mem_id,
756                            graph_score,
757                            "graph".to_string(),
758                            snippet,
759                            row.body,
760                            Some(hop as usize),
761                        ));
762                    }
763                }
764            }
765        }
766
767        // GAP-09/10 FIX: Build directed evidence chains using BFS with predecessor map,
768        // filtered to entities discovered in this sub-query.
769        if !seed_entity_ids.is_empty() {
770            let (entity_depth, predecessor) = bfs_with_predecessors(
771                &conn,
772                &seed_entity_ids,
773                namespace,
774                min_weight,
775                max_hops as u32,
776                max_neighbors_per_hop,
777            )
778            .unwrap_or_default();
779
780            let seed_entity_set: HashSet<i64> = seed_entity_ids.iter().copied().collect();
781
782            // Collect entity IDs we need names for.
783            let all_entity_ids: Vec<i64> = entity_depth.keys().copied().collect();
784            let mut entity_names: HashMap<i64, String> = HashMap::new();
785            for &eid in &all_entity_ids {
786                let name_res: rusqlite::Result<String> = conn.query_row(
787                    "SELECT name FROM entities WHERE id = ?1",
788                    rusqlite::params![eid],
789                    |r| r.get(0),
790                );
791                if let Ok(name) = name_res {
792                    entity_names.insert(eid, name);
793                }
794            }
795
796            // Reconstruct a path for each non-seed entity that has a predecessor.
797            for (&target_id, &_hop) in &entity_depth {
798                if seed_entity_set.contains(&target_id) {
799                    continue;
800                }
801                if !predecessor.contains_key(&target_id) {
802                    continue;
803                }
804                if let Some((path_nodes, total_weight)) =
805                    reconstruct_path(target_id, &seed_entity_set, &predecessor, &entity_names)
806                {
807                    if path_nodes.len() < 2 {
808                        continue;
809                    }
810                    let from = path_nodes
811                        .first()
812                        .map(|n| n.entity.clone())
813                        .unwrap_or_default();
814                    let to = path_nodes
815                        .last()
816                        .map(|n| n.entity.clone())
817                        .unwrap_or_default();
818                    let depth = path_nodes.len();
819                    chains.push(EvidenceChain {
820                        from,
821                        to,
822                        path: path_nodes,
823                        total_weight,
824                        depth,
825                        sub_query_ids: vec![sub_query_id],
826                    });
827                }
828            }
829
830            // Sort chains by total_weight descending and cap to avoid huge output.
831            chains.sort_by(|a, b| {
832                b.total_weight
833                    .partial_cmp(&a.total_weight)
834                    .unwrap_or(std::cmp::Ordering::Equal)
835            });
836            chains.truncate(20);
837        }
838    }
839
840    Ok(SubQueryResult {
841        sub_query_id,
842        hits,
843        chains,
844    })
845}
846
847// ────────────────────────────────────────────────────────────────────────────
848// Re-export sub_query_results field initialisation for the stats counter.
849// The field is moved out of run_async after the join loop; we need to shadow it.
850// ────────────────────────────────────────────────────────────────────────────
851
852#[cfg(test)]
853mod tests {
854    use super::*;
855
856    #[test]
857    fn test_decompose_and_conjunction() {
858        let result = decompose_query("A and B", 7);
859        assert_eq!(result, vec!["A", "B"]);
860    }
861
862    #[test]
863    fn test_decompose_no_split() {
864        let result = decompose_query("simple query", 7);
865        assert_eq!(result, vec!["simple query"]);
866    }
867
868    #[test]
869    fn test_decompose_three_parts() {
870        let result = decompose_query("A, B and C", 7);
871        assert_eq!(result, vec!["A", "B", "C"]);
872    }
873
874    #[test]
875    fn test_decompose_portuguese_conjunctions() {
876        let result = decompose_query("A e B", 7);
877        assert_eq!(result, vec!["A", "B"]);
878    }
879
880    #[test]
881    fn test_decompose_max_cap() {
882        let parts: Vec<String> = (0..10).map(|i| format!("part{i}")).collect();
883        let query = parts.join(", ");
884        let result = decompose_query(&query, 7);
885        assert!(
886            result.len() <= 7,
887            "expected at most 7 sub-queries, got {}",
888            result.len()
889        );
890    }
891
892    #[test]
893    fn test_decompose_empty_preserves_original() {
894        let result = decompose_query("", 7);
895        assert_eq!(result, vec![""]);
896    }
897
898    #[test]
899    fn test_decompose_semicolons() {
900        let result = decompose_query("auth design; deployment config; logging", 7);
901        assert_eq!(result, vec!["auth design", "deployment config", "logging"]);
902    }
903
904    #[test]
905    fn test_decompose_relational_phrase() {
906        let result = decompose_query("auth that caused deployment failure", 7);
907        assert_eq!(result, vec!["auth", "deployment failure"]);
908    }
909
910    #[test]
911    fn test_sub_query_serialization() {
912        let sq = SubQuery {
913            id: 0,
914            text: "test query".to_string(),
915            source: "original",
916        };
917        let json = serde_json::to_value(&sq).expect("serialization failed");
918        assert_eq!(json["id"], 0);
919        assert_eq!(json["text"], "test query");
920        assert_eq!(json["source"], "original");
921    }
922
923    #[test]
924    fn test_deep_result_omits_body_when_none() {
925        let result = DeepResult {
926            name: "test".to_string(),
927            score: 0.9,
928            source: "knn".to_string(),
929            sub_query_ids: vec![0],
930            snippet: "snippet".to_string(),
931            body: None,
932            hop_distance: None,
933        };
934        let json = serde_json::to_string(&result).expect("serialization failed");
935        assert!(!json.contains("\"body\""), "body must be omitted when None");
936    }
937
938    #[test]
939    fn test_deep_result_includes_body_when_some() {
940        let result = DeepResult {
941            name: "test".to_string(),
942            score: 0.9,
943            source: "knn".to_string(),
944            sub_query_ids: vec![0, 1],
945            snippet: "snippet".to_string(),
946            body: Some("full body content".to_string()),
947            hop_distance: Some(2),
948        };
949        let json = serde_json::to_string(&result).expect("serialization failed");
950        assert!(json.contains("\"body\""), "body must be present when Some");
951        assert!(json.contains("full body content"));
952    }
953
954    #[test]
955    fn test_evidence_node_omits_none_fields() {
956        let node = EvidenceNode {
957            entity: "auth-module".to_string(),
958            relation: None,
959            weight: None,
960        };
961        let json = serde_json::to_string(&node).expect("serialization failed");
962        assert!(
963            !json.contains("\"relation\""),
964            "relation must be omitted when None"
965        );
966        assert!(
967            !json.contains("\"weight\""),
968            "weight must be omitted when None"
969        );
970    }
971
972    #[test]
973    fn test_research_stats_serialization() {
974        let stats = ResearchStats {
975            sub_queries_total: 3,
976            sub_queries_completed: 2,
977            sub_queries_failed: 1,
978            sub_queries_timed_out: 0,
979            unique_memories_found: 10,
980            evidence_chains_found: 2,
981            elapsed_ms: 1234,
982        };
983        let json = serde_json::to_value(&stats).expect("serialization failed");
984        assert_eq!(json["sub_queries_total"], 3);
985        assert_eq!(json["sub_queries_completed"], 2);
986        assert_eq!(json["sub_queries_failed"], 1);
987        assert_eq!(json["elapsed_ms"], 1234);
988    }
989
990    #[test]
991    fn test_deep_research_response_serialization() {
992        let resp = DeepResearchResponse {
993            query: "test query".to_string(),
994            sub_queries: vec![SubQuery {
995                id: 0,
996                text: "test query".to_string(),
997                source: "original",
998            }],
999            results: vec![],
1000            evidence_chains: vec![],
1001            stats: ResearchStats {
1002                sub_queries_total: 1,
1003                sub_queries_completed: 1,
1004                sub_queries_failed: 0,
1005                sub_queries_timed_out: 0,
1006                unique_memories_found: 0,
1007                evidence_chains_found: 0,
1008                elapsed_ms: 42,
1009            },
1010        };
1011        let json = serde_json::to_value(&resp).expect("serialization failed");
1012        assert_eq!(json["query"], "test query");
1013        assert!(json["sub_queries"].is_array());
1014        assert!(json["results"].is_array());
1015        assert!(json["evidence_chains"].is_array());
1016        assert_eq!(json["stats"]["elapsed_ms"], 42);
1017    }
1018
1019    // ---- GAP-07 regression: different sub-queries produce distinct embeddings ----
1020    // We test decompose_query returns texts that *would* produce distinct embeddings
1021    // (different text inputs → different embedding inputs → different search results).
1022    #[test]
1023    fn test_distinct_sub_queries_produce_distinct_texts() {
1024        let queries = [
1025            "authentication design decisions",
1026            "deployment configuration and infrastructure",
1027        ];
1028        // These two texts must be different strings (prerequisite for distinct embeddings).
1029        assert_ne!(queries[0], queries[1]);
1030
1031        // decompose_query with semicolons must preserve distinct texts.
1032        let decomposed = decompose_query(
1033            "authentication design decisions; deployment configuration and infrastructure",
1034            7,
1035        );
1036        assert_eq!(decomposed.len(), 2);
1037        assert_ne!(decomposed[0], decomposed[1]);
1038    }
1039
1040    // ---- GAP-08/11 regression: rrf_fuse integration via fusion module ----
1041    #[test]
1042    fn test_rrf_fuse_via_fusion_module() {
1043        use crate::storage::fusion::rrf_fuse;
1044
1045        let knn_ids: Vec<i64> = vec![1, 2, 3];
1046        let fts_ids: Vec<i64> = vec![2, 1, 4];
1047        let scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], 60.0);
1048
1049        // Items appearing in both lists must score higher than items in only one list.
1050        let score_1 = scores[&1];
1051        let score_2 = scores[&2];
1052        let score_3 = scores[&3]; // knn only, rank 3
1053        let score_4 = scores[&4]; // fts only, rank 3
1054
1055        assert!(
1056            score_1 > score_3,
1057            "id 1 (both lists) must beat id 3 (knn-only rank 3)"
1058        );
1059        assert!(
1060            score_2 > score_4,
1061            "id 2 (both lists) must beat id 4 (fts-only rank 3)"
1062        );
1063    }
1064
1065    // ---- GAP-09/10 regression: evidence chains must be directed paths ----
1066    #[test]
1067    fn test_evidence_chain_has_from_to_and_path() {
1068        let chain = EvidenceChain {
1069            from: "auth-module".to_string(),
1070            to: "jwt-service".to_string(),
1071            path: vec![
1072                EvidenceNode {
1073                    entity: "auth-module".to_string(),
1074                    relation: None,
1075                    weight: None,
1076                },
1077                EvidenceNode {
1078                    entity: "token-validator".to_string(),
1079                    relation: Some("depends-on".to_string()),
1080                    weight: Some(0.9),
1081                },
1082                EvidenceNode {
1083                    entity: "jwt-service".to_string(),
1084                    relation: Some("uses".to_string()),
1085                    weight: Some(0.8),
1086                },
1087            ],
1088            total_weight: 0.72,
1089            depth: 3,
1090            sub_query_ids: vec![0],
1091        };
1092
1093        let json = serde_json::to_value(&chain).expect("serialization failed");
1094        assert!(
1095            json["from"].is_string(),
1096            "evidence chain must have 'from' field"
1097        );
1098        assert!(
1099            json["to"].is_string(),
1100            "evidence chain must have 'to' field"
1101        );
1102        assert!(
1103            json["path"].is_array(),
1104            "evidence chain must have 'path' array"
1105        );
1106        assert_eq!(json["path"].as_array().unwrap().len(), 3);
1107        assert!(json["total_weight"].is_number(), "must have total_weight");
1108        assert_eq!(json["depth"], 3);
1109    }
1110
1111    // ---- GAP-10 regression: reconstruct_path returns correct node order ----
1112    #[test]
1113    fn test_reconstruct_path_root_to_target_order() {
1114        // Build a simple chain: entity 10 (seed) -> entity 20 -> entity 30 (target)
1115        let seed_set: HashSet<i64> = [10i64].into_iter().collect();
1116        let mut predecessor: HashMap<i64, (i64, String, f64)> = HashMap::new();
1117        predecessor.insert(20, (10, "depends-on".to_string(), 0.9));
1118        predecessor.insert(30, (20, "uses".to_string(), 0.8));
1119        let mut entity_names: HashMap<i64, String> = HashMap::new();
1120        entity_names.insert(10, "seed-entity".to_string());
1121        entity_names.insert(20, "middle-entity".to_string());
1122        entity_names.insert(30, "target-entity".to_string());
1123
1124        let result = reconstruct_path(30, &seed_set, &predecessor, &entity_names);
1125        assert!(result.is_some(), "path must be reconstructed");
1126        let (nodes, weight) = result.unwrap();
1127        // Path must be [seed, middle, target]
1128        assert_eq!(nodes.len(), 3);
1129        assert_eq!(nodes[0].entity, "seed-entity");
1130        assert_eq!(nodes[1].entity, "middle-entity");
1131        assert_eq!(nodes[2].entity, "target-entity");
1132        // total_weight = 0.9 * 0.8
1133        assert!((weight - 0.72).abs() < 1e-6);
1134    }
1135
1136    // ---- GAP-09 regression: evidence chains must NOT be present for 1-hop trivial pairs ----
1137    #[test]
1138    fn test_evidence_chains_single_hop_filtered_out() {
1139        // A chain of depth 1 (only root node) should be discarded.
1140        let chain = EvidenceChain {
1141            from: "a".to_string(),
1142            to: "a".to_string(),
1143            path: vec![EvidenceNode {
1144                entity: "a".to_string(),
1145                relation: None,
1146                weight: None,
1147            }],
1148            total_weight: 1.0,
1149            depth: 1,
1150            sub_query_ids: vec![0],
1151        };
1152        // Simulate the filter: retain chains with depth >= 2.
1153        let chains = vec![chain];
1154        let retained: Vec<_> = chains.into_iter().filter(|c| c.depth >= 2).collect();
1155        assert!(retained.is_empty(), "depth-1 chains must be filtered out");
1156    }
1157
1158    // ---- GAP-17 regression: bfs_with_predecessors honours max_neighbors_per_hop ----
1159    #[test]
1160    fn test_bfs_with_predecessors_respects_neighbor_cap() {
1161        use crate::graph::bfs_with_predecessors;
1162        use rusqlite::Connection;
1163
1164        let conn = Connection::open_in_memory().unwrap();
1165        conn.execute_batch(
1166            "CREATE TABLE relationships (
1167                source_id INTEGER NOT NULL,
1168                target_id INTEGER NOT NULL,
1169                weight REAL NOT NULL,
1170                namespace TEXT NOT NULL,
1171                relation TEXT NOT NULL DEFAULT 'related'
1172             );",
1173        )
1174        .unwrap();
1175
1176        // Seed entity 1 has 5 neighbours.
1177        for target in 2i64..=6 {
1178            conn.execute(
1179                "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, 'ns')",
1180                rusqlite::params![1i64, target, 1.0f64],
1181            )
1182            .unwrap();
1183        }
1184
1185        // Without cap: all 5 neighbours reached.
1186        let (depth_uncapped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, None).unwrap();
1187        assert_eq!(
1188            depth_uncapped.len() - 1,
1189            5,
1190            "uncapped must discover all 5 neighbours (plus seed)"
1191        );
1192
1193        // With cap=2: only top-2 neighbours (by weight; all equal here so first 2 returned).
1194        let (depth_capped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, Some(2)).unwrap();
1195        // seed + 2 neighbours = 3 entries.
1196        assert_eq!(
1197            depth_capped.len(),
1198            3,
1199            "capped to 2 must yield seed + 2 neighbours"
1200        );
1201    }
1202}