Skip to main content

zeph_memory/graph/
activation.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! SYNAPSE spreading activation retrieval over the entity graph.
5//!
6//! Implements the spreading activation algorithm from arXiv 2601.02744, adapted for
7//! the zeph-memory graph schema. Seeds are matched via fuzzy entity search; activation
8//! propagates hop-by-hop with:
9//! - Exponential decay per hop (`decay_lambda`)
10//! - Edge confidence weighting
11//! - Temporal recency weighting (reuses `GraphConfig.temporal_decay_rate`)
12//! - Lateral inhibition (nodes above `inhibition_threshold` stop receiving activation)
13//! - Per-hop pruning to enforce `max_activated_nodes` bound (SA-INV-04)
14//! - MAGMA edge type filtering via `edge_types` parameter
15
16use std::collections::HashMap;
17use std::sync::OnceLock;
18use std::time::{Instant, SystemTime, UNIX_EPOCH};
19#[allow(unused_imports)]
20use zeph_db::sql;
21
22use crate::embedding_store::EmbeddingStore;
23use crate::error::MemoryError;
24use crate::graph::store::GraphStore;
25use crate::graph::types::{Edge, EdgeType, edge_type_weight, evolved_weight};
26
27/// A graph node that was activated during spreading activation.
28#[derive(Debug, Clone)]
29pub struct ActivatedNode {
30    /// Database ID of the activated entity.
31    pub entity_id: i64,
32    /// Final activation score in `[0.0, 1.0]`.
33    pub activation: f32,
34    /// Hop at which the maximum activation was received (`0` = seed).
35    pub depth: u32,
36}
37
38/// A graph edge traversed during spreading activation, with its activation score.
39#[derive(Debug, Clone)]
40pub struct ActivatedFact {
41    /// The traversed edge.
42    pub edge: Edge,
43    /// Activation score of the source or target entity at time of traversal.
44    pub activation_score: f32,
45}
46
47/// Parameters for spreading activation. Mirrors `SpreadingActivationConfig` but lives
48/// in `zeph-memory` so the crate does not depend on `zeph-config`.
49#[derive(Debug, Clone)]
50pub struct SpreadingActivationParams {
51    pub decay_lambda: f32,
52    pub max_hops: u32,
53    pub activation_threshold: f32,
54    pub inhibition_threshold: f32,
55    pub max_activated_nodes: usize,
56    pub temporal_decay_rate: f64,
57    /// Weight of structural score in hybrid seed ranking. Range: [0.0, 1.0]. Default: 0.4.
58    pub seed_structural_weight: f32,
59    /// Maximum seeds per community ID. 0 = unlimited. Default: 3.
60    pub seed_community_cap: usize,
61}
62
63// ── HL-F5: HeLa-Mem spreading activation (#3346) ─────────────────────────────
64
65/// A graph edge surfaced by HL-F5 spreading activation (#3346), scored by
66/// `path_weight × max(cosine_query_to_endpoint, 0.0)`.
67///
68/// Mirrors [`ActivatedFact`] so callers can dispatch over a single
69/// `Vec<HelaFact>` ↔ `Vec<ActivatedFact>` ↔ `Vec<GraphFact>` shape at the
70/// strategy-selection site.
71#[derive(Debug, Clone)]
72pub struct HelaFact {
73    /// The edge by which the higher-scored endpoint was reached.
74    pub edge: Edge,
75    /// Final HL-F5 score: `path_weight × cosine_clamped`. Range: `[0.0, +∞)`.
76    pub score: f32,
77    /// BFS depth at which `edge` was traversed (`1..=spread_depth`).
78    /// `0` is reserved for the synthetic anchor edge in the isolated-anchor fallback.
79    pub depth: u32,
80    /// Multiplicative product of edge weights along the BFS path that reached
81    /// this edge's far endpoint. Range: `[0.0, +∞)`.
82    pub path_weight: f32,
83    /// Clamped cosine similarity of the far endpoint's entity embedding
84    /// to the query embedding, in `[0.0, 1.0]`. `None` when the endpoint
85    /// has no stored embedding (skipped from results in that case).
86    pub cosine: Option<f32>,
87}
88
89/// Parameters for HL-F5 spreading activation retrieval.
90///
91/// Build via [`Default`] and override individual fields:
92///
93/// ```rust
94/// use zeph_memory::graph::activation::HelaSpreadParams;
95///
96/// let params = HelaSpreadParams { spread_depth: 3, ..Default::default() };
97/// ```
98#[derive(Debug, Clone)]
99pub struct HelaSpreadParams {
100    /// BFS hops. Clamped to `[1, 6]` at runtime. Default: `2`.
101    pub spread_depth: u32,
102    /// MAGMA edge-type filter. Empty = all types. Default: `[]`.
103    pub edge_types: Vec<EdgeType>,
104    /// Soft upper bound on the visited-node set. Default: `200`.
105    pub max_visited: usize,
106    /// Per-step circuit breaker. Any internal step (anchor ANN, edges batch,
107    /// vectors batch) that exceeds this duration triggers an `Ok(Vec::new())`
108    /// fallback with a `WARN`. Default: `Some(8 ms)`.
109    pub step_budget: Option<std::time::Duration>,
110}
111
112impl Default for HelaSpreadParams {
113    fn default() -> Self {
114        Self {
115            spread_depth: 2,
116            edge_types: Vec::new(),
117            max_visited: 200,
118            step_budget: Some(std::time::Duration::from_millis(8)),
119        }
120    }
121}
122
123/// Process-global dim-mismatch sentinel for HL-F5 (keyed by collection name).
124///
125/// MINOR-1 resolution: keyed by collection so re-provisioning with a different
126/// dimension recovers after a process restart.  A per-`SemanticMemory` guard would
127/// require passing state down; a process-global string key is the least-invasive
128/// approach that prevents permanent lockout from transient startup errors.
129/// Test isolation: each test constructs its own `HelaSpreadParams` with
130/// a distinct mock collection name to avoid cross-test interference.
131static HELA_DIM_MISMATCH: OnceLock<String> = OnceLock::new();
132
133/// Cosine similarity of two equal-length slices.
134///
135/// Returns `0.0` when either norm is zero (prevents division by zero).
136fn cosine(a: &[f32], b: &[f32]) -> f32 {
137    let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
138    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
139    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
140    let denom = (norm_a * norm_b).max(f32::EPSILON);
141    dot / denom
142}
143
144/// HL-F5 BFS spreading activation from the top-1 ANN anchor node (#3346).
145///
146/// Algorithm overview:
147/// 1. Embed `query` → anchor via ANN search in the entity Qdrant collection.
148/// 2. BFS up to `params.spread_depth` hops, propagating multiplicative edge
149///    weights (`path_weight = Π edge.weight along path`). Multi-path convergence
150///    keeps the maximum `path_weight`.
151/// 3. Retrieve entity embeddings for all visited nodes via `get_points`.
152/// 4. Score each node: `score = path_weight × max(cosine(query, entity), 0.0)`.
153/// 5. Sort descending, truncate to `limit`, reinforce traversed edges via Hebbian
154///    update (when `hebbian_enabled`).
155///
156/// Fallback: when the anchor entity has no outgoing edges a single synthetic
157/// [`HelaFact`] with `edge.id == 0` and `score = anchor_cosine` is returned
158/// (the real ANN cosine, never a fabricated `1.0`).
159///
160/// Per-step circuit breaker: any individual step exceeding `params.step_budget`
161/// emits a `WARN` and returns `Ok(Vec::new())`.
162///
163/// Dim-mismatch resilience: a one-time dim probe on the first call guards against
164/// collection/provider configuration mismatches (#3382 pattern). Subsequent calls
165/// to a mismatched collection short-circuit immediately.
166///
167/// # Errors
168///
169/// Returns an error if the embed call or any database query fails.
170#[tracing::instrument(
171    name = "memory.graph.hela_spread",
172    skip_all,
173    fields(
174        depth = params.spread_depth,
175        limit,
176        anchor_id = tracing::field::Empty,
177        visited = tracing::field::Empty,
178        scored = tracing::field::Empty,
179        fallback = tracing::field::Empty,
180    )
181)]
182#[allow(clippy::too_many_arguments, clippy::too_many_lines)] // complex algorithm function; both suppressions justified until the function is decomposed in a future refactor
183pub async fn hela_spreading_recall(
184    store: &GraphStore,
185    embeddings: &EmbeddingStore,
186    provider: &zeph_llm::any::AnyProvider,
187    query: &str,
188    limit: usize,
189    params: &HelaSpreadParams,
190    hebbian_enabled: bool,
191    hebbian_lr: f32,
192) -> Result<Vec<HelaFact>, MemoryError> {
193    use zeph_llm::LlmProvider as _;
194
195    const ENTITY_COLLECTION: &str = "zeph_graph_entities";
196
197    if limit == 0 {
198        return Ok(Vec::new());
199    }
200
201    // ── Step 0: dim-mismatch guard ────────────────────────────────────────────
202    // MINOR-1: guard is keyed by collection name so re-provisioning recovers.
203    if HELA_DIM_MISMATCH.get().map(String::as_str) == Some(ENTITY_COLLECTION) {
204        tracing::debug!("hela: dim mismatch previously detected for collection, skipping");
205        return Ok(Vec::new());
206    }
207
208    // ── Step 1: embed query ───────────────────────────────────────────────────
209    let q_vec = provider.embed(query).await?;
210
211    // Dim probe: search with k=1 to catch dimension mismatch at the Qdrant layer.
212    let t_anchor = Instant::now();
213    let anchor_results = match embeddings
214        .search_collection(ENTITY_COLLECTION, &q_vec, 1, None)
215        .await
216    {
217        Ok(r) => r,
218        Err(e) => {
219            let msg = e.to_string();
220            if msg.contains("wrong vector dimension")
221                || msg.contains("InvalidArgument")
222                || msg.contains("dimension")
223            {
224                let _ = HELA_DIM_MISMATCH.set(ENTITY_COLLECTION.to_owned());
225                tracing::warn!(
226                    collection = ENTITY_COLLECTION,
227                    error = %e,
228                    "hela: vector dimension mismatch — HL-F5 disabled for this collection"
229                );
230                return Ok(Vec::new());
231            }
232            return Err(e);
233        }
234    };
235
236    if params.step_budget.is_some_and(|b| t_anchor.elapsed() > b) {
237        tracing::warn!(
238            elapsed_ms = t_anchor.elapsed().as_millis(),
239            "hela: anchor ANN over budget"
240        );
241        return Ok(Vec::new());
242    }
243
244    let Some(anchor_point) = anchor_results.first() else {
245        tracing::debug!("hela: no anchor found, returning empty");
246        return Ok(Vec::new());
247    };
248    let Some(anchor_entity_id) = anchor_point
249        .payload
250        .get("entity_id")
251        .and_then(serde_json::Value::as_i64)
252    else {
253        tracing::warn!("hela: anchor point missing entity_id payload");
254        return Ok(Vec::new());
255    };
256    let anchor_cosine = anchor_point.score;
257
258    tracing::Span::current().record("anchor_id", anchor_entity_id);
259    tracing::debug!(anchor_entity_id, anchor_cosine, "hela: anchor resolved");
260
261    let spread_depth = params.spread_depth.clamp(1, 6);
262
263    // ── Step 2: BFS with multiplicative path-weight propagation ──────────────
264    // `visited`: entity_id → (depth, path_weight, edge_id_via_which_we_arrived)
265    let mut visited: HashMap<i64, (u32, f32, Option<i64>)> = HashMap::new();
266    visited.insert(anchor_entity_id, (0, 1.0, None));
267
268    // Dedup edges keyed by id for Step 4 lookup (avoids N clones per frontier).
269    // MINOR-3 resolution: collect edges into a HashMap<id, Edge> outside the
270    // per-source loop to avoid 10K clones on a hub × 50-entity frontier.
271    let mut edge_cache: HashMap<i64, Edge> = HashMap::new();
272    let mut frontier: Vec<i64> = vec![anchor_entity_id];
273
274    for hop in 0..spread_depth {
275        if frontier.is_empty() {
276            break;
277        }
278
279        tracing::debug!(hop, frontier_size = frontier.len(), "hela: starting hop");
280
281        let t_step = Instant::now();
282        let edges = store
283            .edges_for_entities(&frontier, &params.edge_types)
284            .await?;
285        if params.step_budget.is_some_and(|b| t_step.elapsed() > b) {
286            tracing::warn!(
287                hop,
288                elapsed_ms = t_step.elapsed().as_millis(),
289                "hela: edge-fetch over budget"
290            );
291            return Ok(Vec::new());
292        }
293
294        let mut next_frontier: Vec<i64> = Vec::new();
295
296        for edge in &edges {
297            // Cache by edge id to avoid repeated clones per source in frontier.
298            edge_cache.entry(edge.id).or_insert_with(|| edge.clone());
299
300            for &src_id in &frontier {
301                let neighbor = if edge.source_entity_id == src_id {
302                    edge.target_entity_id
303                } else if edge.target_entity_id == src_id {
304                    edge.source_entity_id
305                } else {
306                    continue;
307                };
308
309                let parent_pw = visited.get(&src_id).map_or(1.0, |&(_, pw, _)| pw);
310                let new_pw = parent_pw * edge.weight;
311
312                // Multi-path resolution: keep MAX path_weight; lower depth as
313                // tie-break. MINOR-4 note: max_visited is a soft bound — the
314                // actual visited set may exceed it by O(edges_per_hop_step) for
315                // one frontier step before the outer break fires.
316                let entry = visited
317                    .entry(neighbor)
318                    .or_insert((hop + 1, 0.0_f32, Some(edge.id)));
319                // Prefer strictly higher path weight; break ties in favour of shallower depth.
320                if new_pw > entry.1
321                    || ((new_pw - entry.1).abs() < f32::EPSILON && hop + 1 < entry.0)
322                {
323                    *entry = (hop + 1, new_pw, Some(edge.id));
324                    if !next_frontier.contains(&neighbor) {
325                        next_frontier.push(neighbor);
326                    }
327                }
328
329                if visited.len() >= params.max_visited {
330                    break;
331                }
332            }
333
334            if visited.len() >= params.max_visited {
335                break;
336            }
337        }
338
339        tracing::debug!(
340            hop,
341            edges_fetched = edges.len(),
342            visited = visited.len(),
343            next_frontier = next_frontier.len(),
344            "hela: hop complete"
345        );
346
347        frontier = next_frontier;
348        if visited.len() >= params.max_visited {
349            break;
350        }
351    }
352
353    // ── Isolated-anchor fallback ──────────────────────────────────────────────
354    // `visited.len() == 1` means no edges were traversed from the anchor.
355    if visited.len() == 1 {
356        tracing::Span::current().record("fallback", true);
357        tracing::debug!(
358            anchor_entity_id,
359            anchor_cosine,
360            "hela: anchor isolated, falling back to pure ANN"
361        );
362        let fact = HelaFact {
363            edge: Edge::synthetic_anchor(anchor_entity_id),
364            score: anchor_cosine,
365            depth: 0,
366            path_weight: 1.0,
367            cosine: Some(anchor_cosine.clamp(0.0, 1.0)),
368        };
369        return Ok(vec![fact]);
370    }
371
372    // ── Step 3: retrieve entity embeddings ───────────────────────────────────
373    let entity_ids: Vec<i64> = visited.keys().copied().collect();
374    let point_id_map = store.qdrant_point_ids_for_entities(&entity_ids).await?;
375    let point_ids: Vec<String> = point_id_map.values().cloned().collect();
376
377    let t_vec = Instant::now();
378    let vec_map = embeddings
379        .get_vectors_from_collection(ENTITY_COLLECTION, &point_ids)
380        .await?;
381    if params.step_budget.is_some_and(|b| t_vec.elapsed() > b) {
382        tracing::warn!(
383            elapsed_ms = t_vec.elapsed().as_millis(),
384            "hela: vectors-batch over budget"
385        );
386        return Ok(Vec::new());
387    }
388
389    // ── Step 4: score per visited node ────────────────────────────────────────
390    // Cosine clamped to [0.0, 1.0]: anti-correlated neighbors score 0.0 so
391    // they are ranked below positively-correlated ones.  A negative cosine on a
392    // strongly-reinforced edge would otherwise invert the retrieval signal.
393    let mut facts: Vec<HelaFact> = Vec::with_capacity(visited.len().saturating_sub(1));
394    for (&entity_id, &(depth, path_weight, edge_id_opt)) in &visited {
395        if entity_id == anchor_entity_id {
396            continue;
397        }
398        let Some(edge_id) = edge_id_opt else {
399            continue;
400        };
401        let Some(point_id) = point_id_map.get(&entity_id) else {
402            continue;
403        };
404        let Some(node_vec) = vec_map.get(point_id) else {
405            continue;
406        };
407        if node_vec.len() != q_vec.len() {
408            // Per-node dim mismatch — skip (defense-in-depth for legacy collections).
409            continue;
410        }
411        let cosine_clamped = cosine(&q_vec, node_vec).max(0.0);
412        let fact_score = path_weight * cosine_clamped;
413        let Some(edge) = edge_cache.get(&edge_id).cloned() else {
414            continue;
415        };
416        facts.push(HelaFact {
417            edge,
418            score: fact_score,
419            depth,
420            path_weight,
421            cosine: Some(cosine_clamped),
422        });
423    }
424
425    // ── Step 5: sort, truncate, Hebbian increment ─────────────────────────────
426    facts.sort_by(|a, b| b.score.total_cmp(&a.score));
427    facts.truncate(limit);
428
429    // HL-F2 reinforcement on edges that survived truncation (kept ≈ used).
430    // Hebbian on "kept edges only" — consistent with graph_recall_activated at
431    // graph/retrieval.rs:427-433. Note: SYNAPSE reinforces all traversed edges;
432    // this PR intentionally reinforces only surfaced edges. See MINOR-5.
433    if hebbian_enabled {
434        let edge_ids: Vec<i64> = facts
435            .iter()
436            .map(|f| f.edge.id)
437            .filter(|&id| id != 0) // skip synthetic anchor
438            .collect();
439        if !edge_ids.is_empty()
440            && let Err(e) = store.apply_hebbian_increment(&edge_ids, hebbian_lr).await
441        {
442            tracing::warn!(error = %e, "hela: hebbian increment failed");
443        }
444    }
445
446    tracing::Span::current().record("visited", visited.len());
447    tracing::Span::current().record("scored", facts.len());
448
449    Ok(facts)
450}
451
452// ── SYNAPSE spreading activation ──────────────────────────────────────────────
453
454/// Spreading activation engine parameterized from [`SpreadingActivationParams`].
455pub struct SpreadingActivation {
456    params: SpreadingActivationParams,
457}
458
459impl SpreadingActivation {
460    /// Create a new spreading activation engine from explicit parameters.
461    ///
462    /// `params.temporal_decay_rate` is taken from `GraphConfig.temporal_decay_rate` so that
463    /// recency weighting reuses the same parameter as BFS recall (SA-INV-05).
464    #[must_use]
465    pub fn new(params: SpreadingActivationParams) -> Self {
466        Self { params }
467    }
468
469    /// Run spreading activation from `seeds` over the graph.
470    ///
471    /// Returns activated nodes sorted by activation score descending, along with
472    /// edges collected during propagation.
473    ///
474    /// # Parameters
475    ///
476    /// - `store`: graph database accessor
477    /// - `seeds`: `HashMap<entity_id, initial_activation>` — nodes to start from
478    /// - `edge_types`: MAGMA subgraph filter; when non-empty, only edges of these types
479    ///   are traversed (mirrors `bfs_typed` behaviour; SA-INV-08)
480    ///
481    /// # Errors
482    ///
483    /// Returns an error if any database query fails.
484    pub async fn spread(
485        &self,
486        store: &GraphStore,
487        seeds: HashMap<i64, f32>,
488        edge_types: &[EdgeType],
489    ) -> Result<(Vec<ActivatedNode>, Vec<ActivatedFact>), MemoryError> {
490        if seeds.is_empty() {
491            return Ok((Vec::new(), Vec::new()));
492        }
493
494        // Compute `now_secs` once for consistent temporal recency weighting
495        // across all edges (matches the pattern in retrieval.rs:83-86).
496        let now_secs: i64 = SystemTime::now()
497            .duration_since(UNIX_EPOCH)
498            .map_or(0, |d| d.as_secs().cast_signed());
499
500        let mut activation = self.initialize_seeds(&seeds);
501        let mut activated_facts: Vec<ActivatedFact> = Vec::new();
502
503        for hop in 0..self.params.max_hops {
504            let active_nodes: Vec<(i64, f32)> = activation
505                .iter()
506                .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
507                .map(|(&id, &(score, _))| (id, score))
508                .collect();
509
510            if active_nodes.is_empty() {
511                break;
512            }
513
514            let node_ids: Vec<i64> = active_nodes.iter().map(|(id, _)| *id).collect();
515            let edges = store.edges_for_entities(&node_ids, edge_types).await?;
516            let edge_count = edges.len();
517
518            let next_activation =
519                self.propagate_one_hop(hop, &active_nodes, &edges, &activation, now_secs);
520
521            let pruned_count = self.merge_and_prune(&mut activation, next_activation);
522
523            tracing::debug!(
524                hop,
525                active_nodes = active_nodes.len(),
526                edges_fetched = edge_count,
527                after_merge = activation.len(),
528                pruned = pruned_count,
529                "spreading activation: hop complete"
530            );
531
532            self.collect_activated_facts(&edges, &activation, &mut activated_facts);
533        }
534
535        let result = self.finalize(activation);
536
537        tracing::info!(
538            activated = result.len(),
539            facts = activated_facts.len(),
540            "spreading activation: complete"
541        );
542
543        Ok((result, activated_facts))
544    }
545
546    /// Populate the activation map from seed scores, filtering seeds below threshold.
547    fn initialize_seeds(&self, seeds: &HashMap<i64, f32>) -> HashMap<i64, (f32, u32)> {
548        let mut activation: HashMap<i64, (f32, u32)> = HashMap::new();
549        let mut seed_count = 0usize;
550        // Seeds bypass activation_threshold (they are query anchors per SYNAPSE semantics).
551        for (entity_id, match_score) in seeds {
552            if *match_score < self.params.activation_threshold {
553                tracing::debug!(
554                    entity_id,
555                    score = match_score,
556                    threshold = self.params.activation_threshold,
557                    "spreading activation: seed below threshold, skipping"
558                );
559                continue;
560            }
561            activation.insert(*entity_id, (*match_score, 0));
562            seed_count += 1;
563        }
564        tracing::debug!(
565            seeds = seed_count,
566            "spreading activation: initialized seeds"
567        );
568        activation
569    }
570
571    /// Compute the next-hop activation map by propagating through `edges`.
572    ///
573    /// Applies lateral inhibition (CRIT-02) and clamped multi-path convergence sums.
574    fn propagate_one_hop(
575        &self,
576        hop: u32,
577        active_nodes: &[(i64, f32)],
578        edges: &[Edge],
579        activation: &HashMap<i64, (f32, u32)>,
580        now_secs: i64,
581    ) -> HashMap<i64, (f32, u32)> {
582        let mut next_activation: HashMap<i64, (f32, u32)> = HashMap::new();
583
584        for edge in edges {
585            for &(active_id, node_score) in active_nodes {
586                let neighbor = if edge.source_entity_id == active_id {
587                    edge.target_entity_id
588                } else if edge.target_entity_id == active_id {
589                    edge.source_entity_id
590                } else {
591                    continue;
592                };
593
594                // Lateral inhibition: skip neighbor if it already has high activation
595                // in either the current map OR this hop's next_activation (CRIT-02 fix:
596                // checks both maps to match SYNAPSE paper semantics and prevent runaway
597                // activation when multiple paths converge in the same hop).
598                let current_score = activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
599                let next_score = next_activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
600                if current_score >= self.params.inhibition_threshold
601                    || next_score >= self.params.inhibition_threshold
602                {
603                    continue;
604                }
605
606                let recency = self.recency_weight(&edge.valid_from, now_secs);
607                let edge_weight = evolved_weight(edge.retrieval_count, edge.confidence);
608                let type_w = edge_type_weight(edge.edge_type);
609                let spread_value =
610                    node_score * self.params.decay_lambda * edge_weight * recency * type_w;
611
612                if spread_value < self.params.activation_threshold {
613                    continue;
614                }
615
616                // Clamped sum preserves the multi-path convergence signal: nodes reachable
617                // via multiple paths receive proportionally higher activation (MAJOR-01).
618                let depth_at_max = hop + 1;
619                let entry = next_activation
620                    .entry(neighbor)
621                    .or_insert((0.0, depth_at_max));
622                let new_score = (entry.0 + spread_value).min(1.0);
623                if new_score > entry.0 {
624                    entry.0 = new_score;
625                    entry.1 = depth_at_max;
626                }
627            }
628        }
629
630        next_activation
631    }
632
633    /// Merge `next_activation` into `activation` and prune to `max_activated_nodes` (SA-INV-04).
634    ///
635    /// Returns the number of pruned nodes for tracing.
636    fn merge_and_prune(
637        &self,
638        activation: &mut HashMap<i64, (f32, u32)>,
639        next_activation: HashMap<i64, (f32, u32)>,
640    ) -> usize {
641        for (node_id, (new_score, new_depth)) in next_activation {
642            let entry = activation.entry(node_id).or_insert((0.0, new_depth));
643            if new_score > entry.0 {
644                entry.0 = new_score;
645                entry.1 = new_depth;
646            }
647        }
648
649        if activation.len() > self.params.max_activated_nodes {
650            let before = activation.len();
651            let mut entries: Vec<(i64, (f32, u32))> = activation.drain().collect();
652            entries.sort_by(|(_, (a, _)), (_, (b, _))| b.total_cmp(a));
653            entries.truncate(self.params.max_activated_nodes);
654            *activation = entries.into_iter().collect();
655            before - self.params.max_activated_nodes
656        } else {
657            0
658        }
659    }
660
661    /// Append edges whose both endpoints are above threshold to `activated_facts`.
662    fn collect_activated_facts(
663        &self,
664        edges: &[Edge],
665        activation: &HashMap<i64, (f32, u32)>,
666        activated_facts: &mut Vec<ActivatedFact>,
667    ) {
668        for edge in edges {
669            let src_score = activation
670                .get(&edge.source_entity_id)
671                .map_or(0.0, |&(s, _)| s);
672            let tgt_score = activation
673                .get(&edge.target_entity_id)
674                .map_or(0.0, |&(s, _)| s);
675            if src_score >= self.params.activation_threshold
676                && tgt_score >= self.params.activation_threshold
677            {
678                let activation_score = src_score.max(tgt_score);
679                activated_facts.push(ActivatedFact {
680                    edge: edge.clone(),
681                    activation_score,
682                });
683            }
684        }
685    }
686
687    /// Collect nodes above threshold into `Vec<ActivatedNode>`, sorted descending by score.
688    fn finalize(&self, activation: HashMap<i64, (f32, u32)>) -> Vec<ActivatedNode> {
689        let mut result: Vec<ActivatedNode> = activation
690            .into_iter()
691            .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
692            .map(|(entity_id, (activation, depth))| ActivatedNode {
693                entity_id,
694                activation,
695                depth,
696            })
697            .collect();
698        result.sort_by(|a, b| b.activation.total_cmp(&a.activation));
699        result
700    }
701
702    /// Compute temporal recency weight for an edge.
703    ///
704    /// Formula: `1.0 / (1.0 + age_days * temporal_decay_rate)`.
705    /// Returns `1.0` when `temporal_decay_rate = 0.0` (no temporal adjustment).
706    /// Reuses the same formula as `GraphFact::score_with_decay` (SA-INV-05).
707    #[allow(clippy::cast_precision_loss)]
708    fn recency_weight(&self, valid_from: &str, now_secs: i64) -> f32 {
709        if self.params.temporal_decay_rate <= 0.0 {
710            return 1.0;
711        }
712        let Some(valid_from_secs) = parse_sqlite_datetime_to_unix(valid_from) else {
713            return 1.0;
714        };
715        let age_secs = (now_secs - valid_from_secs).max(0);
716        let age_days = age_secs as f64 / 86_400.0;
717        let weight = 1.0_f64 / (1.0 + age_days * self.params.temporal_decay_rate);
718        // cast f64 -> f32: safe, weight is in [0.0, 1.0]
719        #[allow(clippy::cast_possible_truncation)]
720        let w = weight as f32;
721        w
722    }
723}
724
725/// Parse a `SQLite` `datetime('now')` string to Unix seconds.
726///
727/// Accepts `"YYYY-MM-DD HH:MM:SS"` (and variants with fractional seconds or timezone suffix).
728/// Returns `None` if the string cannot be parsed.
729#[must_use]
730fn parse_sqlite_datetime_to_unix(s: &str) -> Option<i64> {
731    if s.len() < 19 {
732        return None;
733    }
734    let year: i64 = s[0..4].parse().ok()?;
735    let month: i64 = s[5..7].parse().ok()?;
736    let day: i64 = s[8..10].parse().ok()?;
737    let hour: i64 = s[11..13].parse().ok()?;
738    let min: i64 = s[14..16].parse().ok()?;
739    let sec: i64 = s[17..19].parse().ok()?;
740
741    // Days since Unix epoch via civil calendar algorithm.
742    // Reference: https://howardhinnant.github.io/date_algorithms.html#days_from_civil
743    let (y, m) = if month <= 2 {
744        (year - 1, month + 9)
745    } else {
746        (year, month - 3)
747    };
748    let era = y.div_euclid(400);
749    let yoe = y - era * 400;
750    let doy = (153 * m + 2) / 5 + day - 1;
751    let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
752    let days = era * 146_097 + doe - 719_468;
753
754    Some(days * 86_400 + hour * 3_600 + min * 60 + sec)
755}
756
757#[cfg(test)]
758mod tests {
759    use super::*;
760    use crate::graph::GraphStore;
761    use crate::graph::types::EntityType;
762    use crate::store::SqliteStore;
763
764    async fn setup_store() -> GraphStore {
765        let store = SqliteStore::new(":memory:").await.unwrap();
766        GraphStore::new(store.pool().clone())
767    }
768
769    fn default_params() -> SpreadingActivationParams {
770        SpreadingActivationParams {
771            decay_lambda: 0.85,
772            max_hops: 3,
773            activation_threshold: 0.1,
774            inhibition_threshold: 0.8,
775            max_activated_nodes: 50,
776            temporal_decay_rate: 0.0,
777            seed_structural_weight: 0.4,
778            seed_community_cap: 3,
779        }
780    }
781
782    // Test 1: empty graph (no edges) — seed entity is still returned as activated node,
783    // but no facts (edges) are found. Spread does not validate entity existence in DB.
784    #[tokio::test]
785    async fn spread_empty_graph_no_edges_no_facts() {
786        let store = setup_store().await;
787        let sa = SpreadingActivation::new(default_params());
788        let seeds = HashMap::from([(1_i64, 1.0_f32)]);
789        let (nodes, facts) = sa.spread(&store, seeds, &[]).await.unwrap();
790        // Seed node is returned as activated (activation=1.0, depth=0).
791        assert_eq!(nodes.len(), 1, "seed must be in activated nodes");
792        assert_eq!(nodes[0].entity_id, 1);
793        assert!((nodes[0].activation - 1.0).abs() < 1e-6);
794        // No edges in empty graph, so no ActivatedFacts.
795        assert!(
796            facts.is_empty(),
797            "expected no activated facts on empty graph"
798        );
799    }
800
801    // Test 2: empty seeds returns empty
802    #[tokio::test]
803    async fn spread_empty_seeds_returns_empty() {
804        let store = setup_store().await;
805        let sa = SpreadingActivation::new(default_params());
806        let (nodes, facts) = sa.spread(&store, HashMap::new(), &[]).await.unwrap();
807        assert!(nodes.is_empty());
808        assert!(facts.is_empty());
809    }
810
811    // Test 3: single seed with no edges returns only the seed
812    #[tokio::test]
813    async fn spread_single_seed_no_edges_returns_seed() {
814        let store = setup_store().await;
815        let alice = store
816            .upsert_entity("Alice", "Alice", EntityType::Person, None)
817            .await
818            .unwrap();
819
820        let sa = SpreadingActivation::new(default_params());
821        let seeds = HashMap::from([(alice, 1.0_f32)]);
822        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
823        assert_eq!(nodes.len(), 1);
824        assert_eq!(nodes[0].entity_id, alice);
825        assert_eq!(nodes[0].depth, 0);
826        assert!((nodes[0].activation - 1.0).abs() < 1e-6);
827    }
828
829    // Test 4: linear chain A->B->C with max_hops=3 — all activated, scores decay
830    #[tokio::test]
831    async fn spread_linear_chain_all_activated_with_decay() {
832        let store = setup_store().await;
833        let a = store
834            .upsert_entity("A", "A", EntityType::Person, None)
835            .await
836            .unwrap();
837        let b = store
838            .upsert_entity("B", "B", EntityType::Person, None)
839            .await
840            .unwrap();
841        let c = store
842            .upsert_entity("C", "C", EntityType::Person, None)
843            .await
844            .unwrap();
845        store
846            .insert_edge(a, b, "knows", "A knows B", 1.0, None)
847            .await
848            .unwrap();
849        store
850            .insert_edge(b, c, "knows", "B knows C", 1.0, None)
851            .await
852            .unwrap();
853
854        let mut cfg = default_params();
855        cfg.max_hops = 3;
856        cfg.decay_lambda = 0.9;
857        let sa = SpreadingActivation::new(cfg);
858        let seeds = HashMap::from([(a, 1.0_f32)]);
859        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
860
861        let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
862        assert!(ids.contains(&a), "A (seed) must be activated");
863        assert!(ids.contains(&b), "B (hop 1) must be activated");
864        assert!(ids.contains(&c), "C (hop 2) must be activated");
865
866        // Scores must decay: score(A) > score(B) > score(C)
867        let score_a = nodes.iter().find(|n| n.entity_id == a).unwrap().activation;
868        let score_b = nodes.iter().find(|n| n.entity_id == b).unwrap().activation;
869        let score_c = nodes.iter().find(|n| n.entity_id == c).unwrap().activation;
870        assert!(
871            score_a > score_b,
872            "seed A should have higher activation than hop-1 B"
873        );
874        assert!(
875            score_b > score_c,
876            "hop-1 B should have higher activation than hop-2 C"
877        );
878    }
879
880    // Test 5: linear chain with max_hops=1 — C not activated
881    #[tokio::test]
882    async fn spread_linear_chain_max_hops_limits_reach() {
883        let store = setup_store().await;
884        let a = store
885            .upsert_entity("A", "A", EntityType::Person, None)
886            .await
887            .unwrap();
888        let b = store
889            .upsert_entity("B", "B", EntityType::Person, None)
890            .await
891            .unwrap();
892        let c = store
893            .upsert_entity("C", "C", EntityType::Person, None)
894            .await
895            .unwrap();
896        store
897            .insert_edge(a, b, "knows", "A knows B", 1.0, None)
898            .await
899            .unwrap();
900        store
901            .insert_edge(b, c, "knows", "B knows C", 1.0, None)
902            .await
903            .unwrap();
904
905        let mut cfg = default_params();
906        cfg.max_hops = 1;
907        let sa = SpreadingActivation::new(cfg);
908        let seeds = HashMap::from([(a, 1.0_f32)]);
909        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
910
911        let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
912        assert!(ids.contains(&a), "A must be activated (seed)");
913        assert!(ids.contains(&b), "B must be activated (hop 1)");
914        assert!(!ids.contains(&c), "C must NOT be activated with max_hops=1");
915    }
916
917    // Test 6: diamond graph — D receives convergent activation from two paths
918    // Graph: A -> B, A -> C, B -> D, C -> D
919    // With clamped sum, D gets activation from both paths (convergence signal preserved).
920    #[tokio::test]
921    async fn spread_diamond_graph_convergence() {
922        let store = setup_store().await;
923        let a = store
924            .upsert_entity("A", "A", EntityType::Person, None)
925            .await
926            .unwrap();
927        let b = store
928            .upsert_entity("B", "B", EntityType::Person, None)
929            .await
930            .unwrap();
931        let c = store
932            .upsert_entity("C", "C", EntityType::Person, None)
933            .await
934            .unwrap();
935        let d = store
936            .upsert_entity("D", "D", EntityType::Person, None)
937            .await
938            .unwrap();
939        store
940            .insert_edge(a, b, "rel", "A-B", 1.0, None)
941            .await
942            .unwrap();
943        store
944            .insert_edge(a, c, "rel", "A-C", 1.0, None)
945            .await
946            .unwrap();
947        store
948            .insert_edge(b, d, "rel", "B-D", 1.0, None)
949            .await
950            .unwrap();
951        store
952            .insert_edge(c, d, "rel", "C-D", 1.0, None)
953            .await
954            .unwrap();
955
956        let mut cfg = default_params();
957        cfg.max_hops = 3;
958        cfg.decay_lambda = 0.9;
959        cfg.inhibition_threshold = 0.95; // raise inhibition to allow convergence
960        let sa = SpreadingActivation::new(cfg);
961        let seeds = HashMap::from([(a, 1.0_f32)]);
962        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
963
964        let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
965        assert!(ids.contains(&d), "D must be activated via diamond paths");
966
967        // D should be activated at depth 2
968        let node_d = nodes.iter().find(|n| n.entity_id == d).unwrap();
969        assert_eq!(node_d.depth, 2, "D should be at depth 2");
970    }
971
972    // Test 7: inhibition threshold prevents runaway activation in dense cluster
973    #[tokio::test]
974    async fn spread_inhibition_prevents_runaway() {
975        let store = setup_store().await;
976        // Create a hub node connected to many leaves
977        let hub = store
978            .upsert_entity("Hub", "Hub", EntityType::Concept, None)
979            .await
980            .unwrap();
981
982        for i in 0..5 {
983            let leaf = store
984                .upsert_entity(
985                    &format!("Leaf{i}"),
986                    &format!("Leaf{i}"),
987                    EntityType::Concept,
988                    None,
989                )
990                .await
991                .unwrap();
992            store
993                .insert_edge(hub, leaf, "has", &format!("Hub has Leaf{i}"), 1.0, None)
994                .await
995                .unwrap();
996            // Connect all leaves back to hub to create a dense cluster
997            store
998                .insert_edge(
999                    leaf,
1000                    hub,
1001                    "part_of",
1002                    &format!("Leaf{i} part_of Hub"),
1003                    1.0,
1004                    None,
1005                )
1006                .await
1007                .unwrap();
1008        }
1009
1010        // Seed hub with full activation — it should be inhibited after hop 1
1011        let mut cfg = default_params();
1012        cfg.inhibition_threshold = 0.8;
1013        cfg.max_hops = 3;
1014        let sa = SpreadingActivation::new(cfg);
1015        let seeds = HashMap::from([(hub, 1.0_f32)]);
1016        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1017
1018        // Hub should remain at initial activation (1.0), not grow unbounded
1019        let hub_node = nodes.iter().find(|n| n.entity_id == hub);
1020        assert!(hub_node.is_some(), "hub must be in results");
1021        assert!(
1022            hub_node.unwrap().activation <= 1.0,
1023            "activation must not exceed 1.0"
1024        );
1025    }
1026
1027    // Test 8: max_activated_nodes cap — lowest activations pruned
1028    #[tokio::test]
1029    async fn spread_max_activated_nodes_cap_enforced() {
1030        let store = setup_store().await;
1031        let root = store
1032            .upsert_entity("Root", "Root", EntityType::Person, None)
1033            .await
1034            .unwrap();
1035
1036        // Create 20 leaf nodes connected to root
1037        for i in 0..20 {
1038            let leaf = store
1039                .upsert_entity(
1040                    &format!("Node{i}"),
1041                    &format!("Node{i}"),
1042                    EntityType::Concept,
1043                    None,
1044                )
1045                .await
1046                .unwrap();
1047            store
1048                .insert_edge(root, leaf, "has", &format!("Root has Node{i}"), 0.9, None)
1049                .await
1050                .unwrap();
1051        }
1052
1053        let max_nodes = 5;
1054        let cfg = SpreadingActivationParams {
1055            max_activated_nodes: max_nodes,
1056            max_hops: 2,
1057            ..default_params()
1058        };
1059        let sa = SpreadingActivation::new(cfg);
1060        let seeds = HashMap::from([(root, 1.0_f32)]);
1061        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1062
1063        assert!(
1064            nodes.len() <= max_nodes,
1065            "activation must be capped at {max_nodes} nodes, got {}",
1066            nodes.len()
1067        );
1068    }
1069
1070    // Test 9: temporal decay — recent edges produce higher activation
1071    #[tokio::test]
1072    async fn spread_temporal_decay_recency_effect() {
1073        let store = setup_store().await;
1074        let src = store
1075            .upsert_entity("Src", "Src", EntityType::Person, None)
1076            .await
1077            .unwrap();
1078        let recent = store
1079            .upsert_entity("Recent", "Recent", EntityType::Tool, None)
1080            .await
1081            .unwrap();
1082        let old = store
1083            .upsert_entity("Old", "Old", EntityType::Tool, None)
1084            .await
1085            .unwrap();
1086
1087        // Insert recent edge (default valid_from = now)
1088        store
1089            .insert_edge(src, recent, "uses", "Src uses Recent", 1.0, None)
1090            .await
1091            .unwrap();
1092
1093        // Insert old edge manually with a 1970 timestamp
1094        zeph_db::query(
1095            sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
1096             VALUES (?1, ?2, 'uses', 'Src uses Old', 1.0, '1970-01-01 00:00:00')"),
1097        )
1098        .bind(src)
1099        .bind(old)
1100        .execute(store.pool())
1101        .await
1102        .unwrap();
1103
1104        let mut cfg = default_params();
1105        cfg.max_hops = 2;
1106        // Use significant temporal decay rate to distinguish recent vs old
1107        let sa = SpreadingActivation::new(SpreadingActivationParams {
1108            temporal_decay_rate: 0.5,
1109            ..cfg
1110        });
1111        let seeds = HashMap::from([(src, 1.0_f32)]);
1112        let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1113
1114        let score_recent = nodes
1115            .iter()
1116            .find(|n| n.entity_id == recent)
1117            .map_or(0.0, |n| n.activation);
1118        let score_old = nodes
1119            .iter()
1120            .find(|n| n.entity_id == old)
1121            .map_or(0.0, |n| n.activation);
1122
1123        assert!(
1124            score_recent > score_old,
1125            "recent edge ({score_recent}) must produce higher activation than old edge ({score_old})"
1126        );
1127    }
1128
1129    // Test 10: edge_type filtering — only edges of specified type are traversed
1130    #[tokio::test]
1131    async fn spread_edge_type_filter_excludes_other_types() {
1132        let store = setup_store().await;
1133        let a = store
1134            .upsert_entity("A", "A", EntityType::Person, None)
1135            .await
1136            .unwrap();
1137        let b_semantic = store
1138            .upsert_entity("BSemantic", "BSemantic", EntityType::Tool, None)
1139            .await
1140            .unwrap();
1141        let c_causal = store
1142            .upsert_entity("CCausal", "CCausal", EntityType::Concept, None)
1143            .await
1144            .unwrap();
1145
1146        // Semantic edge from A
1147        store
1148            .insert_edge(a, b_semantic, "uses", "A uses BSemantic", 1.0, None)
1149            .await
1150            .unwrap();
1151
1152        // Causal edge from A (inserted with explicit edge_type)
1153        zeph_db::query(
1154            sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, edge_type)
1155             VALUES (?1, ?2, 'caused', 'A caused CCausal', 1.0, datetime('now'), 'causal')"),
1156        )
1157        .bind(a)
1158        .bind(c_causal)
1159        .execute(store.pool())
1160        .await
1161        .unwrap();
1162
1163        let cfg = default_params();
1164        let sa = SpreadingActivation::new(cfg);
1165
1166        // Spread with only semantic edges
1167        let seeds = HashMap::from([(a, 1.0_f32)]);
1168        let (nodes, _) = sa
1169            .spread(&store, seeds, &[EdgeType::Semantic])
1170            .await
1171            .unwrap();
1172
1173        let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
1174        assert!(
1175            ids.contains(&b_semantic),
1176            "BSemantic must be activated via semantic edge"
1177        );
1178        assert!(
1179            !ids.contains(&c_causal),
1180            "CCausal must NOT be activated when filtering to semantic only"
1181        );
1182    }
1183
1184    // Test 11: large seed list (stress test for batch query)
1185    #[tokio::test]
1186    async fn spread_large_seed_list() {
1187        let store = setup_store().await;
1188        let mut seeds = HashMap::new();
1189
1190        // Create 100 seed entities — tests that edges_for_entities handles chunking correctly
1191        for i in 0..100i64 {
1192            let id = store
1193                .upsert_entity(
1194                    &format!("Entity{i}"),
1195                    &format!("entity{i}"),
1196                    EntityType::Concept,
1197                    None,
1198                )
1199                .await
1200                .unwrap();
1201            seeds.insert(id, 1.0_f32);
1202        }
1203
1204        let cfg = default_params();
1205        let sa = SpreadingActivation::new(cfg);
1206        // Should complete without error even with 100 seeds (chunking handles SQLite limit)
1207        let result = sa.spread(&store, seeds, &[]).await;
1208        assert!(
1209            result.is_ok(),
1210            "large seed list must not error: {:?}",
1211            result.err()
1212        );
1213    }
1214
1215    // ── HL-F5 unit tests ─────────────────────────────────────────────────────
1216
1217    #[test]
1218    fn hela_cosine_identical_vectors() {
1219        let v = vec![1.0_f32, 0.0, 0.0];
1220        assert!(
1221            (cosine(&v, &v) - 1.0).abs() < 1e-6,
1222            "identical vectors → cosine 1.0"
1223        );
1224    }
1225
1226    #[test]
1227    fn hela_cosine_orthogonal_vectors() {
1228        let a = vec![1.0_f32, 0.0];
1229        let b = vec![0.0_f32, 1.0];
1230        assert!(
1231            cosine(&a, &b).abs() < 1e-6,
1232            "orthogonal vectors → cosine 0.0"
1233        );
1234    }
1235
1236    #[test]
1237    fn hela_cosine_anti_correlated() {
1238        let a = vec![1.0_f32, 0.0];
1239        let b = vec![-1.0_f32, 0.0];
1240        assert!(
1241            cosine(&a, &b) < 0.0,
1242            "anti-correlated vectors → negative cosine"
1243        );
1244    }
1245
1246    #[test]
1247    fn hela_cosine_zero_vector_no_panic() {
1248        let a = vec![0.0_f32, 0.0];
1249        let b = vec![1.0_f32, 0.0];
1250        // Should not panic — denom is guarded by f32::EPSILON
1251        let result = cosine(&a, &b);
1252        assert!(
1253            result.is_finite(),
1254            "zero-norm vector must yield finite cosine"
1255        );
1256    }
1257
1258    #[test]
1259    fn hela_spread_params_default_depth_is_two() {
1260        let p = HelaSpreadParams::default();
1261        assert_eq!(p.spread_depth, 2);
1262        assert!(p.step_budget.is_some());
1263        assert!(p.edge_types.is_empty());
1264        assert_eq!(p.max_visited, 200);
1265    }
1266
1267    #[test]
1268    fn hela_synthetic_anchor_edge_id_is_zero() {
1269        let edge = Edge::synthetic_anchor(42);
1270        assert_eq!(
1271            edge.id, 0,
1272            "synthetic anchor must have id = 0 to be excluded from Hebbian"
1273        );
1274        assert_eq!(edge.source_entity_id, 42);
1275        assert_eq!(edge.target_entity_id, 42);
1276    }
1277
1278    #[test]
1279    fn hela_negative_cosine_clamped_to_zero_in_score() {
1280        // path_weight × cosine.max(0.0): negative cosine must contribute 0.0
1281        let anti = vec![-1.0_f32, 0.0];
1282        let query = vec![1.0_f32, 0.0];
1283        let cosine_raw = cosine(&query, &anti);
1284        assert!(cosine_raw < 0.0);
1285        let clamped = cosine_raw.max(0.0);
1286        let fact_score = 0.9_f32 * clamped;
1287        assert!(
1288            fact_score < f32::EPSILON,
1289            "anti-correlated score must be 0.0"
1290        );
1291    }
1292
1293    #[test]
1294    fn hela_path_weight_multiplicative() {
1295        // Two-hop path with edge weights 0.8, 0.5 → path_weight = 0.4
1296        let w1 = 0.8_f32;
1297        let w2 = 0.5_f32;
1298        let expected = w1 * w2;
1299        assert!((expected - 0.4).abs() < 1e-6);
1300    }
1301
1302    #[test]
1303    fn hela_max_path_weight_on_multipath() {
1304        // When two paths reach the same node, keep the higher path_weight.
1305        let pw_a = 0.9_f32; // short direct path
1306        let pw_b = 0.3_f32; // longer indirect path
1307        let kept = pw_a.max(pw_b);
1308        assert!(
1309            (kept - 0.9).abs() < 1e-6,
1310            "multi-path resolution must keep maximum path_weight"
1311        );
1312    }
1313
1314    #[test]
1315    fn hela_fact_score_formula() {
1316        let path_weight = 0.8_f32;
1317        let cosine_clamped = 0.75_f32;
1318        let expected = path_weight * cosine_clamped;
1319        // Verify the formula used in hela_spreading_recall Step 4.
1320        assert!((expected - 0.6).abs() < 1e-5);
1321    }
1322}