1use chrono::{DateTime, Duration, Utc};
9use rocksdb::{ColumnFamily, ColumnFamilyDescriptor, IteratorMode, Options, WriteBatch, DB};
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet, VecDeque};
12use std::path::Path;
13use std::sync::Arc;
14
15use crate::memory::types::{ExperienceType, MemoryId};
16
17pub(crate) const CF_FEEDBACK: &str = "feedback";
23
24const MAX_RECENT_SIGNALS: usize = 20;
26
27const MAX_CONTEXT_FINGERPRINTS: usize = 100;
29
30const OVERLAP_STRONG_THRESHOLD: f32 = 0.4;
33const OVERLAP_WEAK_THRESHOLD: f32 = 0.1;
34
35const SEMANTIC_STRONG_THRESHOLD: f32 = 0.6;
38const SEMANTIC_WEAK_THRESHOLD: f32 = 0.3;
39
40const SIGNAL_STRONG_MULTIPLIER: f32 = 0.8;
42const SIGNAL_WEAK_MULTIPLIER: f32 = 0.3;
43const SIGNAL_NO_OVERLAP_PENALTY: f32 = -0.2; const SIGNAL_NEGATIVE_KEYWORD_PENALTY: f32 = -0.5;
45
46const SIGNAL_REPETITION_PENALTY: f32 = -0.4; const SIGNAL_TOPIC_CHANGE_BOOST: f32 = 0.2; const SIGNAL_IGNORED_PENALTY: f32 = -0.2; const ENTITY_WEIGHT: f32 = 0.4;
53const SEMANTIC_WEIGHT: f32 = 0.6;
54
55const STABILITY_INCREMENT: f32 = 0.05;
57const STABILITY_DECREMENT_MULTIPLIER: f32 = 0.1;
58
59const TREND_IMPROVING_THRESHOLD: f32 = 0.1;
61const TREND_DECLINING_THRESHOLD: f32 = -0.1;
62
63const DECAY_HALF_LIFE_DAYS: f32 = 14.0; const NEGATIVE_KEYWORDS: &[&str] = &[
70 "wrong",
72 "incorrect",
73 "not correct",
74 "nope",
75 "not what i meant",
77 "that's not right",
78 "that's wrong",
79 "i already said",
80 "i told you",
81 "i already told",
82 "already mentioned",
83 "not helpful",
85 "not relevant",
86 "not useful",
87 "irrelevant",
88 "useless",
89 "doesn't help",
90 "didn't help",
91 "not related",
92 "doesn't work",
94 "didn't work",
95 "broken",
96 "still broken",
97 "that failed",
98 "forget that",
100 "ignore that",
101 "disregard",
102 "stop suggesting",
103 "don't show",
104];
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
112pub enum SignalTrigger {
113 EntityOverlap { overlap_ratio: f32 },
115
116 SemanticSimilarity { similarity: f32 },
118
119 NegativeKeywords { keywords: Vec<String> },
121
122 UserRepetition { similarity: f32 },
125
126 TopicChange { similarity: f32 },
129
130 Ignored { overlap_ratio: f32 },
133
134 EntityFlow {
139 derived_ratio: f32,
140 novel_ratio: f32,
141 memory_entities_used: usize,
142 response_entities_total: usize,
143 },
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct SignalRecord {
149 pub timestamp: DateTime<Utc>,
151
152 pub value: f32,
154
155 pub confidence: f32,
157
158 pub trigger: SignalTrigger,
160}
161
162impl SignalRecord {
163 pub fn new(value: f32, confidence: f32, trigger: SignalTrigger) -> Self {
164 Self {
165 timestamp: Utc::now(),
166 value: value.clamp(-1.0, 1.0),
167 confidence: confidence.clamp(0.0, 1.0),
168 trigger,
169 }
170 }
171
172 pub fn from_entity_overlap(overlap_ratio: f32) -> Self {
174 let (value, confidence) = if overlap_ratio >= OVERLAP_STRONG_THRESHOLD {
175 (SIGNAL_STRONG_MULTIPLIER * overlap_ratio, 0.9)
176 } else if overlap_ratio >= OVERLAP_WEAK_THRESHOLD {
177 (SIGNAL_WEAK_MULTIPLIER * overlap_ratio, 0.6)
178 } else {
179 (SIGNAL_NO_OVERLAP_PENALTY, 0.4)
180 };
181
182 Self::new(
183 value,
184 confidence,
185 SignalTrigger::EntityOverlap { overlap_ratio },
186 )
187 }
188
189 pub fn from_negative_keywords(keywords: Vec<String>) -> Self {
191 Self::new(
192 SIGNAL_NEGATIVE_KEYWORD_PENALTY,
193 0.95, SignalTrigger::NegativeKeywords { keywords },
195 )
196 }
197}
198
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
205pub enum Trend {
206 Improving,
208 Stable,
210 Declining,
212 Insufficient,
214}
215
216impl Trend {
217 pub fn from_signals(signals: &VecDeque<SignalRecord>) -> Self {
219 if signals.len() < 3 {
220 return Trend::Insufficient;
221 }
222
223 let n = signals.len() as f32;
224 let mut sum_x = 0.0;
225 let mut sum_y = 0.0;
226 let mut sum_xy = 0.0;
227 let mut sum_xx = 0.0;
228
229 for (i, signal) in signals.iter().enumerate() {
230 let x = i as f32;
231 let y = signal.value;
232 sum_x += x;
233 sum_y += y;
234 sum_xy += x * y;
235 sum_xx += x * x;
236 }
237
238 let denominator = n * sum_xx - sum_x * sum_x;
240 if denominator.abs() < f32::EPSILON {
241 return Trend::Stable;
242 }
243
244 let slope = (n * sum_xy - sum_x * sum_y) / denominator;
245
246 if slope > TREND_IMPROVING_THRESHOLD {
247 Trend::Improving
248 } else if slope < TREND_DECLINING_THRESHOLD {
249 Trend::Declining
250 } else {
251 Trend::Stable
252 }
253 }
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct ContextFingerprint {
264 pub entities: Vec<String>,
266
267 pub embedding_signature: [f32; 16],
269
270 pub timestamp: DateTime<Utc>,
272
273 pub was_helpful: bool,
275}
276
277impl ContextFingerprint {
278 pub fn new(entities: Vec<String>, embedding: &[f32], was_helpful: bool) -> Self {
279 let mut signature = [0.0f32; 16];
281 if !embedding.is_empty() {
282 let step = embedding.len() / 16;
283 for (i, sig) in signature.iter_mut().enumerate() {
284 let idx = (i * step).min(embedding.len() - 1);
285 *sig = embedding[idx];
286 }
287 }
288
289 Self {
290 entities,
291 embedding_signature: signature,
292 timestamp: Utc::now(),
293 was_helpful,
294 }
295 }
296
297 pub fn similarity(&self, other: &ContextFingerprint) -> f32 {
299 let self_set: HashSet<_> = self.entities.iter().collect();
301 let other_set: HashSet<_> = other.entities.iter().collect();
302 let intersection = self_set.intersection(&other_set).count() as f32;
303 let union = self_set.union(&other_set).count() as f32;
304 let entity_sim = if union > 0.0 {
305 intersection / union
306 } else {
307 0.0
308 };
309
310 let mut dot = 0.0;
312 let mut norm_a = 0.0;
313 let mut norm_b = 0.0;
314 for i in 0..16 {
315 dot += self.embedding_signature[i] * other.embedding_signature[i];
316 norm_a += self.embedding_signature[i] * self.embedding_signature[i];
317 norm_b += other.embedding_signature[i] * other.embedding_signature[i];
318 }
319 let embed_sim = if norm_a > 0.0 && norm_b > 0.0 {
320 dot / (norm_a.sqrt() * norm_b.sqrt())
321 } else {
322 0.0
323 };
324
325 entity_sim * 0.6 + embed_sim * 0.4
327 }
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct FeedbackMomentum {
338 pub memory_id: MemoryId,
340
341 pub memory_type: ExperienceType,
343
344 pub ema: f32,
347
348 pub signal_count: u32,
350
351 pub stability: f32,
354
355 pub first_signal_at: Option<DateTime<Utc>>,
357
358 pub last_signal_at: Option<DateTime<Utc>>,
360
361 pub recent_signals: VecDeque<SignalRecord>,
363
364 pub helpful_contexts: Vec<ContextFingerprint>,
366
367 pub misleading_contexts: Vec<ContextFingerprint>,
369}
370
371impl FeedbackMomentum {
372 pub fn new(memory_id: MemoryId, memory_type: ExperienceType) -> Self {
373 Self {
374 memory_id,
375 memory_type,
376 ema: 0.0,
377 signal_count: 0,
378 stability: 0.5, first_signal_at: None,
380 last_signal_at: None,
381 recent_signals: VecDeque::with_capacity(MAX_RECENT_SIGNALS),
382 helpful_contexts: Vec::new(),
383 misleading_contexts: Vec::new(),
384 }
385 }
386
387 pub fn base_inertia(&self) -> f32 {
390 match self.memory_type {
391 ExperienceType::Learning => 0.95,
392 ExperienceType::Decision => 0.90,
393 ExperienceType::Pattern => 0.85,
394 ExperienceType::Discovery => 0.75,
395 ExperienceType::Context => 0.60,
396 ExperienceType::Task => 0.50,
397 ExperienceType::Observation => 0.40,
398 ExperienceType::Conversation => 0.30,
399 ExperienceType::Error => 0.20,
400 ExperienceType::CodeEdit => 0.50,
402 ExperienceType::FileAccess => 0.40,
403 ExperienceType::Search => 0.35,
404 ExperienceType::Command => 0.35,
405 ExperienceType::Intention => 0.60,
406 }
407 }
408
409 pub fn age_factor(&self) -> f32 {
412 let age_days = self
413 .first_signal_at
414 .map(|first| {
415 let duration = Utc::now() - first;
416 duration.num_days() as f32
417 })
418 .unwrap_or(0.0);
419
420 if age_days < 1.0 {
421 0.8 } else if age_days < 7.0 {
423 0.9 } else if age_days < 30.0 {
425 1.0 } else {
427 1.1 }
429 }
430
431 pub fn history_factor(&self) -> f32 {
434 match self.signal_count {
435 0..=2 => 0.7, 3..=9 => 0.9, 10..=49 => 1.0, _ => 1.1, }
440 }
441
442 pub fn stability_factor(&self) -> f32 {
445 0.8 + (self.stability * 0.4)
447 }
448
449 pub fn effective_inertia(&self) -> f32 {
451 let inertia = self.base_inertia()
452 * self.age_factor()
453 * self.history_factor()
454 * self.stability_factor();
455
456 inertia.clamp(0.5, 0.99)
458 }
459
460 pub fn recency_weight(&self, signal_time: DateTime<Utc>) -> f32 {
462 let time_since_last = self
463 .last_signal_at
464 .map(|last| signal_time - last)
465 .unwrap_or_else(Duration::zero);
466
467 if time_since_last < Duration::hours(1) {
468 1.0
469 } else if time_since_last < Duration::days(1) {
470 0.9
471 } else if time_since_last < Duration::days(7) {
472 0.7
473 } else {
474 0.5
475 }
476 }
477
478 pub fn update(&mut self, signal: SignalRecord) {
480 let now = signal.timestamp;
481
482 if self.first_signal_at.is_none() {
484 self.first_signal_at = Some(now);
485 }
486
487 let effective_inertia = self.effective_inertia();
489 let recency = self.recency_weight(now);
490
491 let alpha = (1.0 - effective_inertia) * recency * signal.confidence;
494
495 let old_ema = self.ema;
497
498 self.ema = old_ema * (1.0 - alpha) + signal.value * alpha;
500
501 let direction_matches =
503 (signal.value > 0.0) == (old_ema > 0.0) || old_ema.abs() < f32::EPSILON;
504
505 if direction_matches {
506 self.stability = (self.stability + STABILITY_INCREMENT).min(1.0);
508 } else {
509 let contradiction_strength = (signal.value - old_ema).abs();
511 self.stability =
512 (self.stability - STABILITY_DECREMENT_MULTIPLIER * contradiction_strength).max(0.0);
513 }
514
515 self.recent_signals.push_back(signal);
517 if self.recent_signals.len() > MAX_RECENT_SIGNALS {
518 self.recent_signals.pop_front();
519 }
520
521 self.signal_count += 1;
522 self.last_signal_at = Some(now);
523 }
524
525 pub fn trend(&self) -> Trend {
527 Trend::from_signals(&self.recent_signals)
528 }
529
530 pub fn add_context(&mut self, fingerprint: ContextFingerprint) {
532 let target = if fingerprint.was_helpful {
533 &mut self.helpful_contexts
534 } else {
535 &mut self.misleading_contexts
536 };
537
538 target.push(fingerprint);
539
540 if target.len() > MAX_CONTEXT_FINGERPRINTS {
542 target.remove(0);
543 }
544 }
545
546 pub fn matches_helpful_pattern(&self, current: &ContextFingerprint) -> Option<f32> {
548 self.helpful_contexts
549 .iter()
550 .map(|fp| fp.similarity(current))
551 .max_by(|a, b| a.total_cmp(b))
552 }
553
554 pub fn matches_misleading_pattern(&self, current: &ContextFingerprint) -> Option<f32> {
556 self.misleading_contexts
557 .iter()
558 .map(|fp| fp.similarity(current))
559 .max_by(|a, b| a.total_cmp(b))
560 }
561
562 pub fn ema_with_decay(&self) -> f32 {
566 let days_since_last = self
567 .last_signal_at
568 .map(|last| {
569 let duration = Utc::now() - last;
570 duration.num_hours() as f32 / 24.0
571 })
572 .unwrap_or(0.0);
573
574 if days_since_last < 0.1 {
575 return self.ema;
577 }
578
579 let decay_factor = 0.5_f32.powf(days_since_last / DECAY_HALF_LIFE_DAYS);
582
583 self.ema * decay_factor
585 }
586}
587
588#[derive(Debug, Clone, Serialize, Deserialize)]
594pub struct SurfacedMemoryInfo {
595 pub id: MemoryId,
596 pub entities: HashSet<String>,
597 pub content_preview: String,
598 pub score: f32,
599 #[serde(default)]
601 pub embedding: Vec<f32>,
602}
603
604#[derive(Debug, Clone, Serialize, Deserialize)]
606pub struct PendingFeedback {
607 pub user_id: String,
608 pub surfaced_at: DateTime<Utc>,
609 pub surfaced_memories: Vec<SurfacedMemoryInfo>,
610 pub context: String,
611 pub context_embedding: Vec<f32>,
612}
613
614impl PendingFeedback {
615 pub fn new(
616 user_id: String,
617 context: String,
618 context_embedding: Vec<f32>,
619 memories: Vec<SurfacedMemoryInfo>,
620 ) -> Self {
621 Self {
622 user_id,
623 surfaced_at: Utc::now(),
624 surfaced_memories: memories,
625 context,
626 context_embedding,
627 }
628 }
629
630 pub fn is_expired(&self) -> bool {
632 Utc::now() - self.surfaced_at > Duration::hours(1)
633 }
634}
635
636pub fn extract_entities_simple(text: &str) -> HashSet<String> {
643 text.to_lowercase()
644 .split(|c: char| !c.is_alphanumeric() && c != '_')
645 .filter(|word| word.len() > 2)
646 .map(|s| s.to_string())
647 .collect()
648}
649
650pub fn calculate_entity_overlap(
652 memory_entities: &HashSet<String>,
653 response_entities: &HashSet<String>,
654) -> f32 {
655 if memory_entities.is_empty() {
656 return 0.0;
657 }
658
659 let intersection = memory_entities.intersection(response_entities).count() as f32;
660 intersection / memory_entities.len() as f32
661}
662
663fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
665 if a.len() != b.len() || a.is_empty() {
666 return 0.0;
667 }
668
669 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
670 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
671 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
672
673 if norm_a == 0.0 || norm_b == 0.0 {
674 return 0.0;
675 }
676
677 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
678}
679
680fn signal_from_semantic_similarity(similarity: f32) -> (f32, f32) {
682 if similarity >= SEMANTIC_STRONG_THRESHOLD {
683 (SIGNAL_STRONG_MULTIPLIER * similarity, 0.9)
684 } else if similarity >= SEMANTIC_WEAK_THRESHOLD {
685 (SIGNAL_WEAK_MULTIPLIER * similarity, 0.6)
686 } else {
687 (SIGNAL_NO_OVERLAP_PENALTY * 0.5, 0.3) }
689}
690
691pub fn detect_negative_keywords(text: &str) -> Vec<String> {
693 let lower = text.to_lowercase();
694 NEGATIVE_KEYWORDS
695 .iter()
696 .filter(|&&kw| lower.contains(kw))
697 .map(|&s| s.to_string())
698 .collect()
699}
700
701pub fn calculate_entity_flow(
710 memory_entities: &HashSet<String>,
711 response_entities: &HashSet<String>,
712) -> (f32, f32, usize, usize) {
713 if response_entities.is_empty() {
714 return (0.0, 0.0, 0, 0);
715 }
716
717 let derived: HashSet<_> = response_entities
719 .intersection(memory_entities)
720 .cloned()
721 .collect();
722 let derived_count = derived.len();
723
724 let novel_count = response_entities.len() - derived_count;
726
727 let derived_ratio = derived_count as f32 / response_entities.len() as f32;
728 let novel_ratio = novel_count as f32 / response_entities.len() as f32;
729
730 (
731 derived_ratio,
732 novel_ratio,
733 derived_count,
734 response_entities.len(),
735 )
736}
737
738pub fn signal_from_entity_flow(
740 derived_ratio: f32,
741 novel_ratio: f32,
742 memory_entities_used: usize,
743 response_entities_total: usize,
744) -> SignalRecord {
745 let value = if derived_ratio >= 0.5 {
749 0.6 + (derived_ratio - 0.5) * 0.4
751 } else if derived_ratio >= 0.2 {
752 derived_ratio * 1.5
754 } else if novel_ratio >= 0.8 {
755 -0.1
757 } else {
758 0.0
760 };
761
762 let confidence = if response_entities_total >= 3 {
763 0.8 } else {
765 0.5 };
767
768 SignalRecord::new(
769 value,
770 confidence,
771 SignalTrigger::EntityFlow {
772 derived_ratio,
773 novel_ratio,
774 memory_entities_used,
775 response_entities_total,
776 },
777 )
778}
779
780pub fn process_implicit_feedback(
783 pending: &PendingFeedback,
784 response_text: &str,
785 user_followup: Option<&str>,
786) -> Vec<(MemoryId, SignalRecord)> {
787 process_implicit_feedback_with_semantics(pending, response_text, user_followup, None)
789}
790
791pub fn process_implicit_feedback_with_semantics(
797 pending: &PendingFeedback,
798 response_text: &str,
799 user_followup: Option<&str>,
800 response_embedding: Option<&[f32]>,
801) -> Vec<(MemoryId, SignalRecord)> {
802 let response_entities = extract_entities_simple(response_text);
803 let mut signals = Vec::new();
804
805 for memory in &pending.surfaced_memories {
807 let entity_overlap = calculate_entity_overlap(&memory.entities, &response_entities);
809 let (entity_value, entity_conf) = if entity_overlap >= OVERLAP_STRONG_THRESHOLD {
810 (SIGNAL_STRONG_MULTIPLIER * entity_overlap, 0.9)
811 } else if entity_overlap >= OVERLAP_WEAK_THRESHOLD {
812 (SIGNAL_WEAK_MULTIPLIER * entity_overlap, 0.6)
813 } else {
814 (SIGNAL_NO_OVERLAP_PENALTY, 0.4)
815 };
816
817 let (semantic_value, semantic_conf, has_semantic) =
819 if let Some(resp_emb) = response_embedding {
820 if !memory.embedding.is_empty() {
821 let similarity = cosine_similarity(&memory.embedding, resp_emb);
822 let (val, conf) = signal_from_semantic_similarity(similarity);
823 (val, conf, true)
824 } else {
825 (0.0, 0.0, false)
826 }
827 } else {
828 (0.0, 0.0, false)
829 };
830
831 let (combined_value, combined_confidence, trigger) = if has_semantic {
833 let value = (ENTITY_WEIGHT * entity_value) + (SEMANTIC_WEIGHT * semantic_value);
834 let confidence = (ENTITY_WEIGHT * entity_conf) + (SEMANTIC_WEIGHT * semantic_conf);
835
836 let similarity = if let Some(resp_emb) = response_embedding {
838 cosine_similarity(&memory.embedding, resp_emb)
839 } else {
840 0.0
841 };
842 (
843 value,
844 confidence,
845 SignalTrigger::SemanticSimilarity { similarity },
846 )
847 } else {
848 (
850 entity_value,
851 entity_conf,
852 SignalTrigger::EntityOverlap {
853 overlap_ratio: entity_overlap,
854 },
855 )
856 };
857
858 let mut signal = SignalRecord::new(combined_value, combined_confidence, trigger);
859
860 if let Some(followup) = user_followup {
862 let negative = detect_negative_keywords(followup);
863 if !negative.is_empty() {
864 signal.value += SIGNAL_NEGATIVE_KEYWORD_PENALTY;
865 signal.value = signal.value.clamp(-1.0, 1.0);
866 signal.confidence = 0.95; }
868 }
869
870 signals.push((memory.id.clone(), signal));
871 }
872
873 signals
874}
875
876pub fn apply_context_pattern_signals(
889 signals: &mut [(MemoryId, SignalRecord)],
890 is_repetition: bool,
891 is_topic_change: bool,
892 _context_similarity: f32,
893) {
894 for (memory_id, signal) in signals.iter_mut() {
895 if is_repetition {
896 if signal.value < 0.15 {
900 signal.value += SIGNAL_REPETITION_PENALTY;
902 signal.value = signal.value.clamp(-1.0, 1.0);
903 signal.trigger = SignalTrigger::UserRepetition {
904 similarity: _context_similarity,
905 };
906 signal.confidence = 0.85; tracing::debug!(
908 "Repetition detected for memory {:?}: applied penalty",
909 memory_id
910 );
911 }
912 } else if is_topic_change {
913 if signal.value > 0.05 {
917 signal.value += SIGNAL_TOPIC_CHANGE_BOOST;
919 signal.value = signal.value.clamp(-1.0, 1.0);
920 signal.trigger = SignalTrigger::TopicChange {
921 similarity: _context_similarity,
922 };
923 signal.confidence = 0.7; tracing::debug!(
925 "Topic change detected for memory {:?}: applied boost",
926 memory_id
927 );
928 }
929 }
930
931 if signal.value < -0.05 && signal.value > -0.3 {
934 signal.value = SIGNAL_IGNORED_PENALTY.min(signal.value);
936 if !matches!(signal.trigger, SignalTrigger::UserRepetition { .. }) {
937 signal.trigger = SignalTrigger::Ignored {
938 overlap_ratio: match &signal.trigger {
939 SignalTrigger::EntityOverlap { overlap_ratio } => *overlap_ratio,
940 _ => 0.0,
941 },
942 };
943 }
944 }
945 }
946}
947
948#[derive(Debug, Clone, Serialize, Deserialize)]
954pub struct PreviousContext {
955 pub context: String,
957 pub embedding: Vec<f32>,
959 pub timestamp: DateTime<Utc>,
961 pub surfaced_memory_ids: Vec<MemoryId>,
963}
964
965pub struct FeedbackStore {
967 pub momentum: HashMap<MemoryId, FeedbackMomentum>,
969
970 pending: HashMap<String, PendingFeedback>,
972
973 previous_context: HashMap<String, PreviousContext>,
976
977 db: Option<Arc<DB>>,
979
980 dirty: HashSet<MemoryId>,
982}
983
984impl std::fmt::Debug for FeedbackStore {
985 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
986 f.debug_struct("FeedbackStore")
987 .field("momentum_count", &self.momentum.len())
988 .field("pending_count", &self.pending.len())
989 .field("previous_context_count", &self.previous_context.len())
990 .field("has_db", &self.db.is_some())
991 .field("dirty_count", &self.dirty.len())
992 .finish()
993 }
994}
995
996impl Default for FeedbackStore {
997 fn default() -> Self {
998 Self {
999 momentum: HashMap::new(),
1000 pending: HashMap::new(),
1001 previous_context: HashMap::new(),
1002 db: None,
1003 dirty: HashSet::new(),
1004 }
1005 }
1006}
1007
1008impl FeedbackStore {
1009 pub fn new() -> Self {
1011 Self::default()
1012 }
1013
1014 fn feedback_cf(&self) -> Option<&ColumnFamily> {
1017 self.db.as_ref().and_then(|db| db.cf_handle(CF_FEEDBACK))
1018 }
1019
1020 pub fn with_shared_db(db: Arc<DB>, base_path: &Path) -> anyhow::Result<Self> {
1026 Self::migrate_from_separate_db(base_path, &db)?;
1027
1028 let cf = db.cf_handle(CF_FEEDBACK).expect("feedback CF must exist");
1029
1030 let mut momentum = HashMap::new();
1032 let iter = db.prefix_iterator_cf(cf, b"momentum:");
1033 for item in iter {
1034 if let Ok((key, value)) = item {
1035 if let Ok(key_str) = std::str::from_utf8(&key) {
1036 if !key_str.starts_with("momentum:") {
1037 break;
1038 }
1039 if let Ok(m) = serde_json::from_slice::<FeedbackMomentum>(&value) {
1040 momentum.insert(m.memory_id.clone(), m);
1041 }
1042 }
1043 }
1044 }
1045
1046 let mut pending = HashMap::new();
1047 let iter = db.prefix_iterator_cf(cf, b"pending:");
1048 for item in iter {
1049 if let Ok((key, value)) = item {
1050 if let Ok(key_str) = std::str::from_utf8(&key) {
1051 if !key_str.starts_with("pending:") {
1052 break;
1053 }
1054 if let Ok(p) = serde_json::from_slice::<PendingFeedback>(&value) {
1055 if !p.is_expired() {
1056 pending.insert(p.user_id.clone(), p);
1057 } else {
1058 let _ = db.delete_cf(cf, key_str.as_bytes());
1059 }
1060 }
1061 }
1062 }
1063 }
1064
1065 let mut previous_context = HashMap::new();
1066 let iter = db.prefix_iterator_cf(cf, b"prev_ctx:");
1067 for item in iter {
1068 if let Ok((key, value)) = item {
1069 if let Ok(key_str) = std::str::from_utf8(&key) {
1070 if !key_str.starts_with("prev_ctx:") {
1071 break;
1072 }
1073 if let Ok(ctx) = serde_json::from_slice::<PreviousContext>(&value) {
1074 let user_id = key_str.strip_prefix("prev_ctx:").unwrap_or("");
1075 previous_context.insert(user_id.to_string(), ctx);
1076 }
1077 }
1078 }
1079 }
1080
1081 tracing::info!(
1082 "Loaded {} momentum, {} pending, {} previous context from shared feedback CF",
1083 momentum.len(),
1084 pending.len(),
1085 previous_context.len()
1086 );
1087
1088 Ok(Self {
1089 momentum,
1090 pending,
1091 previous_context,
1092 db: Some(db),
1093 dirty: HashSet::new(),
1094 })
1095 }
1096
1097 fn migrate_from_separate_db(base_path: &Path, db: &DB) -> anyhow::Result<()> {
1103 let old_dir = base_path.join("feedback");
1104 if !old_dir.is_dir() {
1105 return Ok(());
1106 }
1107
1108 let cf = db.cf_handle(CF_FEEDBACK).expect("feedback CF must exist");
1109 let old_opts = Options::default();
1110 match DB::open_for_read_only(&old_opts, &old_dir, false) {
1111 Ok(old_db) => {
1112 let mut batch = WriteBatch::default();
1113 let mut count = 0usize;
1114 for item in old_db.iterator(IteratorMode::Start) {
1115 if let Ok((key, value)) = item {
1116 batch.put_cf(cf, &key, &value);
1117 count += 1;
1118 if count % 10_000 == 0 {
1119 db.write(std::mem::take(&mut batch))?;
1120 }
1121 }
1122 }
1123 if !batch.is_empty() {
1124 db.write(batch)?;
1125 }
1126 drop(old_db);
1127 tracing::info!(" feedback: migrated {count} entries to {CF_FEEDBACK} CF");
1128
1129 let backup = base_path.join("feedback.pre_cf_migration");
1130 if backup.exists() {
1131 let _ = std::fs::remove_dir_all(&backup);
1132 }
1133 if let Err(e) = std::fs::rename(&old_dir, &backup) {
1134 tracing::warn!("Could not rename old feedback dir: {e}");
1135 }
1136 }
1137 Err(e) => tracing::warn!("Could not open old feedback DB for migration: {e}"),
1138 }
1139 Ok(())
1140 }
1141
1142 pub fn with_persistence<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
1147 let mut opts = Options::default();
1148 opts.create_if_missing(true);
1149 opts.create_missing_column_families(true);
1150 opts.set_compression_type(rocksdb::DBCompressionType::Lz4);
1151
1152 let cfs = vec![
1153 ColumnFamilyDescriptor::new("default", Options::default()),
1154 ColumnFamilyDescriptor::new(CF_FEEDBACK, {
1155 let mut cf_opts = Options::default();
1156 cf_opts.set_compression_type(rocksdb::DBCompressionType::Lz4);
1157 cf_opts
1158 }),
1159 ];
1160 let db = DB::open_cf_descriptors(&opts, path.as_ref(), cfs)?;
1161 let db = Arc::new(db);
1162
1163 let cf = db.cf_handle(CF_FEEDBACK).expect("feedback CF must exist");
1164
1165 let mut momentum = HashMap::new();
1167 let iter = db.prefix_iterator_cf(cf, b"momentum:");
1168 for item in iter {
1169 if let Ok((key, value)) = item {
1170 if let Ok(key_str) = std::str::from_utf8(&key) {
1171 if !key_str.starts_with("momentum:") {
1172 break;
1173 }
1174 if let Ok(m) = serde_json::from_slice::<FeedbackMomentum>(&value) {
1175 momentum.insert(m.memory_id.clone(), m);
1176 }
1177 }
1178 }
1179 }
1180
1181 let mut pending = HashMap::new();
1183 let iter = db.prefix_iterator_cf(cf, b"pending:");
1184 for item in iter {
1185 if let Ok((key, value)) = item {
1186 if let Ok(key_str) = std::str::from_utf8(&key) {
1187 if !key_str.starts_with("pending:") {
1188 break;
1189 }
1190 if let Ok(p) = serde_json::from_slice::<PendingFeedback>(&value) {
1191 if !p.is_expired() {
1192 pending.insert(p.user_id.clone(), p);
1193 } else {
1194 let _ = db.delete_cf(cf, key_str.as_bytes());
1196 }
1197 }
1198 }
1199 }
1200 }
1201
1202 let mut previous_context = HashMap::new();
1204 let iter = db.prefix_iterator_cf(cf, b"prev_ctx:");
1205 for item in iter {
1206 if let Ok((key, value)) = item {
1207 if let Ok(key_str) = std::str::from_utf8(&key) {
1208 if !key_str.starts_with("prev_ctx:") {
1209 break;
1210 }
1211 if let Ok(ctx) = serde_json::from_slice::<PreviousContext>(&value) {
1212 let user_id = key_str.strip_prefix("prev_ctx:").unwrap_or("");
1213 previous_context.insert(user_id.to_string(), ctx);
1214 }
1215 }
1216 }
1217 }
1218
1219 tracing::info!(
1220 "Loaded {} momentum, {} pending, {} previous context from feedback CF",
1221 momentum.len(),
1222 pending.len(),
1223 previous_context.len()
1224 );
1225
1226 Ok(Self {
1227 momentum,
1228 pending,
1229 previous_context,
1230 db: Some(db),
1231 dirty: HashSet::new(),
1232 })
1233 }
1234
1235 pub fn get_or_create_momentum(
1237 &mut self,
1238 memory_id: MemoryId,
1239 memory_type: ExperienceType,
1240 ) -> &mut FeedbackMomentum {
1241 if !self.momentum.contains_key(&memory_id) {
1243 if let (Some(db), Some(cf)) = (&self.db, self.feedback_cf()) {
1244 let key = format!("momentum:{}", memory_id.0);
1245 if let Ok(Some(data)) = db.get_cf(cf, key.as_bytes()) {
1246 if let Ok(m) = serde_json::from_slice::<FeedbackMomentum>(&data) {
1247 self.momentum.insert(memory_id.clone(), m);
1248 }
1249 }
1250 }
1251 }
1252
1253 self.momentum.entry(memory_id.clone()).or_insert_with(|| {
1254 self.dirty.insert(memory_id.clone());
1255 FeedbackMomentum::new(memory_id, memory_type)
1256 })
1257 }
1258
1259 pub fn get_momentum(&self, memory_id: &MemoryId) -> Option<FeedbackMomentum> {
1262 if let Some(m) = self.momentum.get(memory_id) {
1263 return Some(m.clone());
1264 }
1265 if let (Some(db), Some(cf)) = (&self.db, self.feedback_cf()) {
1267 let key = format!("momentum:{}", memory_id.0);
1268 if let Ok(Some(data)) = db.get_cf(cf, key.as_bytes()) {
1269 if let Ok(m) = serde_json::from_slice::<FeedbackMomentum>(&data) {
1270 return Some(m);
1271 }
1272 }
1273 }
1274 None
1275 }
1276
1277 pub fn mark_dirty(&mut self, memory_id: &MemoryId) {
1279 self.dirty.insert(memory_id.clone());
1280 }
1281
1282 pub fn set_pending(&mut self, pending: PendingFeedback) {
1284 let user_id = pending.user_id.clone();
1285 self.pending.insert(user_id.clone(), pending.clone());
1286
1287 if let (Some(db), Some(cf)) = (&self.db, self.feedback_cf()) {
1289 let key = format!("pending:{}", user_id);
1290 if let Ok(value) = serde_json::to_vec(&pending) {
1291 if let Err(e) = db.put_cf(cf, key.as_bytes(), &value) {
1292 tracing::warn!("Failed to persist pending feedback: {}", e);
1293 }
1294 }
1295 }
1296 }
1297
1298 pub fn take_pending(&mut self, user_id: &str) -> Option<PendingFeedback> {
1300 let result = self.pending.remove(user_id);
1301
1302 if let (Some(db), Some(cf)) = (&self.db, self.feedback_cf()) {
1304 let key = format!("pending:{}", user_id);
1305 let _ = db.delete_cf(cf, key.as_bytes());
1306 }
1307
1308 result
1309 }
1310
1311 pub fn get_pending(&self, user_id: &str) -> Option<&PendingFeedback> {
1313 self.pending.get(user_id)
1314 }
1315
1316 pub fn cleanup_expired(&mut self) {
1318 self.pending.retain(|_, p| !p.is_expired());
1319 }
1320
1321 pub fn set_previous_context(
1324 &mut self,
1325 user_id: &str,
1326 context: String,
1327 embedding: Vec<f32>,
1328 surfaced_memory_ids: Vec<MemoryId>,
1329 ) {
1330 let prev_ctx = PreviousContext {
1331 context,
1332 embedding,
1333 timestamp: Utc::now(),
1334 surfaced_memory_ids,
1335 };
1336
1337 self.previous_context
1338 .insert(user_id.to_string(), prev_ctx.clone());
1339
1340 if let (Some(db), Some(cf)) = (&self.db, self.feedback_cf()) {
1342 let key = format!("prev_ctx:{}", user_id);
1343 if let Ok(value) = serde_json::to_vec(&prev_ctx) {
1344 if let Err(e) = db.put_cf(cf, key.as_bytes(), &value) {
1345 tracing::warn!("Failed to persist previous context: {}", e);
1346 }
1347 }
1348 }
1349 }
1350
1351 pub fn get_previous_context(&self, user_id: &str) -> Option<&PreviousContext> {
1353 self.previous_context.get(user_id)
1354 }
1355
1356 pub fn detect_context_pattern(
1361 &self,
1362 user_id: &str,
1363 current_embedding: &[f32],
1364 ) -> Option<(bool, bool, f32)> {
1365 let prev = self.previous_context.get(user_id)?;
1366
1367 if prev.embedding.is_empty() || current_embedding.is_empty() {
1368 return None;
1369 }
1370
1371 let similarity = cosine_similarity(&prev.embedding, current_embedding);
1372
1373 let is_repetition = similarity > 0.8; let is_topic_change = similarity < 0.3; Some((is_repetition, is_topic_change, similarity))
1378 }
1379
1380 pub fn flush(&mut self) -> anyhow::Result<usize> {
1382 let Some(ref db) = self.db else {
1383 return Ok(0);
1384 };
1385 let Some(cf) = db.cf_handle(CF_FEEDBACK) else {
1386 return Ok(0);
1387 };
1388
1389 let dirty: Vec<MemoryId> = self.dirty.drain().collect();
1392
1393 let mut flushed = 0;
1394 for memory_id in &dirty {
1395 if let Some(momentum) = self.momentum.get(memory_id) {
1396 let key = format!("momentum:{}", memory_id.0);
1397 let value = serde_json::to_vec(momentum)?;
1398 db.put_cf(cf, key.as_bytes(), &value)?;
1399 flushed += 1;
1400 }
1401 }
1402
1403 for (user_id, pending) in &self.pending {
1405 let key = format!("pending:{}", user_id);
1406 let value = serde_json::to_vec(pending)?;
1407 db.put_cf(cf, key.as_bytes(), &value)?;
1408 }
1409
1410 use rocksdb::FlushOptions;
1412 let mut flush_opts = FlushOptions::default();
1413 flush_opts.set_wait(true);
1414 db.flush_cf_opt(cf, &flush_opts)
1415 .map_err(|e| anyhow::anyhow!("Failed to flush feedback CF: {e}"))?;
1416
1417 if flushed > 0 {
1418 tracing::debug!("Flushed {} feedback momentum entries to disk", flushed);
1419 }
1420
1421 Ok(flushed)
1422 }
1423
1424 pub fn database(&self) -> Option<&Arc<DB>> {
1426 self.db.as_ref()
1427 }
1428
1429 pub fn stats(&self) -> FeedbackStoreStats {
1431 FeedbackStoreStats {
1432 total_momentum_entries: self.momentum.len(),
1433 total_pending: self.pending.len(),
1434 avg_ema: if self.momentum.is_empty() {
1435 0.0
1436 } else {
1437 self.momentum.values().map(|m| m.ema).sum::<f32>() / self.momentum.len() as f32
1438 },
1439 avg_stability: if self.momentum.is_empty() {
1440 0.0
1441 } else {
1442 self.momentum.values().map(|m| m.stability).sum::<f32>()
1443 / self.momentum.len() as f32
1444 },
1445 }
1446 }
1447}
1448
1449#[derive(Debug, Clone, Serialize, Deserialize)]
1451pub struct FeedbackStoreStats {
1452 pub total_momentum_entries: usize,
1453 pub total_pending: usize,
1454 pub avg_ema: f32,
1455 pub avg_stability: f32,
1456}
1457
1458#[cfg(test)]
1459mod tests {
1460 use super::*;
1461 use uuid::Uuid;
1462
1463 #[test]
1464 fn test_signal_from_entity_overlap() {
1465 let signal = SignalRecord::from_entity_overlap(0.7);
1467 assert!(signal.value > 0.5);
1468 assert!(signal.confidence > 0.8);
1469
1470 let signal = SignalRecord::from_entity_overlap(0.3);
1472 assert!(signal.value > 0.0);
1473 assert!(signal.value < 0.5);
1474
1475 let signal = SignalRecord::from_entity_overlap(0.05);
1477 assert!(signal.value < 0.0);
1478 }
1479
1480 #[test]
1481 fn test_momentum_inertia_by_type() {
1482 let learning = FeedbackMomentum::new(MemoryId(Uuid::new_v4()), ExperienceType::Learning);
1483 let conversation =
1484 FeedbackMomentum::new(MemoryId(Uuid::new_v4()), ExperienceType::Conversation);
1485
1486 assert!(learning.base_inertia() > conversation.base_inertia());
1487 assert!(learning.base_inertia() >= 0.9);
1488 assert!(conversation.base_inertia() <= 0.4);
1489 }
1490
1491 #[test]
1492 fn test_momentum_update_with_inertia() {
1493 let mut momentum = FeedbackMomentum::new(
1494 MemoryId(Uuid::new_v4()),
1495 ExperienceType::Learning, );
1497
1498 momentum.update(SignalRecord::new(
1500 1.0,
1501 1.0,
1502 SignalTrigger::EntityOverlap { overlap_ratio: 1.0 },
1503 ));
1504
1505 assert!(momentum.ema > 0.0);
1507 assert!(momentum.ema < 0.5); for _ in 0..20 {
1511 momentum.update(SignalRecord::new(
1512 1.0,
1513 1.0,
1514 SignalTrigger::EntityOverlap { overlap_ratio: 1.0 },
1515 ));
1516 }
1517
1518 assert!(momentum.ema > 0.5);
1520 assert!(momentum.stability > 0.7);
1522 }
1523
1524 #[test]
1525 fn test_trend_detection() {
1526 let mut signals = VecDeque::new();
1527
1528 assert_eq!(Trend::from_signals(&signals), Trend::Insufficient);
1530
1531 for i in 0..10 {
1533 signals.push_back(SignalRecord::new(
1534 i as f32 * 0.15, 1.0,
1536 SignalTrigger::TopicChange { similarity: 0.2 },
1537 ));
1538 }
1539 assert_eq!(Trend::from_signals(&signals), Trend::Improving);
1540
1541 signals.clear();
1543 for i in (0..10).rev() {
1544 signals.push_back(SignalRecord::new(
1545 i as f32 * 0.15, 1.0,
1547 SignalTrigger::TopicChange { similarity: 0.2 },
1548 ));
1549 }
1550 assert_eq!(Trend::from_signals(&signals), Trend::Declining);
1551 }
1552
1553 #[test]
1554 fn test_entity_overlap() {
1555 let memory: HashSet<String> = ["rust", "async", "tokio"]
1556 .iter()
1557 .map(|s| s.to_string())
1558 .collect();
1559 let response: HashSet<String> = ["rust", "tokio", "spawn"]
1560 .iter()
1561 .map(|s| s.to_string())
1562 .collect();
1563
1564 let overlap = calculate_entity_overlap(&memory, &response);
1565 assert!((overlap - 0.666).abs() < 0.01); }
1567
1568 #[test]
1569 fn test_negative_keyword_detection() {
1570 let text = "No, that's not what I meant";
1572 let keywords = detect_negative_keywords(text);
1573 assert!(keywords.contains(&"not what i meant".to_string()));
1574
1575 let text2 = "That's not helpful at all, it's irrelevant";
1577 let keywords2 = detect_negative_keywords(text2);
1578 assert!(keywords2.contains(&"not helpful".to_string()));
1579 assert!(keywords2.contains(&"irrelevant".to_string()));
1580
1581 let text3 = "Please forget that, it doesn't work";
1583 let keywords3 = detect_negative_keywords(text3);
1584 assert!(keywords3.contains(&"forget that".to_string()));
1585 assert!(keywords3.contains(&"doesn't work".to_string()));
1586
1587 let text4 = "Can you help me debug this function?";
1589 let keywords4 = detect_negative_keywords(text4);
1590 assert!(keywords4.is_empty());
1591 }
1592
1593 #[test]
1594 fn test_feedback_store_pending() {
1595 let mut store = FeedbackStore::new();
1596 let user_id = "test-user";
1597
1598 assert!(store.get_pending(user_id).is_none());
1600
1601 let pending = PendingFeedback::new(
1603 user_id.to_string(),
1604 "test context".to_string(),
1605 vec![0.1; 384],
1606 vec![SurfacedMemoryInfo {
1607 id: MemoryId(Uuid::new_v4()),
1608 entities: ["rust", "memory"].iter().map(|s| s.to_string()).collect(),
1609 content_preview: "Test memory".to_string(),
1610 score: 0.8,
1611 embedding: Vec::new(),
1612 }],
1613 );
1614 store.set_pending(pending);
1615
1616 assert!(store.get_pending(user_id).is_some());
1618 assert_eq!(
1619 store.get_pending(user_id).unwrap().surfaced_memories.len(),
1620 1
1621 );
1622
1623 let taken = store.take_pending(user_id);
1625 assert!(taken.is_some());
1626 assert!(store.get_pending(user_id).is_none());
1627 }
1628
1629 #[test]
1630 fn test_feedback_store_momentum() {
1631 let mut store = FeedbackStore::new();
1632 let memory_id = MemoryId(Uuid::new_v4());
1633
1634 let momentum = store.get_or_create_momentum(memory_id.clone(), ExperienceType::Context);
1636 assert_eq!(momentum.signal_count, 0);
1637 assert_eq!(momentum.ema, 0.0);
1638
1639 momentum.update(SignalRecord::new(
1641 0.8,
1642 1.0,
1643 SignalTrigger::EntityOverlap { overlap_ratio: 0.8 },
1644 ));
1645 assert!(momentum.ema > 0.0);
1646 assert_eq!(momentum.signal_count, 1);
1647
1648 let momentum2 = store.get_momentum(&memory_id);
1650 assert!(momentum2.is_some());
1651 assert_eq!(momentum2.unwrap().signal_count, 1);
1652 }
1653
1654 #[test]
1655 fn test_process_implicit_feedback_full() {
1656 let memory_id1 = MemoryId(Uuid::new_v4());
1657 let memory_id2 = MemoryId(Uuid::new_v4());
1658
1659 let pending = PendingFeedback::new(
1660 "user1".to_string(),
1661 "How do I use async in Rust?".to_string(),
1662 vec![0.1; 384],
1663 vec![
1664 SurfacedMemoryInfo {
1665 id: memory_id1.clone(),
1666 entities: ["rust", "async", "tokio"]
1667 .iter()
1668 .map(|s| s.to_string())
1669 .collect(),
1670 content_preview: "Rust async with tokio".to_string(),
1671 score: 0.9,
1672 embedding: Vec::new(),
1673 },
1674 SurfacedMemoryInfo {
1675 id: memory_id2.clone(),
1676 entities: ["python", "django"].iter().map(|s| s.to_string()).collect(),
1677 content_preview: "Python Django web".to_string(),
1678 score: 0.3,
1679 embedding: Vec::new(),
1680 },
1681 ],
1682 );
1683
1684 let response =
1686 "To use async in Rust, you can use tokio runtime. Here is an example with async await.";
1687 let signals = process_implicit_feedback(&pending, response, None);
1688
1689 assert_eq!(signals.len(), 2);
1690
1691 let (id1, sig1) = &signals[0];
1693 assert_eq!(id1, &memory_id1);
1694 assert!(sig1.value > 0.0);
1695
1696 let (id2, sig2) = &signals[1];
1698 assert_eq!(id2, &memory_id2);
1699 assert!(sig2.value <= 0.0);
1700 }
1701
1702 #[test]
1703 fn test_process_implicit_feedback_with_negative_keywords() {
1704 let memory_id = MemoryId(Uuid::new_v4());
1705
1706 let pending = PendingFeedback::new(
1707 "user1".to_string(),
1708 "How do I use async?".to_string(),
1709 vec![0.1; 384],
1710 vec![SurfacedMemoryInfo {
1711 id: memory_id.clone(),
1712 entities: ["async", "code"].iter().map(|s| s.to_string()).collect(),
1713 content_preview: "Async code".to_string(),
1714 score: 0.9,
1715 embedding: Vec::new(),
1716 }],
1717 );
1718
1719 let response = "Here is the async code pattern";
1721
1722 let signals1 = process_implicit_feedback(&pending, response, None);
1724 let value_without = signals1[0].1.value;
1725
1726 let signals2 = process_implicit_feedback(&pending, response, Some("No, that is wrong!"));
1728 let value_with = signals2[0].1.value;
1729
1730 assert!(value_with < value_without);
1732 }
1733
1734 #[test]
1735 fn test_context_fingerprint_similarity() {
1736 let embedding: Vec<f32> = (0..384).map(|i| (i as f32) * 0.01).collect();
1737 let fp1 = ContextFingerprint::new(
1738 vec!["rust".to_string(), "memory".to_string()],
1739 &embedding,
1740 true,
1741 );
1742 let fp2 = ContextFingerprint::new(
1743 vec!["rust".to_string(), "async".to_string()],
1744 &embedding,
1745 false,
1746 );
1747 let different_embedding: Vec<f32> = (0..384).map(|i| 1.0 - (i as f32) * 0.01).collect();
1748 let fp3 = ContextFingerprint::new(
1749 vec!["python".to_string(), "django".to_string()],
1750 &different_embedding,
1751 true,
1752 );
1753
1754 let sim12 = fp1.similarity(&fp2);
1756 let sim13 = fp1.similarity(&fp3);
1758
1759 assert!(sim12 > sim13);
1760 }
1761
1762 #[test]
1763 fn test_feedback_store_stats() {
1764 let mut store = FeedbackStore::new();
1765
1766 let stats = store.stats();
1768 assert_eq!(stats.total_momentum_entries, 0);
1769 assert_eq!(stats.total_pending, 0);
1770
1771 for i in 0..5 {
1773 let mut momentum =
1774 FeedbackMomentum::new(MemoryId(Uuid::new_v4()), ExperienceType::Context);
1775 momentum.ema = i as f32 * 0.2; store.momentum.insert(momentum.memory_id.clone(), momentum);
1777 }
1778
1779 let stats = store.stats();
1780 assert_eq!(stats.total_momentum_entries, 5);
1781 assert!((stats.avg_ema - 0.4).abs() < 0.01); }
1783
1784 #[test]
1785 fn test_process_feedback_with_semantic_similarity() {
1786 let memory_id1 = MemoryId(Uuid::new_v4());
1787 let memory_id2 = MemoryId(Uuid::new_v4());
1788
1789 let rust_embedding: Vec<f32> = (0..384).map(|i| (i as f32) * 0.01).collect();
1791 let python_embedding: Vec<f32> = (0..384).map(|i| 1.0 - (i as f32) * 0.01).collect();
1792
1793 let pending = PendingFeedback::new(
1794 "user1".to_string(),
1795 "How do I use async in Rust?".to_string(),
1796 vec![0.1; 384],
1797 vec![
1798 SurfacedMemoryInfo {
1799 id: memory_id1.clone(),
1800 entities: ["rust", "async", "tokio"]
1801 .iter()
1802 .map(|s| s.to_string())
1803 .collect(),
1804 content_preview: "Rust async with tokio".to_string(),
1805 score: 0.9,
1806 embedding: rust_embedding.clone(),
1807 },
1808 SurfacedMemoryInfo {
1809 id: memory_id2.clone(),
1810 entities: ["python", "django"].iter().map(|s| s.to_string()).collect(),
1811 content_preview: "Python Django web".to_string(),
1812 score: 0.3,
1813 embedding: python_embedding.clone(),
1814 },
1815 ],
1816 );
1817
1818 let response = "Here is how to use async/await in Rust with tokio runtime.";
1820 let response_embedding = rust_embedding; let signals_entity_only = process_implicit_feedback(&pending, response, None);
1824
1825 let signals_with_semantic = process_implicit_feedback_with_semantics(
1827 &pending,
1828 response,
1829 None,
1830 Some(&response_embedding),
1831 );
1832
1833 let (id1, _sig1_entity) = &signals_entity_only[0];
1835 let (_, sig1_semantic) = &signals_with_semantic[0];
1836 assert_eq!(id1, &memory_id1);
1837
1838 match &sig1_semantic.trigger {
1840 SignalTrigger::SemanticSimilarity { similarity } => {
1841 assert!(*similarity > 0.9); }
1843 _ => panic!("Expected SemanticSimilarity trigger"),
1844 }
1845
1846 let (id2, sig2_semantic) = &signals_with_semantic[1];
1848 assert_eq!(id2, &memory_id2);
1849 match &sig2_semantic.trigger {
1850 SignalTrigger::SemanticSimilarity { similarity } => {
1851 assert!(*similarity < 0.5); }
1853 _ => panic!("Expected SemanticSimilarity trigger"),
1854 }
1855 }
1856
1857 #[test]
1858 fn test_cosine_similarity_basic() {
1859 let a = vec![1.0, 0.0, 0.0];
1861 let b = vec![1.0, 0.0, 0.0];
1862 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
1863
1864 let c = vec![0.0, 1.0, 0.0];
1866 assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.001);
1867
1868 let d = vec![-1.0, 0.0, 0.0];
1870 assert!((cosine_similarity(&a, &d) - (-1.0)).abs() < 0.001);
1871
1872 assert!((cosine_similarity(&[], &[]) - 0.0).abs() < 0.001);
1874 }
1875
1876 #[test]
1877 fn test_calculate_entity_flow() {
1878 use std::collections::HashSet;
1879
1880 let memory_entities: HashSet<String> = ["rust", "async", "tokio", "futures"]
1882 .iter()
1883 .map(|s| s.to_string())
1884 .collect();
1885 let response_entities: HashSet<String> = ["rust", "async", "tokio", "runtime"]
1886 .iter()
1887 .map(|s| s.to_string())
1888 .collect();
1889
1890 let (derived_ratio, novel_ratio, derived_count, total) =
1891 calculate_entity_flow(&memory_entities, &response_entities);
1892
1893 assert_eq!(derived_count, 3); assert_eq!(total, 4);
1895 assert!((derived_ratio - 0.75).abs() < 0.01);
1896 assert!((novel_ratio - 0.25).abs() < 0.01);
1897
1898 let response_novel: HashSet<String> = ["python", "django", "flask", "web"]
1900 .iter()
1901 .map(|s| s.to_string())
1902 .collect();
1903
1904 let (derived_ratio2, novel_ratio2, derived_count2, _) =
1905 calculate_entity_flow(&memory_entities, &response_novel);
1906
1907 assert_eq!(derived_count2, 0);
1908 assert!((derived_ratio2 - 0.0).abs() < 0.01);
1909 assert!((novel_ratio2 - 1.0).abs() < 0.01);
1910
1911 let empty: HashSet<String> = HashSet::new();
1913 let (dr, nr, dc, total) = calculate_entity_flow(&memory_entities, &empty);
1914 assert_eq!(dc, 0);
1915 assert_eq!(total, 0);
1916 assert!((dr - 0.0).abs() < 0.01);
1917 assert!((nr - 0.0).abs() < 0.01);
1918 }
1919
1920 #[test]
1921 fn test_signal_from_entity_flow() {
1922 let sig1 = signal_from_entity_flow(0.75, 0.25, 3, 4);
1924 assert!(sig1.value > 0.5); assert!((sig1.confidence - 0.8).abs() < 0.01); let sig2 = signal_from_entity_flow(0.3, 0.7, 1, 4);
1929 assert!(sig2.value > 0.0 && sig2.value <= 0.5); assert!((sig2.confidence - 0.8).abs() < 0.01);
1931
1932 let sig3 = signal_from_entity_flow(0.1, 0.9, 0, 4);
1934 assert!(sig3.value < 0.0); assert!((sig3.value - (-0.1)).abs() < 0.01);
1936
1937 let sig4 = signal_from_entity_flow(0.5, 0.5, 1, 2);
1939 assert!((sig4.confidence - 0.5).abs() < 0.01);
1940
1941 match sig1.trigger {
1943 SignalTrigger::EntityFlow {
1944 derived_ratio,
1945 novel_ratio,
1946 memory_entities_used,
1947 response_entities_total,
1948 } => {
1949 assert!((derived_ratio - 0.75).abs() < 0.01);
1950 assert!((novel_ratio - 0.25).abs() < 0.01);
1951 assert_eq!(memory_entities_used, 3);
1952 assert_eq!(response_entities_total, 4);
1953 }
1954 _ => panic!("Expected EntityFlow trigger"),
1955 }
1956 }
1957}