1use chrono::{DateTime, Utc};
36use serde::{Deserialize, Serialize};
37use std::collections::{HashMap, HashSet, VecDeque};
38use std::sync::Arc;
39use tokio::sync::RwLock;
40use uuid::Uuid;
41
42use crate::embeddings::{NerEntity, NeuralNer};
43use crate::graph_memory::GraphMemory;
44use crate::memory::{Experience, ExperienceType, MemorySystem, Query as MemoryQuery};
45use crate::similarity::cosine_similarity;
46
47#[inline]
51pub fn contains_ignore_ascii_case(haystack: &str, needle: &str) -> bool {
52 let needle_bytes = needle.as_bytes();
53 let needle_len = needle_bytes.len();
54 if needle_len == 0 {
55 return true;
56 }
57 let haystack_bytes = haystack.as_bytes();
58 if haystack_bytes.len() < needle_len {
59 return false;
60 }
61 'outer: for start in 0..=(haystack_bytes.len() - needle_len) {
63 for (i, &needle_byte) in needle_bytes.iter().enumerate() {
64 if haystack_bytes[start + i].to_ascii_lowercase() != needle_byte {
65 continue 'outer;
66 }
67 }
68 return true;
69 }
70 false
71}
72
73#[inline]
76pub fn content_hash(content: &str) -> u64 {
77 use std::hash::Hasher;
78 let mut hasher = std::collections::hash_map::DefaultHasher::new();
79 let trimmed = content.trim();
80 let bytes = trimmed.as_bytes();
81
82 const CHUNK_SIZE: usize = 64;
85 let mut buffer: [u8; CHUNK_SIZE] = [0u8; CHUNK_SIZE];
86
87 let mut i = 0;
88 while i + CHUNK_SIZE <= bytes.len() {
89 for j in 0..CHUNK_SIZE {
91 buffer[j] = bytes[i + j].to_ascii_lowercase();
92 }
93 hasher.write(&buffer);
94 i += CHUNK_SIZE;
95 }
96
97 let remaining = bytes.len() - i;
99 if remaining > 0 {
100 for j in 0..remaining {
101 buffer[j] = bytes[i + j].to_ascii_lowercase();
102 }
103 hasher.write(&buffer[..remaining]);
104 }
105
106 hasher.finish()
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
111#[serde(rename_all = "snake_case")]
112pub enum StreamMode {
113 Conversation,
115 Sensor,
117 Event,
119}
120
121impl Default for StreamMode {
122 fn default() -> Self {
123 StreamMode::Conversation
124 }
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct ExtractionConfig {
130 #[serde(default = "default_min_importance")]
133 pub min_importance: f32,
134
135 #[serde(default = "default_true")]
137 pub auto_dedupe: bool,
138
139 #[serde(default = "default_dedupe_threshold")]
142 pub dedupe_threshold: f32,
143
144 #[serde(default = "default_checkpoint_interval")]
147 pub checkpoint_interval_ms: u64,
148
149 #[serde(default = "default_max_buffer_size")]
151 pub max_buffer_size: usize,
152
153 #[serde(default = "default_true")]
155 pub extract_entities: bool,
156
157 #[serde(default = "default_true")]
159 pub create_relationships: bool,
160
161 #[serde(default = "default_true")]
163 pub merge_consecutive: bool,
164
165 #[serde(default = "default_trigger_events")]
167 pub trigger_events: Vec<String>,
168
169 #[serde(default = "default_true")]
172 pub enable_context_injection: bool,
173
174 #[serde(default = "default_injection_min_relevance")]
177 pub injection_min_relevance: f32,
178
179 #[serde(default = "default_injection_max_memories")]
181 pub injection_max_memories: usize,
182
183 #[serde(default = "default_injection_cooldown")]
185 pub injection_cooldown_secs: u64,
186}
187
188fn default_min_importance() -> f32 {
189 0.3
190}
191fn default_true() -> bool {
192 true
193}
194fn default_dedupe_threshold() -> f32 {
195 0.85
196}
197fn default_checkpoint_interval() -> u64 {
198 5000 }
200fn default_max_buffer_size() -> usize {
201 50
202}
203fn default_trigger_events() -> Vec<String> {
204 vec![
205 "error".to_string(),
206 "decision".to_string(),
207 "discovery".to_string(),
208 "learning".to_string(),
209 ]
210}
211fn default_injection_min_relevance() -> f32 {
212 0.70 }
214fn default_injection_max_memories() -> usize {
215 3 }
217fn default_injection_cooldown() -> u64 {
218 180 }
220
221const MIN_CHECKPOINT_INTERVAL_MS: u64 = 100; const MAX_CHECKPOINT_INTERVAL_MS: u64 = 3_600_000; const MAX_BUFFER_SIZE: usize = 10_000; const MAX_TRIGGER_EVENTS: usize = 100; const MAX_SEEN_HASHES: usize = 10_000; impl ExtractionConfig {
229 pub fn validate_and_clamp(&mut self) {
232 self.min_importance = self.min_importance.clamp(0.0, 1.0);
234
235 self.dedupe_threshold = self.dedupe_threshold.clamp(0.0, 1.0);
237
238 if self.checkpoint_interval_ms > 0 {
240 self.checkpoint_interval_ms = self
241 .checkpoint_interval_ms
242 .clamp(MIN_CHECKPOINT_INTERVAL_MS, MAX_CHECKPOINT_INTERVAL_MS);
243 }
244
245 if self.max_buffer_size == 0 {
247 self.max_buffer_size = default_max_buffer_size();
248 }
249 self.max_buffer_size = self.max_buffer_size.min(MAX_BUFFER_SIZE);
250
251 if self.trigger_events.len() > MAX_TRIGGER_EVENTS {
253 self.trigger_events.truncate(MAX_TRIGGER_EVENTS);
254 }
255
256 self.injection_min_relevance = self.injection_min_relevance.clamp(0.0, 1.0);
258 self.injection_max_memories = self.injection_max_memories.clamp(1, 10);
259 self.injection_cooldown_secs = self.injection_cooldown_secs.clamp(0, 3600);
260 }
261}
262
263impl Default for ExtractionConfig {
264 fn default() -> Self {
265 Self {
266 min_importance: default_min_importance(),
267 auto_dedupe: true,
268 dedupe_threshold: default_dedupe_threshold(),
269 checkpoint_interval_ms: default_checkpoint_interval(),
270 max_buffer_size: default_max_buffer_size(),
271 extract_entities: true,
272 create_relationships: true,
273 merge_consecutive: true,
274 trigger_events: default_trigger_events(),
275 enable_context_injection: true,
277 injection_min_relevance: default_injection_min_relevance(),
278 injection_max_memories: default_injection_max_memories(),
279 injection_cooldown_secs: default_injection_cooldown(),
280 }
281 }
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct StreamHandshake {
287 pub user_id: String,
289
290 #[serde(default)]
292 pub mode: StreamMode,
293
294 #[serde(default)]
296 pub extraction_config: ExtractionConfig,
297
298 pub session_id: Option<String>,
300
301 #[serde(default)]
303 pub metadata: HashMap<String, serde_json::Value>,
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
308#[serde(tag = "type", rename_all = "snake_case")]
309pub enum StreamMessage {
310 Content {
312 content: String,
314
315 #[serde(default)]
317 source: Option<String>,
318
319 #[serde(default)]
321 timestamp: Option<DateTime<Utc>>,
322
323 #[serde(default)]
325 importance: Option<f32>,
326
327 #[serde(default)]
329 tags: Vec<String>,
330
331 #[serde(default)]
333 metadata: HashMap<String, serde_json::Value>,
334 },
335
336 Sensor {
338 sensor_id: String,
340
341 values: HashMap<String, f64>,
343
344 #[serde(default)]
346 timestamp: Option<DateTime<Utc>>,
347
348 #[serde(default)]
350 units: HashMap<String, String>,
351 },
352
353 Event {
355 event: String,
357
358 description: String,
360
361 #[serde(default)]
363 timestamp: Option<DateTime<Utc>>,
364
365 #[serde(default)]
367 severity: Option<String>,
368
369 #[serde(default)]
371 data: HashMap<String, serde_json::Value>,
372 },
373
374 Flush,
376
377 Ping,
379
380 Close,
382}
383
384#[derive(Debug, Clone, Serialize, Deserialize)]
386#[serde(tag = "type", rename_all = "snake_case")]
387pub enum ExtractionResult {
388 Extraction {
390 memories_created: usize,
392
393 memory_ids: Vec<String>,
395
396 entities_detected: Vec<DetectedEntity>,
398
399 dedupe_skipped: usize,
401
402 processing_time_ms: u64,
404
405 timestamp: DateTime<Utc>,
407 },
408
409 ContextInjection {
412 memories: Vec<SurfacedStreamMemory>,
414
415 context_hash: u64,
417
418 processing_time_ms: u64,
420
421 timestamp: DateTime<Utc>,
423 },
424
425 Ack {
427 message_type: String,
429 timestamp: DateTime<Utc>,
430 },
431
432 Error {
434 code: String,
436 message: String,
438 fatal: bool,
440 timestamp: DateTime<Utc>,
441 },
442
443 Closed {
445 reason: String,
447 total_memories_created: usize,
449 timestamp: DateTime<Utc>,
450 },
451}
452
453#[derive(Debug, Clone, Serialize, Deserialize)]
455pub struct SurfacedStreamMemory {
456 pub id: String,
458
459 pub content: String,
461
462 pub memory_type: String,
464
465 pub relevance: f32,
467
468 pub relevance_breakdown: RelevanceBreakdown,
470
471 pub created_at: DateTime<Utc>,
473
474 pub tags: Vec<String>,
476}
477
478#[derive(Debug, Clone, Serialize, Deserialize)]
480pub struct RelevanceBreakdown {
481 pub semantic: f32,
483
484 pub recency: f32,
486
487 pub strength: f32,
489}
490
491#[derive(Debug, Clone, Serialize, Deserialize)]
493pub struct DetectedEntity {
494 pub text: String,
496 pub entity_type: String,
498 pub confidence: f32,
500 pub existing: bool,
502}
503
504impl From<&NerEntity> for DetectedEntity {
505 fn from(ner: &NerEntity) -> Self {
506 Self {
507 text: ner.text.clone(),
508 entity_type: ner.entity_type.as_str().to_string(),
509 confidence: ner.confidence,
510 existing: false,
511 }
512 }
513}
514
515#[derive(Debug, Clone)]
517pub struct BufferedMessage {
518 pub content: String,
519 pub source: Option<String>,
520 #[allow(dead_code)]
521 pub timestamp: DateTime<Utc>,
522 pub importance: Option<f32>,
523 pub tags: Vec<String>,
524 pub metadata: HashMap<String, serde_json::Value>,
525}
526
527const MAX_CONCURRENT_SESSIONS: usize = 1000;
529
530const SESSION_TIMEOUT_SECS: i64 = 3600;
532
533pub struct StreamSession {
535 pub session_id: String,
537
538 pub user_id: String,
540
541 pub mode: StreamMode,
543
544 pub config: ExtractionConfig,
546
547 pub metadata: HashMap<String, serde_json::Value>,
549
550 buffer: VecDeque<BufferedMessage>,
552
553 last_extraction: DateTime<Utc>,
555
556 last_activity: DateTime<Utc>,
558
559 total_memories_created: usize,
561
562 seen_hashes: HashSet<u64>,
564
565 #[allow(dead_code)]
568 recent_embeddings: VecDeque<(String, Vec<f32>)>,
569
570 injection_cooldowns: HashMap<String, DateTime<Utc>>,
573
574 recent_context_hashes: VecDeque<u64>,
576}
577
578impl StreamSession {
579 pub fn new(handshake: StreamHandshake) -> Self {
580 let session_id = handshake
581 .session_id
582 .unwrap_or_else(|| Uuid::new_v4().to_string());
583
584 let mut config = handshake.extraction_config;
586 config.validate_and_clamp();
587
588 let now = Utc::now();
589 Self {
590 session_id,
591 user_id: handshake.user_id,
592 mode: handshake.mode,
593 config,
594 metadata: handshake.metadata,
595 buffer: VecDeque::with_capacity(64),
596 last_extraction: now,
597 last_activity: now,
598 total_memories_created: 0,
599 seen_hashes: HashSet::with_capacity(1024),
600 recent_embeddings: VecDeque::with_capacity(100),
601 injection_cooldowns: HashMap::new(),
602 recent_context_hashes: VecDeque::with_capacity(20),
603 }
604 }
605
606 fn mark_injected(&mut self, memory_id: &str) {
608 self.injection_cooldowns
609 .insert(memory_id.to_string(), Utc::now());
610 }
611
612 fn cleanup_injection_cooldowns(&mut self) {
614 let threshold = self.config.injection_cooldown_secs as i64 * 2;
615 let cutoff = Utc::now() - chrono::Duration::seconds(threshold);
616 self.injection_cooldowns.retain(|_, ts| *ts > cutoff);
617 }
618
619 fn should_extract_by_time(&self) -> bool {
621 if self.config.checkpoint_interval_ms == 0 {
622 return false;
623 }
624
625 let elapsed = Utc::now()
626 .signed_duration_since(self.last_extraction)
627 .num_milliseconds() as u64;
628
629 elapsed >= self.config.checkpoint_interval_ms
630 }
631
632 fn should_extract_by_size(&self) -> bool {
634 self.buffer.len() >= self.config.max_buffer_size
635 }
636
637 #[inline]
639 fn hash_content(content: &str) -> u64 {
640 content_hash(content)
641 }
642
643 fn is_exact_duplicate(&self, content: &str) -> bool {
645 let hash = Self::hash_content(content);
646 self.seen_hashes.contains(&hash)
647 }
648
649 fn mark_seen(&mut self, content: &str) {
653 if self.seen_hashes.len() >= MAX_SEEN_HASHES {
654 let target = MAX_SEEN_HASHES / 2;
655 let mut kept = 0usize;
656 self.seen_hashes.retain(|_| {
657 kept += 1;
658 kept <= target
659 });
660 }
661 let hash = Self::hash_content(content);
662 self.seen_hashes.insert(hash);
663 }
664
665 pub fn buffer_message(&mut self, msg: BufferedMessage) -> bool {
667 if self.config.auto_dedupe && self.is_exact_duplicate(&msg.content) {
669 return false;
670 }
671
672 if self.config.merge_consecutive && !self.buffer.is_empty() {
674 if let Some(last) = self.buffer.back_mut() {
675 if last.source == msg.source {
676 last.content.push('\n');
677 last.content.push_str(&msg.content);
678 last.tags.extend(msg.tags);
679 for (k, v) in msg.metadata {
680 last.metadata.insert(k, v);
681 }
682 return true;
683 }
684 }
685 }
686
687 self.mark_seen(&msg.content);
688 self.buffer.push_back(msg);
689 true
690 }
691
692 fn drain_buffer(&mut self) -> Vec<BufferedMessage> {
694 self.last_extraction = Utc::now();
695 self.buffer.drain(..).collect()
696 }
697
698 fn touch(&mut self) {
700 self.last_activity = Utc::now();
701 }
702
703 fn is_stale(&self) -> bool {
705 let elapsed = Utc::now()
706 .signed_duration_since(self.last_activity)
707 .num_seconds();
708 elapsed > SESSION_TIMEOUT_SECS
709 }
710}
711
712pub struct StreamingMemoryExtractor {
714 neural_ner: Arc<NeuralNer>,
716
717 sessions: Arc<RwLock<HashMap<String, StreamSession>>>,
719 }
721
722impl StreamingMemoryExtractor {
723 pub fn new(neural_ner: Arc<NeuralNer>) -> Self {
724 Self {
725 neural_ner,
726 sessions: Arc::new(RwLock::new(HashMap::new())),
727 }
728 }
729
730 pub async fn create_session(&self, handshake: StreamHandshake) -> Result<String, String> {
732 self.cleanup_stale_sessions().await;
734
735 let mut sessions = self.sessions.write().await;
736
737 if sessions.len() >= MAX_CONCURRENT_SESSIONS {
739 return Err(format!(
740 "Maximum concurrent sessions ({}) reached. Try again later.",
741 MAX_CONCURRENT_SESSIONS
742 ));
743 }
744
745 let session = StreamSession::new(handshake);
746 let session_id = session.session_id.clone();
747 sessions.insert(session_id.clone(), session);
748
749 Ok(session_id)
750 }
751
752 pub async fn cleanup_stale_sessions(&self) -> usize {
754 let mut sessions = self.sessions.write().await;
755 let before_count = sessions.len();
756
757 sessions.retain(|_id, session| !session.is_stale());
758
759 let removed = before_count - sessions.len();
760 if removed > 0 {
761 tracing::info!("Cleaned up {} stale streaming sessions", removed);
762 }
763 removed
764 }
765
766 pub async fn session_count(&self) -> usize {
768 self.sessions.read().await.len()
769 }
770
771 pub async fn process_message(
773 &self,
774 session_id: &str,
775 message: StreamMessage,
776 memory_system: Arc<parking_lot::RwLock<MemorySystem>>,
777 ) -> ExtractionResult {
778 let mut sessions = self.sessions.write().await;
779
780 let session = match sessions.get_mut(session_id) {
781 Some(s) => s,
782 None => {
783 return ExtractionResult::Error {
784 code: "SESSION_NOT_FOUND".to_string(),
785 message: format!("Session {} not found", session_id),
786 fatal: true,
787 timestamp: Utc::now(),
788 }
789 }
790 };
791
792 session.touch();
794
795 match message {
796 StreamMessage::Content {
797 content,
798 source,
799 timestamp,
800 importance,
801 tags,
802 metadata,
803 } => {
804 let msg = BufferedMessage {
805 content,
806 source,
807 timestamp: timestamp.unwrap_or_else(Utc::now),
808 importance,
809 tags,
810 metadata,
811 };
812
813 let buffered = session.buffer_message(msg);
814
815 let should_extract = session.should_extract_by_time()
817 || session.should_extract_by_size()
818 || !buffered; if should_extract {
821 drop(sessions);
822 return self.extract_memories(session_id, memory_system).await;
823 }
824
825 ExtractionResult::Ack {
826 message_type: "content".to_string(),
827 timestamp: Utc::now(),
828 }
829 }
830
831 StreamMessage::Event {
832 event,
833 description,
834 timestamp,
835 severity,
836 data,
837 } => {
838 let is_trigger = {
840 let sessions = self.sessions.read().await;
841 sessions
842 .get(session_id)
843 .map(|s| {
844 s.config
845 .trigger_events
846 .iter()
847 .any(|t| t.eq_ignore_ascii_case(&event))
848 })
849 .unwrap_or(false)
850 };
851
852 let content = format!(
853 "[{}] {}: {}",
854 severity.unwrap_or_default(),
855 event,
856 description
857 );
858 let mut metadata: HashMap<String, serde_json::Value> = data;
859 metadata.insert("event_type".to_string(), serde_json::json!(event));
860
861 let msg = BufferedMessage {
862 content,
863 source: Some("event".to_string()),
864 timestamp: timestamp.unwrap_or_else(Utc::now),
865 importance: if is_trigger { Some(0.8) } else { None },
866 tags: vec![event.clone()],
867 metadata,
868 };
869
870 {
871 let mut sessions = self.sessions.write().await;
872 if let Some(session) = sessions.get_mut(session_id) {
873 session.buffer_message(msg);
874 }
875 }
876
877 if is_trigger {
878 return self.extract_memories(session_id, memory_system).await;
879 }
880
881 ExtractionResult::Ack {
882 message_type: "event".to_string(),
883 timestamp: Utc::now(),
884 }
885 }
886
887 StreamMessage::Sensor {
888 sensor_id,
889 values,
890 timestamp,
891 units,
892 } => {
893 let mut parts: Vec<String> = Vec::new();
895 for (key, value) in &values {
896 let unit = units.get(key).map(|u| u.as_str()).unwrap_or("");
897 parts.push(format!("{}={}{}", key, value, unit));
898 }
899 let content = format!("[{}] {}", sensor_id, parts.join(", "));
900
901 let msg = BufferedMessage {
902 content,
903 source: Some(format!("sensor:{}", sensor_id)),
904 timestamp: timestamp.unwrap_or_else(Utc::now),
905 importance: None,
906 tags: vec!["sensor".to_string(), sensor_id],
907 metadata: HashMap::new(),
908 };
909
910 let should_extract = {
911 let mut sessions = self.sessions.write().await;
912 if let Some(session) = sessions.get_mut(session_id) {
913 session.buffer_message(msg);
914 session.should_extract_by_time() || session.should_extract_by_size()
915 } else {
916 return ExtractionResult::Error {
917 code: "SESSION_NOT_FOUND".to_string(),
918 message: format!("Session '{}' not found", session_id),
919 fatal: true,
920 timestamp: Utc::now(),
921 };
922 }
923 };
924
925 if should_extract {
927 return self.extract_memories(session_id, memory_system).await;
928 }
929
930 ExtractionResult::Ack {
931 message_type: "sensor".to_string(),
932 timestamp: Utc::now(),
933 }
934 }
935
936 StreamMessage::Flush => {
937 drop(sessions);
938 self.extract_memories(session_id, memory_system).await
939 }
940
941 StreamMessage::Ping => ExtractionResult::Ack {
942 message_type: "ping".to_string(),
943 timestamp: Utc::now(),
944 },
945
946 StreamMessage::Close => {
947 drop(sessions);
949 let final_result = self.extract_memories(session_id, memory_system).await;
950
951 let mut sessions = self.sessions.write().await;
953 let total = sessions
954 .get(session_id)
955 .map(|s| s.total_memories_created)
956 .unwrap_or(0);
957 sessions.remove(session_id);
958
959 ExtractionResult::Closed {
960 reason: "client_requested".to_string(),
961 total_memories_created: total
962 + match &final_result {
963 ExtractionResult::Extraction {
964 memories_created, ..
965 } => *memories_created,
966 _ => 0,
967 },
968 timestamp: Utc::now(),
969 }
970 }
971 }
972 }
973
974 async fn extract_memories(
976 &self,
977 session_id: &str,
978 memory_system: Arc<parking_lot::RwLock<MemorySystem>>,
979 ) -> ExtractionResult {
980 let start = std::time::Instant::now();
981
982 let (messages, config, user_metadata, mode) = {
984 let mut sessions = self.sessions.write().await;
985 let session = match sessions.get_mut(session_id) {
986 Some(s) => s,
987 None => {
988 return ExtractionResult::Error {
989 code: "SESSION_NOT_FOUND".to_string(),
990 message: format!("Session {} not found", session_id),
991 fatal: true,
992 timestamp: Utc::now(),
993 }
994 }
995 };
996
997 let messages = session.drain_buffer();
998 let config = session.config.clone();
999 let metadata = session.metadata.clone();
1000 let mode = session.mode;
1001
1002 (messages, config, metadata, mode)
1003 };
1004
1005 if messages.is_empty() {
1006 return ExtractionResult::Extraction {
1007 memories_created: 0,
1008 memory_ids: vec![],
1009 entities_detected: vec![],
1010 dedupe_skipped: 0,
1011 processing_time_ms: start.elapsed().as_millis() as u64,
1012 timestamp: Utc::now(),
1013 };
1014 }
1015
1016 let mut memory_ids = Vec::new();
1017 let mut all_entities = Vec::new();
1018 let mut dedupe_skipped = 0;
1019
1020 for msg in messages {
1022 let importance = msg
1024 .importance
1025 .unwrap_or_else(|| Self::calculate_importance(&msg.content, mode, &config));
1026
1027 if importance < config.min_importance {
1029 dedupe_skipped += 1;
1030 continue;
1031 }
1032
1033 let entities: Vec<NerEntity> = if config.extract_entities {
1035 match self.neural_ner.extract(&msg.content) {
1036 Ok(ents) => ents,
1037 Err(e) => {
1038 tracing::debug!("NER extraction failed: {}", e);
1039 Vec::new()
1040 }
1041 }
1042 } else {
1043 Vec::new()
1044 };
1045
1046 for ent in &entities {
1048 all_entities.push(DetectedEntity::from(ent));
1049 }
1050
1051 let experience_type = Self::determine_experience_type(mode, &msg);
1053
1054 let mut string_metadata: HashMap<String, String> = HashMap::new();
1056 for (k, v) in user_metadata.iter() {
1057 string_metadata.insert(k.clone(), v.to_string());
1058 }
1059 for (k, v) in msg.metadata {
1060 string_metadata.insert(k, v.to_string());
1061 }
1062 let tags: Vec<String> = msg.tags.clone();
1064
1065 for tag in &tags {
1067 string_metadata.insert(format!("tag:{}", tag), "true".to_string());
1068 }
1069
1070 let mut all_entity_names: Vec<String> =
1072 entities.iter().map(|e| e.text.clone()).collect();
1073 for tag in &tags {
1074 if !all_entity_names.iter().any(|e| e.eq_ignore_ascii_case(tag)) {
1075 all_entity_names.push(tag.clone());
1076 }
1077 }
1078
1079 let experience = Experience {
1081 content: msg.content,
1082 experience_type,
1083 entities: all_entity_names,
1084 metadata: string_metadata,
1085 embeddings: None, tags,
1087 ..Default::default()
1088 };
1089
1090 let memory_sys = memory_system.read();
1092 match memory_sys.remember(experience, Some(msg.timestamp)) {
1093 Ok(memory_id) => {
1095 memory_ids.push(memory_id.0.to_string());
1097 }
1098 Err(e) => {
1099 tracing::warn!("Failed to store streaming memory: {}", e);
1100 }
1101 }
1102 }
1103
1104 {
1106 let mut sessions = self.sessions.write().await;
1107 if let Some(session) = sessions.get_mut(session_id) {
1108 session.total_memories_created += memory_ids.len();
1109 }
1110 }
1111
1112 ExtractionResult::Extraction {
1113 memories_created: memory_ids.len(),
1114 memory_ids,
1115 entities_detected: all_entities,
1116 dedupe_skipped,
1117 processing_time_ms: start.elapsed().as_millis() as u64,
1118 timestamp: Utc::now(),
1119 }
1120 }
1121
1122 fn calculate_importance(content: &str, mode: StreamMode, _config: &ExtractionConfig) -> f32 {
1124 let mut importance: f32 = 0.5;
1125
1126 let word_count = content.split_whitespace().count();
1128 if word_count > 50 {
1129 importance += 0.1;
1130 } else if word_count < 10 {
1131 importance -= 0.1;
1132 }
1133
1134 match mode {
1136 StreamMode::Conversation => {
1137 if content.contains('?') {
1139 importance += 0.15;
1140 }
1141 if content.contains("```") || content.contains("fn ") || content.contains("def ") {
1143 importance += 0.2;
1144 }
1145 if contains_ignore_ascii_case(content, "error")
1147 || contains_ignore_ascii_case(content, "failed")
1148 {
1149 importance += 0.2;
1150 }
1151 }
1152 StreamMode::Sensor => {
1153 importance = 0.4;
1156 }
1157 StreamMode::Event => {
1158 if contains_ignore_ascii_case(content, "error") {
1160 importance += 0.3;
1161 } else if contains_ignore_ascii_case(content, "warning") {
1162 importance += 0.15;
1163 }
1164 }
1165 }
1166
1167 importance.clamp(0.0, 1.0)
1168 }
1169
1170 fn determine_experience_type(mode: StreamMode, msg: &BufferedMessage) -> ExperienceType {
1172 for tag in &msg.tags {
1174 if contains_ignore_ascii_case(tag, "error") {
1175 return ExperienceType::Error;
1176 }
1177 if contains_ignore_ascii_case(tag, "decision") {
1178 return ExperienceType::Decision;
1179 }
1180 if contains_ignore_ascii_case(tag, "learning") {
1181 return ExperienceType::Learning;
1182 }
1183 if contains_ignore_ascii_case(tag, "discovery") {
1184 return ExperienceType::Discovery;
1185 }
1186 }
1187
1188 match mode {
1190 StreamMode::Conversation => ExperienceType::Conversation,
1191 StreamMode::Sensor => ExperienceType::Observation,
1192 StreamMode::Event => ExperienceType::Observation,
1193 }
1194 }
1195
1196 pub async fn close_session(&self, session_id: &str) -> Option<usize> {
1198 let mut sessions = self.sessions.write().await;
1199 sessions
1200 .remove(session_id)
1201 .map(|s| s.total_memories_created)
1202 }
1203
1204 pub async fn inject_context(
1210 &self,
1211 session_id: &str,
1212 content: &str,
1213 memory_system: Arc<parking_lot::RwLock<MemorySystem>>,
1214 graph_memory: Arc<parking_lot::RwLock<GraphMemory>>,
1215 ) -> Option<ExtractionResult> {
1216 let start = std::time::Instant::now();
1217
1218 let (config, _user_id) = {
1220 let sessions = self.sessions.read().await;
1221 let session = sessions.get(session_id)?;
1222 if !session.config.enable_context_injection {
1223 return None;
1224 }
1225 (session.config.clone(), session.user_id.clone())
1226 };
1227
1228 let context_hash = content_hash(content);
1230
1231 {
1233 let sessions = self.sessions.read().await;
1234 if let Some(session) = sessions.get(session_id) {
1235 if session.recent_context_hashes.contains(&context_hash) {
1236 return None; }
1238 }
1239 }
1240
1241 let content_for_embed = content.to_string();
1243 let memory_for_embed = memory_system.clone();
1244 let context_embedding: Vec<f32> = tokio::task::spawn_blocking(move || {
1245 let guard = memory_for_embed.read();
1246 guard
1247 .compute_embedding(&content_for_embed)
1248 .unwrap_or_else(|_| vec![0.0; 384])
1249 })
1250 .await
1251 .ok()?;
1252
1253 let min_relevance = config.injection_min_relevance;
1256 let max_per_message = config.injection_max_memories;
1257 let cooldown_seconds = config.injection_cooldown_secs;
1258
1259 let content_for_query = content.to_string();
1261 let max_results = max_per_message * 2; let context_emb = context_embedding.clone();
1263
1264 let cooldown_snapshot: HashSet<String> = {
1267 let sessions_guard = self.sessions.read().await;
1268 if let Some(session) = sessions_guard.get(session_id) {
1269 session
1270 .injection_cooldowns
1271 .iter()
1272 .filter(|(_, ts)| {
1273 let elapsed = Utc::now().signed_duration_since(**ts).num_seconds() as u64;
1274 elapsed < cooldown_seconds
1275 })
1276 .map(|(id, _)| id.clone())
1277 .collect()
1278 } else {
1279 HashSet::new()
1280 }
1281 };
1282
1283 let surfaced: Vec<SurfacedStreamMemory> = {
1284 let memory = memory_system.clone();
1285 let graph = graph_memory.clone();
1286
1287 tokio::task::spawn_blocking(move || {
1288 let memory_guard = memory.read();
1289 let graph_guard = graph.read();
1290 let now = Utc::now();
1291
1292 let query = MemoryQuery {
1294 query_text: Some(content_for_query),
1295 max_results,
1296 ..Default::default()
1297 };
1298 let results = memory_guard.recall(&query).unwrap_or_default();
1299
1300 const RECENCY_DECAY_RATE: f32 = 0.01; let mut candidates: Vec<(_, f32, f32, f32, f32)> = results
1305 .into_iter()
1306 .filter_map(|m| {
1307 let memory_embedding = m.experience.embeddings.as_ref()?.clone();
1308
1309 let score = m.get_score().unwrap_or(0.0);
1311
1312 let semantic = cosine_similarity(&memory_embedding, &context_emb);
1314 let hours_old = (now - m.created_at).num_hours().max(0) as f32;
1315 let recency = (-RECENCY_DECAY_RATE * hours_old).exp();
1316 let hebbian_strength = graph_guard
1317 .get_memory_hebbian_strength(&m.id)
1318 .unwrap_or(0.0);
1319
1320 Some((m, score, semantic, recency, hebbian_strength))
1321 })
1322 .collect();
1323
1324 candidates.sort_by(|a, b| b.1.total_cmp(&a.1));
1326
1327 candidates
1329 .into_iter()
1330 .filter(|(m, score, _, _, _)| {
1331 if *score < min_relevance {
1332 return false;
1333 }
1334 if cooldown_snapshot.contains(&m.id.0.to_string()) {
1336 return false;
1337 }
1338 true
1339 })
1340 .take(max_per_message)
1341 .map(
1342 |(m, score, semantic, recency, strength)| SurfacedStreamMemory {
1343 id: m.id.0.to_string(),
1344 content: m.experience.content.clone(),
1345 memory_type: format!("{:?}", m.experience.experience_type),
1346 relevance: score,
1347 relevance_breakdown: RelevanceBreakdown {
1348 semantic,
1349 recency,
1350 strength,
1351 },
1352 created_at: m.created_at,
1353 tags: m.experience.entities.clone(),
1354 },
1355 )
1356 .collect()
1357 })
1358 .await
1359 .ok()?
1360 };
1361
1362 if surfaced.is_empty() {
1363 return None;
1364 }
1365
1366 {
1368 let mut sessions = self.sessions.write().await;
1369 if let Some(session) = sessions.get_mut(session_id) {
1370 for mem in &surfaced {
1371 session.mark_injected(&mem.id);
1372 }
1373 session.recent_context_hashes.push_back(context_hash);
1374 if session.recent_context_hashes.len() > 20 {
1375 session.recent_context_hashes.pop_front();
1376 }
1377 session.cleanup_injection_cooldowns();
1378 }
1379 }
1380
1381 Some(ExtractionResult::ContextInjection {
1382 memories: surfaced,
1383 context_hash,
1384 processing_time_ms: start.elapsed().as_millis() as u64,
1385 timestamp: Utc::now(),
1386 })
1387 }
1388
1389 pub async fn get_session_stats(&self, session_id: &str) -> Option<SessionStats> {
1390 let sessions = self.sessions.read().await;
1391 sessions.get(session_id).map(|s| SessionStats {
1392 session_id: s.session_id.clone(),
1393 user_id: s.user_id.clone(),
1394 mode: s.mode,
1395 buffer_size: s.buffer.len(),
1396 total_memories_created: s.total_memories_created,
1397 last_extraction: s.last_extraction,
1398 })
1399 }
1400}
1401
1402#[derive(Debug, Clone, Serialize, Deserialize)]
1404pub struct SessionStats {
1405 pub session_id: String,
1406 pub user_id: String,
1407 pub mode: StreamMode,
1408 pub buffer_size: usize,
1409 pub total_memories_created: usize,
1410 pub last_extraction: DateTime<Utc>,
1411}
1412
1413#[cfg(test)]
1414mod tests {
1415 use super::*;
1416 use crate::embeddings::NerEntityType;
1417
1418 #[test]
1419 fn test_extraction_config_defaults() {
1420 let config = ExtractionConfig::default();
1421 assert_eq!(config.min_importance, 0.3);
1422 assert!(config.auto_dedupe);
1423 assert_eq!(config.checkpoint_interval_ms, 5000);
1424 assert_eq!(config.max_buffer_size, 50);
1425
1426 assert!(config.enable_context_injection);
1428 assert_eq!(config.injection_min_relevance, 0.70);
1429 assert_eq!(config.injection_max_memories, 3);
1430 assert_eq!(config.injection_cooldown_secs, 180);
1431 }
1432
1433 #[test]
1434 fn test_stream_mode_default() {
1435 let mode = StreamMode::default();
1436 assert_eq!(mode, StreamMode::Conversation);
1437 }
1438
1439 #[test]
1440 fn test_content_hash_consistency() {
1441 let h1 = content_hash("Hello World");
1442 let h2 = content_hash("hello world");
1443 let h3 = content_hash(" hello world ");
1444
1445 assert_eq!(h1, h2);
1447 assert_eq!(h2, h3);
1448 }
1449
1450 #[test]
1451 fn test_calculate_importance_conversation() {
1452 let config = ExtractionConfig::default();
1453
1454 let short =
1456 StreamingMemoryExtractor::calculate_importance("ok", StreamMode::Conversation, &config);
1457 assert!(short < 0.5);
1458
1459 let question = StreamingMemoryExtractor::calculate_importance(
1461 "How do I implement streaming in Rust?",
1462 StreamMode::Conversation,
1463 &config,
1464 );
1465 assert!(question > 0.5);
1466
1467 let error = StreamingMemoryExtractor::calculate_importance(
1469 "Error: connection failed to database server unexpectedly while processing request",
1470 StreamMode::Conversation,
1471 &config,
1472 );
1473 assert!(error > 0.6);
1474 }
1475
1476 #[test]
1477 fn test_determine_experience_type() {
1478 let msg_error = BufferedMessage {
1479 content: "test".to_string(),
1480 source: None,
1481 timestamp: Utc::now(),
1482 importance: None,
1483 tags: vec!["error".to_string()],
1484 metadata: HashMap::new(),
1485 };
1486 assert_eq!(
1487 StreamingMemoryExtractor::determine_experience_type(
1488 StreamMode::Conversation,
1489 &msg_error
1490 ),
1491 ExperienceType::Error
1492 );
1493
1494 let msg_default = BufferedMessage {
1495 content: "test".to_string(),
1496 source: None,
1497 timestamp: Utc::now(),
1498 importance: None,
1499 tags: vec![],
1500 metadata: HashMap::new(),
1501 };
1502 assert_eq!(
1503 StreamingMemoryExtractor::determine_experience_type(
1504 StreamMode::Conversation,
1505 &msg_default
1506 ),
1507 ExperienceType::Conversation
1508 );
1509 assert_eq!(
1510 StreamingMemoryExtractor::determine_experience_type(StreamMode::Sensor, &msg_default),
1511 ExperienceType::Observation
1512 );
1513 }
1514
1515 #[test]
1516 fn test_stream_handshake_deserialization() {
1517 let json = r#"{
1518 "user_id": "test-user",
1519 "mode": "conversation",
1520 "extraction_config": {
1521 "min_importance": 0.5,
1522 "checkpoint_interval_ms": 10000
1523 }
1524 }"#;
1525
1526 let handshake: StreamHandshake = serde_json::from_str(json).unwrap();
1527 assert_eq!(handshake.user_id, "test-user");
1528 assert_eq!(handshake.mode, StreamMode::Conversation);
1529 assert_eq!(handshake.extraction_config.min_importance, 0.5);
1530 assert_eq!(handshake.extraction_config.checkpoint_interval_ms, 10000);
1531 assert!(handshake.extraction_config.auto_dedupe);
1533 }
1534
1535 #[test]
1536 fn test_stream_message_variants() {
1537 let content_json = r#"{
1539 "type": "content",
1540 "content": "Hello world",
1541 "source": "user",
1542 "tags": ["greeting"]
1543 }"#;
1544 let msg: StreamMessage = serde_json::from_str(content_json).unwrap();
1545 matches!(msg, StreamMessage::Content { .. });
1546
1547 let event_json = r#"{
1549 "type": "event",
1550 "event": "error",
1551 "description": "Database connection failed",
1552 "severity": "error"
1553 }"#;
1554 let msg: StreamMessage = serde_json::from_str(event_json).unwrap();
1555 matches!(msg, StreamMessage::Event { .. });
1556
1557 let flush_json = r#"{"type": "flush"}"#;
1559 let msg: StreamMessage = serde_json::from_str(flush_json).unwrap();
1560 matches!(msg, StreamMessage::Flush);
1561 }
1562
1563 #[test]
1564 fn test_detected_entity_from_ner() {
1565 let ner_entity = NerEntity {
1566 text: "Microsoft".to_string(),
1567 entity_type: NerEntityType::Organization,
1568 confidence: 0.95,
1569 start: 0,
1570 end: 9,
1571 };
1572
1573 let detected = DetectedEntity::from(&ner_entity);
1574 assert_eq!(detected.text, "Microsoft");
1575 assert_eq!(detected.entity_type, "ORG");
1576 assert_eq!(detected.confidence, 0.95);
1577 assert!(!detected.existing);
1578 }
1579}