Skip to main content

zeph_memory/graph/
rpe.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! D-MEM RPE-based tiered graph extraction routing.
5//!
6//! Computes a heuristic "reward prediction error" (RPE) signal for each incoming turn.
7//! Low-RPE turns (predictable, topically continuous, no new entities) skip the expensive
8//! MAGMA LLM extraction pipeline. High-RPE turns proceed to full extraction.
9//!
10//! ## RPE formula
11//!
12//! ```text
13//! RPE = 0.5 * (1 - max_cosine_similarity) + 0.5 * entity_novelty_ratio
14//! ```
15//!
16//! Where:
17//! - `max_cosine_similarity` = max cosine similarity between current turn embedding and last N
18//!   turn embeddings. High = topically predictable.
19//! - `entity_novelty_ratio` = fraction of candidate entity names not seen in recent history.
20//!   0.0 if no entities extracted.
21//!
22//! ## Safety valve
23//!
24//! To prevent unbounded skipping, `consecutive_skips` is tracked. When it reaches
25//! `max_skip_turns`, extraction is forced regardless of RPE score.
26
27use std::collections::VecDeque;
28
29use zeph_common::math::cosine_similarity;
30
31/// Maximum number of recent turn embeddings to keep for context similarity computation.
32pub const RPE_EMBEDDING_BUFFER_SIZE: usize = 10;
33
34/// Number of recent entity names to keep in novelty history.
35const ENTITY_HISTORY_SIZE: usize = 200;
36
37/// RPE computation result for a single turn.
38#[derive(Debug, Clone)]
39pub struct RpeSignal {
40    pub rpe_score: f32,
41    pub context_similarity: f32,
42    pub entity_novelty: f32,
43    pub should_extract: bool,
44}
45
46/// Stateful RPE router. Tracks recent embeddings and entity history.
47///
48/// Protected by the caller's synchronization (typically held behind `Arc<Mutex<...>>`
49/// at the `SemanticMemory` layer).
50pub struct RpeRouter {
51    recent_embeddings: VecDeque<Vec<f32>>,
52    entity_history: VecDeque<String>,
53    consecutive_skips: u32,
54    /// RPE below this value → skip extraction. Range: `[0.0, 1.0]`.
55    pub threshold: f32,
56    /// Force extraction after this many consecutive skips. Default: 5.
57    pub max_skip_turns: u32,
58}
59
60impl RpeRouter {
61    #[must_use]
62    pub fn new(threshold: f32, max_skip_turns: u32) -> Self {
63        Self {
64            recent_embeddings: VecDeque::with_capacity(RPE_EMBEDDING_BUFFER_SIZE),
65            entity_history: VecDeque::with_capacity(ENTITY_HISTORY_SIZE),
66            consecutive_skips: 0,
67            threshold,
68            max_skip_turns,
69        }
70    }
71
72    /// Record a turn embedding. Called even when extraction is skipped, so context similarity
73    /// remains up-to-date for the next turn.
74    pub fn push_embedding(&mut self, embedding: Vec<f32>) {
75        if self.recent_embeddings.len() >= RPE_EMBEDDING_BUFFER_SIZE {
76            self.recent_embeddings.pop_front();
77        }
78        self.recent_embeddings.push_back(embedding);
79    }
80
81    /// Record entity names extracted (or candidate names from text) for novelty tracking.
82    pub fn push_entities(&mut self, names: &[String]) {
83        for name in names {
84            if self.entity_history.len() >= ENTITY_HISTORY_SIZE {
85                self.entity_history.pop_front();
86            }
87            self.entity_history.push_back(name.clone());
88        }
89    }
90
91    /// Compute the RPE signal for the current turn.
92    ///
93    /// `turn_embedding` — embedding of the current message.
94    /// `candidate_entities` — entity names extracted from the current message text (may be empty).
95    ///
96    /// Returns the RPE signal. When `recent_embeddings` is empty (cold start), returns
97    /// `rpe_score = 1.0` and `should_extract = true`.
98    #[must_use]
99    pub fn compute(&mut self, turn_embedding: &[f32], candidate_entities: &[String]) -> RpeSignal {
100        // Safety valve: force extraction after max_skip_turns consecutive skips.
101        if self.consecutive_skips >= self.max_skip_turns {
102            tracing::debug!(
103                consecutive_skips = self.consecutive_skips,
104                "D-MEM RPE: safety valve triggered, forcing extraction"
105            );
106            self.consecutive_skips = 0;
107            return RpeSignal {
108                rpe_score: 1.0,
109                context_similarity: 0.0,
110                entity_novelty: 1.0,
111                should_extract: true,
112            };
113        }
114
115        // Cold start: no history yet → always extract.
116        if self.recent_embeddings.is_empty() {
117            return RpeSignal {
118                rpe_score: 1.0,
119                context_similarity: 0.0,
120                entity_novelty: 1.0,
121                should_extract: true,
122            };
123        }
124
125        // Context similarity: max cosine similarity to recent embeddings.
126        let context_similarity = self
127            .recent_embeddings
128            .iter()
129            .map(|emb| cosine_similarity(turn_embedding, emb))
130            .fold(0.0f32, f32::max);
131
132        // Entity novelty: fraction of candidate entities not in history.
133        let entity_novelty = if candidate_entities.is_empty() {
134            0.0
135        } else {
136            let novel = candidate_entities
137                .iter()
138                .filter(|e| !self.entity_history.contains(e))
139                .count();
140            #[allow(clippy::cast_precision_loss)]
141            let ratio = novel as f32 / candidate_entities.len() as f32;
142            ratio
143        };
144
145        let rpe_score = 0.5 * (1.0 - context_similarity) + 0.5 * entity_novelty;
146        let should_extract = rpe_score >= self.threshold;
147
148        if should_extract {
149            self.consecutive_skips = 0;
150        } else {
151            self.consecutive_skips += 1;
152            tracing::debug!(
153                rpe_score,
154                context_similarity,
155                entity_novelty,
156                consecutive_skips = self.consecutive_skips,
157                "D-MEM RPE: low surprise, skipping graph extraction"
158            );
159        }
160
161        RpeSignal {
162            rpe_score,
163            context_similarity,
164            entity_novelty,
165            should_extract,
166        }
167    }
168}
169
170// Lowercased known tech-domain terms that would be missed by capitalization heuristic.
171const TECH_TERMS: &[&str] = &[
172    "rust",
173    "python",
174    "go",
175    "java",
176    "kotlin",
177    "swift",
178    "ruby",
179    "scala",
180    "elixir",
181    "haskell",
182    "typescript",
183    "javascript",
184    "c",
185    "c++",
186    "cpp",
187    "zig",
188    "nim",
189    "odin",
190    "docker",
191    "kubernetes",
192    "k8s",
193    "postgres",
194    "sqlite",
195    "redis",
196    "kafka",
197    "nginx",
198    "linux",
199    "macos",
200    "windows",
201    "android",
202    "ios",
203    "git",
204    "cargo",
205    "npm",
206    "pip",
207    "gradle",
208    "cmake",
209];
210
211/// Extract candidate entity names from text using simple heuristics.
212///
213/// Captures capitalized tokens (length >= 3) that do NOT start the sentence.
214/// Also captures lowercase technical terms known to be common entity types (programming
215/// languages, tools). This is intentionally cheap — no LLM involved.
216///
217/// Returns lowercased names for comparison against stored canonical names.
218#[must_use]
219pub fn extract_candidate_entities(text: &str) -> Vec<String> {
220    let mut candidates = Vec::new();
221    let words: Vec<&str> = text.split_whitespace().collect();
222
223    // Track sentence-start positions to avoid capturing "The", "This", etc.
224    let mut sentence_starts: std::collections::HashSet<usize> = std::collections::HashSet::new();
225    sentence_starts.insert(0);
226    let mut prev_ends_sentence = true; // first word is always sentence-start
227    for (idx, word) in words.iter().enumerate() {
228        if prev_ends_sentence {
229            sentence_starts.insert(idx);
230        }
231        prev_ends_sentence = word.ends_with('.') || word.ends_with('!') || word.ends_with('?');
232    }
233
234    // Collect capitalized non-sentence-start words >= 3 chars.
235    for (idx, word) in words.iter().enumerate() {
236        let clean: String = word
237            .chars()
238            .filter(|c| c.is_alphanumeric() || *c == '_' || *c == '-')
239            .collect();
240        if clean.len() < 3 || sentence_starts.contains(&idx) {
241            continue;
242        }
243        // Skip pure-uppercase acronyms (API, HTTP, JSON).
244        if clean.chars().all(char::is_uppercase) && clean.len() <= 5 {
245            continue;
246        }
247        if clean.chars().next().is_some_and(char::is_uppercase) {
248            candidates.push(clean.to_lowercase());
249        }
250    }
251
252    // Add tech-domain terms found in the text (case-insensitive, word-boundary check).
253    let text_lower = text.to_lowercase();
254    for term in TECH_TERMS {
255        let mut start = 0;
256        while let Some(pos) = text_lower[start..].find(term) {
257            let abs_pos = start + pos;
258            let before_ok = abs_pos == 0
259                || text_lower
260                    .as_bytes()
261                    .get(abs_pos - 1)
262                    .is_none_or(|c| !c.is_ascii_alphanumeric() && *c != b'_');
263            let after_ok = {
264                let end = abs_pos + term.len();
265                end >= text_lower.len()
266                    || text_lower
267                        .as_bytes()
268                        .get(end)
269                        .is_none_or(|c| !c.is_ascii_alphanumeric() && *c != b'_')
270            };
271            if before_ok && after_ok {
272                let t = (*term).to_string();
273                if !candidates.contains(&t) {
274                    candidates.push(t);
275                }
276            }
277            start = abs_pos + 1;
278        }
279    }
280
281    // Deduplicate preserving order.
282    let mut seen = std::collections::HashSet::new();
283    candidates.retain(|c| seen.insert(c.clone()));
284    candidates
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    fn make_embedding(val: f32, len: usize) -> Vec<f32> {
292        vec![val; len]
293    }
294
295    #[test]
296    fn rpe_cold_start_returns_one() {
297        let mut router = RpeRouter::new(0.3, 5);
298        let emb = make_embedding(0.5, 4);
299        let signal = router.compute(&emb, &[]);
300        assert!(signal.should_extract);
301        assert!((signal.rpe_score - 1.0).abs() < 1e-6);
302    }
303
304    #[test]
305    fn rpe_high_similarity_low_novelty_skips() {
306        let mut router = RpeRouter::new(0.3, 5);
307        let emb = make_embedding(1.0, 4);
308        // Seed history with identical embedding.
309        router.push_embedding(emb.clone());
310        router.push_entities(&["rust".to_string()]);
311
312        // Turn with same embedding and known entity → RPE near 0.
313        let signal = router.compute(&emb, &["rust".to_string()]);
314        // context_similarity = 1.0, entity_novelty = 0.0 → RPE = 0.0
315        assert!(!signal.should_extract, "low-RPE turn should be skipped");
316        assert!(signal.rpe_score < 0.3);
317    }
318
319    #[test]
320    fn rpe_low_similarity_high_novelty_extracts() {
321        let mut router = RpeRouter::new(0.3, 5);
322        let prev = vec![1.0f32, 0.0, 0.0, 0.0];
323        router.push_embedding(prev);
324
325        // Orthogonal embedding + all-new entities.
326        let curr = vec![0.0f32, 1.0, 0.0, 0.0];
327        let signal = router.compute(&curr, &["NewFramework".to_string()]);
328        // context_similarity = 0.0, entity_novelty = 1.0 → RPE = 1.0
329        assert!(signal.should_extract);
330        assert!((signal.rpe_score - 1.0).abs() < 1e-6);
331    }
332
333    #[test]
334    fn rpe_max_skip_turns_forces_extraction() {
335        let mut router = RpeRouter::new(0.3, 3);
336        let emb = make_embedding(1.0, 4);
337        router.push_embedding(emb.clone());
338        router.push_entities(&["rust".to_string()]);
339
340        // Force 3 skips.
341        router.consecutive_skips = 3;
342        let signal = router.compute(&emb, &["rust".to_string()]);
343        assert!(signal.should_extract, "safety valve must force extraction");
344        assert_eq!(
345            router.consecutive_skips, 0,
346            "counter must reset after safety valve"
347        );
348    }
349
350    #[test]
351    fn rpe_consecutive_skips_increments() {
352        let mut router = RpeRouter::new(0.9, 10); // high threshold → easy to skip
353        let emb = make_embedding(1.0, 4);
354        router.push_embedding(emb.clone());
355        router.push_entities(&["rust".to_string()]);
356
357        let s = router.compute(&emb, &["rust".to_string()]);
358        if !s.should_extract {
359            assert_eq!(router.consecutive_skips, 1);
360        }
361    }
362
363    #[test]
364    fn extract_candidate_entities_captures_capitalized() {
365        let text = "I use Tokio and Axum for async web development.";
366        let entities = extract_candidate_entities(text);
367        // "Tokio" and "Axum" are mid-sentence capitalized.
368        assert!(
369            entities.contains(&"tokio".to_string()),
370            "expected tokio, got {entities:?}"
371        );
372        assert!(
373            entities.contains(&"axum".to_string()),
374            "expected axum, got {entities:?}"
375        );
376    }
377
378    #[test]
379    fn extract_candidate_entities_captures_tech_terms() {
380        let text = "I write code in rust and use docker for deployment.";
381        let entities = extract_candidate_entities(text);
382        assert!(
383            entities.contains(&"rust".to_string()),
384            "expected rust, got {entities:?}"
385        );
386        assert!(
387            entities.contains(&"docker".to_string()),
388            "expected docker, got {entities:?}"
389        );
390    }
391
392    #[test]
393    fn extract_candidate_entities_ignores_sentence_start() {
394        let text = "The project uses Rust. The team is growing.";
395        let entities = extract_candidate_entities(text);
396        // "The" appears at sentence start and should not be captured.
397        assert!(!entities.contains(&"the".to_string()));
398    }
399
400    #[test]
401    fn extract_candidate_entities_no_duplicates() {
402        let text = "I use rust and I love rust and rust is great.";
403        let entities = extract_candidate_entities(text);
404        let count = entities.iter().filter(|e| e.as_str() == "rust").count();
405        assert_eq!(
406            count, 1,
407            "rust should appear exactly once, got {entities:?}"
408        );
409    }
410}