1use zeph_llm::any::AnyProvider;
5use zeph_llm::provider::{LlmProvider, Message, MessageMetadata, Role};
6
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU64, Ordering};
9
10use crate::embedding_store::{EmbeddingStore, MessageKind, SearchFilter};
11use crate::error::MemoryError;
12use crate::sqlite::SqliteStore;
13use crate::token_counter::TokenCounter;
14use crate::types::{ConversationId, MessageId};
15use crate::vector_store::{FieldCondition, FieldValue, VectorFilter};
16
17const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
18const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
19const CORRECTIONS_COLLECTION: &str = "zeph_corrections";
20
21#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
22pub struct StructuredSummary {
23 pub summary: String,
24 pub key_facts: Vec<String>,
25 pub entities: Vec<String>,
26}
27
28#[derive(Debug)]
29pub struct RecalledMessage {
30 pub message: Message,
31 pub score: f32,
32}
33
34#[derive(Debug, Clone)]
35pub struct Summary {
36 pub id: i64,
37 pub conversation_id: ConversationId,
38 pub content: String,
39 pub first_message_id: MessageId,
40 pub last_message_id: MessageId,
41 pub token_estimate: i64,
42}
43
44#[derive(Debug, Clone)]
45pub struct SessionSummaryResult {
46 pub summary_text: String,
47 pub score: f32,
48 pub conversation_id: ConversationId,
49}
50
51use crate::math::cosine_similarity;
52
53fn apply_temporal_decay(
54 ranked: &mut [(MessageId, f64)],
55 timestamps: &std::collections::HashMap<MessageId, i64>,
56 half_life_days: u32,
57) {
58 if half_life_days == 0 {
59 return;
60 }
61 let now = std::time::SystemTime::now()
62 .duration_since(std::time::UNIX_EPOCH)
63 .unwrap_or_default()
64 .as_secs()
65 .cast_signed();
66 let lambda = std::f64::consts::LN_2 / f64::from(half_life_days);
67
68 for (msg_id, score) in ranked.iter_mut() {
69 if let Some(&ts) = timestamps.get(msg_id) {
70 #[allow(clippy::cast_precision_loss)]
71 let age_days = (now - ts).max(0) as f64 / 86400.0;
72 *score *= (-lambda * age_days).exp();
73 }
74 }
75}
76
77fn apply_mmr(
78 ranked: &[(MessageId, f64)],
79 vectors: &std::collections::HashMap<MessageId, Vec<f32>>,
80 lambda: f32,
81 limit: usize,
82) -> Vec<(MessageId, f64)> {
83 if ranked.is_empty() || limit == 0 {
84 return Vec::new();
85 }
86
87 tracing::debug!(
88 candidates = ranked.len(),
89 limit,
90 lambda = %lambda,
91 "mmr: starting re-ranking"
92 );
93
94 let lambda = f64::from(lambda);
95 let mut selected: Vec<(MessageId, f64)> = Vec::with_capacity(limit);
96 let mut remaining: Vec<(MessageId, f64)> = ranked.to_vec();
97
98 while selected.len() < limit && !remaining.is_empty() {
99 let best_idx = if selected.is_empty() {
100 0
102 } else {
103 let mut best = 0usize;
104 let mut best_score = f64::NEG_INFINITY;
105
106 for (i, &(cand_id, relevance)) in remaining.iter().enumerate() {
107 let max_sim = if let Some(cand_vec) = vectors.get(&cand_id) {
108 selected
109 .iter()
110 .filter_map(|(sel_id, _)| vectors.get(sel_id))
111 .map(|sel_vec| f64::from(cosine_similarity(cand_vec, sel_vec)))
112 .fold(f64::NEG_INFINITY, f64::max)
113 } else {
114 0.0
115 };
116 let max_sim = if max_sim == f64::NEG_INFINITY {
117 0.0
118 } else {
119 max_sim
120 };
121 let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim;
122 if mmr_score > best_score {
123 best_score = mmr_score;
124 best = i;
125 }
126 }
127 best
128 };
129
130 selected.push(remaining.remove(best_idx));
131 }
132
133 tracing::debug!(selected = selected.len(), "mmr: re-ranking complete");
134
135 selected
136}
137
138fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
139 let mut prompt = String::from(
140 "Summarize the following conversation. Extract key facts, decisions, entities, \
141 and context needed to continue the conversation.\n\n\
142 Respond in JSON with fields: summary (string), key_facts (list of strings), \
143 entities (list of strings).\n\nConversation:\n",
144 );
145
146 for (_, role, content) in messages {
147 prompt.push_str(role);
148 prompt.push_str(": ");
149 prompt.push_str(content);
150 prompt.push('\n');
151 }
152
153 prompt
154}
155
156pub struct SemanticMemory {
157 sqlite: SqliteStore,
158 qdrant: Option<Arc<EmbeddingStore>>,
159 provider: AnyProvider,
160 embedding_model: String,
161 vector_weight: f64,
162 keyword_weight: f64,
163 temporal_decay_enabled: bool,
164 temporal_decay_half_life_days: u32,
165 mmr_enabled: bool,
166 mmr_lambda: f32,
167 pub token_counter: Arc<TokenCounter>,
168 pub graph_store: Option<Arc<crate::graph::GraphStore>>,
169 community_detection_failures: Arc<AtomicU64>,
170 graph_extraction_count: Arc<AtomicU64>,
171 graph_extraction_failures: Arc<AtomicU64>,
172}
173
174impl SemanticMemory {
175 pub async fn new(
186 sqlite_path: &str,
187 qdrant_url: &str,
188 provider: AnyProvider,
189 embedding_model: &str,
190 ) -> Result<Self, MemoryError> {
191 Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
192 }
193
194 pub async fn with_weights(
203 sqlite_path: &str,
204 qdrant_url: &str,
205 provider: AnyProvider,
206 embedding_model: &str,
207 vector_weight: f64,
208 keyword_weight: f64,
209 ) -> Result<Self, MemoryError> {
210 Self::with_weights_and_pool_size(
211 sqlite_path,
212 qdrant_url,
213 provider,
214 embedding_model,
215 vector_weight,
216 keyword_weight,
217 5,
218 )
219 .await
220 }
221
222 pub async fn with_weights_and_pool_size(
231 sqlite_path: &str,
232 qdrant_url: &str,
233 provider: AnyProvider,
234 embedding_model: &str,
235 vector_weight: f64,
236 keyword_weight: f64,
237 pool_size: u32,
238 ) -> Result<Self, MemoryError> {
239 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
240 let pool = sqlite.pool().clone();
241
242 let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
243 Ok(store) => Some(Arc::new(store)),
244 Err(e) => {
245 tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
246 None
247 }
248 };
249
250 Ok(Self {
251 sqlite,
252 qdrant,
253 provider,
254 embedding_model: embedding_model.into(),
255 vector_weight,
256 keyword_weight,
257 temporal_decay_enabled: false,
258 temporal_decay_half_life_days: 30,
259 mmr_enabled: false,
260 mmr_lambda: 0.7,
261 token_counter: Arc::new(TokenCounter::new()),
262 graph_store: None,
263 community_detection_failures: Arc::new(AtomicU64::new(0)),
264 graph_extraction_count: Arc::new(AtomicU64::new(0)),
265 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
266 })
267 }
268
269 pub async fn with_qdrant_ops(
278 sqlite_path: &str,
279 ops: crate::QdrantOps,
280 provider: AnyProvider,
281 embedding_model: &str,
282 vector_weight: f64,
283 keyword_weight: f64,
284 pool_size: u32,
285 ) -> Result<Self, MemoryError> {
286 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
287 let pool = sqlite.pool().clone();
288 let store = EmbeddingStore::with_store(Box::new(ops), pool);
289
290 Ok(Self {
291 sqlite,
292 qdrant: Some(Arc::new(store)),
293 provider,
294 embedding_model: embedding_model.into(),
295 vector_weight,
296 keyword_weight,
297 temporal_decay_enabled: false,
298 temporal_decay_half_life_days: 30,
299 mmr_enabled: false,
300 mmr_lambda: 0.7,
301 token_counter: Arc::new(TokenCounter::new()),
302 graph_store: None,
303 community_detection_failures: Arc::new(AtomicU64::new(0)),
304 graph_extraction_count: Arc::new(AtomicU64::new(0)),
305 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
306 })
307 }
308
309 #[must_use]
314 pub fn with_graph_store(mut self, store: Arc<crate::graph::GraphStore>) -> Self {
315 self.graph_store = Some(store);
316 self
317 }
318
319 #[must_use]
321 pub fn community_detection_failures(&self) -> u64 {
322 self.community_detection_failures.load(Ordering::Relaxed)
323 }
324
325 #[must_use]
327 pub fn graph_extraction_count(&self) -> u64 {
328 self.graph_extraction_count.load(Ordering::Relaxed)
329 }
330
331 #[must_use]
333 pub fn graph_extraction_failures(&self) -> u64 {
334 self.graph_extraction_failures.load(Ordering::Relaxed)
335 }
336
337 #[must_use]
339 pub fn with_ranking_options(
340 mut self,
341 temporal_decay_enabled: bool,
342 temporal_decay_half_life_days: u32,
343 mmr_enabled: bool,
344 mmr_lambda: f32,
345 ) -> Self {
346 self.temporal_decay_enabled = temporal_decay_enabled;
347 self.temporal_decay_half_life_days = temporal_decay_half_life_days;
348 self.mmr_enabled = mmr_enabled;
349 self.mmr_lambda = mmr_lambda;
350 self
351 }
352
353 #[must_use]
357 pub fn from_parts(
358 sqlite: SqliteStore,
359 qdrant: Option<Arc<EmbeddingStore>>,
360 provider: AnyProvider,
361 embedding_model: impl Into<String>,
362 vector_weight: f64,
363 keyword_weight: f64,
364 token_counter: Arc<TokenCounter>,
365 ) -> Self {
366 Self {
367 sqlite,
368 qdrant,
369 provider,
370 embedding_model: embedding_model.into(),
371 vector_weight,
372 keyword_weight,
373 temporal_decay_enabled: false,
374 temporal_decay_half_life_days: 30,
375 mmr_enabled: false,
376 mmr_lambda: 0.7,
377 token_counter,
378 graph_store: None,
379 community_detection_failures: Arc::new(AtomicU64::new(0)),
380 graph_extraction_count: Arc::new(AtomicU64::new(0)),
381 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
382 }
383 }
384
385 pub async fn with_sqlite_backend(
391 sqlite_path: &str,
392 provider: AnyProvider,
393 embedding_model: &str,
394 vector_weight: f64,
395 keyword_weight: f64,
396 ) -> Result<Self, MemoryError> {
397 Self::with_sqlite_backend_and_pool_size(
398 sqlite_path,
399 provider,
400 embedding_model,
401 vector_weight,
402 keyword_weight,
403 5,
404 )
405 .await
406 }
407
408 pub async fn with_sqlite_backend_and_pool_size(
414 sqlite_path: &str,
415 provider: AnyProvider,
416 embedding_model: &str,
417 vector_weight: f64,
418 keyword_weight: f64,
419 pool_size: u32,
420 ) -> Result<Self, MemoryError> {
421 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
422 let pool = sqlite.pool().clone();
423 let store = EmbeddingStore::new_sqlite(pool);
424
425 Ok(Self {
426 sqlite,
427 qdrant: Some(Arc::new(store)),
428 provider,
429 embedding_model: embedding_model.into(),
430 vector_weight,
431 keyword_weight,
432 temporal_decay_enabled: false,
433 temporal_decay_half_life_days: 30,
434 mmr_enabled: false,
435 mmr_lambda: 0.7,
436 token_counter: Arc::new(TokenCounter::new()),
437 graph_store: None,
438 community_detection_failures: Arc::new(AtomicU64::new(0)),
439 graph_extraction_count: Arc::new(AtomicU64::new(0)),
440 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
441 })
442 }
443
444 pub async fn remember(
453 &self,
454 conversation_id: ConversationId,
455 role: &str,
456 content: &str,
457 ) -> Result<MessageId, MemoryError> {
458 let message_id = self
459 .sqlite
460 .save_message(conversation_id, role, content)
461 .await?;
462
463 if let Some(qdrant) = &self.qdrant
464 && self.provider.supports_embeddings()
465 {
466 match self.provider.embed(content).await {
467 Ok(vector) => {
468 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
470 if let Err(e) = qdrant.ensure_collection(vector_size).await {
471 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
472 } else if let Err(e) = qdrant
473 .store(
474 message_id,
475 conversation_id,
476 role,
477 vector,
478 MessageKind::Regular,
479 &self.embedding_model,
480 )
481 .await
482 {
483 tracing::warn!("Failed to store embedding: {e:#}");
484 }
485 }
486 Err(e) => {
487 tracing::warn!("Failed to generate embedding: {e:#}");
488 }
489 }
490 }
491
492 Ok(message_id)
493 }
494
495 pub async fn remember_with_parts(
504 &self,
505 conversation_id: ConversationId,
506 role: &str,
507 content: &str,
508 parts_json: &str,
509 ) -> Result<(MessageId, bool), MemoryError> {
510 let message_id = self
511 .sqlite
512 .save_message_with_parts(conversation_id, role, content, parts_json)
513 .await?;
514
515 let mut embedding_stored = false;
516
517 if let Some(qdrant) = &self.qdrant
518 && self.provider.supports_embeddings()
519 {
520 match self.provider.embed(content).await {
521 Ok(vector) => {
522 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
523 if let Err(e) = qdrant.ensure_collection(vector_size).await {
524 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
525 } else if let Err(e) = qdrant
526 .store(
527 message_id,
528 conversation_id,
529 role,
530 vector,
531 MessageKind::Regular,
532 &self.embedding_model,
533 )
534 .await
535 {
536 tracing::warn!("Failed to store embedding: {e:#}");
537 } else {
538 embedding_stored = true;
539 }
540 }
541 Err(e) => {
542 tracing::warn!("Failed to generate embedding: {e:#}");
543 }
544 }
545 }
546
547 Ok((message_id, embedding_stored))
548 }
549
550 pub async fn save_only(
558 &self,
559 conversation_id: ConversationId,
560 role: &str,
561 content: &str,
562 parts_json: &str,
563 ) -> Result<MessageId, MemoryError> {
564 self.sqlite
565 .save_message_with_parts(conversation_id, role, content, parts_json)
566 .await
567 }
568
569 pub async fn recall(
579 &self,
580 query: &str,
581 limit: usize,
582 filter: Option<SearchFilter>,
583 ) -> Result<Vec<RecalledMessage>, MemoryError> {
584 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
585
586 tracing::debug!(
587 query_len = query.len(),
588 limit,
589 has_filter = filter.is_some(),
590 conversation_id = conversation_id.map(|c| c.0),
591 has_qdrant = self.qdrant.is_some(),
592 "recall: starting hybrid search"
593 );
594
595 let keyword_results = match self
597 .sqlite
598 .keyword_search(query, limit * 2, conversation_id)
599 .await
600 {
601 Ok(results) => results,
602 Err(e) => {
603 tracing::warn!("FTS5 keyword search failed: {e:#}");
604 Vec::new()
605 }
606 };
607
608 let vector_results = if let Some(qdrant) = &self.qdrant
610 && self.provider.supports_embeddings()
611 {
612 let query_vector = self.provider.embed(query).await?;
613 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
614 qdrant.ensure_collection(vector_size).await?;
615 qdrant.search(&query_vector, limit * 2, filter).await?
616 } else {
617 Vec::new()
618 };
619
620 self.recall_merge_and_rank(keyword_results, vector_results, limit)
621 .await
622 }
623
624 async fn recall_fts5_raw(
630 &self,
631 query: &str,
632 limit: usize,
633 conversation_id: Option<ConversationId>,
634 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
635 self.sqlite
636 .keyword_search(query, limit * 2, conversation_id)
637 .await
638 }
639
640 async fn recall_vectors_raw(
647 &self,
648 query: &str,
649 limit: usize,
650 filter: Option<SearchFilter>,
651 ) -> Result<Vec<crate::embedding_store::SearchResult>, MemoryError> {
652 let Some(qdrant) = &self.qdrant else {
653 return Ok(Vec::new());
654 };
655 if !self.provider.supports_embeddings() {
656 return Ok(Vec::new());
657 }
658 let query_vector = self.provider.embed(query).await?;
659 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
660 qdrant.ensure_collection(vector_size).await?;
661 qdrant.search(&query_vector, limit * 2, filter).await
662 }
663
664 #[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
673 async fn recall_merge_and_rank(
674 &self,
675 keyword_results: Vec<(MessageId, f64)>,
676 vector_results: Vec<crate::embedding_store::SearchResult>,
677 limit: usize,
678 ) -> Result<Vec<RecalledMessage>, MemoryError> {
679 tracing::debug!(
680 vector_count = vector_results.len(),
681 keyword_count = keyword_results.len(),
682 limit,
683 "recall: merging search results"
684 );
685
686 let mut scores: std::collections::HashMap<MessageId, f64> =
687 std::collections::HashMap::new();
688
689 if !vector_results.is_empty() {
690 let max_vs = vector_results
691 .iter()
692 .map(|r| r.score)
693 .fold(f32::NEG_INFINITY, f32::max);
694 let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
695 for r in &vector_results {
696 let normalized = f64::from(r.score / norm);
697 *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
698 }
699 }
700
701 if !keyword_results.is_empty() {
702 let max_ks = keyword_results
703 .iter()
704 .map(|r| r.1)
705 .fold(f64::NEG_INFINITY, f64::max);
706 let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
707 for &(msg_id, score) in &keyword_results {
708 let normalized = score / norm;
709 *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
710 }
711 }
712
713 if scores.is_empty() {
714 tracing::debug!("recall: empty merge, no overlapping scores");
715 return Ok(Vec::new());
716 }
717
718 let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
719 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
720
721 tracing::debug!(
722 merged = ranked.len(),
723 top_score = ranked.first().map(|r| r.1),
724 bottom_score = ranked.last().map(|r| r.1),
725 vector_weight = %self.vector_weight,
726 keyword_weight = %self.keyword_weight,
727 "recall: weighted merge complete"
728 );
729
730 if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
731 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
732 match self.sqlite.message_timestamps(&ids).await {
733 Ok(timestamps) => {
734 apply_temporal_decay(
735 &mut ranked,
736 ×tamps,
737 self.temporal_decay_half_life_days,
738 );
739 ranked
740 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
741 tracing::debug!(
742 half_life_days = self.temporal_decay_half_life_days,
743 top_score_after = ranked.first().map(|r| r.1),
744 "recall: temporal decay applied"
745 );
746 }
747 Err(e) => {
748 tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
749 }
750 }
751 }
752
753 if self.mmr_enabled && !vector_results.is_empty() {
754 if let Some(qdrant) = &self.qdrant {
755 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
756 match qdrant.get_vectors(&ids).await {
757 Ok(vec_map) if !vec_map.is_empty() => {
758 let ranked_len_before = ranked.len();
759 ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
760 tracing::debug!(
761 before = ranked_len_before,
762 after = ranked.len(),
763 lambda = %self.mmr_lambda,
764 "recall: mmr re-ranked"
765 );
766 }
767 Ok(_) => {
768 ranked.truncate(limit);
769 }
770 Err(e) => {
771 tracing::warn!("MMR: failed to fetch vectors: {e:#}");
772 ranked.truncate(limit);
773 }
774 }
775 } else {
776 ranked.truncate(limit);
777 }
778 } else {
779 ranked.truncate(limit);
780 }
781
782 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
783 let messages = self.sqlite.messages_by_ids(&ids).await?;
784 let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
785
786 let recalled: Vec<RecalledMessage> = ranked
787 .iter()
788 .filter_map(|(msg_id, score)| {
789 msg_map.get(msg_id).map(|msg| RecalledMessage {
790 message: msg.clone(),
791 #[expect(clippy::cast_possible_truncation)]
792 score: *score as f32,
793 })
794 })
795 .collect();
796
797 tracing::debug!(final_count = recalled.len(), "recall: final results");
798
799 Ok(recalled)
800 }
801
802 pub async fn recall_routed(
811 &self,
812 query: &str,
813 limit: usize,
814 filter: Option<SearchFilter>,
815 router: &dyn crate::router::MemoryRouter,
816 ) -> Result<Vec<RecalledMessage>, MemoryError> {
817 use crate::router::MemoryRoute;
818
819 let route = router.route(query);
820 tracing::debug!(?route, query_len = query.len(), "memory routing decision");
821
822 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
823
824 let (keyword_results, vector_results): (
825 Vec<(MessageId, f64)>,
826 Vec<crate::embedding_store::SearchResult>,
827 ) = match route {
828 MemoryRoute::Keyword => {
829 let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
830 (kw, Vec::new())
831 }
832 MemoryRoute::Semantic => {
833 let vr = self.recall_vectors_raw(query, limit, filter).await?;
834 (Vec::new(), vr)
835 }
836 MemoryRoute::Hybrid => {
837 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
839 Ok(r) => r,
840 Err(e) => {
841 tracing::warn!("FTS5 keyword search failed: {e:#}");
842 Vec::new()
843 }
844 };
845 let vr = self.recall_vectors_raw(query, limit, filter).await?;
848 (kw, vr)
849 }
850 MemoryRoute::Graph => {
853 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
854 Ok(r) => r,
855 Err(e) => {
856 tracing::warn!("FTS5 keyword search failed (graph→hybrid fallback): {e:#}");
857 Vec::new()
858 }
859 };
860 let vr = self.recall_vectors_raw(query, limit, filter).await?;
861 (kw, vr)
862 }
863 };
864
865 tracing::debug!(
866 keyword_count = keyword_results.len(),
867 vector_count = vector_results.len(),
868 "recall: routed search results"
869 );
870
871 self.recall_merge_and_rank(keyword_results, vector_results, limit)
872 .await
873 }
874
875 pub async fn recall_graph(
883 &self,
884 query: &str,
885 limit: usize,
886 max_hops: u32,
887 ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
888 let Some(store) = &self.graph_store else {
889 return Ok(Vec::new());
890 };
891
892 tracing::debug!(
893 query_len = query.len(),
894 limit,
895 max_hops,
896 "graph: starting recall"
897 );
898
899 let results = crate::graph::retrieval::graph_recall(
900 store,
901 self.qdrant.as_deref(),
902 &self.provider,
903 query,
904 limit,
905 max_hops,
906 )
907 .await?;
908
909 tracing::debug!(result_count = results.len(), "graph: recall complete");
910
911 Ok(results)
912 }
913
914 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
920 match &self.qdrant {
921 Some(qdrant) => qdrant.has_embedding(message_id).await,
922 None => Ok(false),
923 }
924 }
925
926 pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
935 let Some(qdrant) = &self.qdrant else {
936 return Ok(0);
937 };
938 if !self.provider.supports_embeddings() {
939 return Ok(0);
940 }
941
942 let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
943
944 if unembedded.is_empty() {
945 return Ok(0);
946 }
947
948 let probe = self.provider.embed("probe").await?;
949 let vector_size = u64::try_from(probe.len())?;
950 qdrant.ensure_collection(vector_size).await?;
951
952 let mut count = 0;
953 for (msg_id, conversation_id, role, content) in &unembedded {
954 match self.provider.embed(content).await {
955 Ok(vector) => {
956 if let Err(e) = qdrant
957 .store(
958 *msg_id,
959 *conversation_id,
960 role,
961 vector,
962 MessageKind::Regular,
963 &self.embedding_model,
964 )
965 .await
966 {
967 tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
968 continue;
969 }
970 count += 1;
971 }
972 Err(e) => {
973 tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
974 }
975 }
976 }
977
978 tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
979 Ok(count)
980 }
981
982 pub async fn store_session_summary(
988 &self,
989 conversation_id: ConversationId,
990 summary_text: &str,
991 ) -> Result<(), MemoryError> {
992 let Some(qdrant) = &self.qdrant else {
993 return Ok(());
994 };
995 if !self.provider.supports_embeddings() {
996 return Ok(());
997 }
998
999 let vector = self.provider.embed(summary_text).await?;
1000 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
1001 qdrant
1002 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
1003 .await?;
1004
1005 let payload = serde_json::json!({
1006 "conversation_id": conversation_id.0,
1007 "summary_text": summary_text,
1008 });
1009
1010 qdrant
1011 .store_to_collection(SESSION_SUMMARIES_COLLECTION, payload, vector)
1012 .await?;
1013
1014 tracing::debug!(
1015 conversation_id = conversation_id.0,
1016 "stored session summary"
1017 );
1018 Ok(())
1019 }
1020
1021 pub async fn search_session_summaries(
1027 &self,
1028 query: &str,
1029 limit: usize,
1030 exclude_conversation_id: Option<ConversationId>,
1031 ) -> Result<Vec<SessionSummaryResult>, MemoryError> {
1032 let Some(qdrant) = &self.qdrant else {
1033 tracing::debug!("session-summaries: skipped, no vector store");
1034 return Ok(Vec::new());
1035 };
1036 if !self.provider.supports_embeddings() {
1037 tracing::debug!("session-summaries: skipped, no embedding support");
1038 return Ok(Vec::new());
1039 }
1040
1041 let vector = self.provider.embed(query).await?;
1042 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
1043 qdrant
1044 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
1045 .await?;
1046
1047 let filter = exclude_conversation_id.map(|cid| VectorFilter {
1048 must: vec![],
1049 must_not: vec![FieldCondition {
1050 field: "conversation_id".into(),
1051 value: FieldValue::Integer(cid.0),
1052 }],
1053 });
1054
1055 let points = qdrant
1056 .search_collection(SESSION_SUMMARIES_COLLECTION, &vector, limit, filter)
1057 .await?;
1058
1059 tracing::debug!(
1060 results = points.len(),
1061 limit,
1062 exclude_conversation_id = exclude_conversation_id.map(|c| c.0),
1063 "session-summaries: search complete"
1064 );
1065
1066 let results = points
1067 .into_iter()
1068 .filter_map(|point| {
1069 let summary_text = point.payload.get("summary_text")?.as_str()?.to_owned();
1070 let conversation_id =
1071 ConversationId(point.payload.get("conversation_id")?.as_i64()?);
1072 Some(SessionSummaryResult {
1073 summary_text,
1074 score: point.score,
1075 conversation_id,
1076 })
1077 })
1078 .collect();
1079
1080 Ok(results)
1081 }
1082
1083 #[must_use]
1085 pub fn sqlite(&self) -> &SqliteStore {
1086 &self.sqlite
1087 }
1088
1089 pub async fn is_vector_store_connected(&self) -> bool {
1094 match self.qdrant.as_ref() {
1095 Some(store) => store.health_check().await,
1096 None => false,
1097 }
1098 }
1099
1100 #[must_use]
1102 pub fn has_vector_store(&self) -> bool {
1103 self.qdrant.is_some()
1104 }
1105
1106 #[must_use]
1108 pub fn embedding_store(&self) -> Option<&Arc<EmbeddingStore>> {
1109 self.qdrant.as_ref()
1110 }
1111
1112 pub async fn message_count(&self, conversation_id: ConversationId) -> Result<i64, MemoryError> {
1118 self.sqlite.count_messages(conversation_id).await
1119 }
1120
1121 pub async fn unsummarized_message_count(
1127 &self,
1128 conversation_id: ConversationId,
1129 ) -> Result<i64, MemoryError> {
1130 let after_id = self
1131 .sqlite
1132 .latest_summary_last_message_id(conversation_id)
1133 .await?
1134 .unwrap_or(MessageId(0));
1135 self.sqlite
1136 .count_messages_after(conversation_id, after_id)
1137 .await
1138 }
1139
1140 pub async fn load_summaries(
1146 &self,
1147 conversation_id: ConversationId,
1148 ) -> Result<Vec<Summary>, MemoryError> {
1149 let rows = self.sqlite.load_summaries(conversation_id).await?;
1150 let summaries = rows
1151 .into_iter()
1152 .map(
1153 |(
1154 id,
1155 conversation_id,
1156 content,
1157 first_message_id,
1158 last_message_id,
1159 token_estimate,
1160 )| {
1161 Summary {
1162 id,
1163 conversation_id,
1164 content,
1165 first_message_id,
1166 last_message_id,
1167 token_estimate,
1168 }
1169 },
1170 )
1171 .collect();
1172 Ok(summaries)
1173 }
1174
1175 pub async fn summarize(
1183 &self,
1184 conversation_id: ConversationId,
1185 message_count: usize,
1186 ) -> Result<Option<i64>, MemoryError> {
1187 let total = self.sqlite.count_messages(conversation_id).await?;
1188
1189 if total <= i64::try_from(message_count)? {
1190 return Ok(None);
1191 }
1192
1193 let after_id = self
1194 .sqlite
1195 .latest_summary_last_message_id(conversation_id)
1196 .await?
1197 .unwrap_or(MessageId(0));
1198
1199 let messages = self
1200 .sqlite
1201 .load_messages_range(conversation_id, after_id, message_count)
1202 .await?;
1203
1204 if messages.is_empty() {
1205 return Ok(None);
1206 }
1207
1208 let prompt = build_summarization_prompt(&messages);
1209 let chat_messages = vec![Message {
1210 role: Role::User,
1211 content: prompt,
1212 parts: vec![],
1213 metadata: MessageMetadata::default(),
1214 }];
1215
1216 let structured = match self
1217 .provider
1218 .chat_typed_erased::<StructuredSummary>(&chat_messages)
1219 .await
1220 {
1221 Ok(s) => s,
1222 Err(e) => {
1223 tracing::warn!(
1224 "structured summarization failed, falling back to plain text: {e:#}"
1225 );
1226 let plain = self.provider.chat(&chat_messages).await?;
1227 StructuredSummary {
1228 summary: plain,
1229 key_facts: vec![],
1230 entities: vec![],
1231 }
1232 }
1233 };
1234 let summary_text = &structured.summary;
1235
1236 let token_estimate = i64::try_from(self.token_counter.count_tokens(summary_text))?;
1237 let first_message_id = messages[0].0;
1238 let last_message_id = messages[messages.len() - 1].0;
1239
1240 let summary_id = self
1241 .sqlite
1242 .save_summary(
1243 conversation_id,
1244 summary_text,
1245 first_message_id,
1246 last_message_id,
1247 token_estimate,
1248 )
1249 .await?;
1250
1251 if let Some(qdrant) = &self.qdrant
1252 && self.provider.supports_embeddings()
1253 {
1254 match self.provider.embed(summary_text).await {
1255 Ok(vector) => {
1256 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
1258 if let Err(e) = qdrant.ensure_collection(vector_size).await {
1259 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
1260 } else if let Err(e) = qdrant
1261 .store(
1262 MessageId(summary_id),
1263 conversation_id,
1264 "system",
1265 vector,
1266 MessageKind::Summary,
1267 &self.embedding_model,
1268 )
1269 .await
1270 {
1271 tracing::warn!("Failed to embed summary: {e:#}");
1272 }
1273 }
1274 Err(e) => {
1275 tracing::warn!("Failed to generate summary embedding: {e:#}");
1276 }
1277 }
1278 }
1279
1280 if !structured.key_facts.is_empty() {
1282 self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
1283 .await;
1284 }
1285
1286 Ok(Some(summary_id))
1287 }
1288
1289 async fn store_key_facts(
1290 &self,
1291 conversation_id: ConversationId,
1292 source_summary_id: i64,
1293 key_facts: &[String],
1294 ) {
1295 let Some(qdrant) = &self.qdrant else {
1296 return;
1297 };
1298 if !self.provider.supports_embeddings() {
1299 return;
1300 }
1301
1302 let Some(first_fact) = key_facts.first() else {
1303 return;
1304 };
1305 let first_vector = match self.provider.embed(first_fact).await {
1306 Ok(v) => v,
1307 Err(e) => {
1308 tracing::warn!("Failed to embed key fact: {e:#}");
1309 return;
1310 }
1311 };
1312 let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
1313 if let Err(e) = qdrant
1314 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
1315 .await
1316 {
1317 tracing::warn!("Failed to ensure key_facts collection: {e:#}");
1318 return;
1319 }
1320
1321 let first_payload = serde_json::json!({
1322 "conversation_id": conversation_id.0,
1323 "fact_text": first_fact,
1324 "source_summary_id": source_summary_id,
1325 });
1326 if let Err(e) = qdrant
1327 .store_to_collection(KEY_FACTS_COLLECTION, first_payload, first_vector)
1328 .await
1329 {
1330 tracing::warn!("Failed to store key fact: {e:#}");
1331 }
1332
1333 for fact in &key_facts[1..] {
1334 match self.provider.embed(fact).await {
1335 Ok(vector) => {
1336 let payload = serde_json::json!({
1337 "conversation_id": conversation_id.0,
1338 "fact_text": fact,
1339 "source_summary_id": source_summary_id,
1340 });
1341 if let Err(e) = qdrant
1342 .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
1343 .await
1344 {
1345 tracing::warn!("Failed to store key fact: {e:#}");
1346 }
1347 }
1348 Err(e) => {
1349 tracing::warn!("Failed to embed key fact: {e:#}");
1350 }
1351 }
1352 }
1353 }
1354
1355 pub async fn search_key_facts(
1361 &self,
1362 query: &str,
1363 limit: usize,
1364 ) -> Result<Vec<String>, MemoryError> {
1365 let Some(qdrant) = &self.qdrant else {
1366 tracing::debug!("key-facts: skipped, no vector store");
1367 return Ok(Vec::new());
1368 };
1369 if !self.provider.supports_embeddings() {
1370 tracing::debug!("key-facts: skipped, no embedding support");
1371 return Ok(Vec::new());
1372 }
1373
1374 let vector = self.provider.embed(query).await?;
1375 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
1376 qdrant
1377 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
1378 .await?;
1379
1380 let points = qdrant
1381 .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
1382 .await?;
1383
1384 tracing::debug!(results = points.len(), limit, "key-facts: search complete");
1385
1386 let facts = points
1387 .into_iter()
1388 .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
1389 .collect();
1390
1391 Ok(facts)
1392 }
1393
1394 pub async fn search_document_collection(
1404 &self,
1405 collection: &str,
1406 query: &str,
1407 limit: usize,
1408 ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
1409 let Some(qdrant) = &self.qdrant else {
1410 return Ok(Vec::new());
1411 };
1412 if !self.provider.supports_embeddings() {
1413 return Ok(Vec::new());
1414 }
1415 if !qdrant.collection_exists(collection).await? {
1416 return Ok(Vec::new());
1417 }
1418 let vector = self.provider.embed(query).await?;
1419 let results = qdrant
1420 .search_collection(collection, &vector, limit, None)
1421 .await?;
1422
1423 tracing::debug!(
1424 results = results.len(),
1425 limit,
1426 collection,
1427 "document-collection: search complete"
1428 );
1429
1430 Ok(results)
1431 }
1432
1433 pub async fn store_correction_embedding(
1441 &self,
1442 correction_id: i64,
1443 correction_text: &str,
1444 ) -> Result<(), MemoryError> {
1445 let Some(ref store) = self.qdrant else {
1446 return Ok(());
1447 };
1448 if !self.provider.supports_embeddings() {
1449 return Ok(());
1450 }
1451 let embedding = self
1452 .provider
1453 .embed(correction_text)
1454 .await
1455 .map_err(|e| MemoryError::Other(e.to_string()))?;
1456 let vector_size = u64::try_from(embedding.len()).unwrap_or(896);
1457 store
1458 .ensure_named_collection(CORRECTIONS_COLLECTION, vector_size)
1459 .await?;
1460 let payload = serde_json::json!({ "correction_id": correction_id });
1461 store
1462 .store_to_collection(CORRECTIONS_COLLECTION, payload, embedding)
1463 .await?;
1464 Ok(())
1465 }
1466
1467 pub async fn retrieve_similar_corrections(
1476 &self,
1477 query: &str,
1478 limit: usize,
1479 min_score: f32,
1480 ) -> Result<Vec<crate::sqlite::corrections::UserCorrectionRow>, MemoryError> {
1481 let Some(ref store) = self.qdrant else {
1482 tracing::debug!("corrections: skipped, no vector store");
1483 return Ok(vec![]);
1484 };
1485 if !self.provider.supports_embeddings() {
1486 tracing::debug!("corrections: skipped, no embedding support");
1487 return Ok(vec![]);
1488 }
1489 let embedding = self
1490 .provider
1491 .embed(query)
1492 .await
1493 .map_err(|e| MemoryError::Other(e.to_string()))?;
1494 let vector_size = u64::try_from(embedding.len()).unwrap_or(896);
1495 store
1496 .ensure_named_collection(CORRECTIONS_COLLECTION, vector_size)
1497 .await?;
1498 let scored = store
1499 .search_collection(CORRECTIONS_COLLECTION, &embedding, limit, None)
1500 .await
1501 .unwrap_or_default();
1502
1503 tracing::debug!(
1504 candidates = scored.len(),
1505 min_score = %min_score,
1506 limit,
1507 "corrections: search complete"
1508 );
1509
1510 let mut results = Vec::new();
1511 for point in scored {
1512 if point.score < min_score {
1513 continue;
1514 }
1515 if let Some(id_val) = point.payload.get("correction_id")
1516 && let Some(id) = id_val.as_i64()
1517 {
1518 let rows = self.sqlite.load_corrections_for_id(id).await?;
1519 results.extend(rows);
1520 }
1521 }
1522
1523 tracing::debug!(
1524 retained = results.len(),
1525 "corrections: after min_score filter"
1526 );
1527
1528 Ok(results)
1529 }
1530
1531 pub fn spawn_graph_extraction(
1536 &self,
1537 content: String,
1538 context_messages: Vec<String>,
1539 config: GraphExtractionConfig,
1540 ) {
1541 let pool = self.sqlite.pool().clone();
1542 let provider = self.provider.clone();
1543 let failure_counter = self.community_detection_failures.clone();
1544 let extraction_count = self.graph_extraction_count.clone();
1545 let extraction_failures = self.graph_extraction_failures.clone();
1546
1547 tokio::spawn(async move {
1548 let timeout_dur = std::time::Duration::from_secs(config.extraction_timeout_secs);
1549 let extraction_ok = match tokio::time::timeout(
1550 timeout_dur,
1551 extract_and_store(
1552 content,
1553 context_messages,
1554 provider.clone(),
1555 pool.clone(),
1556 config.clone(),
1557 ),
1558 )
1559 .await
1560 {
1561 Ok(Ok(stats)) => {
1562 tracing::debug!(
1563 entities = stats.entities_upserted,
1564 edges = stats.edges_inserted,
1565 "graph extraction completed"
1566 );
1567 extraction_count.fetch_add(1, Ordering::Relaxed);
1568 true
1569 }
1570 Ok(Err(e)) => {
1571 tracing::warn!("graph extraction failed: {e:#}");
1572 extraction_failures.fetch_add(1, Ordering::Relaxed);
1573 false
1574 }
1575 Err(_elapsed) => {
1576 tracing::warn!("graph extraction timed out");
1577 extraction_failures.fetch_add(1, Ordering::Relaxed);
1578 false
1579 }
1580 };
1581
1582 if extraction_ok && config.community_refresh_interval > 0 {
1583 use crate::graph::GraphStore;
1584
1585 let store = GraphStore::new(pool.clone());
1586 let extraction_count = store.extraction_count().await.unwrap_or(0);
1587 if extraction_count > 0
1588 && i64::try_from(config.community_refresh_interval)
1589 .is_ok_and(|interval| extraction_count % interval == 0)
1590 {
1591 tracing::info!(extraction_count, "triggering community detection refresh");
1592 let store2 = GraphStore::new(pool);
1593 let provider2 = provider;
1594 let retention_days = config.expired_edge_retention_days;
1595 let max_cap = config.max_entities_cap;
1596 let max_prompt_bytes = config.community_summary_max_prompt_bytes;
1597 let concurrency = config.community_summary_concurrency;
1598 let edge_chunk_size = config.lpa_edge_chunk_size;
1599 tokio::spawn(async move {
1600 match crate::graph::community::detect_communities(
1601 &store2,
1602 &provider2,
1603 max_prompt_bytes,
1604 concurrency,
1605 edge_chunk_size,
1606 )
1607 .await
1608 {
1609 Ok(count) => {
1610 tracing::info!(communities = count, "community detection complete");
1611 }
1612 Err(e) => {
1613 tracing::warn!("community detection failed: {e:#}");
1614 failure_counter.fetch_add(1, Ordering::Relaxed);
1615 }
1616 }
1617 match crate::graph::community::run_graph_eviction(
1618 &store2,
1619 retention_days,
1620 max_cap,
1621 )
1622 .await
1623 {
1624 Ok(stats) => {
1625 tracing::info!(
1626 expired_edges = stats.expired_edges_deleted,
1627 orphan_entities = stats.orphan_entities_deleted,
1628 capped_entities = stats.capped_entities_deleted,
1629 "graph eviction complete"
1630 );
1631 }
1632 Err(e) => {
1633 tracing::warn!("graph eviction failed: {e:#}");
1634 }
1635 }
1636 });
1637 }
1638 }
1639 });
1640 }
1641}
1642
1643#[derive(Debug, Clone, Default)]
1648pub struct GraphExtractionConfig {
1649 pub max_entities: usize,
1650 pub max_edges: usize,
1651 pub extraction_timeout_secs: u64,
1652 pub community_refresh_interval: usize,
1653 pub expired_edge_retention_days: u32,
1654 pub max_entities_cap: usize,
1655 pub community_summary_max_prompt_bytes: usize,
1656 pub community_summary_concurrency: usize,
1657 pub lpa_edge_chunk_size: usize,
1658}
1659
1660#[derive(Debug, Default)]
1662pub struct ExtractionStats {
1663 pub entities_upserted: usize,
1664 pub edges_inserted: usize,
1665}
1666
1667pub async fn extract_and_store(
1675 content: String,
1676 context_messages: Vec<String>,
1677 provider: AnyProvider,
1678 pool: sqlx::SqlitePool,
1679 config: GraphExtractionConfig,
1680) -> Result<ExtractionStats, MemoryError> {
1681 use crate::graph::{EntityResolver, GraphExtractor, GraphStore};
1682
1683 let extractor = GraphExtractor::new(provider, config.max_entities, config.max_edges);
1684 let ctx_refs: Vec<&str> = context_messages.iter().map(String::as_str).collect();
1685
1686 let store = GraphStore::new(pool);
1687
1688 let pool = store.pool();
1691 sqlx::query(
1692 "INSERT INTO graph_metadata (key, value) VALUES ('extraction_count', '0')
1693 ON CONFLICT(key) DO NOTHING",
1694 )
1695 .execute(pool)
1696 .await?;
1697 sqlx::query(
1698 "UPDATE graph_metadata
1699 SET value = CAST(CAST(value AS INTEGER) + 1 AS TEXT)
1700 WHERE key = 'extraction_count'",
1701 )
1702 .execute(pool)
1703 .await?;
1704
1705 let Some(result) = extractor.extract(&content, &ctx_refs).await? else {
1706 return Ok(ExtractionStats::default());
1707 };
1708
1709 let resolver = EntityResolver::new(&store);
1710
1711 let mut entities_upserted = 0usize;
1712 let mut entity_ids: std::collections::HashMap<String, i64> = std::collections::HashMap::new();
1713
1714 for entity in &result.entities {
1715 match resolver
1716 .resolve(&entity.name, &entity.entity_type, entity.summary.as_deref())
1717 .await
1718 {
1719 Ok((id, _outcome)) => {
1720 entity_ids.insert(entity.name.clone(), id);
1721 entities_upserted += 1;
1722 }
1723 Err(e) => {
1724 tracing::debug!("graph: skipping entity {:?}: {e:#}", entity.name);
1725 }
1726 }
1727 }
1728
1729 let mut edges_inserted = 0usize;
1730 for edge in &result.edges {
1731 let (Some(&src_id), Some(&tgt_id)) =
1732 (entity_ids.get(&edge.source), entity_ids.get(&edge.target))
1733 else {
1734 tracing::debug!(
1735 "graph: skipping edge {:?}->{:?}: entity not resolved",
1736 edge.source,
1737 edge.target
1738 );
1739 continue;
1740 };
1741 match resolver
1742 .resolve_edge(src_id, tgt_id, &edge.relation, &edge.fact, 0.8, None)
1743 .await
1744 {
1745 Ok(Some(_)) => edges_inserted += 1,
1746 Ok(None) => {} Err(e) => {
1748 tracing::debug!("graph: skipping edge: {e:#}");
1749 }
1750 }
1751 }
1752
1753 Ok(ExtractionStats {
1754 entities_upserted,
1755 edges_inserted,
1756 })
1757}
1758
1759#[cfg(test)]
1760mod tests {
1761 use zeph_llm::mock::MockProvider;
1762 use zeph_llm::provider::Role;
1763
1764 use super::*;
1765
1766 fn test_provider() -> AnyProvider {
1767 AnyProvider::Mock(MockProvider::default())
1768 }
1769
1770 async fn test_semantic_memory(_supports_embeddings: bool) -> SemanticMemory {
1771 let provider = test_provider();
1772 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1773
1774 SemanticMemory {
1775 sqlite,
1776 qdrant: None,
1777 provider,
1778 embedding_model: "test-model".into(),
1779 vector_weight: 0.7,
1780 keyword_weight: 0.3,
1781 temporal_decay_enabled: false,
1782 temporal_decay_half_life_days: 30,
1783 mmr_enabled: false,
1784 mmr_lambda: 0.7,
1785 token_counter: Arc::new(TokenCounter::new()),
1786 graph_store: None,
1787 community_detection_failures: Arc::new(AtomicU64::new(0)),
1788 graph_extraction_count: Arc::new(AtomicU64::new(0)),
1789 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
1790 }
1791 }
1792
1793 #[tokio::test]
1794 async fn with_qdrant_ops_constructs_successfully() {
1795 let ops = crate::QdrantOps::new("http://127.0.0.1:1").unwrap();
1796 let provider = test_provider();
1797 let result =
1798 SemanticMemory::with_qdrant_ops(":memory:", ops, provider, "test-model", 0.7, 0.3, 1)
1799 .await;
1800 assert!(
1801 result.is_ok(),
1802 "with_qdrant_ops must succeed (lazy TCP connect)"
1803 );
1804 }
1805
1806 #[tokio::test]
1807 async fn remember_saves_to_sqlite() {
1808 let memory = test_semantic_memory(false).await;
1809
1810 let cid = memory.sqlite.create_conversation().await.unwrap();
1811 let msg_id = memory.remember(cid, "user", "hello").await.unwrap();
1812
1813 assert_eq!(msg_id, MessageId(1));
1814
1815 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1816 assert_eq!(history.len(), 1);
1817 assert_eq!(history[0].role, Role::User);
1818 assert_eq!(history[0].content, "hello");
1819 }
1820
1821 #[tokio::test]
1822 async fn remember_with_parts_saves_parts_json() {
1823 let memory = test_semantic_memory(false).await;
1824 let cid = memory.sqlite.create_conversation().await.unwrap();
1825
1826 let parts_json =
1827 r#"[{"kind":"ToolOutput","tool_name":"shell","body":"hello","compacted_at":null}]"#;
1828 let (msg_id, _embedding_stored) = memory
1829 .remember_with_parts(cid, "assistant", "tool output", parts_json)
1830 .await
1831 .unwrap();
1832 assert!(msg_id > MessageId(0));
1833
1834 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1835 assert_eq!(history.len(), 1);
1836 assert_eq!(history[0].content, "tool output");
1837 }
1838
1839 #[tokio::test]
1840 async fn recall_returns_empty_without_qdrant() {
1841 let memory = test_semantic_memory(true).await;
1842
1843 let recalled = memory.recall("test", 5, None).await.unwrap();
1844 assert!(recalled.is_empty());
1845 }
1846
1847 #[tokio::test]
1848 async fn has_embedding_without_qdrant() {
1849 let memory = test_semantic_memory(true).await;
1850
1851 let has_embedding = memory.has_embedding(MessageId(1)).await.unwrap();
1852 assert!(!has_embedding);
1853 }
1854
1855 #[tokio::test]
1856 async fn embed_missing_without_qdrant() {
1857 let memory = test_semantic_memory(true).await;
1858
1859 let count = memory.embed_missing().await.unwrap();
1860 assert_eq!(count, 0);
1861 }
1862
1863 #[tokio::test]
1864 async fn sqlite_accessor() {
1865 let memory = test_semantic_memory(false).await;
1866
1867 let cid = memory.sqlite().create_conversation().await.unwrap();
1868 assert_eq!(cid, ConversationId(1));
1869
1870 memory
1871 .sqlite()
1872 .save_message(cid, "user", "test")
1873 .await
1874 .unwrap();
1875
1876 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1877 assert_eq!(history.len(), 1);
1878 }
1879
1880 #[tokio::test]
1881 async fn has_vector_store_returns_false_when_unavailable() {
1882 let memory = test_semantic_memory(false).await;
1883 assert!(!memory.has_vector_store());
1884 }
1885
1886 #[tokio::test]
1887 async fn is_vector_store_connected_returns_false_when_unavailable() {
1888 let memory = test_semantic_memory(false).await;
1889 assert!(!memory.is_vector_store_connected().await);
1890 }
1891
1892 #[tokio::test]
1893 async fn recall_returns_empty_when_embeddings_not_supported() {
1894 let memory = test_semantic_memory(false).await;
1895
1896 let recalled = memory.recall("test", 5, None).await.unwrap();
1897 assert!(recalled.is_empty());
1898 }
1899
1900 #[tokio::test]
1901 async fn embed_missing_returns_zero_when_embeddings_not_supported() {
1902 let memory = test_semantic_memory(false).await;
1903
1904 let cid = memory.sqlite().create_conversation().await.unwrap();
1905 memory
1906 .sqlite()
1907 .save_message(cid, "user", "test")
1908 .await
1909 .unwrap();
1910
1911 let count = memory.embed_missing().await.unwrap();
1912 assert_eq!(count, 0);
1913 }
1914
1915 #[tokio::test]
1916 async fn message_count_empty_conversation() {
1917 let memory = test_semantic_memory(false).await;
1918 let cid = memory.sqlite().create_conversation().await.unwrap();
1919
1920 let count = memory.message_count(cid).await.unwrap();
1921 assert_eq!(count, 0);
1922 }
1923
1924 #[tokio::test]
1925 async fn message_count_after_saves() {
1926 let memory = test_semantic_memory(false).await;
1927 let cid = memory.sqlite().create_conversation().await.unwrap();
1928
1929 memory.remember(cid, "user", "msg1").await.unwrap();
1930 memory.remember(cid, "assistant", "msg2").await.unwrap();
1931
1932 let count = memory.message_count(cid).await.unwrap();
1933 assert_eq!(count, 2);
1934 }
1935
1936 #[tokio::test]
1937 async fn unsummarized_count_decreases_after_summary() {
1938 let memory = test_semantic_memory(false).await;
1939 let cid = memory.sqlite().create_conversation().await.unwrap();
1940
1941 for i in 0..10 {
1942 memory
1943 .remember(cid, "user", &format!("msg{i}"))
1944 .await
1945 .unwrap();
1946 }
1947 assert_eq!(memory.unsummarized_message_count(cid).await.unwrap(), 10);
1948
1949 memory.summarize(cid, 5).await.unwrap();
1950
1951 assert!(memory.unsummarized_message_count(cid).await.unwrap() < 10);
1952 assert_eq!(memory.message_count(cid).await.unwrap(), 10);
1953 }
1954
1955 #[tokio::test]
1956 async fn load_summaries_empty() {
1957 let memory = test_semantic_memory(false).await;
1958 let cid = memory.sqlite().create_conversation().await.unwrap();
1959
1960 let summaries = memory.load_summaries(cid).await.unwrap();
1961 assert!(summaries.is_empty());
1962 }
1963
1964 #[tokio::test]
1965 async fn load_summaries_ordered() {
1966 let memory = test_semantic_memory(false).await;
1967 let cid = memory.sqlite().create_conversation().await.unwrap();
1968
1969 let msg_id1 = memory.remember(cid, "user", "m1").await.unwrap();
1970 let msg_id2 = memory.remember(cid, "assistant", "m2").await.unwrap();
1971 let msg_id3 = memory.remember(cid, "user", "m3").await.unwrap();
1972
1973 let s1 = memory
1974 .sqlite()
1975 .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
1976 .await
1977 .unwrap();
1978 let s2 = memory
1979 .sqlite()
1980 .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
1981 .await
1982 .unwrap();
1983
1984 let summaries = memory.load_summaries(cid).await.unwrap();
1985 assert_eq!(summaries.len(), 2);
1986 assert_eq!(summaries[0].id, s1);
1987 assert_eq!(summaries[0].content, "summary1");
1988 assert_eq!(summaries[1].id, s2);
1989 assert_eq!(summaries[1].content, "summary2");
1990 }
1991
1992 #[tokio::test]
1993 async fn summarize_below_threshold() {
1994 let memory = test_semantic_memory(false).await;
1995 let cid = memory.sqlite().create_conversation().await.unwrap();
1996
1997 memory.remember(cid, "user", "hello").await.unwrap();
1998
1999 let result = memory.summarize(cid, 10).await.unwrap();
2000 assert!(result.is_none());
2001 }
2002
2003 #[tokio::test]
2004 async fn summarize_stores_summary() {
2005 let memory = test_semantic_memory(false).await;
2006 let cid = memory.sqlite().create_conversation().await.unwrap();
2007
2008 for i in 0..5 {
2009 memory
2010 .remember(cid, "user", &format!("message {i}"))
2011 .await
2012 .unwrap();
2013 }
2014
2015 let summary_id = memory.summarize(cid, 3).await.unwrap();
2016 assert!(summary_id.is_some());
2017
2018 let summaries = memory.load_summaries(cid).await.unwrap();
2019 assert_eq!(summaries.len(), 1);
2020 assert_eq!(summaries[0].id, summary_id.unwrap());
2021 assert!(!summaries[0].content.is_empty());
2022 }
2023
2024 #[tokio::test]
2025 async fn summarize_respects_previous_summaries() {
2026 let memory = test_semantic_memory(false).await;
2027 let cid = memory.sqlite().create_conversation().await.unwrap();
2028
2029 for i in 0..10 {
2030 memory
2031 .remember(cid, "user", &format!("message {i}"))
2032 .await
2033 .unwrap();
2034 }
2035
2036 let s1 = memory.summarize(cid, 3).await.unwrap();
2037 assert!(s1.is_some());
2038
2039 let s2 = memory.summarize(cid, 3).await.unwrap();
2040 assert!(s2.is_some());
2041
2042 let summaries = memory.load_summaries(cid).await.unwrap();
2043 assert_eq!(summaries.len(), 2);
2044 assert!(summaries[0].last_message_id < summaries[1].first_message_id);
2045 }
2046
2047 #[tokio::test]
2048 async fn remember_multiple_messages_increments_ids() {
2049 let memory = test_semantic_memory(false).await;
2050 let cid = memory.sqlite.create_conversation().await.unwrap();
2051
2052 let id1 = memory.remember(cid, "user", "first").await.unwrap();
2053 let id2 = memory.remember(cid, "assistant", "second").await.unwrap();
2054 let id3 = memory.remember(cid, "user", "third").await.unwrap();
2055
2056 assert!(id1 < id2);
2057 assert!(id2 < id3);
2058 }
2059
2060 #[tokio::test]
2061 async fn message_count_across_conversations() {
2062 let memory = test_semantic_memory(false).await;
2063 let cid1 = memory.sqlite().create_conversation().await.unwrap();
2064 let cid2 = memory.sqlite().create_conversation().await.unwrap();
2065
2066 memory.remember(cid1, "user", "msg1").await.unwrap();
2067 memory.remember(cid1, "user", "msg2").await.unwrap();
2068 memory.remember(cid2, "user", "msg3").await.unwrap();
2069
2070 assert_eq!(memory.message_count(cid1).await.unwrap(), 2);
2071 assert_eq!(memory.message_count(cid2).await.unwrap(), 1);
2072 }
2073
2074 #[tokio::test]
2075 async fn summarize_exact_threshold_returns_none() {
2076 let memory = test_semantic_memory(false).await;
2077 let cid = memory.sqlite().create_conversation().await.unwrap();
2078
2079 for i in 0..3 {
2080 memory
2081 .remember(cid, "user", &format!("msg {i}"))
2082 .await
2083 .unwrap();
2084 }
2085
2086 let result = memory.summarize(cid, 3).await.unwrap();
2087 assert!(result.is_none());
2088 }
2089
2090 #[tokio::test]
2091 async fn summarize_one_above_threshold_produces_summary() {
2092 let memory = test_semantic_memory(false).await;
2093 let cid = memory.sqlite().create_conversation().await.unwrap();
2094
2095 for i in 0..4 {
2096 memory
2097 .remember(cid, "user", &format!("msg {i}"))
2098 .await
2099 .unwrap();
2100 }
2101
2102 let result = memory.summarize(cid, 3).await.unwrap();
2103 assert!(result.is_some());
2104 }
2105
2106 #[tokio::test]
2107 async fn summary_fields_populated() {
2108 let memory = test_semantic_memory(false).await;
2109 let cid = memory.sqlite().create_conversation().await.unwrap();
2110
2111 for i in 0..5 {
2112 memory
2113 .remember(cid, "user", &format!("msg {i}"))
2114 .await
2115 .unwrap();
2116 }
2117
2118 memory.summarize(cid, 3).await.unwrap();
2119 let summaries = memory.load_summaries(cid).await.unwrap();
2120 let s = &summaries[0];
2121
2122 assert_eq!(s.conversation_id, cid);
2123 assert!(s.first_message_id > MessageId(0));
2124 assert!(s.last_message_id >= s.first_message_id);
2125 assert!(s.token_estimate >= 0);
2126 assert!(!s.content.is_empty());
2127 }
2128
2129 #[test]
2130 fn build_summarization_prompt_format() {
2131 let messages = vec![
2132 (MessageId(1), "user".into(), "Hello".into()),
2133 (MessageId(2), "assistant".into(), "Hi there".into()),
2134 ];
2135 let prompt = build_summarization_prompt(&messages);
2136 assert!(prompt.contains("user: Hello"));
2137 assert!(prompt.contains("assistant: Hi there"));
2138 assert!(prompt.contains("key_facts"));
2139 }
2140
2141 #[test]
2142 fn build_summarization_prompt_empty() {
2143 let messages: Vec<(MessageId, String, String)> = vec![];
2144 let prompt = build_summarization_prompt(&messages);
2145 assert!(prompt.contains("key_facts"));
2146 }
2147
2148 #[test]
2149 fn structured_summary_deserialize() {
2150 let json = r#"{"summary":"s","key_facts":["f1","f2"],"entities":["e1"]}"#;
2151 let ss: StructuredSummary = serde_json::from_str(json).unwrap();
2152 assert_eq!(ss.summary, "s");
2153 assert_eq!(ss.key_facts.len(), 2);
2154 assert_eq!(ss.entities.len(), 1);
2155 }
2156
2157 #[test]
2158 fn structured_summary_empty_facts() {
2159 let json = r#"{"summary":"s","key_facts":[],"entities":[]}"#;
2160 let ss: StructuredSummary = serde_json::from_str(json).unwrap();
2161 assert!(ss.key_facts.is_empty());
2162 assert!(ss.entities.is_empty());
2163 }
2164
2165 #[tokio::test]
2166 async fn search_key_facts_no_qdrant_empty() {
2167 let memory = test_semantic_memory(false).await;
2168 let facts = memory.search_key_facts("query", 5).await.unwrap();
2169 assert!(facts.is_empty());
2170 }
2171
2172 #[test]
2173 fn recalled_message_debug() {
2174 let recalled = RecalledMessage {
2175 message: Message {
2176 role: Role::User,
2177 content: "test".into(),
2178 parts: vec![],
2179 metadata: MessageMetadata::default(),
2180 },
2181 score: 0.95,
2182 };
2183 let dbg = format!("{recalled:?}");
2184 assert!(dbg.contains("RecalledMessage"));
2185 assert!(dbg.contains("0.95"));
2186 }
2187
2188 #[test]
2189 fn summary_clone() {
2190 let summary = Summary {
2191 id: 1,
2192 conversation_id: ConversationId(2),
2193 content: "test summary".into(),
2194 first_message_id: MessageId(1),
2195 last_message_id: MessageId(5),
2196 token_estimate: 10,
2197 };
2198 let cloned = summary.clone();
2199 assert_eq!(summary.id, cloned.id);
2200 assert_eq!(summary.content, cloned.content);
2201 }
2202
2203 #[tokio::test]
2204 async fn remember_preserves_role_mapping() {
2205 let memory = test_semantic_memory(false).await;
2206 let cid = memory.sqlite.create_conversation().await.unwrap();
2207
2208 memory.remember(cid, "user", "u").await.unwrap();
2209 memory.remember(cid, "assistant", "a").await.unwrap();
2210 memory.remember(cid, "system", "s").await.unwrap();
2211
2212 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
2213 assert_eq!(history.len(), 3);
2214 assert_eq!(history[0].role, Role::User);
2215 assert_eq!(history[1].role, Role::Assistant);
2216 assert_eq!(history[2].role, Role::System);
2217 }
2218
2219 #[tokio::test]
2220 async fn new_with_invalid_qdrant_url_graceful() {
2221 let mut mock = MockProvider::default();
2222 mock.supports_embeddings = true;
2223 let provider = AnyProvider::Mock(mock);
2224 let result =
2225 SemanticMemory::new(":memory:", "http://127.0.0.1:1", provider, "test-model").await;
2226 assert!(result.is_ok());
2227 }
2228
2229 #[tokio::test]
2230 async fn test_semantic_memory_sqlite_remember_recall_roundtrip() {
2231 let mut mock = MockProvider::default();
2233 mock.supports_embeddings = true;
2234 let provider = AnyProvider::Mock(mock);
2237
2238 let sqlite = SqliteStore::new(":memory:").await.unwrap();
2239 let pool = sqlite.pool().clone();
2240 let qdrant = Some(Arc::new(
2241 crate::embedding_store::EmbeddingStore::new_sqlite(pool),
2242 ));
2243
2244 let memory = SemanticMemory {
2245 sqlite,
2246 qdrant,
2247 provider,
2248 embedding_model: "test-model".into(),
2249 vector_weight: 0.7,
2250 keyword_weight: 0.3,
2251 temporal_decay_enabled: false,
2252 temporal_decay_half_life_days: 30,
2253 mmr_enabled: false,
2254 mmr_lambda: 0.7,
2255 token_counter: Arc::new(TokenCounter::new()),
2256 graph_store: None,
2257 community_detection_failures: Arc::new(AtomicU64::new(0)),
2258 graph_extraction_count: Arc::new(AtomicU64::new(0)),
2259 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
2260 };
2261
2262 let cid = memory.sqlite().create_conversation().await.unwrap();
2263
2264 let id1 = memory
2266 .remember(cid, "user", "rust async programming")
2267 .await
2268 .unwrap();
2269 let id2 = memory
2270 .remember(cid, "assistant", "use tokio for async")
2271 .await
2272 .unwrap();
2273 assert!(id1 < id2);
2274
2275 let recalled = memory.recall("rust", 5, None).await.unwrap();
2277 assert!(
2278 !recalled.is_empty(),
2279 "recall must return at least one result"
2280 );
2281
2282 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
2284 assert_eq!(history.len(), 2);
2285 assert_eq!(history[0].content, "rust async programming");
2286 }
2287
2288 #[tokio::test]
2289 async fn remember_with_embeddings_supported_but_no_qdrant() {
2290 let memory = test_semantic_memory(true).await;
2291 let cid = memory.sqlite.create_conversation().await.unwrap();
2292
2293 let msg_id = memory.remember(cid, "user", "hello embed").await.unwrap();
2294 assert!(msg_id > MessageId(0));
2295
2296 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
2297 assert_eq!(history.len(), 1);
2298 assert_eq!(history[0].content, "hello embed");
2299 }
2300
2301 #[tokio::test]
2302 async fn remember_verifies_content_via_load_history() {
2303 let memory = test_semantic_memory(false).await;
2304 let cid = memory.sqlite.create_conversation().await.unwrap();
2305
2306 memory.remember(cid, "user", "alpha").await.unwrap();
2307 memory.remember(cid, "assistant", "beta").await.unwrap();
2308 memory.remember(cid, "user", "gamma").await.unwrap();
2309
2310 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
2311 assert_eq!(history.len(), 3);
2312 assert_eq!(history[0].content, "alpha");
2313 assert_eq!(history[1].content, "beta");
2314 assert_eq!(history[2].content, "gamma");
2315 }
2316
2317 #[tokio::test]
2318 async fn message_count_multiple_conversations_isolated() {
2319 let memory = test_semantic_memory(false).await;
2320 let cid1 = memory.sqlite().create_conversation().await.unwrap();
2321 let cid2 = memory.sqlite().create_conversation().await.unwrap();
2322 let cid3 = memory.sqlite().create_conversation().await.unwrap();
2323
2324 for _ in 0..5 {
2325 memory.remember(cid1, "user", "msg").await.unwrap();
2326 }
2327 for _ in 0..3 {
2328 memory.remember(cid2, "user", "msg").await.unwrap();
2329 }
2330
2331 assert_eq!(memory.message_count(cid1).await.unwrap(), 5);
2332 assert_eq!(memory.message_count(cid2).await.unwrap(), 3);
2333 assert_eq!(memory.message_count(cid3).await.unwrap(), 0);
2334 }
2335
2336 #[tokio::test]
2337 async fn summarize_empty_messages_range_returns_none() {
2338 let memory = test_semantic_memory(false).await;
2339 let cid = memory.sqlite().create_conversation().await.unwrap();
2340
2341 for i in 0..6 {
2342 memory
2343 .remember(cid, "user", &format!("msg {i}"))
2344 .await
2345 .unwrap();
2346 }
2347
2348 memory.summarize(cid, 3).await.unwrap();
2349 memory.summarize(cid, 3).await.unwrap();
2350
2351 let summaries = memory.load_summaries(cid).await.unwrap();
2352 assert_eq!(summaries.len(), 2);
2353 }
2354
2355 #[tokio::test]
2356 async fn summarize_token_estimate_populated() {
2357 let memory = test_semantic_memory(false).await;
2358 let cid = memory.sqlite().create_conversation().await.unwrap();
2359
2360 for i in 0..5 {
2361 memory
2362 .remember(cid, "user", &format!("message {i}"))
2363 .await
2364 .unwrap();
2365 }
2366
2367 memory.summarize(cid, 3).await.unwrap();
2368 let summaries = memory.load_summaries(cid).await.unwrap();
2369 let token_est = summaries[0].token_estimate;
2370 assert!(token_est > 0);
2371 }
2372
2373 #[tokio::test]
2374 async fn summarize_fails_when_provider_chat_fails() {
2375 let sqlite = SqliteStore::new(":memory:").await.unwrap();
2376 let provider = AnyProvider::Ollama(zeph_llm::ollama::OllamaProvider::new(
2377 "http://127.0.0.1:1",
2378 "test".into(),
2379 "embed".into(),
2380 ));
2381 let memory = SemanticMemory {
2382 sqlite,
2383 qdrant: None,
2384 provider,
2385 embedding_model: "test".into(),
2386 vector_weight: 0.7,
2387 keyword_weight: 0.3,
2388 temporal_decay_enabled: false,
2389 temporal_decay_half_life_days: 30,
2390 mmr_enabled: false,
2391 mmr_lambda: 0.7,
2392 token_counter: Arc::new(TokenCounter::new()),
2393 graph_store: None,
2394 community_detection_failures: Arc::new(AtomicU64::new(0)),
2395 graph_extraction_count: Arc::new(AtomicU64::new(0)),
2396 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
2397 };
2398 let cid = memory.sqlite().create_conversation().await.unwrap();
2399
2400 for i in 0..5 {
2401 memory
2402 .remember(cid, "user", &format!("msg {i}"))
2403 .await
2404 .unwrap();
2405 }
2406
2407 let result = memory.summarize(cid, 3).await;
2408 assert!(result.is_err());
2409 }
2410
2411 #[tokio::test]
2412 async fn embed_missing_without_embedding_support_returns_zero() {
2413 let memory = test_semantic_memory(false).await;
2414 let cid = memory.sqlite().create_conversation().await.unwrap();
2415 memory
2416 .sqlite()
2417 .save_message(cid, "user", "test message")
2418 .await
2419 .unwrap();
2420
2421 let count = memory.embed_missing().await.unwrap();
2422 assert_eq!(count, 0);
2423 }
2424
2425 #[tokio::test]
2426 async fn has_embedding_returns_false_when_no_qdrant() {
2427 let memory = test_semantic_memory(false).await;
2428 let cid = memory.sqlite.create_conversation().await.unwrap();
2429 let msg_id = memory.remember(cid, "user", "test").await.unwrap();
2430 assert!(!memory.has_embedding(msg_id).await.unwrap());
2431 }
2432
2433 #[tokio::test]
2434 async fn recall_empty_without_qdrant_regardless_of_filter() {
2435 let memory = test_semantic_memory(true).await;
2436 let filter = SearchFilter {
2437 conversation_id: Some(ConversationId(1)),
2438 role: None,
2439 };
2440 let recalled = memory.recall("query", 10, Some(filter)).await.unwrap();
2441 assert!(recalled.is_empty());
2442 }
2443
2444 #[tokio::test]
2445 async fn summarize_message_range_bounds() {
2446 let memory = test_semantic_memory(false).await;
2447 let cid = memory.sqlite().create_conversation().await.unwrap();
2448
2449 for i in 0..8 {
2450 memory
2451 .remember(cid, "user", &format!("msg {i}"))
2452 .await
2453 .unwrap();
2454 }
2455
2456 let summary_id = memory.summarize(cid, 4).await.unwrap().unwrap();
2457 let summaries = memory.load_summaries(cid).await.unwrap();
2458 assert_eq!(summaries.len(), 1);
2459 assert_eq!(summaries[0].id, summary_id);
2460 assert!(summaries[0].first_message_id >= MessageId(1));
2461 assert!(summaries[0].last_message_id >= summaries[0].first_message_id);
2462 }
2463
2464 #[test]
2465 fn build_summarization_prompt_preserves_order() {
2466 let messages = vec![
2467 (MessageId(1), "user".into(), "first".into()),
2468 (MessageId(2), "assistant".into(), "second".into()),
2469 (MessageId(3), "user".into(), "third".into()),
2470 ];
2471 let prompt = build_summarization_prompt(&messages);
2472 let first_pos = prompt.find("user: first").unwrap();
2473 let second_pos = prompt.find("assistant: second").unwrap();
2474 let third_pos = prompt.find("user: third").unwrap();
2475 assert!(first_pos < second_pos);
2476 assert!(second_pos < third_pos);
2477 }
2478
2479 #[test]
2480 fn summary_debug() {
2481 let summary = Summary {
2482 id: 1,
2483 conversation_id: ConversationId(2),
2484 content: "test".into(),
2485 first_message_id: MessageId(1),
2486 last_message_id: MessageId(5),
2487 token_estimate: 10,
2488 };
2489 let dbg = format!("{summary:?}");
2490 assert!(dbg.contains("Summary"));
2491 }
2492
2493 #[tokio::test]
2494 async fn message_count_nonexistent_conversation() {
2495 let memory = test_semantic_memory(false).await;
2496 let count = memory.message_count(ConversationId(999)).await.unwrap();
2497 assert_eq!(count, 0);
2498 }
2499
2500 #[tokio::test]
2501 async fn load_summaries_nonexistent_conversation() {
2502 let memory = test_semantic_memory(false).await;
2503 let summaries = memory.load_summaries(ConversationId(999)).await.unwrap();
2504 assert!(summaries.is_empty());
2505 }
2506
2507 #[tokio::test]
2508 async fn store_session_summary_no_qdrant_noop() {
2509 let memory = test_semantic_memory(true).await;
2510 let result = memory
2511 .store_session_summary(ConversationId(1), "test summary")
2512 .await;
2513 assert!(result.is_ok());
2514 }
2515
2516 #[tokio::test]
2517 async fn store_session_summary_no_embeddings_noop() {
2518 let memory = test_semantic_memory(false).await;
2519 let result = memory
2520 .store_session_summary(ConversationId(1), "test summary")
2521 .await;
2522 assert!(result.is_ok());
2523 }
2524
2525 #[tokio::test]
2526 async fn search_session_summaries_no_qdrant_empty() {
2527 let memory = test_semantic_memory(true).await;
2528 let results = memory
2529 .search_session_summaries("query", 5, None)
2530 .await
2531 .unwrap();
2532 assert!(results.is_empty());
2533 }
2534
2535 #[tokio::test]
2536 async fn search_session_summaries_no_embeddings_empty() {
2537 let memory = test_semantic_memory(false).await;
2538 let results = memory
2539 .search_session_summaries("query", 5, Some(ConversationId(1)))
2540 .await
2541 .unwrap();
2542 assert!(results.is_empty());
2543 }
2544
2545 #[tokio::test]
2546 async fn store_correction_embedding_no_qdrant_noop() {
2547 let memory = test_semantic_memory(true).await;
2548 let result = memory.store_correction_embedding(1, "bad response").await;
2549 assert!(result.is_ok());
2550 }
2551
2552 #[tokio::test]
2553 async fn store_correction_embedding_no_embeddings_noop() {
2554 let memory = test_semantic_memory(false).await;
2555 let result = memory.store_correction_embedding(1, "bad response").await;
2556 assert!(result.is_ok());
2557 }
2558
2559 #[tokio::test]
2560 async fn retrieve_similar_corrections_no_qdrant_empty() {
2561 let memory = test_semantic_memory(true).await;
2562 let results = memory
2563 .retrieve_similar_corrections("query", 5, 0.0)
2564 .await
2565 .unwrap();
2566 assert!(results.is_empty());
2567 }
2568
2569 #[tokio::test]
2570 async fn retrieve_similar_corrections_no_embeddings_empty() {
2571 let memory = test_semantic_memory(false).await;
2572 let results = memory
2573 .retrieve_similar_corrections("query", 5, 0.0)
2574 .await
2575 .unwrap();
2576 assert!(results.is_empty());
2577 }
2578
2579 #[tokio::test]
2580 async fn store_correction_embedding_sqlite_clean_db_roundtrip() {
2581 let mut mock = MockProvider::default();
2582 mock.supports_embeddings = true;
2583 let provider = AnyProvider::Mock(mock);
2584
2585 let sqlite = SqliteStore::new(":memory:").await.unwrap();
2586 let pool = sqlite.pool().clone();
2587 let qdrant = Some(Arc::new(
2588 crate::embedding_store::EmbeddingStore::new_sqlite(pool),
2589 ));
2590
2591 let memory = SemanticMemory {
2592 sqlite,
2593 qdrant,
2594 provider,
2595 embedding_model: "test-model".into(),
2596 vector_weight: 0.7,
2597 keyword_weight: 0.3,
2598 temporal_decay_enabled: false,
2599 temporal_decay_half_life_days: 30,
2600 mmr_enabled: false,
2601 mmr_lambda: 0.7,
2602 token_counter: Arc::new(TokenCounter::new()),
2603 graph_store: None,
2604 community_detection_failures: Arc::new(AtomicU64::new(0)),
2605 graph_extraction_count: Arc::new(AtomicU64::new(0)),
2606 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
2607 };
2608
2609 memory
2612 .store_correction_embedding(1, "bad response")
2613 .await
2614 .unwrap();
2615
2616 let results = memory
2618 .retrieve_similar_corrections("bad", 5, 0.0)
2619 .await
2620 .unwrap();
2621 assert!(results.is_empty());
2622 }
2623
2624 #[test]
2625 fn session_summary_result_debug() {
2626 let result = SessionSummaryResult {
2627 summary_text: "test".into(),
2628 score: 0.9,
2629 conversation_id: ConversationId(1),
2630 };
2631 let dbg = format!("{result:?}");
2632 assert!(dbg.contains("SessionSummaryResult"));
2633 }
2634
2635 #[test]
2636 fn session_summary_result_clone() {
2637 let result = SessionSummaryResult {
2638 summary_text: "test".into(),
2639 score: 0.9,
2640 conversation_id: ConversationId(1),
2641 };
2642 let cloned = result.clone();
2643 assert_eq!(result.summary_text, cloned.summary_text);
2644 assert_eq!(result.conversation_id, cloned.conversation_id);
2645 }
2646
2647 #[tokio::test]
2648 async fn recall_fts5_fallback_without_qdrant() {
2649 let memory = test_semantic_memory(false).await;
2650 let cid = memory.sqlite.create_conversation().await.unwrap();
2651
2652 memory
2653 .remember(cid, "user", "rust programming guide")
2654 .await
2655 .unwrap();
2656 memory
2657 .remember(cid, "assistant", "python tutorial")
2658 .await
2659 .unwrap();
2660 memory
2661 .remember(cid, "user", "advanced rust patterns")
2662 .await
2663 .unwrap();
2664
2665 let recalled = memory.recall("rust", 5, None).await.unwrap();
2666 assert_eq!(recalled.len(), 2);
2667 assert!(recalled[0].score >= recalled[1].score);
2668 }
2669
2670 #[tokio::test]
2671 async fn recall_fts5_fallback_with_filter() {
2672 let memory = test_semantic_memory(false).await;
2673 let cid1 = memory.sqlite.create_conversation().await.unwrap();
2674 let cid2 = memory.sqlite.create_conversation().await.unwrap();
2675
2676 memory.remember(cid1, "user", "hello world").await.unwrap();
2677 memory
2678 .remember(cid2, "user", "hello universe")
2679 .await
2680 .unwrap();
2681
2682 let filter = SearchFilter {
2683 conversation_id: Some(cid1),
2684 role: None,
2685 };
2686 let recalled = memory.recall("hello", 5, Some(filter)).await.unwrap();
2687 assert_eq!(recalled.len(), 1);
2688 }
2689
2690 #[tokio::test]
2691 async fn recall_fts5_no_matches_returns_empty() {
2692 let memory = test_semantic_memory(false).await;
2693 let cid = memory.sqlite.create_conversation().await.unwrap();
2694
2695 memory.remember(cid, "user", "hello world").await.unwrap();
2696
2697 let recalled = memory.recall("nonexistent", 5, None).await.unwrap();
2698 assert!(recalled.is_empty());
2699 }
2700
2701 #[tokio::test]
2702 async fn recall_fts5_respects_limit() {
2703 let memory = test_semantic_memory(false).await;
2704 let cid = memory.sqlite.create_conversation().await.unwrap();
2705
2706 for i in 0..10 {
2707 memory
2708 .remember(cid, "user", &format!("test message number {i}"))
2709 .await
2710 .unwrap();
2711 }
2712
2713 let recalled = memory.recall("test", 3, None).await.unwrap();
2714 assert_eq!(recalled.len(), 3);
2715 }
2716
2717 #[tokio::test]
2720 async fn summarize_fallback_to_plain_text_when_structured_fails() {
2721 let sqlite = SqliteStore::new(":memory:").await.unwrap();
2729 let mut mock = MockProvider::default();
2730 mock.default_response = "plain text summary".into();
2732 let provider = AnyProvider::Mock(mock);
2733
2734 let memory = SemanticMemory {
2735 sqlite,
2736 qdrant: None,
2737 provider,
2738 embedding_model: "test".into(),
2739 vector_weight: 0.7,
2740 keyword_weight: 0.3,
2741 temporal_decay_enabled: false,
2742 temporal_decay_half_life_days: 30,
2743 mmr_enabled: false,
2744 mmr_lambda: 0.7,
2745 token_counter: Arc::new(TokenCounter::new()),
2746 graph_store: None,
2747 community_detection_failures: Arc::new(AtomicU64::new(0)),
2748 graph_extraction_count: Arc::new(AtomicU64::new(0)),
2749 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
2750 };
2751
2752 let cid = memory.sqlite().create_conversation().await.unwrap();
2753 for i in 0..5 {
2754 memory
2755 .remember(cid, "user", &format!("msg {i}"))
2756 .await
2757 .unwrap();
2758 }
2759
2760 let result = memory.summarize(cid, 3).await;
2761 assert!(result.is_ok());
2767 let summaries = memory.load_summaries(cid).await.unwrap();
2768 assert_eq!(summaries.len(), 1);
2769 assert!(!summaries[0].content.is_empty());
2770 }
2771
2772 #[test]
2775 fn temporal_decay_disabled_leaves_scores_unchanged() {
2776 let mut ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2777 let timestamps = std::collections::HashMap::new();
2778 apply_temporal_decay(&mut ranked, ×tamps, 30);
2779 assert!((ranked[0].1 - 1.0).abs() < f64::EPSILON);
2780 assert!((ranked[1].1 - 0.5).abs() < f64::EPSILON);
2781 }
2782
2783 #[test]
2784 fn temporal_decay_zero_age_preserves_score() {
2785 let now = std::time::SystemTime::now()
2786 .duration_since(std::time::UNIX_EPOCH)
2787 .unwrap_or_default()
2788 .as_secs()
2789 .cast_signed();
2790 let mut ranked = vec![(MessageId(1), 1.0f64)];
2791 let mut timestamps = std::collections::HashMap::new();
2792 timestamps.insert(MessageId(1), now);
2793 apply_temporal_decay(&mut ranked, ×tamps, 30);
2794 assert!((ranked[0].1 - 1.0).abs() < 0.01);
2796 }
2797
2798 #[test]
2799 fn temporal_decay_half_life_halves_score() {
2800 let half_life = 30u32;
2802 let age_secs = i64::from(half_life) * 86400;
2803 let now = std::time::SystemTime::now()
2804 .duration_since(std::time::UNIX_EPOCH)
2805 .unwrap_or_default()
2806 .as_secs()
2807 .cast_signed();
2808 let ts = now - age_secs;
2809 let mut ranked = vec![(MessageId(1), 1.0f64)];
2810 let mut timestamps = std::collections::HashMap::new();
2811 timestamps.insert(MessageId(1), ts);
2812 apply_temporal_decay(&mut ranked, ×tamps, half_life);
2813 assert!(
2815 (ranked[0].1 - 0.5).abs() < 0.01,
2816 "score was {}",
2817 ranked[0].1
2818 );
2819 }
2820
2821 #[test]
2824 fn mmr_empty_input_returns_empty() {
2825 let ranked = vec![];
2826 let vectors = std::collections::HashMap::new();
2827 let result = apply_mmr(&ranked, &vectors, 0.7, 5);
2828 assert!(result.is_empty());
2829 }
2830
2831 #[test]
2832 fn mmr_returns_up_to_limit() {
2833 let ranked = vec![
2834 (MessageId(1), 1.0f64),
2835 (MessageId(2), 0.9f64),
2836 (MessageId(3), 0.8f64),
2837 ];
2838 let mut vectors = std::collections::HashMap::new();
2839 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2840 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2841 vectors.insert(MessageId(3), vec![1.0f32, 0.0]);
2842 let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2843 assert_eq!(result.len(), 2);
2844 }
2845
2846 #[test]
2847 fn mmr_without_vectors_picks_by_relevance() {
2848 let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2849 let vectors = std::collections::HashMap::new();
2850 let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2851 assert_eq!(result.len(), 2);
2852 assert_eq!(result[0].0, MessageId(1));
2853 }
2854
2855 #[test]
2856 fn mmr_prefers_diverse_over_redundant() {
2857 let ranked = vec![
2859 (MessageId(1), 1.0f64), (MessageId(2), 0.9f64), (MessageId(3), 0.9f64), ];
2863 let mut vectors = std::collections::HashMap::new();
2864 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2865 vectors.insert(MessageId(2), vec![0.0f32, 1.0]); vectors.insert(MessageId(3), vec![1.0f32, 0.0]); let result = apply_mmr(&ranked, &vectors, 0.5, 2);
2868 assert_eq!(result.len(), 2);
2869 assert_eq!(result[0].0, MessageId(1));
2870 assert_eq!(result[1].0, MessageId(2));
2872 }
2873
2874 #[test]
2875 fn temporal_decay_half_life_zero_is_noop() {
2876 let now = std::time::SystemTime::now()
2877 .duration_since(std::time::UNIX_EPOCH)
2878 .unwrap_or_default()
2879 .as_secs()
2880 .cast_signed();
2881 let age_secs = 30i64 * 86400;
2882 let ts = now - age_secs;
2883 let mut ranked = vec![(MessageId(1), 1.0f64)];
2884 let mut timestamps = std::collections::HashMap::new();
2885 timestamps.insert(MessageId(1), ts);
2886 apply_temporal_decay(&mut ranked, ×tamps, 0);
2888 assert!(
2889 (ranked[0].1 - 1.0).abs() < f64::EPSILON,
2890 "score was {}",
2891 ranked[0].1
2892 );
2893 }
2894
2895 #[test]
2896 fn temporal_decay_huge_age_near_zero() {
2897 let now = std::time::SystemTime::now()
2898 .duration_since(std::time::UNIX_EPOCH)
2899 .unwrap_or_default()
2900 .as_secs()
2901 .cast_signed();
2902 let age_secs = 3650i64 * 86400;
2904 let ts = now - age_secs;
2905 let mut ranked = vec![(MessageId(1), 1.0f64)];
2906 let mut timestamps = std::collections::HashMap::new();
2907 timestamps.insert(MessageId(1), ts);
2908 apply_temporal_decay(&mut ranked, ×tamps, 30);
2909 assert!(ranked[0].1 < 0.001, "score was {}", ranked[0].1);
2911 }
2912
2913 #[test]
2914 fn temporal_decay_small_half_life() {
2915 let now = std::time::SystemTime::now()
2917 .duration_since(std::time::UNIX_EPOCH)
2918 .unwrap_or_default()
2919 .as_secs()
2920 .cast_signed();
2921 let ts = now - 7 * 86400i64;
2922 let mut ranked = vec![(MessageId(1), 1.0f64)];
2923 let mut timestamps = std::collections::HashMap::new();
2924 timestamps.insert(MessageId(1), ts);
2925 apply_temporal_decay(&mut ranked, ×tamps, 1);
2926 assert!(ranked[0].1 < 0.01, "score was {}", ranked[0].1);
2927 }
2928
2929 #[test]
2930 fn mmr_lambda_zero_max_diversity() {
2931 let ranked = vec![
2933 (MessageId(1), 1.0f64), (MessageId(2), 0.9f64), (MessageId(3), 0.85f64), ];
2937 let mut vectors = std::collections::HashMap::new();
2938 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2939 vectors.insert(MessageId(2), vec![0.0f32, 1.0]); vectors.insert(MessageId(3), vec![1.0f32, 0.0]); let result = apply_mmr(&ranked, &vectors, 0.0, 3);
2942 assert_eq!(result.len(), 3);
2943 assert_eq!(result[1].0, MessageId(2));
2945 }
2946
2947 #[test]
2948 fn mmr_lambda_one_pure_relevance() {
2949 let ranked = vec![
2951 (MessageId(1), 1.0f64),
2952 (MessageId(2), 0.8f64),
2953 (MessageId(3), 0.6f64),
2954 ];
2955 let mut vectors = std::collections::HashMap::new();
2956 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2957 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2958 vectors.insert(MessageId(3), vec![0.5f32, 0.5]);
2959 let result = apply_mmr(&ranked, &vectors, 1.0, 3);
2960 assert_eq!(result.len(), 3);
2961 assert_eq!(result[0].0, MessageId(1));
2962 assert_eq!(result[1].0, MessageId(2));
2963 assert_eq!(result[2].0, MessageId(3));
2964 }
2965
2966 #[test]
2967 fn mmr_limit_zero_returns_empty() {
2968 let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.8f64)];
2969 let mut vectors = std::collections::HashMap::new();
2970 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2971 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2972 let result = apply_mmr(&ranked, &vectors, 0.7, 0);
2973 assert!(result.is_empty());
2974 }
2975
2976 #[test]
2977 fn mmr_duplicate_vectors_penalizes_second() {
2978 let ranked = vec![
2980 (MessageId(1), 1.0f64),
2981 (MessageId(2), 1.0f64), (MessageId(3), 0.9f64), ];
2984 let mut vectors = std::collections::HashMap::new();
2985 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2986 vectors.insert(MessageId(2), vec![1.0f32, 0.0]); vectors.insert(MessageId(3), vec![0.0f32, 1.0]); let result = apply_mmr(&ranked, &vectors, 0.5, 3);
2989 assert_eq!(result.len(), 3);
2990 assert_eq!(result[0].0, MessageId(1));
2991 assert_eq!(result[1].0, MessageId(3));
2993 }
2994
2995 #[tokio::test]
2998 async fn recall_routed_keyword_route_returns_fts5_results() {
2999 use crate::{HeuristicRouter, MemoryRoute, MemoryRouter};
3000
3001 let memory = test_semantic_memory(false).await;
3002 let cid = memory.sqlite.create_conversation().await.unwrap();
3003
3004 memory
3005 .remember(cid, "user", "rust programming guide")
3006 .await
3007 .unwrap();
3008 memory
3009 .remember(cid, "assistant", "python tutorial")
3010 .await
3011 .unwrap();
3012
3013 let router = HeuristicRouter;
3015 assert_eq!(router.route("rust_guide"), MemoryRoute::Keyword);
3016
3017 let recalled = memory
3018 .recall_routed("rust_guide", 5, None, &router)
3019 .await
3020 .unwrap();
3021 assert!(recalled.len() <= 2);
3023 }
3024
3025 #[tokio::test]
3026 async fn recall_routed_semantic_route_without_qdrant_returns_empty_vectors() {
3027 use crate::{HeuristicRouter, MemoryRoute, MemoryRouter};
3028
3029 let memory = test_semantic_memory(false).await;
3030 let cid = memory.sqlite.create_conversation().await.unwrap();
3031
3032 memory
3033 .remember(cid, "user", "how does the agent loop work")
3034 .await
3035 .unwrap();
3036
3037 let router = HeuristicRouter;
3039 assert_eq!(
3040 router.route("how does the agent loop work"),
3041 MemoryRoute::Semantic
3042 );
3043
3044 let recalled = memory
3046 .recall_routed("how does the agent loop work", 5, None, &router)
3047 .await
3048 .unwrap();
3049 assert!(recalled.is_empty(), "no Qdrant → empty semantic recall");
3050 }
3051
3052 #[tokio::test]
3053 async fn recall_routed_hybrid_route_falls_back_to_fts5_on_no_qdrant() {
3054 use crate::{HeuristicRouter, MemoryRoute, MemoryRouter};
3055
3056 let memory = test_semantic_memory(false).await;
3057 let cid = memory.sqlite.create_conversation().await.unwrap();
3058
3059 memory
3060 .remember(cid, "user", "context window token budget")
3061 .await
3062 .unwrap();
3063
3064 let router = HeuristicRouter;
3066 assert_eq!(
3067 router.route("context window token budget"),
3068 MemoryRoute::Hybrid
3069 );
3070
3071 let recalled = memory
3073 .recall_routed("context window token budget", 5, None, &router)
3074 .await
3075 .unwrap();
3076 assert!(!recalled.is_empty(), "FTS5 should find the stored message");
3078 }
3079
3080 mod graph_extraction_tests {
3083 use super::*;
3084 use crate::graph::{EntityType, GraphStore};
3085
3086 async fn graph_memory() -> SemanticMemory {
3087 let mem = test_semantic_memory(false).await;
3088 let store = std::sync::Arc::new(GraphStore::new(mem.sqlite.pool().clone()));
3089 mem.with_graph_store(store)
3090 }
3091
3092 #[tokio::test]
3093 async fn recall_graph_returns_empty_when_no_entities() {
3094 let memory = graph_memory().await;
3095 let facts = memory.recall_graph("rust", 10, 2).await.unwrap();
3096 assert!(facts.is_empty(), "empty graph must return empty vec");
3097 }
3098
3099 #[tokio::test]
3100 async fn recall_graph_returns_facts_for_known_entity() {
3101 let memory = graph_memory().await;
3102 let store = GraphStore::new(memory.sqlite.pool().clone());
3103
3104 let rust_id = store
3105 .upsert_entity("rust", "rust", EntityType::Language, Some("a language"))
3106 .await
3107 .unwrap();
3108 let tokio_id = store
3109 .upsert_entity("tokio", "tokio", EntityType::Tool, Some("async runtime"))
3110 .await
3111 .unwrap();
3112 store
3113 .insert_edge(
3114 rust_id,
3115 tokio_id,
3116 "uses",
3117 "Rust uses tokio for async",
3118 0.9,
3119 None,
3120 )
3121 .await
3122 .unwrap();
3123
3124 let facts = memory.recall_graph("rust", 10, 2).await.unwrap();
3125 assert!(!facts.is_empty(), "should return at least one fact");
3126 assert_eq!(facts[0].entity_name, "rust");
3127 assert_eq!(facts[0].relation, "uses");
3128 }
3129
3130 #[tokio::test]
3131 async fn recall_graph_sorted_by_composite_score() {
3132 let memory = graph_memory().await;
3133 let store = GraphStore::new(memory.sqlite.pool().clone());
3134
3135 let a_id = store
3136 .upsert_entity("entity_a", "entity_a", EntityType::Concept, None)
3137 .await
3138 .unwrap();
3139 let b_id = store
3140 .upsert_entity("entity_b", "entity_b", EntityType::Concept, None)
3141 .await
3142 .unwrap();
3143 let c_id = store
3144 .upsert_entity("entity_c", "entity_c", EntityType::Concept, None)
3145 .await
3146 .unwrap();
3147 store
3148 .insert_edge(a_id, b_id, "relates", "a relates b", 0.9, None)
3149 .await
3150 .unwrap();
3151 store
3152 .insert_edge(a_id, c_id, "relates", "a relates c", 0.5, None)
3153 .await
3154 .unwrap();
3155
3156 let facts = memory.recall_graph("entity_a", 10, 1).await.unwrap();
3157 if facts.len() >= 2 {
3158 assert!(
3159 facts[0].composite_score() >= facts[1].composite_score(),
3160 "facts must be sorted descending by composite score"
3161 );
3162 }
3163 }
3164
3165 #[tokio::test]
3166 async fn extract_and_store_returns_zero_stats_for_empty_content() {
3167 let memory = graph_memory().await;
3168 let pool = memory.sqlite.pool().clone();
3169 let provider = test_provider();
3170
3171 let stats = extract_and_store(
3172 String::new(),
3173 vec![],
3174 provider,
3175 pool,
3176 GraphExtractionConfig {
3177 max_entities: 10,
3178 max_edges: 10,
3179 extraction_timeout_secs: 5,
3180 ..Default::default()
3181 },
3182 )
3183 .await
3184 .unwrap();
3185 assert_eq!(stats.entities_upserted, 0);
3186 assert_eq!(stats.edges_inserted, 0);
3187 }
3188
3189 #[tokio::test]
3190 async fn extraction_count_increments_atomically() {
3191 let memory = graph_memory().await;
3192 let pool = memory.sqlite.pool().clone();
3193 let provider = test_provider();
3194
3195 for _ in 0..2 {
3197 let _ = extract_and_store(
3198 "I use Rust for systems programming".to_owned(),
3199 vec![],
3200 provider.clone(),
3201 pool.clone(),
3202 GraphExtractionConfig {
3203 max_entities: 5,
3204 max_edges: 5,
3205 extraction_timeout_secs: 5,
3206 ..Default::default()
3207 },
3208 )
3209 .await;
3210 }
3211
3212 let store = GraphStore::new(pool);
3213 let count = store.get_metadata("extraction_count").await.unwrap();
3214 assert_eq!(
3216 count.as_deref(),
3217 Some("2"),
3218 "extraction_count must be exactly 2 after two extraction attempts"
3219 );
3220 }
3221
3222 #[tokio::test]
3223 async fn recall_graph_truncates_to_limit() {
3224 let memory = graph_memory().await;
3225 let store = GraphStore::new(memory.sqlite.pool().clone());
3226
3227 let root_id = store
3228 .upsert_entity("root", "root", EntityType::Concept, None)
3229 .await
3230 .unwrap();
3231 for i in 0..5 {
3232 let name = format!("target_{i}");
3233 let tid = store
3234 .upsert_entity(&name, &name, EntityType::Concept, None)
3235 .await
3236 .unwrap();
3237 store
3238 .insert_edge(
3239 root_id,
3240 tid,
3241 "links",
3242 &format!("root links {name}"),
3243 0.7,
3244 None,
3245 )
3246 .await
3247 .unwrap();
3248 }
3249
3250 let facts = memory.recall_graph("root", 3, 1).await.unwrap();
3251 assert!(facts.len() <= 3, "recall_graph must respect limit");
3252 }
3253
3254 #[tokio::test]
3256 async fn recall_graph_multi_hop_traverses_two_hops() {
3257 let memory = graph_memory().await;
3260 let store = GraphStore::new(memory.sqlite.pool().clone());
3261
3262 let a_id = store
3263 .upsert_entity("a_entity", "a_entity", EntityType::Person, None)
3264 .await
3265 .unwrap();
3266 let b_id = store
3267 .upsert_entity("b_entity", "b_entity", EntityType::Person, None)
3268 .await
3269 .unwrap();
3270 let c_id = store
3271 .upsert_entity("c_entity", "c_entity", EntityType::Concept, None)
3272 .await
3273 .unwrap();
3274
3275 store
3276 .insert_edge(a_id, b_id, "knows", "a knows b", 0.9, None)
3277 .await
3278 .unwrap();
3279 store
3280 .insert_edge(b_id, c_id, "uses", "b uses c", 0.8, None)
3281 .await
3282 .unwrap();
3283
3284 let facts_1hop = memory.recall_graph("a_entity", 10, 1).await.unwrap();
3286 assert!(!facts_1hop.is_empty(), "hop=1 must find direct edge");
3287
3288 let facts_2hop = memory.recall_graph("a_entity", 10, 2).await.unwrap();
3290 assert!(
3291 facts_2hop.len() >= facts_1hop.len(),
3292 "hop=2 must find at least as many facts as hop=1"
3293 );
3294 let has_bc = facts_2hop.iter().any(|f| {
3295 (f.entity_name.contains("b_entity") || f.target_name.contains("b_entity"))
3296 && (f.entity_name.contains("c_entity") || f.target_name.contains("c_entity"))
3297 });
3298 assert!(has_bc, "hop=2 BFS must traverse to c_entity via b_entity");
3299 }
3300
3301 #[tokio::test]
3303 async fn spawn_graph_extraction_zero_timeout_returns_without_panic() {
3304 let memory = graph_memory().await;
3305 let cfg = GraphExtractionConfig {
3306 max_entities: 5,
3307 max_edges: 5,
3308 extraction_timeout_secs: 0,
3309 ..Default::default()
3310 };
3311 memory.spawn_graph_extraction(
3313 "I use Rust for systems programming".to_owned(),
3314 vec![],
3315 cfg,
3316 );
3317 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
3319 }
3322 }
3323
3324 use proptest::prelude::*;
3327
3328 proptest! {
3329 #[test]
3330 fn count_tokens_never_panics(s in ".*") {
3331 let counter = crate::token_counter::TokenCounter::new();
3332 let _ = counter.count_tokens(&s);
3333 }
3334 }
3335}