1use std::collections::VecDeque;
28
29use crate::graph::belief_revision::cosine_similarity;
30
31pub const RPE_EMBEDDING_BUFFER_SIZE: usize = 10;
33
34const ENTITY_HISTORY_SIZE: usize = 200;
36
37#[derive(Debug, Clone)]
39pub struct RpeSignal {
40 pub rpe_score: f32,
41 pub context_similarity: f32,
42 pub entity_novelty: f32,
43 pub should_extract: bool,
44}
45
46pub struct RpeRouter {
51 recent_embeddings: VecDeque<Vec<f32>>,
52 entity_history: VecDeque<String>,
53 consecutive_skips: u32,
54 pub threshold: f32,
56 pub max_skip_turns: u32,
58}
59
60impl RpeRouter {
61 #[must_use]
62 pub fn new(threshold: f32, max_skip_turns: u32) -> Self {
63 Self {
64 recent_embeddings: VecDeque::with_capacity(RPE_EMBEDDING_BUFFER_SIZE),
65 entity_history: VecDeque::with_capacity(ENTITY_HISTORY_SIZE),
66 consecutive_skips: 0,
67 threshold,
68 max_skip_turns,
69 }
70 }
71
72 pub fn push_embedding(&mut self, embedding: Vec<f32>) {
75 if self.recent_embeddings.len() >= RPE_EMBEDDING_BUFFER_SIZE {
76 self.recent_embeddings.pop_front();
77 }
78 self.recent_embeddings.push_back(embedding);
79 }
80
81 pub fn push_entities(&mut self, names: &[String]) {
83 for name in names {
84 if self.entity_history.len() >= ENTITY_HISTORY_SIZE {
85 self.entity_history.pop_front();
86 }
87 self.entity_history.push_back(name.clone());
88 }
89 }
90
91 #[must_use]
99 pub fn compute(&mut self, turn_embedding: &[f32], candidate_entities: &[String]) -> RpeSignal {
100 if self.consecutive_skips >= self.max_skip_turns {
102 tracing::debug!(
103 consecutive_skips = self.consecutive_skips,
104 "D-MEM RPE: safety valve triggered, forcing extraction"
105 );
106 self.consecutive_skips = 0;
107 return RpeSignal {
108 rpe_score: 1.0,
109 context_similarity: 0.0,
110 entity_novelty: 1.0,
111 should_extract: true,
112 };
113 }
114
115 if self.recent_embeddings.is_empty() {
117 return RpeSignal {
118 rpe_score: 1.0,
119 context_similarity: 0.0,
120 entity_novelty: 1.0,
121 should_extract: true,
122 };
123 }
124
125 let context_similarity = self
127 .recent_embeddings
128 .iter()
129 .map(|emb| cosine_similarity(turn_embedding, emb))
130 .fold(0.0f32, f32::max);
131
132 let entity_novelty = if candidate_entities.is_empty() {
134 0.0
135 } else {
136 let novel = candidate_entities
137 .iter()
138 .filter(|e| !self.entity_history.contains(e))
139 .count();
140 #[allow(clippy::cast_precision_loss)]
141 let ratio = novel as f32 / candidate_entities.len() as f32;
142 ratio
143 };
144
145 let rpe_score = 0.5 * (1.0 - context_similarity) + 0.5 * entity_novelty;
146 let should_extract = rpe_score >= self.threshold;
147
148 if should_extract {
149 self.consecutive_skips = 0;
150 } else {
151 self.consecutive_skips += 1;
152 tracing::debug!(
153 rpe_score,
154 context_similarity,
155 entity_novelty,
156 consecutive_skips = self.consecutive_skips,
157 "D-MEM RPE: low surprise, skipping graph extraction"
158 );
159 }
160
161 RpeSignal {
162 rpe_score,
163 context_similarity,
164 entity_novelty,
165 should_extract,
166 }
167 }
168}
169
170const TECH_TERMS: &[&str] = &[
172 "rust",
173 "python",
174 "go",
175 "java",
176 "kotlin",
177 "swift",
178 "ruby",
179 "scala",
180 "elixir",
181 "haskell",
182 "typescript",
183 "javascript",
184 "c",
185 "c++",
186 "cpp",
187 "zig",
188 "nim",
189 "odin",
190 "docker",
191 "kubernetes",
192 "k8s",
193 "postgres",
194 "sqlite",
195 "redis",
196 "kafka",
197 "nginx",
198 "linux",
199 "macos",
200 "windows",
201 "android",
202 "ios",
203 "git",
204 "cargo",
205 "npm",
206 "pip",
207 "gradle",
208 "cmake",
209];
210
211#[must_use]
219pub fn extract_candidate_entities(text: &str) -> Vec<String> {
220 let mut candidates = Vec::new();
221 let words: Vec<&str> = text.split_whitespace().collect();
222
223 let mut sentence_starts: std::collections::HashSet<usize> = std::collections::HashSet::new();
225 sentence_starts.insert(0);
226 let mut prev_ends_sentence = true; for (idx, word) in words.iter().enumerate() {
228 if prev_ends_sentence {
229 sentence_starts.insert(idx);
230 }
231 prev_ends_sentence = word.ends_with('.') || word.ends_with('!') || word.ends_with('?');
232 }
233
234 for (idx, word) in words.iter().enumerate() {
236 let clean: String = word
237 .chars()
238 .filter(|c| c.is_alphanumeric() || *c == '_' || *c == '-')
239 .collect();
240 if clean.len() < 3 || sentence_starts.contains(&idx) {
241 continue;
242 }
243 if clean.chars().all(char::is_uppercase) && clean.len() <= 5 {
245 continue;
246 }
247 if clean.chars().next().is_some_and(char::is_uppercase) {
248 candidates.push(clean.to_lowercase());
249 }
250 }
251
252 let text_lower = text.to_lowercase();
254 for term in TECH_TERMS {
255 let mut start = 0;
256 while let Some(pos) = text_lower[start..].find(term) {
257 let abs_pos = start + pos;
258 let before_ok = abs_pos == 0
259 || text_lower
260 .as_bytes()
261 .get(abs_pos - 1)
262 .is_none_or(|c| !c.is_ascii_alphanumeric() && *c != b'_');
263 let after_ok = {
264 let end = abs_pos + term.len();
265 end >= text_lower.len()
266 || text_lower
267 .as_bytes()
268 .get(end)
269 .is_none_or(|c| !c.is_ascii_alphanumeric() && *c != b'_')
270 };
271 if before_ok && after_ok {
272 let t = (*term).to_string();
273 if !candidates.contains(&t) {
274 candidates.push(t);
275 }
276 }
277 start = abs_pos + 1;
278 }
279 }
280
281 let mut seen = std::collections::HashSet::new();
283 candidates.retain(|c| seen.insert(c.clone()));
284 candidates
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 fn make_embedding(val: f32, len: usize) -> Vec<f32> {
292 vec![val; len]
293 }
294
295 #[test]
296 fn rpe_cold_start_returns_one() {
297 let mut router = RpeRouter::new(0.3, 5);
298 let emb = make_embedding(0.5, 4);
299 let signal = router.compute(&emb, &[]);
300 assert!(signal.should_extract);
301 assert!((signal.rpe_score - 1.0).abs() < 1e-6);
302 }
303
304 #[test]
305 fn rpe_high_similarity_low_novelty_skips() {
306 let mut router = RpeRouter::new(0.3, 5);
307 let emb = make_embedding(1.0, 4);
308 router.push_embedding(emb.clone());
310 router.push_entities(&["rust".to_string()]);
311
312 let signal = router.compute(&emb, &["rust".to_string()]);
314 assert!(!signal.should_extract, "low-RPE turn should be skipped");
316 assert!(signal.rpe_score < 0.3);
317 }
318
319 #[test]
320 fn rpe_low_similarity_high_novelty_extracts() {
321 let mut router = RpeRouter::new(0.3, 5);
322 let prev = vec![1.0f32, 0.0, 0.0, 0.0];
323 router.push_embedding(prev);
324
325 let curr = vec![0.0f32, 1.0, 0.0, 0.0];
327 let signal = router.compute(&curr, &["NewFramework".to_string()]);
328 assert!(signal.should_extract);
330 assert!((signal.rpe_score - 1.0).abs() < 1e-6);
331 }
332
333 #[test]
334 fn rpe_max_skip_turns_forces_extraction() {
335 let mut router = RpeRouter::new(0.3, 3);
336 let emb = make_embedding(1.0, 4);
337 router.push_embedding(emb.clone());
338 router.push_entities(&["rust".to_string()]);
339
340 router.consecutive_skips = 3;
342 let signal = router.compute(&emb, &["rust".to_string()]);
343 assert!(signal.should_extract, "safety valve must force extraction");
344 assert_eq!(
345 router.consecutive_skips, 0,
346 "counter must reset after safety valve"
347 );
348 }
349
350 #[test]
351 fn rpe_consecutive_skips_increments() {
352 let mut router = RpeRouter::new(0.9, 10); let emb = make_embedding(1.0, 4);
354 router.push_embedding(emb.clone());
355 router.push_entities(&["rust".to_string()]);
356
357 let s = router.compute(&emb, &["rust".to_string()]);
358 if !s.should_extract {
359 assert_eq!(router.consecutive_skips, 1);
360 }
361 }
362
363 #[test]
364 fn extract_candidate_entities_captures_capitalized() {
365 let text = "I use Tokio and Axum for async web development.";
366 let entities = extract_candidate_entities(text);
367 assert!(
369 entities.contains(&"tokio".to_string()),
370 "expected tokio, got {entities:?}"
371 );
372 assert!(
373 entities.contains(&"axum".to_string()),
374 "expected axum, got {entities:?}"
375 );
376 }
377
378 #[test]
379 fn extract_candidate_entities_captures_tech_terms() {
380 let text = "I write code in rust and use docker for deployment.";
381 let entities = extract_candidate_entities(text);
382 assert!(
383 entities.contains(&"rust".to_string()),
384 "expected rust, got {entities:?}"
385 );
386 assert!(
387 entities.contains(&"docker".to_string()),
388 "expected docker, got {entities:?}"
389 );
390 }
391
392 #[test]
393 fn extract_candidate_entities_ignores_sentence_start() {
394 let text = "The project uses Rust. The team is growing.";
395 let entities = extract_candidate_entities(text);
396 assert!(!entities.contains(&"the".to_string()));
398 }
399
400 #[test]
401 fn extract_candidate_entities_no_duplicates() {
402 let text = "I use rust and I love rust and rust is great.";
403 let entities = extract_candidate_entities(text);
404 let count = entities.iter().filter(|e| e.as_str() == "rust").count();
405 assert_eq!(
406 count, 1,
407 "rust should appear exactly once, got {entities:?}"
408 );
409 }
410}