Skip to main content

sqlite_graphrag/commands/
deep_research.rs

1//! Handler for the `deep-research` CLI subcommand.
2//!
3//! Orchestrates parallel multi-hop GraphRAG search via query decomposition.
4//! The workload is I/O-bound (SQLite WAL reads), so tokio is used instead of
5//! rayon. Each sub-query opens its own read-only connection.
6
7use crate::errors::AppError;
8use crate::graph::{
9    bfs_with_predecessors, traverse_from_memories_with_hops_capped, PredecessorMap,
10};
11use crate::output;
12use crate::paths::AppPaths;
13use crate::storage::connection::open_ro;
14use crate::storage::fusion::{rrf_fuse, rrf_max_possible};
15use crate::storage::{entities, memories};
16
17use serde::Serialize;
18use std::collections::HashSet;
19use std::sync::Arc;
20use tokio::sync::Semaphore;
21use tokio::task::JoinSet;
22
23/// Arguments for the `deep-research` subcommand.
24#[derive(clap::Args)]
25#[command(
26    about = "Deep parallel multi-hop GraphRAG research via query decomposition",
27    after_long_help = "EXAMPLES:\n  \
28        # Basic deep research\n  \
29        sqlite-graphrag deep-research \"auth architecture decisions\"\n\n  \
30        # With custom parameters\n  \
31        sqlite-graphrag deep-research \"auth\" --k 20 --max-hops 3 --max-sub-queries 7\n\n  \
32        # Include full memory bodies in output\n  \
33        sqlite-graphrag deep-research \"auth\" --with-bodies\n\n  \
34        # Tune RRF and graph scoring\n  \
35        sqlite-graphrag deep-research \"auth and deployment\" --rrf-k 60 --graph-decay 0.7"
36)]
37pub struct DeepResearchArgs {
38    /// Research query to decompose and search.
39    #[arg(
40        value_name = "QUERY",
41        allow_hyphen_values = true,
42        help = "Research query to decompose and search"
43    )]
44    pub query: String,
45    /// Results per sub-query (Recall@20 captures 95%+ relevant hits).
46    #[arg(
47        long,
48        short,
49        aliases = ["limit", "top-k"],
50        default_value_t = 20,
51        help = "Results per sub-query (Recall@20 captures 95%+ relevant hits)"
52    )]
53    pub k: usize,
54    /// Maximum sub-queries from decomposition (covers complex multi-hop queries).
55    #[arg(
56        long,
57        default_value_t = 7,
58        help = "Maximum sub-queries (covers complex multi-hop queries)"
59    )]
60    pub max_sub_queries: usize,
61    /// Multi-hop graph traversal depth (sweet spot: 2-3 hops).
62    #[arg(
63        long,
64        default_value_t = 3,
65        help = "Multi-hop graph traversal depth (sweet spot: 2-3 hops)"
66    )]
67    pub max_hops: usize,
68    /// Minimum edge weight for graph traversal.
69    #[arg(
70        long,
71        default_value_t = 0.3,
72        help = "Minimum edge weight for graph traversal"
73    )]
74    pub min_weight: f64,
75    /// Maximum concurrent sub-queries (default: min(cpus, 8)).
76    #[arg(long, help = "Maximum concurrent sub-queries (default: min(cpus, 8))")]
77    pub max_concurrency: Option<usize>,
78    /// Timeout per sub-query in seconds.
79    #[arg(long, default_value_t = 30, help = "Timeout per sub-query in seconds")]
80    pub timeout: u64,
81    /// Include full memory bodies in results.
82    #[arg(
83        long,
84        default_value_t = false,
85        help = "Include full memory bodies in results"
86    )]
87    pub with_bodies: bool,
88    /// Maximum results after deduplication.
89    #[arg(
90        long,
91        default_value_t = 50,
92        help = "Maximum results after deduplication"
93    )]
94    pub max_results: usize,
95    /// RRF k parameter controlling score smoothing (higher = less weight on top ranks).
96    #[arg(
97        long,
98        default_value_t = 60.0,
99        help = "RRF k parameter (higher = less weight on top ranks)"
100    )]
101    pub rrf_k: f64,
102    /// Decay factor applied to graph scores per hop (score = seed_score * decay^hop).
103    #[arg(
104        long,
105        default_value_t = 0.7,
106        help = "Graph score decay factor per hop (0.0-1.0)"
107    )]
108    pub graph_decay: f64,
109    /// Minimum score threshold for graph-expanded results (filters noise).
110    #[arg(
111        long,
112        default_value_t = 0.05,
113        help = "Minimum score threshold for graph-expanded results"
114    )]
115    pub graph_min_score: f64,
116    /// Limit top-k neighbours followed per entity per hop (None = unlimited).
117    #[arg(
118        long,
119        help = "Limit neighbours per entity per hop for graph traversal (default: unlimited)"
120    )]
121    pub max_neighbors_per_hop: Option<usize>,
122    /// Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global).
123    #[arg(
124        long,
125        help = "Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global)"
126    )]
127    pub namespace: Option<String>,
128    /// Research mode: `none` (local heuristic, default), `claude-code`, `codex` (v1.1.0).
129    #[arg(long, default_value = "none", value_parser = ["none"], hide = true)]
130    pub mode: String,
131    /// Maximum LLM cost in USD (effective with --mode claude-code/codex, reserved for v1.1.0).
132    #[arg(
133        long,
134        value_name = "USD",
135        help = "Max LLM cost in USD (effective with --mode claude-code/codex)"
136    )]
137    pub max_cost_usd: Option<f64>,
138    /// JSON output (always on, kept for consistency).
139    #[arg(long, hide = true)]
140    pub json: bool,
141    /// Database path.
142    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
143    pub db: Option<String>,
144}
145
146#[derive(Serialize)]
147struct SubQuery {
148    id: usize,
149    text: String,
150    source: &'static str,
151}
152
153#[derive(Serialize)]
154struct DeepResult {
155    name: String,
156    score: f64,
157    source: String,
158    sub_query_ids: Vec<usize>,
159    snippet: String,
160    #[serde(skip_serializing_if = "Option::is_none")]
161    body: Option<String>,
162    hop_distance: Option<usize>,
163}
164
165/// A node in a reconstructed evidence path.
166#[derive(Serialize, Clone)]
167struct EvidenceNode {
168    entity: String,
169    #[serde(skip_serializing_if = "Option::is_none")]
170    relation: Option<String>,
171    #[serde(skip_serializing_if = "Option::is_none")]
172    weight: Option<f64>,
173}
174
175/// A directed evidence chain reconstructed from BFS predecessors.
176///
177/// Fields:
178/// - `from`: name of the seed (source) entity.
179/// - `to`: name of the terminal (target) entity.
180/// - `path`: ordered list of intermediate nodes from `from` to `to`.
181/// - `total_weight`: product of edge weights along the path.
182/// - `sub_query_ids`: which sub-queries produced this chain.
183#[derive(Serialize)]
184struct EvidenceChain {
185    from: String,
186    to: String,
187    path: Vec<EvidenceNode>,
188    total_weight: f64,
189    depth: usize,
190    sub_query_ids: Vec<usize>,
191}
192
193#[derive(Serialize)]
194struct ResearchStats {
195    sub_queries_total: usize,
196    sub_queries_completed: usize,
197    sub_queries_failed: usize,
198    sub_queries_timed_out: usize,
199    unique_memories_found: usize,
200    evidence_chains_found: usize,
201    elapsed_ms: u64,
202    vec_degraded: bool,
203}
204
205#[derive(Serialize)]
206struct GraphContextEntity {
207    name: String,
208    entity_type: String,
209    degree: u32,
210}
211
212#[derive(Serialize)]
213struct GraphContextRel {
214    from: String,
215    to: String,
216    relation: String,
217    weight: f64,
218}
219
220#[derive(Serialize)]
221struct GraphContext {
222    entities: Vec<GraphContextEntity>,
223    relationships: Vec<GraphContextRel>,
224}
225
226#[derive(Serialize)]
227struct DeepResearchResponse {
228    query: String,
229    sub_queries: Vec<SubQuery>,
230    results: Vec<DeepResult>,
231    evidence_chains: Vec<EvidenceChain>,
232    #[serde(skip_serializing_if = "Option::is_none")]
233    graph_context: Option<GraphContext>,
234    stats: ResearchStats,
235}
236
237/// Aggregated hit data: (score, source_label, snippet, body, hop_distance, sub_query_ids).
238type MergedHit = (f64, String, String, String, Option<usize>, Vec<usize>);
239
240/// Intermediate result from a single sub-query execution.
241struct SubQueryResult {
242    sub_query_id: usize,
243    /// (memory_id, score, source_label, snippet, body, hop_distance)
244    hits: Vec<(i64, f64, String, String, String, Option<usize>)>,
245    /// Evidence chains reconstructed from BFS.
246    chains: Vec<EvidenceChain>,
247}
248
249/// Sync entry point — builds a tokio runtime for the async fan-out.
250#[tracing::instrument(skip_all, level = "debug", name = "deep_research")]
251pub fn run(
252    args: DeepResearchArgs,
253    llm_backend: crate::cli::LlmBackendChoice,
254    embedding_backend: crate::cli::EmbeddingBackendChoice,
255) -> Result<(), AppError> {
256    tracing::debug!(target: "deep_research", query = %args.query, k = args.k, "starting deep research");
257    let rt = tokio::runtime::Builder::new_multi_thread()
258        .worker_threads(2)
259        .enable_all()
260        .build()
261        .map_err(|e| AppError::Internal(anyhow::anyhow!("failed to build tokio runtime: {e}")))?;
262    rt.block_on(run_async(args, llm_backend, embedding_backend))
263}
264
265/// Main async logic: decompose, fan-out, assemble, emit JSON.
266async fn run_async(
267    args: DeepResearchArgs,
268    llm_backend: crate::cli::LlmBackendChoice,
269    embedding_backend: crate::cli::EmbeddingBackendChoice,
270) -> Result<(), AppError> {
271    let start = std::time::Instant::now();
272
273    if args.query.trim().is_empty() {
274        return Err(AppError::Validation(crate::i18n::validation::empty_query()));
275    }
276
277    if args.max_cost_usd.is_some() && args.mode == "none" {
278        tracing::warn!(target: "deep_research", "--max-cost-usd has no effect without --mode claude-code/codex");
279    }
280
281    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
282    let paths = AppPaths::resolve(args.db.as_deref())?;
283    crate::storage::connection::ensure_db_ready(&paths)?;
284
285    // Phase 1: Query decomposition (sync, pure logic).
286    let sub_query_texts = decompose_query(&args.query, args.max_sub_queries);
287    let sub_queries: Vec<SubQuery> = sub_query_texts
288        .iter()
289        .enumerate()
290        .map(|(i, text)| SubQuery {
291            id: i,
292            text: text.clone(),
293            source: if sub_query_texts.len() == 1 {
294                "original"
295            } else {
296                "decomposed"
297            },
298        })
299        .collect();
300
301    // GAP-DEEPRESEARCH-001 FIX (v1.0.89): use graceful degradation path
302    // instead of hard-fail. When LLM is unavailable (OAuth expired, timeout,
303    // slots exhausted), fall back to FTS5-only search per sub-query — same
304    // contract as `recall` and `hybrid-search`.
305    output::emit_progress_i18n(
306        "Computing per-sub-query embeddings...",
307        "Calculando embeddings por sub-consulta...",
308    );
309    let mut sub_embeddings: Vec<Option<Arc<Vec<f32>>>> = Vec::with_capacity(sub_query_texts.len());
310    let mut vec_degraded = false;
311    for sq_text in &sub_query_texts {
312        match crate::embedder::try_embed_query_with_embedding_choice(
313            &paths.models,
314            sq_text,
315            embedding_backend,
316            llm_backend,
317        ) {
318            Ok((v, _backend)) => sub_embeddings.push(Some(Arc::new(v))),
319            Err(reason) => {
320                tracing::warn!(target: "deep_research", fallback_reason = %reason, reason_code = %reason.reason_code(), "embedding failed for sub-query; falling back to FTS5");
321                sub_embeddings.push(None);
322                vec_degraded = true;
323            }
324        }
325    }
326
327    // Phase 2: Fan-out — parallel sub-query execution.
328    let cpu_count = std::thread::available_parallelism()
329        .map(|n| n.get())
330        .unwrap_or(4);
331    let permits = args
332        .max_concurrency
333        .unwrap_or_else(|| cpu_count.min(8))
334        .min(sub_queries.len())
335        .max(1);
336    let semaphore = Arc::new(Semaphore::new(permits));
337    let timeout_dur = std::time::Duration::from_secs(args.timeout);
338
339    let mut join_set: JoinSet<Result<SubQueryResult, (usize, String)>> = JoinSet::new();
340
341    for (idx, sq_text) in sub_query_texts.iter().enumerate() {
342        let sem = Arc::clone(&semaphore);
343        // GAP-DEEPRESEARCH-001 FIX: pass Optional embedding (None = FTS5-only).
344        let emb = sub_embeddings[idx].clone();
345        let ns = namespace.clone();
346        let db_path = paths.db.clone();
347        let query_text = sq_text.clone();
348        let k = args.k;
349        let max_hops = args.max_hops;
350        let min_weight = args.min_weight;
351        let rrf_k = args.rrf_k;
352        let graph_decay = args.graph_decay;
353        let graph_min_score = args.graph_min_score;
354        let max_neighbors_per_hop = args.max_neighbors_per_hop;
355
356        join_set.spawn(async move {
357            let _permit = sem
358                .acquire_owned()
359                .await
360                .map_err(|e| (idx, format!("semaphore closed: {e}")))?;
361
362            // Dereference the Arc to obtain a &[f32] slice for the sync function.
363            let result = tokio::time::timeout(timeout_dur, async move {
364                execute_sub_query(
365                    idx,
366                    &query_text,
367                    emb.as_ref().map(|v| v.as_slice()),
368                    &ns,
369                    &db_path,
370                    k,
371                    max_hops,
372                    min_weight,
373                    rrf_k,
374                    graph_decay,
375                    graph_min_score,
376                    max_neighbors_per_hop,
377                )
378            })
379            .await;
380
381            match result {
382                Ok(inner) => inner.map_err(|e| (idx, e)),
383                Err(_) => Err((idx, "timeout".to_string())),
384            }
385        });
386    }
387
388    // Collect results incrementally.
389    let mut sub_query_results: Vec<SubQueryResult> = Vec::with_capacity(sub_queries.len());
390    let mut failed_count = 0usize;
391    let mut timed_out_count = 0usize;
392
393    while let Some(join_result) = join_set.join_next().await {
394        match join_result {
395            Ok(Ok(sqr)) => sub_query_results.push(sqr),
396            Ok(Err((_idx, reason))) => {
397                if reason == "timeout" {
398                    timed_out_count += 1;
399                } else {
400                    failed_count += 1;
401                }
402                tracing::warn!(target: "deep_research", sub_query_id = _idx, reason = %reason, "sub-query failed");
403            }
404            Err(join_err) => {
405                failed_count += 1;
406                if join_err.is_panic() {
407                    tracing::error!(target: "deep_research", error = %join_err, "sub-query task panicked");
408                } else {
409                    tracing::warn!(target: "deep_research", error = %join_err, "sub-query task cancelled");
410                }
411            }
412        }
413    }
414
415    // Phase 3: Evidence assembly — merge, dedup, rank.
416    // Aggregate hits: memory_id -> (best_score, source, snippet, body, hop_distance, sub_query_ids)
417    let mut merged: crate::hash::AHashMap<i64, MergedHit> =
418        crate::hash::AHashMap::with_capacity_and_hasher(
419            sub_query_results.len() * args.k,
420            Default::default(),
421        );
422
423    for sqr in &sub_query_results {
424        for (mem_id, score, source, snippet, body, hop) in &sqr.hits {
425            let entry = merged.entry(*mem_id).or_insert_with(|| {
426                (
427                    *score,
428                    source.clone(),
429                    snippet.clone(),
430                    body.clone(),
431                    *hop,
432                    Vec::new(),
433                )
434            });
435            // Keep best score.
436            if *score > entry.0 {
437                entry.0 = *score;
438                entry.1 = source.clone();
439                entry.2 = snippet.clone();
440                entry.3 = body.clone();
441                entry.4 = *hop;
442            }
443            if !entry.5.contains(&sqr.sub_query_id) {
444                entry.5.push(sqr.sub_query_id);
445            }
446        }
447    }
448
449    // Resolve memory names for merged results.
450    let conn = open_ro(&paths.db)?;
451    let mut results: Vec<DeepResult> = Vec::with_capacity(merged.len().min(args.max_results));
452
453    // Sort by score descending.
454    let mut ranked: Vec<(i64, MergedHit)> = merged.into_iter().collect();
455    ranked.sort_by(|a, b| {
456        b.1 .0
457            .partial_cmp(&a.1 .0)
458            .unwrap_or(std::cmp::Ordering::Equal)
459    });
460    ranked.truncate(args.max_results);
461
462    for (mem_id, (score, source, snippet, body, hop, sq_ids)) in ranked {
463        let name = match memories::read_full(&conn, mem_id)? {
464            Some(row) => row.name,
465            None => continue,
466        };
467        results.push(DeepResult {
468            name,
469            score,
470            source,
471            sub_query_ids: sq_ids,
472            snippet,
473            body: if args.with_bodies { Some(body) } else { None },
474            hop_distance: hop,
475        });
476    }
477
478    // GAP-09/10 FIX: Collect evidence chains from reconstructed BFS paths.
479    // The old code appended flat node pairs from a global SELECT; now each
480    // sub-query returns directed EvidenceChain structs (from, to, path).
481    let completed_count = sub_query_results.len();
482    let mut evidence_chains: Vec<EvidenceChain> = Vec::with_capacity(completed_count * 2);
483    let mut seen_chain_keys: HashSet<String> = HashSet::with_capacity(completed_count * 2);
484
485    for sqr in sub_query_results {
486        for chain in sqr.chains {
487            // Deduplicate chains by (from, to) pair.
488            let key = format!("{}->{}", chain.from, chain.to);
489            if seen_chain_keys.insert(key) {
490                evidence_chains.push(chain);
491            }
492        }
493    }
494
495    // Sort evidence chains by total_weight descending, discard single-hop trivial chains.
496    evidence_chains.retain(|c| c.depth >= 2);
497    evidence_chains.sort_by(|a, b| {
498        b.total_weight
499            .partial_cmp(&a.total_weight)
500            .unwrap_or(std::cmp::Ordering::Equal)
501    });
502
503    let unique_memories = results.len();
504    let evidence_count = evidence_chains.len();
505
506    // MEDIUM-01b: Build graph_context with entities and relationships from result memories.
507    let graph_context = if !results.is_empty() {
508        let result_names: Vec<&str> = results.iter().map(|r| r.name.as_str()).collect();
509        let mut ctx_entities: Vec<GraphContextEntity> = Vec::with_capacity(results.len());
510        let mut ctx_rels: Vec<GraphContextRel> = Vec::with_capacity(results.len() * 2);
511        let mut seen_entity_ids: crate::hash::AHashSet<i64> =
512            crate::hash::AHashSet::with_capacity_and_hasher(results.len(), Default::default());
513
514        for name in &result_names {
515            if let Ok(Some(eid)) = entities::find_entity_id(&conn, &namespace, name) {
516                if seen_entity_ids.insert(eid) {
517                    let etype: String = conn
518                        .query_row(
519                            "SELECT COALESCE(type,'concept') FROM entities WHERE id = ?1",
520                            rusqlite::params![eid],
521                            |r| r.get(0),
522                        )
523                        .unwrap_or_else(|_| "concept".to_string());
524                    let degree: u32 = conn
525                        .query_row(
526                            "SELECT COUNT(*) FROM relationships WHERE source_id = ?1 OR target_id = ?1",
527                            rusqlite::params![eid],
528                            |r| r.get(0),
529                        )
530                        .unwrap_or(0);
531                    ctx_entities.push(GraphContextEntity {
532                        name: name.to_string(),
533                        entity_type: etype,
534                        degree,
535                    });
536                }
537            }
538        }
539
540        let entity_ids: Vec<i64> = seen_entity_ids.iter().copied().collect();
541        if entity_ids.len() >= 2 {
542            let placeholders: String = entity_ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
543            let sql = format!(
544                "SELECT s.name, t.name, r.relation, r.weight \
545                 FROM relationships r \
546                 JOIN entities s ON s.id = r.source_id \
547                 JOIN entities t ON t.id = r.target_id \
548                 WHERE r.source_id IN ({placeholders}) AND r.target_id IN ({placeholders}) \
549                 LIMIT 50"
550            );
551            if let Ok(mut stmt) = conn.prepare(&sql) {
552                let mut params: Vec<Box<dyn rusqlite::types::ToSql>> =
553                    Vec::with_capacity(entity_ids.len() * 2);
554                for id in &entity_ids {
555                    params.push(Box::new(*id));
556                }
557                for id in &entity_ids {
558                    params.push(Box::new(*id));
559                }
560                let param_refs: Vec<&dyn rusqlite::types::ToSql> =
561                    params.iter().map(|p| p.as_ref()).collect();
562                if let Ok(rows) = stmt.query_map(param_refs.as_slice(), |r| {
563                    Ok((
564                        r.get::<_, String>(0)?,
565                        r.get::<_, String>(1)?,
566                        r.get::<_, String>(2)?,
567                        r.get::<_, f64>(3)?,
568                    ))
569                }) {
570                    for row in rows.flatten() {
571                        ctx_rels.push(GraphContextRel {
572                            from: row.0,
573                            to: row.1,
574                            relation: row.2,
575                            weight: row.3,
576                        });
577                    }
578                }
579            }
580        }
581
582        if ctx_entities.is_empty() {
583            None
584        } else {
585            Some(GraphContext {
586                entities: ctx_entities,
587                relationships: ctx_rels,
588            })
589        }
590    } else {
591        None
592    };
593
594    tracing::debug!(target: "deep_research",
595        total_results = results.len(),
596        total_chains = evidence_chains.len(),
597        "assembly complete"
598    );
599
600    // Phase 4: JSON output.
601    output::emit_json(&DeepResearchResponse {
602        query: args.query,
603        sub_queries,
604        results,
605        evidence_chains,
606        graph_context,
607        stats: ResearchStats {
608            sub_queries_total: sub_query_texts.len(),
609            sub_queries_completed: completed_count,
610            sub_queries_failed: failed_count,
611            sub_queries_timed_out: timed_out_count,
612            unique_memories_found: unique_memories,
613            evidence_chains_found: evidence_count,
614            elapsed_ms: start.elapsed().as_millis() as u64,
615            vec_degraded,
616        },
617    })?;
618
619    Ok(())
620}
621
622/// Heuristic query decomposition: splits by conjunctions, commas, semicolons,
623/// relational phrases, and extracts explicit entities (kebab-case or quoted).
624fn decompose_query(query: &str, max: usize) -> Vec<String> {
625    if query.is_empty() {
626        return vec![query.to_string()];
627    }
628
629    let mut parts: Vec<String> = Vec::with_capacity(max);
630
631    // Split by relational phrases first (most specific).
632    let relational = [
633        " that caused ",
634        " depending on ",
635        " related to ",
636        " connected to ",
637        " linked to ",
638        " caused by ",
639        " followed by ",
640    ];
641    let mut text = query.to_string();
642    let mut did_relational_split = false;
643    for phrase in &relational {
644        if text.to_lowercase().contains(phrase) {
645            let lower = text.to_lowercase();
646            if let Some(pos) = lower.find(phrase) {
647                let left = text[..pos].trim().to_string();
648                let right = text[pos + phrase.len()..].trim().to_string();
649                if !left.is_empty() {
650                    parts.push(left);
651                }
652                if !right.is_empty() {
653                    text = right;
654                }
655                did_relational_split = true;
656            }
657        }
658    }
659    if did_relational_split && !text.is_empty() {
660        parts.push(text.clone());
661    }
662
663    // If no relational split, try conjunctions and delimiters.
664    if parts.is_empty() {
665        // Split by semicolons first.
666        let semi_parts: Vec<&str> = query.split(';').collect();
667        if semi_parts.len() > 1 {
668            for p in &semi_parts {
669                let trimmed = p.trim();
670                if !trimmed.is_empty() {
671                    parts.push(trimmed.to_string());
672                }
673            }
674        } else {
675            // Split by commas and conjunctions.
676            // Replace " and " and " e " (Portuguese) with comma, then split.
677            let normalized = query
678                .replace(" and ", ", ")
679                .replace(" AND ", ", ")
680                .replace(" e ", ", ")
681                .replace(" E ", ", ");
682            let comma_parts: Vec<&str> = normalized.split(',').collect();
683            if comma_parts.len() > 1 {
684                for p in &comma_parts {
685                    let trimmed = p.trim();
686                    if !trimmed.is_empty() {
687                        parts.push(trimmed.to_string());
688                    }
689                }
690            }
691        }
692    }
693
694    // If still no split, try word-pair decomposition for multi-word queries.
695    if parts.is_empty() {
696        let words: Vec<&str> = query.split_whitespace().filter(|w| w.len() > 2).collect();
697        if words.len() >= 3 {
698            parts.push(query.to_string());
699            parts.push(format!("{} {}", words[0], words[1]));
700            parts.push(format!(
701                "{} {}",
702                words[words.len() - 2],
703                words[words.len() - 1]
704            ));
705        }
706    }
707
708    if parts.is_empty() {
709        return vec![query.to_string()];
710    }
711
712    // Cap at max.
713    parts.truncate(max);
714    parts
715}
716
717/// Reconstruct a directed path from `target_entity_id` back to a seed using the
718/// predecessor map built by BFS.  Returns the path nodes from root to target
719/// plus the accumulated edge weights.
720fn reconstruct_path(
721    target_id: i64,
722    seed_entity_ids: &HashSet<i64>,
723    predecessor: &PredecessorMap,
724    entity_names: &crate::hash::AHashMap<i64, String>,
725) -> Option<(Vec<EvidenceNode>, f64)> {
726    let mut path_ids: Vec<(i64, Option<String>, Option<f64>)> = Vec::with_capacity(8);
727    let mut total_weight = 1.0_f64;
728    let mut current = target_id;
729
730    loop {
731        if seed_entity_ids.contains(&current) {
732            break;
733        }
734        let (parent, relation, weight) = predecessor.get(&current)?;
735        total_weight *= weight;
736        path_ids.push((current, Some(relation.clone()), Some(*weight)));
737        current = *parent;
738    }
739    // Push the seed entity (root).
740    path_ids.push((current, None, None));
741
742    // Reverse so path goes from seed → target.
743    path_ids.reverse();
744
745    let nodes: Vec<EvidenceNode> = path_ids
746        .into_iter()
747        .map(|(id, relation, weight)| EvidenceNode {
748            entity: entity_names
749                .get(&id)
750                .cloned()
751                .unwrap_or_else(|| format!("entity-{id}")),
752            relation,
753            weight,
754        })
755        .collect();
756
757    Some((nodes, total_weight))
758}
759
760/// Execute a single sub-query: hybrid search (KNN + FTS fused via RRF) + graph traversal.
761///
762/// GAP-07 fix: receives the embedding for THIS sub-query (not the shared original).
763/// GAP-08/11 fix: uses rrf_fuse() for proper score fusion instead of hardcoded 0.5.
764/// GAP-09/10 fix: builds directed evidence chains filtered to discovered entities.
765/// GAP-17: respects max_neighbors_per_hop cap in BFS.
766///
767/// Runs synchronously on a blocking thread (called from a tokio spawn context).
768/// Each call opens its own read-only SQLite connection to leverage WAL concurrency.
769#[allow(clippy::too_many_arguments)]
770fn execute_sub_query(
771    sub_query_id: usize,
772    query_text: &str,
773    embedding: Option<&[f32]>,
774    namespace: &str,
775    db_path: &std::path::Path,
776    k: usize,
777    max_hops: usize,
778    min_weight: f64,
779    rrf_k: f64,
780    graph_decay: f64,
781    graph_min_score: f64,
782    max_neighbors_per_hop: Option<usize>,
783) -> Result<SubQueryResult, String> {
784    let conn = open_ro(db_path).map_err(|e| format!("failed to open db: {e}"))?;
785
786    let mut hits: Vec<(i64, f64, String, String, String, Option<usize>)> =
787        Vec::with_capacity(k * 2);
788    let mut seen_ids: crate::hash::AHashSet<i64> =
789        crate::hash::AHashSet::with_capacity_and_hasher(k * 2, Default::default());
790
791    // --- GAP-08/11 FIX: Use RRF fusion for KNN + FTS instead of hardcoded 0.5 ---
792
793    // 1. KNN vector search — collect ranked IDs (skipped when embedding unavailable).
794    let (knn_ids, knn_distance_map) = if let Some(emb) = embedding {
795        let knn_results = memories::knn_search(&conn, emb, &[namespace.to_string()], None, k)
796            .map_err(|e| format!("knn_search failed: {e}"))?;
797        let ids: Vec<i64> = knn_results.iter().map(|(id, _)| *id).collect();
798        tracing::debug!(target: "deep_research", sub_query_id, knn_count = ids.len(), "KNN complete");
799        let dist_map: crate::hash::AHashMap<i64, f64> = knn_results
800            .iter()
801            .map(|(id, dist)| (*id, *dist as f64))
802            .collect();
803        (ids, dist_map)
804    } else {
805        tracing::debug!(target: "deep_research", sub_query_id, "KNN skipped (no embedding); FTS5-only");
806        (vec![], crate::hash::AHashMap::default())
807    };
808
809    // 2. FTS5 search — collect ranked IDs.
810    let fts_results = match memories::fts_search(&conn, query_text, namespace, None, k) {
811        Ok(rows) => rows,
812        Err(e) => {
813            tracing::warn!(target: "deep_research",
814                sub_query_id,
815                "FTS5 search failed, continuing with KNN only: {e}"
816            );
817            vec![]
818        }
819    };
820    let fts_ids: Vec<i64> = fts_results.iter().map(|r| r.id).collect();
821    tracing::debug!(target: "deep_research", sub_query_id, fts_count = fts_ids.len(), "FTS complete");
822
823    // 3. Fuse via RRF.
824    let rrf_scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], rrf_k);
825    let max_possible = rrf_max_possible(&[1.0, 1.0], rrf_k);
826
827    // 4. Sort fused results and build hits.
828    let mut fused: Vec<(i64, f64)> = rrf_scores.into_iter().collect();
829    fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
830    fused.truncate(k * 2);
831    tracing::debug!(target: "deep_research",
832        sub_query_id,
833        fused_count = fused.len(),
834        "RRF fusion complete"
835    );
836
837    if fused.is_empty() && !knn_ids.is_empty() {
838        tracing::warn!(target: "deep_research", sub_query_id, knn_count = knn_ids.len(), fts_count = fts_ids.len(),
839            "RRF fusion returned 0 results despite KNN/FTS hits; consider lowering --graph-min-score");
840    }
841
842    for (memory_id, combined_score) in &fused {
843        if seen_ids.insert(*memory_id) {
844            let normalized = if max_possible > 0.0 {
845                combined_score / max_possible
846            } else {
847                0.0
848            };
849            let score = normalized.clamp(0.0, 1.0);
850            let in_knn = knn_distance_map.contains_key(memory_id);
851            let in_fts = fts_ids.contains(memory_id);
852            let source = match (in_knn, in_fts) {
853                (true, true) => "hybrid",
854                (true, false) => "knn",
855                (false, true) => "fts",
856                (false, false) => "graph",
857            };
858            if let Ok(Some(row)) = memories::read_full(&conn, *memory_id) {
859                let snippet: String = row.body.chars().take(300).collect();
860                hits.push((
861                    *memory_id,
862                    score,
863                    source.to_string(),
864                    snippet,
865                    row.body,
866                    None,
867                ));
868            }
869        }
870    }
871
872    // 5. Graph traversal from discovered memories.
873    // GAP-09/10 FIX: entity KNN also uses this sub-query's embedding.
874    let memory_ids: Vec<i64> = hits.iter().map(|(id, ..)| *id).collect();
875    let mut chains: Vec<EvidenceChain> = Vec::with_capacity(memory_ids.len());
876
877    if !memory_ids.is_empty() && max_hops > 0 {
878        // Seed entities from KNN on entity vectors (skipped when embedding unavailable).
879        let entity_ids: Vec<i64> = if let Some(emb) = embedding {
880            entities::knn_search(&conn, emb, namespace, 5)
881                .inspect_err(|e| tracing::warn!(target: "deep_research", error = %e, "entity KNN search failed, skipping graph seed"))
882                .unwrap_or_default()
883                .iter()
884                .map(|(id, _)| *id)
885                .collect()
886        } else {
887            vec![]
888        };
889
890        // HIGH-01 FIX: limit seeds to top-5 memories by score to prevent
891        // BFS from starting at every node when k >= total memories.
892        let top_seed_count = 5.min(memory_ids.len());
893        let top_memory_ids = &memory_ids[..top_seed_count];
894        let mut seed_entity_ids: Vec<i64> = entity_ids.clone();
895        for &mem_id in top_memory_ids {
896            let mut stmt = conn
897                .prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")
898                .map_err(|e| format!("prepare failed: {e}"))?;
899            let ids: Vec<i64> = stmt
900                .query_map(rusqlite::params![mem_id], |r| r.get(0))
901                .map_err(|e| format!("query failed: {e}"))?
902                .filter_map(|r| r.ok())
903                .collect();
904            seed_entity_ids.extend(ids);
905        }
906        seed_entity_ids.sort_unstable();
907        seed_entity_ids.dedup();
908        tracing::debug!(target: "deep_research",
909            sub_query_id,
910            seed_count = seed_entity_ids.len(),
911            "seed entities collected"
912        );
913
914        let all_seed_ids: Vec<i64> = memory_ids
915            .iter()
916            .chain(entity_ids.iter())
917            .copied()
918            .collect();
919
920        // Graph traversal with hop scores.
921        if let Ok(graph_results) = traverse_from_memories_with_hops_capped(
922            &conn,
923            &all_seed_ids,
924            namespace,
925            min_weight,
926            max_hops as u32,
927            max_neighbors_per_hop,
928        ) {
929            // Build seed score map from RRF-fused scores for graph decay computation.
930            let seed_score_map: crate::hash::AHashMap<i64, f64> = fused
931                .iter()
932                .map(|(id, s)| {
933                    let normalized = if max_possible > 0.0 {
934                        s / max_possible
935                    } else {
936                        0.0
937                    };
938                    (*id, normalized.clamp(0.0, 1.0))
939                })
940                .collect();
941
942            for (graph_mem_id, hop) in graph_results {
943                if seen_ids.insert(graph_mem_id) {
944                    // GAP-08/11 FIX: graph score = seed_score * decay^hop * edge_weight.
945                    // For the seed score, use the best score among the seed memories that
946                    // transitively reached this graph memory (approximate with the average
947                    // seed score since we don't track the exact path yet).
948                    let avg_seed_score: f64 = if seed_score_map.is_empty() {
949                        0.5
950                    } else {
951                        let sum: f64 = seed_score_map.values().sum();
952                        sum / seed_score_map.len() as f64
953                    };
954                    let graph_score =
955                        (avg_seed_score * graph_decay.powi(hop as i32)).clamp(0.0, 1.0);
956
957                    if graph_score < graph_min_score {
958                        continue;
959                    }
960
961                    if let Ok(Some(row)) = memories::read_full(&conn, graph_mem_id) {
962                        let snippet: String = row.body.chars().take(300).collect();
963                        hits.push((
964                            graph_mem_id,
965                            graph_score,
966                            "graph".to_string(),
967                            snippet,
968                            row.body,
969                            Some(hop as usize),
970                        ));
971                    }
972                }
973            }
974        }
975
976        // GAP-09/10 FIX: Build directed evidence chains using BFS with predecessor map,
977        // filtered to entities discovered in this sub-query.
978        if !seed_entity_ids.is_empty() {
979            let (entity_depth, predecessor) = bfs_with_predecessors(
980                &conn,
981                &seed_entity_ids,
982                namespace,
983                min_weight,
984                max_hops as u32,
985                max_neighbors_per_hop,
986            )
987            .unwrap_or_default();
988
989            tracing::debug!(target: "deep_research",
990                sub_query_id,
991                bfs_nodes = entity_depth.len(),
992                predecessors = predecessor.len(),
993                "BFS complete"
994            );
995
996            let seed_entity_set: HashSet<i64> = seed_entity_ids.iter().copied().collect();
997
998            // Collect entity IDs we need names for.
999            let all_entity_ids: Vec<i64> = entity_depth.keys().copied().collect();
1000            let mut entity_names: crate::hash::AHashMap<i64, String> =
1001                crate::hash::AHashMap::with_capacity_and_hasher(
1002                    all_entity_ids.len(),
1003                    ahash::RandomState::default(),
1004                );
1005            for &eid in &all_entity_ids {
1006                let name_res: rusqlite::Result<String> = conn.query_row(
1007                    "SELECT name FROM entities WHERE id = ?1",
1008                    rusqlite::params![eid],
1009                    |r| r.get(0),
1010                );
1011                if let Ok(name) = name_res {
1012                    entity_names.insert(eid, name);
1013                }
1014            }
1015
1016            // Reconstruct a path for each non-seed entity that has a predecessor.
1017            for (&target_id, &_hop) in &entity_depth {
1018                if seed_entity_set.contains(&target_id) {
1019                    continue;
1020                }
1021                if !predecessor.contains_key(&target_id) {
1022                    continue;
1023                }
1024                if let Some((path_nodes, total_weight)) =
1025                    reconstruct_path(target_id, &seed_entity_set, &predecessor, &entity_names)
1026                {
1027                    if path_nodes.len() < 2 {
1028                        continue;
1029                    }
1030                    let from = path_nodes
1031                        .first()
1032                        .map(|n| n.entity.clone())
1033                        .unwrap_or_default();
1034                    let to = path_nodes
1035                        .last()
1036                        .map(|n| n.entity.clone())
1037                        .unwrap_or_default();
1038                    let depth = path_nodes.len();
1039                    chains.push(EvidenceChain {
1040                        from,
1041                        to,
1042                        path: path_nodes,
1043                        total_weight,
1044                        depth,
1045                        sub_query_ids: vec![sub_query_id],
1046                    });
1047                }
1048            }
1049
1050            // Sort chains by total_weight descending and cap to avoid huge output.
1051            chains.sort_by(|a, b| {
1052                b.total_weight
1053                    .partial_cmp(&a.total_weight)
1054                    .unwrap_or(std::cmp::Ordering::Equal)
1055            });
1056            chains.truncate(20);
1057            tracing::debug!(target: "deep_research",
1058                sub_query_id,
1059                chains_count = chains.len(),
1060                "evidence chains built"
1061            );
1062        }
1063    }
1064
1065    Ok(SubQueryResult {
1066        sub_query_id,
1067        hits,
1068        chains,
1069    })
1070}
1071
1072// ────────────────────────────────────────────────────────────────────────────
1073// Re-export sub_query_results field initialisation for the stats counter.
1074// The field is moved out of run_async after the join loop; we need to shadow it.
1075// ────────────────────────────────────────────────────────────────────────────
1076
1077#[cfg(test)]
1078mod tests {
1079    use super::*;
1080
1081    #[test]
1082    fn test_decompose_and_conjunction() {
1083        let result = decompose_query("A and B", 7);
1084        assert_eq!(result, vec!["A", "B"]);
1085    }
1086
1087    #[test]
1088    fn test_decompose_no_split() {
1089        let result = decompose_query("simple query", 7);
1090        assert_eq!(result, vec!["simple query"]);
1091    }
1092
1093    #[test]
1094    fn test_decompose_three_parts() {
1095        let result = decompose_query("A, B and C", 7);
1096        assert_eq!(result, vec!["A", "B", "C"]);
1097    }
1098
1099    #[test]
1100    fn test_decompose_portuguese_conjunctions() {
1101        let result = decompose_query("A e B", 7);
1102        assert_eq!(result, vec!["A", "B"]);
1103    }
1104
1105    #[test]
1106    fn test_decompose_max_cap() {
1107        let parts: Vec<String> = (0..10).map(|i| format!("part{i}")).collect();
1108        let query = parts.join(", ");
1109        let result = decompose_query(&query, 7);
1110        assert!(
1111            result.len() <= 7,
1112            "expected at most 7 sub-queries, got {}",
1113            result.len()
1114        );
1115    }
1116
1117    #[test]
1118    fn test_decompose_empty_preserves_original() {
1119        let result = decompose_query("", 7);
1120        assert_eq!(result, vec![""]);
1121    }
1122
1123    #[test]
1124    fn test_decompose_semicolons() {
1125        let result = decompose_query("auth design; deployment config; logging", 7);
1126        assert_eq!(result, vec!["auth design", "deployment config", "logging"]);
1127    }
1128
1129    #[test]
1130    fn test_decompose_relational_phrase() {
1131        let result = decompose_query("auth that caused deployment failure", 7);
1132        assert_eq!(result, vec!["auth", "deployment failure"]);
1133    }
1134
1135    #[test]
1136    fn test_sub_query_serialization() {
1137        let sq = SubQuery {
1138            id: 0,
1139            text: "test query".to_string(),
1140            source: "original",
1141        };
1142        let json = serde_json::to_value(&sq).expect("serialization failed");
1143        assert_eq!(json["id"], 0);
1144        assert_eq!(json["text"], "test query");
1145        assert_eq!(json["source"], "original");
1146    }
1147
1148    #[test]
1149    fn test_deep_result_omits_body_when_none() {
1150        let result = DeepResult {
1151            name: "test".to_string(),
1152            score: 0.9,
1153            source: "knn".to_string(),
1154            sub_query_ids: vec![0],
1155            snippet: "snippet".to_string(),
1156            body: None,
1157            hop_distance: None,
1158        };
1159        let json = serde_json::to_string(&result).expect("serialization failed");
1160        assert!(!json.contains("\"body\""), "body must be omitted when None");
1161    }
1162
1163    #[test]
1164    fn test_deep_result_includes_body_when_some() {
1165        let result = DeepResult {
1166            name: "test".to_string(),
1167            score: 0.9,
1168            source: "knn".to_string(),
1169            sub_query_ids: vec![0, 1],
1170            snippet: "snippet".to_string(),
1171            body: Some("full body content".to_string()),
1172            hop_distance: Some(2),
1173        };
1174        let json = serde_json::to_string(&result).expect("serialization failed");
1175        assert!(json.contains("\"body\""), "body must be present when Some");
1176        assert!(json.contains("full body content"));
1177    }
1178
1179    #[test]
1180    fn test_evidence_node_omits_none_fields() {
1181        let node = EvidenceNode {
1182            entity: "auth-module".to_string(),
1183            relation: None,
1184            weight: None,
1185        };
1186        let json = serde_json::to_string(&node).expect("serialization failed");
1187        assert!(
1188            !json.contains("\"relation\""),
1189            "relation must be omitted when None"
1190        );
1191        assert!(
1192            !json.contains("\"weight\""),
1193            "weight must be omitted when None"
1194        );
1195    }
1196
1197    #[test]
1198    fn test_research_stats_serialization() {
1199        let stats = ResearchStats {
1200            sub_queries_total: 3,
1201            sub_queries_completed: 2,
1202            sub_queries_failed: 1,
1203            sub_queries_timed_out: 0,
1204            unique_memories_found: 10,
1205            evidence_chains_found: 2,
1206            elapsed_ms: 1234,
1207            vec_degraded: false,
1208        };
1209        let json = serde_json::to_value(&stats).expect("serialization failed");
1210        assert_eq!(json["sub_queries_total"], 3);
1211        assert_eq!(json["sub_queries_completed"], 2);
1212        assert_eq!(json["sub_queries_failed"], 1);
1213        assert_eq!(json["elapsed_ms"], 1234);
1214    }
1215
1216    #[test]
1217    fn test_deep_research_response_serialization() {
1218        let resp = DeepResearchResponse {
1219            query: "test query".to_string(),
1220            sub_queries: vec![SubQuery {
1221                id: 0,
1222                text: "test query".to_string(),
1223                source: "original",
1224            }],
1225            results: vec![],
1226            evidence_chains: vec![],
1227            graph_context: None,
1228            stats: ResearchStats {
1229                sub_queries_total: 1,
1230                sub_queries_completed: 1,
1231                sub_queries_failed: 0,
1232                sub_queries_timed_out: 0,
1233                unique_memories_found: 0,
1234                evidence_chains_found: 0,
1235                elapsed_ms: 42,
1236                vec_degraded: false,
1237            },
1238        };
1239        let json = serde_json::to_value(&resp).expect("serialization failed");
1240        assert_eq!(json["query"], "test query");
1241        assert!(json["sub_queries"].is_array());
1242        assert!(json["results"].is_array());
1243        assert!(json["evidence_chains"].is_array());
1244        assert_eq!(json["stats"]["elapsed_ms"], 42);
1245    }
1246
1247    // ---- GAP-07 regression: different sub-queries produce distinct embeddings ----
1248    // We test decompose_query returns texts that *would* produce distinct embeddings
1249    // (different text inputs → different embedding inputs → different search results).
1250    #[test]
1251    fn test_distinct_sub_queries_produce_distinct_texts() {
1252        let queries = [
1253            "authentication design decisions",
1254            "deployment configuration and infrastructure",
1255        ];
1256        // These two texts must be different strings (prerequisite for distinct embeddings).
1257        assert_ne!(queries[0], queries[1]);
1258
1259        // decompose_query with semicolons must preserve distinct texts.
1260        let decomposed = decompose_query(
1261            "authentication design decisions; deployment configuration and infrastructure",
1262            7,
1263        );
1264        assert_eq!(decomposed.len(), 2);
1265        assert_ne!(decomposed[0], decomposed[1]);
1266    }
1267
1268    // ---- GAP-08/11 regression: rrf_fuse integration via fusion module ----
1269    #[test]
1270    fn test_rrf_fuse_via_fusion_module() {
1271        use crate::storage::fusion::rrf_fuse;
1272
1273        let knn_ids: Vec<i64> = vec![1, 2, 3];
1274        let fts_ids: Vec<i64> = vec![2, 1, 4];
1275        let scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], 60.0);
1276
1277        // Items appearing in both lists must score higher than items in only one list.
1278        let score_1 = scores[&1];
1279        let score_2 = scores[&2];
1280        let score_3 = scores[&3]; // knn only, rank 3
1281        let score_4 = scores[&4]; // fts only, rank 3
1282
1283        assert!(
1284            score_1 > score_3,
1285            "id 1 (both lists) must beat id 3 (knn-only rank 3)"
1286        );
1287        assert!(
1288            score_2 > score_4,
1289            "id 2 (both lists) must beat id 4 (fts-only rank 3)"
1290        );
1291    }
1292
1293    // ---- GAP-09/10 regression: evidence chains must be directed paths ----
1294    #[test]
1295    fn test_evidence_chain_has_from_to_and_path() {
1296        let chain = EvidenceChain {
1297            from: "auth-module".to_string(),
1298            to: "jwt-service".to_string(),
1299            path: vec![
1300                EvidenceNode {
1301                    entity: "auth-module".to_string(),
1302                    relation: None,
1303                    weight: None,
1304                },
1305                EvidenceNode {
1306                    entity: "token-validator".to_string(),
1307                    relation: Some("depends-on".to_string()),
1308                    weight: Some(0.9),
1309                },
1310                EvidenceNode {
1311                    entity: "jwt-service".to_string(),
1312                    relation: Some("uses".to_string()),
1313                    weight: Some(0.8),
1314                },
1315            ],
1316            total_weight: 0.72,
1317            depth: 3,
1318            sub_query_ids: vec![0],
1319        };
1320
1321        let json = serde_json::to_value(&chain).expect("serialization failed");
1322        assert!(
1323            json["from"].is_string(),
1324            "evidence chain must have 'from' field"
1325        );
1326        assert!(
1327            json["to"].is_string(),
1328            "evidence chain must have 'to' field"
1329        );
1330        assert!(
1331            json["path"].is_array(),
1332            "evidence chain must have 'path' array"
1333        );
1334        assert_eq!(json["path"].as_array().unwrap().len(), 3);
1335        assert!(json["total_weight"].is_number(), "must have total_weight");
1336        assert_eq!(json["depth"], 3);
1337    }
1338
1339    // ---- GAP-10 regression: reconstruct_path returns correct node order ----
1340    #[test]
1341    fn test_reconstruct_path_root_to_target_order() {
1342        // Build a simple chain: entity 10 (seed) -> entity 20 -> entity 30 (target)
1343        let seed_set: HashSet<i64> = [10i64].into_iter().collect();
1344        let mut predecessor: PredecessorMap = std::collections::HashMap::new();
1345        predecessor.insert(20, (10, "depends-on".to_string(), 0.9));
1346        predecessor.insert(30, (20, "uses".to_string(), 0.8));
1347        let mut entity_names: crate::hash::AHashMap<i64, String> = crate::hash::AHashMap::default();
1348        entity_names.insert(10, "seed-entity".to_string());
1349        entity_names.insert(20, "middle-entity".to_string());
1350        entity_names.insert(30, "target-entity".to_string());
1351
1352        let result = reconstruct_path(30, &seed_set, &predecessor, &entity_names);
1353        assert!(result.is_some(), "path must be reconstructed");
1354        let (nodes, weight) = result.unwrap();
1355        // Path must be [seed, middle, target]
1356        assert_eq!(nodes.len(), 3);
1357        assert_eq!(nodes[0].entity, "seed-entity");
1358        assert_eq!(nodes[1].entity, "middle-entity");
1359        assert_eq!(nodes[2].entity, "target-entity");
1360        // total_weight = 0.9 * 0.8
1361        assert!((weight - 0.72).abs() < 1e-6);
1362    }
1363
1364    // ---- GAP-09 regression: evidence chains must NOT be present for 1-hop trivial pairs ----
1365    #[test]
1366    fn test_evidence_chains_single_hop_filtered_out() {
1367        // A chain of depth 1 (only root node) should be discarded.
1368        let chain = EvidenceChain {
1369            from: "a".to_string(),
1370            to: "a".to_string(),
1371            path: vec![EvidenceNode {
1372                entity: "a".to_string(),
1373                relation: None,
1374                weight: None,
1375            }],
1376            total_weight: 1.0,
1377            depth: 1,
1378            sub_query_ids: vec![0],
1379        };
1380        // Simulate the filter: retain chains with depth >= 2.
1381        let chains = vec![chain];
1382        let retained: Vec<_> = chains.into_iter().filter(|c| c.depth >= 2).collect();
1383        assert!(retained.is_empty(), "depth-1 chains must be filtered out");
1384    }
1385
1386    // ---- GAP-17 regression: bfs_with_predecessors honours max_neighbors_per_hop ----
1387    #[test]
1388    fn test_bfs_with_predecessors_respects_neighbor_cap() {
1389        use crate::graph::bfs_with_predecessors;
1390        use rusqlite::Connection;
1391
1392        let conn = Connection::open_in_memory().unwrap();
1393        conn.execute_batch(
1394            "CREATE TABLE relationships (
1395                source_id INTEGER NOT NULL,
1396                target_id INTEGER NOT NULL,
1397                weight REAL NOT NULL,
1398                namespace TEXT NOT NULL,
1399                relation TEXT NOT NULL DEFAULT 'related'
1400             );",
1401        )
1402        .unwrap();
1403
1404        // Seed entity 1 has 5 neighbours.
1405        for target in 2i64..=6 {
1406            conn.execute(
1407                "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, 'ns')",
1408                rusqlite::params![1i64, target, 1.0f64],
1409            )
1410            .unwrap();
1411        }
1412
1413        // Without cap: all 5 neighbours reached.
1414        let (depth_uncapped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, None).unwrap();
1415        assert_eq!(
1416            depth_uncapped.len() - 1,
1417            5,
1418            "uncapped must discover all 5 neighbours (plus seed)"
1419        );
1420
1421        // With cap=2: only top-2 neighbours (by weight; all equal here so first 2 returned).
1422        let (depth_capped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, Some(2)).unwrap();
1423        // seed + 2 neighbours = 3 entries.
1424        assert_eq!(
1425            depth_capped.len(),
1426            3,
1427            "capped to 2 must yield seed + 2 neighbours"
1428        );
1429    }
1430}