Skip to main content

shodh_memory/memory/
injection.rs

1//! Proactive Context Injection System
2//!
3//! Implements truly proactive memory injection - surfacing relevant memories
4//! without explicit agent action. Based on multi-signal relevance scoring
5//! with user-adaptive thresholds.
6//!
7//! # Enhanced Relevance Model (MEMO-1)
8//!
9//! ```text
10//! R(m, c) = α·semantic + β·recency + γ·strength + δ·entity_overlap + ε·type_boost + ζ·file_match - η·suppression
11//! ```
12//!
13//! Where:
14//! - semantic: cosine similarity between memory and context embeddings
15//! - recency: exponential decay based on memory age
16//! - strength: Hebbian edge weight from knowledge graph
17//! - entity_overlap: Jaccard similarity of entities between memory and context
18//! - type_boost: Weight bonus based on memory type (Decision > Learning > Context)
19//! - file_match: Boost when memory mentions files in current context
20//! - suppression: Penalty for memories with negative feedback momentum
21//!
22//! # Feedback Loop
23//!
24//! The system learns from implicit feedback:
25//! - Positive: injected memory referenced in next turn
26//! - Negative: user indicates irrelevance
27//! - Neutral: memory ignored (no adjustment)
28
29use chrono::{DateTime, Utc};
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32use std::time::Instant;
33
34use super::types::MemoryId;
35
36// =============================================================================
37// CONFIGURATION
38// =============================================================================
39
40/// Weights for composite relevance scoring
41///
42/// Enhanced with entity_overlap, type_boost, file_match, and suppression (MEMO-1).
43/// New fields default to 0.0 for backwards compatibility.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct RelevanceWeights {
46    /// Weight for semantic similarity (cosine distance)
47    pub semantic: f32,
48    /// Weight for recency (exponential decay)
49    pub recency: f32,
50    /// Weight for Hebbian strength from graph
51    pub strength: f32,
52    /// Weight for entity overlap between memory and context (MEMO-1)
53    #[serde(default)]
54    pub entity_overlap: f32,
55    /// Weight for memory type boost (Decision > Learning > Context) (MEMO-1)
56    #[serde(default)]
57    pub type_boost: f32,
58    /// Weight for file path matching (MEMO-1)
59    #[serde(default)]
60    pub file_match: f32,
61    /// Weight for negative feedback suppression (MEMO-1)
62    #[serde(default)]
63    pub suppression: f32,
64    /// Weight for episode coherence boost (SHO-temporal)
65    /// Memories from the same episode as the query get boosted
66    #[serde(default)]
67    pub episode_coherence: f32,
68    /// Weight for graph activation from spreading activation traversal
69    /// Higher activation = stronger association in knowledge graph
70    #[serde(default)]
71    pub graph_activation: f32,
72    /// Weight for linguistic score from query analysis
73    /// Focal entity matches, modifier matches, etc.
74    #[serde(default)]
75    pub linguistic_score: f32,
76}
77
78impl Default for RelevanceWeights {
79    fn default() -> Self {
80        Self {
81            semantic: 0.40,          // Primary signal - semantic similarity
82            recency: 0.08,           // Recent memories get boost
83            strength: 0.08,          // Hebbian edge weight (from graph)
84            entity_overlap: 0.08,    // Entity Jaccard similarity
85            type_boost: 0.06,        // Decision/Learning type boost
86            file_match: 0.04,        // File path matching
87            suppression: 0.02,       // Negative feedback penalty
88            episode_coherence: 0.06, // Same-episode boost (prevents bleeding)
89            graph_activation: 0.10,  // Spreading activation from graph traversal
90            linguistic_score: 0.08,  // Query analysis (focal entities, modifiers)
91        }
92    }
93}
94
95impl RelevanceWeights {
96    /// Legacy weights for backwards compatibility (original 3-signal model)
97    pub fn legacy() -> Self {
98        Self {
99            semantic: 0.5,
100            recency: 0.3,
101            strength: 0.2,
102            entity_overlap: 0.0,
103            type_boost: 0.0,
104            file_match: 0.0,
105            suppression: 0.0,
106            episode_coherence: 0.0,
107            graph_activation: 0.0,
108            linguistic_score: 0.0,
109        }
110    }
111}
112
113/// Configuration for proactive injection behavior
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct InjectionConfig {
116    /// Minimum relevance score to trigger injection (0.0 - 1.0)
117    pub min_relevance: f32,
118
119    /// Maximum memories to inject per message
120    pub max_per_message: usize,
121
122    /// Cooldown in seconds before re-injecting same memory
123    pub cooldown_seconds: u64,
124
125    /// Weights for relevance score components
126    pub weights: RelevanceWeights,
127
128    /// Decay rate for recency calculation (λ in e^(-λt))
129    /// Higher = faster decay. Default 0.01 means ~50% at 70 hours
130    pub recency_decay_rate: f32,
131}
132
133impl Default for InjectionConfig {
134    fn default() -> Self {
135        Self {
136            min_relevance: 0.50, // Raised from 0.35 - require stronger semantic match
137            max_per_message: 3,
138            cooldown_seconds: 180,
139            weights: RelevanceWeights::default(),
140            recency_decay_rate: 0.01,
141        }
142    }
143}
144
145impl InjectionConfig {
146    /// Legacy config for backwards compatibility
147    pub fn legacy() -> Self {
148        Self {
149            min_relevance: 0.70,
150            max_per_message: 3,
151            cooldown_seconds: 180,
152            weights: RelevanceWeights::legacy(),
153            recency_decay_rate: 0.01,
154        }
155    }
156}
157
158// Note: RelevanceInput and compute_relevance removed - using unified 5-layer pipeline
159
160// =============================================================================
161// INJECTION ENGINE
162// =============================================================================
163
164/// Candidate memory for injection with computed relevance
165#[derive(Debug, Clone)]
166pub struct InjectionCandidate {
167    pub memory_id: MemoryId,
168    pub relevance_score: f32,
169}
170
171/// Engine that decides which memories to inject
172pub struct InjectionEngine {
173    config: InjectionConfig,
174    /// Tracks last injection time per memory for cooldown
175    cooldowns: HashMap<MemoryId, Instant>,
176}
177
178impl InjectionEngine {
179    pub fn new(config: InjectionConfig) -> Self {
180        Self {
181            config,
182            cooldowns: HashMap::new(),
183        }
184    }
185
186    pub fn with_default_config() -> Self {
187        Self::new(InjectionConfig::default())
188    }
189
190    /// Check if a memory is on cooldown
191    fn on_cooldown(&self, memory_id: &MemoryId) -> bool {
192        if let Some(last) = self.cooldowns.get(memory_id) {
193            last.elapsed().as_secs() < self.config.cooldown_seconds
194        } else {
195            false
196        }
197    }
198
199    /// Select memories for injection from candidates
200    ///
201    /// Filters by:
202    /// 1. Minimum relevance threshold
203    /// 2. Cooldown (recently injected memories excluded)
204    /// 3. Max count limit
205    ///
206    /// Returns memory IDs sorted by relevance (highest first)
207    pub fn select_for_injection(
208        &mut self,
209        mut candidates: Vec<InjectionCandidate>,
210    ) -> Vec<MemoryId> {
211        // Sort by relevance descending
212        candidates.sort_by(|a, b| b.relevance_score.total_cmp(&a.relevance_score));
213
214        let selected: Vec<MemoryId> = candidates
215            .into_iter()
216            .filter(|c| {
217                c.relevance_score >= self.config.min_relevance && !self.on_cooldown(&c.memory_id)
218            })
219            .take(self.config.max_per_message)
220            .map(|c| c.memory_id)
221            .collect();
222
223        // Record injection time for cooldown
224        let now = Instant::now();
225        for id in &selected {
226            self.cooldowns.insert(id.clone(), now);
227        }
228
229        selected
230    }
231
232    /// Clear expired cooldowns to prevent memory leak
233    pub fn cleanup_cooldowns(&mut self) {
234        let threshold = self.config.cooldown_seconds;
235        self.cooldowns
236            .retain(|_, last| last.elapsed().as_secs() < threshold * 2);
237    }
238
239    /// Get current configuration
240    pub fn config(&self) -> &InjectionConfig {
241        &self.config
242    }
243
244    /// Update configuration
245    pub fn set_config(&mut self, config: InjectionConfig) {
246        self.config = config;
247    }
248}
249
250// =============================================================================
251// INJECTION TRACKING (for feedback loop)
252// =============================================================================
253
254/// Record of an injection for feedback tracking
255#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct InjectionRecord {
257    pub memory_id: MemoryId,
258    pub injected_at: DateTime<Utc>,
259    pub relevance_score: f32,
260    pub context_signature: u64,
261}
262
263/// Tracks injections for feedback learning
264#[derive(Debug, Default)]
265pub struct InjectionTracker {
266    /// Recent injections awaiting feedback
267    pending: Vec<InjectionRecord>,
268    /// Max pending records to keep
269    max_pending: usize,
270}
271
272impl InjectionTracker {
273    pub fn new(max_pending: usize) -> Self {
274        Self {
275            pending: Vec::new(),
276            max_pending,
277        }
278    }
279
280    /// Record a new injection
281    pub fn record_injection(
282        &mut self,
283        memory_id: MemoryId,
284        relevance_score: f32,
285        context_signature: u64,
286    ) {
287        let record = InjectionRecord {
288            memory_id,
289            injected_at: Utc::now(),
290            relevance_score,
291            context_signature,
292        };
293
294        self.pending.push(record);
295
296        // Trim old records
297        if self.pending.len() > self.max_pending {
298            self.pending.remove(0);
299        }
300    }
301
302    /// Get pending injections for feedback analysis
303    pub fn pending_injections(&self) -> &[InjectionRecord] {
304        &self.pending
305    }
306
307    /// Clear injections older than given duration
308    pub fn clear_old(&mut self, max_age_seconds: i64) {
309        let cutoff = Utc::now() - chrono::Duration::seconds(max_age_seconds);
310        self.pending.retain(|r| r.injected_at > cutoff);
311    }
312
313    /// Remove specific injection after feedback processed
314    pub fn mark_processed(&mut self, memory_id: &MemoryId) {
315        self.pending.retain(|r| &r.memory_id != memory_id);
316    }
317}
318
319// =============================================================================
320// USER PROFILE (adaptive thresholds)
321// =============================================================================
322
323/// Feedback signal type for learning
324#[derive(Debug, Clone, Copy, PartialEq, Eq)]
325pub enum FeedbackSignal {
326    /// Memory was referenced/used - lower threshold
327    Positive,
328    /// Memory was explicitly rejected - raise threshold
329    Negative,
330    /// Memory was ignored - no change
331    Neutral,
332}
333
334/// Per-user adaptive injection profile
335#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct UserInjectionProfile {
337    pub user_id: String,
338    /// Effective threshold (starts at default, adapts over time)
339    pub effective_threshold: f32,
340    /// Count of positive signals received
341    pub positive_signals: u32,
342    /// Count of negative signals received
343    pub negative_signals: u32,
344    /// Last update timestamp
345    pub updated_at: DateTime<Utc>,
346}
347
348impl UserInjectionProfile {
349    pub fn new(user_id: String) -> Self {
350        Self {
351            user_id,
352            effective_threshold: InjectionConfig::default().min_relevance,
353            positive_signals: 0,
354            negative_signals: 0,
355            updated_at: Utc::now(),
356        }
357    }
358
359    /// Adjust threshold based on feedback signal
360    ///
361    /// - Positive: lower threshold by 0.01 (min 0.50)
362    /// - Negative: raise threshold by 0.02 (max 0.90)
363    /// - Neutral: no change
364    ///
365    /// Asymmetric adjustment: we're more cautious about noise
366    pub fn adjust(&mut self, signal: FeedbackSignal) {
367        match signal {
368            FeedbackSignal::Positive => {
369                self.positive_signals += 1;
370                self.effective_threshold = (self.effective_threshold - 0.01).max(0.50);
371            }
372            FeedbackSignal::Negative => {
373                self.negative_signals += 1;
374                self.effective_threshold = (self.effective_threshold + 0.02).min(0.90);
375            }
376            FeedbackSignal::Neutral => {}
377        }
378        self.updated_at = Utc::now();
379    }
380
381    /// Get signal ratio (positive / total)
382    pub fn signal_ratio(&self) -> f32 {
383        let total = self.positive_signals + self.negative_signals;
384        if total == 0 {
385            0.5 // No data yet
386        } else {
387            self.positive_signals as f32 / total as f32
388        }
389    }
390}
391
392// =============================================================================
393// TESTS
394// =============================================================================
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use uuid::Uuid;
400
401    #[test]
402    fn test_injection_engine_filtering() {
403        let mut engine = InjectionEngine::with_default_config();
404
405        let candidates = vec![
406            InjectionCandidate {
407                memory_id: MemoryId(Uuid::new_v4()),
408                relevance_score: 0.85,
409            },
410            InjectionCandidate {
411                memory_id: MemoryId(Uuid::new_v4()),
412                relevance_score: 0.45, // Below threshold (0.50)
413            },
414            InjectionCandidate {
415                memory_id: MemoryId(Uuid::new_v4()),
416                relevance_score: 0.75,
417            },
418        ];
419
420        let selected = engine.select_for_injection(candidates);
421
422        assert_eq!(selected.len(), 2); // Only 0.85 and 0.75 pass threshold (0.50)
423    }
424
425    #[test]
426    fn test_user_profile_adjustment() {
427        let mut profile = UserInjectionProfile::new("test-user".to_string());
428
429        assert_eq!(profile.effective_threshold, 0.50);
430
431        profile.adjust(FeedbackSignal::Positive);
432        assert!((profile.effective_threshold - 0.49).abs() < 0.01);
433
434        profile.adjust(FeedbackSignal::Negative);
435        assert!((profile.effective_threshold - 0.51).abs() < 0.01);
436
437        // Many negatives should cap at 0.90
438        for _ in 0..20 {
439            profile.adjust(FeedbackSignal::Negative);
440        }
441        assert_eq!(profile.effective_threshold, 0.90);
442    }
443}