Skip to main content

shodh_memory/memory/
feedback.rs

1//! Implicit Feedback System for Memory Reinforcement
2//!
3//! Extracts feedback signals from agent behavior without explicit ratings.
4//! Uses entity overlap, semantic similarity, and user corrections to
5//! determine memory usefulness. Implements momentum-based updates with
6//! type-dependent inertia to prevent noise from destabilizing useful memories.
7
8use chrono::{DateTime, Duration, Utc};
9use rocksdb::{ColumnFamily, ColumnFamilyDescriptor, IteratorMode, Options, WriteBatch, DB};
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet, VecDeque};
12use std::path::Path;
13use std::sync::Arc;
14
15use crate::memory::types::{ExperienceType, MemoryId};
16
17// =============================================================================
18// CONSTANTS
19// =============================================================================
20
21/// Column family name for feedback data in the shared RocksDB instance
22pub(crate) const CF_FEEDBACK: &str = "feedback";
23
24/// Maximum number of recent signals to keep for trend detection
25const MAX_RECENT_SIGNALS: usize = 20;
26
27/// Maximum context fingerprints per memory
28const MAX_CONTEXT_FINGERPRINTS: usize = 100;
29
30/// Entity overlap thresholds
31/// FBK-4: Lowered thresholds so weak signals (0.06-0.15 range) actually affect learning
32const OVERLAP_STRONG_THRESHOLD: f32 = 0.4;
33const OVERLAP_WEAK_THRESHOLD: f32 = 0.1;
34
35/// Semantic similarity thresholds
36/// FBK-4: Lowered to catch more meaningful signals
37const SEMANTIC_STRONG_THRESHOLD: f32 = 0.6;
38const SEMANTIC_WEAK_THRESHOLD: f32 = 0.3;
39
40/// Signal value multipliers (ACT-R inspired)
41const SIGNAL_STRONG_MULTIPLIER: f32 = 0.8;
42const SIGNAL_WEAK_MULTIPLIER: f32 = 0.3;
43const SIGNAL_NO_OVERLAP_PENALTY: f32 = -0.2; // Strengthened: was -0.1 (FBK-3)
44const SIGNAL_NEGATIVE_KEYWORD_PENALTY: f32 = -0.5;
45
46/// Action-based signals (FBK-1, FBK-2)
47const SIGNAL_REPETITION_PENALTY: f32 = -0.4; // User asked again = memories failed
48const SIGNAL_TOPIC_CHANGE_BOOST: f32 = 0.2; // User moved on = task might be complete
49const SIGNAL_IGNORED_PENALTY: f32 = -0.2; // Memory shown but completely unused
50
51/// Weights for combining entity and semantic signals
52const ENTITY_WEIGHT: f32 = 0.4;
53const SEMANTIC_WEIGHT: f32 = 0.6;
54
55/// Stability adjustment rates
56const STABILITY_INCREMENT: f32 = 0.05;
57const STABILITY_DECREMENT_MULTIPLIER: f32 = 0.1;
58
59/// Trend detection thresholds
60const TREND_IMPROVING_THRESHOLD: f32 = 0.1;
61const TREND_DECLINING_THRESHOLD: f32 = -0.1;
62
63/// Time decay constants for momentum (AUD-6)
64/// Momentum should decay towards 0 when not reinforced
65const DECAY_HALF_LIFE_DAYS: f32 = 14.0; // Half-life of 14 days
66
67/// Negative keywords indicating correction/failure
68/// Multi-word phrases checked first (contains match on lowercased text)
69const NEGATIVE_KEYWORDS: &[&str] = &[
70    // Direct negation / correction
71    "wrong",
72    "incorrect",
73    "not correct",
74    "nope",
75    // Frustration / repetition
76    "not what i meant",
77    "that's not right",
78    "that's wrong",
79    "i already said",
80    "i told you",
81    "i already told",
82    "already mentioned",
83    // Irrelevance / unhelpfulness
84    "not helpful",
85    "not relevant",
86    "not useful",
87    "irrelevant",
88    "useless",
89    "doesn't help",
90    "didn't help",
91    "not related",
92    // Failure / broken
93    "doesn't work",
94    "didn't work",
95    "broken",
96    "still broken",
97    "that failed",
98    // Explicit rejection
99    "forget that",
100    "ignore that",
101    "disregard",
102    "stop suggesting",
103    "don't show",
104];
105
106// =============================================================================
107// SIGNAL TYPES
108// =============================================================================
109
110/// What triggered a feedback signal
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub enum SignalTrigger {
113    /// Entity overlap between memory and agent response
114    EntityOverlap { overlap_ratio: f32 },
115
116    /// Semantic similarity between memory and response
117    SemanticSimilarity { similarity: f32 },
118
119    /// Negative keywords detected in user's followup
120    NegativeKeywords { keywords: Vec<String> },
121
122    /// User repeated the same question (retrieval failed)
123    /// Action: user asked again → memories didn't help
124    UserRepetition { similarity: f32 },
125
126    /// Topic changed successfully (task completed)
127    /// Action: user moved on → memories may have helped
128    TopicChange { similarity: f32 },
129
130    /// Memory was surfaced but completely ignored
131    /// Action: response has no relation to memory
132    Ignored { overlap_ratio: f32 },
133
134    /// FBK-8: Entity flow tracking
135    /// Measures how response builds on memory entities
136    /// - derived_ratio: proportion of response entities that came from memory
137    /// - novel_ratio: proportion of response entities that are new
138    EntityFlow {
139        derived_ratio: f32,
140        novel_ratio: f32,
141        memory_entities_used: usize,
142        response_entities_total: usize,
143    },
144}
145
146/// A single feedback signal
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct SignalRecord {
149    /// When the signal was recorded
150    pub timestamp: DateTime<Utc>,
151
152    /// Signal value: -1.0 (misleading) to +1.0 (helpful)
153    pub value: f32,
154
155    /// Confidence in this signal (0.0 to 1.0)
156    pub confidence: f32,
157
158    /// What triggered this signal
159    pub trigger: SignalTrigger,
160}
161
162impl SignalRecord {
163    pub fn new(value: f32, confidence: f32, trigger: SignalTrigger) -> Self {
164        Self {
165            timestamp: Utc::now(),
166            value: value.clamp(-1.0, 1.0),
167            confidence: confidence.clamp(0.0, 1.0),
168            trigger,
169        }
170    }
171
172    /// Create signal from entity overlap ratio
173    pub fn from_entity_overlap(overlap_ratio: f32) -> Self {
174        let (value, confidence) = if overlap_ratio >= OVERLAP_STRONG_THRESHOLD {
175            (SIGNAL_STRONG_MULTIPLIER * overlap_ratio, 0.9)
176        } else if overlap_ratio >= OVERLAP_WEAK_THRESHOLD {
177            (SIGNAL_WEAK_MULTIPLIER * overlap_ratio, 0.6)
178        } else {
179            (SIGNAL_NO_OVERLAP_PENALTY, 0.4)
180        };
181
182        Self::new(
183            value,
184            confidence,
185            SignalTrigger::EntityOverlap { overlap_ratio },
186        )
187    }
188
189    /// Create signal from negative keyword detection
190    pub fn from_negative_keywords(keywords: Vec<String>) -> Self {
191        Self::new(
192            SIGNAL_NEGATIVE_KEYWORD_PENALTY,
193            0.95, // High confidence - explicit correction
194            SignalTrigger::NegativeKeywords { keywords },
195        )
196    }
197}
198
199// =============================================================================
200// TREND DETECTION
201// =============================================================================
202
203/// Trend direction for a memory
204#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
205pub enum Trend {
206    /// Memory is becoming more useful over time
207    Improving,
208    /// Memory usefulness is stable
209    Stable,
210    /// Memory is becoming less useful (possibly outdated)
211    Declining,
212    /// Not enough data to determine trend
213    Insufficient,
214}
215
216impl Trend {
217    /// Calculate trend from recent signals using linear regression
218    pub fn from_signals(signals: &VecDeque<SignalRecord>) -> Self {
219        if signals.len() < 3 {
220            return Trend::Insufficient;
221        }
222
223        let n = signals.len() as f32;
224        let mut sum_x = 0.0;
225        let mut sum_y = 0.0;
226        let mut sum_xy = 0.0;
227        let mut sum_xx = 0.0;
228
229        for (i, signal) in signals.iter().enumerate() {
230            let x = i as f32;
231            let y = signal.value;
232            sum_x += x;
233            sum_y += y;
234            sum_xy += x * y;
235            sum_xx += x * x;
236        }
237
238        // Linear regression slope: (n*Σxy - Σx*Σy) / (n*Σxx - Σx²)
239        let denominator = n * sum_xx - sum_x * sum_x;
240        if denominator.abs() < f32::EPSILON {
241            return Trend::Stable;
242        }
243
244        let slope = (n * sum_xy - sum_x * sum_y) / denominator;
245
246        if slope > TREND_IMPROVING_THRESHOLD {
247            Trend::Improving
248        } else if slope < TREND_DECLINING_THRESHOLD {
249            Trend::Declining
250        } else {
251            Trend::Stable
252        }
253    }
254}
255
256// =============================================================================
257// CONTEXT FINGERPRINT
258// =============================================================================
259
260/// Fingerprint of a context for pattern detection
261/// Tracks which contexts a memory was helpful vs misleading in
262#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct ContextFingerprint {
264    /// Top entities in the context
265    pub entities: Vec<String>,
266
267    /// Compressed embedding signature (top 16 components)
268    pub embedding_signature: [f32; 16],
269
270    /// When this context occurred
271    pub timestamp: DateTime<Utc>,
272
273    /// Was the memory helpful in this context?
274    pub was_helpful: bool,
275}
276
277impl ContextFingerprint {
278    pub fn new(entities: Vec<String>, embedding: &[f32], was_helpful: bool) -> Self {
279        // Compress embedding to 16 components by taking evenly spaced samples
280        let mut signature = [0.0f32; 16];
281        if !embedding.is_empty() {
282            let step = embedding.len() / 16;
283            for (i, sig) in signature.iter_mut().enumerate() {
284                let idx = (i * step).min(embedding.len() - 1);
285                *sig = embedding[idx];
286            }
287        }
288
289        Self {
290            entities,
291            embedding_signature: signature,
292            timestamp: Utc::now(),
293            was_helpful,
294        }
295    }
296
297    /// Calculate similarity to another fingerprint
298    pub fn similarity(&self, other: &ContextFingerprint) -> f32 {
299        // Entity Jaccard similarity
300        let self_set: HashSet<_> = self.entities.iter().collect();
301        let other_set: HashSet<_> = other.entities.iter().collect();
302        let intersection = self_set.intersection(&other_set).count() as f32;
303        let union = self_set.union(&other_set).count() as f32;
304        let entity_sim = if union > 0.0 {
305            intersection / union
306        } else {
307            0.0
308        };
309
310        // Embedding cosine similarity
311        let mut dot = 0.0;
312        let mut norm_a = 0.0;
313        let mut norm_b = 0.0;
314        for i in 0..16 {
315            dot += self.embedding_signature[i] * other.embedding_signature[i];
316            norm_a += self.embedding_signature[i] * self.embedding_signature[i];
317            norm_b += other.embedding_signature[i] * other.embedding_signature[i];
318        }
319        let embed_sim = if norm_a > 0.0 && norm_b > 0.0 {
320            dot / (norm_a.sqrt() * norm_b.sqrt())
321        } else {
322            0.0
323        };
324
325        // Weighted combination
326        entity_sim * 0.6 + embed_sim * 0.4
327    }
328}
329
330// =============================================================================
331// FEEDBACK MOMENTUM
332// =============================================================================
333
334/// Tracks feedback history for a single memory
335/// Implements momentum-based updates with type-dependent inertia
336#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct FeedbackMomentum {
338    /// Memory this momentum belongs to
339    pub memory_id: MemoryId,
340
341    /// Memory type (for inertia calculation)
342    pub memory_type: ExperienceType,
343
344    /// Exponential moving average of feedback signals
345    /// Range: -1.0 (always misleading) to +1.0 (always helpful)
346    pub ema: f32,
347
348    /// How many feedback signals have we received?
349    pub signal_count: u32,
350
351    /// Stability score: how consistent is the feedback?
352    /// High stability = resistant to change
353    pub stability: f32,
354
355    /// When did we first evaluate this memory?
356    pub first_signal_at: Option<DateTime<Utc>>,
357
358    /// When was the last signal?
359    pub last_signal_at: Option<DateTime<Utc>>,
360
361    /// Recent signals for trend detection
362    pub recent_signals: VecDeque<SignalRecord>,
363
364    /// Contexts where this memory was helpful
365    pub helpful_contexts: Vec<ContextFingerprint>,
366
367    /// Contexts where this memory was misleading
368    pub misleading_contexts: Vec<ContextFingerprint>,
369}
370
371impl FeedbackMomentum {
372    pub fn new(memory_id: MemoryId, memory_type: ExperienceType) -> Self {
373        Self {
374            memory_id,
375            memory_type,
376            ema: 0.0,
377            signal_count: 0,
378            stability: 0.5, // Start neutral
379            first_signal_at: None,
380            last_signal_at: None,
381            recent_signals: VecDeque::with_capacity(MAX_RECENT_SIGNALS),
382            helpful_contexts: Vec::new(),
383            misleading_contexts: Vec::new(),
384        }
385    }
386
387    /// Get base inertia for memory type
388    /// Higher inertia = more resistant to change
389    pub fn base_inertia(&self) -> f32 {
390        match self.memory_type {
391            ExperienceType::Learning => 0.95,
392            ExperienceType::Decision => 0.90,
393            ExperienceType::Pattern => 0.85,
394            ExperienceType::Discovery => 0.75,
395            ExperienceType::Context => 0.60,
396            ExperienceType::Task => 0.50,
397            ExperienceType::Observation => 0.40,
398            ExperienceType::Conversation => 0.30,
399            ExperienceType::Error => 0.20,
400            // Others default to medium
401            ExperienceType::CodeEdit => 0.50,
402            ExperienceType::FileAccess => 0.40,
403            ExperienceType::Search => 0.35,
404            ExperienceType::Command => 0.35,
405            ExperienceType::Intention => 0.60,
406        }
407    }
408
409    /// Calculate age factor for inertia
410    /// Older memories are more stable
411    pub fn age_factor(&self) -> f32 {
412        let age_days = self
413            .first_signal_at
414            .map(|first| {
415                let duration = Utc::now() - first;
416                duration.num_days() as f32
417            })
418            .unwrap_or(0.0);
419
420        if age_days < 1.0 {
421            0.8 // New, still malleable
422        } else if age_days < 7.0 {
423            0.9 // Consolidating
424        } else if age_days < 30.0 {
425            1.0 // Consolidated
426        } else {
427            1.1 // Deeply encoded
428        }
429    }
430
431    /// Calculate history factor for inertia
432    /// More evaluations = more confidence = more inertia
433    pub fn history_factor(&self) -> f32 {
434        match self.signal_count {
435            0..=2 => 0.7,   // Not enough data
436            3..=9 => 0.9,   // Some history
437            10..=49 => 1.0, // Good history
438            _ => 1.1,       // Very well tested
439        }
440    }
441
442    /// Calculate stability factor for inertia
443    /// Consistent history = resist change
444    pub fn stability_factor(&self) -> f32 {
445        // Map stability 0.0-1.0 to factor 0.8-1.2
446        0.8 + (self.stability * 0.4)
447    }
448
449    /// Calculate effective inertia combining all factors
450    pub fn effective_inertia(&self) -> f32 {
451        let inertia = self.base_inertia()
452            * self.age_factor()
453            * self.history_factor()
454            * self.stability_factor();
455
456        // Clamp to valid range - never fully frozen, never fully fluid
457        inertia.clamp(0.5, 0.99)
458    }
459
460    /// Calculate recency weight for a signal
461    pub fn recency_weight(&self, signal_time: DateTime<Utc>) -> f32 {
462        let time_since_last = self
463            .last_signal_at
464            .map(|last| signal_time - last)
465            .unwrap_or_else(Duration::zero);
466
467        if time_since_last < Duration::hours(1) {
468            1.0
469        } else if time_since_last < Duration::days(1) {
470            0.9
471        } else if time_since_last < Duration::days(7) {
472            0.7
473        } else {
474            0.5
475        }
476    }
477
478    /// Update momentum with a new signal
479    pub fn update(&mut self, signal: SignalRecord) {
480        let now = signal.timestamp;
481
482        // Initialize first signal time if needed
483        if self.first_signal_at.is_none() {
484            self.first_signal_at = Some(now);
485        }
486
487        // Calculate effective inertia before update
488        let effective_inertia = self.effective_inertia();
489        let recency = self.recency_weight(now);
490
491        // Alpha = how much new signal affects EMA
492        // High inertia = low alpha = resistant to change
493        let alpha = (1.0 - effective_inertia) * recency * signal.confidence;
494
495        // Store old EMA for stability calculation
496        let old_ema = self.ema;
497
498        // Update EMA
499        self.ema = old_ema * (1.0 - alpha) + signal.value * alpha;
500
501        // Update stability
502        let direction_matches =
503            (signal.value > 0.0) == (old_ema > 0.0) || old_ema.abs() < f32::EPSILON;
504
505        if direction_matches {
506            // Consistent feedback: increase stability
507            self.stability = (self.stability + STABILITY_INCREMENT).min(1.0);
508        } else {
509            // Contradictory feedback: decrease stability
510            let contradiction_strength = (signal.value - old_ema).abs();
511            self.stability =
512                (self.stability - STABILITY_DECREMENT_MULTIPLIER * contradiction_strength).max(0.0);
513        }
514
515        // Record signal
516        self.recent_signals.push_back(signal);
517        if self.recent_signals.len() > MAX_RECENT_SIGNALS {
518            self.recent_signals.pop_front();
519        }
520
521        self.signal_count += 1;
522        self.last_signal_at = Some(now);
523    }
524
525    /// Get current trend
526    pub fn trend(&self) -> Trend {
527        Trend::from_signals(&self.recent_signals)
528    }
529
530    /// Add context fingerprint
531    pub fn add_context(&mut self, fingerprint: ContextFingerprint) {
532        let target = if fingerprint.was_helpful {
533            &mut self.helpful_contexts
534        } else {
535            &mut self.misleading_contexts
536        };
537
538        target.push(fingerprint);
539
540        // Trim to max size, keeping most recent
541        if target.len() > MAX_CONTEXT_FINGERPRINTS {
542            target.remove(0);
543        }
544    }
545
546    /// Check if current context matches helpful pattern
547    pub fn matches_helpful_pattern(&self, current: &ContextFingerprint) -> Option<f32> {
548        self.helpful_contexts
549            .iter()
550            .map(|fp| fp.similarity(current))
551            .max_by(|a, b| a.total_cmp(b))
552    }
553
554    /// Check if current context matches misleading pattern
555    pub fn matches_misleading_pattern(&self, current: &ContextFingerprint) -> Option<f32> {
556        self.misleading_contexts
557            .iter()
558            .map(|fp| fp.similarity(current))
559            .max_by(|a, b| a.total_cmp(b))
560    }
561
562    /// Apply time-based decay to momentum (AUD-6)
563    /// Returns the decayed EMA value without mutating the struct.
564    /// Momentum decays towards 0 when not reinforced by feedback.
565    pub fn ema_with_decay(&self) -> f32 {
566        let days_since_last = self
567            .last_signal_at
568            .map(|last| {
569                let duration = Utc::now() - last;
570                duration.num_hours() as f32 / 24.0
571            })
572            .unwrap_or(0.0);
573
574        if days_since_last < 0.1 {
575            // Very recent signal, no decay
576            return self.ema;
577        }
578
579        // Exponential decay with half-life
580        // decay_factor = 0.5^(days / half_life)
581        let decay_factor = 0.5_f32.powf(days_since_last / DECAY_HALF_LIFE_DAYS);
582
583        // Decay towards 0
584        self.ema * decay_factor
585    }
586}
587
588// =============================================================================
589// PENDING FEEDBACK
590// =============================================================================
591
592/// Information about a surfaced memory awaiting feedback
593#[derive(Debug, Clone, Serialize, Deserialize)]
594pub struct SurfacedMemoryInfo {
595    pub id: MemoryId,
596    pub entities: HashSet<String>,
597    pub content_preview: String,
598    pub score: f32,
599    /// Memory embedding for semantic similarity feedback
600    #[serde(default)]
601    pub embedding: Vec<f32>,
602}
603
604/// Pending feedback for a user - tracks what was surfaced, awaiting response
605#[derive(Debug, Clone, Serialize, Deserialize)]
606pub struct PendingFeedback {
607    pub user_id: String,
608    pub surfaced_at: DateTime<Utc>,
609    pub surfaced_memories: Vec<SurfacedMemoryInfo>,
610    pub context: String,
611    pub context_embedding: Vec<f32>,
612}
613
614impl PendingFeedback {
615    pub fn new(
616        user_id: String,
617        context: String,
618        context_embedding: Vec<f32>,
619        memories: Vec<SurfacedMemoryInfo>,
620    ) -> Self {
621        Self {
622            user_id,
623            surfaced_at: Utc::now(),
624            surfaced_memories: memories,
625            context,
626            context_embedding,
627        }
628    }
629
630    /// Check if this pending feedback has expired (older than 1 hour)
631    pub fn is_expired(&self) -> bool {
632        Utc::now() - self.surfaced_at > Duration::hours(1)
633    }
634}
635
636// =============================================================================
637// SIGNAL EXTRACTION
638// =============================================================================
639
640/// Extract entities from text using simple word extraction
641/// TODO: Use NER model for better extraction
642pub fn extract_entities_simple(text: &str) -> HashSet<String> {
643    text.to_lowercase()
644        .split(|c: char| !c.is_alphanumeric() && c != '_')
645        .filter(|word| word.len() > 2)
646        .map(|s| s.to_string())
647        .collect()
648}
649
650/// Calculate entity overlap between memory entities and response entities
651pub fn calculate_entity_overlap(
652    memory_entities: &HashSet<String>,
653    response_entities: &HashSet<String>,
654) -> f32 {
655    if memory_entities.is_empty() {
656        return 0.0;
657    }
658
659    let intersection = memory_entities.intersection(response_entities).count() as f32;
660    intersection / memory_entities.len() as f32
661}
662
663/// Calculate cosine similarity between two embedding vectors
664fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
665    if a.len() != b.len() || a.is_empty() {
666        return 0.0;
667    }
668
669    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
670    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
671    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
672
673    if norm_a == 0.0 || norm_b == 0.0 {
674        return 0.0;
675    }
676
677    (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
678}
679
680/// Create signal from semantic similarity
681fn signal_from_semantic_similarity(similarity: f32) -> (f32, f32) {
682    if similarity >= SEMANTIC_STRONG_THRESHOLD {
683        (SIGNAL_STRONG_MULTIPLIER * similarity, 0.9)
684    } else if similarity >= SEMANTIC_WEAK_THRESHOLD {
685        (SIGNAL_WEAK_MULTIPLIER * similarity, 0.6)
686    } else {
687        (SIGNAL_NO_OVERLAP_PENALTY * 0.5, 0.3) // Lighter penalty for semantic
688    }
689}
690
691/// Detect negative keywords in user's followup message
692pub fn detect_negative_keywords(text: &str) -> Vec<String> {
693    let lower = text.to_lowercase();
694    NEGATIVE_KEYWORDS
695        .iter()
696        .filter(|&&kw| lower.contains(kw))
697        .map(|&s| s.to_string())
698        .collect()
699}
700
701/// FBK-8: Calculate entity flow between memory and response
702///
703/// Tracks how the response builds on memory entities:
704/// - derived_ratio: How many response entities came from the memory (0.0 to 1.0)
705/// - novel_ratio: How many response entities are new/not from memory (0.0 to 1.0)
706///
707/// High derived_ratio = response uses memory knowledge = positive signal
708/// High novel_ratio with low derived = memory might not have been relevant
709pub fn calculate_entity_flow(
710    memory_entities: &HashSet<String>,
711    response_entities: &HashSet<String>,
712) -> (f32, f32, usize, usize) {
713    if response_entities.is_empty() {
714        return (0.0, 0.0, 0, 0);
715    }
716
717    // Count how many response entities came from the memory
718    let derived: HashSet<_> = response_entities
719        .intersection(memory_entities)
720        .cloned()
721        .collect();
722    let derived_count = derived.len();
723
724    // Count novel entities (in response but not in memory)
725    let novel_count = response_entities.len() - derived_count;
726
727    let derived_ratio = derived_count as f32 / response_entities.len() as f32;
728    let novel_ratio = novel_count as f32 / response_entities.len() as f32;
729
730    (
731        derived_ratio,
732        novel_ratio,
733        derived_count,
734        response_entities.len(),
735    )
736}
737
738/// FBK-8: Create signal from entity flow analysis
739pub fn signal_from_entity_flow(
740    derived_ratio: f32,
741    novel_ratio: f32,
742    memory_entities_used: usize,
743    response_entities_total: usize,
744) -> SignalRecord {
745    // Signal value based on how much the response builds on memory
746    // High derived ratio = memory was useful
747    // Low derived ratio with high novel = memory might be irrelevant
748    let value = if derived_ratio >= 0.5 {
749        // Response heavily uses memory entities - strong positive
750        0.6 + (derived_ratio - 0.5) * 0.4
751    } else if derived_ratio >= 0.2 {
752        // Response somewhat uses memory entities - weak positive
753        derived_ratio * 1.5
754    } else if novel_ratio >= 0.8 {
755        // Response mostly novel, memory barely used - slight negative
756        -0.1
757    } else {
758        // Mixed - neutral
759        0.0
760    };
761
762    let confidence = if response_entities_total >= 3 {
763        0.8 // Good sample size
764    } else {
765        0.5 // Small sample, lower confidence
766    };
767
768    SignalRecord::new(
769        value,
770        confidence,
771        SignalTrigger::EntityFlow {
772            derived_ratio,
773            novel_ratio,
774            memory_entities_used,
775            response_entities_total,
776        },
777    )
778}
779
780/// Process feedback for surfaced memories based on agent response
781/// Uses both entity overlap and semantic similarity for more accurate signals
782pub fn process_implicit_feedback(
783    pending: &PendingFeedback,
784    response_text: &str,
785    user_followup: Option<&str>,
786) -> Vec<(MemoryId, SignalRecord)> {
787    // For backwards compatibility, call enhanced version with no response embedding
788    process_implicit_feedback_with_semantics(pending, response_text, user_followup, None)
789}
790
791/// Enhanced feedback processing using both entity overlap and semantic similarity
792///
793/// When response_embedding is provided, combines entity overlap (40%) with
794/// semantic similarity (60%) for a more robust feedback signal. This helps
795/// detect when a memory was genuinely useful vs just sharing some words.
796pub fn process_implicit_feedback_with_semantics(
797    pending: &PendingFeedback,
798    response_text: &str,
799    user_followup: Option<&str>,
800    response_embedding: Option<&[f32]>,
801) -> Vec<(MemoryId, SignalRecord)> {
802    let response_entities = extract_entities_simple(response_text);
803    let mut signals = Vec::new();
804
805    // Calculate combined signals for each memory
806    for memory in &pending.surfaced_memories {
807        // Entity overlap signal
808        let entity_overlap = calculate_entity_overlap(&memory.entities, &response_entities);
809        let (entity_value, entity_conf) = if entity_overlap >= OVERLAP_STRONG_THRESHOLD {
810            (SIGNAL_STRONG_MULTIPLIER * entity_overlap, 0.9)
811        } else if entity_overlap >= OVERLAP_WEAK_THRESHOLD {
812            (SIGNAL_WEAK_MULTIPLIER * entity_overlap, 0.6)
813        } else {
814            (SIGNAL_NO_OVERLAP_PENALTY, 0.4)
815        };
816
817        // Semantic similarity signal (if embeddings available)
818        let (semantic_value, semantic_conf, has_semantic) =
819            if let Some(resp_emb) = response_embedding {
820                if !memory.embedding.is_empty() {
821                    let similarity = cosine_similarity(&memory.embedding, resp_emb);
822                    let (val, conf) = signal_from_semantic_similarity(similarity);
823                    (val, conf, true)
824                } else {
825                    (0.0, 0.0, false)
826                }
827            } else {
828                (0.0, 0.0, false)
829            };
830
831        // Combine signals with weights
832        let (combined_value, combined_confidence, trigger) = if has_semantic {
833            let value = (ENTITY_WEIGHT * entity_value) + (SEMANTIC_WEIGHT * semantic_value);
834            let confidence = (ENTITY_WEIGHT * entity_conf) + (SEMANTIC_WEIGHT * semantic_conf);
835
836            // Use semantic similarity as trigger since it's the primary signal
837            let similarity = if let Some(resp_emb) = response_embedding {
838                cosine_similarity(&memory.embedding, resp_emb)
839            } else {
840                0.0
841            };
842            (
843                value,
844                confidence,
845                SignalTrigger::SemanticSimilarity { similarity },
846            )
847        } else {
848            // Fallback to entity-only signal
849            (
850                entity_value,
851                entity_conf,
852                SignalTrigger::EntityOverlap {
853                    overlap_ratio: entity_overlap,
854                },
855            )
856        };
857
858        let mut signal = SignalRecord::new(combined_value, combined_confidence, trigger);
859
860        // Apply negative keyword penalty if detected in followup
861        if let Some(followup) = user_followup {
862            let negative = detect_negative_keywords(followup);
863            if !negative.is_empty() {
864                signal.value += SIGNAL_NEGATIVE_KEYWORD_PENALTY;
865                signal.value = signal.value.clamp(-1.0, 1.0);
866                signal.confidence = 0.95; // High confidence on explicit correction
867            }
868        }
869
870        signals.push((memory.id.clone(), signal));
871    }
872
873    signals
874}
875
876/// Apply context pattern signals (repetition/topic change) to existing signals
877///
878/// This function modifies signal values based on detected user actions:
879/// - Repetition (user asked same thing again): negative signal (memories failed)
880/// - Topic change (user moved on): positive signal (task might be complete)
881/// - Ignored (memory shown but no overlap): negative signal
882///
883/// # Arguments
884/// - `signals`: Existing signals from process_implicit_feedback
885/// - `is_repetition`: User is asking the same question again
886/// - `is_topic_change`: User has moved to a different topic
887/// - `context_similarity`: Similarity between current and previous context
888pub fn apply_context_pattern_signals(
889    signals: &mut [(MemoryId, SignalRecord)],
890    is_repetition: bool,
891    is_topic_change: bool,
892    _context_similarity: f32,
893) {
894    for (memory_id, signal) in signals.iter_mut() {
895        if is_repetition {
896            // User asked the same thing again - memories didn't help
897            // Apply penalty proportional to how irrelevant the memory was
898            // FBK-4: Lowered threshold from 0.3 to 0.15 so more signals affect learning
899            if signal.value < 0.15 {
900                // Memory wasn't used in response AND user is re-asking
901                signal.value += SIGNAL_REPETITION_PENALTY;
902                signal.value = signal.value.clamp(-1.0, 1.0);
903                signal.trigger = SignalTrigger::UserRepetition {
904                    similarity: _context_similarity,
905                };
906                signal.confidence = 0.85; // High confidence - clear action signal
907                tracing::debug!(
908                    "Repetition detected for memory {:?}: applied penalty",
909                    memory_id
910                );
911            }
912        } else if is_topic_change {
913            // User moved on to different topic - task might be complete
914            // Apply boost to memories that were used in the response
915            // FBK-4: Lowered threshold from 0.1 to 0.05 so more signals affect learning
916            if signal.value > 0.05 {
917                // Memory was somewhat used - boost it
918                signal.value += SIGNAL_TOPIC_CHANGE_BOOST;
919                signal.value = signal.value.clamp(-1.0, 1.0);
920                signal.trigger = SignalTrigger::TopicChange {
921                    similarity: _context_similarity,
922                };
923                signal.confidence = 0.7; // Moderate confidence
924                tracing::debug!(
925                    "Topic change detected for memory {:?}: applied boost",
926                    memory_id
927                );
928            }
929        }
930
931        // Apply ignored penalty for memories with very low overlap
932        // regardless of repetition/topic change
933        if signal.value < -0.05 && signal.value > -0.3 {
934            // Memory was surfaced but not used - strengthen the penalty
935            signal.value = SIGNAL_IGNORED_PENALTY.min(signal.value);
936            if !matches!(signal.trigger, SignalTrigger::UserRepetition { .. }) {
937                signal.trigger = SignalTrigger::Ignored {
938                    overlap_ratio: match &signal.trigger {
939                        SignalTrigger::EntityOverlap { overlap_ratio } => *overlap_ratio,
940                        _ => 0.0,
941                    },
942                };
943            }
944        }
945    }
946}
947
948// =============================================================================
949// FEEDBACK STORE
950// =============================================================================
951
952/// Previous context for a user - used for repetition/topic change detection
953#[derive(Debug, Clone, Serialize, Deserialize)]
954pub struct PreviousContext {
955    /// The query/context text
956    pub context: String,
957    /// Embedding of the context for similarity comparison
958    pub embedding: Vec<f32>,
959    /// When this context was recorded
960    pub timestamp: DateTime<Utc>,
961    /// Memory IDs that were surfaced for this context
962    pub surfaced_memory_ids: Vec<MemoryId>,
963}
964
965/// Persistent store for feedback momentum with in-memory cache
966pub struct FeedbackStore {
967    /// In-memory cache: memory_id -> FeedbackMomentum
968    pub momentum: HashMap<MemoryId, FeedbackMomentum>,
969
970    /// Pending feedback per user: user_id -> PendingFeedback (in-memory only)
971    pending: HashMap<String, PendingFeedback>,
972
973    /// Previous context per user: for repetition/topic change detection
974    /// Tracks what the user asked last time to detect patterns
975    previous_context: HashMap<String, PreviousContext>,
976
977    /// Persistent storage for momentum data
978    db: Option<Arc<DB>>,
979
980    /// Track dirty entries that need persistence
981    dirty: HashSet<MemoryId>,
982}
983
984impl std::fmt::Debug for FeedbackStore {
985    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
986        f.debug_struct("FeedbackStore")
987            .field("momentum_count", &self.momentum.len())
988            .field("pending_count", &self.pending.len())
989            .field("previous_context_count", &self.previous_context.len())
990            .field("has_db", &self.db.is_some())
991            .field("dirty_count", &self.dirty.len())
992            .finish()
993    }
994}
995
996impl Default for FeedbackStore {
997    fn default() -> Self {
998        Self {
999            momentum: HashMap::new(),
1000            pending: HashMap::new(),
1001            previous_context: HashMap::new(),
1002            db: None,
1003            dirty: HashSet::new(),
1004        }
1005    }
1006}
1007
1008impl FeedbackStore {
1009    /// Create in-memory only store (no persistence)
1010    pub fn new() -> Self {
1011        Self::default()
1012    }
1013
1014    /// Get a reference to the feedback column family handle.
1015    /// Returns `None` when running in-memory only or when the CF is missing.
1016    fn feedback_cf(&self) -> Option<&ColumnFamily> {
1017        self.db.as_ref().and_then(|db| db.cf_handle(CF_FEEDBACK))
1018    }
1019
1020    /// Create persistent store backed by a shared RocksDB instance.
1021    ///
1022    /// The caller is responsible for opening the DB with the `CF_FEEDBACK` column
1023    /// family already declared. On first use this constructor migrates data from
1024    /// the legacy standalone `feedback/` DB directory into the shared CF.
1025    pub fn with_shared_db(db: Arc<DB>, base_path: &Path) -> anyhow::Result<Self> {
1026        Self::migrate_from_separate_db(base_path, &db)?;
1027
1028        let cf = db.cf_handle(CF_FEEDBACK).expect("feedback CF must exist");
1029
1030        // Load all momentum entries from the feedback CF
1031        let mut momentum = HashMap::new();
1032        let iter = db.prefix_iterator_cf(cf, b"momentum:");
1033        for item in iter {
1034            if let Ok((key, value)) = item {
1035                if let Ok(key_str) = std::str::from_utf8(&key) {
1036                    if !key_str.starts_with("momentum:") {
1037                        break;
1038                    }
1039                    if let Ok(m) = serde_json::from_slice::<FeedbackMomentum>(&value) {
1040                        momentum.insert(m.memory_id.clone(), m);
1041                    }
1042                }
1043            }
1044        }
1045
1046        let mut pending = HashMap::new();
1047        let iter = db.prefix_iterator_cf(cf, b"pending:");
1048        for item in iter {
1049            if let Ok((key, value)) = item {
1050                if let Ok(key_str) = std::str::from_utf8(&key) {
1051                    if !key_str.starts_with("pending:") {
1052                        break;
1053                    }
1054                    if let Ok(p) = serde_json::from_slice::<PendingFeedback>(&value) {
1055                        if !p.is_expired() {
1056                            pending.insert(p.user_id.clone(), p);
1057                        } else {
1058                            let _ = db.delete_cf(cf, key_str.as_bytes());
1059                        }
1060                    }
1061                }
1062            }
1063        }
1064
1065        let mut previous_context = HashMap::new();
1066        let iter = db.prefix_iterator_cf(cf, b"prev_ctx:");
1067        for item in iter {
1068            if let Ok((key, value)) = item {
1069                if let Ok(key_str) = std::str::from_utf8(&key) {
1070                    if !key_str.starts_with("prev_ctx:") {
1071                        break;
1072                    }
1073                    if let Ok(ctx) = serde_json::from_slice::<PreviousContext>(&value) {
1074                        let user_id = key_str.strip_prefix("prev_ctx:").unwrap_or("");
1075                        previous_context.insert(user_id.to_string(), ctx);
1076                    }
1077                }
1078            }
1079        }
1080
1081        tracing::info!(
1082            "Loaded {} momentum, {} pending, {} previous context from shared feedback CF",
1083            momentum.len(),
1084            pending.len(),
1085            previous_context.len()
1086        );
1087
1088        Ok(Self {
1089            momentum,
1090            pending,
1091            previous_context,
1092            db: Some(db),
1093            dirty: HashSet::new(),
1094        })
1095    }
1096
1097    /// Migrate data from the legacy standalone `feedback/` RocksDB directory
1098    /// into the `CF_FEEDBACK` column family of the shared DB.
1099    ///
1100    /// The old directory is renamed to `feedback.pre_cf_migration` so it can be
1101    /// restored manually if needed.
1102    fn migrate_from_separate_db(base_path: &Path, db: &DB) -> anyhow::Result<()> {
1103        let old_dir = base_path.join("feedback");
1104        if !old_dir.is_dir() {
1105            return Ok(());
1106        }
1107
1108        let cf = db.cf_handle(CF_FEEDBACK).expect("feedback CF must exist");
1109        let old_opts = Options::default();
1110        match DB::open_for_read_only(&old_opts, &old_dir, false) {
1111            Ok(old_db) => {
1112                let mut batch = WriteBatch::default();
1113                let mut count = 0usize;
1114                for item in old_db.iterator(IteratorMode::Start) {
1115                    if let Ok((key, value)) = item {
1116                        batch.put_cf(cf, &key, &value);
1117                        count += 1;
1118                        if count % 10_000 == 0 {
1119                            db.write(std::mem::take(&mut batch))?;
1120                        }
1121                    }
1122                }
1123                if !batch.is_empty() {
1124                    db.write(batch)?;
1125                }
1126                drop(old_db);
1127                tracing::info!("  feedback: migrated {count} entries to {CF_FEEDBACK} CF");
1128
1129                let backup = base_path.join("feedback.pre_cf_migration");
1130                if backup.exists() {
1131                    let _ = std::fs::remove_dir_all(&backup);
1132                }
1133                if let Err(e) = std::fs::rename(&old_dir, &backup) {
1134                    tracing::warn!("Could not rename old feedback dir: {e}");
1135                }
1136            }
1137            Err(e) => tracing::warn!("Could not open old feedback DB for migration: {e}"),
1138        }
1139        Ok(())
1140    }
1141
1142    /// Create persistent store with its own standalone RocksDB instance.
1143    ///
1144    /// Primarily useful for tests and standalone operation. In production, prefer
1145    /// [`with_shared_db`](Self::with_shared_db) to share a single DB instance.
1146    pub fn with_persistence<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
1147        let mut opts = Options::default();
1148        opts.create_if_missing(true);
1149        opts.create_missing_column_families(true);
1150        opts.set_compression_type(rocksdb::DBCompressionType::Lz4);
1151
1152        let cfs = vec![
1153            ColumnFamilyDescriptor::new("default", Options::default()),
1154            ColumnFamilyDescriptor::new(CF_FEEDBACK, {
1155                let mut cf_opts = Options::default();
1156                cf_opts.set_compression_type(rocksdb::DBCompressionType::Lz4);
1157                cf_opts
1158            }),
1159        ];
1160        let db = DB::open_cf_descriptors(&opts, path.as_ref(), cfs)?;
1161        let db = Arc::new(db);
1162
1163        let cf = db.cf_handle(CF_FEEDBACK).expect("feedback CF must exist");
1164
1165        // Load all momentum entries from the feedback CF
1166        let mut momentum = HashMap::new();
1167        let iter = db.prefix_iterator_cf(cf, b"momentum:");
1168        for item in iter {
1169            if let Ok((key, value)) = item {
1170                if let Ok(key_str) = std::str::from_utf8(&key) {
1171                    if !key_str.starts_with("momentum:") {
1172                        break;
1173                    }
1174                    if let Ok(m) = serde_json::from_slice::<FeedbackMomentum>(&value) {
1175                        momentum.insert(m.memory_id.clone(), m);
1176                    }
1177                }
1178            }
1179        }
1180
1181        // Also load pending feedback entries (filter expired ones)
1182        let mut pending = HashMap::new();
1183        let iter = db.prefix_iterator_cf(cf, b"pending:");
1184        for item in iter {
1185            if let Ok((key, value)) = item {
1186                if let Ok(key_str) = std::str::from_utf8(&key) {
1187                    if !key_str.starts_with("pending:") {
1188                        break;
1189                    }
1190                    if let Ok(p) = serde_json::from_slice::<PendingFeedback>(&value) {
1191                        if !p.is_expired() {
1192                            pending.insert(p.user_id.clone(), p);
1193                        } else {
1194                            // Clean up expired pending feedback from disk
1195                            let _ = db.delete_cf(cf, key_str.as_bytes());
1196                        }
1197                    }
1198                }
1199            }
1200        }
1201
1202        // Load previous context entries
1203        let mut previous_context = HashMap::new();
1204        let iter = db.prefix_iterator_cf(cf, b"prev_ctx:");
1205        for item in iter {
1206            if let Ok((key, value)) = item {
1207                if let Ok(key_str) = std::str::from_utf8(&key) {
1208                    if !key_str.starts_with("prev_ctx:") {
1209                        break;
1210                    }
1211                    if let Ok(ctx) = serde_json::from_slice::<PreviousContext>(&value) {
1212                        let user_id = key_str.strip_prefix("prev_ctx:").unwrap_or("");
1213                        previous_context.insert(user_id.to_string(), ctx);
1214                    }
1215                }
1216            }
1217        }
1218
1219        tracing::info!(
1220            "Loaded {} momentum, {} pending, {} previous context from feedback CF",
1221            momentum.len(),
1222            pending.len(),
1223            previous_context.len()
1224        );
1225
1226        Ok(Self {
1227            momentum,
1228            pending,
1229            previous_context,
1230            db: Some(db),
1231            dirty: HashSet::new(),
1232        })
1233    }
1234
1235    /// Get or create momentum for a memory
1236    pub fn get_or_create_momentum(
1237        &mut self,
1238        memory_id: MemoryId,
1239        memory_type: ExperienceType,
1240    ) -> &mut FeedbackMomentum {
1241        // Check if we need to load from disk
1242        if !self.momentum.contains_key(&memory_id) {
1243            if let (Some(db), Some(cf)) = (&self.db, self.feedback_cf()) {
1244                let key = format!("momentum:{}", memory_id.0);
1245                if let Ok(Some(data)) = db.get_cf(cf, key.as_bytes()) {
1246                    if let Ok(m) = serde_json::from_slice::<FeedbackMomentum>(&data) {
1247                        self.momentum.insert(memory_id.clone(), m);
1248                    }
1249                }
1250            }
1251        }
1252
1253        self.momentum.entry(memory_id.clone()).or_insert_with(|| {
1254            self.dirty.insert(memory_id.clone());
1255            FeedbackMomentum::new(memory_id, memory_type)
1256        })
1257    }
1258
1259    /// Get momentum for a memory (if exists in-memory), with disk fallback.
1260    /// Checks the in-memory HashMap first, then falls back to RocksDB.
1261    pub fn get_momentum(&self, memory_id: &MemoryId) -> Option<FeedbackMomentum> {
1262        if let Some(m) = self.momentum.get(memory_id) {
1263            return Some(m.clone());
1264        }
1265        // Fall back to disk lookup
1266        if let (Some(db), Some(cf)) = (&self.db, self.feedback_cf()) {
1267            let key = format!("momentum:{}", memory_id.0);
1268            if let Ok(Some(data)) = db.get_cf(cf, key.as_bytes()) {
1269                if let Ok(m) = serde_json::from_slice::<FeedbackMomentum>(&data) {
1270                    return Some(m);
1271                }
1272            }
1273        }
1274        None
1275    }
1276
1277    /// Mark a memory as dirty (needs persistence)
1278    pub fn mark_dirty(&mut self, memory_id: &MemoryId) {
1279        self.dirty.insert(memory_id.clone());
1280    }
1281
1282    /// Set pending feedback for a user (also persists to disk)
1283    pub fn set_pending(&mut self, pending: PendingFeedback) {
1284        let user_id = pending.user_id.clone();
1285        self.pending.insert(user_id.clone(), pending.clone());
1286
1287        // Persist to disk
1288        if let (Some(db), Some(cf)) = (&self.db, self.feedback_cf()) {
1289            let key = format!("pending:{}", user_id);
1290            if let Ok(value) = serde_json::to_vec(&pending) {
1291                if let Err(e) = db.put_cf(cf, key.as_bytes(), &value) {
1292                    tracing::warn!("Failed to persist pending feedback: {}", e);
1293                }
1294            }
1295        }
1296    }
1297
1298    /// Take pending feedback for a user (removes from store and disk)
1299    pub fn take_pending(&mut self, user_id: &str) -> Option<PendingFeedback> {
1300        let result = self.pending.remove(user_id);
1301
1302        // Remove from disk
1303        if let (Some(db), Some(cf)) = (&self.db, self.feedback_cf()) {
1304            let key = format!("pending:{}", user_id);
1305            let _ = db.delete_cf(cf, key.as_bytes());
1306        }
1307
1308        result
1309    }
1310
1311    /// Get pending feedback for a user (without removing)
1312    pub fn get_pending(&self, user_id: &str) -> Option<&PendingFeedback> {
1313        self.pending.get(user_id)
1314    }
1315
1316    /// Clean up expired pending feedback
1317    pub fn cleanup_expired(&mut self) {
1318        self.pending.retain(|_, p| !p.is_expired());
1319    }
1320
1321    /// Set previous context for a user (for repetition/topic change detection)
1322    /// Called when memories are surfaced to track what the user asked
1323    pub fn set_previous_context(
1324        &mut self,
1325        user_id: &str,
1326        context: String,
1327        embedding: Vec<f32>,
1328        surfaced_memory_ids: Vec<MemoryId>,
1329    ) {
1330        let prev_ctx = PreviousContext {
1331            context,
1332            embedding,
1333            timestamp: Utc::now(),
1334            surfaced_memory_ids,
1335        };
1336
1337        self.previous_context
1338            .insert(user_id.to_string(), prev_ctx.clone());
1339
1340        // Persist to disk
1341        if let (Some(db), Some(cf)) = (&self.db, self.feedback_cf()) {
1342            let key = format!("prev_ctx:{}", user_id);
1343            if let Ok(value) = serde_json::to_vec(&prev_ctx) {
1344                if let Err(e) = db.put_cf(cf, key.as_bytes(), &value) {
1345                    tracing::warn!("Failed to persist previous context: {}", e);
1346                }
1347            }
1348        }
1349    }
1350
1351    /// Get previous context for a user
1352    pub fn get_previous_context(&self, user_id: &str) -> Option<&PreviousContext> {
1353        self.previous_context.get(user_id)
1354    }
1355
1356    /// Compare current context to previous and detect action patterns
1357    /// Returns: (is_repetition, is_topic_change, similarity)
1358    /// - Repetition: similarity > 0.8 means user is asking same thing again (memories failed)
1359    /// - Topic change: similarity < 0.3 means user moved on (task might be complete)
1360    pub fn detect_context_pattern(
1361        &self,
1362        user_id: &str,
1363        current_embedding: &[f32],
1364    ) -> Option<(bool, bool, f32)> {
1365        let prev = self.previous_context.get(user_id)?;
1366
1367        if prev.embedding.is_empty() || current_embedding.is_empty() {
1368            return None;
1369        }
1370
1371        let similarity = cosine_similarity(&prev.embedding, current_embedding);
1372
1373        // ACT-R inspired thresholds
1374        let is_repetition = similarity > 0.8; // High similarity = re-asking
1375        let is_topic_change = similarity < 0.3; // Low similarity = moved on
1376
1377        Some((is_repetition, is_topic_change, similarity))
1378    }
1379
1380    /// Flush dirty entries to disk and ensure WAL is persisted
1381    pub fn flush(&mut self) -> anyhow::Result<usize> {
1382        let Some(ref db) = self.db else {
1383            return Ok(0);
1384        };
1385        let Some(cf) = db.cf_handle(CF_FEEDBACK) else {
1386            return Ok(0);
1387        };
1388
1389        // Drain dirty set first so the mutable borrow is released before we
1390        // take shared references to self.momentum / self.pending below.
1391        let dirty: Vec<MemoryId> = self.dirty.drain().collect();
1392
1393        let mut flushed = 0;
1394        for memory_id in &dirty {
1395            if let Some(momentum) = self.momentum.get(memory_id) {
1396                let key = format!("momentum:{}", memory_id.0);
1397                let value = serde_json::to_vec(momentum)?;
1398                db.put_cf(cf, key.as_bytes(), &value)?;
1399                flushed += 1;
1400            }
1401        }
1402
1403        // Also persist any pending feedback entries
1404        for (user_id, pending) in &self.pending {
1405            let key = format!("pending:{}", user_id);
1406            let value = serde_json::to_vec(pending)?;
1407            db.put_cf(cf, key.as_bytes(), &value)?;
1408        }
1409
1410        // Flush the feedback CF to ensure data persistence (critical for graceful shutdown)
1411        use rocksdb::FlushOptions;
1412        let mut flush_opts = FlushOptions::default();
1413        flush_opts.set_wait(true);
1414        db.flush_cf_opt(cf, &flush_opts)
1415            .map_err(|e| anyhow::anyhow!("Failed to flush feedback CF: {e}"))?;
1416
1417        if flushed > 0 {
1418            tracing::debug!("Flushed {} feedback momentum entries to disk", flushed);
1419        }
1420
1421        Ok(flushed)
1422    }
1423
1424    /// Get reference to the RocksDB database for backup (if available)
1425    pub fn database(&self) -> Option<&Arc<DB>> {
1426        self.db.as_ref()
1427    }
1428
1429    /// Get statistics
1430    pub fn stats(&self) -> FeedbackStoreStats {
1431        FeedbackStoreStats {
1432            total_momentum_entries: self.momentum.len(),
1433            total_pending: self.pending.len(),
1434            avg_ema: if self.momentum.is_empty() {
1435                0.0
1436            } else {
1437                self.momentum.values().map(|m| m.ema).sum::<f32>() / self.momentum.len() as f32
1438            },
1439            avg_stability: if self.momentum.is_empty() {
1440                0.0
1441            } else {
1442                self.momentum.values().map(|m| m.stability).sum::<f32>()
1443                    / self.momentum.len() as f32
1444            },
1445        }
1446    }
1447}
1448
1449/// Statistics about the feedback store
1450#[derive(Debug, Clone, Serialize, Deserialize)]
1451pub struct FeedbackStoreStats {
1452    pub total_momentum_entries: usize,
1453    pub total_pending: usize,
1454    pub avg_ema: f32,
1455    pub avg_stability: f32,
1456}
1457
1458#[cfg(test)]
1459mod tests {
1460    use super::*;
1461    use uuid::Uuid;
1462
1463    #[test]
1464    fn test_signal_from_entity_overlap() {
1465        // Strong overlap (>= 0.4 after FBK-4 threshold adjustment)
1466        let signal = SignalRecord::from_entity_overlap(0.7);
1467        assert!(signal.value > 0.5);
1468        assert!(signal.confidence > 0.8);
1469
1470        // Weak overlap (>= 0.1 after FBK-4 threshold adjustment)
1471        let signal = SignalRecord::from_entity_overlap(0.3);
1472        assert!(signal.value > 0.0);
1473        assert!(signal.value < 0.5);
1474
1475        // No overlap (< 0.1 after FBK-4 threshold adjustment)
1476        let signal = SignalRecord::from_entity_overlap(0.05);
1477        assert!(signal.value < 0.0);
1478    }
1479
1480    #[test]
1481    fn test_momentum_inertia_by_type() {
1482        let learning = FeedbackMomentum::new(MemoryId(Uuid::new_v4()), ExperienceType::Learning);
1483        let conversation =
1484            FeedbackMomentum::new(MemoryId(Uuid::new_v4()), ExperienceType::Conversation);
1485
1486        assert!(learning.base_inertia() > conversation.base_inertia());
1487        assert!(learning.base_inertia() >= 0.9);
1488        assert!(conversation.base_inertia() <= 0.4);
1489    }
1490
1491    #[test]
1492    fn test_momentum_update_with_inertia() {
1493        let mut momentum = FeedbackMomentum::new(
1494            MemoryId(Uuid::new_v4()),
1495            ExperienceType::Learning, // High inertia
1496        );
1497
1498        // Apply positive signal
1499        momentum.update(SignalRecord::new(
1500            1.0,
1501            1.0,
1502            SignalTrigger::EntityOverlap { overlap_ratio: 1.0 },
1503        ));
1504
1505        // EMA should move slowly due to high inertia
1506        assert!(momentum.ema > 0.0);
1507        assert!(momentum.ema < 0.5); // Not too fast
1508
1509        // Apply many positive signals
1510        for _ in 0..20 {
1511            momentum.update(SignalRecord::new(
1512                1.0,
1513                1.0,
1514                SignalTrigger::EntityOverlap { overlap_ratio: 1.0 },
1515            ));
1516        }
1517
1518        // Now EMA should be higher
1519        assert!(momentum.ema > 0.5);
1520        // Stability should be high after consistent signals
1521        assert!(momentum.stability > 0.7);
1522    }
1523
1524    #[test]
1525    fn test_trend_detection() {
1526        let mut signals = VecDeque::new();
1527
1528        // Not enough data
1529        assert_eq!(Trend::from_signals(&signals), Trend::Insufficient);
1530
1531        // Add improving signals (steeper slope > 0.1 threshold)
1532        for i in 0..10 {
1533            signals.push_back(SignalRecord::new(
1534                i as f32 * 0.15, // 0, 0.15, 0.3, ... gives slope ~0.15
1535                1.0,
1536                SignalTrigger::TopicChange { similarity: 0.2 },
1537            ));
1538        }
1539        assert_eq!(Trend::from_signals(&signals), Trend::Improving);
1540
1541        // Add declining signals (steeper slope < -0.1 threshold)
1542        signals.clear();
1543        for i in (0..10).rev() {
1544            signals.push_back(SignalRecord::new(
1545                i as f32 * 0.15, // 1.35, 1.2, ... 0 gives slope ~-0.15
1546                1.0,
1547                SignalTrigger::TopicChange { similarity: 0.2 },
1548            ));
1549        }
1550        assert_eq!(Trend::from_signals(&signals), Trend::Declining);
1551    }
1552
1553    #[test]
1554    fn test_entity_overlap() {
1555        let memory: HashSet<String> = ["rust", "async", "tokio"]
1556            .iter()
1557            .map(|s| s.to_string())
1558            .collect();
1559        let response: HashSet<String> = ["rust", "tokio", "spawn"]
1560            .iter()
1561            .map(|s| s.to_string())
1562            .collect();
1563
1564        let overlap = calculate_entity_overlap(&memory, &response);
1565        assert!((overlap - 0.666).abs() < 0.01); // 2/3
1566    }
1567
1568    #[test]
1569    fn test_negative_keyword_detection() {
1570        // Multi-word phrase detection
1571        let text = "No, that's not what I meant";
1572        let keywords = detect_negative_keywords(text);
1573        assert!(keywords.contains(&"not what i meant".to_string()));
1574
1575        // Irrelevance signals
1576        let text2 = "That's not helpful at all, it's irrelevant";
1577        let keywords2 = detect_negative_keywords(text2);
1578        assert!(keywords2.contains(&"not helpful".to_string()));
1579        assert!(keywords2.contains(&"irrelevant".to_string()));
1580
1581        // Explicit rejection
1582        let text3 = "Please forget that, it doesn't work";
1583        let keywords3 = detect_negative_keywords(text3);
1584        assert!(keywords3.contains(&"forget that".to_string()));
1585        assert!(keywords3.contains(&"doesn't work".to_string()));
1586
1587        // No false positives on neutral text
1588        let text4 = "Can you help me debug this function?";
1589        let keywords4 = detect_negative_keywords(text4);
1590        assert!(keywords4.is_empty());
1591    }
1592
1593    #[test]
1594    fn test_feedback_store_pending() {
1595        let mut store = FeedbackStore::new();
1596        let user_id = "test-user";
1597
1598        // Initially no pending
1599        assert!(store.get_pending(user_id).is_none());
1600
1601        // Set pending feedback
1602        let pending = PendingFeedback::new(
1603            user_id.to_string(),
1604            "test context".to_string(),
1605            vec![0.1; 384],
1606            vec![SurfacedMemoryInfo {
1607                id: MemoryId(Uuid::new_v4()),
1608                entities: ["rust", "memory"].iter().map(|s| s.to_string()).collect(),
1609                content_preview: "Test memory".to_string(),
1610                score: 0.8,
1611                embedding: Vec::new(),
1612            }],
1613        );
1614        store.set_pending(pending);
1615
1616        // Should have pending now
1617        assert!(store.get_pending(user_id).is_some());
1618        assert_eq!(
1619            store.get_pending(user_id).unwrap().surfaced_memories.len(),
1620            1
1621        );
1622
1623        // Take should remove it
1624        let taken = store.take_pending(user_id);
1625        assert!(taken.is_some());
1626        assert!(store.get_pending(user_id).is_none());
1627    }
1628
1629    #[test]
1630    fn test_feedback_store_momentum() {
1631        let mut store = FeedbackStore::new();
1632        let memory_id = MemoryId(Uuid::new_v4());
1633
1634        // Get or create momentum
1635        let momentum = store.get_or_create_momentum(memory_id.clone(), ExperienceType::Context);
1636        assert_eq!(momentum.signal_count, 0);
1637        assert_eq!(momentum.ema, 0.0);
1638
1639        // Update it
1640        momentum.update(SignalRecord::new(
1641            0.8,
1642            1.0,
1643            SignalTrigger::EntityOverlap { overlap_ratio: 0.8 },
1644        ));
1645        assert!(momentum.ema > 0.0);
1646        assert_eq!(momentum.signal_count, 1);
1647
1648        // Get should return existing
1649        let momentum2 = store.get_momentum(&memory_id);
1650        assert!(momentum2.is_some());
1651        assert_eq!(momentum2.unwrap().signal_count, 1);
1652    }
1653
1654    #[test]
1655    fn test_process_implicit_feedback_full() {
1656        let memory_id1 = MemoryId(Uuid::new_v4());
1657        let memory_id2 = MemoryId(Uuid::new_v4());
1658
1659        let pending = PendingFeedback::new(
1660            "user1".to_string(),
1661            "How do I use async in Rust?".to_string(),
1662            vec![0.1; 384],
1663            vec![
1664                SurfacedMemoryInfo {
1665                    id: memory_id1.clone(),
1666                    entities: ["rust", "async", "tokio"]
1667                        .iter()
1668                        .map(|s| s.to_string())
1669                        .collect(),
1670                    content_preview: "Rust async with tokio".to_string(),
1671                    score: 0.9,
1672                    embedding: Vec::new(),
1673                },
1674                SurfacedMemoryInfo {
1675                    id: memory_id2.clone(),
1676                    entities: ["python", "django"].iter().map(|s| s.to_string()).collect(),
1677                    content_preview: "Python Django web".to_string(),
1678                    score: 0.3,
1679                    embedding: Vec::new(),
1680                },
1681            ],
1682        );
1683
1684        // Response that uses Rust async terminology
1685        let response =
1686            "To use async in Rust, you can use tokio runtime. Here is an example with async await.";
1687        let signals = process_implicit_feedback(&pending, response, None);
1688
1689        assert_eq!(signals.len(), 2);
1690
1691        // First memory should have positive signal (high entity overlap)
1692        let (id1, sig1) = &signals[0];
1693        assert_eq!(id1, &memory_id1);
1694        assert!(sig1.value > 0.0);
1695
1696        // Second memory should have negative/low signal (no overlap)
1697        let (id2, sig2) = &signals[1];
1698        assert_eq!(id2, &memory_id2);
1699        assert!(sig2.value <= 0.0);
1700    }
1701
1702    #[test]
1703    fn test_process_implicit_feedback_with_negative_keywords() {
1704        let memory_id = MemoryId(Uuid::new_v4());
1705
1706        let pending = PendingFeedback::new(
1707            "user1".to_string(),
1708            "How do I use async?".to_string(),
1709            vec![0.1; 384],
1710            vec![SurfacedMemoryInfo {
1711                id: memory_id.clone(),
1712                entities: ["async", "code"].iter().map(|s| s.to_string()).collect(),
1713                content_preview: "Async code".to_string(),
1714                score: 0.9,
1715                embedding: Vec::new(),
1716            }],
1717        );
1718
1719        // Response uses entities
1720        let response = "Here is the async code pattern";
1721
1722        // Process without negative keywords
1723        let signals1 = process_implicit_feedback(&pending, response, None);
1724        let value_without = signals1[0].1.value;
1725
1726        // Process with negative keywords in followup
1727        let signals2 = process_implicit_feedback(&pending, response, Some("No, that is wrong!"));
1728        let value_with = signals2[0].1.value;
1729
1730        // Negative keywords should decrease the signal
1731        assert!(value_with < value_without);
1732    }
1733
1734    #[test]
1735    fn test_context_fingerprint_similarity() {
1736        let embedding: Vec<f32> = (0..384).map(|i| (i as f32) * 0.01).collect();
1737        let fp1 = ContextFingerprint::new(
1738            vec!["rust".to_string(), "memory".to_string()],
1739            &embedding,
1740            true,
1741        );
1742        let fp2 = ContextFingerprint::new(
1743            vec!["rust".to_string(), "async".to_string()],
1744            &embedding,
1745            false,
1746        );
1747        let different_embedding: Vec<f32> = (0..384).map(|i| 1.0 - (i as f32) * 0.01).collect();
1748        let fp3 = ContextFingerprint::new(
1749            vec!["python".to_string(), "django".to_string()],
1750            &different_embedding,
1751            true,
1752        );
1753
1754        // fp1 and fp2 share "rust" entity and same embedding
1755        let sim12 = fp1.similarity(&fp2);
1756        // fp1 and fp3 have no entity overlap and different embedding
1757        let sim13 = fp1.similarity(&fp3);
1758
1759        assert!(sim12 > sim13);
1760    }
1761
1762    #[test]
1763    fn test_feedback_store_stats() {
1764        let mut store = FeedbackStore::new();
1765
1766        // Empty stats
1767        let stats = store.stats();
1768        assert_eq!(stats.total_momentum_entries, 0);
1769        assert_eq!(stats.total_pending, 0);
1770
1771        // Add some momentum entries
1772        for i in 0..5 {
1773            let mut momentum =
1774                FeedbackMomentum::new(MemoryId(Uuid::new_v4()), ExperienceType::Context);
1775            momentum.ema = i as f32 * 0.2; // 0, 0.2, 0.4, 0.6, 0.8
1776            store.momentum.insert(momentum.memory_id.clone(), momentum);
1777        }
1778
1779        let stats = store.stats();
1780        assert_eq!(stats.total_momentum_entries, 5);
1781        assert!((stats.avg_ema - 0.4).abs() < 0.01); // (0+0.2+0.4+0.6+0.8)/5 = 0.4
1782    }
1783
1784    #[test]
1785    fn test_process_feedback_with_semantic_similarity() {
1786        let memory_id1 = MemoryId(Uuid::new_v4());
1787        let memory_id2 = MemoryId(Uuid::new_v4());
1788
1789        // Create embeddings: similar embeddings for related content
1790        let rust_embedding: Vec<f32> = (0..384).map(|i| (i as f32) * 0.01).collect();
1791        let python_embedding: Vec<f32> = (0..384).map(|i| 1.0 - (i as f32) * 0.01).collect();
1792
1793        let pending = PendingFeedback::new(
1794            "user1".to_string(),
1795            "How do I use async in Rust?".to_string(),
1796            vec![0.1; 384],
1797            vec![
1798                SurfacedMemoryInfo {
1799                    id: memory_id1.clone(),
1800                    entities: ["rust", "async", "tokio"]
1801                        .iter()
1802                        .map(|s| s.to_string())
1803                        .collect(),
1804                    content_preview: "Rust async with tokio".to_string(),
1805                    score: 0.9,
1806                    embedding: rust_embedding.clone(),
1807                },
1808                SurfacedMemoryInfo {
1809                    id: memory_id2.clone(),
1810                    entities: ["python", "django"].iter().map(|s| s.to_string()).collect(),
1811                    content_preview: "Python Django web".to_string(),
1812                    score: 0.3,
1813                    embedding: python_embedding.clone(),
1814                },
1815            ],
1816        );
1817
1818        // Response embedding similar to rust_embedding
1819        let response = "Here is how to use async/await in Rust with tokio runtime.";
1820        let response_embedding = rust_embedding; // Similar to memory 1
1821
1822        // Process without semantic (backwards compat)
1823        let signals_entity_only = process_implicit_feedback(&pending, response, None);
1824
1825        // Process with semantic similarity
1826        let signals_with_semantic = process_implicit_feedback_with_semantics(
1827            &pending,
1828            response,
1829            None,
1830            Some(&response_embedding),
1831        );
1832
1833        // First memory should score higher with semantic (response embedding matches memory embedding)
1834        let (id1, _sig1_entity) = &signals_entity_only[0];
1835        let (_, sig1_semantic) = &signals_with_semantic[0];
1836        assert_eq!(id1, &memory_id1);
1837
1838        // Semantic signal should use SemanticSimilarity trigger
1839        match &sig1_semantic.trigger {
1840            SignalTrigger::SemanticSimilarity { similarity } => {
1841                assert!(*similarity > 0.9); // High similarity since embeddings are same
1842            }
1843            _ => panic!("Expected SemanticSimilarity trigger"),
1844        }
1845
1846        // Second memory (python) should have low semantic score since embedding is different
1847        let (id2, sig2_semantic) = &signals_with_semantic[1];
1848        assert_eq!(id2, &memory_id2);
1849        match &sig2_semantic.trigger {
1850            SignalTrigger::SemanticSimilarity { similarity } => {
1851                assert!(*similarity < 0.5); // Low similarity - different embeddings
1852            }
1853            _ => panic!("Expected SemanticSimilarity trigger"),
1854        }
1855    }
1856
1857    #[test]
1858    fn test_cosine_similarity_basic() {
1859        // Identical vectors = 1.0
1860        let a = vec![1.0, 0.0, 0.0];
1861        let b = vec![1.0, 0.0, 0.0];
1862        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
1863
1864        // Orthogonal vectors = 0.0
1865        let c = vec![0.0, 1.0, 0.0];
1866        assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.001);
1867
1868        // Opposite vectors = -1.0
1869        let d = vec![-1.0, 0.0, 0.0];
1870        assert!((cosine_similarity(&a, &d) - (-1.0)).abs() < 0.001);
1871
1872        // Empty vectors = 0.0
1873        assert!((cosine_similarity(&[], &[]) - 0.0).abs() < 0.001);
1874    }
1875
1876    #[test]
1877    fn test_calculate_entity_flow() {
1878        use std::collections::HashSet;
1879
1880        // Case 1: Response heavily derived from memory
1881        let memory_entities: HashSet<String> = ["rust", "async", "tokio", "futures"]
1882            .iter()
1883            .map(|s| s.to_string())
1884            .collect();
1885        let response_entities: HashSet<String> = ["rust", "async", "tokio", "runtime"]
1886            .iter()
1887            .map(|s| s.to_string())
1888            .collect();
1889
1890        let (derived_ratio, novel_ratio, derived_count, total) =
1891            calculate_entity_flow(&memory_entities, &response_entities);
1892
1893        assert_eq!(derived_count, 3); // rust, async, tokio
1894        assert_eq!(total, 4);
1895        assert!((derived_ratio - 0.75).abs() < 0.01);
1896        assert!((novel_ratio - 0.25).abs() < 0.01);
1897
1898        // Case 2: Response mostly novel (memory not used)
1899        let response_novel: HashSet<String> = ["python", "django", "flask", "web"]
1900            .iter()
1901            .map(|s| s.to_string())
1902            .collect();
1903
1904        let (derived_ratio2, novel_ratio2, derived_count2, _) =
1905            calculate_entity_flow(&memory_entities, &response_novel);
1906
1907        assert_eq!(derived_count2, 0);
1908        assert!((derived_ratio2 - 0.0).abs() < 0.01);
1909        assert!((novel_ratio2 - 1.0).abs() < 0.01);
1910
1911        // Case 3: Empty response
1912        let empty: HashSet<String> = HashSet::new();
1913        let (dr, nr, dc, total) = calculate_entity_flow(&memory_entities, &empty);
1914        assert_eq!(dc, 0);
1915        assert_eq!(total, 0);
1916        assert!((dr - 0.0).abs() < 0.01);
1917        assert!((nr - 0.0).abs() < 0.01);
1918    }
1919
1920    #[test]
1921    fn test_signal_from_entity_flow() {
1922        // Case 1: High derived ratio (>=0.5) = strong positive
1923        let sig1 = signal_from_entity_flow(0.75, 0.25, 3, 4);
1924        assert!(sig1.value > 0.5); // Strong positive
1925        assert!((sig1.confidence - 0.8).abs() < 0.01); // Good sample size
1926
1927        // Case 2: Medium derived ratio (0.2 to 0.5) = weak positive
1928        let sig2 = signal_from_entity_flow(0.3, 0.7, 1, 4);
1929        assert!(sig2.value > 0.0 && sig2.value <= 0.5); // Weak positive
1930        assert!((sig2.confidence - 0.8).abs() < 0.01);
1931
1932        // Case 3: Low derived, high novel = slight negative
1933        let sig3 = signal_from_entity_flow(0.1, 0.9, 0, 4);
1934        assert!(sig3.value < 0.0); // Negative
1935        assert!((sig3.value - (-0.1)).abs() < 0.01);
1936
1937        // Case 4: Small sample size = lower confidence
1938        let sig4 = signal_from_entity_flow(0.5, 0.5, 1, 2);
1939        assert!((sig4.confidence - 0.5).abs() < 0.01);
1940
1941        // Verify trigger variant
1942        match sig1.trigger {
1943            SignalTrigger::EntityFlow {
1944                derived_ratio,
1945                novel_ratio,
1946                memory_entities_used,
1947                response_entities_total,
1948            } => {
1949                assert!((derived_ratio - 0.75).abs() < 0.01);
1950                assert!((novel_ratio - 0.25).abs() < 0.01);
1951                assert_eq!(memory_entities_used, 3);
1952                assert_eq!(response_entities_total, 4);
1953            }
1954            _ => panic!("Expected EntityFlow trigger"),
1955        }
1956    }
1957}