Skip to main content

sqlite_graphrag/commands/
deep_research.rs

1//! Handler for the `deep-research` CLI subcommand.
2//!
3//! Orchestrates parallel multi-hop GraphRAG search via query decomposition.
4//! The workload is I/O-bound (SQLite WAL reads), so tokio is used instead of
5//! rayon. Each sub-query opens its own read-only connection.
6
7use crate::errors::AppError;
8use crate::graph::{
9    bfs_with_predecessors, traverse_from_memories_with_hops_capped, PredecessorMap,
10};
11use crate::output;
12use crate::paths::AppPaths;
13use crate::storage::connection::open_ro;
14use crate::storage::fusion::{rrf_fuse, rrf_max_possible};
15use crate::storage::{entities, memories};
16
17use serde::Serialize;
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20use tokio::sync::Semaphore;
21use tokio::task::JoinSet;
22
23/// Arguments for the `deep-research` subcommand.
24#[derive(clap::Args)]
25#[command(
26    about = "Deep parallel multi-hop GraphRAG research via query decomposition",
27    after_long_help = "EXAMPLES:\n  \
28        # Basic deep research\n  \
29        sqlite-graphrag deep-research \"auth architecture decisions\"\n\n  \
30        # With custom parameters\n  \
31        sqlite-graphrag deep-research \"auth\" --k 20 --max-hops 3 --max-sub-queries 7\n\n  \
32        # Include full memory bodies in output\n  \
33        sqlite-graphrag deep-research \"auth\" --with-bodies\n\n  \
34        # Tune RRF and graph scoring\n  \
35        sqlite-graphrag deep-research \"auth and deployment\" --rrf-k 60 --graph-decay 0.7"
36)]
37pub struct DeepResearchArgs {
38    /// Research query to decompose and search.
39    #[arg(value_name = "QUERY", help = "Research query to decompose and search")]
40    pub query: String,
41    /// Results per sub-query (Recall@20 captures 95%+ relevant hits).
42    #[arg(
43        long,
44        short,
45        default_value_t = 20,
46        help = "Results per sub-query (Recall@20 captures 95%+ relevant hits)"
47    )]
48    pub k: usize,
49    /// Maximum sub-queries from decomposition (covers complex multi-hop queries).
50    #[arg(
51        long,
52        default_value_t = 7,
53        help = "Maximum sub-queries (covers complex multi-hop queries)"
54    )]
55    pub max_sub_queries: usize,
56    /// Multi-hop graph traversal depth (sweet spot: 2-3 hops).
57    #[arg(
58        long,
59        default_value_t = 3,
60        help = "Multi-hop graph traversal depth (sweet spot: 2-3 hops)"
61    )]
62    pub max_hops: usize,
63    /// Minimum edge weight for graph traversal.
64    #[arg(
65        long,
66        default_value_t = 0.3,
67        help = "Minimum edge weight for graph traversal"
68    )]
69    pub min_weight: f64,
70    /// Maximum concurrent sub-queries (default: min(cpus, 8)).
71    #[arg(long, help = "Maximum concurrent sub-queries (default: min(cpus, 8))")]
72    pub max_concurrency: Option<usize>,
73    /// Timeout per sub-query in seconds.
74    #[arg(long, default_value_t = 30, help = "Timeout per sub-query in seconds")]
75    pub timeout: u64,
76    /// Include full memory bodies in results.
77    #[arg(
78        long,
79        default_value_t = false,
80        help = "Include full memory bodies in results"
81    )]
82    pub with_bodies: bool,
83    /// Maximum results after deduplication.
84    #[arg(
85        long,
86        default_value_t = 50,
87        help = "Maximum results after deduplication"
88    )]
89    pub max_results: usize,
90    /// RRF k parameter controlling score smoothing (higher = less weight on top ranks).
91    #[arg(
92        long,
93        default_value_t = 60.0,
94        help = "RRF k parameter (higher = less weight on top ranks)"
95    )]
96    pub rrf_k: f64,
97    /// Decay factor applied to graph scores per hop (score = seed_score * decay^hop).
98    #[arg(
99        long,
100        default_value_t = 0.7,
101        help = "Graph score decay factor per hop (0.0-1.0)"
102    )]
103    pub graph_decay: f64,
104    /// Minimum score threshold for graph-expanded results (filters noise).
105    #[arg(
106        long,
107        default_value_t = 0.05,
108        help = "Minimum score threshold for graph-expanded results"
109    )]
110    pub graph_min_score: f64,
111    /// Limit top-k neighbours followed per entity per hop (None = unlimited).
112    #[arg(
113        long,
114        help = "Limit neighbours per entity per hop for graph traversal (default: unlimited)"
115    )]
116    pub max_neighbors_per_hop: Option<usize>,
117    /// Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global).
118    #[arg(
119        long,
120        help = "Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global)"
121    )]
122    pub namespace: Option<String>,
123    /// Research mode: `none` (local heuristic, default), `claude-code`, `codex` (v1.1.0).
124    #[arg(long, default_value = "none", value_parser = ["none"], hide = true)]
125    pub mode: String,
126    /// Maximum LLM cost in USD (effective with --mode claude-code/codex, reserved for v1.1.0).
127    #[arg(
128        long,
129        value_name = "USD",
130        help = "Max LLM cost in USD (effective with --mode claude-code/codex)"
131    )]
132    pub max_cost_usd: Option<f64>,
133    /// JSON output (always on, kept for consistency).
134    #[arg(long, hide = true)]
135    pub json: bool,
136    /// Database path.
137    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
138    pub db: Option<String>,
139    #[command(flatten)]
140    pub daemon: crate::cli::DaemonOpts,
141}
142
143#[derive(Serialize)]
144struct SubQuery {
145    id: usize,
146    text: String,
147    source: &'static str,
148}
149
150#[derive(Serialize)]
151struct DeepResult {
152    name: String,
153    score: f64,
154    source: String,
155    sub_query_ids: Vec<usize>,
156    snippet: String,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    body: Option<String>,
159    hop_distance: Option<usize>,
160}
161
162/// A node in a reconstructed evidence path.
163#[derive(Serialize, Clone)]
164struct EvidenceNode {
165    entity: String,
166    #[serde(skip_serializing_if = "Option::is_none")]
167    relation: Option<String>,
168    #[serde(skip_serializing_if = "Option::is_none")]
169    weight: Option<f64>,
170}
171
172/// A directed evidence chain reconstructed from BFS predecessors.
173///
174/// Fields:
175/// - `from`: name of the seed (source) entity.
176/// - `to`: name of the terminal (target) entity.
177/// - `path`: ordered list of intermediate nodes from `from` to `to`.
178/// - `total_weight`: product of edge weights along the path.
179/// - `sub_query_ids`: which sub-queries produced this chain.
180#[derive(Serialize)]
181struct EvidenceChain {
182    from: String,
183    to: String,
184    path: Vec<EvidenceNode>,
185    total_weight: f64,
186    depth: usize,
187    sub_query_ids: Vec<usize>,
188}
189
190#[derive(Serialize)]
191struct ResearchStats {
192    sub_queries_total: usize,
193    sub_queries_completed: usize,
194    sub_queries_failed: usize,
195    sub_queries_timed_out: usize,
196    unique_memories_found: usize,
197    evidence_chains_found: usize,
198    elapsed_ms: u64,
199}
200
201#[derive(Serialize)]
202struct GraphContextEntity {
203    name: String,
204    entity_type: String,
205    degree: u32,
206}
207
208#[derive(Serialize)]
209struct GraphContextRel {
210    from: String,
211    to: String,
212    relation: String,
213    weight: f64,
214}
215
216#[derive(Serialize)]
217struct GraphContext {
218    entities: Vec<GraphContextEntity>,
219    relationships: Vec<GraphContextRel>,
220}
221
222#[derive(Serialize)]
223struct DeepResearchResponse {
224    query: String,
225    sub_queries: Vec<SubQuery>,
226    results: Vec<DeepResult>,
227    evidence_chains: Vec<EvidenceChain>,
228    #[serde(skip_serializing_if = "Option::is_none")]
229    graph_context: Option<GraphContext>,
230    stats: ResearchStats,
231}
232
233/// Aggregated hit data: (score, source_label, snippet, body, hop_distance, sub_query_ids).
234type MergedHit = (f64, String, String, String, Option<usize>, Vec<usize>);
235
236/// Intermediate result from a single sub-query execution.
237struct SubQueryResult {
238    sub_query_id: usize,
239    /// (memory_id, score, source_label, snippet, body, hop_distance)
240    hits: Vec<(i64, f64, String, String, String, Option<usize>)>,
241    /// Evidence chains reconstructed from BFS.
242    chains: Vec<EvidenceChain>,
243}
244
245/// Sync entry point — builds a tokio runtime for the async fan-out.
246pub fn run(args: DeepResearchArgs) -> Result<(), AppError> {
247    let rt = tokio::runtime::Builder::new_multi_thread()
248        .worker_threads(2)
249        .enable_all()
250        .build()
251        .map_err(|e| AppError::Internal(anyhow::anyhow!("failed to build tokio runtime: {e}")))?;
252    rt.block_on(run_async(args))
253}
254
255/// Main async logic: decompose, fan-out, assemble, emit JSON.
256async fn run_async(args: DeepResearchArgs) -> Result<(), AppError> {
257    let start = std::time::Instant::now();
258
259    if args.query.trim().is_empty() {
260        return Err(AppError::Validation(crate::i18n::validation::empty_query()));
261    }
262
263    if args.max_cost_usd.is_some() && args.mode == "none" {
264        tracing::warn!("--max-cost-usd has no effect without --mode claude-code/codex");
265    }
266
267    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
268    let paths = AppPaths::resolve(args.db.as_deref())?;
269    crate::storage::connection::ensure_db_ready(&paths)?;
270
271    // Phase 1: Query decomposition (sync, pure logic).
272    let sub_query_texts = decompose_query(&args.query, args.max_sub_queries);
273    let sub_queries: Vec<SubQuery> = sub_query_texts
274        .iter()
275        .enumerate()
276        .map(|(i, text)| SubQuery {
277            id: i,
278            text: text.clone(),
279            source: if sub_query_texts.len() == 1 {
280                "original"
281            } else {
282                "decomposed"
283            },
284        })
285        .collect();
286
287    // GAP-07 FIX: compute ONE embedding PER sub-query text (sequential — daemon serialises).
288    // The previous code used a single embedding for args.query shared across all sub-queries,
289    // making decomposition cosmetic.  We now build a Vec<Arc<Vec<f32>>> indexed by sub-query.
290    output::emit_progress_i18n(
291        "Computing per-sub-query embeddings...",
292        "Calculando embeddings por sub-consulta...",
293    );
294    let mut sub_embeddings: Vec<Arc<Vec<f32>>> = Vec::with_capacity(sub_query_texts.len());
295    for sq_text in &sub_query_texts {
296        let emb = crate::daemon::embed_query_or_local(
297            &paths.models,
298            sq_text,
299            args.daemon.autostart_daemon,
300        )?;
301        sub_embeddings.push(Arc::new(emb));
302    }
303
304    // Phase 2: Fan-out — parallel sub-query execution.
305    let cpu_count = std::thread::available_parallelism()
306        .map(|n| n.get())
307        .unwrap_or(4);
308    let permits = args
309        .max_concurrency
310        .unwrap_or_else(|| cpu_count.min(8))
311        .min(sub_queries.len())
312        .max(1);
313    let semaphore = Arc::new(Semaphore::new(permits));
314    let timeout_dur = std::time::Duration::from_secs(args.timeout);
315
316    let mut join_set: JoinSet<Result<SubQueryResult, (usize, String)>> = JoinSet::new();
317
318    for (idx, sq_text) in sub_query_texts.iter().enumerate() {
319        let sem = Arc::clone(&semaphore);
320        // GAP-07 FIX: pass embedding for THIS specific sub-query.
321        let emb = Arc::clone(&sub_embeddings[idx]);
322        let ns = namespace.clone();
323        let db_path = paths.db.clone();
324        let query_text = sq_text.clone();
325        let k = args.k;
326        let max_hops = args.max_hops;
327        let min_weight = args.min_weight;
328        let rrf_k = args.rrf_k;
329        let graph_decay = args.graph_decay;
330        let graph_min_score = args.graph_min_score;
331        let max_neighbors_per_hop = args.max_neighbors_per_hop;
332
333        join_set.spawn(async move {
334            let _permit = sem
335                .acquire_owned()
336                .await
337                .map_err(|e| (idx, format!("semaphore closed: {e}")))?;
338
339            // Dereference the Arc to obtain a &[f32] slice for the sync function.
340            let result = tokio::time::timeout(timeout_dur, async move {
341                execute_sub_query(
342                    idx,
343                    &query_text,
344                    emb.as_slice(),
345                    &ns,
346                    &db_path,
347                    k,
348                    max_hops,
349                    min_weight,
350                    rrf_k,
351                    graph_decay,
352                    graph_min_score,
353                    max_neighbors_per_hop,
354                )
355            })
356            .await;
357
358            match result {
359                Ok(inner) => inner.map_err(|e| (idx, e)),
360                Err(_) => Err((idx, "timeout".to_string())),
361            }
362        });
363    }
364
365    // Collect results incrementally.
366    let mut sub_query_results: Vec<SubQueryResult> = Vec::with_capacity(sub_queries.len());
367    let mut failed_count = 0usize;
368    let mut timed_out_count = 0usize;
369
370    while let Some(join_result) = join_set.join_next().await {
371        match join_result {
372            Ok(Ok(sqr)) => sub_query_results.push(sqr),
373            Ok(Err((_idx, reason))) => {
374                if reason == "timeout" {
375                    timed_out_count += 1;
376                } else {
377                    failed_count += 1;
378                }
379                tracing::warn!(sub_query_id = _idx, reason = %reason, "sub-query failed");
380            }
381            Err(join_err) => {
382                failed_count += 1;
383                if join_err.is_panic() {
384                    tracing::error!("sub-query task panicked: {join_err}");
385                } else {
386                    tracing::warn!("sub-query task cancelled: {join_err}");
387                }
388            }
389        }
390    }
391
392    // Phase 3: Evidence assembly — merge, dedup, rank.
393    // Aggregate hits: memory_id -> (best_score, source, snippet, body, hop_distance, sub_query_ids)
394    let mut merged: HashMap<i64, MergedHit> = HashMap::new();
395
396    for sqr in &sub_query_results {
397        for (mem_id, score, source, snippet, body, hop) in &sqr.hits {
398            let entry = merged.entry(*mem_id).or_insert_with(|| {
399                (
400                    *score,
401                    source.clone(),
402                    snippet.clone(),
403                    body.clone(),
404                    *hop,
405                    Vec::new(),
406                )
407            });
408            // Keep best score.
409            if *score > entry.0 {
410                entry.0 = *score;
411                entry.1 = source.clone();
412                entry.2 = snippet.clone();
413                entry.3 = body.clone();
414                entry.4 = *hop;
415            }
416            if !entry.5.contains(&sqr.sub_query_id) {
417                entry.5.push(sqr.sub_query_id);
418            }
419        }
420    }
421
422    // Resolve memory names for merged results.
423    let conn = open_ro(&paths.db)?;
424    let mut results: Vec<DeepResult> = Vec::with_capacity(merged.len().min(args.max_results));
425
426    // Sort by score descending.
427    let mut ranked: Vec<(i64, MergedHit)> = merged.into_iter().collect();
428    ranked.sort_by(|a, b| {
429        b.1 .0
430            .partial_cmp(&a.1 .0)
431            .unwrap_or(std::cmp::Ordering::Equal)
432    });
433    ranked.truncate(args.max_results);
434
435    for (mem_id, (score, source, snippet, body, hop, sq_ids)) in ranked {
436        let name = match memories::read_full(&conn, mem_id)? {
437            Some(row) => row.name,
438            None => continue,
439        };
440        results.push(DeepResult {
441            name,
442            score,
443            source,
444            sub_query_ids: sq_ids,
445            snippet,
446            body: if args.with_bodies { Some(body) } else { None },
447            hop_distance: hop,
448        });
449    }
450
451    // GAP-09/10 FIX: Collect evidence chains from reconstructed BFS paths.
452    // The old code appended flat node pairs from a global SELECT; now each
453    // sub-query returns directed EvidenceChain structs (from, to, path).
454    let completed_count = sub_query_results.len();
455    let mut evidence_chains: Vec<EvidenceChain> = Vec::new();
456    let mut seen_chain_keys: HashSet<String> = HashSet::new();
457
458    for sqr in sub_query_results {
459        for chain in sqr.chains {
460            // Deduplicate chains by (from, to) pair.
461            let key = format!("{}->{}", chain.from, chain.to);
462            if seen_chain_keys.insert(key) {
463                evidence_chains.push(chain);
464            }
465        }
466    }
467
468    // Sort evidence chains by total_weight descending, discard single-hop trivial chains.
469    evidence_chains.retain(|c| c.depth >= 2);
470    evidence_chains.sort_by(|a, b| {
471        b.total_weight
472            .partial_cmp(&a.total_weight)
473            .unwrap_or(std::cmp::Ordering::Equal)
474    });
475
476    let unique_memories = results.len();
477    let evidence_count = evidence_chains.len();
478
479    // MEDIUM-01b: Build graph_context with entities and relationships from result memories.
480    let graph_context = if !results.is_empty() {
481        let result_names: Vec<&str> = results.iter().map(|r| r.name.as_str()).collect();
482        let mut ctx_entities: Vec<GraphContextEntity> = Vec::new();
483        let mut ctx_rels: Vec<GraphContextRel> = Vec::new();
484        let mut seen_entity_ids: HashSet<i64> = HashSet::new();
485
486        for name in &result_names {
487            if let Ok(Some(eid)) = entities::find_entity_id(&conn, &namespace, name) {
488                if seen_entity_ids.insert(eid) {
489                    let etype: String = conn
490                        .query_row(
491                            "SELECT COALESCE(type,'concept') FROM entities WHERE id = ?1",
492                            rusqlite::params![eid],
493                            |r| r.get(0),
494                        )
495                        .unwrap_or_else(|_| "concept".to_string());
496                    let degree: u32 = conn
497                        .query_row(
498                            "SELECT COUNT(*) FROM relationships WHERE source_id = ?1 OR target_id = ?1",
499                            rusqlite::params![eid],
500                            |r| r.get(0),
501                        )
502                        .unwrap_or(0);
503                    ctx_entities.push(GraphContextEntity {
504                        name: name.to_string(),
505                        entity_type: etype,
506                        degree,
507                    });
508                }
509            }
510        }
511
512        let entity_ids: Vec<i64> = seen_entity_ids.iter().copied().collect();
513        if entity_ids.len() >= 2 {
514            let placeholders: String = entity_ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
515            let sql = format!(
516                "SELECT s.name, t.name, r.relation, r.weight \
517                 FROM relationships r \
518                 JOIN entities s ON s.id = r.source_id \
519                 JOIN entities t ON t.id = r.target_id \
520                 WHERE r.source_id IN ({placeholders}) AND r.target_id IN ({placeholders}) \
521                 LIMIT 50"
522            );
523            if let Ok(mut stmt) = conn.prepare(&sql) {
524                let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
525                for id in &entity_ids {
526                    params.push(Box::new(*id));
527                }
528                for id in &entity_ids {
529                    params.push(Box::new(*id));
530                }
531                let param_refs: Vec<&dyn rusqlite::types::ToSql> =
532                    params.iter().map(|p| p.as_ref()).collect();
533                if let Ok(rows) = stmt.query_map(param_refs.as_slice(), |r| {
534                    Ok((
535                        r.get::<_, String>(0)?,
536                        r.get::<_, String>(1)?,
537                        r.get::<_, String>(2)?,
538                        r.get::<_, f64>(3)?,
539                    ))
540                }) {
541                    for row in rows.flatten() {
542                        ctx_rels.push(GraphContextRel {
543                            from: row.0,
544                            to: row.1,
545                            relation: row.2,
546                            weight: row.3,
547                        });
548                    }
549                }
550            }
551        }
552
553        if ctx_entities.is_empty() {
554            None
555        } else {
556            Some(GraphContext {
557                entities: ctx_entities,
558                relationships: ctx_rels,
559            })
560        }
561    } else {
562        None
563    };
564
565    tracing::debug!(
566        total_results = results.len(),
567        total_chains = evidence_chains.len(),
568        "assembly complete"
569    );
570
571    // Phase 4: JSON output.
572    output::emit_json(&DeepResearchResponse {
573        query: args.query,
574        sub_queries,
575        results,
576        evidence_chains,
577        graph_context,
578        stats: ResearchStats {
579            sub_queries_total: sub_query_texts.len(),
580            sub_queries_completed: completed_count,
581            sub_queries_failed: failed_count,
582            sub_queries_timed_out: timed_out_count,
583            unique_memories_found: unique_memories,
584            evidence_chains_found: evidence_count,
585            elapsed_ms: start.elapsed().as_millis() as u64,
586        },
587    })?;
588
589    Ok(())
590}
591
592/// Heuristic query decomposition: splits by conjunctions, commas, semicolons,
593/// relational phrases, and extracts explicit entities (kebab-case or quoted).
594fn decompose_query(query: &str, max: usize) -> Vec<String> {
595    if query.is_empty() {
596        return vec![query.to_string()];
597    }
598
599    let mut parts: Vec<String> = Vec::new();
600
601    // Split by relational phrases first (most specific).
602    let relational = [
603        " that caused ",
604        " depending on ",
605        " related to ",
606        " connected to ",
607        " linked to ",
608        " caused by ",
609        " followed by ",
610    ];
611    let mut text = query.to_string();
612    let mut did_relational_split = false;
613    for phrase in &relational {
614        if text.to_lowercase().contains(phrase) {
615            let lower = text.to_lowercase();
616            if let Some(pos) = lower.find(phrase) {
617                let left = text[..pos].trim().to_string();
618                let right = text[pos + phrase.len()..].trim().to_string();
619                if !left.is_empty() {
620                    parts.push(left);
621                }
622                if !right.is_empty() {
623                    text = right;
624                }
625                did_relational_split = true;
626            }
627        }
628    }
629    if did_relational_split && !text.is_empty() {
630        parts.push(text.clone());
631    }
632
633    // If no relational split, try conjunctions and delimiters.
634    if parts.is_empty() {
635        // Split by semicolons first.
636        let semi_parts: Vec<&str> = query.split(';').collect();
637        if semi_parts.len() > 1 {
638            for p in &semi_parts {
639                let trimmed = p.trim();
640                if !trimmed.is_empty() {
641                    parts.push(trimmed.to_string());
642                }
643            }
644        } else {
645            // Split by commas and conjunctions.
646            // Replace " and " and " e " (Portuguese) with comma, then split.
647            let normalized = query
648                .replace(" and ", ", ")
649                .replace(" AND ", ", ")
650                .replace(" e ", ", ")
651                .replace(" E ", ", ");
652            let comma_parts: Vec<&str> = normalized.split(',').collect();
653            if comma_parts.len() > 1 {
654                for p in &comma_parts {
655                    let trimmed = p.trim();
656                    if !trimmed.is_empty() {
657                        parts.push(trimmed.to_string());
658                    }
659                }
660            }
661        }
662    }
663
664    // If still no split, try word-pair decomposition for multi-word queries.
665    if parts.is_empty() {
666        let words: Vec<&str> = query.split_whitespace().filter(|w| w.len() > 2).collect();
667        if words.len() >= 3 {
668            parts.push(query.to_string());
669            parts.push(format!("{} {}", words[0], words[1]));
670            parts.push(format!(
671                "{} {}",
672                words[words.len() - 2],
673                words[words.len() - 1]
674            ));
675        }
676    }
677
678    if parts.is_empty() {
679        return vec![query.to_string()];
680    }
681
682    // Cap at max.
683    parts.truncate(max);
684    parts
685}
686
687/// Reconstruct a directed path from `target_entity_id` back to a seed using the
688/// predecessor map built by BFS.  Returns the path nodes from root to target
689/// plus the accumulated edge weights.
690fn reconstruct_path(
691    target_id: i64,
692    seed_entity_ids: &HashSet<i64>,
693    predecessor: &PredecessorMap,
694    entity_names: &HashMap<i64, String>,
695) -> Option<(Vec<EvidenceNode>, f64)> {
696    let mut path_ids: Vec<(i64, Option<String>, Option<f64>)> = Vec::new();
697    let mut total_weight = 1.0_f64;
698    let mut current = target_id;
699
700    loop {
701        if seed_entity_ids.contains(&current) {
702            break;
703        }
704        let (parent, relation, weight) = predecessor.get(&current)?;
705        total_weight *= weight;
706        path_ids.push((current, Some(relation.clone()), Some(*weight)));
707        current = *parent;
708    }
709    // Push the seed entity (root).
710    path_ids.push((current, None, None));
711
712    // Reverse so path goes from seed → target.
713    path_ids.reverse();
714
715    let nodes: Vec<EvidenceNode> = path_ids
716        .into_iter()
717        .map(|(id, relation, weight)| EvidenceNode {
718            entity: entity_names
719                .get(&id)
720                .cloned()
721                .unwrap_or_else(|| format!("entity-{id}")),
722            relation,
723            weight,
724        })
725        .collect();
726
727    Some((nodes, total_weight))
728}
729
730/// Execute a single sub-query: hybrid search (KNN + FTS fused via RRF) + graph traversal.
731///
732/// GAP-07 fix: receives the embedding for THIS sub-query (not the shared original).
733/// GAP-08/11 fix: uses rrf_fuse() for proper score fusion instead of hardcoded 0.5.
734/// GAP-09/10 fix: builds directed evidence chains filtered to discovered entities.
735/// GAP-17: respects max_neighbors_per_hop cap in BFS.
736///
737/// Runs synchronously on a blocking thread (called from a tokio spawn context).
738/// Each call opens its own read-only SQLite connection to leverage WAL concurrency.
739#[allow(clippy::too_many_arguments)]
740fn execute_sub_query(
741    sub_query_id: usize,
742    query_text: &str,
743    embedding: &[f32],
744    namespace: &str,
745    db_path: &std::path::Path,
746    k: usize,
747    max_hops: usize,
748    min_weight: f64,
749    rrf_k: f64,
750    graph_decay: f64,
751    graph_min_score: f64,
752    max_neighbors_per_hop: Option<usize>,
753) -> Result<SubQueryResult, String> {
754    let conn = open_ro(db_path).map_err(|e| format!("failed to open db: {e}"))?;
755
756    let mut hits: Vec<(i64, f64, String, String, String, Option<usize>)> =
757        Vec::with_capacity(k * 2);
758    let mut seen_ids: HashSet<i64> = HashSet::new();
759
760    // --- GAP-08/11 FIX: Use RRF fusion for KNN + FTS instead of hardcoded 0.5 ---
761
762    // 1. KNN vector search — collect ranked IDs.
763    let knn_results = memories::knn_search(&conn, embedding, &[namespace.to_string()], None, k)
764        .map_err(|e| format!("knn_search failed: {e}"))?;
765    let knn_ids: Vec<i64> = knn_results.iter().map(|(id, _)| *id).collect();
766    tracing::debug!(sub_query_id, knn_count = knn_ids.len(), "KNN complete");
767
768    // Build distance map for score computation.
769    let knn_distance_map: HashMap<i64, f64> = knn_results
770        .iter()
771        .map(|(id, dist)| (*id, *dist as f64))
772        .collect();
773
774    // 2. FTS5 search — collect ranked IDs.
775    let fts_results = match memories::fts_search(&conn, query_text, namespace, None, k) {
776        Ok(rows) => rows,
777        Err(e) => {
778            tracing::warn!(
779                sub_query_id,
780                "FTS5 search failed, continuing with KNN only: {e}"
781            );
782            vec![]
783        }
784    };
785    let fts_ids: Vec<i64> = fts_results.iter().map(|r| r.id).collect();
786    tracing::debug!(sub_query_id, fts_count = fts_ids.len(), "FTS complete");
787
788    // 3. Fuse via RRF.
789    let rrf_scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], rrf_k);
790    let max_possible = rrf_max_possible(&[1.0, 1.0], rrf_k);
791
792    // 4. Sort fused results and build hits.
793    let mut fused: Vec<(i64, f64)> = rrf_scores.into_iter().collect();
794    fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
795    fused.truncate(k * 2);
796    tracing::debug!(
797        sub_query_id,
798        fused_count = fused.len(),
799        "RRF fusion complete"
800    );
801
802    if fused.is_empty() && !knn_ids.is_empty() {
803        tracing::warn!(sub_query_id, knn_count = knn_ids.len(), fts_count = fts_ids.len(),
804            "RRF fusion returned 0 results despite KNN/FTS hits; consider lowering --graph-min-score");
805    }
806
807    for (memory_id, combined_score) in &fused {
808        if seen_ids.insert(*memory_id) {
809            let normalized = if max_possible > 0.0 {
810                combined_score / max_possible
811            } else {
812                0.0
813            };
814            let score = normalized.clamp(0.0, 1.0);
815            let in_knn = knn_distance_map.contains_key(memory_id);
816            let in_fts = fts_ids.contains(memory_id);
817            let source = match (in_knn, in_fts) {
818                (true, true) => "hybrid",
819                (true, false) => "knn",
820                (false, true) => "fts",
821                (false, false) => "graph",
822            };
823            if let Ok(Some(row)) = memories::read_full(&conn, *memory_id) {
824                let snippet: String = row.body.chars().take(300).collect();
825                hits.push((
826                    *memory_id,
827                    score,
828                    source.to_string(),
829                    snippet,
830                    row.body,
831                    None,
832                ));
833            }
834        }
835    }
836
837    // 5. Graph traversal from discovered memories.
838    // GAP-09/10 FIX: entity KNN also uses this sub-query's embedding.
839    let memory_ids: Vec<i64> = hits.iter().map(|(id, ..)| *id).collect();
840    let mut chains: Vec<EvidenceChain> = Vec::new();
841
842    if !memory_ids.is_empty() && max_hops > 0 {
843        // Seed entities from KNN on entity vectors using THIS sub-query's embedding.
844        let entity_knn = entities::knn_search(&conn, embedding, namespace, 5).unwrap_or_default();
845        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
846
847        // HIGH-01 FIX: limit seeds to top-5 memories by score to prevent
848        // BFS from starting at every node when k >= total memories.
849        let top_seed_count = 5.min(memory_ids.len());
850        let top_memory_ids = &memory_ids[..top_seed_count];
851        let mut seed_entity_ids: Vec<i64> = entity_ids.clone();
852        for &mem_id in top_memory_ids {
853            let mut stmt = conn
854                .prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")
855                .map_err(|e| format!("prepare failed: {e}"))?;
856            let ids: Vec<i64> = stmt
857                .query_map(rusqlite::params![mem_id], |r| r.get(0))
858                .map_err(|e| format!("query failed: {e}"))?
859                .filter_map(|r| r.ok())
860                .collect();
861            seed_entity_ids.extend(ids);
862        }
863        seed_entity_ids.sort_unstable();
864        seed_entity_ids.dedup();
865        tracing::debug!(
866            sub_query_id,
867            seed_count = seed_entity_ids.len(),
868            "seed entities collected"
869        );
870
871        let all_seed_ids: Vec<i64> = memory_ids
872            .iter()
873            .chain(entity_ids.iter())
874            .copied()
875            .collect();
876
877        // Graph traversal with hop scores.
878        if let Ok(graph_results) = traverse_from_memories_with_hops_capped(
879            &conn,
880            &all_seed_ids,
881            namespace,
882            min_weight,
883            max_hops as u32,
884            max_neighbors_per_hop,
885        ) {
886            // Build seed score map from RRF-fused scores for graph decay computation.
887            let seed_score_map: HashMap<i64, f64> = fused
888                .iter()
889                .map(|(id, s)| {
890                    let normalized = if max_possible > 0.0 {
891                        s / max_possible
892                    } else {
893                        0.0
894                    };
895                    (*id, normalized.clamp(0.0, 1.0))
896                })
897                .collect();
898
899            for (graph_mem_id, hop) in graph_results {
900                if seen_ids.insert(graph_mem_id) {
901                    // GAP-08/11 FIX: graph score = seed_score * decay^hop * edge_weight.
902                    // For the seed score, use the best score among the seed memories that
903                    // transitively reached this graph memory (approximate with the average
904                    // seed score since we don't track the exact path yet).
905                    let avg_seed_score: f64 = if seed_score_map.is_empty() {
906                        0.5
907                    } else {
908                        let sum: f64 = seed_score_map.values().sum();
909                        sum / seed_score_map.len() as f64
910                    };
911                    let graph_score =
912                        (avg_seed_score * graph_decay.powi(hop as i32)).clamp(0.0, 1.0);
913
914                    if graph_score < graph_min_score {
915                        continue;
916                    }
917
918                    if let Ok(Some(row)) = memories::read_full(&conn, graph_mem_id) {
919                        let snippet: String = row.body.chars().take(300).collect();
920                        hits.push((
921                            graph_mem_id,
922                            graph_score,
923                            "graph".to_string(),
924                            snippet,
925                            row.body,
926                            Some(hop as usize),
927                        ));
928                    }
929                }
930            }
931        }
932
933        // GAP-09/10 FIX: Build directed evidence chains using BFS with predecessor map,
934        // filtered to entities discovered in this sub-query.
935        if !seed_entity_ids.is_empty() {
936            let (entity_depth, predecessor) = bfs_with_predecessors(
937                &conn,
938                &seed_entity_ids,
939                namespace,
940                min_weight,
941                max_hops as u32,
942                max_neighbors_per_hop,
943            )
944            .unwrap_or_default();
945
946            tracing::debug!(
947                sub_query_id,
948                bfs_nodes = entity_depth.len(),
949                predecessors = predecessor.len(),
950                "BFS complete"
951            );
952
953            let seed_entity_set: HashSet<i64> = seed_entity_ids.iter().copied().collect();
954
955            // Collect entity IDs we need names for.
956            let all_entity_ids: Vec<i64> = entity_depth.keys().copied().collect();
957            let mut entity_names: HashMap<i64, String> = HashMap::new();
958            for &eid in &all_entity_ids {
959                let name_res: rusqlite::Result<String> = conn.query_row(
960                    "SELECT name FROM entities WHERE id = ?1",
961                    rusqlite::params![eid],
962                    |r| r.get(0),
963                );
964                if let Ok(name) = name_res {
965                    entity_names.insert(eid, name);
966                }
967            }
968
969            // Reconstruct a path for each non-seed entity that has a predecessor.
970            for (&target_id, &_hop) in &entity_depth {
971                if seed_entity_set.contains(&target_id) {
972                    continue;
973                }
974                if !predecessor.contains_key(&target_id) {
975                    continue;
976                }
977                if let Some((path_nodes, total_weight)) =
978                    reconstruct_path(target_id, &seed_entity_set, &predecessor, &entity_names)
979                {
980                    if path_nodes.len() < 2 {
981                        continue;
982                    }
983                    let from = path_nodes
984                        .first()
985                        .map(|n| n.entity.clone())
986                        .unwrap_or_default();
987                    let to = path_nodes
988                        .last()
989                        .map(|n| n.entity.clone())
990                        .unwrap_or_default();
991                    let depth = path_nodes.len();
992                    chains.push(EvidenceChain {
993                        from,
994                        to,
995                        path: path_nodes,
996                        total_weight,
997                        depth,
998                        sub_query_ids: vec![sub_query_id],
999                    });
1000                }
1001            }
1002
1003            // Sort chains by total_weight descending and cap to avoid huge output.
1004            chains.sort_by(|a, b| {
1005                b.total_weight
1006                    .partial_cmp(&a.total_weight)
1007                    .unwrap_or(std::cmp::Ordering::Equal)
1008            });
1009            chains.truncate(20);
1010            tracing::debug!(
1011                sub_query_id,
1012                chains_count = chains.len(),
1013                "evidence chains built"
1014            );
1015        }
1016    }
1017
1018    Ok(SubQueryResult {
1019        sub_query_id,
1020        hits,
1021        chains,
1022    })
1023}
1024
1025// ────────────────────────────────────────────────────────────────────────────
1026// Re-export sub_query_results field initialisation for the stats counter.
1027// The field is moved out of run_async after the join loop; we need to shadow it.
1028// ────────────────────────────────────────────────────────────────────────────
1029
1030#[cfg(test)]
1031mod tests {
1032    use super::*;
1033
1034    #[test]
1035    fn test_decompose_and_conjunction() {
1036        let result = decompose_query("A and B", 7);
1037        assert_eq!(result, vec!["A", "B"]);
1038    }
1039
1040    #[test]
1041    fn test_decompose_no_split() {
1042        let result = decompose_query("simple query", 7);
1043        assert_eq!(result, vec!["simple query"]);
1044    }
1045
1046    #[test]
1047    fn test_decompose_three_parts() {
1048        let result = decompose_query("A, B and C", 7);
1049        assert_eq!(result, vec!["A", "B", "C"]);
1050    }
1051
1052    #[test]
1053    fn test_decompose_portuguese_conjunctions() {
1054        let result = decompose_query("A e B", 7);
1055        assert_eq!(result, vec!["A", "B"]);
1056    }
1057
1058    #[test]
1059    fn test_decompose_max_cap() {
1060        let parts: Vec<String> = (0..10).map(|i| format!("part{i}")).collect();
1061        let query = parts.join(", ");
1062        let result = decompose_query(&query, 7);
1063        assert!(
1064            result.len() <= 7,
1065            "expected at most 7 sub-queries, got {}",
1066            result.len()
1067        );
1068    }
1069
1070    #[test]
1071    fn test_decompose_empty_preserves_original() {
1072        let result = decompose_query("", 7);
1073        assert_eq!(result, vec![""]);
1074    }
1075
1076    #[test]
1077    fn test_decompose_semicolons() {
1078        let result = decompose_query("auth design; deployment config; logging", 7);
1079        assert_eq!(result, vec!["auth design", "deployment config", "logging"]);
1080    }
1081
1082    #[test]
1083    fn test_decompose_relational_phrase() {
1084        let result = decompose_query("auth that caused deployment failure", 7);
1085        assert_eq!(result, vec!["auth", "deployment failure"]);
1086    }
1087
1088    #[test]
1089    fn test_sub_query_serialization() {
1090        let sq = SubQuery {
1091            id: 0,
1092            text: "test query".to_string(),
1093            source: "original",
1094        };
1095        let json = serde_json::to_value(&sq).expect("serialization failed");
1096        assert_eq!(json["id"], 0);
1097        assert_eq!(json["text"], "test query");
1098        assert_eq!(json["source"], "original");
1099    }
1100
1101    #[test]
1102    fn test_deep_result_omits_body_when_none() {
1103        let result = DeepResult {
1104            name: "test".to_string(),
1105            score: 0.9,
1106            source: "knn".to_string(),
1107            sub_query_ids: vec![0],
1108            snippet: "snippet".to_string(),
1109            body: None,
1110            hop_distance: None,
1111        };
1112        let json = serde_json::to_string(&result).expect("serialization failed");
1113        assert!(!json.contains("\"body\""), "body must be omitted when None");
1114    }
1115
1116    #[test]
1117    fn test_deep_result_includes_body_when_some() {
1118        let result = DeepResult {
1119            name: "test".to_string(),
1120            score: 0.9,
1121            source: "knn".to_string(),
1122            sub_query_ids: vec![0, 1],
1123            snippet: "snippet".to_string(),
1124            body: Some("full body content".to_string()),
1125            hop_distance: Some(2),
1126        };
1127        let json = serde_json::to_string(&result).expect("serialization failed");
1128        assert!(json.contains("\"body\""), "body must be present when Some");
1129        assert!(json.contains("full body content"));
1130    }
1131
1132    #[test]
1133    fn test_evidence_node_omits_none_fields() {
1134        let node = EvidenceNode {
1135            entity: "auth-module".to_string(),
1136            relation: None,
1137            weight: None,
1138        };
1139        let json = serde_json::to_string(&node).expect("serialization failed");
1140        assert!(
1141            !json.contains("\"relation\""),
1142            "relation must be omitted when None"
1143        );
1144        assert!(
1145            !json.contains("\"weight\""),
1146            "weight must be omitted when None"
1147        );
1148    }
1149
1150    #[test]
1151    fn test_research_stats_serialization() {
1152        let stats = ResearchStats {
1153            sub_queries_total: 3,
1154            sub_queries_completed: 2,
1155            sub_queries_failed: 1,
1156            sub_queries_timed_out: 0,
1157            unique_memories_found: 10,
1158            evidence_chains_found: 2,
1159            elapsed_ms: 1234,
1160        };
1161        let json = serde_json::to_value(&stats).expect("serialization failed");
1162        assert_eq!(json["sub_queries_total"], 3);
1163        assert_eq!(json["sub_queries_completed"], 2);
1164        assert_eq!(json["sub_queries_failed"], 1);
1165        assert_eq!(json["elapsed_ms"], 1234);
1166    }
1167
1168    #[test]
1169    fn test_deep_research_response_serialization() {
1170        let resp = DeepResearchResponse {
1171            query: "test query".to_string(),
1172            sub_queries: vec![SubQuery {
1173                id: 0,
1174                text: "test query".to_string(),
1175                source: "original",
1176            }],
1177            results: vec![],
1178            evidence_chains: vec![],
1179            graph_context: None,
1180            stats: ResearchStats {
1181                sub_queries_total: 1,
1182                sub_queries_completed: 1,
1183                sub_queries_failed: 0,
1184                sub_queries_timed_out: 0,
1185                unique_memories_found: 0,
1186                evidence_chains_found: 0,
1187                elapsed_ms: 42,
1188            },
1189        };
1190        let json = serde_json::to_value(&resp).expect("serialization failed");
1191        assert_eq!(json["query"], "test query");
1192        assert!(json["sub_queries"].is_array());
1193        assert!(json["results"].is_array());
1194        assert!(json["evidence_chains"].is_array());
1195        assert_eq!(json["stats"]["elapsed_ms"], 42);
1196    }
1197
1198    // ---- GAP-07 regression: different sub-queries produce distinct embeddings ----
1199    // We test decompose_query returns texts that *would* produce distinct embeddings
1200    // (different text inputs → different embedding inputs → different search results).
1201    #[test]
1202    fn test_distinct_sub_queries_produce_distinct_texts() {
1203        let queries = [
1204            "authentication design decisions",
1205            "deployment configuration and infrastructure",
1206        ];
1207        // These two texts must be different strings (prerequisite for distinct embeddings).
1208        assert_ne!(queries[0], queries[1]);
1209
1210        // decompose_query with semicolons must preserve distinct texts.
1211        let decomposed = decompose_query(
1212            "authentication design decisions; deployment configuration and infrastructure",
1213            7,
1214        );
1215        assert_eq!(decomposed.len(), 2);
1216        assert_ne!(decomposed[0], decomposed[1]);
1217    }
1218
1219    // ---- GAP-08/11 regression: rrf_fuse integration via fusion module ----
1220    #[test]
1221    fn test_rrf_fuse_via_fusion_module() {
1222        use crate::storage::fusion::rrf_fuse;
1223
1224        let knn_ids: Vec<i64> = vec![1, 2, 3];
1225        let fts_ids: Vec<i64> = vec![2, 1, 4];
1226        let scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], 60.0);
1227
1228        // Items appearing in both lists must score higher than items in only one list.
1229        let score_1 = scores[&1];
1230        let score_2 = scores[&2];
1231        let score_3 = scores[&3]; // knn only, rank 3
1232        let score_4 = scores[&4]; // fts only, rank 3
1233
1234        assert!(
1235            score_1 > score_3,
1236            "id 1 (both lists) must beat id 3 (knn-only rank 3)"
1237        );
1238        assert!(
1239            score_2 > score_4,
1240            "id 2 (both lists) must beat id 4 (fts-only rank 3)"
1241        );
1242    }
1243
1244    // ---- GAP-09/10 regression: evidence chains must be directed paths ----
1245    #[test]
1246    fn test_evidence_chain_has_from_to_and_path() {
1247        let chain = EvidenceChain {
1248            from: "auth-module".to_string(),
1249            to: "jwt-service".to_string(),
1250            path: vec![
1251                EvidenceNode {
1252                    entity: "auth-module".to_string(),
1253                    relation: None,
1254                    weight: None,
1255                },
1256                EvidenceNode {
1257                    entity: "token-validator".to_string(),
1258                    relation: Some("depends-on".to_string()),
1259                    weight: Some(0.9),
1260                },
1261                EvidenceNode {
1262                    entity: "jwt-service".to_string(),
1263                    relation: Some("uses".to_string()),
1264                    weight: Some(0.8),
1265                },
1266            ],
1267            total_weight: 0.72,
1268            depth: 3,
1269            sub_query_ids: vec![0],
1270        };
1271
1272        let json = serde_json::to_value(&chain).expect("serialization failed");
1273        assert!(
1274            json["from"].is_string(),
1275            "evidence chain must have 'from' field"
1276        );
1277        assert!(
1278            json["to"].is_string(),
1279            "evidence chain must have 'to' field"
1280        );
1281        assert!(
1282            json["path"].is_array(),
1283            "evidence chain must have 'path' array"
1284        );
1285        assert_eq!(json["path"].as_array().unwrap().len(), 3);
1286        assert!(json["total_weight"].is_number(), "must have total_weight");
1287        assert_eq!(json["depth"], 3);
1288    }
1289
1290    // ---- GAP-10 regression: reconstruct_path returns correct node order ----
1291    #[test]
1292    fn test_reconstruct_path_root_to_target_order() {
1293        // Build a simple chain: entity 10 (seed) -> entity 20 -> entity 30 (target)
1294        let seed_set: HashSet<i64> = [10i64].into_iter().collect();
1295        let mut predecessor: HashMap<i64, (i64, String, f64)> = HashMap::new();
1296        predecessor.insert(20, (10, "depends-on".to_string(), 0.9));
1297        predecessor.insert(30, (20, "uses".to_string(), 0.8));
1298        let mut entity_names: HashMap<i64, String> = HashMap::new();
1299        entity_names.insert(10, "seed-entity".to_string());
1300        entity_names.insert(20, "middle-entity".to_string());
1301        entity_names.insert(30, "target-entity".to_string());
1302
1303        let result = reconstruct_path(30, &seed_set, &predecessor, &entity_names);
1304        assert!(result.is_some(), "path must be reconstructed");
1305        let (nodes, weight) = result.unwrap();
1306        // Path must be [seed, middle, target]
1307        assert_eq!(nodes.len(), 3);
1308        assert_eq!(nodes[0].entity, "seed-entity");
1309        assert_eq!(nodes[1].entity, "middle-entity");
1310        assert_eq!(nodes[2].entity, "target-entity");
1311        // total_weight = 0.9 * 0.8
1312        assert!((weight - 0.72).abs() < 1e-6);
1313    }
1314
1315    // ---- GAP-09 regression: evidence chains must NOT be present for 1-hop trivial pairs ----
1316    #[test]
1317    fn test_evidence_chains_single_hop_filtered_out() {
1318        // A chain of depth 1 (only root node) should be discarded.
1319        let chain = EvidenceChain {
1320            from: "a".to_string(),
1321            to: "a".to_string(),
1322            path: vec![EvidenceNode {
1323                entity: "a".to_string(),
1324                relation: None,
1325                weight: None,
1326            }],
1327            total_weight: 1.0,
1328            depth: 1,
1329            sub_query_ids: vec![0],
1330        };
1331        // Simulate the filter: retain chains with depth >= 2.
1332        let chains = vec![chain];
1333        let retained: Vec<_> = chains.into_iter().filter(|c| c.depth >= 2).collect();
1334        assert!(retained.is_empty(), "depth-1 chains must be filtered out");
1335    }
1336
1337    // ---- GAP-17 regression: bfs_with_predecessors honours max_neighbors_per_hop ----
1338    #[test]
1339    fn test_bfs_with_predecessors_respects_neighbor_cap() {
1340        use crate::graph::bfs_with_predecessors;
1341        use rusqlite::Connection;
1342
1343        let conn = Connection::open_in_memory().unwrap();
1344        conn.execute_batch(
1345            "CREATE TABLE relationships (
1346                source_id INTEGER NOT NULL,
1347                target_id INTEGER NOT NULL,
1348                weight REAL NOT NULL,
1349                namespace TEXT NOT NULL,
1350                relation TEXT NOT NULL DEFAULT 'related'
1351             );",
1352        )
1353        .unwrap();
1354
1355        // Seed entity 1 has 5 neighbours.
1356        for target in 2i64..=6 {
1357            conn.execute(
1358                "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, 'ns')",
1359                rusqlite::params![1i64, target, 1.0f64],
1360            )
1361            .unwrap();
1362        }
1363
1364        // Without cap: all 5 neighbours reached.
1365        let (depth_uncapped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, None).unwrap();
1366        assert_eq!(
1367            depth_uncapped.len() - 1,
1368            5,
1369            "uncapped must discover all 5 neighbours (plus seed)"
1370        );
1371
1372        // With cap=2: only top-2 neighbours (by weight; all equal here so first 2 returned).
1373        let (depth_capped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, Some(2)).unwrap();
1374        // seed + 2 neighbours = 3 entries.
1375        assert_eq!(
1376            depth_capped.len(),
1377            3,
1378            "capped to 2 must yield seed + 2 neighbours"
1379        );
1380    }
1381}