1use zeph_llm::any::AnyProvider;
5use zeph_llm::provider::{LlmProvider, Message, MessageMetadata, Role};
6
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU64, Ordering};
9
10use crate::embedding_store::{EmbeddingStore, MessageKind, SearchFilter};
11use crate::error::MemoryError;
12use crate::sqlite::SqliteStore;
13use crate::token_counter::TokenCounter;
14use crate::types::{ConversationId, MessageId};
15use crate::vector_store::{FieldCondition, FieldValue, VectorFilter};
16
17const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
18const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
19const CORRECTIONS_COLLECTION: &str = "zeph_corrections";
20
21#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
22pub struct StructuredSummary {
23 pub summary: String,
24 pub key_facts: Vec<String>,
25 pub entities: Vec<String>,
26}
27
28#[derive(Debug)]
29pub struct RecalledMessage {
30 pub message: Message,
31 pub score: f32,
32}
33
34#[derive(Debug, Clone)]
35pub struct Summary {
36 pub id: i64,
37 pub conversation_id: ConversationId,
38 pub content: String,
39 pub first_message_id: MessageId,
40 pub last_message_id: MessageId,
41 pub token_estimate: i64,
42}
43
44#[derive(Debug, Clone)]
45pub struct SessionSummaryResult {
46 pub summary_text: String,
47 pub score: f32,
48 pub conversation_id: ConversationId,
49}
50
51fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
52 if a.len() != b.len() || a.is_empty() {
53 return 0.0;
54 }
55 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
56 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
57 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
58 if norm_a == 0.0 || norm_b == 0.0 {
59 return 0.0;
60 }
61 dot / (norm_a * norm_b)
62}
63
64fn apply_temporal_decay(
65 ranked: &mut [(MessageId, f64)],
66 timestamps: &std::collections::HashMap<MessageId, i64>,
67 half_life_days: u32,
68) {
69 if half_life_days == 0 {
70 return;
71 }
72 let now = std::time::SystemTime::now()
73 .duration_since(std::time::UNIX_EPOCH)
74 .unwrap_or_default()
75 .as_secs()
76 .cast_signed();
77 let lambda = std::f64::consts::LN_2 / f64::from(half_life_days);
78
79 for (msg_id, score) in ranked.iter_mut() {
80 if let Some(&ts) = timestamps.get(msg_id) {
81 #[allow(clippy::cast_precision_loss)]
82 let age_days = (now - ts).max(0) as f64 / 86400.0;
83 *score *= (-lambda * age_days).exp();
84 }
85 }
86}
87
88fn apply_mmr(
89 ranked: &[(MessageId, f64)],
90 vectors: &std::collections::HashMap<MessageId, Vec<f32>>,
91 lambda: f32,
92 limit: usize,
93) -> Vec<(MessageId, f64)> {
94 if ranked.is_empty() || limit == 0 {
95 return Vec::new();
96 }
97
98 let lambda = f64::from(lambda);
99 let mut selected: Vec<(MessageId, f64)> = Vec::with_capacity(limit);
100 let mut remaining: Vec<(MessageId, f64)> = ranked.to_vec();
101
102 while selected.len() < limit && !remaining.is_empty() {
103 let best_idx = if selected.is_empty() {
104 0
106 } else {
107 let mut best = 0usize;
108 let mut best_score = f64::NEG_INFINITY;
109
110 for (i, &(cand_id, relevance)) in remaining.iter().enumerate() {
111 let max_sim = if let Some(cand_vec) = vectors.get(&cand_id) {
112 selected
113 .iter()
114 .filter_map(|(sel_id, _)| vectors.get(sel_id))
115 .map(|sel_vec| f64::from(cosine_similarity(cand_vec, sel_vec)))
116 .fold(f64::NEG_INFINITY, f64::max)
117 } else {
118 0.0
119 };
120 let max_sim = if max_sim == f64::NEG_INFINITY {
121 0.0
122 } else {
123 max_sim
124 };
125 let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim;
126 if mmr_score > best_score {
127 best_score = mmr_score;
128 best = i;
129 }
130 }
131 best
132 };
133
134 selected.push(remaining.remove(best_idx));
135 }
136
137 selected
138}
139
140fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
141 let mut prompt = String::from(
142 "Summarize the following conversation. Extract key facts, decisions, entities, \
143 and context needed to continue the conversation.\n\n\
144 Respond in JSON with fields: summary (string), key_facts (list of strings), \
145 entities (list of strings).\n\nConversation:\n",
146 );
147
148 for (_, role, content) in messages {
149 prompt.push_str(role);
150 prompt.push_str(": ");
151 prompt.push_str(content);
152 prompt.push('\n');
153 }
154
155 prompt
156}
157
158pub struct SemanticMemory {
159 sqlite: SqliteStore,
160 qdrant: Option<Arc<EmbeddingStore>>,
161 provider: AnyProvider,
162 embedding_model: String,
163 vector_weight: f64,
164 keyword_weight: f64,
165 temporal_decay_enabled: bool,
166 temporal_decay_half_life_days: u32,
167 mmr_enabled: bool,
168 mmr_lambda: f32,
169 pub token_counter: Arc<TokenCounter>,
170 pub graph_store: Option<Arc<crate::graph::GraphStore>>,
171 community_detection_failures: Arc<AtomicU64>,
172 graph_extraction_count: Arc<AtomicU64>,
173 graph_extraction_failures: Arc<AtomicU64>,
174}
175
176impl SemanticMemory {
177 pub async fn new(
185 sqlite_path: &str,
186 qdrant_url: &str,
187 provider: AnyProvider,
188 embedding_model: &str,
189 ) -> Result<Self, MemoryError> {
190 Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
191 }
192
193 pub async fn with_weights(
199 sqlite_path: &str,
200 qdrant_url: &str,
201 provider: AnyProvider,
202 embedding_model: &str,
203 vector_weight: f64,
204 keyword_weight: f64,
205 ) -> Result<Self, MemoryError> {
206 Self::with_weights_and_pool_size(
207 sqlite_path,
208 qdrant_url,
209 provider,
210 embedding_model,
211 vector_weight,
212 keyword_weight,
213 5,
214 )
215 .await
216 }
217
218 pub async fn with_weights_and_pool_size(
224 sqlite_path: &str,
225 qdrant_url: &str,
226 provider: AnyProvider,
227 embedding_model: &str,
228 vector_weight: f64,
229 keyword_weight: f64,
230 pool_size: u32,
231 ) -> Result<Self, MemoryError> {
232 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
233 let pool = sqlite.pool().clone();
234
235 let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
236 Ok(store) => Some(Arc::new(store)),
237 Err(e) => {
238 tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
239 None
240 }
241 };
242
243 Ok(Self {
244 sqlite,
245 qdrant,
246 provider,
247 embedding_model: embedding_model.into(),
248 vector_weight,
249 keyword_weight,
250 temporal_decay_enabled: false,
251 temporal_decay_half_life_days: 30,
252 mmr_enabled: false,
253 mmr_lambda: 0.7,
254 token_counter: Arc::new(TokenCounter::new()),
255 graph_store: None,
256 community_detection_failures: Arc::new(AtomicU64::new(0)),
257 graph_extraction_count: Arc::new(AtomicU64::new(0)),
258 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
259 })
260 }
261
262 #[must_use]
267 pub fn with_graph_store(mut self, store: Arc<crate::graph::GraphStore>) -> Self {
268 self.graph_store = Some(store);
269 self
270 }
271
272 #[must_use]
274 pub fn community_detection_failures(&self) -> u64 {
275 self.community_detection_failures.load(Ordering::Relaxed)
276 }
277
278 #[must_use]
280 pub fn graph_extraction_count(&self) -> u64 {
281 self.graph_extraction_count.load(Ordering::Relaxed)
282 }
283
284 #[must_use]
286 pub fn graph_extraction_failures(&self) -> u64 {
287 self.graph_extraction_failures.load(Ordering::Relaxed)
288 }
289
290 #[must_use]
292 pub fn with_ranking_options(
293 mut self,
294 temporal_decay_enabled: bool,
295 temporal_decay_half_life_days: u32,
296 mmr_enabled: bool,
297 mmr_lambda: f32,
298 ) -> Self {
299 self.temporal_decay_enabled = temporal_decay_enabled;
300 self.temporal_decay_half_life_days = temporal_decay_half_life_days;
301 self.mmr_enabled = mmr_enabled;
302 self.mmr_lambda = mmr_lambda;
303 self
304 }
305
306 #[must_use]
310 pub fn from_parts(
311 sqlite: SqliteStore,
312 qdrant: Option<Arc<EmbeddingStore>>,
313 provider: AnyProvider,
314 embedding_model: impl Into<String>,
315 vector_weight: f64,
316 keyword_weight: f64,
317 token_counter: Arc<TokenCounter>,
318 ) -> Self {
319 Self {
320 sqlite,
321 qdrant,
322 provider,
323 embedding_model: embedding_model.into(),
324 vector_weight,
325 keyword_weight,
326 temporal_decay_enabled: false,
327 temporal_decay_half_life_days: 30,
328 mmr_enabled: false,
329 mmr_lambda: 0.7,
330 token_counter,
331 graph_store: None,
332 community_detection_failures: Arc::new(AtomicU64::new(0)),
333 graph_extraction_count: Arc::new(AtomicU64::new(0)),
334 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
335 }
336 }
337
338 pub async fn with_sqlite_backend(
344 sqlite_path: &str,
345 provider: AnyProvider,
346 embedding_model: &str,
347 vector_weight: f64,
348 keyword_weight: f64,
349 ) -> Result<Self, MemoryError> {
350 Self::with_sqlite_backend_and_pool_size(
351 sqlite_path,
352 provider,
353 embedding_model,
354 vector_weight,
355 keyword_weight,
356 5,
357 )
358 .await
359 }
360
361 pub async fn with_sqlite_backend_and_pool_size(
367 sqlite_path: &str,
368 provider: AnyProvider,
369 embedding_model: &str,
370 vector_weight: f64,
371 keyword_weight: f64,
372 pool_size: u32,
373 ) -> Result<Self, MemoryError> {
374 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
375 let pool = sqlite.pool().clone();
376 let store = EmbeddingStore::new_sqlite(pool);
377
378 Ok(Self {
379 sqlite,
380 qdrant: Some(Arc::new(store)),
381 provider,
382 embedding_model: embedding_model.into(),
383 vector_weight,
384 keyword_weight,
385 temporal_decay_enabled: false,
386 temporal_decay_half_life_days: 30,
387 mmr_enabled: false,
388 mmr_lambda: 0.7,
389 token_counter: Arc::new(TokenCounter::new()),
390 graph_store: None,
391 community_detection_failures: Arc::new(AtomicU64::new(0)),
392 graph_extraction_count: Arc::new(AtomicU64::new(0)),
393 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
394 })
395 }
396
397 pub async fn remember(
406 &self,
407 conversation_id: ConversationId,
408 role: &str,
409 content: &str,
410 ) -> Result<MessageId, MemoryError> {
411 let message_id = self
412 .sqlite
413 .save_message(conversation_id, role, content)
414 .await?;
415
416 if let Some(qdrant) = &self.qdrant
417 && self.provider.supports_embeddings()
418 {
419 match self.provider.embed(content).await {
420 Ok(vector) => {
421 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
423 if let Err(e) = qdrant.ensure_collection(vector_size).await {
424 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
425 } else if let Err(e) = qdrant
426 .store(
427 message_id,
428 conversation_id,
429 role,
430 vector,
431 MessageKind::Regular,
432 &self.embedding_model,
433 )
434 .await
435 {
436 tracing::warn!("Failed to store embedding: {e:#}");
437 }
438 }
439 Err(e) => {
440 tracing::warn!("Failed to generate embedding: {e:#}");
441 }
442 }
443 }
444
445 Ok(message_id)
446 }
447
448 pub async fn remember_with_parts(
457 &self,
458 conversation_id: ConversationId,
459 role: &str,
460 content: &str,
461 parts_json: &str,
462 ) -> Result<(MessageId, bool), MemoryError> {
463 let message_id = self
464 .sqlite
465 .save_message_with_parts(conversation_id, role, content, parts_json)
466 .await?;
467
468 let mut embedding_stored = false;
469
470 if let Some(qdrant) = &self.qdrant
471 && self.provider.supports_embeddings()
472 {
473 match self.provider.embed(content).await {
474 Ok(vector) => {
475 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
476 if let Err(e) = qdrant.ensure_collection(vector_size).await {
477 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
478 } else if let Err(e) = qdrant
479 .store(
480 message_id,
481 conversation_id,
482 role,
483 vector,
484 MessageKind::Regular,
485 &self.embedding_model,
486 )
487 .await
488 {
489 tracing::warn!("Failed to store embedding: {e:#}");
490 } else {
491 embedding_stored = true;
492 }
493 }
494 Err(e) => {
495 tracing::warn!("Failed to generate embedding: {e:#}");
496 }
497 }
498 }
499
500 Ok((message_id, embedding_stored))
501 }
502
503 pub async fn save_only(
511 &self,
512 conversation_id: ConversationId,
513 role: &str,
514 content: &str,
515 parts_json: &str,
516 ) -> Result<MessageId, MemoryError> {
517 self.sqlite
518 .save_message_with_parts(conversation_id, role, content, parts_json)
519 .await
520 }
521
522 pub async fn recall(
532 &self,
533 query: &str,
534 limit: usize,
535 filter: Option<SearchFilter>,
536 ) -> Result<Vec<RecalledMessage>, MemoryError> {
537 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
538
539 let keyword_results = match self
541 .sqlite
542 .keyword_search(query, limit * 2, conversation_id)
543 .await
544 {
545 Ok(results) => results,
546 Err(e) => {
547 tracing::warn!("FTS5 keyword search failed: {e:#}");
548 Vec::new()
549 }
550 };
551
552 let vector_results = if let Some(qdrant) = &self.qdrant
554 && self.provider.supports_embeddings()
555 {
556 let query_vector = self.provider.embed(query).await?;
557 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
558 qdrant.ensure_collection(vector_size).await?;
559 qdrant.search(&query_vector, limit * 2, filter).await?
560 } else {
561 Vec::new()
562 };
563
564 self.recall_merge_and_rank(keyword_results, vector_results, limit)
565 .await
566 }
567
568 async fn recall_fts5_raw(
574 &self,
575 query: &str,
576 limit: usize,
577 conversation_id: Option<ConversationId>,
578 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
579 self.sqlite
580 .keyword_search(query, limit * 2, conversation_id)
581 .await
582 }
583
584 async fn recall_vectors_raw(
591 &self,
592 query: &str,
593 limit: usize,
594 filter: Option<SearchFilter>,
595 ) -> Result<Vec<crate::embedding_store::SearchResult>, MemoryError> {
596 let Some(qdrant) = &self.qdrant else {
597 return Ok(Vec::new());
598 };
599 if !self.provider.supports_embeddings() {
600 return Ok(Vec::new());
601 }
602 let query_vector = self.provider.embed(query).await?;
603 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
604 qdrant.ensure_collection(vector_size).await?;
605 qdrant.search(&query_vector, limit * 2, filter).await
606 }
607
608 #[allow(clippy::cast_possible_truncation)]
617 async fn recall_merge_and_rank(
618 &self,
619 keyword_results: Vec<(MessageId, f64)>,
620 vector_results: Vec<crate::embedding_store::SearchResult>,
621 limit: usize,
622 ) -> Result<Vec<RecalledMessage>, MemoryError> {
623 let mut scores: std::collections::HashMap<MessageId, f64> =
624 std::collections::HashMap::new();
625
626 if !vector_results.is_empty() {
627 let max_vs = vector_results
628 .iter()
629 .map(|r| r.score)
630 .fold(f32::NEG_INFINITY, f32::max);
631 let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
632 for r in &vector_results {
633 let normalized = f64::from(r.score / norm);
634 *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
635 }
636 }
637
638 if !keyword_results.is_empty() {
639 let max_ks = keyword_results
640 .iter()
641 .map(|r| r.1)
642 .fold(f64::NEG_INFINITY, f64::max);
643 let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
644 for &(msg_id, score) in &keyword_results {
645 let normalized = score / norm;
646 *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
647 }
648 }
649
650 if scores.is_empty() {
651 return Ok(Vec::new());
652 }
653
654 let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
655 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
656
657 if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
658 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
659 match self.sqlite.message_timestamps(&ids).await {
660 Ok(timestamps) => {
661 apply_temporal_decay(
662 &mut ranked,
663 ×tamps,
664 self.temporal_decay_half_life_days,
665 );
666 ranked
667 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
668 }
669 Err(e) => {
670 tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
671 }
672 }
673 }
674
675 if self.mmr_enabled && !vector_results.is_empty() {
676 if let Some(qdrant) = &self.qdrant {
677 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
678 match qdrant.get_vectors(&ids).await {
679 Ok(vec_map) if !vec_map.is_empty() => {
680 ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
681 }
682 Ok(_) => {
683 ranked.truncate(limit);
684 }
685 Err(e) => {
686 tracing::warn!("MMR: failed to fetch vectors: {e:#}");
687 ranked.truncate(limit);
688 }
689 }
690 } else {
691 ranked.truncate(limit);
692 }
693 } else {
694 ranked.truncate(limit);
695 }
696
697 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
698 let messages = self.sqlite.messages_by_ids(&ids).await?;
699 let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
700
701 let recalled = ranked
702 .iter()
703 .filter_map(|(msg_id, score)| {
704 msg_map.get(msg_id).map(|msg| RecalledMessage {
705 message: msg.clone(),
706 #[expect(clippy::cast_possible_truncation)]
707 score: *score as f32,
708 })
709 })
710 .collect();
711
712 Ok(recalled)
713 }
714
715 pub async fn recall_routed(
724 &self,
725 query: &str,
726 limit: usize,
727 filter: Option<SearchFilter>,
728 router: &dyn crate::router::MemoryRouter,
729 ) -> Result<Vec<RecalledMessage>, MemoryError> {
730 use crate::router::MemoryRoute;
731
732 let route = router.route(query);
733 tracing::debug!(?route, query_len = query.len(), "memory routing decision");
734
735 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
736
737 let (keyword_results, vector_results): (
738 Vec<(MessageId, f64)>,
739 Vec<crate::embedding_store::SearchResult>,
740 ) = match route {
741 MemoryRoute::Keyword => {
742 let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
743 (kw, Vec::new())
744 }
745 MemoryRoute::Semantic => {
746 let vr = self.recall_vectors_raw(query, limit, filter).await?;
747 (Vec::new(), vr)
748 }
749 MemoryRoute::Hybrid => {
750 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
752 Ok(r) => r,
753 Err(e) => {
754 tracing::warn!("FTS5 keyword search failed: {e:#}");
755 Vec::new()
756 }
757 };
758 let vr = self.recall_vectors_raw(query, limit, filter).await?;
761 (kw, vr)
762 }
763 MemoryRoute::Graph => {
766 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
767 Ok(r) => r,
768 Err(e) => {
769 tracing::warn!("FTS5 keyword search failed (graph→hybrid fallback): {e:#}");
770 Vec::new()
771 }
772 };
773 let vr = self.recall_vectors_raw(query, limit, filter).await?;
774 (kw, vr)
775 }
776 };
777
778 self.recall_merge_and_rank(keyword_results, vector_results, limit)
779 .await
780 }
781
782 pub async fn recall_graph(
790 &self,
791 query: &str,
792 limit: usize,
793 max_hops: u32,
794 ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
795 let Some(store) = &self.graph_store else {
796 return Ok(Vec::new());
797 };
798 crate::graph::retrieval::graph_recall(
799 store,
800 self.qdrant.as_deref(),
801 &self.provider,
802 query,
803 limit,
804 max_hops,
805 )
806 .await
807 }
808
809 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
815 match &self.qdrant {
816 Some(qdrant) => qdrant.has_embedding(message_id).await,
817 None => Ok(false),
818 }
819 }
820
821 pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
830 let Some(qdrant) = &self.qdrant else {
831 return Ok(0);
832 };
833 if !self.provider.supports_embeddings() {
834 return Ok(0);
835 }
836
837 let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
838
839 if unembedded.is_empty() {
840 return Ok(0);
841 }
842
843 let probe = self.provider.embed("probe").await?;
844 let vector_size = u64::try_from(probe.len())?;
845 qdrant.ensure_collection(vector_size).await?;
846
847 let mut count = 0;
848 for (msg_id, conversation_id, role, content) in &unembedded {
849 match self.provider.embed(content).await {
850 Ok(vector) => {
851 if let Err(e) = qdrant
852 .store(
853 *msg_id,
854 *conversation_id,
855 role,
856 vector,
857 MessageKind::Regular,
858 &self.embedding_model,
859 )
860 .await
861 {
862 tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
863 continue;
864 }
865 count += 1;
866 }
867 Err(e) => {
868 tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
869 }
870 }
871 }
872
873 tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
874 Ok(count)
875 }
876
877 pub async fn store_session_summary(
883 &self,
884 conversation_id: ConversationId,
885 summary_text: &str,
886 ) -> Result<(), MemoryError> {
887 let Some(qdrant) = &self.qdrant else {
888 return Ok(());
889 };
890 if !self.provider.supports_embeddings() {
891 return Ok(());
892 }
893
894 let vector = self.provider.embed(summary_text).await?;
895 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
896 qdrant
897 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
898 .await?;
899
900 let payload = serde_json::json!({
901 "conversation_id": conversation_id.0,
902 "summary_text": summary_text,
903 });
904
905 qdrant
906 .store_to_collection(SESSION_SUMMARIES_COLLECTION, payload, vector)
907 .await?;
908
909 tracing::debug!(
910 conversation_id = conversation_id.0,
911 "stored session summary"
912 );
913 Ok(())
914 }
915
916 pub async fn search_session_summaries(
922 &self,
923 query: &str,
924 limit: usize,
925 exclude_conversation_id: Option<ConversationId>,
926 ) -> Result<Vec<SessionSummaryResult>, MemoryError> {
927 let Some(qdrant) = &self.qdrant else {
928 return Ok(Vec::new());
929 };
930 if !self.provider.supports_embeddings() {
931 return Ok(Vec::new());
932 }
933
934 let vector = self.provider.embed(query).await?;
935 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
936 qdrant
937 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
938 .await?;
939
940 let filter = exclude_conversation_id.map(|cid| VectorFilter {
941 must: vec![],
942 must_not: vec![FieldCondition {
943 field: "conversation_id".into(),
944 value: FieldValue::Integer(cid.0),
945 }],
946 });
947
948 let points = qdrant
949 .search_collection(SESSION_SUMMARIES_COLLECTION, &vector, limit, filter)
950 .await?;
951
952 let results = points
953 .into_iter()
954 .filter_map(|point| {
955 let summary_text = point.payload.get("summary_text")?.as_str()?.to_owned();
956 let conversation_id =
957 ConversationId(point.payload.get("conversation_id")?.as_i64()?);
958 Some(SessionSummaryResult {
959 summary_text,
960 score: point.score,
961 conversation_id,
962 })
963 })
964 .collect();
965
966 Ok(results)
967 }
968
969 #[must_use]
971 pub fn sqlite(&self) -> &SqliteStore {
972 &self.sqlite
973 }
974
975 pub async fn is_vector_store_connected(&self) -> bool {
980 match self.qdrant.as_ref() {
981 Some(store) => store.health_check().await,
982 None => false,
983 }
984 }
985
986 #[must_use]
988 pub fn has_vector_store(&self) -> bool {
989 self.qdrant.is_some()
990 }
991
992 #[must_use]
994 pub fn embedding_store(&self) -> Option<&Arc<EmbeddingStore>> {
995 self.qdrant.as_ref()
996 }
997
998 pub async fn message_count(&self, conversation_id: ConversationId) -> Result<i64, MemoryError> {
1004 self.sqlite.count_messages(conversation_id).await
1005 }
1006
1007 pub async fn unsummarized_message_count(
1013 &self,
1014 conversation_id: ConversationId,
1015 ) -> Result<i64, MemoryError> {
1016 let after_id = self
1017 .sqlite
1018 .latest_summary_last_message_id(conversation_id)
1019 .await?
1020 .unwrap_or(MessageId(0));
1021 self.sqlite
1022 .count_messages_after(conversation_id, after_id)
1023 .await
1024 }
1025
1026 pub async fn load_summaries(
1032 &self,
1033 conversation_id: ConversationId,
1034 ) -> Result<Vec<Summary>, MemoryError> {
1035 let rows = self.sqlite.load_summaries(conversation_id).await?;
1036 let summaries = rows
1037 .into_iter()
1038 .map(
1039 |(
1040 id,
1041 conversation_id,
1042 content,
1043 first_message_id,
1044 last_message_id,
1045 token_estimate,
1046 )| {
1047 Summary {
1048 id,
1049 conversation_id,
1050 content,
1051 first_message_id,
1052 last_message_id,
1053 token_estimate,
1054 }
1055 },
1056 )
1057 .collect();
1058 Ok(summaries)
1059 }
1060
1061 pub async fn summarize(
1069 &self,
1070 conversation_id: ConversationId,
1071 message_count: usize,
1072 ) -> Result<Option<i64>, MemoryError> {
1073 let total = self.sqlite.count_messages(conversation_id).await?;
1074
1075 if total <= i64::try_from(message_count)? {
1076 return Ok(None);
1077 }
1078
1079 let after_id = self
1080 .sqlite
1081 .latest_summary_last_message_id(conversation_id)
1082 .await?
1083 .unwrap_or(MessageId(0));
1084
1085 let messages = self
1086 .sqlite
1087 .load_messages_range(conversation_id, after_id, message_count)
1088 .await?;
1089
1090 if messages.is_empty() {
1091 return Ok(None);
1092 }
1093
1094 let prompt = build_summarization_prompt(&messages);
1095 let chat_messages = vec![Message {
1096 role: Role::User,
1097 content: prompt,
1098 parts: vec![],
1099 metadata: MessageMetadata::default(),
1100 }];
1101
1102 let structured = match self
1103 .provider
1104 .chat_typed_erased::<StructuredSummary>(&chat_messages)
1105 .await
1106 {
1107 Ok(s) => s,
1108 Err(e) => {
1109 tracing::warn!(
1110 "structured summarization failed, falling back to plain text: {e:#}"
1111 );
1112 let plain = self.provider.chat(&chat_messages).await?;
1113 StructuredSummary {
1114 summary: plain,
1115 key_facts: vec![],
1116 entities: vec![],
1117 }
1118 }
1119 };
1120 let summary_text = &structured.summary;
1121
1122 let token_estimate = i64::try_from(self.token_counter.count_tokens(summary_text))?;
1123 let first_message_id = messages[0].0;
1124 let last_message_id = messages[messages.len() - 1].0;
1125
1126 let summary_id = self
1127 .sqlite
1128 .save_summary(
1129 conversation_id,
1130 summary_text,
1131 first_message_id,
1132 last_message_id,
1133 token_estimate,
1134 )
1135 .await?;
1136
1137 if let Some(qdrant) = &self.qdrant
1138 && self.provider.supports_embeddings()
1139 {
1140 match self.provider.embed(summary_text).await {
1141 Ok(vector) => {
1142 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
1144 if let Err(e) = qdrant.ensure_collection(vector_size).await {
1145 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
1146 } else if let Err(e) = qdrant
1147 .store(
1148 MessageId(summary_id),
1149 conversation_id,
1150 "system",
1151 vector,
1152 MessageKind::Summary,
1153 &self.embedding_model,
1154 )
1155 .await
1156 {
1157 tracing::warn!("Failed to embed summary: {e:#}");
1158 }
1159 }
1160 Err(e) => {
1161 tracing::warn!("Failed to generate summary embedding: {e:#}");
1162 }
1163 }
1164 }
1165
1166 if !structured.key_facts.is_empty() {
1168 self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
1169 .await;
1170 }
1171
1172 Ok(Some(summary_id))
1173 }
1174
1175 async fn store_key_facts(
1176 &self,
1177 conversation_id: ConversationId,
1178 source_summary_id: i64,
1179 key_facts: &[String],
1180 ) {
1181 let Some(qdrant) = &self.qdrant else {
1182 return;
1183 };
1184 if !self.provider.supports_embeddings() {
1185 return;
1186 }
1187
1188 let Some(first_fact) = key_facts.first() else {
1189 return;
1190 };
1191 let first_vector = match self.provider.embed(first_fact).await {
1192 Ok(v) => v,
1193 Err(e) => {
1194 tracing::warn!("Failed to embed key fact: {e:#}");
1195 return;
1196 }
1197 };
1198 let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
1199 if let Err(e) = qdrant
1200 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
1201 .await
1202 {
1203 tracing::warn!("Failed to ensure key_facts collection: {e:#}");
1204 return;
1205 }
1206
1207 let first_payload = serde_json::json!({
1208 "conversation_id": conversation_id.0,
1209 "fact_text": first_fact,
1210 "source_summary_id": source_summary_id,
1211 });
1212 if let Err(e) = qdrant
1213 .store_to_collection(KEY_FACTS_COLLECTION, first_payload, first_vector)
1214 .await
1215 {
1216 tracing::warn!("Failed to store key fact: {e:#}");
1217 }
1218
1219 for fact in &key_facts[1..] {
1220 match self.provider.embed(fact).await {
1221 Ok(vector) => {
1222 let payload = serde_json::json!({
1223 "conversation_id": conversation_id.0,
1224 "fact_text": fact,
1225 "source_summary_id": source_summary_id,
1226 });
1227 if let Err(e) = qdrant
1228 .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
1229 .await
1230 {
1231 tracing::warn!("Failed to store key fact: {e:#}");
1232 }
1233 }
1234 Err(e) => {
1235 tracing::warn!("Failed to embed key fact: {e:#}");
1236 }
1237 }
1238 }
1239 }
1240
1241 pub async fn search_key_facts(
1247 &self,
1248 query: &str,
1249 limit: usize,
1250 ) -> Result<Vec<String>, MemoryError> {
1251 let Some(qdrant) = &self.qdrant else {
1252 return Ok(Vec::new());
1253 };
1254 if !self.provider.supports_embeddings() {
1255 return Ok(Vec::new());
1256 }
1257
1258 let vector = self.provider.embed(query).await?;
1259 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
1260 qdrant
1261 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
1262 .await?;
1263
1264 let points = qdrant
1265 .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
1266 .await?;
1267
1268 let facts = points
1269 .into_iter()
1270 .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
1271 .collect();
1272
1273 Ok(facts)
1274 }
1275
1276 pub async fn search_document_collection(
1286 &self,
1287 collection: &str,
1288 query: &str,
1289 limit: usize,
1290 ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
1291 let Some(qdrant) = &self.qdrant else {
1292 return Ok(Vec::new());
1293 };
1294 if !self.provider.supports_embeddings() {
1295 return Ok(Vec::new());
1296 }
1297 if !qdrant.collection_exists(collection).await? {
1298 return Ok(Vec::new());
1299 }
1300 let vector = self.provider.embed(query).await?;
1301 qdrant
1302 .search_collection(collection, &vector, limit, None)
1303 .await
1304 }
1305
1306 pub async fn store_correction_embedding(
1314 &self,
1315 correction_id: i64,
1316 correction_text: &str,
1317 ) -> Result<(), MemoryError> {
1318 let Some(ref store) = self.qdrant else {
1319 return Ok(());
1320 };
1321 if !self.provider.supports_embeddings() {
1322 return Ok(());
1323 }
1324 let embedding = self
1325 .provider
1326 .embed(correction_text)
1327 .await
1328 .map_err(|e| MemoryError::Other(e.to_string()))?;
1329 let vector_size = u64::try_from(embedding.len()).unwrap_or(896);
1330 store
1331 .ensure_named_collection(CORRECTIONS_COLLECTION, vector_size)
1332 .await?;
1333 let payload = serde_json::json!({ "correction_id": correction_id });
1334 store
1335 .store_to_collection(CORRECTIONS_COLLECTION, payload, embedding)
1336 .await?;
1337 Ok(())
1338 }
1339
1340 pub async fn retrieve_similar_corrections(
1349 &self,
1350 query: &str,
1351 limit: usize,
1352 min_score: f32,
1353 ) -> Result<Vec<crate::sqlite::corrections::UserCorrectionRow>, MemoryError> {
1354 let Some(ref store) = self.qdrant else {
1355 return Ok(vec![]);
1356 };
1357 if !self.provider.supports_embeddings() {
1358 return Ok(vec![]);
1359 }
1360 let embedding = self
1361 .provider
1362 .embed(query)
1363 .await
1364 .map_err(|e| MemoryError::Other(e.to_string()))?;
1365 let vector_size = u64::try_from(embedding.len()).unwrap_or(896);
1366 store
1367 .ensure_named_collection(CORRECTIONS_COLLECTION, vector_size)
1368 .await?;
1369 let scored = store
1370 .search_collection(CORRECTIONS_COLLECTION, &embedding, limit, None)
1371 .await
1372 .unwrap_or_default();
1373
1374 let mut results = Vec::new();
1375 for point in scored {
1376 if point.score < min_score {
1377 continue;
1378 }
1379 if let Some(id_val) = point.payload.get("correction_id")
1380 && let Some(id) = id_val.as_i64()
1381 {
1382 let rows = self.sqlite.load_corrections_for_id(id).await?;
1383 results.extend(rows);
1384 }
1385 }
1386 Ok(results)
1387 }
1388
1389 pub fn spawn_graph_extraction(
1394 &self,
1395 content: String,
1396 context_messages: Vec<String>,
1397 config: GraphExtractionConfig,
1398 ) {
1399 let pool = self.sqlite.pool().clone();
1400 let provider = self.provider.clone();
1401 let failure_counter = self.community_detection_failures.clone();
1402 let extraction_count = self.graph_extraction_count.clone();
1403 let extraction_failures = self.graph_extraction_failures.clone();
1404
1405 tokio::spawn(async move {
1406 let timeout_dur = std::time::Duration::from_secs(config.extraction_timeout_secs);
1407 let extraction_ok = match tokio::time::timeout(
1408 timeout_dur,
1409 extract_and_store(
1410 content,
1411 context_messages,
1412 provider.clone(),
1413 pool.clone(),
1414 config.clone(),
1415 ),
1416 )
1417 .await
1418 {
1419 Ok(Ok(stats)) => {
1420 tracing::debug!(
1421 entities = stats.entities_upserted,
1422 edges = stats.edges_inserted,
1423 "graph extraction completed"
1424 );
1425 extraction_count.fetch_add(1, Ordering::Relaxed);
1426 true
1427 }
1428 Ok(Err(e)) => {
1429 tracing::warn!("graph extraction failed: {e:#}");
1430 extraction_failures.fetch_add(1, Ordering::Relaxed);
1431 false
1432 }
1433 Err(_elapsed) => {
1434 tracing::warn!("graph extraction timed out");
1435 extraction_failures.fetch_add(1, Ordering::Relaxed);
1436 false
1437 }
1438 };
1439
1440 if extraction_ok && config.community_refresh_interval > 0 {
1441 use crate::graph::GraphStore;
1442
1443 let store = GraphStore::new(pool.clone());
1444 let extraction_count = store.extraction_count().await.unwrap_or(0);
1445 if extraction_count > 0
1446 && i64::try_from(config.community_refresh_interval)
1447 .is_ok_and(|interval| extraction_count % interval == 0)
1448 {
1449 tracing::info!(extraction_count, "triggering community detection refresh");
1450 let store2 = GraphStore::new(pool);
1451 let provider2 = provider;
1452 let retention_days = config.expired_edge_retention_days;
1453 let max_cap = config.max_entities_cap;
1454 let max_prompt_bytes = config.community_summary_max_prompt_bytes;
1455 let concurrency = config.community_summary_concurrency;
1456 tokio::spawn(async move {
1457 match crate::graph::community::detect_communities(
1458 &store2,
1459 &provider2,
1460 max_prompt_bytes,
1461 concurrency,
1462 )
1463 .await
1464 {
1465 Ok(count) => {
1466 tracing::info!(communities = count, "community detection complete");
1467 }
1468 Err(e) => {
1469 tracing::warn!("community detection failed: {e:#}");
1470 failure_counter.fetch_add(1, Ordering::Relaxed);
1471 }
1472 }
1473 match crate::graph::community::run_graph_eviction(
1474 &store2,
1475 retention_days,
1476 max_cap,
1477 )
1478 .await
1479 {
1480 Ok(stats) => {
1481 tracing::info!(
1482 expired_edges = stats.expired_edges_deleted,
1483 orphan_entities = stats.orphan_entities_deleted,
1484 capped_entities = stats.capped_entities_deleted,
1485 "graph eviction complete"
1486 );
1487 }
1488 Err(e) => {
1489 tracing::warn!("graph eviction failed: {e:#}");
1490 }
1491 }
1492 });
1493 }
1494 }
1495 });
1496 }
1497}
1498
1499#[derive(Debug, Clone, Default)]
1504pub struct GraphExtractionConfig {
1505 pub max_entities: usize,
1506 pub max_edges: usize,
1507 pub extraction_timeout_secs: u64,
1508 pub community_refresh_interval: usize,
1509 pub expired_edge_retention_days: u32,
1510 pub max_entities_cap: usize,
1511 pub community_summary_max_prompt_bytes: usize,
1512 pub community_summary_concurrency: usize,
1513}
1514
1515#[derive(Debug, Default)]
1517pub struct ExtractionStats {
1518 pub entities_upserted: usize,
1519 pub edges_inserted: usize,
1520}
1521
1522pub async fn extract_and_store(
1530 content: String,
1531 context_messages: Vec<String>,
1532 provider: AnyProvider,
1533 pool: sqlx::SqlitePool,
1534 config: GraphExtractionConfig,
1535) -> Result<ExtractionStats, MemoryError> {
1536 use crate::graph::{EntityResolver, GraphExtractor, GraphStore};
1537
1538 let extractor = GraphExtractor::new(provider, config.max_entities, config.max_edges);
1539 let ctx_refs: Vec<&str> = context_messages.iter().map(String::as_str).collect();
1540
1541 let store = GraphStore::new(pool);
1542
1543 let pool = store.pool();
1546 sqlx::query(
1547 "INSERT INTO graph_metadata (key, value) VALUES ('extraction_count', '0')
1548 ON CONFLICT(key) DO NOTHING",
1549 )
1550 .execute(pool)
1551 .await?;
1552 sqlx::query(
1553 "UPDATE graph_metadata
1554 SET value = CAST(CAST(value AS INTEGER) + 1 AS TEXT)
1555 WHERE key = 'extraction_count'",
1556 )
1557 .execute(pool)
1558 .await?;
1559
1560 let Some(result) = extractor.extract(&content, &ctx_refs).await? else {
1561 return Ok(ExtractionStats::default());
1562 };
1563
1564 let resolver = EntityResolver::new(&store);
1565
1566 let mut entities_upserted = 0usize;
1567 let mut entity_ids: std::collections::HashMap<String, i64> = std::collections::HashMap::new();
1568
1569 for entity in &result.entities {
1570 match resolver
1571 .resolve(&entity.name, &entity.entity_type, entity.summary.as_deref())
1572 .await
1573 {
1574 Ok((id, _outcome)) => {
1575 entity_ids.insert(entity.name.clone(), id);
1576 entities_upserted += 1;
1577 }
1578 Err(e) => {
1579 tracing::debug!("graph: skipping entity {:?}: {e:#}", entity.name);
1580 }
1581 }
1582 }
1583
1584 let mut edges_inserted = 0usize;
1585 for edge in &result.edges {
1586 let (Some(&src_id), Some(&tgt_id)) =
1587 (entity_ids.get(&edge.source), entity_ids.get(&edge.target))
1588 else {
1589 tracing::debug!(
1590 "graph: skipping edge {:?}->{:?}: entity not resolved",
1591 edge.source,
1592 edge.target
1593 );
1594 continue;
1595 };
1596 match resolver
1597 .resolve_edge(src_id, tgt_id, &edge.relation, &edge.fact, 0.8, None)
1598 .await
1599 {
1600 Ok(Some(_)) => edges_inserted += 1,
1601 Ok(None) => {} Err(e) => {
1603 tracing::debug!("graph: skipping edge: {e:#}");
1604 }
1605 }
1606 }
1607
1608 Ok(ExtractionStats {
1609 entities_upserted,
1610 edges_inserted,
1611 })
1612}
1613
1614#[cfg(test)]
1615mod tests {
1616 use zeph_llm::mock::MockProvider;
1617 use zeph_llm::provider::Role;
1618
1619 use super::*;
1620
1621 fn test_provider() -> AnyProvider {
1622 AnyProvider::Mock(MockProvider::default())
1623 }
1624
1625 async fn test_semantic_memory(_supports_embeddings: bool) -> SemanticMemory {
1626 let provider = test_provider();
1627 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1628
1629 SemanticMemory {
1630 sqlite,
1631 qdrant: None,
1632 provider,
1633 embedding_model: "test-model".into(),
1634 vector_weight: 0.7,
1635 keyword_weight: 0.3,
1636 temporal_decay_enabled: false,
1637 temporal_decay_half_life_days: 30,
1638 mmr_enabled: false,
1639 mmr_lambda: 0.7,
1640 token_counter: Arc::new(TokenCounter::new()),
1641 graph_store: None,
1642 community_detection_failures: Arc::new(AtomicU64::new(0)),
1643 graph_extraction_count: Arc::new(AtomicU64::new(0)),
1644 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
1645 }
1646 }
1647
1648 #[tokio::test]
1649 async fn remember_saves_to_sqlite() {
1650 let memory = test_semantic_memory(false).await;
1651
1652 let cid = memory.sqlite.create_conversation().await.unwrap();
1653 let msg_id = memory.remember(cid, "user", "hello").await.unwrap();
1654
1655 assert_eq!(msg_id, MessageId(1));
1656
1657 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1658 assert_eq!(history.len(), 1);
1659 assert_eq!(history[0].role, Role::User);
1660 assert_eq!(history[0].content, "hello");
1661 }
1662
1663 #[tokio::test]
1664 async fn remember_with_parts_saves_parts_json() {
1665 let memory = test_semantic_memory(false).await;
1666 let cid = memory.sqlite.create_conversation().await.unwrap();
1667
1668 let parts_json =
1669 r#"[{"kind":"ToolOutput","tool_name":"shell","body":"hello","compacted_at":null}]"#;
1670 let (msg_id, _embedding_stored) = memory
1671 .remember_with_parts(cid, "assistant", "tool output", parts_json)
1672 .await
1673 .unwrap();
1674 assert!(msg_id > MessageId(0));
1675
1676 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1677 assert_eq!(history.len(), 1);
1678 assert_eq!(history[0].content, "tool output");
1679 }
1680
1681 #[tokio::test]
1682 async fn recall_returns_empty_without_qdrant() {
1683 let memory = test_semantic_memory(true).await;
1684
1685 let recalled = memory.recall("test", 5, None).await.unwrap();
1686 assert!(recalled.is_empty());
1687 }
1688
1689 #[tokio::test]
1690 async fn has_embedding_without_qdrant() {
1691 let memory = test_semantic_memory(true).await;
1692
1693 let has_embedding = memory.has_embedding(MessageId(1)).await.unwrap();
1694 assert!(!has_embedding);
1695 }
1696
1697 #[tokio::test]
1698 async fn embed_missing_without_qdrant() {
1699 let memory = test_semantic_memory(true).await;
1700
1701 let count = memory.embed_missing().await.unwrap();
1702 assert_eq!(count, 0);
1703 }
1704
1705 #[tokio::test]
1706 async fn sqlite_accessor() {
1707 let memory = test_semantic_memory(false).await;
1708
1709 let cid = memory.sqlite().create_conversation().await.unwrap();
1710 assert_eq!(cid, ConversationId(1));
1711
1712 memory
1713 .sqlite()
1714 .save_message(cid, "user", "test")
1715 .await
1716 .unwrap();
1717
1718 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1719 assert_eq!(history.len(), 1);
1720 }
1721
1722 #[tokio::test]
1723 async fn has_vector_store_returns_false_when_unavailable() {
1724 let memory = test_semantic_memory(false).await;
1725 assert!(!memory.has_vector_store());
1726 }
1727
1728 #[tokio::test]
1729 async fn is_vector_store_connected_returns_false_when_unavailable() {
1730 let memory = test_semantic_memory(false).await;
1731 assert!(!memory.is_vector_store_connected().await);
1732 }
1733
1734 #[tokio::test]
1735 async fn recall_returns_empty_when_embeddings_not_supported() {
1736 let memory = test_semantic_memory(false).await;
1737
1738 let recalled = memory.recall("test", 5, None).await.unwrap();
1739 assert!(recalled.is_empty());
1740 }
1741
1742 #[tokio::test]
1743 async fn embed_missing_returns_zero_when_embeddings_not_supported() {
1744 let memory = test_semantic_memory(false).await;
1745
1746 let cid = memory.sqlite().create_conversation().await.unwrap();
1747 memory
1748 .sqlite()
1749 .save_message(cid, "user", "test")
1750 .await
1751 .unwrap();
1752
1753 let count = memory.embed_missing().await.unwrap();
1754 assert_eq!(count, 0);
1755 }
1756
1757 #[tokio::test]
1758 async fn message_count_empty_conversation() {
1759 let memory = test_semantic_memory(false).await;
1760 let cid = memory.sqlite().create_conversation().await.unwrap();
1761
1762 let count = memory.message_count(cid).await.unwrap();
1763 assert_eq!(count, 0);
1764 }
1765
1766 #[tokio::test]
1767 async fn message_count_after_saves() {
1768 let memory = test_semantic_memory(false).await;
1769 let cid = memory.sqlite().create_conversation().await.unwrap();
1770
1771 memory.remember(cid, "user", "msg1").await.unwrap();
1772 memory.remember(cid, "assistant", "msg2").await.unwrap();
1773
1774 let count = memory.message_count(cid).await.unwrap();
1775 assert_eq!(count, 2);
1776 }
1777
1778 #[tokio::test]
1779 async fn unsummarized_count_decreases_after_summary() {
1780 let memory = test_semantic_memory(false).await;
1781 let cid = memory.sqlite().create_conversation().await.unwrap();
1782
1783 for i in 0..10 {
1784 memory
1785 .remember(cid, "user", &format!("msg{i}"))
1786 .await
1787 .unwrap();
1788 }
1789 assert_eq!(memory.unsummarized_message_count(cid).await.unwrap(), 10);
1790
1791 memory.summarize(cid, 5).await.unwrap();
1792
1793 assert!(memory.unsummarized_message_count(cid).await.unwrap() < 10);
1794 assert_eq!(memory.message_count(cid).await.unwrap(), 10);
1795 }
1796
1797 #[tokio::test]
1798 async fn load_summaries_empty() {
1799 let memory = test_semantic_memory(false).await;
1800 let cid = memory.sqlite().create_conversation().await.unwrap();
1801
1802 let summaries = memory.load_summaries(cid).await.unwrap();
1803 assert!(summaries.is_empty());
1804 }
1805
1806 #[tokio::test]
1807 async fn load_summaries_ordered() {
1808 let memory = test_semantic_memory(false).await;
1809 let cid = memory.sqlite().create_conversation().await.unwrap();
1810
1811 let msg_id1 = memory.remember(cid, "user", "m1").await.unwrap();
1812 let msg_id2 = memory.remember(cid, "assistant", "m2").await.unwrap();
1813 let msg_id3 = memory.remember(cid, "user", "m3").await.unwrap();
1814
1815 let s1 = memory
1816 .sqlite()
1817 .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
1818 .await
1819 .unwrap();
1820 let s2 = memory
1821 .sqlite()
1822 .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
1823 .await
1824 .unwrap();
1825
1826 let summaries = memory.load_summaries(cid).await.unwrap();
1827 assert_eq!(summaries.len(), 2);
1828 assert_eq!(summaries[0].id, s1);
1829 assert_eq!(summaries[0].content, "summary1");
1830 assert_eq!(summaries[1].id, s2);
1831 assert_eq!(summaries[1].content, "summary2");
1832 }
1833
1834 #[tokio::test]
1835 async fn summarize_below_threshold() {
1836 let memory = test_semantic_memory(false).await;
1837 let cid = memory.sqlite().create_conversation().await.unwrap();
1838
1839 memory.remember(cid, "user", "hello").await.unwrap();
1840
1841 let result = memory.summarize(cid, 10).await.unwrap();
1842 assert!(result.is_none());
1843 }
1844
1845 #[tokio::test]
1846 async fn summarize_stores_summary() {
1847 let memory = test_semantic_memory(false).await;
1848 let cid = memory.sqlite().create_conversation().await.unwrap();
1849
1850 for i in 0..5 {
1851 memory
1852 .remember(cid, "user", &format!("message {i}"))
1853 .await
1854 .unwrap();
1855 }
1856
1857 let summary_id = memory.summarize(cid, 3).await.unwrap();
1858 assert!(summary_id.is_some());
1859
1860 let summaries = memory.load_summaries(cid).await.unwrap();
1861 assert_eq!(summaries.len(), 1);
1862 assert_eq!(summaries[0].id, summary_id.unwrap());
1863 assert!(!summaries[0].content.is_empty());
1864 }
1865
1866 #[tokio::test]
1867 async fn summarize_respects_previous_summaries() {
1868 let memory = test_semantic_memory(false).await;
1869 let cid = memory.sqlite().create_conversation().await.unwrap();
1870
1871 for i in 0..10 {
1872 memory
1873 .remember(cid, "user", &format!("message {i}"))
1874 .await
1875 .unwrap();
1876 }
1877
1878 let s1 = memory.summarize(cid, 3).await.unwrap();
1879 assert!(s1.is_some());
1880
1881 let s2 = memory.summarize(cid, 3).await.unwrap();
1882 assert!(s2.is_some());
1883
1884 let summaries = memory.load_summaries(cid).await.unwrap();
1885 assert_eq!(summaries.len(), 2);
1886 assert!(summaries[0].last_message_id < summaries[1].first_message_id);
1887 }
1888
1889 #[tokio::test]
1890 async fn remember_multiple_messages_increments_ids() {
1891 let memory = test_semantic_memory(false).await;
1892 let cid = memory.sqlite.create_conversation().await.unwrap();
1893
1894 let id1 = memory.remember(cid, "user", "first").await.unwrap();
1895 let id2 = memory.remember(cid, "assistant", "second").await.unwrap();
1896 let id3 = memory.remember(cid, "user", "third").await.unwrap();
1897
1898 assert!(id1 < id2);
1899 assert!(id2 < id3);
1900 }
1901
1902 #[tokio::test]
1903 async fn message_count_across_conversations() {
1904 let memory = test_semantic_memory(false).await;
1905 let cid1 = memory.sqlite().create_conversation().await.unwrap();
1906 let cid2 = memory.sqlite().create_conversation().await.unwrap();
1907
1908 memory.remember(cid1, "user", "msg1").await.unwrap();
1909 memory.remember(cid1, "user", "msg2").await.unwrap();
1910 memory.remember(cid2, "user", "msg3").await.unwrap();
1911
1912 assert_eq!(memory.message_count(cid1).await.unwrap(), 2);
1913 assert_eq!(memory.message_count(cid2).await.unwrap(), 1);
1914 }
1915
1916 #[tokio::test]
1917 async fn summarize_exact_threshold_returns_none() {
1918 let memory = test_semantic_memory(false).await;
1919 let cid = memory.sqlite().create_conversation().await.unwrap();
1920
1921 for i in 0..3 {
1922 memory
1923 .remember(cid, "user", &format!("msg {i}"))
1924 .await
1925 .unwrap();
1926 }
1927
1928 let result = memory.summarize(cid, 3).await.unwrap();
1929 assert!(result.is_none());
1930 }
1931
1932 #[tokio::test]
1933 async fn summarize_one_above_threshold_produces_summary() {
1934 let memory = test_semantic_memory(false).await;
1935 let cid = memory.sqlite().create_conversation().await.unwrap();
1936
1937 for i in 0..4 {
1938 memory
1939 .remember(cid, "user", &format!("msg {i}"))
1940 .await
1941 .unwrap();
1942 }
1943
1944 let result = memory.summarize(cid, 3).await.unwrap();
1945 assert!(result.is_some());
1946 }
1947
1948 #[tokio::test]
1949 async fn summary_fields_populated() {
1950 let memory = test_semantic_memory(false).await;
1951 let cid = memory.sqlite().create_conversation().await.unwrap();
1952
1953 for i in 0..5 {
1954 memory
1955 .remember(cid, "user", &format!("msg {i}"))
1956 .await
1957 .unwrap();
1958 }
1959
1960 memory.summarize(cid, 3).await.unwrap();
1961 let summaries = memory.load_summaries(cid).await.unwrap();
1962 let s = &summaries[0];
1963
1964 assert_eq!(s.conversation_id, cid);
1965 assert!(s.first_message_id > MessageId(0));
1966 assert!(s.last_message_id >= s.first_message_id);
1967 assert!(s.token_estimate >= 0);
1968 assert!(!s.content.is_empty());
1969 }
1970
1971 #[test]
1972 fn build_summarization_prompt_format() {
1973 let messages = vec![
1974 (MessageId(1), "user".into(), "Hello".into()),
1975 (MessageId(2), "assistant".into(), "Hi there".into()),
1976 ];
1977 let prompt = build_summarization_prompt(&messages);
1978 assert!(prompt.contains("user: Hello"));
1979 assert!(prompt.contains("assistant: Hi there"));
1980 assert!(prompt.contains("key_facts"));
1981 }
1982
1983 #[test]
1984 fn build_summarization_prompt_empty() {
1985 let messages: Vec<(MessageId, String, String)> = vec![];
1986 let prompt = build_summarization_prompt(&messages);
1987 assert!(prompt.contains("key_facts"));
1988 }
1989
1990 #[test]
1991 fn structured_summary_deserialize() {
1992 let json = r#"{"summary":"s","key_facts":["f1","f2"],"entities":["e1"]}"#;
1993 let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1994 assert_eq!(ss.summary, "s");
1995 assert_eq!(ss.key_facts.len(), 2);
1996 assert_eq!(ss.entities.len(), 1);
1997 }
1998
1999 #[test]
2000 fn structured_summary_empty_facts() {
2001 let json = r#"{"summary":"s","key_facts":[],"entities":[]}"#;
2002 let ss: StructuredSummary = serde_json::from_str(json).unwrap();
2003 assert!(ss.key_facts.is_empty());
2004 assert!(ss.entities.is_empty());
2005 }
2006
2007 #[tokio::test]
2008 async fn search_key_facts_no_qdrant_empty() {
2009 let memory = test_semantic_memory(false).await;
2010 let facts = memory.search_key_facts("query", 5).await.unwrap();
2011 assert!(facts.is_empty());
2012 }
2013
2014 #[test]
2015 fn recalled_message_debug() {
2016 let recalled = RecalledMessage {
2017 message: Message {
2018 role: Role::User,
2019 content: "test".into(),
2020 parts: vec![],
2021 metadata: MessageMetadata::default(),
2022 },
2023 score: 0.95,
2024 };
2025 let dbg = format!("{recalled:?}");
2026 assert!(dbg.contains("RecalledMessage"));
2027 assert!(dbg.contains("0.95"));
2028 }
2029
2030 #[test]
2031 fn summary_clone() {
2032 let summary = Summary {
2033 id: 1,
2034 conversation_id: ConversationId(2),
2035 content: "test summary".into(),
2036 first_message_id: MessageId(1),
2037 last_message_id: MessageId(5),
2038 token_estimate: 10,
2039 };
2040 let cloned = summary.clone();
2041 assert_eq!(summary.id, cloned.id);
2042 assert_eq!(summary.content, cloned.content);
2043 }
2044
2045 #[tokio::test]
2046 async fn remember_preserves_role_mapping() {
2047 let memory = test_semantic_memory(false).await;
2048 let cid = memory.sqlite.create_conversation().await.unwrap();
2049
2050 memory.remember(cid, "user", "u").await.unwrap();
2051 memory.remember(cid, "assistant", "a").await.unwrap();
2052 memory.remember(cid, "system", "s").await.unwrap();
2053
2054 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
2055 assert_eq!(history.len(), 3);
2056 assert_eq!(history[0].role, Role::User);
2057 assert_eq!(history[1].role, Role::Assistant);
2058 assert_eq!(history[2].role, Role::System);
2059 }
2060
2061 #[tokio::test]
2062 async fn new_with_invalid_qdrant_url_graceful() {
2063 let mut mock = MockProvider::default();
2064 mock.supports_embeddings = true;
2065 let provider = AnyProvider::Mock(mock);
2066 let result =
2067 SemanticMemory::new(":memory:", "http://127.0.0.1:1", provider, "test-model").await;
2068 assert!(result.is_ok());
2069 }
2070
2071 #[tokio::test]
2072 async fn test_semantic_memory_sqlite_remember_recall_roundtrip() {
2073 let mut mock = MockProvider::default();
2075 mock.supports_embeddings = true;
2076 let provider = AnyProvider::Mock(mock);
2079
2080 let sqlite = SqliteStore::new(":memory:").await.unwrap();
2081 let pool = sqlite.pool().clone();
2082 let qdrant = Some(Arc::new(
2083 crate::embedding_store::EmbeddingStore::new_sqlite(pool),
2084 ));
2085
2086 let memory = SemanticMemory {
2087 sqlite,
2088 qdrant,
2089 provider,
2090 embedding_model: "test-model".into(),
2091 vector_weight: 0.7,
2092 keyword_weight: 0.3,
2093 temporal_decay_enabled: false,
2094 temporal_decay_half_life_days: 30,
2095 mmr_enabled: false,
2096 mmr_lambda: 0.7,
2097 token_counter: Arc::new(TokenCounter::new()),
2098 graph_store: None,
2099 community_detection_failures: Arc::new(AtomicU64::new(0)),
2100 graph_extraction_count: Arc::new(AtomicU64::new(0)),
2101 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
2102 };
2103
2104 let cid = memory.sqlite().create_conversation().await.unwrap();
2105
2106 let id1 = memory
2108 .remember(cid, "user", "rust async programming")
2109 .await
2110 .unwrap();
2111 let id2 = memory
2112 .remember(cid, "assistant", "use tokio for async")
2113 .await
2114 .unwrap();
2115 assert!(id1 < id2);
2116
2117 let recalled = memory.recall("rust", 5, None).await.unwrap();
2119 assert!(
2120 !recalled.is_empty(),
2121 "recall must return at least one result"
2122 );
2123
2124 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
2126 assert_eq!(history.len(), 2);
2127 assert_eq!(history[0].content, "rust async programming");
2128 }
2129
2130 #[tokio::test]
2131 async fn remember_with_embeddings_supported_but_no_qdrant() {
2132 let memory = test_semantic_memory(true).await;
2133 let cid = memory.sqlite.create_conversation().await.unwrap();
2134
2135 let msg_id = memory.remember(cid, "user", "hello embed").await.unwrap();
2136 assert!(msg_id > MessageId(0));
2137
2138 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
2139 assert_eq!(history.len(), 1);
2140 assert_eq!(history[0].content, "hello embed");
2141 }
2142
2143 #[tokio::test]
2144 async fn remember_verifies_content_via_load_history() {
2145 let memory = test_semantic_memory(false).await;
2146 let cid = memory.sqlite.create_conversation().await.unwrap();
2147
2148 memory.remember(cid, "user", "alpha").await.unwrap();
2149 memory.remember(cid, "assistant", "beta").await.unwrap();
2150 memory.remember(cid, "user", "gamma").await.unwrap();
2151
2152 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
2153 assert_eq!(history.len(), 3);
2154 assert_eq!(history[0].content, "alpha");
2155 assert_eq!(history[1].content, "beta");
2156 assert_eq!(history[2].content, "gamma");
2157 }
2158
2159 #[tokio::test]
2160 async fn message_count_multiple_conversations_isolated() {
2161 let memory = test_semantic_memory(false).await;
2162 let cid1 = memory.sqlite().create_conversation().await.unwrap();
2163 let cid2 = memory.sqlite().create_conversation().await.unwrap();
2164 let cid3 = memory.sqlite().create_conversation().await.unwrap();
2165
2166 for _ in 0..5 {
2167 memory.remember(cid1, "user", "msg").await.unwrap();
2168 }
2169 for _ in 0..3 {
2170 memory.remember(cid2, "user", "msg").await.unwrap();
2171 }
2172
2173 assert_eq!(memory.message_count(cid1).await.unwrap(), 5);
2174 assert_eq!(memory.message_count(cid2).await.unwrap(), 3);
2175 assert_eq!(memory.message_count(cid3).await.unwrap(), 0);
2176 }
2177
2178 #[tokio::test]
2179 async fn summarize_empty_messages_range_returns_none() {
2180 let memory = test_semantic_memory(false).await;
2181 let cid = memory.sqlite().create_conversation().await.unwrap();
2182
2183 for i in 0..6 {
2184 memory
2185 .remember(cid, "user", &format!("msg {i}"))
2186 .await
2187 .unwrap();
2188 }
2189
2190 memory.summarize(cid, 3).await.unwrap();
2191 memory.summarize(cid, 3).await.unwrap();
2192
2193 let summaries = memory.load_summaries(cid).await.unwrap();
2194 assert_eq!(summaries.len(), 2);
2195 }
2196
2197 #[tokio::test]
2198 async fn summarize_token_estimate_populated() {
2199 let memory = test_semantic_memory(false).await;
2200 let cid = memory.sqlite().create_conversation().await.unwrap();
2201
2202 for i in 0..5 {
2203 memory
2204 .remember(cid, "user", &format!("message {i}"))
2205 .await
2206 .unwrap();
2207 }
2208
2209 memory.summarize(cid, 3).await.unwrap();
2210 let summaries = memory.load_summaries(cid).await.unwrap();
2211 let token_est = summaries[0].token_estimate;
2212 assert!(token_est > 0);
2213 }
2214
2215 #[tokio::test]
2216 async fn summarize_fails_when_provider_chat_fails() {
2217 let sqlite = SqliteStore::new(":memory:").await.unwrap();
2218 let provider = AnyProvider::Ollama(zeph_llm::ollama::OllamaProvider::new(
2219 "http://127.0.0.1:1",
2220 "test".into(),
2221 "embed".into(),
2222 ));
2223 let memory = SemanticMemory {
2224 sqlite,
2225 qdrant: None,
2226 provider,
2227 embedding_model: "test".into(),
2228 vector_weight: 0.7,
2229 keyword_weight: 0.3,
2230 temporal_decay_enabled: false,
2231 temporal_decay_half_life_days: 30,
2232 mmr_enabled: false,
2233 mmr_lambda: 0.7,
2234 token_counter: Arc::new(TokenCounter::new()),
2235 graph_store: None,
2236 community_detection_failures: Arc::new(AtomicU64::new(0)),
2237 graph_extraction_count: Arc::new(AtomicU64::new(0)),
2238 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
2239 };
2240 let cid = memory.sqlite().create_conversation().await.unwrap();
2241
2242 for i in 0..5 {
2243 memory
2244 .remember(cid, "user", &format!("msg {i}"))
2245 .await
2246 .unwrap();
2247 }
2248
2249 let result = memory.summarize(cid, 3).await;
2250 assert!(result.is_err());
2251 }
2252
2253 #[tokio::test]
2254 async fn embed_missing_without_embedding_support_returns_zero() {
2255 let memory = test_semantic_memory(false).await;
2256 let cid = memory.sqlite().create_conversation().await.unwrap();
2257 memory
2258 .sqlite()
2259 .save_message(cid, "user", "test message")
2260 .await
2261 .unwrap();
2262
2263 let count = memory.embed_missing().await.unwrap();
2264 assert_eq!(count, 0);
2265 }
2266
2267 #[tokio::test]
2268 async fn has_embedding_returns_false_when_no_qdrant() {
2269 let memory = test_semantic_memory(false).await;
2270 let cid = memory.sqlite.create_conversation().await.unwrap();
2271 let msg_id = memory.remember(cid, "user", "test").await.unwrap();
2272 assert!(!memory.has_embedding(msg_id).await.unwrap());
2273 }
2274
2275 #[tokio::test]
2276 async fn recall_empty_without_qdrant_regardless_of_filter() {
2277 let memory = test_semantic_memory(true).await;
2278 let filter = SearchFilter {
2279 conversation_id: Some(ConversationId(1)),
2280 role: None,
2281 };
2282 let recalled = memory.recall("query", 10, Some(filter)).await.unwrap();
2283 assert!(recalled.is_empty());
2284 }
2285
2286 #[tokio::test]
2287 async fn summarize_message_range_bounds() {
2288 let memory = test_semantic_memory(false).await;
2289 let cid = memory.sqlite().create_conversation().await.unwrap();
2290
2291 for i in 0..8 {
2292 memory
2293 .remember(cid, "user", &format!("msg {i}"))
2294 .await
2295 .unwrap();
2296 }
2297
2298 let summary_id = memory.summarize(cid, 4).await.unwrap().unwrap();
2299 let summaries = memory.load_summaries(cid).await.unwrap();
2300 assert_eq!(summaries.len(), 1);
2301 assert_eq!(summaries[0].id, summary_id);
2302 assert!(summaries[0].first_message_id >= MessageId(1));
2303 assert!(summaries[0].last_message_id >= summaries[0].first_message_id);
2304 }
2305
2306 #[test]
2307 fn build_summarization_prompt_preserves_order() {
2308 let messages = vec![
2309 (MessageId(1), "user".into(), "first".into()),
2310 (MessageId(2), "assistant".into(), "second".into()),
2311 (MessageId(3), "user".into(), "third".into()),
2312 ];
2313 let prompt = build_summarization_prompt(&messages);
2314 let first_pos = prompt.find("user: first").unwrap();
2315 let second_pos = prompt.find("assistant: second").unwrap();
2316 let third_pos = prompt.find("user: third").unwrap();
2317 assert!(first_pos < second_pos);
2318 assert!(second_pos < third_pos);
2319 }
2320
2321 #[test]
2322 fn summary_debug() {
2323 let summary = Summary {
2324 id: 1,
2325 conversation_id: ConversationId(2),
2326 content: "test".into(),
2327 first_message_id: MessageId(1),
2328 last_message_id: MessageId(5),
2329 token_estimate: 10,
2330 };
2331 let dbg = format!("{summary:?}");
2332 assert!(dbg.contains("Summary"));
2333 }
2334
2335 #[tokio::test]
2336 async fn message_count_nonexistent_conversation() {
2337 let memory = test_semantic_memory(false).await;
2338 let count = memory.message_count(ConversationId(999)).await.unwrap();
2339 assert_eq!(count, 0);
2340 }
2341
2342 #[tokio::test]
2343 async fn load_summaries_nonexistent_conversation() {
2344 let memory = test_semantic_memory(false).await;
2345 let summaries = memory.load_summaries(ConversationId(999)).await.unwrap();
2346 assert!(summaries.is_empty());
2347 }
2348
2349 #[tokio::test]
2350 async fn store_session_summary_no_qdrant_noop() {
2351 let memory = test_semantic_memory(true).await;
2352 let result = memory
2353 .store_session_summary(ConversationId(1), "test summary")
2354 .await;
2355 assert!(result.is_ok());
2356 }
2357
2358 #[tokio::test]
2359 async fn store_session_summary_no_embeddings_noop() {
2360 let memory = test_semantic_memory(false).await;
2361 let result = memory
2362 .store_session_summary(ConversationId(1), "test summary")
2363 .await;
2364 assert!(result.is_ok());
2365 }
2366
2367 #[tokio::test]
2368 async fn search_session_summaries_no_qdrant_empty() {
2369 let memory = test_semantic_memory(true).await;
2370 let results = memory
2371 .search_session_summaries("query", 5, None)
2372 .await
2373 .unwrap();
2374 assert!(results.is_empty());
2375 }
2376
2377 #[tokio::test]
2378 async fn search_session_summaries_no_embeddings_empty() {
2379 let memory = test_semantic_memory(false).await;
2380 let results = memory
2381 .search_session_summaries("query", 5, Some(ConversationId(1)))
2382 .await
2383 .unwrap();
2384 assert!(results.is_empty());
2385 }
2386
2387 #[tokio::test]
2388 async fn store_correction_embedding_no_qdrant_noop() {
2389 let memory = test_semantic_memory(true).await;
2390 let result = memory.store_correction_embedding(1, "bad response").await;
2391 assert!(result.is_ok());
2392 }
2393
2394 #[tokio::test]
2395 async fn store_correction_embedding_no_embeddings_noop() {
2396 let memory = test_semantic_memory(false).await;
2397 let result = memory.store_correction_embedding(1, "bad response").await;
2398 assert!(result.is_ok());
2399 }
2400
2401 #[tokio::test]
2402 async fn retrieve_similar_corrections_no_qdrant_empty() {
2403 let memory = test_semantic_memory(true).await;
2404 let results = memory
2405 .retrieve_similar_corrections("query", 5, 0.0)
2406 .await
2407 .unwrap();
2408 assert!(results.is_empty());
2409 }
2410
2411 #[tokio::test]
2412 async fn retrieve_similar_corrections_no_embeddings_empty() {
2413 let memory = test_semantic_memory(false).await;
2414 let results = memory
2415 .retrieve_similar_corrections("query", 5, 0.0)
2416 .await
2417 .unwrap();
2418 assert!(results.is_empty());
2419 }
2420
2421 #[tokio::test]
2422 async fn store_correction_embedding_sqlite_clean_db_roundtrip() {
2423 let mut mock = MockProvider::default();
2424 mock.supports_embeddings = true;
2425 let provider = AnyProvider::Mock(mock);
2426
2427 let sqlite = SqliteStore::new(":memory:").await.unwrap();
2428 let pool = sqlite.pool().clone();
2429 let qdrant = Some(Arc::new(
2430 crate::embedding_store::EmbeddingStore::new_sqlite(pool),
2431 ));
2432
2433 let memory = SemanticMemory {
2434 sqlite,
2435 qdrant,
2436 provider,
2437 embedding_model: "test-model".into(),
2438 vector_weight: 0.7,
2439 keyword_weight: 0.3,
2440 temporal_decay_enabled: false,
2441 temporal_decay_half_life_days: 30,
2442 mmr_enabled: false,
2443 mmr_lambda: 0.7,
2444 token_counter: Arc::new(TokenCounter::new()),
2445 graph_store: None,
2446 community_detection_failures: Arc::new(AtomicU64::new(0)),
2447 graph_extraction_count: Arc::new(AtomicU64::new(0)),
2448 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
2449 };
2450
2451 memory
2454 .store_correction_embedding(1, "bad response")
2455 .await
2456 .unwrap();
2457
2458 let results = memory
2460 .retrieve_similar_corrections("bad", 5, 0.0)
2461 .await
2462 .unwrap();
2463 assert!(results.is_empty());
2464 }
2465
2466 #[test]
2467 fn session_summary_result_debug() {
2468 let result = SessionSummaryResult {
2469 summary_text: "test".into(),
2470 score: 0.9,
2471 conversation_id: ConversationId(1),
2472 };
2473 let dbg = format!("{result:?}");
2474 assert!(dbg.contains("SessionSummaryResult"));
2475 }
2476
2477 #[test]
2478 fn session_summary_result_clone() {
2479 let result = SessionSummaryResult {
2480 summary_text: "test".into(),
2481 score: 0.9,
2482 conversation_id: ConversationId(1),
2483 };
2484 let cloned = result.clone();
2485 assert_eq!(result.summary_text, cloned.summary_text);
2486 assert_eq!(result.conversation_id, cloned.conversation_id);
2487 }
2488
2489 #[tokio::test]
2490 async fn recall_fts5_fallback_without_qdrant() {
2491 let memory = test_semantic_memory(false).await;
2492 let cid = memory.sqlite.create_conversation().await.unwrap();
2493
2494 memory
2495 .remember(cid, "user", "rust programming guide")
2496 .await
2497 .unwrap();
2498 memory
2499 .remember(cid, "assistant", "python tutorial")
2500 .await
2501 .unwrap();
2502 memory
2503 .remember(cid, "user", "advanced rust patterns")
2504 .await
2505 .unwrap();
2506
2507 let recalled = memory.recall("rust", 5, None).await.unwrap();
2508 assert_eq!(recalled.len(), 2);
2509 assert!(recalled[0].score >= recalled[1].score);
2510 }
2511
2512 #[tokio::test]
2513 async fn recall_fts5_fallback_with_filter() {
2514 let memory = test_semantic_memory(false).await;
2515 let cid1 = memory.sqlite.create_conversation().await.unwrap();
2516 let cid2 = memory.sqlite.create_conversation().await.unwrap();
2517
2518 memory.remember(cid1, "user", "hello world").await.unwrap();
2519 memory
2520 .remember(cid2, "user", "hello universe")
2521 .await
2522 .unwrap();
2523
2524 let filter = SearchFilter {
2525 conversation_id: Some(cid1),
2526 role: None,
2527 };
2528 let recalled = memory.recall("hello", 5, Some(filter)).await.unwrap();
2529 assert_eq!(recalled.len(), 1);
2530 }
2531
2532 #[tokio::test]
2533 async fn recall_fts5_no_matches_returns_empty() {
2534 let memory = test_semantic_memory(false).await;
2535 let cid = memory.sqlite.create_conversation().await.unwrap();
2536
2537 memory.remember(cid, "user", "hello world").await.unwrap();
2538
2539 let recalled = memory.recall("nonexistent", 5, None).await.unwrap();
2540 assert!(recalled.is_empty());
2541 }
2542
2543 #[tokio::test]
2544 async fn recall_fts5_respects_limit() {
2545 let memory = test_semantic_memory(false).await;
2546 let cid = memory.sqlite.create_conversation().await.unwrap();
2547
2548 for i in 0..10 {
2549 memory
2550 .remember(cid, "user", &format!("test message number {i}"))
2551 .await
2552 .unwrap();
2553 }
2554
2555 let recalled = memory.recall("test", 3, None).await.unwrap();
2556 assert_eq!(recalled.len(), 3);
2557 }
2558
2559 #[tokio::test]
2562 async fn summarize_fallback_to_plain_text_when_structured_fails() {
2563 let sqlite = SqliteStore::new(":memory:").await.unwrap();
2571 let mut mock = MockProvider::default();
2572 mock.default_response = "plain text summary".into();
2574 let provider = AnyProvider::Mock(mock);
2575
2576 let memory = SemanticMemory {
2577 sqlite,
2578 qdrant: None,
2579 provider,
2580 embedding_model: "test".into(),
2581 vector_weight: 0.7,
2582 keyword_weight: 0.3,
2583 temporal_decay_enabled: false,
2584 temporal_decay_half_life_days: 30,
2585 mmr_enabled: false,
2586 mmr_lambda: 0.7,
2587 token_counter: Arc::new(TokenCounter::new()),
2588 graph_store: None,
2589 community_detection_failures: Arc::new(AtomicU64::new(0)),
2590 graph_extraction_count: Arc::new(AtomicU64::new(0)),
2591 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
2592 };
2593
2594 let cid = memory.sqlite().create_conversation().await.unwrap();
2595 for i in 0..5 {
2596 memory
2597 .remember(cid, "user", &format!("msg {i}"))
2598 .await
2599 .unwrap();
2600 }
2601
2602 let result = memory.summarize(cid, 3).await;
2603 assert!(result.is_ok());
2609 let summaries = memory.load_summaries(cid).await.unwrap();
2610 assert_eq!(summaries.len(), 1);
2611 assert!(!summaries[0].content.is_empty());
2612 }
2613
2614 #[test]
2617 fn temporal_decay_disabled_leaves_scores_unchanged() {
2618 let mut ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2619 let timestamps = std::collections::HashMap::new();
2620 apply_temporal_decay(&mut ranked, ×tamps, 30);
2621 assert!((ranked[0].1 - 1.0).abs() < f64::EPSILON);
2622 assert!((ranked[1].1 - 0.5).abs() < f64::EPSILON);
2623 }
2624
2625 #[test]
2626 fn temporal_decay_zero_age_preserves_score() {
2627 let now = std::time::SystemTime::now()
2628 .duration_since(std::time::UNIX_EPOCH)
2629 .unwrap_or_default()
2630 .as_secs()
2631 .cast_signed();
2632 let mut ranked = vec![(MessageId(1), 1.0f64)];
2633 let mut timestamps = std::collections::HashMap::new();
2634 timestamps.insert(MessageId(1), now);
2635 apply_temporal_decay(&mut ranked, ×tamps, 30);
2636 assert!((ranked[0].1 - 1.0).abs() < 0.01);
2638 }
2639
2640 #[test]
2641 fn temporal_decay_half_life_halves_score() {
2642 let half_life = 30u32;
2644 let age_secs = i64::from(half_life) * 86400;
2645 let now = std::time::SystemTime::now()
2646 .duration_since(std::time::UNIX_EPOCH)
2647 .unwrap_or_default()
2648 .as_secs()
2649 .cast_signed();
2650 let ts = now - age_secs;
2651 let mut ranked = vec![(MessageId(1), 1.0f64)];
2652 let mut timestamps = std::collections::HashMap::new();
2653 timestamps.insert(MessageId(1), ts);
2654 apply_temporal_decay(&mut ranked, ×tamps, half_life);
2655 assert!(
2657 (ranked[0].1 - 0.5).abs() < 0.01,
2658 "score was {}",
2659 ranked[0].1
2660 );
2661 }
2662
2663 #[test]
2666 fn mmr_empty_input_returns_empty() {
2667 let ranked = vec![];
2668 let vectors = std::collections::HashMap::new();
2669 let result = apply_mmr(&ranked, &vectors, 0.7, 5);
2670 assert!(result.is_empty());
2671 }
2672
2673 #[test]
2674 fn mmr_returns_up_to_limit() {
2675 let ranked = vec![
2676 (MessageId(1), 1.0f64),
2677 (MessageId(2), 0.9f64),
2678 (MessageId(3), 0.8f64),
2679 ];
2680 let mut vectors = std::collections::HashMap::new();
2681 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2682 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2683 vectors.insert(MessageId(3), vec![1.0f32, 0.0]);
2684 let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2685 assert_eq!(result.len(), 2);
2686 }
2687
2688 #[test]
2689 fn mmr_without_vectors_picks_by_relevance() {
2690 let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2691 let vectors = std::collections::HashMap::new();
2692 let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2693 assert_eq!(result.len(), 2);
2694 assert_eq!(result[0].0, MessageId(1));
2695 }
2696
2697 #[test]
2698 fn mmr_prefers_diverse_over_redundant() {
2699 let ranked = vec![
2701 (MessageId(1), 1.0f64), (MessageId(2), 0.9f64), (MessageId(3), 0.9f64), ];
2705 let mut vectors = std::collections::HashMap::new();
2706 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2707 vectors.insert(MessageId(2), vec![0.0f32, 1.0]); vectors.insert(MessageId(3), vec![1.0f32, 0.0]); let result = apply_mmr(&ranked, &vectors, 0.5, 2);
2710 assert_eq!(result.len(), 2);
2711 assert_eq!(result[0].0, MessageId(1));
2712 assert_eq!(result[1].0, MessageId(2));
2714 }
2715
2716 #[test]
2717 fn temporal_decay_half_life_zero_is_noop() {
2718 let now = std::time::SystemTime::now()
2719 .duration_since(std::time::UNIX_EPOCH)
2720 .unwrap_or_default()
2721 .as_secs()
2722 .cast_signed();
2723 let age_secs = 30i64 * 86400;
2724 let ts = now - age_secs;
2725 let mut ranked = vec![(MessageId(1), 1.0f64)];
2726 let mut timestamps = std::collections::HashMap::new();
2727 timestamps.insert(MessageId(1), ts);
2728 apply_temporal_decay(&mut ranked, ×tamps, 0);
2730 assert!(
2731 (ranked[0].1 - 1.0).abs() < f64::EPSILON,
2732 "score was {}",
2733 ranked[0].1
2734 );
2735 }
2736
2737 #[test]
2738 fn temporal_decay_huge_age_near_zero() {
2739 let now = std::time::SystemTime::now()
2740 .duration_since(std::time::UNIX_EPOCH)
2741 .unwrap_or_default()
2742 .as_secs()
2743 .cast_signed();
2744 let age_secs = 3650i64 * 86400;
2746 let ts = now - age_secs;
2747 let mut ranked = vec![(MessageId(1), 1.0f64)];
2748 let mut timestamps = std::collections::HashMap::new();
2749 timestamps.insert(MessageId(1), ts);
2750 apply_temporal_decay(&mut ranked, ×tamps, 30);
2751 assert!(ranked[0].1 < 0.001, "score was {}", ranked[0].1);
2753 }
2754
2755 #[test]
2756 fn temporal_decay_small_half_life() {
2757 let now = std::time::SystemTime::now()
2759 .duration_since(std::time::UNIX_EPOCH)
2760 .unwrap_or_default()
2761 .as_secs()
2762 .cast_signed();
2763 let ts = now - 7 * 86400i64;
2764 let mut ranked = vec![(MessageId(1), 1.0f64)];
2765 let mut timestamps = std::collections::HashMap::new();
2766 timestamps.insert(MessageId(1), ts);
2767 apply_temporal_decay(&mut ranked, ×tamps, 1);
2768 assert!(ranked[0].1 < 0.01, "score was {}", ranked[0].1);
2769 }
2770
2771 #[test]
2772 fn mmr_lambda_zero_max_diversity() {
2773 let ranked = vec![
2775 (MessageId(1), 1.0f64), (MessageId(2), 0.9f64), (MessageId(3), 0.85f64), ];
2779 let mut vectors = std::collections::HashMap::new();
2780 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2781 vectors.insert(MessageId(2), vec![0.0f32, 1.0]); vectors.insert(MessageId(3), vec![1.0f32, 0.0]); let result = apply_mmr(&ranked, &vectors, 0.0, 3);
2784 assert_eq!(result.len(), 3);
2785 assert_eq!(result[1].0, MessageId(2));
2787 }
2788
2789 #[test]
2790 fn mmr_lambda_one_pure_relevance() {
2791 let ranked = vec![
2793 (MessageId(1), 1.0f64),
2794 (MessageId(2), 0.8f64),
2795 (MessageId(3), 0.6f64),
2796 ];
2797 let mut vectors = std::collections::HashMap::new();
2798 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2799 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2800 vectors.insert(MessageId(3), vec![0.5f32, 0.5]);
2801 let result = apply_mmr(&ranked, &vectors, 1.0, 3);
2802 assert_eq!(result.len(), 3);
2803 assert_eq!(result[0].0, MessageId(1));
2804 assert_eq!(result[1].0, MessageId(2));
2805 assert_eq!(result[2].0, MessageId(3));
2806 }
2807
2808 #[test]
2809 fn mmr_limit_zero_returns_empty() {
2810 let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.8f64)];
2811 let mut vectors = std::collections::HashMap::new();
2812 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2813 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2814 let result = apply_mmr(&ranked, &vectors, 0.7, 0);
2815 assert!(result.is_empty());
2816 }
2817
2818 #[test]
2819 fn mmr_duplicate_vectors_penalizes_second() {
2820 let ranked = vec![
2822 (MessageId(1), 1.0f64),
2823 (MessageId(2), 1.0f64), (MessageId(3), 0.9f64), ];
2826 let mut vectors = std::collections::HashMap::new();
2827 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2828 vectors.insert(MessageId(2), vec![1.0f32, 0.0]); vectors.insert(MessageId(3), vec![0.0f32, 1.0]); let result = apply_mmr(&ranked, &vectors, 0.5, 3);
2831 assert_eq!(result.len(), 3);
2832 assert_eq!(result[0].0, MessageId(1));
2833 assert_eq!(result[1].0, MessageId(3));
2835 }
2836
2837 #[tokio::test]
2840 async fn recall_routed_keyword_route_returns_fts5_results() {
2841 use crate::{HeuristicRouter, MemoryRoute, MemoryRouter};
2842
2843 let memory = test_semantic_memory(false).await;
2844 let cid = memory.sqlite.create_conversation().await.unwrap();
2845
2846 memory
2847 .remember(cid, "user", "rust programming guide")
2848 .await
2849 .unwrap();
2850 memory
2851 .remember(cid, "assistant", "python tutorial")
2852 .await
2853 .unwrap();
2854
2855 let router = HeuristicRouter;
2857 assert_eq!(router.route("rust_guide"), MemoryRoute::Keyword);
2858
2859 let recalled = memory
2860 .recall_routed("rust_guide", 5, None, &router)
2861 .await
2862 .unwrap();
2863 assert!(recalled.len() <= 2);
2865 }
2866
2867 #[tokio::test]
2868 async fn recall_routed_semantic_route_without_qdrant_returns_empty_vectors() {
2869 use crate::{HeuristicRouter, MemoryRoute, MemoryRouter};
2870
2871 let memory = test_semantic_memory(false).await;
2872 let cid = memory.sqlite.create_conversation().await.unwrap();
2873
2874 memory
2875 .remember(cid, "user", "how does the agent loop work")
2876 .await
2877 .unwrap();
2878
2879 let router = HeuristicRouter;
2881 assert_eq!(
2882 router.route("how does the agent loop work"),
2883 MemoryRoute::Semantic
2884 );
2885
2886 let recalled = memory
2888 .recall_routed("how does the agent loop work", 5, None, &router)
2889 .await
2890 .unwrap();
2891 assert!(recalled.is_empty(), "no Qdrant → empty semantic recall");
2892 }
2893
2894 #[tokio::test]
2895 async fn recall_routed_hybrid_route_falls_back_to_fts5_on_no_qdrant() {
2896 use crate::{HeuristicRouter, MemoryRoute, MemoryRouter};
2897
2898 let memory = test_semantic_memory(false).await;
2899 let cid = memory.sqlite.create_conversation().await.unwrap();
2900
2901 memory
2902 .remember(cid, "user", "context window token budget")
2903 .await
2904 .unwrap();
2905
2906 let router = HeuristicRouter;
2908 assert_eq!(
2909 router.route("context window token budget"),
2910 MemoryRoute::Hybrid
2911 );
2912
2913 let recalled = memory
2915 .recall_routed("context window token budget", 5, None, &router)
2916 .await
2917 .unwrap();
2918 assert!(!recalled.is_empty(), "FTS5 should find the stored message");
2920 }
2921
2922 mod graph_extraction_tests {
2925 use super::*;
2926 use crate::graph::{EntityType, GraphStore};
2927
2928 async fn graph_memory() -> SemanticMemory {
2929 let mem = test_semantic_memory(false).await;
2930 let store = std::sync::Arc::new(GraphStore::new(mem.sqlite.pool().clone()));
2931 mem.with_graph_store(store)
2932 }
2933
2934 #[tokio::test]
2935 async fn recall_graph_returns_empty_when_no_entities() {
2936 let memory = graph_memory().await;
2937 let facts = memory.recall_graph("rust", 10, 2).await.unwrap();
2938 assert!(facts.is_empty(), "empty graph must return empty vec");
2939 }
2940
2941 #[tokio::test]
2942 async fn recall_graph_returns_facts_for_known_entity() {
2943 let memory = graph_memory().await;
2944 let store = GraphStore::new(memory.sqlite.pool().clone());
2945
2946 let rust_id = store
2947 .upsert_entity("rust", "rust", EntityType::Language, Some("a language"))
2948 .await
2949 .unwrap();
2950 let tokio_id = store
2951 .upsert_entity("tokio", "tokio", EntityType::Tool, Some("async runtime"))
2952 .await
2953 .unwrap();
2954 store
2955 .insert_edge(
2956 rust_id,
2957 tokio_id,
2958 "uses",
2959 "Rust uses tokio for async",
2960 0.9,
2961 None,
2962 )
2963 .await
2964 .unwrap();
2965
2966 let facts = memory.recall_graph("rust", 10, 2).await.unwrap();
2967 assert!(!facts.is_empty(), "should return at least one fact");
2968 assert_eq!(facts[0].entity_name, "rust");
2969 assert_eq!(facts[0].relation, "uses");
2970 }
2971
2972 #[tokio::test]
2973 async fn recall_graph_sorted_by_composite_score() {
2974 let memory = graph_memory().await;
2975 let store = GraphStore::new(memory.sqlite.pool().clone());
2976
2977 let a_id = store
2978 .upsert_entity("entity_a", "entity_a", EntityType::Concept, None)
2979 .await
2980 .unwrap();
2981 let b_id = store
2982 .upsert_entity("entity_b", "entity_b", EntityType::Concept, None)
2983 .await
2984 .unwrap();
2985 let c_id = store
2986 .upsert_entity("entity_c", "entity_c", EntityType::Concept, None)
2987 .await
2988 .unwrap();
2989 store
2990 .insert_edge(a_id, b_id, "relates", "a relates b", 0.9, None)
2991 .await
2992 .unwrap();
2993 store
2994 .insert_edge(a_id, c_id, "relates", "a relates c", 0.5, None)
2995 .await
2996 .unwrap();
2997
2998 let facts = memory.recall_graph("entity_a", 10, 1).await.unwrap();
2999 if facts.len() >= 2 {
3000 assert!(
3001 facts[0].composite_score() >= facts[1].composite_score(),
3002 "facts must be sorted descending by composite score"
3003 );
3004 }
3005 }
3006
3007 #[tokio::test]
3008 async fn extract_and_store_returns_zero_stats_for_empty_content() {
3009 let memory = graph_memory().await;
3010 let pool = memory.sqlite.pool().clone();
3011 let provider = test_provider();
3012
3013 let stats = extract_and_store(
3014 String::new(),
3015 vec![],
3016 provider,
3017 pool,
3018 GraphExtractionConfig {
3019 max_entities: 10,
3020 max_edges: 10,
3021 extraction_timeout_secs: 5,
3022 ..Default::default()
3023 },
3024 )
3025 .await
3026 .unwrap();
3027 assert_eq!(stats.entities_upserted, 0);
3028 assert_eq!(stats.edges_inserted, 0);
3029 }
3030
3031 #[tokio::test]
3032 async fn extraction_count_increments_atomically() {
3033 let memory = graph_memory().await;
3034 let pool = memory.sqlite.pool().clone();
3035 let provider = test_provider();
3036
3037 for _ in 0..2 {
3039 let _ = extract_and_store(
3040 "I use Rust for systems programming".to_owned(),
3041 vec![],
3042 provider.clone(),
3043 pool.clone(),
3044 GraphExtractionConfig {
3045 max_entities: 5,
3046 max_edges: 5,
3047 extraction_timeout_secs: 5,
3048 ..Default::default()
3049 },
3050 )
3051 .await;
3052 }
3053
3054 let store = GraphStore::new(pool);
3055 let count = store.get_metadata("extraction_count").await.unwrap();
3056 assert_eq!(
3058 count.as_deref(),
3059 Some("2"),
3060 "extraction_count must be exactly 2 after two extraction attempts"
3061 );
3062 }
3063
3064 #[tokio::test]
3065 async fn recall_graph_truncates_to_limit() {
3066 let memory = graph_memory().await;
3067 let store = GraphStore::new(memory.sqlite.pool().clone());
3068
3069 let root_id = store
3070 .upsert_entity("root", "root", EntityType::Concept, None)
3071 .await
3072 .unwrap();
3073 for i in 0..5 {
3074 let name = format!("target_{i}");
3075 let tid = store
3076 .upsert_entity(&name, &name, EntityType::Concept, None)
3077 .await
3078 .unwrap();
3079 store
3080 .insert_edge(
3081 root_id,
3082 tid,
3083 "links",
3084 &format!("root links {name}"),
3085 0.7,
3086 None,
3087 )
3088 .await
3089 .unwrap();
3090 }
3091
3092 let facts = memory.recall_graph("root", 3, 1).await.unwrap();
3093 assert!(facts.len() <= 3, "recall_graph must respect limit");
3094 }
3095
3096 #[tokio::test]
3098 async fn recall_graph_multi_hop_traverses_two_hops() {
3099 let memory = graph_memory().await;
3102 let store = GraphStore::new(memory.sqlite.pool().clone());
3103
3104 let a_id = store
3105 .upsert_entity("a_entity", "a_entity", EntityType::Person, None)
3106 .await
3107 .unwrap();
3108 let b_id = store
3109 .upsert_entity("b_entity", "b_entity", EntityType::Person, None)
3110 .await
3111 .unwrap();
3112 let c_id = store
3113 .upsert_entity("c_entity", "c_entity", EntityType::Concept, None)
3114 .await
3115 .unwrap();
3116
3117 store
3118 .insert_edge(a_id, b_id, "knows", "a knows b", 0.9, None)
3119 .await
3120 .unwrap();
3121 store
3122 .insert_edge(b_id, c_id, "uses", "b uses c", 0.8, None)
3123 .await
3124 .unwrap();
3125
3126 let facts_1hop = memory.recall_graph("a_entity", 10, 1).await.unwrap();
3128 assert!(!facts_1hop.is_empty(), "hop=1 must find direct edge");
3129
3130 let facts_2hop = memory.recall_graph("a_entity", 10, 2).await.unwrap();
3132 assert!(
3133 facts_2hop.len() >= facts_1hop.len(),
3134 "hop=2 must find at least as many facts as hop=1"
3135 );
3136 let has_bc = facts_2hop.iter().any(|f| {
3137 (f.entity_name.contains("b_entity") || f.target_name.contains("b_entity"))
3138 && (f.entity_name.contains("c_entity") || f.target_name.contains("c_entity"))
3139 });
3140 assert!(has_bc, "hop=2 BFS must traverse to c_entity via b_entity");
3141 }
3142
3143 #[tokio::test]
3145 async fn spawn_graph_extraction_zero_timeout_returns_without_panic() {
3146 let memory = graph_memory().await;
3147 let cfg = GraphExtractionConfig {
3148 max_entities: 5,
3149 max_edges: 5,
3150 extraction_timeout_secs: 0,
3151 ..Default::default()
3152 };
3153 memory.spawn_graph_extraction(
3155 "I use Rust for systems programming".to_owned(),
3156 vec![],
3157 cfg,
3158 );
3159 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
3161 }
3164 }
3165
3166 use proptest::prelude::*;
3169
3170 proptest! {
3171 #[test]
3172 fn count_tokens_never_panics(s in ".*") {
3173 let counter = crate::token_counter::TokenCounter::new();
3174 let _ = counter.count_tokens(&s);
3175 }
3176 }
3177}