1use zeph_llm::any::AnyProvider;
5use zeph_llm::provider::{LlmProvider, Message, MessageMetadata, Role};
6
7use std::sync::Arc;
8
9use crate::embedding_store::{EmbeddingStore, MessageKind, SearchFilter};
10use crate::error::MemoryError;
11use crate::sqlite::SqliteStore;
12use crate::token_counter::TokenCounter;
13use crate::types::{ConversationId, MessageId};
14use crate::vector_store::{FieldCondition, FieldValue, VectorFilter};
15
16const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
17const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
18
19#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
20pub struct StructuredSummary {
21 pub summary: String,
22 pub key_facts: Vec<String>,
23 pub entities: Vec<String>,
24}
25
26#[derive(Debug)]
27pub struct RecalledMessage {
28 pub message: Message,
29 pub score: f32,
30}
31
32#[derive(Debug, Clone)]
33pub struct Summary {
34 pub id: i64,
35 pub conversation_id: ConversationId,
36 pub content: String,
37 pub first_message_id: MessageId,
38 pub last_message_id: MessageId,
39 pub token_estimate: i64,
40}
41
42#[derive(Debug, Clone)]
43pub struct SessionSummaryResult {
44 pub summary_text: String,
45 pub score: f32,
46 pub conversation_id: ConversationId,
47}
48
49fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
50 if a.len() != b.len() || a.is_empty() {
51 return 0.0;
52 }
53 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
54 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
55 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
56 if norm_a == 0.0 || norm_b == 0.0 {
57 return 0.0;
58 }
59 dot / (norm_a * norm_b)
60}
61
62fn apply_temporal_decay(
63 ranked: &mut [(MessageId, f64)],
64 timestamps: &std::collections::HashMap<MessageId, i64>,
65 half_life_days: u32,
66) {
67 if half_life_days == 0 {
68 return;
69 }
70 let now = std::time::SystemTime::now()
71 .duration_since(std::time::UNIX_EPOCH)
72 .unwrap_or_default()
73 .as_secs()
74 .cast_signed();
75 let lambda = std::f64::consts::LN_2 / f64::from(half_life_days);
76
77 for (msg_id, score) in ranked.iter_mut() {
78 if let Some(&ts) = timestamps.get(msg_id) {
79 #[allow(clippy::cast_precision_loss)]
80 let age_days = (now - ts).max(0) as f64 / 86400.0;
81 *score *= (-lambda * age_days).exp();
82 }
83 }
84}
85
86fn apply_mmr(
87 ranked: &[(MessageId, f64)],
88 vectors: &std::collections::HashMap<MessageId, Vec<f32>>,
89 lambda: f32,
90 limit: usize,
91) -> Vec<(MessageId, f64)> {
92 if ranked.is_empty() || limit == 0 {
93 return Vec::new();
94 }
95
96 let lambda = f64::from(lambda);
97 let mut selected: Vec<(MessageId, f64)> = Vec::with_capacity(limit);
98 let mut remaining: Vec<(MessageId, f64)> = ranked.to_vec();
99
100 while selected.len() < limit && !remaining.is_empty() {
101 let best_idx = if selected.is_empty() {
102 0
104 } else {
105 let mut best = 0usize;
106 let mut best_score = f64::NEG_INFINITY;
107
108 for (i, &(cand_id, relevance)) in remaining.iter().enumerate() {
109 let max_sim = if let Some(cand_vec) = vectors.get(&cand_id) {
110 selected
111 .iter()
112 .filter_map(|(sel_id, _)| vectors.get(sel_id))
113 .map(|sel_vec| f64::from(cosine_similarity(cand_vec, sel_vec)))
114 .fold(f64::NEG_INFINITY, f64::max)
115 } else {
116 0.0
117 };
118 let max_sim = if max_sim == f64::NEG_INFINITY {
119 0.0
120 } else {
121 max_sim
122 };
123 let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim;
124 if mmr_score > best_score {
125 best_score = mmr_score;
126 best = i;
127 }
128 }
129 best
130 };
131
132 selected.push(remaining.remove(best_idx));
133 }
134
135 selected
136}
137
138fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
139 let mut prompt = String::from(
140 "Summarize the following conversation. Extract key facts, decisions, entities, \
141 and context needed to continue the conversation.\n\n\
142 Respond in JSON with fields: summary (string), key_facts (list of strings), \
143 entities (list of strings).\n\nConversation:\n",
144 );
145
146 for (_, role, content) in messages {
147 prompt.push_str(role);
148 prompt.push_str(": ");
149 prompt.push_str(content);
150 prompt.push('\n');
151 }
152
153 prompt
154}
155
156pub struct SemanticMemory {
157 sqlite: SqliteStore,
158 qdrant: Option<EmbeddingStore>,
159 provider: AnyProvider,
160 embedding_model: String,
161 vector_weight: f64,
162 keyword_weight: f64,
163 temporal_decay_enabled: bool,
164 temporal_decay_half_life_days: u32,
165 mmr_enabled: bool,
166 mmr_lambda: f32,
167 pub token_counter: Arc<TokenCounter>,
168}
169
170impl SemanticMemory {
171 pub async fn new(
179 sqlite_path: &str,
180 qdrant_url: &str,
181 provider: AnyProvider,
182 embedding_model: &str,
183 ) -> Result<Self, MemoryError> {
184 Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
185 }
186
187 pub async fn with_weights(
193 sqlite_path: &str,
194 qdrant_url: &str,
195 provider: AnyProvider,
196 embedding_model: &str,
197 vector_weight: f64,
198 keyword_weight: f64,
199 ) -> Result<Self, MemoryError> {
200 Self::with_weights_and_pool_size(
201 sqlite_path,
202 qdrant_url,
203 provider,
204 embedding_model,
205 vector_weight,
206 keyword_weight,
207 5,
208 )
209 .await
210 }
211
212 pub async fn with_weights_and_pool_size(
218 sqlite_path: &str,
219 qdrant_url: &str,
220 provider: AnyProvider,
221 embedding_model: &str,
222 vector_weight: f64,
223 keyword_weight: f64,
224 pool_size: u32,
225 ) -> Result<Self, MemoryError> {
226 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
227 let pool = sqlite.pool().clone();
228
229 let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
230 Ok(store) => Some(store),
231 Err(e) => {
232 tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
233 None
234 }
235 };
236
237 Ok(Self {
238 sqlite,
239 qdrant,
240 provider,
241 embedding_model: embedding_model.into(),
242 vector_weight,
243 keyword_weight,
244 temporal_decay_enabled: false,
245 temporal_decay_half_life_days: 30,
246 mmr_enabled: false,
247 mmr_lambda: 0.7,
248 token_counter: Arc::new(TokenCounter::new()),
249 })
250 }
251
252 #[must_use]
254 pub fn with_ranking_options(
255 mut self,
256 temporal_decay_enabled: bool,
257 temporal_decay_half_life_days: u32,
258 mmr_enabled: bool,
259 mmr_lambda: f32,
260 ) -> Self {
261 self.temporal_decay_enabled = temporal_decay_enabled;
262 self.temporal_decay_half_life_days = temporal_decay_half_life_days;
263 self.mmr_enabled = mmr_enabled;
264 self.mmr_lambda = mmr_lambda;
265 self
266 }
267
268 pub async fn with_sqlite_backend(
274 sqlite_path: &str,
275 provider: AnyProvider,
276 embedding_model: &str,
277 vector_weight: f64,
278 keyword_weight: f64,
279 ) -> Result<Self, MemoryError> {
280 Self::with_sqlite_backend_and_pool_size(
281 sqlite_path,
282 provider,
283 embedding_model,
284 vector_weight,
285 keyword_weight,
286 5,
287 )
288 .await
289 }
290
291 pub async fn with_sqlite_backend_and_pool_size(
297 sqlite_path: &str,
298 provider: AnyProvider,
299 embedding_model: &str,
300 vector_weight: f64,
301 keyword_weight: f64,
302 pool_size: u32,
303 ) -> Result<Self, MemoryError> {
304 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
305 let pool = sqlite.pool().clone();
306 let store = EmbeddingStore::new_sqlite(pool);
307
308 Ok(Self {
309 sqlite,
310 qdrant: Some(store),
311 provider,
312 embedding_model: embedding_model.into(),
313 vector_weight,
314 keyword_weight,
315 temporal_decay_enabled: false,
316 temporal_decay_half_life_days: 30,
317 mmr_enabled: false,
318 mmr_lambda: 0.7,
319 token_counter: Arc::new(TokenCounter::new()),
320 })
321 }
322
323 pub async fn remember(
332 &self,
333 conversation_id: ConversationId,
334 role: &str,
335 content: &str,
336 ) -> Result<MessageId, MemoryError> {
337 let message_id = self
338 .sqlite
339 .save_message(conversation_id, role, content)
340 .await?;
341
342 if let Some(qdrant) = &self.qdrant
343 && self.provider.supports_embeddings()
344 {
345 match self.provider.embed(content).await {
346 Ok(vector) => {
347 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
349 if let Err(e) = qdrant.ensure_collection(vector_size).await {
350 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
351 } else if let Err(e) = qdrant
352 .store(
353 message_id,
354 conversation_id,
355 role,
356 vector,
357 MessageKind::Regular,
358 &self.embedding_model,
359 )
360 .await
361 {
362 tracing::warn!("Failed to store embedding: {e:#}");
363 }
364 }
365 Err(e) => {
366 tracing::warn!("Failed to generate embedding: {e:#}");
367 }
368 }
369 }
370
371 Ok(message_id)
372 }
373
374 pub async fn remember_with_parts(
383 &self,
384 conversation_id: ConversationId,
385 role: &str,
386 content: &str,
387 parts_json: &str,
388 ) -> Result<(MessageId, bool), MemoryError> {
389 let message_id = self
390 .sqlite
391 .save_message_with_parts(conversation_id, role, content, parts_json)
392 .await?;
393
394 let mut embedding_stored = false;
395
396 if let Some(qdrant) = &self.qdrant
397 && self.provider.supports_embeddings()
398 {
399 match self.provider.embed(content).await {
400 Ok(vector) => {
401 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
402 if let Err(e) = qdrant.ensure_collection(vector_size).await {
403 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
404 } else if let Err(e) = qdrant
405 .store(
406 message_id,
407 conversation_id,
408 role,
409 vector,
410 MessageKind::Regular,
411 &self.embedding_model,
412 )
413 .await
414 {
415 tracing::warn!("Failed to store embedding: {e:#}");
416 } else {
417 embedding_stored = true;
418 }
419 }
420 Err(e) => {
421 tracing::warn!("Failed to generate embedding: {e:#}");
422 }
423 }
424 }
425
426 Ok((message_id, embedding_stored))
427 }
428
429 pub async fn save_only(
437 &self,
438 conversation_id: ConversationId,
439 role: &str,
440 content: &str,
441 parts_json: &str,
442 ) -> Result<MessageId, MemoryError> {
443 self.sqlite
444 .save_message_with_parts(conversation_id, role, content, parts_json)
445 .await
446 }
447
448 #[allow(clippy::too_many_lines)]
458 pub async fn recall(
459 &self,
460 query: &str,
461 limit: usize,
462 filter: Option<SearchFilter>,
463 ) -> Result<Vec<RecalledMessage>, MemoryError> {
464 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
465
466 let keyword_results = match self
468 .sqlite
469 .keyword_search(query, limit * 2, conversation_id)
470 .await
471 {
472 Ok(results) => results,
473 Err(e) => {
474 tracing::warn!("FTS5 keyword search failed: {e:#}");
475 Vec::new()
476 }
477 };
478
479 let vector_results = if let Some(qdrant) = &self.qdrant
481 && self.provider.supports_embeddings()
482 {
483 let query_vector = self.provider.embed(query).await?;
484 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
485 qdrant.ensure_collection(vector_size).await?;
486 qdrant.search(&query_vector, limit * 2, filter).await?
487 } else {
488 Vec::new()
489 };
490
491 let mut scores: std::collections::HashMap<MessageId, f64> =
493 std::collections::HashMap::new();
494
495 if !vector_results.is_empty() {
496 let max_vs = vector_results
497 .iter()
498 .map(|r| r.score)
499 .fold(f32::NEG_INFINITY, f32::max);
500 let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
501 for r in &vector_results {
502 let normalized = f64::from(r.score / norm);
503 *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
504 }
505 }
506
507 if !keyword_results.is_empty() {
508 let max_ks = keyword_results
509 .iter()
510 .map(|r| r.1)
511 .fold(f64::NEG_INFINITY, f64::max);
512 let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
513 for &(msg_id, score) in &keyword_results {
514 let normalized = score / norm;
515 *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
516 }
517 }
518
519 if scores.is_empty() {
520 return Ok(Vec::new());
521 }
522
523 let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
525 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
526
527 if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
529 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
530 match self.sqlite.message_timestamps(&ids).await {
531 Ok(timestamps) => {
532 apply_temporal_decay(
533 &mut ranked,
534 ×tamps,
535 self.temporal_decay_half_life_days,
536 );
537 ranked
538 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
539 }
540 Err(e) => {
541 tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
542 }
543 }
544 }
545
546 if self.mmr_enabled && !vector_results.is_empty() {
548 if let Some(qdrant) = &self.qdrant {
549 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
550 match qdrant.get_vectors(&ids).await {
551 Ok(vec_map) if !vec_map.is_empty() => {
552 ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
553 }
554 Ok(_) => {
555 ranked.truncate(limit);
556 }
557 Err(e) => {
558 tracing::warn!("MMR: failed to fetch vectors: {e:#}");
559 ranked.truncate(limit);
560 }
561 }
562 } else {
563 ranked.truncate(limit);
564 }
565 } else {
566 ranked.truncate(limit);
567 }
568
569 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
570 let messages = self.sqlite.messages_by_ids(&ids).await?;
571 let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
572
573 let recalled = ranked
574 .iter()
575 .filter_map(|(msg_id, score)| {
576 msg_map.get(msg_id).map(|msg| RecalledMessage {
577 message: msg.clone(),
578 #[expect(clippy::cast_possible_truncation)]
579 score: *score as f32,
580 })
581 })
582 .collect();
583
584 Ok(recalled)
585 }
586
587 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
593 match &self.qdrant {
594 Some(qdrant) => qdrant.has_embedding(message_id).await,
595 None => Ok(false),
596 }
597 }
598
599 pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
608 let Some(qdrant) = &self.qdrant else {
609 return Ok(0);
610 };
611 if !self.provider.supports_embeddings() {
612 return Ok(0);
613 }
614
615 let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
616
617 if unembedded.is_empty() {
618 return Ok(0);
619 }
620
621 let probe = self.provider.embed("probe").await?;
622 let vector_size = u64::try_from(probe.len())?;
623 qdrant.ensure_collection(vector_size).await?;
624
625 let mut count = 0;
626 for (msg_id, conversation_id, role, content) in &unembedded {
627 match self.provider.embed(content).await {
628 Ok(vector) => {
629 if let Err(e) = qdrant
630 .store(
631 *msg_id,
632 *conversation_id,
633 role,
634 vector,
635 MessageKind::Regular,
636 &self.embedding_model,
637 )
638 .await
639 {
640 tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
641 continue;
642 }
643 count += 1;
644 }
645 Err(e) => {
646 tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
647 }
648 }
649 }
650
651 tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
652 Ok(count)
653 }
654
655 pub async fn store_session_summary(
661 &self,
662 conversation_id: ConversationId,
663 summary_text: &str,
664 ) -> Result<(), MemoryError> {
665 let Some(qdrant) = &self.qdrant else {
666 return Ok(());
667 };
668 if !self.provider.supports_embeddings() {
669 return Ok(());
670 }
671
672 let vector = self.provider.embed(summary_text).await?;
673 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
674 qdrant
675 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
676 .await?;
677
678 let payload = serde_json::json!({
679 "conversation_id": conversation_id.0,
680 "summary_text": summary_text,
681 });
682
683 qdrant
684 .store_to_collection(SESSION_SUMMARIES_COLLECTION, payload, vector)
685 .await?;
686
687 tracing::debug!(
688 conversation_id = conversation_id.0,
689 "stored session summary"
690 );
691 Ok(())
692 }
693
694 pub async fn search_session_summaries(
700 &self,
701 query: &str,
702 limit: usize,
703 exclude_conversation_id: Option<ConversationId>,
704 ) -> Result<Vec<SessionSummaryResult>, MemoryError> {
705 let Some(qdrant) = &self.qdrant else {
706 return Ok(Vec::new());
707 };
708 if !self.provider.supports_embeddings() {
709 return Ok(Vec::new());
710 }
711
712 let vector = self.provider.embed(query).await?;
713 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
714 qdrant
715 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
716 .await?;
717
718 let filter = exclude_conversation_id.map(|cid| VectorFilter {
719 must: vec![],
720 must_not: vec![FieldCondition {
721 field: "conversation_id".into(),
722 value: FieldValue::Integer(cid.0),
723 }],
724 });
725
726 let points = qdrant
727 .search_collection(SESSION_SUMMARIES_COLLECTION, &vector, limit, filter)
728 .await?;
729
730 let results = points
731 .into_iter()
732 .filter_map(|point| {
733 let summary_text = point.payload.get("summary_text")?.as_str()?.to_owned();
734 let conversation_id =
735 ConversationId(point.payload.get("conversation_id")?.as_i64()?);
736 Some(SessionSummaryResult {
737 summary_text,
738 score: point.score,
739 conversation_id,
740 })
741 })
742 .collect();
743
744 Ok(results)
745 }
746
747 #[must_use]
749 pub fn sqlite(&self) -> &SqliteStore {
750 &self.sqlite
751 }
752
753 pub async fn is_vector_store_connected(&self) -> bool {
758 match self.qdrant.as_ref() {
759 Some(store) => store.health_check().await,
760 None => false,
761 }
762 }
763
764 #[must_use]
766 pub fn has_vector_store(&self) -> bool {
767 self.qdrant.is_some()
768 }
769
770 pub async fn message_count(&self, conversation_id: ConversationId) -> Result<i64, MemoryError> {
776 self.sqlite.count_messages(conversation_id).await
777 }
778
779 pub async fn unsummarized_message_count(
785 &self,
786 conversation_id: ConversationId,
787 ) -> Result<i64, MemoryError> {
788 let after_id = self
789 .sqlite
790 .latest_summary_last_message_id(conversation_id)
791 .await?
792 .unwrap_or(MessageId(0));
793 self.sqlite
794 .count_messages_after(conversation_id, after_id)
795 .await
796 }
797
798 pub async fn load_summaries(
804 &self,
805 conversation_id: ConversationId,
806 ) -> Result<Vec<Summary>, MemoryError> {
807 let rows = self.sqlite.load_summaries(conversation_id).await?;
808 let summaries = rows
809 .into_iter()
810 .map(
811 |(
812 id,
813 conversation_id,
814 content,
815 first_message_id,
816 last_message_id,
817 token_estimate,
818 )| {
819 Summary {
820 id,
821 conversation_id,
822 content,
823 first_message_id,
824 last_message_id,
825 token_estimate,
826 }
827 },
828 )
829 .collect();
830 Ok(summaries)
831 }
832
833 pub async fn summarize(
841 &self,
842 conversation_id: ConversationId,
843 message_count: usize,
844 ) -> Result<Option<i64>, MemoryError> {
845 let total = self.sqlite.count_messages(conversation_id).await?;
846
847 if total <= i64::try_from(message_count)? {
848 return Ok(None);
849 }
850
851 let after_id = self
852 .sqlite
853 .latest_summary_last_message_id(conversation_id)
854 .await?
855 .unwrap_or(MessageId(0));
856
857 let messages = self
858 .sqlite
859 .load_messages_range(conversation_id, after_id, message_count)
860 .await?;
861
862 if messages.is_empty() {
863 return Ok(None);
864 }
865
866 let prompt = build_summarization_prompt(&messages);
867 let chat_messages = vec![Message {
868 role: Role::User,
869 content: prompt,
870 parts: vec![],
871 metadata: MessageMetadata::default(),
872 }];
873
874 let structured = match self
875 .provider
876 .chat_typed_erased::<StructuredSummary>(&chat_messages)
877 .await
878 {
879 Ok(s) => s,
880 Err(e) => {
881 tracing::warn!(
882 "structured summarization failed, falling back to plain text: {e:#}"
883 );
884 let plain = self.provider.chat(&chat_messages).await?;
885 StructuredSummary {
886 summary: plain,
887 key_facts: vec![],
888 entities: vec![],
889 }
890 }
891 };
892 let summary_text = &structured.summary;
893
894 let token_estimate = i64::try_from(self.token_counter.count_tokens(summary_text))?;
895 let first_message_id = messages[0].0;
896 let last_message_id = messages[messages.len() - 1].0;
897
898 let summary_id = self
899 .sqlite
900 .save_summary(
901 conversation_id,
902 summary_text,
903 first_message_id,
904 last_message_id,
905 token_estimate,
906 )
907 .await?;
908
909 if let Some(qdrant) = &self.qdrant
910 && self.provider.supports_embeddings()
911 {
912 match self.provider.embed(summary_text).await {
913 Ok(vector) => {
914 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
916 if let Err(e) = qdrant.ensure_collection(vector_size).await {
917 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
918 } else if let Err(e) = qdrant
919 .store(
920 MessageId(summary_id),
921 conversation_id,
922 "system",
923 vector,
924 MessageKind::Summary,
925 &self.embedding_model,
926 )
927 .await
928 {
929 tracing::warn!("Failed to embed summary: {e:#}");
930 }
931 }
932 Err(e) => {
933 tracing::warn!("Failed to generate summary embedding: {e:#}");
934 }
935 }
936 }
937
938 if !structured.key_facts.is_empty() {
940 self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
941 .await;
942 }
943
944 Ok(Some(summary_id))
945 }
946
947 async fn store_key_facts(
948 &self,
949 conversation_id: ConversationId,
950 source_summary_id: i64,
951 key_facts: &[String],
952 ) {
953 let Some(qdrant) = &self.qdrant else {
954 return;
955 };
956 if !self.provider.supports_embeddings() {
957 return;
958 }
959
960 let Some(first_fact) = key_facts.first() else {
961 return;
962 };
963 let first_vector = match self.provider.embed(first_fact).await {
964 Ok(v) => v,
965 Err(e) => {
966 tracing::warn!("Failed to embed key fact: {e:#}");
967 return;
968 }
969 };
970 let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
971 if let Err(e) = qdrant
972 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
973 .await
974 {
975 tracing::warn!("Failed to ensure key_facts collection: {e:#}");
976 return;
977 }
978
979 let first_payload = serde_json::json!({
980 "conversation_id": conversation_id.0,
981 "fact_text": first_fact,
982 "source_summary_id": source_summary_id,
983 });
984 if let Err(e) = qdrant
985 .store_to_collection(KEY_FACTS_COLLECTION, first_payload, first_vector)
986 .await
987 {
988 tracing::warn!("Failed to store key fact: {e:#}");
989 }
990
991 for fact in &key_facts[1..] {
992 match self.provider.embed(fact).await {
993 Ok(vector) => {
994 let payload = serde_json::json!({
995 "conversation_id": conversation_id.0,
996 "fact_text": fact,
997 "source_summary_id": source_summary_id,
998 });
999 if let Err(e) = qdrant
1000 .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
1001 .await
1002 {
1003 tracing::warn!("Failed to store key fact: {e:#}");
1004 }
1005 }
1006 Err(e) => {
1007 tracing::warn!("Failed to embed key fact: {e:#}");
1008 }
1009 }
1010 }
1011 }
1012
1013 pub async fn search_key_facts(
1019 &self,
1020 query: &str,
1021 limit: usize,
1022 ) -> Result<Vec<String>, MemoryError> {
1023 let Some(qdrant) = &self.qdrant else {
1024 return Ok(Vec::new());
1025 };
1026 if !self.provider.supports_embeddings() {
1027 return Ok(Vec::new());
1028 }
1029
1030 let vector = self.provider.embed(query).await?;
1031 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
1032 qdrant
1033 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
1034 .await?;
1035
1036 let points = qdrant
1037 .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
1038 .await?;
1039
1040 let facts = points
1041 .into_iter()
1042 .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
1043 .collect();
1044
1045 Ok(facts)
1046 }
1047}
1048
1049#[cfg(test)]
1050mod tests {
1051 use zeph_llm::mock::MockProvider;
1052 use zeph_llm::provider::Role;
1053
1054 use super::*;
1055
1056 fn test_provider() -> AnyProvider {
1057 AnyProvider::Mock(MockProvider::default())
1058 }
1059
1060 async fn test_semantic_memory(_supports_embeddings: bool) -> SemanticMemory {
1061 let provider = test_provider();
1062 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1063
1064 SemanticMemory {
1065 sqlite,
1066 qdrant: None,
1067 provider,
1068 embedding_model: "test-model".into(),
1069 vector_weight: 0.7,
1070 keyword_weight: 0.3,
1071 temporal_decay_enabled: false,
1072 temporal_decay_half_life_days: 30,
1073 mmr_enabled: false,
1074 mmr_lambda: 0.7,
1075 token_counter: Arc::new(TokenCounter::new()),
1076 }
1077 }
1078
1079 #[tokio::test]
1080 async fn remember_saves_to_sqlite() {
1081 let memory = test_semantic_memory(false).await;
1082
1083 let cid = memory.sqlite.create_conversation().await.unwrap();
1084 let msg_id = memory.remember(cid, "user", "hello").await.unwrap();
1085
1086 assert_eq!(msg_id, MessageId(1));
1087
1088 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1089 assert_eq!(history.len(), 1);
1090 assert_eq!(history[0].role, Role::User);
1091 assert_eq!(history[0].content, "hello");
1092 }
1093
1094 #[tokio::test]
1095 async fn remember_with_parts_saves_parts_json() {
1096 let memory = test_semantic_memory(false).await;
1097 let cid = memory.sqlite.create_conversation().await.unwrap();
1098
1099 let parts_json =
1100 r#"[{"kind":"ToolOutput","tool_name":"shell","body":"hello","compacted_at":null}]"#;
1101 let (msg_id, _embedding_stored) = memory
1102 .remember_with_parts(cid, "assistant", "tool output", parts_json)
1103 .await
1104 .unwrap();
1105 assert!(msg_id > MessageId(0));
1106
1107 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1108 assert_eq!(history.len(), 1);
1109 assert_eq!(history[0].content, "tool output");
1110 }
1111
1112 #[tokio::test]
1113 async fn recall_returns_empty_without_qdrant() {
1114 let memory = test_semantic_memory(true).await;
1115
1116 let recalled = memory.recall("test", 5, None).await.unwrap();
1117 assert!(recalled.is_empty());
1118 }
1119
1120 #[tokio::test]
1121 async fn has_embedding_without_qdrant() {
1122 let memory = test_semantic_memory(true).await;
1123
1124 let has_embedding = memory.has_embedding(MessageId(1)).await.unwrap();
1125 assert!(!has_embedding);
1126 }
1127
1128 #[tokio::test]
1129 async fn embed_missing_without_qdrant() {
1130 let memory = test_semantic_memory(true).await;
1131
1132 let count = memory.embed_missing().await.unwrap();
1133 assert_eq!(count, 0);
1134 }
1135
1136 #[tokio::test]
1137 async fn sqlite_accessor() {
1138 let memory = test_semantic_memory(false).await;
1139
1140 let cid = memory.sqlite().create_conversation().await.unwrap();
1141 assert_eq!(cid, ConversationId(1));
1142
1143 memory
1144 .sqlite()
1145 .save_message(cid, "user", "test")
1146 .await
1147 .unwrap();
1148
1149 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1150 assert_eq!(history.len(), 1);
1151 }
1152
1153 #[tokio::test]
1154 async fn has_vector_store_returns_false_when_unavailable() {
1155 let memory = test_semantic_memory(false).await;
1156 assert!(!memory.has_vector_store());
1157 }
1158
1159 #[tokio::test]
1160 async fn is_vector_store_connected_returns_false_when_unavailable() {
1161 let memory = test_semantic_memory(false).await;
1162 assert!(!memory.is_vector_store_connected().await);
1163 }
1164
1165 #[tokio::test]
1166 async fn recall_returns_empty_when_embeddings_not_supported() {
1167 let memory = test_semantic_memory(false).await;
1168
1169 let recalled = memory.recall("test", 5, None).await.unwrap();
1170 assert!(recalled.is_empty());
1171 }
1172
1173 #[tokio::test]
1174 async fn embed_missing_returns_zero_when_embeddings_not_supported() {
1175 let memory = test_semantic_memory(false).await;
1176
1177 let cid = memory.sqlite().create_conversation().await.unwrap();
1178 memory
1179 .sqlite()
1180 .save_message(cid, "user", "test")
1181 .await
1182 .unwrap();
1183
1184 let count = memory.embed_missing().await.unwrap();
1185 assert_eq!(count, 0);
1186 }
1187
1188 #[tokio::test]
1189 async fn message_count_empty_conversation() {
1190 let memory = test_semantic_memory(false).await;
1191 let cid = memory.sqlite().create_conversation().await.unwrap();
1192
1193 let count = memory.message_count(cid).await.unwrap();
1194 assert_eq!(count, 0);
1195 }
1196
1197 #[tokio::test]
1198 async fn message_count_after_saves() {
1199 let memory = test_semantic_memory(false).await;
1200 let cid = memory.sqlite().create_conversation().await.unwrap();
1201
1202 memory.remember(cid, "user", "msg1").await.unwrap();
1203 memory.remember(cid, "assistant", "msg2").await.unwrap();
1204
1205 let count = memory.message_count(cid).await.unwrap();
1206 assert_eq!(count, 2);
1207 }
1208
1209 #[tokio::test]
1210 async fn unsummarized_count_decreases_after_summary() {
1211 let memory = test_semantic_memory(false).await;
1212 let cid = memory.sqlite().create_conversation().await.unwrap();
1213
1214 for i in 0..10 {
1215 memory
1216 .remember(cid, "user", &format!("msg{i}"))
1217 .await
1218 .unwrap();
1219 }
1220 assert_eq!(memory.unsummarized_message_count(cid).await.unwrap(), 10);
1221
1222 memory.summarize(cid, 5).await.unwrap();
1223
1224 assert!(memory.unsummarized_message_count(cid).await.unwrap() < 10);
1225 assert_eq!(memory.message_count(cid).await.unwrap(), 10);
1226 }
1227
1228 #[tokio::test]
1229 async fn load_summaries_empty() {
1230 let memory = test_semantic_memory(false).await;
1231 let cid = memory.sqlite().create_conversation().await.unwrap();
1232
1233 let summaries = memory.load_summaries(cid).await.unwrap();
1234 assert!(summaries.is_empty());
1235 }
1236
1237 #[tokio::test]
1238 async fn load_summaries_ordered() {
1239 let memory = test_semantic_memory(false).await;
1240 let cid = memory.sqlite().create_conversation().await.unwrap();
1241
1242 let msg_id1 = memory.remember(cid, "user", "m1").await.unwrap();
1243 let msg_id2 = memory.remember(cid, "assistant", "m2").await.unwrap();
1244 let msg_id3 = memory.remember(cid, "user", "m3").await.unwrap();
1245
1246 let s1 = memory
1247 .sqlite()
1248 .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
1249 .await
1250 .unwrap();
1251 let s2 = memory
1252 .sqlite()
1253 .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
1254 .await
1255 .unwrap();
1256
1257 let summaries = memory.load_summaries(cid).await.unwrap();
1258 assert_eq!(summaries.len(), 2);
1259 assert_eq!(summaries[0].id, s1);
1260 assert_eq!(summaries[0].content, "summary1");
1261 assert_eq!(summaries[1].id, s2);
1262 assert_eq!(summaries[1].content, "summary2");
1263 }
1264
1265 #[tokio::test]
1266 async fn summarize_below_threshold() {
1267 let memory = test_semantic_memory(false).await;
1268 let cid = memory.sqlite().create_conversation().await.unwrap();
1269
1270 memory.remember(cid, "user", "hello").await.unwrap();
1271
1272 let result = memory.summarize(cid, 10).await.unwrap();
1273 assert!(result.is_none());
1274 }
1275
1276 #[tokio::test]
1277 async fn summarize_stores_summary() {
1278 let memory = test_semantic_memory(false).await;
1279 let cid = memory.sqlite().create_conversation().await.unwrap();
1280
1281 for i in 0..5 {
1282 memory
1283 .remember(cid, "user", &format!("message {i}"))
1284 .await
1285 .unwrap();
1286 }
1287
1288 let summary_id = memory.summarize(cid, 3).await.unwrap();
1289 assert!(summary_id.is_some());
1290
1291 let summaries = memory.load_summaries(cid).await.unwrap();
1292 assert_eq!(summaries.len(), 1);
1293 assert_eq!(summaries[0].id, summary_id.unwrap());
1294 assert!(!summaries[0].content.is_empty());
1295 }
1296
1297 #[tokio::test]
1298 async fn summarize_respects_previous_summaries() {
1299 let memory = test_semantic_memory(false).await;
1300 let cid = memory.sqlite().create_conversation().await.unwrap();
1301
1302 for i in 0..10 {
1303 memory
1304 .remember(cid, "user", &format!("message {i}"))
1305 .await
1306 .unwrap();
1307 }
1308
1309 let s1 = memory.summarize(cid, 3).await.unwrap();
1310 assert!(s1.is_some());
1311
1312 let s2 = memory.summarize(cid, 3).await.unwrap();
1313 assert!(s2.is_some());
1314
1315 let summaries = memory.load_summaries(cid).await.unwrap();
1316 assert_eq!(summaries.len(), 2);
1317 assert!(summaries[0].last_message_id < summaries[1].first_message_id);
1318 }
1319
1320 #[tokio::test]
1321 async fn remember_multiple_messages_increments_ids() {
1322 let memory = test_semantic_memory(false).await;
1323 let cid = memory.sqlite.create_conversation().await.unwrap();
1324
1325 let id1 = memory.remember(cid, "user", "first").await.unwrap();
1326 let id2 = memory.remember(cid, "assistant", "second").await.unwrap();
1327 let id3 = memory.remember(cid, "user", "third").await.unwrap();
1328
1329 assert!(id1 < id2);
1330 assert!(id2 < id3);
1331 }
1332
1333 #[tokio::test]
1334 async fn message_count_across_conversations() {
1335 let memory = test_semantic_memory(false).await;
1336 let cid1 = memory.sqlite().create_conversation().await.unwrap();
1337 let cid2 = memory.sqlite().create_conversation().await.unwrap();
1338
1339 memory.remember(cid1, "user", "msg1").await.unwrap();
1340 memory.remember(cid1, "user", "msg2").await.unwrap();
1341 memory.remember(cid2, "user", "msg3").await.unwrap();
1342
1343 assert_eq!(memory.message_count(cid1).await.unwrap(), 2);
1344 assert_eq!(memory.message_count(cid2).await.unwrap(), 1);
1345 }
1346
1347 #[tokio::test]
1348 async fn summarize_exact_threshold_returns_none() {
1349 let memory = test_semantic_memory(false).await;
1350 let cid = memory.sqlite().create_conversation().await.unwrap();
1351
1352 for i in 0..3 {
1353 memory
1354 .remember(cid, "user", &format!("msg {i}"))
1355 .await
1356 .unwrap();
1357 }
1358
1359 let result = memory.summarize(cid, 3).await.unwrap();
1360 assert!(result.is_none());
1361 }
1362
1363 #[tokio::test]
1364 async fn summarize_one_above_threshold_produces_summary() {
1365 let memory = test_semantic_memory(false).await;
1366 let cid = memory.sqlite().create_conversation().await.unwrap();
1367
1368 for i in 0..4 {
1369 memory
1370 .remember(cid, "user", &format!("msg {i}"))
1371 .await
1372 .unwrap();
1373 }
1374
1375 let result = memory.summarize(cid, 3).await.unwrap();
1376 assert!(result.is_some());
1377 }
1378
1379 #[tokio::test]
1380 async fn summary_fields_populated() {
1381 let memory = test_semantic_memory(false).await;
1382 let cid = memory.sqlite().create_conversation().await.unwrap();
1383
1384 for i in 0..5 {
1385 memory
1386 .remember(cid, "user", &format!("msg {i}"))
1387 .await
1388 .unwrap();
1389 }
1390
1391 memory.summarize(cid, 3).await.unwrap();
1392 let summaries = memory.load_summaries(cid).await.unwrap();
1393 let s = &summaries[0];
1394
1395 assert_eq!(s.conversation_id, cid);
1396 assert!(s.first_message_id > MessageId(0));
1397 assert!(s.last_message_id >= s.first_message_id);
1398 assert!(s.token_estimate >= 0);
1399 assert!(!s.content.is_empty());
1400 }
1401
1402 #[test]
1403 fn build_summarization_prompt_format() {
1404 let messages = vec![
1405 (MessageId(1), "user".into(), "Hello".into()),
1406 (MessageId(2), "assistant".into(), "Hi there".into()),
1407 ];
1408 let prompt = build_summarization_prompt(&messages);
1409 assert!(prompt.contains("user: Hello"));
1410 assert!(prompt.contains("assistant: Hi there"));
1411 assert!(prompt.contains("key_facts"));
1412 }
1413
1414 #[test]
1415 fn build_summarization_prompt_empty() {
1416 let messages: Vec<(MessageId, String, String)> = vec![];
1417 let prompt = build_summarization_prompt(&messages);
1418 assert!(prompt.contains("key_facts"));
1419 }
1420
1421 #[test]
1422 fn structured_summary_deserialize() {
1423 let json = r#"{"summary":"s","key_facts":["f1","f2"],"entities":["e1"]}"#;
1424 let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1425 assert_eq!(ss.summary, "s");
1426 assert_eq!(ss.key_facts.len(), 2);
1427 assert_eq!(ss.entities.len(), 1);
1428 }
1429
1430 #[test]
1431 fn structured_summary_empty_facts() {
1432 let json = r#"{"summary":"s","key_facts":[],"entities":[]}"#;
1433 let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1434 assert!(ss.key_facts.is_empty());
1435 assert!(ss.entities.is_empty());
1436 }
1437
1438 #[tokio::test]
1439 async fn search_key_facts_no_qdrant_empty() {
1440 let memory = test_semantic_memory(false).await;
1441 let facts = memory.search_key_facts("query", 5).await.unwrap();
1442 assert!(facts.is_empty());
1443 }
1444
1445 #[test]
1446 fn recalled_message_debug() {
1447 let recalled = RecalledMessage {
1448 message: Message {
1449 role: Role::User,
1450 content: "test".into(),
1451 parts: vec![],
1452 metadata: MessageMetadata::default(),
1453 },
1454 score: 0.95,
1455 };
1456 let dbg = format!("{recalled:?}");
1457 assert!(dbg.contains("RecalledMessage"));
1458 assert!(dbg.contains("0.95"));
1459 }
1460
1461 #[test]
1462 fn summary_clone() {
1463 let summary = Summary {
1464 id: 1,
1465 conversation_id: ConversationId(2),
1466 content: "test summary".into(),
1467 first_message_id: MessageId(1),
1468 last_message_id: MessageId(5),
1469 token_estimate: 10,
1470 };
1471 let cloned = summary.clone();
1472 assert_eq!(summary.id, cloned.id);
1473 assert_eq!(summary.content, cloned.content);
1474 }
1475
1476 #[tokio::test]
1477 async fn remember_preserves_role_mapping() {
1478 let memory = test_semantic_memory(false).await;
1479 let cid = memory.sqlite.create_conversation().await.unwrap();
1480
1481 memory.remember(cid, "user", "u").await.unwrap();
1482 memory.remember(cid, "assistant", "a").await.unwrap();
1483 memory.remember(cid, "system", "s").await.unwrap();
1484
1485 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1486 assert_eq!(history.len(), 3);
1487 assert_eq!(history[0].role, Role::User);
1488 assert_eq!(history[1].role, Role::Assistant);
1489 assert_eq!(history[2].role, Role::System);
1490 }
1491
1492 #[tokio::test]
1493 async fn new_with_invalid_qdrant_url_graceful() {
1494 let mut mock = MockProvider::default();
1495 mock.supports_embeddings = true;
1496 let provider = AnyProvider::Mock(mock);
1497 let result =
1498 SemanticMemory::new(":memory:", "http://127.0.0.1:1", provider, "test-model").await;
1499 assert!(result.is_ok());
1500 }
1501
1502 #[tokio::test]
1503 async fn test_semantic_memory_sqlite_remember_recall_roundtrip() {
1504 let mut mock = MockProvider::default();
1506 mock.supports_embeddings = true;
1507 let provider = AnyProvider::Mock(mock);
1510
1511 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1512 let pool = sqlite.pool().clone();
1513 let qdrant = Some(crate::embedding_store::EmbeddingStore::new_sqlite(pool));
1514
1515 let memory = SemanticMemory {
1516 sqlite,
1517 qdrant,
1518 provider,
1519 embedding_model: "test-model".into(),
1520 vector_weight: 0.7,
1521 keyword_weight: 0.3,
1522 temporal_decay_enabled: false,
1523 temporal_decay_half_life_days: 30,
1524 mmr_enabled: false,
1525 mmr_lambda: 0.7,
1526 token_counter: Arc::new(TokenCounter::new()),
1527 };
1528
1529 let cid = memory.sqlite().create_conversation().await.unwrap();
1530
1531 let id1 = memory
1533 .remember(cid, "user", "rust async programming")
1534 .await
1535 .unwrap();
1536 let id2 = memory
1537 .remember(cid, "assistant", "use tokio for async")
1538 .await
1539 .unwrap();
1540 assert!(id1 < id2);
1541
1542 let recalled = memory.recall("rust", 5, None).await.unwrap();
1544 assert!(
1545 !recalled.is_empty(),
1546 "recall must return at least one result"
1547 );
1548
1549 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1551 assert_eq!(history.len(), 2);
1552 assert_eq!(history[0].content, "rust async programming");
1553 }
1554
1555 #[tokio::test]
1556 async fn remember_with_embeddings_supported_but_no_qdrant() {
1557 let memory = test_semantic_memory(true).await;
1558 let cid = memory.sqlite.create_conversation().await.unwrap();
1559
1560 let msg_id = memory.remember(cid, "user", "hello embed").await.unwrap();
1561 assert!(msg_id > MessageId(0));
1562
1563 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1564 assert_eq!(history.len(), 1);
1565 assert_eq!(history[0].content, "hello embed");
1566 }
1567
1568 #[tokio::test]
1569 async fn remember_verifies_content_via_load_history() {
1570 let memory = test_semantic_memory(false).await;
1571 let cid = memory.sqlite.create_conversation().await.unwrap();
1572
1573 memory.remember(cid, "user", "alpha").await.unwrap();
1574 memory.remember(cid, "assistant", "beta").await.unwrap();
1575 memory.remember(cid, "user", "gamma").await.unwrap();
1576
1577 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1578 assert_eq!(history.len(), 3);
1579 assert_eq!(history[0].content, "alpha");
1580 assert_eq!(history[1].content, "beta");
1581 assert_eq!(history[2].content, "gamma");
1582 }
1583
1584 #[tokio::test]
1585 async fn message_count_multiple_conversations_isolated() {
1586 let memory = test_semantic_memory(false).await;
1587 let cid1 = memory.sqlite().create_conversation().await.unwrap();
1588 let cid2 = memory.sqlite().create_conversation().await.unwrap();
1589 let cid3 = memory.sqlite().create_conversation().await.unwrap();
1590
1591 for _ in 0..5 {
1592 memory.remember(cid1, "user", "msg").await.unwrap();
1593 }
1594 for _ in 0..3 {
1595 memory.remember(cid2, "user", "msg").await.unwrap();
1596 }
1597
1598 assert_eq!(memory.message_count(cid1).await.unwrap(), 5);
1599 assert_eq!(memory.message_count(cid2).await.unwrap(), 3);
1600 assert_eq!(memory.message_count(cid3).await.unwrap(), 0);
1601 }
1602
1603 #[tokio::test]
1604 async fn summarize_empty_messages_range_returns_none() {
1605 let memory = test_semantic_memory(false).await;
1606 let cid = memory.sqlite().create_conversation().await.unwrap();
1607
1608 for i in 0..6 {
1609 memory
1610 .remember(cid, "user", &format!("msg {i}"))
1611 .await
1612 .unwrap();
1613 }
1614
1615 memory.summarize(cid, 3).await.unwrap();
1616 memory.summarize(cid, 3).await.unwrap();
1617
1618 let summaries = memory.load_summaries(cid).await.unwrap();
1619 assert_eq!(summaries.len(), 2);
1620 }
1621
1622 #[tokio::test]
1623 async fn summarize_token_estimate_populated() {
1624 let memory = test_semantic_memory(false).await;
1625 let cid = memory.sqlite().create_conversation().await.unwrap();
1626
1627 for i in 0..5 {
1628 memory
1629 .remember(cid, "user", &format!("message {i}"))
1630 .await
1631 .unwrap();
1632 }
1633
1634 memory.summarize(cid, 3).await.unwrap();
1635 let summaries = memory.load_summaries(cid).await.unwrap();
1636 let token_est = summaries[0].token_estimate;
1637 assert!(token_est > 0);
1638 }
1639
1640 #[tokio::test]
1641 async fn summarize_fails_when_provider_chat_fails() {
1642 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1643 let provider = AnyProvider::Ollama(zeph_llm::ollama::OllamaProvider::new(
1644 "http://127.0.0.1:1",
1645 "test".into(),
1646 "embed".into(),
1647 ));
1648 let memory = SemanticMemory {
1649 sqlite,
1650 qdrant: None,
1651 provider,
1652 embedding_model: "test".into(),
1653 vector_weight: 0.7,
1654 keyword_weight: 0.3,
1655 temporal_decay_enabled: false,
1656 temporal_decay_half_life_days: 30,
1657 mmr_enabled: false,
1658 mmr_lambda: 0.7,
1659 token_counter: Arc::new(TokenCounter::new()),
1660 };
1661 let cid = memory.sqlite().create_conversation().await.unwrap();
1662
1663 for i in 0..5 {
1664 memory
1665 .remember(cid, "user", &format!("msg {i}"))
1666 .await
1667 .unwrap();
1668 }
1669
1670 let result = memory.summarize(cid, 3).await;
1671 assert!(result.is_err());
1672 }
1673
1674 #[tokio::test]
1675 async fn embed_missing_without_embedding_support_returns_zero() {
1676 let memory = test_semantic_memory(false).await;
1677 let cid = memory.sqlite().create_conversation().await.unwrap();
1678 memory
1679 .sqlite()
1680 .save_message(cid, "user", "test message")
1681 .await
1682 .unwrap();
1683
1684 let count = memory.embed_missing().await.unwrap();
1685 assert_eq!(count, 0);
1686 }
1687
1688 #[tokio::test]
1689 async fn has_embedding_returns_false_when_no_qdrant() {
1690 let memory = test_semantic_memory(false).await;
1691 let cid = memory.sqlite.create_conversation().await.unwrap();
1692 let msg_id = memory.remember(cid, "user", "test").await.unwrap();
1693 assert!(!memory.has_embedding(msg_id).await.unwrap());
1694 }
1695
1696 #[tokio::test]
1697 async fn recall_empty_without_qdrant_regardless_of_filter() {
1698 let memory = test_semantic_memory(true).await;
1699 let filter = SearchFilter {
1700 conversation_id: Some(ConversationId(1)),
1701 role: None,
1702 };
1703 let recalled = memory.recall("query", 10, Some(filter)).await.unwrap();
1704 assert!(recalled.is_empty());
1705 }
1706
1707 #[tokio::test]
1708 async fn summarize_message_range_bounds() {
1709 let memory = test_semantic_memory(false).await;
1710 let cid = memory.sqlite().create_conversation().await.unwrap();
1711
1712 for i in 0..8 {
1713 memory
1714 .remember(cid, "user", &format!("msg {i}"))
1715 .await
1716 .unwrap();
1717 }
1718
1719 let summary_id = memory.summarize(cid, 4).await.unwrap().unwrap();
1720 let summaries = memory.load_summaries(cid).await.unwrap();
1721 assert_eq!(summaries.len(), 1);
1722 assert_eq!(summaries[0].id, summary_id);
1723 assert!(summaries[0].first_message_id >= MessageId(1));
1724 assert!(summaries[0].last_message_id >= summaries[0].first_message_id);
1725 }
1726
1727 #[test]
1728 fn build_summarization_prompt_preserves_order() {
1729 let messages = vec![
1730 (MessageId(1), "user".into(), "first".into()),
1731 (MessageId(2), "assistant".into(), "second".into()),
1732 (MessageId(3), "user".into(), "third".into()),
1733 ];
1734 let prompt = build_summarization_prompt(&messages);
1735 let first_pos = prompt.find("user: first").unwrap();
1736 let second_pos = prompt.find("assistant: second").unwrap();
1737 let third_pos = prompt.find("user: third").unwrap();
1738 assert!(first_pos < second_pos);
1739 assert!(second_pos < third_pos);
1740 }
1741
1742 #[test]
1743 fn summary_debug() {
1744 let summary = Summary {
1745 id: 1,
1746 conversation_id: ConversationId(2),
1747 content: "test".into(),
1748 first_message_id: MessageId(1),
1749 last_message_id: MessageId(5),
1750 token_estimate: 10,
1751 };
1752 let dbg = format!("{summary:?}");
1753 assert!(dbg.contains("Summary"));
1754 }
1755
1756 #[tokio::test]
1757 async fn message_count_nonexistent_conversation() {
1758 let memory = test_semantic_memory(false).await;
1759 let count = memory.message_count(ConversationId(999)).await.unwrap();
1760 assert_eq!(count, 0);
1761 }
1762
1763 #[tokio::test]
1764 async fn load_summaries_nonexistent_conversation() {
1765 let memory = test_semantic_memory(false).await;
1766 let summaries = memory.load_summaries(ConversationId(999)).await.unwrap();
1767 assert!(summaries.is_empty());
1768 }
1769
1770 #[tokio::test]
1771 async fn store_session_summary_no_qdrant_noop() {
1772 let memory = test_semantic_memory(true).await;
1773 let result = memory
1774 .store_session_summary(ConversationId(1), "test summary")
1775 .await;
1776 assert!(result.is_ok());
1777 }
1778
1779 #[tokio::test]
1780 async fn store_session_summary_no_embeddings_noop() {
1781 let memory = test_semantic_memory(false).await;
1782 let result = memory
1783 .store_session_summary(ConversationId(1), "test summary")
1784 .await;
1785 assert!(result.is_ok());
1786 }
1787
1788 #[tokio::test]
1789 async fn search_session_summaries_no_qdrant_empty() {
1790 let memory = test_semantic_memory(true).await;
1791 let results = memory
1792 .search_session_summaries("query", 5, None)
1793 .await
1794 .unwrap();
1795 assert!(results.is_empty());
1796 }
1797
1798 #[tokio::test]
1799 async fn search_session_summaries_no_embeddings_empty() {
1800 let memory = test_semantic_memory(false).await;
1801 let results = memory
1802 .search_session_summaries("query", 5, Some(ConversationId(1)))
1803 .await
1804 .unwrap();
1805 assert!(results.is_empty());
1806 }
1807
1808 #[test]
1809 fn session_summary_result_debug() {
1810 let result = SessionSummaryResult {
1811 summary_text: "test".into(),
1812 score: 0.9,
1813 conversation_id: ConversationId(1),
1814 };
1815 let dbg = format!("{result:?}");
1816 assert!(dbg.contains("SessionSummaryResult"));
1817 }
1818
1819 #[test]
1820 fn session_summary_result_clone() {
1821 let result = SessionSummaryResult {
1822 summary_text: "test".into(),
1823 score: 0.9,
1824 conversation_id: ConversationId(1),
1825 };
1826 let cloned = result.clone();
1827 assert_eq!(result.summary_text, cloned.summary_text);
1828 assert_eq!(result.conversation_id, cloned.conversation_id);
1829 }
1830
1831 #[tokio::test]
1832 async fn recall_fts5_fallback_without_qdrant() {
1833 let memory = test_semantic_memory(false).await;
1834 let cid = memory.sqlite.create_conversation().await.unwrap();
1835
1836 memory
1837 .remember(cid, "user", "rust programming guide")
1838 .await
1839 .unwrap();
1840 memory
1841 .remember(cid, "assistant", "python tutorial")
1842 .await
1843 .unwrap();
1844 memory
1845 .remember(cid, "user", "advanced rust patterns")
1846 .await
1847 .unwrap();
1848
1849 let recalled = memory.recall("rust", 5, None).await.unwrap();
1850 assert_eq!(recalled.len(), 2);
1851 assert!(recalled[0].score >= recalled[1].score);
1852 }
1853
1854 #[tokio::test]
1855 async fn recall_fts5_fallback_with_filter() {
1856 let memory = test_semantic_memory(false).await;
1857 let cid1 = memory.sqlite.create_conversation().await.unwrap();
1858 let cid2 = memory.sqlite.create_conversation().await.unwrap();
1859
1860 memory.remember(cid1, "user", "hello world").await.unwrap();
1861 memory
1862 .remember(cid2, "user", "hello universe")
1863 .await
1864 .unwrap();
1865
1866 let filter = SearchFilter {
1867 conversation_id: Some(cid1),
1868 role: None,
1869 };
1870 let recalled = memory.recall("hello", 5, Some(filter)).await.unwrap();
1871 assert_eq!(recalled.len(), 1);
1872 }
1873
1874 #[tokio::test]
1875 async fn recall_fts5_no_matches_returns_empty() {
1876 let memory = test_semantic_memory(false).await;
1877 let cid = memory.sqlite.create_conversation().await.unwrap();
1878
1879 memory.remember(cid, "user", "hello world").await.unwrap();
1880
1881 let recalled = memory.recall("nonexistent", 5, None).await.unwrap();
1882 assert!(recalled.is_empty());
1883 }
1884
1885 #[tokio::test]
1886 async fn recall_fts5_respects_limit() {
1887 let memory = test_semantic_memory(false).await;
1888 let cid = memory.sqlite.create_conversation().await.unwrap();
1889
1890 for i in 0..10 {
1891 memory
1892 .remember(cid, "user", &format!("test message number {i}"))
1893 .await
1894 .unwrap();
1895 }
1896
1897 let recalled = memory.recall("test", 3, None).await.unwrap();
1898 assert_eq!(recalled.len(), 3);
1899 }
1900
1901 #[tokio::test]
1904 async fn summarize_fallback_to_plain_text_when_structured_fails() {
1905 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1913 let mut mock = MockProvider::default();
1914 mock.default_response = "plain text summary".into();
1916 let provider = AnyProvider::Mock(mock);
1917
1918 let memory = SemanticMemory {
1919 sqlite,
1920 qdrant: None,
1921 provider,
1922 embedding_model: "test".into(),
1923 vector_weight: 0.7,
1924 keyword_weight: 0.3,
1925 temporal_decay_enabled: false,
1926 temporal_decay_half_life_days: 30,
1927 mmr_enabled: false,
1928 mmr_lambda: 0.7,
1929 token_counter: Arc::new(TokenCounter::new()),
1930 };
1931
1932 let cid = memory.sqlite().create_conversation().await.unwrap();
1933 for i in 0..5 {
1934 memory
1935 .remember(cid, "user", &format!("msg {i}"))
1936 .await
1937 .unwrap();
1938 }
1939
1940 let result = memory.summarize(cid, 3).await;
1941 assert!(result.is_ok());
1947 let summaries = memory.load_summaries(cid).await.unwrap();
1948 assert_eq!(summaries.len(), 1);
1949 assert!(!summaries[0].content.is_empty());
1950 }
1951
1952 #[test]
1955 fn temporal_decay_disabled_leaves_scores_unchanged() {
1956 let mut ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
1957 let timestamps = std::collections::HashMap::new();
1958 apply_temporal_decay(&mut ranked, ×tamps, 30);
1959 assert!((ranked[0].1 - 1.0).abs() < f64::EPSILON);
1960 assert!((ranked[1].1 - 0.5).abs() < f64::EPSILON);
1961 }
1962
1963 #[test]
1964 fn temporal_decay_zero_age_preserves_score() {
1965 let now = std::time::SystemTime::now()
1966 .duration_since(std::time::UNIX_EPOCH)
1967 .unwrap_or_default()
1968 .as_secs()
1969 .cast_signed();
1970 let mut ranked = vec![(MessageId(1), 1.0f64)];
1971 let mut timestamps = std::collections::HashMap::new();
1972 timestamps.insert(MessageId(1), now);
1973 apply_temporal_decay(&mut ranked, ×tamps, 30);
1974 assert!((ranked[0].1 - 1.0).abs() < 0.01);
1976 }
1977
1978 #[test]
1979 fn temporal_decay_half_life_halves_score() {
1980 let half_life = 30u32;
1982 let age_secs = i64::from(half_life) * 86400;
1983 let now = std::time::SystemTime::now()
1984 .duration_since(std::time::UNIX_EPOCH)
1985 .unwrap_or_default()
1986 .as_secs()
1987 .cast_signed();
1988 let ts = now - age_secs;
1989 let mut ranked = vec![(MessageId(1), 1.0f64)];
1990 let mut timestamps = std::collections::HashMap::new();
1991 timestamps.insert(MessageId(1), ts);
1992 apply_temporal_decay(&mut ranked, ×tamps, half_life);
1993 assert!(
1995 (ranked[0].1 - 0.5).abs() < 0.01,
1996 "score was {}",
1997 ranked[0].1
1998 );
1999 }
2000
2001 #[test]
2004 fn mmr_empty_input_returns_empty() {
2005 let ranked = vec![];
2006 let vectors = std::collections::HashMap::new();
2007 let result = apply_mmr(&ranked, &vectors, 0.7, 5);
2008 assert!(result.is_empty());
2009 }
2010
2011 #[test]
2012 fn mmr_returns_up_to_limit() {
2013 let ranked = vec![
2014 (MessageId(1), 1.0f64),
2015 (MessageId(2), 0.9f64),
2016 (MessageId(3), 0.8f64),
2017 ];
2018 let mut vectors = std::collections::HashMap::new();
2019 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2020 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2021 vectors.insert(MessageId(3), vec![1.0f32, 0.0]);
2022 let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2023 assert_eq!(result.len(), 2);
2024 }
2025
2026 #[test]
2027 fn mmr_without_vectors_picks_by_relevance() {
2028 let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2029 let vectors = std::collections::HashMap::new();
2030 let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2031 assert_eq!(result.len(), 2);
2032 assert_eq!(result[0].0, MessageId(1));
2033 }
2034
2035 #[test]
2036 fn mmr_prefers_diverse_over_redundant() {
2037 let ranked = vec![
2039 (MessageId(1), 1.0f64), (MessageId(2), 0.9f64), (MessageId(3), 0.9f64), ];
2043 let mut vectors = std::collections::HashMap::new();
2044 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2045 vectors.insert(MessageId(2), vec![0.0f32, 1.0]); vectors.insert(MessageId(3), vec![1.0f32, 0.0]); let result = apply_mmr(&ranked, &vectors, 0.5, 2);
2048 assert_eq!(result.len(), 2);
2049 assert_eq!(result[0].0, MessageId(1));
2050 assert_eq!(result[1].0, MessageId(2));
2052 }
2053
2054 #[test]
2055 fn temporal_decay_half_life_zero_is_noop() {
2056 let now = std::time::SystemTime::now()
2057 .duration_since(std::time::UNIX_EPOCH)
2058 .unwrap_or_default()
2059 .as_secs()
2060 .cast_signed();
2061 let age_secs = 30i64 * 86400;
2062 let ts = now - age_secs;
2063 let mut ranked = vec![(MessageId(1), 1.0f64)];
2064 let mut timestamps = std::collections::HashMap::new();
2065 timestamps.insert(MessageId(1), ts);
2066 apply_temporal_decay(&mut ranked, ×tamps, 0);
2068 assert!(
2069 (ranked[0].1 - 1.0).abs() < f64::EPSILON,
2070 "score was {}",
2071 ranked[0].1
2072 );
2073 }
2074
2075 #[test]
2076 fn temporal_decay_huge_age_near_zero() {
2077 let now = std::time::SystemTime::now()
2078 .duration_since(std::time::UNIX_EPOCH)
2079 .unwrap_or_default()
2080 .as_secs()
2081 .cast_signed();
2082 let age_secs = 3650i64 * 86400;
2084 let ts = now - age_secs;
2085 let mut ranked = vec![(MessageId(1), 1.0f64)];
2086 let mut timestamps = std::collections::HashMap::new();
2087 timestamps.insert(MessageId(1), ts);
2088 apply_temporal_decay(&mut ranked, ×tamps, 30);
2089 assert!(ranked[0].1 < 0.001, "score was {}", ranked[0].1);
2091 }
2092
2093 #[test]
2094 fn temporal_decay_small_half_life() {
2095 let now = std::time::SystemTime::now()
2097 .duration_since(std::time::UNIX_EPOCH)
2098 .unwrap_or_default()
2099 .as_secs()
2100 .cast_signed();
2101 let ts = now - 7 * 86400i64;
2102 let mut ranked = vec![(MessageId(1), 1.0f64)];
2103 let mut timestamps = std::collections::HashMap::new();
2104 timestamps.insert(MessageId(1), ts);
2105 apply_temporal_decay(&mut ranked, ×tamps, 1);
2106 assert!(ranked[0].1 < 0.01, "score was {}", ranked[0].1);
2107 }
2108
2109 #[test]
2110 fn mmr_lambda_zero_max_diversity() {
2111 let ranked = vec![
2113 (MessageId(1), 1.0f64), (MessageId(2), 0.9f64), (MessageId(3), 0.85f64), ];
2117 let mut vectors = std::collections::HashMap::new();
2118 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2119 vectors.insert(MessageId(2), vec![0.0f32, 1.0]); vectors.insert(MessageId(3), vec![1.0f32, 0.0]); let result = apply_mmr(&ranked, &vectors, 0.0, 3);
2122 assert_eq!(result.len(), 3);
2123 assert_eq!(result[1].0, MessageId(2));
2125 }
2126
2127 #[test]
2128 fn mmr_lambda_one_pure_relevance() {
2129 let ranked = vec![
2131 (MessageId(1), 1.0f64),
2132 (MessageId(2), 0.8f64),
2133 (MessageId(3), 0.6f64),
2134 ];
2135 let mut vectors = std::collections::HashMap::new();
2136 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2137 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2138 vectors.insert(MessageId(3), vec![0.5f32, 0.5]);
2139 let result = apply_mmr(&ranked, &vectors, 1.0, 3);
2140 assert_eq!(result.len(), 3);
2141 assert_eq!(result[0].0, MessageId(1));
2142 assert_eq!(result[1].0, MessageId(2));
2143 assert_eq!(result[2].0, MessageId(3));
2144 }
2145
2146 #[test]
2147 fn mmr_limit_zero_returns_empty() {
2148 let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.8f64)];
2149 let mut vectors = std::collections::HashMap::new();
2150 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2151 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2152 let result = apply_mmr(&ranked, &vectors, 0.7, 0);
2153 assert!(result.is_empty());
2154 }
2155
2156 #[test]
2157 fn mmr_duplicate_vectors_penalizes_second() {
2158 let ranked = vec![
2160 (MessageId(1), 1.0f64),
2161 (MessageId(2), 1.0f64), (MessageId(3), 0.9f64), ];
2164 let mut vectors = std::collections::HashMap::new();
2165 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2166 vectors.insert(MessageId(2), vec![1.0f32, 0.0]); vectors.insert(MessageId(3), vec![0.0f32, 1.0]); let result = apply_mmr(&ranked, &vectors, 0.5, 3);
2169 assert_eq!(result.len(), 3);
2170 assert_eq!(result[0].0, MessageId(1));
2171 assert_eq!(result[1].0, MessageId(3));
2173 }
2174
2175 use proptest::prelude::*;
2178
2179 proptest! {
2180 #[test]
2181 fn count_tokens_never_panics(s in ".*") {
2182 let counter = crate::token_counter::TokenCounter::new();
2183 let _ = counter.count_tokens(&s);
2184 }
2185 }
2186}