1use zeph_llm::any::AnyProvider;
5use zeph_llm::provider::{LlmProvider, Message, MessageMetadata, Role};
6
7use std::sync::Arc;
8
9use crate::embedding_store::{EmbeddingStore, MessageKind, SearchFilter};
10use crate::error::MemoryError;
11use crate::sqlite::SqliteStore;
12use crate::token_counter::TokenCounter;
13use crate::types::{ConversationId, MessageId};
14use crate::vector_store::{FieldCondition, FieldValue, VectorFilter};
15
16const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
17const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
18const CORRECTIONS_COLLECTION: &str = "zeph_corrections";
19
20#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
21pub struct StructuredSummary {
22 pub summary: String,
23 pub key_facts: Vec<String>,
24 pub entities: Vec<String>,
25}
26
27#[derive(Debug)]
28pub struct RecalledMessage {
29 pub message: Message,
30 pub score: f32,
31}
32
33#[derive(Debug, Clone)]
34pub struct Summary {
35 pub id: i64,
36 pub conversation_id: ConversationId,
37 pub content: String,
38 pub first_message_id: MessageId,
39 pub last_message_id: MessageId,
40 pub token_estimate: i64,
41}
42
43#[derive(Debug, Clone)]
44pub struct SessionSummaryResult {
45 pub summary_text: String,
46 pub score: f32,
47 pub conversation_id: ConversationId,
48}
49
50fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
51 if a.len() != b.len() || a.is_empty() {
52 return 0.0;
53 }
54 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
55 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
56 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
57 if norm_a == 0.0 || norm_b == 0.0 {
58 return 0.0;
59 }
60 dot / (norm_a * norm_b)
61}
62
63fn apply_temporal_decay(
64 ranked: &mut [(MessageId, f64)],
65 timestamps: &std::collections::HashMap<MessageId, i64>,
66 half_life_days: u32,
67) {
68 if half_life_days == 0 {
69 return;
70 }
71 let now = std::time::SystemTime::now()
72 .duration_since(std::time::UNIX_EPOCH)
73 .unwrap_or_default()
74 .as_secs()
75 .cast_signed();
76 let lambda = std::f64::consts::LN_2 / f64::from(half_life_days);
77
78 for (msg_id, score) in ranked.iter_mut() {
79 if let Some(&ts) = timestamps.get(msg_id) {
80 #[allow(clippy::cast_precision_loss)]
81 let age_days = (now - ts).max(0) as f64 / 86400.0;
82 *score *= (-lambda * age_days).exp();
83 }
84 }
85}
86
87fn apply_mmr(
88 ranked: &[(MessageId, f64)],
89 vectors: &std::collections::HashMap<MessageId, Vec<f32>>,
90 lambda: f32,
91 limit: usize,
92) -> Vec<(MessageId, f64)> {
93 if ranked.is_empty() || limit == 0 {
94 return Vec::new();
95 }
96
97 let lambda = f64::from(lambda);
98 let mut selected: Vec<(MessageId, f64)> = Vec::with_capacity(limit);
99 let mut remaining: Vec<(MessageId, f64)> = ranked.to_vec();
100
101 while selected.len() < limit && !remaining.is_empty() {
102 let best_idx = if selected.is_empty() {
103 0
105 } else {
106 let mut best = 0usize;
107 let mut best_score = f64::NEG_INFINITY;
108
109 for (i, &(cand_id, relevance)) in remaining.iter().enumerate() {
110 let max_sim = if let Some(cand_vec) = vectors.get(&cand_id) {
111 selected
112 .iter()
113 .filter_map(|(sel_id, _)| vectors.get(sel_id))
114 .map(|sel_vec| f64::from(cosine_similarity(cand_vec, sel_vec)))
115 .fold(f64::NEG_INFINITY, f64::max)
116 } else {
117 0.0
118 };
119 let max_sim = if max_sim == f64::NEG_INFINITY {
120 0.0
121 } else {
122 max_sim
123 };
124 let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim;
125 if mmr_score > best_score {
126 best_score = mmr_score;
127 best = i;
128 }
129 }
130 best
131 };
132
133 selected.push(remaining.remove(best_idx));
134 }
135
136 selected
137}
138
139fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
140 let mut prompt = String::from(
141 "Summarize the following conversation. Extract key facts, decisions, entities, \
142 and context needed to continue the conversation.\n\n\
143 Respond in JSON with fields: summary (string), key_facts (list of strings), \
144 entities (list of strings).\n\nConversation:\n",
145 );
146
147 for (_, role, content) in messages {
148 prompt.push_str(role);
149 prompt.push_str(": ");
150 prompt.push_str(content);
151 prompt.push('\n');
152 }
153
154 prompt
155}
156
157pub struct SemanticMemory {
158 sqlite: SqliteStore,
159 qdrant: Option<EmbeddingStore>,
160 provider: AnyProvider,
161 embedding_model: String,
162 vector_weight: f64,
163 keyword_weight: f64,
164 temporal_decay_enabled: bool,
165 temporal_decay_half_life_days: u32,
166 mmr_enabled: bool,
167 mmr_lambda: f32,
168 pub token_counter: Arc<TokenCounter>,
169}
170
171impl SemanticMemory {
172 pub async fn new(
180 sqlite_path: &str,
181 qdrant_url: &str,
182 provider: AnyProvider,
183 embedding_model: &str,
184 ) -> Result<Self, MemoryError> {
185 Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
186 }
187
188 pub async fn with_weights(
194 sqlite_path: &str,
195 qdrant_url: &str,
196 provider: AnyProvider,
197 embedding_model: &str,
198 vector_weight: f64,
199 keyword_weight: f64,
200 ) -> Result<Self, MemoryError> {
201 Self::with_weights_and_pool_size(
202 sqlite_path,
203 qdrant_url,
204 provider,
205 embedding_model,
206 vector_weight,
207 keyword_weight,
208 5,
209 )
210 .await
211 }
212
213 pub async fn with_weights_and_pool_size(
219 sqlite_path: &str,
220 qdrant_url: &str,
221 provider: AnyProvider,
222 embedding_model: &str,
223 vector_weight: f64,
224 keyword_weight: f64,
225 pool_size: u32,
226 ) -> Result<Self, MemoryError> {
227 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
228 let pool = sqlite.pool().clone();
229
230 let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
231 Ok(store) => Some(store),
232 Err(e) => {
233 tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
234 None
235 }
236 };
237
238 Ok(Self {
239 sqlite,
240 qdrant,
241 provider,
242 embedding_model: embedding_model.into(),
243 vector_weight,
244 keyword_weight,
245 temporal_decay_enabled: false,
246 temporal_decay_half_life_days: 30,
247 mmr_enabled: false,
248 mmr_lambda: 0.7,
249 token_counter: Arc::new(TokenCounter::new()),
250 })
251 }
252
253 #[must_use]
255 pub fn with_ranking_options(
256 mut self,
257 temporal_decay_enabled: bool,
258 temporal_decay_half_life_days: u32,
259 mmr_enabled: bool,
260 mmr_lambda: f32,
261 ) -> Self {
262 self.temporal_decay_enabled = temporal_decay_enabled;
263 self.temporal_decay_half_life_days = temporal_decay_half_life_days;
264 self.mmr_enabled = mmr_enabled;
265 self.mmr_lambda = mmr_lambda;
266 self
267 }
268
269 #[cfg(any(test, feature = "mock"))]
273 #[must_use]
274 pub fn from_parts(
275 sqlite: SqliteStore,
276 qdrant: Option<EmbeddingStore>,
277 provider: AnyProvider,
278 embedding_model: impl Into<String>,
279 vector_weight: f64,
280 keyword_weight: f64,
281 token_counter: Arc<TokenCounter>,
282 ) -> Self {
283 Self {
284 sqlite,
285 qdrant,
286 provider,
287 embedding_model: embedding_model.into(),
288 vector_weight,
289 keyword_weight,
290 temporal_decay_enabled: false,
291 temporal_decay_half_life_days: 30,
292 mmr_enabled: false,
293 mmr_lambda: 0.7,
294 token_counter,
295 }
296 }
297
298 pub async fn with_sqlite_backend(
304 sqlite_path: &str,
305 provider: AnyProvider,
306 embedding_model: &str,
307 vector_weight: f64,
308 keyword_weight: f64,
309 ) -> Result<Self, MemoryError> {
310 Self::with_sqlite_backend_and_pool_size(
311 sqlite_path,
312 provider,
313 embedding_model,
314 vector_weight,
315 keyword_weight,
316 5,
317 )
318 .await
319 }
320
321 pub async fn with_sqlite_backend_and_pool_size(
327 sqlite_path: &str,
328 provider: AnyProvider,
329 embedding_model: &str,
330 vector_weight: f64,
331 keyword_weight: f64,
332 pool_size: u32,
333 ) -> Result<Self, MemoryError> {
334 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
335 let pool = sqlite.pool().clone();
336 let store = EmbeddingStore::new_sqlite(pool);
337
338 Ok(Self {
339 sqlite,
340 qdrant: Some(store),
341 provider,
342 embedding_model: embedding_model.into(),
343 vector_weight,
344 keyword_weight,
345 temporal_decay_enabled: false,
346 temporal_decay_half_life_days: 30,
347 mmr_enabled: false,
348 mmr_lambda: 0.7,
349 token_counter: Arc::new(TokenCounter::new()),
350 })
351 }
352
353 pub async fn remember(
362 &self,
363 conversation_id: ConversationId,
364 role: &str,
365 content: &str,
366 ) -> Result<MessageId, MemoryError> {
367 let message_id = self
368 .sqlite
369 .save_message(conversation_id, role, content)
370 .await?;
371
372 if let Some(qdrant) = &self.qdrant
373 && self.provider.supports_embeddings()
374 {
375 match self.provider.embed(content).await {
376 Ok(vector) => {
377 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
379 if let Err(e) = qdrant.ensure_collection(vector_size).await {
380 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
381 } else if let Err(e) = qdrant
382 .store(
383 message_id,
384 conversation_id,
385 role,
386 vector,
387 MessageKind::Regular,
388 &self.embedding_model,
389 )
390 .await
391 {
392 tracing::warn!("Failed to store embedding: {e:#}");
393 }
394 }
395 Err(e) => {
396 tracing::warn!("Failed to generate embedding: {e:#}");
397 }
398 }
399 }
400
401 Ok(message_id)
402 }
403
404 pub async fn remember_with_parts(
413 &self,
414 conversation_id: ConversationId,
415 role: &str,
416 content: &str,
417 parts_json: &str,
418 ) -> Result<(MessageId, bool), MemoryError> {
419 let message_id = self
420 .sqlite
421 .save_message_with_parts(conversation_id, role, content, parts_json)
422 .await?;
423
424 let mut embedding_stored = false;
425
426 if let Some(qdrant) = &self.qdrant
427 && self.provider.supports_embeddings()
428 {
429 match self.provider.embed(content).await {
430 Ok(vector) => {
431 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
432 if let Err(e) = qdrant.ensure_collection(vector_size).await {
433 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
434 } else if let Err(e) = qdrant
435 .store(
436 message_id,
437 conversation_id,
438 role,
439 vector,
440 MessageKind::Regular,
441 &self.embedding_model,
442 )
443 .await
444 {
445 tracing::warn!("Failed to store embedding: {e:#}");
446 } else {
447 embedding_stored = true;
448 }
449 }
450 Err(e) => {
451 tracing::warn!("Failed to generate embedding: {e:#}");
452 }
453 }
454 }
455
456 Ok((message_id, embedding_stored))
457 }
458
459 pub async fn save_only(
467 &self,
468 conversation_id: ConversationId,
469 role: &str,
470 content: &str,
471 parts_json: &str,
472 ) -> Result<MessageId, MemoryError> {
473 self.sqlite
474 .save_message_with_parts(conversation_id, role, content, parts_json)
475 .await
476 }
477
478 #[allow(clippy::too_many_lines)]
488 pub async fn recall(
489 &self,
490 query: &str,
491 limit: usize,
492 filter: Option<SearchFilter>,
493 ) -> Result<Vec<RecalledMessage>, MemoryError> {
494 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
495
496 let keyword_results = match self
498 .sqlite
499 .keyword_search(query, limit * 2, conversation_id)
500 .await
501 {
502 Ok(results) => results,
503 Err(e) => {
504 tracing::warn!("FTS5 keyword search failed: {e:#}");
505 Vec::new()
506 }
507 };
508
509 let vector_results = if let Some(qdrant) = &self.qdrant
511 && self.provider.supports_embeddings()
512 {
513 let query_vector = self.provider.embed(query).await?;
514 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
515 qdrant.ensure_collection(vector_size).await?;
516 qdrant.search(&query_vector, limit * 2, filter).await?
517 } else {
518 Vec::new()
519 };
520
521 let mut scores: std::collections::HashMap<MessageId, f64> =
523 std::collections::HashMap::new();
524
525 if !vector_results.is_empty() {
526 let max_vs = vector_results
527 .iter()
528 .map(|r| r.score)
529 .fold(f32::NEG_INFINITY, f32::max);
530 let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
531 for r in &vector_results {
532 let normalized = f64::from(r.score / norm);
533 *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
534 }
535 }
536
537 if !keyword_results.is_empty() {
538 let max_ks = keyword_results
539 .iter()
540 .map(|r| r.1)
541 .fold(f64::NEG_INFINITY, f64::max);
542 let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
543 for &(msg_id, score) in &keyword_results {
544 let normalized = score / norm;
545 *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
546 }
547 }
548
549 if scores.is_empty() {
550 return Ok(Vec::new());
551 }
552
553 let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
555 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
556
557 if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
559 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
560 match self.sqlite.message_timestamps(&ids).await {
561 Ok(timestamps) => {
562 apply_temporal_decay(
563 &mut ranked,
564 ×tamps,
565 self.temporal_decay_half_life_days,
566 );
567 ranked
568 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
569 }
570 Err(e) => {
571 tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
572 }
573 }
574 }
575
576 if self.mmr_enabled && !vector_results.is_empty() {
578 if let Some(qdrant) = &self.qdrant {
579 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
580 match qdrant.get_vectors(&ids).await {
581 Ok(vec_map) if !vec_map.is_empty() => {
582 ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
583 }
584 Ok(_) => {
585 ranked.truncate(limit);
586 }
587 Err(e) => {
588 tracing::warn!("MMR: failed to fetch vectors: {e:#}");
589 ranked.truncate(limit);
590 }
591 }
592 } else {
593 ranked.truncate(limit);
594 }
595 } else {
596 ranked.truncate(limit);
597 }
598
599 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
600 let messages = self.sqlite.messages_by_ids(&ids).await?;
601 let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
602
603 let recalled = ranked
604 .iter()
605 .filter_map(|(msg_id, score)| {
606 msg_map.get(msg_id).map(|msg| RecalledMessage {
607 message: msg.clone(),
608 #[expect(clippy::cast_possible_truncation)]
609 score: *score as f32,
610 })
611 })
612 .collect();
613
614 Ok(recalled)
615 }
616
617 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
623 match &self.qdrant {
624 Some(qdrant) => qdrant.has_embedding(message_id).await,
625 None => Ok(false),
626 }
627 }
628
629 pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
638 let Some(qdrant) = &self.qdrant else {
639 return Ok(0);
640 };
641 if !self.provider.supports_embeddings() {
642 return Ok(0);
643 }
644
645 let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
646
647 if unembedded.is_empty() {
648 return Ok(0);
649 }
650
651 let probe = self.provider.embed("probe").await?;
652 let vector_size = u64::try_from(probe.len())?;
653 qdrant.ensure_collection(vector_size).await?;
654
655 let mut count = 0;
656 for (msg_id, conversation_id, role, content) in &unembedded {
657 match self.provider.embed(content).await {
658 Ok(vector) => {
659 if let Err(e) = qdrant
660 .store(
661 *msg_id,
662 *conversation_id,
663 role,
664 vector,
665 MessageKind::Regular,
666 &self.embedding_model,
667 )
668 .await
669 {
670 tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
671 continue;
672 }
673 count += 1;
674 }
675 Err(e) => {
676 tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
677 }
678 }
679 }
680
681 tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
682 Ok(count)
683 }
684
685 pub async fn store_session_summary(
691 &self,
692 conversation_id: ConversationId,
693 summary_text: &str,
694 ) -> Result<(), MemoryError> {
695 let Some(qdrant) = &self.qdrant else {
696 return Ok(());
697 };
698 if !self.provider.supports_embeddings() {
699 return Ok(());
700 }
701
702 let vector = self.provider.embed(summary_text).await?;
703 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
704 qdrant
705 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
706 .await?;
707
708 let payload = serde_json::json!({
709 "conversation_id": conversation_id.0,
710 "summary_text": summary_text,
711 });
712
713 qdrant
714 .store_to_collection(SESSION_SUMMARIES_COLLECTION, payload, vector)
715 .await?;
716
717 tracing::debug!(
718 conversation_id = conversation_id.0,
719 "stored session summary"
720 );
721 Ok(())
722 }
723
724 pub async fn search_session_summaries(
730 &self,
731 query: &str,
732 limit: usize,
733 exclude_conversation_id: Option<ConversationId>,
734 ) -> Result<Vec<SessionSummaryResult>, MemoryError> {
735 let Some(qdrant) = &self.qdrant else {
736 return Ok(Vec::new());
737 };
738 if !self.provider.supports_embeddings() {
739 return Ok(Vec::new());
740 }
741
742 let vector = self.provider.embed(query).await?;
743 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
744 qdrant
745 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
746 .await?;
747
748 let filter = exclude_conversation_id.map(|cid| VectorFilter {
749 must: vec![],
750 must_not: vec![FieldCondition {
751 field: "conversation_id".into(),
752 value: FieldValue::Integer(cid.0),
753 }],
754 });
755
756 let points = qdrant
757 .search_collection(SESSION_SUMMARIES_COLLECTION, &vector, limit, filter)
758 .await?;
759
760 let results = points
761 .into_iter()
762 .filter_map(|point| {
763 let summary_text = point.payload.get("summary_text")?.as_str()?.to_owned();
764 let conversation_id =
765 ConversationId(point.payload.get("conversation_id")?.as_i64()?);
766 Some(SessionSummaryResult {
767 summary_text,
768 score: point.score,
769 conversation_id,
770 })
771 })
772 .collect();
773
774 Ok(results)
775 }
776
777 #[must_use]
779 pub fn sqlite(&self) -> &SqliteStore {
780 &self.sqlite
781 }
782
783 pub async fn is_vector_store_connected(&self) -> bool {
788 match self.qdrant.as_ref() {
789 Some(store) => store.health_check().await,
790 None => false,
791 }
792 }
793
794 #[must_use]
796 pub fn has_vector_store(&self) -> bool {
797 self.qdrant.is_some()
798 }
799
800 pub async fn message_count(&self, conversation_id: ConversationId) -> Result<i64, MemoryError> {
806 self.sqlite.count_messages(conversation_id).await
807 }
808
809 pub async fn unsummarized_message_count(
815 &self,
816 conversation_id: ConversationId,
817 ) -> Result<i64, MemoryError> {
818 let after_id = self
819 .sqlite
820 .latest_summary_last_message_id(conversation_id)
821 .await?
822 .unwrap_or(MessageId(0));
823 self.sqlite
824 .count_messages_after(conversation_id, after_id)
825 .await
826 }
827
828 pub async fn load_summaries(
834 &self,
835 conversation_id: ConversationId,
836 ) -> Result<Vec<Summary>, MemoryError> {
837 let rows = self.sqlite.load_summaries(conversation_id).await?;
838 let summaries = rows
839 .into_iter()
840 .map(
841 |(
842 id,
843 conversation_id,
844 content,
845 first_message_id,
846 last_message_id,
847 token_estimate,
848 )| {
849 Summary {
850 id,
851 conversation_id,
852 content,
853 first_message_id,
854 last_message_id,
855 token_estimate,
856 }
857 },
858 )
859 .collect();
860 Ok(summaries)
861 }
862
863 pub async fn summarize(
871 &self,
872 conversation_id: ConversationId,
873 message_count: usize,
874 ) -> Result<Option<i64>, MemoryError> {
875 let total = self.sqlite.count_messages(conversation_id).await?;
876
877 if total <= i64::try_from(message_count)? {
878 return Ok(None);
879 }
880
881 let after_id = self
882 .sqlite
883 .latest_summary_last_message_id(conversation_id)
884 .await?
885 .unwrap_or(MessageId(0));
886
887 let messages = self
888 .sqlite
889 .load_messages_range(conversation_id, after_id, message_count)
890 .await?;
891
892 if messages.is_empty() {
893 return Ok(None);
894 }
895
896 let prompt = build_summarization_prompt(&messages);
897 let chat_messages = vec![Message {
898 role: Role::User,
899 content: prompt,
900 parts: vec![],
901 metadata: MessageMetadata::default(),
902 }];
903
904 let structured = match self
905 .provider
906 .chat_typed_erased::<StructuredSummary>(&chat_messages)
907 .await
908 {
909 Ok(s) => s,
910 Err(e) => {
911 tracing::warn!(
912 "structured summarization failed, falling back to plain text: {e:#}"
913 );
914 let plain = self.provider.chat(&chat_messages).await?;
915 StructuredSummary {
916 summary: plain,
917 key_facts: vec![],
918 entities: vec![],
919 }
920 }
921 };
922 let summary_text = &structured.summary;
923
924 let token_estimate = i64::try_from(self.token_counter.count_tokens(summary_text))?;
925 let first_message_id = messages[0].0;
926 let last_message_id = messages[messages.len() - 1].0;
927
928 let summary_id = self
929 .sqlite
930 .save_summary(
931 conversation_id,
932 summary_text,
933 first_message_id,
934 last_message_id,
935 token_estimate,
936 )
937 .await?;
938
939 if let Some(qdrant) = &self.qdrant
940 && self.provider.supports_embeddings()
941 {
942 match self.provider.embed(summary_text).await {
943 Ok(vector) => {
944 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
946 if let Err(e) = qdrant.ensure_collection(vector_size).await {
947 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
948 } else if let Err(e) = qdrant
949 .store(
950 MessageId(summary_id),
951 conversation_id,
952 "system",
953 vector,
954 MessageKind::Summary,
955 &self.embedding_model,
956 )
957 .await
958 {
959 tracing::warn!("Failed to embed summary: {e:#}");
960 }
961 }
962 Err(e) => {
963 tracing::warn!("Failed to generate summary embedding: {e:#}");
964 }
965 }
966 }
967
968 if !structured.key_facts.is_empty() {
970 self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
971 .await;
972 }
973
974 Ok(Some(summary_id))
975 }
976
977 async fn store_key_facts(
978 &self,
979 conversation_id: ConversationId,
980 source_summary_id: i64,
981 key_facts: &[String],
982 ) {
983 let Some(qdrant) = &self.qdrant else {
984 return;
985 };
986 if !self.provider.supports_embeddings() {
987 return;
988 }
989
990 let Some(first_fact) = key_facts.first() else {
991 return;
992 };
993 let first_vector = match self.provider.embed(first_fact).await {
994 Ok(v) => v,
995 Err(e) => {
996 tracing::warn!("Failed to embed key fact: {e:#}");
997 return;
998 }
999 };
1000 let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
1001 if let Err(e) = qdrant
1002 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
1003 .await
1004 {
1005 tracing::warn!("Failed to ensure key_facts collection: {e:#}");
1006 return;
1007 }
1008
1009 let first_payload = serde_json::json!({
1010 "conversation_id": conversation_id.0,
1011 "fact_text": first_fact,
1012 "source_summary_id": source_summary_id,
1013 });
1014 if let Err(e) = qdrant
1015 .store_to_collection(KEY_FACTS_COLLECTION, first_payload, first_vector)
1016 .await
1017 {
1018 tracing::warn!("Failed to store key fact: {e:#}");
1019 }
1020
1021 for fact in &key_facts[1..] {
1022 match self.provider.embed(fact).await {
1023 Ok(vector) => {
1024 let payload = serde_json::json!({
1025 "conversation_id": conversation_id.0,
1026 "fact_text": fact,
1027 "source_summary_id": source_summary_id,
1028 });
1029 if let Err(e) = qdrant
1030 .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
1031 .await
1032 {
1033 tracing::warn!("Failed to store key fact: {e:#}");
1034 }
1035 }
1036 Err(e) => {
1037 tracing::warn!("Failed to embed key fact: {e:#}");
1038 }
1039 }
1040 }
1041 }
1042
1043 pub async fn search_key_facts(
1049 &self,
1050 query: &str,
1051 limit: usize,
1052 ) -> Result<Vec<String>, MemoryError> {
1053 let Some(qdrant) = &self.qdrant else {
1054 return Ok(Vec::new());
1055 };
1056 if !self.provider.supports_embeddings() {
1057 return Ok(Vec::new());
1058 }
1059
1060 let vector = self.provider.embed(query).await?;
1061 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
1062 qdrant
1063 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
1064 .await?;
1065
1066 let points = qdrant
1067 .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
1068 .await?;
1069
1070 let facts = points
1071 .into_iter()
1072 .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
1073 .collect();
1074
1075 Ok(facts)
1076 }
1077
1078 pub async fn search_document_collection(
1088 &self,
1089 collection: &str,
1090 query: &str,
1091 limit: usize,
1092 ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
1093 let Some(qdrant) = &self.qdrant else {
1094 return Ok(Vec::new());
1095 };
1096 if !self.provider.supports_embeddings() {
1097 return Ok(Vec::new());
1098 }
1099 if !qdrant.collection_exists(collection).await? {
1100 return Ok(Vec::new());
1101 }
1102 let vector = self.provider.embed(query).await?;
1103 qdrant
1104 .search_collection(collection, &vector, limit, None)
1105 .await
1106 }
1107
1108 pub async fn store_correction_embedding(
1116 &self,
1117 correction_id: i64,
1118 correction_text: &str,
1119 ) -> Result<(), MemoryError> {
1120 let Some(ref store) = self.qdrant else {
1121 return Ok(());
1122 };
1123 if !self.provider.supports_embeddings() {
1124 return Ok(());
1125 }
1126 let embedding = self
1127 .provider
1128 .embed(correction_text)
1129 .await
1130 .map_err(|e| MemoryError::Other(e.to_string()))?;
1131 let payload = serde_json::json!({ "correction_id": correction_id });
1132 store
1133 .store_to_collection(CORRECTIONS_COLLECTION, payload, embedding)
1134 .await?;
1135 Ok(())
1136 }
1137
1138 pub async fn retrieve_similar_corrections(
1147 &self,
1148 query: &str,
1149 limit: usize,
1150 min_score: f32,
1151 ) -> Result<Vec<crate::sqlite::corrections::UserCorrectionRow>, MemoryError> {
1152 let Some(ref store) = self.qdrant else {
1153 return Ok(vec![]);
1154 };
1155 if !self.provider.supports_embeddings() {
1156 return Ok(vec![]);
1157 }
1158 let embedding = self
1159 .provider
1160 .embed(query)
1161 .await
1162 .map_err(|e| MemoryError::Other(e.to_string()))?;
1163 let scored = store
1164 .search_collection(CORRECTIONS_COLLECTION, &embedding, limit, None)
1165 .await
1166 .unwrap_or_default();
1167
1168 let mut results = Vec::new();
1169 for point in scored {
1170 if point.score < min_score {
1171 continue;
1172 }
1173 if let Some(id_val) = point.payload.get("correction_id")
1174 && let Some(id) = id_val.as_i64()
1175 {
1176 let rows = self.sqlite.load_corrections_for_id(id).await?;
1177 results.extend(rows);
1178 }
1179 }
1180 Ok(results)
1181 }
1182}
1183
1184#[cfg(test)]
1185mod tests {
1186 use zeph_llm::mock::MockProvider;
1187 use zeph_llm::provider::Role;
1188
1189 use super::*;
1190
1191 fn test_provider() -> AnyProvider {
1192 AnyProvider::Mock(MockProvider::default())
1193 }
1194
1195 async fn test_semantic_memory(_supports_embeddings: bool) -> SemanticMemory {
1196 let provider = test_provider();
1197 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1198
1199 SemanticMemory {
1200 sqlite,
1201 qdrant: None,
1202 provider,
1203 embedding_model: "test-model".into(),
1204 vector_weight: 0.7,
1205 keyword_weight: 0.3,
1206 temporal_decay_enabled: false,
1207 temporal_decay_half_life_days: 30,
1208 mmr_enabled: false,
1209 mmr_lambda: 0.7,
1210 token_counter: Arc::new(TokenCounter::new()),
1211 }
1212 }
1213
1214 #[tokio::test]
1215 async fn remember_saves_to_sqlite() {
1216 let memory = test_semantic_memory(false).await;
1217
1218 let cid = memory.sqlite.create_conversation().await.unwrap();
1219 let msg_id = memory.remember(cid, "user", "hello").await.unwrap();
1220
1221 assert_eq!(msg_id, MessageId(1));
1222
1223 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1224 assert_eq!(history.len(), 1);
1225 assert_eq!(history[0].role, Role::User);
1226 assert_eq!(history[0].content, "hello");
1227 }
1228
1229 #[tokio::test]
1230 async fn remember_with_parts_saves_parts_json() {
1231 let memory = test_semantic_memory(false).await;
1232 let cid = memory.sqlite.create_conversation().await.unwrap();
1233
1234 let parts_json =
1235 r#"[{"kind":"ToolOutput","tool_name":"shell","body":"hello","compacted_at":null}]"#;
1236 let (msg_id, _embedding_stored) = memory
1237 .remember_with_parts(cid, "assistant", "tool output", parts_json)
1238 .await
1239 .unwrap();
1240 assert!(msg_id > MessageId(0));
1241
1242 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1243 assert_eq!(history.len(), 1);
1244 assert_eq!(history[0].content, "tool output");
1245 }
1246
1247 #[tokio::test]
1248 async fn recall_returns_empty_without_qdrant() {
1249 let memory = test_semantic_memory(true).await;
1250
1251 let recalled = memory.recall("test", 5, None).await.unwrap();
1252 assert!(recalled.is_empty());
1253 }
1254
1255 #[tokio::test]
1256 async fn has_embedding_without_qdrant() {
1257 let memory = test_semantic_memory(true).await;
1258
1259 let has_embedding = memory.has_embedding(MessageId(1)).await.unwrap();
1260 assert!(!has_embedding);
1261 }
1262
1263 #[tokio::test]
1264 async fn embed_missing_without_qdrant() {
1265 let memory = test_semantic_memory(true).await;
1266
1267 let count = memory.embed_missing().await.unwrap();
1268 assert_eq!(count, 0);
1269 }
1270
1271 #[tokio::test]
1272 async fn sqlite_accessor() {
1273 let memory = test_semantic_memory(false).await;
1274
1275 let cid = memory.sqlite().create_conversation().await.unwrap();
1276 assert_eq!(cid, ConversationId(1));
1277
1278 memory
1279 .sqlite()
1280 .save_message(cid, "user", "test")
1281 .await
1282 .unwrap();
1283
1284 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1285 assert_eq!(history.len(), 1);
1286 }
1287
1288 #[tokio::test]
1289 async fn has_vector_store_returns_false_when_unavailable() {
1290 let memory = test_semantic_memory(false).await;
1291 assert!(!memory.has_vector_store());
1292 }
1293
1294 #[tokio::test]
1295 async fn is_vector_store_connected_returns_false_when_unavailable() {
1296 let memory = test_semantic_memory(false).await;
1297 assert!(!memory.is_vector_store_connected().await);
1298 }
1299
1300 #[tokio::test]
1301 async fn recall_returns_empty_when_embeddings_not_supported() {
1302 let memory = test_semantic_memory(false).await;
1303
1304 let recalled = memory.recall("test", 5, None).await.unwrap();
1305 assert!(recalled.is_empty());
1306 }
1307
1308 #[tokio::test]
1309 async fn embed_missing_returns_zero_when_embeddings_not_supported() {
1310 let memory = test_semantic_memory(false).await;
1311
1312 let cid = memory.sqlite().create_conversation().await.unwrap();
1313 memory
1314 .sqlite()
1315 .save_message(cid, "user", "test")
1316 .await
1317 .unwrap();
1318
1319 let count = memory.embed_missing().await.unwrap();
1320 assert_eq!(count, 0);
1321 }
1322
1323 #[tokio::test]
1324 async fn message_count_empty_conversation() {
1325 let memory = test_semantic_memory(false).await;
1326 let cid = memory.sqlite().create_conversation().await.unwrap();
1327
1328 let count = memory.message_count(cid).await.unwrap();
1329 assert_eq!(count, 0);
1330 }
1331
1332 #[tokio::test]
1333 async fn message_count_after_saves() {
1334 let memory = test_semantic_memory(false).await;
1335 let cid = memory.sqlite().create_conversation().await.unwrap();
1336
1337 memory.remember(cid, "user", "msg1").await.unwrap();
1338 memory.remember(cid, "assistant", "msg2").await.unwrap();
1339
1340 let count = memory.message_count(cid).await.unwrap();
1341 assert_eq!(count, 2);
1342 }
1343
1344 #[tokio::test]
1345 async fn unsummarized_count_decreases_after_summary() {
1346 let memory = test_semantic_memory(false).await;
1347 let cid = memory.sqlite().create_conversation().await.unwrap();
1348
1349 for i in 0..10 {
1350 memory
1351 .remember(cid, "user", &format!("msg{i}"))
1352 .await
1353 .unwrap();
1354 }
1355 assert_eq!(memory.unsummarized_message_count(cid).await.unwrap(), 10);
1356
1357 memory.summarize(cid, 5).await.unwrap();
1358
1359 assert!(memory.unsummarized_message_count(cid).await.unwrap() < 10);
1360 assert_eq!(memory.message_count(cid).await.unwrap(), 10);
1361 }
1362
1363 #[tokio::test]
1364 async fn load_summaries_empty() {
1365 let memory = test_semantic_memory(false).await;
1366 let cid = memory.sqlite().create_conversation().await.unwrap();
1367
1368 let summaries = memory.load_summaries(cid).await.unwrap();
1369 assert!(summaries.is_empty());
1370 }
1371
1372 #[tokio::test]
1373 async fn load_summaries_ordered() {
1374 let memory = test_semantic_memory(false).await;
1375 let cid = memory.sqlite().create_conversation().await.unwrap();
1376
1377 let msg_id1 = memory.remember(cid, "user", "m1").await.unwrap();
1378 let msg_id2 = memory.remember(cid, "assistant", "m2").await.unwrap();
1379 let msg_id3 = memory.remember(cid, "user", "m3").await.unwrap();
1380
1381 let s1 = memory
1382 .sqlite()
1383 .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
1384 .await
1385 .unwrap();
1386 let s2 = memory
1387 .sqlite()
1388 .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
1389 .await
1390 .unwrap();
1391
1392 let summaries = memory.load_summaries(cid).await.unwrap();
1393 assert_eq!(summaries.len(), 2);
1394 assert_eq!(summaries[0].id, s1);
1395 assert_eq!(summaries[0].content, "summary1");
1396 assert_eq!(summaries[1].id, s2);
1397 assert_eq!(summaries[1].content, "summary2");
1398 }
1399
1400 #[tokio::test]
1401 async fn summarize_below_threshold() {
1402 let memory = test_semantic_memory(false).await;
1403 let cid = memory.sqlite().create_conversation().await.unwrap();
1404
1405 memory.remember(cid, "user", "hello").await.unwrap();
1406
1407 let result = memory.summarize(cid, 10).await.unwrap();
1408 assert!(result.is_none());
1409 }
1410
1411 #[tokio::test]
1412 async fn summarize_stores_summary() {
1413 let memory = test_semantic_memory(false).await;
1414 let cid = memory.sqlite().create_conversation().await.unwrap();
1415
1416 for i in 0..5 {
1417 memory
1418 .remember(cid, "user", &format!("message {i}"))
1419 .await
1420 .unwrap();
1421 }
1422
1423 let summary_id = memory.summarize(cid, 3).await.unwrap();
1424 assert!(summary_id.is_some());
1425
1426 let summaries = memory.load_summaries(cid).await.unwrap();
1427 assert_eq!(summaries.len(), 1);
1428 assert_eq!(summaries[0].id, summary_id.unwrap());
1429 assert!(!summaries[0].content.is_empty());
1430 }
1431
1432 #[tokio::test]
1433 async fn summarize_respects_previous_summaries() {
1434 let memory = test_semantic_memory(false).await;
1435 let cid = memory.sqlite().create_conversation().await.unwrap();
1436
1437 for i in 0..10 {
1438 memory
1439 .remember(cid, "user", &format!("message {i}"))
1440 .await
1441 .unwrap();
1442 }
1443
1444 let s1 = memory.summarize(cid, 3).await.unwrap();
1445 assert!(s1.is_some());
1446
1447 let s2 = memory.summarize(cid, 3).await.unwrap();
1448 assert!(s2.is_some());
1449
1450 let summaries = memory.load_summaries(cid).await.unwrap();
1451 assert_eq!(summaries.len(), 2);
1452 assert!(summaries[0].last_message_id < summaries[1].first_message_id);
1453 }
1454
1455 #[tokio::test]
1456 async fn remember_multiple_messages_increments_ids() {
1457 let memory = test_semantic_memory(false).await;
1458 let cid = memory.sqlite.create_conversation().await.unwrap();
1459
1460 let id1 = memory.remember(cid, "user", "first").await.unwrap();
1461 let id2 = memory.remember(cid, "assistant", "second").await.unwrap();
1462 let id3 = memory.remember(cid, "user", "third").await.unwrap();
1463
1464 assert!(id1 < id2);
1465 assert!(id2 < id3);
1466 }
1467
1468 #[tokio::test]
1469 async fn message_count_across_conversations() {
1470 let memory = test_semantic_memory(false).await;
1471 let cid1 = memory.sqlite().create_conversation().await.unwrap();
1472 let cid2 = memory.sqlite().create_conversation().await.unwrap();
1473
1474 memory.remember(cid1, "user", "msg1").await.unwrap();
1475 memory.remember(cid1, "user", "msg2").await.unwrap();
1476 memory.remember(cid2, "user", "msg3").await.unwrap();
1477
1478 assert_eq!(memory.message_count(cid1).await.unwrap(), 2);
1479 assert_eq!(memory.message_count(cid2).await.unwrap(), 1);
1480 }
1481
1482 #[tokio::test]
1483 async fn summarize_exact_threshold_returns_none() {
1484 let memory = test_semantic_memory(false).await;
1485 let cid = memory.sqlite().create_conversation().await.unwrap();
1486
1487 for i in 0..3 {
1488 memory
1489 .remember(cid, "user", &format!("msg {i}"))
1490 .await
1491 .unwrap();
1492 }
1493
1494 let result = memory.summarize(cid, 3).await.unwrap();
1495 assert!(result.is_none());
1496 }
1497
1498 #[tokio::test]
1499 async fn summarize_one_above_threshold_produces_summary() {
1500 let memory = test_semantic_memory(false).await;
1501 let cid = memory.sqlite().create_conversation().await.unwrap();
1502
1503 for i in 0..4 {
1504 memory
1505 .remember(cid, "user", &format!("msg {i}"))
1506 .await
1507 .unwrap();
1508 }
1509
1510 let result = memory.summarize(cid, 3).await.unwrap();
1511 assert!(result.is_some());
1512 }
1513
1514 #[tokio::test]
1515 async fn summary_fields_populated() {
1516 let memory = test_semantic_memory(false).await;
1517 let cid = memory.sqlite().create_conversation().await.unwrap();
1518
1519 for i in 0..5 {
1520 memory
1521 .remember(cid, "user", &format!("msg {i}"))
1522 .await
1523 .unwrap();
1524 }
1525
1526 memory.summarize(cid, 3).await.unwrap();
1527 let summaries = memory.load_summaries(cid).await.unwrap();
1528 let s = &summaries[0];
1529
1530 assert_eq!(s.conversation_id, cid);
1531 assert!(s.first_message_id > MessageId(0));
1532 assert!(s.last_message_id >= s.first_message_id);
1533 assert!(s.token_estimate >= 0);
1534 assert!(!s.content.is_empty());
1535 }
1536
1537 #[test]
1538 fn build_summarization_prompt_format() {
1539 let messages = vec![
1540 (MessageId(1), "user".into(), "Hello".into()),
1541 (MessageId(2), "assistant".into(), "Hi there".into()),
1542 ];
1543 let prompt = build_summarization_prompt(&messages);
1544 assert!(prompt.contains("user: Hello"));
1545 assert!(prompt.contains("assistant: Hi there"));
1546 assert!(prompt.contains("key_facts"));
1547 }
1548
1549 #[test]
1550 fn build_summarization_prompt_empty() {
1551 let messages: Vec<(MessageId, String, String)> = vec![];
1552 let prompt = build_summarization_prompt(&messages);
1553 assert!(prompt.contains("key_facts"));
1554 }
1555
1556 #[test]
1557 fn structured_summary_deserialize() {
1558 let json = r#"{"summary":"s","key_facts":["f1","f2"],"entities":["e1"]}"#;
1559 let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1560 assert_eq!(ss.summary, "s");
1561 assert_eq!(ss.key_facts.len(), 2);
1562 assert_eq!(ss.entities.len(), 1);
1563 }
1564
1565 #[test]
1566 fn structured_summary_empty_facts() {
1567 let json = r#"{"summary":"s","key_facts":[],"entities":[]}"#;
1568 let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1569 assert!(ss.key_facts.is_empty());
1570 assert!(ss.entities.is_empty());
1571 }
1572
1573 #[tokio::test]
1574 async fn search_key_facts_no_qdrant_empty() {
1575 let memory = test_semantic_memory(false).await;
1576 let facts = memory.search_key_facts("query", 5).await.unwrap();
1577 assert!(facts.is_empty());
1578 }
1579
1580 #[test]
1581 fn recalled_message_debug() {
1582 let recalled = RecalledMessage {
1583 message: Message {
1584 role: Role::User,
1585 content: "test".into(),
1586 parts: vec![],
1587 metadata: MessageMetadata::default(),
1588 },
1589 score: 0.95,
1590 };
1591 let dbg = format!("{recalled:?}");
1592 assert!(dbg.contains("RecalledMessage"));
1593 assert!(dbg.contains("0.95"));
1594 }
1595
1596 #[test]
1597 fn summary_clone() {
1598 let summary = Summary {
1599 id: 1,
1600 conversation_id: ConversationId(2),
1601 content: "test summary".into(),
1602 first_message_id: MessageId(1),
1603 last_message_id: MessageId(5),
1604 token_estimate: 10,
1605 };
1606 let cloned = summary.clone();
1607 assert_eq!(summary.id, cloned.id);
1608 assert_eq!(summary.content, cloned.content);
1609 }
1610
1611 #[tokio::test]
1612 async fn remember_preserves_role_mapping() {
1613 let memory = test_semantic_memory(false).await;
1614 let cid = memory.sqlite.create_conversation().await.unwrap();
1615
1616 memory.remember(cid, "user", "u").await.unwrap();
1617 memory.remember(cid, "assistant", "a").await.unwrap();
1618 memory.remember(cid, "system", "s").await.unwrap();
1619
1620 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1621 assert_eq!(history.len(), 3);
1622 assert_eq!(history[0].role, Role::User);
1623 assert_eq!(history[1].role, Role::Assistant);
1624 assert_eq!(history[2].role, Role::System);
1625 }
1626
1627 #[tokio::test]
1628 async fn new_with_invalid_qdrant_url_graceful() {
1629 let mut mock = MockProvider::default();
1630 mock.supports_embeddings = true;
1631 let provider = AnyProvider::Mock(mock);
1632 let result =
1633 SemanticMemory::new(":memory:", "http://127.0.0.1:1", provider, "test-model").await;
1634 assert!(result.is_ok());
1635 }
1636
1637 #[tokio::test]
1638 async fn test_semantic_memory_sqlite_remember_recall_roundtrip() {
1639 let mut mock = MockProvider::default();
1641 mock.supports_embeddings = true;
1642 let provider = AnyProvider::Mock(mock);
1645
1646 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1647 let pool = sqlite.pool().clone();
1648 let qdrant = Some(crate::embedding_store::EmbeddingStore::new_sqlite(pool));
1649
1650 let memory = SemanticMemory {
1651 sqlite,
1652 qdrant,
1653 provider,
1654 embedding_model: "test-model".into(),
1655 vector_weight: 0.7,
1656 keyword_weight: 0.3,
1657 temporal_decay_enabled: false,
1658 temporal_decay_half_life_days: 30,
1659 mmr_enabled: false,
1660 mmr_lambda: 0.7,
1661 token_counter: Arc::new(TokenCounter::new()),
1662 };
1663
1664 let cid = memory.sqlite().create_conversation().await.unwrap();
1665
1666 let id1 = memory
1668 .remember(cid, "user", "rust async programming")
1669 .await
1670 .unwrap();
1671 let id2 = memory
1672 .remember(cid, "assistant", "use tokio for async")
1673 .await
1674 .unwrap();
1675 assert!(id1 < id2);
1676
1677 let recalled = memory.recall("rust", 5, None).await.unwrap();
1679 assert!(
1680 !recalled.is_empty(),
1681 "recall must return at least one result"
1682 );
1683
1684 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1686 assert_eq!(history.len(), 2);
1687 assert_eq!(history[0].content, "rust async programming");
1688 }
1689
1690 #[tokio::test]
1691 async fn remember_with_embeddings_supported_but_no_qdrant() {
1692 let memory = test_semantic_memory(true).await;
1693 let cid = memory.sqlite.create_conversation().await.unwrap();
1694
1695 let msg_id = memory.remember(cid, "user", "hello embed").await.unwrap();
1696 assert!(msg_id > MessageId(0));
1697
1698 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1699 assert_eq!(history.len(), 1);
1700 assert_eq!(history[0].content, "hello embed");
1701 }
1702
1703 #[tokio::test]
1704 async fn remember_verifies_content_via_load_history() {
1705 let memory = test_semantic_memory(false).await;
1706 let cid = memory.sqlite.create_conversation().await.unwrap();
1707
1708 memory.remember(cid, "user", "alpha").await.unwrap();
1709 memory.remember(cid, "assistant", "beta").await.unwrap();
1710 memory.remember(cid, "user", "gamma").await.unwrap();
1711
1712 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1713 assert_eq!(history.len(), 3);
1714 assert_eq!(history[0].content, "alpha");
1715 assert_eq!(history[1].content, "beta");
1716 assert_eq!(history[2].content, "gamma");
1717 }
1718
1719 #[tokio::test]
1720 async fn message_count_multiple_conversations_isolated() {
1721 let memory = test_semantic_memory(false).await;
1722 let cid1 = memory.sqlite().create_conversation().await.unwrap();
1723 let cid2 = memory.sqlite().create_conversation().await.unwrap();
1724 let cid3 = memory.sqlite().create_conversation().await.unwrap();
1725
1726 for _ in 0..5 {
1727 memory.remember(cid1, "user", "msg").await.unwrap();
1728 }
1729 for _ in 0..3 {
1730 memory.remember(cid2, "user", "msg").await.unwrap();
1731 }
1732
1733 assert_eq!(memory.message_count(cid1).await.unwrap(), 5);
1734 assert_eq!(memory.message_count(cid2).await.unwrap(), 3);
1735 assert_eq!(memory.message_count(cid3).await.unwrap(), 0);
1736 }
1737
1738 #[tokio::test]
1739 async fn summarize_empty_messages_range_returns_none() {
1740 let memory = test_semantic_memory(false).await;
1741 let cid = memory.sqlite().create_conversation().await.unwrap();
1742
1743 for i in 0..6 {
1744 memory
1745 .remember(cid, "user", &format!("msg {i}"))
1746 .await
1747 .unwrap();
1748 }
1749
1750 memory.summarize(cid, 3).await.unwrap();
1751 memory.summarize(cid, 3).await.unwrap();
1752
1753 let summaries = memory.load_summaries(cid).await.unwrap();
1754 assert_eq!(summaries.len(), 2);
1755 }
1756
1757 #[tokio::test]
1758 async fn summarize_token_estimate_populated() {
1759 let memory = test_semantic_memory(false).await;
1760 let cid = memory.sqlite().create_conversation().await.unwrap();
1761
1762 for i in 0..5 {
1763 memory
1764 .remember(cid, "user", &format!("message {i}"))
1765 .await
1766 .unwrap();
1767 }
1768
1769 memory.summarize(cid, 3).await.unwrap();
1770 let summaries = memory.load_summaries(cid).await.unwrap();
1771 let token_est = summaries[0].token_estimate;
1772 assert!(token_est > 0);
1773 }
1774
1775 #[tokio::test]
1776 async fn summarize_fails_when_provider_chat_fails() {
1777 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1778 let provider = AnyProvider::Ollama(zeph_llm::ollama::OllamaProvider::new(
1779 "http://127.0.0.1:1",
1780 "test".into(),
1781 "embed".into(),
1782 ));
1783 let memory = SemanticMemory {
1784 sqlite,
1785 qdrant: None,
1786 provider,
1787 embedding_model: "test".into(),
1788 vector_weight: 0.7,
1789 keyword_weight: 0.3,
1790 temporal_decay_enabled: false,
1791 temporal_decay_half_life_days: 30,
1792 mmr_enabled: false,
1793 mmr_lambda: 0.7,
1794 token_counter: Arc::new(TokenCounter::new()),
1795 };
1796 let cid = memory.sqlite().create_conversation().await.unwrap();
1797
1798 for i in 0..5 {
1799 memory
1800 .remember(cid, "user", &format!("msg {i}"))
1801 .await
1802 .unwrap();
1803 }
1804
1805 let result = memory.summarize(cid, 3).await;
1806 assert!(result.is_err());
1807 }
1808
1809 #[tokio::test]
1810 async fn embed_missing_without_embedding_support_returns_zero() {
1811 let memory = test_semantic_memory(false).await;
1812 let cid = memory.sqlite().create_conversation().await.unwrap();
1813 memory
1814 .sqlite()
1815 .save_message(cid, "user", "test message")
1816 .await
1817 .unwrap();
1818
1819 let count = memory.embed_missing().await.unwrap();
1820 assert_eq!(count, 0);
1821 }
1822
1823 #[tokio::test]
1824 async fn has_embedding_returns_false_when_no_qdrant() {
1825 let memory = test_semantic_memory(false).await;
1826 let cid = memory.sqlite.create_conversation().await.unwrap();
1827 let msg_id = memory.remember(cid, "user", "test").await.unwrap();
1828 assert!(!memory.has_embedding(msg_id).await.unwrap());
1829 }
1830
1831 #[tokio::test]
1832 async fn recall_empty_without_qdrant_regardless_of_filter() {
1833 let memory = test_semantic_memory(true).await;
1834 let filter = SearchFilter {
1835 conversation_id: Some(ConversationId(1)),
1836 role: None,
1837 };
1838 let recalled = memory.recall("query", 10, Some(filter)).await.unwrap();
1839 assert!(recalled.is_empty());
1840 }
1841
1842 #[tokio::test]
1843 async fn summarize_message_range_bounds() {
1844 let memory = test_semantic_memory(false).await;
1845 let cid = memory.sqlite().create_conversation().await.unwrap();
1846
1847 for i in 0..8 {
1848 memory
1849 .remember(cid, "user", &format!("msg {i}"))
1850 .await
1851 .unwrap();
1852 }
1853
1854 let summary_id = memory.summarize(cid, 4).await.unwrap().unwrap();
1855 let summaries = memory.load_summaries(cid).await.unwrap();
1856 assert_eq!(summaries.len(), 1);
1857 assert_eq!(summaries[0].id, summary_id);
1858 assert!(summaries[0].first_message_id >= MessageId(1));
1859 assert!(summaries[0].last_message_id >= summaries[0].first_message_id);
1860 }
1861
1862 #[test]
1863 fn build_summarization_prompt_preserves_order() {
1864 let messages = vec![
1865 (MessageId(1), "user".into(), "first".into()),
1866 (MessageId(2), "assistant".into(), "second".into()),
1867 (MessageId(3), "user".into(), "third".into()),
1868 ];
1869 let prompt = build_summarization_prompt(&messages);
1870 let first_pos = prompt.find("user: first").unwrap();
1871 let second_pos = prompt.find("assistant: second").unwrap();
1872 let third_pos = prompt.find("user: third").unwrap();
1873 assert!(first_pos < second_pos);
1874 assert!(second_pos < third_pos);
1875 }
1876
1877 #[test]
1878 fn summary_debug() {
1879 let summary = Summary {
1880 id: 1,
1881 conversation_id: ConversationId(2),
1882 content: "test".into(),
1883 first_message_id: MessageId(1),
1884 last_message_id: MessageId(5),
1885 token_estimate: 10,
1886 };
1887 let dbg = format!("{summary:?}");
1888 assert!(dbg.contains("Summary"));
1889 }
1890
1891 #[tokio::test]
1892 async fn message_count_nonexistent_conversation() {
1893 let memory = test_semantic_memory(false).await;
1894 let count = memory.message_count(ConversationId(999)).await.unwrap();
1895 assert_eq!(count, 0);
1896 }
1897
1898 #[tokio::test]
1899 async fn load_summaries_nonexistent_conversation() {
1900 let memory = test_semantic_memory(false).await;
1901 let summaries = memory.load_summaries(ConversationId(999)).await.unwrap();
1902 assert!(summaries.is_empty());
1903 }
1904
1905 #[tokio::test]
1906 async fn store_session_summary_no_qdrant_noop() {
1907 let memory = test_semantic_memory(true).await;
1908 let result = memory
1909 .store_session_summary(ConversationId(1), "test summary")
1910 .await;
1911 assert!(result.is_ok());
1912 }
1913
1914 #[tokio::test]
1915 async fn store_session_summary_no_embeddings_noop() {
1916 let memory = test_semantic_memory(false).await;
1917 let result = memory
1918 .store_session_summary(ConversationId(1), "test summary")
1919 .await;
1920 assert!(result.is_ok());
1921 }
1922
1923 #[tokio::test]
1924 async fn search_session_summaries_no_qdrant_empty() {
1925 let memory = test_semantic_memory(true).await;
1926 let results = memory
1927 .search_session_summaries("query", 5, None)
1928 .await
1929 .unwrap();
1930 assert!(results.is_empty());
1931 }
1932
1933 #[tokio::test]
1934 async fn search_session_summaries_no_embeddings_empty() {
1935 let memory = test_semantic_memory(false).await;
1936 let results = memory
1937 .search_session_summaries("query", 5, Some(ConversationId(1)))
1938 .await
1939 .unwrap();
1940 assert!(results.is_empty());
1941 }
1942
1943 #[test]
1944 fn session_summary_result_debug() {
1945 let result = SessionSummaryResult {
1946 summary_text: "test".into(),
1947 score: 0.9,
1948 conversation_id: ConversationId(1),
1949 };
1950 let dbg = format!("{result:?}");
1951 assert!(dbg.contains("SessionSummaryResult"));
1952 }
1953
1954 #[test]
1955 fn session_summary_result_clone() {
1956 let result = SessionSummaryResult {
1957 summary_text: "test".into(),
1958 score: 0.9,
1959 conversation_id: ConversationId(1),
1960 };
1961 let cloned = result.clone();
1962 assert_eq!(result.summary_text, cloned.summary_text);
1963 assert_eq!(result.conversation_id, cloned.conversation_id);
1964 }
1965
1966 #[tokio::test]
1967 async fn recall_fts5_fallback_without_qdrant() {
1968 let memory = test_semantic_memory(false).await;
1969 let cid = memory.sqlite.create_conversation().await.unwrap();
1970
1971 memory
1972 .remember(cid, "user", "rust programming guide")
1973 .await
1974 .unwrap();
1975 memory
1976 .remember(cid, "assistant", "python tutorial")
1977 .await
1978 .unwrap();
1979 memory
1980 .remember(cid, "user", "advanced rust patterns")
1981 .await
1982 .unwrap();
1983
1984 let recalled = memory.recall("rust", 5, None).await.unwrap();
1985 assert_eq!(recalled.len(), 2);
1986 assert!(recalled[0].score >= recalled[1].score);
1987 }
1988
1989 #[tokio::test]
1990 async fn recall_fts5_fallback_with_filter() {
1991 let memory = test_semantic_memory(false).await;
1992 let cid1 = memory.sqlite.create_conversation().await.unwrap();
1993 let cid2 = memory.sqlite.create_conversation().await.unwrap();
1994
1995 memory.remember(cid1, "user", "hello world").await.unwrap();
1996 memory
1997 .remember(cid2, "user", "hello universe")
1998 .await
1999 .unwrap();
2000
2001 let filter = SearchFilter {
2002 conversation_id: Some(cid1),
2003 role: None,
2004 };
2005 let recalled = memory.recall("hello", 5, Some(filter)).await.unwrap();
2006 assert_eq!(recalled.len(), 1);
2007 }
2008
2009 #[tokio::test]
2010 async fn recall_fts5_no_matches_returns_empty() {
2011 let memory = test_semantic_memory(false).await;
2012 let cid = memory.sqlite.create_conversation().await.unwrap();
2013
2014 memory.remember(cid, "user", "hello world").await.unwrap();
2015
2016 let recalled = memory.recall("nonexistent", 5, None).await.unwrap();
2017 assert!(recalled.is_empty());
2018 }
2019
2020 #[tokio::test]
2021 async fn recall_fts5_respects_limit() {
2022 let memory = test_semantic_memory(false).await;
2023 let cid = memory.sqlite.create_conversation().await.unwrap();
2024
2025 for i in 0..10 {
2026 memory
2027 .remember(cid, "user", &format!("test message number {i}"))
2028 .await
2029 .unwrap();
2030 }
2031
2032 let recalled = memory.recall("test", 3, None).await.unwrap();
2033 assert_eq!(recalled.len(), 3);
2034 }
2035
2036 #[tokio::test]
2039 async fn summarize_fallback_to_plain_text_when_structured_fails() {
2040 let sqlite = SqliteStore::new(":memory:").await.unwrap();
2048 let mut mock = MockProvider::default();
2049 mock.default_response = "plain text summary".into();
2051 let provider = AnyProvider::Mock(mock);
2052
2053 let memory = SemanticMemory {
2054 sqlite,
2055 qdrant: None,
2056 provider,
2057 embedding_model: "test".into(),
2058 vector_weight: 0.7,
2059 keyword_weight: 0.3,
2060 temporal_decay_enabled: false,
2061 temporal_decay_half_life_days: 30,
2062 mmr_enabled: false,
2063 mmr_lambda: 0.7,
2064 token_counter: Arc::new(TokenCounter::new()),
2065 };
2066
2067 let cid = memory.sqlite().create_conversation().await.unwrap();
2068 for i in 0..5 {
2069 memory
2070 .remember(cid, "user", &format!("msg {i}"))
2071 .await
2072 .unwrap();
2073 }
2074
2075 let result = memory.summarize(cid, 3).await;
2076 assert!(result.is_ok());
2082 let summaries = memory.load_summaries(cid).await.unwrap();
2083 assert_eq!(summaries.len(), 1);
2084 assert!(!summaries[0].content.is_empty());
2085 }
2086
2087 #[test]
2090 fn temporal_decay_disabled_leaves_scores_unchanged() {
2091 let mut ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2092 let timestamps = std::collections::HashMap::new();
2093 apply_temporal_decay(&mut ranked, ×tamps, 30);
2094 assert!((ranked[0].1 - 1.0).abs() < f64::EPSILON);
2095 assert!((ranked[1].1 - 0.5).abs() < f64::EPSILON);
2096 }
2097
2098 #[test]
2099 fn temporal_decay_zero_age_preserves_score() {
2100 let now = std::time::SystemTime::now()
2101 .duration_since(std::time::UNIX_EPOCH)
2102 .unwrap_or_default()
2103 .as_secs()
2104 .cast_signed();
2105 let mut ranked = vec![(MessageId(1), 1.0f64)];
2106 let mut timestamps = std::collections::HashMap::new();
2107 timestamps.insert(MessageId(1), now);
2108 apply_temporal_decay(&mut ranked, ×tamps, 30);
2109 assert!((ranked[0].1 - 1.0).abs() < 0.01);
2111 }
2112
2113 #[test]
2114 fn temporal_decay_half_life_halves_score() {
2115 let half_life = 30u32;
2117 let age_secs = i64::from(half_life) * 86400;
2118 let now = std::time::SystemTime::now()
2119 .duration_since(std::time::UNIX_EPOCH)
2120 .unwrap_or_default()
2121 .as_secs()
2122 .cast_signed();
2123 let ts = now - age_secs;
2124 let mut ranked = vec![(MessageId(1), 1.0f64)];
2125 let mut timestamps = std::collections::HashMap::new();
2126 timestamps.insert(MessageId(1), ts);
2127 apply_temporal_decay(&mut ranked, ×tamps, half_life);
2128 assert!(
2130 (ranked[0].1 - 0.5).abs() < 0.01,
2131 "score was {}",
2132 ranked[0].1
2133 );
2134 }
2135
2136 #[test]
2139 fn mmr_empty_input_returns_empty() {
2140 let ranked = vec![];
2141 let vectors = std::collections::HashMap::new();
2142 let result = apply_mmr(&ranked, &vectors, 0.7, 5);
2143 assert!(result.is_empty());
2144 }
2145
2146 #[test]
2147 fn mmr_returns_up_to_limit() {
2148 let ranked = vec![
2149 (MessageId(1), 1.0f64),
2150 (MessageId(2), 0.9f64),
2151 (MessageId(3), 0.8f64),
2152 ];
2153 let mut vectors = std::collections::HashMap::new();
2154 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2155 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2156 vectors.insert(MessageId(3), vec![1.0f32, 0.0]);
2157 let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2158 assert_eq!(result.len(), 2);
2159 }
2160
2161 #[test]
2162 fn mmr_without_vectors_picks_by_relevance() {
2163 let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2164 let vectors = std::collections::HashMap::new();
2165 let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2166 assert_eq!(result.len(), 2);
2167 assert_eq!(result[0].0, MessageId(1));
2168 }
2169
2170 #[test]
2171 fn mmr_prefers_diverse_over_redundant() {
2172 let ranked = vec![
2174 (MessageId(1), 1.0f64), (MessageId(2), 0.9f64), (MessageId(3), 0.9f64), ];
2178 let mut vectors = std::collections::HashMap::new();
2179 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2180 vectors.insert(MessageId(2), vec![0.0f32, 1.0]); vectors.insert(MessageId(3), vec![1.0f32, 0.0]); let result = apply_mmr(&ranked, &vectors, 0.5, 2);
2183 assert_eq!(result.len(), 2);
2184 assert_eq!(result[0].0, MessageId(1));
2185 assert_eq!(result[1].0, MessageId(2));
2187 }
2188
2189 #[test]
2190 fn temporal_decay_half_life_zero_is_noop() {
2191 let now = std::time::SystemTime::now()
2192 .duration_since(std::time::UNIX_EPOCH)
2193 .unwrap_or_default()
2194 .as_secs()
2195 .cast_signed();
2196 let age_secs = 30i64 * 86400;
2197 let ts = now - age_secs;
2198 let mut ranked = vec![(MessageId(1), 1.0f64)];
2199 let mut timestamps = std::collections::HashMap::new();
2200 timestamps.insert(MessageId(1), ts);
2201 apply_temporal_decay(&mut ranked, ×tamps, 0);
2203 assert!(
2204 (ranked[0].1 - 1.0).abs() < f64::EPSILON,
2205 "score was {}",
2206 ranked[0].1
2207 );
2208 }
2209
2210 #[test]
2211 fn temporal_decay_huge_age_near_zero() {
2212 let now = std::time::SystemTime::now()
2213 .duration_since(std::time::UNIX_EPOCH)
2214 .unwrap_or_default()
2215 .as_secs()
2216 .cast_signed();
2217 let age_secs = 3650i64 * 86400;
2219 let ts = now - age_secs;
2220 let mut ranked = vec![(MessageId(1), 1.0f64)];
2221 let mut timestamps = std::collections::HashMap::new();
2222 timestamps.insert(MessageId(1), ts);
2223 apply_temporal_decay(&mut ranked, ×tamps, 30);
2224 assert!(ranked[0].1 < 0.001, "score was {}", ranked[0].1);
2226 }
2227
2228 #[test]
2229 fn temporal_decay_small_half_life() {
2230 let now = std::time::SystemTime::now()
2232 .duration_since(std::time::UNIX_EPOCH)
2233 .unwrap_or_default()
2234 .as_secs()
2235 .cast_signed();
2236 let ts = now - 7 * 86400i64;
2237 let mut ranked = vec![(MessageId(1), 1.0f64)];
2238 let mut timestamps = std::collections::HashMap::new();
2239 timestamps.insert(MessageId(1), ts);
2240 apply_temporal_decay(&mut ranked, ×tamps, 1);
2241 assert!(ranked[0].1 < 0.01, "score was {}", ranked[0].1);
2242 }
2243
2244 #[test]
2245 fn mmr_lambda_zero_max_diversity() {
2246 let ranked = vec![
2248 (MessageId(1), 1.0f64), (MessageId(2), 0.9f64), (MessageId(3), 0.85f64), ];
2252 let mut vectors = std::collections::HashMap::new();
2253 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2254 vectors.insert(MessageId(2), vec![0.0f32, 1.0]); vectors.insert(MessageId(3), vec![1.0f32, 0.0]); let result = apply_mmr(&ranked, &vectors, 0.0, 3);
2257 assert_eq!(result.len(), 3);
2258 assert_eq!(result[1].0, MessageId(2));
2260 }
2261
2262 #[test]
2263 fn mmr_lambda_one_pure_relevance() {
2264 let ranked = vec![
2266 (MessageId(1), 1.0f64),
2267 (MessageId(2), 0.8f64),
2268 (MessageId(3), 0.6f64),
2269 ];
2270 let mut vectors = std::collections::HashMap::new();
2271 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2272 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2273 vectors.insert(MessageId(3), vec![0.5f32, 0.5]);
2274 let result = apply_mmr(&ranked, &vectors, 1.0, 3);
2275 assert_eq!(result.len(), 3);
2276 assert_eq!(result[0].0, MessageId(1));
2277 assert_eq!(result[1].0, MessageId(2));
2278 assert_eq!(result[2].0, MessageId(3));
2279 }
2280
2281 #[test]
2282 fn mmr_limit_zero_returns_empty() {
2283 let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.8f64)];
2284 let mut vectors = std::collections::HashMap::new();
2285 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2286 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2287 let result = apply_mmr(&ranked, &vectors, 0.7, 0);
2288 assert!(result.is_empty());
2289 }
2290
2291 #[test]
2292 fn mmr_duplicate_vectors_penalizes_second() {
2293 let ranked = vec![
2295 (MessageId(1), 1.0f64),
2296 (MessageId(2), 1.0f64), (MessageId(3), 0.9f64), ];
2299 let mut vectors = std::collections::HashMap::new();
2300 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2301 vectors.insert(MessageId(2), vec![1.0f32, 0.0]); vectors.insert(MessageId(3), vec![0.0f32, 1.0]); let result = apply_mmr(&ranked, &vectors, 0.5, 3);
2304 assert_eq!(result.len(), 3);
2305 assert_eq!(result[0].0, MessageId(1));
2306 assert_eq!(result[1].0, MessageId(3));
2308 }
2309
2310 use proptest::prelude::*;
2313
2314 proptest! {
2315 #[test]
2316 fn count_tokens_never_panics(s in ".*") {
2317 let counter = crate::token_counter::TokenCounter::new();
2318 let _ = counter.count_tokens(&s);
2319 }
2320 }
2321}