Skip to main content

zeph_memory/graph/
activation.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! SYNAPSE spreading activation retrieval over the entity graph.
5//!
6//! Implements the spreading activation algorithm from arXiv 2601.02744, adapted for
7//! the zeph-memory graph schema. Seeds are matched via fuzzy entity search; activation
8//! propagates hop-by-hop with:
9//! - Exponential decay per hop (`decay_lambda`)
10//! - Edge confidence weighting
11//! - Temporal recency weighting (reuses `GraphConfig.temporal_decay_rate`)
12//! - Lateral inhibition (nodes above `inhibition_threshold` stop receiving activation)
13//! - Per-hop pruning to enforce `max_activated_nodes` bound (SA-INV-04)
14//! - MAGMA edge type filtering via `edge_types` parameter
15
16use std::collections::HashMap;
17use std::sync::OnceLock;
18use std::time::{Instant, SystemTime, UNIX_EPOCH};
19#[allow(unused_imports)]
20use zeph_db::sql;
21
22use crate::embedding_store::EmbeddingStore;
23use crate::error::MemoryError;
24use crate::graph::store::GraphStore;
25use crate::graph::types::{Edge, EdgeType, edge_type_weight, evolved_weight};
26
27/// A graph node that was activated during spreading activation.
28#[derive(Debug, Clone)]
29pub struct ActivatedNode {
30    /// Database ID of the activated entity.
31    pub entity_id: i64,
32    /// Final activation score in `[0.0, 1.0]`.
33    pub activation: f32,
34    /// Hop at which the maximum activation was received (`0` = seed).
35    pub depth: u32,
36}
37
38/// A graph edge traversed during spreading activation, with its activation score.
39#[derive(Debug, Clone)]
40pub struct ActivatedFact {
41    /// The traversed edge.
42    pub edge: Edge,
43    /// Activation score of the source or target entity at time of traversal.
44    pub activation_score: f32,
45}
46
47pub use zeph_common::memory::SpreadingActivationParams;
48
49// ── HL-F5: HeLa-Mem spreading activation (#3346) ─────────────────────────────
50
51/// A graph edge surfaced by HL-F5 spreading activation (#3346), scored by
52/// `path_weight × max(cosine_query_to_endpoint, 0.0)`.
53///
54/// Mirrors [`ActivatedFact`] so callers can dispatch over a single
55/// `Vec<HelaFact>` ↔ `Vec<ActivatedFact>` ↔ `Vec<GraphFact>` shape at the
56/// strategy-selection site.
57#[derive(Debug, Clone)]
58pub struct HelaFact {
59    /// The edge by which the higher-scored endpoint was reached.
60    pub edge: Edge,
61    /// Final HL-F5 score: `path_weight × cosine_clamped`. Range: `[0.0, +∞)`.
62    pub score: f32,
63    /// BFS depth at which `edge` was traversed (`1..=spread_depth`).
64    /// `0` is reserved for the synthetic anchor edge in the isolated-anchor fallback.
65    pub depth: u32,
66    /// Multiplicative product of edge weights along the BFS path that reached
67    /// this edge's far endpoint. Range: `[0.0, +∞)`.
68    pub path_weight: f32,
69    /// Clamped cosine similarity of the far endpoint's entity embedding
70    /// to the query embedding, in `[0.0, 1.0]`. `None` when the endpoint
71    /// has no stored embedding (skipped from results in that case).
72    pub cosine: Option<f32>,
73}
74
75/// Parameters for HL-F5 spreading activation retrieval.
76///
77/// Build via [`Default`] and override individual fields:
78///
79/// ```rust
80/// use zeph_memory::graph::activation::HelaSpreadParams;
81///
82/// let params = HelaSpreadParams { spread_depth: 3, ..Default::default() };
83/// ```
84#[derive(Debug, Clone)]
85pub struct HelaSpreadParams {
86    /// BFS hops. Clamped to `[1, 6]` at runtime. Default: `2`.
87    pub spread_depth: u32,
88    /// MAGMA edge-type filter. Empty = all types. Default: `[]`.
89    pub edge_types: Vec<EdgeType>,
90    /// Soft upper bound on the visited-node set. Default: `200`.
91    pub max_visited: usize,
92    /// Per-step circuit breaker. Any internal step (anchor ANN, edges batch,
93    /// vectors batch) that exceeds this duration triggers an `Ok(Vec::new())`
94    /// fallback with a `WARN`. Default: `Some(8 ms)`.
95    pub step_budget: Option<std::time::Duration>,
96    /// Timeout for the initial query embedding call. `None` = no timeout.
97    /// Default: `Some(5 s)`.
98    pub embed_timeout: Option<std::time::Duration>,
99}
100
101impl Default for HelaSpreadParams {
102    fn default() -> Self {
103        Self {
104            spread_depth: 2,
105            edge_types: Vec::new(),
106            max_visited: 200,
107            step_budget: Some(std::time::Duration::from_millis(8)),
108            embed_timeout: Some(std::time::Duration::from_secs(5)),
109        }
110    }
111}
112
113/// Process-global dim-mismatch sentinel for HL-F5 (keyed by collection name).
114///
115/// MINOR-1 resolution: keyed by collection so re-provisioning with a different
116/// dimension recovers after a process restart.  A per-`SemanticMemory` guard would
117/// require passing state down; a process-global string key is the least-invasive
118/// approach that prevents permanent lockout from transient startup errors.
119/// Test isolation: each test constructs its own `HelaSpreadParams` with
120/// a distinct mock collection name to avoid cross-test interference.
121static HELA_DIM_MISMATCH: OnceLock<String> = OnceLock::new();
122
123/// Cosine similarity of two equal-length slices.
124///
125/// Returns `0.0` when either norm is zero (prevents division by zero).
126fn cosine(a: &[f32], b: &[f32]) -> f32 {
127    let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
128    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
129    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
130    let denom = (norm_a * norm_b).max(f32::EPSILON);
131    dot / denom
132}
133
134/// HL-F5 BFS spreading activation from the top-1 ANN anchor node (#3346).
135///
136/// Algorithm overview:
137/// 1. Embed `query` → anchor via ANN search in the entity Qdrant collection.
138/// 2. BFS up to `params.spread_depth` hops, propagating multiplicative edge
139///    weights (`path_weight = Π edge.weight along path`). Multi-path convergence
140///    keeps the maximum `path_weight`.
141/// 3. Retrieve entity embeddings for all visited nodes via `get_points`.
142/// 4. Score each node: `score = path_weight × max(cosine(query, entity), 0.0)`.
143/// 5. Sort descending, truncate to `limit`, reinforce traversed edges via Hebbian
144///    update (when `hebbian_enabled`).
145///
146/// Fallback: when the anchor entity has no outgoing edges a single synthetic
147/// [`HelaFact`] with `edge.id == 0` and `score = anchor_cosine` is returned
148/// (the real ANN cosine, never a fabricated `1.0`).
149///
150/// Per-step circuit breaker: any individual step exceeding `params.step_budget`
151/// emits a `WARN` and returns `Ok(Vec::new())`.
152///
153/// Dim-mismatch resilience: a one-time dim probe on the first call guards against
154/// collection/provider configuration mismatches (#3382 pattern). Subsequent calls
155/// to a mismatched collection short-circuit immediately.
156///
157/// # Errors
158///
159/// Returns an error if the embed call or any database query fails.
160#[tracing::instrument(
161    name = "memory.graph.hela_spread",
162    skip_all,
163    fields(
164        depth = params.spread_depth,
165        limit,
166        anchor_id = tracing::field::Empty,
167        visited = tracing::field::Empty,
168        scored = tracing::field::Empty,
169        fallback = tracing::field::Empty,
170    )
171)]
172#[allow(clippy::too_many_arguments, clippy::too_many_lines)] // complex algorithm function; both suppressions justified until the function is decomposed in a future refactor
173pub async fn hela_spreading_recall(
174    store: &GraphStore,
175    embeddings: &EmbeddingStore,
176    provider: &zeph_llm::any::AnyProvider,
177    query: &str,
178    limit: usize,
179    params: &HelaSpreadParams,
180    hebbian_enabled: bool,
181    hebbian_lr: f32,
182) -> Result<Vec<HelaFact>, MemoryError> {
183    use zeph_llm::LlmProvider as _;
184
185    const ENTITY_COLLECTION: &str = "zeph_graph_entities";
186
187    if limit == 0 {
188        return Ok(Vec::new());
189    }
190
191    // ── Step 0: dim-mismatch guard ────────────────────────────────────────────
192    // MINOR-1: guard is keyed by collection name so re-provisioning recovers.
193    if HELA_DIM_MISMATCH.get().map(String::as_str) == Some(ENTITY_COLLECTION) {
194        tracing::debug!("hela: dim mismatch previously detected for collection, skipping");
195        return Ok(Vec::new());
196    }
197
198    // ── Step 1: embed query ───────────────────────────────────────────────────
199    let q_vec = if let Some(timeout) = params.embed_timeout {
200        tokio::time::timeout(timeout, provider.embed(query))
201            .await
202            .map_err(|_| {
203                tracing::warn!(timeout_ms = timeout.as_millis(), "hela: embed timed out");
204                MemoryError::Timeout("hela embed".into())
205            })??
206    } else {
207        provider.embed(query).await?
208    };
209
210    // Dim probe: search with k=1 to catch dimension mismatch at the Qdrant layer.
211    let t_anchor = Instant::now();
212    let anchor_results = match embeddings
213        .search_collection(ENTITY_COLLECTION, &q_vec, 1, None)
214        .await
215    {
216        Ok(r) => r,
217        Err(e) => {
218            let msg = e.to_string();
219            if msg.contains("wrong vector dimension")
220                || msg.contains("InvalidArgument")
221                || msg.contains("dimension")
222            {
223                let _ = HELA_DIM_MISMATCH.set(ENTITY_COLLECTION.to_owned());
224                tracing::warn!(
225                    collection = ENTITY_COLLECTION,
226                    error = %e,
227                    "hela: vector dimension mismatch — HL-F5 disabled for this collection"
228                );
229                return Ok(Vec::new());
230            }
231            return Err(e);
232        }
233    };
234
235    if params.step_budget.is_some_and(|b| t_anchor.elapsed() > b) {
236        tracing::warn!(
237            elapsed_ms = t_anchor.elapsed().as_millis(),
238            "hela: anchor ANN over budget"
239        );
240        return Ok(Vec::new());
241    }
242
243    let Some(anchor_point) = anchor_results.first() else {
244        tracing::debug!("hela: no anchor found, returning empty");
245        return Ok(Vec::new());
246    };
247    let Some(anchor_entity_id) = anchor_point
248        .payload
249        .get("entity_id")
250        .and_then(serde_json::Value::as_i64)
251    else {
252        tracing::warn!("hela: anchor point missing entity_id payload");
253        return Ok(Vec::new());
254    };
255    let anchor_cosine = anchor_point.score;
256
257    tracing::Span::current().record("anchor_id", anchor_entity_id);
258    tracing::debug!(anchor_entity_id, anchor_cosine, "hela: anchor resolved");
259
260    let spread_depth = params.spread_depth.clamp(1, 6);
261
262    // ── Step 2: BFS with multiplicative path-weight propagation ──────────────
263    // `visited`: entity_id → (depth, path_weight, edge_id_via_which_we_arrived)
264    let mut visited: HashMap<i64, (u32, f32, Option<i64>)> = HashMap::new();
265    visited.insert(anchor_entity_id, (0, 1.0, None));
266
267    // Dedup edges keyed by id for Step 4 lookup (avoids N clones per frontier).
268    // MINOR-3 resolution: collect edges into a HashMap<id, Edge> outside the
269    // per-source loop to avoid 10K clones on a hub × 50-entity frontier.
270    let mut edge_cache: HashMap<i64, Edge> = HashMap::new();
271    let mut frontier: Vec<i64> = vec![anchor_entity_id];
272
273    for hop in 0..spread_depth {
274        if frontier.is_empty() {
275            break;
276        }
277
278        tracing::debug!(hop, frontier_size = frontier.len(), "hela: starting hop");
279
280        let t_step = Instant::now();
281        let edges = store
282            .edges_for_entities(&frontier, &params.edge_types)
283            .await?;
284        if params.step_budget.is_some_and(|b| t_step.elapsed() > b) {
285            tracing::warn!(
286                hop,
287                elapsed_ms = t_step.elapsed().as_millis(),
288                "hela: edge-fetch over budget"
289            );
290            return Ok(Vec::new());
291        }
292
293        let mut next_frontier: Vec<i64> = Vec::new();
294
295        for edge in &edges {
296            // Cache by edge id to avoid repeated clones per source in frontier.
297            edge_cache.entry(edge.id).or_insert_with(|| edge.clone());
298
299            for &src_id in &frontier {
300                let neighbor = if edge.source_entity_id == src_id {
301                    edge.target_entity_id
302                } else if edge.target_entity_id == src_id {
303                    edge.source_entity_id
304                } else {
305                    continue;
306                };
307
308                let parent_pw = visited.get(&src_id).map_or(1.0, |&(_, pw, _)| pw);
309                let new_pw = parent_pw * edge.weight;
310
311                // Multi-path resolution: keep MAX path_weight; lower depth as
312                // tie-break. MINOR-4 note: max_visited is a soft bound — the
313                // actual visited set may exceed it by O(edges_per_hop_step) for
314                // one frontier step before the outer break fires.
315                let entry = visited
316                    .entry(neighbor)
317                    .or_insert((hop + 1, 0.0_f32, Some(edge.id)));
318                // Prefer strictly higher path weight; break ties in favour of shallower depth.
319                if new_pw > entry.1
320                    || ((new_pw - entry.1).abs() < f32::EPSILON && hop + 1 < entry.0)
321                {
322                    *entry = (hop + 1, new_pw, Some(edge.id));
323                    if !next_frontier.contains(&neighbor) {
324                        next_frontier.push(neighbor);
325                    }
326                }
327
328                if visited.len() >= params.max_visited {
329                    break;
330                }
331            }
332
333            if visited.len() >= params.max_visited {
334                break;
335            }
336        }
337
338        tracing::debug!(
339            hop,
340            edges_fetched = edges.len(),
341            visited = visited.len(),
342            next_frontier = next_frontier.len(),
343            "hela: hop complete"
344        );
345
346        frontier = next_frontier;
347        if visited.len() >= params.max_visited {
348            break;
349        }
350    }
351
352    // ── Isolated-anchor fallback ──────────────────────────────────────────────
353    // `visited.len() == 1` means no edges were traversed from the anchor.
354    if visited.len() == 1 {
355        tracing::Span::current().record("fallback", true);
356        tracing::debug!(
357            anchor_entity_id,
358            anchor_cosine,
359            "hela: anchor isolated, falling back to pure ANN"
360        );
361        let fact = HelaFact {
362            edge: Edge::synthetic_anchor(anchor_entity_id),
363            score: anchor_cosine,
364            depth: 0,
365            path_weight: 1.0,
366            cosine: Some(anchor_cosine.clamp(0.0, 1.0)),
367        };
368        return Ok(vec![fact]);
369    }
370
371    // ── Step 3: retrieve entity embeddings ───────────────────────────────────
372    let entity_ids: Vec<i64> = visited.keys().copied().collect();
373    let point_id_map = store.qdrant_point_ids_for_entities(&entity_ids).await?;
374    let point_ids: Vec<String> = point_id_map.values().cloned().collect();
375
376    let t_vec = Instant::now();
377    let vec_map = embeddings
378        .get_vectors_from_collection(ENTITY_COLLECTION, &point_ids)
379        .await?;
380    if params.step_budget.is_some_and(|b| t_vec.elapsed() > b) {
381        tracing::warn!(
382            elapsed_ms = t_vec.elapsed().as_millis(),
383            "hela: vectors-batch over budget"
384        );
385        return Ok(Vec::new());
386    }
387
388    // ── Step 4: score per visited node ────────────────────────────────────────
389    // Cosine clamped to [0.0, 1.0]: anti-correlated neighbors score 0.0 so
390    // they are ranked below positively-correlated ones.  A negative cosine on a
391    // strongly-reinforced edge would otherwise invert the retrieval signal.
392    let mut facts: Vec<HelaFact> = Vec::with_capacity(visited.len().saturating_sub(1));
393    for (&entity_id, &(depth, path_weight, edge_id_opt)) in &visited {
394        if entity_id == anchor_entity_id {
395            continue;
396        }
397        let Some(edge_id) = edge_id_opt else {
398            continue;
399        };
400        let Some(point_id) = point_id_map.get(&entity_id) else {
401            continue;
402        };
403        let Some(node_vec) = vec_map.get(point_id) else {
404            continue;
405        };
406        if node_vec.len() != q_vec.len() {
407            // Per-node dim mismatch — skip (defense-in-depth for legacy collections).
408            continue;
409        }
410        let cosine_clamped = cosine(&q_vec, node_vec).max(0.0);
411        let fact_score = path_weight * cosine_clamped;
412        let Some(edge) = edge_cache.get(&edge_id).cloned() else {
413            continue;
414        };
415        facts.push(HelaFact {
416            edge,
417            score: fact_score,
418            depth,
419            path_weight,
420            cosine: Some(cosine_clamped),
421        });
422    }
423
424    // ── Step 5: sort, truncate, Hebbian increment ─────────────────────────────
425    facts.sort_by(|a, b| b.score.total_cmp(&a.score));
426    facts.truncate(limit);
427
428    // HL-F2 reinforcement on edges that survived truncation (kept ≈ used).
429    // Hebbian on "kept edges only" — consistent with graph_recall_activated at
430    // graph/retrieval.rs:427-433. Note: SYNAPSE reinforces all traversed edges;
431    // this PR intentionally reinforces only surfaced edges. See MINOR-5.
432    if hebbian_enabled {
433        let edge_ids: Vec<i64> = facts
434            .iter()
435            .map(|f| f.edge.id)
436            .filter(|&id| id != 0) // skip synthetic anchor
437            .collect();
438        if !edge_ids.is_empty()
439            && let Err(e) = store.apply_hebbian_increment(&edge_ids, hebbian_lr).await
440        {
441            tracing::warn!(error = %e, "hela: hebbian increment failed");
442        }
443    }
444
445    tracing::Span::current().record("visited", visited.len());
446    tracing::Span::current().record("scored", facts.len());
447
448    Ok(facts)
449}
450
451// ── SYNAPSE spreading activation ──────────────────────────────────────────────
452
453/// Spreading activation engine parameterized from [`SpreadingActivationParams`].
454pub struct SpreadingActivation {
455    params: SpreadingActivationParams,
456}
457
458impl SpreadingActivation {
459    /// Create a new spreading activation engine from explicit parameters.
460    ///
461    /// `params.temporal_decay_rate` is taken from `GraphConfig.temporal_decay_rate` so that
462    /// recency weighting reuses the same parameter as BFS recall (SA-INV-05).
463    #[must_use]
464    pub fn new(params: SpreadingActivationParams) -> Self {
465        Self { params }
466    }
467
468    /// Run spreading activation from `seeds` over the graph.
469    ///
470    /// Returns activated nodes sorted by activation score descending, along with
471    /// edges collected during propagation.
472    ///
473    /// # Parameters
474    ///
475    /// - `store`: graph database accessor
476    /// - `seeds`: `HashMap<entity_id, initial_activation>` — nodes to start from
477    /// - `edge_types`: MAGMA subgraph filter; when non-empty, only edges of these types
478    ///   are traversed (mirrors `bfs_typed` behaviour; SA-INV-08)
479    ///
480    /// # Errors
481    ///
482    /// Returns an error if any database query fails.
483    pub async fn spread(
484        &self,
485        store: &GraphStore,
486        seeds: HashMap<i64, f32>,
487        edge_types: &[EdgeType],
488    ) -> Result<(Vec<ActivatedNode>, Vec<ActivatedFact>), MemoryError> {
489        if seeds.is_empty() {
490            return Ok((Vec::new(), Vec::new()));
491        }
492
493        // Compute `now_secs` once for consistent temporal recency weighting
494        // across all edges (matches the pattern in retrieval.rs:83-86).
495        let now_secs: i64 = SystemTime::now()
496            .duration_since(UNIX_EPOCH)
497            .map_or(0, |d| d.as_secs().cast_signed());
498
499        let mut activation = self.initialize_seeds(&seeds);
500        let mut activated_facts: Vec<ActivatedFact> = Vec::new();
501
502        for hop in 0..self.params.max_hops {
503            let active_nodes: Vec<(i64, f32)> = activation
504                .iter()
505                .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
506                .map(|(&id, &(score, _))| (id, score))
507                .collect();
508
509            if active_nodes.is_empty() {
510                break;
511            }
512
513            let node_ids: Vec<i64> = active_nodes.iter().map(|(id, _)| *id).collect();
514            let edges = store.edges_for_entities(&node_ids, edge_types).await?;
515            let edge_count = edges.len();
516
517            let next_activation =
518                self.propagate_one_hop(hop, &active_nodes, &edges, &activation, now_secs);
519
520            let pruned_count = self.merge_and_prune(&mut activation, next_activation);
521
522            tracing::debug!(
523                hop,
524                active_nodes = active_nodes.len(),
525                edges_fetched = edge_count,
526                after_merge = activation.len(),
527                pruned = pruned_count,
528                "spreading activation: hop complete"
529            );
530
531            self.collect_activated_facts(&edges, &activation, &mut activated_facts);
532        }
533
534        let result = self.finalize(activation);
535
536        tracing::info!(
537            activated = result.len(),
538            facts = activated_facts.len(),
539            "spreading activation: complete"
540        );
541
542        Ok((result, activated_facts))
543    }
544
545    /// Populate the activation map from seed scores, filtering seeds below threshold.
546    fn initialize_seeds(&self, seeds: &HashMap<i64, f32>) -> HashMap<i64, (f32, u32)> {
547        let mut activation: HashMap<i64, (f32, u32)> = HashMap::new();
548        let mut seed_count = 0usize;
549        // Seeds bypass activation_threshold (they are query anchors per SYNAPSE semantics).
550        for (entity_id, match_score) in seeds {
551            if *match_score < self.params.activation_threshold {
552                tracing::debug!(
553                    entity_id,
554                    score = match_score,
555                    threshold = self.params.activation_threshold,
556                    "spreading activation: seed below threshold, skipping"
557                );
558                continue;
559            }
560            activation.insert(*entity_id, (*match_score, 0));
561            seed_count += 1;
562        }
563        tracing::debug!(
564            seeds = seed_count,
565            "spreading activation: initialized seeds"
566        );
567        activation
568    }
569
570    /// Compute the next-hop activation map by propagating through `edges`.
571    ///
572    /// Applies lateral inhibition (CRIT-02) and clamped multi-path convergence sums.
573    fn propagate_one_hop(
574        &self,
575        hop: u32,
576        active_nodes: &[(i64, f32)],
577        edges: &[Edge],
578        activation: &HashMap<i64, (f32, u32)>,
579        now_secs: i64,
580    ) -> HashMap<i64, (f32, u32)> {
581        let mut next_activation: HashMap<i64, (f32, u32)> = HashMap::new();
582
583        for edge in edges {
584            for &(active_id, node_score) in active_nodes {
585                let neighbor = if edge.source_entity_id == active_id {
586                    edge.target_entity_id
587                } else if edge.target_entity_id == active_id {
588                    edge.source_entity_id
589                } else {
590                    continue;
591                };
592
593                // Lateral inhibition: skip neighbor if it already has high activation
594                // in either the current map OR this hop's next_activation (CRIT-02 fix:
595                // checks both maps to match SYNAPSE paper semantics and prevent runaway
596                // activation when multiple paths converge in the same hop).
597                let current_score = activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
598                let next_score = next_activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
599                if current_score >= self.params.inhibition_threshold
600                    || next_score >= self.params.inhibition_threshold
601                {
602                    continue;
603                }
604
605                let recency = self.recency_weight(&edge.valid_from, now_secs);
606                let edge_weight = evolved_weight(edge.retrieval_count, edge.confidence);
607                let type_w = edge_type_weight(edge.edge_type);
608                let spread_value =
609                    node_score * self.params.decay_lambda * edge_weight * recency * type_w;
610
611                if spread_value < self.params.activation_threshold {
612                    continue;
613                }
614
615                // Clamped sum preserves the multi-path convergence signal: nodes reachable
616                // via multiple paths receive proportionally higher activation (MAJOR-01).
617                let depth_at_max = hop + 1;
618                let entry = next_activation
619                    .entry(neighbor)
620                    .or_insert((0.0, depth_at_max));
621                let new_score = (entry.0 + spread_value).min(1.0);
622                if new_score > entry.0 {
623                    entry.0 = new_score;
624                    entry.1 = depth_at_max;
625                }
626            }
627        }
628
629        next_activation
630    }
631
632    /// Merge `next_activation` into `activation` and prune to `max_activated_nodes` (SA-INV-04).
633    ///
634    /// Returns the number of pruned nodes for tracing.
635    fn merge_and_prune(
636        &self,
637        activation: &mut HashMap<i64, (f32, u32)>,
638        next_activation: HashMap<i64, (f32, u32)>,
639    ) -> usize {
640        for (node_id, (new_score, new_depth)) in next_activation {
641            let entry = activation.entry(node_id).or_insert((0.0, new_depth));
642            if new_score > entry.0 {
643                entry.0 = new_score;
644                entry.1 = new_depth;
645            }
646        }
647
648        if activation.len() > self.params.max_activated_nodes {
649            let before = activation.len();
650            let mut entries: Vec<(i64, (f32, u32))> = activation.drain().collect();
651            entries.sort_by(|(_, (a, _)), (_, (b, _))| b.total_cmp(a));
652            entries.truncate(self.params.max_activated_nodes);
653            *activation = entries.into_iter().collect();
654            before - self.params.max_activated_nodes
655        } else {
656            0
657        }
658    }
659
660    /// Append edges whose both endpoints are above threshold to `activated_facts`.
661    fn collect_activated_facts(
662        &self,
663        edges: &[Edge],
664        activation: &HashMap<i64, (f32, u32)>,
665        activated_facts: &mut Vec<ActivatedFact>,
666    ) {
667        for edge in edges {
668            let src_score = activation
669                .get(&edge.source_entity_id)
670                .map_or(0.0, |&(s, _)| s);
671            let tgt_score = activation
672                .get(&edge.target_entity_id)
673                .map_or(0.0, |&(s, _)| s);
674            if src_score >= self.params.activation_threshold
675                && tgt_score >= self.params.activation_threshold
676            {
677                let activation_score = src_score.max(tgt_score);
678                activated_facts.push(ActivatedFact {
679                    edge: edge.clone(),
680                    activation_score,
681                });
682            }
683        }
684    }
685
686    /// Collect nodes above threshold into `Vec<ActivatedNode>`, sorted descending by score.
687    fn finalize(&self, activation: HashMap<i64, (f32, u32)>) -> Vec<ActivatedNode> {
688        let mut result: Vec<ActivatedNode> = activation
689            .into_iter()
690            .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
691            .map(|(entity_id, (activation, depth))| ActivatedNode {
692                entity_id,
693                activation,
694                depth,
695            })
696            .collect();
697        result.sort_by(|a, b| b.activation.total_cmp(&a.activation));
698        result
699    }
700
701    /// Compute temporal recency weight for an edge.
702    ///
703    /// Formula: `1.0 / (1.0 + age_days * temporal_decay_rate)`.
704    /// Returns `1.0` when `temporal_decay_rate = 0.0` (no temporal adjustment).
705    /// Reuses the same formula as `GraphFact::score_with_decay` (SA-INV-05).
706    #[allow(clippy::cast_precision_loss)]
707    fn recency_weight(&self, valid_from: &str, now_secs: i64) -> f32 {
708        if self.params.temporal_decay_rate <= 0.0 {
709            return 1.0;
710        }
711        let Some(valid_from_secs) = parse_sqlite_datetime_to_unix(valid_from) else {
712            return 1.0;
713        };
714        let age_secs = (now_secs - valid_from_secs).max(0);
715        let age_days = age_secs as f64 / 86_400.0;
716        let weight = 1.0_f64 / (1.0 + age_days * self.params.temporal_decay_rate);
717        // cast f64 -> f32: safe, weight is in [0.0, 1.0]
718        #[allow(clippy::cast_possible_truncation)]
719        let w = weight as f32;
720        w
721    }
722}
723
724/// Parse a `SQLite` `datetime('now')` string to Unix seconds.
725///
726/// Accepts `"YYYY-MM-DD HH:MM:SS"` (and variants with fractional seconds or timezone suffix).
727/// Returns `None` if the string cannot be parsed.
728#[must_use]
729fn parse_sqlite_datetime_to_unix(s: &str) -> Option<i64> {
730    if s.len() < 19 {
731        return None;
732    }
733    let year: i64 = s[0..4].parse().ok()?;
734    let month: i64 = s[5..7].parse().ok()?;
735    let day: i64 = s[8..10].parse().ok()?;
736    let hour: i64 = s[11..13].parse().ok()?;
737    let min: i64 = s[14..16].parse().ok()?;
738    let sec: i64 = s[17..19].parse().ok()?;
739
740    // Days since Unix epoch via civil calendar algorithm.
741    // Reference: https://howardhinnant.github.io/date_algorithms.html#days_from_civil
742    let (y, m) = if month <= 2 {
743        (year - 1, month + 9)
744    } else {
745        (year, month - 3)
746    };
747    let era = y.div_euclid(400);
748    let yoe = y - era * 400;
749    let doy = (153 * m + 2) / 5 + day - 1;
750    let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
751    let days = era * 146_097 + doe - 719_468;
752
753    Some(days * 86_400 + hour * 3_600 + min * 60 + sec)
754}
755
756#[cfg(test)]
757mod tests {
758    use super::*;
759    use crate::graph::GraphStore;
760    use crate::graph::types::EntityType;
761    use crate::store::SqliteStore;
762
763    async fn setup_store() -> GraphStore {
764        let store = SqliteStore::new(":memory:").await.unwrap();
765        GraphStore::new(store.pool().clone())
766    }
767
768    fn default_params() -> SpreadingActivationParams {
769        SpreadingActivationParams {
770            decay_lambda: 0.85,
771            max_hops: 3,
772            activation_threshold: 0.1,
773            inhibition_threshold: 0.8,
774            max_activated_nodes: 50,
775            temporal_decay_rate: 0.0,
776            seed_structural_weight: 0.4,
777            seed_community_cap: 3,
778        }
779    }
780
781    // Test 1: empty graph (no edges) — seed entity is still returned as activated node,
782    // but no facts (edges) are found. Spread does not validate entity existence in DB.
783    #[tokio::test]
784    async fn spread_empty_graph_no_edges_no_facts() {
785        let store = setup_store().await;
786        let sa = SpreadingActivation::new(default_params());
787        let seeds = HashMap::from([(1_i64, 1.0_f32)]);
788        let (nodes, facts) = sa.spread(&store, seeds, &[]).await.unwrap();
789        // Seed node is returned as activated (activation=1.0, depth=0).
790        assert_eq!(nodes.len(), 1, "seed must be in activated nodes");
791        assert_eq!(nodes[0].entity_id, 1);
792        assert!((nodes[0].activation - 1.0).abs() < 1e-6);
793        // No edges in empty graph, so no ActivatedFacts.
794        assert!(
795            facts.is_empty(),
796            "expected no activated facts on empty graph"
797        );
798    }
799
800    // Test 2: empty seeds returns empty
801    #[tokio::test]
802    async fn spread_empty_seeds_returns_empty() {
803        let store = setup_store().await;
804        let sa = SpreadingActivation::new(default_params());
805        let (nodes, facts) = sa.spread(&store, HashMap::new(), &[]).await.unwrap();
806        assert!(nodes.is_empty());
807        assert!(facts.is_empty());
808    }
809
810    // Test 3: single seed with no edges returns only the seed
811    #[tokio::test]
812    async fn spread_single_seed_no_edges_returns_seed() {
813        let store = setup_store().await;
814        let alice = store
815            .upsert_entity("Alice", "Alice", EntityType::Person, None)
816            .await
817            .unwrap()
818            .0;
819
820        let sa = SpreadingActivation::new(default_params());
821        let seeds = HashMap::from([(alice, 1.0_f32)]);
822        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
823        assert_eq!(nodes.len(), 1);
824        assert_eq!(nodes[0].entity_id, alice);
825        assert_eq!(nodes[0].depth, 0);
826        assert!((nodes[0].activation - 1.0).abs() < 1e-6);
827    }
828
829    // Test 4: linear chain A->B->C with max_hops=3 — all activated, scores decay
830    #[tokio::test]
831    async fn spread_linear_chain_all_activated_with_decay() {
832        let store = setup_store().await;
833        let a = store
834            .upsert_entity("A", "A", EntityType::Person, None)
835            .await
836            .unwrap()
837            .0;
838        let b = store
839            .upsert_entity("B", "B", EntityType::Person, None)
840            .await
841            .unwrap()
842            .0;
843        let c = store
844            .upsert_entity("C", "C", EntityType::Person, None)
845            .await
846            .unwrap()
847            .0;
848        store
849            .insert_edge(a, b, "knows", "A knows B", 1.0, None)
850            .await
851            .unwrap();
852        store
853            .insert_edge(b, c, "knows", "B knows C", 1.0, None)
854            .await
855            .unwrap();
856
857        let mut cfg = default_params();
858        cfg.max_hops = 3;
859        cfg.decay_lambda = 0.9;
860        let sa = SpreadingActivation::new(cfg);
861        let seeds = HashMap::from([(a, 1.0_f32)]);
862        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
863
864        let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
865        assert!(ids.contains(&a), "A (seed) must be activated");
866        assert!(ids.contains(&b), "B (hop 1) must be activated");
867        assert!(ids.contains(&c), "C (hop 2) must be activated");
868
869        // Scores must decay: score(A) > score(B) > score(C)
870        let score_a = nodes.iter().find(|n| n.entity_id == a).unwrap().activation;
871        let score_b = nodes.iter().find(|n| n.entity_id == b).unwrap().activation;
872        let score_c = nodes.iter().find(|n| n.entity_id == c).unwrap().activation;
873        assert!(
874            score_a > score_b,
875            "seed A should have higher activation than hop-1 B"
876        );
877        assert!(
878            score_b > score_c,
879            "hop-1 B should have higher activation than hop-2 C"
880        );
881    }
882
883    // Test 5: linear chain with max_hops=1 — C not activated
884    #[tokio::test]
885    async fn spread_linear_chain_max_hops_limits_reach() {
886        let store = setup_store().await;
887        let a = store
888            .upsert_entity("A", "A", EntityType::Person, None)
889            .await
890            .unwrap()
891            .0;
892        let b = store
893            .upsert_entity("B", "B", EntityType::Person, None)
894            .await
895            .unwrap()
896            .0;
897        let c = store
898            .upsert_entity("C", "C", EntityType::Person, None)
899            .await
900            .unwrap()
901            .0;
902        store
903            .insert_edge(a, b, "knows", "A knows B", 1.0, None)
904            .await
905            .unwrap();
906        store
907            .insert_edge(b, c, "knows", "B knows C", 1.0, None)
908            .await
909            .unwrap();
910
911        let mut cfg = default_params();
912        cfg.max_hops = 1;
913        let sa = SpreadingActivation::new(cfg);
914        let seeds = HashMap::from([(a, 1.0_f32)]);
915        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
916
917        let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
918        assert!(ids.contains(&a), "A must be activated (seed)");
919        assert!(ids.contains(&b), "B must be activated (hop 1)");
920        assert!(!ids.contains(&c), "C must NOT be activated with max_hops=1");
921    }
922
923    // Test 6: diamond graph — D receives convergent activation from two paths
924    // Graph: A -> B, A -> C, B -> D, C -> D
925    // With clamped sum, D gets activation from both paths (convergence signal preserved).
926    #[tokio::test]
927    async fn spread_diamond_graph_convergence() {
928        let store = setup_store().await;
929        let a = store
930            .upsert_entity("A", "A", EntityType::Person, None)
931            .await
932            .unwrap()
933            .0;
934        let b = store
935            .upsert_entity("B", "B", EntityType::Person, None)
936            .await
937            .unwrap()
938            .0;
939        let c = store
940            .upsert_entity("C", "C", EntityType::Person, None)
941            .await
942            .unwrap()
943            .0;
944        let d = store
945            .upsert_entity("D", "D", EntityType::Person, None)
946            .await
947            .unwrap()
948            .0;
949        store
950            .insert_edge(a, b, "rel", "A-B", 1.0, None)
951            .await
952            .unwrap();
953        store
954            .insert_edge(a, c, "rel", "A-C", 1.0, None)
955            .await
956            .unwrap();
957        store
958            .insert_edge(b, d, "rel", "B-D", 1.0, None)
959            .await
960            .unwrap();
961        store
962            .insert_edge(c, d, "rel", "C-D", 1.0, None)
963            .await
964            .unwrap();
965
966        let mut cfg = default_params();
967        cfg.max_hops = 3;
968        cfg.decay_lambda = 0.9;
969        cfg.inhibition_threshold = 0.95; // raise inhibition to allow convergence
970        let sa = SpreadingActivation::new(cfg);
971        let seeds = HashMap::from([(a, 1.0_f32)]);
972        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
973
974        let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
975        assert!(ids.contains(&d), "D must be activated via diamond paths");
976
977        // D should be activated at depth 2
978        let node_d = nodes.iter().find(|n| n.entity_id == d).unwrap();
979        assert_eq!(node_d.depth, 2, "D should be at depth 2");
980    }
981
982    // Test 7: inhibition threshold prevents runaway activation in dense cluster
983    #[tokio::test]
984    async fn spread_inhibition_prevents_runaway() {
985        let store = setup_store().await;
986        // Create a hub node connected to many leaves
987        let hub = store
988            .upsert_entity("Hub", "Hub", EntityType::Concept, None)
989            .await
990            .unwrap()
991            .0;
992
993        for i in 0..5 {
994            let leaf = store
995                .upsert_entity(
996                    &format!("Leaf{i}"),
997                    &format!("Leaf{i}"),
998                    EntityType::Concept,
999                    None,
1000                )
1001                .await
1002                .unwrap()
1003                .0;
1004            store
1005                .insert_edge(hub, leaf, "has", &format!("Hub has Leaf{i}"), 1.0, None)
1006                .await
1007                .unwrap();
1008            // Connect all leaves back to hub to create a dense cluster
1009            store
1010                .insert_edge(
1011                    leaf,
1012                    hub,
1013                    "part_of",
1014                    &format!("Leaf{i} part_of Hub"),
1015                    1.0,
1016                    None,
1017                )
1018                .await
1019                .unwrap();
1020        }
1021
1022        // Seed hub with full activation — it should be inhibited after hop 1
1023        let mut cfg = default_params();
1024        cfg.inhibition_threshold = 0.8;
1025        cfg.max_hops = 3;
1026        let sa = SpreadingActivation::new(cfg);
1027        let seeds = HashMap::from([(hub, 1.0_f32)]);
1028        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1029
1030        // Hub should remain at initial activation (1.0), not grow unbounded
1031        let hub_node = nodes.iter().find(|n| n.entity_id == hub);
1032        assert!(hub_node.is_some(), "hub must be in results");
1033        assert!(
1034            hub_node.unwrap().activation <= 1.0,
1035            "activation must not exceed 1.0"
1036        );
1037    }
1038
1039    // Test 8: max_activated_nodes cap — lowest activations pruned
1040    #[tokio::test]
1041    async fn spread_max_activated_nodes_cap_enforced() {
1042        let store = setup_store().await;
1043        let root = store
1044            .upsert_entity("Root", "Root", EntityType::Person, None)
1045            .await
1046            .unwrap()
1047            .0;
1048
1049        // Create 20 leaf nodes connected to root
1050        for i in 0..20 {
1051            let leaf = store
1052                .upsert_entity(
1053                    &format!("Node{i}"),
1054                    &format!("Node{i}"),
1055                    EntityType::Concept,
1056                    None,
1057                )
1058                .await
1059                .unwrap()
1060                .0;
1061            store
1062                .insert_edge(root, leaf, "has", &format!("Root has Node{i}"), 0.9, None)
1063                .await
1064                .unwrap();
1065        }
1066
1067        let max_nodes = 5;
1068        let cfg = SpreadingActivationParams {
1069            max_activated_nodes: max_nodes,
1070            max_hops: 2,
1071            ..default_params()
1072        };
1073        let sa = SpreadingActivation::new(cfg);
1074        let seeds = HashMap::from([(root, 1.0_f32)]);
1075        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1076
1077        assert!(
1078            nodes.len() <= max_nodes,
1079            "activation must be capped at {max_nodes} nodes, got {}",
1080            nodes.len()
1081        );
1082    }
1083
1084    // Test 9: temporal decay — recent edges produce higher activation
1085    #[tokio::test]
1086    async fn spread_temporal_decay_recency_effect() {
1087        let store = setup_store().await;
1088        let src = store
1089            .upsert_entity("Src", "Src", EntityType::Person, None)
1090            .await
1091            .unwrap()
1092            .0;
1093        let recent = store
1094            .upsert_entity("Recent", "Recent", EntityType::Tool, None)
1095            .await
1096            .unwrap()
1097            .0;
1098        let old = store
1099            .upsert_entity("Old", "Old", EntityType::Tool, None)
1100            .await
1101            .unwrap()
1102            .0;
1103
1104        // Insert recent edge (default valid_from = now)
1105        store
1106            .insert_edge(src, recent, "uses", "Src uses Recent", 1.0, None)
1107            .await
1108            .unwrap();
1109
1110        // Insert old edge manually with a 1970 timestamp
1111        zeph_db::query(
1112            sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
1113             VALUES (?1, ?2, 'uses', 'Src uses Old', 1.0, '1970-01-01 00:00:00')"),
1114        )
1115        .bind(src)
1116        .bind(old)
1117        .execute(store.pool())
1118        .await
1119        .unwrap();
1120
1121        let mut cfg = default_params();
1122        cfg.max_hops = 2;
1123        // Use significant temporal decay rate to distinguish recent vs old
1124        let sa = SpreadingActivation::new(SpreadingActivationParams {
1125            temporal_decay_rate: 0.5,
1126            ..cfg
1127        });
1128        let seeds = HashMap::from([(src, 1.0_f32)]);
1129        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1130
1131        let score_recent = nodes
1132            .iter()
1133            .find(|n| n.entity_id == recent)
1134            .map_or(0.0, |n| n.activation);
1135        let score_old = nodes
1136            .iter()
1137            .find(|n| n.entity_id == old)
1138            .map_or(0.0, |n| n.activation);
1139
1140        assert!(
1141            score_recent > score_old,
1142            "recent edge ({score_recent}) must produce higher activation than old edge ({score_old})"
1143        );
1144    }
1145
1146    // Test 10: edge_type filtering — only edges of specified type are traversed
1147    #[tokio::test]
1148    async fn spread_edge_type_filter_excludes_other_types() {
1149        let store = setup_store().await;
1150        let a = store
1151            .upsert_entity("A", "A", EntityType::Person, None)
1152            .await
1153            .unwrap()
1154            .0;
1155        let b_semantic = store
1156            .upsert_entity("BSemantic", "BSemantic", EntityType::Tool, None)
1157            .await
1158            .unwrap()
1159            .0;
1160        let c_causal = store
1161            .upsert_entity("CCausal", "CCausal", EntityType::Concept, None)
1162            .await
1163            .unwrap()
1164            .0;
1165
1166        // Semantic edge from A
1167        store
1168            .insert_edge(a, b_semantic, "uses", "A uses BSemantic", 1.0, None)
1169            .await
1170            .unwrap();
1171
1172        // Causal edge from A (inserted with explicit edge_type)
1173        zeph_db::query(
1174            sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, edge_type)
1175             VALUES (?1, ?2, 'caused', 'A caused CCausal', 1.0, datetime('now'), 'causal')"),
1176        )
1177        .bind(a)
1178        .bind(c_causal)
1179        .execute(store.pool())
1180        .await
1181        .unwrap();
1182
1183        let cfg = default_params();
1184        let sa = SpreadingActivation::new(cfg);
1185
1186        // Spread with only semantic edges
1187        let seeds = HashMap::from([(a, 1.0_f32)]);
1188        let (nodes, _) = sa
1189            .spread(&store, seeds, &[EdgeType::Semantic])
1190            .await
1191            .unwrap();
1192
1193        let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
1194        assert!(
1195            ids.contains(&b_semantic),
1196            "BSemantic must be activated via semantic edge"
1197        );
1198        assert!(
1199            !ids.contains(&c_causal),
1200            "CCausal must NOT be activated when filtering to semantic only"
1201        );
1202    }
1203
1204    // Test 11: large seed list (stress test for batch query)
1205    #[tokio::test]
1206    async fn spread_large_seed_list() {
1207        let store = setup_store().await;
1208        let mut seeds = HashMap::new();
1209
1210        // Create 100 seed entities — tests that edges_for_entities handles chunking correctly
1211        for i in 0..100i64 {
1212            let id = store
1213                .upsert_entity(
1214                    &format!("Entity{i}"),
1215                    &format!("entity{i}"),
1216                    EntityType::Concept,
1217                    None,
1218                )
1219                .await
1220                .unwrap()
1221                .0;
1222            seeds.insert(id, 1.0_f32);
1223        }
1224
1225        let cfg = default_params();
1226        let sa = SpreadingActivation::new(cfg);
1227        // Should complete without error even with 100 seeds (chunking handles SQLite limit)
1228        let result = sa.spread(&store, seeds, &[]).await;
1229        assert!(
1230            result.is_ok(),
1231            "large seed list must not error: {:?}",
1232            result.err()
1233        );
1234    }
1235
1236    // ── HL-F5 unit tests ─────────────────────────────────────────────────────
1237
1238    #[test]
1239    fn hela_cosine_identical_vectors() {
1240        let v = vec![1.0_f32, 0.0, 0.0];
1241        assert!(
1242            (cosine(&v, &v) - 1.0).abs() < 1e-6,
1243            "identical vectors → cosine 1.0"
1244        );
1245    }
1246
1247    #[test]
1248    fn hela_cosine_orthogonal_vectors() {
1249        let a = vec![1.0_f32, 0.0];
1250        let b = vec![0.0_f32, 1.0];
1251        assert!(
1252            cosine(&a, &b).abs() < 1e-6,
1253            "orthogonal vectors → cosine 0.0"
1254        );
1255    }
1256
1257    #[test]
1258    fn hela_cosine_anti_correlated() {
1259        let a = vec![1.0_f32, 0.0];
1260        let b = vec![-1.0_f32, 0.0];
1261        assert!(
1262            cosine(&a, &b) < 0.0,
1263            "anti-correlated vectors → negative cosine"
1264        );
1265    }
1266
1267    #[test]
1268    fn hela_cosine_zero_vector_no_panic() {
1269        let a = vec![0.0_f32, 0.0];
1270        let b = vec![1.0_f32, 0.0];
1271        // Should not panic — denom is guarded by f32::EPSILON
1272        let result = cosine(&a, &b);
1273        assert!(
1274            result.is_finite(),
1275            "zero-norm vector must yield finite cosine"
1276        );
1277    }
1278
1279    #[test]
1280    fn hela_spread_params_default_depth_is_two() {
1281        let p = HelaSpreadParams::default();
1282        assert_eq!(p.spread_depth, 2);
1283        assert!(p.step_budget.is_some());
1284        assert!(p.edge_types.is_empty());
1285        assert_eq!(p.max_visited, 200);
1286    }
1287
1288    #[test]
1289    fn hela_spread_params_default_embed_timeout_is_some() {
1290        let p = HelaSpreadParams::default();
1291        assert!(
1292            p.embed_timeout.is_some(),
1293            "default embed_timeout must be Some (5 s)"
1294        );
1295    }
1296
1297    // Regression test for #4285: hela_spreading_recall must return
1298    // MemoryError::Timeout when the embed provider stalls beyond embed_timeout.
1299    #[tokio::test]
1300    async fn hela_spreading_recall_embed_timeout_returns_error() {
1301        use std::time::Duration;
1302        use zeph_llm::any::AnyProvider;
1303        use zeph_llm::mock::MockProvider;
1304
1305        use crate::embedding_store::EmbeddingStore;
1306        use crate::error::MemoryError;
1307        use crate::in_memory_store::InMemoryVectorStore;
1308
1309        let store = setup_store().await;
1310
1311        // Provider sleeps 500 ms; timeout is set to 50 ms → must fire.
1312        let mock = MockProvider::default().with_embed_delay(500);
1313        let provider = AnyProvider::Mock(mock);
1314
1315        let sqlite = crate::store::SqliteStore::with_pool_size(":memory:", 1)
1316            .await
1317            .unwrap();
1318        let embeddings =
1319            EmbeddingStore::with_store(Box::new(InMemoryVectorStore::new()), sqlite.pool().clone());
1320
1321        let params = HelaSpreadParams {
1322            embed_timeout: Some(Duration::from_millis(50)),
1323            ..Default::default()
1324        };
1325
1326        let result = hela_spreading_recall(
1327            &store,
1328            &embeddings,
1329            &provider,
1330            "test query",
1331            5,
1332            &params,
1333            false,
1334            0.0,
1335        )
1336        .await;
1337
1338        assert!(
1339            matches!(result, Err(MemoryError::Timeout(_))),
1340            "expected Err(MemoryError::Timeout), got {result:?}"
1341        );
1342    }
1343
1344    // When embed_timeout is None the embed call is not wrapped; the (fast) mock
1345    // returns immediately and the function must succeed.
1346    #[tokio::test]
1347    async fn hela_spreading_recall_no_timeout_does_not_wrap() {
1348        use zeph_llm::any::AnyProvider;
1349        use zeph_llm::mock::MockProvider;
1350
1351        use crate::embedding_store::EmbeddingStore;
1352        use crate::in_memory_store::InMemoryVectorStore;
1353
1354        let store = setup_store().await;
1355
1356        let mock = MockProvider::default().with_embed_delay(0);
1357        let provider = AnyProvider::Mock(mock);
1358
1359        let sqlite = crate::store::SqliteStore::with_pool_size(":memory:", 1)
1360            .await
1361            .unwrap();
1362        let embeddings =
1363            EmbeddingStore::with_store(Box::new(InMemoryVectorStore::new()), sqlite.pool().clone());
1364
1365        let params = HelaSpreadParams {
1366            embed_timeout: None,
1367            ..Default::default()
1368        };
1369
1370        // embed returns a zero-dimension vector (embed not configured for 384-dim),
1371        // so Qdrant search finds nothing — the function returns Ok(Vec::new()).
1372        let result = hela_spreading_recall(
1373            &store,
1374            &embeddings,
1375            &provider,
1376            "test query",
1377            5,
1378            &params,
1379            false,
1380            0.0,
1381        )
1382        .await;
1383
1384        // The mock embed returns an empty vec by default; the ANN search will
1385        // find no results — the expected outcome is Ok(empty) or a non-Timeout error.
1386        assert!(
1387            !matches!(result, Err(crate::error::MemoryError::Timeout(_))),
1388            "embed_timeout: None must not produce a Timeout error, got {result:?}"
1389        );
1390    }
1391
1392    #[test]
1393    fn hela_synthetic_anchor_edge_id_is_zero() {
1394        let edge = Edge::synthetic_anchor(42);
1395        assert_eq!(
1396            edge.id, 0,
1397            "synthetic anchor must have id = 0 to be excluded from Hebbian"
1398        );
1399        assert_eq!(edge.source_entity_id, 42);
1400        assert_eq!(edge.target_entity_id, 42);
1401    }
1402
1403    #[test]
1404    fn hela_negative_cosine_clamped_to_zero_in_score() {
1405        // path_weight × cosine.max(0.0): negative cosine must contribute 0.0
1406        let anti = vec![-1.0_f32, 0.0];
1407        let query = vec![1.0_f32, 0.0];
1408        let cosine_raw = cosine(&query, &anti);
1409        assert!(cosine_raw < 0.0);
1410        let clamped = cosine_raw.max(0.0);
1411        let fact_score = 0.9_f32 * clamped;
1412        assert!(
1413            fact_score < f32::EPSILON,
1414            "anti-correlated score must be 0.0"
1415        );
1416    }
1417
1418    #[test]
1419    fn hela_path_weight_multiplicative() {
1420        // Two-hop path with edge weights 0.8, 0.5 → path_weight = 0.4
1421        let w1 = 0.8_f32;
1422        let w2 = 0.5_f32;
1423        let expected = w1 * w2;
1424        assert!((expected - 0.4).abs() < 1e-6);
1425    }
1426
1427    #[test]
1428    fn hela_max_path_weight_on_multipath() {
1429        // When two paths reach the same node, keep the higher path_weight.
1430        let pw_a = 0.9_f32; // short direct path
1431        let pw_b = 0.3_f32; // longer indirect path
1432        let kept = pw_a.max(pw_b);
1433        assert!(
1434            (kept - 0.9).abs() < 1e-6,
1435            "multi-path resolution must keep maximum path_weight"
1436        );
1437    }
1438
1439    #[test]
1440    fn hela_fact_score_formula() {
1441        let path_weight = 0.8_f32;
1442        let cosine_clamped = 0.75_f32;
1443        let expected = path_weight * cosine_clamped;
1444        // Verify the formula used in hela_spreading_recall Step 4.
1445        assert!((expected - 0.6).abs() < 1e-5);
1446    }
1447}