Skip to main content

sqlite_graphrag/commands/
deep_research.rs

1//! Handler for the `deep-research` CLI subcommand.
2//!
3//! Orchestrates parallel multi-hop GraphRAG search via query decomposition.
4//! The workload is I/O-bound (SQLite WAL reads), so tokio is used instead of
5//! rayon. Each sub-query opens its own read-only connection.
6
7use crate::errors::AppError;
8use crate::graph::{
9    bfs_with_predecessors, traverse_from_memories_with_hops_capped, PredecessorMap,
10};
11use crate::output;
12use crate::paths::AppPaths;
13use crate::storage::connection::open_ro;
14use crate::storage::fusion::{rrf_fuse, rrf_max_possible};
15use crate::storage::{entities, memories};
16
17use serde::Serialize;
18use std::collections::HashSet;
19use std::sync::Arc;
20use tokio::sync::Semaphore;
21use tokio::task::JoinSet;
22
23/// Arguments for the `deep-research` subcommand.
24#[derive(clap::Args)]
25#[command(
26    about = "Deep parallel multi-hop GraphRAG research via query decomposition",
27    after_long_help = "EXAMPLES:\n  \
28        # Basic deep research\n  \
29        sqlite-graphrag deep-research \"auth architecture decisions\"\n\n  \
30        # With custom parameters\n  \
31        sqlite-graphrag deep-research \"auth\" --k 20 --max-hops 3 --max-sub-queries 7\n\n  \
32        # Include full memory bodies in output\n  \
33        sqlite-graphrag deep-research \"auth\" --with-bodies\n\n  \
34        # Tune RRF and graph scoring\n  \
35        sqlite-graphrag deep-research \"auth and deployment\" --rrf-k 60 --graph-decay 0.7"
36)]
37pub struct DeepResearchArgs {
38    /// Research query to decompose and search.
39    #[arg(
40        value_name = "QUERY",
41        allow_hyphen_values = true,
42        help = "Research query to decompose and search"
43    )]
44    pub query: String,
45    /// Results per sub-query (Recall@20 captures 95%+ relevant hits).
46    #[arg(
47        long,
48        short,
49        aliases = ["limit", "top-k"],
50        default_value_t = 20,
51        help = "Results per sub-query (Recall@20 captures 95%+ relevant hits)"
52    )]
53    pub k: usize,
54    /// Maximum sub-queries from decomposition (covers complex multi-hop queries).
55    #[arg(
56        long,
57        default_value_t = 7,
58        help = "Maximum sub-queries (covers complex multi-hop queries)"
59    )]
60    pub max_sub_queries: usize,
61    /// Multi-hop graph traversal depth (sweet spot: 2-3 hops).
62    #[arg(
63        long,
64        default_value_t = 3,
65        help = "Multi-hop graph traversal depth (sweet spot: 2-3 hops)"
66    )]
67    pub max_hops: usize,
68    /// Minimum edge weight for graph traversal.
69    #[arg(
70        long,
71        default_value_t = 0.3,
72        help = "Minimum edge weight for graph traversal"
73    )]
74    pub min_weight: f64,
75    /// Maximum concurrent sub-queries (default: min(cpus, 8)).
76    #[arg(long, help = "Maximum concurrent sub-queries (default: min(cpus, 8))")]
77    pub max_concurrency: Option<usize>,
78    /// Timeout per sub-query in seconds.
79    #[arg(long, default_value_t = 30, help = "Timeout per sub-query in seconds")]
80    pub timeout: u64,
81    /// Include full memory bodies in results.
82    #[arg(
83        long,
84        default_value_t = false,
85        help = "Include full memory bodies in results"
86    )]
87    pub with_bodies: bool,
88    /// Maximum results after deduplication.
89    #[arg(
90        long,
91        default_value_t = 50,
92        help = "Maximum results after deduplication"
93    )]
94    pub max_results: usize,
95    /// RRF k parameter controlling score smoothing (higher = less weight on top ranks).
96    #[arg(
97        long,
98        default_value_t = 60.0,
99        help = "RRF k parameter (higher = less weight on top ranks)"
100    )]
101    pub rrf_k: f64,
102    /// Decay factor applied to graph scores per hop (score = seed_score * decay^hop).
103    #[arg(
104        long,
105        default_value_t = 0.7,
106        help = "Graph score decay factor per hop (0.0-1.0)"
107    )]
108    pub graph_decay: f64,
109    /// Minimum score threshold for graph-expanded results (filters noise).
110    #[arg(
111        long,
112        default_value_t = 0.05,
113        help = "Minimum score threshold for graph-expanded results"
114    )]
115    pub graph_min_score: f64,
116    /// Limit top-k neighbours followed per entity per hop (None = unlimited).
117    #[arg(
118        long,
119        help = "Limit neighbours per entity per hop for graph traversal (default: unlimited)"
120    )]
121    pub max_neighbors_per_hop: Option<usize>,
122    /// Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global).
123    #[arg(
124        long,
125        help = "Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global)"
126    )]
127    pub namespace: Option<String>,
128    /// Research mode: `none` (local heuristic, default), `claude-code`, `codex` (v1.1.0).
129    #[arg(long, default_value = "none", value_parser = ["none"], hide = true)]
130    pub mode: String,
131    /// Maximum LLM cost in USD (effective with --mode claude-code/codex, reserved for v1.1.0).
132    #[arg(
133        long,
134        value_name = "USD",
135        help = "Max LLM cost in USD (effective with --mode claude-code/codex)"
136    )]
137    pub max_cost_usd: Option<f64>,
138    /// JSON output (always on, kept for consistency).
139    #[arg(long, hide = true)]
140    pub json: bool,
141    /// Database path.
142    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
143    pub db: Option<String>,
144    #[command(flatten)]
145    pub daemon: crate::cli::DaemonOpts,
146}
147
148#[derive(Serialize)]
149struct SubQuery {
150    id: usize,
151    text: String,
152    source: &'static str,
153}
154
155#[derive(Serialize)]
156struct DeepResult {
157    name: String,
158    score: f64,
159    source: String,
160    sub_query_ids: Vec<usize>,
161    snippet: String,
162    #[serde(skip_serializing_if = "Option::is_none")]
163    body: Option<String>,
164    hop_distance: Option<usize>,
165}
166
167/// A node in a reconstructed evidence path.
168#[derive(Serialize, Clone)]
169struct EvidenceNode {
170    entity: String,
171    #[serde(skip_serializing_if = "Option::is_none")]
172    relation: Option<String>,
173    #[serde(skip_serializing_if = "Option::is_none")]
174    weight: Option<f64>,
175}
176
177/// A directed evidence chain reconstructed from BFS predecessors.
178///
179/// Fields:
180/// - `from`: name of the seed (source) entity.
181/// - `to`: name of the terminal (target) entity.
182/// - `path`: ordered list of intermediate nodes from `from` to `to`.
183/// - `total_weight`: product of edge weights along the path.
184/// - `sub_query_ids`: which sub-queries produced this chain.
185#[derive(Serialize)]
186struct EvidenceChain {
187    from: String,
188    to: String,
189    path: Vec<EvidenceNode>,
190    total_weight: f64,
191    depth: usize,
192    sub_query_ids: Vec<usize>,
193}
194
195#[derive(Serialize)]
196struct ResearchStats {
197    sub_queries_total: usize,
198    sub_queries_completed: usize,
199    sub_queries_failed: usize,
200    sub_queries_timed_out: usize,
201    unique_memories_found: usize,
202    evidence_chains_found: usize,
203    elapsed_ms: u64,
204}
205
206#[derive(Serialize)]
207struct GraphContextEntity {
208    name: String,
209    entity_type: String,
210    degree: u32,
211}
212
213#[derive(Serialize)]
214struct GraphContextRel {
215    from: String,
216    to: String,
217    relation: String,
218    weight: f64,
219}
220
221#[derive(Serialize)]
222struct GraphContext {
223    entities: Vec<GraphContextEntity>,
224    relationships: Vec<GraphContextRel>,
225}
226
227#[derive(Serialize)]
228struct DeepResearchResponse {
229    query: String,
230    sub_queries: Vec<SubQuery>,
231    results: Vec<DeepResult>,
232    evidence_chains: Vec<EvidenceChain>,
233    #[serde(skip_serializing_if = "Option::is_none")]
234    graph_context: Option<GraphContext>,
235    stats: ResearchStats,
236}
237
238/// Aggregated hit data: (score, source_label, snippet, body, hop_distance, sub_query_ids).
239type MergedHit = (f64, String, String, String, Option<usize>, Vec<usize>);
240
241/// Intermediate result from a single sub-query execution.
242struct SubQueryResult {
243    sub_query_id: usize,
244    /// (memory_id, score, source_label, snippet, body, hop_distance)
245    hits: Vec<(i64, f64, String, String, String, Option<usize>)>,
246    /// Evidence chains reconstructed from BFS.
247    chains: Vec<EvidenceChain>,
248}
249
250/// Sync entry point — builds a tokio runtime for the async fan-out.
251#[tracing::instrument(skip_all, level = "debug", name = "deep_research")]
252pub fn run(args: DeepResearchArgs) -> Result<(), AppError> {
253    tracing::debug!(target: "deep_research", query = %args.query, k = args.k, "starting deep research");
254    let rt = tokio::runtime::Builder::new_multi_thread()
255        .worker_threads(2)
256        .enable_all()
257        .build()
258        .map_err(|e| AppError::Internal(anyhow::anyhow!("failed to build tokio runtime: {e}")))?;
259    rt.block_on(run_async(args))
260}
261
262/// Main async logic: decompose, fan-out, assemble, emit JSON.
263async fn run_async(args: DeepResearchArgs) -> Result<(), AppError> {
264    let start = std::time::Instant::now();
265
266    if args.query.trim().is_empty() {
267        return Err(AppError::Validation(crate::i18n::validation::empty_query()));
268    }
269
270    if args.max_cost_usd.is_some() && args.mode == "none" {
271        tracing::warn!(target: "deep_research", "--max-cost-usd has no effect without --mode claude-code/codex");
272    }
273
274    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
275    let paths = AppPaths::resolve(args.db.as_deref())?;
276    crate::storage::connection::ensure_db_ready(&paths)?;
277
278    // Phase 1: Query decomposition (sync, pure logic).
279    let sub_query_texts = decompose_query(&args.query, args.max_sub_queries);
280    let sub_queries: Vec<SubQuery> = sub_query_texts
281        .iter()
282        .enumerate()
283        .map(|(i, text)| SubQuery {
284            id: i,
285            text: text.clone(),
286            source: if sub_query_texts.len() == 1 {
287                "original"
288            } else {
289                "decomposed"
290            },
291        })
292        .collect();
293
294    // GAP-07 FIX: compute ONE embedding PER sub-query text (sequential — daemon serialises).
295    // The previous code used a single embedding for args.query shared across all sub-queries,
296    // making decomposition cosmetic.  We now build a Vec<Arc<Vec<f32>>> indexed by sub-query.
297    output::emit_progress_i18n(
298        "Computing per-sub-query embeddings...",
299        "Calculando embeddings por sub-consulta...",
300    );
301    let mut sub_embeddings: Vec<Arc<Vec<f32>>> = Vec::with_capacity(sub_query_texts.len());
302    for sq_text in &sub_query_texts {
303        let emb = crate::daemon::embed_query_or_local(
304            &paths.models,
305            sq_text,
306            args.daemon.autostart_daemon,
307        )?;
308        sub_embeddings.push(Arc::new(emb));
309    }
310
311    // Phase 2: Fan-out — parallel sub-query execution.
312    let cpu_count = std::thread::available_parallelism()
313        .map(|n| n.get())
314        .unwrap_or(4);
315    let permits = args
316        .max_concurrency
317        .unwrap_or_else(|| cpu_count.min(8))
318        .min(sub_queries.len())
319        .max(1);
320    let semaphore = Arc::new(Semaphore::new(permits));
321    let timeout_dur = std::time::Duration::from_secs(args.timeout);
322
323    let mut join_set: JoinSet<Result<SubQueryResult, (usize, String)>> = JoinSet::new();
324
325    for (idx, sq_text) in sub_query_texts.iter().enumerate() {
326        let sem = Arc::clone(&semaphore);
327        // GAP-07 FIX: pass embedding for THIS specific sub-query.
328        let emb = Arc::clone(&sub_embeddings[idx]);
329        let ns = namespace.clone();
330        let db_path = paths.db.clone();
331        let query_text = sq_text.clone();
332        let k = args.k;
333        let max_hops = args.max_hops;
334        let min_weight = args.min_weight;
335        let rrf_k = args.rrf_k;
336        let graph_decay = args.graph_decay;
337        let graph_min_score = args.graph_min_score;
338        let max_neighbors_per_hop = args.max_neighbors_per_hop;
339
340        join_set.spawn(async move {
341            let _permit = sem
342                .acquire_owned()
343                .await
344                .map_err(|e| (idx, format!("semaphore closed: {e}")))?;
345
346            // Dereference the Arc to obtain a &[f32] slice for the sync function.
347            let result = tokio::time::timeout(timeout_dur, async move {
348                execute_sub_query(
349                    idx,
350                    &query_text,
351                    emb.as_slice(),
352                    &ns,
353                    &db_path,
354                    k,
355                    max_hops,
356                    min_weight,
357                    rrf_k,
358                    graph_decay,
359                    graph_min_score,
360                    max_neighbors_per_hop,
361                )
362            })
363            .await;
364
365            match result {
366                Ok(inner) => inner.map_err(|e| (idx, e)),
367                Err(_) => Err((idx, "timeout".to_string())),
368            }
369        });
370    }
371
372    // Collect results incrementally.
373    let mut sub_query_results: Vec<SubQueryResult> = Vec::with_capacity(sub_queries.len());
374    let mut failed_count = 0usize;
375    let mut timed_out_count = 0usize;
376
377    while let Some(join_result) = join_set.join_next().await {
378        match join_result {
379            Ok(Ok(sqr)) => sub_query_results.push(sqr),
380            Ok(Err((_idx, reason))) => {
381                if reason == "timeout" {
382                    timed_out_count += 1;
383                } else {
384                    failed_count += 1;
385                }
386                tracing::warn!(target: "deep_research", sub_query_id = _idx, reason = %reason, "sub-query failed");
387            }
388            Err(join_err) => {
389                failed_count += 1;
390                if join_err.is_panic() {
391                    tracing::error!(target: "deep_research", error = %join_err, "sub-query task panicked");
392                } else {
393                    tracing::warn!(target: "deep_research", error = %join_err, "sub-query task cancelled");
394                }
395            }
396        }
397    }
398
399    // Phase 3: Evidence assembly — merge, dedup, rank.
400    // Aggregate hits: memory_id -> (best_score, source, snippet, body, hop_distance, sub_query_ids)
401    let mut merged: crate::hash::AHashMap<i64, MergedHit> =
402        crate::hash::AHashMap::with_capacity_and_hasher(
403            sub_query_results.len() * args.k,
404            Default::default(),
405        );
406
407    for sqr in &sub_query_results {
408        for (mem_id, score, source, snippet, body, hop) in &sqr.hits {
409            let entry = merged.entry(*mem_id).or_insert_with(|| {
410                (
411                    *score,
412                    source.clone(),
413                    snippet.clone(),
414                    body.clone(),
415                    *hop,
416                    Vec::new(),
417                )
418            });
419            // Keep best score.
420            if *score > entry.0 {
421                entry.0 = *score;
422                entry.1 = source.clone();
423                entry.2 = snippet.clone();
424                entry.3 = body.clone();
425                entry.4 = *hop;
426            }
427            if !entry.5.contains(&sqr.sub_query_id) {
428                entry.5.push(sqr.sub_query_id);
429            }
430        }
431    }
432
433    // Resolve memory names for merged results.
434    let conn = open_ro(&paths.db)?;
435    let mut results: Vec<DeepResult> = Vec::with_capacity(merged.len().min(args.max_results));
436
437    // Sort by score descending.
438    let mut ranked: Vec<(i64, MergedHit)> = merged.into_iter().collect();
439    ranked.sort_by(|a, b| {
440        b.1 .0
441            .partial_cmp(&a.1 .0)
442            .unwrap_or(std::cmp::Ordering::Equal)
443    });
444    ranked.truncate(args.max_results);
445
446    for (mem_id, (score, source, snippet, body, hop, sq_ids)) in ranked {
447        let name = match memories::read_full(&conn, mem_id)? {
448            Some(row) => row.name,
449            None => continue,
450        };
451        results.push(DeepResult {
452            name,
453            score,
454            source,
455            sub_query_ids: sq_ids,
456            snippet,
457            body: if args.with_bodies { Some(body) } else { None },
458            hop_distance: hop,
459        });
460    }
461
462    // GAP-09/10 FIX: Collect evidence chains from reconstructed BFS paths.
463    // The old code appended flat node pairs from a global SELECT; now each
464    // sub-query returns directed EvidenceChain structs (from, to, path).
465    let completed_count = sub_query_results.len();
466    let mut evidence_chains: Vec<EvidenceChain> = Vec::with_capacity(completed_count * 2);
467    let mut seen_chain_keys: HashSet<String> = HashSet::with_capacity(completed_count * 2);
468
469    for sqr in sub_query_results {
470        for chain in sqr.chains {
471            // Deduplicate chains by (from, to) pair.
472            let key = format!("{}->{}", chain.from, chain.to);
473            if seen_chain_keys.insert(key) {
474                evidence_chains.push(chain);
475            }
476        }
477    }
478
479    // Sort evidence chains by total_weight descending, discard single-hop trivial chains.
480    evidence_chains.retain(|c| c.depth >= 2);
481    evidence_chains.sort_by(|a, b| {
482        b.total_weight
483            .partial_cmp(&a.total_weight)
484            .unwrap_or(std::cmp::Ordering::Equal)
485    });
486
487    let unique_memories = results.len();
488    let evidence_count = evidence_chains.len();
489
490    // MEDIUM-01b: Build graph_context with entities and relationships from result memories.
491    let graph_context = if !results.is_empty() {
492        let result_names: Vec<&str> = results.iter().map(|r| r.name.as_str()).collect();
493        let mut ctx_entities: Vec<GraphContextEntity> = Vec::with_capacity(results.len());
494        let mut ctx_rels: Vec<GraphContextRel> = Vec::with_capacity(results.len() * 2);
495        let mut seen_entity_ids: crate::hash::AHashSet<i64> =
496            crate::hash::AHashSet::with_capacity_and_hasher(results.len(), Default::default());
497
498        for name in &result_names {
499            if let Ok(Some(eid)) = entities::find_entity_id(&conn, &namespace, name) {
500                if seen_entity_ids.insert(eid) {
501                    let etype: String = conn
502                        .query_row(
503                            "SELECT COALESCE(type,'concept') FROM entities WHERE id = ?1",
504                            rusqlite::params![eid],
505                            |r| r.get(0),
506                        )
507                        .unwrap_or_else(|_| "concept".to_string());
508                    let degree: u32 = conn
509                        .query_row(
510                            "SELECT COUNT(*) FROM relationships WHERE source_id = ?1 OR target_id = ?1",
511                            rusqlite::params![eid],
512                            |r| r.get(0),
513                        )
514                        .unwrap_or(0);
515                    ctx_entities.push(GraphContextEntity {
516                        name: name.to_string(),
517                        entity_type: etype,
518                        degree,
519                    });
520                }
521            }
522        }
523
524        let entity_ids: Vec<i64> = seen_entity_ids.iter().copied().collect();
525        if entity_ids.len() >= 2 {
526            let placeholders: String = entity_ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
527            let sql = format!(
528                "SELECT s.name, t.name, r.relation, r.weight \
529                 FROM relationships r \
530                 JOIN entities s ON s.id = r.source_id \
531                 JOIN entities t ON t.id = r.target_id \
532                 WHERE r.source_id IN ({placeholders}) AND r.target_id IN ({placeholders}) \
533                 LIMIT 50"
534            );
535            if let Ok(mut stmt) = conn.prepare(&sql) {
536                let mut params: Vec<Box<dyn rusqlite::types::ToSql>> =
537                    Vec::with_capacity(entity_ids.len() * 2);
538                for id in &entity_ids {
539                    params.push(Box::new(*id));
540                }
541                for id in &entity_ids {
542                    params.push(Box::new(*id));
543                }
544                let param_refs: Vec<&dyn rusqlite::types::ToSql> =
545                    params.iter().map(|p| p.as_ref()).collect();
546                if let Ok(rows) = stmt.query_map(param_refs.as_slice(), |r| {
547                    Ok((
548                        r.get::<_, String>(0)?,
549                        r.get::<_, String>(1)?,
550                        r.get::<_, String>(2)?,
551                        r.get::<_, f64>(3)?,
552                    ))
553                }) {
554                    for row in rows.flatten() {
555                        ctx_rels.push(GraphContextRel {
556                            from: row.0,
557                            to: row.1,
558                            relation: row.2,
559                            weight: row.3,
560                        });
561                    }
562                }
563            }
564        }
565
566        if ctx_entities.is_empty() {
567            None
568        } else {
569            Some(GraphContext {
570                entities: ctx_entities,
571                relationships: ctx_rels,
572            })
573        }
574    } else {
575        None
576    };
577
578    tracing::debug!(target: "deep_research",
579        total_results = results.len(),
580        total_chains = evidence_chains.len(),
581        "assembly complete"
582    );
583
584    // Phase 4: JSON output.
585    output::emit_json(&DeepResearchResponse {
586        query: args.query,
587        sub_queries,
588        results,
589        evidence_chains,
590        graph_context,
591        stats: ResearchStats {
592            sub_queries_total: sub_query_texts.len(),
593            sub_queries_completed: completed_count,
594            sub_queries_failed: failed_count,
595            sub_queries_timed_out: timed_out_count,
596            unique_memories_found: unique_memories,
597            evidence_chains_found: evidence_count,
598            elapsed_ms: start.elapsed().as_millis() as u64,
599        },
600    })?;
601
602    Ok(())
603}
604
605/// Heuristic query decomposition: splits by conjunctions, commas, semicolons,
606/// relational phrases, and extracts explicit entities (kebab-case or quoted).
607fn decompose_query(query: &str, max: usize) -> Vec<String> {
608    if query.is_empty() {
609        return vec![query.to_string()];
610    }
611
612    let mut parts: Vec<String> = Vec::with_capacity(max);
613
614    // Split by relational phrases first (most specific).
615    let relational = [
616        " that caused ",
617        " depending on ",
618        " related to ",
619        " connected to ",
620        " linked to ",
621        " caused by ",
622        " followed by ",
623    ];
624    let mut text = query.to_string();
625    let mut did_relational_split = false;
626    for phrase in &relational {
627        if text.to_lowercase().contains(phrase) {
628            let lower = text.to_lowercase();
629            if let Some(pos) = lower.find(phrase) {
630                let left = text[..pos].trim().to_string();
631                let right = text[pos + phrase.len()..].trim().to_string();
632                if !left.is_empty() {
633                    parts.push(left);
634                }
635                if !right.is_empty() {
636                    text = right;
637                }
638                did_relational_split = true;
639            }
640        }
641    }
642    if did_relational_split && !text.is_empty() {
643        parts.push(text.clone());
644    }
645
646    // If no relational split, try conjunctions and delimiters.
647    if parts.is_empty() {
648        // Split by semicolons first.
649        let semi_parts: Vec<&str> = query.split(';').collect();
650        if semi_parts.len() > 1 {
651            for p in &semi_parts {
652                let trimmed = p.trim();
653                if !trimmed.is_empty() {
654                    parts.push(trimmed.to_string());
655                }
656            }
657        } else {
658            // Split by commas and conjunctions.
659            // Replace " and " and " e " (Portuguese) with comma, then split.
660            let normalized = query
661                .replace(" and ", ", ")
662                .replace(" AND ", ", ")
663                .replace(" e ", ", ")
664                .replace(" E ", ", ");
665            let comma_parts: Vec<&str> = normalized.split(',').collect();
666            if comma_parts.len() > 1 {
667                for p in &comma_parts {
668                    let trimmed = p.trim();
669                    if !trimmed.is_empty() {
670                        parts.push(trimmed.to_string());
671                    }
672                }
673            }
674        }
675    }
676
677    // If still no split, try word-pair decomposition for multi-word queries.
678    if parts.is_empty() {
679        let words: Vec<&str> = query.split_whitespace().filter(|w| w.len() > 2).collect();
680        if words.len() >= 3 {
681            parts.push(query.to_string());
682            parts.push(format!("{} {}", words[0], words[1]));
683            parts.push(format!(
684                "{} {}",
685                words[words.len() - 2],
686                words[words.len() - 1]
687            ));
688        }
689    }
690
691    if parts.is_empty() {
692        return vec![query.to_string()];
693    }
694
695    // Cap at max.
696    parts.truncate(max);
697    parts
698}
699
700/// Reconstruct a directed path from `target_entity_id` back to a seed using the
701/// predecessor map built by BFS.  Returns the path nodes from root to target
702/// plus the accumulated edge weights.
703fn reconstruct_path(
704    target_id: i64,
705    seed_entity_ids: &HashSet<i64>,
706    predecessor: &PredecessorMap,
707    entity_names: &crate::hash::AHashMap<i64, String>,
708) -> Option<(Vec<EvidenceNode>, f64)> {
709    let mut path_ids: Vec<(i64, Option<String>, Option<f64>)> = Vec::with_capacity(8);
710    let mut total_weight = 1.0_f64;
711    let mut current = target_id;
712
713    loop {
714        if seed_entity_ids.contains(&current) {
715            break;
716        }
717        let (parent, relation, weight) = predecessor.get(&current)?;
718        total_weight *= weight;
719        path_ids.push((current, Some(relation.clone()), Some(*weight)));
720        current = *parent;
721    }
722    // Push the seed entity (root).
723    path_ids.push((current, None, None));
724
725    // Reverse so path goes from seed → target.
726    path_ids.reverse();
727
728    let nodes: Vec<EvidenceNode> = path_ids
729        .into_iter()
730        .map(|(id, relation, weight)| EvidenceNode {
731            entity: entity_names
732                .get(&id)
733                .cloned()
734                .unwrap_or_else(|| format!("entity-{id}")),
735            relation,
736            weight,
737        })
738        .collect();
739
740    Some((nodes, total_weight))
741}
742
743/// Execute a single sub-query: hybrid search (KNN + FTS fused via RRF) + graph traversal.
744///
745/// GAP-07 fix: receives the embedding for THIS sub-query (not the shared original).
746/// GAP-08/11 fix: uses rrf_fuse() for proper score fusion instead of hardcoded 0.5.
747/// GAP-09/10 fix: builds directed evidence chains filtered to discovered entities.
748/// GAP-17: respects max_neighbors_per_hop cap in BFS.
749///
750/// Runs synchronously on a blocking thread (called from a tokio spawn context).
751/// Each call opens its own read-only SQLite connection to leverage WAL concurrency.
752#[allow(clippy::too_many_arguments)]
753fn execute_sub_query(
754    sub_query_id: usize,
755    query_text: &str,
756    embedding: &[f32],
757    namespace: &str,
758    db_path: &std::path::Path,
759    k: usize,
760    max_hops: usize,
761    min_weight: f64,
762    rrf_k: f64,
763    graph_decay: f64,
764    graph_min_score: f64,
765    max_neighbors_per_hop: Option<usize>,
766) -> Result<SubQueryResult, String> {
767    let conn = open_ro(db_path).map_err(|e| format!("failed to open db: {e}"))?;
768
769    let mut hits: Vec<(i64, f64, String, String, String, Option<usize>)> =
770        Vec::with_capacity(k * 2);
771    let mut seen_ids: crate::hash::AHashSet<i64> =
772        crate::hash::AHashSet::with_capacity_and_hasher(k * 2, Default::default());
773
774    // --- GAP-08/11 FIX: Use RRF fusion for KNN + FTS instead of hardcoded 0.5 ---
775
776    // 1. KNN vector search — collect ranked IDs.
777    let knn_results = memories::knn_search(&conn, embedding, &[namespace.to_string()], None, k)
778        .map_err(|e| format!("knn_search failed: {e}"))?;
779    let knn_ids: Vec<i64> = knn_results.iter().map(|(id, _)| *id).collect();
780    tracing::debug!(target: "deep_research", sub_query_id, knn_count = knn_ids.len(), "KNN complete");
781
782    // Build distance map for score computation.
783    let knn_distance_map: crate::hash::AHashMap<i64, f64> = knn_results
784        .iter()
785        .map(|(id, dist)| (*id, *dist as f64))
786        .collect();
787
788    // 2. FTS5 search — collect ranked IDs.
789    let fts_results = match memories::fts_search(&conn, query_text, namespace, None, k) {
790        Ok(rows) => rows,
791        Err(e) => {
792            tracing::warn!(target: "deep_research",
793                sub_query_id,
794                "FTS5 search failed, continuing with KNN only: {e}"
795            );
796            vec![]
797        }
798    };
799    let fts_ids: Vec<i64> = fts_results.iter().map(|r| r.id).collect();
800    tracing::debug!(target: "deep_research", sub_query_id, fts_count = fts_ids.len(), "FTS complete");
801
802    // 3. Fuse via RRF.
803    let rrf_scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], rrf_k);
804    let max_possible = rrf_max_possible(&[1.0, 1.0], rrf_k);
805
806    // 4. Sort fused results and build hits.
807    let mut fused: Vec<(i64, f64)> = rrf_scores.into_iter().collect();
808    fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
809    fused.truncate(k * 2);
810    tracing::debug!(target: "deep_research",
811        sub_query_id,
812        fused_count = fused.len(),
813        "RRF fusion complete"
814    );
815
816    if fused.is_empty() && !knn_ids.is_empty() {
817        tracing::warn!(target: "deep_research", sub_query_id, knn_count = knn_ids.len(), fts_count = fts_ids.len(),
818            "RRF fusion returned 0 results despite KNN/FTS hits; consider lowering --graph-min-score");
819    }
820
821    for (memory_id, combined_score) in &fused {
822        if seen_ids.insert(*memory_id) {
823            let normalized = if max_possible > 0.0 {
824                combined_score / max_possible
825            } else {
826                0.0
827            };
828            let score = normalized.clamp(0.0, 1.0);
829            let in_knn = knn_distance_map.contains_key(memory_id);
830            let in_fts = fts_ids.contains(memory_id);
831            let source = match (in_knn, in_fts) {
832                (true, true) => "hybrid",
833                (true, false) => "knn",
834                (false, true) => "fts",
835                (false, false) => "graph",
836            };
837            if let Ok(Some(row)) = memories::read_full(&conn, *memory_id) {
838                let snippet: String = row.body.chars().take(300).collect();
839                hits.push((
840                    *memory_id,
841                    score,
842                    source.to_string(),
843                    snippet,
844                    row.body,
845                    None,
846                ));
847            }
848        }
849    }
850
851    // 5. Graph traversal from discovered memories.
852    // GAP-09/10 FIX: entity KNN also uses this sub-query's embedding.
853    let memory_ids: Vec<i64> = hits.iter().map(|(id, ..)| *id).collect();
854    let mut chains: Vec<EvidenceChain> = Vec::with_capacity(memory_ids.len());
855
856    if !memory_ids.is_empty() && max_hops > 0 {
857        // Seed entities from KNN on entity vectors using THIS sub-query's embedding.
858        let entity_knn = entities::knn_search(&conn, embedding, namespace, 5)
859            .inspect_err(|e| tracing::warn!(target: "deep_research", error = %e, "entity KNN search failed, skipping graph seed"))
860            .unwrap_or_default();
861        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
862
863        // HIGH-01 FIX: limit seeds to top-5 memories by score to prevent
864        // BFS from starting at every node when k >= total memories.
865        let top_seed_count = 5.min(memory_ids.len());
866        let top_memory_ids = &memory_ids[..top_seed_count];
867        let mut seed_entity_ids: Vec<i64> = entity_ids.clone();
868        for &mem_id in top_memory_ids {
869            let mut stmt = conn
870                .prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")
871                .map_err(|e| format!("prepare failed: {e}"))?;
872            let ids: Vec<i64> = stmt
873                .query_map(rusqlite::params![mem_id], |r| r.get(0))
874                .map_err(|e| format!("query failed: {e}"))?
875                .filter_map(|r| r.ok())
876                .collect();
877            seed_entity_ids.extend(ids);
878        }
879        seed_entity_ids.sort_unstable();
880        seed_entity_ids.dedup();
881        tracing::debug!(target: "deep_research",
882            sub_query_id,
883            seed_count = seed_entity_ids.len(),
884            "seed entities collected"
885        );
886
887        let all_seed_ids: Vec<i64> = memory_ids
888            .iter()
889            .chain(entity_ids.iter())
890            .copied()
891            .collect();
892
893        // Graph traversal with hop scores.
894        if let Ok(graph_results) = traverse_from_memories_with_hops_capped(
895            &conn,
896            &all_seed_ids,
897            namespace,
898            min_weight,
899            max_hops as u32,
900            max_neighbors_per_hop,
901        ) {
902            // Build seed score map from RRF-fused scores for graph decay computation.
903            let seed_score_map: crate::hash::AHashMap<i64, f64> = fused
904                .iter()
905                .map(|(id, s)| {
906                    let normalized = if max_possible > 0.0 {
907                        s / max_possible
908                    } else {
909                        0.0
910                    };
911                    (*id, normalized.clamp(0.0, 1.0))
912                })
913                .collect();
914
915            for (graph_mem_id, hop) in graph_results {
916                if seen_ids.insert(graph_mem_id) {
917                    // GAP-08/11 FIX: graph score = seed_score * decay^hop * edge_weight.
918                    // For the seed score, use the best score among the seed memories that
919                    // transitively reached this graph memory (approximate with the average
920                    // seed score since we don't track the exact path yet).
921                    let avg_seed_score: f64 = if seed_score_map.is_empty() {
922                        0.5
923                    } else {
924                        let sum: f64 = seed_score_map.values().sum();
925                        sum / seed_score_map.len() as f64
926                    };
927                    let graph_score =
928                        (avg_seed_score * graph_decay.powi(hop as i32)).clamp(0.0, 1.0);
929
930                    if graph_score < graph_min_score {
931                        continue;
932                    }
933
934                    if let Ok(Some(row)) = memories::read_full(&conn, graph_mem_id) {
935                        let snippet: String = row.body.chars().take(300).collect();
936                        hits.push((
937                            graph_mem_id,
938                            graph_score,
939                            "graph".to_string(),
940                            snippet,
941                            row.body,
942                            Some(hop as usize),
943                        ));
944                    }
945                }
946            }
947        }
948
949        // GAP-09/10 FIX: Build directed evidence chains using BFS with predecessor map,
950        // filtered to entities discovered in this sub-query.
951        if !seed_entity_ids.is_empty() {
952            let (entity_depth, predecessor) = bfs_with_predecessors(
953                &conn,
954                &seed_entity_ids,
955                namespace,
956                min_weight,
957                max_hops as u32,
958                max_neighbors_per_hop,
959            )
960            .unwrap_or_default();
961
962            tracing::debug!(target: "deep_research",
963                sub_query_id,
964                bfs_nodes = entity_depth.len(),
965                predecessors = predecessor.len(),
966                "BFS complete"
967            );
968
969            let seed_entity_set: HashSet<i64> = seed_entity_ids.iter().copied().collect();
970
971            // Collect entity IDs we need names for.
972            let all_entity_ids: Vec<i64> = entity_depth.keys().copied().collect();
973            let mut entity_names: crate::hash::AHashMap<i64, String> =
974                crate::hash::AHashMap::with_capacity_and_hasher(
975                    all_entity_ids.len(),
976                    ahash::RandomState::default(),
977                );
978            for &eid in &all_entity_ids {
979                let name_res: rusqlite::Result<String> = conn.query_row(
980                    "SELECT name FROM entities WHERE id = ?1",
981                    rusqlite::params![eid],
982                    |r| r.get(0),
983                );
984                if let Ok(name) = name_res {
985                    entity_names.insert(eid, name);
986                }
987            }
988
989            // Reconstruct a path for each non-seed entity that has a predecessor.
990            for (&target_id, &_hop) in &entity_depth {
991                if seed_entity_set.contains(&target_id) {
992                    continue;
993                }
994                if !predecessor.contains_key(&target_id) {
995                    continue;
996                }
997                if let Some((path_nodes, total_weight)) =
998                    reconstruct_path(target_id, &seed_entity_set, &predecessor, &entity_names)
999                {
1000                    if path_nodes.len() < 2 {
1001                        continue;
1002                    }
1003                    let from = path_nodes
1004                        .first()
1005                        .map(|n| n.entity.clone())
1006                        .unwrap_or_default();
1007                    let to = path_nodes
1008                        .last()
1009                        .map(|n| n.entity.clone())
1010                        .unwrap_or_default();
1011                    let depth = path_nodes.len();
1012                    chains.push(EvidenceChain {
1013                        from,
1014                        to,
1015                        path: path_nodes,
1016                        total_weight,
1017                        depth,
1018                        sub_query_ids: vec![sub_query_id],
1019                    });
1020                }
1021            }
1022
1023            // Sort chains by total_weight descending and cap to avoid huge output.
1024            chains.sort_by(|a, b| {
1025                b.total_weight
1026                    .partial_cmp(&a.total_weight)
1027                    .unwrap_or(std::cmp::Ordering::Equal)
1028            });
1029            chains.truncate(20);
1030            tracing::debug!(target: "deep_research",
1031                sub_query_id,
1032                chains_count = chains.len(),
1033                "evidence chains built"
1034            );
1035        }
1036    }
1037
1038    Ok(SubQueryResult {
1039        sub_query_id,
1040        hits,
1041        chains,
1042    })
1043}
1044
1045// ────────────────────────────────────────────────────────────────────────────
1046// Re-export sub_query_results field initialisation for the stats counter.
1047// The field is moved out of run_async after the join loop; we need to shadow it.
1048// ────────────────────────────────────────────────────────────────────────────
1049
1050#[cfg(test)]
1051mod tests {
1052    use super::*;
1053
1054    #[test]
1055    fn test_decompose_and_conjunction() {
1056        let result = decompose_query("A and B", 7);
1057        assert_eq!(result, vec!["A", "B"]);
1058    }
1059
1060    #[test]
1061    fn test_decompose_no_split() {
1062        let result = decompose_query("simple query", 7);
1063        assert_eq!(result, vec!["simple query"]);
1064    }
1065
1066    #[test]
1067    fn test_decompose_three_parts() {
1068        let result = decompose_query("A, B and C", 7);
1069        assert_eq!(result, vec!["A", "B", "C"]);
1070    }
1071
1072    #[test]
1073    fn test_decompose_portuguese_conjunctions() {
1074        let result = decompose_query("A e B", 7);
1075        assert_eq!(result, vec!["A", "B"]);
1076    }
1077
1078    #[test]
1079    fn test_decompose_max_cap() {
1080        let parts: Vec<String> = (0..10).map(|i| format!("part{i}")).collect();
1081        let query = parts.join(", ");
1082        let result = decompose_query(&query, 7);
1083        assert!(
1084            result.len() <= 7,
1085            "expected at most 7 sub-queries, got {}",
1086            result.len()
1087        );
1088    }
1089
1090    #[test]
1091    fn test_decompose_empty_preserves_original() {
1092        let result = decompose_query("", 7);
1093        assert_eq!(result, vec![""]);
1094    }
1095
1096    #[test]
1097    fn test_decompose_semicolons() {
1098        let result = decompose_query("auth design; deployment config; logging", 7);
1099        assert_eq!(result, vec!["auth design", "deployment config", "logging"]);
1100    }
1101
1102    #[test]
1103    fn test_decompose_relational_phrase() {
1104        let result = decompose_query("auth that caused deployment failure", 7);
1105        assert_eq!(result, vec!["auth", "deployment failure"]);
1106    }
1107
1108    #[test]
1109    fn test_sub_query_serialization() {
1110        let sq = SubQuery {
1111            id: 0,
1112            text: "test query".to_string(),
1113            source: "original",
1114        };
1115        let json = serde_json::to_value(&sq).expect("serialization failed");
1116        assert_eq!(json["id"], 0);
1117        assert_eq!(json["text"], "test query");
1118        assert_eq!(json["source"], "original");
1119    }
1120
1121    #[test]
1122    fn test_deep_result_omits_body_when_none() {
1123        let result = DeepResult {
1124            name: "test".to_string(),
1125            score: 0.9,
1126            source: "knn".to_string(),
1127            sub_query_ids: vec![0],
1128            snippet: "snippet".to_string(),
1129            body: None,
1130            hop_distance: None,
1131        };
1132        let json = serde_json::to_string(&result).expect("serialization failed");
1133        assert!(!json.contains("\"body\""), "body must be omitted when None");
1134    }
1135
1136    #[test]
1137    fn test_deep_result_includes_body_when_some() {
1138        let result = DeepResult {
1139            name: "test".to_string(),
1140            score: 0.9,
1141            source: "knn".to_string(),
1142            sub_query_ids: vec![0, 1],
1143            snippet: "snippet".to_string(),
1144            body: Some("full body content".to_string()),
1145            hop_distance: Some(2),
1146        };
1147        let json = serde_json::to_string(&result).expect("serialization failed");
1148        assert!(json.contains("\"body\""), "body must be present when Some");
1149        assert!(json.contains("full body content"));
1150    }
1151
1152    #[test]
1153    fn test_evidence_node_omits_none_fields() {
1154        let node = EvidenceNode {
1155            entity: "auth-module".to_string(),
1156            relation: None,
1157            weight: None,
1158        };
1159        let json = serde_json::to_string(&node).expect("serialization failed");
1160        assert!(
1161            !json.contains("\"relation\""),
1162            "relation must be omitted when None"
1163        );
1164        assert!(
1165            !json.contains("\"weight\""),
1166            "weight must be omitted when None"
1167        );
1168    }
1169
1170    #[test]
1171    fn test_research_stats_serialization() {
1172        let stats = ResearchStats {
1173            sub_queries_total: 3,
1174            sub_queries_completed: 2,
1175            sub_queries_failed: 1,
1176            sub_queries_timed_out: 0,
1177            unique_memories_found: 10,
1178            evidence_chains_found: 2,
1179            elapsed_ms: 1234,
1180        };
1181        let json = serde_json::to_value(&stats).expect("serialization failed");
1182        assert_eq!(json["sub_queries_total"], 3);
1183        assert_eq!(json["sub_queries_completed"], 2);
1184        assert_eq!(json["sub_queries_failed"], 1);
1185        assert_eq!(json["elapsed_ms"], 1234);
1186    }
1187
1188    #[test]
1189    fn test_deep_research_response_serialization() {
1190        let resp = DeepResearchResponse {
1191            query: "test query".to_string(),
1192            sub_queries: vec![SubQuery {
1193                id: 0,
1194                text: "test query".to_string(),
1195                source: "original",
1196            }],
1197            results: vec![],
1198            evidence_chains: vec![],
1199            graph_context: None,
1200            stats: ResearchStats {
1201                sub_queries_total: 1,
1202                sub_queries_completed: 1,
1203                sub_queries_failed: 0,
1204                sub_queries_timed_out: 0,
1205                unique_memories_found: 0,
1206                evidence_chains_found: 0,
1207                elapsed_ms: 42,
1208            },
1209        };
1210        let json = serde_json::to_value(&resp).expect("serialization failed");
1211        assert_eq!(json["query"], "test query");
1212        assert!(json["sub_queries"].is_array());
1213        assert!(json["results"].is_array());
1214        assert!(json["evidence_chains"].is_array());
1215        assert_eq!(json["stats"]["elapsed_ms"], 42);
1216    }
1217
1218    // ---- GAP-07 regression: different sub-queries produce distinct embeddings ----
1219    // We test decompose_query returns texts that *would* produce distinct embeddings
1220    // (different text inputs → different embedding inputs → different search results).
1221    #[test]
1222    fn test_distinct_sub_queries_produce_distinct_texts() {
1223        let queries = [
1224            "authentication design decisions",
1225            "deployment configuration and infrastructure",
1226        ];
1227        // These two texts must be different strings (prerequisite for distinct embeddings).
1228        assert_ne!(queries[0], queries[1]);
1229
1230        // decompose_query with semicolons must preserve distinct texts.
1231        let decomposed = decompose_query(
1232            "authentication design decisions; deployment configuration and infrastructure",
1233            7,
1234        );
1235        assert_eq!(decomposed.len(), 2);
1236        assert_ne!(decomposed[0], decomposed[1]);
1237    }
1238
1239    // ---- GAP-08/11 regression: rrf_fuse integration via fusion module ----
1240    #[test]
1241    fn test_rrf_fuse_via_fusion_module() {
1242        use crate::storage::fusion::rrf_fuse;
1243
1244        let knn_ids: Vec<i64> = vec![1, 2, 3];
1245        let fts_ids: Vec<i64> = vec![2, 1, 4];
1246        let scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], 60.0);
1247
1248        // Items appearing in both lists must score higher than items in only one list.
1249        let score_1 = scores[&1];
1250        let score_2 = scores[&2];
1251        let score_3 = scores[&3]; // knn only, rank 3
1252        let score_4 = scores[&4]; // fts only, rank 3
1253
1254        assert!(
1255            score_1 > score_3,
1256            "id 1 (both lists) must beat id 3 (knn-only rank 3)"
1257        );
1258        assert!(
1259            score_2 > score_4,
1260            "id 2 (both lists) must beat id 4 (fts-only rank 3)"
1261        );
1262    }
1263
1264    // ---- GAP-09/10 regression: evidence chains must be directed paths ----
1265    #[test]
1266    fn test_evidence_chain_has_from_to_and_path() {
1267        let chain = EvidenceChain {
1268            from: "auth-module".to_string(),
1269            to: "jwt-service".to_string(),
1270            path: vec![
1271                EvidenceNode {
1272                    entity: "auth-module".to_string(),
1273                    relation: None,
1274                    weight: None,
1275                },
1276                EvidenceNode {
1277                    entity: "token-validator".to_string(),
1278                    relation: Some("depends-on".to_string()),
1279                    weight: Some(0.9),
1280                },
1281                EvidenceNode {
1282                    entity: "jwt-service".to_string(),
1283                    relation: Some("uses".to_string()),
1284                    weight: Some(0.8),
1285                },
1286            ],
1287            total_weight: 0.72,
1288            depth: 3,
1289            sub_query_ids: vec![0],
1290        };
1291
1292        let json = serde_json::to_value(&chain).expect("serialization failed");
1293        assert!(
1294            json["from"].is_string(),
1295            "evidence chain must have 'from' field"
1296        );
1297        assert!(
1298            json["to"].is_string(),
1299            "evidence chain must have 'to' field"
1300        );
1301        assert!(
1302            json["path"].is_array(),
1303            "evidence chain must have 'path' array"
1304        );
1305        assert_eq!(json["path"].as_array().unwrap().len(), 3);
1306        assert!(json["total_weight"].is_number(), "must have total_weight");
1307        assert_eq!(json["depth"], 3);
1308    }
1309
1310    // ---- GAP-10 regression: reconstruct_path returns correct node order ----
1311    #[test]
1312    fn test_reconstruct_path_root_to_target_order() {
1313        // Build a simple chain: entity 10 (seed) -> entity 20 -> entity 30 (target)
1314        let seed_set: HashSet<i64> = [10i64].into_iter().collect();
1315        let mut predecessor: PredecessorMap = std::collections::HashMap::new();
1316        predecessor.insert(20, (10, "depends-on".to_string(), 0.9));
1317        predecessor.insert(30, (20, "uses".to_string(), 0.8));
1318        let mut entity_names: crate::hash::AHashMap<i64, String> = crate::hash::AHashMap::default();
1319        entity_names.insert(10, "seed-entity".to_string());
1320        entity_names.insert(20, "middle-entity".to_string());
1321        entity_names.insert(30, "target-entity".to_string());
1322
1323        let result = reconstruct_path(30, &seed_set, &predecessor, &entity_names);
1324        assert!(result.is_some(), "path must be reconstructed");
1325        let (nodes, weight) = result.unwrap();
1326        // Path must be [seed, middle, target]
1327        assert_eq!(nodes.len(), 3);
1328        assert_eq!(nodes[0].entity, "seed-entity");
1329        assert_eq!(nodes[1].entity, "middle-entity");
1330        assert_eq!(nodes[2].entity, "target-entity");
1331        // total_weight = 0.9 * 0.8
1332        assert!((weight - 0.72).abs() < 1e-6);
1333    }
1334
1335    // ---- GAP-09 regression: evidence chains must NOT be present for 1-hop trivial pairs ----
1336    #[test]
1337    fn test_evidence_chains_single_hop_filtered_out() {
1338        // A chain of depth 1 (only root node) should be discarded.
1339        let chain = EvidenceChain {
1340            from: "a".to_string(),
1341            to: "a".to_string(),
1342            path: vec![EvidenceNode {
1343                entity: "a".to_string(),
1344                relation: None,
1345                weight: None,
1346            }],
1347            total_weight: 1.0,
1348            depth: 1,
1349            sub_query_ids: vec![0],
1350        };
1351        // Simulate the filter: retain chains with depth >= 2.
1352        let chains = vec![chain];
1353        let retained: Vec<_> = chains.into_iter().filter(|c| c.depth >= 2).collect();
1354        assert!(retained.is_empty(), "depth-1 chains must be filtered out");
1355    }
1356
1357    // ---- GAP-17 regression: bfs_with_predecessors honours max_neighbors_per_hop ----
1358    #[test]
1359    fn test_bfs_with_predecessors_respects_neighbor_cap() {
1360        use crate::graph::bfs_with_predecessors;
1361        use rusqlite::Connection;
1362
1363        let conn = Connection::open_in_memory().unwrap();
1364        conn.execute_batch(
1365            "CREATE TABLE relationships (
1366                source_id INTEGER NOT NULL,
1367                target_id INTEGER NOT NULL,
1368                weight REAL NOT NULL,
1369                namespace TEXT NOT NULL,
1370                relation TEXT NOT NULL DEFAULT 'related'
1371             );",
1372        )
1373        .unwrap();
1374
1375        // Seed entity 1 has 5 neighbours.
1376        for target in 2i64..=6 {
1377            conn.execute(
1378                "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, 'ns')",
1379                rusqlite::params![1i64, target, 1.0f64],
1380            )
1381            .unwrap();
1382        }
1383
1384        // Without cap: all 5 neighbours reached.
1385        let (depth_uncapped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, None).unwrap();
1386        assert_eq!(
1387            depth_uncapped.len() - 1,
1388            5,
1389            "uncapped must discover all 5 neighbours (plus seed)"
1390        );
1391
1392        // With cap=2: only top-2 neighbours (by weight; all equal here so first 2 returned).
1393        let (depth_capped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, Some(2)).unwrap();
1394        // seed + 2 neighbours = 3 entries.
1395        assert_eq!(
1396            depth_capped.len(),
1397            3,
1398            "capped to 2 must yield seed + 2 neighbours"
1399        );
1400    }
1401}