shodh_memory/memory/injection.rs
1//! Proactive Context Injection System
2//!
3//! Implements truly proactive memory injection - surfacing relevant memories
4//! without explicit agent action. Based on multi-signal relevance scoring
5//! with user-adaptive thresholds.
6//!
7//! # Enhanced Relevance Model (MEMO-1)
8//!
9//! ```text
10//! R(m, c) = α·semantic + β·recency + γ·strength + δ·entity_overlap + ε·type_boost + ζ·file_match - η·suppression
11//! ```
12//!
13//! Where:
14//! - semantic: cosine similarity between memory and context embeddings
15//! - recency: exponential decay based on memory age
16//! - strength: Hebbian edge weight from knowledge graph
17//! - entity_overlap: Jaccard similarity of entities between memory and context
18//! - type_boost: Weight bonus based on memory type (Decision > Learning > Context)
19//! - file_match: Boost when memory mentions files in current context
20//! - suppression: Penalty for memories with negative feedback momentum
21//!
22//! # Feedback Loop
23//!
24//! The system learns from implicit feedback:
25//! - Positive: injected memory referenced in next turn
26//! - Negative: user indicates irrelevance
27//! - Neutral: memory ignored (no adjustment)
28
29use chrono::{DateTime, Utc};
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32use std::time::Instant;
33
34use super::types::MemoryId;
35
36// =============================================================================
37// CONFIGURATION
38// =============================================================================
39
40/// Weights for composite relevance scoring
41///
42/// Enhanced with entity_overlap, type_boost, file_match, and suppression (MEMO-1).
43/// New fields default to 0.0 for backwards compatibility.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct RelevanceWeights {
46 /// Weight for semantic similarity (cosine distance)
47 pub semantic: f32,
48 /// Weight for recency (exponential decay)
49 pub recency: f32,
50 /// Weight for Hebbian strength from graph
51 pub strength: f32,
52 /// Weight for entity overlap between memory and context (MEMO-1)
53 #[serde(default)]
54 pub entity_overlap: f32,
55 /// Weight for memory type boost (Decision > Learning > Context) (MEMO-1)
56 #[serde(default)]
57 pub type_boost: f32,
58 /// Weight for file path matching (MEMO-1)
59 #[serde(default)]
60 pub file_match: f32,
61 /// Weight for negative feedback suppression (MEMO-1)
62 #[serde(default)]
63 pub suppression: f32,
64 /// Weight for episode coherence boost (SHO-temporal)
65 /// Memories from the same episode as the query get boosted
66 #[serde(default)]
67 pub episode_coherence: f32,
68 /// Weight for graph activation from spreading activation traversal
69 /// Higher activation = stronger association in knowledge graph
70 #[serde(default)]
71 pub graph_activation: f32,
72 /// Weight for linguistic score from query analysis
73 /// Focal entity matches, modifier matches, etc.
74 #[serde(default)]
75 pub linguistic_score: f32,
76}
77
78impl Default for RelevanceWeights {
79 fn default() -> Self {
80 Self {
81 semantic: 0.40, // Primary signal - semantic similarity
82 recency: 0.08, // Recent memories get boost
83 strength: 0.08, // Hebbian edge weight (from graph)
84 entity_overlap: 0.08, // Entity Jaccard similarity
85 type_boost: 0.06, // Decision/Learning type boost
86 file_match: 0.04, // File path matching
87 suppression: 0.02, // Negative feedback penalty
88 episode_coherence: 0.06, // Same-episode boost (prevents bleeding)
89 graph_activation: 0.10, // Spreading activation from graph traversal
90 linguistic_score: 0.08, // Query analysis (focal entities, modifiers)
91 }
92 }
93}
94
95impl RelevanceWeights {
96 /// Legacy weights for backwards compatibility (original 3-signal model)
97 pub fn legacy() -> Self {
98 Self {
99 semantic: 0.5,
100 recency: 0.3,
101 strength: 0.2,
102 entity_overlap: 0.0,
103 type_boost: 0.0,
104 file_match: 0.0,
105 suppression: 0.0,
106 episode_coherence: 0.0,
107 graph_activation: 0.0,
108 linguistic_score: 0.0,
109 }
110 }
111}
112
113/// Configuration for proactive injection behavior
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct InjectionConfig {
116 /// Minimum relevance score to trigger injection (0.0 - 1.0)
117 pub min_relevance: f32,
118
119 /// Maximum memories to inject per message
120 pub max_per_message: usize,
121
122 /// Cooldown in seconds before re-injecting same memory
123 pub cooldown_seconds: u64,
124
125 /// Weights for relevance score components
126 pub weights: RelevanceWeights,
127
128 /// Decay rate for recency calculation (λ in e^(-λt))
129 /// Higher = faster decay. Default 0.01 means ~50% at 70 hours
130 pub recency_decay_rate: f32,
131}
132
133impl Default for InjectionConfig {
134 fn default() -> Self {
135 Self {
136 min_relevance: 0.50, // Raised from 0.35 - require stronger semantic match
137 max_per_message: 3,
138 cooldown_seconds: 180,
139 weights: RelevanceWeights::default(),
140 recency_decay_rate: 0.01,
141 }
142 }
143}
144
145impl InjectionConfig {
146 /// Legacy config for backwards compatibility
147 pub fn legacy() -> Self {
148 Self {
149 min_relevance: 0.70,
150 max_per_message: 3,
151 cooldown_seconds: 180,
152 weights: RelevanceWeights::legacy(),
153 recency_decay_rate: 0.01,
154 }
155 }
156}
157
158// Note: RelevanceInput and compute_relevance removed - using unified 5-layer pipeline
159
160// =============================================================================
161// INJECTION ENGINE
162// =============================================================================
163
164/// Candidate memory for injection with computed relevance
165#[derive(Debug, Clone)]
166pub struct InjectionCandidate {
167 pub memory_id: MemoryId,
168 pub relevance_score: f32,
169}
170
171/// Engine that decides which memories to inject
172pub struct InjectionEngine {
173 config: InjectionConfig,
174 /// Tracks last injection time per memory for cooldown
175 cooldowns: HashMap<MemoryId, Instant>,
176}
177
178impl InjectionEngine {
179 pub fn new(config: InjectionConfig) -> Self {
180 Self {
181 config,
182 cooldowns: HashMap::new(),
183 }
184 }
185
186 pub fn with_default_config() -> Self {
187 Self::new(InjectionConfig::default())
188 }
189
190 /// Check if a memory is on cooldown
191 fn on_cooldown(&self, memory_id: &MemoryId) -> bool {
192 if let Some(last) = self.cooldowns.get(memory_id) {
193 last.elapsed().as_secs() < self.config.cooldown_seconds
194 } else {
195 false
196 }
197 }
198
199 /// Select memories for injection from candidates
200 ///
201 /// Filters by:
202 /// 1. Minimum relevance threshold
203 /// 2. Cooldown (recently injected memories excluded)
204 /// 3. Max count limit
205 ///
206 /// Returns memory IDs sorted by relevance (highest first)
207 pub fn select_for_injection(
208 &mut self,
209 mut candidates: Vec<InjectionCandidate>,
210 ) -> Vec<MemoryId> {
211 // Sort by relevance descending
212 candidates.sort_by(|a, b| b.relevance_score.total_cmp(&a.relevance_score));
213
214 let selected: Vec<MemoryId> = candidates
215 .into_iter()
216 .filter(|c| {
217 c.relevance_score >= self.config.min_relevance && !self.on_cooldown(&c.memory_id)
218 })
219 .take(self.config.max_per_message)
220 .map(|c| c.memory_id)
221 .collect();
222
223 // Record injection time for cooldown
224 let now = Instant::now();
225 for id in &selected {
226 self.cooldowns.insert(id.clone(), now);
227 }
228
229 selected
230 }
231
232 /// Clear expired cooldowns to prevent memory leak
233 pub fn cleanup_cooldowns(&mut self) {
234 let threshold = self.config.cooldown_seconds;
235 self.cooldowns
236 .retain(|_, last| last.elapsed().as_secs() < threshold * 2);
237 }
238
239 /// Get current configuration
240 pub fn config(&self) -> &InjectionConfig {
241 &self.config
242 }
243
244 /// Update configuration
245 pub fn set_config(&mut self, config: InjectionConfig) {
246 self.config = config;
247 }
248}
249
250// =============================================================================
251// INJECTION TRACKING (for feedback loop)
252// =============================================================================
253
254/// Record of an injection for feedback tracking
255#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct InjectionRecord {
257 pub memory_id: MemoryId,
258 pub injected_at: DateTime<Utc>,
259 pub relevance_score: f32,
260 pub context_signature: u64,
261}
262
263/// Tracks injections for feedback learning
264#[derive(Debug, Default)]
265pub struct InjectionTracker {
266 /// Recent injections awaiting feedback
267 pending: Vec<InjectionRecord>,
268 /// Max pending records to keep
269 max_pending: usize,
270}
271
272impl InjectionTracker {
273 pub fn new(max_pending: usize) -> Self {
274 Self {
275 pending: Vec::new(),
276 max_pending,
277 }
278 }
279
280 /// Record a new injection
281 pub fn record_injection(
282 &mut self,
283 memory_id: MemoryId,
284 relevance_score: f32,
285 context_signature: u64,
286 ) {
287 let record = InjectionRecord {
288 memory_id,
289 injected_at: Utc::now(),
290 relevance_score,
291 context_signature,
292 };
293
294 self.pending.push(record);
295
296 // Trim old records
297 if self.pending.len() > self.max_pending {
298 self.pending.remove(0);
299 }
300 }
301
302 /// Get pending injections for feedback analysis
303 pub fn pending_injections(&self) -> &[InjectionRecord] {
304 &self.pending
305 }
306
307 /// Clear injections older than given duration
308 pub fn clear_old(&mut self, max_age_seconds: i64) {
309 let cutoff = Utc::now() - chrono::Duration::seconds(max_age_seconds);
310 self.pending.retain(|r| r.injected_at > cutoff);
311 }
312
313 /// Remove specific injection after feedback processed
314 pub fn mark_processed(&mut self, memory_id: &MemoryId) {
315 self.pending.retain(|r| &r.memory_id != memory_id);
316 }
317}
318
319// =============================================================================
320// USER PROFILE (adaptive thresholds)
321// =============================================================================
322
323/// Feedback signal type for learning
324#[derive(Debug, Clone, Copy, PartialEq, Eq)]
325pub enum FeedbackSignal {
326 /// Memory was referenced/used - lower threshold
327 Positive,
328 /// Memory was explicitly rejected - raise threshold
329 Negative,
330 /// Memory was ignored - no change
331 Neutral,
332}
333
334/// Per-user adaptive injection profile
335#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct UserInjectionProfile {
337 pub user_id: String,
338 /// Effective threshold (starts at default, adapts over time)
339 pub effective_threshold: f32,
340 /// Count of positive signals received
341 pub positive_signals: u32,
342 /// Count of negative signals received
343 pub negative_signals: u32,
344 /// Last update timestamp
345 pub updated_at: DateTime<Utc>,
346}
347
348impl UserInjectionProfile {
349 pub fn new(user_id: String) -> Self {
350 Self {
351 user_id,
352 effective_threshold: InjectionConfig::default().min_relevance,
353 positive_signals: 0,
354 negative_signals: 0,
355 updated_at: Utc::now(),
356 }
357 }
358
359 /// Adjust threshold based on feedback signal
360 ///
361 /// - Positive: lower threshold by 0.01 (min 0.50)
362 /// - Negative: raise threshold by 0.02 (max 0.90)
363 /// - Neutral: no change
364 ///
365 /// Asymmetric adjustment: we're more cautious about noise
366 pub fn adjust(&mut self, signal: FeedbackSignal) {
367 match signal {
368 FeedbackSignal::Positive => {
369 self.positive_signals += 1;
370 self.effective_threshold = (self.effective_threshold - 0.01).max(0.50);
371 }
372 FeedbackSignal::Negative => {
373 self.negative_signals += 1;
374 self.effective_threshold = (self.effective_threshold + 0.02).min(0.90);
375 }
376 FeedbackSignal::Neutral => {}
377 }
378 self.updated_at = Utc::now();
379 }
380
381 /// Get signal ratio (positive / total)
382 pub fn signal_ratio(&self) -> f32 {
383 let total = self.positive_signals + self.negative_signals;
384 if total == 0 {
385 0.5 // No data yet
386 } else {
387 self.positive_signals as f32 / total as f32
388 }
389 }
390}
391
392// =============================================================================
393// TESTS
394// =============================================================================
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use uuid::Uuid;
400
401 #[test]
402 fn test_injection_engine_filtering() {
403 let mut engine = InjectionEngine::with_default_config();
404
405 let candidates = vec![
406 InjectionCandidate {
407 memory_id: MemoryId(Uuid::new_v4()),
408 relevance_score: 0.85,
409 },
410 InjectionCandidate {
411 memory_id: MemoryId(Uuid::new_v4()),
412 relevance_score: 0.45, // Below threshold (0.50)
413 },
414 InjectionCandidate {
415 memory_id: MemoryId(Uuid::new_v4()),
416 relevance_score: 0.75,
417 },
418 ];
419
420 let selected = engine.select_for_injection(candidates);
421
422 assert_eq!(selected.len(), 2); // Only 0.85 and 0.75 pass threshold (0.50)
423 }
424
425 #[test]
426 fn test_user_profile_adjustment() {
427 let mut profile = UserInjectionProfile::new("test-user".to_string());
428
429 assert_eq!(profile.effective_threshold, 0.50);
430
431 profile.adjust(FeedbackSignal::Positive);
432 assert!((profile.effective_threshold - 0.49).abs() < 0.01);
433
434 profile.adjust(FeedbackSignal::Negative);
435 assert!((profile.effective_threshold - 0.51).abs() < 0.01);
436
437 // Many negatives should cap at 0.90
438 for _ in 0..20 {
439 profile.adjust(FeedbackSignal::Negative);
440 }
441 assert_eq!(profile.effective_threshold, 0.90);
442 }
443}