Skip to main content

zeph_memory/graph/
activation.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! SYNAPSE spreading activation retrieval over the entity graph.
5//!
6//! Implements the spreading activation algorithm from arXiv 2601.02744, adapted for
7//! the zeph-memory graph schema. Seeds are matched via fuzzy entity search; activation
8//! propagates hop-by-hop with:
9//! - Exponential decay per hop (`decay_lambda`)
10//! - Edge confidence weighting
11//! - Temporal recency weighting (reuses `GraphConfig.temporal_decay_rate`)
12//! - Lateral inhibition (nodes above `inhibition_threshold` stop receiving activation)
13//! - Per-hop pruning to enforce `max_activated_nodes` bound (SA-INV-04)
14//! - MAGMA edge type filtering via `edge_types` parameter
15
16use std::collections::{HashMap, HashSet};
17use std::sync::OnceLock;
18use std::time::{Instant, SystemTime, UNIX_EPOCH};
19#[allow(unused_imports)]
20use zeph_db::sql;
21
22use crate::embedding_store::EmbeddingStore;
23use crate::error::MemoryError;
24use crate::graph::store::GraphStore;
25use crate::graph::types::{Edge, EdgeType, edge_type_weight, evolved_weight};
26
27/// A graph node that was activated during spreading activation.
28#[derive(Debug, Clone)]
29pub struct ActivatedNode {
30    /// Database ID of the activated entity.
31    pub entity_id: i64,
32    /// Final activation score in `[0.0, 1.0]`.
33    pub activation: f32,
34    /// Hop at which the maximum activation was received (`0` = seed).
35    pub depth: u32,
36}
37
38/// A graph edge traversed during spreading activation, with its activation score.
39#[derive(Debug, Clone)]
40pub struct ActivatedFact {
41    /// The traversed edge.
42    pub edge: Edge,
43    /// Activation score of the source or target entity at time of traversal.
44    pub activation_score: f32,
45    /// `true` when this edge has a pending implicit conflict candidate (spec 004-17).
46    pub is_implicit_conflict: bool,
47    /// ID of the `implicit_conflict_candidates` row, if any.
48    pub conflict_candidate_id: Option<i64>,
49}
50
51pub use zeph_common::memory::SpreadingActivationParams;
52
53// ── HL-F5: HeLa-Mem spreading activation (#3346) ─────────────────────────────
54
55/// A graph edge surfaced by HL-F5 spreading activation (#3346), scored by
56/// `path_weight × max(cosine_query_to_endpoint, 0.0)`.
57///
58/// Mirrors [`ActivatedFact`] so callers can dispatch over a single
59/// `Vec<HelaFact>` ↔ `Vec<ActivatedFact>` ↔ `Vec<GraphFact>` shape at the
60/// strategy-selection site.
61#[derive(Debug, Clone)]
62pub struct HelaFact {
63    /// The edge by which the higher-scored endpoint was reached.
64    pub edge: Edge,
65    /// Final HL-F5 score: `path_weight × cosine_clamped`. Range: `[0.0, +∞)`.
66    pub score: f32,
67    /// BFS depth at which `edge` was traversed (`1..=spread_depth`).
68    /// `0` is reserved for the synthetic anchor edge in the isolated-anchor fallback.
69    pub depth: u32,
70    /// Multiplicative product of edge weights along the BFS path that reached
71    /// this edge's far endpoint. Range: `[0.0, +∞)`.
72    pub path_weight: f32,
73    /// Clamped cosine similarity of the far endpoint's entity embedding
74    /// to the query embedding, in `[0.0, 1.0]`. `None` when the endpoint
75    /// has no stored embedding (skipped from results in that case).
76    pub cosine: Option<f32>,
77}
78
79/// Parameters for HL-F5 spreading activation retrieval.
80///
81/// Build via [`Default`] and override individual fields:
82///
83/// ```rust
84/// use zeph_memory::graph::activation::HelaSpreadParams;
85///
86/// let params = HelaSpreadParams { spread_depth: 3, ..Default::default() };
87/// ```
88#[derive(Debug, Clone)]
89pub struct HelaSpreadParams {
90    /// BFS hops. Clamped to `[1, 6]` at runtime. Default: `2`.
91    pub spread_depth: u32,
92    /// MAGMA edge-type filter. Empty = all types. Default: `[]`.
93    pub edge_types: Vec<EdgeType>,
94    /// Soft upper bound on the visited-node set. Default: `200`.
95    pub max_visited: usize,
96    /// Per-step circuit breaker. Any internal step (anchor ANN, edges batch,
97    /// vectors batch) that exceeds this duration triggers an `Ok(Vec::new())`
98    /// fallback with a `WARN`. Default: `Some(8 ms)`.
99    pub step_budget: Option<std::time::Duration>,
100    /// Timeout for the initial query embedding call. `None` = no timeout.
101    /// Default: `Some(5 s)`.
102    pub embed_timeout: Option<std::time::Duration>,
103}
104
105impl Default for HelaSpreadParams {
106    fn default() -> Self {
107        Self {
108            spread_depth: 2,
109            edge_types: Vec::new(),
110            max_visited: 200,
111            step_budget: Some(std::time::Duration::from_millis(8)),
112            embed_timeout: Some(std::time::Duration::from_secs(5)),
113        }
114    }
115}
116
117/// Process-global dim-mismatch sentinel for HL-F5 (keyed by collection name).
118///
119/// MINOR-1 resolution: keyed by collection so re-provisioning with a different
120/// dimension recovers after a process restart.  A per-`SemanticMemory` guard would
121/// require passing state down; a process-global string key is the least-invasive
122/// approach that prevents permanent lockout from transient startup errors.
123/// Test isolation: each test constructs its own `HelaSpreadParams` with
124/// a distinct mock collection name to avoid cross-test interference.
125static HELA_DIM_MISMATCH: OnceLock<String> = OnceLock::new();
126
127/// Cosine similarity of two equal-length slices.
128///
129/// Returns `0.0` when either norm is zero (prevents division by zero).
130fn cosine(a: &[f32], b: &[f32]) -> f32 {
131    let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
132    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
133    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
134    let denom = (norm_a * norm_b).max(f32::EPSILON);
135    dot / denom
136}
137
138/// HL-F5 BFS spreading activation from the top-1 ANN anchor node (#3346).
139///
140/// Algorithm overview:
141/// 1. Embed `query` → anchor via ANN search in the entity Qdrant collection.
142/// 2. BFS up to `params.spread_depth` hops, propagating multiplicative edge
143///    weights (`path_weight = Π edge.weight along path`). Multi-path convergence
144///    keeps the maximum `path_weight`.
145/// 3. Retrieve entity embeddings for all visited nodes via `get_points`.
146/// 4. Score each node: `score = path_weight × max(cosine(query, entity), 0.0)`.
147/// 5. Sort descending, truncate to `limit`, reinforce traversed edges via Hebbian
148///    update (when `hebbian_enabled`).
149///
150/// Fallback: when the anchor entity has no outgoing edges a single synthetic
151/// [`HelaFact`] with `edge.id == 0` and `score = anchor_cosine` is returned
152/// (the real ANN cosine, never a fabricated `1.0`).
153///
154/// Per-step circuit breaker: any individual step exceeding `params.step_budget`
155/// emits a `WARN` and returns `Ok(Vec::new())`.
156///
157/// Dim-mismatch resilience: a one-time dim probe on the first call guards against
158/// collection/provider configuration mismatches (#3382 pattern). Subsequent calls
159/// to a mismatched collection short-circuit immediately.
160///
161/// # Errors
162///
163/// Returns an error if the embed call or any database query fails.
164#[tracing::instrument(
165    name = "memory.graph.hela_spread",
166    skip_all,
167    fields(
168        depth = params.spread_depth,
169        limit,
170        anchor_id = tracing::field::Empty,
171        visited = tracing::field::Empty,
172        scored = tracing::field::Empty,
173        fallback = tracing::field::Empty,
174    )
175)]
176#[allow(clippy::too_many_arguments, clippy::too_many_lines)] // complex algorithm function; both suppressions justified until the function is decomposed in a future refactor
177pub async fn hela_spreading_recall(
178    store: &GraphStore,
179    embeddings: &EmbeddingStore,
180    provider: &zeph_llm::any::AnyProvider,
181    query: &str,
182    limit: usize,
183    params: &HelaSpreadParams,
184    hebbian_enabled: bool,
185    hebbian_lr: f32,
186) -> Result<Vec<HelaFact>, MemoryError> {
187    use zeph_llm::LlmProvider as _;
188
189    const ENTITY_COLLECTION: &str = "zeph_graph_entities";
190
191    if limit == 0 {
192        return Ok(Vec::new());
193    }
194
195    // ── Step 0: dim-mismatch guard ────────────────────────────────────────────
196    // MINOR-1: guard is keyed by collection name so re-provisioning recovers.
197    if HELA_DIM_MISMATCH.get().map(String::as_str) == Some(ENTITY_COLLECTION) {
198        tracing::debug!("hela: dim mismatch previously detected for collection, skipping");
199        return Ok(Vec::new());
200    }
201
202    // ── Step 1: embed query ───────────────────────────────────────────────────
203    let q_vec = if let Some(timeout) = params.embed_timeout {
204        tokio::time::timeout(timeout, provider.embed(query))
205            .await
206            .map_err(|_| {
207                tracing::warn!(timeout_ms = timeout.as_millis(), "hela: embed timed out");
208                MemoryError::Timeout("hela embed".into())
209            })??
210    } else {
211        provider.embed(query).await?
212    };
213
214    // Dim probe: search with k=1 to catch dimension mismatch at the Qdrant layer.
215    let t_anchor = Instant::now();
216    let anchor_results = match embeddings
217        .search_collection(ENTITY_COLLECTION, &q_vec, 1, None)
218        .await
219    {
220        Ok(r) => r,
221        Err(e) => {
222            let msg = e.to_string();
223            if msg.contains("wrong vector dimension")
224                || msg.contains("InvalidArgument")
225                || msg.contains("dimension")
226            {
227                let _ = HELA_DIM_MISMATCH.set(ENTITY_COLLECTION.to_owned());
228                tracing::warn!(
229                    collection = ENTITY_COLLECTION,
230                    error = %e,
231                    "hela: vector dimension mismatch — HL-F5 disabled for this collection"
232                );
233                return Ok(Vec::new());
234            }
235            return Err(e);
236        }
237    };
238
239    if params.step_budget.is_some_and(|b| t_anchor.elapsed() > b) {
240        tracing::warn!(
241            elapsed_ms = t_anchor.elapsed().as_millis(),
242            "hela: anchor ANN over budget"
243        );
244        return Ok(Vec::new());
245    }
246
247    let Some(anchor_point) = anchor_results.first() else {
248        tracing::debug!("hela: no anchor found, returning empty");
249        return Ok(Vec::new());
250    };
251    let Some(anchor_entity_id) = anchor_point
252        .payload
253        .get("entity_id")
254        .and_then(serde_json::Value::as_i64)
255    else {
256        tracing::warn!("hela: anchor point missing entity_id payload");
257        return Ok(Vec::new());
258    };
259    let anchor_cosine = anchor_point.score;
260
261    tracing::Span::current().record("anchor_id", anchor_entity_id);
262    tracing::debug!(anchor_entity_id, anchor_cosine, "hela: anchor resolved");
263
264    let spread_depth = params.spread_depth.clamp(1, 6);
265
266    // ── Step 2: BFS with multiplicative path-weight propagation ──────────────
267    // `visited`: entity_id → (depth, path_weight, edge_id_via_which_we_arrived)
268    let mut visited: HashMap<i64, (u32, f32, Option<i64>)> = HashMap::new();
269    visited.insert(anchor_entity_id, (0, 1.0, None));
270
271    // Dedup edges keyed by id for Step 4 lookup (avoids N clones per frontier).
272    // MINOR-3 resolution: collect edges into a HashMap<id, Edge> outside the
273    // per-source loop to avoid 10K clones on a hub × 50-entity frontier.
274    let mut edge_cache: HashMap<i64, Edge> = HashMap::new();
275    let mut frontier: Vec<i64> = vec![anchor_entity_id];
276
277    for hop in 0..spread_depth {
278        if frontier.is_empty() {
279            break;
280        }
281
282        tracing::debug!(hop, frontier_size = frontier.len(), "hela: starting hop");
283
284        let t_step = Instant::now();
285        let edges = store
286            .edges_for_entities(&frontier, &params.edge_types)
287            .await?;
288        if params.step_budget.is_some_and(|b| t_step.elapsed() > b) {
289            tracing::warn!(
290                hop,
291                elapsed_ms = t_step.elapsed().as_millis(),
292                "hela: edge-fetch over budget"
293            );
294            return Ok(Vec::new());
295        }
296
297        let mut next_frontier: Vec<i64> = Vec::new();
298        let mut next_frontier_set: HashSet<i64> = HashSet::new();
299
300        for edge in &edges {
301            // Cache by edge id to avoid repeated clones per source in frontier.
302            edge_cache.entry(edge.id).or_insert_with(|| edge.clone());
303
304            for &src_id in &frontier {
305                let neighbor = if edge.source_entity_id == src_id {
306                    edge.target_entity_id
307                } else if edge.target_entity_id == src_id {
308                    edge.source_entity_id
309                } else {
310                    continue;
311                };
312
313                let parent_pw = visited.get(&src_id).map_or(1.0, |&(_, pw, _)| pw);
314                let new_pw = parent_pw * edge.weight;
315
316                // Multi-path resolution: keep MAX path_weight; lower depth as
317                // tie-break. MINOR-4 note: max_visited is a soft bound — the
318                // actual visited set may exceed it by O(edges_per_hop_step) for
319                // one frontier step before the outer break fires.
320                let entry = visited
321                    .entry(neighbor)
322                    .or_insert((hop + 1, 0.0_f32, Some(edge.id)));
323                // Prefer strictly higher path weight; break ties in favour of shallower depth.
324                if new_pw > entry.1
325                    || ((new_pw - entry.1).abs() < f32::EPSILON && hop + 1 < entry.0)
326                {
327                    *entry = (hop + 1, new_pw, Some(edge.id));
328                    if next_frontier_set.insert(neighbor) {
329                        next_frontier.push(neighbor);
330                    }
331                }
332
333                if visited.len() >= params.max_visited {
334                    break;
335                }
336            }
337
338            if visited.len() >= params.max_visited {
339                break;
340            }
341        }
342
343        tracing::debug!(
344            hop,
345            edges_fetched = edges.len(),
346            visited = visited.len(),
347            next_frontier = next_frontier.len(),
348            "hela: hop complete"
349        );
350
351        frontier = next_frontier;
352        if visited.len() >= params.max_visited {
353            break;
354        }
355    }
356
357    // ── Isolated-anchor fallback ──────────────────────────────────────────────
358    // `visited.len() == 1` means no edges were traversed from the anchor.
359    if visited.len() == 1 {
360        tracing::Span::current().record("fallback", true);
361        tracing::debug!(
362            anchor_entity_id,
363            anchor_cosine,
364            "hela: anchor isolated, falling back to pure ANN"
365        );
366        let fact = HelaFact {
367            edge: Edge::synthetic_anchor(anchor_entity_id),
368            score: anchor_cosine,
369            depth: 0,
370            path_weight: 1.0,
371            cosine: Some(anchor_cosine.clamp(0.0, 1.0)),
372        };
373        return Ok(vec![fact]);
374    }
375
376    // ── Step 3: retrieve entity embeddings ───────────────────────────────────
377    let entity_ids: Vec<i64> = visited.keys().copied().collect();
378    let point_id_map = store.qdrant_point_ids_for_entities(&entity_ids).await?;
379    let point_ids: Vec<String> = point_id_map.values().cloned().collect();
380
381    let t_vec = Instant::now();
382    let vec_map = embeddings
383        .get_vectors_from_collection(ENTITY_COLLECTION, &point_ids)
384        .await?;
385    if params.step_budget.is_some_and(|b| t_vec.elapsed() > b) {
386        tracing::warn!(
387            elapsed_ms = t_vec.elapsed().as_millis(),
388            "hela: vectors-batch over budget"
389        );
390        return Ok(Vec::new());
391    }
392
393    // ── Step 4: score per visited node ────────────────────────────────────────
394    // Cosine clamped to [0.0, 1.0]: anti-correlated neighbors score 0.0 so
395    // they are ranked below positively-correlated ones.  A negative cosine on a
396    // strongly-reinforced edge would otherwise invert the retrieval signal.
397    let mut facts: Vec<HelaFact> = Vec::with_capacity(visited.len().saturating_sub(1));
398    for (&entity_id, &(depth, path_weight, edge_id_opt)) in &visited {
399        if entity_id == anchor_entity_id {
400            continue;
401        }
402        let Some(edge_id) = edge_id_opt else {
403            continue;
404        };
405        let Some(point_id) = point_id_map.get(&entity_id) else {
406            continue;
407        };
408        let Some(node_vec) = vec_map.get(point_id) else {
409            continue;
410        };
411        if node_vec.len() != q_vec.len() {
412            // Per-node dim mismatch — skip (defense-in-depth for legacy collections).
413            continue;
414        }
415        let cosine_clamped = cosine(&q_vec, node_vec).max(0.0);
416        let fact_score = path_weight * cosine_clamped;
417        let Some(edge) = edge_cache.get(&edge_id).cloned() else {
418            continue;
419        };
420        facts.push(HelaFact {
421            edge,
422            score: fact_score,
423            depth,
424            path_weight,
425            cosine: Some(cosine_clamped),
426        });
427    }
428
429    // ── Step 5: sort, truncate, Hebbian increment ─────────────────────────────
430    facts.sort_by(|a, b| b.score.total_cmp(&a.score));
431    facts.truncate(limit);
432
433    // HL-F2 reinforcement on edges that survived truncation (kept ≈ used).
434    // Hebbian on "kept edges only" — consistent with graph_recall_activated at
435    // graph/retrieval.rs:427-433. Note: SYNAPSE reinforces all traversed edges;
436    // this PR intentionally reinforces only surfaced edges. See MINOR-5.
437    if hebbian_enabled {
438        let edge_ids: Vec<i64> = facts
439            .iter()
440            .map(|f| f.edge.id)
441            .filter(|&id| id != 0) // skip synthetic anchor
442            .collect();
443        if !edge_ids.is_empty()
444            && let Err(e) = store.apply_hebbian_increment(&edge_ids, hebbian_lr).await
445        {
446            tracing::warn!(error = %e, "hela: hebbian increment failed");
447        }
448    }
449
450    tracing::Span::current().record("visited", visited.len());
451    tracing::Span::current().record("scored", facts.len());
452
453    Ok(facts)
454}
455
456// ── SYNAPSE spreading activation ──────────────────────────────────────────────
457
458/// Spreading activation engine parameterized from [`SpreadingActivationParams`].
459pub struct SpreadingActivation {
460    params: SpreadingActivationParams,
461}
462
463impl SpreadingActivation {
464    /// Create a new spreading activation engine from explicit parameters.
465    ///
466    /// `params.temporal_decay_rate` is taken from `GraphConfig.temporal_decay_rate` so that
467    /// recency weighting reuses the same parameter as BFS recall (SA-INV-05).
468    #[must_use]
469    pub fn new(params: SpreadingActivationParams) -> Self {
470        Self { params }
471    }
472
473    /// Run spreading activation from `seeds` over the graph.
474    ///
475    /// Returns activated nodes sorted by activation score descending, along with
476    /// edges collected during propagation.
477    ///
478    /// # Parameters
479    ///
480    /// - `store`: graph database accessor
481    /// - `seeds`: `HashMap<entity_id, initial_activation>` — nodes to start from
482    /// - `edge_types`: MAGMA subgraph filter; when non-empty, only edges of these types
483    ///   are traversed (mirrors `bfs_typed` behaviour; SA-INV-08)
484    ///
485    /// # Errors
486    ///
487    /// Returns an error if any database query fails.
488    pub async fn spread(
489        &self,
490        store: &GraphStore,
491        seeds: HashMap<i64, f32>,
492        edge_types: &[EdgeType],
493    ) -> Result<(Vec<ActivatedNode>, Vec<ActivatedFact>), MemoryError> {
494        if seeds.is_empty() {
495            return Ok((Vec::new(), Vec::new()));
496        }
497
498        // Compute `now_secs` once for consistent temporal recency weighting
499        // across all edges (matches the pattern in retrieval.rs:83-86).
500        let now_secs: i64 = SystemTime::now()
501            .duration_since(UNIX_EPOCH)
502            .map_or(0, |d| d.as_secs().cast_signed());
503
504        let mut activation = self.initialize_seeds(&seeds);
505        let mut activated_facts: Vec<ActivatedFact> = Vec::new();
506
507        for hop in 0..self.params.max_hops {
508            let active_nodes: Vec<(i64, f32)> = activation
509                .iter()
510                .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
511                .map(|(&id, &(score, _))| (id, score))
512                .collect();
513
514            if active_nodes.is_empty() {
515                break;
516            }
517
518            let node_ids: Vec<i64> = active_nodes.iter().map(|(id, _)| *id).collect();
519            let edges = store.edges_for_entities(&node_ids, edge_types).await?;
520            let edge_count = edges.len();
521
522            let next_activation =
523                self.propagate_one_hop(hop, &active_nodes, &edges, &activation, now_secs);
524
525            let pruned_count = self.merge_and_prune(&mut activation, next_activation);
526
527            tracing::debug!(
528                hop,
529                active_nodes = active_nodes.len(),
530                edges_fetched = edge_count,
531                after_merge = activation.len(),
532                pruned = pruned_count,
533                "spreading activation: hop complete"
534            );
535
536            self.collect_activated_facts(&edges, &activation, &mut activated_facts);
537        }
538
539        let result = self.finalize(activation);
540
541        tracing::info!(
542            activated = result.len(),
543            facts = activated_facts.len(),
544            "spreading activation: complete"
545        );
546
547        Ok((result, activated_facts))
548    }
549
550    /// Populate the activation map from seed scores, filtering seeds below threshold.
551    fn initialize_seeds(&self, seeds: &HashMap<i64, f32>) -> HashMap<i64, (f32, u32)> {
552        let mut activation: HashMap<i64, (f32, u32)> = HashMap::new();
553        let mut seed_count = 0usize;
554        // Seeds bypass activation_threshold (they are query anchors per SYNAPSE semantics).
555        for (entity_id, match_score) in seeds {
556            if *match_score < self.params.activation_threshold {
557                tracing::debug!(
558                    entity_id,
559                    score = match_score,
560                    threshold = self.params.activation_threshold,
561                    "spreading activation: seed below threshold, skipping"
562                );
563                continue;
564            }
565            activation.insert(*entity_id, (*match_score, 0));
566            seed_count += 1;
567        }
568        tracing::debug!(
569            seeds = seed_count,
570            "spreading activation: initialized seeds"
571        );
572        activation
573    }
574
575    /// Compute the next-hop activation map by propagating through `edges`.
576    ///
577    /// Applies lateral inhibition (CRIT-02) and clamped multi-path convergence sums.
578    fn propagate_one_hop(
579        &self,
580        hop: u32,
581        active_nodes: &[(i64, f32)],
582        edges: &[Edge],
583        activation: &HashMap<i64, (f32, u32)>,
584        now_secs: i64,
585    ) -> HashMap<i64, (f32, u32)> {
586        let mut next_activation: HashMap<i64, (f32, u32)> = HashMap::new();
587
588        for edge in edges {
589            for &(active_id, node_score) in active_nodes {
590                let neighbor = if edge.source_entity_id == active_id {
591                    edge.target_entity_id
592                } else if edge.target_entity_id == active_id {
593                    edge.source_entity_id
594                } else {
595                    continue;
596                };
597
598                // Lateral inhibition: skip neighbor if it already has high activation
599                // in either the current map OR this hop's next_activation (CRIT-02 fix:
600                // checks both maps to match SYNAPSE paper semantics and prevent runaway
601                // activation when multiple paths converge in the same hop).
602                let current_score = activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
603                let next_score = next_activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
604                if current_score >= self.params.inhibition_threshold
605                    || next_score >= self.params.inhibition_threshold
606                {
607                    continue;
608                }
609
610                let recency = self.recency_weight(&edge.valid_from, now_secs);
611                // SYNAPSE blend: use Benna-Fusi fast/slow variables instead of raw confidence (#3709).
612                let blended = self.params.alpha * edge.confidence_fast
613                    + (1.0 - self.params.alpha) * edge.confidence_slow;
614                let edge_weight = evolved_weight(edge.retrieval_count, blended);
615                let type_w = edge_type_weight(edge.edge_type);
616                let spread_value =
617                    node_score * self.params.decay_lambda * edge_weight * recency * type_w;
618
619                if spread_value < self.params.activation_threshold {
620                    continue;
621                }
622
623                // Clamped sum preserves the multi-path convergence signal: nodes reachable
624                // via multiple paths receive proportionally higher activation (MAJOR-01).
625                let depth_at_max = hop + 1;
626                let entry = next_activation
627                    .entry(neighbor)
628                    .or_insert((0.0, depth_at_max));
629                let new_score = (entry.0 + spread_value).min(1.0);
630                if new_score > entry.0 {
631                    entry.0 = new_score;
632                    entry.1 = depth_at_max;
633                }
634            }
635        }
636
637        next_activation
638    }
639
640    /// Merge `next_activation` into `activation` and prune to `max_activated_nodes` (SA-INV-04).
641    ///
642    /// Returns the number of pruned nodes for tracing.
643    fn merge_and_prune(
644        &self,
645        activation: &mut HashMap<i64, (f32, u32)>,
646        next_activation: HashMap<i64, (f32, u32)>,
647    ) -> usize {
648        for (node_id, (new_score, new_depth)) in next_activation {
649            let entry = activation.entry(node_id).or_insert((0.0, new_depth));
650            if new_score > entry.0 {
651                entry.0 = new_score;
652                entry.1 = new_depth;
653            }
654        }
655
656        if activation.len() > self.params.max_activated_nodes {
657            let before = activation.len();
658            let mut entries: Vec<(i64, (f32, u32))> = activation.drain().collect();
659            entries.sort_by(|(_, (a, _)), (_, (b, _))| b.total_cmp(a));
660            entries.truncate(self.params.max_activated_nodes);
661            *activation = entries.into_iter().collect();
662            before - self.params.max_activated_nodes
663        } else {
664            0
665        }
666    }
667
668    /// Append edges whose both endpoints are above threshold to `activated_facts`.
669    fn collect_activated_facts(
670        &self,
671        edges: &[Edge],
672        activation: &HashMap<i64, (f32, u32)>,
673        activated_facts: &mut Vec<ActivatedFact>,
674    ) {
675        for edge in edges {
676            let src_score = activation
677                .get(&edge.source_entity_id)
678                .map_or(0.0, |&(s, _)| s);
679            let tgt_score = activation
680                .get(&edge.target_entity_id)
681                .map_or(0.0, |&(s, _)| s);
682            if src_score >= self.params.activation_threshold
683                && tgt_score >= self.params.activation_threshold
684            {
685                let activation_score = src_score.max(tgt_score);
686                activated_facts.push(ActivatedFact {
687                    edge: edge.clone(),
688                    activation_score,
689                    is_implicit_conflict: false,
690                    conflict_candidate_id: None,
691                });
692            }
693        }
694    }
695
696    /// Collect nodes above threshold into `Vec<ActivatedNode>`, sorted descending by score.
697    fn finalize(&self, activation: HashMap<i64, (f32, u32)>) -> Vec<ActivatedNode> {
698        let mut result: Vec<ActivatedNode> = activation
699            .into_iter()
700            .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
701            .map(|(entity_id, (activation, depth))| ActivatedNode {
702                entity_id,
703                activation,
704                depth,
705            })
706            .collect();
707        result.sort_by(|a, b| b.activation.total_cmp(&a.activation));
708        result
709    }
710
711    /// Compute temporal recency weight for an edge.
712    ///
713    /// Formula: `1.0 / (1.0 + age_days * temporal_decay_rate)`.
714    /// Returns `1.0` when `temporal_decay_rate = 0.0` (no temporal adjustment).
715    /// Reuses the same formula as `GraphFact::score_with_decay` (SA-INV-05).
716    #[allow(clippy::cast_precision_loss)]
717    fn recency_weight(&self, valid_from: &str, now_secs: i64) -> f32 {
718        if self.params.temporal_decay_rate <= 0.0 {
719            return 1.0;
720        }
721        let Some(valid_from_secs) = parse_sqlite_datetime_to_unix(valid_from) else {
722            return 1.0;
723        };
724        let age_secs = (now_secs - valid_from_secs).max(0);
725        let age_days = age_secs as f64 / 86_400.0;
726        let weight = 1.0_f64 / (1.0 + age_days * self.params.temporal_decay_rate);
727        // cast f64 -> f32: safe, weight is in [0.0, 1.0]
728        #[allow(clippy::cast_possible_truncation)]
729        let w = weight as f32;
730        w
731    }
732}
733
734/// Parse a `SQLite` `datetime('now')` string to Unix seconds.
735///
736/// Accepts `"YYYY-MM-DD HH:MM:SS"` (and variants with fractional seconds or timezone suffix).
737/// Returns `None` if the string cannot be parsed.
738#[must_use]
739fn parse_sqlite_datetime_to_unix(s: &str) -> Option<i64> {
740    if s.len() < 19 {
741        return None;
742    }
743    let year: i64 = s[0..4].parse().ok()?;
744    let month: i64 = s[5..7].parse().ok()?;
745    let day: i64 = s[8..10].parse().ok()?;
746    let hour: i64 = s[11..13].parse().ok()?;
747    let min: i64 = s[14..16].parse().ok()?;
748    let sec: i64 = s[17..19].parse().ok()?;
749
750    // Days since Unix epoch via civil calendar algorithm.
751    // Reference: https://howardhinnant.github.io/date_algorithms.html#days_from_civil
752    let (y, m) = if month <= 2 {
753        (year - 1, month + 9)
754    } else {
755        (year, month - 3)
756    };
757    let era = y.div_euclid(400);
758    let yoe = y - era * 400;
759    let doy = (153 * m + 2) / 5 + day - 1;
760    let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
761    let days = era * 146_097 + doe - 719_468;
762
763    Some(days * 86_400 + hour * 3_600 + min * 60 + sec)
764}
765
766#[cfg(test)]
767mod tests {
768    use super::*;
769    use crate::graph::GraphStore;
770    use crate::graph::types::EntityType;
771    use crate::store::SqliteStore;
772
773    async fn setup_store() -> GraphStore {
774        let store = SqliteStore::new(":memory:").await.unwrap();
775        GraphStore::new(store.pool().clone())
776    }
777
778    fn default_params() -> SpreadingActivationParams {
779        SpreadingActivationParams {
780            decay_lambda: 0.85,
781            max_hops: 3,
782            activation_threshold: 0.1,
783            inhibition_threshold: 0.8,
784            max_activated_nodes: 50,
785            temporal_decay_rate: 0.0,
786            seed_structural_weight: 0.4,
787            seed_community_cap: 3,
788            alpha: 0.3,
789        }
790    }
791
792    // Test 1: empty graph (no edges) — seed entity is still returned as activated node,
793    // but no facts (edges) are found. Spread does not validate entity existence in DB.
794    #[tokio::test]
795    async fn spread_empty_graph_no_edges_no_facts() {
796        let store = setup_store().await;
797        let sa = SpreadingActivation::new(default_params());
798        let seeds = HashMap::from([(1_i64, 1.0_f32)]);
799        let (nodes, facts) = sa.spread(&store, seeds, &[]).await.unwrap();
800        // Seed node is returned as activated (activation=1.0, depth=0).
801        assert_eq!(nodes.len(), 1, "seed must be in activated nodes");
802        assert_eq!(nodes[0].entity_id, 1);
803        assert!((nodes[0].activation - 1.0).abs() < 1e-6);
804        // No edges in empty graph, so no ActivatedFacts.
805        assert!(
806            facts.is_empty(),
807            "expected no activated facts on empty graph"
808        );
809    }
810
811    // Test 2: empty seeds returns empty
812    #[tokio::test]
813    async fn spread_empty_seeds_returns_empty() {
814        let store = setup_store().await;
815        let sa = SpreadingActivation::new(default_params());
816        let (nodes, facts) = sa.spread(&store, HashMap::new(), &[]).await.unwrap();
817        assert!(nodes.is_empty());
818        assert!(facts.is_empty());
819    }
820
821    // Test 3: single seed with no edges returns only the seed
822    #[tokio::test]
823    async fn spread_single_seed_no_edges_returns_seed() {
824        let store = setup_store().await;
825        let alice = store
826            .upsert_entity("Alice", "Alice", EntityType::Person, None)
827            .await
828            .unwrap()
829            .0;
830
831        let sa = SpreadingActivation::new(default_params());
832        let seeds = HashMap::from([(alice, 1.0_f32)]);
833        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
834        assert_eq!(nodes.len(), 1);
835        assert_eq!(nodes[0].entity_id, alice);
836        assert_eq!(nodes[0].depth, 0);
837        assert!((nodes[0].activation - 1.0).abs() < 1e-6);
838    }
839
840    // Test 4: linear chain A->B->C with max_hops=3 — all activated, scores decay
841    #[tokio::test]
842    async fn spread_linear_chain_all_activated_with_decay() {
843        let store = setup_store().await;
844        let a = store
845            .upsert_entity("A", "A", EntityType::Person, None)
846            .await
847            .unwrap()
848            .0;
849        let b = store
850            .upsert_entity("B", "B", EntityType::Person, None)
851            .await
852            .unwrap()
853            .0;
854        let c = store
855            .upsert_entity("C", "C", EntityType::Person, None)
856            .await
857            .unwrap()
858            .0;
859        store
860            .insert_edge(a, b, "knows", "A knows B", 1.0, None)
861            .await
862            .unwrap();
863        store
864            .insert_edge(b, c, "knows", "B knows C", 1.0, None)
865            .await
866            .unwrap();
867
868        let mut cfg = default_params();
869        cfg.max_hops = 3;
870        cfg.decay_lambda = 0.9;
871        let sa = SpreadingActivation::new(cfg);
872        let seeds = HashMap::from([(a, 1.0_f32)]);
873        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
874
875        let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
876        assert!(ids.contains(&a), "A (seed) must be activated");
877        assert!(ids.contains(&b), "B (hop 1) must be activated");
878        assert!(ids.contains(&c), "C (hop 2) must be activated");
879
880        // Scores must decay: score(A) > score(B) > score(C)
881        let score_a = nodes.iter().find(|n| n.entity_id == a).unwrap().activation;
882        let score_b = nodes.iter().find(|n| n.entity_id == b).unwrap().activation;
883        let score_c = nodes.iter().find(|n| n.entity_id == c).unwrap().activation;
884        assert!(
885            score_a > score_b,
886            "seed A should have higher activation than hop-1 B"
887        );
888        assert!(
889            score_b > score_c,
890            "hop-1 B should have higher activation than hop-2 C"
891        );
892    }
893
894    // Test 5: linear chain with max_hops=1 — C not activated
895    #[tokio::test]
896    async fn spread_linear_chain_max_hops_limits_reach() {
897        let store = setup_store().await;
898        let a = store
899            .upsert_entity("A", "A", EntityType::Person, None)
900            .await
901            .unwrap()
902            .0;
903        let b = store
904            .upsert_entity("B", "B", EntityType::Person, None)
905            .await
906            .unwrap()
907            .0;
908        let c = store
909            .upsert_entity("C", "C", EntityType::Person, None)
910            .await
911            .unwrap()
912            .0;
913        store
914            .insert_edge(a, b, "knows", "A knows B", 1.0, None)
915            .await
916            .unwrap();
917        store
918            .insert_edge(b, c, "knows", "B knows C", 1.0, None)
919            .await
920            .unwrap();
921
922        let mut cfg = default_params();
923        cfg.max_hops = 1;
924        let sa = SpreadingActivation::new(cfg);
925        let seeds = HashMap::from([(a, 1.0_f32)]);
926        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
927
928        let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
929        assert!(ids.contains(&a), "A must be activated (seed)");
930        assert!(ids.contains(&b), "B must be activated (hop 1)");
931        assert!(!ids.contains(&c), "C must NOT be activated with max_hops=1");
932    }
933
934    // Test 6: diamond graph — D receives convergent activation from two paths
935    // Graph: A -> B, A -> C, B -> D, C -> D
936    // With clamped sum, D gets activation from both paths (convergence signal preserved).
937    #[tokio::test]
938    async fn spread_diamond_graph_convergence() {
939        let store = setup_store().await;
940        let a = store
941            .upsert_entity("A", "A", EntityType::Person, None)
942            .await
943            .unwrap()
944            .0;
945        let b = store
946            .upsert_entity("B", "B", EntityType::Person, None)
947            .await
948            .unwrap()
949            .0;
950        let c = store
951            .upsert_entity("C", "C", EntityType::Person, None)
952            .await
953            .unwrap()
954            .0;
955        let d = store
956            .upsert_entity("D", "D", EntityType::Person, None)
957            .await
958            .unwrap()
959            .0;
960        store
961            .insert_edge(a, b, "rel", "A-B", 1.0, None)
962            .await
963            .unwrap();
964        store
965            .insert_edge(a, c, "rel", "A-C", 1.0, None)
966            .await
967            .unwrap();
968        store
969            .insert_edge(b, d, "rel", "B-D", 1.0, None)
970            .await
971            .unwrap();
972        store
973            .insert_edge(c, d, "rel", "C-D", 1.0, None)
974            .await
975            .unwrap();
976
977        let mut cfg = default_params();
978        cfg.max_hops = 3;
979        cfg.decay_lambda = 0.9;
980        cfg.inhibition_threshold = 0.95; // raise inhibition to allow convergence
981        let sa = SpreadingActivation::new(cfg);
982        let seeds = HashMap::from([(a, 1.0_f32)]);
983        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
984
985        let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
986        assert!(ids.contains(&d), "D must be activated via diamond paths");
987
988        // D should be activated at depth 2
989        let node_d = nodes.iter().find(|n| n.entity_id == d).unwrap();
990        assert_eq!(node_d.depth, 2, "D should be at depth 2");
991    }
992
993    // Test 7: inhibition threshold prevents runaway activation in dense cluster
994    #[tokio::test]
995    async fn spread_inhibition_prevents_runaway() {
996        let store = setup_store().await;
997        // Create a hub node connected to many leaves
998        let hub = store
999            .upsert_entity("Hub", "Hub", EntityType::Concept, None)
1000            .await
1001            .unwrap()
1002            .0;
1003
1004        for i in 0..5 {
1005            let leaf = store
1006                .upsert_entity(
1007                    &format!("Leaf{i}"),
1008                    &format!("Leaf{i}"),
1009                    EntityType::Concept,
1010                    None,
1011                )
1012                .await
1013                .unwrap()
1014                .0;
1015            store
1016                .insert_edge(hub, leaf, "has", &format!("Hub has Leaf{i}"), 1.0, None)
1017                .await
1018                .unwrap();
1019            // Connect all leaves back to hub to create a dense cluster
1020            store
1021                .insert_edge(
1022                    leaf,
1023                    hub,
1024                    "part_of",
1025                    &format!("Leaf{i} part_of Hub"),
1026                    1.0,
1027                    None,
1028                )
1029                .await
1030                .unwrap();
1031        }
1032
1033        // Seed hub with full activation — it should be inhibited after hop 1
1034        let mut cfg = default_params();
1035        cfg.inhibition_threshold = 0.8;
1036        cfg.max_hops = 3;
1037        let sa = SpreadingActivation::new(cfg);
1038        let seeds = HashMap::from([(hub, 1.0_f32)]);
1039        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1040
1041        // Hub should remain at initial activation (1.0), not grow unbounded
1042        let hub_node = nodes.iter().find(|n| n.entity_id == hub);
1043        assert!(hub_node.is_some(), "hub must be in results");
1044        assert!(
1045            hub_node.unwrap().activation <= 1.0,
1046            "activation must not exceed 1.0"
1047        );
1048    }
1049
1050    // Test 8: max_activated_nodes cap — lowest activations pruned
1051    #[tokio::test]
1052    async fn spread_max_activated_nodes_cap_enforced() {
1053        let store = setup_store().await;
1054        let root = store
1055            .upsert_entity("Root", "Root", EntityType::Person, None)
1056            .await
1057            .unwrap()
1058            .0;
1059
1060        // Create 20 leaf nodes connected to root
1061        for i in 0..20 {
1062            let leaf = store
1063                .upsert_entity(
1064                    &format!("Node{i}"),
1065                    &format!("Node{i}"),
1066                    EntityType::Concept,
1067                    None,
1068                )
1069                .await
1070                .unwrap()
1071                .0;
1072            store
1073                .insert_edge(root, leaf, "has", &format!("Root has Node{i}"), 0.9, None)
1074                .await
1075                .unwrap();
1076        }
1077
1078        let max_nodes = 5;
1079        let cfg = SpreadingActivationParams {
1080            max_activated_nodes: max_nodes,
1081            max_hops: 2,
1082            ..default_params()
1083        };
1084        let sa = SpreadingActivation::new(cfg);
1085        let seeds = HashMap::from([(root, 1.0_f32)]);
1086        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1087
1088        assert!(
1089            nodes.len() <= max_nodes,
1090            "activation must be capped at {max_nodes} nodes, got {}",
1091            nodes.len()
1092        );
1093    }
1094
1095    // Test 9: temporal decay — recent edges produce higher activation
1096    #[tokio::test]
1097    async fn spread_temporal_decay_recency_effect() {
1098        let store = setup_store().await;
1099        let src = store
1100            .upsert_entity("Src", "Src", EntityType::Person, None)
1101            .await
1102            .unwrap()
1103            .0;
1104        let recent = store
1105            .upsert_entity("Recent", "Recent", EntityType::Tool, None)
1106            .await
1107            .unwrap()
1108            .0;
1109        let old = store
1110            .upsert_entity("Old", "Old", EntityType::Tool, None)
1111            .await
1112            .unwrap()
1113            .0;
1114
1115        // Insert recent edge (default valid_from = now)
1116        store
1117            .insert_edge(src, recent, "uses", "Src uses Recent", 1.0, None)
1118            .await
1119            .unwrap();
1120
1121        // Insert old edge manually with a 1970 timestamp
1122        zeph_db::query(
1123            sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
1124             VALUES (?1, ?2, 'uses', 'Src uses Old', 1.0, '1970-01-01 00:00:00')"),
1125        )
1126        .bind(src)
1127        .bind(old)
1128        .execute(store.pool())
1129        .await
1130        .unwrap();
1131
1132        let mut cfg = default_params();
1133        cfg.max_hops = 2;
1134        // Use significant temporal decay rate to distinguish recent vs old
1135        let sa = SpreadingActivation::new(SpreadingActivationParams {
1136            temporal_decay_rate: 0.5,
1137            ..cfg
1138        });
1139        let seeds = HashMap::from([(src, 1.0_f32)]);
1140        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1141
1142        let score_recent = nodes
1143            .iter()
1144            .find(|n| n.entity_id == recent)
1145            .map_or(0.0, |n| n.activation);
1146        let score_old = nodes
1147            .iter()
1148            .find(|n| n.entity_id == old)
1149            .map_or(0.0, |n| n.activation);
1150
1151        assert!(
1152            score_recent > score_old,
1153            "recent edge ({score_recent}) must produce higher activation than old edge ({score_old})"
1154        );
1155    }
1156
1157    // Test 10: edge_type filtering — only edges of specified type are traversed
1158    #[tokio::test]
1159    async fn spread_edge_type_filter_excludes_other_types() {
1160        let store = setup_store().await;
1161        let a = store
1162            .upsert_entity("A", "A", EntityType::Person, None)
1163            .await
1164            .unwrap()
1165            .0;
1166        let b_semantic = store
1167            .upsert_entity("BSemantic", "BSemantic", EntityType::Tool, None)
1168            .await
1169            .unwrap()
1170            .0;
1171        let c_causal = store
1172            .upsert_entity("CCausal", "CCausal", EntityType::Concept, None)
1173            .await
1174            .unwrap()
1175            .0;
1176
1177        // Semantic edge from A
1178        store
1179            .insert_edge(a, b_semantic, "uses", "A uses BSemantic", 1.0, None)
1180            .await
1181            .unwrap();
1182
1183        // Causal edge from A (inserted with explicit edge_type)
1184        zeph_db::query(
1185            sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, edge_type)
1186             VALUES (?1, ?2, 'caused', 'A caused CCausal', 1.0, datetime('now'), 'causal')"),
1187        )
1188        .bind(a)
1189        .bind(c_causal)
1190        .execute(store.pool())
1191        .await
1192        .unwrap();
1193
1194        let cfg = default_params();
1195        let sa = SpreadingActivation::new(cfg);
1196
1197        // Spread with only semantic edges
1198        let seeds = HashMap::from([(a, 1.0_f32)]);
1199        let (nodes, _) = sa
1200            .spread(&store, seeds, &[EdgeType::Semantic])
1201            .await
1202            .unwrap();
1203
1204        let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
1205        assert!(
1206            ids.contains(&b_semantic),
1207            "BSemantic must be activated via semantic edge"
1208        );
1209        assert!(
1210            !ids.contains(&c_causal),
1211            "CCausal must NOT be activated when filtering to semantic only"
1212        );
1213    }
1214
1215    // Test 11: large seed list (stress test for batch query)
1216    #[tokio::test]
1217    async fn spread_large_seed_list() {
1218        let store = setup_store().await;
1219        let mut seeds = HashMap::new();
1220
1221        // Create 100 seed entities — tests that edges_for_entities handles chunking correctly
1222        for i in 0..100i64 {
1223            let id = store
1224                .upsert_entity(
1225                    &format!("Entity{i}"),
1226                    &format!("entity{i}"),
1227                    EntityType::Concept,
1228                    None,
1229                )
1230                .await
1231                .unwrap()
1232                .0;
1233            seeds.insert(id, 1.0_f32);
1234        }
1235
1236        let cfg = default_params();
1237        let sa = SpreadingActivation::new(cfg);
1238        // Should complete without error even with 100 seeds (chunking handles SQLite limit)
1239        let result = sa.spread(&store, seeds, &[]).await;
1240        assert!(
1241            result.is_ok(),
1242            "large seed list must not error: {:?}",
1243            result.err()
1244        );
1245    }
1246
1247    // ── HL-F5 unit tests ─────────────────────────────────────────────────────
1248
1249    #[test]
1250    fn hela_cosine_identical_vectors() {
1251        let v = vec![1.0_f32, 0.0, 0.0];
1252        assert!(
1253            (cosine(&v, &v) - 1.0).abs() < 1e-6,
1254            "identical vectors → cosine 1.0"
1255        );
1256    }
1257
1258    #[test]
1259    fn hela_cosine_orthogonal_vectors() {
1260        let a = vec![1.0_f32, 0.0];
1261        let b = vec![0.0_f32, 1.0];
1262        assert!(
1263            cosine(&a, &b).abs() < 1e-6,
1264            "orthogonal vectors → cosine 0.0"
1265        );
1266    }
1267
1268    #[test]
1269    fn hela_cosine_anti_correlated() {
1270        let a = vec![1.0_f32, 0.0];
1271        let b = vec![-1.0_f32, 0.0];
1272        assert!(
1273            cosine(&a, &b) < 0.0,
1274            "anti-correlated vectors → negative cosine"
1275        );
1276    }
1277
1278    #[test]
1279    fn hela_cosine_zero_vector_no_panic() {
1280        let a = vec![0.0_f32, 0.0];
1281        let b = vec![1.0_f32, 0.0];
1282        // Should not panic — denom is guarded by f32::EPSILON
1283        let result = cosine(&a, &b);
1284        assert!(
1285            result.is_finite(),
1286            "zero-norm vector must yield finite cosine"
1287        );
1288    }
1289
1290    #[test]
1291    fn hela_spread_params_default_depth_is_two() {
1292        let p = HelaSpreadParams::default();
1293        assert_eq!(p.spread_depth, 2);
1294        assert!(p.step_budget.is_some());
1295        assert!(p.edge_types.is_empty());
1296        assert_eq!(p.max_visited, 200);
1297    }
1298
1299    #[test]
1300    fn hela_spread_params_default_embed_timeout_is_some() {
1301        let p = HelaSpreadParams::default();
1302        assert!(
1303            p.embed_timeout.is_some(),
1304            "default embed_timeout must be Some (5 s)"
1305        );
1306    }
1307
1308    // Regression test for #4285: hela_spreading_recall must return
1309    // MemoryError::Timeout when the embed provider stalls beyond embed_timeout.
1310    #[tokio::test]
1311    async fn hela_spreading_recall_embed_timeout_returns_error() {
1312        use std::time::Duration;
1313        use zeph_llm::any::AnyProvider;
1314        use zeph_llm::mock::MockProvider;
1315
1316        use crate::embedding_store::EmbeddingStore;
1317        use crate::error::MemoryError;
1318        use crate::in_memory_store::InMemoryVectorStore;
1319
1320        let store = setup_store().await;
1321
1322        // Provider sleeps 500 ms; timeout is set to 50 ms → must fire.
1323        let mock = MockProvider::default().with_embed_delay(500);
1324        let provider = AnyProvider::Mock(mock);
1325
1326        let sqlite = crate::store::SqliteStore::with_pool_size(":memory:", 1)
1327            .await
1328            .unwrap();
1329        let embeddings =
1330            EmbeddingStore::with_store(Box::new(InMemoryVectorStore::new()), sqlite.pool().clone());
1331
1332        let params = HelaSpreadParams {
1333            embed_timeout: Some(Duration::from_millis(50)),
1334            ..Default::default()
1335        };
1336
1337        let result = hela_spreading_recall(
1338            &store,
1339            &embeddings,
1340            &provider,
1341            "test query",
1342            5,
1343            &params,
1344            false,
1345            0.0,
1346        )
1347        .await;
1348
1349        assert!(
1350            matches!(result, Err(MemoryError::Timeout(_))),
1351            "expected Err(MemoryError::Timeout), got {result:?}"
1352        );
1353    }
1354
1355    // When embed_timeout is None the embed call is not wrapped; the (fast) mock
1356    // returns immediately and the function must succeed.
1357    #[tokio::test]
1358    async fn hela_spreading_recall_no_timeout_does_not_wrap() {
1359        use zeph_llm::any::AnyProvider;
1360        use zeph_llm::mock::MockProvider;
1361
1362        use crate::embedding_store::EmbeddingStore;
1363        use crate::in_memory_store::InMemoryVectorStore;
1364
1365        let store = setup_store().await;
1366
1367        let mock = MockProvider::default().with_embed_delay(0);
1368        let provider = AnyProvider::Mock(mock);
1369
1370        let sqlite = crate::store::SqliteStore::with_pool_size(":memory:", 1)
1371            .await
1372            .unwrap();
1373        let embeddings =
1374            EmbeddingStore::with_store(Box::new(InMemoryVectorStore::new()), sqlite.pool().clone());
1375
1376        let params = HelaSpreadParams {
1377            embed_timeout: None,
1378            ..Default::default()
1379        };
1380
1381        // embed returns a zero-dimension vector (embed not configured for 384-dim),
1382        // so Qdrant search finds nothing — the function returns Ok(Vec::new()).
1383        let result = hela_spreading_recall(
1384            &store,
1385            &embeddings,
1386            &provider,
1387            "test query",
1388            5,
1389            &params,
1390            false,
1391            0.0,
1392        )
1393        .await;
1394
1395        // The mock embed returns an empty vec by default; the ANN search will
1396        // find no results — the expected outcome is Ok(empty) or a non-Timeout error.
1397        assert!(
1398            !matches!(result, Err(crate::error::MemoryError::Timeout(_))),
1399            "embed_timeout: None must not produce a Timeout error, got {result:?}"
1400        );
1401    }
1402
1403    #[test]
1404    fn hela_synthetic_anchor_edge_id_is_zero() {
1405        let edge = Edge::synthetic_anchor(42);
1406        assert_eq!(
1407            edge.id, 0,
1408            "synthetic anchor must have id = 0 to be excluded from Hebbian"
1409        );
1410        assert_eq!(edge.source_entity_id, 42);
1411        assert_eq!(edge.target_entity_id, 42);
1412    }
1413
1414    #[test]
1415    fn hela_negative_cosine_clamped_to_zero_in_score() {
1416        // path_weight × cosine.max(0.0): negative cosine must contribute 0.0
1417        let anti = vec![-1.0_f32, 0.0];
1418        let query = vec![1.0_f32, 0.0];
1419        let cosine_raw = cosine(&query, &anti);
1420        assert!(cosine_raw < 0.0);
1421        let clamped = cosine_raw.max(0.0);
1422        let fact_score = 0.9_f32 * clamped;
1423        assert!(
1424            fact_score < f32::EPSILON,
1425            "anti-correlated score must be 0.0"
1426        );
1427    }
1428
1429    #[test]
1430    fn hela_path_weight_multiplicative() {
1431        // Two-hop path with edge weights 0.8, 0.5 → path_weight = 0.4
1432        let w1 = 0.8_f32;
1433        let w2 = 0.5_f32;
1434        let expected = w1 * w2;
1435        assert!((expected - 0.4).abs() < 1e-6);
1436    }
1437
1438    #[test]
1439    fn hela_max_path_weight_on_multipath() {
1440        // When two paths reach the same node, keep the higher path_weight.
1441        let pw_a = 0.9_f32; // short direct path
1442        let pw_b = 0.3_f32; // longer indirect path
1443        let kept = pw_a.max(pw_b);
1444        assert!(
1445            (kept - 0.9).abs() < 1e-6,
1446            "multi-path resolution must keep maximum path_weight"
1447        );
1448    }
1449
1450    #[test]
1451    fn hela_fact_score_formula() {
1452        let path_weight = 0.8_f32;
1453        let cosine_clamped = 0.75_f32;
1454        let expected = path_weight * cosine_clamped;
1455        // Verify the formula used in hela_spreading_recall Step 4.
1456        assert!((expected - 0.6).abs() < 1e-5);
1457    }
1458
1459    /// SYNAPSE alpha-blend: blended score must equal `alpha*fast + (1-alpha)*slow`.
1460    ///
1461    /// With diverged fast=0.7, slow=0.51, alpha=0.3:
1462    ///   blended = 0.3*0.7 + 0.7*0.51 = 0.21 + 0.357 = 0.567
1463    #[tokio::test]
1464    async fn synapse_blend_uses_alpha_not_raw_confidence() {
1465        let store = SqliteStore::new(":memory:").await.unwrap();
1466        let gs = GraphStore::new(store.pool().clone()).with_benna_rates(0.5, 0.05);
1467
1468        let alpha = 0.3_f32;
1469        let params = SpreadingActivationParams {
1470            alpha,
1471            ..default_params()
1472        };
1473
1474        let src = gs
1475            .upsert_entity("Blend_src", "Blend_src", EntityType::Person, None)
1476            .await
1477            .unwrap()
1478            .0;
1479        let tgt = gs
1480            .upsert_entity("Blend_tgt", "Blend_tgt", EntityType::Concept, None)
1481            .await
1482            .unwrap()
1483            .0;
1484
1485        // First insert — fast=slow=0.5.
1486        gs.insert_edge_typed(
1487            src,
1488            tgt,
1489            "knows",
1490            "Blend_src knows Blend_tgt",
1491            0.5,
1492            None,
1493            crate::graph::types::EdgeType::Semantic,
1494        )
1495        .await
1496        .unwrap();
1497        // Second insert — applies Benna-Fusi: fast≈0.7, slow≈0.51.
1498        gs.insert_edge_typed(
1499            src,
1500            tgt,
1501            "knows",
1502            "Blend_src knows Blend_tgt",
1503            0.9,
1504            None,
1505            crate::graph::types::EdgeType::Semantic,
1506        )
1507        .await
1508        .unwrap();
1509
1510        let sa = SpreadingActivation::new(params);
1511        let seeds = HashMap::from([(src, 1.0_f32)]);
1512        let (_nodes, facts) = sa.spread(&gs, seeds, &[]).await.unwrap();
1513
1514        // The blend formula is applied inside propagate_one_hop; the resulting edge weight
1515        // is evolved_weight(retrieval_count, blended). We verify that at least one fact is
1516        // returned with a score that is NOT equal to the raw confidence (0.9), confirming
1517        // the blend path was taken.
1518        assert!(!facts.is_empty(), "spread must return at least one fact");
1519        let raw_score = facts[0].activation_score;
1520        // blended = 0.3*0.7 + 0.7*0.51 ≈ 0.567 — different from raw confidence 0.9
1521        assert!(
1522            (raw_score - 0.9_f32).abs() > 0.05,
1523            "blend score {raw_score} must differ from raw confidence 0.9 — alpha blend not applied"
1524        );
1525    }
1526}